diff --git a/.bazelrc b/.bazelrc index a37a042fdf73c..ce8406b58aaab 100644 --- a/.bazelrc +++ b/.bazelrc @@ -31,6 +31,10 @@ build --copt=-isystem --copt=bazel-out/k8-fastbuild-cpu-only/bin # rules_cuda configuration build:cpu-only --@rules_cuda//cuda:enable_cuda=False +# Definition of --config=shell +# interactive shell immediately before execution +build:shell --run_under="//tools/bazel_tools:shellwrap" + # Disable all warnings for external repositories. We don't care about # their warnings. build --per_file_copt=^external/@-w @@ -103,6 +107,9 @@ build --per_file_copt='//:aten/src/ATen/RegisterNestedTensorCPU\.cpp$'@-Wno-erro build --per_file_copt='//:aten/src/ATen/RegisterQuantizedCPU\.cpp$'@-Wno-error=unused-function build --per_file_copt='//:aten/src/ATen/RegisterSparseCPU\.cpp$'@-Wno-error=unused-function build --per_file_copt='//:aten/src/ATen/RegisterSparseCsrCPU\.cpp$'@-Wno-error=unused-function +build --per_file_copt='//:aten/src/ATen/RegisterNestedTensorMeta\.cpp$'@-Wno-error=unused-function +build --per_file_copt='//:aten/src/ATen/RegisterSparseMeta\.cpp$'@-Wno-error=unused-function +build --per_file_copt='//:aten/src/ATen/RegisterQuantizedMeta\.cpp$'@-Wno-error=unused-function build --per_file_copt='//:aten/src/ATen/RegisterZeroTensor\.cpp$'@-Wno-error=unused-function build --per_file_copt='//:torch/csrc/lazy/generated/RegisterAutogradLazy\.cpp$'@-Wno-error=unused-function build --per_file_copt='//:torch/csrc/lazy/generated/RegisterLazy\.cpp$'@-Wno-error=unused-function diff --git a/.buckconfig.oss b/.buckconfig.oss index a289ddd82a146..6a214726272fd 100644 --- a/.buckconfig.oss +++ b/.buckconfig.oss @@ -3,9 +3,11 @@ [buildfile] name = BUCK.oss + includes = //tools/build_defs/select.bzl [repositories] bazel_skylib = third_party/bazel-skylib/ + ovr_config = . [download] in_build = true @@ -13,6 +15,11 @@ [cxx] cxxflags = -std=c++17 should_remap_host_platform = true + cpp = /usr/bin/clang + cc = /usr/bin/clang + cxx = /usr/bin/clang++ + cxxpp = /usr/bin/clang++ + ld = /usr/bin/clang++ [project] default_flavors_mode=all diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index 09756135fe64c..4ea80ab4f79da 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -75,6 +75,7 @@ def child_constructor(self): "vulkan": VulkanConfigNode, "parallel_tbb": ParallelTBBConfigNode, "crossref": CrossRefConfigNode, + "dynamo": DynamoConfigNode, "parallel_native": ParallelNativeConfigNode, "onnx": ONNXConfigNode, "libtorch": LibTorchConfigNode, @@ -179,6 +180,14 @@ def child_constructor(self): return ImportantConfigNode +class DynamoConfigNode(TreeConfigNode): + def init2(self, node_name): + self.props["is_dynamo"] = node_name + + def child_constructor(self): + return ImportantConfigNode + + class ParallelNativeConfigNode(TreeConfigNode): def modify_label(self, label): return "PARALLELNATIVE=" + str(label) diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index 0eb7b5ec5210f..76e87b07c1889 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -240,6 +240,7 @@ def instantiate_configs(only_slow_gradcheck): is_xla = fc.find_prop("is_xla") or False is_asan = fc.find_prop("is_asan") or False is_crossref = fc.find_prop("is_crossref") or False + is_dynamo = fc.find_prop("is_dynamo") or False is_onnx = fc.find_prop("is_onnx") or False is_pure_torch = fc.find_prop("is_pure_torch") or False is_vulkan = fc.find_prop("is_vulkan") or False @@ -286,6 +287,9 @@ def instantiate_configs(only_slow_gradcheck): if is_crossref: parms_list_ignored_for_docker_image.append("crossref") + if is_dynamo: + parms_list_ignored_for_docker_image.append("dynamo") + if is_onnx: parms_list.append("onnx") python_version = fc.find_prop("pyver") diff --git a/.circleci/docker/android/build.gradle b/.circleci/docker/android/build.gradle index 734e8eab8d9d8..66b936326b72c 100644 --- a/.circleci/docker/android/build.gradle +++ b/.circleci/docker/android/build.gradle @@ -53,7 +53,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' implementation 'com.google.code.findbugs:jsr305:3.0.1' - implementation 'com.facebook.soloader:nativeloader:0.10.1' + implementation 'com.facebook.soloader:nativeloader:0.10.4' implementation 'junit:junit:' + rootProject.junitVersion implementation 'androidx.test:core:' + rootProject.coreVersion diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh index 42b1747678dab..ee785bbc95039 100755 --- a/.circleci/docker/build.sh +++ b/.circleci/docker/build.sh @@ -139,9 +139,19 @@ case "$image" in KATEX=yes ;; pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7) - CUDA_VERSION=11.6.0 + CUDA_VERSION=11.6.2 CUDNN_VERSION=8 - ANACONDA_PYTHON_VERSION=3.7 + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=7 + PROTOBUF=yes + DB=yes + VISION=yes + KATEX=yes + ;; + pytorch-linux-bionic-cuda11.7-cudnn8-py3-gcc7) + CUDA_VERSION=11.7.0 + CUDNN_VERSION=8 + ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=7 PROTOBUF=yes DB=yes @@ -176,6 +186,13 @@ case "$image" in DB=yes VISION=yes ;; + pytorch-linux-focal-py3-clang10-onnx) + ANACONDA_PYTHON_VERSION=3.7 + CLANG_VERSION=10 + PROTOBUF=yes + DB=yes + VISION=yes + ;; pytorch-linux-xenial-py3-clang5-android-ndk-r19c) ANACONDA_PYTHON_VERSION=3.7 CLANG_VERSION=5.0 @@ -227,21 +244,21 @@ case "$image" in DB=yes VISION=yes ;; - pytorch-linux-bionic-rocm5.0-py3.7) + pytorch-linux-focal-rocm5.1-py3.7) ANACONDA_PYTHON_VERSION=3.7 GCC_VERSION=9 PROTOBUF=yes DB=yes VISION=yes - ROCM_VERSION=5.0 + ROCM_VERSION=5.1.1 ;; - pytorch-linux-bionic-rocm5.1-py3.7) + pytorch-linux-focal-rocm5.2-py3.7) ANACONDA_PYTHON_VERSION=3.7 GCC_VERSION=9 PROTOBUF=yes DB=yes VISION=yes - ROCM_VERSION=5.1.1 + ROCM_VERSION=5.2 ;; pytorch-linux-focal-py3.7-gcc7) ANACONDA_PYTHON_VERSION=3.7 @@ -261,6 +278,15 @@ case "$image" in DB=yes VISION=yes ;; + pytorch-linux-jammy-cuda11.7-cudnn8-py3.8-clang12) + ANACONDA_PYTHON_VERSION=3.8 + CUDA_VERSION=11.7 + CUDNN_VERSION=8 + CLANG_VERSION=12 + PROTOBUF=yes + DB=yes + VISION=yes + ;; *) # Catch-all for builds that are not hardcoded. PROTOBUF=yes diff --git a/.circleci/docker/build_docker.sh b/.circleci/docker/build_docker.sh index 23d985f1701dd..3e936b3a447b6 100755 --- a/.circleci/docker/build_docker.sh +++ b/.circleci/docker/build_docker.sh @@ -18,6 +18,7 @@ tag="${DOCKER_TAG}" registry="308535385114.dkr.ecr.us-east-1.amazonaws.com" image="${registry}/pytorch/${IMAGE_NAME}" +ghcr_image="ghcr.io/pytorch/ci-image" login() { aws ecr get-authorization-token --region us-east-1 --output text --query 'authorizationData[].authorizationToken' | @@ -48,7 +49,19 @@ fi # Only push if `DOCKER_SKIP_PUSH` = false if [ "${DOCKER_SKIP_PUSH:-true}" = "false" ]; then - docker push "${image}:${tag}" + # Only push if docker image doesn't exist already. + # ECR image tags are immutable so this will avoid pushing if only just testing if the docker jobs work + # NOTE: The only workflow that should push these images should be the docker-builds.yml workflow + if ! docker manifest inspect "${image}:${tag}" >/dev/null 2>/dev/null; then + docker push "${image}:${tag}" + fi + + if [ "${PUSH_GHCR_IMAGE:-}" = "true" ]; then + # Push docker image to the ghcr.io + echo $GHCR_PAT | docker login ghcr.io -u pytorch --password-stdin + docker tag "${image}:${tag}" "${ghcr_image}:${IMAGE_NAME}-${tag}" + docker push "${ghcr_image}:${IMAGE_NAME}-${tag}" + fi fi if [ -z "${DOCKER_SKIP_S3_UPLOAD:-}" ]; then diff --git a/.circleci/docker/centos-rocm/Dockerfile b/.circleci/docker/centos-rocm/Dockerfile index e0ef9e3296fe8..7c7708d416fe1 100644 --- a/.circleci/docker/centos-rocm/Dockerfile +++ b/.circleci/docker/centos-rocm/Dockerfile @@ -12,7 +12,7 @@ ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} # Install common dependencies (so that this step can be cached separately) ARG EC2 -ADD ./common/install_base.sh install_base.sh +COPY ./common/install_base.sh install_base.sh RUN bash ./install_base.sh && rm install_base.sh # Update CentOS git version @@ -23,52 +23,52 @@ RUN yum install -y git # Install devtoolset ARG DEVTOOLSET_VERSION -ADD ./common/install_devtoolset.sh install_devtoolset.sh +COPY ./common/install_devtoolset.sh install_devtoolset.sh RUN bash ./install_devtoolset.sh && rm install_devtoolset.sh ENV BASH_ENV "/etc/profile" # (optional) Install non-default glibc version ARG GLIBC_VERSION -ADD ./common/install_glibc.sh install_glibc.sh +COPY ./common/install_glibc.sh install_glibc.sh RUN if [ -n "${GLIBC_VERSION}" ]; then bash ./install_glibc.sh; fi RUN rm install_glibc.sh # Install user -ADD ./common/install_user.sh install_user.sh +COPY ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh # Install conda and other packages (e.g., numpy, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION -ADD requirements-ci.txt /opt/conda/requirements-ci.txt -ADD ./common/install_conda.sh install_conda.sh +COPY requirements-ci.txt /opt/conda/requirements-ci.txt +COPY ./common/install_conda.sh install_conda.sh RUN bash ./install_conda.sh && rm install_conda.sh RUN rm /opt/conda/requirements-ci.txt # (optional) Install protobuf for ONNX ARG PROTOBUF -ADD ./common/install_protobuf.sh install_protobuf.sh +COPY ./common/install_protobuf.sh install_protobuf.sh RUN if [ -n "${PROTOBUF}" ]; then bash ./install_protobuf.sh; fi RUN rm install_protobuf.sh ENV INSTALLED_PROTOBUF ${PROTOBUF} # (optional) Install database packages like LMDB and LevelDB ARG DB -ADD ./common/install_db.sh install_db.sh +COPY ./common/install_db.sh install_db.sh RUN if [ -n "${DB}" ]; then bash ./install_db.sh; fi RUN rm install_db.sh ENV INSTALLED_DB ${DB} # (optional) Install vision packages like OpenCV and ffmpeg ARG VISION -ADD ./common/install_vision.sh install_vision.sh +COPY ./common/install_vision.sh install_vision.sh RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi RUN rm install_vision.sh ENV INSTALLED_VISION ${VISION} # Install rocm ARG ROCM_VERSION -ADD ./common/install_rocm.sh install_rocm.sh +COPY ./common/install_rocm.sh install_rocm.sh RUN bash ./install_rocm.sh RUN rm install_rocm.sh ENV PATH /opt/rocm/bin:$PATH @@ -82,18 +82,18 @@ ENV LC_ALL en_US.utf8 # (optional) Install non-default CMake version ARG CMAKE_VERSION -ADD ./common/install_cmake.sh install_cmake.sh +COPY ./common/install_cmake.sh install_cmake.sh RUN if [ -n "${CMAKE_VERSION}" ]; then bash ./install_cmake.sh; fi RUN rm install_cmake.sh # (optional) Install non-default Ninja version ARG NINJA_VERSION -ADD ./common/install_ninja.sh install_ninja.sh +COPY ./common/install_ninja.sh install_ninja.sh RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi RUN rm install_ninja.sh # Install ccache/sccache (do this last, so we get priority in PATH) -ADD ./common/install_cache.sh install_cache.sh +COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH RUN bash ./install_cache.sh && rm install_cache.sh diff --git a/.circleci/docker/common/install_base.sh b/.circleci/docker/common/install_base.sh index d08dda324bed8..26ca9d79cedeb 100755 --- a/.circleci/docker/common/install_base.sh +++ b/.circleci/docker/common/install_base.sh @@ -24,9 +24,11 @@ install_ubuntu() { fi if [[ "$CLANG_VERSION" == 12 ]]; then - libomp_dev="libomp-12-dev" + maybe_libomp_dev="libomp-12-dev" + elif [[ "$CLANG_VERSION" == 10 ]]; then + maybe_libomp_dev="libomp-10-dev" else - libomp_dev="" + maybe_libomp_dev="" fi # TODO: Remove this once nvidia package repos are back online @@ -60,11 +62,12 @@ install_ubuntu() { libjpeg-dev \ libasound2-dev \ libsndfile-dev \ - ${libomp_dev} \ + ${maybe_libomp_dev} \ software-properties-common \ wget \ sudo \ - vim + vim \ + jq # Should resolve issues related to various apt package repository cert issues # see: https://github.com/pytorch/pytorch/issues/65931 diff --git a/.circleci/docker/common/install_clang.sh b/.circleci/docker/common/install_clang.sh index 3753c32ca8f8a..9d67030a9bbda 100755 --- a/.circleci/docker/common/install_clang.sh +++ b/.circleci/docker/common/install_clang.sh @@ -13,6 +13,9 @@ if [ -n "$CLANG_VERSION" ]; then sudo apt-get install -y --no-install-recommends gpg-agent wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - apt-add-repository "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-${CLANG_VERSION} main" + elif [[ $UBUNTU_VERSION == 22.04 ]]; then + # work around ubuntu apt-get conflicts + sudo apt-get -y -f install fi sudo apt-get update diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh index cf8c108dd1aad..49afcb5aef423 100755 --- a/.circleci/docker/common/install_conda.sh +++ b/.circleci/docker/common/install_conda.sh @@ -73,19 +73,21 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then } # Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README - # DO NOT install cmake here as it would install a version newer than 3.10, but - # we want to pin to version 3.10. - if [ "$ANACONDA_PYTHON_VERSION" = "3.9" ]; then + # DO NOT install cmake here as it would install a version newer than 3.13, but + # we want to pin to version 3.13. + CONDA_COMMON_DEPS="astunparse pyyaml mkl=2022.0.1 mkl-include=2022.0.1 setuptools cffi future six" + if [ "$ANACONDA_PYTHON_VERSION" = "3.10" ]; then # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source - conda_install numpy=1.19.2 astunparse pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 + conda_install numpy=1.21.2 ${CONDA_COMMON_DEPS} llvmdev=8.0.0 + elif [ "$ANACONDA_PYTHON_VERSION" = "3.9" ]; then + # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source + conda_install numpy=1.19.2 ${CONDA_COMMON_DEPS} llvmdev=8.0.0 elif [ "$ANACONDA_PYTHON_VERSION" = "3.8" ]; then # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source - conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 - elif [ "$ANACONDA_PYTHON_VERSION" = "3.7" ]; then - # DO NOT install dataclasses if installing python-3.7, since its part of python-3.7 core packages - conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six typing_extensions + conda_install numpy=1.18.5 ${CONDA_COMMON_DEPS} llvmdev=8.0.0 else - conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six dataclasses typing_extensions + # Install `typing_extensions` for 3.7 + conda_install numpy=1.18.5 ${CONDA_COMMON_DEPS} typing_extensions fi # Magma package names are concatenation of CUDA major and minor ignoring revision diff --git a/.circleci/docker/common/install_rocm.sh b/.circleci/docker/common/install_rocm.sh index 4cda40bbdca52..ceebd7d606713 100644 --- a/.circleci/docker/common/install_rocm.sh +++ b/.circleci/docker/common/install_rocm.sh @@ -35,7 +35,7 @@ ver() { } # Map ROCm version to AMDGPU version -declare -A AMDGPU_VERSIONS=( ["4.5.2"]="21.40.2" ["5.0"]="21.50" ["5.1.1"]="22.10.1" ) +declare -A AMDGPU_VERSIONS=( ["5.0"]="21.50" ["5.1.1"]="22.10.1" ["5.2"]="22.20" ) install_ubuntu() { apt-get update diff --git a/.circleci/docker/requirements-ci.txt b/.circleci/docker/requirements-ci.txt index 2cb73341db78e..451bd39467c37 100644 --- a/.circleci/docker/requirements-ci.txt +++ b/.circleci/docker/requirements-ci.txt @@ -41,7 +41,7 @@ flatbuffers==2.0 #Pinned versions: #test that import: -hypothesis==4.53.2 +hypothesis==5.35.1 # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 #Description: advanced library for generating parametrized tests #Pinned versions: 3.44.6, 4.53.2 @@ -86,11 +86,11 @@ mypy==0.960 #Pinned versions: 0.960 #test that import: test_typing.py, test_type_hints.py -#networkx +networkx==2.6.3 #Description: creation, manipulation, and study of #the structure, dynamics, and functions of complex networks -#Pinned versions: 2.0 -#test that import: +#Pinned versions: 2.6.3 (latest version that works with Python 3.7+) +#test that import: functorch #ninja #Description: build system. Note that it install from @@ -100,6 +100,7 @@ mypy==0.960 numba==0.49.0 ; python_version < "3.9" numba==0.54.1 ; python_version == "3.9" +numba==0.55.2 ; python_version == "3.10" #Description: Just-In-Time Compiler for Numerical Functions #Pinned versions: 0.54.1, 0.49.0, <=0.49.1 #test that import: test_numba_integration.py @@ -143,6 +144,16 @@ pytest #Pinned versions: #test that import: test_typing.py, test_cpp_extensions_aot.py, run_test.py +pytest-xdist +#Description: plugin for running pytest in parallel +#Pinned versions: +#test that import: + +pytest-rerunfailures +#Description: plugin for rerunning tests in pytest +#Pinned versions: +#test that import: + #pytest-benchmark #Description: fixture for benchmarking code #Pinned versions: 3.2.3 @@ -178,7 +189,8 @@ scikit-image #Pinned versions: 0.20.3 #test that import: -scipy==1.6.3 +scipy==1.6.3 ; python_version < "3.10" +scipy==1.8.1 ; python_version == "3.10" # Pin SciPy because of failing distribution tests (see #60347) #Description: scientific python #Pinned versions: 1.6.3 diff --git a/.circleci/docker/ubuntu-cuda/Dockerfile b/.circleci/docker/ubuntu-cuda/Dockerfile index fe4aa0b5833ac..f7674987a0c3e 100644 --- a/.circleci/docker/ubuntu-cuda/Dockerfile +++ b/.circleci/docker/ubuntu-cuda/Dockerfile @@ -11,81 +11,85 @@ ENV DEBIAN_FRONTEND noninteractive # Install common dependencies (so that this step can be cached separately) ARG EC2 -ADD ./common/install_base.sh install_base.sh +COPY ./common/install_base.sh install_base.sh RUN bash ./install_base.sh && rm install_base.sh # Install user -ADD ./common/install_user.sh install_user.sh +COPY ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh # Install katex ARG KATEX -ADD ./common/install_docs_reqs.sh install_docs_reqs.sh +COPY ./common/install_docs_reqs.sh install_docs_reqs.sh RUN bash ./install_docs_reqs.sh && rm install_docs_reqs.sh # Install conda and other packages (e.g., numpy, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION -ADD requirements-ci.txt /opt/conda/requirements-ci.txt -ADD ./common/install_conda.sh install_conda.sh +COPY requirements-ci.txt /opt/conda/requirements-ci.txt +COPY ./common/install_conda.sh install_conda.sh RUN bash ./install_conda.sh && rm install_conda.sh RUN rm /opt/conda/requirements-ci.txt # Install gcc ARG GCC_VERSION -ADD ./common/install_gcc.sh install_gcc.sh +COPY ./common/install_gcc.sh install_gcc.sh RUN bash ./install_gcc.sh && rm install_gcc.sh # Install clang ARG CLANG_VERSION -ADD ./common/install_clang.sh install_clang.sh +COPY ./common/install_clang.sh install_clang.sh RUN bash ./install_clang.sh && rm install_clang.sh # (optional) Install protobuf for ONNX ARG PROTOBUF -ADD ./common/install_protobuf.sh install_protobuf.sh +COPY ./common/install_protobuf.sh install_protobuf.sh RUN if [ -n "${PROTOBUF}" ]; then bash ./install_protobuf.sh; fi RUN rm install_protobuf.sh ENV INSTALLED_PROTOBUF ${PROTOBUF} # (optional) Install database packages like LMDB and LevelDB ARG DB -ADD ./common/install_db.sh install_db.sh +COPY ./common/install_db.sh install_db.sh RUN if [ -n "${DB}" ]; then bash ./install_db.sh; fi RUN rm install_db.sh ENV INSTALLED_DB ${DB} # (optional) Install vision packages like OpenCV and ffmpeg ARG VISION -ADD ./common/install_vision.sh install_vision.sh +COPY ./common/install_vision.sh install_vision.sh RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi RUN rm install_vision.sh ENV INSTALLED_VISION ${VISION} -ADD ./common/install_openssl.sh install_openssl.sh +COPY ./common/install_openssl.sh install_openssl.sh ENV OPENSSL_ROOT_DIR /opt/openssl RUN bash ./install_openssl.sh ENV OPENSSL_DIR /opt/openssl # (optional) Install non-default CMake version ARG CMAKE_VERSION -ADD ./common/install_cmake.sh install_cmake.sh +COPY ./common/install_cmake.sh install_cmake.sh RUN if [ -n "${CMAKE_VERSION}" ]; then bash ./install_cmake.sh; fi RUN rm install_cmake.sh # Install ccache/sccache (do this last, so we get priority in PATH) -ADD ./common/install_cache.sh install_cache.sh +COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH +# See https://github.com/pytorch/pytorch/issues/82174 +# TODO(sdym@fb.com): +# check if this is needed after full off Xenial migration +ENV CARGO_NET_GIT_FETCH_WITH_CLI true RUN bash ./install_cache.sh && rm install_cache.sh ENV CMAKE_CUDA_COMPILER_LAUNCHER=/opt/cache/bin/sccache # Add jni.h for java host build -ADD ./common/install_jni.sh install_jni.sh -ADD ./java/jni.h jni.h +COPY ./common/install_jni.sh install_jni.sh +COPY ./java/jni.h jni.h RUN bash ./install_jni.sh && rm install_jni.sh # Install Open MPI for CUDA -ADD ./common/install_openmpi.sh install_openmpi.sh +COPY ./common/install_openmpi.sh install_openmpi.sh RUN if [ -n "${CUDA_VERSION}" ]; then bash install_openmpi.sh; fi RUN rm install_openmpi.sh @@ -103,7 +107,7 @@ COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm # Install CUDNN ARG CUDNN_VERSION -ADD ./common/install_cudnn.sh install_cudnn.sh +COPY ./common/install_cudnn.sh install_cudnn.sh RUN if [ "${CUDNN_VERSION}" -eq 8 ]; then bash install_cudnn.sh; fi RUN rm install_cudnn.sh diff --git a/.circleci/docker/ubuntu-rocm/Dockerfile b/.circleci/docker/ubuntu-rocm/Dockerfile index 2605928763633..a994b2e52f236 100644 --- a/.circleci/docker/ubuntu-rocm/Dockerfile +++ b/.circleci/docker/ubuntu-rocm/Dockerfile @@ -12,56 +12,56 @@ ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} # Install common dependencies (so that this step can be cached separately) ARG EC2 -ADD ./common/install_base.sh install_base.sh +COPY ./common/install_base.sh install_base.sh RUN bash ./install_base.sh && rm install_base.sh # Install clang ARG LLVMDEV ARG CLANG_VERSION -ADD ./common/install_clang.sh install_clang.sh +COPY ./common/install_clang.sh install_clang.sh RUN bash ./install_clang.sh && rm install_clang.sh # Install user -ADD ./common/install_user.sh install_user.sh +COPY ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh # Install conda and other packages (e.g., numpy, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION -ADD requirements-ci.txt /opt/conda/requirements-ci.txt -ADD ./common/install_conda.sh install_conda.sh +COPY requirements-ci.txt /opt/conda/requirements-ci.txt +COPY ./common/install_conda.sh install_conda.sh RUN bash ./install_conda.sh && rm install_conda.sh RUN rm /opt/conda/requirements-ci.txt # Install gcc ARG GCC_VERSION -ADD ./common/install_gcc.sh install_gcc.sh +COPY ./common/install_gcc.sh install_gcc.sh RUN bash ./install_gcc.sh && rm install_gcc.sh # (optional) Install protobuf for ONNX ARG PROTOBUF -ADD ./common/install_protobuf.sh install_protobuf.sh +COPY ./common/install_protobuf.sh install_protobuf.sh RUN if [ -n "${PROTOBUF}" ]; then bash ./install_protobuf.sh; fi RUN rm install_protobuf.sh ENV INSTALLED_PROTOBUF ${PROTOBUF} # (optional) Install database packages like LMDB and LevelDB ARG DB -ADD ./common/install_db.sh install_db.sh +COPY ./common/install_db.sh install_db.sh RUN if [ -n "${DB}" ]; then bash ./install_db.sh; fi RUN rm install_db.sh ENV INSTALLED_DB ${DB} # (optional) Install vision packages like OpenCV and ffmpeg ARG VISION -ADD ./common/install_vision.sh install_vision.sh +COPY ./common/install_vision.sh install_vision.sh RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi RUN rm install_vision.sh ENV INSTALLED_VISION ${VISION} # Install rocm ARG ROCM_VERSION -ADD ./common/install_rocm.sh install_rocm.sh +COPY ./common/install_rocm.sh install_rocm.sh RUN bash ./install_rocm.sh RUN rm install_rocm.sh ENV PATH /opt/rocm/bin:$PATH @@ -75,18 +75,18 @@ ENV LC_ALL C.UTF-8 # (optional) Install non-default CMake version ARG CMAKE_VERSION -ADD ./common/install_cmake.sh install_cmake.sh +COPY ./common/install_cmake.sh install_cmake.sh RUN if [ -n "${CMAKE_VERSION}" ]; then bash ./install_cmake.sh; fi RUN rm install_cmake.sh # (optional) Install non-default Ninja version ARG NINJA_VERSION -ADD ./common/install_ninja.sh install_ninja.sh +COPY ./common/install_ninja.sh install_ninja.sh RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi RUN rm install_ninja.sh # Install ccache/sccache (do this last, so we get priority in PATH) -ADD ./common/install_cache.sh install_cache.sh +COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH RUN bash ./install_cache.sh && rm install_cache.sh diff --git a/.circleci/docker/ubuntu/Dockerfile b/.circleci/docker/ubuntu/Dockerfile index 8e51102d0eaa3..22592534c20f0 100644 --- a/.circleci/docker/ubuntu/Dockerfile +++ b/.circleci/docker/ubuntu/Dockerfile @@ -10,45 +10,45 @@ ARG CLANG_VERSION # Install common dependencies (so that this step can be cached separately) ARG EC2 -ADD ./common/install_base.sh install_base.sh +COPY ./common/install_base.sh install_base.sh RUN bash ./install_base.sh && rm install_base.sh # Install clang ARG LLVMDEV -ADD ./common/install_clang.sh install_clang.sh +COPY ./common/install_clang.sh install_clang.sh RUN bash ./install_clang.sh && rm install_clang.sh # (optional) Install thrift. ARG THRIFT -ADD ./common/install_thrift.sh install_thrift.sh +COPY ./common/install_thrift.sh install_thrift.sh RUN if [ -n "${THRIFT}" ]; then bash ./install_thrift.sh; fi RUN rm install_thrift.sh ENV INSTALLED_THRIFT ${THRIFT} # Install user -ADD ./common/install_user.sh install_user.sh +COPY ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh # Install katex ARG KATEX -ADD ./common/install_docs_reqs.sh install_docs_reqs.sh +COPY ./common/install_docs_reqs.sh install_docs_reqs.sh RUN bash ./install_docs_reqs.sh && rm install_docs_reqs.sh # Install conda and other packages (e.g., numpy, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION -ADD requirements-ci.txt /opt/conda/requirements-ci.txt -ADD ./common/install_conda.sh install_conda.sh +COPY requirements-ci.txt /opt/conda/requirements-ci.txt +COPY ./common/install_conda.sh install_conda.sh RUN bash ./install_conda.sh && rm install_conda.sh RUN rm /opt/conda/requirements-ci.txt # Install gcc ARG GCC_VERSION -ADD ./common/install_gcc.sh install_gcc.sh +COPY ./common/install_gcc.sh install_gcc.sh RUN bash ./install_gcc.sh && rm install_gcc.sh # Install lcov for C++ code coverage -ADD ./common/install_lcov.sh install_lcov.sh +COPY ./common/install_lcov.sh install_lcov.sh RUN bash ./install_lcov.sh && rm install_lcov.sh # Install cuda and cudnn @@ -60,21 +60,21 @@ ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH # (optional) Install protobuf for ONNX ARG PROTOBUF -ADD ./common/install_protobuf.sh install_protobuf.sh +COPY ./common/install_protobuf.sh install_protobuf.sh RUN if [ -n "${PROTOBUF}" ]; then bash ./install_protobuf.sh; fi RUN rm install_protobuf.sh ENV INSTALLED_PROTOBUF ${PROTOBUF} # (optional) Install database packages like LMDB and LevelDB ARG DB -ADD ./common/install_db.sh install_db.sh +COPY ./common/install_db.sh install_db.sh RUN if [ -n "${DB}" ]; then bash ./install_db.sh; fi RUN rm install_db.sh ENV INSTALLED_DB ${DB} # (optional) Install vision packages like OpenCV and ffmpeg ARG VISION -ADD ./common/install_vision.sh install_vision.sh +COPY ./common/install_vision.sh install_vision.sh RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi RUN rm install_vision.sh ENV INSTALLED_VISION ${VISION} @@ -83,9 +83,9 @@ ENV INSTALLED_VISION ${VISION} ARG ANDROID ARG ANDROID_NDK ARG GRADLE_VERSION -ADD ./common/install_android.sh install_android.sh -ADD ./android/AndroidManifest.xml AndroidManifest.xml -ADD ./android/build.gradle build.gradle +COPY ./common/install_android.sh install_android.sh +COPY ./android/AndroidManifest.xml AndroidManifest.xml +COPY ./android/build.gradle build.gradle RUN if [ -n "${ANDROID}" ]; then bash ./install_android.sh; fi RUN rm install_android.sh RUN rm AndroidManifest.xml @@ -94,46 +94,50 @@ ENV INSTALLED_ANDROID ${ANDROID} # (optional) Install Vulkan SDK ARG VULKAN_SDK_VERSION -ADD ./common/install_vulkan_sdk.sh install_vulkan_sdk.sh +COPY ./common/install_vulkan_sdk.sh install_vulkan_sdk.sh RUN if [ -n "${VULKAN_SDK_VERSION}" ]; then bash ./install_vulkan_sdk.sh; fi RUN rm install_vulkan_sdk.sh # (optional) Install swiftshader ARG SWIFTSHADER -ADD ./common/install_swiftshader.sh install_swiftshader.sh +COPY ./common/install_swiftshader.sh install_swiftshader.sh RUN if [ -n "${SWIFTSHADER}" ]; then bash ./install_swiftshader.sh; fi RUN rm install_swiftshader.sh # (optional) Install non-default CMake version ARG CMAKE_VERSION -ADD ./common/install_cmake.sh install_cmake.sh +COPY ./common/install_cmake.sh install_cmake.sh RUN if [ -n "${CMAKE_VERSION}" ]; then bash ./install_cmake.sh; fi RUN rm install_cmake.sh # (optional) Install non-default Ninja version ARG NINJA_VERSION -ADD ./common/install_ninja.sh install_ninja.sh +COPY ./common/install_ninja.sh install_ninja.sh RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi RUN rm install_ninja.sh -ADD ./common/install_openssl.sh install_openssl.sh +COPY ./common/install_openssl.sh install_openssl.sh RUN bash ./install_openssl.sh ENV OPENSSL_ROOT_DIR /opt/openssl ENV OPENSSL_DIR /opt/openssl RUN rm install_openssl.sh # Install ccache/sccache (do this last, so we get priority in PATH) -ADD ./common/install_cache.sh install_cache.sh +COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH +# See https://github.com/pytorch/pytorch/issues/82174 +# TODO(sdym@fb.com): +# check if this is needed after full off Xenial migration +ENV CARGO_NET_GIT_FETCH_WITH_CLI true RUN bash ./install_cache.sh && rm install_cache.sh # Add jni.h for java host build -ADD ./common/install_jni.sh install_jni.sh -ADD ./java/jni.h jni.h +COPY ./common/install_jni.sh install_jni.sh +COPY ./java/jni.h jni.h RUN bash ./install_jni.sh && rm install_jni.sh # Install Open MPI for CUDA -ADD ./common/install_openmpi.sh install_openmpi.sh +COPY ./common/install_openmpi.sh install_openmpi.sh RUN if [ -n "${CUDA_VERSION}" ]; then bash install_openmpi.sh; fi RUN rm install_openmpi.sh diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index e69515017f4f0..6e34b3e1e5f41 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -89,7 +89,7 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then cu_ver="${DESIRED_CUDA:2:2}.${DESIRED_CUDA:4}" CUDA_PACKAGE="cudatoolkit" - if [[ "$DESIRED_CUDA" == "cu116" ]]; then + if [[ "$DESIRED_CUDA" == "cu116" || "$DESIRED_CUDA" == "cu117" ]]; then CUDA_PACKAGE="cuda" fi diff --git a/.circleci/scripts/binary_windows_build.sh b/.circleci/scripts/binary_windows_build.sh index e6500b8d9c93d..be77e6483b7e2 100644 --- a/.circleci/scripts/binary_windows_build.sh +++ b/.circleci/scripts/binary_windows_build.sh @@ -6,7 +6,7 @@ mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR" export CUDA_VERSION="${DESIRED_CUDA/cu/}" export USE_SCCACHE=1 -export SCCACHE_BUCKET=ossci-compiler-cache-windows +export SCCACHE_BUCKET=ossci-compiler-cache export SCCACHE_IGNORE_SERVER_IO_ERROR=1 export VC_YEAR=2019 diff --git a/.circleci/scripts/build_android_gradle.sh b/.circleci/scripts/build_android_gradle.sh index 4be715b4dbfee..598e9cd0a6bd2 100755 --- a/.circleci/scripts/build_android_gradle.sh +++ b/.circleci/scripts/build_android_gradle.sh @@ -78,7 +78,7 @@ if [[ "${BUILD_ENVIRONMENT}" == *-gradle-build-only-x86_32* ]]; then GRADLE_PARAMS+=" -PABI_FILTERS=x86" fi -if [ -n "{GRADLE_OFFLINE:-}" ]; then +if [ -n "${GRADLE_OFFLINE:-}" ]; then GRADLE_PARAMS+=" --offline" fi diff --git a/.circleci/scripts/setup_ci_environment.sh b/.circleci/scripts/setup_ci_environment.sh index 17baae08cda42..8ac4f5b43a9a2 100755 --- a/.circleci/scripts/setup_ci_environment.sh +++ b/.circleci/scripts/setup_ci_environment.sh @@ -32,7 +32,7 @@ if ! command -v aws >/dev/null; then fi if [ -n "${USE_CUDA_DOCKER_RUNTIME:-}" ]; then - DRIVER_FN="NVIDIA-Linux-x86_64-510.60.02.run" + DRIVER_FN="NVIDIA-Linux-x86_64-515.57.run" wget "https://s3.amazonaws.com/ossci-linux/nvidia_driver/$DRIVER_FN" sudo /bin/bash "$DRIVER_FN" -s --no-drm || (sudo cat /var/log/nvidia-installer.log && false) nvidia-smi diff --git a/.circleci/scripts/windows_cuda_install.sh b/.circleci/scripts/windows_cuda_install.sh index f06a2b0ab0963..b75129b7e4929 100644 --- a/.circleci/scripts/windows_cuda_install.sh +++ b/.circleci/scripts/windows_cuda_install.sh @@ -14,6 +14,11 @@ case ${CUDA_VERSION} in cuda_installer_name="cuda_11.6.0_511.23_windows" cuda_install_packages="thrust_11.6 nvcc_11.6 cuobjdump_11.6 nvprune_11.6 nvprof_11.6 cupti_11.6 cublas_11.6 cublas_dev_11.6 cudart_11.6 cufft_11.6 cufft_dev_11.6 curand_11.6 curand_dev_11.6 cusolver_11.6 cusolver_dev_11.6 cusparse_11.6 cusparse_dev_11.6 npp_11.6 npp_dev_11.6 nvrtc_11.6 nvrtc_dev_11.6 nvml_dev_11.6" ;; + 11.7) + cuda_installer_name="cuda_11.7.0_516.01_windows" + cuda_install_packages="thrust_11.7 nvcc_11.7 cuobjdump_11.7 nvprune_11.7 nvprof_11.7 cupti_11.7 cublas_11.7 cublas_dev_11.7 cudart_11.7 cufft_11.7 cufft_dev_11.7 curand_11.7 curand_dev_11.7 cusolver_11.7 cusolver_dev_11.7 cusparse_11.7 cusparse_dev_11.7 npp_11.7 npp_dev_11.7 nvrtc_11.7 nvrtc_dev_11.7 nvml_dev_11.7" + ;; + *) echo "CUDA_VERSION $CUDA_VERSION is not supported yet" exit 1 diff --git a/.circleci/scripts/windows_cudnn_install.sh b/.circleci/scripts/windows_cudnn_install.sh index a815008ee1e0f..763bc950fc4be 100644 --- a/.circleci/scripts/windows_cudnn_install.sh +++ b/.circleci/scripts/windows_cudnn_install.sh @@ -16,6 +16,10 @@ case ${CUDA_VERSION} in # Use cudnn8.3 with hard-coded cuda11.5 version cudnn_file_name="cudnn-windows-x86_64-8.3.2.44_cuda11.5-archive" ;; + 11.7) + # Use cudnn8.3 with hard-coded cuda11.5 version + cudnn_file_name="cudnn-windows-x86_64-8.3.2.44_cuda11.5-archive" + ;; *) echo "CUDA_VERSION: ${CUDA_VERSION} not supported yet" exit 1 diff --git a/.flake8 b/.flake8 index fd251010b0446..75abc4d19049f 100644 --- a/.flake8 +++ b/.flake8 @@ -22,6 +22,9 @@ exclude = ./docs/caffe2, ./docs/cpp/src, ./docs/src, + ./functorch/docs, + ./functorch/examples, + ./functorch/notebooks, ./scripts, ./test/generated_type_hints_smoketest.py, ./third_party, diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 7d428014cd79c..fc203f1e0d6ce 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1 +1,8 @@ -Fixes #ISSUE_NUMBER +### Description + + +### Issue + + +### Testing + diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index d1e25f698f141..4b5afb13f3675 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -12,5 +12,7 @@ self-hosted-runner: - windows.8xlarge.nvidia.gpu - bm-runner - linux.rocm.gpu + - macos-m1-12 + - macos-12-xl - macos-12 - macos12.3-m1 diff --git a/.github/actions/build-android/action.yml b/.github/actions/build-android/action.yml index 9016eb1beff58..5233b62cef0e3 100644 --- a/.github/actions/build-android/action.yml +++ b/.github/actions/build-android/action.yml @@ -37,7 +37,6 @@ runs: shell: bash env: BRANCH: ${{ inputs.branch }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-build-and-test BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-${{ inputs.arch-for-build-env }}-build" AWS_DEFAULT_REGION: us-east-1 PR_NUMBER: ${{ github.event.pull_request.number }} @@ -51,7 +50,6 @@ runs: export container_name container_name=$(docker run \ -e BUILD_ENVIRONMENT \ - -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e AWS_DEFAULT_REGION \ -e PR_NUMBER \ diff --git a/.github/actions/calculate-docker-image/action.yml b/.github/actions/calculate-docker-image/action.yml index db4f3d40cca82..7215bf84e987c 100644 --- a/.github/actions/calculate-docker-image/action.yml +++ b/.github/actions/calculate-docker-image/action.yml @@ -17,9 +17,16 @@ inputs: pull: description: If set to any value, run `docker pull`` on the calculated image. required: false + skip_push: + description: If set to true value, skip will be pushed, default is to skip so that pushing will be explicit + required: false + default: "true" force_push: description: If set to any value, always run the push required: false + push-ghcr-image: + description: If set to any value, push docker image to the ghcr.io. + required: false outputs: docker-image: @@ -34,7 +41,7 @@ runs: id: calculate-tag env: IS_XLA: ${{ inputs.xla == 'true' && 'true' || '' }} - XLA_IMAGE_TAG: v0.2 + XLA_IMAGE_TAG: v0.4 DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/${{ inputs.docker-image-name }} run: | if [ -n "${IS_XLA}" ]; then @@ -96,8 +103,11 @@ runs: env: IMAGE_NAME: ${{inputs.docker-image-name}} DOCKER_SKIP_S3_UPLOAD: "1" - DOCKER_SKIP_PUSH: ${{ steps.check.outputs.skip_push || 'false' }} + # Skip push if we don't need it, or if specified in the inputs + DOCKER_SKIP_PUSH: ${{ steps.check.outputs.skip_push || inputs.skip_push }} DOCKER_TAG: ${{ steps.calculate-tag.outputs.docker-tag }} + PUSH_GHCR_IMAGE: ${{ inputs.push-ghcr-image }} + GHCR_PAT: ${{ env.GHCR_PAT }} working-directory: .circleci/docker shell: bash run: | diff --git a/.github/actions/pull-docker-image/action.yml b/.github/actions/pull-docker-image/action.yml index 6ec67088d7a0e..75e8baf6f2c9f 100644 --- a/.github/actions/pull-docker-image/action.yml +++ b/.github/actions/pull-docker-image/action.yml @@ -17,6 +17,7 @@ runs: run: | retry () { "$@" || (sleep 1 && "$@") || (sleep 2 && "$@") } # ignore output since only exit code is used for conditional - if docker inspect --type=image "${DOCKER_IMAGE}" >/dev/null 2>/dev/null; then + # only pull docker image if it's not available locally + if ! docker inspect --type=image "${DOCKER_IMAGE}" >/dev/null 2>/dev/null; then retry docker pull "${DOCKER_IMAGE}" fi diff --git a/.github/actions/test-pytorch-binary/action.yml b/.github/actions/test-pytorch-binary/action.yml new file mode 100644 index 0000000000000..bc2c546f57b28 --- /dev/null +++ b/.github/actions/test-pytorch-binary/action.yml @@ -0,0 +1,41 @@ +name: Test pytorch binary + +description: Pulls the docker image and tests the pytorch binary using it. All env variable referenced in the "Test PyTorch binary" step must be set in the GITHUB_ENV file + +runs: + using: composite + steps: + - name: Test PyTorch binary + shell: bash + run: | + set -x + # shellcheck disable=SC2086,SC2090 + container_name=$(docker run \ + ${GPU_FLAG:-} \ + -e BINARY_ENV_FILE \ + -e BUILDER_ROOT \ + -e BUILD_ENVIRONMENT \ + -e BUILD_SPLIT_CUDA \ + -e DESIRED_CUDA \ + -e DESIRED_DEVTOOLSET \ + -e DESIRED_PYTHON \ + -e GITHUB_ACTIONS \ + -e GPU_ARCH_TYPE \ + -e GPU_ARCH_VERSION \ + -e LIBTORCH_VARIANT \ + -e PACKAGE_TYPE \ + -e PYTORCH_FINAL_PACKAGE_DIR \ + -e PYTORCH_ROOT \ + -e SKIP_ALL_TESTS \ + --tty \ + --detach \ + -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ + -v "${GITHUB_WORKSPACE}/builder:/builder" \ + -v "${RUNNER_TEMP}/artifacts:/final_pkgs" \ + -w / \ + "${DOCKER_IMAGE}" + ) + docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh" + # Generate test script + docker exec -t -w "${PYTORCH_ROOT}" -e OUTPUT_SCRIPT="/run.sh" "${container_name}" bash -c "bash .circleci/scripts/binary_linux_test.sh" + docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash -x /run.sh" diff --git a/.github/actions/upload-test-artifacts/action.yml b/.github/actions/upload-test-artifacts/action.yml index b335f12206d70..35e249ea96be3 100644 --- a/.github/actions/upload-test-artifacts/action.yml +++ b/.github/actions/upload-test-artifacts/action.yml @@ -36,6 +36,20 @@ runs: rm -f test-reports-*.zip zip -r "test-reports-${FILE_SUFFIX}.zip" test -i '*.xml' + - name: Zip usage log for upload + if: runner.os != 'Windows' && !inputs.use-gha + shell: bash + env: + FILE_SUFFIX: ${{ inputs.file-suffix }} + run: | + # Remove any previous test reports if they exist + rm -f usage-log-*.zip + # this workflow is also run in bazel build test, but we dont generate usage reports for it + # so check to see if the file exists first + if [ -f 'usage_log.txt' ]; then + zip "usage-log-${FILE_SUFFIX}.zip" 'usage_log.txt' + fi + # Windows zip - name: Zip JSONs for upload if: runner.os == 'Windows' && !inputs.use-gha @@ -55,6 +69,15 @@ runs: # -ir => recursive include all files in pattern 7z a "test-reports-$Env:FILE_SUFFIX.zip" -ir'!test\*.xml' + - name: Zip usage log for upload + if: runner.os == 'Windows' && !inputs.use-gha + shell: powershell + env: + FILE_SUFFIX: ${{ inputs.file-suffix }} + run: | + # -ir => recursive include all files in pattern + 7z a "usage-log-$Env:FILE_SUFFIX.zip" 'usage_log.txt' + # S3 upload - name: Store Test Downloaded JSONs on S3 uses: seemethere/upload-artifact-s3@v5 @@ -76,6 +99,16 @@ runs: if-no-files-found: error path: test-reports-*.zip + - name: Store Usage Logs on S3 + uses: seemethere/upload-artifact-s3@v5 + if: ${{ !inputs.use-gha }} + with: + s3-prefix: | + ${{ github.repository }}/${{ github.run_id }}/${{ github.run_attempt }}/artifact + retention-days: 14 + if-no-files-found: ignore + path: usage-log-*.zip + # GHA upload - name: Store Test Downloaded JSONs on Github uses: actions/upload-artifact@v2 diff --git a/.github/ci_commit_pins/torchdynamo.txt b/.github/ci_commit_pins/torchdynamo.txt index 5c726523ee65c..4bca66289a606 100644 --- a/.github/ci_commit_pins/torchdynamo.txt +++ b/.github/ci_commit_pins/torchdynamo.txt @@ -1 +1 @@ -dbb83776c5de185a7a8e2dbafa41d4dc9552d2f9 +a43631c54014b2e68a09b39658cbf515875394f6 diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index c5251d96a3220..7406c775afed2 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -8a2dc6f22ac4389ccba8859aa1e1cb14f1ee53db +1a1d509c8e6578584e7e9e4bd442654bf39149c8 diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 260d4f68f3c24..6f0f5eab8182e 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -02346af955653a179b896eef5475e569ab8d4229 +73c64a55fb096f1e132029d3decbb6f4e532cc7b diff --git a/.github/merge_rules.json b/.github/merge_rules.json index f056c89cb6ef7..704e1a5d96509 100644 --- a/.github/merge_rules.json +++ b/.github/merge_rules.json @@ -3,11 +3,12 @@ "name": "ONNX exporter", "patterns": [ ".jenkins/caffe2/*", - "scripts/onnx/**", + "aten/src/ATen/core/interned_strings.h", "docs/source/onnx.rst", - "test/onnx/**", + "docs/source/scripts/onnx/**", + "scripts/onnx/**", "test/jit/test_export_modes.py", - "aten/src/ATen/core/interned_strings.h", + "test/onnx/**", "tools/onnx/**", "torch/_C/__init__.pyi.in", "torch/csrc/jit/passes/onnx.*", @@ -51,32 +52,32 @@ }, { "name": "CI Pinned Hashes", - "patterns": [".github/ci_commit_pins/**"], + "patterns": [ + ".github/ci_commit_pins/vision.txt", + ".github/ci_commit_pins/torchdynamo.txt" + ], "approved_by": ["pytorchbot", "ezyang", "pytorch/pytorch-dev-infra"], "mandatory_checks_name": [ "Facebook CLA Check", "Lint", - "linux-docs / build-docs (cpp)", - "linux-docs / build-docs (python)", - "win-vs2019-cpu-py3 / build", - "win-vs2019-cuda11.3-py3 / build", - "linux-bionic-py3.7-clang9 / build", - "linux-xenial-py3.7-clang7-onnx / build", - "linux-xenial-py3.7-clang7-asan / build", - "linux-vulkan-bionic-py3.7-clang9 / build", - "linux-xenial-cuda11.3-py3.7-gcc7 / build", - "linux-bionic-cuda11.3-py3.7-clang9 / build", - "linux-xenial-py3-clang5-mobile-build / build", - "linux-xenial-py3-clang5-mobile-custom-build-static / build", - "pytorch-xla-linux-bionic-py3.7-clang8 / build", - "deploy-linux-xenial-cuda11.3-py3.7-gcc7 / build", - "linux-focal-py3.7-gcc7 / build" + "pull" + ] + }, + { + "name": "XLA hash pin update", + "patterns": [".github/ci_commit_pins/xla.txt"], + "approved_by": ["pytorchbot", "ezyang", "pytorch/pytorch-dev-infra"], + "mandatory_checks_name": [ + "Facebook CLA Check", + "Lint", + "pull / linux-bionic-py3_7-clang8-xla / build", + "pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge)" ] }, { "name": "Documentation", "patterns": ["docs/**", "torch/*docs.py"], - "approved_by": ["mruberry", "ngimel", "janeyx99"], + "approved_by": ["mruberry", "ngimel", "janeyx99", "svekars"], "mandatory_checks_name": [ "Facebook CLA Check", "Lint", @@ -98,12 +99,12 @@ "patterns": [ "aten/src/ATen/native/cuda/linalg/**", "aten/src/ATen/LinalgBackend.h", - "aten/src/ATen/native/**/*LinearAlgebra*", + "aten/src/ATen/native/**LinearAlgebra*", "docs/source/linalg.rst", "torch/linalg/**", "torch/_linalg_utils.py", - "torch/**/python_linalg_functions.*", - "torch/**/linalg.h", + "torch/**python_linalg_functions.*", + "torch/**linalg.h", "tools/autograd/templates/python_linalg_functions.cpp", "test/test_linalg.py" ], @@ -124,7 +125,7 @@ "docs/source/fft.rst", "torch/fft/**", "torch/csrc/api/include/torch/fft.h", - "torch/**/python_fft_functions.*", + "torch/**python_fft_functions.*", "tools/autograd/templates/python_fft_functions.cpp", "test/cpp/api/fft.cpp" ], @@ -141,18 +142,18 @@ "benchmarks/sparse", "c10/util/sparse_bitset.h", "docs/source/sparse.rst", - "torch/**/sparse/**", - "torch/**/*sparse*", + "torch/**sparse/**", + "torch/**sparse*", "torch/optim/sparse*", "torch/ao/nn/sparse/**", - "torch/utils/benchmark/**/*sparse*", + "torch/utils/benchmark/**sparse*", "aten/src/ATen/native/ao_sparse/**", "aten/src/ATen/native/sparse/**", - "aten/src/ATen/**/*Sparse*", + "aten/src/ATen/**Sparse*", "aten/src/ATen/*Sparse*", "torch/_masked/**", "test/*_masked*", - "test/**/*sparse*" + "test/**sparse*" ], "approved_by": ["nikitaved", "cpuhrsch", "pearu", "IvanYashchuk"], "mandatory_checks_name": [ @@ -165,6 +166,7 @@ "name": "MPS", "patterns": [ "test/test_mps.py", + "aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/mps/**", "aten/src/ATen/native/mps/**" ], @@ -175,6 +177,19 @@ "pull" ] }, + { + "name": "Distributions", + "patterns": [ + "torch/distributions/**", + "test/distributions/**" + ], + "approved_by": ["fritzo", "neerajprad", "alicanb", "vishwakftw"], + "mandatory_checks_name": [ + "Facebook CLA Check", + "Lint", + "pull" + ] + }, { "name": "Distributed", "patterns": [ diff --git a/.github/scripts/print_latest_commits.py b/.github/scripts/fetch_latest_green_commit.py similarity index 93% rename from .github/scripts/print_latest_commits.py rename to .github/scripts/fetch_latest_green_commit.py index c83ee3cbe8f81..c9bb4830ab722 100644 --- a/.github/scripts/print_latest_commits.py +++ b/.github/scripts/fetch_latest_green_commit.py @@ -1,3 +1,4 @@ +import sys from typing import Any, Dict, List, NamedTuple, Tuple from gitutils import _check_output @@ -5,6 +6,9 @@ import os import re +def eprint(msg: str) -> None: + print(msg, file=sys.stderr) + class WorkflowCheck(NamedTuple): workflowName: str name: str @@ -68,7 +72,6 @@ def isGreen(commit: str, results: Dict[str, Any]) -> Tuple[bool, str]: "trunk": False, "lint": False, "linux-binary": False, - "android-tests": False, "windows-binary": False, } @@ -92,8 +95,13 @@ def isGreen(commit: str, results: Dict[str, Any]) -> Tuple[bool, str]: def get_latest_green_commit(commits: List[str], results: Dict[str, Any]) -> Any: for commit in commits: - if isGreen(commit, results)[0]: + eprint(f"Checking {commit}") + is_green, msg = isGreen(commit, results) + if is_green: + eprint("GREEN") return commit + else: + eprint("RED: " + msg) return None def main() -> None: diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 1d81f72edd8e0..4549a16f7a808 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -13,7 +13,7 @@ from typing import Dict, List, Tuple, Optional -CUDA_ARCHES = ["10.2", "11.3", "11.6"] +CUDA_ARCHES = ["10.2", "11.3", "11.6", "11.7"] ROCM_ARCHES = ["5.0", "5.1.1"] @@ -183,10 +183,15 @@ def generate_wheels_matrix(os: str, if python_versions is None: # Define default python version - python_versions = FULL_PYTHON_VERSIONS + python_versions = list(FULL_PYTHON_VERSIONS) if os == "macos-arm64": python_versions = list_without(python_versions, ["3.7"]) + if os == "linux": + # NOTE: We only build 3.11 wheel on linux as 3.11 is not + # available on conda right now + python_versions.append("3.11") + if arches is None: # Define default compute archivectures arches = ["cpu"] @@ -201,6 +206,9 @@ def generate_wheels_matrix(os: str, for arch_version in arches: gpu_arch_type = arch_type(arch_version) gpu_arch_version = "" if arch_version == "cpu" else arch_version + # Skip rocm 3.11 binaries for now as the docker image are not correct + if python_version == "3.11" and gpu_arch_type == "rocm": + continue ret.append( { "python_version": python_version, diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py index 4c43fc251fb10..aa64fe15387e8 100644 --- a/.github/scripts/gitutils.py +++ b/.github/scripts/gitutils.py @@ -305,8 +305,8 @@ def patterns_to_regex(allowed_patterns: List[str]) -> Any: """ pattern is glob-like, i.e. the only special sequences it has are: - ? - matches single character - - * - matches any non-folder separator characters - - ** - matches any characters + - * - matches any non-folder separator characters or no character + - ** - matches any characters or no character Assuming that patterns are free of braces and backslashes the only character that needs to be escaped are dot and plus """ @@ -324,9 +324,9 @@ def patterns_to_regex(allowed_patterns: List[str]) -> Any: elif c == "*": if pattern_.peek() == "*": next(pattern_) - rc += ".+" + rc += ".*" else: - rc += "[^/]+" + rc += "[^/]*" else: rc += c rc += ")" diff --git a/.github/scripts/gql_mocks.json b/.github/scripts/gql_mocks.json index 641eeafa19089..b146600f936af 100644 --- a/.github/scripts/gql_mocks.json +++ b/.github/scripts/gql_mocks.json @@ -1,5 +1,5 @@ { - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=73811 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=73811 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -401,6 +401,15 @@ "startCursor": "Y3Vyc29yOnYyOpHOP6yFeQ==", "hasPreviousPage": true } + }, + "labels": { + "edges": [ + { + "node": { + "name": "cla signed" + } + } + ] } } } @@ -2018,11 +2027,11 @@ } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=31093 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=31093 owner=pytorch": { "data": { "repository": { "pullRequest": { - "closed": false, + "closed": true, "isCrossRepository": true, "author": { "login": "mingxiaoh" @@ -2525,6 +2534,30 @@ "startCursor": "Y3Vyc29yOnYyOpHOKCmhXQ==", "hasPreviousPage": true } + }, + "labels": { + "edges": [ + { + "node": { + "name": "triaged" + } + }, + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "Stale" + } + } + ] } } } @@ -3013,254 +3046,964 @@ } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=76118 owner=pytorch": { + "query_sha=2dc8bfb6750c4a2402124dc53123d266427c0b92d06add20e3221b57a0f5268f commit=6882717f73deffb692219ccd1fd6db258d8ed684 name=pytorch owner=pytorch": { "data": { "repository": { - "pullRequest": { - "closed": false, - "isCrossRepository": false, - "author": { - "login": "malfet" - }, - "title": "Dummy change with lots of commits", - "body": "Draft PR with 100+ commits, to test mergebot ", - "headRefName": "malfet/pr-with-lots-of-commits", - "headRepository": { - "nameWithOwner": "pytorch/pytorch" - }, - "baseRefName": "master", - "baseRepository": { - "nameWithOwner": "pytorch/pytorch", - "isPrivate": false, - "defaultBranchRef": { - "name": "master" - } - }, - "mergeCommit": null, - "commits_with_authors": { - "nodes": [ - { - "commit": { - "author": { - "user": { - "login": "malfet" - }, - "email": "nshulga@fb.com", - "name": "Nikita Shulga" - }, - "oid": "3067f2240afc7a29dc348000aa19eccbd9772303" - } - }, - { - "commit": { - "author": { - "user": { - "login": "andrewor14" - }, - "email": "andrewor@fb.com", - "name": "Andrew Or" + "object": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } }, - "oid": "2f655b71f70c496c4e645f6cdb27d7bb7e825701" - } + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625272" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1hng=" }, { - "commit": { - "author": { - "user": null, - "email": "mruberry@devfair044.h1.fair", - "name": "Mike Ruberry" + "node": { + "app": { + "name": "Netlify", + "databaseId": 13473 }, - "oid": "0c6dcaa7f58a19c42a530f4ee14bb6f0f03ca9fb" - } - }, - { - "commit": { - "author": { - "user": { - "login": "dzdang" - }, - "email": "dzdang@umich.edu", - "name": "dzdang" + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } }, - "oid": "cad11c563d41ebcffb1683fe1f1288b8157413b3" - } + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625297" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1hpE=" }, { - "commit": { - "author": { - "user": { - "login": "alanwaketan" - }, - "email": "jwtan@fb.com", - "name": "Jiewen Tan" + "node": { + "app": { + "name": "Azure Pipelines", + "databaseId": 9426 }, - "oid": "4dfd0875a68d87fccb5ad0d81692db480043b86e" - } - }, - { - "commit": { - "author": { - "user": null, - "email": "mruberry@devfair044.h1.fair", - "name": "Mike Ruberry" + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } }, - "oid": "2d37e74690582a4a26890e4c8b98f1f80e589c82" - } + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625308" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1hpw=" }, { - "commit": { - "author": { - "user": { - "login": "alanwaketan" - }, - "email": "jwtan@fb.com", - "name": "Jiewen Tan" + "node": { + "app": { + "name": "Dependabot", + "databaseId": 29110 }, - "oid": "d4aee60947e1a3ef23c7c42990621e0746fdd0a8" - } - }, - { - "commit": { - "author": { - "user": { - "login": "peterbell10" - }, - "email": "peterbell10@live.co.uk", - "name": "Peter Bell" + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } }, - "oid": "aac6204bf710beb5e50a383d426ae6222396335a" - } + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625328" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1hrA=" }, { - "commit": { - "author": { - "user": { - "login": "dzdang" - }, - "email": "dzdang@umich.edu", - "name": "dzdang" + "node": { + "app": { + "name": "Codecov", + "databaseId": 254 }, - "oid": "4b0362cab884584c24f5834b3874f5f357f56b5d" - } - }, - { - "commit": { - "author": { - "user": null, - "email": "mruberry@devfair044.h1.fair", - "name": "Mike Ruberry" + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } }, - "oid": "7536df613cbc645a9e68e6a3b0a8450753260fd1" - } + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625347" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1hsM=" }, { - "commit": { - "author": { - "user": null, - "email": "mruberry@devfair044.h1.fair", - "name": "Mike Ruberry" + "node": { + "app": { + "name": "PyTorch Bot", + "databaseId": 40112 }, - "oid": "20a50cb966d28d7bf82924adf781cf72a01ef90e" - } - }, - { - "commit": { - "author": { - "user": null, - "email": "mruberry@devfair044.h1.fair", - "name": "Mike Ruberry" + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } }, - "oid": "486387e8644afb46edff5aa5925b55c8119f67f0" - } + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625357" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1hs0=" }, { - "commit": { - "author": { - "user": { - "login": "dzdang" - }, - "email": "dzdang@umich.edu", - "name": "dzdang" + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 }, - "oid": "acb9d78b9b732d3667b881727e6ed9f92a8c549f" - } - }, - { - "commit": { - "author": { - "user": null, - "email": "mruberry@devfair044.h1.fair", - "name": "Mike Ruberry" + "workflowRun": { + "workflow": { + "name": "Lint" + } }, - "oid": "683bb7959a5b973f8470c081ad02e8fc508e784a" - } - }, - { - "commit": { - "author": { - "user": { - "login": "qihqi" - }, - "email": "qihan@fb.com", - "name": "Han Qi" + "checkRuns": { + "nodes": [ + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257521878?check_suite_focus=true" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257521941?check_suite_focus=true" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522171?check_suite_focus=true" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522418?check_suite_focus=true" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522648?check_suite_focus=true" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522731?check_suite_focus=true" + }, + { + "name": "Test collect_env (older_python_version)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522798?check_suite_focus=true" + }, + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523046?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbCVA2Y=", + "hasNextPage": false + } }, - "oid": "a870cb40af65adf0b77d55f6b554d7093d284d7a" - } + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625464" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1hzg=" }, { - "commit": { - "author": { - "user": { - "login": "Krovatkin" - }, - "email": "korovaikon@gmail.com", - "name": "Nikolay Korovaiko" + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 }, - "oid": "70793b9f328ddf52cc86336104c3a064c8582ef4" - } - }, - { - "commit": { - "author": { - "user": { - "login": "suo" - }, - "email": "suo@fb.com", - "name": "Michael Suo" + "workflowRun": { + "workflow": { + "name": "trunk" + } }, - "oid": "f70b31f62b1c5159eef2725484b175983517c88c" - } - }, - { - "commit": { - "author": { - "user": { - "login": "dagitses" - }, - "email": "mikeyd@fb.com", - "name": "Michael Andreas Dagitses" + "checkRuns": { + "nodes": [ + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522494?check_suite_focus=true" + }, + { + "name": "android-emulator-build-test / build-and-test", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522741?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522887?check_suite_focus=true" + }, + { + "name": "macos-10-15-py3-arm64 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523057?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523301?check_suite_focus=true" + }, + { + "name": "ios-12-5-1-x86-64 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523681?check_suite_focus=true" + }, + { + "name": "libtorch-linux-bionic-cuda11.6-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523926?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524141?check_suite_focus=true" + }, + { + "name": "libtorch-linux-xenial-cuda10.2-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524423?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9-slow / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524568?check_suite_focus=true" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524710?check_suite_focus=true" + }, + { + "name": "macos-10-15-py3-lite-interpreter-x86-64 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524925?check_suite_focus=true" + }, + { + "name": "macos-11-py3-x86-64 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257525196?check_suite_focus=true" + }, + { + "name": "caffe2-linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257525344?check_suite_focus=true" + }, + { + "name": "parallelnative-linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257525621?check_suite_focus=true" + }, + { + "name": "parallelnative-linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257748822?check_suite_focus=true" + }, + { + "name": "parallelnative-linux-focal-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257748937?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9-slow / test (slow, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257940181?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257996123?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / test (default, 2, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257996266?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / test (slow, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257996436?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / test (nogpu_AVX512, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257996598?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / test (nogpu_NO_AVX2, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257996687?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / test (jit_legacy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257996800?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257996869?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda10.2-py3.9-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257996947?check_suite_focus=true" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258043565?check_suite_focus=true" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258043644?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / test (default, 1, 5, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258043840?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / test (default, 2, 5, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258043904?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / test (default, 3, 5, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258043967?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / test (default, 4, 5, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258044051?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / test (default, 5, 5, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258044125?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / test (force_on_cpu, 1, 1, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258044194?check_suite_focus=true" + }, + { + "name": "macos-12.3-py3.8-arm64-test / Run MPS tests", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258358668?check_suite_focus=true" + }, + { + "name": "macos-11-py3-x86-64 / test (default, 1, 2, macos-12)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258757994?check_suite_focus=true" + }, + { + "name": "macos-11-py3-x86-64 / test (default, 2, 2, macos-12)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258758076?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbCn27w=", + "hasNextPage": false + } }, - "oid": "04d3ec1db60defe1c6904bf77e9f8dfa87dc0b63" - } + "conclusion": "FAILURE", + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625556" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1h5Q=" }, { - "commit": { - "author": { - "user": null, - "email": "mruberry@devfair044.h1.fair", - "name": "Mike Ruberry" + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 }, - "oid": "46b754a55b63e3168ad5854ad412c124934b675d" + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-bionic-rocm5.1-py3.7", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522250?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522456?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522650?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257522894?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523070?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523312?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523709?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257523936?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524138?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524427?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524554?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524720?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257524938?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257525212?check_suite_focus=true" + }, + { + "name": "linux-jammy-cuda11.6-cudnn8-py3.8-clang12 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257525332?check_suite_focus=true" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257525623?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257525714?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257525946?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257526187?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257526402?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257526593?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257688277?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257759879?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257760015?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 3, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257760116?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 4, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257760245?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257760346?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257760456?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257909951?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257909994?check_suite_focus=true" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257912956?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257934535?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257934615?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257934714?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257934784?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257934866?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257934975?check_suite_focus=true" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257935092?check_suite_focus=true" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257935201?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257943077?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257943146?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257943200?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257943268?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257943319?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257943373?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257960183?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7257960282?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 1, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258020141?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 2, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258020221?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 3, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258020306?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbCcmdI=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/6882717f73deffb692219ccd1fd6db258d8ed684/checks?check_suite_id=7280625557" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbH1h5U=" + } + ], + "pageInfo": { + "hasNextPage": false + } + } + } + } + } + }, + "query_sha=23d6a47e5fd875c42231779040ec1d35d0042b502c9142cb0d33d6f65d58fead commit=6882717f73deffb692219ccd1fd6db258d8ed684 cr_cursor=Y3Vyc29yOnYyOpHPAAAAAbCcmdI= cs_cursor=Y3Vyc29yOnYyOpHPAAAAAbH1h5Q= name=pytorch owner=pytorch": { + "data": { + "repository": { + "object": { + "oid": "6882717f73deffb692219ccd1fd6db258d8ed684", + "checkSuites": { + "nodes": [ + { + "checkRuns": { + "nodes": [ + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258020388?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 5, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258020493?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7258219463?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbCfo8c=", + "hasNextPage": false + } } - }, + } + ] + } + } + } + } + }, + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=76118 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "malfet" + }, + "title": "Dummy change with lots of commits", + "body": "Draft PR with 100+ commits, to test mergebot ", + "headRefName": "malfet/pr-with-lots-of-commits", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ { "commit": { "author": { "user": { - "login": "robieta" + "login": "malfet" }, - "email": "taylorrobie@fb.com", - "name": "Taylor Robie" + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "3067f2240afc7a29dc348000aa19eccbd9772303" + } + }, + { + "commit": { + "author": { + "user": { + "login": "andrewor14" + }, + "email": "andrewor@fb.com", + "name": "Andrew Or" + }, + "oid": "2f655b71f70c496c4e645f6cdb27d7bb7e825701" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "0c6dcaa7f58a19c42a530f4ee14bb6f0f03ca9fb" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "cad11c563d41ebcffb1683fe1f1288b8157413b3" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "jwtan@fb.com", + "name": "Jiewen Tan" + }, + "oid": "4dfd0875a68d87fccb5ad0d81692db480043b86e" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "2d37e74690582a4a26890e4c8b98f1f80e589c82" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "jwtan@fb.com", + "name": "Jiewen Tan" + }, + "oid": "d4aee60947e1a3ef23c7c42990621e0746fdd0a8" + } + }, + { + "commit": { + "author": { + "user": { + "login": "peterbell10" + }, + "email": "peterbell10@live.co.uk", + "name": "Peter Bell" + }, + "oid": "aac6204bf710beb5e50a383d426ae6222396335a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "4b0362cab884584c24f5834b3874f5f357f56b5d" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "7536df613cbc645a9e68e6a3b0a8450753260fd1" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "20a50cb966d28d7bf82924adf781cf72a01ef90e" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "486387e8644afb46edff5aa5925b55c8119f67f0" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "acb9d78b9b732d3667b881727e6ed9f92a8c549f" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "683bb7959a5b973f8470c081ad02e8fc508e784a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "qihqi" + }, + "email": "qihan@fb.com", + "name": "Han Qi" + }, + "oid": "a870cb40af65adf0b77d55f6b554d7093d284d7a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "Krovatkin" + }, + "email": "korovaikon@gmail.com", + "name": "Nikolay Korovaiko" + }, + "oid": "70793b9f328ddf52cc86336104c3a064c8582ef4" + } + }, + { + "commit": { + "author": { + "user": { + "login": "suo" + }, + "email": "suo@fb.com", + "name": "Michael Suo" + }, + "oid": "f70b31f62b1c5159eef2725484b175983517c88c" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dagitses" + }, + "email": "mikeyd@fb.com", + "name": "Michael Andreas Dagitses" + }, + "oid": "04d3ec1db60defe1c6904bf77e9f8dfa87dc0b63" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "46b754a55b63e3168ad5854ad412c124934b675d" + } + }, + { + "commit": { + "author": { + "user": { + "login": "robieta" + }, + "email": "taylorrobie@fb.com", + "name": "Taylor Robie" }, "oid": "13df69e13ee571fdd716139419a00aec47ade7d6" } @@ -5045,15 +5788,6 @@ }, "comments": { "nodes": [ - { - "bodyText": "Merge failed due to Matched rule superuser, but it was not reviewed yet by any of:hongxiayang,janeyx99,mehdimashayekhi,tvalentius,yidawang-oss, ...", - "author": { - "login": "pytorchmergebot" - }, - "authorAssociation": "MEMBER", - "editor": null, - "databaseId": 1104214220 - }, { "bodyText": "Merge failed due to Matched rule superuser, but it was not reviewed yet by any of:zou3519,abhikrish,mehtanirav,wconstab,lc0, ...", "author": { @@ -5089,12 +5823,35 @@ "authorAssociation": "MEMBER", "editor": null, "databaseId": 1104379712 + }, + { + "bodyText": "Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale. Feel free to remove the Stale label if you feel this was a mistake. If you are unable to remove the Stale label please contact a maintainer in order to do so. If you want the bot to never mark this PR stale again, add the no-stale label.Stale pull requests will automatically be closed after 30 days of inactivity.", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1160658699 } ], "pageInfo": { - "startCursor": "Y3Vyc29yOnYyOpHOQdD4zA==", + "startCursor": "Y3Vyc29yOnYyOpHOQdD9Sg==", "hasPreviousPage": true } + }, + "labels": { + "edges": [ + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "Stale" + } + } + ] } } } @@ -5482,7 +6239,7 @@ } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=76123 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=76123 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -7019,6 +7776,20 @@ "startCursor": "Y3Vyc29yOnYyOpHOQqri9w==", "hasPreviousPage": true } + }, + "labels": { + "edges": [ + { + "node": { + "name": "oncall: distributed" + } + }, + { + "node": { + "name": "cla signed" + } + } + ] } } } @@ -7094,7 +7865,7 @@ } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=71759 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=71759 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -7915,21 +8686,50 @@ "startCursor": "Y3Vyc29yOnYyOpHOPoR4Lg==", "hasPreviousPage": true } - } - } - } - } - }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=75095 owner=pytorch": { - "data": { - "repository": { - "pullRequest": { - "closed": true, - "isCrossRepository": false, - "author": { - "login": "mruberry" }, - "title": "Initial prims, references, and test architecture for them", + "labels": { + "edges": [ + { + "node": { + "name": "triaged" + } + }, + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "release notes: nn" + } + }, + { + "node": { + "name": "topic: performance" + } + } + ] + } + } + } + } + }, + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=75095 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "mruberry" + }, + "title": "Initial prims, references, and test architecture for them", "body": "This PR adds an initial set of experimental primitive operations and Python references that reimplement existing PyTorch operations using them. See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 for additional context.\r\n\r\nThe following experimental primitives are added:\r\n\r\n- Elementwise unary prims -- abs, acos, acosh, asin, atan, cos, cosh, bessel_i0e, bessel_i1e, cbrt, ceil, digamma, erf, erf_inv, erfc, exp, expm1, floor, igamma, igammac, is_finite, lgamma, log, log1p, neg, reciprocal, round, sign, sinh, sqrt, square, tan. \r\n- Elementwise binary prims -- add, atan2, bitwise_and, bitwise_not, bitwise_or, bitwise_xor, div, eq, ge, gt, le, lt, max, min, mul, ne, nextafter, pow, rsqrt, shift_left, shift_right_arithmetic\r\n- View prims -- brodcast_in_dim, collapse_view, split_dim, squeeze\r\n- Shape prims -- collapse, concatenate, reshape\r\n- Conditional prims -- select\r\n- Data conversion & movement prims -- convert_element_type, device_put\r\n- Inplace prims -- copy_to, resize\r\n\r\nThese primitives do not add any new functionality to PyTorch, but are intended to be the semantic building blocks for reference operators. We have tried to make them consistent with the operations in [jax.lax](https://jax.readthedocs.io/en/latest/jax.lax.html) where possible (because PyTorch prefers being consistent with other frameworks), although there are key differences between these prims and operations in jax.lax. Most notably is that these prims model view semantics and inplace operations.\r\n\r\nIn addition to these primitives the following elementwise binary Python references are added:\r\n\r\n- Elementwise binary Python references -- add, atan2, bitwise_and, bitwise_left_shift, bitwise_or, bitwise_right_shift, bitwise_xor, eq, float_power, ge, gt, le, lt, maximum, minimum, mul, ne, nextafter, pow, sub, true_divide\r\n- Conditional Python references - where\r\n- Data conversion & movement references - copy_to\r\n\r\nA Python reference implements the same behavior as its corresponding PyTorch operator (excepting slight numerical differences, bug fixes, and in some cases additional features). \r\n\r\nThe start of an OpInfo-based test architecture for these references is also included in this PR. A new list, `python_ref_db`, is added to `common_methods_invocations.py`. This list introduces the new `ElementwiseBinaryPythonRefInfo`, which inherits input arguments from the original operators' OpInfo, allows them to be overridden, and then constructs the OpInfo for the Python reference using the (potentially modified) arguments. OpInfo-based tests can opt-into testing references by including this new list in the Sequence passed to the `@ops` decorator. \r\n\r\ncc @ngimel @csarofeen @kevinstephano @Lezcano ", "headRefName": "prims_and_references", "headRepository": { @@ -9329,12 +10129,31 @@ "startCursor": "Y3Vyc29yOnYyOpHOQebHmg==", "hasPreviousPage": true } + }, + "labels": { + "edges": [ + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "topic: not user facing" + } + }, + { + "node": { + "name": "module: primTorch" + } + } + ] } } } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=77700 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=77700 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -10075,6 +10894,20 @@ "startCursor": "Y3Vyc29yOnYyOpHOQ1FKZg==", "hasPreviousPage": false } + }, + "labels": { + "edges": [ + { + "node": { + "name": "Merged" + } + }, + { + "node": { + "name": "cla signed" + } + } + ] } } } @@ -10136,7 +10969,7 @@ } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=68111 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=68111 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -11821,6 +12654,40 @@ "startCursor": "Y3Vyc29yOnYyOpHOQAuLsw==", "hasPreviousPage": true } + }, + "labels": { + "edges": [ + { + "node": { + "name": "oncall: jit" + } + }, + { + "node": { + "name": "triaged" + } + }, + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "Reverted" + } + }, + { + "node": { + "name": "intel priority" + } + } + ] } } } @@ -12016,7 +12883,7 @@ } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=73969 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=73969 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -12523,6 +13390,20 @@ "startCursor": "Y3Vyc29yOnYyOpHOP11MjQ==", "hasPreviousPage": false } + }, + "labels": { + "edges": [ + { + "node": { + "name": "fb-exported" + } + }, + { + "node": { + "name": "cla signed" + } + } + ] } } } @@ -13054,7 +13935,7 @@ } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=73099 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=73099 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -13841,6 +14722,35 @@ "startCursor": "Y3Vyc29yOnYyOpHOPniAWQ==", "hasPreviousPage": true } + }, + "labels": { + "edges": [ + { + "node": { + "name": "oncall: jit" + } + }, + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "release notes: onnx" + } + }, + { + "node": { + "name": "topic: bug fixes" + } + } + ] } } } @@ -14048,7 +14958,7 @@ } } }, - "query_sha=c926aeea1daf714aadc93dc2eeecdc836af2d998be52090bfad6da33284a35fe name=pytorch number=74649 owner=pytorch": { + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=74649 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -14673,6 +15583,15 @@ "startCursor": "Y3Vyc29yOnYyOpHOQDAOUg==", "hasPreviousPage": false } + }, + "labels": { + "edges": [ + { + "node": { + "name": "cla signed" + } + } + ] } } } @@ -14753,6 +15672,9 @@ { "login": "shoumikhin" }, + { + "login": "huydhn" + }, { "login": "teytaud" }, @@ -14792,6 +15714,9 @@ { "login": "yzhao30" }, + { + "login": "rmaz" + }, { "login": "bearzx" }, @@ -14873,6 +15798,9 @@ { "login": "jfix71" }, + { + "login": "atuljangra" + }, { "login": "idning" }, @@ -14888,6 +15816,9 @@ { "login": "radkris-git" }, + { + "login": "xunnanxu" + }, { "login": "javier-m" }, @@ -14971,35 +15902,35 @@ }, { "login": "smessmer" - }, - { - "login": "ananthsub" - }, - { - "login": "d1jang" - }, - { - "login": "firstprayer" - }, - { - "login": "malfet" } ], "pageInfo": { "hasNextPage": true, - "endCursor": "Y3Vyc29yOnYyOpHOACVwFA==" + "endCursor": "Y3Vyc29yOnYyOpHOACQ5JQ==" } } } } } }, - "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=Y3Vyc29yOnYyOpHOACVwFA== name=metamates org=pytorch": { + "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=Y3Vyc29yOnYyOpHOACQ5JQ== name=metamates org=pytorch": { "data": { "organization": { "team": { "members": { "nodes": [ + { + "login": "ananthsub" + }, + { + "login": "d1jang" + }, + { + "login": "firstprayer" + }, + { + "login": "malfet" + }, { "login": "fegin" }, @@ -15048,9 +15979,6 @@ { "login": "saketh-are" }, - { - "login": "jessebrizzi" - }, { "login": "msaroufim" }, @@ -15066,6 +15994,9 @@ { "login": "hlin09" }, + { + "login": "hudeven" + }, { "login": "terrychenism" }, @@ -15087,6 +16018,9 @@ { "login": "desertfire" }, + { + "login": "YosuaMichael" + }, { "login": "banitag1" }, @@ -15103,7 +16037,7 @@ "login": "bilalsal" }, { - "login": "jaceyca" + "login": "DanilBaibak" }, { "login": "serhaty" @@ -15123,9 +16057,6 @@ { "login": "superzgc" }, - { - "login": "tenpercent" - }, { "login": "bertmaher" }, @@ -15138,6 +16069,12 @@ { "login": "jiayisuse" }, + { + "login": "bochko" + }, + { + "login": "jeanschmidt" + }, { "login": "bradleyhd" }, @@ -15219,9 +16156,6 @@ { "login": "gqchen" }, - { - "login": "jayleverett" - }, { "login": "george-qi" }, @@ -15284,38 +16218,38 @@ }, { "login": "fduwjj" - }, - { - "login": "frank-wei" - }, - { - "login": "esqu1" - }, - { - "login": "prabhat00155" - }, - { - "login": "Gamrix" - }, - { - "login": "QuentinDuval" } ], "pageInfo": { "hasNextPage": true, - "endCursor": "Y3Vyc29yOnYyOpHOAHEcNg==" + "endCursor": "Y3Vyc29yOnYyOpHOAGncmA==" } } } } } }, - "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=Y3Vyc29yOnYyOpHOAHEcNg== name=metamates org=pytorch": { + "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=Y3Vyc29yOnYyOpHOAGncmA== name=metamates org=pytorch": { "data": { "organization": { "team": { "members": { "nodes": [ + { + "login": "frank-wei" + }, + { + "login": "esqu1" + }, + { + "login": "prabhat00155" + }, + { + "login": "Gamrix" + }, + { + "login": "QuentinDuval" + }, { "login": "atalman" }, @@ -15361,9 +16295,6 @@ { "login": "Jack-Khuu" }, - { - "login": "alanwaketan" - }, { "login": "mehtanirav" }, @@ -15379,9 +16310,6 @@ { "login": "muntaqim" }, - { - "login": "dennysem" - }, { "login": "ymao1993" }, @@ -15427,9 +16355,15 @@ { "login": "pritamdamania87" }, + { + "login": "psavla2" + }, { "login": "rahxephon89" }, + { + "login": "migeed-z" + }, { "login": "iseeyuan" }, @@ -15496,9 +16430,6 @@ { "login": "robieta" }, - { - "login": "amirhmk" - }, { "login": "davidxili" }, @@ -15603,35 +16534,35 @@ }, { "login": "mrshenli" - }, - { - "login": "lena-kashtelyan" - }, - { - "login": "brad-mengchi" - }, - { - "login": "kimishpatel" - }, - { - "login": "aaronenyeshi" } ], "pageInfo": { "hasNextPage": true, - "endCursor": "Y3Vyc29yOnYyOpHOAQyXPg==" + "endCursor": "Y3Vyc29yOnYyOpHOAQNk0w==" } } } } } }, - "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=Y3Vyc29yOnYyOpHOAQyXPg== name=metamates org=pytorch": { + "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=Y3Vyc29yOnYyOpHOAQNk0w== name=metamates org=pytorch": { "data": { "organization": { "team": { "members": { "nodes": [ + { + "login": "lena-kashtelyan" + }, + { + "login": "brad-mengchi" + }, + { + "login": "kimishpatel" + }, + { + "login": "aaronenyeshi" + }, { "login": "shajrawi" }, @@ -15867,10 +16798,10 @@ "login": "mengwa41" }, { - "login": "hx89" + "login": "YulunW" }, { - "login": "kiukchung" + "login": "hx89" }, { "login": "hanhsienhuang" @@ -15919,35 +16850,35 @@ }, { "login": "shunting314" - }, - { - "login": "edward-io" - }, - { - "login": "sean-ngo" - }, - { - "login": "bzinodev" - }, - { - "login": "skim0514" } ], "pageInfo": { "hasNextPage": true, - "endCursor": "Y3Vyc29yOnYyOpHOA0w3oQ==" + "endCursor": "Y3Vyc29yOnYyOpHOAyJyuA==" } } } } } }, - "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=Y3Vyc29yOnYyOpHOA0w3oQ== name=metamates org=pytorch": { + "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=Y3Vyc29yOnYyOpHOAyJyuA== name=metamates org=pytorch": { "data": { "organization": { "team": { "members": { "nodes": [ + { + "login": "edward-io" + }, + { + "login": "sean-ngo" + }, + { + "login": "bzinodev" + }, + { + "login": "skim0514" + }, { "login": "xcheng16" }, @@ -16053,6 +16984,9 @@ { "login": "yuguo68" }, + { + "login": "c-odrin" + }, { "login": "chowarfb" }, @@ -16094,77 +17028,37 @@ }, { "login": "anirbanr-fb-r2p" + }, + { + "login": "kirklandsign" + }, + { + "login": "o-hanna" + }, + { + "login": "izaitsevfb" + }, + { + "login": "weiwangmeta" } ], "pageInfo": { "hasNextPage": false, - "endCursor": "Y3Vyc29yOnYyOpHOBkbBhA==" + "endCursor": "Y3Vyc29yOnYyOpHOBoQSVA==" } } } } } }, - "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=None name=pytorch-dev-infra org=pytorch": { + "query_sha=0a34acb829d8aca9dd28a8ba388dfa52f6ecdde7e903ace1caabdcfaba87de98 cursor=MTAw name=pytorch number=76118 owner=pytorch": { "data": { - "organization": { - "team": { - "members": { + "repository": { + "pullRequest": { + "files": { "nodes": [ { - "login": "kit1980" - }, - { - "login": "b0noI" - }, - { - "login": "seemethere" - }, - { - "login": "malfet" - }, - { - "login": "tenpercent" - }, - { - "login": "ZainRizvi" - }, - { - "login": "atalman" - }, - { - "login": "akshayParashar1995" - }, - { - "login": "osalpekar" - }, - { - "login": "swang392" - }, - { - "login": "janeyx99" - }, - { - "login": "clee2000" - } - ], - "pageInfo": { - "hasNextPage": false, - "endCursor": "Y3Vyc29yOnYyOpHOAqnOlw==" - } - } - } - } - } - }, - "query_sha=0a34acb829d8aca9dd28a8ba388dfa52f6ecdde7e903ace1caabdcfaba87de98 cursor=MTAw name=pytorch number=76118 owner=pytorch": { - "data": { - "repository": { - "pullRequest": { - "files": { - "nodes": [ - { - "path": "docs/source/quantization.rst" + "path": "docs/source/quantization.rst" }, { "path": "docs/source/scripts/build_quantization_configs.py" @@ -16856,100 +17750,2222 @@ "path": "torch/fx/experimental/meta_tracer.py" }, { - "path": "torch/fx/graph.py" + "path": "torch/fx/graph.py" + }, + { + "path": "torch/jit/_shape_functions.py" + }, + { + "path": "torch/nn/parallel/_replicated_tensor_ddp_interop.py" + }, + { + "path": "torch/nn/parallel/_replicated_tensor_ddp_utils.py" + }, + { + "path": "torch/nn/parallel/distributed.py" + }, + { + "path": "torch/nn/utils/_expanded_weights/__init__.py" + }, + { + "path": "torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py" + }, + { + "path": "torch/onnx/symbolic_opset11.py" + }, + { + "path": "torch/onnx/symbolic_opset12.py" + }, + { + "path": "torch/onnx/symbolic_opset9.py" + }, + { + "path": "torch/optim/adagrad.py" + }, + { + "path": "torch/optim/lr_scheduler.py" + }, + { + "path": "torch/overrides.py" + }, + { + "path": "torch/quantization/fx/pattern_utils.py" + }, + { + "path": "torch/quantization/fx/quantization_patterns.py" + }, + { + "path": "torch/quantization/fx/quantization_types.py" + }, + { + "path": "torch/return_types.py" + }, + { + "path": "torch/testing/_internal/common_device_type.py" + }, + { + "path": "torch/testing/_internal/common_distributed.py" + }, + { + "path": "torch/testing/_internal/common_fx2trt.py" + }, + { + "path": "torch/testing/_internal/common_methods_invocations.py" + }, + { + "path": "torch/testing/_internal/common_utils.py" + }, + { + "path": "torch/testing/_internal/composite_compliance.py" + }, + { + "path": "torch/testing/_internal/distributed/distributed_test.py" + }, + { + "path": "torch/testing/_internal/jit_metaprogramming_utils.py" + }, + { + "path": "torch/utils/cpp_extension.py" + }, + { + "path": "torch/utils/data/datapipes/_typing.py" + }, + { + "path": "torch/utils/model_dump/__init__.py" + } + ], + "pageInfo": { + "endCursor": "MzQ4", + "hasNextPage": false + } + } + } + } + } + }, + "query_sha=4c16925415d1fcc12ac0f5f7ce73b8e6122997d2f51c4c2757c2543e6493c60d cr_cursor=Y3Vyc29yOnYyOpHPAAAAAWuVD9M= cs_cursor=Y3Vyc29yOnYyOpHPAAAAAXEsRtE= name=pytorch number=76118 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "commits": { + "nodes": [ + { + "commit": { + "oid": "5696e8357cf38f852ef3d680381513e26f202371", + "checkSuites": { + "nodes": [ + { + "checkRuns": { + "nodes": [ + { + "name": "win-vs2019-cuda11.3-py3 / test (force_on_cpu, 1, 1, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/6099898412?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAWuVECw=", + "hasNextPage": false + } + } + } + ] + } + } + } + ] + } + } + } + } + }, + "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=None name=pytorch-dev-infra org=pytorch": { + "data": { + "organization": { + "team": { + "members": { + "nodes": [ + { + "login": "kit1980" + }, + { + "login": "huydhn" + }, + { + "login": "b0noI" + }, + { + "login": "seemethere" + }, + { + "login": "malfet" + }, + { + "login": "DanilBaibak" + }, + { + "login": "ZainRizvi" + }, + { + "login": "jeanschmidt" + }, + { + "login": "atalman" + }, + { + "login": "mehtanirav" + }, + { + "login": "osalpekar" + }, + { + "login": "swang392" + }, + { + "login": "janeyx99" + }, + { + "login": "clee2000" + }, + { + "login": "izaitsevfb" + }, + { + "login": "weiwangmeta" + } + ], + "pageInfo": { + "hasNextPage": false, + "endCursor": "Y3Vyc29yOnYyOpHOBoQSVA==" + } + } + } + } + } + }, + "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=None name=qwertyuiop org=pytorch": { + "data": { + "organization": { + "team": null + } + } + }, + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=82169 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "ezyang" + }, + "title": "Move test_dtypes so it runs later", + "body": "Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):\n* __->__ #82169\n\nThe error messages it gives are very unhelpful (because a failure\ngets translated into \"dtype was not supported\" rather than the\nactual backtrace), so I'd rather get error messages about this after\nI've tested basic functionality.\n\nSigned-off-by: Edward Z. Yang ", + "headRefName": "gh/ezyang/1279/head", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "gh/ezyang/1279/base", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "cef34da55a59da5a32494bff218ccd4978b659d3" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "83ad7e73a07111ac1d85e931d14360cc22c01edd" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "28140e4008289251b695385acfb48ac7a47cd49c" + } + } + ], + "pageInfo": { + "endCursor": "Mw", + "hasNextPage": false + }, + "totalCount": 3 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543705427?check_suite_focus=true" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543705796?check_suite_focus=true" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543705914?check_suite_focus=true" + }, + { + "name": "Test collect_env (older_python_version)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543706071?check_suite_focus=true" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543706300?check_suite_focus=true" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543706581?check_suite_focus=true" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543706911?check_suite_focus=true" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543707223?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcGj1lc=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696649" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRc8k=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696651" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRc8s=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543705420?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcGjz0w=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696656" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRc9A=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696660" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRc9Q=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696715" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdAs=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543706290?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543706587?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543706915?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543707231?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543707459?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543707794?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543708127?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543708379?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543708606?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543709052?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543709309?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543709535?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543709809?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543709986?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543710238?check_suite_focus=true" + }, + { + "name": "linux-focal-rocm5.2-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543710467?check_suite_focus=true" + }, + { + "name": "linux-jammy-cuda11.6-cudnn8-py3.8-clang12 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543710675?check_suite_focus=true" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543710925?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543711166?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7543711347?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544378552?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544378697?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544378800?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544378922?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544379063?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544379177?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (functorch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544379274?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544414957?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544415089?check_suite_focus=true" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544418146?check_suite_focus=true" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544418325?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544418649?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544418760?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544418892?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (functorch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544418988?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544419111?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544419210?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544419367?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544420236?check_suite_focus=true" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544427790?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 1, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544526201?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 2, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544526466?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 3, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544526651?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544526810?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 5, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544526939?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544790873?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544790983?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 3, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544791069?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 4, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544791145?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544791233?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcG0YME=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696836" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdIQ=" + }, + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcGjyQg=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696896" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdMA=" + }, + { + "node": { + "app": { + "name": "Netlify", + "databaseId": 13473 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546697185" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdeE=" + }, + { + "node": { + "app": { + "name": "Azure Pipelines", + "databaseId": 9426 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546697205" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdfU=" + }, + { + "node": { + "app": { + "name": "Dependabot", + "databaseId": 29110 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546697224" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdgg=" + } + ], + "pageInfo": { + "hasNextPage": true + } + }, + "pushedDate": "2022-07-27T15:34:17Z", + "oid": "28140e4008289251b695385acfb48ac7a47cd49c" + } + } + ] + }, + "changedFiles": 1, + "files": { + "nodes": [ + { + "path": "test/test_ops.py" + } + ], + "pageInfo": { + "endCursor": "MQ", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "zou3519" + }, + "state": "APPROVED" + }, + { + "author": { + "login": "Chillee" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wNy0yNVQxNDo0NTozNS0wNzowMLkyMDIyLTA3LTI1VDE0OjQ1OjM1LTA3OjAwzj6XYmg=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "@pytorchbot merge -f FORCE", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197107402 + }, + { + "bodyText": "You need to provide a reason for using force merge, in the format @pytorchbot merge -f '[CATEGORY] Explanation'. With [CATEGORY] being one the following:\nEMERGENCY - an emergency fix to quickly address an issue\nMINOR - a minor fix such as cleaning locally unused variables, which shouldn't break anything\nPRE_TESTED - a previous CI run tested everything and you've only added minor changes like fixing lint\nOTHER - something not covered above", + "author": { + "login": "pytorch-bot" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1197107439 + }, + { + "bodyText": "@pytorchbot merge -f \"[OTHER] normal land failed twice already\"", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197108130 + }, + { + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197119348 + }, + { + "bodyText": "Hey @ezyang.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1197120095 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOR1poyg==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "Merged" + } + }, + { + "node": { + "name": "cla signed" + } + } + ] + } + } + } + } + }, + "query_sha=4c16925415d1fcc12ac0f5f7ce73b8e6122997d2f51c4c2757c2543e6493c60d cr_cursor=Y3Vyc29yOnYyOpHPAAAAAcG0YME= cs_cursor=Y3Vyc29yOnYyOpHPAAAAAcHRdAs= name=pytorch number=82169 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "commits": { + "nodes": [ + { + "commit": { + "oid": "28140e4008289251b695385acfb48ac7a47cd49c", + "checkSuites": { + "nodes": [ + { + "checkRuns": { + "nodes": [ + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544791308?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (functorch, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544791418?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544791778?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544877177?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544877276?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / test (functorch, 1, 1, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7544877367?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcG1sTc=", + "hasNextPage": false + } + } + } + ] + } + } + } + ] + } + } + } + } + }, + "query_sha=4fa42dda073cf7ac75b2bbf595a8ef67b6dfff4bd248668750ff33ea913bf75f cursor=Y3Vyc29yOnYyOpHPAAAAAcHRdgg= name=pytorch number=82169 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "commits": { + "nodes": [ + { + "commit": { + "oid": "28140e4008289251b695385acfb48ac7a47cd49c", + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Codecov", + "databaseId": 254 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546697240" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdhg=" + }, + { + "node": { + "app": { + "name": "PyTorch Bot", + "databaseId": 40112 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546697255" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdic=" + } + ], + "pageInfo": { + "hasNextPage": false + } + } + } + } + ] + } + } + } + } + }, + "query_sha=0e2a29eda6405cea4c9de20fb80ae7924910e17272a7b251040182e7d8c390e0 name=pytorch number=79694 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": false, + "isCrossRepository": true, + "author": { + "login": "kshitij12345" + }, + "title": "[complex] conv_transpose1d", + "body": "Reference: https://github.com/pytorch/pytorch/issues/71108", + "headRefName": "develop/complex/conv_transpose1d", + "headRepository": { + "nameWithOwner": "kshitij12345/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "d1ea948e65ac6d31ad056287ab65d38ecc68b30d" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "b4ba1db9a3a71bd8c03158dcd1b68711360633d8" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "655a4220beae163bfe578f0318a130df01ec05d6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "Kshiteej K" + }, + "oid": "8181716be7a8005eb13ad5c3f2e1279ed1c60aff" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "9e5ca3663e7471786eeebebfdf84aea5d761712f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "9c110f39bcdc4e56386b6f9c4e2c082c8940ade6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "49315e79d0eee8008e2a74575c6fc0f6a9531ee4" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "728752480760226270c374a0acc08e28b9b133f3" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "ffe43399d6f60ef7844523a5f465c11d9a67062f" + } + } + ], + "pageInfo": { + "endCursor": "OQ", + "hasNextPage": false + }, + "totalCount": 9 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAboNCRo=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/ffe43399d6f60ef7844523a5f465c11d9a67062f/checks?check_suite_id=7428002306" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbq-UgI=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426574264?check_suite_focus=true" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426574600?check_suite_focus=true" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426574693?check_suite_focus=true" + }, + { + "name": "Test collect_env (older_python_version)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426574832?check_suite_focus=true" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426575043?check_suite_focus=true" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426575297?check_suite_focus=true" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426575617?check_suite_focus=true" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426575807?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbqojb8=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/ffe43399d6f60ef7844523a5f465c11d9a67062f/checks?check_suite_id=7437320797" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbtMgl0=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426574246?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbqoh6Y=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/ffe43399d6f60ef7844523a5f465c11d9a67062f/checks?check_suite_id=7437320800" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbtMgmA=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426574798?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426575118?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426575476?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426575622?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426575875?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426576118?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426576360?check_suite_focus=true" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426576522?check_suite_focus=true" + }, + { + "name": "linux-jammy-cuda11.6-cudnn8-py3.8-clang12 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426576694?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426576858?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426577069?check_suite_focus=true" + }, + { + "name": "linux-focal-rocm5.1-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426577340?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426577507?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426577677?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426577906?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426578065?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426578285?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426578423?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426578533?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426578766?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426768328?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426768494?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426768635?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426768797?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426768904?check_suite_focus=true" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426769059?check_suite_focus=true" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426769221?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426794528?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426794681?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426794811?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426794965?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426795132?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426795278?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (functorch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426795396?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426815145?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426815265?check_suite_focus=true" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426818878?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 1, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426857383?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 2, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426857577?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 3, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426857720?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426857893?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 5, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426858145?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426883486?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426949849?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426950005?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 3, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426950152?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 4, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426950337?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426950460?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426950568?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7426961175?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbqubxc=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/ffe43399d6f60ef7844523a5f465c11d9a67062f/checks?check_suite_id=7437320828" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbtMgnw=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453692770?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbxGU2I=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/ffe43399d6f60ef7844523a5f465c11d9a67062f/checks?check_suite_id=7463496300" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbzb6mw=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453692736?check_suite_focus=true" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453693139?check_suite_focus=true" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453693588?check_suite_focus=true" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453693942?check_suite_focus=true" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453694270?check_suite_focus=true" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453694519?check_suite_focus=true" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453694654?check_suite_focus=true" + }, + { + "name": "Test collect_env (older_python_version)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453694759?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbxGWyc=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/ffe43399d6f60ef7844523a5f465c11d9a67062f/checks?check_suite_id=7463496306" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbzb6nI=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453693883?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453694269?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453694482?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453694773?check_suite_focus=true" + }, + { + "name": "linux-jammy-cuda11.6-cudnn8-py3.8-clang12 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453695048?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453695376?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453695572?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453695789?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453696094?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453696262?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453696440?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453696619?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453696913?check_suite_focus=true" + }, + { + "name": "linux-focal-rocm5.1-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453697192?check_suite_focus=true" + }, + { + "name": "win-vs2019-cuda11.6-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453697504?check_suite_focus=true" + }, + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453697701?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453697927?check_suite_focus=true" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453698388?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453698629?check_suite_focus=true" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453698800?check_suite_focus=true" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453870481?check_suite_focus=true" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453870600?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453870806?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453870899?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453871006?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453871108?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453871214?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453871379?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453877423?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453877577?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453877679?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453877783?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453877932?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453878058?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (functorch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453878178?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453882847?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453882949?check_suite_focus=true" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453888149?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 1, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453922173?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 2, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453922275?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 3, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453922371?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453922449?check_suite_focus=true" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 5, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453922527?check_suite_focus=true" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7453931393?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454011679?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454011783?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 3, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454011866?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 4, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454011976?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454012075?check_suite_focus=true" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454012177?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbxLMxE=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/ffe43399d6f60ef7844523a5f465c11d9a67062f/checks?check_suite_id=7463496361" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAbzb6qk=" + } + ], + "pageInfo": { + "hasNextPage": false + } + }, + "pushedDate": "2022-07-19T19:21:58Z", + "oid": "ffe43399d6f60ef7844523a5f465c11d9a67062f" + } + } + ] + }, + "changedFiles": 3, + "files": { + "nodes": [ + { + "path": "aten/src/ATen/native/Convolution.cpp" }, { - "path": "torch/jit/_shape_functions.py" + "path": "torch/testing/_internal/common_methods_invocations.py" }, { - "path": "torch/nn/parallel/_replicated_tensor_ddp_interop.py" - }, + "path": "torch/testing/_internal/common_modules.py" + } + ], + "pageInfo": { + "endCursor": "Mw", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ { - "path": "torch/nn/parallel/_replicated_tensor_ddp_utils.py" - }, + "author": { + "login": "ngimel" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wNy0xOVQxMDowNzo1NC0wNzowMLkyMDIyLTA3LTE5VDEwOjA3OjU0LTA3OjAwzj43QcY=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ { - "path": "torch/nn/parallel/distributed.py" + "bodyText": "@pytorchbot revert -m \"breaking internal builds\" -c \"ghfirst\"", + "author": { + "login": "jeanschmidt" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191328013 }, { - "path": "torch/nn/utils/_expanded_weights/__init__.py" + "bodyText": "@pytorchbot revert -m \"breaking internal builds\" -c \"ghfirst\"", + "author": { + "login": "jeanschmidt" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191329792 }, { - "path": "torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py" + "bodyText": "@pytorchbot successfully started a revert job. Check the current status here", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191330586 }, { - "path": "torch/onnx/symbolic_opset11.py" + "bodyText": "@kshitij12345 your PR has been successfully reverted.", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191330690 }, { - "path": "torch/onnx/symbolic_opset12.py" - }, + "bodyText": "@jeanschmidt which test is it failing on? I tried running the test_eager_transforms in functorch but couldn't reproduce it.", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1193667568 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHORwI5DQ==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ { - "path": "torch/onnx/symbolic_opset9.py" + "node": { + "name": "open source" + } }, { - "path": "torch/optim/adagrad.py" + "node": { + "name": "Merged" + } }, { - "path": "torch/optim/lr_scheduler.py" + "node": { + "name": "cla signed" + } }, { - "path": "torch/overrides.py" - }, + "node": { + "name": "Reverted" + } + } + ] + } + } + } + } + }, + "query_sha=62ce809793481ce6ddce6e1a19d9b0761755ff0ff75decaf8a79419eaf793110 cursor=Y3Vyc29yOnYyOpHORwI5DQ== name=pytorch number=79694 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "comments": { + "nodes": [ { - "path": "torch/quantization/fx/pattern_utils.py" + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/79694\n\ud83d\udcc4 \u00a0Preview Python docs built from this PR\n\ud83d\udcc4 \u00a0Preview C++ docs built from this PR\n\u2753Need help or want to give feedback on the CI? Visit our office hours\n\n\u2705 No Failures (0 Pending)\nAs of commit ffe4339 (more details on the Dr. CI page):\nExpand to see more\n\n\ud83d\udc9a \ud83d\udc9a Looks good so far! There are no failures yet. \ud83d\udc9a \ud83d\udc9a\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 1157454523 }, { - "path": "torch/quantization/fx/quantization_patterns.py" + "bodyText": "Unable to reproduce jit failure locally (will skip the test)\nCI Failure : https://github.com/pytorch/pytorch/runs/6926187074?check_suite_focus=true#step:9:20230\npytest test/test_ops_jit.py -k test_variant_consistency_jit_nn_functional_conv_transpose1d_cpu_complex64 -v\n=============================================================== test session starts ===============================================================\nplatform linux -- Python 3.10.0, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 -- /home/kshiteej/.conda/envs/pytorch-cuda-dev/bin/python\ncachedir: .pytest_cache\nhypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/kshiteej/Pytorch/pytorch_complex_convolution.py/.hypothesis/examples')\nrootdir: /home/kshiteej/Pytorch/pytorch_complex_convolution.py, configfile: pytest.ini\nplugins: hypothesis-6.23.2, repeat-0.9.1\ncollected 1976 items / 1975 deselected / 1 selected \n\ntest/test_ops_jit.py::TestJitCPU::test_variant_consistency_jit_nn_functional_conv_transpose1d_cpu_complex64 PASSED [100%]\n\n================================================================ warnings summary =================================================================\n../../.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_internal/common_cuda.py:9\n /home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_internal/common_cuda.py:9: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives\n from distutils.version import LooseVersion\n\n../../.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:91\n /home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:91: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system.\n warnings.warn(\n\n-- Docs: https://docs.pytest.org/en/stable/warnings.html\n================================================= 1 passed, 1975 deselected, 2 warnings in 4.90s =================================================", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": { + "login": "kshitij12345" + }, + "databaseId": 1186949486 }, { - "path": "torch/quantization/fx/quantization_types.py" + "bodyText": "@pytorchbot merge", + "author": { + "login": "ngimel" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189347786 }, { - "path": "torch/return_types.py" + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189350009 }, { - "path": "torch/testing/_internal/common_device_type.py" + "bodyText": "Hey @kshitij12345.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1189350932 }, { - "path": "torch/testing/_internal/common_distributed.py" + "bodyText": "@pytorchbot revert -m \"broke slow test https://github.com/pytorch/pytorch/runs/7414560957?check_suite_focus=true#step:9:31516\" -c \"nosignal\"", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1189459845 }, { - "path": "torch/testing/_internal/common_fx2trt.py" + "bodyText": "@pytorchbot successfully started a revert job. Check the current status here", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189460926 }, { - "path": "torch/testing/_internal/common_methods_invocations.py" + "bodyText": "Will not revert as @kshitij12345 is not a MEMBER, but COLLABORATOR", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189460942 }, { - "path": "torch/testing/_internal/common_utils.py" + "bodyText": "@pytorchbot revert -m \"broke slow test https://github.com/pytorch/pytorch/runs/7414560957?check_suite_focus=true#step:9:31516\" -c \"nosignal\"", + "author": { + "login": "anjali411" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189529734 }, { - "path": "torch/testing/_internal/composite_compliance.py" + "bodyText": "@pytorchbot successfully started a revert job. Check the current status here", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189530756 }, { - "path": "torch/testing/_internal/distributed/distributed_test.py" + "bodyText": "@kshitij12345 your PR has been successfully reverted.", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189530831 }, { - "path": "torch/testing/_internal/jit_metaprogramming_utils.py" + "bodyText": "@pytorchbot merge -g", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1190070141 }, { - "path": "torch/utils/cpp_extension.py" + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1190071424 }, { - "path": "torch/utils/data/datapipes/_typing.py" + "bodyText": "Hey @kshitij12345.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1190258272 }, { - "path": "torch/utils/model_dump/__init__.py" + "bodyText": "commit is breaking internal builds/tests https://pastebin.com/HX4RUusH (pytorch/functorch/test:test_eager_transforms)", + "author": { + "login": "jeanschmidt" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191327616 } ], "pageInfo": { - "endCursor": "MzQ4", - "hasNextPage": false + "startCursor": "Y3Vyc29yOnYyOpHORP1auw==", + "hasPreviousPage": false } } } } } }, - "query_sha=4c16925415d1fcc12ac0f5f7ce73b8e6122997d2f51c4c2757c2543e6493c60d cr_cursor=Y3Vyc29yOnYyOpHPAAAAAWuVD9M= cs_cursor=Y3Vyc29yOnYyOpHPAAAAAXEsRtE= name=pytorch number=76118 owner=pytorch": { + "query_sha=4c16925415d1fcc12ac0f5f7ce73b8e6122997d2f51c4c2757c2543e6493c60d cr_cursor=Y3Vyc29yOnYyOpHPAAAAAbqubxc= cs_cursor=Y3Vyc29yOnYyOpHPAAAAAbtMgmA= name=pytorch number=79694 owner=pytorch": { "data": { "repository": { "pullRequest": { @@ -16957,20 +19973,25 @@ "nodes": [ { "commit": { - "oid": "5696e8357cf38f852ef3d680381513e26f202371", + "oid": "ffe43399d6f60ef7844523a5f465c11d9a67062f", "checkSuites": { "nodes": [ { "checkRuns": { "nodes": [ { - "name": "win-vs2019-cuda11.3-py3 / test (force_on_cpu, 1, 1, windows.4xlarge)", + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", "conclusion": "SUCCESS", - "detailsUrl": "https://github.com/pytorch/pytorch/runs/6099898412?check_suite_focus=true" + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7427036779?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7427036925?check_suite_focus=true" } ], "pageInfo": { - "endCursor": "Y3Vyc29yOnYyOpHPAAAAAWuVECw=", + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbqvlv0=", "hasNextPage": false } } @@ -16985,10 +20006,49 @@ } } }, - "query_sha=a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5 cursor=None name=qwertyuiop org=pytorch": { + "query_sha=4c16925415d1fcc12ac0f5f7ce73b8e6122997d2f51c4c2757c2543e6493c60d cr_cursor=Y3Vyc29yOnYyOpHPAAAAAbxLMxE= cs_cursor=Y3Vyc29yOnYyOpHPAAAAAbzb6nI= name=pytorch number=79694 owner=pytorch": { "data": { - "organization": { - "team": null + "repository": { + "pullRequest": { + "commits": { + "nodes": [ + { + "commit": { + "oid": "ffe43399d6f60ef7844523a5f465c11d9a67062f", + "checkSuites": { + "nodes": [ + { + "checkRuns": { + "nodes": [ + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454025911?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454189584?check_suite_focus=true" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/7454189772?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAbxN6Mw=", + "hasNextPage": false + } + } + } + ] + } + } + } + ] + } + } } } } diff --git a/.github/scripts/test_print_latest_commits.py b/.github/scripts/test_fetch_latest_green_commit.py similarity index 83% rename from .github/scripts/test_print_latest_commits.py rename to .github/scripts/test_fetch_latest_green_commit.py index 3095f4ccdc8ee..2f84658e63944 100644 --- a/.github/scripts/test_print_latest_commits.py +++ b/.github/scripts/test_fetch_latest_green_commit.py @@ -1,6 +1,6 @@ from unittest import TestCase, main, mock from typing import Any, List, Dict -from print_latest_commits import isGreen, WorkflowCheck +from fetch_latest_green_commit import isGreen, WorkflowCheck workflowNames = [ "pull", @@ -37,13 +37,13 @@ def make_test_checks(self) -> List[Dict[str, Any]]: return workflow_checks class TestPrintCommits(TestCase): - @mock.patch('print_latest_commits.get_commit_results', return_value=TestChecks().make_test_checks()) + @mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks()) def test_all_successful(self, mock_get_commit_results: Any) -> None: "Test with workflows are successful" workflow_checks = mock_get_commit_results() self.assertTrue(isGreen("sha", workflow_checks)[0]) - @mock.patch('print_latest_commits.get_commit_results', return_value=TestChecks().make_test_checks()) + @mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks()) def test_necessary_successful(self, mock_get_commit_results: Any) -> None: "Test with necessary workflows are successful" workflow_checks = mock_get_commit_results() @@ -54,7 +54,7 @@ def test_necessary_successful(self, mock_get_commit_results: Any) -> None: workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[12], "failed") self.assertTrue(isGreen("sha", workflow_checks)[0]) - @mock.patch('print_latest_commits.get_commit_results', return_value=TestChecks().make_test_checks()) + @mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks()) def test_necessary_skipped(self, mock_get_commit_results: Any) -> None: "Test with necessary job (ex: pull) skipped" workflow_checks = mock_get_commit_results() @@ -62,7 +62,7 @@ def test_necessary_skipped(self, mock_get_commit_results: Any) -> None: result = isGreen("sha", workflow_checks) self.assertTrue(result[0]) - @mock.patch('print_latest_commits.get_commit_results', return_value=TestChecks().make_test_checks()) + @mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks()) def test_skippable_skipped(self, mock_get_commit_results: Any) -> None: "Test with skippable jobs (periodic and docker-release-builds skipped" workflow_checks = mock_get_commit_results() @@ -70,7 +70,7 @@ def test_skippable_skipped(self, mock_get_commit_results: Any) -> None: workflow_checks = set_workflow_job_status(workflow_checks, "docker-release-builds", "skipped") self.assertTrue(isGreen("sha", workflow_checks)) - @mock.patch('print_latest_commits.get_commit_results', return_value=TestChecks().make_test_checks()) + @mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks()) def test_necessary_failed(self, mock_get_commit_results: Any) -> None: "Test with necessary job (ex: Lint) failed" workflow_checks = mock_get_commit_results() @@ -79,7 +79,7 @@ def test_necessary_failed(self, mock_get_commit_results: Any) -> None: self.assertFalse(result[0]) self.assertEqual(result[1], "Lint checks were not successful") - @mock.patch('print_latest_commits.get_commit_results', return_value=TestChecks().make_test_checks()) + @mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks()) def test_skippable_failed(self, mock_get_commit_results: Any) -> None: "Test with skippable job (ex: docker-release-builds) failing" workflow_checks = mock_get_commit_results() @@ -89,13 +89,13 @@ def test_skippable_failed(self, mock_get_commit_results: Any) -> None: self.assertFalse(result[0]) self.assertEqual(result[1], "docker-release-builds checks were not successful") - @mock.patch('print_latest_commits.get_commit_results', return_value={}) + @mock.patch('fetch_latest_green_commit.get_commit_results', return_value={}) def test_no_workflows(self, mock_get_commit_results: Any) -> None: "Test with missing workflows" workflow_checks = mock_get_commit_results() result = isGreen("sha", workflow_checks) self.assertFalse(result[0]) - self.assertEqual(result[1], "missing required workflows: pull, trunk, lint, linux-binary, android-tests, windows-binary") + self.assertEqual(result[1], "missing required workflows: pull, trunk, lint, linux-binary, windows-binary") if __name__ == "__main__": main() diff --git a/.github/scripts/test_gitutils.py b/.github/scripts/test_gitutils.py index 80a6e148e1eeb..78696771d993f 100644 --- a/.github/scripts/test_gitutils.py +++ b/.github/scripts/test_gitutils.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from gitutils import PeekableIterator +from gitutils import PeekableIterator, patterns_to_regex from unittest import TestCase, main class TestPeekableIterator(TestCase): @@ -22,6 +22,18 @@ def test_peek(self, input_: str = "abcdef") -> None: self.assertTrue(iter_.peek() is None) +class TestPattern(TestCase): + def test_double_asterisks(self) -> None: + allowed_patterns = [ + "aten/src/ATen/native/**LinearAlgebra*", + ] + patterns_re = patterns_to_regex(allowed_patterns) + fnames = [ + "aten/src/ATen/native/LinearAlgebra.cpp", + "aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp"] + for filename in fnames: + self.assertTrue(patterns_re.match(filename)) + if __name__ == '__main__': main() diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index cc07f35c29c0e..af3faf8cd0948 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -12,9 +12,12 @@ from hashlib import sha256 from trymerge import (find_matching_merge_rule, + get_land_checkrun_conclusions, + validate_land_time_checks, gh_graphql, gh_get_team_members, read_merge_rules, + validate_revert, GitHubPR, MergeRule, MandatoryChecksMissingError, @@ -74,6 +77,7 @@ def __init__(self) -> None: self.comment_id = 0 self.on_mandatory = False self.on_green = False + self.land_checks = False self.reason = 'this is for testing' return Object() @@ -90,6 +94,7 @@ def mock_merge(pr_num: int, repo: GitRepo, comment_id: Optional[int] = None, mandatory_only: bool = False, on_green: bool = False, + land_checks: bool = False, timeout_minutes: int = 400, stale_pr_days: int = 3) -> None: pass @@ -117,16 +122,25 @@ def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule approved_by=["pytorch/metamates"], mandatory_checks_name=["Lint", "Facebook CLA Check", - "linux-xenial-cuda11.3-py3.7-gcc7 / build", + "pull / linux-xenial-cuda11.3-py3.7-gcc7 / build", ], ), ] +class DummyGitRepo(GitRepo): + def __init__(self) -> None: + super().__init__(get_git_repo_dir(), get_git_remote_name()) + + def commits_resolving_gh_pr(self, pr_num: int) -> List[str]: + return ["FakeCommitSha"] + + def commit_message(self, ref: str) -> str: + return "super awsome commit message" class TestGitHubPR(TestCase): def test_merge_rules_valid(self) -> None: "Test that merge_rules.json can be parsed" - repo = GitRepo(get_git_repo_dir(), get_git_repo_dir()) + repo = DummyGitRepo() self.assertGreater(len(read_merge_rules(repo, "pytorch", "pytorch")), 1) @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) @@ -134,7 +148,7 @@ def test_merge_rules_valid(self) -> None: def test_match_rules(self, mocked_gql: Any, mocked_rmr: Any) -> None: "Tests that PR passes merge rules" pr = GitHubPR("pytorch", "pytorch", 77700) - repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) + repo = DummyGitRepo() self.assertTrue(find_matching_merge_rule(pr, repo) is not None) @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) @@ -142,7 +156,7 @@ def test_match_rules(self, mocked_gql: Any, mocked_rmr: Any) -> None: def test_lint_fails(self, mocked_gql: Any, mocked_rmr: Any) -> None: "Tests that PR fails mandatory lint check" pr = GitHubPR("pytorch", "pytorch", 74649) - repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) + repo = DummyGitRepo() self.assertRaises(RuntimeError, lambda: find_matching_merge_rule(pr, repo)) @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) @@ -189,7 +203,7 @@ def test_internal_changes(self, mocked_gql: Any) -> None: def test_checksuites_pagination(self, mocked_gql: Any) -> None: "Tests that PR with lots of checksuits can be fetched" pr = GitHubPR("pytorch", "pytorch", 73811) - self.assertGreater(len(pr.get_checkrun_conclusions()), 0) + self.assertEqual(len(pr.get_checkrun_conclusions()), 104) @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) def test_comments_pagination(self, mocked_gql: Any) -> None: @@ -232,7 +246,7 @@ def test_pending_status_check(self, mocked_gql: Any, mocked_read_merge_rules: An """ Tests that PR with nonexistent/pending status checks fails with the right reason. """ pr = GitHubPR("pytorch", "pytorch", 76118) - repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) + repo = DummyGitRepo() self.assertRaisesRegex(MandatoryChecksMissingError, ".*are pending/not yet run.*", lambda: find_matching_merge_rule(pr, repo)) @@ -253,7 +267,33 @@ def test_get_checkruns_many_runs(self, mocked_gql: Any) -> None: """ pr = GitHubPR("pytorch", "pytorch", 77700) conclusions = pr.get_checkrun_conclusions() - self.assertTrue("linux-docs / build-docs (cpp)" in conclusions.keys()) + self.assertEqual(len(conclusions), 83) + self.assertTrue("pull / linux-docs / build-docs (cpp)" in conclusions.keys()) + + @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) + def test_cancelled_gets_ignored(self, mocked_gql: Any) -> None: + """ Tests that cancelled workflow does not override existing successfull status + """ + pr = GitHubPR("pytorch", "pytorch", 82169) + conclusions = pr.get_checkrun_conclusions() + self.assertTrue("Lint" in conclusions.keys()) + self.assertEqual(conclusions["Lint"][0], "SUCCESS") + + @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) + def test_get_many_land_checks(self, mocked_gql: Any) -> None: + """ Tests that all checkruns can be fetched for a commit + """ + conclusions = get_land_checkrun_conclusions('pytorch', 'pytorch', '6882717f73deffb692219ccd1fd6db258d8ed684') + self.assertEqual(len(conclusions), 101) + self.assertTrue("pull / linux-docs / build-docs (cpp)" in conclusions.keys()) + + @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) + def test_failed_land_checks(self, mocked_gql: Any) -> None: + """ Tests that PR with Land Checks fail with a RunTime error + """ + self.assertRaisesRegex(RuntimeError, + ".*Failed to merge; some land checks failed.*", + lambda: validate_land_time_checks('pytorch', 'pytorch', '6882717f73deffb692219ccd1fd6db258d8ed684')) @mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info()) @mock.patch('trymerge.parse_args', return_value=mock_parse_args(True, False)) @@ -273,6 +313,7 @@ def test_main_force(self, mock_merge: Any, mock_parse_args: Any, mock_gh_get_inf force=True, comment_id=mock.ANY, on_green=False, + land_checks=False, mandatory_only=False) @mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info()) @@ -286,7 +327,15 @@ def test_main_merge(self, mock_merge: Any, mock_parse_args: Any, mock_gh_get_inf force=False, comment_id=mock.ANY, on_green=False, + land_checks=False, mandatory_only=False) + @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) + def test_revert_rules(self, mock_gql: Any) -> None: + """ Tests that reverts from collaborators are allowed """ + pr = GitHubPR("pytorch", "pytorch", 79694) + repo = DummyGitRepo() + self.assertIsNotNone(validate_revert(repo, pr, comment_id=1189459845)) + if __name__ == "__main__": main() diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 44ebe50dea295..9e23869cb3804 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from urllib.request import urlopen, Request from urllib.error import HTTPError -from typing import cast, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Iterable, Pattern, cast, Any, Callable, Dict, List, Optional, Tuple, Union from gitutils import get_git_remote_name, get_git_repo_dir, patterns_to_regex, GitRepo from functools import lru_cache from warnings import warn @@ -113,7 +113,7 @@ mergeCommit { oid } - commits_with_authors:commits(first: 100) { + commits_with_authors: commits(first: 100) { ...CommitAuthors totalCount } @@ -139,7 +139,7 @@ } } reviews(last: 100) { - ...PRReviews + ...PRReviews } comments(last: 5) { nodes { @@ -158,6 +158,13 @@ hasPreviousPage } } + labels(first: 100) { + edges { + node { + name + } + } + } } } } @@ -200,6 +207,62 @@ } """ +GH_GET_COMMIT_CHECKSUITES = GH_CHECKSUITES_FRAGMENT + """ +query ($owner: String!, $name: String!, $commit: String) { + repository(name: $name, owner: $owner) { + object(expression: $commit) { + ... on Commit { + checkSuites { + ...PRCheckSuites + } + } + } + } +} +""" + +GH_GET_COMMIT_NEXT_CHECKSUITES = GH_CHECKSUITES_FRAGMENT + """ +query ($owner: String!, $name: String!, $commit: String, $cursor: String!) { + repository(name: $name, owner: $owner) { + object(expression: $commit) { + ... on Commit { + oid + checkSuites(first: 10, after: $cursor) { + ...PRCheckSuites + } + } + } + } +} +""" + +GH_GET_COMMIT_NEXT_CHECK_RUNS = """ +query ($owner: String!, $name: String!, $cs_cursor: String, $cr_cursor: String!, $commit: String) { + repository(name: $name, owner: $owner) { + object(expression: $commit) { + ... on Commit { + oid + checkSuites(first: 1, after: $cs_cursor) { + nodes { + checkRuns(first: 100, after: $cr_cursor) { + nodes { + name + conclusion + detailsUrl + } + pageInfo { + endCursor + hasNextPage + } + } + } + } + } + } + } +} +""" + GH_GET_PR_NEXT_CHECK_RUNS = """ query ($owner: String!, $name: String!, $number: Int!, $cs_cursor: String, $cr_cursor: String!) { repository(name: $name, owner: $owner) { @@ -231,7 +294,6 @@ } """ - GH_GET_PR_PREV_COMMENTS = """ query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) { repository(name: $name, owner: $owner) { @@ -302,16 +364,16 @@ """ RE_GHSTACK_HEAD_REF = re.compile(r"^(gh/[^/]+/[0-9]+/)head$") -RE_GHSTACK_SOURCE_ID = re.compile(r'^ghstack-source-id: (.+)\n?', re.MULTILINE) +RE_GHSTACK_DESC = re.compile(r'Stack.*:\r?\n(\* [^\r\n]+\r?\n)+', re.MULTILINE) RE_PULL_REQUEST_RESOLVED = re.compile( r'Pull Request resolved: ' r'https://github.com/(?P[^/]+)/(?P[^/]+)/pull/(?P[0-9]+)', re.MULTILINE ) -RE_REVERT_CMD = re.compile(r"@pytorch(merge|)bot\s+revert\s+this") -RE_REVERT_CMD_CLI = re.compile(r"@pytorch(merge|)bot\s+revert\s+(-m.*-c.*|-c.*-m.*)") RE_DIFF_REV = re.compile(r'^Differential Revision:.+?(D[0-9]+)', re.MULTILINE) - +CIFLOW_LABEL = re.compile(r"^ciflow/.+") +CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk") +BOT_COMMANDS_WIKI = 'https://github.com/pytorch/pytorch/wiki/Bot-commands' def _fetch_url(url: str, *, headers: Optional[Dict[str, str]] = None, @@ -332,7 +394,6 @@ def _fetch_url(url: str, *, print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}") raise - def fetch_json(url: str, params: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: @@ -341,13 +402,27 @@ def fetch_json(url: str, url += '?' + '&'.join(f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items()) return cast(List[Dict[str, Any]], _fetch_url(url, headers=headers, data=data, reader=json.load)) +def fetch_json_dict(url: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None) -> Dict[str, Any] : + headers = {'Accept': 'application/vnd.github.v3+json'} + if params is not None and len(params) > 0: + url += '?' + '&'.join(f"{name}={urllib.parse.quote(str(val))}" for name, val in params.items()) + return cast(Dict[str, Any], _fetch_url(url, headers=headers, data=data, reader=json.load)) -def gh_post_comment(org: str, project: str, pr_num: int, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]: +def _gh_post_comment(url: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]: if dry_run: print(comment) return [] - return fetch_json(f'https://api.github.com/repos/{org}/{project}/issues/{pr_num}/comments', - data={"body": comment}) + return fetch_json(url, data={"body": comment}) + + +def gh_post_pr_comment(org: str, project: str, pr_num: int, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]: + return _gh_post_comment(f'https://api.github.com/repos/{org}/{project}/issues/{pr_num}/comments', comment, dry_run) + + +def gh_post_commit_comment(org: str, project: str, sha: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]: + return _gh_post_comment(f'https://api.github.com/repos/{org}/{project}/commits/{sha}/comments', comment, dry_run) def gh_add_labels(org: str, project: str, pr_num: int, labels: Union[str, List[str]]) -> None: @@ -366,6 +441,9 @@ def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any: rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no) return rc["data"]["repository"]["pullRequest"] +def gh_get_land_check_info(org: str, proj: str, commit: str) -> Any: + rc = gh_graphql(GH_GET_COMMIT_CHECKSUITES, name=proj, owner=org, commit=commit) + return rc["data"]["repository"]["object"] @lru_cache(maxsize=None) def gh_get_team_members(org: str, name: str) -> List[str]: @@ -381,6 +459,55 @@ def gh_get_team_members(org: str, name: str) -> List[str]: rc += [member["login"] for member in team_members["nodes"]] return rc +def get_check_run_name_prefix(workflow_run: Any) -> str: + if workflow_run is None: + return "" + else: + return f'{workflow_run["workflow"]["name"]} / ' + + +def add_workflow_conclusions( + checksuites: Any, + get_next_checkruns_page: Callable[[List[Dict[str, Dict[str, Any]]], int, Any], Any], + get_next_checksuites: Callable[[Any], Any] +) -> Dict[str, Tuple[str, str]]: + conclusions = {} + + def add_conclusions(edges: Any) -> None: + for edge_idx, edge in enumerate(edges): + node = edge["node"] + workflow_run = node["workflowRun"] + checkruns = node["checkRuns"] + if workflow_run is not None: + workflow_name = workflow_run["workflow"]["name"] + workflow_conclusion = node["conclusion"] + # Do not override existing status with cancelled + if workflow_conclusion == "CANCELLED" and workflow_name in conclusions: + continue + conclusions[workflow_name] = (workflow_conclusion, node["url"]) + has_failing_check = False + while checkruns is not None: + for checkrun_node in checkruns["nodes"]: + if checkrun_node["conclusion"] == 'FAILURE': + has_failing_check = True + conclusions[f'{get_check_run_name_prefix(workflow_run)}{checkrun_node["name"]}'] = ( + checkrun_node["conclusion"], checkrun_node["detailsUrl"] + ) + if bool(checkruns["pageInfo"]["hasNextPage"]): + checkruns = get_next_checkruns_page(edges, edge_idx, checkruns) + else: + checkruns = None + # Github doesn't set conclusion to failure if a job is still pending + if workflow_run is not None and has_failing_check: + conclusions[workflow_run["workflow"]["name"]] = ("FAILURE", node["url"]) + + add_conclusions(checksuites["edges"]) + while bool(checksuites["pageInfo"]["hasNextPage"]): + checksuites = get_next_checksuites(checksuites) + add_conclusions(checksuites["edges"]) + + return conclusions + def parse_args() -> Any: from argparse import ArgumentParser @@ -388,6 +515,7 @@ def parse_args() -> Any: parser.add_argument("--dry-run", action="store_true") parser.add_argument("--on-green", action="store_true") parser.add_argument("--on-mandatory", action="store_true") + parser.add_argument("--land-checks", action="store_true") parser.add_argument("--revert", action="store_true") parser.add_argument("--force", action="store_true") parser.add_argument("--comment-id", type=int) @@ -421,6 +549,7 @@ def __init__(self, org: str, project: str, pr_num: int) -> None: self.pr_num = pr_num self.info = gh_get_pr_info(org, project, pr_num) self.changed_files: Optional[List[str]] = None + self.labels: Optional[List[str]] = None self.conclusions: Optional[Dict[str, Tuple[str, str]]] = None self.comments: Optional[List[GitHubComment]] = None self._authors: Optional[List[Tuple[str, str]]] = None @@ -542,38 +671,31 @@ def get_committer_login(self, num: int = 0) -> str: def get_committer_author(self, num: int = 0) -> str: return self._fetch_authors()[num][1] + def get_labels(self) -> List[str]: + if self.labels is not None: + return self.labels + labels = [node['node']['name'] for node in self.info["labels"]["edges"]] if "labels" in self.info else [] + self.labels = labels + return self.labels + def get_checkrun_conclusions(self) -> Dict[str, Tuple[str, str]]: """ Returns dict of checkrun -> [conclusion, url] """ if self.conclusions is not None: return self.conclusions orig_last_commit = self.info["commits"]["nodes"][-1]["commit"] - checksuites = orig_last_commit["checkSuites"] - conclusions = {} - - def add_conclusions(edges: List[Dict[str, Dict[str, Any]]]) -> None: - for edge_idx, edge in enumerate(edges): - node = edge["node"] - workflow_run = node["workflowRun"] - checkruns = node["checkRuns"] - if workflow_run is not None: - conclusions[workflow_run["workflow"]["name"]] = (node["conclusion"], node["url"]) - while checkruns is not None: - for checkrun_node in checkruns["nodes"]: - conclusions[checkrun_node["name"]] = (checkrun_node["conclusion"], checkrun_node["detailsUrl"]) - if bool(checkruns["pageInfo"]["hasNextPage"]): - rc = gh_graphql(GH_GET_PR_NEXT_CHECK_RUNS, - name=self.project, - owner=self.org, - number=self.pr_num, - cs_cursor=edges[edge_idx - 1]["cursor"] if edge_idx > 0 else None, - cr_cursor=checkruns["pageInfo"]["endCursor"]) - last_commit = rc["data"]["repository"]["pullRequest"]["commits"]["nodes"][-1]["commit"] - checkruns = last_commit["checkSuites"]["nodes"][-1]["checkRuns"] - else: - checkruns = None - add_conclusions(checksuites["edges"]) - while bool(checksuites["pageInfo"]["hasNextPage"]): + def get_pr_next_check_runs(edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any) -> Any: + rc = gh_graphql(GH_GET_PR_NEXT_CHECK_RUNS, + name=self.project, + owner=self.org, + number=self.pr_num, + cs_cursor=edges[edge_idx - 1]["cursor"] if edge_idx > 0 else None, + cr_cursor=checkruns["pageInfo"]["endCursor"]) + last_commit = rc["data"]["repository"]["pullRequest"]["commits"]["nodes"][-1]["commit"] + checkruns = last_commit["checkSuites"]["nodes"][-1]["checkRuns"] + return checkruns + + def get_pr_next_checksuites(checksuites: Any) -> Any: rc = gh_graphql(GH_GET_PR_NEXT_CHECKSUITES, name=self.project, owner=self.org, @@ -583,10 +705,12 @@ def add_conclusions(edges: List[Dict[str, Dict[str, Any]]]) -> None: last_commit = info["commits"]["nodes"][-1]["commit"] if last_commit["oid"] != orig_last_commit["oid"]: raise RuntimeError("Last commit changed on PR") - checksuites = last_commit["checkSuites"] - add_conclusions(checksuites["edges"]) - self.conclusions = conclusions - return conclusions + return last_commit["checkSuites"] + + checksuites = orig_last_commit["checkSuites"] + + self.conclusions = add_workflow_conclusions(checksuites, get_pr_next_check_runs, get_pr_next_checksuites) + return self.conclusions def get_authors(self) -> Dict[str, str]: rc = {} @@ -683,7 +807,6 @@ def has_internal_changes(self) -> bool: def merge_ghstack_into(self, repo: GitRepo, force: bool, comment_id: Optional[int] = None) -> None: assert self.is_ghstack_pr() - approved_by = self.get_approved_by() # For ghstack, cherry-pick commits based from origin orig_ref = f"{repo.remote}/{re.sub(r'/head$', '/orig', self.head_ref())}" rev_list = repo.revlist(f"{self.default_branch()}..{orig_ref}") @@ -695,33 +818,53 @@ def merge_ghstack_into(self, repo: GitRepo, force: bool, comment_id: Optional[in if self.org != m.group('owner') or self.project != m.group('repo'): raise RuntimeError(f"PR {m.group('number')} resolved to wrong owner/repo pair") pr_num = int(m.group('number')) + commit_msg = self.gen_commit_message(filter_ghstack=True) if pr_num != self.pr_num: pr = GitHubPR(self.org, self.project, pr_num) if pr.is_closed(): print(f"Skipping {idx+1} of {len(rev_list)} PR (#{pr_num}) as its already been merged") continue - approved_by = pr.get_approved_by() + commit_msg = pr.gen_commit_message(filter_ghstack=True) # Raises exception if matching rule is not found find_matching_merge_rule(pr, repo, force=force, skip_internal_checks=can_skip_internal_checks(self, comment_id)) - # Adding the url here makes it clickable within the Github UI - approved_by_urls = ', '.join(prefix_with_github_url(login) for login in approved_by) repo.cherry_pick(rev) - msg = re.sub(RE_GHSTACK_SOURCE_ID, "", msg) - msg += f"\nApproved by: {approved_by_urls}\n" - repo.amend_commit_message(msg) - - def merge_into(self, repo: GitRepo, *, force: bool = False, dry_run: bool = False, comment_id: Optional[int] = None) -> None: + repo.amend_commit_message(commit_msg) + + def gen_commit_message(self, filter_ghstack: bool = False) -> str: + """ Fetches title and body from PR description + adds reviewed by, pull request resolved and optionally + filters out ghstack info """ + # Adding the url here makes it clickable within the Github UI + approved_by_urls = ', '.join(prefix_with_github_url(login) for login in self.get_approved_by()) + msg = self.get_title() + f" (#{self.pr_num})\n\n" + msg += self.get_body() if not filter_ghstack else re.sub(RE_GHSTACK_DESC, "", self.get_body()) + msg += f"\nPull Request resolved: {self.get_pr_url()}\n" + msg += f"Approved by: {approved_by_urls}\n" + return msg + + def merge_into(self, repo: GitRepo, *, + force: bool = False, + dry_run: bool = False, + comment_id: Optional[int] = None) -> None: # Raises exception if matching rule is not found find_matching_merge_rule(self, repo, force=force, skip_internal_checks=can_skip_internal_checks(self, comment_id)) - if repo.current_branch() != self.default_branch(): - repo.checkout(self.default_branch()) + self.merge_changes(repo, force, comment_id) + + repo.push(self.default_branch(), dry_run) + if not dry_run: + gh_add_labels(self.org, self.project, self.pr_num, ["merged"]) + + def merge_changes(self, + repo: GitRepo, + force: bool = False, + comment_id: Optional[int] = None, + branch: Optional[str] = None) -> None: + branch_to_merge_into = self.default_branch() if branch is None else branch + if repo.current_branch() != branch_to_merge_into: + repo.checkout(branch_to_merge_into) if not self.is_ghstack_pr(): - # Adding the url here makes it clickable within the Github UI - approved_by_urls = ', '.join(prefix_with_github_url(login) for login in self.get_approved_by()) - msg = self.get_title() + f" (#{self.pr_num})\n\n" + self.get_body() - msg += f"\nPull Request resolved: {self.get_pr_url()}\n" - msg += f"Approved by: {approved_by_urls}\n" + msg = self.gen_commit_message() pr_branch_name = f"__pull-request-{self.pr_num}__init__" repo.fetch(f"pull/{self.pr_num}/head", pr_branch_name) repo._run_git("merge", "--squash", pr_branch_name) @@ -729,16 +872,34 @@ def merge_into(self, repo: GitRepo, *, force: bool = False, dry_run: bool = Fals else: self.merge_ghstack_into(repo, force, comment_id=comment_id) - repo.push(self.default_branch(), dry_run) - gh_post_comment(self.org, self.project, self.pr_num, - f"@{self.get_pr_creator_login()} your PR has been successfully merged.", dry_run) - if not dry_run: - gh_add_labels(self.org, self.project, self.pr_num, ["merged"]) + def create_land_time_check_branch(self, + repo: GitRepo, + branch: str, + force: bool = False, + comment_id: Optional[int] = None,) -> str: + self.merge_changes(repo, branch=branch, force=force, comment_id=comment_id) + land_check_branch = f'landchecks/{self.pr_num}' + try: + repo._run_git('branch', "-D", land_check_branch) + except Exception: + pass + repo._run_git('checkout', "-b", land_check_branch) + repo._run_git('push', '-u', 'origin', land_check_branch, '--force') + commit = repo.get_commit('HEAD').commit_hash + gh_post_pr_comment(self.org, self.project, self.pr_num, + '@pytorchbot successfully started a merge and created land time checks.' + + f' See merge status [here]({os.getenv("GH_RUN_URL")}) ' + + f'and [land check]({BOT_COMMANDS_WIKI}) ' + f'progress [here](https://hud.pytorch.org/{self.org}/{self.project}/commit/{commit}).') + return commit class MandatoryChecksMissingError(Exception): pass +class PostCommentError(Exception): + pass + @dataclass class MergeRule: @@ -822,21 +983,10 @@ def find_matching_merge_rule(pr: GitHubPR, reject_reason = (f"Matched rule {rule_name}, but PR #{pr.pr_num} was not reviewed yet by any of: " + f"{', '.join(list(rule_approvers_set)[:5])}{', ...' if len(rule_approvers_set) > 5 else ''}") continue - if rule.mandatory_checks_name is not None: - pending_checks: List[Tuple[str, Optional[str]]] = [] - failed_checks: List[Tuple[str, Optional[str]]] = [] - checks = pr.get_checkrun_conclusions() - # HACK: We don't want to skip CLA check, even when forced - for checkname in filter(lambda x: force is False or "CLA Check" in x, rule.mandatory_checks_name): - if checkname not in checks: - pending_checks.append((checkname, None)) - elif checks[checkname][0] is None: - pending_checks.append((checkname, checks[checkname][1])) - elif checks[checkname][0] != 'SUCCESS': - failed_checks.append((checkname, checks[checkname][1])) - - def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str: - return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks) + mandatory_checks = rule.mandatory_checks_name if rule.mandatory_checks_name is not None else [] + checks = pr.get_checkrun_conclusions() + required_checks = filter(lambda x: force is False or "CLA Check" in x, mandatory_checks) + [pending_checks, failed_checks] = categorize_checks(checks, required_checks) if len(failed_checks) > 0: if reject_reason_score < 30000: @@ -858,6 +1008,35 @@ def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str: raise RuntimeError(reject_reason) +def get_land_checkrun_conclusions(org: str, project: str, commit: str) -> Dict[str, Tuple[str, str]]: + + def get_commit_next_check_runs(edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any) -> Any: + rc = gh_graphql(GH_GET_COMMIT_NEXT_CHECK_RUNS, + name=project, + owner=org, + cs_cursor=edges[edge_idx - 1]["cursor"] if edge_idx > 0 else None, + cr_cursor=checkruns["pageInfo"]["endCursor"], + commit=commit) + return rc["data"]["repository"]["object"]["checkSuites"]["nodes"][-1]["checkRuns"] + + def get_commit_next_checksuites(checksuites: Any) -> Any: + rc = gh_graphql(GH_GET_COMMIT_NEXT_CHECKSUITES, + name=project, + owner=org, + commit=commit, + cursor=checksuites["edges"][-1]["cursor"]) + info = rc["data"]["repository"]["object"] + return info["checkSuites"] + + land_check_info = gh_get_land_check_info(org, project, commit) + checksuites = land_check_info["checkSuites"] + + return add_workflow_conclusions(checksuites, get_commit_next_check_runs, get_commit_next_checksuites) + + +def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str: + return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks) + def pr_get_checks_with_lambda(pr: GitHubPR, status_check: Callable[[Optional[str]], bool]) -> List[Tuple[str, str]]: checks = pr.get_checkrun_conclusions() return [(name, status[1]) for name, status in checks.items() if status_check(status[0])] @@ -868,28 +1047,25 @@ def pr_get_pending_checks(pr: GitHubPR) -> List[Tuple[str, str]]: def pr_get_failed_checks(pr: GitHubPR) -> List[Tuple[str, str]]: - return pr_get_checks_with_lambda(pr, lambda x: x == "FAILURE") + return pr_get_checks_with_lambda(pr, lambda x: x in ["FAILURE", "STARTUP_FAILURE"]) -def try_revert(repo: GitRepo, pr: GitHubPR, *, - dry_run: bool = False, - comment_id: Optional[int] = None, - reason: Optional[str] = None) -> None: - def post_comment(msg: str) -> None: - gh_post_comment(pr.org, pr.project, pr.pr_num, msg, dry_run=dry_run) - if not pr.is_closed(): - return post_comment(f"Can't revert open PR #{pr.pr_num}") +def validate_revert(repo: GitRepo, pr: GitHubPR, *, + comment_id: Optional[int] = None) -> Tuple[str, str]: comment = pr.get_last_comment() if comment_id is None else pr.get_comment_by_id(comment_id) - if not RE_REVERT_CMD.match(comment.body_text) and not RE_REVERT_CMD_CLI.match(comment.body_text): - raise RuntimeError(f"Comment {comment.body_text} does not seem to be a valid revert command") if comment.editor_login is not None: - return post_comment("Don't want to revert based on edited command") + raise PostCommentError("Don't want to revert based on edited command") author_association = comment.author_association author_login = comment.author_login + allowed_reverters = ["COLLABORATOR", "MEMBER", "OWNER"] # For some reason, one can not be a member of private repo, only CONTRIBUTOR - expected_association = "CONTRIBUTOR" if pr.is_base_repo_private() else "MEMBER" - if author_association != expected_association and author_association != "OWNER": - return post_comment(f"Will not revert as @{author_login} is not a {expected_association}, but {author_association}") + if pr.is_base_repo_private(): + allowed_reverters.append("CONTRIBUTOR") + if author_association not in allowed_reverters: + raise PostCommentError(( + f"Will not revert as @{author_login} is not one of " + f"[{', '.join(allowed_reverters)}], but instead is {author_association}." + )) skip_internal_checks = can_skip_internal_checks(pr, comment_id) # Raises exception if matching rule is not found, but ignores all status checks @@ -898,29 +1074,43 @@ def post_comment(msg: str) -> None: if commit_sha is None: commits = repo.commits_resolving_gh_pr(pr.pr_num) if len(commits) == 0: - raise RuntimeError("Can't find any commits resolving PR") + raise PostCommentError("Can't find any commits resolving PR") commit_sha = commits[0] msg = repo.commit_message(commit_sha) rc = RE_DIFF_REV.search(msg) if rc is not None and not can_skip_internal_checks: - raise RuntimeError(f"Can't revert PR that was landed via phabricator as {rc.group(1)}") + raise PostCommentError(f"Can't revert PR that was landed via phabricator as {rc.group(1)}") + return (author_login, commit_sha) + + +def try_revert(repo: GitRepo, pr: GitHubPR, *, + dry_run: bool = False, + comment_id: Optional[int] = None, + reason: Optional[str] = None) -> None: + def post_comment(msg: str) -> None: + gh_post_pr_comment(pr.org, pr.project, pr.pr_num, msg, dry_run=dry_run) + try: + author_login, commit_sha = validate_revert(repo, pr, comment_id=comment_id) + except PostCommentError as e: + return post_comment(str(e)) + revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}" + revert_msg += f" due to {reason}\n" if reason is not None else "\n" repo.checkout(pr.default_branch()) repo.revert(commit_sha) msg = repo.commit_message("HEAD") msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg) - msg += f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}" - msg += f" due to {reason}\n" if reason is not None else "\n" + msg += revert_msg repo.amend_commit_message(msg) repo.push(pr.default_branch(), dry_run) post_comment(f"@{pr.get_pr_creator_login()} your PR has been successfully reverted.") if not dry_run: gh_add_labels(pr.org, pr.project, pr.pr_num, ["reverted"]) + gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg) def prefix_with_github_url(suffix_str: str) -> str: return f"https://github.com/{suffix_str}" - def check_for_sev(org: str, project: str, force: bool) -> None: if force: return @@ -941,6 +1131,35 @@ def check_for_sev(org: str, project: str, force: bool) -> None: ) return +def validate_land_time_checks(org: str, project: str, commit: str) -> None: + checks = get_land_checkrun_conclusions(org, project, commit) + if(len(checks) == 0): + raise MandatoryChecksMissingError("Refusing to merge as land check(s) are not yet run") + + [pending_checks, failed_checks] = categorize_checks(checks, checks) + + if len(failed_checks) > 0: + raise RuntimeError(f"Failed to merge; some land checks failed: {checks_to_str(failed_checks)}") + if len(pending_checks) > 0: + raise MandatoryChecksMissingError(f"Refusing to merge as land check(s) {checks_to_str(pending_checks)} are not yet run") + +def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool: + return len(list(filter(pattern.match, labels))) > 0 + +def categorize_checks(check_runs: Dict[str, Tuple[str, str]], + required_checks: Iterable[str]) -> Tuple[List[Tuple[str, Optional[str]]], List[Tuple[str, Optional[str]]]]: + pending_checks: List[Tuple[str, Optional[str]]] = [] + failed_checks: List[Tuple[str, Optional[str]]] = [] + for checkname in required_checks: + if checkname not in check_runs: + pending_checks.append((checkname, None)) + elif check_runs[checkname][0] is None: + pending_checks.append((checkname, check_runs[checkname][1])) + elif (check_runs[checkname][0].upper() != 'SUCCESS' + and check_runs[checkname][0].upper() != 'SKIPPED' + and check_runs[checkname][0].upper() != 'NEUTRAL'): + failed_checks.append((checkname, check_runs[checkname][1])) + return (pending_checks, failed_checks) def merge(pr_num: int, repo: GitRepo, dry_run: bool = False, @@ -948,6 +1167,7 @@ def merge(pr_num: int, repo: GitRepo, comment_id: Optional[int] = None, mandatory_only: bool = False, on_green: bool = False, + land_checks: bool = False, timeout_minutes: int = 400, stale_pr_days: int = 3) -> None: repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) @@ -961,6 +1181,9 @@ def merge(pr_num: int, repo: GitRepo, if (datetime.utcnow() - pr.last_pushed_at()).days > stale_pr_days: raise RuntimeError("This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again.") + if land_checks: + land_check_commit = pr.create_land_time_check_branch(repo, 'viable/strict', force=force, comment_id=comment_id) + start_time = time.time() last_exception = '' elapsed_time = 0.0 @@ -976,12 +1199,23 @@ def merge(pr_num: int, repo: GitRepo, find_matching_merge_rule(pr, repo) pending = pr_get_pending_checks(pr) failing = pr_get_failed_checks(pr) + + # HACK until GitHub will be better about surfacing those + startup_failures = pr_get_checks_with_lambda(pr, lambda x: x == "STARTUP_FAILURE") + if len(startup_failures) > 0: + raise RuntimeError(f"{len(failing)} STARTUP failures reported, please check workflows syntax! " + + ' ,'.join(f"[{x[0]}]({x[1]})" for x in startup_failures[:5])) + # END of HACK + if (not mandatory_only and on_green) and len(failing) > 0: raise RuntimeError(f"{len(failing)} additional jobs have failed, first few of them are: " + ' ,'.join(f"[{x[0]}]({x[1]})" for x in failing[:5])) if (not mandatory_only and on_green) and len(pending) > 0: raise MandatoryChecksMissingError(f"Still waiting for {len(pending)} additional jobs to finish, " + f"first few of them are: {' ,'.join(x[0] for x in pending[:5])}") + if land_checks: + validate_land_time_checks(org, project, land_check_commit) + return pr.merge_into(repo, dry_run=dry_run, force=force, comment_id=comment_id) except MandatoryChecksMissingError as ex: last_exception = str(ex) @@ -999,19 +1233,25 @@ def main() -> None: repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) org, project = repo.gh_owner_and_name() pr = GitHubPR(org, project, args.pr_num) + land_checks = args.land_checks and not has_label(pr.get_labels(), CIFLOW_TRUNK_LABEL) def handle_exception(e: Exception, msg: str = "Merge failed") -> None: msg += f" due to {e}" run_url = os.getenv("GH_RUN_URL") if run_url is not None: msg += f"\nRaised by {run_url}" - gh_post_comment(org, project, args.pr_num, msg, dry_run=args.dry_run) + if land_checks: + msg += (" If you believe this is an error, you can use the old behavior with `@pytorchbot merge -g`" + + ' (optionally with the "ciflow/trunk" to get land signals)' + + ' or use `@pytorchbot merge -f "some reason here"`.' + + f" For more information, see the [bot wiki]({BOT_COMMANDS_WIKI}).") + gh_post_pr_comment(org, project, args.pr_num, msg, dry_run=args.dry_run) import traceback traceback.print_exc() - - msg = f"@pytorchbot successfully started a {'revert' if args.revert else 'merge'} job." - msg += f" Check the current status [here]({os.getenv('GH_RUN_URL')})" - gh_post_comment(org, project, args.pr_num, msg, dry_run=args.dry_run) + if not land_checks: + msg = f"@pytorchbot successfully started a {'revert' if args.revert else 'merge'} job." + msg += f" Check the current status [here]({os.getenv('GH_RUN_URL')})" + gh_post_pr_comment(org, project, args.pr_num, msg, dry_run=args.dry_run) if args.revert: try: @@ -1021,24 +1261,25 @@ def handle_exception(e: Exception, msg: str = "Merge failed") -> None: return if pr.is_closed(): - gh_post_comment(org, project, args.pr_num, f"Can't merge closed PR #{args.pr_num}", dry_run=args.dry_run) + gh_post_pr_comment(org, project, args.pr_num, f"Can't merge closed PR #{args.pr_num}", dry_run=args.dry_run) return if pr.is_cross_repo() and pr.is_ghstack_pr(): - gh_post_comment(org, project, args.pr_num, "Cross-repo ghstack merges are not supported", dry_run=args.dry_run) + gh_post_pr_comment(org, project, args.pr_num, "Cross-repo ghstack merges are not supported", dry_run=args.dry_run) return try: + on_green = args.on_green or has_label(pr.get_labels(), CIFLOW_LABEL) merge(args.pr_num, repo, dry_run=args.dry_run, force=args.force, comment_id=args.comment_id, - on_green=args.on_green, - mandatory_only=args.on_mandatory) + on_green=on_green, + mandatory_only=args.on_mandatory, + land_checks=land_checks) except Exception as e: handle_exception(e) - if __name__ == "__main__": main() diff --git a/.github/scripts/tryrebase.py b/.github/scripts/tryrebase.py index aa5cabd8d7f88..1b69f653e525a 100755 --- a/.github/scripts/tryrebase.py +++ b/.github/scripts/tryrebase.py @@ -6,7 +6,7 @@ import re from typing import Any from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo -from trymerge import gh_post_comment, GitHubPR +from trymerge import gh_post_pr_comment as gh_post_comment, GitHubPR def parse_args() -> Any: @@ -49,6 +49,12 @@ def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: repo.fetch(orig_ref, orig_ref) repo._run_git("rebase", onto_branch, orig_ref) + # steal the identity of the committer of the commit on the orig branch + email = repo._run_git("log", orig_ref, "--pretty=format:%ae", "-1") + name = repo._run_git("log", orig_ref, "--pretty=format:%an", "-1") + repo._run_git("config", "--global", "user.name", name) + repo._run_git("config", "--global", "user.email", email) + os.environ["OAUTH_TOKEN"] = os.environ["GITHUB_TOKEN"] with open('.ghstackrc', 'w+') as f: f.write('[ghstack]\n' + diff --git a/.github/scripts/update_commit_hashes.py b/.github/scripts/update_commit_hashes.py index 617b42e77a54c..5dad5877ca4ae 100644 --- a/.github/scripts/update_commit_hashes.py +++ b/.github/scripts/update_commit_hashes.py @@ -11,18 +11,24 @@ def git_api( - url: str, params: Dict[str, str], post: bool = False, token: str = MERGEBOT_TOKEN + url: str, params: Dict[str, str], type: str = "get", token: str = MERGEBOT_TOKEN ) -> Any: headers = { "Accept": "application/vnd.github.v3+json", "Authorization": f"token {token}", } - if post: + if type == "post": return requests.post( f"https://api.github.com{url}", data=json.dumps(params), headers=headers, ).json() + elif type == "patch": + return requests.patch( + f"https://api.github.com{url}", + data=json.dumps(params), + headers=headers, + ).json() else: return requests.get( f"https://api.github.com{url}", @@ -46,7 +52,7 @@ def make_pr(repo_name: str, branch_name: str) -> Any: "body": "This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/" + f".github/workflows/_update-commit-hash.yml).\nUpdate the pinned {repo_name} hash.", } - response = git_api(f"/repos/{OWNER}/{REPO}/pulls", params, post=True) + response = git_api(f"/repos/{OWNER}/{REPO}/pulls", params, type="post") print(f"made pr {response['html_url']}") return response["number"] @@ -57,22 +63,47 @@ def approve_pr(pr_number: str) -> None: git_api( f"/repos/{OWNER}/{REPO}/pulls/{pr_number}/reviews", params, - post=True, + type="post", token=PYTORCHBOT_TOKEN, ) -def make_comment(pr_number: str) -> None: - params = {"body": "@pytorchbot merge -g"} +def make_comment(pr_number: str, msg: str) -> None: + params = {"body": msg} # comment with pytorchbot because pytorchmergebot gets ignored git_api( f"/repos/{OWNER}/{REPO}/issues/{pr_number}/comments", params, - post=True, + type="post", token=PYTORCHBOT_TOKEN, ) +def close_pr(pr_number: str) -> None: + params = {"state": "closed"} + git_api( + f"/repos/{OWNER}/{REPO}/pulls/{pr_number}", + params, + type="patch", + ) + + +def is_newer_hash(new_hash: str, old_hash: str, repo_name: str) -> bool: + def _get_date(hash: str) -> int: + # this git command prints the unix timestamp of the hash + return int( + subprocess.run( + f"git show --no-patch --no-notes --pretty=%ct {hash}".split(), + capture_output=True, + cwd=f"{repo_name}", + ) + .stdout.decode("utf-8") + .strip() + ) + + return _get_date(new_hash) > _get_date(old_hash) + + def main() -> None: args = parse_args() @@ -87,22 +118,28 @@ def main() -> None: if response["total_count"] != 0: # pr does exist pr_num = response["items"][0]["number"] + link = response["items"][0]["html_url"] response = git_api(f"/repos/{OWNER}/{REPO}/pulls/{pr_num}", {}) branch_name = response["head"]["ref"] - print(f"pr does exist, number is {pr_num}, branch name is {branch_name}") - - # update file - hash = subprocess.run( - f"git rev-parse {args.branch}".split(), - capture_output=True, - cwd=f"{args.repo_name}", - ).stdout.decode("utf-8") - with open(f".github/ci_commit_pins/{args.repo_name}.txt", "w") as f: - f.write(hash) - git_diff = subprocess.run( - f"git diff --exit-code .github/ci_commit_pins/{args.repo_name}.txt".split() + print( + f"pr does exist, number is {pr_num}, branch name is {branch_name}, link is {link}" + ) + + hash = ( + subprocess.run( + f"git rev-parse {args.branch}".split(), + capture_output=True, + cwd=f"{args.repo_name}", + ) + .stdout.decode("utf-8") + .strip() ) - if git_diff.returncode == 1: + with open(f".github/ci_commit_pins/{args.repo_name}.txt", "r+") as f: + old_hash = f.read().strip() + f.seek(0) + f.truncate() + f.write(f"{hash}\n") + if is_newer_hash(hash, old_hash, args.repo_name): # if there was an update, push to branch subprocess.run(f"git checkout -b {branch_name}".split()) subprocess.run(f"git add .github/ci_commit_pins/{args.repo_name}.txt".split()) @@ -115,9 +152,16 @@ def main() -> None: # no existing pr, so make a new one and approve it pr_num = make_pr(args.repo_name, branch_name) approve_pr(pr_num) - if pr_num is not None: # comment to merge if all checks are green - make_comment(pr_num) + make_comment(pr_num, "@pytorchbot merge -g") + else: + print( + f"tried to update from old hash: {old_hash} to new hash: {hash} but the old hash seems to be newer, not creating pr" + ) + if pr_num is not None: + make_comment(pr_num, "closing pr as the current hash seems up to date") + close_pr(pr_num) + print(f"closing PR {pr_num}") if __name__ == "__main__": diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index 798247fcb49e6..f0f3e3a430f7d 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -13,7 +13,7 @@ {%- macro concurrency(build_environment) -%} concurrency: - group: !{{ build_environment }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + group: !{{ build_environment }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true {%- endmacro -%} @@ -91,7 +91,6 @@ on: AWS_DEFAULT_REGION: us-east-1 GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: !{{ build_environment }}-test PR_NUMBER: ${{ github.event.pull_request.number }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} TAG: ${{ steps.parse-ref.outputs.tag }} diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index 67f1caf6957cb..2879da9dad9c2 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -48,79 +48,36 @@ jobs: {%- for config in build_configs %} !{{ config["build_name"] }}-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: linux.4xlarge - timeout-minutes: !{{ common.timeout_minutes }} - !{{ upload.binary_env(config) }} - steps: - !{{ common.setup_ec2_linux() }} - !{{ common.checkout(deep_clone=False, directory="pytorch") }} - !{{ common.checkout(deep_clone=False, directory="builder", repository="pytorch/builder", branch=common.builder_branch) }} -{%- if config["gpu_arch_type"] == 'cuda' and config["gpu_arch_version"].startswith('11') %} - - name: Set BUILD_SPLIT_CUDA - run: | - echo "BUILD_SPLIT_CUDA='ON'" >> "$GITHUB_ENV" -{%- endif %} - - name: Pull Docker image - run: | - !{{ common.add_retry_to_env() }} - retry docker pull "${DOCKER_IMAGE}" - - name: Build PyTorch binary - run: | - set -x - mkdir -p artifacts/ - container_name=$(docker run \ - -e BINARY_ENV_FILE \ - -e BUILDER_ROOT \ - -e BUILD_ENVIRONMENT \ - -e BUILD_SPLIT_CUDA \ - -e DESIRED_CUDA \ - -e DESIRED_DEVTOOLSET \ - -e DESIRED_PYTHON \ - -e GITHUB_ACTIONS \ - -e GPU_ARCH_TYPE \ - -e GPU_ARCH_VERSION \ - -e LIBTORCH_VARIANT \ - -e PACKAGE_TYPE \ - -e PYTORCH_FINAL_PACKAGE_DIR \ - -e PYTORCH_ROOT \ - -e SKIP_ALL_TESTS \ - --tty \ - --detach \ - -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ - -v "${GITHUB_WORKSPACE}/builder:/builder" \ - -v "${RUNNER_TEMP}/artifacts:/artifacts" \ - -w / \ - "${DOCKER_IMAGE}" - ) - docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh" - docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /builder/!{{ config["package_type"] }}/build.sh" - !{{ common.chown_dir("${RUNNER_TEMP}/artifacts") }} - - uses: !{{ common.upload_artifact_s3_action }} - with: - name: !{{ config["build_name"] }} - retention-days: 14 - if-no-files-found: error - path: - ${{ runner.temp }}/artifacts/* - !{{ common.teardown_ec2_linux("pytorch/") }} + uses: ./.github/workflows/_binary-build-linux.yml + with:!{{ upload.binary_env_as_input(config) }} + build_name: !{{ config["build_name"] }} + build_environment: !{{ build_environment }} + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + !{{ config["build_name"] }}-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: !{{ config["build_name"] }}-build -{%- if config["gpu_arch_type"] == "rocm" %} - runs-on: linux.rocm.gpu -{%- elif config["gpu_arch_type"] == "cuda" %} - runs-on: linux.4xlarge.nvidia.gpu +{%- if config["gpu_arch_type"] != "rocm" %} + uses: ./.github/workflows/_binary-test-linux.yml + with:!{{ upload.binary_env_as_input(config) }} + build_name: !{{ config["build_name"] }} + build_environment: !{{ build_environment }} + {%- if config["gpu_arch_type"] == "rocm" %} + runs_on: linux.rocm.gpu + {%- elif config["gpu_arch_type"] == "cuda" %} + runs_on: linux.4xlarge.nvidia.gpu + {%- else %} + runs_on: linux.4xlarge + {%- endif %} + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} {%- else %} - runs-on: linux.4xlarge -{%- endif %} + runs-on: linux.rocm.gpu timeout-minutes: !{{ common.timeout_minutes }} !{{ upload.binary_env(config) }} steps: -{%- if config["gpu_arch_type"] == "rocm" %} !{{ common.setup_rocm_linux() }} -{%- else %} - !{{ common.setup_ec2_linux() }} -{%- endif %} - uses: !{{ common.download_artifact_s3_action }} name: Download Build Artifacts with: @@ -128,66 +85,19 @@ jobs: path: "${{ runner.temp }}/artifacts/" !{{ common.checkout(deep_clone=False, directory="pytorch") }} !{{ common.checkout(deep_clone=False, directory="builder", repository="pytorch/builder", branch=common.builder_branch) }} -{%- if config["gpu_arch_type"] == "rocm" %} - name: ROCm set GPU_FLAG run: | echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" -{%- elif config["gpu_arch_type"] == "cuda" %} - - uses: nick-fields/retry@71062288b76e2b6214ebde0e673ce0de1755740a - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - with: - timeout_minutes: 10 - max_attempts: 3 - command: | - set -ex - pushd pytorch - bash .github/scripts/install_nvidia_utils_linux.sh - echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" - popd -{%- endif %} - name: Pull Docker image - run: | - !{{ common.add_retry_to_env() }} - retry docker pull "${DOCKER_IMAGE}" - - name: Test PyTorch binary - run: | - set -x - # shellcheck disable=SC2086,SC2090 - container_name=$(docker run \ - ${GPU_FLAG:-} \ - -e BINARY_ENV_FILE \ - -e BUILDER_ROOT \ - -e BUILD_ENVIRONMENT \ - -e BUILD_SPLIT_CUDA \ - -e DESIRED_CUDA \ - -e DESIRED_DEVTOOLSET \ - -e DESIRED_PYTHON \ - -e GITHUB_ACTIONS \ - -e GPU_ARCH_TYPE \ - -e GPU_ARCH_VERSION \ - -e LIBTORCH_VARIANT \ - -e PACKAGE_TYPE \ - -e PYTORCH_FINAL_PACKAGE_DIR \ - -e PYTORCH_ROOT \ - -e SKIP_ALL_TESTS \ - --tty \ - --detach \ - -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ - -v "${GITHUB_WORKSPACE}/builder:/builder" \ - -v "${RUNNER_TEMP}/artifacts:/final_pkgs" \ - -w / \ - "${DOCKER_IMAGE}" - ) - docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh" - # Generate test script - docker exec -t -w "${PYTORCH_ROOT}" -e OUTPUT_SCRIPT="/run.sh" "${container_name}" bash -c "bash .circleci/scripts/binary_linux_test.sh" - docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash -x /run.sh" -{%- if config["gpu_arch_type"] == "rocm" %} + uses: ./pytorch/.github/actions/pull-docker-image + with: + docker-image: !{{ config["container_image"] }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary !{{ common.teardown_rocm_linux() }} -{%- else %} - !{{ common.teardown_ec2_linux("pytorch/") }} {%- endif %} - {%- if branches == "nightly" %} + +{%- if branches == "nightly" %} !{{ upload.upload_binaries(config) }} - {%- endif %} +{%- endif %} {%- endfor %} diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index 2674cb26fb980..64bc3653e8de8 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -61,7 +61,7 @@ jobs: {%- if config["package_type"] == "libtorch" %} runs-on: macos-10.15 {%- else %} - runs-on: macos-12 + runs-on: macos-12-xl {%- endif %} {%- if config["package_type"] == "libtorch" %} # libtorch builds take a long time on github hosted runners diff --git a/.github/templates/upload.yml.j2 b/.github/templates/upload.yml.j2 index 58b877bc48c9e..f62e90cc3c45b 100644 --- a/.github/templates/upload.yml.j2 +++ b/.github/templates/upload.yml.j2 @@ -1,10 +1,16 @@ {% import 'common.yml.j2' as common %} {%- macro binary_env(config, is_windows=False) -%} - env: + env:!{{ binary_env_as_input(config, is_windows, True) }} +{%- endmacro %} + +{%- macro binary_env_as_input(config, is_windows=False, include_skip_tests=False) -%} {%- if is_windows %} PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder +{%- else %} + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder {%- endif %} PACKAGE_TYPE: !{{ config["package_type"] }} # TODO: This is a legacy variable that we eventually want to get rid of in @@ -14,23 +20,25 @@ GPU_ARCH_VERSION: !{{ config["gpu_arch_version"] }} {%- endif %} GPU_ARCH_TYPE: !{{ config["gpu_arch_type"] }} +{%- if include_skip_tests %} + SKIP_ALL_TESTS: 1 +{%- endif %} {%- if not is_windows %} DOCKER_IMAGE: !{{ config["container_image"] }} {%- endif %} - SKIP_ALL_TESTS: 1 {%- if config["package_type"] == "libtorch" %} -{%- if config["libtorch_config"] %} + {%- if config["libtorch_config"] %} LIBTORCH_CONFIG: !{{ config["libtorch_config"] }} -{%- endif %} + {%- endif %} LIBTORCH_VARIANT: !{{ config["libtorch_variant"] }} -{%- if config["devtoolset"] %} + {%- if config["devtoolset"] %} DESIRED_DEVTOOLSET: !{{ config["devtoolset"] }} -{%- endif %} -{%- if is_windows %} + {%- endif %} + {%- if is_windows %} # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.7" -{%- endif %} + {%- endif %} {%- else %} DESIRED_PYTHON: "!{{ config["python_version"] }}" {%- endif %} @@ -39,60 +47,21 @@ {%- macro upload_binaries(config, is_windows=False, has_test=True, use_s3=True) -%} !{{ config["build_name"] }}-upload: # Uploading - runs-on: linux.2xlarge # self hosted runner to download ec2 artifacts if: ${{ github.repository_owner == 'pytorch' }} {%- if has_test %} needs: !{{ config["build_name"] }}-test {%- else %} needs: !{{ config["build_name"] }}-build {%- endif %} - !{{ binary_env(config, is_windows) }} - steps: - !{{ common.setup_ec2_linux() }} - - name: Clone pytorch/pytorch - uses: actions/checkout@v2 -{%- if use_s3 %} - - uses: !{{ common.download_artifact_s3_action }} -{%- else %} - - uses: actions/download-artifact@v2 -{%- endif %} - name: Download Build Artifacts - with: - name: !{{ config["build_name"] }} - path: "${{ runner.temp }}/artifacts/" - - name: Set DRY_RUN (only for tagged pushes) - if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || (startsWith(github.event.ref, 'refs/tags/') && !startsWith(github.event.ref, 'refs/tags/ciflow/'))) }} - run: | - echo "DRY_RUN=disabled" >> "$GITHUB_ENV" - - name: Set UPLOAD_CHANNEL (only for tagged pushes) - if: ${{ github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/') && !startsWith(github.event.ref, 'refs/tags/ciflow/') }} - run: | - # reference ends with an RC suffix - if [[ ${GITHUB_REF_NAME} = *-rc[0-9]* ]]; then - echo "UPLOAD_CHANNEL=test" >> "$GITHUB_ENV" - fi - - name: Upload binaries - env: - PKG_DIR: "${{ runner.temp }}/artifacts" - UPLOAD_SUBFOLDER: "${{ env.DESIRED_CUDA }}" - # When running these on pull_request events these should be blank - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} - ANACONDA_API_TOKEN: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - run: | - docker run --rm -i \ - -e ANACONDA_API_TOKEN \ - -e AWS_ACCESS_KEY_ID \ - -e AWS_SECRET_ACCESS_KEY \ - -e DRY_RUN \ - -e PACKAGE_TYPE \ - -e PKG_DIR=/artifacts \ - -e UPLOAD_CHANNEL \ - -e UPLOAD_SUBFOLDER \ - -v "${RUNNER_TEMP}/artifacts:/artifacts" \ - -v "${GITHUB_WORKSPACE}:/v" \ - -w /v \ - 308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/miniconda3:4.10.3 \ - bash -c '.circleci/scripts/binary_upload.sh' - !{{ common.teardown_ec2_linux() }} -{%- endmacro -%} + with:!{{ binary_env_as_input(config, is_windows) }} + build_name: !{{ config["build_name"] }} + {%- if not use_s3 %} + use_s3: False + {%- endif %} + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} + aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml +{%- endmacro %} diff --git a/.github/workflows/_android-build-test.yml b/.github/workflows/_android-build-test.yml index d79b98eeb2a13..4d3e07826eaeb 100644 --- a/.github/workflows/_android-build-test.yml +++ b/.github/workflows/_android-build-test.yml @@ -11,6 +11,13 @@ on: required: true type: string description: Name of the base docker image to build with. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} @@ -57,7 +64,6 @@ jobs: - name: Build env: BUILD_ENVIRONMENT: ${{ inputs.build-environment }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-build-and-test TORCH_CUDA_ARCH_LIST: 5.2 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -67,9 +73,6 @@ jobs: # 1) Not shareable: it's custom selective build, which is different from default libtorch mobile build; # 2) Not parallelizable by architecture: it only builds libtorch for one architecture; - echo "DOCKER_IMAGE: ${DOCKER_IMAGE}" - time docker pull "${DOCKER_IMAGE}" >/dev/null - export BUILD_LITE_INTERPRETER BUILD_LITE_INTERPRETER="1" if [[ "${BUILD_ENVIRONMENT}" == *"full-jit" ]]; then @@ -79,7 +82,6 @@ jobs: git submodule sync && git submodule update -q --init --recursive --depth 1 --jobs 0 export id id=$(docker run -e BUILD_ENVIRONMENT \ - -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e SKIP_SCCACHE_INITIALIZATION=1 \ diff --git a/.github/workflows/_android-full-build-test.yml b/.github/workflows/_android-full-build-test.yml index a042ac24d04e5..efc66846db7a3 100644 --- a/.github/workflows/_android-full-build-test.yml +++ b/.github/workflows/_android-full-build-test.yml @@ -11,6 +11,13 @@ on: required: true type: string description: Name of the base docker image to build with. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. secrets: SONATYPE_NEXUS_USERNAME: diff --git a/.github/workflows/_bazel-build-test.yml b/.github/workflows/_bazel-build-test.yml index 3fb29593260e6..06786d237f073 100644 --- a/.github/workflows/_bazel-build-test.yml +++ b/.github/workflows/_bazel-build-test.yml @@ -11,6 +11,13 @@ on: required: true type: string description: Name of the base docker image to build with. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} @@ -68,7 +75,6 @@ jobs: env: BUILD_ENVIRONMENT: ${{ inputs.build-environment }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-build-and-test # TODO duplicated AWS_DEFAULT_REGION: us-east-1 SHA1: ${{ github.event.pull_request.head.sha || github.sha }} @@ -80,7 +86,6 @@ jobs: # detached container should get cleaned up by teardown_ec2_linux container_name=$(docker run \ -e BUILD_ENVIRONMENT \ - -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e SKIP_SCCACHE_INITIALIZATION=1 \ @@ -104,7 +109,6 @@ jobs: # Time out the test phase after 3.5 hours timeout-minutes: 210 env: - JOB_BASE_NAME: ${{ inputs.build-environment }}-build-and-test BUILD_ENVIRONMENT: ${{ inputs.build-environment }} PR_NUMBER: ${{ github.event.pull_request.number }} BRANCH: ${{ steps.parse-ref.outputs.branch }} @@ -117,8 +121,21 @@ jobs: run: | # detached container should get cleaned up by teardown_ec2_linux export SHARD_NUMBER=0 + COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") - export COMMIT_MESSAGES + + # sanitize the input commit message and PR body here: + # + # trim all new lines from commit messages + PR_BODY to avoid issues with batch environment + # variable copying. see https://github.com/pytorch/pytorch/pull/80043#issuecomment-1167796028 + COMMIT_MESSAGES="${COMMIT_MESSAGES//[$'\n\r']}" + PR_BODY="${PR_BODY//[$'\n\r']}" + + # then trim all special characters like single and double quotes to avoid unescaped inputs to + # wreak havoc internally + export COMMIT_MESSAGES="${COMMIT_MESSAGES//[\'\"]}" + export PR_BODY="${PR_BODY//[\'\"]}" + # TODO: Stop building test binaries as part of the build phase # Make sure we copy test results from bazel-testlogs symlink to # a regular directory ./test/test-reports @@ -128,9 +145,10 @@ jobs: -e GIT_DEFAULT_BRANCH="$GIT_DEFAULT_BRANCH" \ -e SHARD_NUMBER \ -e NUM_TEST_SHARDS \ - -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e PR_BODY \ + -e COMMIT_MESSAGES \ -e PYTORCH_RETRY_TEST_CASES \ -e PYTORCH_OVERRIDE_FLAKY_SIGNAL \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ @@ -163,7 +181,6 @@ jobs: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} BRANCH: ${{ steps.parse-ref.outputs.branch }} BUILD_ENVIRONMENT: ${{ inputs.build-environment }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-test PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_RETRY_TEST_CASES: 1 PYTORCH_OVERRIDE_FLAKY_SIGNAL: 1 diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml new file mode 100644 index 0000000000000..b1b88a5b32f80 --- /dev/null +++ b/.github/workflows/_binary-build-linux.yml @@ -0,0 +1,237 @@ +name: linux-binary-build + +on: + workflow_call: + inputs: + build_name: + required: true + type: string + description: The build's name + build_environment: + required: true + type: string + description: The build environment + PYTORCH_ROOT: + required: true + type: string + description: Root directory for the pytorch/pytorch repository + BUILDER_ROOT: + required: true + type: string + description: Root directory for the pytorch/builder repository + PACKAGE_TYPE: + required: true + type: string + description: Package type + DESIRED_CUDA: + required: true + type: string + description: Desired Cuda version + GPU_ARCH_VERSION: + required: false + type: string + description: GPU Arch version + GPU_ARCH_TYPE: + required: true + type: string + description: GPU Arch type + DOCKER_IMAGE: + required: true + type: string + description: Docker image to use + LIBTORCH_CONFIG: + required: false + type: string + description: Desired libtorch config (for libtorch builds only) + LIBTORCH_VARIANT: + required: false + type: string + description: Desired libtorch variant (for libtorch builds only) + DESIRED_DEVTOOLSET: + required: false + type: string + description: Desired dev toolset + DESIRED_PYTHON: + required: false + type: string + description: Desired python version + secrets: + github-token: + required: true + description: Github Token + +jobs: + build: + runs-on: linux.4xlarge + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ inputs.PYTORCH_ROOT }} + BUILDER_ROOT: ${{ inputs.BUILDER_ROOT }} + PACKAGE_TYPE: ${{ inputs.PACKAGE_TYPE }} + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: ${{ inputs.DESIRED_CUDA }} + GPU_ARCH_VERSION: ${{ inputs.GPU_ARCH_VERSION }} + GPU_ARCH_TYPE: ${{ inputs.GPU_ARCH_TYPE }} + DOCKER_IMAGE: ${{ inputs.DOCKER_IMAGE }} + SKIP_ALL_TESTS: 1 + LIBTORCH_CONFIG: ${{ inputs.LIBTORCH_CONFIG }} + LIBTORCH_VARIANT: ${{ inputs.LIBTORCH_VARIANT }} + DESIRED_DEVTOOLSET: ${{ inputs.DESIRED_DEVTOOLSET }} + DESIRED_PYTHON: ${{ inputs.DESIRED_PYTHON }} + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + ANACONDA_USER: pytorch + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + BUILD_ENVIRONMENT: ${{ inputs.build_environment }} + GITHUB_TOKEN: ${{ secrets.github-token }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + steps: + - name: Make the env permanent during this workflow (but not the secrets) + shell: bash + run: | + { + echo "PYTORCH_ROOT=${{ env.PYTORCH_ROOT }}" + echo "BUILDER_ROOT=${{ env.BUILDER_ROOT }}" + echo "PACKAGE_TYPE=${{ env.PACKAGE_TYPE }}" + + echo "DESIRED_CUDA=${{ env.DESIRED_CUDA }}" + echo "GPU_ARCH_VERSION=${{ env.GPU_ARCH_VERSION }}" + echo "GPU_ARCH_TYPE=${{ env.GPU_ARCH_TYPE }}" + echo "DOCKER_IMAGE=${{ env.DOCKER_IMAGE }}" + echo "SKIP_ALL_TESTS=${{ env.SKIP_ALL_TESTS }}" + echo "LIBTORCH_CONFIG=${{ env.LIBTORCH_CONFIG }}" + echo "LIBTORCH_VARIANT=${{ env.LIBTORCH_VARIANT }}" + echo "DESIRED_DEVTOOLSET=${{ env.DESIRED_DEVTOOLSET }}" + echo "DESIRED_PYTHON=${{ env.DESIRED_PYTHON }}" + + echo "ALPINE_IMAGE=${{ env.ALPINE_IMAGE }}" + echo "ANACONDA_USER=${{ env.ANACONDA_USER }}" + echo "AWS_DEFAULT_REGION=${{ env.AWS_DEFAULT_REGION }}" + echo "BINARY_ENV_FILE=${{ env.BINARY_ENV_FILE }}" + echo "BUILD_ENVIRONMENT=${{ env.BUILD_ENVIRONMENT }}" + echo "PR_NUMBER=${{ env.PR_NUMBER }}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + echo "SHA1=${{ env.SHA1 }}" + } >> "${GITHUB_ENV} }}" + - name: List the env + shell: bash + run: env + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + - name: Setup Linux + uses: ./.github/actions/setup-linux + - name: Chown workspace + uses: ./.github/actions/chown-workspace + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: ./.github/actions/setup-ssh + with: + github-secret: ${{ secrets.github-token }} + - name: Clean workspace + shell: bash + run: | + rm -rf "${GITHUB_WORKSPACE}" + mkdir "${GITHUB_WORKSPACE}" + + - name: Checkout PyTorch to pytorch dir + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + + - name: Checkout pytorch/builder to builder dir + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + + - name: Set BUILD_SPLIT_CUDA + if: ${{ inputs.GPU_ARCH_TYPE == 'cuda' && startsWith(inputs.GPU_ARCH_VERSION, '11') }} + shell: bash + run: | + echo "BUILD_SPLIT_CUDA='ON'" >> "$GITHUB_ENV" + - name: Pull Docker image + run: | + retry () { + "$@" || (sleep 1 && "$@") || (sleep 2 && "$@") + } + retry docker pull "${DOCKER_IMAGE}" + - name: Build PyTorch binary + run: | + set -x + mkdir -p artifacts/ + container_name=$(docker run \ + -e BINARY_ENV_FILE \ + -e BUILDER_ROOT \ + -e BUILD_ENVIRONMENT \ + -e BUILD_SPLIT_CUDA \ + -e DESIRED_CUDA \ + -e DESIRED_DEVTOOLSET \ + -e DESIRED_PYTHON \ + -e GITHUB_ACTIONS \ + -e GPU_ARCH_TYPE \ + -e GPU_ARCH_VERSION \ + -e LIBTORCH_VARIANT \ + -e PACKAGE_TYPE \ + -e PYTORCH_FINAL_PACKAGE_DIR \ + -e PYTORCH_ROOT \ + -e SKIP_ALL_TESTS \ + --tty \ + --detach \ + -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ + -v "${GITHUB_WORKSPACE}/builder:/builder" \ + -v "${RUNNER_TEMP}/artifacts:/artifacts" \ + -w / \ + "${DOCKER_IMAGE}" + ) + docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh" + docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /builder/${{ inputs.PACKAGE_TYPE }}/build.sh" + - name: Chown artifacts + if: always() + shell: bash + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "${RUNNER_TEMP}/artifacts:/v" -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + + - uses: seemethere/upload-artifact-s3@v5 + with: + name: ${{ inputs.build_name }} + retention-days: 14 + if-no-files-found: error + path: + ${{ runner.temp }}/artifacts/* + + - name: Hold runner for 2 hours or until ssh sessions have drained + working-directory: pytorch/ + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Chown workspace + if: always() + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml new file mode 100644 index 0000000000000..5c29288b82462 --- /dev/null +++ b/.github/workflows/_binary-test-linux.yml @@ -0,0 +1,212 @@ +name: linux-binary-test + +on: + workflow_call: + inputs: + build_name: + required: true + type: string + description: The build's name + build_environment: + required: true + type: string + description: The build environment + PYTORCH_ROOT: + required: true + type: string + description: Root directory for the pytorch/pytorch repository + BUILDER_ROOT: + required: true + type: string + description: Root directory for the pytorch/builder repository + PACKAGE_TYPE: + required: true + type: string + description: Package type + DESIRED_CUDA: + required: true + type: string + description: Desired Cuda version + GPU_ARCH_VERSION: + required: false + type: string + description: GPU Arch version + GPU_ARCH_TYPE: + required: true + type: string + description: GPU Arch type + DOCKER_IMAGE: + required: true + type: string + description: Docker image to use + LIBTORCH_CONFIG: + required: false + type: string + description: Desired libtorch config (for libtorch builds only) + LIBTORCH_VARIANT: + required: false + type: string + description: Desired libtorch variant (for libtorch builds only) + DESIRED_DEVTOOLSET: + required: false + type: string + description: Desired dev toolset + DESIRED_PYTHON: + required: false + type: string + description: Desired python version + runs_on: + required: true + type: string + description: Hardware to run this job on. Valid values are linux.4xlarge, linux.4xlarge.nvidia.gpu, and linux.rocm.gpu + secrets: + github-token: + required: true + description: Github Token + +jobs: + build: + runs-on: ${{ inputs.runs_on }} + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ inputs.PYTORCH_ROOT }} + BUILDER_ROOT: ${{ inputs.BUILDER_ROOT }} + PACKAGE_TYPE: ${{ inputs.PACKAGE_TYPE }} + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: ${{ inputs.DESIRED_CUDA }} + GPU_ARCH_VERSION: ${{ inputs.GPU_ARCH_VERSION }} + GPU_ARCH_TYPE: ${{ inputs.GPU_ARCH_TYPE }} + DOCKER_IMAGE: ${{ inputs.DOCKER_IMAGE }} + SKIP_ALL_TESTS: 1 + LIBTORCH_CONFIG: ${{ inputs.LIBTORCH_CONFIG }} + LIBTORCH_VARIANT: ${{ inputs.LIBTORCH_VARIANT }} + DESIRED_DEVTOOLSET: ${{ inputs.DESIRED_DEVTOOLSET }} + DESIRED_PYTHON: ${{ inputs.DESIRED_PYTHON }} + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + ANACONDA_USER: pytorch + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + BUILD_ENVIRONMENT: ${{ inputs.build_environment }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + steps: + - name: Make the env permanent during this workflow (but not the secrets) + shell: bash + run: | + { + echo "PYTORCH_ROOT=${{ env.PYTORCH_ROOT }}" + echo "BUILDER_ROOT=${{ env.BUILDER_ROOT }}" + echo "PACKAGE_TYPE=${{ env.PACKAGE_TYPE }}" + + echo "DESIRED_CUDA=${{ env.DESIRED_CUDA }}" + echo "GPU_ARCH_VERSION=${{ env.GPU_ARCH_VERSION }}" + echo "GPU_ARCH_TYPE=${{ env.GPU_ARCH_TYPE }}" + echo "DOCKER_IMAGE=${{ env.DOCKER_IMAGE }}" + echo "SKIP_ALL_TESTS=${{ env.SKIP_ALL_TESTS }}" + echo "LIBTORCH_CONFIG=${{ env.LIBTORCH_CONFIG }}" + echo "LIBTORCH_VARIANT=${{ env.LIBTORCH_VARIANT }}" + echo "DESIRED_DEVTOOLSET=${{ env.DESIRED_DEVTOOLSET }}" + echo "DESIRED_PYTHON=${{ env.DESIRED_PYTHON }}" + + echo "ALPINE_IMAGE=${{ env.ALPINE_IMAGE }}" + echo "ANACONDA_USER=${{ env.ANACONDA_USER }}" + echo "AWS_DEFAULT_REGION=${{ env.AWS_DEFAULT_REGION }}" + echo "BINARY_ENV_FILE=${{ env.BINARY_ENV_FILE }}" + echo "BUILD_ENVIRONMENT=${{ env.BUILD_ENVIRONMENT }}" + echo "PR_NUMBER=${{ env.PR_NUMBER }}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + echo "SHA1=${{ env.SHA1 }}" + } >> "${GITHUB_ENV} }}" + + # Setup the environment + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + - name: Setup Linux + uses: ./.github/actions/setup-linux + - name: Chown workspace + uses: ./.github/actions/chown-workspace + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: ./.github/actions/setup-ssh + with: + github-secret: ${{ secrets.github-token }} + - name: Clean workspace + shell: bash + run: | + rm -rf "${GITHUB_WORKSPACE}" + mkdir "${GITHUB_WORKSPACE}" + + - uses: seemethere/download-artifact-s3@v4 + name: Download Build Artifacts + with: + name: ${{ inputs.build_name }} + path: "${{ runner.temp }}/artifacts/" + + + - name: Checkout PyTorch to pytorch dir + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + + - name: Checkout pytorch/builder to builder dir + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + + - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG + uses: nick-fields/retry@71062288b76e2b6214ebde0e673ce0de1755740a + if: ${{ inputs.GPU_ARCH_TYPE == 'cuda' }} + with: + timeout_minutes: 10 + max_attempts: 3 + command: | + set -ex + pushd pytorch + bash .github/scripts/install_nvidia_utils_linux.sh + echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" + popd + + - name: Pull Docker image + uses: ./pytorch/.github/actions/pull-docker-image + with: + docker-image: ${{ inputs.DOCKER_IMAGE }} + + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + + - name: Hold runner for 2 hours or until ssh sessions have drained + working-directory: pytorch/ + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Chown workspace + if: always() + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af diff --git a/.github/workflows/_binary-upload.yml b/.github/workflows/_binary-upload.yml new file mode 100644 index 0000000000000..cf47de9ccf212 --- /dev/null +++ b/.github/workflows/_binary-upload.yml @@ -0,0 +1,178 @@ +name: upload + +on: + workflow_call: + inputs: + build_name: + required: true + type: string + description: The build's name + use_s3: + type: boolean + default: true + description: If true, will download artifacts from s3. Otherwise will use the default github artifact download action + PYTORCH_ROOT: + required: false + type: string + description: Root directory for the pytorch/pytorch repository. Not actually needed, but currently passing it in since we pass in the same inputs to the reusable workflows of all binary builds + BUILDER_ROOT: + required: false + type: string + description: Root directory for the pytorch/builder repository. Not actually needed, but currently passing it in since we pass in the same inputs to the reusable workflows of all binary builds + PACKAGE_TYPE: + required: true + type: string + description: Package type + DESIRED_CUDA: + required: true + type: string + description: Desired Cuda version + GPU_ARCH_VERSION: + required: false + type: string + description: GPU Arch version + GPU_ARCH_TYPE: + required: true + type: string + description: GPU Arch type + DOCKER_IMAGE: + required: false + type: string + description: Docker image to use + LIBTORCH_CONFIG: + required: false + type: string + description: Desired libtorch config (for libtorch builds only) + LIBTORCH_VARIANT: + required: false + type: string + description: Desired libtorch variant (for libtorch builds only) + DESIRED_DEVTOOLSET: + required: false + type: string + description: Desired dev toolset + DESIRED_PYTHON: + required: false + type: string + description: Desired python version + secrets: + github-token: + required: true + description: Github Token + aws-access-key-id: + required: true + description: AWS access key id + aws-pytorch-uploader-secret-access-key: + required: true + description: AWS secret access key + conda-pytorchbot-token: + required: true + description: Conda PyTorchBot token +jobs: + build: + runs-on: linux.2xlarge + env: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: ${{ inputs.PACKAGE_TYPE }} + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: ${{ inputs.DESIRED_CUDA }} + GPU_ARCH_VERSION: ${{ inputs.GPU_ARCH_VERSION }} + GPU_ARCH_TYPE: ${{ inputs.GPU_ARCH_TYPE }} + DOCKER_IMAGE: ${{ inputs.DOCKER_IMAGE }} + SKIP_ALL_TESTS: 1 + LIBTORCH_CONFIG: ${{ inputs.LIBTORCH_CONFIG }} + LIBTORCH_VARIANT: ${{ inputs.LIBTORCH_VARIANT }} + DESIRED_DEVTOOLSET: ${{ inputs.DESIRED_DEVTOOLSET }} + DESIRED_PYTHON: ${{ inputs.DESIRED_PYTHON }} + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + ANACONDA_USER: pytorch + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + GITHUB_TOKEN: ${{ secrets.github-token }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + steps: + - name: List the env + shell: bash + run: env + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + - name: Setup Linux + uses: ./.github/actions/setup-linux + - name: Chown workspace + uses: ./.github/actions/chown-workspace + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: ./.github/actions/setup-ssh + with: + github-secret: ${{ secrets.github-token }} + + - name: Download Build Artifacts with S3 + uses: seemethere/download-artifact-s3@v4 + if: ${{ inputs.use_s3 }} + with: + name: ${{ inputs.build_name }} + path: "${{ runner.temp }}/artifacts/" + + - name: Download Build Artifacts without S3 + uses: actions/download-artifact@v2 + if: ${{ !inputs.use_s3 }} + with: + name: ${{ inputs.build_name }} + path: "${{ runner.temp }}/artifacts/" + + - name: Set DRY_RUN (only for tagged pushes) + if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || (startsWith(github.event.ref, 'refs/tags/') && !startsWith(github.event.ref, 'refs/tags/ciflow/'))) }} + run: | + echo "DRY_RUN=disabled" >> "$GITHUB_ENV" + - name: Set UPLOAD_CHANNEL (only for tagged pushes) + if: ${{ github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/') && !startsWith(github.event.ref, 'refs/tags/ciflow/') }} + run: | + # reference ends with an RC suffix + if [[ ${GITHUB_REF_NAME} = *-rc[0-9]* ]]; then + echo "UPLOAD_CHANNEL=test" >> "$GITHUB_ENV" + fi + - name: Upload binaries + env: + PKG_DIR: "${{ runner.temp }}/artifacts" + UPLOAD_SUBFOLDER: "${{ env.DESIRED_CUDA }}" + # When running these on pull_request events these should be blank + AWS_ACCESS_KEY_ID: ${{ secrets.aws-access-key-id }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.aws-pytorch-uploader-secret-access-key }} + ANACONDA_API_TOKEN: ${{ secrets.conda-pytorchbot-token }} + run: | + docker run --rm -i \ + -e ANACONDA_API_TOKEN \ + -e AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY \ + -e DRY_RUN \ + -e PACKAGE_TYPE \ + -e PKG_DIR=/artifacts \ + -e UPLOAD_CHANNEL \ + -e UPLOAD_SUBFOLDER \ + -v "${RUNNER_TEMP}/artifacts:/artifacts" \ + -v "${GITHUB_WORKSPACE}:/v" \ + -w /v \ + 308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/miniconda3:4.10.3 \ + bash -c '.circleci/scripts/binary_upload.sh' + + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Chown workspace + if: always() + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af diff --git a/.github/workflows/_buck-build-test.yml b/.github/workflows/_buck-build-test.yml index ba9761e3254c1..ae7f7517e2eda 100644 --- a/.github/workflows/_buck-build-test.yml +++ b/.github/workflows/_buck-build-test.yml @@ -10,8 +10,6 @@ defaults: jobs: buck-build-test: runs-on: ubuntu-latest - env: - JOB_BASE_NAME: ubuntu-latest-buck steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -64,21 +62,17 @@ jobs: command: | sh scripts/buck_setup.sh - - name: Build glog + - name: Build tools run: | - buck build third_party:glog + buck build tools: --keep-going - - name: Build C10 + - name: Run tools tests run: | - buck build c10:c10 - - - name: Build cpuinfo - run: | - buck build third_party:cpuinfo + buck test tools:selective_build_test tools:gen_oplist_test tools:gen_operators_yaml_test - - name: Build pthreadpool + - name: Build c10 run: | - buck build third_party:pthreadpool + buck build c10:c10 - name: Build XNNPACK run: | @@ -86,7 +80,11 @@ jobs: - name: Build QNNPACK run: | - buck build aten/src/ATen/native/quantized/cpu/qnnpack/... --keep-going + buck build aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack + + - name: Test QNNPACK + run: | + buck test aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack_test - name: Build aten_cpu run: | @@ -96,9 +94,9 @@ jobs: run: | buck build :torch_mobile_core - - name: Build torch_mobile_all_ops + - name: Build pt_ops_full run: | - buck build :torch_mobile_all_ops + buck build :pt_ops_full - name: Build mobile benchmark run: | diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index 9ba5e0eb686d9..de28790f8c5e9 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -16,6 +16,18 @@ on: type: boolean default: false description: If set, push the docs to the docs website. + run-doxygen: + required: false + type: boolean + default: false + description: If set, will enable C++ API doc generation using doxygen / breathe / exhale. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. secrets: GH_PYTORCHBOT_TOKEN: @@ -68,6 +80,7 @@ jobs: WITH_PUSH: ${{ github.event_name == 'schedule' || startsWith(github.event.ref, 'refs/tags/v') }} DOCKER_IMAGE: ${{ inputs.docker-image }} DOCS_TYPE: ${{ matrix.docs_type }} + RUN_DOXYGEN: ${{ inputs.run-doxygen }} BUILD_ENVIRONMENT: ${{ inputs.build-environment }} run: | set -ex @@ -84,6 +97,7 @@ jobs: -e SHA1="$GITHUB_SHA" \ -e DOCS_VERSION="${target}" \ -e DOCS_TYPE \ + -e RUN_DOXYGEN \ -e WITH_PUSH \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ diff --git a/.github/workflows/_ios-build-test.yml b/.github/workflows/_ios-build-test.yml index c8098b5a8dbbd..189e21d210e59 100644 --- a/.github/workflows/_ios-build-test.yml +++ b/.github/workflows/_ios-build-test.yml @@ -15,6 +15,13 @@ on: required: true type: string description: Which iOS arch to build for. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. secrets: IOS_CERT_KEY_2022: @@ -41,10 +48,9 @@ jobs: # NOTE: These builds will not run successfully without running on `pytorch/pytorch` due to the limitations # of accessing secrets from forked pull requests and IOS' dependency on secrets for their build/test if: github.repository_owner == 'pytorch' - runs-on: macos-10.15 + runs-on: macos-12 timeout-minutes: 240 env: - JOB_BASE_NAME: ${{ inputs.build-environment }}-build IOS_CERT_KEY_2022: ${{ secrets.IOS_CERT_KEY_2022 }} IOS_CERT_SECRET: ${{ secrets.IOS_CERT_SECRET }} IOS_DEV_TEAM_ID: ${{ secrets.IOS_DEV_TEAM_ID }} @@ -108,6 +114,7 @@ jobs: cd ios/TestApp # install fastlane sudo gem install bundler && bundle install + bundle update fastlane # install certificates echo "${IOS_CERT_KEY_2022}" >> cert.txt base64 --decode cert.txt -o Certificates.p12 @@ -151,6 +158,7 @@ jobs: run: | # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" + # use the pytorch nightly build to generate models pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html # generate models for differnet backends cd "${GITHUB_WORKSPACE}/ios/TestApp/benchmark" @@ -172,13 +180,13 @@ jobs: ruby setup.rb fi cd "${GITHUB_WORKSPACE}/ios/TestApp" - instruments -s -devices + # instruments -s -devices if [ "${BUILD_LITE_INTERPRETER}" == 1 ]; then if [ "${USE_COREML_DELEGATE}" == 1 ]; then - fastlane scan --only_testing TestAppTests/TestAppTests/testCoreML + bundle exec fastlane scan --only_testing TestAppTests/TestAppTests/testCoreML else - fastlane scan --skip_testing TestAppTests/TestAppTests/testCoreML + bundle exec fastlane scan --skip_testing TestAppTests/TestAppTests/testCoreML fi else - fastlane scan --only_testing TestAppTests/TestAppTests/testFullJIT + bundle exec fastlane scan --only_testing TestAppTests/TestAppTests/testFullJIT fi diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index 9a9cae06f4e08..09a400c4d502c 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -21,6 +21,13 @@ on: type: boolean default: false description: If set, build in debug mode. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. outputs: docker-image: @@ -85,12 +92,12 @@ jobs: env: BUILD_ENVIRONMENT: ${{ inputs.build-environment }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-build # TODO duplicated AWS_DEFAULT_REGION: us-east-1 PR_NUMBER: ${{ github.event.pull_request.number }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + SCCACHE_S3_KEY_PREFIX: ${{ github.workflow }} XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} TORCH_CUDA_ARCH_LIST: 5.2 @@ -102,13 +109,13 @@ jobs: # detached container should get cleaned up by teardown_ec2_linux container_name=$(docker run \ -e BUILD_ENVIRONMENT \ - -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e AWS_DEFAULT_REGION \ -e PR_NUMBER \ -e SHA1 \ -e BRANCH \ -e SCCACHE_BUCKET \ + -e SCCACHE_S3_KEY_PREFIX \ -e XLA_CUDA \ -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e SKIP_SCCACHE_INITIALIZATION=1 \ diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index ac1ea422952c0..aa81647c53fcf 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -15,6 +15,13 @@ on: required: true type: string description: Docker image to run in. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} @@ -56,6 +63,15 @@ jobs: bash .github/scripts/install_nvidia_utils_linux.sh echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" + - name: Start monitoring script + id: monitor-script + shell: bash + run: | + python3 -m pip install psutil==5.9.1 + python3 -m pip install pynvml==11.4.1 + python3 -m tools.stats.monitor > usage_log.txt 2>&1 & + echo "::set-output name=monitor-script-pid::${!}" + - name: Download build artifacts uses: ./.github/actions/download-build-artifacts with: @@ -75,7 +91,6 @@ jobs: BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }} PYTORCH_RETRY_TEST_CASES: 1 PYTORCH_OVERRIDE_FLAKY_SIGNAL: 1 - JOB_BASE_NAME: ${{ inputs.build-environment }}-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} @@ -98,7 +113,18 @@ jobs: fi COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") - export COMMIT_MESSAGES + + # sanitize the input commit message and PR body here: + # + # trim all new lines from commit messages + PR_BODY to avoid issues with batch environment + # variable copying. see https://github.com/pytorch/pytorch/pull/80043#issuecomment-1167796028 + COMMIT_MESSAGES="${COMMIT_MESSAGES//[$'\n\r']}" + PR_BODY="${PR_BODY//[$'\n\r']}" + + # then trim all special characters like single and double quotes to avoid unescaped inputs to + # wreak havoc internally + export COMMIT_MESSAGES="${COMMIT_MESSAGES//[\'\"]}" + export PR_BODY="${PR_BODY//[\'\"]}" # detached container should get cleaned up by teardown_ec2_linux # TODO: Stop building test binaries as part of the build phase @@ -115,7 +141,6 @@ jobs: -e AWS_DEFAULT_REGION \ -e IN_WHEEL_TEST \ -e SHARD_NUMBER \ - -e JOB_BASE_NAME \ -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PR_BODY \ @@ -150,6 +175,14 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} + - name: Stop monitoring script + if: always() && steps.monitor-script.outputs.monitor-script-pid + shell: bash + env: + MONITOR_SCRIPT_PID: ${{ steps.monitor-script.outputs.monitor-script-pid }} + run: | + kill "$MONITOR_SCRIPT_PID" + - name: Upload test artifacts uses: ./.github/actions/upload-test-artifacts if: always() && (steps.test.conclusion == 'success' || steps.test.conclusion == 'failure') @@ -171,7 +204,6 @@ jobs: AWS_DEFAULT_REGION: us-east-1 GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} BUILD_ENVIRONMENT: ${{ inputs.build-environment }} diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index d9747a3ebbddc..f17bd649c7131 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -20,6 +20,13 @@ on: type: string default: "" description: What xcode version to build with. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. secrets: MACOS_SCCACHE_S3_ACCESS_KEY_ID: @@ -40,7 +47,6 @@ jobs: if: github.repository_owner == 'pytorch' runs-on: ${{ inputs.runner-type }} env: - JOB_BASE_NAME: ${{ inputs.build-environment }} # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -77,6 +83,7 @@ jobs: sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" + echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" - name: Get workflow job id id: get-job-id @@ -95,7 +102,7 @@ jobs: - name: Archive artifacts into zip if: inputs.build-generates-artifacts run: | - zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json + zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json - name: Store PyTorch Build Artifacts on GHA uses: actions/upload-artifact@v2 diff --git a/.github/workflows/_mac-test-arm64.yml b/.github/workflows/_mac-test-arm64.yml index e8896df3b7d68..14502a32ad684 100644 --- a/.github/workflows/_mac-test-arm64.yml +++ b/.github/workflows/_mac-test-arm64.yml @@ -7,12 +7,19 @@ on: required: true type: string description: Top-level label for what's being built/tested. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. jobs: run_mps_test: name: "Run MPS tests" - runs-on: macos12.3-m1 + runs-on: macos-m1-12 steps: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -40,7 +47,7 @@ jobs: # shellcheck disable=SC1090 . ~/miniconda3/etc/profile.d/conda.sh set -ex - conda create -yp "${ENV_NAME}" "python=${PY_VERS}" numpy expecttest + conda create -yp "${ENV_NAME}" "python=${PY_VERS}" numpy expecttest pyyaml # As wheels are cross-compiled they are reported as x86_64 ones ORIG_WHLNAME=$(ls -1 dist/*.whl); ARM_WHLNAME=${ORIG_WHLNAME/x86_64/arm64}; mv ${ORIG_WHLNAME} ${ARM_WHLNAME} conda run -p "${ENV_NAME}" python3 -mpip install dist/*.whl diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 651f291ad80c9..e919bef85a67a 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -11,6 +11,13 @@ on: required: true type: string description: JSON description of what test configs to run. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: @@ -37,7 +44,6 @@ jobs: env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} BUILD_ENVIRONMENT: ${{ inputs.build-environment }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} @@ -49,6 +55,15 @@ jobs: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + - name: Start monitoring script + id: monitor-script + shell: bash + run: | + python3 -m pip install psutil==5.9.1 + python3 -m pip install pynvml==11.4.1 + python3 -m tools.stats.monitor > usage_log.txt 2>&1 & + echo "::set-output name=monitor-script-pid::${!}" + - name: Download build artifacts uses: ./.github/actions/download-build-artifacts with: @@ -76,7 +91,19 @@ jobs: id: test run: | COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") - export COMMIT_MESSAGES + + # sanitize the input commit message and PR body here: + # + # trim all new lines from commit messages + PR_BODY to avoid issues with batch environment + # variable copying. see https://github.com/pytorch/pytorch/pull/80043#issuecomment-1167796028 + COMMIT_MESSAGES="${COMMIT_MESSAGES//[$'\n\r']}" + PR_BODY="${PR_BODY//[$'\n\r']}" + + # then trim all special characters like single and double quotes to avoid unescaped inputs to + # wreak havoc internally + export COMMIT_MESSAGES="${COMMIT_MESSAGES//[\'\"]}" + export PR_BODY="${PR_BODY//[\'\"]}" + python3 -mpip install dist/*.whl .jenkins/pytorch/macos-test.sh @@ -87,6 +114,14 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} + - name: Stop monitoring script + if: always() && steps.monitor-script.outputs.monitor-script-pid + shell: bash + env: + MONITOR_SCRIPT_PID: ${{ steps.monitor-script.outputs.monitor-script-pid }} + run: | + kill "$MONITOR_SCRIPT_PID" + - name: Upload test artifacts uses: ./.github/actions/upload-test-artifacts if: always() && (steps.test.conclusion == 'success' || steps.test.conclusion == 'failure') @@ -100,7 +135,6 @@ jobs: AWS_DEFAULT_REGION: us-east-1 GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} BUILD_ENVIRONMENT: ${{ inputs.build-environment }} diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 61ecf9937f395..b5550fdda7f0a 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -19,6 +19,13 @@ on: required: true type: string description: Docker image to run in. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: @@ -55,6 +62,15 @@ jobs: with: docker-image: ${{ inputs.docker-image }} + - name: Start monitoring script + id: monitor-script + shell: bash + run: | + python3 -m pip install psutil==5.9.1 + python3 -m pip install pynvml==11.4.1 + python3 -m tools.stats.monitor > usage_log.txt 2>&1 & + echo "::set-output name=monitor-script-pid::${!}" + - name: Download build artifacts uses: ./.github/actions/download-build-artifacts with: @@ -73,7 +89,6 @@ jobs: SHA1: ${{ github.event.pull_request.head.sha || github.sha }} PYTORCH_RETRY_TEST_CASES: 1 PYTORCH_OVERRIDE_FLAKY_SIGNAL: 1 - JOB_BASE_NAME: ${{ inputs.build-environment }}-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} @@ -94,7 +109,18 @@ jobs: fi COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") - export COMMIT_MESSAGES + + # sanitize the input commit message and PR body here: + # + # trim all new lines from commit messages + PR_BODY to avoid issues with batch environment + # variable copying. see https://github.com/pytorch/pytorch/pull/80043#issuecomment-1167796028 + COMMIT_MESSAGES="${COMMIT_MESSAGES//[$'\n\r']}" + PR_BODY="${PR_BODY//[$'\n\r']}" + + # then trim all special characters like single and double quotes to avoid unescaped inputs to + # wreak havoc internally + export COMMIT_MESSAGES="${COMMIT_MESSAGES//[\'\"]}" + export PR_BODY="${PR_BODY//[\'\"]}" # detached container should get cleaned up by teardown_ec2_linux # TODO: Stop building test binaries as part of the build phase @@ -110,7 +136,6 @@ jobs: -e AWS_DEFAULT_REGION \ -e IN_WHEEL_TEST \ -e SHARD_NUMBER \ - -e JOB_BASE_NAME \ -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PR_BODY \ @@ -151,6 +176,14 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} + - name: Stop monitoring script + if: always() && steps.monitor-script.outputs.monitor-script-pid + shell: bash + env: + MONITOR_SCRIPT_PID: ${{ steps.monitor-script.outputs.monitor-script-pid }} + run: | + kill "$MONITOR_SCRIPT_PID" + - name: Upload test artifacts uses: ./.github/actions/upload-test-artifacts if: always() && (steps.test.conclusion == 'success' || steps.test.conclusion == 'failure') @@ -164,7 +197,6 @@ jobs: AWS_DEFAULT_REGION: us-east-1 GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} BUILD_ENVIRONMENT: ${{ inputs.build-environment }} diff --git a/.github/workflows/run_android_tests.yml b/.github/workflows/_run_android_tests.yml similarity index 70% rename from .github/workflows/run_android_tests.yml rename to .github/workflows/_run_android_tests.yml index 5b29734756365..273ec2db81aed 100644 --- a/.github/workflows/run_android_tests.yml +++ b/.github/workflows/_run_android_tests.yml @@ -1,33 +1,15 @@ name: android-tests on: - push: - tags: - # Trigger on release candidate builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ - - 'ciflow/trunk/*' - - 'ciflow/android/*' - branches: - - master - - main - - release/* - workflow_dispatch: - -concurrency: - group: run-android-tests-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true + workflow_call: defaults: run: shell: bash -e -l {0} jobs: - build-and-test: runs-on: ubuntu-latest - env: - JOB_BASE_NAME: ubuntu-latest-android-tests steps: - name: Setup miniconda uses: conda-incubator/setup-miniconda@v2 @@ -61,10 +43,10 @@ jobs: ANDROID_ROOT="/usr/local/lib/android" ANDROID_SDK_ROOT="${ANDROID_ROOT}/sdk" SDKMANAGER="${ANDROID_SDK_ROOT}/cmdline-tools/latest/bin/sdkmanager" - echo "y" | $SDKMANAGER "ndk;21.4.7075529" + echo "y" | ${SDKMANAGER} "ndk;21.4.7075529" export ANDROID_NDK="${ANDROID_SDK_ROOT}/ndk-bundle" - ln -sfn $ANDROID_SDK_ROOT/ndk/21.4.7075529 $ANDROID_NDK + ln -sfn ${ANDROID_SDK_ROOT}/ndk/21.4.7075529 ${ANDROID_NDK} echo "CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname "$(which conda)")/../"}" >> "${GITHUB_ENV}" ./scripts/build_pytorch_android.sh x86 diff --git a/.github/workflows/_update-commit-hash.yml b/.github/workflows/_update-commit-hash.yml index 541e164b479ef..42e12d9dca9f6 100644 --- a/.github/workflows/_update-commit-hash.yml +++ b/.github/workflows/_update-commit-hash.yml @@ -36,7 +36,7 @@ jobs: - name: Checkout shell: bash run: | - git clone https://github.com/pytorch/${{ inputs.repo-name }}.git --depth=1 --quiet + git clone https://github.com/pytorch/${{ inputs.repo-name }}.git --quiet - name: Check if there already exists a PR shell: bash diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index 6e0c423cc259d..fb2195fafce66 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -16,6 +16,13 @@ on: type: boolean default: false description: If set, build in debug mode. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} @@ -26,8 +33,6 @@ jobs: if: github.repository_owner == 'pytorch' runs-on: [self-hosted, windows.4xlarge] timeout-minutes: 240 - env: - JOB_BASE_NAME: ${{ inputs.build-environment }}-build steps: # [see note: pytorch repo ref] - name: Checkout PyTorch @@ -67,6 +72,7 @@ jobs: CUDA_VERSION: ${{ inputs.cuda-version }} PYTHON_VERSION: "3.8" SCCACHE_BUCKET: "ossci-compiler-cache" + SCCACHE_S3_KEY_PREFIX: ${{ github.workflow }} VC_PRODUCT: "BuildTools" VC_VERSION: "" VC_YEAR: "2019" @@ -87,7 +93,7 @@ jobs: with: retention-days: 14 if-no-files-found: error - name: ${{ env.BUILD_ENVIRONMENT }} + name: ${{ inputs.build-environment }} path: C:\${{ github.run_id }}\build-results - name: Upload sccache stats diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index af54526d8588f..560c0fe84e1d4 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -15,6 +15,13 @@ on: required: true type: string description: JSON description of what test configs to run. + sync-tag: + required: false + type: string + default: "" + description: | + If this is set, our linter will use this to make sure that every other + job with the same `sync-tag` is identical. env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} @@ -45,10 +52,19 @@ jobs: with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Start monitoring script + id: monitor-script + shell: bash + run: | + python3 -m pip install psutil==5.9.1 + python3 -m pip install pynvml==11.4.1 + python3 -m tools.stats.monitor > usage_log.txt 2>&1 & + echo "::set-output name=monitor-script-pid::${!}" + - name: Download PyTorch Build Artifacts uses: seemethere/download-artifact-s3@v4 with: - name: ${{ env.BUILD_ENVIRONMENT }} + name: ${{ inputs.build-environment }} path: C:\${{ github.run_id }}\build-results - name: Check build-results folder @@ -79,12 +95,23 @@ jobs: SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} TEST_CONFIG: ${{ matrix.config }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-test PR_BODY: ${{ github.event.pull_request.body }} TORCH_CUDA_ARCH_LIST: "7.0" run: | COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") - export COMMIT_MESSAGES + + # sanitize the input commit message and PR body here: + # + # trim all new lines from commit messages + PR_BODY to avoid issues with batch environment + # variable copying. see https://github.com/pytorch/pytorch/pull/80043#issuecomment-1167796028 + COMMIT_MESSAGES="${COMMIT_MESSAGES//[$'\n\r']}" + PR_BODY="${PR_BODY//[$'\n\r']}" + + # then trim all special characters like single and double quotes to avoid unescaped inputs to + # wreak havoc internally + export COMMIT_MESSAGES="${COMMIT_MESSAGES//[\'\"]}" + export PR_BODY="${PR_BODY//[\'\"]}" + .jenkins/pytorch/win-test.sh - name: Get workflow job id @@ -94,6 +121,14 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} + - name: Stop monitoring script + if: always() && steps.monitor-script.outputs.monitor-script-pid + shell: bash + env: + MONITOR_SCRIPT_PID: ${{ steps.monitor-script.outputs.monitor-script-pid }} + run: | + kill "$MONITOR_SCRIPT_PID" + - name: Upload test artifacts uses: ./.github/actions/upload-test-artifacts if: always() && (steps.test.conclusion == 'success' || steps.test.conclusion == 'failure') @@ -110,7 +145,6 @@ jobs: AWS_DEFAULT_REGION: us-east-1 GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: ${{ inputs.build-environment }}-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} BUILD_ENVIRONMENT: ${{ inputs.build-environment }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5e03703511c79..002f25561c358 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,6 +7,7 @@ on: - master - main - release/* + - landchecks/* workflow_dispatch: jobs: diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index c95d276d3565d..133aa56865c70 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -32,5 +32,6 @@ jobs: build-environment: linux-focal-py3.7-gcc7 docker-image: ${{ needs.docs-build.outputs.docker-image }} push: true + run-doxygen: true secrets: GH_PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }} diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 5f87921a0fa36..0e3e565deb914 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -9,38 +9,10 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}--${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: - linux-bionic-cuda11_6-py3_7-gcc7-build: - name: linux-bionic-cuda11.6-py3.7-gcc7 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-bionic-cuda11.6-py3.7-gcc7 - docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 - - linux-bionic-cuda11_6-py3_7-gcc7-test: - name: linux-bionic-cuda11.6-py3.7-gcc7 - uses: ./.github/workflows/_linux-test.yml - needs: linux-bionic-cuda11_6-py3_7-gcc7-build - with: - build-environment: linux-bionic-cuda11.6-py3.7-gcc7 - docker-image: ${{ needs.linux-bionic-cuda11_6-py3_7-gcc7-build.outputs.docker-image }} - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, - ]} - - libtorch-linux-bionic-cuda11_6-py3_7-gcc7-build: - name: libtorch-linux-bionic-cuda11.6-py3.7-gcc7 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: libtorch-linux-bionic-cuda11.6-py3.7-gcc7 - docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 - build-generates-artifacts: false - linux-xenial-cuda10_2-py3-gcc7-slow-gradcheck-build: name: linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck uses: ./.github/workflows/_linux-build.yml @@ -61,20 +33,20 @@ jobs: { config: "default", shard: 2, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, ]} - linux-bionic-rocm5_1-py3_7-slow-build: - name: linux-bionic-rocm5.1-py3.7-slow + linux-focal-rocm5_2-py3_7-slow-build: + name: linux-focal-rocm5.2-py3.7-slow uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-bionic-rocm5.1-py3.7 - docker-image-name: pytorch-linux-bionic-rocm5.1-py3.7 + build-environment: linux-focal-rocm5.2-py3.7 + docker-image-name: pytorch-linux-focal-rocm5.2-py3.7 - linux-bionic-rocm5_1-py3_7-slow-test: - name: linux-bionic-rocm5.1-py3.7-slow + linux-focal-rocm5_2-py3_7-slow-test: + name: linux-focal-rocm5.2-py3.7-slow uses: ./.github/workflows/_rocm-test.yml - needs: linux-bionic-rocm5_1-py3_7-slow-build + needs: linux-focal-rocm5_2-py3_7-slow-build with: - build-environment: linux-bionic-rocm5.1-py3.7 - docker-image: ${{ needs.linux-bionic-rocm5_1-py3_7-slow-build.outputs.docker-image }} + build-environment: linux-focal-rocm5.2-py3.7 + docker-image: ${{ needs.linux-focal-rocm5_2-py3_7-slow-build.outputs.docker-image }} test-matrix: | { include: [ { config: "slow", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" }, @@ -83,20 +55,20 @@ jobs: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} - linux-bionic-rocm5_1-py3_7-distributed-build: - name: linux-bionic-rocm5.1-py3.7-distributed + linux-focal-rocm5_2-py3_7-distributed-build: + name: linux-focal-rocm5.2-py3.7-distributed uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-bionic-rocm5.1-py3.7 - docker-image-name: pytorch-linux-bionic-rocm5.1-py3.7 + build-environment: linux-focal-rocm5.2-py3.7 + docker-image-name: pytorch-linux-focal-rocm5.2-py3.7 - linux-bionic-rocm5_1-py3_7-distributed-test: - name: linux-bionic-rocm5.1-py3.7-distributed + linux-focal-rocm5_2-py3_7-distributed-test: + name: linux-focal-rocm5.2-py3.7-distributed uses: ./.github/workflows/_rocm-test.yml - needs: linux-bionic-rocm5_1-py3_7-distributed-build + needs: linux-focal-rocm5_2-py3_7-distributed-build with: - build-environment: linux-bionic-rocm5.1-py3.7 - docker-image: ${{ needs.linux-bionic-rocm5_1-py3_7-distributed-build.outputs.docker-image }} + build-environment: linux-focal-rocm5.2-py3.7 + docker-image: ${{ needs.linux-focal-rocm5_2-py3_7-distributed-build.outputs.docker-image }} test-matrix: | { include: [ { config: "distributed", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, @@ -125,21 +97,21 @@ jobs: { config: "multigpu", shard: 1, num_shards: 1, runner: "linux.16xlarge.nvidia.gpu" }, ]} - linux-xenial-cuda11_3-py3_7-gcc7-debug-build: - name: linux-xenial-cuda11.3-py3.7-gcc7-debug + linux-bionic-cuda11_6-py3_7-gcc7-debug-build: + name: linux-bionic-cuda11.6-py3.7-gcc7-debug uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-xenial-cuda11.3-py3.7-gcc7-debug - docker-image-name: pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7 + build-environment: linux-bionic-cuda11.6-py3.7-gcc7-debug + docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 build-with-debug: true - linux-xenial-cuda11_3-py3_7-gcc7-debug-test: - name: linux-xenial-cuda11.3-py3.7-gcc7-debug + linux-bionic-cuda11_6-py3_7-gcc7-debug-test: + name: linux-bionic-cuda11.6-py3.7-gcc7-debug uses: ./.github/workflows/_linux-test.yml - needs: linux-xenial-cuda11_3-py3_7-gcc7-debug-build + needs: linux-bionic-cuda11_6-py3_7-gcc7-debug-build with: - build-environment: linux-xenial-cuda11.3-py3.7-gcc7-debug - docker-image: ${{ needs.linux-xenial-cuda11_3-py3_7-gcc7-debug-build.outputs.docker-image }} + build-environment: linux-bionic-cuda11.6-py3.7-gcc7-debug + docker-image: ${{ needs.linux-bionic-cuda11_6-py3_7-gcc7-debug-build.outputs.docker-image }} test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 4, runner: "linux.4xlarge.nvidia.gpu" }, @@ -148,26 +120,18 @@ jobs: { config: "default", shard: 4, num_shards: 4, runner: "linux.4xlarge.nvidia.gpu" }, ]} - win-vs2019-cuda11_6-py3-build: - name: win-vs2019-cuda11.6-py3 - uses: ./.github/workflows/_win-build.yml - with: - build-environment: win-vs2019-cuda11.6-py3 - cuda-version: "11.6" - - win-vs2019-cuda11_6-py3-test: - name: win-vs2019-cuda11.6-py3 - uses: ./.github/workflows/_win-test.yml - needs: win-vs2019-cuda11_6-py3-build + ios-12-5-1-x86-64-coreml: + name: ios-12-5-1-x86-64-coreml + uses: ./.github/workflows/_ios-build-test.yml with: - build-environment: win-vs2019-cuda11.6-py3 - cuda-version: "11.6" - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 2, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "force_on_cpu", shard: 1, num_shards: 1, runner: "windows.4xlarge" }, - ]} + build-environment: ios-12-5-1-x86-64-coreml + ios-platform: SIMULATOR + ios-arch: x86_64 + secrets: + IOS_CERT_KEY_2022: ${{ secrets.IOS_CERT_KEY_2022 }} + IOS_CERT_SECRET: ${{ secrets.IOS_CERT_SECRET}} + IOS_DEV_TEAM_ID: ${{ secrets.IOS_DEV_TEAM_ID}} + IOS_SIGN_KEY_2022: ${{ secrets.IOS_SIGN_KEY_2022 }} ios-12-5-1-arm64: name: ios-12-5-1-arm64 @@ -220,3 +184,7 @@ jobs: IOS_CERT_SECRET: ${{ secrets.IOS_CERT_SECRET}} IOS_DEV_TEAM_ID: ${{ secrets.IOS_DEV_TEAM_ID}} IOS_SIGN_KEY_2022: ${{ secrets.IOS_SIGN_KEY_2022 }} + + buck-build-test: + name: buck-build-test + uses: ./.github/workflows/_buck-build-test.yml diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index dc5141f859d5b..a31111ecf885f 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -6,12 +6,13 @@ on: - master - main - release/* + - landchecks/* tags: - ciflow/trunk/* workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: @@ -61,28 +62,22 @@ jobs: { include: [ { config: "default", shard: 1, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 2, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "slow", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "nogpu_NO_AVX", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "functorch", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 1, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 2, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, { config: "distributed", shard: 1, num_shards: 2, runner: "linux.8xlarge.nvidia.gpu" }, { config: "distributed", shard: 2, num_shards: 2, runner: "linux.8xlarge.nvidia.gpu" }, ]} - libtorch-linux-xenial-cuda10_2-py3_7-gcc7-build: - name: libtorch-linux-xenial-cuda10.2-py3.7-gcc7 + libtorch-linux-bionic-cuda11_6-py3_7-gcc7-build: + name: libtorch-linux-bionic-cuda11.6-py3.7-gcc7 uses: ./.github/workflows/_linux-build.yml with: - build-environment: libtorch-linux-xenial-cuda10.2-py3.7-gcc7 - docker-image-name: pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7 - build-generates-artifacts: false - - libtorch-linux-xenial-cuda11_3-py3_7-gcc7-build: - name: libtorch-linux-xenial-cuda11.3-py3.7-gcc7 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: libtorch-linux-xenial-cuda11.3-py3.7-gcc7 - docker-image-name: pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7 + build-environment: libtorch-linux-bionic-cuda11.6-py3.7-gcc7 + docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 build-generates-artifacts: false # no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated @@ -138,92 +133,80 @@ jobs: IOS_DEV_TEAM_ID: ${{ secrets.IOS_DEV_TEAM_ID}} IOS_SIGN_KEY_2022: ${{ secrets.IOS_SIGN_KEY_2022 }} - ios-12-5-1-x86-64-coreml: - name: ios-12-5-1-x86-64-coreml - uses: ./.github/workflows/_ios-build-test.yml - with: - build-environment: ios-12-5-1-x86-64-coreml - ios-platform: SIMULATOR - ios-arch: x86_64 - secrets: - IOS_CERT_KEY_2022: ${{ secrets.IOS_CERT_KEY_2022 }} - IOS_CERT_SECRET: ${{ secrets.IOS_CERT_SECRET}} - IOS_DEV_TEAM_ID: ${{ secrets.IOS_DEV_TEAM_ID}} - IOS_SIGN_KEY_2022: ${{ secrets.IOS_SIGN_KEY_2022 }} - - macos-11-py3-x86-64-build: - name: macos-11-py3-x86-64 + macos-12-py3-x86-64-build: + name: macos-12-py3-x86-64 uses: ./.github/workflows/_mac-build.yml with: - build-environment: macos-11-py3-x86-64 + build-environment: macos-12-py3-x86-64 xcode-version: "13.3.1" - runner-type: macos-12 + runner-type: macos-12-xl build-generates-artifacts: true secrets: MACOS_SCCACHE_S3_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} - macos-11-py3-x86-64-test: - name: macos-11-py3-x86-64 + macos-12-py3-x86-64-test: + name: macos-12-py3-x86-64 uses: ./.github/workflows/_mac-test.yml - needs: macos-11-py3-x86-64-build + needs: macos-12-py3-x86-64-build with: - build-environment: macos-11-py3-x86-64 + build-environment: macos-12-py3-x86-64 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 2, runner: "macos-12" }, { config: "default", shard: 2, num_shards: 2, runner: "macos-12" }, + { config: "functorch", shard: 1, num_shards: 1, runner: "macos-12" }, ]} secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} - macos-10-15-py3-lite-interpreter-x86-64: - name: macos-10-15-py3-lite-interpreter-x86-64 + macos-12-py3-x86-64-lite-interpreter-build-test: + name: macos-12-py3-x86-64-lite-interpreter uses: ./.github/workflows/_mac-build.yml with: - build-environment: macos-10-15-py3-lite-interpreter-x86-64 - xcode-version: "12" - runner-type: macos-10.15 + build-environment: macos-12-py3-lite-interpreter-x86-64 + xcode-version: "13.3.1" + runner-type: macos-12-xl build-generates-artifacts: false secrets: MACOS_SCCACHE_S3_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} - macos-10-15-py3-arm64: - name: macos-10-15-py3-arm64 + macos-12-py3-arm64-build: + name: macos-12-py3-arm64 uses: ./.github/workflows/_mac-build.yml with: - build-environment: macos-10-15-py3-arm64 + build-environment: macos-12-py3-arm64 xcode-version: "13.3.1" - runner-type: macos-12 + runner-type: macos-12-xl build-generates-artifacts: true secrets: MACOS_SCCACHE_S3_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} - macos-12-3-py38-arm64-test: - name: macos-12.3-py3.8-arm64-test + macos-12-py3-arm64-mps-test: + name: macos-12-py3-arm64 uses: ./.github/workflows/_mac-test-arm64.yml - needs: macos-10-15-py3-arm64 + needs: macos-12-py3-arm64-build with: - build-environment: macos-10-15-py3-arm64 + build-environment: macos-12-py3-arm64 - # please ensure that this and its corresponding job in pull.yml are in sync - win-vs2019-cuda11_3-py3-build: - name: win-vs2019-cuda11.3-py3 + win-vs2019-cuda11_6-py3-build: + name: win-vs2019-cuda11.6-py3 uses: ./.github/workflows/_win-build.yml with: - build-environment: win-vs2019-cuda11.3-py3 - cuda-version: "11.3" + build-environment: win-vs2019-cuda11.6-py3 + cuda-version: "11.6" + sync-tag: win-cuda-build - win-vs2019-cuda11_3-py3-test: - name: win-vs2019-cuda11.3-py3 + win-vs2019-cuda11_6-py3-test: + name: win-vs2019-cuda11.6-py3 uses: ./.github/workflows/_win-test.yml - needs: win-vs2019-cuda11_3-py3-build + needs: win-vs2019-cuda11_6-py3-build with: - build-environment: win-vs2019-cuda11.3-py3 - cuda-version: "11.3" + build-environment: win-vs2019-cuda11.6-py3 + cuda-version: "11.6" test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, @@ -231,23 +214,25 @@ jobs: { config: "default", shard: 3, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, { config: "default", shard: 4, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, { config: "default", shard: 5, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, + { config: "functorch", shard: 1, num_shards: 1, runner: "windows.8xlarge.nvidia.gpu" }, { config: "force_on_cpu", shard: 1, num_shards: 1, runner: "windows.4xlarge" }, ]} - linux-bionic-rocm5_1-py3_7-build: - name: linux-bionic-rocm5.1-py3.7 + linux-focal-rocm5_2-py3_7-build: + name: linux-focal-rocm5.2-py3.7 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-bionic-rocm5.1-py3.7 - docker-image-name: pytorch-linux-bionic-rocm5.1-py3.7 + build-environment: linux-focal-rocm5.2-py3.7 + docker-image-name: pytorch-linux-focal-rocm5.2-py3.7 + sync-tag: rocm-build - linux-bionic-rocm5_1-py3_7-test: - name: linux-bionic-rocm5.1-py3.7 + linux-focal-rocm5_2-py3_7-test: + name: linux-focal-rocm5.2-py3.7 uses: ./.github/workflows/_rocm-test.yml - needs: linux-bionic-rocm5_1-py3_7-build + needs: linux-focal-rocm5_2-py3_7-build with: - build-environment: linux-bionic-rocm5.1-py3.7 - docker-image: ${{ needs.linux-bionic-rocm5_1-py3_7-build.outputs.docker-image }} + build-environment: linux-focal-rocm5.2-py3.7 + docker-image: ${{ needs.linux-focal-rocm5_2-py3_7-build.outputs.docker-image }} test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, @@ -257,6 +242,6 @@ jobs: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} - buck-build-test: - name: buck-build-test - uses: ./.github/workflows/_buck-build-test.yml + android-emulator-build-test: + name: android-emulator-build-test + uses: ./.github/workflows/_run_android_tests.yml diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml index 99cea22e3f271..8db7b0c97c5c9 100644 --- a/.github/workflows/trymerge.yml +++ b/.github/workflows/trymerge.yml @@ -31,6 +31,7 @@ jobs: GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} FORCE: ${{ github.event.client_payload.force}} ON_GREEN: ${{ github.event.client_payload.on_green}} + LAND_CHECKS: ${{ github.event.client_payload.land_checks }} COMMENT_ID: ${{ github.event.client_payload.comment_id }} run: | set -ex @@ -42,6 +43,8 @@ jobs: fi elif [ -n "${ON_GREEN}" ]; then python3 .github/scripts/trymerge.py --on-green "${PR_NUM}" + elif [ -n "${LAND_CHECKS}" ]; then + python3 .github/scripts/trymerge.py --land-checks "${PR_NUM}" elif [ -n "${COMMENT_ID}" ]; then python3 .github/scripts/trymerge.py --comment-id "${COMMENT_ID}" "${PR_NUM}" else diff --git a/.github/workflows/update-commit-hashes.yml b/.github/workflows/update-commit-hashes.yml index 3c3c61ca58259..6c72492d93ac2 100644 --- a/.github/workflows/update-commit-hashes.yml +++ b/.github/workflows/update-commit-hashes.yml @@ -2,10 +2,10 @@ name: update-commit-hashes on: schedule: - # Every day at 12:37am - # Choose a random time near midnight because it may be delayed if there are high loads + # Every day at 7:37am UTC = 12:27am PST + # Choose a random time near midnight PST because it may be delayed if there are high loads # See https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule - - cron: 37 0 * * * + - cron: 37 7 * * * workflow_dispatch: jobs: diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index 2a7292719634d..872d8f5c14285 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -35,11 +35,12 @@ jobs: env: ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }} run: | - output=$(python3 .github/scripts/print_latest_commits.py) + output=$(python3 .github/scripts/fetch_latest_green_commit.py) echo "::set-output name=latest_viable_sha::$output" id: get-latest-commit - name: Push SHA to viable/strict branch + if: steps.get-latest-commit.outputs.latest_viable_sha != 'None' env: GITHUB_TOKEN: ${{ secrets.MERGEBOT_TOKEN }} run: | diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index f69572bd1dbe2..b649aac2c7c50 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -7,8 +7,26 @@ on: - completed jobs: + # the conclusion field in the github context is sometimes null + # solution adapted from https://github.com/community/community/discussions/21090#discussioncomment-3226271 + get_workflow_conclusion: + runs-on: ubuntu-latest + outputs: + conclusion: ${{ fromJson(steps.get_conclusion.outputs.data).conclusion }} + steps: + - name: Get workflow run conclusion + uses: octokit/request-action@v2.1.0 + id: get_conclusion + with: + route: GET /repos/${{ github.repository }}/actions/runs/${{ github.event.workflow_run.id }}/attempts/${{ github.event.workflow_run.run_attempt }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + upload-test-stats: - if: github.event.workflow_run.conclusion == 'success' || github.event.workflow_run.conclusion == 'failure' + needs: get_workflow_conclusion + if: + github.event.workflow_run.conclusion == 'success' || github.event.workflow_run.conclusion == 'failure' || + needs.get_workflow_conclusion.outputs.conclusion == 'success' || needs.get_workflow_conclusion.outputs.conclusion == 'failure' runs-on: [self-hosted, linux.2xlarge] name: Upload test stats for ${{ github.event.workflow_run.id }}, attempt ${{ github.event.workflow_run.run_attempt }} @@ -34,9 +52,10 @@ jobs: WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} WORKFLOW_RUN_ATTEMPT: ${{ github.event.workflow_run.run_attempt }} WORKFLOW_URL: ${{ github.event.workflow_run.html_url }} + HEAD_BRANCH: ${{ github.event.workflow_run.head_branch }} run: | echo "${WORKFLOW_URL}" - python3 -m tools.stats.upload_test_stats --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" + python3 -m tools.stats.upload_test_stats --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --head-branch "${HEAD_BRANCH}" python3 -m tools.stats.upload_sccache_stats --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" check-api-rate: diff --git a/.gitignore b/.gitignore index 814b14f60befb..88d472b456f4a 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,7 @@ docs/cpp/source/html/ docs/cpp/source/latex/ docs/source/generated/ log +usage_log.txt test-reports/ test/.coverage test/.hypothesis/ diff --git a/.gitmodules b/.gitmodules index 7ceeff397076e..538967d317641 100644 --- a/.gitmodules +++ b/.gitmodules @@ -139,6 +139,12 @@ [submodule "third_party/pocketfft"] path = third_party/pocketfft url = https://github.com/mreineck/pocketfft +[submodule "third_party/ittapi"] + path = third_party/ittapi + url = https://github.com/intel/ittapi.git [submodule "third_party/flatbuffers"] path = third_party/flatbuffers url = https://github.com/google/flatbuffers.git +[submodule "third_party/nlohmann"] + path = third_party/nlohmann + url = https://github.com/nlohmann/json.git diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index e605f19b5a8d5..6016911941b5d 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -172,7 +172,7 @@ fi # ONNX tests # ############## if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then - pip install -q --user "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" + pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" pip install -q --user ninja flatbuffers==2.0 numpy==1.21.5 onnxruntime==1.11.0 # numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21. # We don't actually need it for our tests, but it's imported if it's present, so uninstall. diff --git a/.jenkins/pytorch/build-asan.sh b/.jenkins/pytorch/build-asan.sh index 963a062fa1e65..d46f4bd2a6854 100755 --- a/.jenkins/pytorch/build-asan.sh +++ b/.jenkins/pytorch/build-asan.sh @@ -6,10 +6,14 @@ # shellcheck source=./common.sh source "$(dirname "${BASH_SOURCE[0]}")/common.sh" +# shellcheck source=./common-build.sh +source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" echo "Clang version:" clang --version +python tools/stats/export_test_times.py + # detect_leaks=0: Python is very leaky, so we need suppress it # symbolize=1: Gives us much better errors when things go wrong export ASAN_OPTIONS=detect_leaks=0:detect_stack_use_after_return=1:symbolize=1:detect_odr_violation=0 diff --git a/.jenkins/pytorch/build-mobile.sh b/.jenkins/pytorch/build-mobile.sh index 85d8280597e2c..1f253ff58c03d 100755 --- a/.jenkins/pytorch/build-mobile.sh +++ b/.jenkins/pytorch/build-mobile.sh @@ -8,6 +8,8 @@ set -eu -o pipefail # shellcheck source=./common.sh source "$(dirname "${BASH_SOURCE[0]}")/common.sh" +# shellcheck source=./common-build.sh +source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" # Install torch & torchvision - used to download & trace test model. # Ideally we should use the libtorch built on the PR so that backward diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index f4513e2918c75..d442a4ebd41c2 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -8,6 +8,8 @@ set -ex # shellcheck source=./common.sh source "$(dirname "${BASH_SOURCE[0]}")/common.sh" +# shellcheck source=./common-build.sh +source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" if [[ "$BUILD_ENVIRONMENT" == *-clang7-asan* ]]; then exec "$(dirname "${BASH_SOURCE[0]}")/build-asan.sh" "$@" @@ -158,11 +160,6 @@ fi # Target only our CI GPU machine's CUDA arch to speed up the build export TORCH_CUDA_ARCH_LIST="5.2" -# Add sm_75 support for the Linux CUDA 11.1 cuDNN 8 CircleCI build -if [[ "$BUILD_ENVIRONMENT" == *xenial-cuda11.1*build ]]; then - export TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST";7.5" -fi - if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then export CC=clang export CXX=clang++ @@ -302,7 +299,7 @@ fi if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then # export test times so that potential sharded tests that'll branch off this build will use consistent data # don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build - python test/run_test.py --export-past-test-times + python tools/stats/export_test_times.py fi print_sccache_stats diff --git a/.jenkins/pytorch/common-build.sh b/.jenkins/pytorch/common-build.sh new file mode 100644 index 0000000000000..4f21d9c678748 --- /dev/null +++ b/.jenkins/pytorch/common-build.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Required environment variables: +# $BUILD_ENVIRONMENT (should be set by your Docker image) + +if [[ "$BUILD_ENVIRONMENT" != *win-* ]]; then + # Save the absolute path in case later we chdir (as occurs in the gpu perf test) + script_dir="$( cd "$(dirname "${BASH_SOURCE[0]}")" || exit ; pwd -P )" + + if which sccache > /dev/null; then + # Save sccache logs to file + sccache --stop-server > /dev/null 2>&1 || true + rm -f ~/sccache_error.log || true + if [[ -n "${SKIP_SCCACHE_INITIALIZATION:-}" ]]; then + # sccache --start-server seems to hang forever on self hosted runners for GHA + # so let's just go ahead and skip the --start-server altogether since it seems + # as though sccache still gets used even when the sscache server isn't started + # explicitly + echo "Skipping sccache server initialization, setting environment variables" + export SCCACHE_IDLE_TIMEOUT=1200 + export SCCACHE_ERROR_LOG=~/sccache_error.log + export RUST_LOG=sccache::server=error + elif [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then + SCCACHE_ERROR_LOG=~/sccache_error.log SCCACHE_IDLE_TIMEOUT=0 sccache --start-server + else + # increasing SCCACHE_IDLE_TIMEOUT so that extension_backend_test.cpp can build after this PR: + # https://github.com/pytorch/pytorch/pull/16645 + SCCACHE_ERROR_LOG=~/sccache_error.log SCCACHE_IDLE_TIMEOUT=1200 RUST_LOG=sccache::server=error sccache --start-server + fi + + # Report sccache stats for easier debugging + sccache --zero-stats + function sccache_epilogue() { + echo "::group::Sccache Compilation Log" + echo '=================== sccache compilation log ===================' + python "$script_dir/print_sccache_log.py" ~/sccache_error.log 2>/dev/null + echo '=========== If your build fails, please take a look at the log above for possible reasons ===========' + sccache --show-stats + sccache --stop-server || true + echo "::endgroup::" + } + + trap_add sccache_epilogue EXIT + fi + + if which ccache > /dev/null; then + # Report ccache stats for easier debugging + ccache --zero-stats + ccache --show-stats + function ccache_epilogue() { + ccache --show-stats + } + trap_add ccache_epilogue EXIT + fi +fi diff --git a/.jenkins/pytorch/common.sh b/.jenkins/pytorch/common.sh index c7d4139bc3287..c71acc7e66cfe 100644 --- a/.jenkins/pytorch/common.sh +++ b/.jenkins/pytorch/common.sh @@ -5,9 +5,6 @@ source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" set -ex -# Save the SCRIPT_DIR absolute path in case later we chdir (as occurs in the gpu perf test) -SCRIPT_DIR="$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )" - # Required environment variables: # $BUILD_ENVIRONMENT (should be set by your Docker image) @@ -22,63 +19,6 @@ if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then export HSA_FORCE_FINE_GRAIN_PCIE=1 fi -# This token is used by a parser on Jenkins logs for determining -# if a failure is a legitimate problem, or a problem with the build -# system; to find out more, grep for this string in ossci-job-dsl. -echo "ENTERED_USER_LAND" - -trap_add cleanup EXIT - -if [[ "$BUILD_ENVIRONMENT" != *win-* ]]; then - if which sccache > /dev/null; then - # Save sccache logs to file - sccache --stop-server > /dev/null 2>&1 || true - rm -f ~/sccache_error.log || true - if [[ -n "${SKIP_SCCACHE_INITIALIZATION:-}" ]]; then - # sccache --start-server seems to hang forever on self hosted runners for GHA - # so let's just go ahead and skip the --start-server altogether since it seems - # as though sccache still gets used even when the sscache server isn't started - # explicitly - echo "Skipping sccache server initialization, setting environment variables" - export SCCACHE_IDLE_TIMEOUT=1200 - export SCCACHE_ERROR_LOG=~/sccache_error.log - export RUST_LOG=sccache::server=error - elif [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then - SCCACHE_ERROR_LOG=~/sccache_error.log SCCACHE_IDLE_TIMEOUT=0 sccache --start-server - else - # increasing SCCACHE_IDLE_TIMEOUT so that extension_backend_test.cpp can build after this PR: - # https://github.com/pytorch/pytorch/pull/16645 - SCCACHE_ERROR_LOG=~/sccache_error.log SCCACHE_IDLE_TIMEOUT=1200 RUST_LOG=sccache::server=error sccache --start-server - fi - - # Report sccache stats for easier debugging - sccache --zero-stats - function sccache_epilogue() { - echo "::group::Sccache Compilation Log" - echo '=================== sccache compilation log ===================' - python "$SCRIPT_DIR/print_sccache_log.py" ~/sccache_error.log 2>/dev/null - echo '=========== If your build fails, please take a look at the log above for possible reasons ===========' - sccache --show-stats - sccache --stop-server || true - echo "::endgroup::" - } - - if [[ "${JOB_BASE_NAME}" == *-build ]]; then - trap_add sccache_epilogue EXIT - fi - fi - - if which ccache > /dev/null; then - # Report ccache stats for easier debugging - ccache --zero-stats - ccache --show-stats - function ccache_epilogue() { - ccache --show-stats - } - trap_add ccache_epilogue EXIT - fi -fi - # TODO: Renable libtorch testing for MacOS, see https://github.com/pytorch/pytorch/issues/62598 # shellcheck disable=SC2034 BUILD_TEST_LIBTORCH=0 diff --git a/.jenkins/pytorch/common_utils.sh b/.jenkins/pytorch/common_utils.sh index 5acd7c49f189a..0584ddab9e2a0 100644 --- a/.jenkins/pytorch/common_utils.sh +++ b/.jenkins/pytorch/common_utils.sh @@ -35,19 +35,6 @@ trap_add() { # inherit them unless the trace attribute is set declare -f -t trap_add -# NB: define this function before set -x, so that we don't -# pollute the log with a premature EXITED_USER_LAND ;) -function cleanup { - # Note that if you've exited user land, then CI will conclude that - # any failure is the CI's fault. So we MUST only output this - # string - retcode=$? - set +x - if [ $retcode -eq 0 ]; then - echo "EXITED_USER_LAND" - fi -} - function assert_git_not_dirty() { # TODO: we should add an option to `build_amd.py` that reverts the repo to # an unmodified state. @@ -111,7 +98,7 @@ function get_pinned_commit() { function install_torchvision() { local commit commit=$(get_pinned_commit vision) - pip_install --user "git+https://github.com/pytorch/vision.git@${commit}" + pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${commit}" } function checkout_install_torchvision() { @@ -152,13 +139,22 @@ function checkout_install_torchdynamo() { popd } +function install_functorch() { + pushd functorch + time python setup.py develop + popd +} + +function test_functorch() { + python test/run_test.py --functorch --verbose +} + function print_sccache_stats() { echo 'PyTorch Build Statistics' sccache --show-stats if [[ -n "${OUR_GITHUB_JOB_ID}" ]]; then - sccache --show-stats \ - | python -m tools.stats.sccache_stats_to_json \ + sccache --show-stats --stats-format json | jq .stats \ > "sccache-stats-${BUILD_ENVIRONMENT}-${OUR_GITHUB_JOB_ID}.json" else echo "env var OUR_GITHUB_JOB_ID not set, will not write sccache stats to json" diff --git a/.jenkins/pytorch/docs-test.sh b/.jenkins/pytorch/docs-test.sh index 374dae28e325f..d57d78bc7f7b0 100755 --- a/.jenkins/pytorch/docs-test.sh +++ b/.jenkins/pytorch/docs-test.sh @@ -5,6 +5,6 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch docs" -cd "${SCRIPT_DIR}/../../docs" +cd docs pip_install -r requirements.txt make doctest diff --git a/.jenkins/pytorch/macos-build.sh b/.jenkins/pytorch/macos-build.sh index 7c1854346efdf..db33e2dedf95b 100755 --- a/.jenkins/pytorch/macos-build.sh +++ b/.jenkins/pytorch/macos-build.sh @@ -3,6 +3,8 @@ # shellcheck disable=SC2034 # shellcheck source=./macos-common.sh source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh" +# shellcheck source=./common-build.sh +source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" # Build PyTorch if [ -z "${CI}" ]; then @@ -71,4 +73,6 @@ if which sccache > /dev/null; then print_sccache_stats fi +python tools/stats/export_test_times.py + assert_git_not_dirty diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index cd2f3694c9a3e..1b15fab1ed205 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -10,7 +10,9 @@ pip install -q hypothesis "expecttest==0.1.3" "librosa>=0.6.2" "numba<=0.49.1" p # TODO move this to docker # Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014 pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \ - pytest + pytest \ + pytest-xdist \ + pytest-rerunfailures if [ -z "${CI}" ]; then rm -rf "${WORKSPACE_DIR}"/miniconda3/lib/python3.6/site-packages/torch* @@ -167,7 +169,10 @@ test_dynamo() { popd } -if [[ $NUM_TEST_SHARDS -gt 1 ]]; then +if [[ "${TEST_CONFIG}" == *functorch* ]]; then + install_functorch + test_functorch +elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then test_python_shard "${SHARD_NUMBER}" if [[ "${SHARD_NUMBER}" == 1 ]]; then test_libtorch diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index bd57f67f4b057..b476d25250791 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -15,11 +15,6 @@ BUILD_DIR="build" BUILD_RENAMED_DIR="build_renamed" BUILD_BIN_DIR="$BUILD_DIR"/bin -# GHA has test config defined for the test job, so we need to add them. -if [[ -n "${TEST_CONFIG}" ]]; then - BUILD_ENVIRONMENT="${BUILD_ENVIRONMENT}-${TEST_CONFIG}" -fi - # Get fully qualified path using realpath if [[ "$BUILD_ENVIRONMENT" != *bazel* ]]; then CUSTOM_TEST_ARTIFACT_BUILD_DIR=$(realpath "${CUSTOM_TEST_ARTIFACT_BUILD_DIR:-"build/custom_test_artifacts"}") @@ -38,16 +33,16 @@ export LANG=C.UTF-8 PR_NUMBER=${PR_NUMBER:-${CIRCLE_PR_NUMBER:-}} -if [[ $TEST_CONFIG == 'default' ]]; then +if [[ "$TEST_CONFIG" == 'default' ]]; then export CUDA_VISIBLE_DEVICES=0 export HIP_VISIBLE_DEVICES=0 fi -if [[ $TEST_CONFIG == 'distributed' ]] && [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then +if [[ "$TEST_CONFIG" == 'distributed' ]] && [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then export HIP_VISIBLE_DEVICES=0,1 fi -if [[ "$BUILD_ENVIRONMENT" == *-slow-* || $TEST_CONFIG == 'slow' ]]; then +if [[ "$TEST_CONFIG" == 'slow' ]]; then export PYTORCH_TEST_WITH_SLOW=1 export PYTORCH_TEST_SKIP_FAST=1 fi @@ -67,10 +62,15 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda11* ]]; then export BUILD_SPLIT_CUDA=ON fi -if [[ "$BUILD_ENVIRONMENT" == *crossref* ]]; then +if [[ "$TEST_CONFIG" == *crossref* ]]; then export PYTORCH_TEST_WITH_CROSSREF=1 fi +if [[ "$TEST_CONFIG" == *dynamo* ]]; then + export PYTORCH_TEST_WITH_DYNAMO=1 +fi + +# TODO: this condition is never true, need to fix this. if [[ -n "$PR_NUMBER" ]] && [[ -z "$CI_MASTER" || "$CI_MASTER" == "false" ]]; then # skip expensive checks when on PR and CI_MASTER flag is not set export PYTORCH_TEST_SKIP_CUDA_MEM_LEAK_CHECK=1 @@ -82,12 +82,6 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then # Print GPU info rocminfo rocminfo | grep -E 'Name:.*\sgfx|Marketing' - - # Manually set NUM_TEST_SHARDS since Jenkins doesn't do it - # TODO: Can remove this once ROCm migration from Jenkins to GHA is complete. - if [[ -z "${GITHUB_ACTIONS}" ]]; then - export NUM_TEST_SHARDS=2 - fi fi if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then @@ -153,11 +147,9 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)") fi -if [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX-* || $TEST_CONFIG == 'nogpu_NO_AVX' ]]; then - export ATEN_CPU_CAPABILITY=default -elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* || $TEST_CONFIG == 'nogpu_NO_AVX2' ]]; then +if [[ $TEST_CONFIG == 'nogpu_NO_AVX2' ]]; then export ATEN_CPU_CAPABILITY=default -elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX512-* || $TEST_CONFIG == 'nogpu_NO_AVX512' ]]; then +elif [[ $TEST_CONFIG == 'nogpu_AVX512' ]]; then export ATEN_CPU_CAPABILITY=avx2 fi @@ -180,6 +172,33 @@ test_python() { assert_git_not_dirty } + +test_dynamo_shard() { + if [[ -z "$NUM_TEST_SHARDS" ]]; then + echo "NUM_TEST_SHARDS must be defined to run a Python test shard" + exit 1 + fi + time python test/run_test.py \ + --exclude-jit-executor \ + --exclude-distributed-tests \ + --exclude \ + test_autograd \ + test_proxy_tensor \ + test_quantization \ + test_public_bindings \ + test_dataloader \ + test_reductions \ + test_namedtensor \ + test_namedtuple_return_api \ + test_profiler \ + test_profiler_tree \ + test_overrides \ + test_python_dispatch \ + --shard "$1" "$NUM_TEST_SHARDS" \ + --verbose + assert_git_not_dirty +} + test_python_gloo_with_tls() { source "$(dirname "${BASH_SOURCE[0]}")/run_glootls_test.sh" assert_git_not_dirty @@ -277,7 +296,7 @@ test_libtorch() { fi # Run Lazy Tensor cpp tests - if [[ "$BUILD_ENVIRONMENT" == *cuda* && "$BUILD_ENVIRONMENT" != *nogpu* ]]; then + if [[ "$BUILD_ENVIRONMENT" == *cuda* && "$TEST_CONFIG" != *nogpu* ]]; then LTC_TS_CUDA=1 "$TORCH_BIN_DIR"/test_lazy --gtest_output=xml:$TEST_REPORTS_DIR/test_lazy.xml else "$TORCH_BIN_DIR"/test_lazy --gtest_output=xml:$TEST_REPORTS_DIR/test_lazy.xml @@ -289,6 +308,8 @@ test_libtorch() { # Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy. OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml + + # TODO: this condition is never (BUILD_ENVIRONMENT doesn't start with pytorch-), need to fix this. if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3* ]]; then if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* && "${BUILD_ENVIRONMENT}" != *asan* ]]; then # TODO: Consider to run static_runtime_test from $TORCH_BIN_DIR (may need modify build script) @@ -310,12 +331,12 @@ test_vulkan() { if [[ "$BUILD_ENVIRONMENT" == *vulkan* ]]; then ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_TEST_DIR" ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_TEST_DIR" - export VK_ICD_FILENAMES=/var/lib/jenkins/swiftshader/build/Linux/vk_swiftshader_icd.json + export VK_ICD_FILENAMES=/var/lib/jenkins/swiftshader/swiftshader/build/Linux/vk_swiftshader_icd.json # NB: the ending test_vulkan must match the current function name for the current # test reporting process (in print_test_stats.py) to function as expected. TEST_REPORTS_DIR=test/test-reports/cpp-vulkan/test_vulkan mkdir -p $TEST_REPORTS_DIR - "$TORCH_TEST_DIR"/vulkan_api_test --gtest_output=xml:$TEST_REPORTS_DIR/vulkan_test.xml + LD_LIBRARY_PATH=/var/lib/jenkins/swiftshader/swiftshader/build/Linux/ "$TORCH_TEST_DIR"/vulkan_api_test --gtest_output=xml:$TEST_REPORTS_DIR/vulkan_test.xml fi } @@ -424,8 +445,12 @@ test_torch_function_benchmark() { } build_xla() { + # xla test needs sccache setup. + # shellcheck source=./common-build.sh + source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" + XLA_DIR=xla - USE_CACHE=0 + USE_CACHE=1 clone_pytorch_xla # shellcheck disable=SC1091 source "xla/.circleci/common.sh" @@ -438,6 +463,10 @@ build_xla() { } test_xla() { + # xla test needs sccache setup. + # shellcheck source=./common-build.sh + source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" + clone_pytorch_xla # shellcheck disable=SC1091 source "./xla/.circleci/common.sh" @@ -447,32 +476,49 @@ test_xla() { } # Do NOT run this test before any other tests, like test_python_shard, etc. -# Because this function uninstalls the torch built from branch, and install -# nightly version. +# Because this function uninstalls the torch built from branch and installs +# the torch built on its base commit. test_forward_backward_compatibility() { set -x + REPO_DIR=$(pwd) + if [[ "${BASE_SHA}" == "${SHA1}" ]]; then + echo "On trunk, we should compare schemas with torch built from the parent commit" + SHA_TO_COMPARE=$(git rev-parse "${SHA1}"^) + else + echo "On pull, we should compare schemas with torch built from the merge base" + SHA_TO_COMPARE=$(git merge-base "${SHA1}" "${BASE_SHA}") + fi + export SHA_TO_COMPARE + # create a dummy ts model at this version python test/create_dummy_torchscript_model.py /tmp/model_new.pt - REPO_DIR=$(pwd) - pushd test/forward_backward_compatibility python -m venv venv # shellcheck disable=SC1091 . venv/bin/activate - # install the nightly before the base commit -- fallback to most recent nightly in case of error - VERSION=$(cat "${REPO_DIR}/version.txt") - DATE_OF_BASE=$(git show -s --format=%cd --date=short "${BASE_SHA}") - pip_install --pre "torch<${VERSION::-2}.dev${DATE_OF_BASE//-/}" -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html || \ - pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + + # build torch at the base commit to generate a base function schema for comparison + git reset --hard "${SHA_TO_COMPARE}" + echo "::group::Installing Torch From Base Commit" + pip install -r requirements.txt + # shellcheck source=./common-build.sh + source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" + python setup.py bdist_wheel --bdist-dir="base_bdist_tmp" --dist-dir="base_dist" + python -mpip install base_dist/*.whl + echo "::endgroup::" + + pushd test/forward_backward_compatibility pip show torch python dump_all_function_schemas.py --filename nightly_schemas.txt - # FC: verify newmodel can be load with old code. + + git reset --hard "${SHA1}" + # FC: verify new model can be load with old code. if ! python ../load_torchscript_model.py /tmp/model_new.pt; then echo "FC check failed: new model cannot be load in old code" return 1 fi python ../create_dummy_torchscript_model.py /tmp/model_old.pt deactivate - rm -r venv + rm -r "${REPO_DIR}/venv" "${REPO_DIR}/base_dist" pip show torch python check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt # BC: verify old model can be load with new code @@ -488,6 +534,10 @@ test_forward_backward_compatibility() { test_bazel() { set -e + # bazel test needs sccache setup. + # shellcheck source=./common-build.sh + source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" + get_bazel # Test //c10/... without Google flags and logging libraries. The @@ -500,7 +550,7 @@ test_bazel() { } test_benchmarks() { - if [[ "$BUILD_ENVIRONMENT" == *cuda* && "$BUILD_ENVIRONMENT" != *nogpu* && $TEST_CONFIG != *nogpu* ]]; then + if [[ "$BUILD_ENVIRONMENT" == *cuda* && $TEST_CONFIG != *nogpu* ]]; then pip_install --user "pytest-benchmark==3.2.3" pip_install --user "requests" BENCHMARK_DATA="benchmarks/.data" @@ -564,10 +614,10 @@ if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-baze (cd test && python -c "import torch; print(torch.__config__.show())") (cd test && python -c "import torch; print(torch.__config__.parallel_info())") fi -if [[ "${BUILD_ENVIRONMENT}" == *deploy* ]]; then +if [[ "${TEST_CONFIG}" == *deploy* ]]; then install_torchdynamo test_torch_deploy -elif [[ "${BUILD_ENVIRONMENT}" == *backward* ]]; then +elif [[ "${TEST_CONFIG}" == *backward* ]]; then test_forward_backward_compatibility # Do NOT add tests after bc check tests, see its comment. elif [[ "${TEST_CONFIG}" == *xla* ]]; then @@ -575,18 +625,29 @@ elif [[ "${TEST_CONFIG}" == *xla* ]]; then install_torchdynamo build_xla test_xla -elif [[ "${BUILD_ENVIRONMENT}" == *jit_legacy-test || "${JOB_BASE_NAME}" == *jit_legacy-test || $TEST_CONFIG == 'jit_legacy' ]]; then +elif [[ "$TEST_CONFIG" == 'jit_legacy' ]]; then test_python_legacy_jit elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then # TODO: run some C++ tests echo "no-op at the moment" -elif [[ "${BUILD_ENVIRONMENT}" == *distributed* || "${JOB_BASE_NAME}" == *distributed* ]]; then +elif [[ "$TEST_CONFIG" == distributed ]]; then install_torchdynamo test_distributed # Only run RPC C++ tests on the first shard if [[ "${SHARD_NUMBER}" == 1 ]]; then test_rpc fi +elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then + test_without_numpy + install_torchvision + install_torchdynamo + test_dynamo_shard 1 + test_aten +elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then + install_torchvision + checkout_install_torchdynamo + test_dynamo_shard 2 + test_dynamo elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then test_without_numpy install_torchvision @@ -602,20 +663,21 @@ elif [[ "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then test_custom_script_ops test_custom_backend test_torch_function_benchmark - test_dynamo elif [[ "${SHARD_NUMBER}" -gt 2 ]]; then # Handle arbitrary number of shards install_torchdynamo test_python_shard "$SHARD_NUMBER" elif [[ "${BUILD_ENVIRONMENT}" == *vulkan* ]]; then - # TODO: re-enable vulkan test - echo "no-op at the moment" + test_vulkan elif [[ "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then test_bazel elif [[ "${BUILD_ENVIRONMENT}" == *-mobile-lightweight-dispatch* ]]; then test_libtorch elif [[ "${TEST_CONFIG}" = docs_test ]]; then test_docs_test +elif [[ "${TEST_CONFIG}" == *functorch* ]]; then + install_functorch + test_functorch else install_torchvision install_torchdynamo diff --git a/.jenkins/pytorch/win-build.sh b/.jenkins/pytorch/win-build.sh index 47011e4a107be..2de3022d40c2f 100755 --- a/.jenkins/pytorch/win-build.sh +++ b/.jenkins/pytorch/win-build.sh @@ -12,6 +12,8 @@ fi SCRIPT_PARENT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) # shellcheck source=./common.sh source "$SCRIPT_PARENT_DIR/common.sh" +# shellcheck source=./common-build.sh +source "$SCRIPT_PARENT_DIR/common-build.sh" IMAGE_COMMIT_ID=$(git rev-parse HEAD) export IMAGE_COMMIT_ID diff --git a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat index 839a28fba408e..b954430734b02 100644 --- a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat +++ b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat @@ -147,12 +147,13 @@ python setup.py install --cmake && sccache --show-stats && ( if not errorlevel 0 exit /b :: export test times so that potential sharded tests that'll branch off this build will use consistent data - python test/run_test.py --export-past-test-times %PYTORCH_FINAL_PACKAGE_DIR%/.pytorch-test-times.json + python tools/stats/export_test_times.py + copy /Y ".pytorch-test-times.json" "%PYTORCH_FINAL_PACKAGE_DIR%" :: Also save build/.ninja_log as an artifact copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\" ) ) -sccache --show-stats | python tools/stats/sccache_stats_to_json.py > sccache-stats-%BUILD_ENVIRONMENT%-%OUR_GITHUB_JOB_ID%.json +sccache --show-stats --stats-format json | jq .stats > sccache-stats-%BUILD_ENVIRONMENT%-%OUR_GITHUB_JOB_ID%.json sccache --stop-server diff --git a/.jenkins/pytorch/win-test-helpers/install_test_functorch.bat b/.jenkins/pytorch/win-test-helpers/install_test_functorch.bat new file mode 100644 index 0000000000000..7679bffbc70e7 --- /dev/null +++ b/.jenkins/pytorch/win-test-helpers/install_test_functorch.bat @@ -0,0 +1,32 @@ +call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat +:: exit the batch once there's an error +if not errorlevel 0 ( + echo "setup pytorch env failed" + echo %errorlevel% + exit /b +) + +pushd functorch +echo "Install functorch" +:: --no-deps because for some reason, on windows, `torch` isn't found in +:: `pip list` despite being installed. With just `python setup.py develop`, +:: setuptools explicitly checks for the existence of torch and can't find it. +python setup.py develop --no-deps +popd +if ERRORLEVEL 1 goto fail + +echo "Installing test dependencies" +pip install networkx +if errorlevel 1 exit /b + +echo "Test functorch" +pushd test +python run_test.py --functorch --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose +popd +if ERRORLEVEL 1 goto fail + +:eof +exit /b 0 + +:fail +exit /b 1 diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat index 6578486312452..54b954a0503f1 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat @@ -19,7 +19,7 @@ if "%INSTALL_FRESH_CONDA%"=="1" ( call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3 if "%INSTALL_FRESH_CONDA%"=="1" ( - call conda install -y -q python=%PYTHON_VERSION% numpy cffi pyyaml boto3 libuv + call conda install -y -q python=%PYTHON_VERSION% numpy"<1.23" cffi pyyaml boto3 libuv if errorlevel 1 exit /b if not errorlevel 0 exit /b call conda install -y -q -c conda-forge cmake=3.22.3 diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat index c7f3e1b6a6140..90725b7666a33 100644 --- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat +++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat @@ -36,7 +36,7 @@ popd ======= :: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014 -pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest +pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-rerunfailures if errorlevel 1 exit /b if not errorlevel 0 exit /b diff --git a/.jenkins/pytorch/win-test-helpers/test_python_jit_legacy.bat b/.jenkins/pytorch/win-test-helpers/test_python_jit_legacy.bat index 146638a533041..c18151d65c023 100644 --- a/.jenkins/pytorch/win-test-helpers/test_python_jit_legacy.bat +++ b/.jenkins/pytorch/win-test-helpers/test_python_jit_legacy.bat @@ -1,7 +1,7 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat echo Copying over test times file -copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%" +copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%" pushd test diff --git a/.jenkins/pytorch/win-test-helpers/test_python_shard.bat b/.jenkins/pytorch/win-test-helpers/test_python_shard.bat index ccc615f67f31e..5313bc0078d5f 100644 --- a/.jenkins/pytorch/win-test-helpers/test_python_shard.bat +++ b/.jenkins/pytorch/win-test-helpers/test_python_shard.bat @@ -22,7 +22,7 @@ if "%SHARD_NUMBER%" == "1" ( ) echo Copying over test times file -copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%" +copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%" echo Run nn tests python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index 7455ebe8e7607..dc28521204878 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -64,12 +64,15 @@ run_tests() { fi done - "$SCRIPT_HELPERS_DIR"/test_python_shard.bat - if [[ ( -z "${JOB_BASE_NAME}" || "${JOB_BASE_NAME}" == *-test ) && $NUM_TEST_SHARDS -eq 1 ]]; then + if [[ "${TEST_CONFIG}" == *functorch* ]]; then + "$SCRIPT_HELPERS_DIR"/install_test_functorch.bat + elif [[ $NUM_TEST_SHARDS -eq 1 ]]; then + "$SCRIPT_HELPERS_DIR"/test_python_shard.bat "$SCRIPT_HELPERS_DIR"/test_custom_script_ops.bat "$SCRIPT_HELPERS_DIR"/test_custom_backend.bat "$SCRIPT_HELPERS_DIR"/test_libtorch.bat else + "$SCRIPT_HELPERS_DIR"/test_python_shard.bat if [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then "$SCRIPT_HELPERS_DIR"/test_libtorch.bat if [[ "${USE_CUDA}" == "1" ]]; then diff --git a/.lintrunner.toml b/.lintrunner.toml index b6613e8ba45b6..02b02d1aaf06e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -9,6 +9,9 @@ exclude_patterns = [ 'docs/caffe2/**', 'docs/cpp/src/**', 'docs/src/**', + 'functorch/docs/**', + 'functorch/examples/**', + 'functorch/notebooks/**', 'scripts/**', 'test/generated_type_hints_smoketest.py', 'third_party/**', @@ -16,6 +19,7 @@ exclude_patterns = [ 'torch/lib/**', 'venv/**', '**/*.pyi', + 'tools/test/test_selective_build.py', ] command = [ 'python3', @@ -42,6 +46,8 @@ init_command = [ code = 'CLANGFORMAT' include_patterns = [ 'aten/src/ATen/*.h', + 'aten/src/ATen/native/vulkan/**/*.h', + 'aten/src/ATen/native/vulkan/**/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', 'torch/csrc/**/*.h', @@ -50,6 +56,7 @@ include_patterns = [ 'test/cpp/**/*.cpp', ] exclude_patterns = [ + 'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h', 'c10/util/strong_type.h', 'torch/csrc/jit/serialization/mobile_bytecode_generated.h', ] @@ -118,7 +125,7 @@ init_command = [ '--dry-run={{DRYRUN}}', 'numpy==1.21.6', 'expecttest==0.1.3', - 'mypy==0.950', + 'mypy==0.960', 'types-requests==2.27.25', 'types-six==1.16.15', 'types-PyYAML==6.0.7', @@ -137,6 +144,7 @@ include_patterns = [ '.github/**/*.py', 'benchmarks/instruction_counts/**/*.py', 'tools/**/*.py', + 'torchgen/**/*.py', 'torch/utils/_pytree.py', 'torch/utils/benchmark/utils/common.py', 'torch/utils/benchmark/utils/timer.py', @@ -145,6 +153,10 @@ include_patterns = [ exclude_patterns = [ # (linbinyu) copied from internal repo 'tools/code_analyzer/gen_operators_yaml.py', + 'tools/gen_vulkan_spv.py', + 'tools/test/gen_operators_yaml_test.py', + 'tools/test/gen_oplist_test.py', + 'tools/test/test_selective_build.py', ] command = [ 'python3', @@ -212,7 +224,9 @@ command = [ [[linter]] code = 'TYPEIGNORE' include_patterns = ['**/*.py', '**/*.pyi'] -exclude_patterns = ['test/test_jit.py'] +exclude_patterns = [ + 'test/test_jit.py', +] command = [ 'python3', 'tools/linter/adapters/grep_linter.py', @@ -300,6 +314,7 @@ exclude_patterns = [ '**/contrib/**', '**/*.diff', 'third_party/**', + 'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h', 'test/cpp/jit/upgrader_models/*.ptl', 'test/cpp/jit/upgrader_models/*.ptl.ff', ] @@ -327,6 +342,7 @@ exclude_patterns = [ 'third_party/**', '**/.gitattributes', '**/.gitmodules', + 'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h', 'test/cpp/jit/upgrader_models/*.ptl', 'test/cpp/jit/upgrader_models/*.ptl.ff', '.lintrunner.toml', @@ -334,6 +350,7 @@ exclude_patterns = [ command = [ 'python3', 'tools/linter/adapters/grep_linter.py', + # @lint-ignore TXT2 '--pattern= ', '--linter-name=TABS', '--error-name=saw some tabs', @@ -354,6 +371,7 @@ include_patterns = [ ] exclude_patterns = [ 'aten/src/ATen/native/quantized/cpu/qnnpack/**', + 'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h', 'torch/csrc/jit/serialization/mobile_bytecode_generated.h', ] command = [ @@ -370,6 +388,73 @@ command = [ '@{{PATHSFILE}}' ] +[[linter]] +code = 'PYBIND11_INCLUDE' +include_patterns = [ + '**/*.cpp', + '**/*.h', +] +exclude_patterns = [ + 'torch/csrc/utils/pybind.h', + 'torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp', + 'caffe2/**/*', +] +command = [ + 'python3', + 'tools/linter/adapters/grep_linter.py', + '--pattern=#include ', + '--linter-name=PYBIND11_INCLUDE', + '--match-first-only', + '--error-name=direct include of pybind11', + # https://stackoverflow.com/a/33416489/23845 + # NB: this won't work if the pybind11 include is on the first line; + # but that's fine because it will just mean the lint will still fail + # after applying the change and you will have to fix it manually + '--replace-pattern=1,/(#include \n\1/', + """--error-description=\ + This #include directly includes pybind11 without also including \ + #include ; this means some important \ + specializations may not be included.\ + """, + '--', + '@{{PATHSFILE}}' +] + +[[linter]] +code = 'PYBIND11_SPECIALIZATION' +include_patterns = [ + '**/*.cpp', + '**/*.h', +] +exclude_patterns = [ + # The place for all orphan specializations + 'torch/csrc/utils/pybind.h', + # These specializations are non-orphan + 'torch/csrc/distributed/c10d/init.cpp', + 'torch/csrc/jit/python/pybind.h', + # These are safe to exclude as they do not have Python + 'c10/**/*', +] +command = [ + 'python3', + 'tools/linter/adapters/grep_linter.py', + '--pattern=PYBIND11_DECLARE_HOLDER_TYPE', + '--linter-name=PYBIND11_SPECIALIZATION', + '--error-name=pybind11 specialization in non-standard location', + """--error-description=\ + This pybind11 specialization (PYBIND11_DECLARE_HOLDER_TYPE) should \ + be placed in torch/csrc/utils/pybind.h so that it is guaranteed to be \ + included at any site that may potentially make use of it via py::cast. \ + If your specialization is in the same header file as the definition \ + of the holder type, you can ignore this lint by adding your header to \ + the exclude_patterns for this lint in .lintrunner.toml. For more \ + information see https://github.com/pybind/pybind11/issues/4099 \ + """, + '--', + '@{{PATHSFILE}}' +] + [[linter]] code = 'PYPIDEP' include_patterns = ['.github/**'] @@ -553,28 +638,104 @@ command = [ ] [[linter]] -code = 'BLACK' +code = 'CALL_ONCE' include_patterns = [ - 'torchgen/**/*.py', + 'c10/**', + 'aten/**', + 'torch/csrc/**', +] +exclude_patterns = [ + 'c10/util/CallOnce.h', +] +command = [ + 'python3', + 'tools/linter/adapters/grep_linter.py', + '--pattern=std::call_once', + '--linter-name=CALL_ONCE', + '--error-name=invalid call_once', + '--replace-pattern=s/std::call_once/c10::call_once/', + """--error-description=\ + Use of std::call_once is forbidden and should be replaced with c10::call_once\ + """, + '--', + '@{{PATHSFILE}}' +] + +[[linter]] +code = 'ONCE_FLAG' +include_patterns = [ + 'c10/**', + 'aten/**', + 'torch/csrc/**', +] +command = [ + 'python3', + 'tools/linter/adapters/grep_linter.py', + '--pattern=std::once_flag', + '--linter-name=ONCE_FLAG', + '--error-name=invalid once_flag', + '--replace-pattern=s/std::once_flag/c10::once_flag/', + """--error-description=\ + Use of std::once_flag is forbidden and should be replaced with c10::once_flag\ + """, + '--', + '@{{PATHSFILE}}' +] + +[[linter]] +code = 'WORKFLOWSYNC' +include_patterns = [ + '.github/workflows/pull.yml', + '.github/workflows/trunk.yml', + '.github/workflows/periodic.yml', +] +command = [ + 'python3', + 'tools/linter/adapters/workflow_consistency_linter.py', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python3', + 'tools/linter/adapters/pip_init.py', + '--dry-run={{DRYRUN}}', + 'PyYAML==6.0', +] + +# Black + usort +[[linter]] +code = 'UFMT' +include_patterns = [ + 'test/onnx/**/*.py', + 'test/test_dynamo_cudagraphs.py', 'tools/**/*.py', - 'torch/package/**/*.py', 'torch/onnx/**/*.py', - 'torch/_refs/**/*.py', + 'torch/package/**/*.py', + 'torch/_decomp/**/*.py', + 'torch/_lazy/**/*.py', + 'torch/_masked/**/*.py', 'torch/_prims/**/*.py', - 'torch/_meta_registrations.py', - 'test/onnx/**/*.py', + 'torch/_refs/**/*.py', + 'torch/_subclasses/**/*.py', + 'torch/_*.py', + 'torchgen/**/*.py', ] command = [ 'python3', - 'tools/linter/adapters/black_linter.py', + 'tools/linter/adapters/ufmt_linter.py', '--', '@{{PATHSFILE}}' ] +exclude_patterns = [ + 'tools/gen_vulkan_spv.py', + 'torch/__init__.py', # Skip this file to format because it's part of the public API +] init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - '--no-binary', 'black==22.3.0', + 'ufmt==1.3.3', + 'usort==1.0.2', ] is_formatter = true diff --git a/BUCK.oss b/BUCK.oss index 125868dc9ec5c..13d5518801a24 100644 --- a/BUCK.oss +++ b/BUCK.oss @@ -1,34 +1,17 @@ load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") load( - ":build_variables.bzl", - "aten_cpu_source_list", - "aten_native_source_list", - "core_sources_common", - "jit_core_headers", - "jit_core_sources", - "libtorch_profiler_sources", -) -load( - ":pt_defs.oss.bzl", - "USED_PT_BACKENDS", - "build_aten_cpu", - "gen_aten_files", - "gen_aten_libtorch_files", - "get_aten_codegen_extra_params", - "get_pt_compiler_flags", - "get_pt_preprocessor_flags", + ":pt_ops.bzl", "pt_operator_library", - "get_pt_ops_deps", - "aten_ufunc_generated_all_cpu_sources", - "TEMPLATE_SOURCE_LIST", ) load(":buckbuild.bzl", "define_buck_targets", + "get_pt_operator_registry_dict", ) +# define shared buck targets define_buck_targets() +# define OSS only targets cxx_library( name = "pthreadpool", srcs = ['caffe2/utils/threadpool/pthreadpool.cc', 'caffe2/utils/threadpool/pthreadpool_impl.cc', 'caffe2/utils/threadpool/pthreadpool-cpp.cc', 'caffe2/utils/threadpool/thread_pool_guard.cpp', 'caffe2/utils/threadpool/ThreadPool.cc'], @@ -76,183 +59,6 @@ cxx_library( visibility = ['PUBLIC'], ) -cxx_library( - name = "common_core", - srcs = ['caffe2/core/common.cc'], - deps = [':caffe2_headers', '//c10:c10'], - exported_deps = [], - compiler_flags = ['-frtti', '-Os', '-Wno-unknown-pragmas', '-Wno-write-strings', '-Wno-unused-variable', '-Wno-unused-function', '-Wno-deprecated-declarations', '-Wno-shadow', '-Wno-global-constructors', '-Wno-missing-prototypes', '-std=gnu++17'], - preferred_linkage = "static", - header_namespace = "caffe2", - headers = [], - link_whole = True, - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - -build_aten_cpu( - name = "aten_cpu", - srcs = jit_core_sources + - aten_cpu_source_list + [ - # Generated - ":gen_aten[Functions.cpp]", - ":gen_aten[Operators_0.cpp]", - ":gen_aten[Operators_1.cpp]", - ":gen_aten[Operators_2.cpp]", - ":gen_aten[Operators_3.cpp]", - ":gen_aten[Operators_4.cpp]", - ":gen_aten[core/ATenOpList.cpp]", - ":gen_aten[core/TensorMethods.cpp]", - ] + [ - # Needed by ATen/native/EmbeddingBag.cpp - "caffe2/perfkernels/embedding_lookup_idx.cc", - ], -) - -fb_xplat_genrule( - name = "generate_aten_config", - srcs = [ - "aten/src/ATen/Config.h.in", - ], - cmd = " ".join([ - "sed", - "-e 's/@AT_MKLDNN_ENABLED@/ATEN_MKLDNN_ENABLED_FBXPLAT/g'", - "-e 's/@AT_MKL_ENABLED@/ATEN_MKL_ENABLED_FBXPLAT/g'", - "-e 's/@AT_MKL_SEQUENTIAL@/ATEN_MKL_SEQUENTIAL_FBXPLAT/g'", - "-e 's/@AT_FFTW_ENABLED@/0/g'", - "-e 's/@AT_POCKETFFT_ENABLED@/0/g'", - "-e 's/@AT_NNPACK_ENABLED@/ATEN_NNPACK_ENABLED_FBXPLAT/g'", - "-e 's/@CAFFE2_STATIC_LINK_CUDA_INT@/CAFFE2_STATIC_LINK_CUDA_FBXPLAT/g'", - "-e 's/@AT_BUILD_WITH_BLAS@/USE_BLAS_FBXPLAT/g'", - "-e 's/@AT_PARALLEL_OPENMP@/AT_PARALLEL_OPENMP_FBXPLAT/g'", - "-e 's/@AT_PARALLEL_NATIVE@/AT_PARALLEL_NATIVE_FBXPLAT/g'", - "-e 's/@AT_PARALLEL_NATIVE_TBB@/AT_PARALLEL_NATIVE_TBB_FBXPLAT/g'", - "-e 's/@AT_BUILD_WITH_LAPACK@/USE_LAPACK_FBXPLAT/g'", - "-e 's/@AT_BLAS_F2C@/AT_BLAS_F2C_FBXPLAT/g'", - "-e 's/@AT_BLAS_USE_CBLAS_DOT@/AT_BLAS_USE_CBLAS_DOT_FBXPLAT/g'", - "aten/src/ATen/Config.h.in > $OUT/Config.h", - ]), - outs = { - "Config.h": ["Config.h"], - }, - default_outs = ["."], -) - -gen_aten_files( - name = "gen_aten", - extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS), - visibility = ["PUBLIC"], -) - -ATEN_EXPORTED_HEADERS = { - "CPUFunctions.h": ":gen_aten[CPUFunctions.h]", - "CPUFunctions_inl.h": ":gen_aten[CPUFunctions_inl.h]", - "CompositeExplicitAutogradFunctions.h": ":gen_aten[CompositeExplicitAutogradFunctions.h]", - "CompositeExplicitAutogradFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradFunctions_inl.h]", - "CompositeExplicitAutogradNonFunctionalFunctions.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions.h]", - "CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]", - "CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]", - "CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]", - "FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]", - "Functions.h": ":gen_aten[Functions.h]", - "MethodOperators.h": ":gen_aten[MethodOperators.h]", - "NativeFunctions.h": ":gen_aten[NativeFunctions.h]", - "NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]", - "Operators.h": ":gen_aten[Operators.h]", - "RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]", - "core/TensorBody.h": ":gen_aten[core/TensorBody.h]", - "core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]", - "core/enum_tag.h": ":gen_aten[core/enum_tag.h]", -} - -cxx_library( - name = "generated_aten_headers_cpu", - header_namespace = "ATen", - exported_headers = ATEN_EXPORTED_HEADERS, -) -gen_aten_libtorch_files(name = "gen_aten_libtorch") - -cxx_library( - name = "torch_mobile_observer", - srcs = [ - "torch/csrc/jit/mobile/observer.cpp", - #"torch/fb/observers/MobileObserverUtil.cpp", - ], - header_namespace = "", - exported_headers = subdir_glob( - [ - ("", "torch/csrc/jit/mobile/observer.h"), - #("", "torch/fb/observers/ObserverUtil.h"), - #("", "torch/fb/observers/MobileObserverUtil.h"), - ], - ), - visibility = ["PUBLIC"], - deps = [ - "//c10:c10", - ], -) - -python_library( - name = "aten_code_template", - srcs = subdir_glob([ - ("aten", "src/ATen/code_template.py"), - ]), - base_module = "", - visibility = ["PUBLIC"], -) - - -cxx_library( - name = "torch_common", - srcs = core_sources_common, - compiler_flags = get_pt_compiler_flags(), - exported_preprocessor_flags = get_pt_preprocessor_flags(), - link_whole = True, - visibility = ["PUBLIC"], - deps = [ - ":aten_cpu", - ":generated-autograd-headers", - ":torch_headers", - "//third_party:glog", - "//c10:c10", - ], -) - - -cxx_library( - name = "torch_mobile_deserialize_common", - srcs = [ - "torch/csrc/jit/mobile/parse_bytecode.cpp", - "torch/csrc/jit/mobile/parse_operators.cpp", - "torch/csrc/jit/mobile/upgrader_mobile.cpp", - "torch/csrc/jit/serialization/import_read.cpp", - "torch/csrc/jit/serialization/unpickler.cpp", - ], - header_namespace = "", - exported_headers = [ - "torch/csrc/jit/serialization/import_read.h", - "torch/csrc/jit/serialization/unpickler.h", - ], - compiler_flags = get_pt_compiler_flags(), - link_whole = True, - linker_flags = [ - "-Wl,--no-as-needed", - ], - visibility = ["PUBLIC"], - exported_deps = [ - ":aten_cpu", - ":caffe2_headers", - ":caffe2_serialize", - ":torch_common", - ":torch_headers", - ":torch_mobile_headers", - ":torch_mobile_module", - ":torch_mobile_observer", - "//third_party:glog", - "//c10:c10", - ], -) - cxx_library( name = "caffe2_serialize", srcs = [ @@ -270,161 +76,19 @@ cxx_library( ], ) -cxx_library( - name = "torch_mobile_deserialize", - srcs = [ - "torch/csrc/jit/mobile/import.cpp", - ], - header_namespace = "", - exported_headers = [ - "torch/csrc/jit/mobile/import.h", - ], - compiler_flags = get_pt_compiler_flags(), - link_whole = True, - linker_flags = [ - "-Wl,--no-as-needed", - ], - visibility = ["PUBLIC"], - exported_deps = [ - ":aten_cpu", - ":caffe2_headers", - ":caffe2_serialize", - ":torch_common", - ":torch_headers", - ":torch_mobile_headers", - ":torch_mobile_module", - ":torch_mobile_observer", - "//third_party:glog", - "//c10:c10", - ":torch_mobile_deserialize_common", - ], -) - -cxx_library( - name = "torch_mobile_module", - srcs = [ - "torch/csrc/jit/mobile/function.cpp", - "torch/csrc/jit/mobile/interpreter.cpp", - "torch/csrc/jit/mobile/module.cpp", - ], - header_namespace = "", - exported_headers = [], - compiler_flags = get_pt_compiler_flags(), - link_whole = True, - linker_flags = [ - "-Wl,--no-as-needed", - ], - visibility = ["PUBLIC"], - exported_deps = [ - ":aten_cpu", - ":caffe2_headers", - ":torch_common", - ":torch_headers", - ":torch_mobile_headers", - ":torch_mobile_observer", - "//third_party:glog", - "//c10:c10", - ], -) - -cxx_library( - name = "torch_mobile_core", - srcs = [], - header_namespace = "", - exported_headers = [], - compiler_flags = get_pt_compiler_flags(), - exported_preprocessor_flags = get_pt_preprocessor_flags(), - link_whole = True, - linker_flags = [ - "-Wl,--no-as-needed", - # "-ldl", - ], - visibility = ["PUBLIC"], - deps = [ - ":generated-autograd-headers", - ":torch_mobile_observer", - ":torch_mobile_headers", - ], - exported_deps = [ - ":aten_cpu", - ":torch_common", - ":torch_mobile_deserialize", - ], -) - pt_operator_library( name = "torch_mobile_ops_full_dev", - check_decl = False, include_all_operators = True, ) cxx_library( - name = "torch_mobile_all_ops", - visibility = ["PUBLIC"], - deps = get_pt_ops_deps( + name = "pt_ops_full", + **get_pt_operator_registry_dict( name = "pt_ops_full", - train = False, deps = [ ":torch_mobile_ops_full_dev", ], - enable_flatbuffer = False, - ), -) - -python_library( - name = "gen_oplist_lib", - srcs = subdir_glob([ - ("tools/code_analyzer", "gen_oplist.py"), - ("tools/code_analyzer", "gen_op_registration_allowlist.py"), - ]), - base_module = "", - deps = [ - "//third_party:pyyaml", - "//tools/lite_interpreter:gen_selected_mobile_ops_header", - "//torchgen:torchgen", - ], -) - -python_binary( - name = "gen_oplist", - main_module = "gen_oplist", - visibility = ["PUBLIC"], - deps = [ - ":gen_oplist_lib", - ], -) - -python_library( - name = "gen_operators_yaml_lib", - srcs = subdir_glob([ - ("tools/code_analyzer", "gen_operators_yaml.py"), - ("tools/code_analyzer", "gen_op_registration_allowlist.py"), - ]), - base_module = "", - deps = [ - "//third_party:pyyaml", - "//torchgen:torchgen", - ], -) - -python_binary( - name = "gen_aten_bin", - main_module = "torchgen.gen", - visibility = [ - "PUBLIC", - ], - deps = [ - "//torchgen:torchgen", - ], -) - -python_binary( - name = "gen_operators_yaml", - main_module = "gen_operators_yaml", - visibility = ["PUBLIC"], - deps = [ - ":gen_operators_yaml_lib", - ], + ) ) cxx_binary( @@ -452,16 +116,7 @@ cxx_binary( ], deps = [ ":torch_mobile_core", - ":torch_mobile_all_ops", + ":pt_ops_full", "//c10:c10", ], ) - -filegroup( - name = "templated_selective_build_srcs", - # NB: no glob here, there are generated targets in this list! - srcs = glob(TEMPLATE_SOURCE_LIST) + aten_ufunc_generated_all_cpu_sources(":gen_aten[{}]"), - visibility = [ - "PUBLIC", - ], -) diff --git a/BUILD.bazel b/BUILD.bazel index 9849c796722a7..823a59bb63b75 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -50,6 +50,9 @@ generated_cpu_cpp = [ "aten/src/ATen/RegisterCompositeExplicitAutograd.cpp", "aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp", "aten/src/ATen/RegisterMeta.cpp", + "aten/src/ATen/RegisterSparseMeta.cpp", + "aten/src/ATen/RegisterQuantizedMeta.cpp", + "aten/src/ATen/RegisterNestedTensorMeta.cpp", "aten/src/ATen/RegisterSchema.cpp", "aten/src/ATen/CPUFunctions.h", "aten/src/ATen/CPUFunctions_inl.h", @@ -76,6 +79,7 @@ generated_cpu_cpp = [ "aten/src/ATen/MethodOperators.h", "aten/src/ATen/NativeMetaFunctions.h", "aten/src/ATen/RegistrationDeclarations.h", + "aten/src/ATen/VmapGeneratedPlumbing.h", "aten/src/ATen/core/aten_interned_strings.h", "aten/src/ATen/core/enum_tag.h", "aten/src/ATen/core/TensorBody.h", @@ -1357,7 +1361,7 @@ cc_library( includes = [ "caffe2/contrib/aten", "caffe2/core/nomnigraph/include", - "third_party/miniz-2.0.8", + "third_party/miniz-2.1.0", ], visibility = ["//visibility:public"], deps = [ @@ -1424,7 +1428,7 @@ cc_library( ":caffe2_perfkernels_avx2", ":caffe2_perfkernels_avx512", ":caffe2_protos", - "//third_party/miniz-2.0.8:miniz", + "//third_party/miniz-2.1.0:miniz", "@com_google_protobuf//:protobuf", "@eigen", "@fbgemm//:fbgemm_src_headers", @@ -1605,6 +1609,7 @@ cc_library( "torch/csrc/distributed", "torch/lib", "torch/lib/libshm", + "third_party/kineto/libkineto/include", ], visibility = ["//visibility:public"], deps = [ @@ -1656,6 +1661,7 @@ cc_library( deps = [ ":caffe2", ":torch_headers", + "@kineto", ] + if_cuda([ ":torch_distributed_cuda", "@cuda//:nvToolsExt", diff --git a/CMakeLists.txt b/CMakeLists.txt index 698c56a46a1da..38a430ee7287c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,13 +37,17 @@ if(env_cxx_standard GREATER -1) WARNING "C++ standard version definition detected in environment variable." "PyTorch requires -std=c++14. Please remove -std=c++ settings in your environment.") endif() -set(CMAKE_CXX_STANDARD 14) -set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_C_STANDARD 11 CACHE STRING "The C standard whose features are requested to build this target.") if(DEFINED GLIBCXX_USE_CXX11_ABI) if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) set(CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") + else() + # Please note this is required in order to ensure compatibility between gcc 9 and gcc 7 + # This could be removed when all Linux PyTorch binary builds are compiled by the same toolchain again + string(APPEND CMAKE_CXX_FLAGS " -fabi-version=11") endif() endif() @@ -291,6 +295,10 @@ if(NOT USE_XNNPACK AND CMAKE_VERSION VERSION_LESS ${XNNPACK_MIN_CMAKE_VER}) endif() option(USE_ZMQ "Use ZMQ" OFF) option(USE_ZSTD "Use ZSTD" OFF) +# Ensure that an ITT build is the default for x86 CPUs +cmake_dependent_option( + USE_ITT "Use Intel(R) VTune Profiler ITT functionality" ON + "CPU_INTEL" OFF) # Ensure that an MKLDNN build is the default for x86 CPUs # but optional for AArch64 (dependent on -DUSE_MKLDNN). cmake_dependent_option( @@ -308,6 +316,14 @@ option(USE_DISTRIBUTED "Use distributed" ON) cmake_dependent_option( USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) +cmake_dependent_option( + USE_UCC "Use UCC. Only available if USE_DISTRIBUTED is on." OFF + "USE_DISTRIBUTED" OFF) +cmake_dependent_option( + USE_SYSTEM_UCC "Use system-wide UCC" OFF + "USE_UCC" OFF) +cmake_dependent_option( + USE_C10D_UCC "USE C10D UCC" ON "USE_DISTRIBUTED;USE_UCC" OFF) cmake_dependent_option( USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) @@ -784,6 +800,11 @@ if(NOT MSVC) string(APPEND CMAKE_CXX_FLAGS " -Wall") string(APPEND CMAKE_CXX_FLAGS " -Wextra") string(APPEND CMAKE_CXX_FLAGS " -Werror=return-type") + if(NOT USE_CUDNN) + # Temporary fix to ignore non virtual dtor error if cudnn is used. A + # separate PR to cudnn_frontend is needed to address this later on + string(APPEND CMAKE_CXX_FLAGS " -Werror=non-virtual-dtor") + endif() string(APPEND CMAKE_CXX_FLAGS " -Wno-missing-field-initializers") string(APPEND CMAKE_CXX_FLAGS " -Wno-type-limits") string(APPEND CMAKE_CXX_FLAGS " -Wno-array-bounds") @@ -840,6 +861,7 @@ if(NOT MSVC) # These flags are not available in GCC-4.8.5. Set only when using clang. # Compared against https://gcc.gnu.org/onlinedocs/gcc-4.8.5/gcc/Option-Summary.html if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + string(APPEND CMAKE_CXX_FLAGS " -Wconstant-conversion") string(APPEND CMAKE_CXX_FLAGS " -Wno-invalid-partial-specialization") string(APPEND CMAKE_CXX_FLAGS " -Wno-typedef-redefinition") string(APPEND CMAKE_CXX_FLAGS " -Wno-unknown-warning-option") @@ -1089,6 +1111,7 @@ if(BUILD_SHARED_LIBS) ${PROJECT_SOURCE_DIR}/cmake/public/protobuf.cmake ${PROJECT_SOURCE_DIR}/cmake/public/threads.cmake ${PROJECT_SOURCE_DIR}/cmake/public/utils.cmake + ${PROJECT_SOURCE_DIR}/cmake/public/LoadHIP.cmake DESTINATION share/cmake/Caffe2/public COMPONENT dev) install(DIRECTORY diff --git a/CODEOWNERS b/CODEOWNERS index a38c5df38700f..1bb8efe9de0b9 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -13,7 +13,6 @@ /test/test_public_bindings.py @albanD /test/allowlist_for_publicAPI.json @albanD @anjali411 /docs/source/conf.py @albanD -/aten/src/ATen/native/native_functions.yaml @bdhirsh /aten/src/ATen/native/tags.yaml @anjali411 # Tensorpipe RPC Agent. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 20ee44b2ee087..7b4a1246d002d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -443,7 +443,7 @@ These are the docs that you see over at [our docs website](https://pytorch.org/d - **Developer facing documentation**: Developer facing documentation is spread around our READMEs in our codebase and in the [PyTorch Developer Wiki](https://pytorch.org/wiki). -If you're interested in adding new developer docs, please read this [page on the wiki](https://github.com/pytorch/pytorch/wiki/Where-or-how-should-I-add-documentation%3F) on our best practices for where to put it. +If you're interested in adding new developer docs, please read this [page on the wiki](https://github.com/pytorch/pytorch/wiki/Where-or-how-should-I-add-documentation) on our best practices for where to put it. The rest of this section is about user-facing documentation. @@ -467,12 +467,20 @@ pip install -r requirements.txt # Or if you prefer an uncontaminated global executable environment or do not want to go through the node configuration: # npm install katex && export PATH="$PATH:$(pwd)/node_modules/.bin" ``` +> Note: if you installed `nodejs` with a different package manager (e.g., +`conda`) then `npm` will probably install a version of `katex` that is not +compatible with your version of `nodejs` and doc builds will fail. +A combination of versions that is known to work is `node@6.13.1` and +`katex@0.13.18`. To install the latter with `npm` you can run +```npm install -g katex@0.13.18``` + > Note that if you are a Facebook employee using a devserver, yarn may be more convenient to install katex: ```bash yarn global add katex ``` +> If a specific version is required you can use for example `yarn global add katex@0.13.18`. 3. Generate the documentation HTML files. The generated files will be in `docs/build/html`. diff --git a/Dockerfile b/Dockerfile index a8dc7f141685d..1bd522a624067 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,11 +28,13 @@ ENV PATH /opt/conda/bin:$PATH FROM dev-base as conda ARG PYTHON_VERSION=3.8 +COPY requirements.txt . RUN curl -fsSL -v -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ chmod +x ~/miniconda.sh && \ ~/miniconda.sh -b -p /opt/conda && \ rm ~/miniconda.sh && \ - /opt/conda/bin/conda install -y python=${PYTHON_VERSION} conda-build pyyaml numpy ipython && \ + /opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \ + /opt/conda/bin/python -mpip install -r requirements.txt && \ /opt/conda/bin/conda clean -ya FROM dev-base as submodule-update diff --git a/README.md b/README.md index ad71c42cbcc97..10d14d354cc82 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ Commands to install binaries via Conda or pip wheels are on our website: [https: #### NVIDIA Jetson Platforms -Python wheels for NVIDIA's Jetson Nano, Jetson TX2, and Jetson AGX Xavier are provided [here](https://forums.developer.nvidia.com/t/pytorch-for-jetson-version-1-10-now-available/72048) and the L4T container is published [here](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-pytorch) +Python wheels for NVIDIA's Jetson Nano, Jetson TX1/TX2, Jetson Xavier NX/AGX, and Jetson AGX Orin are provided [here](https://forums.developer.nvidia.com/t/pytorch-for-jetson-version-1-10-now-available/72048) and the L4T container is published [here](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-pytorch) They require JetPack 4.2 and above, and [@dusty-nv](https://github.com/dusty-nv) and [@ptrblck](https://github.com/ptrblck) are maintaining them. @@ -380,6 +380,13 @@ You can then build the documentation by running `make ` from the If you get a katex error run `npm install katex`. If it persists, try `npm install -g katex` +> Note: if you installed `nodejs` with a different package manager (e.g., +`conda`) then `npm` will probably install a version of `katex` that is not +compatible with your version of `nodejs` and doc builds will fail. +A combination of versions that is known to work is `node@6.13.1` and +`katex@0.13.18`. To install the latter with `npm` you can run +```npm install -g katex@0.13.18``` + ### Previous Versions Installation instructions and binaries for previous PyTorch versions may be found diff --git a/WORKSPACE b/WORKSPACE index 96eea42c342ca..d26dfca5a3336 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -126,6 +126,12 @@ new_local_repository( path = "third_party/fmt", ) +new_local_repository( + name = "kineto", + build_file = "//third_party:kineto.BUILD", + path = "third_party/kineto", +) + new_patched_local_repository( name = "tbb", patches = [ diff --git a/android/README.md b/android/README.md index d1d6bcd6aa3b2..d1b40ff50f374 100644 --- a/android/README.md +++ b/android/README.md @@ -111,12 +111,12 @@ dependencies { implementation(name:'pytorch_android', ext:'aar') implementation(name:'pytorch_android_torchvision', ext:'aar') ... - implementation 'com.facebook.soloader:nativeloader:0.10.1' + implementation 'com.facebook.soloader:nativeloader:0.10.4' implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' } ``` We also have to add all transitive dependencies of our aars. -As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.1'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them. +As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.4'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them. (In case of using maven dependencies they are added automatically from `pom.xml`). You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly. diff --git a/android/build.gradle b/android/build.gradle index 3fd8f3b0d9c47..cd3755883f924 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -13,7 +13,7 @@ allprojects { junitVersion = "4.12" fbjniJavaOnlyVersion = "0.2.2" - soLoaderNativeLoaderVersion = "0.10.1" + soLoaderNativeLoaderVersion = "0.10.4" } repositories { diff --git a/android/pytorch_android/CMakeLists.txt b/android/pytorch_android/CMakeLists.txt index 4d1bf714a45f1..ad2647c2f4df6 100644 --- a/android/pytorch_android/CMakeLists.txt +++ b/android/pytorch_android/CMakeLists.txt @@ -14,7 +14,7 @@ endif() include(GNUInstallDirs) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_VERBOSE_MAKEFILE ON) message(STATUS "ANDROID_STL:${ANDROID_STL}") diff --git a/android/pytorch_android_torchvision/CMakeLists.txt b/android/pytorch_android_torchvision/CMakeLists.txt index 788e09bcc8e99..08de7cebde491 100644 --- a/android/pytorch_android_torchvision/CMakeLists.txt +++ b/android/pytorch_android_torchvision/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.4.1) project(pytorch_vision_jni CXX) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_VERBOSE_MAKEFILE ON) set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) diff --git a/android/test_app/app/CMakeLists.txt b/android/test_app/app/CMakeLists.txt index 1094223b4c99a..457ccbe189bd7 100644 --- a/android/test_app/app/CMakeLists.txt +++ b/android/test_app/app/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.4.1) set(PROJECT_NAME pytorch_testapp_jni) project(${PROJECT_NAME} CXX) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_VERBOSE_MAKEFILE ON) set(build_DIR ${CMAKE_SOURCE_DIR}/build) diff --git a/android/test_app/app/build.gradle b/android/test_app/app/build.gradle index 573e8cdd27406..d726e6424d88c 100644 --- a/android/test_app/app/build.gradle +++ b/android/test_app/app/build.gradle @@ -139,7 +139,7 @@ tasks.all { task -> dependencies { implementation 'com.android.support:appcompat-v7:28.0.0' - implementation 'com.facebook.soloader:nativeloader:0.10.1' + implementation 'com.facebook.soloader:nativeloader:0.10.4' localImplementation project(':pytorch_android') localImplementation project(':pytorch_android_torchvision') @@ -154,7 +154,7 @@ dependencies { aarImplementation(name:'pytorch_android', ext:'aar') aarImplementation(name:'pytorch_android_torchvision', ext:'aar') - aarImplementation 'com.facebook.soloader:nativeloader:0.10.1' + aarImplementation 'com.facebook.soloader:nativeloader:0.10.4' aarImplementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' def camerax_version = "1.0.0-alpha05" diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index e4b5f7028d241..a269f82fa8176 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace at { @@ -56,18 +57,21 @@ static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { return dim == 0 || dim == -1; } -Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional dtype) { - // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail - // and instead returns a new scalar tensor (this also happens for dim=-1) - // If the following happens: - // >>> x = torch.randn(B0) # the per-examples are all scalars - // >>> vmap(partial(torch.sum, dim=0), x) - // then we replicate the behavior of sum(scalar_tensor, dim=0). - if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) { - return self.clone(); +Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional dtype) { + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail + // and instead returns a new scalar tensor (this also happens for dim=-1) + // If the following happens: + // >>> x = torch.randn(B0) # the per-examples are all scalars + // >>> vmap(partial(torch.sum, dim=0), x) + // then we replicate the behavior of sum(scalar_tensor, dim=0). + if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) { + return self.clone(); + } } auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); - auto dims_physical = self_physical.getPhysicalDims(dims); + auto dims_physical = self_physical.getPhysicalDims(opt_dims); auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype); return self_physical.getPhysicalToLogicalMap().apply(result); } @@ -181,10 +185,13 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) return self_physical.getPhysicalToLogicalMap().apply(result); } -Tensor expand_batching_rule_symint(const Tensor& self, SymIntArrayRef psize, bool implicit) { - return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit); +Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) { + return self.expand(asIntArrayRefSlow(psize), implicit); } +Tensor sum_symint_batching_rule(const Tensor& input_t, c10::SymIntArrayRef dim, bool keepdim, optional opt_dtype) { + return input_t.sum(c10::asIntArrayRefSlow(dim), keepdim, opt_dtype); +} std::vector chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); @@ -465,6 +472,10 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) { return self_physical.getPhysicalToLogicalMap().apply(result); } +Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) { + return self.view(asIntArrayRefSlow(size)); +} + Tensor view_as_complex_batching_rule(const Tensor& self) { // guard against the user passing in a batch of scalar tensors with batch // size equal to 2. @@ -995,6 +1006,16 @@ Tensor new_empty_batching_rule( return physical_view.getPhysicalToLogicalMap().apply(result); } +Tensor new_empty_symint_batching_rule( + const Tensor& self, + c10::SymIntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory); +} + Tensor new_empty_strided_batching_rule( const Tensor& self, IntArrayRef size, @@ -1079,6 +1100,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule); m.impl("sum.dim_IntList", sum_batching_rule); + m.impl("sum.SymInt", sum_symint_batching_rule); m.impl("is_complex", native::is_complex); // inplace operations @@ -1093,7 +1115,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("tensor_split.indices", tensor_split_indices_batching_rule); m.impl("diagonal", diagonal_batching_rule); m.impl("expand", expand_batching_rule); - m.impl("expand.SymInt", expand_batching_rule_symint); + m.impl("expand.SymInt", expand_symint_batching_rule); m.impl("expand_as", native::expand_as); // composite wrt autograd m.impl("movedim.intlist", movedim_batching_rule); m.impl("movedim.int", static_cast(native::movedim)); // composite wrt autograd @@ -1122,6 +1144,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("unfold", unfold_batching_rule); m.impl("unsqueeze", unsqueeze_batching_rule); m.impl("view", view_batching_rule); + m.impl("view.SymInt", view_symint_batching_rule); m.impl("view_as", native::view_as); // composite wrt autograd // clamp operations @@ -1260,6 +1283,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { // Tensor.new_* operators m.impl("new_empty", new_empty_batching_rule); + m.impl("new_empty.SymInt", new_empty_symint_batching_rule); m.impl("new_empty_strided", new_empty_strided_batching_rule); m.impl("new_zeros", new_zeros_batching_rule); diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 87f3e0688ab4d..286d59f3e97d6 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -121,7 +121,7 @@ file(GLOB native_ao_sparse_h "native/ao_sparse/cpu/*.h" "native/ao_sparse/quantized/*.h" "native/ao_sparse/quantized/cpu/*.h") -file(GLOB native_quantized_h "native/quantized/*.h" "native/quantized/cpu/*.h", "native/quantized/cudnn/*.h") +file(GLOB native_quantized_h "native/quantized/*.h" "native/quantized/cpu/*.h" "native/quantized/cudnn/*.h") file(GLOB native_cpu_h "native/cpu/*.h") file(GLOB native_cuda_cu "native/cuda/*.cu") diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index e1c46a710686a..4e8c9cae04f73 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -144,6 +144,14 @@ void Context::setBenchmarkCuDNN(bool b) { benchmark_cudnn = b; } +int Context::benchmarkLimitCuDNN() const { + return benchmark_limit_cudnn; +} + +void Context::setBenchmarkLimitCuDNN(int b) { + benchmark_limit_cudnn = b; +} + bool Context::allowTF32CuBLAS() const { static bool allow_tf32_cublas_override = c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true; return allow_tf32_cublas_override || float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 2665e893e99f9..8f3928376473d 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -99,10 +100,10 @@ class TORCH_API Context { // defined in header so that getNonVariableType has ability to inline // call_once check. getNonVariableType is called fairly frequently void lazyInitCUDA() { - std::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); }); + c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); }); } void lazyInitHIP() { - std::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); }); + c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); }); } static const at::cuda::NVRTC& getNVRTC() { return detail::getCUDAHooks().nvrtc(); @@ -120,6 +121,8 @@ class TORCH_API Context { void setUserEnabledMkldnn(bool e); bool benchmarkCuDNN() const; void setBenchmarkCuDNN(bool); + int benchmarkLimitCuDNN() const; + void setBenchmarkLimitCuDNN(int); bool deterministicCuDNN() const; void setDeterministicCuDNN(bool); @@ -244,8 +247,8 @@ class TORCH_API Context { } } static bool checkCuBLASConfigDeterministic(); - std::once_flag thc_init; - std::once_flag thh_init; + c10::once_flag thc_init; + c10::once_flag thh_init; bool enabled_cudnn = true; bool deterministic_cudnn = false; bool _deterministic_algorithms = false; @@ -253,6 +256,7 @@ class TORCH_API Context { bool benchmark_cudnn = false; Float32MatmulPrecision float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; + int benchmark_limit_cudnn = 10; bool allow_tf32_cudnn = true; bool allow_fp16_reduction_cublas = true; bool enabled_mkldnn = true; diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 7339b1f92a9e0..08d41126a1619 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -47,132 +47,64 @@ TORCH_API void record_kernel_function_dtype(std::string name); #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) #endif +// Avoid if_constexpr if possble, as it's more expensive to compile #if defined __cpp_if_constexpr -#define AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ - case enum_type: { \ - if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \ - AT_ERROR( \ - "dtype '", \ - toString(enum_type), \ - "' not selected for kernel tag ", \ - #NAME); \ - } \ - using HINT = type; \ - return __VA_ARGS__(); \ - } -#else -#define AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ - case enum_type: { \ - at::guts::if_constexpr<( \ - !at::should_include_kernel_dtype(NAME, enum_type))>([] { \ - AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \ - }); \ - using HINT = type; \ - return __VA_ARGS__(); \ - } +#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \ + do { \ + if constexpr (!at::should_include_kernel_dtype( \ + at_dispatch_name, enum_type)) { \ + AT_ERROR( \ + "dtype '", \ + toString(enum_type), \ + "' not selected for kernel tag ", \ + at_dispatch_name); \ + } \ + } while (0) +#else // defined __cpp_if_constexpr +#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \ + at::guts::if_constexpr([&] { \ + AT_ERROR( \ + "dtype '", \ + toString(enum_type), \ + "' not selected for kernel tag ", \ + at_dispatch_name); \ + }) #endif -#define AT_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ - AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, scalar_t, __VA_ARGS__) - -// Workaround for C10_UNUSED because CUDA 10.1 and below fails to handle unused -// attribute in the type aliasing context. Keep name long and verbose to avoid -// macro collisions. -#if defined(__CUDACC__) && defined(CUDA_VERSION) && CUDA_VERSION <= 10010 -#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND -#else -#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND C10_UNUSED -#endif // defined(__CUDACC__) && defined(CUDA_VERSION) && CUDA_VERSION <= 10010 - -#if defined __cpp_if_constexpr -#define AT_QINT_PRIVATE_CASE_TYPE( \ - NAME, enum_type, type, underlying_enum, underlying_type, ...) \ - case enum_type: { \ - if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \ - AT_ERROR( \ - "dtype '", \ - toString(enum_type), \ - "' not selected for kernel tag ", \ - #NAME); \ - } \ - using scalar_t = type; \ - using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ - scalar_t::underlying; \ - const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \ - const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ - toUnderlying(enum_type); \ - (void)SCALAR_TYPE; /* Suppress unused-var compiler warning */ \ - /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \ - return __VA_ARGS__(); \ - } -#else -#define AT_QINT_PRIVATE_CASE_TYPE( \ - NAME, enum_type, type, underlying_enum, underlying_type, ...) \ - case enum_type: { \ - at::guts::if_constexpr<( \ - !at::should_include_kernel_dtype(NAME, enum_type))>([] { \ - AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \ - }); \ - using scalar_t = type; \ - using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ - scalar_t::underlying; \ - const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \ - const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ - toUnderlying(enum_type); \ - (void)SCALAR_TYPE; /* Suppress unused-var compiler warning */ \ - /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \ - return __VA_ARGS__(); \ +#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using HINT = c10::impl::ScalarTypeToCPPTypeT; \ + return __VA_ARGS__(); \ } -#endif -#if defined __cpp_if_constexpr -#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ - NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \ - case enum_type: { \ - if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \ - AT_ERROR( \ - "dtype '", \ - toString(enum_type), \ - "' not selected for kernel tag ", \ - #NAME); \ - } \ - using scalar_t = type; \ - using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ - scalar_t::underlying; \ - const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \ - const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ - toUnderlying(enum_type); \ - C10_UNUSED int bit_width = bitwidth; \ - C10_UNUSED int64_t quant_min = qmin; \ - C10_UNUSED int64_t quant_max = qmax; \ - (void)bit_width; /* Suppress unused variable warning */ \ - (void)quant_min; /* Suppress unused variable warning */ \ - (void)quant_max; /* Suppress unused variable warning */ \ - return __VA_ARGS__(); \ +#define AT_DISPATCH_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__) + +#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + using underlying_t C10_UNUSED = typename scalar_t::underlying; \ + const auto& SCALAR_TYPE C10_UNUSED = enum_type; \ + const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \ + return __VA_ARGS__(); \ } -#else -#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ - NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \ - case enum_type: { \ - at::guts::if_constexpr<( \ - !at::should_include_kernel_dtype(NAME, enum_type))>([] { \ - AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \ - }); \ - using scalar_t = type; \ - using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ - scalar_t::underlying; \ - const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \ - const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ - toUnderlying(enum_type); \ - int bit_width = bitwidth; \ - int64_t quant_min = qmin; \ - int64_t quant_max = qmax; \ - (void)bit_width; /* Suppress unused variable warning */ \ - (void)quant_min; /* Suppress unused variable warning */ \ - (void)quant_max; /* Suppress unused variable warning */ \ - return __VA_ARGS__(); \ + +#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + enum_type, scalar_type, bitwidth, qmin, qmax, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + using underlying_t C10_UNUSED = typename scalar_t::underlying; \ + const auto& SCALAR_TYPE C10_UNUSED = enum_type; \ + const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \ + C10_UNUSED int bit_width = bitwidth; \ + C10_UNUSED int64_t quant_min = qmin; \ + C10_UNUSED int64_t quant_max = qmax; \ + return __VA_ARGS__(); \ } -#endif namespace detail { @@ -210,7 +142,7 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] { // // Your code here, with 'scalar_t' now defined to // // be the dtype in question -// }) +// }); // // There are many variations of this macro, so it's important to // understand exactly /which/ dtypes you want to get instantiated, as @@ -261,721 +193,330 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} // functions. There is no risk of missing out on any code, so // it's mostly a risk of a Type-2 error, and not a Type-1 error. // +// Switch-like syntax: +// ------------------- +// There is also a switch-case like syntax which is useful if a kernel +// needs to be specialized for particular scalar types +// +// AT_DISPATCH_SWITCH(self.scalar_type(), "op_name", +// AT_DISPATCH_CASE_INTEGRAL_TYPES([&] { +// op_integral(iter); +// }) +// AT_DISPATCH_CASE_FLOATING_TYPES([&] { +// op_floating(iter); +// }) +// AT_DISPATCH_CASE(kBool, [&] { +// op_bool(iter); +// }) +// ); +// +// For each AT_DISPATCH_FOO macro, there is a corresponding +// AT_DISPATCH_CASE_FOO macro which can be used inside of an +// AT_DISPATCH_SWITCH block. // NB: the the_type variable is not used, but we have kept it for // backwards compatibility. It's probably not used by anyone though; // but we're just being safe (and it doesn't hurt.) Note we must // use it to shut up warnings about unused store. -#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ - }() - -#define AT_DISPATCH_FLOATING_TYPES_AND2( \ - SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE1, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE2, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ - }() - -#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \ - SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \ - SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE1, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE2, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE1, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE2, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE3, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \ +#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \ [&] { \ const auto& the_type = TYPE; \ + constexpr const char* at_dispatch_name = NAME; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \ switch (_st) { \ - AT_QINT_PRIVATE_CASE_TYPE( \ - NAME, at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \ - AT_QINT_PRIVATE_CASE_TYPE( \ - NAME, at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \ - AT_QINT_PRIVATE_CASE_TYPE( \ - NAME, at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \ + __VA_ARGS__ \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR( \ + '"', \ + at_dispatch_name, \ + "\" not implemented for '", \ + toString(_st), \ + "'"); \ } \ }() -#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_QINT_PRIVATE_CASE_TYPE( \ - NAME, at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \ - AT_QINT_PRIVATE_CASE_TYPE( \ - NAME, at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ - }() - -#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ - NAME, \ - at::kQInt8, \ - at::qint8, \ - int8_t, \ - CHAR_BIT, \ - SCHAR_MIN, \ - SCHAR_MAX, \ - __VA_ARGS__) \ - AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ - NAME, \ - at::kQUInt8, \ - at::quint8, \ - uint8_t, \ - CHAR_BIT, \ - 0, \ - UCHAR_MAX, \ - __VA_ARGS__) \ - AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ - NAME, \ - at::kQInt32, \ - at::qint32, \ - int, \ - CHAR_BIT * sizeof(int), \ - INT_MIN, \ - INT_MAX, \ - __VA_ARGS__) \ - AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ - NAME, at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \ - AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ - NAME, at::kQUInt2x4, at::quint2x4, uint8_t, 2, 0, 3, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ - }() - -#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op*/ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op*/ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op*/ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() +#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \ + AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__) + +#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \ + SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \ + SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) + +#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES(...) \ + AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_TYPES(...) \ + AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__) + +#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \ + AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) + +#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQInt32, \ + at::qint32, \ + CHAR_BIT * sizeof(int), \ + INT_MIN, \ + INT_MAX, \ + __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__) + +#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op*/ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE1, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE2, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ - SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op*/ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE1, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE2, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_ALL_TYPES_AND3( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op*/ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE1, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE2, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE3, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op*/ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE1, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE2, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE3, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op*/ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexFloat, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - at::ScalarType::ComplexDouble, \ - c10::complex, \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE1, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE2, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE3, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - NAME, \ - SCALARTYPE4, \ - decltype(c10::impl::ScalarTypeToCPPType::t), \ - __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() - -#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_index_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _it = ::detail::scalar_type(the_index_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _it) \ - switch (_it) { \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - NAME, at::ScalarType::Int, int32_t, index_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - NAME, at::ScalarType::Long, int64_t, index_t, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \ - } \ - }() + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) + +#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Int, index_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Long, index_t, __VA_ARGS__)) // ---------------------------------------------------------------------------- // DEPRECATED MACROS, DON'T USE THESE // ---------------------------------------------------------------------------- -#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \ - [&] { \ - detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() +#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \ + detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__)) diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index d720a43aa51e3..caf2a4e653c86 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -275,6 +275,71 @@ TensorBase empty_meta( return empty_meta(size, dtype, memory_format_opt); } +TensorBase empty_symint_meta( + SymIntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt +) { + auto device = device_or_default(device_opt); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta); + // NB: because there is no SparseMeta (yet), non-strided layout is + // exerciseable + TORCH_CHECK_NOT_IMPLEMENTED( + layout_or_default(layout_opt) == Layout::Strided, + "non-strided meta tensors not supported yet" + ); + + auto scalar_type = dtype_or_default(dtype_opt); + auto *allocator = GetAllocator(kMeta); + constexpr c10::DispatchKeySet meta_dks(c10::DispatchKey::Meta); + // TODO: do this. Note that naive implementation will choke on truly + // unknown sizes without on the fly reasoning + // at::detail::check_size_nonnegative(size); + at::detail::raise_warning_for_complex_half(scalar_type); + caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type); + SymInt size_bytes = dtype.itemsize(); + for (auto s : size) { + size_bytes = size_bytes * s; + } + auto storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + size_bytes, + allocator, + /*resizeable=*/true); + + auto tensor = detail::make_tensor_base( + std::move(storage_impl), meta_dks, dtype); + + int64_t dim = size.size(); + std::vector strides; + strides.resize(dim); + + // TODO: Move this into TensorImpl + auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous); + switch (memory_format) { + case MemoryFormat::Contiguous: { + if (dim > 0) { + const auto last_idx = dim - 1; + strides.at(last_idx) = 1; + for (auto i = last_idx - 1; i >= 0; --i) { + // TODO: max with 1 + strides.at(i) = strides.at(i+1) * size.at(i+1); + } + } + break; + } + default: + TORCH_CHECK(0, "other memory format not implemented yet"); + } + + tensor.unsafeGetTensorImpl()->set_sym_sizes_and_strides(size, strides); + + return tensor; +} + TensorBase empty_meta( IntArrayRef size, const TensorOptions &options) { return at::detail::empty_meta( diff --git a/aten/src/ATen/EmptyTensor.h b/aten/src/ATen/EmptyTensor.h index fdcc30730cf5e..06a33601a1549 100644 --- a/aten/src/ATen/EmptyTensor.h +++ b/aten/src/ATen/EmptyTensor.h @@ -87,6 +87,14 @@ TORCH_API TensorBase empty_meta( c10::optional pin_memory_opt, c10::optional memory_format_opt); +TORCH_API TensorBase empty_symint_meta( + SymIntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt); + TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options); TORCH_API TensorBase diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index c7b9e907c489b..7a81076a7dd01 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -437,17 +437,16 @@ inline std::vector expand_outplace(TensorList to_expand) { return result; } -// Sums `tensor` repeatedly to produce a tensor of shape `shape`. -// Precondition: is_expandable_to(shape, tensor.sizes()) must be true static inline Tensor sum_to( Tensor tensor, - const IntArrayRef shape, + const c10::SymIntArrayRef shape, bool always_return_non_view = false) { if (shape.size() == 0) { return tensor.sum(); } - c10::SmallVector reduce_dims; - const at::IntArrayRef sizes = tensor.sizes(); + + auto sizes = tensor.sym_sizes(); + c10::SmallVector reduce_dims; const int64_t leading_dims = sizes.size() - shape.size(); for (const auto i : c10::irange(leading_dims)) { reduce_dims.push_back(i); @@ -457,29 +456,44 @@ static inline Tensor sum_to( reduce_dims.push_back(i); } } + if (!reduce_dims.empty()) { - tensor = tensor.sum(reduce_dims, /*keepdim=*/true); + tensor = tensor.sum_symint(reduce_dims, /*keepdim=*/true); } + if (always_return_non_view) { // This is only actually used by the functionalization pass. // We want to be able to guarantee that this function doesn't return a view // of the input. - return leading_dims > 0 ? at::view_copy(tensor, shape) : tensor.clone(); + return leading_dims > 0 ? at::view_copy_symint(tensor, shape) + : tensor.clone(); } else { - return leading_dims > 0 ? tensor.view(shape) : tensor; + return leading_dims > 0 ? tensor.view_symint(shape) : tensor; } } -// True if `shape` can be broadcasted to `desired` -static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) { +// Sums `tensor` repeatedly to produce a tensor of shape `shape`. +// Precondition: is_expandable_to(shape, tensor.sizes()) must be true +static inline Tensor sum_to( + Tensor tensor, + const IntArrayRef shape, + bool always_return_non_view = false) { + auto sym_size = c10::SymIntArrayRef( + reinterpret_cast(shape.data()), shape.size()); + return sum_to(tensor, sym_size, always_return_non_view); +} + +static inline bool is_expandable_to( + SymIntArrayRef shape, + c10::SymIntArrayRef desired) { size_t ndim = shape.size(); size_t target_dim = desired.size(); if (ndim > target_dim) { return false; } for (const auto i : c10::irange(ndim)) { - int64_t size = shape[ndim - i - 1]; - int64_t target = desired[target_dim - i - 1]; + auto size = shape[ndim - i - 1]; + auto target = desired[target_dim - i - 1]; if (size != target && size != 1) { return false; } @@ -487,4 +501,12 @@ static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) { return true; } +static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) { + auto sym_shape = c10::SymIntArrayRef( + reinterpret_cast(shape.data()), shape.size()); + auto sym_desired = c10::SymIntArrayRef( + reinterpret_cast(desired.data()), desired.size()); + return is_expandable_to(sym_shape, sym_desired); +} + } // namespace at diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 1aa948138febc..471c74a73c952 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -172,6 +172,10 @@ Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& return mutated_view; } +Tensor FunctionalInverses::lift_fresh_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { + return mutated_view; +} + Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::optional start, c10::optional end, int64_t step) { // Pessimism: we can't reapply views for slice_scatter. return base.slice_scatter(mutated_view, dim, start, end, step); @@ -295,6 +299,14 @@ Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& m } } +Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) { + if (reapply_views) { + return mutated_view.view_symint(base.sym_sizes()); + } else { + return at::view_copy_symint(mutated_view, base.sym_sizes()); + } +} + Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) { if (reapply_views) { return mutated_view.view(base.scalar_type()); diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 65e47e5f19d1e..a8c58466a052c 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -9,6 +9,12 @@ #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + namespace at { void FunctionalTensorWrapper::set_constructor_metadata() { @@ -22,11 +28,27 @@ void FunctionalTensorWrapper::set_constructor_metadata() { refresh_numel(); refresh_contiguous(); storage_access_should_throw_ = false; + // In general, the sizes/stride metadata on a tensor can change as it is mutated, + // and these changes need to be reflected in the metadata of the wrapper. + set_allow_tensor_metadata_change(true); key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set(); // All of the keys corresponding to functorch transforms should not be copied over. // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect // to participate in the functorch transforms. - key_set_ = key_set_ - c10::functorch_transforms_ks; + key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks; + // For better error handling, + // we also don't want our wrapper tensor to be able to dispatch directly + // to a backend kernel. + // Dispatching directly to e.g. a CPU kernel would always segfault, + // because wrapper tensors don't have any real data. + // (This should never happen because we should always hit a functionalization kernel, + // but can help make bugs less nasty). + // Here, we defensively remove any backend keys from the wrapper's keyset. + // We don't want to remove actual backend bits though (say we're redispatching to autograd; + // we need to know if we're dispatching to AutogradCPU or AutogradXLA). + // Instead, it's sufficient to remove the `Dense` dispatch key, + // which prevents us from accidentally trying to directly run a CPU/CUDA kernel. + key_set_ = key_set_.remove(c10::DispatchKey::Dense); } FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) @@ -180,10 +202,18 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) { value_ = other; // out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor. // We need to propagate that metadata mutation to the wrapper (new size). - set_sizes_and_strides(value_.sizes(), value_.strides()); - set_storage_offset(value_.storage_offset()); + if (sizes() != value_.sizes() || strides() != value_.strides()) { + set_sizes_and_strides(value_.sizes(), value_.strides()); + } + if (storage_offset() != value_.storage_offset()) { + set_storage_offset(value_.storage_offset()); + } if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) { - value_ = value_.to(c10::TensorOptions().dtype(dtype()).layout(layout())); + // .to() should not re-entrantly go through functionalization. + at::AutoDispatchSkipFunctionalize guard; + // and we want _to_copy() to show up in the graph, not the composite .to() operator + // (this can happen if autograd has already run by the time we enter this code) + value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout())); } } @@ -260,6 +290,44 @@ const char* FunctionalTensorWrapper::tensorimpl_type_name() const { return "FunctionalTensorWrapper"; } +template +c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + if (key_set_.has(DispatchKey::Python) && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this); + if (r) { + r->set_version_counter(std::forward(version_counter)); + r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return r; + } + } + auto impl = c10::make_intrusive(value_); + copy_tensor_metadata( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/std::forward(version_counter), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + impl->refresh_contiguous(); + return impl; +} + +c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + return shallow_copy_and_detach_core( + version_counter, allow_tensor_metadata_change); +} + +c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + return shallow_copy_and_detach_core( + std::move(version_counter), allow_tensor_metadata_change); +} + at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const { return value_.unsafeGetTensorImpl()->sizes(); } @@ -275,6 +343,9 @@ int64_t FunctionalTensorWrapper::numel_custom() const { bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const { return value_.unsafeGetTensorImpl()->is_contiguous(); } +c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes() const { + return value_.unsafeGetTensorImpl()->sym_sizes(); +} c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const { return value_.unsafeGetTensorImpl()->sym_sizes(); } @@ -329,7 +400,7 @@ std::vector to_functional_tensor(const TensorList& t_list) { Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) { // Note [Wrapped Numbers <> Functionalization] - if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + if (!tensor.defined() || tensor.unsafeGetTensorImpl()->is_wrapped_number()) { return tensor; } if (isFunctionalTensor(tensor)) { @@ -452,43 +523,43 @@ bool isFunctionalTensor(const c10::optional& t) { } } +// For lists that have a mix of functional and nonfunctional tensors, +// functionalization machinery should just unwrap the functional wrappers +// and leave the ordinary tensors alone. bool isFunctionalTensor(const c10::List& t_list) { if (t_list.size() == 0) return false; - bool any_functional = isFunctionalTensor(t_list[0]); - for (const auto i : c10::irange(1, t_list.size())) { - auto curr_functional = isFunctionalTensor(t_list[i]); - TORCH_INTERNAL_ASSERT( - curr_functional == any_functional, - "Functionalization encountered a list of tensors where some are functional", - "and some are not, which is not currently unsupported."); + auto functional_count = 0; + for (const auto i : c10::irange(t_list.size())) { + if (!t_list[i].defined()) continue; + if (isFunctionalTensor(t_list[i])) { + ++functional_count; + } } - return any_functional; + return functional_count > 0; } bool isFunctionalTensor(const c10::List>& t_list) { if (t_list.size() == 0) return false; - bool any_functional = isFunctionalTensor(t_list[0]); - for (const auto i : c10::irange(1, t_list.size())) { - auto curr_functional = isFunctionalTensor(t_list[i]); - TORCH_INTERNAL_ASSERT( - curr_functional == any_functional, - "Functionalization encountered a list of tensors where some are functional", - "and some are not, which is not currently unsupported."); + auto functional_count = 0; + for (const auto i : c10::irange(t_list.size())) { + if (!t_list[i].has_value() || !t_list[i]->defined()) continue; + if (isFunctionalTensor(t_list[i])) { + ++functional_count; + } } - return any_functional; + return functional_count > 0; } bool isFunctionalTensor(const c10::ArrayRef t_list) { if (t_list.size() == 0) return false; - bool any_functional = isFunctionalTensor(t_list[0]); - for (const auto i : c10::irange(1, t_list.size())) { - auto curr_functional = isFunctionalTensor(t_list[i]); - TORCH_INTERNAL_ASSERT( - curr_functional == any_functional, - "Functionalization encountered a list of tensors where some are functional", - "and some are not, which is not currently unsupported."); + auto functional_count = 0; + for (const auto i : c10::irange(t_list.size())) { + if (!t_list[i].defined()) continue; + if (isFunctionalTensor(t_list[i])) { + ++functional_count; + } } - return any_functional; + return functional_count > 0; } Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { @@ -552,5 +623,93 @@ void setFunctionalizationReapplyViewsTLS(bool reapply_views) { } } // namespace impl + + +// Given an **out-of-place** op that might internally call view/inplace ops, +// This function will "functionalize" it. +// That is, it will call the operator, but removing any intermediate views/mutations +// that are performed inside of it. +// This is useful for LTC/XLA, which would like to re-use some of our composite kernels +// from pytorch core but not have to worry about the view ops that they might call. +// e.g. at::block_diag +void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_arguments = schema.arguments().size(); + const auto arguments_begin = stack->size() - num_arguments; + auto arguments = torch::jit::last(stack, num_arguments); + + // Wrap all tensor-like inputs into FunctionalTensorWrappers. + // When we re-invoke the dispatcher, this will automatically enable the functionalization pass. + for (uint64_t idx = 0; idx < num_arguments; ++idx) { + const auto& ivalue = arguments[idx]; + if (ivalue.isTensor()) { + auto t = ivalue.toTensor(); + if (t.defined()) { + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t), + "The composite op functionalization fallback expects its inputs all not to be functional tensors"); + auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); + (*stack)[arguments_begin + idx] = t_new; + } + } else if (ivalue.isTensorList()) { + auto tensors = ivalue.toTensorList(); + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensors), + "The composite op functionalization fallback expects its inputs all not to be functional tensors"); + auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors)); + (*stack)[arguments_begin + idx] = t_new; + } else if (ivalue.isOptionalTensorList()) { + auto opt_tensors = ivalue.toOptionalTensorList(); + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(opt_tensors), + "The composite op functionalization fallback expects its inputs all not to be functional tensors"); + auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors)); + (*stack)[arguments_begin + idx] = t_new; + } + } + + { + // Today when you call at::empty(device=lazy), the lazy backend decides whether or not to wrap + // the output in a functional tensor based on TLS. + // In this code, we're re-entrantly entering functionalization in the same call-stack, + // so we need to manually fix up TLS as if it hadn't already been called. + auto curr_tls = c10::impl::tls_local_dispatch_key_set(); + auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet(); + tls_reenable_functionalize.set_included(curr_tls.included_); + tls_reenable_functionalize.set_excluded(curr_tls.excluded_.remove(c10::DispatchKey::Functionalize)); + c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize); + // So, we should probably provide a way to directly call a kernel registered to + // the `CompositeExplicitAutograd` key. + // We can't do that today, so this should be a reasonably good proxy + // (It won't work in cases where an op has both a CompositeExplicitAutograd kernel + // AND a dedicated meta kernel, but that probably shouldn't ever happen). + op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::Meta), stack); + } + + const auto num_returns = schema.returns().size(); + const auto returns_begin = stack->size() - num_returns; + auto returns = torch::jit::last(stack, num_returns); + + for (const auto idx : c10::irange(num_returns)) { + const auto& ivalue = returns[idx]; + if (ivalue.isTensor()) { + auto t = ivalue.toTensor(); + if (!t.defined()) continue; + at::functionalization::impl::sync(t); + auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); + (*stack)[returns_begin + idx] = t_new; + } else if (ivalue.isTensorList()) { + auto tensors = ivalue.toTensorList(); + at::functionalization::impl::sync(tensors); + auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors)); + (*stack)[returns_begin + idx] = t_new; + } else if (ivalue.isOptionalTensorList()) { + auto opt_tensors = ivalue.toOptionalTensorList(); + at::functionalization::impl::sync(opt_tensors); + auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors)); + (*stack)[returns_begin + idx] = t_new; + } + } +} + + + } // namespace functionalization } // namespace at diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index 91bde5086ecb6..c5c0339fc1bfe 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -4,6 +4,9 @@ #include #include #include +#include +#include +#include #include @@ -120,6 +123,14 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { // See Note[resize_() in functionalization pass] void maybe_replace_storage(const Tensor& other); + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + ~FunctionalTensorWrapper() override = default; // FunctionalTensorWrapper overrides all custom size/stride function, @@ -130,6 +141,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { int64_t dim_custom() const override; int64_t numel_custom() const override; bool is_contiguous_custom(at::MemoryFormat memory_format) const override; + c10::SymIntArrayRef sym_sizes() const override; c10::SymIntArrayRef sym_sizes_custom() const override; private: @@ -137,6 +149,16 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { void set_constructor_metadata(); functionalization::FunctionalStorageImpl* functional_storage_impl() const; + // This is used to re-implement shallow_copy_and_detach for + // FunctionalTensorWrapper. The implementation is identical, but we just need + // to return a subclass instead of a plain TensorImpl. + // TODO: maybe it's possible to arrange for that to happen automatically + // without an override here? + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + // Note that value is not taken by reference: internally, the wrapper will // change the value tensor that it points to over time. Tensor value_; @@ -251,5 +273,36 @@ class TORCH_API FunctionalizationReapplyViewsGuard { }; } // namespace impl + +// Helper function to call an out-of-place composite aten kernel that may use +// mutations / views internally, and functionalize them. +TORCH_API void functionalize_op_helper( + const c10::OperatorHandle& op, + torch::jit::Stack* stack); + +template +struct _functionalize_aten_op final {}; + +template +struct _functionalize_aten_op final { + static ReturnType call(ParameterTypes... args) { + auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow( + (const char*)Op::name, (const char*)Op::overload_name) + .typed(); + + return c10::impl::BoxedKernelWrapper::call( + c10::BoxedKernel::makeFromFunction(), + op, + // BoxedKernelWrapper knows to ignore this keyset argument, + // because functionalize_op_helper doesn't take in a DispatchKeySet + c10::DispatchKeySet(), + args...); + } +}; + +template +using functionalize_aten_op = _functionalize_aten_op; + } // namespace functionalization } // namespace at diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index f48b887f4198f..25c81165f8830 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -2,6 +2,8 @@ #include #include #include +#include +#include #include #include @@ -10,11 +12,16 @@ #include #include #else +#include #include +#include +#include +#include #include #include #include #include +#include #endif namespace { @@ -32,7 +39,7 @@ namespace { if (ivalue.isTensor()) { any_tensor_inputs = true; auto t = ivalue.toTensor(); - if (at::functionalization::impl::isFunctionalTensor(t)) { + if (t.defined() && at::functionalization::impl::isFunctionalTensor(t)) { any_functional_inputs = true; at::functionalization::impl::sync(t); auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); @@ -73,6 +80,7 @@ namespace { const auto& ivalue = returns[idx]; if (ivalue.isTensor() && should_wrap_outputs) { auto t = ivalue.toTensor(); + if (!t.defined()) continue; auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); (*stack)[returns_begin + idx] = t_new; } else if (ivalue.isTensorList() && should_wrap_outputs) { @@ -126,7 +134,7 @@ const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet, at::Tensor tmp_output; { at::AutoDispatchSkipFunctionalize guard; - tmp_output = at::resize_functional(self_, size, memory_format); + tmp_output = at::resize(self_, size, memory_format); } auto itemsize = self.dtype().itemsize(); @@ -168,7 +176,116 @@ const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet, at::Tensor lift_functionalize(const at::Tensor & self) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self)); - return at::functionalization::impl::to_functional_tensor(self); + at::AutoDispatchSkipFunctionalize guard; + auto out = at::lift(self); + return at::functionalization::impl::to_functional_tensor(out); +} + +at::Tensor lift_fresh_functionalize(const at::Tensor & self) { + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self)); + at::AutoDispatchSkipFunctionalize guard; + auto out = at::lift_fresh(self); + return at::functionalization::impl::to_functional_tensor(out); +} + +at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) { + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self)); + at::AutoDispatchSkipFunctionalize guard; + auto out = at::lift_fresh_copy(self); + return at::functionalization::impl::to_functional_tensor(out); +} + +bool device_opted_into_functionalization(c10::Device self_device, c10::optional tgt_device) { + // If the target device is empty, then the output tensor should be on the same device as the input + auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device; + return real_tgt_device.type() == c10::DeviceType::XLA || real_tgt_device.type() == c10::DeviceType::Lazy; +} + +// note I only need this because the to.dtype/to.dtype_layout overload calls this, so we skip the op above. +// We should probably get rid of this though. +at::Tensor _to_copy_functionalize( + const at::Tensor & self, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + bool non_blocking, + c10::optional memory_format) { + at::Tensor self_; + if (at::functionalization::impl::isFunctionalTensor(self)) { + // sync any pending updates + at::functionalization::impl::sync(self); + // pass the unwrapped tensor to the backend + self_ = at::functionalization::impl::from_functional_tensor(self); + } else { + self_ = self; + } + + at::AutoDispatchSkipFunctionalize guard; + auto out = at::_to_copy(self_, dtype, layout, device, pin_memory, non_blocking, memory_format); + + // Special case: if the Functionalize key is not in TLS, we assume that we're running + // on a lazy backend (LTC). + // In that case, if we're copying to a non-functionalize-enabled device, + // then the functionalization pass should "end". We need to sync any updates on the input + // tensor, but we shouldn't wrap the output. + if (!c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) { + if (!device_opted_into_functionalization(self.device(), device)) { + return out; + } + } + return at::functionalization::impl::to_functional_tensor(out); +} + + +// Why is _unsafe_view special-cased here? +// Basically just to satisfy autograd's debug asserts. +// The situation: +// - _unsafe_view's autograd kernel has debug asserts to confirm +// that the input and output alias storage. +// - _unsafe_view's schema in native_functions.yaml +// does not contain alias annotations, so it advertises as non-aliasing. +// - functionalization will then treat _unsafe_view like a non-aliasing op. +// Specifically, autograd will redispatch to functionalization's +// boxed fallback kernel, which creates a new FunctionalTensorWrapper output +// that does **not** alias storage with the input, tripping the assert. +// The kernel written here just manually re-ifies the aliasing relationship. +// +// Another way to handle this would be to fix unsafe_view's alias annotations +// in native_functions.yaml, but I think this would be a pessimization. +// The idea with _unsafe_view is that you're guaranteed that the input +// is a temporary, and don't actually have to worry about propagating +// mutations between the input and output. +at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::IntArrayRef size) { + if (!at::functionalization::impl::isFunctionalTensor(self)) { + at::AutoDispatchSkipFunctionalize guard; + return at::_unsafe_view(self, size); + } + + auto self_ = at::functionalization::impl::from_functional_tensor(self); + at::Tensor tmp_output; + { + at::AutoDispatchSkipFunctionalize guard; + tmp_output = at::_unsafe_view(self_, size); + } + + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor { + return at::_unsafe_view(base, size); + }, + [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor { + return at::_unsafe_view(mutated_view, base.sizes()); + } + ); + + auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, view_meta); + // See Note [Propagating strides in the functionalization pass] + // (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view) + auto inferred_size = at::infer_size_dv(size, self.numel()); + auto stride = at::detail::computeStride(self.sizes(), self.strides(), inferred_size); + TORCH_INTERNAL_ASSERT(stride.has_value()); + out.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride.value()); + return out; } TORCH_LIBRARY_IMPL(_, Functionalize, m) { @@ -178,4 +295,8 @@ TORCH_LIBRARY_IMPL(_, Functionalize, m) { TORCH_LIBRARY_IMPL(aten, Functionalize, m) { m.impl("resize_", TORCH_FN(resize__functionalization)); m.impl("lift", TORCH_FN(lift_functionalize)); + m.impl("lift_fresh", TORCH_FN(lift_fresh_functionalize)); + m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy)); + m.impl("_to_copy", TORCH_FN(_to_copy_functionalize)); + m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize)); } diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp index e9d5e02d747b7..077e9e742fc77 100644 --- a/aten/src/ATen/NestedTensorImpl.cpp +++ b/aten/src/ATen/NestedTensorImpl.cpp @@ -9,8 +9,10 @@ namespace at { namespace native { inline std::vector construct_opt_sizes(const at::Tensor& sizes) { + // torch.tensor([]) is considered to have `dim() = 1` and `size(0) = 0` + // torch.nested_tensor([]) should also has `dim() = 1` and `size(0) = 0` if (sizes.dim() == 0) { - return std::vector(); + return std::vector({0}); } TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2); std::vector result(1, sizes.sizes()[0]); @@ -35,16 +37,91 @@ inline std::vector construct_opt_sizes(const at::Tensor& sizes) { return result; } +// assume contiguous, we can construct stride from size +inline at::Tensor construct_nested_stride_tensor(const at::Tensor& sizes) { + // empty `sizes` means empty nested tensor, so return empty strides + if (sizes.dim() == 0) { + return sizes; + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2); + int64_t orig_dim = sizes.size(1); + // `sizes`.sizes() = ntensors x 0 means empty but shaped `sizes` + // in this case strides is also empty but shaped + if (orig_dim == 0) { + return sizes; + } + at::Tensor strides = sizes.new_empty(sizes.sizes()); + const int64_t* sizes_ptr = sizes.data_ptr(); + int64_t* strides_ptr = strides.data_ptr(); + for (int64_t i = 0; i < sizes.size(0); i++) { + strides_ptr[orig_dim - 1] = 1; + int64_t product = sizes_ptr[orig_dim - 1]; + for (int64_t j = orig_dim - 2; j >= 0; j--) { + strides_ptr[j] = product; + product *= sizes_ptr[j]; + } + sizes_ptr += orig_dim; + strides_ptr += orig_dim; + } + return strides; +} + +// assume contiguous, we can construct offsets from size +inline std::vector construct_offsets(const at::Tensor& sizes) { + // empty `sizes` means empty nested tensor, so return empty strides + if (sizes.dim() == 0) { + return std::vector(); + } + int64_t ntensors = sizes.size(0), + orig_dim = sizes.size(1); + std::vector offsets(ntensors); + // nesting scalars has easy offsets + if (orig_dim == 0) { + std::iota(offsets.begin(), offsets.end(), 0); + return offsets; + } + const int64_t* sizes_ptr = sizes.data_ptr(); + offsets[0] = 0; + for (int64_t i = 0; i < ntensors - 1; i++) { + int64_t row_product = sizes_ptr[0]; + for (int64_t j = 1; j < orig_dim; j++) { + row_product *= sizes_ptr[j]; + } + offsets[i + 1] = offsets[i] + row_product; + sizes_ptr += orig_dim; + } + return offsets; +} + +// [Note: Nested Tensor Autograd] The Nested Tensor key is a functionality +// key and therefore getAutogradRelatedKeySetFromBackend will return the +// wrong autograd key. For this specific impl we make sure to register the +// correct Autograd key which is AutogradNestedTensor +c10::DispatchKeySet generate_nested_key_set(at::Tensor buffer) { + c10::DispatchKeySet key_set = + (c10::DispatchKeySet(DispatchKey::NestedTensor) | + c10::DispatchKeySet( + buffer.is_cuda() ? BackendComponent::CUDABit + : BackendComponent::CPUBit)); + + // Add AutogradNestedTensor specific keys + key_set = key_set | inplace_or_view_ks | autograd_nested; + return key_set; +} + NestedTensorImpl::NestedTensorImpl( at::Tensor buffer, - at::Tensor nested_size_tensor) + at::Tensor nested_size_tensor, + at::Tensor nested_stride_tensor, + const std::vector& offsets) : TensorImpl( - (c10::DispatchKeySet(DispatchKey::NestedTensor) | - c10::DispatchKeySet(buffer.is_cuda() ? BackendComponent::CUDABit : BackendComponent::CPUBit)), + generate_nested_key_set(buffer), buffer.dtype(), buffer.device()), buffer_(std::move(buffer)), nested_size_tensor_(std::move(nested_size_tensor)), + nested_stride_tensor_(std::move(nested_stride_tensor)), + offsets_(offsets), opt_sizes_(construct_opt_sizes(nested_size_tensor_)) { TORCH_WARN_ONCE( @@ -54,10 +131,27 @@ NestedTensorImpl::NestedTensorImpl( TORCH_INTERNAL_ASSERT(nested_size_tensor_.is_contiguous()); int64_t size_dim = nested_size_tensor_.dim(); TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2); + TORCH_INTERNAL_ASSERT(nested_stride_tensor_.is_contiguous()); + TORCH_INTERNAL_ASSERT(nested_stride_tensor_.dim() == size_dim); + TORCH_INTERNAL_ASSERT(nested_stride_tensor_.sizes() == nested_size_tensor_.sizes()); + TORCH_INTERNAL_ASSERT((size_dim == 0 && (int64_t)offsets_.empty()) + || (size_dim == 2 && nested_size_tensor_.size(0) == (int64_t)offsets_.size())); refresh_dim(); set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomSizes); } +// assume contiguous, `nested_stride_tensor` and `offsets` +// can be infered from `nested_size_tensor` +NestedTensorImpl::NestedTensorImpl( + at::Tensor buffer, + at::Tensor nested_size_tensor) + : NestedTensorImpl( + buffer, + nested_size_tensor, + construct_nested_stride_tensor(nested_size_tensor), + construct_offsets(nested_size_tensor)) +{} + void NestedTensorImpl::refresh_dim() { const auto my_dim = nested_size_tensor_.dim() ? nested_size_tensor_.sizes()[1] + 1 : 1; sizes_and_strides_.resize(my_dim); @@ -67,9 +161,32 @@ void NestedTensorImpl::refresh_dim() { int64_t NestedTensorImpl::dim_custom() const { return dim_default(); } + +// Currently sizes and strides assume contiguous int64_t NestedTensorImpl::numel_custom() const { - TORCH_CHECK(false, "numel is disabled."); + if (nested_size_tensor_.dim() == 0) { + return 0; + } + constexpr auto numel_max = std::min( + static_cast(std::numeric_limits::max()), + static_cast(std::numeric_limits::max())); + + const auto nt_dim = nested_size_tensor_.size(1); + const int64_t* sizes_ptr = nested_size_tensor_.data_ptr(); + uint64_t num_elements{0}; + + for (const auto i : c10::irange(nested_size_tensor_.size(0))) { + uint64_t n = 1; + const auto start{sizes_ptr + i * nt_dim}; + const auto end{start + nt_dim}; + bool overflows = c10::safe_multiplies_u64(start, end, &n); + num_elements += n; + overflows |= (num_elements > numel_max); + TORCH_CHECK(!overflows, "numel: integer multiplication overflow"); + } + return static_cast(num_elements); } + bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const { TORCH_CHECK(false, "is_contiguous is disabled."); } diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h index 26e76aad22e42..47f6c1516b9d5 100644 --- a/aten/src/ATen/NestedTensorImpl.h +++ b/aten/src/ATen/NestedTensorImpl.h @@ -11,6 +11,13 @@ namespace at { namespace native { struct TORCH_API NestedTensorImpl : public c10::TensorImpl { + explicit NestedTensorImpl( + at::Tensor buffer, + at::Tensor nested_size_tensor, + at::Tensor nested_stride_tensor, + const std::vector& offsets); + // assume contiguous, `nested_stride_tensor` and `offsets` + // can be infered from `nested_size_tensor` explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor); // TODO: don't expose private implementation details like this; in @@ -19,6 +26,13 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { const Tensor& get_nested_size_tensor() const { return nested_size_tensor_; } + // TODO: don't expose private implementation details like this + const Tensor& get_nested_stride_tensor() const { + return nested_stride_tensor_; + } + const std::vector& get_offsets() const { + return offsets_; + } // Returns nullopt if the ith dimension is irregular. The ith dimension // of a NestedTensor is regular if the unbound tensors match in // size at the (i-1)th dimension. @@ -30,6 +44,16 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { return opt_sizes_[d]; } + int64_t size(int64_t d) const { + c10::optional optional_size = this->opt_size(d); + TORCH_CHECK( + optional_size.has_value(), + "Given dimension ", + d, + " is irregular and does not have a size."); + return *optional_size; + } + const at::Tensor& get_buffer() const { return buffer_; } @@ -41,6 +65,12 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { // with real implementations int64_t numel_custom() const override; bool is_contiguous_custom(MemoryFormat) const override; + int64_t size_custom(int64_t d) const override { + return this->size(d); + } + c10::SymInt sym_size_custom(int64_t d) const override { + return c10::SymInt{this->size(d)}; + } IntArrayRef sizes_custom() const override; c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymIntArrayRef sym_sizes() const override; @@ -55,8 +85,23 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { void refresh_dim(); at::Tensor buffer_; - const at::Tensor nested_size_tensor_; + const at::Tensor nested_size_tensor_, nested_stride_tensor_; + // The starting positions of the underlying tensors in contiguous buffer + // i.e. the buffer memory offsets to get the underlying tensors + // The reason to keep this metadata is that, without strong enough constraint + // it cannot be derived from `nested_size_tensor_` + // and `nested_stride_tensor_`: + // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2] + // this can happen e.g. after slicing a nested tensor + // 2. when multiple tensors share a same memory + // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2] + // Some strong enough constraints are: + // 1. every underlying tensor is contiguous in memory + // && nesting in ascending order + std::vector offsets_; // NOTE: -1 here means the size is missing + // TODO: maybe we can remove this metadata since + // we can compute it from `nested_size_tensor_` std::vector opt_sizes_; }; @@ -74,11 +119,60 @@ inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) { return static_cast(tensor.unsafeGetTensorImpl()); } -// TODO: real implementation once we support strides. -inline bool nested_tensor_impl_is_contiguous( - const NestedTensorImpl* nt, - at::MemoryFormat memory_format = MemoryFormat::Contiguous) { - return memory_format == MemoryFormat::Contiguous; +inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) { + int64_t ntensors = nt->size(0); + if (ntensors == 0) { + return true; + } + const Tensor &sizemat = nt->get_nested_size_tensor(), + &stridemat = nt->get_nested_stride_tensor(); + const auto& offsets = nt->get_offsets(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars + if (orig_dim == 0) { + // each scalar must be contiguous + // if there is blanck memory between underlying scalars + for (int64_t i = 0; i < ntensors; i++) { + if (offsets[i] != i) { + return false; + } + } + } + // nesting tensors + else { + // if any underlying tensor is noncontiguous + const int64_t *sizemat_ptr = sizemat.data_ptr(), + *stridemat_ptr = stridemat.data_ptr(); + for (int64_t i = 0; i < ntensors; i++) { + if (stridemat_ptr[orig_dim - 1] != 1) { + return false; + } + int64_t product = sizemat_ptr[orig_dim - 1]; + for (int64_t j = orig_dim - 2; j >= 0; j--) { + if (stridemat_ptr[j] != product) { + return false; + } + product *= sizemat_ptr[j]; + } + sizemat_ptr += orig_dim; + stridemat_ptr += orig_dim; + } + // if there is blanck memory between underlying tensors + if (offsets[0] != 0) { + return false; + } + sizemat_ptr = sizemat.data_ptr(); + stridemat_ptr = stridemat.data_ptr(); + for (int64_t i = 1; i < ntensors; i++) { + if (offsets[i] != offsets[i - 1] + *sizemat_ptr * *stridemat_ptr) { + return false; + } + sizemat_ptr += orig_dim; + stridemat_ptr += orig_dim; + } + } + // everything is fine + return true; } inline const at::Tensor& get_nested_size_tensor(const at::Tensor& tensor) { diff --git a/aten/src/ATen/OpMathType.h b/aten/src/ATen/OpMathType.h index b01a706e7c2b4..f08e420692569 100644 --- a/aten/src/ATen/OpMathType.h +++ b/aten/src/ATen/OpMathType.h @@ -30,7 +30,7 @@ using opmath_type = typename OpMathType::type; namespace { -c10::ScalarType toOpMathType(const c10::ScalarType type) { +inline c10::ScalarType toOpMathType(const c10::ScalarType type) { switch (type) { #define DEFINE_CASE(scalar_t, TypeNum) \ case ScalarType::TypeNum: \ diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h index 9624b987ba6b5..878c465962b86 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.h +++ b/aten/src/ATen/SparseCsrTensorImpl.h @@ -57,6 +57,23 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl { return col_indices_.size(-1); } + inline int64_t batch_dim() const noexcept { + return crow_indices_.dim() - 1; + } + + inline int64_t sparse_dim() const noexcept { + return 2; + } + + inline int64_t dense_dim() const noexcept { + return values_.dim() - batch_dim() - block_dim() - 1; + } + + private: + inline int64_t block_dim() const noexcept { + return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0); + } + protected: IntArrayRef strides_custom() const override; diff --git a/aten/src/ATen/SparseCsrTensorUtils.h b/aten/src/ATen/SparseCsrTensorUtils.h index daec805bc58d8..24b5ae47df7d6 100644 --- a/aten/src/ATen/SparseCsrTensorUtils.h +++ b/aten/src/ATen/SparseCsrTensorUtils.h @@ -181,6 +181,38 @@ inline std::string plainIndicesName(Layout layout) { [&] { return "row_indices"; }); } +inline std::string compressedDimName(Layout layout) { + switch (layout) { + case kSparseCsr: + return "row"; + case kSparseCsc: + return "column"; + case kSparseBsr: + return "row block"; + case kSparseBsc: + return "column block"; + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return ""; + } +} + +inline std::string plainDimName(Layout layout) { + switch (layout) { + case kSparseCsr: + return "column"; + case kSparseCsc: + return "row"; + case kSparseBsr: + return "column block"; + case kSparseBsc: + return "row block"; + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return ""; + } +} + inline int rowDimension(Layout layout, IntArrayRef size) { return size.size() - (isCompressedRow(layout) ? 2 : 1); } diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 9fa5c70a409bd..03999da97312b 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -7,15 +7,10 @@ namespace at { namespace { DeviceType sparseTensorSetToDeviceType(DispatchKeySet key_set) { - if (key_set.has(DispatchKey::SparseCPU)) { - return kCPU; - } else if (key_set.has(DispatchKey::SparseXPU)) { - return kXPU; - } else if (key_set.has(DispatchKey::SparseCUDA)) { - return kCUDA; - } else { - AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", key_set); - } + auto k = c10::highestPriorityBackendTypeId(key_set); + TORCH_CHECK(c10::toFunctionalityKey(k) == DispatchKey::Sparse, + "cannot create sparse tensor with non sparse dispatch key ", k); + return c10::dispatchKeyToDeviceType(k); } } diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index 6f0c76729b912..f1eedfa83ef9b 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -640,7 +640,7 @@ static inline Tensor get_item( tensorIndices, disable_slice_optimization, self_device, - *self_sizes); + self_sizes); if (tensorIndices.empty()) { if (sliced.is_same(self)) { // ensure we return a shallow copy for things like x[...] diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 907ec8c5c57d4..a4715a2caabb3 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -946,6 +946,8 @@ void TensorIteratorBase::build_ternary_op( const TensorBase& out, const TensorBase& a, const TensorBase& b, const TensorBase& c) { build(TensorIteratorConfig() + .promote_inputs_to_common_dtype(true) + .enforce_safe_casting_to_output(true) .add_owned_output(out) .add_owned_input(a) .add_owned_input(b) diff --git a/aten/src/ATen/VmapTransforms.cpp b/aten/src/ATen/VmapTransforms.cpp index 4bda903545fdf..20c792f73709b 100644 --- a/aten/src/ATen/VmapTransforms.cpp +++ b/aten/src/ATen/VmapTransforms.cpp @@ -55,13 +55,20 @@ int64_t VmapPhysicalView::numLogicalDims() const { return /*physical*/tensor_.dim() - numBatchDims(); } -VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const { +VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const { auto logical_ndim = numLogicalDims(); // NB: fmap doesn't have a SmallVector variant, so we don't use it here. VmapDimVector result; result.reserve(logical_ndim); - for (auto dim : logical_dims) { - result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims()); + if (opt_logical_dims.has_value()) { + auto logical_dims = opt_logical_dims.value(); + for (auto dim : logical_dims) { + result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims()); + } + } else { + for (int64_t dim = 0; dim < logical_ndim; dim++) { + result.push_back(dim + numBatchDims()); + } } return result; } diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/VmapTransforms.h index 89190265a9563..53e476e2243fa 100644 --- a/aten/src/ATen/VmapTransforms.h +++ b/aten/src/ATen/VmapTransforms.h @@ -131,7 +131,7 @@ struct TORCH_API VmapPhysicalView { // This is because the size of levels tell us that the first two dimensions // of `tensor_` are batch dimensions, so a logical dim of `n` is actually // a physical dim of `n + 2`. - VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const; + VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const; int64_t getPhysicalDim(int64_t logical_dim) const; // Returns a VmapPhysicalToLogicalMap object. This can be used for diff --git a/aten/src/ATen/WrapDimUtilsMulti.h b/aten/src/ATen/WrapDimUtilsMulti.h index 64975e9a43c47..c1899bea872c9 100644 --- a/aten/src/ATen/WrapDimUtilsMulti.h +++ b/aten/src/ATen/WrapDimUtilsMulti.h @@ -14,7 +14,7 @@ namespace at { constexpr size_t dim_bitset_size = 64; static inline std::bitset dim_list_to_bitset( - IntArrayRef dims, + OptionalIntArrayRef opt_dims, int64_t ndims) { TORCH_CHECK( ndims <= (int64_t)dim_bitset_size, @@ -22,11 +22,21 @@ static inline std::bitset dim_list_to_bitset( dim_bitset_size, " dims are supported"); std::bitset seen; - for (const auto i : c10::irange(dims.size())) { - size_t dim = maybe_wrap_dim(dims[i], ndims); - TORCH_CHECK( - !seen[dim], "dim ", dim, " appears multiple times in the list of dims"); - seen[dim] = true; + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + for (const auto i : c10::irange(dims.size())) { + size_t dim = maybe_wrap_dim(dims[i], ndims); + TORCH_CHECK( + !seen[dim], + "dim ", + dim, + " appears multiple times in the list of dims"); + seen[dim] = true; + } + } else { + for (int64_t dim = 0; dim < ndims; dim++) { + seen[dim] = true; + } } return seen; } diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 1ebcf775a548d..da0a87b02d1d0 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -455,7 +455,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { // KERNEL(ADD_NS(norm), "norm.ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional, IntArrayRef, bool, ScalarType), fp32_set_opt_dtype) // KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype) - KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional), fp32_set_opt_dtype) + KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, OptionalIntArrayRef, bool, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional), fp32_set_opt_dtype) // fp32_append_dtype // The fp32_append_dtype wrapper overrides implicit promotion behavior. @@ -483,11 +483,15 @@ TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); } + TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { // lower_precision_fp cast policy KERNEL_CPU(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL_CPU(ADD_NS(conv1d), "conv1d.padding", Tensor (const Tensor&, const Tensor&, const c10::optional&, IntArrayRef, c10::string_view, IntArrayRef, int64_t groups), lower_precision_fp) KERNEL_CPU(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL_CPU(ADD_NS(conv2d), "conv2d.padding", Tensor (const Tensor&, const Tensor&, const c10::optional&, IntArrayRef, c10::string_view, IntArrayRef, int64_t groups), lower_precision_fp) KERNEL_CPU(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) + KERNEL_CPU(ADD_NS(conv3d), "conv3d.padding", Tensor (const Tensor&, const Tensor&, const c10::optional&, IntArrayRef, c10::string_view, IntArrayRef, int64_t groups), lower_precision_fp) KERNEL_CPU(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), lower_precision_fp) KERNEL_CPU(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), lower_precision_fp) KERNEL_CPU(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index bb675939b27c6..a9ae2f12c4dd0 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -467,6 +467,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("sum.IntList_out", CppFunction::makeFallthrough()); m.impl("sum.dim_DimnameList", CppFunction::makeFallthrough()); m.impl("sum.dim_IntList", CppFunction::makeFallthrough()); + m.impl("sum.SymInt", CppFunction::makeFallthrough()); m.impl("t", CppFunction::makeFallthrough()); m.impl("tan", CppFunction::makeFallthrough()); m.impl("tan.out", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index f9f3d6ff7f830..37b46ae15a3c0 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -53,7 +53,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch const auto& maybe_torch_dispatch_mode_state = at::impl::TorchDispatchModeTLS::get_state(); if (maybe_torch_dispatch_mode_state) { - maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack, maybe_torch_dispatch_mode_state); + maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack); return; } @@ -69,7 +69,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { if (ivalue.isTensor()) { auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter(); if (interpreter) { - interpreter->dispatch(op, stack, nullptr); + interpreter->dispatch(op, stack); return; } } else if (ivalue.isTensorList()) { @@ -78,7 +78,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { for (const auto& nv : ivalue.toListRef()) { auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter(); if (interpreter) { - interpreter->dispatch(op, stack, nullptr); + interpreter->dispatch(op, stack); return; } } diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 094981478577a..ca9c8b5f245a0 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -157,18 +158,11 @@ class TORCH_API TensorBase { } c10::SymInt sym_size(int64_t dim) const { - const auto sizes = this->sym_sizes(); - const auto ndim = static_cast(sizes.size()); - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; - + return impl_->sym_size(dim); } int64_t size(int64_t dim) const { - const auto sizes = this->sizes(); - const auto ndim = static_cast(sizes.size()); - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; + return impl_->size(dim); } int64_t stride(int64_t dim) const { @@ -352,12 +346,12 @@ class TORCH_API TensorBase { } /// Returns a `Tensor`'s layout. - Layout layout() const noexcept { + Layout layout() const { return impl_->layout(); } /// Returns a `Tensor`'s dtype (`TypeMeta`). - caffe2::TypeMeta dtype() const noexcept { + caffe2::TypeMeta dtype() const { return impl_->dtype(); } @@ -725,9 +719,9 @@ class TORCH_API TensorBase { //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template - using hook_return_void_t = std::enable_if_t::type>::value, unsigned>; + using hook_return_void_t = std::enable_if_t>::value, unsigned>; template - using hook_return_var_t = std::enable_if_t::type, TensorBase>::value, unsigned>; + using hook_return_var_t = std::enable_if_t, TensorBase>::value, unsigned>; /// Registers a backward hook. /// @@ -897,58 +891,7 @@ struct MaybeOwnedTraits { }; template <> -struct ExclusivelyOwnedTraits { - using repr_type = at::TensorBase; - using pointer_type = at::TensorBase*; - using const_pointer_type = const at::TensorBase*; - - static repr_type nullRepr() { - return at::TensorBase(); - } - - template - static repr_type createInPlace(Args&&... args) { - return at::TensorBase(std::forward(args)...); - } - - static repr_type moveToRepr(at::TensorBase&& x) { - return std::move(x); - } - - static void destroyOwned(at::TensorBase& x) { - TensorImpl*const toDestroy = x.unsafeReleaseTensorImpl(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(toDestroy != nullptr, "Tensor somehow got null TensorImpl?"); - // May be 0 because UndefinedTensorImpl doesn't get its refcount - // incremented. - const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined), - "ExclusivelyOwned destroyed with isUndefined ", isUndefined, " and refcount ", toDestroy->refcount_, ", expected 1 or, if isUndefined, 0!"); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - toDestroy->weakcount_ == 1 || (toDestroy->weakcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()), - "ExclusivelyOwned destroyed with isUndefined ", isUndefined, " and weakcount ", toDestroy->weakcount_, ", expected 1 or, if isUndefined, 0!"); - if (!isUndefined) { -#ifndef NDEBUG - // Needed to pass the debug assertions in ~intrusive_ptr_target. - toDestroy->refcount_ = 0; - toDestroy->weakcount_ = 0; -#endif - delete toDestroy; - } - } - - static at::TensorBase take(at::TensorBase& x) { - return std::move(x); - } - - static pointer_type getImpl(repr_type& x) { - return &x; - } - - static const_pointer_type getImpl(const repr_type& x) { - return &x; - } -}; +struct ExclusivelyOwnedTraits : public c10::ExclusivelyOwnedTensorTraits {}; } // namespace c10 namespace at { diff --git a/aten/src/ATen/core/boxing/BoxedKernel.h b/aten/src/ATen/core/boxing/BoxedKernel.h new file mode 100644 index 0000000000000..829031f423eb2 --- /dev/null +++ b/aten/src/ATen/core/boxing/BoxedKernel.h @@ -0,0 +1,176 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +struct IValue; +using Stack = std::vector; + +class OperatorHandle; +class KernelFunction; + +// This kernel implements the behavior of falling through to the next available +// registered dispatch key. The implementation of this function is FAST; it is +// no overhead to fallthrough to the next key. See cpp file for some more +// implementation notes; notably, this does NOT actually go through the +// boxing/unboxing codepath. +TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); + +// Note [Ambiguity in AutogradOther kernel] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This error-reporting kernel is registered to the AutogradOther entry in the +// dispatch table when there is both a CompositeImplicitAutograd kernel and a +// backend kernel for ANY backend that maps to AutogradOther. To see why +// this is necessary in the AutogradOther case, it's helpful to first see +// why everything works out fine for a backend that has a reserved Autograd +// entry (see rule 2.2 in [Note] DispatchTable computation): +// +// CPU AutogradCPU +// reg? registers with... +// ------------------------------------------------- +// y Autograd registration takes precedence +// over CompositeImplicitAutograd. +// This is good, because the CPU specific backend +// implementation is more specialized and typically better; +// if we used the composite, we would bypass it. +// (NB: the Autograd key is guaranteed to exist because +// the autograd codegen requires it!) +// +// n CompositeImplicitAutograd takes precedence. +// This is also good, because the Autograd +// registration (if it exists) would try to redispatch +// to the (non-existent) CPU implementation; by +// using the composite, we ensure the operator +// actually works. +// +// As you can see, when we have a specific Autograd key (AutogradCPU), we can +// decide whether or not to use the CompositeImplicitAutograd kernel or the +// Autograd kernel based on whether or not the backend kernel exists. +// +// However, for AutogradOther (which is the catchall autograd kernel for +// everything that doesn't have a specific Autograd key), we can't do this +// trick because there isn't any unique backend to peek at to disambiguate; +// if there are some backends that have implementations they prefer Autograd, +// but unimplemented backends would prefer CompositeImplicitAutograd. Rather +// than arbitrarily pick one or the other, we just register a kernel that raises +// an error and let the user decide how to proceed. +TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); + +// Note [named_not_supported_kernel] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This kernel implements reporting an error message saying that named tensor is +// not supported. This kernel doesn't rely on the Stack, and so it is special +// cased in the dispatcher to be triggered before we attempt boxing (so we can +// give a good error message in cases when boxing is not supported). When +// boxing is universally supported this can be removed. +[[noreturn]] TORCH_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); + +/** + * BoxedKernel is similar to a std::function storing a boxed kernel. + */ +class TORCH_API BoxedKernel final { +public: + // This is how boxed kernels are actually stored + // + // Note [Plumbing Keys Through The Dispatcher] + // Benchmarks have shown that it is expensive for the dispatcher to read from thread-local storage (TLS) + // upon every dispatch call into order to compute which kernel to dispatch to. + // + // To mitigate this, we've updated the calling convention inside the dispatcher to expect every kernel that it stores + // to have a first argument of type DispatchKeySet. + // + // What are the invariants of the DispatchKeySet when it gets passed to a kernel? + // - All keys to the left of the current dispatch key have been masked out. + // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the highest bit to be DispatchKey::Tracer) + // - All other keys that dispatcher normally would have computed through TLS + global state + op arguments + // are still in the set. + // + // Kernels can then opt into using this keyset to save the dispatcher from doing repeated work during redispatches: + // recalculating the highest-priority dispatch key, which involves reading from TLS. Instead, the kernels that opt in will + // calculate an updated DispatchKeySet directly from the old one, and pass the updated set directly into the dispatcher + // upon redispatching. + // + // This is an opt-in mechanism: Kernels can automatically opt in by setting the first argument in their signature + // to be of type DispatchKeySet. See the kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for examples. + // + // The mechanism for optionally passing that DispatchKeySet into the kernel lives in make_boxed_from_unboxed_functor.h. + // See Note [Plumbing Keys Through The Dispatcher 2] for details. + using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); + // This is the public API for how boxed kernels are defined + using BoxedKernelFunction = void(const OperatorHandle&, Stack*); + using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*); + + BoxedKernel(); + + // Fast path for dispatch to allow not touching the boxed kernel in + // the common case where unboxed is available. + bool isValid() const; + bool isFallthrough() const; + + /** + * Call the function with boxed arguments. + */ + void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const; + + /** + * Create a KernelFunction from a boxed function. + * + * Example: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>(); + */ + template + static BoxedKernel makeFromFunction(); + + /** + * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none) + * See Note [Plumbing Keys Through The Dispatcher] for details. + */ + template + static BoxedKernel makeFromFunction(); + + /** + * Create a KernelFunction from a boxed functor. + * + * Example: + * + * > class MyFunctor final : public c10::OperatorKernel { + * > public: + * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} + * > }; + * > BoxedKernel func = BoxedKernel::makeFromFunctor(std::make_unique()); + */ + template + static BoxedKernel makeFromFunctor(std::unique_ptr kernelFunctor); + + + static BoxedKernel makeFallthrough(); + static BoxedKernel makeAmbiguousAutogradOther(); + static BoxedKernel makeNamedNotSupported(); + +private: + + friend class KernelFunction; + + template + static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); + + template + static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); + + explicit BoxedKernel(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func); + + OperatorKernel* getFunctor() const; + InternalBoxedKernelFunction* getFnPtr() const; + + c10::intrusive_ptr functor_; + InternalBoxedKernelFunction* boxed_kernel_func_; +}; + +} // namespace c10 + +#include diff --git a/aten/src/ATen/core/boxing/BoxedKernel_impl.h b/aten/src/ATen/core/boxing/BoxedKernel_impl.h new file mode 100644 index 0000000000000..421b85cca3ec5 --- /dev/null +++ b/aten/src/ATen/core/boxing/BoxedKernel_impl.h @@ -0,0 +1,99 @@ +#pragma once + +namespace c10 { + +inline BoxedKernel::BoxedKernel() + : functor_() +, boxed_kernel_func_(nullptr) +{} + +inline BoxedKernel::BoxedKernel(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func) +: functor_(std::move(functor)) +, boxed_kernel_func_(boxed_kernel_func) +{} + +template +inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack) { + // Note that we're dropping the DispatchKeySet argument. + // See Note [Plumbing Keys Through The Dispatcher 2] for details. + func(opHandle, stack); +} + +template +inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet ks, Stack* stack) { + // See Note [Plumbing Keys Through The Dispatcher 2] for details. + func(opHandle, ks, stack); +} + +inline bool BoxedKernel::isValid() const { + return boxed_kernel_func_ != nullptr; +} + +inline bool BoxedKernel::isFallthrough() const { + return boxed_kernel_func_ == &fallthrough_kernel; +} + +inline void BoxedKernel::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + boxed_kernel_func_ != nullptr, + "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel." + ); + (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunction() { + return BoxedKernel( + nullptr, // no functor_ object + &make_boxed_function + ); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunction() { + return BoxedKernel( + nullptr, // no functor_ object + &make_boxed_function + ); +} + +inline BoxedKernel BoxedKernel::makeFallthrough() { + return BoxedKernel( + nullptr, // no functor_ object + &fallthrough_kernel + ); +} + +inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() { + return BoxedKernel( + nullptr, // no functor_ object + &ambiguous_autogradother_kernel + ); +} + +inline BoxedKernel BoxedKernel::makeNamedNotSupported() { + return BoxedKernel( + nullptr, // no functor_ object + &named_not_supported_kernel + ); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunctor(std::unique_ptr kernelFunctor) { + static_assert(std::is_base_of::value, "Tried to call BoxedKernel::makeFromFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + return BoxedKernel( + std::move(kernelFunctor), + [](OperatorKernel* kernel, const OperatorHandle& op, DispatchKeySet ks, Stack* stack) { + (*static_cast(kernel))(op, ks, stack); + } + ); +} + +inline OperatorKernel* BoxedKernel::getFunctor() const { + return functor_.get(); +} +inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const { + return boxed_kernel_func_; +} + +} // namespace c10 diff --git a/aten/src/ATen/core/boxing/KernelFunction.cpp b/aten/src/ATen/core/boxing/KernelFunction.cpp index b0bc48f7b2564..90bf14aa14726 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.cpp +++ b/aten/src/ATen/core/boxing/KernelFunction.cpp @@ -44,10 +44,11 @@ void named_not_supported_kernel(OperatorKernel*, const OperatorHandle& op, Dispa // single line summary of state std::string KernelFunction::dumpState() const { std::ostringstream oss; - if (boxed_kernel_func_ == fallthrough_kernel) { + auto boxed_kernel_fn = boxed_kernel_func_.getFnPtr(); + if (boxed_kernel_fn == fallthrough_kernel) { oss << "fallthrough "; } - if (boxed_kernel_func_) { + if (boxed_kernel_fn) { oss << "boxed "; } if (unboxed_kernel_func_) { @@ -57,7 +58,7 @@ std::string KernelFunction::dumpState() const { } bool KernelFunction::_equalsBoxedAndUnboxed(const KernelFunction& other) const { - return boxed_kernel_func_ == other.boxed_kernel_func_ && + return boxed_kernel_func_.getFnPtr() == other.boxed_kernel_func_.getFnPtr() && unboxed_kernel_func_ == other.unboxed_kernel_func_; } diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h index 9441f47e0d63a..8ab34e95046ab 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.h +++ b/aten/src/ATen/core/boxing/KernelFunction.h @@ -1,6 +1,8 @@ #pragma once +#include #include +#include #include #include @@ -10,62 +12,7 @@ using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack class OperatorHandle; struct OperatorKernel; - -// This kernel implements the behavior of falling through to the next available -// registered dispatch key. The implementation of this function is FAST; it is -// no overhead to fallthrough to the next key. See cpp file for some more -// implementation notes; notably, this does NOT actually go through the -// boxing/unboxing codepath. -TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); - -// Note [Ambiguity in AutogradOther kernel] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// This error-reporting kernel is registered to the AutogradOther entry in the -// dispatch table when there is both a CompositeImplicitAutograd kernel and a -// backend kernel for ANY backend that maps to AutogradOther. To see why -// this is necessary in the AutogradOther case, it's helpful to first see -// why everything works out fine for a backend that has a reserved Autograd -// entry (see rule 2.2 in [Note] DispatchTable computation): -// -// CPU AutogradCPU -// reg? registers with... -// ------------------------------------------------- -// y Autograd registration takes precedence -// over CompositeImplicitAutograd. -// This is good, because the CPU specific backend -// implementation is more specialized and typically better; -// if we used the composite, we would bypass it. -// (NB: the Autograd key is guaranteed to exist because -// the autograd codegen requires it!) -// -// n CompositeImplicitAutograd takes precedence. -// This is also good, because the Autograd -// registration (if it exists) would try to redispatch -// to the (non-existent) CPU implementation; by -// using the composite, we ensure the operator -// actually works. -// -// As you can see, when we have a specific Autograd key (AutogradCPU), we can -// decide whether or not to use the CompositeImplicitAutograd kernel or the -// Autograd kernel based on whether or not the backend kernel exists. -// -// However, for AutogradOther (which is the catchall autograd kernel for -// everything that doesn't have a specific Autograd key), we can't do this -// trick because there isn't any unique backend to peek at to disambiguate; -// if there are some backends that have implementations they prefer Autograd, -// but unimplemented backends would prefer CompositeImplicitAutograd. Rather -// than arbitrarily pick one or the other, we just register a kernel that raises -// an error and let the user decide how to proceed. -TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); - -// Note [named_not_supported_kernel] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// This kernel implements reporting an error message saying that named tensor is -// not supported. This kernel doesn't rely on the Stack, and so it is special -// cased in the dispatcher to be triggered before we attempt boxing (so we can -// give a good error message in cases when boxing is not supported). When -// boxing is universally supported this can be removed. -[[noreturn]] TORCH_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); +class KernelFunction; /** * KernelFunction is similar to std::function but stores a kernel function. @@ -75,35 +22,9 @@ TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHan */ class TORCH_API KernelFunction final { public: - // This is how boxed kernels are actually stored - // - // Note [Plumbing Keys Through The Dispatcher] - // Benchmarks have shown that it is expensive for the dispatcher to read from thread-local storage (TLS) - // upon every dispatch call into order to compute which kernel to dispatch to. - // - // To mitigate this, we've updated the calling convention inside the dispatcher to expect every kernel that it stores - // to have a first argument of type DispatchKeySet. - // - // What are the invariants of the DispatchKeySet when it gets passed to a kernel? - // - All keys to the left of the current dispatch key have been masked out. - // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the highest bit to be DispatchKey::Tracer) - // - All other keys that dispatcher normally would have computed through TLS + global state + op arguments - // are still in the set. - // - // Kernels can then opt into using this keyset to save the dispatcher from doing repeated work during redispatches: - // recalculating the highest-priority dispatch key, which involves reading from TLS. Instead, the kernels that opt in will - // calculate an updated DispatchKeySet directly from the old one, and pass the updated set directly into the dispatcher - // upon redispatching. - // - // This is an opt-in mechanism: Kernels can automatically opt in by setting the first argument in their signature - // to be of type DispatchKeySet. See the kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for examples. - // - // The mechanism for optionally passing that DispatchKeySet into the kernel lives in make_boxed_from_unboxed_functor.h. - // See Note [Plumbing Keys Through The Dispatcher 2] for details. - using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); - // This is the public API for how boxed kernels are defined - using BoxedKernelFunction = void(const OperatorHandle&, Stack*); - using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*); + using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction; + using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction; + using BoxedKernelFunction_withDispatchKeys = BoxedKernel::BoxedKernelFunction_withDispatchKeys; KernelFunction(); @@ -155,6 +76,11 @@ class TORCH_API KernelFunction final { template Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const; + /** + * Create a KernelFunction from a BoxedKernel. + */ + static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn); + /** * Create a KernelFunction from a boxed function. * @@ -234,12 +160,6 @@ class TORCH_API KernelFunction final { static KernelFunction makeAmbiguousAutogradOther(); static KernelFunction makeNamedNotSupported(); - template - static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); - - template - static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); - /** * Create a KernelFunction from an unboxed lambda. * @@ -259,13 +179,15 @@ class TORCH_API KernelFunction final { private: - explicit KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func); - - OperatorKernel* getFunctor_() const; - - c10::intrusive_ptr functor_; + explicit KernelFunction( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func, + void* unboxed_kernel_func); + explicit KernelFunction( + BoxedKernel boxed_fn, + void* unboxed_kernel_func); - InternalBoxedKernelFunction* boxed_kernel_func_; + BoxedKernel boxed_kernel_func_; void* unboxed_kernel_func_; }; diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index 01537c2dc4710..c33175e4b99ab 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -6,48 +6,34 @@ namespace c10 { inline KernelFunction::KernelFunction() - : functor_() -, boxed_kernel_func_(nullptr) -, unboxed_kernel_func_(nullptr) + : boxed_kernel_func_() + , unboxed_kernel_func_(nullptr) {} inline KernelFunction::KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func) -: functor_(std::move(functor)) -, boxed_kernel_func_(boxed_kernel_func) -, unboxed_kernel_func_(unboxed_kernel_func) + : boxed_kernel_func_(std::move(functor), boxed_kernel_func) + , unboxed_kernel_func_(unboxed_kernel_func) {} -template -inline void KernelFunction::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack) { - // Note that we're dropping the DispatchKeySet argument. - // See Note [Plumbing Keys Through The Dispatcher 2] for details. - func(opHandle, stack); -} +inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func) + : boxed_kernel_func_(std::move(boxed_fn)) + , unboxed_kernel_func_(unboxed_kernel_func) +{} inline bool KernelFunction::isValidUnboxed() const { - return unboxed_kernel_func_ != nullptr; -} - -template -inline void KernelFunction::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet ks, Stack* stack) { - // See Note [Plumbing Keys Through The Dispatcher 2] for details. - func(opHandle, ks, stack); + return unboxed_kernel_func_ != nullptr; } inline bool KernelFunction::isValid() const { - return boxed_kernel_func_ != nullptr; + return boxed_kernel_func_.isValid(); } inline bool KernelFunction::isFallthrough() const { - return boxed_kernel_func_ == &fallthrough_kernel; + return boxed_kernel_func_.isFallthrough(); } inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - boxed_kernel_func_ != nullptr, - "Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction." - ); - (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack); + boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack); } template @@ -64,63 +50,48 @@ C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, Di // want callers to explicitly specify the Args. if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { - return callUnboxedKernelFunction(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward(args)...); + auto *functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction( + unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...); } - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - boxed_kernel_func_ != nullptr, - "Tried to call KernelFunction::call() on an uninitialized KernelFunction." - ); - return impl::BoxedKernelWrapper::call( boxed_kernel_func_, - functor_.get(), opHandle, dispatchKeySet, std::forward(args)... ); } +inline KernelFunction KernelFunction::makeFromBoxedKernel(BoxedKernel boxed_fn) { + return KernelFunction(std::move(boxed_fn), nullptr); // no unboxed function pointer +} + template inline KernelFunction KernelFunction::makeFromBoxedFunction() { - return KernelFunction( - nullptr, // no functor_ object - &make_boxed_function, - nullptr // no unboxed function pointer - ); + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunction()); } template inline KernelFunction KernelFunction::makeFromBoxedFunction() { - return KernelFunction( - nullptr, // no functor_ object - &make_boxed_function, - nullptr // no unboxed function pointer - ); + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunction()); } inline KernelFunction KernelFunction::makeFallthrough() { - return KernelFunction( - nullptr, // no functor_ object - &fallthrough_kernel, - nullptr // no unboxed function pointer - ); + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFallthrough()); } inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() { - return KernelFunction( - nullptr, // no functor_ object - &ambiguous_autogradother_kernel, - nullptr // no unboxed function pointer - ); + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeAmbiguousAutogradOther()); } inline KernelFunction KernelFunction::makeNamedNotSupported() { - return KernelFunction( - nullptr, // no functor_ object - &named_not_supported_kernel, - nullptr // no unboxed function pointer - ); + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeNamedNotSupported()); } template @@ -140,14 +111,8 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr inline KernelFunction KernelFunction::makeFromBoxedFunctor(std::unique_ptr kernelFunctor) { - static_assert(std::is_base_of::value, "Tried to call KernelFunction::makeFromBoxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); - return KernelFunction( - std::move(kernelFunctor), - [](OperatorKernel* kernel, const OperatorHandle& op, DispatchKeySet ks, Stack* stack) { - (*static_cast(kernel))(op, ks, stack); - }, - nullptr // no unboxed function pointer - ); + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); } template diff --git a/aten/src/ATen/core/boxing/OperatorKernel.h b/aten/src/ATen/core/boxing/OperatorKernel.h new file mode 100644 index 0000000000000..ac4f06a91c474 --- /dev/null +++ b/aten/src/ATen/core/boxing/OperatorKernel.h @@ -0,0 +1,27 @@ +#pragma once +#include + +namespace c10 { + +/** + * Inherit from OperatorKernel to implement a c10 kernel. + * + * Example: + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * + * The kernel class is allowed to have members but these are equivalent + * to global variables. The kernel implementation is responsible for + * preventing race conditions on them. + * + * See below for how to register this kernel with PyTorch. + */ +struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target { + virtual ~OperatorKernel() = default; +}; + +} // namespace c10 diff --git a/aten/src/ATen/core/boxing/impl/boxing.h b/aten/src/ATen/core/boxing/impl/boxing.h index b16f01798c6c7..ccac9ebe8f61b 100644 --- a/aten/src/ATen/core/boxing/impl/boxing.h +++ b/aten/src/ATen/core/boxing/impl/boxing.h @@ -4,9 +4,10 @@ // i.e. how to make a vector from a set of concrete arguments. #include +#include #include -#include +#include #include @@ -217,14 +218,13 @@ struct BoxedKernelWrapper< > > { static Result call( - KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, - OperatorKernel* functor, + const BoxedKernel& boxed_kernel_func, const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args ) { torch::jit::Stack stack = boxArgs(std::forward(args)...); - (*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); return guts::if_constexpr::value>( [&] (auto delay_check) { @@ -258,14 +258,13 @@ struct BoxedKernelWrapper< std::enable_if_t::value, void> > { static at::Tensor& call( - KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, - OperatorKernel* functor, + const BoxedKernel& boxed_kernel_func, const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, at::Tensor& outArg, OtherArgs... otherArgs ) { torch::jit::Stack stack = boxArgs(outArg, std::forward(otherArgs)...); - (*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( stack.size() == 1, "Boxed kernel was expected to return a single value on the stack, ", @@ -285,14 +284,13 @@ struct BoxedKernelWrapper< std::enable_if_t::value, void> > { static const at::Tensor& call( - KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, - OperatorKernel* functor, + const BoxedKernel& boxed_kernel_func, const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, const at::Tensor& outArg, OtherArgs... otherArgs ) { torch::jit::Stack stack = boxArgs(outArg, otherArgs...); - (*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( stack.size() == 1, "Boxed kernel was expected to return a single value on the stack, ", @@ -323,14 +321,13 @@ struct BoxedKernelWrapper< > > { static at::Tensor& call( - KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, - OperatorKernel* functor, + const BoxedKernel& boxed_kernel_func, const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, FirstArg firstArg, RestArgs... restArgs ) { torch::jit::Stack stack = boxArgs(std::forward(firstArg), std::forward(restArgs)...); - (*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( stack.size() == 1, "Boxed kernel was expected to return a single value on the stack, ", @@ -360,8 +357,7 @@ struct BoxedKernelWrapper< > > { static Result call( - KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, - OperatorKernel* functor, + const BoxedKernel& boxed_kernel_func, const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args @@ -370,7 +366,7 @@ struct BoxedKernelWrapper< constexpr int RetCount = std::tuple_size(); torch::jit::Stack stack = boxArgs(std::forward(args)...); - (*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( stack.size() == RetCount, "Boxed kernel was expected to return ", RetCount, " values on the stack, ", diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 2b2228bb944da..0a28330a0bfb5 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -64,27 +65,6 @@ class OperatorHandle; * the expected operator signature at each call site. */ -/** - * Inherit from OperatorKernel to implement a c10 kernel. - * - * Example: - * > namespace { - * > class my_kernel_cpu final : public c10::OperatorKernel { - * > public: - * > Tensor operator()(Tensor a, Tensor b) {...} - * > }; - * > } - * - * The kernel class is allowed to have members but these are equivalent - * to global variables. The kernel implementation is responsible for - * preventing race conditions on them. - * - * See below for how to register this kernel with PyTorch. - */ -struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target { - virtual ~OperatorKernel() = default; -}; - namespace impl { // supported_primitive_arg_types defines which primitive types we allow in // kernel functions as arguments or returns. diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 3385bff63a11a..667eefdcc5ab8 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -4,6 +4,11 @@ namespace c10 { +bool show_dispatch_trace() { + static char const* temp = getenv("TORCH_SHOW_DISPATCH_TRACE"); + return temp != nullptr; +} + namespace detail { class RegistrationListenerList final { diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 3ab0619c05ae9..83d1738da423b 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -18,6 +18,8 @@ namespace c10 { +TORCH_API bool show_dispatch_trace(); + class TORCH_API OperatorHandle; template class TypedOperatorHandle; @@ -557,6 +559,11 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandl detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor() .template getDispatchKeySetUnboxed(args...); +#ifndef NDEBUG + if (show_dispatch_trace()) { + std::cerr << "[call] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; + } +#endif const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); @@ -572,6 +579,11 @@ template inline Return Dispatcher::redispatch(const TypedOperatorHandle& op, DispatchKeySet currentDispatchKeySet, Args... args) const { detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 // do not use RecordFunction on redispatch +#ifndef NDEBUG + if (show_dispatch_trace()) { + std::cerr << "[redispatch] op=[" << op.operator_name() << "], key=[" << toString(currentDispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; + } +#endif const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet); return kernel.template call(op, currentDispatchKeySet, std::forward(args)...); } @@ -580,6 +592,11 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const // note: this doesn't need the mutex because write operations on the list keep iterators intact. const auto& entry = op.operatorDef_->op; auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); +#ifndef NDEBUG + if (show_dispatch_trace()) { + std::cerr << "[callBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; + } +#endif const auto& kernel = entry.lookup(dispatchKeySet); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); @@ -606,6 +623,11 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const { // note: this doesn't need the mutex because write operations on the list keep iterators intact. const auto& entry = op.operatorDef_->op; +#ifndef NDEBUG + if (show_dispatch_trace()) { + std::cerr << "[redispatchBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; + } +#endif const auto& kernel = entry.lookup(dispatchKeySet); return kernel.callBoxed(op, dispatchKeySet, stack); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 1d94fce904086..afcf552fdecda 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -431,7 +431,7 @@ std::string OperatorEntry::listAllDispatchKeys() const { if (has_kernels) { str << ", "; } - str << static_cast(iter); + str << k; has_kernels = true; } str << "]"; diff --git a/aten/src/ATen/core/function_schema.cpp b/aten/src/ATen/core/function_schema.cpp index a4319f03132cc..a3a10862178c8 100644 --- a/aten/src/ATen/core/function_schema.cpp +++ b/aten/src/ATen/core/function_schema.cpp @@ -1,6 +1,7 @@ #include #include +#include namespace c10 { @@ -8,4 +9,156 @@ void FunctionSchema::dump() const { std::cout << *this << "\n"; } +const std::vector& FunctionSchema::getCorrectList(SchemaArgType type) const { + if (type == SchemaArgType::input) { + return arguments(); + } else { + return returns(); + } } + +bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional &lhs, const c10::optional &rhs) const { + if (!lhs || !rhs) { + return false; + } + for (const TypePtr& lhsType : *lhs) { + for (const TypePtr& rhsType : *rhs) { + if (lhsType == rhsType) { + return true; + } + } + } + return false; +} + +c10::optional FunctionSchema::getAliasTypeSetContainedTypes(const c10::optional &aliasTypeSet) const { + if (!aliasTypeSet) { + return c10::nullopt; + } + std::unordered_set containedTypes; + std::stack typeStack; + // Push all 1st level contained types into the stack. + for (const TypePtr& type: *aliasTypeSet) { + for (const TypePtr& containedType : type->containedTypes()){ + typeStack.push(containedType); + } + } + + // process all further level contained types. + while (!typeStack.empty()) { + TypePtr current = typeStack.top(); + typeStack.pop(); + if (!containedTypes.count(current)) { + for (const TypePtr& containedType : current->containedTypes()) { + typeStack.push(containedType); + } + } + containedTypes.insert(current); + } + + return AliasTypeSet(containedTypes.begin(), containedTypes.end()); +} + +c10::optional FunctionSchema::mapTypeToAliasTypeSet(const TypePtr& type) const { + switch(type->kind()) { + case TypeKind::ListType: + case TypeKind::DictType: + case TypeKind::ClassType: + case TypeKind::TensorType: + return AliasTypeSet {c10::unshapedType(type)}; + case TypeKind::UnionType: { + AliasTypeSet mutable_types; + for (const TypePtr& inner : + type->expectRef().containedTypes()) { + if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) { + mutable_types.insert( + mutable_types.end(), + (*maybe_inner_types).begin(), + (*maybe_inner_types).end()); + } + } + if (mutable_types.size() == 0) { + return c10::nullopt; + } + return mutable_types; + } + case TypeKind::AnyType: + return {AliasTypeSet{type}}; + case TypeKind::OptionalType: { + auto inner = type->castRaw()->getElementType(); + return mapTypeToAliasTypeSet(inner); + } + case TypeKind::TupleType: { + AliasTypeSet mutable_types; + for (const TypePtr& inner : type->expectRef().elements()) { + if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) { + mutable_types.insert( + mutable_types.end(), + (*maybe_inner_types).begin(), + (*maybe_inner_types).end()); + } + } + if (mutable_types.size() == 0) { + return c10::nullopt; + } + return {AliasTypeSet{TupleType::create(mutable_types)}}; + } + default: + return c10::nullopt; + } +} + +bool FunctionSchema::may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const { + TORCH_INTERNAL_ASSERT( + (lhs.index < getCorrectList(lhs.type).size()), + "Invalid index for schema."); + TORCH_INTERNAL_ASSERT( + (rhs.index < getCorrectList(rhs.type).size()), + "Invalid index for schema."); + + const Argument lhsArg = getCorrectList(lhs.type)[lhs.index]; + const Argument rhsArg = getCorrectList(rhs.type)[rhs.index]; + + c10::optional lhsTypes = mapTypeToAliasTypeSet(lhsArg.type()); + c10::optional rhsTypes = mapTypeToAliasTypeSet(rhsArg.type()); + + // Check to see if lhs and rhs have the same alias set + if (canAliasTypeSetsAlias(lhsTypes, rhsTypes)) { + if (lhsArg.alias_info() && rhsArg.alias_info()) { + for (const auto& lhsSet : lhsArg.alias_info()->afterSets()) { + for (const auto& rhsSet : rhsArg.alias_info()->afterSets()) { + if (lhsSet == rhsSet) { + return true; + } + } + } + } + } + + return false; +} + +bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional) const { + bool may_alias_result = may_alias(lhs, rhs); + if (may_alias_result) { + return true; + } + + const c10::Argument lhsArg = getCorrectList(lhs.type)[lhs.index]; + const c10::Argument rhsArg = getCorrectList(rhs.type)[rhs.index]; + c10::optional lhsTypes = mapTypeToAliasTypeSet(lhsArg.type()); + c10::optional rhsTypes = mapTypeToAliasTypeSet(rhsArg.type()); + c10::optional lhsContainedTypes = getAliasTypeSetContainedTypes(lhsTypes); + c10::optional rhsContainedTypes = getAliasTypeSetContainedTypes(rhsTypes); + + // Checks if one side is wildcard and the other side is a container of the same type + bool lhsWildcard = lhsArg.alias_info() && lhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(lhsTypes, rhsContainedTypes); + bool rhsWildcard = rhsArg.alias_info() && rhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(rhsTypes, lhsContainedTypes); + + if (bidirectional) { + return lhsWildcard || rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes); + } else { + return rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes); + } +} +} // namespace c10 diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 6680d2543e26a..77fdb20f6516a 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -20,6 +20,8 @@ namespace c10 { struct Argument; struct FunctionSchema; +using AliasTypeSet = std::vector; + bool operator==(const Argument& lhs, const Argument& rhs); struct Argument { @@ -202,9 +204,24 @@ inline bool operator!=(const Argument& lhs, const Argument& rhs) { return !(lhs == rhs); } +enum struct TORCH_API SchemaArgType { input, output }; + +/** + * struct SchemaArgument + * + * Structure used to represent arguments or returns for a schema. + */ +struct TORCH_API SchemaArgument { + SchemaArgType type; + size_t index; + bool operator==(const SchemaArgument& rhs) const { + return type == rhs.type && index == rhs.index; + } +}; + bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs); -struct FunctionSchema { +struct TORCH_API FunctionSchema { FunctionSchema( std::string name, std::string overload_name, @@ -352,6 +369,13 @@ struct FunctionSchema { bool is_varret() const { return is_varret_; } + bool is_aliasing(const c10::SchemaArgument &argument) const { + TORCH_INTERNAL_ASSERT( + argument.index < getCorrectList(argument.type).size(), + "Invalid index for schema."); + const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info(); + return aliasInfo; + } bool is_mutable() const { return std::any_of( arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) { @@ -359,6 +383,49 @@ struct FunctionSchema { return aliasInfo && aliasInfo->isWrite(); }); } + bool is_mutable(const c10::SchemaArgument &argument) const { + TORCH_INTERNAL_ASSERT( + argument.index < getCorrectList(argument.type).size(), + "Invalid index for schema."); + const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info(); + return aliasInfo && aliasInfo->isWrite(); + } + bool is_mutable(c10::string_view name) const { + c10::optional index = argumentIndexWithName(name); + TORCH_INTERNAL_ASSERT( + index != c10::nullopt, "Schema has no argument named ", name); + + return is_mutable({c10::SchemaArgType::input, static_cast(*index)}); + } + + // Returns whether lhs and rhs may alias directly. + // This does not account for cases where lhs or rhs are a container that + // may contain elements that alias the other argument. + // FunctionSchema::may_contain_alias will include that functionality. + bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const; + + // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container + // that may contain elements that alias the other argument. + // bidirectional = false only returns whether lhs may contain an alias of rhs + // while bidirectional = true returns both directions. + bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const; + + // Returns whether the two AliasTypeSets contain any similarities + // ie: whether the two type sets can alias. + bool canAliasTypeSetsAlias(const c10::optional &lhs, const c10::optional &rhs) const; + + // Recursively Finds all contained types within the AliasTypeSet. + c10::optional getAliasTypeSetContainedTypes(const c10::optional &aliasTypeSet) const; + + // Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp. + // Used to map types to a type such that all types that can alias will be mapped to the same type. + // For example, calling this method on 'Optional[List[int]]' is the same as calling this method + // on 'List[int]'. + c10::optional mapTypeToAliasTypeSet(const TypePtr& type) const; + + // Returns either arguments() or returns() depending on the SchemaArgType + // output => returns(), input => arguments() + const std::vector& getCorrectList(SchemaArgType type) const; c10::optional argumentIndexWithName(c10::string_view name) const { for (const auto i : c10::irange(arguments().size())) { @@ -547,4 +614,15 @@ inline std::string toString(const FunctionSchema& schema) { } // namespace c10 +namespace std { +template<> + struct hash { + size_t operator()(const c10::SchemaArgument& arg) const + { + return c10::hash_combine(std::hash()(arg.index), std::hash()(static_cast(arg.type))); + } + }; +} // namespace std + + #include // IWYU pragma: keep diff --git a/aten/src/ATen/core/interned_strings.cpp b/aten/src/ATen/core/interned_strings.cpp index f9956b3d8f9ad..0ad87c21c837f 100644 --- a/aten/src/ATen/core/interned_strings.cpp +++ b/aten/src/ATen/core/interned_strings.cpp @@ -140,6 +140,7 @@ bool Symbol::is_attr() const { return ns() == namespaces::attr; } bool Symbol::is_aten() const { return ns() == namespaces::aten; } bool Symbol::is_cuda() const { return ns() == namespaces::cuda; } bool Symbol::is_prim() const { return ns() == namespaces::prim; } +bool Symbol::is_prims() const { return ns() == namespaces::prims; } bool Symbol::is_onnx() const { return ns() == namespaces::onnx; } bool Symbol::is_user() const { return ns() == namespaces::user; } bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; } diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 26f48ee341d61..8a195128b4d2c 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -14,6 +14,7 @@ namespace c10 { #define FORALL_NS_SYMBOLS(_) \ _(namespaces, prim) \ + _(namespaces, prims) \ _(namespaces, aten) \ _(namespaces, cuda) \ _(namespaces, onnx) \ diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index e9a0caecc5d65..8d0199b3c9546 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -555,8 +555,14 @@ struct TORCH_API IValue final { payload.u.as_int = i; } - IValue(c10::SymInt i) : tag(Tag::SymInt) { - payload.u.as_int = i.data(); + IValue(c10::SymInt i) { + if (i.is_symbolic()) { + tag = Tag::SymInt; + payload.u.as_intrusive_ptr = i.toSymIntNodeImpl().release(); + } else { + tag = Tag::Int; + payload.u.as_int = i.as_int_unchecked(); + } } IValue(c10::SymIntArrayRef v); @@ -565,9 +571,7 @@ struct TORCH_API IValue final { return Tag::SymInt == tag; } - c10::SymInt toSymInt() const { - return c10::SymInt(payload.u.as_int); - } + c10::SymInt toSymInt() const; // allow you to pass literals (3, 4) without ambiguity IValue(int32_t i) : IValue(static_cast(i)) {} @@ -1072,7 +1076,7 @@ struct TORCH_API IValue final { case Tag::Int: return false; case Tag::SymInt: - return false; + return true; case Tag::Bool: return false; case Tag::Tuple: diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 3d0f361df26ed..301b448b834eb 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -215,6 +215,14 @@ inline at::Generator IValue::toGenerator() const& { AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); return at::Generator(toIntrusivePtr()); } +inline c10::SymInt IValue::toSymInt() const { + AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); + if (isSymInt()) { + return c10::SymInt::toSymInt(toIntrusivePtr()); + } else { + return c10::SymInt(payload.u.as_int); + } +} namespace ivalue { @@ -991,7 +999,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { cb = std::move(callback)](Future& parentFut) mutable { try { guts::if_constexpr::type, + typename c10::invoke_result_t, IValueWithStorages>::value>( [&](auto identity) { IValue value; diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 42fad24762942..50b27a0e8fd8b 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1309,7 +1309,7 @@ struct TORCH_API SymIntType : public Type { return "SymInt"; } std::string annotation_str_impl(TypePrinter printer = nullptr) const override { - // TODO: will become a Union[SymbolicIntNode|int] in the near future + // TODO: will become a Union[SymIntNodeImpl|int] in the near future return "int"; } static const TypeKind Kind = TypeKind::SymIntType; diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h index 2e1c84db867ba..6fee9fe0a1138 100644 --- a/aten/src/ATen/core/jit_type_base.h +++ b/aten/src/ATen/core/jit_type_base.h @@ -247,7 +247,7 @@ struct TORCH_API Type { // nvcc; see comment in destroy() below. struct SharedPtrWrapper { SharedPtrWrapper(std::shared_ptr &&x) - : repr_(x) {} + : repr_(std::move(x)) {} std::shared_ptr repr_; }; union Repr { diff --git a/aten/src/ATen/core/symbol.h b/aten/src/ATen/core/symbol.h index 26493151b8abe..c06c261c3dd3c 100644 --- a/aten/src/ATen/core/symbol.h +++ b/aten/src/ATen/core/symbol.h @@ -81,6 +81,7 @@ struct TORCH_API Symbol { bool is_aten() const; bool is_cuda() const; bool is_prim() const; + bool is_prims() const; bool is_onnx() const; bool is_user() const; bool is_caffe2() const; diff --git a/aten/src/ATen/cpu/vec/vec.h b/aten/src/ATen/cpu/vec/vec.h index 24b8818d2a8d8..9d39142aad91c 100644 --- a/aten/src/ATen/cpu/vec/vec.h +++ b/aten/src/ATen/cpu/vec/vec.h @@ -1,5 +1,36 @@ +#pragma once + #if defined(CPU_CAPABILITY_AVX512) #include #else #include #endif + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline Vectorized convert_to_bool(Vectorized x) { + __at_align__ bool buffer[x.size()]; + x.ne(Vectorized(0)).store(buffer); + + Vectorized ret; + static_assert(x.size() == ret.size(), ""); + std::memcpy(ret, buffer, ret.size() * sizeof(bool)); + return ret; +} + +template <> +inline Vectorized Vectorized::loadu(const void* ptr) { + // See NOTE [Loading boolean values] + return convert_to_bool(Vectorized::loadu(ptr)); +} + +template <> +inline Vectorized Vectorized::loadu(const void* ptr, int64_t count) { + // See NOTE [Loading boolean values] + return convert_to_bool(Vectorized::loadu(ptr, count)); +} + +}}} // namespace at::vec::CPU_CAPABILITY diff --git a/aten/src/ATen/cuda/Atomic.cuh b/aten/src/ATen/cuda/Atomic.cuh index 1189cc05de12a..079b289ef8c30 100644 --- a/aten/src/ATen/cuda/Atomic.cuh +++ b/aten/src/ATen/cuda/Atomic.cuh @@ -71,97 +71,123 @@ struct AtomicFPOp { } }; -template -struct AtomicAddIntegerImpl; - -template -struct AtomicAddIntegerImpl { - inline __device__ void operator()(T *address, T val) { - size_t offset = (size_t)address & 3; - uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); - uint32_t old = *address_as_ui; - uint32_t shift = offset * 8; - uint32_t old_byte; - uint32_t newval; - uint32_t assumed; - - do { - assumed = old; - old_byte = (old >> shift) & 0xff; - // preserve size in initial cast. Casting directly to uint32_t pads - // negative signed values with 1's (e.g. signed -1 = unsigned ~0). - newval = static_cast(val + static_cast(old_byte)); - newval = (old & ~(0x000000ff << shift)) | (newval << shift); - old = atomicCAS(address_as_ui, assumed, newval); - } while (assumed != old); - } +#define ATOMIC_INTEGER_IMPL(NAME) \ +template \ +struct Atomic##NAME##IntegerImpl; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + size_t offset = (size_t)address & 3; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + uint32_t old = *address_as_ui; \ + uint32_t shift = offset * 8; \ + uint32_t old_byte; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_byte = (old >> shift) & 0xff; \ + newval = static_cast(func(val, static_cast(old_byte))); \ + newval = (old & ~(0x000000ff << shift)) | (newval << shift); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + size_t offset = (size_t)address & 2; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + bool is_32_align = offset; \ + uint32_t old = *address_as_ui; \ + uint32_t old_bytes; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_bytes = is_32_align ? old >> 16 : old & 0xffff; \ + newval = static_cast(func(val, static_cast(old_bytes))); \ + newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + uint32_t * address_as_ui = (uint32_t *) (address); \ + uint32_t old = *address_as_ui; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + newval = static_cast(func(val, static_cast(old))); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + unsigned long long * address_as_ui = (unsigned long long *) (address); \ + unsigned long long old = *address_as_ui; \ + unsigned long long newval; \ + unsigned long long assumed; \ + \ + do { \ + assumed = old; \ + newval = static_cast(func(val, static_cast(old))); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ }; -template -struct AtomicAddIntegerImpl { - inline __device__ void operator()(T *address, T val) { - size_t offset = (size_t)address & 2; - uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); - bool is_32_align = offset; - uint32_t old = *address_as_ui; - uint32_t old_bytes; - uint32_t newval; - uint32_t assumed; - do { - assumed = old; - old_bytes = is_32_align ? old >> 16 : old & 0xffff; - // preserve size in initial cast. Casting directly to uint32_t pads - // negative signed values with 1's (e.g. signed -1 = unsigned ~0). - newval = static_cast(val + static_cast(old_bytes)); - newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; - old = atomicCAS(address_as_ui, assumed, newval); - } while (assumed != old); - } -}; +# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \ +static inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \ +Atomic##NAME##IntegerImpl()(address, \ + val, \ + [](DTYPE a, DTYPE b) { \ + return OP; \ + }); \ +} \ -template -struct AtomicAddIntegerImpl { - inline __device__ void operator()(T *address, T val) { - uint32_t * address_as_ui = (uint32_t *) (address); - uint32_t old = *address_as_ui; - uint32_t newval; - uint32_t assumed; - - do { - assumed = old; - newval = static_cast(val + static_cast(old)); - old = atomicCAS(address_as_ui, assumed, newval); - } while (assumed != old); - } -}; - -template -struct AtomicAddIntegerImpl { - inline __device__ void operator()(T *address, T val) { - unsigned long long * address_as_ui = (unsigned long long *) (address); - unsigned long long old = *address_as_ui; - unsigned long long newval; - unsigned long long assumed; - - do { - assumed = old; - newval = static_cast(val + static_cast(old)); - old = atomicCAS(address_as_ui, assumed, newval); - } while (assumed != old); - } -}; +ATOMIC_INTEGER_IMPL(Add) +// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64) static inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) { - AtomicAddIntegerImpl()(address, val); + AtomicAddIntegerImpl()(address, + val, + [](uint8_t a, uint8_t b) { + return a + b; + }); } static inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) { - AtomicAddIntegerImpl()(address, val); + AtomicAddIntegerImpl()(address, + val, + [](int8_t a, int8_t b) { + return a + b; + }); } static inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) { - AtomicAddIntegerImpl()(address, val); + AtomicAddIntegerImpl()(address, + val, + [](int16_t a, int16_t b) { + return a + b; + }); } static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) { @@ -172,7 +198,11 @@ static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) { #if defined(USE_ROCM) __atomic_fetch_add(address, val, __ATOMIC_RELAXED); #else - AtomicAddIntegerImpl()(address, val); + AtomicAddIntegerImpl()(address, + val, + [](int64_t a, int64_t b) { + return a + b; + }); #endif } @@ -308,6 +338,13 @@ static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { // Atomic multiplication implementation. +ATOMIC_INTEGER_IMPL(Mul) +GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int8_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int16_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int32_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int64_t) + inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) { return AtomicFPOp()(address, val, [](at::Half bsum, at::Half val) { @@ -362,6 +399,13 @@ __host__ __device__ T safe_max(T a, T b) { return max; } +ATOMIC_INTEGER_IMPL(Max) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t) + inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) { return AtomicFPOp()(address, val, [](at::Half bsum, at::Half val) { @@ -415,6 +459,13 @@ __host__ __device__ T safe_min(T a, T b) { return min; } +ATOMIC_INTEGER_IMPL(Min) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t) + inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) { return AtomicFPOp()(address, val, [](at::Half bsum, at::Half val) { diff --git a/aten/src/ATen/cuda/CUDAContext.cpp b/aten/src/ATen/cuda/CUDAContext.cpp index 1751128f1a881..98fa9a5f6dd2b 100644 --- a/aten/src/ATen/cuda/CUDAContext.cpp +++ b/aten/src/ATen/cuda/CUDAContext.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -11,8 +12,8 @@ namespace at { namespace cuda { namespace { DeviceIndex num_gpus = -1; -std::once_flag init_flag; -std::deque device_flags; +c10::once_flag init_flag; +std::deque device_flags; std::vector device_properties; void initCUDAContextVectors() { @@ -44,15 +45,15 @@ cudaDeviceProp* getCurrentDeviceProperties() { } cudaDeviceProp* getDeviceProperties(int64_t device) { - std::call_once(init_flag, initCUDAContextVectors); + c10::call_once(init_flag, initCUDAContextVectors); if (device == -1) device = c10::cuda::current_device(); AT_ASSERT(device >= 0 && device < num_gpus); - std::call_once(device_flags[device], initDeviceProperty, device); + c10::call_once(device_flags[device], initDeviceProperty, device); return &device_properties[device]; } bool canDeviceAccessPeer(int64_t device, int64_t peer_device) { - std::call_once(init_flag, initCUDAContextVectors); + c10::call_once(init_flag, initCUDAContextVectors); if (device == -1) device = c10::cuda::current_device(); AT_ASSERT(device >= 0 && device < num_gpus); AT_ASSERT(peer_device >= 0 && peer_device < num_gpus); diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index 0f8f8b0eb5288..0cac5d6da2d54 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace at { @@ -12,13 +13,13 @@ namespace detail { namespace { // Ensures we only call cudaGetDeviceCount only once. -static std::once_flag num_gpu_init_flag; +static c10::once_flag num_gpu_init_flag; // Total number of gpus in the system. static int64_t num_gpus; // Ensures default_gens_cuda is initialized once. -static std::deque cuda_gens_init_flag; +static std::deque cuda_gens_init_flag; // Default, global CUDA generators, one per GPU. static std::vector default_gens_cuda; @@ -44,14 +45,14 @@ static void initCUDAGenVector(){ * cuda device. */ const Generator& getDefaultCUDAGenerator(DeviceIndex device_index) { - std::call_once(num_gpu_init_flag, initCUDAGenVector); + c10::call_once(num_gpu_init_flag, initCUDAGenVector); DeviceIndex idx = device_index; if (idx == -1) { idx = c10::cuda::current_device(); } else { TORCH_CHECK(idx >= 0 && idx < num_gpus); } - std::call_once(cuda_gens_init_flag[idx], [&] { + c10::call_once(cuda_gens_init_flag[idx], [&] { default_gens_cuda[idx] = make_generator(idx); default_gens_cuda[idx].seed(); }); @@ -62,7 +63,7 @@ const Generator& getDefaultCUDAGenerator(DeviceIndex device_index) { * Utility to create a CUDAGeneratorImpl. Returns a shared_ptr */ Generator createCUDAGenerator(DeviceIndex device_index) { - std::call_once(num_gpu_init_flag, initCUDAGenVector); + c10::call_once(num_gpu_init_flag, initCUDAGenVector); DeviceIndex idx = device_index; if (idx == -1) { idx = c10::cuda::current_device(); diff --git a/aten/src/ATen/cuda/cub-RadixSortKeys.cu b/aten/src/ATen/cuda/cub-RadixSortKeys.cu new file mode 100644 index 0000000000000..330ab350d130a --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortKeys.cu @@ -0,0 +1,59 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at { +namespace cuda { +namespace cub { + +template +void radix_sort_keys( + const key_t* keys_in, + key_t* keys_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + TORCH_CHECK( + n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::cuda_type::type; + + const key_t_* keys_in_ = reinterpret_cast(keys_in); + key_t_* keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending, + keys_in_, + keys_out_, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } else { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys, + keys_in_, + keys_out_, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } +} + +#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \ + template void radix_sort_keys( \ + const scalar_t* keys_in, \ + scalar_t* keys_out, \ + int64_t n, \ + bool descending, \ + int64_t begin_bit, \ + int64_t end_bit); + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES) + +} // namespace cub +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs.cu b/aten/src/ATen/cuda/cub-RadixSortPairs.cu new file mode 100644 index 0000000000000..3c28a7141cf26 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs.cu @@ -0,0 +1,93 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at { +namespace cuda { +namespace cub { +namespace detail { + +template +void radix_sort_pairs_impl( + const key_t* keys_in, + key_t* keys_out, + const OpaqueType* values_in, + OpaqueType* values_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + TORCH_CHECK( + n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::cuda_type::type; + + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr keys_out_owner; + + if (keys_out == nullptr) { + keys_out_owner = allocator->allocate(n * sizeof(key_t)); + keys_out = reinterpret_cast(keys_out_owner.get()); + } + + const key_t_* keys_in_ = reinterpret_cast(keys_in); + key_t_* keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } else { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } +} + +#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \ + template void radix_sort_pairs_impl( \ + const key_t* keys_in, \ + key_t* keys_out, \ + const OpaqueType* values_in, \ + OpaqueType* values_out, \ + int64_t n, \ + bool descending, \ + int64_t begin_bit, \ + int64_t end_bit); + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) +AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) +AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) + +#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ + AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) + +// BFloat16 Radix sort is supported from ROCm 4.5 onwards +#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500) +AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) +#endif + +} // namespace detail + +} // namespace cub +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/cub.cu b/aten/src/ATen/cuda/cub.cu index bf3216eee6dad..61aa7747e1999 100644 --- a/aten/src/ATen/cuda/cub.cu +++ b/aten/src/ATen/cuda/cub.cu @@ -5,118 +5,6 @@ namespace at { namespace cuda { namespace cub { -namespace detail { - -template -void radix_sort_pairs_impl( - const key_t *keys_in, key_t *keys_out, - const OpaqueType *values_in, OpaqueType *values_out, - int64_t n, bool descending, int64_t begin_bit, int64_t end_bit) { - TORCH_CHECK(n <= std::numeric_limits::max(), - "cub sort does not support sorting more than INT_MAX elements"); - using key_t_ = typename detail::cuda_type::type; - - auto allocator = c10::cuda::CUDACachingAllocator::get(); - c10::DataPtr keys_out_owner; - - if (keys_out == nullptr) { - keys_out_owner = allocator->allocate(n * sizeof(key_t)); - keys_out = reinterpret_cast(keys_out_owner.get()); - } - - const key_t_ *keys_in_ = reinterpret_cast(keys_in); - key_t_ *keys_out_ = reinterpret_cast(keys_out); - - if (descending) { - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending, - keys_in_, keys_out_, values_in, values_out, n, - begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); - } else { - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs, - keys_in_, keys_out_, values_in, values_out, n, - begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); - } -} - -#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \ - template void radix_sort_pairs_impl( \ - const key_t *keys_in, key_t *keys_out, \ - const OpaqueType *values_in, \ - OpaqueType *values_out, \ - int64_t n, bool descending, int64_t begin_bit, int64_t end_bit); - -AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) -AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) -AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) - -#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ - AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) - -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) - -// BFloat16 Radix sort is supported from ROCm 4.5 onwards -#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500) -AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) -#endif - -} // namespace detail - -template -void radix_sort_keys( - const key_t *keys_in, key_t *keys_out, - int64_t n, bool descending, int64_t begin_bit, int64_t end_bit) { - TORCH_CHECK(n <= std::numeric_limits::max(), - "cub sort does not support sorting more than INT_MAX elements"); - using key_t_ = typename detail::cuda_type::type; - - const key_t_ *keys_in_ = reinterpret_cast(keys_in); - key_t_ *keys_out_ = reinterpret_cast(keys_out); - - if (descending) { - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending, - keys_in_, keys_out_, n, - begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); - } else { - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys, - keys_in_, keys_out_, n, - begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); - } -} - -template -void unique(const scalar_t *input, scalar_t *output, int64_t *num_selected_out, int64_t num_items) { - TORCH_CHECK(num_items <= std::numeric_limits::max(), - "cub unique does not support more than INT_MAX elements"); - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique, - input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream()); -} - -template -void run_length_encode(const scalar_t *input, scalar_t *output, int64_t *counts_out, - int64_t *length_out, int64_t num_items) { - TORCH_CHECK(num_items <= std::numeric_limits::max(), - "cub run_length_encode does not support more than INT_MAX elements"); - CUB_WRAPPER( - NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode, - input, output, counts_out, length_out, num_items, - at::cuda::getCurrentCUDAStream()); -} - -#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \ - template void radix_sort_keys( \ - const scalar_t *keys_in, scalar_t *keys_out, int64_t n, \ - bool descending, int64_t begin_bit, int64_t end_bit); \ - template void unique( \ - const scalar_t *input, scalar_t *output, \ - int64_t *num_selected_out, int64_t num_items); \ - template void run_length_encode( \ - const scalar_t *input, scalar_t *output, int64_t *counts_out, \ - int64_t *length_out, int64_t n); - -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES) namespace { template @@ -145,7 +33,20 @@ void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_ template void exclusive_sum_in_common_type(const int32_t *input, int32_t *output, int64_t num_items); template void exclusive_sum_in_common_type(const int64_t *input, int64_t *output, int64_t num_items); -template void exclusive_sum_in_common_type(const bool *input, int64_t *output, int64_t num_items); -template void exclusive_sum_in_common_type(const uint8_t *input, int64_t *output, int64_t num_items); + +namespace { +struct CountMaskOp { + __device__ int64_t operator() (const uint8_t &x) const { + return x != 0; + } +}; +} + +void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n) { + CountMaskOp op{}; + auto iter = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator< + bool, decltype(op), decltype(mask)>(mask, op); + exclusive_scan(iter, output_idx, SumOp{}, int64_t{0}, n); +} }}} // namespace at::cuda::cub diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index abe2e9272014f..7ac10378b0bcd 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -385,4 +385,36 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT #endif +template +void unique(InputIteratorT input, OutputIteratorT output, + NumSelectedIteratorT num_selected_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub unique does not support more than INT_MAX elements"); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique, + input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream()); +} + +template +void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out, + LengthOutputIteratorT length_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub run_length_encode does not support more than INT_MAX elements"); + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode, + input, output, counts_out, length_out, num_items, + at::cuda::getCurrentCUDAStream()); +} + +template +void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub reduce does not support more than INT_MAX elements"); + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce, + input, output, num_items, op, init, + at::cuda::getCurrentCUDAStream()); + +} + }}} // namespace at::cuda::cub diff --git a/aten/src/ATen/cuda/cub.h b/aten/src/ATen/cuda/cub.h index 85e0ff210831d..2e6a808d6f510 100644 --- a/aten/src/ATen/cuda/cub.h +++ b/aten/src/ATen/cuda/cub.h @@ -62,14 +62,6 @@ void radix_sort_keys( const key_t *keys_in, key_t *keys_out, int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8); -template -void unique(const scalar_t *input, scalar_t *output, - int64_t *num_selected_out, int64_t num_items); - -template -void run_length_encode(const scalar_t *input, scalar_t *output, int64_t *counts_out, - int64_t *length_out, int64_t n); - // NOTE: Intermediate sums will be truncated to input_t precision template void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n); @@ -88,4 +80,10 @@ void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { return exclusive_sum_in_common_type(input, output, n); } +void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n); +inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) { + return mask_exclusive_sum( + reinterpret_cast(mask), output_idx, n); +} + }}} // namespace at::cuda::cub diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 93a23ec6a7309..ea335180259e1 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -266,6 +266,20 @@ bool CUDAHooks::supportsDepthwiseConvolutionWithCuDNN() const { #endif } +bool CUDAHooks::supportsBFloat16ConvolutionWithCuDNNv8() const { +#if AT_CUDNN_ENABLED() + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + // Check for Volta cores + if (prop->major >= 8) { + return true; + } else { + return false; + } +#else + return false; +#endif +} + long CUDAHooks::versionCuDNN() const { #if AT_CUDNN_ENABLED() return CUDNN_VERSION; diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index 5aa2721170ed4..d53276ab3bbac 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -39,6 +39,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool compiledWithMIOpen() const override; bool supportsDilatedConvolutionWithCuDNN() const override; bool supportsDepthwiseConvolutionWithCuDNN() const override; + bool supportsBFloat16ConvolutionWithCuDNNv8() const override; bool hasCUDART() const override; long versionCUDART() const override; long versionCuDNN() const override; diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp index e720994e9249d..0e75c1842cbb7 100644 --- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -3,7 +3,6 @@ #include #include #include -#include // for std::call_once namespace at { namespace cuda { diff --git a/aten/src/ATen/detail/CUDAHooksInterface.cpp b/aten/src/ATen/detail/CUDAHooksInterface.cpp index 775994cfbebdd..6f1198ea25527 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.cpp +++ b/aten/src/ATen/detail/CUDAHooksInterface.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -35,8 +36,8 @@ const CUDAHooksInterface& getCUDAHooks() { // needing a lock, be careful; it doesn't look like Registry.h is thread // safe...) #if !defined C10_MOBILE - static std::once_flag once; - std::call_once(once, [] { + static c10::once_flag once; + c10::call_once(once, [] { cuda_hooks = CUDAHooksRegistry()->Create("CUDAHooks", CUDAHooksArgs{}).release(); if (!cuda_hooks) { cuda_hooks = new CUDAHooksInterface(); diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 1303b9f8c8bf8..7ba8f68d94b20 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -147,6 +147,10 @@ struct TORCH_API CUDAHooksInterface { return false; } + virtual bool supportsBFloat16ConvolutionWithCuDNNv8() const { + return false; + } + virtual long versionCuDNN() const { TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); } diff --git a/aten/src/ATen/detail/HIPHooksInterface.cpp b/aten/src/ATen/detail/HIPHooksInterface.cpp index 27fe6d1dc3b80..0ae903c4d6a51 100644 --- a/aten/src/ATen/detail/HIPHooksInterface.cpp +++ b/aten/src/ATen/detail/HIPHooksInterface.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -13,8 +14,8 @@ namespace detail { const HIPHooksInterface& getHIPHooks() { static std::unique_ptr hip_hooks; #if !defined C10_MOBILE - static std::once_flag once; - std::call_once(once, [] { + static c10::once_flag once; + c10::call_once(once, [] { hip_hooks = HIPHooksRegistry()->Create("HIPHooks", HIPHooksArgs{}); if (!hip_hooks) { hip_hooks = diff --git a/aten/src/ATen/detail/ORTHooksInterface.cpp b/aten/src/ATen/detail/ORTHooksInterface.cpp index 33f70935a04d0..79d28bbb8d3d5 100644 --- a/aten/src/ATen/detail/ORTHooksInterface.cpp +++ b/aten/src/ATen/detail/ORTHooksInterface.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -12,8 +13,8 @@ namespace detail { // See getCUDAHooks for some more commentary const ORTHooksInterface& getORTHooks() { static std::unique_ptr ort_hooks; - static std::once_flag once; - std::call_once(once, [] { + static c10::once_flag once; + c10::call_once(once, [] { ort_hooks = ORTHooksRegistry()->Create("ORTHooks", {}); if (!ort_hooks) { ort_hooks = diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h index 6e6bf9ffe8d70..c0ef7476196be 100644 --- a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -99,7 +99,14 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI } DeviceIndex deviceCount() const noexcept override { int deviceCnt; - C10_HIP_CHECK(hipGetDeviceCount(&deviceCnt)); + hipError_t _err; + _err = hipGetDeviceCount(&deviceCnt); +#if defined(USE_ROCM) && (ROCM_VERSION < 50201) + if(_err == hipErrorInvalidDevice) + return 0; +#endif + if(_err != hipErrorNoDevice && _err != hipSuccess) + C10_HIP_CHECK(_err); return deviceCnt; } diff --git a/aten/src/ATen/mps/MPSAllocator.h b/aten/src/ATen/mps/MPSAllocator.h index 72c0024807255..ee8712d227ce7 100644 --- a/aten/src/ATen/mps/MPSAllocator.h +++ b/aten/src/ATen/mps/MPSAllocator.h @@ -63,13 +63,18 @@ struct HeapBlock; struct BufferBlock { id buffer; - size_t size; + size_t size; // size after alignment + size_t requested_size; // requested size (before alignment) + // buffer shape is used for retrieving base of views in cached graphs + std::vector shape; bool in_use; HeapBlock* heap; id_t buf_id; - BufferBlock(size_t Size, const id Buffer = nullptr, HeapBlock* Heap = nullptr, id_t BufID = 0) : - buffer(Buffer), size(Size), in_use(false), heap(Heap), buf_id(BufID) { } + BufferBlock(size_t Size, size_t RequestedSize = 0, const id Buffer = nullptr, + HeapBlock* Heap = nullptr, id_t BufID = 0) : + buffer(Buffer), size(Size), requested_size(RequestedSize), + in_use(false), heap(Heap), buf_id(BufID) { } static bool Comparator(const BufferBlock* a, const BufferBlock* b) { return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer; @@ -193,6 +198,9 @@ class MPSHeapAllocatorImpl void Free(void* ptr); void EmptyCache(); bool isSharedBuffer(void* ptr); + ssize_t getRequestedBufferSize(void* ptr); + void setBufferShape(void* ptr, const IntArrayRef& shape); + IntArrayRef getBufferShape(void* ptr); inline id Device() const { return m_device; } void enable_debug_info() { m_enable_debug_info = true; } diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index 873a78fffce8e..2433acbc050b2 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -58,7 +58,7 @@ TORCH_INTERNAL_ASSERT(buffer); // insert heap after a buffer was created on it to update the order of heap's set p.pool->heaps.insert(heap); - p.buffer_block = new BufferBlock(p.size(), buffer, heap, m_allocated_buffers.size() + 1); + p.buffer_block = new BufferBlock(p.size(), p.requested_size, buffer, heap, m_allocated_buffers.size() + 1); m_allocated_buffers[p.buffer_block->buffer] = p.buffer_block; m_total_allocated_memory += p.size(); @@ -66,7 +66,8 @@ std::cerr << "Allocated " << (p.pool->is_shared ? "shared" : "private") << " buffer #" << p.buffer_block->buf_id - << " with aligned size " << format_size(p.size()) + << " of size " << format_size(p.size()) + << " at " << p.buffer_block->buffer << " (requested size: " << format_size(p.requested_size) << ", heap size: " << format_size(heap->size.available) << ", total allocated: " << format_size(m_total_allocated_memory) << ")\n"; @@ -92,7 +93,8 @@ std::cerr << "Reusing " << (p.pool->is_shared ? "shared" : "private") << " buffer #" << p.buffer_block->buf_id - << " with aligned size " << format_size(p.buffer_block->size) + << " of size " << format_size(p.buffer_block->size) + << " at " << p.buffer_block->buffer << " (requested size: " << format_size(p.requested_size) << ")\n"; } return true; @@ -129,6 +131,7 @@ TORCH_INTERNAL_ASSERT(buffer_block->in_use); trigger_memory_callbacks(buffer_block, IMpsAllocatorCallback::EventType::FREED); buffer_block->in_use = false; + buffer_block->shape.clear(); // reset shape BufferPool *pool = buffer_block->heap->pool; // Makes sure the BufferBlock* isn't already present in the pool we're freeing it back into. TORCH_INTERNAL_ASSERT(pool->buffers.insert(buffer_block).second); @@ -136,8 +139,7 @@ BufferBlock* MPSHeapAllocatorImpl::get_allocated_buffer_block(void* ptr) { - id buf = __builtin_bit_cast(id, ptr); - auto it = m_allocated_buffers.find(buf); + auto it = m_allocated_buffers.find(ptr); if (it == m_allocated_buffers.end()) return nullptr; @@ -159,6 +161,40 @@ return buffer_block && buffer_block->heap->pool->is_shared; } +ssize_t MPSHeapAllocatorImpl::getRequestedBufferSize(void* ptr) +{ + std::lock_guard lock(m_mutex); + + BufferBlock *buffer_block = get_allocated_buffer_block(ptr); + if (buffer_block) + return (ssize_t) buffer_block->requested_size; + // this indicates the passed buffer pointer wasn't found + return -1; +} + +void MPSHeapAllocatorImpl::setBufferShape(void* ptr, const IntArrayRef& shape) +{ + std::lock_guard lock(m_mutex); + + BufferBlock *buffer_block = get_allocated_buffer_block(ptr); + TORCH_INTERNAL_ASSERT(buffer_block, "failed to find the buffer ", ptr); + // note that the IntArrayRef doesn't own the underlying data, and the backing + // memory for shape data must persist as long as the buffer is in use. + // So we need to copy to vector. + buffer_block->shape = shape.vec(); +} + +IntArrayRef MPSHeapAllocatorImpl::getBufferShape(void* ptr) +{ + std::lock_guard lock(m_mutex); + + BufferBlock *buffer_block = get_allocated_buffer_block(ptr); + if (buffer_block && buffer_block->shape.size() > 0) + return IntArrayRef{buffer_block->shape}; + + return IntArrayRef(); +} + void MPSHeapAllocatorImpl::Free(void* ptr) { std::lock_guard lock(m_mutex); @@ -350,6 +386,19 @@ static bool isEnvVarEnabled(const char *envvar) { return &_getPrivateAllocator(); } +// TODO: create MPSHooks interface and move these there. +ssize_t get_requested_buffer_size(void* ptr) { + return _getAllocImpl().getRequestedBufferSize(ptr); +} + +void set_buffer_shape(void* ptr, const IntArrayRef& shape) { + _getAllocImpl().setBufferShape(ptr, shape); +} + +IntArrayRef get_buffer_shape(void* ptr) { + return _getAllocImpl().getBufferShape(ptr); +}; + } // namespace mps namespace native { @@ -380,5 +429,4 @@ Tensor _pin_memory_mps(const Tensor& self, c10::optional device) } } // namespace native - } // namespace at diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index b48de11ec1ae1..2775100666494 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -1,15 +1,17 @@ // Copyright © 2022 Apple Inc. +#include + #include namespace at { namespace mps { static std::unique_ptr mps_device; -static std::once_flag mpsdev_init; +static c10::once_flag mpsdev_init; MPSDevice* MPSDevice::getInstance() { - std::call_once(mpsdev_init, [] { + c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr(new MPSDevice()); }); return mps_device.get(); diff --git a/aten/src/ATen/mps/MPSStream.h b/aten/src/ATen/mps/MPSStream.h index 1c19c42b7d774..d4e6172954da3 100644 --- a/aten/src/ATen/mps/MPSStream.h +++ b/aten/src/ATen/mps/MPSStream.h @@ -38,10 +38,18 @@ namespace mps { // MPSStream //----------------------------------------------------------------- +enum class SyncType { + NONE, // no commit to command buffer + COMMIT, // commit and flush the command buffer + COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish + COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer +}; + class TORCH_API MPSStream { public: enum Unchecked { UNCHECKED }; + /// Construct a MPSStream from a Stream. This construction is checked, /// and will raise an error if the Stream is not, in fact, a MPS stream. explicit MPSStream(Stream stream); @@ -50,12 +58,18 @@ class TORCH_API MPSStream MTLCommandQueue_t commandQueue() const { return _commandQueue; }; dispatch_queue_t queue() const { return _serialQueue; } - MTLCommandBuffer_t commandBuffer(); + MPSCommandBuffer* commandBuffer(); void commit(bool flush); void commitAndWait(); - void synchronize(); - + void commitAndContinue(); + void synchronize(SyncType syncType); + void fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE); + void copy(id srcBuffer, id dstBuffer, + size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType = SyncType::NONE); + void copy_and_sync(id srcBuffer, id dstBuffer, + size_t length, size_t srcOffset, size_t dstOffset, bool non_blocking); void flush(); + void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE); /// Get the MPS device index that this stream is associated with. c10::DeviceIndex device_index() const { return _stream.device_index(); } @@ -70,7 +84,8 @@ class TORCH_API MPSStream private: Stream _stream; MTLCommandQueue_t _commandQueue = nil; - MTLCommandBuffer_t _commandBuffer = nil; + MPSCommandBuffer* _commandBuffer = nil; + MPSGraphExecutionDescriptor *_executionDescriptor = nil; void _flush(bool commitAndWait) const; dispatch_queue_t _serialQueue = nullptr; diff --git a/aten/src/ATen/mps/MPSStream.mm b/aten/src/ATen/mps/MPSStream.mm index 7d1d346f17556..948d5723cad91 100644 --- a/aten/src/ATen/mps/MPSStream.mm +++ b/aten/src/ATen/mps/MPSStream.mm @@ -5,6 +5,8 @@ namespace at { namespace mps { +#define USE_MPSCOMMANDBUFFER 1 + //----------------------------------------------------------------- // MPSStream //----------------------------------------------------------------- @@ -12,38 +14,56 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) { _commandQueue = [MPSDevice::getInstance()->device() newCommandQueue]; TORCH_CHECK(_stream.device_type() == DeviceType::MPS); - _serialQueue = dispatch_queue_create("metal gpu stream", NULL); + _serialQueue = dispatch_queue_create("metal gpu stream", nullptr); + _executionDescriptor = [MPSGraphExecutionDescriptor new]; + _executionDescriptor.completionHandler = ^(NSDictionary * resultsDictionary, + NSError * _Nullable error) { }; } MPSStream::~MPSStream() { - [_commandQueue autorelease]; + [_commandQueue release]; _commandQueue = nil; + [_executionDescriptor release]; assert(_commandBuffer == nil); } -id MPSStream::commandBuffer() { +MPSCommandBuffer* MPSStream::commandBuffer() { if (!_commandBuffer) { - _commandBuffer = - [MPSCommandBuffer commandBufferFromCommandQueue:_commandQueue].retain; + _commandBuffer = [MPSCommandBuffer commandBufferFromCommandQueue:_commandQueue].retain; } return _commandBuffer; } -void MPSStream::synchronize() { - dispatch_sync(queue(), ^() { - @autoreleasepool { - commandBuffer(); +void MPSStream::synchronize(SyncType syncType) { + if (!_commandBuffer) + return; + switch(syncType) { + case SyncType::NONE: + // typically in GPU to GPU copies we won't commit explicitly + break; + case SyncType::COMMIT: + flush(); + break; + case SyncType::COMMIT_AND_WAIT: commitAndWait(); - } - }); + break; + case SyncType::COMMIT_AND_CONTINUE: + commitAndContinue(); + break; + } } void MPSStream::commit(bool doFlush) { +#if USE_MPSCOMMANDBUFFER + [commandBuffer() commitAndContinue]; +#else if (doFlush) { flush(); } +#endif } void MPSStream::commitAndWait() { @@ -54,6 +74,11 @@ _commandBuffer = nil; } +void MPSStream::commitAndContinue() { + assert(_commandBuffer); + [_commandBuffer commitAndContinue]; +} + void MPSStream::flush() { if (_commandBuffer) { [_commandBuffer commit]; @@ -71,6 +96,67 @@ [_commandBuffer release]; } +void MPSStream::fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) +{ + TORCH_INTERNAL_ASSERT(length >= offset); + if (length == 0) return; + dispatch_sync(_serialQueue, ^() { + @autoreleasepool { + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + [blitEncoder fillBuffer:buffer + range:NSMakeRange(offset, length) + value:value]; + [blitEncoder endEncoding]; + synchronize(syncType); + } + }); +} + +void MPSStream::copy(id srcBuffer, id dstBuffer, + size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) { + dispatch_sync(_serialQueue, ^() { + @autoreleasepool { + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + [blitEncoder copyFromBuffer:srcBuffer + sourceOffset:(NSUInteger)srcOffset + toBuffer:dstBuffer + destinationOffset:(NSUInteger)dstOffset + size:(NSUInteger)length]; + [blitEncoder endEncoding]; + synchronize(syncType); + } + }); +} + +void MPSStream::copy_and_sync(id srcBuffer, id dstBuffer, size_t length, + size_t srcOffset, size_t dstOffset, bool non_blocking) { + copy(srcBuffer, dstBuffer, length, srcOffset, dstOffset, + !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT); +} + +void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) { + dispatch_sync(_serialQueue, ^() { +#if USE_MPSCOMMANDBUFFER + [mpsGraph encodeToCommandBuffer:commandBuffer() + feeds:feeds + targetOperations:nil + resultsDictionary:results + executionDescriptor:_executionDescriptor]; + // mostly the syncType is NONE, but in some cases we may want to sync and wait (e.g., gatherViewTensor) + synchronize(syncType); +#else + commit(true); + [mpsGraph runAsyncWithMTLCommandQueue:_commandQueue + feeds:feeds + targetOperations:nil + resultsDictionary:results + executionDescriptor:_executionDescriptor]; +#endif + }); +} + //----------------------------------------------------------------- // MPSStreamImpl //----------------------------------------------------------------- diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index f40c4aa3e8239..97f504b85dd19 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -617,6 +617,7 @@ TORCH_IMPL_FUNC(threshold_backward_out)(const Tensor& grad, const Tensor& self, Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) { int64_t weight_num = weight_.numel(); Tensor result = at::empty_like(self, self.suggest_memory_format()); + TORCH_INTERNAL_ASSERT(weight_.defined()); if (weight_num != 1) { int64_t input_ndim = self.dim(); @@ -636,10 +637,12 @@ Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) { // All elements go into the channel dimension DimVector sizes(ndim, 1), strides(ndim, 0); auto as_nd = [&](const Tensor& t) { - TORCH_INTERNAL_ASSERT(t.defined() && (t.dim() == 1 || t.dim() == 0)); + TORCH_CHECK( + t.dim() == 1 || t.dim() == 0, + "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", t.dim()); if (ndim >= 2) { - sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1; - strides[1] = t.dim() == 1 ? t.strides()[0] : 0; + sizes[1] = t.dim() == 1 ? t.size(0) : 1; + strides[1] = t.dim() == 1 ? t.stride(0) : 0; return t.as_strided(sizes, strides); } return t.as_strided(sizes, strides); @@ -648,11 +651,9 @@ Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) { if (self.scalar_type() == ScalarType::BFloat16) { auto w_bf16 = at::empty(weight_.sizes(), weight_.options().dtype(ScalarType::BFloat16)); w_bf16.copy_(weight_); - w = weight_.defined() ? as_nd(w_bf16) : - at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU); + w = as_nd(w_bf16); } else { - w = weight_.defined() ? as_nd(weight_) : - at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU); + w = as_nd(weight_); } auto iter = TensorIteratorConfig() diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 8e68af040e5cc..b2dc974f5a3b8 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -411,9 +411,10 @@ TORCH_META_FUNC(triangular_solve)(const Tensor& self, const Tensor& A, bool uppe } } -TORCH_META_FUNC(_linalg_solve)(const Tensor& A, - const Tensor& B, - bool left) { +TORCH_META_FUNC(_linalg_solve_ex)(const Tensor& A, + const Tensor& B, + bool left, + bool check_errors) { // dtype at::native::checkFloatingOrComplex(A, "linalg.solve"); TORCH_CHECK(A.scalar_type() == B.scalar_type(), @@ -446,8 +447,11 @@ TORCH_META_FUNC(_linalg_solve)(const Tensor& A, auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true); set_output_strided(1, shape, LU_strides, A.options(), {}); - // Pivots + // pivots set_output_contiguous(2, shape.slice(0, ndim - 1), A.options().dtype(kInt)); + + // info + set_output_contiguous(3, shape.slice(0, ndim - 2), A.options().dtype(kInt)); } TORCH_META_FUNC(linalg_lu_factor_ex)(const Tensor& A, bool pivot, bool check_errors) { @@ -504,10 +508,27 @@ TORCH_META_FUNC(linalg_lu_solve)(const Tensor& LU, set_output_strided(0, B_broadcast_size, result_strides, B.options(), {}); } +TORCH_META_FUNC(linalg_cholesky_ex)(const Tensor& A, + bool upper, + bool check_errors) { + at::native::squareCheckInputs(A, "linalg.cholesky"); + at::native::checkFloatingOrComplex(A, "linalg.cholesky"); + + auto A_shape = A.sizes(); + auto ndim = A_shape.size(); + + // L + auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true); + set_output_strided(0, A_shape, L_strides, A.options(), {}); + + // info + set_output_contiguous(1, A_shape.slice(0, ndim - 2), A.options().dtype(ScalarType::Int)); +} TORCH_META_FUNC(linalg_qr)(const Tensor& A, c10::string_view mode) { at::native::checkIsMatrix(A, "linalg.qr"); + at::native::checkFloatingOrComplex(A, "linalg.qr"); bool compute_q, reduced_mode; std::tie(compute_q, reduced_mode) = at::native::_parse_qr_mode(mode); @@ -538,6 +559,7 @@ TORCH_META_FUNC(_linalg_svd)(const Tensor& A, bool compute_uv, c10::optional driver) { at::native::checkIsMatrix(A, "linalg.svd"); + at::native::checkFloatingOrComplex(A, "linalg.svd"); auto sizes = A.sizes().vec(); const auto m = sizes.cend()[-2]; @@ -1409,17 +1431,81 @@ template<> void blasTriangularSolve(char side, char uplo, char trans, cha #endif void _linalg_check_errors( - const Tensor& info, + const Tensor& infos, const c10::string_view api_name, bool is_matrix) { - if (info.is_meta()) { + TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt); + TORCH_INTERNAL_ASSERT(infos.is_contiguous()); + if (infos.is_meta()) { + return; + } + + // If it's all zeros, we return early. + // We optimise for the most likely case. + if (C10_LIKELY(!infos.any().item())) { return; } + + int32_t info; + std::string batch_str; if (is_matrix) { - singleCheckErrors(info.item(), api_name); + info = infos.item(); + // batch_str needn't be set for matrices } else { - batchCheckErrors(info, api_name); + // Find the first non-zero info + auto infos_cpu = infos.to(at::kCPU); + auto ptr = infos_cpu.data_ptr(); + auto n = infos.numel(); + auto info_ptr = std::find_if(ptr, ptr + n, [](int32_t x) { return x != 0; }); + info = *info_ptr; + batch_str = ": (Batch element " + std::to_string(std::distance(ptr, info_ptr)) + ")"; + } + + if (info < 0) { + // Reference LAPACK 3.10+ changed `info` behavior for inputs with non-finite values + // Previously, it would return `info` > 0, but now it returns `info` = -4 + // OpenBLAS 0.3.15+ uses the Reference LAPACK 3.10+. + // MKL 2022.0+ uses the Reference LAPACK 3.10+. + // Older version of MKL and OpenBLAS follow the old behavior (return `info` > 0). + // Here we check for the case where `info` is -4 and raise an error + if (api_name.find("svd") != api_name.npos) { + TORCH_CHECK_LINALG(info != -4, api_name, batch_str, + ": The algorithm failed to converge because the input matrix contained non-finite values."); + } + TORCH_INTERNAL_ASSERT(false, api_name, batch_str, + ": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library."); + } else if (info > 0) { + if (api_name.find("inv") != api_name.npos) { + // inv, inverse, cholesky_inverse, etc. + TORCH_CHECK_LINALG(false, api_name, batch_str, + ": The diagonal element ", info, " is zero, the inversion could not be completed because the input matrix is singular."); + } else if (api_name.find("solve") != api_name.npos) { + // solve, linalg_solve, cholesky_solve, etc. + TORCH_CHECK_LINALG(false, api_name, batch_str, + ": The solver failed because the input matrix is singular."); + } else if (api_name.find("cholesky") != api_name.npos) { + TORCH_CHECK_LINALG(false, api_name, batch_str, + ": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite)."); + } else if (api_name.find("svd") != api_name.npos) { + TORCH_CHECK_LINALG(false, api_name, batch_str, + ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: ", info, ")."); + } else if (api_name.find("eig") != api_name.npos || api_name.find("syevd") != api_name.npos) { + TORCH_CHECK_LINALG(false, api_name, batch_str, + ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ")."); + } else if (api_name.find("lstsq") != api_name.npos) { + TORCH_CHECK_LINALG(false, api_name, batch_str, + ": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ")."); + } else if (api_name.find("lu_factor") != api_name.npos) { + TORCH_CHECK(false, api_name, batch_str, + ": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. " + "If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or " + "linalg.lu_factor_ex(A, pivot)"); + } else { + TORCH_INTERNAL_ASSERT(false, api_name, ": Unknown error code: ", info, "."); + } } + // We should never reach this point as info was non-zero + TORCH_INTERNAL_ASSERT(false); } bool _requires_fw_or_bw_grad(const Tensor& input) { @@ -1649,7 +1735,7 @@ std::tuple linalg_inv_ex(const Tensor& input, bool check_errors) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, std::vector& infos) { +static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, Tensor& infos) { #if !AT_BUILD_WITH_LAPACK() AT_ERROR("cholesky_solve: LAPACK library not found in compilation"); #else @@ -1657,6 +1743,7 @@ static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, std::vector(); auto b_data = b.data_ptr(); + auto infos_data = infos.data_ptr(); auto A_mat_stride = matrixStride(A); auto b_mat_stride = matrixStride(b); auto batch_size = batchCount(A); @@ -1670,7 +1757,7 @@ static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, std::vector(uplo, n, nrhs, A_working_ptr, ldab, b_working_ptr, ldab, &info); - infos[i] = info; + infos_data[i] = info; if (info != 0) { return; } @@ -1681,16 +1768,12 @@ static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, std::vector infos(batchCount(self), 0); + auto infos = at::zeros({batchCount(self)}, self.options().dtype(kInt)); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_solve_cpu", [&]{ apply_cholesky_solve(self_working_copy, A_working_copy, upper, infos); }); - if (self.dim() > 2) { - batchCheckErrors(infos, "cholesky_solve_cpu"); - } else { - singleCheckErrors(infos[0], "cholesky_solve_cpu"); - } + at::_linalg_check_errors(infos, "cholesky_solve_cpu", self.dim() == 2); return self_working_copy; } @@ -1774,132 +1857,57 @@ Tensor& cholesky_out(const Tensor &self, bool upper, Tensor &result) { return result; } -void linalg_cholesky_out_info(const Tensor& input, const Tensor& result, const Tensor& info, bool upper) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.size(-1) == input.size(-2)); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.scalar_type() == input.scalar_type()); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.device() == input.device()); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.scalar_type() == at::kInt); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.device() == input.device()); - - // if result has no elements we can modify it - if (result.numel() == 0) { - at::native::resize_as_(result, input.mT(), MemoryFormat::Contiguous); - result.transpose_(-2, -1); - } - - // result tensor must be in batched column major order (Fortran contiguous) - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.mT().is_contiguous()); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(input.sizes())); - - // cholesky_stub (apply_cholesky) performs calculations in-place and result must be a copy of input - result.copy_(input); - - // if info has no elements we can modify it - auto expected_info_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2); // input.shape[:-2] - if (info.numel() == 0) { - info.resize_(expected_info_shape); +TORCH_IMPL_FUNC(linalg_cholesky_ex_out)(const Tensor& A, + bool upper, + bool check_errors, + const Tensor& L, + const Tensor& info) { + // Nothing to do there + if (L.numel() == 0) { + info.zero_(); + return; } + const auto cpu = A.device() == kCPU; - // info must be contiguous - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.is_contiguous()); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.sizes().equals(expected_info_shape)); - info.fill_(0); - - cholesky_stub(result.device().type(), result, info, upper); - - if (upper) { - result.triu_(); + // We can perform this optimisation just on CPU as it fails for MAGMA + // due to some bug + if (cpu) { + if (upper) { + at::triu_out(const_cast(L), A); + } else { + at::tril_out(const_cast(L), A); + } } else { - result.tril_(); - } -} - -std::tuple linalg_cholesky_ex_out(const Tensor& input, bool upper, bool check_errors, Tensor& L, Tensor& info) { - squareCheckInputs(input, "linalg.cholesky_ex"); - checkSameDevice("torch.linalg.cholesky_ex", L, input, "L"); - checkLinalgCompatibleDtype("torch.linalg.cholesky_ex", L, input, "L"); - checkSameDevice("torch.linalg.cholesky_ex", info, input, "info"); - - // Do not allow type promotion for the `info` tensor, it must be of Int dtype - // Int is used because current interface to LAPACK and its CUDA implementation use "int" type. - // https://github.com/pytorch/pytorch/pull/56724#discussion_r618916774 - ScalarType info_output_type = ScalarType::Int; - TORCH_CHECK( - info.scalar_type() == info_output_type, - "torch.linalg.cholesky_ex: ", - "Expected info to have ", info_output_type, " dtype, but got info with dtype ", info.scalar_type()); - - bool L_input_same_type = (L.scalar_type() == input.scalar_type()); - bool L_equal_expected_shape = L.sizes().equals(input.sizes()); - bool is_L_batched_column_major = false; - if (L.dim() >= 2) { - is_L_batched_column_major = L.mT().is_contiguous(); + L.copy_(A); } - // if L is not empty and not in batched column major format - bool copy_needed = (L.numel() != 0 && !is_L_batched_column_major); - copy_needed |= (L.numel() != 0 && !L_equal_expected_shape); // or L does not have the expected shape - copy_needed |= !L_input_same_type; // or L does not have the same dtype as input - // we have to allocate a temporary tensor + cholesky_stub(L.device().type(), L, info, upper); - // similar conditions for info tensor - auto expected_info_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2); // input.shape[:-2] - copy_needed |= (info.numel() != 0 && !info.is_contiguous()); - copy_needed |= (info.numel() != 0 && !(info.sizes().equals(expected_info_shape))); // or L does not have the expected shape - - if (copy_needed) { - Tensor L_tmp = at::empty({0}, input.options()); - Tensor info_tmp = at::empty({0}, input.options().dtype(kInt)); - linalg_cholesky_out_info(input, L_tmp, info_tmp, upper); - at::native::resize_output(L, L_tmp.sizes()); - L.copy_(L_tmp); - at::native::resize_output(info, info_tmp.sizes()); - info.copy_(info_tmp); - } else { - // use "out" tensors' memory directly - linalg_cholesky_out_info(input, L, info, upper); + if (!cpu) { + if (upper) { + L.triu_(); + } else { + L.tril_(); + } } if (check_errors) { - at::_linalg_check_errors(info, "torch.linalg.cholesky_ex", input.dim() == 2); + at::_linalg_check_errors(info, "linalg.cholesky_ex", A.dim() == 2); } - - return std::tuple(L, info); -} - -std::tuple linalg_cholesky_ex(const Tensor& input, bool upper, bool check_errors) { - Tensor L = at::empty({0}, input.options()); - Tensor info = at::empty({0}, input.options().dtype(kInt)); - std::tie(L, info) = at::native::linalg_cholesky_ex_out(input, upper, check_errors, L, info); - return std::make_tuple(L, info); } -Tensor linalg_cholesky(const Tensor &self, bool upper) { - Tensor result, info; - std::tie(result, info) = at::linalg_cholesky_ex(self, upper, /*check_errors=*/false); - - // we pass check_errors=false above and do the check here - // so that the name of the function is correct in the error message - at::_linalg_check_errors(info, "torch.linalg_cholesky", self.dim() == 2); - return result; +Tensor linalg_cholesky(const Tensor& A, bool upper) { + Tensor L, info; + std::tie(L, info) = at::linalg_cholesky_ex(A, upper, /*check_errors=*/false); + at::_linalg_check_errors(info, "linalg.cholesky", A.dim() == 2); + return L; } -Tensor& linalg_cholesky_out(const Tensor &self, bool upper, Tensor &result) { - // linalg_cholesky_ex_outf includes these checks, but we do it here - // so that the name of the function is correct in the error message - checkSameDevice("torch.linalg.cholesky", result, self); - checkLinalgCompatibleDtype("torch.linalg.cholesky", result, self); - - Tensor info = at::empty({0}, self.options().dtype(kInt)); - std::tie(result, info) = at::linalg_cholesky_ex_outf(self, upper, /*check_errors=*/false, result, info); - - // we pass check_errors=false above and do the check here - // so that the name of the function is correct in the error message - at::_linalg_check_errors(info, "torch.linalg.cholesky", self.dim() == 2); - return result; +Tensor& linalg_cholesky_out(const Tensor& A, bool upper, Tensor& L) { + auto info = at::empty({0}, A.options().dtype(kInt)); + at::linalg_cholesky_ex_out(L, info, A, upper, /*check_errors=*/false); + at::_linalg_check_errors(info, "linalg.cholesky", A.dim() == 2); + return L; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1982,24 +1990,25 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Auxiliary function that returns the LU decomposition to use it in the backward -TORCH_IMPL_FUNC(_linalg_solve_out)(const Tensor& A, - const Tensor& B, - bool left, - const Tensor& result, - const Tensor& LU, - const Tensor& pivots) { +TORCH_IMPL_FUNC(_linalg_solve_ex_out)(const Tensor& A, + const Tensor& B, + bool left, + bool check_errors, + const Tensor& result, + const Tensor& LU, + const Tensor& pivots, + const Tensor& info) { // Possible optimization: Compute the LU factorization of A^T if A is contiguous // Then we solve A^T X = B with adjoint=True // This saves a copy as A doesn't need to be copied into an F-contig matrix in lu_factor const bool use_A_T = A.is_contiguous() && !A.is_complex(); - auto info = at::empty({0}, A.options().dtype(kInt)); at::linalg_lu_factor_ex_out(const_cast(LU), const_cast(pivots), const_cast(info), - use_A_T ? A.mT() : A, - /*pivot=*/true, - /*check_errors=*/false); - at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2); + use_A_T ? A.mT() : A); + if (check_errors) { + at::_linalg_check_errors(info, "torch.linalg.solve_ex", A.dim() == 2); + } // [numpy-compat] Handle vectors on the rhs const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, B); @@ -2008,22 +2017,45 @@ TORCH_IMPL_FUNC(_linalg_solve_out)(const Tensor& A, at::linalg_lu_solve_out(result_, LU, pivots, B_, left, /*adjoint*/use_A_T); } +std::tuple linalg_solve_ex_out(const Tensor& A, + const Tensor& B, + bool left, + bool check_errors, + Tensor& result, + Tensor& info) { + auto LU = B.new_empty({0}); + auto pivots = B.new_empty({0}, kInt); + at::_linalg_solve_ex_out(result, LU, pivots, info, A, B, left, check_errors); + return std::tie(result, info); +} + +// We implement linalg_solve_ex as a composite function of _linalg_solve +std::tuple linalg_solve_ex(const Tensor& A, + const Tensor& B, + bool left, + bool check_errors) { + Tensor result, LU, pivots, info; + std::tie(result, LU, pivots, info) = at::_linalg_solve_ex(A, B, left, check_errors); + return std::make_tuple(std::move(result), std::move(info)); +} + Tensor& linalg_solve_out(const Tensor& A, const Tensor& B, bool left, Tensor& result) { - - auto LU = at::empty({0}, A.options()); - auto pivots = at::empty({0}, A.options().dtype(kInt)); - at::_linalg_solve_out(result, LU, pivots, A, B, left); + auto info = B.new_empty({0}, kInt); + at::linalg_solve_ex_out(result, info, A, B, left); + at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2); return result; } -// We implement linalg_solve as a composite function of _linalg_solve Tensor linalg_solve(const Tensor& A, const Tensor& B, bool left) { - return std::get<0>(at::_linalg_solve(A, B, left)); + Tensor result, info; + std::tie(result, info) = at::linalg_solve_ex(A, B, left); + at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2); + return result; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2868,7 +2900,7 @@ Tensor& linalg_eigvalsh_out(const Tensor& A, c10::string_view uplo, Tensor& L) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool upper, std::vector& infos) { +static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool upper, int* infos) { #if !AT_BUILD_WITH_LAPACK() AT_ERROR("symeig: LAPACK library not found in compilation"); #else @@ -2920,7 +2952,7 @@ static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool } std::tuple _symeig_helper_cpu(const Tensor& self, bool eigenvectors, bool upper) { - std::vector infos(batchCount(self), 0); + auto infos = at::zeros({batchCount(self)}, self.options().dtype(kInt)); auto self_sizes = self.sizes().vec(); self_sizes.pop_back(); @@ -2933,14 +2965,10 @@ std::tuple _symeig_helper_cpu(const Tensor& self, bool eigenvect auto self_working_copy = cloneBatchedColumnMajor(self); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "symeig_cpu", [&]{ - apply_symeig(self_working_copy, eigvals, eigenvectors, upper, infos); + apply_symeig(self_working_copy, eigvals, eigenvectors, upper, infos.data_ptr()); }); - if (self.dim() > 2) { - batchCheckErrors(infos, "symeig_cpu"); - } else { - singleCheckErrors(infos[0], "symeig_cpu"); - } + at::_linalg_check_errors(infos, "symeig", self.dim() == 2); if (eigenvectors) { return std::tuple(eigvals, self_working_copy); } else { @@ -4059,94 +4087,6 @@ std::tuple legacy_lstsq_out( return std::tuple(B_out, A_out); } -Tensor _det_lu_based_helper_backward_helper( - const Tensor& det_grad, - const Tensor& det, - const Tensor& self, - const Tensor& lu, - const Tensor& pivs -) { - auto eps = at::native::_get_epsilon(c10::toRealValueType(self.scalar_type())); - auto n = self.size(-1); - auto eps_tensor = at::tensor(eps, self.options()); - auto condition_diagonal = [&](const Tensor& x) { - auto x_diag = x.diagonal(0, -2, -1); - auto x_diag_conditioned = at::where( - x_diag == 0.0, - eps_tensor, - x_diag - ); - x_diag.copy_(x_diag_conditioned); - }; - - // create a matrix d := (det_grad * det.conj()) I - // NOTE: we do not use the shorter version - // auto d = at::zeros_like(self); - // d.diagonal(0, -2, -1).copy_((det_grad * det.conj()).unsqueeze(-1)); - // to avoid in-place operations to eliminate potential issues with Vmap - auto det_expanded_sizes = det.sizes().vec(); - det_expanded_sizes.push_back(n); - auto d_diag = det_grad * det.conj(); - auto d = at::diag_embed(d_diag.unsqueeze(-1).expand(det_expanded_sizes)); - // make sure that d is Fortran-contiguous. The transposition is sufficient as d is a diagonal square matrix - d = d.mT(); - - if (self.device().type() == at::kCPU) { - // we want to condition the diagonal of the lu Tensor, but it is not allowed - // to modify arguments of backward functions in-place, hence the cloning. - auto lu_clone = lu.clone(); - condition_diagonal(lu_clone); - - auto trans = self.is_complex() ? TransposeType::ConjTranspose : TransposeType::Transpose; - - // d is modified in-place and will contain the result - lu_solve_stub(self.device().type(), lu_clone, pivs, d, trans); - return d; - } - // lu_solve is less stable than two triangular_solve for CUDA tensors. - else { - Tensor p, l, u; - std::tie(p, l, u) = at::lu_unpack(lu, pivs, /*unpack_data=*/true, /*unpack_pivots=*/true); - - if (self.is_complex()) { - // Tensors u_h and l_h should be physically conjugated prior to applying kernel stubs, - // as .conj() is lazy and will not materialize conjugated output. - l.conj_physical_(); - u.conj_physical_(); - } - - // triangular_solve_stub performs operations in-place. - // Tensor d will contain the result - condition_diagonal(u); - - // Solve u^h x = d - // note that d = c I for some scalar c, hence - // d u_h^{-1} = c I u_h^{-1} = u_h^{-1} c I = u_h^{-1} d. - // NOTE: u is contigious and upper-triangular, - // but from the Fortran respective it is lower-triangular and already transposed. - // Since u is conjugated in-place in the code above, it is sufficient - // to just run triangular_solve with upper=false. - triangular_solve_stub( - self.device().type(), u, d, - /*left=*/true, - /*upper=*/false, - /*transpose=*/TransposeType::NoTranspose, - /*unitriangular=*/false); - - // After this operation d will contain a row-wise permuted grad wrt to self - // The same notes as for the system involving u apply here. - triangular_solve_stub( - self.device().type(), l, d, - /*left=*/true, - /*upper=*/true, - /*transpose=*/TransposeType::NoTranspose, - /*unitriangular=*/true); - - // multiply by p to restore the row order - return at::matmul(p, d); - } -} - DEFINE_DISPATCH(ldl_factor_stub); TORCH_IMPL_FUNC(linalg_ldl_factor_ex_out) @@ -4232,6 +4172,39 @@ TORCH_IMPL_FUNC(linalg_ldl_solve_out) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve_triangular ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) { + checkFloatingOrComplex(x, "linalg.vecdot"); + TORCH_CHECK(x.scalar_type() == y.scalar_type(), + "linalg.vecdot: Expected x and y to have the same dtype, but found x of type ", + x.scalar_type(), " and y of type ", y.scalar_type(), " instead"); + // out checks + TORCH_CHECK(out.scalar_type() == x.scalar_type(), + "linalg.vecdot: Expected out of dtype", x.scalar_type(), + " but found ", out.scalar_type()); + checkSameDevice("linalg.vecdot", x, out); + + // Computes x^H y + if (x.dim() == 1 && y.dim() == 1) { + at::native::resize_output(out, {}); + return at::vdot_out(out, x, y); + } else { + return at::sum_out(out, x.conj() * y, /*dim=*/dim); + } +} + +Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) { + checkFloatingOrComplex(x, "linalg.vecdot"); + TORCH_CHECK(x.scalar_type() == y.scalar_type(), + "linalg.vecdot: Expected x and y to have the same dtype, but found x of type ", + x.scalar_type(), " and y of type ", y.scalar_type(), " instead"); + // Computes x^H y + if (x.dim() == 1 && y.dim() == 1) { + return at::vdot(x, y); + } else { + return x.conj().mul(y).sum(/*dim=*/dim); + } +} + /* Solves the matrix equation AX = B for A triangular. 'left' If true solves AX = B, if false solves XA = B diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 54a8e9199228d..5b18dbe2d5fad 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -128,7 +128,7 @@ Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper) } template -void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vecs_, int64_t* info_ptr) { +void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vecs_, int* info_ptr) { #if !AT_BUILD_WITH_LAPACK() TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ", "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); @@ -158,16 +158,14 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec // call lapackEig once to get the optimal size for work data scalar_t wkopt; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int info; lapackEig('N', jobvr, n, self_data, n, wr, - nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, &info); + nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, info_ptr); int lwork = std::max(1, real_impl(wkopt)); // call again to do the actual work Tensor work = at::empty({lwork}, self.dtype()); lapackEig('N', jobvr, n, self_data, n, wr, - nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, rwork_data, &info); - *info_ptr = info; + nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, rwork_data, info_ptr); } #endif } @@ -200,12 +198,12 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector : Tensor(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t info; + auto infos = at::zeros({}, self.options().dtype(kInt)); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cpu", [&]{ - apply_eig(self_, eigenvectors, vals_, vecs_, &info); + apply_eig(self_, eigenvectors, vals_, vecs_, infos.data_ptr()); }); // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) - singleCheckErrors(info, "eig_cpu"); + at::_linalg_check_errors(infos, "eig", /*is_matrix*/true); return std::tuple(vals_, vecs_); } diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 0ce4336f04930..13593a3379498 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -9,6 +10,9 @@ #include #if AT_BUILD_WITH_BLAS() +#if C10_IOS +#include +#else extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc); extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc); extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); @@ -16,14 +20,11 @@ extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void #ifdef BLAS_HAS_SBGEMM extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, - const decltype(c10::impl::ScalarTypeToCPPType::t) *a, int *lda, - const decltype(c10::impl::ScalarTypeToCPPType::t) *b, int *ldb, + const at::BFloat16 *a, int *lda, + const at::BFloat16 *b, int *ldb, float *beta, float *c, int *ldc); #endif // BLAS_HAS_SBGEMM -#endif // AT_BUILD_WITH_BLAS() - -#if AT_BUILD_WITH_BLAS() extern "C" void cswap_(int *n, const void *x, int *incx, void *y, int *incy); extern "C" void dcopy_(int *n, const double *x, int *incx, double *y, int *incy); extern "C" void scopy_(int *n, const float *x, int *incx, float *y, int *incy); @@ -33,7 +34,8 @@ extern "C" void daxpy_(int *n, double *a, const double *x, int *incx, double *y, extern "C" void saxpy_(int *n, float *a, const float *x, int *incx, float *y, int *incy); extern "C" void caxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy); extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy); -#endif // AT_BUILD_WITH_BLAS() +#endif // C10_IOS +#endif // AT_BUILD_WITH_BLAS #ifdef USE_FBGEMM #include @@ -97,6 +99,17 @@ fbgemm::matrix_op_t to_fbgemm(TransposeType trans) { } #endif // USE_FBGEMM +#if (AT_BUILD_WITH_BLAS() && C10_IOS) +CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) { + switch (trans) { + case TransposeType::Transpose: return CblasTrans; + case TransposeType::NoTranspose: return CblasNoTrans; + case TransposeType::ConjTranspose: return CblasConjTrans; + } + TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); +} +#endif + } // namespace (anonymous) DEFINE_DISPATCH(gemm_stub); @@ -113,8 +126,20 @@ void gemm( #if AT_BUILD_WITH_BLAS() if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; - char transa_ = to_blas(transa), transb_ = to_blas(transb); double alpha_ = alpha, beta_ = beta; + #if C10_IOS + CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); + CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); + cblas_dgemm(CblasColMajor, + transa_, transb_, + m_, n_, k_, + alpha_, + a, lda_, + b, ldb_, + beta_, + c, ldc_); + #else + char transa_ = to_blas(transa), transb_ = to_blas(transb); dgemm_( &transa_, &transb_, &m_, &n_, &k_, @@ -123,6 +148,7 @@ void gemm( b, &ldb_, &beta_, c, &ldc_); + #endif return; } #endif @@ -143,8 +169,20 @@ void gemm( #if AT_BUILD_WITH_BLAS() if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; - char transa_ = to_blas(transa), transb_ = to_blas(transb); float alpha_ = alpha, beta_ = beta; + #if C10_IOS + CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); + CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); + cblas_sgemm(CblasColMajor, + transa_, transb_, + m_, n_, k_, + alpha_, + a, lda_, + b, ldb_, + beta_, + c, ldc_); + #else + char transa_ = to_blas(transa), transb_ = to_blas(transb); sgemm_( &transa_, &transb_, &m_, &n_, &k_, @@ -153,6 +191,7 @@ void gemm( b, &ldb_, &beta_, c, &ldc_); + #endif return; } #endif @@ -173,8 +212,20 @@ void gemm( #if AT_BUILD_WITH_BLAS() if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; - char transa_ = to_blas(transa), transb_ = to_blas(transb); c10::complex alpha_ = alpha, beta_ = beta; + #if C10_IOS + CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); + CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); + cblas_zgemm(CblasColMajor, + transa_, transb_, + m_, n_, k_, + &alpha_, + a, lda_, + b, ldb_, + &beta_, + c, ldc_); + #else + char transa_ = to_blas(transa), transb_ = to_blas(transb); zgemm_( &transa_, &transb_, &m_, &n_, &k_, @@ -183,6 +234,7 @@ void gemm( b, &ldb_, &beta_, c, &ldc_); + #endif return; } #endif @@ -203,8 +255,20 @@ void gemm( #if AT_BUILD_WITH_BLAS() if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; - char transa_ = to_blas(transa), transb_ = to_blas(transb); c10::complex alpha_ = alpha, beta_ = beta; + #if C10_IOS + CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); + CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); + cblas_cgemm(CblasColMajor, + transa_, transb_, + m_, n_, k_, + &alpha_, + a, lda_, + b, ldb_, + &beta_, + c, ldc_); + #else + char transa_ = to_blas(transa), transb_ = to_blas(transb); cgemm_( &transa_, &transb_, &m_, &n_, &k_, @@ -213,6 +277,7 @@ void gemm( b, &ldb_, &beta_, c, &ldc_); + #endif return; } #endif @@ -224,19 +289,19 @@ void gemm( void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, - const decltype(c10::impl::ScalarTypeToCPPType::t) alpha, - const decltype(c10::impl::ScalarTypeToCPPType::t) *a, int64_t lda, - const decltype(c10::impl::ScalarTypeToCPPType::t) *b, int64_t ldb, - const decltype(c10::impl::ScalarTypeToCPPType::t) beta, - decltype(c10::impl::ScalarTypeToCPPType::t) *c, int64_t ldc) { + const float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + const float beta, + at::BFloat16 *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM) if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; char transa_ = to_blas(transa), transb_ = to_blas(transb); - // alpha and beta and C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back. - float alpha_ = (float) alpha, beta_ = (float) beta; + float alpha_ = alpha, beta_ = beta; int c_size = n_ * ldc_; + // C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back. std::vector float_v(c, c + c_size); sbgemm_(&transa_, &transb_, &m_, &n_, &k_, @@ -246,10 +311,15 @@ void gemm( &beta_, float_v.data(), &ldc_); for (auto cv: float_v) { - *(c++) = static_cast<_bfloat16_t>(cv); + *(c++) = c10::convert(cv); } return; } +#endif +#if AT_MKLDNN_ENABLED() + if (mkldnn_bf16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { + return; + } #endif gemm_stub( at::kCPU, at::kBFloat16, @@ -461,7 +531,11 @@ void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; + #if C10_IOS + cblas_daxpy(i_n, a, x, i_incx, y, i_incy); + #else daxpy_(&i_n, &a, x, &i_incx, y, &i_incy); + #endif return; } #endif @@ -482,7 +556,11 @@ void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t in int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; + #if C10_IOS + cblas_saxpy(i_n, a, x, i_incx, y, i_incy); + #else saxpy_(&i_n, &a, x, &i_incx, y, &i_incy); + #endif return; } #endif @@ -503,7 +581,11 @@ void axpy(int64_t n, c10::complex a, const c10::complex *x, int6 int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; + #if C10_IOS + cblas_zaxpy(i_n, &a, x, i_incx, y, i_incy); + #else zaxpy_(&i_n, &a, x, &i_incx, y, &i_incy); + #endif return; } #endif @@ -524,7 +606,11 @@ void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_ int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; + #if C10_IOS + cblas_caxpy(i_n, &a, x, i_incx, y, i_incy); + #else caxpy_(&i_n, &a, x, &i_incx, y, &i_incy); + #endif return; } #endif @@ -546,7 +632,11 @@ void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; + #if C10_IOS + cblas_dcopy(i_n, x, i_incx, y, i_incy); + #else dcopy_(&i_n, x, &i_incx, y, &i_incy); + #endif return; } #endif @@ -566,7 +656,11 @@ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy) { int i_n = (int)n; int i_incx = (int)incx; int i_incy = (int)incy; + #if C10_IOS + cblas_scopy(i_n, x, i_incx, y, i_incy); + #else scopy_(&i_n, x, &i_incx, y, &i_incy); + #endif return; } #endif @@ -586,7 +680,11 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *x, int64_t incx, c10::complex #include #include #include @@ -33,10 +34,10 @@ template void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, - scalar_t alpha, + at::opmath_type alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, - scalar_t beta, + at::opmath_type beta, scalar_t *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); gemm_stub( @@ -62,17 +63,14 @@ void gemm( float beta, float *c, int64_t ldc); -#ifdef BLAS_HAS_SBGEMM -using _bfloat16_t = decltype(c10::impl::ScalarTypeToCPPType::t); void gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, - _bfloat16_t alpha, - const _bfloat16_t *a, int64_t lda, - const _bfloat16_t *b, int64_t ldb, - _bfloat16_t beta, - _bfloat16_t *c, int64_t ldc); -#endif // BLAS_HAS_SBGEMM + float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + float beta, + at::BFloat16 *c, int64_t ldc); void gemm( TransposeType transa, TransposeType transb, diff --git a/aten/src/ATen/native/CPUFallback.h b/aten/src/ATen/native/CPUFallback.h index b8c54b0177486..91f1f08c11845 100644 --- a/aten/src/ATen/native/CPUFallback.h +++ b/aten/src/ATen/native/CPUFallback.h @@ -32,8 +32,7 @@ struct _call_fallback_fn final { //.findSchemaOrThrow("a", "b") .typed(); return c10::impl::BoxedKernelWrapper::call( - c10::KernelFunction::make_boxed_function, - nullptr, + c10::BoxedKernel::makeFromFunction(), op, c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset. //std::forward(args...) diff --git a/aten/src/ATen/native/ChanelShuffle.cpp b/aten/src/ATen/native/ChanelShuffle.cpp index 0eb109c42e509..7def359e7056a 100644 --- a/aten/src/ATen/native/ChanelShuffle.cpp +++ b/aten/src/ATen/native/ChanelShuffle.cpp @@ -38,11 +38,12 @@ Tensor channel_shuffle(const Tensor& self, int64_t groups) { #if defined(C10_MOBILE) && defined(USE_XNNPACK) if (self.is_contiguous(MemoryFormat::ChannelsLast) && xnnpack::use_channel_shuffle(self, groups)) { - return xnnpack::channel_shuffle(self, groups); + auto output = self.numel() == 0 ? self : xnnpack::channel_shuffle(self, groups); + return output; } #endif - auto output = at::native_channel_shuffle(self, groups); + auto output = self.numel() == 0 ? self : at::native_channel_shuffle(self, groups); return namedinference::propagate_names_if_nonempty( output, self.has_names() ? self.names() : at::ArrayRef{}); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index a6127a53577b9..9f2d8efbd6181 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -194,8 +194,10 @@ auto ConvParams::use_cudnn(const at::Tensor& input, const at::Tensor& weight) co if (!input.is_cuda() || !cudnn_enabled) { return false; } - if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { - return at::native::cudnnv8_enabled_check_debug(); + if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { + if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) { + return false; + } } if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) { // bypass dilation checks for channels_last convolution @@ -875,7 +877,7 @@ static Tensor convolution_same( if (symmetric_padding) { // All backends handle symmetric padding natively DimVector output_padding(static_cast(dim)); - return native::convolution(input, weight, bias, stride, padding_l, dilation, + return at::convolution(input, weight, bias, stride, padding_l, dilation, false, output_padding, groups); } @@ -913,7 +915,7 @@ Tensor _convolution_mode( } else if (padding == "valid") { // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) const int64_t padding_[] = {0}; - return at::native::convolution( + return at::convolution( input, weight, bias, stride, padding_, dilation, false, padding_, groups); } TORCH_CHECK(false, "Invalid padding string: '", padding, "'"); @@ -1067,7 +1069,7 @@ ConvBackend select_conv_backend( // Expand 1d -> 2d. // This is only done for backends that don't natively support 1d spatial input. - if (k == 3 && !input.is_mkldnn()) { + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { // avoid accidentally going through NHWC for permuted 3d input. input = input.contiguous(); params.view1d_as_2d(); @@ -1309,7 +1311,7 @@ at::Tensor _convolution( // Expand 1d -> 2d. // This is only done for backends that don't natively support 1d spatial input. - if (k == 3 && !input.is_mkldnn()) { + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { // avoid accidentally going through NHWC for permuted 3d input. input = input.contiguous(); params.view1d_as_2d(); @@ -1480,7 +1482,7 @@ at::Tensor _convolution( break; } - if (k == 3 && !input.is_mkldnn()) { + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { output = view3d(output); } @@ -1504,7 +1506,7 @@ std::tuple convolution_backward_overrideable( const Tensor& grad_output, const Tensor& input, const Tensor& weight, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups, std::array output_mask) { - AT_ERROR("You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function "); + TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function "); return std::tuple( at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT), at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT), @@ -1815,7 +1817,7 @@ std::tuple convolution_backward( // Expand 1d -> 2d. // This is only done for backends that don't natively support 1d spatial input. - if (k == 3 && !input.is_mkldnn()) { + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { // avoid accidentally going through NHWC for permuted 3d input. input = input.contiguous(); params.view1d_as_2d(); @@ -2031,12 +2033,12 @@ std::tuple convolution_backward( // Convert 2D inputs back to 1D for backends that don't natively support 1D // spatial inputs. if (output_mask[0]) { - if (k == 3 && !input.is_mkldnn()) { + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { backend_grad_input = view3d(backend_grad_input); } } if (output_mask[1]) { - if (k == 3 && !input.is_mkldnn()) { + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { backend_grad_weight = view3d(backend_grad_weight); } } diff --git a/aten/src/ATen/native/ConvolutionMM2d.cpp b/aten/src/ATen/native/ConvolutionMM2d.cpp index 1837a0d838ea0..d93166a1e343d 100644 --- a/aten/src/ATen/native/ConvolutionMM2d.cpp +++ b/aten/src/ATen/native/ConvolutionMM2d.cpp @@ -19,7 +19,8 @@ static Tensor compute_columns2d( const Tensor& input, IntArrayRef padding, IntArrayRef stride, - IntArrayRef kernel_size) { + IntArrayRef kernel_size, + bool is_channels_last) { const int64_t kernel_height = kernel_size[0]; const int64_t kernel_width = kernel_size[1]; const int64_t pad_height = padding[0]; @@ -33,8 +34,6 @@ static Tensor compute_columns2d( const int64_t output_height = (input_height + 2 * pad_height - kernel_height) / stride_height + 1; const int64_t output_width = (input_width + 2 * pad_width - kernel_width) / stride_width + 1; - bool is_channels_last = input.suggest_memory_format() == at::MemoryFormat::ChannelsLast; - Tensor columns; if ((kernel_height == 1) && (stride_height == 1) && (pad_height == 0) && (kernel_width == 1) && (stride_width == 1) && (pad_width == 0)) { @@ -504,7 +503,7 @@ static void slow_conv2d_backward_weight_out_cpu_template( true); auto grad_output = grad_output_.contiguous(memory_format); - Tensor finput = compute_columns2d(input, padding, stride, kernel_size); + Tensor finput = compute_columns2d(input, padding, stride, kernel_size, use_channels_last); const int64_t batch_size = input.size(0); @@ -571,7 +570,7 @@ Tensor& slow_conv2d_forward_out_cpu( const int64_t output_height = (input_height + 2 * pad_height - kernel_height) / stride_height + 1; const int64_t output_width = (input_width + 2 * pad_width - kernel_width) / stride_width + 1; - Tensor finput = compute_columns2d(input, padding, stride, kernel_size); + Tensor finput = compute_columns2d(input, padding, stride, kernel_size, use_channels_last); output.resize_({batch_size, n_output_plane, output_height, output_width}, memory_format); if (bias.defined()) { output.copy_(bias.reshape({-1, 1, 1})); diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 46c0d48d8a7b2..d4b5c74c3bf34 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -261,13 +261,18 @@ Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) { // That might be fine for functorch (which already doesn't preserve strides in vmap), // but it's worth looking into whether or not this implementation will be problematic for LazyTensor/XLA. auto intermediate = src.to(self, non_blocking); - // Unfortunately, copy()'s decomposition involves view ops. - // To preserve the functionalization pass semantics of "maybe reapply views", - // we need to manually do that here. - if (at::functionalization::impl::getFunctionalizationReapplyViewsTLS()) { - return intermediate.expand(self.sizes()); - } else { + // We can't use expand() here. Why? + // The contract for copy_() is that the output tensor has the same amount of storage as the original tensor. + // e.g. This should work: + // a = torch.ones(4, 4) + // b = torch.ones(1, 4) + // c = torch.ones(4, 4) + // torch.ops.aten.copy(a, b).add_(c) + // We don't want to emit an extra copy every time though, so we only do it if the shapes are different. + if (self.sizes() != intermediate.sizes()) { return at::expand_copy(intermediate, self.sizes()); + } else { + return intermediate; } } diff --git a/aten/src/ATen/native/Cross.cpp b/aten/src/ATen/native/Cross.cpp index 66bfce55646e1..4b3e43da1147b 100644 --- a/aten/src/ATen/native/Cross.cpp +++ b/aten/src/ATen/native/Cross.cpp @@ -27,7 +27,7 @@ namespace native { DEFINE_DISPATCH(cross_stub); -int64_t _default_cross_dim(const c10::optional &dimension, IntArrayRef sizes) { +int64_t _default_cross_dim(const c10::optional &dimension, SymIntArrayRef sizes) { // If dimension is not given, it defaults to the first dimension found with the size 3. // Note that this behaviour might be unexpected. // _default_cross_dim is called internally inside the cross implementation to calculate @@ -45,15 +45,16 @@ int64_t _default_cross_dim(const c10::optional &dimension, IntArrayRef } Tensor cross(const Tensor & input, const Tensor & other, const c10::optional dimension) { - auto dim = _default_cross_dim(dimension, input.sizes()); + auto dim = _default_cross_dim(dimension, input.sym_sizes()); return at::linalg_cross(input, other, dim); } Tensor & cross_out(const Tensor & input, const Tensor & other, const c10::optional dimension, Tensor & out) { - auto dim = _default_cross_dim(dimension, input.sizes()); + auto dim = _default_cross_dim(dimension, input.sym_sizes()); return at::linalg_cross_out(out, input, other, dim); } + TORCH_IMPL_FUNC(linalg_cross_out) (const Tensor & input, const Tensor & other, const int64_t dim, const Tensor & out) { auto out_size = infer_size(input.sizes(), other.sizes()); diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 327d7def16d74..a91448c3da72e 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -39,7 +39,9 @@ static CPUCapability compute_cpu_capability() { #if !defined(__powerpc__) && !defined(__s390x__) if (cpuinfo_initialize()) { -#ifdef HAVE_AVX512_CPU_DEFINITION + // AVX512 can be slower then AVX2, so lets keep it as opt-in + // see https://github.com/pytorch/pytorch/issues/80252 +#if defined(HAVE_AVX512_CPU_DEFINITION) && false // GCC supports some AVX512 intrinsics such as _mm512_set_epi16 only in // versions 9 & beyond. So, we want to ensure that only releases built with // supported compilers on supported hardware return CPU Capability AVX512, @@ -120,6 +122,12 @@ void* DispatchStubImpl::get_call_ptr( TORCH_INTERNAL_ASSERT(hip_dispatch_ptr, "DispatchStub: missing HIP kernel"); return hip_dispatch_ptr; +#if defined(USE_MPS) + case DeviceType::MPS: + TORCH_INTERNAL_ASSERT(mps_dispatch_ptr, "DispatchStub: missing MPS kernel"); + return mps_dispatch_ptr; +#endif + default: AT_ERROR("DispatchStub: unsupported device type", device_type); } diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index bd04b4df9a95f..6e71b5bb5881b 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -118,6 +118,7 @@ struct TORCH_API DispatchStubImpl { std::atomic cpu_dispatch_ptr{nullptr}; void* cuda_dispatch_ptr = nullptr; void* hip_dispatch_ptr = nullptr; + void* mps_dispatch_ptr = nullptr; #endif }; @@ -165,6 +166,10 @@ struct DispatchStub { impl.hip_dispatch_ptr = reinterpret_cast(fn_ptr); } + void set_mps_dispatch_ptr(FnPtr fn_ptr) { + impl.mps_dispatch_ptr = reinterpret_cast(fn_ptr); + } + static TORCH_API FnPtr DEFAULT; #ifdef HAVE_AVX512_CPU_DEFINITION static TORCH_API FnPtr AVX512; @@ -190,6 +195,13 @@ struct RegisterCUDADispatch { } }; +template +struct RegisterMPSDispatch { + RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { + stub.set_mps_dispatch_ptr(value); + } +}; + template struct RegisterHIPDispatch { RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { @@ -259,6 +271,9 @@ struct RegisterHIPDispatch { #define REGISTER_HIP_DISPATCH(name, fn) \ static RegisterHIPDispatch name ## __register(name, fn); +#define REGISTER_MPS_DISPATCH(name, fn) \ + static RegisterMPSDispatch name ## __register(name, fn); + // NB: This macro must be used in an actual 'cu' file; if you try using // it from a 'cpp' file it will not work! #if defined(__CUDACC__) @@ -268,6 +283,9 @@ struct RegisterHIPDispatch { // is HIP in the PyTorch HIPify build. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) +#elif defined(__OBJC__) && defined(USE_MPS) +// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel +#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn) #elif defined(CPU_CAPABILITY) #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #define REGISTER_NO_AVX512_DISPATCH(name) \ diff --git a/aten/src/ATen/native/DistributionTemplates.h b/aten/src/ATen/native/DistributionTemplates.h index f7795bc9bedf1..15e2be8c8f271 100644 --- a/aten/src/ATen/native/DistributionTemplates.h +++ b/aten/src/ATen/native/DistributionTemplates.h @@ -204,7 +204,6 @@ Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, c10::opt at::native::resize_output(output, shape); normal_impl_(output, 0, 1, gen); // CUDA NB: addcmul_out copies the tensor to be added into the output. - // Please look at aten/src/THC/generic/THCTensorMathPointwise.cu // The previous function here was addcmul_out(output, mean_tensor, output, std, 1); // The third argument is not a constant reference and hence the samples in output are overwritten. // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std @@ -219,7 +218,6 @@ Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c at::native::resize_output(output, shape); normal_impl_(output, 0, 1, gen); // CUDA NB: addcmul_out copies the tensor to be added into the output. - // Please look at aten/src/THC/generic/THCTensorMathPointwise.cu // The previous function here was addcmul_out(output, mean, output, std, 1); // The third argument is not a constant reference and hence the samples in output are overwritten. // Consequently, the computation performed is mean + mean * std instead of mean + output * std diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index b23a18a8376a6..962c010614422 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -328,6 +328,11 @@ Tensor normal_meta(const Tensor& mean, const Tensor& std, c10::optional(mean, std, gen); } +// functional variant, only used by the functionalization pass. +Tensor normal_functional(const Tensor& self, double mean, double std, c10::optional generator) { + return self.clone().normal_(mean, std, generator); +} + // ==================================================== Random ======================================================== template diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 6d8cea26f52e5..17094bf9082d5 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -1271,11 +1272,18 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_, Tensor offset2bag_; if (indices.numel() != 0 && offset2bag.numel() == 0) { - offset2bag_ = at::zeros( - {indices.size(0) + 1}, offsets.options()); // offset2bag = [0 0 0 0 0] + offset2bag_ = offsets.new_zeros( + {indices.size(0) + 1}, offsets.options()); // offset2bag = [0 0 0 0 0] make_offset2bag(offsets, offset2bag_); - offset2bag_.resize_({indices.size(0)}); + // For Composite Compliance, if `offset2bag_` is CCT + // then we can't call `resize_`. Instead we call `narrow` + // to slice the tensor. + if (isTensorSubclassLike(offset2bag_)) { + offset2bag_ = offset2bag_.narrow(0, 0, indices.size(0)); + } else { + offset2bag_.resize_({indices.size(0)}); + } } else { auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt}); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index e1aa6f41319ba..7abbc1e333441 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -40,7 +40,7 @@ namespace detail { " for ", isComplexType(self_dtype) ? "complex" : "real", " inputs, but got ", dtype); TORCH_CHECK(promoteTypes(self_dtype, dtype) == dtype, name, ": the dtype of the input ", "(", self_dtype, ") should be convertible ", - "without narrowing to the specified dtype (", dtype, ")."); + "without narrowing to the specified dtype (", dtype, ")"); } } } @@ -89,13 +89,14 @@ TORCH_META_FUNC(linalg_vector_norm)(const Tensor& self, const Scalar& scalar_ord // - We cannot reduce the whole tensor // - We cannot reduce over an empty dimension if (self.numel() == 0 && (ord < 0. || ord == INFINITY)) { - TORCH_CHECK(opt_dim.has_value(), + // dim=None or dim=() reduces the whole tensor + TORCH_CHECK(opt_dim.has_value() && opt_dim->size() != 0, "linalg.vector_norm cannot compute the ", scalar_ord, " norm on an empty ", "tensor because the operation does not have an identity"); for (auto dim_num : dim) { TORCH_CHECK(self.size(dim_num) != 0, - "linalg.vector_norm cannot compute the ", scalar_ord, " norm on an empty ", - "dimension because the operation does not have an identity"); + "linalg.vector_norm cannot compute the ", scalar_ord, " norm on the dimension ", dim_num , + "because this dimension is empty and the operation does not have an identity"); } } @@ -109,6 +110,47 @@ TORCH_META_FUNC(linalg_vector_norm)(const Tensor& self, const Scalar& scalar_ord set_output_raw_strided(0, shape, {}, options); } +TORCH_META_FUNC(_linalg_det)(const Tensor& A) { + at::native::squareCheckInputs(A, "linalg.det"); + at::native::checkFloatingOrComplex(A, "linalg.det"); + + auto shape = A.sizes(); + auto ndim = shape.size(); + + // det + set_output_contiguous(0, shape.slice(0, ndim - 2), A.options()); + + // LU + auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true); + set_output_strided(1, shape, LU_strides, A.options()); + + // pivots + set_output_contiguous(2, shape.slice(0, ndim - 1), A.options().dtype(kInt)); +} + +TORCH_META_FUNC(_linalg_slogdet)(const Tensor& A) { + at::native::squareCheckInputs(A, "linalg.slogdet"); + at::native::checkFloatingOrComplex(A, "linalg.slogdet", /*low_precision*/false); + + auto shape= A.sizes(); + auto ndim = shape.size(); + + auto shape_outputs = shape.slice(0, ndim - 2); + + // sign + set_output_contiguous(0, shape_outputs, A.options()); + + // logabsdet + set_output_contiguous(1, shape_outputs, A.options().dtype(toRealValueType(A.scalar_type()))); + + // LU + auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true); + set_output_strided(2, shape, LU_strides, A.options()); + + // pivots + set_output_contiguous(3, shape.slice(0, ndim - 1), A.options().dtype(kInt)); +} + template void common_checks_baddbmm_bmm(Meta& meta, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm, const c10::optional& self_baddbmm = nullopt) { TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor"); @@ -170,28 +212,42 @@ namespace native { DEFINE_DISPATCH(addr_stub); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.det ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // As P is a permutation matrix // det(P) = 1 if it's an even permutation and det(P) = -1 if it's an odd permutation -static inline Tensor _lu_det_P(const Tensor& lu, const Tensor& pivs) { - const auto n = lu.size(-1); - auto det_P = (at::arange(1, n + 1, pivs.options()) != pivs) +Tensor lu_det_P(const Tensor& pivots) { + return (at::arange(1, pivots.size(-1) + 1, pivots.options()) != pivots) .sum(-1, /*keepdim=*/false, /*dtype=*/at::kLong) .fmod_(2) // take 0 to 1 and 1 to -1 .mul_(-2) .add_(1); - return det_P; } -// Given a pivoted LU factorization A = P L U, -// det(A) = det(P) * det(L) * det(U). -std::tuple _det_lu_based_helper(const Tensor& self) { - Tensor pivs, lu; - std::tie(lu, pivs, std::ignore) = at::linalg_lu_factor_ex(self); - const auto det_P = _lu_det_P(lu, pivs); - auto det = det_P * at::prod(lu.diagonal(0, -2 ,-1), /*dim=*/-1); +// Auxiliary function that returns the LU decomposition to use it in the backward +TORCH_IMPL_FUNC(_linalg_det_out)(const Tensor& A, const Tensor& result, const Tensor& LU, const Tensor& pivots) { + // info is an aux tensor + auto info = at::empty({0}, A.options().dtype(kInt)); + // Optimisation: lu_factor_ex requires the input to be F-contig, otherwise it copies + // Use the transpose of if A is contiguous since det(A^T) = det(A) + // We limit this to real matrices, but it could also be implemented for complex matrices + at::linalg_lu_factor_ex_out(const_cast(LU), const_cast(pivots), const_cast(info), A.is_contiguous() && !A.is_complex() ? A.mH() : A); + + // det = det_P * prod(diag(LU)) + at::mul_out(const_cast(result), lu_det_P(pivots), at::prod(LU.diagonal(0, -2 ,-1), /*dim=*/-1)); +} - return std::make_tuple(std::move(det), std::move(lu), std::move(pivs)); +Tensor linalg_det(const Tensor& A) { + return std::get<0>(at::_linalg_det(A)); +} + +Tensor& linalg_det_out(const Tensor& A, Tensor& result) { + auto LU = at::empty({0}, A.options()); + auto pivots = at::empty({0}, A.options().dtype(kInt)); + at::_linalg_det_out(result, LU, pivots, A); + return result; } // torch.det, alias for torch.linalg.det @@ -199,90 +255,59 @@ Tensor det(const Tensor& self) { return at::linalg_det(self); } -Tensor linalg_det(const Tensor& self) { - squareCheckInputs(self, "linalg.det"); - checkFloatingOrComplex(self, "linalg.det"); +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.slogdet ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - return std::get<0>(at::_det_lu_based_helper(self)); -} +// Auxiliary function that returns the LU decomposition to use it in the backward +TORCH_IMPL_FUNC(_linalg_slogdet_out)(const Tensor& A, const Tensor& sign, const Tensor& logabsdet, const Tensor& LU, const Tensor& pivots) { + // info is an aux tensor + auto info = at::empty({0}, A.options().dtype(kInt)); + // Optimisation: lu_factor_ex requires the input to be F-contig, otherwise it copies + // Use the transpose of if A is contiguous since det(A^T) = det(A) + // We limit this to real matrices, but it could also be implemented for complex matrices + at::linalg_lu_factor_ex_out(const_cast(LU), const_cast(pivots), const_cast(info), A.is_contiguous() && !A.is_complex() ? A.mH() : A); -Tensor& linalg_det_out(const Tensor& self, Tensor& out) { - checkSameDevice("torch.linalg.det", out, self, "out"); - checkLinalgCompatibleDtype("torch.linalg.det", out, self, "out"); + auto diag_U = LU.diagonal(0, -2, -1); + // sign + at::mul_out(const_cast(sign), diag_U.sgn().prod(-1), lu_det_P(pivots)); - IntArrayRef out_sizes(self.sizes().data(), self.dim() - 2); - at::native::resize_output(out, out_sizes); + // logabsdet + at::sum_out(const_cast(logabsdet), diag_U.abs().log_(), -1); +} - auto det = at::native::linalg_det(self); - out.copy_(det); - return out; +std::tuple linalg_slogdet(const Tensor& A) { + auto out = at::_linalg_slogdet(A); + return std::make_tuple(std::move(std::get<0>(out)), std::move(std::get<1>(out))); } -Tensor logdet(const Tensor& self) { - squareCheckInputs(self, "logdet"); - checkFloatingOrComplex(self, "logdet"); - - Tensor pivs, lu; - std::tie(lu, pivs, std::ignore) = at::linalg_lu_factor_ex(self); - const auto det_P = _lu_det_P(lu, pivs); - const auto diag_U = lu.diagonal(0, -2 ,-1); - const auto det_sign = diag_U.sign().prod(-1).mul_(det_P); - - // If det_sign > 0, diag_U.abs_().log_().sum(-1) gives logdet (this means U is not singular). - // If det_sign <= 0, then we get proper nan (when det < 0, i.e., det_sign) or -inf (when det = 0, i.e., U is singular). - // U is singular when U(i, i) = 0 for some i in [1, self.size(-1)]. - Tensor logdet_vals = diag_U.abs_().log_().sum(-1); - if (self.dim() > 2) { - auto indices = toListOfOptionalTensors((det_sign < 0).nonzero_numpy()); - // NOLINTNEXTLINE(performance-move-const-arg) - logdet_vals.index_put_(std::move(indices), at::full({}, NAN, self.options())); - } else if (det_sign.item() < 0) { - logdet_vals.fill_(NAN); - } - return logdet_vals; -} - -std::tuple linalg_slogdet(const Tensor& self) { - squareCheckInputs(self, "linalg.slogdet"); - ScalarType t = self.scalar_type(); - TORCH_CHECK(t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble, - "linalg.slogdet: expected a tensor of float, double, cfloat or cdouble types but got ", t); - - Tensor pivs, lu; - std::tie(lu, pivs, std::ignore) = at::linalg_lu_factor_ex(self); - const auto det_P = _lu_det_P(lu, pivs); - const auto diag_U = lu.diagonal(0, -2 ,-1); - const auto det_sign = diag_U.sgn().prod(-1).mul_(det_P); - // abslogdet_val is -inf if U is singular, in which case diag_U.abs_().log_().sum(-1) will return -inf. - // U is singular when U(i, i) = 0 for some i in [1, self.size(-1)]. - // Since abslogdet_val cannot take nan, no special case handling is required. - // in-place abs is not supported for complex tensors - auto abslogdet_val = isComplexType(t) ? diag_U.abs().log_().sum(-1) : diag_U.abs_().log_().sum(-1); - return std::make_tuple(det_sign, abslogdet_val); +std::tuple linalg_slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) { + auto LU = at::empty({0}, A.options()); + auto pivots = at::empty({0}, A.options().dtype(kInt)); + at::_linalg_slogdet_out(sign, logabsdet, LU, pivots, A); + return std::tie(sign, logabsdet); } -// TODO: implement _out variant avoiding copy and using already allocated storage directly -std::tuple linalg_slogdet_out(const Tensor& input, Tensor& sign, Tensor& logabsdet) { - checkSameDevice("linalg.slogdet", sign, input, "sign"); - checkSameDevice("linalg.slogdet", logabsdet, input, "logabsdet"); - checkLinalgCompatibleDtype("linalg.slogdet", sign, input, "sign"); - ScalarType real_dtype = toRealValueType(input.scalar_type()); - // logabsdet is always real-valued here - checkLinalgCompatibleDtype("linalg.slogdet", logabsdet.scalar_type(), real_dtype, "logabsdet"); +// Alias +std::tuple slogdet(const Tensor& A) { + return at::linalg_slogdet(A); +} - Tensor sign_tmp, logabsdet_tmp; - std::tie(sign_tmp, logabsdet_tmp) = at::linalg_slogdet(input); +std::tuple slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) { + return at::linalg_slogdet_out(sign, logabsdet, A); +} - at::native::resize_output(sign, sign_tmp.sizes()); - sign.copy_(sign_tmp); - at::native::resize_output(logabsdet, logabsdet_tmp.sizes()); - logabsdet.copy_(logabsdet_tmp); +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ logdet ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - return std::tuple(sign, logabsdet); -} +Tensor logdet(const Tensor& A) { + squareCheckInputs(A, "logdet"); + checkFloatingOrComplex(A, "logdet", /*low_precision*/false); + Tensor sign, logabsdet; + std::tie(sign, logabsdet) = at::linalg_slogdet(A); -std::tuple slogdet(const Tensor& self) { - return at::linalg_slogdet(self); + if (A.is_complex()) { + return sign.log() + logabsdet; + } else { + return at::where(sign == -1., NAN, logabsdet); + } } namespace { @@ -1227,11 +1252,6 @@ static void addmm_impl_cpu_( result.copy_(self); } - if (use_mkldnn_bf16_matmul(m1, m2, result)){ - mkldnn_matmul(m1, m2, result, beta.to(), alpha.to()); - return; - } - bool transpose_c = false; Tensor c; @@ -1302,14 +1322,15 @@ static void addmm_impl_cpu_( AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, result.scalar_type(), "addmm_impl_cpu_", [&]{ + using opmath_t = at::opmath_type; at::native::cpublas::gemm( transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose, transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose, m, n, k, - alpha.to(), + alpha.to(), a.data_ptr(), lda, b.data_ptr(), ldb, - beta.to(), + beta.to(), c.data_ptr(), ldc); }); @@ -1739,7 +1760,8 @@ Tensor _matmul_impl( } else if (dim_tensor1 == 2 && dim_tensor2 == 1) { return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2); } else if (dim_tensor1 == 1 && dim_tensor2 == 2) { - return has_out ? at::mv_out(out, tensor2.t(), tensor1) : tensor2.t().mv(tensor1); + return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0) + : tensor1.unsqueeze(0).mm(tensor2).squeeze_(0); } else if (dim_tensor1 == 2 && dim_tensor2 == 2) { return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2); } else if (should_fold(tensor1, dim_tensor2) || should_fold(tensor2, dim_tensor1)) { @@ -1859,8 +1881,22 @@ Tensor _matmul_impl( Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) { auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2); - at::Tensor unused; - auto result = at::native::_matmul_impl(unused, tensor1, tensor2); + at::Tensor result, unused; + // Note [is_nested check] + // We have 2 choices to support nested tensor matmul: + // 1. intercept here by is_nested check + // 2. add nested tensor dispatch key + // Although 1. is gross, we still choose 1. because we hesitate about 2.: + // * We tried 2. for reshape and it caused a weird autograd bug + // (see comment in reshape in TensorShape.cpp) + // * but 2. for linear works? + // TODO: use 2. after we make sure it is fine + if (tensor1.is_nested() || tensor2.is_nested()) { + result = at::_NestedTensor_GeneralizedBMM(tensor1, tensor2); + } + else { + result = at::native::_matmul_impl(unused, tensor1, tensor2); + } namedinference::propagate_names_if_nonempty(result, maybe_outnames); return result; } @@ -1998,6 +2034,19 @@ inline Tensor _blob_to_Tensor( return _move_memory_if_cuda_input(tensor, in); } +template +inline Tensor _linear_combination( + const Tensor& t, + std::initializer_list blob) { + // _blob_to_Tensor converts blob to a 2D tensor for _compute_linear_combination. + // If this tensor is of shape (1, *), the result of _compute_linear_combination + // is going to be of shape (1, *t.shape) so we squeeze(0) so that + // for any t with t.dim() >= 1: t.dim() == _compute_linear_combination(t, ...).dim(). + return at::native::_compute_linear_combination( + t, _blob_to_Tensor(blob, t)) + .squeeze(0); +} + // I + A Tensor compute_T1(const Tensor& A) { // 2 for {I, A} @@ -2029,15 +2078,15 @@ Tensor compute_T4(const Tensor& A) { // contains A^2 As.select(0, 2), // computes (I / 2 + A / 6 + A^2 / 24) - at::native::_compute_linear_combination( + _linear_combination( As.narrow(0, 0, 3), - _blob_to_Tensor({1 / 2.0, 1 / 6.0, 1 / 24.0}, A) + {1 / 2.0, 1 / 6.0, 1 / 24.0} ) ); // I + A + A^2 * (I / 2 + A / 6 + A^2 / 24) - return at::native::_compute_linear_combination( - As, _blob_to_Tensor({1.0, 1.0, 0.0, 1.0}, A) + return _linear_combination( + As, {1.0, 1.0, 0.0, 1.0} ); } @@ -2064,10 +2113,10 @@ Tensor compute_T8(const Tensor& A) { view_out, // As.select(0, 2) = A^2 As.select(0, 2), - at::native::_compute_linear_combination( + _linear_combination( // extract {A, A^2} from As As.narrow(0, 1, 2), - _blob_to_Tensor({x1, x2}, A) + {x1, x2} ) ); @@ -2077,20 +2126,19 @@ Tensor compute_T8(const Tensor& A) { _matmul_impl( view_out, // x3 * A2 + A4 - at::native::_compute_linear_combination( + _linear_combination( As.narrow(0, 2, 2), - _blob_to_Tensor({x3, 1.0}, A) + {x3, 1.0} ), - at::native::_compute_linear_combination( + _linear_combination( As.narrow(0, 0, 4), - _blob_to_Tensor({x4, x5, x6, x7}, A) + {x4, x5, x6, x7} ) ); // return I + A + y2 * A2 + A8; - return at::native::_compute_linear_combination( - As, - _blob_to_Tensor({1.0, 1.0, y2, 0.0, 1.0}, A) + return _linear_combination( + As, {1.0, 1.0, y2, 0.0, 1.0} ); } @@ -2460,14 +2508,17 @@ TORCH_IMPL_FUNC(linalg_vector_norm_out)(const Tensor& self, const Scalar& scalar norm_stub(iter.device_type(), iter, ord); } -void _linalg_matrix_norm_checks(const Tensor& A, IntArrayRef dim, optional opt_dtype) { - at::native::checkFloatingOrComplex(A, "linalg.matrix_norm"); - TORCH_CHECK(A.dim() >= 2, - "linalg.matrix_norm: input tensor must be a matrix or a batch of matrices"); +void _linalg_matrix_norm_checks(const Tensor& A, std::vector& dim, optional opt_dtype, bool low_precision) { + // A + at::native::checkIsMatrix(A, "linalg.matrix_norm"); + at::native::checkFloatingOrComplex(A, "linalg.matrix_norm", /*low_precision*/low_precision); // dim - TORCH_CHECK(dim.size() == 2, "linalg.matrix_norm: dim must be a 2-tuple of ints"); - TORCH_CHECK(dim[0] != dim[1], "Expected dims to be different, got (", dim[0], ", ", dim[1], ") instead"); + TORCH_CHECK(dim.size() == 2, "linalg.matrix_norm: dim must be a 2-tuple. Got ", dim); + // wrap first to identify weird scenarios like A.ndim = 2, dim = (1, -1) + // dim is modified in place while wrapping it + maybe_wrap_dims(dim, A.dim()); + TORCH_CHECK(dim[0] != dim[1], "linalg.matrix_norm: dims must be different. Got (", dim[0], ", ", dim[1], ")"); // dtype at::detail::check_linalg_norm_dtype(opt_dtype, A.scalar_type(), "linalg.matrix_norm"); @@ -2479,14 +2530,14 @@ Tensor linalg_matrix_norm( IntArrayRef dim, bool keepdim, optional opt_dtype) { - _linalg_matrix_norm_checks(A, dim, opt_dtype); - + // Check ord first as it will be used in the dtype check of A auto ord = scalar_ord.toDouble(); auto abs_ord = std::abs(ord); TORCH_CHECK(abs_ord == 2. || abs_ord == 1. || abs_ord == INFINITY, "linalg.matrix_norm: Order ", ord, " not supported."); auto dim_ = dim.vec(); - maybe_wrap_dims(dim_, A.dim()); + // Check A, dim, and dtype + _linalg_matrix_norm_checks(A, dim_, opt_dtype, /*low_precision*/abs_ord != 2.); auto max_min = [ord, keepdim](const Tensor& A, int64_t dim) { return ord > 0 ? A.amax(dim, keepdim) : A.amin(dim, keepdim); }; if (abs_ord == 2.) { @@ -2539,16 +2590,18 @@ Tensor linalg_matrix_norm( IntArrayRef dim, bool keepdim, optional opt_dtype) { - _linalg_matrix_norm_checks(A, dim, opt_dtype); + // Check ord first as it will be used in the dtype check of A TORCH_CHECK(ord == "fro" || ord == "nuc", "linalg.matrix_norm: Order ", ord, " not supported."); - auto A_ = opt_dtype.has_value() ? A.to(*opt_dtype) : A; + auto dim_ = dim.vec(); + // Check A, dim, and dtype + _linalg_matrix_norm_checks(A, dim_, opt_dtype, /*low_precision*/ord != "nuc"); if (ord == "fro") { - return at::linalg_vector_norm(A_, 2, dim, keepdim); + return at::linalg_vector_norm(A, 2, dim_, keepdim, opt_dtype); } else { // nuc - auto dim_ = dim.vec(); - maybe_wrap_dims(dim_, A_.dim()); + auto A_ = opt_dtype.has_value() ? A.to(*opt_dtype) : A; + // Move dims to the end auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], A_.dim()); auto result = at::linalg_svdvals(A_.permute(permutation)).sum(-1, keepdim); diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 9f6f210859274..cbeb49fe81c6e 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -294,88 +294,16 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, c " but each b matrix is ", self.size(-2), " by ", self.size(-1)); } -static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name) { - TORCH_CHECK((at::isFloatingType(t.scalar_type()) || at::isComplexType(t.scalar_type())), - f_name, ": Expected a floating point or complex tensor as input. Got ", toString(t.scalar_type())); -} - -/* - * Given a info int, obtained after a single operation, this function check if the computation - * has been successful (info = 0) or not, and report in case of the latter. - */ -static inline void singleCheckErrors(int64_t info, const c10::string_view name, int64_t batch_id=-1) { - std::string batch_string{""}; - if (batch_id >= 0) { - batch_string = ": (Batch element " + std::to_string(batch_id) + ")"; - } - if (info < 0) { - // Reference LAPACK 3.10+ changed `info` behavior for inputs with non-finite values - // Previously, it would return `info` > 0, but now it returns `info` = -4 - // OpenBLAS 0.3.15+ uses the Reference LAPACK 3.10+. - // MKL 2022.0+ uses the Reference LAPACK 3.10+. - // Older version of MKL and OpenBLAS follow the old behavior (return `info` > 0). - // Here we check for the case where `info` is -4 and raise an error - if (name.find("svd") != name.npos) { - TORCH_CHECK_LINALG(info != -4, name, batch_string, - ": The algorithm failed to converge because the input matrix contained non-finite values."); - } - TORCH_INTERNAL_ASSERT(false, name, batch_string, - ": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library."); - } else if (info > 0) { - if (name.find("inv") != name.npos) { - // inv, inverse, cholesky_inverse, etc. - TORCH_CHECK_LINALG(false, name, batch_string, - ": The diagonal element ", info, " is zero, the inversion could not be completed because the input matrix is singular."); - } else if (name.find("solve") != name.npos) { - // solve, linalg_solve, cholesky_solve, etc. - TORCH_CHECK_LINALG(false, name, batch_string, - ": The solver failed because the input matrix is singular."); - } else if (name.find("cholesky") != name.npos) { - TORCH_CHECK_LINALG(false, name, batch_string, - ": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite)."); - } else if (name.find("svd") != name.npos) { - TORCH_CHECK_LINALG(false, name, batch_string, - ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: ", info, ")."); - } else if (name.find("eig") != name.npos || name.find("syevd") != name.npos) { - TORCH_CHECK_LINALG(false, name, batch_string, - ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ")."); - } else if (name.find("lstsq") != name.npos) { - TORCH_CHECK_LINALG(false, name, batch_string, - ": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ")."); - } else if (name.find("lu_factor") != name.npos) { - TORCH_CHECK(false, name, batch_string, - ": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. " - "If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or " - "linalg.lu_factor_ex(A, pivot)"); - } else { - TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, "."); - } - } -} - -/* - * Given a vector of int64_t infos, obtained after a batch operations, - * this function checks if the computation over all these batches has been - * successful (info = 0) or not, and report in case of the latter. - */ -static inline void batchCheckErrors(const std::vector& infos, const c10::string_view name) { - for (const auto i : c10::irange(infos.size())) { - auto info = infos[i]; - singleCheckErrors(info, name, i); +static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) { + auto dtype = t.scalar_type(); + TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)), + f_name, ": Expected a floating point or complex tensor as input. Got ", dtype); + if (!allow_low_precision_dtypes) { + TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble, + f_name, ": Low precision dtypes not supported. Got ", dtype); } } -/* - * This is an overloaded case of the previous function for a tensor of infos. - */ -static inline void batchCheckErrors(const Tensor& infos, const c10::string_view name) { - auto infos_cpu = infos.to(at::kCPU); - auto infos_data = infos_cpu.data_ptr(); - for (const auto i : c10::irange(infos.numel())) { - auto info = infos_data[i]; - singleCheckErrors(info, name, i); - } -} // Checks if all the Tensors in a TensorList are of the same dimensions static inline void checkAllSameDim(TensorList tensors, int64_t dim) { diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 784dd3cc5ed82..b5b7acb8ede27 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -187,46 +187,12 @@ Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Ten return apply_loss_reduction(output, reduction); } -Tensor _kl_div_log_target(const Tensor& input, const Tensor& target, int64_t reduction) { - auto output = at::exp(target) * (target - input); - return apply_loss_reduction(output, reduction); -} - -Tensor _kl_div_non_log_target(const Tensor& input, const Tensor& target, int64_t reduction) { - auto output_pos = target * (at::log(target) - input); - auto zeros = at::zeros_like(output_pos, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto output = at::where(target > 0, output_pos, zeros); - return apply_loss_reduction(output, reduction); -} - Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) { - return log_target ? _kl_div_log_target(input, target, reduction) - : _kl_div_non_log_target(input, target, reduction); -} - -Tensor kl_div_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) { - auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto grad_expand = grad.expand_as(input); - if (!log_target) { - auto iter = TensorIteratorConfig() - .add_output(grad_input) - .add_input(target) - .add_input(grad_expand) - .build(); - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "kl_div_backward_cpu", [&]() { - cpu_serial_kernel(iter, [](scalar_t target_val, scalar_t grad_val) -> scalar_t{ - return target_val > 0 ? -target_val * grad_val : 0; - }); - }); - } - else { - grad_input = -at::exp(target) * grad_expand; - } - - if (reduction == at::Reduction::Mean) { - return grad_input / input.numel(); - } - return grad_input; + TORCH_CHECK(!input.is_complex() && !target.is_complex(), + "kl_div: Complex inputs not supported.") + auto output = log_target ? at::exp(target) * (target - input) + : target * (at::log(target) - input); + return apply_loss_reduction(output, reduction); } Tensor binary_cross_entropy_cpu(const Tensor& input, const Tensor& target, const c10::optional& weight_opt, int64_t reduction) { @@ -351,49 +317,6 @@ Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& targe return apply_loss_reduction(loss, reduction); } -Tensor binary_cross_entropy_with_logits_backward( - const Tensor& grad, - const Tensor& input, - const Tensor& target, - const c10::optional& weight_opt, - const c10::optional& pos_weight_opt, - int64_t reduction) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - const Tensor& pos_weight = - c10::value_or_else(pos_weight_opt, [] { return Tensor(); }); - - Tensor grad_input; - auto hasSubclassTensors = at::areAnyTensorSubclassLike({grad, input, target}); - - // If there are subclassed tensors use the out of place version - if (pos_weight.defined()) { - // pos_weight might need to be broadcasted, thus mul(target) is not inplace. - auto t = pos_weight.mul(target); - grad_input = hasSubclassTensors - ? t.add(1).sub(target).mul(input.sigmoid()).sub(t).mul(grad) - : t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t).mul_(grad); - } else { - grad_input = hasSubclassTensors ? (input.sigmoid() - target).mul(grad) - : (input.sigmoid() - target).mul_(grad); - } - if (weight.defined()) { - if (at::areAnyTensorSubclassLike({grad_input, weight})) { - grad_input = grad_input.mul(weight); - } else { - grad_input.mul_(weight); - } - } - - if (reduction == at::Reduction::Mean) { - return grad_input / input.numel(); - } - - return grad_input; -} - Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction) { Tensor loss; diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 9d6686413bc51..1fba036f35472 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -2177,6 +2177,185 @@ static inline C10_HOST_DEVICE T calc_log_ndtr(T x) { } } +template +static inline C10_HOST_DEVICE T airy_ai_forward(T x) { + static const T AN[] = { + +3.46538101525629032477e-01, + +1.20075952739645805542e+01, + +7.62796053615234516538e+01, + +1.68089224934630576269e+02, + +1.59756391350164413639e+02, + +7.05360906840444183113e+01, + +1.40264691163389668864e+01, + +9.99999999999999995305e-01, + }; + + static const T AD[] = { + +5.67594532638770212846e-01, + +1.47562562584847203173e+01, + +8.45138970141474626562e+01, + +1.77318088145400459522e+02, + +1.64234692871529701831e+02, + +7.14778400825575695274e+01, + +1.40959135607834029598e+01, + +1.00000000000000000470e+00, + }; + + static const T AFN[] = { + -1.31696323418331795333e-01, + -6.26456544431912369773e-01, + -6.93158036036933542233e-01, + -2.79779981545119124951e-01, + -4.91900132609500318020e-02, + -4.06265923594885404393e-03, + -1.59276496239262096340e-04, + -2.77649108155232920844e-06, + -1.67787698489114633780e-08, + }; + + static const T AFD[] = { + +1.33560420706553243746e+01, + +3.26825032795224613948e+01, + +2.67367040941499554804e+01, + +9.18707402907259625840e+00, + +1.47529146771666414581e+00, + +1.15687173795188044134e-01, + +4.40291641615211203805e-03, + +7.54720348287414296618e-05, + +4.51850092970580378464e-07, + }; + + static const T AGN[] = { + +1.97339932091685679179e-02, + +3.91103029615688277255e-01, + +1.06579897599595591108e+00, + +9.39169229816650230044e-01, + +3.51465656105547619242e-01, + +6.33888919628925490927e-02, + +5.85804113048388458567e-03, + +2.82851600836737019778e-04, + +6.98793669997260967291e-06, + +8.11789239554389293311e-08, + +3.41551784765923618484e-10, + }; + + static const T AGD[] = { + +9.30892908077441974853e+00, + +1.98352928718312140417e+01, + +1.55646628932864612953e+01, + +5.47686069422975497931e+00, + +9.54293611618961883998e-01, + +8.64580826352392193095e-02, + +4.12656523824222607191e-03, + +1.01259085116509135510e-04, + +1.17166733214413521882e-06, + +4.91834570062930015649e-09, + }; + + int domain_flag = 0; + + T ai; + + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + + if (x > T(103.892)) { + return T(0.0); + } + + T f; + T g; + T k; + + if (x < T(-2.09)) { + T z = T(1.0) / (T(-2.0) * x * std::sqrt(-x) / T(3.0)); + + T afn = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afn = afn * (z * z) + AFN[index]; + } + + T afd = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afd = afd * (z * z) + AFD[index]; + } + + T agn = 0.0; + + for (uint8_t index = 0; index <= 10 + 0; index++) { + agn = agn * (z * z) + AGN[index]; + } + + T agd = 0.0; + + for (uint8_t index = 0; index <= 10 - 1; index++) { + agd = agd * (z * z) + AGD[index]; + } + + T t = T(-2.0) * x * std::sqrt(-x) / T(3.0) + T(0.25) * M_PI; + + return T(5.64189583547756286948e-01) / std::sqrt(std::sqrt(-x)) * (std::sin(t) * (T(1.0) + z * z * afn / afd) - std::cos(t) * (z * agn / agd)); + } + + if (x >= T(2.09)) { + domain_flag = 5; + + T zeta = T(2.0) * x * std::sqrt(x) / T(3.0); + + T an = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + an = an * (T(1.0) / zeta) + AN[index]; + } + + T ad = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + ad = ad * (T(1.0) / zeta) + AD[index]; + } + + ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * std::sqrt(std::sqrt(x)) * std::exp(zeta)); + + if (x > T(8.3203353)) { + return ai; + } + } + + f = 1.0; + g = x; + k = 1.0; + + T m = 1.0; + T n = x; + T t = 1.0; + T z = x * x * x; + + while (t > std::numeric_limits::epsilon()) { + m *= z; + k += T(1.0); + m /= k; + n *= z; + k += T(1.0); + n /= k; + m /= k; + f += m; + k += T(1.0); + n /= k; + g += n; + + t = std::abs(m / f); + } + + if ((domain_flag & 1) == 0) { + return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g; + } + + return ai; +} // T airy_ai(T x) + template static inline C10_HOST_DEVICE T bessel_j0_forward(T x) { static const T PP[] = { @@ -3314,6 +3493,161 @@ static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x); } // modified_bessel_k1_forward(T x) +template +static inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint64_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (T(0.5) * (a - p) - std::log(T(0.5) * x) * modified_bessel_i0_forward(x)) * std::exp(x); + } + + T b = B[0]; + + for (uint64_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return T(0.5) * (b - p) / std::sqrt(x); +} // T scaled_modified_bessel_k0_forward(T x) + +template +static inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint64_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * std::exp(x); + } + + T b = B[0]; + + for (uint64_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return (T(0.5) * (b - p) / std::sqrt(x)); +} // T scaled_modified_bessel_k1_forward(T x) + template static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { if (n < 0) { @@ -3526,4 +3860,17 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_w_forward(T x, T n) +template +static inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) { + if (std::isinf(x)) { + return T(0.0); + } + + if (std::abs(x) < T(0.5)) { + return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0))))))); + } + + return std::sin(x) / x; +} // T spherical_bessel_j0_forward(T x) + C10_CLANG_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/native/MathBitFallThroughLists.h b/aten/src/ATen/native/MathBitFallThroughLists.h index 97b0854d82d0a..025c25bcbe7b9 100644 --- a/aten/src/ATen/native/MathBitFallThroughLists.h +++ b/aten/src/ATen/native/MathBitFallThroughLists.h @@ -54,6 +54,7 @@ namespace at { #define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \ m.impl("empty_like", torch::CppFunction::makeFallthrough()); \ m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \ + m.impl("empty.SymInt", torch::CppFunction::makeFallthrough()); \ m.impl("empty.out", torch::CppFunction::makeFallthrough()); \ m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \ m.impl("full_like", torch::CppFunction::makeFallthrough()); \ diff --git a/aten/src/ATen/native/MathBitsFallback.h b/aten/src/ATen/native/MathBitsFallback.h index f16ad6b605f6d..4e9c2d9e98b18 100644 --- a/aten/src/ATen/native/MathBitsFallback.h +++ b/aten/src/ATen/native/MathBitsFallback.h @@ -98,7 +98,6 @@ struct MathOpFallback { continue; } auto tensor = std::move(ivalue).toTensor(); - TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), op_name, " fallback does not support meta tensors."); auto resolved_tensor = at::clone(tensor); if (mut_arg) { TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ", diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index db432171d1dcc..0b3bb3e04c7b9 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -24,6 +24,23 @@ Tensor empty_meta( size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } +Tensor empty_symint_meta( + SymIntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt +) { + + auto opt_size = asIntArrayRefSlowOpt(size); + if (opt_size.has_value()) { + return at::detail::empty_meta(*opt_size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); + } + return at::detail::empty_symint_meta( + size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); +} + Tensor empty_strided_meta( IntArrayRef size, IntArrayRef stride, diff --git a/aten/src/ATen/native/NNPACK.cpp b/aten/src/ATen/native/NNPACK.cpp index 11adc156f3655..3df0a0623e437 100644 --- a/aten/src/ATen/native/NNPACK.cpp +++ b/aten/src/ATen/native/NNPACK.cpp @@ -1,6 +1,8 @@ #include #include +#include + #include #if !AT_NNPACK_ENABLED() @@ -37,10 +39,10 @@ namespace at { namespace native { static bool init_nnpack() { - static std::once_flag once_; + static c10::once_flag once_; static bool nnpack_successfully_initialized_ = false; - std::call_once(once_, []() { + c10::call_once(once_, []() { const nnp_status nnpack_status = nnp_initialize(); nnpack_successfully_initialized_ = (nnp_status_success == nnpack_status); diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index feb8b89f1aabe..e5373cac4ad2f 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -489,7 +489,9 @@ std::tuple _batch_norm_impl_index( ||(input.size(0) <= 65535 && !training)) //spatial, eval && detail::getCUDAHooks().compiledWithCuDNN() && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN() - && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L); + && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L + && input.numel() < std::numeric_limits::max() // some cuDNN kernels have 32-bit indexing limitations + ); if (use_cudnn) { auto input_c = input.contiguous(input.suggest_memory_format()); diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 856493d6d2c2f..043e93e332a69 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -49,7 +49,7 @@ signature. `Tensor` or `Tensor?` must sometimes be annotated to indicate aliasing and mutability. In general annotations can be defined via the following four situations: - `Tensor(a)` - `a` is a set of Tensors that may alias to the same data. - - `Tensor(a!)` - `a` members of a may be written to thus mutating the underlying data. + - `Tensor(a!)` - members of `a` may be written to thus mutating the underlying data. - `Tensor!` - shorthand for Tensor(fresh\_identifier!) - `Tensor(a! -> a|b)` - Tensor is in set `a`, written to, and after the write is in set `a` AND `b`. For more details on when and why this needs to happen, please see the section on annotations. @@ -187,6 +187,18 @@ overload names, at most one overload is allowed to have an empty overload name. The declarations also support the following attributes. +**Namespaces.** User can register operators in different namespaces than `aten`, by simply putting custom namespaces before the function name. Currently nested namespace is not supported for function name. If not specified, all the functions will be registered in `aten` namespace. + +For example, suppose we are registering `my_op` into `custom` namespace, we can have: +``` +- func: custom::my_op(Tensor(a) self, ...) -> Tensor(a) + variants: function, method + dispatch: + CPU: my_op_cpu + CUDA: my_op_cuda +``` + +Note that we have a one-off `TORCH_LIBRARY` APIs to achieve the same goal of registering an operator in a custom namespace. Comparing with that API, having custom namespace in `native_functions.yaml` is useful in cases where the function does not really belong to ATen but is also widely used and it is preferred to have a shared place to register it. ### `variants` @@ -270,7 +282,14 @@ dispatch: This specifies the actual name of the function you want to dispatch to, so you can dispatch to different functions depending on which backend the passed tensors -belong to. If the dispatch table is omitted, we assume a default dispatch +belong to. Notice that custom namespaces is supported on these names, it's useful when the native function listed lives in a namespace other than the default `at::native`. Currently we support nested namespace with maximum level of 2. For example: +``` +dispatch: + CPU: custom::ns::func_cpu +``` +The example above hinted the native function can be found under `custom::ns::native` namespace (the trailing `::native` is added automatically). + +If the dispatch table is omitted, we assume a default dispatch table: ``` @@ -364,19 +383,22 @@ added if applicable), so that it's still available for other backends to use. If you implemented a native function in C++ and want to find out which dispatch keyword should be used in native_functions.yaml, please [follow steps in dispatch keywords](#choosing-the-right-dispatch-keyword) -### CompositeImplicitAutograd Compliance +### Composite Compliance + +Definition: a "composite function" is an Operator registered as +CompositeImplicitAutograd or a (Python or C++) function that consists of PyTorch +operations. Examples of the latter include backward formulas and forward-mode AD formulas. -Functions registered as CompositeImplicitAutograd MUST work for most, if not -all, backends. This means that we impose a set of constraints that make it more -difficult to write a CompositeImplicitAutograd function than writing regular -PyTorch code. +Composite functions defined in the PyTorch library MUST work for most, if not +all, backends/subclasses. This means that we impose a set of constraints that make it more +difficult to write composite functions inside PyTorch library code than users +writing PyTorch code. If you wish to do something that is banned (you may wish to do this for perf -reasons), please write a backwards formula for your operator so it is no longer -CompositeImplicitAutograd or hide parts of the operator in a new operator -that is not CompositeImplicitAutograd. +reasons), please write a backwards formula for your function so it is no longer +hide parts of the function in a new aten operator that is not CompositeImplicitAutograd. -CompositeImplicitAutograd operators must not: +Composite functions may not: - call `resize_` or moral equivalents. These are tricky to handle for many backends, like vmap and meta. - call `out=` operations. These are impossible to handle for vmap and can cause diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp index 49dd2c5de1ebd..b4eff5ed9e21f 100644 --- a/aten/src/ATen/native/RangeFactories.cpp +++ b/aten/src/ATen/native/RangeFactories.cpp @@ -115,7 +115,7 @@ Tensor& range_out(const Scalar& start, const Scalar& end, const Scalar& step, Te std::isfinite(static_cast(xend)), "unsupported range: ", xstart, " -> ", xend); TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + "upper bound and lower bound inconsistent with step sign"); int64_t size = static_cast(((xend - xstart) / xstep) + 1); if (result.numel() != size) { result.resize_({size}); @@ -143,12 +143,19 @@ Tensor& range_out(const Scalar& start, const Scalar& end, const Scalar& step, Te } Tensor& arange_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) { - AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "arange_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, result.scalar_type(), "arange_cpu", [&]() { using accscalar_t = at::acc_type; auto xstart = start.to(); auto xend = end.to(); auto xstep = step.to(); + TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); + TORCH_CHECK(std::isfinite(static_cast(xstart)) && + std::isfinite(static_cast(xend)), + "unsupported range: ", xstart, " -> ", xend); + TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), + "upper bound and larger bound inconsistent with step sign"); + // we use double precision for (start - end) / step // to compute size_d for consistency across devices. // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, @@ -159,20 +166,13 @@ Tensor& arange_out(const Scalar& start, const Scalar& end, const Scalar& step, T double size_d; // NOLINTNEXTLINE(bugprone-branch-clone) if (std::is_same::value) { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); + int64_t sgn = (xstep > 0) - (xstep < 0); + size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); } else { size_d = std::ceil(static_cast(end.to() - start.to()) / step.to()); } - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); - TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), "invalid size, possible overflow?"); diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 24a079d23c2f0..52ddcd83774ff 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -52,8 +52,6 @@ namespace meta { static ScalarType infer_dtype_from_optional( const Tensor& self, - IntArrayRef dim, - bool keepdim, const optional& opt_dtype, const Tensor& result) { // 'opt_dtype' has the priority for both cases. @@ -187,9 +185,9 @@ TORCH_META_FUNC(cumprod) } TORCH_META_FUNC2(sum, dim_IntList) -(const Tensor& self, IntArrayRef dim, bool keepdim, optional opt_dtype) { - auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output()); - resize_reduction(*this, self, dim, keepdim, out_dtype); +(const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { + auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output()); + resize_reduction(*this, self, opt_dim, keepdim, out_dtype); } TORCH_META_FUNC2(prod, dim_int) @@ -197,12 +195,12 @@ TORCH_META_FUNC2(prod, dim_int) int64_t dim, bool keepdim, c10::optional dtype) { - auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, dtype, maybe_get_output()); + auto out_dtype = infer_dtype_from_optional(self, dtype, maybe_get_output()); resize_reduction(*this, self, dim, keepdim, out_dtype); } TORCH_META_FUNC2(mean, dim) -(const Tensor& self, IntArrayRef dim, bool keepdim, optional opt_dtype) { +(const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { auto in_dtype = at::native::get_dtype_from_self(self, opt_dtype, true); if (!at::isFloatingType(in_dtype) && !at::isComplexType(in_dtype)) { @@ -221,8 +219,8 @@ TORCH_META_FUNC2(mean, dim) "Got: ", dtype); } - auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output()); - resize_reduction(*this, self, dim, keepdim, out_dtype); + auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output()); + resize_reduction(*this, self, opt_dim, keepdim, out_dtype); } ScalarType get_result_or_self_value_dtype( @@ -670,12 +668,12 @@ template void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data, int self_dim_size, int self_stride, int values_stride, int indices_stride) { Operation op; - T1 out = self_data[0]; + T1 out = c10::load(self_data); int idx = 0; for (const auto i : c10::irange(self_dim_size)) { - T1 curr_elem = self_data[i*self_stride]; + T1 curr_elem = c10::load(&self_data[i*self_stride]); if(isnan_(curr_elem) || (!isnan_(out) && op(curr_elem, out))) { - out = self_data[i*self_stride]; + out = curr_elem; idx = i; } values_data[i*values_stride] = out; @@ -1061,11 +1059,11 @@ inline ScalarType get_dtype_from_result(Tensor& result, optional dty TORCH_IMPL_FUNC(sum_out) (const Tensor& self, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype, const Tensor& result) { - auto iter = meta::make_reduction_from_out_ty(self, result, dim, keepdim, result.scalar_type()); + auto iter = meta::make_reduction_from_out_ty(self, result, opt_dim, keepdim, result.scalar_type()); if (iter.numel() == 0) { result.zero_(); } else { @@ -1081,6 +1079,10 @@ Tensor sum(const Tensor& self, DimnameList dim, bool keepdim, c10::optional opt_dtype) { + return at::sum(input_t, c10::asIntArrayRefSlow(dim), keepdim, opt_dtype); +} + Tensor& sum_out(const Tensor& self, DimnameList dim, bool keepdim, optional opt_dtype, Tensor& result) { return at::sum_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype); @@ -1189,7 +1191,7 @@ Tensor& prod_out(const Tensor& self, Dimname dim, TORCH_IMPL_FUNC(mean_out) (const Tensor& self, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, c10::optional opt_dtype, const Tensor& result) { @@ -1200,19 +1202,19 @@ TORCH_IMPL_FUNC(mean_out) // in lieu of the sum + divide implementation below. if (self.device().is_cpu()) { int64_t dim_prod = 1; - if (dim.size() == 0 || self.ndimension() == 0) { + if (!opt_dim.has_value() || opt_dim.value().size() == 0 || self.ndimension() == 0) { dim_prod = self.numel(); } else { + auto dim = opt_dim.value(); for (auto d : dim) { dim_prod *= self.size(d); } } auto& result_mut = const_cast(result); - at::sum_out(result_mut, self, dim, keepdim, dtype).div_(dim_prod); + at::sum_out(result_mut, self, opt_dim, keepdim, dtype).div_(dim_prod); } else { - DimVector dims(dim); auto iter = at::meta::make_reduction_from_out_ty( - self, result, dims, keepdim, dtype); + self, result, opt_dim, keepdim, dtype); if (iter.numel() == 0) { result.fill_(std::numeric_limits::quiet_NaN()); } else { @@ -1261,7 +1263,7 @@ Tensor nanmean( self.scalar_type()); const auto factor = at::native::isnan(self.detach()).logical_not_().sum(dim, keepdim); - return at::nansum(self, dim, keepdim, opt_dtype).div_(factor); + return at::nansum(self, dim, keepdim, opt_dtype).div(factor); } static Tensor squeeze_multiple(const Tensor& self, IntArrayRef dims) { diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 8f63b9bd0b67a..7c73c85d4c2ff 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -110,12 +110,27 @@ static inline Tensor integer_upcast(const Tensor& self, optional dty using DimMask = TensorIterator::DimMask; -static DimMask make_dim_mask(IntArrayRef dims, int64_t ndim) { +static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) { + if (opt_dims.has_value()) { + return DimVector(opt_dims.value()); + } else { + std::vector all_dims(ndim); + std::iota(all_dims.begin(), all_dims.end(), 0); + return DimVector(all_dims); + } +} + +static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim) { DimMask mask; - if (dims.empty()) { - mask = DimMask().flip(); + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + if (dims.empty()) { + mask = DimMask().flip(); + } else { + mask = at::dim_list_to_bitset(dims, ndim); + } } else { - mask = at::dim_list_to_bitset(dims, ndim); + mask = DimMask().flip(); } return mask; } @@ -320,10 +335,10 @@ static C10_UNUSED DimVector get_reduction_shape( static void resize_reduction( impl::MetaBase& meta, const Tensor& self, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim, ScalarType out_dtype) { - DimVector dims_(dims); + DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim()); maybe_wrap_dims(dims_, self.dim()); auto shape = get_reduction_shape(self, dims_, keepdim); meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype)); @@ -351,11 +366,11 @@ static void resize_reduction_with_indices( static TensorIterator make_reduction( const Tensor& self, const Tensor& result, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim, ScalarType in_dtype) { int64_t ndim = self.dim(); - auto mask = at::native::make_dim_mask(dims, ndim); + auto mask = at::native::make_dim_mask(opt_dims, ndim); auto viewed_result = at::native::review_reduce_result(result, ndim, mask, keepdim); if (self.scalar_type() == in_dtype) { @@ -389,7 +404,7 @@ static TensorIterator make_reduction( static C10_UNUSED TensorIterator make_reduction_from_out_ty( const Tensor& self, const Tensor& result, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim, ScalarType out_dtype) { // special case for type promotion in mixed precision, improves computational @@ -401,7 +416,7 @@ static C10_UNUSED TensorIterator make_reduction_from_out_ty( (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat); auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype; - return make_reduction(self, result, dims, keepdim, in_dtype); + return make_reduction(self, result, opt_dims, keepdim, in_dtype); } } // namespace meta diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index 5edde5e182134..43c7874e43722 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -337,13 +337,7 @@ TORCH_IMPL_FUNC(log_softmax_cpu_out) if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) { log_softmax_lastdim_kernel(kCPU, output, input_); } else { - AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::BFloat16, input_.scalar_type(), "log_softmax", [&] { - host_softmax< - scalar_t, - true /* LogSoftMax */, - false /* MaskedSoftMax */>(output, input_, dim_); - }); + log_softmax_kernel(kCPU, output, input_, dim_); } } @@ -372,13 +366,7 @@ TORCH_IMPL_FUNC(softmax_backward_cpu_out) if (grad_.ndimension() > 0 && dim_ == grad_.ndimension() - 1) { softmax_backward_lastdim_kernel(kCPU, grad_input, grad_, output); } else { - AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::BFloat16, grad.scalar_type(), "softmax_backward", [&] { - host_softmax_backward< - scalar_t, - false /* LogSoftMax */, - false /* MaskedSoftmax */>(grad_input, grad_, output, dim_); - }); + softmax_backward_kernel(kCPU, grad_input, grad_, output, dim_); } } @@ -401,16 +389,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) ( if (grad_.ndimension() > 0 && dim_ == grad_.ndimension() - 1) { log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad_, output_); } else { - AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::BFloat16, - grad.scalar_type(), - "log_softmax_backward", - [&] { - host_softmax_backward< - scalar_t, - true /* LogSoftMax */, - false /* MaskedSoftMax */>(grad_input, grad_, output_, dim_); - }); + log_softmax_backward_kernel(kCPU, grad_input, grad_, output_, dim_); } } } @@ -551,6 +530,8 @@ DEFINE_DISPATCH(log_softmax_backward_lastdim_kernel); DEFINE_DISPATCH(softmax_kernel); DEFINE_DISPATCH(log_softmax_kernel); +DEFINE_DISPATCH(softmax_backward_kernel); +DEFINE_DISPATCH(log_softmax_backward_kernel); Tensor softmax(const Tensor& self, Dimname dim, optional dtype) { return at::softmax(self, dimname_to_position(self, dim), dtype); diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 7396e9b3b91c8..18820973fd847 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -263,7 +264,16 @@ Tensor quantile_compute( // For nanquantile, compute ranks based on number of non-nan values. // If all values are nan, set rank to 0 so the quantile computed is nan. ranks = q * (sorted.isnan().logical_not_().sum(-1, true) - 1); - ranks.masked_fill_(ranks < 0, 0); + // For Composite Compliance, + // if `ranks` is `CCT` but it's tangent is a regular Tensor, + // then while computing jvp, we end calling `masked_fill_` + // on a regular Tensor with CCT args, so we call + // `masked_fill` instead. + if (isTensorSubclassLike(ranks) && ranks._fw_grad(/*level=*/0).defined()) { + ranks = ranks.masked_fill(ranks < 0, 0); + } else { + ranks.masked_fill_(ranks < 0, 0); + } } else { // For quantile, compute ranks based on reduction size. If there is nan // set rank to last index so the quantile computed will be nan. @@ -296,7 +306,23 @@ Tensor quantile_compute( // Interpolate to compute quantiles and store in values_below Tensor ranks_above = ranks.ceil_().toType(kLong); Tensor values_above = sorted.gather(-1, ranks_above); - values_below.lerp_(values_above, weights); + // For Composite Compliance, + // if either `values_below`, `values_above` or `weights` are a CCT + // or tangents of `value_above` and `weights` are a CCT, + // but if the tangent of `value_below` is a regular Tensor, + // then while computing jvp, we will end-up copying a `CCT`, + // into regular Tensor. So we use out-of-place variant of `lerp` + auto is_primal_cct = + areAnyTensorSubclassLike({values_below, values_above, weights}); + auto is_tangent_cct = areAnyTensorSubclassLike( + {values_above._fw_grad(/*level=*/0), weights._fw_grad(/*level=*/0)}); + if ((is_primal_cct || is_tangent_cct) && + values_below._fw_grad(/*level=*/0).defined() && + !isTensorSubclassLike(values_below._fw_grad(/*level=*/0))) { + values_below = values_below.lerp(values_above, weights); + } else { + values_below.lerp_(values_above, weights); + } } if (q.dim() == 0) { diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 9c0ebed7551aa..d6389608a9e36 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -1088,7 +1088,15 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho // We need to trim the front padding away if centered const auto start = center ? n_fft / 2 : 0; - const auto end = lengthOpt.has_value() ? start + lengthOpt.value() : (center ? - n_fft / 2 : -1); + const auto end = [&] () -> int64_t { + if (lengthOpt.has_value()) { + return start + *lengthOpt; + } + if (center) { + return -(n_fft / 2); + } + return expected_output_signal_len; + }(); y = y.slice(2, start, end, 1); window_envelop = window_envelop.slice(2, start, end, 1); diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 61a0ffdfe3036..951d9eeb18fa3 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -853,7 +853,7 @@ void index_reduce_func_impl( const SCATTER_GATHER_OP& op) { if (!result.is_same(self)) result.copy_(self); if (!include_self) { - AT_DISPATCH_FLOATING_TYPES_AND2( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_reduce_func_exclude_input_init", [&] { scalar_t init_val; @@ -932,7 +932,11 @@ void index_reduce_func_impl( auto counts = include_self ? at::ones_like(result) : at::zeros_like(result); counts.index_add_(dim, index, at::ones_like(source)); counts.masked_fill_(counts == 0, 1); - result.div_(counts); + if (result.is_floating_point() || result.is_complex()) { + result.div_(counts); + } else { + result.div_(counts, "floor"); + } } } else { @@ -940,7 +944,7 @@ void index_reduce_func_impl( auto counts = include_self ? at::ones_like(result) : at::zeros_like(result); // explicitly capture all required variables to work around windows build // TODO: fix this when windows can correctly capture variables in nested lambda - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, + AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, result.scalar_type(), "index_func_", [&result, &source, &dim, &index_contig, &numel, &op, &counts] { auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); auto source_stride = source.dim() == 0 ? 1 : source.stride(dim); @@ -983,7 +987,11 @@ void index_reduce_func_impl( }); if (op == SCATTER_GATHER_OP::REDUCE_MEAN) { counts.masked_fill_(counts == 0, 1); - result.div_(counts); + if (result.is_floating_point() || result.is_complex()) { + result.div_(counts); + } else { + result.div_(counts, "floor"); + } } } } diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index b3546f3a54134..02ccd133c7ee0 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -4,9 +4,9 @@ #include #include -#include #include #include +#include namespace at { namespace native { @@ -385,26 +385,55 @@ Tensor sparse_compressed_to_dense( return dst_transposed.transpose(batch_ndim, batch_ndim + 1); } if (self.layout() == kSparseBsr) { - TORCH_CHECK(self.dim() == 2, "Can only convert 2D SparseBsr to Strided."); - Tensor indices = at::_convert_indices_from_csr_to_coo( - self.crow_indices(), self.col_indices(), false, false); + auto crow_indices = self.crow_indices(); + auto col_indices = self.col_indices(); auto values = self.values(); - int64_t blocksize[2] = {values.size(-2), values.size(-1)}; - DimVector expanded_size( - {self.size(0) / blocksize[0], - self.size(1) / blocksize[1], - blocksize[0], - blocksize[1]}); - // We make use of COO dense dimensions here to use the COO to dense format - // conversion. - auto self_coo = - at::native::_sparse_coo_tensor_unsafe(indices, values, expanded_size) - .coalesce(); - auto dense = self_coo.to_dense(); - // Here we are untiling the result. - dense = dense.transpose(1, 2); - dense = dense.reshape({self.size(0), self.size(1)}); - return dense; + Tensor dense = at::zeros(self.sizes(), self.options().layout(kStrided)); + if (self.dim() == 2) { + // Pad shape so we can treat 2-d like batched, we will squeeze out the + // phantom batch dim at the end + crow_indices = crow_indices.unsqueeze(0); + col_indices = col_indices.unsqueeze(0); + values = values.unsqueeze(0); + dense = dense.unsqueeze(0); + } + if (self.dim() > 3) { + // Flatten batch dims + auto n_batch_dim = self.dim() - 2; + crow_indices = crow_indices.flatten(0, n_batch_dim); + col_indices = col_indices.flatten(0, n_batch_dim); + values = values.flatten(0, n_batch_dim); + dense = dense.flatten(0, n_batch_dim); + } + + // At this point everything has 3d shape either the batch dim was inserted, + // existed already or was flattened from multiple batch dims + std::array blocksize = {values.size(-2), values.size(-1)}; + auto n_batch = values.size(0); + // If we already had batch dim(s) and any of them were zero we can take the + // early exit. + if (n_batch == 0) { + return dense.reshape(self.sizes()); + } + // Due to early exit above this reshape should always be valid + dense = dense.reshape({n_batch, -1, values.size(-2), values.size(-1)}); + for (auto batch : c10::irange(n_batch)) { + Tensor batch_indices = at::_convert_indices_from_csr_to_coo( + crow_indices[batch], col_indices[batch], false, false); + auto batch_row_indices = batch_indices.select(0, 0); + auto batch_col_indices = batch_indices.select(0, 1); + auto offsets = batch_col_indices + + batch_row_indices * (self.size(-1) / blocksize[1]); + dense[batch].index_add_(0, offsets, values[batch]); + } + + // untile the result, NOTE: The final reshape uses the original self.sizes() + // which will squeeze out the extra batch dim if we put one in + return dense + .unflatten( + 1, {self.size(-2) / blocksize[0], self.size(-1) / blocksize[1]}) + .transpose(2, 3) + .reshape(self.sizes()); } return self.to_sparse().to_dense(); } @@ -564,6 +593,36 @@ Tensor _tile_tensor(const Tensor& self, IntArrayRef blocksize) { .contiguous(); } +Tensor _batch_tile_tensor(const Tensor& self, IntArrayRef blocksize) { + if (self.dim() == 2) { + return _tile_tensor(self, blocksize); + } + auto n_batch_dim = self.dim() - 2; + // Same as _tile_tensor, just per matrix entry of self, if self is 3D. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[0] > 0); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[1] > 0); + auto block_size_0 = self.size(-2) / blocksize[0]; + auto block_size_1 = self.size(-1) / blocksize[1]; + auto tiled_sizes = DimVector(self.sizes().slice(0, n_batch_dim)); + tiled_sizes.push_back(block_size_0); + tiled_sizes.push_back(blocksize[0]); + tiled_sizes.push_back(block_size_1); + tiled_sizes.push_back(blocksize[1]); + + return self.reshape(tiled_sizes).transpose(-3, -2).contiguous(); +} + +Tensor _mask_to_indices(const Tensor& mask) { + // This function returns a vector of the indices at which given + // boolean mask is True. at::nonzero can achieve the same, but + // we yet have to compare the performance difference. + TORCH_CHECK(mask.dim() == 1, "Currently _mask_to_indices only supports 1-d masks."); + TORCH_CHECK(mask.dtype() == at::kBool, "Expected mask to be of dtype bool."); + return at::native::arange( + mask.numel(), at::kLong, kStrided, mask.device()) + .masked_select(mask); +} + std::pair _not_zero_mask_to_col_row_indices( Tensor not_zero_mask, ScalarType index_dtype, @@ -574,7 +633,8 @@ std::pair _not_zero_mask_to_col_row_indices( .expand_as(not_zero_mask) .masked_select(not_zero_mask); auto row_indices = - at::native::arange(not_zero_mask.size(-2), index_dtype, kStrided, index_device) + at::native::arange( + not_zero_mask.size(-2), index_dtype, kStrided, index_device) .view({not_zero_mask.size(-2), 1}) .expand_as(not_zero_mask) .masked_select(not_zero_mask); @@ -582,44 +642,108 @@ std::pair _not_zero_mask_to_col_row_indices( } Tensor dense_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize) { - TORCH_CHECK(self.dim() == 2, "Can only covert 2D Tensor to BSR."); TORCH_CHECK( blocksize[0] > 0 && blocksize[1] > 0, "blocksize needs to be non zero, but got ", blocksize); TORCH_CHECK( - self.size(0) % blocksize[0] == 0, - "Tensor size(0) ", - self.size(0), + self.size(-2) % blocksize[0] == 0, + "Tensor size(-2) ", + self.size(-2), " needs to be divisible by blocksize[0] ", blocksize[0]); TORCH_CHECK( - self.size(1) % blocksize[1] == 0, - "Tensor size(1) ", - self.size(1), + self.size(-1) % blocksize[1] == 0, + "Tensor size(-1) ", + self.size(-1), " needs to be divisible by blocksize[1] ", blocksize[1]); - auto block_size_0 = self.size(0) / blocksize[0]; - auto values = _tile_tensor(self, blocksize); - auto not_zero_mask = _tile_tensor((self != 0), blocksize); + auto block_size_0 = self.size(-2) / blocksize[0]; + auto n_batch_dim = self.dim() - 2; + + auto values = _batch_tile_tensor(self, blocksize); + auto not_zero_mask = _batch_tile_tensor((self != 0), blocksize); // Find tiles that have at least 1 non-zero value in them. not_zero_mask = not_zero_mask.any(-1).any(-1); + + if (n_batch_dim > 0) { + // for 3D input the mask is already flat along the batch dims, avoid + // creating unnessesary view + if (n_batch_dim > 1) { + // flatten out the batch dims for N-D input + not_zero_mask = not_zero_mask.flatten(0, n_batch_dim - 1); + } + TORCH_CHECK( + not_zero_mask.size(0) > 0, + "to_sparse_bsr: Expected product of batch dimensions to be non-zero."); + + // If the input is ND we assert that the same sparsity pattern + // is used across matrices. That means the same number of materialized + // values and *at the same location*. + // This requirement is not included in Pearu's blog post on BSR invariants. + // He specifically states that different batches may have different sparsity + // patterns as long as the number of specified elements is the same for all + // batches. + + auto not_zero_mask_0 = not_zero_mask.select(0, 0); + auto nse_per_batch = not_zero_mask_0.sum().repeat(not_zero_mask.size(0)); + TORCH_CHECK( + not_zero_mask.sum({-2, -1}).equal(nse_per_batch), + "Expect the same number of specified elements per batch."); + } + Tensor col_indices; Tensor row_indices; - std::tie(col_indices, row_indices) = - _not_zero_mask_to_col_row_indices(not_zero_mask, at::kLong, not_zero_mask.device()); - Tensor crow_indices = at::_convert_indices_from_coo_to_csr( - row_indices.view({-1}), block_size_0, false /* out_int32 */); - values = values.reshape({-1, values.size(-2), values.size(-1)}); - not_zero_mask = not_zero_mask.reshape({-1}); - // TODO: masked_select does not support some form of broadcasting, so we're - // using the mask to construct indices that are then passed into index_select. - // This isn't ideal. - values = values.index_select( - 0, - at::native::arange(not_zero_mask.numel(), at::kLong, kStrided, not_zero_mask.device()) - .masked_select(not_zero_mask)); + std::tie(col_indices, row_indices) = _not_zero_mask_to_col_row_indices( + not_zero_mask, at::kLong, not_zero_mask.device()); + Tensor crow_indices; + + if (n_batch_dim > 0) { + // reshape to put the (flattened) batch dims back in + col_indices = col_indices.reshape({not_zero_mask.size(0), -1}); + row_indices = row_indices.reshape({not_zero_mask.size(0), -1}); + crow_indices = at::empty( + {not_zero_mask.size(0), block_size_0 + 1}, col_indices.options()); + // For each batch compute crow_indices + for (auto batch : c10::irange(not_zero_mask.size(0))) { + Tensor batch_crow_indices = crow_indices[batch]; + at::_convert_indices_from_coo_to_csr_out( + batch_crow_indices, + row_indices[batch], + block_size_0, + false /* out_int32 */); + } + // At this point, we have constructed col_indices and crow_indices + // such that they are 2d with dim0 of length B = product(batchdims). We can + // now reshape them to the correct shapes. + auto batch_shape = self.sizes().slice(0, n_batch_dim); + crow_indices = crow_indices.unflatten(0, batch_shape); + col_indices = col_indices.unflatten(0, batch_shape); + + // Mask is also leading dim B, but we can't masked select wit it (see below) + // unless it is flat, then we can partially faltten values, index it along + // and unfold the result to batchdims + (nnz(per batch), ) + auto batch_sizes_nnz = DimVector(batch_shape); + batch_sizes_nnz.push_back(-1); // we can infer nnz + not_zero_mask = not_zero_mask.flatten(); + // TODO: masked_select does not support some form of broadcasting, so we're + // using the mask to construct indices that are then passed into + // index_select. This isn't ideal. + values = values.flatten(0, -3) + .index_select(0, _mask_to_indices(not_zero_mask)) + .unflatten(0, batch_sizes_nnz); + + } else { + crow_indices = at::_convert_indices_from_coo_to_csr( + row_indices.view({-1}), block_size_0, false /* out_int32 */); + not_zero_mask = not_zero_mask.reshape({-1}); + // TODO: masked_select does not support some form of broadcasting, so we're + // using the mask to construct indices that are then passed into + // index_select. This isn't ideal. + values = values.reshape({-1, values.size(-2), values.size(-1)}) + .index_select(0, _mask_to_indices(not_zero_mask)); + } return at::native::_sparse_bsr_tensor_unsafe( crow_indices, diff --git a/aten/src/ATen/native/TensorConversions.h b/aten/src/ATen/native/TensorConversions.h index 04f44f6f1980f..75a01ea0e7554 100644 --- a/aten/src/ATen/native/TensorConversions.h +++ b/aten/src/ATen/native/TensorConversions.h @@ -20,5 +20,6 @@ bool to_will_alias( Tensor to_meta(const Tensor& tensor); c10::optional to_meta(const c10::optional& tensor); std::vector to_meta(const at::TensorList& t_list); + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 4494ff16eb6be..7d112b9f415d4 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -17,12 +17,18 @@ #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif #include #include #include #include #include +#include namespace at { namespace native { @@ -176,6 +182,11 @@ Tensor empty_cpu(IntArrayRef size, c10::optional dtype_opt, c10::opt return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } +Tensor empty_symint_cpu(c10::SymIntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + return at::native::empty_cpu(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); +} + Tensor empty( IntArrayRef size, c10::optional names, @@ -381,11 +392,22 @@ Tensor new_empty( c10::optional device_opt, c10::optional pin_memory_opt ) { + return self.new_empty_symint(c10::SymIntArrayRef::fromIntArrayRef(size), dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + +Tensor new_empty_symint( + const Tensor& self, + SymIntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt + ) { auto dtype = dtype_opt.has_value() ? dtype_opt : optTypeMetaToScalarType(self.options().dtype_opt()); auto layout = layout_opt.has_value() ? layout_opt : self.options().layout_opt(); auto device = device_opt.has_value() ? device_opt : self.options().device_opt(); auto pin_memory = pin_memory_opt.has_value() ? pin_memory_opt : self.options().pinned_memory_opt(); - return at::empty(size, dtype, layout, device, pin_memory, c10::nullopt); + return at::empty_symint(size, dtype, layout, device, pin_memory, c10::nullopt); } Tensor new_empty_strided( @@ -411,7 +433,7 @@ Tensor eye(int64_t n, c10::optional device, c10::optional pin_memory) { // the default value of `m` equals to `n` - return native::eye(n, n, dtype, layout, device, pin_memory); + return at::eye(n, n, dtype, layout, device, pin_memory); } Tensor eye(int64_t n, int64_t m, @@ -436,6 +458,9 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) { TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); result.resize_({n, m}); + + if (result.is_meta()) return result; + result.zero_(); int64_t sz = std::min(n, m); @@ -1073,6 +1098,14 @@ Tensor zeros(IntArrayRef size, return result.zero_(); } +Tensor zeros_symint(c10::SymIntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + return at::zeros(asIntArrayRefSlow(size), dtype, layout, device, pin_memory); +} + Tensor _efficientzerotensor(IntArrayRef size, c10::optional dtype, c10::optional layout, @@ -1118,6 +1151,7 @@ Tensor zeros_like( } else { res.sparse_resize_and_clear_(self.sizes(), self.sizes().size(), 0); } + res._coalesced_(true); return res; } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 72758f1e70bc2..bbecb346ce3ef 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -27,6 +27,7 @@ #include #include #include +#include namespace at { namespace meta { @@ -844,7 +845,7 @@ Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim Tensor expand_symint(const Tensor& self, c10::SymIntArrayRef packed_size, bool implicit) { auto size = asIntArrayRefSlow(packed_size); - return expand(self, size, implicit); + return self.expand(size, implicit); } Tensor expand(const Tensor& self, IntArrayRef size, bool /*unused*/) { @@ -927,7 +928,7 @@ const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stri } Tensor narrow_copy_symint(const Tensor& self, int64_t dim, int64_t start, SymInt sym_length) { - return narrow_copy(self, dim, start, sym_length.expect_int()); + return self.narrow_copy(dim, start, sym_length.expect_int()); } Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) { @@ -1072,24 +1073,101 @@ Tensor narrow(const Tensor& self, int64_t dim, const Tensor& start, int64_t leng return at::narrow(self, dim, st, length); } +std::tuple> +_permute_size_stride_estimation(const Tensor& self, IntArrayRef dims) { + const auto ndim = self.dim(); + TORCH_CHECK(ndim == static_cast(dims.size()), + "permute(sparse_coo): number of dimensions in the tensor input ", + "does not match the length of the desired ordering of dimensions ", + "i.e. input.dim() = ", ndim, " is not equal to len(dims) = ", dims.size()); + + const auto is_strided_layout = self.options().layout() == at::kStrided; + const auto old_sizes = self.sizes(); + const auto old_strides = is_strided_layout ? self.strides() : IntArrayRef{}; + + auto new_sizes = DimVector(ndim); + auto new_strides = DimVector(is_strided_layout ? ndim : 0); + auto wrapped_dims = std::vector(ndim); + std::vector seen_dims(ndim); + + for (const auto i : c10::irange(ndim)) { + const auto d = maybe_wrap_dim(dims[i], ndim); + TORCH_CHECK(!seen_dims[d], + "permute(): duplicate dims are not allowed."); + seen_dims[d] = true; + wrapped_dims[i] = d; + new_sizes[i] = old_sizes[d]; + if (is_strided_layout) { + new_strides[i] = old_strides[d]; + } + } + + return std::make_tuple(new_sizes, new_strides, wrapped_dims); +} + Tensor permute(const Tensor& self, IntArrayRef dims) { - auto nDims = self.dim(); - TORCH_CHECK(dims.size() == (size_t)nDims, - "number of dims don't match in permute"); - auto oldSizes = self.sizes(); - auto oldStrides = self.strides(); - DimVector newSizes(nDims); - DimVector newStrides(nDims); - std::vector seen(nDims); - for (const auto i : c10::irange(nDims)) { - auto dim = maybe_wrap_dim(dims[i], nDims); - TORCH_CHECK(!seen[dim], - "repeated dim in permute"); - seen[dim] = true; - newSizes[i] = oldSizes[dim]; - newStrides[i] = oldStrides[dim]; - } - return self.as_strided(newSizes, newStrides); + DimVector new_sizes, new_strides; + std::vector _; + std::tie(new_sizes, new_strides, _) = _permute_size_stride_estimation(self, dims); + return self.as_strided(new_sizes, new_strides); +} + +Tensor permute_sparse_coo(const Tensor& self, IntArrayRef dims) { + DimVector new_sizes, _; + std::vector wrapped_dims; + std::tie(new_sizes, _, wrapped_dims) = _permute_size_stride_estimation(self, dims); + + const auto ndim = self.dim(); + const auto sparse_ndim = self.sparse_dim(); + const auto dense_ndim = self.dense_dim(); + + auto dims_id_perm = std::vector(ndim); + auto dims_sparse_dense_id_perm = std::vector(ndim); + for (const auto i : c10::irange(ndim)) { + dims_id_perm[i] = i; + dims_sparse_dense_id_perm[i] = wrapped_dims[i]; + } + std::sort(dims_sparse_dense_id_perm.begin(), dims_sparse_dense_id_perm.begin() + sparse_ndim); + std::sort(dims_sparse_dense_id_perm.begin() + sparse_ndim, dims_sparse_dense_id_perm.end()); + TORCH_CHECK(dims_sparse_dense_id_perm == dims_id_perm, + "permute(sparse_coo): transpositions between sparse and dense dimensions are not allowed.", + "Only transpositions within sparse and dense dimensions are supported."); + + const auto slice = [](std::vector v, size_t begin, size_t len) -> decltype(v) { + return std::vector{v.begin() + begin, v.begin() + begin + len}; + }; + + auto old_sparse_dims = slice(dims_id_perm, 0, sparse_ndim); + auto old_dense_dims = slice(dims_id_perm, sparse_ndim, ndim - sparse_ndim); + auto new_sparse_dims = slice(wrapped_dims, 0, sparse_ndim); + auto new_dense_dims = slice(wrapped_dims, sparse_ndim, ndim - sparse_ndim); + + auto old_indices = self._indices(); + auto old_values = self._values(); + + const auto new_indices = (new_sparse_dims == old_sparse_dims) + ? old_indices + : [&]() -> Tensor { + auto sparse_perm_tensor = at::from_blob(reinterpret_cast(new_sparse_dims.data()), + {sparse_ndim}, old_indices.options().device(at::kCPU)); + // creates new indices. It is possible to avoid that if COO + // is allowed to store a permutation vector. + return old_indices.index_select(0, sparse_perm_tensor.to(self.device().type())); + }(); + const auto new_values = (new_dense_dims == old_dense_dims) + ? old_values + : [&]() -> Tensor { + auto values_perm = std::vector(dense_ndim + 1); + for (const auto i : c10::irange(dense_ndim)) { + values_perm[i + 1] = new_dense_dims[i] - sparse_ndim + 1; + } + return old_values.permute(values_perm); + }(); + + const auto is_coalesced = self.is_coalesced() && (dims[0] == 0); + return _sparse_coo_tensor_with_dims_and_tensors( + sparse_ndim, dense_ndim, new_sizes, new_indices, new_values, self.options()) + ._coalesced_(is_coalesced); } Tensor repeat(const Tensor& self, IntArrayRef repeats) { @@ -1184,6 +1262,17 @@ Tensor alias_with_sizes_and_strides( } Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) { + // reshape has special autograd logic since it sometimes returns a view but sometimes does not + // we have to intercept here instead of using dispatcher + // otherwise we will see "autograd still running" kind of error in inference mode: + // * if we create a tensor in inference mode scope, + // then pass it to a inference mode decorated function, + // everything is fine + // * but if we create the input tensor not with inference mode, + // then errors like "Cannot set version_counter for inference tensor" arise + if (self.is_nested()) { + return at::_reshape_nested(self, proposed_shape); + } if (self.is_sparse()) { AT_ERROR("reshape is not implemented for sparse tensors"); } @@ -1252,7 +1341,8 @@ static Tensor select_sparse(const Tensor& self, int64_t dim, int64_t index) { if (new_values.size(0) == 1) { return new_values[0]; } else { - return new_values.sum(0); + // sum promotes integral type to int64 when dtype is not specified. + return at::sum(new_values, 0, false, new_values.scalar_type()); } } else { auto dimIndices = (arange( @@ -2820,7 +2910,7 @@ static inline void handle_unflatten_exception(const std::runtime_error &e, } } -Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional names) { +Tensor unflatten_impl(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional names) { dim = maybe_wrap_dim(dim, self.dim()); TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty"); @@ -2859,8 +2949,12 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option return result; } +Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes) { + return native::unflatten_impl(self, dim, sizes, c10::nullopt); +} + Tensor unflatten(const Tensor& self, Dimname dim, IntArrayRef sizes, DimnameList names) { - return native::unflatten(self, dimname_to_position(self, dim), sizes, names); + return native::unflatten_impl(self, dimname_to_position(self, dim), sizes, names); } Tensor view_as(const Tensor& self, const Tensor& other) { @@ -3028,6 +3122,11 @@ Tensor view(const Tensor& self, return view_impl(self, size); } +Tensor view_symint(const Tensor& self, + c10::SymIntArrayRef size) { + return self.view(c10::asIntArrayRefSlow(size)); +} + Tensor alias(const Tensor& self) { return alias_with_sizes_and_strides(self, self.sizes(), self.strides()); } @@ -3330,6 +3429,11 @@ at::Tensor lift(const at::Tensor& self) { return self; } +// See notes in native_functions.yaml +at::Tensor lift_fresh(const at::Tensor& self) { + return self; +} + at::Tensor& _fw_primal_copy_out(const at::Tensor & self, int64_t level, at::Tensor & out) { auto tmp = self._fw_primal(level); out.copy_(tmp); diff --git a/aten/src/ATen/native/TransposeType.h b/aten/src/ATen/native/TransposeType.h index 5353394a9dde0..8956bbc5bf928 100644 --- a/aten/src/ATen/native/TransposeType.h +++ b/aten/src/ATen/native/TransposeType.h @@ -12,7 +12,7 @@ enum class TransposeType { }; // Transforms TransposeType into the BLAS / LAPACK format -static char to_blas(TransposeType trans) { +static inline char to_blas(TransposeType trans) { switch (trans) { case TransposeType::Transpose: return 'T'; case TransposeType::NoTranspose: return 'N'; diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 4b0f6f73be0d9..160955a013505 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -71,6 +71,7 @@ CREATE_UNARY_FLOAT_META_FUNC(special_log_ndtr) CREATE_UNARY_FLOAT_META_FUNC(sqrt) CREATE_UNARY_FLOAT_META_FUNC(tan) CREATE_UNARY_FLOAT_META_FUNC(tanh) +CREATE_UNARY_FLOAT_META_FUNC(special_airy_ai) CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j0) CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j1) CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y0) @@ -79,6 +80,9 @@ CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i0) CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i1) CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k0) CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k1) +CREATE_UNARY_FLOAT_META_FUNC(special_scaled_modified_bessel_k0) +CREATE_UNARY_FLOAT_META_FUNC(special_scaled_modified_bessel_k1) +CREATE_UNARY_FLOAT_META_FUNC(special_spherical_bessel_j0) TORCH_META_FUNC(polygamma)(int64_t n, const Tensor& self) { TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n."); @@ -198,6 +202,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub) CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub) CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub) CREATE_UNARY_TORCH_IMPL_FUNC(trunc_out, trunc_stub) +CREATE_UNARY_TORCH_IMPL_FUNC(special_airy_ai_out, special_airy_ai_stub) CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub) CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub) CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y0_out, special_bessel_y0_stub) @@ -206,6 +211,9 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i0_out, special_modified_be CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i1_out, special_modified_bessel_i1_stub) CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k0_out, special_modified_bessel_k0_stub) CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k1_out, special_modified_bessel_k1_stub) +CREATE_UNARY_TORCH_IMPL_FUNC(special_scaled_modified_bessel_k0_out, special_scaled_modified_bessel_k0_stub) +CREATE_UNARY_TORCH_IMPL_FUNC(special_scaled_modified_bessel_k1_out, special_scaled_modified_bessel_k1_stub) +CREATE_UNARY_TORCH_IMPL_FUNC(special_spherical_bessel_j0_out, special_spherical_bessel_j0_stub) TORCH_IMPL_FUNC(round_decimals_out) (const Tensor& self, int64_t decimals, const Tensor& result) { @@ -473,10 +481,6 @@ Tensor& conj_physical_(Tensor& self) { // else returns a new negated tensor with neg bit set to 0 Tensor resolve_neg(const Tensor& self) { if (!self.is_neg()) { return self; } - // currently a tensor should never have both conj and neg bit set - // the only way to get an imag bit is complex_tensor.conj().imag but there's - // no intended designed mechanism to enter the complex world with this imag bit - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!self.is_conj()); // negation is materialized in `copy_()` that clone ultimately calls into return self.clone(); } @@ -485,10 +489,6 @@ Tensor resolve_neg(const Tensor& self) { // else returns a new negated tensor with neg bit set to 0 Tensor resolve_conj(const Tensor& self) { if (!self.is_conj()) { return self; } - // currently a tensor should never have both conj and neg bit set - // the only way to get an imag bit is complex_tensor.conj().imag but there's - // no intended designed mechanism to enter the complex world with this imag bit - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!self.is_neg()); // conjugation is materialized in `copy_()` that clone ultimately calls into return self.clone(); } @@ -879,6 +879,7 @@ DEFINE_DISPATCH(tanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-v DEFINE_DISPATCH(trigamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(trunc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(lgamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(special_airy_ai_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(special_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(special_bessel_j1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(special_bessel_y0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) @@ -887,6 +888,9 @@ DEFINE_DISPATCH(special_modified_bessel_i0_stub); // NOLINT(cppcoreguidelines-av DEFINE_DISPATCH(special_modified_bessel_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(special_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_DISPATCH(special_modified_bessel_k1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(special_scaled_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(special_scaled_modified_bessel_k1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(special_spherical_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) } // namespace native } // namespace at diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index d9786d1ebdb56..103e522fa35db 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -70,6 +70,7 @@ DECLARE_DISPATCH(unary_fn, tanh_stub); DECLARE_DISPATCH(unary_fn, trigamma_stub); DECLARE_DISPATCH(unary_fn, trunc_stub); DECLARE_DISPATCH(unary_fn, lgamma_stub); +DECLARE_DISPATCH(unary_fn, special_airy_ai_stub); DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub); DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub); DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub); @@ -78,6 +79,9 @@ DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub); DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub); DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub); DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub); +DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub); +DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub); +DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub); // NB: these are actually defined in Distribution DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, c10::optional), bernoulli_tensor_stub); diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index dc066d99d46d4..f418611e08644 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -4,7 +4,6 @@ #include #include -#include #include #include #include @@ -24,6 +23,53 @@ namespace native{ namespace { +// Extract the unique elements from [begin, end) into a new Tensor +template +Tensor unique_elements(const scalar_t* begin, const scalar_t* end, + bool sorted, const TensorOptions &options) { + // Create unordered set of elements + auto set = std::unordered_set(begin, end); + + // Write the output tensor + Tensor output = at::empty({static_cast(set.size())}, options); + scalar_t *output_data = output.data_ptr(); + std::copy(set.begin(), set.end(), output_data); + if (sorted) { + std::sort(output_data, output_data + set.size()); + } + return output; +} + +// Specialization for boolean inputs, since we can't construct a set +// directly from an array of bool as it won't handle invalid byte values. +// See NOTE [Loading boolean values] +Tensor unique_elements(const bool* begin, const bool* end, + bool /*sorted*/, const TensorOptions &options) { + // Instead of a set, track whether a value has been seen + std::array seen; + seen.fill(false); + + for (; begin != end; ++begin) { + seen[c10::load(begin)] = true; + if (seen[false] && seen[true]) { + break; + } + } + + // Write the output tensor + int64_t num_elem = seen[false] + seen[true]; + Tensor output = at::empty({num_elem}, options); + bool *output_data = output.data_ptr(); + + if (seen[false]) { + *output_data++ = false; + } + if (seen[true]) { + *output_data++ = true; + } + return output; +} + template std::tuple unique_cpu_template( const Tensor& self, @@ -33,19 +79,11 @@ std::tuple unique_cpu_template( const Tensor& input = self.contiguous(); const scalar_t* input_data = input.data_ptr(); int64_t numel = input.numel(); - Tensor output; Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong)); Tensor counts = at::empty({0}, self.options().dtype(kLong)); - std::unordered_set set(input_data, input_data + numel); - output = at::empty({static_cast(set.size())}, input.options()); - scalar_t *output_data = output.data_ptr(); - if (sorted) { - std::vector vec(set.begin(), set.end()); - std::sort(vec.begin(), vec.end()); - std::copy(vec.begin(), vec.end(), output_data); - } else { - std::copy(set.begin(), set.end(), output_data); - } + Tensor output = unique_elements(input_data, input_data + numel, + sorted, input.options()); + const scalar_t *output_data = output.data_ptr(); if (return_inverse || return_counts) { inverse_indices.resize_(input.sizes()); @@ -56,7 +94,8 @@ std::tuple unique_cpu_template( inverse_map[output_data[i]] = i; } for (const auto i : c10::irange(numel)) { - inverse_indices_data[i] = inverse_map[input_data[i]]; + const auto val = c10::load(&input_data[i]); + inverse_indices_data[i] = inverse_map[val]; } if (return_counts) { std::unordered_map counts_map; @@ -65,7 +104,8 @@ std::tuple unique_cpu_template( counts_map[output_data[i]] = 0; } for (const auto i : c10::irange(numel)) { - counts_map[input_data[i]] += 1; + const auto val = c10::load(&input_data[i]); + counts_map[val] += 1; } counts.resize_(output.sizes()); counts.fill_(0); @@ -98,7 +138,8 @@ std::tuple unique_consecutive_cpu_template( scalar_t *output_data = output.data_ptr(); int64_t *inverse_data = inverse_indices.data_ptr();; int64_t *counts_data = nullptr; - *output_data = *input_data; + scalar_t last_value = c10::load(input_data); + *output_data = last_value; if (return_counts) { counts.resize_({numel}); @@ -111,8 +152,10 @@ std::tuple unique_consecutive_cpu_template( inverse_data[0] = 0; } for (const auto i : c10::irange(1, numel)) { - if (input_data[i] != *p) { - *(++p) = input_data[i]; + const auto value = c10::load(&input_data[i]); + if (value != last_value) { + *(++p) = value; + last_value = value; if (return_counts) { *(q++) = i - last; last = i; @@ -208,8 +251,8 @@ std::tuple _unique_dim_cpu_template( std::sort(indices.begin(), indices.end(), [&](int64_t a, int64_t b) -> bool { for (const auto i : c10::irange(numel)) { - scalar_t lhs = input_flat_ptr[i + a * numel]; - scalar_t rhs = input_flat_ptr[i + b * numel]; + scalar_t lhs = c10::load(&input_flat_ptr[i + a * numel]); + scalar_t rhs = c10::load(&input_flat_ptr[i + b * numel]); if (lhs < rhs) { return true; } else if (lhs > rhs) { diff --git a/aten/src/ATen/native/WeightNorm.cpp b/aten/src/ATen/native/WeightNorm.cpp index b2229bdbf0d2b..bf258d80a0fb3 100644 --- a/aten/src/ATen/native/WeightNorm.cpp +++ b/aten/src/ATen/native/WeightNorm.cpp @@ -82,7 +82,10 @@ Tensor _weight_norm auto v = v_in.contiguous(); auto g = g_in.contiguous(); - bool can_use_fused = (dim == 0) || (dim == v.dim() - 1); + auto has_half_dtype = v.scalar_type() == at::ScalarType::Half + || g.scalar_type() == at::ScalarType::Half; + + bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1)); if (can_use_fused) { // weight_norm does not have a derivative defined for it, so this will route back through diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp index 07fc3d245fe20..2f1d8a3e7be98 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp @@ -14,53 +14,20 @@ int register_linear_params() { "sparse", TORCH_SELECTIVE_CLASS("LinearPackedParamsBase")) .def_pickle( [](const c10::intrusive_ptr& params) - -> LinearPackedSerializationType { // __getstate__ - return params->unpack(); + -> BCSRSerializationType { // __getstate__ + return params->serialize(); }, - [](LinearPackedSerializationType state) + [](BCSRSerializationType state) -> c10::intrusive_ptr< LinearPackedParamsBase> { // __setstate__ - at::Tensor weight; - c10::optional bias; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t out_features_block_size, in_features_block_size; - weight = std::move(std::get<0>(state)); - bias = std::move(std::get<1>(state)); - out_features_block_size = std::get<2>(state)[0]; - in_features_block_size = std::get<2>(state)[1]; - #ifdef USE_FBGEMM if (at::globalContext().qEngine() == at::QEngine::FBGEMM) { - if (weight.scalar_type() == at::kQInt8) { - return PackedLinearWeight::prepack( - weight, - bias, - out_features_block_size, - in_features_block_size); - } else { - TORCH_CHECK( - false, - "Unsupported data type", - c10::toString(weight.scalar_type()), - " in serialized LinearPackedParams object!"); - } + return PackedLinearWeight::deserialize(state); } #endif // USE_FBGEMM #ifdef USE_PYTORCH_QNNPACK if (at::globalContext().qEngine() == at::QEngine::QNNPACK) { - if (weight.scalar_type() == at::kQInt8) { - return PackedLinearWeightQnnp::prepack( - weight, - bias, - out_features_block_size, - in_features_block_size); - } else { - TORCH_CHECK( - false, - "Unsupported data type", - c10::toString(weight.scalar_type()), - " in serialized LinearPackedParams object!"); - } + return PackedLinearWeightQnnp::deserialize(state); } #endif // USE_FBGEMM TORCH_CHECK(false, "Unknown qengine"); diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h index e87b583074462..6943f9cdd3f5f 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h @@ -63,6 +63,11 @@ struct TORCH_API PackedLinearWeight LinearPackedSerializationType unpack() override; + BCSRSerializationType serialize() override; + + static c10::intrusive_ptr deserialize( + const BCSRSerializationType& serialized); + c10::optional bias() override { return bias_; } diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h b/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h index d2c9cc37e43d2..57ebba85a0632 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h @@ -11,6 +11,31 @@ namespace sparse { using LinearPackedSerializationType = std::tuple, std::vector>; +#define SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION 1 + +using BCSRSerializationType = + std::tuple< + int64_t, // Serialization Version + c10::optional, // Bias + int64_t, // Out Features (Row) Block Size + int64_t, // In Features (Column) Block Size + at::Tensor, // Weight Scales (single element vector if per-tensor) (float) + at::Tensor, // Wrapper for Weight Zero Points (single element vector if per-tensor) (int8_t) + bool, // Quantization Scheme (true: per tensor, false: per channel) + at::Tensor, // Wrapper for Row Block Indices (int32_t) + at::Tensor, // Wrapper for Column Block Indices (int32_t) + at::Tensor, // Wrapper for Non-Zero Weight Values, each +128 (uint8_t) + int64_t, // Number of Output Channels + int64_t // Number of Input Channels + >; + +using BCSR = + std::tuple< + std::vector, // Non-Zero Weight Values + std::vector, // Compressed Row Block Indices + std::vector // Column Block Indices + >; + struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { public: LinearPackedParamsBase( @@ -33,6 +58,8 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { virtual LinearPackedSerializationType unpack() = 0; + virtual BCSRSerializationType serialize() = 0; + virtual c10::optional bias() = 0; virtual void set_bias(const c10::optional& bias) { diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp new file mode 100644 index 0000000000000..24d24eee66ec9 --- /dev/null +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp @@ -0,0 +1,258 @@ +#include + +#ifdef USE_FBGEMM +#include +#endif +#ifdef USE_PYTORCH_QNNPACK +#include +#endif + +namespace ao { +namespace sparse { + +namespace { +const int64_t bias_index = 1; +const int64_t out_features_block_size_index = 2; +const int64_t in_features_block_size_index = 3; +const int64_t weight_scales_index = 4; +const int64_t weight_zero_point_index = 5; +const int64_t quantization_scheme_index = 6; +const int64_t row_block_indices_index = 7; +const int64_t col_block_indices_index = 8; +const int64_t weight_values_index = 9; +const int64_t num_output_channels_index = 10; +const int64_t num_input_channels_index = 11; + +template +std::vector unwrap_vector(at::Tensor tensor) { + std::vector vec(tensor.numel()); + TENSOR_DTYPE* tensor_data_ptr = tensor.data_ptr(); + std::copy(tensor_data_ptr, tensor_data_ptr + tensor.numel(), vec.data()); + return vec; +} + +#ifdef USE_FBGEMM +/** + * Adapted from Fbgemm BCSRMatrix::unpack, but with non-zero zero points and + * without tiling + * https://github.com/pytorch/FBGEMM/blob/9d7c48a65419d0350f9e9e72f31e05bfe37e85a4/src/FbgemmSparseDense.cc#L154 + */ +void unpack_bcsr( + int8_t* dst, + ao::sparse::BCSR bcsr, + const int64_t R, + const int64_t C, + const int64_t RB, + const int64_t CB, + const int8_t* zero_points, + const bool qscheme_per_tensor) { + const size_t ld = C; + // zero out destination + if (qscheme_per_tensor) { + memset(dst, zero_points[0], R * C * sizeof(int8_t)); + } else { + for (int64_t i = 0; i < R; i++) { + memset(dst + i * C, zero_points[i], C * sizeof(int8_t)); + } + } + const std::vector& weight_values = std::get<0>(bcsr); + const std::vector& row_indices = std::get<1>(bcsr); + const std::vector& col_indices = std::get<2>(bcsr); + int64_t rowBlocks = (R + RB - 1) / RB; + for (int64_t i = 0; i < rowBlocks; ++i) { + // For the current tile, rowBPtr starts from currentTileIdx + for (int64_t r = row_indices[i]; r < row_indices[i + 1]; ++r) { + int64_t curColIdx = col_indices[r]; + for (int64_t ib = 0; ib < RB; ++ib) { + for (int64_t jb = 0; jb < CB; ++jb) { + // Are we within bounds of destination matrix? + if ((i * RB + ib) < R && (curColIdx * CB + jb) < C) { + dst[(i * RB + ib) * ld + curColIdx * CB + jb] = + weight_values[r * RB * CB + ib * CB + jb]; + } + } + } + } + } +} +#endif // USE_FBGEMM +} // namespace + +#ifdef USE_FBGEMM + +c10::intrusive_ptr PackedLinearWeight::deserialize( + const BCSRSerializationType& serialized) { + const int64_t out_features_block_size = + std::get(serialized); + const int64_t in_features_block_size = + std::get(serialized); + const c10::QScheme q_scheme = std::get(serialized) + ? c10::kPerTensorAffine + : c10::kPerChannelAffine; + const int64_t output_channels = + std::get(serialized); + const int64_t input_channels = std::get(serialized); + // Unpack the untiled bcsr, then pack it in tiled form + at::Tensor weight_origin; + const at::Tensor weight_zero_points = + std::get(serialized); + if (q_scheme == c10::kPerTensorAffine) { + weight_origin = at::_empty_affine_quantized( + {output_channels, input_channels}, + at::device(c10::kCPU).dtype(c10::kQInt8), + std::get(serialized).data_ptr()[0], + weight_zero_points.data_ptr()[0]); + } else if (q_scheme == c10::kPerChannelAffine) { + weight_origin = at::_empty_per_channel_affine_quantized( + {output_channels, input_channels}, + std::get(serialized), + weight_zero_points, + 0, // The output channel axis is 0 + device(c10::kCPU).dtype(c10::kQInt8)); + } + + const at::Tensor loaded_weight_values = + std::get(serialized); + const uint8_t* loaded_weight_values_ptr = + loaded_weight_values.data_ptr(); + const int64_t loaded_weight_values_size = loaded_weight_values.numel(); + // Subtract 128 because we serialize as +128, which s best for + // minimizing memory footprint for QNNPack + std::vector weight_values(loaded_weight_values_size); + std::transform( + loaded_weight_values_ptr, + loaded_weight_values_ptr + loaded_weight_values_size, + weight_values.begin(), + [](uint8_t v) { + return static_cast(static_cast(v) - 128); + }); + + // Unpack as non backend specific untiled BCSR then pack as Fbgemm tiled BCSR + // because untiled Fbgemm BCSR currently doesn't exist + unpack_bcsr( + reinterpret_cast(weight_origin.data_ptr()), + ao::sparse::BCSR( + std::move(weight_values), + unwrap_vector( + std::get(serialized)), // Row Indices + unwrap_vector( + std::get(serialized))), // Col Indices + output_channels, + input_channels, + out_features_block_size, + in_features_block_size, + weight_zero_points.data_ptr(), + q_scheme == c10::kPerTensorAffine); + + return PackedLinearWeight::prepack( + weight_origin, + std::get(serialized), + out_features_block_size, + in_features_block_size); +} + +#endif // USE_FBGEMM + +#ifdef USE_PYTORCH_QNNPACK + +c10::intrusive_ptr PackedLinearWeightQnnp::deserialize( + const BCSRSerializationType& serialized) { + return c10::make_intrusive(serialized); +} + +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +PackedLinearWeightQnnp::PackedLinearWeightQnnp( + const BCSRSerializationType& serialized) + : LinearPackedParamsBase( + std::get(serialized), + std::get(serialized)), + orig_bias_(std::get(serialized)), + q_scheme_( + std::get(serialized) + ? c10::kPerTensorAffine + : c10::kPerChannelAffine), + output_channels_(std::get(serialized)), + input_channels_(std::get(serialized)) { + if (orig_bias_.has_value()) { + bias_ = orig_bias_.value(); + + TORCH_CHECK( + (bias_.ndimension() == 1 && bias_.size(0) == output_channels_), + "ao::sparse::qlinear_deserialize (qnnpack): Given weight of size ", + "{", + output_channels_, + ", ", + input_channels_, + "}", + ", expected bias to be 1-dimensional with ", + output_channels_, + " elements", + ", but got bias of size ", + bias_.sizes(), + " instead"); + } else { + bias_ = at::zeros(output_channels_, at::device(at::kCPU).dtype(at::kFloat)); + } + + // Pad amount (8) comes from make_zero_points_and_scales_tensor + // https://github.com/pytorch/pytorch/blob/f8c1acea1e78573c04cd18893c4abff9eea64b03/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h#L468 + const int64_t output_channels_padded = output_channels_ + 8; + + w_scales_ = at::empty( + {output_channels_padded}, at::device(at::kCPU).dtype(at::kFloat)); + float* w_scales_data_ptr = w_scales_.data_ptr(); + std::fill_n( + w_scales_data_ptr + output_channels_, + output_channels_padded - output_channels_, + 1); // Pad with 1 + + w_zero_points_ = + std::vector(output_channels_padded, 0); // Pad with 0; + + const float* w_scales_orig_data_ptr = + std::get(serialized).data_ptr(); + const int8_t* w_zp_orig_data_ptr = + std::get(serialized).data_ptr(); + + const std::function add_128 = [](int8_t v) { + return static_cast(static_cast(v) + 128); + }; + + if (q_scheme_ == at::kPerTensorAffine) { + std::fill_n(w_scales_data_ptr, output_channels_, w_scales_orig_data_ptr[0]); + std::fill_n( + w_zero_points_.begin(), output_channels_, w_zp_orig_data_ptr[0] + 128); + } else if (q_scheme_ == at::kPerChannelAffine) { + std::copy( + w_scales_orig_data_ptr, + w_scales_orig_data_ptr + output_channels_, + w_scales_data_ptr); + std::transform( + w_zp_orig_data_ptr, + w_zp_orig_data_ptr + output_channels_, + w_zero_points_.begin(), + add_128); + } else { + TORCH_CHECK(false, "Unsupported quantization scheme."); + } + + deserialized_bcsr_row_block_indices_ = + std::get(serialized); + deserialized_bcsr_col_block_indices_ = + std::get(serialized); + deserialized_bcsr_weight_values_ = std::get(serialized); + + bcsr_matrix_ = qnnpack::generateBlockCSRMatrix( + (uint32_t*)deserialized_bcsr_col_block_indices_.data_ptr(), + (uint32_t*)deserialized_bcsr_row_block_indices_.data_ptr(), + deserialized_bcsr_weight_values_.data_ptr(), + deserialized_bcsr_col_block_indices_.numel(), + deserialized_bcsr_row_block_indices_.numel(), + deserialized_bcsr_weight_values_.numel(), + out_features_block_size_, + in_features_block_size_); +} +#endif // USE_PYTORCH_QNNPACK + +} // namespace sparse +} // namespace ao diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp index 46a07ffaa15a7..bd6f92c97c5e7 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp @@ -36,7 +36,7 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( const auto rows_input = c10::multiply_integers(input.sizes().begin(), input.sizes().end() - 1); const auto cols_input = static_cast(input.size(input.dim() - 1)); TORCH_CHECK( - cols_input == orig_weight_.size(1), + cols_input == input_channels_, "quantized_sparse_lienar: Input tensor's last and weight tensor's" " second dimension must match."); @@ -71,8 +71,8 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( pytorch_qnnp_operator_t sparse_linear_op{nullptr}; pytorch_qnnp_status status = pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8( - orig_weight_.size(1), - orig_weight_.size(0), + input_channels_, + output_channels_, q_input_contig.q_zero_point(), w_zero_points_.data(), bcsr_matrix_->col_indices.data(), @@ -111,8 +111,7 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( requantization_scales_.data(); std::vector out_sizes = input.sizes().vec(); - size_t rows_w = orig_weight_.size(0); - out_sizes.back() = rows_w; + out_sizes.back() = output_channels_; auto output = at::empty(out_sizes, input.options().dtype(at::kFloat)); @@ -124,7 +123,7 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( cols_input, /* num input channels */ bias_.data_ptr(), output.data_ptr(), - rows_w /* num output channels */); + output_channels_); TORCH_CHECK( status == pytorch_qnnp_status_success, "Failed to setup sparse linear operator on" diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp index 187ed4fd1404a..616ed9011e0cd 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp @@ -52,6 +52,12 @@ c10::intrusive_ptr PackedLinearWeight:: "The weight tensor for ao::sparse::qlinear_prepack (fbgemm) should" " be 2-dimensional."); + TORCH_CHECK( + out_features_block_size == 1 && in_features_block_size == 4, + "The out and in features block sizes for ao::sparse::qlinear_prepack", + " (fbgemm) should be 1 and 4 respectively (got ", out_features_block_size, + " and ", in_features_block_size, ")"); + auto N = weight.size(0); auto K = weight.size(1); @@ -138,29 +144,28 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp( const c10::optional& bias, const int64_t out_features_block_size, const int64_t in_features_block_size) - : LinearPackedParamsBase( - out_features_block_size, - in_features_block_size), - orig_weight_(weight), - orig_bias_(bias) { + : LinearPackedParamsBase(out_features_block_size, in_features_block_size), + orig_bias_(bias), + q_scheme_(weight.qscheme()), + output_channels_(weight.size(0)), + input_channels_(weight.size(1)) { TORCH_CHECK( weight.dim() == 2, "ao::sparse::qlinear (qnnpack): Weight tensor rank should be == 2"); TORCH_CHECK(out_features_block_size > 0, "Row block size must be > 0."); TORCH_CHECK(in_features_block_size > 0, "Row block size must be > 0."); - int64_t rows_w = weight.size(0); if (bias.has_value()) { bias_ = bias.value(); } else { - bias_ = at::zeros(rows_w, weight.options().dtype(at::kFloat)); + bias_ = at::zeros(output_channels_, weight.options().dtype(at::kFloat)); } TORCH_CHECK( - (bias_.ndimension() == 1 && bias_.size(0) == rows_w), + (bias_.ndimension() == 1 && bias_.size(0) == output_channels_), "ao::sparse::qlinear_prepack (qnnpack): Given weight of size ", weight.sizes(), ", expected bias to be 1-dimensional with ", - rows_w, + output_channels_, " elements", ", but got bias of size ", bias_.sizes(), @@ -168,9 +173,8 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp( // Given bias is supposed to be 1 dim, it is already contiguous, // but the weight might be non-contiguous. - at::Tensor weight_contig = orig_weight_.contiguous(); + at::Tensor weight_contig = weight.contiguous(); - q_scheme_ = orig_weight_.qscheme(); std::tie(w_zero_points_, w_scales_) = make_zero_points_and_scales_tensor(weight_contig); const float* weight_scales_data = w_scales_.data_ptr(); @@ -188,8 +192,8 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp( } bcsr_matrix_ = qnnpack::generateBlockCSRMatrix( reinterpret_cast(qnnp_w_data), - orig_weight_.size(0), /* output_channels */ - orig_weight_.size(1), /* input_channels */ + output_channels_, + input_channels_, out_features_block_size, in_features_block_size, w_zero_points_.data()); diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_serialize.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_serialize.cpp new file mode 100644 index 0000000000000..cacb2815a2a3a --- /dev/null +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_serialize.cpp @@ -0,0 +1,218 @@ +#include + +#ifdef USE_FBGEMM +#include +#endif +#ifdef USE_PYTORCH_QNNPACK +#include +#endif + +namespace ao { +namespace sparse { + +namespace { +/** + - Wrap a vector in a Tensor, copying data into its own data pointer. + - The type of vec is T& (not vector&) so this works with any vector-like + datastructure which has .data() and .size() + */ +template +at::Tensor wrap_vector(T& vec, c10::ScalarType dtype) { + at::Tensor t = at::empty( + {static_cast(vec.size())}, at::device(c10::kCPU).dtype(dtype)); + std::copy( + vec.data(), vec.data() + vec.size(), t.data_ptr()); + return t; +} + +#ifdef USE_FBGEMM +/** + * Adapted from Fbgemm BCSRMatrix::pack, but with zero points, without tiling, + * and without determining row_offsets + * https://github.com/pytorch/FBGEMM/blob/9d7c48a65419d0350f9e9e72f31e05bfe37e85a4/src/FbgemmSparseDense.cc#L84 + */ +ao::sparse::BCSR pack_bcsr( + const int8_t* src, + const int64_t R, + const int64_t C, + const int64_t RB, + const int64_t CB, + const int8_t* zero_points, + const bool qscheme_per_tensor) { + const size_t ld = C; + std::vector rowBPtr; + std::vector colBIdx; + std::vector values; + rowBPtr.push_back(0); + int64_t nnzb = 0; + int64_t rowBlocks = (R + RB - 1) / RB; + for (int64_t i = 0; i < rowBlocks; ++i) { + int64_t curCols = C; + int64_t curColBlocks = (curCols + CB - 1) / CB; + for (int64_t j = 0; j < curColBlocks; ++j) { + // is the whole block zero? + bool isCurrentBlockNonZero = false; + for (int64_t ib = 0; ib < RB; ++ib) { + // break if already found a non-zero element or + // out of bounds + if (isCurrentBlockNonZero || (i * RB + ib) >= R) { + break; + } + const int64_t curr_row = i * RB + ib; + const int8_t curr_row_zero_point = + qscheme_per_tensor ? zero_points[0] : zero_points[curr_row]; + for (int64_t jb = 0; jb < CB; ++jb) { + // within bound? + if ((j * CB + jb) >= C) { + continue; + } else { + if (src[curr_row * ld + j * CB + jb] != curr_row_zero_point) { + isCurrentBlockNonZero = true; + break; + } + } + } + } + if (isCurrentBlockNonZero) { + for (int64_t ib = 0; ib < RB; ++ib) { + for (int64_t jb = 0; jb < CB; ++jb) { + if ((i * RB + ib) >= R || (j * CB + jb) >= C) { + // zero fill + values.push_back(0); + } else { + int8_t val = src[(i * RB + ib) * ld + j * CB + jb]; + values.push_back(val); + } + } + } + colBIdx.push_back(static_cast(j)); + nnzb++; + } + } + rowBPtr.push_back(static_cast(nnzb)); + } + return ao::sparse::BCSR( + std::move(values), std::move(rowBPtr), std::move(colBIdx)); +} +#endif // USE_FBGEMM +} // namespace + +#ifdef USE_FBGEMM + +BCSRSerializationType PackedLinearWeight::serialize() { + // Get weights, row indices, and col indices in untiled form; + // unpack the tiled bcsr then pack it in untiled form + std::vector dense_weight_values = std::vector(w->R * w->C); + w->unpack(dense_weight_values.data()); + + const bool qscheme_per_tensor = (q_scheme == c10::kPerTensorAffine); + at::Tensor zero_points = wrap_vector(w_zp, c10::kChar); + + ao::sparse::BCSR untiled_bcsr = pack_bcsr( + dense_weight_values.data(), + w->R, + w->C, + w->RB, + w->CB, + zero_points.data_ptr(), + qscheme_per_tensor); + + std::vector& packed_weight_values = std::get<0>(untiled_bcsr); + // Add 128 to each weight value. This serialization format is best for + // minimizing memory footprint for QNNPack + + at::Tensor weight_values = at::empty( + {static_cast(packed_weight_values.size())}, + at::device(c10::kCPU).dtype(c10::kByte)); + std::transform( + packed_weight_values.begin(), + packed_weight_values.end(), + weight_values.data_ptr(), + [](int8_t v) { + return static_cast(static_cast(v) + 128); + }); + + return BCSRSerializationType( + SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION, + bias_, + out_features_block_size_, + in_features_block_size_, + wrap_vector(w_scale, c10::kFloat), + // Narrowing from int32_t to int8_t; this is okay because qint8 zero + // points are restricted to fit in bounds of int_8 + std::move(zero_points), + qscheme_per_tensor, + wrap_vector( + std::get<1>(untiled_bcsr), c10::kInt), // Row block indices + wrap_vector( + std::get<2>(untiled_bcsr), c10::kInt), // Col block indices + std::move(weight_values), + w->R, + w->C); +} + +#endif // USE_FBGEMM + +#ifdef USE_PYTORCH_QNNPACK + +BCSRSerializationType PackedLinearWeightQnnp::serialize() { + at::Tensor w_scales_compact; + at::Tensor w_zero_points_compact; + const float* w_scales_data_ptr = w_scales_.data_ptr(); + std::function subtract_128 = [](uint8_t v) { + return static_cast(static_cast(v) - 128); + }; + + if (q_scheme_ == at::kPerTensorAffine) { + w_scales_compact = at::empty({1}, at::device(c10::kCPU).dtype(c10::kFloat)); + w_zero_points_compact = + at::empty({1}, at::device(c10::kCPU).dtype(c10::kChar)); + + w_scales_compact.data_ptr()[0] = w_scales_data_ptr[0]; + w_zero_points_compact.data_ptr()[0] = + static_cast(static_cast(w_zero_points_[0]) - 128); + } else if (q_scheme_ == at::kPerChannelAffine) { + w_scales_compact = + at::empty({output_channels_}, at::device(c10::kCPU).dtype(c10::kFloat)); + w_zero_points_compact = + at::empty({output_channels_}, at::device(c10::kCPU).dtype(c10::kChar)); + + std::copy( + w_scales_data_ptr, + w_scales_data_ptr + + output_channels_, // Don't go to the end because of padding + w_scales_compact.data_ptr()); + + // Subtract 128 from each zero point, to reverse addition done during + // prepacking + std::transform( + w_zero_points_.begin(), + w_zero_points_.begin() + + output_channels_, // Don't go to the end because of padding + w_zero_points_compact.data_ptr(), + subtract_128); + } else { + TORCH_CHECK(false, "Unsupported quantization scheme."); + } + + return BCSRSerializationType( + SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION, + orig_bias_, + out_features_block_size_, + in_features_block_size_, + std::move(w_scales_compact), + std::move(w_zero_points_compact), + (q_scheme_ == c10::kPerTensorAffine), + wrap_vector( + bcsr_matrix_->row_values, c10::kInt), // Casting from uint32_t to int + wrap_vector( + bcsr_matrix_->col_indices, c10::kInt), // Casting from uint32_t to int + wrap_vector(bcsr_matrix_->values, c10::kByte), + output_channels_, + input_channels_); +} + +#endif // USE_PYTORCH_QNNPACK + +} // namespace sparse +} // namespace ao diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp index ec6e160b16c3e..c10cc40af4a20 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp @@ -14,35 +14,41 @@ int register_linear_params(); LinearPackedSerializationType PackedLinearWeight::unpack() { auto packW = w.get(); - int64_t N = static_cast(packW->R); - int64_t K = static_cast(packW->C); + const int64_t N = static_cast(packW->R); + const int64_t K = static_cast(packW->C); at::Tensor weight_origin; if (q_scheme == c10::kPerTensorAffine) { weight_origin = at::_empty_affine_quantized( {N, K}, at::device(c10::kCPU).dtype(c10::kQInt8), w_scale[0], w_zp[0]); } else if (q_scheme == c10::kPerChannelAffine) { - auto scales = at::from_blob( - w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat)); - auto zero_points = at::from_blob( - w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kInt)); + at::Tensor scales = at::empty( + {static_cast(w_scale.size())}, + at::device(c10::kCPU).dtype(c10::kFloat)); + std::copy(w_scale.begin(), w_scale.end(), scales.data_ptr()); + + at::Tensor zero_points = at::empty( + {static_cast(w_zp.size())}, + at::device(c10::kCPU).dtype(c10::kInt)); + std::copy(w_zp.begin(), w_zp.end(), zero_points.data_ptr()); weight_origin = at::_empty_per_channel_affine_quantized( {N, K}, - scales.toType(c10::kDouble), - zero_points.toType(c10::kLong), + scales, + zero_points, 0, // The output channel axis is 0 device(c10::kCPU).dtype(c10::kQInt8)); } - // TODO: uncomment once unpack is implemented for BCSRMatrix - // int8_t* weight_ptr_int8 = - // reinterpret_cast(weight_origin.data_ptr()); - // packW->unpack(weight_ptr_int8); - std::vector block_pattern( + int8_t* weight_ptr_int8 = + reinterpret_cast(weight_origin.data_ptr()); + + packW->unpack(weight_ptr_int8); + + const std::vector block_pattern( {out_features_block_size_, in_features_block_size_}); - return std::make_tuple(weight_origin, bias_, std::move(block_pattern)); + return std::make_tuple(std::move(weight_origin), bias_, block_pattern); } #endif // USE_FBGEMM @@ -50,9 +56,58 @@ LinearPackedSerializationType PackedLinearWeight::unpack() { #ifdef USE_PYTORCH_QNNPACK LinearPackedSerializationType PackedLinearWeightQnnp::unpack() { + const int64_t N = static_cast(output_channels_); + const int64_t K = static_cast(input_channels_); + + float* w_scales_ptr = w_scales_.data_ptr(); + + at::Tensor weight_origin; + if (q_scheme_ == c10::kPerTensorAffine) { + weight_origin = at::_empty_affine_quantized( + {N, K}, + at::device(c10::kCPU).dtype(c10::kQInt8), + w_scales_ptr[0], + w_zero_points_[0] - 128); + } else if (q_scheme_ == c10::kPerChannelAffine) { + at::Tensor scales = at::empty( + {static_cast(output_channels_)}, + at::device(c10::kCPU).dtype(c10::kFloat)); + std::copy( + w_scales_ptr, + w_scales_ptr + output_channels_, + scales.data_ptr()); + + at::Tensor zero_points = at::empty( + {static_cast(output_channels_)}, + at::device(c10::kCPU).dtype(c10::kInt)); + std::transform( + w_zero_points_.begin(), + w_zero_points_.begin() + output_channels_, + zero_points.data_ptr(), + [](uint8_t v) { return static_cast(v) - 128; }); + + weight_origin = at::_empty_per_channel_affine_quantized( + {N, K}, + scales, + zero_points, + 0, // The output channel axis is 0 + device(c10::kCPU).dtype(c10::kQInt8)); + } + + int8_t* weight_ptr_int8 = + reinterpret_cast(weight_origin.data_ptr()); + + bcsr_matrix_->unpack( + weight_ptr_int8, + output_channels_, + input_channels_, + w_zero_points_.data()); + std::vector block_pattern( {out_features_block_size_, in_features_block_size_}); - return std::make_tuple(orig_weight_, orig_bias_, std::move(block_pattern)); + + return std::make_tuple( + std::move(weight_origin), bias_, std::move(block_pattern)); } #endif // USE_FBGEMM @@ -67,7 +122,7 @@ class QLinearUnpackWeightInt8 final { } }; -TORCH_LIBRARY_IMPL(sparse, QuantizedCPU, m) { +TORCH_LIBRARY_IMPL(sparse, CatchAll, m) { m.impl( TORCH_SELECTIVE_NAME("sparse::qlinear_unpack"), TORCH_FN(QLinearUnpackWeightInt8::run)); diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h b/aten/src/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h index 48717aa59c7dd..098b862297fd5 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h @@ -17,7 +17,7 @@ namespace sparse { struct TORCH_API PackedLinearWeightQnnp : public LinearPackedParamsBase { PackedLinearWeightQnnp(const at::Tensor& weight, const c10::optional& bias, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */); - at::Tensor orig_weight_; + explicit PackedLinearWeightQnnp(const BCSRSerializationType& serialized); c10::optional orig_bias_; // Seperate copy of bias exist so that we can fill in zeros when // optional bias does not exist. This is to compy with qnnpack operator that @@ -32,6 +32,15 @@ struct TORCH_API PackedLinearWeightQnnp std::vector requantization_scales_; std::unique_ptr sparse_linear_op_{nullptr}; + int64_t output_channels_; + int64_t input_channels_; + // Deserialized Tensors are stored to maintain the lifetime of underlying + // BCSR data. + // These are left empty if PackedLinearWeightQnnp is created via prepacking + // rather than deserializing. + at::Tensor deserialized_bcsr_row_block_indices_; + at::Tensor deserialized_bcsr_col_block_indices_; + at::Tensor deserialized_bcsr_weight_values_; at::Tensor apply( const at::Tensor& input, @@ -53,6 +62,11 @@ struct TORCH_API PackedLinearWeightQnnp LinearPackedSerializationType unpack() override; + BCSRSerializationType serialize() override; + + static c10::intrusive_ptr deserialize( + const BCSRSerializationType& serialized); + c10::optional bias() override { return orig_bias_; } diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index 14fef621b10f7..6f3eac783ccda 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -558,14 +558,15 @@ void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) { AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "hardshrink_cpu", [&] { auto lambd_val = lambd.to(); + using Vec = Vectorized; cpu_kernel_vec( iter, [=](scalar_t self_val) { return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : self_val; }, - [=](Vectorized self_val) { - return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val; + [=](Vec self_val) { + return Vec::blendv(self_val, Vec(0), (self_val >= -lambd_val) & (self_val <= lambd_val)); }); }); } diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 57fc6b7c56de3..cf12c392f8682 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -8,13 +8,13 @@ namespace native { namespace cpublas { namespace { -template -void scale_(int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda) { - if (alpha == scalar_t(1)) { +template +void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t *a, int64_t lda) { + if (alpha == opmath_t(1)) { return; // identity } - if (alpha == scalar_t(0)) { + if (alpha == opmath_t(0)) { for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { a[j * lda + i] = scalar_t(0); @@ -31,13 +31,13 @@ void scale_(int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda) { } -template +template void gemm_notrans_( int64_t m, int64_t n, int64_t k, - scalar_t alpha, + opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, - scalar_t beta, + opmath_t beta, scalar_t *c, int64_t ldc) { // c *= beta scale_(m, n, beta, c, ldc); @@ -45,7 +45,7 @@ void gemm_notrans_( // c += alpha * (a @ b) for (const auto l : c10::irange(k)) { for (const auto j : c10::irange(n)) { - scalar_t val = b[l + j * ldb] * alpha; + opmath_t val = b[l + j * ldb] * alpha; int64_t i_m = m / 4; for (const auto i_i : c10::irange(i_m)) { c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val; @@ -60,22 +60,22 @@ void gemm_notrans_( } } -template +template void gemm_transa_( int64_t m, int64_t n, int64_t k, - scalar_t alpha, + opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, - scalar_t beta, + opmath_t beta, scalar_t *c, int64_t ldc) { // c = alpha * (a.T @ b) + beta * c const scalar_t *a_ = a; for (const auto i : c10::irange(m)) { const scalar_t *b_ = b; for (const auto j : c10::irange(n)) { - scalar_t sum = 0; + opmath_t sum = 0; for (const auto l : c10::irange(k)) { - sum += a_[l]*b_[l]; + sum += static_cast(a_[l]) * static_cast(b_[l]); } b_ += ldb; if (beta == scalar_t(0)) @@ -87,13 +87,13 @@ void gemm_transa_( } } -template +template void gemm_transb_( int64_t m, int64_t n, int64_t k, - scalar_t alpha, + opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, - scalar_t beta, + opmath_t beta, scalar_t *c, int64_t ldc) { // c *= beta scale_(m, n, beta, c, ldc); @@ -101,7 +101,7 @@ void gemm_transb_( // c += alpha * (a @ b.T) for (const auto l : c10::irange(k)) { for (const auto j : c10::irange(n)) { - scalar_t val = b[j + l * ldb] * alpha; + opmath_t val = b[j + l * ldb] * alpha; int64_t i_m = m / 4; for (const auto i_i : c10::irange(i_m)) { c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val; @@ -116,13 +116,13 @@ void gemm_transb_( } } -template +template void gemm_transab_( int64_t m, int64_t n, int64_t k, - scalar_t alpha, + opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, - scalar_t beta, + opmath_t beta, scalar_t *c, int64_t ldc) { // c *= beta scale_(m, n, beta, c, ldc); @@ -133,29 +133,29 @@ void gemm_transab_( int64_t l_k = k / 4; for (const auto l_l : c10::irange(l_k)) { c[j * ldc + i] += a[i * lda + l_l * 4 + 0] // - * b[(l_l * 4 + 0) * ldb + j] * alpha; + * (b[(l_l * 4 + 0) * ldb + j] * alpha); c[j * ldc + i] += a[i * lda + l_l * 4 + 1] // - * b[(l_l * 4 + 1) * ldb + j] * alpha; + * (b[(l_l * 4 + 1) * ldb + j] * alpha); c[j * ldc + i] += a[i * lda + l_l * 4 + 2] // - * b[(l_l * 4 + 2) * ldb + j] * alpha; + * (b[(l_l * 4 + 2) * ldb + j] * alpha); c[j * ldc + i] += a[i * lda + l_l * 4 + 3] // - * b[(l_l * 4 + 3) * ldb + j] * alpha; + * (b[(l_l * 4 + 3) * ldb + j] * alpha); } int64_t l = l_k * 4; for (; l < k; l++) - c[j * ldc + i] += a[i * lda + l] * b[l * ldb + j] * alpha; + c[j * ldc + i] += a[i * lda + l] * (b[l * ldb + j] * alpha); } } } -template +template void gemm_core_( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, - scalar_t alpha, + opmath_t alpha, const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, - scalar_t beta, + opmath_t beta, scalar_t *c, int64_t ldc) { if(transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) { return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -180,12 +180,13 @@ void cpublas_gemm_impl( AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16, type, "cpublas_gemm_impl", [&]{ + using opmath_t = at::opmath_type; gemm_core_( transa, transb, m, n, k, - alpha.to(), + alpha.to(), static_cast(a), lda, static_cast(b), ldb, - beta.to(), + beta.to(), static_cast(c), ldc); }); } @@ -201,7 +202,8 @@ void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const v } else { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16, type, "cpublas_axpy_impl", [&] { - auto a = _a.to(); + using opmath_t = at::opmath_type; + auto a = _a.to(); auto x = static_cast(_x); auto y = static_cast(_y); int64_t i; diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index 40a0c20b5ca8d..de1841d989c3b 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include namespace at { namespace native { @@ -15,6 +17,159 @@ void conj_kernel(TensorIteratorBase &iter); namespace { +void float_bfloat16_copy_kernel(TensorIteratorBase &iter, bool requires_neg) { + auto strides_out = iter.strides(0); + auto strides_in = iter.strides(1); + auto shape = iter.shape(); + c10::SmallBuffer strides(2 * std::max(iter.ndim(), 2)); + auto get_strides = [](int64_t* strides, IntArrayRef strides_out, IntArrayRef strides_in, int64_t ndim) { + for (const auto dim : c10::irange(ndim)) { + for (const auto arg : c10::irange(2)) { + *strides++ = arg == 0? strides_out[dim] : strides_in[dim]; + } + } + // Always at least 2d strides to support 2d for_each loops + if (ndim < 2) { + std::fill_n(strides, (2 - ndim) * 2, 0); + } + }; + get_strides(strides.data(), strides_out, strides_in, iter.ndim()); + if ((iter.dtype(0) == kFloat) && (iter.dtype(1) == kBFloat16)) { + using dest_t = float; + using scalar_t = BFloat16; + using Vecd = Vectorized; + using Vecs = Vectorized; + c10::SmallBuffer ptrs(2); + dest_t* output_data = iter.tensor_base(0).data_ptr(); + scalar_t* input_data = iter.tensor_base(1).data_ptr(); + ptrs[0] = reinterpret_cast(output_data); + ptrs[1] = reinterpret_cast(input_data); + + int64_t grain_size = at::internal::GRAIN_SIZE; + + auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) { + std::array data; + std::copy_n(base, 2, data.data()); + const int64_t *outer_strides = &strides[2]; + + for (const auto it : c10::irange(size1)) { + (void)it; + Vecd dst_s; + if (strides_in[0] == 0) { + dst_s = Vecd(dest_t(*((scalar_t*)data[1]))); + if (requires_neg) { + dst_s = dst_s.neg(); + } + } + int64_t i = 0; + for (; i <= size0 - Vecs::size(); i += Vecs::size()) { + if (strides_in[0] != 0) { + Vecs data_vec = Vecs::loadu(data[1] + i * sizeof(scalar_t)); + Vecd data_vec0, data_vec1; + std::tie(data_vec0, data_vec1) = convert_bfloat16_float(data_vec); + if (requires_neg) { + data_vec0 = data_vec0.neg(); + data_vec1 = data_vec1.neg(); + } + data_vec0.store(data[0] + i * sizeof(dest_t)); + data_vec1.store(data[0] + (i + Vecd::size()) * sizeof(dest_t)); + } else { + dst_s.store(data[0] + i * sizeof(dest_t)); + dst_s.store(data[0] + (i + Vecd::size()) * sizeof(dest_t)); + } + } + if (i < size0) { + if (strides_in[0] != 0) { + Vecs data_vec = Vecs::loadu(data[1] + i * sizeof(scalar_t), size0 - i); + Vecd data_vec0, data_vec1; + std::tie(data_vec0, data_vec1) = convert_bfloat16_float(data_vec); + if (requires_neg) { + data_vec0 = data_vec0.neg(); + data_vec1 = data_vec1.neg(); + } + data_vec0.store(data[0] + i * sizeof(dest_t), ((size0 - i) > Vecd::size())? Vecd::size() : (size0 - i)); + data_vec1.store(data[0] + (i + Vecd::size()) * sizeof(dest_t), ((size0 - i) > Vecd::size())? (size0 - i - Vecd::size()) : 0); + } else { + dst_s.store(data[0] + i * sizeof(dest_t), ((size0 - i) > Vecd::size())? Vecd::size() : (size0 - i)); + dst_s.store(data[0] + (i + Vecd::size()) * sizeof(dest_t), ((size0 - i) > Vecd::size())? (size0 - i - Vecd::size()) : 0); + } + } + data[0] += outer_strides[0]; + data[1] += outer_strides[1]; + } + + }; + + parallel_for(0, iter.numel(), grain_size, [&] (int64_t begin, int64_t end) { + at::internal::serial_for_each(shape, strides, ptrs.data(), 2, loop, {begin, end}); + }); + } else if ((iter.dtype(0) == kBFloat16) && (iter.dtype(1) == kFloat)) { + using dest_t = BFloat16; + using scalar_t = float; + using Vecd = Vectorized; + using Vecs = Vectorized; + c10::SmallBuffer ptrs(2); + dest_t* output_data = iter.tensor_base(0).data_ptr(); + scalar_t* input_data = iter.tensor_base(1).data_ptr(); + + ptrs[0] = reinterpret_cast(output_data); + ptrs[1] = reinterpret_cast(input_data); + + int64_t grain_size = at::internal::GRAIN_SIZE; + + auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) { + std::array data; + std::copy_n(base, 2, data.data()); + const int64_t *outer_strides = &strides[2]; + + for (const auto it : c10::irange(size1)) { + (void)it; + Vecd dst_s; + if (strides_in[0] == 0) { + dst_s = Vecd(dest_t(*((scalar_t*)data[1]))); + if (requires_neg) { + dst_s = dst_s.neg(); + } + } + int64_t i = 0; + for (; i <= size0 - 2 * Vecs::size(); i += 2 * Vecs::size()) { + if (strides_in[0] != 0) { + Vecs data_vec0 = Vecs::loadu(data[1] + i * sizeof(scalar_t)); + Vecs data_vec1 = Vecs::loadu(data[1] + (i + Vecs::size()) * sizeof(scalar_t)); + auto data_vec = convert_float_bfloat16(data_vec0, data_vec1); + if (requires_neg) { + data_vec = data_vec.neg(); + } + data_vec.store(data[0] + i * sizeof(dest_t)); + } else { + dst_s.store(data[0] + i * sizeof(dest_t)); + } + + } + if (i < size0) { + if (strides_in[0] != 0) { + Vecs data_vec0 = Vecs::loadu(data[1] + i * sizeof(scalar_t), ((size0 - i) > Vecs::size())? Vecs::size() : (size0 - i)); + Vecs data_vec1 = Vecs::loadu(data[1] + (i + Vecs::size()) * sizeof(scalar_t), ((size0 - i) > Vecs::size())? (size0 - i - Vecs::size()) : 0); + auto data_vec = convert_float_bfloat16(data_vec0, data_vec1); + if (requires_neg) { + data_vec = data_vec.neg(); + } + data_vec.store(data[0] + i * sizeof(dest_t), size0 - i); + } else { + dst_s.store(data[0] + i * sizeof(dest_t), size0 - i); + } + } + data[0] += outer_strides[0]; + data[1] += outer_strides[1]; + } + + }; + parallel_for(0, iter.numel(), grain_size, [&] (int64_t begin, int64_t end) { + at::internal::serial_for_each(shape, strides, ptrs.data(), 2, loop, {begin, end}); + }); + } +} + void direct_copy_kernel(TensorIteratorBase &iter) { // TODO: we don't actually need separate instantiations per dtype; // we only need a separate instantiation per dtype size. This would @@ -78,8 +233,15 @@ void copy_kernel(TensorIterator& iter, bool /*non_blocking*/) { isComplexType(dtype) && (iter.tensor_base(0).is_conj() != iter.tensor_base(1).is_conj())); const bool requires_neg = (iter.tensor_base(0).is_neg() != iter.tensor_base(1).is_neg()); + auto strides_out = iter.strides(0); + auto strides_in = iter.strides(1); if (dtype == iter.dtype(1)) { copy_same_dtype(iter, requires_conj, requires_neg); + } else if (!requires_conj && ((iter.dtype(1) == kBFloat16 && iter.dtype(0) == kFloat && + sizeof(float) == strides_out[0] && (sizeof(BFloat16) == strides_in[0] || strides_in[0] == 0)) || + (iter.dtype(1) == kFloat && iter.dtype(0) == kBFloat16 && + sizeof(BFloat16) == strides_out[0] && (sizeof(float) == strides_in[0] || strides_in[0] == 0)))) { + float_bfloat16_copy_kernel(iter, requires_neg); } else { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] { using dest_t = scalar_t; diff --git a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp index d3be310e28024..602f8cd41b39c 100644 --- a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp @@ -11,8 +11,8 @@ namespace native { namespace { static void addcmul_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) { - ScalarType dtype = iter.dtype(0); - if (iter.dtype() == kBFloat16) { + ScalarType dtype = iter.common_dtype(); + if (dtype == kBFloat16) { float float_val = value.to(); auto float_vec = Vectorized(float_val); cpu_kernel_vec( @@ -51,7 +51,7 @@ static void addcmul_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) { } static void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) { - ScalarType dtype = iter.dtype(0); + ScalarType dtype = iter.common_dtype(); if (dtype == kBFloat16) { float float_val = value.to(); auto float_vec = Vectorized(float_val); diff --git a/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp b/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp index c92c68f9f4718..449f229e2d1b2 100644 --- a/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp +++ b/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp @@ -19,7 +19,7 @@ namespace { using namespace vec; static void arange_kernel(TensorIterator& iter, const Scalar& scalar_start, const Scalar& scalar_steps, const Scalar& scalar_step) { - AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "arange_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "arange_cpu", [&]() { using accscalar_t = at::acc_type; auto start = scalar_start.to(); auto steps = scalar_steps.to(); @@ -43,7 +43,7 @@ static void arange_kernel(TensorIterator& iter, const Scalar& scalar_start, cons } static void linspace_kernel(TensorIterator& iter, const Scalar& scalar_start, const Scalar& scalar_end, int64_t steps) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "linspace_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "linspace_cpu", [&]() { // step should be of double type for all integral types using step_t = std::conditional_t::value, double, scalar_t>; const scalar_t start = scalar_start.to(); diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 58f39f1566771..8fe94699503bb 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -218,7 +218,7 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { char *in = data[ntensors - 1]; int64_t stride = strides[ntensors - 1]; for (const auto i : c10::irange(size)) { - acc = ops.reduce(acc, *(data_t*)in, begin + i); + acc = ops.reduce(acc, c10::load(in), begin + i); in += stride; } }, {begin, end}); diff --git a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp index 908d4fc60b7b6..98e569dde1504 100644 --- a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp +++ b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp @@ -250,6 +250,413 @@ inline void _vec_host_softmax_backward_lastdim( }); } +template +inline void _vec_softmax_backward( + scalar_t* grad_input_data_base, + scalar_t* grad_output_data_base, + scalar_t* output_data_base, + int64_t outer_size, + int64_t inner_size, + int64_t dim_size) { + using Vec = vec::Vectorized; + int64_t outer_stride = dim_size * inner_size; + int64_t BLOCK_SIZE = 128 * 1024; + int64_t CHUNK_SIZE = std::max( + int64_t(BLOCK_SIZE / dim_size / sizeof(scalar_t)), (int64_t)Vec::size()); + CHUNK_SIZE = CHUNK_SIZE / Vec::size() * Vec::size(); + int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size * CHUNK_SIZE); + parallel_for( + 0, outer_size * num_chunks, grain_size, [&](int64_t begin, int64_t end) { + // thread local temp buffer that holds vertical sum result + std::unique_ptr buffer(new scalar_t[CHUNK_SIZE]); + scalar_t* tmp_sum_data = buffer.get(); + + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * CHUNK_SIZE; + int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin); + + // init + Vec zero_vec = Vec(scalar_t(0)); + int64_t d0 = 0; + for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { + zero_vec.store(tmp_sum_data + d0); + } + for (; d0 < size; d0++) { + tmp_sum_data[d0] = scalar_t(0); + } + + // compute sum of grad_output * output + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + int64_t offset = outer_idx * outer_stride + dim_idx * inner_size + + inner_idx_begin; + scalar_t* grad_output_ptr = grad_output_data_base + offset; + scalar_t* output_ptr = output_data_base + offset; + + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1); + Vec output_vec = Vec::loadu(output_ptr + d1); + Vec sum_vec = Vec::loadu(tmp_sum_data + d1); + sum_vec += grad_output_vec * output_vec; + sum_vec.store(tmp_sum_data + d1); + } + for (; d1 < size; d1++) { + tmp_sum_data[d1] += grad_output_ptr[d1] * output_ptr[d1]; + } + } + + // compute output * (grad_output - sum) + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + int64_t offset = outer_idx * outer_stride + dim_idx * inner_size + + inner_idx_begin; + scalar_t* grad_output_ptr = grad_output_data_base + offset; + scalar_t* output_ptr = output_data_base + offset; + scalar_t* grad_input_ptr = grad_input_data_base + offset; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + Vec grad_output_vec = Vec::loadu(grad_output_ptr + d2); + Vec output_vec = Vec::loadu(output_ptr + d2); + Vec sum_vec = Vec::loadu(tmp_sum_data + d2); + Vec grad_input_vec = output_vec * (grad_output_vec - sum_vec); + grad_input_vec.store(grad_input_ptr + d2); + } + for (; d2 < size; d2++) { + grad_input_ptr[d2] = output_ptr[d2] * (grad_output_ptr[d2] - tmp_sum_data[d2]); + } + } + } + }); +} + +template <> +inline void _vec_softmax_backward( + BFloat16* grad_input_data_base, + BFloat16* grad_output_data_base, + BFloat16* output_data_base, + int64_t outer_size, + int64_t inner_size, + int64_t dim_size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t outer_stride = dim_size * inner_size; + int64_t BLOCK_SIZE = 128 * 1024; + int64_t CHUNK_SIZE = std::max( + int64_t(BLOCK_SIZE / dim_size / sizeof(BFloat16)), (int64_t)bVec::size()); + CHUNK_SIZE = CHUNK_SIZE / bVec::size() * bVec::size(); + int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size * CHUNK_SIZE); + parallel_for( + 0, outer_size * num_chunks, grain_size, [&](int64_t begin, int64_t end) { + // thread local temp buffer that holds vertical sum result + std::unique_ptr buffer(new float[CHUNK_SIZE]); + float* tmp_sum_data = buffer.get(); + + // thread local buffer that holds grad_output and output data in float32 + std::unique_ptr grad_output_buffer( + new float[dim_size * CHUNK_SIZE]); + float* grad_output_buffer_data = grad_output_buffer.get(); + + std::unique_ptr output_buffer( + new float[dim_size * CHUNK_SIZE]); + float* output_buffer_data = output_buffer.get(); + + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * CHUNK_SIZE; + int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin); + + // init + fVec zero_fvec = fVec(float(0)); + int64_t d0 = 0; + for (; d0 < size - (size % bVec::size()); d0 += bVec::size()) { + zero_fvec.store(tmp_sum_data + d0); + zero_fvec.store(tmp_sum_data + d0 + fVec::size()); + } + for (; d0 < size; d0++) { + tmp_sum_data[d0] = float(0); + } + + // compute sum of grad_output * output + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + int64_t offset = outer_idx * outer_stride + dim_idx * inner_size + + inner_idx_begin; + BFloat16* grad_output_ptr = grad_output_data_base + offset; + BFloat16* output_ptr = output_data_base + offset; + float* grad_output_buffer_ptr = + grad_output_buffer_data + dim_idx * CHUNK_SIZE; + float* output_buffer_ptr = + output_buffer_data + dim_idx * CHUNK_SIZE; + + int64_t d1 = 0; + for (; d1 < size - (size % bVec::size()); d1 += bVec::size()) { + bVec grad_output_bvec = bVec::loadu(grad_output_ptr + d1); + fVec grad_output_fvec0, grad_output_fvec1; + std::tie(grad_output_fvec0, grad_output_fvec1) = + convert_bfloat16_float(grad_output_bvec); + bVec output_bvec = bVec::loadu(output_ptr + d1); + fVec output_fvec0, output_fvec1; + std::tie(output_fvec0, output_fvec1) = + convert_bfloat16_float(output_bvec); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d1); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d1 + fVec::size()); + sum_fvec0 += grad_output_fvec0 * output_fvec0; + sum_fvec1 += grad_output_fvec1 * output_fvec1; + sum_fvec0.store(tmp_sum_data + d1); + sum_fvec1.store(tmp_sum_data + d1 + fVec::size()); + + // cache the 'converted' float grad_output and output + grad_output_fvec0.store(grad_output_buffer_ptr + d1); + grad_output_fvec1.store( + grad_output_buffer_ptr + d1 + fVec::size()); + output_fvec0.store(output_buffer_ptr + d1); + output_fvec1.store(output_buffer_ptr + d1 + fVec::size()); + } + for (; d1 < size; d1++) { + float grad_output_val = float(grad_output_ptr[d1]); + float output_val = float(output_ptr[d1]); + tmp_sum_data[d1] += grad_output_val * output_val; + grad_output_buffer_ptr[d1] = grad_output_val; + output_buffer_ptr[d1] = output_val; + } + } + + // compute output * (grad_output - sum) + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + BFloat16* grad_input_ptr = grad_input_data_base + + outer_idx * outer_stride + dim_idx * inner_size + + inner_idx_begin; + float* grad_output_buffer_ptr = + grad_output_buffer_data + dim_idx * CHUNK_SIZE; + float* output_buffer_ptr = + output_buffer_data + dim_idx * CHUNK_SIZE; + + int64_t d2 = 0; + for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) { + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size()); + fVec grad_output_fvec0 = fVec::loadu(grad_output_buffer_ptr + d2); + fVec grad_output_fvec1 = + fVec::loadu(grad_output_buffer_ptr + d2 + fVec::size()); + fVec output_fvec0 = fVec::loadu(output_buffer_ptr + d2); + fVec output_fvec1 = + fVec::loadu(output_buffer_ptr + d2 + fVec::size()); + fVec grad_input_fvec0 = + output_fvec0 * (grad_output_fvec0 - sum_fvec0); + fVec grad_input_fvec1 = + output_fvec1 * (grad_output_fvec1 - sum_fvec1); + bVec grad_input_bvec = + convert_float_bfloat16(grad_input_fvec0, grad_input_fvec1); + grad_input_bvec.store(grad_input_ptr + d2); + } + for (; d2 < size; d2++) { + grad_input_ptr[d2] = output_buffer_ptr[d2] * (grad_output_buffer_ptr[d2] - tmp_sum_data[d2]); + } + } + } + }); +} + +template +inline void _vec_log_softmax_backward( + scalar_t* grad_input_data_base, + scalar_t* grad_output_data_base, + scalar_t* output_data_base, + int64_t outer_size, + int64_t inner_size, + int64_t dim_size) { + using Vec = vec::Vectorized; + int64_t outer_stride = dim_size * inner_size; + int64_t BLOCK_SIZE = 128 * 1024; + int64_t CHUNK_SIZE = std::max( + int64_t(BLOCK_SIZE / dim_size / sizeof(scalar_t)), (int64_t)Vec::size()); + CHUNK_SIZE = CHUNK_SIZE / Vec::size() * Vec::size(); + int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size * CHUNK_SIZE); + parallel_for( + 0, outer_size * num_chunks, grain_size, [&](int64_t begin, int64_t end) { + // thread local temp buffer that holds vertical sum result + std::unique_ptr buffer(new scalar_t[CHUNK_SIZE]); + scalar_t* tmp_sum_data = buffer.get(); + + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * CHUNK_SIZE; + int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin); + + // init + Vec zero_vec = Vec(scalar_t(0)); + int64_t d0 = 0; + for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { + zero_vec.store(tmp_sum_data + d0); + } + for (; d0 < size; d0++) { + tmp_sum_data[d0] = scalar_t(0); + } + + // compute sum of grad_output + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + scalar_t* grad_output_ptr = grad_output_data_base + + outer_idx * outer_stride + dim_idx * inner_size + + inner_idx_begin; + + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec grad_output_vec = Vec::loadu(grad_output_ptr + d1); + Vec sum_vec = Vec::loadu(tmp_sum_data + d1); + sum_vec += grad_output_vec; + sum_vec.store(tmp_sum_data + d1); + } + for (; d1 < size; d1++) { + tmp_sum_data[d1] += grad_output_ptr[d1]; + } + } + + // compute grad_output - output.exp() * sum + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + int64_t offset = outer_idx * outer_stride + dim_idx * inner_size + + inner_idx_begin; + scalar_t* grad_output_ptr = grad_output_data_base + offset; + scalar_t* output_ptr = output_data_base + offset; + scalar_t* grad_input_ptr = grad_input_data_base + offset; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + Vec grad_output_vec = Vec::loadu(grad_output_ptr + d2); + Vec output_vec = Vec::loadu(output_ptr + d2); + Vec sum_vec = Vec::loadu(tmp_sum_data + d2); + Vec grad_input_vec = grad_output_vec - output_vec.exp() * sum_vec; + grad_input_vec.store(grad_input_ptr + d2); + } + for (; d2 < size; d2++) { + grad_input_ptr[d2] = grad_output_ptr[d2] - + std::exp(output_ptr[d2]) * tmp_sum_data[d2]; + } + } + } + }); +} + +template <> +inline void _vec_log_softmax_backward( + BFloat16* grad_input_data_base, + BFloat16* grad_output_data_base, + BFloat16* output_data_base, + int64_t outer_size, + int64_t inner_size, + int64_t dim_size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t outer_stride = dim_size * inner_size; + int64_t BLOCK_SIZE = 128 * 1024; + int64_t CHUNK_SIZE = std::max( + int64_t(BLOCK_SIZE / dim_size / sizeof(BFloat16)), (int64_t)bVec::size()); + CHUNK_SIZE = CHUNK_SIZE / bVec::size() * bVec::size(); + int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + int64_t grain_size = internal::GRAIN_SIZE / (16 * dim_size * CHUNK_SIZE); + parallel_for( + 0, outer_size * num_chunks, grain_size, [&](int64_t begin, int64_t end) { + // thread local temp buffer that holds vertical sum result + std::unique_ptr buffer(new float[CHUNK_SIZE]); + float* tmp_sum_data = buffer.get(); + + // thread local buffer that holds grad_output data in float32 + std::unique_ptr grad_output_buffer( + new float[dim_size * CHUNK_SIZE]); + float* grad_output_buffer_data = grad_output_buffer.get(); + + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * CHUNK_SIZE; + int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin); + + // init + fVec zero_fvec = fVec(float(0)); + int64_t d0 = 0; + for (; d0 < size - (size % bVec::size()); d0 += bVec::size()) { + zero_fvec.store(tmp_sum_data + d0); + zero_fvec.store(tmp_sum_data + d0 + fVec::size()); + } + for (; d0 < size; d0++) { + tmp_sum_data[d0] = float(0); + } + + // compute sum of grad_output + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + BFloat16* grad_output_ptr = grad_output_data_base + + outer_idx * outer_stride + dim_idx * inner_size + + inner_idx_begin; + float* grad_output_buffer_ptr = + grad_output_buffer_data + dim_idx * CHUNK_SIZE; + + int64_t d1 = 0; + for (; d1 < size - (size % bVec::size()); d1 += bVec::size()) { + bVec grad_output_bvec = bVec::loadu(grad_output_ptr + d1); + fVec grad_output_fvec0, grad_output_fvec1; + std::tie(grad_output_fvec0, grad_output_fvec1) = + convert_bfloat16_float(grad_output_bvec); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d1); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d1 + fVec::size()); + sum_fvec0 += grad_output_fvec0; + sum_fvec1 += grad_output_fvec1; + sum_fvec0.store(tmp_sum_data + d1); + sum_fvec1.store(tmp_sum_data + d1 + fVec::size()); + + // cache the 'converted' float grad_output + grad_output_fvec0.store(grad_output_buffer_ptr + d1); + grad_output_fvec1.store( + grad_output_buffer_ptr + d1 + fVec::size()); + } + for (; d1 < size; d1++) { + float grad_output_val = float(grad_output_ptr[d1]); + tmp_sum_data[d1] += grad_output_val; + grad_output_buffer_ptr[d1] = grad_output_val; + } + } + + // compute grad_output - output.exp() * sum + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + int64_t offset = outer_idx * outer_stride + dim_idx * inner_size + + inner_idx_begin; + BFloat16* output_ptr = output_data_base + offset; + BFloat16* grad_input_ptr = grad_input_data_base + offset; + float* grad_output_buffer_ptr = + grad_output_buffer_data + dim_idx * CHUNK_SIZE; + + int64_t d2 = 0; + for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) { + bVec output_bvec = bVec::loadu(output_ptr + d2); + fVec output_fvec0, output_fvec1; + std::tie(output_fvec0, output_fvec1) = + convert_bfloat16_float(output_bvec); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size()); + fVec grad_output_fvec0 = fVec::loadu(grad_output_buffer_ptr + d2); + fVec grad_output_fvec1 = + fVec::loadu(grad_output_buffer_ptr + d2 + fVec::size()); + fVec grad_input_fvec0 = + grad_output_fvec0 - output_fvec0.exp() * sum_fvec0; + fVec grad_input_fvec1 = + grad_output_fvec1 - output_fvec1.exp() * sum_fvec1; + bVec grad_input_bvec = + convert_float_bfloat16(grad_input_fvec0, grad_input_fvec1); + grad_input_bvec.store(grad_input_ptr + d2); + } + for (; d2 < size; d2++) { + grad_input_ptr[d2] = grad_output_buffer_ptr[d2] - + std::exp(float(output_ptr[d2])) * tmp_sum_data[d2]; + } + } + } + }); +} + template struct vec_host_softmax_lastdim { static void apply(const Tensor& output, const Tensor& input) { @@ -279,13 +686,13 @@ inline void _vec_softmax( using Vec_bf16 = vec::Vectorized; int64_t dim_stride = inner_size; int64_t outer_stride = dim_size * dim_stride; - int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); + int64_t grain_size = std::max(internal::GRAIN_SIZE / dim_size, (int64_t)1); int vectorized_step = Vec_bf16().size(); // Currently, we only support BFloat16 in this special implementation parallel_for( 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { int64_t idx = begin; - std::unique_ptr temp_vec_input(new float[dim_size*vectorized_step*2]()); - std::unique_ptr temp_vec_output(new float[dim_size*vectorized_step*2]()); + std::unique_ptr temp_vec_input(new float[dim_size*vectorized_step]()); + std::unique_ptr temp_vec_output(new float[dim_size*vectorized_step]()); float* temp_vec_input_data = temp_vec_input.get(); float* temp_vec_output_data = temp_vec_output.get(); while (idx < end) { @@ -299,37 +706,37 @@ inline void _vec_softmax( output_data_base + outer_idx * outer_stride + inner_idx; // Step 1: Get max Score Vec_bf16 max_vec_bf16 = Vec_bf16::loadu(input_data); - std::tuple, vec::Vectorized> convert_result = convert_bfloat16_float(max_vec_bf16); + std::tuple convert_result = convert_bfloat16_float(max_vec_bf16); Vec max_vec_o1 = std::get<0>(convert_result); Vec max_vec_o2 = std::get<1>(convert_result); std::get<0>(convert_result).store(temp_vec_input_data); - std::get<1>(convert_result).store(temp_vec_input_data + vectorized_step); + std::get<1>(convert_result).store(temp_vec_input_data + Vec().size()); for (const auto d : c10::irange(1, dim_size)) { Vec_bf16 input_vec_bf16 = Vec_bf16::loadu(input_data + d * dim_stride); convert_result = convert_bfloat16_float(input_vec_bf16); max_vec_o1 = vec::maximum(max_vec_o1, std::get<0>(convert_result)); max_vec_o2 = vec::maximum(max_vec_o2, std::get<1>(convert_result)); - std::get<0>(convert_result).store(temp_vec_input_data + d*vectorized_step*2); - std::get<1>(convert_result).store(temp_vec_input_data + d*vectorized_step*2 + vectorized_step); + std::get<0>(convert_result).store(temp_vec_input_data + d*vectorized_step); + std::get<1>(convert_result).store(temp_vec_input_data + d*vectorized_step + Vec().size()); } // Step2: Calculate sum Vec sum_vec_o1 = Vec(0.0); Vec sum_vec_o2 = Vec(0.0); for (const auto d : c10::irange(dim_size)) { - Vec output_vec_o1 = Vec::loadu(temp_vec_input_data + d*vectorized_step*2); - Vec output_vec_o2 = Vec::loadu(temp_vec_input_data + d*vectorized_step*2 + vectorized_step); + Vec output_vec_o1 = Vec::loadu(temp_vec_input_data + d*vectorized_step); + Vec output_vec_o2 = Vec::loadu(temp_vec_input_data + d*vectorized_step + Vec().size()); output_vec_o1 = (output_vec_o1 - max_vec_o1).exp(); output_vec_o2 = (output_vec_o2 - max_vec_o2).exp(); - output_vec_o1.store(temp_vec_output_data + d*vectorized_step*2); - output_vec_o2.store(temp_vec_output_data + d*vectorized_step*2 + vectorized_step); + output_vec_o1.store(temp_vec_output_data + d*vectorized_step); + output_vec_o2.store(temp_vec_output_data + d*vectorized_step + Vec().size()); sum_vec_o1 = sum_vec_o1 + output_vec_o1; sum_vec_o2 = sum_vec_o2 + output_vec_o2; } // Step3: Unify for (const auto d : c10::irange(dim_size)) { - Vec output_vec_o1 = Vec::loadu(temp_vec_output_data + d*vectorized_step*2); - Vec output_vec_o2 = Vec::loadu(temp_vec_output_data + d*vectorized_step*2 + vectorized_step); + Vec output_vec_o1 = Vec::loadu(temp_vec_output_data + d*vectorized_step); + Vec output_vec_o2 = Vec::loadu(temp_vec_output_data + d*vectorized_step + Vec().size()); output_vec_o1 = output_vec_o1/sum_vec_o1; output_vec_o2 = output_vec_o2/sum_vec_o2; Vec_bf16 output_vec_bf16 = convert_float_bfloat16(output_vec_o1, output_vec_o2); @@ -385,7 +792,7 @@ inline void _vec_softmax( using Vec = vec::Vectorized; int64_t dim_stride = inner_size; int64_t outer_stride = dim_size * dim_stride; - int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); + int64_t grain_size = std::max(internal::GRAIN_SIZE / dim_size, (int64_t)1); int vectorized_step = Vec().size(); parallel_for( 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { @@ -741,6 +1148,45 @@ struct vec_host_softmax_backward_lastdim { } }; +template +struct vec_host_softmax_backward { + static void apply( + const Tensor& grad_input, + const Tensor& grad, + const Tensor& output, + int64_t dim) { + int64_t outer_size = 1; + int64_t dim_size = grad.size(dim); + int64_t inner_size = 1; + for (const auto i : c10::irange(dim)) { + outer_size *= grad.size(i); + } + for (int64_t i = dim + 1; i < grad.dim(); ++i) { + inner_size *= grad.size(i); + } + scalar_t* grad_input_data_base = grad_input.data_ptr(); + scalar_t* grad_output_data_base = grad.data_ptr(); + scalar_t* output_data_base = output.data_ptr(); + if (LogSoftMax) { + _vec_log_softmax_backward( + grad_input_data_base, + grad_output_data_base, + output_data_base, + outer_size, + inner_size, + dim_size); + } else { + _vec_softmax_backward( + grad_input_data_base, + grad_output_data_base, + output_data_base, + outer_size, + inner_size, + dim_size); + } + } +}; + static void softmax_lastdim_kernel_impl( const Tensor& result, const Tensor& self) { @@ -795,6 +1241,36 @@ static void log_softmax_backward_lastdim_kernel_impl( }); } +static void softmax_backward_kernel_impl( + const Tensor& grad_input, + const Tensor& grad, + const Tensor& output, + int64_t dim) { + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + grad.scalar_type(), + "softmax_backward_kernel_impl", + [&] { + vec_host_softmax_backward::apply( + grad_input, grad, output, dim); + }); +} + +static void log_softmax_backward_kernel_impl( + const Tensor& grad_input, + const Tensor& grad, + const Tensor& output, + int64_t dim) { + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, + grad.scalar_type(), + "log_softmax_backward_kernel_impl", + [&] { + vec_host_softmax_backward::apply( + grad_input, grad, output, dim); + }); +} + } // anonymous namespace REGISTER_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_impl); @@ -808,5 +1284,8 @@ REGISTER_DISPATCH( REGISTER_DISPATCH(softmax_kernel, &softmax_kernel_impl); REGISTER_DISPATCH(log_softmax_kernel, &log_softmax_kernel_impl); - +REGISTER_DISPATCH(softmax_backward_kernel, &softmax_backward_kernel_impl); +REGISTER_DISPATCH( + log_softmax_backward_kernel, + &log_softmax_backward_kernel_impl); }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/SoftmaxKernel.h b/aten/src/ATen/native/cpu/SoftmaxKernel.h index f9af739034542..ee9fac647ad62 100644 --- a/aten/src/ATen/native/cpu/SoftmaxKernel.h +++ b/aten/src/ATen/native/cpu/SoftmaxKernel.h @@ -17,8 +17,12 @@ DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel); DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel); using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t); +using backward_fn_with_dim = + void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t); + DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel); DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel); - +DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel); +DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel); } } diff --git a/aten/src/ATen/native/cpu/SortingKernel.cpp b/aten/src/ATen/native/cpu/SortingKernel.cpp index b756c6c46a7ed..fdbecbb65cdff 100644 --- a/aten/src/ATen/native/cpu/SortingKernel.cpp +++ b/aten/src/ATen/native/cpu/SortingKernel.cpp @@ -60,7 +60,8 @@ void _dim_apply( } }; - iter.for_each(loop); + int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, dim_size); + iter.for_each(loop, /*grain_size=*/grain_size); } ); } diff --git a/aten/src/ATen/native/cpu/SparseFactories.cpp b/aten/src/ATen/native/cpu/SparseFactories.cpp new file mode 100644 index 0000000000000..0b0f73e1844c1 --- /dev/null +++ b/aten/src/ATen/native/cpu/SparseFactories.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace at { +namespace native { +using namespace at::sparse; + +namespace { +void _spdiags_kernel_cpu( + TensorIterator& iter, + const Tensor& diagonals, + Tensor& values, + Tensor& indices) { + auto* row_index_write_ptr = indices[0].data_ptr(); + auto* col_index_write_ptr = indices[1].data_ptr(); + const int64_t diagonals_read_stride = diagonals.stride(1); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + at::ScalarType::BFloat16, + at::ScalarType::Half, + at::ScalarType::Bool, + at::ScalarType::ComplexHalf, + diagonals.scalar_type(), + "spdiags_cpu", + [&] { + auto* values_write_ptr = values.data_ptr(); + cpu_kernel( + iter, + [&](int64_t diag_index, + int64_t diag_offset, + int64_t out_offset, + int64_t n_out) -> int64_t { + if (n_out > 0) { + auto* rows_start = row_index_write_ptr + out_offset; + auto* cols_start = col_index_write_ptr + out_offset; + auto* vals_start = values_write_ptr + out_offset; + const int64_t first_col = std::max(diag_offset, 0); + const int64_t first_row = first_col - diag_offset; + auto* data_read = diagonals[diag_index].data_ptr() + + first_col * diagonals_read_stride; + for (int64_t i = 0; i < n_out; ++i) { + rows_start[i] = first_row + i; + cols_start[i] = first_col + i; + vals_start[i] = data_read[i * diagonals_read_stride]; + } + } + // dummy return + return 0; + }); + }); +} + +} // namespace + +REGISTER_DISPATCH(spdiags_kernel_stub, &_spdiags_kernel_cpu) + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 675c700c7cd41..903fef2f03312 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -113,7 +113,7 @@ static void min_kernel_impl( const scalar_t* self_data, auto self_dim_stride) { using value_t = typename c10::scalar_value_type::type; value_t (*zabs_)(scalar_t) = zabs; - scalar_t min_number = self_data[0]; + scalar_t min_number = c10::load(self_data); int64_t index = 0; for (const auto i : c10::irange(self_dim_size)) { scalar_t value = self_data[i * self_dim_stride]; @@ -146,10 +146,10 @@ static void max_kernel_impl( const scalar_t* self_data, auto self_dim_stride) { using value_t = typename c10::scalar_value_type::type; value_t (*zabs_)(scalar_t) = zabs; - scalar_t max_number = self_data[0]; + scalar_t max_number = c10::load(self_data); int64_t index = 0; for (const auto i : c10::irange(self_dim_size)) { - scalar_t value = self_data[i * self_dim_stride]; + scalar_t value = c10::load(&self_data[i * self_dim_stride]); if (!(zabs_(value) <= zabs_(max_number))) { max_number = value; index = i; @@ -182,10 +182,10 @@ static void aminmax_kernel( compare_base_kernel(min_result, max_result, self, wrap_dim, keepdim, [&] ( scalar_t* min_result_data, scalar_t* max_result_data, const scalar_t* self_data, auto self_dim_stride) { - scalar_t min_number = self_data[0]; - scalar_t max_number = self_data[0]; + scalar_t min_number = c10::load(self_data); + scalar_t max_number = min_number; for (const auto i : c10::irange(self_dim_size)) { - scalar_t value = self_data[i * self_dim_stride]; + scalar_t value = c10::load(&self_data[i * self_dim_stride]); // note: comparison is written this way to handle NaN correctly if (!(value >= min_number)) { min_number = value; @@ -257,7 +257,7 @@ static void mode_kernel_impl( int64_t max_freq = 0; for (const auto i : c10::irange(self_dim_size)) { - elements[i] = std::make_pair(self_data[i * self_dim_stride], i); + elements[i] = std::make_pair(c10::load(&self_data[i * self_dim_stride]), i); } // Even though, theoretically, we don't need to specify this lambda diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index ba81ffbcfd4cf..a53587e56da4b 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -300,14 +300,15 @@ void sign_kernel(TensorIteratorBase& iter){ static void signbit_kernel(TensorIteratorBase& iter){ // NOTE: signbit does not always support integral arguments. - if (at::isIntegralType(iter.input_dtype(), /*includeBool=*/false)) { - AT_DISPATCH_INTEGRAL_TYPES(iter.input_dtype(), "signbit_cpu", [&]() { - cpu_kernel(iter, [](scalar_t a) -> bool { return c10::is_negative(a); }); }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, ScalarType::Half, iter.input_dtype(), "signbit_cpu", [&]() { - using opmath_t = at::opmath_type; - cpu_kernel(iter, [](scalar_t a) -> bool { return std::signbit(opmath_t{a}); }); }); - } + AT_DISPATCH_SWITCH(iter.input_dtype(), "signbit_cpu", + AT_DISPATCH_CASE_INTEGRAL_TYPES([&] { + cpu_kernel(iter, [](scalar_t a) -> bool { return c10::is_negative(a); }); + }) + AT_DISPATCH_CASE_FLOATING_TYPES_AND2(kBFloat16, ScalarType::Half, [&] { + using opmath_t = at::opmath_type; + cpu_kernel(iter, [](scalar_t a) -> bool { return std::signbit(opmath_t{a}); }); + }) + ); } static void sgn_kernel(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cpu/WeightNormKernel.cpp b/aten/src/ATen/native/cpu/WeightNormKernel.cpp index dfec0a49aeb1b..9dc6b52858050 100644 --- a/aten/src/ATen/native/cpu/WeightNormKernel.cpp +++ b/aten/src/ATen/native/cpu/WeightNormKernel.cpp @@ -328,8 +328,14 @@ void weight_norm_backward_last_dim_kernel( auto grad_v_data = grad_v.data_ptr(); auto grad_g_data = grad_g.data_ptr(); + // the temp buffer will be used twice: + // 1. vertical reduction from [M, N] to [T, N] + // 2. store the intermediate data of `sum`, `a` and `b`, + // so need to make sure it has at least 3 rows + // int num_threads = at::get_num_threads(); - Tensor buffer = at::empty({num_threads, N}, saved_norm.options()).zero_(); + int K = std::max(3, num_threads); + Tensor buffer = at::empty({K, N}, saved_norm.options()).zero_(); auto buffer_data = buffer.data_ptr(); // vertical parallel reduction @@ -351,6 +357,9 @@ void weight_norm_backward_last_dim_kernel( buffer_data[j] = sum; } + // reuse the 1st row of buffer to store the sum + // 2nd row to store coefficient a + // 3rd row to store coefficient b accscalar_t* per_dim_sum = buffer_data; accscalar_t* a = buffer_data + N; accscalar_t* b = buffer_data + 2 * N; diff --git a/aten/src/ATen/native/cpu/airy_ai.cpp b/aten/src/ATen/native/cpu/airy_ai.cpp new file mode 100644 index 0000000000000..1beb899fd0566 --- /dev/null +++ b/aten/src/ATen/native/cpu/airy_ai.cpp @@ -0,0 +1,42 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at { + namespace native { + inline namespace CPU_CAPABILITY { + static void airy_ai_kernel(TensorIteratorBase& iterator) { + TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2); + + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "airy_ai_cpu", [&]() { + cpu_kernel(iterator, [](scalar_t x) { + return airy_ai_forward(x); + }); + }); + } // airy_ai_kernel(TensorIteratorBase& iterator) + } // namespace CPU_CAPABILITY + + REGISTER_DISPATCH(special_airy_ai_stub, &CPU_CAPABILITY::airy_ai_kernel); + } // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp index ad277a278fa2d..c00b764f08055 100644 --- a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp @@ -18,6 +18,7 @@ #else #include #include +#include #endif namespace at { namespace native { @@ -234,7 +235,7 @@ void batch_norm_cpu_collect_stats_channels_last_impl( // Normal size of C should fit in L1, otherwise consider blocking on C. // int num_threads = at::get_num_threads(); - Tensor buffer = at::empty({num_threads, n_channel}, input.options()).zero_(); + Tensor buffer = at::zeros({num_threads, n_channel}, input.options()); scalar_t* buffer_data = buffer.data_ptr(); // compute mean per input @@ -464,7 +465,7 @@ void batch_norm_cpu_backward_channels_last_impl(Tensor& grad_input, Tensor& grad // Second path: parallel along dim1 of the immediate buffer. // int num_threads = at::get_num_threads(); - Tensor buffer = at::empty({2, num_threads, n_channel}, input.options()).zero_(); + Tensor buffer = at::zeros({2, num_threads, n_channel}, input.options()); scalar_t* sum_data = buffer.data_ptr(); scalar_t* dotp_data = sum_data + num_threads * n_channel; @@ -811,7 +812,7 @@ inline void batch_norm_cpu_collect_stats_channels_last_internal( param_t* var_sum_data = var_sum.data_ptr(); int num_threads = at::get_num_threads(); - Tensor buffer = at::empty({num_threads, n_channel}, input.options().dtype(kFloat)).zero_(); + Tensor buffer = at::zeros({num_threads, n_channel}, input.options().dtype(kFloat)); float* buffer_data = buffer.data_ptr(); at::parallel_for(0, N, 1, [&](int64_t begin, int64_t end) { @@ -1064,7 +1065,7 @@ void batch_norm_cpu_backward_channels_last_internal(Tensor& grad_input, Tensor& } int num_threads = at::get_num_threads(); - Tensor buffer = at::empty({2, num_threads, n_channel}, input.options().dtype(kFloat)).zero_(); + Tensor buffer = at::zeros({2, num_threads, n_channel}, input.options().dtype(kFloat)); float* sum_data = buffer.data_ptr(); float* dotp_data = sum_data + num_threads * n_channel; diff --git a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp index e1af1658d1a34..f7104875b8247 100644 --- a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp @@ -35,7 +35,7 @@ void LayerNormKernelImplInternal( Tensor* rstd) { using T_ACC = vec::vec_scalar_t; using Vec = vec::Vectorized; - DCHECK_EQ(X.numel(), M * N); + TORCH_DCHECK_EQ(X.numel(), M * N); DCHECK(!gamma.defined() || gamma.numel() == N); DCHECK(!beta.defined() || beta.numel() == N); const T* X_data = X.data_ptr(); @@ -117,10 +117,10 @@ void LayerNormBackwardKernelImplInternal( Tensor* dbeta) { using T_ACC = vec::vec_scalar_t; using Vec = vec::Vectorized; - DCHECK_EQ(dY.numel(), M * N); - DCHECK_EQ(X.numel(), M * N); - DCHECK_EQ(mean.numel(), M); - DCHECK_EQ(rstd.numel(), M); + TORCH_DCHECK_EQ(dY.numel(), M * N); + TORCH_DCHECK_EQ(X.numel(), M * N); + TORCH_DCHECK_EQ(mean.numel(), M); + TORCH_DCHECK_EQ(rstd.numel(), M); DCHECK(!gamma.defined() || gamma.numel() == N); const T* dY_data = dY.template data_ptr(); const T* X_data = X.template data_ptr(); diff --git a/aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp b/aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp new file mode 100644 index 0000000000000..1c9d3ee6cacf3 --- /dev/null +++ b/aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp @@ -0,0 +1,42 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at { + namespace native { + inline namespace CPU_CAPABILITY { + static void scaled_modified_bessel_k0_kernel(TensorIteratorBase& iterator) { + TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2); + + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "scaled_modified_bessel_k0_cpu", [&]() { + cpu_kernel(iterator, [](scalar_t x) { + return scaled_modified_bessel_k0_forward(x); + }); + }); + } // scaled_modified_bessel_k0_kernel(TensorIteratorBase& iterator) + } // namespace CPU_CAPABILITY + + REGISTER_DISPATCH(special_scaled_modified_bessel_k0_stub, &CPU_CAPABILITY::scaled_modified_bessel_k0_kernel); + } // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp b/aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp new file mode 100644 index 0000000000000..b11968a9019d7 --- /dev/null +++ b/aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp @@ -0,0 +1,42 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at { + namespace native { + inline namespace CPU_CAPABILITY { + static void scaled_modified_bessel_k1_kernel(TensorIteratorBase& iterator) { + TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2); + + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "scaled_modified_bessel_k1_cpu", [&]() { + cpu_kernel(iterator, [](scalar_t x) { + return scaled_modified_bessel_k1_forward(x); + }); + }); + } // scaled_modified_bessel_k1_kernel(TensorIteratorBase& iterator) + } // namespace CPU_CAPABILITY + + REGISTER_DISPATCH(special_scaled_modified_bessel_k1_stub, &CPU_CAPABILITY::scaled_modified_bessel_k1_kernel); + } // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cpu/spherical_bessel_j0.cpp b/aten/src/ATen/native/cpu/spherical_bessel_j0.cpp new file mode 100644 index 0000000000000..aad58b8e09464 --- /dev/null +++ b/aten/src/ATen/native/cpu/spherical_bessel_j0.cpp @@ -0,0 +1,42 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at { + namespace native { + inline namespace CPU_CAPABILITY { + static void spherical_bessel_j0_kernel(TensorIteratorBase& iterator) { + TORCH_INTERNAL_ASSERT(iterator.ntensors() == 2); + + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "spherical_bessel_j0_cpu", [&]() { + cpu_kernel(iterator, [](scalar_t x) { + return spherical_bessel_j0_forward(x); + }); + }); + } // spherical_bessel_j0_kernel(TensorIteratorBase& iterator) + } // namespace CPU_CAPABILITY + + REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &CPU_CAPABILITY::spherical_bessel_j0_kernel); + } // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/Activation.cpp b/aten/src/ATen/native/cuda/Activation.cpp index 55b397ca77f4d..4360f8b5c3efc 100644 --- a/aten/src/ATen/native/cuda/Activation.cpp +++ b/aten/src/ATen/native/cuda/Activation.cpp @@ -110,8 +110,13 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) { TORCH_CHECK(weight.is_contiguous()); int64_t weight_num = weight.numel(); + int64_t weight_dim = weight.dim(); Tensor result = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + TORCH_CHECK(weight_dim == 0 || weight_dim == 1, + "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", + weight_dim); + // case1: shared weight for all channels if (weight_num == 1) { auto iter = TensorIterator::unary_op(result, input); diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu deleted file mode 100644 index 33f0a2950a651..0000000000000 --- a/aten/src/ATen/native/cuda/Activation.cu +++ /dev/null @@ -1,710 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#define _USE_MATH_DEFINES - -#include - -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace at { -namespace native { - -// ----------------------------------- -// glu forward -// ----------------------------------- -void glu_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { - using opmath_t = at::opmath_type; - gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t { - const opmath_t a = a_; - const opmath_t b = b_; - const opmath_t one = opmath_t(1); - const opmath_t sigmoid = one / (one + std::exp(-b)); - return a * sigmoid; - }); - }); -} - -// ----------------------------------- -// glu forward ad -// ----------------------------------- -void glu_jvp_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { - using opmath_t = at::opmath_type; - gpu_kernel(iter, [] GPU_LAMBDA ( - scalar_t res_, - scalar_t b_, - scalar_t da_, - scalar_t db_) -> scalar_t { - const opmath_t res = res_; - const opmath_t b = b_; - const opmath_t da = da_; - const opmath_t db = db_; - const opmath_t one = opmath_t(1); - - const opmath_t sig_b = one / (one + std::exp(-b)); - return ( - da * sig_b + res * (db - sig_b * db) - ); - }); - }); -} - -// ----------------------------------- -// glu backward -// ----------------------------------- - -// Byte offsets don't require multiplication by sizeof(T), so are slightly cheaper. -// For fixed offsets, this removes all penalty from 64-bit indexing. -template -__device__ T* byte_offset(T* ptr, int64_t offset) { - using byte_ptr_t = typename std::conditional< - std::is_const::value, const char*, char*>::type; - return reinterpret_cast( - reinterpret_cast(ptr) + offset - ); -} - -template -__global__ void glu_backward_kernel( - int numel, scalar_t* gI, const scalar_t* I, const scalar_t* gO, - OffsetCalc offset_calculator, - int64_t gI_byte_offset, int64_t I_byte_offset) { - using opmath_t = at::opmath_type; - - const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; - if (linear_index >= numel) { - return; - } - const auto offsets = offset_calculator.get(linear_index); - - // We explicitly iterate over the first half of the input tensor, and - // gI_byte_offset and I_byte_offset are the offsets to access the - // corresponding index in the second half of the tensor. - const opmath_t a = I[offsets[1]]; - const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset); - const opmath_t gO_val = gO[offsets[2]]; - - const auto one = opmath_t(1); - const opmath_t sigmoid = one / (one + std::exp(-b)); - - auto* gA = gI + offsets[0]; - *gA = sigmoid * gO_val; - - auto* gB = byte_offset(gA, gI_byte_offset); - *gB = (one - sigmoid) * sigmoid * gO_val * a; -} - -void launch_glu_backward_kernel(const TensorIteratorBase& iter, - int64_t gI_stride, int64_t I_stride) { - const auto N = iter.numel(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(N > 0 && N <= std::numeric_limits::max()); - const auto offset_calculator = make_element_offset_calculator<3>(iter); - constexpr int64_t block_size = 256; - const int64_t grid = (N + block_size - 1) / block_size; - const auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "glu_backward_cuda", [&] { - auto gI = static_cast(iter.data_ptr(0)); - auto I = static_cast(iter.data_ptr(1)); - auto gO = static_cast(iter.data_ptr(2)); - glu_backward_kernel<<>>( - N, gI, I, gO, offset_calculator, - gI_stride * sizeof(scalar_t), I_stride * sizeof(scalar_t)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -// ----------------------------------- -// log_sigmoid forward -// ----------------------------------- - -void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.common_dtype(), - "log_sigmoid_forward_cuda", [&] { - using opmath_t = at::opmath_type; - - gpu_kernel(iter, - [] GPU_LAMBDA (scalar_t in_) -> scalar_t { - const opmath_t in = in_; - const auto min = std::min(opmath_t(0), in); - const auto z = std::exp(-std::abs(in)); - return min - std::log1p(z); - }); - }); -} - -// ----------------------------------- -// log_sigmoid backward -// ----------------------------------- - -void log_sigmoid_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.common_dtype(), - "log_sigmoid_backward_cuda", [&] { - using opmath_t = at::opmath_type; - gpu_kernel(iter, - [] GPU_LAMBDA (scalar_t in_, scalar_t grad_out_) -> scalar_t { - const opmath_t in = in_; - const opmath_t grad_out = grad_out_; - - auto in_negative = in < opmath_t(0); - auto max_deriv = in_negative ? opmath_t(1) : opmath_t(0); - auto sign = in_negative ? opmath_t(1) : -opmath_t(1); - const auto z = std::exp(-std::abs(in)); - return grad_out * (max_deriv - sign * (z / (opmath_t(1) + z))); - }); - }); -} - -// ----------------------------------- -// prelu forward -// ----------------------------------- -void launch_prelu_cuda_kernel_share_weights(TensorIteratorBase &iter, const TensorBase &weight) { - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_cuda", [&] { - const auto *weight_data = weight.data_ptr(); - at::native::gpu_kernel(iter, - [weight_data] GPU_LAMBDA (scalar_t input_val) { - return (input_val > 0) ? input_val : *weight_data * input_val; - }); - }); -} - -template -__global__ void prelu_cuda_kernel_multi_weights( - scalar_t* result_data, - const scalar_t* input_data, - const scalar_t* weight_data, - int64_t input_stride0, - int64_t input_stride1, - int64_t input_numel) { - - int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; - if (linearId >= input_numel) return; - - // multiply values at each channel with weight[channel_index] - int64_t channel = (linearId % input_stride0) / input_stride1; - scalar_t input_data_val = input_data[linearId]; - result_data[linearId] = (input_data_val > 0) ? input_data_val : weight_data[channel] * input_data_val; -} - -void launch_prelu_cuda_kernel_multi_weights( - const TensorBase &result, const TensorBase &input, const TensorBase &weight) { - int64_t input_ndim = input.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - int64_t input_stride0 = 1, input_stride1 = 1; - - if (input_ndim > 1) { - channel_size = input.size(1); // channel is the 2nd dim of input - auto strides = input.strides(); - input_stride0 = strides[0]; - input_stride1 = strides[1]; - } - const int64_t weight_num = weight.numel(); - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - // config to run cuda kernel - int64_t input_numel = input.numel(); - const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); - dim3 grid; - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu: input too large or too many dimensions"); - - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_cuda", [&] { - prelu_cuda_kernel_multi_weights - <<>>( - result.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - input_stride0, - input_stride1, - input_numel); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -// ----------------------------------- -// prelu backward -// ----------------------------------- -void launch_prelu_cuda_backward_kernel_share_weights( - TensorIteratorBase &iter, const TensorBase &weight) { - // N.B. `std::tuple` does not support `::operator=` on device code. - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_backward_cuda", [&] { - const auto *weight_data = weight.data_ptr(); - gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t input, scalar_t grad_out) -> thrust::tuple { - scalar_t input_grad = input > 0 ? grad_out : (*weight_data) * grad_out; - scalar_t weight_grad_collector = input > 0 ? scalar_t(0) : input * grad_out; - return {input_grad, weight_grad_collector}; - }); - }); -} - -template -__global__ void prelu_cuda_backward_kernel_multi_weights( - const scalar_t* input_data, - const scalar_t* weight_data, - const scalar_t* grad_out_data, - scalar_t* input_grad_data, - scalar_t* weight_grad_collector, - int64_t input_stride0, - int64_t input_stride1, - int64_t input_numel) { - - int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; - if (linearId >= input_numel) return; - int64_t channel = (linearId % input_stride0) / input_stride1; - scalar_t input_data_val = input_data[linearId]; - scalar_t grad_out_data_val = grad_out_data[linearId]; - input_grad_data[linearId] = (input_data_val > 0) ? grad_out_data_val : weight_data[channel] * grad_out_data_val; - weight_grad_collector[linearId] = (input_data_val > 0) ? scalar_t(0) : input_data_val * grad_out_data_val; -} - -void launch_prelu_cuda_backward_kernel_multi_weights( - const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out, - const TensorBase &input_grad, const TensorBase &weight_grad_collector) { - int64_t input_ndim = input.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - int64_t input_stride0 = 1, input_stride1 = 1; - - if (input_ndim > 1) { - channel_size = input.size(1); // channel is the 2nd dim of input - auto strides = input.strides(); - input_stride0 = strides[0]; - input_stride1 = strides[1]; - } - const int64_t weight_num = weight.numel(); - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - // config to run cuda kernel - int64_t input_numel = input.numel(); - const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); - dim3 grid; - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu_backward_cuda: input too large or too many dimensions"); - - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_backward_cuda", [&] { - prelu_cuda_backward_kernel_multi_weights - <<>>( - input.data_ptr(), - weight.data_ptr(), - grad_out.data_ptr(), - input_grad.data_ptr(), - weight_grad_collector.data_ptr(), - input_stride0, - input_stride1, - input_numel); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -// ----------------------------------- -// hardshrink -// ----------------------------------- -void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardshrink_cuda", [&]() { - auto lambd = value.to(); - gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t { - return (a >= -lambd && a <= lambd) ? scalar_t(0) : a; - }); - }); -} - -void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softshrink_cuda", [&]() { - auto lambd = value.to(); - gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t { - return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0)); - }); - }); -} - -void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "shrink_backward_cuda", [&]() { - auto lambd = value.to(); - gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t { - return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) : grad_val; - }); - }); -} - -void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) { - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, - iter.dtype(), "hardtanh_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - auto min_val = min.to(); - auto max_val = max.to(); - gpu_kernel(iter, [min_val, max_val]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - opmath_t aop = static_cast(a); - opmath_t bop = static_cast(b); - return (bop <= min_val) || (bop >= max_val) ? opmath_t(0) : aop; - }); - }); -} - -void softplus_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "softplus_cuda", [&]() { - using opmath_t = at::opmath_type; - auto beta = beta_.to(); - auto threshold = threshold_.to(); - gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t { - opmath_t aop = static_cast(a); - return (aop * beta) > threshold ? aop : (::log1p(std::exp(aop * beta))) / beta; - }); - }); -} - -void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "softplus_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - auto beta = beta_.to(); - auto threshold = threshold_.to(); - gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - opmath_t aop = static_cast(a); - opmath_t bop = static_cast(b); - opmath_t z = std::exp(bop * beta); - return (bop * beta) > threshold ? aop : aop * z / (z + opmath_t(1.)); - }); - }); -} - -template -void threshold_kernel_impl(TensorIteratorBase& iter, scalar_t threshold, scalar_t value) { - gpu_kernel_with_scalars(iter, [=]GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t { - return x <= threshold ? value : other; - }); -} - -static void threshold_kernel_cuda(TensorIteratorBase& iter, const Scalar& threshold, const Scalar& value) { - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "threshold_cuda", [&] { - threshold_kernel_impl(iter, threshold.to(), value.to()); - }); -} - -void elu_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_cuda", [&]() { - using opmath_t = at::opmath_type; - auto negcoef = alpha.to() * scale.to(); - auto poscoef = scale.to(); - auto negiptcoef = input_scale.to(); - gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a) -> scalar_t { - opmath_t aop = static_cast(a); - return aop > 0 ? aop * poscoef : std::expm1(aop * negiptcoef) * negcoef; - }); - }); -} - -void elu_backward_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - auto negcoef = alpha.to() * scale.to(); - auto poscoef = scale.to(); - auto negiptcoef = input_scale.to(); - gpu_kernel(iter, [negcoef, poscoef, negiptcoef, is_result]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - opmath_t aop = static_cast(a); - opmath_t bop = static_cast(b); - - if (is_result) { - return bop <= 0 ? aop * negiptcoef * (bop + negcoef) : aop * poscoef; - } else { - return bop <= 0 ? aop * negiptcoef * negcoef * std::exp(bop * negiptcoef) : aop * poscoef; - } - }); - }); -} - -void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { - if (approximate == GeluType::Tanh) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { - gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); - constexpr opmath_t kKappa = 0.044715; - auto x_cube = static_cast(x) * static_cast(x) * static_cast(x); - auto inner = kBeta * (static_cast(x) + kKappa * x_cube); - return opmath_t(0.5) * static_cast(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner)); - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { - gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - constexpr opmath_t kAlpha = M_SQRT1_2; - return static_cast(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); - }); - }); - } -} - -void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { - if (approximate == GeluType::Tanh) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { - gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); - constexpr opmath_t kKappa = 0.044715; - auto x_sq = static_cast(x) * static_cast(x); - auto x_cube = x_sq * static_cast(x); - auto inner = kBeta * (static_cast(x) + kKappa * x_cube); - auto tanh_inner = c10::cuda::compat::tanh(inner); - - auto left = opmath_t(0.5) * static_cast(x); - auto right = opmath_t(1) + tanh_inner; - - auto left_derivative = 0.5 * right; - - auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; - auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); - auto right_derivative = left * tanh_derivative * inner_derivative; - - return static_cast(dy) * (left_derivative + right_derivative); - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { - gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); - constexpr opmath_t kAlpha = M_SQRT1_2; - const opmath_t cdf = - opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); - const opmath_t pdf = - c10::cuda::compat::exp( - opmath_t(-0.5) * static_cast(x) * static_cast(x)) * - kBeta; - return static_cast(dy) * (cdf + static_cast(x) * pdf); - }); - }); - } -} - -namespace { - -void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "leaky_relu_cuda", [&]() { - using opmath_t = at::opmath_type; - auto negval = negval_.to(); - gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a) -> scalar_t { - opmath_t aop = static_cast(a); - return aop > opmath_t(0) ? aop : aop * negval; - }); - }); -} - -void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "leaky_relu_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - auto negval = negval_.to(); - gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - opmath_t aop = static_cast(a); - opmath_t bop = static_cast(b); - return aop > opmath_t(0) ? bop : bop * negval; - }); - }); -} - -void hardswish_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() { - using opmath_t = at::opmath_type; - const opmath_t zero(0.0f); - const opmath_t one_sixth(1.0f / 6.0f); - const opmath_t three(3.0f); - const opmath_t six(6.0f); - gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { - opmath_t x = static_cast(self_val); - return x * std::min(std::max(x + three, zero), six) * one_sixth; - }); - }); -} - -void hardswish_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - const opmath_t zero(0.0f); - const opmath_t three(3.0f); - const opmath_t neg_three(-3.0f); - const opmath_t one_half(0.5f); - gpu_kernel( - iter, - [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { - opmath_t grad_val = static_cast(grad_val_); - opmath_t self_val = static_cast(self_val_); - if (self_val < neg_three) { - return zero; - } else if (self_val <= three) { - return grad_val * ((self_val / three) + one_half); - } else { - return grad_val; - } - }); - }); -} - -void hardsigmoid_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "hardsigmoid_cuda", [&]() { - using opmath_t = at::opmath_type; - const opmath_t zero(0.0f); - const opmath_t one_sixth(1.0f / 6.0f); - const opmath_t three(3.0f); - const opmath_t six(6.0f); - gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { - opmath_t x = static_cast(self_val); - return std::min(std::max(x + three, zero), six) * one_sixth; - }); - }); -} - -void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "hardsigmoid_backward_cuda", - [&]() { - using opmath_t = at::opmath_type; - const opmath_t zero(0.0f); - const opmath_t three(3.0f); - const opmath_t neg_three(-3.0f); - const opmath_t one_sixth(1.0f / 6.0f); - gpu_kernel( - iter, - [zero, three, neg_three, one_sixth]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { - opmath_t grad_val = static_cast(grad_val_); - opmath_t self_val = static_cast(self_val_); - return (self_val > neg_three && self_val < three) - ? grad_val * one_sixth - : zero; - }); - }); -} - -void silu_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "silu_cuda", - [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA(scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - const opmath_t x_acc = static_cast(x); - return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); - }); - }); -} - -void silu_backward_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "silu_backward_cuda", - [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - const opmath_t dy_acc = static_cast(dy); - const opmath_t x_acc = static_cast(x); - const opmath_t s_acc = - opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); - return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc)); - }); - }); -} - -void mish_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "mish_cuda", - [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA(scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - const opmath_t x_acc = static_cast(x); - return x_acc * c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); - }); - }); -} - -void mish_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "mish_backward_cuda", - [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - const opmath_t dy_acc = static_cast(dy); - const opmath_t x_acc = static_cast(x); - const opmath_t s_acc = - opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); - const opmath_t t_acc = - c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); - return dy_acc * (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc)); - }); - }); -} - -} // namespace - -REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel); -REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel); -REGISTER_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_kernel); -REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel); -REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel); -REGISTER_DISPATCH(elu_stub, &elu_kernel); -REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel); -REGISTER_DISPATCH(glu_stub, &glu_kernel); -REGISTER_DISPATCH(glu_jvp_stub, &glu_jvp_kernel); -REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel); -REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel); -REGISTER_DISPATCH(hardswish_stub, &hardswish_kernel); -REGISTER_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel); -REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel); -REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel); -REGISTER_DISPATCH(softplus_stub, &softplus_kernel); -REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel); -REGISTER_DISPATCH(silu_stub, &silu_kernel); -REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel); -REGISTER_DISPATCH(mish_stub, &mish_kernel); -REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel); -REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda); - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationEluKernel.cu b/aten/src/ATen/native/cuda/ActivationEluKernel.cu new file mode 100644 index 0000000000000..113e6da10eacd --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationEluKernel.cu @@ -0,0 +1,88 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void elu_kernel( + TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "elu_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto negcoef = alpha.to() * scale.to(); + auto poscoef = scale.to(); + auto negiptcoef = input_scale.to(); + gpu_kernel( + iter, + [negcoef, poscoef, negiptcoef] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return aop > 0 ? aop * poscoef + : std::expm1(aop * negiptcoef) * negcoef; + }); + }); +} + +void elu_backward_kernel( + TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "elu_backward_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto negcoef = alpha.to() * scale.to(); + auto poscoef = scale.to(); + auto negiptcoef = input_scale.to(); + gpu_kernel( + iter, + [negcoef, poscoef, negiptcoef, is_result] GPU_LAMBDA( + scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + + if (is_result) { + return bop <= 0 ? aop * negiptcoef * (bop + negcoef) + : aop * poscoef; + } else { + return bop <= 0 + ? aop * negiptcoef * negcoef * std::exp(bop * negiptcoef) + : aop * poscoef; + } + }); + }); +} +} // namespace + +REGISTER_DISPATCH(elu_stub, &elu_kernel); +REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu new file mode 100644 index 0000000000000..d3d7879d3b884 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu @@ -0,0 +1,90 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_cube = static_cast(x) * static_cast(x) * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + return opmath_t(0.5) * static_cast(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner)); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kAlpha = M_SQRT1_2; + return static_cast(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + }); + }); + } +} + +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_sq = static_cast(x) * static_cast(x); + auto x_cube = x_sq * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + auto tanh_inner = c10::cuda::compat::tanh(inner); + + auto left = opmath_t(0.5) * static_cast(x); + auto right = opmath_t(1) + tanh_inner; + + auto left_derivative = 0.5 * right; + + auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; + auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); + auto right_derivative = left * tanh_derivative * inner_derivative; + + return static_cast(dy) * (left_derivative + right_derivative); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); + constexpr opmath_t kAlpha = M_SQRT1_2; + const opmath_t cdf = + opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + const opmath_t pdf = + c10::cuda::compat::exp( + opmath_t(-0.5) * static_cast(x) * static_cast(x)) * + kBeta; + return static_cast(dy) * (cdf + static_cast(x) * pdf); + }); + }); + } +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationGluKernel.cu b/aten/src/ATen/native/cuda/ActivationGluKernel.cu new file mode 100644 index 0000000000000..740edbbf38ee2 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationGluKernel.cu @@ -0,0 +1,143 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// ----------------------------------- +// glu forward +// ----------------------------------- +void glu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a_, scalar_t b_) -> scalar_t { + const opmath_t a = a_; + const opmath_t b = b_; + const opmath_t one = opmath_t(1); + const opmath_t sigmoid = one / (one + std::exp(-b)); + return a * sigmoid; + }); + }); +} + +// ----------------------------------- +// glu forward ad +// ----------------------------------- +void glu_jvp_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel( + iter, + [] GPU_LAMBDA( + scalar_t res_, scalar_t b_, scalar_t da_, scalar_t db_) + -> scalar_t { + const opmath_t res = res_; + const opmath_t b = b_; + const opmath_t da = da_; + const opmath_t db = db_; + const opmath_t one = opmath_t(1); + + const opmath_t sig_b = one / (one + std::exp(-b)); + return (da * sig_b + res * (db - sig_b * db)); + }); + }); +} + +// ----------------------------------- +// glu backward +// ----------------------------------- + +// Byte offsets don't require multiplication by sizeof(T), so are slightly +// cheaper. For fixed offsets, this removes all penalty from 64-bit indexing. +template +__device__ T* byte_offset(T* ptr, int64_t offset) { + using byte_ptr_t = typename std:: + conditional::value, const char*, char*>::type; + return reinterpret_cast(reinterpret_cast(ptr) + offset); +} + +template +__global__ void glu_backward_kernel( + int numel, + scalar_t* gI, + const scalar_t* I, + const scalar_t* gO, + OffsetCalc offset_calculator, + int64_t gI_byte_offset, + int64_t I_byte_offset) { + using opmath_t = at::opmath_type; + + const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; + if (linear_index >= numel) { + return; + } + const auto offsets = offset_calculator.get(linear_index); + + // We explicitly iterate over the first half of the input tensor, and + // gI_byte_offset and I_byte_offset are the offsets to access the + // corresponding index in the second half of the tensor. + const opmath_t a = I[offsets[1]]; + const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset); + const opmath_t gO_val = gO[offsets[2]]; + + const auto one = opmath_t(1); + const opmath_t sigmoid = one / (one + std::exp(-b)); + + auto* gA = gI + offsets[0]; + *gA = sigmoid * gO_val; + + auto* gB = byte_offset(gA, gI_byte_offset); + *gB = (one - sigmoid) * sigmoid * gO_val * a; +} + +void launch_glu_backward_kernel( + const TensorIteratorBase& iter, + int64_t gI_stride, + int64_t I_stride) { + const auto N = iter.numel(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + N > 0 && N <= std::numeric_limits::max()); + const auto offset_calculator = make_element_offset_calculator<3>(iter); + constexpr int64_t block_size = 256; + const int64_t grid = (N + block_size - 1) / block_size; + const auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "glu_backward_cuda", [&] { + auto gI = static_cast(iter.data_ptr(0)); + auto I = static_cast(iter.data_ptr(1)); + auto gO = static_cast(iter.data_ptr(2)); + glu_backward_kernel<<>>( + N, + gI, + I, + gO, + offset_calculator, + gI_stride * sizeof(scalar_t), + I_stride * sizeof(scalar_t)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +REGISTER_DISPATCH(glu_stub, &glu_kernel); +REGISTER_DISPATCH(glu_jvp_stub, &glu_jvp_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu new file mode 100644 index 0000000000000..ae2f6b11b8523 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu @@ -0,0 +1,41 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardshrink_cuda", + [&]() { + auto lambd = value.to(); + gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t { + return (a >= -lambd && a <= lambd) ? scalar_t(0) : a; + }); + }); +} +} // namespace + +REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu new file mode 100644 index 0000000000000..ceafa53b72f1c --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu @@ -0,0 +1,76 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void hardsigmoid_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardsigmoid_cuda", + [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t one_sixth(1.0f / 6.0f); + const opmath_t three(3.0f); + const opmath_t six(6.0f); + gpu_kernel( + iter, + [zero, one_sixth, three, six] GPU_LAMBDA( + scalar_t self_val) -> scalar_t { + opmath_t x = static_cast(self_val); + return std::min(std::max(x + three, zero), six) * one_sixth; + }); + }); +} + +void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardsigmoid_backward_cuda", + [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t three(3.0f); + const opmath_t neg_three(-3.0f); + const opmath_t one_sixth(1.0f / 6.0f); + gpu_kernel( + iter, + [zero, three, neg_three, one_sixth] GPU_LAMBDA( + scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + opmath_t grad_val = static_cast(grad_val_); + opmath_t self_val = static_cast(self_val_); + return (self_val > neg_three && self_val < three) + ? grad_val * one_sixth + : zero; + }); + }); +} + +} // namespace + +REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel); +REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu new file mode 100644 index 0000000000000..7d952043ad872 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu @@ -0,0 +1,65 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void hardswish_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t one_sixth(1.0f / 6.0f); + const opmath_t three(3.0f); + const opmath_t six(6.0f); + gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { + opmath_t x = static_cast(self_val); + return x * std::min(std::max(x + three, zero), six) * one_sixth; + }); + }); +} + +void hardswish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_cuda", [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t three(3.0f); + const opmath_t neg_three(-3.0f); + const opmath_t one_half(0.5f); + gpu_kernel( + iter, + [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + opmath_t grad_val = static_cast(grad_val_); + opmath_t self_val = static_cast(self_val_); + if (self_val < neg_three) { + return zero; + } else if (self_val <= three) { + return grad_val * ((self_val / three) + one_half); + } else { + return grad_val; + } + }); + }); +} +} // namespace + +REGISTER_DISPATCH(hardswish_stub, &hardswish_kernel); +REGISTER_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu new file mode 100644 index 0000000000000..1ef3fdba2898f --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu @@ -0,0 +1,46 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void hardtanh_backward_kernel( + TensorIterator& iter, + const Scalar& min, + const Scalar& max) { + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::Half, iter.dtype(), "hardtanh_backward_cuda", [&]() { + using opmath_t = at::opmath_type; + auto min_val = min.to(); + auto max_val = max.to(); + gpu_kernel( + iter, + [min_val, max_val] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + return (bop <= min_val) || (bop >= max_val) ? opmath_t(0) : aop; + }); + }); +} +} // namespace + +REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu new file mode 100644 index 0000000000000..c323aca1ca7fb --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu @@ -0,0 +1,64 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "leaky_relu_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto negval = negval_.to(); + gpu_kernel(iter, [negval] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return aop > opmath_t(0) ? aop : aop * negval; + }); + }); +} + +void leaky_relu_backward_kernel( + TensorIteratorBase& iter, + const Scalar& negval_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "leaky_relu_backward_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto negval = negval_.to(); + gpu_kernel( + iter, [negval] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + return aop > opmath_t(0) ? bop : bop * negval; + }); + }); +} +} // namespace + +REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel); +REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu new file mode 100644 index 0000000000000..131462467d3dd --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu @@ -0,0 +1,66 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// ----------------------------------- +// log_sigmoid forward +// ----------------------------------- + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_forward_cuda", [&] { + using opmath_t = at::opmath_type; + + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t in_) -> scalar_t { + const opmath_t in = in_; + const auto min = std::min(opmath_t(0), in); + const auto z = std::exp(-std::abs(in)); + return min - std::log1p(z); + }); + }); +} + +namespace { +// ----------------------------------- +// log_sigmoid backward +// ----------------------------------- +void log_sigmoid_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_backward_cuda", [&] { + using opmath_t = at::opmath_type; + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t in_, scalar_t grad_out_) -> scalar_t { + const opmath_t in = in_; + const opmath_t grad_out = grad_out_; + + auto in_negative = in < opmath_t(0); + auto max_deriv = in_negative ? opmath_t(1) : opmath_t(0); + auto sign = in_negative ? opmath_t(1) : -opmath_t(1); + const auto z = std::exp(-std::abs(in)); + return grad_out * (max_deriv - sign * (z / (opmath_t(1) + z))); + }); + }); +} +} // namespace + +REGISTER_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationMishKernel.cu b/aten/src/ATen/native/cuda/ActivationMishKernel.cu new file mode 100644 index 0000000000000..70c058644f666 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationMishKernel.cu @@ -0,0 +1,66 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void mish_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc * + c10::cuda::compat::tanh( + c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); + }); + }); +} + +void mish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_backward_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); + const opmath_t t_acc = c10::cuda::compat::tanh( + c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); + return dy_acc * + (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc)); + }); + }); +} +} // namespace + +REGISTER_DISPATCH(mish_stub, &mish_kernel); +REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationPreluKernel.cu b/aten/src/ATen/native/cuda/ActivationPreluKernel.cu new file mode 100644 index 0000000000000..0d8f09714698e --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationPreluKernel.cu @@ -0,0 +1,175 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// ----------------------------------- +// prelu forward +// ----------------------------------- +void launch_prelu_cuda_kernel_share_weights(TensorIteratorBase &iter, const TensorBase &weight) { + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_cuda", [&] { + const auto *weight_data = weight.data_ptr(); + at::native::gpu_kernel(iter, + [weight_data] GPU_LAMBDA (scalar_t input_val) { + return (input_val > 0) ? input_val : *weight_data * input_val; + }); + }); +} + +template +__global__ void prelu_cuda_kernel_multi_weights( + scalar_t* result_data, + const scalar_t* input_data, + const scalar_t* weight_data, + int64_t input_stride0, + int64_t input_stride1, + int64_t input_numel) { + + int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; + if (linearId >= input_numel) return; + + // multiply values at each channel with weight[channel_index] + int64_t channel = (linearId % input_stride0) / input_stride1; + scalar_t input_data_val = input_data[linearId]; + result_data[linearId] = (input_data_val > 0) ? input_data_val : weight_data[channel] * input_data_val; +} + +void launch_prelu_cuda_kernel_multi_weights( + const TensorBase &result, const TensorBase &input, const TensorBase &weight) { + int64_t input_ndim = input.dim(); + TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); + + int64_t channel_size = 1; // channel_size default to 1 + int64_t input_stride0 = 1, input_stride1 = 1; + + if (input_ndim > 1) { + channel_size = input.size(1); // channel is the 2nd dim of input + auto strides = input.strides(); + input_stride0 = strides[0]; + input_stride1 = strides[1]; + } + const int64_t weight_num = weight.numel(); + TORCH_CHECK(channel_size == weight_num, + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + " and channel size = ", channel_size, "."); + + // config to run cuda kernel + int64_t input_numel = input.numel(); + const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); + dim3 grid; + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu: input too large or too many dimensions"); + + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_cuda", [&] { + prelu_cuda_kernel_multi_weights + <<>>( + result.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + input_stride0, + input_stride1, + input_numel); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +// ----------------------------------- +// prelu backward +// ----------------------------------- +void launch_prelu_cuda_backward_kernel_share_weights( + TensorIteratorBase &iter, const TensorBase &weight) { + // N.B. `std::tuple` does not support `::operator=` on device code. + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_backward_cuda", [&] { + const auto *weight_data = weight.data_ptr(); + gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t input, scalar_t grad_out) -> thrust::tuple { + scalar_t input_grad = input > 0 ? grad_out : (*weight_data) * grad_out; + scalar_t weight_grad_collector = input > 0 ? scalar_t(0) : input * grad_out; + return {input_grad, weight_grad_collector}; + }); + }); +} + +template +__global__ void prelu_cuda_backward_kernel_multi_weights( + const scalar_t* input_data, + const scalar_t* weight_data, + const scalar_t* grad_out_data, + scalar_t* input_grad_data, + scalar_t* weight_grad_collector, + int64_t input_stride0, + int64_t input_stride1, + int64_t input_numel) { + + int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; + if (linearId >= input_numel) return; + int64_t channel = (linearId % input_stride0) / input_stride1; + scalar_t input_data_val = input_data[linearId]; + scalar_t grad_out_data_val = grad_out_data[linearId]; + input_grad_data[linearId] = (input_data_val > 0) ? grad_out_data_val : weight_data[channel] * grad_out_data_val; + weight_grad_collector[linearId] = (input_data_val > 0) ? scalar_t(0) : input_data_val * grad_out_data_val; +} + +void launch_prelu_cuda_backward_kernel_multi_weights( + const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out, + const TensorBase &input_grad, const TensorBase &weight_grad_collector) { + int64_t input_ndim = input.dim(); + TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); + + int64_t channel_size = 1; // channel_size default to 1 + int64_t input_stride0 = 1, input_stride1 = 1; + + if (input_ndim > 1) { + channel_size = input.size(1); // channel is the 2nd dim of input + auto strides = input.strides(); + input_stride0 = strides[0]; + input_stride1 = strides[1]; + } + const int64_t weight_num = weight.numel(); + TORCH_CHECK(channel_size == weight_num, + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + " and channel size = ", channel_size, "."); + + // config to run cuda kernel + int64_t input_numel = input.numel(); + const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); + dim3 grid; + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu_backward_cuda: input too large or too many dimensions"); + + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_backward_cuda", [&] { + prelu_cuda_backward_kernel_multi_weights + <<>>( + input.data_ptr(), + weight.data_ptr(), + grad_out.data_ptr(), + input_grad.data_ptr(), + weight_grad_collector.data_ptr(), + input_stride0, + input_stride1, + input_numel); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu new file mode 100644 index 0000000000000..701b901e4f773 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu @@ -0,0 +1,61 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void silu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "silu_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); + }); + }); +} + +void silu_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "silu_backward_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); + return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc)); + }); + }); +} +} // namespace + +REGISTER_DISPATCH(silu_stub, &silu_kernel); +REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu new file mode 100644 index 0000000000000..86c04221b24f0 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu @@ -0,0 +1,76 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void softplus_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + gpu_kernel(iter, [beta, threshold] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return (aop * beta) > threshold + ? aop + : (::log1p(std::exp(aop * beta))) / beta; + }); + }); +} + +void softplus_backward_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_backward_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + gpu_kernel( + iter, + [beta, threshold] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + opmath_t z = std::exp(bop * beta); + return (bop * beta) > threshold ? aop + : aop * z / (z + opmath_t(1.)); + }); + }); +} + +} // namespace + +REGISTER_DISPATCH(softplus_stub, &softplus_kernel); +REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu new file mode 100644 index 0000000000000..e21e3b94fac48 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu @@ -0,0 +1,60 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softshrink_cuda", + [&]() { + auto lambd = value.to(); + gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t { + return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0)); + }); + }); +} + +void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "shrink_backward_cuda", + [&]() { + auto lambd = value.to(); + gpu_kernel( + iter, + [lambd] GPU_LAMBDA( + scalar_t grad_val, scalar_t self_val) -> scalar_t { + return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) + : grad_val; + }); + }); +} +} // namespace + +REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel); +REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu new file mode 100644 index 0000000000000..86d8bbd528c8f --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu @@ -0,0 +1,54 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +void threshold_kernel_impl( + TensorIteratorBase& iter, + scalar_t threshold, + scalar_t value) { + gpu_kernel_with_scalars( + iter, [=] GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t { + return x <= threshold ? value : other; + }); +} + +static void threshold_kernel_cuda( + TensorIteratorBase& iter, + const Scalar& threshold, + const Scalar& value) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "threshold_cuda", + [&] { + threshold_kernel_impl( + iter, threshold.to(), value.to()); + }); +} + +} // namespace + +REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu b/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu new file mode 100644 index 0000000000000..13e9757b5f39d --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu @@ -0,0 +1,112 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +void div_floor_kernel_cuda(TensorIteratorBase& iter) { + // See NOTE: [Floor Division in Python] + const auto dtype = iter.common_dtype(); + if (dtype == kByte) { + // In the special case of unsigned integer division, floor division is + // equivalent to truncation division (since the signs of the divisor and + // dividend are always the same) + return div_trunc_kernel_cuda(iter); + } else if (isIntegralType(dtype, /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_cuda", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + if (c10::signs_differ(a, b)) { + // Subtracts one from the results of truncation division if the + // divisor and dividend have different sign(bit)s and the + // remainder of the division is nonzero + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + } + + return a / b; + }); + }); + } else if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() { + using accscalar_t = at::acc_type; + auto b = iter.scalar_value(2); + if (C10_UNLIKELY(b == 0)) { + return div_true_kernel_cuda(iter); + } + + auto inv_b = accscalar_t(1.0) / b; + iter.remove_operand(2); + gpu_kernel(iter, [b, inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t { + auto mod = std::fmod(a, b); + auto div = (a - mod) * inv_b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = c10::cuda::compat::copysign(scalar_t(0), a * inv_b); + } + return floordiv; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + if (C10_UNLIKELY(b == 0)) { + return a / b; + } + + auto mod = std::fmod(a, b); + auto div = (a - mod) / b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = c10::cuda::compat::copysign(scalar_t(0), a / b); + } + return floordiv; + }); + }); + } +} +} // namespace binary_internal + +REGISTER_DISPATCH(div_floor_stub, &binary_internal::div_floor_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu b/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu new file mode 100644 index 0000000000000..642318d2239fb --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu @@ -0,0 +1,63 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +const char div_name[] = "div_kernel"; +void div_true_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (iter.common_dtype() == kComplexHalf) { + using scalar_t = c10::complex; +#if AT_USE_JITERATOR() + static const auto div_string = jiterator_stringify( + template T div_kernel(T a, T b) { return a / b; }); + opmath_jitted_gpu_kernel_with_scalars( + iter, div_string); +#else + using opmath_t = at::opmath_type; + opmath_gpu_kernel_with_scalars(iter, DivFunctor()); +#endif + return; + } + if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { + using opmath_t = at::opmath_type; + auto inv_b = opmath_t(1.0) / iter.scalar_value(2); + iter.remove_operand(2); + gpu_kernel( + iter, + BUnaryFunctor>( + MulFunctor(), inv_b)); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { + DivFunctor f; + gpu_kernel_with_scalars(iter, f); + }); + } +} +} // namespace binary_internal + +REGISTER_DISPATCH(div_true_stub, &binary_internal::div_true_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryDivTruncKernel.cu b/aten/src/ATen/native/cuda/BinaryDivTruncKernel.cu new file mode 100644 index 0000000000000..01a04b40cbc1a --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryDivTruncKernel.cu @@ -0,0 +1,55 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +void div_trunc_kernel_cuda(TensorIteratorBase& iter) { + auto dtype = iter.common_dtype(); + if (isIntegralType(dtype, /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cuda", [&]() { + gpu_kernel_with_scalars( + iter, + [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return a / b; }); + }); + } else if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() { + using accscalar_t = at::acc_type; + auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); + iter.remove_operand(2); + gpu_kernel(iter, [inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t { + return std::trunc(a * inv_b); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return std::trunc(a / b); + }); + }); + } +} +} // namespace binary_internal + +REGISTER_DISPATCH(div_trunc_stub, &binary_internal::div_trunc_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryInternal.h b/aten/src/ATen/native/cuda/BinaryInternal.h new file mode 100644 index 0000000000000..e098d32b114d6 --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryInternal.h @@ -0,0 +1,48 @@ +// DON'T include this except from Binary*.cu files. It should not leak into +// headers. +#pragma once +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +template +struct DivFunctor { + __device__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead +// [-Werror=int-in-bool-context] +template <> +struct MulFunctor { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; +void div_true_kernel_cuda(TensorIteratorBase& iter); +void div_trunc_kernel_cuda(TensorIteratorBase& iter); +} // namespace binary_internal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu deleted file mode 100644 index f3998bf9c2cd9..0000000000000 --- a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu +++ /dev/null @@ -1,222 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -// NOTE: CUDA on Windows requires that the enclosing function -// of a __device__ lambda not have internal linkage. - -namespace at { namespace native { - -template -struct DivFunctor { - __device__ scalar_t operator() (scalar_t a, scalar_t b) const { - return a / b; - } -}; - -template -struct MulFunctor { - __device__ T operator() (T a, T b) const { - return a * b; - } -}; - -// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] -template<> -struct MulFunctor { - __device__ bool operator() (bool a, bool b) const { - return a && b; - } -}; - -const char div_name[] = "div_kernel"; -void div_true_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (iter.common_dtype() == kComplexHalf) { - using scalar_t = c10::complex; - #if AT_USE_JITERATOR() - static const auto div_string = jiterator_stringify( - template - T div_kernel(T a, T b) { - return a / b; - } - ); - opmath_jitted_gpu_kernel_with_scalars(iter, div_string); - #else - using opmath_t = at::opmath_type; - opmath_gpu_kernel_with_scalars(iter, DivFunctor()); - #endif - return; - } - if (iter.is_cpu_scalar(2)) { - // optimization for floating-point types: if the second operand is a CPU - // scalar, compute a * reciprocal(b). Note that this may lose one bit of - // precision compared to computing the division. - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { - using opmath_t = at::opmath_type; - auto inv_b = opmath_t(1.0) / iter.scalar_value(2); - iter.remove_operand(2); - gpu_kernel(iter, BUnaryFunctor>( - MulFunctor(), inv_b)); - }); - } else { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { - DivFunctor f; - gpu_kernel_with_scalars(iter, f); - }); - } -} - -void div_trunc_kernel_cuda(TensorIteratorBase& iter) { - auto dtype = iter.common_dtype(); - if (isIntegralType(dtype, /*includeBool*/ false)) { - AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cuda", [&]() { - gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - return a / b; - }); - }); - } else if (iter.is_cpu_scalar(2)) { - // optimization for floating-point types: if the second operand is a CPU - // scalar, compute a * reciprocal(b). Note that this may lose one bit of - // precision compared to computing the division. - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() { - using accscalar_t = at::acc_type; - auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); - iter.remove_operand(2); - gpu_kernel(iter, [inv_b] GPU_LAMBDA (scalar_t a) -> scalar_t { - return std::trunc(a * inv_b); - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() { - gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - return std::trunc(a / b); - }); - }); - } -} - -void div_floor_kernel_cuda(TensorIteratorBase& iter) { - // See NOTE: [Floor Division in Python] - const auto dtype = iter.common_dtype(); - if (dtype == kByte) { - // In the special case of unsigned integer division, floor division is - // equivalent to truncation division (since the signs of the divisor and - // dividend are always the same) - return div_trunc_kernel_cuda(iter); - } else if (isIntegralType(dtype, /*includeBool*/ false)) { - AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_cuda", [&]() { - gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - if (c10::signs_differ(a, b)) { - // Subtracts one from the results of truncation division if the - // divisor and dividend have different sign(bit)s and the remainder of - // the division is nonzero - const auto quot = a / b; - const auto rem = a % b; - return rem ? quot - 1 : quot; - } - - return a / b; - }); - }); - } else if (iter.is_cpu_scalar(2)) { - // optimization for floating-point types: if the second operand is a CPU - // scalar, compute a * reciprocal(b). Note that this may lose one bit of - // precision compared to computing the division. - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() { - using accscalar_t = at::acc_type; - auto b = iter.scalar_value(2); - if (C10_UNLIKELY(b == 0)) { - return div_true_kernel_cuda(iter); - } - - auto inv_b = accscalar_t(1.0) / b; - iter.remove_operand(2); - gpu_kernel(iter, [b, inv_b] GPU_LAMBDA (scalar_t a) -> scalar_t { - auto mod = std::fmod(a, b); - auto div = (a - mod) * inv_b; - if ((mod != 0) && (b < 0) != (mod < 0)) { - div -= scalar_t(1); - } - - scalar_t floordiv; - if (div != 0) { - floordiv = std::floor(div); - if (div - floordiv > scalar_t(0.5)) { - floordiv += scalar_t(1.0); - } - } else { - floordiv = c10::cuda::compat::copysign(scalar_t(0), a * inv_b); - } - return floordiv; - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() { - gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - if (C10_UNLIKELY(b == 0)) { - return a / b; - } - - auto mod = std::fmod(a, b); - auto div = (a - mod) / b; - if ((mod != 0) && (b < 0) != (mod < 0)) { - div -= scalar_t(1); - } - - scalar_t floordiv; - if (div != 0) { - floordiv = std::floor(div); - if (div - floordiv > scalar_t(0.5)) { - floordiv += scalar_t(1.0); - } - } else { - floordiv = c10::cuda::compat::copysign(scalar_t(0), a / b); - } - return floordiv; - }); - }); - } -} - -const char mul_name[] = "mul_kernel"; -void mul_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (common_dtype == kComplexHalf) { - using scalar_t = c10::complex; - #if AT_USE_JITERATOR() - static const auto mul_string = jiterator_stringify( - template - T mul_kernel(T a, T b) { - return a * b; - } - ); - opmath_jitted_gpu_kernel_with_scalars(iter, mul_string); - #else - using opmath_t = at::opmath_type; - opmath_symmetric_gpu_kernel_with_scalars(iter, MulFunctor()); - #endif - } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() { - using opmath_t = at::opmath_type; - opmath_symmetric_gpu_kernel_with_scalars(iter, MulFunctor()); - }); - } -} - -REGISTER_DISPATCH(div_true_stub, &div_true_kernel_cuda); -REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel_cuda); -REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel_cuda); -REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda); - -}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BinaryMulKernel.cu b/aten/src/ATen/native/cuda/BinaryMulKernel.cu new file mode 100644 index 0000000000000..b0b4f4886ab85 --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryMulKernel.cu @@ -0,0 +1,50 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at { +namespace native { + +const char mul_name[] = "mul_kernel"; +void mul_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (common_dtype == kComplexHalf) { + using scalar_t = c10::complex; +#if AT_USE_JITERATOR() + static const auto mul_string = jiterator_stringify( + template T mul_kernel(T a, T b) { return a * b; }); + opmath_jitted_gpu_kernel_with_scalars( + iter, mul_string); +#else + using opmath_t = at::opmath_type; + opmath_symmetric_gpu_kernel_with_scalars( + iter, binary_internal::MulFunctor()); +#endif + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() { + using opmath_t = at::opmath_type; + opmath_symmetric_gpu_kernel_with_scalars( + iter, binary_internal::MulFunctor()); + }); + } +} + +REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh index a235d005d3e42..830a3024a9839 100644 --- a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh +++ b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh @@ -17,7 +17,10 @@ #include #include +#include +#include +#include #include #include #include @@ -25,8 +28,6 @@ namespace at { namespace native { -namespace { - template constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence seq) { constexpr auto size = seq.size(); @@ -45,111 +46,90 @@ constexpr auto tuple_to_array(std::tuple& extra_args) { return tuple_to_array_helper(extra_args, std::make_index_sequence{}); } -// Helper function to return a vector -// corresponding to the type of the arguments in parameter pack. -template -c10::SmallVector get_extra_args_typenames() { - return {at::cuda::jit::typeName()...}; +struct JittedVecKernelCache { + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + at::cuda::jit::NvrtcFunction vec1; + at::cuda::jit::NvrtcFunction vec2; + at::cuda::jit::NvrtcFunction vec4; +}; + +struct JittedKernelVariantCache { + JittedVecKernelCache vec; + at::cuda::jit::NvrtcFunction noncontiguous; + at::cuda::jit::NvrtcFunction dynamic_contiguous; + at::cuda::jit::NvrtcFunction dynamic_noncontiguous; +}; + +inline c10::SmallBuffer pack_kernel_args( + std::initializer_list args, + c10::ArrayRef extra_args) { + c10::SmallBuffer ret(args.size() + extra_args.size()); + std::copy(args.begin(), args.end(), ret.data()); + std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size()); + return ret; } -} // namespace - -template -static inline void launch_jitted_unrolled_kernel( - DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data, - inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous, - at::opmath_type scalar_val, - std::tuple extra_args) { + typename storer_t> +void launch_jitted_unrolled_kernel( + std::mutex &jiterator_mutex, + at::cuda::jit::NvrtcFunction &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, + int64_t N, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s, + bool contiguous, + at::cuda::jit::BinaryFuncVariant scalar_pos, + void* scalar_val, + c10::ArrayRef extra_args) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); //casting result to int is always safe, intermediate is int64 and won't overflow const uint32_t grid = (N + block_work_size() - 1) / block_work_size(); - static std::mutex _jiterator_mutex; - static std::vector fns(c10::cuda::device_count()); - - at::cuda::jit::NvrtcFunction* fn_ptr = &fns[dev_idx]; - if (!fn_ptr->function) { - const std::lock_guard lock{_jiterator_mutex}; - if (!fn_ptr->function) { - constexpr int nInputs = array_t::size() - 1; - constexpr int nOutputs = 1; // fix me + if (!fn_cache.function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_cache.function) { constexpr bool dynamic_casting = !std::is_same() || !std::is_same(); - std::string string_name{name}; - std::string f_inputs_type_str = at::cuda::jit::typeName(); - std::string compute_type_str = at::cuda::jit::typeName>(); - std::string result_type_str = at::cuda::jit::typeName(); - c10::SmallVector extra_args_types = get_extra_args_typenames(); - auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, string_name, - f_inputs_type_str, compute_type_str, result_type_str, - contiguous, dynamic_casting, scalar_pos, extra_args_types); - *fn_ptr = at::cuda::jit::jit_pwise_function(code, name); + auto code = at::cuda::jit::generate_code( + desc, contiguous, dynamic_casting, scalar_pos); + fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name); } } - // pack args for kernel launch - constexpr int kernel_args = 7; - // size of `extra_args` is known at compile-time - constexpr auto extra_args_size = sizeof...(Args); - void* args[kernel_args + extra_args_size]; - args[0] = static_cast(&N); - args[1] = static_cast(&data); - args[2] = static_cast(&ic); - args[3] = static_cast(&oc); - args[4] = static_cast(&l); - args[5] = static_cast(&s); - args[6] = static_cast(&scalar_val); - - auto extra_args_array = tuple_to_array(extra_args); - for (const auto i : c10::irange(extra_args_size)) { - // since 7 slots are already filled in `args` - args[i + 7] = extra_args_array[i]; - } - at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1u, 1u}, + auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); } -template< - char const *name, - typename result_type, - typename f_inputs_type, - int arity, - at::cuda::jit::BinaryFuncVariant scalar_pos, - typename array_t, typename ... Args> -static inline void launch_jitted_vectorized_kernel(DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data, -at::opmath_type scalar_val, std::tuple extra_args) { +template +void launch_jitted_vectorized_kernel( + std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data, + at::cuda::jit::BinaryFuncVariant scalar_pos, + void *scalar_val, c10::ArrayRef extra_args) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); // N is still int64_t for the computation, but it's always safe to cast result to int const uint32_t grid = (N + block_work_size() - 1) / block_work_size(); - const int vec_size = memory::jitted_can_vectorize_up_to(data); + const int vec_size = at::cuda::jit::can_vectorize_up_to( + desc, c10::ArrayRef(data.data, data.size())); // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) // fn_ptr is set to the appropriate function based on the vec size and GPU used - // TODO: Memory use can probably be optimized by re-using kernels across GPUs with - // the same compute capability - static std::mutex _jiterator_mutex; - static std::vector fns4(c10::cuda::device_count()); - static std::vector fns2(c10::cuda::device_count()); - static std::vector fns1(c10::cuda::device_count()); - - at::cuda::jit::NvrtcFunction* fn_ptr; if (vec_size == 4) { - fn_ptr = &fns4[dev_idx]; + fn_ptr = &fn_cache.vec4; } else if (vec_size == 2) { - fn_ptr = &fns2[dev_idx]; + fn_ptr = &fn_cache.vec2; } else if (vec_size ==1) { - fn_ptr = &fns1[dev_idx]; + fn_ptr = &fn_cache.vec1; } else { TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel"); } @@ -157,94 +137,54 @@ at::opmath_type scalar_val, std::tuple extra_args) { bool vectorized = vec_size > 1; if (!fn_ptr->function) { - const std::lock_guard lock{_jiterator_mutex}; + const std::lock_guard lock{jiterator_mutex}; if (!fn_ptr->function) { // cache miss! // Generates program - constexpr int nInputs = array_t::size() - 1; - constexpr int nOutputs = 1; // fix me - std::string string_name{name}; - std::string f_inputs_type_str = at::cuda::jit::typeName(); - std::string compute_type_str = at::cuda::jit::typeName>(); - std::string result_type_str = at::cuda::jit::typeName(); - c10::SmallVector extra_args_types = get_extra_args_typenames(); - auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, string_name, - f_inputs_type_str, compute_type_str, result_type_str, - /*contiguous=*/true, /*dynamic_casting=*/false, - scalar_pos, - extra_args_types, - vectorized, vec_size); - std::string kernel_name = vectorized ? string_name + "_vectorized" + std::to_string(vec_size) : string_name; + auto code = at::cuda::jit::generate_code( + desc, /*contiguous=*/true, /*dynamic_casting=*/false, + scalar_pos, vectorized, vec_size); + std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name; // Acquires the program *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name); } } - // size of `extra_args` is known at compile-time - constexpr auto extra_args_size = sizeof...(Args); - auto extra_args_array = tuple_to_array(extra_args); - if (vectorized) { - // pack args for kernel launch - constexpr int kernel_args = 3; - void* args[kernel_args + extra_args_size]; - args[0] = static_cast(&N); - args[1] = static_cast(&data); - args[2] = static_cast(&scalar_val); - - for (const auto i : c10::irange(extra_args_size)) { - // since 3 slots are already filled in `args` - args[i + 3] = extra_args_array[i]; - } - at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1u, 1u}, {num_threads(), 1u, 1u}); + auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); } else { auto ic = TrivialOffsetCalculator(); auto oc = TrivialOffsetCalculator<1>(); auto l = memory::LoadWithoutCast(); auto s = memory::StoreWithoutCast(); - // pack args for kernel launch - constexpr int kernel_args = 7; - void* args[kernel_args + extra_args_size]; - args[0] = static_cast(&N); - args[1] = static_cast(&data); - args[2] = static_cast(&ic); - args[3] = static_cast(&oc); - args[4] = static_cast(&l); - args[5] = static_cast(&s); - args[6] = static_cast(&scalar_val); - - for (const auto i : c10::irange(extra_args_size)) { - // since 7 slots are already filled in `args` - args[i + 7] = extra_args_array[i]; - } - - at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1u, 1u}, {num_threads(), 1u, 1u}); + auto args = pack_kernel_args( + {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); } } -template < - char const* name, - typename result_type, - typename f_inputs_type, - int arity, - at::cuda::jit::BinaryFuncVariant scalar_pos = - at::cuda::jit::BinaryFuncVariant::NoScalar, - typename... Args> -void jitted_gpu_kernel_impl( +template +void jitted_gpu_kernel_generic( + std::mutex &jiterator_mutex, + JittedKernelVariantCache &cache, + const at::cuda::jit::KernelDescriptor &desc, + at::cuda::jit::BinaryFuncVariant scalar_pos, + c10::ArrayRef extra_args, TensorIteratorBase& iter, - const std::string& f, const bool dynamic_casting, - at::opmath_type scalar_val, - std::tuple extra_args) { + void *scalar_val) { TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); TORCH_INTERNAL_ASSERT(iter.ninputs() == arity); TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); constexpr int ntensors = arity + 1; at::detail::Array data; - for (auto i = decltype(ntensors){0}; i < ntensors; ++i) { + for (auto i : c10::irange(ntensors)) { data[i] = (char*)iter.data_ptr(i); } @@ -262,8 +202,9 @@ void jitted_gpu_kernel_impl( if (!dynamic_casting) { if (contiguous) { // Case 1: no dynamic casting and contiguous - launch_jitted_vectorized_kernel( - iter.device().index(), numel, f, data, scalar_val, extra_args); + launch_jitted_vectorized_kernel( + jiterator_mutex, cache.vec, desc, + numel, data, scalar_pos, scalar_val, extra_args); return; } @@ -272,9 +213,10 @@ void jitted_gpu_kernel_impl( auto output_offset_calculator = make_output_offset_calculator(iter); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); - launch_jitted_unrolled_kernel( - iter.device().index(), numel, f, data, input_offset_calculator, - output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.noncontiguous, desc, numel, data, + input_offset_calculator, output_offset_calculator, loader, + storer, contiguous, scalar_pos, scalar_val, extra_args); return; } @@ -291,18 +233,58 @@ void jitted_gpu_kernel_impl( // Case 3: dynamic casting and contiguous auto input_offset_calculator = TrivialOffsetCalculator(); auto output_offset_calculator = TrivialOffsetCalculator<1>(); - launch_jitted_unrolled_kernel( - iter.device().index(), numel, f, data, input_offset_calculator, - output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); return; } // Case 4: dynamic casting and noncontiguous auto input_offset_calculator = make_input_offset_calculator(iter); auto output_offset_calculator = make_output_offset_calculator(iter); - launch_jitted_unrolled_kernel( - iter.device().index(), numel, f, data, input_offset_calculator, - output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); +} + +// NOTE: static to reduce chances of name collision. +template < + char const* name, + typename result_type, + typename f_inputs_type, + int arity, + at::cuda::jit::BinaryFuncVariant scalar_pos = + at::cuda::jit::BinaryFuncVariant::NoScalar, + typename... ExtraArgs> +static void jitted_gpu_kernel_impl( + TensorIteratorBase& iter, + const std::string &f, + const bool dynamic_casting, + at::opmath_type scalar_val, + std::tuple extra_args) { + + // TODO: Memory use can probably be optimized by re-using kernels across GPUs with + // the same compute capability + static std::mutex jiterator_mutex; + static std::vector device_caches(c10::cuda::device_count()); + + constexpr int nInputs = arity; + constexpr int nOutputs = 1; // TODO: Support more than 1 output + static const auto desc = at::cuda::jit::make_kernel_descriptor< + result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs); + + auto &cache = device_caches[iter.device().index()]; + auto extra_args_array = tuple_to_array(extra_args); + return jitted_gpu_kernel_generic( + jiterator_mutex, + cache, + desc, + scalar_pos, + extra_args_array, + iter, + dynamic_casting, + &scalar_val + ); } }} // at::native diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 57f04d481fc5c..4fb647e329d3c 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -173,9 +173,10 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) { Tensor dst_contig; Tensor src_contig; - // Type conversions are performed on the CPU for CPU-GPU copies and on - // the src device for GPU-GPU copies. - if (iter.device_type(0) == kCUDA) { + // If non_blocking is true - type conversions are performed on the GPU + // for CPU-GPU copies, otherwise type conversions are performed on the CPU. + // Type conversions are performed on the src device for GPU-GPU copies. + if (iter.device_type(0) == kCUDA || non_blocking) { dst_contig = dst.is_contiguous() ? dst : at::empty_like(dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT); src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous(); } else { diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu index 12817d5f66ea0..c3ed197acb929 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu @@ -35,22 +35,40 @@ __device__ inline int min(int a, int b) { template __global__ static void max_pool3d_with_indices_single_out_frame( scalar_t* inputData, - PackedTensorAccessor64 output, - PackedTensorAccessor64 indices, + scalar_t* outputData, + int64_t* indicesData, + int features, int itime, int iheight, int iwidth, + int obatch, int otime, int oheight, int owidth, int kT, int kH, int kW, int dT, int dH, int dW, int pT, int pH, int pW, int dilationT, int dilationH, int dilationW, - int offsetZ) + int offsetZ, + bool channels_last) { int oColumn = blockIdx.x * blockDim.x + threadIdx.x; - int oRow = blockIdx.y * blockDim.y + threadIdx.y; - int oFrame = (blockIdx.z + offsetZ) % output.size(1); // output frame/time - int64_t slice = (blockIdx.z + offsetZ) / output.size(1); // output slice/feature - // For int64_t data type, see https://github.com/pytorch/pytorch/issues/52822 + int oRow = blockIdx.y * blockDim.y + threadIdx.y; + int oFrame = 0; + // used only for channels-first indexing + int64_t slice = 0; + // used only for channels-last indexing + int batch = 0; + int channel = 0; + if (!channels_last) { + // indexing order: batch, channel, time + oFrame = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % otime; // output frame/time + slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / otime; // output slice/feature + } else { + // indexing order: batch, time, channel + channel = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % features; // output feature (channel) + slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / features; // output slice (batch + time) + batch = slice / otime; + oFrame = slice % otime; + } - if (oRow < output.size(2) && oColumn < output.size(3)) + // For int64_t data type, see https://github.com/pytorch/pytorch/issues/52822 + if (oRow < oheight && oColumn < owidth && oFrame < otime && channel < features && batch < obatch) { int tStart = oFrame * dT - pT; int hStart = oRow * dH - pH; @@ -66,8 +84,14 @@ __global__ static void max_pool3d_with_indices_single_out_frame( while(wStart < 0) wStart += dilationW; - int maxIndex = tStart * iheight * iwidth + hStart * iwidth + wStart; - inputData += slice * itime * iheight * iwidth; + // maxIndex remains in "channels-first"/contiguous + int64_t maxIndex = tStart * iheight * iwidth + hStart * iwidth + wStart; + + if (!channels_last) { + inputData += (int64_t) slice * itime * iheight * iwidth; + } else { + inputData += ((int64_t) batch * itime * iheight * iwidth * features) + channel; + } scalar_t max = at::numeric_limits::lower_bound(); // -Infinity @@ -77,8 +101,14 @@ __global__ static void max_pool3d_with_indices_single_out_frame( { for (int w = wStart; w < wEnd; w += dilationW) { + scalar_t val; int index = t * iheight * iwidth + h * iwidth + w; - scalar_t val = inputData[index]; + if (!channels_last) { + val = inputData[index]; + } else { + int64_t index_channels_last = index*features; + val = inputData[index_channels_last]; + } if ((max < val) || at::_isnan(val)) { @@ -89,8 +119,14 @@ __global__ static void max_pool3d_with_indices_single_out_frame( } } - output[slice][oFrame][oRow][oColumn] = max; - indices[slice][oFrame][oRow][oColumn] = maxIndex; + int64_t out_index; + if (!channels_last) { + out_index = (int64_t) slice*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn; + } else { + out_index = ((int64_t) batch*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn)*features + channel; + } + outputData[out_index] = max; + indicesData[out_index] = maxIndex; } } @@ -99,37 +135,50 @@ void max_pool3d_with_indices_out_frame( scalar_t* input_data, const Tensor& output, const Tensor& indices, - int totalZ, + int features, + int64_t totalZ, int itime, int iheight, int iwidth, - int otime, int oheight, int owidth, + int obatch, int otime, int oheight, int owidth, int kT, int kH, int kW, int dT, int dH, int dW, int pT, int pH, int pW, - int dilationT, int dilationH, int dilationW) + int dilationT, int dilationH, int dilationW, + bool channels_last) { int offsetZ = 0; - dim3 block(32, 8); + int threadX = 32; + int threadY = 8; + int threadZ = 1; + int stepZ = 65535; + if (channels_last) { + threadX = 2; + threadY = 4; + threadZ = 64; + } + dim3 block(threadX, threadY, threadZ); while (totalZ > 0) { dim3 grid(ceil_div(owidth, static_cast(block.x)), ceil_div(oheight, static_cast(block.y)), - totalZ > 65535 ? 65535 : totalZ); + totalZ > stepZ*threadZ ? stepZ : ceil_div(totalZ, static_cast(threadZ))); max_pool3d_with_indices_single_out_frame <<>>( input_data, - output.packed_accessor64(), - indices.packed_accessor64(), + output.data_ptr(), + indices.data_ptr(), + features, itime, iheight, iwidth, + obatch, otime, oheight, owidth, kT, kH, kW, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW, - offsetZ); + offsetZ, channels_last); C10_CUDA_KERNEL_LAUNCH_CHECK(); - totalZ -= 65535; - offsetZ += 65535; + totalZ -= threadZ*stepZ; + offsetZ += threadZ*stepZ; } } @@ -138,25 +187,52 @@ void max_pool3d_with_indices_out_frame( template __global__ static void max_pool3d_with_indices_backward_single_out_frame( scalar_t *gradInputData, - PackedTensorAccessor64 gradOutput, - PackedTensorAccessor64 indices, + scalar_t *gradOutputData, + int64_t *indicesData, + int features, int itime, int iheight, int iwidth, - int dT, int dH, int dW, - int pT, int pH, int pW, - int dilationT, int dilationH, - int offsetZ) + int obatch, int otime, int oheight, int owidth, + int offsetZ, + bool channels_last) { int oColumn = blockIdx.x * blockDim.x + threadIdx.x; - int oRow = blockIdx.y * blockDim.y + threadIdx.y; - int oFrame = (blockIdx.z + offsetZ) % gradOutput.size(1); // output frame/time - int slice = (blockIdx.z + offsetZ) / gradOutput.size(1); // output slice/feature + int oRow = blockIdx.y * blockDim.y + threadIdx.y; + + int oFrame = 0; + // used only for channels-first indexing + int64_t slice = 0; + // used only for channels-last indexing + int batch = 0; + int channel = 0; + if (!channels_last) { + // indexing order: batch, channel, time + oFrame = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % otime; // output frame/time + slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / otime; // output slice/feature + } else { + // indexing order: batch, time, channel + channel = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % features; // output feature (channel) + slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / features; // output slice (batch + time) + batch = slice / otime; + oFrame = slice % otime; + } - if (oRow < gradOutput.size(2) && oColumn < gradOutput.size(3)) + if (oRow < oheight && oColumn < owidth && oFrame < otime && batch < obatch && channel < features) { - int maxIndex = indices[slice][oFrame][oRow][oColumn]; + int64_t out_index; + if (!channels_last) { + out_index = (int64_t) slice*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn; + } else { + out_index = ((int64_t) batch*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn)*features + channel; + } + int64_t maxIndex = indicesData[out_index]; if (maxIndex != -1) { - gpuAtomicAddNoReturn(&gradInputData[slice * itime * iheight * iwidth + maxIndex], - gradOutput[slice][oFrame][oRow][oColumn]); + if (!channels_last) { + gpuAtomicAddNoReturn(&gradInputData[(int64_t) slice * itime * iheight * iwidth + maxIndex], + gradOutputData[out_index]); + } else { + gpuAtomicAddNoReturn(&gradInputData[((int64_t) batch * itime * iheight * iwidth + maxIndex) * features + channel], + gradOutputData[out_index]); + } } } } @@ -166,35 +242,43 @@ void max_pool3d_with_indices_backward_out_frame( scalar_t *gradInputData, const Tensor& gradOutput, const Tensor& indices, + int features, int64_t totalZ, int itime, int iheight, int iwidth, - int oheight, int owidth, - int dT, int dH, int dW, - int pT, int pH, int pW, - int dilationT, int dilationH) + int obatch, int otime, int oheight, int owidth, + bool channels_last) { int offsetZ = 0; - dim3 block(32, 8); + int threadX = 32; + int threadY = 8; + int threadZ = 1; + int stepZ = 65535; + if (channels_last) { + threadX = 2; + threadY = 4; + threadZ = 64; + } + dim3 block(threadX, threadY, threadZ); while (totalZ > 0) { dim3 grid(ceil_div(owidth, static_cast(block.x)), ceil_div(oheight, static_cast(block.y)), - totalZ > 65535 ? 65535 : totalZ); + totalZ > stepZ*threadZ ? stepZ : ceil_div(totalZ, static_cast(block.z))); max_pool3d_with_indices_backward_single_out_frame <<>>( gradInputData, - gradOutput.packed_accessor64(), - indices.packed_accessor64(), + gradOutput.data_ptr(), + indices.data_ptr(), + features, itime, iheight, iwidth, - dT, dH, dW, - pT, pH, pW, - dilationT, dilationH, - offsetZ); + obatch, otime, oheight, owidth, + offsetZ, + channels_last); C10_CUDA_KERNEL_LAUNCH_CHECK(); - totalZ -= 65535; - offsetZ += 65535; + totalZ -= threadZ*stepZ; + offsetZ += threadZ*stepZ; } } @@ -263,45 +347,65 @@ void max_pool3d_with_indices_out_cuda_template( otime, oheight, owidth, "max_pool3d_with_indices_out_cuda_template()"); + bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d; + Tensor _input = input; if (input.ndimension() == 4) { - output.resize_({ nslices, otime, oheight, owidth}); - indices.resize_({nslices, otime, oheight, owidth}); - } - else { - output.resize_({nbatch, nslices, otime, oheight, owidth}); - indices.resize_({nbatch, nslices, otime, oheight, owidth}); + Tensor input_channels_last_check = input.unsqueeze(0); + // work around buggy behavior of suggest_memory_format here where + // suggested format of unsqueezed tensor is contiguous while it is + // really only contiguous in ChannelsLast3d + channels_last = (!input_channels_last_check.is_contiguous()) && + input_channels_last_check.is_contiguous(at::MemoryFormat::ChannelsLast3d); + if (!channels_last) { + output.resize_({ nslices, otime, oheight, owidth}); + indices.resize_({nslices, otime, oheight, owidth}); + } else { + _input = input_channels_last_check; + output.resize_({1, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d); + indices.resize_({1, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d); + output = output.squeeze(0); + indices = indices.squeeze(0); + } + } else { + if (!channels_last) { + output.resize_({nbatch, nslices, otime, oheight, owidth}); + indices.resize_({nbatch, nslices, otime, oheight, owidth}); + } else { + output.resize_({nbatch, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d); + indices.resize_({nbatch, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d); + } } if (input.numel() == 0) { return; } - Tensor work_input = input.contiguous(); + Tensor work_input; Tensor work_output = output; - Tensor work_indices = indices; - if (input.ndimension() == 5) { - // Collapse batch and feature dimensions. - work_input = work_input.reshape({nbatch * nslices, itime, iheight, iwidth}); - work_output = work_output.reshape({nbatch * nslices, otime, oheight, owidth}); - work_indices = work_indices.reshape({nbatch * nslices, otime, oheight, owidth}); + if (!channels_last) { + work_input = input.contiguous(); + } else { + work_input = _input.contiguous(at::MemoryFormat::ChannelsLast3d); } + Tensor work_indices = indices; AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_pool3d_with_indices_out_frame", [&]{ scalar_t *input_data = work_input.data_ptr(); - int64_t totalZ = otime * nslices * nbatch; + const int64_t totalZ = otime * nslices * nbatch; max_pool3d_with_indices_out_frame( input_data, work_output, work_indices, + nslices, // features totalZ, itime, iheight, iwidth, - otime, oheight, owidth, + nbatch, otime, oheight, owidth, kT, kH, kW, dT, dH, dW, pT, pH, pW, - dilationT, dilationH, dilationW); + dilationT, dilationH, dilationW, channels_last); } ); } @@ -361,7 +465,24 @@ void max_pool3d_with_indices_backward_out_cuda_template( "Expected 4D or 5D gradOutput tensor, but got ", gradOutput.sizes()); // Resize and initialize result tensor. - gradInput.resize_as_(input); + bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d; + Tensor _input = input; + if (input.ndimension() == 4) { + Tensor input_channels_last_check = input.unsqueeze(0); + // work around buggy behavior of suggest_memory_format here where + // suggested format of unsqueezed tensor is contiguous while it is + // really only contiguous in ChannelsLast3d + channels_last = (!input_channels_last_check.is_contiguous()) && + input_channels_last_check.is_contiguous(at::MemoryFormat::ChannelsLast3d); + if (channels_last) { + _input = input_channels_last_check; + } + } + if (!channels_last) { + gradInput.resize_as_(input); + } else { + gradInput.resize_as_(_input, at::MemoryFormat::ChannelsLast3d); + } gradInput.zero_(); const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1; @@ -393,14 +514,19 @@ void max_pool3d_with_indices_backward_out_cuda_template( } Tensor work_grad_input = gradInput; - Tensor work_grad_output = gradOutput.contiguous(); - Tensor work_indices = indices.contiguous(); - - if (input.ndimension() == 5) { - // Collapse batch and feature dimensions. - work_grad_input = work_grad_input.reshape({nbatch * nslices, itime, iheight, iwidth}); - work_grad_output = work_grad_output.reshape({nbatch * nslices, otime, oheight, owidth}); - work_indices = work_indices.reshape({nbatch * nslices, otime, oheight, owidth}); + Tensor work_grad_output; + Tensor work_indices; + if (!channels_last) { + work_grad_output = gradOutput.contiguous(); + work_indices = indices.contiguous(); + } else { + if (input.ndimension() == 4) { + work_grad_output = gradOutput.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d); + work_indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d); + } else { + work_grad_output = gradOutput.contiguous(at::MemoryFormat::ChannelsLast3d); + work_indices = indices.contiguous(at::MemoryFormat::ChannelsLast3d); + } } AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), @@ -411,12 +537,11 @@ void max_pool3d_with_indices_backward_out_cuda_template( max_pool3d_with_indices_backward_out_frame( grad_input_data, work_grad_output, work_indices, + nslices, totalZ, itime, iheight, iwidth, - oheight, owidth, - dT, dH, dW, - pT, pH, pW, - dilationT, dilationH); + nbatch, otime, oheight, owidth, + channels_last); } ); } @@ -512,7 +637,7 @@ Tensor max_pool3d_with_indices_backward_cuda( // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("max_pool3d_with_indices_backward_cuda"); - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto gradInput = at::zeros_like(input, input.suggest_memory_format()); max_pool3d_with_indices_backward_out_cuda_template( gradInput, gradOutput, diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 8a241cabcd2d3..568669e10fdfe 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -98,10 +98,9 @@ __global__ void embedding_backward_feature_kernel // then finishes by adding the accumulated buffer to dst_row in grad_weight. if(dst_row != padding_idx && src_row < n) // Per-warp exit condition, safe with ballot_sync { - int match_found_this_thread = - (dst_row == indices_batch[chunk_start - batch_start + threadIdx.x]); - if(threadIdx.x >= n_this_chunk) - match_found_this_thread = 0; + int match_found_this_thread = 0; + if(threadIdx.x < n_this_chunk) + match_found_this_thread = (dst_row == indices_batch[chunk_start - batch_start + threadIdx.x]); #if defined(USE_ROCM) unsigned long long int matchmask = WARP_BALLOT(match_found_this_thread); int first_remaining_peer = __ffsll(matchmask) - 1; diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index c226bb6a1858d..dee39b40e91e1 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -347,7 +347,7 @@ void masked_scatter_cuda_impl( auto maskPrefixSum_data = maskPrefixSum.data_ptr(); auto mask_data = mask_cont.data_ptr(); - at::cuda::cub::exclusive_sum_in_common_type( + at::cuda::cub::mask_exclusive_sum( mask_data, maskPrefixSum_data, mask_numel); // Asynchronously check that the number of `1` elements present in the mask diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index e10ea18cfa414..4720da4bd1124 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -545,11 +545,13 @@ bool indexShouldBeMajor(cuda::detail::TensorInfo &info, } void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) { - if (!result.is_same(self)) result.copy_(self); + if (!result.is_same(self)) { + result.copy_(self); + } // Scalars are treated as 1-d tensor - Tensor self_ = (result.dim() == 0) ? result.view(1) : result; - Tensor source_ = (source.dim() == 0) ? source.view(1) : source; + const Tensor self_ = (result.dim() == 0) ? result.view(1) : result; + const Tensor source_ = (source.dim() == 0) ? source.view(1) : source; TORCH_CHECK(result.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims"); TORCH_CHECK(source.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims" ); @@ -571,19 +573,19 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c // total size of the tensor ignoring dimension `dim`; // -the number of index we are choosing, which is the total size // of the tensor `index`. - ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_); - ptrdiff_t sourceTotalSize = source.numel(); - int64_t selfAddDimSize = self_.size(dim); - ptrdiff_t numIndex = index.numel(); - int64_t selfNumel = self_.numel(); + const ptrdiff_t sliceSize = getSliceSize(self_, dim, index, source_); + const ptrdiff_t sourceTotalSize = source.numel(); + const int64_t selfAddDimSize = self_.size(dim); + const ptrdiff_t numIndex = index.numel(); + const int64_t selfNumel = self_.numel(); if (sliceSize == 0) { return; } const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - bool indContig = index.is_contiguous(); + const bool indContig = index.is_contiguous(); - int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + const int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; #define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ indexFuncSmallIndex \ @@ -604,25 +606,25 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c selfAddDimSize, selfNumel, reduce_add, alpha_value); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); - dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); - dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); + const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); + const dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); - dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); - dim3 largeIndexBlock(std::min(sourceTotalSize, (ptrdiff_t)128)); + const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); + const dim3 largeIndexBlock(std::min(sourceTotalSize, (ptrdiff_t)128)); if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(source) && cuda::detail::canUse32BitIndexMath(index)) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "index_add", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::ComplexHalf, result.scalar_type(), "index_add", [&] { cuda::detail::TensorInfo selfInfo = cuda::detail::getTensorInfo(self_); - int selfAddDim = selfInfo.collapseDims(dim); + const int selfAddDim = selfInfo.collapseDims(dim); selfInfo.reduceDim(selfAddDim); - auto alpha_value = alpha.to(); + const auto alpha_value = alpha.to(); AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { auto sourceInfo = cuda::detail::getTensorInfo(source_); - int sourceAddDim = sourceInfo.collapseDims(dim); + const int sourceAddDim = sourceInfo.collapseDims(dim); sourceInfo.reduceDim(sourceAddDim); auto indexInfo = @@ -642,7 +644,7 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); } } else { - bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); + const bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); @@ -668,13 +670,13 @@ void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, c AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { cuda::detail::TensorInfo selfInfo = cuda::detail::getTensorInfo(self_); - int selfAddDim = selfInfo.collapseDims(dim); + const int selfAddDim = selfInfo.collapseDims(dim); selfInfo.reduceDim(selfAddDim); - auto alpha_value = alpha.to(); + const auto alpha_value = alpha.to(); cuda::detail::TensorInfo sourceInfo = cuda::detail::getTensorInfo(source_); - int sourceAddDim = sourceInfo.collapseDims(dim); + const int sourceAddDim = sourceInfo.collapseDims(dim); sourceInfo.reduceDim(sourceAddDim); AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { @@ -714,7 +716,7 @@ void index_reduce_func_cuda_impl( TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims"); if (!include_self) { - AT_DISPATCH_FLOATING_TYPES_AND2( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_reduce_func_cuda_exclude_input_init", [&] { scalar_t init_val; @@ -786,7 +788,7 @@ void index_reduce_func_cuda_impl( if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(source) && cuda::detail::canUse32BitIndexMath(index)) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "index_reduce", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "index_reduce", [&] { cuda::detail::TensorInfo selfInfo = cuda::detail::getTensorInfo(self_); int selfReduceDim = selfInfo.collapseDims(dim); @@ -838,7 +840,7 @@ void index_reduce_func_cuda_impl( }); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_reduce", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_reduce", [&] { cuda::detail::TensorInfo selfInfo = cuda::detail::getTensorInfo(self_); int selfReduceDim = selfInfo.collapseDims(dim); @@ -886,7 +888,11 @@ TORCH_IMPL_FUNC(index_reduce_cuda_out) auto counts = include_self ? at::ones_like(result) : at::zeros_like(result); counts.index_add_(dim, index, at::ones_like(source)); counts.masked_fill_(counts == 0, 1); - result.div_(counts); + if (result.is_floating_point() || result.is_complex()) { + result.div_(counts); + } else { + result.div_(counts, "floor"); + } } else if (reduce == "amax") { index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MAXIMUM, reduce_maximum, result); } else if (reduce == "amin") { diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu index 9f54bf392924e..280a5046ef06d 100644 --- a/aten/src/ATen/native/cuda/LinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu @@ -1,3 +1,4 @@ +#define TORCH_ASSERT_NO_OPERATORS #include #include #include diff --git a/aten/src/ATen/native/cuda/Loss.cu b/aten/src/ATen/native/cuda/Loss.cu index 0f60c673d26bb..2b5d17b9547ed 100644 --- a/aten/src/ATen/native/cuda/Loss.cu +++ b/aten/src/ATen/native/cuda/Loss.cu @@ -60,32 +60,6 @@ void binary_cross_entropy_backward_out_kernel(Tensor& grad_input, const Tensor& namespace at { namespace native { -Tensor kl_div_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) { - auto grad_input = at::empty_like(input); - if (!log_target) { - TensorIterator iter = TensorIteratorConfig() - .add_output(grad_input) - .add_input(target) - .add_input(grad) - .build(); - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "kl_div_backward_cuda", [&]() { - scalar_t inv = (reduction == at::Reduction::Mean) ? scalar_t(1.0 / input.numel()) : scalar_t(1.0); - gpu_kernel(iter, - [inv] GPU_LAMBDA (scalar_t target_val, scalar_t grad_val) { - return (target_val > 0) ? scalar_t(-target_val * grad_val * inv) : scalar_t(0.0); - }); - }); - } - else { - grad_input = -at::exp(target) * grad; - if (reduction == at::Reduction::Mean) { - grad_input /= input.numel(); - } - } - - return grad_input; -} - Tensor binary_cross_entropy_cuda(const Tensor& input, const Tensor& target, const c10::optional& weight_opt, int64_t reduction) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); @@ -178,17 +152,10 @@ namespace { constexpr int NLL_LOSS_THREADS = 32; -#define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - at::ScalarType _it = TYPE; \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _it) \ - switch (_it) { \ - AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Byte, uint8_t, index_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Long, int64_t, index_t, __VA_ARGS__)\ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \ - } \ - }() +#define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Byte, index_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Long, index_t, __VA_ARGS__)) template __global__ void nll_loss_forward_no_reduce_cuda_kernel( diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 3b4ca34cf15e6..c2a4870d7f5dd 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -1265,6 +1265,187 @@ const auto erfcx_string = jiterator_stringify( } ); // erfcx_string +const auto airy_ai_string = jiterator_stringify( + template + T airy_ai_forward(T x) { + static const T AN[] = { + +3.46538101525629032477e-01, + +1.20075952739645805542e+01, + +7.62796053615234516538e+01, + +1.68089224934630576269e+02, + +1.59756391350164413639e+02, + +7.05360906840444183113e+01, + +1.40264691163389668864e+01, + +9.99999999999999995305e-01, + }; + + static const T AD[] = { + +5.67594532638770212846e-01, + +1.47562562584847203173e+01, + +8.45138970141474626562e+01, + +1.77318088145400459522e+02, + +1.64234692871529701831e+02, + +7.14778400825575695274e+01, + +1.40959135607834029598e+01, + +1.00000000000000000470e+00, + }; + + static const T AFN[] = { + -1.31696323418331795333e-01, + -6.26456544431912369773e-01, + -6.93158036036933542233e-01, + -2.79779981545119124951e-01, + -4.91900132609500318020e-02, + -4.06265923594885404393e-03, + -1.59276496239262096340e-04, + -2.77649108155232920844e-06, + -1.67787698489114633780e-08, + }; + + static const T AFD[] = { + +1.33560420706553243746e+01, + +3.26825032795224613948e+01, + +2.67367040941499554804e+01, + +9.18707402907259625840e+00, + +1.47529146771666414581e+00, + +1.15687173795188044134e-01, + +4.40291641615211203805e-03, + +7.54720348287414296618e-05, + +4.51850092970580378464e-07, + }; + + static const T AGN[] = { + +1.97339932091685679179e-02, + +3.91103029615688277255e-01, + +1.06579897599595591108e+00, + +9.39169229816650230044e-01, + +3.51465656105547619242e-01, + +6.33888919628925490927e-02, + +5.85804113048388458567e-03, + +2.82851600836737019778e-04, + +6.98793669997260967291e-06, + +8.11789239554389293311e-08, + +3.41551784765923618484e-10, + }; + + static const T AGD[] = { + +9.30892908077441974853e+00, + +1.98352928718312140417e+01, + +1.55646628932864612953e+01, + +5.47686069422975497931e+00, + +9.54293611618961883998e-01, + +8.64580826352392193095e-02, + +4.12656523824222607191e-03, + +1.01259085116509135510e-04, + +1.17166733214413521882e-06, + +4.91834570062930015649e-09, + }; + + int domain_flag = 0; + + T ai; + + if (isinf(x)) { + return NAN; + } + + if (x > T(103.892)) { + return T(0.0); + } + + T f; + T g; + T k; + + if (x < T(-2.09)) { + T z = T(1.0) / (T(-2.0) * x * sqrt(-x) / T(3.0)); + + T afn = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afn = afn * (z * z) + AFN[index]; + } + + T afd = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afd = afd * (z * z) + AFD[index]; + } + + T agn = 0.0; + + for (uint8_t index = 0; index <= 10 + 0; index++) { + agn = agn * (z * z) + AGN[index]; + } + + T agd = 0.0; + + for (uint8_t index = 0; index <= 10 - 1; index++) { + agd = agd * (z * z) + AGD[index]; + } + + T t = T(-2.0) * x * sqrt(-x) / T(3.0) + T(0.25) * T(3.14159265358979323846); + + return T(5.64189583547756286948e-01) / sqrt(sqrt(-x)) * (sin(t) * (T(1.0) + z * z * afn / afd) - cos(t) * (z * agn / agd)); + } + + if (x >= T(2.09)) { + domain_flag = 5; + + T zeta = T(2.0) * x * sqrt(x) / T(3.0); + + T an = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + an = an * (T(1.0) / zeta) + AN[index]; + } + + T ad = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + ad = ad * (T(1.0) / zeta) + AD[index]; + } + + ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * sqrt(sqrt(x)) * exp(zeta)); + + if (x > T(8.3203353)) { + return ai; + } + } + + f = 1.0; + g = x; + k = 1.0; + + T m = 1.0; + T n = x; + T t = 1.0; + T z = x * x * x; + + while (t > T(1.11022302462515654042e-16)) { + m *= z; + k += T(1.0); + m /= k; + n *= z; + k += T(1.0); + n /= k; + m /= k; + f += m; + k += T(1.0); + n /= k; + g += n; + + t = abs(m / f); + } + + if ((domain_flag & 1) == 0) { + return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g; + } + + return ai; + } // T airy_ai(T x) +); // airy_ai_string + const auto bessel_j0_string = jiterator_stringify( template T bessel_j0_forward(T x) { @@ -2354,6 +2535,85 @@ const auto modified_bessel_k0_string = modified_bessel_i0_string + jiterator_str } // modified_bessel_k0_forward(T x) ); // modified_bessel_k0_string +const auto scaled_modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify( + template + T scaled_modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (T(0.5) * (a - p) - log(T(0.5) * x) * modified_bessel_i0_forward(x)) * exp(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return T(0.5) * (b - p) / sqrt(x); + } // T scaled_modified_bessel_k0_forward(T x) +); // scaled_modified_bessel_k0_string + const auto modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify( template T modified_bessel_k1_forward(T x) { @@ -2434,6 +2694,86 @@ const auto modified_bessel_k1_string = modified_bessel_i1_string + jiterator_str } // modified_bessel_k1_forward(T x) ); // modified_bessel_k1_string +const auto scaled_modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify( + template + T scaled_modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * exp(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return (T(0.5) * (b - p) / sqrt(x)); + } // T scaled_modified_bessel_k1_forward(T x) +); // scaled_modified_bessel_k1_string + const auto shifted_chebyshev_polynomial_t_string = jiterator_stringify( template T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { @@ -2654,6 +2994,21 @@ const auto shifted_chebyshev_polynomial_w_string = jiterator_stringify( } // shifted_chebyshev_polynomial_w_forward(T x, T n) ); // shifted_chebyshev_polynomial_w_string +const auto spherical_bessel_j0_string = jiterator_stringify( + template + T spherical_bessel_j0_forward(T x) { + if (isinf(x)) { + return T(0.0); + } + + if (abs(x) < T(0.5)) { + return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0))))))); + } + + return sin(x) / x; + } // T spherical_bessel_j0_forward(T x) +); // spherical_bessel_j0_string + #else // !AT_USE_JITERATOR() -- kernels must be precompiled template diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 409354bb3cb8e..355db3439d07b 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -382,19 +382,4 @@ inline int can_vectorize_up_to(array_t pointers) { return result; } -// jitted version of the above -// See Note [Jiterator], this relies on the assumptions enumerated there -template -inline int jitted_can_vectorize_up_to(array_t pointers) { - // Deals with output - int result = can_vectorize_up_to(pointers[0]); - - // Incorporates input(s) - for (auto i = decltype(arity){1}; i < (arity + 1); ++i) { - result = std::min(result, can_vectorize_up_to(pointers[i])); - } - - return result; -} - }}} // namespace at::native::memory diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index e7b2372a18dad..3b27ebfc7d922 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -1,4 +1,4 @@ -// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include @@ -20,18 +20,12 @@ #include #include #include +#include #include #include #include #endif -// TODO: Doesn't exist in this branch -#if 0 -#include -#else -#include -#endif - namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh index 4f308d0847dcb..9958d4c9b8144 100644 --- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh +++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh @@ -123,12 +123,14 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc for (int it = 0; it < WARP_ITERATIONS; ++it) { if (is_masked) { int idx = it*WARP_SIZE; - if (!is_transformer_mask) { - idx += i*element_count; - } - if (!mask[idx]) { - max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - is_meaningful_max = true; + if ((idx + local_idx) < element_count) { + if (!is_transformer_mask) { + idx += i*element_count; + } + if (!mask[idx]) { + max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + is_meaningful_max = true; + } } } else { max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it]; @@ -156,22 +158,28 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc } } else { int idx = it*WARP_SIZE; + bool valid = (idx + local_idx) < element_count; if (!is_transformer_mask) { idx += i*element_count; } - - if (!mask[idx]) { - if (is_log_softmax) { - sum[i] += std::exp(elements[i][it] - max_value[i]); + if (valid) { + if (!mask[idx]) { + if (is_log_softmax) { + sum[i] += std::exp(elements[i][it] - max_value[i]); + } else { + elements[i][it] = std::exp(elements[i][it] - max_value[i]); + sum[i] += elements[i][it]; + } } else { - elements[i][it] = std::exp(elements[i][it] - max_value[i]); - sum[i] += elements[i][it]; + if (!is_log_softmax) { + // Masked values are treated as -infinity, and std::exp(-infinity) is 0. + elements[i][it] = 0; + } } } else { - if (!is_log_softmax) { - // Masked values are treated as -infinity, and std::exp(-infinity) is 0. - elements[i][it] = 0; - } + if (!is_log_softmax) { + elements[i][it] = 0.; + } } } } diff --git a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu index b1c4a2ae4b411..a0a976ef141fa 100644 --- a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu +++ b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu @@ -13,9 +13,12 @@ namespace at { namespace native { const char addcmul_name[] = "addcmul"; void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { - auto dtype = iter.dtype(); + auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { - #if AT_USE_JITERATOR() + // When using Jiterator, addcmul and addcdiv kernels get stuck during a + // promotion test on CUDA 11.3, so only enable that from CUDA 11.5: + // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209 + #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() { auto alpha = value.to(); static const auto addcmul_string = jiterator_stringify( @@ -55,9 +58,12 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { // return a + alpha * (b / static_cast(c)); const char addcdiv_name[] = "addcdiv"; void addcdiv_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { - auto dtype = iter.dtype(); + auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { - #if AT_USE_JITERATOR() + // When using Jiterator, addcmul and addcdiv kernels get stuck during a + // promotion test on CUDA 11.3, so only enable that from CUDA 11.5: + // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209 + #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcdiv_cuda", [&]() { auto alpha = value.to(); static const auto addcdiv_string = diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index 55981ac1ad8e3..17aa4099dc532 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -218,6 +218,13 @@ Tensor& arange_cuda_out(const Scalar& start, const Scalar& end, const Scalar& st auto xend = end.to(); auto xstep = step.to(); + TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); + TORCH_CHECK(std::isfinite(static_cast(xstart)) && + std::isfinite(static_cast(xend)), + "unsupported range: ", xstart, " -> ", xend); + TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), + "upper bound and larger bound inconsistent with step sign"); + // we use double precision for (start - end) / step // to compute size_d for consistency across devices. // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, @@ -227,20 +234,13 @@ Tensor& arange_cuda_out(const Scalar& start, const Scalar& end, const Scalar& st // the corner-case we do want to take into account is int64_t, which has higher precision than double double size_d; if (std::is_same::value) { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); + int64_t sgn = (xstep > 0) - (xstep < 0); + size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); } else { size_d = std::ceil(static_cast(end.to() - start.to()) / step.to()); } - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); - TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), "invalid size, possible overflow?"); int64_t size = static_cast(size_d); diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index a614542cd9281..34e99ae57a59d 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -69,6 +69,10 @@ struct mnt_wrapper >{ static constexpr int MAX_NUM_THREADS = 256; }; +constexpr int max_reduce_threads(c10::ScalarType type) { + return type == kComplexDouble ? 256 : 512; +} + struct ReduceConfig { static constexpr int BLOCK_X = 0; static constexpr int BLOCK_Y = 1; @@ -896,43 +900,37 @@ static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) } } -template -static void launch_jitted_reduce_kernel(DeviceIndex idx, const ReduceConfig& config, -R& reduction, const std::string& func) { - constexpr int max_threads = mnt_wrapper::MAX_NUM_THREADS; +inline void launch_jitted_reduce_kernel( + std::mutex &jiterator_mutex, + std::array &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, + int vt0, const ReduceConfig& config, void *reduction) { dim3 block = config.block(); dim3 grid = config.grid(); - static std::mutex _jiterator_mutex; - static std::vector> fns(c10::cuda::device_count()); int shared_memory = config.shared_memory_size(); at::cuda::jit::NvrtcFunction* fn_ptr; switch(config.output_vec_size) { case 4: - fn_ptr = &fns[idx][0]; + fn_ptr = &fn_cache[0]; break; case 2: - fn_ptr = &fns[idx][1]; + fn_ptr = &fn_cache[1]; break; default: - fn_ptr = &fns[idx][2]; + fn_ptr = &fn_cache[2]; } if (!fn_ptr->function) { - std::string f_inputs_type_str = at::cuda::jit::typeName(); - std::string accum_type_str = at::cuda::jit::typeName>(); - std::string result_type_str = at::cuda::jit::typeName(); - int max_threads_codegen = max_threads/config.output_vec_size; - auto code = at::cuda::jit::generate_reduction_code(1, func, name, vt0, - f_inputs_type_str, accum_type_str, result_type_str, - true, false, config.output_vec_size, max_threads_codegen); - - *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_"+std::string(name)); + int max_threads_codegen = + max_reduce_threads(desc.f_inputs_type) / config.output_vec_size; + auto code = at::cuda::jit::generate_reduction_code( + desc, vt0, true, false, config.output_vec_size, max_threads_codegen); + *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_" + desc.name); } constexpr int kernel_args = 1; void* args[kernel_args]; - args[0] = static_cast(&reduction); + args[0] = reduction; at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory); } @@ -1311,9 +1309,17 @@ inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& fu reduce.accumulate = iter.should_accumulate(); reduce.final_output = iter.is_final_output(); - launch_jitted_reduce_kernel(iter.device().index(), - config, reduce, func); + constexpr int nInputs = 1; + constexpr int nOutputs = 1; + static auto desc = at::cuda::jit::make_kernel_descriptor< + out_scalar_t, scalar_t>(name, func, nInputs, nOutputs); + + static std::mutex jiterator_mutex; + static std::vector> fn_cache(c10::cuda::device_count()); + auto &cache = fn_cache[iter.device().index()]; + + launch_jitted_reduce_kernel( + jiterator_mutex, cache, desc, vt0, config, &reduce); } }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu new file mode 100644 index 0000000000000..292404cb36acb --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu @@ -0,0 +1,51 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +void _min_max_values_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + MinMaxOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), + at::numeric_limits::lower_bound())); +} + +void aminmax_allreduce_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] { + _min_max_values_kernel_cuda_impl(iter); + }); +} + +void aminmax_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() { + gpu_reduce_kernel( + iter, + MinMaxOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), + at::numeric_limits::lower_bound())); + }); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceArgMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceArgMaxKernel.cu new file mode 100644 index 0000000000000..fd8e071cd5c8d --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceArgMaxKernel.cu @@ -0,0 +1,48 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +void argmax_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + ArgMaxOps{}, + thrust::pair( + at::numeric_limits::lower_bound(), 0)); +}; + +void argmax_kernel_cuda(TensorIterator& iter) { + // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, + // we can convert float16 & bfloat16 to float and do all the operations in + // float. + if (iter.dtype(1) == kHalf) { + argmax_kernel_cuda_impl(iter); + } else if (iter.dtype(1) == kBFloat16) { + argmax_kernel_cuda_impl(iter); + } else { + AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() { + argmax_kernel_cuda_impl(iter); + }); + } +} + +REGISTER_DISPATCH(argmax_stub, &argmax_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceArgMinKernel.cu b/aten/src/ATen/native/cuda/ReduceArgMinKernel.cu new file mode 100644 index 0000000000000..20eb736e49457 --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceArgMinKernel.cu @@ -0,0 +1,48 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +void argmin_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + ArgMinOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), 0)); +}; + +void argmin_kernel_cuda(TensorIterator& iter) { + // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, + // we can convert float16 & bfloat16 to float and do all the operations in + // float. + if (iter.dtype(1) == kHalf) { + argmin_kernel_cuda_impl(iter); + } else if (iter.dtype(1) == kBFloat16) { + argmin_kernel_cuda_impl(iter); + } else { + AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() { + argmin_kernel_cuda_impl(iter); + }); + } +} + +REGISTER_DISPATCH(argmin_stub, &argmin_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu new file mode 100644 index 0000000000000..a5363838ee257 --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu @@ -0,0 +1,63 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +struct MaxNanFunctor { + __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { + return (at::_isnan(a) || a > b) ? a : b; + } +}; + +template +void max_values_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + func_wrapper(MaxNanFunctor()), + at::numeric_limits::lower_bound()); +} + +void max_values_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() { + max_values_kernel_cuda_impl(iter); + }); +} + +void max_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() { + gpu_reduce_kernel( + iter, + MaxOps{}, + thrust::pair( + at::numeric_limits::lower_bound(), 0)); + }); +} + +void max_all_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] { + max_values_kernel_cuda_impl(iter); + }); +} + +REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu deleted file mode 100644 index db9e9b5e60aff..0000000000000 --- a/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu +++ /dev/null @@ -1,168 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - - -namespace at { namespace native { - -template -struct MaxNanFunctor { - __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { - return (at::_isnan(a) || a > b) ? a : b; - } -}; - -template -void max_values_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, func_wrapper (MaxNanFunctor()), - at::numeric_limits::lower_bound()); -} - -template -struct MinNanFunctor { - __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { - return (at::_isnan(a) || a < b) ? a : b; - } -}; - -template -void min_values_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, func_wrapper (MinNanFunctor()), - at::numeric_limits::upper_bound()); -} - -void max_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() { - max_values_kernel_cuda_impl(iter); - }); -} - -void min_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { - min_values_kernel_cuda_impl(iter); - }); -} - -template -void argmax_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, - ArgMaxOps{}, - thrust::pair(at::numeric_limits::lower_bound(), 0)); -}; - -template -void argmin_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, - ArgMinOps{}, - thrust::pair(at::numeric_limits::upper_bound(), 0)); -}; - -void argmax_kernel_cuda(TensorIterator& iter) { - // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, - // we can convert float16 & bfloat16 to float and do all the operations in float. - if (iter.dtype(1) == kHalf) { - argmax_kernel_cuda_impl(iter); - } else if (iter.dtype(1) == kBFloat16) { - argmax_kernel_cuda_impl(iter); - } else { - AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() { - argmax_kernel_cuda_impl(iter); - }); - } -} - -void argmin_kernel_cuda(TensorIterator& iter) { - // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, - // we can convert float16 & bfloat16 to float and do all the operations in float. - if (iter.dtype(1) == kHalf) { - argmin_kernel_cuda_impl(iter); - } else if (iter.dtype(1) == kBFloat16) { - argmin_kernel_cuda_impl(iter); - } else { - AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() { - argmin_kernel_cuda_impl(iter); - }); - } -} - -void min_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { - gpu_reduce_kernel( - iter, - MinOps{}, - thrust::pair(at::numeric_limits::upper_bound(), 0)); - }); -} - -void max_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() { - gpu_reduce_kernel( - iter, - MaxOps{}, - thrust::pair(at::numeric_limits::lower_bound(), 0)); - }); -} - -void aminmax_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() { - gpu_reduce_kernel( - iter, - MinMaxOps{}, - thrust::pair( - at::numeric_limits::upper_bound(), - at::numeric_limits::lower_bound() - ) - ); - }); -} - -void min_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { - min_values_kernel_cuda_impl(iter); - }); -} - -void max_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] { - max_values_kernel_cuda_impl(iter); - }); -} - -template -void _min_max_values_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, MinMaxOps{}, thrust::pair( - at::numeric_limits::upper_bound(), - at::numeric_limits::lower_bound() - )); -} - -void aminmax_allreduce_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] { - _min_max_values_kernel_cuda_impl(iter); - }); -} - -REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda); -REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda); -REGISTER_DISPATCH(argmax_stub, &argmax_kernel_cuda); -REGISTER_DISPATCH(argmin_stub, &argmin_kernel_cuda); - -}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu new file mode 100644 index 0000000000000..54d0f8499e541 --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu @@ -0,0 +1,58 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + + +namespace at { namespace native { + +template +struct MinNanFunctor { + __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { + return (at::_isnan(a) || a < b) ? a : b; + } +}; + +template +void min_values_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper (MinNanFunctor()), + at::numeric_limits::upper_bound()); +} + +void min_values_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { + min_values_kernel_cuda_impl(iter); + }); +} + +void min_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { + gpu_reduce_kernel( + iter, + MinOps{}, + thrust::pair(at::numeric_limits::upper_bound(), 0)); + }); +} + +void min_all_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { + min_values_kernel_cuda_impl(iter); + }); +} + +REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/RenormKernel.cu b/aten/src/ATen/native/cuda/RenormKernel.cu index dc96fb0e0c1d1..59f31daa381e1 100644 --- a/aten/src/ATen/native/cuda/RenormKernel.cu +++ b/aten/src/ATen/native/cuda/RenormKernel.cu @@ -1,3 +1,4 @@ +#define TORCH_ASSERT_NO_OPERATORS #include #include #include diff --git a/aten/src/ATen/native/cuda/ReplicationPadding.cu b/aten/src/ATen/native/cuda/ReplicationPadding.cu index d967ffd0354df..21d6cc1600fac 100644 --- a/aten/src/ATen/native/cuda/ReplicationPadding.cu +++ b/aten/src/ATen/native/cuda/ReplicationPadding.cu @@ -1,3 +1,4 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include diff --git a/aten/src/ATen/native/cuda/RreluWithNoise.cu b/aten/src/ATen/native/cuda/RreluWithNoise.cu index 3b2435d3dae42..762098ab77703 100644 --- a/aten/src/ATen/native/cuda/RreluWithNoise.cu +++ b/aten/src/ATen/native/cuda/RreluWithNoise.cu @@ -60,7 +60,7 @@ __global__ void rrelu_with_noise_cuda_kernel( noise[li] = r; } else { output[li] = input[li]; - noise[li] = static_cast(0); + noise[li] = static_cast(1); } } __syncthreads(); @@ -155,7 +155,7 @@ Tensor& rrelu_with_noise_out_cuda(const Tensor& self, checkAllSameGPU("rrelu_with_noise_out_cuda", {self_arg, noise_arg, output_arg}); if (training) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "rrelu_with_noise_out_cuda", [&] { _rrelu_with_noise_cuda_train( output, self, noise, lower, upper, generator); diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu index 8461aa4cd8e3e..559c203046761 100644 --- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu +++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu @@ -319,7 +319,7 @@ struct cuda_scatter_gather_base_kernel { auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride; - AT_DISPATCH_FLOATING_TYPES_AND2( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cuda_scatter_gather_base_kernel_func", [&] { @@ -450,7 +450,7 @@ struct cuda_scatter_fill_base_kernel { auto index_size = ensure_nonempty_size(self, dim); auto index_stride = ensure_nonempty_stride(self, dim); - AT_DISPATCH_FLOATING_TYPES_AND2( + AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cuda_scatter_fill_base_kernel_reduce_multiply", [&] { diff --git a/aten/src/ATen/native/cuda/Sort.cpp b/aten/src/ATen/native/cuda/Sort.cpp index fd74d35974f31..7d4ff50645f1d 100644 --- a/aten/src/ATen/native/cuda/Sort.cpp +++ b/aten/src/ATen/native/cuda/Sort.cpp @@ -23,20 +23,6 @@ namespace at { namespace native { -// We perform a segmented sort in cub with inputs that have -// more than 1024/2048 elements along the selected dimension. -// Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace). -bool should_use_small_sort(const TensorBase &self, int64_t dim) { - int64_t nsort = self.sizes()[dim]; - int64_t threshold; - if (self.scalar_type() == kLong || self.scalar_type() == kDouble) { - threshold = 1024; - } else { - threshold = 2048; - } - return nsort <= threshold; -} - std::vector infer_dense_strides_dim_last(const Tensor & self, int64_t dim); void fillSliceWithIndex(const Tensor& t, int dim) { @@ -84,7 +70,7 @@ void sort_cuda_kernel( "Sort currently does not support complex dtypes on CUDA."); // use inplace algorithm for smaller input sizes without stable=True - if (should_use_small_sort(self, dim) && !stable) { + if (should_use_small_sort(self, dim)) { // from thc: sorted->values, indices->indices, input->self fillSliceWithIndex(indices, dim); @@ -93,7 +79,7 @@ void sort_cuda_kernel( // Sort using our in-place k/v kernel that supports arbitrary // layout - sortKeyValueInplace(values, indices, dim, descending); + sortKeyValueInplace(values, indices, dim, descending, stable); return; } diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index 5f21b1ceb7b5d..e5e3274fd69a9 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -28,12 +28,164 @@ static int minimum_grid_for_occupancy(T kernel, int max_block_size) { return minGridSize; } -// In alignment with default sort on a c++ map, this function -// will permute key and value tensors identically, and -// in such a way that the 'key' tensor is ordered numerically -void sortKeyValueInplace(const TensorBase& key, - const TensorBase& value, - int dim, bool dir) { +// For very small sorts, use bitonicSortKVInPlace which performs +// better because it can sort multiple arrays within the same block of +// threads, improving occupancy. +// +// TODO: cub in CUDA 11.6 has a WarpMergeSort primitive that could +// replace the bitonic sort here. +struct SmallBitonicSort { + template + void sort( + at::cuda::detail::TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending) { + constexpr int sort_size = 32; + constexpr int max_block_y = 16; + constexpr int items_per_thread = 2; + static_assert(sort_size % items_per_thread == 0, ""); + constexpr int block_x = sort_size / items_per_thread; + + TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size); + + // Scale batch size down if the grid would be too small + const auto min_grid = minimum_grid_for_occupancy( + bitonicSortKVInPlace< + A, -1, block_x, max_block_y, + K, V, LTOp, IndexType>, + block_x * max_block_y); + const auto max_batch = std::max(IndexType{1}, keySlices / min_grid); + const int block_y = std::min(IndexType(max_block_y), max_batch); + dim3 block(block_x, block_y); + + dim3 grid; + const int grid_count = (keySlices + block_y - 1) / block_y; + TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid), + "Too many slices to sort"); + const auto stream = at::cuda::getCurrentCUDAStream(); + + if (descending) { + bitonicSortKVInPlace + <<>>( + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + GTOp()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + bitonicSortKVInPlace + <<>>( + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + LTOp()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } +}; + +// For medium sizes (32 < n <= 4096) use radixSortKVInplace for better +// performance than the bitonic sort kernel. +struct MediumRadixSort { + + template + void sort( + at::cuda::detail::TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending) { + +#define HANDLE_CASE(SIZE, ITEMS_PER_THREAD) \ + fixed_size_sort( \ + keyInfo, \ + keySlices, \ + keySliceSize, \ + keySliceStride, \ + valueInfo, \ + valueSliceStride, \ + descending) + + int64_t ceilPowerOf2 = nextHighestPowerOf2(keySliceSize); + TORCH_INTERNAL_ASSERT(ceilPowerOf2 <= 4096); + switch (ceilPowerOf2) { + case 4096: + HANDLE_CASE(4096, 32); + break; + case 2048: + HANDLE_CASE(2048, 32); + break; + case 1024: + case 512: + case 256: + HANDLE_CASE(1024, 32); + break; + case 128: + case 64: + HANDLE_CASE(128, 4); + break; + case 32: + case 16: + case 8: + case 4: + case 2: + HANDLE_CASE(32, 2); + break; + case 1: + /* Nothing to do, data already sorted */ + break; + default: + TORCH_INTERNAL_ASSERT(false); + } +#undef HANDLE_CASE + + } + + template + void fixed_size_sort( + at::cuda::detail::TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending) { + static_assert(sort_size % items_per_thread == 0, ""); + constexpr int block = sort_size / items_per_thread; + dim3 grid; + TORCH_INTERNAL_ASSERT(getGridFromTiles(keySlices, grid), + "Too many slices to sort"); + + const auto stream = at::cuda::getCurrentCUDAStream(); + radixSortKVInPlace + <<>>( + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + descending); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +}; + +template +void sortCommon(Sorter sorter, const TensorBase &key, const TensorBase &value, + int dim, bool descending) { TORCH_CHECK(key.sizes() == value.sizes(), "Key tensor must have same size as value tensor"); int dims = value.dim(); @@ -49,93 +201,15 @@ void sortKeyValueInplace(const TensorBase& key, int64_t keySliceSize = key.size(dim); ptrdiff_t keySlices = inElements / keySliceSize; - // The amount of shared memory and block size is based on - // 2^ceil(lg(n)); we choose that sorting implementation for a given - // size. - int64_t ceilPowerOf2 = nextHighestPowerOf2(keySliceSize); - - // FIXME: We'd have to find some other trick with Thrust to perform a - // vectorized (key, value) sort by slice segment - TORCH_INTERNAL_ASSERT(ceilPowerOf2 <= 2048, "sortKeyValueInplace only works for sizes <= 2048 at present"); - - const auto stream = c10::cuda::getCurrentCUDAStream(); - -#define HANDLE_CASE(TYPE, A, SIZE, BATCH) \ - do { \ - constexpr int items_per_thread = 2; \ - static_assert(SIZE % items_per_thread == 0, ""); \ - constexpr int block_x = SIZE / items_per_thread; \ - constexpr int max_block_y = BATCH; \ - \ - /* Scale batch size down if the grid would be too small */ \ - const auto min_grid = minimum_grid_for_occupancy( \ - bitonicSortKVInPlace< \ - A, -1, block_x, max_block_y, \ - scalar_t, int64_t, LTOp, TYPE>, \ - block_x * max_block_y); \ - const auto max_batch = std::max(int64_t{1}, keySlices / min_grid); \ - const int block_y = std::min(int64_t{max_block_y}, max_batch); \ - dim3 block(block_x, block_y); \ - \ - dim3 grid; \ - const int grid_count = (keySlices + block_y - 1) / block_y; \ - TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid), \ - "Too many slices to sort"); \ - \ - if (dir) { \ - bitonicSortKVInPlace \ - <<>>( \ - keyInfo, \ - (TYPE) keySlices, \ - (TYPE) keySliceSize, \ - (TYPE) keyInfo.strides[collapseKeyDim], \ - valueInfo, \ - (TYPE) valueInfo.strides[collapseValueDim], \ - GTOp()); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - bitonicSortKVInPlace \ - <<>>( \ - keyInfo, \ - (TYPE) keySlices, \ - (TYPE) keySliceSize, \ - (TYPE) keyInfo.strides[collapseKeyDim], \ - valueInfo, \ - (TYPE) valueInfo.strides[collapseValueDim], \ - LTOp()); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } while (0) - -#define HANDLE_SORT_CASE(TYPE, A) \ - { \ - switch (ceilPowerOf2) { \ - case 2048: \ - HANDLE_CASE(TYPE, A, 2048, 1); \ - break; \ - case 1024: \ - case 512: \ - case 256: \ - HANDLE_CASE(TYPE, A, 1024, 1); \ - break; \ - case 128: \ - case 64: \ - HANDLE_CASE(TYPE, A, 128, 4); \ - break; \ - case 32: \ - case 16: \ - case 8: \ - case 4: \ - case 2: \ - HANDLE_CASE(TYPE, A, 32, 16); \ - break; \ - case 1: \ - /* Nothing to do, data already sorted */ \ - break; \ - default: \ - TORCH_INTERNAL_ASSERT(false); \ - } \ - } +#define HANDLE_SORT_CASE(TYPE, A) \ + sorter.template sort( \ + keyInfo, \ + (TYPE) keySlices, \ + (TYPE) keySliceSize, \ + (TYPE) keyInfo.strides[collapseKeyDim], \ + valueInfo, \ + (TYPE) valueInfo.strides[collapseValueDim], \ + descending) // The constructed key/value tensor info is used to select the slice // we are sorting on a per-block basis @@ -189,225 +263,21 @@ void sortKeyValueInplace(const TensorBase& key, HANDLE_SORT_CASE(uint64_t, -1); } }); -#undef HANDLE_CASE #undef HANDLE_SORT_CASE -#undef HANDLE_A_CASE } -namespace { - -struct offset_t { - int stride; - int begin; - __device__ int operator[](int i) { - return stride * (begin + i); +void sortKeyValueInplace( + const TensorBase& key, + const TensorBase& value, + int dim, + bool descending, + bool stable) { + if (!stable && key.size(dim) <= 32) { + // NOTE: Bitonic sort is unstable + sortCommon(SmallBitonicSort{}, key, value, dim, descending); + } else { + sortCommon(MediumRadixSort{}, key, value, dim, descending); } -}; - -} - -namespace { - -// Segmented sort by full sort algorithm:. -// Say we are sorting a (2, 3) tensor. We have in flattened form: -// values 0.4 1.2 5.3 6.2 1.3 2.3 -// indices 0 1 2 0 1 2 -// segment_id 0 0 0 1 1 1 - -// First we sort by values, globally: -// values 6.2 5.3 2.3 1.2 1.3 0.4 -// indices 0 2 2 1 1 0 -// segment_id 1 0 1 0 1 0 - -// Then we stable sort by segment id: -// values 5.3 1.2 0.4 6.2 2.3 1.3 -// indices 2 1 0 0 2 1 -// segment_id 0 0 0 1 1 1 - -// This method can only work if the slice we are sorting (`dim`) is -// innermost, and both values and indices are contiguous. We do this -// by re-arranging the input into this form as needed, which will -// unfortunately allocate memory if the request is not in this form. -// Vectorized sort is slower than iterated sort if the number of -// slices is small (since we're sorting twice, instead of invoking a -// smaller sort `numSlices` times), but the cub sort -// implementation here is a catch-all, so we're not looking for -// efficiency, but instead correctness. - -template -__global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64_t *index, const int2 *i_s_ptr, int nsegments, int nsort) { - CUDA_KERNEL_LOOP(i, nsegments * nsort) { - int segment = i / nsort; - int j = i % nsort; - - int offset = segment * nsort; - const scalar_t *in_ = in + offset; - scalar_t *out_ = out + offset; - int64_t *index_ = index + offset; - const int2 *i_s_ptr_ = i_s_ptr + offset; - - int idx = i_s_ptr_[j].y; - index_[j] = idx; - out_[j] = in_[idx]; - } -} - - -C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) -__global__ void fill_index_and_segment_kernel( - int2 *data, int numel, at::cuda::detail::IntDivider nsort_divider) { - CUDA_KERNEL_LOOP(idx, numel) { - auto div_mod = nsort_divider.divmod(idx); - auto segment = static_cast(div_mod.div); - auto sort = static_cast(div_mod.mod); - data[idx] = int2{segment, sort}; - } -} - -C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) -__global__ void fill_reverse_indices_kernel( - int64_t *data, int numel, at::cuda::detail::IntDivider nsort_divider) { - CUDA_KERNEL_LOOP(idx, numel) { - data[idx] = nsort_divider.mod(idx); - } -} - -template -inline void segmented_sort_large_segments( - const int64_t nsegments, const int64_t nsort, const int64_t n, const bool descending, - const scalar_t * self_ptr, scalar_t * values_ptr, int64_t * indices_ptr - ) { - using namespace at::cuda::detail; - auto allocator = at::cuda::getCUDADeviceAllocator(); - auto stream = at::cuda::getCurrentCUDAStream(); - dim3 block = CUDA_NUM_THREADS; - dim3 grid = GET_BLOCKS(nsort); - c10::DeviceArray indices(*allocator, nsort); - at::cuda::detail::IntDivider nsort_divider(nsort); - fill_reverse_indices_kernel<<>>( - indices.get(), nsort, nsort_divider); - const int64_t *initial_indices = indices.get(); - - for (auto i: c10::irange(nsegments)){ - at::cuda::cub::radix_sort_pairs( - self_ptr, values_ptr, initial_indices, indices_ptr, - nsort, descending); - indices_ptr += nsort; - self_ptr += nsort; - values_ptr += nsort; - } -} - -template -inline void segmented_sort_pairs_by_full_sort( - const int64_t nsegments, const int64_t nsort, const int64_t n, const bool descending, - const scalar_t *const self_ptr, scalar_t *const values_ptr, int64_t *const indices_ptr -) { - int64_t segment_bits = std::max(1L, static_cast(std::ceil(std::log2(nsegments)))); - - const auto numel = nsort * nsegments; - auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); - auto indices_and_segment = cuda_allocator->allocate(numel * sizeof(int2)); - auto i_s_ptr = static_cast(indices_and_segment.get()); - - using namespace at::cuda::detail; - dim3 block = CUDA_NUM_THREADS; - dim3 grid = GET_BLOCKS(numel); - auto stream = c10::cuda::getCurrentCUDAStream(); - at::cuda::detail::IntDivider nsort_divider(nsort); - fill_index_and_segment_kernel<<>>( - i_s_ptr, numel, nsort_divider); - - auto indices_and_segment2 = cuda_allocator->allocate(nsegments * nsort * sizeof(int2)); - auto i_s_ptr2 = static_cast(indices_and_segment2.get()); - - at::cuda::cub::radix_sort_pairs( - self_ptr, nullptr, i_s_ptr, i_s_ptr2, - n, descending); - - TORCH_INTERNAL_ASSERT(segment_bits <= 32); - - // sort on lower 32bits, i.e. segment index - at::cuda::cub::radix_sort_keys( - reinterpret_cast(i_s_ptr2), reinterpret_cast(i_s_ptr), - n, false, 0, segment_bits); - - sort_postprocess_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( - self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort); -} - -template -void segmented_sort_pairs( - int64_t nsegments, int64_t nsort, int64_t n, bool descending, - const scalar_t *self_ptr, scalar_t *values_ptr, int64_t *indices_ptr) { - const auto numel = nsort * nsegments; - auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); - auto reverse_indices = cuda_allocator->allocate(numel * sizeof(int64_t)); - int64_t *reverse_indices_ptr = static_cast(reverse_indices.get()); - - using namespace at::cuda::detail; - dim3 block = CUDA_NUM_THREADS; - dim3 grid = GET_BLOCKS(numel); - auto stream = c10::cuda::getCurrentCUDAStream(); - at::cuda::detail::IntDivider nsort_divider(nsort); - fill_reverse_indices_kernel<<>>( - reverse_indices_ptr, numel, nsort_divider); - - at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr, - reverse_indices_ptr, indices_ptr, n, nsegments, - offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending); -} - -} // namespace - -void launch_stable_sort_kernel( - const TensorBase &self, int64_t dim, bool descending, - const TensorBase &values, const TensorBase &indices) { - const auto numel = self.numel(); - if (numel == 0) { - return; - } - - int64_t numel_or_intmax = std::min(numel, static_cast(std::numeric_limits::max())); - int64_t nsort = self.size(dim); - int64_t nbatch = (numel_or_intmax / nsort) * nsort; - TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort); - int64_t *indices_ptr = indices.data_ptr(); - -#if (defined(USE_ROCM) && ROCM_VERSION < 40500) - constexpr bool is_rocm_bf16_sort_unsupported = true; -#else - constexpr bool is_rocm_bf16_sort_unsupported = false; -#endif - - AT_DISPATCH_ALL_TYPES_AND3(kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&]{ - c10::guts::if_constexpr::value)>([&](auto _){ - const scalar_t *self_ptr = self.data_ptr(); - scalar_t *values_ptr = values.data_ptr(); - int64_t remaining = _(numel); - while (remaining > 0) { - int64_t n = std::min(remaining, nbatch); - int64_t nsegments = n / nsort; - - if (nsegments == 1 || nsort >= 1000000) { //rough heuristics where even a single sort occupies GPU - segmented_sort_large_segments( - nsegments, nsort, n, descending, - self_ptr, values_ptr, indices_ptr); - } else if (nsegments < 128) { - segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending, - self_ptr, values_ptr, indices_ptr); - } else { - segmented_sort_pairs(nsegments, nsort, n, descending, - self_ptr, values_ptr, indices_ptr); - } - - remaining -= n; - self_ptr += n; - values_ptr += n; - indices_ptr += n; - } - }, [&](auto _){ TORCH_CHECK(_(false), "BFloat16 is not supported on ROCm < 4.5"); }); - }); } }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Sort.h b/aten/src/ATen/native/cuda/Sort.h index ea4a47dcb9f2f..656b4ce2c2bba 100644 --- a/aten/src/ATen/native/cuda/Sort.h +++ b/aten/src/ATen/native/cuda/Sort.h @@ -1,22 +1,17 @@ #pragma once #include - -namespace at { -class TensorBase; -} +#include +#include namespace at { namespace native { -// Stable-sort self into values, and set indices to the -// inverse-permutation from values back to self. -// Output tensors must be pre-allocated and contiguous. -void launch_stable_sort_kernel(const TensorBase &self, int64_t dim, bool descending, - const TensorBase &values, const TensorBase &indices); +inline bool should_use_small_sort(const TensorBase &self, int64_t dim) { + return self.size(dim) <= 4096; +} -bool should_use_small_sort(const TensorBase &self, int64_t dim); -void sortKeyValueInplace(const TensorBase &key, - const TensorBase &value, - int dim, bool dir); +void sortKeyValueInplace( + const TensorBase &key, const TensorBase &value, int dim, + bool descending, bool stable=false); }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/SortStable.cu b/aten/src/ATen/native/cuda/SortStable.cu new file mode 100644 index 0000000000000..cf6ffb778e57a --- /dev/null +++ b/aten/src/ATen/native/cuda/SortStable.cu @@ -0,0 +1,299 @@ + +#define TORCH_ASSERT_NO_OPERATORS +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at { +namespace native { + +namespace { + +struct offset_t { + int stride; + int begin; + __device__ int operator[](int i) { + return stride * (begin + i); + } +}; +// Segmented sort by full sort algorithm:. +// Say we are sorting a (2, 3) tensor. We have in flattened form: +// values 0.4 1.2 5.3 6.2 1.3 2.3 +// indices 0 1 2 0 1 2 +// segment_id 0 0 0 1 1 1 + +// First we sort by values, globally: +// values 6.2 5.3 2.3 1.2 1.3 0.4 +// indices 0 2 2 1 1 0 +// segment_id 1 0 1 0 1 0 + +// Then we stable sort by segment id: +// values 5.3 1.2 0.4 6.2 2.3 1.3 +// indices 2 1 0 0 2 1 +// segment_id 0 0 0 1 1 1 + +// This method can only work if the slice we are sorting (`dim`) is +// innermost, and both values and indices are contiguous. We do this +// by re-arranging the input into this form as needed, which will +// unfortunately allocate memory if the request is not in this form. +// Vectorized sort is slower than iterated sort if the number of +// slices is small (since we're sorting twice, instead of invoking a +// smaller sort `numSlices` times), but the cub sort +// implementation here is a catch-all, so we're not looking for +// efficiency, but instead correctness. + +template +__global__ void sort_postprocess_kernel( + const scalar_t* in, + scalar_t* out, + int64_t* index, + const int2* i_s_ptr, + int nsegments, + int nsort) { + CUDA_KERNEL_LOOP(i, nsegments * nsort) { + int segment = i / nsort; + int j = i % nsort; + + int offset = segment * nsort; + const scalar_t* in_ = in + offset; + scalar_t* out_ = out + offset; + int64_t* index_ = index + offset; + const int2* i_s_ptr_ = i_s_ptr + offset; + + int idx = i_s_ptr_[j].y; + index_[j] = idx; + out_[j] = in_[idx]; + } +} + +C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) +__global__ void fill_index_and_segment_kernel( + int2* data, + int numel, + at::cuda::detail::IntDivider nsort_divider) { + CUDA_KERNEL_LOOP(idx, numel) { + auto div_mod = nsort_divider.divmod(idx); + auto segment = static_cast(div_mod.div); + auto sort = static_cast(div_mod.mod); + data[idx] = int2{segment, sort}; + } +} + +C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) +__global__ void fill_reverse_indices_kernel( + int64_t* data, + int numel, + at::cuda::detail::IntDivider nsort_divider) { + CUDA_KERNEL_LOOP(idx, numel) { + data[idx] = nsort_divider.mod(idx); + } +} + +template +inline void segmented_sort_large_segments( + const int64_t nsegments, + const int64_t nsort, + const int64_t n, + const bool descending, + const scalar_t* self_ptr, + scalar_t* values_ptr, + int64_t* indices_ptr) { + using namespace at::cuda::detail; + auto allocator = at::cuda::getCUDADeviceAllocator(); + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(nsort); + c10::DeviceArray indices(*allocator, nsort); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_reverse_indices_kernel<<>>( + indices.get(), nsort, nsort_divider); + const int64_t* initial_indices = indices.get(); + + for (auto i : c10::irange(nsegments)) { + at::cuda::cub::radix_sort_pairs( + self_ptr, values_ptr, initial_indices, indices_ptr, nsort, descending); + indices_ptr += nsort; + self_ptr += nsort; + values_ptr += nsort; + } +} + +template +inline void segmented_sort_pairs_by_full_sort( + const int64_t nsegments, + const int64_t nsort, + const int64_t n, + const bool descending, + const scalar_t* const self_ptr, + scalar_t* const values_ptr, + int64_t* const indices_ptr) { + int64_t segment_bits = std::max( + 1L, static_cast(std::ceil(std::log2(nsegments)))); + + const auto numel = nsort * nsegments; + auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); + auto indices_and_segment = cuda_allocator->allocate(numel * sizeof(int2)); + auto i_s_ptr = static_cast(indices_and_segment.get()); + + using namespace at::cuda::detail; + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(numel); + auto stream = c10::cuda::getCurrentCUDAStream(); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_index_and_segment_kernel<<>>( + i_s_ptr, numel, nsort_divider); + + auto indices_and_segment2 = + cuda_allocator->allocate(nsegments * nsort * sizeof(int2)); + auto i_s_ptr2 = static_cast(indices_and_segment2.get()); + + at::cuda::cub::radix_sort_pairs( + self_ptr, nullptr, i_s_ptr, i_s_ptr2, n, descending); + + TORCH_INTERNAL_ASSERT(segment_bits <= 32); + + // sort on lower 32bits, i.e. segment index + at::cuda::cub::radix_sort_keys( + reinterpret_cast(i_s_ptr2), + reinterpret_cast(i_s_ptr), + n, + false, + 0, + segment_bits); + + sort_postprocess_kernel<<< + (n + 511) / 512, + 512, + 0, + at::cuda::getCurrentCUDAStream()>>>( + self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort); +} + +template +void segmented_sort_pairs( + int64_t nsegments, + int64_t nsort, + int64_t n, + bool descending, + const scalar_t* self_ptr, + scalar_t* values_ptr, + int64_t* indices_ptr) { + const auto numel = nsort * nsegments; + auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); + auto reverse_indices = cuda_allocator->allocate(numel * sizeof(int64_t)); + int64_t* reverse_indices_ptr = static_cast(reverse_indices.get()); + + using namespace at::cuda::detail; + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(numel); + auto stream = c10::cuda::getCurrentCUDAStream(); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_reverse_indices_kernel<<>>( + reverse_indices_ptr, numel, nsort_divider); + + at::cuda::cub::segmented_sort_pairs( + self_ptr, + values_ptr, + reverse_indices_ptr, + indices_ptr, + n, + nsegments, + offset_t{(int)nsort, 0}, + offset_t{(int)nsort, 1}, + descending); +} + +} // namespace + +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices) { + const auto numel = self.numel(); + if (numel == 0) { + return; + } + + int64_t numel_or_intmax = + std::min(numel, static_cast(std::numeric_limits::max())); + int64_t nsort = self.size(dim); + int64_t nbatch = (numel_or_intmax / nsort) * nsort; + TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort); + int64_t* indices_ptr = indices.data_ptr(); + +#if (defined(USE_ROCM) && ROCM_VERSION < 40500) + constexpr bool is_rocm_bf16_sort_unsupported = true; +#else + constexpr bool is_rocm_bf16_sort_unsupported = false; +#endif + + AT_DISPATCH_ALL_TYPES_AND3( + kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&] { + c10::guts::if_constexpr::value)>( + [&](auto _) { + const scalar_t* self_ptr = self.data_ptr(); + scalar_t* values_ptr = values.data_ptr(); + int64_t remaining = _(numel); + while (remaining > 0) { + int64_t n = std::min(remaining, nbatch); + int64_t nsegments = n / nsort; + + if (nsegments == 1 || + nsort >= 1000000) { // rough heuristics where even a single + // sort occupies GPU + segmented_sort_large_segments( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } else if (nsegments < 128) { + segmented_sort_pairs_by_full_sort( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } else { + segmented_sort_pairs( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } + + remaining -= n; + self_ptr += n; + values_ptr += n; + indices_ptr += n; + } + }, + [&](auto _) { + TORCH_CHECK(_(false), "BFloat16 is not supported on ROCm < 4.5"); + }); + }); +} +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/SortStable.h b/aten/src/ATen/native/cuda/SortStable.h new file mode 100644 index 0000000000000..039c4307c522c --- /dev/null +++ b/aten/src/ATen/native/cuda/SortStable.h @@ -0,0 +1,19 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +// Stable-sort self into values, and set indices to the +// inverse-permutation from values back to self. +// Output tensors must be pre-allocated and contiguous. +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/SortUtils.cuh b/aten/src/ATen/native/cuda/SortUtils.cuh index 282a7ceb06b4b..a1d309ce709e2 100644 --- a/aten/src/ATen/native/cuda/SortUtils.cuh +++ b/aten/src/ATen/native/cuda/SortUtils.cuh @@ -2,10 +2,12 @@ #include #include +#include #include #include #include #include +#include namespace at { namespace native { @@ -151,4 +153,101 @@ bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, } } +template +C10_LAUNCH_BOUNDS_1(block_size) +__global__ void +radixSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + bool descending) { + static_assert(block_size > 0, ""); + + // Find the slice of the tensor that we are sorting + const IndexType linearIndex = getLinearBlockId(); + // Tiling the slices could have us be out of bounds, if there are a + // lot of slices to sort + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + using key_t = typename at::cuda::cub::detail::cuda_type::type; + using LoadKeys = cub::BlockLoad; + using LoadValues = cub::BlockLoad; + using Sort = cub::BlockRadixSort; + using StoreKeys = cub::BlockStore; + using StoreValues = cub::BlockStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage; + + // cub's Block operations operate on a fixed number of items, but the + // actual slice we are sorting might be smaller. So, we need to make + // up the difference with keys that will always sort higher. + const K invalid_key = [descending] { + using radix_t = typename cub::Traits::UnsignedBits; + union { + K key; + radix_t radix; + } tmp; + tmp.radix = descending ? + cub::Traits::LOWEST_KEY : + cub::Traits::MAX_KEY; + return tmp.key; + }(); + const V invalid_value = static_cast(0); + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + __syncthreads(); + LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + __syncthreads(); + + // Sort! + if (descending) { + Sort(tmp_storage.sort).SortDescending( + reinterpret_cast(local_keys), + local_values); + } else { + Sort(tmp_storage.sort).Sort( + reinterpret_cast(local_keys), + local_values); + } + __syncthreads(); + + // Store outputs + StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + __syncthreads(); + StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + }} // at::native diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index f442c9c9f4e16..6e05908b2ccea 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -55,6 +55,10 @@ Tensor empty_cuda(IntArrayRef size, c10::optional dtype_opt, c10::op return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } +Tensor empty_symint_cuda(c10::SymIntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + return at::native::empty_cuda(asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); +} + Tensor _efficientzerotensor_cuda(IntArrayRef size, c10::optional dtype, c10::optional layout, diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cu b/aten/src/ATen/native/cuda/TensorModeKernel.cu index ce76987e94e05..7d2371cba557d 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cu +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cu @@ -7,10 +7,12 @@ #include #include +#include #include #include #include #include +#include #include #include #include @@ -19,6 +21,112 @@ namespace at { namespace native { +template +struct ModeImpl { + std::tuple operator()( + scalar_t *iter_begin, + scalar_t *iter_end) { + at::cuda::ThrustAllocator thrust_allocator; + auto stream = at::cuda::getCurrentCUDAStream(); + auto policy = thrust::cuda::par(thrust_allocator).on(stream); + + const auto n_element = iter_end - iter_begin; + auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); + auto sort_buffer = c10::DeviceArray(*cuda_allocator, n_element); + auto sort_buffer_ptr = thrust::device_pointer_cast(sort_buffer.get()); + auto count_from_zero_iter = thrust::make_counting_iterator(int64_t{0}); + thrust::copy_n(policy, count_from_zero_iter, n_element, sort_buffer_ptr); + + + // Sort the input data. The original indices of the data are stored in + // sort_buffer_ptr + thrust::sort_by_key(policy, iter_begin, iter_end, sort_buffer_ptr); + + // Count # of unique elements via an inner product between adjacent elements. + // Add 1 if two neighboring element are not equal. + int unique = 1 + + thrust::inner_product( + policy, + iter_begin, + iter_end - 1, + iter_begin + 1, + 0, + thrust::plus(), + thrust::not_equal_to()); + + // Count frequency of each element + auto keys = c10::DeviceArray(*cuda_allocator, unique); + auto counts = c10::DeviceArray(*cuda_allocator, unique); + + auto keys_ptr = thrust::device_pointer_cast(keys.get()); + auto counts_ptr = thrust::device_pointer_cast(counts.get()); + + thrust::reduce_by_key( + policy, + iter_begin, + iter_end, + thrust::constant_iterator(1), + keys_ptr, + counts_ptr); + + // Find index of maximum count + auto it = thrust::max_element(policy, counts_ptr, counts_ptr + unique); + scalar_t mode = keys_ptr[it - counts_ptr]; + + // Find first index within which it occurs + auto position_iter = thrust::find(policy, iter_begin, iter_end, mode); + + // Translate to original non-sorted index + TORCH_INTERNAL_ASSERT(position_iter != iter_end); + int64_t index = sort_buffer_ptr[position_iter - iter_begin]; + return {mode, index}; + } +}; + +struct EqualsMode { + bool mode; + + C10_DEVICE bool operator()(const uint8_t x) { + return static_cast(x) == mode; + } +}; + +template <> +struct ModeImpl { + std::tuple operator()( + const bool *first, + const bool *last) { + at::cuda::ThrustAllocator thrust_allocator; + auto stream = at::cuda::getCurrentCUDAStream(); + auto policy = thrust::cuda::par(thrust_allocator).on(stream); + + // For bool, we can skip finding the unique elements since there + // are only two possible values. + + // See NOTE [Loading boolean values] + auto first_bytes = reinterpret_cast(first); + auto last_bytes = reinterpret_cast(last); + + const auto numel = last - first; + const auto num_true = thrust::count_if( + policy, + first_bytes, + last_bytes, + [] GPU_LAMBDA (uint8_t x) { + return static_cast(x); + } + ); + const auto num_false = (numel - num_true); + const auto mode = num_true > num_false; + + // Find first index within which it occurs + const auto position_iter = thrust::find_if( + policy, first_bytes, last_bytes, EqualsMode{mode}); + const int64_t index = position_iter - first_bytes; + return {mode, index}; + } +}; + template void calculate_mode( const TensorBase& values, @@ -26,9 +134,6 @@ void calculate_mode( const TensorBase& self, std::vector& position, int dim) { - at::cuda::ThrustAllocator thrust_allocator; - auto stream = at::cuda::getCurrentCUDAStream(); - auto policy = thrust::cuda::par(thrust_allocator).on(stream); TORCH_INTERNAL_ASSERT(self.is_contiguous()); @@ -47,53 +152,9 @@ void calculate_mode( scalar_t* iter_begin = data; scalar_t* iter_end = data + n_element; - auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); - auto sort_buffer = c10::DeviceArray(*cuda_allocator, n_element); - auto sort_buffer_ptr = thrust::device_pointer_cast(sort_buffer.get()); - auto count_from_zero_iter = thrust::make_counting_iterator(int64_t{0}); - thrust::copy_n(policy, count_from_zero_iter, n_element, sort_buffer_ptr); - - - // Sort the input data. The original indices of the data are stored in - // sort_buffer_ptr - thrust::sort_by_key(policy, iter_begin, iter_end, sort_buffer_ptr); - - // Count # of unique elements via an inner product between adjacent elements. - // Add 1 if two neighboring element are not equal. - int unique = 1 + - thrust::inner_product( - policy, - iter_begin, - iter_end - 1, - iter_begin + 1, - 0, - thrust::plus(), - thrust::not_equal_to()); - - // Count frequency of each element - auto keys = c10::DeviceArray(*cuda_allocator, unique); - auto counts = c10::DeviceArray(*cuda_allocator, unique); - - auto keys_ptr = thrust::device_pointer_cast(keys.get()); - auto counts_ptr = thrust::device_pointer_cast(counts.get()); - - thrust::reduce_by_key( - policy, - iter_begin, - iter_end, - thrust::constant_iterator(1), - keys_ptr, - counts_ptr); - - // Find index of maximum count - auto it = thrust::max_element(policy, counts_ptr, counts_ptr + unique); - scalar_t mode = keys_ptr[it - counts_ptr]; - - // Find first index within which it occurs - auto position_iter = thrust::find(policy, iter_begin, iter_end, mode); - - TORCH_INTERNAL_ASSERT(position_iter != iter_end); - int64_t index = sort_buffer_ptr[position_iter - iter_begin]; + scalar_t mode; + int64_t index; + std::tie(mode, index) = ModeImpl{}(iter_begin, iter_end); // Place mode, index in output scalar_t* values_data = values.data_ptr(); @@ -105,10 +166,11 @@ void calculate_mode( indices_data += ensure_nonempty_stride(indices, i) * pos; } + auto stream = at::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaMemcpyAsync( values_data, &mode, sizeof(scalar_t), cudaMemcpyHostToDevice, stream)); //memcpy_and_sync will synchronize results - at::cuda::memcpy_and_sync(indices_data, &index, sizeof(scalar_t), cudaMemcpyHostToDevice, stream); + at::cuda::memcpy_and_sync(indices_data, &index, sizeof(int64_t), cudaMemcpyHostToDevice, stream); } template diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cuh b/aten/src/ATen/native/cuda/TensorModeKernel.cuh index 93412ca36d6d1..c3220774ee202 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cuh +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cuh @@ -194,6 +194,9 @@ __device__ inline void bitonicSortKeys( // dimension as the innermost dim, such that we can get the particular slice for // a Tensor via its linear block dimension * the slice size. template +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11070 +__launch_bounds__(1024, 1) +#endif __global__ void compute_mode( T* input, at::cuda::detail::TensorInfo values, @@ -229,10 +232,10 @@ __global__ void compute_mode( // Each thread loads up to two elements from the Tensor into shared memory if (tidx < sliceSize) { - smem[tidx] = input[linearOffset + tidx]; + smem[tidx] = c10::load(&input[linearOffset + tidx]); } if (stidx < sliceSize) { - smem[stidx] = input[linearOffset + stidx]; + smem[stidx] = c10::load(&input[linearOffset + stidx]); } // Next, we initialize a boolean region of the buffer, offset by the loaded @@ -393,11 +396,11 @@ __global__ void compute_mode( unsigned mode_index[2] = {0u, 0u}; if (tidx * 2 < sliceSize) { const unsigned idx = tidx * 2; - mode_index[0] = input[linearOffset + idx] == mode ? idx : 0u; + mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; } if (tidx * 2 + 1 < sliceSize) { const unsigned idx = tidx * 2 + 1; - mode_index[1] = input[linearOffset + idx] == mode ? idx : 0u; + mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; } struct MaxIndexOp { diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index 335d746294d0d..fb1c16a1ca126 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -135,8 +135,9 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { auto total_dims = in_tensor.dim(); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + at::ScalarType::ComplexHalf, in_tensor.scalar_type(), "roll_cuda", [&] { roll_cuda_kernel<<>>( diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu new file mode 100644 index 0000000000000..d27a74c19c960 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char acos_name[] = "acos"; +void acos_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto acos_string = jiterator_stringify( + template T acos(T a) { return std::acos(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "acos_name", [&]() { + jitted_gpu_kernel< + /*name=*/acos_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, acos_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "acos_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::acos(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "acos_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::acos(a); + }); + }); + } +} + +REGISTER_DISPATCH(acos_stub, &acos_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu new file mode 100644 index 0000000000000..f831e9e5b8710 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char acosh_name[] = "acosh"; +void acosh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if(at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto acosh_string = jiterator_stringify( + template + T acosh(T a) { + return std::acosh(a); + } + ); + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acosh_name", [&]() { + jitted_gpu_kernel< + /*name=*/ acosh_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 1>(iter, acosh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acosh_name", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::acosh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + common_dtype, "acosh_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::acosh(a); + }); + }); + } +} + +REGISTER_DISPATCH(acosh_stub, &acosh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu new file mode 100644 index 0000000000000..fdabb67717741 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu @@ -0,0 +1,52 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char asin_name[] = "asin"; +void asin_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto asin_string = jiterator_stringify( + template T asin(T a) { return std::asin(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "asin_name", [&]() { + jitted_gpu_kernel< + /*name=*/asin_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, asin_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "asin_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::asin(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "asin_cuda", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::asin(a); + }); + }); + } +} + +REGISTER_DISPATCH(asin_stub, &asin_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu new file mode 100644 index 0000000000000..e1cf41b46db50 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char asinh_name[] = "asinh"; +void asinh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto asinh_string = jiterator_stringify( + template T asinh(T a) { return std::asinh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "asinh_name", [&]() { + jitted_gpu_kernel< + /*name=*/asinh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, asinh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "asinh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::asinh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "asinh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::asinh(a); + }); + }); + } +} + +REGISTER_DISPATCH(asinh_stub, &asinh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu new file mode 100644 index 0000000000000..4bbbb95d384f9 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char atan_name[] = "atan"; +void atan_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto atan_string = jiterator_stringify( + template + T atan(T a) { + return std::atan(a); + } + ); + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() { + jitted_gpu_kernel< + /*name=*/ atan_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 1>(iter, atan_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::atan(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + common_dtype, "atan_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::atan(a); + }); + }); + } +} + +REGISTER_DISPATCH(atan_stub, &atan_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu new file mode 100644 index 0000000000000..461a0f042205d --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char atanh_name[] = "atanh"; +void atanh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto atanh_string = jiterator_stringify( + template T atanh(T a) { return std::atanh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "atanh_name", [&]() { + jitted_gpu_kernel< + /*name=*/atanh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, atanh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "atanh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::atanh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "atanh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::atanh(a); + }); + }); + } +} + +REGISTER_DISPATCH(atanh_stub, &atanh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu new file mode 100644 index 0000000000000..c246a1d332235 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu @@ -0,0 +1,55 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char cos_name[] = "cos"; +void cos_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto cos_string = jiterator_stringify( + template T cos(T a) { return std::cos(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cos_name", [&]() { + jitted_gpu_kernel< + /*name=*/cos_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, cos_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cos_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::cos(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "cos_cuda", + [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cos(a); }); + }); + } +} + +REGISTER_DISPATCH(cos_stub, &cos_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu new file mode 100644 index 0000000000000..1d1f479843286 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char cosh_name[] = "cosh"; +void cosh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto cosh_string = jiterator_stringify( + template T cosh(T a) { return std::cosh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cosh_name", [&]() { + jitted_gpu_kernel< + /*name=*/cosh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, cosh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cosh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::cosh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "cosh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::cosh(a); + }); + }); + } +} + +REGISTER_DISPATCH(cosh_stub, &cosh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu deleted file mode 100644 index e223e1da7a236..0000000000000 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ /dev/null @@ -1,372 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace at { namespace native { - -void acos_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - iter.common_dtype(), "acos_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::acos(a); - }); - }); -} - -const char asin_name[] = "asin"; -void asin_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto asin_string = jiterator_stringify( - template - T asin(T a) { - return std::asin(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "asin_name", [&]() { - jitted_gpu_kernel< - /*name=*/ asin_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, asin_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "asin_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::asin(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, common_dtype, "asin_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::asin(a); - }); - }); - } -} - -const char atan_name[] = "atan"; -void atan_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto atan_string = jiterator_stringify( - template - T atan(T a) { - return std::atan(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() { - jitted_gpu_kernel< - /*name=*/ atan_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, atan_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::atan(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "atan_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::atan(a); - }); - }); - } -} - -const char sin_name[] = "sin"; -void sin_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto sin_string = jiterator_stringify( - template - T sin(T a) { - return std::sin(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sin_name", [&]() { - jitted_gpu_kernel< - /*name=*/ sin_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, sin_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sin_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::sin(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "sin_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::sin(a); - }); - }); - } -} - -const char cos_name[] = "cos"; -void cos_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto cos_string = jiterator_stringify( - template - T cos(T a) { - return std::cos(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cos_name", [&]() { - jitted_gpu_kernel< - /*name=*/ cos_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, cos_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cos_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::cos(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "cos_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::cos(a); - }); - }); - } -} - -const char sinh_name[] = "sinh"; -void sinh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto sinh_string = jiterator_stringify( - template - T sinh(T a) { - return std::sinh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sinh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ sinh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, sinh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sinh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::sinh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "sinh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::sinh(a); - }); - }); - } -} - -const char cosh_name[] = "cosh"; -void cosh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto cosh_string = jiterator_stringify( - template - T cosh(T a) { - return std::cosh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cosh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ cosh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, cosh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cosh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::cosh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "cosh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::cosh(a); - }); - }); - } -} - -const char tanh_name[] = "tanh"; -void tanh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto tanh_string = jiterator_stringify( - template - T tanh(T a) { - return std::tanh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tanh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ tanh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, tanh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tanh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::tanh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "tanh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::tanh(a); - }); - }); - } -} - -void acosh_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - iter.common_dtype(), "acosh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::acosh(a); - }); - }); -} - -void asinh_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - iter.common_dtype(), "asinh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::asinh(a); - }); - }); -} - -void atanh_kernel_cuda(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - iter.common_dtype(), "atanh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::atanh(a); - }); - }); -} - -const char tan_name[] = "tan"; -void tan_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto tan_string = jiterator_stringify( - template - T tan(T a) { - return std::tan(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tan_name", [&]() { - jitted_gpu_kernel< - /*name=*/ tan_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, tan_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tan_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::tan(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "tan_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::tan(a); - }); - }); - } -} - -REGISTER_DISPATCH(acos_stub, &acos_kernel_cuda); -REGISTER_DISPATCH(acosh_stub, &acosh_kernel_cuda); -REGISTER_DISPATCH(asinh_stub, &asinh_kernel_cuda); -REGISTER_DISPATCH(atanh_stub, &atanh_kernel_cuda); -REGISTER_DISPATCH(asin_stub, &asin_kernel_cuda); -REGISTER_DISPATCH(atan_stub, &atan_kernel_cuda); -REGISTER_DISPATCH(sin_stub, &sin_kernel_cuda); -REGISTER_DISPATCH(cos_stub, &cos_kernel_cuda); -REGISTER_DISPATCH(sinh_stub, &sinh_kernel_cuda); -REGISTER_DISPATCH(cosh_stub, &cosh_kernel_cuda); -REGISTER_DISPATCH(tanh_stub, &tanh_kernel_cuda); -REGISTER_DISPATCH(tan_stub, &tan_kernel_cuda); - -}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu new file mode 100644 index 0000000000000..833ecdccc18c2 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu @@ -0,0 +1,55 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char sin_name[] = "sin"; +void sin_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto sin_string = jiterator_stringify( + template T sin(T a) { return std::sin(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sin_name", [&]() { + jitted_gpu_kernel< + /*name=*/sin_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, sin_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sin_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::sin(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "sin_cuda", + [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sin(a); }); + }); + } +} + +REGISTER_DISPATCH(sin_stub, &sin_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu new file mode 100644 index 0000000000000..fb806aa84b66e --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char sinh_name[] = "sinh"; +void sinh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto sinh_string = jiterator_stringify( + template T sinh(T a) { return std::sinh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sinh_name", [&]() { + jitted_gpu_kernel< + /*name=*/sinh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, sinh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sinh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::sinh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "sinh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::sinh(a); + }); + }); + } +} + +REGISTER_DISPATCH(sinh_stub, &sinh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu new file mode 100644 index 0000000000000..a57499b337237 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu @@ -0,0 +1,55 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char tan_name[] = "tan"; +void tan_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto tan_string = jiterator_stringify( + template T tan(T a) { return std::tan(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "tan_name", [&]() { + jitted_gpu_kernel< + /*name=*/tan_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, tan_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "tan_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::tan(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "tan_cuda", + [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::tan(a); }); + }); + } +} + +REGISTER_DISPATCH(tan_stub, &tan_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu new file mode 100644 index 0000000000000..ffaf36a028f61 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char tanh_name[] = "tanh"; +void tanh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto tanh_string = jiterator_stringify( + template T tanh(T a) { return std::tanh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "tanh_name", [&]() { + jitted_gpu_kernel< + /*name=*/tanh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, tanh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "tanh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::tanh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "tanh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::tanh(a); + }); + }); + } +} + +REGISTER_DISPATCH(tanh_stub, &tanh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index 746bba7a66c5c..320534e6540c0 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -3,7 +3,8 @@ #include #include #include -#include + +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -20,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +34,6 @@ namespace native{ namespace { - template < typename policy_t, typename scalar_t, typename equal_t, typename not_equal_t @@ -48,7 +49,6 @@ std::tuple compute_unique( equal_t equal, not_equal_t not_equal ) { - // inverse indices Tensor inverse_indices; if (!return_inverse || num_inp == 0) { @@ -140,8 +140,8 @@ std::tuple unique_dim_cuda_template( thrust::sort(policy, indices_data, indices_data + num_inp, [=] __device__ (int64_t a, int64_t b) -> bool { for (int64_t i = 0; i < n; ++i) { - scalar_t lhs = input_flat_ptr[i + a * n]; - scalar_t rhs = input_flat_ptr[i + b * n]; + scalar_t lhs = c10::load(&input_flat_ptr[i + a * n]); + scalar_t rhs = c10::load(&input_flat_ptr[i + b * n]); if (lhs < rhs) { return true; } else if (lhs > rhs) { @@ -160,8 +160,8 @@ std::tuple unique_dim_cuda_template( return_inverse, return_counts, options, [=] __device__ (int64_t a, int64_t b) -> bool { for (int64_t i = 0; i < n; ++i) { - scalar_t lhs = input_flat_ptr[i + a * n]; - scalar_t rhs = input_flat_ptr[i + b * n]; + scalar_t lhs = c10::load(&input_flat_ptr[i + a * n]); + scalar_t rhs = c10::load(&input_flat_ptr[i + b * n]); if (lhs != rhs) { return false; } @@ -170,8 +170,8 @@ std::tuple unique_dim_cuda_template( }, [=] __device__ (int64_t a, int64_t b) -> int64_t { for (int64_t i = 0; i < n; ++i) { - scalar_t lhs = input_flat_ptr[i + a * n]; - scalar_t rhs = input_flat_ptr[i + b * n]; + scalar_t lhs = c10::load(&input_flat_ptr[i + a * n]); + scalar_t rhs = c10::load(&input_flat_ptr[i + b * n]); if (lhs != rhs) { return 1; } diff --git a/aten/src/ATen/native/cuda/UniqueCub.cu b/aten/src/ATen/native/cuda/UniqueCub.cu index cc19b96a77971..c5d40242221b9 100644 --- a/aten/src/ATen/native/cuda/UniqueCub.cu +++ b/aten/src/ATen/native/cuda/UniqueCub.cu @@ -4,7 +4,10 @@ #include #include #include -#include +#include + +#include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -19,10 +22,10 @@ namespace internal { namespace { -template +template __global__ void adjacent_difference_kernel( int64_t n, - const scalar_t* input, + InputIteratorT input, int* output) { CUDA_KERNEL_LOOP(i, n) { output[i] = i > 0 ? input[i] != input[i - 1] : 0; @@ -39,28 +42,44 @@ __global__ void scatter_kernel( } } +template +const scalar_t * wrap_input_iterator(const scalar_t *data) { + return data; +} + +struct LoadBoolOp { + __device__ bool operator()(uint8_t x) const { + return static_cast(x); + } +}; + +auto wrap_input_iterator(const bool *data) { + // See NOTE [Loading boolean values] + LoadBoolOp op; + return NO_ROCM(at_cuda_detail)::cub::TransformInputIterator( + reinterpret_cast(data), op); +} + // A variation of compute_unique (defined in Unique.cu) that doesn't allow // customizing equal and not_equal (CUB doesn't allow them). template -std::tuple compute_unique( +std::tuple compute_unique( const Tensor& sorted, const Tensor& sorted_indices, const bool return_inverse, const bool return_counts, const bool consecutive) { int64_t num_inp = sorted.numel(); - TORCH_CHECK( - num_inp <= INT_MAX, "num_inp ", num_inp, " is too big to for CUB"); auto options = sorted.options().dtype(kLong); - const scalar_t* data = sorted.data_ptr(); + auto data = wrap_input_iterator(sorted.data_ptr()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // inverse indices Tensor inverse_indices; - if (!return_inverse || num_inp == 0) { + if (!return_inverse) { inverse_indices = at::empty({0}, options); } else { - inverse_indices = at::empty({num_inp}, options); + inverse_indices = at::empty(sorted.sizes(), options); Tensor inv_loc = consecutive ? at::empty({num_inp}, options.dtype(kInt)) : inverse_indices; int* inv_loc_ptr = static_cast(inv_loc.data_ptr()); @@ -70,8 +89,8 @@ std::tuple compute_unique( int curDevice = -1; cudaGetDevice(&curDevice); cuda::getApplyGrid(num_inp, grid, curDevice); - adjacent_difference_kernel - <<>>(num_inp, data, inv_loc_ptr); + adjacent_difference_kernel<<>>( + num_inp, data, inv_loc_ptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); Tensor inv_loc_out = @@ -114,8 +133,9 @@ std::tuple compute_unique( counts.resize_(num_out); } - return std::tuple( - data_out, inverse_indices, counts, num_out); + data_out.resize_(num_out); + return std::tuple( + data_out, inverse_indices, counts); } } // namespace @@ -123,54 +143,182 @@ std::tuple compute_unique( // This function (and compute_unique above) are defined in a separate file from // Unique.cu because for now ATen/cuda/cub.cuh can't be used together with // thrust in the same compilation unit. + template -std::tuple unique_cuda_template( - const Tensor& self, - const bool consecutive, - const bool return_inverse, - const bool return_counts) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +struct UniqueCub { + std::tuple operator() ( + const Tensor& self, + const bool consecutive, + const bool return_inverse, + const bool return_counts) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - auto options = self.options().dtype(kLong); - int64_t num_inp = self.numel(); - Tensor sorted; - Tensor self_c = self.contiguous(); - if (consecutive) { - sorted = self_c; - } else { - sorted = at::empty({num_inp}, self.options()); + int64_t num_inp = self.numel(); + Tensor sorted; + if (consecutive) { + sorted = self; + } else { + sorted = at::empty(self.sizes(), self.options()); + } + scalar_t* sorted_data = sorted.data_ptr(); + + Tensor sorted_indices; + if (!return_inverse) { + if (!consecutive) { + cuda::cub::radix_sort_keys(self.data_ptr(), sorted_data, num_inp); + } + } else { + if (!consecutive) { + auto options = self.options().dtype(kLong); + Tensor range = at::arange(0, num_inp, options); + sorted_indices = at::empty({num_inp}, options); + cuda::cub::radix_sort_pairs( + self.data_ptr(), + sorted_data, + range.data_ptr(), + sorted_indices.data_ptr(), + num_inp); + } + } + + return compute_unique( + sorted, sorted_indices, return_inverse, return_counts, consecutive); } - scalar_t* sorted_data = sorted.data_ptr(); +}; - Tensor sorted_indices; - if (!return_inverse) { - if (!consecutive) { - cuda::cub::radix_sort_keys(self_c.data_ptr(), sorted_data, num_inp); +struct MapNumberOfTrueValues { + __device__ int operator()(uint8_t x) const { + return static_cast(x); + } +}; + +C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) +__global__ void unique_bool_write_inverse_indices( + const int numel, + const int *num_true_p, + const bool *self, + int64_t *inverse_indices_out) { + constexpr int false_idx = 0; + const int num_true = *num_true_p; + const int num_false = numel - num_true; + const int true_idx = num_false > 0; + + CUDA_KERNEL_LOOP(i, numel) { + const auto value = c10::load(&self[i]); + inverse_indices_out[i] = value ? true_idx : false_idx; + } +} + +C10_LAUNCH_BOUNDS_1(1) +__global__ void unique_bool_write_output( + const int numel, + const int *num_true_p, + bool *values_out, + int64_t *counts_out) { + constexpr int false_idx = 0; + const int num_true = *num_true_p; + const int num_false = numel - num_true; + const int true_idx = num_false > 0; + + if (blockIdx.x == 0 && threadIdx.x == 0) { + if (num_false > 0) { + values_out[false_idx] = false; + counts_out[false_idx] = num_false; } - } else { - if (!consecutive) { - Tensor range = at::arange(0, num_inp, options); - sorted_indices = at::empty({num_inp}, options); - cuda::cub::radix_sort_pairs( - self_c.data_ptr(), - sorted_data, - range.data_ptr(), - sorted_indices.data_ptr(), - num_inp); + if (num_true > 0) { + values_out[true_idx] = true; + counts_out[true_idx] = num_true; } } +} - Tensor output, inverse_indices, counts; - int64_t num_out; - std::tie(output, inverse_indices, counts, num_out) = compute_unique( - sorted, sorted_indices, return_inverse, return_counts, consecutive); - output.resize_(num_out); +template <> +struct UniqueCub { + + std::tuple operator() ( + const Tensor& self, + const bool consecutive, + const bool return_inverse, + const bool return_counts) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int64_t num_inp = self.numel(); + + Tensor output, inverse_indices, counts; + if (consecutive) { + Tensor sorted_indices; + return compute_unique( + self, sorted_indices, return_inverse, return_counts, consecutive); + } + + // Instead of sorting, we use a reduction to find the number of + // true values and from that we can infer the number of false. + // If either has a count of zero, we omit it from the output. + auto allocator = at::cuda::getCUDADeviceAllocator(); + c10::DeviceArray tmp_num_true(*allocator, 1); + + const bool* self_data = self.data_ptr(); + MapNumberOfTrueValues op; + NO_ROCM(at_cuda_detail)::cub::TransformInputIterator + data_iter(reinterpret_cast(self_data), op); + at::cuda::cub::reduce(data_iter, tmp_num_true.get(), num_inp, + NO_ROCM(at_cuda_detail)::cub::Sum{}, 0); + + auto options = self.options(); + output = at::empty({2}, self.options()); + counts = at::empty({2}, options.dtype(kLong)); + + unique_bool_write_output<<<1, 1, 0, stream>>>( + num_inp, + tmp_num_true.get(), + output.data_ptr(), + counts.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); - if (return_inverse) { - inverse_indices.resize_(self.sizes()); + if (return_inverse) { + using namespace at::cuda::detail; + inverse_indices = at::empty(self.sizes(), options.dtype(kLong)); + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(num_inp); + unique_bool_write_inverse_indices<<>>( + num_inp, + tmp_num_true.get(), + self_data, + return_inverse ? inverse_indices.data_ptr() : nullptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + + // Final sync to fix the ouput tensors shape + int num_true = 0; + at::cuda::memcpy_and_sync(&num_true, tmp_num_true.get(), sizeof(int), + cudaMemcpyDeviceToHost, stream); + const int num_false = num_inp - num_true; + const int num_out = ((num_true > 0) + (num_false > 0)); + output.resize_({num_out}); + counts.resize_({num_out}); + + return std::tuple(output, inverse_indices, counts); + } +}; + +template +std::tuple unique_cuda_template( + const Tensor& self, + const bool consecutive, + const bool return_inverse, + const bool return_counts) { + auto num_inp = self.numel(); + TORCH_CHECK( + num_inp <= INT_MAX, "num_inp ", num_inp, " is too big to for CUB"); + if (num_inp == 0) { + Tensor output = at::empty({0}, self.options()); + Tensor inverse_indices = at::empty(self.sizes(), self.options().dtype(kLong)); + Tensor counts = at::empty({0}, self.options().dtype(kLong)); + return std::tuple(output, inverse_indices, counts); } - return std::tuple(output, inverse_indices, counts); + auto self_c = self.expect_contiguous(); + return UniqueCub{}(*self_c, consecutive, return_inverse, return_counts); } #define INSTANTIATE_UNIQUE_CUDA_TEMPLATE(TYPE) \ diff --git a/aten/src/ATen/native/cuda/ValidateCompressedIndicesKernel.cu b/aten/src/ATen/native/cuda/ValidateCompressedIndicesKernel.cu new file mode 100644 index 0000000000000..cde8364703596 --- /dev/null +++ b/aten/src/ATen/native/cuda/ValidateCompressedIndicesKernel.cu @@ -0,0 +1,30 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +namespace at { +namespace native { + +namespace { + +template +struct CUDAKernelLauncher { + static void launch(TensorIteratorBase& iter, const func_t& f) { + gpu_kernel(iter, f); + } +}; + +} + +void _validate_compressed_sparse_indices_cuda( + const bool is_crow, + const Tensor& cidx, + const Tensor& idx, + const int64_t cdim, + const int64_t dim, + const int64_t nnz) { + validate_compressed_sparse_indices_kernel( + is_crow, cidx, idx, cdim, dim, nnz); +} + +}} diff --git a/aten/src/ATen/native/cuda/airy_ai.cu b/aten/src/ATen/native/cuda/airy_ai.cu new file mode 100644 index 0000000000000..335894807bf9d --- /dev/null +++ b/aten/src/ATen/native/cuda/airy_ai.cu @@ -0,0 +1,43 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + namespace native { + namespace { + const char airy_ai_name[] = "airy_ai_forward"; + + void airy_ai_kernel_cuda(TensorIteratorBase& iterator) { +#if AT_USE_JITERATOR() + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "airy_ai_cuda", [&]() { + jitted_gpu_kernel(iterator, airy_ai_string); + }); +#else + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "airy_ai_cuda", [&]() { + gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return airy_ai_forward(a); + }); + }); +#endif // AT_USE_JITERATOR() + } + } + + REGISTER_DISPATCH(special_airy_ai_stub, &airy_ai_kernel_cuda); + } // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 10270bd2dbf1e..673ea9f476e46 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -726,6 +727,36 @@ void __inline__ initializeCudaContext() { } } +std::string generate_code( + const KernelDescriptor &desc, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + bool vectorized, + int vec_size, + bool return_by_ref) { + c10::SmallVector extra_args_typenames(desc.extra_args_types.size()); + for (auto i : c10::irange(extra_args_typenames.size())) { + extra_args_typenames[i] = typeName(desc.extra_args_types[i]); + } + + return generate_code( + desc.nInputs, + desc.nOutputs, + desc.f, + desc.name, + typeName(desc.f_inputs_type), + typeName(toOpMathType(desc.f_inputs_type)), + typeName(desc.result_type), + contiguous, + dynamic_casting, + scalar_pos, + extra_args_typenames, + vectorized, + vec_size, + return_by_ref); +} + //FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh #define THREAD_WORK_SIZE 4 constexpr int thread_work_size = THREAD_WORK_SIZE; @@ -1056,6 +1087,31 @@ std::string load_code_template(const std::string& path) { return s; } +std::string generate_reduction_code( + const KernelDescriptor &desc, + int vt0, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen) { + TORCH_INTERNAL_ASSERT(desc.nInputs == 1); + TORCH_INTERNAL_ASSERT(desc.extra_args_types.size() == 0); + + return generate_reduction_code( + desc.nOutputs, + desc.f, + desc.name, + vt0, + typeName(desc.f_inputs_type), + typeName(toOpMathType(desc.f_inputs_type)), + typeName(desc.result_type), + contiguous, + vectorized, + vec_size, + max_threads_codegen + ); +} + std::string generate_reduction_code( int nOutputs, const std::string& func, diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h index 129ad3e0f3834..13aa723db2756 100644 --- a/aten/src/ATen/native/cuda/jit_utils.h +++ b/aten/src/ATen/native/cuda/jit_utils.h @@ -19,6 +19,70 @@ struct NvrtcFunction { CUfunction function = nullptr; }; +struct KernelDescriptor { + std::string name; + std::string f; + c10::ScalarType f_inputs_type; + c10::ScalarType result_type; + c10::SmallVector extra_args_types; + int nInputs, nOutputs; +}; + +// Helper function to return a vector +// corresponding to the type of the arguments in parameter pack. +template +c10::SmallVector get_extra_args_types() { + return {c10::CppTypeToScalarType::value ...}; +} + +template < + typename result_type, + typename f_inputs_type, + typename... ExtraArgs> +KernelDescriptor make_kernel_descriptor( + std::string name, + std::string f, + int nInputs, + int nOutputs) { + KernelDescriptor ret; + ret.name = std::move(name); + ret.f = std::move(f); + ret.f_inputs_type = c10::CppTypeToScalarType::value; + ret.result_type = c10::CppTypeToScalarType::value; + ret.extra_args_types = get_extra_args_types(); + ret.nInputs = nInputs; + ret.nOutputs = nOutputs; + return ret; +} + +inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { + auto ip = reinterpret_cast(pointer); + if (ip % (4 * default_alignment) == 0) { + return 4; + } + if (ip % (2 * default_alignment) == 0) { + return 2; + } + return 1; +} + +inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef pointers) { + TORCH_INTERNAL_ASSERT(desc.nOutputs == 1); + TORCH_INTERNAL_ASSERT(static_cast(pointers.size()) == 1 + desc.nInputs); + + // Deals with output + auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize(); + int result = can_vectorize_up_to(result_size, pointers[0]); + + // Incorporates input(s) + auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + for (auto i : c10::irange(1, pointers.size())) { + result = std::min(result, can_vectorize_up_to(input_size, pointers[i])); + } + + return result; +} + std::string generate_code( int nInputs, int nOutputs, @@ -35,6 +99,15 @@ std::string generate_code( int vec_size=0, bool return_by_ref=false); +std::string generate_code( + const KernelDescriptor &desc, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + std::string generate_reduction_code( int nOutputs, const std::string& func, @@ -48,6 +121,14 @@ std::string generate_reduction_code( int vec_size, int max_threads_codegen); +std::string generate_reduction_code( + const KernelDescriptor &desc, + const int vt0, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + NvrtcFunction jit_pwise_function( const std::string& code, const std::string& kernel_name); @@ -108,17 +189,12 @@ template <> inline std::string typeName(){ } #define TYPE_NAME_CASE(ctype, scalartype) \ - case ScalarType::scalartype: return std::string(#ctype); + case ScalarType::scalartype: return typeName(); inline std::string typeName(ScalarType t) { switch (t) { - AT_FORALL_SCALAR_TYPES(TYPE_NAME_CASE) - case ScalarType::Bool : return "bool"; - case ScalarType::Half : return "at::Half"; - case ScalarType::BFloat16 : return "at::BFloat16"; - case ScalarType::ComplexFloat : return "std::complex"; - case ScalarType::ComplexDouble : return "std::complex"; - default: - TORCH_CHECK(false, "invalid type for jiterator"); + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE) + default: + TORCH_CHECK(false, "invalid type for jiterator"); } } #undef TYPE_NAME_CASE diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index a327b7dc3ed7a..96d700c761ebf 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -940,7 +940,7 @@ std::tuple layer_norm_backward_cuda( c10::nullopt /* pin_memory */, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - if (M > 0) { + if (M > 0 && N > 0) { LayerNormBackwardKernelImpl( dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta); } diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index 53944db0fb1a2..320c799f23bce 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #endif #if AT_MAGMA_ENABLED() @@ -2279,11 +2280,8 @@ std::tuple _symeig_helper_cuda(const Tensor& self, bool eigenvec apply_magma_eigh(eigvals_working_copy, self_working_copy, infos, upper, eigenvectors); }); - if (self.dim() > 2) { - batchCheckErrors(infos, "symeig_cuda"); - } else { - singleCheckErrors(infos.item().toInt(), "symeig_cuda"); - } + at::_linalg_check_errors(infos, "symeig", self.dim() == 2); + if (eigenvectors) { return std::tuple(eigvals_working_copy.to(self.device()), self_working_copy); } else { @@ -2355,7 +2353,7 @@ REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); template static void apply_eig(const Tensor& self, bool eigenvectors, Tensor& out_eigvals, Tensor& out_eigvecs, - int64_t *info_ptr) { + int* info_ptr) { #if !AT_MAGMA_ENABLED() TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorch with MAGMA. " "Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA."); @@ -2425,11 +2423,11 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector ? at::empty_strided({n, n}, {1, n}, options) : Tensor(); - int64_t info; + auto infos = at::zeros({}, self_working_copy.options().dtype(kInt)); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cuda", [&]{ - apply_eig(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, &info); + apply_eig(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, infos.data_ptr()); }); - singleCheckErrors(info, "eig_cuda"); + at::_linalg_check_errors(infos, "eig", /*is_matrix*/true); return std::tuple(out_eigvals, out_eigvecs); } @@ -2847,22 +2845,22 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor return; } - auto batch_size = batchCount(B); - auto m = LU.size(-2); - auto b2 = B.size(-1); + auto b = batchCount(B); + auto n = LU.size(-2); + auto k = B.size(-1); // magma implementation of LU solve cannot handle a b tensor with last dim > 1024 // See https://bitbucket.org/icl/magma/issues/19/dgesv_batched-dgetrs_batched-fails-for - bool over_batched_magma_dim_limit = b2 > 1024; + bool over_batched_magma_dim_limit = k > 1024; // heuristics determined from tests dicussed in https://github.com/pytorch/pytorch/pull/72935 // Computes X = U^{-1}L^{-1}P^T B via triangular solves // Helps mitigating the bugs in magma - auto lu_solve_triangular = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, const TransposeType trans) { + auto lu_solve_triangular = [n](const Tensor& LU, const Tensor& pivots, const Tensor& B, const TransposeType trans) { auto LU_ = maybe_expand_lu(B, LU); auto pivots_ = maybe_expand_pivots(B, pivots); // LAPACK / cublas / etc returns the permutation in an odd format // Here we transform it to a vector representing a permutation, i.e. a (batch of) vectors st. P(i) = j - auto perm = at::arange(m, pivots_->options().dtype(kLong)).expand(pivots_->sizes()).contiguous(); + auto perm = at::arange(n, pivots_->options().dtype(kLong)).expand(pivots_->sizes()).contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) @@ -2871,14 +2869,14 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor .add_output(perm) .add_input(*pivots_) .build(); - unpack_pivots_stub(pivots_->device().type(), iter, m); + unpack_pivots_stub(pivots_->device().type(), iter, n); if (trans == TransposeType::NoTranspose) { // Get the inverse permutation // This is an insertion sort, and it's equivalent to // perm = at::argsort(perm); // but more parallelisable and O(n), exploiting that perm is a permutation - auto id_perm = at::arange(m, perm.options()).expand(perm.sizes()); + auto id_perm = at::arange(n, perm.options()).expand(perm.sizes()); auto inv_perm = perm.scatter(-1, perm, id_perm); // B1 = P^T @ B (must be done out-of-place as B is both source and target) auto B1 = B.scatter(-2, inv_perm.unsqueeze(-1).expand_as(B), B); @@ -2916,7 +2914,7 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor auto preferred_backend = at::globalContext().linalgPreferredBackend(); #ifdef USE_CUSOLVER if (preferred_backend == at::LinalgBackend::Cusolver) { - if (batch_size <= 2 && m >= 64) { + if (b <= 2 && n >= 64) { lu_solve_looped_cusolver(LU, pivots, B, trans); } else { lu_solve_batched_cublas_fn(LU, pivots, B, trans); @@ -2935,37 +2933,75 @@ static void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor return; } - // Summary: In most cases we use cublas / cusolver - // MAGMA is faster for large matrices, but it is buggy for trans != NoTranspose or for large batches. - // LU solve is fast in some cases when adjoint=True + // Heuristic + //if (n == k) { + // if (k <= 16) batched_cublas + // else solve_triag + //} else { + //if (n <= 8) { + // if (k >= 256 && NoTranspose) batched_magma + // else batched_cusolver + //} else if (n <= 32) { + // b <= 2 looped_cusolver + // k <= 8 batched_cusolver + // solve_triag + //} else if (n <= 64) { + // b <= 2 && (k <= 64 || adjoint) looped_cusolver + // k <= 8 batched_cusolver + // solve_triag + //} else if (n <= 128) { + // if (b <= 2 && k <= 2) looped_cusolver + // else if (k <= 2) batched_cusolver + // else solve_triag + //} else { // n > 128 + // solve_triag + //} + //} + + // Particular case when multiplying A^{-1}B where B is square + // In this case doing two triangular solves is almost always fastest + if (n == k) { #ifdef CUDART_VERSION -#ifdef USE_CUSOLVER - if (batch_size <= 2 && m >= 64) { - lu_solve_looped_cusolver(LU, pivots, B, trans); - return; - } -#endif // ifdef USE_CUSOLVER - if (trans != TransposeType::NoTranspose && m <= 2 && batch_size >= 128) { + if (n <= 16) { + lu_solve_batched_cublas_fn(LU, pivots, B, trans); + return; + } +#endif lu_solve_triangular(LU, pivots, B, trans); + return; } -#if AT_MAGMA_ENABLED() - else if (!over_batched_magma_dim_limit && trans == TransposeType::NoTranspose && m >= 256 && batch_size >= 128) { + +#ifdef CUDART_VERSION +#ifdef USE_CUSOLVER +if (n <= 8) { + if (use_magma_ && !over_batched_magma_dim_limit && trans == TransposeType::NoTranspose && k >= 256) { lu_solve_batched_magma_fn(LU, pivots, B, trans); - } -#endif - else { + } else { lu_solve_batched_cublas_fn(LU, pivots, B, trans); } -#else - // If it's not buggy and it's faster than solve triangular (large matrix and large batch regime) - // we use batched_magma, otherwise, we resort to two triangular solves - // For trans != TransposeType::NoTranspose lu_solve_triangular is faster anyway - if (!over_batched_magma_dim_limit && trans == TransposeType::NoTranspose && m >= 256 && batch_size >= 128) { - lu_solve_batched_magma_fn(LU, pivots, B, trans); +} else if (n <= 64) { + if (b <= 2 && (k <= 64 || trans != TransposeType::NoTranspose || n <= 32)) { + lu_solve_looped_cusolver(LU, pivots, B, trans); + } else if (k <= 8) { + lu_solve_batched_cublas_fn(LU, pivots, B, trans); + } else { + lu_solve_triangular(LU, pivots, B, trans); } - else { +} else if (n <= 128) { + if (b <= 2 && k <= 2) { + lu_solve_looped_cusolver(LU, pivots, B, trans); + } else if (k <= 2) { + lu_solve_batched_cublas_fn(LU, pivots, B, trans); + } else { lu_solve_triangular(LU, pivots, B, trans); } +} else { // n > 128 + lu_solve_triangular(LU, pivots, B, trans); +} +#endif // ifdef USE_CUSOLVER +#else // No cublas or cusolver + // lu_solve_triangular is almost always best + lu_solve_triangular(LU, pivots, B, trans); #endif // ifdef CUDART_VERSION } diff --git a/aten/src/ATen/native/cuda/scaled_modified_bessel_k0.cu b/aten/src/ATen/native/cuda/scaled_modified_bessel_k0.cu new file mode 100644 index 0000000000000..4412ab439c145 --- /dev/null +++ b/aten/src/ATen/native/cuda/scaled_modified_bessel_k0.cu @@ -0,0 +1,43 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + namespace native { + namespace { + const char scaled_modified_bessel_k0_name[] = "scaled_modified_bessel_k0_forward"; + + void scaled_modified_bessel_k0_kernel_cuda(TensorIteratorBase& iterator) { +#if AT_USE_JITERATOR() + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "scaled_modified_bessel_k0_cuda", [&]() { + jitted_gpu_kernel(iterator, scaled_modified_bessel_k0_string); + }); +#else + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "scaled_modified_bessel_k0_cuda", [&]() { + gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return scaled_modified_bessel_k0_forward(a); + }); + }); +#endif // AT_USE_JITERATOR() + } + } + + REGISTER_DISPATCH(special_scaled_modified_bessel_k0_stub, &scaled_modified_bessel_k0_kernel_cuda); + } // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/scaled_modified_bessel_k1.cu b/aten/src/ATen/native/cuda/scaled_modified_bessel_k1.cu new file mode 100644 index 0000000000000..093d954e9afc7 --- /dev/null +++ b/aten/src/ATen/native/cuda/scaled_modified_bessel_k1.cu @@ -0,0 +1,43 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + namespace native { + namespace { + const char scaled_modified_bessel_k1_name[] = "scaled_modified_bessel_k1_forward"; + + void scaled_modified_bessel_k1_kernel_cuda(TensorIteratorBase& iterator) { +#if AT_USE_JITERATOR() + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "scaled_modified_bessel_k1_cuda", [&]() { + jitted_gpu_kernel(iterator, scaled_modified_bessel_k1_string); + }); +#else + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "scaled_modified_bessel_k1_cuda", [&]() { + gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return scaled_modified_bessel_k1_forward(a); + }); + }); +#endif // AT_USE_JITERATOR() + } + } + + REGISTER_DISPATCH(special_scaled_modified_bessel_k1_stub, &scaled_modified_bessel_k1_kernel_cuda); + } // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/spherical_bessel_j0.cu b/aten/src/ATen/native/cuda/spherical_bessel_j0.cu new file mode 100644 index 0000000000000..3e7df4a16ec9b --- /dev/null +++ b/aten/src/ATen/native/cuda/spherical_bessel_j0.cu @@ -0,0 +1,43 @@ +#define TORCH_ASSERT_NO_OPERATORS + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + namespace native { + namespace { + const char spherical_bessel_j0_name[] = "spherical_bessel_j0_forward"; + + void spherical_bessel_j0_kernel_cuda(TensorIteratorBase& iterator) { +#if AT_USE_JITERATOR() + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "spherical_bessel_j0_cuda", [&]() { + jitted_gpu_kernel(iterator, spherical_bessel_j0_string); + }); +#else + AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "spherical_bessel_j0_cuda", [&]() { + gpu_kernel(iterator, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return spherical_bessel_j0_forward(a); + }); + }); +#endif // AT_USE_JITERATOR() + } + } + + REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &spherical_bessel_j0_kernel_cuda); + } // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp index 6968548b0e0e5..9f921faf0320d 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.cpp +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -63,6 +63,7 @@ namespace at { namespace native { std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params) { out << "ConvolutionParams \n" + << " memory_format = " << params.memory_format << "\n" << " data_type = " << cudnnTypeToString(params.dataType) << "\n" << " padding = " << ArrayRef{params.padding} << "\n" << " stride = " << ArrayRef{params.stride} << "\n" @@ -83,7 +84,7 @@ void setConvolutionParams( ConvolutionParams* params, const at::Tensor& input, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool deterministic, bool allow_tf32) { + int64_t groups, bool deterministic, bool allow_tf32, at::MemoryFormat memory_format) { cudnnDataType_t dataType = getCudnnDataType(input); memset(params, 0, sizeof(ConvolutionParams)); @@ -91,7 +92,7 @@ void setConvolutionParams( params->dataType = dataType; // ASSERT(weight.dim() == input.dim()) params->input_dim = input.dim(); - params->memory_format = input.suggest_memory_format(); + params->memory_format = memory_format; for (int i = 0; i != params->input_dim; ++i) { params->input_size[i] = (int) input.sizes()[i]; params->weight_size[i] = (int) weight.sizes()[i]; diff --git a/aten/src/ATen/native/cudnn/ConvShared.h b/aten/src/ATen/native/cudnn/ConvShared.h index 9ee5bfb3f9e6c..fbcf667f40fc3 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.h +++ b/aten/src/ATen/native/cudnn/ConvShared.h @@ -48,7 +48,7 @@ void setConvolutionParams( ConvolutionParams* params, const at::Tensor& input, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool deterministic, bool allow_tf32); + int64_t groups, bool deterministic, bool allow_tf32, at::MemoryFormat memory_format); std::string repro_from_args(const ConvolutionParams& args); diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index a2ff4839a40cb..5225fff3bc234 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -628,8 +628,8 @@ void raw_cudnn_convolution_forward_out_32bit( ConvolutionArgs args{ input, output, weight }; args.handle = getCudnnHandle(); - setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32); at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, weight); + setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format); args.idesc.set(input, memory_format); args.wdesc.set(weight, memory_format, 0); args.odesc.set(output, memory_format); @@ -692,8 +692,8 @@ void raw_cudnn_convolution_backward_input_out_32bit( ConvolutionArgs args{ grad_input, grad_output, weight }; args.handle = getCudnnHandle(); - setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic, allow_tf32); at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(grad_input, weight); + setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format); args.idesc.set(grad_input, memory_format); args.wdesc.set(weight, memory_format, 0); args.odesc.set(grad_output, memory_format); @@ -755,8 +755,8 @@ void raw_cudnn_convolution_backward_weight_out_32bit( ConvolutionArgs args{ input, grad_output, grad_weight }; args.handle = getCudnnHandle(); - setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic, allow_tf32); at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, grad_weight); + setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format); args.idesc.set(input, memory_format); args.wdesc.set(grad_weight, memory_format, 0); args.odesc.set(grad_output, memory_format); @@ -868,6 +868,7 @@ void raw_cudnn_convolution_add_relu_out_v7( auto dataType = getCudnnDataType(input); ConvolutionArgs args{input, output, weight}; args.handle = getCudnnHandle(); + at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, weight); setConvolutionParams( &args.params, input, @@ -877,8 +878,8 @@ void raw_cudnn_convolution_add_relu_out_v7( dilation, groups, deterministic, - allow_tf32); - at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, weight); + allow_tf32, + memory_format); args.idesc.set(input, memory_format); args.wdesc.set(weight, memory_format, 0); args.odesc.set(output, memory_format); diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index 843fb5297050f..7d5664b12cf51 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -152,7 +152,7 @@ BenchmarkCache benchmark_cache_fus // would not be a POD anymore. void setCacheKey(CacheKey& key, const cudnnBackendDescriptorType_t operation, const Tensor& y, const Tensor& x, const Tensor& w, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, int64_t groups, bool deterministic, bool allow_tf32) { memset(&key, 0, sizeof(key)); - setConvolutionParams(&key.params, x, w, padding, stride, dilation, groups, deterministic, allow_tf32); + setConvolutionParams(&key.params, x, w, padding, stride, dilation, groups, deterministic, allow_tf32, x.suggest_memory_format()); key.operation = operation; key.x_alignment = getAlignment(x); key.y_alignment = getAlignment(y); @@ -161,7 +161,7 @@ void setCacheKey(CacheKey& key, const cudnnBackendDescriptorType_t operation, co void setCacheKeyFused(CacheKeyFused& key, const Tensor& y, const Tensor& x, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, int64_t groups, bool deterministic, bool allow_tf32) { memset(&key, 0, sizeof(key)); - setConvolutionParams(&key.params, x, w, padding, stride, dilation, groups, deterministic, allow_tf32); + setConvolutionParams(&key.params, x, w, padding, stride, dilation, groups, deterministic, allow_tf32, x.suggest_memory_format()); key.x_alignment = getAlignment(x); key.y_alignment = getAlignment(y); key.w_alignment = getAlignment(w); @@ -344,7 +344,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera remove_invalid = true; } } - if (remove_invalid) { + if (remove_invalid || max_plans) { cudnn_frontend::executionPlans_t new_valid_plans; unsigned int plan_count = 0; for (auto &plan : valid_plans) { @@ -370,7 +370,8 @@ auto get_plans_from_find(const cudnnHandle_t handle, const cudnnBackendDescripto cudnn_frontend::executionPlans_t valid_plans; c10::DeviceGuard g(x.options().device()); at::DataPtr workspace_ptr; - generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr); + auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN(); + generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr, benchmark_limit); auto variantPack = cudnn_frontend::VariantPackBuilder() .setDataPointers(3, data_ptrs) .setUids(3, uids) @@ -400,7 +401,8 @@ auto get_plans_from_find_fused(const cudnnHandle_t handle, cudnn_frontend::executionPlans_t valid_plans; c10::DeviceGuard g(x.options().device()); at::DataPtr workspace_ptr; - generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr); + auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN(); + generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr, benchmark_limit); auto variantPack = cudnn_frontend::VariantPackBuilder() .setDataPointers(5, data_ptrs) .setUids(5, uids) diff --git a/aten/src/ATen/native/metal/MetalShaders.h b/aten/src/ATen/native/metal/MetalShaders.h index d7a76f77a314e..3fcc84173d396 100644 --- a/aten/src/ATen/native/metal/MetalShaders.h +++ b/aten/src/ATen/native/metal/MetalShaders.h @@ -421,6 +421,35 @@ kernel void hardswish(texture2d_array in_arr[[texture(0), fu } } +constant bool hardshrink_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); +constant bool hardshrink_is_tex = !hardshrink_is_arr; +kernel void hardshrink(texture2d_array in_arr[[texture(0), function_constant(hardshrink_is_arr)]], + texture2d in_tex[[texture(0), function_constant(hardshrink_is_tex)]], + texture2d_array out_arr[[texture(1), function_constant(hardshrink_is_arr)]], + texture2d out_tex[[texture(1), function_constant(hardshrink_is_tex)]], + ushort3 gid[[thread_position_in_grid]]) { + const ushort oH = ushort_arg_2; + const ushort oW = ushort_arg_3; + const half lambda = (half)float_arg_0; + if (gid.x >= oW || gid.y >= oH) { + return; + } + ushort2 gid_ = gid.xy; + if (hardshrink_is_arr) { + half4 value = in_arr.read(gid_, gid.z); + half4 mask1 = half4(value <= lambda); + half4 mask2 = half4(value >= -lambda); + half4 outval = (1 - mask1)*value + (1 - mask2)*value; + out_arr.write(outval, gid_, gid.z); + } else { + half4 value = in_tex.read(gid_); + half4 mask1 = half4(value <= lambda); + half4 mask2 = half4(value >= -lambda); + half4 outval = (1 - mask1)*value + (1 - mask2)*value; + out_tex.write(outval, gid_); + } +} + constant bool leaky_relu_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); constant bool leaky_relu_is_tex = !leaky_relu_is_arr; kernel void leaky_relu(texture2d_array in_arr[[texture(0), function_constant(leaky_relu_is_arr)]], diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h index dac76535764ff..a8698a3daad58 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h @@ -42,6 +42,8 @@ bool test_sigmoid(); bool test_hardsigmoid(); bool test_hardswish_(); bool test_hardswish(); +bool test_hardshrink_(); +bool test_hardshrink(); bool test_leaky_relu_(); bool test_leaky_relu(); bool test_upsampling_nearest2d_vec(); diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm index 38a4fdefe53d8..91e93858f56c7 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm @@ -25,6 +25,44 @@ bool checkRtol(const at::Tensor& diff, const std::vector inputs) { } return diff.abs().max().item() < (0.01 + 2e-2 * maxValue); } + +bool checkHardShrink(const at::Tensor& ref, const at::Tensor& out, const float clamp_thresh) { + float* ref_ptr = ref.data_ptr(); + float* out_ptr = out.data_ptr(); + float ref_max = ref.abs().max().item(); + float out_max = out.abs().max().item(); + float max_val = std::fmax(ref_max, out_max); + float kTolerance = 1e-2; + + float abs_clamp_thresh = std::abs(clamp_thresh); + + for (int i = 0; i < ref.numel(); ++i) { + float ref_val = ref_ptr[i]; + float out_val = out_ptr[i]; + + float abs_diff = std::abs(ref_val - out_val); + + // For values near the clamp threshold, results may be ambiguous. + float distance_from_thresh = std::abs(std::abs(ref_val) - abs_clamp_thresh); + if (distance_from_thresh < kTolerance * abs_clamp_thresh) { + if (out_val != 0.0f) { + if (abs_diff >= kTolerance * max_val) { + return false; + } + } + } + else if (std::abs(ref_val) < std::abs(abs_clamp_thresh)) { + if (out_val != 0.0f) { + return false; + } + } + else if (abs_diff >= kTolerance * max_val) { + return false; + } + } + return true; +} + bool almostEqual(const at::Tensor& a, const at::Tensor& b) { return checkRtol(a - b, {a, b}) && a.strides().vec() == b.strides().vec(); } @@ -274,6 +312,44 @@ bool test_hardswish() { }); } +bool test_hardshrink_() { + __block std::vector size{3, 3, 44, 44}; + bool result = true; + for (const auto lambd_value : {0.42, 1.0, 4.2, 13.7}) { + bool b = TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X = + (at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) - 0.5) * 20; + auto X2 = X.metal(); + auto Y1 = X.hardshrink(lambd_value); + auto Y2 = X2.hardshrink(lambd_value).cpu(); + return checkHardShrink(Y1, Y2, lambd_value); + }); + if (!b) { + result = false; + } + } + return result; +} + +bool test_hardshrink() { + __block std::vector size{3, 3, 44, 44}; + bool result = true; + for (const auto lambd_value : {0.42, 1.0, 4.2, 13.7}) { + bool b = TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X = + (at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) - 0.5) * 20; + auto X2 = X.metal(); + auto Y1 = at::hardshrink(X, lambd_value); + auto Y2 = at::hardshrink(X2, lambd_value).cpu(); + return checkHardShrink(Y1, Y2, lambd_value); + }); + if (!b) { + result = false; + } + } + return result; +} + bool test_leaky_relu_() { __block std::vector size{3, 3, 44, 44}; return TEST(size, __PRETTY_FUNCTION__, ^bool { diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm index 71f3558459c5e..5932240df2e17 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm @@ -70,6 +70,8 @@ - (void)registerTests { REG_TEST("test_hardsigmoid", test_hardsigmoid); REG_TEST("test_hardswish_", test_hardswish_); REG_TEST("test_hardswish", test_hardswish); + REG_TEST("test_hardshrink_", test_hardshrink_); + REG_TEST("test_hardshrink", test_hardshrink); REG_TEST("test_leaky_relu_", test_leaky_relu_); REG_TEST("test_leaky_relu", test_leaky_relu); REG_TEST("test_upsampling_nearest2d_vec", test_upsampling_nearest2d_vec); diff --git a/aten/src/ATen/native/metal/ops/MetalHardshrink.mm b/aten/src/ATen/native/metal/ops/MetalHardshrink.mm new file mode 100644 index 0000000000000..9727680704074 --- /dev/null +++ b/aten/src/ATen/native/metal/ops/MetalHardshrink.mm @@ -0,0 +1,93 @@ +#include +#import +#import +#import +#import +#import +#import +#import +#import +#include + +namespace at { +namespace native { +namespace metal { + +using MetalTensorImpl = at::MetalTensorImpl; + +Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) { + float l = lambda.toFloat(); + MPSImage* X = imageFromTensor(input); + MetalCommandBuffer* commandBuffer = getCommandBuffer(input); + IntArrayRef outputSize = input.sizes(); + std::vector imageSize = computeImageSize(outputSize); + MPSImage* Y = createTemporaryImage(commandBuffer, imageSize); + id encoder = + [commandBuffer.buffer computeCommandEncoder]; + id state = + [[MetalContext sharedInstance] specializedPipelineState:"hardshrink" + Constants:@[ + @(X.numberOfImages), + @(X.featureChannels), + @(X.height), + @(X.width), + @(l) + ]]; + + [encoder setComputePipelineState:state]; + [encoder setTexture:[X texture] atIndex:0]; + [encoder setTexture:[Y texture] atIndex:1]; + + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl(); + MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle(); + implStorage.texture()->setImage(Y); + return input; +} + +Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) { + float l = lambda.toFloat(); + MPSImage* X = imageFromTensor(input); + IntArrayRef outputSize = input.sizes(); + MetalTensorImplStorage mt{outputSize.vec()}; + MetalCommandBuffer* commandBuffer = getCommandBuffer(input); + mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); + MPSImage* Y = mt.texture()->image(); + id encoder = + [commandBuffer.buffer computeCommandEncoder]; + id state = + [[MetalContext sharedInstance] specializedPipelineState:"hardshrink" + Constants:@[ + @(X.numberOfImages), + @(X.featureChannels), + @(X.height), + @(X.width), + @(l) + ]]; + + [encoder setComputePipelineState:state]; + [encoder setTexture:[X texture] atIndex:0]; + [encoder setTexture:[Y texture] atIndex:1]; + + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + + auto output = makeTensor(std::move(mt), input.options()); + return output; +} + +TORCH_LIBRARY_IMPL(aten, Metal, m) { + m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink_"), TORCH_FN(hardshrink_)); + m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink)); +}; + +} +} +} diff --git a/aten/src/ATen/native/metal/ops/MetalReduce.mm b/aten/src/ATen/native/metal/ops/MetalReduce.mm index f2c1bc84287db..b0da375809b87 100644 --- a/aten/src/ATen/native/metal/ops/MetalReduce.mm +++ b/aten/src/ATen/native/metal/ops/MetalReduce.mm @@ -30,7 +30,7 @@ Tensor wrapper_mean_dim( const Tensor& input, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim, c10::optional dtype) { if (@available(iOS 11.3, *)) { @@ -39,18 +39,21 @@ Tensor wrapper_mean_dim( TORCH_CHECK(imageSize.size() == 4); // TODO: [T87340633] Support reducing the batch dimension TORCH_CHECK(imageSize[0] == 1); - auto mask = make_dim_mask(dims, input.dim()); + auto mask = make_dim_mask(opt_dims, input.dim()); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); MPSImage* Y = nil; - for (int dim : dims) { - imageSize[dim] = 1; - MPSNNReduceUnary* kernel = kernelForReducedDim(dim); - if (kernel) { - Y = createTemporaryImage(commandBuffer, imageSize); - [kernel encodeToCommandBuffer:commandBuffer.buffer - sourceImage:X - destinationImage:Y]; - X = Y; + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + for (int dim : dims) { + imageSize[dim] = 1; + MPSNNReduceUnary* kernel = kernelForReducedDim(dim); + if (kernel) { + Y = createTemporaryImage(commandBuffer, imageSize); + [kernel encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + X = Y; + } } } MetalTensorImplStorage mt{imageSize}; diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index a2489e42e1852..0096a1cda6743 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -43,6 +43,81 @@ REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub); namespace at { namespace native { +// follow check rules from native/Convolution.cpp without transpose supported +static void check_shape_forward(const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const IntArrayRef& padding, + const IntArrayRef& stride, + const IntArrayRef& dilation, + const int64_t groups) { +#define MKLDNN_CONV_ARG_CHECK(IT, OP) std::any_of(IT.begin(), IT.end(), [](auto x) { return x OP 0; }) + auto is_padding_neg = MKLDNN_CONV_ARG_CHECK(padding, <); + auto is_stride_nonpos = MKLDNN_CONV_ARG_CHECK(stride, <=); + auto is_dilation_nonpos = MKLDNN_CONV_ARG_CHECK(dilation, <=); +#undef MKLDNN_CONV_ARG_CHECK + TORCH_CHECK(!is_padding_neg, "negative padding is not supported"); + TORCH_CHECK(!is_stride_nonpos, "non-positive stride is not supported"); + TORCH_CHECK(!is_dilation_nonpos, "non-positive dilation is not supported"); + TORCH_CHECK(groups > 0, "non-positive groups is not supported"); + + int64_t k = input.ndimension(); + const IntArrayRef& weight_sizes = weight.sizes(); + int64_t weight_dim = weight_sizes.size(); + + TORCH_CHECK(weight_dim == k, + "Expected ", weight_dim, "-dimensional input for ", weight_dim, + "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ", + input.sizes(), " instead"); + TORCH_CHECK(weight_sizes[0] >= groups, + "Given groups=", groups, ", expected weight to be at least ", groups, + " at dimension 0, but got weight of size ", weight_sizes, " instead"); + TORCH_CHECK(weight_sizes[0] % groups == 0, + "Given groups=", groups, ", expected weight to be divisible by ", + groups, " at dimension 0, but got weight of size [", weight_sizes, + "] instead"); + TORCH_CHECK(input.size(1) == (weight_sizes[1] * groups), + "Given groups=", groups, ", weight of size ", weight_sizes, + ", expected input", input.sizes(), " to have ", + (weight_sizes[1] * groups), " channels, but got ", input.size(1), + " channels instead"); + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]), + "Given weight of size ", weight_sizes, + ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements", + ", but got bias of size ", bias.sizes(), " instead"); + + std::vector input_shape; + std::vector kernel_shape; + bool kernel_size_correct = true; + + for (const auto i : c10::irange(2, k)) { + input_shape.push_back(input.size(i) + 2 * padding[i-2]); + // log new kernel size considering dilation + kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1); + if (input_shape.back() < kernel_shape.back()) { + kernel_size_correct = false; + } + } + + TORCH_CHECK(input_shape.size() == kernel_shape.size(), "Inconsistent shape between Input and Kernel"); + + if (!kernel_size_correct) { + // If kernel size is incorrect + std::ostringstream input_ss; + std::ostringstream kernel_ss; + std::string separator = ""; + + for (int i = 0, len = input_shape.size(); i < len; ++i) { + input_ss << separator << input_shape[i]; + kernel_ss << separator << kernel_shape[i]; + separator = " x "; + } + + TORCH_CHECK(false, "Calculated padded input size per channel: (", input_ss.str(), "). " + "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size"); + } +} + #define MKLDNNTensor(itensor, options) \ new_with_itensor_mkldnn( \ std::move(itensor), \ @@ -82,7 +157,8 @@ namespace at { namespace native { Tensor mkldnn_convolution( const Tensor& input, - const Tensor& weight, const c10::optional& bias_opt, + const Tensor& weight, + const c10::optional& bias_opt, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, @@ -96,6 +172,8 @@ Tensor mkldnn_convolution( "mkldnn_convolution: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); } + check_shape_forward(input, weight, bias, padding, stride, dilation, groups); + bool is_channels_last = input.suggest_memory_format() == at::MemoryFormat::ChannelsLast; auto output_sizes = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index 8bc001aa2ede5..c49885e3c8c1a 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -117,6 +117,10 @@ ideep::tensor itensor_from_tensor(const Tensor& tensor) { } } +int set_verbose(int level) { + return ideep::utils::set_verbose(level); +} + }} #endif // AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h index d2c9f93569c97..a86d1c4b722c3 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h @@ -24,6 +24,9 @@ TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor); // Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor. TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor); +// Set MKLDNN verbose level +TORCH_API int set_verbose(int level); + }} #endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index 87c0e31710499..e399e2143dea6 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -23,6 +23,17 @@ bool use_mkldnn_bf16_matmul( return false; } +bool mkldnn_bf16_gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const c10::BFloat16 *a, int64_t lda, + const c10::BFloat16 *b, int64_t ldb, + float beta, + c10::BFloat16 *c, int64_t ldc) { + return false; +} + } // namespace native } // namespace at @@ -34,6 +45,75 @@ bool use_mkldnn_bf16_matmul( namespace at { namespace native { +static bool use_mkldnn_bf16_matmul() { + return ( + at::globalContext().userEnabledMkldnn() && + mkldnn_bf16_device_check()); +} + +bool mkldnn_bf16_gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const c10::BFloat16 *a_data, int64_t lda, + const c10::BFloat16 *b_data, int64_t ldb, + float beta, + c10::BFloat16 *c_data, int64_t ldc) { + if (!use_mkldnn_bf16_matmul() || + (m * n * k <= 16 * 16 * 16) || + (alpha == 0.0f)) { + return false; + } + + ideep::attr_t op_attr; + // Use mkldnn post ops to perform the add. + if (beta != 0.0f) { + op_attr = ideep::attr_t::fuse_sum(); + } + + ideep::tensor::dims a_strides{{1, lda}}, b_strides{{1, ldb}}, c_strides{{1, ldc}}; + if (transa != TransposeType::NoTranspose) { + std::swap(a_strides[0], a_strides[1]); + } + if (transb != TransposeType::NoTranspose) { + std::swap(b_strides[0], b_strides[1]); + } + + ideep::tensor a({ + /*sizes=*/{m, k}, + ideep::tensor::data_type::bf16, + /*strides=*/a_strides}, + const_cast(a_data)); + ideep::tensor b({ + /*sizes=*/{k, n}, + ideep::tensor::data_type::bf16, + /*strides=*/b_strides}, + const_cast(b_data)); + ideep::tensor c({ + /*sizes=*/{m, n}, + ideep::tensor::data_type::bf16, + /*strides=*/c_strides}, + c_data); + + ideep::matmul_forward::compute( + a, b, c, alpha, beta, + ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr); + + if (c.get_data_handle() != c_data){ + // ideep will query onednn expect format of output + // if given output format is not expected, ideep will re-init an output buffer + // under this case, we need copy the re-inited buffer back to given buffer + ideep::tensor real_output({ + /*sizes=*/{m, n}, + ideep::tensor::data_type::bf16, + /*strides=*/c_strides}, + c_data); + c.reorder_to(real_output); + } + + return true; +} + void mkldnn_matmul( const Tensor &mat1, const Tensor &mat2, @@ -128,14 +208,13 @@ bool use_mkldnn_bf16_matmul( const Tensor& mat2, const Tensor& result) { return ( - at::globalContext().userEnabledMkldnn() && - mat1.scalar_type() == kBFloat16 && - mat2.scalar_type() == kBFloat16 && - (!result.defined() || result.scalar_type() == kBFloat16) && - mat1.numel() != 0 && - mat2.numel() != 0 && - mkldnn_bf16_device_check() && - checksize(mat1, mat2)); + use_mkldnn_bf16_matmul() && + mat1.scalar_type() == kBFloat16 && + mat2.scalar_type() == kBFloat16 && + (!result.defined() || result.scalar_type() == kBFloat16) && + mat1.numel() != 0 && + mat2.numel() != 0 && + checksize(mat1, mat2)); } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/Matmul.h b/aten/src/ATen/native/mkldnn/Matmul.h index 250f4f228420c..63426714933b2 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.h +++ b/aten/src/ATen/native/mkldnn/Matmul.h @@ -2,6 +2,7 @@ #include #include +#include // For TransposeType namespace at { namespace native { @@ -18,6 +19,16 @@ bool use_mkldnn_bf16_matmul( const Tensor& mat2, const Tensor& result_opt); +// Try running mkldnn optimized gemm, or returns false if naive gemm would be faster +bool mkldnn_bf16_gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const c10::BFloat16 *a, int64_t lda, + const c10::BFloat16 *b, int64_t ldb, + float beta, + c10::BFloat16 *c, int64_t ldc); + } } diff --git a/aten/src/ATen/native/mkldnn/TensorFactories.cpp b/aten/src/ATen/native/mkldnn/TensorFactories.cpp index dc34281d25cac..a944d4db19b62 100644 --- a/aten/src/ATen/native/mkldnn/TensorFactories.cpp +++ b/aten/src/ATen/native/mkldnn/TensorFactories.cpp @@ -2,6 +2,10 @@ namespace at { namespace native { +Tensor empty_symint_mkldnn(c10::SymIntArrayRef sizes, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { + return at::native::empty_mkldnn(c10::asIntArrayRefSlow(sizes), dtype, layout, device, pin_memory, optional_memory_format); +} + #if AT_MKLDNN_ENABLED() Tensor empty_mkldnn(IntArrayRef sizes, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp index 7676edc1878aa..ec3c58eda77f1 100644 --- a/aten/src/ATen/native/mkldnn/TensorShape.cpp +++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #if !AT_MKLDNN_ENABLED() @@ -86,3 +87,15 @@ Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) { } // namespace at #endif // AT_MKLDNN_ENABLED + + +namespace at { +namespace native { + + +Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size) { + return mkldnn_view(self, c10::asIntArrayRefSlow(size)); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/mkldnn/TensorShape.h b/aten/src/ATen/native/mkldnn/TensorShape.h index bbb8ea9bca063..92af7e2759347 100644 --- a/aten/src/ATen/native/mkldnn/TensorShape.h +++ b/aten/src/ATen/native/mkldnn/TensorShape.h @@ -1,12 +1,15 @@ #pragma once #include +#include namespace at { namespace native { Tensor mkldnn_view(const Tensor& self, IntArrayRef size); +Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size); + Tensor mkldnn_clone(const Tensor& self); } // namespace native diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 9f9fa2d7a3a59..32aede7fc5e0d 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -47,21 +47,25 @@ MPSDataType getMPSDataType(ScalarType scalar_type); MPSDataType getMPSScalarType(ScalarType scalar_type); std::string getMPSTypeString(ScalarType scalar_type); std::string getMPSShapeString(MPSShape* shape); -std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = true); +std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false); double getMPSScalarValue(const Tensor& t); std::string getArrayRefString(const IntArrayRef s); -std::string getStridedKey(const Tensor& self, const IntArrayRef sz, - const IntArrayRef strides, int64_t offset); -id gatherViewTensor(const at::Tensor& src, id s); +// use has_storage() on the returned tensor to determine if src actually is a view +Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst); +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output); MPSShape* getMPSShape(const Tensor& t); MPSShape* getMPSShape(IntArrayRef sizes); MPSShape* getMPSShape(c10::MaybeOwned t); +static inline id getMTLBufferStorage(const at::Tensor& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + class Placeholder { public: - Placeholder() : _placeholder(nullptr), _value(nullptr) {} - Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr) {} + Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {} Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr); MPSGraphTensor* getMPSGraphTensor() { return _placeholder; @@ -73,22 +77,10 @@ class Placeholder { return _value == nullptr; } - void allocateViewTensor(const at::Tensor& src) - { - assert (!_viewOutput.numel()); - _viewOutput = at::native::empty_mps( - src.sizes(), - src.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - } - private: MPSGraphTensor* _placeholder; MPSGraphTensorData* _value; - Tensor _viewOutput; + Tensor _tensor; }; void resize_tensor(Tensor* output); @@ -103,6 +95,7 @@ void printTensorNDArray(const Tensor& t); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar); string get_mem_format_string(c10::MemoryFormat memory_format); @@ -117,12 +110,26 @@ struct MPSCachedGraph [_object release]; _object = nullptr; } + + template + inline T* as() { + return static_cast(this); + } + MPSGraph *graph() const { return (MPSGraph *)_object; } NSObject *object() const { return _object; } private: NSObject *_object = nullptr; }; +struct MPSUnaryCachedGraph : public MPSCachedGraph +{ + MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; +}; + + // TODO: Improve the overall design of MPSGraphCache. // https://github.com/pytorch/pytorch/issues/77176 // Cache holding various keys mapped to graphs @@ -200,6 +207,11 @@ struct MPSGraphCache return result; } + template + inline T* LookUpAs(const std::string& key) const { + return static_cast(LookUp(key)); + } + void FindAndRemoveViewEntry(void* ptr) { // this may find multiple view entries with the same buffer pointers auto views_range = views_list.equal_range(ptr); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 72faad64d96b9..7e22e2c103a8f 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -60,41 +60,8 @@ return gen; } -std::string getStridedKey(const Tensor& self, const IntArrayRef sz, - const IntArrayRef strides, int64_t offset) { - // TODO: move storage_offset to a PlaceholderTensor and strides to a - // tensor too, to avoid too many cache entries. - return std::to_string((uintptr_t)self.storage().data()) + - ":" + mps::getArrayRefString(sz) + - ":" + mps::getArrayRefString(strides) + - ":" + std::to_string(offset) + - ":" + getMPSTypeString(self.scalar_type()); -} - -void runMPSGraph( - MPSStream* mpsStream, - MPSGraph* mpsGraph, - NSDictionary* feeds, - NSDictionary* results) { - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - mpsStream->commit(true); - id commandQueue = mpsStream->commandQueue(); - MPSGraphExecutionDescriptor *executionDescriptor = [[MPSGraphExecutionDescriptor new] autorelease]; - - executionDescriptor.completionHandler = ^(NSDictionary * resultsDictionary, - NSError * _Nullable error) { - }; - - [mpsGraph runAsyncWithMTLCommandQueue:commandQueue - feeds:feeds - targetOperations:nil - resultsDictionary:results - executionDescriptor:executionDescriptor]; - - } - }); +void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) { + mpsStream->executeMPSGraph(mpsGraph, feeds, results); } MPSDataType getMPSDataType(ScalarType scalar_type) { @@ -109,8 +76,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { return MPSDataTypeInt64; case ScalarType::Short: return MPSDataTypeInt16; - case ScalarType::Byte: + case ScalarType::Char: return MPSDataTypeInt8; + case ScalarType::Byte: + return MPSDataTypeUInt8; case ScalarType::Bool: return MPSDataTypeBool; case ScalarType::Double: @@ -136,8 +105,10 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return MPSDataTypeInt64; case ScalarType::Short: return MPSDataTypeInt16; - case ScalarType::Byte: + case ScalarType::Char: return MPSDataTypeInt8; + case ScalarType::Byte: + return MPSDataTypeUInt8; case ScalarType::Bool: return MPSDataTypeBool; default: @@ -149,19 +120,21 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { switch (scalar_type) { case ScalarType::Double: case ScalarType::Float: - return "MPSDataTypeFloat32"; + return "Float32"; case ScalarType::Half: - return "MPSDataTypeFloat16"; + return "Float16"; case ScalarType::Int: - return "MPSDataTypeInt32"; + return "Int32"; case ScalarType::Long: - return "MPSDataTypeInt64"; + return "Int64"; case ScalarType::Short: - return "MPSDataTypeInt16"; + return "Int16"; + case ScalarType::Char: + return "Int8"; case ScalarType::Byte: - return "MPSDataTypeInt8"; + return "UInt8"; case ScalarType::Bool: - return "MPSDataTypeBool"; + return "Bool"; default: return "Undefined"; } @@ -219,7 +192,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { { NSInteger sz_i = (i < sz) ? t.size(i) : 1; - NSNumber* number = [NSNumber numberWithInt:sz_i]; + NSNumber* number = [NSNumber numberWithInteger:sz_i]; numbers[i] = number; } return [NSArray arrayWithObjects:numbers count:sz_]; @@ -240,7 +213,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { { NSInteger sz_i = (i < sz) ? sizes[i] : 1; - NSNumber* number = [NSNumber numberWithInt:sz_i]; + NSNumber* number = [NSNumber numberWithInteger:sz_i]; numbers[i] = number; } return [NSArray arrayWithObjects:numbers count:sz_]; @@ -254,109 +227,38 @@ void printTensorNDArray(const Tensor& t) { auto selfDType = getMPSDataType(t.scalar_type()); // Initialize data - id selfBuf = __builtin_bit_cast(id, t.storage().data()); + id selfBuf = getMTLBufferStorage(t); MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape dataType:selfDType] autorelease]; [tdata printNDArray]; } -MPSCachedGraph* _getCachedGraph(const at::Tensor& src) { - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - string key = getStridedKey(src, src.sizes(), src.strides(), src.storage_offset()); - MPSCachedGraph* cachedGraph = cache_->LookUp(key); - - return cachedGraph; -} - -id _gatherViewTensor(const at::Tensor& src, id sourceBuffer, MPSCachedGraph* mpsCachedGraph, Tensor& output) { - TORCH_CHECK(mpsCachedGraph != nil); - - MPSStream* stream = getCurrentMPSStream(); - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; - - CachedGraph* cachedGraph = static_cast(mpsCachedGraph); - - @autoreleasepool { - MPSGraphTensor* inputTensor = cachedGraph->inputTensor_; - MPSGraphTensorData* inputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer - shape: [inputTensor shape] - dataType: [inputTensor dataType]] autorelease]; - id resultBuffer = __builtin_bit_cast(id, output.storage().data()); - MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: resultBuffer - shape: getMPSShape(src.sizes()) - dataType: getMPSDataType(src.scalar_type())] autorelease]; - NSDictionary* feeds = @{ - inputTensor : inputTensorData - }; - - NSDictionary* results = @{ - cachedGraph->outputTensor_ : outputTensorData - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - return resultBuffer; - } -} - -id gatherViewTensor(const at::Tensor& src, id sourceBuffer) { - MPSCachedGraph* mpsCachedGraph = _getCachedGraph(src); - if (mpsCachedGraph) { - Tensor output = at::native::empty_mps( - src.sizes(), - src.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - - _gatherViewTensor(src, sourceBuffer, mpsCachedGraph, output); - return __builtin_bit_cast(id, output.storage().data()); - } - - return nil; -} - -id gatherViewTensorWithAllocatedMem(const at::Tensor& src, id sourceBuffer, Tensor& output, MPSCachedGraph* mpsCachedGraph) { - TORCH_CHECK(mpsCachedGraph != nil); - - _gatherViewTensor(src, sourceBuffer, mpsCachedGraph, output); - return __builtin_bit_cast(id, output.storage().data()); -} - -Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape) +Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape) : _tensor(src) { - Tensor src_ = src; - TORCH_CHECK(src_.is_mps(), "Placeholder storage has not been allocated on MPS device!"); - // extract the pointer to MTLBuffer from the Tensor's storage - id srcBuf = __builtin_bit_cast(id, src.storage().data()); - if (src.is_view()) { - MPSCachedGraph* cachedGraph = _getCachedGraph(src); - if (cachedGraph) { - allocateViewTensor(src); - id gatherTensor = gatherViewTensorWithAllocatedMem(src, srcBuf, _viewOutput, cachedGraph); - if (gatherTensor) { - srcBuf = gatherTensor; - } - } else { - src_ = src.contiguous(); - srcBuf = __builtin_bit_cast(id, src_.storage().data()); + TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!"); + // extract the pointer to MTLBuffer from the Tensor's storage + id srcBuf = getMTLBufferStorage(src); + // a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose()) + if (src.is_view() || !src.is_contiguous()) { + Tensor emptyShell = Tensor(); + // use "_tensor" from Placeholder to retain view's output during its usage in other ops + _tensor = gatherViewTensor(src, emptyShell); + if (!_tensor.has_storage()) { + // if we cannot gather, we make the the tensor contiguous implicitly, and keep + // it in placeholder to be able to retrieve it when we return from constructor + _tensor = src.contiguous(); } + srcBuf = getMTLBufferStorage(_tensor); } // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero. // if buffer size is zero in here, it's not a user error. It could be a missing check for // tensor.numel() == 0 in our internal implementations of ops. TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); - const MPSDataType mpsDataType = src_.dim() == 0 ? getMPSScalarType(src_.scalar_type()) : getMPSDataType(src_.scalar_type()); + const MPSDataType mpsDataType = _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) : getMPSDataType(_tensor.scalar_type()); if (!mpsShape) - mpsShape = getMPSShape(src_); + mpsShape = getMPSShape(_tensor); _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:mpsShape @@ -373,7 +275,7 @@ void printTensorNDArray(const Tensor& t) { MPSGraphTensorData *result = nil; if (tensor.numel() > 0) { - id buf = __builtin_bit_cast(id, tensor.storage().data()); + id buf = getMTLBufferStorage(tensor); result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf shape:mpsShape dataType:dataType] @@ -416,6 +318,9 @@ void printTensorNDArray(const Tensor& t) { case MPSDataTypeInt8: v.i = scalar.to(); break; + case MPSDataTypeUInt8: + v.i = scalar.to(); + break; case MPSDataTypeBool: v.b = scalar.to(); break; @@ -458,6 +363,12 @@ void resize_tensor(Tensor* output) { name:nil]; } +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar) { + return [mpsGraph placeholderWithShape:@[@1] + dataType:getMPSScalarType(scalar.type()) + name:nil]; +} + // this is meant to suppress the availability warning on castTensor // we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType) { diff --git a/aten/src/ATen/native/mps/TensorFactory.cpp b/aten/src/ATen/native/mps/TensorFactory.cpp index 78899fc8fa3c3..d280da4d9c650 100644 --- a/aten/src/ATen/native/mps/TensorFactory.cpp +++ b/aten/src/ATen/native/mps/TensorFactory.cpp @@ -71,6 +71,17 @@ Tensor empty_mps( return at::detail::empty_mps(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } +Tensor empty_symint_mps( + c10::SymIntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt) { + + return at::native::empty_mps(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); +} + Tensor empty_strided_mps( IntArrayRef size, IntArrayRef stride, diff --git a/aten/src/ATen/native/mps/TensorFactory.h b/aten/src/ATen/native/mps/TensorFactory.h index cb7931deb6bc2..b669cce44fe6d 100644 --- a/aten/src/ATen/native/mps/TensorFactory.h +++ b/aten/src/ATen/native/mps/TensorFactory.h @@ -1,17 +1,10 @@ // Copyright © 2022 Apple Inc. -#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ - }() +#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)) diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index d765de498b3f5..b741276b45e01 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -18,24 +18,18 @@ Tensor relu_mps(const Tensor& self) { using namespace mps; + using CachedGraph = MPSUnaryCachedGraph; Tensor output = at::empty_like(self); resize_tensor(&output); TORCH_CHECK(output.is_mps()); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { string key = "relu" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -79,24 +73,18 @@ Tensor relu_mps(const Tensor& self) { Tensor & relu_mps_(Tensor & self) { using namespace mps; + using CachedGraph = MPSUnaryCachedGraph; // Inplace relu Tensor &output = self; TORCH_CHECK(output.is_mps()); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { string key = "relu_" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -141,15 +129,9 @@ Tensor relu_mps(const Tensor& self) { TORCH_IMPL_FUNC(leaky_relu_out_mps) ( const Tensor& self, const Scalar& negative_slope, const Tensor& output) { using namespace mps; + using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(output.is_mps()); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream *stream = getCurrentMPSStream(); @@ -157,7 +139,7 @@ Tensor relu_mps(const Tensor& self) { @autoreleasepool { string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -185,7 +167,7 @@ Tensor relu_mps(const Tensor& self) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = tmpCachedGraph->as(); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); @@ -296,25 +278,19 @@ Tensor relu_mps(const Tensor& self) { const bool half_to_float, const Tensor &out) { using namespace mps; + using CachedGraph = MPSUnaryCachedGraph; if (self.numel() == 0) { return; } - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = at::mps::getCurrentMPSStream(); @autoreleasepool { string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + to_string(dim); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -438,22 +414,16 @@ Tensor relu_mps(const Tensor& self) { const Tensor& self, const Tensor& output) { using namespace mps; + using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(output.is_mps()); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { string key = "sigmoid_out_mps" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -651,15 +621,9 @@ Tensor relu_mps(const Tensor& self) { const Scalar& value, const Tensor& result) { using namespace mps; + using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(self.is_mps()); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream(); @@ -669,7 +633,7 @@ Tensor relu_mps(const Tensor& self) { to_string(threshold.to()) + ":" + to_string(value.to()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -871,7 +835,7 @@ Tensor relu_mps(const Tensor& self) { getMPSDataType(self.scalar_type()), getMPSShape(self)); - MPSGraphTensor* outputTensor = normcdf(mpsGraph, inputTensor); + MPSGraphTensor* outputTensor = normcdf(mpsGraph, inputTensor); outputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor secondaryTensor:inputTensor name:nil]; @@ -1266,6 +1230,212 @@ void elu_variants_out_mps ( } +TORCH_IMPL_FUNC(glu_out_mps) ( + const Tensor& self, const int64_t dim, const Tensor& output + ) { + using namespace mps; + TORCH_CHECK(output.is_mps()); + + // Empty output + if(output.numel() == 0) + return; + + // this can't pass anyway because a 0-dimensional tensor has "size" 1, which + // can't be evenly halved, but give a nicer error message here. + TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); + auto wrap_dim = maybe_wrap_dim(dim, self.dim()); + const int64_t nIn = self.size(wrap_dim); + TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", + wrap_dim, " is size ", nIn); + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); + + @autoreleasepool { + string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim);; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, + getMPSDataType(self.scalar_type()), + getMPSShape(self)); + NSArray * outputTensorsArray = [mpsGraph splitTensor:inputTensor + numSplits:2 + axis:wrap_dim + name:nil]; + MPSGraphTensor* firstHalf = outputTensorsArray[0]; + MPSGraphTensor* secondHalf = [mpsGraph sigmoidWithTensor:outputTensorsArray[1] + name:nil]; + + MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor:firstHalf + secondaryTensor:secondHalf + name:nil]; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + + } + +} + +Tensor& glu_backward_mps_out ( + const Tensor& grad_output, const Tensor& self, const int64_t dim, Tensor& grad_input + ) { + using namespace mps; + + // Empty output + if(grad_input.numel() == 0) + return grad_input; + + // this can't pass anyway because a 0-dimensional tensor has "size" 1, which + // can't be evenly halved, but give a nicer error message here. + TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); + auto wrap_dim = maybe_wrap_dim(dim, self.dim()); + const int64_t nIn = self.size(wrap_dim); + TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", + wrap_dim, " is size ", nIn); + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *gradInputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); + + @autoreleasepool { + string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + to_string(dim); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, + getMPSDataType(self.scalar_type()), + getMPSShape(self)); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, + getMPSDataType(grad_output.scalar_type()), + getMPSShape(grad_output)); + NSArray * inputTensorsArray = [mpsGraph splitTensor:inputTensor + numSplits:2 + axis:wrap_dim + name:nil]; + + // first half + MPSGraphTensor* sigmoidOutputTensor = [mpsGraph sigmoidWithTensor:inputTensorsArray[1] + name:nil]; + MPSGraphTensor* firstHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor : sigmoidOutputTensor + secondaryTensor : gradOutputTensor + name : nil]; + + // second half + MPSGraphTensor* one_val = [mpsGraph constantWithScalar:1.0 + shape:@[@1] + dataType:getMPSDataType(self.scalar_type())]; + + MPSGraphTensor* secondHalfOutputTensor = [mpsGraph subtractionWithPrimaryTensor : one_val + secondaryTensor : sigmoidOutputTensor + name : nil]; + secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor : secondHalfOutputTensor + secondaryTensor : sigmoidOutputTensor + name : nil]; + secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor : secondHalfOutputTensor + secondaryTensor : inputTensorsArray[0] + name : nil]; + secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor : secondHalfOutputTensor + secondaryTensor : gradOutputTensor + name : nil]; + + MPSGraphTensor* outputTensor = [mpsGraph concatTensor : firstHalfOutputTensor + withTensor : secondHalfOutputTensor + dimension : wrap_dim + name : nil]; + newCachedGraph->gradInputTensor_ = outputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData(), + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + + } + return grad_input; + +} + +Tensor glu_backward_mps (const Tensor& grad_output, + const Tensor& self, + const int64_t dim) { + + Tensor grad_input = at::native::empty_mps( + self.sizes(), + self.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + grad_input = glu_backward_mps_out(grad_output, self, dim, grad_input); + return grad_input; +} + + TORCH_IMPL_FUNC(softplus_out_mps) ( const Tensor& self, const Scalar& beta, @@ -1364,8 +1534,353 @@ void elu_variants_out_mps ( }; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } +} + +TORCH_IMPL_FUNC(softplus_backward_out_mps) ( + const Tensor& grad_output, + const Tensor& self, + const Scalar& beta, + const Scalar& threshold, + const Tensor& grad_input +) { + using namespace mps; + TORCH_CHECK(self.is_mps()); + + // Empty output + if(grad_input.numel() == 0) + return; + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); + @autoreleasepool { + string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}); + + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + + MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 + shape:@[@1] + dataType:getMPSDataType(self.scalar_type())]; + MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.to() + shape:@[@1] + dataType:getMPSDataType(self.scalar_type())]; + MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:betaTensor + name:nil]; + MPSGraphTensor* expBxTensor = [mpsGraph exponentWithTensor:bxTensor + name:nil]; + MPSGraphTensor* unitExpBxTensor = [mpsGraph additionWithPrimaryTensor:expBxTensor + secondaryTensor:unitTensor + name:nil]; + MPSGraphTensor* rTensor = [mpsGraph multiplicationWithPrimaryTensor:gradOutputTensor + secondaryTensor:expBxTensor + name:nil]; + rTensor = [mpsGraph divisionWithPrimaryTensor:rTensor + secondaryTensor:unitExpBxTensor + name:nil]; + MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to() + shape:@[@1] + dataType:getMPSDataType(self.scalar_type())]; + MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor + secondaryTensor:thresholdTensor + name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:gradOutputTensor + falsePredicateTensor:rTensor + name:nil]; + + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = @{ + gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } +} + + +Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { + using namespace mps; + + int64_t weight_num = weight_.numel(); + Tensor result = at::empty_like(self, self.suggest_memory_format()); + TORCH_INTERNAL_ASSERT(weight_.defined()); + + if (result.numel() == 0){ + return result; + } + + TORCH_CHECK( + weight_.dim() == 1 || weight_.dim() == 0, + "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", weight_.dim() + ); + + int64_t input_ndim = self.dim(); + NSMutableArray * expand_dims = [NSMutableArray new]; + + if (weight_num != 1) { + TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); + + int64_t channel_size = 1; // channel_size default to 1 + if (input_ndim > 1) { + channel_size = self.size(1); // channel is the 2nd dim of input + } + TORCH_CHECK(channel_size == weight_num, + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + " and channel size = ", channel_size, "."); + + for (const auto i : c10::irange(input_ndim)) { + if (i == 1) continue; + [expand_dims addObject:[NSNumber numberWithInt:i]]; + } + } + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *weightTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); + + @autoreleasepool { + NSString* expand_dims_key = [[expand_dims valueForKey:@"description"] componentsJoinedByString:@","]; + string key = "prelu_mps:" + getTensorsStringKey({self, weight_}) + string([expand_dims_key UTF8String]); + + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + + MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_); + + MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar:0.0 + shape:@[@1] + dataType:getMPSDataType(self.scalar_type())]; + MPSGraphTensor *reluTensor = [mpsGraph reLUWithTensor:inputTensor + name:nil]; + MPSGraphTensor *predicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor + secondaryTensor: zeroTensor + name: nil]; + MPSGraphTensor *weightedTensor = [mpsGraph selectWithPredicateTensor: predicateTensor + truePredicateTensor: inputTensor + falsePredicateTensor: zeroTensor + name: nil]; + if (weight_num != 1) { + MPSGraphTensor *expandedWeightTensor = [mpsGraph expandDimsOfTensor:weightTensor + axes:expand_dims + name:nil]; + weightedTensor = [mpsGraph multiplicationWithPrimaryTensor:weightedTensor + secondaryTensor:expandedWeightTensor + name:nil]; + }else{ + weightedTensor = [mpsGraph multiplicationWithPrimaryTensor:weightedTensor + secondaryTensor:weightTensor + name:nil]; + } + MPSGraphTensor *outputTensor = [mpsGraph additionWithPrimaryTensor:reluTensor + secondaryTensor:weightedTensor + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->weightTensor_ = weightTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + weightPlaceholder.getMPSGraphTensor() : weightPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + return result; +} + +std::tuple prelu_backward_mps(const Tensor& grad_output, const Tensor& self, const Tensor& weight_) { + using namespace mps; + + int64_t weight_num = weight_.numel(); + NSMutableArray * reduce_dims = [NSMutableArray new]; + Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); + Tensor weight_grad = at::empty_like(weight_, at::MemoryFormat::Contiguous); + + TORCH_CHECK( + weight_.dim() == 1 || weight_.dim() == 0, + "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", weight_.dim() + ); + + if (weight_num != 1) { + int64_t input_ndim = self.dim(); + TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); + + int64_t channel_size = 1; // channel_size default to 1 + if (input_ndim > 1) { + channel_size = self.size(1); // channel is the 2nd dim of input + } + TORCH_CHECK(channel_size == weight_num, + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + " and channel size = ", channel_size, "." + ); + + for (const auto i : c10::irange(input_ndim)) { + if (i == 1) continue; + [reduce_dims addObject:[NSNumber numberWithInt:i]]; + } + } + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *weightTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + MPSGraphTensor *weightedGradTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); + + @autoreleasepool { + NSString* reduce_dims_key = [[reduce_dims valueForKey:@"description"] componentsJoinedByString:@","]; + string key = "prelu_backward_mps:" + getTensorsStringKey({grad_output, self, weight_}) + ":" + string([reduce_dims_key UTF8String]); + + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + + MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + + MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_); + + MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar: 0.0 + shape:@[@1] + dataType: inputTensor.dataType]; + MPSGraphTensor* weightedGradOutputTensor = nil; + if (weight_num != 1) { + MPSGraphTensor *expandedWeightTensor = [mpsGraph expandDimsOfTensor:weightTensor + axes:reduce_dims + name:nil]; + weightedGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:expandedWeightTensor + secondaryTensor:gradOutputTensor + name:nil]; + } else { + weightedGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:weightTensor + secondaryTensor:gradOutputTensor + name:nil]; + } + MPSGraphTensor* inputGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:gradOutputTensor + name:nil]; + MPSGraphTensor *predicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor + secondaryTensor: zeroTensor + name: nil]; + MPSGraphTensor *outputTensor = [mpsGraph selectWithPredicateTensor: predicateTensor + truePredicateTensor: gradOutputTensor + falsePredicateTensor: weightedGradOutputTensor + name: nil]; + MPSGraphTensor *weightedGradTensor = [mpsGraph selectWithPredicateTensor: predicateTensor + truePredicateTensor: zeroTensor + falsePredicateTensor: inputGradOutputTensor + name: nil]; + weightedGradTensor = [mpsGraph reductionSumWithTensor:weightedGradTensor + axes:reduce_dims + name:nil]; + + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->weightTensor_ = weightTensor; + newCachedGraph->outputTensor_ = outputTensor; + newCachedGraph->weightedGradTensor_ = weightedGradTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); + Placeholder weightedGradPlaceholder = Placeholder(cachedGraph->weightedGradTensor_, weight_grad); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + weightPlaceholder.getMPSGraphTensor() : weightPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = @{ + gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData(), + weightedGradPlaceholder.getMPSGraphTensor() : weightedGradPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + return std::tuple{grad_input, weight_grad}; } TORCH_IMPL_FUNC(silu_out_mps) ( diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 6f3f2e547ad19..6e325de38c830 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -25,7 +25,7 @@ // alpha is always 1.0 except when this function is called from add_sub_template() void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha, - const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock) + const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock) { // it's possible to receive empty tensors here if (self.numel() == 0 || other.numel() == 0) { @@ -36,9 +36,26 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha const bool is_self_scalar = self.dim() == 0; const bool is_other_scalar = other.dim() == 0; + auto new_size = at::infer_size(self.sizes(), other.sizes()); + if (!output_.sizes().equals(new_size)) { + output_.resize_(new_size); + } + + Tensor output = output_; + bool needsCopyToOutput = false; + + if (!output_.is_contiguous()) { + output = output_.contiguous(); + needsCopyToOutput = true; + // else, determine if this is an in-place operation on a view output + } else if (output_.is_view() && (self.is_alias_of(output_) || other.is_alias_of(output_))) { + output = at::native::empty_mps(output_.sizes(), output_.scalar_type(), c10::nullopt, kMPS); + needsCopyToOutput = true; + } + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({self, other}, /*use_scalar_value*/ false); + string key = op_name + getTensorsStringKey({self, other, output_}, /*use_scalar_value*/ false); BinaryOpCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -62,6 +79,11 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype); } newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor); + // Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor + // Output tensor should have been promoted but it remains an int32 tensor + if (output_.scalar_type() != common_dtype) { + newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output_.scalar_type()); + } } return newCachedGraph; }); @@ -72,13 +94,13 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha Placeholder selfPlaceholder; Placeholder otherPlaceholder; - if (is_self_scalar) { + if (is_self_scalar && !self.is_mps()) { feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self.item(), getMPSScalarType(self.scalar_type())); } else { selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self); feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } - if (is_other_scalar) { + if (is_other_scalar && !other.is_mps()) { feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other.item(), getMPSScalarType(other.scalar_type())); } else { otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other); @@ -89,11 +111,15 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha feeds[cachedGraph->alphaTensor] = getMPSGraphTensorFromScalar(mpsStream, alpha, getMPSScalarType(other.scalar_type())); } - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, needsCopyToOutput ? output : output_); NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results); + + if (needsCopyToOutput) { + output_.copy_(output); + } } } @@ -127,8 +153,10 @@ void div_mode_template(const Tensor& self, const Tensor& other, void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output, std::string op_name) { - if (alpha.toDouble() == 0.0) + if (alpha.toDouble() == 0.0) { const_cast(output) = self.clone(); + return; + } const bool alpha_has_value = alpha.toDouble() != 1.0; if (alpha_has_value) { @@ -162,7 +190,18 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp } // namespace mps -#define CREATE_MPS_BINARY_OP_FUNC(func_out, func_stub, other_type) \ +#define CREATE_MPS_BINARY_COMPARISON_OP_FUNC(func_out, func_stub, other_type) \ +Tensor& func_out (const Tensor& self, const other_type& other, Tensor& output) { \ + mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ + ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ + MPSGraph* mpsGraph = cachedGraph->graph(); \ + return [mpsGraph func_stub##WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \ + secondaryTensor:mps::castMPSTensor(mpsGraph, secondaryCastTensor, ScalarType::Bool) \ + name:nil]; }); \ + return output; \ +} + +#define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \ TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \ mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ @@ -173,7 +212,7 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp } // Boolean Ops require casting output to "MPSDataTypeBool" -#define CREATE_MPS_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \ +#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \ TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \ mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ @@ -185,26 +224,30 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp } // Boolean Binary Ops -CREATE_MPS_BOOLEAN_OP_FUNC(eq_scalar_out_mps, equal, Scalar); -CREATE_MPS_BOOLEAN_OP_FUNC(eq_tensor_out_mps, equal, Tensor); -CREATE_MPS_BOOLEAN_OP_FUNC(ne_scalar_out_mps, notEqual, Scalar); -CREATE_MPS_BOOLEAN_OP_FUNC(ne_tensor_out_mps, notEqual, Tensor); -CREATE_MPS_BOOLEAN_OP_FUNC(le_scalar_out_mps, lessThanOrEqualTo, Scalar); -CREATE_MPS_BOOLEAN_OP_FUNC(le_tensor_out_mps, lessThanOrEqualTo, Tensor); -CREATE_MPS_BOOLEAN_OP_FUNC(lt_scalar_out_mps, lessThan, Scalar); -CREATE_MPS_BOOLEAN_OP_FUNC(lt_tensor_out_mps, lessThan, Tensor); -CREATE_MPS_BOOLEAN_OP_FUNC(ge_scalar_out_mps, greaterThanOrEqualTo, Scalar); -CREATE_MPS_BOOLEAN_OP_FUNC(ge_tensor_out_mps, greaterThanOrEqualTo, Tensor); -CREATE_MPS_BOOLEAN_OP_FUNC(gt_scalar_out_mps, greaterThan, Scalar); -CREATE_MPS_BOOLEAN_OP_FUNC(gt_tensor_out_mps, greaterThan, Tensor); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(eq_scalar_out_mps, equal, Scalar); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(eq_tensor_out_mps, equal, Tensor); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(ne_scalar_out_mps, notEqual, Scalar); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(ne_tensor_out_mps, notEqual, Tensor); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(le_scalar_out_mps, lessThanOrEqualTo, Scalar); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(le_tensor_out_mps, lessThanOrEqualTo, Tensor); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(lt_scalar_out_mps, lessThan, Scalar); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(lt_tensor_out_mps, lessThan, Tensor); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(ge_scalar_out_mps, greaterThanOrEqualTo, Scalar); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(ge_tensor_out_mps, greaterThanOrEqualTo, Tensor); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(gt_scalar_out_mps, greaterThan, Scalar); +CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(gt_tensor_out_mps, greaterThan, Tensor); // Arithmetic Binary Ops -CREATE_MPS_BINARY_OP_FUNC(minimum_out_mps, minimum, Tensor); -CREATE_MPS_BINARY_OP_FUNC(maximum_out_mps, maximum, Tensor); -CREATE_MPS_BINARY_OP_FUNC(mul_out_mps, multiplication, Tensor); -CREATE_MPS_BINARY_OP_FUNC(pow_tensor_scalar_out_mps, power, Scalar); -CREATE_MPS_BINARY_OP_FUNC(pow_tensor_tensor_out_mps, power, Tensor); -CREATE_MPS_BINARY_OP_FUNC(atan2_mps_out, atan2, Tensor); +CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(minimum_out_mps, minimum, Tensor); +CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(maximum_out_mps, maximum, Tensor); +CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(mul_out_mps, multiplication, Tensor); +CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(pow_tensor_scalar_out_mps, power, Scalar); +CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(pow_tensor_tensor_out_mps, power, Tensor); +CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(atan2_mps_out, atan2, Tensor); + +CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_and_out_mps, logicalAND, Tensor); +CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_or_out_mps, logicalOR, Tensor); +CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_xor_out_mps, logicalXOR, Tensor); TORCH_IMPL_FUNC(div_out_mode_mps) (const Tensor& self, const Tensor& other, c10::optional rounding_mode, const Tensor& output) { diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm index 09e962b94f78e..0cfd7ccc2ff5b 100644 --- a/aten/src/ATen/native/mps/operations/ConstantOps.mm +++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm @@ -1,17 +1,6 @@ // Copyright © 2022 Apple Inc. -#include -#include -#include -#include #include -#include - -#ifdef __OBJC__ -#include -#endif - -using namespace at::mps; namespace at { namespace native { @@ -34,13 +23,8 @@ MPSGraphCache *cache_ = MPSGraphCache::getInstance(); @autoreleasepool { + string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble()); - MPSShape* input_shape = getMPSShape(self); - NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - - string key = "fill_scalar_mps_impl:" + getMPSTypeString(self.scalar_type()) - + ":" + string([ns_shape_key UTF8String]) - + ":" + to_string(value.toDouble()); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -50,12 +34,21 @@ @autoreleasepool{ MPSGraph *mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - + auto isBool = self.scalar_type() == c10::ScalarType::Bool; + auto dataType = (!isBool) ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8; + // constantWithScalar does not work for boolTypes on MacOS-12.[34] + // workaround by filing it as int8 tensor and than casting to bool + // See https://github.com/pytorch/pytorch/issues/82427 MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble() - shape:input_shape - dataType:getMPSScalarType(self.scalar_type())]; + shape:getMPSShape(self) + dataType:dataType]; MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil]; + if (isBool) { + outputTensor = [mpsGraph castTensor:outputTensor + toType:MPSDataTypeBool + name:@"constWithBool-workaround"]; + } newCachedGraph->outputTensor_ = outputTensor; } @@ -78,17 +71,37 @@ return self; } +// returns false if tensor cannot be filled with fillBuffer() +bool fill_mps_tensor_(Tensor& self, uint8_t value) { + if (self.is_contiguous()) { + MPSStream* stream = getCurrentMPSStream(); + auto storage_byte_offset = self.storage_offset() * self.itemsize(); + stream->fill(mps::getMTLBufferStorage(self), 0, self.nbytes(), storage_byte_offset); + return true; + } + return false; +} + Tensor& zero_mps_(Tensor& self) { - return at::native::fill_scalar_mps_impl(self, 0.0f); + // check if it's possible to use fillBuffer() to fill the Tensor's storage + if (fill_mps_tensor_(self, 0) == true) + return self; + return fill_scalar_mps_impl(self, 0.0f); } Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) { - return at::native::fill_scalar_mps_impl(self, value); + if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) + return self; + return fill_scalar_mps_impl(self, value); } Tensor& fill_tensor_mps_(Tensor& self, const Tensor& value) { TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions."); - return at::native::fill_scalar_mps_impl(self, value.item()); + Scalar scalar_value = value.item(); + if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) + return self; + return fill_scalar_mps_impl(self, scalar_value); } + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 02d38573edff3..3c2ab0d6c2f8b 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -12,149 +12,9 @@ #include #include + namespace at { namespace native { - -MPSGraphTensor* chainViewOperation(MPSGraph* mpsGraph, IntArrayRef size, - IntArrayRef stride, int64_t storage_offset, - MPSGraphTensor* inputTensor, const Tensor& self) { - MPSGraphTensor *outputTensor = nil; - const size_t shape_size = size.size(); - - @autoreleasepool { - std::vector sizeArray(shape_size); - const int64_t int_max = std::numeric_limits::max(); - for (int i = 0; i < shape_size; i++) { - TORCH_CHECK(size[i] <= int_max); - sizeArray[i] = static_cast(size[i]); - } - NSData* shapeData = [NSData dataWithBytes:sizeArray.data() - length:shape_size * sizeof(int32_t)]; - MPSGraphTensor* shapeTensor = [mpsGraph constantWithData:shapeData - shape:@[[NSNumber numberWithUnsignedInteger: shape_size]] - dataType:MPSDataTypeInt32]; - - MPSGraphTensor* storageOffsetTensor = [mpsGraph constantWithScalar:storage_offset - dataType:MPSDataTypeInt32]; - MPSGraphTensor* strideTensor = [mpsGraph constantWithScalar:stride[shape_size - 1] - dataType:MPSDataTypeInt32]; - MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:-1 - withShapeTensor:shapeTensor - name:nil]; - MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor - secondaryTensor:strideTensor - name:nil]; - MPSGraphTensor* indicesTensor = indexTensor; - // create stride Tensors for each rank of the input tensor - for (int i = 1; i < shape_size; i++) { - strideTensor = [mpsGraph constantWithScalar:stride[shape_size - i - 1] - dataType:MPSDataTypeInt32]; - MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:(-i - 1) - withShapeTensor:shapeTensor - name:nil]; - MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor - secondaryTensor:strideTensor - name:nil]; - indicesTensor = [mpsGraph additionWithPrimaryTensor:indexTensor - secondaryTensor:indicesTensor - name:nil]; - } - indicesTensor = [mpsGraph additionWithPrimaryTensor:indicesTensor - secondaryTensor:storageOffsetTensor - name:nil]; - MPSGraphTensor *reshapedInputTensor = [mpsGraph reshapeTensor:inputTensor - withShape:@[@-1] - name:nil]; - MPSGraphTensor *reshapedIndicesTensor = [mpsGraph reshapeTensor:indicesTensor - withShape:@[@-1] - name:nil]; - // Call gather to coalesce the needed values. Result will be of same shape as flattened indices tensor - MPSGraphTensor *gatheredTensor = [mpsGraph gatherWithUpdatesTensor:reshapedInputTensor - indicesTensor:reshapedIndicesTensor - axis:0 - batchDimensions:0 - name:nil]; - // Reshape the data to desired size - outputTensor = [mpsGraph reshapeTensor:gatheredTensor - withShapeTensor:shapeTensor - name:nil]; - } - return outputTensor; -} - - -// There are few cases we need to consider: -// Here nodes are the Tensors and the edges are the operations performed on the -// Tensor. As a result of the operation performed we can have result as View -// Tensor (View T) or a Non view tensor (NonView T). The difference is if its -// mapped by the same underlying storage ptr or a new MTLBuffer was allocated. -// T = Tensor -// ---------- -// | Orig T | -// ---------- -// / | \ -// View T View T NonView T -// / / \ | -// View T / \ | -// | / \ | -// | / \ | -// | / \ | -// NonView T NonView T -Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, - IntArrayRef stride, - optional storage_offset_) { - using namespace mps; - // Use the size and stride to create a unique key - auto result = detail::make_tensor( - c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); - auto storage_offset = storage_offset_.value_or(self.storage_offset()); - setStrided(result, size, stride, storage_offset); - - // 0 sizes won't result in any change in the shape of the Tensor so we can - // skip it. Also if the memory is contiguous we don't need to do - // gather-scatter operations using graph. - if (size.size() > 0) { - - // If self itself was a view tensor, that means we need to chain the graphs - // else we will create a new entry in the cache - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - @autoreleasepool { - string key = mps::getStridedKey(self, size, stride, storage_offset); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if (!cachedGraph) { - // Check if this stride operation is already performed on a strided tensor - auto origKey = getStridedKey(self, self.sizes(), self.strides(), self.storage_offset()); - auto origGraph = static_cast(cache_->LookUp(origKey)); - cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - // Self is the input tensor we are creating view of, which can also be a stride - // In later case, preserve shape from the original graph - auto shape = origGraph ? [origGraph->inputTensor_ shape] : getMPSShape(self); - auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), shape); - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = chainViewOperation(mpsGraph, size, stride, - storage_offset, inputTensor, self); - } - return newCachedGraph; - }, self.storage().data()); - } - } - } - return result; -} - namespace mps { void* pageAlignedBlockPtr( @@ -173,15 +33,6 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, return (void*)alignedAddress; } -static bool copy_requires_temporaries(const Tensor& dst, const Tensor& src) { - bool same_dtype = src.dtype() == dst.dtype(); - if (same_dtype && src.is_contiguous() && dst.is_contiguous()) { - return false; - } else { - return true; - } -} - // Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type(). // The shapes and dtypes are taken from dst and src, but their storage pointers are not used. void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, @@ -236,14 +87,10 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, } } -static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, - bool non_blocking) { - - using namespace mps; +static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) +{ id device = MPSDevice::getInstance()->device(); MPSStream* stream = getCurrentMPSStream(); - uint64_t size = src_.nbytes(); - if (size == 0) return dst_; Tensor dst; Tensor src; if (!dst_.is_contiguous()) { @@ -253,65 +100,51 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, } auto storage_byte_offset = src_.storage_offset() * src_.itemsize(); - id sourceBuffer = __builtin_bit_cast(id, src_.storage().data()); if (!src_.is_contiguous()) { - id gatherTensor = gatherViewTensor(src_, sourceBuffer); - if (gatherTensor) { - sourceBuffer = gatherTensor; + Tensor emptyShell = Tensor(); + src = gatherViewTensor(src_, emptyShell); + if (src.has_storage()) { storage_byte_offset = 0; } else { src = src_.expand_as(dst).contiguous(); - sourceBuffer = __builtin_bit_cast(id, src.storage().data()); storage_byte_offset = src.storage_offset() * src.itemsize(); } } else { src = src_; } - - void* host_dst = dst.storage().data(); - - if (sourceBuffer == nil) return dst_; - NSUInteger destOffset = dst.storage_offset() * dst.itemsize(); + id sourceBuffer = getMTLBufferStorage(src); + size_t src_total_size = src_.is_view() ? at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()) : + src.nbytes(); + size_t size_to_copy = src.nbytes(); // In case of dtype change, first convert src inplace if (src_.dtype() != dst_.dtype()) { - copy_cast_mps(dst_, src_, sourceBuffer, sourceBuffer); + copy_cast_mps(dst, src, sourceBuffer, sourceBuffer); + // Use the element size of dst to calculate the total size after casting + size_to_copy = (size_to_copy / src.element_size()) * dst.element_size(); } + // If there's anything wrong with source, we shouldn't return dst_ silently and must error out. + TORCH_INTERNAL_ASSERT(sourceBuffer && size_to_copy > 0); + TORCH_INTERNAL_ASSERT(src_total_size >= storage_byte_offset); + TORCH_INTERNAL_ASSERT(dst.nbytes() >= (dst.storage_offset() * dst.element_size())); + @autoreleasepool { MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared; NSUInteger alignedLength = 0; - void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)size, &alignedLength); + void* host_dst = dst.storage().data(); + void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_total_size, &alignedLength); id destBuffer = [device newBufferWithBytesNoCopy:alignedPtr length:alignedLength options:options deallocator:nil]; - destOffset = uintptr_t(host_dst) - uintptr_t(alignedPtr); + NSUInteger destOffset = uintptr_t(host_dst) - uintptr_t(alignedPtr); // 4 bytes alignment required on macos for blits. - TORCH_CHECK(destOffset % 4 == 0, "Unaligned blit request"); - - dispatch_sync(stream->queue(), ^() { - @autoreleasepool { - id commandBuffer = stream->commandBuffer(); - id blitEncoder = - [commandBuffer blitCommandEncoder]; - - [blitEncoder copyFromBuffer:sourceBuffer - sourceOffset:(NSUInteger)storage_byte_offset - toBuffer:destBuffer - destinationOffset:(NSUInteger)destOffset - size:(NSUInteger)size]; - [blitEncoder endEncoding]; - - if (non_blocking) { - stream->commit(true); - } else { - stream->commitAndWait(); - } - [destBuffer release]; - } - }); + TORCH_INTERNAL_ASSERT(destOffset % 4 == 0, "Unaligned blit request"); + + stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, storage_byte_offset, destOffset, non_blocking); + [destBuffer release]; } if (!dst.is_same(dst_)) { dst_.copy_(dst, non_blocking); @@ -320,36 +153,42 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, return dst_; } -static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, - bool non_blocking) { +static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) +{ MPSStream* stream = getCurrentMPSStream(); Tensor dst; Tensor src; id device = MPSDevice::getInstance()->device(); auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize(); - id destBuffer = __builtin_bit_cast(id, dst_.storage().data()); - + id destBuffer = getMTLBufferStorage(dst_); + uint64_t src_total_size = 0; if (src_.is_view()) { src = src_.to(dst_.dtype()).expand_as(dst_).contiguous(); + // Get the actual size of a View (takes into account the storage offset) + // For View tensors, the storage offset can be bigger than what's being reported by nbytes + src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()); } else { src = src_; if (src.dtype() != dst_.dtype()) { // In case of dtype change, perform conversion on source device src = src.to(dst_.dtype()); } + src_total_size = src.nbytes(); } + const size_t size_to_copy = src.nbytes(); const void* host_src = src.storage().data(); - uint64_t size = src.nbytes(); + TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size())); + TORCH_INTERNAL_ASSERT(dst_.nbytes() >= dst_byte_offset); NSUInteger sourceOffset = 0; @autoreleasepool { MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared; NSUInteger alignedLength = 0; - void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size, &alignedLength); + void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength); id sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr length:alignedLength options:options @@ -358,25 +197,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, if (src_.is_view() || !src_.is_contiguous()) sourceOffset += src_.storage_offset() * src_.itemsize(); - dispatch_sync(stream->queue(), ^() { - @autoreleasepool { - id commandBuffer = stream->commandBuffer(); - id blitEncoder = - [commandBuffer blitCommandEncoder]; - - [blitEncoder copyFromBuffer:sourceBuffer - sourceOffset:(NSUInteger)sourceOffset - toBuffer:destBuffer - destinationOffset:(NSUInteger)dst_byte_offset - size:(NSUInteger)size]; - [blitEncoder endEncoding]; - if (non_blocking) { - stream->commit(true); - } else { - stream->commitAndWait(); - } - } - }); + stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); [sourceBuffer release]; } @@ -385,68 +206,53 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, void copy_blit_mps(void* dst, const void* src, size_t size) { MPSStream* stream = getCurrentMPSStream(); - id sourceBuffer = (id)(src); - id destBuffer = (id)(dst); - dispatch_sync(stream->queue(), ^() { - @autoreleasepool { - id commandBuffer = stream->commandBuffer(); - id blitEncoder = - [commandBuffer blitCommandEncoder]; - - [blitEncoder copyFromBuffer:sourceBuffer - sourceOffset:0 - toBuffer:destBuffer - destinationOffset:0 - size:size]; - [blitEncoder endEncoding]; - stream->commitAndWait(); - } - }); + stream->copy_and_sync((id)(src), (id)(dst), size, 0, 0, true); } - -static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, - bool non_blocking) { - uint64_t size = src_.nbytes(); +static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) +{ auto src_byte_offset = src_.storage_offset() * src_.itemsize(); - id sourceBuffer = __builtin_bit_cast(id, src_.storage().data()); + auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize(); + + // If dst is contiguous and there is no byte offset, we can save directly the result of + // gather into dst. This reduces the overhead of doing an additional blit for most cases + bool returnGatherOutput = (dst_.is_contiguous() && !dst_byte_offset); Tensor src; + if (!src_.is_contiguous()) { - id gatherTensor = gatherViewTensor(src_, sourceBuffer); - if (gatherTensor) { - sourceBuffer = gatherTensor; + Tensor emptyShell = Tensor(); + src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell); + + if (src.has_storage()) { + if (returnGatherOutput) + return dst_; + src_byte_offset = 0; } else { src = src_.expand_as(dst_).contiguous(); - sourceBuffer = __builtin_bit_cast(id, src.storage().data()); src_byte_offset = src.storage_offset() * src.itemsize(); } } else { src = src_; } + // Scatter to `dst` if the memory is not contiguous + // If the memory is not contiguous, it means that the tensor has strides and we would not be + // able to do the copy using a single blit + if (!dst_.is_contiguous()) { + return scatterViewTensor(src, dst_); + } src._set_conj(src_.is_conj()); src._set_neg(src_.is_neg()); - auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize(); - id destBuffer = __builtin_bit_cast(id, dst_.storage().data()); - + id destBuffer = getMTLBufferStorage(dst_); + id sourceBuffer = getMTLBufferStorage(src); + const size_t src_size = src.nbytes(); if (src.dtype() == dst_.dtype()) { MPSStream* stream = getCurrentMPSStream(); - dispatch_sync(stream->queue(), ^() { - @autoreleasepool { - id commandBuffer = stream->commandBuffer(); - id blitEncoder = [commandBuffer blitCommandEncoder]; - [blitEncoder copyFromBuffer:sourceBuffer - sourceOffset:src_byte_offset - toBuffer:destBuffer - destinationOffset:dst_byte_offset - size:size]; - [blitEncoder endEncoding]; - stream->commitAndWait(); - } - }); + // for GPU to GPU copies we only encode to stream's command buffer (no flushing) + stream->copy(sourceBuffer, destBuffer, src_size, src_byte_offset, dst_byte_offset); } else { - copy_cast_mps(dst_, src_, destBuffer, sourceBuffer); + copy_cast_mps(dst_, src, destBuffer, sourceBuffer); } return dst_; } diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index 3acab4a1b9af6..a4b73bd75fb03 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -11,11 +11,9 @@ #include #include #include + namespace at { namespace native { -namespace templates { - -} Tensor& uniform_mps_(Tensor& input, double from, double to, c10::optional gen_) { @@ -127,6 +125,62 @@ return normal_mps_out(mean_t, std_t, gen, self); } +Tensor normal_mps(const Tensor& mean, double std, c10::optional gen) { + Tensor output = empty_mps( + mean.sizes(), + mean.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + Tensor std_t = empty_mps( + output.sizes(), + output.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + std_t.fill_(std); + + return normal_mps_out(mean, std_t, gen, output); +} + +Tensor normal_mps(double mean, const Tensor& std, c10::optional gen) { + Tensor output = empty_mps( + std.sizes(), + std.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + Tensor mean_t = empty_mps( + output.sizes(), + output.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + mean_t.fill_(mean); + + return normal_mps_out(mean_t, std, gen, output); +} + +Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional gen) { + auto shape = at::infer_size(mean.sizes(), std.sizes()); + + Tensor output = empty_mps( + shape, + mean.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + return normal_mps_out(mean, std, gen, output); +} + Tensor& normal_mps_out(const Tensor& mean, double std, c10::optional gen, Tensor& output) { TORCH_CHECK(std >= 0.0, "normal_mps_out expects std >= 0.0, but found std=", std); diff --git a/aten/src/ATen/native/mps/operations/Equal.cpp b/aten/src/ATen/native/mps/operations/Equal.cpp new file mode 100644 index 0000000000000..93c1be87e88a3 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Equal.cpp @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at { +namespace mps { +TORCH_API at::Tensor eq(const at::Tensor & self, const at::Tensor & other); +} // namespace +namespace native { + +bool mps_equal(const Tensor& self, const Tensor &src) { + if (!at::namedinference::are_names_equal( + self.unsafeGetTensorImpl(), src.unsafeGetTensorImpl())) { + return false; + } + at::NoNamesGuard guard; + TORCH_CHECK(self.device() == src.device(), "Cannot compare two tensors on " + "different devices. Got: ", self.device(), " and ", src.device()); + if (self.sizes() != src.sizes()) { + return false; + } + if (self.numel() == 0) { + return true; + } + return at::mps::eq(self, src).all().item().to(); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index b11ecd1bf8992..7c0d7544cf21a 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,183 @@ namespace at { namespace native { +Tensor flip_mps(const Tensor& self, IntArrayRef dims) { + using namespace mps; + + Tensor result = at::native::empty_mps( + self.sizes(), + self.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + auto total_dims = self.dim(); + // It wraps the dims and checks that there are no repeated dims + auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims); + NSMutableArray * ns_dims = [NSMutableArray new]; + + for (const auto i : c10::irange(total_dims)) { + if(flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) { + [ns_dims addObject:[NSNumber numberWithInt:i]]; + } + } + + // Nothing to do, we return fast + if (dims.size() == 0 || self.numel() <=1) { + result.copy_(self); + return result; + } + + MPSStream* stream = getCurrentMPSStream(); + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","]; + // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph + string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + string([ns_dims_key UTF8String]); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor + axes:ns_dims + name:nil]; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); + + + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + // Run the graph + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + + return result; + +} + +TORCH_IMPL_FUNC(index_add_mps_out)( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + const Scalar& alpha, + const Tensor& result) { + + using namespace mps; + MPSStream* stream = getCurrentMPSStream(); + dim = maybe_wrap_dim(dim, self.dim()); + auto numel = index.numel(); + auto alpha_f = alpha.to(); + + if (numel == 0) { + return; + } + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* indexTensor_ = nil; + MPSGraphTensor* sourceTensor_ = nil; + MPSGraphTensor* alphaTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + + string key = "index_add_mps_out" + getTensorsStringKey({self, index, source}) + ":" + std::to_string(dim); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); + MPSGraphTensor* sourceTensor = mpsGraphRankedPlaceHolder(mpsGraph, source); + MPSGraphTensor* alphaTensor = mpsGraphScalarPlaceHolder(mpsGraph, alpha_f); + MPSGraphTensor* inputSlice = [mpsGraph gatherWithUpdatesTensor:inputTensor + indicesTensor:indexTensor + axis:dim + batchDimensions:0 + name:nil]; + MPSGraphTensor* alphaSourceSlice = [mpsGraph multiplicationWithPrimaryTensor:sourceTensor + secondaryTensor:alphaTensor + name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:inputTensor + updatesTensor:alphaSourceSlice + indicesTensor:indexTensor + axis:dim + mode:MPSGraphScatterModeAdd + name:nil]; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->indexTensor_ = indexTensor; + newCachedGraph->sourceTensor_ = sourceTensor; + newCachedGraph->alphaTensor_ = alphaTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index); + Placeholder sourcePlaceholder = Placeholder(cachedGraph->sourceTensor_, source); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); + + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData(), + sourcePlaceholder.getMPSGraphTensor() : sourcePlaceholder.getMPSGraphTensorData(), + cachedGraph->alphaTensor_ : getMPSGraphTensorFromScalar(stream, alpha_f, MPSDataTypeFloat32) + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } +} + Tensor index_select_mps(const Tensor & self, int64_t dim, const Tensor & index) { diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index d9cad62ee27c4..a6710ea5fc2a5 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -25,6 +25,10 @@ Tensor _mps_linear( using namespace mps; + TORCH_CHECK(input.scalar_type() == ScalarType::Double + || input.scalar_type() == ScalarType::Float + || input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs"); + // See [Note: hacky wrapper removal for optional tensor] auto bias = bias_opt.has_value() ? c10::MaybeOwned::borrowed(*bias_opt) @@ -156,6 +160,10 @@ Tensor _mps_linear_backward_input( TORCH_CHECK(weight.device().is_mps() && weight.scalar_type() == kFloat, "mps_linear_backward: weight needs to be a dense tensor"); + TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double + || grad_output.scalar_type() == ScalarType::Float + || grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs"); + const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous(); struct CachedGraph : public mps::MPSCachedGraph @@ -232,6 +240,10 @@ Tensor _mps_linear_backward_input( TORCH_CHECK(grad_output.is_mps() && input.is_mps(), "_mps_linear_backward: grad_output and input needs to be mps layout"); + TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double + || grad_output.scalar_type() == ScalarType::Float + || grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs"); + struct CachedGraph : public mps::MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} @@ -242,9 +254,9 @@ Tensor _mps_linear_backward_input( MPSGraphTensor *biasTensor_ = nil; }; - auto grad_output_reshaped = grad_output.dim() > 2 ? + auto grad_output_reshaped = grad_output.dim() != 2 ? grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output; - auto input_reshaped = input.dim() > 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input; + auto input_reshaped = input.dim() != 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input; TORCH_CHECK(grad_output_reshaped.is_mps()); TORCH_CHECK(input_reshaped.is_mps()); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 22e5fa822b36a..8b69c65c17fae 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -100,6 +100,9 @@ void prepare_matrices_for_broadcasting( Tensor& output) { using namespace mps; TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(self.scalar_type() == ScalarType::Double + || self.scalar_type() == ScalarType::Float + || self.scalar_type() == ScalarType::Half, "MPS device does not support mm for non-float inputs"); TensorArg args[]{{output, "out", 0}, {self, "mat1", 1}, {other, "mat2", 2}}; checkAllSameGPU("mm", args); @@ -208,6 +211,9 @@ void prepare_matrices_for_broadcasting( TORCH_CHECK(output.is_mps()); TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(self.scalar_type() == ScalarType::Double + || self.scalar_type() == ScalarType::Float + || self.scalar_type() == ScalarType::Half, "MPS device does not support addmm for non-float input"); TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}}; checkAllSameGPU(__func__, args); @@ -246,7 +252,7 @@ void prepare_matrices_for_broadcasting( bool transpose_mat1 = false; bool transpose_mat2 = false; - prepare_matrices_for_broadcasting(&bias, self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2); + prepare_matrices_for_broadcasting(&(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2); struct CachedGraph : public mps::MPSCachedGraph { @@ -260,7 +266,7 @@ void prepare_matrices_for_broadcasting( mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); @autoreleasepool { - string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, bias}) + string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(transpose_mat1) + ":" + to_string(transpose_mat2) + ":" + to_string(beta.toDouble()) + ":" + to_string(alpha.toDouble()); @@ -276,7 +282,7 @@ void prepare_matrices_for_broadcasting( MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor *otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other); - MPSGraphTensor *biasTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, bias); + MPSGraphTensor *biasTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *bias_); MPSGraphTensor* t1 = nil; MPSGraphTensor* t2 = nil; @@ -306,7 +312,7 @@ void prepare_matrices_for_broadcasting( // Intermediates for beta and alpha MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble() - dataType:getMPSScalarType(bias.scalar_type())]; + dataType:getMPSScalarType((*bias_).scalar_type())]; MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble() dataType:getMPSScalarType(self.scalar_type())]; @@ -340,7 +346,7 @@ void prepare_matrices_for_broadcasting( Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other); - Placeholder biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias); + Placeholder biasPlaceholder = Placeholder(cachedGraph->biasTensor_, *bias_); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); NSDictionary* feeds = @{ @@ -366,6 +372,10 @@ void prepare_matrices_for_broadcasting( Tensor & result) { using namespace mps; + TORCH_CHECK(batch1.scalar_type() == ScalarType::Double + || batch1.scalar_type() == ScalarType::Float + || batch1.scalar_type() == ScalarType::Half, "MPS device does not support bmm for non-float inputs"); + if (batch1.numel() == 0 || batch2.numel() == 0) { return result; } @@ -444,6 +454,10 @@ void prepare_matrices_for_broadcasting( TORCH_CHECK(batch2.is_mps()); TORCH_CHECK(result.is_mps()); + TORCH_CHECK(batch1.scalar_type() == ScalarType::Double + || batch1.scalar_type() == ScalarType::Float + || batch1.scalar_type() == ScalarType::Half, "MPS device does not support addbmm or baddbmm for non-float inputs"); + TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor"); TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor"); TORCH_CHECK(batch1.size(0) == batch2.size(0), diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 7f08968bf07cc..454a9512c23ab 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -314,10 +314,10 @@ void mse_loss_out_impl(const Tensor& input, const Tensor& target, // NLLLoss void nllnd_loss_backward_impl( -Tensor& grad_input, +Tensor& grad_input_arg, const Tensor& grad_output, -const Tensor& input, -const Tensor& target, +const Tensor& input_arg, +const Tensor& target_arg, const Tensor& weight, int64_t reduction, int64_t ignore_index, @@ -325,7 +325,7 @@ void nllnd_loss_backward_impl( bool is2D) { // Empty output - if(grad_input.numel() == 0) + if(grad_input_arg.numel() == 0) return; MPSStream* stream = getCurrentMPSStream(); @@ -342,6 +342,10 @@ void nllnd_loss_backward_impl( MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg; + auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg; + auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg; + @autoreleasepool { auto numClasses = grad_input.sizes()[1]; @@ -472,24 +476,24 @@ void nllnd_loss_backward_impl( void nllnd_loss_forward_impl (Tensor& output, Tensor& total_weight, - const Tensor& input, - const Tensor& target, + const Tensor& input_arg, + const Tensor& target_arg, const Tensor& weight, int64_t reduction, int64_t ignore_index, bool is2D) { - std::vector reshapedTarget(target.sizes().begin(), target.sizes().end()); + std::vector reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end()); reshapedTarget.push_back(1); - Tensor batchSizeTensor = at::empty_like(input).resize_(IntArrayRef(1)); + Tensor batchSizeTensor = at::empty_like(input_arg).resize_(IntArrayRef(1)); float batchVal = 1.0f; for(size_t i = 0; i < reshapedTarget.size(); ++i) batchVal *= reshapedTarget[i]; batchSizeTensor[0] = batchVal; if(reduction == Reduction::None) - output.resize_(target.sizes()); + output.resize_(target_arg.sizes()); if(reduction == Reduction::Sum) output.resize_({}); if(reduction == Reduction::Mean) @@ -516,6 +520,9 @@ void nllnd_loss_backward_impl( MPSStream* stream = getCurrentMPSStream(); + auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg; + auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg; + @autoreleasepool { bool isWeightsArrayValid = (weight.numel() > 0); @@ -764,12 +771,12 @@ void smooth_l1_loss_impl( MPSGraphTensor *mpsGraphOneTensor = [mpsGraph constantWithScalar: 1.0 dataType: inputTensor.dataType]; MPSGraphTensor *mpsGraphHalfTensor = [mpsGraph constantWithScalar: 0.5 - dataType: MPSDataTypeFloat32]; + dataType: inputTensor.dataType]; MPSGraphTensor *betaTensor = [mpsGraph constantWithScalar: beta - dataType: MPSDataTypeFloat32]; + dataType: inputTensor.dataType]; // 0.5 * beta MPSGraphTensor *halfTensorMulBetaTensor = [mpsGraph constantWithScalar: beta * 0.5 - dataType: MPSDataTypeFloat32]; + dataType: inputTensor.dataType]; // Calculating first part of the equation: // ln = 0.5(xn - yn)^2/beta, if |xn - yn| < beta @@ -1004,7 +1011,7 @@ void smooth_l1_loss_backward_impl( NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - targetPlaceholder.getMPSGraphTensor() : targetPlaceholder .getMPSGraphTensorData() + targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() }; NSDictionary* results = @{ gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() @@ -1035,6 +1042,244 @@ void smooth_l1_loss_backward_template( // APIs exposed to at::native scope +// HuberLoss + +Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output){ + string op_name = __func__; + using namespace mps; + TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.") + TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") + TORCH_CHECK(output.is_mps()); + + if(reduction == Reduction::None) + output.resize_(target.sizes()); + if(reduction == Reduction::Sum) + output.resize_({}); + if(reduction == Reduction::Mean) + output.resize_({}); + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* targetTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); + + @autoreleasepool { + string key = op_name + ":" + reductionToString(reduction) + ":" + std::to_string(delta) + ":" + getTensorsStringKey({input, target}); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); + MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta + shape:@[@1] + dataType:MPSDataTypeFloat32]; + MPSGraphTensor* halfTensor = [mpsGraph constantWithScalar:.5f + shape:@[@1] + dataType:MPSDataTypeFloat32]; + + MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor: inputTensor + secondaryTensor: targetTensor + name: nil]; + MPSGraphTensor* absDiffTensor = [mpsGraph absoluteWithTensor: diffTensor + name: nil]; + MPSGraphTensor* firstCondTensor = [mpsGraph multiplicationWithPrimaryTensor: absDiffTensor + secondaryTensor: absDiffTensor + name: nil]; + firstCondTensor = [mpsGraph multiplicationWithPrimaryTensor: firstCondTensor + secondaryTensor: halfTensor + name: nil]; + MPSGraphTensor* secondCondTensor = [mpsGraph multiplicationWithPrimaryTensor: deltaTensor + secondaryTensor: halfTensor + name: nil]; + secondCondTensor = [mpsGraph subtractionWithPrimaryTensor: absDiffTensor + secondaryTensor: secondCondTensor + name: nil]; + secondCondTensor = [mpsGraph multiplicationWithPrimaryTensor: deltaTensor + secondaryTensor: secondCondTensor + name: nil]; + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph lessThanOrEqualToWithPrimaryTensor:absDiffTensor + secondaryTensor:deltaTensor + name:nil] + truePredicateTensor: firstCondTensor + falsePredicateTensor: secondCondTensor + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->targetTensor_ = targetTensor; + newCachedGraph->outputTensor_ = reduceTensor(outputTensor, reduction, mpsGraph, input.sizes().size()); + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } + return output; +} + +Tensor huber_loss_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta) { + TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta."); + Tensor output = at::native::empty_mps( + input.sizes(), + input.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + return huber_loss_out_mps(input, target, reduction, delta, output); +} + +Tensor& huber_loss_backward_out_mps( + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double delta, + Tensor& grad_input) +{ + using namespace mps; + auto is_mean_reduction = reduction == Reduction::Mean; + auto input_numel = input.numel(); + + auto new_grad_output = grad_output.contiguous(); + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *targetTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + }; + + MPSGraphCache *cache_ = MPSGraphCache::getInstance(); + + MPSStream *stream= getCurrentMPSStream(); + + @autoreleasepool { + MPSShape* input_shape = getMPSShape(input); + NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; + + string key = "huber_loss_backward_out_mps:" + reductionToString(reduction) + ":" + + std::to_string(delta) + ":" + + [ns_shape_key UTF8String] + ":" + + getMPSTypeString(input.scalar_type()) + ":" + + getMPSTypeString(target.scalar_type()); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + // Initialize graph + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(new_grad_output.scalar_type()), getMPSShape(new_grad_output)); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape); + MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target.scalar_type()), getMPSShape(target)); + MPSGraphTensor* isMeanReductionTensor = [mpsGraph constantWithScalar:is_mean_reduction + dataType:MPSDataTypeInt64]; // constant does not support MPSDataTypeBool + MPSGraphTensor* inputNumelTensor = [mpsGraph constantWithScalar:input_numel + dataType:getMPSDataType(new_grad_output.scalar_type())]; + + MPSGraphTensor* normGradOutputTensor = [mpsGraph selectWithPredicateTensor:isMeanReductionTensor + truePredicateTensor: [mpsGraph divisionWithPrimaryTensor:gradOutputTensor + secondaryTensor:inputNumelTensor + name:nil] + falsePredicateTensor: gradOutputTensor + name:nil]; + MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta + shape:getMPSShape(target) + dataType:MPSDataTypeFloat32]; + MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor + secondaryTensor:targetTensor + name:nil]; + MPSGraphTensor* normGradOutputDeltaTensor = [mpsGraph multiplicationWithPrimaryTensor:normGradOutputTensor + secondaryTensor:deltaTensor + name:nil]; + // first condition: (input - target) <= -delta + // formula: -norm * grad_output * delta + MPSGraphTensor* firstCondTensor = [mpsGraph negativeWithTensor: normGradOutputDeltaTensor + name: nil]; + // second condition: (input - target) >= delta + // formula: norm * grad_output * delta + MPSGraphTensor* secondCondTensor = normGradOutputDeltaTensor; + + // third condition: (input - target) within -delta to delta + // formula: norm * (input - target) * grad_output + MPSGraphTensor* thirdCondTensor = [mpsGraph multiplicationWithPrimaryTensor:normGradOutputTensor + secondaryTensor:diffTensor + name:nil]; + + MPSGraphTensor* secondThirdTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph greaterThanOrEqualToWithPrimaryTensor:diffTensor + secondaryTensor:deltaTensor + name:nil] + truePredicateTensor: secondCondTensor + falsePredicateTensor: thirdCondTensor + name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph lessThanOrEqualToWithPrimaryTensor: diffTensor + secondaryTensor:[mpsGraph negativeWithTensor: deltaTensor + name: nil] + name:nil] + truePredicateTensor: firstCondTensor + falsePredicateTensor: secondThirdTensor + name:nil]; + + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->targetTensor_ = targetTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + })); + } + + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, new_grad_output); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); + + NSDictionary* feeds = @{ + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + return grad_input; +} + // MSELoss TORCH_IMPL_FUNC(mse_loss_out_mps) ( const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& output) { diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index 77a284963d6e1..adf10b16fbfa8 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -103,16 +103,9 @@ Tensor _mps_max_pool2d( outputHeight, outputWidth, memory_format); namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; CheckedFrom c = "mps_max_pool2d"; - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); Tensor output_t; @@ -161,7 +154,7 @@ Tensor _mps_max_pool2d( to_string(padW) + ":" + to_string(padH) + ":" + to_string(ceil_mode) + ":" + mem_format_key + mps::getTensorsStringKey({input_t}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -711,7 +704,7 @@ Tensor mps_max_pool2d_backward( to_string(ceil_mode) + ":" + mem_format_key + ":" + to_string(divisor_override_value) + mps::getTensorsStringKey({input}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index cc1ef6ab3fa12..67aeae4ca3cbe 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -13,13 +13,23 @@ namespace at { namespace native { -using namespace std; - enum StdVarType { STANDARD_VARIANCE, STANDARD_DEVIATION }; +enum MPSReductionType { + MAX, + MIN, + AMAX, + AMIN, + SUM, + PROD, + MEAN, + COUNT_NONZERO +}; + + void set_apparent_shapes(NSMutableArray * &apparent_out_shape, NSMutableArray * &apparent_in_shape, int64_t num_reduce_dims, @@ -74,13 +84,15 @@ void set_apparent_shapes(NSMutableArray * &apparent_out_shape, // Helper function to set the axes of reduction void set_axes(NSMutableArray * &axes, int64_t num_reduce_dims, - IntArrayRef& dim, + OptionalIntArrayRef opt_dim, int64_t num_input_dims) { if(num_reduce_dims == 0) { axes = [NSMutableArray arrayWithCapacity:1]; axes[0] = @0; } else { + TORCH_INTERNAL_ASSERT(opt_dim.has_value()); + IntArrayRef dim = opt_dim.value(); axes = [NSMutableArray arrayWithCapacity:num_reduce_dims]; for(int i = 0; i < num_reduce_dims; i++) { axes[i] = [NSNumber numberWithInt:maybe_wrap_dim(dim[i], num_input_dims)]; @@ -90,7 +102,7 @@ void set_axes(NSMutableArray * &axes, // Helper function to prepare axes and tensor shapes void set_axes_and_shapes(const Tensor& input_t, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, NSMutableArray * &axes, NSMutableArray * &apparent_input_shape, NSMutableArray * &apparent_output_shape, @@ -99,13 +111,13 @@ void set_axes_and_shapes(const Tensor& input_t, IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - int64_t num_reduce_dims = dims.size(); + int64_t num_reduce_dims = opt_dims.has_value() ? opt_dims.value().size() : 0; int64_t num_output_dims; num_output_dims = num_reduce_dims == 0 ? 1 : num_input_dims; // Reduction axes - set_axes(axes, num_reduce_dims, dims, input_shape.size()); + set_axes(axes, num_reduce_dims, opt_dims, input_shape.size()); // Shapes set_apparent_shapes(apparent_output_shape, @@ -127,19 +139,22 @@ void set_axes_and_shapes(const Tensor& input_t, void reduction_out_mps (const Tensor& input_t, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, c10::optional dtype, const Tensor& output_t, - string reduction_type, - string func_name) { + MPSReductionType reduction_type, + const std::string& func_name) { IntArrayRef input_shape = input_t.sizes(); - for(int i = 0; i < dim.size(); i++) { - auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); - TORCH_CHECK(wrap_dim < input_shape.size(), - func_name+": reduction dim must be in the range of input shape") + if (opt_dim.has_value()) { + IntArrayRef dim = opt_dim.value(); + for(int i = 0; i < dim.size(); i++) { + auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); + TORCH_CHECK(wrap_dim < input_shape.size(), + func_name+": reduction dim must be in the range of input shape") + } } namespace native_mps = at::native::mps; @@ -149,17 +164,9 @@ void set_axes_and_shapes(const Tensor& input_t, NSMutableArray *apparent_output_shape = nil; NSMutableArray *output_shape = nil; - set_axes_and_shapes(input_t, dim, axes, apparent_input_shape, apparent_output_shape, output_shape); - - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; + set_axes_and_shapes(input_t, opt_dim, axes, apparent_input_shape, apparent_output_shape, output_shape); - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + auto cache_ = native_mps::MPSGraphCache::getInstance(); if (output_t.numel() == 0 || input_t.numel() == 0) { return; @@ -172,7 +179,8 @@ void set_axes_and_shapes(const Tensor& input_t, // TODO: Make this key proper NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","]; string key = func_name+":" + string([ns_key UTF8String]) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()) + ":" + native_mps::getMPSTypeString(output_t.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + using CachedGraph = native_mps::MPSUnaryCachedGraph; + auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -196,19 +204,19 @@ void set_axes_and_shapes(const Tensor& input_t, MPSGraphTensor* castOutputTensor = nil; - if(reduction_type == "sum") { + if(reduction_type == MPSReductionType::SUM) { castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:axes name:nil]; - } else if(reduction_type == "prod") { + } else if(reduction_type == MPSReductionType::PROD) { castOutputTensor = [mpsGraph reductionProductWithTensor:castInputTensor axes:axes name:nil]; - } else if(reduction_type == "mean") { + } else if(reduction_type == MPSReductionType::MEAN) { castOutputTensor = [mpsGraph meanOfTensor:inputTensor axes:axes name:nil]; - } else if(reduction_type == "count_nonzero") { + } else if(reduction_type == MPSReductionType::COUNT_NONZERO) { MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0 dataType:castInputTensor.dataType]; @@ -220,11 +228,11 @@ void set_axes_and_shapes(const Tensor& input_t, axes:axes name:nil]; } - else if(reduction_type == "amax") { + else if(reduction_type == MPSReductionType::AMAX) { castOutputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor axes:axes name:nil]; - } else if(reduction_type == "amin") { + } else if(reduction_type == MPSReductionType::AMIN) { castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor axes:axes name:nil]; @@ -244,7 +252,7 @@ void set_axes_and_shapes(const Tensor& input_t, } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = tmpCachedGraph->as(); } auto inputPlaceholder = native_mps::Placeholder(); @@ -268,12 +276,12 @@ void set_axes_and_shapes(const Tensor& input_t, TORCH_IMPL_FUNC(sum_out_mps) (const Tensor& input_t, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, c10::optional dtype, const Tensor& output_t) { - reduction_out_mps(input_t, dim, keepdim, dtype, output_t, "sum", "sum_out_mps"); + reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); } TORCH_IMPL_FUNC(prod_out_mps) @@ -285,7 +293,7 @@ void set_axes_and_shapes(const Tensor& input_t, int64_t dims[1] = {dim}; - reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, "prod", "prod_out_mps"); + reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps"); } // Taken from ReduceOps.cpp @@ -309,7 +317,7 @@ inline ScalarType get_dtype_from_self( bool keepdim, const Tensor& output_t) { - reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amax", "amax_out_mps"); + reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps"); } TORCH_IMPL_FUNC(amin_out_mps) @@ -318,7 +326,7 @@ inline ScalarType get_dtype_from_self( bool keepdim, const Tensor& output_t) { - reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amin", "amin_out_mps"); + reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps"); } Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { @@ -338,7 +346,7 @@ Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { c10::nullopt, c10::nullopt); - reduction_out_mps(self, IntArrayRef(dims, num_dims), false, opt_dtype, const_cast(output_t), "prod", "prod_mps"); + reduction_out_mps(self, IntArrayRef(dims, num_dims), false, opt_dtype, const_cast(output_t), MPSReductionType::PROD, "prod_mps"); return output_t; } @@ -365,7 +373,7 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ c10::nullopt, c10::nullopt); - reduction_out_mps(self, dims, false, self.scalar_type(), const_cast(output_t), "count_nonzero", "count_nonzero_mps"); + reduction_out_mps(self, dims, false, self.scalar_type(), const_cast(output_t), MPSReductionType::COUNT_NONZERO, "count_nonzero_mps"); free(raw_output_shape); @@ -374,140 +382,12 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ TORCH_IMPL_FUNC(mean_out_mps) (const Tensor& input_t, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, c10::optional dtype, const Tensor& output_t) { - reduction_out_mps(input_t, dim, keepdim, dtype, output_t, "mean", "mean_out_mps"); -} - -TORCH_IMPL_FUNC(argmax_out_mps) - (const Tensor& input_t, - c10::optional dim, - bool keepdim, - const Tensor& output_t) { - - namespace native_mps = at::native::mps; - - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - - int64_t dim_; - - if (dim.has_value()) { - dim_ = maybe_wrap_dim(dim.value(), input_t.dim()); - native::zero_numel_check_dims(input_t, dim_, "argmax()"); - } else { - TORCH_CHECK_INDEX( - input_t.numel() != 0, - "argmax()", ": Expected reduction dim to be specified for input.numel() == 0."); - // Since input will be flattened, take argmax along 0'th dimension - dim_ = 0; - } - - // Calculate the output shape according to keepdim=True - // If there is no dim argument, the input shape is flattened - IntArrayRef input_shape = input_t.sizes(); - int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_in_shape = nil; - NSMutableArray *apparent_out_shape = nil; - - if(dim.has_value()) { - apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; - for(int i = 0; i < num_input_dims; i++) { - if(dim_ == i) - apparent_out_shape[i] = @1; - else - apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; - } - } - else { - apparent_in_shape = [NSMutableArray arrayWithCapacity:1]; - int64_t num_in_elements = 1; - for(int i = 0; i < num_input_dims; i++) { - num_in_elements *= input_shape[i]; - } - apparent_in_shape[0] = [NSNumber numberWithInt:num_in_elements]; - - apparent_out_shape = [NSMutableArray arrayWithCapacity:1]; - apparent_out_shape[0] = @1; - } - - if (output_t.numel() == 0) { - return; - } - - auto stream = at::mps::getCurrentMPSStream(); - - @autoreleasepool { - string key = "argmax_out_mps:" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); - - MPSGraphTensor* castInputTensor = nil; - - if(input_t.scalar_type() != ScalarType::Float && - input_t.scalar_type() != ScalarType::Int && - input_t.scalar_type() != ScalarType::Half) - castInputTensor = [mpsGraph castTensor:inputTensor - toType:MPSDataTypeFloat32 - name:@"castInputTensor"]; - else - castInputTensor = inputTensor; - - MPSGraphTensor* argmaxOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor - axis:(NSInteger)dim_ - name:@"argmax_out"]; - MPSGraphTensor* outputTensor = [mpsGraph castTensor:argmaxOutTensor - toType:MPSDataTypeInt64 - name:@"cast_out"]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - native_mps::Placeholder inputPlaceholder = native_mps::Placeholder(); - if(apparent_in_shape) - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape); - else - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); - - auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); - - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; - - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - - } - + reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps"); } TORCH_IMPL_FUNC(norm_out_mps) @@ -529,13 +409,7 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ namespace native_mps = at::native::mps; CheckedFrom c = "norm_out_mps"; - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; + using CachedGraph = native_mps::MPSUnaryCachedGraph; native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); @@ -576,7 +450,7 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + native_mps::getMPSTypeString(input_t.scalar_type()) + ":p" + to_string(p) + ":" + keepdim_info; - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -649,7 +523,7 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = tmpCachedGraph->as(); } auto inputPlaceholder = native_mps::Placeholder(); @@ -683,6 +557,7 @@ Tensor std_var_common_impl_mps( StdVarType stdVarType) { namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); @@ -704,15 +579,6 @@ Tensor std_var_common_impl_mps( const auto correction_value = use_correction ? correction.value() : false; int64_t correction_n = 1; - - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); int64_t num_output_dims = 0; @@ -856,9 +722,9 @@ Tensor std_var_common_impl_mps( string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased "; string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0"; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; - string key = op_key + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()) + ":" + bessel_corrected; + string key = op_key + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + native_mps::getTensorsStringKey(input_t) + ":" + bessel_corrected; - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + auto cachedGraph = cache_->LookUpAs(key); // Initialize once if configuration not found in cache if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -948,19 +814,12 @@ Tensor std_mps( const Tensor& output_t) { namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; if (output_t.numel() == 0 || input_t.numel() == 0) { return; } - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, "any()"); @@ -983,7 +842,7 @@ Tensor std_mps( @autoreleasepool { MPSShape* input_t_shape = native_mps::getMPSShape(input_t); string key = string("any_out_mps:") + native_mps::getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -1026,7 +885,7 @@ Tensor std_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = tmpCachedGraph->as(); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); @@ -1046,26 +905,19 @@ Tensor std_mps( TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t) { namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; if (output_t.numel() == 0 || input_t.numel() == 0) { return; } - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + auto cache_ = native_mps::MPSGraphCache::getInstance(); auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { MPSShape* input_t_shape = native_mps::getMPSShape(input_t); string key = string("any_all_out_mps:") + native_mps::getMPSShapeString(input_t_shape) +":" + native_mps::getMPSTypeString(input_t.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -1133,19 +985,12 @@ Tensor std_mps( const Tensor& output_t) { namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; if (output_t.numel() == 0 || input_t.numel() == 0) { return; } - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, "all()"); @@ -1168,7 +1013,7 @@ Tensor std_mps( @autoreleasepool { MPSShape* input_t_shape = native_mps::getMPSShape(input_t); string key = string("all_out_mps:") + native_mps::getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -1211,7 +1056,7 @@ Tensor std_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = tmpCachedGraph->as(); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); @@ -1231,18 +1076,11 @@ Tensor std_mps( TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t) { namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; if (output_t.numel() == 0 || input_t.numel() == 0) { return; } - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); auto stream = at::mps::getCurrentMPSStream(); @@ -1250,7 +1088,7 @@ Tensor std_mps( @autoreleasepool { MPSShape* input_t_shape = native_mps::getMPSShape(input_t); string key = string("all_all_out_mps:") + native_mps::getMPSShapeString(input_t_shape) +":" + native_mps::getMPSTypeString(input_t.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -1294,7 +1132,7 @@ Tensor std_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = tmpCachedGraph->as(); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); @@ -1316,18 +1154,11 @@ Tensor std_mps( Tensor min_max_mps (const Tensor& input_t, - string reduction_type, - string func_name) { + MPSReductionType reduction_type, + const std::string& func_name) { namespace native_mps = at::native::mps; - - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; + using CachedGraph = native_mps::MPSUnaryCachedGraph; native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); @@ -1350,7 +1181,7 @@ Tensor std_mps( @autoreleasepool { string key = func_name + mps::getTensorsStringKey(input_t); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); // Initialize once if configuration not found in cache if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -1365,11 +1196,11 @@ Tensor std_mps( MPSGraphTensor* outputTensor = nil; - if(reduction_type == "max") + if(reduction_type == MPSReductionType::MAX) outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor axes:@[@0] name:nil]; - else if(reduction_type == "min") + else if(reduction_type == MPSReductionType::MIN) outputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor axes:@[@0] name:nil]; @@ -1403,13 +1234,13 @@ Tensor std_mps( // Max entire tensor into scalar result Tensor max_mps(const Tensor& input_t) { - return min_max_mps(input_t, "max", "max_mps"); + return min_max_mps(input_t, MPSReductionType::MAX, "max_mps"); } // Min entire tensor into scalar result Tensor min_mps(const Tensor& input_t) { - return min_max_mps(input_t, "min", "min_mps"); + return min_max_mps(input_t, MPSReductionType::MIN, "min_mps"); } void min_max_out_mps @@ -1418,8 +1249,8 @@ Tensor min_mps(const Tensor& input_t) { bool keepdim, const Tensor& output_t, const Tensor& indices_t, - string reduction_type, - string func_name) { + MPSReductionType reduction_type, + const std::string& func_name) { namespace native_mps = at::native::mps; @@ -1477,11 +1308,11 @@ Tensor min_mps(const Tensor& input_t) { MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); MPSGraphTensor* outputTensor = nil; - if(reduction_type == "max") + if(reduction_type == MPSReductionType::MAX) outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor axis:(NSInteger)dim_ name:nil]; - else if(reduction_type == "min") + else if(reduction_type == MPSReductionType::MIN) outputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor axis:(NSInteger)dim_ name:nil]; @@ -1498,11 +1329,11 @@ Tensor min_mps(const Tensor& input_t) { castInputTensor = inputTensor; MPSGraphTensor* argreduceOutTensor = nil; - if(reduction_type == "max") + if(reduction_type == MPSReductionType::MAX) argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor axis:(NSInteger)dim_ name:@"argmax_out"]; - else if(reduction_type == "min") + else if(reduction_type == MPSReductionType::MIN) argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor axis:(NSInteger)dim_ name:@"argmax_out"]; @@ -1550,7 +1381,7 @@ Tensor min_mps(const Tensor& input_t) { int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, "max()"); - min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, "max", "max_out_mps"); + min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MAX, "max_out_mps"); } // Min out with dim @@ -1564,16 +1395,163 @@ Tensor min_mps(const Tensor& input_t) { int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, "min()"); - min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, "min", "min_out_mps"); + min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MIN, "min_out_mps"); } +void argmax_argmin_out_mps + (const Tensor& input_t, + c10::optional dim, + bool keepdim, + const Tensor& output_t, + MPSReductionType reduction_type, + const std::string& func_name) { + namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; + + native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + + int64_t dim_; + + if (dim.has_value()) { + dim_ = maybe_wrap_dim(dim.value(), input_t.dim()); + zero_numel_check_dims(input_t, dim_, reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()"); + } else { + TORCH_CHECK_INDEX( + input_t.numel() != 0, + reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()" , ": Expected reduction dim to be specified for input.numel() == 0."); + // Since input will be flattened, take argmax or argmin along 0'th dimension + dim_ = 0; + } + + // Calculate the output shape according to keepdim=True + // If there is no dim argument, the input shape is flattened + IntArrayRef input_shape = input_t.sizes(); + int64_t num_input_dims = input_shape.size(); + NSMutableArray *apparent_in_shape = nil; + NSMutableArray *apparent_out_shape = nil; + + if(dim.has_value()) { + apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; + for(int i = 0; i < num_input_dims; i++) { + if(dim_ == i) + apparent_out_shape[i] = @1; + else + apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; + } + } + else { + apparent_in_shape = [NSMutableArray arrayWithCapacity:1]; + int64_t num_in_elements = 1; + for(int i = 0; i < num_input_dims; i++) { + num_in_elements *= input_shape[i]; + } + apparent_in_shape[0] = [NSNumber numberWithInt:num_in_elements]; + + apparent_out_shape = [NSMutableArray arrayWithCapacity:1]; + apparent_out_shape[0] = @1; + } + + if (output_t.numel() == 0) { + return; + } + + auto stream = at::mps::getCurrentMPSStream(); + + @autoreleasepool { + string key = func_name + to_string(dim_) + ":" + native_mps::getTensorsStringKey(input_t); + CachedGraph* cachedGraph = cache_->LookUpAs(key); + + if(!cachedGraph) { + native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = native_mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + + MPSGraphTensor* castInputTensor = nil; + MPSGraphTensor* argreduceOutTensor = nil; + + if(input_t.scalar_type() != ScalarType::Float && + input_t.scalar_type() != ScalarType::Int && + input_t.scalar_type() != ScalarType::Half) + castInputTensor = [mpsGraph castTensor:inputTensor + toType:MPSDataTypeFloat32 + name:@"castInputTensor"]; + else + castInputTensor = inputTensor; + + if (reduction_type == MPSReductionType::MAX) { + argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor + axis:(NSInteger)dim_ + name:nil]; + } + else { + argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor + axis:(NSInteger)dim_ + name:nil]; + } + MPSGraphTensor* outputTensor = [mpsGraph castTensor:argreduceOutTensor + toType:MPSDataTypeInt64 + name:@"castOutpuTensor"]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + native_mps::Placeholder inputPlaceholder = native_mps::Placeholder(); + if(apparent_in_shape) + inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape); + else + inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + + auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); + + NSDictionary *feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + + NSDictionary *results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } +} + +TORCH_IMPL_FUNC(argmax_out_mps) + (const Tensor& input_t, + c10::optional dim, + bool keepdim, + const Tensor& output_t) { + + argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MAX, "argmax_out_mps"); +} + +TORCH_IMPL_FUNC(argmin_out_mps) + (const Tensor& input_t, + c10::optional dim, + bool keepdim, + const Tensor& output_t) { + + argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MIN, "argmin_out_mps"); +} + + // Min/Max with dim std::tuple min_max_mps (const Tensor& input_t, int64_t dim, bool keepdim, - string reduction_type, - string func_name) { + MPSReductionType reduction_type, + const std::string& func_name) { namespace native_mps = at::native::mps; @@ -1661,7 +1639,7 @@ Tensor min_mps(const Tensor& input_t) { int64_t dim, bool keepdim) { - return min_max_mps(input_t, dim, keepdim, "max", "max_mps"); + return min_max_mps(input_t, dim, keepdim, MPSReductionType::MAX, "max_mps"); } // Min with dim @@ -1670,9 +1648,8 @@ Tensor min_mps(const Tensor& input_t) { int64_t dim, bool keepdim) { - return min_max_mps(input_t, dim, keepdim, "min", "min_mps"); -} - + return min_max_mps(input_t, dim, keepdim, MPSReductionType::MIN, "min_mps"); } -} +} // native +} // at diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index e709455d84063..b9c465145ffeb 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -44,12 +44,21 @@ void set_apparent_shapes(NSMutableArray * input_shape, int64_t num_repeat_dims) { - // Set repeats_shape + bool repeat_empty = false; + if(num_repeat_dims == 0) { + num_repeat_dims = num_input_dims; + repeat_empty = true; + } + // Set repeats_shape repeats_shape = [NSMutableArray arrayWithCapacity:num_repeat_dims]; - for(int i = 0; i < num_repeat_dims; i++) - repeats_shape[i] = [NSNumber numberWithInt:repeats[i]]; + for(int i = 0; i < num_repeat_dims; i++) { + if(repeat_empty) + repeats_shape[i] = [NSNumber numberWithInteger:1]; + else + repeats_shape[i] = [NSNumber numberWithInteger:repeats[i]]; + } // If no extension of the shape is needed if(num_repeat_dims == num_input_dims) { @@ -115,7 +124,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { c10::nullopt); // Empty output - if(zero_tensor) + if(zero_tensor || output.numel() == 0) return output; auto stream = at::mps::getCurrentMPSStream(); diff --git a/aten/src/ATen/native/mps/operations/RnnOps.mm b/aten/src/ATen/native/mps/operations/RnnOps.mm index a219d3f8172cb..0dd1bd6b47a21 100644 --- a/aten/src/ATen/native/mps/operations/RnnOps.mm +++ b/aten/src/ATen/native/mps/operations/RnnOps.mm @@ -153,6 +153,13 @@ name:nil]]; } + MPSGraphTensor* outputTensor = [outputs objectAtIndex:0]; + if (batch_first) { + outputTensor = [mpsGraph transposeTensor:outputTensor + dimension:0 + withDimension:1 + name:nil]; + } MPSGraphTensor* outputStates = [mpsGraph concatTensors:outputStateArray dimension:0 name:nil]; @@ -166,7 +173,7 @@ dimension:0 name:nil]; - std::vector outputTensors = {[outputs objectAtIndex:0], outputStates, outputCellStates, outputZStates, outputCellStatesFwd}; + std::vector outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd}; newCachedGraph->inputTensors_ = inputTensors; newCachedGraph->outputTensors_ = outputTensors; newCachedGraph->kernelWeightsList_ = kernelWeightsList; diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index 393d4f1bc57a6..977f9f1ce3fae 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -21,7 +21,7 @@ // Pad operations (1D/2D/3D forward and backward) Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef padding, const c10::optional& grad_output_opt, - MPSGraphPaddingMode mode, const string op_name) + MPSGraphPaddingMode mode, double constantValue, const string op_name) { const int padding_size = (int) padding.size(); const int padding_dim = padding_size / 2; // either 1D, 2D, or 3D @@ -150,7 +150,7 @@ withPaddingMode:mode leftPadding:leftPadding rightPadding:rightPadding - constantValue:0 + constantValue:constantValue name:nil]; } else { newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); @@ -187,101 +187,116 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_mps) (const Tensor& input, IntArrayRef padding, const Tensor& output) { - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, MPSGraphPaddingModeReflect, "reflection_pad1d_out_mps"); + mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, + MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_out_mps"); } TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps) (const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, MPSGraphPaddingModeReflect, "reflection_pad1d_backward_out_mps"); + mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, + MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_backward_out_mps"); } TORCH_IMPL_FUNC(replication_pad1d_out_mps) (const Tensor& input, IntArrayRef padding, const Tensor& output) { - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad1d_out_mps"); + mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, + MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_out_mps"); } TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps) (const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, MPSGraphPaddingModeClampToEdge, "replication_pad1d_backward_out_mps"); + mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, + MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_backward_out_mps"); } // 2D Reflection and Replication Padding Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output) { - return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, __func__); + return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__); } Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding) { Tensor output = at::empty({0}, input.options()); - return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, __func__); + return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__); } Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, __func__); + return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__); } Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, __func__); + return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__); } TORCH_IMPL_FUNC(replication_pad2d_out_mps) (const Tensor& input, IntArrayRef padding, const Tensor& output) { - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad2d_out_mps"); + mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, + MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad2d_out_mps"); } Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__); + return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); } Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__); + return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); } // 3D Reflection and Replication Padding TORCH_IMPL_FUNC(reflection_pad3d_out_mps) (const Tensor& input, IntArrayRef padding, const Tensor& output) { - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, MPSGraphPaddingModeReflect, "reflection_pad3d_out_mps"); + mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, + MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_out_mps"); } TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps) (const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, MPSGraphPaddingModeReflect, "reflection_pad3d_backward_out_mps"); + mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, + MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_backward_out_mps"); } TORCH_IMPL_FUNC(replication_pad3d_out_mps) (const Tensor& input, IntArrayRef padding, const Tensor& output) { - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad3d_out_mps"); + mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, + MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad3d_out_mps"); } Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__); + return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); } Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__); + return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); +} + +// backward pass is exlicitly handled in autograd by negating the "pad" argument +Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) +{ + Tensor output = at::empty({0}, self.options()); + return mps::pad_out_template(output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__); } // topk @@ -543,13 +558,22 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, } at::assert_no_internal_overlap(out); + // Indices of tensors to be skipped because they're empty + std::vector skipped_tensor_indices; + // Tensors to be read + std::vector input_tensors; + int tensor_idx = 0; for(const Tensor& t : materialized_inputs) { - if (should_skip(t)) { + if(t.numel() == 0 || should_skip(t)) { + skipped_tensor_indices.push_back(tensor_idx); + tensor_idx++; continue; } + input_tensors.push_back(&t); nDims = t.dim(); // TODO: Is this OK? notSkippedTensor = &t; + tensor_idx++; } // If all inputs are empty tensors, return an empty tensor @@ -623,9 +647,19 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, MPSGraphCache *cache_ = MPSGraphCache::getInstance(); + // Make string out of skipped tensor indices + string skipped_indices_string = ""; + for(int idx : skipped_tensor_indices) + skipped_indices_string += (std::to_string(idx)+","); + string input_types = ""; + for(const Tensor& tensor : materialized_inputs) + input_types += (getMPSTypeString(tensor.scalar_type())+","); + @autoreleasepool { string key = "cat_out_mps:" + getMPSTypeString(result_type(inputs)) + ":" + to_string(inputs.size()) + + ":" + skipped_indices_string + + ":" + input_types + ":" + to_string(dimension); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -638,22 +672,44 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, newCachedGraph = new CachedGraph(mpsGraph); // Create placeholders - MPSGraphTensor* inputMPSGraphTensors[inputs.size()]; - - for(int i = 0; i < inputs.size(); i++) - inputMPSGraphTensors[i] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(result_type(inputs))); + auto len_tensor_array = inputs.size() - skipped_tensor_indices.size(); + MPSGraphTensor* inputMPSGraphTensors[len_tensor_array]; + MPSGraphTensor* castInputMPSGraphTensors[len_tensor_array]; + + int graph_tensor_idx = 0; + for(const Tensor* tensor : input_tensors) { + inputMPSGraphTensors[graph_tensor_idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(tensor->scalar_type()) ); + if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) { + castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx] + toType:MPSDataTypeFloat32 + name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]]; + } + else { + if(tensor->scalar_type() != result_type(inputs)) + castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx] + toType:getMPSDataType(result_type(inputs)) + name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]]; + else + castInputMPSGraphTensors[graph_tensor_idx] = inputMPSGraphTensors[graph_tensor_idx]; + } + graph_tensor_idx++; + } - auto inputTensorsArray = [NSArray arrayWithObjects:inputMPSGraphTensors - count:inputs.size()]; + auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors + count:len_tensor_array]; // Use concatTensors to concatenate MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray dimension:dimension // Maybe convert this from int64_t -> int32 name:nil]; - newCachedGraph->inputMPSGraphTensors_ = (MPSGraphTensor**)malloc(inputs.size() * sizeof(MPSGraphTensor*)); + newCachedGraph->inputMPSGraphTensors_ = (MPSGraphTensor**)malloc(len_tensor_array * sizeof(MPSGraphTensor*)); - for(int i = 0; i < inputs.size(); i++) + for(int i = 0; i < len_tensor_array; i++) newCachedGraph->inputMPSGraphTensors_[i] = inputMPSGraphTensors[i]; + if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) + outputTensor = [mpsGraph castTensor:outputTensor + toType:MPSDataTypeBool + name:@"outputTensor"]; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; @@ -663,16 +719,20 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, std::vector inputPlaceholders; int i = 0; + int t_idx = 0; for(const Tensor& tensor : materialized_inputs) { - Placeholder currentInputPlaceholder = Placeholder(cachedGraph->inputMPSGraphTensors_[i], tensor); - inputPlaceholders.push_back(currentInputPlaceholder); + if(std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) { + Placeholder currentInputPlaceholder = Placeholder(cachedGraph->inputMPSGraphTensors_[t_idx], tensor); + inputPlaceholders.push_back(currentInputPlaceholder); + t_idx++; + } i++; } Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; - for (int i = 0; i < inputs.size(); i++) { + for (int i = 0; i < inputPlaceholders.size(); i++) { feeds[(inputPlaceholders[i]).getMPSGraphTensor()] = (inputPlaceholders[i]).getMPSGraphTensorData(); } NSDictionary* results = @{ @@ -916,5 +976,77 @@ void upsample_out_mps(const Tensor& input, using namespace mps; upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeBilinear, align_corners); } + +void upsample1d_out_mps(const Tensor& input, + IntArrayRef output_size, + c10::optional scales, + const Tensor& output, + MPSGraphResizeMode requested_mode) +{ + // Get stream + using namespace mps; + using CachedGraph = MPSUnaryCachedGraph; + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + /* sizes */ + int64_t out_size = output_size[0]; + @autoreleasepool { + MPSShape* input_shape = getMPSShape(input); + NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; + string key = string("upsample_1d:") + mps::getMPSShapeString(input_shape) + ":" + + getMPSTypeString(input.scalar_type()) + + ":size" + to_string(out_size) + + ":mode" + to_string(requested_mode); + + CachedGraph* cachedGraph = cache_->LookUpAs(key); + if(!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape); + newCachedGraph->outputTensor_ = [mpsGraph resizeTensor:newCachedGraph->inputTensor_ + size:@[ @(out_size), @(1)] + mode:requested_mode + centerResult: true + alignCorners: true + layout: MPSGraphTensorNamedDataLayoutCHW + name:nil]; + } + return newCachedGraph; + })); + } + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } +} + + +TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) ( + const Tensor& input, + IntArrayRef output_size, + c10::optional scales, + const Tensor& output) +{ + using namespace mps; + upsample1d_out_mps(input, output_size, scales, output, MPSGraphResizeNearest); +} + + + + + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index 088886e0d12d7..2231a66fb3ac6 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -19,37 +19,36 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una if (!output.is_same_size(self)) { output.resize_(self.sizes()); } - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor = nil, *outputTensor = nil; - }; + // Empty tensor is noop + if (self.numel() == 0) { + return; + } MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { string key = op_name + getTensorsStringKey({self}, /*use_scalar_value*/ false); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () { - CachedGraph *newCachedGraph = nil; + MPSUnaryCachedGraph *newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* castTensor = newCachedGraph->inputTensor; + newCachedGraph = new MPSUnaryCachedGraph(mpsGraph); + newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* castTensor = newCachedGraph->inputTensor_; // Integer input must be cast to float if output is float if (isIntegralType(self.scalar_type()) && isFloatingType(output.scalar_type())) { - castTensor = castMPSTensor(mpsGraph, newCachedGraph->inputTensor, output.scalar_type()); + castTensor = castMPSTensor(mpsGraph, newCachedGraph->inputTensor_, output.scalar_type()); } - newCachedGraph->outputTensor = unaryBlock(mpsGraph, castTensor); + newCachedGraph->outputTensor_ = unaryBlock(mpsGraph, castTensor); } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = tmpCachedGraph->as(); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; @@ -126,45 +125,47 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una CREATE_MPS_UNARY_TORCH_IMPL_FUNC(abs_out_mps, absolute) +Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) +{ + auto bool_self = self.to(ScalarType::Bool); + mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor){ return [mpsGraph notWithTensor:inputTensor name:nil];}); + return output; +} + TORCH_IMPL_FUNC(log1p_out_mps) (const Tensor& self, const Tensor& output) { using namespace mps; if (!output.is_same_size(self)) { output.resize_(self.sizes()); } - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor = nil, *outputTensor = nil; - }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { string key = string("log1p_out_mps") + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () { - CachedGraph *newCachedGraph = nil; + MPSUnaryCachedGraph *newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + newCachedGraph = new MPSUnaryCachedGraph(mpsGraph); + newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:getMPSShape(self) dataType:mps::getMPSDataType(self.scalar_type())]; - MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:newCachedGraph->inputTensor + MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:newCachedGraph->inputTensor_ secondaryTensor:oneTensor name:nil]; - newCachedGraph->outputTensor = [mpsGraph logarithmWithTensor:addedTensor + newCachedGraph->outputTensor_ = [mpsGraph logarithmWithTensor:addedTensor name:nil]; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = tmpCachedGraph->as(); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm new file mode 100644 index 0000000000000..4fa614ae6e2c6 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -0,0 +1,275 @@ +// Copyright © 2022 Apple Inc. + +#include +#include + +namespace at { + +// these are from MPSAllocator +namespace mps { + // to check the requested non-aligned size of an MTL buffer + ssize_t get_requested_buffer_size(void* ptr); + // to retrieve the shape of a base tensor from a view tensor + IntArrayRef get_buffer_shape(void* ptr); + // to set the shape of a base tensor from a view tensor + void set_buffer_shape(void* ptr, const IntArrayRef& shape); +} + +namespace native { +namespace mps { + +struct ViewCachedGraph : public MPSCachedGraph +{ + ViewCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor = nil; + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* updatesTensor = nil; + MPSGraphTensor* storageOffsetTensor = nil; + std::vector strideTensors; +}; + +static std::string getStridedKey(const ScalarType& dtype, const IntArrayRef& base_shape, + const IntArrayRef& new_shape, bool is_scatter) +{ + return (is_scatter ? "scatter:" : "gather:") + getMPSTypeString(dtype) + "[" + + getArrayRefString(base_shape) + "]:[" + getArrayRefString(new_shape) + "]"; +} + +// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op +static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output, + bool needsScatter, bool requires_sync = false) +{ + const id sourceBuffer = getMTLBufferStorage(src); + const id outputBuffer = getMTLBufferStorage(output); + + const IntArrayRef& strides = needsScatter ? output.strides() : src.strides(); + const IntArrayRef& sizes = needsScatter ? output.sizes() : src.sizes(); + const int64_t storage_offset = needsScatter ? output.storage_offset() : src.storage_offset(); + const MPSDataType inputType = [cachedGraph->inputTensor dataType]; + + MPSShape *inputShape = [cachedGraph->inputTensor shape]; + MPSShape *outputShape = needsScatter ? inputShape : getMPSShape(src); + + MPSStream* stream = getCurrentMPSStream(); + @autoreleasepool { + NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + // in case of scatter, we use output tensor as input buffer and write the results back to the source buffer + feeds[cachedGraph->inputTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: needsScatter ? outputBuffer : sourceBuffer + shape: inputShape + dataType: inputType] autorelease]; + if (needsScatter) { + feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer + shape: getMPSShape(src.numel()) + dataType: inputType] autorelease]; + } + feeds[cachedGraph->storageOffsetTensor] = getMPSGraphTensorFromScalar(stream, Scalar(storage_offset), MPSDataTypeInt32); + for (int i = 0; i < sizes.size(); i++) { + feeds[cachedGraph->strideTensors[i]] = getMPSGraphTensorFromScalar(stream, Scalar(strides[i]), MPSDataTypeInt32); + } + // Workaround for MPSShaderLibrary bug + // TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved + auto outputType = getMPSDataType(output.scalar_type()); + if (outputType == MPSDataTypeUInt8) { + outputType = MPSDataTypeInt8; + } + MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: outputBuffer + shape: outputShape + dataType: outputType] autorelease]; + NSDictionary* results = @{ + cachedGraph->outputTensor : outputTensorData + }; + stream->executeMPSGraph(cachedGraph->graph(), feeds, results, + requires_sync ? SyncType::COMMIT : SyncType::NONE); + } + return output; +} + +static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const IntArrayRef& size, + const IntArrayRef& stride, int64_t offset, + const IntArrayRef& base_shape, bool needsScatter) +{ + MPSGraph* mpsGraph = cachedGraph->graph(); + MPSGraphTensor *outputTensor = nil; + const size_t shape_size = size.size(); + + @autoreleasepool { + std::vector sizeArray(shape_size); + const int64_t int_max = std::numeric_limits::max(); + for (int i = 0; i < shape_size; i++) { + TORCH_CHECK(size[i] <= int_max); + sizeArray[i] = static_cast(size[i]); + } + NSData* shapeData = [NSData dataWithBytes: sizeArray.data() + length: shape_size * sizeof(int32_t)]; + MPSGraphTensor* shapeTensor = [mpsGraph constantWithData: shapeData + shape: @[[NSNumber numberWithUnsignedInteger: shape_size]] + dataType: MPSDataTypeInt32]; + MPSGraphTensor* indicesTensor = nil; + // create stride Tensors for each rank of the input tensor + for (int i = 0; i < shape_size; i++) { + MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis: (-i - 1) + withShapeTensor: shapeTensor + name: nil]; + MPSGraphTensor* strideTensor = cachedGraph->strideTensors[shape_size - i - 1]; + MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor: rangeTensor + secondaryTensor: strideTensor + name: nil]; + if (!indicesTensor) { + indicesTensor = indexTensor; + } else { + indicesTensor = [mpsGraph additionWithPrimaryTensor: indexTensor + secondaryTensor: indicesTensor + name: nil]; + } + } + + indicesTensor = [mpsGraph additionWithPrimaryTensor: indicesTensor + secondaryTensor: cachedGraph->storageOffsetTensor + name: nil]; + MPSGraphTensor *reshapedInputTensor = [mpsGraph reshapeTensor: cachedGraph->inputTensor + withShape: @[@-1] + name: nil]; + MPSGraphTensor *reshapedIndicesTensor = [mpsGraph reshapeTensor: indicesTensor + withShape: @[@-1] + name: nil]; + if (needsScatter) { + MPSGraphTensor* scatteredTensor = [mpsGraph scatterAlongAxis: 0 + withDataTensor: reshapedInputTensor + updatesTensor: cachedGraph->updatesTensor + indicesTensor: reshapedIndicesTensor + mode: MPSGraphScatterModeSet + name: nil]; + outputTensor = [mpsGraph reshapeTensor: scatteredTensor + withShape: getMPSShape(base_shape) + name: nil]; + } else { + // Call gather to coalesce the needed values. Result will be of same shape as flattened indices tensor + MPSGraphTensor *gatheredTensor = [mpsGraph gatherWithUpdatesTensor: reshapedInputTensor + indicesTensor: reshapedIndicesTensor + axis: 0 + batchDimensions: 0 + name: nil]; + // Reshape the data to desired size + outputTensor = [mpsGraph reshapeTensor: gatheredTensor + withShapeTensor: shapeTensor + name: nil]; + } + } + return outputTensor; +} + +// There are few cases we need to consider: +// Here nodes are the Tensors and the edges are the operations performed on the +// Tensor. As a result of the operation performed we can have result as View +// Tensor (View T) or a Non view tensor (NonView T). The difference is if its +// mapped by the same underlying storage ptr or a new MTLBuffer was allocated. +// T = Tensor +// ---------- +// | Orig T | +// ---------- +// / | \ +// View T View T NonView T +// / / \ | +// View T / \ | +// | / \ | +// | / \ | +// | / \ | +// NonView T NonView T +static ViewCachedGraph* createViewGraph(const Tensor& self, IntArrayRef size, IntArrayRef stride, int64_t storage_offset, bool needsScatter) +{ + IntArrayRef base_shape = get_buffer_shape(self.storage().data()); + if (base_shape.size() == 0) { + // IntArrayRef wouldn't own the data, so we use a static storage + static const int64_t shape_1d = 1; + // self.sizes().size() could be zero + base_shape = self.sizes().size() ? self.sizes() : IntArrayRef(&shape_1d, 1); + // base_shape will be retained in MPSAllocator until buffer gets recycled + if (self.storage().data()) + set_buffer_shape(self.storage().data(), base_shape); + } + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + string key = getStridedKey(self.scalar_type(), base_shape, size, needsScatter); + ViewCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + + if (!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + ViewCachedGraph *newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new ViewCachedGraph(mpsGraph); + // Workaround for MPSShaderLibrary bug + // TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved + auto inputType = getMPSScalarType(self.scalar_type()); + if (inputType == MPSDataTypeUInt8) { + inputType = MPSDataTypeInt8; + } + // Self is the input tensor we are creating view of + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape)); + newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]); + for (int i = 0; i < size.size(); i++) { + newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1])); + } + if (needsScatter) { + newCachedGraph->updatesTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); + } + newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter); + } + return newCachedGraph; + })); + } + return cachedGraph; + } +} + +Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) +{ + ViewCachedGraph* cachedGraph = nullptr; + + const IntArrayRef& base_shape = get_buffer_shape(src.storage().data()); + if (base_shape.size() > 0) { + string key = getStridedKey(src.scalar_type(), base_shape, src.sizes(), /*is_scatter*/ false); + cachedGraph = static_cast(MPSGraphCache::getInstance()->LookUp(key)); + } + // there are cases where gatherViewTensor() is called without having as_strided() called beforehand. + // this typically may come from copy_mps variants. In such cases, when the base_shape isn't found the + // callers would resort to make the tensor contiguous in an alternative code path. + if (!cachedGraph) { + return Tensor(); + } + + bool requires_sync = false; + Tensor output; + if (!dst.has_storage()) { + output = at::native::empty_mps(src.sizes(), src.scalar_type(), c10::nullopt, kMPS); + requires_sync = true; + } + return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false, requires_sync); +} + +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) +{ + ViewCachedGraph* cachedGraph = createViewGraph(output, output.sizes(), output.strides(), + output.storage_offset(), /*needsScatter*/ true); + return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true, /*requires_sync*/ true); +} + +} // namespace mps + +// implementation of as_strided() op +Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional storage_offset_) +{ + auto storage_offset = storage_offset_.value_or(self.storage_offset()); + auto result = detail::make_tensor(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); + setStrided(result, size, stride, storage_offset); + + // 0 sizes won't result in any change in the shape of the Tensor so we can skip it. + if (size.size() > 0) + mps::createViewGraph(self, size, stride, storage_offset, /*needsScatter*/ false); + + return result; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d1b7bcb17481e..ab6d38e553d30 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -210,6 +210,7 @@ variants: function dispatch: CUDA: fused_dropout_cuda + tags: nondeterministic_seeded - func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor variants: function @@ -221,6 +222,7 @@ dispatch: CPU: native_dropout_cpu CUDA: native_dropout_cuda + tags: nondeterministic_seeded - func: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor dispatch: @@ -243,6 +245,7 @@ dispatch: CompositeImplicitAutograd: dropout NestedTensorCPU, NestedTensorCUDA: dropout_nested + tags: nondeterministic_seeded - func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) dispatch: @@ -591,6 +594,8 @@ - func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool variants: function, method + dispatch: + CompositeExplicitAutograd: allclose - func: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator @@ -614,8 +619,12 @@ device_check: NoCheck # TensorIterator - func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: arange - func: arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: arange # Note [arange.start_step schema] # We want `arange.start_step` to be grouped up with `arange.start_out`, @@ -624,8 +633,12 @@ # We should probably just make "step" a defaultable param on arange.start, # and kill arange.start_step. - func: arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: arange - func: arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: arange_out - func: arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -660,6 +673,7 @@ structured: True dispatch: CPU, CUDA: argmin_out + MPS: argmin_out_mps - func: acosh(Tensor self) -> Tensor variants: function, method @@ -765,7 +779,7 @@ device_guard: False tags: inplace_view dispatch: - CompositeExplicitAutograd: as_strided_ + CompositeExplicitAutogradNonFunctional: as_strided_ - func: asin(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -872,8 +886,12 @@ SparseCsrCUDA: baddbmm_out_sparse_csr_cuda - func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: bartlett_window - func: bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: bartlett_window - func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor @@ -891,6 +909,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: bernoulli + tags: nondeterministic_seeded - func: bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -905,7 +924,7 @@ dispatch: CPU, CUDA: bernoulli_ MPS: bernoulli_mps_ - autogen: bernoulli.Tensor_functional, bernoulli.Tensor_out + autogen: bernoulli.Tensor, bernoulli.Tensor_out - func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -923,6 +942,7 @@ - func: bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor device_check: NoCheck # TensorIterator variants: function, method + tags: nondeterministic_seeded - func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor @@ -966,9 +986,6 @@ dispatch: CompositeExplicitAutograd: binary_cross_entropy_with_logits -- func: binary_cross_entropy_with_logits_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor - variants: function - - func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor variants: function, method dispatch: @@ -1040,6 +1057,7 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: logical_not_out + MPS: logical_not_out_mps - func: logical_xor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -1057,6 +1075,7 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: logical_xor_out + MPS: logical_xor_out_mps - func: logical_and(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -1074,6 +1093,7 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: logical_and_out + MPS: logical_and_out_mps - func: logical_or(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -1091,10 +1111,15 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: logical_or_out + MPS: logical_or_out_mps - func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: blackman_window - func: blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: blackman_window - func: bmm(Tensor self, Tensor mat2) -> Tensor structured_delegate: bmm.out @@ -1102,6 +1127,7 @@ dispatch: SparseCPU: bmm_sparse_cpu SparseCUDA: bmm_sparse_cuda + NestedTensorCPU, NestedTensorCUDA: bmm_nested - func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -1114,6 +1140,10 @@ SparseCUDA: bmm_out_sparse_cuda SparseCsrCUDA: bmm_out_sparse_csr_cuda +- func: _NestedTensor_GeneralizedBMM(Tensor self, Tensor mat2) -> Tensor + dispatch: + NestedTensorCPU, NestedTensorCUDA: _NestedTensor_GeneralizedBMM + - func: broadcast_tensors(Tensor[] tensors) -> Tensor[] device_check: NoCheck device_guard: False @@ -1365,6 +1395,7 @@ variants: function dispatch: CompositeExplicitAutograd: constant_pad_nd + MPS: constant_pad_nd_mps - func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) variants: method @@ -1900,7 +1931,7 @@ dispatch: CPU: embedding_renorm_cpu_ CUDA: embedding_renorm_cuda_ - autogen: embedding_renorm.functional, embedding_renorm.out + autogen: embedding_renorm, embedding_renorm.out - func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor @@ -1956,6 +1987,8 @@ - func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor device_check: NoCheck device_guard: False + dispatch: + CompositeExplicitAutograd: empty - func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor dispatch: @@ -1964,9 +1997,27 @@ MPS: empty_mps Meta: empty_meta MkldnnCPU: empty_mkldnn - SparseCPU, SparseCUDA: empty_sparse + SparseCPU, SparseCUDA, SparseMeta: empty_sparse SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed - QuantizedCPU, QuantizedCUDA: empty_unknown_quantized + QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized + +# all calls to empty() in python used to go through the symint overload +# even if all arguments were concerete integers. +# adding symint overloads of kernels for every dispatch key allowed us +# to skip redispatching to `empty.memory_format` and hit backend kernels directly +# we recently updated signature parsing to dispath `empty()` calls in python +# to `empty.SymInt` iff there's is a symint node argument +# hopefully, we could simplify this entry soon +- func: empty.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + CPU: empty_symint_cpu + CUDA: empty_symint_cuda + MPS: empty_symint_mps + Meta: empty_symint_meta + MkldnnCPU: empty_symint_mkldnn + SparseCPU, SparseCUDA, SparseMeta: empty_symint_sparse + SparseCsrCPU, SparseCsrCUDA: empty_symint_sparse_compressed + QuantizedCPU, QuantizedCUDA: empty_symint_unknown_quantized # We do not make new_empty a composite that calls into new_empty_strided, as the strided version # is significantly more difficult to implement by different backends @@ -1975,19 +2026,36 @@ dispatch: CompositeExplicitAutograd: new_empty +- func: new_empty.SymInt(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + variants: method + dispatch: + CompositeExplicitAutograd: new_empty_symint + - func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: method dispatch: - CompositeExplicitAutograd: new_empty_strided + CompositeExplicitAutogradNonFunctional: new_empty_strided - func: new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: method + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: new_full - func: new_zeros(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: method + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: new_zeros - func: new_ones(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: method + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: new_ones # other overrides are to provide a more helpful error message that dtype is required - func: _empty_affine_quantized(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor @@ -2014,7 +2082,7 @@ MPS: resize_mps_ QuantizedCPU: quantized_resize_cpu_ SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_ - autogen: resize.functional, resize.out + autogen: resize, resize.out # This is a utility function to enable users to resize out tensor while registering kernels for out variants. # Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration @@ -2024,7 +2092,7 @@ variants: function dispatch: Meta: _resize_output_ - autogen: _resize_output.functional, _resize_output.out + autogen: _resize_output, _resize_output.out - func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor category_override: factory @@ -2042,7 +2110,7 @@ dispatch: CompositeExplicitAutograd: empty_like QuantizedCPU, QuantizedCUDA: empty_like_quantized - SparseCPU, SparseCUDA: empty_like_sparse_coo + SparseCPU, SparseCUDA, SparseMeta: empty_like_sparse_coo SparseCsrCPU, SparseCsrCUDA: empty_like_sparse_csr - func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -2173,19 +2241,24 @@ device_check: NoCheck device_guard: False +# decomposes to eye.m - func: eye(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: eye - func: eye.m(int n, int m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: eye - func: eye.out(int n, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: eye_out_cpu + CPU, Meta: eye_out_cpu CUDA: eye_out_cuda MPS: eye_out_mps - func: eye.m_out(int n, int m, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: eye_out_cpu + CPU, Meta: eye_out_cpu CUDA: eye_out_cuda MPS: eye_out_mps @@ -2201,11 +2274,11 @@ - func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a) variants: function, method -- func: unflatten.int(Tensor(a) self, int dim, int[] sizes, Dimname[]? names=None) -> Tensor(a) - variants: method +- func: unflatten.int(Tensor(a) self, int dim, int[] sizes) -> Tensor(a) + variants: function, method - func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a) - variants: method + variants: function, method - func: fill.Scalar(Tensor self, Scalar value) -> Tensor variants: function @@ -2312,12 +2385,22 @@ - func: full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor device_check: NoCheck device_guard: False + dispatch: + CompositeExplicitAutograd: full - func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: full - func: full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: full_out - func: full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: full_like - func: from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -2404,22 +2487,40 @@ CUDA: grid_sampler_3d_backward_cuda - func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hann_window - func: hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hann_window - func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hamming_window - func: hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hamming_window - func: hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hamming_window - func: hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: hamming_window - func: kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: kaiser_window - func: kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: kaiser_window - func: kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: kaiser_window - func: hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor @@ -2473,6 +2574,13 @@ CPU: _fft_c2c_mkl_out CUDA: _fft_c2c_cufft_out +- func: _validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> () + device_check: NoCheck + variants: function + dispatch: + CPU: _validate_compressed_sparse_indices_cpu + CUDA: _validate_compressed_sparse_indices_cuda + - func: _cufft_get_plan_cache_size(int device_index) -> int - func: _cufft_get_plan_cache_max_size(int device_index) -> int @@ -2548,7 +2656,7 @@ dispatch: CPU, CUDA: _index_put_impl_ QuantizedCPU: _index_put_impl_quantized_cpu_ - autogen: _index_put_impl.functional, _index_put_impl.out + autogen: _index_put_impl, _index_put_impl.out - func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor variants: function @@ -2665,13 +2773,6 @@ manual_cpp_binding: True - func: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor - dispatch: - CompositeExplicitAutograd: kl_div - -- func: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor - dispatch: - CPU: kl_div_backward_cpu - CUDA: kl_div_backward_cuda - func: kron(Tensor self, Tensor other) -> Tensor variants: function, method @@ -2727,9 +2828,18 @@ - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor python_module: nn + dispatch: + CompositeImplicitAutograd: linear + NestedTensorCPU, NestedTensorCUDA: nested_linear + +- func: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + NestedTensorCPU, NestedTensorCUDA: nested_linear_backward - func: linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn + dispatch: + CompositeExplicitAutograd: linear_out # TODO: Add this function to MPS dispatch key so that we avoid declaring it in # native_functions.yaml @@ -2793,6 +2903,8 @@ - func: ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - func: linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: linspace - func: linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -2952,12 +3064,9 @@ dispatch: CompositeExplicitAutograd: xlogy_out -- func: logdet(Tensor self) -> Tensor - variants: function, method - dispatch: - CompositeExplicitAutograd: logdet - - func: logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: logspace - func: logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -3029,7 +3138,8 @@ - func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CompositeExplicitAutograd: logsumexp_out + # calls squeeze + CompositeExplicitAutogradNonFunctional: logsumexp_out - func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator @@ -3183,14 +3293,14 @@ dispatch: CompositeExplicitAutograd: mean -- func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +- func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor structured_delegate: mean.out device_check: NoCheck # TensorIterator variants: function, method dispatch: QuantizedCPU: mean_quantized_cpu -- func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) +- func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) structured: True device_check: NoCheck # TensorIterator dispatch: @@ -3413,6 +3523,7 @@ dispatch: CompositeExplicitAutograd: mul SparseCsrCPU, SparseCsrCUDA: mul_scalar_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Scalar - func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3420,6 +3531,7 @@ dispatch: CompositeExplicitAutograd: mul_ SparseCsrCPU, SparseCsrCUDA: mul__scalar_sparse_csr + NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul__Scalar autogen: mul.Scalar_out # multiply, alias for mul @@ -3555,12 +3667,22 @@ - func: ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor device_check: NoCheck device_guard: False + dispatch: + CompositeExplicitAutograd: ones - func: ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: ones - func: ones.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: ones_out - func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: ones_like - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor @@ -3596,6 +3718,7 @@ dispatch: CompositeExplicitAutograd: permute MPS: permute_mps + SparseCPU, SparseCUDA: permute_sparse_coo - func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) variants: function, method @@ -3638,12 +3761,12 @@ - func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor dispatch: CPU: pixel_shuffle_cpu - CompositeExplicitAutograd: math_pixel_shuffle + CompositeExplicitAutogradNonFunctional: math_pixel_shuffle - func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor dispatch: CPU: pixel_unshuffle_cpu - CompositeExplicitAutograd: math_pixel_unshuffle + CompositeExplicitAutogradNonFunctional: math_pixel_unshuffle - func: channel_shuffle(Tensor self, int groups) -> Tensor dispatch: @@ -3711,68 +3834,135 @@ CompositeExplicitAutograd: deg2rad_out - func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: scalar_tensor - func: rand.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor device_check: NoCheck device_guard: False + dispatch: + CompositeExplicitAutograd: rand - func: rand.generator_with_names(int[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor device_check: NoCheck device_guard: False + dispatch: + CompositeExplicitAutograd: rand - func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: rand - func: rand.generator(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: rand - func: rand.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: rand_out - func: rand.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) - func: rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: rand_like -- func: randint(int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: randint(int high, int[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint -- func: randint.generator(int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: randint.generator(int high, int[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: randint -- func: randint.low(int low, int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: randint.low(int low, int high, int[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randint -- func: randint.low_generator(int low, int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: randint.low_generator(int low, int high, int[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: randint - func: randint.out(int high, int[] size, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: randint_out - func: randint.generator_out(int high, int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: randint_out - func: randint.low_out(int low, int high, int[] size, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: randint_out - func: randint.low_generator_out(int low, int high, int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: randint_out - func: randint_like(Tensor self, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like - func: randint_like.low_dtype(Tensor self, int low, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like - func: randn(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randn - func: randn.generator(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: randn - func: randn.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor device_check: NoCheck device_guard: False + dispatch: + CompositeExplicitAutograd: randn - func: randn.generator_with_names(int[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor device_check: NoCheck device_guard: False + dispatch: + CompositeExplicitAutograd: randn - func: randn.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) - func: randn.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) - func: randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randn_like - func: randperm(int n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: randperm - func: randperm.generator(int n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: randperm - func: randperm.out(int n, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: randperm_out - func: randperm.generator_out(int n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -3780,8 +3970,12 @@ CUDA: randperm_out_cuda - func: range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: range - func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: range - func: range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -3868,6 +4062,14 @@ device_check: NoCheck device_guard: False +- func: _reshape_nested(Tensor self, int[] shape) -> Tensor + dispatch: + NestedTensorCPU, NestedTensorCUDA: _reshape_nested + +- func: _reshape_nested_backward(Tensor self, Tensor grad) -> Tensor + dispatch: + NestedTensorCPU, NestedTensorCUDA: _reshape_nested_backward + # NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape. # They are not user-facing, hence the leading underscore. Please don't use it # anywhere else. @@ -3937,6 +4139,7 @@ - func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded - func: rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3974,6 +4177,8 @@ MkldnnCPU: mkldnn_prelu CPU: prelu_cpu CUDA: prelu_cuda + MPS: prelu_mps + QuantizedCPU: prelu_quantized_cpu - func: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) variants: function, method @@ -3981,6 +4186,7 @@ MkldnnCPU: mkldnn_prelu_backward CPU: prelu_backward_cpu CUDA: prelu_backward_cuda + MPS: prelu_backward_mps - func: gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) structured: True @@ -4089,7 +4295,7 @@ device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutograd: select_backward + CompositeExplicitAutogradNonFunctional: select_backward - func: selu(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -4340,17 +4546,15 @@ dispatch: CompositeExplicitAutograd: as_strided_scatter -- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) - variants: function, method - dispatch: - CompositeExplicitAutograd: slogdet - - func: smm(Tensor self, Tensor mat2) -> Tensor variants: function, method # softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. - func: softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor variants: function, method + dispatch: + CompositeImplicitAutograd: softmax + NestedTensorCPU, NestedTensorCUDA: softmax - func: softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) variants: function @@ -4364,6 +4568,7 @@ structured_delegate: _softmax.out dispatch: MkldnnCPU: mkldnn_softmax + NestedTensorCPU, NestedTensorCUDA: softmax_nested - func: _softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -4544,16 +4749,24 @@ CompositeExplicitAutograd: sum SparseCsrCPU, SparseCsrCUDA: sum_csr -- func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +- func: sum.SymInt(Tensor self, SymInt[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + device_check: NoCheck # TensorIterator + variants: function, method + dispatch: + CompositeExplicitAutograd: sum_symint + +- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor structured_delegate: sum.IntList_out device_check: NoCheck # TensorIterator variants: function, method + dispatch: + NestedTensorCPU: NestedTensor_sum_dim_CPU - func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor device_check: NoCheck # TensorIterator variants: function, method -- func: sum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) +- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) structured: True device_check: NoCheck # TensorIterator dispatch: @@ -4627,6 +4840,7 @@ dispatch: CPU, CUDA: std MPS: std_mps + QuantizedCPU: std_quantized_cpu - func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator @@ -4657,6 +4871,7 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: std_out + QuantizedCPU: std_out_quantized_cpu - func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator @@ -4820,6 +5035,7 @@ device_guard: False dispatch: CompositeExplicitAutograd: transpose + NestedTensorCPU, NestedTensorCUDA: transpose_nested - func: transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) variants: function, method @@ -4856,6 +5072,7 @@ variants: function, method dispatch: CPU, QuantizedCPU, CUDA, QuantizedCUDA: flip + MPS: flip_mps - func: fliplr(Tensor self) -> Tensor variants: function, method @@ -4917,7 +5134,8 @@ - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor dispatch: - CompositeExplicitAutograd: _trilinear + # calls unsqueeze + CompositeExplicitAutogradNonFunctional: _trilinear - func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor @@ -5140,6 +5358,8 @@ - func: zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor device_check: NoCheck device_guard: False + dispatch: + CompositeExplicitAutograd: zeros - func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -5147,10 +5367,23 @@ CUDA: _efficientzerotensor_cuda - func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: zeros + +- func: zeros.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: zeros_symint - func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: zeros_out + SparseCPU, SparseCUDA, SparseMeta: zeros_out - func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: zeros_like - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor variants: function @@ -5163,6 +5396,7 @@ dispatch: CPU: _s_gamma_cpu CUDA: _s_gamma_cuda + tags: nondeterministic_seeded - func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor dispatch: @@ -5180,12 +5414,14 @@ dispatch: CPU: _s_poisson_cpu CUDA: _s_poisson_cuda + tags: nondeterministic_seeded - func: binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor device_check: NoCheck # TensorIterator dispatch: CPU: _s_binomial_cpu CUDA: _s_binomial_cuda + tags: nondeterministic_seeded # When more variants get ported to native, this dispatch will get more # complicated @@ -5262,6 +5498,11 @@ SparseCPU: log_softmax_backward_sparse_cpu SparseCUDA: log_softmax_backward_sparse_cuda +- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor + python_module: sparse + dispatch: + CPU: spdiags + - func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor device_check: NoCheck # TensorIterator variants: function, method @@ -5371,7 +5612,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: resize_as_ - autogen: resize_as.functional, resize_as.out + autogen: resize_as, resize_as.out - func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) use_const_ref_for_mutable_tensors: True @@ -5379,7 +5620,7 @@ dispatch: SparseCPU, SparseCUDA: resize_as_sparse_ SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_ - autogen: resize_as_sparse.functional, resize_as_sparse.out + autogen: resize_as_sparse, resize_as_sparse.out - func: zero_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5388,10 +5629,10 @@ CPU, CUDA: zero_ MPS: zero_mps_ Meta: zero_meta_ - SparseCPU, SparseCUDA: zero_sparse_ + SparseCPU, SparseCUDA, SparseMeta: zero_sparse_ SparseCsrCPU, SparseCsrCUDA: zero_sparse_csr_ MkldnnCPU: mkldnn_zero_ - autogen: zero.functional, zero.out + autogen: zero, zero.out - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5664,6 +5905,8 @@ - func: _sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + dispatch: + CompositeExplicitAutograd: sparse_coo_tensor - func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -5681,25 +5924,25 @@ - func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor dispatch: - SparseCPU, SparseCUDA: new_with_dims_sparse + SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse - func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor dispatch: - SparseCPU, SparseCUDA: new_with_dims_and_tensor_sparse + SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse - func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: method dispatch: - SparseCPU, SparseCUDA: sparse_resize_ - autogen: sparse_resize.functional, sparse_resize.out + SparseCPU, SparseCUDA, SparseMeta: sparse_resize_ + autogen: sparse_resize, sparse_resize.out - func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: method dispatch: - SparseCPU, SparseCUDA: sparse_resize_and_clear_ - autogen: sparse_resize_and_clear.functional, sparse_resize_and_clear.out + SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_ + autogen: sparse_resize_and_clear, sparse_resize_and_clear.out - func: sparse_mask(Tensor self, Tensor mask) -> Tensor variants: method @@ -5727,7 +5970,8 @@ - func: sparse_dim(Tensor self) -> int variants: method dispatch: - SparseCPU, SparseCUDA: sparse_dim_sparse + SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse + SparseCsrCPU, SparseCsrCUDA: sparse_dim_sparse_csr device_check: NoCheck device_guard: False @@ -5742,7 +5986,8 @@ - func: dense_dim(Tensor self) -> int variants: method dispatch: - SparseCPU, SparseCUDA: dense_dim_sparse + SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse + SparseCsrCPU, SparseCsrCUDA: dense_dim_sparse_csr device_check: NoCheck device_guard: False @@ -5750,14 +5995,14 @@ - func: _dimV(Tensor self) -> int variants: method dispatch: - SparseCPU, SparseCUDA: dense_dim_sparse + SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse device_check: NoCheck device_guard: False - func: _nnz(Tensor self) -> int variants: method dispatch: - SparseCPU, SparseCUDA: _nnz_sparse + SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse SparseCsrCPU, SparseCsrCUDA: _nnz_sparse_csr device_check: NoCheck device_guard: False @@ -5778,21 +6023,21 @@ - func: is_coalesced(Tensor self) -> bool variants: method dispatch: - SparseCPU, SparseCUDA: is_coalesced_sparse + SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse device_check: NoCheck device_guard: False - func: _indices(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA: _indices_sparse + SparseCPU, SparseCUDA, SparseMeta: _indices_sparse device_check: NoCheck device_guard: False - func: _values(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA: _values_sparse + SparseCPU, SparseCUDA, SparseMeta: _values_sparse device_check: NoCheck device_guard: False @@ -5802,22 +6047,22 @@ - func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) variants: method dispatch: - SparseCPU, SparseCUDA: _coalesced_sparse_ + SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_ device_check: NoCheck device_guard: False - autogen: _coalesced.functional, _coalesced.out + autogen: _coalesced, _coalesced.out - func: indices(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA: indices_sparse + SparseCPU, SparseCUDA, SparseMeta: indices_sparse device_check: NoCheck device_guard: False - func: values(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA: values_sparse + SparseCPU, SparseCUDA, SparseMeta: values_sparse SparseCsrCPU, SparseCsrCUDA: values_sparse_csr device_check: NoCheck device_guard: False @@ -5865,7 +6110,7 @@ variants: function dispatch: SparseCPU, SparseCUDA: copy_sparse_ - autogen: copy_sparse_to_sparse.functional, copy_sparse_to_sparse.out + autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out - func: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] variants: function, method @@ -6074,7 +6319,7 @@ dispatch: CPU: fused_moving_avg_obs_fake_quant_cpu CUDA: fused_moving_avg_obs_fake_quant_cuda - autogen: _fused_moving_avg_obs_fq_helper.functional, _fused_moving_avg_obs_fq_helper.out + autogen: _fused_moving_avg_obs_fq_helper_functional, _fused_moving_avg_obs_fq_helper.out - func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int) variants: function @@ -6268,7 +6513,7 @@ device_guard: False dispatch: CPU, CUDA, Meta, MPS: set_ - autogen: set.source_Storage_functional, set.source_Storage_out + autogen: set.source_Storage, set.source_Storage_out - func: set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!) variants: method @@ -6279,7 +6524,7 @@ CUDA: set_storage_cuda_ MPS: set_storage_mps_ QuantizedCPU, QuantizedCUDA: set_storage_quantized_ - autogen: set.source_Storage_storage_offset_functional, set.source_Storage_storage_offset_out + autogen: set.source_Storage_storage_offset, set.source_Storage_storage_offset_out - func: set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!) variants: method @@ -6292,7 +6537,7 @@ device_guard: False dispatch: CPU, CUDA, Meta, MPS: set_tensor_ - autogen: set.source_Tensor_functional, set.source_Tensor_out + autogen: set.source_Tensor, set.source_Tensor_out - func: set_(Tensor(a!) self) -> Tensor(a!) variants: method @@ -6301,15 +6546,32 @@ CUDA: set_cuda_ Meta: set_meta_ MPS: set_mps_ - autogen: set.functional, set.out + autogen: set, set.out +# Not making it CompositeImplicitAutograd because lift +# should be a primitive w.r.t. functorch + +# TODO: this should have a view annotation +# TODO: shouldn't be a method - func: lift(Tensor self) -> Tensor - variants: method dispatch: - # Not making it CompositeImplicitAutograd because lift - # should be a primitive w.r.t. functorch CompositeExplicitAutograd: lift +# lift_fresh is called with an argument that is guaranteed to be +# fresh (i.e., newly allocated). This is ONLY called from a +# torch.tensor call; if you FX trace a lift_fresh, you are obligated +# to convert this into a lift_fresh_copy (because FX will violate the +# freshness invariant when tracing). +- func: lift_fresh(Tensor(a) self) -> Tensor(a) + dispatch: + CompositeExplicitAutograd: lift_fresh + +# Like lift, but it clones the input. +- func: lift_fresh_copy(Tensor self) -> Tensor + tags: view_copy + dispatch: + CompositeExplicitAutograd: lift_fresh_copy + - func: is_set_to(Tensor self, Tensor tensor) -> bool variants: method device_check: NoCheck @@ -6371,6 +6633,14 @@ CUDA: masked_softmax_backward_cuda CPU: masked_softmax_backward_cpu +- func: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a) + variants: method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: view_symint + MkldnnCPU: mkldnn_view_symint + - func: view(Tensor(a) self, int[] size) -> Tensor(a) variants: method device_check: NoCheck @@ -6400,6 +6670,8 @@ - func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor variants: function, method + dispatch: + CompositeExplicitAutograd: put - func: index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) structured: True @@ -6409,6 +6681,7 @@ dispatch: CPU: index_add_cpu_out CUDA: index_add_cuda_out + MPS: index_add_mps_out - func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) structured_delegate: index_add.out @@ -6934,7 +7207,7 @@ CPU, CUDA: random_ Meta: random_meta_ MPS: random_mps_ - autogen: random.from_functional, random.from_out + autogen: random.from, random.from_out - func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6943,7 +7216,7 @@ CPU, CUDA: random_ Meta: random_meta_ MPS: random_mps_ - autogen: random.to_functional, random.to_out + autogen: random.to, random.to_out - func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6951,7 +7224,7 @@ dispatch: CPU, CUDA: random_ Meta: random_meta_ - autogen: random.functional, random.out + autogen: random, random.out - func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6960,21 +7233,21 @@ CPU, CUDA: uniform_ MPS: uniform_mps_ Meta: uniform_meta_ - autogen: uniform.functional, uniform.out + autogen: uniform, uniform.out - func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CPU, CUDA: cauchy_ - autogen: cauchy.functional, cauchy.out + autogen: cauchy, cauchy.out - func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CPU, CUDA: log_normal_ - autogen: log_normal.functional, log_normal.out + autogen: log_normal, log_normal.out - func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6982,7 +7255,7 @@ dispatch: CPU, CUDA: exponential_ MPS: exponential_mps_ - autogen: exponential.functional, exponential.out + autogen: exponential, exponential.out - func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6991,7 +7264,7 @@ CPU, CUDA: geometric_ # wrappers for TH functions - autogen: geometric.functional, geometric.out + autogen: geometric, geometric.out - func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -7707,6 +7980,7 @@ variants: method, function dispatch: CPU, CUDA: multinomial + tags: nondeterministic_seeded - func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8296,6 +8570,7 @@ dispatch: CPU: cpu_equal CUDA: cuda_equal + MPS: mps_equal QuantizedCPU: equal_quantized_cpu - func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) @@ -8375,7 +8650,15 @@ MPS: normal_mps_ Meta: normal_meta_ SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_ - autogen: normal.functional, normal.out + autogen: normal.out + +# Only used by the functionalization pass. +# Normally, the codegen would be able to generate a normal() NativeFunction, +# but we can't due to overload ambiguity with normal.Tensor_float. +- func: normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: normal_functional - func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -8386,8 +8669,9 @@ - func: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor dispatch: CPU, CUDA: normal - #MPS: normal_mps + MPS: normal_mps Meta: normal_meta + tags: nondeterministic_seeded - func: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -8398,8 +8682,9 @@ - func: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor dispatch: CPU, CUDA: normal + MPS: normal_mps Meta: normal_meta - #MPS: normal_mps + tags: nondeterministic_seeded - func: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -8410,12 +8695,17 @@ - func: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor dispatch: CPU, CUDA: normal + MPS: normal_mps Meta: normal_meta - #MPS: normal_mps + tags: nondeterministic_seeded - func: normal.float_float(float mean, float std, int[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeExplicitAutograd: normal - func: normal.float_float_out(float mean, float std, int[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeExplicitAutograd: normal_out - func: alias(Tensor(a) self) -> Tensor(a) variants: method, function @@ -8426,13 +8716,13 @@ variants: function dispatch: CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_ - autogen: _amp_foreach_non_finite_check_and_unscale.functional, _amp_foreach_non_finite_check_and_unscale.out + autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out - func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) variants: function dispatch: CUDA: _amp_update_scale_cuda_ - autogen: _amp_update_scale.functional, _amp_update_scale.out + autogen: _amp_update_scale, _amp_update_scale.out #- func: _cat(Tensor[] tensors, int dim=0) -> Tensor #dispatch: @@ -8447,7 +8737,7 @@ #CUDA: cat_out_cuda #QuantizedCPU: cat_out_quantized_cpu -- func: _foreach_add.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] +- func: _foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8460,9 +8750,9 @@ dispatch: CPU: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ - autogen: _foreach_add.Scalar_functional, _foreach_add.Scalar_out + autogen: _foreach_add.Scalar_out -- func: _foreach_sub.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] +- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8475,9 +8765,9 @@ dispatch: CPU: foreach_tensor_sub_scalar_kernel_slow_ CUDA: foreach_tensor_sub_scalar_kernel_cuda_ - autogen: _foreach_sub.Scalar_functional, _foreach_sub.Scalar_out + autogen: _foreach_sub.Scalar_out -- func: _foreach_mul.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] +- func: _foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8490,9 +8780,9 @@ dispatch: CPU: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ - autogen: _foreach_mul.Scalar_functional, _foreach_mul.Scalar_out + autogen: _foreach_mul.Scalar_out -- func: _foreach_div.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] +- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8505,9 +8795,9 @@ dispatch: CPU: foreach_tensor_div_scalar_kernel_slow_ CUDA: foreach_tensor_div_scalar_kernel_cuda_ - autogen: _foreach_div.Scalar_functional, _foreach_div.Scalar_out + autogen: _foreach_div.Scalar_out -- func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[] +- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8520,9 +8810,9 @@ dispatch: CPU: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ - autogen: _foreach_add.List_functional, _foreach_add.List_out + autogen: _foreach_add.List_out -- func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[] +- func: _foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8535,9 +8825,9 @@ dispatch: CPU: foreach_tensor_sub_list_kernel_slow_ CUDA: foreach_tensor_sub_list_kernel_cuda_ - autogen: _foreach_sub.List_functional, _foreach_sub.List_out + autogen: _foreach_sub.List_out -- func: _foreach_mul.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] +- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8550,9 +8840,9 @@ dispatch: CPU: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ - autogen: _foreach_mul.List_functional, _foreach_mul.List_out + autogen: _foreach_mul.List_out -- func: _foreach_div.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] +- func: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8565,9 +8855,9 @@ dispatch: CPU: foreach_tensor_div_list_kernel_slow_ CUDA: foreach_tensor_div_list_kernel_cuda_ - autogen: _foreach_div.List_functional, _foreach_div.List_out + autogen: _foreach_div.List_out -- func: _foreach_add.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[] +- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8580,9 +8870,9 @@ dispatch: CPU: foreach_tensor_add_scalarlist_kernel_slow_ CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ - autogen: _foreach_add.ScalarList_functional, _foreach_add.ScalarList_out + autogen: _foreach_add.ScalarList_out -- func: _foreach_sub.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[] +- func: _foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8595,9 +8885,9 @@ dispatch: CPU: foreach_tensor_sub_scalarlist_kernel_slow_ CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ - autogen: _foreach_sub.ScalarList_functional, _foreach_sub.ScalarList_out + autogen: _foreach_sub.ScalarList_out -- func: _foreach_div.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[] +- func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8610,9 +8900,9 @@ dispatch: CPU: foreach_tensor_div_scalarlist_kernel_slow_ CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ - autogen: _foreach_div.ScalarList_functional, _foreach_div.ScalarList_out + autogen: _foreach_div.ScalarList_out -- func: _foreach_mul.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[] +- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8625,9 +8915,9 @@ dispatch: CPU: foreach_tensor_mul_scalarlist_kernel_slow_ CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ - autogen: _foreach_mul.ScalarList_functional, _foreach_mul.ScalarList_out + autogen: _foreach_mul.ScalarList_out -- func: _foreach_exp(Tensor[] tensors) -> Tensor[] +- func: _foreach_exp(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8640,7 +8930,7 @@ dispatch: CPU: foreach_tensor_zero_slow_ CUDA: foreach_tensor_zero_cuda_ - autogen: _foreach_zero.functional, _foreach_zero.out + autogen: _foreach_zero, _foreach_zero.out - func: _foreach_exp_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -8648,9 +8938,9 @@ dispatch: CPU: foreach_tensor_exp_slow_ CUDA: foreach_tensor_exp_cuda_ - autogen: _foreach_exp.functional, _foreach_exp.out + autogen: _foreach_exp.out -- func: _foreach_sqrt(Tensor[] tensors) -> Tensor[] +- func: _foreach_sqrt(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8663,9 +8953,9 @@ dispatch: CPU: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ - autogen: _foreach_sqrt.functional, _foreach_sqrt.out + autogen: _foreach_sqrt.out -- func: _foreach_abs(Tensor[] tensors) -> Tensor[] +- func: _foreach_abs(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8678,9 +8968,9 @@ dispatch: CPU: foreach_tensor_abs_slow_ CUDA: foreach_tensor_abs_cuda_ - autogen: _foreach_abs.functional, _foreach_abs.out + autogen: _foreach_abs.out -- func: _foreach_acos(Tensor[] tensors) -> Tensor[] +- func: _foreach_acos(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8693,9 +8983,9 @@ dispatch: CPU: foreach_tensor_acos_slow_ CUDA: foreach_tensor_acos_cuda_ - autogen: _foreach_acos.functional, _foreach_acos.out + autogen: _foreach_acos.out -- func: _foreach_asin(Tensor[] tensors) -> Tensor[] +- func: _foreach_asin(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8708,9 +8998,9 @@ dispatch: CPU: foreach_tensor_asin_slow_ CUDA: foreach_tensor_asin_cuda_ - autogen: _foreach_asin.functional, _foreach_asin.out + autogen: _foreach_asin.out -- func: _foreach_atan(Tensor[] tensors) -> Tensor[] +- func: _foreach_atan(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8723,9 +9013,9 @@ dispatch: CPU: foreach_tensor_atan_slow_ CUDA: foreach_tensor_atan_cuda_ - autogen: _foreach_atan.functional, _foreach_atan.out + autogen: _foreach_atan.out -- func: _foreach_ceil(Tensor[] tensors) -> Tensor[] +- func: _foreach_ceil(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8738,9 +9028,9 @@ dispatch: CPU: foreach_tensor_ceil_slow_ CUDA: foreach_tensor_ceil_cuda_ - autogen: _foreach_ceil.functional, _foreach_ceil.out + autogen: _foreach_ceil.out -- func: _foreach_cos(Tensor[] tensors) -> Tensor[] +- func: _foreach_cos(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8753,9 +9043,9 @@ dispatch: CPU: foreach_tensor_cos_slow_ CUDA: foreach_tensor_cos_cuda_ - autogen: _foreach_cos.functional, _foreach_cos.out + autogen: _foreach_cos.out -- func: _foreach_cosh(Tensor[] tensors) -> Tensor[] +- func: _foreach_cosh(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8768,9 +9058,9 @@ dispatch: CPU: foreach_tensor_cosh_slow_ CUDA: foreach_tensor_cosh_cuda_ - autogen: _foreach_cosh.functional, _foreach_cosh.out + autogen: _foreach_cosh.out -- func: _foreach_erf(Tensor[] tensors) -> Tensor[] +- func: _foreach_erf(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8783,9 +9073,9 @@ dispatch: CPU: foreach_tensor_erf_slow_ CUDA: foreach_tensor_erf_cuda_ - autogen: _foreach_erf.functional, _foreach_erf.out + autogen: _foreach_erf.out -- func: _foreach_erfc(Tensor[] tensors) -> Tensor[] +- func: _foreach_erfc(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8798,9 +9088,9 @@ dispatch: CPU: foreach_tensor_erfc_slow_ CUDA: foreach_tensor_erfc_cuda_ - autogen: _foreach_erfc.functional, _foreach_erfc.out + autogen: _foreach_erfc.out -- func: _foreach_expm1(Tensor[] tensors) -> Tensor[] +- func: _foreach_expm1(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8813,9 +9103,9 @@ dispatch: CPU: foreach_tensor_expm1_slow_ CUDA: foreach_tensor_expm1_cuda_ - autogen: _foreach_expm1.functional, _foreach_expm1.out + autogen: _foreach_expm1.out -- func: _foreach_floor(Tensor[] tensors) -> Tensor[] +- func: _foreach_floor(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8828,9 +9118,9 @@ dispatch: CPU: foreach_tensor_floor_slow_ CUDA: foreach_tensor_floor_cuda_ - autogen: _foreach_floor.functional, _foreach_floor.out + autogen: _foreach_floor.out -- func: _foreach_log(Tensor[] tensors) -> Tensor[] +- func: _foreach_log(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8843,9 +9133,9 @@ dispatch: CPU: foreach_tensor_log_slow_ CUDA: foreach_tensor_log_cuda_ - autogen: _foreach_log.functional, _foreach_log.out + autogen: _foreach_log.out -- func: _foreach_log10(Tensor[] tensors) -> Tensor[] +- func: _foreach_log10(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8858,9 +9148,9 @@ dispatch: CPU: foreach_tensor_log10_slow_ CUDA: foreach_tensor_log10_cuda_ - autogen: _foreach_log10.functional, _foreach_log10.out + autogen: _foreach_log10.out -- func: _foreach_log1p(Tensor[] tensors) -> Tensor[] +- func: _foreach_log1p(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8873,9 +9163,9 @@ dispatch: CPU: foreach_tensor_log1p_slow_ CUDA: foreach_tensor_log1p_cuda_ - autogen: _foreach_log1p.functional, _foreach_log1p.out + autogen: _foreach_log1p.out -- func: _foreach_log2(Tensor[] tensors) -> Tensor[] +- func: _foreach_log2(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8888,9 +9178,9 @@ dispatch: CPU: foreach_tensor_log2_slow_ CUDA: foreach_tensor_log2_cuda_ - autogen: _foreach_log2.functional, _foreach_log2.out + autogen: _foreach_log2.out -- func: _foreach_neg(Tensor[] tensors) -> Tensor[] +- func: _foreach_neg(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8903,9 +9193,9 @@ dispatch: CPU: foreach_tensor_neg_slow_ CUDA: foreach_tensor_neg_cuda_ - autogen: _foreach_neg.functional, _foreach_neg.out + autogen: _foreach_neg.out -- func: _foreach_tan(Tensor[] tensors) -> Tensor[] +- func: _foreach_tan(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8918,9 +9208,9 @@ dispatch: CPU: foreach_tensor_tan_slow_ CUDA: foreach_tensor_tan_cuda_ - autogen: _foreach_tan.functional, _foreach_tan.out + autogen: _foreach_tan.out -- func: _foreach_tanh(Tensor[] tensors) -> Tensor[] +- func: _foreach_tanh(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8933,9 +9223,9 @@ dispatch: CPU: foreach_tensor_tanh_slow_ CUDA: foreach_tensor_tanh_cuda_ - autogen: _foreach_tanh.functional, _foreach_tanh.out + autogen: _foreach_tanh.out -- func: _foreach_sin(Tensor[] tensors) -> Tensor[] +- func: _foreach_sin(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8948,9 +9238,9 @@ dispatch: CPU: foreach_tensor_sin_slow_ CUDA: foreach_tensor_sin_cuda_ - autogen: _foreach_sin.functional, _foreach_sin.out + autogen: _foreach_sin.out -- func: _foreach_sinh(Tensor[] tensors) -> Tensor[] +- func: _foreach_sinh(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8963,9 +9253,9 @@ dispatch: CPU: foreach_tensor_sinh_slow_ CUDA: foreach_tensor_sinh_cuda_ - autogen: _foreach_sinh.functional, _foreach_sinh.out + autogen: _foreach_sinh.out -- func: _foreach_round(Tensor[] tensors) -> Tensor[] +- func: _foreach_round(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8978,9 +9268,9 @@ dispatch: CPU: foreach_tensor_round_slow_ CUDA: foreach_tensor_round_cuda_ - autogen: _foreach_round.functional, _foreach_round.out + autogen: _foreach_round.out -- func: _foreach_lgamma(Tensor[] tensors) -> Tensor[] +- func: _foreach_lgamma(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -8993,9 +9283,9 @@ dispatch: CPU: foreach_tensor_lgamma_slow_ CUDA: foreach_tensor_lgamma_cuda_ - autogen: _foreach_lgamma.functional, _foreach_lgamma.out + autogen: _foreach_lgamma.out -- func: _foreach_frac(Tensor[] tensors) -> Tensor[] +- func: _foreach_frac(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -9008,9 +9298,9 @@ dispatch: CPU: foreach_tensor_frac_slow_ CUDA: foreach_tensor_frac_cuda_ - autogen: _foreach_frac.functional, _foreach_frac.out + autogen: _foreach_frac.out -- func: _foreach_reciprocal(Tensor[] tensors) -> Tensor[] +- func: _foreach_reciprocal(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -9023,9 +9313,9 @@ dispatch: CPU: foreach_tensor_reciprocal_slow_ CUDA: foreach_tensor_reciprocal_cuda_ - autogen: _foreach_reciprocal.functional, _foreach_reciprocal.out + autogen: _foreach_reciprocal.out -- func: _foreach_sigmoid(Tensor[] tensors) -> Tensor[] +- func: _foreach_sigmoid(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -9038,9 +9328,9 @@ dispatch: CPU: foreach_tensor_sigmoid_slow_ CUDA: foreach_tensor_sigmoid_cuda_ - autogen: _foreach_sigmoid.functional, _foreach_sigmoid.out + autogen: _foreach_sigmoid.out -- func: _foreach_trunc(Tensor[] tensors) -> Tensor[] +- func: _foreach_trunc(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -9053,7 +9343,7 @@ dispatch: CPU: foreach_tensor_trunc_slow_ CUDA: foreach_tensor_trunc_cuda_ - autogen: _foreach_trunc.functional, _foreach_trunc.out + autogen: _foreach_trunc.out - func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -9061,7 +9351,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalar_slow_ CUDA: foreach_tensor_addcdiv_scalar_cuda_ - autogen: _foreach_addcdiv.Scalar_functional, _foreach_addcdiv.Scalar_out + autogen: _foreach_addcdiv.Scalar_out - func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -9069,7 +9359,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalar_slow_ CUDA: foreach_tensor_addcmul_scalar_cuda_ - autogen: _foreach_addcmul.Scalar_functional, _foreach_addcmul.Scalar_out + autogen: _foreach_addcmul.Scalar_out - func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -9077,7 +9367,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalarlist_slow_ CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ - autogen: _foreach_addcdiv.ScalarList_functional, _foreach_addcdiv.ScalarList_out + autogen: _foreach_addcdiv.ScalarList_out - func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -9085,51 +9375,51 @@ dispatch: CPU: foreach_tensor_addcmul_scalarlist_slow_ CUDA: foreach_tensor_addcmul_scalarlist_cuda_ - autogen: _foreach_addcmul.ScalarList_functional, _foreach_addcmul.ScalarList_out + autogen: _foreach_addcmul.ScalarList_out -- func: _foreach_addcdiv.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] +- func: _foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: CPU: foreach_tensor_addcdiv_scalar_slow CUDA: foreach_tensor_addcdiv_scalar_cuda -- func: _foreach_addcmul.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] +- func: _foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: CPU: foreach_tensor_addcmul_scalar_slow CUDA: foreach_tensor_addcmul_scalar_cuda -- func: _foreach_addcdiv.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] +- func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: CPU: foreach_tensor_addcdiv_scalarlist_slow CUDA: foreach_tensor_addcdiv_scalarlist_cuda -- func: _foreach_addcmul.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] +- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: CPU: foreach_tensor_addcmul_scalarlist_slow CUDA: foreach_tensor_addcmul_scalarlist_cuda -- func: _foreach_maximum.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] +- func: _foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: CPU: foreach_tensor_maximum_slow CUDA: foreach_tensor_maximum_cuda -- func: _foreach_minimum.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] +- func: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: CPU: foreach_tensor_minimum_slow CUDA: foreach_tensor_minimum_cuda -- func: _foreach_norm.Scalar(Tensor[] tensors, Scalar ord=2) -> Tensor[] +- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: @@ -9376,16 +9666,19 @@ python_module: nn dispatch: CPU, CUDA: huber_loss_out + MPS: huber_loss_out_mps - func: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor python_module: nn dispatch: CPU, CUDA: huber_loss + MPS: huber_loss_mps - func: huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU, CUDA: huber_loss_backward_out + MPS: huber_loss_backward_out_mps - func: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor python_module: nn @@ -9449,6 +9742,7 @@ python_module: nn dispatch: CPU, CUDA: glu_out + MPS: glu_out_mps - func: glu(Tensor self, int dim=-1) -> Tensor structured_delegate: glu.out @@ -9460,12 +9754,14 @@ dispatch: CPU: glu_backward_cpu_out CUDA: glu_backward_cuda_out + MPS: glu_backward_mps_out - func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor python_module: nn dispatch: CPU: glu_backward_cpu CUDA: glu_backward_cuda + MPS: glu_backward_mps - func: glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor python_module: nn @@ -9646,6 +9942,7 @@ dispatch: CPU: rrelu_with_noise_cpu CUDA: rrelu_with_noise_cuda + tags: nondeterministic_seeded - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor python_module: nn @@ -9678,6 +9975,7 @@ python_module: nn dispatch: CPU, CUDA: softplus_backward_out + MPS: softplus_backward_out_mps - func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor structured_delegate: softplus_backward.grad_input @@ -10444,6 +10742,7 @@ dispatch: CPU: upsample_nearest1d_out_cpu CUDA: upsample_nearest1d_out_cuda + MPS: upsample_nearest1d_out_mps - func: _upsample_nearest_exact1d.out(Tensor self, int[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -10796,6 +11095,7 @@ dispatch: CompositeExplicitAutograd: isinf SparseCPU, SparseCUDA: isinf_sparse + SparseMeta: isinf_sparse_meta SparseCsrCPU, SparseCsrCUDA: isinf_sparse_csr - func: record_stream(Tensor(a!) self, Stream s) -> () @@ -11383,18 +11683,26 @@ - func: fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor python_module: fft variants: function + dispatch: + CompositeExplicitAutograd: fft_fftfreq - func: fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) python_module: fft variants: function + dispatch: + CompositeExplicitAutograd: fft_fftfreq_out - func: fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor python_module: fft variants: function + dispatch: + CompositeExplicitAutograd: fft_rfftfreq - func: fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) python_module: fft variants: function + dispatch: + CompositeExplicitAutograd: fft_rfftfreq_out - func: fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor python_module: fft @@ -11416,23 +11724,19 @@ # "_ex" stands for experimental - func: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) python_module: linalg - variants: function - dispatch: - CPU, CUDA: linalg_cholesky_ex + structured_delegate: linalg_cholesky_ex.L - func: linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) python_module: linalg - variants: function + structured: True dispatch: CPU, CUDA: linalg_cholesky_ex_out - func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor python_module: linalg - variants: function - func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) python_module: linalg - variants: function - func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor python_module: linalg @@ -11497,27 +11801,25 @@ CPU, CUDA: linalg_lu_solve_out # linalg.det -- func: linalg_det(Tensor self) -> Tensor +- func: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + structured_delegate: _linalg_det.result + +- func: _linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) + structured: True + dispatch: + CPU, CUDA: _linalg_det_out + +- func: linalg_det(Tensor A) -> Tensor python_module: linalg variants: function -- func: linalg_det.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +- func: linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) python_module: linalg # torch.det, alias for torch.linalg.det - func: det(Tensor self) -> Tensor variants: function, method -- func: _det_lu_based_helper(Tensor self) -> (Tensor det, Tensor lu, Tensor pivs) - variants: function - dispatch: - CPU, CUDA: _det_lu_based_helper - -- func: _det_lu_based_helper_backward_helper(Tensor det_grad, Tensor det, Tensor self, Tensor lu, Tensor pivs) -> Tensor - variants: function - dispatch: - CPU, CUDA: _det_lu_based_helper_backward_helper - - func: linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info) structured_delegate: linalg_ldl_factor_ex.out python_module: linalg @@ -11572,22 +11874,41 @@ - func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) python_module: linalg +- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor + python_module: linalg + variants: function + +- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + - func: linalg_matrix_exp(Tensor self) -> Tensor python_module: linalg variants: function dispatch: CPU, CUDA: linalg_matrix_exp -- func: linalg_slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) - python_module: linalg - variants: function +- func: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + structured_delegate: _linalg_slogdet.sign + +- func: _linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) + structured: True dispatch: - CPU, CUDA: linalg_slogdet + CPU, CUDA: _linalg_slogdet_out -- func: linalg_slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) +- func: linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet) python_module: linalg - dispatch: - CPU, CUDA: linalg_slogdet_out + +- func: linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + python_module: linalg + +- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) + variants: function, method + +- func: slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + variants: function + +- func: logdet(Tensor self) -> Tensor + variants: function, method - func: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) python_module: linalg @@ -11644,19 +11965,21 @@ dispatch: CPU: _linalg_inv_out_helper_cpu CUDA: _linalg_inv_out_helper_cuda - autogen: _linalg_inv_out_helper.functional, _linalg_inv_out_helper.out + autogen: _linalg_inv_out_helper, _linalg_inv_out_helper.out - func: linalg_inv_ex(Tensor self, *, bool check_errors=False) -> (Tensor inverse, Tensor info) python_module: linalg variants: function dispatch: - CompositeExplicitAutograd: linalg_inv_ex + # calls transpose_ + CompositeExplicitAutogradNonFunctional: linalg_inv_ex - func: linalg_inv_ex.inverse(Tensor self, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info) python_module: linalg variants: function dispatch: - CompositeExplicitAutograd: linalg_inv_ex_out + # calls transpose_ + CompositeExplicitAutogradNonFunctional: linalg_inv_ex_out - func: linalg_inv(Tensor self) -> Tensor python_module: linalg @@ -11768,7 +12091,9 @@ python_module: linalg variants: function dispatch: - CompositeExplicitAutograd: linalg_pinv + # calls svd, which calls mH() (view op) + # also calls narrow() + CompositeExplicitAutogradNonFunctional: linalg_pinv - func: linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) python_module: linalg @@ -11802,13 +12127,19 @@ python_module: linalg variants: function -- func: _linalg_solve(Tensor A, Tensor B, *, bool left=True) -> (Tensor result, Tensor LU, Tensor pivots) - structured_delegate: _linalg_solve.result +- func: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) + structured_delegate: _linalg_solve_ex.result -- func: _linalg_solve.result(Tensor A, Tensor B, *, bool left=True, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) +- func: _linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) structured: True dispatch: - CPU, CUDA: _linalg_solve_out + CPU, CUDA: _linalg_solve_ex_out + +- func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info) + python_module: linalg + +- func: linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info) + python_module: linalg - func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor python_module: linalg @@ -12020,7 +12351,7 @@ - func: expand_copy.SymInt(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor variants: function dispatch: - CompositeExplicitAutogradNonFunctional: expand_copy_SymInt + CompositeExplicitAutograd: expand_copy_SymInt tags: view_copy - func: permute_copy(Tensor self, int[] dims) -> Tensor @@ -12209,6 +12540,13 @@ CompositeExplicitAutograd: _neg_view_copy_out +- func: view_copy.SymInt(Tensor self, SymInt[] size) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: view_copy_SymInt + tags: view_copy + + - func: as_strided_copy.out(Tensor self, int[] size, int[] stride, int? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: @@ -12398,6 +12736,22 @@ dispatch: CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention +- func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor) + variants: function + +- func: special_airy_ai(Tensor x) -> Tensor + python_module: special + structured_delegate: special_airy_ai.out + variants: function + +- func: special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_airy_ai_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: @@ -12808,6 +13162,32 @@ structured: True variants: function +- func: special_scaled_modified_bessel_k0(Tensor x) -> Tensor + python_module: special + structured_delegate: special_scaled_modified_bessel_k0.out + variants: function + +- func: special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_scaled_modified_bessel_k0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + +- func: special_scaled_modified_bessel_k1(Tensor x) -> Tensor + python_module: special + structured_delegate: special_scaled_modified_bessel_k1.out + variants: function + +- func: special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_scaled_modified_bessel_k1_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + - func: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special @@ -12955,3 +13335,22 @@ device_check: NoCheck python_module: special variants: function + +- func: special_spherical_bessel_j0(Tensor x) -> Tensor + python_module: special + structured_delegate: special_spherical_bessel_j0.out + variants: function + +- func: special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: special_spherical_bessel_j0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + +# Aux function used in the test TestPythonDispatch.test_kwarg_only_and_positional_default +# within test/test_python_dispatch.py +- func: _foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor + dispatch: + CPU: foobar diff --git a/aten/src/ATen/native/nested/NestedTensorBackward.cpp b/aten/src/ATen/native/nested/NestedTensorBackward.cpp new file mode 100644 index 0000000000000..600db24c03aa2 --- /dev/null +++ b/aten/src/ATen/native/nested/NestedTensorBackward.cpp @@ -0,0 +1,68 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +std::tuple nested_linear_backward( + const Tensor& input, + const Tensor& grad_output, + const Tensor& weight, + std::array output_mask) { + if (!grad_output.defined()) { + return std::tuple{Tensor(), Tensor(), Tensor()}; + } + Tensor grad_input, grad_weight, grad_bias; + auto* nt_grad_output = get_nested_tensor_impl(grad_output); + auto* nt_input = get_nested_tensor_impl(input); + TORCH_INTERNAL_ASSERT(nt_grad_output != nullptr); + TORCH_INTERNAL_ASSERT(nt_input != nullptr); + TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_grad_output)); + auto grad_ouput_buffer = nt_grad_output->get_buffer(); + auto input_buffer = nt_input->get_buffer(); + + auto reshaped_grad = grad_ouput_buffer.reshape({-1, weight.size(0)}); + + if (output_mask[0]) { + auto grad_input_buffer = at::mm(reshaped_grad, weight).view({-1}); + auto grad_input_nt_size = nt_input->get_nested_size_tensor().clone(); + grad_input = wrap_buffer(grad_input_buffer, grad_input_nt_size); + } + if (output_mask[1]) { + grad_weight = + at::mm(reshaped_grad.t(), input_buffer.reshape({-1, weight.size(1)})); + } + if (output_mask[2]) { + grad_bias = reshaped_grad.sum(0); + } + return std::tuple{grad_input, grad_weight, grad_bias}; +} + +Tensor _reshape_nested_backward(const Tensor& self, const Tensor& grad) { + auto self_ptr = get_nested_tensor_impl(self); + // TODO: this is to reproduce self_ptr->opt_sizes_ + // if an accessor is provided in the future, can replace this + std::vector sizes; + for (int64_t i = 0; i < self_ptr->dim(); i++) { + c10::optional opt_size = self_ptr->opt_size(i); + if (opt_size.has_value()) { + sizes.push_back(*opt_size); + } + else { + sizes.push_back(-1); + } + } + return grad.reshape(sizes); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index a47cb124b0b0e..6c05986e2e61f 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -96,20 +96,10 @@ Tensor pad_tensor_to_shape( } } // namespace -at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_size_tensor) { - TORCH_CHECK(buffer.is_contiguous(), "Given buffer must be contiguous."); - return at::detail::make_tensor( - std::move(buffer), std::move(nested_size_tensor)); -} - inline const at::Tensor& get_buffer(const at::Tensor& tensor) { return get_nested_tensor_impl(tensor)->get_buffer(); } - -// CPU only! -// TODO: The algorithm here can be optimized, right now it involves a lot of -// small tensor manipulations std::vector NestedTensor_unbind( const at::Tensor& self, int64_t dim) { @@ -119,22 +109,18 @@ std::vector NestedTensor_unbind( "got dimension ", dim, " instead."); - auto esizes = get_nested_size_tensor(self); - std::vector result_tensors; - if (esizes.dim() == 0) { + auto self_ptr = get_nested_tensor_impl(self); + int64_t ntensors = self_ptr->size(0); + std::vector result_tensors(ntensors); + if (ntensors == 0) { return result_tensors; } - auto esizes_chunks = esizes.unbind(0); - std::vector splits; - for (const auto i : c10::irange(esizes_chunks.size())) { - splits.push_back(esizes_chunks[i].prod().item()); - } - auto buffer_chunks = at::split_with_sizes(get_buffer(self), splits); - for (const auto i : c10::irange(buffer_chunks.size())) { - const auto& esize_chunk = esizes_chunks[i]; - result_tensors.push_back(buffer_chunks[i].view(IntArrayRef( - esize_chunk.data_ptr(), - esize_chunk.data_ptr() + esize_chunk.numel()))); + const at::Tensor& buffer = self_ptr->get_buffer(); + std::vector sizes = NestedTensor_get_sizes(self_ptr), + strides = NestedTensor_get_strides(self_ptr); + const std::vector& offsets = self_ptr->get_offsets(); + for (int64_t i = 0; i < ntensors; i++) { + result_tensors[i] = buffer.as_strided(sizes[i], strides[i], offsets[i]); } return result_tensors; } @@ -218,15 +204,6 @@ Tensor nested_tensor( c10::optional layout, c10::optional device, c10::optional pin_memory) { - TensorOptions options_ = - TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( - pin_memory); - - if (list.size() == 0) { - return wrap_buffer(ones({0}, dtype, layout, device), ones({})); - } - std::vector sizes; - std::vector flat_tensors; for (const auto i : c10::irange(list.size())) { if (i > 0) { int64_t dim_i = list[i].dim(); @@ -244,15 +221,13 @@ Tensor nested_tensor( i - 1, "."); } - // TODO: Remove call to contiguous once we support strides. - flat_tensors.push_back(list[i].reshape(-1).contiguous()); - sizes.push_back(tensor(c10::IntArrayRef(list[i].sizes()))); } - - TensorOptions options = flat_tensors[0].options().merge_in(options_); - - return wrap_buffer( - at::cat(flat_tensors).to(options), at::native::stack(sizes)); + return impl::wrap_tensor_node( + impl::TensorNode(list), + dtype, + layout, + device, + pin_memory); } int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt) { @@ -365,6 +340,11 @@ Tensor NestedTensor_to_padded_tensor_generic( const Tensor& t, double padding, OptionalIntArrayRef output_size) { + // TODO: support noncontiguous case + // error out for now + TORCH_CHECK( + nested_tensor_impl_is_contiguous(get_nested_tensor_impl(t)), + "for now to_padded_tensor only supports contiguous nested tensor"); // TODO: skipped optimization for case of all 1x1 tensors auto& nt = *get_nested_tensor_impl(t); auto max_size = NestedTensor_get_max_size(nt); @@ -508,6 +488,22 @@ Tensor NestedTensor_elementwise_Tensor( const Tensor& other, const std::string& op_name, Func f) { + // self is a scalar + if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) { + auto other_impl = get_nested_tensor_impl(other); + return wrap_buffer( + f(self, other_impl->get_buffer()), + other_impl->get_nested_size_tensor().clone() + ); + } + // other is a scalar + if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) { + auto self_impl = get_nested_tensor_impl(self); + return wrap_buffer( + f(self_impl->get_buffer(), other), + self_impl->get_nested_size_tensor().clone() + ); + } NestedTensorImpl* self_impl = nullptr; NestedTensorImpl* other_impl = nullptr; std::tie(self_impl, other_impl) = @@ -540,12 +536,29 @@ Tensor NestedTensor_mul_Tensor(const Tensor& self, const Tensor& other) { }); } +// Only usable on the C++ side; scalars are converted to tensors coming from Python. +Tensor NestedTensor_mul_Scalar(const Tensor& self, const Scalar& other) { + return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other)); +} + template Tensor& NestedTensor_elementwise__Tensor( Tensor& self, const Tensor& other, const std::string& op_name, Func f) { + // self is a scalar + if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) { + auto other_impl = get_nested_tensor_impl(other); + f(self, other_impl->get_buffer()); + return self; + } + // other is a scalar + if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) { + auto self_impl = get_nested_tensor_impl(self); + f(self_impl->get_buffer(), other); + return self; + } NestedTensorImpl* self_impl = nullptr; NestedTensorImpl* other_impl = nullptr; std::tie(self_impl, other_impl) = @@ -575,56 +588,102 @@ Tensor& NestedTensor_mul__Tensor(Tensor& self, const Tensor& other) { }); } -Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) { +// Only usable on the C++ side; scalars are converted to tensors coming from Python. +Tensor& NestedTensor_mul__Scalar(Tensor& self, const Scalar& other) { + return NestedTensor_mul__Tensor(self, wrapped_scalar_tensor(other)); +} + +// Very rudimentary sum_dim for prototyping with torch_scatter.segment_reduce. +Tensor NestedTensor_sum_dim_CPU( + const Tensor& self, + OptionalIntArrayRef opt_dims, + bool keepdim, + c10::optional dtype) { + // Only allow reductions across the last dim + auto dims = opt_dims.value_or(IntArrayRef{}); TORCH_CHECK( - dim == 0, - "NestedTensor can only be selected along dimension 0 ", - "got dimension ", dim, " instead." + dims.size() == 1, + "NestedTensor only allows reduction of a single dimension for now." ); - auto self_ptr = get_nested_tensor_impl(self); - // buffer contains the underlying data in a contiguous vector - const at::Tensor & buffer = self_ptr->get_buffer(); - int64_t numel = buffer.numel(); + auto dim = maybe_wrap_dim(dims[0], self.dim()); TORCH_CHECK( - numel > 0, - "cannot index an empty nested tensor." - ); - // nested_tensor[i] = i-th original tensor - int64_t ntensors = *(self_ptr->opt_size(0)); - int64_t positive_index = at::maybe_wrap_dim(index, ntensors); - // determine the memory segment of the i-th original tensor - Tensor sizemat = get_nested_size_tensor(self); - int64_t original_dim = sizemat.size(1); - const int64_t * sizemat_ptr = sizemat.data_ptr(); - // start of the segment - int64_t start = 0, sizemat_offset = 0; - for (int64_t i = 0; i < positive_index; i++) { - int64_t row_product = sizemat_ptr[sizemat_offset]; - sizemat_offset++; - for (int64_t j = 1; j < original_dim; j++) { - row_product *= sizemat_ptr[sizemat_offset]; - sizemat_offset++; - } - start += row_product; - } - // btw determine the shape of the i-th original tensor - IntArrayRef shape(sizemat_ptr + sizemat_offset, sizemat_ptr + sizemat_offset + original_dim); - // stop of the segment - int64_t stop; - if (positive_index == ntensors - 1) { - stop = numel; + dim == self.dim() - 1, + "NestedTensor can only be reduced across the last dimension for now ", + "got dimension ", + dim, + " instead."); + // Always keep reduced dim for now + // This is to avoid the case where the nested tensors are 1D and keepdim=False + // making the nested tensors -> elements (e.g. sum(nt([1, 2 ,3], [4, 5]), -1) -> nt(6, 9)) + TORCH_CHECK(keepdim, "NestedTensor always requires keepdim=True for now."); + // acc_dtype is not supported for now + TORCH_CHECK(!dtype, "NestedTensor does not support dtype argument for now."); + + auto nt_input = get_nested_tensor_impl(self); + TORCH_CHECK( + nested_tensor_impl_is_contiguous(nt_input), + "NestedTensor does not support reductions when the input is noncontiguous for now."); + int64_t ntensors = nt_input->size(0); + if (ntensors == 0) { + return self; } - else { - int64_t row_product = sizemat_ptr[sizemat_offset]; - sizemat_offset++; - for (int64_t j = 1; j < original_dim; j++) { - row_product *= sizemat_ptr[sizemat_offset]; - sizemat_offset++; + const Tensor& buffer = nt_input->get_buffer(); + + auto sizemat = nt_input->get_nested_size_tensor(); + // create output size tensor for keepdim=True + auto output_sizemat = sizemat.clone(); + output_sizemat.select(1, -1).fill_(1); + + auto num_segments = at::prod(output_sizemat, -1); + auto segment_lengths = sizemat.select(1, -1); + const int64_t new_numel = at::sum(num_segments).item(); + auto output_buffer = buffer.new_empty(IntArrayRef(new_numel)); + + // This logic assumes for now that + // (1) all the nested tensors are contiguous + // (2) the nested tensors are stored contiguously in the buffer + AT_DISPATCH_ALL_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, buffer.scalar_type(), "nested_sum_dim_cpu", [&]() { + auto* output_data = output_buffer.data_ptr(); + const auto* input_data = buffer.data_ptr(); + int64_t out_idx = 0, in_idx = 0; + for (const auto i : c10::irange(ntensors)) { + int64_t segments = num_segments[i].item(); + int64_t segment_length = segment_lengths[i].item(); + for (auto j = 0; j < segments; j++) { + scalar_t res = 0; + for (auto k = 0; k < segment_length; k++) { + res += input_data[in_idx]; + in_idx += 1; + } + output_data[out_idx] = res; + out_idx += 1; + } } - stop = start + row_product; - } - // extract the memory segment then reshape to the original shape - return buffer.slice(0, start, stop).view(shape); + }); + + return wrap_buffer(output_buffer, output_sizemat); +} + +Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) { + auto self_ptr = get_nested_tensor_impl(self); + int64_t positive_dim = at::maybe_wrap_dim(dim, self_ptr->dim()); + TORCH_CHECK( + positive_dim == 0, + "NestedTensor can only be selected along dimension 0 ", + "got dimension ", dim, " instead." + ); + int64_t ntensors = self_ptr->size(0); + TORCH_CHECK_INDEX( + index >= -ntensors && index < ntensors, + "index ", index, + " is out of bounds for dimension 0 with size ", ntensors); + int64_t positive_index = index < 0 ? index + ntensors : index; + const at::Tensor& buffer = self_ptr->get_buffer(); + std::vector sizes = NestedTensor_get_sizes(self_ptr), + strides = NestedTensor_get_strides(self_ptr); + const std::vector& offsets = self_ptr->get_offsets(); + return buffer.as_strided(sizes[positive_index], strides[positive_index], offsets[positive_index]); } Tensor clone_nested( @@ -649,10 +708,14 @@ at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){ Tensor dropout_nested(const Tensor& input, double p, bool train) { auto input_ptr = get_nested_tensor_impl(input); - const Tensor & input_buffer = input_ptr->get_buffer(), - sizemat = input_ptr->get_nested_size_tensor(); + const Tensor& input_buffer = input_ptr->get_buffer(), + & sizemat = input_ptr->get_nested_size_tensor(), + & stridemat = input_ptr->get_nested_stride_tensor(); + const std::vector& offsets = input_ptr->get_offsets(); Tensor output_buffer = at::dropout(input_buffer, p, train); - return wrap_buffer(output_buffer, sizemat.clone()); + // regular tensor dropout reuses input size and stride + // i.e. if input is not contiguous, then output is also discontiguous + return wrap_buffer(output_buffer, sizemat.clone(), stridemat.clone(), offsets); } Tensor& dropout_nested_(Tensor& input, double p, bool train) { @@ -661,5 +724,432 @@ Tensor& dropout_nested_(Tensor& input, double p, bool train) { return input; } +Tensor softmax_nested( + const Tensor& input, + const int64_t dim, + const bool half_to_float) { + auto input_ptr = get_nested_tensor_impl(input); + int64_t ntensors = input_ptr->size(0); + if (ntensors == 0) { + return input; + } + int64_t positive_dim = at::maybe_wrap_dim(dim, input_ptr->dim()); + TORCH_CHECK( + positive_dim >= 1, + "Cannot apply softmax across nested dimension 0"); + // create a contiguous output + const Tensor& buffer = input_ptr->get_buffer(), + & sizemat = input_ptr->get_nested_size_tensor(); + Tensor output_buffer = buffer.new_empty(buffer.sizes()); + Tensor output = wrap_buffer(output_buffer, sizemat.clone()); + // call tensor softmax + // TODO: for cpu, maybe use `parallel_for` if benchmarks show necessity + // to do that, have to merge `aten/src/ATen/native/cpu/SoftMaxKernel.cpp/softmax_kernel` + // 1. it has `parallel_for` and we cannot multi-thread in multi-thread + // 2. cannot dispatch in multi-thread (in this case at::_softmax_out) + std::vector input_unbind = input.unbind(), + output_unbind = output.unbind(); + for (int64_t i = 0; i < ntensors; i++) { + at::_softmax_out( + output_unbind[i], + input_unbind[i], + positive_dim - 1, + half_to_float); + } + return output; +} + +Tensor bmm_nested(const Tensor& self, const Tensor& mat2) { + if (self.is_nested() && !mat2.is_nested()) { + AT_ERROR("Expected both to be nested, but got a nested self and non-nested other"); + } + else if (!self.is_nested() && mat2.is_nested()) { + AT_ERROR("Expected both to be nested, but got a non-nested self and nested other"); + } + // dispatcher should have guaranteed that at least one is nested + auto self_ptr = get_nested_tensor_impl(self); + auto mat2_ptr = get_nested_tensor_impl(mat2); + TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor"); + TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor"); + int64_t ntensors = self_ptr->size(0), + ntensors2 = mat2_ptr->size(0); + TORCH_CHECK(ntensors == ntensors2, + "Expected size for the 1st dimension of batch2 tensor to be: ", ntensors, + " but got: ", ntensors2, "."); + const Tensor& self_buffer = self_ptr->get_buffer(), + & mat2_buffer = mat2_ptr->get_buffer(); + std::vector self_sizes = NestedTensor_get_sizes(self_ptr), + mat2_sizes = NestedTensor_get_sizes(mat2_ptr), + self_strides = NestedTensor_get_strides(self_ptr), + mat2_strides = NestedTensor_get_strides(mat2_ptr); + const std::vector& self_offsets = self_ptr->get_offsets(), + & mat2_offsets = mat2_ptr->get_offsets(); + // create a contiguous output + int64_t out_numel = 0; + const Tensor& self_sizemat = self_ptr->get_nested_size_tensor(); + Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes()); + int64_t* out_sizemat_ptr = out_sizemat.data_ptr(); + for (int64_t i = 0; i < ntensors; i++) { + const IntArrayRef& self_shape = self_sizes[i], + & mat2_shape = mat2_sizes[i]; + const int64_t& self_size0 = self_shape[0], & self_size1 = self_shape[1], + & mat2_size0 = mat2_shape[0], & mat2_size1 = mat2_shape[1]; + TORCH_CHECK(self_size1 == mat2_size0, + i, "-th nested matrices in batch cannot be multiplied (", + self_size0, "x", self_size1, " and ", + mat2_size0, "x", mat2_size1, ")"); + out_sizemat_ptr[0] = self_size0; + out_sizemat_ptr[1] = mat2_size1; + out_sizemat_ptr += 2; + out_numel += self_size0 * mat2_size1; + } + Tensor out_buffer = self_buffer.new_empty(out_numel); + Tensor output = wrap_buffer(out_buffer, out_sizemat); + // call tensor mm + // TODO: `padding nested tensor -> bmm -> remove padding` may be more efficient + // until we have specialized nested tensor bmm kernel + // useful resource: `aten/src/ATen/native/cpu/LinearAlgebra.cpp/bmm_out_or_baddbmm_` + // `aten/src/ATen/native/cuda/Blas.cpp/baddbmm_out_cuda_impl` + std::vector output_unbind = output.unbind(); + for (int64_t i = 0; i < ntensors; i++) { + at::mm_out(output_unbind[i], + self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]), + mat2_buffer.as_strided(mat2_sizes[i], mat2_strides[i], mat2_offsets[i])); + } + return output; +} + +// utilities support _NestedTensor_GeneralizedBMM +namespace { +inline std::tuple, Tensor> +_NestedTensor_GeneralizedBMM_BatchSizes_OutputMemory( + const std::vector& self_sizes, + const std::vector& mat2_sizes, + const c10::TensorOptions& buffer_op, + const c10::TensorOptions& sizemat_op) { + int64_t ntensors = self_sizes.size(), + ndims = self_sizes[0].size(); + std::vector batch_sizes(ntensors, 1); + Tensor sizemat = at::empty({ntensors, ndims}, sizemat_op); + int64_t* sizemat_ptr = sizemat.data_ptr(); + int64_t numel = 0; + for (int64_t i = 0; i < ntensors; i++) { + const IntArrayRef& self_size = self_sizes[i], + & mat2_size = mat2_sizes[i]; + int64_t& batch_size = batch_sizes[i]; + // batch dimensions + for (int64_t j = 0; j < ndims - 2; j++) { + const int64_t& self_sizej = self_size[j], + & mat2_sizej = mat2_size[j]; + TORCH_CHECK( + self_sizej == mat2_sizej, + "matmul: For nested tensors, no broadcasting is currently performed: ", + i, "-th nested matrices in batch at dimension ", j + 1, + " have mismatching sizes ", self_sizej, " and ", mat2_sizej); + sizemat_ptr[j] = self_sizej; + batch_size *= sizemat_ptr[j]; + } + // matrix multiplication dimensions + const int64_t& self_size0 = self_size[ndims - 2], & self_size1 = self_size[ndims - 1], + & mat2_size0 = mat2_size[ndims - 2], & mat2_size1 = mat2_size[ndims - 1]; + TORCH_CHECK( + self_size1 == mat2_size0, + "matmul: ", + i, "-th nested matrices in batch cannot be multiplied (", + self_size0, "x", self_size1, " and ", + mat2_size0, "x", mat2_size1, ")"); + sizemat_ptr[ndims - 2] = self_size0; + sizemat_ptr[ndims - 1] = mat2_size1; + sizemat_ptr += ndims; + numel += batch_size * self_size0 * mat2_size1; + } + Tensor buffer = at::empty(numel, buffer_op); + Tensor output = wrap_buffer(buffer, sizemat); + return std::make_tuple(batch_sizes, output); +} +} + +// This is a generalized batched matmul dedicated to nested tensors, +// where `self` and `mat2` have same number (>= 3) of dimensions. +// The last 2 dimensions will be considered as matrix dimensions, +// so they should be matrix-multiplicable. +// The leading dimensions are considered as batch dimensions, +// and since nested tensor does not support broadcasting for now, +// for each batch dimension `self` and `mat2` must have same size. +Tensor _NestedTensor_GeneralizedBMM(const Tensor& self, const Tensor& mat2) { + if (self.is_nested() && !mat2.is_nested()) { + AT_ERROR("Expected both to be nested, but got a nested self and non-nested other"); + } + else if (!self.is_nested() && mat2.is_nested()) { + AT_ERROR("Expected both to be nested, but got a non-nested self and nested other"); + } + // dispatcher should have guaranteed that at least one is nested + auto self_ptr = get_nested_tensor_impl(self), + mat2_ptr = get_nested_tensor_impl(mat2); + int64_t self_dim = self_ptr->dim(), + mat2_dim = mat2_ptr->dim(); + TORCH_CHECK( + self_dim >= 3, + "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: ", + self_dim); + TORCH_CHECK( + mat2_dim >= 3, + "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: ", + mat2_dim); + TORCH_CHECK(self_dim == mat2_dim, "matmul: both inputs must have same rank"); + int64_t ntensors = self_ptr->size(0), + ntensors2 = mat2_ptr->size(0); + TORCH_CHECK(ntensors == ntensors2, + "matmul: Expected size for the 1st dimension of 2nd input tensor to be: ", ntensors, + " but got: ", ntensors2, "."); + const Tensor& self_buffer = self_ptr->get_buffer(), + & mat2_buffer = mat2_ptr->get_buffer(); + std::vector self_sizes = NestedTensor_get_sizes(self_ptr), + mat2_sizes = NestedTensor_get_sizes(mat2_ptr), + self_strides = NestedTensor_get_strides(self_ptr), + mat2_strides = NestedTensor_get_strides(mat2_ptr); + const std::vector& self_offsets = self_ptr->get_offsets(), + & mat2_offsets = mat2_ptr->get_offsets(); + // create a contiguous output + std::vector batch_sizes; + Tensor output; + std::tie(batch_sizes, output) = _NestedTensor_GeneralizedBMM_BatchSizes_OutputMemory( + self_sizes, mat2_sizes, self_buffer.options(), self_ptr->get_nested_size_tensor().options()); + // call tensor matmul + // TODO: `padding nested tensor -> bmm -> remove padding` may be more efficient + // until we have specialized nested tensor bmm kernel + // useful resource: `aten/src/ATen/native/cpu/LinearAlgebra.cpp/bmm_out_or_baddbmm_` + // `aten/src/ATen/native/cuda/Blas.cpp/baddbmm_out_cuda_impl` + std::vector output_unbind = output.unbind(); + for (int64_t i = 0; i < ntensors; i++) { + const IntArrayRef& self_size = self_sizes[i], + & mat2_size = mat2_sizes[i]; + const int64_t& batch_size = batch_sizes[i]; + if (batch_size == 1) { + at::mm_out( + output_unbind[i], + self_buffer.as_strided(self_size, self_strides[i], self_offsets[i]), + mat2_buffer.as_strided(mat2_size, mat2_strides[i], mat2_offsets[i]) + ); + } + else { + at::bmm_out( + output_unbind[i], + self_buffer.as_strided(self_size, self_strides[i], self_offsets[i]) + .reshape({batch_size, self_size[self_dim - 1 - 2], self_size[self_dim - 1 - 1]}), + mat2_buffer.as_strided(mat2_size, mat2_strides[i], mat2_offsets[i]) + .reshape({batch_size, mat2_size[self_dim - 1 - 2], mat2_size[self_dim - 1 - 1]}) + ); + } + } + return output; +} + +Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) { + auto self_ptr = get_nested_tensor_impl(self); + // check input dimensions + int64_t ndims = self_ptr->dim(); + int64_t positive_dim0 = at::maybe_wrap_dim(dim0, ndims), + positive_dim1 = at::maybe_wrap_dim(dim1, ndims); + if (positive_dim0 == positive_dim1) { + return self; + } + TORCH_CHECK(positive_dim0 > 0 && positive_dim1 > 0, "Nested tensor dimension 0 cannot be transposed"); + // -- to exclude the implicit batch dimension + ndims--; + positive_dim0--; + positive_dim1--; + // transpose = switch `dim0` and `dim1` columns of `sizemat` and `stridemat` + const Tensor& sizemat = self_ptr->get_nested_size_tensor(), + & stridemat = self_ptr->get_nested_stride_tensor(); + Tensor column_indices = sizemat.new_empty(ndims); + int64_t* column_indices_ptr = column_indices.data_ptr(); + std::iota(column_indices_ptr, column_indices_ptr + ndims, 0); + column_indices_ptr[positive_dim0] = positive_dim1; + column_indices_ptr[positive_dim1] = positive_dim0; + // create transposed `sizemat` and `stridemat` + Tensor sizemat_transposed = at::index_select(sizemat, 1, column_indices), + stridemat_transposed = at::index_select(stridemat, 1, column_indices); + return wrap_buffer(self_ptr->get_buffer(), sizemat_transposed, stridemat_transposed, self_ptr->get_offsets()); +} + +// utilities supporting `_reshape_nested` +namespace { +// Args: +// sizes: the sizes of original nested tensor +// strides: the strides of original nested tensor +// proposed_shape: user proposed new shape +// op: the options for new size and stride matrices +// Returns: +// whether reshape as view is possible (i.e. old buffer can be reused) +// size matrix after reshape +// stride matrix after reshape (not fully populated if reshape as view is impossible) +inline std::tuple NestedTensor_reshape_size_stride( + const std::vector& sizes, + const std::vector& strides, + const IntArrayRef& proposed_shape, + const c10::TensorOptions& op) { + int64_t ntensors = sizes.size(), + ndims_underlying = sizes[0].size(), + ndims_underlying_reshaped = proposed_shape.size() - 1; + bool reshape_as_view = true; + Tensor sizemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op), + stridemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op); + int64_t* sizemat_reshaped_ptr = sizemat_reshaped.data_ptr(), + * stridemat_reshaped_ptr = stridemat_reshaped.data_ptr(); + for (int64_t itensor = 0; itensor < ntensors; itensor++) { + const IntArrayRef& size = sizes[itensor], + & stride = strides[itensor]; + // compute reshaped size + std::vector size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end()); + // some negative sizes remain to be infered + if (ndims_underlying < ndims_underlying_reshaped) { + // replace negative sizes for old dimensions with old sizes + int64_t numel = 1, numel_reshaped = 1; + for (int64_t idim = 0; idim < ndims_underlying; idim++) { + int64_t& size_reshaped = size_reshaped_vector[idim]; + TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped); + if (size_reshaped == -1) { + size_reshaped = size[idim]; + } + numel *= size[idim]; + numel_reshaped *= size_reshaped; + } + // infer negative size for new dimension + int64_t infer_index = -1; + for (int64_t idim = ndims_underlying; idim < ndims_underlying_reshaped; idim++) { + const int64_t& size_reshaped = size_reshaped_vector[idim]; + if (size_reshaped >= 0) { + numel_reshaped *= size_reshaped; + } + else if (size_reshaped == -1) { + if (infer_index > -1) { + throw std::runtime_error("only one dimension can be inferred"); + } + else { + infer_index = idim; + } + } + else { + AT_ERROR("invalid shape dimension ", size_reshaped); + } + } + // See Note [inference and inheritance semantics] + TORCH_CHECK(infer_index == -1, "nested tensor does not infer shape"); + } + // all negative sizes can be replaced + else { + for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) { + int64_t& size_reshaped = size_reshaped_vector[idim]; + TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped); + if (size_reshaped == -1) { + size_reshaped = size[idim]; + } + } + } + IntArrayRef size_reshaped(size_reshaped_vector); + // compute reshaped stride + auto opt_stride_reshaped = at::detail::computeStride(size, stride, size_reshaped); + // reshape as view is possible + if (opt_stride_reshaped.has_value()) { + const IntArrayRef& stride_reshaped = *opt_stride_reshaped; + // fill reshaped size and stride into sizemat and stridemat + for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) { + sizemat_reshaped_ptr[idim] = size_reshaped[idim]; + stridemat_reshaped_ptr[idim] = stride_reshaped[idim]; + } + sizemat_reshaped_ptr += ndims_underlying_reshaped; + stridemat_reshaped_ptr += ndims_underlying_reshaped; + } + // reshape as view is impossible + else { + reshape_as_view = false; + // fill reshaped size into sizemat + for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) { + sizemat_reshaped_ptr[idim] = size_reshaped[idim]; + } + sizemat_reshaped_ptr += ndims_underlying_reshaped; + } + } + return std::make_tuple(reshape_as_view, sizemat_reshaped, stridemat_reshaped); +} + +// Args: +// nt_reshaped: the reshaped nested tensor to receive copies +// buffer: the original nested tensor buffer +// sizes: the original nested tensor sizes (may have gone through collapsing or splitting) +// strides: the original nested tensor strides (may have gone through collapsing or splitting) +// offsets: the original nested tensor offsets (may have gone through collapsing or splitting) +inline void NestedTensor_reshape_copy( + Tensor& nt_reshaped, + const Tensor& buffer, + const std::vector& sizes, + const std::vector& strides, + const std::vector& offsets) { + auto nt_reshaped_ptr = get_nested_tensor_impl(nt_reshaped); + const Tensor& buffer_reshaped = nt_reshaped_ptr->get_buffer(); + std::vector sizes_reshaped = NestedTensor_get_sizes(nt_reshaped_ptr), + strides_reshaped = NestedTensor_get_strides(nt_reshaped_ptr); + const std::vector& offsets_reshaped = nt_reshaped_ptr->get_offsets(); + for (int64_t i = 0; i < nt_reshaped_ptr->size(0); i++) { + buffer_reshaped.as_strided(sizes_reshaped[i], strides_reshaped[i], offsets_reshaped[i]).copy_( + // TODO: can we avoid allocating new memory for `buffer...reshape` + // I did not find anything like reshape_out + buffer.as_strided(sizes[i], strides[i], offsets[i]).reshape(sizes_reshaped[i])); + } +} +} + +// Special rules for reshape(nested tensor): +// 1. Only 1 regular dimension can be collapsed with +// or splitted from the implicit batch dimension +// 2. Instead of infering size, -1 means "inherit the old size", so: +// * negative size is legal for a ragged dimension +// * multiple sizes can be -1 +Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) { + TORCH_CHECK( + proposed_shape.size() > 0, + "shape '[]' is invalid for a nested tensor"); + auto self_ptr = get_nested_tensor_impl(self); + // basic information before reshaping + int64_t ntensors = self_ptr->size(0); + TORCH_CHECK( + ntensors > 0, + "empty nested tensor cannot be reshaped"); + // basic information after reshaping + int64_t ntensors_reshaped; + if (proposed_shape[0] >= 0) { + ntensors_reshaped = proposed_shape[0]; + } + else if (proposed_shape[0] == -1) { + ntensors_reshaped = ntensors; + } + else { + AT_ERROR("invalid shape dimension ", proposed_shape[0]); + } + TORCH_CHECK( + ntensors == ntensors_reshaped, + "for now reshape cannot change the implicit batch dimension"); + std::vector sizes = NestedTensor_get_sizes(self_ptr), + strides = NestedTensor_get_strides(self_ptr); + const std::vector& offsets = self_ptr->get_offsets(); + // reshaping underlying tensor dimensions does not change offset + // determine reshaped size and stride + const Tensor& buffer = self_ptr->get_buffer(), + & sizemat = self_ptr->get_nested_size_tensor(); + bool reshape_as_view; + Tensor sizemat_reshaped, stridemat_reshaped; + std::tie(reshape_as_view, sizemat_reshaped, stridemat_reshaped) = NestedTensor_reshape_size_stride( + sizes, strides, proposed_shape, sizemat.options()); + if (reshape_as_view) { + return wrap_buffer(buffer, sizemat_reshaped, stridemat_reshaped, offsets); + } + Tensor buffer_reshaped = buffer.new_empty(buffer.sizes()); + Tensor output = wrap_buffer(buffer_reshaped, sizemat_reshaped); + NestedTensor_reshape_copy(output, + buffer, sizes, strides, offsets); + return output; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorMath.h b/aten/src/ATen/native/nested/NestedTensorMath.h index 8f2919fc35b8e..b315a3b253df3 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.h +++ b/aten/src/ATen/native/nested/NestedTensorMath.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -12,9 +13,245 @@ struct NestedTensorImpl; // TODO: cache this and only do it once per NestedTensor int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt); -TORCH_API std::vector NestedTensor_get_max_size(const NestedTensorImpl& nt); +inline at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_size_tensor) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer.is_contiguous(), "Given buffer must be contiguous."); + return at::detail::make_tensor( + std::move(buffer), std::move(nested_size_tensor)); +} -TORCH_API Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding, OptionalIntArrayRef output_size); +inline at::Tensor wrap_buffer( + at::Tensor buffer, at::Tensor nested_size_tensor, + at::Tensor nested_stride_tensor, const std::vector& offsets) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer.is_contiguous(), "Given buffer must be contiguous."); + return at::detail::make_tensor( + std::move(buffer), std::move(nested_size_tensor), + std::move(nested_stride_tensor), offsets); +} + +// The sizes of the underlying tensors +inline std::vector NestedTensor_get_sizes(const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector sizes(ntensors); + if (ntensors == 0) { + return sizes; + } + const Tensor& sizemat = self_ptr->get_nested_size_tensor(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars has empty sizes + if (orig_dim == 0) { + return sizes; + } + const int64_t* sizemat_ptr = sizemat.data_ptr(); + for (int64_t i = 0; i < ntensors; i++) { + sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim); + sizemat_ptr += orig_dim; + } + return sizes; +} + +inline std::vector NestedTensor_get_sizes(const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_sizes(self_ptr); +} + +// The strides of the underlying tensors +inline std::vector NestedTensor_get_strides(const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector strides(ntensors); + if (ntensors == 0) { + return strides; + } + const Tensor& stridemat = self_ptr->get_nested_stride_tensor(); + int64_t orig_dim = stridemat.size(1); + // nesting scalars has empty strides + if (orig_dim == 0) { + return strides; + } + const int64_t* stridemat_ptr = stridemat.data_ptr(); + for (int64_t i = 0; i < ntensors; i++) { + strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim); + stridemat_ptr += orig_dim; + } + return strides; +} + +inline std::vector NestedTensor_get_strides(const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_strides(self_ptr); +} + +TORCH_API std::vector NestedTensor_get_max_size( + const NestedTensorImpl& nt); + +TORCH_API Tensor NestedTensor_to_padded_tensor_generic( + const Tensor& t, + double padding, + OptionalIntArrayRef output_size); + +namespace impl { + +template +struct NestedNode { + NestedNode() = delete; + explicit NestedNode(std::vector&& children) + : _is_leaf(false), _children(children) {} + explicit NestedNode(TensorList children) + : _is_leaf(false), _children(children.vec()) {} + // NestedNode(NestedNode&) = delete; + // NestedNode(const NestedNode&) = delete; + // NestedNode& operator=(NestedNode) = delete; + explicit NestedNode(T payload) : _is_leaf(true), _payload(payload) {} + inline bool is_leaf() const { + return _is_leaf; + } + inline size_t degree() const { + return _children.size(); + } + inline const std::vector unbind() const { + return _children; + } + inline T children(size_t i) const { + return _children[i]; + } + inline const T& payload() const { + return _payload; + } + inline T& payload() { + return _payload; + } + + private: + bool _is_leaf; + std::vector _children; + T _payload; +}; + +using TensorNode = NestedNode; + +template +class _map; + +template +class _map> { + public: + static A function_one( + F&& fn, + const Args&... nested_node) { + return std::forward(fn)(nested_node...); + } + // NOTE: We must move F to avoid copying objects if it is a lambda with + // captures. + static NestedNode function( + F&& fn, + const NestedNode&... nested_node) { + size_t degree = 0; + bool all_leaf = true; + c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) { + all_leaf = all_leaf && (n.is_leaf()); + if (degree > 1 && n.degree() > 1) { + TORCH_CHECK(degree == n.degree(), "NestedNodes must match in degree."); + } + if (n.degree() > degree) { + degree = n.degree(); + } + return nullptr; + }); + // All NestedNodes just wrap regular objects. + if (all_leaf) { + return NestedNode(std::forward(fn)(nested_node.payload()...)); + } + // Some NestedNodes wrap regular Tensors, some NestedTensors and some other types. + std::vector result; + for (size_t i = 0; i < degree; i++) { + std::tuple children = c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&i](auto a) { + static_assert( + c10::guts::is_instantiation_of::value, + "Internal error."); + // Broadcast regular arguments across NestedTensor constituents. + // This could be a Tensor, integer or anything else really. + if (a.is_leaf()) { + return a.payload(); + } + // Broadcast NestedTensors with one constituent. + if (a.degree() == 1 && !a.is_leaf()) { + return a.children(0); + } + TORCH_CHECK(a.degree() > 0, "Internal assert."); + return a.children(i); + }); + c10::guts::apply( + [&result, &fn](Args... filtered) { + result.emplace_back(function_one(std::forward(fn), filtered...)); + }, + std::move(children)); + } + return NestedNode(std::move(result)); + } +}; + +// TODO: Add static assert to verify lambda arguments match nested_node types +template +static inline NestedNode< + typename c10::guts::infer_function_traits::type::return_type> +map(F&& fn, const NestedNode&... nested_node) { + return _map< + F, + typename c10::guts::infer_function_traits::type::return_type, + typename c10::guts::infer_function_traits::type::parameter_types>:: + function(std::forward(fn), nested_node...); +} + +inline TensorNode get_nested_tensor_structure(at::Tensor tensor) { + if (get_nested_tensor_impl_or_null(tensor) == nullptr) { + return TensorNode(std::move(tensor)); + } + return TensorNode(tensor.unbind()); +} + +inline Tensor wrap_tensor_node( + TensorNode tensor_node, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + TORCH_CHECK( + !tensor_node.is_leaf(), "Expected TensorNode to wrap a list of Tensors."); + TensorOptions options_ = + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( + pin_memory); + if (tensor_node.degree() == 0) { + return wrap_buffer(ones({0}, dtype, layout, device), ones({})); + } + std::vector sizes; + std::vector flat_tensors; + for (const auto i : c10::irange(tensor_node.degree())) { + flat_tensors.push_back( + tensor_node.children(i).reshape(-1).contiguous()); + sizes.push_back( + tensor(c10::IntArrayRef(tensor_node.children(i).sizes()))); + } + + TensorOptions options = flat_tensors[0].options().merge_in(options_); + + return wrap_buffer( + at::cat(flat_tensors).to(options), at::native::stack(sizes)); +} + +} // namespace impl + +// This function is meant to ease rapid operator coverage for +// NestedTensor kernels. It is not meant to be efficient. Use it judiciously. +template +inline at::Tensor map_nested_tensor(F&& fn, A... a) { + return wrap_tensor_node( + impl::map(std::forward(fn), impl::get_nested_tensor_structure(a)...), + c10::nullopt, + c10::nullopt, + c10::nullopt, + c10::nullopt); +} } // namespace native } // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp index bfacd8f29ce41..d33decc224333 100644 --- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp @@ -4,22 +4,76 @@ #include #include #include +#include +#include namespace at { namespace native { +namespace { -Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) { - auto* nt_self = get_nested_tensor_impl_or_null(self); - TORCH_CHECK(nt_self != nullptr); - TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_self)); - TORCH_CHECK(self.dim() == 3 && other.dim() == 2); - const auto last_dim = get_consistent_last_dim_of_nested_tensor(*nt_self); +inline void check_nested_tensor_matrix_constraints( + const Tensor& nested_tensor, + const Tensor& dense_matrix, + c10::string_view caller) { + auto* nt_input = get_nested_tensor_impl(nested_tensor); + TORCH_INTERNAL_ASSERT(nt_input != nullptr); + TORCH_CHECK( + !dense_matrix.is_nested(), + caller, + " does not support nested weight when input is a nested tensor.") + // TODO: support noncontiguous case + // error out for now TORCH_CHECK( - last_dim == other.sizes()[0], - "shape mismatch for NestedTensor matmul. NestedTensor last_dim: ", + nested_tensor_impl_is_contiguous(nt_input), + "for now linear only supports contiguous nested tensor"); + TORCH_CHECK( + nested_tensor.dim() == 3 && dense_matrix.dim() == 2, + caller, + " requires nested_tensor.dim == 3 and dense_matrix.dim == 2." + " Nested tensor dim: ", + nested_tensor.dim(), + ". Dense tensor dim: ", + dense_matrix.dim()); + const auto last_dim = get_consistent_last_dim_of_nested_tensor(*nt_input); + // We check check the second dimension for linear because it transposes before matrix multiply + int64_t dim_constraint = (caller == "Linear") ? 1 : 0; + auto dense_size = dense_matrix.size(dim_constraint); + TORCH_CHECK( + last_dim == dense_size, + "Shape mismatch for NestedTensor ", + caller, + ": Expected input's (a nested tensor) 'last_dim' to equal 'weight.size(", + dim_constraint, + "),", + " but got: last_dim = ", last_dim, - " vs. first dim of rhs: ", - other.sizes()[0]); + ", and weight.size(", + dim_constraint, + ") = ", + dense_size); +} +} // namespace + +Tensor nested_linear( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias_opt) { + check_nested_tensor_matrix_constraints(input, weight, c10::string_view{"Linear"}); + auto* nt_input = get_nested_tensor_impl(input); + const Tensor& input_buffer = nt_input->get_buffer(); + Tensor result_buffer = + at::linear(input_buffer.reshape({-1, weight.size(1)}), weight, bias_opt); + result_buffer = result_buffer.reshape({-1}); + int64_t weight_size_1 = weight.size(0); + Tensor new_sizes = nt_input->get_nested_size_tensor().clone(); + // Now the last entry in every row of new_sizes should be weight_size_1. + new_sizes.index_put_({at::indexing::Slice(), -1}, weight_size_1); + return wrap_buffer(result_buffer, new_sizes); +} + +Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) { + check_nested_tensor_matrix_constraints(self, other, c10::string_view{"Matmul"}); + auto* nt_self = get_nested_tensor_impl_or_null(self); const Tensor& self_buffer = nt_self->get_buffer(); Tensor result_buffer = at::mm(self_buffer.reshape({-1, other.sizes()[0]}), other); @@ -28,8 +82,7 @@ Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) { Tensor new_sizes = nt_self->get_nested_size_tensor().clone(); // Now the last entry in every row of new_sizes should be other_size_1. new_sizes.index_put_({at::indexing::Slice(), -1}, other_size_1); - return at::detail::make_tensor( - std::move(result_buffer), std::move(new_sizes)); + return wrap_buffer(result_buffer, new_sizes); } Tensor NestedTensor_times_Tensor_plus_Tensor_addmm( @@ -85,6 +138,45 @@ Tensor NestedTensor_add_NestedTensor_in_place( return self; } +void NestedTensor_softmax_dropout(const Tensor& query, Tensor& attn_scores) { + const auto* query_nt = get_nested_tensor_impl_or_null(query); + TORCH_INTERNAL_ASSERT(query_nt != nullptr); + TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(query_nt)); + + const Tensor& sizes = query_nt->get_nested_size_tensor(); + const auto num_tensors = sizes.sizes()[0]; + const auto max_seq_len = attn_scores.sizes()[2]; + + for (int64_t i = 0; i < num_tensors; i++) { + auto seq_len = sizes.index({i, 0}).item(); + auto subseq = attn_scores.index( + {i, + indexing::Slice(), + indexing::Slice(0, seq_len), + indexing::Slice(0, seq_len)}); + auto subscores = at::softmax(subseq, subseq.dim() - 1); + attn_scores.index_put_( + {i, + indexing::Slice(), + indexing::Slice(0, seq_len), + indexing::Slice(0, seq_len)}, + subscores); + attn_scores.index_put_( + {i, + indexing::Slice(), + indexing::Slice(0, seq_len), + indexing::Slice(seq_len, max_seq_len)}, + 0); + attn_scores.index_put_( + {i, + indexing::Slice(), + indexing::Slice(seq_len, max_seq_len), + indexing::Slice(0, max_seq_len)}, + 0); + } +} + + Tensor NestedTensor_batch_offsets_from_size_tensor( const Tensor& sizes, int64_t extra_elements) { @@ -137,6 +229,5 @@ Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional mask_dim, c } return result; } - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h index 77eb0145d6847..96ecfe91c3ddd 100644 --- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h +++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h @@ -50,6 +50,8 @@ Tensor NestedTensor_from_padded_tensor_cpu( const Tensor& padded, const NestedTensorImpl& nt); +void NestedTensor_softmax_dropout(const Tensor& query, Tensor& attn_scores); + Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional mask_dim, c10::optional mask_dim_length); template diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index f1cf67676ced8..d89e5c5763d7f 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -135,7 +135,9 @@ Tensor NestedTensor_to_padded_tensor_cuda( (t.dtype() == at::kFloat || t.dtype() == at::kDouble || t.dtype() == at::kHalf)) { auto* nt_input = get_nested_tensor_impl(t); - TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_input)); + TORCH_CHECK( + nested_tensor_impl_is_contiguous(nt_input), + "for now to_padded_tensor only supports contiguous nested tensor"); const auto& nt_buffer = nt_input->get_buffer(); if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0) && diff --git a/aten/src/ATen/native/prim_native_functions.cpp b/aten/src/ATen/native/prim_native_functions.cpp index cc77f04b585fe..8f82345c19058 100644 --- a/aten/src/ATen/native/prim_native_functions.cpp +++ b/aten/src/ATen/native/prim_native_functions.cpp @@ -22,5 +22,13 @@ bool is_nonzero(const Tensor& self) { } TORCH_INTERNAL_ASSERT(false, "Expected non-Tensor backend scalar"); } + + +// Aux function used in the test TestPythonDispatch.test_kwarg_only_and_positional_default +// within test/test_python_dispatch.py +Tensor foobar(const Tensor& self, bool arg1, bool arg2, bool arg3) { + return self; +} + } // namespace meta } // namespace at diff --git a/aten/src/ATen/native/quantized/TensorFactories.cpp b/aten/src/ATen/native/quantized/TensorFactories.cpp index aa0fef5df9dc0..66c48f4ce7528 100644 --- a/aten/src/ATen/native/quantized/TensorFactories.cpp +++ b/aten/src/ATen/native/quantized/TensorFactories.cpp @@ -66,6 +66,16 @@ Tensor empty_per_channel_affine_quantized( quantizer); } +Tensor empty_symint_unknown_quantized( + c10::SymIntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional optional_memory_format) { + return at::native::empty_unknown_quantized(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format); +} + Tensor empty_unknown_quantized( IntArrayRef size, c10::optional dtype, diff --git a/aten/src/ATen/native/quantized/cpu/QuantizedOps.h b/aten/src/ATen/native/quantized/cpu/QuantizedOps.h index bfa1f1f775623..506f0e46e573f 100644 --- a/aten/src/ATen/native/quantized/cpu/QuantizedOps.h +++ b/aten/src/ATen/native/quantized/cpu/QuantizedOps.h @@ -163,6 +163,35 @@ using qnormalize_fn = void (*)( double /* eps */, Tensor* /* Y */); +using qmean_inner_dim_fn = void (*)( + const Tensor& /* X */, + OptionalIntArrayRef /* opt_dim */, + bool /* keepdim */, + c10::optional /* opt_dtype */, + Tensor& /* Y */); + +using qstd_inner_dim_fn = void (*)( + const Tensor& /* X */, + OptionalIntArrayRef /* dim */, + optional /* unbiased */, + bool /* keepdim */, + Tensor& /* Y */); + +using qnormalize_nhwc_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + bool /* affine_per_channel */, + int /* num_channels */, + int /* num_groups */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */); + +using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, + const Tensor& /*qw*/); + DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub); DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub); DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub); @@ -186,6 +215,7 @@ DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub); DECLARE_DISPATCH(qdropout_fn, qdropout_stub); DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub); DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub); +DECLARE_DISPATCH(qnormalize_nhwc_fn, quantized_groupnorm_nhwc_stub); DECLARE_DISPATCH(qrelu_fn, qrelu_stub); DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub); DECLARE_DISPATCH(qgelu_fn, qgelu_stub); @@ -194,6 +224,9 @@ DECLARE_DISPATCH(qtanh_fn, qtanh_stub); DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub); DECLARE_DISPATCH(qtopk_fn, qtopk_stub); DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub); +DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub); +DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub); +DECLARE_DISPATCH(qprelu_fn, qprelu_stub); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp b/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp index 7e9f98d4ec09c..e7f78b29bbf07 100644 --- a/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp @@ -1,12 +1,45 @@ #include #include #include +#include #include #include #include namespace at { namespace native { + +DEFINE_DISPATCH(qmean_inner_dim_stub); +DEFINE_DISPATCH(qstd_inner_dim_stub); + +// If mean/std is taken in the innermost dims, the fast path can be used. +inline bool is_innnermost_dim( + const Tensor& self, + OptionalIntArrayRef opt_dim) { + if (!opt_dim.has_value()) { + return true; + } + auto dims = opt_dim.value().vec(); + auto ndim = self.dim(); + maybe_wrap_dims(dims, ndim); + std::sort(dims.begin(), dims.end(), std::greater()); + bool is_innermost = dims.empty() || dims[0] == ndim - 1; + for (size_t i = 1; i < dims.size(); ++i) { + is_innermost = is_innermost && (dims[i] == dims[i-1] - 1); + } + return is_innermost; +} + +inline bool is_mean_inner_dim_fast_path( + const Tensor& self, + OptionalIntArrayRef opt_dim, + c10::optional opt_dtype) { + bool is_fast_path = + is_innnermost_dim(self, opt_dim) && + (!opt_dtype.has_value() || opt_dtype.value() == self.scalar_type()); + return is_fast_path; +} + #ifdef USE_PYTORCH_QNNPACK Tensor qnnpack_mean(const Tensor& input, IntArrayRef dim, bool keepdim) { Tensor output; @@ -82,22 +115,31 @@ Tensor qnnpack_mean(const Tensor& input, IntArrayRef dim, bool keepdim) { #endif Tensor& mean_out_quantized_cpu( const Tensor& self, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, c10::optional opt_dtype, Tensor& result) { #ifdef USE_PYTORCH_QNNPACK if (at::globalContext().qEngine() == at::QEngine::QNNPACK && - self.scalar_type() == kQUInt8 && - // QNNPACK currently is only supported for NCHW + dim=(2, 3) - // Remove these checks after generic version is implemented. - self.ndimension() == 4 && dim.size() == 2 && dim[0] == 2 && dim[1] == 3) { - result = qnnpack_mean(self, dim, keepdim); - return result; + self.scalar_type() == kQUInt8 && opt_dim.has_value()) { + auto dim = opt_dim.value(); + // QNNPACK currently is only supported for NCHW + dim=(2, 3) + // Remove these checks after generic version is implemented. + if (self.ndimension() == 4 && dim.size() == 2 && dim[0] == 2 && dim[1] == 3) { + result = qnnpack_mean(self, dim, keepdim); + return result; + } } #endif + + // Take average in the innermost dimensions + if (self.is_contiguous(c10::MemoryFormat::Contiguous) && + is_mean_inner_dim_fast_path(self, opt_dim, opt_dtype)) { + qmean_inner_dim_stub(self.device().type(), self, opt_dim, keepdim, opt_dtype, result); + return result; + } auto self_dequantized = self.dequantize(); - auto result_dequantized = at::mean(self_dequantized, dim, keepdim, opt_dtype); + auto result_dequantized = at::mean(self_dequantized, opt_dim, keepdim, opt_dtype); result = at::quantize_per_tensor( result_dequantized, self.q_scale(), @@ -108,11 +150,11 @@ Tensor& mean_out_quantized_cpu( Tensor mean_quantized_cpu( const Tensor& self, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, optional dtype) { Tensor result; - mean_out_quantized_cpu(self, dim, keepdim, dtype, result); + mean_out_quantized_cpu(self, opt_dim, keepdim, dtype, result); return result; } @@ -135,5 +177,79 @@ Tensor& mean_out_quantized_cpu( self, dimnames_to_positions(self, dim), keepdim, opt_dtype, result); } +// qstd +inline bool is_std_inner_dim_fast_path( + const Tensor& self, + OptionalIntArrayRef dim, + optional unbiased) { + // Do not enter fast path if there are too few elements + IntArrayRef dims = dim.has_value() ? dim.value() : IntArrayRef(); + auto all_dims = std::vector(self.dim()); + std::iota(all_dims.begin(), all_dims.end(), 0); + dims = dims.empty() ? all_dims : dims; + bool is_unbiased = unbiased.has_value() ? unbiased.value() : 0; + int64_t num_ele = 1; + for (auto d : dims) { + num_ele *= self.size(d); + } + if (num_ele == 1 && is_unbiased) { + return false; + } + return is_innnermost_dim(self, dims); +} + +Tensor& std_out_quantized_cpu( + const Tensor& self, + OptionalIntArrayRef dim, + optional unbiased, + bool keepdim, + Tensor& result) { + // Fast path + if (self.is_contiguous(c10::MemoryFormat::Contiguous) && + is_std_inner_dim_fast_path(self, dim, unbiased)) { + qstd_inner_dim_stub(self.device().type(), self, dim, unbiased, keepdim, result); + return result; + } + + // Reference path + auto self_dequantized = self.dequantize(); + auto result_dequantized = at::std(self_dequantized, dim, unbiased, keepdim); + result = at::quantize_per_tensor( + result_dequantized, + self.q_scale(), + self.q_zero_point(), + self.scalar_type()); + return result; +} + +Tensor std_quantized_cpu( + const Tensor& self, + OptionalIntArrayRef dim, + optional unbiased, + bool keepdim) { + Tensor result; + std_out_quantized_cpu(self, dim, unbiased, keepdim, result); + return result; +} + +Tensor std_quantized_cpu( + const Tensor& self, + DimnameList dim, + optional unbiased, + bool keepdim) { + return std_quantized_cpu( + self, dimnames_to_positions(self, dim), unbiased, keepdim); +} + +Tensor& std_out_quantized_cpu( + Tensor& result, + const Tensor& self, + DimnameList dim, + optional unbiased, + bool keepdim) { + return std_out_quantized_cpu( + self, dimnames_to_positions(self, dim), unbiased, keepdim, result); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp b/aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp index dc038a0f4e3e0..b4a524566605e 100644 --- a/aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp @@ -5,14 +5,16 @@ #include #include +#include + namespace at { namespace native { void initQNNPACK() { - static std::once_flag once; + static c10::once_flag once; static enum pytorch_qnnp_status qnnpackStatus = pytorch_qnnp_status_uninitialized; - std::call_once(once, []() { qnnpackStatus = pytorch_qnnp_initialize(); }); + c10::call_once(once, []() { qnnpackStatus = pytorch_qnnp_initialize(); }); TORCH_CHECK( qnnpackStatus == pytorch_qnnp_status_success, "failed to initialize QNNPACK"); diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index d16e4d6f4bfd2..b1bdaadaf5b33 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #ifdef USE_FBGEMM @@ -652,6 +653,81 @@ static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx, }); } +static void qprelu_out_kernel(Tensor& out, + const Tensor& qx, + const Tensor& qw) { + int32_t i_zp = static_cast(qx.q_zero_point()); + float i_scale = static_cast(qx.q_scale()); + + int32_t w_zp = static_cast(qw.q_zero_point()); + float w_scale = static_cast(qw.q_scale()); + + int32_t o_zp = static_cast(out.q_zero_point()); + float o_scale = static_cast(out.q_scale()); + float o_inv_scale = 1.0f / o_scale; + + float multiplier = i_scale * w_scale * o_inv_scale; + + int64_t input_ndim = qx.dim(); + TORCH_CHECK(input_ndim > 0, "qprelu: zero-dim input tensor is not allowed."); + + // Helper to convert 1d tensors or scalar tensor to an nd tensor that broadcasts with input + // All elements go into the channel dimension + DimVector sizes(input_ndim, 1), strides(input_ndim, 0); + auto as_nd = [&](const Tensor& t) { + TORCH_INTERNAL_ASSERT(t.defined() && (t.dim() == 1 || t.dim() == 0)); + sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1; + strides[1] = t.dim() == 1 ? t.strides()[0] : 0; + return t.as_strided(sizes, strides); + }; + + auto qw_nd = as_nd(qw); + + auto iter = TensorIteratorConfig() + .add_output(out) + .add_input(qx) + .add_input(qw_nd) + .build(); + + AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qprelu", [&] { + using qVec = Vectorized; + qVec i_zp_vec = qVec(static_cast(i_zp)); + qVec w_zp_vec = qVec(static_cast(w_zp)); + + // Quantized one as weight + auto qw_one = at::native::quantize_val(w_scale, w_zp, 1.0f); + qVec vec_qw_one = qVec(qw_one); + auto vec_qw_one_sub_zp = vec_qw_one.widening_subtract(w_zp_vec)[0]; + int32_t qw_one_sub_zp = qw_one.val_ - w_zp; + + cpu_kernel_vec( + iter, + [=](scalar_t val_qx, scalar_t val_qw) -> scalar_t { + int32_t qx_pos = std::max(static_cast(val_qx.val_), i_zp); + int32_t qx_neg = std::min(static_cast(val_qx.val_), i_zp); + int32_t qx_pos_sub_zp = qx_pos - i_zp; + int32_t qx_neg_sub_zp = qx_neg - i_zp; + int32_t qw_sub_zp = val_qw.val_ - w_zp; + auto qy_sub_zp = qx_pos_sub_zp * qw_one_sub_zp + qx_neg_sub_zp * qw_sub_zp; + return at::native::requantize_from_int( + multiplier, o_zp, qy_sub_zp); + }, + [=](qVec vec_qx, qVec vec_qw) -> qVec { + auto vec_qx_pos = vec_qx.maximum(i_zp_vec); + auto vec_qx_neg = vec_qx.minimum(i_zp_vec); + qVec::int_vec_return_type qx_pos_sub_zp = vec_qx_pos.widening_subtract(i_zp_vec); + qVec::int_vec_return_type qx_neg_sub_zp = vec_qx_neg.widening_subtract(i_zp_vec); + qVec::int_vec_return_type qw_sub_zp = vec_qw.widening_subtract(w_zp_vec); + qVec::int_vec_return_type qy_sub_zp; + for (const auto i : c10::irange(qVec::int_num_vecs())) { + qy_sub_zp[i] = qx_pos_sub_zp[i] * vec_qw_one_sub_zp + qx_neg_sub_zp[i] * qw_sub_zp[i]; + } + return qVec::requantize_from_int(qy_sub_zp, multiplier, o_zp); + }); + }); + +} + void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) { int64_t zero_point = qx.q_zero_point(); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) @@ -2727,6 +2803,399 @@ void quantized_normalize_kernel( }); } +void qmean_inner_dim_kernel( + const Tensor& self, + OptionalIntArrayRef opt_dim, + bool keepdim, + c10::optional opt_dtype, + Tensor& result) { + // 'opt_dtype' should be none or equal to that of input + ScalarType dtype = self.scalar_type(); + auto in_dims = self.sizes().vec(); + auto out_dims = in_dims; + bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty(); + size_t num_dims_to_squeeze = is_all_reduce ? self.dim() : opt_dim.value().size(); + int64_t M = 1; // Num of groups + int64_t N = 1; // Num of elements to take average of in each group + for (size_t i = 0; i < in_dims.size() - num_dims_to_squeeze; ++i) { + M *= in_dims[i]; + } + for (size_t i = 0; i < num_dims_to_squeeze; ++i) { + auto idx = out_dims.size() - 1 - i; + N *= out_dims[idx]; + out_dims[idx] = 1; + } + if (!keepdim) { + out_dims.erase(out_dims.end() - num_dims_to_squeeze, out_dims.end()); + } + result = at::_empty_affine_quantized( + out_dims, + at::device(kCPU).dtype(dtype).memory_format(self.suggest_memory_format()), + self.q_scale(), + self.q_zero_point(), + c10::nullopt); + + AT_DISPATCH_QINT_TYPES(self.scalar_type(), "quantized_mean_kernel_impl_cpu", [&]() { + scalar_t* X_data = self.data_ptr(); + scalar_t* Y_data = result.data_ptr(); + + at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) { + for (const auto i : c10::irange(start, end)) { + scalar_t* X_ptr = X_data + i * N; + scalar_t* Y_ptr = Y_data + i; + scalar_t::underlying* X_ptr_underlying = reinterpret_cast(X_ptr); + scalar_t::underlying* Y_ptr_underlying = reinterpret_cast(Y_ptr); + auto x_sum = hsum(X_ptr_underlying, N); + float y_float = static_cast(x_sum) / N; + *Y_ptr_underlying = std::nearbyint(y_float); + } + }); + }); +} + +void qstd_inner_dim_kernel( + const Tensor& self, + OptionalIntArrayRef dim, + optional unbiased, + bool keepdim, + Tensor& result) { + ScalarType dtype = self.scalar_type(); + auto in_dims = self.sizes().vec(); + auto out_dims = in_dims; + size_t num_dims_to_squeeze = dim.has_value() && !dim.value().empty() ? + dim.value().size() : + self.dim(); + int64_t M = 1; // Num of groups + int64_t N = 1; // Num of elements to take std of in each group + for (size_t i = 0; i < in_dims.size() - num_dims_to_squeeze; ++i) { + M *= in_dims[i]; + } + for (size_t i = 0; i < num_dims_to_squeeze; ++i) { + auto idx = out_dims.size() - 1 - i; + N *= out_dims[idx]; + out_dims[idx] = 1; + } + if (!keepdim) { + out_dims.erase(out_dims.end() - num_dims_to_squeeze, out_dims.end()); + } + int64_t den = N; // Denominator when computing mean and deviation + if (unbiased.has_value() && unbiased.value() == 1) { + den -= 1; + } + auto x_scale = self.q_scale(); + auto x_zp = self.q_zero_point(); + result = at::_empty_affine_quantized( + out_dims, + at::device(kCPU).dtype(dtype).memory_format(self.suggest_memory_format()), + x_scale, + x_zp, + c10::nullopt); + + AT_DISPATCH_QINT_TYPES(self.scalar_type(), "quantized_std_kernel_impl_cpu", [&]() { + scalar_t* X_data = self.data_ptr(); + scalar_t* Y_data = result.data_ptr(); + + at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) { + for (const auto i : c10::irange(start, end)) { + scalar_t* X_ptr = X_data + i * N; + scalar_t* Y_ptr = Y_data + i; + scalar_t::underlying* X_ptr_underlying = reinterpret_cast(X_ptr); + scalar_t::underlying* Y_ptr_underlying = reinterpret_cast(Y_ptr); + auto x_sum_shifted = hsum(X_ptr_underlying, N); + auto x_sum_sq_shifted = hsum_sq(X_ptr_underlying, N); + // Use double for intermediate variables to avoid accuracy issue + // Mean with zero point + double x_mean_shifted_div_scale_x = static_cast(x_sum_shifted) / N; + double x_mean_unbiased_shifted_div_scale_x = static_cast(x_sum_shifted) / den; + // variance / x_scale^2 + double x_var_div_scale_x_sq = + std::max(static_cast(x_sum_sq_shifted) / den - + 2 * x_mean_shifted_div_scale_x * x_mean_unbiased_shifted_div_scale_x + + x_mean_shifted_div_scale_x * x_mean_shifted_div_scale_x * N / den, (double)0.0); + double y_float = std::sqrt(x_var_div_scale_x_sq) * x_scale; + *Y_ptr_underlying = at::native::quantize_val( + x_scale, x_zp, y_float) + .val_; + } + }); + }); +} + +// For group norm of channels_last input +void quantized_groupnorm_nhwc_kernel( + const Tensor& X, // input tensor + const Tensor& gamma, // weight (optional) + const Tensor& beta, // bias (optional) + bool affine_per_channel, // must be true for group/instance norm + int num_channels, // only used if affine_per_channel is set + int num_groups, // only used if affine_per_channel is set + int64_t M, // number of groups = Bs * G + int64_t N, // number of elements in each group = C * H * W / G + double eps, + Tensor* Y) { + AT_DISPATCH_QINT_TYPES(X.scalar_type(), "quantized_norm_nhwc_kernel_impl_cpu", [&]() { + using qVec = vec::Vectorized; + using fVec = vec::Vectorized; + + int64_t G = num_groups; + int64_t Bs = M / G; + int64_t C = num_channels; + + TORCH_INTERNAL_ASSERT(X.numel() == M * N, "Unexpected num elements in X"); + TORCH_INTERNAL_ASSERT( + !gamma.defined() || + (!affine_per_channel && gamma.numel() == N) || + (affine_per_channel && gamma.numel() == C), + "Unexpected size of gamma"); + TORCH_INTERNAL_ASSERT( + !beta.defined() || + (!affine_per_channel && beta.numel() == N) || + (affine_per_channel && beta.numel() == C), + "Unexpected size of beta"); + + scalar_t* X_data = X.data_ptr(); + const float* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; + const float* beta_data = beta.defined() ? beta.data_ptr() : nullptr; + scalar_t* Y_data = Y->data_ptr(); + const bool gamma_null = gamma_data == nullptr; + const bool beta_null = beta_data == nullptr; + int64_t x_zp = X.q_zero_point(); + float x_scale = X.q_scale(); + fVec x_zp_vec((float)x_zp); + fVec one_vec(1.0f); + fVec zero_vec(0.0f); + float x_fake_scale = 1.0f; + fVec x_fake_scale_vec(x_fake_scale); + fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg(); + int64_t y_zp = Y->q_zero_point(); + float y_scale = Y->q_scale(); + float y_inv_scale = 1.0f / y_scale; + + constexpr int kFloatVLen = fVec::size(); + int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs(); + int64_t channels_per_group = C / G; + int64_t HxW = N / channels_per_group; + int64_t kNumIntVecInHxW = channels_per_group / kIntVLen; + int64_t kNonVecRemInHxW = channels_per_group % kIntVLen; + int64_t kNumIntVecOnChannel = C / kIntVLen; + int64_t kNonVecRemOnChannel = C % kIntVLen; + + // Buffer for x and x^2 + Tensor buffer = at::empty({M, 2 * channels_per_group}, X.options().dtype(at::kFloat)); + float* buffer_data = buffer.data_ptr(); + + // We can parallel in the following 2 impls: + // + // impl-1: parallel on N * G. Only need one omp session but memory access + // per thread is non-contiguous. + // + // impl-2: parallel on N * HxW. Memory access per thread is contiguous, + // but requires help of extra temp buffer of size {T, N, 2C}. + // + // Generally impl-2 has better performance when HxW is large enough + // The threshold is found by tests. + constexpr int64_t feature_map_threshold = 512; + if (HxW < feature_map_threshold) { + // Impl-1: Parallel for each group + // + // Parallel for each group, M = Bs * G + at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) { + int64_t n{0} /* batch index */, g{0} /* group index in each batch */; + data_index_init(begin, n, N, g, G); + for (const auto grpIdx : c10::irange(begin, end)) { // For each group + + // Step 1: calculate mean and variance. + int64_t l_sum_shifted = 0; + int64_t l_sum_sq_shifted = 0; + for (const auto hw : c10::irange(HxW)) { + scalar_t* X_ptr = X_data + n * N * G + g * channels_per_group + hw * C; + scalar_t::underlying* X_ptr_underlying = reinterpret_cast(X_ptr); + l_sum_shifted += hsum(X_ptr_underlying, channels_per_group); + l_sum_sq_shifted += hsum_sq(X_ptr_underlying, channels_per_group); + } + + // mean(dqX) / scale_x + x_zp + float l_mean_shifted_div_scale_x = static_cast(l_sum_shifted) / N; + // mean(dqX) / scale_x + float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp; + // var(dqX) / scale_x^2 + float layer_var_div_scale_x_sq = + std::max(static_cast(l_sum_sq_shifted) / N - + l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f); + // scale_x / sqrt(var(dqX) + eps) + float scale_x_div_layer_std = x_scale / + std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps); + + // Step 2: calculate scale and bias + float* scale_ptr = buffer_data + grpIdx * 2 * channels_per_group; + float* bias_ptr = scale_ptr + channels_per_group; + for (const auto d : c10::irange(channels_per_group)) { + const int64_t chIdx = g * channels_per_group + d; + scale_ptr[d] = scale_x_div_layer_std * (gamma_null ? 1.0f : gamma_data[chIdx]); + bias_ptr[d] = -scale_ptr[d] * layer_mean_div_scale_x + (beta_null ? 0.0f : beta_data[chIdx]); + } + + // Step 3: applying scale and bias + for (const auto hwIdx : c10::irange(HxW)) { + const scalar_t* X_ptr = X_data + n * N * G + g * channels_per_group + hwIdx * C; + scalar_t* Y_ptr = Y_data + n * N * G + g * channels_per_group + hwIdx * C; + // vectorized + for (const auto vecIdx : c10::irange(kNumIntVecInHxW)) { + int64_t vecStartIdx = vecIdx * kIntVLen; + auto qXVec = qVec::loadu(X_ptr + vecStartIdx); + auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec, + x_fake_scale_zp_neg_premul_vec); + for (size_t fvecIdx = 0; fvecIdx < dqXVec.size(); ++fvecIdx) { + auto scaleVec = fVec::loadu(scale_ptr + vecStartIdx + fvecIdx * kFloatVLen); + auto biasVec = fVec::loadu(bias_ptr + vecStartIdx + fvecIdx * kFloatVLen); + dqXVec[fvecIdx] = dqXVec[fvecIdx] * scaleVec + biasVec; + } + qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale) + .store(Y_ptr + vecStartIdx); + } + // Remaining scalar + for (int64_t remIdx = kNumIntVecInHxW * kIntVLen; + remIdx < kNonVecRemInHxW + kNumIntVecInHxW * kIntVLen; + ++remIdx) { + auto qXVal = X_ptr[remIdx]; + float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal); + float dqY = dqXVal * scale_ptr[remIdx] + bias_ptr[remIdx]; + Y_ptr[remIdx] = at::native::quantize_val(y_scale, y_zp, dqY); + } + } // loop over HxW + + data_index_step(n, N, g, G); + } // for each group + }); // parallel_for + } else { // HxW > feature_map_threshold + // impl-2: parallel on Bs * HxW. + // + // Buffer for x and x^2 + // To avoid thread conflict, we use a temp buffer of {T, Bs, 2*C} + int num_threads = at::get_num_threads(); + Tensor buffer = at::empty({num_threads, Bs, 2 * C}, X.options().dtype(at::kFloat)).zero_(); + float* buffer_data = buffer.data_ptr(); + Tensor mean = at::empty(M, X.options().dtype(at::kFloat)); + float* mean_data = mean.data_ptr(); + Tensor rstd = at::empty(M, X.options().dtype(at::kFloat)); + float* rstd_data = rstd.data_ptr(); + + // Step 1: Accumulate on C dimension + at::parallel_for(0, Bs * HxW, 1, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + float* buffer_ptr = buffer_data + tid * Bs * 2 * C; + + int64_t n{0} /* batch index */, m{0} /* HxW index */; + data_index_init(begin, n, Bs, m, HxW); + for (const auto nhwIdx : c10::irange(begin, end)) { + float* mean_ptr = buffer_ptr + n * 2 * C; + float* rstd_ptr = mean_ptr + C; + scalar_t* X_ptr = X_data + nhwIdx * C; + scalar_t::underlying* X_ptr_underlying = reinterpret_cast(X_ptr); + for (int chIdx = 0; chIdx < C; ++chIdx) { + auto x = X_ptr_underlying[chIdx]; + mean_ptr[chIdx] += x; + rstd_ptr[chIdx] += x * x; + } + data_index_step(n, Bs, m, HxW); + } + }); + + // Step 2: Calculate mean and rstd + for (const auto n : c10::irange(Bs)) { + for (const auto g : c10::irange(G)) { + float mean_val{0}, rstd_val{0}; + for (const auto t : c10::irange(num_threads)) { + float* buffer_ptr = buffer_data + t * Bs * 2 * C + n * 2 * C; + for (const auto d : c10::irange(channels_per_group)) { + mean_val += buffer_ptr[g * channels_per_group + d]; + rstd_val += buffer_ptr[g * channels_per_group + d + C]; + } // for d + } // for t + + // mean / scale_x + x_zp + float l_mean_shifted_div_scale_x = mean_val / N; + // mean / scale_x + float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp; + // var / scale_x^2 + float layer_var_div_scale_x_sq = + std::max(rstd_val / N - + l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f); + // scale_x / sqrt(var + eps) + float scale_x_div_layer_std = x_scale / + std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps); + mean_data[n * G + g] = layer_mean_div_scale_x; + rstd_data[n * G + g] = scale_x_div_layer_std; + + } // for g + } // for n + + // Step 3: Calculate scale and bias + // + // We could fuse step 3 and 4 into a single session but this way is better: + // a. D might be too small for vectorization; + // b. Avoid duplicate caculation of scale/bias, each HxW plain share the same scale/bias + // + for (const auto n : c10::irange(Bs)) { + for (const auto g : c10::irange(G)) { + float* scale_ptr = buffer_data + n * 2 * C; + float* bias_ptr = scale_ptr + C; + float mean_val = mean_data[n * G + g]; + float rstd_val = rstd_data[n * G + g]; + for (const auto d : c10::irange(channels_per_group)) { + const int64_t chIdx = g * channels_per_group + d; + scale_ptr[chIdx] = rstd_val * (gamma_null ? 1.0f : gamma_data[chIdx]); + bias_ptr[chIdx] = -scale_ptr[chIdx] * mean_val + (beta_null ? 0.0f : beta_data[chIdx]); + } // for d + } // for g + } // for n + + // step-4: apply scale and bias + // + // Parallel on all the outer dimensions of Bs and HxW + // and vectorize on C. + // + at::parallel_for(0, Bs * HxW, 1, [&](int64_t begin, int64_t end) { + int64_t n{0}, m{0}; + data_index_init(begin, n, Bs, m, HxW); + for (const auto nhwIdx : c10::irange(begin, end)) { + const scalar_t* X_ptr = X_data + nhwIdx * C; + scalar_t* Y_ptr = Y_data + nhwIdx * C; + float* scale_ptr = buffer_data + n * 2 * C; + float* bias_ptr = scale_ptr + C; + // Vectorized + for (const auto vecIdx : c10::irange(kNumIntVecOnChannel)) { + int64_t vecStartIdx = vecIdx * kIntVLen; + auto qXVec = qVec::loadu(X_ptr + vecStartIdx); + auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec, + x_fake_scale_zp_neg_premul_vec); + for (size_t fvecIdx = 0; fvecIdx < dqXVec.size(); ++fvecIdx) { + auto scaleVec = fVec::loadu(scale_ptr + vecStartIdx + fvecIdx * kFloatVLen); + auto biasVec = fVec::loadu(bias_ptr + vecStartIdx + fvecIdx * kFloatVLen); + dqXVec[fvecIdx] = dqXVec[fvecIdx] * scaleVec + biasVec; + } + qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale) + .store(Y_ptr + vecStartIdx); + } + // Remaining scalar + for (int64_t remIdx = kNumIntVecOnChannel * kIntVLen; + remIdx < kNonVecRemOnChannel + kNumIntVecOnChannel * kIntVLen; + ++remIdx) { + auto qXVal = X_ptr[remIdx]; + float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal); + float dqY = dqXVal * scale_ptr[remIdx] + bias_ptr[remIdx]; + Y_ptr[remIdx] = at::native::quantize_val(y_scale, y_zp, dqY); + } + + data_index_step(n, Bs, m, HxW); + } // for idx on nhw + }); // parallel_for on nhw + + } // if HxW > feature_map_threshold + + }); // AT_DISPATCH_QINT_TYPES +} + #ifdef USE_FBGEMM void quantize_tensor_per_tensor_affine_cpu( const Tensor& rtensor, @@ -3694,6 +4163,7 @@ REGISTER_NO_AVX512_DISPATCH(qmul_relu_stub); REGISTER_NO_AVX512_DISPATCH(qmul_stub); REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub); REGISTER_NO_AVX512_DISPATCH(qrelu_stub); +REGISTER_NO_AVX512_DISPATCH(qprelu_stub); REGISTER_NO_AVX512_DISPATCH(qgelu_stub); REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub); REGISTER_NO_AVX512_DISPATCH(qtanh_stub); @@ -3704,11 +4174,14 @@ REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_stub); REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_affine_stub); REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_float_qparams_stub); REGISTER_NO_AVX512_DISPATCH(quantized_normalize_stub); +REGISTER_NO_AVX512_DISPATCH(quantized_groupnorm_nhwc_stub); REGISTER_NO_AVX512_DISPATCH(qupsample_bilinear2d_nhwc_stub); REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub); REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub); REGISTER_NO_AVX512_DISPATCH(masked_fill_kernel_quantized_stub); REGISTER_NO_AVX512_DISPATCH(index_put_kernel_quantized_stub); +REGISTER_NO_AVX512_DISPATCH(qmean_inner_dim_stub); +REGISTER_NO_AVX512_DISPATCH(qstd_inner_dim_stub); #else REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub, &dequantize_tensor_per_channel_affine_cpu); @@ -3748,6 +4221,7 @@ REGISTER_DISPATCH(qmul_relu_stub, &qmul_kernel); REGISTER_DISPATCH(qmul_stub, &qmul_kernel); REGISTER_DISPATCH(qrelu_leaky_stub, &leaky_qrelu_out_kernel); REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel); +REGISTER_DISPATCH(qprelu_stub, &qprelu_out_kernel); REGISTER_DISPATCH(qgelu_stub, &qgelu_kernel); REGISTER_DISPATCH(qsigmoid_stub, &qsigmoid_kernel); REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel); @@ -3765,6 +4239,7 @@ REGISTER_DISPATCH( quantize_tensor_per_channel_float_qparams_stub, &quantize_tensor_per_channel_float_qparams_cpu); REGISTER_DISPATCH(quantized_normalize_stub, &quantized_normalize_kernel); +REGISTER_DISPATCH(quantized_groupnorm_nhwc_stub, &quantized_groupnorm_nhwc_kernel); REGISTER_DISPATCH(qupsample_bilinear2d_nhwc_stub, &qupsample_bilinear2d_nhwc_kernel); REGISTER_DISPATCH( @@ -3779,6 +4254,8 @@ REGISTER_DISPATCH( REGISTER_DISPATCH( index_put_kernel_quantized_stub, &index_put_kernel_quantized_cpu); +REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel); +REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel); #endif // CPU_CAPABILITY_AVX512 && _WIN32 } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/BUCK.oss b/aten/src/ATen/native/quantized/cpu/qnnpack/BUCK.oss index 85abc6a609160..4580a6f7205be 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/BUCK.oss +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/BUCK.oss @@ -1,143 +1,4 @@ -load("//tools/build_defs:glob_defs.bzl", "subdir_glob") +load("//:buckbuild.bzl", "third_party") +load(":buckbuild.bzl", "define_qnnpack") -cxx_library( - name = "pytorch_qnnpack", - srcs = ['src/add.c', 'src/average-pooling.c', 'src/channel-shuffle.c', 'src/clamp.c', 'src/conv-prepack.cc', 'src/conv-run.cc', 'src/convolution.c', 'src/deconv-run.cc', 'src/deconvolution.c', 'src/fc-dynamic-run.cc', 'src/fc-prepack.cc', 'src/fc-run.cc', 'src/fully-connected.c', 'src/fully-connected-sparse.c', 'src/global-average-pooling.c', 'src/hardsigmoid.c', 'src/hardswish.c', 'src/indirection.c', 'src/init.c', 'src/leaky-relu.c', 'src/max-pooling.c', 'src/operator-delete.c', 'src/operator-run.c', 'src/pack_block_sparse.cc', 'src/sigmoid.c', 'src/softargmax.c', 'src/tanh.c'], - deps = [':qnnp_interface', ':ukernels_asm', ':ukernels_neon', ':ukernels_psimd', ':ukernels_scalar', ':ukernels_sse2', ':ukernels_sse41', ':ukernels_ssse3', '//third_party:cpuinfo', '//third_party:FP16', '//third_party:FXdiv'], - exported_deps = ['//third_party:cpuinfo'], - compiler_flags = ['-O2', '-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION', '-Wno-deprecated-declarations'], - preferred_linkage = "static", - exported_headers = subdir_glob([("src", "qnnpack/*.h"),("include", "*.h"),]), - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_compiler_flags = [['armv7', ['-mfpu=neon']], ['^android-armv7$', ['-marm', '-mfloat-abi=softfp']]], - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - - -cxx_library( - name = "ukernels_ssse3", - srcs = ['wrappers/requantization/gemmlowp-ssse3.c', 'wrappers/requantization/precise-ssse3.c', 'wrappers/requantization/q31-ssse3.c'], - deps = [':qnnp_interface', '//third_party:cpuinfo', '//third_party:FP16', '//third_party:FXdiv'], - exported_deps = [], - compiler_flags = ['-O3', '-ffast-math', '-Wno-error=unused-variable', '-Wno-shadow', '-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION'], - preferred_linkage = "static", - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_compiler_flags = [['86', ['-mssse3', '-mno-sse4']], ['osmeta', ['-mosmeta-no-restrict-sse']]], - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - - -cxx_library( - name = "ukernels_psimd", - srcs = ['src/requantization/fp32-psimd.c', 'src/requantization/precise-psimd.c', 'src/sgemm/6x8-psimd.c'], - deps = [':qnnp_interface', '//third_party:cpuinfo', '//third_party:FP16', '//third_party:FXdiv', '//third_party:psimd'], - exported_deps = [], - compiler_flags = ['-O3', '-ffast-math', '-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION'], - preferred_linkage = "static", - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_compiler_flags = [['armv7', ['-mfpu=neon']], ['^android-armv7$', ['-marm', '-mfloat-abi=softfp']]], - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - - -cxx_library( - name = "ukernels_scalar", - srcs = ['src/requantization/fp32-scalar.c', 'src/requantization/gemmlowp-scalar.c', 'src/requantization/precise-scalar.c', 'src/requantization/q31-scalar.c', 'src/u8lut32norm/scalar.c', 'src/x8lut/scalar.c'], - deps = [':qnnp_interface', '//third_party:cpuinfo', '//third_party:FP16', '//third_party:FXdiv'], - exported_deps = [], - compiler_flags = ['-O2', '-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION'], - preferred_linkage = "static", - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - - -cxx_library( - name = "ukernels_asm", - srcs = ['wrappers/dummy.c', 'wrappers/hgemm/8x8-aarch32-neonfp16arith.S', 'wrappers/q8conv/4x8-aarch32-neon.S', 'wrappers/q8dwconv/up8x9-aarch32-neon.S', 'wrappers/q8dwconv/up8x9-aarch32-neon-per-channel.S', 'wrappers/q8gemm/4x8-aarch32-neon.S', 'wrappers/q8gemm/4x8-dq-aarch32-neon.S', 'wrappers/q8gemm/4x8c2-xzp-aarch32-neon.S', 'wrappers/q8gemm_sparse/4x4-packA-aarch32-neon.S', 'wrappers/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S', 'wrappers/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S', 'wrappers/q8gemm_sparse/8x4-packA-aarch64-neon.S', 'wrappers/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S', 'wrappers/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S', 'wrappers/q8conv/8x8-aarch64-neon.S', 'wrappers/q8gemm/8x8-aarch64-neon.S', 'wrappers/q8gemm/8x8-dq-aarch64-neon.S'], - deps = [], - exported_deps = [], - compiler_flags = ['-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION'], - preferred_linkage = "static", - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_compiler_flags = [['^iphoneos-armv7$', ['-mfpu=neon-vfpv4']], ['osmeta', ['-mfpu=neon-vfpv4']]], - platform_preprocessor_flags = [['android', ['-D__ELF__=1']], ['tizen', ['-D__ELF__=1']], ['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - - -cxx_library( - name = "ukernels_sse41", - srcs = ['wrappers/requantization/gemmlowp-sse4.c', 'wrappers/requantization/precise-sse4.c', 'wrappers/requantization/q31-sse4.c'], - deps = [':qnnp_interface', '//third_party:cpuinfo', '//third_party:FP16', '//third_party:FXdiv'], - exported_deps = [], - compiler_flags = ['-O3', '-ffast-math', '-Wno-error=unused-variable', '-Wno-shadow', '-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION'], - preferred_linkage = "static", - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_compiler_flags = [['86', ['-msse4.1', '-mno-sse4.2']], ['osmeta', ['-mosmeta-no-restrict-sse']]], - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - - -cxx_library( - name = "ukernels_neon", - srcs = ['wrappers/q8avgpool/mp8x9p8q-neon.c', 'wrappers/q8avgpool/up8x9-neon.c', 'wrappers/q8avgpool/up8xm-neon.c', 'wrappers/q8conv/4x8-neon.c', 'wrappers/q8conv/8x8-neon.c', 'wrappers/q8dwconv/mp8x25-neon.c', 'wrappers/q8dwconv/mp8x25-neon-per-channel.c', 'wrappers/q8dwconv/mp8x27-neon.c', 'wrappers/q8dwconv/up8x9-neon.c', 'wrappers/q8dwconv/up8x9-neon-per-channel.c', 'wrappers/q8gavgpool/mp8x7p7q-neon.c', 'wrappers/q8gavgpool/up8x7-neon.c', 'wrappers/q8gavgpool/up8xm-neon.c', 'wrappers/q8gemm/4x-sumrows-neon.c', 'wrappers/q8gemm/4x8-dq-neon.c', 'wrappers/q8gemm/4x8-neon.c', 'wrappers/q8gemm/4x8c2-xzp-neon.c', 'wrappers/q8gemm/6x4-neon.c', 'wrappers/q8gemm/8x8-neon.c', 'wrappers/q8vadd/neon.c', 'wrappers/requantization/fp32-neon.c', 'wrappers/requantization/gemmlowp-neon.c', 'wrappers/requantization/precise-neon.c', 'wrappers/requantization/q31-neon.c', 'wrappers/sgemm/5x8-neon.c', 'wrappers/sgemm/6x8-neon.c', 'wrappers/u8clamp/neon.c', 'wrappers/u8maxpool/16x9p8q-neon.c', 'wrappers/u8maxpool/sub16-neon.c', 'wrappers/u8rmax/neon.c', 'wrappers/x8zip/x2-neon.c', 'wrappers/x8zip/x3-neon.c', 'wrappers/x8zip/x4-neon.c', 'wrappers/x8zip/xm-neon.c'], - deps = [':qnnp_interface', '//third_party:cpuinfo', '//third_party:FP16', '//third_party:FXdiv'], - exported_deps = [], - compiler_flags = ['-O3', '-ffast-math', '-Wno-error=unused-variable', '-Wno-shadow', '-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION'], - preferred_linkage = "static", - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_compiler_flags = [['armv7', ['-mfpu=neon']], ['^android-armv7$', ['-marm', '-mfloat-abi=softfp']]], - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - - -cxx_library( - name = "ukernels_sse2", - srcs = ['wrappers/q8avgpool/mp8x9p8q-sse2.c', 'wrappers/q8avgpool/up8x9-sse2.c', 'wrappers/q8avgpool/up8xm-sse2.c', 'wrappers/q8conv/4x4c2-sse2.c', 'wrappers/q8dwconv/mp8x25-sse2.c', 'wrappers/q8dwconv/mp8x25-sse2-per-channel.c', 'wrappers/q8dwconv/mp8x27-sse2.c', 'wrappers/q8dwconv/up8x9-sse2.c', 'wrappers/q8dwconv/up8x9-sse2-per-channel.c', 'wrappers/q8gavgpool/mp8x7p7q-sse2.c', 'wrappers/q8gavgpool/up8x7-sse2.c', 'wrappers/q8gavgpool/up8xm-sse2.c', 'wrappers/q8gemm/2x4c8-sse2.c', 'wrappers/q8gemm/4x4c2-dq-sse2.c', 'wrappers/q8gemm/4x4c2-sse2.c', 'wrappers/q8gemm_sparse/8x4c1x4-packed-sse2.c', 'wrappers/q8vadd/sse2.c', 'wrappers/requantization/fp32-sse2.c', 'wrappers/requantization/gemmlowp-sse2.c', 'wrappers/requantization/precise-sse2.c', 'wrappers/requantization/q31-sse2.c', 'wrappers/u8clamp/sse2.c', 'wrappers/u8maxpool/16x9p8q-sse2.c', 'wrappers/u8maxpool/sub16-sse2.c', 'wrappers/u8rmax/sse2.c', 'wrappers/x8zip/x2-sse2.c', 'wrappers/x8zip/x3-sse2.c', 'wrappers/x8zip/x4-sse2.c', 'wrappers/x8zip/xm-sse2.c'], - deps = [':qnnp_interface', '//third_party:cpuinfo', '//third_party:FP16', '//third_party:FXdiv'], - exported_deps = [], - compiler_flags = ['-O3', '-ffast-math', '-Wno-error=unused-variable', '-Wno-shadow', '-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION'], - preferred_linkage = "static", - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_compiler_flags = [['86', ['-msse2', '-mno-sse3']]], - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) - - -cxx_library( - name = "qnnp_interface", - srcs = [], - deps = ['//third_party:pthreadpool_header'], - exported_deps = [], - compiler_flags = ['-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION'], - preferred_linkage = "static", - header_namespace = "", - headers = subdir_glob([("src", "**/*.c"), ("src", "q8gemm_sparse/*.h"), ("src", "qnnpack/*.h"), ("src", "requantization/*.h")]), - link_whole = False, - platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], - visibility = ['PUBLIC'], -) +define_qnnpack(third_party) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl new file mode 100644 index 0000000000000..5c1c316678e10 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl @@ -0,0 +1,647 @@ +load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") +load("//tools/build_defs:fb_xplat_cxx_test.bzl", "fb_xplat_cxx_test") +load("//tools/build_defs:glob_defs.bzl", "subdir_glob") +load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX") + +# Shared by internal and OSS BUCK +def define_qnnpack(third_party, labels = []): + fb_xplat_cxx_library( + # @autodeps-skip + name = "ukernels_scalar", + srcs = [ + "src/requantization/fp32-scalar.c", + "src/requantization/gemmlowp-scalar.c", + "src/requantization/precise-scalar.c", + "src/requantization/q31-scalar.c", + "src/u8lut32norm/scalar.c", + "src/x8lut/scalar.c", + ], + headers = subdir_glob([ + ("src", "qnnpack/*.h"), + ("src", "requantization/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + fbobjc_preprocessor_flags = [ + "-DQNNP_PRIVATE=", + "-DQNNP_INTERNAL=", + ], + force_static = True, + labels = labels, + visibility = ["PUBLIC"], + deps = [ + ":qnnp_interface", + third_party("cpuinfo"), + third_party("FP16"), + third_party("FXdiv"), + ], + ) + + fb_xplat_cxx_library( + # @autodeps-skip + name = "ukernels_sse2", + srcs = [ + "wrappers/q8avgpool/mp8x9p8q-sse2.c", + "wrappers/q8avgpool/up8x9-sse2.c", + "wrappers/q8avgpool/up8xm-sse2.c", + "wrappers/q8conv/4x4c2-sse2.c", + "wrappers/q8dwconv/mp8x25-sse2.c", + "wrappers/q8dwconv/mp8x25-sse2-per-channel.c", + "wrappers/q8dwconv/mp8x27-sse2.c", + "wrappers/q8dwconv/up8x9-sse2.c", + "wrappers/q8dwconv/up8x9-sse2-per-channel.c", + "wrappers/q8gavgpool/mp8x7p7q-sse2.c", + "wrappers/q8gavgpool/up8x7-sse2.c", + "wrappers/q8gavgpool/up8xm-sse2.c", + "wrappers/q8gemm/2x4c8-sse2.c", + "wrappers/q8gemm/4x4c2-dq-sse2.c", + "wrappers/q8gemm/4x4c2-sse2.c", + "wrappers/q8gemm_sparse/8x4c1x4-packed-sse2.c", + "wrappers/q8vadd/sse2.c", + "wrappers/requantization/fp32-sse2.c", + "wrappers/requantization/gemmlowp-sse2.c", + "wrappers/requantization/precise-sse2.c", + "wrappers/requantization/q31-sse2.c", + "wrappers/u8clamp/sse2.c", + "wrappers/u8maxpool/16x9p8q-sse2.c", + "wrappers/u8maxpool/sub16-sse2.c", + "wrappers/u8rmax/sse2.c", + "wrappers/x8zip/x2-sse2.c", + "wrappers/x8zip/x3-sse2.c", + "wrappers/x8zip/x4-sse2.c", + "wrappers/x8zip/xm-sse2.c", + ], + headers = subdir_glob([ + ("src", "**/*.c"), + ("src", "q8gemm_sparse/*.h"), + ("src", "qnnpack/*.h"), + ("src", "requantization/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O3", + "-ffast-math", + "-Wno-error=unused-variable", + "-Wno-shadow", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + fbobjc_preprocessor_flags = [ + "-DQNNP_PRIVATE=", + "-DQNNP_INTERNAL=", + ], + force_static = True, + labels = labels, + platform_compiler_flags = [ + ( + "86", + [ + "-msse2", + "-mno-sse3", + ], + ), + ], + visibility = ["PUBLIC"], + deps = [ + ":qnnp_interface", + third_party("cpuinfo"), + third_party("FP16"), + third_party("FXdiv"), + ], + ) + + fb_xplat_cxx_library( + # @autodeps-skip + name = "ukernels_ssse3", + srcs = [ + "wrappers/requantization/gemmlowp-ssse3.c", + "wrappers/requantization/precise-ssse3.c", + "wrappers/requantization/q31-ssse3.c", + ], + headers = subdir_glob([ + ("src", "**/*.c"), + ("src", "qnnpack/*.h"), + ("src", "requantization/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O3", + "-ffast-math", + "-Wno-error=unused-variable", + "-Wno-shadow", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + fbobjc_preprocessor_flags = [ + "-DQNNP_PRIVATE=", + "-DQNNP_INTERNAL=", + ], + force_static = True, + labels = labels, + platform_compiler_flags = [ + ( + "86", + [ + "-mssse3", + "-mno-sse4", + ], + ), + ( + # By default, osmeta compiler silently ignores -msseXX flags. + # This flag disables this behavior. + "osmeta", + [ + "-mosmeta-no-restrict-sse", + ], + ), + ], + visibility = ["PUBLIC"], + deps = [ + ":qnnp_interface", + third_party("cpuinfo"), + third_party("FP16"), + third_party("FXdiv"), + ], + ) + + fb_xplat_cxx_library( + # @autodeps-skip + name = "ukernels_sse41", + srcs = [ + "wrappers/requantization/gemmlowp-sse4.c", + "wrappers/requantization/precise-sse4.c", + "wrappers/requantization/q31-sse4.c", + ], + headers = subdir_glob([ + ("src", "**/*.c"), + ("src", "qnnpack/*.h"), + ("src", "requantization/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O3", + "-ffast-math", + "-Wno-error=unused-variable", + "-Wno-shadow", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + fbobjc_preprocessor_flags = [ + "-DQNNP_PRIVATE=", + "-DQNNP_INTERNAL=", + ], + force_static = True, + labels = labels, + platform_compiler_flags = [ + ( + "86", + [ + "-msse4.1", + "-mno-sse4.2", + ], + ), + ( + # By default, osmeta compiler silently ignores -msseXX flags. + # This flag disables this behavior. + "osmeta", + [ + "-mosmeta-no-restrict-sse", + ], + ), + ], + visibility = ["PUBLIC"], + deps = [ + ":qnnp_interface", + third_party("cpuinfo"), + third_party("FP16"), + third_party("FXdiv"), + ], + ) + + fb_xplat_cxx_library( + # @autodeps-skip + name = "qnnp_interface", + headers = subdir_glob( + [ + ("include", "*.h"), + ("src", "qnnpack/*.h"), + ("src", "requantization/*.h"), + ], + ), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + force_static = True, + labels = labels, + visibility = ["PUBLIC"], + deps = [ + third_party("pthreadpool_header"), + ], + ) + + fb_xplat_cxx_library( + # @autodeps-skip + name = "pytorch_qnnpack", + srcs = [ + "src/add.c", + "src/average-pooling.c", + "src/channel-shuffle.c", + "src/clamp.c", + "src/conv-prepack.cc", + "src/conv-run.cc", + "src/convolution.c", + "src/deconv-run.cc", + "src/deconvolution.c", + "src/fc-dynamic-run.cc", + "src/fc-prepack.cc", + "src/fc-run.cc", + "src/fully-connected.c", + "src/fully-connected-sparse.c", + "src/global-average-pooling.c", + "src/hardsigmoid.c", + "src/hardswish.c", + "src/indirection.c", + "src/init.c", + "src/leaky-relu.c", + "src/max-pooling.c", + "src/operator-delete.c", + "src/operator-run.c", + "src/pack_block_sparse.cc", + "src/sigmoid.c", + "src/softargmax.c", + "src/tanh.c", + ], + headers = subdir_glob([ + ("src", "**/*.c"), + ("src", "**/*.h"), + ("src", "qnnpack/*.h"), + ("include", "**/*.h"), + ]), + header_namespace = "", + exported_headers = subdir_glob([ + ("src", "qnnpack/*.h"), + ("include", "*.h"), + ]), + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + fbobjc_preprocessor_flags = [ + "-DQNNP_PRIVATE=", + "-DQNNP_INTERNAL=", + ], + force_static = True, + labels = [ + "supermodule:android/default/pytorch", + "supermodule:ios/default/public.pytorch", + ], + platform_compiler_flags = [ + ( + "armv7", + [ + "-mfpu=neon", + ], + ), + ( + "^android-armv7$", + [ + "-marm", + "-mfloat-abi=softfp", + ], + ), + ], + visibility = ["PUBLIC"], + deps = [ + ":qnnp_interface", + ":ukernels_asm", + ":ukernels_neon", + ":ukernels_psimd", + ":ukernels_scalar", + ":ukernels_sse2", + ":ukernels_sse41", + ":ukernels_ssse3", + third_party("cpuinfo"), + third_party("FP16"), + third_party("FXdiv"), + ], + exported_deps = [ + third_party("cpuinfo"), + ], + ) + + # Only ukernels implemented in C with ARM NEON intrinsics + fb_xplat_cxx_library( + # @autodeps-skip + name = "ukernels_neon", + srcs = [ + "wrappers/q8avgpool/mp8x9p8q-neon.c", + "wrappers/q8avgpool/up8x9-neon.c", + "wrappers/q8avgpool/up8xm-neon.c", + "wrappers/q8conv/4x8-neon.c", + "wrappers/q8conv/8x8-neon.c", + "wrappers/q8dwconv/mp8x25-neon.c", + "wrappers/q8dwconv/mp8x25-neon-per-channel.c", + "wrappers/q8dwconv/mp8x27-neon.c", + "wrappers/q8dwconv/up8x9-neon.c", + "wrappers/q8dwconv/up8x9-neon-per-channel.c", + "wrappers/q8gavgpool/mp8x7p7q-neon.c", + "wrappers/q8gavgpool/up8x7-neon.c", + "wrappers/q8gavgpool/up8xm-neon.c", + "wrappers/q8gemm/4x-sumrows-neon.c", + "wrappers/q8gemm/4x8-dq-neon.c", + "wrappers/q8gemm/4x8-neon.c", + "wrappers/q8gemm/4x8c2-xzp-neon.c", + "wrappers/q8gemm/6x4-neon.c", + "wrappers/q8gemm/8x8-neon.c", + "wrappers/q8vadd/neon.c", + "wrappers/requantization/fp32-neon.c", + "wrappers/requantization/gemmlowp-neon.c", + "wrappers/requantization/precise-neon.c", + "wrappers/requantization/q31-neon.c", + "wrappers/sgemm/5x8-neon.c", + "wrappers/sgemm/6x8-neon.c", + "wrappers/u8clamp/neon.c", + "wrappers/u8maxpool/16x9p8q-neon.c", + "wrappers/u8maxpool/sub16-neon.c", + "wrappers/u8rmax/neon.c", + "wrappers/x8zip/x2-neon.c", + "wrappers/x8zip/x3-neon.c", + "wrappers/x8zip/x4-neon.c", + "wrappers/x8zip/xm-neon.c", + ], + headers = subdir_glob([ + ("src", "**/*.c"), + ("src", "qnnpack/*.h"), + ("src", "requantization/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O3", + "-ffast-math", + "-Wno-error=unused-variable", + "-Wno-shadow", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + fbobjc_preprocessor_flags = [ + "-DQNNP_PRIVATE=", + "-DQNNP_INTERNAL=", + ], + force_static = True, + labels = labels, + platform_compiler_flags = [ + ( + "armv7", + [ + "-mfpu=neon", + ], + ), + ( + "^android-armv7$", + [ + "-marm", + "-mfloat-abi=softfp", + ], + ), + ], + visibility = ["PUBLIC"], + deps = [ + ":qnnp_interface", + third_party("cpuinfo"), + third_party("FP16"), + third_party("FXdiv"), + ], + ) + + fb_xplat_cxx_library( + # @autodeps-skip + name = "ukernels_asm", + srcs = [ + # Dummy empty source file to work around link error on x86-64 Android + # when static library contains no symbols. + "wrappers/dummy.c", + # AArch32 ukernels + "wrappers/hgemm/8x8-aarch32-neonfp16arith.S", + "wrappers/q8conv/4x8-aarch32-neon.S", + "wrappers/q8dwconv/up8x9-aarch32-neon.S", + "wrappers/q8dwconv/up8x9-aarch32-neon-per-channel.S", + "wrappers/q8gemm/4x8-aarch32-neon.S", + "wrappers/q8gemm/4x8-dq-aarch32-neon.S", + "wrappers/q8gemm/4x8c2-xzp-aarch32-neon.S", + "wrappers/q8gemm_sparse/4x4-packA-aarch32-neon.S", + "wrappers/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S", + "wrappers/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S", + "wrappers/q8gemm_sparse/8x4-packA-aarch64-neon.S", + "wrappers/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S", + "wrappers/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S", + # AArch64 ukernels + "wrappers/q8conv/8x8-aarch64-neon.S", + "wrappers/q8gemm/8x8-aarch64-neon.S", + "wrappers/q8gemm/8x8-dq-aarch64-neon.S", + ], + headers = subdir_glob([ + ("src", "qnnpack/assembly.h"), + ("src", "**/*.S"), + ("src", "requantization/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + fbobjc_preprocessor_flags = [ + "-DQNNP_PRIVATE=", + "-DQNNP_INTERNAL=", + ], + force_static = True, + labels = labels, + platform_compiler_flags = [ + ( + # iOS assembler doesn't let us specify ISA in the assembly file, + # so this must be set to the highest version of ISA of any of the + # assembly functions + "^iphoneos-armv7$", + [ + "-mfpu=neon-vfpv4", + ], + ), + ( + "osmeta", + [ + "-mfpu=neon-vfpv4", + ], + ), + ], + platform_preprocessor_flags = [ + ( + "android", + [ + # Workaround for osmeta-android, which builds for ELF, but hides it + "-D__ELF__=1", + ], + ), + ( + "tizen", + [ + # Workaround for osmeta-tizen, which builds for ELF, but hides it + "-D__ELF__=1", + ], + ), + ], + visibility = ["PUBLIC"], + ) + + fb_xplat_cxx_library( + # @autodeps-skip + name = "ukernels_psimd", + srcs = [ + "src/requantization/fp32-psimd.c", + "src/requantization/precise-psimd.c", + "src/sgemm/6x8-psimd.c", + ], + headers = subdir_glob([ + ("src", "**/*.c"), + ("src", "qnnpack/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O3", + "-ffast-math", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + fbobjc_preprocessor_flags = [ + "-DQNNP_PRIVATE=", + "-DQNNP_INTERNAL=", + ], + force_static = True, + labels = labels, + platform_compiler_flags = [ + ( + "armv7", + [ + "-mfpu=neon", + ], + ), + ( + "^android-armv7$", + [ + "-marm", + "-mfloat-abi=softfp", + ], + ), + ], + visibility = ["PUBLIC"], + deps = [ + ":qnnp_interface", + third_party("cpuinfo"), + third_party("FP16"), + third_party("FXdiv"), + third_party("psimd"), + ], + ) + + fb_xplat_cxx_test( + # @autodeps-skip + fbandroid_use_instrumentation_test = True, + contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], + platforms = (CXX, APPLE, ANDROID), + apple_sdks = (IOS, MACOSX), + name = "pytorch_qnnpack_test", + srcs = [ + "test/add.cc", + "test/average-pooling.cc", + "test/channel-shuffle.cc", + "test/clamp.cc", + "test/convolution.cc", + "test/deconvolution.cc", + "test/fully-connected.cc", + "test/fully-connected-sparse.cc", + "test/global-average-pooling.cc", + "test/hardsigmoid.cc", + "test/hardswish.cc", + "test/leaky-relu.cc", + "test/max-pooling.cc", + "test/q8avgpool.cc", + "test/q8conv.cc", + "test/q8dwconv.cc", + "test/q8gavgpool.cc", + "test/q8gemm_sparse.cc", + "test/q8vadd.cc", + "test/requantization.cc", + "test/sgemm.cc", + "test/sigmoid.cc", + "test/softargmax.cc", + "test/tanh.cc", + "test/u8clamp.cc", + "test/u8lut32norm.cc", + "test/u8maxpool.cc", + "test/u8rmax.cc", + "test/x8lut.cc", + "test/x8zip.cc", + ], + headers = { + "add-operator-tester.h": "test/add-operator-tester.h", + "average-pooling-operator-tester.h": "test/average-pooling-operator-tester.h", + "avgpool-microkernel-tester.h": "test/avgpool-microkernel-tester.h", + "channel-shuffle-operator-tester.h": "test/channel-shuffle-operator-tester.h", + "clamp-microkernel-tester.h": "test/clamp-microkernel-tester.h", + "clamp-operator-tester.h": "test/clamp-operator-tester.h", + "convolution-operator-tester.h": "test/convolution-operator-tester.h", + "deconvolution-operator-tester.h": "test/deconvolution-operator-tester.h", + "dwconv-microkernel-tester.h": "test/dwconv-microkernel-tester.h", + "fully-connected-operator-tester.h": "test/fully-connected-operator-tester.h", + "fully-connected-sparse-operator-tester.h": "test/fully-connected-sparse-operator-tester.h", + "gavgpool-microkernel-tester.h": "test/gavgpool-microkernel-tester.h", + "gemm-block-sparse-microkernel-tester.h": "test/gemm-block-sparse-microkernel-tester.h", + "gemm-microkernel-tester.h": "test/gemm-microkernel-tester.h", + "global-average-pooling-operator-tester.h": "test/global-average-pooling-operator-tester.h", + "hardsigmoid-operator-tester.h": "test/hardsigmoid-operator-tester.h", + "hardswish-operator-tester.h": "test/hardswish-operator-tester.h", + "leaky-relu-operator-tester.h": "test/leaky-relu-operator-tester.h", + "lut-microkernel-tester.h": "test/lut-microkernel-tester.h", + "lut-norm-microkernel-tester.h": "test/lut-norm-microkernel-tester.h", + "max-pooling-operator-tester.h": "test/max-pooling-operator-tester.h", + "maxpool-microkernel-tester.h": "test/maxpool-microkernel-tester.h", + "requantization-tester.h": "test/requantization-tester.h", + "rmax-microkernel-tester.h": "test/rmax-microkernel-tester.h", + "sigmoid-operator-tester.h": "test/sigmoid-operator-tester.h", + "softargmax-operator-tester.h": "test/softargmax-operator-tester.h", + "tanh-operator-tester.h": "test/tanh-operator-tester.h", + "test_utils.h": "test/test_utils.h", + "vadd-microkernel-tester.h": "test/vadd-microkernel-tester.h", + "zip-microkernel-tester.h": "test/zip-microkernel-tester.h", + }, + header_namespace = "", + compiler_flags = [ + "-fexceptions", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + ], + platform_linker_flags = [ + ( + "^linux.*$", + [ + "-Wl,--no-as-needed", + "-ldl", + "-pthread", + ], + ), + ], + env = { + # These tests fail in sandcastle since they leak memory. Disable LeakSanitizer. + "ASAN_OPTIONS": "detect_leaks=0", + }, + deps = [ + ":pytorch_qnnpack", + third_party("cpuinfo"), + third_party("FP16"), + third_party("pthreadpool"), + ], + ) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h b/aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h index 62fdef2cdf9b2..bfaa19e564b42 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h @@ -7,10 +7,10 @@ */ #pragma once +#include #include #include #include -#include #ifndef _WIN32 #include @@ -20,19 +20,75 @@ namespace qnnpack { -typedef struct BCSRMatrix { +template +struct OwnedOrBorrowedVector { + using VECTOR_T = #ifndef _WIN32 - std::vector> col_indices; - std::vector> row_values; - std::vector> values; + std::vector>; #else - std::vector col_indices; - std::vector row_values; - std::vector values; + std::vector; #endif + + // Only one of owned_vec_data_ or borrowed_tuple_data_ will be meaningfully + // populated. + // A union could potentially be used here to reduce memory usage. + // std::variant is not used here because it causes internal build errors + // due to incompatibility. + VECTOR_T owned_vec_data_; + std::tuple borrowed_tuple_data_; + bool owned; + + VECTOR_T& vector() { + assert(owned); + return owned_vec_data_; + } + + uint32_t size() const { + if (owned) { + return owned_vec_data_.size(); + } else { + return std::get<1>(borrowed_tuple_data_); + } + } + + const T* data() const { + if (owned) { + return owned_vec_data_.data(); + } else { + return std::get<0>(borrowed_tuple_data_); + } + } + + const T& operator[](int i) const { + return data()[i]; + } + + OwnedOrBorrowedVector() : owned(true) {} + + OwnedOrBorrowedVector(T* data_ptr, const uint32_t size) + : borrowed_tuple_data_(std::tuple(data_ptr, size)), + owned(false) {} +}; + +typedef struct BCSRMatrix { + OwnedOrBorrowedVector col_indices; + OwnedOrBorrowedVector row_values; + OwnedOrBorrowedVector values; uint32_t col_block_size; // input features block size uint32_t row_block_size; // output features block size void print() const; + /* + * Unpack from BCSR to Dense + * - Each value and zero point converted to int8_t by subtracting 128 + * - num_rows and num_cols are dimensions of dense weight tensor + * - dst should be able to hold num_rows * num_cols elements + * - zero_points should hold num_rows zero points + */ + void unpack( + int8_t* dst, + const int64_t num_rows, + const int64_t num_cols, + const uint8_t* zero_points) const; } BCSRMatrix; std::unique_ptr generateBlockCSRMatrix( @@ -43,4 +99,14 @@ std::unique_ptr generateBlockCSRMatrix( const uint32_t col_block_size, const uint8_t* zero_points); +std::unique_ptr generateBlockCSRMatrix( + uint32_t* col_indices, + uint32_t* row_values, + uint8_t* values, + const int64_t col_indices_size, + const int64_t row_values_size, + const int64_t values_size, + const int64_t row_block_size, + const int64_t col_block_size); + } // namespace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/pack_block_sparse.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/pack_block_sparse.cc index 6a6134023bfc8..c837f55cda855 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/pack_block_sparse.cc +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/pack_block_sparse.cc @@ -5,8 +5,11 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include #include +#include #include +#include #include @@ -21,13 +24,17 @@ std::unique_ptr generateBlockCSRMatrix( assert(K > 0); std::unique_ptr bcsr_mat_ptr = std::make_unique(); auto& bcsr_mat = *bcsr_mat_ptr; + auto& row_values = bcsr_mat.row_values.vector(); + auto& col_indices = bcsr_mat.col_indices.vector(); + auto& values = bcsr_mat.values.vector(); + const uint32_t num_row_blocks = (N + row_block_size - 1) / row_block_size; // K must be > 0 const uint32_t num_col_blocks = (K + col_block_size - 1) / col_block_size; - bcsr_mat.row_values.reserve(num_row_blocks); + row_values.reserve(num_row_blocks); uint32_t num_nnz_blocks{0}; - bcsr_mat.row_values.push_back(num_nnz_blocks); + row_values.push_back(num_nnz_blocks); for (uint32_t i = 0; i < num_row_blocks; ++i) { for (uint32_t j = 0; j < num_col_blocks; ++j) { bool block_zero{true}; @@ -49,14 +56,14 @@ std::unique_ptr generateBlockCSRMatrix( } block_scanned: if (!block_zero) { - bcsr_mat.col_indices.push_back(j); + col_indices.push_back(j); num_nnz_blocks++; for (uint32_t ib = 0; ib < row_block_size; ++ib) { uint32_t row_index = i * row_block_size + ib; if PYTORCH_QNNP_UNLIKELY(row_index >= N) { for (; row_index < (num_row_blocks * row_block_size); row_index++) { for (uint32_t jb = 0; jb < col_block_size; ++jb) { - bcsr_mat.values.push_back(zero_points[N-1]); + values.push_back(zero_points[N-1]); } } break; @@ -64,39 +71,100 @@ std::unique_ptr generateBlockCSRMatrix( for (uint32_t jb = 0; jb < col_block_size; ++jb) { uint32_t col_index = j * col_block_size + jb; if PYTORCH_QNNP_UNLIKELY(col_index >= K) { - bcsr_mat.values.push_back(zero_points[row_index]); + values.push_back(zero_points[row_index]); } else { uint8_t val = *(a + row_index * K + col_index); - bcsr_mat.values.push_back(val); + values.push_back(val); } } } } } - bcsr_mat.row_values.push_back(num_nnz_blocks); + row_values.push_back(num_nnz_blocks); } bcsr_mat.row_block_size = row_block_size; bcsr_mat.col_block_size = col_block_size; return bcsr_mat_ptr; } +std::unique_ptr generateBlockCSRMatrix( + uint32_t* col_indices, + uint32_t* row_values, + uint8_t* values, + const int64_t col_indices_size, + const int64_t row_values_size, + const int64_t values_size, + const int64_t row_block_size, + const int64_t col_block_size) { + std::unique_ptr bcsr_mat_ptr = std::make_unique(); + BCSRMatrix& bcsr_mat = *bcsr_mat_ptr; + bcsr_mat.col_indices = + OwnedOrBorrowedVector(col_indices, col_indices_size); + bcsr_mat.row_values = + OwnedOrBorrowedVector(row_values, row_values_size); + bcsr_mat.values = OwnedOrBorrowedVector(values, values_size); + bcsr_mat.row_block_size = row_block_size; + bcsr_mat.col_block_size = col_block_size; + return bcsr_mat_ptr; +} + void BCSRMatrix::print() const { std::cout << "row block size:" << row_block_size << std::endl; std::cout << "col block size:" << col_block_size << std::endl; std::cout << "row ptr\n"; - for (const auto& t : row_values) { - std::cout << t << ", "; + for (int i = 0; i < row_values.size(); i++) { + std::cout << row_values[i] << ", "; } std::cout << std::endl; std::cout << "col indices\n"; - for (const auto& t : col_indices) { - std::cout << t << ", "; + for (int i = 0; i < col_indices.size(); i++) { + std::cout << col_indices[i] << ", "; } std::cout << std::endl; std::cout << "Actual values\n"; - for (const auto& t : values) { - std::cout << (uint32_t)t << ", "; + for (int i = 0; i < values.size(); i++) { + std::cout << (uint32_t)values[i] << ", "; } std::cout << std::endl; } + +void BCSRMatrix::unpack( + int8_t* dst, + const int64_t num_rows, + const int64_t num_cols, + const uint8_t* zero_points) const { + for (int64_t i = 0; i < num_rows; i++) { + memset( + dst + i * num_cols, + static_cast(static_cast(zero_points[i]) - 128), + num_cols * sizeof(int8_t)); + } + + const int64_t num_block_rows = static_cast(row_values.size()) - 1; + const int64_t block_size = (int64_t)row_block_size * col_block_size; + int64_t weight_values_num = 0; + for (int64_t block_row_num = 0; block_row_num < num_block_rows; + block_row_num++) { + const int64_t num_blocks_in_current_block_row = + row_values[block_row_num + 1] - row_values[block_row_num]; + for (int64_t k = 0; k < num_blocks_in_current_block_row; + k++) { // iterate over each block in the row + const int64_t block_start_row_num = block_row_num * row_block_size; + const int64_t block_start_col_num = + (int64_t)(col_indices[weight_values_num / block_size]) * + col_block_size; + for (int64_t l = 0; l < block_size; + l++) { // iterate over each value in the block + const int64_t row_num = block_start_row_num + l / col_block_size; + const int64_t col_num = block_start_col_num + l % col_block_size; + if (row_num < num_rows && col_num < num_cols) { + dst[row_num * num_cols + col_num] = static_cast( + static_cast(values[weight_values_num]) - 128); + } + weight_values_num++; + } + } + } +} + } // namsepace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-operator-tester.h index 2ce93f37b23df..b8e6ce73853fb 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-operator-tester.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-operator-tester.h @@ -356,7 +356,7 @@ class FullyConnectedOperatorTester { } for (size_t i = 0; i < batchSize(); i++) { for (size_t c = 0; c < outputChannels(); c++) { - ASSERT_EQ( + ASSERT_FLOAT_EQ( output_dynamic[i * outputChannels() + c], ((float)accumulators[i * outputChannels() + c] * requantization_scales[c]) + float(bias[c])) diff --git a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp index dd93d1a5fb461..ddfbad8917f74 100644 --- a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp +++ b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp @@ -13,6 +13,7 @@ namespace at { namespace native { DEFINE_DISPATCH(quantized_normalize_stub); +DEFINE_DISPATCH(quantized_groupnorm_nhwc_stub); Tensor quantized_layer_norm_impl( const Tensor& input, @@ -55,8 +56,11 @@ Tensor quantized_group_norm_impl( double eps, double output_scale, int64_t output_zero_point) { + bool is_channels_last = qx.is_contiguous(c10::MemoryFormat::ChannelsLast); + auto mem_layout = is_channels_last ? c10::MemoryFormat::ChannelsLast : + c10::MemoryFormat::Contiguous; - const auto& qx_contig = qx.contiguous(); + const auto& qx_contig = qx.contiguous(mem_layout); const auto& weight_contig = weight.contiguous(); const auto& bias_contig = bias.contiguous(); @@ -87,8 +91,13 @@ Tensor quantized_group_norm_impl( if (M > 0) { bool affine_per_channel = true; - quantized_normalize_stub(kCPU, qx_contig, weight_contig, bias_contig, - affine_per_channel, num_channels, num_groups, M, N, eps, &Y); + if (is_channels_last) { + quantized_groupnorm_nhwc_stub(kCPU, qx_contig, weight_contig, bias_contig, + affine_per_channel, num_channels, num_groups, M, N, eps, &Y); + } else { + quantized_normalize_stub(kCPU, qx_contig, weight_contig, bias_contig, + affine_per_channel, num_channels, num_groups, M, N, eps, &Y); + } } return Y; } diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp index ad071d63dcdb0..e4ca887fb674b 100644 --- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp @@ -17,6 +17,7 @@ namespace native { DEFINE_DISPATCH(qrelu_stub); DEFINE_DISPATCH(qrelu_leaky_stub); +DEFINE_DISPATCH(qprelu_stub); #ifdef USE_PYTORCH_QNNPACK Tensor qnnpack_relu(Tensor input) { @@ -134,6 +135,32 @@ Tensor& leaky_relu_quantized_cpu_(Tensor& self, const Scalar& negval) { return self; } +Tensor prelu_quantized_cpu_impl(const Tensor& self, const Tensor& weight, + double output_scale, int64_t output_zero_point) { + auto ndim = self.dim(); + // for ndim < 1 or > 5, go to reference path + if (ndim > 5 || ndim < 1) { + auto x = self.dequantize(); + auto y = at::prelu(x, weight); + return at::quantize_per_tensor(y, output_scale, output_zero_point, c10::kQUInt8); + } + + auto qy = at::_empty_affine_quantized(self.sizes(), + at::device(kCPU) + .dtype(self.scalar_type()), + output_scale, + output_zero_point, + self.suggest_memory_format()); + + qprelu_stub(self.device().type(), qy, self, weight); + + return qy; +} + +Tensor prelu_quantized_cpu(const Tensor& self, const Tensor& weight) { + return prelu_quantized_cpu_impl(self, weight, self.q_scale(), self.q_zero_point()); +} + namespace { Tensor quantized_relu6(const Tensor& qx) { Tensor qy; @@ -175,9 +202,17 @@ class QLeakyRelu final { } }; +class QPRelu final { + public: + static Tensor run(Tensor self, const Tensor& weight, double output_scale, int64_t output_zero_point) { + return prelu_quantized_cpu_impl(self, weight, output_scale, output_zero_point); + } +}; + TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::relu6"), TORCH_FN(QRelu6::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::leaky_relu"), TORCH_FN(QLeakyRelu::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::prelu"), TORCH_FN(QPRelu::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp b/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp index 4f3bb54c22151..f29f548fc758c 100644 --- a/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp +++ b/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp @@ -94,7 +94,7 @@ Tensor qsoftmax_qnnpack(const Tensor& qx, const int64_t dim) { TORCH_CHECK( status == pytorch_qnnp_status_success, "failed to create QNNPACK Softmax operator"); - CHECK_NOTNULL(softargmax); + TORCH_CHECK_NOTNULL(softargmax); status = pytorch_qnnp_setup_softargmax_nc_q8( softargmax, batch_size, input, input_stride, output, output_stride); diff --git a/aten/src/ATen/native/quantized/cudnn/Conv.cpp b/aten/src/ATen/native/quantized/cudnn/Conv.cpp index 08dc1ec519288..ca1c3e146684d 100644 --- a/aten/src/ATen/native/quantized/cudnn/Conv.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Conv.cpp @@ -125,7 +125,7 @@ void PackedConvWeightCudnn::apply_impl_helper(const at::Tensor& qua auto padding_vec = padding_.vec(); auto stride_vec = stride_.vec(); auto dilation_vec = dilation_.vec(); - setConvolutionParams(&key.params, input, maybe_padded_weight_, padding_vec, stride_vec, dilation_vec, groups_, deterministic, allow_tf32); + setConvolutionParams(&key.params, input, maybe_padded_weight_, padding_vec, stride_vec, dilation_vec, groups_, deterministic, allow_tf32, input.suggest_memory_format()); // operator datatype needs to be int32 for int8 convolution, but we can // set the datatype for output tensor to int32 or fp32 diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index ea5338d15871a..a6ac4b330b0f1 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -187,6 +187,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::max_pool2d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::prelu(Tensor qx, Tensor weight, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::softmax(Tensor qx, int dim, float output_scale, int output_zero_point) -> Tensor")); } diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 77979f55647de..8d3a17a24ff82 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -28,11 +28,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -45,7 +47,9 @@ #include #include #include +#include #include +#include #endif namespace at { @@ -58,199 +62,170 @@ namespace { } // end anonymous namespace -void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) { +/* + Validate the arguments to sparse compressed (CSR, CSC, BSR, and BSC) + tensor factory functions. + + The CSR and BSR invariants for PyTorch are outlined in + + https://pearu.github.io/csr_tensor_invariants.html + https://pearu.github.io/bsr_tensor_invariants.html - // Layout must be Sparse Compressed + that in what follows are generalized for all sparse compressed + formats with support to batched and dense dimensions. +*/ + +void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) { + // Layout must be Sparse Compressed, 2.4 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{}); const std::string layout_name = layoutToString(layout, /*upper=*/ true); const std::string compressed_indices_name = compressedIndicesName(layout); const std::string plain_indices_name = plainIndicesName(layout); + const std::string compressed_dim_name = compressedDimName(layout); + const std::string plain_dim_name = plainDimName(layout); // Layout Invariants + // 2.1, 3.5 TORCH_CHECK( - plain_indices.layout() == kStrided && plain_indices.is_contiguous(), - "expected ", plain_indices_name, " to be a strided and contiguous tensor"); + plain_indices.layout() == kStrided && plain_indices.is_contiguous(), + "expected ", plain_indices_name, " to be a strided and contiguous tensor"); + // 2.2, 3.6 TORCH_CHECK( - compressed_indices.layout() == kStrided && compressed_indices.is_contiguous(), - "expected ", compressed_indices_name ," to be a strided and contiguous tensor"); + compressed_indices.layout() == kStrided && compressed_indices.is_contiguous(), + "expected ", compressed_indices_name ," to be a strided and contiguous tensor"); + // 2.3, partially 3.7 + // TODO: allow values be contiguous along both block dimensions when the format is BSR or BSC TORCH_CHECK( values.layout() == kStrided && values.is_contiguous(), "expected values to be a strided and contiguous tensor"); + const int base_ndim = 2; // corresponds to compressed and plain indices + const int batch_ndim = compressed_indices.dim() - 1; + const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( + layout, "validate_sparse_compressed_tensor_args", + [&] { return 0; }, [&] { return 2; }); + const int dense_ndim = values.dim() - batch_ndim - block_ndim - 1; // Shape and Strides invariants - TORCH_CHECK( - size.size() >= 2, - "size of a batched ", layout_name, " tensor must have length >= 2, but got: ", - size.size()); - TORCH_CHECK( - compressed_indices.dim() >= 1, - compressed_indices_name, " must have dim >= 1 but got ", compressed_indices_name, ".dim() = ", - compressed_indices.dim()); - TORCH_CHECK( - plain_indices.dim() >= 1, - plain_indices_name, " must have dim >= 1 but got ", plain_indices_name, ".dim() = ", - plain_indices.dim()); - TORCH_CHECK( - values.dim() >= 1, - "values must have dim >= 1 but got values.dim() = ", - values.dim()); + // 3.2 TORCH_CHECK( - compressed_indices.dim() == plain_indices.dim(), - "number of dimensions of ", compressed_indices_name, " and ", plain_indices_name, " must be the same but got ", - compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively"); - - int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( - layout, "validate_sparse_compressed_tensor_args", - [&] { - TORCH_CHECK( - compressed_indices.dim() <= values.dim(), - "number of dimensions of indices (=", compressed_indices.dim(), - ") must be equal or less than the number of dimensions of values (=", values.dim(), ")"); - return 0; - }, - [&] { - TORCH_CHECK( - compressed_indices.dim() + 2 <= values.dim(), - "number of dimensions of indices (=", compressed_indices.dim(), - ") plus two must be equal or less than the number of dimensions of values (=", values.dim(), ")"); - return 2; - }); - int dense_ndim = values.dim() - compressed_indices.dim() - block_ndim; - TORCH_CHECK(dense_ndim == 0, "non-zero dense dimensions (=", dense_ndim, ") is not supported for ", layout, " layout"); + batch_ndim >= 0, + compressed_indices_name, " must have dimensionality >= 1 but got ", compressed_indices.dim()); - int batch_ndim = size.size() - 2 - dense_ndim; - TORCH_INTERNAL_ASSERT(block_ndim >= 0 && dense_ndim >=0 && batch_ndim >= 0); + // 3.3 + TORCH_CHECK( + compressed_indices.dim() == plain_indices.dim(), + compressed_indices_name, " and ", plain_indices_name, " dimensionalities must be equal but got ", + compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively"); + // 3.4 + TORCH_CHECK( + dense_ndim >= 0, + "values must have dimensionality > sum of batch and block dimensionalities (=", + batch_ndim, " + ", block_ndim, ") but got ", values.dim()); + // 3.1 TORCH_CHECK( - static_cast(compressed_indices.dim()) == size.size() - 1 - dense_ndim, - "number of dimensions of indices must be one less than the number of dimensions of the provided size", - " (minus the number of dense dimensions) but got ", - compressed_indices.dim(), " not equal to ", size.size(), " - 1 - ", dense_ndim); + static_cast(size.size()) == batch_ndim + base_ndim + dense_ndim, + "tensor dimensionality must be sum of batch, base, and dense dimensionalites (=", + batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size()); // For CSR/CSC formats, we define blocksize=(1, 1) so that checking // the sparse compressed tensor invariants can be unified with the // BSR/BSC invariants. + // 3.10 DimVector blocksize{ - (block_ndim == 2 ? std::max(1, values.sizes()[values.dim() - dense_ndim - 2]) : 1), - (block_ndim == 2 ? std::max(1, values.sizes()[values.dim() - dense_ndim - 1]) : 1), + (block_ndim == 2 ? std::max(1, values.size(batch_ndim + 1)) : 1), + (block_ndim == 2 ? std::max(1, values.size(batch_ndim + 2)) : 1), }; TORCH_INTERNAL_ASSERT(blocksize.size() == 2 && blocksize[0] > 0 && blocksize[1] > 0); - int64_t numel_per_block = blocksize[0] * blocksize[1]; - int compressed_dim = compressedDimension(layout, size, dense_ndim); - int plain_dim = plainDimension(layout, size, dense_ndim); - - // All batch sizes must be the same - DimVector batch_size = DimVector(size.slice(0, batch_ndim)); - DimVector compressed_indices_batch_size = DimVector(compressed_indices.sizes().slice(0, compressed_indices.dim() - 1)); - DimVector plain_indices_batch_size = DimVector(plain_indices.sizes().slice(0, plain_indices.dim() - 1)); - DimVector values_batch_size = DimVector(values.sizes().slice(0, values.dim() - 1 - block_ndim - dense_ndim)); + // All batch sizes must be the same and consistent with tensor batchsize, 3.1, 3.8, 3.9, 3.10 + DimVector batchsize = DimVector(size.slice(0, batch_ndim)); + DimVector compressed_indices_batchsize = DimVector(compressed_indices.sizes().slice(0, batch_ndim)); + DimVector plain_indices_batchsize = DimVector(plain_indices.sizes().slice(0, batch_ndim)); + DimVector values_batchsize = DimVector(values.sizes().slice(0, batch_ndim)); + const int values_nnz = (values.numel() ? values.size(batch_ndim) : 0); + DimVector values_blocksize = DimVector(values.sizes().slice(batch_ndim + 1, block_ndim)); + DimVector values_densesize = DimVector(values.sizes().slice(batch_ndim + 1 + block_ndim, dense_ndim)); TORCH_CHECK( - batch_size == compressed_indices_batch_size && - batch_size == plain_indices_batch_size && - batch_size == values_batch_size, - "all batch dimensions of the provided size (", batch_size, "), indices (", - compressed_indices_batch_size,", ", plain_indices_batch_size, "), and values (", - values_batch_size,") must be the same."); - - // A tensor constitutes of full blocks - for (int i=0; i= 1` - if (block_ndim == 2) { - TORCH_CHECK( - compressed_indices.size(-1) == (size[compressed_dim] / blocksize[compressed_dim - batch_ndim] + 1), - compressed_indices_name, ".size(-1) must be equal to size[-", (size.size() - compressed_dim), - "]/blocksize[", compressed_dim - batch_ndim, "] + 1 (that is ", - size[compressed_dim] / blocksize[compressed_dim - batch_ndim] + 1, "), but got: ", compressed_indices.size(-1)); - TORCH_CHECK( - plain_indices.numel() * numel_per_block == values.numel(), - "number of ", plain_indices_name, " elements must be the same as the number of blocks in values, but got ", - plain_indices_name, ".numel() * numel_per_block: ", plain_indices.numel() * numel_per_block, - ", values.numel(): ", values.numel(),", numel_per_block: ", numel_per_block); - } else { - TORCH_CHECK( - compressed_indices.size(-1) == (size[compressed_dim] + 1), - compressed_indices_name, ".size(-1) must be equal to size[-", (size.size() - compressed_dim), - "] + 1 (that is ", - size[compressed_dim] + 1, "), but got: ", compressed_indices.size(-1)); - TORCH_CHECK( - plain_indices.numel() == values.numel(), - "number of ", plain_indices_name, " elements must be the same number of elements, but got ", - plain_indices_name, ".numel(): ", plain_indices.numel(), - ", values.numel(): ", values.numel()); + // A tensor constitutes of full blocks, 3.1 + for (int i=0; i(); - auto batch_stride = compressed_indices_cpu.dim() >= 2 ? compressed_indices_cpu.stride(-2) : 0; - auto compressed_dims = (block_ndim == 0 ? size[compressed_dim] : size[compressed_dim] / blocksize[compressed_dim - batch_ndim]); - for (const auto batch_id : c10::irange(batchCount(compressed_indices_cpu))) { - TORCH_CHECK( - compressed_indices_data_ptr[batch_id*batch_stride] == 0, - "(Batch element ", batch_id, ") ", - ": 0th value of ", compressed_indices_name, " must be 0, but it is ", compressed_indices_data_ptr[batch_id*batch_stride]); - TORCH_CHECK( - compressed_indices_data_ptr[batch_id*batch_stride + compressed_indices.size(-1) - 1] == plain_indices.size(-1), - "(Batch element ", batch_id, ") ", - "last value of ", compressed_indices_name, " should be equal to the length of ", plain_indices_name, "."); - for (int i = 1; i <= compressed_dims; i++) { - TORCH_CHECK( - compressed_indices_data_ptr[batch_id*batch_stride + i - 1] <= compressed_indices_data_ptr[batch_id*batch_stride + i], - "(Batch element ", batch_id, ") ", - "at position i = ", i, ", the condition ", compressed_indices_name, "[i - 1] <= ", compressed_indices_name, "[i] fails, got ", - compressed_indices_data_ptr[batch_id*batch_stride + i - 1], " <= ", compressed_indices_data_ptr[batch_id*batch_stride + i]); - } - } - if (plain_indices.numel() > 0) { - TORCH_CHECK(0 <= plain_indices.min().item(), plain_indices_name, ".min() should be greater or equal to zero"); - TORCH_CHECK(size[plain_dim] > plain_indices.max().item(), "size[-", (size.size() - plain_dim),"] should be greater than ", plain_indices_name, ".max()"); - } - }); + if (plain_indices.numel() > 0) { + at::_validate_compressed_sparse_indices( + /*is_crow = */layout == kSparseCsr || layout == kSparseBsr, + compressed_indices, + plain_indices, + compressed_dim_size, + plain_dim_size, + values_nnz + ); + } // Device Invariants - TORCH_CHECK( - plain_indices.get_device() == compressed_indices.get_device(), - compressed_indices_name, " and ", plain_indices_name, " devices (", - compressed_indices.get_device(), - ", ", - plain_indices.get_device(), - ") must match"); - TORCH_CHECK( - compressed_indices.get_device() == values.get_device(), - "device of ", compressed_indices_name, " (", - compressed_indices.get_device(), - ") must match device of values (", - values.get_device(), - ")"); + // 4.1 TORCH_CHECK( values.device().type() == kCPU || values.device().type() == kCUDA, "device type of values (", values.device().type(), ") must be CPU or CUDA"); - + // 4.2, 4.3, 4.4 + TORCH_CHECK( + compressed_indices.get_device() == values.get_device(), + "device of ", compressed_indices_name, " (=", + compressed_indices.device(), + ") must match device of values (=", + values.device(), + ")"); + TORCH_CHECK( + compressed_indices.get_device() == plain_indices.get_device(), + "device of ", compressed_indices_name, " (=", + compressed_indices.device(), + ") must match device of ", plain_indices_name," (=", + plain_indices.device(), + ")"); } void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) { @@ -360,34 +335,54 @@ DimVector _estimate_sparse_compressed_tensor_size( const Tensor& plain_indices, const Tensor& values, Layout layout) { - int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size", [&] { return 0; }, [&] { return 2; }); - int dense_ndim = values.dim() - compressed_indices.dim() - block_ndim; + const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size", [&] { return 0; }, [&] { return 2; }); + const int base_ndim = 2; // corresponds to compressed and plain indices + const int batch_ndim = compressed_indices.dim() - 1; + const std::string compressed_indices_name = compressedIndicesName(layout); + const std::string plain_indices_name = plainIndicesName(layout); + TORCH_CHECK( + batch_ndim >= 0, + compressed_indices_name, " must have dimensionality >= 1 but got ", compressed_indices.dim()); + TORCH_CHECK( + compressed_indices.dim() == plain_indices.dim(), + compressed_indices_name, " and ", plain_indices_name, " dimensionalities must be equal but got ", + compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively"); + const int dense_ndim = values.dim() - batch_ndim - block_ndim - 1; + TORCH_CHECK( + dense_ndim >= 0, + "values must have dimensionality > sum of batch and block dimensionalities (=", + batch_ndim, " + ", block_ndim, ") but got ", values.dim()); DimVector blocksize{ - (block_ndim == 2 ? std::max(1, values.sizes()[values.dim() - dense_ndim - 2]) : 1), - (block_ndim == 2 ? std::max(1, values.sizes()[values.dim() - dense_ndim - 1]) : 1), + (block_ndim == 2 ? std::max(1, values.size(batch_ndim + 1)) : 1), + (block_ndim == 2 ? std::max(1, values.size(batch_ndim + 2)) : 1) }; - DimVector size = DimVector(IntArrayRef(plain_indices.sizes().data(), plain_indices.dim() - 1)); - int64_t compressed_dim = (plain_indices.size(-1) > 0 ? compressed_indices.size(-1) - 1 : 0); - int64_t plain_dim = AT_DISPATCH_INTEGRAL_TYPES(plain_indices.scalar_type(), "estimate_sparse_compressed_tensor_size", - [&]() -> int64_t { - if (plain_indices.numel() > 0) { - return plain_indices.max().item() + 1; - } else { - return 0; - } - }); + DimVector size = DimVector(compressed_indices.sizes().slice(0, batch_ndim)); + int64_t compressed_dim_size = (compressed_indices.dim() > 0 && compressed_indices.size(-1) > 0 ? compressed_indices.size(-1) - 1 : 0); + int64_t plain_dim_size = AT_DISPATCH_INTEGRAL_TYPES(plain_indices.scalar_type(), "estimate_sparse_compressed_tensor_size", + [&]() -> int64_t { + if (plain_indices.numel() > 0) { + return plain_indices.max().item() + 1; + } else { + return 0; + } + }); AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size", [&]{ - size.push_back(compressed_dim * blocksize[0]); - size.push_back(plain_dim * blocksize[1]); + size.push_back(compressed_dim_size * blocksize[0]); + size.push_back(plain_dim_size * blocksize[1]); }, [&]{ - size.push_back(plain_dim * blocksize[0]); - size.push_back(compressed_dim * blocksize[1]); + size.push_back(plain_dim_size * blocksize[0]); + size.push_back(compressed_dim_size * blocksize[1]); }); for (int i=0; i(size.size()) == batch_ndim + base_ndim + dense_ndim, + "tensor dimensionality must be sum of batch, base, and dense dimensionalites (=", + batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size()); return size; } @@ -493,6 +488,16 @@ SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc) SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr) SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc) +Tensor empty_symint_sparse_compressed( + c10::SymIntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional optional_memory_format) { + return at::native::empty_sparse_compressed(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format); +} + Tensor empty_sparse_compressed( IntArrayRef size, c10::optional dtype, @@ -630,6 +635,14 @@ Tensor row_indices_sparse_csr(const Tensor& self) { [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); }); } +int64_t sparse_dim_sparse_csr(const SparseCsrTensor& self) { + return get_sparse_csr_impl(self)->sparse_dim(); +} + +int64_t dense_dim_sparse_csr(const SparseCsrTensor& self) { + return get_sparse_csr_impl(self)->dense_dim(); +} + bool _is_same_size_as_sparse_csr( const SparseCsrTensor& self, const SparseCsrTensor& src) { @@ -692,6 +705,9 @@ Tensor empty_like_sparse_csr( .merge_in(options_) .merge_memory_format(optional_memory_format); + TORCH_CHECK(options.layout() == self.layout(), + "empty_like with different sparse layout is not supported (self is ", + self.layout(), " but you requested ", options.layout(), ")"); if (options.layout() == kSparseCsr) { auto result = at::native::_sparse_csr_tensor_unsafe( self.crow_indices().clone(), @@ -702,6 +718,37 @@ Tensor empty_like_sparse_csr( self.layout(), options.device()); return result; + } else if (options.layout() == kSparseCsc) { + auto result = at::native::_sparse_csc_tensor_unsafe( + self.ccol_indices().clone(), + self.row_indices().clone(), + at::empty(self.values().sizes(), options.layout(kStrided)), + self.sizes(), + optTypeMetaToScalarType(options.dtype()), + self.layout(), + options.device()); + return result; + } else if (options.layout() == kSparseBsr) { + auto result = at::native::_sparse_bsr_tensor_unsafe( + self.crow_indices().clone(), + self.col_indices().clone(), + at::empty(self.values().sizes(), options.layout(kStrided)), + self.sizes(), + optTypeMetaToScalarType(options.dtype()), + self.layout(), + options.device()); + + return result; + } else if (options.layout() == kSparseBsc) { + auto result = at::native::_sparse_bsc_tensor_unsafe( + self.ccol_indices().clone(), + self.row_indices().clone(), + at::empty(self.values().sizes(), options.layout(kStrided)), + self.sizes(), + optTypeMetaToScalarType(options.dtype()), + self.layout(), + options.device()); + return result; } else if (options.layout() == kStrided) { return at::native::empty_like(self, dtype, layout, device, pin_memory, optional_memory_format); } else { @@ -710,13 +757,22 @@ Tensor empty_like_sparse_csr( } Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { - TORCH_INTERNAL_ASSERT(self.is_sparse_csr()); - TORCH_CHECK_INDEX(self.dim() != 0, "select() cannot be applied to a 0-dim tensor."); + TORCH_CHECK( + self.layout() == kSparseCsr || self.layout() == kSparseBsr, + "select(): currently only supports the SparseCsr and SparseBsr layout."); + TORCH_CHECK_INDEX( + self.dim() != 0, "select() cannot be applied to a 0-dim tensor."); dim = maybe_wrap_dim(dim, self.dim()); auto size = self.size(dim); if (index < -size || index >= size) { - TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ", - self.sizes(), " at dimension ", dim); + TORCH_CHECK_INDEX( + false, + "select(): index ", + index, + " out of range for tensor of size ", + self.sizes(), + " at dimension ", + dim); } if (index < 0) { index += size; @@ -730,6 +786,17 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { // Selecting batch dimension if (dim < self.dim() - 2) { + if (self.layout() == kSparseBsr) { + return at::native::_sparse_bsr_tensor_unsafe( + self.crow_indices().select(dim, index), + self.col_indices().select(dim, index), + self.values().select(dim, index), + new_sizes, + optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), + options.device_opt(), + options.pinned_memory_opt()); + } return at::native::_sparse_csr_tensor_unsafe( self.crow_indices().select(dim, index), self.col_indices().select(dim, index), @@ -740,9 +807,15 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { options.device_opt(), options.pinned_memory_opt()); } else { - TORCH_CHECK(self.dim() == 2, "select(): selecting rows or columns is not implemented for batched sparse CSR tensors.") - // Converting to COO and calling select is slighly slower than operating on the CSR indices directly - // for constructing a COO vector, however current version is more readable and easier to understand. + TORCH_CHECK( + self.is_sparse_csr(), + "select(): selecting non-batch dimensions is currently only supported for CSR tensors."); + TORCH_CHECK( + self.dim() == 2, + "select(): selecting rows or columns is not implemented for batched sparse CSR tensors.") + // Converting to COO and calling select is slighly slower than operating on + // the CSR indices directly for constructing a COO vector, however current + // version is more readable and easier to understand. return self.to_sparse().select(dim, index); } } diff --git a/aten/src/ATen/native/sparse/SparseFactories.cpp b/aten/src/ATen/native/sparse/SparseFactories.cpp new file mode 100644 index 0000000000000..f0007747660e8 --- /dev/null +++ b/aten/src/ATen/native/sparse/SparseFactories.cpp @@ -0,0 +1,95 @@ +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace at { +namespace native { + +DEFINE_DISPATCH(spdiags_kernel_stub); + +Tensor spdiags( + const Tensor& diagonals, + const Tensor& offsets, + IntArrayRef shape, + c10::optional layout) { + auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals; + TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix"); + TORCH_CHECK(shape.size() == 2, "Output shape must be 2d"); + auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets; + TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector"); + TORCH_CHECK( + diagonals_2d.size(0) == offsets_1d.size(0), + "Number of diagonals (", + diagonals_2d.size(0), + ") does not match the number of offsets (", + offsets_1d.size(0), + ")"); + if (layout) { + TORCH_CHECK( + (*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) || + (*layout == Layout::SparseCsr), + "Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ", + *layout); + } + TORCH_CHECK( + offsets_1d.scalar_type() == at::kLong, + "Offset Tensor must have dtype Long but got ", + offsets_1d.scalar_type()); + + TORCH_CHECK( + offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(), + "Offset tensor contains duplicate values"); + + auto nnz_per_diag = at::where( + offsets_1d.le(0), + offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)), + offsets_1d.add(-std::min(shape[1], diagonals_2d.size(1))).neg()); + + auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1); + const auto nnz = diagonals_2d.size(0) > 0 + ? nnz_per_diag_cumsum.select(-1, -1).item() + : int64_t{0}; + // Offsets into nnz for each diagonal + auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag); + // coo tensor guts + auto indices = at::empty({2, nnz}, offsets_1d.options()); + auto values = at::empty({nnz}, diagonals_2d.options()); + // We add this indexer to lookup the row of diagonals we are reading from at + // each iteration + const auto n_diag = offsets_1d.size(0); + Tensor diag_index = at::arange(n_diag, offsets_1d.options()); + // cpu_kernel requires an output + auto dummy = at::empty({1}, offsets_1d.options()).resize_({0}); + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .add_output(dummy) + .add_input(diag_index) + .add_input(offsets_1d) + .add_input(result_mem_offsets) + .add_input(nnz_per_diag) + .build(); + spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices); + auto result_coo = at::sparse_coo_tensor(indices, values, shape); + if (layout) { + if (*layout == Layout::SparseCsr) { + return result_coo.to_sparse_csr(); + } + if (*layout == Layout::SparseCsc) { + return result_coo.to_sparse_csc(); + } + } + return result_coo; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/sparse/SparseFactories.h b/aten/src/ATen/native/sparse/SparseFactories.h new file mode 100644 index 0000000000000..3fd68931878b6 --- /dev/null +++ b/aten/src/ATen/native/sparse/SparseFactories.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include +#include +#include + +namespace at { +namespace native { + +using spdiags_kernel_fn_t = + void (*)(TensorIterator&, const Tensor&, Tensor&, Tensor&); + +DECLARE_DISPATCH(spdiags_kernel_fn_t, spdiags_kernel_stub); +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 088e1bcf234d3..a162689eb5fb9 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -132,12 +132,15 @@ SparseTensor new_sparse( c10::optional pin_memory) { AT_ASSERT(layout.has_value() && *layout == kSparse); DispatchKey dispatch_key; - if (device_or_default(device).is_cuda()) { - dispatch_key = DispatchKey::SparseCUDA; - } else if (device_or_default(device).is_xpu()) { - dispatch_key = DispatchKey::SparseXPU; - } else { - dispatch_key = DispatchKey::SparseCPU; + switch (device_or_default(device).type()) { +#define DO_CASE(device, _) \ + case DeviceType::device: \ + dispatch_key = DispatchKey::Sparse##device; \ + break; + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + default: + TORCH_CHECK(false, "device type not supported for sparse ", device_or_default(device)) } return detail::make_tensor( DispatchKeySet(dispatch_key), @@ -204,6 +207,17 @@ Tensor empty_sparse( size.size(), 0, size, dtype, layout, device, pin_memory); } +/** Empty init **/ +Tensor empty_symint_sparse( + c10::SymIntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional optional_memory_format) { + return at::native::empty_sparse(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format); +} + /* Shape init */ Tensor sparse_coo_tensor(IntArrayRef size, c10::optional dtype, diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 1e48ca1bfa6c3..ad98fcee2d5bb 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -1,3 +1,5 @@ +#include + #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include @@ -64,6 +66,7 @@ #include #include #include +#include #endif #include @@ -705,7 +708,15 @@ Tensor mul_sparse(const Tensor& self, const Tensor& other) { } Tensor& mul_sparse_(Tensor& self, const Tensor& other) { - return at::mul_out(self, self, other); // redispatch! + if (self.is_sparse()) { + return at::mul_out(self, self, other); // redispatch! + } + else { + const auto res = at::mul(self, other); + self.zero_(); + self.add_(res); + return self; + } } Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) { @@ -744,23 +755,184 @@ Tensor& mul_sparse_csr_(Tensor& self, const Tensor& other) { return at::mul_out(self, self, other); // redispatch! } -SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, SparseTensor& r) { - if (src_.dim() == 0) { - return mul_out_sparse_zerodim(r, t_, src_); - } else if (t_.dim() == 0) { - return mul_out_sparse_zerodim(r, src_, t_); +// A generic function to implement pointwise-like operations +// with index intersection between dense and sparse COO tensors. +// NOTE: op is always called as op(dense_values, sparse_values), +// so it is up to the user to supply right implementations for non-commutative +// operations. +template +Tensor& intersection_binary_op_sparse_dense_out( + const Tensor& d, + const SparseTensor& s_, + Tensor& res, + const char* const op_name, + const binary_func_t& op, + const bool coalesce = false) { + // compute broadcasted shape. + const auto res_shape = infer_size(d.sizes(), s_.sizes()); + + // Short-circuit if either s_ or d is empty. + if (!s_._nnz() || !s_.numel() || !d.numel()) { + const auto sparse_dim = static_cast(res_shape.size()); + const auto indices = at::empty({sparse_dim, 0}, s_._indices().options()); + const auto values = at::empty({0}, s_._values().options().dtype(res.scalar_type())); + get_sparse_impl(res)->raw_resize_(sparse_dim, /*dense_dim=*/0, /*shape=*/res_shape); + get_sparse_impl(res)->set_indices_and_values_unsafe(indices, values); + get_sparse_impl(res)->set_nnz_and_narrow(0); + return res._coalesced_(true); } - TORCH_CHECK(t_.sizes().equals(src_.sizes()), "mul operands have incompatible sizes"); + const auto d_dim = d.dim(); + const auto s_dim = s_.dim(); + + // Always coalesce when sparse broadcasts over dense, + // because new sparse dimensions are created and + // repeated indices have to be eliminated because of that. + const auto s = (coalesce || d_dim > s_dim) ? s_.coalesce() : s_; + + const auto sparse_dim = s.sparse_dim(); + const auto dense_dim = s.dense_dim(); + + const auto s_indices = s._indices(); + const auto s_values = s._values(); + + const auto apply_op = [&](const Tensor& d_filtered) -> Tensor& { + const auto res_indices = s_indices.clone(); + const auto res_values = op(d_filtered, s_values); + get_sparse_impl(res)->raw_resize_(sparse_dim, dense_dim, res_shape); + get_sparse_impl(res)->set_indices_and_values_unsafe(res_indices, res_values); + get_sparse_impl(res)->set_nnz_and_narrow(s._nnz()); + return res._coalesced_(s.is_coalesced()); + }; + + // Easiest case: only dense dimensions intersect. + // This means only value tensors interact. + if (d_dim <= dense_dim) { + return apply_op(d); + } + + // Now we have intersection between sparse and dense dims. + const auto sparse_dim_intersec = std::min(sparse_dim, d_dim - dense_dim); + const auto d_start_dim_intersec = std::max(0, d_dim - s_dim); + const auto s_start_dim_intersec = std::max(0, s_dim - d_dim); + + // Index d with s_indices to find values which + // interact with s_values. + const auto d_filtered = [&]() -> Tensor { + using at::indexing::Slice; + using at::indexing::Ellipsis; + using at::indexing::TensorIndex; + + std::vector intersec_indices; + intersec_indices.reserve(d_dim); + + if (d_start_dim_intersec) { + intersec_indices.push_back(Ellipsis); + } + for (const auto i : c10::irange(sparse_dim_intersec)) { + const auto s_idx = s_start_dim_intersec + i; + intersec_indices.push_back(s_indices[s_idx]); + } + for (auto i = d_start_dim_intersec + sparse_dim_intersec; i < d_dim; ++i) { + intersec_indices.push_back(Slice()); + } + // we need to expand d in the dimensions it is being indexed into + // to avoid out of bound indices + const auto d_expanded_shape = std::vector( + res_shape.end() - d_dim, res_shape.end()); + return d.expand(d_expanded_shape).index(intersec_indices); + }(); + + // When dims match or sparse is "larger", the result nnz is the same, + // so only values get modified. + if (s_dim >= d_dim) { + return apply_op(d_filtered); + } + + // Otherwise nnz gets larger, and both indices and values need an update. + const auto d_batch_shape = d.sizes().slice(0, d_start_dim_intersec); + const auto d_batch_len = d_batch_shape.size(); + int64_t batch_count; + int64_t max_batch_dim; + std::tie(batch_count, max_batch_dim) = [&]() -> std::tuple { + int64_t batch_count = 1; + int64_t max_batch_dim = 0; + for (const auto& b : d_batch_shape) { + batch_count *= b; + max_batch_dim = std::max(b, max_batch_dim); + } + return std::make_tuple(batch_count, max_batch_dim); + }(); + + const auto res_sparse_dim = static_cast(d_batch_shape.size()) + sparse_dim; + const auto res_dense_dim = dense_dim; + const auto s_nnz = s._nnz(); + const auto res_nnz = batch_count * s_nnz; + auto res_values_shape = s_values.sizes().vec(); + res_values_shape[0] = res_nnz; + const auto res_values = op(d_filtered, s_values).reshape(res_values_shape); + const auto res_indices = [&]() -> Tensor { + const auto index_buffer = at::arange(max_batch_dim, s_indices.options()); + auto res_indices = at::empty({res_sparse_dim, res_nnz}, s_indices.options()); + // fill in indices corresponding to the "batch" dimensions of d. + int64_t n_repeat_interleave = res_nnz; + int n_repeat = 1; + for (const auto dim : c10::irange(d_batch_len)) { + const auto dim_size = d_batch_shape[dim]; + n_repeat_interleave /= dim_size; + // fill in indices corresponding to the "batch" dimension dim. + // Equivalent to res_indices[dim].copy_(repeat_interleave(dim_index, n_repeat_interleave).repeat(n_repeat)) + const std::initializer_list dim_index_expanded_shape = {n_repeat, dim_size, n_repeat_interleave}; + const auto dim_index = index_buffer.slice(-1, 0, dim_size); + const auto dim_index_expanded = dim_index.unsqueeze(0).unsqueeze_(-1).expand(dim_index_expanded_shape); + // NOTE: res_indices is contiguous, so view is safe + res_indices[dim].view(dim_index_expanded_shape).copy_(dim_index_expanded); + n_repeat *= dim_size; + } + // fill in indices corresponding to s_indices. + // Equivalent to res_indices_sparse.copy(s_indices.repeat({1, n_repeat}) + n_repeat = res_nnz / s_nnz; + auto res_indices_sparse = res_indices.narrow(0, d_batch_len, res_sparse_dim - d_batch_len); + const std::initializer_list s_indices_expanded_shape = {-1, n_repeat, s_nnz}; + const auto s_indices_expanded = s_indices.unsqueeze(1).expand(s_indices_expanded_shape); + res_indices_sparse.view(s_indices_expanded_shape).copy_(s_indices_expanded); + + return res_indices; + }(); + + get_sparse_impl(res)->raw_resize_(res_sparse_dim, res_dense_dim, res_shape); + get_sparse_impl(res)->set_indices_and_values_unsafe(res_indices, res_values); + get_sparse_impl(res)->set_nnz_and_narrow(res_nnz); + // By design of index expansion and that s is coalesced, + // the result is also coalesced. + return res._coalesced_(true); +} + +Tensor& _mul_dense_sparse_out(const Tensor& d, const Tensor& s, Tensor& res) { + return intersection_binary_op_sparse_dense_out(d, s, res, "mul", [](const Tensor& a, const Tensor& b) -> Tensor { + return at::mul(a, b); + }); +} + +SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, Tensor& r) { AT_ASSERT(!t_.is_cuda()); // dispatch argument TORCH_CHECK(!r.is_cuda(), "mul: expected 'out' to be CPU tensor, but got CUDA tensor"); TORCH_CHECK(!src_.is_cuda(), "mul: expected 'other' to be a CPU tensor, but got a CUDA tensor"); - TORCH_CHECK(src_.is_sparse(), "mul(sparse, dense) is not supported"); - TORCH_CHECK(t_.is_sparse(), "mul(dense, sparse) is not supported"); - TORCH_CHECK(t_.sizes().equals(src_.sizes()), "mul: expected 'self' and 'other' to have same sizes, but ", t_.sizes(), " != ", src_.sizes()); - if (src_._nnz() == 0 || t_._nnz() == 0) { - r.resize_as_(src_); + // case mul(sparse, dense) + if (!src_.is_sparse()) { + return _mul_dense_sparse_out(src_, t_, r); + } + // case mul(dense, sparse) + if (!t_.is_sparse()) { + return _mul_dense_sparse_out(t_, src_, r); + } + + TORCH_CHECK(t_.sizes().equals(src_.sizes()), "mul: expected 'self' and 'other' to have same sizes when both are sparse" + ", but ", t_.sizes(), " != ", src_.sizes()); + + if (!t_._nnz() || !src_._nnz()) { + r.resize_as_(t_); return r.zero_(); } diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.h b/aten/src/ATen/native/sparse/SparseTensorMath.h index 7ddeb3672402d..645e0e65e0605 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.h +++ b/aten/src/ATen/native/sparse/SparseTensorMath.h @@ -6,5 +6,6 @@ namespace at { namespace native { TORCH_API sparse::SparseTensor& mul_out_sparse_scalar(sparse::SparseTensor& r, const sparse::SparseTensor& t, const Scalar& value); TORCH_API sparse::SparseTensor& mul_out_sparse_zerodim(sparse::SparseTensor& r, const sparse::SparseTensor& t, const Tensor& value); +TORCH_API sparse::SparseTensor& _mul_dense_sparse_out(const Tensor& d, const Tensor& s, Tensor& res); }} diff --git a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp index ab56b82a8391f..d3da0f45200bf 100644 --- a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp +++ b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp @@ -183,6 +183,10 @@ COALESCED_UNARY_UFUNC_NO_INPLACE(isposinf); COALESCED_UNARY_UFUNC_FUNCTIONAL(isnan); COALESCED_UNARY_UFUNC_FUNCTIONAL(isinf); +Tensor isinf_sparse_meta(const Tensor& self) { + TORCH_CHECK_NOT_IMPLEMENTED(0, "nyi isinf for SparseMeta"); +} + Tensor nan_to_num_sparse( const Tensor &self, c10::optional nan, c10::optional posinf, c10::optional neginf) { diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h new file mode 100644 index 0000000000000..7d71d2104e5b0 --- /dev/null +++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h @@ -0,0 +1,319 @@ +#pragma once + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) +#define GPUCC +#define FUNCAPI __host__ __device__ +#define INLINE __forceinline__ +#define NAME "compressed_index_invariance_checks_cuda" +#else +#define FUNCAPI +#define INLINE inline +#define NAME "compressed_index_invariance_checks_cpu" +#endif + +#if defined(_WIN32) || defined(_WIN64) +#define RESTRICT __restrict +#else +#define RESTRICT __restrict__ +#endif + +#define INVARIANT_CHECK_FUNC_API static INLINE FUNCAPI void + +namespace at { +namespace native { + +namespace { + +// NOTE: all the checks but the very last one are designed +// to work with vectors. +// To enable vectorization one would need to write a conversion +// Vec -> bool and make kernel launchers call into vectorized +// execution paths. + +// All the invariants are described in https://pearu.github.io/bsr_tensor_invariants.html +// NOTE: in the code we also use `cidx/idx` to refer to `compressed_indices/plain_indices` +// respectively. + +INVARIANT_CHECK_FUNC_API +_assert(const bool cond, const char* const message) { +#ifdef GPUCC + CUDA_KERNEL_ASSERT(cond && message); +#else + TORCH_CHECK(cond, message); +#endif +} + +enum class CDimName : bool { + CRow, + CCol +}; + +// Invariant 5.1 +// compressed_index[..., 0] == 0. +template +INVARIANT_CHECK_FUNC_API +_check_first_cidx_is_zero(const index_t& cidx, const index_t& zero) { + const bool invariant = cidx == zero; + if (cdim_name == CDimName::CRow) { + _assert(invariant, "`crow_indices[..., 0] == 0` is not satisfied."); + } + else { + _assert(invariant, "`ccol_indices[..., 0] == 0` is not satisfied."); + } +} + +// Invariant 5.2 +// compressed_index[..., -1] == nnz. +template +INVARIANT_CHECK_FUNC_API +_check_last_cidx_is_nnz(const index_t& cidx, const index_t& nnz) { + const bool invariant = cidx == nnz; + if (cdim_name == CDimName::CRow) { + _assert(invariant, "`crow_indices[..., -1] == nnz` is not satisfied."); + } + else { + _assert(invariant, "`ccol_indices[..., -1] == nnz` is not satisfied."); + } +} + +// Invariant 5.3 +// 0 <= compressed_indices[..., 1:] - compressed_indices[..., :-1] <= plain_dim. +template +INVARIANT_CHECK_FUNC_API +_check_cidx_nondecreasing_locally_bounded_sequence( + const index_t& cidx, + const index_t& cidx_next, + const index_t& zero, + const index_t& dim) { + const auto s_cidx = cidx_next - cidx; + const bool invariant = zero <= s_cidx && s_cidx <= dim; + if (cdim_name == CDimName::CRow) { + _assert(invariant, + "`0 <= crow_indices[..., 1:] - crow_indices[..., :-1] <= ncols` is not satisfied."); + } + else { + _assert(invariant, + "`0 <= ccol_indices[..., 1:] - ccol_indices[..., :-1] <= nrows` is not satisfied."); + } +} + +// Invariants 5.4 and 5.5 +// 0 <= plain_index < plain_dim. +template +INVARIANT_CHECK_FUNC_API +_check_idx_bounds( + const index_t& idx, + const index_t& zero, + const index_t& dim) { + const bool invariant = zero <= idx && idx < dim; + if (cdim_name == CDimName::CRow) { + _assert(invariant, "`0 <= col_indices < ncols` is not satisfied."); + } + else { + _assert(invariant, "`0 <= row_indices < nrows` is not satisfied."); + } +} + +// Invariant 5.6 +// plain_indices[..., compressed_indices[..., i - 1]:compressed_indices[..., i]] +// for all i = 1, ..., compressed_dim +// are sorted and distinct along the last dimension values. +template +INVARIANT_CHECK_FUNC_API +_check_idx_sorted_distinct_vals_slices_with_cidx( + const index_t* RESTRICT ptr_idx_batch, + const index_t cidx, + const index_t cidx_next) { + // Note that ptr_idx_batch = &idx[batch_idx] and is contiguous. + const auto* RESTRICT slice_begin = ptr_idx_batch + cidx; + const auto* RESTRICT slice_end = ptr_idx_batch + cidx_next; + for (auto* RESTRICT curr = slice_begin + 1; curr < slice_end; ++curr) { + const auto invariant = *(curr - 1) < *curr; + if (cdim_name == CDimName::CRow) { + _assert(invariant, "`col_indices[..., crow_indices[..., i - 1]:crow_indices[..., i]] " + "for all i = 1, ..., nrows " + "are sorted and distinct along the last dimension values` " + "is not satisfied."); + } + else { + _assert(invariant, "`row_indices[..., ccol_indices[..., i - 1]:ccol_indices[..., i]] " + "for all i = 1, ..., ncols " + "are sorted and distinct along the last dimension values` " + "is not satisfied."); + } + } +} + +static inline int64_t indexCount(IntArrayRef sizes) { + int64_t res = 1; + for (const auto& s : sizes) { + res *= s; + } + return res; +} + +template +struct EmptyVecKernel { + static void launch(TensorIteratorBase& iter, const func_t& f, const vec_func_t& vec_f) { + } +}; + +template +using DummyVec = scalar_t; + +template < + template class kernel_t, + template class vec_kernel_t> +struct KernelLauncher { + template + static void launch(TensorIteratorBase& iter, const func_t& f, const vec_func_t& vec_f) { + vec_kernel_t::launch(iter, f, vec_f); + } + + template + static void launch(TensorIteratorBase& iter, const func_t& f) { + kernel_t::launch(iter, f); + } +}; + +template < + CDimName cdim_name, + template class kernel_t, + template class vec_kernel_t = EmptyVecKernel, + template class Vec = DummyVec> +void _validate_compressed_sparse_indices_kernel( + const Tensor& cidx, + const Tensor& idx, + const int64_t cdim, + const int64_t dim, + const int64_t nnz) { + if (cdim_name == CDimName::CRow) { + TORCH_CHECK(cidx.size(-1) == cdim + 1, "crow_indices have wrong shape: ", + "crow_indices.shape[-1] = ", cidx.size(-1), " is not equal to ", + "nrows + 1 = ", cdim + 1); + TORCH_CHECK(idx.size(-1) == nnz, "col_indices have wrong shape: ", + "col_indices.shape[-1] = ", idx.size(-1), " is not equal to ", + "nnz = ", nnz); + } + else { + TORCH_CHECK(cidx.size(-1) == cdim + 1, "ccol_indices have wrong shape: ", + "ccol_indices.shape[-1] = ", cidx.size(-1), " is not equal to ", + "ncols + 1 = ", cdim + 1); + TORCH_CHECK(idx.size(-1) == nnz, "row_indices have wrong shape: ", + "row_indices.shape[-1] = ", idx.size(-1), " is not equal to ", + "nnz = ", nnz); + } + + using KernelLauncher = KernelLauncher; + + // For TensorIterator's output: no void lambdas. + const auto dummy = at::empty({1}, cidx.options()); + + // Invariants 5.4 and 5.5 + { + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .add_owned_output(dummy.expand_as(idx)) + .add_input(idx) + .build(); + + AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, dim] () { + const auto zero = index_t {0}; + KernelLauncher::launch(iter, + [zero, dim] FUNCAPI (index_t idx) -> index_t { + _check_idx_bounds(idx, zero, dim); + return 0; + } + ); + }); + } + + // Invariants 5.1, 5.2, 5.3, 5.6 + { + const auto cidx_first = cidx.slice(-1, 0, 1); + const auto cidx_last = cidx.slice(-1, cdim, cdim + 1); + + const auto cidx_curr = cidx.slice(-1, 0, cdim); + const auto cidx_next = cidx.slice(-1, 1, cdim + 1); + + const auto batch_dims = cidx.sizes().slice(0, cidx.dim() - 1); + const auto batch_count = indexCount(batch_dims); + const auto batch_idx = at::arange(batch_count, cidx.options()).view(batch_dims).unsqueeze_(-1); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .add_owned_output(dummy.expand_as(cidx_curr)) + .add_input(cidx_first) + .add_input(cidx_last) + .add_input(cidx_curr) + .add_input(cidx_next) + .add_input(batch_idx) + .build(); + + AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz] () { + const auto* RESTRICT ptr_idx = idx.data_ptr(); + const auto zero = index_t {0}; + KernelLauncher::launch(iter, + [zero, dim, nnz, ptr_idx] FUNCAPI ( + index_t cidx_first, + index_t cidx_last, + index_t cidx_curr, + index_t cidx_next, + index_t batch_idx) -> index_t { + // Invariant 5.1 + _check_first_cidx_is_zero(cidx_first, zero); + // Invariant 5.2 + _check_last_cidx_is_nnz(cidx_last, nnz); + // Invariant 5.3 + _check_cidx_nondecreasing_locally_bounded_sequence(cidx_curr, cidx_next, zero, dim); + // Invariant 5.6 + // NOTE: the implementation below is sync-less, but, unfortunately, + // work is not guaranteed to be well-balanced between different threads. + // idx is contiguous and of shape (..., nnz), so batches are multiples of nnz apart. + const auto* RESTRICT ptr_idx_batch = ptr_idx + batch_idx * nnz; + _check_idx_sorted_distinct_vals_slices_with_cidx( + ptr_idx_batch, cidx_curr, cidx_next); + return 0; + } + ); + }); + } +} + +template < + template class kernel_t, + template class vec_kernel_t = EmptyVecKernel, + template class Vec = DummyVec> +void validate_compressed_sparse_indices_kernel( + const bool is_crow, + const Tensor& cidx, + const Tensor& idx, + const int64_t cdim, + const int64_t dim, + const int64_t nnz) { + if (is_crow) { + _validate_compressed_sparse_indices_kernel( + cidx, idx, cdim, dim, nnz); + } + else { + _validate_compressed_sparse_indices_kernel( + cidx, idx, cdim, dim, nnz); + } +} + +} // anonymous namespace for invariance checkers and utilities + +}} diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesKernel.cpp b/aten/src/ATen/native/sparse/ValidateCompressedIndicesKernel.cpp new file mode 100644 index 0000000000000..e8e7c6293abd2 --- /dev/null +++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesKernel.cpp @@ -0,0 +1,46 @@ +#include +#include + +namespace at { +namespace native { + +namespace { + +template +struct CPUKernel { + static void launch(TensorIteratorBase& iter, const func_t& f) { + cpu_kernel(iter, f); + } +}; + +template +struct EmptyKernel { + static void launch(TensorIteratorBase& iter, const func_t& f) { + } +}; + +template +struct CPUVecKernel { + static void launch(TensorIteratorBase& iter, const func_t& f, const vec_func_t& vec_f) { + cpu_kernel_vec(iter, f, vec_f); + } +}; + +} + +void _validate_compressed_sparse_indices_cpu( + const bool is_crow, + const Tensor& cidx, + const Tensor& idx, + const int64_t cdim, + const int64_t dim, + const int64_t nnz) { + // Call into + // compressed_index_invariance_checks_kernel + // to enable vectorized checks once all the conditions for that are met, + // see ATen/native/sparse/CompressedIndexChecksCommon.h for more details. + validate_compressed_sparse_indices_kernel( + is_crow, cidx, idx, cdim, dim, nnz); +} + +}} diff --git a/aten/src/ATen/native/sparse/cuda/SoftMax.cu b/aten/src/ATen/native/sparse/cuda/SoftMax.cu index 6bf96053bb23f..05cb9e06d90f3 100644 --- a/aten/src/ATen/native/sparse/cuda/SoftMax.cu +++ b/aten/src/ATen/native/sparse/cuda/SoftMax.cu @@ -80,10 +80,18 @@ static int getNumThreads(int nElem) { return threadSizes[4]; } +int64_t get_nvalues(const IntArrayRef& sizes, int64_t sparse_dim) { + /* Return the number of entries in the dense part of a sparse tensor. + `sizes` is a vector of sparse tensor dimensions. + `sparse_dim` is the dimension of the sparse part of a sparse tensor. + */ + return c10::multiply_integers(sizes.begin() + sparse_dim, sizes.end()); +} + template __global__ void cuda_sparse_coo_softmax_kernel( int64_t* sorted_pool_indices, - int64_t size, + int64_t pool_size, int64_t* pool_sizes, int64_t* pool_offsets, int64_t nvalues, @@ -91,7 +99,7 @@ __global__ void cuda_sparse_coo_softmax_kernel( PackedTensorAccessor input_values_acc, PackedTensorAccessor output_values_acc) { /* - See ATen/native/sparse/Softmax.cpp:cpu_sparse_coo_softmax for the CPU + See ATen/native/sparse/SoftMax.cpp:cpu_sparse_coo_softmax for the CPU implementation of the sparse softmax algorithm that this implementation is based on. */ @@ -103,7 +111,7 @@ __global__ void cuda_sparse_coo_softmax_kernel( int index = tid + blkid * blksz; int step = blksz * gridsz; - while (index < size) { + while (index < pool_size) { int64_t offset = pool_offsets[index]; int64_t* pool_indices = sorted_pool_indices + offset; int64_t pool_indices_size = pool_sizes[index]; @@ -153,7 +161,7 @@ __global__ void cuda_sparse_coo_softmax_backward_kernel( PackedTensorAccessor out_values_accessor, PackedTensorAccessor grad_values_accessor) { /* - See ATen/native/sparse/Softmax.cpp:cpu_sparse_coo_softmax_backward for + See ATen/native/sparse/SoftMax.cpp:cpu_sparse_coo_softmax_backward for the CPU implementation of the sparse softmax backward algorithm that this implementation is based on. */ @@ -226,7 +234,7 @@ Tensor get_offsets( const IntArrayRef& sizes, const int64_t dim) { /* - See ATen/native/sparse/Softmax.cpp:get_offsets for the CPU + See ATen/native/sparse/SoftMax.cpp:get_offsets for the CPU implementation of get_offsets function that this implementation is based on. */ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -284,8 +292,8 @@ std::tuple compute_pool_max( Return pools of indices that align with the given dimension and the corresponding max values for each pool. - See ATen/native/sparse/Softmax.cpp:get_offsets and - ATen/native/sparse/Softmax.cpp:cpu_sparse_coo_softmax for the CPU + See ATen/native/sparse/SoftMax.cpp:get_offsets and + ATen/native/sparse/SoftMax.cpp:cpu_sparse_coo_softmax for the CPU implementation that this implementation is based on. */ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -379,7 +387,7 @@ void cuda_sparse_coo_softmax( const Tensor& input, const int64_t dim) { /* - See ATen/native/sparse/Softmax.cpp:cpu_sparse_coo_softmax for the CPU + See ATen/native/sparse/SoftMax.cpp:cpu_sparse_coo_softmax for the CPU implementation of the sparse softmax algorithm that this implementation is based on. */ @@ -408,7 +416,7 @@ void cuda_sparse_coo_softmax( auto nnz = values.size(0); auto sizes = input.sizes(); - auto nvalues = values.numel() / nnz; + auto nvalues = get_nvalues(sizes, sparse_dim); /* Prepare accessors */ auto values_2 = values.view({nnz, nvalues}); @@ -429,17 +437,23 @@ void cuda_sparse_coo_softmax( int block_size = getNumThreads(pool_size); const int grid_size = (pool_size + block_size - 1) / block_size; - cuda_sparse_coo_softmax_kernel - <<>>( - sorted_indices.data_ptr(), - pool_size, - pool_sizes.data_ptr(), - pool_offsets.data_ptr(), - nvalues, - mx_buffer.data_ptr(), - values_accessor, - out_values_accessor); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + // If either nvalues or pool_size are zero, then cuda_sparse_coo_softmax_kernel + // won't actually perform any computation. Further, they will be + // invalid configuration parameters for the launch. So let's not + // launch a kernel unless both are non-zero. + if (nvalues > 0 && pool_size > 0) { + cuda_sparse_coo_softmax_kernel + <<>>( + sorted_indices.data_ptr(), + pool_size, + pool_sizes.data_ptr(), + pool_offsets.data_ptr(), + nvalues, + mx_buffer.data_ptr(), + values_accessor, + out_values_accessor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } } template @@ -450,7 +464,7 @@ void cuda_sparse_coo_softmax_backward( const int64_t dim, ScalarType input_dtype) { /* - See ATen/native/sparse/Softmax.cpp:cpu_sparse_coo_softmax_backward for + See ATen/native/sparse/SoftMax.cpp:cpu_sparse_coo_softmax_backward for the CPU implementation of the sparse softmax backward algorithm that this implementation is based on. */ @@ -522,7 +536,7 @@ void cuda_sparse_coo_softmax_backward( } auto nnz = values.size(0); - auto nvalues = values.numel() / nnz; + auto nvalues = get_nvalues(sizes, sparse_dim); auto values_2 = values.view({nnz, nvalues}); auto values_accessor = values_2.packed_accessor(); @@ -559,21 +573,23 @@ void cuda_sparse_coo_softmax_backward( int block_size = getNumThreads(pool_size); const int grid_size = (pool_size + block_size - 1) / block_size; - cuda_sparse_coo_softmax_backward_kernel - <<>>( - sorted_indices.data_ptr(), - pool_size, - pool_sizes.data_ptr(), - pool_offsets.data_ptr(), - nvalues, - grad_nnz, - grad_offsets.data_ptr(), - out_offsets.data_ptr(), - lower_bound_values.data_ptr(), - values_accessor, - out_values_accessor, - grad_values_accessor); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + if (nvalues > 0 && pool_size > 0) { + cuda_sparse_coo_softmax_backward_kernel + <<>>( + sorted_indices.data_ptr(), + pool_size, + pool_sizes.data_ptr(), + pool_offsets.data_ptr(), + nvalues, + grad_nnz, + grad_offsets.data_ptr(), + out_offsets.data_ptr(), + lower_bound_values.data_ptr(), + values_accessor, + out_values_accessor, + grad_values_accessor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } } } // end anonymous namespace diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index a427c07a085b4..c88a9c6abfde6 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -188,9 +188,6 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { for (int64_t d = sparse_dim - 1; d >= 0; d--) { // NB: Not a select, so I can preserve the outer dimension Tensor indicesSlice = newIndices.narrow(0, d, 1); - // Note for the porting guide: THCTensor_(copy) does NOT do normal - // broadcasting logic; instead, it will blast the elements from one - // to the other so long as the numel is the same indicesSlice.copy_(indices1D); indices1D.divide_(self.size(d), "trunc"); indicesSlice.add_(indices1D, -self.size(d)); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 3b35457b66d27..bea7788e9d579 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -459,16 +459,20 @@ struct TensorMulOp { } }; -SparseTensor& mul_out_sparse_cuda(const SparseTensor& t_, const SparseTensor& src_, SparseTensor& r_) { - if (src_.dim() == 0) { - return mul_out_sparse_zerodim(r_, t_, src_); - } else if (t_.dim() == 0) { - return mul_out_sparse_zerodim(r_, src_, t_); +SparseTensor& mul_out_sparse_cuda(const Tensor& t_, const Tensor& src_, SparseTensor& r_) { + TORCH_CHECK(r_.is_cuda(), "mul: expected 'out' to be CUDA, but got CPU"); + + // case mul(sparse, dense) + if (!src_.is_sparse()) { + return _mul_dense_sparse_out(src_, t_, r_); + } + // case mul(dense, sparse) + if (!t_.is_sparse()) { + return _mul_dense_sparse_out(t_, src_, r_); } TORCH_CHECK(t_.is_cuda(), "mul: expected 'self' to be CUDA, but got CPU"); TORCH_CHECK(src_.is_cuda(), "mul: expected 'other' to be CUDA, but got CPU"); - TORCH_CHECK(r_.is_cuda(), "mul: expected 'out' to be CUDA, but got CPU"); TORCH_CHECK(cuda::check_device({r_, t_, src_})); TORCH_CHECK(t_.sizes().equals(src_.sizes()), "mul: expected 'self' and 'other' to have same size, but ", t_.sizes(), " != ", src_.sizes()); @@ -495,7 +499,7 @@ SparseTensor& mul_out_sparse_cuda(const SparseTensor& t_, const SparseTensor& sr Tensor r_values_ = new_values_with_size_of(t_values_, max_nnz).zero_(); - int64_t valueSize = t_values_.stride(0); + int64_t valueSize = std::max(1, t_values_.stride(0)); const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), valueSize)); dim3 grid; int curDevice = -1; diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index b57c8c96a9059..39ff5de6f7c48 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -16,3 +16,11 @@ desc: | This tag indicates that the operator doesn't have an explicit entry in native_functions.yaml, and instead was generated automatically by the codegen. +- tag: nondeterministic_seeded + desc: | + This tag indicates if an operator is nondeterminstically seeded (ie is random) + such that the operator intentially produces different results when run twice on the same inputs. +- tag: nondeterministic_bitwise + desc: | + This tag indicates if an operator doesn't guarentee bitwise equivalence + across different runs of an operator with identical inputs. diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 3c0548f7fc392..67fa95c72aa2e 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -120,21 +120,11 @@ Tensor masked_softmax( c10::optional attn_mask, const Tensor& query) { if (query.is_nested() && !attn_mask) { - // TODO: maybe we could do better than generating a mask every time? - - attn_mask = NestedTensor_to_mask(query, 2, attn_scores.size(2)); - // TODO: CPU path does not support transformer mask yet. if (attn_scores.is_cpu()) { - attn_mask = attn_mask->view({-1, 1, 1, attn_scores.sizes()[3]}); - // 1 means skip, 0 means keep. - // want: - // 0,0 -> 0 - // 0,1 -> 1 - // 1,1 -> 1 - // so that's logical OR. - *attn_mask = *attn_mask | attn_mask->transpose(2, 3); - attn_mask = at::expand_inplace(attn_scores, *attn_mask)->contiguous(); + NestedTensor_softmax_dropout(query, attn_scores); + return attn_scores; } + attn_mask = NestedTensor_to_mask(query, 2, attn_scores.size(2)); attn_mask = attn_mask->to(query.device(), /*non-blocking=*/true); } if (attn_mask && attn_mask->dtype() != at::kBool) { @@ -668,5 +658,74 @@ std::tuple native_decoder_only_multi_head_attent return std::make_tuple(std::move(proj), std::move(qkt), std::move(k), std::move(v)); } +// Computes scaled dot product attention on query, key and value tensors, using +// an optional attention mask if passed, and applying dropout if a probability +// greater than 0.0 is specified. +// +// Args: +// query (Tensor): Query tensor; shape (N, ..., L, E) +// key (Tensor): Key tensor; shape (N, ..., S, E) +// value (Tensor): Value tensor; shape (N, ..., S, E) +// attn_mask (optional Tensor): Attention mask; shape (N, ..., L, S) or (L, S). Currently, only a boolean mask +// is supported, where a value of True indicates that the element *should* take part in attention. +// dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied +// need_attn_weights (bool): If true, the second return value will contain the attention weights used; +// otherwise, the second return value is unspecified +// is_causal (bool): If true, assumes causal attention masking; for this case, attn_mask should not be set. +// TODO: Consider removing this flag before promoting this function to the public API. It's possible +// to get specialized support for causal masks (and other types of masking e.g. local attention / block +// sparse masks) via tensor subclassing, allowing for a leaner API. +// +// Returns a tuple containing: +// output (Tensor): Attention output; shape (N, ..., L, E) +// attn_weights (Tensor): Attention weighting; shape (N, ..., L, S) +// +// Shape legend: +// N: Batch size +// ...: Any number of other batch dimensions (optional) +// S: Source sequence length +// L: Target sequence length +// E: Embedding dimension +std::tuple _scaled_dot_product_attention( + const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { + auto attn_mask = attn_mask_; + TORCH_CHECK(!attn_mask.has_value() || attn_mask->dtype() == at::kBool, + "_scaled_dot_product_attention: Only boolean attention masks are currently supported, but found: ", + attn_mask->dtype()) + // Naive, composite implementation defined here. + const auto embed_size = query_.size(-1); + const auto query = query_ * (1. / ::sqrt(static_cast(embed_size))); + if (is_causal) { + TORCH_CHECK(!attn_mask.has_value(), + "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True"); + TORCH_CHECK(!query.is_nested() && !key.is_nested(), + "_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True"); + + // Replace attn_mask with causal mask; lower triangular elements take part in attention. + const auto L = query.size(-2), S = key.size(-2); + attn_mask = at::ones({L, S}, query.options().dtype(at::kBool)).tril(); + } + if (attn_mask.has_value()) { + TORCH_CHECK(!query.is_nested() && !key.is_nested(), + "_scaled_dot_product_attention: Nested tensors for query / key are not supported " + "when an explicit attn_mask is set"); + // Convert boolean mask to additive mask; need to invert mask to indicate what to mask *out*. + auto new_attn_mask = at::zeros_like(*attn_mask, query.dtype()); + new_attn_mask.masked_fill_(attn_mask->logical_not(), -std::numeric_limits::infinity()); + attn_mask = new_attn_mask; + } + auto attn = at::matmul(query, key.transpose(-2, -1)); + if (attn_mask.has_value()) { + attn.add_(*attn_mask); + } + attn = at::softmax(attn, -1); + if (dropout_p > 0.0) { + at::dropout_(attn, dropout_p, true); + } + const auto output = at::matmul(attn, value); + return (need_attn_weights ? std::make_tuple(output, attn) : std::make_tuple(output, Tensor())); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/transformers/transformer.cpp b/aten/src/ATen/native/transformers/transformer.cpp index 7bfd7036bcc17..bba3adc9b2c4b 100644 --- a/aten/src/ATen/native/transformers/transformer.cpp +++ b/aten/src/ATen/native/transformers/transformer.cpp @@ -101,12 +101,15 @@ Tensor transformer_encoder_layer_forward( : src.clone(); } } - TORCH_CHECK(!norm_first, "norm_first is not supported yet"); const bool use_nested_tensor = src.is_nested(); - auto x = std::get<0>(native_multi_head_attention( - src, - src, - src, + Tensor x = src; + if (norm_first) { + x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor); + } + x = std::get<0>(native_multi_head_attention( + x, + x, + x, embed_dim, num_heads, qkv_weight, @@ -116,9 +119,15 @@ Tensor transformer_encoder_layer_forward( mask, false /* need_weights */)); add_in_place(x, src, use_nested_tensor); - x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor); + if (!norm_first) { + x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor); + } auto pre_ffn_res = x; + + if (norm_first) { + x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_2, layer_norm_bias_2, use_nested_tensor); + } x = ffn( x, ffn_weight_1, @@ -128,7 +137,9 @@ Tensor transformer_encoder_layer_forward( use_gelu, /* add_norm* */ false); add_in_place(x, pre_ffn_res, use_nested_tensor); - x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_2, layer_norm_bias_2, use_nested_tensor); + if (!norm_first) { + x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_2, layer_norm_bias_2, use_nested_tensor); + } return x; } diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index df5b48bbdd2ba..2ef238c0bff00 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -59,7 +59,6 @@ full_codegen: - gt.Tensor - hardsigmoid - index_select - - kl_div_backward - le.Scalar - le.Tensor - leaky_relu @@ -85,6 +84,8 @@ full_codegen: - mm - mul.Tensor - mv + - native_batch_norm + - native_batch_norm_backward - native_dropout - native_dropout_backward - native_layer_norm @@ -100,9 +101,9 @@ full_codegen: - norm.ScalarOpt_dim - pow.Tensor_Scalar - pow.Tensor_Tensor - - random.functional - - random.from_functional - - random.to_functional + - random + - random.from + - random.to - reciprocal - relu - remainder.Tensor @@ -139,44 +140,71 @@ full_codegen: - upsample_bilinear2d_backward - upsample_nearest2d - upsample_nearest2d_backward - - zero.functional + - zero - narrow_copy.SymInt + - alias_copy + - as_strided_copy + - diagonal_copy + - expand_copy + - permute_copy + - _reshape_alias_copy + - select_copy.int + - detach_copy + - slice_copy.Tensor + # Not implemented yet because LTC codegen doesn't currently work + # for ops that return lists of tensors. + #- split_copy.Tensor + #- split_with_sizes_copy + #- unbind_copy.int + - squeeze_copy + - squeeze_copy.dim + - t_copy + - transpose_copy.int + - unsqueeze_copy + - view_copy + - view_copy.dtype + - unfold_copy + - select_scatter + - slice_scatter + - diagonal_scatter + - as_strided_scatter +ir_gen: + - selu supported: - - as_strided - - as_strided_ - clone - _copy_from - _copy_from_and_resize - - diagonal - empty.memory_format + - empty.SymInt - empty_strided - - expand - fill_.Scalar - - narrow - - native_batch_norm - - native_batch_norm_backward - normal_ - max_pool3d_with_indices - max_pool3d_with_indices_backward - - permute - - select.int - - slice.Tensor - - squeeze - - squeeze.dim - - squeeze_ - - squeeze_.dim - - t - - t_ - _to_copy - - transpose.int - - transpose_ - - unsqueeze - - unsqueeze_ - - view - - alias - _unsafe_view + - lift + - lift_fresh + # Below are all operators that are "composite" in core, + # but require us to explicitly re-enable functionalization in order to use them. + # Why? These operators are all CompositeExplicitAutograd, which mean that they run + # after functionalization, + # but their implementations call view operators (which we need to functionalize away). + - block_diag + - diagonal_backward + - slice_backward + - new_empty_strided + - narrow_copy + - pixel_shuffle + - pixel_unshuffle + - select_backward + - _trilinear + - linalg_inv_ex + - linalg_pinv.atol_rtol_tensor + - logsumexp.out autograd: - max_pool3d + - native_group_norm # Ops that don't have a native schema definitions and are dispatched within Lazy Tensor Core non_native: diff --git a/aten/src/ATen/native/verbose_wrapper.cpp b/aten/src/ATen/native/verbose_wrapper.cpp new file mode 100644 index 0000000000000..3d7607b9b5e89 --- /dev/null +++ b/aten/src/ATen/native/verbose_wrapper.cpp @@ -0,0 +1,32 @@ +#include +#include + +#if AT_MKL_ENABLED() +#include +#endif + +#if AT_MKLDNN_ENABLED() +#include +#endif + +namespace torch { +namespace verbose { + +TORCH_API int _mkl_set_verbose(int enable) { +#if AT_MKL_ENABLED() + return mkl_verbose(enable); +#else + return 0; +#endif +} + +TORCH_API int _mkldnn_set_verbose(int level) { +#if AT_MKLDNN_ENABLED() + return at::native::set_verbose(level); +#else + return 0; +#endif +} + +} // namespace verbose +} // namespace torch diff --git a/aten/src/ATen/native/verbose_wrapper.h b/aten/src/ATen/native/verbose_wrapper.h new file mode 100644 index 0000000000000..9c7ab363f7dd3 --- /dev/null +++ b/aten/src/ATen/native/verbose_wrapper.h @@ -0,0 +1,11 @@ +#ifndef VERBOSE_WRAPPER_H +#define VERBOSE_WRAPPER_H + +namespace torch { +namespace verbose { +int _mkl_set_verbose(int enable); +int _mkldnn_set_verbose(int level); +} // namespace verbose +} // namespace torch + +#endif // VERBOSE_WRAPPER_H diff --git a/aten/src/ATen/native/vulkan/VulkanGuardImpl.cpp b/aten/src/ATen/native/vulkan/VulkanGuardImpl.cpp index 03d70a54f68e6..6bca69ebc1601 100644 --- a/aten/src/ATen/native/vulkan/VulkanGuardImpl.cpp +++ b/aten/src/ATen/native/vulkan/VulkanGuardImpl.cpp @@ -54,8 +54,8 @@ struct VulkanGuardImpl final : public c10::impl::DeviceGuardImplInterface { bool queryEvent(void* event) const override { TORCH_CHECK(false, "VULKAN backend doesn't support events.") } - void destroyEvent(void* event, const DeviceIndex device_index) const - noexcept override {} + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override {} }; C10_REGISTER_GUARD_IMPL(Vulkan, VulkanGuardImpl); diff --git a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h index 05c5ce977cd18..3802befad9d25 100644 --- a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h +++ b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h @@ -23,8 +23,7 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl { opaque_handle, sizes, false), - strides_(strides.vec()) { - } + strides_(strides.vec()) {} IntArrayRef strides_custom() const override { return strides_; diff --git a/aten/src/ATen/native/vulkan/api/Adapter.cpp b/aten/src/ATen/native/vulkan/api/Adapter.cpp index e319ae5be98cf..ff7311886e9a0 100644 --- a/aten/src/ATen/native/vulkan/api/Adapter.cpp +++ b/aten/src/ATen/native/vulkan/api/Adapter.cpp @@ -8,6 +8,51 @@ namespace native { namespace vulkan { namespace api { +PhysicalDevice::PhysicalDevice(const VkPhysicalDevice physical_device_handle) + : handle(physical_device_handle), + properties{}, + memory_properties{}, + queue_families{}, + num_compute_queues(0), + has_unified_memory(false), + has_timestamps(false), + timestamp_period(0) { + // Extract physical device properties + vkGetPhysicalDeviceProperties(handle, &properties); + vkGetPhysicalDeviceMemoryProperties(handle, &memory_properties); + + has_timestamps = properties.limits.timestampComputeAndGraphics; + timestamp_period = properties.limits.timestampPeriod; + + // Check if there are any memory types have both the HOST_VISIBLE and the + // DEVICE_LOCAL property flags + const VkMemoryPropertyFlags unified_memory_flags = + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + for (const uint32_t i : c10::irange(memory_properties.memoryTypeCount)) { + if (memory_properties.memoryTypes[i].propertyFlags | unified_memory_flags) { + has_unified_memory = true; + break; + } + } + + uint32_t queue_family_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties( + handle, &queue_family_count, nullptr); + + queue_families.resize(queue_family_count); + vkGetPhysicalDeviceQueueFamilyProperties( + handle, &queue_family_count, queue_families.data()); + + // Find the total number of compute queues + for (const uint32_t family_i : c10::irange(queue_families.size())) { + const VkQueueFamilyProperties& properties = queue_families[family_i]; + // Check if this family has compute capability + if (properties.queueFlags & VK_QUEUE_COMPUTE_BIT) { + num_compute_queues += properties.queueCount; + } + } +} + namespace { void find_requested_device_extensions( @@ -37,12 +82,108 @@ void find_requested_device_extensions( } } -// +VkDevice create_logical_device( + const PhysicalDevice& physical_device, + const uint32_t num_queues_to_create, + std::vector& queues, + std::vector& queue_usage) { + // Find compute queues up to the requested number of queues + + std::vector queue_create_infos; + queue_create_infos.reserve(num_queues_to_create); + + std::vector> queues_to_get; + queues_to_get.reserve(num_queues_to_create); + + uint32_t remaining_queues = num_queues_to_create; + for (const uint32_t family_i : + c10::irange(physical_device.queue_families.size())) { + const VkQueueFamilyProperties& queue_properties = + physical_device.queue_families[family_i]; + // Check if this family has compute capability + if (queue_properties.queueFlags & VK_QUEUE_COMPUTE_BIT) { + const uint32_t queues_to_init = + std::min(remaining_queues, queue_properties.queueCount); + + const std::vector queue_priorities(queues_to_init, 1.0f); + queue_create_infos.push_back({ + VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + family_i, // queueFamilyIndex + queues_to_init, // queueCount + queue_priorities.data(), // pQueuePriorities + }); + + for (const uint32_t queue_i : c10::irange(queues_to_init)) { + // Use this to get the queue handle once device is created + queues_to_get.emplace_back(family_i, queue_i); + } + remaining_queues -= queues_to_init; + } + if (remaining_queues == 0) { + break; + } + } + + queues.reserve(queues_to_get.size()); + queue_usage.reserve(queues_to_get.size()); + + // Create the VkDevice + + std::vector requested_device_extensions{ +#ifdef VK_KHR_portability_subset + VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME, +#endif /* VK_KHR_portability_subset */ + }; + + std::vector enabled_device_extensions; + find_requested_device_extensions( + physical_device.handle, + enabled_device_extensions, + requested_device_extensions); + + const VkDeviceCreateInfo device_create_info{ + VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + static_cast(queue_create_infos.size()), // queueCreateInfoCount + queue_create_infos.data(), // pQueueCreateInfos + 0u, // enabledLayerCount + nullptr, // ppEnabledLayerNames + static_cast( + enabled_device_extensions.size()), // enabledExtensionCount + enabled_device_extensions.data(), // ppEnabledExtensionNames + nullptr, // pEnabledFeatures + }; + + VkDevice handle; + VK_CHECK(vkCreateDevice( + physical_device.handle, &device_create_info, nullptr, &handle)); + +#ifdef USE_VULKAN_VOLK + volkLoadDevice(handle); +#endif /* USE_VULKAN_VOLK */ + + // Obtain handles for the created queues and initialize queue usage heuristic + + for (const std::pair& queue_idx : queues_to_get) { + VkQueue queue_handle = VK_NULL_HANDLE; + VkQueueFlags flags = + physical_device.queue_families[queue_idx.first].queueFlags; + vkGetDeviceQueue(handle, queue_idx.first, queue_idx.second, &queue_handle); + queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle}); + // Initial usage value + queue_usage.push_back(0); + } + + return handle; +} + // Print utils -// std::string get_device_type_str(const VkPhysicalDeviceType type) { - switch(type) { + switch (type) { case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU: return "INTEGRATED_GPU"; case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU: @@ -96,199 +237,53 @@ std::string get_queue_family_properties_str(const VkQueueFlags flags) { } // namespace -Adapter::Adapter(const VkPhysicalDevice handle, const uint32_t num_queues) - : physical_handle_(handle), - properties_{}, - memory_properties_{}, - queue_families_{}, - num_requested_queues_{num_queues}, - queue_usage_{}, - handle_(VK_NULL_HANDLE), - queues_{}, - num_compute_queues_{}, - has_unified_memory_{false}, - timestamp_compute_and_graphics_{false}, - timestamp_period_{0.f} { - // This should never happen, but double check to be safe - TORCH_CHECK( - VK_NULL_HANDLE != physical_handle_, - "Pytorch Vulkan Adapter: VK_NULL_HANDLE passed to Adapter constructor!") - - vkGetPhysicalDeviceProperties(physical_handle_, &properties_); - vkGetPhysicalDeviceMemoryProperties(physical_handle_, &memory_properties_); - - timestamp_compute_and_graphics_ = properties_.limits.timestampComputeAndGraphics; - timestamp_period_ = properties_.limits.timestampPeriod; - - // Check if there are any memory types have both the HOST_VISIBLE and the - // DEVICE_LOCAL property flags - const VkMemoryPropertyFlags unified_memory_flags = - VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - for (const uint32_t i : c10::irange(memory_properties_.memoryTypeCount)) { - if (memory_properties_.memoryTypes[i].propertyFlags | unified_memory_flags) { - has_unified_memory_ = true; - break; - } - } - - uint32_t queue_family_count = 0; - vkGetPhysicalDeviceQueueFamilyProperties( - physical_handle_, &queue_family_count, nullptr); - - queue_families_.resize(queue_family_count); - vkGetPhysicalDeviceQueueFamilyProperties( - physical_handle_, &queue_family_count, queue_families_.data()); - - // Find the total number of compute queues - for (const uint32_t family_i : c10::irange(queue_families_.size())) { - const VkQueueFamilyProperties& properties = queue_families_[family_i]; - // Check if this family has compute capability - if (properties.queueFlags & VK_QUEUE_COMPUTE_BIT) { - num_compute_queues_ += properties.queueCount; - } - } +// +// DeviceHandle +// - queue_usage_.reserve(num_requested_queues_); - queues_.reserve(num_requested_queues_); -} +DeviceHandle::DeviceHandle(const VkDevice device) : handle_(device) {} -Adapter::Adapter(Adapter&& other) noexcept - : physical_handle_(other.physical_handle_), - properties_(other.properties_), - memory_properties_(other.memory_properties_), - queue_families_(std::move(other.queue_families_)), - num_requested_queues_(other.num_requested_queues_), - queue_usage_(std::move(other.queue_usage_)), - handle_(other.handle_), - queues_(std::move(other.queues_)), - num_compute_queues_(other.num_compute_queues_), - has_unified_memory_(other.has_unified_memory_), - timestamp_compute_and_graphics_(other.timestamp_compute_and_graphics_), - timestamp_period_(other.timestamp_period_) { - other.physical_handle_ = VK_NULL_HANDLE; +DeviceHandle::DeviceHandle(DeviceHandle&& other) noexcept + : handle_(other.handle_) { other.handle_ = VK_NULL_HANDLE; } -Adapter::~Adapter() { - if C10_LIKELY(VK_NULL_HANDLE == handle_) { +DeviceHandle::~DeviceHandle() { + if C10_LIKELY (VK_NULL_HANDLE == handle_) { return; } vkDestroyDevice(handle_, nullptr); - handle_ = VK_NULL_HANDLE; } -void Adapter::init_device() { - // It is possible that multiple threads will attempt to initialize the device - // simultaneously, so lock the mutex before initializing - std::lock_guard lock(mutex_); - - // Do not initialize the device if there are no compute queues available - TORCH_CHECK( - num_compute_queues_ > 0, - "Pytorch Vulkan Adapter: Cannot initialize Adapter as this device does not " - "have any queues that support compute!") - - // This device has already been initialized, no-op - if C10_LIKELY(VK_NULL_HANDLE != handle_) { - return; - } - - // - // Find compute queues up to the requested number of queues - // - - std::vector queue_create_infos; - queue_create_infos.reserve(num_requested_queues_); - - std::vector> queues_to_get; - queues_to_get.reserve(num_requested_queues_); - - uint32_t remaining_queues = num_requested_queues_; - for (const uint32_t family_i : c10::irange(queue_families_.size())) { - const VkQueueFamilyProperties& properties = queue_families_[family_i]; - // Check if this family has compute capability - if (properties.queueFlags & VK_QUEUE_COMPUTE_BIT) { - const uint32_t queues_to_init = std::min( - remaining_queues, properties.queueCount); - - const std::vector queue_priorities(queues_to_init, 1.0f); - queue_create_infos.push_back({ - VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - family_i, // queueFamilyIndex - queues_to_init, // queueCount - queue_priorities.data(), // pQueuePriorities - }); - - for (const uint32_t queue_i : c10::irange(queues_to_init)) { - // Use this to get the queue handle once device is created - queues_to_get.emplace_back(family_i, queue_i); - } - remaining_queues -= queues_to_init; - } - if (remaining_queues == 0) { - break; - } - } - - // - // Create the VkDevice - // - - std::vector requested_device_extensions { - #ifdef VK_KHR_portability_subset - VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME, - #endif - }; - - std::vector enabled_device_extensions; - find_requested_device_extensions( - physical_handle_, enabled_device_extensions, requested_device_extensions); - - const VkDeviceCreateInfo device_create_info{ - VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - static_cast(queue_create_infos.size()), // queueCreateInfoCount - queue_create_infos.data(), // pQueueCreateInfos - 0u, // enabledLayerCount - nullptr, // ppEnabledLayerNames - static_cast(enabled_device_extensions.size()), // enabledExtensionCount - enabled_device_extensions.data(), // ppEnabledExtensionNames - nullptr, // pEnabledFeatures - }; - - const VkResult device_create_res = vkCreateDevice( - physical_handle_, &device_create_info, nullptr, &handle_); - // If device was not created successfully, ensure handle_ is invalid and throw - if (VK_SUCCESS != device_create_res) { - handle_ = VK_NULL_HANDLE; - VK_CHECK(device_create_res); - } - -#ifdef USE_VULKAN_VOLK - volkLoadDevice(handle_); -#endif - - // - // Obtain handles for the created queues and initialize queue usage heuristic - // +// +// Adapter +// - for (const std::pair& queue_idx : queues_to_get) { - VkQueue queue_handle = VK_NULL_HANDLE; - VkQueueFlags flags = queue_families_[queue_idx.first].queueFlags; - vkGetDeviceQueue( - handle_, queue_idx.first, queue_idx.second, &queue_handle); - queues_.push_back({queue_idx.first, queue_idx.second, flags, queue_handle}); - // Initial usage value - queue_usage_.push_back(0); - } -} +Adapter::Adapter( + const VkInstance instance, + const PhysicalDevice& physical_device, + const uint32_t num_queues) + : queue_usage_mutex_{}, + physical_device_(physical_device), + queues_{}, + queue_usage_{}, + queue_mutexes_{}, + instance_(instance), + device_(create_logical_device( + physical_device_, + num_queues, + queues_, + queue_usage_)), + shader_layout_cache_(device_.handle_), + shader_cache_(device_.handle_), + pipeline_layout_cache_(device_.handle_), + compute_pipeline_cache_(device_.handle_), + sampler_cache_(device_.handle_), + vma_(instance_, physical_device_.handle, device_.handle_) {} Adapter::Queue Adapter::request_queue() { // Lock the mutex as multiple threads can request a queue at the same time - std::lock_guard lock(mutex_); + std::lock_guard lock(queue_usage_mutex_); uint32_t min_usage = UINT32_MAX; uint32_t min_used_i = 0; @@ -307,36 +302,77 @@ void Adapter::return_queue(Adapter::Queue& compute_queue) { for (const uint32_t i : c10::irange(queues_.size())) { if ((queues_[i].family_index == compute_queue.family_index) && (queues_[i].queue_index == compute_queue.queue_index)) { - std::lock_guard lock(mutex_); + std::lock_guard lock(queue_usage_mutex_); queue_usage_[i] -= 1; break; } } } +void Adapter::submit_cmd( + const Adapter::Queue& device_queue, + const VkCommandBuffer cmd, + const VkFence fence) { + const VkSubmitInfo submit_info{ + VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType + nullptr, // pNext + 0u, // waitSemaphoreCount + nullptr, // pWaitSemaphores + nullptr, // pWaitDstStageMask + 1u, // commandBufferCount + &cmd, // pCommandBuffers + 0u, // signalSemaphoreCount + nullptr, // pSignalSemaphores + }; + + std::lock_guard queue_lock( + queue_mutexes_[device_queue.queue_index % NUM_QUEUE_MUTEXES]); + + VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence)); +} + +void Adapter::submit_cmds( + const Adapter::Queue& device_queue, + const std::vector& cmds, + const VkFence fence) { + const VkSubmitInfo submit_info{ + VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType + nullptr, // pNext + 0u, // waitSemaphoreCount + nullptr, // pWaitSemaphores + nullptr, // pWaitDstStageMask + utils::safe_downcast(cmds.size()), // commandBufferCount + cmds.data(), // pCommandBuffers + 0u, // signalSemaphoreCount + nullptr, // pSignalSemaphores + }; + + VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence)); +} + std::string Adapter::stringize() const { std::stringstream ss; - uint32_t v_major = VK_VERSION_MAJOR(properties_.apiVersion); - uint32_t v_minor = VK_VERSION_MINOR(properties_.apiVersion); - std::string device_type = get_device_type_str(properties_.deviceType); - VkPhysicalDeviceLimits limits = properties_.limits; + VkPhysicalDeviceProperties properties = physical_device_.properties; + uint32_t v_major = VK_VERSION_MAJOR(properties.apiVersion); + uint32_t v_minor = VK_VERSION_MINOR(properties.apiVersion); + std::string device_type = get_device_type_str(properties.deviceType); + VkPhysicalDeviceLimits limits = properties.limits; ss << "{" << std::endl; ss << " Physical Device Info {" << std::endl; ss << " apiVersion: " << v_major << "." << v_minor << std::endl; - ss << " driverversion: " << properties_.driverVersion << std::endl; + ss << " driverversion: " << properties.driverVersion << std::endl; ss << " deviceType: " << device_type << std::endl; - ss << " deviceName: " << properties_.deviceName << std::endl; + ss << " deviceName: " << properties.deviceName << std::endl; -#define PRINT_LIMIT_PROP(name) \ - ss << " " << std::left << std::setw(36) << #name << limits.name << std::endl; +#define PRINT_LIMIT_PROP(name) \ + ss << " " << std::left << std::setw(36) << #name << limits.name \ + << std::endl; -#define PRINT_LIMIT_PROP_VEC3(name) \ - ss << " " << std::left << std::setw(36) << #name \ - << limits.name[0] << "," \ - << limits.name[1] << "," \ - << limits.name[2] << std::endl; +#define PRINT_LIMIT_PROP_VEC3(name) \ + ss << " " << std::left << std::setw(36) << #name << limits.name[0] \ + << "," << limits.name[1] << "," << limits.name[2] << std::endl; ss << " Physical Device Limits {" << std::endl; PRINT_LIMIT_PROP(maxImageDimension1D); @@ -351,36 +387,43 @@ std::string Adapter::stringize() const { PRINT_LIMIT_PROP(maxComputeWorkGroupInvocations); PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupSize); ss << " }" << std::endl; - ss << " }" << std::endl;; + ss << " }" << std::endl; + ; + + const VkPhysicalDeviceMemoryProperties& mem_props = + physical_device_.memory_properties; - const VkPhysicalDeviceMemoryProperties& mem_props = memory_properties_; ss << " Memory Info {" << std::endl; ss << " Memory Types [" << std::endl; for (const auto i : c10::irange(mem_props.memoryTypeCount)) { - ss << " " << " [Heap " << mem_props.memoryTypes[i].heapIndex << "] " - << get_memory_properties_str(mem_props.memoryTypes[i].propertyFlags) - << std::endl; + ss << " " + << " [Heap " << mem_props.memoryTypes[i].heapIndex << "] " + << get_memory_properties_str(mem_props.memoryTypes[i].propertyFlags) + << std::endl; } ss << " ]" << std::endl; ss << " Memory Heaps [" << std::endl; for (const auto i : c10::irange(mem_props.memoryHeapCount)) { - ss << " " << mem_props.memoryHeaps[i].size << std::endl; + ss << " " << mem_props.memoryHeaps[i].size << std::endl; } ss << " ]" << std::endl; ss << " }" << std::endl; ss << " Queue Families {" << std::endl; - for (const VkQueueFamilyProperties& queue_family_props : queue_families_) { - ss << " (" << queue_family_props.queueCount << " Queues) " - << get_queue_family_properties_str(queue_family_props.queueFlags) << std::endl; + for (const VkQueueFamilyProperties& queue_family_props : + physical_device_.queue_families) { + ss << " (" << queue_family_props.queueCount << " Queues) " + << get_queue_family_properties_str(queue_family_props.queueFlags) + << std::endl; } ss << " }" << std::endl; - ss << " VkDevice: " << handle_ << std::endl; + ss << " VkDevice: " << device_.handle_ << std::endl; ss << " Compute Queues [" << std::endl; for (const Adapter::Queue& compute_queue : queues_) { - ss << " Family " << compute_queue.family_index - << ", Queue " << compute_queue.queue_index - << ": " << compute_queue.handle << std::endl;; + ss << " Family " << compute_queue.family_index << ", Queue " + << compute_queue.queue_index << ": " << compute_queue.handle + << std::endl; + ; } ss << " ]" << std::endl; ss << "}"; diff --git a/aten/src/ATen/native/vulkan/api/Adapter.h b/aten/src/ATen/native/vulkan/api/Adapter.h index a7aa29cc5baa5..8a43a4581edba 100644 --- a/aten/src/ATen/native/vulkan/api/Adapter.h +++ b/aten/src/ATen/native/vulkan/api/Adapter.h @@ -3,16 +3,52 @@ #ifdef USE_VULKAN_API #include -#include +#include #include +#include #include -#include namespace at { namespace native { namespace vulkan { namespace api { +struct PhysicalDevice final { + // Handle + VkPhysicalDevice handle; + + // Properties obtained from Vulkan + VkPhysicalDeviceProperties properties; + VkPhysicalDeviceMemoryProperties memory_properties; + std::vector queue_families; + + // Metadata + uint32_t num_compute_queues; + bool has_unified_memory; + bool has_timestamps; + float timestamp_period; + + explicit PhysicalDevice(const VkPhysicalDevice); +}; + +class DeviceHandle final { + public: + explicit DeviceHandle(const VkDevice device); + + DeviceHandle(const DeviceHandle&) = delete; + DeviceHandle& operator=(const DeviceHandle&) = delete; + + DeviceHandle(DeviceHandle&&) noexcept; + DeviceHandle& operator=(DeviceHandle&&) = delete; + + ~DeviceHandle(); + + private: + VkDevice handle_; + + friend class Adapter; +}; + // // A Vulkan Adapter represents a logical device and all its properties. It // manages all relevant properties of the underlying physical device, a @@ -21,9 +57,10 @@ namespace api { // which points to the logical device object on the GPU. // // This class is primarily used by the Runtime class, which holds one Adapter -// instance for each physical device visible to the VkInstance. Upon construction, -// this class will populate the physical device properties, but will not create -// the logical device until specifically requested via the init_device() funtion. +// instance for each physical device visible to the VkInstance. Upon +// construction, this class will populate the physical device properties, but +// will not create the logical device until specifically requested via the +// init_device() funtion. // // init_device() will create the logical device and obtain the VkDevice handle // for it. It will also create a number of compute queues up to the amount @@ -36,17 +73,22 @@ namespace api { // separate threads) to run concurrently. // +#define NUM_QUEUE_MUTEXES 4 + class Adapter final { public: - explicit Adapter(const VkPhysicalDevice handle, const uint32_t num_queues); + explicit Adapter( + const VkInstance instance, + const PhysicalDevice& physical_device, + const uint32_t num_queues); Adapter(const Adapter&) = delete; Adapter& operator=(const Adapter&) = delete; - Adapter(Adapter&&) noexcept; + Adapter(Adapter&&) = delete; Adapter& operator=(Adapter&&) = delete; - ~Adapter(); + ~Adapter() = default; struct Queue { uint32_t family_index; @@ -56,61 +98,111 @@ class Adapter final { }; private: - // Use a mutex to manage resources held by this class since + // Use a mutex to manage queue usage info since // it can be accessed from multiple threads - std::mutex mutex_; - // Physical Device Properties - VkPhysicalDevice physical_handle_; - VkPhysicalDeviceProperties properties_; - VkPhysicalDeviceMemoryProperties memory_properties_; - std::vector queue_families_; + std::mutex queue_usage_mutex_; + // Physical Device Info + PhysicalDevice physical_device_; // Queue Management - uint32_t num_requested_queues_; + std::vector queues_; std::vector queue_usage_; + std::array queue_mutexes_; // Handles - VkDevice handle_; - std::vector queues_; - // Metadata - uint32_t num_compute_queues_; - bool has_unified_memory_; - bool timestamp_compute_and_graphics_; - float timestamp_period_; + VkInstance instance_; + DeviceHandle device_; + // Device-level resource caches + ShaderLayoutCache shader_layout_cache_; + ShaderCache shader_cache_; + PipelineLayoutCache pipeline_layout_cache_; + ComputePipelineCache compute_pipeline_cache_; + // Memory Management + SamplerCache sampler_cache_; + MemoryAllocator vma_; public: + // Physical Device metadata + inline VkPhysicalDevice physical_handle() const { - return physical_handle_; + return physical_device_.handle; } inline VkDevice device_handle() const { - return handle_; + return device_.handle_; } inline bool has_unified_memory() const { - return has_unified_memory_; + return physical_device_.has_unified_memory; } inline uint32_t num_compute_queues() const { - return num_compute_queues_; + return physical_device_.num_compute_queues; } inline bool timestamp_compute_and_graphics() const { - return timestamp_compute_and_graphics_; + return physical_device_.has_timestamps; } inline float timestamp_period() const { - return timestamp_period_; + return physical_device_.timestamp_period; } - void init_device(); + // Queue Management + Queue request_queue(); - void return_queue(Queue& compute_queue); + void return_queue(Queue&); + + // Caches + + inline ShaderLayoutCache& shader_layout_cache() { + return shader_layout_cache_; + } + + inline ShaderCache& shader_cache() { + return shader_cache_; + } + + inline PipelineLayoutCache& pipeline_layout_cache() { + return pipeline_layout_cache_; + } + + inline ComputePipelineCache& compute_pipeline_cache() { + return compute_pipeline_cache_; + } + + // Memory Allocation + + inline SamplerCache& sampler_cache() { + return sampler_cache_; + } + + inline MemoryAllocator& vma() { + return vma_; + } + + // Command Buffer Submission + + void submit_cmd( + const Queue&, + const VkCommandBuffer, + const VkFence fence = VK_NULL_HANDLE); + + void submit_cmds( + const Adapter::Queue&, + const std::vector&, + const VkFence fence = VK_NULL_HANDLE); + + // Miscellaneous - inline Shader::WorkGroup local_work_group_size() const { - return { 4u, 4u, 4u, }; + inline utils::uvec3 local_work_group_size() const { + return { + 4u, + 4u, + 4u, + }; } std::string stringize() const; - friend std::ostream& operator<<(std::ostream& os, const Adapter& adapter); + friend std::ostream& operator<<(std::ostream&, const Adapter&); }; } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Allocator.h b/aten/src/ATen/native/vulkan/api/Allocator.h index 022a8bcc91975..470eb07543c24 100644 --- a/aten/src/ATen/native/vulkan/api/Allocator.h +++ b/aten/src/ATen/native/vulkan/api/Allocator.h @@ -12,43 +12,45 @@ #define VMA_VULKAN_VERSION 1000000 #ifdef USE_VULKAN_WRAPPER - #define VMA_STATIC_VULKAN_FUNCTIONS 0 +#define VMA_STATIC_VULKAN_FUNCTIONS 0 #else - #define VMA_DYNAMIC_VULKAN_FUNCTIONS 0 -#endif +#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0 +#endif /* USE_VULKAN_WRAPPER */ #define VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE (32ull * 1024 * 1024) #define VMA_SMALL_HEAP_MAX_SIZE (256ull * 1024 * 1024) +#define VMA_STATS_STRING_ENABLED 0 + #ifdef DEBUG - #define VMA_DEBUG_ALIGNMENT 4096 - #define VMA_DEBUG_ALWAYS_DEDICATED_MEMORY 0 - #define VMA_DEBUG_DETECT_CORRUPTION 1 - #define VMA_DEBUG_GLOBAL_MUTEX 1 - #define VMA_DEBUG_INITIALIZE_ALLOCATIONS 1 - #define VMA_DEBUG_MARGIN 64 - #define VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY 256 - #define VMA_RECORDING_ENABLED 1 - - #define VMA_DEBUG_LOG(format, ...) - /* - #define VMA_DEBUG_LOG(format, ...) do { \ - printf(format, __VA_ARGS__); \ - printf("\n"); \ - } while(false) - */ +#define VMA_DEBUG_ALIGNMENT 4096 +#define VMA_DEBUG_ALWAYS_DEDICATED_MEMORY 0 +#define VMA_DEBUG_DETECT_CORRUPTION 1 +#define VMA_DEBUG_GLOBAL_MUTEX 1 +#define VMA_DEBUG_INITIALIZE_ALLOCATIONS 1 +#define VMA_DEBUG_MARGIN 64 +#define VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY 256 +#define VMA_RECORDING_ENABLED 1 + +#define VMA_DEBUG_LOG(format, ...) +/* +#define VMA_DEBUG_LOG(format, ...) do { \ + printf(format, __VA_ARGS__); \ + printf("\n"); \ +} while(false) +*/ #endif /* DEBUG */ #ifdef __clang__ - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-Wnullability-completeness" - #pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnullability-completeness" +#pragma clang diagnostic ignored "-Wunused-variable" #endif /* __clang__ */ #include #ifdef __clang__ - #pragma clang diagnostic pop +#pragma clang diagnostic pop #endif /* __clang__ */ #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Cache.h b/aten/src/ATen/native/vulkan/api/Cache.h deleted file mode 100644 index a93385088277d..0000000000000 --- a/aten/src/ATen/native/vulkan/api/Cache.h +++ /dev/null @@ -1,96 +0,0 @@ -#pragma once - -#ifdef USE_VULKAN_API - -#include - -namespace at { -namespace native { -namespace vulkan { -namespace api { - -// -// A generic cache for immutable Vulkan objects, when there will not be many -// instances of those objects required at runtime. The previous sentence puts -// two constraints on proper use of this cache: 1) First, the objects should -// preferably be immutable otherwise much care is required to synchronize -// their usage. 2) Second, this cache is only intended for objects that -// we will not have many instances of during the entire execution of the -// program, otherwise the cache must be _infrequently_ purged. Proper usage -// model for this cache is in direct contrast with Vulkan object pools, which -// indeed are required to be _frequently_ purged. That is an important -// distinction. -// - -template -class Cache final { - public: - explicit Cache(Factory factory); - Cache(const Cache&) = delete; - Cache& operator=(const Cache&) = delete; - Cache(Cache&&) = default; - Cache& operator=(Cache&&) = default; - ~Cache() = default; - - // Factory must have the following symbols defined. - - typedef typename Factory::Descriptor Descriptor; - typedef typename Factory::Handle Handle; - typedef typename Factory::Hasher Hasher; - - // Create or retrieve a resource. - // - // This operation is a simple cache lookup and returns the Handle corresponding - // to the descriptor if the object is already present in the cache. Otherwise, - // Factory is used to create the object, after which point the object is added - // to the cache. Regardless, this function returns with the object in the cache. - - auto retrieve(const Descriptor& descriptor); - - // Only call this function infrequently, if ever. This cache is only intended - // for immutable Vulkan objects of which a small finite instances are required - // at runtime. A good place to call this function is between model loads. - - void purge(); - - private: - struct Configuration final { - static constexpr uint32_t kReserve = 64u; - }; - - ska::flat_hash_map cache_; - Factory factory_; -}; - -// -// Impl -// - -template -inline Cache::Cache(Factory factory) - : factory_(std::move(factory)) { - cache_.reserve(Configuration::kReserve); -} - -template -inline auto Cache::retrieve( - const Descriptor& descriptor) { - auto iterator = cache_.find(descriptor); - if C10_UNLIKELY(cache_.cend() == iterator) { - iterator = cache_.insert({descriptor, factory_(descriptor)}).first; - } - - return iterator->second.get(); -} - -template -inline void Cache::purge() { - cache_.clear(); -} - -} // namespace api -} // namespace vulkan -} // namespace native -} // namespace at - -#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp index 7f32681a6f5ba..b2c63ee4399f5 100644 --- a/aten/src/ATen/native/vulkan/api/Command.cpp +++ b/aten/src/ATen/native/vulkan/api/Command.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include @@ -8,506 +8,323 @@ namespace at { namespace native { namespace vulkan { namespace api { -namespace { -std::mutex queue_mutex; - -VkCommandPool create_command_pool( - const VkDevice device, - const uint32_t queue_family_index) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); - - const VkCommandPoolCreateInfo command_pool_create_info{ - VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO, - nullptr, - VK_COMMAND_POOL_CREATE_TRANSIENT_BIT, - queue_family_index, - }; - - VkCommandPool command_pool{}; - VK_CHECK(vkCreateCommandPool( - device, - &command_pool_create_info, - nullptr, - &command_pool)); - - TORCH_CHECK( - command_pool, - "Invalid Vulkan command pool!"); - - return command_pool; +// +// CommandBuffer +// + +CommandBuffer::CommandBuffer( + const VkCommandBuffer handle, + const VkCommandBufferUsageFlags flags) + : handle_(handle), + flags_(flags), + state_(CommandBuffer::State::NEW), + bound_{} {} + +CommandBuffer::CommandBuffer(CommandBuffer&& other) noexcept + : handle_(other.handle_), + flags_(other.flags_), + state_(other.state_), + bound_(other.bound_) { + other.handle_ = VK_NULL_HANDLE; + other.bound_.reset(); + state_ = CommandBuffer::State::INVALID; } -void allocate_command_buffers( - const VkDevice device, - const VkCommandPool command_pool, - VkCommandBuffer* const command_buffers, - const uint32_t count) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_pool, - "Invalid Vulkan command pool!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffers && (count > 0u), - "Invalid usage!"); - - const VkCommandBufferAllocateInfo command_buffer_allocate_info{ - VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, - nullptr, - command_pool, - VK_COMMAND_BUFFER_LEVEL_PRIMARY, - count, - }; +CommandBuffer& CommandBuffer::operator=(CommandBuffer&& other) noexcept { + handle_ = other.handle_; + flags_ = other.flags_; + state_ = other.state_; + bound_ = other.bound_; - VK_CHECK(vkAllocateCommandBuffers( - device, - &command_buffer_allocate_info, - command_buffers)); -} - -} // namespace - -Command::Buffer::Buffer(const VkCommandBuffer command_buffer) - : command_buffer_(command_buffer) { -} - -Command::Buffer::Buffer(Buffer&& buffer) - : command_buffer_(std::move(buffer.command_buffer_)), - bound_(std::move(buffer.bound_)), - barriers_(std::move(buffer.barriers_)) { - buffer.invalidate(); -} - -Command::Buffer& Command::Buffer::operator=(Buffer&& buffer) { - if (&buffer != this) { - command_buffer_ = std::move(buffer.command_buffer_); - bound_ = std::move(buffer.bound_); - barriers_ = std::move(buffer.barriers_); - - buffer.invalidate(); - }; + other.handle_ = VK_NULL_HANDLE; + other.bound_.reset(); + other.state_ = CommandBuffer::State::INVALID; return *this; } -void Command::Buffer::Buffer::begin() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); - - const VkCommandBufferBeginInfo command_buffer_begin_info{ - VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, - nullptr, - VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, - nullptr, - }; - - VK_CHECK(vkBeginCommandBuffer( - command_buffer_, - &command_buffer_begin_info)); - - // Reset - bound_.reset(); - barriers_.reset(); -} +void CommandBuffer::begin() { + TORCH_CHECK( + state_ == CommandBuffer::State::NEW, + "Vulkan CommandBuffer: called begin() on a command buffer whose state " + "is not NEW."); -void Command::Buffer::Buffer::end() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); + const VkCommandBufferBeginInfo begin_info{ + VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, + nullptr, + flags_, + nullptr, + }; - VK_CHECK(vkEndCommandBuffer(command_buffer_)); + VK_CHECK(vkBeginCommandBuffer(handle_, &begin_info)); + state_ = CommandBuffer::State::RECORDING; } -void Command::Buffer::barrier(const Pipeline::Barrier& barrier) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); - - barriers_.stage.src |= barrier.stage.src; - barriers_.stage.dst |= barrier.stage.dst; - - barriers_.buffers.insert( - barriers_.buffers.end(), - barrier.buffers.begin(), - barrier.buffers.end()); +void CommandBuffer::end() { + TORCH_CHECK( + state_ == CommandBuffer::State::RECORDING, + "Vulkan CommandBuffer: called end() on a command buffer whose state " + "is not RECORDING."); - barriers_.images.insert( - barriers_.images.end(), - barrier.images.begin(), - barrier.images.end()); + VK_CHECK(vkEndCommandBuffer(handle_)); + state_ = CommandBuffer::State::READY; } -void Command::Buffer::bind(const Pipeline::Object& pipeline) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - pipeline, - "Invalid Vulkan pipeline!"); +void CommandBuffer::bind_pipeline( + const VkPipeline pipeline, + const VkPipelineLayout pipeline_layout, + const utils::uvec3 local_workgroup_size) { + TORCH_CHECK( + state_ == CommandBuffer::State::RECORDING, + "Vulkan CommandBuffer: called bind_pipeline() on a command buffer whose state " + "is not RECORDING."); - if (pipeline.handle != bound_.pipeline.handle) { - vkCmdBindPipeline( - command_buffer_, - VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline.handle); + if (pipeline != bound_.pipeline) { + vkCmdBindPipeline(handle_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); bound_.pipeline = pipeline; } -} -void Command::Buffer::bind(const Descriptor::Set& set) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); + bound_.pipeline_layout = pipeline_layout; + bound_.local_workgroup_size = local_workgroup_size; - const VkDescriptorSet descriptor_set = set.handle(); + state_ = CommandBuffer::State::PIPELINE_BOUND; +} - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor_set, - "Invalid Vulkan descriptor set!"); +void CommandBuffer::bind_descriptors(const VkDescriptorSet descriptors) { + TORCH_CHECK( + state_ == CommandBuffer::State::PIPELINE_BOUND, + "Vulkan CommandBuffer: called bind_descriptors() on a command buffer whose state " + "is not PIPELINE_BOUND."); - if (descriptor_set != bound_.descriptor_set) { + if (descriptors != bound_.descriptors) { vkCmdBindDescriptorSets( - command_buffer_, - VK_PIPELINE_BIND_POINT_COMPUTE, - bound_.pipeline.layout, - 0u, - 1u, - &descriptor_set, - 0u, - nullptr); - - bound_.descriptor_set = descriptor_set; + handle_, // commandBuffer + VK_PIPELINE_BIND_POINT_COMPUTE, // pipelineBindPoint + bound_.pipeline_layout, // layout + 0u, // firstSet + 1u, // descriptorSetCount + &descriptors, // pDescriptorSets + 0u, // dynamicOffsetCount + nullptr); // pDynamicOffsets } -} -void Command::Buffer::copy( - const Resource::Buffer::Object source, - const Resource::Buffer::Object destination) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); + bound_.descriptors = descriptors; - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - source, - "Invalid Vulkan source buffer!"); + state_ = CommandBuffer::State::DESCRIPTORS_BOUND; +} - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - destination, - "Invalid Vulkan destination buffer!"); +void CommandBuffer::insert_barrier(const PipelineBarrier& pipeline_barrier) { + TORCH_CHECK( + state_ == CommandBuffer::State::DESCRIPTORS_BOUND || + state_ == CommandBuffer::State::RECORDING, + "Vulkan CommandBuffer: called insert_barrier() on a command buffer whose state " + "is not DESCRIPTORS_BOUND or RECORDING."); - barrier(); + if (pipeline_barrier) { + c10::SmallVector buffer_memory_barriers; + for (const api::BufferMemoryBarrier& memory_barrier : + pipeline_barrier.buffers) { + buffer_memory_barriers.push_back(memory_barrier.handle); + } - const VkBufferCopy buffer_copy{ - 0u, - 0u, - std::min(source.range, destination.range), - }; + c10::SmallVector image_memory_barriers; + for (const api::ImageMemoryBarrier& memory_barrier : + pipeline_barrier.images) { + image_memory_barriers.push_back(memory_barrier.handle); + } - vkCmdCopyBuffer( - command_buffer_, - source.handle, - destination.handle, - 1u, - &buffer_copy); -} + vkCmdPipelineBarrier( + handle_, // commandBuffer + pipeline_barrier.stage.src, // srcStageMask + pipeline_barrier.stage.dst, // dstStageMask + 0u, // dependencyFlags + 0u, // memoryBarrierCount + nullptr, // pMemoryBarriers + buffer_memory_barriers.size(), // bufferMemoryBarrierCount + buffer_memory_barriers.data(), // pMemoryBarriers + image_memory_barriers.size(), // imageMemoryBarrierCount + image_memory_barriers.data()); // pImageMemoryBarriers + } -void Command::Buffer::dispatch( - const Shader::WorkGroup& global_work_group) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); + state_ = CommandBuffer::State::BARRIERS_INSERTED; +} - barrier(); +void CommandBuffer::dispatch(const utils::uvec3& global_workgroup_size) { + TORCH_CHECK( + state_ == CommandBuffer::State::BARRIERS_INSERTED, + "Vulkan CommandBuffer: called dispatch() on a command buffer whose state " + "is not BARRIERS_INSERTED."); vkCmdDispatch( - command_buffer_, + handle_, utils::div_up( - global_work_group.data[0u], - bound_.pipeline.local_work_group.data[0u]), + global_workgroup_size.data[0u], bound_.local_workgroup_size.data[0u]), utils::div_up( - global_work_group.data[1u], - bound_.pipeline.local_work_group.data[1u]), + global_workgroup_size.data[1u], bound_.local_workgroup_size.data[1u]), utils::div_up( - global_work_group.data[2u], - bound_.pipeline.local_work_group.data[2u])); + global_workgroup_size.data[2u], + bound_.local_workgroup_size.data[2u])); + + state_ = CommandBuffer::State::RECORDING; } -void Command::Buffer::barrier() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer_, - "This command buffer is in an invalid state! " - "Potential reason: This command buffer is moved from."); +void CommandBuffer::copy_texture_to_texture( + const api::VulkanImage& source, + const api::VulkanImage& destination, + const api::utils::uvec3& copy_range, + const api::utils::uvec3& src_offset, + const api::utils::uvec3& dst_offset) { + TORCH_CHECK( + state_ == CommandBuffer::State::BARRIERS_INSERTED, + "Vulkan CommandBuffer: called copy_texture_to_texture() on a command buffer whose state " + "is not BARRIERS_INSERTED."); + + const VkImageSubresourceLayers src_subresource_layers{ + VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask + 0u, // mipLevel + 0u, // baseArrayLayer + 1u, // layerCount + }; - if (barriers_.stage) { - c10::SmallVector buffer_memory_barriers; + const VkImageSubresourceLayers dst_subresource_layers{ + VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask + 0u, // mipLevel + 0u, // baseArrayLayer + 1u, // layerCount + }; - for (const Resource::Buffer::Barrier& barrier : barriers_.buffers) { - buffer_memory_barriers.push_back({ - VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, - nullptr, - barrier.memory.src, - barrier.memory.dst, - VK_QUEUE_FAMILY_IGNORED, - VK_QUEUE_FAMILY_IGNORED, - barrier.object.handle, - barrier.object.offset, - barrier.object.range, - }); - } + const VkImageCopy copy_details{ + src_subresource_layers, // srcSubresource + create_offset3d(src_offset), // srcOffset + dst_subresource_layers, // dstSubresource + create_offset3d(dst_offset), // dstOffset + create_extent3d(copy_range), // extent + }; - c10::SmallVector image_memory_barriers; + vkCmdCopyImage( + handle_, + source.handle(), + source.layout(), + destination.handle(), + destination.layout(), + 1u, + ©_details); - for (const Resource::Image::Barrier& barrier : barriers_.images) { - image_memory_barriers.push_back({ - VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, - nullptr, - barrier.memory.src, - barrier.memory.dst, - barrier.layout.src, - barrier.layout.dst, - VK_QUEUE_FAMILY_IGNORED, - VK_QUEUE_FAMILY_IGNORED, - barrier.object.handle, - { - VK_IMAGE_ASPECT_COLOR_BIT, - 0u, - VK_REMAINING_MIP_LEVELS, - 0u, - VK_REMAINING_ARRAY_LAYERS, - }, - }); - } + state_ = CommandBuffer::State::RECORDING; +} - vkCmdPipelineBarrier( - command_buffer_, - barriers_.stage.src, - barriers_.stage.dst, - 0u, - 0u, - nullptr, - buffer_memory_barriers.size(), - buffer_memory_barriers.data(), - image_memory_barriers.size(), - image_memory_barriers.data()); - } +void CommandBuffer::write_timestamp( + const VkQueryPool querypool, + const uint32_t idx) const { + TORCH_CHECK( + state_ == CommandBuffer::State::RECORDING, + "Vulkan CommandBuffer: called write_timestamp() on a command buffer whose state " + "is not RECORDING."); - // Reset - barriers_.reset(); + vkCmdWriteTimestamp( + handle_, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, querypool, idx); } -void Command::Buffer::invalidate() { - command_buffer_ = VK_NULL_HANDLE; -} +void CommandBuffer::reset_querypool( + const VkQueryPool querypool, + const uint32_t first_idx, + const uint32_t count) const { + TORCH_CHECK( + state_ == CommandBuffer::State::RECORDING, + "Vulkan CommandBuffer: called reset_querypool() on a command buffer whose state " + "is not RECORDING."); -inline void Command::Buffer::Bound::reset() { - pipeline = {}; - descriptor_set = VK_NULL_HANDLE; + vkCmdResetQueryPool(handle_, querypool, first_idx, count); } -inline Command::Buffer::Barrier::Stage::operator bool() const { - return (0u != src) || (0u != dst); -} +VkCommandBuffer CommandBuffer::get_submit_handle() { + TORCH_CHECK( + state_ == CommandBuffer::State::READY, + "Vulkan CommandBuffer: called begin() on a command buffer whose state " + "is not READY."); -inline void Command::Buffer::Barrier::reset() { - stage = {}; - buffers.clear(); - images.clear(); -} + const VkCommandBuffer handle = handle_; -Command::Pool::Pool(const GPU& gpu) - : device_(gpu.device), - command_pool_( - create_command_pool(gpu.device, gpu.queue_family_index), - VK_DELETER(CommandPool)(device_)), - buffer_{} { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_, - "Invalid Vulkan device!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_pool_, - "Invalid Vulkan command pool!"); - - buffer_.pool.reserve(Configuration::kReserve); -} + handle_ = VK_NULL_HANDLE; + bound_.reset(); + state_ = CommandBuffer::State::SUBMITTED; -Command::Pool::Pool(Pool&& pool) - : device_(std::move(pool.device_)), - command_pool_(std::move(pool.command_pool_)), - buffer_(std::move(pool.buffer_)), - stream_(std::move(pool.stream_)) { - pool.invalidate(); + return handle; } -Command::Pool& Command::Pool::operator=(Pool&& pool) { - if (&pool != this) { - device_ = std::move(pool.device_); - command_pool_ = std::move(pool.command_pool_); - buffer_ = std::move(pool.buffer_); - stream_ = std::move(pool.stream_); +// +// CommandPool +// - pool.invalidate(); +CommandPool::CommandPool( + const VkDevice device, + const uint32_t queue_family_idx, + const CommandPoolConfig& config) + : device_(device), + queue_family_idx_(queue_family_idx), + pool_(VK_NULL_HANDLE), + config_(config), + mutex_{}, + buffers_{}, + in_use_(0u) { + const VkCommandPoolCreateInfo create_info{ + VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO, + nullptr, + VK_COMMAND_POOL_CREATE_TRANSIENT_BIT, + queue_family_idx_, }; - return *this; -} + VK_CHECK(vkCreateCommandPool(device_, &create_info, nullptr, &pool_)); -Command::Pool::~Pool() { - try { - if (device_ && command_pool_) { - purge(); - } - } - catch (const std::exception& e) { - TORCH_WARN( - "Vulkan: Command pool destructor raised an exception! Error: ", - e.what()); - } - catch (...) { - TORCH_WARN( - "Vulkan: Command pool destructor raised an exception! " - "Error: Unknown"); - } + // Pre-allocate some command buffers + allocate_new_batch(config_.cmdPoolInitialSize); } -Command::Buffer Command::Pool::allocate() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && command_pool_, - "This command pool is in an invalid state! " - "Potential reason: This command pool is moved from."); - - if (buffer_.pool.size() == buffer_.in_use) { - buffer_.pool.resize( - buffer_.pool.size() + - Configuration::kQuantum); - - allocate_command_buffers( - device_, - command_pool_.get(), - buffer_.pool.data() + buffer_.in_use, - Configuration::kQuantum); +CommandPool::~CommandPool() { + if (VK_NULL_HANDLE == pool_) { + return; } - - return Buffer(buffer_.pool[buffer_.in_use++]); + vkDestroyCommandPool(device_, pool_, nullptr); } -Command::Buffer& Command::Pool::stream() { - if (!stream_.buffer) { - stream_.buffer = allocate(); - stream_.buffer.begin(); - stream_.counter = 0u; - } - - return stream_.buffer; -} +CommandBuffer CommandPool::get_new_cmd() { + std::lock_guard lock(mutex_); -void Command::Pool::purge() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && command_pool_, - "This command pool is in an invalid state! " - "Potential reason: This command pool is moved from."); + // No-ops if there are command buffers available + allocate_new_batch(config_.cmdPoolBatchSize); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - !stream_.buffer, - "Pending command buffer detected. Make sure all command buffers are " - "submitted to the queue for execution prior to reclaiming pool memory."); + const VkCommandBuffer handle = buffers_[in_use_]; - buffer_.in_use = 0u; - VK_CHECK(vkResetCommandPool(device_, command_pool_.get(), 0u)); + in_use_++; + return CommandBuffer(handle); } -void Command::Pool::submit( - const VkQueue queue, - const c10::ArrayRef buffers, - const Resource::Fence fence) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && command_pool_, - "This command pool is in an invalid state! " - "Potential reason: This command pool is moved from."); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - queue, - "Invalid Vulkan queue!"); - - c10::SmallVector command_buffers; - command_buffers.reserve(buffers.size()); - - for (const Buffer& buffer : buffers) { - VkCommandBuffer command_buffer = buffer.handle(); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_buffer, - "Invalid Vulkan command buffer!"); - - // Are we submitting our one and only command stream, or a regular command - // buffer whose scope is manually maintained by the user? Automatically - // maintain state and submission rate if the former. - - if (stream_.buffer.handle() == command_buffer) { - // Hand the stream off to the driver if: - // - The user has implictly signaled interest in the results via a fence. - // - We are over the submission cutoff. We don't want to starve the GPU. - - if (fence || (stream_.counter++ > Configuration::kSubmit)) { - stream_.buffer.end(); - stream_.buffer.invalidate(); - } - // Skip - Accumulate more calls prior to submission. - else { - command_buffer = VK_NULL_HANDLE; - } - } +void CommandPool::flush() { + std::lock_guard lock(mutex_); + VK_CHECK(vkResetCommandPool(device_, pool_, 0u)); + in_use_ = 0u; +} - if (command_buffer) { - command_buffers.push_back(command_buffer); - } +void CommandPool::allocate_new_batch(const uint32_t count) { + // No-ops if there are still command buffers availble + if (in_use_ < buffers_.size()) { + return; } - if (!command_buffers.empty()) { - const VkSubmitInfo submit_info{ - VK_STRUCTURE_TYPE_SUBMIT_INFO, - nullptr, - 0u, - nullptr, - nullptr, - utils::safe_downcast(command_buffers.size()), - command_buffers.data(), - 0u, - nullptr, - }; - - { - // vkQueueSubmit is not thread-safe, only one thread can push the commands at a time. - // (See https://vkguide.dev/docs/chapter-1/vulkan_command_flow/#vulkan-command-execution) - // The number of available queues depends on GPU. It could be 1 and we cannot assume we can create multiple queues. - // Thus, we need to avoid calling vkQueueSubmit from multiple threads at the same time. - // When running Vulkan backend in different threads without any locking mechanism, - // vkQueueSubmit will get the VK_ERROR_INITIALIZATION_FAILED(-3) error. - std::lock_guard guard(queue_mutex); - VK_CHECK(vkQueueSubmit(queue, 1u, &submit_info, fence.handle())); - } - } -} + buffers_.resize(buffers_.size() + count); -void Command::Pool::invalidate() { - device_ = VK_NULL_HANDLE; - command_pool_.reset(); + const VkCommandBufferAllocateInfo allocate_info{ + VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, // sType + nullptr, // pNext + pool_, // commandPool + VK_COMMAND_BUFFER_LEVEL_PRIMARY, // level + count, // commandBufferCount + }; + + VK_CHECK(vkAllocateCommandBuffers( + device_, &allocate_info, buffers_.data() + in_use_)); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h index 3b656a30868d2..6bb9f49e95656 100644 --- a/aten/src/ATen/native/vulkan/api/Command.h +++ b/aten/src/ATen/native/vulkan/api/Command.h @@ -14,127 +14,128 @@ namespace native { namespace vulkan { namespace api { -struct Command final { - class Pool; +class CommandBuffer final { + public: + explicit CommandBuffer( + const VkCommandBuffer, + const VkCommandBufferUsageFlags = + VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT); + + CommandBuffer(const CommandBuffer&) = delete; + CommandBuffer& operator=(const CommandBuffer&) = delete; + + CommandBuffer(CommandBuffer&&) noexcept; + CommandBuffer& operator=(CommandBuffer&&) noexcept; + + ~CommandBuffer() = default; + + // The lifecycle of a command buffer is as follows: + enum State { + INVALID, // Used to indicate the command buffer is moved from + NEW, // Set during constructor + RECORDING, // Set during call to begin(), dispatch(), and + // copy_texture_to_texture() + PIPELINE_BOUND, // Set during call to bind_pipeline() + DESCRIPTORS_BOUND, // Set during call to bind_descriptors() + BARRIERS_INSERTED, // Set during call to insert_barrier() + READY, // Set during call to end() + SUBMITTED, // Set during call to get_submit_handle() + }; - // - // Buffer - // + struct Bound { + VkPipeline pipeline; + VkPipelineLayout pipeline_layout; + utils::uvec3 local_workgroup_size; + VkDescriptorSet descriptors; + + explicit Bound() + : pipeline{VK_NULL_HANDLE}, + pipeline_layout{VK_NULL_HANDLE}, + local_workgroup_size{0u, 0u, 0u}, + descriptors{VK_NULL_HANDLE} {} + + inline void reset() { + pipeline = VK_NULL_HANDLE; + pipeline_layout = VK_NULL_HANDLE; + local_workgroup_size = {0u, 0u, 0u}; + descriptors = VK_NULL_HANDLE; + } + }; - class Buffer final { - public: - explicit Buffer(VkCommandBuffer command_buffer = VK_NULL_HANDLE); - Buffer(const Buffer&) = delete; - Buffer& operator=(const Buffer&) = delete; - Buffer(Buffer&&); - Buffer& operator=(Buffer&&); - ~Buffer() = default; + private: + VkCommandBuffer handle_; + VkCommandBufferUsageFlags flags_; + State state_; + Bound bound_; - operator bool() const; - VkCommandBuffer handle() const; + public: + void begin(); + void end(); - void begin(); - void end(); + void bind_pipeline( + const VkPipeline, + const VkPipelineLayout, + const utils::uvec3); + void bind_descriptors(const VkDescriptorSet); - void barrier(const Pipeline::Barrier& barrier); - void bind(const Pipeline::Object& pipeline); - void bind(const Descriptor::Set& set); - void copy(Resource::Buffer::Object source, Resource::Buffer::Object destination); - void dispatch(const Shader::WorkGroup& global_work_group); + void insert_barrier(const PipelineBarrier& pipeline_barrier); + void dispatch(const utils::uvec3&); - private: - friend class Pool; + void copy_texture_to_texture( + const api::VulkanImage&, + const api::VulkanImage&, + const api::utils::uvec3&, + const api::utils::uvec3&, + const api::utils::uvec3&); - void barrier(); - void invalidate(); + void write_timestamp(const VkQueryPool, const uint32_t) const; + void reset_querypool(const VkQueryPool, const uint32_t, const uint32_t) const; - private: - VkCommandBuffer command_buffer_; + VkCommandBuffer get_submit_handle(); - struct Bound final { - Pipeline::Object pipeline; - VkDescriptorSet descriptor_set; + inline operator bool() const { + return VK_NULL_HANDLE != handle_; + } +}; - void reset(); - } bound_; +struct CommandPoolConfig final { + uint32_t cmdPoolInitialSize; + uint32_t cmdPoolBatchSize; +}; - struct Barrier final { - struct Stage final { - VkPipelineStageFlags src; - VkPipelineStageFlags dst; +class CommandPool final { + public: + explicit CommandPool( + const VkDevice, + const uint32_t, + const CommandPoolConfig&); - operator bool() const; - } stage; + CommandPool(const CommandPool&) = delete; + CommandPool& operator=(const CommandPool&) = delete; - c10::SmallVector buffers; - c10::SmallVector images; + CommandPool(CommandPool&&) = delete; + CommandPool& operator=(CommandPool&&) = delete; - void reset(); - } barriers_; - }; + ~CommandPool(); - // - // Pool - // - - class Pool final { - public: - explicit Pool(const GPU& gpu); - Pool(const Pool&) = delete; - Pool& operator=(const Pool&) = delete; - Pool(Pool&&); - Pool& operator=(Pool&&); - ~Pool(); - - Buffer allocate(); - Buffer& stream(); - void purge(); - - void submit( - VkQueue queue, - c10::ArrayRef buffers, - Resource::Fence fence = {}); - - private: - void invalidate(); - - private: - struct Configuration final { - static constexpr uint32_t kQuantum = 4u; - static constexpr uint32_t kReserve = 16u; - static constexpr uint32_t kSubmit = 16u; - }; - - VkDevice device_; - Handle command_pool_; - - struct { - std::vector pool; - size_t in_use; - } buffer_; - - struct { - Buffer buffer; - uint32_t counter; - } stream_; - } pool /* [thread_count] */; - - explicit Command(const GPU& gpu) - : pool(gpu) { - } -}; + private: + VkDevice device_; + uint32_t queue_family_idx_; + VkCommandPool pool_; + CommandPoolConfig config_; + // New Buffers + std::mutex mutex_; + std::vector buffers_; + size_t in_use_; -// -// Impl -// + public: + CommandBuffer get_new_cmd(); -inline Command::Buffer::operator bool() const { - return VK_NULL_HANDLE != command_buffer_; -} + void flush(); -inline VkCommandBuffer Command::Buffer::handle() const { - return command_buffer_; -} + private: + void allocate_new_batch(const uint32_t); +}; } // namespace api } // namespace vulkan diff --git a/aten/src/ATen/native/vulkan/api/Common.cpp b/aten/src/ATen/native/vulkan/api/Common.cpp deleted file mode 100644 index 8749d4b420e01..0000000000000 --- a/aten/src/ATen/native/vulkan/api/Common.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include - -#define VK_DELETER_DISPATCHABLE_DEFINE(Handle) \ - VK_DELETER_DISPATCHABLE_DECLARE(Handle) { \ - if (C10_LIKELY(VK_NULL_HANDLE != handle)) { \ - vkDestroy##Handle(handle, nullptr); \ - } \ - } - -#define VK_DELETER_NON_DISPATCHABLE_DEFINE(Handle) \ - destroy_##Handle::destroy_##Handle(const VkDevice device) \ - : device_(device) { \ - } \ - \ - void destroy_##Handle::operator()(const Vk##Handle handle) const { \ - if (C10_LIKELY(VK_NULL_HANDLE != handle)) { \ - vkDestroy##Handle(device_, handle, nullptr); \ - } \ - } - -namespace at { -namespace native { -namespace vulkan { -namespace api { - -VK_DELETER_DISPATCHABLE_DEFINE(Instance); -VK_DELETER_DISPATCHABLE_DEFINE(Device); -VK_DELETER_NON_DISPATCHABLE_DEFINE(Semaphore); -VK_DELETER_NON_DISPATCHABLE_DEFINE(Fence); -VK_DELETER_NON_DISPATCHABLE_DEFINE(Buffer); -VK_DELETER_NON_DISPATCHABLE_DEFINE(Image); -VK_DELETER_NON_DISPATCHABLE_DEFINE(Event); -VK_DELETER_NON_DISPATCHABLE_DEFINE(BufferView); -VK_DELETER_NON_DISPATCHABLE_DEFINE(ImageView); -VK_DELETER_NON_DISPATCHABLE_DEFINE(ShaderModule); -VK_DELETER_NON_DISPATCHABLE_DEFINE(PipelineCache); -VK_DELETER_NON_DISPATCHABLE_DEFINE(PipelineLayout); -VK_DELETER_NON_DISPATCHABLE_DEFINE(Pipeline); -VK_DELETER_NON_DISPATCHABLE_DEFINE(DescriptorSetLayout); -VK_DELETER_NON_DISPATCHABLE_DEFINE(Sampler); -VK_DELETER_NON_DISPATCHABLE_DEFINE(DescriptorPool); -VK_DELETER_NON_DISPATCHABLE_DEFINE(CommandPool); - -} // namespace api -} // namespace vulkan -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h index 84bbeaa6f0e10..1fa268e63409b 100644 --- a/aten/src/ATen/native/vulkan/api/Common.h +++ b/aten/src/ATen/native/vulkan/api/Common.h @@ -4,39 +4,33 @@ #include +#include + #ifdef USE_VULKAN_SHADERC_RUNTIME #include -#define VK_KERNEL(name) \ - ::at::native::vulkan::api::Shader::Descriptor{ \ - name##_glsl, \ +#define VK_KERNEL(name) \ + ::at::native::vulkan::api::ShaderSource { \ +#name, name##_glsl, \ } #else #include -#define VK_KERNEL(name) \ - ::at::native::vulkan::api::Shader::Descriptor{ \ - name##_spv, \ - name##_spv_len, \ +#define VK_KERNEL(name) \ + ::at::native::vulkan::api::ShaderSource { \ +#name, name##_spv, name##_spv_len, name##_spv_layout \ } #endif /* USE_VULKAN_SHADERC_RUNTIME */ -#ifdef USE_VULKAN_WRAPPER -#ifdef USE_VULKAN_VOLK -#include -#else -#include -#endif /* USE_VULKAN_VOLK */ -#else -#include -#endif /* USE_VULKAN_WRAPPER */ - -#define VK_CHECK(function) \ - do { \ - const VkResult result = (function); \ - TORCH_CHECK( \ - VK_SUCCESS == result, \ - C10_STRINGIZE(__FILE__), " [", \ - C10_STRINGIZE(__LINE__), "] " \ - "VkResult:", result); \ +#define VK_CHECK(function) \ + do { \ + const VkResult result = (function); \ + TORCH_CHECK( \ + VK_SUCCESS == result, \ + C10_STRINGIZE(__FILE__), \ + " [", \ + C10_STRINGIZE(__LINE__), \ + "] " \ + "VkResult:", \ + result); \ } while (false) #define VK_CHECK_RELAXED(function) \ @@ -45,163 +39,4 @@ TORCH_CHECK(VK_SUCCESS <= result, "VkResult:", result); \ } while (false) -#define VK_DELETER(Handle) \ - at::native::vulkan::api::destroy_##Handle - -#define VK_DELETER_DISPATCHABLE_DECLARE(Handle) \ - void destroy_##Handle(const Vk##Handle handle) - -#define VK_DELETER_NON_DISPATCHABLE_DECLARE(Handle) \ - class destroy_##Handle final { \ - public: \ - explicit destroy_##Handle(const VkDevice device); \ - void operator()(const Vk##Handle handle) const; \ - private: \ - VkDevice device_; \ - }; - -namespace at { -namespace native { -namespace vulkan { -namespace api { - -class Adapter; -struct Command; -class Context; -struct Descriptor; -struct Pipeline; -struct Resource; -class Runtime; -struct Shader; - -struct GPU final { - VkInstance instance; - const Adapter* adapter; - VkDevice device; - uint32_t queue_family_index; - VkQueue queue; -}; - -VK_DELETER_DISPATCHABLE_DECLARE(Instance); -VK_DELETER_DISPATCHABLE_DECLARE(Device); -VK_DELETER_NON_DISPATCHABLE_DECLARE(Semaphore); -VK_DELETER_NON_DISPATCHABLE_DECLARE(Fence); -VK_DELETER_NON_DISPATCHABLE_DECLARE(Buffer); -VK_DELETER_NON_DISPATCHABLE_DECLARE(Image); -VK_DELETER_NON_DISPATCHABLE_DECLARE(Event); -VK_DELETER_NON_DISPATCHABLE_DECLARE(BufferView); -VK_DELETER_NON_DISPATCHABLE_DECLARE(ImageView); -VK_DELETER_NON_DISPATCHABLE_DECLARE(ShaderModule); -VK_DELETER_NON_DISPATCHABLE_DECLARE(PipelineCache); -VK_DELETER_NON_DISPATCHABLE_DECLARE(PipelineLayout); -VK_DELETER_NON_DISPATCHABLE_DECLARE(Pipeline); -VK_DELETER_NON_DISPATCHABLE_DECLARE(DescriptorSetLayout); -VK_DELETER_NON_DISPATCHABLE_DECLARE(Sampler); -VK_DELETER_NON_DISPATCHABLE_DECLARE(DescriptorPool); -VK_DELETER_NON_DISPATCHABLE_DECLARE(CommandPool); - -// Vulkan objects are referenced via handles. The spec defines Vulkan handles -// under two categories: dispatchable and non-dispatchable. Dispatchable handles -// are required to be strongly typed as a result of being pointers to unique -// opaque types. Since dispatchable handles are pointers at the heart, -// std::unique_ptr can be used to manage their lifetime with a custom deleter. -// Non-dispatchable handles on the other hand, are not required to have strong -// types, and even though they default to the same implementation as dispatchable -// handles on some platforms - making the use of std::unique_ptr possible - they -// are only required by the spec to weakly aliases 64-bit integers which is the -// implementation some platforms default to. This makes the use of std::unique_ptr -// difficult since semantically unique_ptrs store pointers to their payload -// which is also what is passed onto the custom deleters. - -template -class Handle final { - public: - Handle(Type payload, Deleter deleter); - Handle(const Handle&) = delete; - Handle& operator=(const Handle&) = delete; - Handle(Handle&&); - Handle& operator=(Handle&&) &; - Handle& operator=(Handle&&) && = delete; - ~Handle(); - - operator bool() const; - Type get() const &; - Type get() const && = delete; - Type release(); - void reset(Type payload = kNull); - - private: - static constexpr Type kNull{}; - - private: - Type payload_; - Deleter deleter_; -}; - -// -// Impl -// - -template -constexpr Type Handle::kNull; - -template -inline Handle::Handle(const Type payload, Deleter deleter) - : payload_(payload), - deleter_(std::move(deleter)) { -} - -template -inline Handle::Handle(Handle&& handle) - : payload_(handle.release()), - deleter_(std::move(handle.deleter_)) { -} - -template -inline Handle& -Handle::operator=(Handle&& handle) & -{ - reset(handle.release()); - deleter_ = std::move(handle.deleter_); - return *this; -} - -template -inline Handle::~Handle() { - reset(); -} - -template -inline Handle::operator bool() const { - return get(); -} - -template -inline Type Handle::get() const & { - return payload_; -} - -template -inline Type Handle::release() { - const Type payload = payload_; - payload_ = kNull; - - return payload; -} - -template -inline void Handle::reset(Type payload) { - using std::swap; - swap(payload_, payload); - - if (kNull != payload) { - deleter_(payload); - } -} - -} // namespace api -} // namespace vulkan -} // namespace native -} // namespace at - #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp index 260d10dbe686e..4d7b3aa0d9877 100644 --- a/aten/src/ATen/native/vulkan/api/Context.cpp +++ b/aten/src/ATen/native/vulkan/api/Context.cpp @@ -1,5 +1,4 @@ #include -#include #include #include @@ -9,140 +8,131 @@ namespace at { namespace native { namespace vulkan { namespace api { -namespace { - -VkDevice create_device( - const VkPhysicalDevice physical_device, - const uint32_t compute_queue_family_index) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - physical_device, - "Invalid Vulkan physical device!"); - - const float queue_priorities = 1.0f; - const VkDeviceQueueCreateInfo device_queue_create_info{ - VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO, - nullptr, - 0u, - compute_queue_family_index, - 1u, - &queue_priorities, - }; - - uint32_t device_extension_properties_count = 0; - VK_CHECK(vkEnumerateDeviceExtensionProperties( - physical_device, - nullptr, - &device_extension_properties_count, - nullptr)); - - std::vector device_extension_properties( - device_extension_properties_count); - - VK_CHECK(vkEnumerateDeviceExtensionProperties( - physical_device, - nullptr, - &device_extension_properties_count, - device_extension_properties.data())); - - constexpr const char* const requested_device_extensions[]{ - #ifdef VK_KHR_portability_subset - // https://vulkan.lunarg.com/doc/view/1.2.162.0/mac/1.2-extensions/vkspec.html#VUID-VkDeviceCreateInfo-pProperties-04451 - VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME, - #endif - }; - - std::vector enabled_device_extensions; - - for (const auto& requested_device_extension : requested_device_extensions) { - for (const auto& extension : device_extension_properties) { - if (strcmp(requested_device_extension, extension.extensionName) == 0) { - enabled_device_extensions.push_back(requested_device_extension); - break; - } - } - } - const VkDeviceCreateInfo device_create_info{ - VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, - nullptr, - 0u, - 1u, - &device_queue_create_info, - 0u, - nullptr, - static_cast(enabled_device_extensions.size()), - enabled_device_extensions.data(), - nullptr, - }; - - VkDevice device{}; - VK_CHECK(vkCreateDevice(physical_device, &device_create_info, nullptr, &device)); - TORCH_CHECK(device, "Invalid Vulkan device!"); - -#ifdef USE_VULKAN_WRAPPER -#ifdef USE_VULKAN_VOLK - volkLoadDevice(device); -#endif -#endif - - return device; +Context::Context(size_t adapter_i, const ContextConfig& config) + : config_(config), + // Important handles + adapter_p_(runtime()->get_adapter_p(adapter_i)), + device_(adapter_p_->device_handle()), + queue_(adapter_p_->request_queue()), + // Resource pools + command_pool_(device_, queue_.family_index, config_.cmdPoolConfig), + descriptor_pool_(device_, config_.descriptorPoolConfig), + fences_(device_), +// Diagnostics +#ifdef USE_VULKAN_GPU_DIAGNOSTICS + querypool_(device_, config_.queryPoolConfig), +#endif /* USE_VULKAN_GPU_DIAGNOSTICS */ + // Command buffer submission + cmd_mutex_{}, + cmd_(VK_NULL_HANDLE), + submit_count_{0u}, + // Memory Management + buffer_clearlist_mutex_{}, + buffers_to_clear_{}, + image_clearlist_mutex_{}, + images_to_clear_{} { +} + +Context::~Context() { + flush(); + // Let the device know the context is done with the queue + adapter_p_->return_queue(queue_); } -VkQueue acquire_queue( - const VkDevice device, - const uint32_t compute_queue_family_index) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); +DescriptorSet Context::submit_compute_prologue( + CommandBuffer& command_buffer, + const ShaderSource& shader_descriptor, + const utils::uvec3& local_workgroup_size) { + const VkDescriptorSetLayout shader_layout = + shader_layout_cache().retrieve(shader_descriptor.kernel_layout); + + const VkPipelineLayout pipeline_layout = + pipeline_layout_cache().retrieve(shader_layout); + + const VkPipeline pipeline = pipeline_cache().retrieve( + {pipeline_layout_cache().retrieve(shader_layout), + shader_cache().retrieve(shader_descriptor), + local_workgroup_size}); - VkQueue queue{}; - vkGetDeviceQueue(device, compute_queue_family_index, 0, &queue); - TORCH_CHECK(queue, "Invalid Vulkan queue!"); + command_buffer.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size); - return queue; + return descriptor_pool().get_descriptor_set( + shader_layout, shader_descriptor.kernel_layout); } -} // namespace +void Context::submit_compute_epilogue( + CommandBuffer& command_buffer, + const DescriptorSet& descriptors, + const PipelineBarrier& pipeline_barrier, + const utils::uvec3& global_workgroup_size) { + command_buffer.bind_descriptors(descriptors.get_bind_handle()); + command_buffer.insert_barrier(pipeline_barrier); -Context::Context(const VkInstance instance, size_t adapter_i) - : instance_(instance), - adapter_i_(adapter_i), - device_(runtime()->get_adapter(adapter_i).device_handle()), - queue_(runtime()->get_adapter(adapter_i).request_queue()), - shader_(gpu()), - pipeline_(gpu()), - threadcontext_(gpu()) { + command_buffer.dispatch(global_workgroup_size); } -Context::~Context() { - // Let the device know the context is done with the queue - runtime()->get_adapter(adapter_i_).return_queue(queue_); - // Do not call flush() since all per-thread objects will be destroyed as each thread exits +void Context::submit_texture_copy( + const PipelineBarrier& pipeline_barrier, + const api::VulkanImage& source, + const api::VulkanImage& destination, + const api::utils::uvec3& copy_range, + const api::utils::uvec3& src_offset, + const api::utils::uvec3& dst_offset, + const VkFence fence_handle) { + // Serialize recording to the shared command buffer. Do not initialize with a + // mutex just yet, since in some cases it will be externally managed. + std::unique_lock cmd_lock; + // Refer to comments in submit_compute_job for explanation. + if (fence_handle == VK_NULL_HANDLE) { + cmd_lock = std::unique_lock(cmd_mutex_); + } + + set_cmd(); + +#ifdef USE_VULKAN_GPU_DIAGNOSTICS + uint32_t log_idx = querypool_.shader_profile_begin( + cmd_, + "copy_texture_to_texture", + create_extent3d({0, 0, 0}), + create_extent3d({0, 0, 0})); +#endif /* USE_VULKAN_GPU_DIAGNOSTICS */ + + cmd_.insert_barrier(pipeline_barrier); + + cmd_.copy_texture_to_texture( + source, destination, copy_range, src_offset, dst_offset); + +#ifdef USE_VULKAN_GPU_DIAGNOSTICS + querypool_.shader_profile_end(cmd_, log_idx); +#endif /* USE_VULKAN_GPU_DIAGNOSTICS */ + + submit_count_++; + if (fence_handle != VK_NULL_HANDLE || + submit_count_ >= config_.cmdSubmitFrequency) { + submit_cmd_to_gpu(fence_handle); + } } -void Context::flush() { - VK_CHECK(vkQueueWaitIdle(queue())); +void Context::submit_cmd_to_gpu(const VkFence fence_handle) { + if (cmd_) { + cmd_.end(); + adapter_p_->submit_cmd(queue_, cmd_.get_submit_handle(), fence_handle); - resource().pool.purge(); - descriptor().pool.purge(); - command().pool.purge(); + submit_count_ = 0u; + } } -void Context::wait(const at::Tensor& src) { - // wait only if Vulkan tensor - if (at::kVulkan == src.device().type()) { - api::Command::Pool& command_pool = command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); +void Context::flush() { + VK_CHECK(vkQueueWaitIdle(queue())); - using Future = ops::vTensor::Future; - const ops::vTensor& v_src = ops::convert(src); - const Future v_src_future = v_src.host(command_buffer); + command_pool_.flush(); + descriptor_pool_.flush(); - // This wait() is a no-op if data is not out of sync. More often than - // not though, waits here are expected as the GPU catches up with - // compute submitted from CPU. - v_src_future.wait(); - } + std::lock_guard bufferlist_lock(buffer_clearlist_mutex_); + std::lock_guard imagelist_lock(image_clearlist_mutex_); + buffers_to_clear_.clear(); + images_to_clear_.clear(); } bool available() { @@ -152,21 +142,47 @@ bool available() { Context* context() { static const std::unique_ptr context([]() -> Context* { try { - return new Context(runtime()->instance(), runtime()->default_adapter_i()); - } - catch (const std::exception& e) { - TORCH_CHECK(false, "Vulkan: Failed to initialize context! Error: ", e.what()); - } - catch (...) { - TORCH_CHECK(false, "Vulkan: Failed to initialize context! Error: Unknown"); + const uint32_t submit_frequency = 16u; + + const CommandPoolConfig cmd_config{ + 32u, // cmdPoolInitialSize + 8u, // cmdPoolBatchSize + }; + + const DescriptorPoolConfig descriptor_pool_config{ + 1024u, // descriptorPoolMaxSets + 1024u, // descriptorUniformBufferCount + 1024u, // descriptorStorageBufferCount + 1024u, // descriptorCombinedSamplerCount + 1024u, // descriptorStorageImageCount + 32u, // descriptorPileSizes + }; + + const QueryPoolConfig query_pool_config{ + 4096u, // maxQueryCount + 256u, // initialReserveSize + }; + + const ContextConfig config{ + submit_frequency, // cmdSubmitFrequency + cmd_config, // cmdPoolConfig + descriptor_pool_config, // descriptorPoolConfig + query_pool_config, // queryPoolConfig + }; + + return new Context(runtime()->default_adapter_i(), config); + } catch (const std::exception& e) { + TORCH_CHECK( + false, "Vulkan: Failed to initialize context! Error: ", e.what()); + } catch (...) { + TORCH_CHECK( + false, "Vulkan: Failed to initialize context! Error: Unknown"); } return nullptr; }()); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - context, - "Invalid Vulkan context!"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(context, "Invalid Vulkan context!"); return context.get(); } @@ -182,41 +198,6 @@ struct VulkanImpl final : public at::vulkan::VulkanImplInterface { }; static at::vulkan::VulkanImplRegistrar g_vulkan_impl(new VulkanImpl()); -Descriptor::Set dispatch_prologue( - Command::Buffer& command_buffer, - const Shader::Layout::Signature& shader_layout_signature, - const Shader::Descriptor& shader_descriptor, - const Shader::WorkGroup& local_work_group_size) { - Context* const context = api::context(); - Descriptor& descriptor = context->descriptor(); - Pipeline& pipeline = context->pipeline(); - Shader& shader = context->shader(); - - const Shader::Layout::Object shader_layout = - shader.layout.cache.retrieve({ - shader_layout_signature, - }); - - command_buffer.bind( - pipeline.cache.retrieve({ - pipeline.layout.cache.retrieve({ - shader_layout.handle, - }), - shader.cache.retrieve(shader_descriptor), - local_work_group_size, - })); - - return descriptor.pool.allocate(shader_layout); -} - -void dispatch_epilogue( - Command::Buffer& command_buffer, - const Descriptor::Set& descriptor_set, - const Shader::WorkGroup& global_work_group) { - command_buffer.bind(descriptor_set); - command_buffer.dispatch(global_work_group); -} - } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h index 7b1bb85f92306..fbf4aae11376f 100644 --- a/aten/src/ATen/native/vulkan/api/Context.h +++ b/aten/src/ATen/native/vulkan/api/Context.h @@ -2,20 +2,28 @@ #ifdef USE_VULKAN_API -#include #include #include +#include #include #include +#include #include +#include #include -#include namespace at { namespace native { namespace vulkan { namespace api { +struct ContextConfig final { + uint32_t cmdSubmitFrequency; + CommandPoolConfig cmdPoolConfig; + DescriptorPoolConfig descriptorPoolConfig; + QueryPoolConfig queryPoolConfig; +}; + // // Vulkan Context holds onto all relevant Vulkan state as it pertains to our // use of Vulkan in PyTorch. A Context is associated with one, and only one, @@ -27,168 +35,291 @@ namespace api { class Context final { public: - explicit Context(const VkInstance instance, size_t adapter_i); + explicit Context(size_t adapter_i, const ContextConfig&); Context(const Context&) = delete; Context& operator=(const Context&) = delete; - Context(Context&&) = default; - Context& operator=(Context&&) = default; + Context(Context&&) = delete; + Context& operator=(Context&&) = delete; ~Context(); - GPU gpu(); - Command& command(); - Shader& shader(); - QueryPool& querypool(); - Pipeline& pipeline(); - Descriptor& descriptor(); - Resource& resource(); + private: + // Config + ContextConfig config_; + // Important handles + Adapter* adapter_p_; + VkDevice device_; + Adapter::Queue queue_; + // Resource Pools + CommandPool command_pool_; + DescriptorPool descriptor_pool_; + FencePool fences_; + // Diagnostics +#ifdef USE_VULKAN_GPU_DIAGNOSTICS + QueryPool querypool_; +#endif /* USE_VULKAN_GPU_DIAGNOSTICS */ + // Command buffers submission + std::mutex cmd_mutex_; + CommandBuffer cmd_; + uint32_t submit_count_; + // Memory Management + std::mutex buffer_clearlist_mutex_; + std::vector buffers_to_clear_; + std::mutex image_clearlist_mutex_; + std::vector images_to_clear_; - // GPU RPC - template - void dispatch( - Command::Buffer& command_buffer, - const Shader::Layout::Signature& shader_layout_signature, - const Shader::Descriptor& shader_descriptor, - const Shader::WorkGroup& global_work_group, - const Shader::WorkGroup& local_work_group_size, - Arguments&&... arguments); - - // This function is expensive and its use consequential for performance. Only - // use this function for debugging or as a short term hack on way to a more - // performant solution. + public: + // Adapter access - void flush(); + inline Adapter* adapter_ptr() { + return adapter_p_; + } - // Use this function only for debugging and testing when you want to make sure - // all GPU operations get finished before calling flush(). Otherwise, it may crash. - void wait(const at::Tensor& src); + inline VkDevice device() { + return device_; + } - private: - VkDevice device(); - VkQueue queue(); + inline VkQueue queue() { + return queue_.handle; + } + + // Device Caches + + inline ShaderLayoutCache& shader_layout_cache() { + return adapter_ptr()->shader_layout_cache(); + } + + inline ShaderCache& shader_cache() { + return adapter_ptr()->shader_cache(); + } + + inline PipelineLayoutCache& pipeline_layout_cache() { + return adapter_ptr()->pipeline_layout_cache(); + } + + inline ComputePipelineCache& pipeline_cache() { + return adapter_ptr()->compute_pipeline_cache(); + } + + // Resource Pools + + inline DescriptorPool& descriptor_pool() { + return descriptor_pool_; + } + + inline FencePool& fences() { + return fences_; + } + + // Diagnostics + +#ifdef USE_VULKAN_GPU_DIAGNOSTICS + inline QueryPool& querypool() { + return querypool_; + } + + inline void reset_querypool() { + set_cmd(); + querypool_.reset(cmd_); + } +#endif /* USE_VULKAN_GPU_DIAGNOSTICS */ + + // Memory Management + void register_buffer_cleanup(VulkanBuffer& buffer) { + std::lock_guard bufferlist_lock(buffer_clearlist_mutex_); + buffers_to_clear_.emplace_back(std::move(buffer)); + } + + void register_image_cleanup(VulkanImage& image) { + std::lock_guard imagelist_lock(image_clearlist_mutex_); + images_to_clear_.emplace_back(std::move(image)); + } + + // GPU RPC + + inline std::unique_lock dispatch_lock() { + return std::unique_lock(cmd_mutex_); + } private: - // Construction and destruction order matters. Do not move members around. - VkInstance instance_; - size_t adapter_i_; - VkDevice device_; - Adapter::Queue queue_; - Shader shader_; - Pipeline pipeline_; - ThreadContext threadcontext_; -}; + inline void set_cmd() { + if (!cmd_) { + cmd_ = command_pool_.get_new_cmd(); + cmd_.begin(); + } + } + + DescriptorSet submit_compute_prologue( + CommandBuffer&, + const ShaderSource&, + const utils::uvec3&); + + void submit_compute_epilogue( + CommandBuffer&, + const DescriptorSet&, + const PipelineBarrier&, + const utils::uvec3&); -bool available(); + public: + template + void submit_compute_job( + const ShaderSource&, + const PipelineBarrier&, + const utils::uvec3&, + const utils::uvec3&, + const VkFence fence_handle, + Arguments&&...); + + void submit_texture_copy( + const PipelineBarrier& pipeline_barrier, + const api::VulkanImage&, + const api::VulkanImage&, + const api::utils::uvec3&, + const api::utils::uvec3&, + const api::utils::uvec3&, + const VkFence fence_handle); -// The global runtime is retrieved using this function, where it is declared as -// a static local variable. -Context* context(); + private: + void submit_cmd_to_gpu(const VkFence fence_handle = VK_NULL_HANDLE); -// -// Impl -// + public: + void flush(); +}; -inline GPU Context::gpu() { - // A GPU is simply a (physical device, logical device, device queue) trio. - const Adapter* p_adapter = runtime()->get_adapter_p(adapter_i_); - return { - instance_, - p_adapter, - device_, - queue_.family_index, - queue_.handle, - }; -} +class UniformParamsBuffer final { + private: + Context* context_p_; + VulkanBuffer vulkan_buffer_; -inline Shader& Context::shader() { - return shader_; -} + public: + template + UniformParamsBuffer(Context* context_p, const Block& block) + : context_p_(context_p), + vulkan_buffer_( + context_p_->adapter_ptr()->vma().create_params_buffer(block)) {} -inline Pipeline& Context::pipeline() { - return pipeline_; -} + UniformParamsBuffer(const UniformParamsBuffer&) = delete; + UniformParamsBuffer& operator=(const UniformParamsBuffer&) = delete; -inline Command& Context::command() { - return threadcontext_.command(); -} + UniformParamsBuffer(UniformParamsBuffer&&) = delete; + UniformParamsBuffer& operator=(UniformParamsBuffer&&) = delete; -inline Descriptor& Context::descriptor() { - return threadcontext_.descriptor(); -} + ~UniformParamsBuffer() { + context_p_->register_buffer_cleanup(vulkan_buffer_); + } -inline Resource& Context::resource() { - return threadcontext_.resource(); -} + VulkanBuffer& buffer() { + return vulkan_buffer_; + } +}; -inline QueryPool& Context::querypool() { - return threadcontext_.querypool(); -} +class StagingBuffer final { + private: + Context* context_p_; + VulkanBuffer vulkan_buffer_; -inline VkDevice Context::device() { - return device_; -} + public: + StagingBuffer( + Context* context_p, + const VkDeviceSize size, + const bool gpuonly = false) + : context_p_(context_p), + vulkan_buffer_(context_p_->adapter_ptr()->vma().create_storage_buffer( + size, + gpuonly)) {} + + StagingBuffer(const StagingBuffer&) = delete; + StagingBuffer& operator=(const StagingBuffer&) = delete; + + StagingBuffer(StagingBuffer&&) = delete; + StagingBuffer& operator=(StagingBuffer&&) = delete; + + ~StagingBuffer() { + context_p_->register_buffer_cleanup(vulkan_buffer_); + } + + VulkanBuffer& buffer() { + return vulkan_buffer_; + } +}; -inline VkQueue Context::queue() { - return queue_.handle; -} +bool available(); + +// The global runtime is retrieved using this function, where it is declared as +// a static local variable. +Context* context(); namespace detail { -template< - size_t...Indices, - typename ...Arguments> +template inline void bind( - Descriptor::Set& descriptor_set, + DescriptorSet& descriptor_set, const std::index_sequence, - Arguments&&...arguments) { + Arguments&&... arguments) { C10_UNUSED const int _[]{ - 0, - (descriptor_set.bind(Indices, std::forward(arguments)), 0)..., + 0, + (descriptor_set.bind(Indices, std::forward(arguments)), 0)..., }; } } // namespace detail -template -inline void Context::dispatch( - Command::Buffer& command_buffer, - const Shader::Layout::Signature& shader_layout_signature, - const Shader::Descriptor& shader_descriptor, - const Shader::WorkGroup& global_work_group, - const Shader::WorkGroup& local_work_group_size, +template +inline void Context::submit_compute_job( + const ShaderSource& shader_descriptor, + const PipelineBarrier& pipeline_barrier, + const utils::uvec3& global_work_group, + const utils::uvec3& local_work_group_size, + const VkFence fence_handle, Arguments&&... arguments) { - // Forward declaration - Descriptor::Set dispatch_prologue( - Command::Buffer&, - const Shader::Layout::Signature&, - const Shader::Descriptor&, - const Shader::WorkGroup&); + // Serialize recording to the shared command buffer. Do not initialize with a + // mutex just yet, since in some cases it will be externally managed. + std::unique_lock cmd_lock; + // If a fence was passed, then assume that the host intends to sync with + // the GPU, implying there will be imminent calls to fence.wait() and flush(). + // We therefore assume the mutex is externally managed in this case, and the + // calling thread has already locked the mutex prior to calling the function, + // and will release the mutex manually after calling flush(). This will + // prevent more dispatches from being recorded until we have flushed the + // Context. + if (fence_handle == VK_NULL_HANDLE) { + cmd_lock = std::unique_lock(cmd_mutex_); + } + + set_cmd(); + +#ifdef USE_VULKAN_GPU_DIAGNOSTICS + uint32_t log_idx = querypool_.shader_profile_begin( + cmd_, + shader_descriptor.kernel_name, + create_extent3d(global_work_group), + create_extent3d(local_work_group_size)); +#endif /* USE_VULKAN_GPU_DIAGNOSTICS */ // Factor out template parameter independent code to minimize code bloat. - Descriptor::Set descriptor_set = dispatch_prologue( - command_buffer, - shader_layout_signature, - shader_descriptor, - local_work_group_size); + DescriptorSet descriptor_set = + submit_compute_prologue(cmd_, shader_descriptor, local_work_group_size); detail::bind( descriptor_set, std::index_sequence_for{}, std::forward(arguments)...); - // Forward declaration - void dispatch_epilogue( - Command::Buffer&, - const Descriptor::Set&, - const Shader::WorkGroup&); - // Factor out template parameter independent code to minimize code bloat. - dispatch_epilogue( - command_buffer, - descriptor_set, - global_work_group); + submit_compute_epilogue( + cmd_, descriptor_set, pipeline_barrier, global_work_group); + +#ifdef USE_VULKAN_GPU_DIAGNOSTICS + querypool_.shader_profile_end(cmd_, log_idx); +#endif /* USE_VULKAN_GPU_DIAGNOSTICS */ + + submit_count_++; + if (fence_handle != VK_NULL_HANDLE || + submit_count_ >= config_.cmdSubmitFrequency) { + submit_cmd_to_gpu(fence_handle); + } } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.cpp b/aten/src/ATen/native/vulkan/api/Descriptor.cpp index 3c97cd815120c..37ba55ca1c363 100644 --- a/aten/src/ATen/native/vulkan/api/Descriptor.cpp +++ b/aten/src/ATen/native/vulkan/api/Descriptor.cpp @@ -5,408 +5,265 @@ namespace at { namespace native { namespace vulkan { namespace api { -namespace { - -VkDescriptorPool create_descriptor_pool(const VkDevice device) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); - - const struct { - uint32_t capacity; - c10::SmallVector sizes; - } descriptor { - 1024u, - { - /* - Buffers - */ - { - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - 1024u, - }, - { - VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, - 1024u, - }, - - /* - Images - */ - - { - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - 1024u, - }, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - 1024u, - }, - }, - }; - - const VkDescriptorPoolCreateInfo descriptor_pool_create_info{ - VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO, - nullptr, - 0u, - descriptor.capacity, - static_cast(descriptor.sizes.size()), - descriptor.sizes.data(), - }; - - VkDescriptorPool descriptor_pool{}; - VK_CHECK(vkCreateDescriptorPool( - device, - &descriptor_pool_create_info, - nullptr, - &descriptor_pool)); +// +// DescriptorSet +// - TORCH_CHECK( - descriptor_pool, - "Invalid Vulkan descriptor pool!"); - - return descriptor_pool; -} - -void allocate_descriptor_sets( +DescriptorSet::DescriptorSet( const VkDevice device, - const VkDescriptorPool descriptor_pool, - const VkDescriptorSetLayout descriptor_set_layout, - VkDescriptorSet* const descriptor_sets, - const uint32_t count) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor_pool, - "Invalid Vulkan descriptor pool!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor_set_layout, - "Invalid Vulkan descriptor set layout!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor_sets && (count > 0u), - "Invalid usage!"); - - std::vector descriptor_set_layouts(count); - fill( - descriptor_set_layouts.begin(), - descriptor_set_layouts.end(), - descriptor_set_layout - ); - - const VkDescriptorSetAllocateInfo descriptor_set_allocate_info{ - VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO, - nullptr, - descriptor_pool, - utils::safe_downcast(descriptor_set_layouts.size()), - descriptor_set_layouts.data(), - }; - - VK_CHECK(vkAllocateDescriptorSets( - device, - &descriptor_set_allocate_info, - descriptor_sets)); + const VkDescriptorSet handle, + const ShaderLayout::Signature& shader_layout_signature) + : device_(device), + handle_(handle), + shader_layout_signature_(shader_layout_signature), + bindings_{} {} + +DescriptorSet::DescriptorSet(DescriptorSet&& other) noexcept + : device_(other.device_), + handle_(other.handle_), + shader_layout_signature_(std::move(other.shader_layout_signature_)), + bindings_(std::move(other.bindings_)) { + other.handle_ = VK_NULL_HANDLE; } -} // namespace - -Descriptor::Set::Set( - const VkDevice device, - VkDescriptorSet descriptor_set, - const Shader::Layout::Signature& shader_layout_signature) - : device_(device), - descriptor_set_(descriptor_set), - shader_layout_signature_(shader_layout_signature), - bindings_{} { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_, - "Invalid Vulkan device!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor_set_, - "Invalid Vulkan descriptor set!"); -} - -Descriptor::Set::Set(Set&& set) - : device_(std::move(set.device_)), - descriptor_set_(std::move(set.descriptor_set_)), - shader_layout_signature_(std::move(set.shader_layout_signature_)), - bindings_(std::move(set.bindings_)) { - set.invalidate(); -} +DescriptorSet& DescriptorSet::operator=(DescriptorSet&& other) noexcept { + device_ = other.device_; + handle_ = other.handle_; + shader_layout_signature_ = std::move(other.shader_layout_signature_); + bindings_ = std::move(other.bindings_); -Descriptor::Set& Descriptor::Set::operator=(Set&& set) { - if (&set != this) { - device_ = std::move(set.device_); - descriptor_set_ = std::move(set.descriptor_set_); - shader_layout_signature_ = std::move(set.shader_layout_signature_); - bindings_ = std::move(set.bindings_); - - set.invalidate(); - }; + other.handle_ = VK_NULL_HANDLE; return *this; } -Descriptor::Set& Descriptor::Set::bind( - const uint32_t binding, - const Resource::Buffer::Object& buffer) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && descriptor_set_, - "This descriptor set is in an invalid state! " - "Potential reason: This descriptor set is moved from."); - - update({ - binding, - shader_layout_signature_[binding], +DescriptorSet& DescriptorSet::bind( + const uint32_t idx, + const VulkanBuffer& buffer) { + add_binding(DescriptorSet::ResourceBinding{ + idx, // binding_idx + shader_layout_signature_[idx], // descriptor_type + false, // is_image { - .buffer = { - buffer.handle, - buffer.offset, - buffer.range, - }, + // resource_info + .buffer_info = + { + buffer.handle(), // buffer + buffer.mem_offset(), // offset + buffer.mem_range(), // range + }, }, - }); + }); return *this; } -Descriptor::Set& Descriptor::Set::bind( - const uint32_t binding, - const Resource::Image::Object& image) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && descriptor_set_, - "This descriptor set is in an invalid state! " - "Potential reason: This descriptor set is moved from."); - - update(Item{ - binding, - shader_layout_signature_[binding], +DescriptorSet& DescriptorSet::bind( + const uint32_t idx, + const VulkanImage& image) { + VkImageLayout binding_layout = image.layout(); + if (shader_layout_signature_[idx] == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE) { + binding_layout = VK_IMAGE_LAYOUT_GENERAL; + } + + add_binding(DescriptorSet::ResourceBinding{ + idx, // binding_idx + shader_layout_signature_[idx], // descriptor_type + true, // is_image { - .image = { - image.sampler, - image.view, - [](const VkDescriptorType type, const VkImageLayout layout) { - return (VK_DESCRIPTOR_TYPE_STORAGE_IMAGE == type) ? - VK_IMAGE_LAYOUT_GENERAL : layout; - }(shader_layout_signature_[binding], image.layout), - }, + // resource_info + .image_info = + { + image.sampler(), // buffer + image.image_view(), // imageView + binding_layout, // imageLayout + }, }, - }); + }); return *this; } -VkDescriptorSet Descriptor::Set::handle() const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && descriptor_set_, - "This descriptor set is in an invalid state! " - "Potential reason: This descriptor set is moved from."); - - if (bindings_.dirty) { - const auto is_buffer = [](const VkDescriptorType type) { - switch (type) { - case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER: - case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER: - case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: - case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: - case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC: - case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC: - return true; - - default: - return false; - } +VkDescriptorSet DescriptorSet::get_bind_handle() const { + c10::SmallVector write_descriptor_sets; + + for (const ResourceBinding& binding : bindings_) { + VkWriteDescriptorSet write{ + VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET, // sType + nullptr, // pNext + handle_, // dstSet + binding.binding_idx, // dstBinding + 0u, // dstArrayElement + 1u, // descriptorCount + binding.descriptor_type, // descriptorType + nullptr, // pImageInfo + nullptr, // pBufferInfo + nullptr, // pTexelBufferView }; - const auto is_image = [](const VkDescriptorType type) { - switch (type) { - case VK_DESCRIPTOR_TYPE_SAMPLER: - case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: - case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE: - case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: - case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT: - return true; - - default: - return false; - } - }; - - c10::SmallVector write_descriptor_sets; - - for (const Item& item : bindings_.items) { - VkWriteDescriptorSet write{ - VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET, - nullptr, - descriptor_set_, - item.binding, - 0u, - 1u, - item.type, - nullptr, - nullptr, - nullptr, - }; - - if (is_buffer(item.type)) { - write.pBufferInfo = &item.info.buffer; - } - else if (is_image(item.type)) { - write.pImageInfo = &item.info.image; - } - - write_descriptor_sets.emplace_back(write); + if (binding.is_image) { + write.pImageInfo = &binding.resource_info.image_info; + } else { + write.pBufferInfo = &binding.resource_info.buffer_info; } - vkUpdateDescriptorSets( - device_, - write_descriptor_sets.size(), - write_descriptor_sets.data(), - 0u, - nullptr); - - // Reset - bindings_.dirty = false; + write_descriptor_sets.emplace_back(write); } - return descriptor_set_; -} + vkUpdateDescriptorSets( + device_, + write_descriptor_sets.size(), + write_descriptor_sets.data(), + 0u, + nullptr); + + VkDescriptorSet ret = handle_; -void Descriptor::Set::invalidate() { - device_ = VK_NULL_HANDLE; - descriptor_set_ = VK_NULL_HANDLE; + return ret; } -void Descriptor::Set::update(const Item& item) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && descriptor_set_, - "This descriptor set is in an invalid state! " - "Potential reason: This descriptor set is moved from."); - - const auto items_itr = std::find_if( - bindings_.items.begin(), - bindings_.items.end(), - [binding = item.binding](const Item& other) { - return other.binding == binding; +void DescriptorSet::add_binding(const ResourceBinding& binding) { + const auto bindings_itr = std::find_if( + bindings_.begin(), + bindings_.end(), + [binding_idx = binding.binding_idx](const ResourceBinding& other) { + return other.binding_idx == binding_idx; }); - if (bindings_.items.end() == items_itr) { - bindings_.items.emplace_back(item); - } - else { - *items_itr = item; + if (bindings_.end() == bindings_itr) { + bindings_.emplace_back(binding); + } else { + *bindings_itr = binding; } - - bindings_.dirty = true; } -Descriptor::Pool::Pool(const GPU& gpu) - : device_(gpu.device), - descriptor_pool_( - create_descriptor_pool(gpu.device), - VK_DELETER(DescriptorPool)(device_)) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_, - "Invalid Vulkan device!"); +// +// DescriptorSetPile +// - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor_pool_, - "Invalid Vulkan descriptor pool!"); +DescriptorSetPile::DescriptorSetPile( + const uint32_t pile_size, + const VkDescriptorSetLayout descriptor_set_layout, + const VkDevice device, + const VkDescriptorPool descriptor_pool) + : pile_size_{pile_size}, + set_layout_{descriptor_set_layout}, + device_{device}, + pool_{descriptor_pool}, + descriptors_{}, + in_use_(0u) { + descriptors_.resize(pile_size_); + allocate_new_batch(); } -Descriptor::Pool::Pool(Pool&& pool) - : device_(std::move(pool.device_)), - descriptor_pool_(std::move(pool.descriptor_pool_)), - set_(std::move(pool.set_)) { - pool.invalidate(); +VkDescriptorSet DescriptorSetPile::get_descriptor_set() { + // No-ops if there are descriptor sets available + allocate_new_batch(); + + const VkDescriptorSet handle = descriptors_[in_use_]; + descriptors_[in_use_] = VK_NULL_HANDLE; + + in_use_++; + return handle; } -Descriptor::Pool& Descriptor::Pool::operator=(Pool&& pool) { - if (&pool != this) { - device_ = std::move(pool.device_); - descriptor_pool_ = std::move(pool.descriptor_pool_); - set_ = std::move(pool.set_); +void DescriptorSetPile::allocate_new_batch() { + // No-ops if there are still descriptor sets availble + if (in_use_ < descriptors_.size() && + descriptors_[in_use_] != VK_NULL_HANDLE) { + return; + } + + std::vector layouts(descriptors_.size()); + fill(layouts.begin(), layouts.end(), set_layout_); - pool.invalidate(); + const VkDescriptorSetAllocateInfo allocate_info{ + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO, // sType + nullptr, // pNext + pool_, // descriptorPool + utils::safe_downcast(layouts.size()), // descriptorSetCount + layouts.data(), // pSetLayouts }; - return *this; -} + VK_CHECK( + vkAllocateDescriptorSets(device_, &allocate_info, descriptors_.data())); -Descriptor::Pool::~Pool() { - try { - if (device_ && descriptor_pool_) { - purge(); - } - } - catch (const std::exception& e) { - TORCH_WARN( - "Vulkan: Descriptor pool destructor raised an exception! Error: ", - e.what()); - } - catch (...) { - TORCH_WARN( - "Vulkan: Descriptor pool destructor raised an exception! " - "Error: Unknown"); - } + in_use_ = 0u; } -Descriptor::Set Descriptor::Pool::allocate( - const Shader::Layout::Object& shader_layout) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && descriptor_pool_, - "This descriptor pool is in an invalid state! " - "Potential reason: This descriptor pool is moved from."); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - shader_layout, - "Invalid Vulkan shader layout!"); - - auto iterator = set_.layouts.find(shader_layout.handle); - if (set_.layouts.cend() == iterator) { - iterator = set_.layouts.insert({shader_layout.handle, {}}).first; - iterator->second.pool.reserve(Configuration::kReserve); - } +// +// DescriptorPool +// - auto& layout = iterator->second; +DescriptorPool::DescriptorPool( + const VkDevice device, + const DescriptorPoolConfig& config) + : device_(device), + pool_(VK_NULL_HANDLE), + config_(config), + mutex_{}, + piles_{} { + c10::SmallVector type_sizes{ + { + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + config_.descriptorUniformBufferCount, + }, + { + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + config_.descriptorStorageBufferCount, + }, + { + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + config_.descriptorCombinedSamplerCount, + }, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + config_.descriptorStorageBufferCount, + }, + }; - if (layout.pool.size() == layout.in_use) { - layout.pool.resize( - layout.pool.size() + - Configuration::kQuantum); + const VkDescriptorPoolCreateInfo create_info{ + VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + config_.descriptorPoolMaxSets, // maxSets + static_cast(type_sizes.size()), // poolSizeCounts + type_sizes.data(), // pPoolSizes + }; - allocate_descriptor_sets( - device_, - descriptor_pool_.get(), - shader_layout.handle, - layout.pool.data() + layout.in_use, - Configuration::kQuantum); - } + VK_CHECK(vkCreateDescriptorPool(device_, &create_info, nullptr, &pool_)); +} - return Set( - device_, - layout.pool[layout.in_use++], - shader_layout.signature); +DescriptorPool::~DescriptorPool() { + if (VK_NULL_HANDLE == pool_) { + return; + } + vkDestroyDescriptorPool(device_, pool_, nullptr); } -void Descriptor::Pool::purge() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && descriptor_pool_, - "This descriptor pool is in an invalid state! " - "Potential reason: This descriptor pool is moved from."); +DescriptorSet DescriptorPool::get_descriptor_set( + const VkDescriptorSetLayout set_layout, + const ShaderLayout::Signature& signature) { + auto it = piles_.find(set_layout); + if (piles_.cend() == it) { + it = piles_ + .insert({ + set_layout, + DescriptorSetPile( + config_.descriptorPileSizes, set_layout, device_, pool_), + }) + .first; + } + + VkDescriptorSet handle = it->second.get_descriptor_set(); - VK_CHECK(vkResetDescriptorPool(device_, descriptor_pool_.get(), 0u)); - set_.layouts.clear(); + return DescriptorSet(device_, handle, signature); } -void Descriptor::Pool::invalidate() { - device_ = VK_NULL_HANDLE; - descriptor_pool_.reset(); +void DescriptorPool::flush() { + VK_CHECK(vkResetDescriptorPool(device_, pool_, 0u)); + piles_.clear(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.h b/aten/src/ATen/native/vulkan/api/Descriptor.h index 6c50a350d7f3b..5ef21e1f78df0 100644 --- a/aten/src/ATen/native/vulkan/api/Descriptor.h +++ b/aten/src/ATen/native/vulkan/api/Descriptor.h @@ -11,136 +11,117 @@ namespace native { namespace vulkan { namespace api { -// -// This struct defines caches of descriptor pools, and descriptor sets allocated -// from those pools, intended to minimize redundant object reconstructions or -// accelerate unavoidable memory allocations, both at the cost of extra memory -// consumption. -// -// A descriptor set is logically an array of descriptors, each of which -// references a resource (i.e. buffers and images), in turn telling the core -// executing the shader, where in GPU, or GPU-accessible system, memory the said -// resource resides. -// -// To accelerate creation of the descriptor sets, modern graphics APIs allocate -// them from a pool, more elaborately referred to as descriptor pools, which do -// need to be purged frequently _after_ none of the descriptors the pools contain -// is in use by the GPU. Care must be taken that descriptors are not freed while -// they are in use by the pipeline, which considering the asynchronous nature of -// CPU-GPU interactions, can be anytime after the command is issued until it is -// fully executed by the GPU. -// -// As you can imagine, it is possible to have multiple descriptor pools, each of -// which is configured to house different types of descriptor sets with different -// allocation strategies. These descriptor pools themselves are fairly stable -// objects in that they theymself should not be created and destroyed frequently. -// That is the reason why we store them in a cache, which according to our usage -// of the term 'cache' in this implementatoin, is reserved for objects that are -// created infrequently and stabilize to a manageable number quickly over the -// lifetime of the program. -// -// Descriptor sets though, on the other hand, are allocated from pools which -// indeed does mean that the pools must be purged on a regular basis or else -// they will run out of free items. Again, this is in line with our usage of -// the term 'pool' in this implementation which we use to refer to a container -// of objects that is allocated out of and is required to be frequently purged. -// -// It is important to point out that for performance reasons, we intentionally -// do not free the descriptor sets individually, and instead opt to purge the -// pool in its totality, even though Vulkan supports the former usage pattern -// as well. This behavior is by design. -// - -struct Descriptor final { - // - // Set - // - - class Set final { - public: - Set( - VkDevice device, - VkDescriptorSet descriptor_set, - const Shader::Layout::Signature& shader_layout_signature); - Set(const Set&) = delete; - Set& operator=(const Set&) = delete; - Set(Set&&); - Set& operator=(Set&&); - ~Set() = default; - - Set& bind(uint32_t binding, const Resource::Buffer::Object& buffer); - Set& bind(uint32_t binding, const Resource::Image::Object& image); - - VkDescriptorSet handle() const; - - private: - void invalidate(); - - private: - struct Item final { - uint32_t binding; - VkDescriptorType type; - - union { - VkDescriptorBufferInfo buffer; - VkDescriptorImageInfo image; - } info; - }; - - void update(const Item& item); - - private: - VkDevice device_; - VkDescriptorSet descriptor_set_; - Shader::Layout::Signature shader_layout_signature_; - - struct { - c10::SmallVector items; - mutable bool dirty; - } bindings_; +class DescriptorSet final { + public: + explicit DescriptorSet( + const VkDevice, + const VkDescriptorSet, + const ShaderLayout::Signature&); + + DescriptorSet(const DescriptorSet&) = delete; + DescriptorSet& operator=(const DescriptorSet&) = delete; + + DescriptorSet(DescriptorSet&&) noexcept; + DescriptorSet& operator=(DescriptorSet&&) noexcept; + + ~DescriptorSet() = default; + + struct ResourceBinding final { + uint32_t binding_idx; + VkDescriptorType descriptor_type; + bool is_image; + + union { + VkDescriptorBufferInfo buffer_info; + VkDescriptorImageInfo image_info; + } resource_info; }; - // - // Pool - // - - class Pool final { - public: - explicit Pool(const GPU& gpu); - Pool(const Pool&) = delete; - Pool& operator=(const Pool&) = delete; - Pool(Pool&&); - Pool& operator=(Pool&&); - ~Pool(); - - Set allocate(const Shader::Layout::Object& shader_layout); - void purge(); - - private: - void invalidate(); - - private: - struct Configuration final { - static constexpr uint32_t kQuantum = 16u; - static constexpr uint32_t kReserve = 64u; - }; - - VkDevice device_; - Handle descriptor_pool_; - - struct { - struct Layout final { - std::vector pool; - size_t in_use; - }; - - ska::flat_hash_map layouts; - } set_; - } pool /* [thread_count] */; - - explicit Descriptor(const GPU& gpu) - : pool(gpu) { - } + private: + VkDevice device_; + VkDescriptorSet handle_; + ShaderLayout::Signature shader_layout_signature_; + c10::SmallVector bindings_; + + public: + DescriptorSet& bind(const uint32_t, const VulkanBuffer&); + DescriptorSet& bind(const uint32_t, const VulkanImage&); + + VkDescriptorSet get_bind_handle() const; + + private: + void add_binding(const ResourceBinding& resource); +}; + +class DescriptorSetPile final { + public: + DescriptorSetPile( + const uint32_t, + const VkDescriptorSetLayout, + const VkDevice, + const VkDescriptorPool); + + DescriptorSetPile(const DescriptorSetPile&) = delete; + DescriptorSetPile& operator=(const DescriptorSetPile&) = delete; + + DescriptorSetPile(DescriptorSetPile&&) = default; + DescriptorSetPile& operator=(DescriptorSetPile&&) = default; + + ~DescriptorSetPile() = default; + + private: + uint32_t pile_size_; + VkDescriptorSetLayout set_layout_; + VkDevice device_; + VkDescriptorPool pool_; + std::vector descriptors_; + size_t in_use_; + + public: + VkDescriptorSet get_descriptor_set(); + + private: + void allocate_new_batch(); +}; + +struct DescriptorPoolConfig final { + // Overall Pool capacity + uint32_t descriptorPoolMaxSets; + // DescriptorCounts by type + uint32_t descriptorUniformBufferCount; + uint32_t descriptorStorageBufferCount; + uint32_t descriptorCombinedSamplerCount; + uint32_t descriptorStorageImageCount; + // Pile size for pre-allocating descriptor sets + uint32_t descriptorPileSizes; +}; + +class DescriptorPool final { + public: + explicit DescriptorPool(const VkDevice, const DescriptorPoolConfig&); + + DescriptorPool(const DescriptorPool&) = delete; + DescriptorPool& operator=(const DescriptorPool&) = delete; + + DescriptorPool(DescriptorPool&&) = delete; + DescriptorPool& operator=(DescriptorPool&&) = delete; + + ~DescriptorPool(); + + private: + VkDevice device_; + VkDescriptorPool pool_; + DescriptorPoolConfig config_; + // New Descriptors + std::mutex mutex_; + ska::flat_hash_map piles_; + + public: + DescriptorSet get_descriptor_set( + const VkDescriptorSetLayout handle, + const ShaderLayout::Signature& signature); + + void flush(); }; } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Helper.cpp b/aten/src/ATen/native/vulkan/api/Helper.cpp deleted file mode 100644 index dba7d5aa8710b..0000000000000 --- a/aten/src/ATen/native/vulkan/api/Helper.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include - -namespace at { -namespace native { -namespace vulkan { -namespace api { -namespace helper { - -#ifdef USE_VULKAN_API - -void copy_texture_to_texture( - api::Command::Buffer& command_buffer, - api::Resource::Image::Object& src_image, - api::Resource::Image::Object& dst_image, - api::utils::uvec3 copy_extents, - api::utils::uvec3 src_offset, - api::utils::uvec3 dst_offset) { - VkImageCopy copy_info{}; - copy_info.srcSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; - copy_info.srcSubresource.layerCount = 1; - copy_info.dstSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; - copy_info.dstSubresource.layerCount = 1; - copy_info.extent.width = copy_extents.data[0u]; - copy_info.extent.height = copy_extents.data[1u]; - copy_info.extent.depth = copy_extents.data[2u]; - copy_info.srcOffset.x = src_offset.data[0u]; - copy_info.srcOffset.y = src_offset.data[1u]; - copy_info.srcOffset.z = src_offset.data[2u]; - copy_info.dstOffset.x = dst_offset.data[0u]; - copy_info.dstOffset.y = dst_offset.data[1u]; - copy_info.dstOffset.z = dst_offset.data[2u]; - - // To use vkCmdCopyImage, the stage of src & dst image must be set to vTensor::Stage::Transfer. - vkCmdCopyImage( - command_buffer.handle(), - src_image.handle, VK_IMAGE_LAYOUT_GENERAL, - dst_image.handle, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, - 1, - ©_info); -} - -#endif /* USE_VULKAN_API */ - -} // namespace helper -} // namespace api -} // namespace vulkan -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Helper.h b/aten/src/ATen/native/vulkan/api/Helper.h deleted file mode 100644 index 60d8560974be2..0000000000000 --- a/aten/src/ATen/native/vulkan/api/Helper.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#ifdef USE_VULKAN_API - -#include -#include - -namespace at { -namespace native { -namespace vulkan { -namespace api { -namespace helper { -// -// Copy Texture -// - -void copy_texture_to_texture( - api::Command::Buffer& command_buffer, - api::Resource::Image::Object& src_image, - api::Resource::Image::Object& dst_image, - api::utils::uvec3 copy_extents, - api::utils::uvec3 src_offset, - api::utils::uvec3 dst_offset); - -} // namespace utils -} // namespace api -} // namespace vulkan -} // namespace native -} // namespace at - -#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/OpProfiler.h b/aten/src/ATen/native/vulkan/api/OpProfiler.h deleted file mode 100644 index b38b5dc957294..0000000000000 --- a/aten/src/ATen/native/vulkan/api/OpProfiler.h +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -#ifdef USE_VULKAN_API - -#include - -namespace at { -namespace native { -namespace vulkan { -namespace api { - -class OpProfiler final { - public: - explicit OpProfiler(Command::Buffer& buffer, QueryPool& querypool, const std::string& query_name) - : buffer_(buffer), - querypool_(querypool) { - query_index_ = querypool.begin(buffer_.handle(), query_name); - } - OpProfiler(const OpProfiler&) = delete; - OpProfiler(OpProfiler&&) = delete; - OpProfiler& operator=(const OpProfiler&) = delete; - OpProfiler& operator=(OpProfiler&&) = delete; - ~OpProfiler() { - querypool_.end(buffer_.handle(), query_index_); - } - -private: - Command::Buffer& buffer_; - QueryPool& querypool_; - int query_index_; -}; - -} // namespace api -} // namespace vulkan -} // namespace native -} // namespace at - -#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp index 89e85892ee0c3..dcf87ea9d43a5 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp +++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp @@ -5,172 +5,337 @@ namespace native { namespace vulkan { namespace api { -Pipeline::Layout::Factory::Factory(const GPU& gpu) - : device_(gpu.device) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_, - "Invalid Vulkan device!"); +// +// Utility Functions +// + +VkAccessFlags vk_access( + const PipelineStageFlags stage, + const MemoryAccessFlags access) { + VkAccessFlags vk_access = 0u; + + if (access & MemoryAccessType::READ) { + if (stage & PipelineStage::COMPUTE) { + vk_access |= VK_ACCESS_SHADER_READ_BIT; + } + + if (stage & PipelineStage::HOST) { + vk_access |= VK_ACCESS_HOST_READ_BIT; + } + + if (stage & PipelineStage::TRANSFER) { + vk_access |= VK_ACCESS_TRANSFER_READ_BIT; + } + } + + if (access & MemoryAccessType::WRITE) { + if (stage & PipelineStage::COMPUTE) { + vk_access |= VK_ACCESS_SHADER_WRITE_BIT; + } + + if (stage & PipelineStage::HOST) { + vk_access |= VK_ACCESS_HOST_WRITE_BIT; + } + + if (stage & PipelineStage::TRANSFER) { + vk_access |= VK_ACCESS_TRANSFER_WRITE_BIT; + } + } + + return vk_access; } -typename Pipeline::Layout::Factory::Handle Pipeline::Layout::Factory::operator()( - const Descriptor& descriptor) const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor.descriptor_set_layout, - "Invalid Vulkan descriptor set layout!"); +VkPipelineStageFlags vk_stage(const PipelineStageFlags stage) { + VkPipelineStageFlags vk_stage = 0u; - const VkPipelineLayoutCreateInfo pipeline_layout_create_info{ - VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, - nullptr, - 0u, - 1u, - &descriptor.descriptor_set_layout, - 0u, - nullptr, - }; + if (stage & PipelineStage::COMPUTE) { + vk_stage |= VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; + } - VkPipelineLayout pipeline_layout{}; - VK_CHECK(vkCreatePipelineLayout( - device_, - &pipeline_layout_create_info, - nullptr, - &pipeline_layout)); + if (stage & PipelineStage::HOST) { + vk_stage |= VK_PIPELINE_STAGE_HOST_BIT; + } - TORCH_CHECK( - pipeline_layout, - "Invalid Vulkan pipeline layout!"); + if (stage & PipelineStage::TRANSFER) { + vk_stage |= VK_PIPELINE_STAGE_TRANSFER_BIT; + } - return Handle{ - pipeline_layout, - Deleter(device_), - }; + return vk_stage; } -namespace { +VkImageLayout vk_layout( + const PipelineStageFlags stage, + const MemoryAccessFlags access) { + switch (stage) { + case PipelineStage::COMPUTE: + switch (access) { + case MemoryAccessType::READ: + return VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL; + default: + return VK_IMAGE_LAYOUT_GENERAL; + } + break; -VkPipelineCache create_pipeline_cache(const VkDevice device) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); + case PipelineStage::TRANSFER: + switch (access) { + case MemoryAccessType::READ: + return VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL; - const VkPipelineCacheCreateInfo pipeline_cache_create_info{ - VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, - nullptr, - 0u, - 0u, - nullptr, - }; + case MemoryAccessType::WRITE: + return VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL; - VkPipelineCache pipeline_cache{}; - VK_CHECK(vkCreatePipelineCache( - device, - &pipeline_cache_create_info, - nullptr, - &pipeline_cache)); + default: + TORCH_INTERNAL_ASSERT(false, "Invalid!"); + } + break; - TORCH_CHECK( - pipeline_cache, - "Invalid Vulkan pipeline cache!"); + default: + TORCH_INTERNAL_ASSERT(false, "Invalid!"); + } - return pipeline_cache; + return VK_IMAGE_LAYOUT_UNDEFINED; } -} // namespace +// +// PipelineLayout +// -Pipeline::Factory::Factory(const GPU& gpu) - : device_(gpu.device), - pipeline_cache_( - create_pipeline_cache(device_), - VK_DELETER(PipelineCache)(device_)) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_, - "Invalid Vulkan device!"); +PipelineLayout::PipelineLayout( + const VkDevice device, + const VkDescriptorSetLayout descriptor_layout) + : device_(device), handle_{VK_NULL_HANDLE} { + // TODO: Enable push constants + const VkPipelineLayoutCreateInfo pipeline_layout_create_info{ + VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + 1u, // setLayoutCount + &descriptor_layout, // pSetLayouts + 0u, // pushConstantRangeCount + nullptr, // pPushConstantRanges + }; - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - pipeline_cache_, - "Invalid Vulkan pipeline cache!"); + VK_CHECK(vkCreatePipelineLayout( + device_, &pipeline_layout_create_info, nullptr, &handle_)); +} + +PipelineLayout::PipelineLayout(PipelineLayout&& other) noexcept + : device_(other.device_), handle_(other.handle_) { + other.handle_ = VK_NULL_HANDLE; } -typename Pipeline::Factory::Handle Pipeline::Factory::operator()( - const Descriptor& descriptor) const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor.pipeline_layout, - "Invalid Vulkan pipeline layout!"); +PipelineLayout::~PipelineLayout() { + if C10_LIKELY (VK_NULL_HANDLE == handle_) { + return; + } + vkDestroyPipelineLayout(device_, handle_, nullptr); + handle_ = VK_NULL_HANDLE; +} - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor.shader_module, - "Invalid Vulkan shader module!"); +void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept { + VkDevice tmp_device = lhs.device_; + VkPipelineLayout tmp_handle = lhs.handle_; + lhs.device_ = rhs.device_; + lhs.handle_ = rhs.handle_; + + rhs.device_ = tmp_device; + rhs.handle_ = tmp_handle; +} + +// +// ComputePipeline +// + +ComputePipeline::ComputePipeline( + const VkDevice device, + const ComputePipeline::Descriptor& descriptor, + const VkPipelineCache pipeline_cache) + : device_(device), handle_{VK_NULL_HANDLE} { constexpr VkSpecializationMapEntry specialization_map_entires[3]{ - // X - { - 0u, - offsetof(Shader::WorkGroup, data[0u]), - sizeof(Shader::WorkGroup::data[0u]), - }, - // Y - { - 1u, - offsetof(Shader::WorkGroup, data[1u]), - sizeof(Shader::WorkGroup::data[1u]), - }, - // Z - { - 2u, - offsetof(Shader::WorkGroup, data[2u]), - sizeof(Shader::WorkGroup::data[2u]), - }, + // X + { + 0u, + offsetof(utils::uvec3, data[0u]), + sizeof(utils::uvec3::data[0u]), + }, + // Y + { + 1u, + offsetof(utils::uvec3, data[1u]), + sizeof(utils::uvec3::data[1u]), + }, + // Z + { + 2u, + offsetof(utils::uvec3, data[2u]), + sizeof(utils::uvec3::data[2u]), + }, }; const VkSpecializationInfo specialization_info{ - 3u, - specialization_map_entires, - sizeof(descriptor.local_work_group), - &descriptor.local_work_group, + 3u, // mapEntryCount + specialization_map_entires, // pMapEntries + sizeof(descriptor.local_work_group), // dataSize + &descriptor.local_work_group, // pData + }; + + const VkPipelineShaderStageCreateInfo shader_stage_create_info{ + VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + VK_SHADER_STAGE_COMPUTE_BIT, // stage + descriptor.shader_module, // module + "main", // pName + &specialization_info, // pSpecializationInfo }; const VkComputePipelineCreateInfo compute_pipeline_create_info{ - VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, - nullptr, - 0u, - VkPipelineShaderStageCreateInfo{ - VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, - nullptr, - 0u, - VK_SHADER_STAGE_COMPUTE_BIT, - descriptor.shader_module, - "main", - &specialization_info, - }, - descriptor.pipeline_layout, - VK_NULL_HANDLE, - 0u, + VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + shader_stage_create_info, // stage + descriptor.pipeline_layout, // layout + VK_NULL_HANDLE, // basePipelineHandle + 0u, // basePipelineIndex }; - VkPipeline pipeline{}; VK_CHECK(vkCreateComputePipelines( device_, - pipeline_cache_.get(), + pipeline_cache, 1u, &compute_pipeline_create_info, nullptr, - &pipeline)); + &handle_)); +} + +ComputePipeline::ComputePipeline(ComputePipeline&& other) noexcept + : device_(other.device_), handle_(other.handle_) { + other.handle_ = VK_NULL_HANDLE; +} + +ComputePipeline::~ComputePipeline() { + if C10_LIKELY (VK_NULL_HANDLE == handle_) { + return; + } + vkDestroyPipeline(device_, handle_, nullptr); + handle_ = VK_NULL_HANDLE; +} + +void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept { + VkDevice tmp_device = lhs.device_; + VkPipeline tmp_handle = lhs.handle_; + + lhs.device_ = rhs.device_; + lhs.handle_ = rhs.handle_; + + rhs.device_ = tmp_device; + rhs.handle_ = tmp_handle; +} + +bool operator==( + const ComputePipeline::Descriptor& _1, + const ComputePipeline::Descriptor& _2) { + return ( + _1.pipeline_layout == _2.pipeline_layout && + _1.shader_module == _2.shader_module && + _1.local_work_group == _2.local_work_group); +} + +// +// PipelineLayoutCache +// + +PipelineLayoutCache::PipelineLayoutCache(const VkDevice device) + : cache_mutex_{}, device_(device), cache_{} {} + +PipelineLayoutCache::PipelineLayoutCache(PipelineLayoutCache&& other) noexcept + : cache_mutex_{}, device_(other.device_) { + std::lock_guard lock(other.cache_mutex_); + cache_ = std::move(other.cache_); +} + +PipelineLayoutCache::~PipelineLayoutCache() { + purge(); +} - TORCH_CHECK( - pipeline, - "Invalid Vulkan pipeline!"); +VkPipelineLayout PipelineLayoutCache::retrieve( + const PipelineLayoutCache::Key& key) { + std::lock_guard lock(cache_mutex_); - return Handle{ - pipeline, - Deleter(device_), + auto it = cache_.find(key); + if C10_UNLIKELY (cache_.cend() == it) { + it = cache_.insert({key, PipelineLayoutCache::Value(device_, key)}).first; + } + + return it->second.handle(); +} + +void PipelineLayoutCache::purge() { + std::lock_guard lock(cache_mutex_); + cache_.clear(); +} + +// +// ComputePipelineCache +// + +ComputePipelineCache::ComputePipelineCache(const VkDevice device) + : cache_mutex_{}, + device_(device), + pipeline_cache_{VK_NULL_HANDLE}, + cache_{} { + const VkPipelineCacheCreateInfo pipeline_cache_create_info{ + VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + 0u, // initialDataSize + nullptr, // pInitialData }; + + VK_CHECK(vkCreatePipelineCache( + device, &pipeline_cache_create_info, nullptr, &pipeline_cache_)); } -Pipeline::Cache::Cache(Factory factory) - : cache_(std::move(factory)) { +ComputePipelineCache::ComputePipelineCache( + ComputePipelineCache&& other) noexcept + : cache_mutex_{}, + device_(other.device_), + pipeline_cache_(other.pipeline_cache_) { + std::lock_guard lock(other.cache_mutex_); + cache_ = std::move(other.cache_); + + other.pipeline_cache_ = VK_NULL_HANDLE; +} + +ComputePipelineCache::~ComputePipelineCache() { + purge(); + + if C10_LIKELY (VK_NULL_HANDLE == pipeline_cache_) { + return; + } + vkDestroyPipelineCache(device_, pipeline_cache_, nullptr); + pipeline_cache_ = VK_NULL_HANDLE; +} + +VkPipeline ComputePipelineCache::retrieve( + const ComputePipelineCache::Key& key) { + std::lock_guard lock(cache_mutex_); + + auto it = cache_.find(key); + if C10_UNLIKELY (cache_.cend() == it) { + it = cache_ + .insert( + {key, + ComputePipelineCache::Value(device_, key, pipeline_cache_)}) + .first; + } + + return it->second.handle(); } -void Pipeline::Cache::purge() { - cache_.purge(); +void ComputePipelineCache::purge() { + cache_.clear(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h index bbff4fa914a37..f53a414f32584 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.h +++ b/aten/src/ATen/native/vulkan/api/Pipeline.h @@ -3,7 +3,6 @@ #ifdef USE_VULKAN_API #include -#include #include #include #include @@ -13,231 +12,177 @@ namespace native { namespace vulkan { namespace api { -// -// This struct defines pipeline, and pipeline layout, caches intended to minimize -// redundant object reconstructions at the cost of extra memory consumption. -// -// A Vulkan pipeline contains the entirety of states, as one coherent monolithic -// bundle, required to configure the GPU's execution pipeline. This usage -// pattern minimizes driver overhead, promotes pipeline state reuse, and is a -// departure from, and in direct contrast with, OpenGL's individually confiurable -// state machine. -// -// A Vulkan pipeline layout represents a sequence of Vulkan descriptor sets each -// having a specific layout, and deterimines the interface between all shader -// stages and shader resources. For more information on shaders and shader -// layouts check the description of at::navie::vulkan::api::Shader. -// -// This struct defines the facilities required to create, reuse, and destruct -// these Vulkan objects. -// +struct PipelineBarrier final { + struct Stages final { + VkPipelineStageFlags src; + VkPipelineStageFlags dst; + } stage; -struct Pipeline final { - // - // Barrier - // + c10::SmallVector buffers; + c10::SmallVector images; - struct Barrier final { - struct Stage final { - VkPipelineStageFlags src; - VkPipelineStageFlags dst; - } stage; + inline operator bool() const { + return (0u != stage.src) || (0u != stage.dst) || !buffers.empty() || + !images.empty(); + } +}; - c10::SmallVector buffers; - c10::SmallVector images; +using PipelineStageFlags = uint8_t; - operator bool() const; - }; +enum PipelineStage : PipelineStageFlags { + NO_STAGE = 0u << 0u, + COMPUTE = 1u << 0u, + HOST = 1u << 1u, + TRANSFER = 1u << 2u, +}; - // - // Layout - // +VkAccessFlags vk_access(const PipelineStageFlags, const MemoryAccessFlags); +VkPipelineStageFlags vk_stage(const PipelineStageFlags); +VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags); - struct Layout final { - /* - Descriptor - */ +class PipelineLayout final { + public: + explicit PipelineLayout(const VkDevice, const VkDescriptorSetLayout); - struct Descriptor final { - VkDescriptorSetLayout descriptor_set_layout; - }; + PipelineLayout(const PipelineLayout&) = delete; + PipelineLayout& operator=(const PipelineLayout&) = delete; - /* - Factory - */ + PipelineLayout(PipelineLayout&&) noexcept; + PipelineLayout& operator=(PipelineLayout&&) = delete; - class Factory final { - public: - explicit Factory(const GPU& gpu); + ~PipelineLayout(); - typedef Layout::Descriptor Descriptor; - typedef VK_DELETER(PipelineLayout) Deleter; - typedef api::Handle Handle; + private: + VkDevice device_; + VkPipelineLayout handle_; - struct Hasher { - size_t operator()(const Descriptor& descriptor) const; - }; + public: + VkPipelineLayout handle() const { + return handle_; + } - Handle operator()(const Descriptor& descriptor) const; + // We need to define a custom swap function since this class + // does not allow for move assignment. The swap function will + // be used in the hash map. + friend void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept; +}; - private: - VkDevice device_; - }; +class ComputePipeline final { + public: + struct Descriptor final { + VkPipelineLayout pipeline_layout; + VkShaderModule shader_module; + utils::uvec3 local_work_group; + }; - /* - Cache - */ + explicit ComputePipeline( + const VkDevice device, + const Descriptor& descriptor, + const VkPipelineCache pipeline_cache); - typedef api::Cache Cache; - Cache cache; + ComputePipeline(const ComputePipeline&) = delete; + ComputePipeline& operator=(const ComputePipeline&) = delete; - explicit Layout(const GPU& gpu) - : cache(Factory(gpu)) { - } - } layout; + ComputePipeline(ComputePipeline&&) noexcept; + ComputePipeline& operator=(ComputePipeline&&) = delete; - // - // Stage - // + ~ComputePipeline(); - struct Stage final { - typedef uint8_t Flags; + private: + VkDevice device_; + VkPipeline handle_; - enum Type : Flags { - None = 0u << 0u, - Compute = 1u << 0u, - Host = 1u << 1u, - Transfer = 1u << 2u, - }; - }; + public: + inline VkPipeline handle() const { + return handle_; + } - /* - Descriptor - */ + // We need to define a custom swap function since this class + // does not allow for move assignment. The swap function will + // be used in the hash map. + friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept; +}; - struct Descriptor final { - VkPipelineLayout pipeline_layout; - VkShaderModule shader_module; - Shader::WorkGroup local_work_group; +class PipelineLayoutCache final { + public: + explicit PipelineLayoutCache(const VkDevice device); + + PipelineLayoutCache(const PipelineLayoutCache&) = delete; + PipelineLayoutCache& operator=(const PipelineLayoutCache&) = delete; + + PipelineLayoutCache(PipelineLayoutCache&&) noexcept; + PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete; + + ~PipelineLayoutCache(); + + using Key = VkDescriptorSetLayout; + using Value = PipelineLayout; + + struct Hasher { + inline size_t operator()( + const VkDescriptorSetLayout descriptor_layout) const { + return c10::get_hash(descriptor_layout); + } }; - /* - Factory - */ + private: + // Multiple threads could potentially be adding entries into the cache, so use + // a mutex to manage access + std::mutex cache_mutex_; - class Factory final { - public: - explicit Factory(const GPU& gpu); + VkDevice device_; + ska::flat_hash_map cache_; - typedef Pipeline::Descriptor Descriptor; - typedef VK_DELETER(Pipeline) Deleter; - typedef api::Handle Handle; + public: + VkPipelineLayout retrieve(const Key&); + void purge(); +}; - struct Hasher { - size_t operator()(const Descriptor& descriptor) const; - }; +class ComputePipelineCache final { + public: + explicit ComputePipelineCache(const VkDevice device); - Handle operator()(const Descriptor& descriptor) const; + ComputePipelineCache(const ComputePipelineCache&) = delete; + ComputePipelineCache& operator=(const ComputePipelineCache&) = delete; - private: - VkDevice device_; - api::Handle pipeline_cache_; - }; + ComputePipelineCache(ComputePipelineCache&&) noexcept; + ComputePipelineCache& operator=(ComputePipelineCache&&) = delete; - /* - Object - */ + ~ComputePipelineCache(); - struct Object final { - VkPipeline handle; - VkPipelineLayout layout; - Shader::WorkGroup local_work_group; + using Key = ComputePipeline::Descriptor; + using Value = ComputePipeline; - operator bool() const; + struct Hasher { + inline size_t operator()( + const ComputePipeline::Descriptor& descriptor) const { + return c10::get_hash( + descriptor.pipeline_layout, + descriptor.shader_module, + descriptor.local_work_group.data[0u], + descriptor.local_work_group.data[1u], + descriptor.local_work_group.data[2u]); + }; }; - /* - Cache - */ - - class Cache final { - public: - explicit Cache(Factory factory); - Cache(const Cache&) = delete; - Cache& operator=(const Cache&) = delete; - Cache(Cache&&) = default; - Cache& operator=(Cache&&) = default; - ~Cache() = default; - - Object retrieve(const Descriptor& descriptor); - void purge(); - - private: - api::Cache cache_; - } cache; - - explicit Pipeline(const GPU& gpu) - : layout(gpu), - cache(Factory(gpu)) { - } + private: + // Multiple threads could potentially be adding entries into the cache, so use + // a mutex to manage access + std::mutex cache_mutex_; + + VkDevice device_; + VkPipelineCache pipeline_cache_; + ska::flat_hash_map cache_; + + public: + VkPipeline retrieve(const Key&); + void purge(); }; // // Impl // -inline Pipeline::Barrier::operator bool() const { - return (0u != stage.src) || - (0u != stage.dst) || - !buffers.empty() || - !images.empty(); -} - -inline bool operator==( - const Pipeline::Layout::Descriptor& _1, - const Pipeline::Layout::Descriptor& _2) { - - return (_1.descriptor_set_layout == _2.descriptor_set_layout); -} - -inline size_t Pipeline::Layout::Factory::Hasher::operator()( - const Descriptor& descriptor) const { - return c10::get_hash(descriptor.descriptor_set_layout); -} - -inline bool operator==( - const Pipeline::Descriptor& _1, - const Pipeline::Descriptor& _2) { - - return (_1.pipeline_layout == _2.pipeline_layout && \ - _1.shader_module == _2.shader_module && \ - _1.local_work_group == _2.local_work_group); -} - -inline size_t Pipeline::Factory::Hasher::operator()( - const Descriptor& descriptor) const { - return c10::get_hash( - descriptor.pipeline_layout, - descriptor.shader_module, - descriptor.local_work_group.data[0u], - descriptor.local_work_group.data[1u], - descriptor.local_work_group.data[2u]); -} - -inline Pipeline::Object::operator bool() const { - return (VK_NULL_HANDLE != handle) && - (VK_NULL_HANDLE != layout); -} - -inline Pipeline::Object Pipeline::Cache::retrieve( - const Descriptor& descriptor) { - return { - cache_.retrieve(descriptor), - descriptor.pipeline_layout, - descriptor.local_work_group, - }; -} - } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/QueryPool.cpp b/aten/src/ATen/native/vulkan/api/QueryPool.cpp index 9e12e3be3e3fa..1ac8d2b7a21ae 100644 --- a/aten/src/ATen/native/vulkan/api/QueryPool.cpp +++ b/aten/src/ATen/native/vulkan/api/QueryPool.cpp @@ -1,117 +1,190 @@ #include +#include #include +#include + namespace at { namespace native { namespace vulkan { namespace api { -namespace { - -VkQueryPool create_query_pool(const VkDevice& device, const uint32_t queryCount) { - VkQueryPool queryPool{}; - VkQueryPoolCreateInfo info{}; - info.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO; - info.queryType = VK_QUERY_TYPE_TIMESTAMP; - info.queryCount = queryCount; - VK_CHECK(vkCreateQueryPool(device, &info, nullptr, &queryPool)); - return queryPool; -}; - -void destroy_query_pool(const VkDevice& device, const VkQueryPool& querypool) { - if (VK_NULL_HANDLE != device && VK_NULL_HANDLE != querypool) { - vkDestroyQueryPool(device, querypool, nullptr); - } -} -} // namespace - -QueryPool::QueryPool(const VkDevice& device, const bool is_timestamps_supported, const float timestamp_period_us) - : device_(device), - is_timestamps_supported_(is_timestamps_supported), - timestamp_period_us_(timestamp_period_us), - querypool_(VK_NULL_HANDLE) { +QueryPool::QueryPool(const VkDevice device, const QueryPoolConfig& config) + : mutex_{}, + device_(device), + config_(config), + querypool_(VK_NULL_HANDLE), + shader_log_{}, + in_use_(0u) { + const VkQueryPoolCreateInfo info{ + VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + VK_QUERY_TYPE_TIMESTAMP, // queryType + config_.maxQueryCount, // queryCount + 0u, // pipelineStatistics + }; + + VK_CHECK(vkCreateQueryPool(device_, &info, nullptr, &querypool_)); + + shader_log_.reserve(config_.initialReserveSize); } QueryPool::~QueryPool() { - destroy_query_pool(device_, querypool_); - querypool_ = VK_NULL_HANDLE; - query_names_.clear(); + if (VK_NULL_HANDLE == querypool_) { + return; + } + vkDestroyQueryPool(device_, querypool_, nullptr); + shader_log_.clear(); } -bool QueryPool::is_enabled() const { - return VK_NULL_HANDLE != querypool_; +void QueryPool::reset(const CommandBuffer& cmd) { + std::lock_guard lock(mutex_); + cmd.reset_querypool(querypool_, 0u, in_use_); + in_use_ = 0u; + shader_log_.clear(); } -bool QueryPool::enable() { - TORCH_CHECK(VK_NULL_HANDLE == querypool_, "The query pool already exists."); - TORCH_CHECK(is_timestamps_supported_, "The device doesn't support for timestamps on all graphics and compute queues."); - querypool_ = create_query_pool(device_, Configuration::kMaxQueryCount); - return is_enabled(); +uint32_t QueryPool::write_timestamp(const CommandBuffer& cmd) { + TORCH_CHECK( + in_use_ < config_.maxQueryCount, + "Vulkan QueryPool: Exceeded the maximum number of queries " + "allowed by the queryPool (", + config_.maxQueryCount, + ")!"); + + cmd.write_timestamp(querypool_, in_use_); + + return in_use_++; } -std::vector QueryPool::disable(const bool waitfor_allqueries/* = true*/) { - auto out = result(waitfor_allqueries); - destroy_query_pool(device_, querypool_); - querypool_ = VK_NULL_HANDLE; - query_names_.clear(); - return out; +uint32_t QueryPool::shader_profile_begin( + const CommandBuffer& cmd, + const std::string& kernel_name, + const VkExtent3D global_workgroup_size, + const VkExtent3D local_workgroup_size) { + std::lock_guard lock(mutex_); + + uint32_t query_idx = write_timestamp(cmd); + + uint32_t log_idx = shader_log_.size(); + ShaderDuration log_entry{ + log_idx, + // Execution Properties + kernel_name, + global_workgroup_size, + local_workgroup_size, + // Query indexes + query_idx, // start query idx + UINT32_MAX, // end query idx + // Timings + 0u, // start time + 0u, // end time + 0u, // duration + }; + + shader_log_.emplace_back(log_entry); + + return log_idx; } -int QueryPool::begin(const VkCommandBuffer& commandBuffer, const std::string& query_name) { - if (VK_NULL_HANDLE == querypool_ || VK_NULL_HANDLE == commandBuffer) { - return -1; - } - auto newQueryIndex = static_cast(query_names_.size()); - TORCH_CHECK(newQueryIndex < Configuration::kMaxQueryCount, "The query index cannot exceed Configuration::kMaxQueryCount."); - query_names_.push_back(query_name); +void QueryPool::shader_profile_end( + const CommandBuffer& cmd, + const uint32_t log_idx) { + std::lock_guard lock(mutex_); + + uint32_t query_idx = write_timestamp(cmd); - vkCmdWriteTimestamp( - commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, querypool_, newQueryIndex * Configuration::kTimestampsPerQuery); - return static_cast(newQueryIndex); + shader_log_[log_idx].end_query_idx = query_idx; } -void QueryPool::end(const VkCommandBuffer& commandBuffer, const int queryIndex) { - if (VK_NULL_HANDLE == querypool_ || VK_NULL_HANDLE == commandBuffer) { - return; +void QueryPool::extract_results() { + std::lock_guard lock(mutex_); + + const VkQueryResultFlags flags = VK_QUERY_RESULT_64_BIT; + + std::vector query_data; + query_data.resize(in_use_); + + VK_CHECK(vkGetQueryPoolResults( + device_, + querypool_, + 0u, // firstQuery + in_use_, // queryCount + sizeof(uint64_t) * in_use_, // dataSize + query_data.data(), // pData + sizeof(uint64_t), // stride + flags)); // flags + + for (ShaderDuration& entry : shader_log_) { + entry.start_time_ns = query_data.at(entry.start_query_idx); + entry.end_time_ns = query_data.at(entry.end_query_idx); + + entry.execution_duration_ns = entry.end_time_ns - entry.start_time_ns; } - vkCmdWriteTimestamp( - commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, querypool_, static_cast(queryIndex) * Configuration::kTimestampsPerQuery + 1u); } -std::vector QueryPool::result(const bool waitfor_allqueries) const { - if (VK_NULL_HANDLE == querypool_) { - return std::vector {}; +std::ostream& operator<<(std::ostream& os, const VkExtent3D& extents) { + os << "{" << extents.width << ", " << extents.height << ", " << extents.depth + << "}"; + return os; +} + +std::string stringize(const VkExtent3D& extents) { + std::stringstream ss; + ss << "{" << extents.width << ", " << extents.height << ", " << extents.depth + << "}"; + return ss.str(); +} + +std::string QueryPool::generate_string_report() { + std::lock_guard lock(mutex_); + + std::stringstream ss; + + int kernel_name_w = 25; + int global_size_w = 15; + int duration_w = 25; + + ss << std::left; + ss << std::setw(kernel_name_w) << "Kernel Name"; + ss << std::setw(global_size_w) << "Workgroup Size"; + ss << std::right << std::setw(duration_w) << "Duration (ns)"; + ss << std::endl; + + ss << std::left; + ss << std::setw(kernel_name_w) << "==========="; + ss << std::setw(global_size_w) << "=============="; + ss << std::right << std::setw(duration_w) << "==========="; + ss << std::endl; + + for (ShaderDuration& entry : shader_log_) { + std::chrono::duration exec_duration_ns( + entry.execution_duration_ns); + + ss << std::left; + ss << std::setw(kernel_name_w) << entry.kernel_name; + ss << std::setw(global_size_w) << stringize(entry.global_workgroup_size); + ss << std::right << std::setw(duration_w) << exec_duration_ns.count(); + ss << std::endl; } - std::vector perfInfo; - const VkQueryResultFlags flags = waitfor_allqueries ? (VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT) : VK_QUERY_RESULT_64_BIT; - std::array counter_data{}; - for (uint32_t queryIndex = 0u; queryIndex < query_names_.size(); ++queryIndex) { - const auto& query_name = query_names_[queryIndex]; - - // Grab the gpu timings (nanoseconds) - auto ret = vkGetQueryPoolResults(device_, querypool_, queryIndex * Configuration::kTimestampsPerQuery, Configuration::kTimestampsPerQuery, - sizeof(uint64_t) * counter_data.size(), counter_data.data(), sizeof(uint64_t), - flags); - if (ret != VK_SUCCESS) { - std::stringstream msg; - msg << "vkGetQueryPoolResults() for \"" << query_name << "\"" << " returned an error code " << ret << "."; - TORCH_WARN(msg.str()); - continue; - } + return ss.str(); +} - // Tally up GPU time - int64_t gpu_time_us = static_cast( - (static_cast(counter_data[1] - counter_data[0]) * - timestamp_period_us_) / 1'000.f); // convert ns to us - - perfInfo.emplace_back(QueryPool::PerfInfo { - query_name, - static_cast(static_cast(counter_data[0]) * timestamp_period_us_ / 1'000.f), - static_cast(static_cast(counter_data[1]) * timestamp_period_us_ / 1'000.f), - gpu_time_us }); - } - return perfInfo; +void QueryPool::print_results() { + std::cout << generate_string_report() << std::endl; +} + +uint64_t QueryPool::get_total_op_ns(std::string op_name) { + std::lock_guard lock(mutex_); + uint64_t sum = 0; + for (ShaderDuration& entry : shader_log_) { + if (entry.kernel_name == op_name) { + sum += entry.execution_duration_ns; + } + } + return sum; } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/QueryPool.h b/aten/src/ATen/native/vulkan/api/QueryPool.h index edabba7fa7050..12d9832e3dc46 100644 --- a/aten/src/ATen/native/vulkan/api/QueryPool.h +++ b/aten/src/ATen/native/vulkan/api/QueryPool.h @@ -5,48 +5,81 @@ #include #include #include +#include namespace at { namespace native { namespace vulkan { namespace api { +struct QueryPoolConfig final { + uint32_t maxQueryCount; + uint32_t initialReserveSize; +}; + +struct ShaderDuration final { + uint32_t idx; + + // Execution Properties + std::string kernel_name; + VkExtent3D global_workgroup_size; + VkExtent3D local_workgroup_size; + + // Query indexes + uint32_t start_query_idx; + uint32_t end_query_idx; + + // Timings + uint64_t start_time_ns; + uint64_t end_time_ns; + uint64_t execution_duration_ns; +}; + class QueryPool final { public: - explicit QueryPool(const VkDevice& device, const bool is_timestamps_supported, const float timestamp_period_us); + explicit QueryPool(const VkDevice, const QueryPoolConfig&); + QueryPool(const QueryPool&) = delete; - QueryPool(QueryPool&&) = default; QueryPool& operator=(const QueryPool&) = delete; - QueryPool& operator=(QueryPool&&) = default; + + QueryPool(QueryPool&&) = delete; + QueryPool& operator=(QueryPool&&) = delete; + ~QueryPool(); -public: - struct PerfInfo final { - std::string query_name; - int64_t start_time_us; - int64_t end_time_us; - int64_t execution_time_us; - }; - - struct Configuration final { - static constexpr uint32_t kTimestampsPerQuery = 2u; - static constexpr uint32_t kMaxQueryCount = 65536u; - }; - -public: - bool is_enabled() const; - bool enable(); - std::vector disable(const bool waitfor_allqueries = true); - int begin(const VkCommandBuffer& commandBuffer, const std::string& query_name); - void end(const VkCommandBuffer& commandBuffer, const int queryIndex); - std::vector result(const bool waitfor_allqueries) const; - -private: + private: + std::mutex mutex_; + VkDevice device_; - bool is_timestamps_supported_; - float timestamp_period_us_; + QueryPoolConfig config_; + VkQueryPool querypool_; - std::vector query_names_; + + std::vector shader_log_; + size_t in_use_; + + private: + uint32_t write_timestamp(const CommandBuffer&); + + std::string generate_string_report(); + + public: + inline bool is_enabled() const { + return VK_NULL_HANDLE != querypool_; + } + + void reset(const CommandBuffer&); + + uint32_t shader_profile_begin( + const CommandBuffer&, + const std::string&, + const VkExtent3D, + const VkExtent3D); + void shader_profile_end(const CommandBuffer&, const uint32_t); + + void extract_results(); + void print_results(); + uint64_t get_total_op_ns(std::string op_name); }; } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index 520ccc87d5336..82b98579e051f 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -1,129 +1,161 @@ -#include #include +#include namespace at { namespace native { namespace vulkan { namespace api { -namespace { -VmaAllocator create_allocator( - const VkInstance instance, - const VkPhysicalDevice physical_device, - const VkDevice device) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - instance, - "Invalid Vulkan instance!"); +// +// Utility Functions +// - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - physical_device, - "Invalid Vulkan physical device!"); +VkFormat vk_format(const caffe2::TypeMeta dtype) { + switch (c10::typeMetaToScalarType(dtype)) { + case kFloat: +#ifdef USE_VULKAN_FP16_INFERENCE + return VK_FORMAT_R16G16B16A16_SFLOAT; +#else + return VK_FORMAT_R32G32B32A32_SFLOAT; +#endif /* USE_VULKAN_FP16_INFERENCE */ - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); + case c10::kQUInt8: + return VK_FORMAT_R8G8B8A8_UINT; - const VmaAllocatorCreateInfo allocator_create_info{ - 0u, - physical_device, - device, - 0u, - nullptr, - nullptr, - 1u, - nullptr, - nullptr, - nullptr, - instance, - VK_API_VERSION_1_0, + default: + TORCH_CHECK(false, "Vulkan tensor format not supported!"); + } + return VK_FORMAT_UNDEFINED; +} +// +// MemoryBarrier +// + +MemoryBarrier::MemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags) + : handle{ + VK_STRUCTURE_TYPE_MEMORY_BARRIER, // sType + nullptr, // pNext + src_access_flags, // srcAccessMask + dst_access_flags, // dstAccessMask + } {} + +// +// VulkanBuffer +// + +VulkanBuffer::VulkanBuffer() + : memory_properties_{}, + buffer_properties_{}, + allocator_(VK_NULL_HANDLE), + allocation_(VK_NULL_HANDLE), + handle_(VK_NULL_HANDLE) {} + +VulkanBuffer::VulkanBuffer( + const VmaAllocator vma_allocator, + const VkDeviceSize size, + const VulkanBuffer::MemoryProperties& mem_props) + : memory_properties_(mem_props), + buffer_properties_({ + size, + 0u, + size, + }), + allocator_(vma_allocator), + allocation_(VK_NULL_HANDLE), + handle_(VK_NULL_HANDLE) { + const VkBufferCreateInfo buffer_create_info{ + VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + size, // size + memory_properties_.buffer_usage, // usage + VK_SHARING_MODE_EXCLUSIVE, // sharingMode + 0u, // queueFamilyIndexCount + nullptr, // pQueueFamilyIndices }; - VmaAllocator allocator{}; - VK_CHECK(vmaCreateAllocator(&allocator_create_info, &allocator)); - TORCH_CHECK(allocator, "Invalid VMA (Vulkan Memory Allocator) allocator!"); + // TODO: enable creation with a custom pool + VmaAllocationCreateInfo alloc_create_info{ + memory_properties_.create_flags, // flags + memory_properties_.memory_usage, // usage + memory_properties_.required_mem_flags, // requiredFlags + memory_properties_.preferred_mem_flags, // preferredFlags + 0u, // memoryTypeBits + VK_NULL_HANDLE, // pool + nullptr, // pUserData + 0.5f, // priority + }; - return allocator; + VK_CHECK(vmaCreateBuffer( + allocator_, + &buffer_create_info, + &alloc_create_info, + &handle_, + &allocation_, + nullptr)); } -VmaAllocationCreateInfo create_allocation_create_info( - const Resource::Memory::Descriptor& descriptor) { - return VmaAllocationCreateInfo{ - VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT, - descriptor.usage, - descriptor.required, - descriptor.preferred, - 0u, - VK_NULL_HANDLE, - nullptr, - 0.5f, - }; +VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) noexcept + : memory_properties_(other.memory_properties_), + buffer_properties_(other.buffer_properties_), + allocator_(other.allocator_), + allocation_(other.allocation_), + handle_(other.handle_) { + other.allocation_ = VK_NULL_HANDLE; + other.handle_ = VK_NULL_HANDLE; } -} // namespace +VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) noexcept { + const VmaAllocation tmp_allocation = allocation_; + const VkBuffer tmp_buffer = handle_; -void release_buffer(const Resource::Buffer& buffer) { - // Safe to pass null as buffer or allocation. - vmaDestroyBuffer( - buffer.memory.allocator, - buffer.object.handle, - buffer.memory.allocation); -} - -void release_image(const Resource::Image& image) { - // Sampler is an immutable object. Its lifetime is managed through the cache. + memory_properties_ = other.memory_properties_; + buffer_properties_ = other.buffer_properties_; + allocator_ = other.allocator_; + allocation_ = other.allocation_; + handle_ = other.handle_; - if (VK_NULL_HANDLE != image.object.view) { - VmaAllocatorInfo allocator_info{}; - vmaGetAllocatorInfo(image.memory.allocator, &allocator_info); - vkDestroyImageView(allocator_info.device, image.object.view, nullptr); - } + other.allocation_ = tmp_allocation; + other.handle_ = tmp_buffer; - // Safe to pass null as image or allocation. - vmaDestroyImage( - image.memory.allocator, - image.object.handle, - image.memory.allocation); + return *this; } -void* map( - const Resource::Memory& memory, - const Resource::Memory::Access::Flags access) { - void* data = nullptr; - VK_CHECK(vmaMapMemory(memory.allocator, memory.allocation, &data)); - - if (access & Resource::Memory::Access::Read) { - // Call will be ignored by implementation if the memory type this allocation - // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior - // we want. - VK_CHECK(vmaInvalidateAllocation( - memory.allocator, memory.allocation, 0u, VK_WHOLE_SIZE)); +VulkanBuffer::~VulkanBuffer() { + if (VK_NULL_HANDLE != handle_) { + vmaDestroyBuffer(allocator_, handle_, allocation_); } - - return data; } -Resource::Memory::Scope::Scope( - const VmaAllocator allocator, - const VmaAllocation allocation, - const Access::Flags access) - : allocator_(allocator), - allocation_(allocation), - access_(access) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - allocator, - "Invalid VMA (Vulkan Memory Allocator) allocator!"); +// +// MemoryMap +// - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - allocation, - "Invalid VMA (Vulkan Memory Allocator) allocation!"); +MemoryMap::MemoryMap(const VulkanBuffer& buffer, const uint8_t access) + : access_(access), + allocator_(buffer.vma_allocator()), + allocation_(buffer.allocation()), + data_(nullptr) { + VK_CHECK(vmaMapMemory(allocator_, allocation_, &data_)); } -void Resource::Memory::Scope::operator()(const void* const data) const { - if (C10_UNLIKELY(!data)) { +MemoryMap::MemoryMap(MemoryMap&& other) noexcept + : access_(other.access_), + allocator_(other.allocator_), + allocation_(other.allocation_), + data_(other.data_) { + other.allocation_ = VK_NULL_HANDLE; + other.data_ = nullptr; +} + +MemoryMap::~MemoryMap() { + if (C10_UNLIKELY(!data_)) { return; } - if (access_ & Access::Write) { + if (access_ & MemoryAccessType::WRITE) { // Call will be ignored by implementation if the memory type this allocation // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior // we want. @@ -133,558 +165,529 @@ void Resource::Memory::Scope::operator()(const void* const data) const { vmaUnmapMemory(allocator_, allocation_); } -Resource::Image::Sampler::Factory::Factory(const GPU& gpu) - : device_(gpu.device) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_, - "Invalid Vulkan device!"); +void MemoryMap::invalidate() { + if (access_ & MemoryAccessType::READ) { + // Call will be ignored by implementation if the memory type this allocation + // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior + // we want. + VK_CHECK( + vmaInvalidateAllocation(allocator_, allocation_, 0u, VK_WHOLE_SIZE)); + } +} + +// +// BufferMemoryBarrier +// + +BufferMemoryBarrier::BufferMemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags, + const VulkanBuffer& buffer) + : handle{ + VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, // sType + nullptr, // pNext + src_access_flags, // srcAccessMask + dst_access_flags, // dstAccessMask + VK_QUEUE_FAMILY_IGNORED, // srcQueueFamilyIndex + VK_QUEUE_FAMILY_IGNORED, // dstQueueFamilyIndex + buffer.handle_, // buffer + buffer.buffer_properties_.mem_offset, // offset + buffer.buffer_properties_.mem_range, // size + } {} + +// +// ImageSampler +// + +bool operator==( + const ImageSampler::Properties& _1, + const ImageSampler::Properties& _2) { + return ( + _1.filter == _2.filter && _1.mipmap_mode == _2.mipmap_mode && + _1.address_mode == _2.address_mode && _1.border_color == _2.border_color); } -typename Resource::Image::Sampler::Factory::Handle -Resource::Image::Sampler::Factory::operator()( - const Descriptor& descriptor) const { +ImageSampler::ImageSampler( + const VkDevice device, + const ImageSampler::Properties& props) + : device_(device), handle_(VK_NULL_HANDLE) { const VkSamplerCreateInfo sampler_create_info{ - VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO, - nullptr, - 0u, - descriptor.filter, - descriptor.filter, - descriptor.mipmap_mode, - descriptor.address_mode, - descriptor.address_mode, - descriptor.address_mode, - 0.0f, - VK_FALSE, - 1.0f, - VK_FALSE, - VK_COMPARE_OP_NEVER, - 0.0f, - VK_LOD_CLAMP_NONE, - descriptor.border, - VK_FALSE, + VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + props.filter, // magFilter + props.filter, // minFilter + props.mipmap_mode, // mipmapMode + props.address_mode, // addressModeU + props.address_mode, // addressModeV + props.address_mode, // addressModeW + 0.0f, // mipLodBias + VK_FALSE, // anisotropyEnable + 1.0f, // maxAnisotropy, + VK_FALSE, // compareEnable + VK_COMPARE_OP_NEVER, // compareOp + 0.0f, // minLod + VK_LOD_CLAMP_NONE, // maxLod + props.border_color, // borderColor + VK_FALSE, // unnormalizedCoordinates }; - VkSampler sampler{}; - VK_CHECK(vkCreateSampler( - device_, - &sampler_create_info, - nullptr, - &sampler)); - - TORCH_CHECK( - sampler, - "Invalid Vulkan image sampler!"); - - return Handle{ - sampler, - Deleter(device_), - }; + VK_CHECK(vkCreateSampler(device_, &sampler_create_info, nullptr, &handle_)); } -VkFence Resource::Fence::handle(const bool add_to_waitlist) const { - if (!pool) { - return VK_NULL_HANDLE; - } - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - id < pool->fence_.pool.size(), - "Invalid Vulkan fence!"); - - const VkFence fence = pool->fence_.pool[id].get(); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - fence, - "Invalid Vulkan fence!"); +ImageSampler::ImageSampler(ImageSampler&& other) noexcept + : device_(other.device_), handle_(other.handle_) { + other.handle_ = VK_NULL_HANDLE; +} - if (add_to_waitlist) { - pool->fence_.waitlist.push_back(fence); +ImageSampler::~ImageSampler() { + if C10_LIKELY (VK_NULL_HANDLE == handle_) { + return; } + vkDestroySampler(device_, handle_, nullptr); +} - return fence; +size_t ImageSampler::Hasher::operator()( + const ImageSampler::Properties& props) const { + return c10::get_hash( + props.filter, props.mipmap_mode, props.address_mode, props.border_color); } -void Resource::Fence::wait(const uint64_t timeout_nanoseconds) { - const VkFence fence = handle(/* add_to_waitlist = */ false); +void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept { + VkDevice tmp_device = lhs.device_; + VkSampler tmp_handle = lhs.handle_; - const auto waitlist_itr = std::find( - pool->fence_.waitlist.cbegin(), - pool->fence_.waitlist.cend(), - fence); + lhs.device_ = rhs.device_; + lhs.handle_ = rhs.handle_; - if (pool->fence_.waitlist.cend() != waitlist_itr) { - VK_CHECK(vkWaitForFences( - pool->device_, - 1u, - &fence, - VK_TRUE, - timeout_nanoseconds)); + rhs.device_ = tmp_device; + rhs.handle_ = tmp_handle; +} - VK_CHECK(vkResetFences( - pool->device_, - 1u, - &fence)); +// +// VulkanImage +// + +VulkanImage::VulkanImage() + : memory_properties_{}, + image_properties_{}, + view_properties_{}, + sampler_properties_{}, + allocator_(VK_NULL_HANDLE), + allocation_(VK_NULL_HANDLE), + handles_{ + VK_NULL_HANDLE, + VK_NULL_HANDLE, + VK_NULL_HANDLE, + }, + layout_{} {} + +VulkanImage::VulkanImage( + const VmaAllocator vma_allocator, + const VkDevice device, + const MemoryProperties& mem_props, + const ImageProperties& image_props, + const ViewProperties& view_props, + const SamplerProperties& sampler_props, + const VkImageLayout layout, + const VkSampler sampler) + : memory_properties_(mem_props), + image_properties_(image_props), + view_properties_(view_props), + sampler_properties_(sampler_props), + allocator_(vma_allocator), + allocation_(VK_NULL_HANDLE), + handles_{ + VK_NULL_HANDLE, + VK_NULL_HANDLE, + sampler, + }, + layout_(layout) { + const VkImageCreateInfo image_create_info{ + VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + image_properties_.image_type, // imageType + image_properties_.image_format, // format + image_properties_.image_extents, // extents + 1u, // mipLevels + 1u, // arrayLayers + VK_SAMPLE_COUNT_1_BIT, // samples + VK_IMAGE_TILING_OPTIMAL, // tiling + memory_properties_.image_usage, // usage + VK_SHARING_MODE_EXCLUSIVE, // sharingMode + 0u, // queueFamilyIndexCount + nullptr, // pQueueFamilyIndices + layout_, // initialLayout + }; - pool->fence_.waitlist.erase(waitlist_itr); - } -} + // TODO: enable creation with a custom pool + const VmaAllocationCreateInfo alloc_create_info{ + memory_properties_.create_flags, // flags + memory_properties_.memory_usage, // usage + memory_properties_.required_mem_flags, // requiredFlags + memory_properties_.preferred_mem_flags, // preferredFlags + 0u, // memoryTypeBits + VK_NULL_HANDLE, // pool + nullptr, // pUserData + 0.5f, // priority + }; -namespace { + VK_CHECK(vmaCreateImage( + allocator_, + &image_create_info, + &alloc_create_info, + &(handles_.image), + &allocation_, + nullptr)); -class Linear final : public Resource::Pool::Policy { - public: - Linear( - VkDeviceSize block_size, - uint32_t min_block_count, - uint32_t max_block_count); + // Image View - virtual void enact( - VmaAllocator allocator, - const VkMemoryRequirements& memory_requirements, - VmaAllocationCreateInfo& allocation_create_info) override; + const VkComponentMapping component_mapping{ + VK_COMPONENT_SWIZZLE_IDENTITY, // r + VK_COMPONENT_SWIZZLE_IDENTITY, // g + VK_COMPONENT_SWIZZLE_IDENTITY, // b + VK_COMPONENT_SWIZZLE_IDENTITY, // a + }; - private: - struct Configuration final { - static constexpr uint32_t kReserve = 16u; + const VkImageSubresourceRange subresource_range{ + VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask + 0u, // baseMipLevel + VK_REMAINING_MIP_LEVELS, // levelCount + 0u, // baseArrayLayer + VK_REMAINING_ARRAY_LAYERS, // layerCount }; - struct Entry final { - class Deleter final { - public: - explicit Deleter(VmaAllocator); - void operator()(VmaPool) const; + const VkImageViewCreateInfo image_view_create_info{ + VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + handles_.image, // image + view_properties_.view_type, // viewType + view_properties_.view_format, // format + component_mapping, // components + subresource_range, // subresourceRange + }; - private: - VmaAllocator allocator_; - }; + VK_CHECK(vkCreateImageView( + device, &image_view_create_info, nullptr, &(handles_.image_view))); +} - uint32_t memory_type_index; - Handle handle; - }; +VulkanImage::VulkanImage(VulkanImage&& other) noexcept + : memory_properties_(other.memory_properties_), + image_properties_(other.image_properties_), + view_properties_(other.view_properties_), + sampler_properties_(other.sampler_properties_), + allocator_(other.allocator_), + allocation_(other.allocation_), + handles_(other.handles_), + layout_(other.layout_) { + other.allocation_ = VK_NULL_HANDLE; + other.handles_.image = VK_NULL_HANDLE; + other.handles_.image_view = VK_NULL_HANDLE; + other.handles_.sampler = VK_NULL_HANDLE; +} - std::vector pools_; - - struct { - VkDeviceSize size; - uint32_t min; - uint32_t max; - } block_; -}; - -Linear::Entry::Deleter::Deleter(const VmaAllocator allocator) - : allocator_(allocator) { -} - -void Linear::Entry::Deleter::operator()(const VmaPool pool) const { - vmaDestroyPool(allocator_, pool); -} - -Linear::Linear( - const VkDeviceSize block_size, - const uint32_t min_block_count, - const uint32_t max_block_count) - : block_ { - block_size, - min_block_count, - max_block_count, - } { - pools_.reserve(Configuration::kReserve); -} - -void Linear::enact( - const VmaAllocator allocator, - const VkMemoryRequirements& memory_requirements, - VmaAllocationCreateInfo& allocation_create_info) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - allocator, - "Invalid VMA (Vulkan Memory Allocator) allocator!"); - - uint32_t memory_type_index = 0u; - VK_CHECK(vmaFindMemoryTypeIndex( - allocator, - memory_requirements.memoryTypeBits, - &allocation_create_info, - &memory_type_index)); - - auto pool_itr = std::find_if( - pools_.begin(), - pools_.end(), - [memory_type_index](const Entry& entry) { - return entry.memory_type_index == memory_type_index; - }); - - if (pools_.end() == pool_itr) { - const VmaPoolCreateInfo pool_create_info{ - memory_type_index, - VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT, - block_.size, - block_.min, - block_.max, - 0u, - }; - - VmaPool pool{}; - VK_CHECK(vmaCreatePool( - allocator, - &pool_create_info, - &pool)); - - TORCH_CHECK( - pool, - "Invalid VMA (Vulkan Memory Allocator) memory pool!"); - - pools_.push_back({ - memory_type_index, - { - pool, - Entry::Deleter(allocator), - }, - }); +VulkanImage& VulkanImage::operator=(VulkanImage&& other) noexcept { + const VmaAllocation tmp_allocation = allocation_; + const VkImage tmp_image = handles_.image; + const VkImageView tmp_image_view = handles_.image_view; - pool_itr = std::prev(pools_.end()); - } + memory_properties_ = other.memory_properties_; + image_properties_ = other.image_properties_; + view_properties_ = other.view_properties_; + sampler_properties_ = other.sampler_properties_; + allocator_ = other.allocator_; + allocation_ = other.allocation_; + handles_ = other.handles_; + layout_ = other.layout_; - allocation_create_info.pool = pool_itr->handle.get(); -} - -} // namespace - -std::unique_ptr Resource::Pool::Policy::linear( - const VkDeviceSize block_size, - const uint32_t min_block_count, - const uint32_t max_block_count) { - return std::make_unique( - block_size, - min_block_count, - max_block_count); -} - -Resource::Pool::Pool( - const GPU& gpu, - std::unique_ptr policy) - : device_(gpu.device), - allocator_( - create_allocator( - gpu.instance, - gpu.adapter->physical_handle(), - device_), - vmaDestroyAllocator), - memory_{ - std::move(policy), - }, - image_{ - .sampler = Image::Sampler{gpu}, - }, - fence_{} { - buffer_.pool.reserve(Configuration::kReserve); - image_.pool.reserve(Configuration::kReserve); - fence_.pool.reserve(Configuration::kReserve); -} - -Resource::Pool::Pool(Pool&& pool) - : device_(std::move(pool.device_)), - allocator_(std::move(pool.allocator_)), - memory_(std::move(pool.memory_)), - buffer_(std::move(pool.buffer_)), - image_(std::move(pool.image_)), - fence_(std::move(pool.fence_)) { - pool.invalidate(); -} - -Resource::Pool& Resource::Pool::operator=(Pool&& pool) { - if (&pool != this) { - device_ = std::move(pool.device_); - allocator_ = std::move(pool.allocator_); - memory_ = std::move(pool.memory_); - buffer_ = std::move(pool.buffer_); - image_ = std::move(pool.image_); - fence_ = std::move(pool.fence_); - - pool.invalidate(); - }; + other.allocation_ = tmp_allocation; + other.handles_.image = tmp_image; + other.handles_.image_view = tmp_image_view; return *this; } -Resource::Pool::~Pool() { - try { - if (device_ && allocator_) { - purge(); - } - } - catch (const std::exception& e) { - TORCH_WARN( - "Vulkan: Resource pool destructor raised an exception! Error: ", - e.what()); +VulkanImage::~VulkanImage() { + if (VK_NULL_HANDLE != handles_.image_view) { + VmaAllocatorInfo allocator_info{}; + vmaGetAllocatorInfo(allocator_, &allocator_info); + vkDestroyImageView(allocator_info.device, handles_.image_view, nullptr); } - catch (...) { - TORCH_WARN( - "Vulkan: Resource pool destructor raised an exception! " - "Error: Unknown"); + + if (VK_NULL_HANDLE != handles_.image) { + vmaDestroyImage(allocator_, handles_.image, allocation_); } } -Resource::Buffer Resource::Pool::create_buffer( - const Buffer::Descriptor& descriptor) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && allocator_, - "This resource pool is in an invalid state! ", - "Potential reason: This resource pool is moved from."); +// +// ImageMemoryBarrier +// + +ImageMemoryBarrier::ImageMemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags, + const VkImageLayout src_layout_flags, + const VkImageLayout dst_layout_flags, + const VulkanImage& image) + : handle{ + VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, // sType + nullptr, // pNext + src_access_flags, // srcAccessMask + dst_access_flags, // dstAccessMask + src_layout_flags, // oldLayout + dst_layout_flags, // newLayout + VK_QUEUE_FAMILY_IGNORED, // srcQueueFamilyIndex + VK_QUEUE_FAMILY_IGNORED, // dstQueueFamilyIndex + image.handles_.image, // image + { + // subresourceRange + VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask + 0u, // baseMipLevel + VK_REMAINING_MIP_LEVELS, // levelCount + 0u, // baseArrayLayer + VK_REMAINING_ARRAY_LAYERS, // layerCount + }, + } {} + +// +// SamplerCache +// + +SamplerCache::SamplerCache(const VkDevice device) + : cache_mutex_{}, device_(device), cache_{} {} + +SamplerCache::SamplerCache(SamplerCache&& other) noexcept + : cache_mutex_{}, device_(other.device_) { + std::lock_guard lock(other.cache_mutex_); + cache_ = std::move(other.cache_); +} - const VkBufferCreateInfo buffer_create_info{ - VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, - nullptr, - 0u, - descriptor.size, - descriptor.usage.buffer, - VK_SHARING_MODE_EXCLUSIVE, - 0u, - nullptr, - }; +SamplerCache::~SamplerCache() { + purge(); +} - VkBuffer buffer{}; - VK_CHECK(vkCreateBuffer( - device_, - &buffer_create_info, - nullptr, - &buffer)); +VkSampler SamplerCache::retrieve(const SamplerCache::Key& key) { + std::lock_guard lock(cache_mutex_); - TORCH_CHECK( - buffer, - "Invalid Vulkan buffer!"); + auto it = cache_.find(key); + if C10_UNLIKELY (cache_.cend() == it) { + it = cache_.insert({key, SamplerCache::Value(device_, key)}).first; + } - VkMemoryRequirements memory_requirements{}; - vkGetBufferMemoryRequirements( - device_, - buffer, - &memory_requirements); + return it->second.handle(); +} - VmaAllocationCreateInfo allocation_create_info = - create_allocation_create_info(descriptor.usage.memory); +void SamplerCache::purge() { + cache_.clear(); +} - if (memory_.policy) { - memory_.policy->enact( - allocator_.get(), - memory_requirements, - allocation_create_info); - } +// +// MemoryAllocator +// - VmaAllocation allocation{}; - VK_CHECK(vmaAllocateMemory( - allocator_.get(), - &memory_requirements, - &allocation_create_info, - &allocation, - nullptr)); +MemoryAllocator::MemoryAllocator( + const VkInstance instance, + const VkPhysicalDevice physical_device, + const VkDevice device) + : instance_{}, + physical_device_(physical_device), + device_(device), + allocator_{VK_NULL_HANDLE} { + VmaVulkanFunctions vk_functions{}; + vk_functions.vkGetInstanceProcAddr = vkGetInstanceProcAddr; + vk_functions.vkGetDeviceProcAddr = vkGetDeviceProcAddr; - TORCH_CHECK( - allocation, - "Invalid VMA (Vulkan Memory Allocator) allocation!"); + const VmaAllocatorCreateInfo allocator_create_info{ + 0u, // flags + physical_device_, // physicalDevice + device_, // device + 0u, // preferredLargeHeapBlockSize + nullptr, // pAllocationCallbacks + nullptr, // pDeviceMemoryCallbacks + nullptr, // pHeapSizeLimit + &vk_functions, // pVulkanFunctions + instance, // instance + VK_API_VERSION_1_0, // vulkanApiVersion + nullptr, // pTypeExternalMemoryHandleTypes + }; - VK_CHECK(vmaBindBufferMemory( - allocator_.get(), - allocation, - buffer)); + VK_CHECK(vmaCreateAllocator(&allocator_create_info, &allocator_)); +} - return Buffer{ - Buffer::Object{ - buffer, - 0u, - descriptor.size, - }, - Memory{ - allocator_.get(), - allocation, - }, - }; +MemoryAllocator::MemoryAllocator(MemoryAllocator&& other) noexcept + : instance_(other.instance_), + physical_device_(other.physical_device_), + device_(other.device_), + allocator_(other.allocator_) { + other.allocator_ = VK_NULL_HANDLE; + other.device_ = VK_NULL_HANDLE; + other.physical_device_ = VK_NULL_HANDLE; + other.instance_ = VK_NULL_HANDLE; } -void Resource::Pool::register_buffer_cleanup(const Resource::Buffer& buffer) { - buffer_.pool.emplace_back(buffer, &release_buffer); +MemoryAllocator::~MemoryAllocator() { + if C10_LIKELY (VK_NULL_HANDLE == allocator_) { + return; + } + vmaDestroyAllocator(allocator_); } -Resource::Image Resource::Pool::create_image( - const Image::Descriptor& descriptor) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && allocator_, - "This resource pool is in an invalid state! ", - "Potential reason: This resource pool is moved from."); +VulkanImage MemoryAllocator::create_image3d( + const VkExtent3D& extents, + const VulkanImage::SamplerProperties& sampler_props, + const VkSampler sampler, + const caffe2::TypeMeta dtype, + bool allow_transfer) { + VkImageUsageFlags usage = + VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT; + if (allow_transfer) { + usage |= + (VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT); + } - const VkImageCreateInfo image_create_info{ - VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, - nullptr, - 0u, - descriptor.type, - descriptor.format, - descriptor.extent, - 1u, - 1u, - VK_SAMPLE_COUNT_1_BIT, - VK_IMAGE_TILING_OPTIMAL, - descriptor.usage.image, - VK_SHARING_MODE_EXCLUSIVE, - 0u, - nullptr, - VK_IMAGE_LAYOUT_UNDEFINED, + const VkFormat image_format = vk_format(dtype); + + const VulkanImage::MemoryProperties mem_props{ + DEFAULT_ALLOCATION_STRATEGY, + VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE, + 0u, + 0u, + usage, }; - VkImage image{}; - VK_CHECK(vkCreateImage( - device_, - &image_create_info, - nullptr, - &image)); + const VulkanImage::ImageProperties image_props{ + VK_IMAGE_TYPE_3D, + image_format, + extents, + }; + + const VulkanImage::ViewProperties view_props{ + VK_IMAGE_VIEW_TYPE_3D, + image_format, + }; - TORCH_CHECK( - image, - "Invalid Vulkan image!"); + const VkImageLayout initial_layout = VK_IMAGE_LAYOUT_UNDEFINED; - VkMemoryRequirements memory_requirements{}; - vkGetImageMemoryRequirements( + return VulkanImage( + allocator_, device_, - image, - &memory_requirements); + mem_props, + image_props, + view_props, + sampler_props, + initial_layout, + sampler); +} - VmaAllocationCreateInfo allocation_create_info = - create_allocation_create_info(descriptor.usage.memory); +VulkanBuffer MemoryAllocator::create_storage_buffer( + const VkDeviceSize size, + const bool gpu_only) { + const VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - if (memory_.policy) { - memory_.policy->enact( - allocator_.get(), - memory_requirements, - allocation_create_info); + VmaAllocationCreateFlags create_flags = DEFAULT_ALLOCATION_STRATEGY; + if (!gpu_only) { + create_flags |= VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT; } - VmaAllocation allocation{}; - VK_CHECK(vmaAllocateMemory( - allocator_.get(), - &memory_requirements, - &allocation_create_info, - &allocation, - nullptr)); + const VmaMemoryUsage vma_usage = + gpu_only ? VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE : VMA_MEMORY_USAGE_AUTO; - TORCH_CHECK( - allocation, - "Invalid VMA (Vulkan Memory Allocator) allocation!"); + const VkMemoryPropertyFlags required_mem_props = + gpu_only ? 0u : VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - VK_CHECK(vmaBindImageMemory( - allocator_.get(), - allocation, - image)); + const VkMemoryPropertyFlags preferred_mem_props = gpu_only + ? 0u + : VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | + VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - const VkImageViewCreateInfo image_view_create_info{ - VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO, - nullptr, - 0u, - image, - descriptor.view.type, - descriptor.view.format, - { - VK_COMPONENT_SWIZZLE_IDENTITY, - VK_COMPONENT_SWIZZLE_IDENTITY, - VK_COMPONENT_SWIZZLE_IDENTITY, - VK_COMPONENT_SWIZZLE_IDENTITY, - }, - { - VK_IMAGE_ASPECT_COLOR_BIT, - 0u, - VK_REMAINING_MIP_LEVELS, - 0u, - VK_REMAINING_ARRAY_LAYERS, - }, - }; - - VkImageView view{}; - VK_CHECK(vkCreateImageView( - device_, - &image_view_create_info, - nullptr, - &view)); - - TORCH_CHECK( - view, - "Invalid Vulkan image view!"); - - return Image{ - Image::Object{ - image, - VK_IMAGE_LAYOUT_UNDEFINED, - view, - image_.sampler.cache.retrieve(descriptor.sampler), - }, - Memory{ - allocator_.get(), - allocation, - }, + const VulkanBuffer::MemoryProperties mem_props{ + create_flags, + vma_usage, + required_mem_props, + preferred_mem_props, + buffer_usage, }; -} -void Resource::Pool::register_image_cleanup(const Resource::Image& image) { - image_.pool.emplace_back(image, &release_image); + return VulkanBuffer(allocator_, size, mem_props); } -Resource::Fence Resource::Pool::fence() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && allocator_, - "This resource pool is in an invalid state! ", - "Potential reason: This resource pool is moved from."); - - if (fence_.pool.size() == fence_.in_use) { - const VkFenceCreateInfo fence_create_info{ - VK_STRUCTURE_TYPE_FENCE_CREATE_INFO, - nullptr, +VulkanBuffer MemoryAllocator::create_staging_buffer(const VkDeviceSize size) { + const VulkanBuffer::MemoryProperties mem_props{ + DEFAULT_ALLOCATION_STRATEGY, + VMA_MEMORY_USAGE_AUTO_PREFER_HOST, + 0u, 0u, - }; + VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, + }; - VkFence fence{}; - VK_CHECK(vkCreateFence( - device_, - &fence_create_info, - nullptr, - &fence)); + return VulkanBuffer(allocator_, size, mem_props); +} - TORCH_CHECK( - fence, - "Invalid Vulkan fence!"); +// +// VulkanFence +// - fence_.pool.emplace_back(fence, VK_DELETER(Fence)(device_)); - } +VulkanFence::VulkanFence() + : device_(VK_NULL_HANDLE), handle_(VK_NULL_HANDLE), waiting_(false) {} - return Fence{ - this, - fence_.in_use++, +VulkanFence::VulkanFence(const VkDevice device) + : device_(device), handle_(VK_NULL_HANDLE), waiting_(VK_NULL_HANDLE) { + const VkFenceCreateInfo fence_create_info{ + VK_STRUCTURE_TYPE_FENCE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags }; + + VK_CHECK(vkCreateFence(device_, &fence_create_info, nullptr, &handle_)); } -void Resource::Pool::purge() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_ && allocator_, - "This resource pool is in an invalid state! ", - "Potential reason: This resource pool is moved from."); +VulkanFence::VulkanFence(VulkanFence&& other) noexcept + : device_(other.device_), handle_(other.handle_), waiting_(other.waiting_) { + other.handle_ = VK_NULL_HANDLE; + other.waiting_ = false; +} - if (!fence_.waitlist.empty()) { - VK_CHECK(vkWaitForFences( - device_, - fence_.waitlist.size(), - fence_.waitlist.data(), - VK_TRUE, - UINT64_MAX)); +VulkanFence& VulkanFence::operator=(VulkanFence&& other) noexcept { + device_ = other.device_; + handle_ = other.handle_; + waiting_ = other.waiting_; - VK_CHECK(vkResetFences( - device_, - fence_.waitlist.size(), - fence_.waitlist.data())); + other.device_ = VK_NULL_HANDLE; + other.handle_ = VK_NULL_HANDLE; + other.waiting_ = false; - fence_.waitlist.clear(); - } + return *this; +} - fence_.in_use = 0u; - image_.pool.clear(); - buffer_.pool.clear(); +VulkanFence::~VulkanFence() { + if C10_LIKELY (VK_NULL_HANDLE == handle_) { + return; + } + vkDestroyFence(device_, handle_, nullptr); } -void Resource::Pool::invalidate() { - device_ = VK_NULL_HANDLE; - allocator_.reset(); +void VulkanFence::wait() { + // if get_submit_handle() has not been called, then this will no-op + if (waiting_) { + VkResult fence_status = VK_NOT_READY; + // Run the wait in a loop to keep the CPU hot. A single call to + // vkWaitForFences with no timeout may cause the calling thread to be + // scheduled out. + do { + // The timeout (last) arg is in units of ns + fence_status = vkWaitForFences(device_, 1u, &handle_, VK_TRUE, 100000); + + TORCH_CHECK( + fence_status != VK_ERROR_DEVICE_LOST, + "Vulkan Fence: Device lost while waiting for fence!"); + } while (fence_status != VK_SUCCESS); + + VK_CHECK(vkResetFences(device_, 1u, &handle_)); + + waiting_ = false; + } } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Resource.h b/aten/src/ATen/native/vulkan/api/Resource.h index 3c2b46c33984e..1efd907b3246d 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.h +++ b/aten/src/ATen/native/vulkan/api/Resource.h @@ -2,459 +2,489 @@ #ifdef USE_VULKAN_API -#include #include -#include #include +#include + namespace at { namespace native { namespace vulkan { namespace api { -struct Resource final { - class Pool; +typedef uint8_t MemoryAccessFlags; - // - // Memory - // +VkFormat vk_format(const caffe2::TypeMeta dtype); - struct Memory final { - /* - Descriptor - */ +constexpr VmaAllocationCreateFlags DEFAULT_ALLOCATION_STRATEGY = + VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT; - struct Descriptor final { - VmaMemoryUsage usage; - VkMemoryPropertyFlags /* optional */ required; - VkMemoryPropertyFlags /* optional */ preferred; - }; +enum MemoryAccessType : MemoryAccessFlags { + NONE = 0u << 0u, + READ = 1u << 0u, + WRITE = 1u << 1u, +}; - /* - Barrier - */ +struct MemoryBarrier final { + VkMemoryBarrier handle; - struct Barrier final { - VkAccessFlags src; - VkAccessFlags dst; - }; + MemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags); +}; - /* - Access - */ - - struct Access final { - typedef uint8_t Flags; - - enum Type : Flags { - None = 0u << 0u, - Read = 1u << 0u, - Write = 1u << 1u, - }; - - template - using Pointer = std::add_pointer_t< - std::conditional_t< - 0u != (access & Write), - Type, - std::add_const_t>>; - }; +class VulkanBuffer final { + public: + struct MemoryProperties final { + VmaAllocationCreateFlags create_flags; + + VmaMemoryUsage memory_usage; + VkMemoryPropertyFlags required_mem_flags; + VkMemoryPropertyFlags preferred_mem_flags; - class Scope; - template - using Handle = Handle; + VkBufferUsageFlags buffer_usage; + }; + + struct BufferProperties final { + VkDeviceSize size; + VkDeviceSize mem_offset; + VkDeviceSize mem_range; + }; - template< - typename Type, - typename Pointer = Access::Pointer> - Handle map() const &; + explicit VulkanBuffer(); - template< - typename Type, - Access::Flags kAccess, - typename Pointer = Access::Pointer> - Handle map() &; + explicit VulkanBuffer( + const VmaAllocator, + const VkDeviceSize, + const MemoryProperties&); - VmaAllocator allocator; - VmaAllocation allocation; + VulkanBuffer(const VulkanBuffer&) = delete; + VulkanBuffer& operator=(const VulkanBuffer&) = delete; - private: - // Intentionally disabed to ensure memory access is always properly - // encapsualted in a scoped map-unmap region. Allowing below overloads - // to be invoked on a temporary would open the door to the possibility - // of accessing the underlying memory out of the expected scope making - // for seemingly ineffective memory writes and hard to hunt down bugs. + VulkanBuffer(VulkanBuffer&&) noexcept; + VulkanBuffer& operator=(VulkanBuffer&&) noexcept; - template - Handle map() const && = delete; + ~VulkanBuffer(); - template - Handle map() && = delete; + struct Package final { + VkBuffer handle; + VkDeviceSize buffer_offset; + VkDeviceSize buffer_range; }; - // - // Buffer - // + friend struct BufferMemoryBarrier; - struct Buffer final { - /* - Descriptor - */ + private: + MemoryProperties memory_properties_; + BufferProperties buffer_properties_; + // The allocator object this was allocated from + VmaAllocator allocator_; + // Handles to the allocated memory + VmaAllocation allocation_; + VkBuffer handle_; - struct Descriptor final { - VkDeviceSize size; + public: + inline VmaAllocator vma_allocator() const { + return allocator_; + } - struct { - VkBufferUsageFlags buffer; - Memory::Descriptor memory; - } usage; - }; + inline VmaAllocation allocation() const { + return allocation_; + } - /* - Object - */ + inline VkBuffer handle() const { + return handle_; + } - struct Object final { - VkBuffer handle; - VkDeviceSize offset; - VkDeviceSize range; + inline VkDeviceSize mem_offset() const { + return buffer_properties_.mem_offset; + } - operator bool() const; - }; + inline VkDeviceSize mem_range() const { + return buffer_properties_.mem_range; + } - /* - Barrier - */ + operator bool() const { + return (allocation_ != VK_NULL_HANDLE); + } +}; - struct Barrier final { - Object object; - Memory::Barrier memory; - }; +class MemoryMap final { + public: + explicit MemoryMap( + const VulkanBuffer& buffer, + const MemoryAccessFlags access); - Object object; - Memory memory; + MemoryMap(const MemoryMap&) = delete; + MemoryMap& operator=(const MemoryMap&) = delete; - operator bool() const; - }; + MemoryMap(MemoryMap&&) noexcept; + MemoryMap& operator=(MemoryMap&&) = delete; - // - // Image - // + ~MemoryMap(); - struct Image final { - // - // Sampler - // + private: + uint8_t access_; + VmaAllocator allocator_; + VmaAllocation allocation_; + void* data_; - struct Sampler final { - /* - Descriptor - */ + public: + template + T* data() { + return reinterpret_cast(data_); + } - struct Descriptor final { - VkFilter filter; - VkSamplerMipmapMode mipmap_mode; - VkSamplerAddressMode address_mode; - VkBorderColor border; - }; + void invalidate(); +}; - /* - Factory - */ +struct BufferMemoryBarrier final { + VkBufferMemoryBarrier handle; - class Factory final { - public: - explicit Factory(const GPU& gpu); + BufferMemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags, + const VulkanBuffer& buffer); +}; - typedef Sampler::Descriptor Descriptor; - typedef VK_DELETER(Sampler) Deleter; - typedef api::Handle Handle; +class ImageSampler final { + public: + struct Properties final { + VkFilter filter; + VkSamplerMipmapMode mipmap_mode; + VkSamplerAddressMode address_mode; + VkBorderColor border_color; + }; - struct Hasher { - size_t operator()(const Descriptor& descriptor) const; - }; + explicit ImageSampler(const VkDevice, const Properties&); - Handle operator()(const Descriptor& descriptor) const; + ImageSampler(const ImageSampler&) = delete; + ImageSampler& operator=(const ImageSampler&) = delete; - private: - VkDevice device_; - }; + ImageSampler(ImageSampler&&) noexcept; + ImageSampler& operator=(ImageSampler&&) = delete; - /* - Cache - */ + ~ImageSampler(); - typedef api::Cache Cache; - Cache cache; + private: + VkDevice device_; + VkSampler handle_; - explicit Sampler(const GPU& gpu) - : cache(Factory(gpu)) { - } - }; + public: + VkSampler handle() const { + return handle_; + } - /* - Descriptor - */ + struct Hasher { + size_t operator()(const Properties&) const; + }; - struct Descriptor final { - VkImageType type; - VkFormat format; - VkExtent3D extent; + // We need to define a custom swap function since this class + // does not allow for move assignment. The swap function will + // be used in the hash map. + friend void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept; +}; - struct { - VkImageUsageFlags image; - Memory::Descriptor memory; - } usage; +class VulkanImage final { + public: + struct MemoryProperties final { + VmaAllocationCreateFlags create_flags; - struct { - VkImageViewType type; - VkFormat format; - } view; + VmaMemoryUsage memory_usage; + VkMemoryPropertyFlags required_mem_flags; + VkMemoryPropertyFlags preferred_mem_flags; - Sampler::Descriptor sampler; - }; + VkImageUsageFlags image_usage; + }; + + struct ImageProperties final { + VkImageType image_type; + VkFormat image_format; + VkExtent3D image_extents; + }; - /* - Object - */ + struct ViewProperties final { + VkImageViewType view_type; + VkFormat view_format; + }; - struct Object final { - VkImage handle; - VkImageLayout layout; - VkImageView view; - VkSampler sampler; + typedef ImageSampler::Properties SamplerProperties; - operator bool() const; - }; + struct Handles final { + VkImage image; + VkImageView image_view; + VkSampler sampler; + }; - /* - Barrier - */ + explicit VulkanImage(); - struct Barrier final { - Object object; - Memory::Barrier memory; + explicit VulkanImage( + const VmaAllocator, + const VkDevice, + const MemoryProperties&, + const ImageProperties&, + const ViewProperties&, + const SamplerProperties&, + const VkImageLayout layout, + const VkSampler); - struct { - VkImageLayout src; - VkImageLayout dst; - } layout; - }; + VulkanImage(const VulkanImage&) = delete; + VulkanImage& operator=(const VulkanImage&) = delete; + + VulkanImage(VulkanImage&&) noexcept; + VulkanImage& operator=(VulkanImage&&) noexcept; - Object object; - Memory memory; + ~VulkanImage(); - operator bool() const; + struct Package final { + VkImage handle; + VkImageLayout image_layout; + VkImageView image_view; + VkSampler image_sampler; }; - // - // Fence - // + friend struct ImageMemoryBarrier; - struct Fence final { - Pool* pool; - size_t id; + private: + MemoryProperties memory_properties_; + ImageProperties image_properties_; + ViewProperties view_properties_; + SamplerProperties sampler_properties_; + // The allocator object this was allocated from + VmaAllocator allocator_; + // Handles to the allocated memory + VmaAllocation allocation_; + Handles handles_; + // Layout + VkImageLayout layout_; - operator bool() const; - VkFence handle(bool add_to_waitlist = true) const; - void wait(uint64_t timeout_nanoseconds = UINT64_MAX); - }; + public: + inline VmaAllocator vma_allocator() const { + return allocator_; + } - // - // Pool - // - - class Pool final { - public: - class Policy { - public: - virtual ~Policy() = default; - - static std::unique_ptr linear( - VkDeviceSize block_size = VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE, - uint32_t min_block_count = 1u, - uint32_t max_block_count = UINT32_MAX); - - virtual void enact( - VmaAllocator allocator, - const VkMemoryRequirements& memory_requirements, - VmaAllocationCreateInfo& allocation_create_info) = 0; - }; + inline VmaAllocation allocation() const { + return allocation_; + } - explicit Pool(const GPU& gpu, std::unique_ptr = {}); - Pool(const Pool&) = delete; - Pool& operator=(const Pool&) = delete; - Pool(Pool&&); - Pool& operator=(Pool&&); - ~Pool(); + inline VkExtent3D extents() const { + return image_properties_.image_extents; + } - // Primary + inline VkImage handle() const { + return handles_.image; + } - Buffer create_buffer(const Buffer::Descriptor& descriptor); - void register_buffer_cleanup(const Buffer& buffer); - Image create_image(const Image::Descriptor& descriptor); - void register_image_cleanup(const Image& image); + inline VkImageView image_view() const { + return handles_.image_view; + } - Fence fence(); - void purge(); + inline VkSampler sampler() const { + return handles_.sampler; + } - // Helper + Package package() const { + return { + handles_.image, + layout_, + handles_.image_view, + handles_.sampler, + }; + } - template - Buffer uniform(const Block& block); + inline VkImageLayout layout() const { + return layout_; + } - private: - friend struct Fence; + inline void set_layout(const VkImageLayout layout) { + layout_ = layout; + } - void invalidate(); + inline operator bool() const { + return (allocation_ != VK_NULL_HANDLE); + } +}; - private: - struct Configuration final { - static constexpr uint32_t kReserve = 256u; - }; +struct ImageMemoryBarrier final { + VkImageMemoryBarrier handle; - VkDevice device_; - Handle allocator_; + ImageMemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags, + const VkImageLayout src_layout_flags, + const VkImageLayout dst_layout_flags, + const VulkanImage& image); +}; - struct { - std::unique_ptr policy; - } memory_; +class SamplerCache final { + public: + explicit SamplerCache(const VkDevice device); - struct { - std::vector> pool; - } buffer_; + SamplerCache(const SamplerCache&) = delete; + SamplerCache& operator=(const SamplerCache&) = delete; - struct { - std::vector> pool; - Image::Sampler sampler; - } image_; + SamplerCache(SamplerCache&&) noexcept; + SamplerCache& operator=(SamplerCache&&) = delete; - struct { - std::vector> pool; - mutable std::vector waitlist; - size_t in_use; - } fence_; - } pool; + ~SamplerCache(); - explicit Resource(const GPU& gpu) - : pool(gpu, nullptr) { - } -}; + typedef ImageSampler::Properties Key; + typedef ImageSampler Value; + typedef ImageSampler::Hasher Hasher; -void release_buffer(const Resource::Buffer& buffer); + private: + // Multiple threads could potentially be adding entries into the cache, so use + // a mutex to manage access + std::mutex cache_mutex_; -void release_image(const Resource::Image& image); + VkDevice device_; + ska::flat_hash_map cache_; -// -// Impl -// + public: + VkSampler retrieve(const Key&); + void purge(); +}; -class Resource::Memory::Scope final { +class MemoryAllocator final { public: - Scope( - VmaAllocator allocator, - VmaAllocation allocation, - Access::Flags access); + explicit MemoryAllocator( + const VkInstance instance, + const VkPhysicalDevice physical_device, + const VkDevice device); - void operator()(const void* data) const; + MemoryAllocator(const MemoryAllocator&) = delete; + MemoryAllocator& operator=(const MemoryAllocator&) = delete; + + MemoryAllocator(MemoryAllocator&&) noexcept; + MemoryAllocator& operator=(MemoryAllocator&&) = delete; + + ~MemoryAllocator(); private: + VkInstance instance_; + VkPhysicalDevice physical_device_; + VkDevice device_; VmaAllocator allocator_; - VmaAllocation allocation_; - Access::Flags access_; + + public: + VulkanImage create_image3d( + const VkExtent3D&, + const VulkanImage::SamplerProperties&, + const VkSampler, + const caffe2::TypeMeta dtype, + const bool allow_transfer = false); + + VulkanBuffer create_storage_buffer( + const VkDeviceSize, + const bool gpu_only = true); + + VulkanBuffer create_staging_buffer(const VkDeviceSize); + + template + VulkanBuffer create_params_buffer(const Block& block); }; -template -inline Resource::Memory::Handle Resource::Memory::map() const & { - // Forward declaration - void* map(const Memory&, Access::Flags); +class VulkanFence final { + public: + // TODO: This is required for the lazy allocation pattern in api/Tensor. + // It will be disabled pending future refactors. + explicit VulkanFence(); - return Handle{ - reinterpret_cast(map(*this, Access::Read)), - Scope(allocator, allocation, Access::Read), - }; -} + explicit VulkanFence(const VkDevice); -template -inline Resource::Memory::Handle Resource::Memory::map() & { - // Forward declaration - void* map(const Memory&, Access::Flags); + VulkanFence(const VulkanFence&) = delete; + VulkanFence& operator=(const VulkanFence&) = delete; - static_assert( - (kAccess == Access::Read) || - (kAccess == Access::Write) || - (kAccess == (Access::Read | Access::Write)), - "Invalid memory access!"); + VulkanFence(VulkanFence&&) noexcept; + VulkanFence& operator=(VulkanFence&&) noexcept; - return Handle{ - reinterpret_cast(map(*this, kAccess)), - Scope(allocator, allocation, kAccess), - }; -} + ~VulkanFence(); -inline Resource::Buffer::Object::operator bool() const { - return VK_NULL_HANDLE != handle; -} + private: + VkDevice device_; + VkFence handle_; + bool waiting_; -inline Resource::Buffer::operator bool() const { - return object; -} + public: + // Used to get the handle for a queue submission. + VkFence get_submit_handle() { + if (handle_ != VK_NULL_HANDLE) { + // Indicate we are now waiting for this fence to be signaled + waiting_ = true; + } + return handle_; + } -inline bool operator==( - const Resource::Image::Sampler::Descriptor& _1, - const Resource::Image::Sampler::Descriptor& _2) { + VkFence handle() { + return handle_; + } - return (_1.filter == _2.filter && \ - _1.mipmap_mode == _2.mipmap_mode && \ - _1.address_mode == _2.address_mode && \ - _1.border == _2.border); -} + // Trigger a synchronous wait for the fence to be signaled + void wait(); -inline size_t Resource::Image::Sampler::Factory::Hasher::operator()( - const Descriptor& descriptor) const { - return c10::get_hash( - descriptor.filter, - descriptor.mipmap_mode, - descriptor.address_mode, - descriptor.border); -} + bool waiting() const { + return waiting_; + } -inline Resource::Image::Object::operator bool() const { - return VK_NULL_HANDLE != handle; -} + operator bool() const { + return (VK_NULL_HANDLE != handle_); + } +}; -inline Resource::Image::operator bool() const { - return object; -} +// A pool to track created Fences and reuse ones that are available. +// Only intended to be modified by one thread at a time. +struct FencePool final { + VkDevice device_; -inline Resource::Fence::operator bool() const { - return pool; -} + std::stack pool_; + + explicit FencePool(const VkDevice device) : device_(device), pool_{} {} + + // Returns an rvalue reference to a fence, so that it can be moved + inline VulkanFence get_fence() { + if (pool_.empty()) { + VulkanFence new_fence = VulkanFence(device_); + return new_fence; + } + + VulkanFence top_fence = std::move(pool_.top()); + pool_.pop(); + + return top_fence; + } + + // Marks the fence as available + inline void return_fence(VulkanFence& fence) { + pool_.push(std::move(fence)); + } +}; + +// +// Impl +// + +template +inline VulkanBuffer MemoryAllocator::create_params_buffer(const Block& block) { + const VulkanBuffer::MemoryProperties mem_props{ + DEFAULT_ALLOCATION_STRATEGY | + VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT, + VMA_MEMORY_USAGE_AUTO, + 0u, + 0u, + VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, + }; -template -inline Resource::Buffer Resource::Pool::uniform(const Block& block) { - Buffer uniform = this->create_buffer({ - sizeof(Block), - { - VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, - { - VMA_MEMORY_USAGE_CPU_TO_GPU, - 0u, - 0u, - }, - }, - }); - this->register_buffer_cleanup(uniform); + VulkanBuffer uniform_buffer(allocator_, sizeof(Block), mem_props); + // Fill the uniform buffer with data in block { - Memory::Handle memory = uniform.memory.template map< - Block, - Memory::Access::Write>(); + MemoryMap mapping(uniform_buffer, MemoryAccessType::WRITE); + Block* data_ptr = mapping.template data(); - *memory.get() = block; + *data_ptr = block; } - return uniform; + return uniform_buffer; } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Runtime.cpp b/aten/src/ATen/native/vulkan/api/Runtime.cpp index c925a0226f6ac..a1c460fa4dc97 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.cpp +++ b/aten/src/ATen/native/vulkan/api/Runtime.cpp @@ -1,5 +1,4 @@ #include -#include namespace at { namespace native { @@ -8,13 +7,11 @@ namespace api { namespace { - void find_requested_layers_and_extensions( std::vector& enabled_layers, std::vector& enabled_extensions, const std::vector& requested_layers, const std::vector& requested_extensions) { - // Get supported instance layers uint32_t layer_count = 0; VK_CHECK(vkEnumerateInstanceLayerProperties(&layer_count, nullptr)); @@ -55,27 +52,27 @@ void find_requested_layers_and_extensions( VkInstance create_instance(const RuntimeConfiguration& config) { const VkApplicationInfo application_info{ - VK_STRUCTURE_TYPE_APPLICATION_INFO, // sType - nullptr, // pNext - "PyTorch Vulkan Backend", // pApplicationName - 0, // applicationVersion - nullptr, // pEngineName - 0, // engineVersion - VK_API_VERSION_1_0, // apiVersion + VK_STRUCTURE_TYPE_APPLICATION_INFO, // sType + nullptr, // pNext + "PyTorch Vulkan Backend", // pApplicationName + 0, // applicationVersion + nullptr, // pEngineName + 0, // engineVersion + VK_API_VERSION_1_0, // apiVersion }; std::vector enabled_layers; std::vector enabled_extensions; if (config.enableValidationMessages) { - std::vector requested_layers { - // "VK_LAYER_LUNARG_api_dump", - "VK_LAYER_KHRONOS_validation", + std::vector requested_layers{ + // "VK_LAYER_LUNARG_api_dump", + "VK_LAYER_KHRONOS_validation", }; - std::vector requested_extensions { - #ifdef VK_EXT_debug_report - VK_EXT_DEBUG_REPORT_EXTENSION_NAME, - #endif + std::vector requested_extensions{ +#ifdef VK_EXT_debug_report + VK_EXT_DEBUG_REPORT_EXTENSION_NAME, +#endif /* VK_EXT_debug_report */ }; find_requested_layers_and_extensions( @@ -86,14 +83,14 @@ VkInstance create_instance(const RuntimeConfiguration& config) { } const VkInstanceCreateInfo instance_create_info{ - VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - &application_info, // pApplicationInfo - static_cast(enabled_layers.size()), // enabledLayerCount - enabled_layers.data(), // ppEnabledLayerNames - static_cast(enabled_extensions.size()), // enabledExtensionCount - enabled_extensions.data(), // ppEnabledExtensionNames + VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + &application_info, // pApplicationInfo + static_cast(enabled_layers.size()), // enabledLayerCount + enabled_layers.data(), // ppEnabledLayerNames + static_cast(enabled_extensions.size()), // enabledExtensionCount + enabled_extensions.data(), // ppEnabledExtensionNames }; VkInstance instance{}; @@ -102,15 +99,15 @@ VkInstance create_instance(const RuntimeConfiguration& config) { #ifdef USE_VULKAN_VOLK volkLoadInstance(instance); -#endif +#endif /* USE_VULKAN_VOLK */ return instance; } -std::vector create_adapters(const VkInstance instance, - const uint32_t num_queues) { +std::vector create_physical_devices( + const VkInstance instance) { if (VK_NULL_HANDLE == instance) { - return std::vector(); + return std::vector(); } uint32_t device_count = 0; @@ -119,13 +116,13 @@ std::vector create_adapters(const VkInstance instance, std::vector devices(device_count); VK_CHECK(vkEnumeratePhysicalDevices(instance, &device_count, devices.data())); - std::vector adapters; - adapters.reserve(device_count); + std::vector device_mappings; + device_mappings.reserve(device_count); for (const VkPhysicalDevice physical_device : devices) { - adapters.emplace_back(physical_device, num_queues); + device_mappings.emplace_back(PhysicalDevice(physical_device), -1); } - return adapters; + return device_mappings; } VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn( @@ -157,21 +154,21 @@ VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn( } VkDebugReportCallbackEXT create_debug_report_callback( - const VkInstance instance, const RuntimeConfiguration config) { + const VkInstance instance, + const RuntimeConfiguration config) { if (VK_NULL_HANDLE == instance || !config.enableValidationMessages) { return VkDebugReportCallbackEXT{}; } const VkDebugReportCallbackCreateInfoEXT debugReportCallbackCreateInfo{ - VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT, // sType - nullptr, // pNext - VK_DEBUG_REPORT_INFORMATION_BIT_EXT | - VK_DEBUG_REPORT_WARNING_BIT_EXT | - VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | - VK_DEBUG_REPORT_ERROR_BIT_EXT | - VK_DEBUG_REPORT_DEBUG_BIT_EXT, // flags - debug_report_callback_fn, // pfnCallback - nullptr, // pUserData + VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT, // sType + nullptr, // pNext + VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | + VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | + VK_DEBUG_REPORT_ERROR_BIT_EXT | + VK_DEBUG_REPORT_DEBUG_BIT_EXT, // flags + debug_report_callback_fn, // pfnCallback + nullptr, // pUserData }; const auto vkCreateDebugReportCallbackEXT = @@ -189,9 +186,7 @@ VkDebugReportCallbackEXT create_debug_report_callback( nullptr, &debug_report_callback)); - TORCH_CHECK( - debug_report_callback, - "Invalid Vulkan debug report callback!"); + TORCH_CHECK(debug_report_callback, "Invalid Vulkan debug report callback!"); return debug_report_callback; } @@ -200,21 +195,22 @@ VkDebugReportCallbackEXT create_debug_report_callback( // Adapter selection methods // -uint32_t select_first(const std::vector& adapters) { - if (adapters.size() == 0) { - TORCH_WARN("Pytorch Vulkan Runtime: no device adapters are available for selection!"); - return adapters.size() + 1; // return out of range to signal invalidity +uint32_t select_first(const std::vector& devices) { + if (devices.size() == 0) { + TORCH_WARN( + "Pytorch Vulkan Runtime: no device devices are available for selection!"); + return devices.size() + 1; // return out of range to signal invalidity } // Select the first adapter that has compute capability - for (const uint32_t i : c10::irange(adapters.size())) { - if (adapters[i].num_compute_queues() > 0) { + for (const uint32_t i : c10::irange(devices.size())) { + if (devices[i].first.num_compute_queues > 0) { return i; } } - TORCH_WARN("Pytorch Vulkan Runtime: no device adapters support compute!"); - return adapters.size() + 1; + TORCH_WARN("Pytorch Vulkan Runtime: no device devices support compute!"); + return devices.size() + 1; } // @@ -241,30 +237,28 @@ std::unique_ptr init_global_vulkan_runtime() { const bool enableValidationMessages = #if defined(DEBUG) - true; + true; #else - false; + false; #endif /* DEBUG */ const bool initDefaultDevice = true; const uint32_t numRequestedQueues = 1; // TODO: raise this value - const RuntimeConfiguration default_config { - enableValidationMessages, - initDefaultDevice, - AdapterSelector::First, - numRequestedQueues, + const RuntimeConfiguration default_config{ + enableValidationMessages, + initDefaultDevice, + AdapterSelector::First, + numRequestedQueues, }; try { return std::make_unique(Runtime(default_config)); - } - catch (const std::exception& e) { + } catch (const std::exception& e) { TORCH_WARN( "Pytorch Vulkan Runtime: Failed to initialize the global vulkan runtime! " "The global vulkan runtime is invalid. Error: ", e.what()); - } - catch (...) { + } catch (...) { TORCH_WARN( "Pytorch Vulkan Runtime: Failed to initialize the global vulkan runtime! " "The global vulkan runtime is invalid. " @@ -277,23 +271,26 @@ std::unique_ptr init_global_vulkan_runtime() { } // namespace Runtime::Runtime(const RuntimeConfiguration config) - : instance_(create_instance(config)), - adapters_(create_adapters(instance_, config.numRequestedQueues)), - default_adapter_i_{}, - debug_report_callback_(create_debug_report_callback(instance_, config)) { + : config_(config), + instance_(create_instance(config_)), + device_mappings_(create_physical_devices(instance_)), + adapters_{}, + default_adapter_i_(UINT32_MAX), + debug_report_callback_(create_debug_report_callback(instance_, config_)) { + // List of adapters will never exceed the number of physical devices + adapters_.reserve(device_mappings_.size()); + if (config.initDefaultDevice) { try { - switch(config.defaultSelector) { + switch (config.defaultSelector) { case AdapterSelector::First: - default_adapter_i_ = init_adapter(select_first); + default_adapter_i_ = create_adapter(select_first); } - } - catch (const std::exception& e) { + } catch (const std::exception& e) { TORCH_WARN( "Pytorch Vulkan Runtime: Could not initialize default device! Error: ", e.what()); - } - catch (...) { + } catch (...) { TORCH_WARN( "Pytorch Vulkan Runtime: Could not initialize default device! Error: " "Unknown."); @@ -302,26 +299,27 @@ Runtime::Runtime(const RuntimeConfiguration config) } Runtime::~Runtime() { - if C10_LIKELY(VK_NULL_HANDLE == instance_) { + if C10_LIKELY (VK_NULL_HANDLE == instance_) { return; } - // Clear adapters list to trigger device destruction before destroying VkInstance + // Clear adapters list to trigger device destruction before destroying + // VkInstance adapters_.clear(); - // Instance must be destroyed last as its used to destroy the debug report callback. + // Instance must be destroyed last as its used to destroy the debug report + // callback. if (debug_report_callback_) { const auto vkDestroyDebugReportCallbackEXT = - (PFN_vkDestroyDebugReportCallbackEXT)vkGetInstanceProcAddr( - instance_, "vkDestroyDebugReportCallbackEXT"); + (PFN_vkDestroyDebugReportCallbackEXT)vkGetInstanceProcAddr( + instance_, "vkDestroyDebugReportCallbackEXT"); - TORCH_CHECK( - vkDestroyDebugReportCallbackEXT, - "Pytorch Vulkan Runtime: Could not load vkDestroyDebugReportCallbackEXT " - "when destroying debug_report_callback_"); + TORCH_CHECK( + vkDestroyDebugReportCallbackEXT, + "Pytorch Vulkan Runtime: Could not load vkDestroyDebugReportCallbackEXT " + "when destroying debug_report_callback_"); - vkDestroyDebugReportCallbackEXT( - instance_, debug_report_callback_, nullptr); + vkDestroyDebugReportCallbackEXT(instance_, debug_report_callback_, nullptr); debug_report_callback_ = {}; } @@ -331,29 +329,40 @@ Runtime::~Runtime() { } Runtime::Runtime(Runtime&& other) noexcept - : instance_(other.instance_), - adapters_(std::move(other.adapters_)), - default_adapter_i_(other.default_adapter_i_), - debug_report_callback_(other.debug_report_callback_) { + : config_(other.config_), + instance_(other.instance_), + adapters_(std::move(other.adapters_)), + default_adapter_i_(other.default_adapter_i_), + debug_report_callback_(other.debug_report_callback_) { other.instance_ = VK_NULL_HANDLE; other.debug_report_callback_ = {}; } -uint32_t Runtime::init_adapter(const Selector& selector) { +uint32_t Runtime::create_adapter(const Selector& selector) { TORCH_CHECK( - adapters_.size() > 0, + device_mappings_.size() > 0, "Pytorch Vulkan Runtime: Could not initialize adapter because no " "devices were found by the Vulkan instance."); - uint32_t i = selector(adapters_); + uint32_t physical_device_i = selector(device_mappings_); TORCH_CHECK( - i < adapters_.size(), + physical_device_i < device_mappings_.size(), "Pytorch Vulkan Runtime: no suitable device adapter was selected! " "Device could not be initialized"); - adapters_[i].init_device(); + Runtime::DeviceMapping& device_mapping = device_mappings_[physical_device_i]; + // If an Adapter has already been created, return that + int32_t adapter_i = device_mapping.second; + if (adapter_i >= 0) { + return adapter_i; + } + // Otherwise, create an adapter for the selected physical device + adapter_i = utils::safe_downcast(adapters_.size()); + adapters_.emplace_back( + new Adapter(instance_, device_mapping.first, config_.numRequestedQueues)); + device_mapping.second = adapter_i; - return i; + return adapter_i; } Runtime* runtime() { @@ -361,7 +370,8 @@ Runtime* runtime() { // non-static function to ensure it has external linkage. If it were a global // static variable there would be one copy per translation unit that includes // Runtime.h as it would have internal linkage. - static const std::unique_ptr p_runtime = init_global_vulkan_runtime(); + static const std::unique_ptr p_runtime = + init_global_vulkan_runtime(); TORCH_CHECK( p_runtime, "Pytorch Vulkan Runtime: The global runtime could not be retrieved " diff --git a/aten/src/ATen/native/vulkan/api/Runtime.h b/aten/src/ATen/native/vulkan/api/Runtime.h index 140c0869d6279..c6f2799a09cb3 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.h +++ b/aten/src/ATen/native/vulkan/api/Runtime.h @@ -2,6 +2,7 @@ #ifdef USE_VULKAN_API +#include #include namespace at { @@ -32,9 +33,10 @@ struct RuntimeConfiguration final { class Runtime final { public: - explicit Runtime(const RuntimeConfiguration config); + explicit Runtime(const RuntimeConfiguration); - // Do not allow copying. There should be only one global instance of this class. + // Do not allow copying. There should be only one global instance of this + // class. Runtime(const Runtime&) = delete; Runtime& operator=(const Runtime&) = delete; @@ -43,9 +45,16 @@ class Runtime final { ~Runtime(); + using DeviceMapping = std::pair; + using AdapterPtr = std::unique_ptr; + private: + RuntimeConfiguration config_; + VkInstance instance_; - std::vector adapters_; + + std::vector device_mappings_; + std::vector adapters_; uint32_t default_adapter_i_; VkDebugReportCallbackEXT debug_report_callback_; @@ -59,30 +68,25 @@ class Runtime final { TORCH_CHECK( default_adapter_i_ >= 0 && default_adapter_i_ < adapters_.size(), "Pytorch Vulkan Runtime: Default device adapter is not set correctly!"); - return &adapters_[default_adapter_i_]; - } - - inline Adapter& get_adapter() { - TORCH_CHECK( - default_adapter_i_ >= 0 && default_adapter_i_ < adapters_.size(), - "Pytorch Vulkan Runtime: Default device adapter is not set correctly!"); - return adapters_[default_adapter_i_]; + return adapters_[default_adapter_i_].get(); } inline Adapter* get_adapter_p(uint32_t i) { - return &adapters_[i]; - } - - inline Adapter& get_adapter(uint32_t i) { - return adapters_[i]; + TORCH_CHECK( + i >= 0 && i < adapters_.size(), + "Pytorch Vulkan Runtime: Adapter at index ", + i, + " is not available!"); + return adapters_[i].get(); } inline uint32_t default_adapter_i() const { return default_adapter_i_; } - using Selector = std::function&)>; - uint32_t init_adapter(const Selector& selector); + using Selector = + std::function&)>; + uint32_t create_adapter(const Selector&); }; // The global runtime is retrieved using this function, where it is declared as diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp index 6e075e50dd144..de7ac11d418da 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.cpp +++ b/aten/src/ATen/native/vulkan/api/Shader.cpp @@ -9,167 +9,217 @@ namespace native { namespace vulkan { namespace api { -Shader::Layout::Factory::Factory(const GPU& gpu) - : device_(gpu.device) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_, - "Invalid Vulkan device!"); +// +// ShaderSource +// + +ShaderSource::ShaderSource(std::string name, const char* const glsl_src) + : type(ShaderSource::Type::GLSL), + src_code{ + .glsl = + { + glsl_src, + 0u, + }, + }, + kernel_name{std::move(name)} {} + +ShaderSource::ShaderSource( + std::string name, + const uint32_t* const spirv_bin, + const uint32_t size, + const std::vector& layout) + : type(Type::SPIRV), + src_code{ + .spirv = + { + spirv_bin, + size, + }, + }, + kernel_name{std::move(name)}, + kernel_layout{layout} {} + +bool operator==(const ShaderSource& _1, const ShaderSource& _2) { + if (_1.type != _2.type) { + return false; + } + + if (_1.type == ShaderSource::Type::SPIRV) { + return ( + _1.src_code.spirv.bin == _2.src_code.spirv.bin && + _1.src_code.spirv.size == _2.src_code.spirv.size); + } else { + return (_1.src_code.glsl.src == _2.src_code.glsl.src); + } } -Shader::Layout::Factory::Handle Shader::Layout::Factory::operator()( - const Descriptor& descriptor) const { +// +// ShaderLayout +// + +ShaderLayout::ShaderLayout( + const VkDevice device, + const ShaderLayout::Signature& signature) + : device_(device), handle_{VK_NULL_HANDLE} { c10::SmallVector bindings; - uint32_t binding = 0u; - for (const VkDescriptorType type : descriptor.signature) { + uint32_t binding_num = 0u; + for (const VkDescriptorType type : signature) { bindings.push_back({ - binding++, - type, - 1u, - VK_SHADER_STAGE_COMPUTE_BIT, - nullptr, + binding_num++, // binding + type, // descriptorType + 1u, // descriptorCount + VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags + nullptr, // pImmutableSamplers }); } const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{ - VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, - nullptr, - 0u, - static_cast(bindings.size()), - bindings.data(), + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + static_cast(bindings.size()), // bindingCount + bindings.data(), // pBindings }; - VkDescriptorSetLayout descriptor_set_layout{}; VK_CHECK(vkCreateDescriptorSetLayout( - device_, - &descriptor_set_layout_create_info, - nullptr, - &descriptor_set_layout)); - - TORCH_CHECK( - descriptor_set_layout, - "Invalid Vulkan descriptor set layout!"); - - return Handle{ - descriptor_set_layout, - Deleter(device_), - }; + device_, &descriptor_set_layout_create_info, nullptr, &handle_)); } -Shader::Layout::Cache::Cache(Factory factory) - : cache_(std::move(factory)) { +ShaderLayout::ShaderLayout(ShaderLayout&& other) noexcept + : device_(other.device_), handle_(other.handle_) { + other.handle_ = VK_NULL_HANDLE; } -void Shader::Layout::Cache::purge() { - cache_.purge(); +ShaderLayout::~ShaderLayout() { + if C10_LIKELY (VK_NULL_HANDLE == handle_) { + return; + } + vkDestroyDescriptorSetLayout(device_, handle_, nullptr); + handle_ = VK_NULL_HANDLE; } -#ifdef USE_VULKAN_SHADERC_RUNTIME +void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept { + VkDevice tmp_device = lhs.device_; + VkDescriptorSetLayout tmp_handle = lhs.handle_; -struct Shader::Factory::Compiler final { - shaderc::Compiler context; - shaderc::CompileOptions options; - - Compiler() { - options.SetNanClamp(/*enable =*/ true); - options.SetSourceLanguage(shaderc_source_language_glsl); - options.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_0); - options.SetWarningsAsErrors(); - #ifdef DEBUG - options.SetGenerateDebugInfo(); - #endif /* DEBUG */ - options.SetOptimizationLevel(shaderc_optimization_level_zero); - } + lhs.device_ = rhs.device_; + lhs.handle_ = rhs.handle_; - std::vector compile(const char* const source) const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - source, - "Invalid shader source code!"); - - const shaderc::SpvCompilationResult result = context.CompileGlslToSpv( - source, - ::strlen(source), - shaderc_compute_shader, - "vulkan_shader.comp", - options); - - const shaderc_compilation_status status = result.GetCompilationStatus(); - TORCH_INTERNAL_ASSERT( - shaderc_compilation_status_success == status, - "Shader compilation error: ", - result.GetErrorMessage()); - - return std::vector(result.cbegin(), result.cend()); - } -}; + rhs.device_ = tmp_device; + rhs.handle_ = tmp_handle; +} + +// +// ShaderModule +// -#else +ShaderModule::ShaderModule(const VkDevice device, const ShaderSource& source) + : device_(device), handle_{VK_NULL_HANDLE} { + const uint32_t* code = source.src_code.spirv.bin; + uint32_t size = source.src_code.spirv.size; -struct Shader::Factory::Compiler final { - std::vector compile(const char* const /* source */) const { - return std::vector{}; + const VkShaderModuleCreateInfo shader_module_create_info{ + VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + size, // codeSize + code, // pCode + }; + + VK_CHECK(vkCreateShaderModule( + device_, &shader_module_create_info, nullptr, &handle_)); +} + +ShaderModule::ShaderModule(ShaderModule&& other) noexcept + : device_(other.device_), handle_(other.handle_) { + other.handle_ = VK_NULL_HANDLE; +} + +ShaderModule::~ShaderModule() { + if C10_LIKELY (VK_NULL_HANDLE == handle_) { + return; } -}; + vkDestroyShaderModule(device_, handle_, nullptr); + handle_ = VK_NULL_HANDLE; +} -#endif /* USE_VULKAN_SHADERC_RUNTIME */ +void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept { + VkDevice tmp_device = lhs.device_; + VkShaderModule tmp_handle = lhs.handle_; + + lhs.device_ = rhs.device_; + lhs.handle_ = rhs.handle_; -Shader::Factory::Factory(const GPU& gpu) - : device_(gpu.device), - compiler_(new Compiler) { + rhs.device_ = tmp_device; + rhs.handle_ = tmp_handle; } -// std::unique_ptr requires its template parameter to be fully defined. -// For that reason pimpl through unique_ptr requires the definition of -// the [default] constructor and move assignment operator to appear after -// impl is fully defined. +// +// ShaderLayoutCache +// -Shader::Factory::Factory(Factory&&) = default; -Shader::Factory& Shader::Factory::Factory::operator=(Factory&&) = default; -Shader::Factory::~Factory() = default; +ShaderLayoutCache::ShaderLayoutCache(const VkDevice device) + : cache_mutex_{}, device_(device), cache_{} {} -typename Shader::Factory::Handle Shader::Factory::operator()( - const Descriptor& descriptor) const { - std::vector binary; +ShaderLayoutCache::ShaderLayoutCache(ShaderLayoutCache&& other) noexcept + : cache_mutex_{}, device_(other.device_) { + std::lock_guard lock(other.cache_mutex_); + cache_ = std::move(other.cache_); +} - const uint32_t* code = nullptr; - uint32_t size = 0u; +ShaderLayoutCache::~ShaderLayoutCache() { + purge(); +} - if (Descriptor::Type::Source == descriptor.type) { - binary = compiler_->compile(descriptor.shader.source.glsl); - code = binary.data(); - size = sizeof(uint32_t) * static_cast(binary.size()); - } - else if (Descriptor::Type::Binary == descriptor.type) { - code = descriptor.shader.binary.spirv; - size = descriptor.shader.binary.size; +VkDescriptorSetLayout ShaderLayoutCache::retrieve( + const ShaderLayoutCache::Key& key) { + std::lock_guard lock(cache_mutex_); + + auto it = cache_.find(key); + if C10_UNLIKELY (cache_.cend() == it) { + it = cache_.insert({key, ShaderLayoutCache::Value(device_, key)}).first; } - else { - TORCH_INTERNAL_ASSERT(false, "Invalid descriptor type!"); + + return it->second.handle(); +} + +void ShaderLayoutCache::purge() { + std::lock_guard lock(cache_mutex_); + cache_.clear(); +} + +// +// ShaderCache +// + +ShaderCache::ShaderCache(const VkDevice device) + : cache_mutex_{}, device_(device), cache_{} {} + +ShaderCache::ShaderCache(ShaderCache&& other) noexcept + : cache_mutex_{}, device_(other.device_) { + std::lock_guard lock(other.cache_mutex_); + cache_ = std::move(other.cache_); +} + +ShaderCache::~ShaderCache() { + purge(); +} + +VkShaderModule ShaderCache::retrieve(const ShaderCache::Key& key) { + std::lock_guard lock(cache_mutex_); + + auto it = cache_.find(key); + if C10_UNLIKELY (cache_.cend() == it) { + it = cache_.insert({key, ShaderCache::Value(device_, key)}).first; } - const VkShaderModuleCreateInfo shader_module_create_info{ - VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, - nullptr, - 0u, - size, - code, - }; + return it->second.handle(); +} - VkShaderModule shader_module{}; - VK_CHECK(vkCreateShaderModule( - device_, - &shader_module_create_info, - nullptr, - &shader_module)); - - TORCH_CHECK( - shader_module, - "Invalid Vulkan shader module!"); - - return Handle{ - shader_module, - Deleter(device_), - }; +void ShaderCache::purge() { + cache_.clear(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h index e68061a320b70..3db3adc4d00d1 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.h +++ b/aten/src/ATen/native/vulkan/api/Shader.h @@ -3,7 +3,6 @@ #ifdef USE_VULKAN_API #include -#include #include #include @@ -12,274 +11,161 @@ namespace native { namespace vulkan { namespace api { -// -// This struct defines shader, and shader layout, caches intended to minimize -// redundant object reconstructions at the cost of extra memory consumption. -// -// A shader is a small, usually simple, program that typically runs on a GPU as -// part of the graphics or compute pipelines. The shader layout defines the -// interface between that program and the outside world, namely what the host -// (i.e. CPU) sees as configurable parameters of the said shader per dispatch. -// If the shader was a regular function, the shader layout would have been its -// function prototype declaring the number and type of its arguments. -// -// Furthermore, shader layouts, or as Vulkan calls them descriptor set layouts, -// define the blueprint out of which descriptor sets are instantiated. Descriptor -// sets themselves, bundle the input to and output from a shader and contain -// pointers to GPU, and GPU accessible system, memory locations where the actual -// resources reside. Shader layouts are also used in creation of Vulkan pipeline -// layouts, while multiple shaders are bundled together to form a portion of the -// the monolithic state objects that are Vulkan pipelines. -// -// This struct defines the facilities required to create, compile, reuse, -// and destruct the aforementioned Vulkan objects. -// - -struct Shader final { - // - // Layout - // - - struct Layout final { - /* - Signature - */ - - typedef c10::SmallVector Signature; - - /* - Descriptor - */ - - struct Descriptor final { - Signature signature; - }; - - /* - Factory - */ - - class Factory final { - public: - explicit Factory(const GPU& gpu); - - typedef Layout::Descriptor Descriptor; - typedef VK_DELETER(DescriptorSetLayout) Deleter; - typedef api::Handle Handle; - - struct Hasher { - size_t operator()(const Descriptor& descriptor) const; - }; - - Handle operator()(const Descriptor& descriptor) const; - - private: - VkDevice device_; - }; - - struct Object final { - VkDescriptorSetLayout handle; - Signature signature; - - operator bool() const; - }; - - /* - Cache - */ - - class Cache final { - public: - explicit Cache(Factory factory); - Cache(const Cache&) = delete; - Cache& operator=(const Cache&) = delete; - Cache(Cache&&) = default; - Cache& operator=(Cache&&) = default; - ~Cache() = default; - - Object retrieve(const Descriptor& descriptor); - void purge(); - - private: - api::Cache cache_; - } cache; - - explicit Layout(const GPU& gpu) - : cache(Factory(gpu)) { - } - } layout; - - // - // Work Group - // - - typedef utils::uvec3 WorkGroup; - - /* - Descriptor - */ - - struct Descriptor final { - enum class Type { - Source, - Binary, - } type; - - union { - struct { - const char* glsl; // Null-terminated - uint32_t unused; // Padding - } source; - - struct { - const uint32_t* spirv; - uint32_t size; // Bytes - } binary; - } shader; - - Descriptor(const char* glsl); - Descriptor(const uint32_t* spirv, uint32_t bytes); - }; +class ShaderLayout final { + public: + using Signature = c10::SmallVector; - /* - Factory - */ - - class Factory final { - public: - explicit Factory(const GPU& gpu); - Factory(const Factory&) = delete; - Factory& operator=(const Factory&) = delete; - Factory(Factory&&); - Factory& operator=(Factory&&); - ~Factory(); - - typedef Shader::Descriptor Descriptor; - typedef VK_DELETER(ShaderModule) Deleter; - typedef api::Handle Handle; - - struct Hasher { - size_t operator()(const Descriptor& descriptor) const; - }; - - Handle operator()(const Descriptor& descriptor) const; - - private: - VkDevice device_; - struct Compiler; - std::unique_ptr compiler_; - }; + explicit ShaderLayout(const VkDevice, const Signature&); + + ShaderLayout(const ShaderLayout&) = delete; + ShaderLayout& operator=(const ShaderLayout&) = delete; + + ShaderLayout(ShaderLayout&&) noexcept; + ShaderLayout& operator=(ShaderLayout&&) = delete; - /* - Cache - */ + ~ShaderLayout(); - typedef api::Cache Cache; - Cache cache; + private: + VkDevice device_; + VkDescriptorSetLayout handle_; - explicit Shader(const GPU& gpu) - : layout(gpu), - cache(Factory(gpu)) { + public: + VkDescriptorSetLayout handle() const { + return handle_; } + + // We need to define a custom swap function since this class + // does not allow for move assignment. The swap function will + // be used in the hash map. + friend void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept; }; -// -// Impl -// +struct ShaderSource final { + enum class Type { GLSL, SPIRV } type; + + union { + struct { + const char* src; // Null-terminated + uint32_t unused; // padding + } glsl; + struct { + const uint32_t* bin; + uint32_t size; + } spirv; + } src_code; + + std::string kernel_name; + ShaderLayout::Signature kernel_layout; + + explicit ShaderSource(std::string, const char*); + explicit ShaderSource( + std::string, + const uint32_t*, + const uint32_t, + const std::vector&); +}; -inline bool operator==( - const Shader::Layout::Descriptor& _1, - const Shader::Layout::Descriptor& _2) { - return _1.signature == _2.signature; -} +class ShaderModule final { + public: + explicit ShaderModule(const VkDevice device, const ShaderSource& source); + + ShaderModule(const ShaderModule&) = delete; + ShaderModule& operator=(const ShaderModule&) = delete; -inline size_t Shader::Layout::Factory::Hasher::operator()( - const Descriptor& descriptor) const { - size_t hash = 0u; + ShaderModule(ShaderModule&&) noexcept; + ShaderModule& operator=(ShaderModule&&) = delete; - for (const VkDescriptorType type : descriptor.signature) { - hash = c10::hash_combine( - hash, - c10::get_hash(type)); + ~ShaderModule(); + + private: + VkDevice device_; + VkShaderModule handle_; + + public: + inline VkShaderModule handle() const { + return handle_; } - return hash; -} + // We need to define a custom swap function since this class + // does not allow for move assignment. The swap function will + // be used in the hash map. + friend void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept; +}; -inline Shader::Layout::Object::operator bool() const { - return VK_NULL_HANDLE != handle; -} +class ShaderLayoutCache final { + public: + explicit ShaderLayoutCache(const VkDevice device); + + ShaderLayoutCache(const ShaderLayoutCache&) = delete; + ShaderLayoutCache& operator=(const ShaderLayoutCache&) = delete; + + ShaderLayoutCache(ShaderLayoutCache&&) noexcept; + ShaderLayoutCache& operator=(ShaderLayoutCache&&) = delete; + + ~ShaderLayoutCache(); + + using Key = ShaderLayout::Signature; + using Value = ShaderLayout; + + struct Hasher { + inline size_t operator()(const ShaderLayout::Signature& signature) const { + size_t hashed = 0u; -inline Shader::Layout::Object Shader::Layout::Cache::retrieve( - const Descriptor& descriptor) { - return { - cache_.retrieve(descriptor), - descriptor.signature, + for (const VkDescriptorType type : signature) { + hashed = c10::hash_combine(hashed, c10::get_hash(type)); + } + + return hashed; + } }; -} -inline bool operator==( - const Shader::WorkGroup& _1, - const Shader::WorkGroup& _2) { + private: + // Multiple threads could potentially be adding entries into the cache, so use + // a mutex to manage access + std::mutex cache_mutex_; - return (_1.data[0u] == _2.data[0u] && _1.data[1u] == _2.data[1u] && _1.data[2u] == _2.data[2u]); -} + VkDevice device_; + ska::flat_hash_map cache_; -inline Shader::Descriptor::Descriptor(const char* const glsl) - : type(Type::Source), - shader{ - .source = { - glsl, - 0u, - }, - } { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - glsl, - "Invalid shader source code!"); -} + public: + VkDescriptorSetLayout retrieve(const Key&); + void purge(); +}; -inline Shader::Descriptor::Descriptor( - const uint32_t* const code, - const uint32_t size) - : type(Type::Binary), - shader{ - .binary = { - code, - size, - }, - } { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - code && (0u != size), - "Invalid shader binary!"); -} +class ShaderCache final { + public: + explicit ShaderCache(const VkDevice device); -inline bool operator==( - const Shader::Descriptor& _1, - const Shader::Descriptor& _2) { + ShaderCache(const ShaderCache&) = delete; + ShaderCache& operator=(const ShaderCache&) = delete; - if (_1.type != _2.type) - return false; + ShaderCache(ShaderCache&&) noexcept; + ShaderCache& operator=(ShaderCache&&) = delete; - if (_1.type == Shader::Descriptor::Type::Binary) { - return (_1.shader.binary.spirv == _2.shader.binary.spirv && \ - _1.shader.binary.size == _2.shader.binary.size); - } - else { - return (_1.shader.source.glsl == _2.shader.source.glsl); - } -} + ~ShaderCache(); -inline size_t Shader::Factory::Hasher::operator()( - const Descriptor& descriptor) const { - static_assert( - sizeof(Descriptor::shader.source) == sizeof(Descriptor::shader.binary), - "This implementation requires sizeof(Source) to be equal to sizeof(Binary)."); + using Key = ShaderSource; + using Value = ShaderModule; - return c10::get_hash( - descriptor.type, - descriptor.shader.binary.spirv, - descriptor.shader.binary.size); -} + struct Hasher { + inline size_t operator()(const ShaderSource& source) const { + return c10::get_hash( + source.type, source.src_code.spirv.bin, source.src_code.spirv.size); + } + }; + + private: + // Multiple threads could potentially be adding entries into the cache, so use + // a mutex to manage access + std::mutex cache_mutex_; + + VkDevice device_; + ska::flat_hash_map cache_; + + public: + VkShaderModule retrieve(const Key&); + void purge(); +}; } // namespace api } // namespace vulkan @@ -289,12 +175,11 @@ inline size_t Shader::Factory::Hasher::operator()( inline bool operator==( const VkDescriptorSetLayoutBinding& _1, const VkDescriptorSetLayoutBinding& _2) { - - return (_1.binding == _2.binding && \ - _1.descriptorType == _2.descriptorType && \ - _1.descriptorCount == _2.descriptorCount && \ - _1.stageFlags == _2.stageFlags && \ - _1.pImmutableSamplers == _2.pImmutableSamplers); + return ( + _1.binding == _2.binding && _1.descriptorType == _2.descriptorType && + _1.descriptorCount == _2.descriptorCount && + _1.stageFlags == _2.stageFlags && + _1.pImmutableSamplers == _2.pImmutableSamplers); } #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/ThreadContext.cpp b/aten/src/ATen/native/vulkan/api/ThreadContext.cpp deleted file mode 100644 index d230d97ecda78..0000000000000 --- a/aten/src/ATen/native/vulkan/api/ThreadContext.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#ifdef USE_VULKAN_API - -#include - -namespace at { -namespace native { -namespace vulkan { -namespace api { - -ThreadContext::ThreadContext(const GPU& gpu) - : gpu_(gpu) { -} - -template -ThreadContext::SingletonThreadLocalObject::SingletonThreadLocalObject(const GPU& gpu) { - TORCH_INTERNAL_ASSERT(false, "SingletonThreadLocalObject doesn't support the generalized template constructor!"); -} - -// -// Specialized template functions -// - -template<> -ThreadContext::SingletonThreadLocalObject::SingletonThreadLocalObject(const GPU& gpu) - : object_(gpu) { -} - -template<> -ThreadContext::SingletonThreadLocalObject::SingletonThreadLocalObject(const GPU& gpu) - : object_(gpu) { -} - -template<> -ThreadContext::SingletonThreadLocalObject::SingletonThreadLocalObject(const GPU& gpu) - : object_(gpu) { -} - -template<> -ThreadContext::SingletonThreadLocalObject::SingletonThreadLocalObject(const GPU& gpu) - : object_(gpu.device, - gpu.adapter->timestamp_compute_and_graphics(), - gpu.adapter->timestamp_period()) { -} - -} // namespace api -} // namespace vulkan -} // namespace native -} // namespace at - -#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/ThreadContext.h b/aten/src/ATen/native/vulkan/api/ThreadContext.h deleted file mode 100644 index 0145e345f8d7e..0000000000000 --- a/aten/src/ATen/native/vulkan/api/ThreadContext.h +++ /dev/null @@ -1,81 +0,0 @@ -#pragma once - -#ifdef USE_VULKAN_API - -#include -#include -#include -#include -#include - -namespace at { -namespace native { -namespace vulkan { -namespace api { - -// -// Vulkan Thread Context holds onto all per-thread Vulkan states such as -// Command, Descriptor and Resource objects. -// - -class ThreadContext final { - public: - ThreadContext() = delete; - explicit ThreadContext(const GPU& gpu); - ThreadContext(const ThreadContext&) = delete; - ThreadContext(ThreadContext&&) = default; - ThreadContext& operator=(const ThreadContext&) = delete; - ThreadContext& operator=(ThreadContext&&) = default; - - Command& command(); - Descriptor& descriptor(); - Resource& resource(); - QueryPool& querypool(); - - private: - GPU gpu_; - - private: - template - class SingletonThreadLocalObject final { - public: - explicit SingletonThreadLocalObject(const GPU& gpu); - SingletonThreadLocalObject(const SingletonThreadLocalObject&) = delete; - SingletonThreadLocalObject& operator=(const SingletonThreadLocalObject&) = delete; - SingletonThreadLocalObject(SingletonThreadLocalObject&&) = default; - SingletonThreadLocalObject& operator=(SingletonThreadLocalObject&&) = default; - inline static T& get(const GPU& gpu) { - static thread_local SingletonThreadLocalObject object(gpu); - return object.object_; - } - private: - T object_; - }; -}; - -// -// Impl -// - -inline Command& ThreadContext::command() { - return SingletonThreadLocalObject::get(gpu_); -} - -inline Descriptor& ThreadContext::descriptor() { - return SingletonThreadLocalObject::get(gpu_); -} - -inline Resource& ThreadContext::resource() { - return SingletonThreadLocalObject::get(gpu_); -} - -inline QueryPool& ThreadContext::querypool() { - return SingletonThreadLocalObject::get(gpu_); -} - -} // namespace api -} // namespace vulkan -} // namespace native -} // namespace at - -#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Utils.h b/aten/src/ATen/native/vulkan/api/Utils.h index 10ace9addc3ff..7c350abc23f21 100644 --- a/aten/src/ATen/native/vulkan/api/Utils.h +++ b/aten/src/ATen/native/vulkan/api/Utils.h @@ -1,5 +1,8 @@ #pragma once -#include // For c10::overflows + +#include // For c10::overflows + +#include #ifdef USE_VULKAN_API @@ -13,24 +16,18 @@ namespace utils { // Alignment // -template -inline constexpr Type align_down( - const Type number, - const Type multiple) { +template +inline constexpr Type align_down(const Type number, const Type multiple) { return (number / multiple) * multiple; } -template -inline constexpr Type align_up( - const Type number, - const Type multiple) { +template +inline constexpr Type align_up(const Type number, const Type multiple) { return align_down(number + multiple - 1, multiple); } -template -inline constexpr Type div_up( - const Type numerator, - const Type denominator) { +template +inline constexpr Type div_up(const Type numerator, const Type denominator) { return (numerator + denominator - 1) / denominator; } @@ -76,32 +73,50 @@ inline constexpr To safe_downcast(const From v) { namespace detail { -template +template struct vec final { Type data[N]; }; } // namespace detail -template +template using ivec = detail::vec; using ivec2 = ivec<2u>; using ivec3 = ivec<3u>; using ivec4 = ivec<4u>; -template +template using uvec = detail::vec; using uvec2 = uvec<2u>; using uvec3 = uvec<3u>; using uvec4 = uvec<4u>; -template +template using vec = detail::vec; using vec2 = vec<2u>; using vec3 = vec<3u>; using vec4 = vec<4u>; } // namespace utils + +inline bool operator==(const utils::uvec3& _1, const utils::uvec3& _2) { + return ( + _1.data[0u] == _2.data[0u] && _1.data[1u] == _2.data[1u] && + _1.data[2u] == _2.data[2u]); +} + +inline VkOffset3D create_offset3d(const utils::uvec3& offsets) { + return VkOffset3D{ + static_cast(offsets.data[0u]), + static_cast(offsets.data[1u]), + static_cast(offsets.data[2u])}; +} + +inline VkExtent3D create_extent3d(const utils::uvec3& extents) { + return VkExtent3D{extents.data[0u], extents.data[1u], extents.data[2u]}; +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/vk_api.h b/aten/src/ATen/native/vulkan/api/vk_api.h new file mode 100644 index 0000000000000..d704615900b00 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/vk_api.h @@ -0,0 +1,15 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#ifdef USE_VULKAN_WRAPPER +#ifdef USE_VULKAN_VOLK +#include +#else +#include +#endif /* USE_VULKAN_VOLK */ +#else +#include +#endif /* USE_VULKAN_WRAPPER */ + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h index e75cd56979be7..7b04e54d944bd 100644 --- a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h +++ b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h @@ -1,5 +1,5 @@ // -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,12 +25,12 @@ /** \mainpage Vulkan Memory Allocator -Version 3.0.0-development (2021-02-16) +Version 3.0.1 (2022-05-26) -Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. \n +Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. \n License: MIT -Documentation of all members: vk_mem_alloc.h +API documentation divided into groups: [Modules](modules.html) \section main_table_of_contents Table of contents @@ -49,7 +49,6 @@ Documentation of all members: vk_mem_alloc.h - [Mapping functions](@ref memory_mapping_mapping_functions) - [Persistently mapped memory](@ref memory_mapping_persistently_mapped_memory) - [Cache flush and invalidate](@ref memory_mapping_cache_control) - - [Finding out if memory is mappable](@ref memory_mapping_finding_if_memory_mappable) - \subpage staying_within_budget - [Querying for budget](@ref staying_within_budget_querying_for_budget) - [Controlling memory usage](@ref staying_within_budget_controlling_memory_usage) @@ -61,19198 +60,19499 @@ Documentation of all members: vk_mem_alloc.h - [Stack](@ref linear_algorithm_stack) - [Double stack](@ref linear_algorithm_double_stack) - [Ring buffer](@ref linear_algorithm_ring_buffer) - - [Buddy allocation algorithm](@ref buddy_algorithm) - \subpage defragmentation - - [Defragmenting CPU memory](@ref defragmentation_cpu) - - [Defragmenting GPU memory](@ref defragmentation_gpu) - - [Additional notes](@ref defragmentation_additional_notes) - - [Writing custom allocation algorithm](@ref defragmentation_custom_algorithm) - - \subpage lost_allocations - \subpage statistics - [Numeric statistics](@ref statistics_numeric_statistics) - [JSON dump](@ref statistics_json_dump) - \subpage allocation_annotation - [Allocation user data](@ref allocation_user_data) - [Allocation names](@ref allocation_names) + - \subpage virtual_allocator - \subpage debugging_memory_usage - [Memory initialization](@ref debugging_memory_usage_initialization) - [Margins](@ref debugging_memory_usage_margins) - [Corruption detection](@ref debugging_memory_usage_corruption_detection) - - \subpage record_and_replay + - \subpage opengl_interop - \subpage usage_patterns - - [Common mistakes](@ref usage_patterns_common_mistakes) - - [Simple patterns](@ref usage_patterns_simple) - - [Advanced patterns](@ref usage_patterns_advanced) + - [GPU-only resource](@ref usage_patterns_gpu_only) + - [Staging copy for upload](@ref usage_patterns_staging_copy_upload) + - [Readback](@ref usage_patterns_readback) + - [Advanced data uploading](@ref usage_patterns_advanced_data_uploading) + - [Other use cases](@ref usage_patterns_other_use_cases) - \subpage configuration - [Pointers to Vulkan functions](@ref config_Vulkan_functions) - [Custom host memory allocator](@ref custom_memory_allocator) - [Device memory allocation callbacks](@ref allocation_callbacks) - [Device heap memory limit](@ref heap_memory_limit) - - \subpage vk_khr_dedicated_allocation - - \subpage enabling_buffer_device_address - - \subpage vk_amd_device_coherent_memory +- Extension support + - \subpage vk_khr_dedicated_allocation + - \subpage enabling_buffer_device_address + - \subpage vk_ext_memory_priority + - \subpage vk_amd_device_coherent_memory - \subpage general_considerations - [Thread safety](@ref general_considerations_thread_safety) + - [Versioning and compatibility](@ref general_considerations_versioning_and_compatibility) - [Validation layer warnings](@ref general_considerations_validation_layer_warnings) - [Allocation algorithm](@ref general_considerations_allocation_algorithm) - [Features not supported](@ref general_considerations_features_not_supported) \section main_see_also See also -- [Product page on GPUOpen](https://gpuopen.com/gaming-product/vulkan-memory-allocator/) -- [Source repository on GitHub](https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator) +- [**Product page on GPUOpen**](https://gpuopen.com/gaming-product/vulkan-memory-allocator/) +- [**Source repository on GitHub**](https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator) +\defgroup group_init Library initialization +\brief API elements related to the initialization and management of the entire library, especially #VmaAllocator object. +\defgroup group_alloc Memory allocation -\page quick_start Quick start - -\section quick_start_project_setup Project setup +\brief API elements related to the allocation, deallocation, and management of Vulkan memory, buffers, images. +Most basic ones being: vmaCreateBuffer(), vmaCreateImage(). -Vulkan Memory Allocator comes in form of a "stb-style" single header file. -You don't need to build it as a separate library project. -You can add this file directly to your project and submit it to code repository next to your other source files. +\defgroup group_virtual Virtual allocator -"Single header" doesn't mean that everything is contained in C/C++ declarations, -like it tends to be in case of inline functions or C++ templates. -It means that implementation is bundled with interface in a single file and needs to be extracted using preprocessor macro. -If you don't do it properly, you will get linker errors. +\brief API elements related to the mechanism of \ref virtual_allocator - using the core allocation algorithm +for user-defined purpose without allocating any real GPU memory. -To do it properly: +\defgroup group_stats Statistics --# Include "vk_mem_alloc.h" file in each CPP file where you want to use the library. - This includes declarations of all members of the library. --# In exacly one CPP file define following macro before this include. - It enables also internal definitions. +\brief API elements that query current status of the allocator, from memory usage, budget, to full dump of the internal state in JSON format. +See documentation chapter: \ref statistics. +*/ -\code -#define VMA_IMPLEMENTATION -#include -\endcode -It may be a good idea to create dedicated CPP file just for this purpose. +#ifdef __cplusplus +extern "C" { +#endif -Note on language: This library is written in C++, but has C-compatible interface. -Thus you can include and use vk_mem_alloc.h in C or C++ code, but full -implementation with `VMA_IMPLEMENTATION` macro must be compiled as C++, NOT as C. +#ifndef VULKAN_H_ + #include +#endif -Please note that this library includes header ``, which in turn -includes `` on Windows. If you need some specific macros defined -before including these headers (like `WIN32_LEAN_AND_MEAN` or -`WINVER` for Windows, `VK_USE_PLATFORM_WIN32_KHR` for Vulkan), you must define -them before every `#include` of this library. +// Define this macro to declare maximum supported Vulkan version in format AAABBBCCC, +// where AAA = major, BBB = minor, CCC = patch. +// If you want to use version > 1.0, it still needs to be enabled via VmaAllocatorCreateInfo::vulkanApiVersion. +#if !defined(VMA_VULKAN_VERSION) + #if defined(VK_VERSION_1_3) + #define VMA_VULKAN_VERSION 1003000 + #elif defined(VK_VERSION_1_2) + #define VMA_VULKAN_VERSION 1002000 + #elif defined(VK_VERSION_1_1) + #define VMA_VULKAN_VERSION 1001000 + #else + #define VMA_VULKAN_VERSION 1000000 + #endif +#endif -You may need to configure the way you import Vulkan functions. +#if defined(__ANDROID__) && defined(VK_NO_PROTOTYPES) && VMA_STATIC_VULKAN_FUNCTIONS + extern PFN_vkGetInstanceProcAddr vkGetInstanceProcAddr; + extern PFN_vkGetDeviceProcAddr vkGetDeviceProcAddr; + extern PFN_vkGetPhysicalDeviceProperties vkGetPhysicalDeviceProperties; + extern PFN_vkGetPhysicalDeviceMemoryProperties vkGetPhysicalDeviceMemoryProperties; + extern PFN_vkAllocateMemory vkAllocateMemory; + extern PFN_vkFreeMemory vkFreeMemory; + extern PFN_vkMapMemory vkMapMemory; + extern PFN_vkUnmapMemory vkUnmapMemory; + extern PFN_vkFlushMappedMemoryRanges vkFlushMappedMemoryRanges; + extern PFN_vkInvalidateMappedMemoryRanges vkInvalidateMappedMemoryRanges; + extern PFN_vkBindBufferMemory vkBindBufferMemory; + extern PFN_vkBindImageMemory vkBindImageMemory; + extern PFN_vkGetBufferMemoryRequirements vkGetBufferMemoryRequirements; + extern PFN_vkGetImageMemoryRequirements vkGetImageMemoryRequirements; + extern PFN_vkCreateBuffer vkCreateBuffer; + extern PFN_vkDestroyBuffer vkDestroyBuffer; + extern PFN_vkCreateImage vkCreateImage; + extern PFN_vkDestroyImage vkDestroyImage; + extern PFN_vkCmdCopyBuffer vkCmdCopyBuffer; + #if VMA_VULKAN_VERSION >= 1001000 + extern PFN_vkGetBufferMemoryRequirements2 vkGetBufferMemoryRequirements2; + extern PFN_vkGetImageMemoryRequirements2 vkGetImageMemoryRequirements2; + extern PFN_vkBindBufferMemory2 vkBindBufferMemory2; + extern PFN_vkBindImageMemory2 vkBindImageMemory2; + extern PFN_vkGetPhysicalDeviceMemoryProperties2 vkGetPhysicalDeviceMemoryProperties2; + #endif // #if VMA_VULKAN_VERSION >= 1001000 +#endif // #if defined(__ANDROID__) && VMA_STATIC_VULKAN_FUNCTIONS && VK_NO_PROTOTYPES -- By default, VMA assumes you you link statically with Vulkan API. If this is not the case, - `#define VMA_STATIC_VULKAN_FUNCTIONS 0` before `#include` of the VMA implementation and use another way. -- You can `#define VMA_DYNAMIC_VULKAN_FUNCTIONS 1` and make sure `vkGetInstanceProcAddr` and `vkGetDeviceProcAddr` globals are defined. - All the remaining Vulkan functions will be fetched automatically. -- Finally, you can provide your own pointers to all Vulkan functions needed by VMA using structure member - VmaAllocatorCreateInfo::pVulkanFunctions, if you fetched them in some custom way e.g. using some loader like [Volk](https://github.com/zeux/volk). +#if !defined(VMA_DEDICATED_ALLOCATION) + #if VK_KHR_get_memory_requirements2 && VK_KHR_dedicated_allocation + #define VMA_DEDICATED_ALLOCATION 1 + #else + #define VMA_DEDICATED_ALLOCATION 0 + #endif +#endif +#if !defined(VMA_BIND_MEMORY2) + #if VK_KHR_bind_memory2 + #define VMA_BIND_MEMORY2 1 + #else + #define VMA_BIND_MEMORY2 0 + #endif +#endif -\section quick_start_initialization Initialization +#if !defined(VMA_MEMORY_BUDGET) + #if VK_EXT_memory_budget && (VK_KHR_get_physical_device_properties2 || VMA_VULKAN_VERSION >= 1001000) + #define VMA_MEMORY_BUDGET 1 + #else + #define VMA_MEMORY_BUDGET 0 + #endif +#endif -At program startup: +// Defined to 1 when VK_KHR_buffer_device_address device extension or equivalent core Vulkan 1.2 feature is defined in its headers. +#if !defined(VMA_BUFFER_DEVICE_ADDRESS) + #if VK_KHR_buffer_device_address || VMA_VULKAN_VERSION >= 1002000 + #define VMA_BUFFER_DEVICE_ADDRESS 1 + #else + #define VMA_BUFFER_DEVICE_ADDRESS 0 + #endif +#endif --# Initialize Vulkan to have `VkPhysicalDevice`, `VkDevice` and `VkInstance` object. --# Fill VmaAllocatorCreateInfo structure and create #VmaAllocator object by - calling vmaCreateAllocator(). +// Defined to 1 when VK_EXT_memory_priority device extension is defined in Vulkan headers. +#if !defined(VMA_MEMORY_PRIORITY) + #if VK_EXT_memory_priority + #define VMA_MEMORY_PRIORITY 1 + #else + #define VMA_MEMORY_PRIORITY 0 + #endif +#endif -\code -VmaAllocatorCreateInfo allocatorInfo = {}; -allocatorInfo.vulkanApiVersion = VK_API_VERSION_1_2; -allocatorInfo.physicalDevice = physicalDevice; -allocatorInfo.device = device; -allocatorInfo.instance = instance; +// Defined to 1 when VK_KHR_external_memory device extension is defined in Vulkan headers. +#if !defined(VMA_EXTERNAL_MEMORY) + #if VK_KHR_external_memory + #define VMA_EXTERNAL_MEMORY 1 + #else + #define VMA_EXTERNAL_MEMORY 0 + #endif +#endif -VmaAllocator allocator; -vmaCreateAllocator(&allocatorInfo, &allocator); -\endcode +// Define these macros to decorate all public functions with additional code, +// before and after returned type, appropriately. This may be useful for +// exporting the functions when compiling VMA as a separate library. Example: +// #define VMA_CALL_PRE __declspec(dllexport) +// #define VMA_CALL_POST __cdecl +#ifndef VMA_CALL_PRE + #define VMA_CALL_PRE +#endif +#ifndef VMA_CALL_POST + #define VMA_CALL_POST +#endif -Only members `physicalDevice`, `device`, `instance` are required. -However, you should inform the library which Vulkan version do you use by setting -VmaAllocatorCreateInfo::vulkanApiVersion and which extensions did you enable -by setting VmaAllocatorCreateInfo::flags (like #VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT for VK_KHR_buffer_device_address). -Otherwise, VMA would use only features of Vulkan 1.0 core with no extensions. +// Define this macro to decorate pointers with an attribute specifying the +// length of the array they point to if they are not null. +// +// The length may be one of +// - The name of another parameter in the argument list where the pointer is declared +// - The name of another member in the struct where the pointer is declared +// - The name of a member of a struct type, meaning the value of that member in +// the context of the call. For example +// VMA_LEN_IF_NOT_NULL("VkPhysicalDeviceMemoryProperties::memoryHeapCount"), +// this means the number of memory heaps available in the device associated +// with the VmaAllocator being dealt with. +#ifndef VMA_LEN_IF_NOT_NULL + #define VMA_LEN_IF_NOT_NULL(len) +#endif +// The VMA_NULLABLE macro is defined to be _Nullable when compiling with Clang. +// see: https://clang.llvm.org/docs/AttributeReference.html#nullable +#ifndef VMA_NULLABLE + #ifdef __clang__ + #define VMA_NULLABLE _Nullable + #else + #define VMA_NULLABLE + #endif +#endif -\section quick_start_resource_allocation Resource allocation +// The VMA_NOT_NULL macro is defined to be _Nonnull when compiling with Clang. +// see: https://clang.llvm.org/docs/AttributeReference.html#nonnull +#ifndef VMA_NOT_NULL + #ifdef __clang__ + #define VMA_NOT_NULL _Nonnull + #else + #define VMA_NOT_NULL + #endif +#endif -When you want to create a buffer or image: +// If non-dispatchable handles are represented as pointers then we can give +// then nullability annotations +#ifndef VMA_NOT_NULL_NON_DISPATCHABLE + #if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__) ) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) + #define VMA_NOT_NULL_NON_DISPATCHABLE VMA_NOT_NULL + #else + #define VMA_NOT_NULL_NON_DISPATCHABLE + #endif +#endif --# Fill `VkBufferCreateInfo` / `VkImageCreateInfo` structure. --# Fill VmaAllocationCreateInfo structure. --# Call vmaCreateBuffer() / vmaCreateImage() to get `VkBuffer`/`VkImage` with memory - already allocated and bound to it. +#ifndef VMA_NULLABLE_NON_DISPATCHABLE + #if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__) ) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) + #define VMA_NULLABLE_NON_DISPATCHABLE VMA_NULLABLE + #else + #define VMA_NULLABLE_NON_DISPATCHABLE + #endif +#endif -\code -VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; -bufferInfo.size = 65536; -bufferInfo.usage = VK_BUFFER_USAGE_VERTEX_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; +#ifndef VMA_STATS_STRING_ENABLED + #define VMA_STATS_STRING_ENABLED 1 +#endif -VmaAllocationCreateInfo allocInfo = {}; -allocInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +// +// INTERFACE +// +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// -VkBuffer buffer; -VmaAllocation allocation; -vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); -\endcode +// Sections for managing code placement in file, only for development purposes e.g. for convenient folding inside an IDE. +#ifndef _VMA_ENUM_DECLARATIONS -Don't forget to destroy your objects when no longer needed: +/** +\addtogroup group_init +@{ +*/ -\code -vmaDestroyBuffer(allocator, buffer, allocation); -vmaDestroyAllocator(allocator); -\endcode +/// Flags for created #VmaAllocator. +typedef enum VmaAllocatorCreateFlagBits +{ + /** \brief Allocator and all objects created from it will not be synchronized internally, so you must guarantee they are used from only one thread at a time or synchronized externally by you. + Using this flag may increase performance because internal mutexes are not used. + */ + VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT = 0x00000001, + /** \brief Enables usage of VK_KHR_dedicated_allocation extension. -\page choosing_memory_type Choosing memory type + The flag works only if VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_0`. + When it is `VK_API_VERSION_1_1`, the flag is ignored because the extension has been promoted to Vulkan 1.1. -Physical devices in Vulkan support various combinations of memory heaps and -types. Help with choosing correct and optimal memory type for your specific -resource is one of the key features of this library. You can use it by filling -appropriate members of VmaAllocationCreateInfo structure, as described below. -You can also combine multiple methods. + Using this extension will automatically allocate dedicated blocks of memory for + some buffers and images instead of suballocating place for them out of bigger + memory blocks (as if you explicitly used #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT + flag) when it is recommended by the driver. It may improve performance on some + GPUs. --# If you just want to find memory type index that meets your requirements, you - can use function: vmaFindMemoryTypeIndex(), vmaFindMemoryTypeIndexForBufferInfo(), - vmaFindMemoryTypeIndexForImageInfo(). --# If you want to allocate a region of device memory without association with any - specific image or buffer, you can use function vmaAllocateMemory(). Usage of - this function is not recommended and usually not needed. - vmaAllocateMemoryPages() function is also provided for creating multiple allocations at once, - which may be useful for sparse binding. --# If you already have a buffer or an image created, you want to allocate memory - for it and then you will bind it yourself, you can use function - vmaAllocateMemoryForBuffer(), vmaAllocateMemoryForImage(). - For binding you should use functions: vmaBindBufferMemory(), vmaBindImageMemory() - or their extended versions: vmaBindBufferMemory2(), vmaBindImageMemory2(). --# If you want to create a buffer or an image, allocate memory for it and bind - them together, all in one call, you can use function vmaCreateBuffer(), - vmaCreateImage(). This is the easiest and recommended way to use this library. + You may set this flag only if you found out that following device extensions are + supported, you enabled them while creating Vulkan device passed as + VmaAllocatorCreateInfo::device, and you want them to be used internally by this + library: -When using 3. or 4., the library internally queries Vulkan for memory types -supported for that buffer or image (function `vkGetBufferMemoryRequirements()`) -and uses only one of these types. + - VK_KHR_get_memory_requirements2 (device extension) + - VK_KHR_dedicated_allocation (device extension) -If no memory type can be found that meets all the requirements, these functions -return `VK_ERROR_FEATURE_NOT_PRESENT`. + When this flag is set, you can experience following warnings reported by Vulkan + validation layer. You can ignore them. -You can leave VmaAllocationCreateInfo structure completely filled with zeros. -It means no requirements are specified for memory type. -It is valid, although not very useful. - -\section choosing_memory_type_usage Usage - -The easiest way to specify memory requirements is to fill member -VmaAllocationCreateInfo::usage using one of the values of enum #VmaMemoryUsage. -It defines high level, common usage types. -For more details, see description of this enum. - -For example, if you want to create a uniform buffer that will be filled using -transfer only once or infrequently and used for rendering every frame, you can -do it using following code: + > vkBindBufferMemory(): Binding memory to buffer 0x2d but vkGetBufferMemoryRequirements() has not been called on that buffer. + */ + VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT = 0x00000002, + /** + Enables usage of VK_KHR_bind_memory2 extension. -\code -VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; -bufferInfo.size = 65536; -bufferInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + The flag works only if VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_0`. + When it is `VK_API_VERSION_1_1`, the flag is ignored because the extension has been promoted to Vulkan 1.1. -VmaAllocationCreateInfo allocInfo = {}; -allocInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; + You may set this flag only if you found out that this device extension is supported, + you enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, + and you want it to be used internally by this library. -VkBuffer buffer; -VmaAllocation allocation; -vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); -\endcode + The extension provides functions `vkBindBufferMemory2KHR` and `vkBindImageMemory2KHR`, + which allow to pass a chain of `pNext` structures while binding. + This flag is required if you use `pNext` parameter in vmaBindBufferMemory2() or vmaBindImageMemory2(). + */ + VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT = 0x00000004, + /** + Enables usage of VK_EXT_memory_budget extension. -\section choosing_memory_type_required_preferred_flags Required and preferred flags + You may set this flag only if you found out that this device extension is supported, + you enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, + and you want it to be used internally by this library, along with another instance extension + VK_KHR_get_physical_device_properties2, which is required by it (or Vulkan 1.1, where this extension is promoted). -You can specify more detailed requirements by filling members -VmaAllocationCreateInfo::requiredFlags and VmaAllocationCreateInfo::preferredFlags -with a combination of bits from enum `VkMemoryPropertyFlags`. For example, -if you want to create a buffer that will be persistently mapped on host (so it -must be `HOST_VISIBLE`) and preferably will also be `HOST_COHERENT` and `HOST_CACHED`, -use following code: + The extension provides query for current memory usage and budget, which will probably + be more accurate than an estimation used by the library otherwise. + */ + VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT = 0x00000008, + /** + Enables usage of VK_AMD_device_coherent_memory extension. -\code -VmaAllocationCreateInfo allocInfo = {}; -allocInfo.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; -allocInfo.preferredFlags = VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT; -allocInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT; + You may set this flag only if you: -VkBuffer buffer; -VmaAllocation allocation; -vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); -\endcode + - found out that this device extension is supported and enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, + - checked that `VkPhysicalDeviceCoherentMemoryFeaturesAMD::deviceCoherentMemory` is true and set it while creating the Vulkan device, + - want it to be used internally by this library. -A memory type is chosen that has all the required flags and as many preferred -flags set as possible. + The extension and accompanying device feature provide access to memory types with + `VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD` and `VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD` flags. + They are useful mostly for writing breadcrumb markers - a common method for debugging GPU crash/hang/TDR. -If you use VmaAllocationCreateInfo::usage, it is just internally converted to -a set of required and preferred flags. + When the extension is not enabled, such memory types are still enumerated, but their usage is illegal. + To protect from this error, if you don't create the allocator with this flag, it will refuse to allocate any memory or create a custom pool in such memory type, + returning `VK_ERROR_FEATURE_NOT_PRESENT`. + */ + VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT = 0x00000010, + /** + Enables usage of "buffer device address" feature, which allows you to use function + `vkGetBufferDeviceAddress*` to get raw GPU pointer to a buffer and pass it for usage inside a shader. -\section choosing_memory_type_explicit_memory_types Explicit memory types + You may set this flag only if you: -If you inspected memory types available on the physical device and you have -a preference for memory types that you want to use, you can fill member -VmaAllocationCreateInfo::memoryTypeBits. It is a bit mask, where each bit set -means that a memory type with that index is allowed to be used for the -allocation. Special value 0, just like `UINT32_MAX`, means there are no -restrictions to memory type index. + 1. (For Vulkan version < 1.2) Found as available and enabled device extension + VK_KHR_buffer_device_address. + This extension is promoted to core Vulkan 1.2. + 2. Found as available and enabled device feature `VkPhysicalDeviceBufferDeviceAddressFeatures::bufferDeviceAddress`. -Please note that this member is NOT just a memory type index. -Still you can use it to choose just one, specific memory type. -For example, if you already determined that your buffer should be created in -memory type 2, use following code: + When this flag is set, you can create buffers with `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT` using VMA. + The library automatically adds `VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT` to + allocated memory blocks wherever it might be needed. -\code -uint32_t memoryTypeIndex = 2; + For more information, see documentation chapter \ref enabling_buffer_device_address. + */ + VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT = 0x00000020, + /** + Enables usage of VK_EXT_memory_priority extension in the library. -VmaAllocationCreateInfo allocInfo = {}; -allocInfo.memoryTypeBits = 1u << memoryTypeIndex; + You may set this flag only if you found available and enabled this device extension, + along with `VkPhysicalDeviceMemoryPriorityFeaturesEXT::memoryPriority == VK_TRUE`, + while creating Vulkan device passed as VmaAllocatorCreateInfo::device. -VkBuffer buffer; -VmaAllocation allocation; -vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); -\endcode + When this flag is used, VmaAllocationCreateInfo::priority and VmaPoolCreateInfo::priority + are used to set priorities of allocated Vulkan memory. Without it, these variables are ignored. + A priority must be a floating-point value between 0 and 1, indicating the priority of the allocation relative to other memory allocations. + Larger values are higher priority. The granularity of the priorities is implementation-dependent. + It is automatically passed to every call to `vkAllocateMemory` done by the library using structure `VkMemoryPriorityAllocateInfoEXT`. + The value to be used for default priority is 0.5. + For more details, see the documentation of the VK_EXT_memory_priority extension. + */ + VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT = 0x00000040, -\section choosing_memory_type_custom_memory_pools Custom memory pools + VMA_ALLOCATOR_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaAllocatorCreateFlagBits; +/// See #VmaAllocatorCreateFlagBits. +typedef VkFlags VmaAllocatorCreateFlags; -If you allocate from custom memory pool, all the ways of specifying memory -requirements described above are not applicable and the aforementioned members -of VmaAllocationCreateInfo structure are ignored. Memory type is selected -explicitly when creating the pool and then used to make all the allocations from -that pool. For further details, see \ref custom_memory_pools. +/** @} */ -\section choosing_memory_type_dedicated_allocations Dedicated allocations +/** +\addtogroup group_alloc +@{ +*/ -Memory for allocations is reserved out of larger block of `VkDeviceMemory` -allocated from Vulkan internally. That's the main feature of this whole library. -You can still request a separate memory block to be created for an allocation, -just like you would do in a trivial solution without using any allocator. -In that case, a buffer or image is always bound to that memory at offset 0. -This is called a "dedicated allocation". -You can explicitly request it by using flag #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. -The library can also internally decide to use dedicated allocation in some cases, e.g.: +/// \brief Intended usage of the allocated memory. +typedef enum VmaMemoryUsage +{ + /** No intended memory usage specified. + Use other members of VmaAllocationCreateInfo to specify your requirements. + */ + VMA_MEMORY_USAGE_UNKNOWN = 0, + /** + \deprecated Obsolete, preserved for backward compatibility. + Prefers `VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT`. + */ + VMA_MEMORY_USAGE_GPU_ONLY = 1, + /** + \deprecated Obsolete, preserved for backward compatibility. + Guarantees `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` and `VK_MEMORY_PROPERTY_HOST_COHERENT_BIT`. + */ + VMA_MEMORY_USAGE_CPU_ONLY = 2, + /** + \deprecated Obsolete, preserved for backward compatibility. + Guarantees `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT`, prefers `VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT`. + */ + VMA_MEMORY_USAGE_CPU_TO_GPU = 3, + /** + \deprecated Obsolete, preserved for backward compatibility. + Guarantees `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT`, prefers `VK_MEMORY_PROPERTY_HOST_CACHED_BIT`. + */ + VMA_MEMORY_USAGE_GPU_TO_CPU = 4, + /** + \deprecated Obsolete, preserved for backward compatibility. + Prefers not `VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT`. + */ + VMA_MEMORY_USAGE_CPU_COPY = 5, + /** + Lazily allocated GPU memory having `VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT`. + Exists mostly on mobile platforms. Using it on desktop PC or other GPUs with no such memory type present will fail the allocation. -- When the size of the allocation is large. -- When [VK_KHR_dedicated_allocation](@ref vk_khr_dedicated_allocation) extension is enabled - and it reports that dedicated allocation is required or recommended for the resource. -- When allocation of next big memory block fails due to not enough device memory, - but allocation with the exact requested size succeeds. + Usage: Memory for transient attachment images (color attachments, depth attachments etc.), created with `VK_IMAGE_USAGE_TRANSIENT_ATTACHMENT_BIT`. + Allocations with this usage are always created as dedicated - it implies #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. + */ + VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED = 6, + /** + Selects best memory type automatically. + This flag is recommended for most common use cases. + + When using this flag, if you want to map the allocation (using vmaMapMemory() or #VMA_ALLOCATION_CREATE_MAPPED_BIT), + you must pass one of the flags: #VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or #VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT + in VmaAllocationCreateInfo::flags. + + It can be used only with functions that let the library know `VkBufferCreateInfo` or `VkImageCreateInfo`, e.g. + vmaCreateBuffer(), vmaCreateImage(), vmaFindMemoryTypeIndexForBufferInfo(), vmaFindMemoryTypeIndexForImageInfo() + and not with generic memory allocation functions. + */ + VMA_MEMORY_USAGE_AUTO = 7, + /** + Selects best memory type automatically with preference for GPU (device) memory. -\page memory_mapping Memory mapping + When using this flag, if you want to map the allocation (using vmaMapMemory() or #VMA_ALLOCATION_CREATE_MAPPED_BIT), + you must pass one of the flags: #VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or #VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT + in VmaAllocationCreateInfo::flags. -To "map memory" in Vulkan means to obtain a CPU pointer to `VkDeviceMemory`, -to be able to read from it or write to it in CPU code. -Mapping is possible only of memory allocated from a memory type that has -`VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` flag. -Functions `vkMapMemory()`, `vkUnmapMemory()` are designed for this purpose. -You can use them directly with memory allocated by this library, -but it is not recommended because of following issue: -Mapping the same `VkDeviceMemory` block multiple times is illegal - only one mapping at a time is allowed. -This includes mapping disjoint regions. Mapping is not reference-counted internally by Vulkan. -Because of this, Vulkan Memory Allocator provides following facilities: + It can be used only with functions that let the library know `VkBufferCreateInfo` or `VkImageCreateInfo`, e.g. + vmaCreateBuffer(), vmaCreateImage(), vmaFindMemoryTypeIndexForBufferInfo(), vmaFindMemoryTypeIndexForImageInfo() + and not with generic memory allocation functions. + */ + VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE = 8, + /** + Selects best memory type automatically with preference for CPU (host) memory. -\section memory_mapping_mapping_functions Mapping functions + When using this flag, if you want to map the allocation (using vmaMapMemory() or #VMA_ALLOCATION_CREATE_MAPPED_BIT), + you must pass one of the flags: #VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or #VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT + in VmaAllocationCreateInfo::flags. -The library provides following functions for mapping of a specific #VmaAllocation: vmaMapMemory(), vmaUnmapMemory(). -They are safer and more convenient to use than standard Vulkan functions. -You can map an allocation multiple times simultaneously - mapping is reference-counted internally. -You can also map different allocations simultaneously regardless of whether they use the same `VkDeviceMemory` block. -The way it's implemented is that the library always maps entire memory block, not just region of the allocation. -For further details, see description of vmaMapMemory() function. -Example: + It can be used only with functions that let the library know `VkBufferCreateInfo` or `VkImageCreateInfo`, e.g. + vmaCreateBuffer(), vmaCreateImage(), vmaFindMemoryTypeIndexForBufferInfo(), vmaFindMemoryTypeIndexForImageInfo() + and not with generic memory allocation functions. + */ + VMA_MEMORY_USAGE_AUTO_PREFER_HOST = 9, -\code -// Having these objects initialized: + VMA_MEMORY_USAGE_MAX_ENUM = 0x7FFFFFFF +} VmaMemoryUsage; -struct ConstantBuffer +/// Flags to be passed as VmaAllocationCreateInfo::flags. +typedef enum VmaAllocationCreateFlagBits { - ... -}; -ConstantBuffer constantBufferData; - -VmaAllocator allocator; -VkBuffer constantBuffer; -VmaAllocation constantBufferAllocation; + /** \brief Set this flag if the allocation should have its own memory block. -// You can map and fill your buffer using following code: + Use it for special, big resources, like fullscreen images used as attachments. + */ + VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT = 0x00000001, -void* mappedData; -vmaMapMemory(allocator, constantBufferAllocation, &mappedData); -memcpy(mappedData, &constantBufferData, sizeof(constantBufferData)); -vmaUnmapMemory(allocator, constantBufferAllocation); -\endcode + /** \brief Set this flag to only try to allocate from existing `VkDeviceMemory` blocks and never create new such block. -When mapping, you may see a warning from Vulkan validation layer similar to this one: + If new allocation cannot be placed in any of the existing blocks, allocation + fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY` error. -Mapping an image with layout VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL can result in undefined behavior if this memory is used by the device. Only GENERAL or PREINITIALIZED should be used. + You should not use #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT and + #VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT at the same time. It makes no sense. + */ + VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT = 0x00000002, + /** \brief Set this flag to use a memory that will be persistently mapped and retrieve pointer to it. -It happens because the library maps entire `VkDeviceMemory` block, where different -types of images and buffers may end up together, especially on GPUs with unified memory like Intel. -You can safely ignore it if you are sure you access only memory of the intended -object that you wanted to map. + Pointer to mapped memory will be returned through VmaAllocationInfo::pMappedData. + It is valid to use this flag for allocation made from memory type that is not + `HOST_VISIBLE`. This flag is then ignored and memory is not mapped. This is + useful if you need an allocation that is efficient to use on GPU + (`DEVICE_LOCAL`) and still want to map it directly if possible on platforms that + support it (e.g. Intel GPU). + */ + VMA_ALLOCATION_CREATE_MAPPED_BIT = 0x00000004, + /** \deprecated Preserved for backward compatibility. Consider using vmaSetAllocationName() instead. + + Set this flag to treat VmaAllocationCreateInfo::pUserData as pointer to a + null-terminated string. Instead of copying pointer value, a local copy of the + string is made and stored in allocation's `pName`. The string is automatically + freed together with the allocation. It is also used in vmaBuildStatsString(). + */ + VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT = 0x00000020, + /** Allocation will be created from upper stack in a double stack pool. -\section memory_mapping_persistently_mapped_memory Persistently mapped memory + This flag is only allowed for custom pools created with #VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT flag. + */ + VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT = 0x00000040, + /** Create both buffer/image and allocation, but don't bind them together. + It is useful when you want to bind yourself to do some more advanced binding, e.g. using some extensions. + The flag is meaningful only with functions that bind by default: vmaCreateBuffer(), vmaCreateImage(). + Otherwise it is ignored. -Kepping your memory persistently mapped is generally OK in Vulkan. -You don't need to unmap it before using its data on the GPU. -The library provides a special feature designed for that: -Allocations made with #VMA_ALLOCATION_CREATE_MAPPED_BIT flag set in -VmaAllocationCreateInfo::flags stay mapped all the time, -so you can just access CPU pointer to it any time -without a need to call any "map" or "unmap" function. -Example: + If you want to make sure the new buffer/image is not tied to the new memory allocation + through `VkMemoryDedicatedAllocateInfoKHR` structure in case the allocation ends up in its own memory block, + use also flag #VMA_ALLOCATION_CREATE_CAN_ALIAS_BIT. + */ + VMA_ALLOCATION_CREATE_DONT_BIND_BIT = 0x00000080, + /** Create allocation only if additional device memory required for it, if any, won't exceed + memory budget. Otherwise return `VK_ERROR_OUT_OF_DEVICE_MEMORY`. + */ + VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT = 0x00000100, + /** \brief Set this flag if the allocated memory will have aliasing resources. + + Usage of this flag prevents supplying `VkMemoryDedicatedAllocateInfoKHR` when #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT is specified. + Otherwise created dedicated memory will not be suitable for aliasing resources, resulting in Vulkan Validation Layer errors. + */ + VMA_ALLOCATION_CREATE_CAN_ALIAS_BIT = 0x00000200, + /** + Requests possibility to map the allocation (using vmaMapMemory() or #VMA_ALLOCATION_CREATE_MAPPED_BIT). + + - If you use #VMA_MEMORY_USAGE_AUTO or other `VMA_MEMORY_USAGE_AUTO*` value, + you must use this flag to be able to map the allocation. Otherwise, mapping is incorrect. + - If you use other value of #VmaMemoryUsage, this flag is ignored and mapping is always possible in memory types that are `HOST_VISIBLE`. + This includes allocations created in \ref custom_memory_pools. + + Declares that mapped memory will only be written sequentially, e.g. using `memcpy()` or a loop writing number-by-number, + never read or accessed randomly, so a memory type can be selected that is uncached and write-combined. + + \warning Violating this declaration may work correctly, but will likely be very slow. + Watch out for implicit reads introduced by doing e.g. `pMappedData[i] += x;` + Better prepare your data in a local variable and `memcpy()` it to the mapped pointer all at once. + */ + VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT = 0x00000400, + /** + Requests possibility to map the allocation (using vmaMapMemory() or #VMA_ALLOCATION_CREATE_MAPPED_BIT). + + - If you use #VMA_MEMORY_USAGE_AUTO or other `VMA_MEMORY_USAGE_AUTO*` value, + you must use this flag to be able to map the allocation. Otherwise, mapping is incorrect. + - If you use other value of #VmaMemoryUsage, this flag is ignored and mapping is always possible in memory types that are `HOST_VISIBLE`. + This includes allocations created in \ref custom_memory_pools. + + Declares that mapped memory can be read, written, and accessed in random order, + so a `HOST_CACHED` memory type is required. + */ + VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT = 0x00000800, + /** + Together with #VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or #VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT, + it says that despite request for host access, a not-`HOST_VISIBLE` memory type can be selected + if it may improve performance. + + By using this flag, you declare that you will check if the allocation ended up in a `HOST_VISIBLE` memory type + (e.g. using vmaGetAllocationMemoryProperties()) and if not, you will create some "staging" buffer and + issue an explicit transfer to write/read your data. + To prepare for this possibility, don't forget to add appropriate flags like + `VK_BUFFER_USAGE_TRANSFER_DST_BIT`, `VK_BUFFER_USAGE_TRANSFER_SRC_BIT` to the parameters of created buffer or image. + */ + VMA_ALLOCATION_CREATE_HOST_ACCESS_ALLOW_TRANSFER_INSTEAD_BIT = 0x00001000, + /** Allocation strategy that chooses smallest possible free range for the allocation + to minimize memory usage and fragmentation, possibly at the expense of allocation time. + */ + VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT = 0x00010000, + /** Allocation strategy that chooses first suitable free range for the allocation - + not necessarily in terms of the smallest offset but the one that is easiest and fastest to find + to minimize allocation time, possibly at the expense of allocation quality. + */ + VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT = 0x00020000, + /** Allocation strategy that chooses always the lowest offset in available space. + This is not the most efficient strategy but achieves highly packed data. + Used internally by defragmentation, not recomended in typical usage. + */ + VMA_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT = 0x00040000, + /** Alias to #VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT. + */ + VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT = VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT, + /** Alias to #VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT. + */ + VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT = VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT, + /** A bit mask to extract only `STRATEGY` bits from entire set of flags. + */ + VMA_ALLOCATION_CREATE_STRATEGY_MASK = + VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT | + VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT | + VMA_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT, -\code -VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; -bufCreateInfo.size = sizeof(ConstantBuffer); -bufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT; + VMA_ALLOCATION_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaAllocationCreateFlagBits; +/// See #VmaAllocationCreateFlagBits. +typedef VkFlags VmaAllocationCreateFlags; -VmaAllocationCreateInfo allocCreateInfo = {}; -allocCreateInfo.usage = VMA_MEMORY_USAGE_CPU_ONLY; -allocCreateInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT; +/// Flags to be passed as VmaPoolCreateInfo::flags. +typedef enum VmaPoolCreateFlagBits +{ + /** \brief Use this flag if you always allocate only buffers and linear images or only optimal images out of this pool and so Buffer-Image Granularity can be ignored. -VkBuffer buf; -VmaAllocation alloc; -VmaAllocationInfo allocInfo; -vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); + This is an optional optimization flag. -// Buffer is already mapped. You can access its memory. -memcpy(allocInfo.pMappedData, &constantBufferData, sizeof(constantBufferData)); -\endcode + If you always allocate using vmaCreateBuffer(), vmaCreateImage(), + vmaAllocateMemoryForBuffer(), then you don't need to use it because allocator + knows exact type of your allocations so it can handle Buffer-Image Granularity + in the optimal way. -There are some exceptions though, when you should consider mapping memory only for a short period of time: - -- When operating system is Windows 7 or 8.x (Windows 10 is not affected because it uses WDDM2), - device is discrete AMD GPU, - and memory type is the special 256 MiB pool of `DEVICE_LOCAL + HOST_VISIBLE` memory - (selected when you use #VMA_MEMORY_USAGE_CPU_TO_GPU), - then whenever a memory block allocated from this memory type stays mapped - for the time of any call to `vkQueueSubmit()` or `vkQueuePresentKHR()`, this - block is migrated by WDDM to system RAM, which degrades performance. It doesn't - matter if that particular memory block is actually used by the command buffer - being submitted. -- On Mac/MoltenVK there is a known bug - [Issue #175](https://github.com/KhronosGroup/MoltenVK/issues/175) - which requires unmapping before GPU can see updated texture. -- Keeping many large memory blocks mapped may impact performance or stability of some debugging tools. + If you also allocate using vmaAllocateMemoryForImage() or vmaAllocateMemory(), + exact type of such allocations is not known, so allocator must be conservative + in handling Buffer-Image Granularity, which can lead to suboptimal allocation + (wasted memory). In that case, if you can make sure you always allocate only + buffers and linear images or only optimal images out of this pool, use this flag + to make allocator disregard Buffer-Image Granularity and so make allocations + faster and more optimal. + */ + VMA_POOL_CREATE_IGNORE_BUFFER_IMAGE_GRANULARITY_BIT = 0x00000002, -\section memory_mapping_cache_control Cache flush and invalidate + /** \brief Enables alternative, linear allocation algorithm in this pool. -Memory in Vulkan doesn't need to be unmapped before using it on GPU, -but unless a memory types has `VK_MEMORY_PROPERTY_HOST_COHERENT_BIT` flag set, -you need to manually **invalidate** cache before reading of mapped pointer -and **flush** cache after writing to mapped pointer. -Map/unmap operations don't do that automatically. -Vulkan provides following functions for this purpose `vkFlushMappedMemoryRanges()`, -`vkInvalidateMappedMemoryRanges()`, but this library provides more convenient -functions that refer to given allocation object: vmaFlushAllocation(), -vmaInvalidateAllocation(), -or multiple objects at once: vmaFlushAllocations(), vmaInvalidateAllocations(). + Specify this flag to enable linear allocation algorithm, which always creates + new allocations after last one and doesn't reuse space from allocations freed in + between. It trades memory consumption for simplified algorithm and data + structure, which has better performance and uses less memory for metadata. -Regions of memory specified for flush/invalidate must be aligned to -`VkPhysicalDeviceLimits::nonCoherentAtomSize`. This is automatically ensured by the library. -In any memory type that is `HOST_VISIBLE` but not `HOST_COHERENT`, all allocations -within blocks are aligned to this value, so their offsets are always multiply of -`nonCoherentAtomSize` and two different allocations never share same "line" of this size. + By using this flag, you can achieve behavior of free-at-once, stack, + ring buffer, and double stack. + For details, see documentation chapter \ref linear_algorithm. + */ + VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT = 0x00000004, -Please note that memory allocated with #VMA_MEMORY_USAGE_CPU_ONLY is guaranteed to be `HOST_COHERENT`. + /** Bit mask to extract only `ALGORITHM` bits from entire set of flags. + */ + VMA_POOL_CREATE_ALGORITHM_MASK = + VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT, -Also, Windows drivers from all 3 **PC** GPU vendors (AMD, Intel, NVIDIA) -currently provide `HOST_COHERENT` flag on all memory types that are -`HOST_VISIBLE`, so on this platform you may not need to bother. + VMA_POOL_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaPoolCreateFlagBits; +/// Flags to be passed as VmaPoolCreateInfo::flags. See #VmaPoolCreateFlagBits. +typedef VkFlags VmaPoolCreateFlags; -\section memory_mapping_finding_if_memory_mappable Finding out if memory is mappable +/// Flags to be passed as VmaDefragmentationInfo::flags. +typedef enum VmaDefragmentationFlagBits +{ + /* \brief Use simple but fast algorithm for defragmentation. + May not achieve best results but will require least time to compute and least allocations to copy. + */ + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_FAST_BIT = 0x1, + /* \brief Default defragmentation algorithm, applied also when no `ALGORITHM` flag is specified. + Offers a balance between defragmentation quality and the amount of allocations and bytes that need to be moved. + */ + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_BALANCED_BIT = 0x2, + /* \brief Perform full defragmentation of memory. + Can result in notably more time to compute and allocations to copy, but will achieve best memory packing. + */ + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_FULL_BIT = 0x4, + /** \brief Use the most roboust algorithm at the cost of time to compute and number of copies to make. + Only available when bufferImageGranularity is greater than 1, since it aims to reduce + alignment issues between different types of resources. + Otherwise falls back to same behavior as #VMA_DEFRAGMENTATION_FLAG_ALGORITHM_FULL_BIT. + */ + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_EXTENSIVE_BIT = 0x8, -It may happen that your allocation ends up in memory that is `HOST_VISIBLE` (available for mapping) -despite it wasn't explicitly requested. -For example, application may work on integrated graphics with unified memory (like Intel) or -allocation from video memory might have failed, so the library chose system memory as fallback. + /// A bit mask to extract only `ALGORITHM` bits from entire set of flags. + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_MASK = + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_FAST_BIT | + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_BALANCED_BIT | + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_FULL_BIT | + VMA_DEFRAGMENTATION_FLAG_ALGORITHM_EXTENSIVE_BIT, -You can detect this case and map such allocation to access its memory on CPU directly, -instead of launching a transfer operation. -In order to do that: inspect `allocInfo.memoryType`, call vmaGetMemoryTypeProperties(), -and look for `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` flag in properties of that memory type. + VMA_DEFRAGMENTATION_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaDefragmentationFlagBits; +/// See #VmaDefragmentationFlagBits. +typedef VkFlags VmaDefragmentationFlags; -\code -VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; -bufCreateInfo.size = sizeof(ConstantBuffer); -bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; +/// Operation performed on single defragmentation move. See structure #VmaDefragmentationMove. +typedef enum VmaDefragmentationMoveOperation +{ + /// Buffer/image has been recreated at `dstTmpAllocation`, data has been copied, old buffer/image has been destroyed. `srcAllocation` should be changed to point to the new place. This is the default value set by vmaBeginDefragmentationPass(). + VMA_DEFRAGMENTATION_MOVE_OPERATION_COPY = 0, + /// Set this value if you cannot move the allocation. New place reserved at `dstTmpAllocation` will be freed. `srcAllocation` will remain unchanged. + VMA_DEFRAGMENTATION_MOVE_OPERATION_IGNORE = 1, + /// Set this value if you decide to abandon the allocation and you destroyed the buffer/image. New place reserved at `dstTmpAllocation` will be freed, along with `srcAllocation`, which will be destroyed. + VMA_DEFRAGMENTATION_MOVE_OPERATION_DESTROY = 2, +} VmaDefragmentationMoveOperation; -VmaAllocationCreateInfo allocCreateInfo = {}; -allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; -allocCreateInfo.preferredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; +/** @} */ -VkBuffer buf; -VmaAllocation alloc; -VmaAllocationInfo allocInfo; -vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); +/** +\addtogroup group_virtual +@{ +*/ -VkMemoryPropertyFlags memFlags; -vmaGetMemoryTypeProperties(allocator, allocInfo.memoryType, &memFlags); -if((memFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0) -{ - // Allocation ended up in mappable memory. You can map it and access it directly. - void* mappedData; - vmaMapMemory(allocator, alloc, &mappedData); - memcpy(mappedData, &constantBufferData, sizeof(constantBufferData)); - vmaUnmapMemory(allocator, alloc); -} -else +/// Flags to be passed as VmaVirtualBlockCreateInfo::flags. +typedef enum VmaVirtualBlockCreateFlagBits { - // Allocation ended up in non-mappable memory. - // You need to create CPU-side buffer in VMA_MEMORY_USAGE_CPU_ONLY and make a transfer. -} -\endcode + /** \brief Enables alternative, linear allocation algorithm in this virtual block. -You can even use #VMA_ALLOCATION_CREATE_MAPPED_BIT flag while creating allocations -that are not necessarily `HOST_VISIBLE` (e.g. using #VMA_MEMORY_USAGE_GPU_ONLY). -If the allocation ends up in memory type that is `HOST_VISIBLE`, it will be persistently mapped and you can use it directly. -If not, the flag is just ignored. -Example: + Specify this flag to enable linear allocation algorithm, which always creates + new allocations after last one and doesn't reuse space from allocations freed in + between. It trades memory consumption for simplified algorithm and data + structure, which has better performance and uses less memory for metadata. -\code -VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; -bufCreateInfo.size = sizeof(ConstantBuffer); -bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + By using this flag, you can achieve behavior of free-at-once, stack, + ring buffer, and double stack. + For details, see documentation chapter \ref linear_algorithm. + */ + VMA_VIRTUAL_BLOCK_CREATE_LINEAR_ALGORITHM_BIT = 0x00000001, -VmaAllocationCreateInfo allocCreateInfo = {}; -allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; -allocCreateInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT; + /** \brief Bit mask to extract only `ALGORITHM` bits from entire set of flags. + */ + VMA_VIRTUAL_BLOCK_CREATE_ALGORITHM_MASK = + VMA_VIRTUAL_BLOCK_CREATE_LINEAR_ALGORITHM_BIT, -VkBuffer buf; -VmaAllocation alloc; -VmaAllocationInfo allocInfo; -vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); + VMA_VIRTUAL_BLOCK_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaVirtualBlockCreateFlagBits; +/// Flags to be passed as VmaVirtualBlockCreateInfo::flags. See #VmaVirtualBlockCreateFlagBits. +typedef VkFlags VmaVirtualBlockCreateFlags; -if(allocInfo.pMappedData != nullptr) -{ - // Allocation ended up in mappable memory. - // It's persistently mapped. You can access it directly. - memcpy(allocInfo.pMappedData, &constantBufferData, sizeof(constantBufferData)); -} -else +/// Flags to be passed as VmaVirtualAllocationCreateInfo::flags. +typedef enum VmaVirtualAllocationCreateFlagBits { - // Allocation ended up in non-mappable memory. - // You need to create CPU-side buffer in VMA_MEMORY_USAGE_CPU_ONLY and make a transfer. -} -\endcode + /** \brief Allocation will be created from upper stack in a double stack pool. + This flag is only allowed for virtual blocks created with #VMA_VIRTUAL_BLOCK_CREATE_LINEAR_ALGORITHM_BIT flag. + */ + VMA_VIRTUAL_ALLOCATION_CREATE_UPPER_ADDRESS_BIT = VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT, + /** \brief Allocation strategy that tries to minimize memory usage. + */ + VMA_VIRTUAL_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT = VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT, + /** \brief Allocation strategy that tries to minimize allocation time. + */ + VMA_VIRTUAL_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT = VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT, + /** Allocation strategy that chooses always the lowest offset in available space. + This is not the most efficient strategy but achieves highly packed data. + */ + VMA_VIRTUAL_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT = VMA_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT, + /** \brief A bit mask to extract only `STRATEGY` bits from entire set of flags. -\page staying_within_budget Staying within budget - -When developing a graphics-intensive game or program, it is important to avoid allocating -more GPU memory than it's physically available. When the memory is over-committed, -various bad things can happen, depending on the specific GPU, graphics driver, and -operating system: + These strategy flags are binary compatible with equivalent flags in #VmaAllocationCreateFlagBits. + */ + VMA_VIRTUAL_ALLOCATION_CREATE_STRATEGY_MASK = VMA_ALLOCATION_CREATE_STRATEGY_MASK, -- It may just work without any problems. -- The application may slow down because some memory blocks are moved to system RAM - and the GPU has to access them through PCI Express bus. -- A new allocation may take very long time to complete, even few seconds, and possibly - freeze entire system. -- The new allocation may fail with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. -- It may even result in GPU crash (TDR), observed as `VK_ERROR_DEVICE_LOST` - returned somewhere later. + VMA_VIRTUAL_ALLOCATION_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaVirtualAllocationCreateFlagBits; +/// Flags to be passed as VmaVirtualAllocationCreateInfo::flags. See #VmaVirtualAllocationCreateFlagBits. +typedef VkFlags VmaVirtualAllocationCreateFlags; -\section staying_within_budget_querying_for_budget Querying for budget +/** @} */ -To query for current memory usage and available budget, use function vmaGetBudget(). -Returned structure #VmaBudget contains quantities expressed in bytes, per Vulkan memory heap. +#endif // _VMA_ENUM_DECLARATIONS -Please note that this function returns different information and works faster than -vmaCalculateStats(). vmaGetBudget() can be called every frame or even before every -allocation, while vmaCalculateStats() is intended to be used rarely, -only to obtain statistical information, e.g. for debugging purposes. +#ifndef _VMA_DATA_TYPES_DECLARATIONS -It is recommended to use VK_EXT_memory_budget device extension to obtain information -about the budget from Vulkan device. VMA is able to use this extension automatically. -When not enabled, the allocator behaves same way, but then it estimates current usage -and available budget based on its internal information and Vulkan memory heap sizes, -which may be less precise. In order to use this extension: +/** +\addtogroup group_init +@{ */ -1. Make sure extensions VK_EXT_memory_budget and VK_KHR_get_physical_device_properties2 - required by it are available and enable them. Please note that the first is a device - extension and the second is instance extension! -2. Use flag #VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT when creating #VmaAllocator object. -3. Make sure to call vmaSetCurrentFrameIndex() every frame. Budget is queried from - Vulkan inside of it to avoid overhead of querying it with every allocation. +/** \struct VmaAllocator +\brief Represents main object of this library initialized. -\section staying_within_budget_controlling_memory_usage Controlling memory usage +Fill structure #VmaAllocatorCreateInfo and call function vmaCreateAllocator() to create it. +Call function vmaDestroyAllocator() to destroy it. -There are many ways in which you can try to stay within the budget. +It is recommended to create just one object of this type per `VkDevice` object, +right after Vulkan is initialized and keep it alive until before Vulkan device is destroyed. +*/ +VK_DEFINE_HANDLE(VmaAllocator) -First, when making new allocation requires allocating a new memory block, the library -tries not to exceed the budget automatically. If a block with default recommended size -(e.g. 256 MB) would go over budget, a smaller block is allocated, possibly even -dedicated memory for just this resource. +/** @} */ -If the size of the requested resource plus current memory usage is more than the -budget, by default the library still tries to create it, leaving it to the Vulkan -implementation whether the allocation succeeds or fails. You can change this behavior -by using #VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT flag. With it, the allocation is -not made if it would exceed the budget or if the budget is already exceeded. -Some other allocations become lost instead to make room for it, if the mechanism of -[lost allocations](@ref lost_allocations) is used. -If that is not possible, the allocation fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. -Example usage pattern may be to pass the #VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT flag -when creating resources that are not essential for the application (e.g. the texture -of a specific object) and not to pass it when creating critically important resources -(e.g. render targets). +/** +\addtogroup group_alloc +@{ +*/ -Finally, you can also use #VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT flag to make sure -a new allocation is created only when it fits inside one of the existing memory blocks. -If it would require to allocate a new block, if fails instead with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. -This also ensures that the function call is very fast because it never goes to Vulkan -to obtain a new block. +/** \struct VmaPool +\brief Represents custom memory pool -Please note that creating \ref custom_memory_pools with VmaPoolCreateInfo::minBlockCount -set to more than 0 will try to allocate memory blocks without checking whether they -fit within budget. +Fill structure VmaPoolCreateInfo and call function vmaCreatePool() to create it. +Call function vmaDestroyPool() to destroy it. +For more information see [Custom memory pools](@ref choosing_memory_type_custom_memory_pools). +*/ +VK_DEFINE_HANDLE(VmaPool) -\page resource_aliasing Resource aliasing (overlap) +/** \struct VmaAllocation +\brief Represents single memory allocation. -New explicit graphics APIs (Vulkan and Direct3D 12), thanks to manual memory -management, give an opportunity to alias (overlap) multiple resources in the -same region of memory - a feature not available in the old APIs (Direct3D 11, OpenGL). -It can be useful to save video memory, but it must be used with caution. +It may be either dedicated block of `VkDeviceMemory` or a specific region of a bigger block of this type +plus unique offset. -For example, if you know the flow of your whole render frame in advance, you -are going to use some intermediate textures or buffers only during a small range of render passes, -and you know these ranges don't overlap in time, you can bind these resources to -the same place in memory, even if they have completely different parameters (width, height, format etc.). +There are multiple ways to create such object. +You need to fill structure VmaAllocationCreateInfo. +For more information see [Choosing memory type](@ref choosing_memory_type). -![Resource aliasing (overlap)](../gfx/Aliasing.png) +Although the library provides convenience functions that create Vulkan buffer or image, +allocate memory for it and bind them together, +binding of the allocation to a buffer or an image is out of scope of the allocation itself. +Allocation object can exist without buffer/image bound, +binding can be done manually by the user, and destruction of it can be done +independently of destruction of the allocation. -Such scenario is possible using VMA, but you need to create your images manually. -Then you need to calculate parameters of an allocation to be made using formula: +The object also remembers its size and some other information. +To retrieve this information, use function vmaGetAllocationInfo() and inspect +returned structure VmaAllocationInfo. +*/ +VK_DEFINE_HANDLE(VmaAllocation) -- allocation size = max(size of each image) -- allocation alignment = max(alignment of each image) -- allocation memoryTypeBits = bitwise AND(memoryTypeBits of each image) +/** \struct VmaDefragmentationContext +\brief An opaque object that represents started defragmentation process. -Following example shows two different images bound to the same place in memory, -allocated to fit largest of them. +Fill structure #VmaDefragmentationInfo and call function vmaBeginDefragmentation() to create it. +Call function vmaEndDefragmentation() to destroy it. +*/ +VK_DEFINE_HANDLE(VmaDefragmentationContext) -\code -// A 512x512 texture to be sampled. -VkImageCreateInfo img1CreateInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; -img1CreateInfo.imageType = VK_IMAGE_TYPE_2D; -img1CreateInfo.extent.width = 512; -img1CreateInfo.extent.height = 512; -img1CreateInfo.extent.depth = 1; -img1CreateInfo.mipLevels = 10; -img1CreateInfo.arrayLayers = 1; -img1CreateInfo.format = VK_FORMAT_R8G8B8A8_SRGB; -img1CreateInfo.tiling = VK_IMAGE_TILING_OPTIMAL; -img1CreateInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; -img1CreateInfo.usage = VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_SAMPLED_BIT; -img1CreateInfo.samples = VK_SAMPLE_COUNT_1_BIT; +/** @} */ -// A full screen texture to be used as color attachment. -VkImageCreateInfo img2CreateInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; -img2CreateInfo.imageType = VK_IMAGE_TYPE_2D; -img2CreateInfo.extent.width = 1920; -img2CreateInfo.extent.height = 1080; -img2CreateInfo.extent.depth = 1; -img2CreateInfo.mipLevels = 1; -img2CreateInfo.arrayLayers = 1; -img2CreateInfo.format = VK_FORMAT_R8G8B8A8_UNORM; -img2CreateInfo.tiling = VK_IMAGE_TILING_OPTIMAL; -img2CreateInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; -img2CreateInfo.usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT; -img2CreateInfo.samples = VK_SAMPLE_COUNT_1_BIT; +/** +\addtogroup group_virtual +@{ +*/ -VkImage img1; -res = vkCreateImage(device, &img1CreateInfo, nullptr, &img1); -VkImage img2; -res = vkCreateImage(device, &img2CreateInfo, nullptr, &img2); +/** \struct VmaVirtualAllocation +\brief Represents single memory allocation done inside VmaVirtualBlock. -VkMemoryRequirements img1MemReq; -vkGetImageMemoryRequirements(device, img1, &img1MemReq); -VkMemoryRequirements img2MemReq; -vkGetImageMemoryRequirements(device, img2, &img2MemReq); +Use it as a unique identifier to virtual allocation within the single block. -VkMemoryRequirements finalMemReq = {}; -finalMemReq.size = std::max(img1MemReq.size, img2MemReq.size); -finalMemReq.alignment = std::max(img1MemReq.alignment, img2MemReq.alignment); -finalMemReq.memoryTypeBits = img1MemReq.memoryTypeBits & img2MemReq.memoryTypeBits; -// Validate if(finalMemReq.memoryTypeBits != 0) +Use value `VK_NULL_HANDLE` to represent a null/invalid allocation. +*/ +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VmaVirtualAllocation); -VmaAllocationCreateInfo allocCreateInfo = {}; -allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; +/** @} */ -VmaAllocation alloc; -res = vmaAllocateMemory(allocator, &finalMemReq, &allocCreateInfo, &alloc, nullptr); +/** +\addtogroup group_virtual +@{ +*/ -res = vmaBindImageMemory(allocator, alloc, img1); -res = vmaBindImageMemory(allocator, alloc, img2); +/** \struct VmaVirtualBlock +\brief Handle to a virtual block object that allows to use core allocation algorithm without allocating any real GPU memory. -// You can use img1, img2 here, but not at the same time! +Fill in #VmaVirtualBlockCreateInfo structure and use vmaCreateVirtualBlock() to create it. Use vmaDestroyVirtualBlock() to destroy it. +For more information, see documentation chapter \ref virtual_allocator. -vmaFreeMemory(allocator, alloc); -vkDestroyImage(allocator, img2, nullptr); -vkDestroyImage(allocator, img1, nullptr); -\endcode +This object is not thread-safe - should not be used from multiple threads simultaneously, must be synchronized externally. +*/ +VK_DEFINE_HANDLE(VmaVirtualBlock) -Remember that using resouces that alias in memory requires proper synchronization. -You need to issue a memory barrier to make sure commands that use `img1` and `img2` -don't overlap on GPU timeline. -You also need to treat a resource after aliasing as uninitialized - containing garbage data. -For example, if you use `img1` and then want to use `img2`, you need to issue -an image memory barrier for `img2` with `oldLayout` = `VK_IMAGE_LAYOUT_UNDEFINED`. +/** @} */ -Additional considerations: +/** +\addtogroup group_init +@{ +*/ -- Vulkan also allows to interpret contents of memory between aliasing resources consistently in some cases. -See chapter 11.8. "Memory Aliasing" of Vulkan specification or `VK_IMAGE_CREATE_ALIAS_BIT` flag. -- You can create more complex layout where different images and buffers are bound -at different offsets inside one large allocation. For example, one can imagine -a big texture used in some render passes, aliasing with a set of many small buffers -used between in some further passes. To bind a resource at non-zero offset of an allocation, -use vmaBindBufferMemory2() / vmaBindImageMemory2(). -- Before allocating memory for the resources you want to alias, check `memoryTypeBits` -returned in memory requirements of each resource to make sure the bits overlap. -Some GPUs may expose multiple memory types suitable e.g. only for buffers or -images with `COLOR_ATTACHMENT` usage, so the sets of memory types supported by your -resources may be disjoint. Aliasing them is not possible in that case. +/// Callback function called after successful vkAllocateMemory. +typedef void (VKAPI_PTR* PFN_vmaAllocateDeviceMemoryFunction)( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryType, + VkDeviceMemory VMA_NOT_NULL_NON_DISPATCHABLE memory, + VkDeviceSize size, + void* VMA_NULLABLE pUserData); +/// Callback function called before vkFreeMemory. +typedef void (VKAPI_PTR* PFN_vmaFreeDeviceMemoryFunction)( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryType, + VkDeviceMemory VMA_NOT_NULL_NON_DISPATCHABLE memory, + VkDeviceSize size, + void* VMA_NULLABLE pUserData); -\page custom_memory_pools Custom memory pools +/** \brief Set of callbacks that the library will call for `vkAllocateMemory` and `vkFreeMemory`. -A memory pool contains a number of `VkDeviceMemory` blocks. -The library automatically creates and manages default pool for each memory type available on the device. -Default memory pool automatically grows in size. -Size of allocated blocks is also variable and managed automatically. +Provided for informative purpose, e.g. to gather statistics about number of +allocations or total amount of memory allocated in Vulkan. -You can create custom pool and allocate memory out of it. -It can be useful if you want to: +Used in VmaAllocatorCreateInfo::pDeviceMemoryCallbacks. +*/ +typedef struct VmaDeviceMemoryCallbacks +{ + /// Optional, can be null. + PFN_vmaAllocateDeviceMemoryFunction VMA_NULLABLE pfnAllocate; + /// Optional, can be null. + PFN_vmaFreeDeviceMemoryFunction VMA_NULLABLE pfnFree; + /// Optional, can be null. + void* VMA_NULLABLE pUserData; +} VmaDeviceMemoryCallbacks; -- Keep certain kind of allocations separate from others. -- Enforce particular, fixed size of Vulkan memory blocks. -- Limit maximum amount of Vulkan memory allocated for that pool. -- Reserve minimum or fixed amount of Vulkan memory always preallocated for that pool. +/** \brief Pointers to some Vulkan functions - a subset used by the library. -To use custom memory pools: +Used in VmaAllocatorCreateInfo::pVulkanFunctions. +*/ +typedef struct VmaVulkanFunctions +{ + /// Required when using VMA_DYNAMIC_VULKAN_FUNCTIONS. + PFN_vkGetInstanceProcAddr VMA_NULLABLE vkGetInstanceProcAddr; + /// Required when using VMA_DYNAMIC_VULKAN_FUNCTIONS. + PFN_vkGetDeviceProcAddr VMA_NULLABLE vkGetDeviceProcAddr; + PFN_vkGetPhysicalDeviceProperties VMA_NULLABLE vkGetPhysicalDeviceProperties; + PFN_vkGetPhysicalDeviceMemoryProperties VMA_NULLABLE vkGetPhysicalDeviceMemoryProperties; + PFN_vkAllocateMemory VMA_NULLABLE vkAllocateMemory; + PFN_vkFreeMemory VMA_NULLABLE vkFreeMemory; + PFN_vkMapMemory VMA_NULLABLE vkMapMemory; + PFN_vkUnmapMemory VMA_NULLABLE vkUnmapMemory; + PFN_vkFlushMappedMemoryRanges VMA_NULLABLE vkFlushMappedMemoryRanges; + PFN_vkInvalidateMappedMemoryRanges VMA_NULLABLE vkInvalidateMappedMemoryRanges; + PFN_vkBindBufferMemory VMA_NULLABLE vkBindBufferMemory; + PFN_vkBindImageMemory VMA_NULLABLE vkBindImageMemory; + PFN_vkGetBufferMemoryRequirements VMA_NULLABLE vkGetBufferMemoryRequirements; + PFN_vkGetImageMemoryRequirements VMA_NULLABLE vkGetImageMemoryRequirements; + PFN_vkCreateBuffer VMA_NULLABLE vkCreateBuffer; + PFN_vkDestroyBuffer VMA_NULLABLE vkDestroyBuffer; + PFN_vkCreateImage VMA_NULLABLE vkCreateImage; + PFN_vkDestroyImage VMA_NULLABLE vkDestroyImage; + PFN_vkCmdCopyBuffer VMA_NULLABLE vkCmdCopyBuffer; +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + /// Fetch "vkGetBufferMemoryRequirements2" on Vulkan >= 1.1, fetch "vkGetBufferMemoryRequirements2KHR" when using VK_KHR_dedicated_allocation extension. + PFN_vkGetBufferMemoryRequirements2KHR VMA_NULLABLE vkGetBufferMemoryRequirements2KHR; + /// Fetch "vkGetImageMemoryRequirements2" on Vulkan >= 1.1, fetch "vkGetImageMemoryRequirements2KHR" when using VK_KHR_dedicated_allocation extension. + PFN_vkGetImageMemoryRequirements2KHR VMA_NULLABLE vkGetImageMemoryRequirements2KHR; +#endif +#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 + /// Fetch "vkBindBufferMemory2" on Vulkan >= 1.1, fetch "vkBindBufferMemory2KHR" when using VK_KHR_bind_memory2 extension. + PFN_vkBindBufferMemory2KHR VMA_NULLABLE vkBindBufferMemory2KHR; + /// Fetch "vkBindImageMemory2" on Vulkan >= 1.1, fetch "vkBindImageMemory2KHR" when using VK_KHR_bind_memory2 extension. + PFN_vkBindImageMemory2KHR VMA_NULLABLE vkBindImageMemory2KHR; +#endif +#if VMA_MEMORY_BUDGET || VMA_VULKAN_VERSION >= 1001000 + PFN_vkGetPhysicalDeviceMemoryProperties2KHR VMA_NULLABLE vkGetPhysicalDeviceMemoryProperties2KHR; +#endif +#if VMA_VULKAN_VERSION >= 1003000 + /// Fetch from "vkGetDeviceBufferMemoryRequirements" on Vulkan >= 1.3, but you can also fetch it from "vkGetDeviceBufferMemoryRequirementsKHR" if you enabled extension VK_KHR_maintenance4. + PFN_vkGetDeviceBufferMemoryRequirements VMA_NULLABLE vkGetDeviceBufferMemoryRequirements; + /// Fetch from "vkGetDeviceImageMemoryRequirements" on Vulkan >= 1.3, but you can also fetch it from "vkGetDeviceImageMemoryRequirementsKHR" if you enabled extension VK_KHR_maintenance4. + PFN_vkGetDeviceImageMemoryRequirements VMA_NULLABLE vkGetDeviceImageMemoryRequirements; +#endif +} VmaVulkanFunctions; --# Fill VmaPoolCreateInfo structure. --# Call vmaCreatePool() to obtain #VmaPool handle. --# When making an allocation, set VmaAllocationCreateInfo::pool to this handle. - You don't need to specify any other parameters of this structure, like `usage`. +/// Description of a Allocator to be created. +typedef struct VmaAllocatorCreateInfo +{ + /// Flags for created allocator. Use #VmaAllocatorCreateFlagBits enum. + VmaAllocatorCreateFlags flags; + /// Vulkan physical device. + /** It must be valid throughout whole lifetime of created allocator. */ + VkPhysicalDevice VMA_NOT_NULL physicalDevice; + /// Vulkan device. + /** It must be valid throughout whole lifetime of created allocator. */ + VkDevice VMA_NOT_NULL device; + /// Preferred size of a single `VkDeviceMemory` block to be allocated from large heaps > 1 GiB. Optional. + /** Set to 0 to use default, which is currently 256 MiB. */ + VkDeviceSize preferredLargeHeapBlockSize; + /// Custom CPU memory allocation callbacks. Optional. + /** Optional, can be null. When specified, will also be used for all CPU-side memory allocations. */ + const VkAllocationCallbacks* VMA_NULLABLE pAllocationCallbacks; + /// Informative callbacks for `vkAllocateMemory`, `vkFreeMemory`. Optional. + /** Optional, can be null. */ + const VmaDeviceMemoryCallbacks* VMA_NULLABLE pDeviceMemoryCallbacks; + /** \brief Either null or a pointer to an array of limits on maximum number of bytes that can be allocated out of particular Vulkan memory heap. -Example: + If not NULL, it must be a pointer to an array of + `VkPhysicalDeviceMemoryProperties::memoryHeapCount` elements, defining limit on + maximum number of bytes that can be allocated out of particular Vulkan memory + heap. -\code -// Create a pool that can have at most 2 blocks, 128 MiB each. -VmaPoolCreateInfo poolCreateInfo = {}; -poolCreateInfo.memoryTypeIndex = ... -poolCreateInfo.blockSize = 128ull * 1024 * 1024; -poolCreateInfo.maxBlockCount = 2; + Any of the elements may be equal to `VK_WHOLE_SIZE`, which means no limit on that + heap. This is also the default in case of `pHeapSizeLimit` = NULL. -VmaPool pool; -vmaCreatePool(allocator, &poolCreateInfo, &pool); + If there is a limit defined for a heap: -// Allocate a buffer out of it. -VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; -bufCreateInfo.size = 1024; -bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + - If user tries to allocate more memory from that heap using this allocator, + the allocation fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. + - If the limit is smaller than heap size reported in `VkMemoryHeap::size`, the + value of this limit will be reported instead when using vmaGetMemoryProperties(). -VmaAllocationCreateInfo allocCreateInfo = {}; -allocCreateInfo.pool = pool; + Warning! Using this feature may not be equivalent to installing a GPU with + smaller amount of memory, because graphics driver doesn't necessary fail new + allocations with `VK_ERROR_OUT_OF_DEVICE_MEMORY` result when memory capacity is + exceeded. It may return success and just silently migrate some device memory + blocks to system RAM. This driver behavior can also be controlled using + VK_AMD_memory_overallocation_behavior extension. + */ + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL("VkPhysicalDeviceMemoryProperties::memoryHeapCount") pHeapSizeLimit; -VkBuffer buf; -VmaAllocation alloc; -VmaAllocationInfo allocInfo; -vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); -\endcode + /** \brief Pointers to Vulkan functions. Can be null. -You have to free all allocations made from this pool before destroying it. + For details see [Pointers to Vulkan functions](@ref config_Vulkan_functions). + */ + const VmaVulkanFunctions* VMA_NULLABLE pVulkanFunctions; + /** \brief Handle to Vulkan instance object. -\code -vmaDestroyBuffer(allocator, buf, alloc); -vmaDestroyPool(allocator, pool); -\endcode + Starting from version 3.0.0 this member is no longer optional, it must be set! + */ + VkInstance VMA_NOT_NULL instance; + /** \brief Optional. The highest version of Vulkan that the application is designed to use. -\section custom_memory_pools_MemTypeIndex Choosing memory type index + It must be a value in the format as created by macro `VK_MAKE_VERSION` or a constant like: `VK_API_VERSION_1_1`, `VK_API_VERSION_1_0`. + The patch version number specified is ignored. Only the major and minor versions are considered. + It must be less or equal (preferably equal) to value as passed to `vkCreateInstance` as `VkApplicationInfo::apiVersion`. + Only versions 1.0, 1.1, 1.2, 1.3 are supported by the current implementation. + Leaving it initialized to zero is equivalent to `VK_API_VERSION_1_0`. + */ + uint32_t vulkanApiVersion; +#if VMA_EXTERNAL_MEMORY + /** \brief Either null or a pointer to an array of external memory handle types for each Vulkan memory type. -When creating a pool, you must explicitly specify memory type index. -To find the one suitable for your buffers or images, you can use helper functions -vmaFindMemoryTypeIndexForBufferInfo(), vmaFindMemoryTypeIndexForImageInfo(). -You need to provide structures with example parameters of buffers or images -that you are going to create in that pool. + If not NULL, it must be a pointer to an array of `VkPhysicalDeviceMemoryProperties::memoryTypeCount` + elements, defining external memory handle types of particular Vulkan memory type, + to be passed using `VkExportMemoryAllocateInfoKHR`. -\code -VkBufferCreateInfo exampleBufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; -exampleBufCreateInfo.size = 1024; // Whatever. -exampleBufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; // Change if needed. + Any of the elements may be equal to 0, which means not to use `VkExportMemoryAllocateInfoKHR` on this memory type. + This is also the default in case of `pTypeExternalMemoryHandleTypes` = NULL. + */ + const VkExternalMemoryHandleTypeFlagsKHR* VMA_NULLABLE VMA_LEN_IF_NOT_NULL("VkPhysicalDeviceMemoryProperties::memoryTypeCount") pTypeExternalMemoryHandleTypes; +#endif // #if VMA_EXTERNAL_MEMORY +} VmaAllocatorCreateInfo; -VmaAllocationCreateInfo allocCreateInfo = {}; -allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; // Change if needed. +/// Information about existing #VmaAllocator object. +typedef struct VmaAllocatorInfo +{ + /** \brief Handle to Vulkan instance object. -uint32_t memTypeIndex; -vmaFindMemoryTypeIndexForBufferInfo(allocator, &exampleBufCreateInfo, &allocCreateInfo, &memTypeIndex); + This is the same value as has been passed through VmaAllocatorCreateInfo::instance. + */ + VkInstance VMA_NOT_NULL instance; + /** \brief Handle to Vulkan physical device object. -VmaPoolCreateInfo poolCreateInfo = {}; -poolCreateInfo.memoryTypeIndex = memTypeIndex; -// ... -\endcode + This is the same value as has been passed through VmaAllocatorCreateInfo::physicalDevice. + */ + VkPhysicalDevice VMA_NOT_NULL physicalDevice; + /** \brief Handle to Vulkan device object. -When creating buffers/images allocated in that pool, provide following parameters: + This is the same value as has been passed through VmaAllocatorCreateInfo::device. + */ + VkDevice VMA_NOT_NULL device; +} VmaAllocatorInfo; -- `VkBufferCreateInfo`: Prefer to pass same parameters as above. - Otherwise you risk creating resources in a memory type that is not suitable for them, which may result in undefined behavior. - Using different `VK_BUFFER_USAGE_` flags may work, but you shouldn't create images in a pool intended for buffers - or the other way around. -- VmaAllocationCreateInfo: You don't need to pass same parameters. Fill only `pool` member. - Other members are ignored anyway. +/** @} */ -\section linear_algorithm Linear allocation algorithm +/** +\addtogroup group_stats +@{ +*/ -Each Vulkan memory block managed by this library has accompanying metadata that -keeps track of used and unused regions. By default, the metadata structure and -algorithm tries to find best place for new allocations among free regions to -optimize memory usage. This way you can allocate and free objects in any order. +/** \brief Calculated statistics of memory usage e.g. in a specific memory type, heap, custom pool, or total. -![Default allocation algorithm](../gfx/Linear_allocator_1_algo_default.png) +These are fast to calculate. +See functions: vmaGetHeapBudgets(), vmaGetPoolStatistics(). +*/ +typedef struct VmaStatistics +{ + /** \brief Number of `VkDeviceMemory` objects - Vulkan memory blocks allocated. + */ + uint32_t blockCount; + /** \brief Number of #VmaAllocation objects allocated. + + Dedicated allocations have their own blocks, so each one adds 1 to `allocationCount` as well as `blockCount`. + */ + uint32_t allocationCount; + /** \brief Number of bytes allocated in `VkDeviceMemory` blocks. + + \note To avoid confusion, please be aware that what Vulkan calls an "allocation" - a whole `VkDeviceMemory` object + (e.g. as in `VkPhysicalDeviceLimits::maxMemoryAllocationCount`) is called a "block" in VMA, while VMA calls + "allocation" a #VmaAllocation object that represents a memory region sub-allocated from such block, usually for a single buffer or image. + */ + VkDeviceSize blockBytes; + /** \brief Total number of bytes occupied by all #VmaAllocation objects. + + Always less or equal than `blockBytes`. + Difference `(blockBytes - allocationBytes)` is the amount of memory allocated from Vulkan + but unused by any #VmaAllocation. + */ + VkDeviceSize allocationBytes; +} VmaStatistics; -Sometimes there is a need to use simpler, linear allocation algorithm. You can -create custom pool that uses such algorithm by adding flag -#VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT to VmaPoolCreateInfo::flags while creating -#VmaPool object. Then an alternative metadata management is used. It always -creates new allocations after last one and doesn't reuse free regions after -allocations freed in the middle. It results in better allocation performance and -less memory consumed by metadata. +/** \brief More detailed statistics than #VmaStatistics. -![Linear allocation algorithm](../gfx/Linear_allocator_2_algo_linear.png) +These are slower to calculate. Use for debugging purposes. +See functions: vmaCalculateStatistics(), vmaCalculatePoolStatistics(). -With this one flag, you can create a custom pool that can be used in many ways: -free-at-once, stack, double stack, and ring buffer. See below for details. +Previous version of the statistics API provided averages, but they have been removed +because they can be easily calculated as: -\subsection linear_algorithm_free_at_once Free-at-once +\code +VkDeviceSize allocationSizeAvg = detailedStats.statistics.allocationBytes / detailedStats.statistics.allocationCount; +VkDeviceSize unusedBytes = detailedStats.statistics.blockBytes - detailedStats.statistics.allocationBytes; +VkDeviceSize unusedRangeSizeAvg = unusedBytes / detailedStats.unusedRangeCount; +\endcode +*/ +typedef struct VmaDetailedStatistics +{ + /// Basic statistics. + VmaStatistics statistics; + /// Number of free ranges of memory between allocations. + uint32_t unusedRangeCount; + /// Smallest allocation size. `VK_WHOLE_SIZE` if there are 0 allocations. + VkDeviceSize allocationSizeMin; + /// Largest allocation size. 0 if there are 0 allocations. + VkDeviceSize allocationSizeMax; + /// Smallest empty range size. `VK_WHOLE_SIZE` if there are 0 empty ranges. + VkDeviceSize unusedRangeSizeMin; + /// Largest empty range size. 0 if there are 0 empty ranges. + VkDeviceSize unusedRangeSizeMax; +} VmaDetailedStatistics; -In a pool that uses linear algorithm, you still need to free all the allocations -individually, e.g. by using vmaFreeMemory() or vmaDestroyBuffer(). You can free -them in any order. New allocations are always made after last one - free space -in the middle is not reused. However, when you release all the allocation and -the pool becomes empty, allocation starts from the beginning again. This way you -can use linear algorithm to speed up creation of allocations that you are going -to release all at once. +/** \brief General statistics from current state of the Allocator - +total memory usage across all memory heaps and types. -![Free-at-once](../gfx/Linear_allocator_3_free_at_once.png) +These are slower to calculate. Use for debugging purposes. +See function vmaCalculateStatistics(). +*/ +typedef struct VmaTotalStatistics +{ + VmaDetailedStatistics memoryType[VK_MAX_MEMORY_TYPES]; + VmaDetailedStatistics memoryHeap[VK_MAX_MEMORY_HEAPS]; + VmaDetailedStatistics total; +} VmaTotalStatistics; -This mode is also available for pools created with VmaPoolCreateInfo::maxBlockCount -value that allows multiple memory blocks. +/** \brief Statistics of current memory usage and available budget for a specific memory heap. -\subsection linear_algorithm_stack Stack +These are fast to calculate. +See function vmaGetHeapBudgets(). +*/ +typedef struct VmaBudget +{ + /** \brief Statistics fetched from the library. + */ + VmaStatistics statistics; + /** \brief Estimated current memory usage of the program, in bytes. -When you free an allocation that was created last, its space can be reused. -Thanks to this, if you always release allocations in the order opposite to their -creation (LIFO - Last In First Out), you can achieve behavior of a stack. + Fetched from system using VK_EXT_memory_budget extension if enabled. -![Stack](../gfx/Linear_allocator_4_stack.png) + It might be different than `statistics.blockBytes` (usually higher) due to additional implicit objects + also occupying the memory, like swapchain, pipelines, descriptor heaps, command buffers, or + `VkDeviceMemory` blocks allocated outside of this library, if any. + */ + VkDeviceSize usage; + /** \brief Estimated amount of memory available to the program, in bytes. -This mode is also available for pools created with VmaPoolCreateInfo::maxBlockCount -value that allows multiple memory blocks. + Fetched from system using VK_EXT_memory_budget extension if enabled. -\subsection linear_algorithm_double_stack Double stack + It might be different (most probably smaller) than `VkMemoryHeap::size[heapIndex]` due to factors + external to the program, decided by the operating system. + Difference `budget - usage` is the amount of additional memory that can probably + be allocated without problems. Exceeding the budget may result in various problems. + */ + VkDeviceSize budget; +} VmaBudget; -The space reserved by a custom pool with linear algorithm may be used by two -stacks: +/** @} */ -- First, default one, growing up from offset 0. -- Second, "upper" one, growing down from the end towards lower offsets. +/** +\addtogroup group_alloc +@{ +*/ -To make allocation from upper stack, add flag #VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT -to VmaAllocationCreateInfo::flags. +/** \brief Parameters of new #VmaAllocation. -![Double stack](../gfx/Linear_allocator_7_double_stack.png) +To be used with functions like vmaCreateBuffer(), vmaCreateImage(), and many others. +*/ +typedef struct VmaAllocationCreateInfo +{ + /// Use #VmaAllocationCreateFlagBits enum. + VmaAllocationCreateFlags flags; + /** \brief Intended usage of memory. -Double stack is available only in pools with one memory block - -VmaPoolCreateInfo::maxBlockCount must be 1. Otherwise behavior is undefined. + You can leave #VMA_MEMORY_USAGE_UNKNOWN if you specify memory requirements in other way. \n + If `pool` is not null, this member is ignored. + */ + VmaMemoryUsage usage; + /** \brief Flags that must be set in a Memory Type chosen for an allocation. -When the two stacks' ends meet so there is not enough space between them for a -new allocation, such allocation fails with usual -`VK_ERROR_OUT_OF_DEVICE_MEMORY` error. + Leave 0 if you specify memory requirements in other way. \n + If `pool` is not null, this member is ignored.*/ + VkMemoryPropertyFlags requiredFlags; + /** \brief Flags that preferably should be set in a memory type chosen for an allocation. -\subsection linear_algorithm_ring_buffer Ring buffer + Set to 0 if no additional flags are preferred. \n + If `pool` is not null, this member is ignored. */ + VkMemoryPropertyFlags preferredFlags; + /** \brief Bitmask containing one bit set for every memory type acceptable for this allocation. -When you free some allocations from the beginning and there is not enough free space -for a new one at the end of a pool, allocator's "cursor" wraps around to the -beginning and starts allocation there. Thanks to this, if you always release -allocations in the same order as you created them (FIFO - First In First Out), -you can achieve behavior of a ring buffer / queue. + Value 0 is equivalent to `UINT32_MAX` - it means any memory type is accepted if + it meets other requirements specified by this structure, with no further + restrictions on memory type index. \n + If `pool` is not null, this member is ignored. + */ + uint32_t memoryTypeBits; + /** \brief Pool that this allocation should be created in. -![Ring buffer](../gfx/Linear_allocator_5_ring_buffer.png) + Leave `VK_NULL_HANDLE` to allocate from default pool. If not null, members: + `usage`, `requiredFlags`, `preferredFlags`, `memoryTypeBits` are ignored. + */ + VmaPool VMA_NULLABLE pool; + /** \brief Custom general-purpose pointer that will be stored in #VmaAllocation, can be read as VmaAllocationInfo::pUserData and changed using vmaSetAllocationUserData(). -Pools with linear algorithm support [lost allocations](@ref lost_allocations) when used as ring buffer. -If there is not enough free space for a new allocation, but existing allocations -from the front of the queue can become lost, they become lost and the allocation -succeeds. + If #VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT is used, it must be either + null or pointer to a null-terminated string. The string will be then copied to + internal buffer, so it doesn't need to be valid after allocation call. + */ + void* VMA_NULLABLE pUserData; + /** \brief A floating-point value between 0 and 1, indicating the priority of the allocation relative to other memory allocations. -![Ring buffer with lost allocations](../gfx/Linear_allocator_6_ring_buffer_lost.png) + It is used only when #VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT flag was used during creation of the #VmaAllocator object + and this allocation ends up as dedicated or is explicitly forced as dedicated using #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. + Otherwise, it has the priority of a memory block where it is placed and this variable is ignored. + */ + float priority; +} VmaAllocationCreateInfo; -Ring buffer is available only in pools with one memory block - -VmaPoolCreateInfo::maxBlockCount must be 1. Otherwise behavior is undefined. +/// Describes parameter of created #VmaPool. +typedef struct VmaPoolCreateInfo +{ + /** \brief Vulkan memory type index to allocate this pool from. + */ + uint32_t memoryTypeIndex; + /** \brief Use combination of #VmaPoolCreateFlagBits. + */ + VmaPoolCreateFlags flags; + /** \brief Size of a single `VkDeviceMemory` block to be allocated as part of this pool, in bytes. Optional. -\section buddy_algorithm Buddy allocation algorithm + Specify nonzero to set explicit, constant size of memory blocks used by this + pool. -There is another allocation algorithm that can be used with custom pools, called -"buddy". Its internal data structure is based on a tree of blocks, each having -size that is a power of two and a half of its parent's size. When you want to -allocate memory of certain size, a free node in the tree is located. If it's too -large, it is recursively split into two halves (called "buddies"). However, if -requested allocation size is not a power of two, the size of a tree node is -aligned up to the nearest power of two and the remaining space is wasted. When -two buddy nodes become free, they are merged back into one larger node. + Leave 0 to use default and let the library manage block sizes automatically. + Sizes of particular blocks may vary. + In this case, the pool will also support dedicated allocations. + */ + VkDeviceSize blockSize; + /** \brief Minimum number of blocks to be always allocated in this pool, even if they stay empty. -![Buddy allocator](../gfx/Buddy_allocator.png) + Set to 0 to have no preallocated blocks and allow the pool be completely empty. + */ + size_t minBlockCount; + /** \brief Maximum number of blocks that can be allocated in this pool. Optional. -The advantage of buddy allocation algorithm over default algorithm is faster -allocation and deallocation, as well as smaller external fragmentation. The -disadvantage is more wasted space (internal fragmentation). + Set to 0 to use default, which is `SIZE_MAX`, which means no limit. -For more information, please read ["Buddy memory allocation" on Wikipedia](https://en.wikipedia.org/wiki/Buddy_memory_allocation) -or other sources that describe this concept in general. + Set to same value as VmaPoolCreateInfo::minBlockCount to have fixed amount of memory allocated + throughout whole lifetime of this pool. + */ + size_t maxBlockCount; + /** \brief A floating-point value between 0 and 1, indicating the priority of the allocations in this pool relative to other memory allocations. -To use buddy allocation algorithm with a custom pool, add flag -#VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT to VmaPoolCreateInfo::flags while creating -#VmaPool object. + It is used only when #VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT flag was used during creation of the #VmaAllocator object. + Otherwise, this variable is ignored. + */ + float priority; + /** \brief Additional minimum alignment to be used for all allocations created from this pool. Can be 0. -Several limitations apply to pools that use buddy algorithm: + Leave 0 (default) not to impose any additional alignment. If not 0, it must be a power of two. + It can be useful in cases where alignment returned by Vulkan by functions like `vkGetBufferMemoryRequirements` is not enough, + e.g. when doing interop with OpenGL. + */ + VkDeviceSize minAllocationAlignment; + /** \brief Additional `pNext` chain to be attached to `VkMemoryAllocateInfo` used for every allocation made by this pool. Optional. -- It is recommended to use VmaPoolCreateInfo::blockSize that is a power of two. - Otherwise, only largest power of two smaller than the size is used for - allocations. The remaining space always stays unused. -- [Margins](@ref debugging_memory_usage_margins) and - [corruption detection](@ref debugging_memory_usage_corruption_detection) - don't work in such pools. -- [Lost allocations](@ref lost_allocations) don't work in such pools. You can - use them, but they never become lost. Support may be added in the future. -- [Defragmentation](@ref defragmentation) doesn't work with allocations made from - such pool. + Optional, can be null. If not null, it must point to a `pNext` chain of structures that can be attached to `VkMemoryAllocateInfo`. + It can be useful for special needs such as adding `VkExportMemoryAllocateInfoKHR`. + Structures pointed by this member must remain alive and unchanged for the whole lifetime of the custom pool. -\page defragmentation Defragmentation + Please note that some structures, e.g. `VkMemoryPriorityAllocateInfoEXT`, `VkMemoryDedicatedAllocateInfoKHR`, + can be attached automatically by this library when using other, more convenient of its features. + */ + void* VMA_NULLABLE pMemoryAllocateNext; +} VmaPoolCreateInfo; -Interleaved allocations and deallocations of many objects of varying size can -cause fragmentation over time, which can lead to a situation where the library is unable -to find a continuous range of free memory for a new allocation despite there is -enough free space, just scattered across many small free ranges between existing -allocations. +/** @} */ -To mitigate this problem, you can use defragmentation feature: -structure #VmaDefragmentationInfo2, function vmaDefragmentationBegin(), vmaDefragmentationEnd(). -Given set of allocations, -this function can move them to compact used memory, ensure more continuous free -space and possibly also free some `VkDeviceMemory` blocks. +/** +\addtogroup group_alloc +@{ +*/ -What the defragmentation does is: +/// Parameters of #VmaAllocation objects, that can be retrieved using function vmaGetAllocationInfo(). +typedef struct VmaAllocationInfo +{ + /** \brief Memory type index that this allocation was allocated from. -- Updates #VmaAllocation objects to point to new `VkDeviceMemory` and offset. - After allocation has been moved, its VmaAllocationInfo::deviceMemory and/or - VmaAllocationInfo::offset changes. You must query them again using - vmaGetAllocationInfo() if you need them. -- Moves actual data in memory. + It never changes. + */ + uint32_t memoryType; + /** \brief Handle to Vulkan memory object. -What it doesn't do, so you need to do it yourself: + Same memory object can be shared by multiple allocations. -- Recreate buffers and images that were bound to allocations that were defragmented and - bind them with their new places in memory. - You must use `vkDestroyBuffer()`, `vkDestroyImage()`, - `vkCreateBuffer()`, `vkCreateImage()`, vmaBindBufferMemory(), vmaBindImageMemory() - for that purpose and NOT vmaDestroyBuffer(), - vmaDestroyImage(), vmaCreateBuffer(), vmaCreateImage(), because you don't need to - destroy or create allocation objects! -- Recreate views and update descriptors that point to these buffers and images. + It can change after the allocation is moved during \ref defragmentation. + */ + VkDeviceMemory VMA_NULLABLE_NON_DISPATCHABLE deviceMemory; + /** \brief Offset in `VkDeviceMemory` object to the beginning of this allocation, in bytes. `(deviceMemory, offset)` pair is unique to this allocation. -\section defragmentation_cpu Defragmenting CPU memory + You usually don't need to use this offset. If you create a buffer or an image together with the allocation using e.g. function + vmaCreateBuffer(), vmaCreateImage(), functions that operate on these resources refer to the beginning of the buffer or image, + not entire device memory block. Functions like vmaMapMemory(), vmaBindBufferMemory() also refer to the beginning of the allocation + and apply this offset automatically. -Following example demonstrates how you can run defragmentation on CPU. -Only allocations created in memory types that are `HOST_VISIBLE` can be defragmented. -Others are ignored. + It can change after the allocation is moved during \ref defragmentation. + */ + VkDeviceSize offset; + /** \brief Size of this allocation, in bytes. -The way it works is: + It never changes. -- It temporarily maps entire memory blocks when necessary. -- It moves data using `memmove()` function. + \note Allocation size returned in this variable may be greater than the size + requested for the resource e.g. as `VkBufferCreateInfo::size`. Whole size of the + allocation is accessible for operations on memory e.g. using a pointer after + mapping with vmaMapMemory(), but operations on the resource e.g. using + `vkCmdCopyBuffer` must be limited to the size of the resource. + */ + VkDeviceSize size; + /** \brief Pointer to the beginning of this allocation as mapped data. -\code -// Given following variables already initialized: -VkDevice device; -VmaAllocator allocator; -std::vector buffers; -std::vector allocations; + If the allocation hasn't been mapped using vmaMapMemory() and hasn't been + created with #VMA_ALLOCATION_CREATE_MAPPED_BIT flag, this value is null. + It can change after call to vmaMapMemory(), vmaUnmapMemory(). + It can also change after the allocation is moved during \ref defragmentation. + */ + void* VMA_NULLABLE pMappedData; + /** \brief Custom general-purpose pointer that was passed as VmaAllocationCreateInfo::pUserData or set using vmaSetAllocationUserData(). -const uint32_t allocCount = (uint32_t)allocations.size(); -std::vector allocationsChanged(allocCount); + It can change after call to vmaSetAllocationUserData() for this allocation. + */ + void* VMA_NULLABLE pUserData; + /** \brief Custom allocation name that was set with vmaSetAllocationName(). + + It can change after call to vmaSetAllocationName() for this allocation. + + Another way to set custom name is to pass it in VmaAllocationCreateInfo::pUserData with + additional flag #VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT set [DEPRECATED]. + */ + const char* VMA_NULLABLE pName; +} VmaAllocationInfo; -VmaDefragmentationInfo2 defragInfo = {}; -defragInfo.allocationCount = allocCount; -defragInfo.pAllocations = allocations.data(); -defragInfo.pAllocationsChanged = allocationsChanged.data(); -defragInfo.maxCpuBytesToMove = VK_WHOLE_SIZE; // No limit. -defragInfo.maxCpuAllocationsToMove = UINT32_MAX; // No limit. +/** \brief Parameters for defragmentation. -VmaDefragmentationContext defragCtx; -vmaDefragmentationBegin(allocator, &defragInfo, nullptr, &defragCtx); -vmaDefragmentationEnd(allocator, defragCtx); +To be used with function vmaBeginDefragmentation(). +*/ +typedef struct VmaDefragmentationInfo +{ + /// \brief Use combination of #VmaDefragmentationFlagBits. + VmaDefragmentationFlags flags; + /** \brief Custom pool to be defragmented. -for (const auto i : c10::irange(allocCount)) { - if(allocationsChanged[i]) - { - // Destroy buffer that is immutably bound to memory region which is no longer valid. - vkDestroyBuffer(device, buffers[i], nullptr); + If null then default pools will undergo defragmentation process. + */ + VmaPool VMA_NULLABLE pool; + /** \brief Maximum numbers of bytes that can be copied during single pass, while moving allocations to different places. - // Create new buffer with same parameters. - VkBufferCreateInfo bufferInfo = ...; - vkCreateBuffer(device, &bufferInfo, nullptr, &buffers[i]); + `0` means no limit. + */ + VkDeviceSize maxBytesPerPass; + /** \brief Maximum number of allocations that can be moved during single pass to a different place. - // You can make dummy call to vkGetBufferMemoryRequirements here to silence validation layer warning. + `0` means no limit. + */ + uint32_t maxAllocationsPerPass; +} VmaDefragmentationInfo; - // Bind new buffer to new memory region. Data contained in it is already moved. - VmaAllocationInfo allocInfo; - vmaGetAllocationInfo(allocator, allocations[i], &allocInfo); - vmaBindBufferMemory(allocator, allocations[i], buffers[i]); - } -} -\endcode +/// Single move of an allocation to be done for defragmentation. +typedef struct VmaDefragmentationMove +{ + /// Operation to be performed on the allocation by vmaEndDefragmentationPass(). Default value is #VMA_DEFRAGMENTATION_MOVE_OPERATION_COPY. You can modify it. + VmaDefragmentationMoveOperation operation; + /// Allocation that should be moved. + VmaAllocation VMA_NOT_NULL srcAllocation; + /** \brief Temporary allocation pointing to destination memory that will replace `srcAllocation`. + + \warning Do not store this allocation in your data structures! It exists only temporarily, for the duration of the defragmentation pass, + to be used for binding new buffer/image to the destination memory using e.g. vmaBindBufferMemory(). + vmaEndDefragmentationPass() will destroy it and make `srcAllocation` point to this memory. + */ + VmaAllocation VMA_NOT_NULL dstTmpAllocation; +} VmaDefragmentationMove; -Setting VmaDefragmentationInfo2::pAllocationsChanged is optional. -This output array tells whether particular allocation in VmaDefragmentationInfo2::pAllocations at the same index -has been modified during defragmentation. -You can pass null, but you then need to query every allocation passed to defragmentation -for new parameters using vmaGetAllocationInfo() if you might need to recreate and rebind a buffer or image associated with it. +/** \brief Parameters for incremental defragmentation steps. -If you use [Custom memory pools](@ref choosing_memory_type_custom_memory_pools), -you can fill VmaDefragmentationInfo2::poolCount and VmaDefragmentationInfo2::pPools -instead of VmaDefragmentationInfo2::allocationCount and VmaDefragmentationInfo2::pAllocations -to defragment all allocations in given pools. -You cannot use VmaDefragmentationInfo2::pAllocationsChanged in that case. -You can also combine both methods. +To be used with function vmaBeginDefragmentationPass(). +*/ +typedef struct VmaDefragmentationPassMoveInfo +{ + /// Number of elements in the `pMoves` array. + uint32_t moveCount; + /** \brief Array of moves to be performed by the user in the current defragmentation pass. + + Pointer to an array of `moveCount` elements, owned by VMA, created in vmaBeginDefragmentationPass(), destroyed in vmaEndDefragmentationPass(). -\section defragmentation_gpu Defragmenting GPU memory + For each element, you should: + + 1. Create a new buffer/image in the place pointed by VmaDefragmentationMove::dstMemory + VmaDefragmentationMove::dstOffset. + 2. Copy data from the VmaDefragmentationMove::srcAllocation e.g. using `vkCmdCopyBuffer`, `vkCmdCopyImage`. + 3. Make sure these commands finished executing on the GPU. + 4. Destroy the old buffer/image. + + Only then you can finish defragmentation pass by calling vmaEndDefragmentationPass(). + After this call, the allocation will point to the new place in memory. -It is also possible to defragment allocations created in memory types that are not `HOST_VISIBLE`. -To do that, you need to pass a command buffer that meets requirements as described in -VmaDefragmentationInfo2::commandBuffer. The way it works is: + Alternatively, if you cannot move specific allocation, you can set VmaDefragmentationMove::operation to #VMA_DEFRAGMENTATION_MOVE_OPERATION_IGNORE. -- It creates temporary buffers and binds them to entire memory blocks when necessary. -- It issues `vkCmdCopyBuffer()` to passed command buffer. + Alternatively, if you decide you want to completely remove the allocation: -Example: + 1. Destroy its buffer/image. + 2. Set VmaDefragmentationMove::operation to #VMA_DEFRAGMENTATION_MOVE_OPERATION_DESTROY. -\code -// Given following variables already initialized: -VkDevice device; -VmaAllocator allocator; -VkCommandBuffer commandBuffer; -std::vector buffers; -std::vector allocations; + Then, after vmaEndDefragmentationPass() the allocation will be freed. + */ + VmaDefragmentationMove* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(moveCount) pMoves; +} VmaDefragmentationPassMoveInfo; +/// Statistics returned for defragmentation process in function vmaEndDefragmentation(). +typedef struct VmaDefragmentationStats +{ + /// Total number of bytes that have been copied while moving allocations to different places. + VkDeviceSize bytesMoved; + /// Total number of bytes that have been released to the system by freeing empty `VkDeviceMemory` objects. + VkDeviceSize bytesFreed; + /// Number of allocations that have been moved to different places. + uint32_t allocationsMoved; + /// Number of empty `VkDeviceMemory` objects that have been released to the system. + uint32_t deviceMemoryBlocksFreed; +} VmaDefragmentationStats; -const uint32_t allocCount = (uint32_t)allocations.size(); -std::vector allocationsChanged(allocCount); +/** @} */ -VkCommandBufferBeginInfo cmdBufBeginInfo = ...; -vkBeginCommandBuffer(commandBuffer, &cmdBufBeginInfo); +/** +\addtogroup group_virtual +@{ +*/ -VmaDefragmentationInfo2 defragInfo = {}; -defragInfo.allocationCount = allocCount; -defragInfo.pAllocations = allocations.data(); -defragInfo.pAllocationsChanged = allocationsChanged.data(); -defragInfo.maxGpuBytesToMove = VK_WHOLE_SIZE; // Notice it's "GPU" this time. -defragInfo.maxGpuAllocationsToMove = UINT32_MAX; // Notice it's "GPU" this time. -defragInfo.commandBuffer = commandBuffer; +/// Parameters of created #VmaVirtualBlock object to be passed to vmaCreateVirtualBlock(). +typedef struct VmaVirtualBlockCreateInfo +{ + /** \brief Total size of the virtual block. -VmaDefragmentationContext defragCtx; -vmaDefragmentationBegin(allocator, &defragInfo, nullptr, &defragCtx); + Sizes can be expressed in bytes or any units you want as long as you are consistent in using them. + For example, if you allocate from some array of structures, 1 can mean single instance of entire structure. + */ + VkDeviceSize size; -vkEndCommandBuffer(commandBuffer); + /** \brief Use combination of #VmaVirtualBlockCreateFlagBits. + */ + VmaVirtualBlockCreateFlags flags; -// Submit commandBuffer. -// Wait for a fence that ensures commandBuffer execution finished. + /** \brief Custom CPU memory allocation callbacks. Optional. -vmaDefragmentationEnd(allocator, defragCtx); - -for (const auto i : c10::irange(allocCount)) { - if(allocationsChanged[i]) - { - // Destroy buffer that is immutably bound to memory region which is no longer valid. - vkDestroyBuffer(device, buffers[i], nullptr); + Optional, can be null. When specified, they will be used for all CPU-side memory allocations. + */ + const VkAllocationCallbacks* VMA_NULLABLE pAllocationCallbacks; +} VmaVirtualBlockCreateInfo; - // Create new buffer with same parameters. - VkBufferCreateInfo bufferInfo = ...; - vkCreateBuffer(device, &bufferInfo, nullptr, &buffers[i]); +/// Parameters of created virtual allocation to be passed to vmaVirtualAllocate(). +typedef struct VmaVirtualAllocationCreateInfo +{ + /** \brief Size of the allocation. - // You can make dummy call to vkGetBufferMemoryRequirements here to silence validation layer warning. + Cannot be zero. + */ + VkDeviceSize size; + /** \brief Required alignment of the allocation. Optional. - // Bind new buffer to new memory region. Data contained in it is already moved. - VmaAllocationInfo allocInfo; - vmaGetAllocationInfo(allocator, allocations[i], &allocInfo); - vmaBindBufferMemory(allocator, allocations[i], buffers[i]); - } -} -\endcode + Must be power of two. Special value 0 has the same meaning as 1 - means no special alignment is required, so allocation can start at any offset. + */ + VkDeviceSize alignment; + /** \brief Use combination of #VmaVirtualAllocationCreateFlagBits. + */ + VmaVirtualAllocationCreateFlags flags; + /** \brief Custom pointer to be associated with the allocation. Optional. -You can combine these two methods by specifying non-zero `maxGpu*` as well as `maxCpu*` parameters. -The library automatically chooses best method to defragment each memory pool. - -You may try not to block your entire program to wait until defragmentation finishes, -but do it in the background, as long as you carefully fullfill requirements described -in function vmaDefragmentationBegin(). - -\section defragmentation_additional_notes Additional notes - -It is only legal to defragment allocations bound to: - -- buffers -- images created with `VK_IMAGE_CREATE_ALIAS_BIT`, `VK_IMAGE_TILING_LINEAR`, and - being currently in `VK_IMAGE_LAYOUT_GENERAL` or `VK_IMAGE_LAYOUT_PREINITIALIZED`. - -Defragmentation of images created with `VK_IMAGE_TILING_OPTIMAL` or in any other -layout may give undefined results. - -If you defragment allocations bound to images, new images to be bound to new -memory region after defragmentation should be created with `VK_IMAGE_LAYOUT_PREINITIALIZED` -and then transitioned to their original layout from before defragmentation if -needed using an image memory barrier. - -While using defragmentation, you may experience validation layer warnings, which you just need to ignore. -See [Validation layer warnings](@ref general_considerations_validation_layer_warnings). - -Please don't expect memory to be fully compacted after defragmentation. -Algorithms inside are based on some heuristics that try to maximize number of Vulkan -memory blocks to make totally empty to release them, as well as to maximimze continuous -empty space inside remaining blocks, while minimizing the number and size of allocations that -need to be moved. Some fragmentation may still remain - this is normal. - -\section defragmentation_custom_algorithm Writing custom defragmentation algorithm - -If you want to implement your own, custom defragmentation algorithm, -there is infrastructure prepared for that, -but it is not exposed through the library API - you need to hack its source code. -Here are steps needed to do this: - --# Main thing you need to do is to define your own class derived from base abstract - class `VmaDefragmentationAlgorithm` and implement your version of its pure virtual methods. - See definition and comments of this class for details. --# Your code needs to interact with device memory block metadata. - If you need more access to its data than it's provided by its public interface, - declare your new class as a friend class e.g. in class `VmaBlockMetadata_Generic`. --# If you want to create a flag that would enable your algorithm or pass some additional - flags to configure it, add them to `VmaDefragmentationFlagBits` and use them in - VmaDefragmentationInfo2::flags. --# Modify function `VmaBlockVectorDefragmentationContext::Begin` to create object - of your new class whenever needed. - - -\page lost_allocations Lost allocations - -If your game oversubscribes video memory, if may work OK in previous-generation -graphics APIs (DirectX 9, 10, 11, OpenGL) because resources are automatically -paged to system RAM. In Vulkan you can't do it because when you run out of -memory, an allocation just fails. If you have more data (e.g. textures) that can -fit into VRAM and you don't need it all at once, you may want to upload them to -GPU on demand and "push out" ones that are not used for a long time to make room -for the new ones, effectively using VRAM (or a cartain memory pool) as a form of -cache. Vulkan Memory Allocator can help you with that by supporting a concept of -"lost allocations". - -To create an allocation that can become lost, include #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT -flag in VmaAllocationCreateInfo::flags. Before using a buffer or image bound to -such allocation in every new frame, you need to query it if it's not lost. -To check it, call vmaTouchAllocation(). -If the allocation is lost, you should not use it or buffer/image bound to it. -You mustn't forget to destroy this allocation and this buffer/image. -vmaGetAllocationInfo() can also be used for checking status of the allocation. -Allocation is lost when returned VmaAllocationInfo::deviceMemory == `VK_NULL_HANDLE`. - -To create an allocation that can make some other allocations lost to make room -for it, use #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT flag. You will -usually use both flags #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT and -#VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT at the same time. - -Warning! Current implementation uses quite naive, brute force algorithm, -which can make allocation calls that use #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT -flag quite slow. A new, more optimal algorithm and data structure to speed this -up is planned for the future. - -Q: When interleaving creation of new allocations with usage of existing ones, -how do you make sure that an allocation won't become lost while it's used in the -current frame? - -It is ensured because vmaTouchAllocation() / vmaGetAllocationInfo() not only returns allocation -status/parameters and checks whether it's not lost, but when it's not, it also -atomically marks it as used in the current frame, which makes it impossible to -become lost in that frame. It uses lockless algorithm, so it works fast and -doesn't involve locking any internal mutex. - -Q: What if my allocation may still be in use by the GPU when it's rendering a -previous frame while I already submit new frame on the CPU? - -You can make sure that allocations "touched" by vmaTouchAllocation() / vmaGetAllocationInfo() will not -become lost for a number of additional frames back from the current one by -specifying this number as VmaAllocatorCreateInfo::frameInUseCount (for default -memory pool) and VmaPoolCreateInfo::frameInUseCount (for custom pool). - -Q: How do you inform the library when new frame starts? - -You need to call function vmaSetCurrentFrameIndex(). - -Example code: + It can be any value and can be used for user-defined purposes. It can be fetched or changed later. + */ + void* VMA_NULLABLE pUserData; +} VmaVirtualAllocationCreateInfo; -\code -struct MyBuffer +/// Parameters of an existing virtual allocation, returned by vmaGetVirtualAllocationInfo(). +typedef struct VmaVirtualAllocationInfo { - VkBuffer m_Buf = nullptr; - VmaAllocation m_Alloc = nullptr; + /** \brief Offset of the allocation. + + Offset at which the allocation was made. + */ + VkDeviceSize offset; + /** \brief Size of the allocation. - // Called when the buffer is really needed in the current frame. - void EnsureBuffer(); -}; + Same value as passed in VmaVirtualAllocationCreateInfo::size. + */ + VkDeviceSize size; + /** \brief Custom pointer associated with the allocation. -void MyBuffer::EnsureBuffer() -{ - // Buffer has been created. - if(m_Buf != VK_NULL_HANDLE) - { - // Check if its allocation is not lost + mark it as used in current frame. - if(vmaTouchAllocation(allocator, m_Alloc)) - { - // It's all OK - safe to use m_Buf. - return; - } - } + Same value as passed in VmaVirtualAllocationCreateInfo::pUserData or to vmaSetVirtualAllocationUserData(). + */ + void* VMA_NULLABLE pUserData; +} VmaVirtualAllocationInfo; - // Buffer not yet exists or lost - destroy and recreate it. +/** @} */ - vmaDestroyBuffer(allocator, m_Buf, m_Alloc); +#endif // _VMA_DATA_TYPES_DECLARATIONS - VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; - bufCreateInfo.size = 1024; - bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; +#ifndef _VMA_FUNCTION_HEADERS - VmaAllocationCreateInfo allocCreateInfo = {}; - allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; - allocCreateInfo.flags = VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT | - VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT; +/** +\addtogroup group_init +@{ +*/ - vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &m_Buf, &m_Alloc, nullptr); -} -\endcode +/// Creates #VmaAllocator object. +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAllocator( + const VmaAllocatorCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaAllocator VMA_NULLABLE* VMA_NOT_NULL pAllocator); -When using lost allocations, you may see some Vulkan validation layer warnings -about overlapping regions of memory bound to different kinds of buffers and -images. This is still valid as long as you implement proper handling of lost -allocations (like in the example above) and don't use them. +/// Destroys allocator object. +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyAllocator( + VmaAllocator VMA_NULLABLE allocator); -You can create an allocation that is already in lost state from the beginning using function -vmaCreateLostAllocation(). It may be useful if you need a "dummy" allocation that is not null. +/** \brief Returns information about existing #VmaAllocator object - handle to Vulkan device etc. -You can call function vmaMakePoolAllocationsLost() to set all eligible allocations -in a specified custom pool to lost state. -Allocations that have been "touched" in current frame or VmaPoolCreateInfo::frameInUseCount frames back -cannot become lost. +It might be useful if you want to keep just the #VmaAllocator handle and fetch other required handles to +`VkPhysicalDevice`, `VkDevice` etc. every time using this function. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocatorInfo( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocatorInfo* VMA_NOT_NULL pAllocatorInfo); -Q: Can I touch allocation that cannot become lost? +/** +PhysicalDeviceProperties are fetched from physicalDevice by the allocator. +You can access it here, without fetching it again on your own. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetPhysicalDeviceProperties( + VmaAllocator VMA_NOT_NULL allocator, + const VkPhysicalDeviceProperties* VMA_NULLABLE* VMA_NOT_NULL ppPhysicalDeviceProperties); -Yes, although it has no visible effect. -Calls to vmaGetAllocationInfo() and vmaTouchAllocation() update last use frame index -also for allocations that cannot become lost, but the only way to observe it is to dump -internal allocator state using vmaBuildStatsString(). -You can use this feature for debugging purposes to explicitly mark allocations that you use -in current frame and then analyze JSON dump to see for how long each allocation stays unused. +/** +PhysicalDeviceMemoryProperties are fetched from physicalDevice by the allocator. +You can access it here, without fetching it again on your own. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryProperties( + VmaAllocator VMA_NOT_NULL allocator, + const VkPhysicalDeviceMemoryProperties* VMA_NULLABLE* VMA_NOT_NULL ppPhysicalDeviceMemoryProperties); +/** +\brief Given Memory Type Index, returns Property Flags of this memory type. -\page statistics Statistics +This is just a convenience function. Same information can be obtained using +vmaGetMemoryProperties(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryTypeProperties( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryTypeIndex, + VkMemoryPropertyFlags* VMA_NOT_NULL pFlags); -This library contains functions that return information about its internal state, -especially the amount of memory allocated from Vulkan. -Please keep in mind that these functions need to traverse all internal data structures -to gather these information, so they may be quite time-consuming. -Don't call them too often. +/** \brief Sets index of the current frame. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaSetCurrentFrameIndex( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t frameIndex); -\section statistics_numeric_statistics Numeric statistics +/** @} */ -You can query for overall statistics of the allocator using function vmaCalculateStats(). -Information are returned using structure #VmaStats. -It contains #VmaStatInfo - number of allocated blocks, number of allocations -(occupied ranges in these blocks), number of unused (free) ranges in these blocks, -number of bytes used and unused (but still allocated from Vulkan) and other information. -They are summed across memory heaps, memory types and total for whole allocator. +/** +\addtogroup group_stats +@{ +*/ -You can query for statistics of a custom pool using function vmaGetPoolStats(). -Information are returned using structure #VmaPoolStats. +/** \brief Retrieves statistics from current state of the Allocator. -You can query for information about specific allocation using function vmaGetAllocationInfo(). -It fill structure #VmaAllocationInfo. +This function is called "calculate" not "get" because it has to traverse all +internal data structures, so it may be quite slow. Use it for debugging purposes. +For faster but more brief statistics suitable to be called every frame or every allocation, +use vmaGetHeapBudgets(). -\section statistics_json_dump JSON dump +Note that when using allocator from multiple threads, returned information may immediately +become outdated. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaCalculateStatistics( + VmaAllocator VMA_NOT_NULL allocator, + VmaTotalStatistics* VMA_NOT_NULL pStats); -You can dump internal state of the allocator to a string in JSON format using function vmaBuildStatsString(). -The result is guaranteed to be correct JSON. -It uses ANSI encoding. -Any strings provided by user (see [Allocation names](@ref allocation_names)) -are copied as-is and properly escaped for JSON, so if they use UTF-8, ISO-8859-2 or any other encoding, -this JSON string can be treated as using this encoding. -It must be freed using function vmaFreeStatsString(). +/** \brief Retrieves information about current memory usage and budget for all memory heaps. -The format of this JSON string is not part of official documentation of the library, -but it will not change in backward-incompatible way without increasing library major version number -and appropriate mention in changelog. +\param allocator +\param[out] pBudgets Must point to array with number of elements at least equal to number of memory heaps in physical device used. -The JSON string contains all the data that can be obtained using vmaCalculateStats(). -It can also contain detailed map of allocated memory blocks and their regions - -free and occupied by allocations. -This allows e.g. to visualize the memory or assess fragmentation. +This function is called "get" not "calculate" because it is very fast, suitable to be called +every frame or every allocation. For more detailed statistics use vmaCalculateStatistics(). +Note that when using allocator from multiple threads, returned information may immediately +become outdated. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetHeapBudgets( + VmaAllocator VMA_NOT_NULL allocator, + VmaBudget* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL("VkPhysicalDeviceMemoryProperties::memoryHeapCount") pBudgets); -\page allocation_annotation Allocation names and user data +/** @} */ -\section allocation_user_data Allocation user data +/** +\addtogroup group_alloc +@{ +*/ -You can annotate allocations with your own information, e.g. for debugging purposes. -To do that, fill VmaAllocationCreateInfo::pUserData field when creating -an allocation. It's an opaque `void*` pointer. You can use it e.g. as a pointer, -some handle, index, key, ordinal number or any other value that would associate -the allocation with your custom metadata. +/** +\brief Helps to find memoryTypeIndex, given memoryTypeBits and VmaAllocationCreateInfo. -\code -VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; -// Fill bufferInfo... +This algorithm tries to find a memory type that: -MyBufferMetadata* pMetadata = CreateBufferMetadata(); +- Is allowed by memoryTypeBits. +- Contains all the flags from pAllocationCreateInfo->requiredFlags. +- Matches intended usage. +- Has as many flags from pAllocationCreateInfo->preferredFlags as possible. -VmaAllocationCreateInfo allocCreateInfo = {}; -allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; -allocCreateInfo.pUserData = pMetadata; +\return Returns VK_ERROR_FEATURE_NOT_PRESENT if not found. Receiving such result +from this function or any other allocating function probably means that your +device doesn't support any memory type with requested features for the specific +type of resource you want to use it for. Please check parameters of your +resource, like image layout (OPTIMAL versus LINEAR) or mip level count. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndex( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryTypeBits, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + uint32_t* VMA_NOT_NULL pMemoryTypeIndex); -VkBuffer buffer; -VmaAllocation allocation; -vmaCreateBuffer(allocator, &bufferInfo, &allocCreateInfo, &buffer, &allocation, nullptr); -\endcode +/** +\brief Helps to find memoryTypeIndex, given VkBufferCreateInfo and VmaAllocationCreateInfo. -The pointer may be later retrieved as VmaAllocationInfo::pUserData: +It can be useful e.g. to determine value to be used as VmaPoolCreateInfo::memoryTypeIndex. +It internally creates a temporary, dummy buffer that never has memory bound. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForBufferInfo( + VmaAllocator VMA_NOT_NULL allocator, + const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + uint32_t* VMA_NOT_NULL pMemoryTypeIndex); -\code -VmaAllocationInfo allocInfo; -vmaGetAllocationInfo(allocator, allocation, &allocInfo); -MyBufferMetadata* pMetadata = (MyBufferMetadata*)allocInfo.pUserData; -\endcode +/** +\brief Helps to find memoryTypeIndex, given VkImageCreateInfo and VmaAllocationCreateInfo. -It can also be changed using function vmaSetAllocationUserData(). +It can be useful e.g. to determine value to be used as VmaPoolCreateInfo::memoryTypeIndex. +It internally creates a temporary, dummy image that never has memory bound. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForImageInfo( + VmaAllocator VMA_NOT_NULL allocator, + const VkImageCreateInfo* VMA_NOT_NULL pImageCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + uint32_t* VMA_NOT_NULL pMemoryTypeIndex); -Values of (non-zero) allocations' `pUserData` are printed in JSON report created by -vmaBuildStatsString(), in hexadecimal form. +/** \brief Allocates Vulkan device memory and creates #VmaPool object. -\section allocation_names Allocation names +\param allocator Allocator object. +\param pCreateInfo Parameters of pool to create. +\param[out] pPool Handle to created pool. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreatePool( + VmaAllocator VMA_NOT_NULL allocator, + const VmaPoolCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaPool VMA_NULLABLE* VMA_NOT_NULL pPool); -There is alternative mode available where `pUserData` pointer is used to point to -a null-terminated string, giving a name to the allocation. To use this mode, -set #VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag in VmaAllocationCreateInfo::flags. -Then `pUserData` passed as VmaAllocationCreateInfo::pUserData or argument to -vmaSetAllocationUserData() must be either null or pointer to a null-terminated string. -The library creates internal copy of the string, so the pointer you pass doesn't need -to be valid for whole lifetime of the allocation. You can free it after the call. +/** \brief Destroys #VmaPool object and frees Vulkan device memory. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyPool( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NULLABLE pool); -\code -VkImageCreateInfo imageInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; -// Fill imageInfo... +/** @} */ -std::string imageName = "Texture: "; -imageName += fileName; +/** +\addtogroup group_stats +@{ +*/ -VmaAllocationCreateInfo allocCreateInfo = {}; -allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; -allocCreateInfo.flags = VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT; -allocCreateInfo.pUserData = imageName.c_str(); +/** \brief Retrieves statistics of existing #VmaPool object. -VkImage image; -VmaAllocation allocation; -vmaCreateImage(allocator, &imageInfo, &allocCreateInfo, &image, &allocation, nullptr); -\endcode +\param allocator Allocator object. +\param pool Pool object. +\param[out] pPoolStats Statistics of specified pool. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolStatistics( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool, + VmaStatistics* VMA_NOT_NULL pPoolStats); -The value of `pUserData` pointer of the allocation will be different than the one -you passed when setting allocation's name - pointing to a buffer managed -internally that holds copy of the string. +/** \brief Retrieves detailed statistics of existing #VmaPool object. -\code -VmaAllocationInfo allocInfo; -vmaGetAllocationInfo(allocator, allocation, &allocInfo); -const char* imageName = (const char*)allocInfo.pUserData; -printf("Image name: %s\n", imageName); -\endcode +\param allocator Allocator object. +\param pool Pool object. +\param[out] pPoolStats Statistics of specified pool. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaCalculatePoolStatistics( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool, + VmaDetailedStatistics* VMA_NOT_NULL pPoolStats); -That string is also printed in JSON report created by vmaBuildStatsString(). +/** @} */ -\note Passing string name to VMA allocation doesn't automatically set it to the Vulkan buffer or image created with it. -You must do it manually using an extension like VK_EXT_debug_utils, which is independent of this library. +/** +\addtogroup group_alloc +@{ +*/ +/** \brief Checks magic number in margins around all allocations in given memory pool in search for corruptions. -\page debugging_memory_usage Debugging incorrect memory usage +Corruption detection is enabled only when `VMA_DEBUG_DETECT_CORRUPTION` macro is defined to nonzero, +`VMA_DEBUG_MARGIN` is defined to nonzero and the pool is created in memory type that is +`HOST_VISIBLE` and `HOST_COHERENT`. For more information, see [Corruption detection](@ref debugging_memory_usage_corruption_detection). -If you suspect a bug with memory usage, like usage of uninitialized memory or -memory being overwritten out of bounds of an allocation, -you can use debug features of this library to verify this. +Possible return values: -\section debugging_memory_usage_initialization Memory initialization +- `VK_ERROR_FEATURE_NOT_PRESENT` - corruption detection is not enabled for specified pool. +- `VK_SUCCESS` - corruption detection has been performed and succeeded. +- `VK_ERROR_UNKNOWN` - corruption detection has been performed and found memory corruptions around one of the allocations. + `VMA_ASSERT` is also fired in that case. +- Other value: Error returned by Vulkan, e.g. memory mapping failure. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckPoolCorruption( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool); -If you experience a bug with incorrect and nondeterministic data in your program and you suspect uninitialized memory to be used, -you can enable automatic memory initialization to verify this. -To do it, define macro `VMA_DEBUG_INITIALIZE_ALLOCATIONS` to 1. +/** \brief Retrieves name of a custom pool. -\code -#define VMA_DEBUG_INITIALIZE_ALLOCATIONS 1 -#include -\endcode +After the call `ppName` is either null or points to an internally-owned null-terminated string +containing name of the pool that was previously set. The pointer becomes invalid when the pool is +destroyed or its name is changed using vmaSetPoolName(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolName( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool, + const char* VMA_NULLABLE* VMA_NOT_NULL ppName); -It makes memory of all new allocations initialized to bit pattern `0xDCDCDCDC`. -Before an allocation is destroyed, its memory is filled with bit pattern `0xEFEFEFEF`. -Memory is automatically mapped and unmapped if necessary. +/** \brief Sets name of a custom pool. -If you find these values while debugging your program, good chances are that you incorrectly -read Vulkan memory that is allocated but not initialized, or already freed, respectively. +`pName` can be either null or pointer to a null-terminated string with new name for the pool. +Function makes internal copy of the string, so it can be changed or freed immediately after this call. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaSetPoolName( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool, + const char* VMA_NULLABLE pName); -Memory initialization works only with memory types that are `HOST_VISIBLE`. -It works also with dedicated allocations. -It doesn't work with allocations created with #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag, -as they cannot be mapped. +/** \brief General purpose memory allocation. -\section debugging_memory_usage_margins Margins +\param allocator +\param pVkMemoryRequirements +\param pCreateInfo +\param[out] pAllocation Handle to allocated memory. +\param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). -By default, allocations are laid out in memory blocks next to each other if possible -(considering required alignment, `bufferImageGranularity`, and `nonCoherentAtomSize`). +You should free the memory using vmaFreeMemory() or vmaFreeMemoryPages(). -![Allocations without margin](../gfx/Margins_1.png) +It is recommended to use vmaAllocateMemoryForBuffer(), vmaAllocateMemoryForImage(), +vmaCreateBuffer(), vmaCreateImage() instead whenever possible. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemory( + VmaAllocator VMA_NOT_NULL allocator, + const VkMemoryRequirements* VMA_NOT_NULL pVkMemoryRequirements, + const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaAllocation VMA_NULLABLE* VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); -Define macro `VMA_DEBUG_MARGIN` to some non-zero value (e.g. 16) to enforce specified -number of bytes as a margin before and after every allocation. +/** \brief General purpose memory allocation for multiple allocation objects at once. -\code -#define VMA_DEBUG_MARGIN 16 -#include -\endcode +\param allocator Allocator object. +\param pVkMemoryRequirements Memory requirements for each allocation. +\param pCreateInfo Creation parameters for each allocation. +\param allocationCount Number of allocations to make. +\param[out] pAllocations Pointer to array that will be filled with handles to created allocations. +\param[out] pAllocationInfo Optional. Pointer to array that will be filled with parameters of created allocations. -![Allocations with margin](../gfx/Margins_2.png) +You should free the memory using vmaFreeMemory() or vmaFreeMemoryPages(). -If your bug goes away after enabling margins, it means it may be caused by memory -being overwritten outside of allocation boundaries. It is not 100% certain though. -Change in application behavior may also be caused by different order and distribution -of allocations across memory blocks after margins are applied. +Word "pages" is just a suggestion to use this function to allocate pieces of memory needed for sparse binding. +It is just a general purpose allocation function able to make multiple allocations at once. +It may be internally optimized to be more efficient than calling vmaAllocateMemory() `allocationCount` times. -The margin is applied also before first and after last allocation in a block. -It may occur only once between two adjacent allocations. +All allocations are made using same parameters. All of them are created out of the same memory pool and type. +If any allocation fails, all allocations already made within this function call are also freed, so that when +returned result is not `VK_SUCCESS`, `pAllocation` array is always entirely filled with `VK_NULL_HANDLE`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryPages( + VmaAllocator VMA_NOT_NULL allocator, + const VkMemoryRequirements* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pVkMemoryRequirements, + const VmaAllocationCreateInfo* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pCreateInfo, + size_t allocationCount, + VmaAllocation VMA_NULLABLE* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations, + VmaAllocationInfo* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocationInfo); -Margins work with all types of memory. +/** \brief Allocates memory suitable for given `VkBuffer`. -Margin is applied only to allocations made out of memory blocks and not to dedicated -allocations, which have their own memory block of specific size. -It is thus not applied to allocations made using #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT flag -or those automatically decided to put into dedicated allocations, e.g. due to its -large size or recommended by VK_KHR_dedicated_allocation extension. -Margins are also not active in custom pools created with #VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT flag. +\param allocator +\param buffer +\param pCreateInfo +\param[out] pAllocation Handle to allocated memory. +\param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). -Margins appear in [JSON dump](@ref statistics_json_dump) as part of free space. +It only creates #VmaAllocation. To bind the memory to the buffer, use vmaBindBufferMemory(). -Note that enabling margins increases memory usage and fragmentation. +This is a special-purpose function. In most cases you should use vmaCreateBuffer(). -\section debugging_memory_usage_corruption_detection Corruption detection +You must free the allocation using vmaFreeMemory() when no longer needed. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForBuffer( + VmaAllocator VMA_NOT_NULL allocator, + VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer, + const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaAllocation VMA_NULLABLE* VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); -You can additionally define macro `VMA_DEBUG_DETECT_CORRUPTION` to 1 to enable validation -of contents of the margins. +/** \brief Allocates memory suitable for given `VkImage`. -\code -#define VMA_DEBUG_MARGIN 16 -#define VMA_DEBUG_DETECT_CORRUPTION 1 -#include -\endcode +\param allocator +\param image +\param pCreateInfo +\param[out] pAllocation Handle to allocated memory. +\param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). -When this feature is enabled, number of bytes specified as `VMA_DEBUG_MARGIN` -(it must be multiply of 4) before and after every allocation is filled with a magic number. -This idea is also know as "canary". -Memory is automatically mapped and unmapped if necessary. +It only creates #VmaAllocation. To bind the memory to the buffer, use vmaBindImageMemory(). -This number is validated automatically when the allocation is destroyed. -If it's not equal to the expected value, `VMA_ASSERT()` is executed. -It clearly means that either CPU or GPU overwritten the memory outside of boundaries of the allocation, -which indicates a serious bug. +This is a special-purpose function. In most cases you should use vmaCreateImage(). -You can also explicitly request checking margins of all allocations in all memory blocks -that belong to specified memory types by using function vmaCheckCorruption(), -or in memory blocks that belong to specified custom pool, by using function -vmaCheckPoolCorruption(). +You must free the allocation using vmaFreeMemory() when no longer needed. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForImage( + VmaAllocator VMA_NOT_NULL allocator, + VkImage VMA_NOT_NULL_NON_DISPATCHABLE image, + const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaAllocation VMA_NULLABLE* VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); -Margin validation (corruption detection) works only for memory types that are -`HOST_VISIBLE` and `HOST_COHERENT`. +/** \brief Frees memory previously allocated using vmaAllocateMemory(), vmaAllocateMemoryForBuffer(), or vmaAllocateMemoryForImage(). +Passing `VK_NULL_HANDLE` as `allocation` is valid. Such function call is just skipped. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemory( + VmaAllocator VMA_NOT_NULL allocator, + const VmaAllocation VMA_NULLABLE allocation); -\page record_and_replay Record and replay +/** \brief Frees memory and destroys multiple allocations. -\section record_and_replay_introduction Introduction +Word "pages" is just a suggestion to use this function to free pieces of memory used for sparse binding. +It is just a general purpose function to free memory and destroy allocations made using e.g. vmaAllocateMemory(), +vmaAllocateMemoryPages() and other functions. +It may be internally optimized to be more efficient than calling vmaFreeMemory() `allocationCount` times. -While using the library, sequence of calls to its functions together with their -parameters can be recorded to a file and later replayed using standalone player -application. It can be useful to: +Allocations in `pAllocations` array can come from any memory pools and types. +Passing `VK_NULL_HANDLE` as elements of `pAllocations` array is valid. Such entries are just skipped. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemoryPages( + VmaAllocator VMA_NOT_NULL allocator, + size_t allocationCount, + const VmaAllocation VMA_NULLABLE* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations); -- Test correctness - check if same sequence of calls will not cause crash or - failures on a target platform. -- Gather statistics - see number of allocations, peak memory usage, number of - calls etc. -- Benchmark performance - see how much time it takes to replay the whole - sequence. +/** \brief Returns current information about specified allocation. -\section record_and_replay_usage Usage +Current paramteres of given allocation are returned in `pAllocationInfo`. -Recording functionality is disabled by default. -To enable it, define following macro before every include of this library: +Although this function doesn't lock any mutex, so it should be quite efficient, +you should avoid calling it too often. +You can retrieve same VmaAllocationInfo structure while creating your resource, from function +vmaCreateBuffer(), vmaCreateImage(). You can remember it if you are sure parameters don't change +(e.g. due to defragmentation). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocationInfo( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VmaAllocationInfo* VMA_NOT_NULL pAllocationInfo); -\code -#define VMA_RECORDING_ENABLED 1 -\endcode +/** \brief Sets pUserData in given allocation to new value. -To record sequence of calls to a file: Fill in -VmaAllocatorCreateInfo::pRecordSettings member while creating #VmaAllocator -object. File is opened and written during whole lifetime of the allocator. +The value of pointer `pUserData` is copied to allocation's `pUserData`. +It is opaque, so you can use it however you want - e.g. +as a pointer, ordinal number or some handle to you own data. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaSetAllocationUserData( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + void* VMA_NULLABLE pUserData); -To replay file: Use VmaReplay - standalone command-line program. -Precompiled binary can be found in "bin" directory. -Its source can be found in "src/VmaReplay" directory. -Its project is generated by Premake. -Command line syntax is printed when the program is launched without parameters. -Basic usage: +/** \brief Sets pName in given allocation to new value. - VmaReplay.exe MyRecording.csv +`pName` must be either null, or pointer to a null-terminated string. The function +makes local copy of the string and sets it as allocation's `pName`. String +passed as pName doesn't need to be valid for whole lifetime of the allocation - +you can free it after this call. String previously pointed by allocation's +`pName` is freed from memory. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaSetAllocationName( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + const char* VMA_NULLABLE pName); -Documentation of file format can be found in file: "docs/Recording file format.md". -It's a human-readable, text file in CSV format (Comma Separated Values). +/** +\brief Given an allocation, returns Property Flags of its memory type. -\section record_and_replay_additional_considerations Additional considerations +This is just a convenience function. Same information can be obtained using +vmaGetAllocationInfo() + vmaGetMemoryProperties(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocationMemoryProperties( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkMemoryPropertyFlags* VMA_NOT_NULL pFlags); -- Replaying file that was recorded on a different GPU (with different parameters - like `bufferImageGranularity`, `nonCoherentAtomSize`, and especially different - set of memory heaps and types) may give different performance and memory usage - results, as well as issue some warnings and errors. -- Current implementation of recording in VMA, as well as VmaReplay application, is - coded and tested only on Windows. Inclusion of recording code is driven by - `VMA_RECORDING_ENABLED` macro. Support for other platforms should be easy to - add. Contributions are welcomed. +/** \brief Maps memory represented by given allocation and returns pointer to it. +Maps memory represented by given allocation to make it accessible to CPU code. +When succeeded, `*ppData` contains pointer to first byte of this memory. -\page usage_patterns Recommended usage patterns +\warning +If the allocation is part of a bigger `VkDeviceMemory` block, returned pointer is +correctly offsetted to the beginning of region assigned to this particular allocation. +Unlike the result of `vkMapMemory`, it points to the allocation, not to the beginning of the whole block. +You should not add VmaAllocationInfo::offset to it! -See also slides from talk: -[Sawicki, Adam. Advanced Graphics Techniques Tutorial: Memory management in Vulkan and DX12. Game Developers Conference, 2018](https://www.gdcvault.com/play/1025458/Advanced-Graphics-Techniques-Tutorial-New) +Mapping is internally reference-counted and synchronized, so despite raw Vulkan +function `vkMapMemory()` cannot be used to map same block of `VkDeviceMemory` +multiple times simultaneously, it is safe to call this function on allocations +assigned to the same memory block. Actual Vulkan memory will be mapped on first +mapping and unmapped on last unmapping. +If the function succeeded, you must call vmaUnmapMemory() to unmap the +allocation when mapping is no longer needed or before freeing the allocation, at +the latest. -\section usage_patterns_common_mistakes Common mistakes +It also safe to call this function multiple times on the same allocation. You +must call vmaUnmapMemory() same number of times as you called vmaMapMemory(). -Use of CPU_TO_GPU instead of CPU_ONLY memory +It is also safe to call this function on allocation created with +#VMA_ALLOCATION_CREATE_MAPPED_BIT flag. Its memory stays mapped all the time. +You must still call vmaUnmapMemory() same number of times as you called +vmaMapMemory(). You must not call vmaUnmapMemory() additional time to free the +"0-th" mapping made automatically due to #VMA_ALLOCATION_CREATE_MAPPED_BIT flag. -#VMA_MEMORY_USAGE_CPU_TO_GPU is recommended only for resources that will be -mapped and written by the CPU, as well as read directly by the GPU - like some -buffers or textures updated every frame (dynamic). If you create a staging copy -of a resource to be written by CPU and then used as a source of transfer to -another resource placed in the GPU memory, that staging resource should be -created with #VMA_MEMORY_USAGE_CPU_ONLY. Please read the descriptions of these -enums carefully for details. +This function fails when used on allocation made in memory type that is not +`HOST_VISIBLE`. -Unnecessary use of custom pools +This function doesn't automatically flush or invalidate caches. +If the allocation is made from a memory types that is not `HOST_COHERENT`, +you also need to use vmaInvalidateAllocation() / vmaFlushAllocation(), as required by Vulkan specification. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaMapMemory( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + void* VMA_NULLABLE* VMA_NOT_NULL ppData); -\ref custom_memory_pools may be useful for special purposes - when you want to -keep certain type of resources separate e.g. to reserve minimum amount of memory -for them, limit maximum amount of memory they can occupy, or make some of them -push out the other through the mechanism of \ref lost_allocations. For most -resources this is not needed and so it is not recommended to create #VmaPool -objects and allocations out of them. Allocating from the default pool is sufficient. +/** \brief Unmaps memory represented by given allocation, mapped previously using vmaMapMemory(). -\section usage_patterns_simple Simple patterns +For details, see description of vmaMapMemory(). -\subsection usage_patterns_simple_render_targets Render targets +This function doesn't automatically flush or invalidate caches. +If the allocation is made from a memory types that is not `HOST_COHERENT`, +you also need to use vmaInvalidateAllocation() / vmaFlushAllocation(), as required by Vulkan specification. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaUnmapMemory( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation); -When: -Any resources that you frequently write and read on GPU, -e.g. images used as color attachments (aka "render targets"), depth-stencil attachments, -images/buffers used as storage image/buffer (aka "Unordered Access View (UAV)"). +/** \brief Flushes memory of given allocation. -What to do: -Create them in video memory that is fastest to access from GPU using -#VMA_MEMORY_USAGE_GPU_ONLY. +Calls `vkFlushMappedMemoryRanges()` for memory associated with given range of given allocation. +It needs to be called after writing to a mapped memory for memory types that are not `HOST_COHERENT`. +Unmap operation doesn't do that automatically. -Consider using [VK_KHR_dedicated_allocation](@ref vk_khr_dedicated_allocation) extension -and/or manually creating them as dedicated allocations using #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT, -especially if they are large or if you plan to destroy and recreate them e.g. when -display resolution changes. -Prefer to create such resources first and all other GPU resources (like textures and vertex buffers) later. +- `offset` must be relative to the beginning of allocation. +- `size` can be `VK_WHOLE_SIZE`. It means all memory from `offset` the the end of given allocation. +- `offset` and `size` don't have to be aligned. + They are internally rounded down/up to multiply of `nonCoherentAtomSize`. +- If `size` is 0, this call is ignored. +- If memory type that the `allocation` belongs to is not `HOST_VISIBLE` or it is `HOST_COHERENT`, + this call is ignored. -\subsection usage_patterns_simple_immutable_resources Immutable resources +Warning! `offset` and `size` are relative to the contents of given `allocation`. +If you mean whole allocation, you can pass 0 and `VK_WHOLE_SIZE`, respectively. +Do not pass allocation's offset as `offset`!!! -When: -Any resources that you fill on CPU only once (aka "immutable") or infrequently -and then read frequently on GPU, -e.g. textures, vertex and index buffers, constant buffers that don't change often. +This function returns the `VkResult` from `vkFlushMappedMemoryRanges` if it is +called, otherwise `VK_SUCCESS`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocation( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize offset, + VkDeviceSize size); -What to do: -Create them in video memory that is fastest to access from GPU using -#VMA_MEMORY_USAGE_GPU_ONLY. +/** \brief Invalidates memory of given allocation. -To initialize content of such resource, create a CPU-side (aka "staging") copy of it -in system memory - #VMA_MEMORY_USAGE_CPU_ONLY, map it, fill it, -and submit a transfer from it to the GPU resource. -You can keep the staging copy if you need it for another upload transfer in the future. -If you don't, you can destroy it or reuse this buffer for uploading different resource -after the transfer finishes. +Calls `vkInvalidateMappedMemoryRanges()` for memory associated with given range of given allocation. +It needs to be called before reading from a mapped memory for memory types that are not `HOST_COHERENT`. +Map operation doesn't do that automatically. -Prefer to create just buffers in system memory rather than images, even for uploading textures. -Use `vkCmdCopyBufferToImage()`. -Dont use images with `VK_IMAGE_TILING_LINEAR`. +- `offset` must be relative to the beginning of allocation. +- `size` can be `VK_WHOLE_SIZE`. It means all memory from `offset` the the end of given allocation. +- `offset` and `size` don't have to be aligned. + They are internally rounded down/up to multiply of `nonCoherentAtomSize`. +- If `size` is 0, this call is ignored. +- If memory type that the `allocation` belongs to is not `HOST_VISIBLE` or it is `HOST_COHERENT`, + this call is ignored. -\subsection usage_patterns_dynamic_resources Dynamic resources +Warning! `offset` and `size` are relative to the contents of given `allocation`. +If you mean whole allocation, you can pass 0 and `VK_WHOLE_SIZE`, respectively. +Do not pass allocation's offset as `offset`!!! -When: -Any resources that change frequently (aka "dynamic"), e.g. every frame or every draw call, -written on CPU, read on GPU. +This function returns the `VkResult` from `vkInvalidateMappedMemoryRanges` if +it is called, otherwise `VK_SUCCESS`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocation( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize offset, + VkDeviceSize size); -What to do: -Create them using #VMA_MEMORY_USAGE_CPU_TO_GPU. -You can map it and write to it directly on CPU, as well as read from it on GPU. +/** \brief Flushes memory of given set of allocations. -This is a more complex situation. Different solutions are possible, -and the best one depends on specific GPU type, but you can use this simple approach for the start. -Prefer to write to such resource sequentially (e.g. using `memcpy`). -Don't perform random access or any reads from it on CPU, as it may be very slow. -Also note that textures written directly from the host through a mapped pointer need to be in LINEAR not OPTIMAL layout. +Calls `vkFlushMappedMemoryRanges()` for memory associated with given ranges of given allocations. +For more information, see documentation of vmaFlushAllocation(). -\subsection usage_patterns_readback Readback +\param allocator +\param allocationCount +\param allocations +\param offsets If not null, it must point to an array of offsets of regions to flush, relative to the beginning of respective allocations. Null means all ofsets are zero. +\param sizes If not null, it must point to an array of sizes of regions to flush in respective allocations. Null means `VK_WHOLE_SIZE` for all allocations. -When: -Resources that contain data written by GPU that you want to read back on CPU, -e.g. results of some computations. +This function returns the `VkResult` from `vkFlushMappedMemoryRanges` if it is +called, otherwise `VK_SUCCESS`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocations( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t allocationCount, + const VmaAllocation VMA_NOT_NULL* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) allocations, + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) offsets, + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) sizes); -What to do: -Create them using #VMA_MEMORY_USAGE_GPU_TO_CPU. -You can write to them directly on GPU, as well as map and read them on CPU. +/** \brief Invalidates memory of given set of allocations. -\section usage_patterns_advanced Advanced patterns +Calls `vkInvalidateMappedMemoryRanges()` for memory associated with given ranges of given allocations. +For more information, see documentation of vmaInvalidateAllocation(). -\subsection usage_patterns_integrated_graphics Detecting integrated graphics +\param allocator +\param allocationCount +\param allocations +\param offsets If not null, it must point to an array of offsets of regions to flush, relative to the beginning of respective allocations. Null means all ofsets are zero. +\param sizes If not null, it must point to an array of sizes of regions to flush in respective allocations. Null means `VK_WHOLE_SIZE` for all allocations. -You can support integrated graphics (like Intel HD Graphics, AMD APU) better -by detecting it in Vulkan. -To do it, call `vkGetPhysicalDeviceProperties()`, inspect -`VkPhysicalDeviceProperties::deviceType` and look for `VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU`. -When you find it, you can assume that memory is unified and all memory types are comparably fast -to access from GPU, regardless of `VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT`. +This function returns the `VkResult` from `vkInvalidateMappedMemoryRanges` if it is +called, otherwise `VK_SUCCESS`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocations( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t allocationCount, + const VmaAllocation VMA_NOT_NULL* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) allocations, + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) offsets, + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) sizes); -You can then sum up sizes of all available memory heaps and treat them as useful for -your GPU resources, instead of only `DEVICE_LOCAL` ones. -You can also prefer to create your resources in memory types that are `HOST_VISIBLE` to map them -directly instead of submitting explicit transfer (see below). +/** \brief Checks magic number in margins around all allocations in given memory types (in both default and custom pools) in search for corruptions. -\subsection usage_patterns_direct_vs_transfer Direct access versus transfer +\param allocator +\param memoryTypeBits Bit mask, where each bit set means that a memory type with that index should be checked. -For resources that you frequently write on CPU and read on GPU, many solutions are possible: +Corruption detection is enabled only when `VMA_DEBUG_DETECT_CORRUPTION` macro is defined to nonzero, +`VMA_DEBUG_MARGIN` is defined to nonzero and only for memory types that are +`HOST_VISIBLE` and `HOST_COHERENT`. For more information, see [Corruption detection](@ref debugging_memory_usage_corruption_detection). --# Create one copy in video memory using #VMA_MEMORY_USAGE_GPU_ONLY, - second copy in system memory using #VMA_MEMORY_USAGE_CPU_ONLY and submit explicit transfer each time. --# Create just a single copy using #VMA_MEMORY_USAGE_CPU_TO_GPU, map it and fill it on CPU, - read it directly on GPU. --# Create just a single copy using #VMA_MEMORY_USAGE_CPU_ONLY, map it and fill it on CPU, - read it directly on GPU. +Possible return values: -Which solution is the most efficient depends on your resource and especially on the GPU. -It is best to measure it and then make the decision. -Some general recommendations: +- `VK_ERROR_FEATURE_NOT_PRESENT` - corruption detection is not enabled for any of specified memory types. +- `VK_SUCCESS` - corruption detection has been performed and succeeded. +- `VK_ERROR_UNKNOWN` - corruption detection has been performed and found memory corruptions around one of the allocations. + `VMA_ASSERT` is also fired in that case. +- Other value: Error returned by Vulkan, e.g. memory mapping failure. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckCorruption( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryTypeBits); -- On integrated graphics use (2) or (3) to avoid unnecesary time and memory overhead - related to using a second copy and making transfer. -- For small resources (e.g. constant buffers) use (2). - Discrete AMD cards have special 256 MiB pool of video memory that is directly mappable. - Even if the resource ends up in system memory, its data may be cached on GPU after first - fetch over PCIe bus. -- For larger resources (e.g. textures), decide between (1) and (2). - You may want to differentiate NVIDIA and AMD, e.g. by looking for memory type that is - both `DEVICE_LOCAL` and `HOST_VISIBLE`. When you find it, use (2), otherwise use (1). +/** \brief Begins defragmentation process. -Similarly, for resources that you frequently write on GPU and read on CPU, multiple -solutions are possible: +\param allocator Allocator object. +\param pInfo Structure filled with parameters of defragmentation. +\param[out] pContext Context object that must be passed to vmaEndDefragmentation() to finish defragmentation. +\returns +- `VK_SUCCESS` if defragmentation can begin. +- `VK_ERROR_FEATURE_NOT_PRESENT` if defragmentation is not supported. --# Create one copy in video memory using #VMA_MEMORY_USAGE_GPU_ONLY, - second copy in system memory using #VMA_MEMORY_USAGE_GPU_TO_CPU and submit explicit tranfer each time. --# Create just single copy using #VMA_MEMORY_USAGE_GPU_TO_CPU, write to it directly on GPU, - map it and read it on CPU. +For more information about defragmentation, see documentation chapter: +[Defragmentation](@ref defragmentation). +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBeginDefragmentation( + VmaAllocator VMA_NOT_NULL allocator, + const VmaDefragmentationInfo* VMA_NOT_NULL pInfo, + VmaDefragmentationContext VMA_NULLABLE* VMA_NOT_NULL pContext); -You should take some measurements to decide which option is faster in case of your specific -resource. +/** \brief Ends defragmentation process. -Note that textures accessed directly from the host through a mapped pointer need to be in LINEAR layout, -which may slow down their usage on the device. -Textures accessed only by the device and transfer operations can use OPTIMAL layout. +\param allocator Allocator object. +\param context Context object that has been created by vmaBeginDefragmentation(). +\param[out] pStats Optional stats for the defragmentation. Can be null. -If you don't want to specialize your code for specific types of GPUs, you can still make -an simple optimization for cases when your resource ends up in mappable memory to use it -directly in this case instead of creating CPU-side staging copy. -For details see [Finding out if memory is mappable](@ref memory_mapping_finding_if_memory_mappable). +Use this function to finish defragmentation started by vmaBeginDefragmentation(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaEndDefragmentation( + VmaAllocator VMA_NOT_NULL allocator, + VmaDefragmentationContext VMA_NOT_NULL context, + VmaDefragmentationStats* VMA_NULLABLE pStats); + +/** \brief Starts single defragmentation pass. + +\param allocator Allocator object. +\param context Context object that has been created by vmaBeginDefragmentation(). +\param[out] pPassInfo Computed informations for current pass. +\returns +- `VK_SUCCESS` if no more moves are possible. Then you can omit call to vmaEndDefragmentationPass() and simply end whole defragmentation. +- `VK_INCOMPLETE` if there are pending moves returned in `pPassInfo`. You need to perform them, call vmaEndDefragmentationPass(), + and then preferably try another pass with vmaBeginDefragmentationPass(). +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBeginDefragmentationPass( + VmaAllocator VMA_NOT_NULL allocator, + VmaDefragmentationContext VMA_NOT_NULL context, + VmaDefragmentationPassMoveInfo* VMA_NOT_NULL pPassInfo); +/** \brief Ends single defragmentation pass. -\page configuration Configuration +\param allocator Allocator object. +\param context Context object that has been created by vmaBeginDefragmentation(). +\param pPassInfo Computed informations for current pass filled by vmaBeginDefragmentationPass() and possibly modified by you. -Please check "CONFIGURATION SECTION" in the code to find macros that you can define -before each include of this file or change directly in this file to provide -your own implementation of basic facilities like assert, `min()` and `max()` functions, -mutex, atomic etc. -The library uses its own implementation of containers by default, but you can switch to using -STL containers instead. +Returns `VK_SUCCESS` if no more moves are possible or `VK_INCOMPLETE` if more defragmentations are possible. -For example, define `VMA_ASSERT(expr)` before including the library to provide -custom implementation of the assertion, compatible with your project. -By default it is defined to standard C `assert(expr)` in `_DEBUG` configuration -and empty otherwise. +Ends incremental defragmentation pass and commits all defragmentation moves from `pPassInfo`. +After this call: -\section config_Vulkan_functions Pointers to Vulkan functions +- Allocations at `pPassInfo[i].srcAllocation` that had `pPassInfo[i].operation ==` #VMA_DEFRAGMENTATION_MOVE_OPERATION_COPY + (which is the default) will be pointing to the new destination place. +- Allocation at `pPassInfo[i].srcAllocation` that had `pPassInfo[i].operation ==` #VMA_DEFRAGMENTATION_MOVE_OPERATION_DESTROY + will be freed. -There are multiple ways to import pointers to Vulkan functions in the library. -In the simplest case you don't need to do anything. -If the compilation or linking of your program or the initialization of the #VmaAllocator -doesn't work for you, you can try to reconfigure it. +If no more moves are possible you can end whole defragmentation. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaEndDefragmentationPass( + VmaAllocator VMA_NOT_NULL allocator, + VmaDefragmentationContext VMA_NOT_NULL context, + VmaDefragmentationPassMoveInfo* VMA_NOT_NULL pPassInfo); -First, the allocator tries to fetch pointers to Vulkan functions linked statically, -like this: +/** \brief Binds buffer to allocation. -\code -m_VulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkAllocateMemory; -\endcode +Binds specified buffer to region of memory represented by specified allocation. +Gets `VkDeviceMemory` handle and offset from the allocation. +If you want to create a buffer, allocate memory for it and bind them together separately, +you should use this function for binding instead of standard `vkBindBufferMemory()`, +because it ensures proper synchronization so that when a `VkDeviceMemory` object is used by multiple +allocations, calls to `vkBind*Memory()` or `vkMapMemory()` won't happen from multiple threads simultaneously +(which is illegal in Vulkan). -If you want to disable this feature, set configuration macro: `#define VMA_STATIC_VULKAN_FUNCTIONS 0`. +It is recommended to use function vmaCreateBuffer() instead of this one. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer); -Second, you can provide the pointers yourself by setting member VmaAllocatorCreateInfo::pVulkanFunctions. -You can fetch them e.g. using functions `vkGetInstanceProcAddr` and `vkGetDeviceProcAddr` or -by using a helper library like [volk](https://github.com/zeux/volk). +/** \brief Binds buffer to allocation with additional parameters. -Third, VMA tries to fetch remaining pointers that are still null by calling -`vkGetInstanceProcAddr` and `vkGetDeviceProcAddr` on its own. -If you want to disable this feature, set configuration macro: `#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0`. +\param allocator +\param allocation +\param allocationLocalOffset Additional offset to be added while binding, relative to the beginning of the `allocation`. Normally it should be 0. +\param buffer +\param pNext A chain of structures to be attached to `VkBindBufferMemoryInfoKHR` structure used internally. Normally it should be null. -Finally, all the function pointers required by the library (considering selected -Vulkan version and enabled extensions) are checked with `VMA_ASSERT` if they are not null. +This function is similar to vmaBindBufferMemory(), but it provides additional parameters. +If `pNext` is not null, #VmaAllocator object must have been created with #VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT flag +or with VmaAllocatorCreateInfo::vulkanApiVersion `>= VK_API_VERSION_1_1`. Otherwise the call fails. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory2( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize allocationLocalOffset, + VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer, + const void* VMA_NULLABLE pNext); -\section custom_memory_allocator Custom host memory allocator +/** \brief Binds image to allocation. -If you use custom allocator for CPU memory rather than default operator `new` -and `delete` from C++, you can make this library using your allocator as well -by filling optional member VmaAllocatorCreateInfo::pAllocationCallbacks. These -functions will be passed to Vulkan, as well as used by the library itself to -make any CPU-side allocations. +Binds specified image to region of memory represented by specified allocation. +Gets `VkDeviceMemory` handle and offset from the allocation. +If you want to create an image, allocate memory for it and bind them together separately, +you should use this function for binding instead of standard `vkBindImageMemory()`, +because it ensures proper synchronization so that when a `VkDeviceMemory` object is used by multiple +allocations, calls to `vkBind*Memory()` or `vkMapMemory()` won't happen from multiple threads simultaneously +(which is illegal in Vulkan). -\section allocation_callbacks Device memory allocation callbacks +It is recommended to use function vmaCreateImage() instead of this one. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkImage VMA_NOT_NULL_NON_DISPATCHABLE image); -The library makes calls to `vkAllocateMemory()` and `vkFreeMemory()` internally. -You can setup callbacks to be informed about these calls, e.g. for the purpose -of gathering some statistics. To do it, fill optional member -VmaAllocatorCreateInfo::pDeviceMemoryCallbacks. +/** \brief Binds image to allocation with additional parameters. -\section heap_memory_limit Device heap memory limit +\param allocator +\param allocation +\param allocationLocalOffset Additional offset to be added while binding, relative to the beginning of the `allocation`. Normally it should be 0. +\param image +\param pNext A chain of structures to be attached to `VkBindImageMemoryInfoKHR` structure used internally. Normally it should be null. -When device memory of certain heap runs out of free space, new allocations may -fail (returning error code) or they may succeed, silently pushing some existing -memory blocks from GPU VRAM to system RAM (which degrades performance). This -behavior is implementation-dependent - it depends on GPU vendor and graphics -driver. +This function is similar to vmaBindImageMemory(), but it provides additional parameters. -On AMD cards it can be controlled while creating Vulkan device object by using -VK_AMD_memory_overallocation_behavior extension, if available. +If `pNext` is not null, #VmaAllocator object must have been created with #VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT flag +or with VmaAllocatorCreateInfo::vulkanApiVersion `>= VK_API_VERSION_1_1`. Otherwise the call fails. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory2( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize allocationLocalOffset, + VkImage VMA_NOT_NULL_NON_DISPATCHABLE image, + const void* VMA_NULLABLE pNext); -Alternatively, if you want to test how your program behaves with limited amount of Vulkan device -memory available without switching your graphics card to one that really has -smaller VRAM, you can use a feature of this library intended for this purpose. -To do it, fill optional member VmaAllocatorCreateInfo::pHeapSizeLimit. +/** \brief Creates a new `VkBuffer`, allocates and binds memory for it. +\param allocator +\param pBufferCreateInfo +\param pAllocationCreateInfo +\param[out] pBuffer Buffer that was created. +\param[out] pAllocation Allocation that was created. +\param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). +This function automatically: -\page vk_khr_dedicated_allocation VK_KHR_dedicated_allocation +-# Creates buffer. +-# Allocates appropriate memory for it. +-# Binds the buffer with the memory. -VK_KHR_dedicated_allocation is a Vulkan extension which can be used to improve -performance on some GPUs. It augments Vulkan API with possibility to query -driver whether it prefers particular buffer or image to have its own, dedicated -allocation (separate `VkDeviceMemory` block) for better efficiency - to be able -to do some internal optimizations. +If any of these operations fail, buffer and allocation are not created, +returned value is negative error code, `*pBuffer` and `*pAllocation` are null. -The extension is supported by this library. It will be used automatically when -enabled. To enable it: +If the function succeeded, you must destroy both buffer and allocation when you +no longer need them using either convenience function vmaDestroyBuffer() or +separately, using `vkDestroyBuffer()` and vmaFreeMemory(). -1 . When creating Vulkan device, check if following 2 device extensions are -supported (call `vkEnumerateDeviceExtensionProperties()`). -If yes, enable them (fill `VkDeviceCreateInfo::ppEnabledExtensionNames`). +If #VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT flag was used, +VK_KHR_dedicated_allocation extension is used internally to query driver whether +it requires or prefers the new buffer to have dedicated allocation. If yes, +and if dedicated allocation is possible +(#VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT is not used), it creates dedicated +allocation for this buffer, just like when using +#VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. -- VK_KHR_get_memory_requirements2 -- VK_KHR_dedicated_allocation +\note This function creates a new `VkBuffer`. Sub-allocation of parts of one large buffer, +although recommended as a good practice, is out of scope of this library and could be implemented +by the user as a higher-level logic on top of VMA. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBuffer( + VmaAllocator VMA_NOT_NULL allocator, + const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + VkBuffer VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pBuffer, + VmaAllocation VMA_NULLABLE* VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); -If you enabled these extensions: +/** \brief Creates a buffer with additional minimum alignment. -2 . Use #VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT flag when creating -your #VmaAllocator`to inform the library that you enabled required extensions -and you want the library to use them. +Similar to vmaCreateBuffer() but provides additional parameter `minAlignment` which allows to specify custom, +minimum alignment to be used when placing the buffer inside a larger memory block, which may be needed e.g. +for interop with OpenGL. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBufferWithAlignment( + VmaAllocator VMA_NOT_NULL allocator, + const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + VkDeviceSize minAlignment, + VkBuffer VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pBuffer, + VmaAllocation VMA_NULLABLE* VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); -\code -allocatorInfo.flags |= VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT; +/** \brief Creates a new `VkBuffer`, binds already created memory for it. -vmaCreateAllocator(&allocatorInfo, &allocator); -\endcode +\param allocator +\param allocation Allocation that provides memory to be used for binding new buffer to it. +\param pBufferCreateInfo +\param[out] pBuffer Buffer that was created. -That's all. The extension will be automatically used whenever you create a -buffer using vmaCreateBuffer() or image using vmaCreateImage(). +This function automatically: -When using the extension together with Vulkan Validation Layer, you will receive -warnings like this: +-# Creates buffer. +-# Binds the buffer with the supplied memory. - vkBindBufferMemory(): Binding memory to buffer 0x33 but vkGetBufferMemoryRequirements() has not been called on that buffer. +If any of these operations fail, buffer is not created, +returned value is negative error code and `*pBuffer` is null. -It is OK, you should just ignore it. It happens because you use function -`vkGetBufferMemoryRequirements2KHR()` instead of standard -`vkGetBufferMemoryRequirements()`, while the validation layer seems to be -unaware of it. +If the function succeeded, you must destroy the buffer when you +no longer need it using `vkDestroyBuffer()`. If you want to also destroy the corresponding +allocation you can use convenience function vmaDestroyBuffer(). +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAliasingBuffer( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, + VkBuffer VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pBuffer); -To learn more about this extension, see: +/** \brief Destroys Vulkan buffer and frees allocated memory. -- [VK_KHR_dedicated_allocation in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap44.html#VK_KHR_dedicated_allocation) -- [VK_KHR_dedicated_allocation unofficial manual](http://asawicki.info/articles/VK_KHR_dedicated_allocation.php5) +This is just a convenience function equivalent to: +\code +vkDestroyBuffer(device, buffer, allocationCallbacks); +vmaFreeMemory(allocator, allocation); +\endcode +It it safe to pass null as buffer and/or allocation. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyBuffer( + VmaAllocator VMA_NOT_NULL allocator, + VkBuffer VMA_NULLABLE_NON_DISPATCHABLE buffer, + VmaAllocation VMA_NULLABLE allocation); -\page vk_amd_device_coherent_memory VK_AMD_device_coherent_memory +/// Function similar to vmaCreateBuffer(). +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateImage( + VmaAllocator VMA_NOT_NULL allocator, + const VkImageCreateInfo* VMA_NOT_NULL pImageCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + VkImage VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pImage, + VmaAllocation VMA_NULLABLE* VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); -VK_AMD_device_coherent_memory is a device extension that enables access to -additional memory types with `VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD` and -`VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD` flag. It is useful mostly for -allocation of buffers intended for writing "breadcrumb markers" in between passes -or draw calls, which in turn are useful for debugging GPU crash/hang/TDR cases. +/// Function similar to vmaCreateAliasingBuffer(). +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAliasingImage( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + const VkImageCreateInfo* VMA_NOT_NULL pImageCreateInfo, + VkImage VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pImage); -When the extension is available but has not been enabled, Vulkan physical device -still exposes those memory types, but their usage is forbidden. VMA automatically -takes care of that - it returns `VK_ERROR_FEATURE_NOT_PRESENT` when an attempt -to allocate memory of such type is made. +/** \brief Destroys Vulkan image and frees allocated memory. -If you want to use this extension in connection with VMA, follow these steps: +This is just a convenience function equivalent to: -\section vk_amd_device_coherent_memory_initialization Initialization +\code +vkDestroyImage(device, image, allocationCallbacks); +vmaFreeMemory(allocator, allocation); +\endcode -1) Call `vkEnumerateDeviceExtensionProperties` for the physical device. -Check if the extension is supported - if returned array of `VkExtensionProperties` contains "VK_AMD_device_coherent_memory". +It it safe to pass null as image and/or allocation. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyImage( + VmaAllocator VMA_NOT_NULL allocator, + VkImage VMA_NULLABLE_NON_DISPATCHABLE image, + VmaAllocation VMA_NULLABLE allocation); -2) Call `vkGetPhysicalDeviceFeatures2` for the physical device instead of old `vkGetPhysicalDeviceFeatures`. -Attach additional structure `VkPhysicalDeviceCoherentMemoryFeaturesAMD` to `VkPhysicalDeviceFeatures2::pNext` to be returned. -Check if the device feature is really supported - check if `VkPhysicalDeviceCoherentMemoryFeaturesAMD::deviceCoherentMemory` is true. +/** @} */ -3) While creating device with `vkCreateDevice`, enable this extension - add "VK_AMD_device_coherent_memory" -to the list passed as `VkDeviceCreateInfo::ppEnabledExtensionNames`. +/** +\addtogroup group_virtual +@{ +*/ -4) While creating the device, also don't set `VkDeviceCreateInfo::pEnabledFeatures`. -Fill in `VkPhysicalDeviceFeatures2` structure instead and pass it as `VkDeviceCreateInfo::pNext`. -Enable this device feature - attach additional structure `VkPhysicalDeviceCoherentMemoryFeaturesAMD` to -`VkPhysicalDeviceFeatures2::pNext` and set its member `deviceCoherentMemory` to `VK_TRUE`. +/** \brief Creates new #VmaVirtualBlock object. -5) While creating #VmaAllocator with vmaCreateAllocator() inform VMA that you -have enabled this extension and feature - add #VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT -to VmaAllocatorCreateInfo::flags. +\param pCreateInfo Parameters for creation. +\param[out] pVirtualBlock Returned virtual block object or `VMA_NULL` if creation failed. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateVirtualBlock( + const VmaVirtualBlockCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaVirtualBlock VMA_NULLABLE* VMA_NOT_NULL pVirtualBlock); -\section vk_amd_device_coherent_memory_usage Usage +/** \brief Destroys #VmaVirtualBlock object. -After following steps described above, you can create VMA allocations and custom pools -out of the special `DEVICE_COHERENT` and `DEVICE_UNCACHED` memory types on eligible -devices. There are multiple ways to do it, for example: +Please note that you should consciously handle virtual allocations that could remain unfreed in the block. +You should either free them individually using vmaVirtualFree() or call vmaClearVirtualBlock() +if you are sure this is what you want. If you do neither, an assert is called. -- You can request or prefer to allocate out of such memory types by adding - `VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD` to VmaAllocationCreateInfo::requiredFlags - or VmaAllocationCreateInfo::preferredFlags. Those flags can be freely mixed with - other ways of \ref choosing_memory_type, like setting VmaAllocationCreateInfo::usage. -- If you manually found memory type index to use for this purpose, force allocation - from this specific index by setting VmaAllocationCreateInfo::memoryTypeBits `= 1u << index`. +If you keep pointers to some additional metadata associated with your virtual allocations in their `pUserData`, +don't forget to free them. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyVirtualBlock( + VmaVirtualBlock VMA_NULLABLE virtualBlock); -\section vk_amd_device_coherent_memory_more_information More information +/** \brief Returns true of the #VmaVirtualBlock is empty - contains 0 virtual allocations and has all its space available for new allocations. +*/ +VMA_CALL_PRE VkBool32 VMA_CALL_POST vmaIsVirtualBlockEmpty( + VmaVirtualBlock VMA_NOT_NULL virtualBlock); -To learn more about this extension, see [VK_AMD_device_coherent_memory in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap44.html#VK_AMD_device_coherent_memory) +/** \brief Returns information about a specific virtual allocation within a virtual block, like its size and `pUserData` pointer. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetVirtualAllocationInfo( + VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaVirtualAllocation VMA_NOT_NULL_NON_DISPATCHABLE allocation, VmaVirtualAllocationInfo* VMA_NOT_NULL pVirtualAllocInfo); -Example use of this extension can be found in the code of the sample and test suite -accompanying this library. +/** \brief Allocates new virtual allocation inside given #VmaVirtualBlock. +If the allocation fails due to not enough free space available, `VK_ERROR_OUT_OF_DEVICE_MEMORY` is returned +(despite the function doesn't ever allocate actual GPU memory). +`pAllocation` is then set to `VK_NULL_HANDLE` and `pOffset`, if not null, it set to `UINT64_MAX`. -\page enabling_buffer_device_address Enabling buffer device address +\param virtualBlock Virtual block +\param pCreateInfo Parameters for the allocation +\param[out] pAllocation Returned handle of the new allocation +\param[out] pOffset Returned offset of the new allocation. Optional, can be null. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaVirtualAllocate( + VmaVirtualBlock VMA_NOT_NULL virtualBlock, + const VmaVirtualAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaVirtualAllocation VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pAllocation, + VkDeviceSize* VMA_NULLABLE pOffset); -Device extension VK_KHR_buffer_device_address -allow to fetch raw GPU pointer to a buffer and pass it for usage in a shader code. -It is promoted to core Vulkan 1.2. +/** \brief Frees virtual allocation inside given #VmaVirtualBlock. -If you want to use this feature in connection with VMA, follow these steps: +It is correct to call this function with `allocation == VK_NULL_HANDLE` - it does nothing. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaVirtualFree( + VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaVirtualAllocation VMA_NULLABLE_NON_DISPATCHABLE allocation); -\section enabling_buffer_device_address_initialization Initialization - -1) (For Vulkan version < 1.2) Call `vkEnumerateDeviceExtensionProperties` for the physical device. -Check if the extension is supported - if returned array of `VkExtensionProperties` contains -"VK_KHR_buffer_device_address". +/** \brief Frees all virtual allocations inside given #VmaVirtualBlock. -2) Call `vkGetPhysicalDeviceFeatures2` for the physical device instead of old `vkGetPhysicalDeviceFeatures`. -Attach additional structure `VkPhysicalDeviceBufferDeviceAddressFeatures*` to `VkPhysicalDeviceFeatures2::pNext` to be returned. -Check if the device feature is really supported - check if `VkPhysicalDeviceBufferDeviceAddressFeatures*::bufferDeviceAddress` is true. +You must either call this function or free each virtual allocation individually with vmaVirtualFree() +before destroying a virtual block. Otherwise, an assert is called. -3) (For Vulkan version < 1.2) While creating device with `vkCreateDevice`, enable this extension - add -"VK_KHR_buffer_device_address" to the list passed as `VkDeviceCreateInfo::ppEnabledExtensionNames`. +If you keep pointer to some additional metadata associated with your virtual allocation in its `pUserData`, +don't forget to free it as well. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaClearVirtualBlock( + VmaVirtualBlock VMA_NOT_NULL virtualBlock); -4) While creating the device, also don't set `VkDeviceCreateInfo::pEnabledFeatures`. -Fill in `VkPhysicalDeviceFeatures2` structure instead and pass it as `VkDeviceCreateInfo::pNext`. -Enable this device feature - attach additional structure `VkPhysicalDeviceBufferDeviceAddressFeatures*` to -`VkPhysicalDeviceFeatures2::pNext` and set its member `bufferDeviceAddress` to `VK_TRUE`. +/** \brief Changes custom pointer associated with given virtual allocation. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaSetVirtualAllocationUserData( + VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaVirtualAllocation VMA_NOT_NULL_NON_DISPATCHABLE allocation, + void* VMA_NULLABLE pUserData); -5) While creating #VmaAllocator with vmaCreateAllocator() inform VMA that you -have enabled this feature - add #VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT -to VmaAllocatorCreateInfo::flags. +/** \brief Calculates and returns statistics about virtual allocations and memory usage in given #VmaVirtualBlock. -\section enabling_buffer_device_address_usage Usage +This function is fast to call. For more detailed statistics, see vmaCalculateVirtualBlockStatistics(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetVirtualBlockStatistics( + VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaStatistics* VMA_NOT_NULL pStats); -After following steps described above, you can create buffers with `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT*` using VMA. -The library automatically adds `VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT*` to -allocated memory blocks wherever it might be needed. +/** \brief Calculates and returns detailed statistics about virtual allocations and memory usage in given #VmaVirtualBlock. -Please note that the library supports only `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT*`. -The second part of this functionality related to "capture and replay" is not supported, -as it is intended for usage in debugging tools like RenderDoc, not in everyday Vulkan usage. +This function is slow to call. Use for debugging purposes. +For less detailed statistics, see vmaGetVirtualBlockStatistics(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaCalculateVirtualBlockStatistics( + VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaDetailedStatistics* VMA_NOT_NULL pStats); -\section enabling_buffer_device_address_more_information More information +/** @} */ -To learn more about this extension, see [VK_KHR_buffer_device_address in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap46.html#VK_KHR_buffer_device_address) +#if VMA_STATS_STRING_ENABLED +/** +\addtogroup group_stats +@{ +*/ -Example use of this extension can be found in the code of the sample and test suite -accompanying this library. +/** \brief Builds and returns a null-terminated string in JSON format with information about given #VmaVirtualBlock. +\param virtualBlock Virtual block. +\param[out] ppStatsString Returned string. +\param detailedMap Pass `VK_FALSE` to only obtain statistics as returned by vmaCalculateVirtualBlockStatistics(). Pass `VK_TRUE` to also obtain full list of allocations and free spaces. -\page general_considerations General considerations +Returned string must be freed using vmaFreeVirtualBlockStatsString(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaBuildVirtualBlockStatsString( + VmaVirtualBlock VMA_NOT_NULL virtualBlock, + char* VMA_NULLABLE* VMA_NOT_NULL ppStatsString, + VkBool32 detailedMap); -\section general_considerations_thread_safety Thread safety +/// Frees a string returned by vmaBuildVirtualBlockStatsString(). +VMA_CALL_PRE void VMA_CALL_POST vmaFreeVirtualBlockStatsString( + VmaVirtualBlock VMA_NOT_NULL virtualBlock, + char* VMA_NULLABLE pStatsString); -- The library has no global state, so separate #VmaAllocator objects can be used - independently. - There should be no need to create multiple such objects though - one per `VkDevice` is enough. -- By default, all calls to functions that take #VmaAllocator as first parameter - are safe to call from multiple threads simultaneously because they are - synchronized internally when needed. -- When the allocator is created with #VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT - flag, calls to functions that take such #VmaAllocator object must be - synchronized externally. -- Access to a #VmaAllocation object must be externally synchronized. For example, - you must not call vmaGetAllocationInfo() and vmaMapMemory() from different - threads at the same time if you pass the same #VmaAllocation object to these - functions. +/** \brief Builds and returns statistics as a null-terminated string in JSON format. +\param allocator +\param[out] ppStatsString Must be freed using vmaFreeStatsString() function. +\param detailedMap +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaBuildStatsString( + VmaAllocator VMA_NOT_NULL allocator, + char* VMA_NULLABLE* VMA_NOT_NULL ppStatsString, + VkBool32 detailedMap); -\section general_considerations_validation_layer_warnings Validation layer warnings +VMA_CALL_PRE void VMA_CALL_POST vmaFreeStatsString( + VmaAllocator VMA_NOT_NULL allocator, + char* VMA_NULLABLE pStatsString); -When using this library, you can meet following types of warnings issued by -Vulkan validation layer. They don't necessarily indicate a bug, so you may need -to just ignore them. +/** @} */ -- *vkBindBufferMemory(): Binding memory to buffer 0xeb8e4 but vkGetBufferMemoryRequirements() has not been called on that buffer.* - - It happens when VK_KHR_dedicated_allocation extension is enabled. - `vkGetBufferMemoryRequirements2KHR` function is used instead, while validation layer seems to be unaware of it. -- *Mapping an image with layout VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL can result in undefined behavior if this memory is used by the device. Only GENERAL or PREINITIALIZED should be used.* - - It happens when you map a buffer or image, because the library maps entire - `VkDeviceMemory` block, where different types of images and buffers may end - up together, especially on GPUs with unified memory like Intel. -- *Non-linear image 0xebc91 is aliased with linear buffer 0xeb8e4 which may indicate a bug.* - - It happens when you use lost allocations, and a new image or buffer is - created in place of an existing object that bacame lost. - - It may happen also when you use [defragmentation](@ref defragmentation). +#endif // VMA_STATS_STRING_ENABLED -\section general_considerations_allocation_algorithm Allocation algorithm +#endif // _VMA_FUNCTION_HEADERS -The library uses following algorithm for allocation, in order: +#ifdef __cplusplus +} +#endif --# Try to find free range of memory in existing blocks. --# If failed, try to create a new block of `VkDeviceMemory`, with preferred block size. --# If failed, try to create such block with size/2, size/4, size/8. --# If failed and #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT flag was - specified, try to find space in existing blocks, possilby making some other - allocations lost. --# If failed, try to allocate separate `VkDeviceMemory` for this allocation, - just like when you use #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. --# If failed, choose other memory type that meets the requirements specified in - VmaAllocationCreateInfo and go to point 1. --# If failed, return `VK_ERROR_OUT_OF_DEVICE_MEMORY`. +#endif // AMD_VULKAN_MEMORY_ALLOCATOR_H -\section general_considerations_features_not_supported Features not supported +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +// +// IMPLEMENTATION +// +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// -Features deliberately excluded from the scope of this library: +// For Visual Studio IntelliSense. +#if defined(__cplusplus) && defined(__INTELLISENSE__) +#define VMA_IMPLEMENTATION +#endif -- Data transfer. Uploading (straming) and downloading data of buffers and images - between CPU and GPU memory and related synchronization is responsibility of the user. - Defining some "texture" object that would automatically stream its data from a - staging copy in CPU memory to GPU memory would rather be a feature of another, - higher-level library implemented on top of VMA. -- Allocations for imported/exported external memory. They tend to require - explicit memory type index and dedicated allocation anyway, so they don't - interact with main features of this library. Such special purpose allocations - should be made manually, using `vkCreateBuffer()` and `vkAllocateMemory()`. -- Sub-allocation of parts of one large buffer. Although recommended as a good practice, - it is the user's responsibility to implement such logic on top of VMA. -- Recreation of buffers and images. Although the library has functions for - buffer and image creation (vmaCreateBuffer(), vmaCreateImage()), you need to - recreate these objects yourself after defragmentation. That's because the big - structures `VkBufferCreateInfo`, `VkImageCreateInfo` are not stored in - #VmaAllocation object. -- Handling CPU memory allocation failures. When dynamically creating small C++ - objects in CPU memory (not Vulkan memory), allocation failures are not checked - and handled gracefully, because that would complicate code significantly and - is usually not needed in desktop PC applications anyway. - Success of an allocation is just checked with an assert. -- Code free of any compiler warnings. Maintaining the library to compile and - work correctly on so many different platforms is hard enough. Being free of - any warnings, on any version of any compiler, is simply not feasible. -- This is a C++ library with C interface. - Bindings or ports to any other programming languages are welcomed as external projects and - are not going to be included into this repository. +#ifdef VMA_IMPLEMENTATION +#undef VMA_IMPLEMENTATION -*/ +#include +#include +#include +#include +#include -#ifdef __cplusplus -extern "C" { +#ifdef _MSC_VER + #include // For functions like __popcnt, _BitScanForward etc. +#endif +#if __cplusplus >= 202002L || _MSVC_LANG >= 202002L // C++20 + #include // For std::popcount #endif -/* -Define this macro to 0/1 to disable/enable support for recording functionality, -available through VmaAllocatorCreateInfo::pRecordSettings. +/******************************************************************************* +CONFIGURATION SECTION + +Define some of these macros before each #include of this header or change them +here if you need other then default behavior depending on your environment. */ -#ifndef VMA_RECORDING_ENABLED - #define VMA_RECORDING_ENABLED 0 -#endif +#ifndef _VMA_CONFIGURATION + +/* +Define this macro to 1 to make the library fetch pointers to Vulkan functions +internally, like: -#if !defined(NOMINMAX) && defined(VMA_IMPLEMENTATION) - #define NOMINMAX // For windows.h + vulkanFunctions.vkAllocateMemory = &vkAllocateMemory; +*/ +#if !defined(VMA_STATIC_VULKAN_FUNCTIONS) && !defined(VK_NO_PROTOTYPES) + #define VMA_STATIC_VULKAN_FUNCTIONS 1 #endif -#if defined(__ANDROID__) && defined(VK_NO_PROTOTYPES) && VMA_STATIC_VULKAN_FUNCTIONS - extern PFN_vkGetInstanceProcAddr vkGetInstanceProcAddr; - extern PFN_vkGetDeviceProcAddr vkGetDeviceProcAddr; - extern PFN_vkGetPhysicalDeviceProperties vkGetPhysicalDeviceProperties; - extern PFN_vkGetPhysicalDeviceMemoryProperties vkGetPhysicalDeviceMemoryProperties; - extern PFN_vkAllocateMemory vkAllocateMemory; - extern PFN_vkFreeMemory vkFreeMemory; - extern PFN_vkMapMemory vkMapMemory; - extern PFN_vkUnmapMemory vkUnmapMemory; - extern PFN_vkFlushMappedMemoryRanges vkFlushMappedMemoryRanges; - extern PFN_vkInvalidateMappedMemoryRanges vkInvalidateMappedMemoryRanges; - extern PFN_vkBindBufferMemory vkBindBufferMemory; - extern PFN_vkBindImageMemory vkBindImageMemory; - extern PFN_vkGetBufferMemoryRequirements vkGetBufferMemoryRequirements; - extern PFN_vkGetImageMemoryRequirements vkGetImageMemoryRequirements; - extern PFN_vkCreateBuffer vkCreateBuffer; - extern PFN_vkDestroyBuffer vkDestroyBuffer; - extern PFN_vkCreateImage vkCreateImage; - extern PFN_vkDestroyImage vkDestroyImage; - extern PFN_vkCmdCopyBuffer vkCmdCopyBuffer; - #if VMA_VULKAN_VERSION >= 1001000 - extern PFN_vkGetBufferMemoryRequirements2 vkGetBufferMemoryRequirements2; - extern PFN_vkGetImageMemoryRequirements2 vkGetImageMemoryRequirements2; - extern PFN_vkBindBufferMemory2 vkBindBufferMemory2; - extern PFN_vkBindImageMemory2 vkBindImageMemory2; - extern PFN_vkGetPhysicalDeviceMemoryProperties2 vkGetPhysicalDeviceMemoryProperties2; - #endif // #if VMA_VULKAN_VERSION >= 1001000 -#endif // #if defined(__ANDROID__) && VMA_STATIC_VULKAN_FUNCTIONS && VK_NO_PROTOTYPES +/* +Define this macro to 1 to make the library fetch pointers to Vulkan functions +internally, like: -#ifndef VULKAN_H_ - #include -#endif + vulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkGetDeviceProcAddr(device, "vkAllocateMemory"); -// Define this macro to declare maximum supported Vulkan version in format AAABBBCCC, -// where AAA = major, BBB = minor, CCC = patch. -// If you want to use version > 1.0, it still needs to be enabled via VmaAllocatorCreateInfo::vulkanApiVersion. -#if !defined(VMA_VULKAN_VERSION) - #if defined(VK_VERSION_1_2) - #define VMA_VULKAN_VERSION 1002000 - #elif defined(VK_VERSION_1_1) - #define VMA_VULKAN_VERSION 1001000 - #else - #define VMA_VULKAN_VERSION 1000000 - #endif +To use this feature in new versions of VMA you now have to pass +VmaVulkanFunctions::vkGetInstanceProcAddr and vkGetDeviceProcAddr as +VmaAllocatorCreateInfo::pVulkanFunctions. Other members can be null. +*/ +#if !defined(VMA_DYNAMIC_VULKAN_FUNCTIONS) + #define VMA_DYNAMIC_VULKAN_FUNCTIONS 1 #endif -#if !defined(VMA_DEDICATED_ALLOCATION) - #if VK_KHR_get_memory_requirements2 && VK_KHR_dedicated_allocation - #define VMA_DEDICATED_ALLOCATION 1 +#ifndef VMA_USE_STL_SHARED_MUTEX + // Compiler conforms to C++17. + #if __cplusplus >= 201703L + #define VMA_USE_STL_SHARED_MUTEX 1 + // Visual studio defines __cplusplus properly only when passed additional parameter: /Zc:__cplusplus + // Otherwise it is always 199711L, despite shared_mutex works since Visual Studio 2015 Update 2. + #elif defined(_MSC_FULL_VER) && _MSC_FULL_VER >= 190023918 && __cplusplus == 199711L && _MSVC_LANG >= 201703L + #define VMA_USE_STL_SHARED_MUTEX 1 #else - #define VMA_DEDICATED_ALLOCATION 0 + #define VMA_USE_STL_SHARED_MUTEX 0 #endif #endif -#if !defined(VMA_BIND_MEMORY2) - #if VK_KHR_bind_memory2 - #define VMA_BIND_MEMORY2 1 - #else - #define VMA_BIND_MEMORY2 0 - #endif -#endif +/* +Define this macro to include custom header files without having to edit this file directly, e.g.: -#if !defined(VMA_MEMORY_BUDGET) - #if VK_EXT_memory_budget && (VK_KHR_get_physical_device_properties2 || VMA_VULKAN_VERSION >= 1001000) - #define VMA_MEMORY_BUDGET 1 - #else - #define VMA_MEMORY_BUDGET 0 - #endif -#endif + // Inside of "my_vma_configuration_user_includes.h": -// Defined to 1 when VK_KHR_buffer_device_address device extension or equivalent core Vulkan 1.2 feature is defined in its headers. -#if !defined(VMA_BUFFER_DEVICE_ADDRESS) - #if VK_KHR_buffer_device_address || VMA_VULKAN_VERSION >= 1002000 - #define VMA_BUFFER_DEVICE_ADDRESS 1 - #else - #define VMA_BUFFER_DEVICE_ADDRESS 0 - #endif -#endif + #include "my_custom_assert.h" // for MY_CUSTOM_ASSERT + #include "my_custom_min.h" // for my_custom_min + #include + #include -// Defined to 1 when VK_EXT_memory_priority device extension is defined in Vulkan headers. -#if !defined(VMA_MEMORY_PRIORITY) - #if VK_EXT_memory_priority - #define VMA_MEMORY_PRIORITY 1 - #else - #define VMA_MEMORY_PRIORITY 0 - #endif -#endif + // Inside a different file, which includes "vk_mem_alloc.h": -// Define these macros to decorate all public functions with additional code, -// before and after returned type, appropriately. This may be useful for -// exporting the functions when compiling VMA as a separate library. Example: -// #define VMA_CALL_PRE __declspec(dllexport) -// #define VMA_CALL_POST __cdecl -#ifndef VMA_CALL_PRE - #define VMA_CALL_PRE -#endif -#ifndef VMA_CALL_POST - #define VMA_CALL_POST -#endif + #define VMA_CONFIGURATION_USER_INCLUDES_H "my_vma_configuration_user_includes.h" + #define VMA_ASSERT(expr) MY_CUSTOM_ASSERT(expr) + #define VMA_MIN(v1, v2) (my_custom_min(v1, v2)) + #include "vk_mem_alloc.h" + ... -// Define this macro to decorate pointers with an attribute specifying the -// length of the array they point to if they are not null. -// -// The length may be one of -// - The name of another parameter in the argument list where the pointer is declared -// - The name of another member in the struct where the pointer is declared -// - The name of a member of a struct type, meaning the value of that member in -// the context of the call. For example -// VMA_LEN_IF_NOT_NULL("VkPhysicalDeviceMemoryProperties::memoryHeapCount"), -// this means the number of memory heaps available in the device associated -// with the VmaAllocator being dealt with. -#ifndef VMA_LEN_IF_NOT_NULL - #define VMA_LEN_IF_NOT_NULL(len) +The following headers are used in this CONFIGURATION section only, so feel free to +remove them if not needed. +*/ +#if !defined(VMA_CONFIGURATION_USER_INCLUDES_H) + #include // for assert + #include // for min, max + #include +#else + #include VMA_CONFIGURATION_USER_INCLUDES_H #endif -// The VMA_NULLABLE macro is defined to be _Nullable when compiling with Clang. -// see: https://clang.llvm.org/docs/AttributeReference.html#nullable -#ifndef VMA_NULLABLE - #ifdef __clang__ - #define VMA_NULLABLE _Nullable - #else - #define VMA_NULLABLE - #endif +#ifndef VMA_NULL + // Value used as null pointer. Define it to e.g.: nullptr, NULL, 0, (void*)0. + #define VMA_NULL nullptr #endif -// The VMA_NOT_NULL macro is defined to be _Nonnull when compiling with Clang. -// see: https://clang.llvm.org/docs/AttributeReference.html#nonnull -#ifndef VMA_NOT_NULL - #ifdef __clang__ - #define VMA_NOT_NULL _Nonnull - #else - #define VMA_NOT_NULL - #endif -#endif +#if defined(__ANDROID_API__) && (__ANDROID_API__ < 16) +#include +static void* vma_aligned_alloc(size_t alignment, size_t size) +{ + // alignment must be >= sizeof(void*) + if(alignment < sizeof(void*)) + { + alignment = sizeof(void*); + } -// If non-dispatchable handles are represented as pointers then we can give -// then nullability annotations -#ifndef VMA_NOT_NULL_NON_DISPATCHABLE - #if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__) ) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) - #define VMA_NOT_NULL_NON_DISPATCHABLE VMA_NOT_NULL - #else - #define VMA_NOT_NULL_NON_DISPATCHABLE - #endif -#endif + return memalign(alignment, size); +} +#elif defined(__APPLE__) || defined(__ANDROID__) || (defined(__linux__) && defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC)) +#include -#ifndef VMA_NULLABLE_NON_DISPATCHABLE - #if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__) ) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) - #define VMA_NULLABLE_NON_DISPATCHABLE VMA_NULLABLE - #else - #define VMA_NULLABLE_NON_DISPATCHABLE - #endif +#if defined(__APPLE__) +#include #endif -/** \struct VmaAllocator -\brief Represents main object of this library initialized. - -Fill structure #VmaAllocatorCreateInfo and call function vmaCreateAllocator() to create it. -Call function vmaDestroyAllocator() to destroy it. - -It is recommended to create just one object of this type per `VkDevice` object, -right after Vulkan is initialized and keep it alive until before Vulkan device is destroyed. -*/ -VK_DEFINE_HANDLE(VmaAllocator) - -/// Callback function called after successful vkAllocateMemory. -typedef void (VKAPI_PTR *PFN_vmaAllocateDeviceMemoryFunction)( - VmaAllocator VMA_NOT_NULL allocator, - uint32_t memoryType, - VkDeviceMemory VMA_NOT_NULL_NON_DISPATCHABLE memory, - VkDeviceSize size, - void* VMA_NULLABLE pUserData); -/// Callback function called before vkFreeMemory. -typedef void (VKAPI_PTR *PFN_vmaFreeDeviceMemoryFunction)( - VmaAllocator VMA_NOT_NULL allocator, - uint32_t memoryType, - VkDeviceMemory VMA_NOT_NULL_NON_DISPATCHABLE memory, - VkDeviceSize size, - void* VMA_NULLABLE pUserData); - -/** \brief Set of callbacks that the library will call for `vkAllocateMemory` and `vkFreeMemory`. - -Provided for informative purpose, e.g. to gather statistics about number of -allocations or total amount of memory allocated in Vulkan. - -Used in VmaAllocatorCreateInfo::pDeviceMemoryCallbacks. -*/ -typedef struct VmaDeviceMemoryCallbacks { - /// Optional, can be null. - PFN_vmaAllocateDeviceMemoryFunction VMA_NULLABLE pfnAllocate; - /// Optional, can be null. - PFN_vmaFreeDeviceMemoryFunction VMA_NULLABLE pfnFree; - /// Optional, can be null. - void* VMA_NULLABLE pUserData; -} VmaDeviceMemoryCallbacks; +static void* vma_aligned_alloc(size_t alignment, size_t size) +{ + // Unfortunately, aligned_alloc causes VMA to crash due to it returning null pointers. (At least under 11.4) + // Therefore, for now disable this specific exception until a proper solution is found. + //#if defined(__APPLE__) && (defined(MAC_OS_X_VERSION_10_16) || defined(__IPHONE_14_0)) + //#if MAC_OS_X_VERSION_MAX_ALLOWED >= MAC_OS_X_VERSION_10_16 || __IPHONE_OS_VERSION_MAX_ALLOWED >= __IPHONE_14_0 + // // For C++14, usr/include/malloc/_malloc.h declares aligned_alloc()) only + // // with the MacOSX11.0 SDK in Xcode 12 (which is what adds + // // MAC_OS_X_VERSION_10_16), even though the function is marked + // // availabe for 10.15. That is why the preprocessor checks for 10.16 but + // // the __builtin_available checks for 10.15. + // // People who use C++17 could call aligned_alloc with the 10.15 SDK already. + // if (__builtin_available(macOS 10.15, iOS 13, *)) + // return aligned_alloc(alignment, size); + //#endif + //#endif -/// Flags for created #VmaAllocator. -typedef enum VmaAllocatorCreateFlagBits { - /** \brief Allocator and all objects created from it will not be synchronized internally, so you must guarantee they are used from only one thread at a time or synchronized externally by you. + // alignment must be >= sizeof(void*) + if(alignment < sizeof(void*)) + { + alignment = sizeof(void*); + } - Using this flag may increase performance because internal mutexes are not used. - */ - VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT = 0x00000001, - /** \brief Enables usage of VK_KHR_dedicated_allocation extension. + void *pointer; + if(posix_memalign(&pointer, alignment, size) == 0) + return pointer; + return VMA_NULL; +} +#elif defined(_WIN32) +static void* vma_aligned_alloc(size_t alignment, size_t size) +{ + return _aligned_malloc(size, alignment); +} +#else +static void* vma_aligned_alloc(size_t alignment, size_t size) +{ + return aligned_alloc(alignment, size); +} +#endif - The flag works only if VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_0`. - When it's `VK_API_VERSION_1_1`, the flag is ignored because the extension has been promoted to Vulkan 1.1. +#if defined(_WIN32) +static void vma_aligned_free(void* ptr) +{ + _aligned_free(ptr); +} +#else +static void vma_aligned_free(void* VMA_NULLABLE ptr) +{ + free(ptr); +} +#endif - Using this extenion will automatically allocate dedicated blocks of memory for - some buffers and images instead of suballocating place for them out of bigger - memory blocks (as if you explicitly used #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT - flag) when it is recommended by the driver. It may improve performance on some - GPUs. +// If your compiler is not compatible with C++11 and definition of +// aligned_alloc() function is missing, uncommeting following line may help: - You may set this flag only if you found out that following device extensions are - supported, you enabled them while creating Vulkan device passed as - VmaAllocatorCreateInfo::device, and you want them to be used internally by this - library: +//#include - - VK_KHR_get_memory_requirements2 (device extension) - - VK_KHR_dedicated_allocation (device extension) +// Normal assert to check for programmer's errors, especially in Debug configuration. +#ifndef VMA_ASSERT + #ifdef NDEBUG + #define VMA_ASSERT(expr) + #else + #define VMA_ASSERT(expr) assert(expr) + #endif +#endif - When this flag is set, you can experience following warnings reported by Vulkan - validation layer. You can ignore them. +// Assert that will be called very often, like inside data structures e.g. operator[]. +// Making it non-empty can make program slow. +#ifndef VMA_HEAVY_ASSERT + #ifdef NDEBUG + #define VMA_HEAVY_ASSERT(expr) + #else + #define VMA_HEAVY_ASSERT(expr) //VMA_ASSERT(expr) + #endif +#endif - > vkBindBufferMemory(): Binding memory to buffer 0x2d but vkGetBufferMemoryRequirements() has not been called on that buffer. - */ - VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT = 0x00000002, - /** - Enables usage of VK_KHR_bind_memory2 extension. +#ifndef VMA_ALIGN_OF + #define VMA_ALIGN_OF(type) (__alignof(type)) +#endif - The flag works only if VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_0`. - When it's `VK_API_VERSION_1_1`, the flag is ignored because the extension has been promoted to Vulkan 1.1. +#ifndef VMA_SYSTEM_ALIGNED_MALLOC + #define VMA_SYSTEM_ALIGNED_MALLOC(size, alignment) vma_aligned_alloc((alignment), (size)) +#endif - You may set this flag only if you found out that this device extension is supported, - you enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, - and you want it to be used internally by this library. +#ifndef VMA_SYSTEM_ALIGNED_FREE + // VMA_SYSTEM_FREE is the old name, but might have been defined by the user + #if defined(VMA_SYSTEM_FREE) + #define VMA_SYSTEM_ALIGNED_FREE(ptr) VMA_SYSTEM_FREE(ptr) + #else + #define VMA_SYSTEM_ALIGNED_FREE(ptr) vma_aligned_free(ptr) + #endif +#endif - The extension provides functions `vkBindBufferMemory2KHR` and `vkBindImageMemory2KHR`, - which allow to pass a chain of `pNext` structures while binding. - This flag is required if you use `pNext` parameter in vmaBindBufferMemory2() or vmaBindImageMemory2(). - */ - VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT = 0x00000004, - /** - Enables usage of VK_EXT_memory_budget extension. +#ifndef VMA_COUNT_BITS_SET + // Returns number of bits set to 1 in (v) + #define VMA_COUNT_BITS_SET(v) VmaCountBitsSet(v) +#endif - You may set this flag only if you found out that this device extension is supported, - you enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, - and you want it to be used internally by this library, along with another instance extension - VK_KHR_get_physical_device_properties2, which is required by it (or Vulkan 1.1, where this extension is promoted). +#ifndef VMA_BITSCAN_LSB + // Scans integer for index of first nonzero value from the Least Significant Bit (LSB). If mask is 0 then returns UINT8_MAX + #define VMA_BITSCAN_LSB(mask) VmaBitScanLSB(mask) +#endif - The extension provides query for current memory usage and budget, which will probably - be more accurate than an estimation used by the library otherwise. - */ - VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT = 0x00000008, - /** - Enables usage of VK_AMD_device_coherent_memory extension. +#ifndef VMA_BITSCAN_MSB + // Scans integer for index of first nonzero value from the Most Significant Bit (MSB). If mask is 0 then returns UINT8_MAX + #define VMA_BITSCAN_MSB(mask) VmaBitScanMSB(mask) +#endif - You may set this flag only if you: +#ifndef VMA_MIN + #define VMA_MIN(v1, v2) ((std::min)((v1), (v2))) +#endif - - found out that this device extension is supported and enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, - - checked that `VkPhysicalDeviceCoherentMemoryFeaturesAMD::deviceCoherentMemory` is true and set it while creating the Vulkan device, - - want it to be used internally by this library. +#ifndef VMA_MAX + #define VMA_MAX(v1, v2) ((std::max)((v1), (v2))) +#endif - The extension and accompanying device feature provide access to memory types with - `VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD` and `VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD` flags. - They are useful mostly for writing breadcrumb markers - a common method for debugging GPU crash/hang/TDR. +#ifndef VMA_SWAP + #define VMA_SWAP(v1, v2) std::swap((v1), (v2)) +#endif - When the extension is not enabled, such memory types are still enumerated, but their usage is illegal. - To protect from this error, if you don't create the allocator with this flag, it will refuse to allocate any memory or create a custom pool in such memory type, - returning `VK_ERROR_FEATURE_NOT_PRESENT`. - */ - VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT = 0x00000010, - /** - Enables usage of "buffer device address" feature, which allows you to use function - `vkGetBufferDeviceAddress*` to get raw GPU pointer to a buffer and pass it for usage inside a shader. +#ifndef VMA_SORT + #define VMA_SORT(beg, end, cmp) std::sort(beg, end, cmp) +#endif - You may set this flag only if you: +#ifndef VMA_DEBUG_LOG + #define VMA_DEBUG_LOG(format, ...) + /* + #define VMA_DEBUG_LOG(format, ...) do { \ + printf(format, __VA_ARGS__); \ + printf("\n"); \ + } while(false) + */ +#endif - 1. (For Vulkan version < 1.2) Found as available and enabled device extension - VK_KHR_buffer_device_address. - This extension is promoted to core Vulkan 1.2. - 2. Found as available and enabled device feature `VkPhysicalDeviceBufferDeviceAddressFeatures::bufferDeviceAddress`. +// Define this macro to 1 to enable functions: vmaBuildStatsString, vmaFreeStatsString. +#if VMA_STATS_STRING_ENABLED + static inline void VmaUint32ToStr(char* VMA_NOT_NULL outStr, size_t strLen, uint32_t num) + { + snprintf(outStr, strLen, "%u", static_cast(num)); + } + static inline void VmaUint64ToStr(char* VMA_NOT_NULL outStr, size_t strLen, uint64_t num) + { + snprintf(outStr, strLen, "%llu", static_cast(num)); + } + static inline void VmaPtrToStr(char* VMA_NOT_NULL outStr, size_t strLen, const void* ptr) + { + snprintf(outStr, strLen, "%p", ptr); + } +#endif - When this flag is set, you can create buffers with `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT` using VMA. - The library automatically adds `VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT` to - allocated memory blocks wherever it might be needed. +#ifndef VMA_MUTEX + class VmaMutex + { + public: + void Lock() { m_Mutex.lock(); } + void Unlock() { m_Mutex.unlock(); } + bool TryLock() { return m_Mutex.try_lock(); } + private: + std::mutex m_Mutex; + }; + #define VMA_MUTEX VmaMutex +#endif - For more information, see documentation chapter \ref enabling_buffer_device_address. - */ - VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT = 0x00000020, - /** - Enables usage of VK_EXT_memory_priority extension in the library. +// Read-write mutex, where "read" is shared access, "write" is exclusive access. +#ifndef VMA_RW_MUTEX + #if VMA_USE_STL_SHARED_MUTEX + // Use std::shared_mutex from C++17. + #include + class VmaRWMutex + { + public: + void LockRead() { m_Mutex.lock_shared(); } + void UnlockRead() { m_Mutex.unlock_shared(); } + bool TryLockRead() { return m_Mutex.try_lock_shared(); } + void LockWrite() { m_Mutex.lock(); } + void UnlockWrite() { m_Mutex.unlock(); } + bool TryLockWrite() { return m_Mutex.try_lock(); } + private: + std::shared_mutex m_Mutex; + }; + #define VMA_RW_MUTEX VmaRWMutex + #elif defined(_WIN32) && defined(WINVER) && WINVER >= 0x0600 + // Use SRWLOCK from WinAPI. + // Minimum supported client = Windows Vista, server = Windows Server 2008. + class VmaRWMutex + { + public: + VmaRWMutex() { InitializeSRWLock(&m_Lock); } + void LockRead() { AcquireSRWLockShared(&m_Lock); } + void UnlockRead() { ReleaseSRWLockShared(&m_Lock); } + bool TryLockRead() { return TryAcquireSRWLockShared(&m_Lock) != FALSE; } + void LockWrite() { AcquireSRWLockExclusive(&m_Lock); } + void UnlockWrite() { ReleaseSRWLockExclusive(&m_Lock); } + bool TryLockWrite() { return TryAcquireSRWLockExclusive(&m_Lock) != FALSE; } + private: + SRWLOCK m_Lock; + }; + #define VMA_RW_MUTEX VmaRWMutex + #else + // Less efficient fallback: Use normal mutex. + class VmaRWMutex + { + public: + void LockRead() { m_Mutex.Lock(); } + void UnlockRead() { m_Mutex.Unlock(); } + bool TryLockRead() { return m_Mutex.TryLock(); } + void LockWrite() { m_Mutex.Lock(); } + void UnlockWrite() { m_Mutex.Unlock(); } + bool TryLockWrite() { return m_Mutex.TryLock(); } + private: + VMA_MUTEX m_Mutex; + }; + #define VMA_RW_MUTEX VmaRWMutex + #endif // #if VMA_USE_STL_SHARED_MUTEX +#endif // #ifndef VMA_RW_MUTEX - You may set this flag only if you found available and enabled this device extension, - along with `VkPhysicalDeviceMemoryPriorityFeaturesEXT::memoryPriority == VK_TRUE`, - while creating Vulkan device passed as VmaAllocatorCreateInfo::device. +/* +If providing your own implementation, you need to implement a subset of std::atomic. +*/ +#ifndef VMA_ATOMIC_UINT32 + #include + #define VMA_ATOMIC_UINT32 std::atomic +#endif - When this flag is used, VmaAllocationCreateInfo::priority and VmaPoolCreateInfo::priority - are used to set priorities of allocated Vulkan memory. Without it, these variables are ignored. +#ifndef VMA_ATOMIC_UINT64 + #include + #define VMA_ATOMIC_UINT64 std::atomic +#endif - A priority must be a floating-point value between 0 and 1, indicating the priority of the allocation relative to other memory allocations. - Larger values are higher priority. The granularity of the priorities is implementation-dependent. - It is automatically passed to every call to `vkAllocateMemory` done by the library using structure `VkMemoryPriorityAllocateInfoEXT`. - The value to be used for default priority is 0.5. - For more details, see the documentation of the VK_EXT_memory_priority extension. +#ifndef VMA_DEBUG_ALWAYS_DEDICATED_MEMORY + /** + Every allocation will have its own memory block. + Define to 1 for debugging purposes only. */ - VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT = 0x00000040, - - VMA_ALLOCATOR_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF -} VmaAllocatorCreateFlagBits; -typedef VkFlags VmaAllocatorCreateFlags; + #define VMA_DEBUG_ALWAYS_DEDICATED_MEMORY (0) +#endif -/** \brief Pointers to some Vulkan functions - a subset used by the library. +#ifndef VMA_MIN_ALIGNMENT + /** + Minimum alignment of all allocations, in bytes. + Set to more than 1 for debugging purposes. Must be power of two. + */ + #ifdef VMA_DEBUG_ALIGNMENT // Old name + #define VMA_MIN_ALIGNMENT VMA_DEBUG_ALIGNMENT + #else + #define VMA_MIN_ALIGNMENT (1) + #endif +#endif -Used in VmaAllocatorCreateInfo::pVulkanFunctions. -*/ -typedef struct VmaVulkanFunctions { - PFN_vkGetPhysicalDeviceProperties VMA_NULLABLE vkGetPhysicalDeviceProperties; - PFN_vkGetPhysicalDeviceMemoryProperties VMA_NULLABLE vkGetPhysicalDeviceMemoryProperties; - PFN_vkAllocateMemory VMA_NULLABLE vkAllocateMemory; - PFN_vkFreeMemory VMA_NULLABLE vkFreeMemory; - PFN_vkMapMemory VMA_NULLABLE vkMapMemory; - PFN_vkUnmapMemory VMA_NULLABLE vkUnmapMemory; - PFN_vkFlushMappedMemoryRanges VMA_NULLABLE vkFlushMappedMemoryRanges; - PFN_vkInvalidateMappedMemoryRanges VMA_NULLABLE vkInvalidateMappedMemoryRanges; - PFN_vkBindBufferMemory VMA_NULLABLE vkBindBufferMemory; - PFN_vkBindImageMemory VMA_NULLABLE vkBindImageMemory; - PFN_vkGetBufferMemoryRequirements VMA_NULLABLE vkGetBufferMemoryRequirements; - PFN_vkGetImageMemoryRequirements VMA_NULLABLE vkGetImageMemoryRequirements; - PFN_vkCreateBuffer VMA_NULLABLE vkCreateBuffer; - PFN_vkDestroyBuffer VMA_NULLABLE vkDestroyBuffer; - PFN_vkCreateImage VMA_NULLABLE vkCreateImage; - PFN_vkDestroyImage VMA_NULLABLE vkDestroyImage; - PFN_vkCmdCopyBuffer VMA_NULLABLE vkCmdCopyBuffer; -#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 - PFN_vkGetBufferMemoryRequirements2KHR VMA_NULLABLE vkGetBufferMemoryRequirements2KHR; - PFN_vkGetImageMemoryRequirements2KHR VMA_NULLABLE vkGetImageMemoryRequirements2KHR; +#ifndef VMA_DEBUG_MARGIN + /** + Minimum margin after every allocation, in bytes. + Set nonzero for debugging purposes only. + */ + #define VMA_DEBUG_MARGIN (0) #endif -#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 - PFN_vkBindBufferMemory2KHR VMA_NULLABLE vkBindBufferMemory2KHR; - PFN_vkBindImageMemory2KHR VMA_NULLABLE vkBindImageMemory2KHR; + +#ifndef VMA_DEBUG_INITIALIZE_ALLOCATIONS + /** + Define this macro to 1 to automatically fill new allocations and destroyed + allocations with some bit pattern. + */ + #define VMA_DEBUG_INITIALIZE_ALLOCATIONS (0) #endif -#if VMA_MEMORY_BUDGET || VMA_VULKAN_VERSION >= 1001000 - PFN_vkGetPhysicalDeviceMemoryProperties2KHR VMA_NULLABLE vkGetPhysicalDeviceMemoryProperties2KHR; + +#ifndef VMA_DEBUG_DETECT_CORRUPTION + /** + Define this macro to 1 together with non-zero value of VMA_DEBUG_MARGIN to + enable writing magic value to the margin after every allocation and + validating it, so that memory corruptions (out-of-bounds writes) are detected. + */ + #define VMA_DEBUG_DETECT_CORRUPTION (0) #endif -} VmaVulkanFunctions; -/// Flags to be used in VmaRecordSettings::flags. -typedef enum VmaRecordFlagBits { - /** \brief Enables flush after recording every function call. +#ifndef VMA_DEBUG_GLOBAL_MUTEX + /** + Set this to 1 for debugging purposes only, to enable single mutex protecting all + entry calls to the library. Can be useful for debugging multithreading issues. + */ + #define VMA_DEBUG_GLOBAL_MUTEX (0) +#endif - Enable it if you expect your application to crash, which may leave recording file truncated. - It may degrade performance though. +#ifndef VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY + /** + Minimum value for VkPhysicalDeviceLimits::bufferImageGranularity. + Set to more than 1 for debugging purposes only. Must be power of two. */ - VMA_RECORD_FLUSH_AFTER_CALL_BIT = 0x00000001, + #define VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY (1) +#endif - VMA_RECORD_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF -} VmaRecordFlagBits; -typedef VkFlags VmaRecordFlags; +#ifndef VMA_DEBUG_DONT_EXCEED_MAX_MEMORY_ALLOCATION_COUNT + /* + Set this to 1 to make VMA never exceed VkPhysicalDeviceLimits::maxMemoryAllocationCount + and return error instead of leaving up to Vulkan implementation what to do in such cases. + */ + #define VMA_DEBUG_DONT_EXCEED_MAX_MEMORY_ALLOCATION_COUNT (0) +#endif -/// Parameters for recording calls to VMA functions. To be used in VmaAllocatorCreateInfo::pRecordSettings. -typedef struct VmaRecordSettings -{ - /// Flags for recording. Use #VmaRecordFlagBits enum. - VmaRecordFlags flags; - /** \brief Path to the file that should be written by the recording. +#ifndef VMA_SMALL_HEAP_MAX_SIZE + /// Maximum size of a memory heap in Vulkan to consider it "small". + #define VMA_SMALL_HEAP_MAX_SIZE (1024ull * 1024 * 1024) +#endif - Suggested extension: "csv". - If the file already exists, it will be overwritten. - It will be opened for the whole time #VmaAllocator object is alive. - If opening this file fails, creation of the whole allocator object fails. - */ - const char* VMA_NOT_NULL pFilePath; -} VmaRecordSettings; +#ifndef VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE + /// Default size of a block allocated as single VkDeviceMemory from a "large" heap. + #define VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE (256ull * 1024 * 1024) +#endif -/// Description of a Allocator to be created. -typedef struct VmaAllocatorCreateInfo -{ - /// Flags for created allocator. Use #VmaAllocatorCreateFlagBits enum. - VmaAllocatorCreateFlags flags; - /// Vulkan physical device. - /** It must be valid throughout whole lifetime of created allocator. */ - VkPhysicalDevice VMA_NOT_NULL physicalDevice; - /// Vulkan device. - /** It must be valid throughout whole lifetime of created allocator. */ - VkDevice VMA_NOT_NULL device; - /// Preferred size of a single `VkDeviceMemory` block to be allocated from large heaps > 1 GiB. Optional. - /** Set to 0 to use default, which is currently 256 MiB. */ - VkDeviceSize preferredLargeHeapBlockSize; - /// Custom CPU memory allocation callbacks. Optional. - /** Optional, can be null. When specified, will also be used for all CPU-side memory allocations. */ - const VkAllocationCallbacks* VMA_NULLABLE pAllocationCallbacks; - /// Informative callbacks for `vkAllocateMemory`, `vkFreeMemory`. Optional. - /** Optional, can be null. */ - const VmaDeviceMemoryCallbacks* VMA_NULLABLE pDeviceMemoryCallbacks; - /** \brief Maximum number of additional frames that are in use at the same time as current frame. +/* +Mapping hysteresis is a logic that launches when vmaMapMemory/vmaUnmapMemory is called +or a persistently mapped allocation is created and destroyed several times in a row. +It keeps additional +1 mapping of a device memory block to prevent calling actual +vkMapMemory/vkUnmapMemory too many times, which may improve performance and help +tools like RenderDOc. +*/ +#ifndef VMA_MAPPING_HYSTERESIS_ENABLED + #define VMA_MAPPING_HYSTERESIS_ENABLED 1 +#endif - This value is used only when you make allocations with - VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag. Such allocation cannot become - lost if allocation.lastUseFrameIndex >= allocator.currentFrameIndex - frameInUseCount. +#ifndef VMA_CLASS_NO_COPY + #define VMA_CLASS_NO_COPY(className) \ + private: \ + className(const className&) = delete; \ + className& operator=(const className&) = delete; +#endif - For example, if you double-buffer your command buffers, so resources used for - rendering in previous frame may still be in use by the GPU at the moment you - allocate resources needed for the current frame, set this value to 1. +#define VMA_VALIDATE(cond) do { if(!(cond)) { \ + VMA_ASSERT(0 && "Validation failed: " #cond); \ + return false; \ + } } while(false) - If you want to allow any allocations other than used in the current frame to - become lost, set this value to 0. - */ - uint32_t frameInUseCount; - /** \brief Either null or a pointer to an array of limits on maximum number of bytes that can be allocated out of particular Vulkan memory heap. +/******************************************************************************* +END OF CONFIGURATION +*/ +#endif // _VMA_CONFIGURATION - If not NULL, it must be a pointer to an array of - `VkPhysicalDeviceMemoryProperties::memoryHeapCount` elements, defining limit on - maximum number of bytes that can be allocated out of particular Vulkan memory - heap. - Any of the elements may be equal to `VK_WHOLE_SIZE`, which means no limit on that - heap. This is also the default in case of `pHeapSizeLimit` = NULL. +static const uint8_t VMA_ALLOCATION_FILL_PATTERN_CREATED = 0xDC; +static const uint8_t VMA_ALLOCATION_FILL_PATTERN_DESTROYED = 0xEF; +// Decimal 2139416166, float NaN, little-endian binary 66 E6 84 7F. +static const uint32_t VMA_CORRUPTION_DETECTION_MAGIC_VALUE = 0x7F84E666; - If there is a limit defined for a heap: +// Copy of some Vulkan definitions so we don't need to check their existence just to handle few constants. +static const uint32_t VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY = 0x00000040; +static const uint32_t VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY = 0x00000080; +static const uint32_t VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_COPY = 0x00020000; +static const uint32_t VK_IMAGE_CREATE_DISJOINT_BIT_COPY = 0x00000200; +static const int32_t VK_IMAGE_TILING_DRM_FORMAT_MODIFIER_EXT_COPY = 1000158000; +static const uint32_t VMA_ALLOCATION_INTERNAL_STRATEGY_MIN_OFFSET = 0x10000000u; +static const uint32_t VMA_ALLOCATION_TRY_COUNT = 32; +static const uint32_t VMA_VENDOR_ID_AMD = 4098; - - If user tries to allocate more memory from that heap using this allocator, - the allocation fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. - - If the limit is smaller than heap size reported in `VkMemoryHeap::size`, the - value of this limit will be reported instead when using vmaGetMemoryProperties(). +// This one is tricky. Vulkan specification defines this code as available since +// Vulkan 1.0, but doesn't actually define it in Vulkan SDK earlier than 1.2.131. +// See pull request #207. +#define VK_ERROR_UNKNOWN_COPY ((VkResult)-13) - Warning! Using this feature may not be equivalent to installing a GPU with - smaller amount of memory, because graphics driver doesn't necessary fail new - allocations with `VK_ERROR_OUT_OF_DEVICE_MEMORY` result when memory capacity is - exceeded. It may return success and just silently migrate some device memory - blocks to system RAM. This driver behavior can also be controlled using - VK_AMD_memory_overallocation_behavior extension. - */ - const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL("VkPhysicalDeviceMemoryProperties::memoryHeapCount") pHeapSizeLimit; - /** \brief Pointers to Vulkan functions. Can be null. +#if VMA_STATS_STRING_ENABLED +// Correspond to values of enum VmaSuballocationType. +static const char* VMA_SUBALLOCATION_TYPE_NAMES[] = +{ + "FREE", + "UNKNOWN", + "BUFFER", + "IMAGE_UNKNOWN", + "IMAGE_LINEAR", + "IMAGE_OPTIMAL", +}; +#endif - For details see [Pointers to Vulkan functions](@ref config_Vulkan_functions). - */ - const VmaVulkanFunctions* VMA_NULLABLE pVulkanFunctions; - /** \brief Parameters for recording of VMA calls. Can be null. +static VkAllocationCallbacks VmaEmptyAllocationCallbacks = + { VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL }; - If not null, it enables recording of calls to VMA functions to a file. - If support for recording is not enabled using `VMA_RECORDING_ENABLED` macro, - creation of the allocator object fails with `VK_ERROR_FEATURE_NOT_PRESENT`. - */ - const VmaRecordSettings* VMA_NULLABLE pRecordSettings; - /** \brief Handle to Vulkan instance object. - Starting from version 3.0.0 this member is no longer optional, it must be set! - */ - VkInstance VMA_NOT_NULL instance; - /** \brief Optional. The highest version of Vulkan that the application is designed to use. +#ifndef _VMA_ENUM_DECLARATIONS - It must be a value in the format as created by macro `VK_MAKE_VERSION` or a constant like: `VK_API_VERSION_1_1`, `VK_API_VERSION_1_0`. - The patch version number specified is ignored. Only the major and minor versions are considered. - It must be less or equal (preferably equal) to value as passed to `vkCreateInstance` as `VkApplicationInfo::apiVersion`. - Only versions 1.0, 1.1, 1.2 are supported by the current implementation. - Leaving it initialized to zero is equivalent to `VK_API_VERSION_1_0`. - */ - uint32_t vulkanApiVersion; -} VmaAllocatorCreateInfo; - -/// Creates Allocator object. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAllocator( - const VmaAllocatorCreateInfo* VMA_NOT_NULL pCreateInfo, - VmaAllocator VMA_NULLABLE * VMA_NOT_NULL pAllocator); +enum VmaSuballocationType +{ + VMA_SUBALLOCATION_TYPE_FREE = 0, + VMA_SUBALLOCATION_TYPE_UNKNOWN = 1, + VMA_SUBALLOCATION_TYPE_BUFFER = 2, + VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN = 3, + VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR = 4, + VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL = 5, + VMA_SUBALLOCATION_TYPE_MAX_ENUM = 0x7FFFFFFF +}; -/// Destroys allocator object. -VMA_CALL_PRE void VMA_CALL_POST vmaDestroyAllocator( - VmaAllocator VMA_NULLABLE allocator); +enum VMA_CACHE_OPERATION +{ + VMA_CACHE_FLUSH, + VMA_CACHE_INVALIDATE +}; -/** \brief Information about existing #VmaAllocator object. -*/ -typedef struct VmaAllocatorInfo +enum class VmaAllocationRequestType { - /** \brief Handle to Vulkan instance object. + Normal, + TLSF, + // Used by "Linear" algorithm. + UpperAddress, + EndOf1st, + EndOf2nd, +}; - This is the same value as has been passed through VmaAllocatorCreateInfo::instance. - */ - VkInstance VMA_NOT_NULL instance; - /** \brief Handle to Vulkan physical device object. +#endif // _VMA_ENUM_DECLARATIONS - This is the same value as has been passed through VmaAllocatorCreateInfo::physicalDevice. - */ - VkPhysicalDevice VMA_NOT_NULL physicalDevice; - /** \brief Handle to Vulkan device object. +#ifndef _VMA_FORWARD_DECLARATIONS +// Opaque handle used by allocation algorithms to identify single allocation in any conforming way. +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VmaAllocHandle); - This is the same value as has been passed through VmaAllocatorCreateInfo::device. - */ - VkDevice VMA_NOT_NULL device; -} VmaAllocatorInfo; +struct VmaMutexLock; +struct VmaMutexLockRead; +struct VmaMutexLockWrite; -/** \brief Returns information about existing #VmaAllocator object - handle to Vulkan device etc. +template +struct AtomicTransactionalIncrement; -It might be useful if you want to keep just the #VmaAllocator handle and fetch other required handles to -`VkPhysicalDevice`, `VkDevice` etc. every time using this function. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocatorInfo(VmaAllocator VMA_NOT_NULL allocator, VmaAllocatorInfo* VMA_NOT_NULL pAllocatorInfo); +template +struct VmaStlAllocator; -/** -PhysicalDeviceProperties are fetched from physicalDevice by the allocator. -You can access it here, without fetching it again on your own. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaGetPhysicalDeviceProperties( - VmaAllocator VMA_NOT_NULL allocator, - const VkPhysicalDeviceProperties* VMA_NULLABLE * VMA_NOT_NULL ppPhysicalDeviceProperties); +template +class VmaVector; -/** -PhysicalDeviceMemoryProperties are fetched from physicalDevice by the allocator. -You can access it here, without fetching it again on your own. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryProperties( - VmaAllocator VMA_NOT_NULL allocator, - const VkPhysicalDeviceMemoryProperties* VMA_NULLABLE * VMA_NOT_NULL ppPhysicalDeviceMemoryProperties); +template +class VmaSmallVector; -/** -\brief Given Memory Type Index, returns Property Flags of this memory type. +template +class VmaPoolAllocator; -This is just a convenience function. Same information can be obtained using -vmaGetMemoryProperties(). -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryTypeProperties( - VmaAllocator VMA_NOT_NULL allocator, - uint32_t memoryTypeIndex, - VkMemoryPropertyFlags* VMA_NOT_NULL pFlags); +template +struct VmaListItem; -/** \brief Sets index of the current frame. +template +class VmaRawList; -This function must be used if you make allocations with -#VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT and -#VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT flags to inform the allocator -when a new frame begins. Allocations queried using vmaGetAllocationInfo() cannot -become lost in the current frame. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaSetCurrentFrameIndex( - VmaAllocator VMA_NOT_NULL allocator, - uint32_t frameIndex); +template +class VmaList; -/** \brief Calculated statistics of memory usage in entire allocator. -*/ -typedef struct VmaStatInfo -{ - /// Number of `VkDeviceMemory` Vulkan memory blocks allocated. - uint32_t blockCount; - /// Number of #VmaAllocation allocation objects allocated. - uint32_t allocationCount; - /// Number of free ranges of memory between allocations. - uint32_t unusedRangeCount; - /// Total number of bytes occupied by all allocations. - VkDeviceSize usedBytes; - /// Total number of bytes occupied by unused ranges. - VkDeviceSize unusedBytes; - VkDeviceSize allocationSizeMin, allocationSizeAvg, allocationSizeMax; - VkDeviceSize unusedRangeSizeMin, unusedRangeSizeAvg, unusedRangeSizeMax; -} VmaStatInfo; - -/// General statistics from current state of Allocator. -typedef struct VmaStats -{ - VmaStatInfo memoryType[VK_MAX_MEMORY_TYPES]; - VmaStatInfo memoryHeap[VK_MAX_MEMORY_HEAPS]; - VmaStatInfo total; -} VmaStats; +template +class VmaIntrusiveLinkedList; -/** \brief Retrieves statistics from current state of the Allocator. +// Unused in this version +#if 0 +template +struct VmaPair; +template +struct VmaPairFirstLess; -This function is called "calculate" not "get" because it has to traverse all -internal data structures, so it may be quite slow. For faster but more brief statistics -suitable to be called every frame or every allocation, use vmaGetBudget(). +template +class VmaMap; +#endif -Note that when using allocator from multiple threads, returned information may immediately -become outdated. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaCalculateStats( - VmaAllocator VMA_NOT_NULL allocator, - VmaStats* VMA_NOT_NULL pStats); +#if VMA_STATS_STRING_ENABLED +class VmaStringBuilder; +class VmaJsonWriter; +#endif -/** \brief Statistics of current memory usage and available budget, in bytes, for specific memory heap. -*/ -typedef struct VmaBudget -{ - /** \brief Sum size of all `VkDeviceMemory` blocks allocated from particular heap, in bytes. - */ - VkDeviceSize blockBytes; +class VmaDeviceMemoryBlock; - /** \brief Sum size of all allocations created in particular heap, in bytes. +struct VmaDedicatedAllocationListItemTraits; +class VmaDedicatedAllocationList; - Usually less or equal than `blockBytes`. - Difference `blockBytes - allocationBytes` is the amount of memory allocated but unused - - available for new allocations or wasted due to fragmentation. +struct VmaSuballocation; +struct VmaSuballocationOffsetLess; +struct VmaSuballocationOffsetGreater; +struct VmaSuballocationItemSizeLess; - It might be greater than `blockBytes` if there are some allocations in lost state, as they account - to this value as well. - */ - VkDeviceSize allocationBytes; +typedef VmaList> VmaSuballocationList; - /** \brief Estimated current memory usage of the program, in bytes. +struct VmaAllocationRequest; - Fetched from system using `VK_EXT_memory_budget` extension if enabled. +class VmaBlockMetadata; +class VmaBlockMetadata_Linear; +class VmaBlockMetadata_TLSF; - It might be different than `blockBytes` (usually higher) due to additional implicit objects - also occupying the memory, like swapchain, pipelines, descriptor heaps, command buffers, or - `VkDeviceMemory` blocks allocated outside of this library, if any. - */ - VkDeviceSize usage; +class VmaBlockVector; - /** \brief Estimated amount of memory available to the program, in bytes. +struct VmaPoolListItemTraits; - Fetched from system using `VK_EXT_memory_budget` extension if enabled. +struct VmaCurrentBudgetData; - It might be different (most probably smaller) than `VkMemoryHeap::size[heapIndex]` due to factors - external to the program, like other programs also consuming system resources. - Difference `budget - usage` is the amount of additional memory that can probably - be allocated without problems. Exceeding the budget may result in various problems. - */ - VkDeviceSize budget; -} VmaBudget; +class VmaAllocationObjectAllocator; -/** \brief Retrieves information about current memory budget for all memory heaps. +#endif // _VMA_FORWARD_DECLARATIONS -\param[out] pBudget Must point to array with number of elements at least equal to number of memory heaps in physical device used. -This function is called "get" not "calculate" because it is very fast, suitable to be called -every frame or every allocation. For more detailed statistics use vmaCalculateStats(). +#ifndef _VMA_FUNCTIONS -Note that when using allocator from multiple threads, returned information may immediately -become outdated. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaGetBudget( - VmaAllocator VMA_NOT_NULL allocator, - VmaBudget* VMA_NOT_NULL pBudget); +/* +Returns number of bits set to 1 in (v). -#ifndef VMA_STATS_STRING_ENABLED -#define VMA_STATS_STRING_ENABLED 1 -#endif +On specific platforms and compilers you can use instrinsics like: -#if VMA_STATS_STRING_ENABLED +Visual Studio: + return __popcnt(v); +GCC, Clang: + return static_cast(__builtin_popcount(v)); -/// Builds and returns statistics as string in JSON format. -/** @param[out] ppStatsString Must be freed using vmaFreeStatsString() function. +Define macro VMA_COUNT_BITS_SET to provide your optimized implementation. +But you need to check in runtime whether user's CPU supports these, as some old processors don't. */ -VMA_CALL_PRE void VMA_CALL_POST vmaBuildStatsString( - VmaAllocator VMA_NOT_NULL allocator, - char* VMA_NULLABLE * VMA_NOT_NULL ppStatsString, - VkBool32 detailedMap); +static inline uint32_t VmaCountBitsSet(uint32_t v) +{ +#if __cplusplus >= 202002L || _MSVC_LANG >= 202002L // C++20 + return std::popcount(v); +#else + uint32_t c = v - ((v >> 1) & 0x55555555); + c = ((c >> 2) & 0x33333333) + (c & 0x33333333); + c = ((c >> 4) + c) & 0x0F0F0F0F; + c = ((c >> 8) + c) & 0x00FF00FF; + c = ((c >> 16) + c) & 0x0000FFFF; + return c; +#endif +} -VMA_CALL_PRE void VMA_CALL_POST vmaFreeStatsString( - VmaAllocator VMA_NOT_NULL allocator, - char* VMA_NULLABLE pStatsString); +static inline uint8_t VmaBitScanLSB(uint64_t mask) +{ +#if defined(_MSC_VER) && defined(_WIN64) + unsigned long pos; + if (_BitScanForward64(&pos, mask)) + return static_cast(pos); + return UINT8_MAX; +#elif defined __GNUC__ || defined __clang__ + return static_cast(__builtin_ffsll(mask)) - 1U; +#else + uint8_t pos = 0; + uint64_t bit = 1; + do + { + if (mask & bit) + return pos; + bit <<= 1; + } while (pos++ < 63); + return UINT8_MAX; +#endif +} -#endif // #if VMA_STATS_STRING_ENABLED +static inline uint8_t VmaBitScanLSB(uint32_t mask) +{ +#ifdef _MSC_VER + unsigned long pos; + if (_BitScanForward(&pos, mask)) + return static_cast(pos); + return UINT8_MAX; +#elif defined __GNUC__ || defined __clang__ + return static_cast(__builtin_ffs(mask)) - 1U; +#else + uint8_t pos = 0; + uint32_t bit = 1; + do + { + if (mask & bit) + return pos; + bit <<= 1; + } while (pos++ < 31); + return UINT8_MAX; +#endif +} -/** \struct VmaPool -\brief Represents custom memory pool +static inline uint8_t VmaBitScanMSB(uint64_t mask) +{ +#if defined(_MSC_VER) && defined(_WIN64) + unsigned long pos; + if (_BitScanReverse64(&pos, mask)) + return static_cast(pos); +#elif defined __GNUC__ || defined __clang__ + if (mask) + return 63 - static_cast(__builtin_clzll(mask)); +#else + uint8_t pos = 63; + uint64_t bit = 1ULL << 63; + do + { + if (mask & bit) + return pos; + bit >>= 1; + } while (pos-- > 0); +#endif + return UINT8_MAX; +} -Fill structure VmaPoolCreateInfo and call function vmaCreatePool() to create it. -Call function vmaDestroyPool() to destroy it. +static inline uint8_t VmaBitScanMSB(uint32_t mask) +{ +#ifdef _MSC_VER + unsigned long pos; + if (_BitScanReverse(&pos, mask)) + return static_cast(pos); +#elif defined __GNUC__ || defined __clang__ + if (mask) + return 31 - static_cast(__builtin_clz(mask)); +#else + uint8_t pos = 31; + uint32_t bit = 1UL << 31; + do + { + if (mask & bit) + return pos; + bit >>= 1; + } while (pos-- > 0); +#endif + return UINT8_MAX; +} -For more information see [Custom memory pools](@ref choosing_memory_type_custom_memory_pools). +/* +Returns true if given number is a power of two. +T must be unsigned integer number or signed integer but always nonnegative. +For 0 returns true. */ -VK_DEFINE_HANDLE(VmaPool) - -typedef enum VmaMemoryUsage +template +inline bool VmaIsPow2(T x) { - /** No intended memory usage specified. - Use other members of VmaAllocationCreateInfo to specify your requirements. - */ - VMA_MEMORY_USAGE_UNKNOWN = 0, - /** Memory will be used on device only, so fast access from the device is preferred. - It usually means device-local GPU (video) memory. - No need to be mappable on host. - It is roughly equivalent of `D3D12_HEAP_TYPE_DEFAULT`. - - Usage: - - - Resources written and read by device, e.g. images used as attachments. - - Resources transferred from host once (immutable) or infrequently and read by - device multiple times, e.g. textures to be sampled, vertex buffers, uniform - (constant) buffers, and majority of other types of resources used on GPU. + return (x & (x - 1)) == 0; +} - Allocation may still end up in `HOST_VISIBLE` memory on some implementations. - In such case, you are free to map it. - You can use #VMA_ALLOCATION_CREATE_MAPPED_BIT with this usage type. - */ - VMA_MEMORY_USAGE_GPU_ONLY = 1, - /** Memory will be mappable on host. - It usually means CPU (system) memory. - Guarantees to be `HOST_VISIBLE` and `HOST_COHERENT`. - CPU access is typically uncached. Writes may be write-combined. - Resources created in this pool may still be accessible to the device, but access to them can be slow. - It is roughly equivalent of `D3D12_HEAP_TYPE_UPLOAD`. - - Usage: Staging copy of resources used as transfer source. - */ - VMA_MEMORY_USAGE_CPU_ONLY = 2, - /** - Memory that is both mappable on host (guarantees to be `HOST_VISIBLE`) and preferably fast to access by GPU. - CPU access is typically uncached. Writes may be write-combined. +// Aligns given value up to nearest multiply of align value. For example: VmaAlignUp(11, 8) = 16. +// Use types like uint32_t, uint64_t as T. +template +static inline T VmaAlignUp(T val, T alignment) +{ + VMA_HEAVY_ASSERT(VmaIsPow2(alignment)); + return (val + alignment - 1) & ~(alignment - 1); +} - Usage: Resources written frequently by host (dynamic), read by device. E.g. textures (with LINEAR layout), vertex buffers, uniform buffers updated every frame or every draw call. - */ - VMA_MEMORY_USAGE_CPU_TO_GPU = 3, - /** Memory mappable on host (guarantees to be `HOST_VISIBLE`) and cached. - It is roughly equivalent of `D3D12_HEAP_TYPE_READBACK`. +// Aligns given value down to nearest multiply of align value. For example: VmaAlignUp(11, 8) = 8. +// Use types like uint32_t, uint64_t as T. +template +static inline T VmaAlignDown(T val, T alignment) +{ + VMA_HEAVY_ASSERT(VmaIsPow2(alignment)); + return val & ~(alignment - 1); +} - Usage: +// Division with mathematical rounding to nearest number. +template +static inline T VmaRoundDiv(T x, T y) +{ + return (x + (y / (T)2)) / y; +} - - Resources written by device, read by host - results of some computations, e.g. screen capture, average scene luminance for HDR tone mapping. - - Any resources read or accessed randomly on host, e.g. CPU-side copy of vertex buffer used as source of transfer, but also used for collision detection. - */ - VMA_MEMORY_USAGE_GPU_TO_CPU = 4, - /** CPU memory - memory that is preferably not `DEVICE_LOCAL`, but also not guaranteed to be `HOST_VISIBLE`. +// Divide by 'y' and round up to nearest integer. +template +static inline T VmaDivideRoundingUp(T x, T y) +{ + return (x + y - (T)1) / y; +} - Usage: Staging copy of resources moved from GPU memory to CPU memory as part - of custom paging/residency mechanism, to be moved back to GPU memory when needed. - */ - VMA_MEMORY_USAGE_CPU_COPY = 5, - /** Lazily allocated GPU memory having `VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT`. - Exists mostly on mobile platforms. Using it on desktop PC or other GPUs with no such memory type present will fail the allocation. +// Returns smallest power of 2 greater or equal to v. +static inline uint32_t VmaNextPow2(uint32_t v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} - Usage: Memory for transient attachment images (color attachments, depth attachments etc.), created with `VK_IMAGE_USAGE_TRANSIENT_ATTACHMENT_BIT`. +static inline uint64_t VmaNextPow2(uint64_t v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v++; + return v; +} - Allocations with this usage are always created as dedicated - it implies #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. - */ - VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED = 6, - - VMA_MEMORY_USAGE_MAX_ENUM = 0x7FFFFFFF -} VmaMemoryUsage; - -/// Flags to be passed as VmaAllocationCreateInfo::flags. -typedef enum VmaAllocationCreateFlagBits { - /** \brief Set this flag if the allocation should have its own memory block. - - Use it for special, big resources, like fullscreen images used as attachments. - - You should not use this flag if VmaAllocationCreateInfo::pool is not null. - */ - VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT = 0x00000001, - - /** \brief Set this flag to only try to allocate from existing `VkDeviceMemory` blocks and never create new such block. - - If new allocation cannot be placed in any of the existing blocks, allocation - fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY` error. - - You should not use #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT and - #VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT at the same time. It makes no sense. +// Returns largest power of 2 less or equal to v. +static inline uint32_t VmaPrevPow2(uint32_t v) +{ + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v = v ^ (v >> 1); + return v; +} - If VmaAllocationCreateInfo::pool is not null, this flag is implied and ignored. */ - VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT = 0x00000002, - /** \brief Set this flag to use a memory that will be persistently mapped and retrieve pointer to it. +static inline uint64_t VmaPrevPow2(uint64_t v) +{ + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v = v ^ (v >> 1); + return v; +} - Pointer to mapped memory will be returned through VmaAllocationInfo::pMappedData. +static inline bool VmaStrIsEmpty(const char* pStr) +{ + return pStr == VMA_NULL || *pStr == '\0'; +} - It is valid to use this flag for allocation made from memory type that is not - `HOST_VISIBLE`. This flag is then ignored and memory is not mapped. This is - useful if you need an allocation that is efficient to use on GPU - (`DEVICE_LOCAL`) and still want to map it directly if possible on platforms that - support it (e.g. Intel GPU). +/* +Returns true if two memory blocks occupy overlapping pages. +ResourceA must be in less memory offset than ResourceB. - You should not use this flag together with #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT. - */ - VMA_ALLOCATION_CREATE_MAPPED_BIT = 0x00000004, - /** Allocation created with this flag can become lost as a result of another - allocation with #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT flag, so you - must check it before use. +Algorithm is based on "Vulkan 1.0.39 - A Specification (with all registered Vulkan extensions)" +chapter 11.6 "Resource Memory Association", paragraph "Buffer-Image Granularity". +*/ +static inline bool VmaBlocksOnSamePage( + VkDeviceSize resourceAOffset, + VkDeviceSize resourceASize, + VkDeviceSize resourceBOffset, + VkDeviceSize pageSize) +{ + VMA_ASSERT(resourceAOffset + resourceASize <= resourceBOffset && resourceASize > 0 && pageSize > 0); + VkDeviceSize resourceAEnd = resourceAOffset + resourceASize - 1; + VkDeviceSize resourceAEndPage = resourceAEnd & ~(pageSize - 1); + VkDeviceSize resourceBStart = resourceBOffset; + VkDeviceSize resourceBStartPage = resourceBStart & ~(pageSize - 1); + return resourceAEndPage == resourceBStartPage; +} - To check if allocation is not lost, call vmaGetAllocationInfo() and check if - VmaAllocationInfo::deviceMemory is not `VK_NULL_HANDLE`. +/* +Returns true if given suballocation types could conflict and must respect +VkPhysicalDeviceLimits::bufferImageGranularity. They conflict if one is buffer +or linear image and another one is optimal image. If type is unknown, behave +conservatively. +*/ +static inline bool VmaIsBufferImageGranularityConflict( + VmaSuballocationType suballocType1, + VmaSuballocationType suballocType2) +{ + if (suballocType1 > suballocType2) + { + VMA_SWAP(suballocType1, suballocType2); + } - For details about supporting lost allocations, see Lost Allocations - chapter of User Guide on Main Page. + switch (suballocType1) + { + case VMA_SUBALLOCATION_TYPE_FREE: + return false; + case VMA_SUBALLOCATION_TYPE_UNKNOWN: + return true; + case VMA_SUBALLOCATION_TYPE_BUFFER: + return + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; + case VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN: + return + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR || + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; + case VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR: + return + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; + case VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL: + return false; + default: + VMA_ASSERT(0); + return true; + } +} - You should not use this flag together with #VMA_ALLOCATION_CREATE_MAPPED_BIT. - */ - VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT = 0x00000008, - /** While creating allocation using this flag, other allocations that were - created with flag #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT can become lost. +static void VmaWriteMagicValue(void* pData, VkDeviceSize offset) +{ +#if VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_DETECT_CORRUPTION + uint32_t* pDst = (uint32_t*)((char*)pData + offset); + const size_t numberCount = VMA_DEBUG_MARGIN / sizeof(uint32_t); + for (size_t i = 0; i < numberCount; ++i, ++pDst) + { + *pDst = VMA_CORRUPTION_DETECTION_MAGIC_VALUE; + } +#else + // no-op +#endif +} - For details about supporting lost allocations, see Lost Allocations - chapter of User Guide on Main Page. - */ - VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT = 0x00000010, - /** Set this flag to treat VmaAllocationCreateInfo::pUserData as pointer to a - null-terminated string. Instead of copying pointer value, a local copy of the - string is made and stored in allocation's `pUserData`. The string is automatically - freed together with the allocation. It is also used in vmaBuildStatsString(). - */ - VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT = 0x00000020, - /** Allocation will be created from upper stack in a double stack pool. +static bool VmaValidateMagicValue(const void* pData, VkDeviceSize offset) +{ +#if VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_DETECT_CORRUPTION + const uint32_t* pSrc = (const uint32_t*)((const char*)pData + offset); + const size_t numberCount = VMA_DEBUG_MARGIN / sizeof(uint32_t); + for (size_t i = 0; i < numberCount; ++i, ++pSrc) + { + if (*pSrc != VMA_CORRUPTION_DETECTION_MAGIC_VALUE) + { + return false; + } + } +#endif + return true; +} - This flag is only allowed for custom pools created with #VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT flag. - */ - VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT = 0x00000040, - /** Create both buffer/image and allocation, but don't bind them together. - It is useful when you want to bind yourself to do some more advanced binding, e.g. using some extensions. - The flag is meaningful only with functions that bind by default: vmaCreateBuffer(), vmaCreateImage(). - Otherwise it is ignored. - */ - VMA_ALLOCATION_CREATE_DONT_BIND_BIT = 0x00000080, - /** Create allocation only if additional device memory required for it, if any, won't exceed - memory budget. Otherwise return `VK_ERROR_OUT_OF_DEVICE_MEMORY`. - */ - VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT = 0x00000100, +/* +Fills structure with parameters of an example buffer to be used for transfers +during GPU memory defragmentation. +*/ +static void VmaFillGpuDefragmentationBufferCreateInfo(VkBufferCreateInfo& outBufCreateInfo) +{ + memset(&outBufCreateInfo, 0, sizeof(outBufCreateInfo)); + outBufCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + outBufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + outBufCreateInfo.size = (VkDeviceSize)VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE; // Example size. +} - /** Allocation strategy that chooses smallest possible free range for the - allocation. - */ - VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT = 0x00010000, - /** Allocation strategy that chooses biggest possible free range for the - allocation. - */ - VMA_ALLOCATION_CREATE_STRATEGY_WORST_FIT_BIT = 0x00020000, - /** Allocation strategy that chooses first suitable free range for the - allocation. - "First" doesn't necessarily means the one with smallest offset in memory, - but rather the one that is easiest and fastest to find. - */ - VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT = 0x00040000, +/* +Performs binary search and returns iterator to first element that is greater or +equal to (key), according to comparison (cmp). - /** Allocation strategy that tries to minimize memory usage. - */ - VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT = VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT, - /** Allocation strategy that tries to minimize allocation time. - */ - VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT = VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT, - /** Allocation strategy that tries to minimize memory fragmentation. - */ - VMA_ALLOCATION_CREATE_STRATEGY_MIN_FRAGMENTATION_BIT = VMA_ALLOCATION_CREATE_STRATEGY_WORST_FIT_BIT, +Cmp should return true if first argument is less than second argument. - /** A bit mask to extract only `STRATEGY` bits from entire set of flags. - */ - VMA_ALLOCATION_CREATE_STRATEGY_MASK = - VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT | - VMA_ALLOCATION_CREATE_STRATEGY_WORST_FIT_BIT | - VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT, +Returned value is the found element, if present in the collection or place where +new element with value (key) should be inserted. +*/ +template +static IterT VmaBinaryFindFirstNotLess(IterT beg, IterT end, const KeyT& key, const CmpLess& cmp) +{ + size_t down = 0, up = (end - beg); + while (down < up) + { + const size_t mid = down + (up - down) / 2; // Overflow-safe midpoint calculation + if (cmp(*(beg + mid), key)) + { + down = mid + 1; + } + else + { + up = mid; + } + } + return beg + down; +} - VMA_ALLOCATION_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF -} VmaAllocationCreateFlagBits; -typedef VkFlags VmaAllocationCreateFlags; +template +IterT VmaBinaryFindSorted(const IterT& beg, const IterT& end, const KeyT& value, const CmpLess& cmp) +{ + IterT it = VmaBinaryFindFirstNotLess( + beg, end, value, cmp); + if (it == end || + (!cmp(*it, value) && !cmp(value, *it))) + { + return it; + } + return end; +} -typedef struct VmaAllocationCreateInfo +/* +Returns true if all pointers in the array are not-null and unique. +Warning! O(n^2) complexity. Use only inside VMA_HEAVY_ASSERT. +T must be pointer type, e.g. VmaAllocation, VmaPool. +*/ +template +static bool VmaValidatePointerArray(uint32_t count, const T* arr) { - /// Use #VmaAllocationCreateFlagBits enum. - VmaAllocationCreateFlags flags; - /** \brief Intended usage of memory. + for (uint32_t i = 0; i < count; ++i) + { + const T iPtr = arr[i]; + if (iPtr == VMA_NULL) + { + return false; + } + for (uint32_t j = i + 1; j < count; ++j) + { + if (iPtr == arr[j]) + { + return false; + } + } + } + return true; +} - You can leave #VMA_MEMORY_USAGE_UNKNOWN if you specify memory requirements in other way. \n - If `pool` is not null, this member is ignored. - */ - VmaMemoryUsage usage; - /** \brief Flags that must be set in a Memory Type chosen for an allocation. +template +static inline void VmaPnextChainPushFront(MainT* mainStruct, NewT* newStruct) +{ + newStruct->pNext = mainStruct->pNext; + mainStruct->pNext = newStruct; +} - Leave 0 if you specify memory requirements in other way. \n - If `pool` is not null, this member is ignored.*/ - VkMemoryPropertyFlags requiredFlags; - /** \brief Flags that preferably should be set in a memory type chosen for an allocation. +// This is the main algorithm that guides the selection of a memory type best for an allocation - +// converts usage to required/preferred/not preferred flags. +static bool FindMemoryPreferences( + bool isIntegratedGPU, + const VmaAllocationCreateInfo& allocCreateInfo, + VkFlags bufImgUsage, // VkBufferCreateInfo::usage or VkImageCreateInfo::usage. UINT32_MAX if unknown. + VkMemoryPropertyFlags& outRequiredFlags, + VkMemoryPropertyFlags& outPreferredFlags, + VkMemoryPropertyFlags& outNotPreferredFlags) +{ + outRequiredFlags = allocCreateInfo.requiredFlags; + outPreferredFlags = allocCreateInfo.preferredFlags; + outNotPreferredFlags = 0; - Set to 0 if no additional flags are preferred. \n - If `pool` is not null, this member is ignored. */ - VkMemoryPropertyFlags preferredFlags; - /** \brief Bitmask containing one bit set for every memory type acceptable for this allocation. + switch(allocCreateInfo.usage) + { + case VMA_MEMORY_USAGE_UNKNOWN: + break; + case VMA_MEMORY_USAGE_GPU_ONLY: + if(!isIntegratedGPU || (outPreferredFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) + { + outPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + break; + case VMA_MEMORY_USAGE_CPU_ONLY: + outRequiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + break; + case VMA_MEMORY_USAGE_CPU_TO_GPU: + outRequiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + if(!isIntegratedGPU || (outPreferredFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) + { + outPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + break; + case VMA_MEMORY_USAGE_GPU_TO_CPU: + outRequiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + outPreferredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + break; + case VMA_MEMORY_USAGE_CPU_COPY: + outNotPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + break; + case VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED: + outRequiredFlags |= VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT; + break; + case VMA_MEMORY_USAGE_AUTO: + case VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE: + case VMA_MEMORY_USAGE_AUTO_PREFER_HOST: + { + if(bufImgUsage == UINT32_MAX) + { + VMA_ASSERT(0 && "VMA_MEMORY_USAGE_AUTO* values can only be used with functions like vmaCreateBuffer, vmaCreateImage so that the details of the created resource are known."); + return false; + } + // This relies on values of VK_IMAGE_USAGE_TRANSFER* being the same VK_BUFFER_IMAGE_TRANSFER*. + const bool deviceAccess = (bufImgUsage & ~(VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT)) != 0; + const bool hostAccessSequentialWrite = (allocCreateInfo.flags & VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT) != 0; + const bool hostAccessRandom = (allocCreateInfo.flags & VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT) != 0; + const bool hostAccessAllowTransferInstead = (allocCreateInfo.flags & VMA_ALLOCATION_CREATE_HOST_ACCESS_ALLOW_TRANSFER_INSTEAD_BIT) != 0; + const bool preferDevice = allocCreateInfo.usage == VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE; + const bool preferHost = allocCreateInfo.usage == VMA_MEMORY_USAGE_AUTO_PREFER_HOST; - Value 0 is equivalent to `UINT32_MAX` - it means any memory type is accepted if - it meets other requirements specified by this structure, with no further - restrictions on memory type index. \n - If `pool` is not null, this member is ignored. - */ - uint32_t memoryTypeBits; - /** \brief Pool that this allocation should be created in. + // CPU random access - e.g. a buffer written to or transferred from GPU to read back on CPU. + if(hostAccessRandom) + { + if(!isIntegratedGPU && deviceAccess && hostAccessAllowTransferInstead && !preferHost) + { + // Nice if it will end up in HOST_VISIBLE, but more importantly prefer DEVICE_LOCAL. + // Omitting HOST_VISIBLE here is intentional. + // In case there is DEVICE_LOCAL | HOST_VISIBLE | HOST_CACHED, it will pick that one. + // Otherwise, this will give same weight to DEVICE_LOCAL as HOST_VISIBLE | HOST_CACHED and select the former if occurs first on the list. + outPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + } + else + { + // Always CPU memory, cached. + outRequiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + } + } + // CPU sequential write - may be CPU or host-visible GPU memory, uncached and write-combined. + else if(hostAccessSequentialWrite) + { + // Want uncached and write-combined. + outNotPreferredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - Leave `VK_NULL_HANDLE` to allocate from default pool. If not null, members: - `usage`, `requiredFlags`, `preferredFlags`, `memoryTypeBits` are ignored. - */ - VmaPool VMA_NULLABLE pool; - /** \brief Custom general-purpose pointer that will be stored in #VmaAllocation, can be read as VmaAllocationInfo::pUserData and changed using vmaSetAllocationUserData(). + if(!isIntegratedGPU && deviceAccess && hostAccessAllowTransferInstead && !preferHost) + { + outPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT | VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + } + else + { + outRequiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + // Direct GPU access, CPU sequential write (e.g. a dynamic uniform buffer updated every frame) + if(deviceAccess) + { + // Could go to CPU memory or GPU BAR/unified. Up to the user to decide. If no preference, choose GPU memory. + if(preferHost) + outNotPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + else + outPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + // GPU no direct access, CPU sequential write (e.g. an upload buffer to be transferred to the GPU) + else + { + // Could go to CPU memory or GPU BAR/unified. Up to the user to decide. If no preference, choose CPU memory. + if(preferDevice) + outPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + else + outNotPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + } + } + // No CPU access + else + { + // GPU access, no CPU access (e.g. a color attachment image) - prefer GPU memory + if(deviceAccess) + { + // ...unless there is a clear preference from the user not to do so. + if(preferHost) + outNotPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + else + outPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + // No direct GPU access, no CPU access, just transfers. + // It may be staging copy intended for e.g. preserving image for next frame (then better GPU memory) or + // a "swap file" copy to free some GPU memory (then better CPU memory). + // Up to the user to decide. If no preferece, assume the former and choose GPU memory. + if(preferHost) + outNotPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + else + outPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + break; + } + default: + VMA_ASSERT(0); + } - If #VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT is used, it must be either - null or pointer to a null-terminated string. The string will be then copied to - internal buffer, so it doesn't need to be valid after allocation call. - */ - void* VMA_NULLABLE pUserData; - /** \brief A floating-point value between 0 and 1, indicating the priority of the allocation relative to other memory allocations. + // Avoid DEVICE_COHERENT unless explicitly requested. + if(((allocCreateInfo.requiredFlags | allocCreateInfo.preferredFlags) & + (VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY | VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY)) == 0) + { + outNotPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY; + } - It is used only when #VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT flag was used during creation of the #VmaAllocator object - and this allocation ends up as dedicated or is explicitly forced as dedicated using #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. - Otherwise, it has the priority of a memory block where it is placed and this variable is ignored. - */ - float priority; -} VmaAllocationCreateInfo; + return true; +} -/** -\brief Helps to find memoryTypeIndex, given memoryTypeBits and VmaAllocationCreateInfo. +//////////////////////////////////////////////////////////////////////////////// +// Memory allocation -This algorithm tries to find a memory type that: +static void* VmaMalloc(const VkAllocationCallbacks* pAllocationCallbacks, size_t size, size_t alignment) +{ + void* result = VMA_NULL; + if ((pAllocationCallbacks != VMA_NULL) && + (pAllocationCallbacks->pfnAllocation != VMA_NULL)) + { + result = (*pAllocationCallbacks->pfnAllocation)( + pAllocationCallbacks->pUserData, + size, + alignment, + VK_SYSTEM_ALLOCATION_SCOPE_OBJECT); + } + else + { + result = VMA_SYSTEM_ALIGNED_MALLOC(size, alignment); + } + VMA_ASSERT(result != VMA_NULL && "CPU memory allocation failed."); + return result; +} -- Is allowed by memoryTypeBits. -- Contains all the flags from pAllocationCreateInfo->requiredFlags. -- Matches intended usage. -- Has as many flags from pAllocationCreateInfo->preferredFlags as possible. +static void VmaFree(const VkAllocationCallbacks* pAllocationCallbacks, void* ptr) +{ + if ((pAllocationCallbacks != VMA_NULL) && + (pAllocationCallbacks->pfnFree != VMA_NULL)) + { + (*pAllocationCallbacks->pfnFree)(pAllocationCallbacks->pUserData, ptr); + } + else + { + VMA_SYSTEM_ALIGNED_FREE(ptr); + } +} -\return Returns VK_ERROR_FEATURE_NOT_PRESENT if not found. Receiving such result -from this function or any other allocating function probably means that your -device doesn't support any memory type with requested features for the specific -type of resource you want to use it for. Please check parameters of your -resource, like image layout (OPTIMAL versus LINEAR) or mip level count. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndex( - VmaAllocator VMA_NOT_NULL allocator, - uint32_t memoryTypeBits, - const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, - uint32_t* VMA_NOT_NULL pMemoryTypeIndex); +template +static T* VmaAllocate(const VkAllocationCallbacks* pAllocationCallbacks) +{ + return (T*)VmaMalloc(pAllocationCallbacks, sizeof(T), VMA_ALIGN_OF(T)); +} -/** -\brief Helps to find memoryTypeIndex, given VkBufferCreateInfo and VmaAllocationCreateInfo. +template +static T* VmaAllocateArray(const VkAllocationCallbacks* pAllocationCallbacks, size_t count) +{ + return (T*)VmaMalloc(pAllocationCallbacks, sizeof(T) * count, VMA_ALIGN_OF(T)); +} -It can be useful e.g. to determine value to be used as VmaPoolCreateInfo::memoryTypeIndex. -It internally creates a temporary, dummy buffer that never has memory bound. -It is just a convenience function, equivalent to calling: +#define vma_new(allocator, type) new(VmaAllocate(allocator))(type) -- `vkCreateBuffer` -- `vkGetBufferMemoryRequirements` -- `vmaFindMemoryTypeIndex` -- `vkDestroyBuffer` -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForBufferInfo( - VmaAllocator VMA_NOT_NULL allocator, - const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, - const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, - uint32_t* VMA_NOT_NULL pMemoryTypeIndex); +#define vma_new_array(allocator, type, count) new(VmaAllocateArray((allocator), (count)))(type) -/** -\brief Helps to find memoryTypeIndex, given VkImageCreateInfo and VmaAllocationCreateInfo. +template +static void vma_delete(const VkAllocationCallbacks* pAllocationCallbacks, T* ptr) +{ + ptr->~T(); + VmaFree(pAllocationCallbacks, ptr); +} -It can be useful e.g. to determine value to be used as VmaPoolCreateInfo::memoryTypeIndex. -It internally creates a temporary, dummy image that never has memory bound. -It is just a convenience function, equivalent to calling: +template +static void vma_delete_array(const VkAllocationCallbacks* pAllocationCallbacks, T* ptr, size_t count) +{ + if (ptr != VMA_NULL) + { + for (size_t i = count; i--; ) + { + ptr[i].~T(); + } + VmaFree(pAllocationCallbacks, ptr); + } +} -- `vkCreateImage` -- `vkGetImageMemoryRequirements` -- `vmaFindMemoryTypeIndex` -- `vkDestroyImage` -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForImageInfo( - VmaAllocator VMA_NOT_NULL allocator, - const VkImageCreateInfo* VMA_NOT_NULL pImageCreateInfo, - const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, - uint32_t* VMA_NOT_NULL pMemoryTypeIndex); +static char* VmaCreateStringCopy(const VkAllocationCallbacks* allocs, const char* srcStr) +{ + if (srcStr != VMA_NULL) + { + const size_t len = strlen(srcStr); + char* const result = vma_new_array(allocs, char, len + 1); + memcpy(result, srcStr, len + 1); + return result; + } + return VMA_NULL; +} -/// Flags to be passed as VmaPoolCreateInfo::flags. -typedef enum VmaPoolCreateFlagBits { - /** \brief Use this flag if you always allocate only buffers and linear images or only optimal images out of this pool and so Buffer-Image Granularity can be ignored. +#if VMA_STATS_STRING_ENABLED +static char* VmaCreateStringCopy(const VkAllocationCallbacks* allocs, const char* srcStr, size_t strLen) +{ + if (srcStr != VMA_NULL) + { + char* const result = vma_new_array(allocs, char, strLen + 1); + memcpy(result, srcStr, strLen); + result[strLen] = '\0'; + return result; + } + return VMA_NULL; +} +#endif // VMA_STATS_STRING_ENABLED - This is an optional optimization flag. +static void VmaFreeString(const VkAllocationCallbacks* allocs, char* str) +{ + if (str != VMA_NULL) + { + const size_t len = strlen(str); + vma_delete_array(allocs, str, len + 1); + } +} - If you always allocate using vmaCreateBuffer(), vmaCreateImage(), - vmaAllocateMemoryForBuffer(), then you don't need to use it because allocator - knows exact type of your allocations so it can handle Buffer-Image Granularity - in the optimal way. +template +size_t VmaVectorInsertSorted(VectorT& vector, const typename VectorT::value_type& value) +{ + const size_t indexToInsert = VmaBinaryFindFirstNotLess( + vector.data(), + vector.data() + vector.size(), + value, + CmpLess()) - vector.data(); + VmaVectorInsert(vector, indexToInsert, value); + return indexToInsert; +} - If you also allocate using vmaAllocateMemoryForImage() or vmaAllocateMemory(), - exact type of such allocations is not known, so allocator must be conservative - in handling Buffer-Image Granularity, which can lead to suboptimal allocation - (wasted memory). In that case, if you can make sure you always allocate only - buffers and linear images or only optimal images out of this pool, use this flag - to make allocator disregard Buffer-Image Granularity and so make allocations - faster and more optimal. - */ - VMA_POOL_CREATE_IGNORE_BUFFER_IMAGE_GRANULARITY_BIT = 0x00000002, +template +bool VmaVectorRemoveSorted(VectorT& vector, const typename VectorT::value_type& value) +{ + CmpLess comparator; + typename VectorT::iterator it = VmaBinaryFindFirstNotLess( + vector.begin(), + vector.end(), + value, + comparator); + if ((it != vector.end()) && !comparator(*it, value) && !comparator(value, *it)) + { + size_t indexToRemove = it - vector.begin(); + VmaVectorRemove(vector, indexToRemove); + return true; + } + return false; +} +#endif // _VMA_FUNCTIONS - /** \brief Enables alternative, linear allocation algorithm in this pool. +#ifndef _VMA_STATISTICS_FUNCTIONS - Specify this flag to enable linear allocation algorithm, which always creates - new allocations after last one and doesn't reuse space from allocations freed in - between. It trades memory consumption for simplified algorithm and data - structure, which has better performance and uses less memory for metadata. - - By using this flag, you can achieve behavior of free-at-once, stack, - ring buffer, and double stack. For details, see documentation chapter - \ref linear_algorithm. +static void VmaClearStatistics(VmaStatistics& outStats) +{ + outStats.blockCount = 0; + outStats.allocationCount = 0; + outStats.blockBytes = 0; + outStats.allocationBytes = 0; +} - When using this flag, you must specify VmaPoolCreateInfo::maxBlockCount == 1 (or 0 for default). +static void VmaAddStatistics(VmaStatistics& inoutStats, const VmaStatistics& src) +{ + inoutStats.blockCount += src.blockCount; + inoutStats.allocationCount += src.allocationCount; + inoutStats.blockBytes += src.blockBytes; + inoutStats.allocationBytes += src.allocationBytes; +} - For more details, see [Linear allocation algorithm](@ref linear_algorithm). - */ - VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT = 0x00000004, +static void VmaClearDetailedStatistics(VmaDetailedStatistics& outStats) +{ + VmaClearStatistics(outStats.statistics); + outStats.unusedRangeCount = 0; + outStats.allocationSizeMin = VK_WHOLE_SIZE; + outStats.allocationSizeMax = 0; + outStats.unusedRangeSizeMin = VK_WHOLE_SIZE; + outStats.unusedRangeSizeMax = 0; +} - /** \brief Enables alternative, buddy allocation algorithm in this pool. +static void VmaAddDetailedStatisticsAllocation(VmaDetailedStatistics& inoutStats, VkDeviceSize size) +{ + inoutStats.statistics.allocationCount++; + inoutStats.statistics.allocationBytes += size; + inoutStats.allocationSizeMin = VMA_MIN(inoutStats.allocationSizeMin, size); + inoutStats.allocationSizeMax = VMA_MAX(inoutStats.allocationSizeMax, size); +} - It operates on a tree of blocks, each having size that is a power of two and - a half of its parent's size. Comparing to default algorithm, this one provides - faster allocation and deallocation and decreased external fragmentation, - at the expense of more memory wasted (internal fragmentation). +static void VmaAddDetailedStatisticsUnusedRange(VmaDetailedStatistics& inoutStats, VkDeviceSize size) +{ + inoutStats.unusedRangeCount++; + inoutStats.unusedRangeSizeMin = VMA_MIN(inoutStats.unusedRangeSizeMin, size); + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, size); +} - For more details, see [Buddy allocation algorithm](@ref buddy_algorithm). - */ - VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT = 0x00000008, +static void VmaAddDetailedStatistics(VmaDetailedStatistics& inoutStats, const VmaDetailedStatistics& src) +{ + VmaAddStatistics(inoutStats.statistics, src.statistics); + inoutStats.unusedRangeCount += src.unusedRangeCount; + inoutStats.allocationSizeMin = VMA_MIN(inoutStats.allocationSizeMin, src.allocationSizeMin); + inoutStats.allocationSizeMax = VMA_MAX(inoutStats.allocationSizeMax, src.allocationSizeMax); + inoutStats.unusedRangeSizeMin = VMA_MIN(inoutStats.unusedRangeSizeMin, src.unusedRangeSizeMin); + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, src.unusedRangeSizeMax); +} - /** Bit mask to extract only `ALGORITHM` bits from entire set of flags. - */ - VMA_POOL_CREATE_ALGORITHM_MASK = - VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT | - VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT, +#endif // _VMA_STATISTICS_FUNCTIONS - VMA_POOL_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF -} VmaPoolCreateFlagBits; -typedef VkFlags VmaPoolCreateFlags; +#ifndef _VMA_MUTEX_LOCK +// Helper RAII class to lock a mutex in constructor and unlock it in destructor (at the end of scope). +struct VmaMutexLock +{ + VMA_CLASS_NO_COPY(VmaMutexLock) +public: + VmaMutexLock(VMA_MUTEX& mutex, bool useMutex = true) : + m_pMutex(useMutex ? &mutex : VMA_NULL) + { + if (m_pMutex) { m_pMutex->Lock(); } + } + ~VmaMutexLock() { if (m_pMutex) { m_pMutex->Unlock(); } } -/** \brief Describes parameter of created #VmaPool. -*/ -typedef struct VmaPoolCreateInfo { - /** \brief Vulkan memory type index to allocate this pool from. - */ - uint32_t memoryTypeIndex; - /** \brief Use combination of #VmaPoolCreateFlagBits. - */ - VmaPoolCreateFlags flags; - /** \brief Size of a single `VkDeviceMemory` block to be allocated as part of this pool, in bytes. Optional. +private: + VMA_MUTEX* m_pMutex; +}; - Specify nonzero to set explicit, constant size of memory blocks used by this - pool. +// Helper RAII class to lock a RW mutex in constructor and unlock it in destructor (at the end of scope), for reading. +struct VmaMutexLockRead +{ + VMA_CLASS_NO_COPY(VmaMutexLockRead) +public: + VmaMutexLockRead(VMA_RW_MUTEX& mutex, bool useMutex) : + m_pMutex(useMutex ? &mutex : VMA_NULL) + { + if (m_pMutex) { m_pMutex->LockRead(); } + } + ~VmaMutexLockRead() { if (m_pMutex) { m_pMutex->UnlockRead(); } } - Leave 0 to use default and let the library manage block sizes automatically. - Sizes of particular blocks may vary. - */ - VkDeviceSize blockSize; - /** \brief Minimum number of blocks to be always allocated in this pool, even if they stay empty. +private: + VMA_RW_MUTEX* m_pMutex; +}; - Set to 0 to have no preallocated blocks and allow the pool be completely empty. - */ - size_t minBlockCount; - /** \brief Maximum number of blocks that can be allocated in this pool. Optional. +// Helper RAII class to lock a RW mutex in constructor and unlock it in destructor (at the end of scope), for writing. +struct VmaMutexLockWrite +{ + VMA_CLASS_NO_COPY(VmaMutexLockWrite) +public: + VmaMutexLockWrite(VMA_RW_MUTEX& mutex, bool useMutex) + : m_pMutex(useMutex ? &mutex : VMA_NULL) + { + if (m_pMutex) { m_pMutex->LockWrite(); } + } + ~VmaMutexLockWrite() { if (m_pMutex) { m_pMutex->UnlockWrite(); } } - Set to 0 to use default, which is `SIZE_MAX`, which means no limit. +private: + VMA_RW_MUTEX* m_pMutex; +}; - Set to same value as VmaPoolCreateInfo::minBlockCount to have fixed amount of memory allocated - throughout whole lifetime of this pool. - */ - size_t maxBlockCount; - /** \brief Maximum number of additional frames that are in use at the same time as current frame. +#if VMA_DEBUG_GLOBAL_MUTEX + static VMA_MUTEX gDebugGlobalMutex; + #define VMA_DEBUG_GLOBAL_MUTEX_LOCK VmaMutexLock debugGlobalMutexLock(gDebugGlobalMutex, true); +#else + #define VMA_DEBUG_GLOBAL_MUTEX_LOCK +#endif +#endif // _VMA_MUTEX_LOCK - This value is used only when you make allocations with - #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag. Such allocation cannot become - lost if allocation.lastUseFrameIndex >= allocator.currentFrameIndex - frameInUseCount. +#ifndef _VMA_ATOMIC_TRANSACTIONAL_INCREMENT +// An object that increments given atomic but decrements it back in the destructor unless Commit() is called. +template +struct AtomicTransactionalIncrement +{ +public: + typedef std::atomic AtomicT; - For example, if you double-buffer your command buffers, so resources used for - rendering in previous frame may still be in use by the GPU at the moment you - allocate resources needed for the current frame, set this value to 1. + ~AtomicTransactionalIncrement() + { + if(m_Atomic) + --(*m_Atomic); + } - If you want to allow any allocations other than used in the current frame to - become lost, set this value to 0. - */ - uint32_t frameInUseCount; - /** \brief A floating-point value between 0 and 1, indicating the priority of the allocations in this pool relative to other memory allocations. + void Commit() { m_Atomic = nullptr; } + T Increment(AtomicT* atomic) + { + m_Atomic = atomic; + return m_Atomic->fetch_add(1); + } - It is used only when #VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT flag was used during creation of the #VmaAllocator object. - Otherwise, this variable is ignored. - */ - float priority; -} VmaPoolCreateInfo; +private: + AtomicT* m_Atomic = nullptr; +}; +#endif // _VMA_ATOMIC_TRANSACTIONAL_INCREMENT -/** \brief Describes parameter of existing #VmaPool. -*/ -typedef struct VmaPoolStats { - /** \brief Total amount of `VkDeviceMemory` allocated from Vulkan for this pool, in bytes. - */ - VkDeviceSize size; - /** \brief Total number of bytes in the pool not used by any #VmaAllocation. - */ - VkDeviceSize unusedSize; - /** \brief Number of #VmaAllocation objects created from this pool that were not destroyed or lost. - */ - size_t allocationCount; - /** \brief Number of continuous memory ranges in the pool not used by any #VmaAllocation. - */ - size_t unusedRangeCount; - /** \brief Size of the largest continuous free memory region available for new allocation. +#ifndef _VMA_STL_ALLOCATOR +// STL-compatible allocator. +template +struct VmaStlAllocator +{ + const VkAllocationCallbacks* const m_pCallbacks; + typedef T value_type; - Making a new allocation of that size is not guaranteed to succeed because of - possible additional margin required to respect alignment and buffer/image - granularity. - */ - VkDeviceSize unusedRangeSizeMax; - /** \brief Number of `VkDeviceMemory` blocks allocated for this pool. - */ - size_t blockCount; -} VmaPoolStats; + VmaStlAllocator(const VkAllocationCallbacks* pCallbacks) : m_pCallbacks(pCallbacks) {} + template + VmaStlAllocator(const VmaStlAllocator& src) : m_pCallbacks(src.m_pCallbacks) {} + VmaStlAllocator(const VmaStlAllocator&) = default; + VmaStlAllocator& operator=(const VmaStlAllocator&) = delete; -/** \brief Allocates Vulkan device memory and creates #VmaPool object. + T* allocate(size_t n) { return VmaAllocateArray(m_pCallbacks, n); } + void deallocate(T* p, size_t n) { VmaFree(m_pCallbacks, p); } -@param allocator Allocator object. -@param pCreateInfo Parameters of pool to create. -@param[out] pPool Handle to created pool. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreatePool( - VmaAllocator VMA_NOT_NULL allocator, - const VmaPoolCreateInfo* VMA_NOT_NULL pCreateInfo, - VmaPool VMA_NULLABLE * VMA_NOT_NULL pPool); + template + bool operator==(const VmaStlAllocator& rhs) const + { + return m_pCallbacks == rhs.m_pCallbacks; + } + template + bool operator!=(const VmaStlAllocator& rhs) const + { + return m_pCallbacks != rhs.m_pCallbacks; + } +}; +#endif // _VMA_STL_ALLOCATOR -/** \brief Destroys #VmaPool object and frees Vulkan device memory. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaDestroyPool( - VmaAllocator VMA_NOT_NULL allocator, - VmaPool VMA_NULLABLE pool); +#ifndef _VMA_VECTOR +/* Class with interface compatible with subset of std::vector. +T must be POD because constructors and destructors are not called and memcpy is +used for these objects. */ +template +class VmaVector +{ +public: + typedef T value_type; + typedef T* iterator; + typedef const T* const_iterator; -/** \brief Retrieves statistics of existing #VmaPool object. + VmaVector(const AllocatorT& allocator); + VmaVector(size_t count, const AllocatorT& allocator); + // This version of the constructor is here for compatibility with pre-C++14 std::vector. + // value is unused. + VmaVector(size_t count, const T& value, const AllocatorT& allocator) : VmaVector(count, allocator) {} + VmaVector(const VmaVector& src); + VmaVector& operator=(const VmaVector& rhs); + ~VmaVector() { VmaFree(m_Allocator.m_pCallbacks, m_pArray); } -@param allocator Allocator object. -@param pool Pool object. -@param[out] pPoolStats Statistics of specified pool. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolStats( - VmaAllocator VMA_NOT_NULL allocator, - VmaPool VMA_NOT_NULL pool, - VmaPoolStats* VMA_NOT_NULL pPoolStats); + bool empty() const { return m_Count == 0; } + size_t size() const { return m_Count; } + T* data() { return m_pArray; } + T& front() { VMA_HEAVY_ASSERT(m_Count > 0); return m_pArray[0]; } + T& back() { VMA_HEAVY_ASSERT(m_Count > 0); return m_pArray[m_Count - 1]; } + const T* data() const { return m_pArray; } + const T& front() const { VMA_HEAVY_ASSERT(m_Count > 0); return m_pArray[0]; } + const T& back() const { VMA_HEAVY_ASSERT(m_Count > 0); return m_pArray[m_Count - 1]; } -/** \brief Marks all allocations in given pool as lost if they are not used in current frame or VmaPoolCreateInfo::frameInUseCount back from now. + iterator begin() { return m_pArray; } + iterator end() { return m_pArray + m_Count; } + const_iterator cbegin() const { return m_pArray; } + const_iterator cend() const { return m_pArray + m_Count; } + const_iterator begin() const { return cbegin(); } + const_iterator end() const { return cend(); } + + void pop_front() { VMA_HEAVY_ASSERT(m_Count > 0); remove(0); } + void pop_back() { VMA_HEAVY_ASSERT(m_Count > 0); resize(size() - 1); } + void push_front(const T& src) { insert(0, src); } + + void push_back(const T& src); + void reserve(size_t newCapacity, bool freeMemory = false); + void resize(size_t newCount); + void clear() { resize(0); } + void shrink_to_fit(); + void insert(size_t index, const T& src); + void remove(size_t index); + + T& operator[](size_t index) { VMA_HEAVY_ASSERT(index < m_Count); return m_pArray[index]; } + const T& operator[](size_t index) const { VMA_HEAVY_ASSERT(index < m_Count); return m_pArray[index]; } -@param allocator Allocator object. -@param pool Pool. -@param[out] pLostAllocationCount Number of allocations marked as lost. Optional - pass null if you don't need this information. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaMakePoolAllocationsLost( - VmaAllocator VMA_NOT_NULL allocator, - VmaPool VMA_NOT_NULL pool, - size_t* VMA_NULLABLE pLostAllocationCount); +private: + AllocatorT m_Allocator; + T* m_pArray; + size_t m_Count; + size_t m_Capacity; +}; -/** \brief Checks magic number in margins around all allocations in given memory pool in search for corruptions. +#ifndef _VMA_VECTOR_FUNCTIONS +template +VmaVector::VmaVector(const AllocatorT& allocator) + : m_Allocator(allocator), + m_pArray(VMA_NULL), + m_Count(0), + m_Capacity(0) {} -Corruption detection is enabled only when `VMA_DEBUG_DETECT_CORRUPTION` macro is defined to nonzero, -`VMA_DEBUG_MARGIN` is defined to nonzero and the pool is created in memory type that is -`HOST_VISIBLE` and `HOST_COHERENT`. For more information, see [Corruption detection](@ref debugging_memory_usage_corruption_detection). +template +VmaVector::VmaVector(size_t count, const AllocatorT& allocator) + : m_Allocator(allocator), + m_pArray(count ? (T*)VmaAllocateArray(allocator.m_pCallbacks, count) : VMA_NULL), + m_Count(count), + m_Capacity(count) {} -Possible return values: +template +VmaVector::VmaVector(const VmaVector& src) + : m_Allocator(src.m_Allocator), + m_pArray(src.m_Count ? (T*)VmaAllocateArray(src.m_Allocator.m_pCallbacks, src.m_Count) : VMA_NULL), + m_Count(src.m_Count), + m_Capacity(src.m_Count) +{ + if (m_Count != 0) + { + memcpy(m_pArray, src.m_pArray, m_Count * sizeof(T)); + } +} -- `VK_ERROR_FEATURE_NOT_PRESENT` - corruption detection is not enabled for specified pool. -- `VK_SUCCESS` - corruption detection has been performed and succeeded. -- `VK_ERROR_VALIDATION_FAILED_EXT` - corruption detection has been performed and found memory corruptions around one of the allocations. - `VMA_ASSERT` is also fired in that case. -- Other value: Error returned by Vulkan, e.g. memory mapping failure. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckPoolCorruption(VmaAllocator VMA_NOT_NULL allocator, VmaPool VMA_NOT_NULL pool); +template +VmaVector& VmaVector::operator=(const VmaVector& rhs) +{ + if (&rhs != this) + { + resize(rhs.m_Count); + if (m_Count != 0) + { + memcpy(m_pArray, rhs.m_pArray, m_Count * sizeof(T)); + } + } + return *this; +} -/** \brief Retrieves name of a custom pool. +template +void VmaVector::push_back(const T& src) +{ + const size_t newIndex = size(); + resize(newIndex + 1); + m_pArray[newIndex] = src; +} -After the call `ppName` is either null or points to an internally-owned null-terminated string -containing name of the pool that was previously set. The pointer becomes invalid when the pool is -destroyed or its name is changed using vmaSetPoolName(). -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolName( - VmaAllocator VMA_NOT_NULL allocator, - VmaPool VMA_NOT_NULL pool, - const char* VMA_NULLABLE * VMA_NOT_NULL ppName); +template +void VmaVector::reserve(size_t newCapacity, bool freeMemory) +{ + newCapacity = VMA_MAX(newCapacity, m_Count); -/** \brief Sets name of a custom pool. + if ((newCapacity < m_Capacity) && !freeMemory) + { + newCapacity = m_Capacity; + } -`pName` can be either null or pointer to a null-terminated string with new name for the pool. -Function makes internal copy of the string, so it can be changed or freed immediately after this call. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaSetPoolName( - VmaAllocator VMA_NOT_NULL allocator, - VmaPool VMA_NOT_NULL pool, - const char* VMA_NULLABLE pName); + if (newCapacity != m_Capacity) + { + T* const newArray = newCapacity ? VmaAllocateArray(m_Allocator, newCapacity) : VMA_NULL; + if (m_Count != 0) + { + memcpy(newArray, m_pArray, m_Count * sizeof(T)); + } + VmaFree(m_Allocator.m_pCallbacks, m_pArray); + m_Capacity = newCapacity; + m_pArray = newArray; + } +} -/** \struct VmaAllocation -\brief Represents single memory allocation. +template +void VmaVector::resize(size_t newCount) +{ + size_t newCapacity = m_Capacity; + if (newCount > m_Capacity) + { + newCapacity = VMA_MAX(newCount, VMA_MAX(m_Capacity * 3 / 2, (size_t)8)); + } -It may be either dedicated block of `VkDeviceMemory` or a specific region of a bigger block of this type -plus unique offset. + if (newCapacity != m_Capacity) + { + T* const newArray = newCapacity ? VmaAllocateArray(m_Allocator.m_pCallbacks, newCapacity) : VMA_NULL; + const size_t elementsToCopy = VMA_MIN(m_Count, newCount); + if (elementsToCopy != 0) + { + memcpy(newArray, m_pArray, elementsToCopy * sizeof(T)); + } + VmaFree(m_Allocator.m_pCallbacks, m_pArray); + m_Capacity = newCapacity; + m_pArray = newArray; + } -There are multiple ways to create such object. -You need to fill structure VmaAllocationCreateInfo. -For more information see [Choosing memory type](@ref choosing_memory_type). + m_Count = newCount; +} -Although the library provides convenience functions that create Vulkan buffer or image, -allocate memory for it and bind them together, -binding of the allocation to a buffer or an image is out of scope of the allocation itself. -Allocation object can exist without buffer/image bound, -binding can be done manually by the user, and destruction of it can be done -independently of destruction of the allocation. +template +void VmaVector::shrink_to_fit() +{ + if (m_Capacity > m_Count) + { + T* newArray = VMA_NULL; + if (m_Count > 0) + { + newArray = VmaAllocateArray(m_Allocator.m_pCallbacks, m_Count); + memcpy(newArray, m_pArray, m_Count * sizeof(T)); + } + VmaFree(m_Allocator.m_pCallbacks, m_pArray); + m_Capacity = m_Count; + m_pArray = newArray; + } +} -The object also remembers its size and some other information. -To retrieve this information, use function vmaGetAllocationInfo() and inspect -returned structure VmaAllocationInfo. +template +void VmaVector::insert(size_t index, const T& src) +{ + VMA_HEAVY_ASSERT(index <= m_Count); + const size_t oldCount = size(); + resize(oldCount + 1); + if (index < oldCount) + { + memmove(m_pArray + (index + 1), m_pArray + index, (oldCount - index) * sizeof(T)); + } + m_pArray[index] = src; +} -Some kinds allocations can be in lost state. -For more information, see [Lost allocations](@ref lost_allocations). -*/ -VK_DEFINE_HANDLE(VmaAllocation) +template +void VmaVector::remove(size_t index) +{ + VMA_HEAVY_ASSERT(index < m_Count); + const size_t oldCount = size(); + if (index < oldCount - 1) + { + memmove(m_pArray + index, m_pArray + (index + 1), (oldCount - index - 1) * sizeof(T)); + } + resize(oldCount - 1); +} +#endif // _VMA_VECTOR_FUNCTIONS -/** \brief Parameters of #VmaAllocation objects, that can be retrieved using function vmaGetAllocationInfo(). -*/ -typedef struct VmaAllocationInfo { - /** \brief Memory type index that this allocation was allocated from. +template +static void VmaVectorInsert(VmaVector& vec, size_t index, const T& item) +{ + vec.insert(index, item); +} - It never changes. - */ - uint32_t memoryType; - /** \brief Handle to Vulkan memory object. +template +static void VmaVectorRemove(VmaVector& vec, size_t index) +{ + vec.remove(index); +} +#endif // _VMA_VECTOR - Same memory object can be shared by multiple allocations. +#ifndef _VMA_SMALL_VECTOR +/* +This is a vector (a variable-sized array), optimized for the case when the array is small. - It can change after call to vmaDefragment() if this allocation is passed to the function, or if allocation is lost. +It contains some number of elements in-place, which allows it to avoid heap allocation +when the actual number of elements is below that threshold. This allows normal "small" +cases to be fast without losing generality for large inputs. +*/ +template +class VmaSmallVector +{ +public: + typedef T value_type; + typedef T* iterator; - If the allocation is lost, it is equal to `VK_NULL_HANDLE`. - */ - VkDeviceMemory VMA_NULLABLE_NON_DISPATCHABLE deviceMemory; - /** \brief Offset in `VkDeviceMemory` object to the beginning of this allocation, in bytes. `(deviceMemory, offset)` pair is unique to this allocation. + VmaSmallVector(const AllocatorT& allocator); + VmaSmallVector(size_t count, const AllocatorT& allocator); + template + VmaSmallVector(const VmaSmallVector&) = delete; + template + VmaSmallVector& operator=(const VmaSmallVector&) = delete; + ~VmaSmallVector() = default; - You usually don't need to use this offset. If you create a buffer or an image together with the allocation using e.g. function - vmaCreateBuffer(), vmaCreateImage(), functions that operate on these resources refer to the beginning of the buffer or image, - not entire device memory block. Functions like vmaMapMemory(), vmaBindBufferMemory() also refer to the beginning of the allocation - and apply this offset automatically. + bool empty() const { return m_Count == 0; } + size_t size() const { return m_Count; } + T* data() { return m_Count > N ? m_DynamicArray.data() : m_StaticArray; } + T& front() { VMA_HEAVY_ASSERT(m_Count > 0); return data()[0]; } + T& back() { VMA_HEAVY_ASSERT(m_Count > 0); return data()[m_Count - 1]; } + const T* data() const { return m_Count > N ? m_DynamicArray.data() : m_StaticArray; } + const T& front() const { VMA_HEAVY_ASSERT(m_Count > 0); return data()[0]; } + const T& back() const { VMA_HEAVY_ASSERT(m_Count > 0); return data()[m_Count - 1]; } - It can change after call to vmaDefragment() if this allocation is passed to the function, or if allocation is lost. - */ - VkDeviceSize offset; - /** \brief Size of this allocation, in bytes. + iterator begin() { return data(); } + iterator end() { return data() + m_Count; } - It never changes, unless allocation is lost. + void pop_front() { VMA_HEAVY_ASSERT(m_Count > 0); remove(0); } + void pop_back() { VMA_HEAVY_ASSERT(m_Count > 0); resize(size() - 1); } + void push_front(const T& src) { insert(0, src); } - \note Allocation size returned in this variable may be greater than the size - requested for the resource e.g. as `VkBufferCreateInfo::size`. Whole size of the - allocation is accessible for operations on memory e.g. using a pointer after - mapping with vmaMapMemory(), but operations on the resource e.g. using - `vkCmdCopyBuffer` must be limited to the size of the resource. - */ - VkDeviceSize size; - /** \brief Pointer to the beginning of this allocation as mapped data. + void push_back(const T& src); + void resize(size_t newCount, bool freeMemory = false); + void clear(bool freeMemory = false); + void insert(size_t index, const T& src); + void remove(size_t index); - If the allocation hasn't been mapped using vmaMapMemory() and hasn't been - created with #VMA_ALLOCATION_CREATE_MAPPED_BIT flag, this value is null. + T& operator[](size_t index) { VMA_HEAVY_ASSERT(index < m_Count); return data()[index]; } + const T& operator[](size_t index) const { VMA_HEAVY_ASSERT(index < m_Count); return data()[index]; } - It can change after call to vmaMapMemory(), vmaUnmapMemory(). - It can also change after call to vmaDefragment() if this allocation is passed to the function. - */ - void* VMA_NULLABLE pMappedData; - /** \brief Custom general-purpose pointer that was passed as VmaAllocationCreateInfo::pUserData or set using vmaSetAllocationUserData(). +private: + size_t m_Count; + T m_StaticArray[N]; // Used when m_Size <= N + VmaVector m_DynamicArray; // Used when m_Size > N +}; - It can change after call to vmaSetAllocationUserData() for this allocation. - */ - void* VMA_NULLABLE pUserData; -} VmaAllocationInfo; +#ifndef _VMA_SMALL_VECTOR_FUNCTIONS +template +VmaSmallVector::VmaSmallVector(const AllocatorT& allocator) + : m_Count(0), + m_DynamicArray(allocator) {} -/** \brief General purpose memory allocation. +template +VmaSmallVector::VmaSmallVector(size_t count, const AllocatorT& allocator) + : m_Count(count), + m_DynamicArray(count > N ? count : 0, allocator) {} -@param[out] pAllocation Handle to allocated memory. -@param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). +template +void VmaSmallVector::push_back(const T& src) +{ + const size_t newIndex = size(); + resize(newIndex + 1); + data()[newIndex] = src; +} -You should free the memory using vmaFreeMemory() or vmaFreeMemoryPages(). +template +void VmaSmallVector::resize(size_t newCount, bool freeMemory) +{ + if (newCount > N && m_Count > N) + { + // Any direction, staying in m_DynamicArray + m_DynamicArray.resize(newCount); + if (freeMemory) + { + m_DynamicArray.shrink_to_fit(); + } + } + else if (newCount > N && m_Count <= N) + { + // Growing, moving from m_StaticArray to m_DynamicArray + m_DynamicArray.resize(newCount); + if (m_Count > 0) + { + memcpy(m_DynamicArray.data(), m_StaticArray, m_Count * sizeof(T)); + } + } + else if (newCount <= N && m_Count > N) + { + // Shrinking, moving from m_DynamicArray to m_StaticArray + if (newCount > 0) + { + memcpy(m_StaticArray, m_DynamicArray.data(), newCount * sizeof(T)); + } + m_DynamicArray.resize(0); + if (freeMemory) + { + m_DynamicArray.shrink_to_fit(); + } + } + else + { + // Any direction, staying in m_StaticArray - nothing to do here + } + m_Count = newCount; +} -It is recommended to use vmaAllocateMemoryForBuffer(), vmaAllocateMemoryForImage(), -vmaCreateBuffer(), vmaCreateImage() instead whenever possible. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemory( - VmaAllocator VMA_NOT_NULL allocator, - const VkMemoryRequirements* VMA_NOT_NULL pVkMemoryRequirements, - const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, - VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, - VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); +template +void VmaSmallVector::clear(bool freeMemory) +{ + m_DynamicArray.clear(); + if (freeMemory) + { + m_DynamicArray.shrink_to_fit(); + } + m_Count = 0; +} -/** \brief General purpose memory allocation for multiple allocation objects at once. +template +void VmaSmallVector::insert(size_t index, const T& src) +{ + VMA_HEAVY_ASSERT(index <= m_Count); + const size_t oldCount = size(); + resize(oldCount + 1); + T* const dataPtr = data(); + if (index < oldCount) + { + // I know, this could be more optimal for case where memmove can be memcpy directly from m_StaticArray to m_DynamicArray. + memmove(dataPtr + (index + 1), dataPtr + index, (oldCount - index) * sizeof(T)); + } + dataPtr[index] = src; +} -@param allocator Allocator object. -@param pVkMemoryRequirements Memory requirements for each allocation. -@param pCreateInfo Creation parameters for each alloction. -@param allocationCount Number of allocations to make. -@param[out] pAllocations Pointer to array that will be filled with handles to created allocations. -@param[out] pAllocationInfo Optional. Pointer to array that will be filled with parameters of created allocations. +template +void VmaSmallVector::remove(size_t index) +{ + VMA_HEAVY_ASSERT(index < m_Count); + const size_t oldCount = size(); + if (index < oldCount - 1) + { + // I know, this could be more optimal for case where memmove can be memcpy directly from m_DynamicArray to m_StaticArray. + T* const dataPtr = data(); + memmove(dataPtr + index, dataPtr + (index + 1), (oldCount - index - 1) * sizeof(T)); + } + resize(oldCount - 1); +} +#endif // _VMA_SMALL_VECTOR_FUNCTIONS +#endif // _VMA_SMALL_VECTOR -You should free the memory using vmaFreeMemory() or vmaFreeMemoryPages(). +#ifndef _VMA_POOL_ALLOCATOR +/* +Allocator for objects of type T using a list of arrays (pools) to speed up +allocation. Number of elements that can be allocated is not bounded because +allocator can create multiple blocks. +*/ +template +class VmaPoolAllocator +{ + VMA_CLASS_NO_COPY(VmaPoolAllocator) +public: + VmaPoolAllocator(const VkAllocationCallbacks* pAllocationCallbacks, uint32_t firstBlockCapacity); + ~VmaPoolAllocator(); + template T* Alloc(Types&&... args); + void Free(T* ptr); -Word "pages" is just a suggestion to use this function to allocate pieces of memory needed for sparse binding. -It is just a general purpose allocation function able to make multiple allocations at once. -It may be internally optimized to be more efficient than calling vmaAllocateMemory() `allocationCount` times. +private: + union Item + { + uint32_t NextFreeIndex; + alignas(T) char Value[sizeof(T)]; + }; + struct ItemBlock + { + Item* pItems; + uint32_t Capacity; + uint32_t FirstFreeIndex; + }; -All allocations are made using same parameters. All of them are created out of the same memory pool and type. -If any allocation fails, all allocations already made within this function call are also freed, so that when -returned result is not `VK_SUCCESS`, `pAllocation` array is always entirely filled with `VK_NULL_HANDLE`. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryPages( - VmaAllocator VMA_NOT_NULL allocator, - const VkMemoryRequirements* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pVkMemoryRequirements, - const VmaAllocationCreateInfo* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pCreateInfo, - size_t allocationCount, - VmaAllocation VMA_NULLABLE * VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations, - VmaAllocationInfo* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocationInfo); + const VkAllocationCallbacks* m_pAllocationCallbacks; + const uint32_t m_FirstBlockCapacity; + VmaVector> m_ItemBlocks; -/** -@param[out] pAllocation Handle to allocated memory. -@param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). + ItemBlock& CreateNewBlock(); +}; -You should free the memory using vmaFreeMemory(). -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForBuffer( - VmaAllocator VMA_NOT_NULL allocator, - VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer, - const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, - VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, - VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); +#ifndef _VMA_POOL_ALLOCATOR_FUNCTIONS +template +VmaPoolAllocator::VmaPoolAllocator(const VkAllocationCallbacks* pAllocationCallbacks, uint32_t firstBlockCapacity) + : m_pAllocationCallbacks(pAllocationCallbacks), + m_FirstBlockCapacity(firstBlockCapacity), + m_ItemBlocks(VmaStlAllocator(pAllocationCallbacks)) +{ + VMA_ASSERT(m_FirstBlockCapacity > 1); +} -/// Function similar to vmaAllocateMemoryForBuffer(). -VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForImage( - VmaAllocator VMA_NOT_NULL allocator, - VkImage VMA_NOT_NULL_NON_DISPATCHABLE image, - const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, - VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, - VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); +template +VmaPoolAllocator::~VmaPoolAllocator() +{ + for (size_t i = m_ItemBlocks.size(); i--;) + vma_delete_array(m_pAllocationCallbacks, m_ItemBlocks[i].pItems, m_ItemBlocks[i].Capacity); + m_ItemBlocks.clear(); +} -/** \brief Frees memory previously allocated using vmaAllocateMemory(), vmaAllocateMemoryForBuffer(), or vmaAllocateMemoryForImage(). +template +template T* VmaPoolAllocator::Alloc(Types&&... args) +{ + for (size_t i = m_ItemBlocks.size(); i--; ) + { + ItemBlock& block = m_ItemBlocks[i]; + // This block has some free items: Use first one. + if (block.FirstFreeIndex != UINT32_MAX) + { + Item* const pItem = &block.pItems[block.FirstFreeIndex]; + block.FirstFreeIndex = pItem->NextFreeIndex; + T* result = (T*)&pItem->Value; + new(result)T(std::forward(args)...); // Explicit constructor call. + return result; + } + } -Passing `VK_NULL_HANDLE` as `allocation` is valid. Such function call is just skipped. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemory( - VmaAllocator VMA_NOT_NULL allocator, - const VmaAllocation VMA_NULLABLE allocation); + // No block has free item: Create new one and use it. + ItemBlock& newBlock = CreateNewBlock(); + Item* const pItem = &newBlock.pItems[0]; + newBlock.FirstFreeIndex = pItem->NextFreeIndex; + T* result = (T*)&pItem->Value; + new(result) T(std::forward(args)...); // Explicit constructor call. + return result; +} -/** \brief Frees memory and destroys multiple allocations. +template +void VmaPoolAllocator::Free(T* ptr) +{ + // Search all memory blocks to find ptr. + for (size_t i = m_ItemBlocks.size(); i--; ) + { + ItemBlock& block = m_ItemBlocks[i]; -Word "pages" is just a suggestion to use this function to free pieces of memory used for sparse binding. -It is just a general purpose function to free memory and destroy allocations made using e.g. vmaAllocateMemory(), -vmaAllocateMemoryPages() and other functions. -It may be internally optimized to be more efficient than calling vmaFreeMemory() `allocationCount` times. + // Casting to union. + Item* pItemPtr; + memcpy(&pItemPtr, &ptr, sizeof(pItemPtr)); -Allocations in `pAllocations` array can come from any memory pools and types. -Passing `VK_NULL_HANDLE` as elements of `pAllocations` array is valid. Such entries are just skipped. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemoryPages( - VmaAllocator VMA_NOT_NULL allocator, - size_t allocationCount, - const VmaAllocation VMA_NULLABLE * VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations); - -/** \brief Returns current information about specified allocation and atomically marks it as used in current frame. - -Current paramteres of given allocation are returned in `pAllocationInfo`. + // Check if pItemPtr is in address range of this block. + if ((pItemPtr >= block.pItems) && (pItemPtr < block.pItems + block.Capacity)) + { + ptr->~T(); // Explicit destructor call. + const uint32_t index = static_cast(pItemPtr - block.pItems); + pItemPtr->NextFreeIndex = block.FirstFreeIndex; + block.FirstFreeIndex = index; + return; + } + } + VMA_ASSERT(0 && "Pointer doesn't belong to this memory pool."); +} -This function also atomically "touches" allocation - marks it as used in current frame, -just like vmaTouchAllocation(). -If the allocation is in lost state, `pAllocationInfo->deviceMemory == VK_NULL_HANDLE`. +template +typename VmaPoolAllocator::ItemBlock& VmaPoolAllocator::CreateNewBlock() +{ + const uint32_t newBlockCapacity = m_ItemBlocks.empty() ? + m_FirstBlockCapacity : m_ItemBlocks.back().Capacity * 3 / 2; -Although this function uses atomics and doesn't lock any mutex, so it should be quite efficient, -you can avoid calling it too often. + const ItemBlock newBlock = + { + vma_new_array(m_pAllocationCallbacks, Item, newBlockCapacity), + newBlockCapacity, + 0 + }; -- You can retrieve same VmaAllocationInfo structure while creating your resource, from function - vmaCreateBuffer(), vmaCreateImage(). You can remember it if you are sure parameters don't change - (e.g. due to defragmentation or allocation becoming lost). -- If you just want to check if allocation is not lost, vmaTouchAllocation() will work faster. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocationInfo( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - VmaAllocationInfo* VMA_NOT_NULL pAllocationInfo); + m_ItemBlocks.push_back(newBlock); -/** \brief Returns `VK_TRUE` if allocation is not lost and atomically marks it as used in current frame. + // Setup singly-linked list of all free items in this block. + for (uint32_t i = 0; i < newBlockCapacity - 1; ++i) + newBlock.pItems[i].NextFreeIndex = i + 1; + newBlock.pItems[newBlockCapacity - 1].NextFreeIndex = UINT32_MAX; + return m_ItemBlocks.back(); +} +#endif // _VMA_POOL_ALLOCATOR_FUNCTIONS +#endif // _VMA_POOL_ALLOCATOR -If the allocation has been created with #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag, -this function returns `VK_TRUE` if it's not in lost state, so it can still be used. -It then also atomically "touches" the allocation - marks it as used in current frame, -so that you can be sure it won't become lost in current frame or next `frameInUseCount` frames. +#ifndef _VMA_RAW_LIST +template +struct VmaListItem +{ + VmaListItem* pPrev; + VmaListItem* pNext; + T Value; +}; -If the allocation is in lost state, the function returns `VK_FALSE`. -Memory of such allocation, as well as buffer or image bound to it, should not be used. -Lost allocation and the buffer/image still need to be destroyed. +// Doubly linked list. +template +class VmaRawList +{ + VMA_CLASS_NO_COPY(VmaRawList) +public: + typedef VmaListItem ItemType; -If the allocation has been created without #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag, -this function always returns `VK_TRUE`. -*/ -VMA_CALL_PRE VkBool32 VMA_CALL_POST vmaTouchAllocation( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation); + VmaRawList(const VkAllocationCallbacks* pAllocationCallbacks); + // Intentionally not calling Clear, because that would be unnecessary + // computations to return all items to m_ItemAllocator as free. + ~VmaRawList() = default; -/** \brief Sets pUserData in given allocation to new value. + size_t GetCount() const { return m_Count; } + bool IsEmpty() const { return m_Count == 0; } -If the allocation was created with VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT, -pUserData must be either null, or pointer to a null-terminated string. The function -makes local copy of the string and sets it as allocation's `pUserData`. String -passed as pUserData doesn't need to be valid for whole lifetime of the allocation - -you can free it after this call. String previously pointed by allocation's -pUserData is freed from memory. + ItemType* Front() { return m_pFront; } + ItemType* Back() { return m_pBack; } + const ItemType* Front() const { return m_pFront; } + const ItemType* Back() const { return m_pBack; } -If the flag was not used, the value of pointer `pUserData` is just copied to -allocation's `pUserData`. It is opaque, so you can use it however you want - e.g. -as a pointer, ordinal number or some handle to you own data. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaSetAllocationUserData( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - void* VMA_NULLABLE pUserData); + ItemType* PushFront(); + ItemType* PushBack(); + ItemType* PushFront(const T& value); + ItemType* PushBack(const T& value); + void PopFront(); + void PopBack(); -/** \brief Creates new allocation that is in lost state from the beginning. + // Item can be null - it means PushBack. + ItemType* InsertBefore(ItemType* pItem); + // Item can be null - it means PushFront. + ItemType* InsertAfter(ItemType* pItem); + ItemType* InsertBefore(ItemType* pItem, const T& value); + ItemType* InsertAfter(ItemType* pItem, const T& value); -It can be useful if you need a dummy, non-null allocation. + void Clear(); + void Remove(ItemType* pItem); -You still need to destroy created object using vmaFreeMemory(). +private: + const VkAllocationCallbacks* const m_pAllocationCallbacks; + VmaPoolAllocator m_ItemAllocator; + ItemType* m_pFront; + ItemType* m_pBack; + size_t m_Count; +}; -Returned allocation is not tied to any specific memory pool or memory type and -not bound to any image or buffer. It has size = 0. It cannot be turned into -a real, non-empty allocation. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaCreateLostAllocation( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation); +#ifndef _VMA_RAW_LIST_FUNCTIONS +template +VmaRawList::VmaRawList(const VkAllocationCallbacks* pAllocationCallbacks) + : m_pAllocationCallbacks(pAllocationCallbacks), + m_ItemAllocator(pAllocationCallbacks, 128), + m_pFront(VMA_NULL), + m_pBack(VMA_NULL), + m_Count(0) {} -/** \brief Maps memory represented by given allocation and returns pointer to it. +template +VmaListItem* VmaRawList::PushFront() +{ + ItemType* const pNewItem = m_ItemAllocator.Alloc(); + pNewItem->pPrev = VMA_NULL; + if (IsEmpty()) + { + pNewItem->pNext = VMA_NULL; + m_pFront = pNewItem; + m_pBack = pNewItem; + m_Count = 1; + } + else + { + pNewItem->pNext = m_pFront; + m_pFront->pPrev = pNewItem; + m_pFront = pNewItem; + ++m_Count; + } + return pNewItem; +} -Maps memory represented by given allocation to make it accessible to CPU code. -When succeeded, `*ppData` contains pointer to first byte of this memory. -If the allocation is part of bigger `VkDeviceMemory` block, the pointer is -correctly offseted to the beginning of region assigned to this particular -allocation. +template +VmaListItem* VmaRawList::PushBack() +{ + ItemType* const pNewItem = m_ItemAllocator.Alloc(); + pNewItem->pNext = VMA_NULL; + if(IsEmpty()) + { + pNewItem->pPrev = VMA_NULL; + m_pFront = pNewItem; + m_pBack = pNewItem; + m_Count = 1; + } + else + { + pNewItem->pPrev = m_pBack; + m_pBack->pNext = pNewItem; + m_pBack = pNewItem; + ++m_Count; + } + return pNewItem; +} -Mapping is internally reference-counted and synchronized, so despite raw Vulkan -function `vkMapMemory()` cannot be used to map same block of `VkDeviceMemory` -multiple times simultaneously, it is safe to call this function on allocations -assigned to the same memory block. Actual Vulkan memory will be mapped on first -mapping and unmapped on last unmapping. +template +VmaListItem* VmaRawList::PushFront(const T& value) +{ + ItemType* const pNewItem = PushFront(); + pNewItem->Value = value; + return pNewItem; +} -If the function succeeded, you must call vmaUnmapMemory() to unmap the -allocation when mapping is no longer needed or before freeing the allocation, at -the latest. +template +VmaListItem* VmaRawList::PushBack(const T& value) +{ + ItemType* const pNewItem = PushBack(); + pNewItem->Value = value; + return pNewItem; +} -It also safe to call this function multiple times on the same allocation. You -must call vmaUnmapMemory() same number of times as you called vmaMapMemory(). +template +void VmaRawList::PopFront() +{ + VMA_HEAVY_ASSERT(m_Count > 0); + ItemType* const pFrontItem = m_pFront; + ItemType* const pNextItem = pFrontItem->pNext; + if (pNextItem != VMA_NULL) + { + pNextItem->pPrev = VMA_NULL; + } + m_pFront = pNextItem; + m_ItemAllocator.Free(pFrontItem); + --m_Count; +} -It is also safe to call this function on allocation created with -#VMA_ALLOCATION_CREATE_MAPPED_BIT flag. Its memory stays mapped all the time. -You must still call vmaUnmapMemory() same number of times as you called -vmaMapMemory(). You must not call vmaUnmapMemory() additional time to free the -"0-th" mapping made automatically due to #VMA_ALLOCATION_CREATE_MAPPED_BIT flag. +template +void VmaRawList::PopBack() +{ + VMA_HEAVY_ASSERT(m_Count > 0); + ItemType* const pBackItem = m_pBack; + ItemType* const pPrevItem = pBackItem->pPrev; + if(pPrevItem != VMA_NULL) + { + pPrevItem->pNext = VMA_NULL; + } + m_pBack = pPrevItem; + m_ItemAllocator.Free(pBackItem); + --m_Count; +} -This function fails when used on allocation made in memory type that is not -`HOST_VISIBLE`. +template +void VmaRawList::Clear() +{ + if (IsEmpty() == false) + { + ItemType* pItem = m_pBack; + while (pItem != VMA_NULL) + { + ItemType* const pPrevItem = pItem->pPrev; + m_ItemAllocator.Free(pItem); + pItem = pPrevItem; + } + m_pFront = VMA_NULL; + m_pBack = VMA_NULL; + m_Count = 0; + } +} -This function always fails when called for allocation that was created with -#VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag. Such allocations cannot be -mapped. +template +void VmaRawList::Remove(ItemType* pItem) +{ + VMA_HEAVY_ASSERT(pItem != VMA_NULL); + VMA_HEAVY_ASSERT(m_Count > 0); -This function doesn't automatically flush or invalidate caches. -If the allocation is made from a memory types that is not `HOST_COHERENT`, -you also need to use vmaInvalidateAllocation() / vmaFlushAllocation(), as required by Vulkan specification. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaMapMemory( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - void* VMA_NULLABLE * VMA_NOT_NULL ppData); + if(pItem->pPrev != VMA_NULL) + { + pItem->pPrev->pNext = pItem->pNext; + } + else + { + VMA_HEAVY_ASSERT(m_pFront == pItem); + m_pFront = pItem->pNext; + } -/** \brief Unmaps memory represented by given allocation, mapped previously using vmaMapMemory(). + if(pItem->pNext != VMA_NULL) + { + pItem->pNext->pPrev = pItem->pPrev; + } + else + { + VMA_HEAVY_ASSERT(m_pBack == pItem); + m_pBack = pItem->pPrev; + } -For details, see description of vmaMapMemory(). + m_ItemAllocator.Free(pItem); + --m_Count; +} -This function doesn't automatically flush or invalidate caches. -If the allocation is made from a memory types that is not `HOST_COHERENT`, -you also need to use vmaInvalidateAllocation() / vmaFlushAllocation(), as required by Vulkan specification. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaUnmapMemory( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation); - -/** \brief Flushes memory of given allocation. +template +VmaListItem* VmaRawList::InsertBefore(ItemType* pItem) +{ + if(pItem != VMA_NULL) + { + ItemType* const prevItem = pItem->pPrev; + ItemType* const newItem = m_ItemAllocator.Alloc(); + newItem->pPrev = prevItem; + newItem->pNext = pItem; + pItem->pPrev = newItem; + if(prevItem != VMA_NULL) + { + prevItem->pNext = newItem; + } + else + { + VMA_HEAVY_ASSERT(m_pFront == pItem); + m_pFront = newItem; + } + ++m_Count; + return newItem; + } + else + return PushBack(); +} -Calls `vkFlushMappedMemoryRanges()` for memory associated with given range of given allocation. -It needs to be called after writing to a mapped memory for memory types that are not `HOST_COHERENT`. -Unmap operation doesn't do that automatically. +template +VmaListItem* VmaRawList::InsertAfter(ItemType* pItem) +{ + if(pItem != VMA_NULL) + { + ItemType* const nextItem = pItem->pNext; + ItemType* const newItem = m_ItemAllocator.Alloc(); + newItem->pNext = nextItem; + newItem->pPrev = pItem; + pItem->pNext = newItem; + if(nextItem != VMA_NULL) + { + nextItem->pPrev = newItem; + } + else + { + VMA_HEAVY_ASSERT(m_pBack == pItem); + m_pBack = newItem; + } + ++m_Count; + return newItem; + } + else + return PushFront(); +} -- `offset` must be relative to the beginning of allocation. -- `size` can be `VK_WHOLE_SIZE`. It means all memory from `offset` the the end of given allocation. -- `offset` and `size` don't have to be aligned. - They are internally rounded down/up to multiply of `nonCoherentAtomSize`. -- If `size` is 0, this call is ignored. -- If memory type that the `allocation` belongs to is not `HOST_VISIBLE` or it is `HOST_COHERENT`, - this call is ignored. +template +VmaListItem* VmaRawList::InsertBefore(ItemType* pItem, const T& value) +{ + ItemType* const newItem = InsertBefore(pItem); + newItem->Value = value; + return newItem; +} -Warning! `offset` and `size` are relative to the contents of given `allocation`. -If you mean whole allocation, you can pass 0 and `VK_WHOLE_SIZE`, respectively. -Do not pass allocation's offset as `offset`!!! +template +VmaListItem* VmaRawList::InsertAfter(ItemType* pItem, const T& value) +{ + ItemType* const newItem = InsertAfter(pItem); + newItem->Value = value; + return newItem; +} +#endif // _VMA_RAW_LIST_FUNCTIONS +#endif // _VMA_RAW_LIST -This function returns the `VkResult` from `vkFlushMappedMemoryRanges` if it is -called, otherwise `VK_SUCCESS`. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocation( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - VkDeviceSize offset, - VkDeviceSize size); +#ifndef _VMA_LIST +template +class VmaList +{ + VMA_CLASS_NO_COPY(VmaList) +public: + class reverse_iterator; + class const_iterator; + class const_reverse_iterator; -/** \brief Invalidates memory of given allocation. + class iterator + { + friend class const_iterator; + friend class VmaList; + public: + iterator() : m_pList(VMA_NULL), m_pItem(VMA_NULL) {} + iterator(const reverse_iterator& src) : m_pList(src.m_pList), m_pItem(src.m_pItem) {} -Calls `vkInvalidateMappedMemoryRanges()` for memory associated with given range of given allocation. -It needs to be called before reading from a mapped memory for memory types that are not `HOST_COHERENT`. -Map operation doesn't do that automatically. + T& operator*() const { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); return m_pItem->Value; } + T* operator->() const { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); return &m_pItem->Value; } -- `offset` must be relative to the beginning of allocation. -- `size` can be `VK_WHOLE_SIZE`. It means all memory from `offset` the the end of given allocation. -- `offset` and `size` don't have to be aligned. - They are internally rounded down/up to multiply of `nonCoherentAtomSize`. -- If `size` is 0, this call is ignored. -- If memory type that the `allocation` belongs to is not `HOST_VISIBLE` or it is `HOST_COHERENT`, - this call is ignored. + bool operator==(const iterator& rhs) const { VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); return m_pItem == rhs.m_pItem; } + bool operator!=(const iterator& rhs) const { VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); return m_pItem != rhs.m_pItem; } -Warning! `offset` and `size` are relative to the contents of given `allocation`. -If you mean whole allocation, you can pass 0 and `VK_WHOLE_SIZE`, respectively. -Do not pass allocation's offset as `offset`!!! + iterator operator++(int) { iterator result = *this; ++*this; return result; } + iterator operator--(int) { iterator result = *this; --*this; return result; } -This function returns the `VkResult` from `vkInvalidateMappedMemoryRanges` if -it is called, otherwise `VK_SUCCESS`. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocation( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - VkDeviceSize offset, - VkDeviceSize size); + iterator& operator++() { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); m_pItem = m_pItem->pNext; return *this; } + iterator& operator--(); -/** \brief Flushes memory of given set of allocations. + private: + VmaRawList* m_pList; + VmaListItem* m_pItem; -Calls `vkFlushMappedMemoryRanges()` for memory associated with given ranges of given allocations. -For more information, see documentation of vmaFlushAllocation(). + iterator(VmaRawList* pList, VmaListItem* pItem) : m_pList(pList), m_pItem(pItem) {} + }; + class reverse_iterator + { + friend class const_reverse_iterator; + friend class VmaList; + public: + reverse_iterator() : m_pList(VMA_NULL), m_pItem(VMA_NULL) {} + reverse_iterator(const iterator& src) : m_pList(src.m_pList), m_pItem(src.m_pItem) {} -\param allocator -\param allocationCount -\param allocations -\param offsets If not null, it must point to an array of offsets of regions to flush, relative to the beginning of respective allocations. Null means all ofsets are zero. -\param sizes If not null, it must point to an array of sizes of regions to flush in respective allocations. Null means `VK_WHOLE_SIZE` for all allocations. + T& operator*() const { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); return m_pItem->Value; } + T* operator->() const { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); return &m_pItem->Value; } -This function returns the `VkResult` from `vkFlushMappedMemoryRanges` if it is -called, otherwise `VK_SUCCESS`. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocations( - VmaAllocator VMA_NOT_NULL allocator, - uint32_t allocationCount, - const VmaAllocation VMA_NOT_NULL * VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) allocations, - const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) offsets, - const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) sizes); + bool operator==(const reverse_iterator& rhs) const { VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); return m_pItem == rhs.m_pItem; } + bool operator!=(const reverse_iterator& rhs) const { VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); return m_pItem != rhs.m_pItem; } -/** \brief Invalidates memory of given set of allocations. + reverse_iterator operator++(int) { reverse_iterator result = *this; ++* this; return result; } + reverse_iterator operator--(int) { reverse_iterator result = *this; --* this; return result; } -Calls `vkInvalidateMappedMemoryRanges()` for memory associated with given ranges of given allocations. -For more information, see documentation of vmaInvalidateAllocation(). + reverse_iterator& operator++() { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); m_pItem = m_pItem->pPrev; return *this; } + reverse_iterator& operator--(); -\param allocator -\param allocationCount -\param allocations -\param offsets If not null, it must point to an array of offsets of regions to flush, relative to the beginning of respective allocations. Null means all ofsets are zero. -\param sizes If not null, it must point to an array of sizes of regions to flush in respective allocations. Null means `VK_WHOLE_SIZE` for all allocations. + private: + VmaRawList* m_pList; + VmaListItem* m_pItem; -This function returns the `VkResult` from `vkInvalidateMappedMemoryRanges` if it is -called, otherwise `VK_SUCCESS`. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocations( - VmaAllocator VMA_NOT_NULL allocator, - uint32_t allocationCount, - const VmaAllocation VMA_NOT_NULL * VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) allocations, - const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) offsets, - const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) sizes); + reverse_iterator(VmaRawList* pList, VmaListItem* pItem) : m_pList(pList), m_pItem(pItem) {} + }; + class const_iterator + { + friend class VmaList; + public: + const_iterator() : m_pList(VMA_NULL), m_pItem(VMA_NULL) {} + const_iterator(const iterator& src) : m_pList(src.m_pList), m_pItem(src.m_pItem) {} + const_iterator(const reverse_iterator& src) : m_pList(src.m_pList), m_pItem(src.m_pItem) {} -/** \brief Checks magic number in margins around all allocations in given memory types (in both default and custom pools) in search for corruptions. + iterator drop_const() { return { const_cast*>(m_pList), const_cast*>(m_pItem) }; } -@param memoryTypeBits Bit mask, where each bit set means that a memory type with that index should be checked. + const T& operator*() const { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); return m_pItem->Value; } + const T* operator->() const { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); return &m_pItem->Value; } -Corruption detection is enabled only when `VMA_DEBUG_DETECT_CORRUPTION` macro is defined to nonzero, -`VMA_DEBUG_MARGIN` is defined to nonzero and only for memory types that are -`HOST_VISIBLE` and `HOST_COHERENT`. For more information, see [Corruption detection](@ref debugging_memory_usage_corruption_detection). + bool operator==(const const_iterator& rhs) const { VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); return m_pItem == rhs.m_pItem; } + bool operator!=(const const_iterator& rhs) const { VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); return m_pItem != rhs.m_pItem; } -Possible return values: + const_iterator operator++(int) { const_iterator result = *this; ++* this; return result; } + const_iterator operator--(int) { const_iterator result = *this; --* this; return result; } -- `VK_ERROR_FEATURE_NOT_PRESENT` - corruption detection is not enabled for any of specified memory types. -- `VK_SUCCESS` - corruption detection has been performed and succeeded. -- `VK_ERROR_VALIDATION_FAILED_EXT` - corruption detection has been performed and found memory corruptions around one of the allocations. - `VMA_ASSERT` is also fired in that case. -- Other value: Error returned by Vulkan, e.g. memory mapping failure. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckCorruption(VmaAllocator VMA_NOT_NULL allocator, uint32_t memoryTypeBits); + const_iterator& operator++() { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); m_pItem = m_pItem->pNext; return *this; } + const_iterator& operator--(); -/** \struct VmaDefragmentationContext -\brief Represents Opaque object that represents started defragmentation process. + private: + const VmaRawList* m_pList; + const VmaListItem* m_pItem; -Fill structure #VmaDefragmentationInfo2 and call function vmaDefragmentationBegin() to create it. -Call function vmaDefragmentationEnd() to destroy it. -*/ -VK_DEFINE_HANDLE(VmaDefragmentationContext) + const_iterator(const VmaRawList* pList, const VmaListItem* pItem) : m_pList(pList), m_pItem(pItem) {} + }; + class const_reverse_iterator + { + friend class VmaList; + public: + const_reverse_iterator() : m_pList(VMA_NULL), m_pItem(VMA_NULL) {} + const_reverse_iterator(const reverse_iterator& src) : m_pList(src.m_pList), m_pItem(src.m_pItem) {} + const_reverse_iterator(const iterator& src) : m_pList(src.m_pList), m_pItem(src.m_pItem) {} -/// Flags to be used in vmaDefragmentationBegin(). None at the moment. Reserved for future use. -typedef enum VmaDefragmentationFlagBits { - VMA_DEFRAGMENTATION_FLAG_INCREMENTAL = 0x1, - VMA_DEFRAGMENTATION_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF -} VmaDefragmentationFlagBits; -typedef VkFlags VmaDefragmentationFlags; + reverse_iterator drop_const() { return { const_cast*>(m_pList), const_cast*>(m_pItem) }; } -/** \brief Parameters for defragmentation. + const T& operator*() const { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); return m_pItem->Value; } + const T* operator->() const { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); return &m_pItem->Value; } -To be used with function vmaDefragmentationBegin(). -*/ -typedef struct VmaDefragmentationInfo2 { - /** \brief Reserved for future use. Should be 0. - */ - VmaDefragmentationFlags flags; - /** \brief Number of allocations in `pAllocations` array. - */ - uint32_t allocationCount; - /** \brief Pointer to array of allocations that can be defragmented. + bool operator==(const const_reverse_iterator& rhs) const { VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); return m_pItem == rhs.m_pItem; } + bool operator!=(const const_reverse_iterator& rhs) const { VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); return m_pItem != rhs.m_pItem; } - The array should have `allocationCount` elements. - The array should not contain nulls. - Elements in the array should be unique - same allocation cannot occur twice. - It is safe to pass allocations that are in the lost state - they are ignored. - All allocations not present in this array are considered non-moveable during this defragmentation. - */ - const VmaAllocation VMA_NOT_NULL * VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations; - /** \brief Optional, output. Pointer to array that will be filled with information whether the allocation at certain index has been changed during defragmentation. + const_reverse_iterator operator++(int) { const_reverse_iterator result = *this; ++* this; return result; } + const_reverse_iterator operator--(int) { const_reverse_iterator result = *this; --* this; return result; } - The array should have `allocationCount` elements. - You can pass null if you are not interested in this information. - */ - VkBool32* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocationsChanged; - /** \brief Numer of pools in `pPools` array. - */ - uint32_t poolCount; - /** \brief Either null or pointer to array of pools to be defragmented. + const_reverse_iterator& operator++() { VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); m_pItem = m_pItem->pPrev; return *this; } + const_reverse_iterator& operator--(); - All the allocations in the specified pools can be moved during defragmentation - and there is no way to check if they were really moved as in `pAllocationsChanged`, - so you must query all the allocations in all these pools for new `VkDeviceMemory` - and offset using vmaGetAllocationInfo() if you might need to recreate buffers - and images bound to them. + private: + const VmaRawList* m_pList; + const VmaListItem* m_pItem; - The array should have `poolCount` elements. - The array should not contain nulls. - Elements in the array should be unique - same pool cannot occur twice. + const_reverse_iterator(const VmaRawList* pList, const VmaListItem* pItem) : m_pList(pList), m_pItem(pItem) {} + }; - Using this array is equivalent to specifying all allocations from the pools in `pAllocations`. - It might be more efficient. - */ - const VmaPool VMA_NOT_NULL * VMA_NULLABLE VMA_LEN_IF_NOT_NULL(poolCount) pPools; - /** \brief Maximum total numbers of bytes that can be copied while moving allocations to different places using transfers on CPU side, like `memcpy()`, `memmove()`. + VmaList(const AllocatorT& allocator) : m_RawList(allocator.m_pCallbacks) {} - `VK_WHOLE_SIZE` means no limit. - */ - VkDeviceSize maxCpuBytesToMove; - /** \brief Maximum number of allocations that can be moved to a different place using transfers on CPU side, like `memcpy()`, `memmove()`. + bool empty() const { return m_RawList.IsEmpty(); } + size_t size() const { return m_RawList.GetCount(); } - `UINT32_MAX` means no limit. - */ - uint32_t maxCpuAllocationsToMove; - /** \brief Maximum total numbers of bytes that can be copied while moving allocations to different places using transfers on GPU side, posted to `commandBuffer`. + iterator begin() { return iterator(&m_RawList, m_RawList.Front()); } + iterator end() { return iterator(&m_RawList, VMA_NULL); } - `VK_WHOLE_SIZE` means no limit. - */ - VkDeviceSize maxGpuBytesToMove; - /** \brief Maximum number of allocations that can be moved to a different place using transfers on GPU side, posted to `commandBuffer`. + const_iterator cbegin() const { return const_iterator(&m_RawList, m_RawList.Front()); } + const_iterator cend() const { return const_iterator(&m_RawList, VMA_NULL); } - `UINT32_MAX` means no limit. - */ - uint32_t maxGpuAllocationsToMove; - /** \brief Optional. Command buffer where GPU copy commands will be posted. + const_iterator begin() const { return cbegin(); } + const_iterator end() const { return cend(); } - If not null, it must be a valid command buffer handle that supports Transfer queue type. - It must be in the recording state and outside of a render pass instance. - You need to submit it and make sure it finished execution before calling vmaDefragmentationEnd(). + reverse_iterator rbegin() { return reverse_iterator(&m_RawList, m_RawList.Back()); } + reverse_iterator rend() { return reverse_iterator(&m_RawList, VMA_NULL); } - Passing null means that only CPU defragmentation will be performed. - */ - VkCommandBuffer VMA_NULLABLE commandBuffer; -} VmaDefragmentationInfo2; + const_reverse_iterator crbegin() const { return const_reverse_iterator(&m_RawList, m_RawList.Back()); } + const_reverse_iterator crend() const { return const_reverse_iterator(&m_RawList, VMA_NULL); } -typedef struct VmaDefragmentationPassMoveInfo { - VmaAllocation VMA_NOT_NULL allocation; - VkDeviceMemory VMA_NOT_NULL_NON_DISPATCHABLE memory; - VkDeviceSize offset; -} VmaDefragmentationPassMoveInfo; + const_reverse_iterator rbegin() const { return crbegin(); } + const_reverse_iterator rend() const { return crend(); } -/** \brief Parameters for incremental defragmentation steps. + void push_back(const T& value) { m_RawList.PushBack(value); } + iterator insert(iterator it, const T& value) { return iterator(&m_RawList, m_RawList.InsertBefore(it.m_pItem, value)); } -To be used with function vmaBeginDefragmentationPass(). -*/ -typedef struct VmaDefragmentationPassInfo { - uint32_t moveCount; - VmaDefragmentationPassMoveInfo* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(moveCount) pMoves; -} VmaDefragmentationPassInfo; + void clear() { m_RawList.Clear(); } + void erase(iterator it) { m_RawList.Remove(it.m_pItem); } -/** \brief Deprecated. Optional configuration parameters to be passed to function vmaDefragment(). +private: + VmaRawList m_RawList; +}; -\deprecated This is a part of the old interface. It is recommended to use structure #VmaDefragmentationInfo2 and function vmaDefragmentationBegin() instead. -*/ -typedef struct VmaDefragmentationInfo { - /** \brief Maximum total numbers of bytes that can be copied while moving allocations to different places. +#ifndef _VMA_LIST_FUNCTIONS +template +typename VmaList::iterator& VmaList::iterator::operator--() +{ + if (m_pItem != VMA_NULL) + { + m_pItem = m_pItem->pPrev; + } + else + { + VMA_HEAVY_ASSERT(!m_pList->IsEmpty()); + m_pItem = m_pList->Back(); + } + return *this; +} - Default is `VK_WHOLE_SIZE`, which means no limit. - */ - VkDeviceSize maxBytesToMove; - /** \brief Maximum number of allocations that can be moved to different place. +template +typename VmaList::reverse_iterator& VmaList::reverse_iterator::operator--() +{ + if (m_pItem != VMA_NULL) + { + m_pItem = m_pItem->pNext; + } + else + { + VMA_HEAVY_ASSERT(!m_pList->IsEmpty()); + m_pItem = m_pList->Front(); + } + return *this; +} - Default is `UINT32_MAX`, which means no limit. - */ - uint32_t maxAllocationsToMove; -} VmaDefragmentationInfo; +template +typename VmaList::const_iterator& VmaList::const_iterator::operator--() +{ + if (m_pItem != VMA_NULL) + { + m_pItem = m_pItem->pPrev; + } + else + { + VMA_HEAVY_ASSERT(!m_pList->IsEmpty()); + m_pItem = m_pList->Back(); + } + return *this; +} -/** \brief Statistics returned by function vmaDefragment(). */ -typedef struct VmaDefragmentationStats { - /// Total number of bytes that have been copied while moving allocations to different places. - VkDeviceSize bytesMoved; - /// Total number of bytes that have been released to the system by freeing empty `VkDeviceMemory` objects. - VkDeviceSize bytesFreed; - /// Number of allocations that have been moved to different places. - uint32_t allocationsMoved; - /// Number of empty `VkDeviceMemory` objects that have been released to the system. - uint32_t deviceMemoryBlocksFreed; -} VmaDefragmentationStats; - -/** \brief Begins defragmentation process. +template +typename VmaList::const_reverse_iterator& VmaList::const_reverse_iterator::operator--() +{ + if (m_pItem != VMA_NULL) + { + m_pItem = m_pItem->pNext; + } + else + { + VMA_HEAVY_ASSERT(!m_pList->IsEmpty()); + m_pItem = m_pList->Back(); + } + return *this; +} +#endif // _VMA_LIST_FUNCTIONS +#endif // _VMA_LIST -@param allocator Allocator object. -@param pInfo Structure filled with parameters of defragmentation. -@param[out] pStats Optional. Statistics of defragmentation. You can pass null if you are not interested in this information. -@param[out] pContext Context object that must be passed to vmaDefragmentationEnd() to finish defragmentation. -@return `VK_SUCCESS` and `*pContext == null` if defragmentation finished within this function call. `VK_NOT_READY` and `*pContext != null` if defragmentation has been started and you need to call vmaDefragmentationEnd() to finish it. Negative value in case of error. - -Use this function instead of old, deprecated vmaDefragment(). - -Warning! Between the call to vmaDefragmentationBegin() and vmaDefragmentationEnd(): - -- You should not use any of allocations passed as `pInfo->pAllocations` or - any allocations that belong to pools passed as `pInfo->pPools`, - including calling vmaGetAllocationInfo(), vmaTouchAllocation(), or access - their data. -- Some mutexes protecting internal data structures may be locked, so trying to - make or free any allocations, bind buffers or images, map memory, or launch - another simultaneous defragmentation in between may cause stall (when done on - another thread) or deadlock (when done on the same thread), unless you are - 100% sure that defragmented allocations are in different pools. -- Information returned via `pStats` and `pInfo->pAllocationsChanged` are undefined. - They become valid after call to vmaDefragmentationEnd(). -- If `pInfo->commandBuffer` is not null, you must submit that command buffer - and make sure it finished execution before calling vmaDefragmentationEnd(). - -For more information and important limitations regarding defragmentation, see documentation chapter: -[Defragmentation](@ref defragmentation). +#ifndef _VMA_INTRUSIVE_LINKED_LIST +/* +Expected interface of ItemTypeTraits: +struct MyItemTypeTraits +{ + typedef MyItem ItemType; + static ItemType* GetPrev(const ItemType* item) { return item->myPrevPtr; } + static ItemType* GetNext(const ItemType* item) { return item->myNextPtr; } + static ItemType*& AccessPrev(ItemType* item) { return item->myPrevPtr; } + static ItemType*& AccessNext(ItemType* item) { return item->myNextPtr; } +}; */ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragmentationBegin( - VmaAllocator VMA_NOT_NULL allocator, - const VmaDefragmentationInfo2* VMA_NOT_NULL pInfo, - VmaDefragmentationStats* VMA_NULLABLE pStats, - VmaDefragmentationContext VMA_NULLABLE * VMA_NOT_NULL pContext); +template +class VmaIntrusiveLinkedList +{ +public: + typedef typename ItemTypeTraits::ItemType ItemType; + static ItemType* GetPrev(const ItemType* item) { return ItemTypeTraits::GetPrev(item); } + static ItemType* GetNext(const ItemType* item) { return ItemTypeTraits::GetNext(item); } + + // Movable, not copyable. + VmaIntrusiveLinkedList() = default; + VmaIntrusiveLinkedList(VmaIntrusiveLinkedList && src); + VmaIntrusiveLinkedList(const VmaIntrusiveLinkedList&) = delete; + VmaIntrusiveLinkedList& operator=(VmaIntrusiveLinkedList&& src); + VmaIntrusiveLinkedList& operator=(const VmaIntrusiveLinkedList&) = delete; + ~VmaIntrusiveLinkedList() { VMA_HEAVY_ASSERT(IsEmpty()); } + + size_t GetCount() const { return m_Count; } + bool IsEmpty() const { return m_Count == 0; } + ItemType* Front() { return m_Front; } + ItemType* Back() { return m_Back; } + const ItemType* Front() const { return m_Front; } + const ItemType* Back() const { return m_Back; } + + void PushBack(ItemType* item); + void PushFront(ItemType* item); + ItemType* PopBack(); + ItemType* PopFront(); + + // MyItem can be null - it means PushBack. + void InsertBefore(ItemType* existingItem, ItemType* newItem); + // MyItem can be null - it means PushFront. + void InsertAfter(ItemType* existingItem, ItemType* newItem); + void Remove(ItemType* item); + void RemoveAll(); -/** \brief Ends defragmentation process. +private: + ItemType* m_Front = VMA_NULL; + ItemType* m_Back = VMA_NULL; + size_t m_Count = 0; +}; -Use this function to finish defragmentation started by vmaDefragmentationBegin(). -It is safe to pass `context == null`. The function then does nothing. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragmentationEnd( - VmaAllocator VMA_NOT_NULL allocator, - VmaDefragmentationContext VMA_NULLABLE context); +#ifndef _VMA_INTRUSIVE_LINKED_LIST_FUNCTIONS +template +VmaIntrusiveLinkedList::VmaIntrusiveLinkedList(VmaIntrusiveLinkedList&& src) + : m_Front(src.m_Front), m_Back(src.m_Back), m_Count(src.m_Count) +{ + src.m_Front = src.m_Back = VMA_NULL; + src.m_Count = 0; +} -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBeginDefragmentationPass( - VmaAllocator VMA_NOT_NULL allocator, - VmaDefragmentationContext VMA_NULLABLE context, - VmaDefragmentationPassInfo* VMA_NOT_NULL pInfo -); -VMA_CALL_PRE VkResult VMA_CALL_POST vmaEndDefragmentationPass( - VmaAllocator VMA_NOT_NULL allocator, - VmaDefragmentationContext VMA_NULLABLE context -); - -/** \brief Deprecated. Compacts memory by moving allocations. - -@param pAllocations Array of allocations that can be moved during this compation. -@param allocationCount Number of elements in pAllocations and pAllocationsChanged arrays. -@param[out] pAllocationsChanged Array of boolean values that will indicate whether matching allocation in pAllocations array has been moved. This parameter is optional. Pass null if you don't need this information. -@param pDefragmentationInfo Configuration parameters. Optional - pass null to use default values. -@param[out] pDefragmentationStats Statistics returned by the function. Optional - pass null if you don't need this information. -@return `VK_SUCCESS` if completed, negative error code in case of error. - -\deprecated This is a part of the old interface. It is recommended to use structure #VmaDefragmentationInfo2 and function vmaDefragmentationBegin() instead. - -This function works by moving allocations to different places (different -`VkDeviceMemory` objects and/or different offsets) in order to optimize memory -usage. Only allocations that are in `pAllocations` array can be moved. All other -allocations are considered nonmovable in this call. Basic rules: - -- Only allocations made in memory types that have - `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` and `VK_MEMORY_PROPERTY_HOST_COHERENT_BIT` - flags can be compacted. You may pass other allocations but it makes no sense - - these will never be moved. -- Custom pools created with #VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT or - #VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT flag are not defragmented. Allocations - passed to this function that come from such pools are ignored. -- Allocations created with #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT or - created as dedicated allocations for any other reason are also ignored. -- Both allocations made with or without #VMA_ALLOCATION_CREATE_MAPPED_BIT - flag can be compacted. If not persistently mapped, memory will be mapped - temporarily inside this function if needed. -- You must not pass same #VmaAllocation object multiple times in `pAllocations` array. - -The function also frees empty `VkDeviceMemory` blocks. - -Warning: This function may be time-consuming, so you shouldn't call it too often -(like after every resource creation/destruction). -You can call it on special occasions (like when reloading a game level or -when you just destroyed a lot of objects). Calling it every frame may be OK, but -you should measure that on your platform. - -For more information, see [Defragmentation](@ref defragmentation) chapter. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragment( - VmaAllocator VMA_NOT_NULL allocator, - const VmaAllocation VMA_NOT_NULL * VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations, - size_t allocationCount, - VkBool32* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocationsChanged, - const VmaDefragmentationInfo* VMA_NULLABLE pDefragmentationInfo, - VmaDefragmentationStats* VMA_NULLABLE pDefragmentationStats); +template +VmaIntrusiveLinkedList& VmaIntrusiveLinkedList::operator=(VmaIntrusiveLinkedList&& src) +{ + if (&src != this) + { + VMA_HEAVY_ASSERT(IsEmpty()); + m_Front = src.m_Front; + m_Back = src.m_Back; + m_Count = src.m_Count; + src.m_Front = src.m_Back = VMA_NULL; + src.m_Count = 0; + } + return *this; +} -/** \brief Binds buffer to allocation. +template +void VmaIntrusiveLinkedList::PushBack(ItemType* item) +{ + VMA_HEAVY_ASSERT(ItemTypeTraits::GetPrev(item) == VMA_NULL && ItemTypeTraits::GetNext(item) == VMA_NULL); + if (IsEmpty()) + { + m_Front = item; + m_Back = item; + m_Count = 1; + } + else + { + ItemTypeTraits::AccessPrev(item) = m_Back; + ItemTypeTraits::AccessNext(m_Back) = item; + m_Back = item; + ++m_Count; + } +} -Binds specified buffer to region of memory represented by specified allocation. -Gets `VkDeviceMemory` handle and offset from the allocation. -If you want to create a buffer, allocate memory for it and bind them together separately, -you should use this function for binding instead of standard `vkBindBufferMemory()`, -because it ensures proper synchronization so that when a `VkDeviceMemory` object is used by multiple -allocations, calls to `vkBind*Memory()` or `vkMapMemory()` won't happen from multiple threads simultaneously -(which is illegal in Vulkan). +template +void VmaIntrusiveLinkedList::PushFront(ItemType* item) +{ + VMA_HEAVY_ASSERT(ItemTypeTraits::GetPrev(item) == VMA_NULL && ItemTypeTraits::GetNext(item) == VMA_NULL); + if (IsEmpty()) + { + m_Front = item; + m_Back = item; + m_Count = 1; + } + else + { + ItemTypeTraits::AccessNext(item) = m_Front; + ItemTypeTraits::AccessPrev(m_Front) = item; + m_Front = item; + ++m_Count; + } +} -It is recommended to use function vmaCreateBuffer() instead of this one. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer); +template +typename VmaIntrusiveLinkedList::ItemType* VmaIntrusiveLinkedList::PopBack() +{ + VMA_HEAVY_ASSERT(m_Count > 0); + ItemType* const backItem = m_Back; + ItemType* const prevItem = ItemTypeTraits::GetPrev(backItem); + if (prevItem != VMA_NULL) + { + ItemTypeTraits::AccessNext(prevItem) = VMA_NULL; + } + m_Back = prevItem; + --m_Count; + ItemTypeTraits::AccessPrev(backItem) = VMA_NULL; + ItemTypeTraits::AccessNext(backItem) = VMA_NULL; + return backItem; +} -/** \brief Binds buffer to allocation with additional parameters. +template +typename VmaIntrusiveLinkedList::ItemType* VmaIntrusiveLinkedList::PopFront() +{ + VMA_HEAVY_ASSERT(m_Count > 0); + ItemType* const frontItem = m_Front; + ItemType* const nextItem = ItemTypeTraits::GetNext(frontItem); + if (nextItem != VMA_NULL) + { + ItemTypeTraits::AccessPrev(nextItem) = VMA_NULL; + } + m_Front = nextItem; + --m_Count; + ItemTypeTraits::AccessPrev(frontItem) = VMA_NULL; + ItemTypeTraits::AccessNext(frontItem) = VMA_NULL; + return frontItem; +} -@param allocationLocalOffset Additional offset to be added while binding, relative to the beginnig of the `allocation`. Normally it should be 0. -@param pNext A chain of structures to be attached to `VkBindBufferMemoryInfoKHR` structure used internally. Normally it should be null. +template +void VmaIntrusiveLinkedList::InsertBefore(ItemType* existingItem, ItemType* newItem) +{ + VMA_HEAVY_ASSERT(newItem != VMA_NULL && ItemTypeTraits::GetPrev(newItem) == VMA_NULL && ItemTypeTraits::GetNext(newItem) == VMA_NULL); + if (existingItem != VMA_NULL) + { + ItemType* const prevItem = ItemTypeTraits::GetPrev(existingItem); + ItemTypeTraits::AccessPrev(newItem) = prevItem; + ItemTypeTraits::AccessNext(newItem) = existingItem; + ItemTypeTraits::AccessPrev(existingItem) = newItem; + if (prevItem != VMA_NULL) + { + ItemTypeTraits::AccessNext(prevItem) = newItem; + } + else + { + VMA_HEAVY_ASSERT(m_Front == existingItem); + m_Front = newItem; + } + ++m_Count; + } + else + PushBack(newItem); +} -This function is similar to vmaBindBufferMemory(), but it provides additional parameters. +template +void VmaIntrusiveLinkedList::InsertAfter(ItemType* existingItem, ItemType* newItem) +{ + VMA_HEAVY_ASSERT(newItem != VMA_NULL && ItemTypeTraits::GetPrev(newItem) == VMA_NULL && ItemTypeTraits::GetNext(newItem) == VMA_NULL); + if (existingItem != VMA_NULL) + { + ItemType* const nextItem = ItemTypeTraits::GetNext(existingItem); + ItemTypeTraits::AccessNext(newItem) = nextItem; + ItemTypeTraits::AccessPrev(newItem) = existingItem; + ItemTypeTraits::AccessNext(existingItem) = newItem; + if (nextItem != VMA_NULL) + { + ItemTypeTraits::AccessPrev(nextItem) = newItem; + } + else + { + VMA_HEAVY_ASSERT(m_Back == existingItem); + m_Back = newItem; + } + ++m_Count; + } + else + return PushFront(newItem); +} -If `pNext` is not null, #VmaAllocator object must have been created with #VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT flag -or with VmaAllocatorCreateInfo::vulkanApiVersion `>= VK_API_VERSION_1_1`. Otherwise the call fails. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory2( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - VkDeviceSize allocationLocalOffset, - VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer, - const void* VMA_NULLABLE pNext); +template +void VmaIntrusiveLinkedList::Remove(ItemType* item) +{ + VMA_HEAVY_ASSERT(item != VMA_NULL && m_Count > 0); + if (ItemTypeTraits::GetPrev(item) != VMA_NULL) + { + ItemTypeTraits::AccessNext(ItemTypeTraits::AccessPrev(item)) = ItemTypeTraits::GetNext(item); + } + else + { + VMA_HEAVY_ASSERT(m_Front == item); + m_Front = ItemTypeTraits::GetNext(item); + } -/** \brief Binds image to allocation. + if (ItemTypeTraits::GetNext(item) != VMA_NULL) + { + ItemTypeTraits::AccessPrev(ItemTypeTraits::AccessNext(item)) = ItemTypeTraits::GetPrev(item); + } + else + { + VMA_HEAVY_ASSERT(m_Back == item); + m_Back = ItemTypeTraits::GetPrev(item); + } + ItemTypeTraits::AccessPrev(item) = VMA_NULL; + ItemTypeTraits::AccessNext(item) = VMA_NULL; + --m_Count; +} -Binds specified image to region of memory represented by specified allocation. -Gets `VkDeviceMemory` handle and offset from the allocation. -If you want to create an image, allocate memory for it and bind them together separately, -you should use this function for binding instead of standard `vkBindImageMemory()`, -because it ensures proper synchronization so that when a `VkDeviceMemory` object is used by multiple -allocations, calls to `vkBind*Memory()` or `vkMapMemory()` won't happen from multiple threads simultaneously -(which is illegal in Vulkan). +template +void VmaIntrusiveLinkedList::RemoveAll() +{ + if (!IsEmpty()) + { + ItemType* item = m_Back; + while (item != VMA_NULL) + { + ItemType* const prevItem = ItemTypeTraits::AccessPrev(item); + ItemTypeTraits::AccessPrev(item) = VMA_NULL; + ItemTypeTraits::AccessNext(item) = VMA_NULL; + item = prevItem; + } + m_Front = VMA_NULL; + m_Back = VMA_NULL; + m_Count = 0; + } +} +#endif // _VMA_INTRUSIVE_LINKED_LIST_FUNCTIONS +#endif // _VMA_INTRUSIVE_LINKED_LIST -It is recommended to use function vmaCreateImage() instead of this one. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - VkImage VMA_NOT_NULL_NON_DISPATCHABLE image); +// Unused in this version. +#if 0 -/** \brief Binds image to allocation with additional parameters. +#ifndef _VMA_PAIR +template +struct VmaPair +{ + T1 first; + T2 second; -@param allocationLocalOffset Additional offset to be added while binding, relative to the beginnig of the `allocation`. Normally it should be 0. -@param pNext A chain of structures to be attached to `VkBindImageMemoryInfoKHR` structure used internally. Normally it should be null. + VmaPair() : first(), second() {} + VmaPair(const T1& firstSrc, const T2& secondSrc) : first(firstSrc), second(secondSrc) {} +}; -This function is similar to vmaBindImageMemory(), but it provides additional parameters. +template +struct VmaPairFirstLess +{ + bool operator()(const VmaPair& lhs, const VmaPair& rhs) const + { + return lhs.first < rhs.first; + } + bool operator()(const VmaPair& lhs, const FirstT& rhsFirst) const + { + return lhs.first < rhsFirst; + } +}; +#endif // _VMA_PAIR -If `pNext` is not null, #VmaAllocator object must have been created with #VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT flag -or with VmaAllocatorCreateInfo::vulkanApiVersion `>= VK_API_VERSION_1_1`. Otherwise the call fails. +#ifndef _VMA_MAP +/* Class compatible with subset of interface of std::unordered_map. +KeyT, ValueT must be POD because they will be stored in VmaVector. */ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory2( - VmaAllocator VMA_NOT_NULL allocator, - VmaAllocation VMA_NOT_NULL allocation, - VkDeviceSize allocationLocalOffset, - VkImage VMA_NOT_NULL_NON_DISPATCHABLE image, - const void* VMA_NULLABLE pNext); +template +class VmaMap +{ +public: + typedef VmaPair PairType; + typedef PairType* iterator; -/** -@param[out] pBuffer Buffer that was created. -@param[out] pAllocation Allocation that was created. -@param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). + VmaMap(const VmaStlAllocator& allocator) : m_Vector(allocator) {} -This function automatically: + iterator begin() { return m_Vector.begin(); } + iterator end() { return m_Vector.end(); } + size_t size() { return m_Vector.size(); } --# Creates buffer. --# Allocates appropriate memory for it. --# Binds the buffer with the memory. + void insert(const PairType& pair); + iterator find(const KeyT& key); + void erase(iterator it); -If any of these operations fail, buffer and allocation are not created, -returned value is negative error code, *pBuffer and *pAllocation are null. +private: + VmaVector< PairType, VmaStlAllocator> m_Vector; +}; -If the function succeeded, you must destroy both buffer and allocation when you -no longer need them using either convenience function vmaDestroyBuffer() or -separately, using `vkDestroyBuffer()` and vmaFreeMemory(). +#ifndef _VMA_MAP_FUNCTIONS +template +void VmaMap::insert(const PairType& pair) +{ + const size_t indexToInsert = VmaBinaryFindFirstNotLess( + m_Vector.data(), + m_Vector.data() + m_Vector.size(), + pair, + VmaPairFirstLess()) - m_Vector.data(); + VmaVectorInsert(m_Vector, indexToInsert, pair); +} -If #VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT flag was used, -VK_KHR_dedicated_allocation extension is used internally to query driver whether -it requires or prefers the new buffer to have dedicated allocation. If yes, -and if dedicated allocation is possible (VmaAllocationCreateInfo::pool is null -and #VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT is not used), it creates dedicated -allocation for this buffer, just like when using -#VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. +template +VmaPair* VmaMap::find(const KeyT& key) +{ + PairType* it = VmaBinaryFindFirstNotLess( + m_Vector.data(), + m_Vector.data() + m_Vector.size(), + key, + VmaPairFirstLess()); + if ((it != m_Vector.end()) && (it->first == key)) + { + return it; + } + else + { + return m_Vector.end(); + } +} -\note This function creates a new `VkBuffer`. Sub-allocation of parts of one large buffer, -although recommended as a good practice, is out of scope of this library and could be implemented -by the user as a higher-level logic on top of VMA. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBuffer( - VmaAllocator VMA_NOT_NULL allocator, - const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, - const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, - VkBuffer VMA_NULLABLE_NON_DISPATCHABLE * VMA_NOT_NULL pBuffer, - VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, - VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); +template +void VmaMap::erase(iterator it) +{ + VmaVectorRemove(m_Vector, it - m_Vector.begin()); +} +#endif // _VMA_MAP_FUNCTIONS +#endif // _VMA_MAP -/** \brief Destroys Vulkan buffer and frees allocated memory. - -This is just a convenience function equivalent to: +#endif // #if 0 -\code -vkDestroyBuffer(device, buffer, allocationCallbacks); -vmaFreeMemory(allocator, allocation); -\endcode +#if !defined(_VMA_STRING_BUILDER) && VMA_STATS_STRING_ENABLED +class VmaStringBuilder +{ +public: + VmaStringBuilder(const VkAllocationCallbacks* allocationCallbacks) : m_Data(VmaStlAllocator(allocationCallbacks)) {} + ~VmaStringBuilder() = default; -It it safe to pass null as buffer and/or allocation. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaDestroyBuffer( - VmaAllocator VMA_NOT_NULL allocator, - VkBuffer VMA_NULLABLE_NON_DISPATCHABLE buffer, - VmaAllocation VMA_NULLABLE allocation); + size_t GetLength() const { return m_Data.size(); } + const char* GetData() const { return m_Data.data(); } + void AddNewLine() { Add('\n'); } + void Add(char ch) { m_Data.push_back(ch); } -/// Function similar to vmaCreateBuffer(). -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateImage( - VmaAllocator VMA_NOT_NULL allocator, - const VkImageCreateInfo* VMA_NOT_NULL pImageCreateInfo, - const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, - VkImage VMA_NULLABLE_NON_DISPATCHABLE * VMA_NOT_NULL pImage, - VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, - VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); + void Add(const char* pStr); + void AddNumber(uint32_t num); + void AddNumber(uint64_t num); + void AddPointer(const void* ptr); -/** \brief Destroys Vulkan image and frees allocated memory. +private: + VmaVector> m_Data; +}; -This is just a convenience function equivalent to: +#ifndef _VMA_STRING_BUILDER_FUNCTIONS +void VmaStringBuilder::Add(const char* pStr) +{ + const size_t strLen = strlen(pStr); + if (strLen > 0) + { + const size_t oldCount = m_Data.size(); + m_Data.resize(oldCount + strLen); + memcpy(m_Data.data() + oldCount, pStr, strLen); + } +} -\code -vkDestroyImage(device, image, allocationCallbacks); -vmaFreeMemory(allocator, allocation); -\endcode +void VmaStringBuilder::AddNumber(uint32_t num) +{ + char buf[11]; + buf[10] = '\0'; + char* p = &buf[10]; + do + { + *--p = '0' + (num % 10); + num /= 10; + } while (num); + Add(p); +} -It it safe to pass null as image and/or allocation. -*/ -VMA_CALL_PRE void VMA_CALL_POST vmaDestroyImage( - VmaAllocator VMA_NOT_NULL allocator, - VkImage VMA_NULLABLE_NON_DISPATCHABLE image, - VmaAllocation VMA_NULLABLE allocation); +void VmaStringBuilder::AddNumber(uint64_t num) +{ + char buf[21]; + buf[20] = '\0'; + char* p = &buf[20]; + do + { + *--p = '0' + (num % 10); + num /= 10; + } while (num); + Add(p); +} -#ifdef __cplusplus +void VmaStringBuilder::AddPointer(const void* ptr) +{ + char buf[21]; + VmaPtrToStr(buf, sizeof(buf), ptr); + Add(buf); } -#endif +#endif //_VMA_STRING_BUILDER_FUNCTIONS +#endif // _VMA_STRING_BUILDER -#endif // AMD_VULKAN_MEMORY_ALLOCATOR_H +#if !defined(_VMA_JSON_WRITER) && VMA_STATS_STRING_ENABLED +/* +Allows to conveniently build a correct JSON document to be written to the +VmaStringBuilder passed to the constructor. +*/ +class VmaJsonWriter +{ + VMA_CLASS_NO_COPY(VmaJsonWriter) +public: + // sb - string builder to write the document to. Must remain alive for the whole lifetime of this object. + VmaJsonWriter(const VkAllocationCallbacks* pAllocationCallbacks, VmaStringBuilder& sb); + ~VmaJsonWriter(); -// For Visual Studio IntelliSense. -#if defined(__cplusplus) && defined(__INTELLISENSE__) -#define VMA_IMPLEMENTATION -#endif + // Begins object by writing "{". + // Inside an object, you must call pairs of WriteString and a value, e.g.: + // j.BeginObject(true); j.WriteString("A"); j.WriteNumber(1); j.WriteString("B"); j.WriteNumber(2); j.EndObject(); + // Will write: { "A": 1, "B": 2 } + void BeginObject(bool singleLine = false); + // Ends object by writing "}". + void EndObject(); -#ifdef VMA_IMPLEMENTATION -#undef VMA_IMPLEMENTATION + // Begins array by writing "[". + // Inside an array, you can write a sequence of any values. + void BeginArray(bool singleLine = false); + // Ends array by writing "[". + void EndArray(); -#include -#include -#include -#include + // Writes a string value inside "". + // pStr can contain any ANSI characters, including '"', new line etc. - they will be properly escaped. + void WriteString(const char* pStr); + + // Begins writing a string value. + // Call BeginString, ContinueString, ContinueString, ..., EndString instead of + // WriteString to conveniently build the string content incrementally, made of + // parts including numbers. + void BeginString(const char* pStr = VMA_NULL); + // Posts next part of an open string. + void ContinueString(const char* pStr); + // Posts next part of an open string. The number is converted to decimal characters. + void ContinueString(uint32_t n); + void ContinueString(uint64_t n); + void ContinueString_Size(size_t n); + // Posts next part of an open string. Pointer value is converted to characters + // using "%p" formatting - shown as hexadecimal number, e.g.: 000000081276Ad00 + void ContinueString_Pointer(const void* ptr); + // Ends writing a string value by writing '"'. + void EndString(const char* pStr = VMA_NULL); -#if VMA_RECORDING_ENABLED - #include - #if defined(_WIN32) - #include - #else - #include - #include - #endif -#endif + // Writes a number value. + void WriteNumber(uint32_t n); + void WriteNumber(uint64_t n); + void WriteSize(size_t n); + // Writes a boolean value - false or true. + void WriteBool(bool b); + // Writes a null value. + void WriteNull(); -/******************************************************************************* -CONFIGURATION SECTION +private: + enum COLLECTION_TYPE + { + COLLECTION_TYPE_OBJECT, + COLLECTION_TYPE_ARRAY, + }; + struct StackItem + { + COLLECTION_TYPE type; + uint32_t valueCount; + bool singleLineMode; + }; -Define some of these macros before each #include of this header or change them -here if you need other then default behavior depending on your environment. -*/ + static const char* const INDENT; -/* -Define this macro to 1 to make the library fetch pointers to Vulkan functions -internally, like: + VmaStringBuilder& m_SB; + VmaVector< StackItem, VmaStlAllocator > m_Stack; + bool m_InsideString; - vulkanFunctions.vkAllocateMemory = &vkAllocateMemory; -*/ -#if !defined(VMA_STATIC_VULKAN_FUNCTIONS) && !defined(VK_NO_PROTOTYPES) - #define VMA_STATIC_VULKAN_FUNCTIONS 1 -#endif + // Write size_t for less than 64bits + void WriteSize(size_t n, std::integral_constant) { m_SB.AddNumber(static_cast(n)); } + // Write size_t for 64bits + void WriteSize(size_t n, std::integral_constant) { m_SB.AddNumber(static_cast(n)); } -/* -Define this macro to 1 to make the library fetch pointers to Vulkan functions -internally, like: + void BeginValue(bool isString); + void WriteIndent(bool oneLess = false); +}; +const char* const VmaJsonWriter::INDENT = " "; - vulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkGetDeviceProcAddr(m_hDevice, vkAllocateMemory); -*/ -#if !defined(VMA_DYNAMIC_VULKAN_FUNCTIONS) - #define VMA_DYNAMIC_VULKAN_FUNCTIONS 1 - #if defined(VK_NO_PROTOTYPES) - extern PFN_vkGetInstanceProcAddr vkGetInstanceProcAddr; - extern PFN_vkGetDeviceProcAddr vkGetDeviceProcAddr; - #endif -#endif +#ifndef _VMA_JSON_WRITER_FUNCTIONS +VmaJsonWriter::VmaJsonWriter(const VkAllocationCallbacks* pAllocationCallbacks, VmaStringBuilder& sb) + : m_SB(sb), + m_Stack(VmaStlAllocator(pAllocationCallbacks)), + m_InsideString(false) {} -// Define this macro to 1 to make the library use STL containers instead of its own implementation. -//#define VMA_USE_STL_CONTAINERS 1 +VmaJsonWriter::~VmaJsonWriter() +{ + VMA_ASSERT(!m_InsideString); + VMA_ASSERT(m_Stack.empty()); +} -/* Set this macro to 1 to make the library including and using STL containers: -std::pair, std::vector, std::list, std::unordered_map. +void VmaJsonWriter::BeginObject(bool singleLine) +{ + VMA_ASSERT(!m_InsideString); -Set it to 0 or undefined to make the library using its own implementation of -the containers. -*/ -#if VMA_USE_STL_CONTAINERS - #define VMA_USE_STL_VECTOR 1 - #define VMA_USE_STL_UNORDERED_MAP 1 - #define VMA_USE_STL_LIST 1 -#endif + BeginValue(false); + m_SB.Add('{'); -#ifndef VMA_USE_STL_SHARED_MUTEX - // Compiler conforms to C++17. - #if __cplusplus >= 201703L - #define VMA_USE_STL_SHARED_MUTEX 1 - // Visual studio defines __cplusplus properly only when passed additional parameter: /Zc:__cplusplus - // Otherwise it's always 199711L, despite shared_mutex works since Visual Studio 2015 Update 2. - // See: https://blogs.msdn.microsoft.com/vcblog/2018/04/09/msvc-now-correctly-reports-__cplusplus/ - #elif defined(_MSC_FULL_VER) && _MSC_FULL_VER >= 190023918 && __cplusplus == 199711L && _MSVC_LANG >= 201703L - #define VMA_USE_STL_SHARED_MUTEX 1 - #else - #define VMA_USE_STL_SHARED_MUTEX 0 - #endif -#endif + StackItem item; + item.type = COLLECTION_TYPE_OBJECT; + item.valueCount = 0; + item.singleLineMode = singleLine; + m_Stack.push_back(item); +} -/* -THESE INCLUDES ARE NOT ENABLED BY DEFAULT. -Library has its own container implementation. -*/ -#if VMA_USE_STL_VECTOR - #include -#endif +void VmaJsonWriter::EndObject() +{ + VMA_ASSERT(!m_InsideString); -#if VMA_USE_STL_UNORDERED_MAP - #include -#endif + WriteIndent(true); + m_SB.Add('}'); -#if VMA_USE_STL_LIST - #include -#endif + VMA_ASSERT(!m_Stack.empty() && m_Stack.back().type == COLLECTION_TYPE_OBJECT); + m_Stack.pop_back(); +} -/* -Following headers are used in this CONFIGURATION section only, so feel free to -remove them if not needed. -*/ -#include // for assert -#include // for min, max -#include +void VmaJsonWriter::BeginArray(bool singleLine) +{ + VMA_ASSERT(!m_InsideString); -#ifndef VMA_NULL - // Value used as null pointer. Define it to e.g.: nullptr, NULL, 0, (void*)0. - #define VMA_NULL nullptr -#endif + BeginValue(false); + m_SB.Add('['); -#if defined(__ANDROID_API__) && (__ANDROID_API__ < 16) -#include -static void* vma_aligned_alloc(size_t alignment, size_t size) + StackItem item; + item.type = COLLECTION_TYPE_ARRAY; + item.valueCount = 0; + item.singleLineMode = singleLine; + m_Stack.push_back(item); +} + +void VmaJsonWriter::EndArray() { - // alignment must be >= sizeof(void*) - if(alignment < sizeof(void*)) - { - alignment = sizeof(void*); - } + VMA_ASSERT(!m_InsideString); - return memalign(alignment, size); + WriteIndent(true); + m_SB.Add(']'); + + VMA_ASSERT(!m_Stack.empty() && m_Stack.back().type == COLLECTION_TYPE_ARRAY); + m_Stack.pop_back(); } -#elif defined(__APPLE__) || defined(__ANDROID__) || (defined(__linux__) && defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC)) -#include -#if defined(__APPLE__) -#include -#endif +void VmaJsonWriter::WriteString(const char* pStr) +{ + BeginString(pStr); + EndString(); +} -static void* vma_aligned_alloc(size_t alignment, size_t size) +void VmaJsonWriter::BeginString(const char* pStr) { -#if defined(__APPLE__) && (defined(MAC_OS_X_VERSION_10_16) || defined(__IPHONE_14_0)) -#if MAC_OS_X_VERSION_MAX_ALLOWED >= MAC_OS_X_VERSION_10_16 || __IPHONE_OS_VERSION_MAX_ALLOWED >= __IPHONE_14_0 - // For C++14, usr/include/malloc/_malloc.h declares aligned_alloc()) only - // with the MacOSX11.0 SDK in Xcode 12 (which is what adds - // MAC_OS_X_VERSION_10_16), even though the function is marked - // availabe for 10.15. That's why the preprocessor checks for 10.16 but - // the __builtin_available checks for 10.15. - // People who use C++17 could call aligned_alloc with the 10.15 SDK already. - if (__builtin_available(macOS 10.15, iOS 13, *)) - return aligned_alloc(alignment, size); -#endif -#endif - // alignment must be >= sizeof(void*) - if(alignment < sizeof(void*)) - { - alignment = sizeof(void*); - } + VMA_ASSERT(!m_InsideString); - void *pointer; - if(posix_memalign(&pointer, alignment, size) == 0) - return pointer; - return VMA_NULL; + BeginValue(true); + m_SB.Add('"'); + m_InsideString = true; + if (pStr != VMA_NULL && pStr[0] != '\0') + { + ContinueString(pStr); + } } -#elif defined(_WIN32) -static void* vma_aligned_alloc(size_t alignment, size_t size) + +void VmaJsonWriter::ContinueString(const char* pStr) { - return _aligned_malloc(size, alignment); + VMA_ASSERT(m_InsideString); + + const size_t strLen = strlen(pStr); + for (size_t i = 0; i < strLen; ++i) + { + char ch = pStr[i]; + if (ch == '\\') + { + m_SB.Add("\\\\"); + } + else if (ch == '"') + { + m_SB.Add("\\\""); + } + else if (ch >= 32) + { + m_SB.Add(ch); + } + else switch (ch) + { + case '\b': + m_SB.Add("\\b"); + break; + case '\f': + m_SB.Add("\\f"); + break; + case '\n': + m_SB.Add("\\n"); + break; + case '\r': + m_SB.Add("\\r"); + break; + case '\t': + m_SB.Add("\\t"); + break; + default: + VMA_ASSERT(0 && "Character not currently supported."); + break; + } + } } -#else -static void* vma_aligned_alloc(size_t alignment, size_t size) + +void VmaJsonWriter::ContinueString(uint32_t n) { - return aligned_alloc(alignment, size); + VMA_ASSERT(m_InsideString); + m_SB.AddNumber(n); } -#endif -#if defined(_WIN32) -static void vma_aligned_free(void* ptr) +void VmaJsonWriter::ContinueString(uint64_t n) { - _aligned_free(ptr); + VMA_ASSERT(m_InsideString); + m_SB.AddNumber(n); } -#else -static void vma_aligned_free(void* ptr) + +void VmaJsonWriter::ContinueString_Size(size_t n) { - free(ptr); + VMA_ASSERT(m_InsideString); + // Fix for AppleClang incorrect type casting + // TODO: Change to if constexpr when C++17 used as minimal standard + WriteSize(n, std::is_same{}); } -#endif -// If your compiler is not compatible with C++11 and definition of -// aligned_alloc() function is missing, uncommeting following line may help: +void VmaJsonWriter::ContinueString_Pointer(const void* ptr) +{ + VMA_ASSERT(m_InsideString); + m_SB.AddPointer(ptr); +} -//#include +void VmaJsonWriter::EndString(const char* pStr) +{ + VMA_ASSERT(m_InsideString); + if (pStr != VMA_NULL && pStr[0] != '\0') + { + ContinueString(pStr); + } + m_SB.Add('"'); + m_InsideString = false; +} -// Normal assert to check for programmer's errors, especially in Debug configuration. -#ifndef VMA_ASSERT - #ifdef NDEBUG - #define VMA_ASSERT(expr) - #else - #define VMA_ASSERT(expr) assert(expr) - #endif -#endif +void VmaJsonWriter::WriteNumber(uint32_t n) +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + m_SB.AddNumber(n); +} -// Assert that will be called very often, like inside data structures e.g. operator[]. -// Making it non-empty can make program slow. -#ifndef VMA_HEAVY_ASSERT - #ifdef NDEBUG - #define VMA_HEAVY_ASSERT(expr) - #else - #define VMA_HEAVY_ASSERT(expr) //VMA_ASSERT(expr) - #endif -#endif +void VmaJsonWriter::WriteNumber(uint64_t n) +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + m_SB.AddNumber(n); +} -#ifndef VMA_ALIGN_OF - #define VMA_ALIGN_OF(type) (__alignof(type)) -#endif +void VmaJsonWriter::WriteSize(size_t n) +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + // Fix for AppleClang incorrect type casting + // TODO: Change to if constexpr when C++17 used as minimal standard + WriteSize(n, std::is_same{}); +} -#ifndef VMA_SYSTEM_ALIGNED_MALLOC - #define VMA_SYSTEM_ALIGNED_MALLOC(size, alignment) vma_aligned_alloc((alignment), (size)) -#endif +void VmaJsonWriter::WriteBool(bool b) +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + m_SB.Add(b ? "true" : "false"); +} -#ifndef VMA_SYSTEM_ALIGNED_FREE - // VMA_SYSTEM_FREE is the old name, but might have been defined by the user - #if defined(VMA_SYSTEM_FREE) - #define VMA_SYSTEM_ALIGNED_FREE(ptr) VMA_SYSTEM_FREE(ptr) - #else - #define VMA_SYSTEM_ALIGNED_FREE(ptr) vma_aligned_free(ptr) - #endif -#endif +void VmaJsonWriter::WriteNull() +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + m_SB.Add("null"); +} -#ifndef VMA_MIN - #define VMA_MIN(v1, v2) (std::min((v1), (v2))) -#endif +void VmaJsonWriter::BeginValue(bool isString) +{ + if (!m_Stack.empty()) + { + StackItem& currItem = m_Stack.back(); + if (currItem.type == COLLECTION_TYPE_OBJECT && + currItem.valueCount % 2 == 0) + { + VMA_ASSERT(isString); + } -#ifndef VMA_MAX - #define VMA_MAX(v1, v2) (std::max((v1), (v2))) -#endif + if (currItem.type == COLLECTION_TYPE_OBJECT && + currItem.valueCount % 2 != 0) + { + m_SB.Add(": "); + } + else if (currItem.valueCount > 0) + { + m_SB.Add(", "); + WriteIndent(); + } + else + { + WriteIndent(); + } + ++currItem.valueCount; + } +} -#ifndef VMA_SWAP - #define VMA_SWAP(v1, v2) std::swap((v1), (v2)) -#endif +void VmaJsonWriter::WriteIndent(bool oneLess) +{ + if (!m_Stack.empty() && !m_Stack.back().singleLineMode) + { + m_SB.AddNewLine(); -#ifndef VMA_SORT - #define VMA_SORT(beg, end, cmp) std::sort(beg, end, cmp) -#endif + size_t count = m_Stack.size(); + if (count > 0 && oneLess) + { + --count; + } + for (size_t i = 0; i < count; ++i) + { + m_SB.Add(INDENT); + } + } +} +#endif // _VMA_JSON_WRITER_FUNCTIONS -#ifndef VMA_DEBUG_LOG - #define VMA_DEBUG_LOG(format, ...) - /* - #define VMA_DEBUG_LOG(format, ...) do { \ - printf(format, __VA_ARGS__); \ - printf("\n"); \ - } while(false) - */ -#endif +static void VmaPrintDetailedStatistics(VmaJsonWriter& json, const VmaDetailedStatistics& stat) +{ + json.BeginObject(); -// Define this macro to 1 to enable functions: vmaBuildStatsString, vmaFreeStatsString. -#if VMA_STATS_STRING_ENABLED - static inline void VmaUint32ToStr(char* outStr, size_t strLen, uint32_t num) + json.WriteString("BlockCount"); + json.WriteNumber(stat.statistics.blockCount); + json.WriteString("BlockBytes"); + json.WriteNumber(stat.statistics.blockBytes); + json.WriteString("AllocationCount"); + json.WriteNumber(stat.statistics.allocationCount); + json.WriteString("AllocationBytes"); + json.WriteNumber(stat.statistics.allocationBytes); + json.WriteString("UnusedRangeCount"); + json.WriteNumber(stat.unusedRangeCount); + + if (stat.statistics.allocationCount > 1) { - snprintf(outStr, strLen, "%u", static_cast(num)); + json.WriteString("AllocationSizeMin"); + json.WriteNumber(stat.allocationSizeMin); + json.WriteString("AllocationSizeMax"); + json.WriteNumber(stat.allocationSizeMax); } - static inline void VmaUint64ToStr(char* outStr, size_t strLen, uint64_t num) + if (stat.unusedRangeCount > 1) { - snprintf(outStr, strLen, "%llu", static_cast(num)); + json.WriteString("UnusedRangeSizeMin"); + json.WriteNumber(stat.unusedRangeSizeMin); + json.WriteString("UnusedRangeSizeMax"); + json.WriteNumber(stat.unusedRangeSizeMax); } - static inline void VmaPtrToStr(char* outStr, size_t strLen, const void* ptr) + json.EndObject(); +} +#endif // _VMA_JSON_WRITER + +#ifndef _VMA_MAPPING_HYSTERESIS + +class VmaMappingHysteresis +{ + VMA_CLASS_NO_COPY(VmaMappingHysteresis) +public: + VmaMappingHysteresis() = default; + + uint32_t GetExtraMapping() const { return m_ExtraMapping; } + + // Call when Map was called. + // Returns true if switched to extra +1 mapping reference count. + bool PostMap() { - snprintf(outStr, strLen, "%p", ptr); +#if VMA_MAPPING_HYSTERESIS_ENABLED + if(m_ExtraMapping == 0) + { + ++m_MajorCounter; + if(m_MajorCounter >= COUNTER_MIN_EXTRA_MAPPING) + { + m_ExtraMapping = 1; + m_MajorCounter = 0; + m_MinorCounter = 0; + return true; + } + } + else // m_ExtraMapping == 1 + PostMinorCounter(); +#endif // #if VMA_MAPPING_HYSTERESIS_ENABLED + return false; } -#endif -#ifndef VMA_MUTEX - class VmaMutex + // Call when Unmap was called. + void PostUnmap() { - public: - void Lock() { m_Mutex.lock(); } - void Unlock() { m_Mutex.unlock(); } - bool TryLock() { return m_Mutex.try_lock(); } - private: - std::mutex m_Mutex; - }; - #define VMA_MUTEX VmaMutex -#endif +#if VMA_MAPPING_HYSTERESIS_ENABLED + if(m_ExtraMapping == 0) + ++m_MajorCounter; + else // m_ExtraMapping == 1 + PostMinorCounter(); +#endif // #if VMA_MAPPING_HYSTERESIS_ENABLED + } -// Read-write mutex, where "read" is shared access, "write" is exclusive access. -#ifndef VMA_RW_MUTEX - #if VMA_USE_STL_SHARED_MUTEX - // Use std::shared_mutex from C++17. - #include - class VmaRWMutex + // Call when allocation was made from the memory block. + void PostAlloc() + { +#if VMA_MAPPING_HYSTERESIS_ENABLED + if(m_ExtraMapping == 1) + ++m_MajorCounter; + else // m_ExtraMapping == 0 + PostMinorCounter(); +#endif // #if VMA_MAPPING_HYSTERESIS_ENABLED + } + + // Call when allocation was freed from the memory block. + // Returns true if switched to extra -1 mapping reference count. + bool PostFree() + { +#if VMA_MAPPING_HYSTERESIS_ENABLED + if(m_ExtraMapping == 1) { - public: - void LockRead() { m_Mutex.lock_shared(); } - void UnlockRead() { m_Mutex.unlock_shared(); } - bool TryLockRead() { return m_Mutex.try_lock_shared(); } - void LockWrite() { m_Mutex.lock(); } - void UnlockWrite() { m_Mutex.unlock(); } - bool TryLockWrite() { return m_Mutex.try_lock(); } - private: - std::shared_mutex m_Mutex; - }; - #define VMA_RW_MUTEX VmaRWMutex - #elif defined(_WIN32) && defined(WINVER) && WINVER >= 0x0600 - // Use SRWLOCK from WinAPI. - // Minimum supported client = Windows Vista, server = Windows Server 2008. - class VmaRWMutex + ++m_MajorCounter; + if(m_MajorCounter >= COUNTER_MIN_EXTRA_MAPPING && + m_MajorCounter > m_MinorCounter + 1) + { + m_ExtraMapping = 0; + m_MajorCounter = 0; + m_MinorCounter = 0; + return true; + } + } + else // m_ExtraMapping == 0 + PostMinorCounter(); +#endif // #if VMA_MAPPING_HYSTERESIS_ENABLED + return false; + } + +private: + static const int32_t COUNTER_MIN_EXTRA_MAPPING = 7; + + uint32_t m_MinorCounter = 0; + uint32_t m_MajorCounter = 0; + uint32_t m_ExtraMapping = 0; // 0 or 1. + + void PostMinorCounter() + { + if(m_MinorCounter < m_MajorCounter) { - public: - VmaRWMutex() { InitializeSRWLock(&m_Lock); } - void LockRead() { AcquireSRWLockShared(&m_Lock); } - void UnlockRead() { ReleaseSRWLockShared(&m_Lock); } - bool TryLockRead() { return TryAcquireSRWLockShared(&m_Lock) != FALSE; } - void LockWrite() { AcquireSRWLockExclusive(&m_Lock); } - void UnlockWrite() { ReleaseSRWLockExclusive(&m_Lock); } - bool TryLockWrite() { return TryAcquireSRWLockExclusive(&m_Lock) != FALSE; } - private: - SRWLOCK m_Lock; - }; - #define VMA_RW_MUTEX VmaRWMutex - #else - // Less efficient fallback: Use normal mutex. - class VmaRWMutex + ++m_MinorCounter; + } + else if(m_MajorCounter > 0) { - public: - void LockRead() { m_Mutex.Lock(); } - void UnlockRead() { m_Mutex.Unlock(); } - bool TryLockRead() { return m_Mutex.TryLock(); } - void LockWrite() { m_Mutex.Lock(); } - void UnlockWrite() { m_Mutex.Unlock(); } - bool TryLockWrite() { return m_Mutex.TryLock(); } - private: - VMA_MUTEX m_Mutex; - }; - #define VMA_RW_MUTEX VmaRWMutex - #endif // #if VMA_USE_STL_SHARED_MUTEX -#endif // #ifndef VMA_RW_MUTEX + --m_MajorCounter; + --m_MinorCounter; + } + } +}; + +#endif // _VMA_MAPPING_HYSTERESIS +#ifndef _VMA_DEVICE_MEMORY_BLOCK /* -If providing your own implementation, you need to implement a subset of std::atomic. +Represents a single block of device memory (`VkDeviceMemory`) with all the +data about its regions (aka suballocations, #VmaAllocation), assigned and free. + +Thread-safety: +- Access to m_pMetadata must be externally synchronized. +- Map, Unmap, Bind* are synchronized internally. */ -#ifndef VMA_ATOMIC_UINT32 - #include - #define VMA_ATOMIC_UINT32 std::atomic -#endif +class VmaDeviceMemoryBlock +{ + VMA_CLASS_NO_COPY(VmaDeviceMemoryBlock) +public: + VmaBlockMetadata* m_pMetadata; -#ifndef VMA_ATOMIC_UINT64 - #include - #define VMA_ATOMIC_UINT64 std::atomic -#endif + VmaDeviceMemoryBlock(VmaAllocator hAllocator); + ~VmaDeviceMemoryBlock(); -#ifndef VMA_DEBUG_ALWAYS_DEDICATED_MEMORY - /** - Every allocation will have its own memory block. - Define to 1 for debugging purposes only. - */ - #define VMA_DEBUG_ALWAYS_DEDICATED_MEMORY (0) -#endif + // Always call after construction. + void Init( + VmaAllocator hAllocator, + VmaPool hParentPool, + uint32_t newMemoryTypeIndex, + VkDeviceMemory newMemory, + VkDeviceSize newSize, + uint32_t id, + uint32_t algorithm, + VkDeviceSize bufferImageGranularity); + // Always call before destruction. + void Destroy(VmaAllocator allocator); -#ifndef VMA_DEBUG_ALIGNMENT - /** - Minimum alignment of all allocations, in bytes. - Set to more than 1 for debugging purposes only. Must be power of two. - */ - #define VMA_DEBUG_ALIGNMENT (1) -#endif + VmaPool GetParentPool() const { return m_hParentPool; } + VkDeviceMemory GetDeviceMemory() const { return m_hMemory; } + uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } + uint32_t GetId() const { return m_Id; } + void* GetMappedData() const { return m_pMappedData; } + uint32_t GetMapRefCount() const { return m_MapCount; } -#ifndef VMA_DEBUG_MARGIN - /** - Minimum margin before and after every allocation, in bytes. - Set nonzero for debugging purposes only. - */ - #define VMA_DEBUG_MARGIN (0) -#endif + // Call when allocation/free was made from m_pMetadata. + // Used for m_MappingHysteresis. + void PostAlloc() { m_MappingHysteresis.PostAlloc(); } + void PostFree(VmaAllocator hAllocator); -#ifndef VMA_DEBUG_INITIALIZE_ALLOCATIONS - /** - Define this macro to 1 to automatically fill new allocations and destroyed - allocations with some bit pattern. - */ - #define VMA_DEBUG_INITIALIZE_ALLOCATIONS (0) -#endif + // Validates all data structures inside this object. If not valid, returns false. + bool Validate() const; + VkResult CheckCorruption(VmaAllocator hAllocator); -#ifndef VMA_DEBUG_DETECT_CORRUPTION - /** - Define this macro to 1 together with non-zero value of VMA_DEBUG_MARGIN to - enable writing magic value to the margin before and after every allocation and - validating it, so that memory corruptions (out-of-bounds writes) are detected. - */ - #define VMA_DEBUG_DETECT_CORRUPTION (0) -#endif + // ppData can be null. + VkResult Map(VmaAllocator hAllocator, uint32_t count, void** ppData); + void Unmap(VmaAllocator hAllocator, uint32_t count); -#ifndef VMA_DEBUG_GLOBAL_MUTEX - /** - Set this to 1 for debugging purposes only, to enable single mutex protecting all - entry calls to the library. Can be useful for debugging multithreading issues. - */ - #define VMA_DEBUG_GLOBAL_MUTEX (0) -#endif + VkResult WriteMagicValueAfterAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize); + VkResult ValidateMagicValueAfterAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize); -#ifndef VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY - /** - Minimum value for VkPhysicalDeviceLimits::bufferImageGranularity. - Set to more than 1 for debugging purposes only. Must be power of two. - */ - #define VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY (1) -#endif + VkResult BindBufferMemory( + const VmaAllocator hAllocator, + const VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkBuffer hBuffer, + const void* pNext); + VkResult BindImageMemory( + const VmaAllocator hAllocator, + const VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkImage hImage, + const void* pNext); + +private: + VmaPool m_hParentPool; // VK_NULL_HANDLE if not belongs to custom pool. + uint32_t m_MemoryTypeIndex; + uint32_t m_Id; + VkDeviceMemory m_hMemory; -#ifndef VMA_DEBUG_DONT_EXCEED_MAX_MEMORY_ALLOCATION_COUNT /* - Set this to 1 to make VMA never exceed VkPhysicalDeviceLimits::maxMemoryAllocationCount - and return error instead of leaving up to Vulkan implementation what to do in such cases. + Protects access to m_hMemory so it is not used by multiple threads simultaneously, e.g. vkMapMemory, vkBindBufferMemory. + Also protects m_MapCount, m_pMappedData. + Allocations, deallocations, any change in m_pMetadata is protected by parent's VmaBlockVector::m_Mutex. */ - #define VMA_DEBUG_DONT_EXCEED_MAX_MEMORY_ALLOCATION_COUNT (0) -#endif + VMA_MUTEX m_MapAndBindMutex; + VmaMappingHysteresis m_MappingHysteresis; + uint32_t m_MapCount; + void* m_pMappedData; +}; +#endif // _VMA_DEVICE_MEMORY_BLOCK -#ifndef VMA_SMALL_HEAP_MAX_SIZE - /// Maximum size of a memory heap in Vulkan to consider it "small". - #define VMA_SMALL_HEAP_MAX_SIZE (1024ull * 1024 * 1024) -#endif +#ifndef _VMA_ALLOCATION_T +struct VmaAllocation_T +{ + friend struct VmaDedicatedAllocationListItemTraits; -#ifndef VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE - /// Default size of a block allocated as single VkDeviceMemory from a "large" heap. - #define VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE (256ull * 1024 * 1024) -#endif + enum FLAGS + { + FLAG_PERSISTENT_MAP = 0x01, + FLAG_MAPPING_ALLOWED = 0x02, + }; -#ifndef VMA_CLASS_NO_COPY - #define VMA_CLASS_NO_COPY(className) \ - private: \ - className(const className&) = delete; \ - className& operator=(const className&) = delete; -#endif +public: + enum ALLOCATION_TYPE + { + ALLOCATION_TYPE_NONE, + ALLOCATION_TYPE_BLOCK, + ALLOCATION_TYPE_DEDICATED, + }; -static const uint32_t VMA_FRAME_INDEX_LOST = UINT32_MAX; + // This struct is allocated using VmaPoolAllocator. + VmaAllocation_T(bool mappingAllowed); + ~VmaAllocation_T(); -// Decimal 2139416166, float NaN, little-endian binary 66 E6 84 7F. -static const uint32_t VMA_CORRUPTION_DETECTION_MAGIC_VALUE = 0x7F84E666; + void InitBlockAllocation( + VmaDeviceMemoryBlock* block, + VmaAllocHandle allocHandle, + VkDeviceSize alignment, + VkDeviceSize size, + uint32_t memoryTypeIndex, + VmaSuballocationType suballocationType, + bool mapped); + // pMappedData not null means allocation is created with MAPPED flag. + void InitDedicatedAllocation( + VmaPool hParentPool, + uint32_t memoryTypeIndex, + VkDeviceMemory hMemory, + VmaSuballocationType suballocationType, + void* pMappedData, + VkDeviceSize size); -static const uint8_t VMA_ALLOCATION_FILL_PATTERN_CREATED = 0xDC; -static const uint8_t VMA_ALLOCATION_FILL_PATTERN_DESTROYED = 0xEF; + ALLOCATION_TYPE GetType() const { return (ALLOCATION_TYPE)m_Type; } + VkDeviceSize GetAlignment() const { return m_Alignment; } + VkDeviceSize GetSize() const { return m_Size; } + void* GetUserData() const { return m_pUserData; } + const char* GetName() const { return m_pName; } + VmaSuballocationType GetSuballocationType() const { return (VmaSuballocationType)m_SuballocationType; } -/******************************************************************************* -END OF CONFIGURATION -*/ + VmaDeviceMemoryBlock* GetBlock() const { VMA_ASSERT(m_Type == ALLOCATION_TYPE_BLOCK); return m_BlockAllocation.m_Block; } + uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } + bool IsPersistentMap() const { return (m_Flags & FLAG_PERSISTENT_MAP) != 0; } + bool IsMappingAllowed() const { return (m_Flags & FLAG_MAPPING_ALLOWED) != 0; } + + void SetUserData(VmaAllocator hAllocator, void* pUserData) { m_pUserData = pUserData; } + void SetName(VmaAllocator hAllocator, const char* pName); + void FreeName(VmaAllocator hAllocator); + uint8_t SwapBlockAllocation(VmaAllocator hAllocator, VmaAllocation allocation); + VmaAllocHandle GetAllocHandle() const; + VkDeviceSize GetOffset() const; + VmaPool GetParentPool() const; + VkDeviceMemory GetMemory() const; + void* GetMappedData() const; -// # Copy of some Vulkan definitions so we don't need to check their existence just to handle few constants. + void BlockAllocMap(); + void BlockAllocUnmap(); + VkResult DedicatedAllocMap(VmaAllocator hAllocator, void** ppData); + void DedicatedAllocUnmap(VmaAllocator hAllocator); -static const uint32_t VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY = 0x00000040; -static const uint32_t VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY = 0x00000080; -static const uint32_t VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_COPY = 0x00020000; +#if VMA_STATS_STRING_ENABLED + uint32_t GetBufferImageUsage() const { return m_BufferImageUsage; } -static const uint32_t VMA_ALLOCATION_INTERNAL_STRATEGY_MIN_OFFSET = 0x10000000u; + void InitBufferImageUsage(uint32_t bufferImageUsage); + void PrintParameters(class VmaJsonWriter& json) const; +#endif -static VkAllocationCallbacks VmaEmptyAllocationCallbacks = { - VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL }; +private: + // Allocation out of VmaDeviceMemoryBlock. + struct BlockAllocation + { + VmaDeviceMemoryBlock* m_Block; + VmaAllocHandle m_AllocHandle; + }; + // Allocation for an object that has its own private VkDeviceMemory. + struct DedicatedAllocation + { + VmaPool m_hParentPool; // VK_NULL_HANDLE if not belongs to custom pool. + VkDeviceMemory m_hMemory; + void* m_pMappedData; // Not null means memory is mapped. + VmaAllocation_T* m_Prev; + VmaAllocation_T* m_Next; + }; + union + { + // Allocation out of VmaDeviceMemoryBlock. + BlockAllocation m_BlockAllocation; + // Allocation for an object that has its own private VkDeviceMemory. + DedicatedAllocation m_DedicatedAllocation; + }; -// Returns number of bits set to 1 in (v). -static inline uint32_t VmaCountBitsSet(uint32_t v) -{ - uint32_t c = v - ((v >> 1) & 0x55555555); - c = ((c >> 2) & 0x33333333) + (c & 0x33333333); - c = ((c >> 4) + c) & 0x0F0F0F0F; - c = ((c >> 8) + c) & 0x00FF00FF; - c = ((c >> 16) + c) & 0x0000FFFF; - return c; -} + VkDeviceSize m_Alignment; + VkDeviceSize m_Size; + void* m_pUserData; + char* m_pName; + uint32_t m_MemoryTypeIndex; + uint8_t m_Type; // ALLOCATION_TYPE + uint8_t m_SuballocationType; // VmaSuballocationType + // Reference counter for vmaMapMemory()/vmaUnmapMemory(). + uint8_t m_MapCount; + uint8_t m_Flags; // enum FLAGS +#if VMA_STATS_STRING_ENABLED + uint32_t m_BufferImageUsage; // 0 if unknown. +#endif +}; +#endif // _VMA_ALLOCATION_T + +#ifndef _VMA_DEDICATED_ALLOCATION_LIST_ITEM_TRAITS +struct VmaDedicatedAllocationListItemTraits +{ + typedef VmaAllocation_T ItemType; + + static ItemType* GetPrev(const ItemType* item) + { + VMA_HEAVY_ASSERT(item->GetType() == VmaAllocation_T::ALLOCATION_TYPE_DEDICATED); + return item->m_DedicatedAllocation.m_Prev; + } + static ItemType* GetNext(const ItemType* item) + { + VMA_HEAVY_ASSERT(item->GetType() == VmaAllocation_T::ALLOCATION_TYPE_DEDICATED); + return item->m_DedicatedAllocation.m_Next; + } + static ItemType*& AccessPrev(ItemType* item) + { + VMA_HEAVY_ASSERT(item->GetType() == VmaAllocation_T::ALLOCATION_TYPE_DEDICATED); + return item->m_DedicatedAllocation.m_Prev; + } + static ItemType*& AccessNext(ItemType* item) + { + VMA_HEAVY_ASSERT(item->GetType() == VmaAllocation_T::ALLOCATION_TYPE_DEDICATED); + return item->m_DedicatedAllocation.m_Next; + } +}; +#endif // _VMA_DEDICATED_ALLOCATION_LIST_ITEM_TRAITS +#ifndef _VMA_DEDICATED_ALLOCATION_LIST /* -Returns true if given number is a power of two. -T must be unsigned integer number or signed integer but always nonnegative. -For 0 returns true. +Stores linked list of VmaAllocation_T objects. +Thread-safe, synchronized internally. */ -template -inline bool VmaIsPow2(T x) +class VmaDedicatedAllocationList { - return (x & (x-1)) == 0; -} +public: + VmaDedicatedAllocationList() {} + ~VmaDedicatedAllocationList(); -// Aligns given value up to nearest multiply of align value. For example: VmaAlignUp(11, 8) = 16. -// Use types like uint32_t, uint64_t as T. -template -static inline T VmaAlignUp(T val, T alignment) -{ - VMA_HEAVY_ASSERT(VmaIsPow2(alignment)); - return (val + alignment - 1) & ~(alignment - 1); -} -// Aligns given value down to nearest multiply of align value. For example: VmaAlignUp(11, 8) = 8. -// Use types like uint32_t, uint64_t as T. -template -static inline T VmaAlignDown(T val, T alignment) -{ - VMA_HEAVY_ASSERT(VmaIsPow2(alignment)); - return val & ~(alignment - 1); -} + void Init(bool useMutex) { m_UseMutex = useMutex; } + bool Validate(); -// Division with mathematical rounding to nearest number. -template -static inline T VmaRoundDiv(T x, T y) -{ - return (x + (y / (T)2)) / y; -} + void AddDetailedStatistics(VmaDetailedStatistics& inoutStats); + void AddStatistics(VmaStatistics& inoutStats); +#if VMA_STATS_STRING_ENABLED + // Writes JSON array with the list of allocations. + void BuildStatsString(VmaJsonWriter& json); +#endif -// Returns smallest power of 2 greater or equal to v. -static inline uint32_t VmaNextPow2(uint32_t v) -{ - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v++; - return v; -} -static inline uint64_t VmaNextPow2(uint64_t v) -{ - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v |= v >> 32; - v++; - return v; -} + bool IsEmpty(); + void Register(VmaAllocation alloc); + void Unregister(VmaAllocation alloc); -// Returns largest power of 2 less or equal to v. -static inline uint32_t VmaPrevPow2(uint32_t v) -{ - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v = v ^ (v >> 1); - return v; -} -static inline uint64_t VmaPrevPow2(uint64_t v) -{ - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v |= v >> 32; - v = v ^ (v >> 1); - return v; -} +private: + typedef VmaIntrusiveLinkedList DedicatedAllocationLinkedList; -static inline bool VmaStrIsEmpty(const char* pStr) -{ - return pStr == VMA_NULL || *pStr == '\0'; -} + bool m_UseMutex = true; + VMA_RW_MUTEX m_Mutex; + DedicatedAllocationLinkedList m_AllocationList; +}; -#if VMA_STATS_STRING_ENABLED +#ifndef _VMA_DEDICATED_ALLOCATION_LIST_FUNCTIONS -static const char* VmaAlgorithmToStr(uint32_t algorithm) +VmaDedicatedAllocationList::~VmaDedicatedAllocationList() { - switch(algorithm) + VMA_HEAVY_ASSERT(Validate()); + + if (!m_AllocationList.IsEmpty()) { - case VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT: - return "Linear"; - case VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT: - return "Buddy"; - case 0: - return "Default"; - default: - VMA_ASSERT(0); - return ""; + VMA_ASSERT(false && "Unfreed dedicated allocations found!"); } } -#endif // #if VMA_STATS_STRING_ENABLED +bool VmaDedicatedAllocationList::Validate() +{ + const size_t declaredCount = m_AllocationList.GetCount(); + size_t actualCount = 0; + VmaMutexLockRead lock(m_Mutex, m_UseMutex); + for (VmaAllocation alloc = m_AllocationList.Front(); + alloc != VMA_NULL; alloc = m_AllocationList.GetNext(alloc)) + { + ++actualCount; + } + VMA_VALIDATE(actualCount == declaredCount); -#ifndef VMA_SORT + return true; +} -template -Iterator VmaQuickSortPartition(Iterator beg, Iterator end, Compare cmp) +void VmaDedicatedAllocationList::AddDetailedStatistics(VmaDetailedStatistics& inoutStats) { - Iterator centerValue = end; --centerValue; - Iterator insertIndex = beg; - for(Iterator memTypeIndex = beg; memTypeIndex < centerValue; ++memTypeIndex) + for(auto* item = m_AllocationList.Front(); item != nullptr; item = DedicatedAllocationLinkedList::GetNext(item)) { - if(cmp(*memTypeIndex, *centerValue)) - { - if(insertIndex != memTypeIndex) - { - VMA_SWAP(*memTypeIndex, *insertIndex); - } - ++insertIndex; - } + const VkDeviceSize size = item->GetSize(); + inoutStats.statistics.blockCount++; + inoutStats.statistics.blockBytes += size; + VmaAddDetailedStatisticsAllocation(inoutStats, item->GetSize()); } - if(insertIndex != centerValue) +} + +void VmaDedicatedAllocationList::AddStatistics(VmaStatistics& inoutStats) +{ + VmaMutexLockRead lock(m_Mutex, m_UseMutex); + + const uint32_t allocCount = (uint32_t)m_AllocationList.GetCount(); + inoutStats.blockCount += allocCount; + inoutStats.allocationCount += allocCount; + + for(auto* item = m_AllocationList.Front(); item != nullptr; item = DedicatedAllocationLinkedList::GetNext(item)) { - VMA_SWAP(*insertIndex, *centerValue); + const VkDeviceSize size = item->GetSize(); + inoutStats.blockBytes += size; + inoutStats.allocationBytes += size; } - return insertIndex; } -template -void VmaQuickSort(Iterator beg, Iterator end, Compare cmp) +#if VMA_STATS_STRING_ENABLED +void VmaDedicatedAllocationList::BuildStatsString(VmaJsonWriter& json) { - if(beg < end) + VmaMutexLockRead lock(m_Mutex, m_UseMutex); + json.BeginArray(); + for (VmaAllocation alloc = m_AllocationList.Front(); + alloc != VMA_NULL; alloc = m_AllocationList.GetNext(alloc)) { - Iterator it = VmaQuickSortPartition(beg, end, cmp); - VmaQuickSort(beg, it, cmp); - VmaQuickSort(it + 1, end, cmp); + json.BeginObject(true); + alloc->PrintParameters(json); + json.EndObject(); } + json.EndArray(); } +#endif // VMA_STATS_STRING_ENABLED -#define VMA_SORT(beg, end, cmp) VmaQuickSort(beg, end, cmp) - -#endif // #ifndef VMA_SORT - -/* -Returns true if two memory blocks occupy overlapping pages. -ResourceA must be in less memory offset than ResourceB. +bool VmaDedicatedAllocationList::IsEmpty() +{ + VmaMutexLockRead lock(m_Mutex, m_UseMutex); + return m_AllocationList.IsEmpty(); +} -Algorithm is based on "Vulkan 1.0.39 - A Specification (with all registered Vulkan extensions)" -chapter 11.6 "Resource Memory Association", paragraph "Buffer-Image Granularity". -*/ -static inline bool VmaBlocksOnSamePage( - VkDeviceSize resourceAOffset, - VkDeviceSize resourceASize, - VkDeviceSize resourceBOffset, - VkDeviceSize pageSize) +void VmaDedicatedAllocationList::Register(VmaAllocation alloc) { - VMA_ASSERT(resourceAOffset + resourceASize <= resourceBOffset && resourceASize > 0 && pageSize > 0); - VkDeviceSize resourceAEnd = resourceAOffset + resourceASize - 1; - VkDeviceSize resourceAEndPage = resourceAEnd & ~(pageSize - 1); - VkDeviceSize resourceBStart = resourceBOffset; - VkDeviceSize resourceBStartPage = resourceBStart & ~(pageSize - 1); - return resourceAEndPage == resourceBStartPage; + VmaMutexLockWrite lock(m_Mutex, m_UseMutex); + m_AllocationList.PushBack(alloc); } -enum VmaSuballocationType +void VmaDedicatedAllocationList::Unregister(VmaAllocation alloc) { - VMA_SUBALLOCATION_TYPE_FREE = 0, - VMA_SUBALLOCATION_TYPE_UNKNOWN = 1, - VMA_SUBALLOCATION_TYPE_BUFFER = 2, - VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN = 3, - VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR = 4, - VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL = 5, - VMA_SUBALLOCATION_TYPE_MAX_ENUM = 0x7FFFFFFF -}; + VmaMutexLockWrite lock(m_Mutex, m_UseMutex); + m_AllocationList.Remove(alloc); +} +#endif // _VMA_DEDICATED_ALLOCATION_LIST_FUNCTIONS +#endif // _VMA_DEDICATED_ALLOCATION_LIST +#ifndef _VMA_SUBALLOCATION /* -Returns true if given suballocation types could conflict and must respect -VkPhysicalDeviceLimits::bufferImageGranularity. They conflict if one is buffer -or linear image and another one is optimal image. If type is unknown, behave -conservatively. +Represents a region of VmaDeviceMemoryBlock that is either assigned and returned as +allocated memory block or free. */ -static inline bool VmaIsBufferImageGranularityConflict( - VmaSuballocationType suballocType1, - VmaSuballocationType suballocType2) +struct VmaSuballocation +{ + VkDeviceSize offset; + VkDeviceSize size; + void* userData; + VmaSuballocationType type; +}; + +// Comparator for offsets. +struct VmaSuballocationOffsetLess { - if(suballocType1 > suballocType2) + bool operator()(const VmaSuballocation& lhs, const VmaSuballocation& rhs) const { - VMA_SWAP(suballocType1, suballocType2); + return lhs.offset < rhs.offset; } +}; - switch(suballocType1) +struct VmaSuballocationOffsetGreater +{ + bool operator()(const VmaSuballocation& lhs, const VmaSuballocation& rhs) const { - case VMA_SUBALLOCATION_TYPE_FREE: - return false; - case VMA_SUBALLOCATION_TYPE_UNKNOWN: - return true; - case VMA_SUBALLOCATION_TYPE_BUFFER: - return - suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || - suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; - case VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN: - return - suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || - suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR || - suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; - case VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR: - return - suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; - case VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL: - return false; - default: - VMA_ASSERT(0); - return true; + return lhs.offset > rhs.offset; } -} +}; -static void VmaWriteMagicValue(void* pData, VkDeviceSize offset) +struct VmaSuballocationItemSizeLess { -#if VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_DETECT_CORRUPTION - uint32_t* pDst = (uint32_t*)((char*)pData + offset); - const size_t numberCount = VMA_DEBUG_MARGIN / sizeof(uint32_t); - for(size_t i = 0; i < numberCount; ++i, ++pDst) + bool operator()(const VmaSuballocationList::iterator lhs, + const VmaSuballocationList::iterator rhs) const { - *pDst = VMA_CORRUPTION_DETECTION_MAGIC_VALUE; + return lhs->size < rhs->size; } -#else - // no-op -#endif -} -static bool VmaValidateMagicValue(const void* pData, VkDeviceSize offset) -{ -#if VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_DETECT_CORRUPTION - const uint32_t* pSrc = (const uint32_t*)((const char*)pData + offset); - const size_t numberCount = VMA_DEBUG_MARGIN / sizeof(uint32_t); - for(size_t i = 0; i < numberCount; ++i, ++pSrc) + bool operator()(const VmaSuballocationList::iterator lhs, + VkDeviceSize rhsSize) const { - if(*pSrc != VMA_CORRUPTION_DETECTION_MAGIC_VALUE) - { - return false; - } + return lhs->size < rhsSize; } -#endif - return true; -} +}; +#endif // _VMA_SUBALLOCATION +#ifndef _VMA_ALLOCATION_REQUEST /* -Fills structure with parameters of an example buffer to be used for transfers -during GPU memory defragmentation. +Parameters of planned allocation inside a VmaDeviceMemoryBlock. +item points to a FREE suballocation. */ -static void VmaFillGpuDefragmentationBufferCreateInfo(VkBufferCreateInfo& outBufCreateInfo) +struct VmaAllocationRequest { - memset(&outBufCreateInfo, 0, sizeof(outBufCreateInfo)); - outBufCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - outBufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - outBufCreateInfo.size = (VkDeviceSize)VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE; // Example size. -} - -// Helper RAII class to lock a mutex in constructor and unlock it in destructor (at the end of scope). -struct VmaMutexLock -{ - VMA_CLASS_NO_COPY(VmaMutexLock) -public: - VmaMutexLock(VMA_MUTEX& mutex, bool useMutex = true) : - m_pMutex(useMutex ? &mutex : VMA_NULL) - { if(m_pMutex) { m_pMutex->Lock(); } } - ~VmaMutexLock() - { if(m_pMutex) { m_pMutex->Unlock(); } } -private: - VMA_MUTEX* m_pMutex; + VmaAllocHandle allocHandle; + VkDeviceSize size; + VmaSuballocationList::iterator item; + void* customData; + uint64_t algorithmData; + VmaAllocationRequestType type; }; +#endif // _VMA_ALLOCATION_REQUEST -// Helper RAII class to lock a RW mutex in constructor and unlock it in destructor (at the end of scope), for reading. -struct VmaMutexLockRead +#ifndef _VMA_BLOCK_METADATA +/* +Data structure used for bookkeeping of allocations and unused ranges of memory +in a single VkDeviceMemory block. +*/ +class VmaBlockMetadata { - VMA_CLASS_NO_COPY(VmaMutexLockRead) public: - VmaMutexLockRead(VMA_RW_MUTEX& mutex, bool useMutex) : - m_pMutex(useMutex ? &mutex : VMA_NULL) - { if(m_pMutex) { m_pMutex->LockRead(); } } - ~VmaMutexLockRead() { if(m_pMutex) { m_pMutex->UnlockRead(); } } -private: - VMA_RW_MUTEX* m_pMutex; -}; + // pAllocationCallbacks, if not null, must be owned externally - alive and unchanged for the whole lifetime of this object. + VmaBlockMetadata(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual); + virtual ~VmaBlockMetadata() = default; -// Helper RAII class to lock a RW mutex in constructor and unlock it in destructor (at the end of scope), for writing. -struct VmaMutexLockWrite -{ - VMA_CLASS_NO_COPY(VmaMutexLockWrite) -public: - VmaMutexLockWrite(VMA_RW_MUTEX& mutex, bool useMutex) : - m_pMutex(useMutex ? &mutex : VMA_NULL) - { if(m_pMutex) { m_pMutex->LockWrite(); } } - ~VmaMutexLockWrite() { if(m_pMutex) { m_pMutex->UnlockWrite(); } } -private: - VMA_RW_MUTEX* m_pMutex; -}; + virtual void Init(VkDeviceSize size) { m_Size = size; } + bool IsVirtual() const { return m_IsVirtual; } + VkDeviceSize GetSize() const { return m_Size; } -#if VMA_DEBUG_GLOBAL_MUTEX - static VMA_MUTEX gDebugGlobalMutex; - #define VMA_DEBUG_GLOBAL_MUTEX_LOCK VmaMutexLock debugGlobalMutexLock(gDebugGlobalMutex, true); -#else - #define VMA_DEBUG_GLOBAL_MUTEX_LOCK + // Validates all data structures inside this object. If not valid, returns false. + virtual bool Validate() const = 0; + virtual size_t GetAllocationCount() const = 0; + virtual size_t GetFreeRegionsCount() const = 0; + virtual VkDeviceSize GetSumFreeSize() const = 0; + // Returns true if this block is empty - contains only single free suballocation. + virtual bool IsEmpty() const = 0; + virtual void GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) = 0; + virtual VkDeviceSize GetAllocationOffset(VmaAllocHandle allocHandle) const = 0; + virtual void* GetAllocationUserData(VmaAllocHandle allocHandle) const = 0; + + virtual VmaAllocHandle GetAllocationListBegin() const = 0; + virtual VmaAllocHandle GetNextAllocation(VmaAllocHandle prevAlloc) const = 0; + virtual VkDeviceSize GetNextFreeRegionSize(VmaAllocHandle alloc) const = 0; + + // Shouldn't modify blockCount. + virtual void AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const = 0; + virtual void AddStatistics(VmaStatistics& inoutStats) const = 0; + +#if VMA_STATS_STRING_ENABLED + virtual void PrintDetailedMap(class VmaJsonWriter& json) const = 0; #endif -// Minimum size of a free suballocation to register it in the free suballocation collection. -static const VkDeviceSize VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER = 16; + // Tries to find a place for suballocation with given parameters inside this block. + // If succeeded, fills pAllocationRequest and returns true. + // If failed, returns false. + virtual bool CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + // Always one of VMA_ALLOCATION_CREATE_STRATEGY_* or VMA_ALLOCATION_INTERNAL_STRATEGY_* flags. + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) = 0; -/* -Performs binary search and returns iterator to first element that is greater or -equal to (key), according to comparison (cmp). + virtual VkResult CheckCorruption(const void* pBlockData) = 0; -Cmp should return true if first argument is less than second argument. + // Makes actual allocation based on request. Request must already be checked and valid. + virtual void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) = 0; -Returned value is the found element, if present in the collection or place where -new element with value (key) should be inserted. -*/ -template -static IterT VmaBinaryFindFirstNotLess(IterT beg, IterT end, const KeyT &key, const CmpLess& cmp) -{ - size_t down = 0, up = (end - beg); - while(down < up) - { - const size_t mid = down + (up - down) / 2; // Overflow-safe midpoint calculation - if(cmp(*(beg+mid), key)) - { - down = mid + 1; - } - else - { - up = mid; - } - } - return beg + down; -} + // Frees suballocation assigned to given memory region. + virtual void Free(VmaAllocHandle allocHandle) = 0; -template -IterT VmaBinaryFindSorted(const IterT& beg, const IterT& end, const KeyT& value, const CmpLess& cmp) -{ - IterT it = VmaBinaryFindFirstNotLess( - beg, end, value, cmp); - if(it == end || - (!cmp(*it, value) && !cmp(value, *it))) - { - return it; - } - return end; -} + // Frees all allocations. + // Careful! Don't call it if there are VmaAllocation objects owned by userData of cleared allocations! + virtual void Clear() = 0; -/* -Returns true if all pointers in the array are not-null and unique. -Warning! O(n^2) complexity. Use only inside VMA_HEAVY_ASSERT. -T must be pointer type, e.g. VmaAllocation, VmaPool. -*/ -template -static bool VmaValidatePointerArray(uint32_t count, const T* arr) -{ - for (const auto i : c10::irange(count)) { - const T iPtr = arr[i]; - if(iPtr == VMA_NULL) - { - return false; - } - for(uint32_t j = i + 1; j < count; ++j) - { - if(iPtr == arr[j]) - { - return false; - } - } - } - return true; -} + virtual void SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) = 0; + virtual void DebugLogAllAllocations() const = 0; -template -static inline void VmaPnextChainPushFront(MainT* mainStruct, NewT* newStruct) -{ - newStruct->pNext = mainStruct->pNext; - mainStruct->pNext = newStruct; -} +protected: + const VkAllocationCallbacks* GetAllocationCallbacks() const { return m_pAllocationCallbacks; } + VkDeviceSize GetBufferImageGranularity() const { return m_BufferImageGranularity; } + VkDeviceSize GetDebugMargin() const { return IsVirtual() ? 0 : VMA_DEBUG_MARGIN; } -//////////////////////////////////////////////////////////////////////////////// -// Memory allocation + void DebugLogAllocation(VkDeviceSize offset, VkDeviceSize size, void* userData) const; +#if VMA_STATS_STRING_ENABLED + // mapRefCount == UINT32_MAX means unspecified. + void PrintDetailedMap_Begin(class VmaJsonWriter& json, + VkDeviceSize unusedBytes, + size_t allocationCount, + size_t unusedRangeCount) const; + void PrintDetailedMap_Allocation(class VmaJsonWriter& json, + VkDeviceSize offset, VkDeviceSize size, void* userData) const; + void PrintDetailedMap_UnusedRange(class VmaJsonWriter& json, + VkDeviceSize offset, + VkDeviceSize size) const; + void PrintDetailedMap_End(class VmaJsonWriter& json) const; +#endif -static void* VmaMalloc(const VkAllocationCallbacks* pAllocationCallbacks, size_t size, size_t alignment) -{ - void* result = VMA_NULL; - if((pAllocationCallbacks != VMA_NULL) && - (pAllocationCallbacks->pfnAllocation != VMA_NULL)) - { - result = (*pAllocationCallbacks->pfnAllocation)( - pAllocationCallbacks->pUserData, - size, - alignment, - VK_SYSTEM_ALLOCATION_SCOPE_OBJECT); - } - else - { - result = VMA_SYSTEM_ALIGNED_MALLOC(size, alignment); - } - VMA_ASSERT(result != VMA_NULL && "CPU memory allocation failed."); - return result; -} +private: + VkDeviceSize m_Size; + const VkAllocationCallbacks* m_pAllocationCallbacks; + const VkDeviceSize m_BufferImageGranularity; + const bool m_IsVirtual; +}; -static void VmaFree(const VkAllocationCallbacks* pAllocationCallbacks, void* ptr) +#ifndef _VMA_BLOCK_METADATA_FUNCTIONS +VmaBlockMetadata::VmaBlockMetadata(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual) + : m_Size(0), + m_pAllocationCallbacks(pAllocationCallbacks), + m_BufferImageGranularity(bufferImageGranularity), + m_IsVirtual(isVirtual) {} + +void VmaBlockMetadata::DebugLogAllocation(VkDeviceSize offset, VkDeviceSize size, void* userData) const { - if((pAllocationCallbacks != VMA_NULL) && - (pAllocationCallbacks->pfnFree != VMA_NULL)) + if (IsVirtual()) { - (*pAllocationCallbacks->pfnFree)(pAllocationCallbacks->pUserData, ptr); + VMA_DEBUG_LOG("UNFREED VIRTUAL ALLOCATION; Offset: %llu; Size: %llu; UserData: %p", offset, size, userData); } else { - VMA_SYSTEM_ALIGNED_FREE(ptr); + VMA_ASSERT(userData != VMA_NULL); + VmaAllocation allocation = reinterpret_cast(userData); + + userData = allocation->GetUserData(); + const char* name = allocation->GetName(); + +#if VMA_STATS_STRING_ENABLED + VMA_DEBUG_LOG("UNFREED ALLOCATION; Offset: %llu; Size: %llu; UserData: %p; Name: %s; Type: %s; Usage: %u", + offset, size, userData, name ? name : "vma_empty", + VMA_SUBALLOCATION_TYPE_NAMES[allocation->GetSuballocationType()], + allocation->GetBufferImageUsage()); +#else + VMA_DEBUG_LOG("UNFREED ALLOCATION; Offset: %llu; Size: %llu; UserData: %p; Name: %s; Type: %u", + offset, size, userData, name ? name : "vma_empty", + (uint32_t)allocation->GetSuballocationType()); +#endif // VMA_STATS_STRING_ENABLED } + } -template -static T* VmaAllocate(const VkAllocationCallbacks* pAllocationCallbacks) +#if VMA_STATS_STRING_ENABLED +void VmaBlockMetadata::PrintDetailedMap_Begin(class VmaJsonWriter& json, + VkDeviceSize unusedBytes, size_t allocationCount, size_t unusedRangeCount) const { - return (T*)VmaMalloc(pAllocationCallbacks, sizeof(T), VMA_ALIGN_OF(T)); -} + json.WriteString("TotalBytes"); + json.WriteNumber(GetSize()); -template -static T* VmaAllocateArray(const VkAllocationCallbacks* pAllocationCallbacks, size_t count) -{ - return (T*)VmaMalloc(pAllocationCallbacks, sizeof(T) * count, VMA_ALIGN_OF(T)); -} + json.WriteString("UnusedBytes"); + json.WriteSize(unusedBytes); -#define vma_new(allocator, type) new(VmaAllocate(allocator))(type) + json.WriteString("Allocations"); + json.WriteSize(allocationCount); -#define vma_new_array(allocator, type, count) new(VmaAllocateArray((allocator), (count)))(type) + json.WriteString("UnusedRanges"); + json.WriteSize(unusedRangeCount); -template -static void vma_delete(const VkAllocationCallbacks* pAllocationCallbacks, T* ptr) -{ - ptr->~T(); - VmaFree(pAllocationCallbacks, ptr); + json.WriteString("Suballocations"); + json.BeginArray(); } -template -static void vma_delete_array(const VkAllocationCallbacks* pAllocationCallbacks, T* ptr, size_t count) +void VmaBlockMetadata::PrintDetailedMap_Allocation(class VmaJsonWriter& json, + VkDeviceSize offset, VkDeviceSize size, void* userData) const { - if(ptr != VMA_NULL) + json.BeginObject(true); + + json.WriteString("Offset"); + json.WriteNumber(offset); + + if (IsVirtual()) { - for(size_t i = count; i--; ) + json.WriteString("Size"); + json.WriteNumber(size); + if (userData) { - ptr[i].~T(); + json.WriteString("CustomData"); + json.BeginString(); + json.ContinueString_Pointer(userData); + json.EndString(); } - VmaFree(pAllocationCallbacks, ptr); - } -} - -static char* VmaCreateStringCopy(const VkAllocationCallbacks* allocs, const char* srcStr) -{ - if(srcStr != VMA_NULL) - { - const size_t len = strlen(srcStr); - char* const result = vma_new_array(allocs, char, len + 1); - memcpy(result, srcStr, len + 1); - return result; } else { - return VMA_NULL; + ((VmaAllocation)userData)->PrintParameters(json); } -} -static void VmaFreeString(const VkAllocationCallbacks* allocs, char* str) -{ - if(str != VMA_NULL) - { - const size_t len = strlen(str); - vma_delete_array(allocs, str, len + 1); - } + json.EndObject(); } -// STL-compatible allocator. -template -class VmaStlAllocator +void VmaBlockMetadata::PrintDetailedMap_UnusedRange(class VmaJsonWriter& json, + VkDeviceSize offset, VkDeviceSize size) const { -public: - const VkAllocationCallbacks* const m_pCallbacks; - typedef T value_type; + json.BeginObject(true); - VmaStlAllocator(const VkAllocationCallbacks* pCallbacks) : m_pCallbacks(pCallbacks) { } - template VmaStlAllocator(const VmaStlAllocator& src) : m_pCallbacks(src.m_pCallbacks) { } + json.WriteString("Offset"); + json.WriteNumber(offset); - T* allocate(size_t n) { return VmaAllocateArray(m_pCallbacks, n); } - void deallocate(T* p, size_t n) { VmaFree(m_pCallbacks, p); } + json.WriteString("Type"); + json.WriteString(VMA_SUBALLOCATION_TYPE_NAMES[VMA_SUBALLOCATION_TYPE_FREE]); - template - bool operator==(const VmaStlAllocator& rhs) const - { - return m_pCallbacks == rhs.m_pCallbacks; - } - template - bool operator!=(const VmaStlAllocator& rhs) const - { - return m_pCallbacks != rhs.m_pCallbacks; - } - - VmaStlAllocator& operator=(const VmaStlAllocator& x) = delete; -}; - -#if VMA_USE_STL_VECTOR - -#define VmaVector std::vector + json.WriteString("Size"); + json.WriteNumber(size); -template -static void VmaVectorInsert(std::vector& vec, size_t index, const T& item) -{ - vec.insert(vec.begin() + index, item); + json.EndObject(); } -template -static void VmaVectorRemove(std::vector& vec, size_t index) +void VmaBlockMetadata::PrintDetailedMap_End(class VmaJsonWriter& json) const { - vec.erase(vec.begin() + index); + json.EndArray(); } +#endif // VMA_STATS_STRING_ENABLED +#endif // _VMA_BLOCK_METADATA_FUNCTIONS +#endif // _VMA_BLOCK_METADATA -#else // #if VMA_USE_STL_VECTOR - -/* Class with interface compatible with subset of std::vector. -T must be POD because constructors and destructors are not called and memcpy is -used for these objects. */ -template -class VmaVector +#ifndef _VMA_BLOCK_BUFFER_IMAGE_GRANULARITY +// Before deleting object of this class remember to call 'Destroy()' +class VmaBlockBufferImageGranularity final { public: - typedef T value_type; - - VmaVector(const AllocatorT& allocator) : - m_Allocator(allocator), - m_pArray(VMA_NULL), - m_Count(0), - m_Capacity(0) + struct ValidationContext { - } + const VkAllocationCallbacks* allocCallbacks; + uint16_t* pageAllocs; + }; - VmaVector(size_t count, const AllocatorT& allocator) : - m_Allocator(allocator), - m_pArray(count ? (T*)VmaAllocateArray(allocator.m_pCallbacks, count) : VMA_NULL), - m_Count(count), - m_Capacity(count) - { - } + VmaBlockBufferImageGranularity(VkDeviceSize bufferImageGranularity); + ~VmaBlockBufferImageGranularity(); - // This version of the constructor is here for compatibility with pre-C++14 std::vector. - // value is unused. - VmaVector(size_t count, const T& value, const AllocatorT& allocator) - : VmaVector(count, allocator) {} + bool IsEnabled() const { return m_BufferImageGranularity > MAX_LOW_BUFFER_IMAGE_GRANULARITY; } - VmaVector(const VmaVector& src) : - m_Allocator(src.m_Allocator), - m_pArray(src.m_Count ? (T*)VmaAllocateArray(src.m_Allocator.m_pCallbacks, src.m_Count) : VMA_NULL), - m_Count(src.m_Count), - m_Capacity(src.m_Count) - { - if(m_Count != 0) - { - memcpy(m_pArray, src.m_pArray, m_Count * sizeof(T)); - } - } + void Init(const VkAllocationCallbacks* pAllocationCallbacks, VkDeviceSize size); + // Before destroying object you must call free it's memory + void Destroy(const VkAllocationCallbacks* pAllocationCallbacks); - ~VmaVector() - { - VmaFree(m_Allocator.m_pCallbacks, m_pArray); - } + void RoundupAllocRequest(VmaSuballocationType allocType, + VkDeviceSize& inOutAllocSize, + VkDeviceSize& inOutAllocAlignment) const; - VmaVector& operator=(const VmaVector& rhs) - { - if(&rhs != this) - { - resize(rhs.m_Count); - if(m_Count != 0) - { - memcpy(m_pArray, rhs.m_pArray, m_Count * sizeof(T)); - } - } - return *this; - } + bool CheckConflictAndAlignUp(VkDeviceSize& inOutAllocOffset, + VkDeviceSize allocSize, + VkDeviceSize blockOffset, + VkDeviceSize blockSize, + VmaSuballocationType allocType) const; - bool empty() const { return m_Count == 0; } - size_t size() const { return m_Count; } - T* data() { return m_pArray; } - const T* data() const { return m_pArray; } + void AllocPages(uint8_t allocType, VkDeviceSize offset, VkDeviceSize size); + void FreePages(VkDeviceSize offset, VkDeviceSize size); + void Clear(); - T& operator[](size_t index) - { - VMA_HEAVY_ASSERT(index < m_Count); - return m_pArray[index]; - } - const T& operator[](size_t index) const - { - VMA_HEAVY_ASSERT(index < m_Count); - return m_pArray[index]; - } + ValidationContext StartValidation(const VkAllocationCallbacks* pAllocationCallbacks, + bool isVirutal) const; + bool Validate(ValidationContext& ctx, VkDeviceSize offset, VkDeviceSize size) const; + bool FinishValidation(ValidationContext& ctx) const; - T& front() - { - VMA_HEAVY_ASSERT(m_Count > 0); - return m_pArray[0]; - } - const T& front() const - { - VMA_HEAVY_ASSERT(m_Count > 0); - return m_pArray[0]; - } - T& back() - { - VMA_HEAVY_ASSERT(m_Count > 0); - return m_pArray[m_Count - 1]; - } - const T& back() const - { - VMA_HEAVY_ASSERT(m_Count > 0); - return m_pArray[m_Count - 1]; - } +private: + static const uint16_t MAX_LOW_BUFFER_IMAGE_GRANULARITY = 256; - void reserve(size_t newCapacity, bool freeMemory = false) + struct RegionInfo { - newCapacity = VMA_MAX(newCapacity, m_Count); + uint8_t allocType; + uint16_t allocCount; + }; - if((newCapacity < m_Capacity) && !freeMemory) - { - newCapacity = m_Capacity; - } + VkDeviceSize m_BufferImageGranularity; + uint32_t m_RegionCount; + RegionInfo* m_RegionInfo; - if(newCapacity != m_Capacity) - { - T* const newArray = newCapacity ? VmaAllocateArray(m_Allocator, newCapacity) : VMA_NULL; - if(m_Count != 0) - { - memcpy(newArray, m_pArray, m_Count * sizeof(T)); - } - VmaFree(m_Allocator.m_pCallbacks, m_pArray); - m_Capacity = newCapacity; - m_pArray = newArray; - } - } + uint32_t GetStartPage(VkDeviceSize offset) const { return OffsetToPageIndex(offset & ~(m_BufferImageGranularity - 1)); } + uint32_t GetEndPage(VkDeviceSize offset, VkDeviceSize size) const { return OffsetToPageIndex((offset + size - 1) & ~(m_BufferImageGranularity - 1)); } - void resize(size_t newCount, bool freeMemory = false) - { - size_t newCapacity = m_Capacity; - if(newCount > m_Capacity) - { - newCapacity = VMA_MAX(newCount, VMA_MAX(m_Capacity * 3 / 2, (size_t)8)); - } - else if(freeMemory) - { - newCapacity = newCount; - } + uint32_t OffsetToPageIndex(VkDeviceSize offset) const; + void AllocPage(RegionInfo& page, uint8_t allocType); +}; - if(newCapacity != m_Capacity) - { - T* const newArray = newCapacity ? VmaAllocateArray(m_Allocator.m_pCallbacks, newCapacity) : VMA_NULL; - const size_t elementsToCopy = VMA_MIN(m_Count, newCount); - if(elementsToCopy != 0) - { - memcpy(newArray, m_pArray, elementsToCopy * sizeof(T)); - } - VmaFree(m_Allocator.m_pCallbacks, m_pArray); - m_Capacity = newCapacity; - m_pArray = newArray; - } +#ifndef _VMA_BLOCK_BUFFER_IMAGE_GRANULARITY_FUNCTIONS +VmaBlockBufferImageGranularity::VmaBlockBufferImageGranularity(VkDeviceSize bufferImageGranularity) + : m_BufferImageGranularity(bufferImageGranularity), + m_RegionCount(0), + m_RegionInfo(VMA_NULL) {} - m_Count = newCount; - } +VmaBlockBufferImageGranularity::~VmaBlockBufferImageGranularity() +{ + VMA_ASSERT(m_RegionInfo == VMA_NULL && "Free not called before destroying object!"); +} - void clear(bool freeMemory = false) +void VmaBlockBufferImageGranularity::Init(const VkAllocationCallbacks* pAllocationCallbacks, VkDeviceSize size) +{ + if (IsEnabled()) { - resize(0, freeMemory); + m_RegionCount = static_cast(VmaDivideRoundingUp(size, m_BufferImageGranularity)); + m_RegionInfo = vma_new_array(pAllocationCallbacks, RegionInfo, m_RegionCount); + memset(m_RegionInfo, 0, m_RegionCount * sizeof(RegionInfo)); } +} - void insert(size_t index, const T& src) +void VmaBlockBufferImageGranularity::Destroy(const VkAllocationCallbacks* pAllocationCallbacks) +{ + if (m_RegionInfo) { - VMA_HEAVY_ASSERT(index <= m_Count); - const size_t oldCount = size(); - resize(oldCount + 1); - if(index < oldCount) - { - memmove(m_pArray + (index + 1), m_pArray + index, (oldCount - index) * sizeof(T)); - } - m_pArray[index] = src; + vma_delete_array(pAllocationCallbacks, m_RegionInfo, m_RegionCount); + m_RegionInfo = VMA_NULL; } +} - void remove(size_t index) +void VmaBlockBufferImageGranularity::RoundupAllocRequest(VmaSuballocationType allocType, + VkDeviceSize& inOutAllocSize, + VkDeviceSize& inOutAllocAlignment) const +{ + if (m_BufferImageGranularity > 1 && + m_BufferImageGranularity <= MAX_LOW_BUFFER_IMAGE_GRANULARITY) { - VMA_HEAVY_ASSERT(index < m_Count); - const size_t oldCount = size(); - if(index < oldCount - 1) + if (allocType == VMA_SUBALLOCATION_TYPE_UNKNOWN || + allocType == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || + allocType == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL) { - memmove(m_pArray + index, m_pArray + (index + 1), (oldCount - index - 1) * sizeof(T)); + inOutAllocAlignment = VMA_MAX(inOutAllocAlignment, m_BufferImageGranularity); + inOutAllocSize = VmaAlignUp(inOutAllocSize, m_BufferImageGranularity); } - resize(oldCount - 1); - } - - void push_back(const T& src) - { - const size_t newIndex = size(); - resize(newIndex + 1); - m_pArray[newIndex] = src; } +} - void pop_back() +bool VmaBlockBufferImageGranularity::CheckConflictAndAlignUp(VkDeviceSize& inOutAllocOffset, + VkDeviceSize allocSize, + VkDeviceSize blockOffset, + VkDeviceSize blockSize, + VmaSuballocationType allocType) const +{ + if (IsEnabled()) { - VMA_HEAVY_ASSERT(m_Count > 0); - resize(size() - 1); + uint32_t startPage = GetStartPage(inOutAllocOffset); + if (m_RegionInfo[startPage].allocCount > 0 && + VmaIsBufferImageGranularityConflict(static_cast(m_RegionInfo[startPage].allocType), allocType)) + { + inOutAllocOffset = VmaAlignUp(inOutAllocOffset, m_BufferImageGranularity); + if (blockSize < allocSize + inOutAllocOffset - blockOffset) + return true; + ++startPage; + } + uint32_t endPage = GetEndPage(inOutAllocOffset, allocSize); + if (endPage != startPage && + m_RegionInfo[endPage].allocCount > 0 && + VmaIsBufferImageGranularityConflict(static_cast(m_RegionInfo[endPage].allocType), allocType)) + { + return true; + } } + return false; +} - void push_front(const T& src) +void VmaBlockBufferImageGranularity::AllocPages(uint8_t allocType, VkDeviceSize offset, VkDeviceSize size) +{ + if (IsEnabled()) { - insert(0, src); - } + uint32_t startPage = GetStartPage(offset); + AllocPage(m_RegionInfo[startPage], allocType); - void pop_front() - { - VMA_HEAVY_ASSERT(m_Count > 0); - remove(0); + uint32_t endPage = GetEndPage(offset, size); + if (startPage != endPage) + AllocPage(m_RegionInfo[endPage], allocType); } - - typedef T* iterator; - - iterator begin() { return m_pArray; } - iterator end() { return m_pArray + m_Count; } - -private: - AllocatorT m_Allocator; - T* m_pArray; - size_t m_Count; - size_t m_Capacity; -}; - -template -static void VmaVectorInsert(VmaVector& vec, size_t index, const T& item) -{ - vec.insert(index, item); } -template -static void VmaVectorRemove(VmaVector& vec, size_t index) +void VmaBlockBufferImageGranularity::FreePages(VkDeviceSize offset, VkDeviceSize size) { - vec.remove(index); + if (IsEnabled()) + { + uint32_t startPage = GetStartPage(offset); + --m_RegionInfo[startPage].allocCount; + if (m_RegionInfo[startPage].allocCount == 0) + m_RegionInfo[startPage].allocType = VMA_SUBALLOCATION_TYPE_FREE; + uint32_t endPage = GetEndPage(offset, size); + if (startPage != endPage) + { + --m_RegionInfo[endPage].allocCount; + if (m_RegionInfo[endPage].allocCount == 0) + m_RegionInfo[endPage].allocType = VMA_SUBALLOCATION_TYPE_FREE; + } + } } -#endif // #if VMA_USE_STL_VECTOR - -template -size_t VmaVectorInsertSorted(VectorT& vector, const typename VectorT::value_type& value) +void VmaBlockBufferImageGranularity::Clear() { - const size_t indexToInsert = VmaBinaryFindFirstNotLess( - vector.data(), - vector.data() + vector.size(), - value, - CmpLess()) - vector.data(); - VmaVectorInsert(vector, indexToInsert, value); - return indexToInsert; + if (m_RegionInfo) + memset(m_RegionInfo, 0, m_RegionCount * sizeof(RegionInfo)); } -template -bool VmaVectorRemoveSorted(VectorT& vector, const typename VectorT::value_type& value) +VmaBlockBufferImageGranularity::ValidationContext VmaBlockBufferImageGranularity::StartValidation( + const VkAllocationCallbacks* pAllocationCallbacks, bool isVirutal) const { - CmpLess comparator; - typename VectorT::iterator it = VmaBinaryFindFirstNotLess( - vector.begin(), - vector.end(), - value, - comparator); - if((it != vector.end()) && !comparator(*it, value) && !comparator(value, *it)) + ValidationContext ctx{ pAllocationCallbacks, VMA_NULL }; + if (!isVirutal && IsEnabled()) { - size_t indexToRemove = it - vector.begin(); - VmaVectorRemove(vector, indexToRemove); - return true; + ctx.pageAllocs = vma_new_array(pAllocationCallbacks, uint16_t, m_RegionCount); + memset(ctx.pageAllocs, 0, m_RegionCount * sizeof(uint16_t)); } - return false; + return ctx; } -//////////////////////////////////////////////////////////////////////////////// -// class VmaSmallVector - -/* -This is a vector (a variable-sized array), optimized for the case when the array is small. - -It contains some number of elements in-place, which allows it to avoid heap allocation -when the actual number of elements is below that threshold. This allows normal "small" -cases to be fast without losing generality for large inputs. -*/ - -template -class VmaSmallVector +bool VmaBlockBufferImageGranularity::Validate(ValidationContext& ctx, + VkDeviceSize offset, VkDeviceSize size) const { -public: - typedef T value_type; - - VmaSmallVector(const AllocatorT& allocator) : - m_Count(0), - m_DynamicArray(allocator) - { - } - VmaSmallVector(size_t count, const AllocatorT& allocator) : - m_Count(count), - m_DynamicArray(count > N ? count : 0, allocator) - { - } - template - VmaSmallVector(const VmaSmallVector& src) = delete; - template - VmaSmallVector& operator=(const VmaSmallVector& rhs) = delete; - - bool empty() const { return m_Count == 0; } - size_t size() const { return m_Count; } - T* data() { return m_Count > N ? m_DynamicArray.data() : m_StaticArray; } - const T* data() const { return m_Count > N ? m_DynamicArray.data() : m_StaticArray; } - - T& operator[](size_t index) - { - VMA_HEAVY_ASSERT(index < m_Count); - return data()[index]; - } - const T& operator[](size_t index) const - { - VMA_HEAVY_ASSERT(index < m_Count); - return data()[index]; - } - - T& front() - { - VMA_HEAVY_ASSERT(m_Count > 0); - return data()[0]; - } - const T& front() const - { - VMA_HEAVY_ASSERT(m_Count > 0); - return data()[0]; - } - T& back() - { - VMA_HEAVY_ASSERT(m_Count > 0); - return data()[m_Count - 1]; - } - const T& back() const + if (IsEnabled()) { - VMA_HEAVY_ASSERT(m_Count > 0); - return data()[m_Count - 1]; - } + uint32_t start = GetStartPage(offset); + ++ctx.pageAllocs[start]; + VMA_VALIDATE(m_RegionInfo[start].allocCount > 0); - void resize(size_t newCount, bool freeMemory = false) - { - if(newCount > N && m_Count > N) - { - // Any direction, staying in m_DynamicArray - m_DynamicArray.resize(newCount, freeMemory); - } - else if(newCount > N && m_Count <= N) - { - // Growing, moving from m_StaticArray to m_DynamicArray - m_DynamicArray.resize(newCount, freeMemory); - if(m_Count > 0) - { - memcpy(m_DynamicArray.data(), m_StaticArray, m_Count * sizeof(T)); - } - } - else if(newCount <= N && m_Count > N) - { - // Shrinking, moving from m_DynamicArray to m_StaticArray - if(newCount > 0) - { - memcpy(m_StaticArray, m_DynamicArray.data(), newCount * sizeof(T)); - } - m_DynamicArray.resize(0, freeMemory); - } - else + uint32_t end = GetEndPage(offset, size); + if (start != end) { - // Any direction, staying in m_StaticArray - nothing to do here + ++ctx.pageAllocs[end]; + VMA_VALIDATE(m_RegionInfo[end].allocCount > 0); } - m_Count = newCount; } + return true; +} - void clear(bool freeMemory = false) +bool VmaBlockBufferImageGranularity::FinishValidation(ValidationContext& ctx) const +{ + // Check proper page structure + if (IsEnabled()) { - m_DynamicArray.clear(freeMemory); - m_Count = 0; - } + VMA_ASSERT(ctx.pageAllocs != VMA_NULL && "Validation context not initialized!"); - void insert(size_t index, const T& src) - { - VMA_HEAVY_ASSERT(index <= m_Count); - const size_t oldCount = size(); - resize(oldCount + 1); - T* const dataPtr = data(); - if(index < oldCount) + for (uint32_t page = 0; page < m_RegionCount; ++page) { - // I know, this could be more optimal for case where memmove can be memcpy directly from m_StaticArray to m_DynamicArray. - memmove(dataPtr + (index + 1), dataPtr + index, (oldCount - index) * sizeof(T)); + VMA_VALIDATE(ctx.pageAllocs[page] == m_RegionInfo[page].allocCount); } - dataPtr[index] = src; + vma_delete_array(ctx.allocCallbacks, ctx.pageAllocs, m_RegionCount); + ctx.pageAllocs = VMA_NULL; } + return true; +} - void remove(size_t index) - { - VMA_HEAVY_ASSERT(index < m_Count); - const size_t oldCount = size(); - if(index < oldCount - 1) - { - // I know, this could be more optimal for case where memmove can be memcpy directly from m_DynamicArray to m_StaticArray. - T* const dataPtr = data(); - memmove(dataPtr + index, dataPtr + (index + 1), (oldCount - index - 1) * sizeof(T)); - } - resize(oldCount - 1); - } +uint32_t VmaBlockBufferImageGranularity::OffsetToPageIndex(VkDeviceSize offset) const +{ + return static_cast(offset >> VMA_BITSCAN_MSB(m_BufferImageGranularity)); +} - void push_back(const T& src) - { - const size_t newIndex = size(); - resize(newIndex + 1); - data()[newIndex] = src; - } +void VmaBlockBufferImageGranularity::AllocPage(RegionInfo& page, uint8_t allocType) +{ + // When current alloc type is free then it can be overriden by new type + if (page.allocCount == 0 || (page.allocCount > 0 && page.allocType == VMA_SUBALLOCATION_TYPE_FREE)) + page.allocType = allocType; - void pop_back() - { - VMA_HEAVY_ASSERT(m_Count > 0); - resize(size() - 1); - } + ++page.allocCount; +} +#endif // _VMA_BLOCK_BUFFER_IMAGE_GRANULARITY_FUNCTIONS +#endif // _VMA_BLOCK_BUFFER_IMAGE_GRANULARITY - void push_front(const T& src) - { - insert(0, src); - } +#if 0 +#ifndef _VMA_BLOCK_METADATA_GENERIC +class VmaBlockMetadata_Generic : public VmaBlockMetadata +{ + friend class VmaDefragmentationAlgorithm_Generic; + friend class VmaDefragmentationAlgorithm_Fast; + VMA_CLASS_NO_COPY(VmaBlockMetadata_Generic) +public: + VmaBlockMetadata_Generic(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual); + virtual ~VmaBlockMetadata_Generic() = default; - void pop_front() - { - VMA_HEAVY_ASSERT(m_Count > 0); - remove(0); - } + size_t GetAllocationCount() const override { return m_Suballocations.size() - m_FreeCount; } + VkDeviceSize GetSumFreeSize() const override { return m_SumFreeSize; } + bool IsEmpty() const override { return (m_Suballocations.size() == 1) && (m_FreeCount == 1); } + void Free(VmaAllocHandle allocHandle) override { FreeSuballocation(FindAtOffset((VkDeviceSize)allocHandle - 1)); } + VkDeviceSize GetAllocationOffset(VmaAllocHandle allocHandle) const override { return (VkDeviceSize)allocHandle - 1; }; - typedef T* iterator; + void Init(VkDeviceSize size) override; + bool Validate() const override; - iterator begin() { return data(); } - iterator end() { return data() + m_Count; } + void AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const override; + void AddStatistics(VmaStatistics& inoutStats) const override; -private: - size_t m_Count; - T m_StaticArray[N]; // Used when m_Size <= N - VmaVector m_DynamicArray; // Used when m_Size > N -}; +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap(class VmaJsonWriter& json, uint32_t mapRefCount) const override; +#endif -//////////////////////////////////////////////////////////////////////////////// -// class VmaPoolAllocator + bool CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) override; -/* -Allocator for objects of type T using a list of arrays (pools) to speed up -allocation. Number of elements that can be allocated is not bounded because -allocator can create multiple blocks. -*/ -template -class VmaPoolAllocator -{ - VMA_CLASS_NO_COPY(VmaPoolAllocator) -public: - VmaPoolAllocator(const VkAllocationCallbacks* pAllocationCallbacks, uint32_t firstBlockCapacity); - ~VmaPoolAllocator(); - template T* Alloc(Types... args); - void Free(T* ptr); + VkResult CheckCorruption(const void* pBlockData) override; -private: - union Item - { - uint32_t NextFreeIndex; - alignas(T) char Value[sizeof(T)]; - }; + void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) override; - struct ItemBlock - { - Item* pItems; - uint32_t Capacity; - uint32_t FirstFreeIndex; - }; + void GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) override; + void* GetAllocationUserData(VmaAllocHandle allocHandle) const override; + VmaAllocHandle GetAllocationListBegin() const override; + VmaAllocHandle GetNextAllocation(VmaAllocHandle prevAlloc) const override; + void Clear() override; + void SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) override; + void DebugLogAllAllocations() const override; - const VkAllocationCallbacks* m_pAllocationCallbacks; - const uint32_t m_FirstBlockCapacity; - VmaVector< ItemBlock, VmaStlAllocator > m_ItemBlocks; +private: + uint32_t m_FreeCount; + VkDeviceSize m_SumFreeSize; + VmaSuballocationList m_Suballocations; + // Suballocations that are free. Sorted by size, ascending. + VmaVector> m_FreeSuballocationsBySize; - ItemBlock& CreateNewBlock(); -}; + VkDeviceSize AlignAllocationSize(VkDeviceSize size) const { return IsVirtual() ? size : VmaAlignUp(size, (VkDeviceSize)16); } -template -VmaPoolAllocator::VmaPoolAllocator(const VkAllocationCallbacks* pAllocationCallbacks, uint32_t firstBlockCapacity) : - m_pAllocationCallbacks(pAllocationCallbacks), - m_FirstBlockCapacity(firstBlockCapacity), - m_ItemBlocks(VmaStlAllocator(pAllocationCallbacks)) -{ - VMA_ASSERT(m_FirstBlockCapacity > 1); -} + VmaSuballocationList::iterator FindAtOffset(VkDeviceSize offset) const; + bool ValidateFreeSuballocationList() const; -template -VmaPoolAllocator::~VmaPoolAllocator() -{ - for(size_t i = m_ItemBlocks.size(); i--; ) - vma_delete_array(m_pAllocationCallbacks, m_ItemBlocks[i].pItems, m_ItemBlocks[i].Capacity); - m_ItemBlocks.clear(); -} + // Checks if requested suballocation with given parameters can be placed in given pFreeSuballocItem. + // If yes, fills pOffset and returns true. If no, returns false. + bool CheckAllocation( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + VmaSuballocationList::const_iterator suballocItem, + VmaAllocHandle* pAllocHandle) const; -template -template T* VmaPoolAllocator::Alloc(Types... args) -{ - for(size_t i = m_ItemBlocks.size(); i--; ) - { - ItemBlock& block = m_ItemBlocks[i]; - // This block has some free items: Use first one. - if(block.FirstFreeIndex != UINT32_MAX) - { - Item* const pItem = &block.pItems[block.FirstFreeIndex]; - block.FirstFreeIndex = pItem->NextFreeIndex; - T* result = (T*)&pItem->Value; - new(result)T(std::forward(args)...); // Explicit constructor call. - return result; - } - } + // Given free suballocation, it merges it with following one, which must also be free. + void MergeFreeWithNext(VmaSuballocationList::iterator item); + // Releases given suballocation, making it free. + // Merges it with adjacent free suballocations if applicable. + // Returns iterator to new free suballocation at this place. + VmaSuballocationList::iterator FreeSuballocation(VmaSuballocationList::iterator suballocItem); + // Given free suballocation, it inserts it into sorted list of + // m_FreeSuballocationsBySize if it is suitable. + void RegisterFreeSuballocation(VmaSuballocationList::iterator item); + // Given free suballocation, it removes it from sorted list of + // m_FreeSuballocationsBySize if it is suitable. + void UnregisterFreeSuballocation(VmaSuballocationList::iterator item); +}; - // No block has free item: Create new one and use it. - ItemBlock& newBlock = CreateNewBlock(); - Item* const pItem = &newBlock.pItems[0]; - newBlock.FirstFreeIndex = pItem->NextFreeIndex; - T* result = (T*)&pItem->Value; - new(result)T(std::forward(args)...); // Explicit constructor call. - return result; -} +#ifndef _VMA_BLOCK_METADATA_GENERIC_FUNCTIONS +VmaBlockMetadata_Generic::VmaBlockMetadata_Generic(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual) + : VmaBlockMetadata(pAllocationCallbacks, bufferImageGranularity, isVirtual), + m_FreeCount(0), + m_SumFreeSize(0), + m_Suballocations(VmaStlAllocator(pAllocationCallbacks)), + m_FreeSuballocationsBySize(VmaStlAllocator(pAllocationCallbacks)) {} -template -void VmaPoolAllocator::Free(T* ptr) +void VmaBlockMetadata_Generic::Init(VkDeviceSize size) { - // Search all memory blocks to find ptr. - for(size_t i = m_ItemBlocks.size(); i--; ) - { - ItemBlock& block = m_ItemBlocks[i]; + VmaBlockMetadata::Init(size); - // Casting to union. - Item* pItemPtr; - memcpy(&pItemPtr, &ptr, sizeof(pItemPtr)); + m_FreeCount = 1; + m_SumFreeSize = size; - // Check if pItemPtr is in address range of this block. - if((pItemPtr >= block.pItems) && (pItemPtr < block.pItems + block.Capacity)) - { - ptr->~T(); // Explicit destructor call. - const uint32_t index = static_cast(pItemPtr - block.pItems); - pItemPtr->NextFreeIndex = block.FirstFreeIndex; - block.FirstFreeIndex = index; - return; - } - } - VMA_ASSERT(0 && "Pointer doesn't belong to this memory pool."); + VmaSuballocation suballoc = {}; + suballoc.offset = 0; + suballoc.size = size; + suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + + m_Suballocations.push_back(suballoc); + m_FreeSuballocationsBySize.push_back(m_Suballocations.begin()); } -template -typename VmaPoolAllocator::ItemBlock& VmaPoolAllocator::CreateNewBlock() +bool VmaBlockMetadata_Generic::Validate() const { - const uint32_t newBlockCapacity = m_ItemBlocks.empty() ? - m_FirstBlockCapacity : m_ItemBlocks.back().Capacity * 3 / 2; - - const ItemBlock newBlock = { - vma_new_array(m_pAllocationCallbacks, Item, newBlockCapacity), - newBlockCapacity, - 0 }; - - m_ItemBlocks.push_back(newBlock); + VMA_VALIDATE(!m_Suballocations.empty()); - // Setup singly-linked list of all free items in this block. - for(uint32_t i = 0; i < newBlockCapacity - 1; ++i) - newBlock.pItems[i].NextFreeIndex = i + 1; - newBlock.pItems[newBlockCapacity - 1].NextFreeIndex = UINT32_MAX; - return m_ItemBlocks.back(); -} + // Expected offset of new suballocation as calculated from previous ones. + VkDeviceSize calculatedOffset = 0; + // Expected number of free suballocations as calculated from traversing their list. + uint32_t calculatedFreeCount = 0; + // Expected sum size of free suballocations as calculated from traversing their list. + VkDeviceSize calculatedSumFreeSize = 0; + // Expected number of free suballocations that should be registered in + // m_FreeSuballocationsBySize calculated from traversing their list. + size_t freeSuballocationsToRegister = 0; + // True if previous visited suballocation was free. + bool prevFree = false; -//////////////////////////////////////////////////////////////////////////////// -// class VmaRawList, VmaList + const VkDeviceSize debugMargin = GetDebugMargin(); -#if VMA_USE_STL_LIST + for (const auto& subAlloc : m_Suballocations) + { + // Actual offset of this suballocation doesn't match expected one. + VMA_VALIDATE(subAlloc.offset == calculatedOffset); -#define VmaList std::list + const bool currFree = (subAlloc.type == VMA_SUBALLOCATION_TYPE_FREE); + // Two adjacent free suballocations are invalid. They should be merged. + VMA_VALIDATE(!prevFree || !currFree); -#else // #if VMA_USE_STL_LIST + VmaAllocation alloc = (VmaAllocation)subAlloc.userData; + if (!IsVirtual()) + { + VMA_VALIDATE(currFree == (alloc == VK_NULL_HANDLE)); + } -template -struct VmaListItem -{ - VmaListItem* pPrev; - VmaListItem* pNext; - T Value; -}; + if (currFree) + { + calculatedSumFreeSize += subAlloc.size; + ++calculatedFreeCount; + ++freeSuballocationsToRegister; -// Doubly linked list. -template -class VmaRawList -{ - VMA_CLASS_NO_COPY(VmaRawList) -public: - typedef VmaListItem ItemType; + // Margin required between allocations - every free space must be at least that large. + VMA_VALIDATE(subAlloc.size >= debugMargin); + } + else + { + if (!IsVirtual()) + { + VMA_VALIDATE((VkDeviceSize)alloc->GetAllocHandle() == subAlloc.offset + 1); + VMA_VALIDATE(alloc->GetSize() == subAlloc.size); + } - VmaRawList(const VkAllocationCallbacks* pAllocationCallbacks); - ~VmaRawList(); - void Clear(); + // Margin required between allocations - previous allocation must be free. + VMA_VALIDATE(debugMargin == 0 || prevFree); + } - size_t GetCount() const { return m_Count; } - bool IsEmpty() const { return m_Count == 0; } + calculatedOffset += subAlloc.size; + prevFree = currFree; + } - ItemType* Front() { return m_pFront; } - const ItemType* Front() const { return m_pFront; } - ItemType* Back() { return m_pBack; } - const ItemType* Back() const { return m_pBack; } + // Number of free suballocations registered in m_FreeSuballocationsBySize doesn't + // match expected one. + VMA_VALIDATE(m_FreeSuballocationsBySize.size() == freeSuballocationsToRegister); - ItemType* PushBack(); - ItemType* PushFront(); - ItemType* PushBack(const T& value); - ItemType* PushFront(const T& value); - void PopBack(); - void PopFront(); + VkDeviceSize lastSize = 0; + for (size_t i = 0; i < m_FreeSuballocationsBySize.size(); ++i) + { + VmaSuballocationList::iterator suballocItem = m_FreeSuballocationsBySize[i]; - // Item can be null - it means PushBack. - ItemType* InsertBefore(ItemType* pItem); - // Item can be null - it means PushFront. - ItemType* InsertAfter(ItemType* pItem); + // Only free suballocations can be registered in m_FreeSuballocationsBySize. + VMA_VALIDATE(suballocItem->type == VMA_SUBALLOCATION_TYPE_FREE); + // They must be sorted by size ascending. + VMA_VALIDATE(suballocItem->size >= lastSize); - ItemType* InsertBefore(ItemType* pItem, const T& value); - ItemType* InsertAfter(ItemType* pItem, const T& value); + lastSize = suballocItem->size; + } - void Remove(ItemType* pItem); + // Check if totals match calculated values. + VMA_VALIDATE(ValidateFreeSuballocationList()); + VMA_VALIDATE(calculatedOffset == GetSize()); + VMA_VALIDATE(calculatedSumFreeSize == m_SumFreeSize); + VMA_VALIDATE(calculatedFreeCount == m_FreeCount); -private: - const VkAllocationCallbacks* const m_pAllocationCallbacks; - VmaPoolAllocator m_ItemAllocator; - ItemType* m_pFront; - ItemType* m_pBack; - size_t m_Count; -}; + return true; +} -template -VmaRawList::VmaRawList(const VkAllocationCallbacks* pAllocationCallbacks) : - m_pAllocationCallbacks(pAllocationCallbacks), - m_ItemAllocator(pAllocationCallbacks, 128), - m_pFront(VMA_NULL), - m_pBack(VMA_NULL), - m_Count(0) +void VmaBlockMetadata_Generic::AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const { + const uint32_t rangeCount = (uint32_t)m_Suballocations.size(); + inoutStats.statistics.blockCount++; + inoutStats.statistics.blockBytes += GetSize(); + + for (const auto& suballoc : m_Suballocations) + { + if (suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + VmaAddDetailedStatisticsAllocation(inoutStats, suballoc.size); + else + VmaAddDetailedStatisticsUnusedRange(inoutStats, suballoc.size); + } } -template -VmaRawList::~VmaRawList() +void VmaBlockMetadata_Generic::AddStatistics(VmaStatistics& inoutStats) const { - // Intentionally not calling Clear, because that would be unnecessary - // computations to return all items to m_ItemAllocator as free. + inoutStats.blockCount++; + inoutStats.allocationCount += (uint32_t)m_Suballocations.size() - m_FreeCount; + inoutStats.blockBytes += GetSize(); + inoutStats.allocationBytes += GetSize() - m_SumFreeSize; } -template -void VmaRawList::Clear() +#if VMA_STATS_STRING_ENABLED +void VmaBlockMetadata_Generic::PrintDetailedMap(class VmaJsonWriter& json, uint32_t mapRefCount) const { - if(IsEmpty() == false) + PrintDetailedMap_Begin(json, + m_SumFreeSize, // unusedBytes + m_Suballocations.size() - (size_t)m_FreeCount, // allocationCount + m_FreeCount, // unusedRangeCount + mapRefCount); + + for (const auto& suballoc : m_Suballocations) { - ItemType* pItem = m_pBack; - while(pItem != VMA_NULL) + if (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE) { - ItemType* const pPrevItem = pItem->pPrev; - m_ItemAllocator.Free(pItem); - pItem = pPrevItem; + PrintDetailedMap_UnusedRange(json, suballoc.offset, suballoc.size); + } + else + { + PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.size, suballoc.userData); } - m_pFront = VMA_NULL; - m_pBack = VMA_NULL; - m_Count = 0; } -} -template -VmaListItem* VmaRawList::PushBack() + PrintDetailedMap_End(json); +} +#endif // VMA_STATS_STRING_ENABLED + +bool VmaBlockMetadata_Generic::CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) { - ItemType* const pNewItem = m_ItemAllocator.Alloc(); - pNewItem->pNext = VMA_NULL; - if(IsEmpty()) + VMA_ASSERT(allocSize > 0); + VMA_ASSERT(!upperAddress); + VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(pAllocationRequest != VMA_NULL); + VMA_HEAVY_ASSERT(Validate()); + + allocSize = AlignAllocationSize(allocSize); + + pAllocationRequest->type = VmaAllocationRequestType::Normal; + pAllocationRequest->size = allocSize; + + const VkDeviceSize debugMargin = GetDebugMargin(); + + // There is not enough total free space in this block to fulfill the request: Early return. + if (m_SumFreeSize < allocSize + debugMargin) { - pNewItem->pPrev = VMA_NULL; - m_pFront = pNewItem; - m_pBack = pNewItem; - m_Count = 1; + return false; } - else + + // New algorithm, efficiently searching freeSuballocationsBySize. + const size_t freeSuballocCount = m_FreeSuballocationsBySize.size(); + if (freeSuballocCount > 0) { - pNewItem->pPrev = m_pBack; - m_pBack->pNext = pNewItem; - m_pBack = pNewItem; - ++m_Count; + if (strategy == 0 || + strategy == VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT) + { + // Find first free suballocation with size not less than allocSize + debugMargin. + VmaSuballocationList::iterator* const it = VmaBinaryFindFirstNotLess( + m_FreeSuballocationsBySize.data(), + m_FreeSuballocationsBySize.data() + freeSuballocCount, + allocSize + debugMargin, + VmaSuballocationItemSizeLess()); + size_t index = it - m_FreeSuballocationsBySize.data(); + for (; index < freeSuballocCount; ++index) + { + if (CheckAllocation( + allocSize, + allocAlignment, + allocType, + m_FreeSuballocationsBySize[index], + &pAllocationRequest->allocHandle)) + { + pAllocationRequest->item = m_FreeSuballocationsBySize[index]; + return true; + } + } + } + else if (strategy == VMA_ALLOCATION_INTERNAL_STRATEGY_MIN_OFFSET) + { + for (VmaSuballocationList::iterator it = m_Suballocations.begin(); + it != m_Suballocations.end(); + ++it) + { + if (it->type == VMA_SUBALLOCATION_TYPE_FREE && CheckAllocation( + allocSize, + allocAlignment, + allocType, + it, + &pAllocationRequest->allocHandle)) + { + pAllocationRequest->item = it; + return true; + } + } + } + else + { + VMA_ASSERT(strategy & (VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT | VMA_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT )); + // Search staring from biggest suballocations. + for (size_t index = freeSuballocCount; index--; ) + { + if (CheckAllocation( + allocSize, + allocAlignment, + allocType, + m_FreeSuballocationsBySize[index], + &pAllocationRequest->allocHandle)) + { + pAllocationRequest->item = m_FreeSuballocationsBySize[index]; + return true; + } + } + } } - return pNewItem; + + return false; } -template -VmaListItem* VmaRawList::PushFront() +VkResult VmaBlockMetadata_Generic::CheckCorruption(const void* pBlockData) { - ItemType* const pNewItem = m_ItemAllocator.Alloc(); - pNewItem->pPrev = VMA_NULL; - if(IsEmpty()) + for (auto& suballoc : m_Suballocations) { - pNewItem->pNext = VMA_NULL; - m_pFront = pNewItem; - m_pBack = pNewItem; - m_Count = 1; + if (suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + { + if (!VmaValidateMagicValue(pBlockData, suballoc.offset + suballoc.size)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); + return VK_ERROR_UNKNOWN_COPY; + } + } } - else + + return VK_SUCCESS; +} + +void VmaBlockMetadata_Generic::Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) +{ + VMA_ASSERT(request.type == VmaAllocationRequestType::Normal); + VMA_ASSERT(request.item != m_Suballocations.end()); + VmaSuballocation& suballoc = *request.item; + // Given suballocation is a free block. + VMA_ASSERT(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + + // Given offset is inside this suballocation. + VMA_ASSERT((VkDeviceSize)request.allocHandle - 1 >= suballoc.offset); + const VkDeviceSize paddingBegin = (VkDeviceSize)request.allocHandle - suballoc.offset - 1; + VMA_ASSERT(suballoc.size >= paddingBegin + request.size); + const VkDeviceSize paddingEnd = suballoc.size - paddingBegin - request.size; + + // Unregister this free suballocation from m_FreeSuballocationsBySize and update + // it to become used. + UnregisterFreeSuballocation(request.item); + + suballoc.offset = (VkDeviceSize)request.allocHandle - 1; + suballoc.size = request.size; + suballoc.type = type; + suballoc.userData = userData; + + // If there are any free bytes remaining at the end, insert new free suballocation after current one. + if (paddingEnd) { - pNewItem->pNext = m_pFront; - m_pFront->pPrev = pNewItem; - m_pFront = pNewItem; - ++m_Count; + VmaSuballocation paddingSuballoc = {}; + paddingSuballoc.offset = suballoc.offset + suballoc.size; + paddingSuballoc.size = paddingEnd; + paddingSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + VmaSuballocationList::iterator next = request.item; + ++next; + const VmaSuballocationList::iterator paddingEndItem = + m_Suballocations.insert(next, paddingSuballoc); + RegisterFreeSuballocation(paddingEndItem); } - return pNewItem; + + // If there are any free bytes remaining at the beginning, insert new free suballocation before current one. + if (paddingBegin) + { + VmaSuballocation paddingSuballoc = {}; + paddingSuballoc.offset = suballoc.offset - paddingBegin; + paddingSuballoc.size = paddingBegin; + paddingSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + const VmaSuballocationList::iterator paddingBeginItem = + m_Suballocations.insert(request.item, paddingSuballoc); + RegisterFreeSuballocation(paddingBeginItem); + } + + // Update totals. + m_FreeCount = m_FreeCount - 1; + if (paddingBegin > 0) + { + ++m_FreeCount; + } + if (paddingEnd > 0) + { + ++m_FreeCount; + } + m_SumFreeSize -= request.size; } -template -VmaListItem* VmaRawList::PushBack(const T& value) +void VmaBlockMetadata_Generic::GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) { - ItemType* const pNewItem = PushBack(); - pNewItem->Value = value; - return pNewItem; + outInfo.offset = (VkDeviceSize)allocHandle - 1; + const VmaSuballocation& suballoc = *FindAtOffset(outInfo.offset); + outInfo.size = suballoc.size; + outInfo.pUserData = suballoc.userData; } -template -VmaListItem* VmaRawList::PushFront(const T& value) +void* VmaBlockMetadata_Generic::GetAllocationUserData(VmaAllocHandle allocHandle) const { - ItemType* const pNewItem = PushFront(); - pNewItem->Value = value; - return pNewItem; + return FindAtOffset((VkDeviceSize)allocHandle - 1)->userData; } -template -void VmaRawList::PopBack() +VmaAllocHandle VmaBlockMetadata_Generic::GetAllocationListBegin() const { - VMA_HEAVY_ASSERT(m_Count > 0); - ItemType* const pBackItem = m_pBack; - ItemType* const pPrevItem = pBackItem->pPrev; - if(pPrevItem != VMA_NULL) + if (IsEmpty()) + return VK_NULL_HANDLE; + + for (const auto& suballoc : m_Suballocations) { - pPrevItem->pNext = VMA_NULL; + if (suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + return (VmaAllocHandle)(suballoc.offset + 1); } - m_pBack = pPrevItem; - m_ItemAllocator.Free(pBackItem); - --m_Count; + VMA_ASSERT(false && "Should contain at least 1 allocation!"); + return VK_NULL_HANDLE; } -template -void VmaRawList::PopFront() +VmaAllocHandle VmaBlockMetadata_Generic::GetNextAllocation(VmaAllocHandle prevAlloc) const { - VMA_HEAVY_ASSERT(m_Count > 0); - ItemType* const pFrontItem = m_pFront; - ItemType* const pNextItem = pFrontItem->pNext; - if(pNextItem != VMA_NULL) + VmaSuballocationList::const_iterator prev = FindAtOffset((VkDeviceSize)prevAlloc - 1); + + for (VmaSuballocationList::const_iterator it = ++prev; it != m_Suballocations.end(); ++it) { - pNextItem->pPrev = VMA_NULL; + if (it->type != VMA_SUBALLOCATION_TYPE_FREE) + return (VmaAllocHandle)(it->offset + 1); } - m_pFront = pNextItem; - m_ItemAllocator.Free(pFrontItem); - --m_Count; + return VK_NULL_HANDLE; } -template -void VmaRawList::Remove(ItemType* pItem) +void VmaBlockMetadata_Generic::Clear() { - VMA_HEAVY_ASSERT(pItem != VMA_NULL); - VMA_HEAVY_ASSERT(m_Count > 0); + const VkDeviceSize size = GetSize(); - if(pItem->pPrev != VMA_NULL) - { - pItem->pPrev->pNext = pItem->pNext; - } - else - { - VMA_HEAVY_ASSERT(m_pFront == pItem); - m_pFront = pItem->pNext; - } + VMA_ASSERT(IsVirtual()); + m_FreeCount = 1; + m_SumFreeSize = size; + m_Suballocations.clear(); + m_FreeSuballocationsBySize.clear(); - if(pItem->pNext != VMA_NULL) - { - pItem->pNext->pPrev = pItem->pPrev; - } - else + VmaSuballocation suballoc = {}; + suballoc.offset = 0; + suballoc.size = size; + suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + m_Suballocations.push_back(suballoc); + + m_FreeSuballocationsBySize.push_back(m_Suballocations.begin()); +} + +void VmaBlockMetadata_Generic::SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) +{ + VmaSuballocation& suballoc = *FindAtOffset((VkDeviceSize)allocHandle - 1); + suballoc.userData = userData; +} + +void VmaBlockMetadata_Generic::DebugLogAllAllocations() const +{ + for (const auto& suballoc : m_Suballocations) { - VMA_HEAVY_ASSERT(m_pBack == pItem); - m_pBack = pItem->pPrev; + if (suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + DebugLogAllocation(suballoc.offset, suballoc.size, suballoc.userData); } - - m_ItemAllocator.Free(pItem); - --m_Count; } -template -VmaListItem* VmaRawList::InsertBefore(ItemType* pItem) +VmaSuballocationList::iterator VmaBlockMetadata_Generic::FindAtOffset(VkDeviceSize offset) const { - if(pItem != VMA_NULL) + VMA_HEAVY_ASSERT(!m_Suballocations.empty()); + const VkDeviceSize last = m_Suballocations.rbegin()->offset; + if (last == offset) + return m_Suballocations.rbegin().drop_const(); + const VkDeviceSize first = m_Suballocations.begin()->offset; + if (first == offset) + return m_Suballocations.begin().drop_const(); + + const size_t suballocCount = m_Suballocations.size(); + const VkDeviceSize step = (last - first + m_Suballocations.begin()->size) / suballocCount; + auto findSuballocation = [&](auto begin, auto end) -> VmaSuballocationList::iterator { - ItemType* const prevItem = pItem->pPrev; - ItemType* const newItem = m_ItemAllocator.Alloc(); - newItem->pPrev = prevItem; - newItem->pNext = pItem; - pItem->pPrev = newItem; - if(prevItem != VMA_NULL) - { - prevItem->pNext = newItem; - } - else + for (auto suballocItem = begin; + suballocItem != end; + ++suballocItem) { - VMA_HEAVY_ASSERT(m_pFront == pItem); - m_pFront = newItem; + if (suballocItem->offset == offset) + return suballocItem.drop_const(); } - ++m_Count; - return newItem; + VMA_ASSERT(false && "Not found!"); + return m_Suballocations.end().drop_const(); + }; + // If requested offset is closer to the end of range, search from the end + if (offset - first > suballocCount * step / 2) + { + return findSuballocation(m_Suballocations.rbegin(), m_Suballocations.rend()); } - else - return PushBack(); + return findSuballocation(m_Suballocations.begin(), m_Suballocations.end()); } -template -VmaListItem* VmaRawList::InsertAfter(ItemType* pItem) +bool VmaBlockMetadata_Generic::ValidateFreeSuballocationList() const { - if(pItem != VMA_NULL) + VkDeviceSize lastSize = 0; + for (size_t i = 0, count = m_FreeSuballocationsBySize.size(); i < count; ++i) { - ItemType* const nextItem = pItem->pNext; - ItemType* const newItem = m_ItemAllocator.Alloc(); - newItem->pNext = nextItem; - newItem->pPrev = pItem; - pItem->pNext = newItem; - if(nextItem != VMA_NULL) - { - nextItem->pPrev = newItem; - } - else - { - VMA_HEAVY_ASSERT(m_pBack == pItem); - m_pBack = newItem; - } - ++m_Count; - return newItem; + const VmaSuballocationList::iterator it = m_FreeSuballocationsBySize[i]; + + VMA_VALIDATE(it->type == VMA_SUBALLOCATION_TYPE_FREE); + VMA_VALIDATE(it->size >= lastSize); + lastSize = it->size; } - else - return PushFront(); + return true; } -template -VmaListItem* VmaRawList::InsertBefore(ItemType* pItem, const T& value) +bool VmaBlockMetadata_Generic::CheckAllocation( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + VmaSuballocationList::const_iterator suballocItem, + VmaAllocHandle* pAllocHandle) const { - ItemType* const newItem = InsertBefore(pItem); - newItem->Value = value; - return newItem; -} + VMA_ASSERT(allocSize > 0); + VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(suballocItem != m_Suballocations.cend()); + VMA_ASSERT(pAllocHandle != VMA_NULL); -template -VmaListItem* VmaRawList::InsertAfter(ItemType* pItem, const T& value) -{ - ItemType* const newItem = InsertAfter(pItem); - newItem->Value = value; - return newItem; -} + const VkDeviceSize debugMargin = GetDebugMargin(); + const VkDeviceSize bufferImageGranularity = GetBufferImageGranularity(); -template -class VmaList -{ - VMA_CLASS_NO_COPY(VmaList) -public: - class iterator + const VmaSuballocation& suballoc = *suballocItem; + VMA_ASSERT(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + + // Size of this suballocation is too small for this request: Early return. + if (suballoc.size < allocSize) { - public: - iterator() : - m_pList(VMA_NULL), - m_pItem(VMA_NULL) - { - } + return false; + } - T& operator*() const - { - VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); - return m_pItem->Value; - } - T* operator->() const - { - VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); - return &m_pItem->Value; - } + // Start from offset equal to beginning of this suballocation. + VkDeviceSize offset = suballoc.offset + (suballocItem == m_Suballocations.cbegin() ? 0 : GetDebugMargin()); - iterator& operator++() - { - VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); - m_pItem = m_pItem->pNext; - return *this; - } - iterator& operator--() + // Apply debugMargin from the end of previous alloc. + if (debugMargin > 0) + { + offset += debugMargin; + } + + // Apply alignment. + offset = VmaAlignUp(offset, allocAlignment); + + // Check previous suballocations for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if (bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment) + { + bool bufferImageGranularityConflict = false; + VmaSuballocationList::const_iterator prevSuballocItem = suballocItem; + while (prevSuballocItem != m_Suballocations.cbegin()) { - if(m_pItem != VMA_NULL) + --prevSuballocItem; + const VmaSuballocation& prevSuballoc = *prevSuballocItem; + if (VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, offset, bufferImageGranularity)) { - m_pItem = m_pItem->pPrev; + if (VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } } else - { - VMA_HEAVY_ASSERT(!m_pList->IsEmpty()); - m_pItem = m_pList->Back(); - } - return *this; - } - - iterator operator++(int) - { - iterator result = *this; - ++*this; - return result; - } - iterator operator--(int) - { - iterator result = *this; - --*this; - return result; - } - - bool operator==(const iterator& rhs) const - { - VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); - return m_pItem == rhs.m_pItem; - } - bool operator!=(const iterator& rhs) const - { - VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); - return m_pItem != rhs.m_pItem; + // Already on previous page. + break; } - - private: - VmaRawList* m_pList; - VmaListItem* m_pItem; - - iterator(VmaRawList* pList, VmaListItem* pItem) : - m_pList(pList), - m_pItem(pItem) + if (bufferImageGranularityConflict) { + offset = VmaAlignUp(offset, bufferImageGranularity); } + } - friend class VmaList; - }; + // Calculate padding at the beginning based on current offset. + const VkDeviceSize paddingBegin = offset - suballoc.offset; - class const_iterator + // Fail if requested size plus margin after is bigger than size of this suballocation. + if (paddingBegin + allocSize + debugMargin > suballoc.size) { - public: - const_iterator() : - m_pList(VMA_NULL), - m_pItem(VMA_NULL) - { - } - - const_iterator(const iterator& src) : - m_pList(src.m_pList), - m_pItem(src.m_pItem) - { - } - - const T& operator*() const - { - VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); - return m_pItem->Value; - } - const T* operator->() const - { - VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); - return &m_pItem->Value; - } + return false; + } - const_iterator& operator++() - { - VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); - m_pItem = m_pItem->pNext; - return *this; - } - const_iterator& operator--() + // Check next suballocations for BufferImageGranularity conflicts. + // If conflict exists, allocation cannot be made here. + if (allocSize % bufferImageGranularity || offset % bufferImageGranularity) + { + VmaSuballocationList::const_iterator nextSuballocItem = suballocItem; + ++nextSuballocItem; + while (nextSuballocItem != m_Suballocations.cend()) { - if(m_pItem != VMA_NULL) + const VmaSuballocation& nextSuballoc = *nextSuballocItem; + if (VmaBlocksOnSamePage(offset, allocSize, nextSuballoc.offset, bufferImageGranularity)) { - m_pItem = m_pItem->pPrev; + if (VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) + { + return false; + } } else { - VMA_HEAVY_ASSERT(!m_pList->IsEmpty()); - m_pItem = m_pList->Back(); + // Already on next page. + break; } - return *this; - } - - const_iterator operator++(int) - { - const_iterator result = *this; - ++*this; - return result; - } - const_iterator operator--(int) - { - const_iterator result = *this; - --*this; - return result; - } - - bool operator==(const const_iterator& rhs) const - { - VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); - return m_pItem == rhs.m_pItem; - } - bool operator!=(const const_iterator& rhs) const - { - VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); - return m_pItem != rhs.m_pItem; - } - - private: - const_iterator(const VmaRawList* pList, const VmaListItem* pItem) : - m_pList(pList), - m_pItem(pItem) - { + ++nextSuballocItem; } + } - const VmaRawList* m_pList; - const VmaListItem* m_pItem; + *pAllocHandle = (VmaAllocHandle)(offset + 1); + // All tests passed: Success. pAllocHandle is already filled. + return true; +} - friend class VmaList; - }; +void VmaBlockMetadata_Generic::MergeFreeWithNext(VmaSuballocationList::iterator item) +{ + VMA_ASSERT(item != m_Suballocations.end()); + VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); - VmaList(const AllocatorT& allocator) : m_RawList(allocator.m_pCallbacks) { } + VmaSuballocationList::iterator nextItem = item; + ++nextItem; + VMA_ASSERT(nextItem != m_Suballocations.end()); + VMA_ASSERT(nextItem->type == VMA_SUBALLOCATION_TYPE_FREE); - bool empty() const { return m_RawList.IsEmpty(); } - size_t size() const { return m_RawList.GetCount(); } + item->size += nextItem->size; + --m_FreeCount; + m_Suballocations.erase(nextItem); +} - iterator begin() { return iterator(&m_RawList, m_RawList.Front()); } - iterator end() { return iterator(&m_RawList, VMA_NULL); } +VmaSuballocationList::iterator VmaBlockMetadata_Generic::FreeSuballocation(VmaSuballocationList::iterator suballocItem) +{ + // Change this suballocation to be marked as free. + VmaSuballocation& suballoc = *suballocItem; + suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + suballoc.userData = VMA_NULL; - const_iterator cbegin() const { return const_iterator(&m_RawList, m_RawList.Front()); } - const_iterator cend() const { return const_iterator(&m_RawList, VMA_NULL); } + // Update totals. + ++m_FreeCount; + m_SumFreeSize += suballoc.size; - void clear() { m_RawList.Clear(); } - void push_back(const T& value) { m_RawList.PushBack(value); } - void erase(iterator it) { m_RawList.Remove(it.m_pItem); } - iterator insert(iterator it, const T& value) { return iterator(&m_RawList, m_RawList.InsertBefore(it.m_pItem, value)); } + // Merge with previous and/or next suballocation if it's also free. + bool mergeWithNext = false; + bool mergeWithPrev = false; -private: - VmaRawList m_RawList; -}; + VmaSuballocationList::iterator nextItem = suballocItem; + ++nextItem; + if ((nextItem != m_Suballocations.end()) && (nextItem->type == VMA_SUBALLOCATION_TYPE_FREE)) + { + mergeWithNext = true; + } -#endif // #if VMA_USE_STL_LIST + VmaSuballocationList::iterator prevItem = suballocItem; + if (suballocItem != m_Suballocations.begin()) + { + --prevItem; + if (prevItem->type == VMA_SUBALLOCATION_TYPE_FREE) + { + mergeWithPrev = true; + } + } -//////////////////////////////////////////////////////////////////////////////// -// class VmaMap + if (mergeWithNext) + { + UnregisterFreeSuballocation(nextItem); + MergeFreeWithNext(suballocItem); + } -// Unused in this version. -#if 0 + if (mergeWithPrev) + { + UnregisterFreeSuballocation(prevItem); + MergeFreeWithNext(prevItem); + RegisterFreeSuballocation(prevItem); + return prevItem; + } + else + { + RegisterFreeSuballocation(suballocItem); + return suballocItem; + } +} -#if VMA_USE_STL_UNORDERED_MAP +void VmaBlockMetadata_Generic::RegisterFreeSuballocation(VmaSuballocationList::iterator item) +{ + VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(item->size > 0); -#define VmaPair std::pair + // You may want to enable this validation at the beginning or at the end of + // this function, depending on what do you want to check. + VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); -#define VMA_MAP_TYPE(KeyT, ValueT) \ - std::unordered_map< KeyT, ValueT, std::hash, std::equal_to, VmaStlAllocator< std::pair > > + if (m_FreeSuballocationsBySize.empty()) + { + m_FreeSuballocationsBySize.push_back(item); + } + else + { + VmaVectorInsertSorted(m_FreeSuballocationsBySize, item); + } -#else // #if VMA_USE_STL_UNORDERED_MAP + //VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); +} -template -struct VmaPair +void VmaBlockMetadata_Generic::UnregisterFreeSuballocation(VmaSuballocationList::iterator item) { - T1 first; - T2 second; + VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(item->size > 0); - VmaPair() : first(), second() { } - VmaPair(const T1& firstSrc, const T2& secondSrc) : first(firstSrc), second(secondSrc) { } -}; + // You may want to enable this validation at the beginning or at the end of + // this function, depending on what do you want to check. + VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); -/* Class compatible with subset of interface of std::unordered_map. -KeyT, ValueT must be POD because they will be stored in VmaVector. -*/ -template -class VmaMap -{ -public: - typedef VmaPair PairType; - typedef PairType* iterator; + VmaSuballocationList::iterator* const it = VmaBinaryFindFirstNotLess( + m_FreeSuballocationsBySize.data(), + m_FreeSuballocationsBySize.data() + m_FreeSuballocationsBySize.size(), + item, + VmaSuballocationItemSizeLess()); + for (size_t index = it - m_FreeSuballocationsBySize.data(); + index < m_FreeSuballocationsBySize.size(); + ++index) + { + if (m_FreeSuballocationsBySize[index] == item) + { + VmaVectorRemove(m_FreeSuballocationsBySize, index); + return; + } + VMA_ASSERT((m_FreeSuballocationsBySize[index]->size == item->size) && "Not found."); + } + VMA_ASSERT(0 && "Not found."); - VmaMap(const VmaStlAllocator& allocator) : m_Vector(allocator) { } + //VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); +} +#endif // _VMA_BLOCK_METADATA_GENERIC_FUNCTIONS +#endif // _VMA_BLOCK_METADATA_GENERIC +#endif // #if 0 - iterator begin() { return m_Vector.begin(); } - iterator end() { return m_Vector.end(); } +#ifndef _VMA_BLOCK_METADATA_LINEAR +/* +Allocations and their references in internal data structure look like this: - void insert(const PairType& pair); - iterator find(const KeyT& key); - void erase(iterator it); +if(m_2ndVectorMode == SECOND_VECTOR_EMPTY): -private: - VmaVector< PairType, VmaStlAllocator > m_Vector; -}; + 0 +-------+ + | | + | | + | | + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount] + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount + 1] + +-------+ + | ... | + +-------+ + | Alloc | 1st[1st.size() - 1] + +-------+ + | | + | | + | | +GetSize() +-------+ -#define VMA_MAP_TYPE(KeyT, ValueT) VmaMap +if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER): -template -struct VmaPairFirstLess -{ - bool operator()(const VmaPair& lhs, const VmaPair& rhs) const - { - return lhs.first < rhs.first; - } - bool operator()(const VmaPair& lhs, const FirstT& rhsFirst) const - { - return lhs.first < rhsFirst; - } -}; + 0 +-------+ + | Alloc | 2nd[0] + +-------+ + | Alloc | 2nd[1] + +-------+ + | ... | + +-------+ + | Alloc | 2nd[2nd.size() - 1] + +-------+ + | | + | | + | | + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount] + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount + 1] + +-------+ + | ... | + +-------+ + | Alloc | 1st[1st.size() - 1] + +-------+ + | | +GetSize() +-------+ -template -void VmaMap::insert(const PairType& pair) -{ - const size_t indexToInsert = VmaBinaryFindFirstNotLess( - m_Vector.data(), - m_Vector.data() + m_Vector.size(), - pair, - VmaPairFirstLess()) - m_Vector.data(); - VmaVectorInsert(m_Vector, indexToInsert, pair); -} +if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK): -template -VmaPair* VmaMap::find(const KeyT& key) -{ - PairType* it = VmaBinaryFindFirstNotLess( - m_Vector.data(), - m_Vector.data() + m_Vector.size(), - key, - VmaPairFirstLess()); - if((it != m_Vector.end()) && (it->first == key)) - { - return it; - } - else - { - return m_Vector.end(); - } -} + 0 +-------+ + | | + | | + | | + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount] + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount + 1] + +-------+ + | ... | + +-------+ + | Alloc | 1st[1st.size() - 1] + +-------+ + | | + | | + | | + +-------+ + | Alloc | 2nd[2nd.size() - 1] + +-------+ + | ... | + +-------+ + | Alloc | 2nd[1] + +-------+ + | Alloc | 2nd[0] +GetSize() +-------+ -template -void VmaMap::erase(iterator it) +*/ +class VmaBlockMetadata_Linear : public VmaBlockMetadata { - VmaVectorRemove(m_Vector, it - m_Vector.begin()); -} - -#endif // #if VMA_USE_STL_UNORDERED_MAP + VMA_CLASS_NO_COPY(VmaBlockMetadata_Linear) +public: + VmaBlockMetadata_Linear(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual); + virtual ~VmaBlockMetadata_Linear() = default; -#endif // #if 0 + VkDeviceSize GetSumFreeSize() const override { return m_SumFreeSize; } + bool IsEmpty() const override { return GetAllocationCount() == 0; } + VkDeviceSize GetAllocationOffset(VmaAllocHandle allocHandle) const override { return (VkDeviceSize)allocHandle - 1; }; -//////////////////////////////////////////////////////////////////////////////// + void Init(VkDeviceSize size) override; + bool Validate() const override; + size_t GetAllocationCount() const override; + size_t GetFreeRegionsCount() const override; -class VmaDeviceMemoryBlock; + void AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const override; + void AddStatistics(VmaStatistics& inoutStats) const override; -enum VMA_CACHE_OPERATION { VMA_CACHE_FLUSH, VMA_CACHE_INVALIDATE }; +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap(class VmaJsonWriter& json) const override; +#endif -struct VmaAllocation_T -{ -private: - static const uint8_t MAP_COUNT_FLAG_PERSISTENT_MAP = 0x80; + bool CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) override; - enum FLAGS - { - FLAG_USER_DATA_STRING = 0x01, - }; + VkResult CheckCorruption(const void* pBlockData) override; -public: - enum ALLOCATION_TYPE - { - ALLOCATION_TYPE_NONE, - ALLOCATION_TYPE_BLOCK, - ALLOCATION_TYPE_DEDICATED, - }; + void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) override; + + void Free(VmaAllocHandle allocHandle) override; + void GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) override; + void* GetAllocationUserData(VmaAllocHandle allocHandle) const override; + VmaAllocHandle GetAllocationListBegin() const override; + VmaAllocHandle GetNextAllocation(VmaAllocHandle prevAlloc) const override; + VkDeviceSize GetNextFreeRegionSize(VmaAllocHandle alloc) const override; + void Clear() override; + void SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) override; + void DebugLogAllAllocations() const override; +private: /* - This struct is allocated using VmaPoolAllocator. + There are two suballocation vectors, used in ping-pong way. + The one with index m_1stVectorIndex is called 1st. + The one with index (m_1stVectorIndex ^ 1) is called 2nd. + 2nd can be non-empty only when 1st is not empty. + When 2nd is not empty, m_2ndVectorMode indicates its mode of operation. */ + typedef VmaVector> SuballocationVectorType; - VmaAllocation_T(uint32_t currentFrameIndex, bool userDataString) : - m_Alignment{1}, - m_Size{0}, - m_pUserData{VMA_NULL}, - m_LastUseFrameIndex{currentFrameIndex}, - m_MemoryTypeIndex{0}, - m_Type{(uint8_t)ALLOCATION_TYPE_NONE}, - m_SuballocationType{(uint8_t)VMA_SUBALLOCATION_TYPE_UNKNOWN}, - m_MapCount{0}, - m_Flags{userDataString ? (uint8_t)FLAG_USER_DATA_STRING : (uint8_t)0} + enum SECOND_VECTOR_MODE { -#if VMA_STATS_STRING_ENABLED - m_CreationFrameIndex = currentFrameIndex; - m_BufferImageUsage = 0; -#endif - } + SECOND_VECTOR_EMPTY, + /* + Suballocations in 2nd vector are created later than the ones in 1st, but they + all have smaller offset. + */ + SECOND_VECTOR_RING_BUFFER, + /* + Suballocations in 2nd vector are upper side of double stack. + They all have offsets higher than those in 1st vector. + Top of this stack means smaller offsets, but higher indices in this vector. + */ + SECOND_VECTOR_DOUBLE_STACK, + }; - ~VmaAllocation_T() - { - VMA_ASSERT((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) == 0 && "Allocation was not unmapped before destruction."); + VkDeviceSize m_SumFreeSize; + SuballocationVectorType m_Suballocations0, m_Suballocations1; + uint32_t m_1stVectorIndex; + SECOND_VECTOR_MODE m_2ndVectorMode; + // Number of items in 1st vector with hAllocation = null at the beginning. + size_t m_1stNullItemsBeginCount; + // Number of other items in 1st vector with hAllocation = null somewhere in the middle. + size_t m_1stNullItemsMiddleCount; + // Number of items in 2nd vector with hAllocation = null. + size_t m_2ndNullItemsCount; - // Check if owned string was freed. - VMA_ASSERT(m_pUserData == VMA_NULL); - } + SuballocationVectorType& AccessSuballocations1st() { return m_1stVectorIndex ? m_Suballocations1 : m_Suballocations0; } + SuballocationVectorType& AccessSuballocations2nd() { return m_1stVectorIndex ? m_Suballocations0 : m_Suballocations1; } + const SuballocationVectorType& AccessSuballocations1st() const { return m_1stVectorIndex ? m_Suballocations1 : m_Suballocations0; } + const SuballocationVectorType& AccessSuballocations2nd() const { return m_1stVectorIndex ? m_Suballocations0 : m_Suballocations1; } - void InitBlockAllocation( - VmaDeviceMemoryBlock* block, - VkDeviceSize offset, - VkDeviceSize alignment, - VkDeviceSize size, - uint32_t memoryTypeIndex, - VmaSuballocationType suballocationType, - bool mapped, - bool canBecomeLost) - { - VMA_ASSERT(m_Type == ALLOCATION_TYPE_NONE); - VMA_ASSERT(block != VMA_NULL); - m_Type = (uint8_t)ALLOCATION_TYPE_BLOCK; - m_Alignment = alignment; - m_Size = size; - m_MemoryTypeIndex = memoryTypeIndex; - m_MapCount = mapped ? MAP_COUNT_FLAG_PERSISTENT_MAP : 0; - m_SuballocationType = (uint8_t)suballocationType; - m_BlockAllocation.m_Block = block; - m_BlockAllocation.m_Offset = offset; - m_BlockAllocation.m_CanBecomeLost = canBecomeLost; - } - - void InitLost() - { - VMA_ASSERT(m_Type == ALLOCATION_TYPE_NONE); - VMA_ASSERT(m_LastUseFrameIndex.load() == VMA_FRAME_INDEX_LOST); - m_Type = (uint8_t)ALLOCATION_TYPE_BLOCK; - m_MemoryTypeIndex = 0; - m_BlockAllocation.m_Block = VMA_NULL; - m_BlockAllocation.m_Offset = 0; - m_BlockAllocation.m_CanBecomeLost = true; - } - - void ChangeBlockAllocation( - VmaAllocator hAllocator, - VmaDeviceMemoryBlock* block, - VkDeviceSize offset); + VmaSuballocation& FindSuballocation(VkDeviceSize offset) const; + bool ShouldCompact1st() const; + void CleanupAfterFree(); - void ChangeOffset(VkDeviceSize newOffset); + bool CreateAllocationRequest_LowerAddress( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest); + bool CreateAllocationRequest_UpperAddress( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest); +}; - // pMappedData not null means allocation is created with MAPPED flag. - void InitDedicatedAllocation( - uint32_t memoryTypeIndex, - VkDeviceMemory hMemory, - VmaSuballocationType suballocationType, - void* pMappedData, - VkDeviceSize size) - { - VMA_ASSERT(m_Type == ALLOCATION_TYPE_NONE); - VMA_ASSERT(hMemory != VK_NULL_HANDLE); - m_Type = (uint8_t)ALLOCATION_TYPE_DEDICATED; - m_Alignment = 0; - m_Size = size; - m_MemoryTypeIndex = memoryTypeIndex; - m_SuballocationType = (uint8_t)suballocationType; - m_MapCount = (pMappedData != VMA_NULL) ? MAP_COUNT_FLAG_PERSISTENT_MAP : 0; - m_DedicatedAllocation.m_hMemory = hMemory; - m_DedicatedAllocation.m_pMappedData = pMappedData; - } +#ifndef _VMA_BLOCK_METADATA_LINEAR_FUNCTIONS +VmaBlockMetadata_Linear::VmaBlockMetadata_Linear(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual) + : VmaBlockMetadata(pAllocationCallbacks, bufferImageGranularity, isVirtual), + m_SumFreeSize(0), + m_Suballocations0(VmaStlAllocator(pAllocationCallbacks)), + m_Suballocations1(VmaStlAllocator(pAllocationCallbacks)), + m_1stVectorIndex(0), + m_2ndVectorMode(SECOND_VECTOR_EMPTY), + m_1stNullItemsBeginCount(0), + m_1stNullItemsMiddleCount(0), + m_2ndNullItemsCount(0) {} - ALLOCATION_TYPE GetType() const { return (ALLOCATION_TYPE)m_Type; } - VkDeviceSize GetAlignment() const { return m_Alignment; } - VkDeviceSize GetSize() const { return m_Size; } - bool IsUserDataString() const { return (m_Flags & FLAG_USER_DATA_STRING) != 0; } - void* GetUserData() const { return m_pUserData; } - void SetUserData(VmaAllocator hAllocator, void* pUserData); - VmaSuballocationType GetSuballocationType() const { return (VmaSuballocationType)m_SuballocationType; } +void VmaBlockMetadata_Linear::Init(VkDeviceSize size) +{ + VmaBlockMetadata::Init(size); + m_SumFreeSize = size; +} - VmaDeviceMemoryBlock* GetBlock() const - { - VMA_ASSERT(m_Type == ALLOCATION_TYPE_BLOCK); - return m_BlockAllocation.m_Block; - } - VkDeviceSize GetOffset() const; - VkDeviceMemory GetMemory() const; - uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } - bool IsPersistentMap() const { return (m_MapCount & MAP_COUNT_FLAG_PERSISTENT_MAP) != 0; } - void* GetMappedData() const; - bool CanBecomeLost() const; +bool VmaBlockMetadata_Linear::Validate() const +{ + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + VMA_VALIDATE(suballocations2nd.empty() == (m_2ndVectorMode == SECOND_VECTOR_EMPTY)); + VMA_VALIDATE(!suballocations1st.empty() || + suballocations2nd.empty() || + m_2ndVectorMode != SECOND_VECTOR_RING_BUFFER); - uint32_t GetLastUseFrameIndex() const + if (!suballocations1st.empty()) { - return m_LastUseFrameIndex.load(); + // Null item at the beginning should be accounted into m_1stNullItemsBeginCount. + VMA_VALIDATE(suballocations1st[m_1stNullItemsBeginCount].type != VMA_SUBALLOCATION_TYPE_FREE); + // Null item at the end should be just pop_back(). + VMA_VALIDATE(suballocations1st.back().type != VMA_SUBALLOCATION_TYPE_FREE); } - bool CompareExchangeLastUseFrameIndex(uint32_t& expected, uint32_t desired) + if (!suballocations2nd.empty()) { - return m_LastUseFrameIndex.compare_exchange_weak(expected, desired); + // Null item at the end should be just pop_back(). + VMA_VALIDATE(suballocations2nd.back().type != VMA_SUBALLOCATION_TYPE_FREE); } - /* - - If hAllocation.LastUseFrameIndex + frameInUseCount < allocator.CurrentFrameIndex, - makes it lost by setting LastUseFrameIndex = VMA_FRAME_INDEX_LOST and returns true. - - Else, returns false. - If hAllocation is already lost, assert - you should not call it then. - If hAllocation was not created with CAN_BECOME_LOST_BIT, assert. - */ - bool MakeLost(uint32_t currentFrameIndex, uint32_t frameInUseCount); + VMA_VALIDATE(m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount <= suballocations1st.size()); + VMA_VALIDATE(m_2ndNullItemsCount <= suballocations2nd.size()); + + VkDeviceSize sumUsedSize = 0; + const size_t suballoc1stCount = suballocations1st.size(); + const VkDeviceSize debugMargin = GetDebugMargin(); + VkDeviceSize offset = 0; - void DedicatedAllocCalcStatsInfo(VmaStatInfo& outInfo) + if (m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) { - VMA_ASSERT(m_Type == ALLOCATION_TYPE_DEDICATED); - outInfo.blockCount = 1; - outInfo.allocationCount = 1; - outInfo.unusedRangeCount = 0; - outInfo.usedBytes = m_Size; - outInfo.unusedBytes = 0; - outInfo.allocationSizeMin = outInfo.allocationSizeMax = m_Size; - outInfo.unusedRangeSizeMin = UINT64_MAX; - outInfo.unusedRangeSizeMax = 0; - } + const size_t suballoc2ndCount = suballocations2nd.size(); + size_t nullItem2ndCount = 0; + for (size_t i = 0; i < suballoc2ndCount; ++i) + { + const VmaSuballocation& suballoc = suballocations2nd[i]; + const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); - void BlockAllocMap(); - void BlockAllocUnmap(); - VkResult DedicatedAllocMap(VmaAllocator hAllocator, void** ppData); - void DedicatedAllocUnmap(VmaAllocator hAllocator); + VmaAllocation const alloc = (VmaAllocation)suballoc.userData; + if (!IsVirtual()) + { + VMA_VALIDATE(currFree == (alloc == VK_NULL_HANDLE)); + } + VMA_VALIDATE(suballoc.offset >= offset); -#if VMA_STATS_STRING_ENABLED - uint32_t GetCreationFrameIndex() const { return m_CreationFrameIndex; } - uint32_t GetBufferImageUsage() const { return m_BufferImageUsage; } + if (!currFree) + { + if (!IsVirtual()) + { + VMA_VALIDATE((VkDeviceSize)alloc->GetAllocHandle() == suballoc.offset + 1); + VMA_VALIDATE(alloc->GetSize() == suballoc.size); + } + sumUsedSize += suballoc.size; + } + else + { + ++nullItem2ndCount; + } - void InitBufferImageUsage(uint32_t bufferImageUsage) - { - VMA_ASSERT(m_BufferImageUsage == 0); - m_BufferImageUsage = bufferImageUsage; + offset = suballoc.offset + suballoc.size + debugMargin; + } + + VMA_VALIDATE(nullItem2ndCount == m_2ndNullItemsCount); } - void PrintParameters(class VmaJsonWriter& json) const; -#endif + for (size_t i = 0; i < m_1stNullItemsBeginCount; ++i) + { + const VmaSuballocation& suballoc = suballocations1st[i]; + VMA_VALIDATE(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE && + suballoc.userData == VMA_NULL); + } -private: - VkDeviceSize m_Alignment; - VkDeviceSize m_Size; - void* m_pUserData; - VMA_ATOMIC_UINT32 m_LastUseFrameIndex; - uint32_t m_MemoryTypeIndex; - uint8_t m_Type; // ALLOCATION_TYPE - uint8_t m_SuballocationType; // VmaSuballocationType - // Bit 0x80 is set when allocation was created with VMA_ALLOCATION_CREATE_MAPPED_BIT. - // Bits with mask 0x7F are reference counter for vmaMapMemory()/vmaUnmapMemory(). - uint8_t m_MapCount; - uint8_t m_Flags; // enum FLAGS + size_t nullItem1stCount = m_1stNullItemsBeginCount; - // Allocation out of VmaDeviceMemoryBlock. - struct BlockAllocation + for (size_t i = m_1stNullItemsBeginCount; i < suballoc1stCount; ++i) { - VmaDeviceMemoryBlock* m_Block; - VkDeviceSize m_Offset; - bool m_CanBecomeLost; - }; + const VmaSuballocation& suballoc = suballocations1st[i]; + const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); - // Allocation for an object that has its own private VkDeviceMemory. - struct DedicatedAllocation - { - VkDeviceMemory m_hMemory; - void* m_pMappedData; // Not null means memory is mapped. - }; + VmaAllocation const alloc = (VmaAllocation)suballoc.userData; + if (!IsVirtual()) + { + VMA_VALIDATE(currFree == (alloc == VK_NULL_HANDLE)); + } + VMA_VALIDATE(suballoc.offset >= offset); + VMA_VALIDATE(i >= m_1stNullItemsBeginCount || currFree); - union + if (!currFree) + { + if (!IsVirtual()) + { + VMA_VALIDATE((VkDeviceSize)alloc->GetAllocHandle() == suballoc.offset + 1); + VMA_VALIDATE(alloc->GetSize() == suballoc.size); + } + sumUsedSize += suballoc.size; + } + else + { + ++nullItem1stCount; + } + + offset = suballoc.offset + suballoc.size + debugMargin; + } + VMA_VALIDATE(nullItem1stCount == m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount); + + if (m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) { - // Allocation out of VmaDeviceMemoryBlock. - BlockAllocation m_BlockAllocation; - // Allocation for an object that has its own private VkDeviceMemory. - DedicatedAllocation m_DedicatedAllocation; - }; + const size_t suballoc2ndCount = suballocations2nd.size(); + size_t nullItem2ndCount = 0; + for (size_t i = suballoc2ndCount; i--; ) + { + const VmaSuballocation& suballoc = suballocations2nd[i]; + const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); -#if VMA_STATS_STRING_ENABLED - uint32_t m_CreationFrameIndex; - uint32_t m_BufferImageUsage; // 0 if unknown. -#endif + VmaAllocation const alloc = (VmaAllocation)suballoc.userData; + if (!IsVirtual()) + { + VMA_VALIDATE(currFree == (alloc == VK_NULL_HANDLE)); + } + VMA_VALIDATE(suballoc.offset >= offset); - void FreeUserDataString(VmaAllocator hAllocator); -}; + if (!currFree) + { + if (!IsVirtual()) + { + VMA_VALIDATE((VkDeviceSize)alloc->GetAllocHandle() == suballoc.offset + 1); + VMA_VALIDATE(alloc->GetSize() == suballoc.size); + } + sumUsedSize += suballoc.size; + } + else + { + ++nullItem2ndCount; + } -/* -Represents a region of VmaDeviceMemoryBlock that is either assigned and returned as -allocated memory block or free. -*/ -struct VmaSuballocation -{ - VkDeviceSize offset; - VkDeviceSize size; - VmaAllocation hAllocation; - VmaSuballocationType type; -}; + offset = suballoc.offset + suballoc.size + debugMargin; + } -// Comparator for offsets. -struct VmaSuballocationOffsetLess -{ - bool operator()(const VmaSuballocation& lhs, const VmaSuballocation& rhs) const - { - return lhs.offset < rhs.offset; - } -}; -struct VmaSuballocationOffsetGreater -{ - bool operator()(const VmaSuballocation& lhs, const VmaSuballocation& rhs) const - { - return lhs.offset > rhs.offset; + VMA_VALIDATE(nullItem2ndCount == m_2ndNullItemsCount); } -}; -typedef VmaList< VmaSuballocation, VmaStlAllocator > VmaSuballocationList; + VMA_VALIDATE(offset <= GetSize()); + VMA_VALIDATE(m_SumFreeSize == GetSize() - sumUsedSize); -// Cost of one additional allocation lost, as equivalent in bytes. -static const VkDeviceSize VMA_LOST_ALLOCATION_COST = 1048576; + return true; +} -enum class VmaAllocationRequestType +size_t VmaBlockMetadata_Linear::GetAllocationCount() const { - Normal, - // Used by "Linear" algorithm. - UpperAddress, - EndOf1st, - EndOf2nd, -}; - -/* -Parameters of planned allocation inside a VmaDeviceMemoryBlock. + return AccessSuballocations1st().size() - m_1stNullItemsBeginCount - m_1stNullItemsMiddleCount + + AccessSuballocations2nd().size() - m_2ndNullItemsCount; +} -If canMakeOtherLost was false: -- item points to a FREE suballocation. -- itemsToMakeLostCount is 0. +size_t VmaBlockMetadata_Linear::GetFreeRegionsCount() const +{ + // Function only used for defragmentation, which is disabled for this algorithm + VMA_ASSERT(0); + return SIZE_MAX; +} -If canMakeOtherLost was true: -- item points to first of sequence of suballocations, which are either FREE, - or point to VmaAllocations that can become lost. -- itemsToMakeLostCount is the number of VmaAllocations that need to be made lost for - the requested allocation to succeed. -*/ -struct VmaAllocationRequest +void VmaBlockMetadata_Linear::AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const { - VkDeviceSize offset; - VkDeviceSize sumFreeSize; // Sum size of free items that overlap with proposed allocation. - VkDeviceSize sumItemSize; // Sum size of items to make lost that overlap with proposed allocation. - VmaSuballocationList::iterator item; - size_t itemsToMakeLostCount; - void* customData; - VmaAllocationRequestType type; + const VkDeviceSize size = GetSize(); + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const size_t suballoc1stCount = suballocations1st.size(); + const size_t suballoc2ndCount = suballocations2nd.size(); - VkDeviceSize CalcCost() const - { - return sumItemSize + itemsToMakeLostCount * VMA_LOST_ALLOCATION_COST; - } -}; + inoutStats.statistics.blockCount++; + inoutStats.statistics.blockBytes += size; -/* -Data structure used for bookkeeping of allocations and unused ranges of memory -in a single VkDeviceMemory block. -*/ -class VmaBlockMetadata -{ -public: - VmaBlockMetadata(VmaAllocator hAllocator); - virtual ~VmaBlockMetadata() { } - virtual void Init(VkDeviceSize size) { m_Size = size; } + VkDeviceSize lastOffset = 0; - // Validates all data structures inside this object. If not valid, returns false. - virtual bool Validate() const = 0; - VkDeviceSize GetSize() const { return m_Size; } - virtual size_t GetAllocationCount() const = 0; - virtual VkDeviceSize GetSumFreeSize() const = 0; - virtual VkDeviceSize GetUnusedRangeSizeMax() const = 0; - // Returns true if this block is empty - contains only single free suballocation. - virtual bool IsEmpty() const = 0; + if (m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; + size_t nextAlloc2ndIndex = 0; + while (lastOffset < freeSpace2ndTo1stEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while (nextAlloc2ndIndex < suballoc2ndCount && + suballocations2nd[nextAlloc2ndIndex].userData == VMA_NULL) + { + ++nextAlloc2ndIndex; + } - virtual void CalcAllocationStatInfo(VmaStatInfo& outInfo) const = 0; - // Shouldn't modify blockCount. - virtual void AddPoolStats(VmaPoolStats& inoutStats) const = 0; + // Found non-null allocation. + if (nextAlloc2ndIndex < suballoc2ndCount) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; -#if VMA_STATS_STRING_ENABLED - virtual void PrintDetailedMap(class VmaJsonWriter& json) const = 0; -#endif + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + VmaAddDetailedStatisticsUnusedRange(inoutStats, unusedRangeSize); + } - // Tries to find a place for suballocation with given parameters inside this block. - // If succeeded, fills pAllocationRequest and returns true. - // If failed, returns false. - virtual bool CreateAllocationRequest( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - bool upperAddress, - VmaSuballocationType allocType, - bool canMakeOtherLost, - // Always one of VMA_ALLOCATION_CREATE_STRATEGY_* or VMA_ALLOCATION_INTERNAL_STRATEGY_* flags. - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest) = 0; + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + VmaAddDetailedStatisticsAllocation(inoutStats, suballoc.size); - virtual bool MakeRequestedAllocationsLost( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VmaAllocationRequest* pAllocationRequest) = 0; + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc2ndIndex; + } + // We are at the end. + else + { + // There is free space from lastOffset to freeSpace2ndTo1stEnd. + if (lastOffset < freeSpace2ndTo1stEnd) + { + const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; + VmaAddDetailedStatisticsUnusedRange(inoutStats, unusedRangeSize); + } - virtual uint32_t MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) = 0; + // End of loop. + lastOffset = freeSpace2ndTo1stEnd; + } + } + } - virtual VkResult CheckCorruption(const void* pBlockData) = 0; + size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; + const VkDeviceSize freeSpace1stTo2ndEnd = + m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; + while (lastOffset < freeSpace1stTo2ndEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while (nextAlloc1stIndex < suballoc1stCount && + suballocations1st[nextAlloc1stIndex].userData == VMA_NULL) + { + ++nextAlloc1stIndex; + } - // Makes actual allocation based on request. Request must already be checked and valid. - virtual void Alloc( - const VmaAllocationRequest& request, - VmaSuballocationType type, - VkDeviceSize allocSize, - VmaAllocation hAllocation) = 0; + // Found non-null allocation. + if (nextAlloc1stIndex < suballoc1stCount) + { + const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; - // Frees suballocation assigned to given memory region. - virtual void Free(const VmaAllocation allocation) = 0; - virtual void FreeAtOffset(VkDeviceSize offset) = 0; + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + VmaAddDetailedStatisticsUnusedRange(inoutStats, unusedRangeSize); + } -protected: - const VkAllocationCallbacks* GetAllocationCallbacks() const { return m_pAllocationCallbacks; } + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + VmaAddDetailedStatisticsAllocation(inoutStats, suballoc.size); -#if VMA_STATS_STRING_ENABLED - void PrintDetailedMap_Begin(class VmaJsonWriter& json, - VkDeviceSize unusedBytes, - size_t allocationCount, - size_t unusedRangeCount) const; - void PrintDetailedMap_Allocation(class VmaJsonWriter& json, - VkDeviceSize offset, - VmaAllocation hAllocation) const; - void PrintDetailedMap_UnusedRange(class VmaJsonWriter& json, - VkDeviceSize offset, - VkDeviceSize size) const; - void PrintDetailedMap_End(class VmaJsonWriter& json) const; -#endif + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc1stIndex; + } + // We are at the end. + else + { + // There is free space from lastOffset to freeSpace1stTo2ndEnd. + if (lastOffset < freeSpace1stTo2ndEnd) + { + const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; + VmaAddDetailedStatisticsUnusedRange(inoutStats, unusedRangeSize); + } -private: - VkDeviceSize m_Size; - const VkAllocationCallbacks* m_pAllocationCallbacks; -}; + // End of loop. + lastOffset = freeSpace1stTo2ndEnd; + } + } -#define VMA_VALIDATE(cond) do { if(!(cond)) { \ - VMA_ASSERT(0 && "Validation failed: " #cond); \ - return false; \ - } } while(false) + if (m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; + while (lastOffset < size) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while (nextAlloc2ndIndex != SIZE_MAX && + suballocations2nd[nextAlloc2ndIndex].userData == VMA_NULL) + { + --nextAlloc2ndIndex; + } -class VmaBlockMetadata_Generic : public VmaBlockMetadata -{ - VMA_CLASS_NO_COPY(VmaBlockMetadata_Generic) -public: - VmaBlockMetadata_Generic(VmaAllocator hAllocator); - virtual ~VmaBlockMetadata_Generic(); - virtual void Init(VkDeviceSize size); + // Found non-null allocation. + if (nextAlloc2ndIndex != SIZE_MAX) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - virtual bool Validate() const; - virtual size_t GetAllocationCount() const { return m_Suballocations.size() - m_FreeCount; } - virtual VkDeviceSize GetSumFreeSize() const { return m_SumFreeSize; } - virtual VkDeviceSize GetUnusedRangeSizeMax() const; - virtual bool IsEmpty() const; + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + VmaAddDetailedStatisticsUnusedRange(inoutStats, unusedRangeSize); + } - virtual void CalcAllocationStatInfo(VmaStatInfo& outInfo) const; - virtual void AddPoolStats(VmaPoolStats& inoutStats) const; + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + VmaAddDetailedStatisticsAllocation(inoutStats, suballoc.size); -#if VMA_STATS_STRING_ENABLED - virtual void PrintDetailedMap(class VmaJsonWriter& json) const; -#endif + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + --nextAlloc2ndIndex; + } + // We are at the end. + else + { + // There is free space from lastOffset to size. + if (lastOffset < size) + { + const VkDeviceSize unusedRangeSize = size - lastOffset; + VmaAddDetailedStatisticsUnusedRange(inoutStats, unusedRangeSize); + } - virtual bool CreateAllocationRequest( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - bool upperAddress, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest); + // End of loop. + lastOffset = size; + } + } + } +} - virtual bool MakeRequestedAllocationsLost( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VmaAllocationRequest* pAllocationRequest); +void VmaBlockMetadata_Linear::AddStatistics(VmaStatistics& inoutStats) const +{ + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const VkDeviceSize size = GetSize(); + const size_t suballoc1stCount = suballocations1st.size(); + const size_t suballoc2ndCount = suballocations2nd.size(); - virtual uint32_t MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount); + inoutStats.blockCount++; + inoutStats.blockBytes += size; + inoutStats.allocationBytes += size - m_SumFreeSize; - virtual VkResult CheckCorruption(const void* pBlockData); + VkDeviceSize lastOffset = 0; - virtual void Alloc( - const VmaAllocationRequest& request, - VmaSuballocationType type, - VkDeviceSize allocSize, - VmaAllocation hAllocation); + if (m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; + size_t nextAlloc2ndIndex = m_1stNullItemsBeginCount; + while (lastOffset < freeSpace2ndTo1stEnd) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while (nextAlloc2ndIndex < suballoc2ndCount && + suballocations2nd[nextAlloc2ndIndex].userData == VMA_NULL) + { + ++nextAlloc2ndIndex; + } - virtual void Free(const VmaAllocation allocation); - virtual void FreeAtOffset(VkDeviceSize offset); + // Found non-null allocation. + if (nextAlloc2ndIndex < suballoc2ndCount) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - //////////////////////////////////////////////////////////////////////////////// - // For defragmentation + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + } - bool IsBufferImageGranularityConflictPossible( - VkDeviceSize bufferImageGranularity, - VmaSuballocationType& inOutPrevSuballocType) const; + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++inoutStats.allocationCount; -private: - friend class VmaDefragmentationAlgorithm_Generic; - friend class VmaDefragmentationAlgorithm_Fast; + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc2ndIndex; + } + // We are at the end. + else + { + if (lastOffset < freeSpace2ndTo1stEnd) + { + // There is free space from lastOffset to freeSpace2ndTo1stEnd. + const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; + } - uint32_t m_FreeCount; - VkDeviceSize m_SumFreeSize; - VmaSuballocationList m_Suballocations; - // Suballocations that are free and have size greater than certain threshold. - // Sorted by size, ascending. - VmaVector< VmaSuballocationList::iterator, VmaStlAllocator< VmaSuballocationList::iterator > > m_FreeSuballocationsBySize; + // End of loop. + lastOffset = freeSpace2ndTo1stEnd; + } + } + } - bool ValidateFreeSuballocationList() const; + size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; + const VkDeviceSize freeSpace1stTo2ndEnd = + m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; + while (lastOffset < freeSpace1stTo2ndEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while (nextAlloc1stIndex < suballoc1stCount && + suballocations1st[nextAlloc1stIndex].userData == VMA_NULL) + { + ++nextAlloc1stIndex; + } - // Checks if requested suballocation with given parameters can be placed in given pFreeSuballocItem. - // If yes, fills pOffset and returns true. If no, returns false. - bool CheckAllocation( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - VmaSuballocationType allocType, - VmaSuballocationList::const_iterator suballocItem, - bool canMakeOtherLost, - VkDeviceSize* pOffset, - size_t* itemsToMakeLostCount, - VkDeviceSize* pSumFreeSize, - VkDeviceSize* pSumItemSize) const; - // Given free suballocation, it merges it with following one, which must also be free. - void MergeFreeWithNext(VmaSuballocationList::iterator item); - // Releases given suballocation, making it free. - // Merges it with adjacent free suballocations if applicable. - // Returns iterator to new free suballocation at this place. - VmaSuballocationList::iterator FreeSuballocation(VmaSuballocationList::iterator suballocItem); - // Given free suballocation, it inserts it into sorted list of - // m_FreeSuballocationsBySize if it's suitable. - void RegisterFreeSuballocation(VmaSuballocationList::iterator item); - // Given free suballocation, it removes it from sorted list of - // m_FreeSuballocationsBySize if it's suitable. - void UnregisterFreeSuballocation(VmaSuballocationList::iterator item); -}; + // Found non-null allocation. + if (nextAlloc1stIndex < suballoc1stCount) + { + const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; -/* -Allocations and their references in internal data structure look like this: + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + } -if(m_2ndVectorMode == SECOND_VECTOR_EMPTY): + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++inoutStats.allocationCount; - 0 +-------+ - | | - | | - | | - +-------+ - | Alloc | 1st[m_1stNullItemsBeginCount] - +-------+ - | Alloc | 1st[m_1stNullItemsBeginCount + 1] - +-------+ - | ... | - +-------+ - | Alloc | 1st[1st.size() - 1] - +-------+ - | | - | | - | | -GetSize() +-------+ + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc1stIndex; + } + // We are at the end. + else + { + if (lastOffset < freeSpace1stTo2ndEnd) + { + // There is free space from lastOffset to freeSpace1stTo2ndEnd. + const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; + } -if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER): + // End of loop. + lastOffset = freeSpace1stTo2ndEnd; + } + } - 0 +-------+ - | Alloc | 2nd[0] - +-------+ - | Alloc | 2nd[1] - +-------+ - | ... | - +-------+ - | Alloc | 2nd[2nd.size() - 1] - +-------+ - | | - | | - | | - +-------+ - | Alloc | 1st[m_1stNullItemsBeginCount] - +-------+ - | Alloc | 1st[m_1stNullItemsBeginCount + 1] - +-------+ - | ... | - +-------+ - | Alloc | 1st[1st.size() - 1] - +-------+ - | | -GetSize() +-------+ + if (m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; + while (lastOffset < size) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while (nextAlloc2ndIndex != SIZE_MAX && + suballocations2nd[nextAlloc2ndIndex].userData == VMA_NULL) + { + --nextAlloc2ndIndex; + } -if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK): + // Found non-null allocation. + if (nextAlloc2ndIndex != SIZE_MAX) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - 0 +-------+ - | | - | | - | | - +-------+ - | Alloc | 1st[m_1stNullItemsBeginCount] - +-------+ - | Alloc | 1st[m_1stNullItemsBeginCount + 1] - +-------+ - | ... | - +-------+ - | Alloc | 1st[1st.size() - 1] - +-------+ - | | - | | - | | - +-------+ - | Alloc | 2nd[2nd.size() - 1] - +-------+ - | ... | - +-------+ - | Alloc | 2nd[1] - +-------+ - | Alloc | 2nd[0] -GetSize() +-------+ + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + } -*/ -class VmaBlockMetadata_Linear : public VmaBlockMetadata -{ - VMA_CLASS_NO_COPY(VmaBlockMetadata_Linear) -public: - VmaBlockMetadata_Linear(VmaAllocator hAllocator); - virtual ~VmaBlockMetadata_Linear(); - virtual void Init(VkDeviceSize size); + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++inoutStats.allocationCount; - virtual bool Validate() const; - virtual size_t GetAllocationCount() const; - virtual VkDeviceSize GetSumFreeSize() const { return m_SumFreeSize; } - virtual VkDeviceSize GetUnusedRangeSizeMax() const; - virtual bool IsEmpty() const { return GetAllocationCount() == 0; } + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + --nextAlloc2ndIndex; + } + // We are at the end. + else + { + if (lastOffset < size) + { + // There is free space from lastOffset to size. + const VkDeviceSize unusedRangeSize = size - lastOffset; + } - virtual void CalcAllocationStatInfo(VmaStatInfo& outInfo) const; - virtual void AddPoolStats(VmaPoolStats& inoutStats) const; + // End of loop. + lastOffset = size; + } + } + } +} #if VMA_STATS_STRING_ENABLED - virtual void PrintDetailedMap(class VmaJsonWriter& json) const; -#endif - - virtual bool CreateAllocationRequest( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - bool upperAddress, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest); - - virtual bool MakeRequestedAllocationsLost( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VmaAllocationRequest* pAllocationRequest); - - virtual uint32_t MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount); - - virtual VkResult CheckCorruption(const void* pBlockData); +void VmaBlockMetadata_Linear::PrintDetailedMap(class VmaJsonWriter& json) const +{ + const VkDeviceSize size = GetSize(); + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const size_t suballoc1stCount = suballocations1st.size(); + const size_t suballoc2ndCount = suballocations2nd.size(); - virtual void Alloc( - const VmaAllocationRequest& request, - VmaSuballocationType type, - VkDeviceSize allocSize, - VmaAllocation hAllocation); + // FIRST PASS - virtual void Free(const VmaAllocation allocation); - virtual void FreeAtOffset(VkDeviceSize offset); + size_t unusedRangeCount = 0; + VkDeviceSize usedBytes = 0; -private: - /* - There are two suballocation vectors, used in ping-pong way. - The one with index m_1stVectorIndex is called 1st. - The one with index (m_1stVectorIndex ^ 1) is called 2nd. - 2nd can be non-empty only when 1st is not empty. - When 2nd is not empty, m_2ndVectorMode indicates its mode of operation. - */ - typedef VmaVector< VmaSuballocation, VmaStlAllocator > SuballocationVectorType; + VkDeviceSize lastOffset = 0; - enum SECOND_VECTOR_MODE + size_t alloc2ndCount = 0; + if (m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) { - SECOND_VECTOR_EMPTY, - /* - Suballocations in 2nd vector are created later than the ones in 1st, but they - all have smaller offset. - */ - SECOND_VECTOR_RING_BUFFER, - /* - Suballocations in 2nd vector are upper side of double stack. - They all have offsets higher than those in 1st vector. - Top of this stack means smaller offsets, but higher indices in this vector. - */ - SECOND_VECTOR_DOUBLE_STACK, - }; + const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; + size_t nextAlloc2ndIndex = 0; + while (lastOffset < freeSpace2ndTo1stEnd) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while (nextAlloc2ndIndex < suballoc2ndCount && + suballocations2nd[nextAlloc2ndIndex].userData == VMA_NULL) + { + ++nextAlloc2ndIndex; + } - VkDeviceSize m_SumFreeSize; - SuballocationVectorType m_Suballocations0, m_Suballocations1; - uint32_t m_1stVectorIndex; - SECOND_VECTOR_MODE m_2ndVectorMode; + // Found non-null allocation. + if (nextAlloc2ndIndex < suballoc2ndCount) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - SuballocationVectorType& AccessSuballocations1st() { return m_1stVectorIndex ? m_Suballocations1 : m_Suballocations0; } - SuballocationVectorType& AccessSuballocations2nd() { return m_1stVectorIndex ? m_Suballocations0 : m_Suballocations1; } - const SuballocationVectorType& AccessSuballocations1st() const { return m_1stVectorIndex ? m_Suballocations1 : m_Suballocations0; } - const SuballocationVectorType& AccessSuballocations2nd() const { return m_1stVectorIndex ? m_Suballocations0 : m_Suballocations1; } + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + ++unusedRangeCount; + } - // Number of items in 1st vector with hAllocation = null at the beginning. - size_t m_1stNullItemsBeginCount; - // Number of other items in 1st vector with hAllocation = null somewhere in the middle. - size_t m_1stNullItemsMiddleCount; - // Number of items in 2nd vector with hAllocation = null. - size_t m_2ndNullItemsCount; + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++alloc2ndCount; + usedBytes += suballoc.size; - bool ShouldCompact1st() const; - void CleanupAfterFree(); + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc2ndIndex; + } + // We are at the end. + else + { + if (lastOffset < freeSpace2ndTo1stEnd) + { + // There is free space from lastOffset to freeSpace2ndTo1stEnd. + ++unusedRangeCount; + } - bool CreateAllocationRequest_LowerAddress( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest); - bool CreateAllocationRequest_UpperAddress( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest); -}; + // End of loop. + lastOffset = freeSpace2ndTo1stEnd; + } + } + } -/* -- GetSize() is the original size of allocated memory block. -- m_UsableSize is this size aligned down to a power of two. - All allocations and calculations happen relative to m_UsableSize. -- GetUnusableSize() is the difference between them. - It is repoted as separate, unused range, not available for allocations. + size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; + size_t alloc1stCount = 0; + const VkDeviceSize freeSpace1stTo2ndEnd = + m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; + while (lastOffset < freeSpace1stTo2ndEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while (nextAlloc1stIndex < suballoc1stCount && + suballocations1st[nextAlloc1stIndex].userData == VMA_NULL) + { + ++nextAlloc1stIndex; + } -Node at level 0 has size = m_UsableSize. -Each next level contains nodes with size 2 times smaller than current level. -m_LevelCount is the maximum number of levels to use in the current object. -*/ -class VmaBlockMetadata_Buddy : public VmaBlockMetadata -{ - VMA_CLASS_NO_COPY(VmaBlockMetadata_Buddy) -public: - VmaBlockMetadata_Buddy(VmaAllocator hAllocator); - virtual ~VmaBlockMetadata_Buddy(); - virtual void Init(VkDeviceSize size); + // Found non-null allocation. + if (nextAlloc1stIndex < suballoc1stCount) + { + const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; - virtual bool Validate() const; - virtual size_t GetAllocationCount() const { return m_AllocationCount; } - virtual VkDeviceSize GetSumFreeSize() const { return m_SumFreeSize + GetUnusableSize(); } - virtual VkDeviceSize GetUnusedRangeSizeMax() const; - virtual bool IsEmpty() const { return m_Root->type == Node::TYPE_FREE; } + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + ++unusedRangeCount; + } - virtual void CalcAllocationStatInfo(VmaStatInfo& outInfo) const; - virtual void AddPoolStats(VmaPoolStats& inoutStats) const; + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++alloc1stCount; + usedBytes += suballoc.size; -#if VMA_STATS_STRING_ENABLED - virtual void PrintDetailedMap(class VmaJsonWriter& json) const; -#endif + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc1stIndex; + } + // We are at the end. + else + { + if (lastOffset < size) + { + // There is free space from lastOffset to freeSpace1stTo2ndEnd. + ++unusedRangeCount; + } - virtual bool CreateAllocationRequest( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - bool upperAddress, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest); + // End of loop. + lastOffset = freeSpace1stTo2ndEnd; + } + } - virtual bool MakeRequestedAllocationsLost( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VmaAllocationRequest* pAllocationRequest); + if (m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; + while (lastOffset < size) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while (nextAlloc2ndIndex != SIZE_MAX && + suballocations2nd[nextAlloc2ndIndex].userData == VMA_NULL) + { + --nextAlloc2ndIndex; + } - virtual uint32_t MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount); + // Found non-null allocation. + if (nextAlloc2ndIndex != SIZE_MAX) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - virtual VkResult CheckCorruption(const void* pBlockData) { return VK_ERROR_FEATURE_NOT_PRESENT; } + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + ++unusedRangeCount; + } - virtual void Alloc( - const VmaAllocationRequest& request, - VmaSuballocationType type, - VkDeviceSize allocSize, - VmaAllocation hAllocation); + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++alloc2ndCount; + usedBytes += suballoc.size; - virtual void Free(const VmaAllocation allocation) { FreeAtOffset(allocation, allocation->GetOffset()); } - virtual void FreeAtOffset(VkDeviceSize offset) { FreeAtOffset(VMA_NULL, offset); } + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + --nextAlloc2ndIndex; + } + // We are at the end. + else + { + if (lastOffset < size) + { + // There is free space from lastOffset to size. + ++unusedRangeCount; + } -private: - static const VkDeviceSize MIN_NODE_SIZE = 32; - static const size_t MAX_LEVELS = 30; + // End of loop. + lastOffset = size; + } + } + } - struct ValidationContext - { - size_t calculatedAllocationCount; - size_t calculatedFreeCount; - VkDeviceSize calculatedSumFreeSize; + const VkDeviceSize unusedBytes = size - usedBytes; + PrintDetailedMap_Begin(json, unusedBytes, alloc1stCount + alloc2ndCount, unusedRangeCount); - ValidationContext() : - calculatedAllocationCount(0), - calculatedFreeCount(0), - calculatedSumFreeSize(0) { } - }; + // SECOND PASS + lastOffset = 0; - struct Node + if (m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) { - VkDeviceSize offset; - enum TYPE - { - TYPE_FREE, - TYPE_ALLOCATION, - TYPE_SPLIT, - TYPE_COUNT - } type; - Node* parent; - Node* buddy; - - union + const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; + size_t nextAlloc2ndIndex = 0; + while (lastOffset < freeSpace2ndTo1stEnd) { - struct - { - Node* prev; - Node* next; - } free; - struct + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while (nextAlloc2ndIndex < suballoc2ndCount && + suballocations2nd[nextAlloc2ndIndex].userData == VMA_NULL) { - VmaAllocation alloc; - } allocation; - struct + ++nextAlloc2ndIndex; + } + + // Found non-null allocation. + if (nextAlloc2ndIndex < suballoc2ndCount) { - Node* leftChild; - } split; - }; - }; + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - // Size of the memory block aligned down to a power of two. - VkDeviceSize m_UsableSize; - uint32_t m_LevelCount; + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } - Node* m_Root; - struct { - Node* front; - Node* back; - } m_FreeList[MAX_LEVELS]; - // Number of nodes in the tree with type == TYPE_ALLOCATION. - size_t m_AllocationCount; - // Number of nodes in the tree with type == TYPE_FREE. - size_t m_FreeCount; - // This includes space wasted due to internal fragmentation. Doesn't include unusable size. - VkDeviceSize m_SumFreeSize; + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.size, suballoc.userData); - VkDeviceSize GetUnusableSize() const { return GetSize() - m_UsableSize; } - void DeleteNode(Node* node); - bool ValidateNode(ValidationContext& ctx, const Node* parent, const Node* curr, uint32_t level, VkDeviceSize levelNodeSize) const; - uint32_t AllocSizeToLevel(VkDeviceSize allocSize) const; - inline VkDeviceSize LevelToNodeSize(uint32_t level) const { return m_UsableSize >> level; } - // Alloc passed just for validation. Can be null. - void FreeAtOffset(VmaAllocation alloc, VkDeviceSize offset); - void CalcAllocationStatInfoNode(VmaStatInfo& outInfo, const Node* node, VkDeviceSize levelNodeSize) const; - // Adds node to the front of FreeList at given level. - // node->type must be FREE. - // node->free.prev, next can be undefined. - void AddToFreeListFront(uint32_t level, Node* node); - // Removes node from FreeList at given level. - // node->type must be FREE. - // node->free.prev, next stay untouched. - void RemoveFromFreeList(uint32_t level, Node* node); + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc2ndIndex; + } + // We are at the end. + else + { + if (lastOffset < freeSpace2ndTo1stEnd) + { + // There is free space from lastOffset to freeSpace2ndTo1stEnd. + const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } -#if VMA_STATS_STRING_ENABLED - void PrintDetailedMapNode(class VmaJsonWriter& json, const Node* node, VkDeviceSize levelNodeSize) const; -#endif -}; + // End of loop. + lastOffset = freeSpace2ndTo1stEnd; + } + } + } -/* -Represents a single block of device memory (`VkDeviceMemory`) with all the -data about its regions (aka suballocations, #VmaAllocation), assigned and free. + nextAlloc1stIndex = m_1stNullItemsBeginCount; + while (lastOffset < freeSpace1stTo2ndEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while (nextAlloc1stIndex < suballoc1stCount && + suballocations1st[nextAlloc1stIndex].userData == VMA_NULL) + { + ++nextAlloc1stIndex; + } -Thread-safety: This class must be externally synchronized. -*/ -class VmaDeviceMemoryBlock -{ - VMA_CLASS_NO_COPY(VmaDeviceMemoryBlock) -public: - VmaBlockMetadata* m_pMetadata; + // Found non-null allocation. + if (nextAlloc1stIndex < suballoc1stCount) + { + const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; - VmaDeviceMemoryBlock(VmaAllocator hAllocator); + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } - ~VmaDeviceMemoryBlock() - { - VMA_ASSERT(m_MapCount == 0 && "VkDeviceMemory block is being destroyed while it is still mapped."); - VMA_ASSERT(m_hMemory == VK_NULL_HANDLE); - } + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.size, suballoc.userData); - // Always call after construction. - void Init( - VmaAllocator hAllocator, - VmaPool hParentPool, - uint32_t newMemoryTypeIndex, - VkDeviceMemory newMemory, - VkDeviceSize newSize, - uint32_t id, - uint32_t algorithm); - // Always call before destruction. - void Destroy(VmaAllocator allocator); + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc1stIndex; + } + // We are at the end. + else + { + if (lastOffset < freeSpace1stTo2ndEnd) + { + // There is free space from lastOffset to freeSpace1stTo2ndEnd. + const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } - VmaPool GetParentPool() const { return m_hParentPool; } - VkDeviceMemory GetDeviceMemory() const { return m_hMemory; } - uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } - uint32_t GetId() const { return m_Id; } - void* GetMappedData() const { return m_pMappedData; } + // End of loop. + lastOffset = freeSpace1stTo2ndEnd; + } + } - // Validates all data structures inside this object. If not valid, returns false. - bool Validate() const; + if (m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; + while (lastOffset < size) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while (nextAlloc2ndIndex != SIZE_MAX && + suballocations2nd[nextAlloc2ndIndex].userData == VMA_NULL) + { + --nextAlloc2ndIndex; + } - VkResult CheckCorruption(VmaAllocator hAllocator); + // Found non-null allocation. + if (nextAlloc2ndIndex != SIZE_MAX) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - // ppData can be null. - VkResult Map(VmaAllocator hAllocator, uint32_t count, void** ppData); - void Unmap(VmaAllocator hAllocator, uint32_t count); + // 1. Process free space before this allocation. + if (lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } - VkResult WriteMagicValueAroundAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize); - VkResult ValidateMagicValueAroundAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize); + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.size, suballoc.userData); - VkResult BindBufferMemory( - const VmaAllocator hAllocator, - const VmaAllocation hAllocation, - VkDeviceSize allocationLocalOffset, - VkBuffer hBuffer, - const void* pNext); - VkResult BindImageMemory( - const VmaAllocator hAllocator, - const VmaAllocation hAllocation, - VkDeviceSize allocationLocalOffset, - VkImage hImage, - const void* pNext); + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + --nextAlloc2ndIndex; + } + // We are at the end. + else + { + if (lastOffset < size) + { + // There is free space from lastOffset to size. + const VkDeviceSize unusedRangeSize = size - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } -private: - VmaPool m_hParentPool; // VK_NULL_HANDLE if not belongs to custom pool. - uint32_t m_MemoryTypeIndex; - uint32_t m_Id; - VkDeviceMemory m_hMemory; + // End of loop. + lastOffset = size; + } + } + } - /* - Protects access to m_hMemory so it's not used by multiple threads simultaneously, e.g. vkMapMemory, vkBindBufferMemory. - Also protects m_MapCount, m_pMappedData. - Allocations, deallocations, any change in m_pMetadata is protected by parent's VmaBlockVector::m_Mutex. - */ - VMA_MUTEX m_Mutex; - uint32_t m_MapCount; - void* m_pMappedData; -}; + PrintDetailedMap_End(json); +} +#endif // VMA_STATS_STRING_ENABLED -struct VmaPointerLess +bool VmaBlockMetadata_Linear::CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) { - bool operator()(const void* lhs, const void* rhs) const - { - return lhs < rhs; - } -}; + VMA_ASSERT(allocSize > 0); + VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(pAllocationRequest != VMA_NULL); + VMA_HEAVY_ASSERT(Validate()); + pAllocationRequest->size = allocSize; + return upperAddress ? + CreateAllocationRequest_UpperAddress( + allocSize, allocAlignment, allocType, strategy, pAllocationRequest) : + CreateAllocationRequest_LowerAddress( + allocSize, allocAlignment, allocType, strategy, pAllocationRequest); +} -struct VmaDefragmentationMove +VkResult VmaBlockMetadata_Linear::CheckCorruption(const void* pBlockData) { - size_t srcBlockIndex; - size_t dstBlockIndex; - VkDeviceSize srcOffset; - VkDeviceSize dstOffset; - VkDeviceSize size; - VmaAllocation hAllocation; - VmaDeviceMemoryBlock* pSrcBlock; - VmaDeviceMemoryBlock* pDstBlock; -}; + VMA_ASSERT(!IsVirtual()); + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + for (size_t i = m_1stNullItemsBeginCount, count = suballocations1st.size(); i < count; ++i) + { + const VmaSuballocation& suballoc = suballocations1st[i]; + if (suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + { + if (!VmaValidateMagicValue(pBlockData, suballoc.offset + suballoc.size)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); + return VK_ERROR_UNKNOWN_COPY; + } + } + } -class VmaDefragmentationAlgorithm; - -/* -Sequence of VmaDeviceMemoryBlock. Represents memory blocks allocated for a specific -Vulkan memory type. - -Synchronized internally with a mutex. -*/ -struct VmaBlockVector -{ - VMA_CLASS_NO_COPY(VmaBlockVector) -public: - VmaBlockVector( - VmaAllocator hAllocator, - VmaPool hParentPool, - uint32_t memoryTypeIndex, - VkDeviceSize preferredBlockSize, - size_t minBlockCount, - size_t maxBlockCount, - VkDeviceSize bufferImageGranularity, - uint32_t frameInUseCount, - bool explicitBlockSize, - uint32_t algorithm, - float priority); - ~VmaBlockVector(); - - VkResult CreateMinBlocks(); - - VmaAllocator GetAllocator() const { return m_hAllocator; } - VmaPool GetParentPool() const { return m_hParentPool; } - bool IsCustomPool() const { return m_hParentPool != VMA_NULL; } - uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } - VkDeviceSize GetPreferredBlockSize() const { return m_PreferredBlockSize; } - VkDeviceSize GetBufferImageGranularity() const { return m_BufferImageGranularity; } - uint32_t GetFrameInUseCount() const { return m_FrameInUseCount; } - uint32_t GetAlgorithm() const { return m_Algorithm; } - - void GetPoolStats(VmaPoolStats* pStats); - - bool IsEmpty(); - bool IsCorruptionDetectionEnabled() const; - - VkResult Allocate( - uint32_t currentFrameIndex, - VkDeviceSize size, - VkDeviceSize alignment, - const VmaAllocationCreateInfo& createInfo, - VmaSuballocationType suballocType, - size_t allocationCount, - VmaAllocation* pAllocations); - - void Free(const VmaAllocation hAllocation); - - // Adds statistics of this BlockVector to pStats. - void AddStats(VmaStats* pStats); - -#if VMA_STATS_STRING_ENABLED - void PrintDetailedMap(class VmaJsonWriter& json); -#endif - - void MakePoolAllocationsLost( - uint32_t currentFrameIndex, - size_t* pLostAllocationCount); - VkResult CheckCorruption(); - - // Saves results in pCtx->res. - void Defragment( - class VmaBlockVectorDefragmentationContext* pCtx, - VmaDefragmentationStats* pStats, VmaDefragmentationFlags flags, - VkDeviceSize& maxCpuBytesToMove, uint32_t& maxCpuAllocationsToMove, - VkDeviceSize& maxGpuBytesToMove, uint32_t& maxGpuAllocationsToMove, - VkCommandBuffer commandBuffer); - void DefragmentationEnd( - class VmaBlockVectorDefragmentationContext* pCtx, - uint32_t flags, - VmaDefragmentationStats* pStats); - - uint32_t ProcessDefragmentations( - class VmaBlockVectorDefragmentationContext *pCtx, - VmaDefragmentationPassMoveInfo* pMove, uint32_t maxMoves); - - void CommitDefragmentations( - class VmaBlockVectorDefragmentationContext *pCtx, - VmaDefragmentationStats* pStats); - - //////////////////////////////////////////////////////////////////////////////// - // To be used only while the m_Mutex is locked. Used during defragmentation. - - size_t GetBlockCount() const { return m_Blocks.size(); } - VmaDeviceMemoryBlock* GetBlock(size_t index) const { return m_Blocks[index]; } - size_t CalcAllocationCount() const; - bool IsBufferImageGranularityConflictPossible() const; - -private: - friend class VmaDefragmentationAlgorithm_Generic; - - const VmaAllocator m_hAllocator; - const VmaPool m_hParentPool; - const uint32_t m_MemoryTypeIndex; - const VkDeviceSize m_PreferredBlockSize; - const size_t m_MinBlockCount; - const size_t m_MaxBlockCount; - const VkDeviceSize m_BufferImageGranularity; - const uint32_t m_FrameInUseCount; - const bool m_ExplicitBlockSize; - const uint32_t m_Algorithm; - const float m_Priority; - VMA_RW_MUTEX m_Mutex; - - /* There can be at most one allocation that is completely empty (except when minBlockCount > 0) - - a hysteresis to avoid pessimistic case of alternating creation and destruction of a VkDeviceMemory. */ - bool m_HasEmptyBlock; - // Incrementally sorted by sumFreeSize, ascending. - VmaVector< VmaDeviceMemoryBlock*, VmaStlAllocator > m_Blocks; - uint32_t m_NextBlockId; - - VkDeviceSize CalcMaxBlockSize() const; - - // Finds and removes given block from vector. - void Remove(VmaDeviceMemoryBlock* pBlock); - - // Performs single step in sorting m_Blocks. They may not be fully sorted - // after this call. - void IncrementallySortBlocks(); - - VkResult AllocatePage( - uint32_t currentFrameIndex, - VkDeviceSize size, - VkDeviceSize alignment, - const VmaAllocationCreateInfo& createInfo, - VmaSuballocationType suballocType, - VmaAllocation* pAllocation); - - // To be used only without CAN_MAKE_OTHER_LOST flag. - VkResult AllocateFromBlock( - VmaDeviceMemoryBlock* pBlock, - uint32_t currentFrameIndex, - VkDeviceSize size, - VkDeviceSize alignment, - VmaAllocationCreateFlags allocFlags, - void* pUserData, - VmaSuballocationType suballocType, - uint32_t strategy, - VmaAllocation* pAllocation); - - VkResult CreateBlock(VkDeviceSize blockSize, size_t* pNewBlockIndex); - - // Saves result to pCtx->res. - void ApplyDefragmentationMovesCpu( - class VmaBlockVectorDefragmentationContext* pDefragCtx, - const VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves); - // Saves result to pCtx->res. - void ApplyDefragmentationMovesGpu( - class VmaBlockVectorDefragmentationContext* pDefragCtx, - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkCommandBuffer commandBuffer); - - /* - Used during defragmentation. pDefragmentationStats is optional. It's in/out - - updated with new data. - */ - void FreeEmptyBlocks(VmaDefragmentationStats* pDefragmentationStats); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + for (size_t i = 0, count = suballocations2nd.size(); i < count; ++i) + { + const VmaSuballocation& suballoc = suballocations2nd[i]; + if (suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + { + if (!VmaValidateMagicValue(pBlockData, suballoc.offset + suballoc.size)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); + return VK_ERROR_UNKNOWN_COPY; + } + } + } - void UpdateHasEmptyBlock(); -}; + return VK_SUCCESS; +} -struct VmaPool_T +void VmaBlockMetadata_Linear::Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) { - VMA_CLASS_NO_COPY(VmaPool_T) -public: - VmaBlockVector m_BlockVector; - - VmaPool_T( - VmaAllocator hAllocator, - const VmaPoolCreateInfo& createInfo, - VkDeviceSize preferredBlockSize); - ~VmaPool_T(); - - uint32_t GetId() const { return m_Id; } - void SetId(uint32_t id) { VMA_ASSERT(m_Id == 0); m_Id = id; } - - const char* GetName() const { return m_Name; } - void SetName(const char* pName); - -#if VMA_STATS_STRING_ENABLED - //void PrintDetailedMap(class VmaStringBuilder& sb); -#endif - -private: - uint32_t m_Id; - char* m_Name; -}; - -/* -Performs defragmentation: + const VkDeviceSize offset = (VkDeviceSize)request.allocHandle - 1; + const VmaSuballocation newSuballoc = { offset, request.size, userData, type }; -- Updates `pBlockVector->m_pMetadata`. -- Updates allocations by calling ChangeBlockAllocation() or ChangeOffset(). -- Does not move actual data, only returns requested moves as `moves`. -*/ -class VmaDefragmentationAlgorithm -{ - VMA_CLASS_NO_COPY(VmaDefragmentationAlgorithm) -public: - VmaDefragmentationAlgorithm( - VmaAllocator hAllocator, - VmaBlockVector* pBlockVector, - uint32_t currentFrameIndex) : - m_hAllocator(hAllocator), - m_pBlockVector(pBlockVector), - m_CurrentFrameIndex(currentFrameIndex) + switch (request.type) { - } - virtual ~VmaDefragmentationAlgorithm() + case VmaAllocationRequestType::UpperAddress: { + VMA_ASSERT(m_2ndVectorMode != SECOND_VECTOR_RING_BUFFER && + "CRITICAL ERROR: Trying to use linear allocator as double stack while it was already used as ring buffer."); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + suballocations2nd.push_back(newSuballoc); + m_2ndVectorMode = SECOND_VECTOR_DOUBLE_STACK; } + break; + case VmaAllocationRequestType::EndOf1st: + { + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - virtual void AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged) = 0; - virtual void AddAll() = 0; - - virtual VkResult Defragment( - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkDeviceSize maxBytesToMove, - uint32_t maxAllocationsToMove, - VmaDefragmentationFlags flags) = 0; - - virtual VkDeviceSize GetBytesMoved() const = 0; - virtual uint32_t GetAllocationsMoved() const = 0; - -protected: - VmaAllocator const m_hAllocator; - VmaBlockVector* const m_pBlockVector; - const uint32_t m_CurrentFrameIndex; + VMA_ASSERT(suballocations1st.empty() || + offset >= suballocations1st.back().offset + suballocations1st.back().size); + // Check if it fits before the end of the block. + VMA_ASSERT(offset + request.size <= GetSize()); - struct AllocationInfo + suballocations1st.push_back(newSuballoc); + } + break; + case VmaAllocationRequestType::EndOf2nd: { - VmaAllocation m_hAllocation; - VkBool32* m_pChanged; + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + // New allocation at the end of 2-part ring buffer, so before first allocation from 1st vector. + VMA_ASSERT(!suballocations1st.empty() && + offset + request.size <= suballocations1st[m_1stNullItemsBeginCount].offset); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - AllocationInfo() : - m_hAllocation(VK_NULL_HANDLE), - m_pChanged(VMA_NULL) - { - } - AllocationInfo(VmaAllocation hAlloc, VkBool32* pChanged) : - m_hAllocation(hAlloc), - m_pChanged(pChanged) + switch (m_2ndVectorMode) { + case SECOND_VECTOR_EMPTY: + // First allocation from second part ring buffer. + VMA_ASSERT(suballocations2nd.empty()); + m_2ndVectorMode = SECOND_VECTOR_RING_BUFFER; + break; + case SECOND_VECTOR_RING_BUFFER: + // 2-part ring buffer is already started. + VMA_ASSERT(!suballocations2nd.empty()); + break; + case SECOND_VECTOR_DOUBLE_STACK: + VMA_ASSERT(0 && "CRITICAL ERROR: Trying to use linear allocator as ring buffer while it was already used as double stack."); + break; + default: + VMA_ASSERT(0); } - }; -}; - -class VmaDefragmentationAlgorithm_Generic : public VmaDefragmentationAlgorithm -{ - VMA_CLASS_NO_COPY(VmaDefragmentationAlgorithm_Generic) -public: - VmaDefragmentationAlgorithm_Generic( - VmaAllocator hAllocator, - VmaBlockVector* pBlockVector, - uint32_t currentFrameIndex, - bool overlappingMoveSupported); - virtual ~VmaDefragmentationAlgorithm_Generic(); - virtual void AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged); - virtual void AddAll() { m_AllAllocations = true; } - - virtual VkResult Defragment( - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkDeviceSize maxBytesToMove, - uint32_t maxAllocationsToMove, - VmaDefragmentationFlags flags); - - virtual VkDeviceSize GetBytesMoved() const { return m_BytesMoved; } - virtual uint32_t GetAllocationsMoved() const { return m_AllocationsMoved; } + suballocations2nd.push_back(newSuballoc); + } + break; + default: + VMA_ASSERT(0 && "CRITICAL INTERNAL ERROR."); + } -private: - uint32_t m_AllocationCount; - bool m_AllAllocations; + m_SumFreeSize -= newSuballoc.size; +} - VkDeviceSize m_BytesMoved; - uint32_t m_AllocationsMoved; +void VmaBlockMetadata_Linear::Free(VmaAllocHandle allocHandle) +{ + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + VkDeviceSize offset = (VkDeviceSize)allocHandle - 1; - struct AllocationInfoSizeGreater + if (!suballocations1st.empty()) { - bool operator()(const AllocationInfo& lhs, const AllocationInfo& rhs) const + // First allocation: Mark it as next empty at the beginning. + VmaSuballocation& firstSuballoc = suballocations1st[m_1stNullItemsBeginCount]; + if (firstSuballoc.offset == offset) { - return lhs.m_hAllocation->GetSize() > rhs.m_hAllocation->GetSize(); + firstSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + firstSuballoc.userData = VMA_NULL; + m_SumFreeSize += firstSuballoc.size; + ++m_1stNullItemsBeginCount; + CleanupAfterFree(); + return; } - }; + } - struct AllocationInfoOffsetGreater + // Last allocation in 2-part ring buffer or top of upper stack (same logic). + if (m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER || + m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) { - bool operator()(const AllocationInfo& lhs, const AllocationInfo& rhs) const + VmaSuballocation& lastSuballoc = suballocations2nd.back(); + if (lastSuballoc.offset == offset) { - return lhs.m_hAllocation->GetOffset() > rhs.m_hAllocation->GetOffset(); + m_SumFreeSize += lastSuballoc.size; + suballocations2nd.pop_back(); + CleanupAfterFree(); + return; } - }; - - struct BlockInfo + } + // Last allocation in 1st vector. + else if (m_2ndVectorMode == SECOND_VECTOR_EMPTY) { - size_t m_OriginalBlockIndex; - VmaDeviceMemoryBlock* m_pBlock; - bool m_HasNonMovableAllocations; - VmaVector< AllocationInfo, VmaStlAllocator > m_Allocations; - - BlockInfo(const VkAllocationCallbacks* pAllocationCallbacks) : - m_OriginalBlockIndex(SIZE_MAX), - m_pBlock(VMA_NULL), - m_HasNonMovableAllocations(true), - m_Allocations(pAllocationCallbacks) - { - } - - void CalcHasNonMovableAllocations() - { - const size_t blockAllocCount = m_pBlock->m_pMetadata->GetAllocationCount(); - const size_t defragmentAllocCount = m_Allocations.size(); - m_HasNonMovableAllocations = blockAllocCount != defragmentAllocCount; - } - - void SortAllocationsBySizeDescending() + VmaSuballocation& lastSuballoc = suballocations1st.back(); + if (lastSuballoc.offset == offset) { - VMA_SORT(m_Allocations.begin(), m_Allocations.end(), AllocationInfoSizeGreater()); + m_SumFreeSize += lastSuballoc.size; + suballocations1st.pop_back(); + CleanupAfterFree(); + return; } + } - void SortAllocationsByOffsetDescending() - { - VMA_SORT(m_Allocations.begin(), m_Allocations.end(), AllocationInfoOffsetGreater()); - } - }; + VmaSuballocation refSuballoc; + refSuballoc.offset = offset; + // Rest of members stays uninitialized intentionally for better performance. - struct BlockPointerLess + // Item from the middle of 1st vector. { - bool operator()(const BlockInfo* pLhsBlockInfo, const VmaDeviceMemoryBlock* pRhsBlock) const - { - return pLhsBlockInfo->m_pBlock < pRhsBlock; - } - bool operator()(const BlockInfo* pLhsBlockInfo, const BlockInfo* pRhsBlockInfo) const + const SuballocationVectorType::iterator it = VmaBinaryFindSorted( + suballocations1st.begin() + m_1stNullItemsBeginCount, + suballocations1st.end(), + refSuballoc, + VmaSuballocationOffsetLess()); + if (it != suballocations1st.end()) { - return pLhsBlockInfo->m_pBlock < pRhsBlockInfo->m_pBlock; + it->type = VMA_SUBALLOCATION_TYPE_FREE; + it->userData = VMA_NULL; + ++m_1stNullItemsMiddleCount; + m_SumFreeSize += it->size; + CleanupAfterFree(); + return; } - }; + } - // 1. Blocks with some non-movable allocations go first. - // 2. Blocks with smaller sumFreeSize go first. - struct BlockInfoCompareMoveDestination + if (m_2ndVectorMode != SECOND_VECTOR_EMPTY) { - bool operator()(const BlockInfo* pLhsBlockInfo, const BlockInfo* pRhsBlockInfo) const + // Item from the middle of 2nd vector. + const SuballocationVectorType::iterator it = m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER ? + VmaBinaryFindSorted(suballocations2nd.begin(), suballocations2nd.end(), refSuballoc, VmaSuballocationOffsetLess()) : + VmaBinaryFindSorted(suballocations2nd.begin(), suballocations2nd.end(), refSuballoc, VmaSuballocationOffsetGreater()); + if (it != suballocations2nd.end()) { - if(pLhsBlockInfo->m_HasNonMovableAllocations && !pRhsBlockInfo->m_HasNonMovableAllocations) - { - return true; - } - if(!pLhsBlockInfo->m_HasNonMovableAllocations && pRhsBlockInfo->m_HasNonMovableAllocations) - { - return false; - } - if(pLhsBlockInfo->m_pBlock->m_pMetadata->GetSumFreeSize() < pRhsBlockInfo->m_pBlock->m_pMetadata->GetSumFreeSize()) - { - return true; - } - return false; + it->type = VMA_SUBALLOCATION_TYPE_FREE; + it->userData = VMA_NULL; + ++m_2ndNullItemsCount; + m_SumFreeSize += it->size; + CleanupAfterFree(); + return; } - }; + } - typedef VmaVector< BlockInfo*, VmaStlAllocator > BlockInfoVector; - BlockInfoVector m_Blocks; + VMA_ASSERT(0 && "Allocation to free not found in linear allocator!"); +} - VkResult DefragmentRound( - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkDeviceSize maxBytesToMove, - uint32_t maxAllocationsToMove, - bool freeOldAllocations); +void VmaBlockMetadata_Linear::GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) +{ + outInfo.offset = (VkDeviceSize)allocHandle - 1; + VmaSuballocation& suballoc = FindSuballocation(outInfo.offset); + outInfo.size = suballoc.size; + outInfo.pUserData = suballoc.userData; +} - size_t CalcBlocksWithNonMovableCount() const; +void* VmaBlockMetadata_Linear::GetAllocationUserData(VmaAllocHandle allocHandle) const +{ + return FindSuballocation((VkDeviceSize)allocHandle - 1).userData; +} - static bool MoveMakesSense( - size_t dstBlockIndex, VkDeviceSize dstOffset, - size_t srcBlockIndex, VkDeviceSize srcOffset); -}; +VmaAllocHandle VmaBlockMetadata_Linear::GetAllocationListBegin() const +{ + // Function only used for defragmentation, which is disabled for this algorithm + VMA_ASSERT(0); + return VK_NULL_HANDLE; +} -class VmaDefragmentationAlgorithm_Fast : public VmaDefragmentationAlgorithm +VmaAllocHandle VmaBlockMetadata_Linear::GetNextAllocation(VmaAllocHandle prevAlloc) const { - VMA_CLASS_NO_COPY(VmaDefragmentationAlgorithm_Fast) -public: - VmaDefragmentationAlgorithm_Fast( - VmaAllocator hAllocator, - VmaBlockVector* pBlockVector, - uint32_t currentFrameIndex, - bool overlappingMoveSupported); - virtual ~VmaDefragmentationAlgorithm_Fast(); + // Function only used for defragmentation, which is disabled for this algorithm + VMA_ASSERT(0); + return VK_NULL_HANDLE; +} - virtual void AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged) { ++m_AllocationCount; } - virtual void AddAll() { m_AllAllocations = true; } +VkDeviceSize VmaBlockMetadata_Linear::GetNextFreeRegionSize(VmaAllocHandle alloc) const +{ + // Function only used for defragmentation, which is disabled for this algorithm + VMA_ASSERT(0); + return 0; +} + +void VmaBlockMetadata_Linear::Clear() +{ + m_SumFreeSize = GetSize(); + m_Suballocations0.clear(); + m_Suballocations1.clear(); + // Leaving m_1stVectorIndex unchanged - it doesn't matter. + m_2ndVectorMode = SECOND_VECTOR_EMPTY; + m_1stNullItemsBeginCount = 0; + m_1stNullItemsMiddleCount = 0; + m_2ndNullItemsCount = 0; +} + +void VmaBlockMetadata_Linear::SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) +{ + VmaSuballocation& suballoc = FindSuballocation((VkDeviceSize)allocHandle - 1); + suballoc.userData = userData; +} - virtual VkResult Defragment( - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkDeviceSize maxBytesToMove, - uint32_t maxAllocationsToMove, - VmaDefragmentationFlags flags); +void VmaBlockMetadata_Linear::DebugLogAllAllocations() const +{ + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + for (auto it = suballocations1st.begin() + m_1stNullItemsBeginCount; it != suballocations1st.end(); ++it) + if (it->type != VMA_SUBALLOCATION_TYPE_FREE) + DebugLogAllocation(it->offset, it->size, it->userData); - virtual VkDeviceSize GetBytesMoved() const { return m_BytesMoved; } - virtual uint32_t GetAllocationsMoved() const { return m_AllocationsMoved; } + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + for (auto it = suballocations2nd.begin(); it != suballocations2nd.end(); ++it) + if (it->type != VMA_SUBALLOCATION_TYPE_FREE) + DebugLogAllocation(it->offset, it->size, it->userData); +} -private: - struct BlockInfo +VmaSuballocation& VmaBlockMetadata_Linear::FindSuballocation(VkDeviceSize offset) const +{ + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + VmaSuballocation refSuballoc; + refSuballoc.offset = offset; + // Rest of members stays uninitialized intentionally for better performance. + + // Item from the 1st vector. { - size_t origBlockIndex; - }; + SuballocationVectorType::const_iterator it = VmaBinaryFindSorted( + suballocations1st.begin() + m_1stNullItemsBeginCount, + suballocations1st.end(), + refSuballoc, + VmaSuballocationOffsetLess()); + if (it != suballocations1st.end()) + { + return const_cast(*it); + } + } - class FreeSpaceDatabase + if (m_2ndVectorMode != SECOND_VECTOR_EMPTY) { - public: - FreeSpaceDatabase() + // Rest of members stays uninitialized intentionally for better performance. + SuballocationVectorType::const_iterator it = m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER ? + VmaBinaryFindSorted(suballocations2nd.begin(), suballocations2nd.end(), refSuballoc, VmaSuballocationOffsetLess()) : + VmaBinaryFindSorted(suballocations2nd.begin(), suballocations2nd.end(), refSuballoc, VmaSuballocationOffsetGreater()); + if (it != suballocations2nd.end()) { - FreeSpace s = {}; - s.blockInfoIndex = SIZE_MAX; - for (const auto i : c10::irange(MAX_COUNT)) { - m_FreeSpaces[i] = s; - } + return const_cast(*it); } + } + + VMA_ASSERT(0 && "Allocation not found in linear allocator!"); + return const_cast(suballocations1st.back()); // Should never occur. +} + +bool VmaBlockMetadata_Linear::ShouldCompact1st() const +{ + const size_t nullItemCount = m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount; + const size_t suballocCount = AccessSuballocations1st().size(); + return suballocCount > 32 && nullItemCount * 2 >= (suballocCount - nullItemCount) * 3; +} + +void VmaBlockMetadata_Linear::CleanupAfterFree() +{ + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - void Register(size_t blockInfoIndex, VkDeviceSize offset, VkDeviceSize size) + if (IsEmpty()) + { + suballocations1st.clear(); + suballocations2nd.clear(); + m_1stNullItemsBeginCount = 0; + m_1stNullItemsMiddleCount = 0; + m_2ndNullItemsCount = 0; + m_2ndVectorMode = SECOND_VECTOR_EMPTY; + } + else + { + const size_t suballoc1stCount = suballocations1st.size(); + const size_t nullItem1stCount = m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount; + VMA_ASSERT(nullItem1stCount <= suballoc1stCount); + + // Find more null items at the beginning of 1st vector. + while (m_1stNullItemsBeginCount < suballoc1stCount && + suballocations1st[m_1stNullItemsBeginCount].type == VMA_SUBALLOCATION_TYPE_FREE) { - if(size < VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) - { - return; - } + ++m_1stNullItemsBeginCount; + --m_1stNullItemsMiddleCount; + } - // Find first invalid or the smallest structure. - size_t bestIndex = SIZE_MAX; - for (const auto i : c10::irange(MAX_COUNT)) { - // Empty structure. - if(m_FreeSpaces[i].blockInfoIndex == SIZE_MAX) - { - bestIndex = i; - break; - } - if(m_FreeSpaces[i].size < size && - (bestIndex == SIZE_MAX || m_FreeSpaces[bestIndex].size > m_FreeSpaces[i].size)) - { - bestIndex = i; - } - } + // Find more null items at the end of 1st vector. + while (m_1stNullItemsMiddleCount > 0 && + suballocations1st.back().type == VMA_SUBALLOCATION_TYPE_FREE) + { + --m_1stNullItemsMiddleCount; + suballocations1st.pop_back(); + } - if(bestIndex != SIZE_MAX) - { - m_FreeSpaces[bestIndex].blockInfoIndex = blockInfoIndex; - m_FreeSpaces[bestIndex].offset = offset; - m_FreeSpaces[bestIndex].size = size; - } + // Find more null items at the end of 2nd vector. + while (m_2ndNullItemsCount > 0 && + suballocations2nd.back().type == VMA_SUBALLOCATION_TYPE_FREE) + { + --m_2ndNullItemsCount; + suballocations2nd.pop_back(); } - bool Fetch(VkDeviceSize alignment, VkDeviceSize size, - size_t& outBlockInfoIndex, VkDeviceSize& outDstOffset) + // Find more null items at the beginning of 2nd vector. + while (m_2ndNullItemsCount > 0 && + suballocations2nd[0].type == VMA_SUBALLOCATION_TYPE_FREE) { - size_t bestIndex = SIZE_MAX; - VkDeviceSize bestFreeSpaceAfter = 0; - for (const auto i : c10::irange(MAX_COUNT)) { - // Structure is valid. - if(m_FreeSpaces[i].blockInfoIndex != SIZE_MAX) - { - const VkDeviceSize dstOffset = VmaAlignUp(m_FreeSpaces[i].offset, alignment); - // Allocation fits into this structure. - if(dstOffset + size <= m_FreeSpaces[i].offset + m_FreeSpaces[i].size) - { - const VkDeviceSize freeSpaceAfter = (m_FreeSpaces[i].offset + m_FreeSpaces[i].size) - - (dstOffset + size); - if(bestIndex == SIZE_MAX || freeSpaceAfter > bestFreeSpaceAfter) - { - bestIndex = i; - bestFreeSpaceAfter = freeSpaceAfter; - } - } - } - } + --m_2ndNullItemsCount; + VmaVectorRemove(suballocations2nd, 0); + } - if(bestIndex != SIZE_MAX) + if (ShouldCompact1st()) + { + const size_t nonNullItemCount = suballoc1stCount - nullItem1stCount; + size_t srcIndex = m_1stNullItemsBeginCount; + for (size_t dstIndex = 0; dstIndex < nonNullItemCount; ++dstIndex) { - outBlockInfoIndex = m_FreeSpaces[bestIndex].blockInfoIndex; - outDstOffset = VmaAlignUp(m_FreeSpaces[bestIndex].offset, alignment); - - if(bestFreeSpaceAfter >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + while (suballocations1st[srcIndex].type == VMA_SUBALLOCATION_TYPE_FREE) { - // Leave this structure for remaining empty space. - const VkDeviceSize alignmentPlusSize = (outDstOffset - m_FreeSpaces[bestIndex].offset) + size; - m_FreeSpaces[bestIndex].offset += alignmentPlusSize; - m_FreeSpaces[bestIndex].size -= alignmentPlusSize; + ++srcIndex; } - else + if (dstIndex != srcIndex) { - // This structure becomes invalid. - m_FreeSpaces[bestIndex].blockInfoIndex = SIZE_MAX; + suballocations1st[dstIndex] = suballocations1st[srcIndex]; } - - return true; + ++srcIndex; } - - return false; + suballocations1st.resize(nonNullItemCount); + m_1stNullItemsBeginCount = 0; + m_1stNullItemsMiddleCount = 0; } - private: - static const size_t MAX_COUNT = 4; - - struct FreeSpace + // 2nd vector became empty. + if (suballocations2nd.empty()) { - size_t blockInfoIndex; // SIZE_MAX means this structure is invalid. - VkDeviceSize offset; - VkDeviceSize size; - } m_FreeSpaces[MAX_COUNT]; - }; - - const bool m_OverlappingMoveSupported; - - uint32_t m_AllocationCount; - bool m_AllAllocations; + m_2ndVectorMode = SECOND_VECTOR_EMPTY; + } - VkDeviceSize m_BytesMoved; - uint32_t m_AllocationsMoved; + // 1st vector became empty. + if (suballocations1st.size() - m_1stNullItemsBeginCount == 0) + { + suballocations1st.clear(); + m_1stNullItemsBeginCount = 0; - VmaVector< BlockInfo, VmaStlAllocator > m_BlockInfos; + if (!suballocations2nd.empty() && m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + // Swap 1st with 2nd. Now 2nd is empty. + m_2ndVectorMode = SECOND_VECTOR_EMPTY; + m_1stNullItemsMiddleCount = m_2ndNullItemsCount; + while (m_1stNullItemsBeginCount < suballocations2nd.size() && + suballocations2nd[m_1stNullItemsBeginCount].type == VMA_SUBALLOCATION_TYPE_FREE) + { + ++m_1stNullItemsBeginCount; + --m_1stNullItemsMiddleCount; + } + m_2ndNullItemsCount = 0; + m_1stVectorIndex ^= 1; + } + } + } - void PreprocessMetadata(); - void PostprocessMetadata(); - void InsertSuballoc(VmaBlockMetadata_Generic* pMetadata, const VmaSuballocation& suballoc); -}; + VMA_HEAVY_ASSERT(Validate()); +} -struct VmaBlockDefragmentationContext +bool VmaBlockMetadata_Linear::CreateAllocationRequest_LowerAddress( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) { - enum BLOCK_FLAG + const VkDeviceSize blockSize = GetSize(); + const VkDeviceSize debugMargin = GetDebugMargin(); + const VkDeviceSize bufferImageGranularity = GetBufferImageGranularity(); + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + if (m_2ndVectorMode == SECOND_VECTOR_EMPTY || m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) { - BLOCK_FLAG_USED = 0x00000001, - }; - uint32_t flags; - VkBuffer hBuffer; -}; + // Try to allocate at the end of 1st vector. -class VmaBlockVectorDefragmentationContext -{ - VMA_CLASS_NO_COPY(VmaBlockVectorDefragmentationContext) -public: - VkResult res; - bool mutexLocked; - VmaVector< VmaBlockDefragmentationContext, VmaStlAllocator > blockContexts; - VmaVector< VmaDefragmentationMove, VmaStlAllocator > defragmentationMoves; - uint32_t defragmentationMovesProcessed; - uint32_t defragmentationMovesCommitted; - bool hasDefragmentationPlan; - - VmaBlockVectorDefragmentationContext( - VmaAllocator hAllocator, - VmaPool hCustomPool, // Optional. - VmaBlockVector* pBlockVector, - uint32_t currFrameIndex); - ~VmaBlockVectorDefragmentationContext(); + VkDeviceSize resultBaseOffset = 0; + if (!suballocations1st.empty()) + { + const VmaSuballocation& lastSuballoc = suballocations1st.back(); + resultBaseOffset = lastSuballoc.offset + lastSuballoc.size + debugMargin; + } - VmaPool GetCustomPool() const { return m_hCustomPool; } - VmaBlockVector* GetBlockVector() const { return m_pBlockVector; } - VmaDefragmentationAlgorithm* GetAlgorithm() const { return m_pAlgorithm; } + // Start from offset equal to beginning of free space. + VkDeviceSize resultOffset = resultBaseOffset; - void AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged); - void AddAll() { m_AllAllocations = true; } + // Apply alignment. + resultOffset = VmaAlignUp(resultOffset, allocAlignment); - void Begin(bool overlappingMoveSupported, VmaDefragmentationFlags flags); + // Check previous suballocations for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if (bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment && !suballocations1st.empty()) + { + bool bufferImageGranularityConflict = false; + for (size_t prevSuballocIndex = suballocations1st.size(); prevSuballocIndex--; ) + { + const VmaSuballocation& prevSuballoc = suballocations1st[prevSuballocIndex]; + if (VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) + { + if (VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } + } + else + // Already on previous page. + break; + } + if (bufferImageGranularityConflict) + { + resultOffset = VmaAlignUp(resultOffset, bufferImageGranularity); + } + } -private: - const VmaAllocator m_hAllocator; - // Null if not from custom pool. - const VmaPool m_hCustomPool; - // Redundant, for convenience not to fetch from m_hCustomPool->m_BlockVector or m_hAllocator->m_pBlockVectors. - VmaBlockVector* const m_pBlockVector; - const uint32_t m_CurrFrameIndex; - // Owner of this object. - VmaDefragmentationAlgorithm* m_pAlgorithm; - - struct AllocInfo - { - VmaAllocation hAlloc; - VkBool32* pChanged; - }; - // Used between constructor and Begin. - VmaVector< AllocInfo, VmaStlAllocator > m_Allocations; - bool m_AllAllocations; -}; + const VkDeviceSize freeSpaceEnd = m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? + suballocations2nd.back().offset : blockSize; -struct VmaDefragmentationContext_T -{ -private: - VMA_CLASS_NO_COPY(VmaDefragmentationContext_T) -public: - VmaDefragmentationContext_T( - VmaAllocator hAllocator, - uint32_t currFrameIndex, - uint32_t flags, - VmaDefragmentationStats* pStats); - ~VmaDefragmentationContext_T(); + // There is enough free space at the end after alignment. + if (resultOffset + allocSize + debugMargin <= freeSpaceEnd) + { + // Check next suballocations for BufferImageGranularity conflicts. + // If conflict exists, allocation cannot be made here. + if ((allocSize % bufferImageGranularity || resultOffset % bufferImageGranularity) && m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + for (size_t nextSuballocIndex = suballocations2nd.size(); nextSuballocIndex--; ) + { + const VmaSuballocation& nextSuballoc = suballocations2nd[nextSuballocIndex]; + if (VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) + { + if (VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) + { + return false; + } + } + else + { + // Already on previous page. + break; + } + } + } - void AddPools(uint32_t poolCount, const VmaPool* pPools); - void AddAllocations( - uint32_t allocationCount, - const VmaAllocation* pAllocations, - VkBool32* pAllocationsChanged); + // All tests passed: Success. + pAllocationRequest->allocHandle = (VmaAllocHandle)(resultOffset + 1); + // pAllocationRequest->item, customData unused. + pAllocationRequest->type = VmaAllocationRequestType::EndOf1st; + return true; + } + } - /* - Returns: - - `VK_SUCCESS` if succeeded and object can be destroyed immediately. - - `VK_NOT_READY` if succeeded but the object must remain alive until vmaDefragmentationEnd(). - - Negative value if error occured and object can be destroyed immediately. - */ - VkResult Defragment( - VkDeviceSize maxCpuBytesToMove, uint32_t maxCpuAllocationsToMove, - VkDeviceSize maxGpuBytesToMove, uint32_t maxGpuAllocationsToMove, - VkCommandBuffer commandBuffer, VmaDefragmentationStats* pStats, VmaDefragmentationFlags flags); + // Wrap-around to end of 2nd vector. Try to allocate there, watching for the + // beginning of 1st vector as the end of free space. + if (m_2ndVectorMode == SECOND_VECTOR_EMPTY || m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + VMA_ASSERT(!suballocations1st.empty()); - VkResult DefragmentPassBegin(VmaDefragmentationPassInfo* pInfo); - VkResult DefragmentPassEnd(); + VkDeviceSize resultBaseOffset = 0; + if (!suballocations2nd.empty()) + { + const VmaSuballocation& lastSuballoc = suballocations2nd.back(); + resultBaseOffset = lastSuballoc.offset + lastSuballoc.size + debugMargin; + } -private: - const VmaAllocator m_hAllocator; - const uint32_t m_CurrFrameIndex; - const uint32_t m_Flags; - VmaDefragmentationStats* const m_pStats; - - VkDeviceSize m_MaxCpuBytesToMove; - uint32_t m_MaxCpuAllocationsToMove; - VkDeviceSize m_MaxGpuBytesToMove; - uint32_t m_MaxGpuAllocationsToMove; - - // Owner of these objects. - VmaBlockVectorDefragmentationContext* m_DefaultPoolContexts[VK_MAX_MEMORY_TYPES]; - // Owner of these objects. - VmaVector< VmaBlockVectorDefragmentationContext*, VmaStlAllocator > m_CustomPoolContexts; -}; + // Start from offset equal to beginning of free space. + VkDeviceSize resultOffset = resultBaseOffset; -#if VMA_RECORDING_ENABLED + // Apply alignment. + resultOffset = VmaAlignUp(resultOffset, allocAlignment); -class VmaRecorder -{ -public: - VmaRecorder(); - VkResult Init(const VmaRecordSettings& settings, bool useMutex); - void WriteConfiguration( - const VkPhysicalDeviceProperties& devProps, - const VkPhysicalDeviceMemoryProperties& memProps, - uint32_t vulkanApiVersion, - bool dedicatedAllocationExtensionEnabled, - bool bindMemory2ExtensionEnabled, - bool memoryBudgetExtensionEnabled, - bool deviceCoherentMemoryExtensionEnabled); - ~VmaRecorder(); - - void RecordCreateAllocator(uint32_t frameIndex); - void RecordDestroyAllocator(uint32_t frameIndex); - void RecordCreatePool(uint32_t frameIndex, - const VmaPoolCreateInfo& createInfo, - VmaPool pool); - void RecordDestroyPool(uint32_t frameIndex, VmaPool pool); - void RecordAllocateMemory(uint32_t frameIndex, - const VkMemoryRequirements& vkMemReq, - const VmaAllocationCreateInfo& createInfo, - VmaAllocation allocation); - void RecordAllocateMemoryPages(uint32_t frameIndex, - const VkMemoryRequirements& vkMemReq, - const VmaAllocationCreateInfo& createInfo, - uint64_t allocationCount, - const VmaAllocation* pAllocations); - void RecordAllocateMemoryForBuffer(uint32_t frameIndex, - const VkMemoryRequirements& vkMemReq, - bool requiresDedicatedAllocation, - bool prefersDedicatedAllocation, - const VmaAllocationCreateInfo& createInfo, - VmaAllocation allocation); - void RecordAllocateMemoryForImage(uint32_t frameIndex, - const VkMemoryRequirements& vkMemReq, - bool requiresDedicatedAllocation, - bool prefersDedicatedAllocation, - const VmaAllocationCreateInfo& createInfo, - VmaAllocation allocation); - void RecordFreeMemory(uint32_t frameIndex, - VmaAllocation allocation); - void RecordFreeMemoryPages(uint32_t frameIndex, - uint64_t allocationCount, - const VmaAllocation* pAllocations); - void RecordSetAllocationUserData(uint32_t frameIndex, - VmaAllocation allocation, - const void* pUserData); - void RecordCreateLostAllocation(uint32_t frameIndex, - VmaAllocation allocation); - void RecordMapMemory(uint32_t frameIndex, - VmaAllocation allocation); - void RecordUnmapMemory(uint32_t frameIndex, - VmaAllocation allocation); - void RecordFlushAllocation(uint32_t frameIndex, - VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size); - void RecordInvalidateAllocation(uint32_t frameIndex, - VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size); - void RecordCreateBuffer(uint32_t frameIndex, - const VkBufferCreateInfo& bufCreateInfo, - const VmaAllocationCreateInfo& allocCreateInfo, - VmaAllocation allocation); - void RecordCreateImage(uint32_t frameIndex, - const VkImageCreateInfo& imageCreateInfo, - const VmaAllocationCreateInfo& allocCreateInfo, - VmaAllocation allocation); - void RecordDestroyBuffer(uint32_t frameIndex, - VmaAllocation allocation); - void RecordDestroyImage(uint32_t frameIndex, - VmaAllocation allocation); - void RecordTouchAllocation(uint32_t frameIndex, - VmaAllocation allocation); - void RecordGetAllocationInfo(uint32_t frameIndex, - VmaAllocation allocation); - void RecordMakePoolAllocationsLost(uint32_t frameIndex, - VmaPool pool); - void RecordDefragmentationBegin(uint32_t frameIndex, - const VmaDefragmentationInfo2& info, - VmaDefragmentationContext ctx); - void RecordDefragmentationEnd(uint32_t frameIndex, - VmaDefragmentationContext ctx); - void RecordSetPoolName(uint32_t frameIndex, - VmaPool pool, - const char* name); - -private: - struct CallParams - { - uint32_t threadId; - double time; - }; - - class UserDataString - { - public: - UserDataString(VmaAllocationCreateFlags allocFlags, const void* pUserData); - const char* GetString() const { return m_Str; } - - private: - char m_PtrStr[17]; - const char* m_Str; - }; - - bool m_UseMutex; - VmaRecordFlags m_Flags; - FILE* m_File; - VMA_MUTEX m_FileMutex; - std::chrono::time_point m_RecordingStartTime; + // Check previous suballocations for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if (bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment && !suballocations2nd.empty()) + { + bool bufferImageGranularityConflict = false; + for (size_t prevSuballocIndex = suballocations2nd.size(); prevSuballocIndex--; ) + { + const VmaSuballocation& prevSuballoc = suballocations2nd[prevSuballocIndex]; + if (VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) + { + if (VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } + } + else + // Already on previous page. + break; + } + if (bufferImageGranularityConflict) + { + resultOffset = VmaAlignUp(resultOffset, bufferImageGranularity); + } + } - void GetBasicParams(CallParams& outParams); + size_t index1st = m_1stNullItemsBeginCount; - // T must be a pointer type, e.g. VmaAllocation, VmaPool. - template - void PrintPointerList(uint64_t count, const T* pItems) - { - if(count) + // There is enough free space at the end after alignment. + if ((index1st == suballocations1st.size() && resultOffset + allocSize + debugMargin <= blockSize) || + (index1st < suballocations1st.size() && resultOffset + allocSize + debugMargin <= suballocations1st[index1st].offset)) { - fprintf(m_File, "%p", pItems[0]); - for(uint64_t i = 1; i < count; ++i) + // Check next suballocations for BufferImageGranularity conflicts. + // If conflict exists, allocation cannot be made here. + if (allocSize % bufferImageGranularity || resultOffset % bufferImageGranularity) { - fprintf(m_File, " %p", pItems[i]); + for (size_t nextSuballocIndex = index1st; + nextSuballocIndex < suballocations1st.size(); + nextSuballocIndex++) + { + const VmaSuballocation& nextSuballoc = suballocations1st[nextSuballocIndex]; + if (VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) + { + if (VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) + { + return false; + } + } + else + { + // Already on next page. + break; + } + } } + + // All tests passed: Success. + pAllocationRequest->allocHandle = (VmaAllocHandle)(resultOffset + 1); + pAllocationRequest->type = VmaAllocationRequestType::EndOf2nd; + // pAllocationRequest->item, customData unused. + return true; } } - void PrintPointerList(uint64_t count, const VmaAllocation* pItems); - void Flush(); -}; - -#endif // #if VMA_RECORDING_ENABLED + return false; +} -/* -Thread-safe wrapper over VmaPoolAllocator free list, for allocation of VmaAllocation_T objects. -*/ -class VmaAllocationObjectAllocator +bool VmaBlockMetadata_Linear::CreateAllocationRequest_UpperAddress( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) { - VMA_CLASS_NO_COPY(VmaAllocationObjectAllocator) -public: - VmaAllocationObjectAllocator(const VkAllocationCallbacks* pAllocationCallbacks); + const VkDeviceSize blockSize = GetSize(); + const VkDeviceSize bufferImageGranularity = GetBufferImageGranularity(); + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - template VmaAllocation Allocate(Types... args); - void Free(VmaAllocation hAlloc); + if (m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + VMA_ASSERT(0 && "Trying to use pool with linear algorithm as double stack, while it is already being used as ring buffer."); + return false; + } -private: - VMA_MUTEX m_Mutex; - VmaPoolAllocator m_Allocator; -}; + // Try to allocate before 2nd.back(), or end of block if 2nd.empty(). + if (allocSize > blockSize) + { + return false; + } + VkDeviceSize resultBaseOffset = blockSize - allocSize; + if (!suballocations2nd.empty()) + { + const VmaSuballocation& lastSuballoc = suballocations2nd.back(); + resultBaseOffset = lastSuballoc.offset - allocSize; + if (allocSize > lastSuballoc.offset) + { + return false; + } + } -struct VmaCurrentBudgetData -{ - VMA_ATOMIC_UINT64 m_BlockBytes[VK_MAX_MEMORY_HEAPS]; - VMA_ATOMIC_UINT64 m_AllocationBytes[VK_MAX_MEMORY_HEAPS]; + // Start from offset equal to end of free space. + VkDeviceSize resultOffset = resultBaseOffset; -#if VMA_MEMORY_BUDGET - VMA_ATOMIC_UINT32 m_OperationsSinceBudgetFetch; - VMA_RW_MUTEX m_BudgetMutex; - uint64_t m_VulkanUsage[VK_MAX_MEMORY_HEAPS]; - uint64_t m_VulkanBudget[VK_MAX_MEMORY_HEAPS]; - uint64_t m_BlockBytesAtBudgetFetch[VK_MAX_MEMORY_HEAPS]; -#endif // #if VMA_MEMORY_BUDGET + const VkDeviceSize debugMargin = GetDebugMargin(); - VmaCurrentBudgetData() + // Apply debugMargin at the end. + if (debugMargin > 0) { - for (const auto heapIndex : c10::irange(VK_MAX_MEMORY_HEAPS)) { - m_BlockBytes[heapIndex] = 0; - m_AllocationBytes[heapIndex] = 0; -#if VMA_MEMORY_BUDGET - m_VulkanUsage[heapIndex] = 0; - m_VulkanBudget[heapIndex] = 0; - m_BlockBytesAtBudgetFetch[heapIndex] = 0; -#endif + if (resultOffset < debugMargin) + { + return false; } - -#if VMA_MEMORY_BUDGET - m_OperationsSinceBudgetFetch = 0; -#endif + resultOffset -= debugMargin; } - void AddAllocation(uint32_t heapIndex, VkDeviceSize allocationSize) + // Apply alignment. + resultOffset = VmaAlignDown(resultOffset, allocAlignment); + + // Check next suballocations from 2nd for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if (bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment && !suballocations2nd.empty()) { - m_AllocationBytes[heapIndex] += allocationSize; -#if VMA_MEMORY_BUDGET - ++m_OperationsSinceBudgetFetch; -#endif + bool bufferImageGranularityConflict = false; + for (size_t nextSuballocIndex = suballocations2nd.size(); nextSuballocIndex--; ) + { + const VmaSuballocation& nextSuballoc = suballocations2nd[nextSuballocIndex]; + if (VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) + { + if (VmaIsBufferImageGranularityConflict(nextSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } + } + else + // Already on previous page. + break; + } + if (bufferImageGranularityConflict) + { + resultOffset = VmaAlignDown(resultOffset, bufferImageGranularity); + } } - void RemoveAllocation(uint32_t heapIndex, VkDeviceSize allocationSize) + // There is enough free space. + const VkDeviceSize endOf1st = !suballocations1st.empty() ? + suballocations1st.back().offset + suballocations1st.back().size : + 0; + if (endOf1st + debugMargin <= resultOffset) { - VMA_ASSERT(m_AllocationBytes[heapIndex] >= allocationSize); // DELME - m_AllocationBytes[heapIndex] -= allocationSize; -#if VMA_MEMORY_BUDGET - ++m_OperationsSinceBudgetFetch; -#endif + // Check previous suballocations for BufferImageGranularity conflicts. + // If conflict exists, allocation cannot be made here. + if (bufferImageGranularity > 1) + { + for (size_t prevSuballocIndex = suballocations1st.size(); prevSuballocIndex--; ) + { + const VmaSuballocation& prevSuballoc = suballocations1st[prevSuballocIndex]; + if (VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) + { + if (VmaIsBufferImageGranularityConflict(allocType, prevSuballoc.type)) + { + return false; + } + } + else + { + // Already on next page. + break; + } + } + } + + // All tests passed: Success. + pAllocationRequest->allocHandle = (VmaAllocHandle)(resultOffset + 1); + // pAllocationRequest->item unused. + pAllocationRequest->type = VmaAllocationRequestType::UpperAddress; + return true; } -}; -// Main allocator object. -struct VmaAllocator_T + return false; +} +#endif // _VMA_BLOCK_METADATA_LINEAR_FUNCTIONS +#endif // _VMA_BLOCK_METADATA_LINEAR + +#if 0 +#ifndef _VMA_BLOCK_METADATA_BUDDY +/* +- GetSize() is the original size of allocated memory block. +- m_UsableSize is this size aligned down to a power of two. + All allocations and calculations happen relative to m_UsableSize. +- GetUnusableSize() is the difference between them. + It is reported as separate, unused range, not available for allocations. + +Node at level 0 has size = m_UsableSize. +Each next level contains nodes with size 2 times smaller than current level. +m_LevelCount is the maximum number of levels to use in the current object. +*/ +class VmaBlockMetadata_Buddy : public VmaBlockMetadata { - VMA_CLASS_NO_COPY(VmaAllocator_T) + VMA_CLASS_NO_COPY(VmaBlockMetadata_Buddy) public: - bool m_UseMutex; - uint32_t m_VulkanApiVersion; - bool m_UseKhrDedicatedAllocation; // Can be set only if m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0). - bool m_UseKhrBindMemory2; // Can be set only if m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0). - bool m_UseExtMemoryBudget; - bool m_UseAmdDeviceCoherentMemory; - bool m_UseKhrBufferDeviceAddress; - bool m_UseExtMemoryPriority; - VkDevice m_hDevice; - VkInstance m_hInstance; - bool m_AllocationCallbacksSpecified; - VkAllocationCallbacks m_AllocationCallbacks; - VmaDeviceMemoryCallbacks m_DeviceMemoryCallbacks; - VmaAllocationObjectAllocator m_AllocationObjectAllocator; + VmaBlockMetadata_Buddy(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual); + virtual ~VmaBlockMetadata_Buddy(); - // Each bit (1 << i) is set if HeapSizeLimit is enabled for that heap, so cannot allocate more than the heap size. - uint32_t m_HeapSizeLimitMask; + size_t GetAllocationCount() const override { return m_AllocationCount; } + VkDeviceSize GetSumFreeSize() const override { return m_SumFreeSize + GetUnusableSize(); } + bool IsEmpty() const override { return m_Root->type == Node::TYPE_FREE; } + VkResult CheckCorruption(const void* pBlockData) override { return VK_ERROR_FEATURE_NOT_PRESENT; } + VkDeviceSize GetAllocationOffset(VmaAllocHandle allocHandle) const override { return (VkDeviceSize)allocHandle - 1; }; + void DebugLogAllAllocations() const override { DebugLogAllAllocationNode(m_Root, 0); } - VkPhysicalDeviceProperties m_PhysicalDeviceProperties; - VkPhysicalDeviceMemoryProperties m_MemProps; + void Init(VkDeviceSize size) override; + bool Validate() const override; - // Default pools. - VmaBlockVector* m_pBlockVectors[VK_MAX_MEMORY_TYPES]; + void AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const override; + void AddStatistics(VmaStatistics& inoutStats) const override; - // Each vector is sorted by memory (handle value). - typedef VmaVector< VmaAllocation, VmaStlAllocator > AllocationVectorType; - AllocationVectorType* m_pDedicatedAllocations[VK_MAX_MEMORY_TYPES]; - VMA_RW_MUTEX m_DedicatedAllocationsMutex[VK_MAX_MEMORY_TYPES]; +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap(class VmaJsonWriter& json, uint32_t mapRefCount) const override; +#endif - VmaCurrentBudgetData m_Budget; - VMA_ATOMIC_UINT32 m_DeviceMemoryCount; // Total number of VkDeviceMemory objects. + bool CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) override; - VmaAllocator_T(const VmaAllocatorCreateInfo* pCreateInfo); - VkResult Init(const VmaAllocatorCreateInfo* pCreateInfo); - ~VmaAllocator_T(); + void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) override; - const VkAllocationCallbacks* GetAllocationCallbacks() const - { - return m_AllocationCallbacksSpecified ? &m_AllocationCallbacks : 0; - } - const VmaVulkanFunctions& GetVulkanFunctions() const - { - return m_VulkanFunctions; - } + void Free(VmaAllocHandle allocHandle) override; + void GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) override; + void* GetAllocationUserData(VmaAllocHandle allocHandle) const override; + VmaAllocHandle GetAllocationListBegin() const override; + VmaAllocHandle GetNextAllocation(VmaAllocHandle prevAlloc) const override; + void Clear() override; + void SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) override; - VkPhysicalDevice GetPhysicalDevice() const { return m_PhysicalDevice; } +private: + static const size_t MAX_LEVELS = 48; - VkDeviceSize GetBufferImageGranularity() const + struct ValidationContext { - return VMA_MAX( - static_cast(VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY), - m_PhysicalDeviceProperties.limits.bufferImageGranularity); - } - - uint32_t GetMemoryHeapCount() const { return m_MemProps.memoryHeapCount; } - uint32_t GetMemoryTypeCount() const { return m_MemProps.memoryTypeCount; } - - uint32_t MemoryTypeIndexToHeapIndex(uint32_t memTypeIndex) const - { - VMA_ASSERT(memTypeIndex < m_MemProps.memoryTypeCount); - return m_MemProps.memoryTypes[memTypeIndex].heapIndex; - } - // True when specific memory type is HOST_VISIBLE but not HOST_COHERENT. - bool IsMemoryTypeNonCoherent(uint32_t memTypeIndex) const - { - return (m_MemProps.memoryTypes[memTypeIndex].propertyFlags & (VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT)) == - VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - } - // Minimum alignment for all allocations in specific memory type. - VkDeviceSize GetMemoryTypeMinAlignment(uint32_t memTypeIndex) const - { - return IsMemoryTypeNonCoherent(memTypeIndex) ? - VMA_MAX((VkDeviceSize)VMA_DEBUG_ALIGNMENT, m_PhysicalDeviceProperties.limits.nonCoherentAtomSize) : - (VkDeviceSize)VMA_DEBUG_ALIGNMENT; - } - - bool IsIntegratedGpu() const + size_t calculatedAllocationCount = 0; + size_t calculatedFreeCount = 0; + VkDeviceSize calculatedSumFreeSize = 0; + }; + struct Node { - return m_PhysicalDeviceProperties.deviceType == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU; - } - - uint32_t GetGlobalMemoryTypeBits() const { return m_GlobalMemoryTypeBits; } - -#if VMA_RECORDING_ENABLED - VmaRecorder* GetRecorder() const { return m_pRecorder; } -#endif + VkDeviceSize offset; + enum TYPE + { + TYPE_FREE, + TYPE_ALLOCATION, + TYPE_SPLIT, + TYPE_COUNT + } type; + Node* parent; + Node* buddy; - void GetBufferMemoryRequirements( - VkBuffer hBuffer, - VkMemoryRequirements& memReq, - bool& requiresDedicatedAllocation, - bool& prefersDedicatedAllocation) const; - void GetImageMemoryRequirements( - VkImage hImage, - VkMemoryRequirements& memReq, - bool& requiresDedicatedAllocation, - bool& prefersDedicatedAllocation) const; + union + { + struct + { + Node* prev; + Node* next; + } free; + struct + { + void* userData; + } allocation; + struct + { + Node* leftChild; + } split; + }; + }; - // Main allocation function. - VkResult AllocateMemory( - const VkMemoryRequirements& vkMemReq, - bool requiresDedicatedAllocation, - bool prefersDedicatedAllocation, - VkBuffer dedicatedBuffer, - VkBufferUsageFlags dedicatedBufferUsage, // UINT32_MAX when unknown. - VkImage dedicatedImage, - const VmaAllocationCreateInfo& createInfo, - VmaSuballocationType suballocType, - size_t allocationCount, - VmaAllocation* pAllocations); + // Size of the memory block aligned down to a power of two. + VkDeviceSize m_UsableSize; + uint32_t m_LevelCount; + VmaPoolAllocator m_NodeAllocator; + Node* m_Root; + struct + { + Node* front; + Node* back; + } m_FreeList[MAX_LEVELS]; - // Main deallocation function. - void FreeMemory( - size_t allocationCount, - const VmaAllocation* pAllocations); + // Number of nodes in the tree with type == TYPE_ALLOCATION. + size_t m_AllocationCount; + // Number of nodes in the tree with type == TYPE_FREE. + size_t m_FreeCount; + // Doesn't include space wasted due to internal fragmentation - allocation sizes are just aligned up to node sizes. + // Doesn't include unusable size. + VkDeviceSize m_SumFreeSize; - void CalculateStats(VmaStats* pStats); + VkDeviceSize GetUnusableSize() const { return GetSize() - m_UsableSize; } + VkDeviceSize LevelToNodeSize(uint32_t level) const { return m_UsableSize >> level; } - void GetBudget( - VmaBudget* outBudget, uint32_t firstHeap, uint32_t heapCount); + VkDeviceSize AlignAllocationSize(VkDeviceSize size) const + { + if (!IsVirtual()) + { + size = VmaAlignUp(size, (VkDeviceSize)16); + } + return VmaNextPow2(size); + } + Node* FindAllocationNode(VkDeviceSize offset, uint32_t& outLevel) const; + void DeleteNodeChildren(Node* node); + bool ValidateNode(ValidationContext& ctx, const Node* parent, const Node* curr, uint32_t level, VkDeviceSize levelNodeSize) const; + uint32_t AllocSizeToLevel(VkDeviceSize allocSize) const; + void AddNodeToDetailedStatistics(VmaDetailedStatistics& inoutStats, const Node* node, VkDeviceSize levelNodeSize) const; + // Adds node to the front of FreeList at given level. + // node->type must be FREE. + // node->free.prev, next can be undefined. + void AddToFreeListFront(uint32_t level, Node* node); + // Removes node from FreeList at given level. + // node->type must be FREE. + // node->free.prev, next stay untouched. + void RemoveFromFreeList(uint32_t level, Node* node); + void DebugLogAllAllocationNode(Node* node, uint32_t level) const; #if VMA_STATS_STRING_ENABLED - void PrintDetailedMap(class VmaJsonWriter& json); + void PrintDetailedMapNode(class VmaJsonWriter& json, const Node* node, VkDeviceSize levelNodeSize) const; #endif +}; - VkResult DefragmentationBegin( - const VmaDefragmentationInfo2& info, - VmaDefragmentationStats* pStats, - VmaDefragmentationContext* pContext); - VkResult DefragmentationEnd( - VmaDefragmentationContext context); - - VkResult DefragmentationPassBegin( - VmaDefragmentationPassInfo* pInfo, - VmaDefragmentationContext context); - VkResult DefragmentationPassEnd( - VmaDefragmentationContext context); +#ifndef _VMA_BLOCK_METADATA_BUDDY_FUNCTIONS +VmaBlockMetadata_Buddy::VmaBlockMetadata_Buddy(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual) + : VmaBlockMetadata(pAllocationCallbacks, bufferImageGranularity, isVirtual), + m_NodeAllocator(pAllocationCallbacks, 32), // firstBlockCapacity + m_Root(VMA_NULL), + m_AllocationCount(0), + m_FreeCount(1), + m_SumFreeSize(0) +{ + memset(m_FreeList, 0, sizeof(m_FreeList)); +} - void GetAllocationInfo(VmaAllocation hAllocation, VmaAllocationInfo* pAllocationInfo); - bool TouchAllocation(VmaAllocation hAllocation); +VmaBlockMetadata_Buddy::~VmaBlockMetadata_Buddy() +{ + DeleteNodeChildren(m_Root); + m_NodeAllocator.Free(m_Root); +} - VkResult CreatePool(const VmaPoolCreateInfo* pCreateInfo, VmaPool* pPool); - void DestroyPool(VmaPool pool); - void GetPoolStats(VmaPool pool, VmaPoolStats* pPoolStats); +void VmaBlockMetadata_Buddy::Init(VkDeviceSize size) +{ + VmaBlockMetadata::Init(size); - void SetCurrentFrameIndex(uint32_t frameIndex); - uint32_t GetCurrentFrameIndex() const { return m_CurrentFrameIndex.load(); } + m_UsableSize = VmaPrevPow2(size); + m_SumFreeSize = m_UsableSize; - void MakePoolAllocationsLost( - VmaPool hPool, - size_t* pLostAllocationCount); - VkResult CheckPoolCorruption(VmaPool hPool); - VkResult CheckCorruption(uint32_t memoryTypeBits); + // Calculate m_LevelCount. + const VkDeviceSize minNodeSize = IsVirtual() ? 1 : 16; + m_LevelCount = 1; + while (m_LevelCount < MAX_LEVELS && + LevelToNodeSize(m_LevelCount) >= minNodeSize) + { + ++m_LevelCount; + } - void CreateLostAllocation(VmaAllocation* pAllocation); + Node* rootNode = m_NodeAllocator.Alloc(); + rootNode->offset = 0; + rootNode->type = Node::TYPE_FREE; + rootNode->parent = VMA_NULL; + rootNode->buddy = VMA_NULL; - // Call to Vulkan function vkAllocateMemory with accompanying bookkeeping. - VkResult AllocateVulkanMemory(const VkMemoryAllocateInfo* pAllocateInfo, VkDeviceMemory* pMemory); - // Call to Vulkan function vkFreeMemory with accompanying bookkeeping. - void FreeVulkanMemory(uint32_t memoryType, VkDeviceSize size, VkDeviceMemory hMemory); - // Call to Vulkan function vkBindBufferMemory or vkBindBufferMemory2KHR. - VkResult BindVulkanBuffer( - VkDeviceMemory memory, - VkDeviceSize memoryOffset, - VkBuffer buffer, - const void* pNext); - // Call to Vulkan function vkBindImageMemory or vkBindImageMemory2KHR. - VkResult BindVulkanImage( - VkDeviceMemory memory, - VkDeviceSize memoryOffset, - VkImage image, - const void* pNext); + m_Root = rootNode; + AddToFreeListFront(0, rootNode); +} - VkResult Map(VmaAllocation hAllocation, void** ppData); - void Unmap(VmaAllocation hAllocation); +bool VmaBlockMetadata_Buddy::Validate() const +{ + // Validate tree. + ValidationContext ctx; + if (!ValidateNode(ctx, VMA_NULL, m_Root, 0, LevelToNodeSize(0))) + { + VMA_VALIDATE(false && "ValidateNode failed."); + } + VMA_VALIDATE(m_AllocationCount == ctx.calculatedAllocationCount); + VMA_VALIDATE(m_SumFreeSize == ctx.calculatedSumFreeSize); - VkResult BindBufferMemory( - VmaAllocation hAllocation, - VkDeviceSize allocationLocalOffset, - VkBuffer hBuffer, - const void* pNext); - VkResult BindImageMemory( - VmaAllocation hAllocation, - VkDeviceSize allocationLocalOffset, - VkImage hImage, - const void* pNext); + // Validate free node lists. + for (uint32_t level = 0; level < m_LevelCount; ++level) + { + VMA_VALIDATE(m_FreeList[level].front == VMA_NULL || + m_FreeList[level].front->free.prev == VMA_NULL); - VkResult FlushOrInvalidateAllocation( - VmaAllocation hAllocation, - VkDeviceSize offset, VkDeviceSize size, - VMA_CACHE_OPERATION op); - VkResult FlushOrInvalidateAllocations( - uint32_t allocationCount, - const VmaAllocation* allocations, - const VkDeviceSize* offsets, const VkDeviceSize* sizes, - VMA_CACHE_OPERATION op); + for (Node* node = m_FreeList[level].front; + node != VMA_NULL; + node = node->free.next) + { + VMA_VALIDATE(node->type == Node::TYPE_FREE); - void FillAllocation(const VmaAllocation hAllocation, uint8_t pattern); + if (node->free.next == VMA_NULL) + { + VMA_VALIDATE(m_FreeList[level].back == node); + } + else + { + VMA_VALIDATE(node->free.next->free.prev == node); + } + } + } - /* - Returns bit mask of memory types that can support defragmentation on GPU as - they support creation of required buffer for copy operations. - */ - uint32_t GetGpuDefragmentationMemoryTypeBits(); + // Validate that free lists ar higher levels are empty. + for (uint32_t level = m_LevelCount; level < MAX_LEVELS; ++level) + { + VMA_VALIDATE(m_FreeList[level].front == VMA_NULL && m_FreeList[level].back == VMA_NULL); + } -private: - VkDeviceSize m_PreferredLargeHeapBlockSize; + return true; +} - VkPhysicalDevice m_PhysicalDevice; - VMA_ATOMIC_UINT32 m_CurrentFrameIndex; - VMA_ATOMIC_UINT32 m_GpuDefragmentationMemoryTypeBits; // UINT32_MAX means uninitialized. +void VmaBlockMetadata_Buddy::AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const +{ + inoutStats.statistics.blockCount++; + inoutStats.statistics.blockBytes += GetSize(); - VMA_RW_MUTEX m_PoolsMutex; - // Protected by m_PoolsMutex. Sorted by pointer value. - VmaVector > m_Pools; - uint32_t m_NextPoolId; + AddNodeToDetailedStatistics(inoutStats, m_Root, LevelToNodeSize(0)); - VmaVulkanFunctions m_VulkanFunctions; + const VkDeviceSize unusableSize = GetUnusableSize(); + if (unusableSize > 0) + VmaAddDetailedStatisticsUnusedRange(inoutStats, unusableSize); +} - // Global bit mask AND-ed with any memoryTypeBits to disallow certain memory types. - uint32_t m_GlobalMemoryTypeBits; +void VmaBlockMetadata_Buddy::AddStatistics(VmaStatistics& inoutStats) const +{ + inoutStats.blockCount++; + inoutStats.allocationCount += (uint32_t)m_AllocationCount; + inoutStats.blockBytes += GetSize(); + inoutStats.allocationBytes += GetSize() - m_SumFreeSize; +} -#if VMA_RECORDING_ENABLED - VmaRecorder* m_pRecorder; -#endif +#if VMA_STATS_STRING_ENABLED +void VmaBlockMetadata_Buddy::PrintDetailedMap(class VmaJsonWriter& json, uint32_t mapRefCount) const +{ + VmaDetailedStatistics stats; + VmaClearDetailedStatistics(stats); + AddDetailedStatistics(stats); - void ImportVulkanFunctions(const VmaVulkanFunctions* pVulkanFunctions); + PrintDetailedMap_Begin( + json, + stats.statistics.blockBytes - stats.statistics.allocationBytes, + stats.statistics.allocationCount, + stats.unusedRangeCount, + mapRefCount); -#if VMA_STATIC_VULKAN_FUNCTIONS == 1 - void ImportVulkanFunctions_Static(); -#endif + PrintDetailedMapNode(json, m_Root, LevelToNodeSize(0)); - void ImportVulkanFunctions_Custom(const VmaVulkanFunctions* pVulkanFunctions); + const VkDeviceSize unusableSize = GetUnusableSize(); + if (unusableSize > 0) + { + PrintDetailedMap_UnusedRange(json, + m_UsableSize, // offset + unusableSize); // size + } -#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 - void ImportVulkanFunctions_Dynamic(); -#endif + PrintDetailedMap_End(json); +} +#endif // VMA_STATS_STRING_ENABLED - void ValidateVulkanFunctions(); +bool VmaBlockMetadata_Buddy::CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) +{ + VMA_ASSERT(!upperAddress && "VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT can be used only with linear algorithm."); - VkDeviceSize CalcPreferredBlockSize(uint32_t memTypeIndex); + allocSize = AlignAllocationSize(allocSize); - VkResult AllocateMemoryOfType( - VkDeviceSize size, - VkDeviceSize alignment, - bool dedicatedAllocation, - VkBuffer dedicatedBuffer, - VkBufferUsageFlags dedicatedBufferUsage, - VkImage dedicatedImage, - const VmaAllocationCreateInfo& createInfo, - uint32_t memTypeIndex, - VmaSuballocationType suballocType, - size_t allocationCount, - VmaAllocation* pAllocations); + // Simple way to respect bufferImageGranularity. May be optimized some day. + // Whenever it might be an OPTIMAL image... + if (allocType == VMA_SUBALLOCATION_TYPE_UNKNOWN || + allocType == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || + allocType == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL) + { + allocAlignment = VMA_MAX(allocAlignment, GetBufferImageGranularity()); + allocSize = VmaAlignUp(allocSize, GetBufferImageGranularity()); + } - // Helper function only to be used inside AllocateDedicatedMemory. - VkResult AllocateDedicatedMemoryPage( - VkDeviceSize size, - VmaSuballocationType suballocType, - uint32_t memTypeIndex, - const VkMemoryAllocateInfo& allocInfo, - bool map, - bool isUserDataString, - void* pUserData, - VmaAllocation* pAllocation); + if (allocSize > m_UsableSize) + { + return false; + } - // Allocates and registers new VkDeviceMemory specifically for dedicated allocations. - VkResult AllocateDedicatedMemory( - VkDeviceSize size, - VmaSuballocationType suballocType, - uint32_t memTypeIndex, - bool withinBudget, - bool map, - bool isUserDataString, - void* pUserData, - float priority, - VkBuffer dedicatedBuffer, - VkBufferUsageFlags dedicatedBufferUsage, - VkImage dedicatedImage, - size_t allocationCount, - VmaAllocation* pAllocations); + const uint32_t targetLevel = AllocSizeToLevel(allocSize); + for (uint32_t level = targetLevel; level--; ) + { + for (Node* freeNode = m_FreeList[level].front; + freeNode != VMA_NULL; + freeNode = freeNode->free.next) + { + if (freeNode->offset % allocAlignment == 0) + { + pAllocationRequest->type = VmaAllocationRequestType::Normal; + pAllocationRequest->allocHandle = (VmaAllocHandle)(freeNode->offset + 1); + pAllocationRequest->size = allocSize; + pAllocationRequest->customData = (void*)(uintptr_t)level; + return true; + } + } + } - void FreeDedicatedMemory(const VmaAllocation allocation); + return false; +} - /* - Calculates and returns bit mask of memory types that can support defragmentation - on GPU as they support creation of required buffer for copy operations. - */ - uint32_t CalculateGpuDefragmentationMemoryTypeBits() const; +void VmaBlockMetadata_Buddy::Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) +{ + VMA_ASSERT(request.type == VmaAllocationRequestType::Normal); - uint32_t CalculateGlobalMemoryTypeBits() const; + const uint32_t targetLevel = AllocSizeToLevel(request.size); + uint32_t currLevel = (uint32_t)(uintptr_t)request.customData; - bool GetFlushOrInvalidateRange( - VmaAllocation allocation, - VkDeviceSize offset, VkDeviceSize size, - VkMappedMemoryRange& outRange) const; + Node* currNode = m_FreeList[currLevel].front; + VMA_ASSERT(currNode != VMA_NULL && currNode->type == Node::TYPE_FREE); + const VkDeviceSize offset = (VkDeviceSize)request.allocHandle - 1; + while (currNode->offset != offset) + { + currNode = currNode->free.next; + VMA_ASSERT(currNode != VMA_NULL && currNode->type == Node::TYPE_FREE); + } -#if VMA_MEMORY_BUDGET - void UpdateVulkanBudget(); -#endif // #if VMA_MEMORY_BUDGET -}; + // Go down, splitting free nodes. + while (currLevel < targetLevel) + { + // currNode is already first free node at currLevel. + // Remove it from list of free nodes at this currLevel. + RemoveFromFreeList(currLevel, currNode); -//////////////////////////////////////////////////////////////////////////////// -// Memory allocation #2 after VmaAllocator_T definition + const uint32_t childrenLevel = currLevel + 1; -static void* VmaMalloc(VmaAllocator hAllocator, size_t size, size_t alignment) -{ - return VmaMalloc(&hAllocator->m_AllocationCallbacks, size, alignment); + // Create two free sub-nodes. + Node* leftChild = m_NodeAllocator.Alloc(); + Node* rightChild = m_NodeAllocator.Alloc(); + + leftChild->offset = currNode->offset; + leftChild->type = Node::TYPE_FREE; + leftChild->parent = currNode; + leftChild->buddy = rightChild; + + rightChild->offset = currNode->offset + LevelToNodeSize(childrenLevel); + rightChild->type = Node::TYPE_FREE; + rightChild->parent = currNode; + rightChild->buddy = leftChild; + + // Convert current currNode to split type. + currNode->type = Node::TYPE_SPLIT; + currNode->split.leftChild = leftChild; + + // Add child nodes to free list. Order is important! + AddToFreeListFront(childrenLevel, rightChild); + AddToFreeListFront(childrenLevel, leftChild); + + ++m_FreeCount; + ++currLevel; + currNode = m_FreeList[currLevel].front; + + /* + We can be sure that currNode, as left child of node previously split, + also fulfills the alignment requirement. + */ + } + + // Remove from free list. + VMA_ASSERT(currLevel == targetLevel && + currNode != VMA_NULL && + currNode->type == Node::TYPE_FREE); + RemoveFromFreeList(currLevel, currNode); + + // Convert to allocation node. + currNode->type = Node::TYPE_ALLOCATION; + currNode->allocation.userData = userData; + + ++m_AllocationCount; + --m_FreeCount; + m_SumFreeSize -= request.size; } -static void VmaFree(VmaAllocator hAllocator, void* ptr) +void VmaBlockMetadata_Buddy::GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) { - VmaFree(&hAllocator->m_AllocationCallbacks, ptr); + uint32_t level = 0; + outInfo.offset = (VkDeviceSize)allocHandle - 1; + const Node* const node = FindAllocationNode(outInfo.offset, level); + outInfo.size = LevelToNodeSize(level); + outInfo.pUserData = node->allocation.userData; } -template -static T* VmaAllocate(VmaAllocator hAllocator) +void* VmaBlockMetadata_Buddy::GetAllocationUserData(VmaAllocHandle allocHandle) const { - return (T*)VmaMalloc(hAllocator, sizeof(T), VMA_ALIGN_OF(T)); + uint32_t level = 0; + const Node* const node = FindAllocationNode((VkDeviceSize)allocHandle - 1, level); + return node->allocation.userData; } -template -static T* VmaAllocateArray(VmaAllocator hAllocator, size_t count) +VmaAllocHandle VmaBlockMetadata_Buddy::GetAllocationListBegin() const { - return (T*)VmaMalloc(hAllocator, sizeof(T) * count, VMA_ALIGN_OF(T)); + // Function only used for defragmentation, which is disabled for this algorithm + return VK_NULL_HANDLE; } -template -static void vma_delete(VmaAllocator hAllocator, T* ptr) +VmaAllocHandle VmaBlockMetadata_Buddy::GetNextAllocation(VmaAllocHandle prevAlloc) const { - if(ptr != VMA_NULL) - { - ptr->~T(); - VmaFree(hAllocator, ptr); - } + // Function only used for defragmentation, which is disabled for this algorithm + return VK_NULL_HANDLE; } -template -static void vma_delete_array(VmaAllocator hAllocator, T* ptr, size_t count) +void VmaBlockMetadata_Buddy::DeleteNodeChildren(Node* node) { - if(ptr != VMA_NULL) + if (node->type == Node::TYPE_SPLIT) { - for(size_t i = count; i--; ) - ptr[i].~T(); - VmaFree(hAllocator, ptr); + DeleteNodeChildren(node->split.leftChild->buddy); + DeleteNodeChildren(node->split.leftChild); + const VkAllocationCallbacks* allocationCallbacks = GetAllocationCallbacks(); + m_NodeAllocator.Free(node->split.leftChild->buddy); + m_NodeAllocator.Free(node->split.leftChild); } } -//////////////////////////////////////////////////////////////////////////////// -// VmaStringBuilder - -#if VMA_STATS_STRING_ENABLED - -class VmaStringBuilder +void VmaBlockMetadata_Buddy::Clear() { -public: - VmaStringBuilder(VmaAllocator alloc) : m_Data(VmaStlAllocator(alloc->GetAllocationCallbacks())) { } - size_t GetLength() const { return m_Data.size(); } - const char* GetData() const { return m_Data.data(); } - - void Add(char ch) { m_Data.push_back(ch); } - void Add(const char* pStr); - void AddNewLine() { Add('\n'); } - void AddNumber(uint32_t num); - void AddNumber(uint64_t num); - void AddPointer(const void* ptr); + DeleteNodeChildren(m_Root); + m_Root->type = Node::TYPE_FREE; + m_AllocationCount = 0; + m_FreeCount = 1; + m_SumFreeSize = m_UsableSize; +} -private: - VmaVector< char, VmaStlAllocator > m_Data; -}; +void VmaBlockMetadata_Buddy::SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) +{ + uint32_t level = 0; + Node* const node = FindAllocationNode((VkDeviceSize)allocHandle - 1, level); + node->allocation.userData = userData; +} -void VmaStringBuilder::Add(const char* pStr) +VmaBlockMetadata_Buddy::Node* VmaBlockMetadata_Buddy::FindAllocationNode(VkDeviceSize offset, uint32_t& outLevel) const { - const size_t strLen = strlen(pStr); - if(strLen > 0) + Node* node = m_Root; + VkDeviceSize nodeOffset = 0; + outLevel = 0; + VkDeviceSize levelNodeSize = LevelToNodeSize(0); + while (node->type == Node::TYPE_SPLIT) { - const size_t oldCount = m_Data.size(); - m_Data.resize(oldCount + strLen); - memcpy(m_Data.data() + oldCount, pStr, strLen); + const VkDeviceSize nextLevelNodeSize = levelNodeSize >> 1; + if (offset < nodeOffset + nextLevelNodeSize) + { + node = node->split.leftChild; + } + else + { + node = node->split.leftChild->buddy; + nodeOffset += nextLevelNodeSize; + } + ++outLevel; + levelNodeSize = nextLevelNodeSize; } + + VMA_ASSERT(node != VMA_NULL && node->type == Node::TYPE_ALLOCATION); + return node; } -void VmaStringBuilder::AddNumber(uint32_t num) +bool VmaBlockMetadata_Buddy::ValidateNode(ValidationContext& ctx, const Node* parent, const Node* curr, uint32_t level, VkDeviceSize levelNodeSize) const { - char buf[11]; - buf[10] = '\0'; - char *p = &buf[10]; - do + VMA_VALIDATE(level < m_LevelCount); + VMA_VALIDATE(curr->parent == parent); + VMA_VALIDATE((curr->buddy == VMA_NULL) == (parent == VMA_NULL)); + VMA_VALIDATE(curr->buddy == VMA_NULL || curr->buddy->buddy == curr); + switch (curr->type) { - *--p = '0' + (num % 10); - num /= 10; + case Node::TYPE_FREE: + // curr->free.prev, next are validated separately. + ctx.calculatedSumFreeSize += levelNodeSize; + ++ctx.calculatedFreeCount; + break; + case Node::TYPE_ALLOCATION: + ++ctx.calculatedAllocationCount; + if (!IsVirtual()) + { + VMA_VALIDATE(curr->allocation.userData != VMA_NULL); + } + break; + case Node::TYPE_SPLIT: + { + const uint32_t childrenLevel = level + 1; + const VkDeviceSize childrenLevelNodeSize = levelNodeSize >> 1; + const Node* const leftChild = curr->split.leftChild; + VMA_VALIDATE(leftChild != VMA_NULL); + VMA_VALIDATE(leftChild->offset == curr->offset); + if (!ValidateNode(ctx, curr, leftChild, childrenLevel, childrenLevelNodeSize)) + { + VMA_VALIDATE(false && "ValidateNode for left child failed."); + } + const Node* const rightChild = leftChild->buddy; + VMA_VALIDATE(rightChild->offset == curr->offset + childrenLevelNodeSize); + if (!ValidateNode(ctx, curr, rightChild, childrenLevel, childrenLevelNodeSize)) + { + VMA_VALIDATE(false && "ValidateNode for right child failed."); + } } - while(num); - Add(p); + break; + default: + return false; + } + + return true; } -void VmaStringBuilder::AddNumber(uint64_t num) +uint32_t VmaBlockMetadata_Buddy::AllocSizeToLevel(VkDeviceSize allocSize) const { - char buf[21]; - buf[20] = '\0'; - char *p = &buf[20]; - do + // I know this could be optimized somehow e.g. by using std::log2p1 from C++20. + uint32_t level = 0; + VkDeviceSize currLevelNodeSize = m_UsableSize; + VkDeviceSize nextLevelNodeSize = currLevelNodeSize >> 1; + while (allocSize <= nextLevelNodeSize && level + 1 < m_LevelCount) { - *--p = '0' + (num % 10); - num /= 10; + ++level; + currLevelNodeSize >>= 1; + nextLevelNodeSize >>= 1; } - while(num); - Add(p); + return level; } -void VmaStringBuilder::AddPointer(const void* ptr) +void VmaBlockMetadata_Buddy::Free(VmaAllocHandle allocHandle) { - char buf[21]; - VmaPtrToStr(buf, sizeof(buf), ptr); - Add(buf); -} - -#endif // #if VMA_STATS_STRING_ENABLED + uint32_t level = 0; + Node* node = FindAllocationNode((VkDeviceSize)allocHandle - 1, level); -//////////////////////////////////////////////////////////////////////////////// -// VmaJsonWriter + ++m_FreeCount; + --m_AllocationCount; + m_SumFreeSize += LevelToNodeSize(level); -#if VMA_STATS_STRING_ENABLED + node->type = Node::TYPE_FREE; -class VmaJsonWriter -{ - VMA_CLASS_NO_COPY(VmaJsonWriter) -public: - VmaJsonWriter(const VkAllocationCallbacks* pAllocationCallbacks, VmaStringBuilder& sb); - ~VmaJsonWriter(); + // Join free nodes if possible. + while (level > 0 && node->buddy->type == Node::TYPE_FREE) + { + RemoveFromFreeList(level, node->buddy); + Node* const parent = node->parent; - void BeginObject(bool singleLine = false); - void EndObject(); + m_NodeAllocator.Free(node->buddy); + m_NodeAllocator.Free(node); + parent->type = Node::TYPE_FREE; - void BeginArray(bool singleLine = false); - void EndArray(); + node = parent; + --level; + --m_FreeCount; + } - void WriteString(const char* pStr); - void BeginString(const char* pStr = VMA_NULL); - void ContinueString(const char* pStr); - void ContinueString(uint32_t n); - void ContinueString(uint64_t n); - void ContinueString_Pointer(const void* ptr); - void EndString(const char* pStr = VMA_NULL); + AddToFreeListFront(level, node); +} - void WriteNumber(uint32_t n); - void WriteNumber(uint64_t n); - void WriteBool(bool b); - void WriteNull(); +void VmaBlockMetadata_Buddy::AddNodeToDetailedStatistics(VmaDetailedStatistics& inoutStats, const Node* node, VkDeviceSize levelNodeSize) const +{ + switch (node->type) + { + case Node::TYPE_FREE: + VmaAddDetailedStatisticsUnusedRange(inoutStats, levelNodeSize); + break; + case Node::TYPE_ALLOCATION: + VmaAddDetailedStatisticsAllocation(inoutStats, levelNodeSize); + break; + case Node::TYPE_SPLIT: + { + const VkDeviceSize childrenNodeSize = levelNodeSize / 2; + const Node* const leftChild = node->split.leftChild; + AddNodeToDetailedStatistics(inoutStats, leftChild, childrenNodeSize); + const Node* const rightChild = leftChild->buddy; + AddNodeToDetailedStatistics(inoutStats, rightChild, childrenNodeSize); + } + break; + default: + VMA_ASSERT(0); + } +} -private: - static const char* const INDENT; +void VmaBlockMetadata_Buddy::AddToFreeListFront(uint32_t level, Node* node) +{ + VMA_ASSERT(node->type == Node::TYPE_FREE); - enum COLLECTION_TYPE + // List is empty. + Node* const frontNode = m_FreeList[level].front; + if (frontNode == VMA_NULL) { - COLLECTION_TYPE_OBJECT, - COLLECTION_TYPE_ARRAY, - }; - struct StackItem + VMA_ASSERT(m_FreeList[level].back == VMA_NULL); + node->free.prev = node->free.next = VMA_NULL; + m_FreeList[level].front = m_FreeList[level].back = node; + } + else { - COLLECTION_TYPE type; - uint32_t valueCount; - bool singleLineMode; - }; + VMA_ASSERT(frontNode->free.prev == VMA_NULL); + node->free.prev = VMA_NULL; + node->free.next = frontNode; + frontNode->free.prev = node; + m_FreeList[level].front = node; + } +} - VmaStringBuilder& m_SB; - VmaVector< StackItem, VmaStlAllocator > m_Stack; - bool m_InsideString; +void VmaBlockMetadata_Buddy::RemoveFromFreeList(uint32_t level, Node* node) +{ + VMA_ASSERT(m_FreeList[level].front != VMA_NULL); - void BeginValue(bool isString); - void WriteIndent(bool oneLess = false); -}; + // It is at the front. + if (node->free.prev == VMA_NULL) + { + VMA_ASSERT(m_FreeList[level].front == node); + m_FreeList[level].front = node->free.next; + } + else + { + Node* const prevFreeNode = node->free.prev; + VMA_ASSERT(prevFreeNode->free.next == node); + prevFreeNode->free.next = node->free.next; + } -const char* const VmaJsonWriter::INDENT = " "; + // It is at the back. + if (node->free.next == VMA_NULL) + { + VMA_ASSERT(m_FreeList[level].back == node); + m_FreeList[level].back = node->free.prev; + } + else + { + Node* const nextFreeNode = node->free.next; + VMA_ASSERT(nextFreeNode->free.prev == node); + nextFreeNode->free.prev = node->free.prev; + } +} -VmaJsonWriter::VmaJsonWriter(const VkAllocationCallbacks* pAllocationCallbacks, VmaStringBuilder& sb) : - m_SB(sb), - m_Stack(VmaStlAllocator(pAllocationCallbacks)), - m_InsideString(false) +void VmaBlockMetadata_Buddy::DebugLogAllAllocationNode(Node* node, uint32_t level) const { + switch (node->type) + { + case Node::TYPE_FREE: + break; + case Node::TYPE_ALLOCATION: + DebugLogAllocation(node->offset, LevelToNodeSize(level), node->allocation.userData); + break; + case Node::TYPE_SPLIT: + { + ++level; + DebugLogAllAllocationNode(node->split.leftChild, level); + DebugLogAllAllocationNode(node->split.leftChild->buddy, level); + } + break; + default: + VMA_ASSERT(0); + } } -VmaJsonWriter::~VmaJsonWriter() +#if VMA_STATS_STRING_ENABLED +void VmaBlockMetadata_Buddy::PrintDetailedMapNode(class VmaJsonWriter& json, const Node* node, VkDeviceSize levelNodeSize) const { - VMA_ASSERT(!m_InsideString); - VMA_ASSERT(m_Stack.empty()); + switch (node->type) + { + case Node::TYPE_FREE: + PrintDetailedMap_UnusedRange(json, node->offset, levelNodeSize); + break; + case Node::TYPE_ALLOCATION: + PrintDetailedMap_Allocation(json, node->offset, levelNodeSize, node->allocation.userData); + break; + case Node::TYPE_SPLIT: + { + const VkDeviceSize childrenNodeSize = levelNodeSize / 2; + const Node* const leftChild = node->split.leftChild; + PrintDetailedMapNode(json, leftChild, childrenNodeSize); + const Node* const rightChild = leftChild->buddy; + PrintDetailedMapNode(json, rightChild, childrenNodeSize); + } + break; + default: + VMA_ASSERT(0); + } } +#endif // VMA_STATS_STRING_ENABLED +#endif // _VMA_BLOCK_METADATA_BUDDY_FUNCTIONS +#endif // _VMA_BLOCK_METADATA_BUDDY +#endif // #if 0 -void VmaJsonWriter::BeginObject(bool singleLine) +#ifndef _VMA_BLOCK_METADATA_TLSF +// To not search current larger region if first allocation won't succeed and skip to smaller range +// use with VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT as strategy in CreateAllocationRequest(). +// When fragmentation and reusal of previous blocks doesn't matter then use with +// VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT for fastest alloc time possible. +class VmaBlockMetadata_TLSF : public VmaBlockMetadata { - VMA_ASSERT(!m_InsideString); + VMA_CLASS_NO_COPY(VmaBlockMetadata_TLSF) +public: + VmaBlockMetadata_TLSF(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual); + virtual ~VmaBlockMetadata_TLSF(); - BeginValue(false); - m_SB.Add('{'); + size_t GetAllocationCount() const override { return m_AllocCount; } + size_t GetFreeRegionsCount() const override { return m_BlocksFreeCount + 1; } + VkDeviceSize GetSumFreeSize() const override { return m_BlocksFreeSize + m_NullBlock->size; } + bool IsEmpty() const override { return m_NullBlock->offset == 0; } + VkDeviceSize GetAllocationOffset(VmaAllocHandle allocHandle) const override { return ((Block*)allocHandle)->offset; }; - StackItem item; - item.type = COLLECTION_TYPE_OBJECT; - item.valueCount = 0; - item.singleLineMode = singleLine; - m_Stack.push_back(item); -} + void Init(VkDeviceSize size) override; + bool Validate() const override; -void VmaJsonWriter::EndObject() -{ - VMA_ASSERT(!m_InsideString); + void AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const override; + void AddStatistics(VmaStatistics& inoutStats) const override; - WriteIndent(true); - m_SB.Add('}'); +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap(class VmaJsonWriter& json) const override; +#endif - VMA_ASSERT(!m_Stack.empty() && m_Stack.back().type == COLLECTION_TYPE_OBJECT); - m_Stack.pop_back(); -} + bool CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) override; -void VmaJsonWriter::BeginArray(bool singleLine) -{ - VMA_ASSERT(!m_InsideString); + VkResult CheckCorruption(const void* pBlockData) override; + void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) override; + + void Free(VmaAllocHandle allocHandle) override; + void GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) override; + void* GetAllocationUserData(VmaAllocHandle allocHandle) const override; + VmaAllocHandle GetAllocationListBegin() const override; + VmaAllocHandle GetNextAllocation(VmaAllocHandle prevAlloc) const override; + VkDeviceSize GetNextFreeRegionSize(VmaAllocHandle alloc) const override; + void Clear() override; + void SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) override; + void DebugLogAllAllocations() const override; - BeginValue(false); - m_SB.Add('['); +private: + // According to original paper it should be preferable 4 or 5: + // M. Masmano, I. Ripoll, A. Crespo, and J. Real "TLSF: a New Dynamic Memory Allocator for Real-Time Systems" + // http://www.gii.upv.es/tlsf/files/ecrts04_tlsf.pdf + static const uint8_t SECOND_LEVEL_INDEX = 5; + static const uint16_t SMALL_BUFFER_SIZE = 256; + static const uint32_t INITIAL_BLOCK_ALLOC_COUNT = 16; + static const uint8_t MEMORY_CLASS_SHIFT = 7; + static const uint8_t MAX_MEMORY_CLASSES = 65 - MEMORY_CLASS_SHIFT; - StackItem item; - item.type = COLLECTION_TYPE_ARRAY; - item.valueCount = 0; - item.singleLineMode = singleLine; - m_Stack.push_back(item); -} + class Block + { + public: + VkDeviceSize offset; + VkDeviceSize size; + Block* prevPhysical; + Block* nextPhysical; -void VmaJsonWriter::EndArray() -{ - VMA_ASSERT(!m_InsideString); + void MarkFree() { prevFree = VMA_NULL; } + void MarkTaken() { prevFree = this; } + bool IsFree() const { return prevFree != this; } + void*& UserData() { VMA_HEAVY_ASSERT(!IsFree()); return userData; } + Block*& PrevFree() { return prevFree; } + Block*& NextFree() { VMA_HEAVY_ASSERT(IsFree()); return nextFree; } - WriteIndent(true); - m_SB.Add(']'); + private: + Block* prevFree; // Address of the same block here indicates that block is taken + union + { + Block* nextFree; + void* userData; + }; + }; - VMA_ASSERT(!m_Stack.empty() && m_Stack.back().type == COLLECTION_TYPE_ARRAY); - m_Stack.pop_back(); -} + size_t m_AllocCount; + // Total number of free blocks besides null block + size_t m_BlocksFreeCount; + // Total size of free blocks excluding null block + VkDeviceSize m_BlocksFreeSize; + uint32_t m_IsFreeBitmap; + uint8_t m_MemoryClasses; + uint32_t m_InnerIsFreeBitmap[MAX_MEMORY_CLASSES]; + uint32_t m_ListsCount; + /* + * 0: 0-3 lists for small buffers + * 1+: 0-(2^SLI-1) lists for normal buffers + */ + Block** m_FreeList; + VmaPoolAllocator m_BlockAllocator; + Block* m_NullBlock; + VmaBlockBufferImageGranularity m_GranularityHandler; + + uint8_t SizeToMemoryClass(VkDeviceSize size) const; + uint16_t SizeToSecondIndex(VkDeviceSize size, uint8_t memoryClass) const; + uint32_t GetListIndex(uint8_t memoryClass, uint16_t secondIndex) const; + uint32_t GetListIndex(VkDeviceSize size) const; + + void RemoveFreeBlock(Block* block); + void InsertFreeBlock(Block* block); + void MergeBlock(Block* block, Block* prev); + + Block* FindFreeBlock(VkDeviceSize size, uint32_t& listIndex) const; + bool CheckBlock( + Block& block, + uint32_t listIndex, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + VmaAllocationRequest* pAllocationRequest); +}; -void VmaJsonWriter::WriteString(const char* pStr) +#ifndef _VMA_BLOCK_METADATA_TLSF_FUNCTIONS +VmaBlockMetadata_TLSF::VmaBlockMetadata_TLSF(const VkAllocationCallbacks* pAllocationCallbacks, + VkDeviceSize bufferImageGranularity, bool isVirtual) + : VmaBlockMetadata(pAllocationCallbacks, bufferImageGranularity, isVirtual), + m_AllocCount(0), + m_BlocksFreeCount(0), + m_BlocksFreeSize(0), + m_IsFreeBitmap(0), + m_MemoryClasses(0), + m_ListsCount(0), + m_FreeList(VMA_NULL), + m_BlockAllocator(pAllocationCallbacks, INITIAL_BLOCK_ALLOC_COUNT), + m_NullBlock(VMA_NULL), + m_GranularityHandler(bufferImageGranularity) {} + +VmaBlockMetadata_TLSF::~VmaBlockMetadata_TLSF() { - BeginString(pStr); - EndString(); + if (m_FreeList) + vma_delete_array(GetAllocationCallbacks(), m_FreeList, m_ListsCount); + m_GranularityHandler.Destroy(GetAllocationCallbacks()); } -void VmaJsonWriter::BeginString(const char* pStr) +void VmaBlockMetadata_TLSF::Init(VkDeviceSize size) { - VMA_ASSERT(!m_InsideString); + VmaBlockMetadata::Init(size); - BeginValue(true); - m_SB.Add('"'); - m_InsideString = true; - if(pStr != VMA_NULL && pStr[0] != '\0') - { - ContinueString(pStr); - } + if (!IsVirtual()) + m_GranularityHandler.Init(GetAllocationCallbacks(), size); + + m_NullBlock = m_BlockAllocator.Alloc(); + m_NullBlock->size = size; + m_NullBlock->offset = 0; + m_NullBlock->prevPhysical = VMA_NULL; + m_NullBlock->nextPhysical = VMA_NULL; + m_NullBlock->MarkFree(); + m_NullBlock->NextFree() = VMA_NULL; + m_NullBlock->PrevFree() = VMA_NULL; + uint8_t memoryClass = SizeToMemoryClass(size); + uint16_t sli = SizeToSecondIndex(size, memoryClass); + m_ListsCount = (memoryClass == 0 ? 0 : (memoryClass - 1) * (1UL << SECOND_LEVEL_INDEX) + sli) + 1; + if (IsVirtual()) + m_ListsCount += 1UL << SECOND_LEVEL_INDEX; + else + m_ListsCount += 4; + + m_MemoryClasses = memoryClass + 2; + memset(m_InnerIsFreeBitmap, 0, MAX_MEMORY_CLASSES * sizeof(uint32_t)); + + m_FreeList = vma_new_array(GetAllocationCallbacks(), Block*, m_ListsCount); + memset(m_FreeList, 0, m_ListsCount * sizeof(Block*)); } -void VmaJsonWriter::ContinueString(const char* pStr) +bool VmaBlockMetadata_TLSF::Validate() const { - VMA_ASSERT(m_InsideString); + VMA_VALIDATE(GetSumFreeSize() <= GetSize()); - const size_t strLen = strlen(pStr); - for (const auto i : c10::irange(strLen)) { - char ch = pStr[i]; - if(ch == '\\') + VkDeviceSize calculatedSize = m_NullBlock->size; + VkDeviceSize calculatedFreeSize = m_NullBlock->size; + size_t allocCount = 0; + size_t freeCount = 0; + + // Check integrity of free lists + for (uint32_t list = 0; list < m_ListsCount; ++list) + { + Block* block = m_FreeList[list]; + if (block != VMA_NULL) { - m_SB.Add("\\\\"); + VMA_VALIDATE(block->IsFree()); + VMA_VALIDATE(block->PrevFree() == VMA_NULL); + while (block->NextFree()) + { + VMA_VALIDATE(block->NextFree()->IsFree()); + VMA_VALIDATE(block->NextFree()->PrevFree() == block); + block = block->NextFree(); + } } - else if(ch == '"') + } + + VkDeviceSize nextOffset = m_NullBlock->offset; + auto validateCtx = m_GranularityHandler.StartValidation(GetAllocationCallbacks(), IsVirtual()); + + VMA_VALIDATE(m_NullBlock->nextPhysical == VMA_NULL); + if (m_NullBlock->prevPhysical) + { + VMA_VALIDATE(m_NullBlock->prevPhysical->nextPhysical == m_NullBlock); + } + // Check all blocks + for (Block* prev = m_NullBlock->prevPhysical; prev != VMA_NULL; prev = prev->prevPhysical) + { + VMA_VALIDATE(prev->offset + prev->size == nextOffset); + nextOffset = prev->offset; + calculatedSize += prev->size; + + uint32_t listIndex = GetListIndex(prev->size); + if (prev->IsFree()) { - m_SB.Add("\\\""); + ++freeCount; + // Check if free block belongs to free list + Block* freeBlock = m_FreeList[listIndex]; + VMA_VALIDATE(freeBlock != VMA_NULL); + + bool found = false; + do + { + if (freeBlock == prev) + found = true; + + freeBlock = freeBlock->NextFree(); + } while (!found && freeBlock != VMA_NULL); + + VMA_VALIDATE(found); + calculatedFreeSize += prev->size; } - else if(ch >= 32) + else { - m_SB.Add(ch); + ++allocCount; + // Check if taken block is not on a free list + Block* freeBlock = m_FreeList[listIndex]; + while (freeBlock) + { + VMA_VALIDATE(freeBlock != prev); + freeBlock = freeBlock->NextFree(); + } + + if (!IsVirtual()) + { + VMA_VALIDATE(m_GranularityHandler.Validate(validateCtx, prev->offset, prev->size)); + } } - else switch(ch) + + if (prev->prevPhysical) { - case '\b': - m_SB.Add("\\b"); - break; - case '\f': - m_SB.Add("\\f"); - break; - case '\n': - m_SB.Add("\\n"); - break; - case '\r': - m_SB.Add("\\r"); - break; - case '\t': - m_SB.Add("\\t"); - break; - default: - VMA_ASSERT(0 && "Character not currently supported."); - break; + VMA_VALIDATE(prev->prevPhysical->nextPhysical == prev); } } -} -void VmaJsonWriter::ContinueString(uint32_t n) -{ - VMA_ASSERT(m_InsideString); - m_SB.AddNumber(n); -} + if (!IsVirtual()) + { + VMA_VALIDATE(m_GranularityHandler.FinishValidation(validateCtx)); + } -void VmaJsonWriter::ContinueString(uint64_t n) -{ - VMA_ASSERT(m_InsideString); - m_SB.AddNumber(n); -} + VMA_VALIDATE(nextOffset == 0); + VMA_VALIDATE(calculatedSize == GetSize()); + VMA_VALIDATE(calculatedFreeSize == GetSumFreeSize()); + VMA_VALIDATE(allocCount == m_AllocCount); + VMA_VALIDATE(freeCount == m_BlocksFreeCount); -void VmaJsonWriter::ContinueString_Pointer(const void* ptr) -{ - VMA_ASSERT(m_InsideString); - m_SB.AddPointer(ptr); + return true; } -void VmaJsonWriter::EndString(const char* pStr) +void VmaBlockMetadata_TLSF::AddDetailedStatistics(VmaDetailedStatistics& inoutStats) const { - VMA_ASSERT(m_InsideString); - if(pStr != VMA_NULL && pStr[0] != '\0') + inoutStats.statistics.blockCount++; + inoutStats.statistics.blockBytes += GetSize(); + if (m_NullBlock->size > 0) + VmaAddDetailedStatisticsUnusedRange(inoutStats, m_NullBlock->size); + + for (Block* block = m_NullBlock->prevPhysical; block != VMA_NULL; block = block->prevPhysical) { - ContinueString(pStr); + if (block->IsFree()) + VmaAddDetailedStatisticsUnusedRange(inoutStats, block->size); + else + VmaAddDetailedStatisticsAllocation(inoutStats, block->size); } - m_SB.Add('"'); - m_InsideString = false; } -void VmaJsonWriter::WriteNumber(uint32_t n) +void VmaBlockMetadata_TLSF::AddStatistics(VmaStatistics& inoutStats) const { - VMA_ASSERT(!m_InsideString); - BeginValue(false); - m_SB.AddNumber(n); + inoutStats.blockCount++; + inoutStats.allocationCount += (uint32_t)m_AllocCount; + inoutStats.blockBytes += GetSize(); + inoutStats.allocationBytes += GetSize() - GetSumFreeSize(); } -void VmaJsonWriter::WriteNumber(uint64_t n) +#if VMA_STATS_STRING_ENABLED +void VmaBlockMetadata_TLSF::PrintDetailedMap(class VmaJsonWriter& json) const { - VMA_ASSERT(!m_InsideString); - BeginValue(false); - m_SB.AddNumber(n); -} + size_t blockCount = m_AllocCount + m_BlocksFreeCount; + VmaStlAllocator allocator(GetAllocationCallbacks()); + VmaVector> blockList(blockCount, allocator); -void VmaJsonWriter::WriteBool(bool b) -{ - VMA_ASSERT(!m_InsideString); - BeginValue(false); - m_SB.Add(b ? "true" : "false"); -} + size_t i = blockCount; + for (Block* block = m_NullBlock->prevPhysical; block != VMA_NULL; block = block->prevPhysical) + { + blockList[--i] = block; + } + VMA_ASSERT(i == 0); -void VmaJsonWriter::WriteNull() -{ - VMA_ASSERT(!m_InsideString); - BeginValue(false); - m_SB.Add("null"); + VmaDetailedStatistics stats; + VmaClearDetailedStatistics(stats); + AddDetailedStatistics(stats); + + PrintDetailedMap_Begin(json, + stats.statistics.blockBytes - stats.statistics.allocationBytes, + stats.statistics.allocationCount, + stats.unusedRangeCount); + + for (; i < blockCount; ++i) + { + Block* block = blockList[i]; + if (block->IsFree()) + PrintDetailedMap_UnusedRange(json, block->offset, block->size); + else + PrintDetailedMap_Allocation(json, block->offset, block->size, block->UserData()); + } + if (m_NullBlock->size > 0) + PrintDetailedMap_UnusedRange(json, m_NullBlock->offset, m_NullBlock->size); + + PrintDetailedMap_End(json); } +#endif -void VmaJsonWriter::BeginValue(bool isString) +bool VmaBlockMetadata_TLSF::CreateAllocationRequest( + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) { - if(!m_Stack.empty()) + VMA_ASSERT(allocSize > 0 && "Cannot allocate empty block!"); + VMA_ASSERT(!upperAddress && "VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT can be used only with linear algorithm."); + + // For small granularity round up + if (!IsVirtual()) + m_GranularityHandler.RoundupAllocRequest(allocType, allocSize, allocAlignment); + + allocSize += GetDebugMargin(); + // Quick check for too small pool + if (allocSize > GetSumFreeSize()) + return false; + + // If no free blocks in pool then check only null block + if (m_BlocksFreeCount == 0) + return CheckBlock(*m_NullBlock, m_ListsCount, allocSize, allocAlignment, allocType, pAllocationRequest); + + // Round up to the next block + VkDeviceSize sizeForNextList = allocSize; + VkDeviceSize smallSizeStep = SMALL_BUFFER_SIZE / (IsVirtual() ? 1 << SECOND_LEVEL_INDEX : 4); + if (allocSize > SMALL_BUFFER_SIZE) { - StackItem& currItem = m_Stack.back(); - if(currItem.type == COLLECTION_TYPE_OBJECT && - currItem.valueCount % 2 == 0) + sizeForNextList += (1ULL << (VMA_BITSCAN_MSB(allocSize) - SECOND_LEVEL_INDEX)); + } + else if (allocSize > SMALL_BUFFER_SIZE - smallSizeStep) + sizeForNextList = SMALL_BUFFER_SIZE + 1; + else + sizeForNextList += smallSizeStep; + + uint32_t nextListIndex = 0; + uint32_t prevListIndex = 0; + Block* nextListBlock = VMA_NULL; + Block* prevListBlock = VMA_NULL; + + // Check blocks according to strategies + if (strategy & VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT) + { + // Quick check for larger block first + nextListBlock = FindFreeBlock(sizeForNextList, nextListIndex); + if (nextListBlock != VMA_NULL && CheckBlock(*nextListBlock, nextListIndex, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + + // If not fitted then null block + if (CheckBlock(*m_NullBlock, m_ListsCount, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + + // Null block failed, search larger bucket + while (nextListBlock) { - VMA_ASSERT(isString); + if (CheckBlock(*nextListBlock, nextListIndex, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + nextListBlock = nextListBlock->NextFree(); } - if(currItem.type == COLLECTION_TYPE_OBJECT && - currItem.valueCount % 2 != 0) + // Failed again, check best fit bucket + prevListBlock = FindFreeBlock(allocSize, prevListIndex); + while (prevListBlock) { - m_SB.Add(": "); + if (CheckBlock(*prevListBlock, prevListIndex, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + prevListBlock = prevListBlock->NextFree(); } - else if(currItem.valueCount > 0) + } + else if (strategy & VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT) + { + // Check best fit bucket + prevListBlock = FindFreeBlock(allocSize, prevListIndex); + while (prevListBlock) { - m_SB.Add(", "); - WriteIndent(); + if (CheckBlock(*prevListBlock, prevListIndex, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + prevListBlock = prevListBlock->NextFree(); } - else + + // If failed check null block + if (CheckBlock(*m_NullBlock, m_ListsCount, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + + // Check larger bucket + nextListBlock = FindFreeBlock(sizeForNextList, nextListIndex); + while (nextListBlock) { - WriteIndent(); + if (CheckBlock(*nextListBlock, nextListIndex, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + nextListBlock = nextListBlock->NextFree(); } - ++currItem.valueCount; } -} - -void VmaJsonWriter::WriteIndent(bool oneLess) -{ - if(!m_Stack.empty() && !m_Stack.back().singleLineMode) + else if (strategy & VMA_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT ) { - m_SB.AddNewLine(); + // Perform search from the start + VmaStlAllocator allocator(GetAllocationCallbacks()); + VmaVector> blockList(m_BlocksFreeCount, allocator); - size_t count = m_Stack.size(); - if(count > 0 && oneLess) + size_t i = m_BlocksFreeCount; + for (Block* block = m_NullBlock->prevPhysical; block != VMA_NULL; block = block->prevPhysical) { - --count; - } - for (const auto i : c10::irange(count)) { - m_SB.Add(INDENT); + if (block->IsFree() && block->size >= allocSize) + blockList[--i] = block; } - } -} -#endif // #if VMA_STATS_STRING_ENABLED + for (; i < m_BlocksFreeCount; ++i) + { + Block& block = *blockList[i]; + if (CheckBlock(block, GetListIndex(block.size), allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + } -//////////////////////////////////////////////////////////////////////////////// + // If failed check null block + if (CheckBlock(*m_NullBlock, m_ListsCount, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; -void VmaAllocation_T::SetUserData(VmaAllocator hAllocator, void* pUserData) -{ - if(IsUserDataString()) + // Whole range searched, no more memory + return false; + } + else { - VMA_ASSERT(pUserData == VMA_NULL || pUserData != m_pUserData); + // Check larger bucket + nextListBlock = FindFreeBlock(sizeForNextList, nextListIndex); + while (nextListBlock) + { + if (CheckBlock(*nextListBlock, nextListIndex, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + nextListBlock = nextListBlock->NextFree(); + } - FreeUserDataString(hAllocator); + // If failed check null block + if (CheckBlock(*m_NullBlock, m_ListsCount, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; - if(pUserData != VMA_NULL) + // Check best fit bucket + prevListBlock = FindFreeBlock(allocSize, prevListIndex); + while (prevListBlock) { - m_pUserData = VmaCreateStringCopy(hAllocator->GetAllocationCallbacks(), (const char*)pUserData); + if (CheckBlock(*prevListBlock, prevListIndex, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + prevListBlock = prevListBlock->NextFree(); } } - else + + // Worst case, full search has to be done + while (++nextListIndex < m_ListsCount) { - m_pUserData = pUserData; + nextListBlock = m_FreeList[nextListIndex]; + while (nextListBlock) + { + if (CheckBlock(*nextListBlock, nextListIndex, allocSize, allocAlignment, allocType, pAllocationRequest)) + return true; + nextListBlock = nextListBlock->NextFree(); + } } + + // No more memory sadly + return false; } -void VmaAllocation_T::ChangeBlockAllocation( - VmaAllocator hAllocator, - VmaDeviceMemoryBlock* block, - VkDeviceSize offset) +VkResult VmaBlockMetadata_TLSF::CheckCorruption(const void* pBlockData) { - VMA_ASSERT(block != VMA_NULL); - VMA_ASSERT(m_Type == ALLOCATION_TYPE_BLOCK); - - // Move mapping reference counter from old block to new block. - if(block != m_BlockAllocation.m_Block) + for (Block* block = m_NullBlock->prevPhysical; block != VMA_NULL; block = block->prevPhysical) { - uint32_t mapRefCount = m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP; - if(IsPersistentMap()) - ++mapRefCount; - m_BlockAllocation.m_Block->Unmap(hAllocator, mapRefCount); - block->Map(hAllocator, mapRefCount, VMA_NULL); + if (!block->IsFree()) + { + if (!VmaValidateMagicValue(pBlockData, block->offset + block->size)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); + return VK_ERROR_UNKNOWN_COPY; + } + } } - m_BlockAllocation.m_Block = block; - m_BlockAllocation.m_Offset = offset; + return VK_SUCCESS; } -void VmaAllocation_T::ChangeOffset(VkDeviceSize newOffset) +void VmaBlockMetadata_TLSF::Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + void* userData) { - VMA_ASSERT(m_Type == ALLOCATION_TYPE_BLOCK); - m_BlockAllocation.m_Offset = newOffset; -} + VMA_ASSERT(request.type == VmaAllocationRequestType::TLSF); -VkDeviceSize VmaAllocation_T::GetOffset() const -{ - switch(m_Type) - { - case ALLOCATION_TYPE_BLOCK: - return m_BlockAllocation.m_Offset; - case ALLOCATION_TYPE_DEDICATED: - return 0; - default: - VMA_ASSERT(0); - return 0; - } -} + // Get block and pop it from the free list + Block* currentBlock = (Block*)request.allocHandle; + VkDeviceSize offset = request.algorithmData; + VMA_ASSERT(currentBlock != VMA_NULL); + VMA_ASSERT(currentBlock->offset <= offset); -VkDeviceMemory VmaAllocation_T::GetMemory() const -{ - switch(m_Type) - { - case ALLOCATION_TYPE_BLOCK: - return m_BlockAllocation.m_Block->GetDeviceMemory(); - case ALLOCATION_TYPE_DEDICATED: - return m_DedicatedAllocation.m_hMemory; - default: - VMA_ASSERT(0); - return VK_NULL_HANDLE; - } -} + if (currentBlock != m_NullBlock) + RemoveFreeBlock(currentBlock); -void* VmaAllocation_T::GetMappedData() const -{ - switch(m_Type) + VkDeviceSize debugMargin = GetDebugMargin(); + VkDeviceSize misssingAlignment = offset - currentBlock->offset; + + // Append missing alignment to prev block or create new one + if (misssingAlignment) { - case ALLOCATION_TYPE_BLOCK: - if(m_MapCount != 0) + Block* prevBlock = currentBlock->prevPhysical; + VMA_ASSERT(prevBlock != VMA_NULL && "There should be no missing alignment at offset 0!"); + + if (prevBlock->IsFree() && prevBlock->size != debugMargin) { - void* pBlockData = m_BlockAllocation.m_Block->GetMappedData(); - VMA_ASSERT(pBlockData != VMA_NULL); - return (char*)pBlockData + m_BlockAllocation.m_Offset; + uint32_t oldList = GetListIndex(prevBlock->size); + prevBlock->size += misssingAlignment; + // Check if new size crosses list bucket + if (oldList != GetListIndex(prevBlock->size)) + { + prevBlock->size -= misssingAlignment; + RemoveFreeBlock(prevBlock); + prevBlock->size += misssingAlignment; + InsertFreeBlock(prevBlock); + } + else + m_BlocksFreeSize += misssingAlignment; } else { - return VMA_NULL; + Block* newBlock = m_BlockAllocator.Alloc(); + currentBlock->prevPhysical = newBlock; + prevBlock->nextPhysical = newBlock; + newBlock->prevPhysical = prevBlock; + newBlock->nextPhysical = currentBlock; + newBlock->size = misssingAlignment; + newBlock->offset = currentBlock->offset; + newBlock->MarkTaken(); + + InsertFreeBlock(newBlock); } - break; - case ALLOCATION_TYPE_DEDICATED: - VMA_ASSERT((m_DedicatedAllocation.m_pMappedData != VMA_NULL) == (m_MapCount != 0)); - return m_DedicatedAllocation.m_pMappedData; - default: - VMA_ASSERT(0); - return VMA_NULL; - } -} -bool VmaAllocation_T::CanBecomeLost() const -{ - switch(m_Type) - { - case ALLOCATION_TYPE_BLOCK: - return m_BlockAllocation.m_CanBecomeLost; - case ALLOCATION_TYPE_DEDICATED: - return false; - default: - VMA_ASSERT(0); - return false; + currentBlock->size -= misssingAlignment; + currentBlock->offset += misssingAlignment; } -} - -bool VmaAllocation_T::MakeLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) -{ - VMA_ASSERT(CanBecomeLost()); - /* - Warning: This is a carefully designed algorithm. - Do not modify unless you really know what you're doing :) - */ - uint32_t localLastUseFrameIndex = GetLastUseFrameIndex(); - for(;;) + VkDeviceSize size = request.size + debugMargin; + if (currentBlock->size == size) { - if(localLastUseFrameIndex == VMA_FRAME_INDEX_LOST) - { - VMA_ASSERT(0); - return false; - } - else if(localLastUseFrameIndex + frameInUseCount >= currentFrameIndex) - { - return false; - } - else // Last use time earlier than current time. + if (currentBlock == m_NullBlock) { - if(CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, VMA_FRAME_INDEX_LOST)) - { - // Setting hAllocation.LastUseFrameIndex atomic to VMA_FRAME_INDEX_LOST is enough to mark it as LOST. - // Calling code just needs to unregister this allocation in owning VmaDeviceMemoryBlock. - return true; - } + // Setup new null block + m_NullBlock = m_BlockAllocator.Alloc(); + m_NullBlock->size = 0; + m_NullBlock->offset = currentBlock->offset + size; + m_NullBlock->prevPhysical = currentBlock; + m_NullBlock->nextPhysical = VMA_NULL; + m_NullBlock->MarkFree(); + m_NullBlock->PrevFree() = VMA_NULL; + m_NullBlock->NextFree() = VMA_NULL; + currentBlock->nextPhysical = m_NullBlock; + currentBlock->MarkTaken(); } } -} - -#if VMA_STATS_STRING_ENABLED - -// Correspond to values of enum VmaSuballocationType. -static const char* VMA_SUBALLOCATION_TYPE_NAMES[] = { - "FREE", - "UNKNOWN", - "BUFFER", - "IMAGE_UNKNOWN", - "IMAGE_LINEAR", - "IMAGE_OPTIMAL", -}; - -void VmaAllocation_T::PrintParameters(class VmaJsonWriter& json) const -{ - json.WriteString("Type"); - json.WriteString(VMA_SUBALLOCATION_TYPE_NAMES[m_SuballocationType]); + else + { + VMA_ASSERT(currentBlock->size > size && "Proper block already found, shouldn't find smaller one!"); - json.WriteString("Size"); - json.WriteNumber(m_Size); + // Create new free block + Block* newBlock = m_BlockAllocator.Alloc(); + newBlock->size = currentBlock->size - size; + newBlock->offset = currentBlock->offset + size; + newBlock->prevPhysical = currentBlock; + newBlock->nextPhysical = currentBlock->nextPhysical; + currentBlock->nextPhysical = newBlock; + currentBlock->size = size; - if(m_pUserData != VMA_NULL) - { - json.WriteString("UserData"); - if(IsUserDataString()) + if (currentBlock == m_NullBlock) { - json.WriteString((const char*)m_pUserData); + m_NullBlock = newBlock; + m_NullBlock->MarkFree(); + m_NullBlock->NextFree() = VMA_NULL; + m_NullBlock->PrevFree() = VMA_NULL; + currentBlock->MarkTaken(); } else { - json.BeginString(); - json.ContinueString_Pointer(m_pUserData); - json.EndString(); + newBlock->nextPhysical->prevPhysical = newBlock; + newBlock->MarkTaken(); + InsertFreeBlock(newBlock); } } + currentBlock->UserData() = userData; - json.WriteString("CreationFrameIndex"); - json.WriteNumber(m_CreationFrameIndex); - - json.WriteString("LastUseFrameIndex"); - json.WriteNumber(GetLastUseFrameIndex()); - - if(m_BufferImageUsage != 0) + if (debugMargin > 0) { - json.WriteString("Usage"); - json.WriteNumber(m_BufferImageUsage); + currentBlock->size -= debugMargin; + Block* newBlock = m_BlockAllocator.Alloc(); + newBlock->size = debugMargin; + newBlock->offset = currentBlock->offset + currentBlock->size; + newBlock->prevPhysical = currentBlock; + newBlock->nextPhysical = currentBlock->nextPhysical; + newBlock->MarkTaken(); + currentBlock->nextPhysical->prevPhysical = newBlock; + currentBlock->nextPhysical = newBlock; + InsertFreeBlock(newBlock); } + + if (!IsVirtual()) + m_GranularityHandler.AllocPages((uint8_t)(uintptr_t)request.customData, + currentBlock->offset, currentBlock->size); + ++m_AllocCount; } -#endif - -void VmaAllocation_T::FreeUserDataString(VmaAllocator hAllocator) +void VmaBlockMetadata_TLSF::Free(VmaAllocHandle allocHandle) { - VMA_ASSERT(IsUserDataString()); - VmaFreeString(hAllocator->GetAllocationCallbacks(), (char*)m_pUserData); - m_pUserData = VMA_NULL; -} + Block* block = (Block*)allocHandle; + Block* next = block->nextPhysical; + VMA_ASSERT(!block->IsFree() && "Block is already free!"); -void VmaAllocation_T::BlockAllocMap() -{ - VMA_ASSERT(GetType() == ALLOCATION_TYPE_BLOCK); + if (!IsVirtual()) + m_GranularityHandler.FreePages(block->offset, block->size); + --m_AllocCount; - if((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) < 0x7F) - { - ++m_MapCount; - } - else + VkDeviceSize debugMargin = GetDebugMargin(); + if (debugMargin > 0) { - VMA_ASSERT(0 && "Allocation mapped too many times simultaneously."); + RemoveFreeBlock(next); + MergeBlock(next, block); + block = next; + next = next->nextPhysical; } -} - -void VmaAllocation_T::BlockAllocUnmap() -{ - VMA_ASSERT(GetType() == ALLOCATION_TYPE_BLOCK); - if((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) != 0) + // Try merging + Block* prev = block->prevPhysical; + if (prev != VMA_NULL && prev->IsFree() && prev->size != debugMargin) { - --m_MapCount; + RemoveFreeBlock(prev); + MergeBlock(block, prev); } + + if (!next->IsFree()) + InsertFreeBlock(block); + else if (next == m_NullBlock) + MergeBlock(m_NullBlock, block); else { - VMA_ASSERT(0 && "Unmapping allocation not previously mapped."); + RemoveFreeBlock(next); + MergeBlock(next, block); + InsertFreeBlock(next); } } -VkResult VmaAllocation_T::DedicatedAllocMap(VmaAllocator hAllocator, void** ppData) +void VmaBlockMetadata_TLSF::GetAllocationInfo(VmaAllocHandle allocHandle, VmaVirtualAllocationInfo& outInfo) { - VMA_ASSERT(GetType() == ALLOCATION_TYPE_DEDICATED); + Block* block = (Block*)allocHandle; + VMA_ASSERT(!block->IsFree() && "Cannot get allocation info for free block!"); + outInfo.offset = block->offset; + outInfo.size = block->size; + outInfo.pUserData = block->UserData(); +} - if(m_MapCount != 0) - { - if((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) < 0x7F) - { - VMA_ASSERT(m_DedicatedAllocation.m_pMappedData != VMA_NULL); - *ppData = m_DedicatedAllocation.m_pMappedData; - ++m_MapCount; - return VK_SUCCESS; - } - else - { - VMA_ASSERT(0 && "Dedicated allocation mapped too many times simultaneously."); - return VK_ERROR_MEMORY_MAP_FAILED; - } - } - else - { - VkResult result = (*hAllocator->GetVulkanFunctions().vkMapMemory)( - hAllocator->m_hDevice, - m_DedicatedAllocation.m_hMemory, - 0, // offset - VK_WHOLE_SIZE, - 0, // flags - ppData); - if(result == VK_SUCCESS) - { - m_DedicatedAllocation.m_pMappedData = *ppData; - m_MapCount = 1; - } - return result; - } +void* VmaBlockMetadata_TLSF::GetAllocationUserData(VmaAllocHandle allocHandle) const +{ + Block* block = (Block*)allocHandle; + VMA_ASSERT(!block->IsFree() && "Cannot get user data for free block!"); + return block->UserData(); } -void VmaAllocation_T::DedicatedAllocUnmap(VmaAllocator hAllocator) +VmaAllocHandle VmaBlockMetadata_TLSF::GetAllocationListBegin() const { - VMA_ASSERT(GetType() == ALLOCATION_TYPE_DEDICATED); + if (m_AllocCount == 0) + return VK_NULL_HANDLE; - if((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) != 0) - { - --m_MapCount; - if(m_MapCount == 0) - { - m_DedicatedAllocation.m_pMappedData = VMA_NULL; - (*hAllocator->GetVulkanFunctions().vkUnmapMemory)( - hAllocator->m_hDevice, - m_DedicatedAllocation.m_hMemory); - } - } - else + for (Block* block = m_NullBlock->prevPhysical; block; block = block->prevPhysical) { - VMA_ASSERT(0 && "Unmapping dedicated allocation not previously mapped."); + if (!block->IsFree()) + return (VmaAllocHandle)block; } + VMA_ASSERT(false && "If m_AllocCount > 0 then should find any allocation!"); + return VK_NULL_HANDLE; } -#if VMA_STATS_STRING_ENABLED - -static void VmaPrintStatInfo(VmaJsonWriter& json, const VmaStatInfo& stat) +VmaAllocHandle VmaBlockMetadata_TLSF::GetNextAllocation(VmaAllocHandle prevAlloc) const { - json.BeginObject(); - - json.WriteString("Blocks"); - json.WriteNumber(stat.blockCount); - - json.WriteString("Allocations"); - json.WriteNumber(stat.allocationCount); - - json.WriteString("UnusedRanges"); - json.WriteNumber(stat.unusedRangeCount); - - json.WriteString("UsedBytes"); - json.WriteNumber(stat.usedBytes); - - json.WriteString("UnusedBytes"); - json.WriteNumber(stat.unusedBytes); + Block* startBlock = (Block*)prevAlloc; + VMA_ASSERT(!startBlock->IsFree() && "Incorrect block!"); - if(stat.allocationCount > 1) + for (Block* block = startBlock->prevPhysical; block; block = block->prevPhysical) { - json.WriteString("AllocationSize"); - json.BeginObject(true); - json.WriteString("Min"); - json.WriteNumber(stat.allocationSizeMin); - json.WriteString("Avg"); - json.WriteNumber(stat.allocationSizeAvg); - json.WriteString("Max"); - json.WriteNumber(stat.allocationSizeMax); - json.EndObject(); + if (!block->IsFree()) + return (VmaAllocHandle)block; } + return VK_NULL_HANDLE; +} - if(stat.unusedRangeCount > 1) - { - json.WriteString("UnusedRangeSize"); - json.BeginObject(true); - json.WriteString("Min"); - json.WriteNumber(stat.unusedRangeSizeMin); - json.WriteString("Avg"); - json.WriteNumber(stat.unusedRangeSizeAvg); - json.WriteString("Max"); - json.WriteNumber(stat.unusedRangeSizeMax); - json.EndObject(); - } +VkDeviceSize VmaBlockMetadata_TLSF::GetNextFreeRegionSize(VmaAllocHandle alloc) const +{ + Block* block = (Block*)alloc; + VMA_ASSERT(!block->IsFree() && "Incorrect block!"); - json.EndObject(); + if (block->prevPhysical) + return block->prevPhysical->IsFree() ? block->prevPhysical->size : 0; + return 0; } -#endif // #if VMA_STATS_STRING_ENABLED - -struct VmaSuballocationItemSizeLess +void VmaBlockMetadata_TLSF::Clear() { - bool operator()( - const VmaSuballocationList::iterator lhs, - const VmaSuballocationList::iterator rhs) const - { - return lhs->size < rhs->size; - } - bool operator()( - const VmaSuballocationList::iterator lhs, - VkDeviceSize rhsSize) const + m_AllocCount = 0; + m_BlocksFreeCount = 0; + m_BlocksFreeSize = 0; + m_IsFreeBitmap = 0; + m_NullBlock->offset = 0; + m_NullBlock->size = GetSize(); + Block* block = m_NullBlock->prevPhysical; + m_NullBlock->prevPhysical = VMA_NULL; + while (block) { - return lhs->size < rhsSize; + Block* prev = block->prevPhysical; + m_BlockAllocator.Free(block); + block = prev; } -}; - - -//////////////////////////////////////////////////////////////////////////////// -// class VmaBlockMetadata - -VmaBlockMetadata::VmaBlockMetadata(VmaAllocator hAllocator) : - m_Size(0), - m_pAllocationCallbacks(hAllocator->GetAllocationCallbacks()) -{ + memset(m_FreeList, 0, m_ListsCount * sizeof(Block*)); + memset(m_InnerIsFreeBitmap, 0, m_MemoryClasses * sizeof(uint32_t)); + m_GranularityHandler.Clear(); } -#if VMA_STATS_STRING_ENABLED - -void VmaBlockMetadata::PrintDetailedMap_Begin(class VmaJsonWriter& json, - VkDeviceSize unusedBytes, - size_t allocationCount, - size_t unusedRangeCount) const +void VmaBlockMetadata_TLSF::SetAllocationUserData(VmaAllocHandle allocHandle, void* userData) { - json.BeginObject(); - - json.WriteString("TotalBytes"); - json.WriteNumber(GetSize()); - - json.WriteString("UnusedBytes"); - json.WriteNumber(unusedBytes); - - json.WriteString("Allocations"); - json.WriteNumber((uint64_t)allocationCount); - - json.WriteString("UnusedRanges"); - json.WriteNumber((uint64_t)unusedRangeCount); - - json.WriteString("Suballocations"); - json.BeginArray(); + Block* block = (Block*)allocHandle; + VMA_ASSERT(!block->IsFree() && "Trying to set user data for not allocated block!"); + block->UserData() = userData; } -void VmaBlockMetadata::PrintDetailedMap_Allocation(class VmaJsonWriter& json, - VkDeviceSize offset, - VmaAllocation hAllocation) const +void VmaBlockMetadata_TLSF::DebugLogAllAllocations() const { - json.BeginObject(true); - - json.WriteString("Offset"); - json.WriteNumber(offset); - - hAllocation->PrintParameters(json); - - json.EndObject(); + for (Block* block = m_NullBlock->prevPhysical; block != VMA_NULL; block = block->prevPhysical) + if (!block->IsFree()) + DebugLogAllocation(block->offset, block->size, block->UserData()); } -void VmaBlockMetadata::PrintDetailedMap_UnusedRange(class VmaJsonWriter& json, - VkDeviceSize offset, - VkDeviceSize size) const +uint8_t VmaBlockMetadata_TLSF::SizeToMemoryClass(VkDeviceSize size) const { - json.BeginObject(true); - - json.WriteString("Offset"); - json.WriteNumber(offset); - - json.WriteString("Type"); - json.WriteString(VMA_SUBALLOCATION_TYPE_NAMES[VMA_SUBALLOCATION_TYPE_FREE]); - - json.WriteString("Size"); - json.WriteNumber(size); - - json.EndObject(); + if (size > SMALL_BUFFER_SIZE) + return VMA_BITSCAN_MSB(size) - MEMORY_CLASS_SHIFT; + return 0; } -void VmaBlockMetadata::PrintDetailedMap_End(class VmaJsonWriter& json) const +uint16_t VmaBlockMetadata_TLSF::SizeToSecondIndex(VkDeviceSize size, uint8_t memoryClass) const { - json.EndArray(); - json.EndObject(); + if (memoryClass == 0) + { + if (IsVirtual()) + return static_cast((size - 1) / 8); + else + return static_cast((size - 1) / 64); + } + return static_cast((size >> (memoryClass + MEMORY_CLASS_SHIFT - SECOND_LEVEL_INDEX)) ^ (1U << SECOND_LEVEL_INDEX)); } -#endif // #if VMA_STATS_STRING_ENABLED - -//////////////////////////////////////////////////////////////////////////////// -// class VmaBlockMetadata_Generic - -VmaBlockMetadata_Generic::VmaBlockMetadata_Generic(VmaAllocator hAllocator) : - VmaBlockMetadata(hAllocator), - m_FreeCount(0), - m_SumFreeSize(0), - m_Suballocations(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), - m_FreeSuballocationsBySize(VmaStlAllocator(hAllocator->GetAllocationCallbacks())) +uint32_t VmaBlockMetadata_TLSF::GetListIndex(uint8_t memoryClass, uint16_t secondIndex) const { -} + if (memoryClass == 0) + return secondIndex; -VmaBlockMetadata_Generic::~VmaBlockMetadata_Generic() -{ + const uint32_t index = static_cast(memoryClass - 1) * (1 << SECOND_LEVEL_INDEX) + secondIndex; + if (IsVirtual()) + return index + (1 << SECOND_LEVEL_INDEX); + else + return index + 4; } -void VmaBlockMetadata_Generic::Init(VkDeviceSize size) +uint32_t VmaBlockMetadata_TLSF::GetListIndex(VkDeviceSize size) const { - VmaBlockMetadata::Init(size); - - m_FreeCount = 1; - m_SumFreeSize = size; - - VmaSuballocation suballoc = {}; - suballoc.offset = 0; - suballoc.size = size; - suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; - suballoc.hAllocation = VK_NULL_HANDLE; - - VMA_ASSERT(size > VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER); - m_Suballocations.push_back(suballoc); - VmaSuballocationList::iterator suballocItem = m_Suballocations.end(); - --suballocItem; - m_FreeSuballocationsBySize.push_back(suballocItem); + uint8_t memoryClass = SizeToMemoryClass(size); + return GetListIndex(memoryClass, SizeToSecondIndex(size, memoryClass)); } -bool VmaBlockMetadata_Generic::Validate() const +void VmaBlockMetadata_TLSF::RemoveFreeBlock(Block* block) { - VMA_VALIDATE(!m_Suballocations.empty()); - - // Expected offset of new suballocation as calculated from previous ones. - VkDeviceSize calculatedOffset = 0; - // Expected number of free suballocations as calculated from traversing their list. - uint32_t calculatedFreeCount = 0; - // Expected sum size of free suballocations as calculated from traversing their list. - VkDeviceSize calculatedSumFreeSize = 0; - // Expected number of free suballocations that should be registered in - // m_FreeSuballocationsBySize calculated from traversing their list. - size_t freeSuballocationsToRegister = 0; - // True if previous visited suballocation was free. - bool prevFree = false; + VMA_ASSERT(block != m_NullBlock); + VMA_ASSERT(block->IsFree()); - for(VmaSuballocationList::const_iterator suballocItem = m_Suballocations.cbegin(); - suballocItem != m_Suballocations.cend(); - ++suballocItem) + if (block->NextFree() != VMA_NULL) + block->NextFree()->PrevFree() = block->PrevFree(); + if (block->PrevFree() != VMA_NULL) + block->PrevFree()->NextFree() = block->NextFree(); + else { - const VmaSuballocation& subAlloc = *suballocItem; - - // Actual offset of this suballocation doesn't match expected one. - VMA_VALIDATE(subAlloc.offset == calculatedOffset); - - const bool currFree = (subAlloc.type == VMA_SUBALLOCATION_TYPE_FREE); - // Two adjacent free suballocations are invalid. They should be merged. - VMA_VALIDATE(!prevFree || !currFree); - - VMA_VALIDATE(currFree == (subAlloc.hAllocation == VK_NULL_HANDLE)); - - if(currFree) - { - calculatedSumFreeSize += subAlloc.size; - ++calculatedFreeCount; - if(subAlloc.size >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) - { - ++freeSuballocationsToRegister; - } - - // Margin required between allocations - every free space must be at least that large. - VMA_VALIDATE(subAlloc.size >= VMA_DEBUG_MARGIN); - } - else + uint8_t memClass = SizeToMemoryClass(block->size); + uint16_t secondIndex = SizeToSecondIndex(block->size, memClass); + uint32_t index = GetListIndex(memClass, secondIndex); + VMA_ASSERT(m_FreeList[index] == block); + m_FreeList[index] = block->NextFree(); + if (block->NextFree() == VMA_NULL) { - VMA_VALIDATE(subAlloc.hAllocation->GetOffset() == subAlloc.offset); - VMA_VALIDATE(subAlloc.hAllocation->GetSize() == subAlloc.size); - - // Margin required between allocations - previous allocation must be free. - VMA_VALIDATE(VMA_DEBUG_MARGIN == 0 || prevFree); + m_InnerIsFreeBitmap[memClass] &= ~(1U << secondIndex); + if (m_InnerIsFreeBitmap[memClass] == 0) + m_IsFreeBitmap &= ~(1UL << memClass); } - - calculatedOffset += subAlloc.size; - prevFree = currFree; } + block->MarkTaken(); + block->UserData() = VMA_NULL; + --m_BlocksFreeCount; + m_BlocksFreeSize -= block->size; +} - // Number of free suballocations registered in m_FreeSuballocationsBySize doesn't - // match expected one. - VMA_VALIDATE(m_FreeSuballocationsBySize.size() == freeSuballocationsToRegister); - - VkDeviceSize lastSize = 0; - for (const auto i : c10::irange(m_FreeSuballocationsBySize.size())) { - VmaSuballocationList::iterator suballocItem = m_FreeSuballocationsBySize[i]; - - // Only free suballocations can be registered in m_FreeSuballocationsBySize. - VMA_VALIDATE(suballocItem->type == VMA_SUBALLOCATION_TYPE_FREE); - // They must be sorted by size ascending. - VMA_VALIDATE(suballocItem->size >= lastSize); +void VmaBlockMetadata_TLSF::InsertFreeBlock(Block* block) +{ + VMA_ASSERT(block != m_NullBlock); + VMA_ASSERT(!block->IsFree() && "Cannot insert block twice!"); - lastSize = suballocItem->size; + uint8_t memClass = SizeToMemoryClass(block->size); + uint16_t secondIndex = SizeToSecondIndex(block->size, memClass); + uint32_t index = GetListIndex(memClass, secondIndex); + VMA_ASSERT(index < m_ListsCount); + block->PrevFree() = VMA_NULL; + block->NextFree() = m_FreeList[index]; + m_FreeList[index] = block; + if (block->NextFree() != VMA_NULL) + block->NextFree()->PrevFree() = block; + else + { + m_InnerIsFreeBitmap[memClass] |= 1U << secondIndex; + m_IsFreeBitmap |= 1UL << memClass; } + ++m_BlocksFreeCount; + m_BlocksFreeSize += block->size; +} - // Check if totals match calculacted values. - VMA_VALIDATE(ValidateFreeSuballocationList()); - VMA_VALIDATE(calculatedOffset == GetSize()); - VMA_VALIDATE(calculatedSumFreeSize == m_SumFreeSize); - VMA_VALIDATE(calculatedFreeCount == m_FreeCount); +void VmaBlockMetadata_TLSF::MergeBlock(Block* block, Block* prev) +{ + VMA_ASSERT(block->prevPhysical == prev && "Cannot merge seperate physical regions!"); + VMA_ASSERT(!prev->IsFree() && "Cannot merge block that belongs to free list!"); - return true; + block->offset = prev->offset; + block->size += prev->size; + block->prevPhysical = prev->prevPhysical; + if (block->prevPhysical) + block->prevPhysical->nextPhysical = block; + m_BlockAllocator.Free(prev); } -VkDeviceSize VmaBlockMetadata_Generic::GetUnusedRangeSizeMax() const +VmaBlockMetadata_TLSF::Block* VmaBlockMetadata_TLSF::FindFreeBlock(VkDeviceSize size, uint32_t& listIndex) const { - if(!m_FreeSuballocationsBySize.empty()) - { - return m_FreeSuballocationsBySize.back()->size; - } - else + uint8_t memoryClass = SizeToMemoryClass(size); + uint32_t innerFreeMap = m_InnerIsFreeBitmap[memoryClass] & (~0U << SizeToSecondIndex(size, memoryClass)); + if (!innerFreeMap) { - return 0; - } -} + // Check higher levels for avaiable blocks + uint32_t freeMap = m_IsFreeBitmap & (~0UL << (memoryClass + 1)); + if (!freeMap) + return VMA_NULL; // No more memory avaible -bool VmaBlockMetadata_Generic::IsEmpty() const -{ - return (m_Suballocations.size() == 1) && (m_FreeCount == 1); + // Find lowest free region + memoryClass = VMA_BITSCAN_LSB(freeMap); + innerFreeMap = m_InnerIsFreeBitmap[memoryClass]; + VMA_ASSERT(innerFreeMap != 0); + } + // Find lowest free subregion + listIndex = GetListIndex(memoryClass, VMA_BITSCAN_LSB(innerFreeMap)); + VMA_ASSERT(m_FreeList[listIndex]); + return m_FreeList[listIndex]; } -void VmaBlockMetadata_Generic::CalcAllocationStatInfo(VmaStatInfo& outInfo) const +bool VmaBlockMetadata_TLSF::CheckBlock( + Block& block, + uint32_t listIndex, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + VmaAllocationRequest* pAllocationRequest) { - outInfo.blockCount = 1; + VMA_ASSERT(block.IsFree() && "Block is already taken!"); - const uint32_t rangeCount = (uint32_t)m_Suballocations.size(); - outInfo.allocationCount = rangeCount - m_FreeCount; - outInfo.unusedRangeCount = m_FreeCount; + VkDeviceSize alignedOffset = VmaAlignUp(block.offset, allocAlignment); + if (block.size < allocSize + alignedOffset - block.offset) + return false; - outInfo.unusedBytes = m_SumFreeSize; - outInfo.usedBytes = GetSize() - outInfo.unusedBytes; + // Check for granularity conflicts + if (!IsVirtual() && + m_GranularityHandler.CheckConflictAndAlignUp(alignedOffset, allocSize, block.offset, block.size, allocType)) + return false; - outInfo.allocationSizeMin = UINT64_MAX; - outInfo.allocationSizeMax = 0; - outInfo.unusedRangeSizeMin = UINT64_MAX; - outInfo.unusedRangeSizeMax = 0; + // Alloc successful + pAllocationRequest->type = VmaAllocationRequestType::TLSF; + pAllocationRequest->allocHandle = (VmaAllocHandle)█ + pAllocationRequest->size = allocSize - GetDebugMargin(); + pAllocationRequest->customData = (void*)allocType; + pAllocationRequest->algorithmData = alignedOffset; - for(VmaSuballocationList::const_iterator suballocItem = m_Suballocations.cbegin(); - suballocItem != m_Suballocations.cend(); - ++suballocItem) + // Place block at the start of list if it's normal block + if (listIndex != m_ListsCount && block.PrevFree()) { - const VmaSuballocation& suballoc = *suballocItem; - if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) - { - outInfo.allocationSizeMin = VMA_MIN(outInfo.allocationSizeMin, suballoc.size); - outInfo.allocationSizeMax = VMA_MAX(outInfo.allocationSizeMax, suballoc.size); - } - else - { - outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, suballoc.size); - outInfo.unusedRangeSizeMax = VMA_MAX(outInfo.unusedRangeSizeMax, suballoc.size); - } + block.PrevFree()->NextFree() = block.NextFree(); + if (block.NextFree()) + block.NextFree()->PrevFree() = block.PrevFree(); + block.PrevFree() = VMA_NULL; + block.NextFree() = m_FreeList[listIndex]; + m_FreeList[listIndex] = █ + if (block.NextFree()) + block.NextFree()->PrevFree() = █ } + + return true; } +#endif // _VMA_BLOCK_METADATA_TLSF_FUNCTIONS +#endif // _VMA_BLOCK_METADATA_TLSF -void VmaBlockMetadata_Generic::AddPoolStats(VmaPoolStats& inoutStats) const -{ - const uint32_t rangeCount = (uint32_t)m_Suballocations.size(); +#ifndef _VMA_BLOCK_VECTOR +/* +Sequence of VmaDeviceMemoryBlock. Represents memory blocks allocated for a specific +Vulkan memory type. - inoutStats.size += GetSize(); - inoutStats.unusedSize += m_SumFreeSize; - inoutStats.allocationCount += rangeCount - m_FreeCount; - inoutStats.unusedRangeCount += m_FreeCount; - inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, GetUnusedRangeSizeMax()); -} +Synchronized internally with a mutex. +*/ +class VmaBlockVector +{ + friend struct VmaDefragmentationContext_T; + VMA_CLASS_NO_COPY(VmaBlockVector) +public: + VmaBlockVector( + VmaAllocator hAllocator, + VmaPool hParentPool, + uint32_t memoryTypeIndex, + VkDeviceSize preferredBlockSize, + size_t minBlockCount, + size_t maxBlockCount, + VkDeviceSize bufferImageGranularity, + bool explicitBlockSize, + uint32_t algorithm, + float priority, + VkDeviceSize minAllocationAlignment, + void* pMemoryAllocateNext); + ~VmaBlockVector(); -#if VMA_STATS_STRING_ENABLED + VmaAllocator GetAllocator() const { return m_hAllocator; } + VmaPool GetParentPool() const { return m_hParentPool; } + bool IsCustomPool() const { return m_hParentPool != VMA_NULL; } + uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } + VkDeviceSize GetPreferredBlockSize() const { return m_PreferredBlockSize; } + VkDeviceSize GetBufferImageGranularity() const { return m_BufferImageGranularity; } + uint32_t GetAlgorithm() const { return m_Algorithm; } + bool HasExplicitBlockSize() const { return m_ExplicitBlockSize; } + float GetPriority() const { return m_Priority; } + const void* GetAllocationNextPtr() const { return m_pMemoryAllocateNext; } + // To be used only while the m_Mutex is locked. Used during defragmentation. + size_t GetBlockCount() const { return m_Blocks.size(); } + // To be used only while the m_Mutex is locked. Used during defragmentation. + VmaDeviceMemoryBlock* GetBlock(size_t index) const { return m_Blocks[index]; } + VMA_RW_MUTEX &GetMutex() { return m_Mutex; } -void VmaBlockMetadata_Generic::PrintDetailedMap(class VmaJsonWriter& json) const -{ - PrintDetailedMap_Begin(json, - m_SumFreeSize, // unusedBytes - m_Suballocations.size() - (size_t)m_FreeCount, // allocationCount - m_FreeCount); // unusedRangeCount + VkResult CreateMinBlocks(); + void AddStatistics(VmaStatistics& inoutStats); + void AddDetailedStatistics(VmaDetailedStatistics& inoutStats); + bool IsEmpty(); + bool IsCorruptionDetectionEnabled() const; - size_t i = 0; - for(VmaSuballocationList::const_iterator suballocItem = m_Suballocations.cbegin(); - suballocItem != m_Suballocations.cend(); - ++suballocItem, ++i) - { - if(suballocItem->type == VMA_SUBALLOCATION_TYPE_FREE) - { - PrintDetailedMap_UnusedRange(json, suballocItem->offset, suballocItem->size); - } - else - { - PrintDetailedMap_Allocation(json, suballocItem->offset, suballocItem->hAllocation); - } - } + VkResult Allocate( + VkDeviceSize size, + VkDeviceSize alignment, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations); - PrintDetailedMap_End(json); -} + void Free(const VmaAllocation hAllocation); -#endif // #if VMA_STATS_STRING_ENABLED +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap(class VmaJsonWriter& json); +#endif -bool VmaBlockMetadata_Generic::CreateAllocationRequest( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - bool upperAddress, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest) -{ - VMA_ASSERT(allocSize > 0); - VMA_ASSERT(!upperAddress); - VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); - VMA_ASSERT(pAllocationRequest != VMA_NULL); - VMA_HEAVY_ASSERT(Validate()); + VkResult CheckCorruption(); - pAllocationRequest->type = VmaAllocationRequestType::Normal; +private: + const VmaAllocator m_hAllocator; + const VmaPool m_hParentPool; + const uint32_t m_MemoryTypeIndex; + const VkDeviceSize m_PreferredBlockSize; + const size_t m_MinBlockCount; + const size_t m_MaxBlockCount; + const VkDeviceSize m_BufferImageGranularity; + const bool m_ExplicitBlockSize; + const uint32_t m_Algorithm; + const float m_Priority; + const VkDeviceSize m_MinAllocationAlignment; - // There is not enough total free space in this block to fullfill the request: Early return. - if(canMakeOtherLost == false && - m_SumFreeSize < allocSize + 2 * VMA_DEBUG_MARGIN) - { - return false; - } + void* const m_pMemoryAllocateNext; + VMA_RW_MUTEX m_Mutex; + // Incrementally sorted by sumFreeSize, ascending. + VmaVector> m_Blocks; + uint32_t m_NextBlockId; + bool m_IncrementalSort = true; - // New algorithm, efficiently searching freeSuballocationsBySize. - const size_t freeSuballocCount = m_FreeSuballocationsBySize.size(); - if(freeSuballocCount > 0) - { - if(strategy == VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT) - { - // Find first free suballocation with size not less than allocSize + 2 * VMA_DEBUG_MARGIN. - VmaSuballocationList::iterator* const it = VmaBinaryFindFirstNotLess( - m_FreeSuballocationsBySize.data(), - m_FreeSuballocationsBySize.data() + freeSuballocCount, - allocSize + 2 * VMA_DEBUG_MARGIN, - VmaSuballocationItemSizeLess()); - size_t index = it - m_FreeSuballocationsBySize.data(); - for(; index < freeSuballocCount; ++index) - { - if(CheckAllocation( - currentFrameIndex, - frameInUseCount, - bufferImageGranularity, - allocSize, - allocAlignment, - allocType, - m_FreeSuballocationsBySize[index], - false, // canMakeOtherLost - &pAllocationRequest->offset, - &pAllocationRequest->itemsToMakeLostCount, - &pAllocationRequest->sumFreeSize, - &pAllocationRequest->sumItemSize)) - { - pAllocationRequest->item = m_FreeSuballocationsBySize[index]; - return true; - } - } - } - else if(strategy == VMA_ALLOCATION_INTERNAL_STRATEGY_MIN_OFFSET) - { - for(VmaSuballocationList::iterator it = m_Suballocations.begin(); - it != m_Suballocations.end(); - ++it) - { - if(it->type == VMA_SUBALLOCATION_TYPE_FREE && CheckAllocation( - currentFrameIndex, - frameInUseCount, - bufferImageGranularity, - allocSize, - allocAlignment, - allocType, - it, - false, // canMakeOtherLost - &pAllocationRequest->offset, - &pAllocationRequest->itemsToMakeLostCount, - &pAllocationRequest->sumFreeSize, - &pAllocationRequest->sumItemSize)) - { - pAllocationRequest->item = it; - return true; - } - } - } - else // WORST_FIT, FIRST_FIT - { - // Search staring from biggest suballocations. - for(size_t index = freeSuballocCount; index--; ) - { - if(CheckAllocation( - currentFrameIndex, - frameInUseCount, - bufferImageGranularity, - allocSize, - allocAlignment, - allocType, - m_FreeSuballocationsBySize[index], - false, // canMakeOtherLost - &pAllocationRequest->offset, - &pAllocationRequest->itemsToMakeLostCount, - &pAllocationRequest->sumFreeSize, - &pAllocationRequest->sumItemSize)) - { - pAllocationRequest->item = m_FreeSuballocationsBySize[index]; - return true; - } - } - } - } + void SetIncrementalSort(bool val) { m_IncrementalSort = val; } - if(canMakeOtherLost) - { - // Brute-force algorithm. TODO: Come up with something better. + VkDeviceSize CalcMaxBlockSize() const; + // Finds and removes given block from vector. + void Remove(VmaDeviceMemoryBlock* pBlock); + // Performs single step in sorting m_Blocks. They may not be fully sorted + // after this call. + void IncrementallySortBlocks(); + void SortByFreeSize(); - bool found = false; - VmaAllocationRequest tmpAllocRequest = {}; - tmpAllocRequest.type = VmaAllocationRequestType::Normal; - for(VmaSuballocationList::iterator suballocIt = m_Suballocations.begin(); - suballocIt != m_Suballocations.end(); - ++suballocIt) - { - if(suballocIt->type == VMA_SUBALLOCATION_TYPE_FREE || - suballocIt->hAllocation->CanBecomeLost()) - { - if(CheckAllocation( - currentFrameIndex, - frameInUseCount, - bufferImageGranularity, - allocSize, - allocAlignment, - allocType, - suballocIt, - canMakeOtherLost, - &tmpAllocRequest.offset, - &tmpAllocRequest.itemsToMakeLostCount, - &tmpAllocRequest.sumFreeSize, - &tmpAllocRequest.sumItemSize)) - { - if(strategy == VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT) - { - *pAllocationRequest = tmpAllocRequest; - pAllocationRequest->item = suballocIt; - break; - } - if(!found || tmpAllocRequest.CalcCost() < pAllocationRequest->CalcCost()) - { - *pAllocationRequest = tmpAllocRequest; - pAllocationRequest->item = suballocIt; - found = true; - } - } - } - } + VkResult AllocatePage( + VkDeviceSize size, + VkDeviceSize alignment, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + VmaAllocation* pAllocation); - return found; - } + VkResult AllocateFromBlock( + VmaDeviceMemoryBlock* pBlock, + VkDeviceSize size, + VkDeviceSize alignment, + VmaAllocationCreateFlags allocFlags, + void* pUserData, + VmaSuballocationType suballocType, + uint32_t strategy, + VmaAllocation* pAllocation); - return false; -} + VkResult CommitAllocationRequest( + VmaAllocationRequest& allocRequest, + VmaDeviceMemoryBlock* pBlock, + VkDeviceSize alignment, + VmaAllocationCreateFlags allocFlags, + void* pUserData, + VmaSuballocationType suballocType, + VmaAllocation* pAllocation); -bool VmaBlockMetadata_Generic::MakeRequestedAllocationsLost( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VmaAllocationRequest* pAllocationRequest) + VkResult CreateBlock(VkDeviceSize blockSize, size_t* pNewBlockIndex); + bool HasEmptyBlock(); +}; +#endif // _VMA_BLOCK_VECTOR + +#ifndef _VMA_DEFRAGMENTATION_CONTEXT +struct VmaDefragmentationContext_T { - VMA_ASSERT(pAllocationRequest && pAllocationRequest->type == VmaAllocationRequestType::Normal); + VMA_CLASS_NO_COPY(VmaDefragmentationContext_T) +public: + VmaDefragmentationContext_T( + VmaAllocator hAllocator, + const VmaDefragmentationInfo& info); + ~VmaDefragmentationContext_T(); - while(pAllocationRequest->itemsToMakeLostCount > 0) - { - if(pAllocationRequest->item->type == VMA_SUBALLOCATION_TYPE_FREE) - { - ++pAllocationRequest->item; - } - VMA_ASSERT(pAllocationRequest->item != m_Suballocations.end()); - VMA_ASSERT(pAllocationRequest->item->hAllocation != VK_NULL_HANDLE); - VMA_ASSERT(pAllocationRequest->item->hAllocation->CanBecomeLost()); - if(pAllocationRequest->item->hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) - { - pAllocationRequest->item = FreeSuballocation(pAllocationRequest->item); - --pAllocationRequest->itemsToMakeLostCount; - } - else - { - return false; - } - } + void GetStats(VmaDefragmentationStats& outStats) { outStats = m_GlobalStats; } - VMA_HEAVY_ASSERT(Validate()); - VMA_ASSERT(pAllocationRequest->item != m_Suballocations.end()); - VMA_ASSERT(pAllocationRequest->item->type == VMA_SUBALLOCATION_TYPE_FREE); + VkResult DefragmentPassBegin(VmaDefragmentationPassMoveInfo& moveInfo); + VkResult DefragmentPassEnd(VmaDefragmentationPassMoveInfo& moveInfo); - return true; -} +private: + // Max number of allocations to ignore due to size constraints before ending single pass + static const uint8_t MAX_ALLOCS_TO_IGNORE = 16; + enum class CounterStatus { Pass, Ignore, End }; -uint32_t VmaBlockMetadata_Generic::MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) -{ - uint32_t lostAllocationCount = 0; - for(VmaSuballocationList::iterator it = m_Suballocations.begin(); - it != m_Suballocations.end(); - ++it) + struct FragmentedBlock + { + uint32_t data; + VmaDeviceMemoryBlock* block; + }; + struct StateBalanced + { + VkDeviceSize avgFreeSize = 0; + VkDeviceSize avgAllocSize = UINT64_MAX; + }; + struct StateExtensive { - if(it->type != VMA_SUBALLOCATION_TYPE_FREE && - it->hAllocation->CanBecomeLost() && - it->hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) + enum class Operation : uint8_t { - it = FreeSuballocation(it); - ++lostAllocationCount; - } - } - return lostAllocationCount; -} + FindFreeBlockBuffer, FindFreeBlockTexture, FindFreeBlockAll, + MoveBuffers, MoveTextures, MoveAll, + Cleanup, Done + }; -VkResult VmaBlockMetadata_Generic::CheckCorruption(const void* pBlockData) -{ - for(VmaSuballocationList::iterator it = m_Suballocations.begin(); - it != m_Suballocations.end(); - ++it) + Operation operation = Operation::FindFreeBlockTexture; + size_t firstFreeBlock = SIZE_MAX; + }; + struct MoveAllocationData { - if(it->type != VMA_SUBALLOCATION_TYPE_FREE) - { - if(!VmaValidateMagicValue(pBlockData, it->offset - VMA_DEBUG_MARGIN)) - { - VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED BEFORE VALIDATED ALLOCATION!"); - return VK_ERROR_VALIDATION_FAILED_EXT; - } - if(!VmaValidateMagicValue(pBlockData, it->offset + it->size)) - { - VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); - return VK_ERROR_VALIDATION_FAILED_EXT; - } - } - } + VkDeviceSize size; + VkDeviceSize alignment; + VmaSuballocationType type; + VmaAllocationCreateFlags flags; + VmaDefragmentationMove move = {}; + }; - return VK_SUCCESS; -} + const VkDeviceSize m_MaxPassBytes; + const uint32_t m_MaxPassAllocations; + + VmaStlAllocator m_MoveAllocator; + VmaVector> m_Moves; + + uint8_t m_IgnoredAllocs = 0; + uint32_t m_Algorithm; + uint32_t m_BlockVectorCount; + VmaBlockVector* m_PoolBlockVector; + VmaBlockVector** m_pBlockVectors; + size_t m_ImmovableBlockCount = 0; + VmaDefragmentationStats m_GlobalStats = { 0 }; + VmaDefragmentationStats m_PassStats = { 0 }; + void* m_AlgorithmState = VMA_NULL; + + static MoveAllocationData GetMoveData(VmaAllocHandle handle, VmaBlockMetadata* metadata); + CounterStatus CheckCounters(VkDeviceSize bytes); + bool IncrementCounters(VkDeviceSize bytes); + bool ReallocWithinBlock(VmaBlockVector& vector, VmaDeviceMemoryBlock* block); + bool AllocInOtherBlock(size_t start, size_t end, MoveAllocationData& data, VmaBlockVector& vector); + + bool ComputeDefragmentation(VmaBlockVector& vector, size_t index); + bool ComputeDefragmentation_Fast(VmaBlockVector& vector); + bool ComputeDefragmentation_Balanced(VmaBlockVector& vector, size_t index, bool update); + bool ComputeDefragmentation_Full(VmaBlockVector& vector); + bool ComputeDefragmentation_Extensive(VmaBlockVector& vector, size_t index); + + void UpdateVectorStatistics(VmaBlockVector& vector, StateBalanced& state); + bool MoveDataToFreeBlocks(VmaSuballocationType currentType, + VmaBlockVector& vector, size_t firstFreeBlock, + bool& texturePresent, bool& bufferPresent, bool& otherPresent); +}; +#endif // _VMA_DEFRAGMENTATION_CONTEXT -void VmaBlockMetadata_Generic::Alloc( - const VmaAllocationRequest& request, - VmaSuballocationType type, - VkDeviceSize allocSize, - VmaAllocation hAllocation) +#ifndef _VMA_POOL_T +struct VmaPool_T { - VMA_ASSERT(request.type == VmaAllocationRequestType::Normal); - VMA_ASSERT(request.item != m_Suballocations.end()); - VmaSuballocation& suballoc = *request.item; - // Given suballocation is a free block. - VMA_ASSERT(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); - // Given offset is inside this suballocation. - VMA_ASSERT(request.offset >= suballoc.offset); - const VkDeviceSize paddingBegin = request.offset - suballoc.offset; - VMA_ASSERT(suballoc.size >= paddingBegin + allocSize); - const VkDeviceSize paddingEnd = suballoc.size - paddingBegin - allocSize; - - // Unregister this free suballocation from m_FreeSuballocationsBySize and update - // it to become used. - UnregisterFreeSuballocation(request.item); - - suballoc.offset = request.offset; - suballoc.size = allocSize; - suballoc.type = type; - suballoc.hAllocation = hAllocation; + friend struct VmaPoolListItemTraits; + VMA_CLASS_NO_COPY(VmaPool_T) +public: + VmaBlockVector m_BlockVector; + VmaDedicatedAllocationList m_DedicatedAllocations; - // If there are any free bytes remaining at the end, insert new free suballocation after current one. - if(paddingEnd) - { - VmaSuballocation paddingSuballoc = {}; - paddingSuballoc.offset = request.offset + allocSize; - paddingSuballoc.size = paddingEnd; - paddingSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; - VmaSuballocationList::iterator next = request.item; - ++next; - const VmaSuballocationList::iterator paddingEndItem = - m_Suballocations.insert(next, paddingSuballoc); - RegisterFreeSuballocation(paddingEndItem); - } + VmaPool_T( + VmaAllocator hAllocator, + const VmaPoolCreateInfo& createInfo, + VkDeviceSize preferredBlockSize); + ~VmaPool_T(); - // If there are any free bytes remaining at the beginning, insert new free suballocation before current one. - if(paddingBegin) - { - VmaSuballocation paddingSuballoc = {}; - paddingSuballoc.offset = request.offset - paddingBegin; - paddingSuballoc.size = paddingBegin; - paddingSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; - const VmaSuballocationList::iterator paddingBeginItem = - m_Suballocations.insert(request.item, paddingSuballoc); - RegisterFreeSuballocation(paddingBeginItem); - } + uint32_t GetId() const { return m_Id; } + void SetId(uint32_t id) { VMA_ASSERT(m_Id == 0); m_Id = id; } - // Update totals. - m_FreeCount = m_FreeCount - 1; - if(paddingBegin > 0) - { - ++m_FreeCount; - } - if(paddingEnd > 0) - { - ++m_FreeCount; - } - m_SumFreeSize -= allocSize; -} + const char* GetName() const { return m_Name; } + void SetName(const char* pName); + +#if VMA_STATS_STRING_ENABLED + //void PrintDetailedMap(class VmaStringBuilder& sb); +#endif + +private: + uint32_t m_Id; + char* m_Name; + VmaPool_T* m_PrevPool = VMA_NULL; + VmaPool_T* m_NextPool = VMA_NULL; +}; -void VmaBlockMetadata_Generic::Free(const VmaAllocation allocation) +struct VmaPoolListItemTraits { - for(VmaSuballocationList::iterator suballocItem = m_Suballocations.begin(); - suballocItem != m_Suballocations.end(); - ++suballocItem) - { - VmaSuballocation& suballoc = *suballocItem; - if(suballoc.hAllocation == allocation) - { - FreeSuballocation(suballocItem); - VMA_HEAVY_ASSERT(Validate()); - return; - } - } - VMA_ASSERT(0 && "Not found!"); -} + typedef VmaPool_T ItemType; + + static ItemType* GetPrev(const ItemType* item) { return item->m_PrevPool; } + static ItemType* GetNext(const ItemType* item) { return item->m_NextPool; } + static ItemType*& AccessPrev(ItemType* item) { return item->m_PrevPool; } + static ItemType*& AccessNext(ItemType* item) { return item->m_NextPool; } +}; +#endif // _VMA_POOL_T + +#ifndef _VMA_CURRENT_BUDGET_DATA +struct VmaCurrentBudgetData +{ + VMA_ATOMIC_UINT32 m_BlockCount[VK_MAX_MEMORY_HEAPS]; + VMA_ATOMIC_UINT32 m_AllocationCount[VK_MAX_MEMORY_HEAPS]; + VMA_ATOMIC_UINT64 m_BlockBytes[VK_MAX_MEMORY_HEAPS]; + VMA_ATOMIC_UINT64 m_AllocationBytes[VK_MAX_MEMORY_HEAPS]; + +#if VMA_MEMORY_BUDGET + VMA_ATOMIC_UINT32 m_OperationsSinceBudgetFetch; + VMA_RW_MUTEX m_BudgetMutex; + uint64_t m_VulkanUsage[VK_MAX_MEMORY_HEAPS]; + uint64_t m_VulkanBudget[VK_MAX_MEMORY_HEAPS]; + uint64_t m_BlockBytesAtBudgetFetch[VK_MAX_MEMORY_HEAPS]; +#endif // VMA_MEMORY_BUDGET + + VmaCurrentBudgetData(); + + void AddAllocation(uint32_t heapIndex, VkDeviceSize allocationSize); + void RemoveAllocation(uint32_t heapIndex, VkDeviceSize allocationSize); +}; -void VmaBlockMetadata_Generic::FreeAtOffset(VkDeviceSize offset) +#ifndef _VMA_CURRENT_BUDGET_DATA_FUNCTIONS +VmaCurrentBudgetData::VmaCurrentBudgetData() { - for(VmaSuballocationList::iterator suballocItem = m_Suballocations.begin(); - suballocItem != m_Suballocations.end(); - ++suballocItem) + for (uint32_t heapIndex = 0; heapIndex < VK_MAX_MEMORY_HEAPS; ++heapIndex) { - VmaSuballocation& suballoc = *suballocItem; - if(suballoc.offset == offset) - { - FreeSuballocation(suballocItem); - return; - } + m_BlockCount[heapIndex] = 0; + m_AllocationCount[heapIndex] = 0; + m_BlockBytes[heapIndex] = 0; + m_AllocationBytes[heapIndex] = 0; +#if VMA_MEMORY_BUDGET + m_VulkanUsage[heapIndex] = 0; + m_VulkanBudget[heapIndex] = 0; + m_BlockBytesAtBudgetFetch[heapIndex] = 0; +#endif } - VMA_ASSERT(0 && "Not found!"); + +#if VMA_MEMORY_BUDGET + m_OperationsSinceBudgetFetch = 0; +#endif } -bool VmaBlockMetadata_Generic::ValidateFreeSuballocationList() const +void VmaCurrentBudgetData::AddAllocation(uint32_t heapIndex, VkDeviceSize allocationSize) { - VkDeviceSize lastSize = 0; - for(size_t i = 0, count = m_FreeSuballocationsBySize.size(); i < count; ++i) - { - const VmaSuballocationList::iterator it = m_FreeSuballocationsBySize[i]; + m_AllocationBytes[heapIndex] += allocationSize; + ++m_AllocationCount[heapIndex]; +#if VMA_MEMORY_BUDGET + ++m_OperationsSinceBudgetFetch; +#endif +} - VMA_VALIDATE(it->type == VMA_SUBALLOCATION_TYPE_FREE); - VMA_VALIDATE(it->size >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER); - VMA_VALIDATE(it->size >= lastSize); - lastSize = it->size; - } - return true; +void VmaCurrentBudgetData::RemoveAllocation(uint32_t heapIndex, VkDeviceSize allocationSize) +{ + VMA_ASSERT(m_AllocationBytes[heapIndex] >= allocationSize); + m_AllocationBytes[heapIndex] -= allocationSize; + VMA_ASSERT(m_AllocationCount[heapIndex] > 0); + --m_AllocationCount[heapIndex]; +#if VMA_MEMORY_BUDGET + ++m_OperationsSinceBudgetFetch; +#endif } +#endif // _VMA_CURRENT_BUDGET_DATA_FUNCTIONS +#endif // _VMA_CURRENT_BUDGET_DATA -bool VmaBlockMetadata_Generic::CheckAllocation( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - VmaSuballocationType allocType, - VmaSuballocationList::const_iterator suballocItem, - bool canMakeOtherLost, - VkDeviceSize* pOffset, - size_t* itemsToMakeLostCount, - VkDeviceSize* pSumFreeSize, - VkDeviceSize* pSumItemSize) const +#ifndef _VMA_ALLOCATION_OBJECT_ALLOCATOR +/* +Thread-safe wrapper over VmaPoolAllocator free list, for allocation of VmaAllocation_T objects. +*/ +class VmaAllocationObjectAllocator { - VMA_ASSERT(allocSize > 0); - VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); - VMA_ASSERT(suballocItem != m_Suballocations.cend()); - VMA_ASSERT(pOffset != VMA_NULL); + VMA_CLASS_NO_COPY(VmaAllocationObjectAllocator) +public: + VmaAllocationObjectAllocator(const VkAllocationCallbacks* pAllocationCallbacks) + : m_Allocator(pAllocationCallbacks, 1024) {} - *itemsToMakeLostCount = 0; - *pSumFreeSize = 0; - *pSumItemSize = 0; + template VmaAllocation Allocate(Types&&... args); + void Free(VmaAllocation hAlloc); - if(canMakeOtherLost) - { - if(suballocItem->type == VMA_SUBALLOCATION_TYPE_FREE) - { - *pSumFreeSize = suballocItem->size; - } - else - { - if(suballocItem->hAllocation->CanBecomeLost() && - suballocItem->hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) - { - ++*itemsToMakeLostCount; - *pSumItemSize = suballocItem->size; - } - else - { - return false; - } - } +private: + VMA_MUTEX m_Mutex; + VmaPoolAllocator m_Allocator; +}; - // Remaining size is too small for this request: Early return. - if(GetSize() - suballocItem->offset < allocSize) - { - return false; - } +template +VmaAllocation VmaAllocationObjectAllocator::Allocate(Types&&... args) +{ + VmaMutexLock mutexLock(m_Mutex); + return m_Allocator.Alloc(std::forward(args)...); +} - // Start from offset equal to beginning of this suballocation. - *pOffset = suballocItem->offset; +void VmaAllocationObjectAllocator::Free(VmaAllocation hAlloc) +{ + VmaMutexLock mutexLock(m_Mutex); + m_Allocator.Free(hAlloc); +} +#endif // _VMA_ALLOCATION_OBJECT_ALLOCATOR - // Apply VMA_DEBUG_MARGIN at the beginning. - if(VMA_DEBUG_MARGIN > 0) - { - *pOffset += VMA_DEBUG_MARGIN; - } +#ifndef _VMA_VIRTUAL_BLOCK_T +struct VmaVirtualBlock_T +{ + VMA_CLASS_NO_COPY(VmaVirtualBlock_T) +public: + const bool m_AllocationCallbacksSpecified; + const VkAllocationCallbacks m_AllocationCallbacks; + + VmaVirtualBlock_T(const VmaVirtualBlockCreateInfo& createInfo); + ~VmaVirtualBlock_T(); + + VkResult Init() { return VK_SUCCESS; } + bool IsEmpty() const { return m_Metadata->IsEmpty(); } + void Free(VmaVirtualAllocation allocation) { m_Metadata->Free((VmaAllocHandle)allocation); } + void SetAllocationUserData(VmaVirtualAllocation allocation, void* userData) { m_Metadata->SetAllocationUserData((VmaAllocHandle)allocation, userData); } + void Clear() { m_Metadata->Clear(); } + + const VkAllocationCallbacks* GetAllocationCallbacks() const; + void GetAllocationInfo(VmaVirtualAllocation allocation, VmaVirtualAllocationInfo& outInfo); + VkResult Allocate(const VmaVirtualAllocationCreateInfo& createInfo, VmaVirtualAllocation& outAllocation, + VkDeviceSize* outOffset); + void GetStatistics(VmaStatistics& outStats) const; + void CalculateDetailedStatistics(VmaDetailedStatistics& outStats) const; +#if VMA_STATS_STRING_ENABLED + void BuildStatsString(bool detailedMap, VmaStringBuilder& sb) const; +#endif - // Apply alignment. - *pOffset = VmaAlignUp(*pOffset, allocAlignment); +private: + VmaBlockMetadata* m_Metadata; +}; - // Check previous suballocations for BufferImageGranularity conflicts. - // Make bigger alignment if necessary. - if(bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment) - { - bool bufferImageGranularityConflict = false; - VmaSuballocationList::const_iterator prevSuballocItem = suballocItem; - while(prevSuballocItem != m_Suballocations.cbegin()) - { - --prevSuballocItem; - const VmaSuballocation& prevSuballoc = *prevSuballocItem; - if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, *pOffset, bufferImageGranularity)) - { - if(VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) - { - bufferImageGranularityConflict = true; - break; - } - } - else - // Already on previous page. - break; - } - if(bufferImageGranularityConflict) - { - *pOffset = VmaAlignUp(*pOffset, bufferImageGranularity); - } - } +#ifndef _VMA_VIRTUAL_BLOCK_T_FUNCTIONS +VmaVirtualBlock_T::VmaVirtualBlock_T(const VmaVirtualBlockCreateInfo& createInfo) + : m_AllocationCallbacksSpecified(createInfo.pAllocationCallbacks != VMA_NULL), + m_AllocationCallbacks(createInfo.pAllocationCallbacks != VMA_NULL ? *createInfo.pAllocationCallbacks : VmaEmptyAllocationCallbacks) +{ + const uint32_t algorithm = createInfo.flags & VMA_VIRTUAL_BLOCK_CREATE_ALGORITHM_MASK; + switch (algorithm) + { + default: + VMA_ASSERT(0); + case 0: + m_Metadata = vma_new(GetAllocationCallbacks(), VmaBlockMetadata_TLSF)(VK_NULL_HANDLE, 1, true); + break; + case VMA_VIRTUAL_BLOCK_CREATE_LINEAR_ALGORITHM_BIT: + m_Metadata = vma_new(GetAllocationCallbacks(), VmaBlockMetadata_Linear)(VK_NULL_HANDLE, 1, true); + break; + } - // Now that we have final *pOffset, check if we are past suballocItem. - // If yes, return false - this function should be called for another suballocItem as starting point. - if(*pOffset >= suballocItem->offset + suballocItem->size) - { - return false; - } + m_Metadata->Init(createInfo.size); +} - // Calculate padding at the beginning based on current offset. - const VkDeviceSize paddingBegin = *pOffset - suballocItem->offset; +VmaVirtualBlock_T::~VmaVirtualBlock_T() +{ + // Define macro VMA_DEBUG_LOG to receive the list of the unfreed allocations + if (!m_Metadata->IsEmpty()) + m_Metadata->DebugLogAllAllocations(); + // This is the most important assert in the entire library. + // Hitting it means you have some memory leak - unreleased virtual allocations. + VMA_ASSERT(m_Metadata->IsEmpty() && "Some virtual allocations were not freed before destruction of this virtual block!"); - // Calculate required margin at the end. - const VkDeviceSize requiredEndMargin = VMA_DEBUG_MARGIN; + vma_delete(GetAllocationCallbacks(), m_Metadata); +} - const VkDeviceSize totalSize = paddingBegin + allocSize + requiredEndMargin; - // Another early return check. - if(suballocItem->offset + totalSize > GetSize()) - { - return false; - } +const VkAllocationCallbacks* VmaVirtualBlock_T::GetAllocationCallbacks() const +{ + return m_AllocationCallbacksSpecified ? &m_AllocationCallbacks : VMA_NULL; +} - // Advance lastSuballocItem until desired size is reached. - // Update itemsToMakeLostCount. - VmaSuballocationList::const_iterator lastSuballocItem = suballocItem; - if(totalSize > suballocItem->size) - { - VkDeviceSize remainingSize = totalSize - suballocItem->size; - while(remainingSize > 0) - { - ++lastSuballocItem; - if(lastSuballocItem == m_Suballocations.cend()) - { - return false; - } - if(lastSuballocItem->type == VMA_SUBALLOCATION_TYPE_FREE) - { - *pSumFreeSize += lastSuballocItem->size; - } - else - { - VMA_ASSERT(lastSuballocItem->hAllocation != VK_NULL_HANDLE); - if(lastSuballocItem->hAllocation->CanBecomeLost() && - lastSuballocItem->hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) - { - ++*itemsToMakeLostCount; - *pSumItemSize += lastSuballocItem->size; - } - else - { - return false; - } - } - remainingSize = (lastSuballocItem->size < remainingSize) ? - remainingSize - lastSuballocItem->size : 0; - } - } +void VmaVirtualBlock_T::GetAllocationInfo(VmaVirtualAllocation allocation, VmaVirtualAllocationInfo& outInfo) +{ + m_Metadata->GetAllocationInfo((VmaAllocHandle)allocation, outInfo); +} - // Check next suballocations for BufferImageGranularity conflicts. - // If conflict exists, we must mark more allocations lost or fail. - if(allocSize % bufferImageGranularity || *pOffset % bufferImageGranularity) - { - VmaSuballocationList::const_iterator nextSuballocItem = lastSuballocItem; - ++nextSuballocItem; - while(nextSuballocItem != m_Suballocations.cend()) - { - const VmaSuballocation& nextSuballoc = *nextSuballocItem; - if(VmaBlocksOnSamePage(*pOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) - { - if(VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) - { - VMA_ASSERT(nextSuballoc.hAllocation != VK_NULL_HANDLE); - if(nextSuballoc.hAllocation->CanBecomeLost() && - nextSuballoc.hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) - { - ++*itemsToMakeLostCount; - } - else - { - return false; - } - } - } - else - { - // Already on next page. - break; - } - ++nextSuballocItem; - } - } - } - else +VkResult VmaVirtualBlock_T::Allocate(const VmaVirtualAllocationCreateInfo& createInfo, VmaVirtualAllocation& outAllocation, + VkDeviceSize* outOffset) +{ + VmaAllocationRequest request = {}; + if (m_Metadata->CreateAllocationRequest( + createInfo.size, // allocSize + VMA_MAX(createInfo.alignment, (VkDeviceSize)1), // allocAlignment + (createInfo.flags & VMA_VIRTUAL_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0, // upperAddress + VMA_SUBALLOCATION_TYPE_UNKNOWN, // allocType - unimportant + createInfo.flags & VMA_VIRTUAL_ALLOCATION_CREATE_STRATEGY_MASK, // strategy + &request)) { - const VmaSuballocation& suballoc = *suballocItem; - VMA_ASSERT(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + m_Metadata->Alloc(request, + VMA_SUBALLOCATION_TYPE_UNKNOWN, // type - unimportant + createInfo.pUserData); + outAllocation = (VmaVirtualAllocation)request.allocHandle; + if(outOffset) + *outOffset = m_Metadata->GetAllocationOffset(request.allocHandle); + return VK_SUCCESS; + } + outAllocation = (VmaVirtualAllocation)VK_NULL_HANDLE; + if (outOffset) + *outOffset = UINT64_MAX; + return VK_ERROR_OUT_OF_DEVICE_MEMORY; +} - *pSumFreeSize = suballoc.size; +void VmaVirtualBlock_T::GetStatistics(VmaStatistics& outStats) const +{ + VmaClearStatistics(outStats); + m_Metadata->AddStatistics(outStats); +} - // Size of this suballocation is too small for this request: Early return. - if(suballoc.size < allocSize) - { - return false; - } +void VmaVirtualBlock_T::CalculateDetailedStatistics(VmaDetailedStatistics& outStats) const +{ + VmaClearDetailedStatistics(outStats); + m_Metadata->AddDetailedStatistics(outStats); +} - // Start from offset equal to beginning of this suballocation. - *pOffset = suballoc.offset; +#if VMA_STATS_STRING_ENABLED +void VmaVirtualBlock_T::BuildStatsString(bool detailedMap, VmaStringBuilder& sb) const +{ + VmaJsonWriter json(GetAllocationCallbacks(), sb); + json.BeginObject(); - // Apply VMA_DEBUG_MARGIN at the beginning. - if(VMA_DEBUG_MARGIN > 0) - { - *pOffset += VMA_DEBUG_MARGIN; - } + VmaDetailedStatistics stats; + CalculateDetailedStatistics(stats); - // Apply alignment. - *pOffset = VmaAlignUp(*pOffset, allocAlignment); + json.WriteString("Stats"); + VmaPrintDetailedStatistics(json, stats); - // Check previous suballocations for BufferImageGranularity conflicts. - // Make bigger alignment if necessary. - if(bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment) - { - bool bufferImageGranularityConflict = false; - VmaSuballocationList::const_iterator prevSuballocItem = suballocItem; - while(prevSuballocItem != m_Suballocations.cbegin()) - { - --prevSuballocItem; - const VmaSuballocation& prevSuballoc = *prevSuballocItem; - if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, *pOffset, bufferImageGranularity)) - { - if(VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) - { - bufferImageGranularityConflict = true; - break; - } - } - else - // Already on previous page. - break; - } - if(bufferImageGranularityConflict) - { - *pOffset = VmaAlignUp(*pOffset, bufferImageGranularity); - } - } - - // Calculate padding at the beginning based on current offset. - const VkDeviceSize paddingBegin = *pOffset - suballoc.offset; - - // Calculate required margin at the end. - const VkDeviceSize requiredEndMargin = VMA_DEBUG_MARGIN; - - // Fail if requested size plus margin before and after is bigger than size of this suballocation. - if(paddingBegin + allocSize + requiredEndMargin > suballoc.size) - { - return false; - } - - // Check next suballocations for BufferImageGranularity conflicts. - // If conflict exists, allocation cannot be made here. - if(allocSize % bufferImageGranularity || *pOffset % bufferImageGranularity) - { - VmaSuballocationList::const_iterator nextSuballocItem = suballocItem; - ++nextSuballocItem; - while(nextSuballocItem != m_Suballocations.cend()) - { - const VmaSuballocation& nextSuballoc = *nextSuballocItem; - if(VmaBlocksOnSamePage(*pOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) - { - if(VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) - { - return false; - } - } - else - { - // Already on next page. - break; - } - ++nextSuballocItem; - } - } + if (detailedMap) + { + json.WriteString("Details"); + json.BeginObject(); + m_Metadata->PrintDetailedMap(json); + json.EndObject(); } - // All tests passed: Success. pOffset is already filled. - return true; + json.EndObject(); } +#endif // VMA_STATS_STRING_ENABLED +#endif // _VMA_VIRTUAL_BLOCK_T_FUNCTIONS +#endif // _VMA_VIRTUAL_BLOCK_T -void VmaBlockMetadata_Generic::MergeFreeWithNext(VmaSuballocationList::iterator item) + +// Main allocator object. +struct VmaAllocator_T { - VMA_ASSERT(item != m_Suballocations.end()); - VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); + VMA_CLASS_NO_COPY(VmaAllocator_T) +public: + bool m_UseMutex; + uint32_t m_VulkanApiVersion; + bool m_UseKhrDedicatedAllocation; // Can be set only if m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0). + bool m_UseKhrBindMemory2; // Can be set only if m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0). + bool m_UseExtMemoryBudget; + bool m_UseAmdDeviceCoherentMemory; + bool m_UseKhrBufferDeviceAddress; + bool m_UseExtMemoryPriority; + VkDevice m_hDevice; + VkInstance m_hInstance; + bool m_AllocationCallbacksSpecified; + VkAllocationCallbacks m_AllocationCallbacks; + VmaDeviceMemoryCallbacks m_DeviceMemoryCallbacks; + VmaAllocationObjectAllocator m_AllocationObjectAllocator; - VmaSuballocationList::iterator nextItem = item; - ++nextItem; - VMA_ASSERT(nextItem != m_Suballocations.end()); - VMA_ASSERT(nextItem->type == VMA_SUBALLOCATION_TYPE_FREE); + // Each bit (1 << i) is set if HeapSizeLimit is enabled for that heap, so cannot allocate more than the heap size. + uint32_t m_HeapSizeLimitMask; - item->size += nextItem->size; - --m_FreeCount; - m_Suballocations.erase(nextItem); -} + VkPhysicalDeviceProperties m_PhysicalDeviceProperties; + VkPhysicalDeviceMemoryProperties m_MemProps; -VmaSuballocationList::iterator VmaBlockMetadata_Generic::FreeSuballocation(VmaSuballocationList::iterator suballocItem) -{ - // Change this suballocation to be marked as free. - VmaSuballocation& suballoc = *suballocItem; - suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; - suballoc.hAllocation = VK_NULL_HANDLE; + // Default pools. + VmaBlockVector* m_pBlockVectors[VK_MAX_MEMORY_TYPES]; + VmaDedicatedAllocationList m_DedicatedAllocations[VK_MAX_MEMORY_TYPES]; - // Update totals. - ++m_FreeCount; - m_SumFreeSize += suballoc.size; + VmaCurrentBudgetData m_Budget; + VMA_ATOMIC_UINT32 m_DeviceMemoryCount; // Total number of VkDeviceMemory objects. - // Merge with previous and/or next suballocation if it's also free. - bool mergeWithNext = false; - bool mergeWithPrev = false; + VmaAllocator_T(const VmaAllocatorCreateInfo* pCreateInfo); + VkResult Init(const VmaAllocatorCreateInfo* pCreateInfo); + ~VmaAllocator_T(); - VmaSuballocationList::iterator nextItem = suballocItem; - ++nextItem; - if((nextItem != m_Suballocations.end()) && (nextItem->type == VMA_SUBALLOCATION_TYPE_FREE)) + const VkAllocationCallbacks* GetAllocationCallbacks() const { - mergeWithNext = true; + return m_AllocationCallbacksSpecified ? &m_AllocationCallbacks : VMA_NULL; } - - VmaSuballocationList::iterator prevItem = suballocItem; - if(suballocItem != m_Suballocations.begin()) + const VmaVulkanFunctions& GetVulkanFunctions() const { - --prevItem; - if(prevItem->type == VMA_SUBALLOCATION_TYPE_FREE) - { - mergeWithPrev = true; - } + return m_VulkanFunctions; } - if(mergeWithNext) - { - UnregisterFreeSuballocation(nextItem); - MergeFreeWithNext(suballocItem); - } + VkPhysicalDevice GetPhysicalDevice() const { return m_PhysicalDevice; } - if(mergeWithPrev) - { - UnregisterFreeSuballocation(prevItem); - MergeFreeWithNext(prevItem); - RegisterFreeSuballocation(prevItem); - return prevItem; - } - else + VkDeviceSize GetBufferImageGranularity() const { - RegisterFreeSuballocation(suballocItem); - return suballocItem; + return VMA_MAX( + static_cast(VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY), + m_PhysicalDeviceProperties.limits.bufferImageGranularity); } -} - -void VmaBlockMetadata_Generic::RegisterFreeSuballocation(VmaSuballocationList::iterator item) -{ - VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); - VMA_ASSERT(item->size > 0); - // You may want to enable this validation at the beginning or at the end of - // this function, depending on what do you want to check. - VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); + uint32_t GetMemoryHeapCount() const { return m_MemProps.memoryHeapCount; } + uint32_t GetMemoryTypeCount() const { return m_MemProps.memoryTypeCount; } - if(item->size >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + uint32_t MemoryTypeIndexToHeapIndex(uint32_t memTypeIndex) const { - if(m_FreeSuballocationsBySize.empty()) - { - m_FreeSuballocationsBySize.push_back(item); - } - else - { - VmaVectorInsertSorted(m_FreeSuballocationsBySize, item); - } + VMA_ASSERT(memTypeIndex < m_MemProps.memoryTypeCount); + return m_MemProps.memoryTypes[memTypeIndex].heapIndex; } - - //VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); -} - - -void VmaBlockMetadata_Generic::UnregisterFreeSuballocation(VmaSuballocationList::iterator item) -{ - VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); - VMA_ASSERT(item->size > 0); - - // You may want to enable this validation at the beginning or at the end of - // this function, depending on what do you want to check. - VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); - - if(item->size >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + // True when specific memory type is HOST_VISIBLE but not HOST_COHERENT. + bool IsMemoryTypeNonCoherent(uint32_t memTypeIndex) const { - VmaSuballocationList::iterator* const it = VmaBinaryFindFirstNotLess( - m_FreeSuballocationsBySize.data(), - m_FreeSuballocationsBySize.data() + m_FreeSuballocationsBySize.size(), - item, - VmaSuballocationItemSizeLess()); - for(size_t index = it - m_FreeSuballocationsBySize.data(); - index < m_FreeSuballocationsBySize.size(); - ++index) - { - if(m_FreeSuballocationsBySize[index] == item) - { - VmaVectorRemove(m_FreeSuballocationsBySize, index); - return; - } - VMA_ASSERT((m_FreeSuballocationsBySize[index]->size == item->size) && "Not found."); - } - VMA_ASSERT(0 && "Not found."); + return (m_MemProps.memoryTypes[memTypeIndex].propertyFlags & (VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT)) == + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; } - - //VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); -} - -bool VmaBlockMetadata_Generic::IsBufferImageGranularityConflictPossible( - VkDeviceSize bufferImageGranularity, - VmaSuballocationType& inOutPrevSuballocType) const -{ - if(bufferImageGranularity == 1 || IsEmpty()) + // Minimum alignment for all allocations in specific memory type. + VkDeviceSize GetMemoryTypeMinAlignment(uint32_t memTypeIndex) const { - return false; + return IsMemoryTypeNonCoherent(memTypeIndex) ? + VMA_MAX((VkDeviceSize)VMA_MIN_ALIGNMENT, m_PhysicalDeviceProperties.limits.nonCoherentAtomSize) : + (VkDeviceSize)VMA_MIN_ALIGNMENT; } - VkDeviceSize minAlignment = VK_WHOLE_SIZE; - bool typeConflictFound = false; - for(VmaSuballocationList::const_iterator it = m_Suballocations.cbegin(); - it != m_Suballocations.cend(); - ++it) + bool IsIntegratedGpu() const { - const VmaSuballocationType suballocType = it->type; - if(suballocType != VMA_SUBALLOCATION_TYPE_FREE) - { - minAlignment = VMA_MIN(minAlignment, it->hAllocation->GetAlignment()); - if(VmaIsBufferImageGranularityConflict(inOutPrevSuballocType, suballocType)) - { - typeConflictFound = true; - } - inOutPrevSuballocType = suballocType; - } + return m_PhysicalDeviceProperties.deviceType == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU; } - return typeConflictFound || minAlignment >= bufferImageGranularity; -} - -//////////////////////////////////////////////////////////////////////////////// -// class VmaBlockMetadata_Linear + uint32_t GetGlobalMemoryTypeBits() const { return m_GlobalMemoryTypeBits; } -VmaBlockMetadata_Linear::VmaBlockMetadata_Linear(VmaAllocator hAllocator) : - VmaBlockMetadata(hAllocator), - m_SumFreeSize(0), - m_Suballocations0(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), - m_Suballocations1(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), - m_1stVectorIndex(0), - m_2ndVectorMode(SECOND_VECTOR_EMPTY), - m_1stNullItemsBeginCount(0), - m_1stNullItemsMiddleCount(0), - m_2ndNullItemsCount(0) -{ -} + void GetBufferMemoryRequirements( + VkBuffer hBuffer, + VkMemoryRequirements& memReq, + bool& requiresDedicatedAllocation, + bool& prefersDedicatedAllocation) const; + void GetImageMemoryRequirements( + VkImage hImage, + VkMemoryRequirements& memReq, + bool& requiresDedicatedAllocation, + bool& prefersDedicatedAllocation) const; + VkResult FindMemoryTypeIndex( + uint32_t memoryTypeBits, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + VkFlags bufImgUsage, // VkBufferCreateInfo::usage or VkImageCreateInfo::usage. UINT32_MAX if unknown. + uint32_t* pMemoryTypeIndex) const; -VmaBlockMetadata_Linear::~VmaBlockMetadata_Linear() -{ -} + // Main allocation function. + VkResult AllocateMemory( + const VkMemoryRequirements& vkMemReq, + bool requiresDedicatedAllocation, + bool prefersDedicatedAllocation, + VkBuffer dedicatedBuffer, + VkImage dedicatedImage, + VkFlags dedicatedBufferImageUsage, // UINT32_MAX if unknown. + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations); -void VmaBlockMetadata_Linear::Init(VkDeviceSize size) -{ - VmaBlockMetadata::Init(size); - m_SumFreeSize = size; -} + // Main deallocation function. + void FreeMemory( + size_t allocationCount, + const VmaAllocation* pAllocations); -bool VmaBlockMetadata_Linear::Validate() const -{ - const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + void CalculateStatistics(VmaTotalStatistics* pStats); - VMA_VALIDATE(suballocations2nd.empty() == (m_2ndVectorMode == SECOND_VECTOR_EMPTY)); - VMA_VALIDATE(!suballocations1st.empty() || - suballocations2nd.empty() || - m_2ndVectorMode != SECOND_VECTOR_RING_BUFFER); + void GetHeapBudgets( + VmaBudget* outBudgets, uint32_t firstHeap, uint32_t heapCount); - if(!suballocations1st.empty()) - { - // Null item at the beginning should be accounted into m_1stNullItemsBeginCount. - VMA_VALIDATE(suballocations1st[m_1stNullItemsBeginCount].hAllocation != VK_NULL_HANDLE); - // Null item at the end should be just pop_back(). - VMA_VALIDATE(suballocations1st.back().hAllocation != VK_NULL_HANDLE); - } - if(!suballocations2nd.empty()) - { - // Null item at the end should be just pop_back(). - VMA_VALIDATE(suballocations2nd.back().hAllocation != VK_NULL_HANDLE); - } +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap(class VmaJsonWriter& json); +#endif - VMA_VALIDATE(m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount <= suballocations1st.size()); - VMA_VALIDATE(m_2ndNullItemsCount <= suballocations2nd.size()); + void GetAllocationInfo(VmaAllocation hAllocation, VmaAllocationInfo* pAllocationInfo); - VkDeviceSize sumUsedSize = 0; - const size_t suballoc1stCount = suballocations1st.size(); - VkDeviceSize offset = VMA_DEBUG_MARGIN; + VkResult CreatePool(const VmaPoolCreateInfo* pCreateInfo, VmaPool* pPool); + void DestroyPool(VmaPool pool); + void GetPoolStatistics(VmaPool pool, VmaStatistics* pPoolStats); + void CalculatePoolStatistics(VmaPool pool, VmaDetailedStatistics* pPoolStats); - if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) - { - const size_t suballoc2ndCount = suballocations2nd.size(); - size_t nullItem2ndCount = 0; - for (const auto i : c10::irange(suballoc2ndCount)) { - const VmaSuballocation& suballoc = suballocations2nd[i]; - const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + void SetCurrentFrameIndex(uint32_t frameIndex); + uint32_t GetCurrentFrameIndex() const { return m_CurrentFrameIndex.load(); } - VMA_VALIDATE(currFree == (suballoc.hAllocation == VK_NULL_HANDLE)); - VMA_VALIDATE(suballoc.offset >= offset); + VkResult CheckPoolCorruption(VmaPool hPool); + VkResult CheckCorruption(uint32_t memoryTypeBits); - if(!currFree) - { - VMA_VALIDATE(suballoc.hAllocation->GetOffset() == suballoc.offset); - VMA_VALIDATE(suballoc.hAllocation->GetSize() == suballoc.size); - sumUsedSize += suballoc.size; - } - else - { - ++nullItem2ndCount; - } + // Call to Vulkan function vkAllocateMemory with accompanying bookkeeping. + VkResult AllocateVulkanMemory(const VkMemoryAllocateInfo* pAllocateInfo, VkDeviceMemory* pMemory); + // Call to Vulkan function vkFreeMemory with accompanying bookkeeping. + void FreeVulkanMemory(uint32_t memoryType, VkDeviceSize size, VkDeviceMemory hMemory); + // Call to Vulkan function vkBindBufferMemory or vkBindBufferMemory2KHR. + VkResult BindVulkanBuffer( + VkDeviceMemory memory, + VkDeviceSize memoryOffset, + VkBuffer buffer, + const void* pNext); + // Call to Vulkan function vkBindImageMemory or vkBindImageMemory2KHR. + VkResult BindVulkanImage( + VkDeviceMemory memory, + VkDeviceSize memoryOffset, + VkImage image, + const void* pNext); - offset = suballoc.offset + suballoc.size + VMA_DEBUG_MARGIN; - } + VkResult Map(VmaAllocation hAllocation, void** ppData); + void Unmap(VmaAllocation hAllocation); - VMA_VALIDATE(nullItem2ndCount == m_2ndNullItemsCount); - } + VkResult BindBufferMemory( + VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkBuffer hBuffer, + const void* pNext); + VkResult BindImageMemory( + VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkImage hImage, + const void* pNext); - for (const auto i : c10::irange(m_1stNullItemsBeginCount)) { - const VmaSuballocation& suballoc = suballocations1st[i]; - VMA_VALIDATE(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE && - suballoc.hAllocation == VK_NULL_HANDLE); + VkResult FlushOrInvalidateAllocation( + VmaAllocation hAllocation, + VkDeviceSize offset, VkDeviceSize size, + VMA_CACHE_OPERATION op); + VkResult FlushOrInvalidateAllocations( + uint32_t allocationCount, + const VmaAllocation* allocations, + const VkDeviceSize* offsets, const VkDeviceSize* sizes, + VMA_CACHE_OPERATION op); + + void FillAllocation(const VmaAllocation hAllocation, uint8_t pattern); + + /* + Returns bit mask of memory types that can support defragmentation on GPU as + they support creation of required buffer for copy operations. + */ + uint32_t GetGpuDefragmentationMemoryTypeBits(); + +#if VMA_EXTERNAL_MEMORY + VkExternalMemoryHandleTypeFlagsKHR GetExternalMemoryHandleTypeFlags(uint32_t memTypeIndex) const + { + return m_TypeExternalMemoryHandleTypes[memTypeIndex]; } +#endif // #if VMA_EXTERNAL_MEMORY - size_t nullItem1stCount = m_1stNullItemsBeginCount; +private: + VkDeviceSize m_PreferredLargeHeapBlockSize; - for (const auto i : c10::irange(m_1stNullItemsBeginCount, suballoc1stCount)) { - const VmaSuballocation& suballoc = suballocations1st[i]; - const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + VkPhysicalDevice m_PhysicalDevice; + VMA_ATOMIC_UINT32 m_CurrentFrameIndex; + VMA_ATOMIC_UINT32 m_GpuDefragmentationMemoryTypeBits; // UINT32_MAX means uninitialized. +#if VMA_EXTERNAL_MEMORY + VkExternalMemoryHandleTypeFlagsKHR m_TypeExternalMemoryHandleTypes[VK_MAX_MEMORY_TYPES]; +#endif // #if VMA_EXTERNAL_MEMORY - VMA_VALIDATE(currFree == (suballoc.hAllocation == VK_NULL_HANDLE)); - VMA_VALIDATE(suballoc.offset >= offset); - VMA_VALIDATE(i >= m_1stNullItemsBeginCount || currFree); + VMA_RW_MUTEX m_PoolsMutex; + typedef VmaIntrusiveLinkedList PoolList; + // Protected by m_PoolsMutex. + PoolList m_Pools; + uint32_t m_NextPoolId; - if(!currFree) - { - VMA_VALIDATE(suballoc.hAllocation->GetOffset() == suballoc.offset); - VMA_VALIDATE(suballoc.hAllocation->GetSize() == suballoc.size); - sumUsedSize += suballoc.size; - } - else - { - ++nullItem1stCount; - } + VmaVulkanFunctions m_VulkanFunctions; - offset = suballoc.offset + suballoc.size + VMA_DEBUG_MARGIN; - } - VMA_VALIDATE(nullItem1stCount == m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount); + // Global bit mask AND-ed with any memoryTypeBits to disallow certain memory types. + uint32_t m_GlobalMemoryTypeBits; - if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) - { - const size_t suballoc2ndCount = suballocations2nd.size(); - size_t nullItem2ndCount = 0; - for(size_t i = suballoc2ndCount; i--; ) - { - const VmaSuballocation& suballoc = suballocations2nd[i]; - const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + void ImportVulkanFunctions(const VmaVulkanFunctions* pVulkanFunctions); - VMA_VALIDATE(currFree == (suballoc.hAllocation == VK_NULL_HANDLE)); - VMA_VALIDATE(suballoc.offset >= offset); +#if VMA_STATIC_VULKAN_FUNCTIONS == 1 + void ImportVulkanFunctions_Static(); +#endif - if(!currFree) - { - VMA_VALIDATE(suballoc.hAllocation->GetOffset() == suballoc.offset); - VMA_VALIDATE(suballoc.hAllocation->GetSize() == suballoc.size); - sumUsedSize += suballoc.size; - } - else - { - ++nullItem2ndCount; - } + void ImportVulkanFunctions_Custom(const VmaVulkanFunctions* pVulkanFunctions); - offset = suballoc.offset + suballoc.size + VMA_DEBUG_MARGIN; - } +#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + void ImportVulkanFunctions_Dynamic(); +#endif - VMA_VALIDATE(nullItem2ndCount == m_2ndNullItemsCount); - } + void ValidateVulkanFunctions(); - VMA_VALIDATE(offset <= GetSize()); - VMA_VALIDATE(m_SumFreeSize == GetSize() - sumUsedSize); + VkDeviceSize CalcPreferredBlockSize(uint32_t memTypeIndex); - return true; -} + VkResult AllocateMemoryOfType( + VmaPool pool, + VkDeviceSize size, + VkDeviceSize alignment, + bool dedicatedPreferred, + VkBuffer dedicatedBuffer, + VkImage dedicatedImage, + VkFlags dedicatedBufferImageUsage, + const VmaAllocationCreateInfo& createInfo, + uint32_t memTypeIndex, + VmaSuballocationType suballocType, + VmaDedicatedAllocationList& dedicatedAllocations, + VmaBlockVector& blockVector, + size_t allocationCount, + VmaAllocation* pAllocations); -size_t VmaBlockMetadata_Linear::GetAllocationCount() const -{ - return AccessSuballocations1st().size() - (m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount) + - AccessSuballocations2nd().size() - m_2ndNullItemsCount; -} + // Helper function only to be used inside AllocateDedicatedMemory. + VkResult AllocateDedicatedMemoryPage( + VmaPool pool, + VkDeviceSize size, + VmaSuballocationType suballocType, + uint32_t memTypeIndex, + const VkMemoryAllocateInfo& allocInfo, + bool map, + bool isUserDataString, + bool isMappingAllowed, + void* pUserData, + VmaAllocation* pAllocation); -VkDeviceSize VmaBlockMetadata_Linear::GetUnusedRangeSizeMax() const -{ - const VkDeviceSize size = GetSize(); + // Allocates and registers new VkDeviceMemory specifically for dedicated allocations. + VkResult AllocateDedicatedMemory( + VmaPool pool, + VkDeviceSize size, + VmaSuballocationType suballocType, + VmaDedicatedAllocationList& dedicatedAllocations, + uint32_t memTypeIndex, + bool map, + bool isUserDataString, + bool isMappingAllowed, + bool canAliasMemory, + void* pUserData, + float priority, + VkBuffer dedicatedBuffer, + VkImage dedicatedImage, + VkFlags dedicatedBufferImageUsage, + size_t allocationCount, + VmaAllocation* pAllocations, + const void* pNextChain = nullptr); + + void FreeDedicatedMemory(const VmaAllocation allocation); + + VkResult CalcMemTypeParams( + VmaAllocationCreateInfo& outCreateInfo, + uint32_t memTypeIndex, + VkDeviceSize size, + size_t allocationCount); + VkResult CalcAllocationParams( + VmaAllocationCreateInfo& outCreateInfo, + bool dedicatedRequired, + bool dedicatedPreferred); /* - We don't consider gaps inside allocation vectors with freed allocations because - they are not suitable for reuse in linear allocator. We consider only space that - is available for new allocations. + Calculates and returns bit mask of memory types that can support defragmentation + on GPU as they support creation of required buffer for copy operations. */ - if(IsEmpty()) - { - return size; - } - - const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + uint32_t CalculateGpuDefragmentationMemoryTypeBits() const; + uint32_t CalculateGlobalMemoryTypeBits() const; - switch(m_2ndVectorMode) - { - case SECOND_VECTOR_EMPTY: - /* - Available space is after end of 1st, as well as before beginning of 1st (which - whould make it a ring buffer). - */ - { - const size_t suballocations1stCount = suballocations1st.size(); - VMA_ASSERT(suballocations1stCount > m_1stNullItemsBeginCount); - const VmaSuballocation& firstSuballoc = suballocations1st[m_1stNullItemsBeginCount]; - const VmaSuballocation& lastSuballoc = suballocations1st[suballocations1stCount - 1]; - return VMA_MAX( - firstSuballoc.offset, - size - (lastSuballoc.offset + lastSuballoc.size)); - } - break; + bool GetFlushOrInvalidateRange( + VmaAllocation allocation, + VkDeviceSize offset, VkDeviceSize size, + VkMappedMemoryRange& outRange) const; - case SECOND_VECTOR_RING_BUFFER: - /* - Available space is only between end of 2nd and beginning of 1st. - */ - { - const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - const VmaSuballocation& lastSuballoc2nd = suballocations2nd.back(); - const VmaSuballocation& firstSuballoc1st = suballocations1st[m_1stNullItemsBeginCount]; - return firstSuballoc1st.offset - (lastSuballoc2nd.offset + lastSuballoc2nd.size); - } - break; +#if VMA_MEMORY_BUDGET + void UpdateVulkanBudget(); +#endif // #if VMA_MEMORY_BUDGET +}; - case SECOND_VECTOR_DOUBLE_STACK: - /* - Available space is only between end of 1st and top of 2nd. - */ - { - const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - const VmaSuballocation& topSuballoc2nd = suballocations2nd.back(); - const VmaSuballocation& lastSuballoc1st = suballocations1st.back(); - return topSuballoc2nd.offset - (lastSuballoc1st.offset + lastSuballoc1st.size); - } - break; - default: - VMA_ASSERT(0); - return 0; - } +#ifndef _VMA_MEMORY_FUNCTIONS +static void* VmaMalloc(VmaAllocator hAllocator, size_t size, size_t alignment) +{ + return VmaMalloc(&hAllocator->m_AllocationCallbacks, size, alignment); } -void VmaBlockMetadata_Linear::CalcAllocationStatInfo(VmaStatInfo& outInfo) const +static void VmaFree(VmaAllocator hAllocator, void* ptr) { - const VkDeviceSize size = GetSize(); - const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - const size_t suballoc1stCount = suballocations1st.size(); - const size_t suballoc2ndCount = suballocations2nd.size(); + VmaFree(&hAllocator->m_AllocationCallbacks, ptr); +} - outInfo.blockCount = 1; - outInfo.allocationCount = (uint32_t)GetAllocationCount(); - outInfo.unusedRangeCount = 0; - outInfo.usedBytes = 0; - outInfo.allocationSizeMin = UINT64_MAX; - outInfo.allocationSizeMax = 0; - outInfo.unusedRangeSizeMin = UINT64_MAX; - outInfo.unusedRangeSizeMax = 0; +template +static T* VmaAllocate(VmaAllocator hAllocator) +{ + return (T*)VmaMalloc(hAllocator, sizeof(T), VMA_ALIGN_OF(T)); +} - VkDeviceSize lastOffset = 0; +template +static T* VmaAllocateArray(VmaAllocator hAllocator, size_t count) +{ + return (T*)VmaMalloc(hAllocator, sizeof(T) * count, VMA_ALIGN_OF(T)); +} - if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) +template +static void vma_delete(VmaAllocator hAllocator, T* ptr) +{ + if(ptr != VMA_NULL) { - const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; - size_t nextAlloc2ndIndex = 0; - while(lastOffset < freeSpace2ndTo1stEnd) - { - // Find next non-null allocation or move nextAllocIndex to the end. - while(nextAlloc2ndIndex < suballoc2ndCount && - suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) - { - ++nextAlloc2ndIndex; - } + ptr->~T(); + VmaFree(hAllocator, ptr); + } +} - // Found non-null allocation. - if(nextAlloc2ndIndex < suballoc2ndCount) - { - const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; +template +static void vma_delete_array(VmaAllocator hAllocator, T* ptr, size_t count) +{ + if(ptr != VMA_NULL) + { + for(size_t i = count; i--; ) + ptr[i].~T(); + VmaFree(hAllocator, ptr); + } +} +#endif // _VMA_MEMORY_FUNCTIONS - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += unusedRangeSize; - outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); - outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); - } +#ifndef _VMA_DEVICE_MEMORY_BLOCK_FUNCTIONS +VmaDeviceMemoryBlock::VmaDeviceMemoryBlock(VmaAllocator hAllocator) + : m_pMetadata(VMA_NULL), + m_MemoryTypeIndex(UINT32_MAX), + m_Id(0), + m_hMemory(VK_NULL_HANDLE), + m_MapCount(0), + m_pMappedData(VMA_NULL) {} - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - outInfo.usedBytes += suballoc.size; - outInfo.allocationSizeMin = VMA_MIN(outInfo.allocationSizeMin, suballoc.size); - outInfo.allocationSizeMax = VMA_MIN(outInfo.allocationSizeMax, suballoc.size); +VmaDeviceMemoryBlock::~VmaDeviceMemoryBlock() +{ + VMA_ASSERT(m_MapCount == 0 && "VkDeviceMemory block is being destroyed while it is still mapped."); + VMA_ASSERT(m_hMemory == VK_NULL_HANDLE); +} - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - ++nextAlloc2ndIndex; - } - // We are at the end. - else - { - // There is free space from lastOffset to freeSpace2ndTo1stEnd. - if(lastOffset < freeSpace2ndTo1stEnd) - { - const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += unusedRangeSize; - outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); - outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); - } +void VmaDeviceMemoryBlock::Init( + VmaAllocator hAllocator, + VmaPool hParentPool, + uint32_t newMemoryTypeIndex, + VkDeviceMemory newMemory, + VkDeviceSize newSize, + uint32_t id, + uint32_t algorithm, + VkDeviceSize bufferImageGranularity) +{ + VMA_ASSERT(m_hMemory == VK_NULL_HANDLE); - // End of loop. - lastOffset = freeSpace2ndTo1stEnd; - } - } - } + m_hParentPool = hParentPool; + m_MemoryTypeIndex = newMemoryTypeIndex; + m_Id = id; + m_hMemory = newMemory; - size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; - const VkDeviceSize freeSpace1stTo2ndEnd = - m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; - while(lastOffset < freeSpace1stTo2ndEnd) + switch (algorithm) { - // Find next non-null allocation or move nextAllocIndex to the end. - while(nextAlloc1stIndex < suballoc1stCount && - suballocations1st[nextAlloc1stIndex].hAllocation == VK_NULL_HANDLE) - { - ++nextAlloc1stIndex; - } - - // Found non-null allocation. - if(nextAlloc1stIndex < suballoc1stCount) - { - const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; - - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += unusedRangeSize; - outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); - outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); - } + case VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT: + m_pMetadata = vma_new(hAllocator, VmaBlockMetadata_Linear)(hAllocator->GetAllocationCallbacks(), + bufferImageGranularity, false); // isVirtual + break; + default: + VMA_ASSERT(0); + // Fall-through. + case 0: + m_pMetadata = vma_new(hAllocator, VmaBlockMetadata_TLSF)(hAllocator->GetAllocationCallbacks(), + bufferImageGranularity, false); // isVirtual + } + m_pMetadata->Init(newSize); +} - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - outInfo.usedBytes += suballoc.size; - outInfo.allocationSizeMin = VMA_MIN(outInfo.allocationSizeMin, suballoc.size); - outInfo.allocationSizeMax = VMA_MIN(outInfo.allocationSizeMax, suballoc.size); +void VmaDeviceMemoryBlock::Destroy(VmaAllocator allocator) +{ + // Define macro VMA_DEBUG_LOG to receive the list of the unfreed allocations + if (!m_pMetadata->IsEmpty()) + m_pMetadata->DebugLogAllAllocations(); + // This is the most important assert in the entire library. + // Hitting it means you have some memory leak - unreleased VmaAllocation objects. + VMA_ASSERT(m_pMetadata->IsEmpty() && "Some allocations were not freed before destruction of this memory block!"); - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - ++nextAlloc1stIndex; - } - // We are at the end. - else - { - // There is free space from lastOffset to freeSpace1stTo2ndEnd. - if(lastOffset < freeSpace1stTo2ndEnd) - { - const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += unusedRangeSize; - outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); - outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); - } + VMA_ASSERT(m_hMemory != VK_NULL_HANDLE); + allocator->FreeVulkanMemory(m_MemoryTypeIndex, m_pMetadata->GetSize(), m_hMemory); + m_hMemory = VK_NULL_HANDLE; - // End of loop. - lastOffset = freeSpace1stTo2ndEnd; - } - } + vma_delete(allocator, m_pMetadata); + m_pMetadata = VMA_NULL; +} - if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) +void VmaDeviceMemoryBlock::PostFree(VmaAllocator hAllocator) +{ + if(m_MappingHysteresis.PostFree()) { - size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; - while(lastOffset < size) + VMA_ASSERT(m_MappingHysteresis.GetExtraMapping() == 0); + if (m_MapCount == 0) { - // Find next non-null allocation or move nextAllocIndex to the end. - while(nextAlloc2ndIndex != SIZE_MAX && - suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) - { - --nextAlloc2ndIndex; - } - - // Found non-null allocation. - if(nextAlloc2ndIndex != SIZE_MAX) - { - const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += unusedRangeSize; - outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); - outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); - } - - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - outInfo.usedBytes += suballoc.size; - outInfo.allocationSizeMin = VMA_MIN(outInfo.allocationSizeMin, suballoc.size); - outInfo.allocationSizeMax = VMA_MIN(outInfo.allocationSizeMax, suballoc.size); - - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - --nextAlloc2ndIndex; - } - // We are at the end. - else - { - // There is free space from lastOffset to size. - if(lastOffset < size) - { - const VkDeviceSize unusedRangeSize = size - lastOffset; - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += unusedRangeSize; - outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); - outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); - } - - // End of loop. - lastOffset = size; - } + m_pMappedData = VMA_NULL; + (*hAllocator->GetVulkanFunctions().vkUnmapMemory)(hAllocator->m_hDevice, m_hMemory); } } - - outInfo.unusedBytes = size - outInfo.usedBytes; } -void VmaBlockMetadata_Linear::AddPoolStats(VmaPoolStats& inoutStats) const +bool VmaDeviceMemoryBlock::Validate() const { - const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - const VkDeviceSize size = GetSize(); - const size_t suballoc1stCount = suballocations1st.size(); - const size_t suballoc2ndCount = suballocations2nd.size(); - - inoutStats.size += size; + VMA_VALIDATE((m_hMemory != VK_NULL_HANDLE) && + (m_pMetadata->GetSize() != 0)); - VkDeviceSize lastOffset = 0; + return m_pMetadata->Validate(); +} - if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) +VkResult VmaDeviceMemoryBlock::CheckCorruption(VmaAllocator hAllocator) +{ + void* pData = nullptr; + VkResult res = Map(hAllocator, 1, &pData); + if (res != VK_SUCCESS) { - const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; - size_t nextAlloc2ndIndex = m_1stNullItemsBeginCount; - while(lastOffset < freeSpace2ndTo1stEnd) - { - // Find next non-null allocation or move nextAlloc2ndIndex to the end. - while(nextAlloc2ndIndex < suballoc2ndCount && - suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) - { - ++nextAlloc2ndIndex; - } - - // Found non-null allocation. - if(nextAlloc2ndIndex < suballoc2ndCount) - { - const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + return res; + } - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - inoutStats.unusedSize += unusedRangeSize; - ++inoutStats.unusedRangeCount; - inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); - } + res = m_pMetadata->CheckCorruption(pData); - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - ++inoutStats.allocationCount; + Unmap(hAllocator, 1); - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - ++nextAlloc2ndIndex; - } - // We are at the end. - else - { - if(lastOffset < freeSpace2ndTo1stEnd) - { - // There is free space from lastOffset to freeSpace2ndTo1stEnd. - const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; - inoutStats.unusedSize += unusedRangeSize; - ++inoutStats.unusedRangeCount; - inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); - } + return res; +} - // End of loop. - lastOffset = freeSpace2ndTo1stEnd; - } - } +VkResult VmaDeviceMemoryBlock::Map(VmaAllocator hAllocator, uint32_t count, void** ppData) +{ + if (count == 0) + { + return VK_SUCCESS; } - size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; - const VkDeviceSize freeSpace1stTo2ndEnd = - m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; - while(lastOffset < freeSpace1stTo2ndEnd) + VmaMutexLock lock(m_MapAndBindMutex, hAllocator->m_UseMutex); + const uint32_t oldTotalMapCount = m_MapCount + m_MappingHysteresis.GetExtraMapping(); + m_MappingHysteresis.PostMap(); + if (oldTotalMapCount != 0) { - // Find next non-null allocation or move nextAllocIndex to the end. - while(nextAlloc1stIndex < suballoc1stCount && - suballocations1st[nextAlloc1stIndex].hAllocation == VK_NULL_HANDLE) + m_MapCount += count; + VMA_ASSERT(m_pMappedData != VMA_NULL); + if (ppData != VMA_NULL) { - ++nextAlloc1stIndex; + *ppData = m_pMappedData; } - - // Found non-null allocation. - if(nextAlloc1stIndex < suballoc1stCount) + return VK_SUCCESS; + } + else + { + VkResult result = (*hAllocator->GetVulkanFunctions().vkMapMemory)( + hAllocator->m_hDevice, + m_hMemory, + 0, // offset + VK_WHOLE_SIZE, + 0, // flags + &m_pMappedData); + if (result == VK_SUCCESS) { - const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; - - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) + if (ppData != VMA_NULL) { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - inoutStats.unusedSize += unusedRangeSize; - ++inoutStats.unusedRangeCount; - inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); + *ppData = m_pMappedData; } - - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - ++inoutStats.allocationCount; - - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - ++nextAlloc1stIndex; + m_MapCount = count; } - // We are at the end. - else - { - if(lastOffset < freeSpace1stTo2ndEnd) - { - // There is free space from lastOffset to freeSpace1stTo2ndEnd. - const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; - inoutStats.unusedSize += unusedRangeSize; - ++inoutStats.unusedRangeCount; - inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); - } + return result; + } +} - // End of loop. - lastOffset = freeSpace1stTo2ndEnd; - } +void VmaDeviceMemoryBlock::Unmap(VmaAllocator hAllocator, uint32_t count) +{ + if (count == 0) + { + return; } - if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + VmaMutexLock lock(m_MapAndBindMutex, hAllocator->m_UseMutex); + if (m_MapCount >= count) { - size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; - while(lastOffset < size) + m_MapCount -= count; + const uint32_t totalMapCount = m_MapCount + m_MappingHysteresis.GetExtraMapping(); + if (totalMapCount == 0) { - // Find next non-null allocation or move nextAlloc2ndIndex to the end. - while(nextAlloc2ndIndex != SIZE_MAX && - suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) - { - --nextAlloc2ndIndex; - } + m_pMappedData = VMA_NULL; + (*hAllocator->GetVulkanFunctions().vkUnmapMemory)(hAllocator->m_hDevice, m_hMemory); + } + m_MappingHysteresis.PostUnmap(); + } + else + { + VMA_ASSERT(0 && "VkDeviceMemory block is being unmapped while it was not previously mapped."); + } +} - // Found non-null allocation. - if(nextAlloc2ndIndex != SIZE_MAX) - { - const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; +VkResult VmaDeviceMemoryBlock::WriteMagicValueAfterAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize) +{ + VMA_ASSERT(VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_MARGIN % 4 == 0 && VMA_DEBUG_DETECT_CORRUPTION); - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - inoutStats.unusedSize += unusedRangeSize; - ++inoutStats.unusedRangeCount; - inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); - } + void* pData; + VkResult res = Map(hAllocator, 1, &pData); + if (res != VK_SUCCESS) + { + return res; + } - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - ++inoutStats.allocationCount; + VmaWriteMagicValue(pData, allocOffset + allocSize); - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - --nextAlloc2ndIndex; - } - // We are at the end. - else - { - if(lastOffset < size) - { - // There is free space from lastOffset to size. - const VkDeviceSize unusedRangeSize = size - lastOffset; - inoutStats.unusedSize += unusedRangeSize; - ++inoutStats.unusedRangeCount; - inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); - } + Unmap(hAllocator, 1); + return VK_SUCCESS; +} - // End of loop. - lastOffset = size; - } - } +VkResult VmaDeviceMemoryBlock::ValidateMagicValueAfterAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize) +{ + VMA_ASSERT(VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_MARGIN % 4 == 0 && VMA_DEBUG_DETECT_CORRUPTION); + + void* pData; + VkResult res = Map(hAllocator, 1, &pData); + if (res != VK_SUCCESS) + { + return res; + } + + if (!VmaValidateMagicValue(pData, allocOffset + allocSize)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER FREED ALLOCATION!"); } + + Unmap(hAllocator, 1); + return VK_SUCCESS; } -#if VMA_STATS_STRING_ENABLED -void VmaBlockMetadata_Linear::PrintDetailedMap(class VmaJsonWriter& json) const +VkResult VmaDeviceMemoryBlock::BindBufferMemory( + const VmaAllocator hAllocator, + const VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkBuffer hBuffer, + const void* pNext) { - const VkDeviceSize size = GetSize(); - const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - const size_t suballoc1stCount = suballocations1st.size(); - const size_t suballoc2ndCount = suballocations2nd.size(); + VMA_ASSERT(hAllocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_BLOCK && + hAllocation->GetBlock() == this); + VMA_ASSERT(allocationLocalOffset < hAllocation->GetSize() && + "Invalid allocationLocalOffset. Did you forget that this offset is relative to the beginning of the allocation, not the whole memory block?"); + const VkDeviceSize memoryOffset = hAllocation->GetOffset() + allocationLocalOffset; + // This lock is important so that we don't call vkBind... and/or vkMap... simultaneously on the same VkDeviceMemory from multiple threads. + VmaMutexLock lock(m_MapAndBindMutex, hAllocator->m_UseMutex); + return hAllocator->BindVulkanBuffer(m_hMemory, memoryOffset, hBuffer, pNext); +} - // FIRST PASS +VkResult VmaDeviceMemoryBlock::BindImageMemory( + const VmaAllocator hAllocator, + const VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkImage hImage, + const void* pNext) +{ + VMA_ASSERT(hAllocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_BLOCK && + hAllocation->GetBlock() == this); + VMA_ASSERT(allocationLocalOffset < hAllocation->GetSize() && + "Invalid allocationLocalOffset. Did you forget that this offset is relative to the beginning of the allocation, not the whole memory block?"); + const VkDeviceSize memoryOffset = hAllocation->GetOffset() + allocationLocalOffset; + // This lock is important so that we don't call vkBind... and/or vkMap... simultaneously on the same VkDeviceMemory from multiple threads. + VmaMutexLock lock(m_MapAndBindMutex, hAllocator->m_UseMutex); + return hAllocator->BindVulkanImage(m_hMemory, memoryOffset, hImage, pNext); +} +#endif // _VMA_DEVICE_MEMORY_BLOCK_FUNCTIONS - size_t unusedRangeCount = 0; - VkDeviceSize usedBytes = 0; +#ifndef _VMA_ALLOCATION_T_FUNCTIONS +VmaAllocation_T::VmaAllocation_T(bool mappingAllowed) + : m_Alignment{ 1 }, + m_Size{ 0 }, + m_pUserData{ VMA_NULL }, + m_pName{ VMA_NULL }, + m_MemoryTypeIndex{ 0 }, + m_Type{ (uint8_t)ALLOCATION_TYPE_NONE }, + m_SuballocationType{ (uint8_t)VMA_SUBALLOCATION_TYPE_UNKNOWN }, + m_MapCount{ 0 }, + m_Flags{ 0 } +{ + if(mappingAllowed) + m_Flags |= (uint8_t)FLAG_MAPPING_ALLOWED; - VkDeviceSize lastOffset = 0; +#if VMA_STATS_STRING_ENABLED + m_BufferImageUsage = 0; +#endif +} - size_t alloc2ndCount = 0; - if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) +VmaAllocation_T::~VmaAllocation_T() +{ + VMA_ASSERT(m_MapCount == 0 && "Allocation was not unmapped before destruction."); + + // Check if owned string was freed. + VMA_ASSERT(m_pName == VMA_NULL); +} + +void VmaAllocation_T::InitBlockAllocation( + VmaDeviceMemoryBlock* block, + VmaAllocHandle allocHandle, + VkDeviceSize alignment, + VkDeviceSize size, + uint32_t memoryTypeIndex, + VmaSuballocationType suballocationType, + bool mapped) +{ + VMA_ASSERT(m_Type == ALLOCATION_TYPE_NONE); + VMA_ASSERT(block != VMA_NULL); + m_Type = (uint8_t)ALLOCATION_TYPE_BLOCK; + m_Alignment = alignment; + m_Size = size; + m_MemoryTypeIndex = memoryTypeIndex; + if(mapped) { - const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; - size_t nextAlloc2ndIndex = 0; - while(lastOffset < freeSpace2ndTo1stEnd) - { - // Find next non-null allocation or move nextAlloc2ndIndex to the end. - while(nextAlloc2ndIndex < suballoc2ndCount && - suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) - { - ++nextAlloc2ndIndex; - } + VMA_ASSERT(IsMappingAllowed() && "Mapping is not allowed on this allocation! Please use one of the new VMA_ALLOCATION_CREATE_HOST_ACCESS_* flags when creating it."); + m_Flags |= (uint8_t)FLAG_PERSISTENT_MAP; + } + m_SuballocationType = (uint8_t)suballocationType; + m_BlockAllocation.m_Block = block; + m_BlockAllocation.m_AllocHandle = allocHandle; +} - // Found non-null allocation. - if(nextAlloc2ndIndex < suballoc2ndCount) - { - const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; +void VmaAllocation_T::InitDedicatedAllocation( + VmaPool hParentPool, + uint32_t memoryTypeIndex, + VkDeviceMemory hMemory, + VmaSuballocationType suballocationType, + void* pMappedData, + VkDeviceSize size) +{ + VMA_ASSERT(m_Type == ALLOCATION_TYPE_NONE); + VMA_ASSERT(hMemory != VK_NULL_HANDLE); + m_Type = (uint8_t)ALLOCATION_TYPE_DEDICATED; + m_Alignment = 0; + m_Size = size; + m_MemoryTypeIndex = memoryTypeIndex; + m_SuballocationType = (uint8_t)suballocationType; + if(pMappedData != VMA_NULL) + { + VMA_ASSERT(IsMappingAllowed() && "Mapping is not allowed on this allocation! Please use one of the new VMA_ALLOCATION_CREATE_HOST_ACCESS_* flags when creating it."); + m_Flags |= (uint8_t)FLAG_PERSISTENT_MAP; + } + m_DedicatedAllocation.m_hParentPool = hParentPool; + m_DedicatedAllocation.m_hMemory = hMemory; + m_DedicatedAllocation.m_pMappedData = pMappedData; + m_DedicatedAllocation.m_Prev = VMA_NULL; + m_DedicatedAllocation.m_Next = VMA_NULL; +} - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - ++unusedRangeCount; - } +void VmaAllocation_T::SetName(VmaAllocator hAllocator, const char* pName) +{ + VMA_ASSERT(pName == VMA_NULL || pName != m_pName); - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - ++alloc2ndCount; - usedBytes += suballoc.size; + FreeName(hAllocator); - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - ++nextAlloc2ndIndex; - } - // We are at the end. - else - { - if(lastOffset < freeSpace2ndTo1stEnd) - { - // There is free space from lastOffset to freeSpace2ndTo1stEnd. - ++unusedRangeCount; - } + if (pName != VMA_NULL) + m_pName = VmaCreateStringCopy(hAllocator->GetAllocationCallbacks(), pName); +} - // End of loop. - lastOffset = freeSpace2ndTo1stEnd; - } - } - } +uint8_t VmaAllocation_T::SwapBlockAllocation(VmaAllocator hAllocator, VmaAllocation allocation) +{ + VMA_ASSERT(allocation != VMA_NULL); + VMA_ASSERT(m_Type == ALLOCATION_TYPE_BLOCK); + VMA_ASSERT(allocation->m_Type == ALLOCATION_TYPE_BLOCK); - size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; - size_t alloc1stCount = 0; - const VkDeviceSize freeSpace1stTo2ndEnd = - m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; - while(lastOffset < freeSpace1stTo2ndEnd) + if (m_MapCount != 0) + m_BlockAllocation.m_Block->Unmap(hAllocator, m_MapCount); + + m_BlockAllocation.m_Block->m_pMetadata->SetAllocationUserData(m_BlockAllocation.m_AllocHandle, allocation); + VMA_SWAP(m_BlockAllocation, allocation->m_BlockAllocation); + m_BlockAllocation.m_Block->m_pMetadata->SetAllocationUserData(m_BlockAllocation.m_AllocHandle, this); + +#if VMA_STATS_STRING_ENABLED + VMA_SWAP(m_BufferImageUsage, allocation->m_BufferImageUsage); +#endif + return m_MapCount; +} + +VmaAllocHandle VmaAllocation_T::GetAllocHandle() const +{ + switch (m_Type) { - // Find next non-null allocation or move nextAllocIndex to the end. - while(nextAlloc1stIndex < suballoc1stCount && - suballocations1st[nextAlloc1stIndex].hAllocation == VK_NULL_HANDLE) - { - ++nextAlloc1stIndex; - } + case ALLOCATION_TYPE_BLOCK: + return m_BlockAllocation.m_AllocHandle; + case ALLOCATION_TYPE_DEDICATED: + return VK_NULL_HANDLE; + default: + VMA_ASSERT(0); + return VK_NULL_HANDLE; + } +} - // Found non-null allocation. - if(nextAlloc1stIndex < suballoc1stCount) - { - const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; +VkDeviceSize VmaAllocation_T::GetOffset() const +{ + switch (m_Type) + { + case ALLOCATION_TYPE_BLOCK: + return m_BlockAllocation.m_Block->m_pMetadata->GetAllocationOffset(m_BlockAllocation.m_AllocHandle); + case ALLOCATION_TYPE_DEDICATED: + return 0; + default: + VMA_ASSERT(0); + return 0; + } +} - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - ++unusedRangeCount; - } +VmaPool VmaAllocation_T::GetParentPool() const +{ + switch (m_Type) + { + case ALLOCATION_TYPE_BLOCK: + return m_BlockAllocation.m_Block->GetParentPool(); + case ALLOCATION_TYPE_DEDICATED: + return m_DedicatedAllocation.m_hParentPool; + default: + VMA_ASSERT(0); + return VK_NULL_HANDLE; + } +} - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - ++alloc1stCount; - usedBytes += suballoc.size; +VkDeviceMemory VmaAllocation_T::GetMemory() const +{ + switch (m_Type) + { + case ALLOCATION_TYPE_BLOCK: + return m_BlockAllocation.m_Block->GetDeviceMemory(); + case ALLOCATION_TYPE_DEDICATED: + return m_DedicatedAllocation.m_hMemory; + default: + VMA_ASSERT(0); + return VK_NULL_HANDLE; + } +} - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - ++nextAlloc1stIndex; +void* VmaAllocation_T::GetMappedData() const +{ + switch (m_Type) + { + case ALLOCATION_TYPE_BLOCK: + if (m_MapCount != 0 || IsPersistentMap()) + { + void* pBlockData = m_BlockAllocation.m_Block->GetMappedData(); + VMA_ASSERT(pBlockData != VMA_NULL); + return (char*)pBlockData + GetOffset(); } - // We are at the end. else { - if(lastOffset < size) - { - // There is free space from lastOffset to freeSpace1stTo2ndEnd. - ++unusedRangeCount; - } - - // End of loop. - lastOffset = freeSpace1stTo2ndEnd; + return VMA_NULL; } + break; + case ALLOCATION_TYPE_DEDICATED: + VMA_ASSERT((m_DedicatedAllocation.m_pMappedData != VMA_NULL) == (m_MapCount != 0 || IsPersistentMap())); + return m_DedicatedAllocation.m_pMappedData; + default: + VMA_ASSERT(0); + return VMA_NULL; } +} + +void VmaAllocation_T::BlockAllocMap() +{ + VMA_ASSERT(GetType() == ALLOCATION_TYPE_BLOCK); + VMA_ASSERT(IsMappingAllowed() && "Mapping is not allowed on this allocation! Please use one of the new VMA_ALLOCATION_CREATE_HOST_ACCESS_* flags when creating it."); - if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + if (m_MapCount < 0xFF) { - size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; - while(lastOffset < size) - { - // Find next non-null allocation or move nextAlloc2ndIndex to the end. - while(nextAlloc2ndIndex != SIZE_MAX && - suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) - { - --nextAlloc2ndIndex; - } + ++m_MapCount; + } + else + { + VMA_ASSERT(0 && "Allocation mapped too many times simultaneously."); + } +} - // Found non-null allocation. - if(nextAlloc2ndIndex != SIZE_MAX) - { - const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - ++unusedRangeCount; - } - - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - ++alloc2ndCount; - usedBytes += suballoc.size; - - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - --nextAlloc2ndIndex; - } - // We are at the end. - else - { - if(lastOffset < size) - { - // There is free space from lastOffset to size. - ++unusedRangeCount; - } +void VmaAllocation_T::BlockAllocUnmap() +{ + VMA_ASSERT(GetType() == ALLOCATION_TYPE_BLOCK); - // End of loop. - lastOffset = size; - } - } + if (m_MapCount > 0) + { + --m_MapCount; } - - const VkDeviceSize unusedBytes = size - usedBytes; - PrintDetailedMap_Begin(json, unusedBytes, alloc1stCount + alloc2ndCount, unusedRangeCount); - - // SECOND PASS - lastOffset = 0; - - if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + else { - const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; - size_t nextAlloc2ndIndex = 0; - while(lastOffset < freeSpace2ndTo1stEnd) - { - // Find next non-null allocation or move nextAlloc2ndIndex to the end. - while(nextAlloc2ndIndex < suballoc2ndCount && - suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) - { - ++nextAlloc2ndIndex; - } - - // Found non-null allocation. - if(nextAlloc2ndIndex < suballoc2ndCount) - { - const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); - } - - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.hAllocation); - - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - ++nextAlloc2ndIndex; - } - // We are at the end. - else - { - if(lastOffset < freeSpace2ndTo1stEnd) - { - // There is free space from lastOffset to freeSpace2ndTo1stEnd. - const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; - PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); - } - - // End of loop. - lastOffset = freeSpace2ndTo1stEnd; - } - } + VMA_ASSERT(0 && "Unmapping allocation not previously mapped."); } +} - nextAlloc1stIndex = m_1stNullItemsBeginCount; - while(lastOffset < freeSpace1stTo2ndEnd) - { - // Find next non-null allocation or move nextAllocIndex to the end. - while(nextAlloc1stIndex < suballoc1stCount && - suballocations1st[nextAlloc1stIndex].hAllocation == VK_NULL_HANDLE) - { - ++nextAlloc1stIndex; - } +VkResult VmaAllocation_T::DedicatedAllocMap(VmaAllocator hAllocator, void** ppData) +{ + VMA_ASSERT(GetType() == ALLOCATION_TYPE_DEDICATED); + VMA_ASSERT(IsMappingAllowed() && "Mapping is not allowed on this allocation! Please use one of the new VMA_ALLOCATION_CREATE_HOST_ACCESS_* flags when creating it."); - // Found non-null allocation. - if(nextAlloc1stIndex < suballoc1stCount) + if (m_MapCount != 0 || IsPersistentMap()) + { + if (m_MapCount < 0xFF) { - const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; - - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); - } - - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.hAllocation); - - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - ++nextAlloc1stIndex; + VMA_ASSERT(m_DedicatedAllocation.m_pMappedData != VMA_NULL); + *ppData = m_DedicatedAllocation.m_pMappedData; + ++m_MapCount; + return VK_SUCCESS; } - // We are at the end. else { - if(lastOffset < freeSpace1stTo2ndEnd) - { - // There is free space from lastOffset to freeSpace1stTo2ndEnd. - const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; - PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); - } - - // End of loop. - lastOffset = freeSpace1stTo2ndEnd; + VMA_ASSERT(0 && "Dedicated allocation mapped too many times simultaneously."); + return VK_ERROR_MEMORY_MAP_FAILED; } } - - if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + else { - size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; - while(lastOffset < size) + VkResult result = (*hAllocator->GetVulkanFunctions().vkMapMemory)( + hAllocator->m_hDevice, + m_DedicatedAllocation.m_hMemory, + 0, // offset + VK_WHOLE_SIZE, + 0, // flags + ppData); + if (result == VK_SUCCESS) { - // Find next non-null allocation or move nextAlloc2ndIndex to the end. - while(nextAlloc2ndIndex != SIZE_MAX && - suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) - { - --nextAlloc2ndIndex; - } - - // Found non-null allocation. - if(nextAlloc2ndIndex != SIZE_MAX) - { - const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; - - // 1. Process free space before this allocation. - if(lastOffset < suballoc.offset) - { - // There is free space from lastOffset to suballoc.offset. - const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; - PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); - } - - // 2. Process this allocation. - // There is allocation with suballoc.offset, suballoc.size. - PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.hAllocation); + m_DedicatedAllocation.m_pMappedData = *ppData; + m_MapCount = 1; + } + return result; + } +} - // 3. Prepare for next iteration. - lastOffset = suballoc.offset + suballoc.size; - --nextAlloc2ndIndex; - } - // We are at the end. - else - { - if(lastOffset < size) - { - // There is free space from lastOffset to size. - const VkDeviceSize unusedRangeSize = size - lastOffset; - PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); - } +void VmaAllocation_T::DedicatedAllocUnmap(VmaAllocator hAllocator) +{ + VMA_ASSERT(GetType() == ALLOCATION_TYPE_DEDICATED); - // End of loop. - lastOffset = size; - } + if (m_MapCount > 0) + { + --m_MapCount; + if (m_MapCount == 0 && !IsPersistentMap()) + { + m_DedicatedAllocation.m_pMappedData = VMA_NULL; + (*hAllocator->GetVulkanFunctions().vkUnmapMemory)( + hAllocator->m_hDevice, + m_DedicatedAllocation.m_hMemory); } } - - PrintDetailedMap_End(json); + else + { + VMA_ASSERT(0 && "Unmapping dedicated allocation not previously mapped."); + } } -#endif // #if VMA_STATS_STRING_ENABLED -bool VmaBlockMetadata_Linear::CreateAllocationRequest( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - bool upperAddress, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest) +#if VMA_STATS_STRING_ENABLED +void VmaAllocation_T::InitBufferImageUsage(uint32_t bufferImageUsage) { - VMA_ASSERT(allocSize > 0); - VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); - VMA_ASSERT(pAllocationRequest != VMA_NULL); - VMA_HEAVY_ASSERT(Validate()); - return upperAddress ? - CreateAllocationRequest_UpperAddress( - currentFrameIndex, frameInUseCount, bufferImageGranularity, - allocSize, allocAlignment, allocType, canMakeOtherLost, strategy, pAllocationRequest) : - CreateAllocationRequest_LowerAddress( - currentFrameIndex, frameInUseCount, bufferImageGranularity, - allocSize, allocAlignment, allocType, canMakeOtherLost, strategy, pAllocationRequest); + VMA_ASSERT(m_BufferImageUsage == 0); + m_BufferImageUsage = bufferImageUsage; } -bool VmaBlockMetadata_Linear::CreateAllocationRequest_UpperAddress( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest) +void VmaAllocation_T::PrintParameters(class VmaJsonWriter& json) const { - const VkDeviceSize size = GetSize(); - SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + json.WriteString("Type"); + json.WriteString(VMA_SUBALLOCATION_TYPE_NAMES[m_SuballocationType]); - if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) - { - VMA_ASSERT(0 && "Trying to use pool with linear algorithm as double stack, while it is already being used as ring buffer."); - return false; - } + json.WriteString("Size"); + json.WriteNumber(m_Size); + json.WriteString("Usage"); + json.WriteNumber(m_BufferImageUsage); - // Try to allocate before 2nd.back(), or end of block if 2nd.empty(). - if(allocSize > size) + if (m_pUserData != VMA_NULL) { - return false; + json.WriteString("CustomData"); + json.BeginString(); + json.ContinueString_Pointer(m_pUserData); + json.EndString(); } - VkDeviceSize resultBaseOffset = size - allocSize; - if(!suballocations2nd.empty()) + if (m_pName != VMA_NULL) { - const VmaSuballocation& lastSuballoc = suballocations2nd.back(); - resultBaseOffset = lastSuballoc.offset - allocSize; - if(allocSize > lastSuballoc.offset) - { - return false; - } + json.WriteString("Name"); + json.WriteString(m_pName); } +} +#endif // VMA_STATS_STRING_ENABLED - // Start from offset equal to end of free space. - VkDeviceSize resultOffset = resultBaseOffset; - - // Apply VMA_DEBUG_MARGIN at the end. - if(VMA_DEBUG_MARGIN > 0) +void VmaAllocation_T::FreeName(VmaAllocator hAllocator) +{ + if(m_pName) { - if(resultOffset < VMA_DEBUG_MARGIN) - { - return false; - } - resultOffset -= VMA_DEBUG_MARGIN; + VmaFreeString(hAllocator->GetAllocationCallbacks(), m_pName); + m_pName = VMA_NULL; } +} +#endif // _VMA_ALLOCATION_T_FUNCTIONS - // Apply alignment. - resultOffset = VmaAlignDown(resultOffset, allocAlignment); - - // Check next suballocations from 2nd for BufferImageGranularity conflicts. - // Make bigger alignment if necessary. - if(bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment && !suballocations2nd.empty()) - { - bool bufferImageGranularityConflict = false; - for(size_t nextSuballocIndex = suballocations2nd.size(); nextSuballocIndex--; ) - { - const VmaSuballocation& nextSuballoc = suballocations2nd[nextSuballocIndex]; - if(VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) - { - if(VmaIsBufferImageGranularityConflict(nextSuballoc.type, allocType)) - { - bufferImageGranularityConflict = true; - break; - } - } - else - // Already on previous page. - break; - } - if(bufferImageGranularityConflict) +#ifndef _VMA_BLOCK_VECTOR_FUNCTIONS +VmaBlockVector::VmaBlockVector( + VmaAllocator hAllocator, + VmaPool hParentPool, + uint32_t memoryTypeIndex, + VkDeviceSize preferredBlockSize, + size_t minBlockCount, + size_t maxBlockCount, + VkDeviceSize bufferImageGranularity, + bool explicitBlockSize, + uint32_t algorithm, + float priority, + VkDeviceSize minAllocationAlignment, + void* pMemoryAllocateNext) + : m_hAllocator(hAllocator), + m_hParentPool(hParentPool), + m_MemoryTypeIndex(memoryTypeIndex), + m_PreferredBlockSize(preferredBlockSize), + m_MinBlockCount(minBlockCount), + m_MaxBlockCount(maxBlockCount), + m_BufferImageGranularity(bufferImageGranularity), + m_ExplicitBlockSize(explicitBlockSize), + m_Algorithm(algorithm), + m_Priority(priority), + m_MinAllocationAlignment(minAllocationAlignment), + m_pMemoryAllocateNext(pMemoryAllocateNext), + m_Blocks(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), + m_NextBlockId(0) {} + +VmaBlockVector::~VmaBlockVector() +{ + for (size_t i = m_Blocks.size(); i--; ) + { + m_Blocks[i]->Destroy(m_hAllocator); + vma_delete(m_hAllocator, m_Blocks[i]); + } +} + +VkResult VmaBlockVector::CreateMinBlocks() +{ + for (size_t i = 0; i < m_MinBlockCount; ++i) + { + VkResult res = CreateBlock(m_PreferredBlockSize, VMA_NULL); + if (res != VK_SUCCESS) { - resultOffset = VmaAlignDown(resultOffset, bufferImageGranularity); + return res; } } + return VK_SUCCESS; +} - // There is enough free space. - const VkDeviceSize endOf1st = !suballocations1st.empty() ? - suballocations1st.back().offset + suballocations1st.back().size : - 0; - if(endOf1st + VMA_DEBUG_MARGIN <= resultOffset) +void VmaBlockVector::AddStatistics(VmaStatistics& inoutStats) +{ + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + + const size_t blockCount = m_Blocks.size(); + for (uint32_t blockIndex = 0; blockIndex < blockCount; ++blockIndex) { - // Check previous suballocations for BufferImageGranularity conflicts. - // If conflict exists, allocation cannot be made here. - if(bufferImageGranularity > 1) + const VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pBlock); + VMA_HEAVY_ASSERT(pBlock->Validate()); + pBlock->m_pMetadata->AddStatistics(inoutStats); + } +} + +void VmaBlockVector::AddDetailedStatistics(VmaDetailedStatistics& inoutStats) +{ + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + + const size_t blockCount = m_Blocks.size(); + for (uint32_t blockIndex = 0; blockIndex < blockCount; ++blockIndex) + { + const VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pBlock); + VMA_HEAVY_ASSERT(pBlock->Validate()); + pBlock->m_pMetadata->AddDetailedStatistics(inoutStats); + } +} + +bool VmaBlockVector::IsEmpty() +{ + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + return m_Blocks.empty(); +} + +bool VmaBlockVector::IsCorruptionDetectionEnabled() const +{ + const uint32_t requiredMemFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + return (VMA_DEBUG_DETECT_CORRUPTION != 0) && + (VMA_DEBUG_MARGIN > 0) && + (m_Algorithm == 0 || m_Algorithm == VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) && + (m_hAllocator->m_MemProps.memoryTypes[m_MemoryTypeIndex].propertyFlags & requiredMemFlags) == requiredMemFlags; +} + +VkResult VmaBlockVector::Allocate( + VkDeviceSize size, + VkDeviceSize alignment, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations) +{ + size_t allocIndex; + VkResult res = VK_SUCCESS; + + alignment = VMA_MAX(alignment, m_MinAllocationAlignment); + + if (IsCorruptionDetectionEnabled()) + { + size = VmaAlignUp(size, sizeof(VMA_CORRUPTION_DETECTION_MAGIC_VALUE)); + alignment = VmaAlignUp(alignment, sizeof(VMA_CORRUPTION_DETECTION_MAGIC_VALUE)); + } + + { + VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); + for (allocIndex = 0; allocIndex < allocationCount; ++allocIndex) { - for(size_t prevSuballocIndex = suballocations1st.size(); prevSuballocIndex--; ) + res = AllocatePage( + size, + alignment, + createInfo, + suballocType, + pAllocations + allocIndex); + if (res != VK_SUCCESS) { - const VmaSuballocation& prevSuballoc = suballocations1st[prevSuballocIndex]; - if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) - { - if(VmaIsBufferImageGranularityConflict(allocType, prevSuballoc.type)) - { - return false; - } - } - else - { - // Already on next page. - break; - } + break; } } + } - // All tests passed: Success. - pAllocationRequest->offset = resultOffset; - pAllocationRequest->sumFreeSize = resultBaseOffset + allocSize - endOf1st; - pAllocationRequest->sumItemSize = 0; - // pAllocationRequest->item unused. - pAllocationRequest->itemsToMakeLostCount = 0; - pAllocationRequest->type = VmaAllocationRequestType::UpperAddress; - return true; + if (res != VK_SUCCESS) + { + // Free all already created allocations. + while (allocIndex--) + Free(pAllocations[allocIndex]); + memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); } - return false; + return res; } -bool VmaBlockMetadata_Linear::CreateAllocationRequest_LowerAddress( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest) +VkResult VmaBlockVector::AllocatePage( + VkDeviceSize size, + VkDeviceSize alignment, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + VmaAllocation* pAllocation) { - const VkDeviceSize size = GetSize(); - SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const bool isUpperAddress = (createInfo.flags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0; - if(m_2ndVectorMode == SECOND_VECTOR_EMPTY || m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + VkDeviceSize freeMemory; { - // Try to allocate at the end of 1st vector. + const uint32_t heapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex); + VmaBudget heapBudget = {}; + m_hAllocator->GetHeapBudgets(&heapBudget, heapIndex, 1); + freeMemory = (heapBudget.usage < heapBudget.budget) ? (heapBudget.budget - heapBudget.usage) : 0; + } - VkDeviceSize resultBaseOffset = 0; - if(!suballocations1st.empty()) - { - const VmaSuballocation& lastSuballoc = suballocations1st.back(); - resultBaseOffset = lastSuballoc.offset + lastSuballoc.size; - } + const bool canFallbackToDedicated = !HasExplicitBlockSize() && + (createInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) == 0; + const bool canCreateNewBlock = + ((createInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) == 0) && + (m_Blocks.size() < m_MaxBlockCount) && + (freeMemory >= size || !canFallbackToDedicated); + uint32_t strategy = createInfo.flags & VMA_ALLOCATION_CREATE_STRATEGY_MASK; - // Start from offset equal to beginning of free space. - VkDeviceSize resultOffset = resultBaseOffset; + // Upper address can only be used with linear allocator and within single memory block. + if (isUpperAddress && + (m_Algorithm != VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT || m_MaxBlockCount > 1)) + { + return VK_ERROR_FEATURE_NOT_PRESENT; + } + + // Early reject: requested allocation size is larger that maximum block size for this block vector. + if (size + VMA_DEBUG_MARGIN > m_PreferredBlockSize) + { + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } - // Apply VMA_DEBUG_MARGIN at the beginning. - if(VMA_DEBUG_MARGIN > 0) + // 1. Search existing allocations. Try to allocate. + if (m_Algorithm == VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) + { + // Use only last block. + if (!m_Blocks.empty()) { - resultOffset += VMA_DEBUG_MARGIN; + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks.back(); + VMA_ASSERT(pCurrBlock); + VkResult res = AllocateFromBlock( + pCurrBlock, size, alignment, createInfo.flags, createInfo.pUserData, suballocType, strategy, pAllocation); + if (res == VK_SUCCESS) + { + VMA_DEBUG_LOG(" Returned from last block #%u", pCurrBlock->GetId()); + IncrementallySortBlocks(); + return VK_SUCCESS; + } } - - // Apply alignment. - resultOffset = VmaAlignUp(resultOffset, allocAlignment); - - // Check previous suballocations for BufferImageGranularity conflicts. - // Make bigger alignment if necessary. - if(bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment && !suballocations1st.empty()) + } + else + { + if (strategy != VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT) // MIN_MEMORY or default { - bool bufferImageGranularityConflict = false; - for(size_t prevSuballocIndex = suballocations1st.size(); prevSuballocIndex--; ) + const bool isHostVisible = + (m_hAllocator->m_MemProps.memoryTypes[m_MemoryTypeIndex].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0; + if(isHostVisible) { - const VmaSuballocation& prevSuballoc = suballocations1st[prevSuballocIndex]; - if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) + const bool isMappingAllowed = (createInfo.flags & + (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) != 0; + /* + For non-mappable allocations, check blocks that are not mapped first. + For mappable allocations, check blocks that are already mapped first. + This way, having many blocks, we will separate mappable and non-mappable allocations, + hopefully limiting the number of blocks that are mapped, which will help tools like RenderDoc. + */ + for(size_t mappingI = 0; mappingI < 2; ++mappingI) { - if(VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) + // Forward order in m_Blocks - prefer blocks with smallest amount of free space. + for (size_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex) { - bufferImageGranularityConflict = true; - break; + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pCurrBlock); + const bool isBlockMapped = pCurrBlock->GetMappedData() != VMA_NULL; + if((mappingI == 0) == (isMappingAllowed == isBlockMapped)) + { + VkResult res = AllocateFromBlock( + pCurrBlock, size, alignment, createInfo.flags, createInfo.pUserData, suballocType, strategy, pAllocation); + if (res == VK_SUCCESS) + { + VMA_DEBUG_LOG(" Returned from existing block #%u", pCurrBlock->GetId()); + IncrementallySortBlocks(); + return VK_SUCCESS; + } + } } } - else - // Already on previous page. - break; } - if(bufferImageGranularityConflict) + else { - resultOffset = VmaAlignUp(resultOffset, bufferImageGranularity); + // Forward order in m_Blocks - prefer blocks with smallest amount of free space. + for (size_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex) + { + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pCurrBlock); + VkResult res = AllocateFromBlock( + pCurrBlock, size, alignment, createInfo.flags, createInfo.pUserData, suballocType, strategy, pAllocation); + if (res == VK_SUCCESS) + { + VMA_DEBUG_LOG(" Returned from existing block #%u", pCurrBlock->GetId()); + IncrementallySortBlocks(); + return VK_SUCCESS; + } + } } } - - const VkDeviceSize freeSpaceEnd = m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? - suballocations2nd.back().offset : size; - - // There is enough free space at the end after alignment. - if(resultOffset + allocSize + VMA_DEBUG_MARGIN <= freeSpaceEnd) + else // VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT { - // Check next suballocations for BufferImageGranularity conflicts. - // If conflict exists, allocation cannot be made here. - if((allocSize % bufferImageGranularity || resultOffset % bufferImageGranularity) && m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + // Backward order in m_Blocks - prefer blocks with largest amount of free space. + for (size_t blockIndex = m_Blocks.size(); blockIndex--; ) { - for(size_t nextSuballocIndex = suballocations2nd.size(); nextSuballocIndex--; ) + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pCurrBlock); + VkResult res = AllocateFromBlock(pCurrBlock, size, alignment, createInfo.flags, createInfo.pUserData, suballocType, strategy, pAllocation); + if (res == VK_SUCCESS) { - const VmaSuballocation& nextSuballoc = suballocations2nd[nextSuballocIndex]; - if(VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) - { - if(VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) - { - return false; - } - } - else - { - // Already on previous page. - break; - } + VMA_DEBUG_LOG(" Returned from existing block #%u", pCurrBlock->GetId()); + IncrementallySortBlocks(); + return VK_SUCCESS; } } - - // All tests passed: Success. - pAllocationRequest->offset = resultOffset; - pAllocationRequest->sumFreeSize = freeSpaceEnd - resultBaseOffset; - pAllocationRequest->sumItemSize = 0; - // pAllocationRequest->item, customData unused. - pAllocationRequest->type = VmaAllocationRequestType::EndOf1st; - pAllocationRequest->itemsToMakeLostCount = 0; - return true; } } - // Wrap-around to end of 2nd vector. Try to allocate there, watching for the - // beginning of 1st vector as the end of free space. - if(m_2ndVectorMode == SECOND_VECTOR_EMPTY || m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + // 2. Try to create new block. + if (canCreateNewBlock) { - VMA_ASSERT(!suballocations1st.empty()); + // Calculate optimal size for new block. + VkDeviceSize newBlockSize = m_PreferredBlockSize; + uint32_t newBlockSizeShift = 0; + const uint32_t NEW_BLOCK_SIZE_SHIFT_MAX = 3; - VkDeviceSize resultBaseOffset = 0; - if(!suballocations2nd.empty()) + if (!m_ExplicitBlockSize) { - const VmaSuballocation& lastSuballoc = suballocations2nd.back(); - resultBaseOffset = lastSuballoc.offset + lastSuballoc.size; + // Allocate 1/8, 1/4, 1/2 as first blocks. + const VkDeviceSize maxExistingBlockSize = CalcMaxBlockSize(); + for (uint32_t i = 0; i < NEW_BLOCK_SIZE_SHIFT_MAX; ++i) + { + const VkDeviceSize smallerNewBlockSize = newBlockSize / 2; + if (smallerNewBlockSize > maxExistingBlockSize && smallerNewBlockSize >= size * 2) + { + newBlockSize = smallerNewBlockSize; + ++newBlockSizeShift; + } + else + { + break; + } + } } - // Start from offset equal to beginning of free space. - VkDeviceSize resultOffset = resultBaseOffset; - - // Apply VMA_DEBUG_MARGIN at the beginning. - if(VMA_DEBUG_MARGIN > 0) + size_t newBlockIndex = 0; + VkResult res = (newBlockSize <= freeMemory || !canFallbackToDedicated) ? + CreateBlock(newBlockSize, &newBlockIndex) : VK_ERROR_OUT_OF_DEVICE_MEMORY; + // Allocation of this size failed? Try 1/2, 1/4, 1/8 of m_PreferredBlockSize. + if (!m_ExplicitBlockSize) { - resultOffset += VMA_DEBUG_MARGIN; - } - - // Apply alignment. - resultOffset = VmaAlignUp(resultOffset, allocAlignment); - - // Check previous suballocations for BufferImageGranularity conflicts. - // Make bigger alignment if necessary. - if(bufferImageGranularity > 1 && bufferImageGranularity != allocAlignment && !suballocations2nd.empty()) - { - bool bufferImageGranularityConflict = false; - for(size_t prevSuballocIndex = suballocations2nd.size(); prevSuballocIndex--; ) + while (res < 0 && newBlockSizeShift < NEW_BLOCK_SIZE_SHIFT_MAX) { - const VmaSuballocation& prevSuballoc = suballocations2nd[prevSuballocIndex]; - if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) + const VkDeviceSize smallerNewBlockSize = newBlockSize / 2; + if (smallerNewBlockSize >= size) { - if(VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) - { - bufferImageGranularityConflict = true; - break; - } + newBlockSize = smallerNewBlockSize; + ++newBlockSizeShift; + res = (newBlockSize <= freeMemory || !canFallbackToDedicated) ? + CreateBlock(newBlockSize, &newBlockIndex) : VK_ERROR_OUT_OF_DEVICE_MEMORY; } else - // Already on previous page. + { break; - } - if(bufferImageGranularityConflict) - { - resultOffset = VmaAlignUp(resultOffset, bufferImageGranularity); + } } } - pAllocationRequest->itemsToMakeLostCount = 0; - pAllocationRequest->sumItemSize = 0; - size_t index1st = m_1stNullItemsBeginCount; - - if(canMakeOtherLost) + if (res == VK_SUCCESS) { - while(index1st < suballocations1st.size() && - resultOffset + allocSize + VMA_DEBUG_MARGIN > suballocations1st[index1st].offset) - { - // Next colliding allocation at the beginning of 1st vector found. Try to make it lost. - const VmaSuballocation& suballoc = suballocations1st[index1st]; - if(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE) - { - // No problem. - } - else - { - VMA_ASSERT(suballoc.hAllocation != VK_NULL_HANDLE); - if(suballoc.hAllocation->CanBecomeLost() && - suballoc.hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) - { - ++pAllocationRequest->itemsToMakeLostCount; - pAllocationRequest->sumItemSize += suballoc.size; - } - else - { - return false; - } - } - ++index1st; - } - - // Check next suballocations for BufferImageGranularity conflicts. - // If conflict exists, we must mark more allocations lost or fail. - if(allocSize % bufferImageGranularity || resultOffset % bufferImageGranularity) - { - while(index1st < suballocations1st.size()) - { - const VmaSuballocation& suballoc = suballocations1st[index1st]; - if(VmaBlocksOnSamePage(resultOffset, allocSize, suballoc.offset, bufferImageGranularity)) - { - if(suballoc.hAllocation != VK_NULL_HANDLE) - { - // Not checking actual VmaIsBufferImageGranularityConflict(allocType, suballoc.type). - if(suballoc.hAllocation->CanBecomeLost() && - suballoc.hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) - { - ++pAllocationRequest->itemsToMakeLostCount; - pAllocationRequest->sumItemSize += suballoc.size; - } - else - { - return false; - } - } - } - else - { - // Already on next page. - break; - } - ++index1st; - } - } + VmaDeviceMemoryBlock* const pBlock = m_Blocks[newBlockIndex]; + VMA_ASSERT(pBlock->m_pMetadata->GetSize() >= size); - // Special case: There is not enough room at the end for this allocation, even after making all from the 1st lost. - if(index1st == suballocations1st.size() && - resultOffset + allocSize + VMA_DEBUG_MARGIN > size) + res = AllocateFromBlock( + pBlock, size, alignment, createInfo.flags, createInfo.pUserData, suballocType, strategy, pAllocation); + if (res == VK_SUCCESS) { - // TODO: This is a known bug that it's not yet implemented and the allocation is failing. - VMA_DEBUG_LOG("Unsupported special case in custom pool with linear allocation algorithm used as ring buffer with allocations that can be lost."); + VMA_DEBUG_LOG(" Created new block #%u Size=%llu", pBlock->GetId(), newBlockSize); + IncrementallySortBlocks(); + return VK_SUCCESS; } - } - - // There is enough free space at the end after alignment. - if((index1st == suballocations1st.size() && resultOffset + allocSize + VMA_DEBUG_MARGIN <= size) || - (index1st < suballocations1st.size() && resultOffset + allocSize + VMA_DEBUG_MARGIN <= suballocations1st[index1st].offset)) - { - // Check next suballocations for BufferImageGranularity conflicts. - // If conflict exists, allocation cannot be made here. - if(allocSize % bufferImageGranularity || resultOffset % bufferImageGranularity) + else { - for (const auto nextSuballocIndex : c10::irange(index1st, suballocations1st.size())) { - const VmaSuballocation& nextSuballoc = suballocations1st[nextSuballocIndex]; - if(VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) - { - if(VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) - { - return false; - } - } - else - { - // Already on next page. - break; - } - } + // Allocation from new block failed, possibly due to VMA_DEBUG_MARGIN or alignment. + return VK_ERROR_OUT_OF_DEVICE_MEMORY; } - - // All tests passed: Success. - pAllocationRequest->offset = resultOffset; - pAllocationRequest->sumFreeSize = - (index1st < suballocations1st.size() ? suballocations1st[index1st].offset : size) - - resultBaseOffset - - pAllocationRequest->sumItemSize; - pAllocationRequest->type = VmaAllocationRequestType::EndOf2nd; - // pAllocationRequest->item, customData unused. - return true; } } - return false; + return VK_ERROR_OUT_OF_DEVICE_MEMORY; } -bool VmaBlockMetadata_Linear::MakeRequestedAllocationsLost( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VmaAllocationRequest* pAllocationRequest) +void VmaBlockVector::Free(const VmaAllocation hAllocation) { - if(pAllocationRequest->itemsToMakeLostCount == 0) + VmaDeviceMemoryBlock* pBlockToDelete = VMA_NULL; + + bool budgetExceeded = false; { - return true; + const uint32_t heapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex); + VmaBudget heapBudget = {}; + m_hAllocator->GetHeapBudgets(&heapBudget, heapIndex, 1); + budgetExceeded = heapBudget.usage >= heapBudget.budget; } - VMA_ASSERT(m_2ndVectorMode == SECOND_VECTOR_EMPTY || m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER); - - // We always start from 1st. - SuballocationVectorType* suballocations = &AccessSuballocations1st(); - size_t index = m_1stNullItemsBeginCount; - size_t madeLostCount = 0; - while(madeLostCount < pAllocationRequest->itemsToMakeLostCount) + // Scope for lock. { - if(index == suballocations->size()) + VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); + + VmaDeviceMemoryBlock* pBlock = hAllocation->GetBlock(); + + if (IsCorruptionDetectionEnabled()) { - index = 0; - // If we get to the end of 1st, we wrap around to beginning of 2nd of 1st. - if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) - { - suballocations = &AccessSuballocations2nd(); - } - // else: m_2ndVectorMode == SECOND_VECTOR_EMPTY: - // suballocations continues pointing at AccessSuballocations1st(). - VMA_ASSERT(!suballocations->empty()); + VkResult res = pBlock->ValidateMagicValueAfterAllocation(m_hAllocator, hAllocation->GetOffset(), hAllocation->GetSize()); + VMA_ASSERT(res == VK_SUCCESS && "Couldn't map block memory to validate magic value."); } - VmaSuballocation& suballoc = (*suballocations)[index]; - if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + + if (hAllocation->IsPersistentMap()) { - VMA_ASSERT(suballoc.hAllocation != VK_NULL_HANDLE); - VMA_ASSERT(suballoc.hAllocation->CanBecomeLost()); - if(suballoc.hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) - { - suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; - suballoc.hAllocation = VK_NULL_HANDLE; - m_SumFreeSize += suballoc.size; - if(suballocations == &AccessSuballocations1st()) - { - ++m_1stNullItemsMiddleCount; - } - else - { - ++m_2ndNullItemsCount; - } - ++madeLostCount; - } - else - { - return false; - } + pBlock->Unmap(m_hAllocator, 1); } - ++index; - } - CleanupAfterFree(); - //VMA_HEAVY_ASSERT(Validate()); // Already called by ClanupAfterFree(). - - return true; -} + const bool hadEmptyBlockBeforeFree = HasEmptyBlock(); + pBlock->m_pMetadata->Free(hAllocation->GetAllocHandle()); + pBlock->PostFree(m_hAllocator); + VMA_HEAVY_ASSERT(pBlock->Validate()); -uint32_t VmaBlockMetadata_Linear::MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) -{ - uint32_t lostAllocationCount = 0; + VMA_DEBUG_LOG(" Freed from MemoryTypeIndex=%u", m_MemoryTypeIndex); - SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - for(size_t i = m_1stNullItemsBeginCount, count = suballocations1st.size(); i < count; ++i) - { - VmaSuballocation& suballoc = suballocations1st[i]; - if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE && - suballoc.hAllocation->CanBecomeLost() && - suballoc.hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) + const bool canDeleteBlock = m_Blocks.size() > m_MinBlockCount; + // pBlock became empty after this deallocation. + if (pBlock->m_pMetadata->IsEmpty()) { - suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; - suballoc.hAllocation = VK_NULL_HANDLE; - ++m_1stNullItemsMiddleCount; - m_SumFreeSize += suballoc.size; - ++lostAllocationCount; + // Already had empty block. We don't want to have two, so delete this one. + if ((hadEmptyBlockBeforeFree || budgetExceeded) && canDeleteBlock) + { + pBlockToDelete = pBlock; + Remove(pBlock); + } + // else: We now have one empty block - leave it. A hysteresis to avoid allocating whole block back and forth. } - } - - SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - for(size_t i = 0, count = suballocations2nd.size(); i < count; ++i) - { - VmaSuballocation& suballoc = suballocations2nd[i]; - if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE && - suballoc.hAllocation->CanBecomeLost() && - suballoc.hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) + // pBlock didn't become empty, but we have another empty block - find and free that one. + // (This is optional, heuristics.) + else if (hadEmptyBlockBeforeFree && canDeleteBlock) { - suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; - suballoc.hAllocation = VK_NULL_HANDLE; - ++m_2ndNullItemsCount; - m_SumFreeSize += suballoc.size; - ++lostAllocationCount; + VmaDeviceMemoryBlock* pLastBlock = m_Blocks.back(); + if (pLastBlock->m_pMetadata->IsEmpty()) + { + pBlockToDelete = pLastBlock; + m_Blocks.pop_back(); + } } + + IncrementallySortBlocks(); } - if(lostAllocationCount) + // Destruction of a free block. Deferred until this point, outside of mutex + // lock, for performance reason. + if (pBlockToDelete != VMA_NULL) { - CleanupAfterFree(); + VMA_DEBUG_LOG(" Deleted empty block #%u", pBlockToDelete->GetId()); + pBlockToDelete->Destroy(m_hAllocator); + vma_delete(m_hAllocator, pBlockToDelete); } - return lostAllocationCount; + m_hAllocator->m_Budget.RemoveAllocation(m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex), hAllocation->GetSize()); + m_hAllocator->m_AllocationObjectAllocator.Free(hAllocation); } -VkResult VmaBlockMetadata_Linear::CheckCorruption(const void* pBlockData) +VkDeviceSize VmaBlockVector::CalcMaxBlockSize() const { - SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - for(size_t i = m_1stNullItemsBeginCount, count = suballocations1st.size(); i < count; ++i) + VkDeviceSize result = 0; + for (size_t i = m_Blocks.size(); i--; ) { - const VmaSuballocation& suballoc = suballocations1st[i]; - if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + result = VMA_MAX(result, m_Blocks[i]->m_pMetadata->GetSize()); + if (result >= m_PreferredBlockSize) { - if(!VmaValidateMagicValue(pBlockData, suballoc.offset - VMA_DEBUG_MARGIN)) - { - VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED BEFORE VALIDATED ALLOCATION!"); - return VK_ERROR_VALIDATION_FAILED_EXT; - } - if(!VmaValidateMagicValue(pBlockData, suballoc.offset + suballoc.size)) - { - VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); - return VK_ERROR_VALIDATION_FAILED_EXT; - } + break; } } + return result; +} - SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - for(size_t i = 0, count = suballocations2nd.size(); i < count; ++i) +void VmaBlockVector::Remove(VmaDeviceMemoryBlock* pBlock) +{ + for (uint32_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex) { - const VmaSuballocation& suballoc = suballocations2nd[i]; - if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + if (m_Blocks[blockIndex] == pBlock) { - if(!VmaValidateMagicValue(pBlockData, suballoc.offset - VMA_DEBUG_MARGIN)) - { - VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED BEFORE VALIDATED ALLOCATION!"); - return VK_ERROR_VALIDATION_FAILED_EXT; - } - if(!VmaValidateMagicValue(pBlockData, suballoc.offset + suballoc.size)) - { - VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); - return VK_ERROR_VALIDATION_FAILED_EXT; - } + VmaVectorRemove(m_Blocks, blockIndex); + return; } } - - return VK_SUCCESS; + VMA_ASSERT(0); } -void VmaBlockMetadata_Linear::Alloc( - const VmaAllocationRequest& request, - VmaSuballocationType type, - VkDeviceSize allocSize, - VmaAllocation hAllocation) +void VmaBlockVector::IncrementallySortBlocks() { - const VmaSuballocation newSuballoc = { request.offset, allocSize, hAllocation, type }; - - switch(request.type) + if (!m_IncrementalSort) + return; + if (m_Algorithm != VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) { - case VmaAllocationRequestType::UpperAddress: - { - VMA_ASSERT(m_2ndVectorMode != SECOND_VECTOR_RING_BUFFER && - "CRITICAL ERROR: Trying to use linear allocator as double stack while it was already used as ring buffer."); - SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - suballocations2nd.push_back(newSuballoc); - m_2ndVectorMode = SECOND_VECTOR_DOUBLE_STACK; - } - break; - case VmaAllocationRequestType::EndOf1st: - { - SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - - VMA_ASSERT(suballocations1st.empty() || - request.offset >= suballocations1st.back().offset + suballocations1st.back().size); - // Check if it fits before the end of the block. - VMA_ASSERT(request.offset + allocSize <= GetSize()); - - suballocations1st.push_back(newSuballoc); - } - break; - case VmaAllocationRequestType::EndOf2nd: + // Bubble sort only until first swap. + for (size_t i = 1; i < m_Blocks.size(); ++i) { - SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - // New allocation at the end of 2-part ring buffer, so before first allocation from 1st vector. - VMA_ASSERT(!suballocations1st.empty() && - request.offset + allocSize <= suballocations1st[m_1stNullItemsBeginCount].offset); - SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); - - switch(m_2ndVectorMode) + if (m_Blocks[i - 1]->m_pMetadata->GetSumFreeSize() > m_Blocks[i]->m_pMetadata->GetSumFreeSize()) { - case SECOND_VECTOR_EMPTY: - // First allocation from second part ring buffer. - VMA_ASSERT(suballocations2nd.empty()); - m_2ndVectorMode = SECOND_VECTOR_RING_BUFFER; - break; - case SECOND_VECTOR_RING_BUFFER: - // 2-part ring buffer is already started. - VMA_ASSERT(!suballocations2nd.empty()); - break; - case SECOND_VECTOR_DOUBLE_STACK: - VMA_ASSERT(0 && "CRITICAL ERROR: Trying to use linear allocator as ring buffer while it was already used as double stack."); - break; - default: - VMA_ASSERT(0); + VMA_SWAP(m_Blocks[i - 1], m_Blocks[i]); + return; } - - suballocations2nd.push_back(newSuballoc); } - break; - default: - VMA_ASSERT(0 && "CRITICAL INTERNAL ERROR."); } - - m_SumFreeSize -= newSuballoc.size; } -void VmaBlockMetadata_Linear::Free(const VmaAllocation allocation) +void VmaBlockVector::SortByFreeSize() { - FreeAtOffset(allocation->GetOffset()); + VMA_SORT(m_Blocks.begin(), m_Blocks.end(), + [](VmaDeviceMemoryBlock* b1, VmaDeviceMemoryBlock* b2) -> bool + { + return b1->m_pMetadata->GetSumFreeSize() < b2->m_pMetadata->GetSumFreeSize(); + }); } -void VmaBlockMetadata_Linear::FreeAtOffset(VkDeviceSize offset) +VkResult VmaBlockVector::AllocateFromBlock( + VmaDeviceMemoryBlock* pBlock, + VkDeviceSize size, + VkDeviceSize alignment, + VmaAllocationCreateFlags allocFlags, + void* pUserData, + VmaSuballocationType suballocType, + uint32_t strategy, + VmaAllocation* pAllocation) { - SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const bool isUpperAddress = (allocFlags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0; - if(!suballocations1st.empty()) + VmaAllocationRequest currRequest = {}; + if (pBlock->m_pMetadata->CreateAllocationRequest( + size, + alignment, + isUpperAddress, + suballocType, + strategy, + &currRequest)) { - // First allocation: Mark it as next empty at the beginning. - VmaSuballocation& firstSuballoc = suballocations1st[m_1stNullItemsBeginCount]; - if(firstSuballoc.offset == offset) - { - firstSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; - firstSuballoc.hAllocation = VK_NULL_HANDLE; - m_SumFreeSize += firstSuballoc.size; - ++m_1stNullItemsBeginCount; - CleanupAfterFree(); - return; - } + return CommitAllocationRequest(currRequest, pBlock, alignment, allocFlags, pUserData, suballocType, pAllocation); } + return VK_ERROR_OUT_OF_DEVICE_MEMORY; +} - // Last allocation in 2-part ring buffer or top of upper stack (same logic). - if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER || - m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) +VkResult VmaBlockVector::CommitAllocationRequest( + VmaAllocationRequest& allocRequest, + VmaDeviceMemoryBlock* pBlock, + VkDeviceSize alignment, + VmaAllocationCreateFlags allocFlags, + void* pUserData, + VmaSuballocationType suballocType, + VmaAllocation* pAllocation) +{ + const bool mapped = (allocFlags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0; + const bool isUserDataString = (allocFlags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0; + const bool isMappingAllowed = (allocFlags & + (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) != 0; + + pBlock->PostAlloc(); + // Allocate from pCurrBlock. + if (mapped) { - VmaSuballocation& lastSuballoc = suballocations2nd.back(); - if(lastSuballoc.offset == offset) + VkResult res = pBlock->Map(m_hAllocator, 1, VMA_NULL); + if (res != VK_SUCCESS) { - m_SumFreeSize += lastSuballoc.size; - suballocations2nd.pop_back(); - CleanupAfterFree(); - return; + return res; } } - // Last allocation in 1st vector. - else if(m_2ndVectorMode == SECOND_VECTOR_EMPTY) - { - VmaSuballocation& lastSuballoc = suballocations1st.back(); - if(lastSuballoc.offset == offset) - { - m_SumFreeSize += lastSuballoc.size; - suballocations1st.pop_back(); - CleanupAfterFree(); - return; - } + + *pAllocation = m_hAllocator->m_AllocationObjectAllocator.Allocate(isMappingAllowed); + pBlock->m_pMetadata->Alloc(allocRequest, suballocType, *pAllocation); + (*pAllocation)->InitBlockAllocation( + pBlock, + allocRequest.allocHandle, + alignment, + allocRequest.size, // Not size, as actual allocation size may be larger than requested! + m_MemoryTypeIndex, + suballocType, + mapped); + VMA_HEAVY_ASSERT(pBlock->Validate()); + if (isUserDataString) + (*pAllocation)->SetName(m_hAllocator, (const char*)pUserData); + else + (*pAllocation)->SetUserData(m_hAllocator, pUserData); + m_hAllocator->m_Budget.AddAllocation(m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex), allocRequest.size); + if (VMA_DEBUG_INITIALIZE_ALLOCATIONS) + { + m_hAllocator->FillAllocation(*pAllocation, VMA_ALLOCATION_FILL_PATTERN_CREATED); + } + if (IsCorruptionDetectionEnabled()) + { + VkResult res = pBlock->WriteMagicValueAfterAllocation(m_hAllocator, (*pAllocation)->GetOffset(), allocRequest.size); + VMA_ASSERT(res == VK_SUCCESS && "Couldn't map block memory to write magic value."); } + return VK_SUCCESS; +} - // Item from the middle of 1st vector. +VkResult VmaBlockVector::CreateBlock(VkDeviceSize blockSize, size_t* pNewBlockIndex) +{ + VkMemoryAllocateInfo allocInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO }; + allocInfo.pNext = m_pMemoryAllocateNext; + allocInfo.memoryTypeIndex = m_MemoryTypeIndex; + allocInfo.allocationSize = blockSize; + +#if VMA_BUFFER_DEVICE_ADDRESS + // Every standalone block can potentially contain a buffer with VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT - always enable the feature. + VkMemoryAllocateFlagsInfoKHR allocFlagsInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO_KHR }; + if (m_hAllocator->m_UseKhrBufferDeviceAddress) { - VmaSuballocation refSuballoc; - refSuballoc.offset = offset; - // Rest of members stays uninitialized intentionally for better performance. - SuballocationVectorType::iterator it = VmaBinaryFindSorted( - suballocations1st.begin() + m_1stNullItemsBeginCount, - suballocations1st.end(), - refSuballoc, - VmaSuballocationOffsetLess()); - if(it != suballocations1st.end()) - { - it->type = VMA_SUBALLOCATION_TYPE_FREE; - it->hAllocation = VK_NULL_HANDLE; - ++m_1stNullItemsMiddleCount; - m_SumFreeSize += it->size; - CleanupAfterFree(); - return; - } + allocFlagsInfo.flags = VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT_KHR; + VmaPnextChainPushFront(&allocInfo, &allocFlagsInfo); } +#endif // VMA_BUFFER_DEVICE_ADDRESS - if(m_2ndVectorMode != SECOND_VECTOR_EMPTY) +#if VMA_MEMORY_PRIORITY + VkMemoryPriorityAllocateInfoEXT priorityInfo = { VK_STRUCTURE_TYPE_MEMORY_PRIORITY_ALLOCATE_INFO_EXT }; + if (m_hAllocator->m_UseExtMemoryPriority) { - // Item from the middle of 2nd vector. - VmaSuballocation refSuballoc; - refSuballoc.offset = offset; - // Rest of members stays uninitialized intentionally for better performance. - SuballocationVectorType::iterator it = m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER ? - VmaBinaryFindSorted(suballocations2nd.begin(), suballocations2nd.end(), refSuballoc, VmaSuballocationOffsetLess()) : - VmaBinaryFindSorted(suballocations2nd.begin(), suballocations2nd.end(), refSuballoc, VmaSuballocationOffsetGreater()); - if(it != suballocations2nd.end()) + VMA_ASSERT(m_Priority >= 0.f && m_Priority <= 1.f); + priorityInfo.priority = m_Priority; + VmaPnextChainPushFront(&allocInfo, &priorityInfo); + } +#endif // VMA_MEMORY_PRIORITY + +#if VMA_EXTERNAL_MEMORY + // Attach VkExportMemoryAllocateInfoKHR if necessary. + VkExportMemoryAllocateInfoKHR exportMemoryAllocInfo = { VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_KHR }; + exportMemoryAllocInfo.handleTypes = m_hAllocator->GetExternalMemoryHandleTypeFlags(m_MemoryTypeIndex); + if (exportMemoryAllocInfo.handleTypes != 0) + { + VmaPnextChainPushFront(&allocInfo, &exportMemoryAllocInfo); + } +#endif // VMA_EXTERNAL_MEMORY + + VkDeviceMemory mem = VK_NULL_HANDLE; + VkResult res = m_hAllocator->AllocateVulkanMemory(&allocInfo, &mem); + if (res < 0) + { + return res; + } + + // New VkDeviceMemory successfully created. + + // Create new Allocation for it. + VmaDeviceMemoryBlock* const pBlock = vma_new(m_hAllocator, VmaDeviceMemoryBlock)(m_hAllocator); + pBlock->Init( + m_hAllocator, + m_hParentPool, + m_MemoryTypeIndex, + mem, + allocInfo.allocationSize, + m_NextBlockId++, + m_Algorithm, + m_BufferImageGranularity); + + m_Blocks.push_back(pBlock); + if (pNewBlockIndex != VMA_NULL) + { + *pNewBlockIndex = m_Blocks.size() - 1; + } + + return VK_SUCCESS; +} + +bool VmaBlockVector::HasEmptyBlock() +{ + for (size_t index = 0, count = m_Blocks.size(); index < count; ++index) + { + VmaDeviceMemoryBlock* const pBlock = m_Blocks[index]; + if (pBlock->m_pMetadata->IsEmpty()) { - it->type = VMA_SUBALLOCATION_TYPE_FREE; - it->hAllocation = VK_NULL_HANDLE; - ++m_2ndNullItemsCount; - m_SumFreeSize += it->size; - CleanupAfterFree(); - return; + return true; } } + return false; +} - VMA_ASSERT(0 && "Allocation to free not found in linear allocator!"); +#if VMA_STATS_STRING_ENABLED +void VmaBlockVector::PrintDetailedMap(class VmaJsonWriter& json) +{ + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + + + json.BeginObject(); + for (size_t i = 0; i < m_Blocks.size(); ++i) + { + json.BeginString(); + json.ContinueString(m_Blocks[i]->GetId()); + json.EndString(); + + json.BeginObject(); + json.WriteString("MapRefCount"); + json.WriteNumber(m_Blocks[i]->GetMapRefCount()); + + m_Blocks[i]->m_pMetadata->PrintDetailedMap(json); + json.EndObject(); + } + json.EndObject(); } +#endif // VMA_STATS_STRING_ENABLED -bool VmaBlockMetadata_Linear::ShouldCompact1st() const +VkResult VmaBlockVector::CheckCorruption() { - const size_t nullItemCount = m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount; - const size_t suballocCount = AccessSuballocations1st().size(); - return suballocCount > 32 && nullItemCount * 2 >= (suballocCount - nullItemCount) * 3; + if (!IsCorruptionDetectionEnabled()) + { + return VK_ERROR_FEATURE_NOT_PRESENT; + } + + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + for (uint32_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex) + { + VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pBlock); + VkResult res = pBlock->CheckCorruption(m_hAllocator); + if (res != VK_SUCCESS) + { + return res; + } + } + return VK_SUCCESS; } -void VmaBlockMetadata_Linear::CleanupAfterFree() +#endif // _VMA_BLOCK_VECTOR_FUNCTIONS + +#ifndef _VMA_DEFRAGMENTATION_CONTEXT_FUNCTIONS +VmaDefragmentationContext_T::VmaDefragmentationContext_T( + VmaAllocator hAllocator, + const VmaDefragmentationInfo& info) + : m_MaxPassBytes(info.maxBytesPerPass == 0 ? VK_WHOLE_SIZE : info.maxBytesPerPass), + m_MaxPassAllocations(info.maxAllocationsPerPass == 0 ? UINT32_MAX : info.maxAllocationsPerPass), + m_MoveAllocator(hAllocator->GetAllocationCallbacks()), + m_Moves(m_MoveAllocator) { - SuballocationVectorType& suballocations1st = AccessSuballocations1st(); - SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + m_Algorithm = info.flags & VMA_DEFRAGMENTATION_FLAG_ALGORITHM_MASK; - if(IsEmpty()) + if (info.pool != VMA_NULL) { - suballocations1st.clear(); - suballocations2nd.clear(); - m_1stNullItemsBeginCount = 0; - m_1stNullItemsMiddleCount = 0; - m_2ndNullItemsCount = 0; - m_2ndVectorMode = SECOND_VECTOR_EMPTY; + m_BlockVectorCount = 1; + m_PoolBlockVector = &info.pool->m_BlockVector; + m_pBlockVectors = &m_PoolBlockVector; + m_PoolBlockVector->SetIncrementalSort(false); + m_PoolBlockVector->SortByFreeSize(); } else { - const size_t suballoc1stCount = suballocations1st.size(); - const size_t nullItem1stCount = m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount; - VMA_ASSERT(nullItem1stCount <= suballoc1stCount); - - // Find more null items at the beginning of 1st vector. - while(m_1stNullItemsBeginCount < suballoc1stCount && - suballocations1st[m_1stNullItemsBeginCount].hAllocation == VK_NULL_HANDLE) + m_BlockVectorCount = hAllocator->GetMemoryTypeCount(); + m_PoolBlockVector = VMA_NULL; + m_pBlockVectors = hAllocator->m_pBlockVectors; + for (uint32_t i = 0; i < m_BlockVectorCount; ++i) { - ++m_1stNullItemsBeginCount; - --m_1stNullItemsMiddleCount; + VmaBlockVector* vector = m_pBlockVectors[i]; + if (vector != VMA_NULL) + { + vector->SetIncrementalSort(false); + vector->SortByFreeSize(); + } } - - // Find more null items at the end of 1st vector. - while(m_1stNullItemsMiddleCount > 0 && - suballocations1st.back().hAllocation == VK_NULL_HANDLE) + } + + switch (m_Algorithm) + { + case 0: // Default algorithm + m_Algorithm = VMA_DEFRAGMENTATION_FLAG_ALGORITHM_BALANCED_BIT; + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_BALANCED_BIT: + { + m_AlgorithmState = vma_new_array(hAllocator, StateBalanced, m_BlockVectorCount); + break; + } + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_EXTENSIVE_BIT: + { + if (hAllocator->GetBufferImageGranularity() > 1) { - --m_1stNullItemsMiddleCount; - suballocations1st.pop_back(); + m_AlgorithmState = vma_new_array(hAllocator, StateExtensive, m_BlockVectorCount); } + break; + } + } +} - // Find more null items at the end of 2nd vector. - while(m_2ndNullItemsCount > 0 && - suballocations2nd.back().hAllocation == VK_NULL_HANDLE) +VmaDefragmentationContext_T::~VmaDefragmentationContext_T() +{ + if (m_PoolBlockVector != VMA_NULL) + { + m_PoolBlockVector->SetIncrementalSort(true); + } + else + { + for (uint32_t i = 0; i < m_BlockVectorCount; ++i) { - --m_2ndNullItemsCount; - suballocations2nd.pop_back(); + VmaBlockVector* vector = m_pBlockVectors[i]; + if (vector != VMA_NULL) + vector->SetIncrementalSort(true); } + } - // Find more null items at the beginning of 2nd vector. - while(m_2ndNullItemsCount > 0 && - suballocations2nd[0].hAllocation == VK_NULL_HANDLE) + if (m_AlgorithmState) + { + switch (m_Algorithm) { - --m_2ndNullItemsCount; - VmaVectorRemove(suballocations2nd, 0); + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_BALANCED_BIT: + vma_delete_array(m_MoveAllocator.m_pCallbacks, reinterpret_cast(m_AlgorithmState), m_BlockVectorCount); + break; + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_EXTENSIVE_BIT: + vma_delete_array(m_MoveAllocator.m_pCallbacks, reinterpret_cast(m_AlgorithmState), m_BlockVectorCount); + break; + default: + VMA_ASSERT(0); } + } +} + +VkResult VmaDefragmentationContext_T::DefragmentPassBegin(VmaDefragmentationPassMoveInfo& moveInfo) +{ + if (m_PoolBlockVector != VMA_NULL) + { + VmaMutexLockWrite lock(m_PoolBlockVector->GetMutex(), m_PoolBlockVector->GetAllocator()->m_UseMutex); - if(ShouldCompact1st()) + if (m_PoolBlockVector->GetBlockCount() > 1) + ComputeDefragmentation(*m_PoolBlockVector, 0); + else if (m_PoolBlockVector->GetBlockCount() == 1) + ReallocWithinBlock(*m_PoolBlockVector, m_PoolBlockVector->GetBlock(0)); + } + else + { + for (uint32_t i = 0; i < m_BlockVectorCount; ++i) { - const size_t nonNullItemCount = suballoc1stCount - nullItem1stCount; - size_t srcIndex = m_1stNullItemsBeginCount; - for (const auto dstIndex : c10::irange(nonNullItemCount)) { - while(suballocations1st[srcIndex].hAllocation == VK_NULL_HANDLE) + if (m_pBlockVectors[i] != VMA_NULL) + { + VmaMutexLockWrite lock(m_pBlockVectors[i]->GetMutex(), m_pBlockVectors[i]->GetAllocator()->m_UseMutex); + + if (m_pBlockVectors[i]->GetBlockCount() > 1) { - ++srcIndex; + if (ComputeDefragmentation(*m_pBlockVectors[i], i)) + break; } - if(dstIndex != srcIndex) + else if (m_pBlockVectors[i]->GetBlockCount() == 1) { - suballocations1st[dstIndex] = suballocations1st[srcIndex]; + if (ReallocWithinBlock(*m_pBlockVectors[i], m_pBlockVectors[i]->GetBlock(0))) + break; } - ++srcIndex; } - suballocations1st.resize(nonNullItemCount); - m_1stNullItemsBeginCount = 0; - m_1stNullItemsMiddleCount = 0; } + } - // 2nd vector became empty. - if(suballocations2nd.empty()) + moveInfo.moveCount = static_cast(m_Moves.size()); + if (moveInfo.moveCount > 0) + { + moveInfo.pMoves = m_Moves.data(); + return VK_INCOMPLETE; + } + + moveInfo.pMoves = VMA_NULL; + return VK_SUCCESS; +} + +VkResult VmaDefragmentationContext_T::DefragmentPassEnd(VmaDefragmentationPassMoveInfo& moveInfo) +{ + VMA_ASSERT(moveInfo.moveCount > 0 ? moveInfo.pMoves != VMA_NULL : true); + + VkResult result = VK_SUCCESS; + VmaStlAllocator blockAllocator(m_MoveAllocator.m_pCallbacks); + VmaVector> immovableBlocks(blockAllocator); + VmaVector> mappedBlocks(blockAllocator); + + VmaAllocator allocator = VMA_NULL; + for (uint32_t i = 0; i < moveInfo.moveCount; ++i) + { + VmaDefragmentationMove& move = moveInfo.pMoves[i]; + size_t prevCount = 0, currentCount = 0; + VkDeviceSize freedBlockSize = 0; + + uint32_t vectorIndex; + VmaBlockVector* vector; + if (m_PoolBlockVector != VMA_NULL) { - m_2ndVectorMode = SECOND_VECTOR_EMPTY; + vectorIndex = 0; + vector = m_PoolBlockVector; } + else + { + vectorIndex = move.srcAllocation->GetMemoryTypeIndex(); + vector = m_pBlockVectors[vectorIndex]; + VMA_ASSERT(vector != VMA_NULL); + } + + switch (move.operation) + { + case VMA_DEFRAGMENTATION_MOVE_OPERATION_COPY: + { + uint8_t mapCount = move.srcAllocation->SwapBlockAllocation(vector->m_hAllocator, move.dstTmpAllocation); + if (mapCount > 0) + { + allocator = vector->m_hAllocator; + VmaDeviceMemoryBlock* newMapBlock = move.srcAllocation->GetBlock(); + bool notPresent = true; + for (FragmentedBlock& block : mappedBlocks) + { + if (block.block == newMapBlock) + { + notPresent = false; + block.data += mapCount; + break; + } + } + if (notPresent) + mappedBlocks.push_back({ mapCount, newMapBlock }); + } - // 1st vector became empty. - if(suballocations1st.size() - m_1stNullItemsBeginCount == 0) + // Scope for locks, Free have it's own lock + { + VmaMutexLockRead lock(vector->GetMutex(), vector->GetAllocator()->m_UseMutex); + prevCount = vector->GetBlockCount(); + freedBlockSize = move.dstTmpAllocation->GetBlock()->m_pMetadata->GetSize(); + } + vector->Free(move.dstTmpAllocation); + { + VmaMutexLockRead lock(vector->GetMutex(), vector->GetAllocator()->m_UseMutex); + currentCount = vector->GetBlockCount(); + } + + result = VK_INCOMPLETE; + break; + } + case VMA_DEFRAGMENTATION_MOVE_OPERATION_IGNORE: { - suballocations1st.clear(); - m_1stNullItemsBeginCount = 0; + m_PassStats.bytesMoved -= move.srcAllocation->GetSize(); + --m_PassStats.allocationsMoved; + vector->Free(move.dstTmpAllocation); - if(!suballocations2nd.empty() && m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + VmaDeviceMemoryBlock* newBlock = move.srcAllocation->GetBlock(); + bool notPresent = true; + for (const FragmentedBlock& block : immovableBlocks) { - // Swap 1st with 2nd. Now 2nd is empty. - m_2ndVectorMode = SECOND_VECTOR_EMPTY; - m_1stNullItemsMiddleCount = m_2ndNullItemsCount; - while(m_1stNullItemsBeginCount < suballocations2nd.size() && - suballocations2nd[m_1stNullItemsBeginCount].hAllocation == VK_NULL_HANDLE) + if (block.block == newBlock) { - ++m_1stNullItemsBeginCount; - --m_1stNullItemsMiddleCount; + notPresent = false; + break; } - m_2ndNullItemsCount = 0; - m_1stVectorIndex ^= 1; } + if (notPresent) + immovableBlocks.push_back({ vectorIndex, newBlock }); + break; } - } - - VMA_HEAVY_ASSERT(Validate()); -} + case VMA_DEFRAGMENTATION_MOVE_OPERATION_DESTROY: + { + m_PassStats.bytesMoved -= move.srcAllocation->GetSize(); + --m_PassStats.allocationsMoved; + // Scope for locks, Free have it's own lock + { + VmaMutexLockRead lock(vector->GetMutex(), vector->GetAllocator()->m_UseMutex); + prevCount = vector->GetBlockCount(); + freedBlockSize = move.srcAllocation->GetBlock()->m_pMetadata->GetSize(); + } + vector->Free(move.srcAllocation); + { + VmaMutexLockRead lock(vector->GetMutex(), vector->GetAllocator()->m_UseMutex); + currentCount = vector->GetBlockCount(); + } + freedBlockSize *= prevCount - currentCount; + VkDeviceSize dstBlockSize; + { + VmaMutexLockRead lock(vector->GetMutex(), vector->GetAllocator()->m_UseMutex); + dstBlockSize = move.dstTmpAllocation->GetBlock()->m_pMetadata->GetSize(); + } + vector->Free(move.dstTmpAllocation); + { + VmaMutexLockRead lock(vector->GetMutex(), vector->GetAllocator()->m_UseMutex); + freedBlockSize += dstBlockSize * (currentCount - vector->GetBlockCount()); + currentCount = vector->GetBlockCount(); + } -//////////////////////////////////////////////////////////////////////////////// -// class VmaBlockMetadata_Buddy - -VmaBlockMetadata_Buddy::VmaBlockMetadata_Buddy(VmaAllocator hAllocator) : - VmaBlockMetadata(hAllocator), - m_Root(VMA_NULL), - m_AllocationCount(0), - m_FreeCount(1), - m_SumFreeSize(0) -{ - memset(m_FreeList, 0, sizeof(m_FreeList)); -} - -VmaBlockMetadata_Buddy::~VmaBlockMetadata_Buddy() -{ - DeleteNode(m_Root); -} - -void VmaBlockMetadata_Buddy::Init(VkDeviceSize size) -{ - VmaBlockMetadata::Init(size); + result = VK_INCOMPLETE; + break; + } + default: + VMA_ASSERT(0); + } - m_UsableSize = VmaPrevPow2(size); - m_SumFreeSize = m_UsableSize; + if (prevCount > currentCount) + { + size_t freedBlocks = prevCount - currentCount; + m_PassStats.deviceMemoryBlocksFreed += static_cast(freedBlocks); + m_PassStats.bytesFreed += freedBlockSize; + } - // Calculate m_LevelCount. - m_LevelCount = 1; - while(m_LevelCount < MAX_LEVELS && - LevelToNodeSize(m_LevelCount) >= MIN_NODE_SIZE) - { - ++m_LevelCount; + switch (m_Algorithm) + { + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_EXTENSIVE_BIT: + { + if (m_AlgorithmState != VMA_NULL) + { + // Avoid unnecessary tries to allocate when new free block is avaiable + StateExtensive& state = reinterpret_cast(m_AlgorithmState)[vectorIndex]; + if (state.firstFreeBlock != SIZE_MAX) + { + const size_t diff = prevCount - currentCount; + if (state.firstFreeBlock >= diff) + { + state.firstFreeBlock -= diff; + if (state.firstFreeBlock != 0) + state.firstFreeBlock -= vector->GetBlock(state.firstFreeBlock - 1)->m_pMetadata->IsEmpty(); + } + else + state.firstFreeBlock = 0; + } + } + } + } } + moveInfo.moveCount = 0; + moveInfo.pMoves = VMA_NULL; + m_Moves.clear(); - Node* rootNode = vma_new(GetAllocationCallbacks(), Node)(); - rootNode->offset = 0; - rootNode->type = Node::TYPE_FREE; - rootNode->parent = VMA_NULL; - rootNode->buddy = VMA_NULL; - - m_Root = rootNode; - AddToFreeListFront(0, rootNode); -} + // Update stats + m_GlobalStats.allocationsMoved += m_PassStats.allocationsMoved; + m_GlobalStats.bytesFreed += m_PassStats.bytesFreed; + m_GlobalStats.bytesMoved += m_PassStats.bytesMoved; + m_GlobalStats.deviceMemoryBlocksFreed += m_PassStats.deviceMemoryBlocksFreed; + m_PassStats = { 0 }; -bool VmaBlockMetadata_Buddy::Validate() const -{ - // Validate tree. - ValidationContext ctx; - if(!ValidateNode(ctx, VMA_NULL, m_Root, 0, LevelToNodeSize(0))) + // Move blocks with immovable allocations according to algorithm + if (immovableBlocks.size() > 0) { - VMA_VALIDATE(false && "ValidateNode failed."); - } - VMA_VALIDATE(m_AllocationCount == ctx.calculatedAllocationCount); - VMA_VALIDATE(m_SumFreeSize == ctx.calculatedSumFreeSize); - - // Validate free node lists. - for (const auto level : c10::irange(m_LevelCount)) { - VMA_VALIDATE(m_FreeList[level].front == VMA_NULL || - m_FreeList[level].front->free.prev == VMA_NULL); - - for(Node* node = m_FreeList[level].front; - node != VMA_NULL; - node = node->free.next) + switch (m_Algorithm) { - VMA_VALIDATE(node->type == Node::TYPE_FREE); - - if(node->free.next == VMA_NULL) + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_EXTENSIVE_BIT: + { + if (m_AlgorithmState != VMA_NULL) { - VMA_VALIDATE(m_FreeList[level].back == node); + bool swapped = false; + // Move to the start of free blocks range + for (const FragmentedBlock& block : immovableBlocks) + { + StateExtensive& state = reinterpret_cast(m_AlgorithmState)[block.data]; + if (state.operation != StateExtensive::Operation::Cleanup) + { + VmaBlockVector* vector = m_pBlockVectors[block.data]; + VmaMutexLockWrite lock(vector->GetMutex(), vector->GetAllocator()->m_UseMutex); + + for (size_t i = 0, count = vector->GetBlockCount() - m_ImmovableBlockCount; i < count; ++i) + { + if (vector->GetBlock(i) == block.block) + { + VMA_SWAP(vector->m_Blocks[i], vector->m_Blocks[vector->GetBlockCount() - ++m_ImmovableBlockCount]); + if (state.firstFreeBlock != SIZE_MAX) + { + if (i + 1 < state.firstFreeBlock) + { + if (state.firstFreeBlock > 1) + VMA_SWAP(vector->m_Blocks[i], vector->m_Blocks[--state.firstFreeBlock]); + else + --state.firstFreeBlock; + } + } + swapped = true; + break; + } + } + } + } + if (swapped) + result = VK_INCOMPLETE; + break; } - else + } + default: + { + // Move to the begining + for (const FragmentedBlock& block : immovableBlocks) { - VMA_VALIDATE(node->free.next->free.prev == node); + VmaBlockVector* vector = m_pBlockVectors[block.data]; + VmaMutexLockWrite lock(vector->GetMutex(), vector->GetAllocator()->m_UseMutex); + + for (size_t i = m_ImmovableBlockCount; i < vector->GetBlockCount(); ++i) + { + if (vector->GetBlock(i) == block.block) + { + VMA_SWAP(vector->m_Blocks[i], vector->m_Blocks[m_ImmovableBlockCount++]); + break; + } + } } + break; + } } } - // Validate that free lists ar higher levels are empty. - for (const auto level : c10::irange(m_LevelCount, MAX_LEVELS)) { - VMA_VALIDATE(m_FreeList[level].front == VMA_NULL && m_FreeList[level].back == VMA_NULL); + // Bulk-map destination blocks + for (const FragmentedBlock& block : mappedBlocks) + { + VkResult res = block.block->Map(allocator, block.data, VMA_NULL); + VMA_ASSERT(res == VK_SUCCESS); } - - return true; + return result; } -VkDeviceSize VmaBlockMetadata_Buddy::GetUnusedRangeSizeMax() const +bool VmaDefragmentationContext_T::ComputeDefragmentation(VmaBlockVector& vector, size_t index) { - for (const auto level : c10::irange(m_LevelCount)) { - if(m_FreeList[level].front != VMA_NULL) - { - return LevelToNodeSize(level); - } + switch (m_Algorithm) + { + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_FAST_BIT: + return ComputeDefragmentation_Fast(vector); + default: + VMA_ASSERT(0); + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_BALANCED_BIT: + return ComputeDefragmentation_Balanced(vector, index, true); + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_FULL_BIT: + return ComputeDefragmentation_Full(vector); + case VMA_DEFRAGMENTATION_FLAG_ALGORITHM_EXTENSIVE_BIT: + return ComputeDefragmentation_Extensive(vector, index); } - return 0; } -void VmaBlockMetadata_Buddy::CalcAllocationStatInfo(VmaStatInfo& outInfo) const +VmaDefragmentationContext_T::MoveAllocationData VmaDefragmentationContext_T::GetMoveData( + VmaAllocHandle handle, VmaBlockMetadata* metadata) { - const VkDeviceSize unusableSize = GetUnusableSize(); - - outInfo.blockCount = 1; + MoveAllocationData moveData; + moveData.move.srcAllocation = (VmaAllocation)metadata->GetAllocationUserData(handle); + moveData.size = moveData.move.srcAllocation->GetSize(); + moveData.alignment = moveData.move.srcAllocation->GetAlignment(); + moveData.type = moveData.move.srcAllocation->GetSuballocationType(); + moveData.flags = 0; - outInfo.allocationCount = outInfo.unusedRangeCount = 0; - outInfo.usedBytes = outInfo.unusedBytes = 0; + if (moveData.move.srcAllocation->IsPersistentMap()) + moveData.flags |= VMA_ALLOCATION_CREATE_MAPPED_BIT; + if (moveData.move.srcAllocation->IsMappingAllowed()) + moveData.flags |= VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT; - outInfo.allocationSizeMax = outInfo.unusedRangeSizeMax = 0; - outInfo.allocationSizeMin = outInfo.unusedRangeSizeMin = UINT64_MAX; - outInfo.allocationSizeAvg = outInfo.unusedRangeSizeAvg = 0; // Unused. - - CalcAllocationStatInfoNode(outInfo, m_Root, LevelToNodeSize(0)); + return moveData; +} - if(unusableSize > 0) +VmaDefragmentationContext_T::CounterStatus VmaDefragmentationContext_T::CheckCounters(VkDeviceSize bytes) +{ + // Ignore allocation if will exceed max size for copy + if (m_PassStats.bytesMoved + bytes > m_MaxPassBytes) { - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += unusableSize; - outInfo.unusedRangeSizeMax = VMA_MAX(outInfo.unusedRangeSizeMax, unusableSize); - outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusableSize); + if (++m_IgnoredAllocs < MAX_ALLOCS_TO_IGNORE) + return CounterStatus::Ignore; + else + return CounterStatus::End; } + return CounterStatus::Pass; } -void VmaBlockMetadata_Buddy::AddPoolStats(VmaPoolStats& inoutStats) const +bool VmaDefragmentationContext_T::IncrementCounters(VkDeviceSize bytes) { - const VkDeviceSize unusableSize = GetUnusableSize(); - - inoutStats.size += GetSize(); - inoutStats.unusedSize += m_SumFreeSize + unusableSize; - inoutStats.allocationCount += m_AllocationCount; - inoutStats.unusedRangeCount += m_FreeCount; - inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, GetUnusedRangeSizeMax()); - - if(unusableSize > 0) + m_PassStats.bytesMoved += bytes; + // Early return when max found + if (++m_PassStats.allocationsMoved >= m_MaxPassAllocations || m_PassStats.bytesMoved >= m_MaxPassBytes) { - ++inoutStats.unusedRangeCount; - // Not updating inoutStats.unusedRangeSizeMax with unusableSize because this space is not available for allocations. + VMA_ASSERT(m_PassStats.allocationsMoved == m_MaxPassAllocations || + m_PassStats.bytesMoved == m_MaxPassBytes && "Exceeded maximal pass threshold!"); + return true; } + return false; } -#if VMA_STATS_STRING_ENABLED - -void VmaBlockMetadata_Buddy::PrintDetailedMap(class VmaJsonWriter& json) const +bool VmaDefragmentationContext_T::ReallocWithinBlock(VmaBlockVector& vector, VmaDeviceMemoryBlock* block) { - // TODO optimize - VmaStatInfo stat; - CalcAllocationStatInfo(stat); - - PrintDetailedMap_Begin( - json, - stat.unusedBytes, - stat.allocationCount, - stat.unusedRangeCount); + VmaBlockMetadata* metadata = block->m_pMetadata; - PrintDetailedMapNode(json, m_Root, LevelToNodeSize(0)); - - const VkDeviceSize unusableSize = GetUnusableSize(); - if(unusableSize > 0) + for (VmaAllocHandle handle = metadata->GetAllocationListBegin(); + handle != VK_NULL_HANDLE; + handle = metadata->GetNextAllocation(handle)) { - PrintDetailedMap_UnusedRange(json, - m_UsableSize, // offset - unusableSize); // size + MoveAllocationData moveData = GetMoveData(handle, metadata); + // Ignore newly created allocations by defragmentation algorithm + if (moveData.move.srcAllocation->GetUserData() == this) + continue; + switch (CheckCounters(moveData.move.srcAllocation->GetSize())) + { + case CounterStatus::Ignore: + continue; + case CounterStatus::End: + return true; + default: + VMA_ASSERT(0); + case CounterStatus::Pass: + break; + } + + VkDeviceSize offset = moveData.move.srcAllocation->GetOffset(); + if (offset != 0 && metadata->GetSumFreeSize() >= moveData.size) + { + VmaAllocationRequest request = {}; + if (metadata->CreateAllocationRequest( + moveData.size, + moveData.alignment, + false, + moveData.type, + VMA_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT, + &request)) + { + if (metadata->GetAllocationOffset(request.allocHandle) < offset) + { + if (vector.CommitAllocationRequest( + request, + block, + moveData.alignment, + moveData.flags, + this, + moveData.type, + &moveData.move.dstTmpAllocation) == VK_SUCCESS) + { + m_Moves.push_back(moveData.move); + if (IncrementCounters(moveData.size)) + return true; + } + } + } + } } - - PrintDetailedMap_End(json); + return false; } -#endif // #if VMA_STATS_STRING_ENABLED - -bool VmaBlockMetadata_Buddy::CreateAllocationRequest( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VkDeviceSize bufferImageGranularity, - VkDeviceSize allocSize, - VkDeviceSize allocAlignment, - bool upperAddress, - VmaSuballocationType allocType, - bool canMakeOtherLost, - uint32_t strategy, - VmaAllocationRequest* pAllocationRequest) +bool VmaDefragmentationContext_T::AllocInOtherBlock(size_t start, size_t end, MoveAllocationData& data, VmaBlockVector& vector) { - VMA_ASSERT(!upperAddress && "VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT can be used only with linear algorithm."); - - // Simple way to respect bufferImageGranularity. May be optimized some day. - // Whenever it might be an OPTIMAL image... - if(allocType == VMA_SUBALLOCATION_TYPE_UNKNOWN || - allocType == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || - allocType == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL) + for (; start < end; ++start) { - allocAlignment = VMA_MAX(allocAlignment, bufferImageGranularity); - allocSize = VMA_MAX(allocSize, bufferImageGranularity); + VmaDeviceMemoryBlock* dstBlock = vector.GetBlock(start); + if (dstBlock->m_pMetadata->GetSumFreeSize() >= data.size) + { + if (vector.AllocateFromBlock(dstBlock, + data.size, + data.alignment, + data.flags, + this, + data.type, + 0, + &data.move.dstTmpAllocation) == VK_SUCCESS) + { + m_Moves.push_back(data.move); + if (IncrementCounters(data.size)) + return true; + break; + } + } } + return false; +} - if(allocSize > m_UsableSize) - { - return false; - } +bool VmaDefragmentationContext_T::ComputeDefragmentation_Fast(VmaBlockVector& vector) +{ + // Move only between blocks - const uint32_t targetLevel = AllocSizeToLevel(allocSize); - for(uint32_t level = targetLevel + 1; level--; ) + // Go through allocations in last blocks and try to fit them inside first ones + for (size_t i = vector.GetBlockCount() - 1; i > m_ImmovableBlockCount; --i) { - for(Node* freeNode = m_FreeList[level].front; - freeNode != VMA_NULL; - freeNode = freeNode->free.next) + VmaBlockMetadata* metadata = vector.GetBlock(i)->m_pMetadata; + + for (VmaAllocHandle handle = metadata->GetAllocationListBegin(); + handle != VK_NULL_HANDLE; + handle = metadata->GetNextAllocation(handle)) { - if(freeNode->offset % allocAlignment == 0) + MoveAllocationData moveData = GetMoveData(handle, metadata); + // Ignore newly created allocations by defragmentation algorithm + if (moveData.move.srcAllocation->GetUserData() == this) + continue; + switch (CheckCounters(moveData.move.srcAllocation->GetSize())) { - pAllocationRequest->type = VmaAllocationRequestType::Normal; - pAllocationRequest->offset = freeNode->offset; - pAllocationRequest->sumFreeSize = LevelToNodeSize(level); - pAllocationRequest->sumItemSize = 0; - pAllocationRequest->itemsToMakeLostCount = 0; - pAllocationRequest->customData = (void*)(uintptr_t)level; + case CounterStatus::Ignore: + continue; + case CounterStatus::End: return true; + default: + VMA_ASSERT(0); + case CounterStatus::Pass: + break; } + + // Check all previous blocks for free space + if (AllocInOtherBlock(0, i, moveData, vector)) + return true; } } - return false; } -bool VmaBlockMetadata_Buddy::MakeRequestedAllocationsLost( - uint32_t currentFrameIndex, - uint32_t frameInUseCount, - VmaAllocationRequest* pAllocationRequest) +bool VmaDefragmentationContext_T::ComputeDefragmentation_Balanced(VmaBlockVector& vector, size_t index, bool update) { - /* - Lost allocations are not supported in buddy allocator at the moment. - Support might be added in the future. - */ - return pAllocationRequest->itemsToMakeLostCount == 0; -} + // Go over every allocation and try to fit it in previous blocks at lowest offsets, + // if not possible: realloc within single block to minimize offset (exclude offset == 0), + // but only if there are noticable gaps between them (some heuristic, ex. average size of allocation in block) + VMA_ASSERT(m_AlgorithmState != VMA_NULL); -uint32_t VmaBlockMetadata_Buddy::MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) -{ - /* - Lost allocations are not supported in buddy allocator at the moment. - Support might be added in the future. - */ - return 0; -} + StateBalanced& vectorState = reinterpret_cast(m_AlgorithmState)[index]; + if (update && vectorState.avgAllocSize == UINT64_MAX) + UpdateVectorStatistics(vector, vectorState); -void VmaBlockMetadata_Buddy::Alloc( - const VmaAllocationRequest& request, - VmaSuballocationType type, - VkDeviceSize allocSize, - VmaAllocation hAllocation) -{ - VMA_ASSERT(request.type == VmaAllocationRequestType::Normal); + const size_t startMoveCount = m_Moves.size(); + VkDeviceSize minimalFreeRegion = vectorState.avgFreeSize / 2; + for (size_t i = vector.GetBlockCount() - 1; i > m_ImmovableBlockCount; --i) + { + VmaDeviceMemoryBlock* block = vector.GetBlock(i); + VmaBlockMetadata* metadata = block->m_pMetadata; + VkDeviceSize prevFreeRegionSize = 0; - const uint32_t targetLevel = AllocSizeToLevel(allocSize); - uint32_t currLevel = (uint32_t)(uintptr_t)request.customData; + for (VmaAllocHandle handle = metadata->GetAllocationListBegin(); + handle != VK_NULL_HANDLE; + handle = metadata->GetNextAllocation(handle)) + { + MoveAllocationData moveData = GetMoveData(handle, metadata); + // Ignore newly created allocations by defragmentation algorithm + if (moveData.move.srcAllocation->GetUserData() == this) + continue; + switch (CheckCounters(moveData.move.srcAllocation->GetSize())) + { + case CounterStatus::Ignore: + continue; + case CounterStatus::End: + return true; + default: + VMA_ASSERT(0); + case CounterStatus::Pass: + break; + } - Node* currNode = m_FreeList[currLevel].front; - VMA_ASSERT(currNode != VMA_NULL && currNode->type == Node::TYPE_FREE); - while(currNode->offset != request.offset) + // Check all previous blocks for free space + const size_t prevMoveCount = m_Moves.size(); + if (AllocInOtherBlock(0, i, moveData, vector)) + return true; + + VkDeviceSize nextFreeRegionSize = metadata->GetNextFreeRegionSize(handle); + // If no room found then realloc within block for lower offset + VkDeviceSize offset = moveData.move.srcAllocation->GetOffset(); + if (prevMoveCount == m_Moves.size() && offset != 0 && metadata->GetSumFreeSize() >= moveData.size) + { + // Check if realloc will make sense + if (prevFreeRegionSize >= minimalFreeRegion || + nextFreeRegionSize >= minimalFreeRegion || + moveData.size <= vectorState.avgFreeSize || + moveData.size <= vectorState.avgAllocSize) + { + VmaAllocationRequest request = {}; + if (metadata->CreateAllocationRequest( + moveData.size, + moveData.alignment, + false, + moveData.type, + VMA_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT, + &request)) + { + if (metadata->GetAllocationOffset(request.allocHandle) < offset) + { + if (vector.CommitAllocationRequest( + request, + block, + moveData.alignment, + moveData.flags, + this, + moveData.type, + &moveData.move.dstTmpAllocation) == VK_SUCCESS) + { + m_Moves.push_back(moveData.move); + if (IncrementCounters(moveData.size)) + return true; + } + } + } + } + } + prevFreeRegionSize = nextFreeRegionSize; + } + } + + // No moves perfomed, update statistics to current vector state + if (startMoveCount == m_Moves.size() && !update) { - currNode = currNode->free.next; - VMA_ASSERT(currNode != VMA_NULL && currNode->type == Node::TYPE_FREE); + vectorState.avgAllocSize = UINT64_MAX; + return ComputeDefragmentation_Balanced(vector, index, false); } + return false; +} - // Go down, splitting free nodes. - while(currLevel < targetLevel) +bool VmaDefragmentationContext_T::ComputeDefragmentation_Full(VmaBlockVector& vector) +{ + // Go over every allocation and try to fit it in previous blocks at lowest offsets, + // if not possible: realloc within single block to minimize offset (exclude offset == 0) + + for (size_t i = vector.GetBlockCount() - 1; i > m_ImmovableBlockCount; --i) { - // currNode is already first free node at currLevel. - // Remove it from list of free nodes at this currLevel. - RemoveFromFreeList(currLevel, currNode); + VmaDeviceMemoryBlock* block = vector.GetBlock(i); + VmaBlockMetadata* metadata = block->m_pMetadata; - const uint32_t childrenLevel = currLevel + 1; + for (VmaAllocHandle handle = metadata->GetAllocationListBegin(); + handle != VK_NULL_HANDLE; + handle = metadata->GetNextAllocation(handle)) + { + MoveAllocationData moveData = GetMoveData(handle, metadata); + // Ignore newly created allocations by defragmentation algorithm + if (moveData.move.srcAllocation->GetUserData() == this) + continue; + switch (CheckCounters(moveData.move.srcAllocation->GetSize())) + { + case CounterStatus::Ignore: + continue; + case CounterStatus::End: + return true; + default: + VMA_ASSERT(0); + case CounterStatus::Pass: + break; + } - // Create two free sub-nodes. - Node* leftChild = vma_new(GetAllocationCallbacks(), Node)(); - Node* rightChild = vma_new(GetAllocationCallbacks(), Node)(); + // Check all previous blocks for free space + const size_t prevMoveCount = m_Moves.size(); + if (AllocInOtherBlock(0, i, moveData, vector)) + return true; - leftChild->offset = currNode->offset; - leftChild->type = Node::TYPE_FREE; - leftChild->parent = currNode; - leftChild->buddy = rightChild; - - rightChild->offset = currNode->offset + LevelToNodeSize(childrenLevel); - rightChild->type = Node::TYPE_FREE; - rightChild->parent = currNode; - rightChild->buddy = leftChild; - - // Convert current currNode to split type. - currNode->type = Node::TYPE_SPLIT; - currNode->split.leftChild = leftChild; + // If no room found then realloc within block for lower offset + VkDeviceSize offset = moveData.move.srcAllocation->GetOffset(); + if (prevMoveCount == m_Moves.size() && offset != 0 && metadata->GetSumFreeSize() >= moveData.size) + { + VmaAllocationRequest request = {}; + if (metadata->CreateAllocationRequest( + moveData.size, + moveData.alignment, + false, + moveData.type, + VMA_ALLOCATION_CREATE_STRATEGY_MIN_OFFSET_BIT, + &request)) + { + if (metadata->GetAllocationOffset(request.allocHandle) < offset) + { + if (vector.CommitAllocationRequest( + request, + block, + moveData.alignment, + moveData.flags, + this, + moveData.type, + &moveData.move.dstTmpAllocation) == VK_SUCCESS) + { + m_Moves.push_back(moveData.move); + if (IncrementCounters(moveData.size)) + return true; + } + } + } + } + } + } + return false; +} - // Add child nodes to free list. Order is important! - AddToFreeListFront(childrenLevel, rightChild); - AddToFreeListFront(childrenLevel, leftChild); +bool VmaDefragmentationContext_T::ComputeDefragmentation_Extensive(VmaBlockVector& vector, size_t index) +{ + // First free single block, then populate it to the brim, then free another block, and so on - ++m_FreeCount; - //m_SumFreeSize -= LevelToNodeSize(currLevel) % 2; // Useful only when level node sizes can be non power of 2. - ++currLevel; - currNode = m_FreeList[currLevel].front; + // Fallback to previous algorithm since without granularity conflicts it can achieve max packing + if (vector.m_BufferImageGranularity == 1) + return ComputeDefragmentation_Full(vector); - /* - We can be sure that currNode, as left child of node previously split, - also fullfills the alignment requirement. - */ - } + VMA_ASSERT(m_AlgorithmState != VMA_NULL); - // Remove from free list. - VMA_ASSERT(currLevel == targetLevel && - currNode != VMA_NULL && - currNode->type == Node::TYPE_FREE); - RemoveFromFreeList(currLevel, currNode); + StateExtensive& vectorState = reinterpret_cast(m_AlgorithmState)[index]; - // Convert to allocation node. - currNode->type = Node::TYPE_ALLOCATION; - currNode->allocation.alloc = hAllocation; + bool texturePresent = false, bufferPresent = false, otherPresent = false; + switch (vectorState.operation) + { + case StateExtensive::Operation::Done: // Vector defragmented + return false; + case StateExtensive::Operation::FindFreeBlockBuffer: + case StateExtensive::Operation::FindFreeBlockTexture: + case StateExtensive::Operation::FindFreeBlockAll: + { + // No more blocks to free, just perform fast realloc and move to cleanup + if (vectorState.firstFreeBlock == 0) + { + vectorState.operation = StateExtensive::Operation::Cleanup; + return ComputeDefragmentation_Fast(vector); + } - ++m_AllocationCount; - --m_FreeCount; - m_SumFreeSize -= allocSize; -} + // No free blocks, have to clear last one + size_t last = (vectorState.firstFreeBlock == SIZE_MAX ? vector.GetBlockCount() : vectorState.firstFreeBlock) - 1; + VmaBlockMetadata* freeMetadata = vector.GetBlock(last)->m_pMetadata; -void VmaBlockMetadata_Buddy::DeleteNode(Node* node) -{ - if(node->type == Node::TYPE_SPLIT) - { - DeleteNode(node->split.leftChild->buddy); - DeleteNode(node->split.leftChild); - } + const size_t prevMoveCount = m_Moves.size(); + for (VmaAllocHandle handle = freeMetadata->GetAllocationListBegin(); + handle != VK_NULL_HANDLE; + handle = freeMetadata->GetNextAllocation(handle)) + { + MoveAllocationData moveData = GetMoveData(handle, freeMetadata); + switch (CheckCounters(moveData.move.srcAllocation->GetSize())) + { + case CounterStatus::Ignore: + continue; + case CounterStatus::End: + return true; + default: + VMA_ASSERT(0); + case CounterStatus::Pass: + break; + } - vma_delete(GetAllocationCallbacks(), node); -} + // Check all previous blocks for free space + if (AllocInOtherBlock(0, last, moveData, vector)) + { + // Full clear performed already + if (prevMoveCount != m_Moves.size() && freeMetadata->GetNextAllocation(handle) == VK_NULL_HANDLE) + reinterpret_cast(m_AlgorithmState)[index] = last; + return true; + } + } -bool VmaBlockMetadata_Buddy::ValidateNode(ValidationContext& ctx, const Node* parent, const Node* curr, uint32_t level, VkDeviceSize levelNodeSize) const -{ - VMA_VALIDATE(level < m_LevelCount); - VMA_VALIDATE(curr->parent == parent); - VMA_VALIDATE((curr->buddy == VMA_NULL) == (parent == VMA_NULL)); - VMA_VALIDATE(curr->buddy == VMA_NULL || curr->buddy->buddy == curr); - switch(curr->type) - { - case Node::TYPE_FREE: - // curr->free.prev, next are validated separately. - ctx.calculatedSumFreeSize += levelNodeSize; - ++ctx.calculatedFreeCount; - break; - case Node::TYPE_ALLOCATION: - ++ctx.calculatedAllocationCount; - ctx.calculatedSumFreeSize += levelNodeSize - curr->allocation.alloc->GetSize(); - VMA_VALIDATE(curr->allocation.alloc != VK_NULL_HANDLE); - break; - case Node::TYPE_SPLIT: + if (prevMoveCount == m_Moves.size()) { - const uint32_t childrenLevel = level + 1; - const VkDeviceSize childrenLevelNodeSize = levelNodeSize / 2; - const Node* const leftChild = curr->split.leftChild; - VMA_VALIDATE(leftChild != VMA_NULL); - VMA_VALIDATE(leftChild->offset == curr->offset); - if(!ValidateNode(ctx, curr, leftChild, childrenLevel, childrenLevelNodeSize)) + // Cannot perform full clear, have to move data in other blocks around + if (last != 0) + { + for (size_t i = last - 1; i; --i) + { + if (ReallocWithinBlock(vector, vector.GetBlock(i))) + return true; + } + } + + if (prevMoveCount == m_Moves.size()) { - VMA_VALIDATE(false && "ValidateNode for left child failed."); + // No possible reallocs within blocks, try to move them around fast + return ComputeDefragmentation_Fast(vector); } - const Node* const rightChild = leftChild->buddy; - VMA_VALIDATE(rightChild->offset == curr->offset + childrenLevelNodeSize); - if(!ValidateNode(ctx, curr, rightChild, childrenLevel, childrenLevelNodeSize)) + } + else + { + switch (vectorState.operation) { - VMA_VALIDATE(false && "ValidateNode for right child failed."); + case StateExtensive::Operation::FindFreeBlockBuffer: + vectorState.operation = StateExtensive::Operation::MoveBuffers; + break; + default: + VMA_ASSERT(0); + case StateExtensive::Operation::FindFreeBlockTexture: + vectorState.operation = StateExtensive::Operation::MoveTextures; + break; + case StateExtensive::Operation::FindFreeBlockAll: + vectorState.operation = StateExtensive::Operation::MoveAll; + break; } + vectorState.firstFreeBlock = last; + // Nothing done, block found without reallocations, can perform another reallocs in same pass + return ComputeDefragmentation_Extensive(vector, index); } break; - default: - return false; } + case StateExtensive::Operation::MoveTextures: + { + if (MoveDataToFreeBlocks(VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL, vector, + vectorState.firstFreeBlock, texturePresent, bufferPresent, otherPresent)) + { + if (texturePresent) + { + vectorState.operation = StateExtensive::Operation::FindFreeBlockTexture; + return ComputeDefragmentation_Extensive(vector, index); + } - return true; -} + if (!bufferPresent && !otherPresent) + { + vectorState.operation = StateExtensive::Operation::Cleanup; + break; + } -uint32_t VmaBlockMetadata_Buddy::AllocSizeToLevel(VkDeviceSize allocSize) const -{ - // I know this could be optimized somehow e.g. by using std::log2p1 from C++20. - uint32_t level = 0; - VkDeviceSize currLevelNodeSize = m_UsableSize; - VkDeviceSize nextLevelNodeSize = currLevelNodeSize >> 1; - while(allocSize <= nextLevelNodeSize && level + 1 < m_LevelCount) - { - ++level; - currLevelNodeSize = nextLevelNodeSize; - nextLevelNodeSize = currLevelNodeSize >> 1; + // No more textures to move, check buffers + vectorState.operation = StateExtensive::Operation::MoveBuffers; + bufferPresent = false; + otherPresent = false; + } + else + break; } - return level; -} - -void VmaBlockMetadata_Buddy::FreeAtOffset(VmaAllocation alloc, VkDeviceSize offset) -{ - // Find node and level. - Node* node = m_Root; - VkDeviceSize nodeOffset = 0; - uint32_t level = 0; - VkDeviceSize levelNodeSize = LevelToNodeSize(0); - while(node->type == Node::TYPE_SPLIT) + case StateExtensive::Operation::MoveBuffers: { - const VkDeviceSize nextLevelSize = levelNodeSize >> 1; - if(offset < nodeOffset + nextLevelSize) + if (MoveDataToFreeBlocks(VMA_SUBALLOCATION_TYPE_BUFFER, vector, + vectorState.firstFreeBlock, texturePresent, bufferPresent, otherPresent)) { - node = node->split.leftChild; + if (bufferPresent) + { + vectorState.operation = StateExtensive::Operation::FindFreeBlockBuffer; + return ComputeDefragmentation_Extensive(vector, index); + } + + if (!otherPresent) + { + vectorState.operation = StateExtensive::Operation::Cleanup; + break; + } + + // No more buffers to move, check all others + vectorState.operation = StateExtensive::Operation::MoveAll; + otherPresent = false; } else + break; + } + case StateExtensive::Operation::MoveAll: + { + if (MoveDataToFreeBlocks(VMA_SUBALLOCATION_TYPE_FREE, vector, + vectorState.firstFreeBlock, texturePresent, bufferPresent, otherPresent)) { - node = node->split.leftChild->buddy; - nodeOffset += nextLevelSize; + if (otherPresent) + { + vectorState.operation = StateExtensive::Operation::FindFreeBlockBuffer; + return ComputeDefragmentation_Extensive(vector, index); + } + // Everything moved + vectorState.operation = StateExtensive::Operation::Cleanup; } - ++level; - levelNodeSize = nextLevelSize; + break; + } + case StateExtensive::Operation::Cleanup: + // Cleanup is handled below so that other operations may reuse the cleanup code. This case is here to prevent the unhandled enum value warning (C4062). + break; } - VMA_ASSERT(node != VMA_NULL && node->type == Node::TYPE_ALLOCATION); - VMA_ASSERT(alloc == VK_NULL_HANDLE || node->allocation.alloc == alloc); + if (vectorState.operation == StateExtensive::Operation::Cleanup) + { + // All other work done, pack data in blocks even tighter if possible + const size_t prevMoveCount = m_Moves.size(); + for (size_t i = 0; i < vector.GetBlockCount(); ++i) + { + if (ReallocWithinBlock(vector, vector.GetBlock(i))) + return true; + } - ++m_FreeCount; - --m_AllocationCount; - m_SumFreeSize += alloc->GetSize(); + if (prevMoveCount == m_Moves.size()) + vectorState.operation = StateExtensive::Operation::Done; + } + return false; +} - node->type = Node::TYPE_FREE; +void VmaDefragmentationContext_T::UpdateVectorStatistics(VmaBlockVector& vector, StateBalanced& state) +{ + size_t allocCount = 0; + size_t freeCount = 0; + state.avgFreeSize = 0; + state.avgAllocSize = 0; - // Join free nodes if possible. - while(level > 0 && node->buddy->type == Node::TYPE_FREE) + for (size_t i = 0; i < vector.GetBlockCount(); ++i) { - RemoveFromFreeList(level, node->buddy); - Node* const parent = node->parent; - - vma_delete(GetAllocationCallbacks(), node->buddy); - vma_delete(GetAllocationCallbacks(), node); - parent->type = Node::TYPE_FREE; + VmaBlockMetadata* metadata = vector.GetBlock(i)->m_pMetadata; - node = parent; - --level; - //m_SumFreeSize += LevelToNodeSize(level) % 2; // Useful only when level node sizes can be non power of 2. - --m_FreeCount; + allocCount += metadata->GetAllocationCount(); + freeCount += metadata->GetFreeRegionsCount(); + state.avgFreeSize += metadata->GetSumFreeSize(); + state.avgAllocSize += metadata->GetSize(); } - AddToFreeListFront(level, node); + state.avgAllocSize = (state.avgAllocSize - state.avgFreeSize) / allocCount; + state.avgFreeSize /= freeCount; } -void VmaBlockMetadata_Buddy::CalcAllocationStatInfoNode(VmaStatInfo& outInfo, const Node* node, VkDeviceSize levelNodeSize) const +bool VmaDefragmentationContext_T::MoveDataToFreeBlocks(VmaSuballocationType currentType, + VmaBlockVector& vector, size_t firstFreeBlock, + bool& texturePresent, bool& bufferPresent, bool& otherPresent) { - switch(node->type) + const size_t prevMoveCount = m_Moves.size(); + for (size_t i = firstFreeBlock ; i;) { - case Node::TYPE_FREE: - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += levelNodeSize; - outInfo.unusedRangeSizeMax = VMA_MAX(outInfo.unusedRangeSizeMax, levelNodeSize); - outInfo.unusedRangeSizeMin = VMA_MAX(outInfo.unusedRangeSizeMin, levelNodeSize); - break; - case Node::TYPE_ALLOCATION: + VmaDeviceMemoryBlock* block = vector.GetBlock(--i); + VmaBlockMetadata* metadata = block->m_pMetadata; + + for (VmaAllocHandle handle = metadata->GetAllocationListBegin(); + handle != VK_NULL_HANDLE; + handle = metadata->GetNextAllocation(handle)) { - const VkDeviceSize allocSize = node->allocation.alloc->GetSize(); - ++outInfo.allocationCount; - outInfo.usedBytes += allocSize; - outInfo.allocationSizeMax = VMA_MAX(outInfo.allocationSizeMax, allocSize); - outInfo.allocationSizeMin = VMA_MAX(outInfo.allocationSizeMin, allocSize); + MoveAllocationData moveData = GetMoveData(handle, metadata); + // Ignore newly created allocations by defragmentation algorithm + if (moveData.move.srcAllocation->GetUserData() == this) + continue; + switch (CheckCounters(moveData.move.srcAllocation->GetSize())) + { + case CounterStatus::Ignore: + continue; + case CounterStatus::End: + return true; + default: + VMA_ASSERT(0); + case CounterStatus::Pass: + break; + } - const VkDeviceSize unusedRangeSize = levelNodeSize - allocSize; - if(unusedRangeSize > 0) + // Move only single type of resources at once + if (!VmaIsBufferImageGranularityConflict(moveData.type, currentType)) { - ++outInfo.unusedRangeCount; - outInfo.unusedBytes += unusedRangeSize; - outInfo.unusedRangeSizeMax = VMA_MAX(outInfo.unusedRangeSizeMax, unusedRangeSize); - outInfo.unusedRangeSizeMin = VMA_MAX(outInfo.unusedRangeSizeMin, unusedRangeSize); + // Try to fit allocation into free blocks + if (AllocInOtherBlock(firstFreeBlock, vector.GetBlockCount(), moveData, vector)) + return false; } + + if (!VmaIsBufferImageGranularityConflict(moveData.type, VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL)) + texturePresent = true; + else if (!VmaIsBufferImageGranularityConflict(moveData.type, VMA_SUBALLOCATION_TYPE_BUFFER)) + bufferPresent = true; + else + otherPresent = true; } - break; - case Node::TYPE_SPLIT: - { - const VkDeviceSize childrenNodeSize = levelNodeSize / 2; - const Node* const leftChild = node->split.leftChild; - CalcAllocationStatInfoNode(outInfo, leftChild, childrenNodeSize); - const Node* const rightChild = leftChild->buddy; - CalcAllocationStatInfoNode(outInfo, rightChild, childrenNodeSize); - } - break; - default: - VMA_ASSERT(0); } + return prevMoveCount == m_Moves.size(); } +#endif // _VMA_DEFRAGMENTATION_CONTEXT_FUNCTIONS -void VmaBlockMetadata_Buddy::AddToFreeListFront(uint32_t level, Node* node) +#ifndef _VMA_POOL_T_FUNCTIONS +VmaPool_T::VmaPool_T( + VmaAllocator hAllocator, + const VmaPoolCreateInfo& createInfo, + VkDeviceSize preferredBlockSize) + : m_BlockVector( + hAllocator, + this, // hParentPool + createInfo.memoryTypeIndex, + createInfo.blockSize != 0 ? createInfo.blockSize : preferredBlockSize, + createInfo.minBlockCount, + createInfo.maxBlockCount, + (createInfo.flags& VMA_POOL_CREATE_IGNORE_BUFFER_IMAGE_GRANULARITY_BIT) != 0 ? 1 : hAllocator->GetBufferImageGranularity(), + createInfo.blockSize != 0, // explicitBlockSize + createInfo.flags & VMA_POOL_CREATE_ALGORITHM_MASK, // algorithm + createInfo.priority, + VMA_MAX(hAllocator->GetMemoryTypeMinAlignment(createInfo.memoryTypeIndex), createInfo.minAllocationAlignment), + createInfo.pMemoryAllocateNext), + m_Id(0), + m_Name(VMA_NULL) {} + +VmaPool_T::~VmaPool_T() { - VMA_ASSERT(node->type == Node::TYPE_FREE); + VMA_ASSERT(m_PrevPool == VMA_NULL && m_NextPool == VMA_NULL); +} - // List is empty. - Node* const frontNode = m_FreeList[level].front; - if(frontNode == VMA_NULL) +void VmaPool_T::SetName(const char* pName) +{ + const VkAllocationCallbacks* allocs = m_BlockVector.GetAllocator()->GetAllocationCallbacks(); + VmaFreeString(allocs, m_Name); + + if (pName != VMA_NULL) { - VMA_ASSERT(m_FreeList[level].back == VMA_NULL); - node->free.prev = node->free.next = VMA_NULL; - m_FreeList[level].front = m_FreeList[level].back = node; + m_Name = VmaCreateStringCopy(allocs, pName); } else { - VMA_ASSERT(frontNode->free.prev == VMA_NULL); - node->free.prev = VMA_NULL; - node->free.next = frontNode; - frontNode->free.prev = node; - m_FreeList[level].front = node; + m_Name = VMA_NULL; } } +#endif // _VMA_POOL_T_FUNCTIONS -void VmaBlockMetadata_Buddy::RemoveFromFreeList(uint32_t level, Node* node) +#ifndef _VMA_ALLOCATOR_T_FUNCTIONS +VmaAllocator_T::VmaAllocator_T(const VmaAllocatorCreateInfo* pCreateInfo) : + m_UseMutex((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT) == 0), + m_VulkanApiVersion(pCreateInfo->vulkanApiVersion != 0 ? pCreateInfo->vulkanApiVersion : VK_API_VERSION_1_0), + m_UseKhrDedicatedAllocation((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT) != 0), + m_UseKhrBindMemory2((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT) != 0), + m_UseExtMemoryBudget((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT) != 0), + m_UseAmdDeviceCoherentMemory((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT) != 0), + m_UseKhrBufferDeviceAddress((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT) != 0), + m_UseExtMemoryPriority((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT) != 0), + m_hDevice(pCreateInfo->device), + m_hInstance(pCreateInfo->instance), + m_AllocationCallbacksSpecified(pCreateInfo->pAllocationCallbacks != VMA_NULL), + m_AllocationCallbacks(pCreateInfo->pAllocationCallbacks ? + *pCreateInfo->pAllocationCallbacks : VmaEmptyAllocationCallbacks), + m_AllocationObjectAllocator(&m_AllocationCallbacks), + m_HeapSizeLimitMask(0), + m_DeviceMemoryCount(0), + m_PreferredLargeHeapBlockSize(0), + m_PhysicalDevice(pCreateInfo->physicalDevice), + m_GpuDefragmentationMemoryTypeBits(UINT32_MAX), + m_NextPoolId(0), + m_GlobalMemoryTypeBits(UINT32_MAX) { - VMA_ASSERT(m_FreeList[level].front != VMA_NULL); - - // It is at the front. - if(node->free.prev == VMA_NULL) - { - VMA_ASSERT(m_FreeList[level].front == node); - m_FreeList[level].front = node->free.next; - } - else + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) { - Node* const prevFreeNode = node->free.prev; - VMA_ASSERT(prevFreeNode->free.next == node); - prevFreeNode->free.next = node->free.next; + m_UseKhrDedicatedAllocation = false; + m_UseKhrBindMemory2 = false; } - // It is at the back. - if(node->free.next == VMA_NULL) - { - VMA_ASSERT(m_FreeList[level].back == node); - m_FreeList[level].back = node->free.prev; - } - else + if(VMA_DEBUG_DETECT_CORRUPTION) { - Node* const nextFreeNode = node->free.next; - VMA_ASSERT(nextFreeNode->free.prev == node); - nextFreeNode->free.prev = node->free.prev; + // Needs to be multiply of uint32_t size because we are going to write VMA_CORRUPTION_DETECTION_MAGIC_VALUE to it. + VMA_ASSERT(VMA_DEBUG_MARGIN % sizeof(uint32_t) == 0); } -} -#if VMA_STATS_STRING_ENABLED -void VmaBlockMetadata_Buddy::PrintDetailedMapNode(class VmaJsonWriter& json, const Node* node, VkDeviceSize levelNodeSize) const -{ - switch(node->type) + VMA_ASSERT(pCreateInfo->physicalDevice && pCreateInfo->device && pCreateInfo->instance); + + if(m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0)) { - case Node::TYPE_FREE: - PrintDetailedMap_UnusedRange(json, node->offset, levelNodeSize); - break; - case Node::TYPE_ALLOCATION: +#if !(VMA_DEDICATED_ALLOCATION) + if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT) != 0) { - PrintDetailedMap_Allocation(json, node->offset, node->allocation.alloc); - const VkDeviceSize allocSize = node->allocation.alloc->GetSize(); - if(allocSize < levelNodeSize) - { - PrintDetailedMap_UnusedRange(json, node->offset + allocSize, levelNodeSize - allocSize); - } + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT set but required extensions are disabled by preprocessor macros."); } - break; - case Node::TYPE_SPLIT: +#endif +#if !(VMA_BIND_MEMORY2) + if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT) != 0) { - const VkDeviceSize childrenNodeSize = levelNodeSize / 2; - const Node* const leftChild = node->split.leftChild; - PrintDetailedMapNode(json, leftChild, childrenNodeSize); - const Node* const rightChild = leftChild->buddy; - PrintDetailedMapNode(json, rightChild, childrenNodeSize); + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT set but required extension is disabled by preprocessor macros."); } - break; - default: - VMA_ASSERT(0); +#endif } -} -#endif // #if VMA_STATS_STRING_ENABLED - - -//////////////////////////////////////////////////////////////////////////////// -// class VmaDeviceMemoryBlock - -VmaDeviceMemoryBlock::VmaDeviceMemoryBlock(VmaAllocator hAllocator) : - m_pMetadata(VMA_NULL), - m_MemoryTypeIndex(UINT32_MAX), - m_Id(0), - m_hMemory(VK_NULL_HANDLE), - m_MapCount(0), - m_pMappedData(VMA_NULL) -{ -} - -void VmaDeviceMemoryBlock::Init( - VmaAllocator hAllocator, - VmaPool hParentPool, - uint32_t newMemoryTypeIndex, - VkDeviceMemory newMemory, - VkDeviceSize newSize, - uint32_t id, - uint32_t algorithm) -{ - VMA_ASSERT(m_hMemory == VK_NULL_HANDLE); - - m_hParentPool = hParentPool; - m_MemoryTypeIndex = newMemoryTypeIndex; - m_Id = id; - m_hMemory = newMemory; - - switch(algorithm) +#if !(VMA_MEMORY_BUDGET) + if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT) != 0) { - case VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT: - m_pMetadata = vma_new(hAllocator, VmaBlockMetadata_Linear)(hAllocator); - break; - case VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT: - m_pMetadata = vma_new(hAllocator, VmaBlockMetadata_Buddy)(hAllocator); - break; - default: - VMA_ASSERT(0); - // Fall-through. - case 0: - m_pMetadata = vma_new(hAllocator, VmaBlockMetadata_Generic)(hAllocator); + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT set but required extension is disabled by preprocessor macros."); } - m_pMetadata->Init(newSize); -} - -void VmaDeviceMemoryBlock::Destroy(VmaAllocator allocator) -{ - // This is the most important assert in the entire library. - // Hitting it means you have some memory leak - unreleased VmaAllocation objects. - VMA_ASSERT(m_pMetadata->IsEmpty() && "Some allocations were not freed before destruction of this memory block!"); - - VMA_ASSERT(m_hMemory != VK_NULL_HANDLE); - allocator->FreeVulkanMemory(m_MemoryTypeIndex, m_pMetadata->GetSize(), m_hMemory); - m_hMemory = VK_NULL_HANDLE; +#endif +#if !(VMA_BUFFER_DEVICE_ADDRESS) + if(m_UseKhrBufferDeviceAddress) + { + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT is set but required extension or Vulkan 1.2 is not available in your Vulkan header or its support in VMA has been disabled by a preprocessor macro."); + } +#endif +#if VMA_VULKAN_VERSION < 1002000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 2, 0)) + { + VMA_ASSERT(0 && "vulkanApiVersion >= VK_API_VERSION_1_2 but required Vulkan version is disabled by preprocessor macros."); + } +#endif +#if VMA_VULKAN_VERSION < 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + VMA_ASSERT(0 && "vulkanApiVersion >= VK_API_VERSION_1_1 but required Vulkan version is disabled by preprocessor macros."); + } +#endif +#if !(VMA_MEMORY_PRIORITY) + if(m_UseExtMemoryPriority) + { + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT is set but required extension is not available in your Vulkan header or its support in VMA has been disabled by a preprocessor macro."); + } +#endif - vma_delete(allocator, m_pMetadata); - m_pMetadata = VMA_NULL; -} + memset(&m_DeviceMemoryCallbacks, 0 ,sizeof(m_DeviceMemoryCallbacks)); + memset(&m_PhysicalDeviceProperties, 0, sizeof(m_PhysicalDeviceProperties)); + memset(&m_MemProps, 0, sizeof(m_MemProps)); -bool VmaDeviceMemoryBlock::Validate() const -{ - VMA_VALIDATE((m_hMemory != VK_NULL_HANDLE) && - (m_pMetadata->GetSize() != 0)); + memset(&m_pBlockVectors, 0, sizeof(m_pBlockVectors)); + memset(&m_VulkanFunctions, 0, sizeof(m_VulkanFunctions)); - return m_pMetadata->Validate(); -} +#if VMA_EXTERNAL_MEMORY + memset(&m_TypeExternalMemoryHandleTypes, 0, sizeof(m_TypeExternalMemoryHandleTypes)); +#endif // #if VMA_EXTERNAL_MEMORY -VkResult VmaDeviceMemoryBlock::CheckCorruption(VmaAllocator hAllocator) -{ - void* pData = nullptr; - VkResult res = Map(hAllocator, 1, &pData); - if(res != VK_SUCCESS) + if(pCreateInfo->pDeviceMemoryCallbacks != VMA_NULL) { - return res; + m_DeviceMemoryCallbacks.pUserData = pCreateInfo->pDeviceMemoryCallbacks->pUserData; + m_DeviceMemoryCallbacks.pfnAllocate = pCreateInfo->pDeviceMemoryCallbacks->pfnAllocate; + m_DeviceMemoryCallbacks.pfnFree = pCreateInfo->pDeviceMemoryCallbacks->pfnFree; } - res = m_pMetadata->CheckCorruption(pData); + ImportVulkanFunctions(pCreateInfo->pVulkanFunctions); - Unmap(hAllocator, 1); + (*m_VulkanFunctions.vkGetPhysicalDeviceProperties)(m_PhysicalDevice, &m_PhysicalDeviceProperties); + (*m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties)(m_PhysicalDevice, &m_MemProps); - return res; -} + VMA_ASSERT(VmaIsPow2(VMA_MIN_ALIGNMENT)); + VMA_ASSERT(VmaIsPow2(VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY)); + VMA_ASSERT(VmaIsPow2(m_PhysicalDeviceProperties.limits.bufferImageGranularity)); + VMA_ASSERT(VmaIsPow2(m_PhysicalDeviceProperties.limits.nonCoherentAtomSize)); -VkResult VmaDeviceMemoryBlock::Map(VmaAllocator hAllocator, uint32_t count, void** ppData) -{ - if(count == 0) - { - return VK_SUCCESS; - } + m_PreferredLargeHeapBlockSize = (pCreateInfo->preferredLargeHeapBlockSize != 0) ? + pCreateInfo->preferredLargeHeapBlockSize : static_cast(VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE); - VmaMutexLock lock(m_Mutex, hAllocator->m_UseMutex); - if(m_MapCount != 0) + m_GlobalMemoryTypeBits = CalculateGlobalMemoryTypeBits(); + +#if VMA_EXTERNAL_MEMORY + if(pCreateInfo->pTypeExternalMemoryHandleTypes != VMA_NULL) { - m_MapCount += count; - VMA_ASSERT(m_pMappedData != VMA_NULL); - if(ppData != VMA_NULL) - { - *ppData = m_pMappedData; - } - return VK_SUCCESS; + memcpy(m_TypeExternalMemoryHandleTypes, pCreateInfo->pTypeExternalMemoryHandleTypes, + sizeof(VkExternalMemoryHandleTypeFlagsKHR) * GetMemoryTypeCount()); } - else +#endif // #if VMA_EXTERNAL_MEMORY + + if(pCreateInfo->pHeapSizeLimit != VMA_NULL) { - VkResult result = (*hAllocator->GetVulkanFunctions().vkMapMemory)( - hAllocator->m_hDevice, - m_hMemory, - 0, // offset - VK_WHOLE_SIZE, - 0, // flags - &m_pMappedData); - if(result == VK_SUCCESS) + for(uint32_t heapIndex = 0; heapIndex < GetMemoryHeapCount(); ++heapIndex) { - if(ppData != VMA_NULL) + const VkDeviceSize limit = pCreateInfo->pHeapSizeLimit[heapIndex]; + if(limit != VK_WHOLE_SIZE) { - *ppData = m_pMappedData; + m_HeapSizeLimitMask |= 1u << heapIndex; + if(limit < m_MemProps.memoryHeaps[heapIndex].size) + { + m_MemProps.memoryHeaps[heapIndex].size = limit; + } } - m_MapCount = count; } - return result; - } -} - -void VmaDeviceMemoryBlock::Unmap(VmaAllocator hAllocator, uint32_t count) -{ - if(count == 0) - { - return; } - VmaMutexLock lock(m_Mutex, hAllocator->m_UseMutex); - if(m_MapCount >= count) + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) { - m_MapCount -= count; - if(m_MapCount == 0) + // Create only supported types + if((m_GlobalMemoryTypeBits & (1u << memTypeIndex)) != 0) { - m_pMappedData = VMA_NULL; - (*hAllocator->GetVulkanFunctions().vkUnmapMemory)(hAllocator->m_hDevice, m_hMemory); + const VkDeviceSize preferredBlockSize = CalcPreferredBlockSize(memTypeIndex); + m_pBlockVectors[memTypeIndex] = vma_new(this, VmaBlockVector)( + this, + VK_NULL_HANDLE, // hParentPool + memTypeIndex, + preferredBlockSize, + 0, + SIZE_MAX, + GetBufferImageGranularity(), + false, // explicitBlockSize + 0, // algorithm + 0.5f, // priority (0.5 is the default per Vulkan spec) + GetMemoryTypeMinAlignment(memTypeIndex), // minAllocationAlignment + VMA_NULL); // // pMemoryAllocateNext + // No need to call m_pBlockVectors[memTypeIndex][blockVectorTypeIndex]->CreateMinBlocks here, + // becase minBlockCount is 0. } } - else - { - VMA_ASSERT(0 && "VkDeviceMemory block is being unmapped while it was not previously mapped."); - } } -VkResult VmaDeviceMemoryBlock::WriteMagicValueAroundAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize) +VkResult VmaAllocator_T::Init(const VmaAllocatorCreateInfo* pCreateInfo) { - VMA_ASSERT(VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_MARGIN % 4 == 0 && VMA_DEBUG_DETECT_CORRUPTION); - VMA_ASSERT(allocOffset >= VMA_DEBUG_MARGIN); + VkResult res = VK_SUCCESS; - void* pData; - VkResult res = Map(hAllocator, 1, &pData); - if(res != VK_SUCCESS) +#if VMA_MEMORY_BUDGET + if(m_UseExtMemoryBudget) { - return res; + UpdateVulkanBudget(); } +#endif // #if VMA_MEMORY_BUDGET - VmaWriteMagicValue(pData, allocOffset - VMA_DEBUG_MARGIN); - VmaWriteMagicValue(pData, allocOffset + allocSize); - - Unmap(hAllocator, 1); - - return VK_SUCCESS; + return res; } -VkResult VmaDeviceMemoryBlock::ValidateMagicValueAroundAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize) +VmaAllocator_T::~VmaAllocator_T() { - VMA_ASSERT(VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_MARGIN % 4 == 0 && VMA_DEBUG_DETECT_CORRUPTION); - VMA_ASSERT(allocOffset >= VMA_DEBUG_MARGIN); + VMA_ASSERT(m_Pools.IsEmpty()); - void* pData; - VkResult res = Map(hAllocator, 1, &pData); - if(res != VK_SUCCESS) + for(size_t memTypeIndex = GetMemoryTypeCount(); memTypeIndex--; ) { - return res; + vma_delete(this, m_pBlockVectors[memTypeIndex]); } +} - if(!VmaValidateMagicValue(pData, allocOffset - VMA_DEBUG_MARGIN)) - { - VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED BEFORE FREED ALLOCATION!"); - } - else if(!VmaValidateMagicValue(pData, allocOffset + allocSize)) +void VmaAllocator_T::ImportVulkanFunctions(const VmaVulkanFunctions* pVulkanFunctions) +{ +#if VMA_STATIC_VULKAN_FUNCTIONS == 1 + ImportVulkanFunctions_Static(); +#endif + + if(pVulkanFunctions != VMA_NULL) { - VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER FREED ALLOCATION!"); + ImportVulkanFunctions_Custom(pVulkanFunctions); } - Unmap(hAllocator, 1); +#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + ImportVulkanFunctions_Dynamic(); +#endif - return VK_SUCCESS; + ValidateVulkanFunctions(); } -VkResult VmaDeviceMemoryBlock::BindBufferMemory( - const VmaAllocator hAllocator, - const VmaAllocation hAllocation, - VkDeviceSize allocationLocalOffset, - VkBuffer hBuffer, - const void* pNext) -{ - VMA_ASSERT(hAllocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_BLOCK && - hAllocation->GetBlock() == this); - VMA_ASSERT(allocationLocalOffset < hAllocation->GetSize() && - "Invalid allocationLocalOffset. Did you forget that this offset is relative to the beginning of the allocation, not the whole memory block?"); - const VkDeviceSize memoryOffset = hAllocation->GetOffset() + allocationLocalOffset; - // This lock is important so that we don't call vkBind... and/or vkMap... simultaneously on the same VkDeviceMemory from multiple threads. - VmaMutexLock lock(m_Mutex, hAllocator->m_UseMutex); - return hAllocator->BindVulkanBuffer(m_hMemory, memoryOffset, hBuffer, pNext); -} +#if VMA_STATIC_VULKAN_FUNCTIONS == 1 -VkResult VmaDeviceMemoryBlock::BindImageMemory( - const VmaAllocator hAllocator, - const VmaAllocation hAllocation, - VkDeviceSize allocationLocalOffset, - VkImage hImage, - const void* pNext) +void VmaAllocator_T::ImportVulkanFunctions_Static() { - VMA_ASSERT(hAllocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_BLOCK && - hAllocation->GetBlock() == this); - VMA_ASSERT(allocationLocalOffset < hAllocation->GetSize() && - "Invalid allocationLocalOffset. Did you forget that this offset is relative to the beginning of the allocation, not the whole memory block?"); - const VkDeviceSize memoryOffset = hAllocation->GetOffset() + allocationLocalOffset; - // This lock is important so that we don't call vkBind... and/or vkMap... simultaneously on the same VkDeviceMemory from multiple threads. - VmaMutexLock lock(m_Mutex, hAllocator->m_UseMutex); - return hAllocator->BindVulkanImage(m_hMemory, memoryOffset, hImage, pNext); -} + // Vulkan 1.0 + m_VulkanFunctions.vkGetInstanceProcAddr = (PFN_vkGetInstanceProcAddr)vkGetInstanceProcAddr; + m_VulkanFunctions.vkGetDeviceProcAddr = (PFN_vkGetDeviceProcAddr)vkGetDeviceProcAddr; + m_VulkanFunctions.vkGetPhysicalDeviceProperties = (PFN_vkGetPhysicalDeviceProperties)vkGetPhysicalDeviceProperties; + m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties = (PFN_vkGetPhysicalDeviceMemoryProperties)vkGetPhysicalDeviceMemoryProperties; + m_VulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkAllocateMemory; + m_VulkanFunctions.vkFreeMemory = (PFN_vkFreeMemory)vkFreeMemory; + m_VulkanFunctions.vkMapMemory = (PFN_vkMapMemory)vkMapMemory; + m_VulkanFunctions.vkUnmapMemory = (PFN_vkUnmapMemory)vkUnmapMemory; + m_VulkanFunctions.vkFlushMappedMemoryRanges = (PFN_vkFlushMappedMemoryRanges)vkFlushMappedMemoryRanges; + m_VulkanFunctions.vkInvalidateMappedMemoryRanges = (PFN_vkInvalidateMappedMemoryRanges)vkInvalidateMappedMemoryRanges; + m_VulkanFunctions.vkBindBufferMemory = (PFN_vkBindBufferMemory)vkBindBufferMemory; + m_VulkanFunctions.vkBindImageMemory = (PFN_vkBindImageMemory)vkBindImageMemory; + m_VulkanFunctions.vkGetBufferMemoryRequirements = (PFN_vkGetBufferMemoryRequirements)vkGetBufferMemoryRequirements; + m_VulkanFunctions.vkGetImageMemoryRequirements = (PFN_vkGetImageMemoryRequirements)vkGetImageMemoryRequirements; + m_VulkanFunctions.vkCreateBuffer = (PFN_vkCreateBuffer)vkCreateBuffer; + m_VulkanFunctions.vkDestroyBuffer = (PFN_vkDestroyBuffer)vkDestroyBuffer; + m_VulkanFunctions.vkCreateImage = (PFN_vkCreateImage)vkCreateImage; + m_VulkanFunctions.vkDestroyImage = (PFN_vkDestroyImage)vkDestroyImage; + m_VulkanFunctions.vkCmdCopyBuffer = (PFN_vkCmdCopyBuffer)vkCmdCopyBuffer; -static void InitStatInfo(VmaStatInfo& outInfo) -{ - memset(&outInfo, 0, sizeof(outInfo)); - outInfo.allocationSizeMin = UINT64_MAX; - outInfo.unusedRangeSizeMin = UINT64_MAX; -} + // Vulkan 1.1 +#if VMA_VULKAN_VERSION >= 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2)vkGetBufferMemoryRequirements2; + m_VulkanFunctions.vkGetImageMemoryRequirements2KHR = (PFN_vkGetImageMemoryRequirements2)vkGetImageMemoryRequirements2; + m_VulkanFunctions.vkBindBufferMemory2KHR = (PFN_vkBindBufferMemory2)vkBindBufferMemory2; + m_VulkanFunctions.vkBindImageMemory2KHR = (PFN_vkBindImageMemory2)vkBindImageMemory2; + m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties2KHR = (PFN_vkGetPhysicalDeviceMemoryProperties2)vkGetPhysicalDeviceMemoryProperties2; + } +#endif -// Adds statistics srcInfo into inoutInfo, like: inoutInfo += srcInfo. -static void VmaAddStatInfo(VmaStatInfo& inoutInfo, const VmaStatInfo& srcInfo) -{ - inoutInfo.blockCount += srcInfo.blockCount; - inoutInfo.allocationCount += srcInfo.allocationCount; - inoutInfo.unusedRangeCount += srcInfo.unusedRangeCount; - inoutInfo.usedBytes += srcInfo.usedBytes; - inoutInfo.unusedBytes += srcInfo.unusedBytes; - inoutInfo.allocationSizeMin = VMA_MIN(inoutInfo.allocationSizeMin, srcInfo.allocationSizeMin); - inoutInfo.allocationSizeMax = VMA_MAX(inoutInfo.allocationSizeMax, srcInfo.allocationSizeMax); - inoutInfo.unusedRangeSizeMin = VMA_MIN(inoutInfo.unusedRangeSizeMin, srcInfo.unusedRangeSizeMin); - inoutInfo.unusedRangeSizeMax = VMA_MAX(inoutInfo.unusedRangeSizeMax, srcInfo.unusedRangeSizeMax); +#if VMA_VULKAN_VERSION >= 1003000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 3, 0)) + { + m_VulkanFunctions.vkGetDeviceBufferMemoryRequirements = (PFN_vkGetDeviceBufferMemoryRequirements)vkGetDeviceBufferMemoryRequirements; + m_VulkanFunctions.vkGetDeviceImageMemoryRequirements = (PFN_vkGetDeviceImageMemoryRequirements)vkGetDeviceImageMemoryRequirements; + } +#endif } -static void VmaPostprocessCalcStatInfo(VmaStatInfo& inoutInfo) -{ - inoutInfo.allocationSizeAvg = (inoutInfo.allocationCount > 0) ? - VmaRoundDiv(inoutInfo.usedBytes, inoutInfo.allocationCount) : 0; - inoutInfo.unusedRangeSizeAvg = (inoutInfo.unusedRangeCount > 0) ? - VmaRoundDiv(inoutInfo.unusedBytes, inoutInfo.unusedRangeCount) : 0; -} +#endif // VMA_STATIC_VULKAN_FUNCTIONS == 1 -VmaPool_T::VmaPool_T( - VmaAllocator hAllocator, - const VmaPoolCreateInfo& createInfo, - VkDeviceSize preferredBlockSize) : - m_BlockVector( - hAllocator, - this, // hParentPool - createInfo.memoryTypeIndex, - createInfo.blockSize != 0 ? createInfo.blockSize : preferredBlockSize, - createInfo.minBlockCount, - createInfo.maxBlockCount, - (createInfo.flags & VMA_POOL_CREATE_IGNORE_BUFFER_IMAGE_GRANULARITY_BIT) != 0 ? 1 : hAllocator->GetBufferImageGranularity(), - createInfo.frameInUseCount, - createInfo.blockSize != 0, // explicitBlockSize - createInfo.flags & VMA_POOL_CREATE_ALGORITHM_MASK, - createInfo.priority), // algorithm - m_Id(0), - m_Name(VMA_NULL) +void VmaAllocator_T::ImportVulkanFunctions_Custom(const VmaVulkanFunctions* pVulkanFunctions) { -} + VMA_ASSERT(pVulkanFunctions != VMA_NULL); -VmaPool_T::~VmaPool_T() -{ -} +#define VMA_COPY_IF_NOT_NULL(funcName) \ + if(pVulkanFunctions->funcName != VMA_NULL) m_VulkanFunctions.funcName = pVulkanFunctions->funcName; -void VmaPool_T::SetName(const char* pName) -{ - const VkAllocationCallbacks* allocs = m_BlockVector.GetAllocator()->GetAllocationCallbacks(); - VmaFreeString(allocs, m_Name); + VMA_COPY_IF_NOT_NULL(vkGetInstanceProcAddr); + VMA_COPY_IF_NOT_NULL(vkGetDeviceProcAddr); + VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceProperties); + VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceMemoryProperties); + VMA_COPY_IF_NOT_NULL(vkAllocateMemory); + VMA_COPY_IF_NOT_NULL(vkFreeMemory); + VMA_COPY_IF_NOT_NULL(vkMapMemory); + VMA_COPY_IF_NOT_NULL(vkUnmapMemory); + VMA_COPY_IF_NOT_NULL(vkFlushMappedMemoryRanges); + VMA_COPY_IF_NOT_NULL(vkInvalidateMappedMemoryRanges); + VMA_COPY_IF_NOT_NULL(vkBindBufferMemory); + VMA_COPY_IF_NOT_NULL(vkBindImageMemory); + VMA_COPY_IF_NOT_NULL(vkGetBufferMemoryRequirements); + VMA_COPY_IF_NOT_NULL(vkGetImageMemoryRequirements); + VMA_COPY_IF_NOT_NULL(vkCreateBuffer); + VMA_COPY_IF_NOT_NULL(vkDestroyBuffer); + VMA_COPY_IF_NOT_NULL(vkCreateImage); + VMA_COPY_IF_NOT_NULL(vkDestroyImage); + VMA_COPY_IF_NOT_NULL(vkCmdCopyBuffer); - if(pName != VMA_NULL) - { - m_Name = VmaCreateStringCopy(allocs, pName); - } - else - { - m_Name = VMA_NULL; - } -} +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + VMA_COPY_IF_NOT_NULL(vkGetBufferMemoryRequirements2KHR); + VMA_COPY_IF_NOT_NULL(vkGetImageMemoryRequirements2KHR); +#endif -#if VMA_STATS_STRING_ENABLED +#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 + VMA_COPY_IF_NOT_NULL(vkBindBufferMemory2KHR); + VMA_COPY_IF_NOT_NULL(vkBindImageMemory2KHR); +#endif -#endif // #if VMA_STATS_STRING_ENABLED +#if VMA_MEMORY_BUDGET + VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceMemoryProperties2KHR); +#endif -VmaBlockVector::VmaBlockVector( - VmaAllocator hAllocator, - VmaPool hParentPool, - uint32_t memoryTypeIndex, - VkDeviceSize preferredBlockSize, - size_t minBlockCount, - size_t maxBlockCount, - VkDeviceSize bufferImageGranularity, - uint32_t frameInUseCount, - bool explicitBlockSize, - uint32_t algorithm, - float priority) : - m_hAllocator(hAllocator), - m_hParentPool(hParentPool), - m_MemoryTypeIndex(memoryTypeIndex), - m_PreferredBlockSize(preferredBlockSize), - m_MinBlockCount(minBlockCount), - m_MaxBlockCount(maxBlockCount), - m_BufferImageGranularity(bufferImageGranularity), - m_FrameInUseCount(frameInUseCount), - m_ExplicitBlockSize(explicitBlockSize), - m_Algorithm(algorithm), - m_Priority(priority), - m_HasEmptyBlock(false), - m_Blocks(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), - m_NextBlockId(0) -{ -} +#if VMA_VULKAN_VERSION >= 1003000 + VMA_COPY_IF_NOT_NULL(vkGetDeviceBufferMemoryRequirements); + VMA_COPY_IF_NOT_NULL(vkGetDeviceImageMemoryRequirements); +#endif -VmaBlockVector::~VmaBlockVector() -{ - for(size_t i = m_Blocks.size(); i--; ) - { - m_Blocks[i]->Destroy(m_hAllocator); - vma_delete(m_hAllocator, m_Blocks[i]); - } +#undef VMA_COPY_IF_NOT_NULL } -VkResult VmaBlockVector::CreateMinBlocks() -{ - for (const auto i : c10::irange(m_MinBlockCount)) { - VkResult res = CreateBlock(m_PreferredBlockSize, VMA_NULL); - if(res != VK_SUCCESS) - { - return res; - } - } - return VK_SUCCESS; -} +#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 -void VmaBlockVector::GetPoolStats(VmaPoolStats* pStats) +void VmaAllocator_T::ImportVulkanFunctions_Dynamic() { - VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + VMA_ASSERT(m_VulkanFunctions.vkGetInstanceProcAddr && m_VulkanFunctions.vkGetDeviceProcAddr && + "To use VMA_DYNAMIC_VULKAN_FUNCTIONS in new versions of VMA you now have to pass " + "VmaVulkanFunctions::vkGetInstanceProcAddr and vkGetDeviceProcAddr as VmaAllocatorCreateInfo::pVulkanFunctions. " + "Other members can be null."); - const size_t blockCount = m_Blocks.size(); +#define VMA_FETCH_INSTANCE_FUNC(memberName, functionPointerType, functionNameString) \ + if(m_VulkanFunctions.memberName == VMA_NULL) \ + m_VulkanFunctions.memberName = \ + (functionPointerType)m_VulkanFunctions.vkGetInstanceProcAddr(m_hInstance, functionNameString); +#define VMA_FETCH_DEVICE_FUNC(memberName, functionPointerType, functionNameString) \ + if(m_VulkanFunctions.memberName == VMA_NULL) \ + m_VulkanFunctions.memberName = \ + (functionPointerType)m_VulkanFunctions.vkGetDeviceProcAddr(m_hDevice, functionNameString); - pStats->size = 0; - pStats->unusedSize = 0; - pStats->allocationCount = 0; - pStats->unusedRangeCount = 0; - pStats->unusedRangeSizeMax = 0; - pStats->blockCount = blockCount; + VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceProperties, PFN_vkGetPhysicalDeviceProperties, "vkGetPhysicalDeviceProperties"); + VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties, PFN_vkGetPhysicalDeviceMemoryProperties, "vkGetPhysicalDeviceMemoryProperties"); + VMA_FETCH_DEVICE_FUNC(vkAllocateMemory, PFN_vkAllocateMemory, "vkAllocateMemory"); + VMA_FETCH_DEVICE_FUNC(vkFreeMemory, PFN_vkFreeMemory, "vkFreeMemory"); + VMA_FETCH_DEVICE_FUNC(vkMapMemory, PFN_vkMapMemory, "vkMapMemory"); + VMA_FETCH_DEVICE_FUNC(vkUnmapMemory, PFN_vkUnmapMemory, "vkUnmapMemory"); + VMA_FETCH_DEVICE_FUNC(vkFlushMappedMemoryRanges, PFN_vkFlushMappedMemoryRanges, "vkFlushMappedMemoryRanges"); + VMA_FETCH_DEVICE_FUNC(vkInvalidateMappedMemoryRanges, PFN_vkInvalidateMappedMemoryRanges, "vkInvalidateMappedMemoryRanges"); + VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory, PFN_vkBindBufferMemory, "vkBindBufferMemory"); + VMA_FETCH_DEVICE_FUNC(vkBindImageMemory, PFN_vkBindImageMemory, "vkBindImageMemory"); + VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements, PFN_vkGetBufferMemoryRequirements, "vkGetBufferMemoryRequirements"); + VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements, PFN_vkGetImageMemoryRequirements, "vkGetImageMemoryRequirements"); + VMA_FETCH_DEVICE_FUNC(vkCreateBuffer, PFN_vkCreateBuffer, "vkCreateBuffer"); + VMA_FETCH_DEVICE_FUNC(vkDestroyBuffer, PFN_vkDestroyBuffer, "vkDestroyBuffer"); + VMA_FETCH_DEVICE_FUNC(vkCreateImage, PFN_vkCreateImage, "vkCreateImage"); + VMA_FETCH_DEVICE_FUNC(vkDestroyImage, PFN_vkDestroyImage, "vkDestroyImage"); + VMA_FETCH_DEVICE_FUNC(vkCmdCopyBuffer, PFN_vkCmdCopyBuffer, "vkCmdCopyBuffer"); - for (const auto blockIndex : c10::irange(blockCount)) { - const VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; - VMA_ASSERT(pBlock); - VMA_HEAVY_ASSERT(pBlock->Validate()); - pBlock->m_pMetadata->AddPoolStats(*pStats); +#if VMA_VULKAN_VERSION >= 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements2KHR, PFN_vkGetBufferMemoryRequirements2, "vkGetBufferMemoryRequirements2"); + VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements2KHR, PFN_vkGetImageMemoryRequirements2, "vkGetImageMemoryRequirements2"); + VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory2KHR, PFN_vkBindBufferMemory2, "vkBindBufferMemory2"); + VMA_FETCH_DEVICE_FUNC(vkBindImageMemory2KHR, PFN_vkBindImageMemory2, "vkBindImageMemory2"); + VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties2KHR, PFN_vkGetPhysicalDeviceMemoryProperties2, "vkGetPhysicalDeviceMemoryProperties2"); } -} - -bool VmaBlockVector::IsEmpty() -{ - VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); - return m_Blocks.empty(); -} - -bool VmaBlockVector::IsCorruptionDetectionEnabled() const -{ - const uint32_t requiredMemFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - return (VMA_DEBUG_DETECT_CORRUPTION != 0) && - (VMA_DEBUG_MARGIN > 0) && - (m_Algorithm == 0 || m_Algorithm == VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) && - (m_hAllocator->m_MemProps.memoryTypes[m_MemoryTypeIndex].propertyFlags & requiredMemFlags) == requiredMemFlags; -} - -static const uint32_t VMA_ALLOCATION_TRY_COUNT = 32; +#endif -VkResult VmaBlockVector::Allocate( - uint32_t currentFrameIndex, - VkDeviceSize size, - VkDeviceSize alignment, - const VmaAllocationCreateInfo& createInfo, - VmaSuballocationType suballocType, - size_t allocationCount, - VmaAllocation* pAllocations) -{ - size_t allocIndex; - VkResult res = VK_SUCCESS; +#if VMA_DEDICATED_ALLOCATION + if(m_UseKhrDedicatedAllocation) + { + VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements2KHR, PFN_vkGetBufferMemoryRequirements2KHR, "vkGetBufferMemoryRequirements2KHR"); + VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements2KHR, PFN_vkGetImageMemoryRequirements2KHR, "vkGetImageMemoryRequirements2KHR"); + } +#endif - if(IsCorruptionDetectionEnabled()) +#if VMA_BIND_MEMORY2 + if(m_UseKhrBindMemory2) { - size = VmaAlignUp(size, sizeof(VMA_CORRUPTION_DETECTION_MAGIC_VALUE)); - alignment = VmaAlignUp(alignment, sizeof(VMA_CORRUPTION_DETECTION_MAGIC_VALUE)); + VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory2KHR, PFN_vkBindBufferMemory2KHR, "vkBindBufferMemory2KHR"); + VMA_FETCH_DEVICE_FUNC(vkBindImageMemory2KHR, PFN_vkBindImageMemory2KHR, "vkBindImageMemory2KHR"); } +#endif // #if VMA_BIND_MEMORY2 +#if VMA_MEMORY_BUDGET + if(m_UseExtMemoryBudget) { - VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); - for(allocIndex = 0; allocIndex < allocationCount; ++allocIndex) - { - res = AllocatePage( - currentFrameIndex, - size, - alignment, - createInfo, - suballocType, - pAllocations + allocIndex); - if(res != VK_SUCCESS) - { - break; - } - } + VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties2KHR, PFN_vkGetPhysicalDeviceMemoryProperties2KHR, "vkGetPhysicalDeviceMemoryProperties2KHR"); } +#endif // #if VMA_MEMORY_BUDGET - if(res != VK_SUCCESS) +#if VMA_VULKAN_VERSION >= 1003000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 3, 0)) { - // Free all already created allocations. - const uint32_t heapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex); - while(allocIndex--) - { - VmaAllocation_T* const alloc = pAllocations[allocIndex]; - const VkDeviceSize allocSize = alloc->GetSize(); - Free(alloc); - m_hAllocator->m_Budget.RemoveAllocation(heapIndex, allocSize); - } - memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); + VMA_FETCH_DEVICE_FUNC(vkGetDeviceBufferMemoryRequirements, PFN_vkGetDeviceBufferMemoryRequirements, "vkGetDeviceBufferMemoryRequirements"); + VMA_FETCH_DEVICE_FUNC(vkGetDeviceImageMemoryRequirements, PFN_vkGetDeviceImageMemoryRequirements, "vkGetDeviceImageMemoryRequirements"); } +#endif - return res; +#undef VMA_FETCH_DEVICE_FUNC +#undef VMA_FETCH_INSTANCE_FUNC } -VkResult VmaBlockVector::AllocatePage( - uint32_t currentFrameIndex, - VkDeviceSize size, - VkDeviceSize alignment, - const VmaAllocationCreateInfo& createInfo, - VmaSuballocationType suballocType, - VmaAllocation* pAllocation) +#endif // VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + +void VmaAllocator_T::ValidateVulkanFunctions() { - const bool isUpperAddress = (createInfo.flags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0; - bool canMakeOtherLost = (createInfo.flags & VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT) != 0; - const bool mapped = (createInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0; - const bool isUserDataString = (createInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0; + VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceProperties != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkAllocateMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkFreeMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkMapMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkUnmapMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkFlushMappedMemoryRanges != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkInvalidateMappedMemoryRanges != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkBindBufferMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkBindImageMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetBufferMemoryRequirements != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetImageMemoryRequirements != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkCreateBuffer != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkDestroyBuffer != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkCreateImage != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkDestroyImage != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkCmdCopyBuffer != VMA_NULL); - VkDeviceSize freeMemory; +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0) || m_UseKhrDedicatedAllocation) { - const uint32_t heapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex); - VmaBudget heapBudget = {}; - m_hAllocator->GetBudget(&heapBudget, heapIndex, 1); - freeMemory = (heapBudget.usage < heapBudget.budget) ? (heapBudget.budget - heapBudget.usage) : 0; + VMA_ASSERT(m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetImageMemoryRequirements2KHR != VMA_NULL); } +#endif - const bool canFallbackToDedicated = !IsCustomPool(); - const bool canCreateNewBlock = - ((createInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) == 0) && - (m_Blocks.size() < m_MaxBlockCount) && - (freeMemory >= size || !canFallbackToDedicated); - uint32_t strategy = createInfo.flags & VMA_ALLOCATION_CREATE_STRATEGY_MASK; - - // If linearAlgorithm is used, canMakeOtherLost is available only when used as ring buffer. - // Which in turn is available only when maxBlockCount = 1. - if(m_Algorithm == VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT && m_MaxBlockCount > 1) +#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0) || m_UseKhrBindMemory2) { - canMakeOtherLost = false; + VMA_ASSERT(m_VulkanFunctions.vkBindBufferMemory2KHR != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkBindImageMemory2KHR != VMA_NULL); } +#endif - // Upper address can only be used with linear allocator and within single memory block. - if(isUpperAddress && - (m_Algorithm != VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT || m_MaxBlockCount > 1)) +#if VMA_MEMORY_BUDGET || VMA_VULKAN_VERSION >= 1001000 + if(m_UseExtMemoryBudget || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) { - return VK_ERROR_FEATURE_NOT_PRESENT; + VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties2KHR != VMA_NULL); } +#endif - // Validate strategy. - switch(strategy) +#if VMA_VULKAN_VERSION >= 1003000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 3, 0)) { - case 0: - strategy = VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT; - break; - case VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT: - case VMA_ALLOCATION_CREATE_STRATEGY_WORST_FIT_BIT: - case VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT: - break; - default: - return VK_ERROR_FEATURE_NOT_PRESENT; + VMA_ASSERT(m_VulkanFunctions.vkGetDeviceBufferMemoryRequirements != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetDeviceImageMemoryRequirements != VMA_NULL); } +#endif +} - // Early reject: requested allocation size is larger that maximum block size for this block vector. - if(size + 2 * VMA_DEBUG_MARGIN > m_PreferredBlockSize) +VkDeviceSize VmaAllocator_T::CalcPreferredBlockSize(uint32_t memTypeIndex) +{ + const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(memTypeIndex); + const VkDeviceSize heapSize = m_MemProps.memoryHeaps[heapIndex].size; + const bool isSmallHeap = heapSize <= VMA_SMALL_HEAP_MAX_SIZE; + return VmaAlignUp(isSmallHeap ? (heapSize / 8) : m_PreferredLargeHeapBlockSize, (VkDeviceSize)32); +} + +VkResult VmaAllocator_T::AllocateMemoryOfType( + VmaPool pool, + VkDeviceSize size, + VkDeviceSize alignment, + bool dedicatedPreferred, + VkBuffer dedicatedBuffer, + VkImage dedicatedImage, + VkFlags dedicatedBufferImageUsage, + const VmaAllocationCreateInfo& createInfo, + uint32_t memTypeIndex, + VmaSuballocationType suballocType, + VmaDedicatedAllocationList& dedicatedAllocations, + VmaBlockVector& blockVector, + size_t allocationCount, + VmaAllocation* pAllocations) +{ + VMA_ASSERT(pAllocations != VMA_NULL); + VMA_DEBUG_LOG(" AllocateMemory: MemoryTypeIndex=%u, AllocationCount=%zu, Size=%llu", memTypeIndex, allocationCount, size); + + VmaAllocationCreateInfo finalCreateInfo = createInfo; + VkResult res = CalcMemTypeParams( + finalCreateInfo, + memTypeIndex, + size, + allocationCount); + if(res != VK_SUCCESS) + return res; + + if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT) != 0) { - return VK_ERROR_OUT_OF_DEVICE_MEMORY; + return AllocateDedicatedMemory( + pool, + size, + suballocType, + dedicatedAllocations, + memTypeIndex, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0, + (finalCreateInfo.flags & + (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_CAN_ALIAS_BIT) != 0, + finalCreateInfo.pUserData, + finalCreateInfo.priority, + dedicatedBuffer, + dedicatedImage, + dedicatedBufferImageUsage, + allocationCount, + pAllocations, + blockVector.GetAllocationNextPtr()); } - - /* - Under certain condition, this whole section can be skipped for optimization, so - we move on directly to trying to allocate with canMakeOtherLost. That's the case - e.g. for custom pools with linear algorithm. - */ - if(!canMakeOtherLost || canCreateNewBlock) + else { - // 1. Search existing allocations. Try to allocate without making other allocations lost. - VmaAllocationCreateFlags allocFlagsCopy = createInfo.flags; - allocFlagsCopy &= ~VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT; + const bool canAllocateDedicated = + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) == 0 && + (pool == VK_NULL_HANDLE || !blockVector.HasExplicitBlockSize()); - if(m_Algorithm == VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) + if(canAllocateDedicated) { - // Use only last block. - if(!m_Blocks.empty()) + // Heuristics: Allocate dedicated memory if requested size if greater than half of preferred block size. + if(size > blockVector.GetPreferredBlockSize() / 2) { - VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks.back(); - VMA_ASSERT(pCurrBlock); - VkResult res = AllocateFromBlock( - pCurrBlock, - currentFrameIndex, + dedicatedPreferred = true; + } + // Protection against creating each allocation as dedicated when we reach or exceed heap size/budget, + // which can quickly deplete maxMemoryAllocationCount: Don't prefer dedicated allocations when above + // 3/4 of the maximum allocation count. + if(m_DeviceMemoryCount.load() > m_PhysicalDeviceProperties.limits.maxMemoryAllocationCount * 3 / 4) + { + dedicatedPreferred = false; + } + + if(dedicatedPreferred) + { + res = AllocateDedicatedMemory( + pool, size, - alignment, - allocFlagsCopy, - createInfo.pUserData, suballocType, - strategy, - pAllocation); + dedicatedAllocations, + memTypeIndex, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0, + (finalCreateInfo.flags & + (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_CAN_ALIAS_BIT) != 0, + finalCreateInfo.pUserData, + finalCreateInfo.priority, + dedicatedBuffer, + dedicatedImage, + dedicatedBufferImageUsage, + allocationCount, + pAllocations, + blockVector.GetAllocationNextPtr()); if(res == VK_SUCCESS) { - VMA_DEBUG_LOG(" Returned from last block #%u", pCurrBlock->GetId()); + // Succeeded: AllocateDedicatedMemory function already filld pMemory, nothing more to do here. + VMA_DEBUG_LOG(" Allocated as DedicatedMemory"); return VK_SUCCESS; } } } - else + + res = blockVector.Allocate( + size, + alignment, + finalCreateInfo, + suballocType, + allocationCount, + pAllocations); + if(res == VK_SUCCESS) + return VK_SUCCESS; + + // Try dedicated memory. + if(canAllocateDedicated && !dedicatedPreferred) { - if(strategy == VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT) - { - // Forward order in m_Blocks - prefer blocks with smallest amount of free space. - for (const auto blockIndex : c10::irange(m_Blocks.size())) { - VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; - VMA_ASSERT(pCurrBlock); - VkResult res = AllocateFromBlock( - pCurrBlock, - currentFrameIndex, - size, - alignment, - allocFlagsCopy, - createInfo.pUserData, - suballocType, - strategy, - pAllocation); - if(res == VK_SUCCESS) - { - VMA_DEBUG_LOG(" Returned from existing block #%u", pCurrBlock->GetId()); - return VK_SUCCESS; - } - } - } - else // WORST_FIT, FIRST_FIT + res = AllocateDedicatedMemory( + pool, + size, + suballocType, + dedicatedAllocations, + memTypeIndex, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0, + (finalCreateInfo.flags & + (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_CAN_ALIAS_BIT) != 0, + finalCreateInfo.pUserData, + finalCreateInfo.priority, + dedicatedBuffer, + dedicatedImage, + dedicatedBufferImageUsage, + allocationCount, + pAllocations, + blockVector.GetAllocationNextPtr()); + if(res == VK_SUCCESS) { - // Backward order in m_Blocks - prefer blocks with largest amount of free space. - for(size_t blockIndex = m_Blocks.size(); blockIndex--; ) - { - VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; - VMA_ASSERT(pCurrBlock); - VkResult res = AllocateFromBlock( - pCurrBlock, - currentFrameIndex, - size, - alignment, - allocFlagsCopy, - createInfo.pUserData, - suballocType, - strategy, - pAllocation); - if(res == VK_SUCCESS) - { - VMA_DEBUG_LOG(" Returned from existing block #%u", pCurrBlock->GetId()); - return VK_SUCCESS; - } - } + // Succeeded: AllocateDedicatedMemory function already filld pMemory, nothing more to do here. + VMA_DEBUG_LOG(" Allocated as DedicatedMemory"); + return VK_SUCCESS; } } + // Everything failed: Return error code. + VMA_DEBUG_LOG(" vkAllocateMemory FAILED"); + return res; + } +} - // 2. Try to create new block. - if(canCreateNewBlock) - { - // Calculate optimal size for new block. - VkDeviceSize newBlockSize = m_PreferredBlockSize; - uint32_t newBlockSizeShift = 0; - const uint32_t NEW_BLOCK_SIZE_SHIFT_MAX = 3; +VkResult VmaAllocator_T::AllocateDedicatedMemory( + VmaPool pool, + VkDeviceSize size, + VmaSuballocationType suballocType, + VmaDedicatedAllocationList& dedicatedAllocations, + uint32_t memTypeIndex, + bool map, + bool isUserDataString, + bool isMappingAllowed, + bool canAliasMemory, + void* pUserData, + float priority, + VkBuffer dedicatedBuffer, + VkImage dedicatedImage, + VkFlags dedicatedBufferImageUsage, + size_t allocationCount, + VmaAllocation* pAllocations, + const void* pNextChain) +{ + VMA_ASSERT(allocationCount > 0 && pAllocations); - if(!m_ExplicitBlockSize) - { - // Allocate 1/8, 1/4, 1/2 as first blocks. - const VkDeviceSize maxExistingBlockSize = CalcMaxBlockSize(); - for (const auto i : c10::irange(NEW_BLOCK_SIZE_SHIFT_MAX)) { - const VkDeviceSize smallerNewBlockSize = newBlockSize / 2; - if(smallerNewBlockSize > maxExistingBlockSize && smallerNewBlockSize >= size * 2) - { - newBlockSize = smallerNewBlockSize; - ++newBlockSizeShift; - } - else - { - break; - } - } - } + VkMemoryAllocateInfo allocInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO }; + allocInfo.memoryTypeIndex = memTypeIndex; + allocInfo.allocationSize = size; + allocInfo.pNext = pNextChain; - size_t newBlockIndex = 0; - VkResult res = (newBlockSize <= freeMemory || !canFallbackToDedicated) ? - CreateBlock(newBlockSize, &newBlockIndex) : VK_ERROR_OUT_OF_DEVICE_MEMORY; - // Allocation of this size failed? Try 1/2, 1/4, 1/8 of m_PreferredBlockSize. - if(!m_ExplicitBlockSize) +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + VkMemoryDedicatedAllocateInfoKHR dedicatedAllocInfo = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR }; + if(!canAliasMemory) + { + if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + if(dedicatedBuffer != VK_NULL_HANDLE) { - while(res < 0 && newBlockSizeShift < NEW_BLOCK_SIZE_SHIFT_MAX) - { - const VkDeviceSize smallerNewBlockSize = newBlockSize / 2; - if(smallerNewBlockSize >= size) - { - newBlockSize = smallerNewBlockSize; - ++newBlockSizeShift; - res = (newBlockSize <= freeMemory || !canFallbackToDedicated) ? - CreateBlock(newBlockSize, &newBlockIndex) : VK_ERROR_OUT_OF_DEVICE_MEMORY; - } - else - { - break; - } - } + VMA_ASSERT(dedicatedImage == VK_NULL_HANDLE); + dedicatedAllocInfo.buffer = dedicatedBuffer; + VmaPnextChainPushFront(&allocInfo, &dedicatedAllocInfo); } - - if(res == VK_SUCCESS) + else if(dedicatedImage != VK_NULL_HANDLE) { - VmaDeviceMemoryBlock* const pBlock = m_Blocks[newBlockIndex]; - VMA_ASSERT(pBlock->m_pMetadata->GetSize() >= size); - - res = AllocateFromBlock( - pBlock, - currentFrameIndex, - size, - alignment, - allocFlagsCopy, - createInfo.pUserData, - suballocType, - strategy, - pAllocation); - if(res == VK_SUCCESS) - { - VMA_DEBUG_LOG(" Created new block #%u Size=%llu", pBlock->GetId(), newBlockSize); - return VK_SUCCESS; - } - else - { - // Allocation from new block failed, possibly due to VMA_DEBUG_MARGIN or alignment. - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } + dedicatedAllocInfo.image = dedicatedImage; + VmaPnextChainPushFront(&allocInfo, &dedicatedAllocInfo); } } } +#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 - // 3. Try to allocate from existing blocks with making other allocations lost. - if(canMakeOtherLost) +#if VMA_BUFFER_DEVICE_ADDRESS + VkMemoryAllocateFlagsInfoKHR allocFlagsInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO_KHR }; + if(m_UseKhrBufferDeviceAddress) { - uint32_t tryIndex = 0; - for(; tryIndex < VMA_ALLOCATION_TRY_COUNT; ++tryIndex) + bool canContainBufferWithDeviceAddress = true; + if(dedicatedBuffer != VK_NULL_HANDLE) + { + canContainBufferWithDeviceAddress = dedicatedBufferImageUsage == UINT32_MAX || // Usage flags unknown + (dedicatedBufferImageUsage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT) != 0; + } + else if(dedicatedImage != VK_NULL_HANDLE) + { + canContainBufferWithDeviceAddress = false; + } + if(canContainBufferWithDeviceAddress) { - VmaDeviceMemoryBlock* pBestRequestBlock = VMA_NULL; - VmaAllocationRequest bestRequest = {}; - VkDeviceSize bestRequestCost = VK_WHOLE_SIZE; + allocFlagsInfo.flags = VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT_KHR; + VmaPnextChainPushFront(&allocInfo, &allocFlagsInfo); + } + } +#endif // #if VMA_BUFFER_DEVICE_ADDRESS - // 1. Search existing allocations. - if(strategy == VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT) - { - // Forward order in m_Blocks - prefer blocks with smallest amount of free space. - for (const auto blockIndex : c10::irange(m_Blocks.size())) { - VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; - VMA_ASSERT(pCurrBlock); - VmaAllocationRequest currRequest = {}; - if(pCurrBlock->m_pMetadata->CreateAllocationRequest( - currentFrameIndex, - m_FrameInUseCount, - m_BufferImageGranularity, - size, - alignment, - (createInfo.flags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0, - suballocType, - canMakeOtherLost, - strategy, - &currRequest)) - { - const VkDeviceSize currRequestCost = currRequest.CalcCost(); - if(pBestRequestBlock == VMA_NULL || - currRequestCost < bestRequestCost) - { - pBestRequestBlock = pCurrBlock; - bestRequest = currRequest; - bestRequestCost = currRequestCost; +#if VMA_MEMORY_PRIORITY + VkMemoryPriorityAllocateInfoEXT priorityInfo = { VK_STRUCTURE_TYPE_MEMORY_PRIORITY_ALLOCATE_INFO_EXT }; + if(m_UseExtMemoryPriority) + { + VMA_ASSERT(priority >= 0.f && priority <= 1.f); + priorityInfo.priority = priority; + VmaPnextChainPushFront(&allocInfo, &priorityInfo); + } +#endif // #if VMA_MEMORY_PRIORITY - if(bestRequestCost == 0) - { - break; - } - } - } - } - } - else // WORST_FIT, FIRST_FIT - { - // Backward order in m_Blocks - prefer blocks with largest amount of free space. - for(size_t blockIndex = m_Blocks.size(); blockIndex--; ) - { - VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; - VMA_ASSERT(pCurrBlock); - VmaAllocationRequest currRequest = {}; - if(pCurrBlock->m_pMetadata->CreateAllocationRequest( - currentFrameIndex, - m_FrameInUseCount, - m_BufferImageGranularity, - size, - alignment, - (createInfo.flags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0, - suballocType, - canMakeOtherLost, - strategy, - &currRequest)) - { - const VkDeviceSize currRequestCost = currRequest.CalcCost(); - if(pBestRequestBlock == VMA_NULL || - currRequestCost < bestRequestCost || - strategy == VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT) - { - pBestRequestBlock = pCurrBlock; - bestRequest = currRequest; - bestRequestCost = currRequestCost; +#if VMA_EXTERNAL_MEMORY + // Attach VkExportMemoryAllocateInfoKHR if necessary. + VkExportMemoryAllocateInfoKHR exportMemoryAllocInfo = { VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_KHR }; + exportMemoryAllocInfo.handleTypes = GetExternalMemoryHandleTypeFlags(memTypeIndex); + if(exportMemoryAllocInfo.handleTypes != 0) + { + VmaPnextChainPushFront(&allocInfo, &exportMemoryAllocInfo); + } +#endif // #if VMA_EXTERNAL_MEMORY - if(bestRequestCost == 0 || - strategy == VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT) - { - break; - } - } - } - } - } + size_t allocIndex; + VkResult res = VK_SUCCESS; + for(allocIndex = 0; allocIndex < allocationCount; ++allocIndex) + { + res = AllocateDedicatedMemoryPage( + pool, + size, + suballocType, + memTypeIndex, + allocInfo, + map, + isUserDataString, + isMappingAllowed, + pUserData, + pAllocations + allocIndex); + if(res != VK_SUCCESS) + { + break; + } + } - if(pBestRequestBlock != VMA_NULL) - { - if(mapped) - { - VkResult res = pBestRequestBlock->Map(m_hAllocator, 1, VMA_NULL); - if(res != VK_SUCCESS) - { - return res; - } - } + if(res == VK_SUCCESS) + { + for (allocIndex = 0; allocIndex < allocationCount; ++allocIndex) + { + dedicatedAllocations.Register(pAllocations[allocIndex]); + } + VMA_DEBUG_LOG(" Allocated DedicatedMemory Count=%zu, MemoryTypeIndex=#%u", allocationCount, memTypeIndex); + } + else + { + // Free all already created allocations. + while(allocIndex--) + { + VmaAllocation currAlloc = pAllocations[allocIndex]; + VkDeviceMemory hMemory = currAlloc->GetMemory(); - if(pBestRequestBlock->m_pMetadata->MakeRequestedAllocationsLost( - currentFrameIndex, - m_FrameInUseCount, - &bestRequest)) - { - // Allocate from this pBlock. - *pAllocation = m_hAllocator->m_AllocationObjectAllocator.Allocate(currentFrameIndex, isUserDataString); - pBestRequestBlock->m_pMetadata->Alloc(bestRequest, suballocType, size, *pAllocation); - UpdateHasEmptyBlock(); - (*pAllocation)->InitBlockAllocation( - pBestRequestBlock, - bestRequest.offset, - alignment, - size, - m_MemoryTypeIndex, - suballocType, - mapped, - (createInfo.flags & VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT) != 0); - VMA_HEAVY_ASSERT(pBestRequestBlock->Validate()); - VMA_DEBUG_LOG(" Returned from existing block"); - (*pAllocation)->SetUserData(m_hAllocator, createInfo.pUserData); - m_hAllocator->m_Budget.AddAllocation(m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex), size); - if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) - { - m_hAllocator->FillAllocation(*pAllocation, VMA_ALLOCATION_FILL_PATTERN_CREATED); - } - if(IsCorruptionDetectionEnabled()) - { - VkResult res = pBestRequestBlock->WriteMagicValueAroundAllocation(m_hAllocator, bestRequest.offset, size); - VMA_ASSERT(res == VK_SUCCESS && "Couldn't map block memory to write magic value."); - } - return VK_SUCCESS; - } - // else: Some allocations must have been touched while we are here. Next try. - } - else + /* + There is no need to call this, because Vulkan spec allows to skip vkUnmapMemory + before vkFreeMemory. + + if(currAlloc->GetMappedData() != VMA_NULL) { - // Could not find place in any of the blocks - break outer loop. - break; + (*m_VulkanFunctions.vkUnmapMemory)(m_hDevice, hMemory); } + */ + + FreeVulkanMemory(memTypeIndex, currAlloc->GetSize(), hMemory); + m_Budget.RemoveAllocation(MemoryTypeIndexToHeapIndex(memTypeIndex), currAlloc->GetSize()); + m_AllocationObjectAllocator.Free(currAlloc); } - /* Maximum number of tries exceeded - a very unlike event when many other - threads are simultaneously touching allocations making it impossible to make - lost at the same time as we try to allocate. */ - if(tryIndex == VMA_ALLOCATION_TRY_COUNT) - { - return VK_ERROR_TOO_MANY_OBJECTS; - } + + memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); } - return VK_ERROR_OUT_OF_DEVICE_MEMORY; + return res; } -void VmaBlockVector::Free( - const VmaAllocation hAllocation) +VkResult VmaAllocator_T::AllocateDedicatedMemoryPage( + VmaPool pool, + VkDeviceSize size, + VmaSuballocationType suballocType, + uint32_t memTypeIndex, + const VkMemoryAllocateInfo& allocInfo, + bool map, + bool isUserDataString, + bool isMappingAllowed, + void* pUserData, + VmaAllocation* pAllocation) { - VmaDeviceMemoryBlock* pBlockToDelete = VMA_NULL; - - bool budgetExceeded = false; + VkDeviceMemory hMemory = VK_NULL_HANDLE; + VkResult res = AllocateVulkanMemory(&allocInfo, &hMemory); + if(res < 0) { - const uint32_t heapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex); - VmaBudget heapBudget = {}; - m_hAllocator->GetBudget(&heapBudget, heapIndex, 1); - budgetExceeded = heapBudget.usage >= heapBudget.budget; + VMA_DEBUG_LOG(" vkAllocateMemory FAILED"); + return res; } - // Scope for lock. + void* pMappedData = VMA_NULL; + if(map) { - VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); - - VmaDeviceMemoryBlock* pBlock = hAllocation->GetBlock(); - - if(IsCorruptionDetectionEnabled()) + res = (*m_VulkanFunctions.vkMapMemory)( + m_hDevice, + hMemory, + 0, + VK_WHOLE_SIZE, + 0, + &pMappedData); + if(res < 0) { - VkResult res = pBlock->ValidateMagicValueAroundAllocation(m_hAllocator, hAllocation->GetOffset(), hAllocation->GetSize()); - VMA_ASSERT(res == VK_SUCCESS && "Couldn't map block memory to validate magic value."); + VMA_DEBUG_LOG(" vkMapMemory FAILED"); + FreeVulkanMemory(memTypeIndex, size, hMemory); + return res; } + } - if(hAllocation->IsPersistentMap()) - { - pBlock->Unmap(m_hAllocator, 1); - } + *pAllocation = m_AllocationObjectAllocator.Allocate(isMappingAllowed); + (*pAllocation)->InitDedicatedAllocation(pool, memTypeIndex, hMemory, suballocType, pMappedData, size); + if (isUserDataString) + (*pAllocation)->SetName(this, (const char*)pUserData); + else + (*pAllocation)->SetUserData(this, pUserData); + m_Budget.AddAllocation(MemoryTypeIndexToHeapIndex(memTypeIndex), size); + if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) + { + FillAllocation(*pAllocation, VMA_ALLOCATION_FILL_PATTERN_CREATED); + } - pBlock->m_pMetadata->Free(hAllocation); - VMA_HEAVY_ASSERT(pBlock->Validate()); + return VK_SUCCESS; +} - VMA_DEBUG_LOG(" Freed from MemoryTypeIndex=%u", m_MemoryTypeIndex); +void VmaAllocator_T::GetBufferMemoryRequirements( + VkBuffer hBuffer, + VkMemoryRequirements& memReq, + bool& requiresDedicatedAllocation, + bool& prefersDedicatedAllocation) const +{ +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + VkBufferMemoryRequirementsInfo2KHR memReqInfo = { VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR }; + memReqInfo.buffer = hBuffer; - const bool canDeleteBlock = m_Blocks.size() > m_MinBlockCount; - // pBlock became empty after this deallocation. - if(pBlock->m_pMetadata->IsEmpty()) - { - // Already has empty block. We don't want to have two, so delete this one. - if((m_HasEmptyBlock || budgetExceeded) && canDeleteBlock) - { - pBlockToDelete = pBlock; - Remove(pBlock); - } - // else: We now have an empty block - leave it. - } - // pBlock didn't become empty, but we have another empty block - find and free that one. - // (This is optional, heuristics.) - else if(m_HasEmptyBlock && canDeleteBlock) - { - VmaDeviceMemoryBlock* pLastBlock = m_Blocks.back(); - if(pLastBlock->m_pMetadata->IsEmpty()) - { - pBlockToDelete = pLastBlock; - m_Blocks.pop_back(); - } - } + VkMemoryDedicatedRequirementsKHR memDedicatedReq = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR }; - UpdateHasEmptyBlock(); - IncrementallySortBlocks(); - } + VkMemoryRequirements2KHR memReq2 = { VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR }; + VmaPnextChainPushFront(&memReq2, &memDedicatedReq); - // Destruction of a free block. Deferred until this point, outside of mutex - // lock, for performance reason. - if(pBlockToDelete != VMA_NULL) + (*m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR)(m_hDevice, &memReqInfo, &memReq2); + + memReq = memReq2.memoryRequirements; + requiresDedicatedAllocation = (memDedicatedReq.requiresDedicatedAllocation != VK_FALSE); + prefersDedicatedAllocation = (memDedicatedReq.prefersDedicatedAllocation != VK_FALSE); + } + else +#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 { - VMA_DEBUG_LOG(" Deleted empty block"); - pBlockToDelete->Destroy(m_hAllocator); - vma_delete(m_hAllocator, pBlockToDelete); + (*m_VulkanFunctions.vkGetBufferMemoryRequirements)(m_hDevice, hBuffer, &memReq); + requiresDedicatedAllocation = false; + prefersDedicatedAllocation = false; } } -VkDeviceSize VmaBlockVector::CalcMaxBlockSize() const +void VmaAllocator_T::GetImageMemoryRequirements( + VkImage hImage, + VkMemoryRequirements& memReq, + bool& requiresDedicatedAllocation, + bool& prefersDedicatedAllocation) const { - VkDeviceSize result = 0; - for(size_t i = m_Blocks.size(); i--; ) +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) { - result = VMA_MAX(result, m_Blocks[i]->m_pMetadata->GetSize()); - if(result >= m_PreferredBlockSize) - { - break; - } + VkImageMemoryRequirementsInfo2KHR memReqInfo = { VK_STRUCTURE_TYPE_IMAGE_MEMORY_REQUIREMENTS_INFO_2_KHR }; + memReqInfo.image = hImage; + + VkMemoryDedicatedRequirementsKHR memDedicatedReq = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR }; + + VkMemoryRequirements2KHR memReq2 = { VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR }; + VmaPnextChainPushFront(&memReq2, &memDedicatedReq); + + (*m_VulkanFunctions.vkGetImageMemoryRequirements2KHR)(m_hDevice, &memReqInfo, &memReq2); + + memReq = memReq2.memoryRequirements; + requiresDedicatedAllocation = (memDedicatedReq.requiresDedicatedAllocation != VK_FALSE); + prefersDedicatedAllocation = (memDedicatedReq.prefersDedicatedAllocation != VK_FALSE); + } + else +#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + { + (*m_VulkanFunctions.vkGetImageMemoryRequirements)(m_hDevice, hImage, &memReq); + requiresDedicatedAllocation = false; + prefersDedicatedAllocation = false; } - return result; } -void VmaBlockVector::Remove(VmaDeviceMemoryBlock* pBlock) +VkResult VmaAllocator_T::FindMemoryTypeIndex( + uint32_t memoryTypeBits, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + VkFlags bufImgUsage, + uint32_t* pMemoryTypeIndex) const { - for (const auto blockIndex : c10::irange(m_Blocks.size())) { - if(m_Blocks[blockIndex] == pBlock) - { - VmaVectorRemove(m_Blocks, blockIndex); - return; - } + memoryTypeBits &= GetGlobalMemoryTypeBits(); + + if(pAllocationCreateInfo->memoryTypeBits != 0) + { + memoryTypeBits &= pAllocationCreateInfo->memoryTypeBits; } - VMA_ASSERT(0); -} -void VmaBlockVector::IncrementallySortBlocks() -{ - if(m_Algorithm != VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) + VkMemoryPropertyFlags requiredFlags = 0, preferredFlags = 0, notPreferredFlags = 0; + if(!FindMemoryPreferences( + IsIntegratedGpu(), + *pAllocationCreateInfo, + bufImgUsage, + requiredFlags, preferredFlags, notPreferredFlags)) { - // Bubble sort only until first swap. - for (const auto i : c10::irange(1, m_Blocks.size())) { - if(m_Blocks[i - 1]->m_pMetadata->GetSumFreeSize() > m_Blocks[i]->m_pMetadata->GetSumFreeSize()) + return VK_ERROR_FEATURE_NOT_PRESENT; + } + + *pMemoryTypeIndex = UINT32_MAX; + uint32_t minCost = UINT32_MAX; + for(uint32_t memTypeIndex = 0, memTypeBit = 1; + memTypeIndex < GetMemoryTypeCount(); + ++memTypeIndex, memTypeBit <<= 1) + { + // This memory type is acceptable according to memoryTypeBits bitmask. + if((memTypeBit & memoryTypeBits) != 0) + { + const VkMemoryPropertyFlags currFlags = + m_MemProps.memoryTypes[memTypeIndex].propertyFlags; + // This memory type contains requiredFlags. + if((requiredFlags & ~currFlags) == 0) { - VMA_SWAP(m_Blocks[i - 1], m_Blocks[i]); - return; + // Calculate cost as number of bits from preferredFlags not present in this memory type. + uint32_t currCost = VMA_COUNT_BITS_SET(preferredFlags & ~currFlags) + + VMA_COUNT_BITS_SET(currFlags & notPreferredFlags); + // Remember memory type with lowest cost. + if(currCost < minCost) + { + *pMemoryTypeIndex = memTypeIndex; + if(currCost == 0) + { + return VK_SUCCESS; + } + minCost = currCost; + } } } } + return (*pMemoryTypeIndex != UINT32_MAX) ? VK_SUCCESS : VK_ERROR_FEATURE_NOT_PRESENT; } -VkResult VmaBlockVector::AllocateFromBlock( - VmaDeviceMemoryBlock* pBlock, - uint32_t currentFrameIndex, +VkResult VmaAllocator_T::CalcMemTypeParams( + VmaAllocationCreateInfo& inoutCreateInfo, + uint32_t memTypeIndex, VkDeviceSize size, - VkDeviceSize alignment, - VmaAllocationCreateFlags allocFlags, - void* pUserData, - VmaSuballocationType suballocType, - uint32_t strategy, - VmaAllocation* pAllocation) + size_t allocationCount) { - VMA_ASSERT((allocFlags & VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT) == 0); - const bool isUpperAddress = (allocFlags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0; - const bool mapped = (allocFlags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0; - const bool isUserDataString = (allocFlags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0; - - VmaAllocationRequest currRequest = {}; - if(pBlock->m_pMetadata->CreateAllocationRequest( - currentFrameIndex, - m_FrameInUseCount, - m_BufferImageGranularity, - size, - alignment, - isUpperAddress, - suballocType, - false, // canMakeOtherLost - strategy, - &currRequest)) + // If memory type is not HOST_VISIBLE, disable MAPPED. + if((inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0 && + (m_MemProps.memoryTypes[memTypeIndex].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) { - // Allocate from pCurrBlock. - VMA_ASSERT(currRequest.itemsToMakeLostCount == 0); - - if(mapped) - { - VkResult res = pBlock->Map(m_hAllocator, 1, VMA_NULL); - if(res != VK_SUCCESS) - { - return res; - } - } + inoutCreateInfo.flags &= ~VMA_ALLOCATION_CREATE_MAPPED_BIT; + } - *pAllocation = m_hAllocator->m_AllocationObjectAllocator.Allocate(currentFrameIndex, isUserDataString); - pBlock->m_pMetadata->Alloc(currRequest, suballocType, size, *pAllocation); - UpdateHasEmptyBlock(); - (*pAllocation)->InitBlockAllocation( - pBlock, - currRequest.offset, - alignment, - size, - m_MemoryTypeIndex, - suballocType, - mapped, - (allocFlags & VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT) != 0); - VMA_HEAVY_ASSERT(pBlock->Validate()); - (*pAllocation)->SetUserData(m_hAllocator, pUserData); - m_hAllocator->m_Budget.AddAllocation(m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex), size); - if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) - { - m_hAllocator->FillAllocation(*pAllocation, VMA_ALLOCATION_FILL_PATTERN_CREATED); - } - if(IsCorruptionDetectionEnabled()) + if((inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT) != 0 && + (inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT) != 0) + { + const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(memTypeIndex); + VmaBudget heapBudget = {}; + GetHeapBudgets(&heapBudget, heapIndex, 1); + if(heapBudget.usage + size * allocationCount > heapBudget.budget) { - VkResult res = pBlock->WriteMagicValueAroundAllocation(m_hAllocator, currRequest.offset, size); - VMA_ASSERT(res == VK_SUCCESS && "Couldn't map block memory to write magic value."); + return VK_ERROR_OUT_OF_DEVICE_MEMORY; } - return VK_SUCCESS; } - return VK_ERROR_OUT_OF_DEVICE_MEMORY; + return VK_SUCCESS; } -VkResult VmaBlockVector::CreateBlock(VkDeviceSize blockSize, size_t* pNewBlockIndex) +VkResult VmaAllocator_T::CalcAllocationParams( + VmaAllocationCreateInfo& inoutCreateInfo, + bool dedicatedRequired, + bool dedicatedPreferred) { - VkMemoryAllocateInfo allocInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO }; - allocInfo.memoryTypeIndex = m_MemoryTypeIndex; - allocInfo.allocationSize = blockSize; - -#if VMA_BUFFER_DEVICE_ADDRESS - // Every standalone block can potentially contain a buffer with VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT - always enable the feature. - VkMemoryAllocateFlagsInfoKHR allocFlagsInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO_KHR }; - if(m_hAllocator->m_UseKhrBufferDeviceAddress) + VMA_ASSERT((inoutCreateInfo.flags & + (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) != + (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT) && + "Specifying both flags VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT and VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT is incorrect."); + VMA_ASSERT((((inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_HOST_ACCESS_ALLOW_TRANSFER_INSTEAD_BIT) == 0 || + (inoutCreateInfo.flags & (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) != 0)) && + "Specifying VMA_ALLOCATION_CREATE_HOST_ACCESS_ALLOW_TRANSFER_INSTEAD_BIT requires also VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT."); + if(inoutCreateInfo.usage == VMA_MEMORY_USAGE_AUTO || inoutCreateInfo.usage == VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE || inoutCreateInfo.usage == VMA_MEMORY_USAGE_AUTO_PREFER_HOST) { - allocFlagsInfo.flags = VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT_KHR; - VmaPnextChainPushFront(&allocInfo, &allocFlagsInfo); + if((inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0) + { + VMA_ASSERT((inoutCreateInfo.flags & (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) != 0 && + "When using VMA_ALLOCATION_CREATE_MAPPED_BIT and usage = VMA_MEMORY_USAGE_AUTO*, you must also specify VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT."); + } } -#endif // #if VMA_BUFFER_DEVICE_ADDRESS -#if VMA_MEMORY_PRIORITY - VkMemoryPriorityAllocateInfoEXT priorityInfo = { VK_STRUCTURE_TYPE_MEMORY_PRIORITY_ALLOCATE_INFO_EXT }; - if(m_hAllocator->m_UseExtMemoryPriority) + // If memory is lazily allocated, it should be always dedicated. + if(dedicatedRequired || + inoutCreateInfo.usage == VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED) { - priorityInfo.priority = m_Priority; - VmaPnextChainPushFront(&allocInfo, &priorityInfo); + inoutCreateInfo.flags |= VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT; } -#endif // #if VMA_MEMORY_PRIORITY - VkDeviceMemory mem = VK_NULL_HANDLE; - VkResult res = m_hAllocator->AllocateVulkanMemory(&allocInfo, &mem); - if(res < 0) + if(inoutCreateInfo.pool != VK_NULL_HANDLE) { - return res; + if(inoutCreateInfo.pool->m_BlockVector.HasExplicitBlockSize() && + (inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT) != 0) + { + VMA_ASSERT(0 && "Specifying VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT while current custom pool doesn't support dedicated allocations."); + return VK_ERROR_FEATURE_NOT_PRESENT; + } + inoutCreateInfo.priority = inoutCreateInfo.pool->m_BlockVector.GetPriority(); } - // New VkDeviceMemory successfully created. + if((inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT) != 0 && + (inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) + { + VMA_ASSERT(0 && "Specifying VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT together with VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT makes no sense."); + return VK_ERROR_FEATURE_NOT_PRESENT; + } - // Create new Allocation for it. - VmaDeviceMemoryBlock* const pBlock = vma_new(m_hAllocator, VmaDeviceMemoryBlock)(m_hAllocator); - pBlock->Init( - m_hAllocator, - m_hParentPool, - m_MemoryTypeIndex, - mem, - allocInfo.allocationSize, - m_NextBlockId++, - m_Algorithm); + if(VMA_DEBUG_ALWAYS_DEDICATED_MEMORY && + (inoutCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) + { + inoutCreateInfo.flags |= VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT; + } - m_Blocks.push_back(pBlock); - if(pNewBlockIndex != VMA_NULL) + // Non-auto USAGE values imply HOST_ACCESS flags. + // And so does VMA_MEMORY_USAGE_UNKNOWN because it is used with custom pools. + // Which specific flag is used doesn't matter. They change things only when used with VMA_MEMORY_USAGE_AUTO*. + // Otherwise they just protect from assert on mapping. + if(inoutCreateInfo.usage != VMA_MEMORY_USAGE_AUTO && + inoutCreateInfo.usage != VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE && + inoutCreateInfo.usage != VMA_MEMORY_USAGE_AUTO_PREFER_HOST) { - *pNewBlockIndex = m_Blocks.size() - 1; + if((inoutCreateInfo.flags & (VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT)) == 0) + { + inoutCreateInfo.flags |= VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT; + } } return VK_SUCCESS; } -void VmaBlockVector::ApplyDefragmentationMovesCpu( - class VmaBlockVectorDefragmentationContext* pDefragCtx, - const VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves) +VkResult VmaAllocator_T::AllocateMemory( + const VkMemoryRequirements& vkMemReq, + bool requiresDedicatedAllocation, + bool prefersDedicatedAllocation, + VkBuffer dedicatedBuffer, + VkImage dedicatedImage, + VkFlags dedicatedBufferImageUsage, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations) { - const size_t blockCount = m_Blocks.size(); - const bool isNonCoherent = m_hAllocator->IsMemoryTypeNonCoherent(m_MemoryTypeIndex); + memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); - enum BLOCK_FLAG - { - BLOCK_FLAG_USED = 0x00000001, - BLOCK_FLAG_MAPPED_FOR_DEFRAGMENTATION = 0x00000002, - }; + VMA_ASSERT(VmaIsPow2(vkMemReq.alignment)); - struct BlockInfo + if(vkMemReq.size == 0) { - uint32_t flags; - void* pMappedData; - }; - VmaVector< BlockInfo, VmaStlAllocator > - blockInfo(blockCount, BlockInfo(), VmaStlAllocator(m_hAllocator->GetAllocationCallbacks())); - memset(blockInfo.data(), 0, blockCount * sizeof(BlockInfo)); - - // Go over all moves. Mark blocks that are used with BLOCK_FLAG_USED. - const size_t moveCount = moves.size(); - for (const auto moveIndex : c10::irange(moveCount)) { - const VmaDefragmentationMove& move = moves[moveIndex]; - blockInfo[move.srcBlockIndex].flags |= BLOCK_FLAG_USED; - blockInfo[move.dstBlockIndex].flags |= BLOCK_FLAG_USED; + return VK_ERROR_INITIALIZATION_FAILED; } - VMA_ASSERT(pDefragCtx->res == VK_SUCCESS); + VmaAllocationCreateInfo createInfoFinal = createInfo; + VkResult res = CalcAllocationParams(createInfoFinal, requiresDedicatedAllocation, prefersDedicatedAllocation); + if(res != VK_SUCCESS) + return res; - // Go over all blocks. Get mapped pointer or map if necessary. - for(size_t blockIndex = 0; pDefragCtx->res == VK_SUCCESS && blockIndex < blockCount; ++blockIndex) + if(createInfoFinal.pool != VK_NULL_HANDLE) { - BlockInfo& currBlockInfo = blockInfo[blockIndex]; - VmaDeviceMemoryBlock* pBlock = m_Blocks[blockIndex]; - if((currBlockInfo.flags & BLOCK_FLAG_USED) != 0) - { - currBlockInfo.pMappedData = pBlock->GetMappedData(); - // It is not originally mapped - map it. - if(currBlockInfo.pMappedData == VMA_NULL) - { - pDefragCtx->res = pBlock->Map(m_hAllocator, 1, &currBlockInfo.pMappedData); - if(pDefragCtx->res == VK_SUCCESS) - { - currBlockInfo.flags |= BLOCK_FLAG_MAPPED_FOR_DEFRAGMENTATION; - } - } - } + VmaBlockVector& blockVector = createInfoFinal.pool->m_BlockVector; + return AllocateMemoryOfType( + createInfoFinal.pool, + vkMemReq.size, + vkMemReq.alignment, + prefersDedicatedAllocation, + dedicatedBuffer, + dedicatedImage, + dedicatedBufferImageUsage, + createInfoFinal, + blockVector.GetMemoryTypeIndex(), + suballocType, + createInfoFinal.pool->m_DedicatedAllocations, + blockVector, + allocationCount, + pAllocations); } - - // Go over all moves. Do actual data transfer. - if(pDefragCtx->res == VK_SUCCESS) + else { - const VkDeviceSize nonCoherentAtomSize = m_hAllocator->m_PhysicalDeviceProperties.limits.nonCoherentAtomSize; - VkMappedMemoryRange memRange = { VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE }; - - for (const auto moveIndex : c10::irange(moveCount)) { - const VmaDefragmentationMove& move = moves[moveIndex]; + // Bit mask of memory Vulkan types acceptable for this allocation. + uint32_t memoryTypeBits = vkMemReq.memoryTypeBits; + uint32_t memTypeIndex = UINT32_MAX; + res = FindMemoryTypeIndex(memoryTypeBits, &createInfoFinal, dedicatedBufferImageUsage, &memTypeIndex); + // Can't find any single memory type matching requirements. res is VK_ERROR_FEATURE_NOT_PRESENT. + if(res != VK_SUCCESS) + return res; + do + { + VmaBlockVector* blockVector = m_pBlockVectors[memTypeIndex]; + VMA_ASSERT(blockVector && "Trying to use unsupported memory type!"); + res = AllocateMemoryOfType( + VK_NULL_HANDLE, + vkMemReq.size, + vkMemReq.alignment, + requiresDedicatedAllocation || prefersDedicatedAllocation, + dedicatedBuffer, + dedicatedImage, + dedicatedBufferImageUsage, + createInfoFinal, + memTypeIndex, + suballocType, + m_DedicatedAllocations[memTypeIndex], + *blockVector, + allocationCount, + pAllocations); + // Allocation succeeded + if(res == VK_SUCCESS) + return VK_SUCCESS; - const BlockInfo& srcBlockInfo = blockInfo[move.srcBlockIndex]; - const BlockInfo& dstBlockInfo = blockInfo[move.dstBlockIndex]; + // Remove old memTypeIndex from list of possibilities. + memoryTypeBits &= ~(1u << memTypeIndex); + // Find alternative memTypeIndex. + res = FindMemoryTypeIndex(memoryTypeBits, &createInfoFinal, dedicatedBufferImageUsage, &memTypeIndex); + } while(res == VK_SUCCESS); - VMA_ASSERT(srcBlockInfo.pMappedData && dstBlockInfo.pMappedData); + // No other matching memory type index could be found. + // Not returning res, which is VK_ERROR_FEATURE_NOT_PRESENT, because we already failed to allocate once. + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } +} - // Invalidate source. - if(isNonCoherent) - { - VmaDeviceMemoryBlock* const pSrcBlock = m_Blocks[move.srcBlockIndex]; - memRange.memory = pSrcBlock->GetDeviceMemory(); - memRange.offset = VmaAlignDown(move.srcOffset, nonCoherentAtomSize); - memRange.size = VMA_MIN( - VmaAlignUp(move.size + (move.srcOffset - memRange.offset), nonCoherentAtomSize), - pSrcBlock->m_pMetadata->GetSize() - memRange.offset); - (*m_hAllocator->GetVulkanFunctions().vkInvalidateMappedMemoryRanges)(m_hAllocator->m_hDevice, 1, &memRange); - } +void VmaAllocator_T::FreeMemory( + size_t allocationCount, + const VmaAllocation* pAllocations) +{ + VMA_ASSERT(pAllocations); - // THE PLACE WHERE ACTUAL DATA COPY HAPPENS. - memmove( - reinterpret_cast(dstBlockInfo.pMappedData) + move.dstOffset, - reinterpret_cast(srcBlockInfo.pMappedData) + move.srcOffset, - static_cast(move.size)); + for(size_t allocIndex = allocationCount; allocIndex--; ) + { + VmaAllocation allocation = pAllocations[allocIndex]; - if(IsCorruptionDetectionEnabled()) + if(allocation != VK_NULL_HANDLE) + { + if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) { - VmaWriteMagicValue(dstBlockInfo.pMappedData, move.dstOffset - VMA_DEBUG_MARGIN); - VmaWriteMagicValue(dstBlockInfo.pMappedData, move.dstOffset + move.size); + FillAllocation(allocation, VMA_ALLOCATION_FILL_PATTERN_DESTROYED); } - // Flush destination. - if(isNonCoherent) + allocation->FreeName(this); + + switch(allocation->GetType()) { - VmaDeviceMemoryBlock* const pDstBlock = m_Blocks[move.dstBlockIndex]; - memRange.memory = pDstBlock->GetDeviceMemory(); - memRange.offset = VmaAlignDown(move.dstOffset, nonCoherentAtomSize); - memRange.size = VMA_MIN( - VmaAlignUp(move.size + (move.dstOffset - memRange.offset), nonCoherentAtomSize), - pDstBlock->m_pMetadata->GetSize() - memRange.offset); - (*m_hAllocator->GetVulkanFunctions().vkFlushMappedMemoryRanges)(m_hAllocator->m_hDevice, 1, &memRange); + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaBlockVector* pBlockVector = VMA_NULL; + VmaPool hPool = allocation->GetParentPool(); + if(hPool != VK_NULL_HANDLE) + { + pBlockVector = &hPool->m_BlockVector; + } + else + { + const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); + pBlockVector = m_pBlockVectors[memTypeIndex]; + VMA_ASSERT(pBlockVector && "Trying to free memory of unsupported type!"); + } + pBlockVector->Free(allocation); + } + break; + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + FreeDedicatedMemory(allocation); + break; + default: + VMA_ASSERT(0); } } } - - // Go over all blocks in reverse order. Unmap those that were mapped just for defragmentation. - // Regardless of pCtx->res == VK_SUCCESS. - for(size_t blockIndex = blockCount; blockIndex--; ) - { - const BlockInfo& currBlockInfo = blockInfo[blockIndex]; - if((currBlockInfo.flags & BLOCK_FLAG_MAPPED_FOR_DEFRAGMENTATION) != 0) - { - VmaDeviceMemoryBlock* pBlock = m_Blocks[blockIndex]; - pBlock->Unmap(m_hAllocator, 1); - } - } } -void VmaBlockVector::ApplyDefragmentationMovesGpu( - class VmaBlockVectorDefragmentationContext* pDefragCtx, - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkCommandBuffer commandBuffer) +void VmaAllocator_T::CalculateStatistics(VmaTotalStatistics* pStats) { - const size_t blockCount = m_Blocks.size(); - - pDefragCtx->blockContexts.resize(blockCount); - memset(pDefragCtx->blockContexts.data(), 0, blockCount * sizeof(VmaBlockDefragmentationContext)); - - // Go over all moves. Mark blocks that are used with BLOCK_FLAG_USED. - const size_t moveCount = moves.size(); - for (const auto moveIndex : c10::irange(moveCount)) { - const VmaDefragmentationMove& move = moves[moveIndex]; + // Initialize. + VmaClearDetailedStatistics(pStats->total); + for(uint32_t i = 0; i < VK_MAX_MEMORY_TYPES; ++i) + VmaClearDetailedStatistics(pStats->memoryType[i]); + for(uint32_t i = 0; i < VK_MAX_MEMORY_HEAPS; ++i) + VmaClearDetailedStatistics(pStats->memoryHeap[i]); - //if(move.type == VMA_ALLOCATION_TYPE_UNKNOWN) - { - // Old school move still require us to map the whole block - pDefragCtx->blockContexts[move.srcBlockIndex].flags |= VmaBlockDefragmentationContext::BLOCK_FLAG_USED; - pDefragCtx->blockContexts[move.dstBlockIndex].flags |= VmaBlockDefragmentationContext::BLOCK_FLAG_USED; - } + // Process default pools. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + VmaBlockVector* const pBlockVector = m_pBlockVectors[memTypeIndex]; + if (pBlockVector != VMA_NULL) + pBlockVector->AddDetailedStatistics(pStats->memoryType[memTypeIndex]); } - VMA_ASSERT(pDefragCtx->res == VK_SUCCESS); - - // Go over all blocks. Create and bind buffer for whole block if necessary. + // Process custom pools. { - VkBufferCreateInfo bufCreateInfo; - VmaFillGpuDefragmentationBufferCreateInfo(bufCreateInfo); - - for(size_t blockIndex = 0; pDefragCtx->res == VK_SUCCESS && blockIndex < blockCount; ++blockIndex) + VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); + for(VmaPool pool = m_Pools.Front(); pool != VMA_NULL; pool = m_Pools.GetNext(pool)) { - VmaBlockDefragmentationContext& currBlockCtx = pDefragCtx->blockContexts[blockIndex]; - VmaDeviceMemoryBlock* pBlock = m_Blocks[blockIndex]; - if((currBlockCtx.flags & VmaBlockDefragmentationContext::BLOCK_FLAG_USED) != 0) - { - bufCreateInfo.size = pBlock->m_pMetadata->GetSize(); - pDefragCtx->res = (*m_hAllocator->GetVulkanFunctions().vkCreateBuffer)( - m_hAllocator->m_hDevice, &bufCreateInfo, m_hAllocator->GetAllocationCallbacks(), &currBlockCtx.hBuffer); - if(pDefragCtx->res == VK_SUCCESS) - { - pDefragCtx->res = (*m_hAllocator->GetVulkanFunctions().vkBindBufferMemory)( - m_hAllocator->m_hDevice, currBlockCtx.hBuffer, pBlock->GetDeviceMemory(), 0); - } - } + VmaBlockVector& blockVector = pool->m_BlockVector; + const uint32_t memTypeIndex = blockVector.GetMemoryTypeIndex(); + blockVector.AddDetailedStatistics(pStats->memoryType[memTypeIndex]); + pool->m_DedicatedAllocations.AddDetailedStatistics(pStats->memoryType[memTypeIndex]); } } - // Go over all moves. Post data transfer commands to command buffer. - if(pDefragCtx->res == VK_SUCCESS) + // Process dedicated allocations. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) { - for (const auto moveIndex : c10::irange(moveCount)) { - const VmaDefragmentationMove& move = moves[moveIndex]; - - const VmaBlockDefragmentationContext& srcBlockCtx = pDefragCtx->blockContexts[move.srcBlockIndex]; - const VmaBlockDefragmentationContext& dstBlockCtx = pDefragCtx->blockContexts[move.dstBlockIndex]; - - VMA_ASSERT(srcBlockCtx.hBuffer && dstBlockCtx.hBuffer); - - VkBufferCopy region = { - move.srcOffset, - move.dstOffset, - move.size }; - (*m_hAllocator->GetVulkanFunctions().vkCmdCopyBuffer)( - commandBuffer, srcBlockCtx.hBuffer, dstBlockCtx.hBuffer, 1, ®ion); - } + m_DedicatedAllocations[memTypeIndex].AddDetailedStatistics(pStats->memoryType[memTypeIndex]); } - // Save buffers to defrag context for later destruction. - if(pDefragCtx->res == VK_SUCCESS && moveCount > 0) + // Sum from memory types to memory heaps. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) { - pDefragCtx->res = VK_NOT_READY; + const uint32_t memHeapIndex = m_MemProps.memoryTypes[memTypeIndex].heapIndex; + VmaAddDetailedStatistics(pStats->memoryHeap[memHeapIndex], pStats->memoryType[memTypeIndex]); } + + // Sum from memory heaps to total. + for(uint32_t memHeapIndex = 0; memHeapIndex < GetMemoryHeapCount(); ++memHeapIndex) + VmaAddDetailedStatistics(pStats->total, pStats->memoryHeap[memHeapIndex]); + + VMA_ASSERT(pStats->total.statistics.allocationCount == 0 || + pStats->total.allocationSizeMax >= pStats->total.allocationSizeMin); + VMA_ASSERT(pStats->total.unusedRangeCount == 0 || + pStats->total.unusedRangeSizeMax >= pStats->total.unusedRangeSizeMin); } -void VmaBlockVector::FreeEmptyBlocks(VmaDefragmentationStats* pDefragmentationStats) +void VmaAllocator_T::GetHeapBudgets(VmaBudget* outBudgets, uint32_t firstHeap, uint32_t heapCount) { - for(size_t blockIndex = m_Blocks.size(); blockIndex--; ) +#if VMA_MEMORY_BUDGET + if(m_UseExtMemoryBudget) { - VmaDeviceMemoryBlock* pBlock = m_Blocks[blockIndex]; - if(pBlock->m_pMetadata->IsEmpty()) + if(m_Budget.m_OperationsSinceBudgetFetch < 30) { - if(m_Blocks.size() > m_MinBlockCount) + VmaMutexLockRead lockRead(m_Budget.m_BudgetMutex, m_UseMutex); + for(uint32_t i = 0; i < heapCount; ++i, ++outBudgets) { - if(pDefragmentationStats != VMA_NULL) + const uint32_t heapIndex = firstHeap + i; + + outBudgets->statistics.blockCount = m_Budget.m_BlockCount[heapIndex]; + outBudgets->statistics.allocationCount = m_Budget.m_AllocationCount[heapIndex]; + outBudgets->statistics.blockBytes = m_Budget.m_BlockBytes[heapIndex]; + outBudgets->statistics.allocationBytes = m_Budget.m_AllocationBytes[heapIndex]; + + if(m_Budget.m_VulkanUsage[heapIndex] + outBudgets->statistics.blockBytes > m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]) { - ++pDefragmentationStats->deviceMemoryBlocksFreed; - pDefragmentationStats->bytesFreed += pBlock->m_pMetadata->GetSize(); + outBudgets->usage = m_Budget.m_VulkanUsage[heapIndex] + + outBudgets->statistics.blockBytes - m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]; + } + else + { + outBudgets->usage = 0; } - VmaVectorRemove(m_Blocks, blockIndex); - pBlock->Destroy(m_hAllocator); - vma_delete(m_hAllocator, pBlock); - } - else - { - break; + // Have to take MIN with heap size because explicit HeapSizeLimit is included in it. + outBudgets->budget = VMA_MIN( + m_Budget.m_VulkanBudget[heapIndex], m_MemProps.memoryHeaps[heapIndex].size); } } + else + { + UpdateVulkanBudget(); // Outside of mutex lock + GetHeapBudgets(outBudgets, firstHeap, heapCount); // Recursion + } } - UpdateHasEmptyBlock(); -} - -void VmaBlockVector::UpdateHasEmptyBlock() -{ - m_HasEmptyBlock = false; - for(size_t index = 0, count = m_Blocks.size(); index < count; ++index) + else +#endif { - VmaDeviceMemoryBlock* const pBlock = m_Blocks[index]; - if(pBlock->m_pMetadata->IsEmpty()) + for(uint32_t i = 0; i < heapCount; ++i, ++outBudgets) { - m_HasEmptyBlock = true; - break; + const uint32_t heapIndex = firstHeap + i; + + outBudgets->statistics.blockCount = m_Budget.m_BlockCount[heapIndex]; + outBudgets->statistics.allocationCount = m_Budget.m_AllocationCount[heapIndex]; + outBudgets->statistics.blockBytes = m_Budget.m_BlockBytes[heapIndex]; + outBudgets->statistics.allocationBytes = m_Budget.m_AllocationBytes[heapIndex]; + + outBudgets->usage = outBudgets->statistics.blockBytes; + outBudgets->budget = m_MemProps.memoryHeaps[heapIndex].size * 8 / 10; // 80% heuristics. } } } -#if VMA_STATS_STRING_ENABLED +void VmaAllocator_T::GetAllocationInfo(VmaAllocation hAllocation, VmaAllocationInfo* pAllocationInfo) +{ + pAllocationInfo->memoryType = hAllocation->GetMemoryTypeIndex(); + pAllocationInfo->deviceMemory = hAllocation->GetMemory(); + pAllocationInfo->offset = hAllocation->GetOffset(); + pAllocationInfo->size = hAllocation->GetSize(); + pAllocationInfo->pMappedData = hAllocation->GetMappedData(); + pAllocationInfo->pUserData = hAllocation->GetUserData(); + pAllocationInfo->pName = hAllocation->GetName(); +} -void VmaBlockVector::PrintDetailedMap(class VmaJsonWriter& json) +VkResult VmaAllocator_T::CreatePool(const VmaPoolCreateInfo* pCreateInfo, VmaPool* pPool) { - VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + VMA_DEBUG_LOG(" CreatePool: MemoryTypeIndex=%u, flags=%u", pCreateInfo->memoryTypeIndex, pCreateInfo->flags); - json.BeginObject(); + VmaPoolCreateInfo newCreateInfo = *pCreateInfo; - if(IsCustomPool()) + // Protection against uninitialized new structure member. If garbage data are left there, this pointer dereference would crash. + if(pCreateInfo->pMemoryAllocateNext) { - const char* poolName = m_hParentPool->GetName(); - if(poolName != VMA_NULL && poolName[0] != '\0') - { - json.WriteString("Name"); - json.WriteString(poolName); - } + VMA_ASSERT(((const VkBaseInStructure*)pCreateInfo->pMemoryAllocateNext)->sType != 0); + } - json.WriteString("MemoryTypeIndex"); - json.WriteNumber(m_MemoryTypeIndex); + if(newCreateInfo.maxBlockCount == 0) + { + newCreateInfo.maxBlockCount = SIZE_MAX; + } + if(newCreateInfo.minBlockCount > newCreateInfo.maxBlockCount) + { + return VK_ERROR_INITIALIZATION_FAILED; + } + // Memory type index out of range or forbidden. + if(pCreateInfo->memoryTypeIndex >= GetMemoryTypeCount() || + ((1u << pCreateInfo->memoryTypeIndex) & m_GlobalMemoryTypeBits) == 0) + { + return VK_ERROR_FEATURE_NOT_PRESENT; + } + if(newCreateInfo.minAllocationAlignment > 0) + { + VMA_ASSERT(VmaIsPow2(newCreateInfo.minAllocationAlignment)); + } - json.WriteString("BlockSize"); - json.WriteNumber(m_PreferredBlockSize); - - json.WriteString("BlockCount"); - json.BeginObject(true); - if(m_MinBlockCount > 0) - { - json.WriteString("Min"); - json.WriteNumber((uint64_t)m_MinBlockCount); - } - if(m_MaxBlockCount < SIZE_MAX) - { - json.WriteString("Max"); - json.WriteNumber((uint64_t)m_MaxBlockCount); - } - json.WriteString("Cur"); - json.WriteNumber((uint64_t)m_Blocks.size()); - json.EndObject(); + const VkDeviceSize preferredBlockSize = CalcPreferredBlockSize(newCreateInfo.memoryTypeIndex); - if(m_FrameInUseCount > 0) - { - json.WriteString("FrameInUseCount"); - json.WriteNumber(m_FrameInUseCount); - } + *pPool = vma_new(this, VmaPool_T)(this, newCreateInfo, preferredBlockSize); - if(m_Algorithm != 0) - { - json.WriteString("Algorithm"); - json.WriteString(VmaAlgorithmToStr(m_Algorithm)); - } + VkResult res = (*pPool)->m_BlockVector.CreateMinBlocks(); + if(res != VK_SUCCESS) + { + vma_delete(this, *pPool); + *pPool = VMA_NULL; + return res; } - else + + // Add to m_Pools. { - json.WriteString("PreferredBlockSize"); - json.WriteNumber(m_PreferredBlockSize); + VmaMutexLockWrite lock(m_PoolsMutex, m_UseMutex); + (*pPool)->SetId(m_NextPoolId++); + m_Pools.PushBack(*pPool); } - json.WriteString("Blocks"); - json.BeginObject(); - for (const auto i : c10::irange(m_Blocks.size())) { - json.BeginString(); - json.ContinueString(m_Blocks[i]->GetId()); - json.EndString(); + return VK_SUCCESS; +} - m_Blocks[i]->m_pMetadata->PrintDetailedMap(json); +void VmaAllocator_T::DestroyPool(VmaPool pool) +{ + // Remove from m_Pools. + { + VmaMutexLockWrite lock(m_PoolsMutex, m_UseMutex); + m_Pools.Remove(pool); } - json.EndObject(); - json.EndObject(); + vma_delete(this, pool); } -#endif // #if VMA_STATS_STRING_ENABLED - -void VmaBlockVector::Defragment( - class VmaBlockVectorDefragmentationContext* pCtx, - VmaDefragmentationStats* pStats, VmaDefragmentationFlags flags, - VkDeviceSize& maxCpuBytesToMove, uint32_t& maxCpuAllocationsToMove, - VkDeviceSize& maxGpuBytesToMove, uint32_t& maxGpuAllocationsToMove, - VkCommandBuffer commandBuffer) +void VmaAllocator_T::GetPoolStatistics(VmaPool pool, VmaStatistics* pPoolStats) { - pCtx->res = VK_SUCCESS; + VmaClearStatistics(*pPoolStats); + pool->m_BlockVector.AddStatistics(*pPoolStats); + pool->m_DedicatedAllocations.AddStatistics(*pPoolStats); +} - const VkMemoryPropertyFlags memPropFlags = - m_hAllocator->m_MemProps.memoryTypes[m_MemoryTypeIndex].propertyFlags; - const bool isHostVisible = (memPropFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0; +void VmaAllocator_T::CalculatePoolStatistics(VmaPool pool, VmaDetailedStatistics* pPoolStats) +{ + VmaClearDetailedStatistics(*pPoolStats); + pool->m_BlockVector.AddDetailedStatistics(*pPoolStats); + pool->m_DedicatedAllocations.AddDetailedStatistics(*pPoolStats); +} - const bool canDefragmentOnCpu = maxCpuBytesToMove > 0 && maxCpuAllocationsToMove > 0 && - isHostVisible; - const bool canDefragmentOnGpu = maxGpuBytesToMove > 0 && maxGpuAllocationsToMove > 0 && - !IsCorruptionDetectionEnabled() && - ((1u << m_MemoryTypeIndex) & m_hAllocator->GetGpuDefragmentationMemoryTypeBits()) != 0; +void VmaAllocator_T::SetCurrentFrameIndex(uint32_t frameIndex) +{ + m_CurrentFrameIndex.store(frameIndex); - // There are options to defragment this memory type. - if(canDefragmentOnCpu || canDefragmentOnGpu) +#if VMA_MEMORY_BUDGET + if(m_UseExtMemoryBudget) { - bool defragmentOnGpu; - // There is only one option to defragment this memory type. - if(canDefragmentOnGpu != canDefragmentOnCpu) - { - defragmentOnGpu = canDefragmentOnGpu; - } - // Both options are available: Heuristics to choose the best one. - else - { - defragmentOnGpu = (memPropFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT) != 0 || - m_hAllocator->IsIntegratedGpu(); - } - - bool overlappingMoveSupported = !defragmentOnGpu; - - if(m_hAllocator->m_UseMutex) - { - if(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL) - { - if(!m_Mutex.TryLockWrite()) - { - pCtx->res = VK_ERROR_INITIALIZATION_FAILED; - return; - } - } - else - { - m_Mutex.LockWrite(); - pCtx->mutexLocked = true; - } - } - - pCtx->Begin(overlappingMoveSupported, flags); + UpdateVulkanBudget(); + } +#endif // #if VMA_MEMORY_BUDGET +} - // Defragment. +VkResult VmaAllocator_T::CheckPoolCorruption(VmaPool hPool) +{ + return hPool->m_BlockVector.CheckCorruption(); +} - const VkDeviceSize maxBytesToMove = defragmentOnGpu ? maxGpuBytesToMove : maxCpuBytesToMove; - const uint32_t maxAllocationsToMove = defragmentOnGpu ? maxGpuAllocationsToMove : maxCpuAllocationsToMove; - pCtx->res = pCtx->GetAlgorithm()->Defragment(pCtx->defragmentationMoves, maxBytesToMove, maxAllocationsToMove, flags); +VkResult VmaAllocator_T::CheckCorruption(uint32_t memoryTypeBits) +{ + VkResult finalRes = VK_ERROR_FEATURE_NOT_PRESENT; - // Accumulate statistics. - if(pStats != VMA_NULL) + // Process default pools. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + VmaBlockVector* const pBlockVector = m_pBlockVectors[memTypeIndex]; + if(pBlockVector != VMA_NULL) { - const VkDeviceSize bytesMoved = pCtx->GetAlgorithm()->GetBytesMoved(); - const uint32_t allocationsMoved = pCtx->GetAlgorithm()->GetAllocationsMoved(); - pStats->bytesMoved += bytesMoved; - pStats->allocationsMoved += allocationsMoved; - VMA_ASSERT(bytesMoved <= maxBytesToMove); - VMA_ASSERT(allocationsMoved <= maxAllocationsToMove); - if(defragmentOnGpu) - { - maxGpuBytesToMove -= bytesMoved; - maxGpuAllocationsToMove -= allocationsMoved; - } - else + VkResult localRes = pBlockVector->CheckCorruption(); + switch(localRes) { - maxCpuBytesToMove -= bytesMoved; - maxCpuAllocationsToMove -= allocationsMoved; + case VK_ERROR_FEATURE_NOT_PRESENT: + break; + case VK_SUCCESS: + finalRes = VK_SUCCESS; + break; + default: + return localRes; } } + } - if(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL) - { - if(m_hAllocator->m_UseMutex) - m_Mutex.UnlockWrite(); - - if(pCtx->res >= VK_SUCCESS && !pCtx->defragmentationMoves.empty()) - pCtx->res = VK_NOT_READY; - - return; - } - - if(pCtx->res >= VK_SUCCESS) + // Process custom pools. + { + VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); + for(VmaPool pool = m_Pools.Front(); pool != VMA_NULL; pool = m_Pools.GetNext(pool)) { - if(defragmentOnGpu) - { - ApplyDefragmentationMovesGpu(pCtx, pCtx->defragmentationMoves, commandBuffer); - } - else + if(((1u << pool->m_BlockVector.GetMemoryTypeIndex()) & memoryTypeBits) != 0) { - ApplyDefragmentationMovesCpu(pCtx, pCtx->defragmentationMoves); + VkResult localRes = pool->m_BlockVector.CheckCorruption(); + switch(localRes) + { + case VK_ERROR_FEATURE_NOT_PRESENT: + break; + case VK_SUCCESS: + finalRes = VK_SUCCESS; + break; + default: + return localRes; + } } } } + + return finalRes; } -void VmaBlockVector::DefragmentationEnd( - class VmaBlockVectorDefragmentationContext* pCtx, - uint32_t flags, - VmaDefragmentationStats* pStats) +VkResult VmaAllocator_T::AllocateVulkanMemory(const VkMemoryAllocateInfo* pAllocateInfo, VkDeviceMemory* pMemory) { - if(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL && m_hAllocator->m_UseMutex) + AtomicTransactionalIncrement deviceMemoryCountIncrement; + const uint64_t prevDeviceMemoryCount = deviceMemoryCountIncrement.Increment(&m_DeviceMemoryCount); +#if VMA_DEBUG_DONT_EXCEED_MAX_MEMORY_ALLOCATION_COUNT + if(prevDeviceMemoryCount >= m_PhysicalDeviceProperties.limits.maxMemoryAllocationCount) { - VMA_ASSERT(pCtx->mutexLocked == false); - - // Incremental defragmentation doesn't hold the lock, so when we enter here we don't actually have any - // lock protecting us. Since we mutate state here, we have to take the lock out now - m_Mutex.LockWrite(); - pCtx->mutexLocked = true; + return VK_ERROR_TOO_MANY_OBJECTS; } +#endif + + const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(pAllocateInfo->memoryTypeIndex); - // If the mutex isn't locked we didn't do any work and there is nothing to delete. - if(pCtx->mutexLocked || !m_hAllocator->m_UseMutex) + // HeapSizeLimit is in effect for this heap. + if((m_HeapSizeLimitMask & (1u << heapIndex)) != 0) { - // Destroy buffers. - for(size_t blockIndex = pCtx->blockContexts.size(); blockIndex--;) + const VkDeviceSize heapSize = m_MemProps.memoryHeaps[heapIndex].size; + VkDeviceSize blockBytes = m_Budget.m_BlockBytes[heapIndex]; + for(;;) { - VmaBlockDefragmentationContext &blockCtx = pCtx->blockContexts[blockIndex]; - if(blockCtx.hBuffer) + const VkDeviceSize blockBytesAfterAllocation = blockBytes + pAllocateInfo->allocationSize; + if(blockBytesAfterAllocation > heapSize) { - (*m_hAllocator->GetVulkanFunctions().vkDestroyBuffer)(m_hAllocator->m_hDevice, blockCtx.hBuffer, m_hAllocator->GetAllocationCallbacks()); + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + if(m_Budget.m_BlockBytes[heapIndex].compare_exchange_strong(blockBytes, blockBytesAfterAllocation)) + { + break; } - } - - if(pCtx->res >= VK_SUCCESS) - { - FreeEmptyBlocks(pStats); } } - - if(pCtx->mutexLocked) + else { - VMA_ASSERT(m_hAllocator->m_UseMutex); - m_Mutex.UnlockWrite(); + m_Budget.m_BlockBytes[heapIndex] += pAllocateInfo->allocationSize; } -} - -uint32_t VmaBlockVector::ProcessDefragmentations( - class VmaBlockVectorDefragmentationContext *pCtx, - VmaDefragmentationPassMoveInfo* pMove, uint32_t maxMoves) -{ - VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); + ++m_Budget.m_BlockCount[heapIndex]; - const uint32_t moveCount = VMA_MIN(uint32_t(pCtx->defragmentationMoves.size()) - pCtx->defragmentationMovesProcessed, maxMoves); + // VULKAN CALL vkAllocateMemory. + VkResult res = (*m_VulkanFunctions.vkAllocateMemory)(m_hDevice, pAllocateInfo, GetAllocationCallbacks(), pMemory); - for(uint32_t i = 0; i < moveCount; ++ i) + if(res == VK_SUCCESS) { - VmaDefragmentationMove& move = pCtx->defragmentationMoves[pCtx->defragmentationMovesProcessed + i]; +#if VMA_MEMORY_BUDGET + ++m_Budget.m_OperationsSinceBudgetFetch; +#endif - pMove->allocation = move.hAllocation; - pMove->memory = move.pDstBlock->GetDeviceMemory(); - pMove->offset = move.dstOffset; + // Informative callback. + if(m_DeviceMemoryCallbacks.pfnAllocate != VMA_NULL) + { + (*m_DeviceMemoryCallbacks.pfnAllocate)(this, pAllocateInfo->memoryTypeIndex, *pMemory, pAllocateInfo->allocationSize, m_DeviceMemoryCallbacks.pUserData); + } - ++ pMove; + deviceMemoryCountIncrement.Commit(); + } + else + { + --m_Budget.m_BlockCount[heapIndex]; + m_Budget.m_BlockBytes[heapIndex] -= pAllocateInfo->allocationSize; } - pCtx->defragmentationMovesProcessed += moveCount; - - return moveCount; + return res; } -void VmaBlockVector::CommitDefragmentations( - class VmaBlockVectorDefragmentationContext *pCtx, - VmaDefragmentationStats* pStats) +void VmaAllocator_T::FreeVulkanMemory(uint32_t memoryType, VkDeviceSize size, VkDeviceMemory hMemory) { - VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); - - for(uint32_t i = pCtx->defragmentationMovesCommitted; i < pCtx->defragmentationMovesProcessed; ++ i) + // Informative callback. + if(m_DeviceMemoryCallbacks.pfnFree != VMA_NULL) { - const VmaDefragmentationMove &move = pCtx->defragmentationMoves[i]; - - move.pSrcBlock->m_pMetadata->FreeAtOffset(move.srcOffset); - move.hAllocation->ChangeBlockAllocation(m_hAllocator, move.pDstBlock, move.dstOffset); + (*m_DeviceMemoryCallbacks.pfnFree)(this, memoryType, hMemory, size, m_DeviceMemoryCallbacks.pUserData); } - pCtx->defragmentationMovesCommitted = pCtx->defragmentationMovesProcessed; - FreeEmptyBlocks(pStats); -} + // VULKAN CALL vkFreeMemory. + (*m_VulkanFunctions.vkFreeMemory)(m_hDevice, hMemory, GetAllocationCallbacks()); -size_t VmaBlockVector::CalcAllocationCount() const -{ - size_t result = 0; - for (const auto i : c10::irange(m_Blocks.size())) { - result += m_Blocks[i]->m_pMetadata->GetAllocationCount(); - } - return result; + const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(memoryType); + --m_Budget.m_BlockCount[heapIndex]; + m_Budget.m_BlockBytes[heapIndex] -= size; + + --m_DeviceMemoryCount; } -bool VmaBlockVector::IsBufferImageGranularityConflictPossible() const +VkResult VmaAllocator_T::BindVulkanBuffer( + VkDeviceMemory memory, + VkDeviceSize memoryOffset, + VkBuffer buffer, + const void* pNext) { - if(m_BufferImageGranularity == 1) - { - return false; - } - VmaSuballocationType lastSuballocType = VMA_SUBALLOCATION_TYPE_FREE; - for(size_t i = 0, count = m_Blocks.size(); i < count; ++i) + if(pNext != VMA_NULL) { - VmaDeviceMemoryBlock* const pBlock = m_Blocks[i]; - VMA_ASSERT(m_Algorithm == 0); - VmaBlockMetadata_Generic* const pMetadata = (VmaBlockMetadata_Generic*)pBlock->m_pMetadata; - if(pMetadata->IsBufferImageGranularityConflictPossible(m_BufferImageGranularity, lastSuballocType)) +#if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 + if((m_UseKhrBindMemory2 || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) && + m_VulkanFunctions.vkBindBufferMemory2KHR != VMA_NULL) { - return true; + VkBindBufferMemoryInfoKHR bindBufferMemoryInfo = { VK_STRUCTURE_TYPE_BIND_BUFFER_MEMORY_INFO_KHR }; + bindBufferMemoryInfo.pNext = pNext; + bindBufferMemoryInfo.buffer = buffer; + bindBufferMemoryInfo.memory = memory; + bindBufferMemoryInfo.memoryOffset = memoryOffset; + return (*m_VulkanFunctions.vkBindBufferMemory2KHR)(m_hDevice, 1, &bindBufferMemoryInfo); + } + else +#endif // #if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 + { + return VK_ERROR_EXTENSION_NOT_PRESENT; } } - return false; -} - -void VmaBlockVector::MakePoolAllocationsLost( - uint32_t currentFrameIndex, - size_t* pLostAllocationCount) -{ - VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); - size_t lostAllocationCount = 0; - for (const auto blockIndex : c10::irange(m_Blocks.size())) { - VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; - VMA_ASSERT(pBlock); - lostAllocationCount += pBlock->m_pMetadata->MakeAllocationsLost(currentFrameIndex, m_FrameInUseCount); - } - if(pLostAllocationCount != VMA_NULL) + else { - *pLostAllocationCount = lostAllocationCount; + return (*m_VulkanFunctions.vkBindBufferMemory)(m_hDevice, buffer, memory, memoryOffset); } } -VkResult VmaBlockVector::CheckCorruption() +VkResult VmaAllocator_T::BindVulkanImage( + VkDeviceMemory memory, + VkDeviceSize memoryOffset, + VkImage image, + const void* pNext) { - if(!IsCorruptionDetectionEnabled()) + if(pNext != VMA_NULL) { - return VK_ERROR_FEATURE_NOT_PRESENT; - } - - VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); - for (const auto blockIndex : c10::irange(m_Blocks.size())) { - VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; - VMA_ASSERT(pBlock); - VkResult res = pBlock->CheckCorruption(m_hAllocator); - if(res != VK_SUCCESS) +#if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 + if((m_UseKhrBindMemory2 || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) && + m_VulkanFunctions.vkBindImageMemory2KHR != VMA_NULL) { - return res; + VkBindImageMemoryInfoKHR bindBufferMemoryInfo = { VK_STRUCTURE_TYPE_BIND_IMAGE_MEMORY_INFO_KHR }; + bindBufferMemoryInfo.pNext = pNext; + bindBufferMemoryInfo.image = image; + bindBufferMemoryInfo.memory = memory; + bindBufferMemoryInfo.memoryOffset = memoryOffset; + return (*m_VulkanFunctions.vkBindImageMemory2KHR)(m_hDevice, 1, &bindBufferMemoryInfo); + } + else +#endif // #if VMA_BIND_MEMORY2 + { + return VK_ERROR_EXTENSION_NOT_PRESENT; } } - return VK_SUCCESS; + else + { + return (*m_VulkanFunctions.vkBindImageMemory)(m_hDevice, image, memory, memoryOffset); + } } -void VmaBlockVector::AddStats(VmaStats* pStats) +VkResult VmaAllocator_T::Map(VmaAllocation hAllocation, void** ppData) { - const uint32_t memTypeIndex = m_MemoryTypeIndex; - const uint32_t memHeapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(memTypeIndex); - - VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); - - for (const auto blockIndex : c10::irange(m_Blocks.size())) { - const VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; - VMA_ASSERT(pBlock); - VMA_HEAVY_ASSERT(pBlock->Validate()); - VmaStatInfo allocationStatInfo; - pBlock->m_pMetadata->CalcAllocationStatInfo(allocationStatInfo); - VmaAddStatInfo(pStats->total, allocationStatInfo); - VmaAddStatInfo(pStats->memoryType[memTypeIndex], allocationStatInfo); - VmaAddStatInfo(pStats->memoryHeap[memHeapIndex], allocationStatInfo); + switch(hAllocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); + char *pBytes = VMA_NULL; + VkResult res = pBlock->Map(this, 1, (void**)&pBytes); + if(res == VK_SUCCESS) + { + *ppData = pBytes + (ptrdiff_t)hAllocation->GetOffset(); + hAllocation->BlockAllocMap(); + } + return res; + } + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + return hAllocation->DedicatedAllocMap(this, ppData); + default: + VMA_ASSERT(0); + return VK_ERROR_MEMORY_MAP_FAILED; } } -//////////////////////////////////////////////////////////////////////////////// -// VmaDefragmentationAlgorithm_Generic members definition - -VmaDefragmentationAlgorithm_Generic::VmaDefragmentationAlgorithm_Generic( - VmaAllocator hAllocator, - VmaBlockVector* pBlockVector, - uint32_t currentFrameIndex, - bool overlappingMoveSupported) : - VmaDefragmentationAlgorithm(hAllocator, pBlockVector, currentFrameIndex), - m_AllocationCount(0), - m_AllAllocations(false), - m_BytesMoved(0), - m_AllocationsMoved(0), - m_Blocks(VmaStlAllocator(hAllocator->GetAllocationCallbacks())) +void VmaAllocator_T::Unmap(VmaAllocation hAllocation) { - // Create block info for each block. - const size_t blockCount = m_pBlockVector->m_Blocks.size(); - for (const auto blockIndex : c10::irange(blockCount)) { - BlockInfo* pBlockInfo = vma_new(m_hAllocator, BlockInfo)(m_hAllocator->GetAllocationCallbacks()); - pBlockInfo->m_OriginalBlockIndex = blockIndex; - pBlockInfo->m_pBlock = m_pBlockVector->m_Blocks[blockIndex]; - m_Blocks.push_back(pBlockInfo); + switch(hAllocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); + hAllocation->BlockAllocUnmap(); + pBlock->Unmap(this, 1); + } + break; + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + hAllocation->DedicatedAllocUnmap(this); + break; + default: + VMA_ASSERT(0); } +} - // Sort them by m_pBlock pointer value. - VMA_SORT(m_Blocks.begin(), m_Blocks.end(), BlockPointerLess()); +VkResult VmaAllocator_T::BindBufferMemory( + VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkBuffer hBuffer, + const void* pNext) +{ + VkResult res = VK_SUCCESS; + switch(hAllocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + res = BindVulkanBuffer(hAllocation->GetMemory(), allocationLocalOffset, hBuffer, pNext); + break; + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); + VMA_ASSERT(pBlock && "Binding buffer to allocation that doesn't belong to any block."); + res = pBlock->BindBufferMemory(this, hAllocation, allocationLocalOffset, hBuffer, pNext); + break; + } + default: + VMA_ASSERT(0); + } + return res; } -VmaDefragmentationAlgorithm_Generic::~VmaDefragmentationAlgorithm_Generic() +VkResult VmaAllocator_T::BindImageMemory( + VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkImage hImage, + const void* pNext) { - for(size_t i = m_Blocks.size(); i--; ) + VkResult res = VK_SUCCESS; + switch(hAllocation->GetType()) { - vma_delete(m_hAllocator, m_Blocks[i]); + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + res = BindVulkanImage(hAllocation->GetMemory(), allocationLocalOffset, hImage, pNext); + break; + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaDeviceMemoryBlock* pBlock = hAllocation->GetBlock(); + VMA_ASSERT(pBlock && "Binding image to allocation that doesn't belong to any block."); + res = pBlock->BindImageMemory(this, hAllocation, allocationLocalOffset, hImage, pNext); + break; + } + default: + VMA_ASSERT(0); } + return res; } -void VmaDefragmentationAlgorithm_Generic::AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged) +VkResult VmaAllocator_T::FlushOrInvalidateAllocation( + VmaAllocation hAllocation, + VkDeviceSize offset, VkDeviceSize size, + VMA_CACHE_OPERATION op) { - // Now as we are inside VmaBlockVector::m_Mutex, we can make final check if this allocation was not lost. - if(hAlloc->GetLastUseFrameIndex() != VMA_FRAME_INDEX_LOST) + VkResult res = VK_SUCCESS; + + VkMappedMemoryRange memRange = {}; + if(GetFlushOrInvalidateRange(hAllocation, offset, size, memRange)) { - VmaDeviceMemoryBlock* pBlock = hAlloc->GetBlock(); - BlockInfoVector::iterator it = VmaBinaryFindFirstNotLess(m_Blocks.begin(), m_Blocks.end(), pBlock, BlockPointerLess()); - if(it != m_Blocks.end() && (*it)->m_pBlock == pBlock) - { - AllocationInfo allocInfo = AllocationInfo(hAlloc, pChanged); - (*it)->m_Allocations.push_back(allocInfo); - } - else + switch(op) { + case VMA_CACHE_FLUSH: + res = (*GetVulkanFunctions().vkFlushMappedMemoryRanges)(m_hDevice, 1, &memRange); + break; + case VMA_CACHE_INVALIDATE: + res = (*GetVulkanFunctions().vkInvalidateMappedMemoryRanges)(m_hDevice, 1, &memRange); + break; + default: VMA_ASSERT(0); } - - ++m_AllocationCount; } + // else: Just ignore this call. + return res; } -VkResult VmaDefragmentationAlgorithm_Generic::DefragmentRound( - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkDeviceSize maxBytesToMove, - uint32_t maxAllocationsToMove, - bool freeOldAllocations) +VkResult VmaAllocator_T::FlushOrInvalidateAllocations( + uint32_t allocationCount, + const VmaAllocation* allocations, + const VkDeviceSize* offsets, const VkDeviceSize* sizes, + VMA_CACHE_OPERATION op) { - if(m_Blocks.empty()) - { - return VK_SUCCESS; - } - - // This is a choice based on research. - // Option 1: - uint32_t strategy = VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT; - // Option 2: - //uint32_t strategy = VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT; - // Option 3: - //uint32_t strategy = VMA_ALLOCATION_CREATE_STRATEGY_MIN_FRAGMENTATION_BIT; + typedef VmaStlAllocator RangeAllocator; + typedef VmaSmallVector RangeVector; + RangeVector ranges = RangeVector(RangeAllocator(GetAllocationCallbacks())); - size_t srcBlockMinIndex = 0; - // When FAST_ALGORITHM, move allocations from only last out of blocks that contain non-movable allocations. - /* - if(m_AlgorithmFlags & VMA_DEFRAGMENTATION_FAST_ALGORITHM_BIT) + for(uint32_t allocIndex = 0; allocIndex < allocationCount; ++allocIndex) { - const size_t blocksWithNonMovableCount = CalcBlocksWithNonMovableCount(); - if(blocksWithNonMovableCount > 0) + const VmaAllocation alloc = allocations[allocIndex]; + const VkDeviceSize offset = offsets != VMA_NULL ? offsets[allocIndex] : 0; + const VkDeviceSize size = sizes != VMA_NULL ? sizes[allocIndex] : VK_WHOLE_SIZE; + VkMappedMemoryRange newRange; + if(GetFlushOrInvalidateRange(alloc, offset, size, newRange)) { - srcBlockMinIndex = blocksWithNonMovableCount - 1; + ranges.push_back(newRange); } } - */ - size_t srcBlockIndex = m_Blocks.size() - 1; - size_t srcAllocIndex = SIZE_MAX; - for(;;) + VkResult res = VK_SUCCESS; + if(!ranges.empty()) { - // 1. Find next allocation to move. - // 1.1. Start from last to first m_Blocks - they are sorted from most "destination" to most "source". - // 1.2. Then start from last to first m_Allocations. - while(srcAllocIndex >= m_Blocks[srcBlockIndex]->m_Allocations.size()) + switch(op) { - if(m_Blocks[srcBlockIndex]->m_Allocations.empty()) - { - // Finished: no more allocations to process. - if(srcBlockIndex == srcBlockMinIndex) - { - return VK_SUCCESS; - } - else - { - --srcBlockIndex; - srcAllocIndex = SIZE_MAX; - } - } - else - { - srcAllocIndex = m_Blocks[srcBlockIndex]->m_Allocations.size() - 1; - } + case VMA_CACHE_FLUSH: + res = (*GetVulkanFunctions().vkFlushMappedMemoryRanges)(m_hDevice, (uint32_t)ranges.size(), ranges.data()); + break; + case VMA_CACHE_INVALIDATE: + res = (*GetVulkanFunctions().vkInvalidateMappedMemoryRanges)(m_hDevice, (uint32_t)ranges.size(), ranges.data()); + break; + default: + VMA_ASSERT(0); } + } + // else: Just ignore this call. + return res; +} - BlockInfo* pSrcBlockInfo = m_Blocks[srcBlockIndex]; - AllocationInfo& allocInfo = pSrcBlockInfo->m_Allocations[srcAllocIndex]; - - const VkDeviceSize size = allocInfo.m_hAllocation->GetSize(); - const VkDeviceSize srcOffset = allocInfo.m_hAllocation->GetOffset(); - const VkDeviceSize alignment = allocInfo.m_hAllocation->GetAlignment(); - const VmaSuballocationType suballocType = allocInfo.m_hAllocation->GetSuballocationType(); - - // 2. Try to find new place for this allocation in preceding or current block. - for(size_t dstBlockIndex = 0; dstBlockIndex <= srcBlockIndex; ++dstBlockIndex) - { - BlockInfo* pDstBlockInfo = m_Blocks[dstBlockIndex]; - VmaAllocationRequest dstAllocRequest; - if(pDstBlockInfo->m_pBlock->m_pMetadata->CreateAllocationRequest( - m_CurrentFrameIndex, - m_pBlockVector->GetFrameInUseCount(), - m_pBlockVector->GetBufferImageGranularity(), - size, - alignment, - false, // upperAddress - suballocType, - false, // canMakeOtherLost - strategy, - &dstAllocRequest) && - MoveMakesSense( - dstBlockIndex, dstAllocRequest.offset, srcBlockIndex, srcOffset)) - { - VMA_ASSERT(dstAllocRequest.itemsToMakeLostCount == 0); +void VmaAllocator_T::FreeDedicatedMemory(const VmaAllocation allocation) +{ + VMA_ASSERT(allocation && allocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_DEDICATED); - // Reached limit on number of allocations or bytes to move. - if((m_AllocationsMoved + 1 > maxAllocationsToMove) || - (m_BytesMoved + size > maxBytesToMove)) - { - return VK_SUCCESS; - } + const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); + VmaPool parentPool = allocation->GetParentPool(); + if(parentPool == VK_NULL_HANDLE) + { + // Default pool + m_DedicatedAllocations[memTypeIndex].Unregister(allocation); + } + else + { + // Custom pool + parentPool->m_DedicatedAllocations.Unregister(allocation); + } - VmaDefragmentationMove move = {}; - move.srcBlockIndex = pSrcBlockInfo->m_OriginalBlockIndex; - move.dstBlockIndex = pDstBlockInfo->m_OriginalBlockIndex; - move.srcOffset = srcOffset; - move.dstOffset = dstAllocRequest.offset; - move.size = size; - move.hAllocation = allocInfo.m_hAllocation; - move.pSrcBlock = pSrcBlockInfo->m_pBlock; - move.pDstBlock = pDstBlockInfo->m_pBlock; + VkDeviceMemory hMemory = allocation->GetMemory(); - moves.push_back(move); + /* + There is no need to call this, because Vulkan spec allows to skip vkUnmapMemory + before vkFreeMemory. - pDstBlockInfo->m_pBlock->m_pMetadata->Alloc( - dstAllocRequest, - suballocType, - size, - allocInfo.m_hAllocation); + if(allocation->GetMappedData() != VMA_NULL) + { + (*m_VulkanFunctions.vkUnmapMemory)(m_hDevice, hMemory); + } + */ - if(freeOldAllocations) - { - pSrcBlockInfo->m_pBlock->m_pMetadata->FreeAtOffset(srcOffset); - allocInfo.m_hAllocation->ChangeBlockAllocation(m_hAllocator, pDstBlockInfo->m_pBlock, dstAllocRequest.offset); - } + FreeVulkanMemory(memTypeIndex, allocation->GetSize(), hMemory); - if(allocInfo.m_pChanged != VMA_NULL) - { - *allocInfo.m_pChanged = VK_TRUE; - } + m_Budget.RemoveAllocation(MemoryTypeIndexToHeapIndex(allocation->GetMemoryTypeIndex()), allocation->GetSize()); + m_AllocationObjectAllocator.Free(allocation); - ++m_AllocationsMoved; - m_BytesMoved += size; + VMA_DEBUG_LOG(" Freed DedicatedMemory MemoryTypeIndex=%u", memTypeIndex); +} - VmaVectorRemove(pSrcBlockInfo->m_Allocations, srcAllocIndex); +uint32_t VmaAllocator_T::CalculateGpuDefragmentationMemoryTypeBits() const +{ + VkBufferCreateInfo dummyBufCreateInfo; + VmaFillGpuDefragmentationBufferCreateInfo(dummyBufCreateInfo); - break; - } - } + uint32_t memoryTypeBits = 0; - // If not processed, this allocInfo remains in pBlockInfo->m_Allocations for next round. + // Create buffer. + VkBuffer buf = VK_NULL_HANDLE; + VkResult res = (*GetVulkanFunctions().vkCreateBuffer)( + m_hDevice, &dummyBufCreateInfo, GetAllocationCallbacks(), &buf); + if(res == VK_SUCCESS) + { + // Query for supported memory types. + VkMemoryRequirements memReq; + (*GetVulkanFunctions().vkGetBufferMemoryRequirements)(m_hDevice, buf, &memReq); + memoryTypeBits = memReq.memoryTypeBits; - if(srcAllocIndex > 0) - { - --srcAllocIndex; - } - else - { - if(srcBlockIndex > 0) - { - --srcBlockIndex; - srcAllocIndex = SIZE_MAX; - } - else - { - return VK_SUCCESS; - } - } + // Destroy buffer. + (*GetVulkanFunctions().vkDestroyBuffer)(m_hDevice, buf, GetAllocationCallbacks()); } -} -size_t VmaDefragmentationAlgorithm_Generic::CalcBlocksWithNonMovableCount() const -{ - size_t result = 0; - for (const auto i : c10::irange(m_Blocks.size())) { - if(m_Blocks[i]->m_HasNonMovableAllocations) - { - ++result; - } - } - return result; + return memoryTypeBits; } -VkResult VmaDefragmentationAlgorithm_Generic::Defragment( - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkDeviceSize maxBytesToMove, - uint32_t maxAllocationsToMove, - VmaDefragmentationFlags flags) +uint32_t VmaAllocator_T::CalculateGlobalMemoryTypeBits() const { - if(!m_AllAllocations && m_AllocationCount == 0) - { - return VK_SUCCESS; - } + // Make sure memory information is already fetched. + VMA_ASSERT(GetMemoryTypeCount() > 0); - const size_t blockCount = m_Blocks.size(); - for (const auto blockIndex : c10::irange(blockCount)) { - BlockInfo* pBlockInfo = m_Blocks[blockIndex]; + uint32_t memoryTypeBits = UINT32_MAX; - if(m_AllAllocations) + if(!m_UseAmdDeviceCoherentMemory) + { + // Exclude memory types that have VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) { - VmaBlockMetadata_Generic* pMetadata = (VmaBlockMetadata_Generic*)pBlockInfo->m_pBlock->m_pMetadata; - for(VmaSuballocationList::const_iterator it = pMetadata->m_Suballocations.begin(); - it != pMetadata->m_Suballocations.end(); - ++it) + if((m_MemProps.memoryTypes[memTypeIndex].propertyFlags & VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY) != 0) { - if(it->type != VMA_SUBALLOCATION_TYPE_FREE) - { - AllocationInfo allocInfo = AllocationInfo(it->hAllocation, VMA_NULL); - pBlockInfo->m_Allocations.push_back(allocInfo); - } + memoryTypeBits &= ~(1u << memTypeIndex); } } - - pBlockInfo->CalcHasNonMovableAllocations(); - - // This is a choice based on research. - // Option 1: - pBlockInfo->SortAllocationsByOffsetDescending(); - // Option 2: - //pBlockInfo->SortAllocationsBySizeDescending(); - } - - // Sort m_Blocks this time by the main criterium, from most "destination" to most "source" blocks. - VMA_SORT(m_Blocks.begin(), m_Blocks.end(), BlockInfoCompareMoveDestination()); - - // This is a choice based on research. - const uint32_t roundCount = 2; - - // Execute defragmentation rounds (the main part). - VkResult result = VK_SUCCESS; - for(uint32_t round = 0; (round < roundCount) && (result == VK_SUCCESS); ++round) - { - result = DefragmentRound(moves, maxBytesToMove, maxAllocationsToMove, !(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL)); } - return result; + return memoryTypeBits; } -bool VmaDefragmentationAlgorithm_Generic::MoveMakesSense( - size_t dstBlockIndex, VkDeviceSize dstOffset, - size_t srcBlockIndex, VkDeviceSize srcOffset) +bool VmaAllocator_T::GetFlushOrInvalidateRange( + VmaAllocation allocation, + VkDeviceSize offset, VkDeviceSize size, + VkMappedMemoryRange& outRange) const { - if(dstBlockIndex < srcBlockIndex) + const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); + if(size > 0 && IsMemoryTypeNonCoherent(memTypeIndex)) { - return true; - } - if(dstBlockIndex > srcBlockIndex) - { - return false; - } - if(dstOffset < srcOffset) - { - return true; - } - return false; -} - -//////////////////////////////////////////////////////////////////////////////// -// VmaDefragmentationAlgorithm_Fast - -VmaDefragmentationAlgorithm_Fast::VmaDefragmentationAlgorithm_Fast( - VmaAllocator hAllocator, - VmaBlockVector* pBlockVector, - uint32_t currentFrameIndex, - bool overlappingMoveSupported) : - VmaDefragmentationAlgorithm(hAllocator, pBlockVector, currentFrameIndex), - m_OverlappingMoveSupported(overlappingMoveSupported), - m_AllocationCount(0), - m_AllAllocations(false), - m_BytesMoved(0), - m_AllocationsMoved(0), - m_BlockInfos(VmaStlAllocator(hAllocator->GetAllocationCallbacks())) -{ - VMA_ASSERT(VMA_DEBUG_MARGIN == 0); - -} - -VmaDefragmentationAlgorithm_Fast::~VmaDefragmentationAlgorithm_Fast() -{ -} - -VkResult VmaDefragmentationAlgorithm_Fast::Defragment( - VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, - VkDeviceSize maxBytesToMove, - uint32_t maxAllocationsToMove, - VmaDefragmentationFlags flags) -{ - VMA_ASSERT(m_AllAllocations || m_pBlockVector->CalcAllocationCount() == m_AllocationCount); - - const size_t blockCount = m_pBlockVector->GetBlockCount(); - if(blockCount == 0 || maxBytesToMove == 0 || maxAllocationsToMove == 0) - { - return VK_SUCCESS; - } - - PreprocessMetadata(); - - // Sort blocks in order from most destination. - - m_BlockInfos.resize(blockCount); - for (const auto i : c10::irange(blockCount)) { - m_BlockInfos[i].origBlockIndex = i; - } - - VMA_SORT(m_BlockInfos.begin(), m_BlockInfos.end(), [this](const BlockInfo& lhs, const BlockInfo& rhs) -> bool { - return m_pBlockVector->GetBlock(lhs.origBlockIndex)->m_pMetadata->GetSumFreeSize() < - m_pBlockVector->GetBlock(rhs.origBlockIndex)->m_pMetadata->GetSumFreeSize(); - }); - - // THE MAIN ALGORITHM - - FreeSpaceDatabase freeSpaceDb; + const VkDeviceSize nonCoherentAtomSize = m_PhysicalDeviceProperties.limits.nonCoherentAtomSize; + const VkDeviceSize allocationSize = allocation->GetSize(); + VMA_ASSERT(offset <= allocationSize); - size_t dstBlockInfoIndex = 0; - size_t dstOrigBlockIndex = m_BlockInfos[dstBlockInfoIndex].origBlockIndex; - VmaDeviceMemoryBlock* pDstBlock = m_pBlockVector->GetBlock(dstOrigBlockIndex); - VmaBlockMetadata_Generic* pDstMetadata = (VmaBlockMetadata_Generic*)pDstBlock->m_pMetadata; - VkDeviceSize dstBlockSize = pDstMetadata->GetSize(); - VkDeviceSize dstOffset = 0; + outRange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + outRange.pNext = VMA_NULL; + outRange.memory = allocation->GetMemory(); - bool end = false; - for(size_t srcBlockInfoIndex = 0; !end && srcBlockInfoIndex < blockCount; ++srcBlockInfoIndex) - { - const size_t srcOrigBlockIndex = m_BlockInfos[srcBlockInfoIndex].origBlockIndex; - VmaDeviceMemoryBlock* const pSrcBlock = m_pBlockVector->GetBlock(srcOrigBlockIndex); - VmaBlockMetadata_Generic* const pSrcMetadata = (VmaBlockMetadata_Generic*)pSrcBlock->m_pMetadata; - for(VmaSuballocationList::iterator srcSuballocIt = pSrcMetadata->m_Suballocations.begin(); - !end && srcSuballocIt != pSrcMetadata->m_Suballocations.end(); ) + switch(allocation->GetType()) { - VmaAllocation_T* const pAlloc = srcSuballocIt->hAllocation; - const VkDeviceSize srcAllocAlignment = pAlloc->GetAlignment(); - const VkDeviceSize srcAllocSize = srcSuballocIt->size; - if(m_AllocationsMoved == maxAllocationsToMove || - m_BytesMoved + srcAllocSize > maxBytesToMove) + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + outRange.offset = VmaAlignDown(offset, nonCoherentAtomSize); + if(size == VK_WHOLE_SIZE) { - end = true; - break; + outRange.size = allocationSize - outRange.offset; } - const VkDeviceSize srcAllocOffset = srcSuballocIt->offset; - - VmaDefragmentationMove move = {}; - // Try to place it in one of free spaces from the database. - size_t freeSpaceInfoIndex; - VkDeviceSize dstAllocOffset; - if(freeSpaceDb.Fetch(srcAllocAlignment, srcAllocSize, - freeSpaceInfoIndex, dstAllocOffset)) + else { - size_t freeSpaceOrigBlockIndex = m_BlockInfos[freeSpaceInfoIndex].origBlockIndex; - VmaDeviceMemoryBlock* pFreeSpaceBlock = m_pBlockVector->GetBlock(freeSpaceOrigBlockIndex); - VmaBlockMetadata_Generic* pFreeSpaceMetadata = (VmaBlockMetadata_Generic*)pFreeSpaceBlock->m_pMetadata; - - // Same block - if(freeSpaceInfoIndex == srcBlockInfoIndex) - { - VMA_ASSERT(dstAllocOffset <= srcAllocOffset); - - // MOVE OPTION 1: Move the allocation inside the same block by decreasing offset. - - VmaSuballocation suballoc = *srcSuballocIt; - suballoc.offset = dstAllocOffset; - suballoc.hAllocation->ChangeOffset(dstAllocOffset); - m_BytesMoved += srcAllocSize; - ++m_AllocationsMoved; - - VmaSuballocationList::iterator nextSuballocIt = srcSuballocIt; - ++nextSuballocIt; - pSrcMetadata->m_Suballocations.erase(srcSuballocIt); - srcSuballocIt = nextSuballocIt; - - InsertSuballoc(pFreeSpaceMetadata, suballoc); - - move.srcBlockIndex = srcOrigBlockIndex; - move.dstBlockIndex = freeSpaceOrigBlockIndex; - move.srcOffset = srcAllocOffset; - move.dstOffset = dstAllocOffset; - move.size = srcAllocSize; - - moves.push_back(move); - } - // Different block - else - { - // MOVE OPTION 2: Move the allocation to a different block. - - VMA_ASSERT(freeSpaceInfoIndex < srcBlockInfoIndex); - - VmaSuballocation suballoc = *srcSuballocIt; - suballoc.offset = dstAllocOffset; - suballoc.hAllocation->ChangeBlockAllocation(m_hAllocator, pFreeSpaceBlock, dstAllocOffset); - m_BytesMoved += srcAllocSize; - ++m_AllocationsMoved; - - VmaSuballocationList::iterator nextSuballocIt = srcSuballocIt; - ++nextSuballocIt; - pSrcMetadata->m_Suballocations.erase(srcSuballocIt); - srcSuballocIt = nextSuballocIt; - - InsertSuballoc(pFreeSpaceMetadata, suballoc); - - move.srcBlockIndex = srcOrigBlockIndex; - move.dstBlockIndex = freeSpaceOrigBlockIndex; - move.srcOffset = srcAllocOffset; - move.dstOffset = dstAllocOffset; - move.size = srcAllocSize; - - moves.push_back(move); - } + VMA_ASSERT(offset + size <= allocationSize); + outRange.size = VMA_MIN( + VmaAlignUp(size + (offset - outRange.offset), nonCoherentAtomSize), + allocationSize - outRange.offset); + } + break; + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + // 1. Still within this allocation. + outRange.offset = VmaAlignDown(offset, nonCoherentAtomSize); + if(size == VK_WHOLE_SIZE) + { + size = allocationSize - offset; } else { - dstAllocOffset = VmaAlignUp(dstOffset, srcAllocAlignment); - - // If the allocation doesn't fit before the end of dstBlock, forward to next block. - while(dstBlockInfoIndex < srcBlockInfoIndex && - dstAllocOffset + srcAllocSize > dstBlockSize) - { - // But before that, register remaining free space at the end of dst block. - freeSpaceDb.Register(dstBlockInfoIndex, dstOffset, dstBlockSize - dstOffset); - - ++dstBlockInfoIndex; - dstOrigBlockIndex = m_BlockInfos[dstBlockInfoIndex].origBlockIndex; - pDstBlock = m_pBlockVector->GetBlock(dstOrigBlockIndex); - pDstMetadata = (VmaBlockMetadata_Generic*)pDstBlock->m_pMetadata; - dstBlockSize = pDstMetadata->GetSize(); - dstOffset = 0; - dstAllocOffset = 0; - } - - // Same block - if(dstBlockInfoIndex == srcBlockInfoIndex) - { - VMA_ASSERT(dstAllocOffset <= srcAllocOffset); - - const bool overlap = dstAllocOffset + srcAllocSize > srcAllocOffset; + VMA_ASSERT(offset + size <= allocationSize); + } + outRange.size = VmaAlignUp(size + (offset - outRange.offset), nonCoherentAtomSize); - bool skipOver = overlap; - if(overlap && m_OverlappingMoveSupported && dstAllocOffset < srcAllocOffset) - { - // If destination and source place overlap, skip if it would move it - // by only < 1/64 of its size. - skipOver = (srcAllocOffset - dstAllocOffset) * 64 < srcAllocSize; - } + // 2. Adjust to whole block. + const VkDeviceSize allocationOffset = allocation->GetOffset(); + VMA_ASSERT(allocationOffset % nonCoherentAtomSize == 0); + const VkDeviceSize blockSize = allocation->GetBlock()->m_pMetadata->GetSize(); + outRange.offset += allocationOffset; + outRange.size = VMA_MIN(outRange.size, blockSize - outRange.offset); - if(skipOver) - { - freeSpaceDb.Register(dstBlockInfoIndex, dstOffset, srcAllocOffset - dstOffset); + break; + } + default: + VMA_ASSERT(0); + } + return true; + } + return false; +} - dstOffset = srcAllocOffset + srcAllocSize; - ++srcSuballocIt; - } - // MOVE OPTION 1: Move the allocation inside the same block by decreasing offset. - else - { - srcSuballocIt->offset = dstAllocOffset; - srcSuballocIt->hAllocation->ChangeOffset(dstAllocOffset); - dstOffset = dstAllocOffset + srcAllocSize; - m_BytesMoved += srcAllocSize; - ++m_AllocationsMoved; - ++srcSuballocIt; - - move.srcBlockIndex = srcOrigBlockIndex; - move.dstBlockIndex = dstOrigBlockIndex; - move.srcOffset = srcAllocOffset; - move.dstOffset = dstAllocOffset; - move.size = srcAllocSize; - - moves.push_back(move); - } - } - // Different block - else - { - // MOVE OPTION 2: Move the allocation to a different block. +#if VMA_MEMORY_BUDGET +void VmaAllocator_T::UpdateVulkanBudget() +{ + VMA_ASSERT(m_UseExtMemoryBudget); - VMA_ASSERT(dstBlockInfoIndex < srcBlockInfoIndex); - VMA_ASSERT(dstAllocOffset + srcAllocSize <= dstBlockSize); + VkPhysicalDeviceMemoryProperties2KHR memProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2_KHR }; - VmaSuballocation suballoc = *srcSuballocIt; - suballoc.offset = dstAllocOffset; - suballoc.hAllocation->ChangeBlockAllocation(m_hAllocator, pDstBlock, dstAllocOffset); - dstOffset = dstAllocOffset + srcAllocSize; - m_BytesMoved += srcAllocSize; - ++m_AllocationsMoved; + VkPhysicalDeviceMemoryBudgetPropertiesEXT budgetProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT }; + VmaPnextChainPushFront(&memProps, &budgetProps); - VmaSuballocationList::iterator nextSuballocIt = srcSuballocIt; - ++nextSuballocIt; - pSrcMetadata->m_Suballocations.erase(srcSuballocIt); - srcSuballocIt = nextSuballocIt; + GetVulkanFunctions().vkGetPhysicalDeviceMemoryProperties2KHR(m_PhysicalDevice, &memProps); - pDstMetadata->m_Suballocations.push_back(suballoc); + { + VmaMutexLockWrite lockWrite(m_Budget.m_BudgetMutex, m_UseMutex); - move.srcBlockIndex = srcOrigBlockIndex; - move.dstBlockIndex = dstOrigBlockIndex; - move.srcOffset = srcAllocOffset; - move.dstOffset = dstAllocOffset; - move.size = srcAllocSize; + for(uint32_t heapIndex = 0; heapIndex < GetMemoryHeapCount(); ++heapIndex) + { + m_Budget.m_VulkanUsage[heapIndex] = budgetProps.heapUsage[heapIndex]; + m_Budget.m_VulkanBudget[heapIndex] = budgetProps.heapBudget[heapIndex]; + m_Budget.m_BlockBytesAtBudgetFetch[heapIndex] = m_Budget.m_BlockBytes[heapIndex].load(); - moves.push_back(move); - } + // Some bugged drivers return the budget incorrectly, e.g. 0 or much bigger than heap size. + if(m_Budget.m_VulkanBudget[heapIndex] == 0) + { + m_Budget.m_VulkanBudget[heapIndex] = m_MemProps.memoryHeaps[heapIndex].size * 8 / 10; // 80% heuristics. } - } - } - - m_BlockInfos.clear(); - - PostprocessMetadata(); - - return VK_SUCCESS; -} - -void VmaDefragmentationAlgorithm_Fast::PreprocessMetadata() -{ - const size_t blockCount = m_pBlockVector->GetBlockCount(); - for (const auto blockIndex : c10::irange(blockCount)) { - VmaBlockMetadata_Generic* const pMetadata = - (VmaBlockMetadata_Generic*)m_pBlockVector->GetBlock(blockIndex)->m_pMetadata; - pMetadata->m_FreeCount = 0; - pMetadata->m_SumFreeSize = pMetadata->GetSize(); - pMetadata->m_FreeSuballocationsBySize.clear(); - for(VmaSuballocationList::iterator it = pMetadata->m_Suballocations.begin(); - it != pMetadata->m_Suballocations.end(); ) - { - if(it->type == VMA_SUBALLOCATION_TYPE_FREE) + else if(m_Budget.m_VulkanBudget[heapIndex] > m_MemProps.memoryHeaps[heapIndex].size) { - VmaSuballocationList::iterator nextIt = it; - ++nextIt; - pMetadata->m_Suballocations.erase(it); - it = nextIt; + m_Budget.m_VulkanBudget[heapIndex] = m_MemProps.memoryHeaps[heapIndex].size; } - else + if(m_Budget.m_VulkanUsage[heapIndex] == 0 && m_Budget.m_BlockBytesAtBudgetFetch[heapIndex] > 0) { - ++it; + m_Budget.m_VulkanUsage[heapIndex] = m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]; } } + m_Budget.m_OperationsSinceBudgetFetch = 0; } } +#endif // VMA_MEMORY_BUDGET -void VmaDefragmentationAlgorithm_Fast::PostprocessMetadata() +void VmaAllocator_T::FillAllocation(const VmaAllocation hAllocation, uint8_t pattern) { - const size_t blockCount = m_pBlockVector->GetBlockCount(); - for (const auto blockIndex : c10::irange(blockCount)) { - VmaBlockMetadata_Generic* const pMetadata = - (VmaBlockMetadata_Generic*)m_pBlockVector->GetBlock(blockIndex)->m_pMetadata; - const VkDeviceSize blockSize = pMetadata->GetSize(); - - // No allocations in this block - entire area is free. - if(pMetadata->m_Suballocations.empty()) + if(VMA_DEBUG_INITIALIZE_ALLOCATIONS && + hAllocation->IsMappingAllowed() && + (m_MemProps.memoryTypes[hAllocation->GetMemoryTypeIndex()].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0) + { + void* pData = VMA_NULL; + VkResult res = Map(hAllocation, &pData); + if(res == VK_SUCCESS) { - pMetadata->m_FreeCount = 1; - //pMetadata->m_SumFreeSize is already set to blockSize. - VmaSuballocation suballoc = { - 0, // offset - blockSize, // size - VMA_NULL, // hAllocation - VMA_SUBALLOCATION_TYPE_FREE }; - pMetadata->m_Suballocations.push_back(suballoc); - pMetadata->RegisterFreeSuballocation(pMetadata->m_Suballocations.begin()); + memset(pData, (int)pattern, (size_t)hAllocation->GetSize()); + FlushOrInvalidateAllocation(hAllocation, 0, VK_WHOLE_SIZE, VMA_CACHE_FLUSH); + Unmap(hAllocation); } - // There are some allocations in this block. else { - VkDeviceSize offset = 0; - VmaSuballocationList::iterator it; - for(it = pMetadata->m_Suballocations.begin(); - it != pMetadata->m_Suballocations.end(); - ++it) - { - VMA_ASSERT(it->type != VMA_SUBALLOCATION_TYPE_FREE); - VMA_ASSERT(it->offset >= offset); + VMA_ASSERT(0 && "VMA_DEBUG_INITIALIZE_ALLOCATIONS is enabled, but couldn't map memory to fill allocation."); + } + } +} + +uint32_t VmaAllocator_T::GetGpuDefragmentationMemoryTypeBits() +{ + uint32_t memoryTypeBits = m_GpuDefragmentationMemoryTypeBits.load(); + if(memoryTypeBits == UINT32_MAX) + { + memoryTypeBits = CalculateGpuDefragmentationMemoryTypeBits(); + m_GpuDefragmentationMemoryTypeBits.store(memoryTypeBits); + } + return memoryTypeBits; +} - // Need to insert preceding free space. - if(it->offset > offset) +#if VMA_STATS_STRING_ENABLED +void VmaAllocator_T::PrintDetailedMap(VmaJsonWriter& json) +{ + json.WriteString("DefaultPools"); + json.BeginObject(); + { + for (uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + VmaBlockVector* pBlockVector = m_pBlockVectors[memTypeIndex]; + VmaDedicatedAllocationList& dedicatedAllocList = m_DedicatedAllocations[memTypeIndex]; + if (pBlockVector != VMA_NULL) + { + json.BeginString("Type "); + json.ContinueString(memTypeIndex); + json.EndString(); + json.BeginObject(); { - ++pMetadata->m_FreeCount; - const VkDeviceSize freeSize = it->offset - offset; - VmaSuballocation suballoc = { - offset, // offset - freeSize, // size - VMA_NULL, // hAllocation - VMA_SUBALLOCATION_TYPE_FREE }; - VmaSuballocationList::iterator precedingFreeIt = pMetadata->m_Suballocations.insert(it, suballoc); - if(freeSize >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) - { - pMetadata->m_FreeSuballocationsBySize.push_back(precedingFreeIt); - } - } + json.WriteString("PreferredBlockSize"); + json.WriteNumber(pBlockVector->GetPreferredBlockSize()); - pMetadata->m_SumFreeSize -= it->size; - offset = it->offset + it->size; + json.WriteString("Blocks"); + pBlockVector->PrintDetailedMap(json); + + json.WriteString("DedicatedAllocations"); + dedicatedAllocList.BuildStatsString(json); + } + json.EndObject(); } + } + } + json.EndObject(); - // Need to insert trailing free space. - if(offset < blockSize) + json.WriteString("CustomPools"); + json.BeginObject(); + { + VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); + if (!m_Pools.IsEmpty()) + { + for (uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) { - ++pMetadata->m_FreeCount; - const VkDeviceSize freeSize = blockSize - offset; - VmaSuballocation suballoc = { - offset, // offset - freeSize, // size - VMA_NULL, // hAllocation - VMA_SUBALLOCATION_TYPE_FREE }; - VMA_ASSERT(it == pMetadata->m_Suballocations.end()); - VmaSuballocationList::iterator trailingFreeIt = pMetadata->m_Suballocations.insert(it, suballoc); - if(freeSize > VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + bool displayType = true; + size_t index = 0; + for (VmaPool pool = m_Pools.Front(); pool != VMA_NULL; pool = m_Pools.GetNext(pool)) { - pMetadata->m_FreeSuballocationsBySize.push_back(trailingFreeIt); + VmaBlockVector& blockVector = pool->m_BlockVector; + if (blockVector.GetMemoryTypeIndex() == memTypeIndex) + { + if (displayType) + { + json.BeginString("Type "); + json.ContinueString(memTypeIndex); + json.EndString(); + json.BeginArray(); + displayType = false; + } + + json.BeginObject(); + { + json.WriteString("Name"); + json.BeginString(); + json.ContinueString_Size(index++); + if (pool->GetName()) + { + json.ContinueString(" - "); + json.ContinueString(pool->GetName()); + } + json.EndString(); + + json.WriteString("PreferredBlockSize"); + json.WriteNumber(blockVector.GetPreferredBlockSize()); + + json.WriteString("Blocks"); + blockVector.PrintDetailedMap(json); + + json.WriteString("DedicatedAllocations"); + pool->m_DedicatedAllocations.BuildStatsString(json); + } + json.EndObject(); + } } - } - VMA_SORT( - pMetadata->m_FreeSuballocationsBySize.begin(), - pMetadata->m_FreeSuballocationsBySize.end(), - VmaSuballocationItemSizeLess()); + if (!displayType) + json.EndArray(); + } } - - VMA_HEAVY_ASSERT(pMetadata->Validate()); } + json.EndObject(); } +#endif // VMA_STATS_STRING_ENABLED +#endif // _VMA_ALLOCATOR_T_FUNCTIONS -void VmaDefragmentationAlgorithm_Fast::InsertSuballoc(VmaBlockMetadata_Generic* pMetadata, const VmaSuballocation& suballoc) + +#ifndef _VMA_PUBLIC_INTERFACE +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAllocator( + const VmaAllocatorCreateInfo* pCreateInfo, + VmaAllocator* pAllocator) { - // TODO: Optimize somehow. Remember iterator instead of searching for it linearly. - VmaSuballocationList::iterator it = pMetadata->m_Suballocations.begin(); - while(it != pMetadata->m_Suballocations.end()) + VMA_ASSERT(pCreateInfo && pAllocator); + VMA_ASSERT(pCreateInfo->vulkanApiVersion == 0 || + (VK_VERSION_MAJOR(pCreateInfo->vulkanApiVersion) == 1 && VK_VERSION_MINOR(pCreateInfo->vulkanApiVersion) <= 3)); + VMA_DEBUG_LOG("vmaCreateAllocator"); + *pAllocator = vma_new(pCreateInfo->pAllocationCallbacks, VmaAllocator_T)(pCreateInfo); + VkResult result = (*pAllocator)->Init(pCreateInfo); + if(result < 0) { - if(it->offset < suballoc.offset) - { - ++it; - } + vma_delete(pCreateInfo->pAllocationCallbacks, *pAllocator); + *pAllocator = VK_NULL_HANDLE; } - pMetadata->m_Suballocations.insert(it, suballoc); + return result; } -//////////////////////////////////////////////////////////////////////////////// -// VmaBlockVectorDefragmentationContext - -VmaBlockVectorDefragmentationContext::VmaBlockVectorDefragmentationContext( - VmaAllocator hAllocator, - VmaPool hCustomPool, - VmaBlockVector* pBlockVector, - uint32_t currFrameIndex) : - res(VK_SUCCESS), - mutexLocked(false), - blockContexts(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), - defragmentationMoves(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), - defragmentationMovesProcessed(0), - defragmentationMovesCommitted(0), - hasDefragmentationPlan(0), - m_hAllocator(hAllocator), - m_hCustomPool(hCustomPool), - m_pBlockVector(pBlockVector), - m_CurrFrameIndex(currFrameIndex), - m_pAlgorithm(VMA_NULL), - m_Allocations(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), - m_AllAllocations(false) +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyAllocator( + VmaAllocator allocator) { + if(allocator != VK_NULL_HANDLE) + { + VMA_DEBUG_LOG("vmaDestroyAllocator"); + VkAllocationCallbacks allocationCallbacks = allocator->m_AllocationCallbacks; // Have to copy the callbacks when destroying. + vma_delete(&allocationCallbacks, allocator); + } } -VmaBlockVectorDefragmentationContext::~VmaBlockVectorDefragmentationContext() +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocatorInfo(VmaAllocator allocator, VmaAllocatorInfo* pAllocatorInfo) { - vma_delete(m_hAllocator, m_pAlgorithm); + VMA_ASSERT(allocator && pAllocatorInfo); + pAllocatorInfo->instance = allocator->m_hInstance; + pAllocatorInfo->physicalDevice = allocator->GetPhysicalDevice(); + pAllocatorInfo->device = allocator->m_hDevice; } -void VmaBlockVectorDefragmentationContext::AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged) +VMA_CALL_PRE void VMA_CALL_POST vmaGetPhysicalDeviceProperties( + VmaAllocator allocator, + const VkPhysicalDeviceProperties **ppPhysicalDeviceProperties) { - AllocInfo info = { hAlloc, pChanged }; - m_Allocations.push_back(info); + VMA_ASSERT(allocator && ppPhysicalDeviceProperties); + *ppPhysicalDeviceProperties = &allocator->m_PhysicalDeviceProperties; } -void VmaBlockVectorDefragmentationContext::Begin(bool overlappingMoveSupported, VmaDefragmentationFlags flags) +VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryProperties( + VmaAllocator allocator, + const VkPhysicalDeviceMemoryProperties** ppPhysicalDeviceMemoryProperties) { - const bool allAllocations = m_AllAllocations || - m_Allocations.size() == m_pBlockVector->CalcAllocationCount(); - - /******************************** - HERE IS THE CHOICE OF DEFRAGMENTATION ALGORITHM. - ********************************/ + VMA_ASSERT(allocator && ppPhysicalDeviceMemoryProperties); + *ppPhysicalDeviceMemoryProperties = &allocator->m_MemProps; +} - /* - Fast algorithm is supported only when certain criteria are met: - - VMA_DEBUG_MARGIN is 0. - - All allocations in this block vector are moveable. - - There is no possibility of image/buffer granularity conflict. - - The defragmentation is not incremental - */ - if(VMA_DEBUG_MARGIN == 0 && - allAllocations && - !m_pBlockVector->IsBufferImageGranularityConflictPossible() && - !(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL)) - { - m_pAlgorithm = vma_new(m_hAllocator, VmaDefragmentationAlgorithm_Fast)( - m_hAllocator, m_pBlockVector, m_CurrFrameIndex, overlappingMoveSupported); - } - else - { - m_pAlgorithm = vma_new(m_hAllocator, VmaDefragmentationAlgorithm_Generic)( - m_hAllocator, m_pBlockVector, m_CurrFrameIndex, overlappingMoveSupported); - } - - if(allAllocations) - { - m_pAlgorithm->AddAll(); - } - else - { - for(size_t i = 0, count = m_Allocations.size(); i < count; ++i) - { - m_pAlgorithm->AddAllocation(m_Allocations[i].hAlloc, m_Allocations[i].pChanged); - } - } +VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryTypeProperties( + VmaAllocator allocator, + uint32_t memoryTypeIndex, + VkMemoryPropertyFlags* pFlags) +{ + VMA_ASSERT(allocator && pFlags); + VMA_ASSERT(memoryTypeIndex < allocator->GetMemoryTypeCount()); + *pFlags = allocator->m_MemProps.memoryTypes[memoryTypeIndex].propertyFlags; } -//////////////////////////////////////////////////////////////////////////////// -// VmaDefragmentationContext +VMA_CALL_PRE void VMA_CALL_POST vmaSetCurrentFrameIndex( + VmaAllocator allocator, + uint32_t frameIndex) +{ + VMA_ASSERT(allocator); -VmaDefragmentationContext_T::VmaDefragmentationContext_T( - VmaAllocator hAllocator, - uint32_t currFrameIndex, - uint32_t flags, - VmaDefragmentationStats* pStats) : - m_hAllocator(hAllocator), - m_CurrFrameIndex(currFrameIndex), - m_Flags(flags), - m_pStats(pStats), - m_CustomPoolContexts(VmaStlAllocator(hAllocator->GetAllocationCallbacks())) + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + allocator->SetCurrentFrameIndex(frameIndex); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaCalculateStatistics( + VmaAllocator allocator, + VmaTotalStatistics* pStats) { - memset(m_DefaultPoolContexts, 0, sizeof(m_DefaultPoolContexts)); + VMA_ASSERT(allocator && pStats); + VMA_DEBUG_GLOBAL_MUTEX_LOCK + allocator->CalculateStatistics(pStats); } -VmaDefragmentationContext_T::~VmaDefragmentationContext_T() +VMA_CALL_PRE void VMA_CALL_POST vmaGetHeapBudgets( + VmaAllocator allocator, + VmaBudget* pBudgets) { - for(size_t i = m_CustomPoolContexts.size(); i--; ) - { - VmaBlockVectorDefragmentationContext* pBlockVectorCtx = m_CustomPoolContexts[i]; - pBlockVectorCtx->GetBlockVector()->DefragmentationEnd(pBlockVectorCtx, m_Flags, m_pStats); - vma_delete(m_hAllocator, pBlockVectorCtx); - } - for(size_t i = m_hAllocator->m_MemProps.memoryTypeCount; i--; ) - { - VmaBlockVectorDefragmentationContext* pBlockVectorCtx = m_DefaultPoolContexts[i]; - if(pBlockVectorCtx) - { - pBlockVectorCtx->GetBlockVector()->DefragmentationEnd(pBlockVectorCtx, m_Flags, m_pStats); - vma_delete(m_hAllocator, pBlockVectorCtx); - } - } + VMA_ASSERT(allocator && pBudgets); + VMA_DEBUG_GLOBAL_MUTEX_LOCK + allocator->GetHeapBudgets(pBudgets, 0, allocator->GetMemoryHeapCount()); } -void VmaDefragmentationContext_T::AddPools(uint32_t poolCount, const VmaPool* pPools) +#if VMA_STATS_STRING_ENABLED + +VMA_CALL_PRE void VMA_CALL_POST vmaBuildStatsString( + VmaAllocator allocator, + char** ppStatsString, + VkBool32 detailedMap) { - for (const auto poolIndex : c10::irange(poolCount)) { - VmaPool pool = pPools[poolIndex]; - VMA_ASSERT(pool); - // Pools with algorithm other than default are not defragmented. - if(pool->m_BlockVector.GetAlgorithm() == 0) - { - VmaBlockVectorDefragmentationContext* pBlockVectorDefragCtx = VMA_NULL; + VMA_ASSERT(allocator && ppStatsString); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - for(size_t i = m_CustomPoolContexts.size(); i--; ) - { - if(m_CustomPoolContexts[i]->GetCustomPool() == pool) - { - pBlockVectorDefragCtx = m_CustomPoolContexts[i]; - break; - } - } + VmaStringBuilder sb(allocator->GetAllocationCallbacks()); + { + VmaBudget budgets[VK_MAX_MEMORY_HEAPS]; + allocator->GetHeapBudgets(budgets, 0, allocator->GetMemoryHeapCount()); + + VmaTotalStatistics stats; + allocator->CalculateStatistics(&stats); - if(!pBlockVectorDefragCtx) + VmaJsonWriter json(allocator->GetAllocationCallbacks(), sb); + json.BeginObject(); + { + json.WriteString("General"); + json.BeginObject(); { - pBlockVectorDefragCtx = vma_new(m_hAllocator, VmaBlockVectorDefragmentationContext)( - m_hAllocator, - pool, - &pool->m_BlockVector, - m_CurrFrameIndex); - m_CustomPoolContexts.push_back(pBlockVectorDefragCtx); - } + const VkPhysicalDeviceProperties& deviceProperties = allocator->m_PhysicalDeviceProperties; + const VkPhysicalDeviceMemoryProperties& memoryProperties = allocator->m_MemProps; - pBlockVectorDefragCtx->AddAll(); - } - } -} + json.WriteString("API"); + json.WriteString("Vulkan"); -void VmaDefragmentationContext_T::AddAllocations( - uint32_t allocationCount, - const VmaAllocation* pAllocations, - VkBool32* pAllocationsChanged) -{ - // Dispatch pAllocations among defragmentators. Create them when necessary. - for (const auto allocIndex : c10::irange(allocationCount)) { - const VmaAllocation hAlloc = pAllocations[allocIndex]; - VMA_ASSERT(hAlloc); - // DedicatedAlloc cannot be defragmented. - if((hAlloc->GetType() == VmaAllocation_T::ALLOCATION_TYPE_BLOCK) && - // Lost allocation cannot be defragmented. - (hAlloc->GetLastUseFrameIndex() != VMA_FRAME_INDEX_LOST)) - { - VmaBlockVectorDefragmentationContext* pBlockVectorDefragCtx = VMA_NULL; - - const VmaPool hAllocPool = hAlloc->GetBlock()->GetParentPool(); - // This allocation belongs to custom pool. - if(hAllocPool != VK_NULL_HANDLE) + json.WriteString("apiVersion"); + json.BeginString(); + json.ContinueString(VK_API_VERSION_MAJOR(deviceProperties.apiVersion)); + json.ContinueString("."); + json.ContinueString(VK_API_VERSION_MINOR(deviceProperties.apiVersion)); + json.ContinueString("."); + json.ContinueString(VK_API_VERSION_PATCH(deviceProperties.apiVersion)); + json.EndString(); + + json.WriteString("GPU"); + json.WriteString(deviceProperties.deviceName); + json.WriteString("deviceType"); + json.WriteNumber(static_cast(deviceProperties.deviceType)); + + json.WriteString("maxMemoryAllocationCount"); + json.WriteNumber(deviceProperties.limits.maxMemoryAllocationCount); + json.WriteString("bufferImageGranularity"); + json.WriteNumber(deviceProperties.limits.bufferImageGranularity); + json.WriteString("nonCoherentAtomSize"); + json.WriteNumber(deviceProperties.limits.nonCoherentAtomSize); + + json.WriteString("memoryHeapCount"); + json.WriteNumber(memoryProperties.memoryHeapCount); + json.WriteString("memoryTypeCount"); + json.WriteNumber(memoryProperties.memoryTypeCount); + } + json.EndObject(); + } + { + json.WriteString("Total"); + VmaPrintDetailedStatistics(json, stats.total); + } + { + json.WriteString("MemoryInfo"); + json.BeginObject(); { - // Pools with algorithm other than default are not defragmented. - if(hAllocPool->m_BlockVector.GetAlgorithm() == 0) + for (uint32_t heapIndex = 0; heapIndex < allocator->GetMemoryHeapCount(); ++heapIndex) { - for(size_t i = m_CustomPoolContexts.size(); i--; ) + json.BeginString("Heap "); + json.ContinueString(heapIndex); + json.EndString(); + json.BeginObject(); { - if(m_CustomPoolContexts[i]->GetCustomPool() == hAllocPool) + const VkMemoryHeap& heapInfo = allocator->m_MemProps.memoryHeaps[heapIndex]; + json.WriteString("Flags"); + json.BeginArray(true); { - pBlockVectorDefragCtx = m_CustomPoolContexts[i]; - break; + if (heapInfo.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) + json.WriteString("DEVICE_LOCAL"); + #if VMA_VULKAN_VERSION >= 1001000 + if (heapInfo.flags & VK_MEMORY_HEAP_MULTI_INSTANCE_BIT) + json.WriteString("MULTI_INSTANCE"); + #endif + + VkMemoryHeapFlags flags = heapInfo.flags & + ~(VK_MEMORY_HEAP_DEVICE_LOCAL_BIT + #if VMA_VULKAN_VERSION >= 1001000 + | VK_MEMORY_HEAP_MULTI_INSTANCE_BIT + #endif + ); + if (flags != 0) + json.WriteNumber(flags); } - } - if(!pBlockVectorDefragCtx) - { - pBlockVectorDefragCtx = vma_new(m_hAllocator, VmaBlockVectorDefragmentationContext)( - m_hAllocator, - hAllocPool, - &hAllocPool->m_BlockVector, - m_CurrFrameIndex); - m_CustomPoolContexts.push_back(pBlockVectorDefragCtx); - } - } - } - // This allocation belongs to default pool. - else - { - const uint32_t memTypeIndex = hAlloc->GetMemoryTypeIndex(); - pBlockVectorDefragCtx = m_DefaultPoolContexts[memTypeIndex]; - if(!pBlockVectorDefragCtx) - { - pBlockVectorDefragCtx = vma_new(m_hAllocator, VmaBlockVectorDefragmentationContext)( - m_hAllocator, - VMA_NULL, // hCustomPool - m_hAllocator->m_pBlockVectors[memTypeIndex], - m_CurrFrameIndex); - m_DefaultPoolContexts[memTypeIndex] = pBlockVectorDefragCtx; - } - } + json.EndArray(); - if(pBlockVectorDefragCtx) - { - VkBool32* const pChanged = (pAllocationsChanged != VMA_NULL) ? - &pAllocationsChanged[allocIndex] : VMA_NULL; - pBlockVectorDefragCtx->AddAllocation(hAlloc, pChanged); - } - } - } -} + json.WriteString("Size"); + json.WriteNumber(heapInfo.size); -VkResult VmaDefragmentationContext_T::Defragment( - VkDeviceSize maxCpuBytesToMove, uint32_t maxCpuAllocationsToMove, - VkDeviceSize maxGpuBytesToMove, uint32_t maxGpuAllocationsToMove, - VkCommandBuffer commandBuffer, VmaDefragmentationStats* pStats, VmaDefragmentationFlags flags) -{ - if(pStats) - { - memset(pStats, 0, sizeof(VmaDefragmentationStats)); - } + json.WriteString("Budget"); + json.BeginObject(); + { + json.WriteString("BudgetBytes"); + json.WriteNumber(budgets[heapIndex].budget); + json.WriteString("UsageBytes"); + json.WriteNumber(budgets[heapIndex].usage); + } + json.EndObject(); - if(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL) - { - // For incremental defragmetnations, we just earmark how much we can move - // The real meat is in the defragmentation steps - m_MaxCpuBytesToMove = maxCpuBytesToMove; - m_MaxCpuAllocationsToMove = maxCpuAllocationsToMove; + json.WriteString("Stats"); + VmaPrintDetailedStatistics(json, stats.memoryHeap[heapIndex]); - m_MaxGpuBytesToMove = maxGpuBytesToMove; - m_MaxGpuAllocationsToMove = maxGpuAllocationsToMove; + json.WriteString("MemoryPools"); + json.BeginObject(); + { + for (uint32_t typeIndex = 0; typeIndex < allocator->GetMemoryTypeCount(); ++typeIndex) + { + if (allocator->MemoryTypeIndexToHeapIndex(typeIndex) == heapIndex) + { + json.BeginString("Type "); + json.ContinueString(typeIndex); + json.EndString(); + json.BeginObject(); + { + json.WriteString("Flags"); + json.BeginArray(true); + { + VkMemoryPropertyFlags flags = allocator->m_MemProps.memoryTypes[typeIndex].propertyFlags; + if (flags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT) + json.WriteString("DEVICE_LOCAL"); + if (flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) + json.WriteString("HOST_VISIBLE"); + if (flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT) + json.WriteString("HOST_COHERENT"); + if (flags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT) + json.WriteString("HOST_CACHED"); + if (flags & VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT) + json.WriteString("LAZILY_ALLOCATED"); + #if VMA_VULKAN_VERSION >= 1001000 + if (flags & VK_MEMORY_PROPERTY_PROTECTED_BIT) + json.WriteString("PROTECTED"); + #endif + #if VK_AMD_device_coherent_memory + if (flags & VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY) + json.WriteString("DEVICE_COHERENT_AMD"); + if (flags & VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY) + json.WriteString("DEVICE_UNCACHED_AMD"); + #endif + + flags &= ~(VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT + #if VMA_VULKAN_VERSION >= 1001000 + | VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT + #endif + #if VK_AMD_device_coherent_memory + | VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY + | VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY + #endif + | VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT + | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT + | VK_MEMORY_PROPERTY_HOST_CACHED_BIT); + if (flags != 0) + json.WriteNumber(flags); + } + json.EndArray(); + + json.WriteString("Stats"); + VmaPrintDetailedStatistics(json, stats.memoryType[typeIndex]); + } + json.EndObject(); + } + } - if(m_MaxCpuBytesToMove == 0 && m_MaxCpuAllocationsToMove == 0 && - m_MaxGpuBytesToMove == 0 && m_MaxGpuAllocationsToMove == 0) - return VK_SUCCESS; + } + json.EndObject(); + } + json.EndObject(); + } + } + json.EndObject(); + } - return VK_NOT_READY; - } + if (detailedMap == VK_TRUE) + allocator->PrintDetailedMap(json); - if(commandBuffer == VK_NULL_HANDLE) - { - maxGpuBytesToMove = 0; - maxGpuAllocationsToMove = 0; + json.EndObject(); } - VkResult res = VK_SUCCESS; - - // Process default pools. - for(uint32_t memTypeIndex = 0; - memTypeIndex < m_hAllocator->GetMemoryTypeCount() && res >= VK_SUCCESS; - ++memTypeIndex) - { - VmaBlockVectorDefragmentationContext* pBlockVectorCtx = m_DefaultPoolContexts[memTypeIndex]; - if(pBlockVectorCtx) - { - VMA_ASSERT(pBlockVectorCtx->GetBlockVector()); - pBlockVectorCtx->GetBlockVector()->Defragment( - pBlockVectorCtx, - pStats, flags, - maxCpuBytesToMove, maxCpuAllocationsToMove, - maxGpuBytesToMove, maxGpuAllocationsToMove, - commandBuffer); - if(pBlockVectorCtx->res != VK_SUCCESS) - { - res = pBlockVectorCtx->res; - } - } - } + *ppStatsString = VmaCreateStringCopy(allocator->GetAllocationCallbacks(), sb.GetData(), sb.GetLength()); +} - // Process custom pools. - for(size_t customCtxIndex = 0, customCtxCount = m_CustomPoolContexts.size(); - customCtxIndex < customCtxCount && res >= VK_SUCCESS; - ++customCtxIndex) +VMA_CALL_PRE void VMA_CALL_POST vmaFreeStatsString( + VmaAllocator allocator, + char* pStatsString) +{ + if(pStatsString != VMA_NULL) { - VmaBlockVectorDefragmentationContext* pBlockVectorCtx = m_CustomPoolContexts[customCtxIndex]; - VMA_ASSERT(pBlockVectorCtx && pBlockVectorCtx->GetBlockVector()); - pBlockVectorCtx->GetBlockVector()->Defragment( - pBlockVectorCtx, - pStats, flags, - maxCpuBytesToMove, maxCpuAllocationsToMove, - maxGpuBytesToMove, maxGpuAllocationsToMove, - commandBuffer); - if(pBlockVectorCtx->res != VK_SUCCESS) - { - res = pBlockVectorCtx->res; - } + VMA_ASSERT(allocator); + VmaFreeString(allocator->GetAllocationCallbacks(), pStatsString); } - - return res; } -VkResult VmaDefragmentationContext_T::DefragmentPassBegin(VmaDefragmentationPassInfo* pInfo) +#endif // VMA_STATS_STRING_ENABLED + +/* +This function is not protected by any mutex because it just reads immutable data. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndex( + VmaAllocator allocator, + uint32_t memoryTypeBits, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + uint32_t* pMemoryTypeIndex) { - VmaDefragmentationPassMoveInfo* pCurrentMove = pInfo->pMoves; - uint32_t movesLeft = pInfo->moveCount; + VMA_ASSERT(allocator != VK_NULL_HANDLE); + VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); + VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); - // Process default pools. - for(uint32_t memTypeIndex = 0; - memTypeIndex < m_hAllocator->GetMemoryTypeCount(); - ++memTypeIndex) - { - VmaBlockVectorDefragmentationContext *pBlockVectorCtx = m_DefaultPoolContexts[memTypeIndex]; - if(pBlockVectorCtx) - { - VMA_ASSERT(pBlockVectorCtx->GetBlockVector()); + return allocator->FindMemoryTypeIndex(memoryTypeBits, pAllocationCreateInfo, UINT32_MAX, pMemoryTypeIndex); +} - if(!pBlockVectorCtx->hasDefragmentationPlan) - { - pBlockVectorCtx->GetBlockVector()->Defragment( - pBlockVectorCtx, - m_pStats, m_Flags, - m_MaxCpuBytesToMove, m_MaxCpuAllocationsToMove, - m_MaxGpuBytesToMove, m_MaxGpuAllocationsToMove, - VK_NULL_HANDLE); +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForBufferInfo( + VmaAllocator allocator, + const VkBufferCreateInfo* pBufferCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + uint32_t* pMemoryTypeIndex) +{ + VMA_ASSERT(allocator != VK_NULL_HANDLE); + VMA_ASSERT(pBufferCreateInfo != VMA_NULL); + VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); + VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); - if(pBlockVectorCtx->res < VK_SUCCESS) - continue; + const VkDevice hDev = allocator->m_hDevice; + const VmaVulkanFunctions* funcs = &allocator->GetVulkanFunctions(); + VkResult res; - pBlockVectorCtx->hasDefragmentationPlan = true; - } +#if VMA_VULKAN_VERSION >= 1003000 + if(funcs->vkGetDeviceBufferMemoryRequirements) + { + // Can query straight from VkBufferCreateInfo :) + VkDeviceBufferMemoryRequirements devBufMemReq = {VK_STRUCTURE_TYPE_DEVICE_BUFFER_MEMORY_REQUIREMENTS}; + devBufMemReq.pCreateInfo = pBufferCreateInfo; - const uint32_t processed = pBlockVectorCtx->GetBlockVector()->ProcessDefragmentations( - pBlockVectorCtx, - pCurrentMove, movesLeft); + VkMemoryRequirements2 memReq = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2}; + (*funcs->vkGetDeviceBufferMemoryRequirements)(hDev, &devBufMemReq, &memReq); - movesLeft -= processed; - pCurrentMove += processed; - } + res = allocator->FindMemoryTypeIndex( + memReq.memoryRequirements.memoryTypeBits, pAllocationCreateInfo, pBufferCreateInfo->usage, pMemoryTypeIndex); } - - // Process custom pools. - for(size_t customCtxIndex = 0, customCtxCount = m_CustomPoolContexts.size(); - customCtxIndex < customCtxCount; - ++customCtxIndex) + else +#endif // #if VMA_VULKAN_VERSION >= 1003000 { - VmaBlockVectorDefragmentationContext *pBlockVectorCtx = m_CustomPoolContexts[customCtxIndex]; - VMA_ASSERT(pBlockVectorCtx && pBlockVectorCtx->GetBlockVector()); - - if(!pBlockVectorCtx->hasDefragmentationPlan) + // Must create a dummy buffer to query :( + VkBuffer hBuffer = VK_NULL_HANDLE; + res = funcs->vkCreateBuffer( + hDev, pBufferCreateInfo, allocator->GetAllocationCallbacks(), &hBuffer); + if(res == VK_SUCCESS) { - pBlockVectorCtx->GetBlockVector()->Defragment( - pBlockVectorCtx, - m_pStats, m_Flags, - m_MaxCpuBytesToMove, m_MaxCpuAllocationsToMove, - m_MaxGpuBytesToMove, m_MaxGpuAllocationsToMove, - VK_NULL_HANDLE); + VkMemoryRequirements memReq = {}; + funcs->vkGetBufferMemoryRequirements(hDev, hBuffer, &memReq); - if(pBlockVectorCtx->res < VK_SUCCESS) - continue; + res = allocator->FindMemoryTypeIndex( + memReq.memoryTypeBits, pAllocationCreateInfo, pBufferCreateInfo->usage, pMemoryTypeIndex); - pBlockVectorCtx->hasDefragmentationPlan = true; + funcs->vkDestroyBuffer( + hDev, hBuffer, allocator->GetAllocationCallbacks()); } - - const uint32_t processed = pBlockVectorCtx->GetBlockVector()->ProcessDefragmentations( - pBlockVectorCtx, - pCurrentMove, movesLeft); - - movesLeft -= processed; - pCurrentMove += processed; } - - pInfo->moveCount = pInfo->moveCount - movesLeft; - - return VK_SUCCESS; + return res; } -VkResult VmaDefragmentationContext_T::DefragmentPassEnd() + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForImageInfo( + VmaAllocator allocator, + const VkImageCreateInfo* pImageCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + uint32_t* pMemoryTypeIndex) { - VkResult res = VK_SUCCESS; + VMA_ASSERT(allocator != VK_NULL_HANDLE); + VMA_ASSERT(pImageCreateInfo != VMA_NULL); + VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); + VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); - // Process default pools. - for(uint32_t memTypeIndex = 0; - memTypeIndex < m_hAllocator->GetMemoryTypeCount(); - ++memTypeIndex) - { - VmaBlockVectorDefragmentationContext *pBlockVectorCtx = m_DefaultPoolContexts[memTypeIndex]; - if(pBlockVectorCtx) - { - VMA_ASSERT(pBlockVectorCtx->GetBlockVector()); + const VkDevice hDev = allocator->m_hDevice; + const VmaVulkanFunctions* funcs = &allocator->GetVulkanFunctions(); + VkResult res; - if(!pBlockVectorCtx->hasDefragmentationPlan) - { - res = VK_NOT_READY; - continue; - } +#if VMA_VULKAN_VERSION >= 1003000 + if(funcs->vkGetDeviceImageMemoryRequirements) + { + // Can query straight from VkImageCreateInfo :) + VkDeviceImageMemoryRequirements devImgMemReq = {VK_STRUCTURE_TYPE_DEVICE_IMAGE_MEMORY_REQUIREMENTS}; + devImgMemReq.pCreateInfo = pImageCreateInfo; + VMA_ASSERT(pImageCreateInfo->tiling != VK_IMAGE_TILING_DRM_FORMAT_MODIFIER_EXT_COPY && (pImageCreateInfo->flags & VK_IMAGE_CREATE_DISJOINT_BIT_COPY) == 0 && + "Cannot use this VkImageCreateInfo with vmaFindMemoryTypeIndexForImageInfo as I don't know what to pass as VkDeviceImageMemoryRequirements::planeAspect."); - pBlockVectorCtx->GetBlockVector()->CommitDefragmentations( - pBlockVectorCtx, m_pStats); + VkMemoryRequirements2 memReq = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2}; + (*funcs->vkGetDeviceImageMemoryRequirements)(hDev, &devImgMemReq, &memReq); - if(pBlockVectorCtx->defragmentationMoves.size() != pBlockVectorCtx->defragmentationMovesCommitted) - res = VK_NOT_READY; - } + res = allocator->FindMemoryTypeIndex( + memReq.memoryRequirements.memoryTypeBits, pAllocationCreateInfo, pImageCreateInfo->usage, pMemoryTypeIndex); } - - // Process custom pools. - for(size_t customCtxIndex = 0, customCtxCount = m_CustomPoolContexts.size(); - customCtxIndex < customCtxCount; - ++customCtxIndex) + else +#endif // #if VMA_VULKAN_VERSION >= 1003000 { - VmaBlockVectorDefragmentationContext *pBlockVectorCtx = m_CustomPoolContexts[customCtxIndex]; - VMA_ASSERT(pBlockVectorCtx && pBlockVectorCtx->GetBlockVector()); - - if(!pBlockVectorCtx->hasDefragmentationPlan) + // Must create a dummy image to query :( + VkImage hImage = VK_NULL_HANDLE; + res = funcs->vkCreateImage( + hDev, pImageCreateInfo, allocator->GetAllocationCallbacks(), &hImage); + if(res == VK_SUCCESS) { - res = VK_NOT_READY; - continue; - } + VkMemoryRequirements memReq = {}; + funcs->vkGetImageMemoryRequirements(hDev, hImage, &memReq); - pBlockVectorCtx->GetBlockVector()->CommitDefragmentations( - pBlockVectorCtx, m_pStats); + res = allocator->FindMemoryTypeIndex( + memReq.memoryTypeBits, pAllocationCreateInfo, pImageCreateInfo->usage, pMemoryTypeIndex); - if(pBlockVectorCtx->defragmentationMoves.size() != pBlockVectorCtx->defragmentationMovesCommitted) - res = VK_NOT_READY; + funcs->vkDestroyImage( + hDev, hImage, allocator->GetAllocationCallbacks()); + } } - return res; } -//////////////////////////////////////////////////////////////////////////////// -// VmaRecorder +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreatePool( + VmaAllocator allocator, + const VmaPoolCreateInfo* pCreateInfo, + VmaPool* pPool) +{ + VMA_ASSERT(allocator && pCreateInfo && pPool); + + VMA_DEBUG_LOG("vmaCreatePool"); -#if VMA_RECORDING_ENABLED + VMA_DEBUG_GLOBAL_MUTEX_LOCK -VmaRecorder::VmaRecorder() : - m_UseMutex(true), - m_Flags(0), - m_File(VMA_NULL), - m_RecordingStartTime(std::chrono::high_resolution_clock::now()) -{ + return allocator->CreatePool(pCreateInfo, pPool); } -VkResult VmaRecorder::Init(const VmaRecordSettings& settings, bool useMutex) +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyPool( + VmaAllocator allocator, + VmaPool pool) { - m_UseMutex = useMutex; - m_Flags = settings.flags; - -#if defined(_WIN32) - // Open file for writing. - errno_t err = fopen_s(&m_File, settings.pFilePath, "wb"); + VMA_ASSERT(allocator); - if(err != 0) + if(pool == VK_NULL_HANDLE) { - return VK_ERROR_INITIALIZATION_FAILED; + return; } -#else - // Open file for writing. - m_File = fopen(settings.pFilePath, "wb"); - if(m_File == 0) - { - return VK_ERROR_INITIALIZATION_FAILED; - } -#endif + VMA_DEBUG_LOG("vmaDestroyPool"); - // Write header. - fprintf(m_File, "%s\n", "Vulkan Memory Allocator,Calls recording"); - fprintf(m_File, "%s\n", "1,8"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - return VK_SUCCESS; + allocator->DestroyPool(pool); } -VmaRecorder::~VmaRecorder() +VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolStatistics( + VmaAllocator allocator, + VmaPool pool, + VmaStatistics* pPoolStats) { - if(m_File != VMA_NULL) - { - fclose(m_File); - } -} + VMA_ASSERT(allocator && pool && pPoolStats); -void VmaRecorder::RecordCreateAllocator(uint32_t frameIndex) -{ - CallParams callParams; - GetBasicParams(callParams); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaCreateAllocator\n", callParams.threadId, callParams.time, frameIndex); - Flush(); + allocator->GetPoolStatistics(pool, pPoolStats); } -void VmaRecorder::RecordDestroyAllocator(uint32_t frameIndex) +VMA_CALL_PRE void VMA_CALL_POST vmaCalculatePoolStatistics( + VmaAllocator allocator, + VmaPool pool, + VmaDetailedStatistics* pPoolStats) { - CallParams callParams; - GetBasicParams(callParams); + VMA_ASSERT(allocator && pool && pPoolStats); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaDestroyAllocator\n", callParams.threadId, callParams.time, frameIndex); - Flush(); + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + allocator->CalculatePoolStatistics(pool, pPoolStats); } -void VmaRecorder::RecordCreatePool(uint32_t frameIndex, const VmaPoolCreateInfo& createInfo, VmaPool pool) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckPoolCorruption(VmaAllocator allocator, VmaPool pool) { - CallParams callParams; - GetBasicParams(callParams); + VMA_ASSERT(allocator && pool); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaCreatePool,%u,%u,%llu,%llu,%llu,%u,%p\n", callParams.threadId, callParams.time, frameIndex, - createInfo.memoryTypeIndex, - createInfo.flags, - createInfo.blockSize, - (uint64_t)createInfo.minBlockCount, - (uint64_t)createInfo.maxBlockCount, - createInfo.frameInUseCount, - pool); - Flush(); -} + VMA_DEBUG_GLOBAL_MUTEX_LOCK -void VmaRecorder::RecordDestroyPool(uint32_t frameIndex, VmaPool pool) -{ - CallParams callParams; - GetBasicParams(callParams); + VMA_DEBUG_LOG("vmaCheckPoolCorruption"); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaDestroyPool,%p\n", callParams.threadId, callParams.time, frameIndex, - pool); - Flush(); + return allocator->CheckPoolCorruption(pool); } -void VmaRecorder::RecordAllocateMemory(uint32_t frameIndex, - const VkMemoryRequirements& vkMemReq, - const VmaAllocationCreateInfo& createInfo, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); - - VmaMutexLock lock(m_FileMutex, m_UseMutex); - UserDataString userDataStr(createInfo.flags, createInfo.pUserData); - fprintf(m_File, "%u,%.3f,%u,vmaAllocateMemory,%llu,%llu,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, - vkMemReq.size, - vkMemReq.alignment, - vkMemReq.memoryTypeBits, - createInfo.flags, - createInfo.usage, - createInfo.requiredFlags, - createInfo.preferredFlags, - createInfo.memoryTypeBits, - createInfo.pool, - allocation, - userDataStr.GetString()); - Flush(); -} - -void VmaRecorder::RecordAllocateMemoryPages(uint32_t frameIndex, - const VkMemoryRequirements& vkMemReq, - const VmaAllocationCreateInfo& createInfo, - uint64_t allocationCount, - const VmaAllocation* pAllocations) -{ - CallParams callParams; - GetBasicParams(callParams); - - VmaMutexLock lock(m_FileMutex, m_UseMutex); - UserDataString userDataStr(createInfo.flags, createInfo.pUserData); - fprintf(m_File, "%u,%.3f,%u,vmaAllocateMemoryPages,%llu,%llu,%u,%u,%u,%u,%u,%u,%p,", callParams.threadId, callParams.time, frameIndex, - vkMemReq.size, - vkMemReq.alignment, - vkMemReq.memoryTypeBits, - createInfo.flags, - createInfo.usage, - createInfo.requiredFlags, - createInfo.preferredFlags, - createInfo.memoryTypeBits, - createInfo.pool); - PrintPointerList(allocationCount, pAllocations); - fprintf(m_File, ",%s\n", userDataStr.GetString()); - Flush(); -} - -void VmaRecorder::RecordAllocateMemoryForBuffer(uint32_t frameIndex, - const VkMemoryRequirements& vkMemReq, - bool requiresDedicatedAllocation, - bool prefersDedicatedAllocation, - const VmaAllocationCreateInfo& createInfo, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); - - VmaMutexLock lock(m_FileMutex, m_UseMutex); - UserDataString userDataStr(createInfo.flags, createInfo.pUserData); - fprintf(m_File, "%u,%.3f,%u,vmaAllocateMemoryForBuffer,%llu,%llu,%u,%u,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, - vkMemReq.size, - vkMemReq.alignment, - vkMemReq.memoryTypeBits, - requiresDedicatedAllocation ? 1 : 0, - prefersDedicatedAllocation ? 1 : 0, - createInfo.flags, - createInfo.usage, - createInfo.requiredFlags, - createInfo.preferredFlags, - createInfo.memoryTypeBits, - createInfo.pool, - allocation, - userDataStr.GetString()); - Flush(); -} - -void VmaRecorder::RecordAllocateMemoryForImage(uint32_t frameIndex, - const VkMemoryRequirements& vkMemReq, - bool requiresDedicatedAllocation, - bool prefersDedicatedAllocation, - const VmaAllocationCreateInfo& createInfo, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); - - VmaMutexLock lock(m_FileMutex, m_UseMutex); - UserDataString userDataStr(createInfo.flags, createInfo.pUserData); - fprintf(m_File, "%u,%.3f,%u,vmaAllocateMemoryForImage,%llu,%llu,%u,%u,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, - vkMemReq.size, - vkMemReq.alignment, - vkMemReq.memoryTypeBits, - requiresDedicatedAllocation ? 1 : 0, - prefersDedicatedAllocation ? 1 : 0, - createInfo.flags, - createInfo.usage, - createInfo.requiredFlags, - createInfo.preferredFlags, - createInfo.memoryTypeBits, - createInfo.pool, - allocation, - userDataStr.GetString()); - Flush(); -} - -void VmaRecorder::RecordFreeMemory(uint32_t frameIndex, - VmaAllocation allocation) +VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolName( + VmaAllocator allocator, + VmaPool pool, + const char** ppName) { - CallParams callParams; - GetBasicParams(callParams); + VMA_ASSERT(allocator && pool && ppName); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaFreeMemory,%p\n", callParams.threadId, callParams.time, frameIndex, - allocation); - Flush(); -} + VMA_DEBUG_LOG("vmaGetPoolName"); -void VmaRecorder::RecordFreeMemoryPages(uint32_t frameIndex, - uint64_t allocationCount, - const VmaAllocation* pAllocations) -{ - CallParams callParams; - GetBasicParams(callParams); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaFreeMemoryPages,", callParams.threadId, callParams.time, frameIndex); - PrintPointerList(allocationCount, pAllocations); - fprintf(m_File, "\n"); - Flush(); + *ppName = pool->GetName(); } -void VmaRecorder::RecordSetAllocationUserData(uint32_t frameIndex, - VmaAllocation allocation, - const void* pUserData) +VMA_CALL_PRE void VMA_CALL_POST vmaSetPoolName( + VmaAllocator allocator, + VmaPool pool, + const char* pName) { - CallParams callParams; - GetBasicParams(callParams); + VMA_ASSERT(allocator && pool); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - UserDataString userDataStr( - allocation->IsUserDataString() ? VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT : 0, - pUserData); - fprintf(m_File, "%u,%.3f,%u,vmaSetAllocationUserData,%p,%s\n", callParams.threadId, callParams.time, frameIndex, - allocation, - userDataStr.GetString()); - Flush(); -} + VMA_DEBUG_LOG("vmaSetPoolName"); -void VmaRecorder::RecordCreateLostAllocation(uint32_t frameIndex, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaCreateLostAllocation,%p\n", callParams.threadId, callParams.time, frameIndex, - allocation); - Flush(); + pool->SetName(pName); } -void VmaRecorder::RecordMapMemory(uint32_t frameIndex, - VmaAllocation allocation) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemory( + VmaAllocator allocator, + const VkMemoryRequirements* pVkMemoryRequirements, + const VmaAllocationCreateInfo* pCreateInfo, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) { - CallParams callParams; - GetBasicParams(callParams); + VMA_ASSERT(allocator && pVkMemoryRequirements && pCreateInfo && pAllocation); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaMapMemory,%p\n", callParams.threadId, callParams.time, frameIndex, - allocation); - Flush(); -} + VMA_DEBUG_LOG("vmaAllocateMemory"); -void VmaRecorder::RecordUnmapMemory(uint32_t frameIndex, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaUnmapMemory,%p\n", callParams.threadId, callParams.time, frameIndex, - allocation); - Flush(); -} + VkResult result = allocator->AllocateMemory( + *pVkMemoryRequirements, + false, // requiresDedicatedAllocation + false, // prefersDedicatedAllocation + VK_NULL_HANDLE, // dedicatedBuffer + VK_NULL_HANDLE, // dedicatedImage + UINT32_MAX, // dedicatedBufferImageUsage + *pCreateInfo, + VMA_SUBALLOCATION_TYPE_UNKNOWN, + 1, // allocationCount + pAllocation); -void VmaRecorder::RecordFlushAllocation(uint32_t frameIndex, - VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size) -{ - CallParams callParams; - GetBasicParams(callParams); + if(pAllocationInfo != VMA_NULL && result == VK_SUCCESS) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaFlushAllocation,%p,%llu,%llu\n", callParams.threadId, callParams.time, frameIndex, - allocation, - offset, - size); - Flush(); + return result; } -void VmaRecorder::RecordInvalidateAllocation(uint32_t frameIndex, - VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryPages( + VmaAllocator allocator, + const VkMemoryRequirements* pVkMemoryRequirements, + const VmaAllocationCreateInfo* pCreateInfo, + size_t allocationCount, + VmaAllocation* pAllocations, + VmaAllocationInfo* pAllocationInfo) { - CallParams callParams; - GetBasicParams(callParams); + if(allocationCount == 0) + { + return VK_SUCCESS; + } - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaInvalidateAllocation,%p,%llu,%llu\n", callParams.threadId, callParams.time, frameIndex, - allocation, - offset, - size); - Flush(); -} + VMA_ASSERT(allocator && pVkMemoryRequirements && pCreateInfo && pAllocations); -void VmaRecorder::RecordCreateBuffer(uint32_t frameIndex, - const VkBufferCreateInfo& bufCreateInfo, - const VmaAllocationCreateInfo& allocCreateInfo, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); - - VmaMutexLock lock(m_FileMutex, m_UseMutex); - UserDataString userDataStr(allocCreateInfo.flags, allocCreateInfo.pUserData); - fprintf(m_File, "%u,%.3f,%u,vmaCreateBuffer,%u,%llu,%u,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, - bufCreateInfo.flags, - bufCreateInfo.size, - bufCreateInfo.usage, - bufCreateInfo.sharingMode, - allocCreateInfo.flags, - allocCreateInfo.usage, - allocCreateInfo.requiredFlags, - allocCreateInfo.preferredFlags, - allocCreateInfo.memoryTypeBits, - allocCreateInfo.pool, - allocation, - userDataStr.GetString()); - Flush(); -} - -void VmaRecorder::RecordCreateImage(uint32_t frameIndex, - const VkImageCreateInfo& imageCreateInfo, - const VmaAllocationCreateInfo& allocCreateInfo, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); - - VmaMutexLock lock(m_FileMutex, m_UseMutex); - UserDataString userDataStr(allocCreateInfo.flags, allocCreateInfo.pUserData); - fprintf(m_File, "%u,%.3f,%u,vmaCreateImage,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, - imageCreateInfo.flags, - imageCreateInfo.imageType, - imageCreateInfo.format, - imageCreateInfo.extent.width, - imageCreateInfo.extent.height, - imageCreateInfo.extent.depth, - imageCreateInfo.mipLevels, - imageCreateInfo.arrayLayers, - imageCreateInfo.samples, - imageCreateInfo.tiling, - imageCreateInfo.usage, - imageCreateInfo.sharingMode, - imageCreateInfo.initialLayout, - allocCreateInfo.flags, - allocCreateInfo.usage, - allocCreateInfo.requiredFlags, - allocCreateInfo.preferredFlags, - allocCreateInfo.memoryTypeBits, - allocCreateInfo.pool, - allocation, - userDataStr.GetString()); - Flush(); -} - -void VmaRecorder::RecordDestroyBuffer(uint32_t frameIndex, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); + VMA_DEBUG_LOG("vmaAllocateMemoryPages"); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaDestroyBuffer,%p\n", callParams.threadId, callParams.time, frameIndex, - allocation); - Flush(); -} + VMA_DEBUG_GLOBAL_MUTEX_LOCK -void VmaRecorder::RecordDestroyImage(uint32_t frameIndex, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); + VkResult result = allocator->AllocateMemory( + *pVkMemoryRequirements, + false, // requiresDedicatedAllocation + false, // prefersDedicatedAllocation + VK_NULL_HANDLE, // dedicatedBuffer + VK_NULL_HANDLE, // dedicatedImage + UINT32_MAX, // dedicatedBufferImageUsage + *pCreateInfo, + VMA_SUBALLOCATION_TYPE_UNKNOWN, + allocationCount, + pAllocations); + + if(pAllocationInfo != VMA_NULL && result == VK_SUCCESS) + { + for(size_t i = 0; i < allocationCount; ++i) + { + allocator->GetAllocationInfo(pAllocations[i], pAllocationInfo + i); + } + } - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaDestroyImage,%p\n", callParams.threadId, callParams.time, frameIndex, - allocation); - Flush(); + return result; } -void VmaRecorder::RecordTouchAllocation(uint32_t frameIndex, - VmaAllocation allocation) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForBuffer( + VmaAllocator allocator, + VkBuffer buffer, + const VmaAllocationCreateInfo* pCreateInfo, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) { - CallParams callParams; - GetBasicParams(callParams); + VMA_ASSERT(allocator && buffer != VK_NULL_HANDLE && pCreateInfo && pAllocation); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaTouchAllocation,%p\n", callParams.threadId, callParams.time, frameIndex, - allocation); - Flush(); -} + VMA_DEBUG_LOG("vmaAllocateMemoryForBuffer"); -void VmaRecorder::RecordGetAllocationInfo(uint32_t frameIndex, - VmaAllocation allocation) -{ - CallParams callParams; - GetBasicParams(callParams); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaGetAllocationInfo,%p\n", callParams.threadId, callParams.time, frameIndex, - allocation); - Flush(); -} + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetBufferMemoryRequirements(buffer, vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation); -void VmaRecorder::RecordMakePoolAllocationsLost(uint32_t frameIndex, - VmaPool pool) -{ - CallParams callParams; - GetBasicParams(callParams); + VkResult result = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + buffer, // dedicatedBuffer + VK_NULL_HANDLE, // dedicatedImage + UINT32_MAX, // dedicatedBufferImageUsage + *pCreateInfo, + VMA_SUBALLOCATION_TYPE_BUFFER, + 1, // allocationCount + pAllocation); + + if(pAllocationInfo && result == VK_SUCCESS) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaMakePoolAllocationsLost,%p\n", callParams.threadId, callParams.time, frameIndex, - pool); - Flush(); + return result; } -void VmaRecorder::RecordDefragmentationBegin(uint32_t frameIndex, - const VmaDefragmentationInfo2& info, - VmaDefragmentationContext ctx) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForImage( + VmaAllocator allocator, + VkImage image, + const VmaAllocationCreateInfo* pCreateInfo, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) { - CallParams callParams; - GetBasicParams(callParams); + VMA_ASSERT(allocator && image != VK_NULL_HANDLE && pCreateInfo && pAllocation); - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaDefragmentationBegin,%u,", callParams.threadId, callParams.time, frameIndex, - info.flags); - PrintPointerList(info.allocationCount, info.pAllocations); - fprintf(m_File, ","); - PrintPointerList(info.poolCount, info.pPools); - fprintf(m_File, ",%llu,%u,%llu,%u,%p,%p\n", - info.maxCpuBytesToMove, - info.maxCpuAllocationsToMove, - info.maxGpuBytesToMove, - info.maxGpuAllocationsToMove, - info.commandBuffer, - ctx); - Flush(); -} + VMA_DEBUG_LOG("vmaAllocateMemoryForImage"); -void VmaRecorder::RecordDefragmentationEnd(uint32_t frameIndex, - VmaDefragmentationContext ctx) -{ - CallParams callParams; - GetBasicParams(callParams); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaDefragmentationEnd,%p\n", callParams.threadId, callParams.time, frameIndex, - ctx); - Flush(); -} + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetImageMemoryRequirements(image, vkMemReq, + requiresDedicatedAllocation, prefersDedicatedAllocation); -void VmaRecorder::RecordSetPoolName(uint32_t frameIndex, - VmaPool pool, - const char* name) -{ - CallParams callParams; - GetBasicParams(callParams); + VkResult result = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + VK_NULL_HANDLE, // dedicatedBuffer + image, // dedicatedImage + UINT32_MAX, // dedicatedBufferImageUsage + *pCreateInfo, + VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN, + 1, // allocationCount + pAllocation); + + if(pAllocationInfo && result == VK_SUCCESS) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } - VmaMutexLock lock(m_FileMutex, m_UseMutex); - fprintf(m_File, "%u,%.3f,%u,vmaSetPoolName,%p,%s\n", callParams.threadId, callParams.time, frameIndex, - pool, name != VMA_NULL ? name : ""); - Flush(); + return result; } -VmaRecorder::UserDataString::UserDataString(VmaAllocationCreateFlags allocFlags, const void* pUserData) +VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemory( + VmaAllocator allocator, + VmaAllocation allocation) { - if(pUserData != VMA_NULL) - { - if((allocFlags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0) - { - m_Str = (const char*)pUserData; - } - else - { - // If VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT is not specified, convert the string's memory address to a string and store it. - snprintf(m_PtrStr, 17, "%p", pUserData); - m_Str = m_PtrStr; - } - } - else + VMA_ASSERT(allocator); + + if(allocation == VK_NULL_HANDLE) { - m_Str = ""; + return; } + + VMA_DEBUG_LOG("vmaFreeMemory"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + allocator->FreeMemory( + 1, // allocationCount + &allocation); } -void VmaRecorder::WriteConfiguration( - const VkPhysicalDeviceProperties& devProps, - const VkPhysicalDeviceMemoryProperties& memProps, - uint32_t vulkanApiVersion, - bool dedicatedAllocationExtensionEnabled, - bool bindMemory2ExtensionEnabled, - bool memoryBudgetExtensionEnabled, - bool deviceCoherentMemoryExtensionEnabled) +VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemoryPages( + VmaAllocator allocator, + size_t allocationCount, + const VmaAllocation* pAllocations) { - fprintf(m_File, "Config,Begin\n"); + if(allocationCount == 0) + { + return; + } - fprintf(m_File, "VulkanApiVersion,%u,%u\n", VK_VERSION_MAJOR(vulkanApiVersion), VK_VERSION_MINOR(vulkanApiVersion)); + VMA_ASSERT(allocator); - fprintf(m_File, "PhysicalDevice,apiVersion,%u\n", devProps.apiVersion); - fprintf(m_File, "PhysicalDevice,driverVersion,%u\n", devProps.driverVersion); - fprintf(m_File, "PhysicalDevice,vendorID,%u\n", devProps.vendorID); - fprintf(m_File, "PhysicalDevice,deviceID,%u\n", devProps.deviceID); - fprintf(m_File, "PhysicalDevice,deviceType,%u\n", devProps.deviceType); - fprintf(m_File, "PhysicalDevice,deviceName,%s\n", devProps.deviceName); + VMA_DEBUG_LOG("vmaFreeMemoryPages"); - fprintf(m_File, "PhysicalDeviceLimits,maxMemoryAllocationCount,%u\n", devProps.limits.maxMemoryAllocationCount); - fprintf(m_File, "PhysicalDeviceLimits,bufferImageGranularity,%llu\n", devProps.limits.bufferImageGranularity); - fprintf(m_File, "PhysicalDeviceLimits,nonCoherentAtomSize,%llu\n", devProps.limits.nonCoherentAtomSize); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - fprintf(m_File, "PhysicalDeviceMemory,HeapCount,%u\n", memProps.memoryHeapCount); - for (const auto i : c10::irange(memProps.memoryHeapCount)) { - fprintf(m_File, "PhysicalDeviceMemory,Heap,%u,size,%llu\n", i, memProps.memoryHeaps[i].size); - fprintf(m_File, "PhysicalDeviceMemory,Heap,%u,flags,%u\n", i, memProps.memoryHeaps[i].flags); - } - fprintf(m_File, "PhysicalDeviceMemory,TypeCount,%u\n", memProps.memoryTypeCount); - for (const auto i : c10::irange(memProps.memoryTypeCount)) { - fprintf(m_File, "PhysicalDeviceMemory,Type,%u,heapIndex,%u\n", i, memProps.memoryTypes[i].heapIndex); - fprintf(m_File, "PhysicalDeviceMemory,Type,%u,propertyFlags,%u\n", i, memProps.memoryTypes[i].propertyFlags); - } + allocator->FreeMemory(allocationCount, pAllocations); +} - fprintf(m_File, "Extension,VK_KHR_dedicated_allocation,%u\n", dedicatedAllocationExtensionEnabled ? 1 : 0); - fprintf(m_File, "Extension,VK_KHR_bind_memory2,%u\n", bindMemory2ExtensionEnabled ? 1 : 0); - fprintf(m_File, "Extension,VK_EXT_memory_budget,%u\n", memoryBudgetExtensionEnabled ? 1 : 0); - fprintf(m_File, "Extension,VK_AMD_device_coherent_memory,%u\n", deviceCoherentMemoryExtensionEnabled ? 1 : 0); +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocationInfo( + VmaAllocator allocator, + VmaAllocation allocation, + VmaAllocationInfo* pAllocationInfo) +{ + VMA_ASSERT(allocator && allocation && pAllocationInfo); - fprintf(m_File, "Macro,VMA_DEBUG_ALWAYS_DEDICATED_MEMORY,%u\n", VMA_DEBUG_ALWAYS_DEDICATED_MEMORY ? 1 : 0); - fprintf(m_File, "Macro,VMA_DEBUG_ALIGNMENT,%llu\n", (VkDeviceSize)VMA_DEBUG_ALIGNMENT); - fprintf(m_File, "Macro,VMA_DEBUG_MARGIN,%llu\n", (VkDeviceSize)VMA_DEBUG_MARGIN); - fprintf(m_File, "Macro,VMA_DEBUG_INITIALIZE_ALLOCATIONS,%u\n", VMA_DEBUG_INITIALIZE_ALLOCATIONS ? 1 : 0); - fprintf(m_File, "Macro,VMA_DEBUG_DETECT_CORRUPTION,%u\n", VMA_DEBUG_DETECT_CORRUPTION ? 1 : 0); - fprintf(m_File, "Macro,VMA_DEBUG_GLOBAL_MUTEX,%u\n", VMA_DEBUG_GLOBAL_MUTEX ? 1 : 0); - fprintf(m_File, "Macro,VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY,%llu\n", (VkDeviceSize)VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY); - fprintf(m_File, "Macro,VMA_SMALL_HEAP_MAX_SIZE,%llu\n", (VkDeviceSize)VMA_SMALL_HEAP_MAX_SIZE); - fprintf(m_File, "Macro,VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE,%llu\n", (VkDeviceSize)VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - fprintf(m_File, "Config,End\n"); + allocator->GetAllocationInfo(allocation, pAllocationInfo); } -void VmaRecorder::GetBasicParams(CallParams& outParams) +VMA_CALL_PRE void VMA_CALL_POST vmaSetAllocationUserData( + VmaAllocator allocator, + VmaAllocation allocation, + void* pUserData) { - #if defined(_WIN32) - outParams.threadId = GetCurrentThreadId(); - #else - // Use C++11 features to get thread id and convert it to uint32_t. - // There is room for optimization since sstream is quite slow. - // Is there a better way to convert std::this_thread::get_id() to uint32_t? - std::thread::id thread_id = std::this_thread::get_id(); - std::stringstream thread_id_to_string_converter; - thread_id_to_string_converter << thread_id; - std::string thread_id_as_string = thread_id_to_string_converter.str(); - outParams.threadId = static_cast(std::stoi(thread_id_as_string.c_str())); - #endif + VMA_ASSERT(allocator && allocation); - auto current_time = std::chrono::high_resolution_clock::now(); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - outParams.time = std::chrono::duration(current_time - m_RecordingStartTime).count(); + allocation->SetUserData(allocator, pUserData); } -void VmaRecorder::PrintPointerList(uint64_t count, const VmaAllocation* pItems) +VMA_CALL_PRE void VMA_CALL_POST vmaSetAllocationName( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + const char* VMA_NULLABLE pName) { - if(count) - { - fprintf(m_File, "%p", pItems[0]); - for(uint64_t i = 1; i < count; ++i) - { - fprintf(m_File, " %p", pItems[i]); - } - } + allocation->SetName(allocator, pName); } -void VmaRecorder::Flush() +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocationMemoryProperties( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkMemoryPropertyFlags* VMA_NOT_NULL pFlags) { - if((m_Flags & VMA_RECORD_FLUSH_AFTER_CALL_BIT) != 0) - { - fflush(m_File); - } + VMA_ASSERT(allocator && allocation && pFlags); + const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); + *pFlags = allocator->m_MemProps.memoryTypes[memTypeIndex].propertyFlags; } -#endif // #if VMA_RECORDING_ENABLED +VMA_CALL_PRE VkResult VMA_CALL_POST vmaMapMemory( + VmaAllocator allocator, + VmaAllocation allocation, + void** ppData) +{ + VMA_ASSERT(allocator && allocation && ppData); -//////////////////////////////////////////////////////////////////////////////// -// VmaAllocationObjectAllocator + VMA_DEBUG_GLOBAL_MUTEX_LOCK -VmaAllocationObjectAllocator::VmaAllocationObjectAllocator(const VkAllocationCallbacks* pAllocationCallbacks) : - m_Allocator(pAllocationCallbacks, 1024) -{ + return allocator->Map(allocation, ppData); } -template VmaAllocation VmaAllocationObjectAllocator::Allocate(Types... args) +VMA_CALL_PRE void VMA_CALL_POST vmaUnmapMemory( + VmaAllocator allocator, + VmaAllocation allocation) { - VmaMutexLock mutexLock(m_Mutex); - return m_Allocator.Alloc(std::forward(args)...); -} + VMA_ASSERT(allocator && allocation); -void VmaAllocationObjectAllocator::Free(VmaAllocation hAlloc) -{ - VmaMutexLock mutexLock(m_Mutex); - m_Allocator.Free(hAlloc); -} + VMA_DEBUG_GLOBAL_MUTEX_LOCK -//////////////////////////////////////////////////////////////////////////////// -// VmaAllocator_T + allocator->Unmap(allocation); +} -VmaAllocator_T::VmaAllocator_T(const VmaAllocatorCreateInfo* pCreateInfo) : - m_UseMutex((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT) == 0), - m_VulkanApiVersion(pCreateInfo->vulkanApiVersion != 0 ? pCreateInfo->vulkanApiVersion : VK_API_VERSION_1_0), - m_UseKhrDedicatedAllocation((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT) != 0), - m_UseKhrBindMemory2((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT) != 0), - m_UseExtMemoryBudget((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT) != 0), - m_UseAmdDeviceCoherentMemory((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT) != 0), - m_UseKhrBufferDeviceAddress((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT) != 0), - m_UseExtMemoryPriority((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT) != 0), - m_hDevice(pCreateInfo->device), - m_hInstance(pCreateInfo->instance), - m_AllocationCallbacksSpecified(pCreateInfo->pAllocationCallbacks != VMA_NULL), - m_AllocationCallbacks(pCreateInfo->pAllocationCallbacks ? - *pCreateInfo->pAllocationCallbacks : VmaEmptyAllocationCallbacks), - m_AllocationObjectAllocator(&m_AllocationCallbacks), - m_HeapSizeLimitMask(0), - m_DeviceMemoryCount(0), - m_PreferredLargeHeapBlockSize(0), - m_PhysicalDevice(pCreateInfo->physicalDevice), - m_CurrentFrameIndex(0), - m_GpuDefragmentationMemoryTypeBits(UINT32_MAX), - m_Pools(VmaStlAllocator(GetAllocationCallbacks())), - m_NextPoolId(0), - m_GlobalMemoryTypeBits(UINT32_MAX) -#if VMA_RECORDING_ENABLED - ,m_pRecorder(VMA_NULL) -#endif +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocation( + VmaAllocator allocator, + VmaAllocation allocation, + VkDeviceSize offset, + VkDeviceSize size) { - if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) - { - m_UseKhrDedicatedAllocation = false; - m_UseKhrBindMemory2 = false; - } - - if(VMA_DEBUG_DETECT_CORRUPTION) - { - // Needs to be multiply of uint32_t size because we are going to write VMA_CORRUPTION_DETECTION_MAGIC_VALUE to it. - VMA_ASSERT(VMA_DEBUG_MARGIN % sizeof(uint32_t) == 0); - } + VMA_ASSERT(allocator && allocation); - VMA_ASSERT(pCreateInfo->physicalDevice && pCreateInfo->device && pCreateInfo->instance); + VMA_DEBUG_LOG("vmaFlushAllocation"); - if(m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0)) - { -#if !(VMA_DEDICATED_ALLOCATION) - if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT) != 0) - { - VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT set but required extensions are disabled by preprocessor macros."); - } -#endif -#if !(VMA_BIND_MEMORY2) - if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT) != 0) - { - VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT set but required extension is disabled by preprocessor macros."); - } -#endif - } -#if !(VMA_MEMORY_BUDGET) - if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT) != 0) - { - VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT set but required extension is disabled by preprocessor macros."); - } -#endif -#if !(VMA_BUFFER_DEVICE_ADDRESS) - if(m_UseKhrBufferDeviceAddress) - { - VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT is set but required extension or Vulkan 1.2 is not available in your Vulkan header or its support in VMA has been disabled by a preprocessor macro."); - } -#endif -#if VMA_VULKAN_VERSION < 1002000 - if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 2, 0)) - { - VMA_ASSERT(0 && "vulkanApiVersion >= VK_API_VERSION_1_2 but required Vulkan version is disabled by preprocessor macros."); - } -#endif -#if VMA_VULKAN_VERSION < 1001000 - if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) - { - VMA_ASSERT(0 && "vulkanApiVersion >= VK_API_VERSION_1_1 but required Vulkan version is disabled by preprocessor macros."); - } -#endif -#if !(VMA_MEMORY_PRIORITY) - if(m_UseExtMemoryPriority) - { - VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT is set but required extension is not available in your Vulkan header or its support in VMA has been disabled by a preprocessor macro."); - } -#endif + VMA_DEBUG_GLOBAL_MUTEX_LOCK - memset(&m_DeviceMemoryCallbacks, 0 ,sizeof(m_DeviceMemoryCallbacks)); - memset(&m_PhysicalDeviceProperties, 0, sizeof(m_PhysicalDeviceProperties)); - memset(&m_MemProps, 0, sizeof(m_MemProps)); + const VkResult res = allocator->FlushOrInvalidateAllocation(allocation, offset, size, VMA_CACHE_FLUSH); - memset(&m_pBlockVectors, 0, sizeof(m_pBlockVectors)); - memset(&m_pDedicatedAllocations, 0, sizeof(m_pDedicatedAllocations)); - memset(&m_VulkanFunctions, 0, sizeof(m_VulkanFunctions)); + return res; +} - if(pCreateInfo->pDeviceMemoryCallbacks != VMA_NULL) - { - m_DeviceMemoryCallbacks.pUserData = pCreateInfo->pDeviceMemoryCallbacks->pUserData; - m_DeviceMemoryCallbacks.pfnAllocate = pCreateInfo->pDeviceMemoryCallbacks->pfnAllocate; - m_DeviceMemoryCallbacks.pfnFree = pCreateInfo->pDeviceMemoryCallbacks->pfnFree; - } +VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocation( + VmaAllocator allocator, + VmaAllocation allocation, + VkDeviceSize offset, + VkDeviceSize size) +{ + VMA_ASSERT(allocator && allocation); - ImportVulkanFunctions(pCreateInfo->pVulkanFunctions); + VMA_DEBUG_LOG("vmaInvalidateAllocation"); - (*m_VulkanFunctions.vkGetPhysicalDeviceProperties)(m_PhysicalDevice, &m_PhysicalDeviceProperties); - (*m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties)(m_PhysicalDevice, &m_MemProps); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - VMA_ASSERT(VmaIsPow2(VMA_DEBUG_ALIGNMENT)); - VMA_ASSERT(VmaIsPow2(VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY)); - VMA_ASSERT(VmaIsPow2(m_PhysicalDeviceProperties.limits.bufferImageGranularity)); - VMA_ASSERT(VmaIsPow2(m_PhysicalDeviceProperties.limits.nonCoherentAtomSize)); + const VkResult res = allocator->FlushOrInvalidateAllocation(allocation, offset, size, VMA_CACHE_INVALIDATE); - m_PreferredLargeHeapBlockSize = (pCreateInfo->preferredLargeHeapBlockSize != 0) ? - pCreateInfo->preferredLargeHeapBlockSize : static_cast(VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE); + return res; +} - m_GlobalMemoryTypeBits = CalculateGlobalMemoryTypeBits(); +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocations( + VmaAllocator allocator, + uint32_t allocationCount, + const VmaAllocation* allocations, + const VkDeviceSize* offsets, + const VkDeviceSize* sizes) +{ + VMA_ASSERT(allocator); - if(pCreateInfo->pHeapSizeLimit != VMA_NULL) + if(allocationCount == 0) { - for (const auto heapIndex : c10::irange(GetMemoryHeapCount())) { - const VkDeviceSize limit = pCreateInfo->pHeapSizeLimit[heapIndex]; - if(limit != VK_WHOLE_SIZE) - { - m_HeapSizeLimitMask |= 1u << heapIndex; - if(limit < m_MemProps.memoryHeaps[heapIndex].size) - { - m_MemProps.memoryHeaps[heapIndex].size = limit; - } - } - } + return VK_SUCCESS; } - for (const auto memTypeIndex : c10::irange(GetMemoryTypeCount())) { - const VkDeviceSize preferredBlockSize = CalcPreferredBlockSize(memTypeIndex); + VMA_ASSERT(allocations); - m_pBlockVectors[memTypeIndex] = vma_new(this, VmaBlockVector)( - this, - VK_NULL_HANDLE, // hParentPool - memTypeIndex, - preferredBlockSize, - 0, - SIZE_MAX, - GetBufferImageGranularity(), - pCreateInfo->frameInUseCount, - false, // explicitBlockSize - false, // linearAlgorithm - 0.5f); // priority (0.5 is the default per Vulkan spec) - // No need to call m_pBlockVectors[memTypeIndex][blockVectorTypeIndex]->CreateMinBlocks here, - // becase minBlockCount is 0. - m_pDedicatedAllocations[memTypeIndex] = vma_new(this, AllocationVectorType)(VmaStlAllocator(GetAllocationCallbacks())); + VMA_DEBUG_LOG("vmaFlushAllocations"); - } + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + const VkResult res = allocator->FlushOrInvalidateAllocations(allocationCount, allocations, offsets, sizes, VMA_CACHE_FLUSH); + + return res; } -VkResult VmaAllocator_T::Init(const VmaAllocatorCreateInfo* pCreateInfo) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocations( + VmaAllocator allocator, + uint32_t allocationCount, + const VmaAllocation* allocations, + const VkDeviceSize* offsets, + const VkDeviceSize* sizes) { - VkResult res = VK_SUCCESS; + VMA_ASSERT(allocator); - if(pCreateInfo->pRecordSettings != VMA_NULL && - !VmaStrIsEmpty(pCreateInfo->pRecordSettings->pFilePath)) + if(allocationCount == 0) { -#if VMA_RECORDING_ENABLED - m_pRecorder = vma_new(this, VmaRecorder)(); - res = m_pRecorder->Init(*pCreateInfo->pRecordSettings, m_UseMutex); - if(res != VK_SUCCESS) - { - return res; - } - m_pRecorder->WriteConfiguration( - m_PhysicalDeviceProperties, - m_MemProps, - m_VulkanApiVersion, - m_UseKhrDedicatedAllocation, - m_UseKhrBindMemory2, - m_UseExtMemoryBudget, - m_UseAmdDeviceCoherentMemory); - m_pRecorder->RecordCreateAllocator(GetCurrentFrameIndex()); -#else - VMA_ASSERT(0 && "VmaAllocatorCreateInfo::pRecordSettings used, but not supported due to VMA_RECORDING_ENABLED not defined to 1."); - return VK_ERROR_FEATURE_NOT_PRESENT; -#endif + return VK_SUCCESS; } -#if VMA_MEMORY_BUDGET - if(m_UseExtMemoryBudget) - { - UpdateVulkanBudget(); - } -#endif // #if VMA_MEMORY_BUDGET + VMA_ASSERT(allocations); + + VMA_DEBUG_LOG("vmaInvalidateAllocations"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + const VkResult res = allocator->FlushOrInvalidateAllocations(allocationCount, allocations, offsets, sizes, VMA_CACHE_INVALIDATE); return res; } -VmaAllocator_T::~VmaAllocator_T() +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckCorruption( + VmaAllocator allocator, + uint32_t memoryTypeBits) { -#if VMA_RECORDING_ENABLED - if(m_pRecorder != VMA_NULL) - { - m_pRecorder->RecordDestroyAllocator(GetCurrentFrameIndex()); - vma_delete(this, m_pRecorder); - } -#endif + VMA_ASSERT(allocator); - VMA_ASSERT(m_Pools.empty()); + VMA_DEBUG_LOG("vmaCheckCorruption"); - for(size_t i = GetMemoryTypeCount(); i--; ) - { - if(m_pDedicatedAllocations[i] != VMA_NULL && !m_pDedicatedAllocations[i]->empty()) - { - VMA_ASSERT(0 && "Unfreed dedicated allocations found."); - } + VMA_DEBUG_GLOBAL_MUTEX_LOCK - vma_delete(this, m_pDedicatedAllocations[i]); - vma_delete(this, m_pBlockVectors[i]); - } + return allocator->CheckCorruption(memoryTypeBits); } -void VmaAllocator_T::ImportVulkanFunctions(const VmaVulkanFunctions* pVulkanFunctions) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBeginDefragmentation( + VmaAllocator allocator, + const VmaDefragmentationInfo* pInfo, + VmaDefragmentationContext* pContext) { -#if VMA_STATIC_VULKAN_FUNCTIONS == 1 - ImportVulkanFunctions_Static(); -#endif + VMA_ASSERT(allocator && pInfo && pContext); - if(pVulkanFunctions != VMA_NULL) + VMA_DEBUG_LOG("vmaBeginDefragmentation"); + + if (pInfo->pool != VMA_NULL) { - ImportVulkanFunctions_Custom(pVulkanFunctions); + // Check if run on supported algorithms + if (pInfo->pool->m_BlockVector.GetAlgorithm() & VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) + return VK_ERROR_FEATURE_NOT_PRESENT; } -#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 - ImportVulkanFunctions_Dynamic(); -#endif + VMA_DEBUG_GLOBAL_MUTEX_LOCK - ValidateVulkanFunctions(); + *pContext = vma_new(allocator, VmaDefragmentationContext_T)(allocator, *pInfo); + return VK_SUCCESS; } -#if VMA_STATIC_VULKAN_FUNCTIONS == 1 - -void VmaAllocator_T::ImportVulkanFunctions_Static() +VMA_CALL_PRE void VMA_CALL_POST vmaEndDefragmentation( + VmaAllocator allocator, + VmaDefragmentationContext context, + VmaDefragmentationStats* pStats) { - // Vulkan 1.0 - m_VulkanFunctions.vkGetPhysicalDeviceProperties = (PFN_vkGetPhysicalDeviceProperties)vkGetPhysicalDeviceProperties; - m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties = (PFN_vkGetPhysicalDeviceMemoryProperties)vkGetPhysicalDeviceMemoryProperties; - m_VulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkAllocateMemory; - m_VulkanFunctions.vkFreeMemory = (PFN_vkFreeMemory)vkFreeMemory; - m_VulkanFunctions.vkMapMemory = (PFN_vkMapMemory)vkMapMemory; - m_VulkanFunctions.vkUnmapMemory = (PFN_vkUnmapMemory)vkUnmapMemory; - m_VulkanFunctions.vkFlushMappedMemoryRanges = (PFN_vkFlushMappedMemoryRanges)vkFlushMappedMemoryRanges; - m_VulkanFunctions.vkInvalidateMappedMemoryRanges = (PFN_vkInvalidateMappedMemoryRanges)vkInvalidateMappedMemoryRanges; - m_VulkanFunctions.vkBindBufferMemory = (PFN_vkBindBufferMemory)vkBindBufferMemory; - m_VulkanFunctions.vkBindImageMemory = (PFN_vkBindImageMemory)vkBindImageMemory; - m_VulkanFunctions.vkGetBufferMemoryRequirements = (PFN_vkGetBufferMemoryRequirements)vkGetBufferMemoryRequirements; - m_VulkanFunctions.vkGetImageMemoryRequirements = (PFN_vkGetImageMemoryRequirements)vkGetImageMemoryRequirements; - m_VulkanFunctions.vkCreateBuffer = (PFN_vkCreateBuffer)vkCreateBuffer; - m_VulkanFunctions.vkDestroyBuffer = (PFN_vkDestroyBuffer)vkDestroyBuffer; - m_VulkanFunctions.vkCreateImage = (PFN_vkCreateImage)vkCreateImage; - m_VulkanFunctions.vkDestroyImage = (PFN_vkDestroyImage)vkDestroyImage; - m_VulkanFunctions.vkCmdCopyBuffer = (PFN_vkCmdCopyBuffer)vkCmdCopyBuffer; + VMA_ASSERT(allocator && context); - // Vulkan 1.1 -#if VMA_VULKAN_VERSION >= 1001000 - if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) - { - m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2)vkGetBufferMemoryRequirements2; - m_VulkanFunctions.vkGetImageMemoryRequirements2KHR = (PFN_vkGetImageMemoryRequirements2)vkGetImageMemoryRequirements2; - m_VulkanFunctions.vkBindBufferMemory2KHR = (PFN_vkBindBufferMemory2)vkBindBufferMemory2; - m_VulkanFunctions.vkBindImageMemory2KHR = (PFN_vkBindImageMemory2)vkBindImageMemory2; - m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties2KHR = (PFN_vkGetPhysicalDeviceMemoryProperties2)vkGetPhysicalDeviceMemoryProperties2; - } -#endif -} + VMA_DEBUG_LOG("vmaEndDefragmentation"); -#endif // #if VMA_STATIC_VULKAN_FUNCTIONS == 1 + VMA_DEBUG_GLOBAL_MUTEX_LOCK -void VmaAllocator_T::ImportVulkanFunctions_Custom(const VmaVulkanFunctions* pVulkanFunctions) + if (pStats) + context->GetStats(*pStats); + vma_delete(allocator, context); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBeginDefragmentationPass( + VmaAllocator VMA_NOT_NULL allocator, + VmaDefragmentationContext VMA_NOT_NULL context, + VmaDefragmentationPassMoveInfo* VMA_NOT_NULL pPassInfo) { - VMA_ASSERT(pVulkanFunctions != VMA_NULL); + VMA_ASSERT(context && pPassInfo); -#define VMA_COPY_IF_NOT_NULL(funcName) \ - if(pVulkanFunctions->funcName != VMA_NULL) m_VulkanFunctions.funcName = pVulkanFunctions->funcName; + VMA_DEBUG_LOG("vmaBeginDefragmentationPass"); - VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceProperties); - VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceMemoryProperties); - VMA_COPY_IF_NOT_NULL(vkAllocateMemory); - VMA_COPY_IF_NOT_NULL(vkFreeMemory); - VMA_COPY_IF_NOT_NULL(vkMapMemory); - VMA_COPY_IF_NOT_NULL(vkUnmapMemory); - VMA_COPY_IF_NOT_NULL(vkFlushMappedMemoryRanges); - VMA_COPY_IF_NOT_NULL(vkInvalidateMappedMemoryRanges); - VMA_COPY_IF_NOT_NULL(vkBindBufferMemory); - VMA_COPY_IF_NOT_NULL(vkBindImageMemory); - VMA_COPY_IF_NOT_NULL(vkGetBufferMemoryRequirements); - VMA_COPY_IF_NOT_NULL(vkGetImageMemoryRequirements); - VMA_COPY_IF_NOT_NULL(vkCreateBuffer); - VMA_COPY_IF_NOT_NULL(vkDestroyBuffer); - VMA_COPY_IF_NOT_NULL(vkCreateImage); - VMA_COPY_IF_NOT_NULL(vkDestroyImage); - VMA_COPY_IF_NOT_NULL(vkCmdCopyBuffer); + VMA_DEBUG_GLOBAL_MUTEX_LOCK -#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 - VMA_COPY_IF_NOT_NULL(vkGetBufferMemoryRequirements2KHR); - VMA_COPY_IF_NOT_NULL(vkGetImageMemoryRequirements2KHR); -#endif + return context->DefragmentPassBegin(*pPassInfo); +} -#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 - VMA_COPY_IF_NOT_NULL(vkBindBufferMemory2KHR); - VMA_COPY_IF_NOT_NULL(vkBindImageMemory2KHR); -#endif +VMA_CALL_PRE VkResult VMA_CALL_POST vmaEndDefragmentationPass( + VmaAllocator VMA_NOT_NULL allocator, + VmaDefragmentationContext VMA_NOT_NULL context, + VmaDefragmentationPassMoveInfo* VMA_NOT_NULL pPassInfo) +{ + VMA_ASSERT(context && pPassInfo); -#if VMA_MEMORY_BUDGET - VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceMemoryProperties2KHR); -#endif + VMA_DEBUG_LOG("vmaEndDefragmentationPass"); -#undef VMA_COPY_IF_NOT_NULL -} + VMA_DEBUG_GLOBAL_MUTEX_LOCK -#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + return context->DefragmentPassEnd(*pPassInfo); +} -void VmaAllocator_T::ImportVulkanFunctions_Dynamic() +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory( + VmaAllocator allocator, + VmaAllocation allocation, + VkBuffer buffer) { -#define VMA_FETCH_INSTANCE_FUNC(memberName, functionPointerType, functionNameString) \ - if(m_VulkanFunctions.memberName == VMA_NULL) \ - m_VulkanFunctions.memberName = \ - (functionPointerType)vkGetInstanceProcAddr(m_hInstance, functionNameString); -#define VMA_FETCH_DEVICE_FUNC(memberName, functionPointerType, functionNameString) \ - if(m_VulkanFunctions.memberName == VMA_NULL) \ - m_VulkanFunctions.memberName = \ - (functionPointerType)vkGetDeviceProcAddr(m_hDevice, functionNameString); + VMA_ASSERT(allocator && allocation && buffer); - VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceProperties, PFN_vkGetPhysicalDeviceProperties, "vkGetPhysicalDeviceProperties"); - VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties, PFN_vkGetPhysicalDeviceMemoryProperties, "vkGetPhysicalDeviceMemoryProperties"); - VMA_FETCH_DEVICE_FUNC(vkAllocateMemory, PFN_vkAllocateMemory, "vkAllocateMemory"); - VMA_FETCH_DEVICE_FUNC(vkFreeMemory, PFN_vkFreeMemory, "vkFreeMemory"); - VMA_FETCH_DEVICE_FUNC(vkMapMemory, PFN_vkMapMemory, "vkMapMemory"); - VMA_FETCH_DEVICE_FUNC(vkUnmapMemory, PFN_vkUnmapMemory, "vkUnmapMemory"); - VMA_FETCH_DEVICE_FUNC(vkFlushMappedMemoryRanges, PFN_vkFlushMappedMemoryRanges, "vkFlushMappedMemoryRanges"); - VMA_FETCH_DEVICE_FUNC(vkInvalidateMappedMemoryRanges, PFN_vkInvalidateMappedMemoryRanges, "vkInvalidateMappedMemoryRanges"); - VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory, PFN_vkBindBufferMemory, "vkBindBufferMemory"); - VMA_FETCH_DEVICE_FUNC(vkBindImageMemory, PFN_vkBindImageMemory, "vkBindImageMemory"); - VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements, PFN_vkGetBufferMemoryRequirements, "vkGetBufferMemoryRequirements"); - VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements, PFN_vkGetImageMemoryRequirements, "vkGetImageMemoryRequirements"); - VMA_FETCH_DEVICE_FUNC(vkCreateBuffer, PFN_vkCreateBuffer, "vkCreateBuffer"); - VMA_FETCH_DEVICE_FUNC(vkDestroyBuffer, PFN_vkDestroyBuffer, "vkDestroyBuffer"); - VMA_FETCH_DEVICE_FUNC(vkCreateImage, PFN_vkCreateImage, "vkCreateImage"); - VMA_FETCH_DEVICE_FUNC(vkDestroyImage, PFN_vkDestroyImage, "vkDestroyImage"); - VMA_FETCH_DEVICE_FUNC(vkCmdCopyBuffer, PFN_vkCmdCopyBuffer, "vkCmdCopyBuffer"); + VMA_DEBUG_LOG("vmaBindBufferMemory"); -#if VMA_VULKAN_VERSION >= 1001000 - if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) - { - VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements2KHR, PFN_vkGetBufferMemoryRequirements2, "vkGetBufferMemoryRequirements2"); - VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements2KHR, PFN_vkGetImageMemoryRequirements2, "vkGetImageMemoryRequirements2"); - VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory2KHR, PFN_vkBindBufferMemory2, "vkBindBufferMemory2"); - VMA_FETCH_DEVICE_FUNC(vkBindImageMemory2KHR, PFN_vkBindImageMemory2, "vkBindImageMemory2"); - VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties2KHR, PFN_vkGetPhysicalDeviceMemoryProperties2, "vkGetPhysicalDeviceMemoryProperties2"); - } -#endif + VMA_DEBUG_GLOBAL_MUTEX_LOCK -#if VMA_DEDICATED_ALLOCATION - if(m_UseKhrDedicatedAllocation) - { - VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements2KHR, PFN_vkGetBufferMemoryRequirements2KHR, "vkGetBufferMemoryRequirements2KHR"); - VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements2KHR, PFN_vkGetImageMemoryRequirements2KHR, "vkGetImageMemoryRequirements2KHR"); - } -#endif + return allocator->BindBufferMemory(allocation, 0, buffer, VMA_NULL); +} -#if VMA_BIND_MEMORY2 - if(m_UseKhrBindMemory2) - { - VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory2KHR, PFN_vkBindBufferMemory2KHR, "vkBindBufferMemory2KHR"); - VMA_FETCH_DEVICE_FUNC(vkBindImageMemory2KHR, PFN_vkBindImageMemory2KHR, "vkBindImageMemory2KHR"); - } -#endif // #if VMA_BIND_MEMORY2 +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory2( + VmaAllocator allocator, + VmaAllocation allocation, + VkDeviceSize allocationLocalOffset, + VkBuffer buffer, + const void* pNext) +{ + VMA_ASSERT(allocator && allocation && buffer); -#if VMA_MEMORY_BUDGET - if(m_UseExtMemoryBudget) - { - VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties2KHR, PFN_vkGetPhysicalDeviceMemoryProperties2KHR, "vkGetPhysicalDeviceMemoryProperties2KHR"); - } -#endif // #if VMA_MEMORY_BUDGET + VMA_DEBUG_LOG("vmaBindBufferMemory2"); -#undef VMA_FETCH_DEVICE_FUNC -#undef VMA_FETCH_INSTANCE_FUNC -} + VMA_DEBUG_GLOBAL_MUTEX_LOCK -#endif // #if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + return allocator->BindBufferMemory(allocation, allocationLocalOffset, buffer, pNext); +} -void VmaAllocator_T::ValidateVulkanFunctions() +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory( + VmaAllocator allocator, + VmaAllocation allocation, + VkImage image) { - VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceProperties != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkAllocateMemory != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkFreeMemory != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkMapMemory != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkUnmapMemory != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkFlushMappedMemoryRanges != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkInvalidateMappedMemoryRanges != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkBindBufferMemory != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkBindImageMemory != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkGetBufferMemoryRequirements != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkGetImageMemoryRequirements != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkCreateBuffer != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkDestroyBuffer != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkCreateImage != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkDestroyImage != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkCmdCopyBuffer != VMA_NULL); + VMA_ASSERT(allocator && allocation && image); -#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 - if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0) || m_UseKhrDedicatedAllocation) - { - VMA_ASSERT(m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkGetImageMemoryRequirements2KHR != VMA_NULL); - } -#endif + VMA_DEBUG_LOG("vmaBindImageMemory"); -#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 - if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0) || m_UseKhrBindMemory2) - { - VMA_ASSERT(m_VulkanFunctions.vkBindBufferMemory2KHR != VMA_NULL); - VMA_ASSERT(m_VulkanFunctions.vkBindImageMemory2KHR != VMA_NULL); - } -#endif + VMA_DEBUG_GLOBAL_MUTEX_LOCK -#if VMA_MEMORY_BUDGET || VMA_VULKAN_VERSION >= 1001000 - if(m_UseExtMemoryBudget || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) - { - VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties2KHR != VMA_NULL); - } -#endif + return allocator->BindImageMemory(allocation, 0, image, VMA_NULL); } -VkDeviceSize VmaAllocator_T::CalcPreferredBlockSize(uint32_t memTypeIndex) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory2( + VmaAllocator allocator, + VmaAllocation allocation, + VkDeviceSize allocationLocalOffset, + VkImage image, + const void* pNext) { - const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(memTypeIndex); - const VkDeviceSize heapSize = m_MemProps.memoryHeaps[heapIndex].size; - const bool isSmallHeap = heapSize <= VMA_SMALL_HEAP_MAX_SIZE; - return VmaAlignUp(isSmallHeap ? (heapSize / 8) : m_PreferredLargeHeapBlockSize, (VkDeviceSize)32); + VMA_ASSERT(allocator && allocation && image); + + VMA_DEBUG_LOG("vmaBindImageMemory2"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + return allocator->BindImageMemory(allocation, allocationLocalOffset, image, pNext); } -VkResult VmaAllocator_T::AllocateMemoryOfType( - VkDeviceSize size, - VkDeviceSize alignment, - bool dedicatedAllocation, - VkBuffer dedicatedBuffer, - VkBufferUsageFlags dedicatedBufferUsage, - VkImage dedicatedImage, - const VmaAllocationCreateInfo& createInfo, - uint32_t memTypeIndex, - VmaSuballocationType suballocType, - size_t allocationCount, - VmaAllocation* pAllocations) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBuffer( + VmaAllocator allocator, + const VkBufferCreateInfo* pBufferCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + VkBuffer* pBuffer, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) { - VMA_ASSERT(pAllocations != VMA_NULL); - VMA_DEBUG_LOG(" AllocateMemory: MemoryTypeIndex=%u, AllocationCount=%zu, Size=%llu", memTypeIndex, allocationCount, size); - - VmaAllocationCreateInfo finalCreateInfo = createInfo; + VMA_ASSERT(allocator && pBufferCreateInfo && pAllocationCreateInfo && pBuffer && pAllocation); - // If memory type is not HOST_VISIBLE, disable MAPPED. - if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0 && - (m_MemProps.memoryTypes[memTypeIndex].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) + if(pBufferCreateInfo->size == 0) { - finalCreateInfo.flags &= ~VMA_ALLOCATION_CREATE_MAPPED_BIT; + return VK_ERROR_INITIALIZATION_FAILED; } - // If memory is lazily allocated, it should be always dedicated. - if(finalCreateInfo.usage == VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED) + if((pBufferCreateInfo->usage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_COPY) != 0 && + !allocator->m_UseKhrBufferDeviceAddress) { - finalCreateInfo.flags |= VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT; + VMA_ASSERT(0 && "Creating a buffer with VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT is not valid if VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT was not used."); + return VK_ERROR_INITIALIZATION_FAILED; } - VmaBlockVector* const blockVector = m_pBlockVectors[memTypeIndex]; - VMA_ASSERT(blockVector); + VMA_DEBUG_LOG("vmaCreateBuffer"); - const VkDeviceSize preferredBlockSize = blockVector->GetPreferredBlockSize(); - bool preferDedicatedMemory = - VMA_DEBUG_ALWAYS_DEDICATED_MEMORY || - dedicatedAllocation || - // Heuristics: Allocate dedicated memory if requested size if greater than half of preferred block size. - size > preferredBlockSize / 2; + VMA_DEBUG_GLOBAL_MUTEX_LOCK - if(preferDedicatedMemory && - (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) == 0 && - finalCreateInfo.pool == VK_NULL_HANDLE) - { - finalCreateInfo.flags |= VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT; - } + *pBuffer = VK_NULL_HANDLE; + *pAllocation = VK_NULL_HANDLE; - if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT) != 0) - { - if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) - { - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } - else - { - return AllocateDedicatedMemory( - size, - suballocType, - memTypeIndex, - (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT) != 0, - (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0, - (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0, - finalCreateInfo.pUserData, - finalCreateInfo.priority, - dedicatedBuffer, - dedicatedBufferUsage, - dedicatedImage, - allocationCount, - pAllocations); - } - } - else + // 1. Create VkBuffer. + VkResult res = (*allocator->GetVulkanFunctions().vkCreateBuffer)( + allocator->m_hDevice, + pBufferCreateInfo, + allocator->GetAllocationCallbacks(), + pBuffer); + if(res >= 0) { - VkResult res = blockVector->Allocate( - m_CurrentFrameIndex.load(), - size, - alignment, - finalCreateInfo, - suballocType, - allocationCount, - pAllocations); - if(res == VK_SUCCESS) - { - return res; - } + // 2. vkGetBufferMemoryRequirements. + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetBufferMemoryRequirements(*pBuffer, vkMemReq, + requiresDedicatedAllocation, prefersDedicatedAllocation); - // 5. Try dedicated memory. - if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) - { - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } + // 3. Allocate memory using allocator. + res = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + *pBuffer, // dedicatedBuffer + VK_NULL_HANDLE, // dedicatedImage + pBufferCreateInfo->usage, // dedicatedBufferImageUsage + *pAllocationCreateInfo, + VMA_SUBALLOCATION_TYPE_BUFFER, + 1, // allocationCount + pAllocation); - // Protection against creating each allocation as dedicated when we reach or exceed heap size/budget, - // which can quickly deplete maxMemoryAllocationCount: Don't try dedicated allocations when above - // 3/4 of the maximum allocation count. - if(m_DeviceMemoryCount.load() > m_PhysicalDeviceProperties.limits.maxMemoryAllocationCount * 3 / 4) + if(res >= 0) { - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } + // 3. Bind buffer with memory. + if((pAllocationCreateInfo->flags & VMA_ALLOCATION_CREATE_DONT_BIND_BIT) == 0) + { + res = allocator->BindBufferMemory(*pAllocation, 0, *pBuffer, VMA_NULL); + } + if(res >= 0) + { + // All steps succeeded. + #if VMA_STATS_STRING_ENABLED + (*pAllocation)->InitBufferImageUsage(pBufferCreateInfo->usage); + #endif + if(pAllocationInfo != VMA_NULL) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } - res = AllocateDedicatedMemory( - size, - suballocType, - memTypeIndex, - (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT) != 0, - (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0, - (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0, - finalCreateInfo.pUserData, - finalCreateInfo.priority, - dedicatedBuffer, - dedicatedBufferUsage, - dedicatedImage, - allocationCount, - pAllocations); - if(res == VK_SUCCESS) - { - // Succeeded: AllocateDedicatedMemory function already filld pMemory, nothing more to do here. - VMA_DEBUG_LOG(" Allocated as DedicatedMemory"); - return VK_SUCCESS; - } - else - { - // Everything failed: Return error code. - VMA_DEBUG_LOG(" vkAllocateMemory FAILED"); + return VK_SUCCESS; + } + allocator->FreeMemory( + 1, // allocationCount + pAllocation); + *pAllocation = VK_NULL_HANDLE; + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); + *pBuffer = VK_NULL_HANDLE; return res; } + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); + *pBuffer = VK_NULL_HANDLE; + return res; } + return res; } -VkResult VmaAllocator_T::AllocateDedicatedMemory( - VkDeviceSize size, - VmaSuballocationType suballocType, - uint32_t memTypeIndex, - bool withinBudget, - bool map, - bool isUserDataString, - void* pUserData, - float priority, - VkBuffer dedicatedBuffer, - VkBufferUsageFlags dedicatedBufferUsage, - VkImage dedicatedImage, - size_t allocationCount, - VmaAllocation* pAllocations) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBufferWithAlignment( + VmaAllocator allocator, + const VkBufferCreateInfo* pBufferCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + VkDeviceSize minAlignment, + VkBuffer* pBuffer, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) { - VMA_ASSERT(allocationCount > 0 && pAllocations); + VMA_ASSERT(allocator && pBufferCreateInfo && pAllocationCreateInfo && VmaIsPow2(minAlignment) && pBuffer && pAllocation); - if(withinBudget) + if(pBufferCreateInfo->size == 0) { - const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(memTypeIndex); - VmaBudget heapBudget = {}; - GetBudget(&heapBudget, heapIndex, 1); - if(heapBudget.usage + size * allocationCount > heapBudget.budget) - { - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } + return VK_ERROR_INITIALIZATION_FAILED; } - - VkMemoryAllocateInfo allocInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO }; - allocInfo.memoryTypeIndex = memTypeIndex; - allocInfo.allocationSize = size; - -#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 - VkMemoryDedicatedAllocateInfoKHR dedicatedAllocInfo = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR }; - if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + if((pBufferCreateInfo->usage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_COPY) != 0 && + !allocator->m_UseKhrBufferDeviceAddress) { - if(dedicatedBuffer != VK_NULL_HANDLE) - { - VMA_ASSERT(dedicatedImage == VK_NULL_HANDLE); - dedicatedAllocInfo.buffer = dedicatedBuffer; - VmaPnextChainPushFront(&allocInfo, &dedicatedAllocInfo); - } - else if(dedicatedImage != VK_NULL_HANDLE) - { - dedicatedAllocInfo.image = dedicatedImage; - VmaPnextChainPushFront(&allocInfo, &dedicatedAllocInfo); - } + VMA_ASSERT(0 && "Creating a buffer with VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT is not valid if VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT was not used."); + return VK_ERROR_INITIALIZATION_FAILED; } -#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 -#if VMA_BUFFER_DEVICE_ADDRESS - VkMemoryAllocateFlagsInfoKHR allocFlagsInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO_KHR }; - if(m_UseKhrBufferDeviceAddress) - { - bool canContainBufferWithDeviceAddress = true; - if(dedicatedBuffer != VK_NULL_HANDLE) - { - canContainBufferWithDeviceAddress = dedicatedBufferUsage == UINT32_MAX || // Usage flags unknown - (dedicatedBufferUsage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT) != 0; - } - else if(dedicatedImage != VK_NULL_HANDLE) - { - canContainBufferWithDeviceAddress = false; - } - if(canContainBufferWithDeviceAddress) - { - allocFlagsInfo.flags = VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT_KHR; - VmaPnextChainPushFront(&allocInfo, &allocFlagsInfo); - } - } -#endif // #if VMA_BUFFER_DEVICE_ADDRESS + VMA_DEBUG_LOG("vmaCreateBufferWithAlignment"); -#if VMA_MEMORY_PRIORITY - VkMemoryPriorityAllocateInfoEXT priorityInfo = { VK_STRUCTURE_TYPE_MEMORY_PRIORITY_ALLOCATE_INFO_EXT }; - if(m_UseExtMemoryPriority) - { - priorityInfo.priority = priority; - VmaPnextChainPushFront(&allocInfo, &priorityInfo); - } -#endif // #if VMA_MEMORY_PRIORITY + VMA_DEBUG_GLOBAL_MUTEX_LOCK - size_t allocIndex; - VkResult res = VK_SUCCESS; - for(allocIndex = 0; allocIndex < allocationCount; ++allocIndex) - { - res = AllocateDedicatedMemoryPage( - size, - suballocType, - memTypeIndex, - allocInfo, - map, - isUserDataString, - pUserData, - pAllocations + allocIndex); - if(res != VK_SUCCESS) - { - break; - } - } + *pBuffer = VK_NULL_HANDLE; + *pAllocation = VK_NULL_HANDLE; - if(res == VK_SUCCESS) + // 1. Create VkBuffer. + VkResult res = (*allocator->GetVulkanFunctions().vkCreateBuffer)( + allocator->m_hDevice, + pBufferCreateInfo, + allocator->GetAllocationCallbacks(), + pBuffer); + if(res >= 0) { - // Register them in m_pDedicatedAllocations. - { - VmaMutexLockWrite lock(m_DedicatedAllocationsMutex[memTypeIndex], m_UseMutex); - AllocationVectorType* pDedicatedAllocations = m_pDedicatedAllocations[memTypeIndex]; - VMA_ASSERT(pDedicatedAllocations); - for(allocIndex = 0; allocIndex < allocationCount; ++allocIndex) - { - VmaVectorInsertSorted(*pDedicatedAllocations, pAllocations[allocIndex]); - } - } + // 2. vkGetBufferMemoryRequirements. + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetBufferMemoryRequirements(*pBuffer, vkMemReq, + requiresDedicatedAllocation, prefersDedicatedAllocation); - VMA_DEBUG_LOG(" Allocated DedicatedMemory Count=%zu, MemoryTypeIndex=#%u", allocationCount, memTypeIndex); - } - else - { - // Free all already created allocations. - while(allocIndex--) - { - VmaAllocation currAlloc = pAllocations[allocIndex]; - VkDeviceMemory hMemory = currAlloc->GetMemory(); + // 2a. Include minAlignment + vkMemReq.alignment = VMA_MAX(vkMemReq.alignment, minAlignment); - /* - There is no need to call this, because Vulkan spec allows to skip vkUnmapMemory - before vkFreeMemory. + // 3. Allocate memory using allocator. + res = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + *pBuffer, // dedicatedBuffer + VK_NULL_HANDLE, // dedicatedImage + pBufferCreateInfo->usage, // dedicatedBufferImageUsage + *pAllocationCreateInfo, + VMA_SUBALLOCATION_TYPE_BUFFER, + 1, // allocationCount + pAllocation); - if(currAlloc->GetMappedData() != VMA_NULL) + if(res >= 0) + { + // 3. Bind buffer with memory. + if((pAllocationCreateInfo->flags & VMA_ALLOCATION_CREATE_DONT_BIND_BIT) == 0) { - (*m_VulkanFunctions.vkUnmapMemory)(m_hDevice, hMemory); + res = allocator->BindBufferMemory(*pAllocation, 0, *pBuffer, VMA_NULL); } - */ + if(res >= 0) + { + // All steps succeeded. + #if VMA_STATS_STRING_ENABLED + (*pAllocation)->InitBufferImageUsage(pBufferCreateInfo->usage); + #endif + if(pAllocationInfo != VMA_NULL) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } - FreeVulkanMemory(memTypeIndex, currAlloc->GetSize(), hMemory); - m_Budget.RemoveAllocation(MemoryTypeIndexToHeapIndex(memTypeIndex), currAlloc->GetSize()); - currAlloc->SetUserData(this, VMA_NULL); - m_AllocationObjectAllocator.Free(currAlloc); + return VK_SUCCESS; + } + allocator->FreeMemory( + 1, // allocationCount + pAllocation); + *pAllocation = VK_NULL_HANDLE; + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); + *pBuffer = VK_NULL_HANDLE; + return res; } - - memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); + *pBuffer = VK_NULL_HANDLE; + return res; } - return res; } -VkResult VmaAllocator_T::AllocateDedicatedMemoryPage( - VkDeviceSize size, - VmaSuballocationType suballocType, - uint32_t memTypeIndex, - const VkMemoryAllocateInfo& allocInfo, - bool map, - bool isUserDataString, - void* pUserData, - VmaAllocation* pAllocation) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAliasingBuffer( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, + VkBuffer VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pBuffer) { - VkDeviceMemory hMemory = VK_NULL_HANDLE; - VkResult res = AllocateVulkanMemory(&allocInfo, &hMemory); - if(res < 0) + VMA_ASSERT(allocator && pBufferCreateInfo && pBuffer && allocation); + + VMA_DEBUG_LOG("vmaCreateAliasingBuffer"); + + *pBuffer = VK_NULL_HANDLE; + + if (pBufferCreateInfo->size == 0) { - VMA_DEBUG_LOG(" vkAllocateMemory FAILED"); - return res; + return VK_ERROR_INITIALIZATION_FAILED; } - - void* pMappedData = VMA_NULL; - if(map) + if ((pBufferCreateInfo->usage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_COPY) != 0 && + !allocator->m_UseKhrBufferDeviceAddress) { - res = (*m_VulkanFunctions.vkMapMemory)( - m_hDevice, - hMemory, - 0, - VK_WHOLE_SIZE, - 0, - &pMappedData); - if(res < 0) - { - VMA_DEBUG_LOG(" vkMapMemory FAILED"); - FreeVulkanMemory(memTypeIndex, size, hMemory); - return res; - } + VMA_ASSERT(0 && "Creating a buffer with VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT is not valid if VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT was not used."); + return VK_ERROR_INITIALIZATION_FAILED; } - *pAllocation = m_AllocationObjectAllocator.Allocate(m_CurrentFrameIndex.load(), isUserDataString); - (*pAllocation)->InitDedicatedAllocation(memTypeIndex, hMemory, suballocType, pMappedData, size); - (*pAllocation)->SetUserData(this, pUserData); - m_Budget.AddAllocation(MemoryTypeIndexToHeapIndex(memTypeIndex), size); - if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + // 1. Create VkBuffer. + VkResult res = (*allocator->GetVulkanFunctions().vkCreateBuffer)( + allocator->m_hDevice, + pBufferCreateInfo, + allocator->GetAllocationCallbacks(), + pBuffer); + if (res >= 0) { - FillAllocation(*pAllocation, VMA_ALLOCATION_FILL_PATTERN_CREATED); + // 2. Bind buffer with memory. + res = allocator->BindBufferMemory(allocation, 0, *pBuffer, VMA_NULL); + if (res >= 0) + { + return VK_SUCCESS; + } + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); } - - return VK_SUCCESS; + return res; } -void VmaAllocator_T::GetBufferMemoryRequirements( - VkBuffer hBuffer, - VkMemoryRequirements& memReq, - bool& requiresDedicatedAllocation, - bool& prefersDedicatedAllocation) const +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyBuffer( + VmaAllocator allocator, + VkBuffer buffer, + VmaAllocation allocation) { -#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 - if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) - { - VkBufferMemoryRequirementsInfo2KHR memReqInfo = { VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR }; - memReqInfo.buffer = hBuffer; + VMA_ASSERT(allocator); - VkMemoryDedicatedRequirementsKHR memDedicatedReq = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR }; + if(buffer == VK_NULL_HANDLE && allocation == VK_NULL_HANDLE) + { + return; + } - VkMemoryRequirements2KHR memReq2 = { VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR }; - VmaPnextChainPushFront(&memReq2, &memDedicatedReq); + VMA_DEBUG_LOG("vmaDestroyBuffer"); - (*m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR)(m_hDevice, &memReqInfo, &memReq2); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - memReq = memReq2.memoryRequirements; - requiresDedicatedAllocation = (memDedicatedReq.requiresDedicatedAllocation != VK_FALSE); - prefersDedicatedAllocation = (memDedicatedReq.prefersDedicatedAllocation != VK_FALSE); + if(buffer != VK_NULL_HANDLE) + { + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, buffer, allocator->GetAllocationCallbacks()); } - else -#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + + if(allocation != VK_NULL_HANDLE) { - (*m_VulkanFunctions.vkGetBufferMemoryRequirements)(m_hDevice, hBuffer, &memReq); - requiresDedicatedAllocation = false; - prefersDedicatedAllocation = false; + allocator->FreeMemory( + 1, // allocationCount + &allocation); } } -void VmaAllocator_T::GetImageMemoryRequirements( - VkImage hImage, - VkMemoryRequirements& memReq, - bool& requiresDedicatedAllocation, - bool& prefersDedicatedAllocation) const +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateImage( + VmaAllocator allocator, + const VkImageCreateInfo* pImageCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + VkImage* pImage, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) { -#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 - if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) - { - VkImageMemoryRequirementsInfo2KHR memReqInfo = { VK_STRUCTURE_TYPE_IMAGE_MEMORY_REQUIREMENTS_INFO_2_KHR }; - memReqInfo.image = hImage; - - VkMemoryDedicatedRequirementsKHR memDedicatedReq = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR }; - - VkMemoryRequirements2KHR memReq2 = { VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR }; - VmaPnextChainPushFront(&memReq2, &memDedicatedReq); - - (*m_VulkanFunctions.vkGetImageMemoryRequirements2KHR)(m_hDevice, &memReqInfo, &memReq2); + VMA_ASSERT(allocator && pImageCreateInfo && pAllocationCreateInfo && pImage && pAllocation); - memReq = memReq2.memoryRequirements; - requiresDedicatedAllocation = (memDedicatedReq.requiresDedicatedAllocation != VK_FALSE); - prefersDedicatedAllocation = (memDedicatedReq.prefersDedicatedAllocation != VK_FALSE); - } - else -#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + if(pImageCreateInfo->extent.width == 0 || + pImageCreateInfo->extent.height == 0 || + pImageCreateInfo->extent.depth == 0 || + pImageCreateInfo->mipLevels == 0 || + pImageCreateInfo->arrayLayers == 0) { - (*m_VulkanFunctions.vkGetImageMemoryRequirements)(m_hDevice, hImage, &memReq); - requiresDedicatedAllocation = false; - prefersDedicatedAllocation = false; + return VK_ERROR_INITIALIZATION_FAILED; } -} -VkResult VmaAllocator_T::AllocateMemory( - const VkMemoryRequirements& vkMemReq, - bool requiresDedicatedAllocation, - bool prefersDedicatedAllocation, - VkBuffer dedicatedBuffer, - VkBufferUsageFlags dedicatedBufferUsage, - VkImage dedicatedImage, - const VmaAllocationCreateInfo& createInfo, - VmaSuballocationType suballocType, - size_t allocationCount, - VmaAllocation* pAllocations) -{ - memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); + VMA_DEBUG_LOG("vmaCreateImage"); - VMA_ASSERT(VmaIsPow2(vkMemReq.alignment)); + VMA_DEBUG_GLOBAL_MUTEX_LOCK - if(vkMemReq.size == 0) - { - return VK_ERROR_VALIDATION_FAILED_EXT; - } - if((createInfo.flags & VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT) != 0 && - (createInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) - { - VMA_ASSERT(0 && "Specifying VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT together with VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT makes no sense."); - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } - if((createInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0 && - (createInfo.flags & VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT) != 0) - { - VMA_ASSERT(0 && "Specifying VMA_ALLOCATION_CREATE_MAPPED_BIT together with VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT is invalid."); - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } - if(requiresDedicatedAllocation) - { - if((createInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) - { - VMA_ASSERT(0 && "VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT specified while dedicated allocation is required."); - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } - if(createInfo.pool != VK_NULL_HANDLE) - { - VMA_ASSERT(0 && "Pool specified while dedicated allocation is required."); - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } - } - if((createInfo.pool != VK_NULL_HANDLE) && - ((createInfo.flags & (VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT)) != 0)) - { - VMA_ASSERT(0 && "Specifying VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT when pool != null is invalid."); - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } + *pImage = VK_NULL_HANDLE; + *pAllocation = VK_NULL_HANDLE; - if(createInfo.pool != VK_NULL_HANDLE) + // 1. Create VkImage. + VkResult res = (*allocator->GetVulkanFunctions().vkCreateImage)( + allocator->m_hDevice, + pImageCreateInfo, + allocator->GetAllocationCallbacks(), + pImage); + if(res >= 0) { - const VkDeviceSize alignmentForPool = VMA_MAX( - vkMemReq.alignment, - GetMemoryTypeMinAlignment(createInfo.pool->m_BlockVector.GetMemoryTypeIndex())); + VmaSuballocationType suballocType = pImageCreateInfo->tiling == VK_IMAGE_TILING_OPTIMAL ? + VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL : + VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR; - VmaAllocationCreateInfo createInfoForPool = createInfo; - // If memory type is not HOST_VISIBLE, disable MAPPED. - if((createInfoForPool.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0 && - (m_MemProps.memoryTypes[createInfo.pool->m_BlockVector.GetMemoryTypeIndex()].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) - { - createInfoForPool.flags &= ~VMA_ALLOCATION_CREATE_MAPPED_BIT; - } + // 2. Allocate memory using allocator. + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetImageMemoryRequirements(*pImage, vkMemReq, + requiresDedicatedAllocation, prefersDedicatedAllocation); - return createInfo.pool->m_BlockVector.Allocate( - m_CurrentFrameIndex.load(), - vkMemReq.size, - alignmentForPool, - createInfoForPool, + res = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + VK_NULL_HANDLE, // dedicatedBuffer + *pImage, // dedicatedImage + pImageCreateInfo->usage, // dedicatedBufferImageUsage + *pAllocationCreateInfo, suballocType, - allocationCount, - pAllocations); - } - else - { - // Bit mask of memory Vulkan types acceptable for this allocation. - uint32_t memoryTypeBits = vkMemReq.memoryTypeBits; - uint32_t memTypeIndex = UINT32_MAX; - VkResult res = vmaFindMemoryTypeIndex(this, memoryTypeBits, &createInfo, &memTypeIndex); - if(res == VK_SUCCESS) - { - VkDeviceSize alignmentForMemType = VMA_MAX( - vkMemReq.alignment, - GetMemoryTypeMinAlignment(memTypeIndex)); + 1, // allocationCount + pAllocation); - res = AllocateMemoryOfType( - vkMemReq.size, - alignmentForMemType, - requiresDedicatedAllocation || prefersDedicatedAllocation, - dedicatedBuffer, - dedicatedBufferUsage, - dedicatedImage, - createInfo, - memTypeIndex, - suballocType, - allocationCount, - pAllocations); - // Succeeded on first try. - if(res == VK_SUCCESS) + if(res >= 0) + { + // 3. Bind image with memory. + if((pAllocationCreateInfo->flags & VMA_ALLOCATION_CREATE_DONT_BIND_BIT) == 0) { - return res; + res = allocator->BindImageMemory(*pAllocation, 0, *pImage, VMA_NULL); } - // Allocation from this memory type failed. Try other compatible memory types. - else + if(res >= 0) { - for(;;) - { - // Remove old memTypeIndex from list of possibilities. - memoryTypeBits &= ~(1u << memTypeIndex); - // Find alternative memTypeIndex. - res = vmaFindMemoryTypeIndex(this, memoryTypeBits, &createInfo, &memTypeIndex); - if(res == VK_SUCCESS) - { - alignmentForMemType = VMA_MAX( - vkMemReq.alignment, - GetMemoryTypeMinAlignment(memTypeIndex)); - - res = AllocateMemoryOfType( - vkMemReq.size, - alignmentForMemType, - requiresDedicatedAllocation || prefersDedicatedAllocation, - dedicatedBuffer, - dedicatedBufferUsage, - dedicatedImage, - createInfo, - memTypeIndex, - suballocType, - allocationCount, - pAllocations); - // Allocation from this alternative memory type succeeded. - if(res == VK_SUCCESS) - { - return res; - } - // else: Allocation from this memory type failed. Try next one - next loop iteration. - } - // No other matching memory type index could be found. - else - { - // Not returning res, which is VK_ERROR_FEATURE_NOT_PRESENT, because we already failed to allocate once. - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } - } - } - } - // Can't find any single memory type maching requirements. res is VK_ERROR_FEATURE_NOT_PRESENT. - else - return res; - } -} - -void VmaAllocator_T::FreeMemory( - size_t allocationCount, - const VmaAllocation* pAllocations) -{ - VMA_ASSERT(pAllocations); - - for(size_t allocIndex = allocationCount; allocIndex--; ) - { - VmaAllocation allocation = pAllocations[allocIndex]; - - if(allocation != VK_NULL_HANDLE) - { - if(TouchAllocation(allocation)) - { - if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) + // All steps succeeded. + #if VMA_STATS_STRING_ENABLED + (*pAllocation)->InitBufferImageUsage(pImageCreateInfo->usage); + #endif + if(pAllocationInfo != VMA_NULL) { - FillAllocation(allocation, VMA_ALLOCATION_FILL_PATTERN_DESTROYED); + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); } - switch(allocation->GetType()) - { - case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: - { - VmaBlockVector* pBlockVector = VMA_NULL; - VmaPool hPool = allocation->GetBlock()->GetParentPool(); - if(hPool != VK_NULL_HANDLE) - { - pBlockVector = &hPool->m_BlockVector; - } - else - { - const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); - pBlockVector = m_pBlockVectors[memTypeIndex]; - } - pBlockVector->Free(allocation); - } - break; - case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: - FreeDedicatedMemory(allocation); - break; - default: - VMA_ASSERT(0); - } + return VK_SUCCESS; } - - // Do this regardless of whether the allocation is lost. Lost allocations still account to Budget.AllocationBytes. - m_Budget.RemoveAllocation(MemoryTypeIndexToHeapIndex(allocation->GetMemoryTypeIndex()), allocation->GetSize()); - allocation->SetUserData(this, VMA_NULL); - m_AllocationObjectAllocator.Free(allocation); + allocator->FreeMemory( + 1, // allocationCount + pAllocation); + *pAllocation = VK_NULL_HANDLE; + (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, *pImage, allocator->GetAllocationCallbacks()); + *pImage = VK_NULL_HANDLE; + return res; } + (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, *pImage, allocator->GetAllocationCallbacks()); + *pImage = VK_NULL_HANDLE; + return res; } + return res; } -void VmaAllocator_T::CalculateStats(VmaStats* pStats) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAliasingImage( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + const VkImageCreateInfo* VMA_NOT_NULL pImageCreateInfo, + VkImage VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pImage) { - // Initialize. - InitStatInfo(pStats->total); - for (const auto i : c10::irange(VK_MAX_MEMORY_TYPES))InitStatInfo(pStats->memoryType[i]); - for (const auto i : c10::irange(VK_MAX_MEMORY_HEAPS))InitStatInfo(pStats->memoryHeap[i]); + VMA_ASSERT(allocator && pImageCreateInfo && pImage && allocation); - // Process default pools. - for (const auto memTypeIndex : c10::irange(GetMemoryTypeCount())) { - VmaBlockVector* const pBlockVector = m_pBlockVectors[memTypeIndex]; - VMA_ASSERT(pBlockVector); - pBlockVector->AddStats(pStats); - } + *pImage = VK_NULL_HANDLE; - // Process custom pools. + VMA_DEBUG_LOG("vmaCreateImage"); + + if (pImageCreateInfo->extent.width == 0 || + pImageCreateInfo->extent.height == 0 || + pImageCreateInfo->extent.depth == 0 || + pImageCreateInfo->mipLevels == 0 || + pImageCreateInfo->arrayLayers == 0) { - VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); - for(size_t poolIndex = 0, poolCount = m_Pools.size(); poolIndex < poolCount; ++poolIndex) - { - m_Pools[poolIndex]->m_BlockVector.AddStats(pStats); - } + return VK_ERROR_INITIALIZATION_FAILED; } - // Process dedicated allocations. - for (const auto memTypeIndex : c10::irange(GetMemoryTypeCount())) { - const uint32_t memHeapIndex = MemoryTypeIndexToHeapIndex(memTypeIndex); - VmaMutexLockRead dedicatedAllocationsLock(m_DedicatedAllocationsMutex[memTypeIndex], m_UseMutex); - AllocationVectorType* const pDedicatedAllocVector = m_pDedicatedAllocations[memTypeIndex]; - VMA_ASSERT(pDedicatedAllocVector); - for(size_t allocIndex = 0, allocCount = pDedicatedAllocVector->size(); allocIndex < allocCount; ++allocIndex) + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + // 1. Create VkImage. + VkResult res = (*allocator->GetVulkanFunctions().vkCreateImage)( + allocator->m_hDevice, + pImageCreateInfo, + allocator->GetAllocationCallbacks(), + pImage); + if (res >= 0) + { + // 2. Bind image with memory. + res = allocator->BindImageMemory(allocation, 0, *pImage, VMA_NULL); + if (res >= 0) { - VmaStatInfo allocationStatInfo; - (*pDedicatedAllocVector)[allocIndex]->DedicatedAllocCalcStatsInfo(allocationStatInfo); - VmaAddStatInfo(pStats->total, allocationStatInfo); - VmaAddStatInfo(pStats->memoryType[memTypeIndex], allocationStatInfo); - VmaAddStatInfo(pStats->memoryHeap[memHeapIndex], allocationStatInfo); + return VK_SUCCESS; } + (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, *pImage, allocator->GetAllocationCallbacks()); } - - // Postprocess. - VmaPostprocessCalcStatInfo(pStats->total); - for (const auto i : c10::irange(GetMemoryTypeCount()))VmaPostprocessCalcStatInfo(pStats->memoryType[i]); - for (const auto i : c10::irange(GetMemoryHeapCount()))VmaPostprocessCalcStatInfo(pStats->memoryHeap[i]); + return res; } -void VmaAllocator_T::GetBudget(VmaBudget* outBudget, uint32_t firstHeap, uint32_t heapCount) +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyImage( + VmaAllocator VMA_NOT_NULL allocator, + VkImage VMA_NULLABLE_NON_DISPATCHABLE image, + VmaAllocation VMA_NULLABLE allocation) { -#if VMA_MEMORY_BUDGET - if(m_UseExtMemoryBudget) + VMA_ASSERT(allocator); + + if(image == VK_NULL_HANDLE && allocation == VK_NULL_HANDLE) { - if(m_Budget.m_OperationsSinceBudgetFetch < 30) - { - VmaMutexLockRead lockRead(m_Budget.m_BudgetMutex, m_UseMutex); - for(uint32_t i = 0; i < heapCount; ++i, ++outBudget) - { - const uint32_t heapIndex = firstHeap + i; + return; + } - outBudget->blockBytes = m_Budget.m_BlockBytes[heapIndex]; - outBudget->allocationBytes = m_Budget.m_AllocationBytes[heapIndex]; + VMA_DEBUG_LOG("vmaDestroyImage"); - if(m_Budget.m_VulkanUsage[heapIndex] + outBudget->blockBytes > m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]) - { - outBudget->usage = m_Budget.m_VulkanUsage[heapIndex] + - outBudget->blockBytes - m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]; - } - else - { - outBudget->usage = 0; - } + VMA_DEBUG_GLOBAL_MUTEX_LOCK - // Have to take MIN with heap size because explicit HeapSizeLimit is included in it. - outBudget->budget = VMA_MIN( - m_Budget.m_VulkanBudget[heapIndex], m_MemProps.memoryHeaps[heapIndex].size); - } - } - else - { - UpdateVulkanBudget(); // Outside of mutex lock - GetBudget(outBudget, firstHeap, heapCount); // Recursion - } + if(image != VK_NULL_HANDLE) + { + (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, image, allocator->GetAllocationCallbacks()); } - else -#endif + if(allocation != VK_NULL_HANDLE) { - for(uint32_t i = 0; i < heapCount; ++i, ++outBudget) - { - const uint32_t heapIndex = firstHeap + i; - - outBudget->blockBytes = m_Budget.m_BlockBytes[heapIndex]; - outBudget->allocationBytes = m_Budget.m_AllocationBytes[heapIndex]; - - outBudget->usage = outBudget->blockBytes; - outBudget->budget = m_MemProps.memoryHeaps[heapIndex].size * 8 / 10; // 80% heuristics. - } + allocator->FreeMemory( + 1, // allocationCount + &allocation); } } -static const uint32_t VMA_VENDOR_ID_AMD = 4098; - -VkResult VmaAllocator_T::DefragmentationBegin( - const VmaDefragmentationInfo2& info, - VmaDefragmentationStats* pStats, - VmaDefragmentationContext* pContext) +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateVirtualBlock( + const VmaVirtualBlockCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaVirtualBlock VMA_NULLABLE * VMA_NOT_NULL pVirtualBlock) { - if(info.pAllocationsChanged != VMA_NULL) + VMA_ASSERT(pCreateInfo && pVirtualBlock); + VMA_ASSERT(pCreateInfo->size > 0); + VMA_DEBUG_LOG("vmaCreateVirtualBlock"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + *pVirtualBlock = vma_new(pCreateInfo->pAllocationCallbacks, VmaVirtualBlock_T)(*pCreateInfo); + VkResult res = (*pVirtualBlock)->Init(); + if(res < 0) { - memset(info.pAllocationsChanged, 0, info.allocationCount * sizeof(VkBool32)); + vma_delete(pCreateInfo->pAllocationCallbacks, *pVirtualBlock); + *pVirtualBlock = VK_NULL_HANDLE; } + return res; +} - *pContext = vma_new(this, VmaDefragmentationContext_T)( - this, m_CurrentFrameIndex.load(), info.flags, pStats); - - (*pContext)->AddPools(info.poolCount, info.pPools); - (*pContext)->AddAllocations( - info.allocationCount, info.pAllocations, info.pAllocationsChanged); - - VkResult res = (*pContext)->Defragment( - info.maxCpuBytesToMove, info.maxCpuAllocationsToMove, - info.maxGpuBytesToMove, info.maxGpuAllocationsToMove, - info.commandBuffer, pStats, info.flags); - - if(res != VK_NOT_READY) +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyVirtualBlock(VmaVirtualBlock VMA_NULLABLE virtualBlock) +{ + if(virtualBlock != VK_NULL_HANDLE) { - vma_delete(this, *pContext); - *pContext = VMA_NULL; + VMA_DEBUG_LOG("vmaDestroyVirtualBlock"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + VkAllocationCallbacks allocationCallbacks = virtualBlock->m_AllocationCallbacks; // Have to copy the callbacks when destroying. + vma_delete(&allocationCallbacks, virtualBlock); } - - return res; } -VkResult VmaAllocator_T::DefragmentationEnd( - VmaDefragmentationContext context) +VMA_CALL_PRE VkBool32 VMA_CALL_POST vmaIsVirtualBlockEmpty(VmaVirtualBlock VMA_NOT_NULL virtualBlock) { - vma_delete(this, context); - return VK_SUCCESS; + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE); + VMA_DEBUG_LOG("vmaIsVirtualBlockEmpty"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + return virtualBlock->IsEmpty() ? VK_TRUE : VK_FALSE; } -VkResult VmaAllocator_T::DefragmentationPassBegin( - VmaDefragmentationPassInfo* pInfo, - VmaDefragmentationContext context) +VMA_CALL_PRE void VMA_CALL_POST vmaGetVirtualAllocationInfo(VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaVirtualAllocation VMA_NOT_NULL_NON_DISPATCHABLE allocation, VmaVirtualAllocationInfo* VMA_NOT_NULL pVirtualAllocInfo) { - return context->DefragmentPassBegin(pInfo); + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE && pVirtualAllocInfo != VMA_NULL); + VMA_DEBUG_LOG("vmaGetVirtualAllocationInfo"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + virtualBlock->GetAllocationInfo(allocation, *pVirtualAllocInfo); } -VkResult VmaAllocator_T::DefragmentationPassEnd( - VmaDefragmentationContext context) -{ - return context->DefragmentPassEnd(); +VMA_CALL_PRE VkResult VMA_CALL_POST vmaVirtualAllocate(VmaVirtualBlock VMA_NOT_NULL virtualBlock, + const VmaVirtualAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, VmaVirtualAllocation VMA_NULLABLE_NON_DISPATCHABLE* VMA_NOT_NULL pAllocation, + VkDeviceSize* VMA_NULLABLE pOffset) +{ + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE && pCreateInfo != VMA_NULL && pAllocation != VMA_NULL); + VMA_DEBUG_LOG("vmaVirtualAllocate"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + return virtualBlock->Allocate(*pCreateInfo, *pAllocation, pOffset); } -void VmaAllocator_T::GetAllocationInfo(VmaAllocation hAllocation, VmaAllocationInfo* pAllocationInfo) +VMA_CALL_PRE void VMA_CALL_POST vmaVirtualFree(VmaVirtualBlock VMA_NOT_NULL virtualBlock, VmaVirtualAllocation VMA_NULLABLE_NON_DISPATCHABLE allocation) { - if(hAllocation->CanBecomeLost()) + if(allocation != VK_NULL_HANDLE) { - /* - Warning: This is a carefully designed algorithm. - Do not modify unless you really know what you're doing :) - */ - const uint32_t localCurrFrameIndex = m_CurrentFrameIndex.load(); - uint32_t localLastUseFrameIndex = hAllocation->GetLastUseFrameIndex(); - for(;;) - { - if(localLastUseFrameIndex == VMA_FRAME_INDEX_LOST) - { - pAllocationInfo->memoryType = UINT32_MAX; - pAllocationInfo->deviceMemory = VK_NULL_HANDLE; - pAllocationInfo->offset = 0; - pAllocationInfo->size = hAllocation->GetSize(); - pAllocationInfo->pMappedData = VMA_NULL; - pAllocationInfo->pUserData = hAllocation->GetUserData(); - return; - } - else if(localLastUseFrameIndex == localCurrFrameIndex) - { - pAllocationInfo->memoryType = hAllocation->GetMemoryTypeIndex(); - pAllocationInfo->deviceMemory = hAllocation->GetMemory(); - pAllocationInfo->offset = hAllocation->GetOffset(); - pAllocationInfo->size = hAllocation->GetSize(); - pAllocationInfo->pMappedData = VMA_NULL; - pAllocationInfo->pUserData = hAllocation->GetUserData(); - return; - } - else // Last use time earlier than current time. - { - if(hAllocation->CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, localCurrFrameIndex)) - { - localLastUseFrameIndex = localCurrFrameIndex; - } - } - } + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE); + VMA_DEBUG_LOG("vmaVirtualFree"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + virtualBlock->Free(allocation); } - else - { -#if VMA_STATS_STRING_ENABLED - uint32_t localCurrFrameIndex = m_CurrentFrameIndex.load(); - uint32_t localLastUseFrameIndex = hAllocation->GetLastUseFrameIndex(); - for(;;) - { - VMA_ASSERT(localLastUseFrameIndex != VMA_FRAME_INDEX_LOST); - if(localLastUseFrameIndex == localCurrFrameIndex) - { - break; - } - else // Last use time earlier than current time. - { - if(hAllocation->CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, localCurrFrameIndex)) - { - localLastUseFrameIndex = localCurrFrameIndex; - } - } - } -#endif +} - pAllocationInfo->memoryType = hAllocation->GetMemoryTypeIndex(); - pAllocationInfo->deviceMemory = hAllocation->GetMemory(); - pAllocationInfo->offset = hAllocation->GetOffset(); - pAllocationInfo->size = hAllocation->GetSize(); - pAllocationInfo->pMappedData = hAllocation->GetMappedData(); - pAllocationInfo->pUserData = hAllocation->GetUserData(); - } +VMA_CALL_PRE void VMA_CALL_POST vmaClearVirtualBlock(VmaVirtualBlock VMA_NOT_NULL virtualBlock) +{ + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE); + VMA_DEBUG_LOG("vmaClearVirtualBlock"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + virtualBlock->Clear(); } -bool VmaAllocator_T::TouchAllocation(VmaAllocation hAllocation) +VMA_CALL_PRE void VMA_CALL_POST vmaSetVirtualAllocationUserData(VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaVirtualAllocation VMA_NOT_NULL_NON_DISPATCHABLE allocation, void* VMA_NULLABLE pUserData) { - // This is a stripped-down version of VmaAllocator_T::GetAllocationInfo. - if(hAllocation->CanBecomeLost()) - { - uint32_t localCurrFrameIndex = m_CurrentFrameIndex.load(); - uint32_t localLastUseFrameIndex = hAllocation->GetLastUseFrameIndex(); - for(;;) - { - if(localLastUseFrameIndex == VMA_FRAME_INDEX_LOST) - { - return false; - } - else if(localLastUseFrameIndex == localCurrFrameIndex) - { - return true; - } - else // Last use time earlier than current time. - { - if(hAllocation->CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, localCurrFrameIndex)) - { - localLastUseFrameIndex = localCurrFrameIndex; - } - } - } - } - else - { -#if VMA_STATS_STRING_ENABLED - uint32_t localCurrFrameIndex = m_CurrentFrameIndex.load(); - uint32_t localLastUseFrameIndex = hAllocation->GetLastUseFrameIndex(); - for(;;) - { - VMA_ASSERT(localLastUseFrameIndex != VMA_FRAME_INDEX_LOST); - if(localLastUseFrameIndex == localCurrFrameIndex) - { - break; - } - else // Last use time earlier than current time. - { - if(hAllocation->CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, localCurrFrameIndex)) - { - localLastUseFrameIndex = localCurrFrameIndex; - } - } - } -#endif + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE); + VMA_DEBUG_LOG("vmaSetVirtualAllocationUserData"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + virtualBlock->SetAllocationUserData(allocation, pUserData); +} - return true; - } +VMA_CALL_PRE void VMA_CALL_POST vmaGetVirtualBlockStatistics(VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaStatistics* VMA_NOT_NULL pStats) +{ + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE && pStats != VMA_NULL); + VMA_DEBUG_LOG("vmaGetVirtualBlockStatistics"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + virtualBlock->GetStatistics(*pStats); } -VkResult VmaAllocator_T::CreatePool(const VmaPoolCreateInfo* pCreateInfo, VmaPool* pPool) +VMA_CALL_PRE void VMA_CALL_POST vmaCalculateVirtualBlockStatistics(VmaVirtualBlock VMA_NOT_NULL virtualBlock, + VmaDetailedStatistics* VMA_NOT_NULL pStats) { - VMA_DEBUG_LOG(" CreatePool: MemoryTypeIndex=%u, flags=%u", pCreateInfo->memoryTypeIndex, pCreateInfo->flags); + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE && pStats != VMA_NULL); + VMA_DEBUG_LOG("vmaCalculateVirtualBlockStatistics"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + virtualBlock->CalculateDetailedStatistics(*pStats); +} - VmaPoolCreateInfo newCreateInfo = *pCreateInfo; +#if VMA_STATS_STRING_ENABLED - if(newCreateInfo.maxBlockCount == 0) - { - newCreateInfo.maxBlockCount = SIZE_MAX; - } - if(newCreateInfo.minBlockCount > newCreateInfo.maxBlockCount) - { - return VK_ERROR_INITIALIZATION_FAILED; - } - // Memory type index out of range or forbidden. - if(pCreateInfo->memoryTypeIndex >= GetMemoryTypeCount() || - ((1u << pCreateInfo->memoryTypeIndex) & m_GlobalMemoryTypeBits) == 0) +VMA_CALL_PRE void VMA_CALL_POST vmaBuildVirtualBlockStatsString(VmaVirtualBlock VMA_NOT_NULL virtualBlock, + char* VMA_NULLABLE * VMA_NOT_NULL ppStatsString, VkBool32 detailedMap) +{ + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE && ppStatsString != VMA_NULL); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + const VkAllocationCallbacks* allocationCallbacks = virtualBlock->GetAllocationCallbacks(); + VmaStringBuilder sb(allocationCallbacks); + virtualBlock->BuildStatsString(detailedMap != VK_FALSE, sb); + *ppStatsString = VmaCreateStringCopy(allocationCallbacks, sb.GetData(), sb.GetLength()); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaFreeVirtualBlockStatsString(VmaVirtualBlock VMA_NOT_NULL virtualBlock, + char* VMA_NULLABLE pStatsString) +{ + if(pStatsString != VMA_NULL) { - return VK_ERROR_FEATURE_NOT_PRESENT; + VMA_ASSERT(virtualBlock != VK_NULL_HANDLE); + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + VmaFreeString(virtualBlock->GetAllocationCallbacks(), pStatsString); } +} +#endif // VMA_STATS_STRING_ENABLED +#endif // _VMA_PUBLIC_INTERFACE +#endif // VMA_IMPLEMENTATION - const VkDeviceSize preferredBlockSize = CalcPreferredBlockSize(newCreateInfo.memoryTypeIndex); +/** +\page quick_start Quick start - *pPool = vma_new(this, VmaPool_T)(this, newCreateInfo, preferredBlockSize); +\section quick_start_project_setup Project setup - VkResult res = (*pPool)->m_BlockVector.CreateMinBlocks(); - if(res != VK_SUCCESS) - { - vma_delete(this, *pPool); - *pPool = VMA_NULL; - return res; - } - - // Add to m_Pools. - { - VmaMutexLockWrite lock(m_PoolsMutex, m_UseMutex); - (*pPool)->SetId(m_NextPoolId++); - VmaVectorInsertSorted(m_Pools, *pPool); - } +Vulkan Memory Allocator comes in form of a "stb-style" single header file. +You don't need to build it as a separate library project. +You can add this file directly to your project and submit it to code repository next to your other source files. - return VK_SUCCESS; -} +"Single header" doesn't mean that everything is contained in C/C++ declarations, +like it tends to be in case of inline functions or C++ templates. +It means that implementation is bundled with interface in a single file and needs to be extracted using preprocessor macro. +If you don't do it properly, you will get linker errors. -void VmaAllocator_T::DestroyPool(VmaPool pool) -{ - // Remove from m_Pools. - { - VmaMutexLockWrite lock(m_PoolsMutex, m_UseMutex); - bool success = VmaVectorRemoveSorted(m_Pools, pool); - VMA_ASSERT(success && "Pool not found in Allocator."); - } +To do it properly: - vma_delete(this, pool); -} +-# Include "vk_mem_alloc.h" file in each CPP file where you want to use the library. + This includes declarations of all members of the library. +-# In exactly one CPP file define following macro before this include. + It enables also internal definitions. -void VmaAllocator_T::GetPoolStats(VmaPool pool, VmaPoolStats* pPoolStats) -{ - pool->m_BlockVector.GetPoolStats(pPoolStats); -} +\code +#define VMA_IMPLEMENTATION +#include "vk_mem_alloc.h" +\endcode -void VmaAllocator_T::SetCurrentFrameIndex(uint32_t frameIndex) -{ - m_CurrentFrameIndex.store(frameIndex); +It may be a good idea to create dedicated CPP file just for this purpose. -#if VMA_MEMORY_BUDGET - if(m_UseExtMemoryBudget) - { - UpdateVulkanBudget(); - } -#endif // #if VMA_MEMORY_BUDGET -} +This library includes header ``, which in turn +includes `` on Windows. If you need some specific macros defined +before including these headers (like `WIN32_LEAN_AND_MEAN` or +`WINVER` for Windows, `VK_USE_PLATFORM_WIN32_KHR` for Vulkan), you must define +them before every `#include` of this library. -void VmaAllocator_T::MakePoolAllocationsLost( - VmaPool hPool, - size_t* pLostAllocationCount) -{ - hPool->m_BlockVector.MakePoolAllocationsLost( - m_CurrentFrameIndex.load(), - pLostAllocationCount); -} +This library is written in C++, but has C-compatible interface. +Thus you can include and use vk_mem_alloc.h in C or C++ code, but full +implementation with `VMA_IMPLEMENTATION` macro must be compiled as C++, NOT as C. +Some features of C++14 used. STL containers, RTTI, or C++ exceptions are not used. -VkResult VmaAllocator_T::CheckPoolCorruption(VmaPool hPool) -{ - return hPool->m_BlockVector.CheckCorruption(); -} -VkResult VmaAllocator_T::CheckCorruption(uint32_t memoryTypeBits) -{ - VkResult finalRes = VK_ERROR_FEATURE_NOT_PRESENT; +\section quick_start_initialization Initialization - // Process default pools. - for (const auto memTypeIndex : c10::irange(GetMemoryTypeCount())) { - if(((1u << memTypeIndex) & memoryTypeBits) != 0) - { - VmaBlockVector* const pBlockVector = m_pBlockVectors[memTypeIndex]; - VMA_ASSERT(pBlockVector); - VkResult localRes = pBlockVector->CheckCorruption(); - switch(localRes) - { - case VK_ERROR_FEATURE_NOT_PRESENT: - break; - case VK_SUCCESS: - finalRes = VK_SUCCESS; - break; - default: - return localRes; - } - } - } +At program startup: - // Process custom pools. - { - VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); - for(size_t poolIndex = 0, poolCount = m_Pools.size(); poolIndex < poolCount; ++poolIndex) - { - if(((1u << m_Pools[poolIndex]->m_BlockVector.GetMemoryTypeIndex()) & memoryTypeBits) != 0) - { - VkResult localRes = m_Pools[poolIndex]->m_BlockVector.CheckCorruption(); - switch(localRes) - { - case VK_ERROR_FEATURE_NOT_PRESENT: - break; - case VK_SUCCESS: - finalRes = VK_SUCCESS; - break; - default: - return localRes; - } - } - } - } +-# Initialize Vulkan to have `VkPhysicalDevice`, `VkDevice` and `VkInstance` object. +-# Fill VmaAllocatorCreateInfo structure and create #VmaAllocator object by + calling vmaCreateAllocator(). - return finalRes; -} +Only members `physicalDevice`, `device`, `instance` are required. +However, you should inform the library which Vulkan version do you use by setting +VmaAllocatorCreateInfo::vulkanApiVersion and which extensions did you enable +by setting VmaAllocatorCreateInfo::flags (like #VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT for VK_KHR_buffer_device_address). +Otherwise, VMA would use only features of Vulkan 1.0 core with no extensions. -void VmaAllocator_T::CreateLostAllocation(VmaAllocation* pAllocation) -{ - *pAllocation = m_AllocationObjectAllocator.Allocate(VMA_FRAME_INDEX_LOST, false); - (*pAllocation)->InitLost(); -} +You may need to configure importing Vulkan functions. There are 3 ways to do this: + +-# **If you link with Vulkan static library** (e.g. "vulkan-1.lib" on Windows): + - You don't need to do anything. + - VMA will use these, as macro `VMA_STATIC_VULKAN_FUNCTIONS` is defined to 1 by default. +-# **If you want VMA to fetch pointers to Vulkan functions dynamically** using `vkGetInstanceProcAddr`, + `vkGetDeviceProcAddr` (this is the option presented in the example below): + - Define `VMA_STATIC_VULKAN_FUNCTIONS` to 0, `VMA_DYNAMIC_VULKAN_FUNCTIONS` to 1. + - Provide pointers to these two functions via VmaVulkanFunctions::vkGetInstanceProcAddr, + VmaVulkanFunctions::vkGetDeviceProcAddr. + - The library will fetch pointers to all other functions it needs internally. +-# **If you fetch pointers to all Vulkan functions in a custom way**, e.g. using some loader like + [Volk](https://github.com/zeux/volk): + - Define `VMA_STATIC_VULKAN_FUNCTIONS` and `VMA_DYNAMIC_VULKAN_FUNCTIONS` to 0. + - Pass these pointers via structure #VmaVulkanFunctions. -// An object that increments given atomic but decrements it back in the destructor unless Commit() is called. -template -struct AtomicTransactionalIncrement -{ -public: - typedef std::atomic AtomicT; - ~AtomicTransactionalIncrement() - { - if(m_Atomic) - --(*m_Atomic); - } - T Increment(AtomicT* atomic) - { - m_Atomic = atomic; - return m_Atomic->fetch_add(1); - } - void Commit() - { - m_Atomic = nullptr; - } +\code +VmaVulkanFunctions vulkanFunctions = {}; +vulkanFunctions.vkGetInstanceProcAddr = &vkGetInstanceProcAddr; +vulkanFunctions.vkGetDeviceProcAddr = &vkGetDeviceProcAddr; -private: - AtomicT* m_Atomic = nullptr; -}; +VmaAllocatorCreateInfo allocatorCreateInfo = {}; +allocatorCreateInfo.vulkanApiVersion = VK_API_VERSION_1_2; +allocatorCreateInfo.physicalDevice = physicalDevice; +allocatorCreateInfo.device = device; +allocatorCreateInfo.instance = instance; +allocatorCreateInfo.pVulkanFunctions = &vulkanFunctions; -VkResult VmaAllocator_T::AllocateVulkanMemory(const VkMemoryAllocateInfo* pAllocateInfo, VkDeviceMemory* pMemory) -{ - AtomicTransactionalIncrement deviceMemoryCountIncrement; - const uint64_t prevDeviceMemoryCount = deviceMemoryCountIncrement.Increment(&m_DeviceMemoryCount); -#if VMA_DEBUG_DONT_EXCEED_MAX_MEMORY_ALLOCATION_COUNT - if(prevDeviceMemoryCount >= m_PhysicalDeviceProperties.limits.maxMemoryAllocationCount) - { - return VK_ERROR_TOO_MANY_OBJECTS; - } -#endif +VmaAllocator allocator; +vmaCreateAllocator(&allocatorCreateInfo, &allocator); +\endcode - const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(pAllocateInfo->memoryTypeIndex); - // HeapSizeLimit is in effect for this heap. - if((m_HeapSizeLimitMask & (1u << heapIndex)) != 0) - { - const VkDeviceSize heapSize = m_MemProps.memoryHeaps[heapIndex].size; - VkDeviceSize blockBytes = m_Budget.m_BlockBytes[heapIndex]; - for(;;) - { - const VkDeviceSize blockBytesAfterAllocation = blockBytes + pAllocateInfo->allocationSize; - if(blockBytesAfterAllocation > heapSize) - { - return VK_ERROR_OUT_OF_DEVICE_MEMORY; - } - if(m_Budget.m_BlockBytes[heapIndex].compare_exchange_strong(blockBytes, blockBytesAfterAllocation)) - { - break; - } - } - } - else - { - m_Budget.m_BlockBytes[heapIndex] += pAllocateInfo->allocationSize; - } +\section quick_start_resource_allocation Resource allocation - // VULKAN CALL vkAllocateMemory. - VkResult res = (*m_VulkanFunctions.vkAllocateMemory)(m_hDevice, pAllocateInfo, GetAllocationCallbacks(), pMemory); +When you want to create a buffer or image: - if(res == VK_SUCCESS) - { -#if VMA_MEMORY_BUDGET - ++m_Budget.m_OperationsSinceBudgetFetch; -#endif +-# Fill `VkBufferCreateInfo` / `VkImageCreateInfo` structure. +-# Fill VmaAllocationCreateInfo structure. +-# Call vmaCreateBuffer() / vmaCreateImage() to get `VkBuffer`/`VkImage` with memory + already allocated and bound to it, plus #VmaAllocation objects that represents its underlying memory. - // Informative callback. - if(m_DeviceMemoryCallbacks.pfnAllocate != VMA_NULL) - { - (*m_DeviceMemoryCallbacks.pfnAllocate)(this, pAllocateInfo->memoryTypeIndex, *pMemory, pAllocateInfo->allocationSize, m_DeviceMemoryCallbacks.pUserData); - } +\code +VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufferInfo.size = 65536; +bufferInfo.usage = VK_BUFFER_USAGE_VERTEX_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - deviceMemoryCountIncrement.Commit(); - } - else - { - m_Budget.m_BlockBytes[heapIndex] -= pAllocateInfo->allocationSize; - } +VmaAllocationCreateInfo allocInfo = {}; +allocInfo.usage = VMA_MEMORY_USAGE_AUTO; - return res; -} +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); +\endcode -void VmaAllocator_T::FreeVulkanMemory(uint32_t memoryType, VkDeviceSize size, VkDeviceMemory hMemory) -{ - // Informative callback. - if(m_DeviceMemoryCallbacks.pfnFree != VMA_NULL) - { - (*m_DeviceMemoryCallbacks.pfnFree)(this, memoryType, hMemory, size, m_DeviceMemoryCallbacks.pUserData); - } +Don't forget to destroy your objects when no longer needed: - // VULKAN CALL vkFreeMemory. - (*m_VulkanFunctions.vkFreeMemory)(m_hDevice, hMemory, GetAllocationCallbacks()); +\code +vmaDestroyBuffer(allocator, buffer, allocation); +vmaDestroyAllocator(allocator); +\endcode - m_Budget.m_BlockBytes[MemoryTypeIndexToHeapIndex(memoryType)] -= size; - --m_DeviceMemoryCount; -} +\page choosing_memory_type Choosing memory type -VkResult VmaAllocator_T::BindVulkanBuffer( - VkDeviceMemory memory, - VkDeviceSize memoryOffset, - VkBuffer buffer, - const void* pNext) -{ - if(pNext != VMA_NULL) - { -#if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 - if((m_UseKhrBindMemory2 || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) && - m_VulkanFunctions.vkBindBufferMemory2KHR != VMA_NULL) - { - VkBindBufferMemoryInfoKHR bindBufferMemoryInfo = { VK_STRUCTURE_TYPE_BIND_BUFFER_MEMORY_INFO_KHR }; - bindBufferMemoryInfo.pNext = pNext; - bindBufferMemoryInfo.buffer = buffer; - bindBufferMemoryInfo.memory = memory; - bindBufferMemoryInfo.memoryOffset = memoryOffset; - return (*m_VulkanFunctions.vkBindBufferMemory2KHR)(m_hDevice, 1, &bindBufferMemoryInfo); - } - else -#endif // #if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 - { - return VK_ERROR_EXTENSION_NOT_PRESENT; - } - } - else - { - return (*m_VulkanFunctions.vkBindBufferMemory)(m_hDevice, buffer, memory, memoryOffset); - } -} +Physical devices in Vulkan support various combinations of memory heaps and +types. Help with choosing correct and optimal memory type for your specific +resource is one of the key features of this library. You can use it by filling +appropriate members of VmaAllocationCreateInfo structure, as described below. +You can also combine multiple methods. -VkResult VmaAllocator_T::BindVulkanImage( - VkDeviceMemory memory, - VkDeviceSize memoryOffset, - VkImage image, - const void* pNext) -{ - if(pNext != VMA_NULL) - { -#if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 - if((m_UseKhrBindMemory2 || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) && - m_VulkanFunctions.vkBindImageMemory2KHR != VMA_NULL) - { - VkBindImageMemoryInfoKHR bindBufferMemoryInfo = { VK_STRUCTURE_TYPE_BIND_IMAGE_MEMORY_INFO_KHR }; - bindBufferMemoryInfo.pNext = pNext; - bindBufferMemoryInfo.image = image; - bindBufferMemoryInfo.memory = memory; - bindBufferMemoryInfo.memoryOffset = memoryOffset; - return (*m_VulkanFunctions.vkBindImageMemory2KHR)(m_hDevice, 1, &bindBufferMemoryInfo); - } - else -#endif // #if VMA_BIND_MEMORY2 - { - return VK_ERROR_EXTENSION_NOT_PRESENT; - } - } - else - { - return (*m_VulkanFunctions.vkBindImageMemory)(m_hDevice, image, memory, memoryOffset); - } -} +-# If you just want to find memory type index that meets your requirements, you + can use function: vmaFindMemoryTypeIndexForBufferInfo(), + vmaFindMemoryTypeIndexForImageInfo(), vmaFindMemoryTypeIndex(). +-# If you want to allocate a region of device memory without association with any + specific image or buffer, you can use function vmaAllocateMemory(). Usage of + this function is not recommended and usually not needed. + vmaAllocateMemoryPages() function is also provided for creating multiple allocations at once, + which may be useful for sparse binding. +-# If you already have a buffer or an image created, you want to allocate memory + for it and then you will bind it yourself, you can use function + vmaAllocateMemoryForBuffer(), vmaAllocateMemoryForImage(). + For binding you should use functions: vmaBindBufferMemory(), vmaBindImageMemory() + or their extended versions: vmaBindBufferMemory2(), vmaBindImageMemory2(). +-# **This is the easiest and recommended way to use this library:** + If you want to create a buffer or an image, allocate memory for it and bind + them together, all in one call, you can use function vmaCreateBuffer(), + vmaCreateImage(). -VkResult VmaAllocator_T::Map(VmaAllocation hAllocation, void** ppData) -{ - if(hAllocation->CanBecomeLost()) - { - return VK_ERROR_MEMORY_MAP_FAILED; - } +When using 3. or 4., the library internally queries Vulkan for memory types +supported for that buffer or image (function `vkGetBufferMemoryRequirements()`) +and uses only one of these types. - switch(hAllocation->GetType()) - { - case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: - { - VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); - char *pBytes = VMA_NULL; - VkResult res = pBlock->Map(this, 1, (void**)&pBytes); - if(res == VK_SUCCESS) - { - *ppData = pBytes + (ptrdiff_t)hAllocation->GetOffset(); - hAllocation->BlockAllocMap(); - } - return res; - } - case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: - return hAllocation->DedicatedAllocMap(this, ppData); - default: - VMA_ASSERT(0); - return VK_ERROR_MEMORY_MAP_FAILED; - } -} +If no memory type can be found that meets all the requirements, these functions +return `VK_ERROR_FEATURE_NOT_PRESENT`. -void VmaAllocator_T::Unmap(VmaAllocation hAllocation) -{ - switch(hAllocation->GetType()) - { - case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: - { - VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); - hAllocation->BlockAllocUnmap(); - pBlock->Unmap(this, 1); - } - break; - case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: - hAllocation->DedicatedAllocUnmap(this); - break; - default: - VMA_ASSERT(0); - } -} +You can leave VmaAllocationCreateInfo structure completely filled with zeros. +It means no requirements are specified for memory type. +It is valid, although not very useful. -VkResult VmaAllocator_T::BindBufferMemory( - VmaAllocation hAllocation, - VkDeviceSize allocationLocalOffset, - VkBuffer hBuffer, - const void* pNext) +\section choosing_memory_type_usage Usage + +The easiest way to specify memory requirements is to fill member +VmaAllocationCreateInfo::usage using one of the values of enum #VmaMemoryUsage. +It defines high level, common usage types. +Since version 3 of the library, it is recommended to use #VMA_MEMORY_USAGE_AUTO to let it select best memory type for your resource automatically. + +For example, if you want to create a uniform buffer that will be filled using +transfer only once or infrequently and then used for rendering every frame as a uniform buffer, you can +do it using following code. The buffer will most likely end up in a memory type with +`VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT` to be fast to access by the GPU device. + +\code +VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufferInfo.size = 65536; +bufferInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + +VmaAllocationCreateInfo allocInfo = {}; +allocInfo.usage = VMA_MEMORY_USAGE_AUTO; + +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); +\endcode + +If you have a preference for putting the resource in GPU (device) memory or CPU (host) memory +on systems with discrete graphics card that have the memories separate, you can use +#VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE or #VMA_MEMORY_USAGE_AUTO_PREFER_HOST. + +When using `VMA_MEMORY_USAGE_AUTO*` while you want to map the allocated memory, +you also need to specify one of the host access flags: +#VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or #VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT. +This will help the library decide about preferred memory type to ensure it has `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` +so you can map it. + +For example, a staging buffer that will be filled via mapped pointer and then +used as a source of transfer to the buffer decribed previously can be created like this. +It will likely and up in a memory type that is `HOST_VISIBLE` and `HOST_COHERENT` +but not `HOST_CACHED` (meaning uncached, write-combined) and not `DEVICE_LOCAL` (meaning system RAM). + +\code +VkBufferCreateInfo stagingBufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +stagingBufferInfo.size = 65536; +stagingBufferInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT; + +VmaAllocationCreateInfo stagingAllocInfo = {}; +stagingAllocInfo.usage = VMA_MEMORY_USAGE_AUTO; +stagingAllocInfo.flags = VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT; + +VkBuffer stagingBuffer; +VmaAllocation stagingAllocation; +vmaCreateBuffer(allocator, &stagingBufferInfo, &stagingAllocInfo, &stagingBuffer, &stagingAllocation, nullptr); +\endcode + +For more examples of creating different kinds of resources, see chapter \ref usage_patterns. + +Usage values `VMA_MEMORY_USAGE_AUTO*` are legal to use only when the library knows +about the resource being created by having `VkBufferCreateInfo` / `VkImageCreateInfo` passed, +so they work with functions like: vmaCreateBuffer(), vmaCreateImage(), vmaFindMemoryTypeIndexForBufferInfo() etc. +If you allocate raw memory using function vmaAllocateMemory(), you have to use other means of selecting +memory type, as decribed below. + +\note +Old usage values (`VMA_MEMORY_USAGE_GPU_ONLY`, `VMA_MEMORY_USAGE_CPU_ONLY`, +`VMA_MEMORY_USAGE_CPU_TO_GPU`, `VMA_MEMORY_USAGE_GPU_TO_CPU`, `VMA_MEMORY_USAGE_CPU_COPY`) +are still available and work same way as in previous versions of the library +for backward compatibility, but they are not recommended. + +\section choosing_memory_type_required_preferred_flags Required and preferred flags + +You can specify more detailed requirements by filling members +VmaAllocationCreateInfo::requiredFlags and VmaAllocationCreateInfo::preferredFlags +with a combination of bits from enum `VkMemoryPropertyFlags`. For example, +if you want to create a buffer that will be persistently mapped on host (so it +must be `HOST_VISIBLE`) and preferably will also be `HOST_COHERENT` and `HOST_CACHED`, +use following code: + +\code +VmaAllocationCreateInfo allocInfo = {}; +allocInfo.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; +allocInfo.preferredFlags = VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT; +allocInfo.flags = VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT | VMA_ALLOCATION_CREATE_MAPPED_BIT; + +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); +\endcode + +A memory type is chosen that has all the required flags and as many preferred +flags set as possible. + +Value passed in VmaAllocationCreateInfo::usage is internally converted to a set of required and preferred flags, +plus some extra "magic" (heuristics). + +\section choosing_memory_type_explicit_memory_types Explicit memory types + +If you inspected memory types available on the physical device and you have +a preference for memory types that you want to use, you can fill member +VmaAllocationCreateInfo::memoryTypeBits. It is a bit mask, where each bit set +means that a memory type with that index is allowed to be used for the +allocation. Special value 0, just like `UINT32_MAX`, means there are no +restrictions to memory type index. + +Please note that this member is NOT just a memory type index. +Still you can use it to choose just one, specific memory type. +For example, if you already determined that your buffer should be created in +memory type 2, use following code: + +\code +uint32_t memoryTypeIndex = 2; + +VmaAllocationCreateInfo allocInfo = {}; +allocInfo.memoryTypeBits = 1u << memoryTypeIndex; + +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); +\endcode + + +\section choosing_memory_type_custom_memory_pools Custom memory pools + +If you allocate from custom memory pool, all the ways of specifying memory +requirements described above are not applicable and the aforementioned members +of VmaAllocationCreateInfo structure are ignored. Memory type is selected +explicitly when creating the pool and then used to make all the allocations from +that pool. For further details, see \ref custom_memory_pools. + +\section choosing_memory_type_dedicated_allocations Dedicated allocations + +Memory for allocations is reserved out of larger block of `VkDeviceMemory` +allocated from Vulkan internally. That is the main feature of this whole library. +You can still request a separate memory block to be created for an allocation, +just like you would do in a trivial solution without using any allocator. +In that case, a buffer or image is always bound to that memory at offset 0. +This is called a "dedicated allocation". +You can explicitly request it by using flag #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. +The library can also internally decide to use dedicated allocation in some cases, e.g.: + +- When the size of the allocation is large. +- When [VK_KHR_dedicated_allocation](@ref vk_khr_dedicated_allocation) extension is enabled + and it reports that dedicated allocation is required or recommended for the resource. +- When allocation of next big memory block fails due to not enough device memory, + but allocation with the exact requested size succeeds. + + +\page memory_mapping Memory mapping + +To "map memory" in Vulkan means to obtain a CPU pointer to `VkDeviceMemory`, +to be able to read from it or write to it in CPU code. +Mapping is possible only of memory allocated from a memory type that has +`VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` flag. +Functions `vkMapMemory()`, `vkUnmapMemory()` are designed for this purpose. +You can use them directly with memory allocated by this library, +but it is not recommended because of following issue: +Mapping the same `VkDeviceMemory` block multiple times is illegal - only one mapping at a time is allowed. +This includes mapping disjoint regions. Mapping is not reference-counted internally by Vulkan. +Because of this, Vulkan Memory Allocator provides following facilities: + +\note If you want to be able to map an allocation, you need to specify one of the flags +#VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or #VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT +in VmaAllocationCreateInfo::flags. These flags are required for an allocation to be mappable +when using #VMA_MEMORY_USAGE_AUTO or other `VMA_MEMORY_USAGE_AUTO*` enum values. +For other usage values they are ignored and every such allocation made in `HOST_VISIBLE` memory type is mappable, +but they can still be used for consistency. + +\section memory_mapping_mapping_functions Mapping functions + +The library provides following functions for mapping of a specific #VmaAllocation: vmaMapMemory(), vmaUnmapMemory(). +They are safer and more convenient to use than standard Vulkan functions. +You can map an allocation multiple times simultaneously - mapping is reference-counted internally. +You can also map different allocations simultaneously regardless of whether they use the same `VkDeviceMemory` block. +The way it is implemented is that the library always maps entire memory block, not just region of the allocation. +For further details, see description of vmaMapMemory() function. +Example: + +\code +// Having these objects initialized: +struct ConstantBuffer { - VkResult res = VK_SUCCESS; - switch(hAllocation->GetType()) - { - case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: - res = BindVulkanBuffer(hAllocation->GetMemory(), allocationLocalOffset, hBuffer, pNext); - break; - case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: - { - VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); - VMA_ASSERT(pBlock && "Binding buffer to allocation that doesn't belong to any block. Is the allocation lost?"); - res = pBlock->BindBufferMemory(this, hAllocation, allocationLocalOffset, hBuffer, pNext); - break; - } - default: - VMA_ASSERT(0); - } - return res; -} + ... +}; +ConstantBuffer constantBufferData = ... + +VmaAllocator allocator = ... +VkBuffer constantBuffer = ... +VmaAllocation constantBufferAllocation = ... + +// You can map and fill your buffer using following code: + +void* mappedData; +vmaMapMemory(allocator, constantBufferAllocation, &mappedData); +memcpy(mappedData, &constantBufferData, sizeof(constantBufferData)); +vmaUnmapMemory(allocator, constantBufferAllocation); +\endcode + +When mapping, you may see a warning from Vulkan validation layer similar to this one: + +Mapping an image with layout VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL can result in undefined behavior if this memory is used by the device. Only GENERAL or PREINITIALIZED should be used. + +It happens because the library maps entire `VkDeviceMemory` block, where different +types of images and buffers may end up together, especially on GPUs with unified memory like Intel. +You can safely ignore it if you are sure you access only memory of the intended +object that you wanted to map. + + +\section memory_mapping_persistently_mapped_memory Persistently mapped memory + +Kepping your memory persistently mapped is generally OK in Vulkan. +You don't need to unmap it before using its data on the GPU. +The library provides a special feature designed for that: +Allocations made with #VMA_ALLOCATION_CREATE_MAPPED_BIT flag set in +VmaAllocationCreateInfo::flags stay mapped all the time, +so you can just access CPU pointer to it any time +without a need to call any "map" or "unmap" function. +Example: + +\code +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = sizeof(ConstantBuffer); +bufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT; + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | + VMA_ALLOCATION_CREATE_MAPPED_BIT; + +VkBuffer buf; +VmaAllocation alloc; +VmaAllocationInfo allocInfo; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); + +// Buffer is already mapped. You can access its memory. +memcpy(allocInfo.pMappedData, &constantBufferData, sizeof(constantBufferData)); +\endcode + +\note #VMA_ALLOCATION_CREATE_MAPPED_BIT by itself doesn't guarantee that the allocation will end up +in a mappable memory type. +For this, you need to also specify #VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT or +#VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT. +#VMA_ALLOCATION_CREATE_MAPPED_BIT only guarantees that if the memory is `HOST_VISIBLE`, the allocation will be mapped on creation. +For an example of how to make use of this fact, see section \ref usage_patterns_advanced_data_uploading. + +\section memory_mapping_cache_control Cache flush and invalidate + +Memory in Vulkan doesn't need to be unmapped before using it on GPU, +but unless a memory types has `VK_MEMORY_PROPERTY_HOST_COHERENT_BIT` flag set, +you need to manually **invalidate** cache before reading of mapped pointer +and **flush** cache after writing to mapped pointer. +Map/unmap operations don't do that automatically. +Vulkan provides following functions for this purpose `vkFlushMappedMemoryRanges()`, +`vkInvalidateMappedMemoryRanges()`, but this library provides more convenient +functions that refer to given allocation object: vmaFlushAllocation(), +vmaInvalidateAllocation(), +or multiple objects at once: vmaFlushAllocations(), vmaInvalidateAllocations(). + +Regions of memory specified for flush/invalidate must be aligned to +`VkPhysicalDeviceLimits::nonCoherentAtomSize`. This is automatically ensured by the library. +In any memory type that is `HOST_VISIBLE` but not `HOST_COHERENT`, all allocations +within blocks are aligned to this value, so their offsets are always multiply of +`nonCoherentAtomSize` and two different allocations never share same "line" of this size. + +Also, Windows drivers from all 3 PC GPU vendors (AMD, Intel, NVIDIA) +currently provide `HOST_COHERENT` flag on all memory types that are +`HOST_VISIBLE`, so on PC you may not need to bother. + + +\page staying_within_budget Staying within budget + +When developing a graphics-intensive game or program, it is important to avoid allocating +more GPU memory than it is physically available. When the memory is over-committed, +various bad things can happen, depending on the specific GPU, graphics driver, and +operating system: + +- It may just work without any problems. +- The application may slow down because some memory blocks are moved to system RAM + and the GPU has to access them through PCI Express bus. +- A new allocation may take very long time to complete, even few seconds, and possibly + freeze entire system. +- The new allocation may fail with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. +- It may even result in GPU crash (TDR), observed as `VK_ERROR_DEVICE_LOST` + returned somewhere later. + +\section staying_within_budget_querying_for_budget Querying for budget + +To query for current memory usage and available budget, use function vmaGetHeapBudgets(). +Returned structure #VmaBudget contains quantities expressed in bytes, per Vulkan memory heap. + +Please note that this function returns different information and works faster than +vmaCalculateStatistics(). vmaGetHeapBudgets() can be called every frame or even before every +allocation, while vmaCalculateStatistics() is intended to be used rarely, +only to obtain statistical information, e.g. for debugging purposes. + +It is recommended to use VK_EXT_memory_budget device extension to obtain information +about the budget from Vulkan device. VMA is able to use this extension automatically. +When not enabled, the allocator behaves same way, but then it estimates current usage +and available budget based on its internal information and Vulkan memory heap sizes, +which may be less precise. In order to use this extension: + +1. Make sure extensions VK_EXT_memory_budget and VK_KHR_get_physical_device_properties2 + required by it are available and enable them. Please note that the first is a device + extension and the second is instance extension! +2. Use flag #VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT when creating #VmaAllocator object. +3. Make sure to call vmaSetCurrentFrameIndex() every frame. Budget is queried from + Vulkan inside of it to avoid overhead of querying it with every allocation. + +\section staying_within_budget_controlling_memory_usage Controlling memory usage + +There are many ways in which you can try to stay within the budget. + +First, when making new allocation requires allocating a new memory block, the library +tries not to exceed the budget automatically. If a block with default recommended size +(e.g. 256 MB) would go over budget, a smaller block is allocated, possibly even +dedicated memory for just this resource. + +If the size of the requested resource plus current memory usage is more than the +budget, by default the library still tries to create it, leaving it to the Vulkan +implementation whether the allocation succeeds or fails. You can change this behavior +by using #VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT flag. With it, the allocation is +not made if it would exceed the budget or if the budget is already exceeded. +VMA then tries to make the allocation from the next eligible Vulkan memory type. +The all of them fail, the call then fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. +Example usage pattern may be to pass the #VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT flag +when creating resources that are not essential for the application (e.g. the texture +of a specific object) and not to pass it when creating critically important resources +(e.g. render targets). + +On AMD graphics cards there is a custom vendor extension available: VK_AMD_memory_overallocation_behavior +that allows to control the behavior of the Vulkan implementation in out-of-memory cases - +whether it should fail with an error code or still allow the allocation. +Usage of this extension involves only passing extra structure on Vulkan device creation, +so it is out of scope of this library. + +Finally, you can also use #VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT flag to make sure +a new allocation is created only when it fits inside one of the existing memory blocks. +If it would require to allocate a new block, if fails instead with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. +This also ensures that the function call is very fast because it never goes to Vulkan +to obtain a new block. + +\note Creating \ref custom_memory_pools with VmaPoolCreateInfo::minBlockCount +set to more than 0 will currently try to allocate memory blocks without checking whether they +fit within budget. + + +\page resource_aliasing Resource aliasing (overlap) + +New explicit graphics APIs (Vulkan and Direct3D 12), thanks to manual memory +management, give an opportunity to alias (overlap) multiple resources in the +same region of memory - a feature not available in the old APIs (Direct3D 11, OpenGL). +It can be useful to save video memory, but it must be used with caution. + +For example, if you know the flow of your whole render frame in advance, you +are going to use some intermediate textures or buffers only during a small range of render passes, +and you know these ranges don't overlap in time, you can bind these resources to +the same place in memory, even if they have completely different parameters (width, height, format etc.). + +![Resource aliasing (overlap)](../gfx/Aliasing.png) + +Such scenario is possible using VMA, but you need to create your images manually. +Then you need to calculate parameters of an allocation to be made using formula: + +- allocation size = max(size of each image) +- allocation alignment = max(alignment of each image) +- allocation memoryTypeBits = bitwise AND(memoryTypeBits of each image) + +Following example shows two different images bound to the same place in memory, +allocated to fit largest of them. + +\code +// A 512x512 texture to be sampled. +VkImageCreateInfo img1CreateInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; +img1CreateInfo.imageType = VK_IMAGE_TYPE_2D; +img1CreateInfo.extent.width = 512; +img1CreateInfo.extent.height = 512; +img1CreateInfo.extent.depth = 1; +img1CreateInfo.mipLevels = 10; +img1CreateInfo.arrayLayers = 1; +img1CreateInfo.format = VK_FORMAT_R8G8B8A8_SRGB; +img1CreateInfo.tiling = VK_IMAGE_TILING_OPTIMAL; +img1CreateInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; +img1CreateInfo.usage = VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_SAMPLED_BIT; +img1CreateInfo.samples = VK_SAMPLE_COUNT_1_BIT; + +// A full screen texture to be used as color attachment. +VkImageCreateInfo img2CreateInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; +img2CreateInfo.imageType = VK_IMAGE_TYPE_2D; +img2CreateInfo.extent.width = 1920; +img2CreateInfo.extent.height = 1080; +img2CreateInfo.extent.depth = 1; +img2CreateInfo.mipLevels = 1; +img2CreateInfo.arrayLayers = 1; +img2CreateInfo.format = VK_FORMAT_R8G8B8A8_UNORM; +img2CreateInfo.tiling = VK_IMAGE_TILING_OPTIMAL; +img2CreateInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; +img2CreateInfo.usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT; +img2CreateInfo.samples = VK_SAMPLE_COUNT_1_BIT; + +VkImage img1; +res = vkCreateImage(device, &img1CreateInfo, nullptr, &img1); +VkImage img2; +res = vkCreateImage(device, &img2CreateInfo, nullptr, &img2); + +VkMemoryRequirements img1MemReq; +vkGetImageMemoryRequirements(device, img1, &img1MemReq); +VkMemoryRequirements img2MemReq; +vkGetImageMemoryRequirements(device, img2, &img2MemReq); + +VkMemoryRequirements finalMemReq = {}; +finalMemReq.size = std::max(img1MemReq.size, img2MemReq.size); +finalMemReq.alignment = std::max(img1MemReq.alignment, img2MemReq.alignment); +finalMemReq.memoryTypeBits = img1MemReq.memoryTypeBits & img2MemReq.memoryTypeBits; +// Validate if(finalMemReq.memoryTypeBits != 0) + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; -VkResult VmaAllocator_T::BindImageMemory( - VmaAllocation hAllocation, - VkDeviceSize allocationLocalOffset, - VkImage hImage, - const void* pNext) -{ - VkResult res = VK_SUCCESS; - switch(hAllocation->GetType()) - { - case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: - res = BindVulkanImage(hAllocation->GetMemory(), allocationLocalOffset, hImage, pNext); - break; - case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: - { - VmaDeviceMemoryBlock* pBlock = hAllocation->GetBlock(); - VMA_ASSERT(pBlock && "Binding image to allocation that doesn't belong to any block. Is the allocation lost?"); - res = pBlock->BindImageMemory(this, hAllocation, allocationLocalOffset, hImage, pNext); - break; - } - default: - VMA_ASSERT(0); - } - return res; -} +VmaAllocation alloc; +res = vmaAllocateMemory(allocator, &finalMemReq, &allocCreateInfo, &alloc, nullptr); -VkResult VmaAllocator_T::FlushOrInvalidateAllocation( - VmaAllocation hAllocation, - VkDeviceSize offset, VkDeviceSize size, - VMA_CACHE_OPERATION op) -{ - VkResult res = VK_SUCCESS; +res = vmaBindImageMemory(allocator, alloc, img1); +res = vmaBindImageMemory(allocator, alloc, img2); - VkMappedMemoryRange memRange = {}; - if(GetFlushOrInvalidateRange(hAllocation, offset, size, memRange)) - { - switch(op) - { - case VMA_CACHE_FLUSH: - res = (*GetVulkanFunctions().vkFlushMappedMemoryRanges)(m_hDevice, 1, &memRange); - break; - case VMA_CACHE_INVALIDATE: - res = (*GetVulkanFunctions().vkInvalidateMappedMemoryRanges)(m_hDevice, 1, &memRange); - break; - default: - VMA_ASSERT(0); - } - } - // else: Just ignore this call. - return res; -} +// You can use img1, img2 here, but not at the same time! -VkResult VmaAllocator_T::FlushOrInvalidateAllocations( - uint32_t allocationCount, - const VmaAllocation* allocations, - const VkDeviceSize* offsets, const VkDeviceSize* sizes, - VMA_CACHE_OPERATION op) -{ - typedef VmaStlAllocator RangeAllocator; - typedef VmaSmallVector RangeVector; - RangeVector ranges = RangeVector(RangeAllocator(GetAllocationCallbacks())); +vmaFreeMemory(allocator, alloc); +vkDestroyImage(allocator, img2, nullptr); +vkDestroyImage(allocator, img1, nullptr); +\endcode - for (const auto allocIndex : c10::irange(allocationCount)) { - const VmaAllocation alloc = allocations[allocIndex]; - const VkDeviceSize offset = offsets != VMA_NULL ? offsets[allocIndex] : 0; - const VkDeviceSize size = sizes != VMA_NULL ? sizes[allocIndex] : VK_WHOLE_SIZE; - VkMappedMemoryRange newRange; - if(GetFlushOrInvalidateRange(alloc, offset, size, newRange)) - { - ranges.push_back(newRange); - } - } +Remember that using resources that alias in memory requires proper synchronization. +You need to issue a memory barrier to make sure commands that use `img1` and `img2` +don't overlap on GPU timeline. +You also need to treat a resource after aliasing as uninitialized - containing garbage data. +For example, if you use `img1` and then want to use `img2`, you need to issue +an image memory barrier for `img2` with `oldLayout` = `VK_IMAGE_LAYOUT_UNDEFINED`. - VkResult res = VK_SUCCESS; - if(!ranges.empty()) - { - switch(op) - { - case VMA_CACHE_FLUSH: - res = (*GetVulkanFunctions().vkFlushMappedMemoryRanges)(m_hDevice, (uint32_t)ranges.size(), ranges.data()); - break; - case VMA_CACHE_INVALIDATE: - res = (*GetVulkanFunctions().vkInvalidateMappedMemoryRanges)(m_hDevice, (uint32_t)ranges.size(), ranges.data()); - break; - default: - VMA_ASSERT(0); - } - } - // else: Just ignore this call. - return res; -} +Additional considerations: -void VmaAllocator_T::FreeDedicatedMemory(const VmaAllocation allocation) -{ - VMA_ASSERT(allocation && allocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_DEDICATED); +- Vulkan also allows to interpret contents of memory between aliasing resources consistently in some cases. +See chapter 11.8. "Memory Aliasing" of Vulkan specification or `VK_IMAGE_CREATE_ALIAS_BIT` flag. +- You can create more complex layout where different images and buffers are bound +at different offsets inside one large allocation. For example, one can imagine +a big texture used in some render passes, aliasing with a set of many small buffers +used between in some further passes. To bind a resource at non-zero offset in an allocation, +use vmaBindBufferMemory2() / vmaBindImageMemory2(). +- Before allocating memory for the resources you want to alias, check `memoryTypeBits` +returned in memory requirements of each resource to make sure the bits overlap. +Some GPUs may expose multiple memory types suitable e.g. only for buffers or +images with `COLOR_ATTACHMENT` usage, so the sets of memory types supported by your +resources may be disjoint. Aliasing them is not possible in that case. - const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); - { - VmaMutexLockWrite lock(m_DedicatedAllocationsMutex[memTypeIndex], m_UseMutex); - AllocationVectorType* const pDedicatedAllocations = m_pDedicatedAllocations[memTypeIndex]; - VMA_ASSERT(pDedicatedAllocations); - bool success = VmaVectorRemoveSorted(*pDedicatedAllocations, allocation); - VMA_ASSERT(success); - } - VkDeviceMemory hMemory = allocation->GetMemory(); +\page custom_memory_pools Custom memory pools - /* - There is no need to call this, because Vulkan spec allows to skip vkUnmapMemory - before vkFreeMemory. +A memory pool contains a number of `VkDeviceMemory` blocks. +The library automatically creates and manages default pool for each memory type available on the device. +Default memory pool automatically grows in size. +Size of allocated blocks is also variable and managed automatically. - if(allocation->GetMappedData() != VMA_NULL) - { - (*m_VulkanFunctions.vkUnmapMemory)(m_hDevice, hMemory); - } - */ +You can create custom pool and allocate memory out of it. +It can be useful if you want to: - FreeVulkanMemory(memTypeIndex, allocation->GetSize(), hMemory); +- Keep certain kind of allocations separate from others. +- Enforce particular, fixed size of Vulkan memory blocks. +- Limit maximum amount of Vulkan memory allocated for that pool. +- Reserve minimum or fixed amount of Vulkan memory always preallocated for that pool. +- Use extra parameters for a set of your allocations that are available in #VmaPoolCreateInfo but not in + #VmaAllocationCreateInfo - e.g., custom minimum alignment, custom `pNext` chain. +- Perform defragmentation on a specific subset of your allocations. - VMA_DEBUG_LOG(" Freed DedicatedMemory MemoryTypeIndex=%u", memTypeIndex); -} +To use custom memory pools: -uint32_t VmaAllocator_T::CalculateGpuDefragmentationMemoryTypeBits() const -{ - VkBufferCreateInfo dummyBufCreateInfo; - VmaFillGpuDefragmentationBufferCreateInfo(dummyBufCreateInfo); +-# Fill VmaPoolCreateInfo structure. +-# Call vmaCreatePool() to obtain #VmaPool handle. +-# When making an allocation, set VmaAllocationCreateInfo::pool to this handle. + You don't need to specify any other parameters of this structure, like `usage`. - uint32_t memoryTypeBits = 0; +Example: - // Create buffer. - VkBuffer buf = VK_NULL_HANDLE; - VkResult res = (*GetVulkanFunctions().vkCreateBuffer)( - m_hDevice, &dummyBufCreateInfo, GetAllocationCallbacks(), &buf); - if(res == VK_SUCCESS) - { - // Query for supported memory types. - VkMemoryRequirements memReq; - (*GetVulkanFunctions().vkGetBufferMemoryRequirements)(m_hDevice, buf, &memReq); - memoryTypeBits = memReq.memoryTypeBits; +\code +// Find memoryTypeIndex for the pool. +VkBufferCreateInfo sampleBufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +sampleBufCreateInfo.size = 0x10000; // Doesn't matter. +sampleBufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - // Destroy buffer. - (*GetVulkanFunctions().vkDestroyBuffer)(m_hDevice, buf, GetAllocationCallbacks()); - } +VmaAllocationCreateInfo sampleAllocCreateInfo = {}; +sampleAllocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; - return memoryTypeBits; -} +uint32_t memTypeIndex; +VkResult res = vmaFindMemoryTypeIndexForBufferInfo(allocator, + &sampleBufCreateInfo, &sampleAllocCreateInfo, &memTypeIndex); +// Check res... -uint32_t VmaAllocator_T::CalculateGlobalMemoryTypeBits() const -{ - // Make sure memory information is already fetched. - VMA_ASSERT(GetMemoryTypeCount() > 0); +// Create a pool that can have at most 2 blocks, 128 MiB each. +VmaPoolCreateInfo poolCreateInfo = {}; +poolCreateInfo.memoryTypeIndex = memTypeIndex; +poolCreateInfo.blockSize = 128ull * 1024 * 1024; +poolCreateInfo.maxBlockCount = 2; - uint32_t memoryTypeBits = UINT32_MAX; +VmaPool pool; +res = vmaCreatePool(allocator, &poolCreateInfo, &pool); +// Check res... - if(!m_UseAmdDeviceCoherentMemory) - { - // Exclude memory types that have VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD. - for (const auto memTypeIndex : c10::irange(GetMemoryTypeCount())) { - if((m_MemProps.memoryTypes[memTypeIndex].propertyFlags & VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY) != 0) - { - memoryTypeBits &= ~(1u << memTypeIndex); - } - } - } +// Allocate a buffer out of it. +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = 1024; +bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - return memoryTypeBits; -} +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.pool = pool; -bool VmaAllocator_T::GetFlushOrInvalidateRange( - VmaAllocation allocation, - VkDeviceSize offset, VkDeviceSize size, - VkMappedMemoryRange& outRange) const -{ - const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); - if(size > 0 && IsMemoryTypeNonCoherent(memTypeIndex)) - { - const VkDeviceSize nonCoherentAtomSize = m_PhysicalDeviceProperties.limits.nonCoherentAtomSize; - const VkDeviceSize allocationSize = allocation->GetSize(); - VMA_ASSERT(offset <= allocationSize); +VkBuffer buf; +VmaAllocation alloc; +res = vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, nullptr); +// Check res... +\endcode - outRange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; - outRange.pNext = VMA_NULL; - outRange.memory = allocation->GetMemory(); +You have to free all allocations made from this pool before destroying it. - switch(allocation->GetType()) - { - case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: - outRange.offset = VmaAlignDown(offset, nonCoherentAtomSize); - if(size == VK_WHOLE_SIZE) - { - outRange.size = allocationSize - outRange.offset; - } - else - { - VMA_ASSERT(offset + size <= allocationSize); - outRange.size = VMA_MIN( - VmaAlignUp(size + (offset - outRange.offset), nonCoherentAtomSize), - allocationSize - outRange.offset); - } - break; - case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: - { - // 1. Still within this allocation. - outRange.offset = VmaAlignDown(offset, nonCoherentAtomSize); - if(size == VK_WHOLE_SIZE) - { - size = allocationSize - offset; - } - else - { - VMA_ASSERT(offset + size <= allocationSize); - } - outRange.size = VmaAlignUp(size + (offset - outRange.offset), nonCoherentAtomSize); +\code +vmaDestroyBuffer(allocator, buf, alloc); +vmaDestroyPool(allocator, pool); +\endcode + +New versions of this library support creating dedicated allocations in custom pools. +It is supported only when VmaPoolCreateInfo::blockSize = 0. +To use this feature, set VmaAllocationCreateInfo::pool to the pointer to your custom pool and +VmaAllocationCreateInfo::flags to #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. + +\note Excessive use of custom pools is a common mistake when using this library. +Custom pools may be useful for special purposes - when you want to +keep certain type of resources separate e.g. to reserve minimum amount of memory +for them or limit maximum amount of memory they can occupy. For most +resources this is not needed and so it is not recommended to create #VmaPool +objects and allocations out of them. Allocating from the default pool is sufficient. + + +\section custom_memory_pools_MemTypeIndex Choosing memory type index + +When creating a pool, you must explicitly specify memory type index. +To find the one suitable for your buffers or images, you can use helper functions +vmaFindMemoryTypeIndexForBufferInfo(), vmaFindMemoryTypeIndexForImageInfo(). +You need to provide structures with example parameters of buffers or images +that you are going to create in that pool. - // 2. Adjust to whole block. - const VkDeviceSize allocationOffset = allocation->GetOffset(); - VMA_ASSERT(allocationOffset % nonCoherentAtomSize == 0); - const VkDeviceSize blockSize = allocation->GetBlock()->m_pMetadata->GetSize(); - outRange.offset += allocationOffset; - outRange.size = VMA_MIN(outRange.size, blockSize - outRange.offset); +\code +VkBufferCreateInfo exampleBufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +exampleBufCreateInfo.size = 1024; // Doesn't matter +exampleBufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - break; - } - default: - VMA_ASSERT(0); - } - return true; - } - return false; -} +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; -#if VMA_MEMORY_BUDGET +uint32_t memTypeIndex; +vmaFindMemoryTypeIndexForBufferInfo(allocator, &exampleBufCreateInfo, &allocCreateInfo, &memTypeIndex); -void VmaAllocator_T::UpdateVulkanBudget() -{ - VMA_ASSERT(m_UseExtMemoryBudget); +VmaPoolCreateInfo poolCreateInfo = {}; +poolCreateInfo.memoryTypeIndex = memTypeIndex; +// ... +\endcode - VkPhysicalDeviceMemoryProperties2KHR memProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2_KHR }; +When creating buffers/images allocated in that pool, provide following parameters: - VkPhysicalDeviceMemoryBudgetPropertiesEXT budgetProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT }; - VmaPnextChainPushFront(&memProps, &budgetProps); +- `VkBufferCreateInfo`: Prefer to pass same parameters as above. + Otherwise you risk creating resources in a memory type that is not suitable for them, which may result in undefined behavior. + Using different `VK_BUFFER_USAGE_` flags may work, but you shouldn't create images in a pool intended for buffers + or the other way around. +- VmaAllocationCreateInfo: You don't need to pass same parameters. Fill only `pool` member. + Other members are ignored anyway. - GetVulkanFunctions().vkGetPhysicalDeviceMemoryProperties2KHR(m_PhysicalDevice, &memProps); +\section linear_algorithm Linear allocation algorithm - { - VmaMutexLockWrite lockWrite(m_Budget.m_BudgetMutex, m_UseMutex); +Each Vulkan memory block managed by this library has accompanying metadata that +keeps track of used and unused regions. By default, the metadata structure and +algorithm tries to find best place for new allocations among free regions to +optimize memory usage. This way you can allocate and free objects in any order. - for (const auto heapIndex : c10::irange(GetMemoryHeapCount())) { - m_Budget.m_VulkanUsage[heapIndex] = budgetProps.heapUsage[heapIndex]; - m_Budget.m_VulkanBudget[heapIndex] = budgetProps.heapBudget[heapIndex]; - m_Budget.m_BlockBytesAtBudgetFetch[heapIndex] = m_Budget.m_BlockBytes[heapIndex].load(); +![Default allocation algorithm](../gfx/Linear_allocator_1_algo_default.png) - // Some bugged drivers return the budget incorrectly, e.g. 0 or much bigger than heap size. - if(m_Budget.m_VulkanBudget[heapIndex] == 0) - { - m_Budget.m_VulkanBudget[heapIndex] = m_MemProps.memoryHeaps[heapIndex].size * 8 / 10; // 80% heuristics. - } - else if(m_Budget.m_VulkanBudget[heapIndex] > m_MemProps.memoryHeaps[heapIndex].size) - { - m_Budget.m_VulkanBudget[heapIndex] = m_MemProps.memoryHeaps[heapIndex].size; - } - if(m_Budget.m_VulkanUsage[heapIndex] == 0 && m_Budget.m_BlockBytesAtBudgetFetch[heapIndex] > 0) - { - m_Budget.m_VulkanUsage[heapIndex] = m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]; - } - } - m_Budget.m_OperationsSinceBudgetFetch = 0; - } -} +Sometimes there is a need to use simpler, linear allocation algorithm. You can +create custom pool that uses such algorithm by adding flag +#VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT to VmaPoolCreateInfo::flags while creating +#VmaPool object. Then an alternative metadata management is used. It always +creates new allocations after last one and doesn't reuse free regions after +allocations freed in the middle. It results in better allocation performance and +less memory consumed by metadata. -#endif // #if VMA_MEMORY_BUDGET +![Linear allocation algorithm](../gfx/Linear_allocator_2_algo_linear.png) -void VmaAllocator_T::FillAllocation(const VmaAllocation hAllocation, uint8_t pattern) -{ - if(VMA_DEBUG_INITIALIZE_ALLOCATIONS && - !hAllocation->CanBecomeLost() && - (m_MemProps.memoryTypes[hAllocation->GetMemoryTypeIndex()].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0) - { - void* pData = VMA_NULL; - VkResult res = Map(hAllocation, &pData); - if(res == VK_SUCCESS) - { - memset(pData, (int)pattern, (size_t)hAllocation->GetSize()); - FlushOrInvalidateAllocation(hAllocation, 0, VK_WHOLE_SIZE, VMA_CACHE_FLUSH); - Unmap(hAllocation); - } - else - { - VMA_ASSERT(0 && "VMA_DEBUG_INITIALIZE_ALLOCATIONS is enabled, but couldn't map memory to fill allocation."); - } - } -} +With this one flag, you can create a custom pool that can be used in many ways: +free-at-once, stack, double stack, and ring buffer. See below for details. +You don't need to specify explicitly which of these options you are going to use - it is detected automatically. -uint32_t VmaAllocator_T::GetGpuDefragmentationMemoryTypeBits() -{ - uint32_t memoryTypeBits = m_GpuDefragmentationMemoryTypeBits.load(); - if(memoryTypeBits == UINT32_MAX) - { - memoryTypeBits = CalculateGpuDefragmentationMemoryTypeBits(); - m_GpuDefragmentationMemoryTypeBits.store(memoryTypeBits); - } - return memoryTypeBits; -} +\subsection linear_algorithm_free_at_once Free-at-once -#if VMA_STATS_STRING_ENABLED +In a pool that uses linear algorithm, you still need to free all the allocations +individually, e.g. by using vmaFreeMemory() or vmaDestroyBuffer(). You can free +them in any order. New allocations are always made after last one - free space +in the middle is not reused. However, when you release all the allocation and +the pool becomes empty, allocation starts from the beginning again. This way you +can use linear algorithm to speed up creation of allocations that you are going +to release all at once. -void VmaAllocator_T::PrintDetailedMap(VmaJsonWriter& json) -{ - bool dedicatedAllocationsStarted = false; - for (const auto memTypeIndex : c10::irange(GetMemoryTypeCount())) { - VmaMutexLockRead dedicatedAllocationsLock(m_DedicatedAllocationsMutex[memTypeIndex], m_UseMutex); - AllocationVectorType* const pDedicatedAllocVector = m_pDedicatedAllocations[memTypeIndex]; - VMA_ASSERT(pDedicatedAllocVector); - if(pDedicatedAllocVector->empty() == false) - { - if(dedicatedAllocationsStarted == false) - { - dedicatedAllocationsStarted = true; - json.WriteString("DedicatedAllocations"); - json.BeginObject(); - } +![Free-at-once](../gfx/Linear_allocator_3_free_at_once.png) - json.BeginString("Type "); - json.ContinueString(memTypeIndex); - json.EndString(); +This mode is also available for pools created with VmaPoolCreateInfo::maxBlockCount +value that allows multiple memory blocks. - json.BeginArray(); +\subsection linear_algorithm_stack Stack - for(size_t i = 0; i < pDedicatedAllocVector->size(); ++i) - { - json.BeginObject(true); - const VmaAllocation hAlloc = (*pDedicatedAllocVector)[i]; - hAlloc->PrintParameters(json); - json.EndObject(); - } +When you free an allocation that was created last, its space can be reused. +Thanks to this, if you always release allocations in the order opposite to their +creation (LIFO - Last In First Out), you can achieve behavior of a stack. - json.EndArray(); - } - } - if(dedicatedAllocationsStarted) - { - json.EndObject(); - } +![Stack](../gfx/Linear_allocator_4_stack.png) - { - bool allocationsStarted = false; - for (const auto memTypeIndex : c10::irange(GetMemoryTypeCount())) { - if(m_pBlockVectors[memTypeIndex]->IsEmpty() == false) - { - if(allocationsStarted == false) - { - allocationsStarted = true; - json.WriteString("DefaultPools"); - json.BeginObject(); - } +This mode is also available for pools created with VmaPoolCreateInfo::maxBlockCount +value that allows multiple memory blocks. - json.BeginString("Type "); - json.ContinueString(memTypeIndex); - json.EndString(); +\subsection linear_algorithm_double_stack Double stack - m_pBlockVectors[memTypeIndex]->PrintDetailedMap(json); - } - } - if(allocationsStarted) - { - json.EndObject(); - } - } +The space reserved by a custom pool with linear algorithm may be used by two +stacks: - // Custom pools - { - VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); - const size_t poolCount = m_Pools.size(); - if(poolCount > 0) - { - json.WriteString("Pools"); - json.BeginObject(); - for (const auto poolIndex : c10::irange(poolCount)) { - json.BeginString(); - json.ContinueString(m_Pools[poolIndex]->GetId()); - json.EndString(); +- First, default one, growing up from offset 0. +- Second, "upper" one, growing down from the end towards lower offsets. - m_Pools[poolIndex]->m_BlockVector.PrintDetailedMap(json); - } - json.EndObject(); - } - } -} +To make allocation from the upper stack, add flag #VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT +to VmaAllocationCreateInfo::flags. -#endif // #if VMA_STATS_STRING_ENABLED +![Double stack](../gfx/Linear_allocator_7_double_stack.png) -//////////////////////////////////////////////////////////////////////////////// -// Public interface +Double stack is available only in pools with one memory block - +VmaPoolCreateInfo::maxBlockCount must be 1. Otherwise behavior is undefined. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAllocator( - const VmaAllocatorCreateInfo* pCreateInfo, - VmaAllocator* pAllocator) -{ - VMA_ASSERT(pCreateInfo && pAllocator); - VMA_ASSERT(pCreateInfo->vulkanApiVersion == 0 || - (VK_VERSION_MAJOR(pCreateInfo->vulkanApiVersion) == 1 && VK_VERSION_MINOR(pCreateInfo->vulkanApiVersion) <= 2)); - VMA_DEBUG_LOG("vmaCreateAllocator"); - *pAllocator = vma_new(pCreateInfo->pAllocationCallbacks, VmaAllocator_T)(pCreateInfo); - return (*pAllocator)->Init(pCreateInfo); -} +When the two stacks' ends meet so there is not enough space between them for a +new allocation, such allocation fails with usual +`VK_ERROR_OUT_OF_DEVICE_MEMORY` error. -VMA_CALL_PRE void VMA_CALL_POST vmaDestroyAllocator( - VmaAllocator allocator) -{ - if(allocator != VK_NULL_HANDLE) - { - VMA_DEBUG_LOG("vmaDestroyAllocator"); - VkAllocationCallbacks allocationCallbacks = allocator->m_AllocationCallbacks; - vma_delete(&allocationCallbacks, allocator); - } -} +\subsection linear_algorithm_ring_buffer Ring buffer -VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocatorInfo(VmaAllocator allocator, VmaAllocatorInfo* pAllocatorInfo) -{ - VMA_ASSERT(allocator && pAllocatorInfo); - pAllocatorInfo->instance = allocator->m_hInstance; - pAllocatorInfo->physicalDevice = allocator->GetPhysicalDevice(); - pAllocatorInfo->device = allocator->m_hDevice; -} +When you free some allocations from the beginning and there is not enough free space +for a new one at the end of a pool, allocator's "cursor" wraps around to the +beginning and starts allocation there. Thanks to this, if you always release +allocations in the same order as you created them (FIFO - First In First Out), +you can achieve behavior of a ring buffer / queue. -VMA_CALL_PRE void VMA_CALL_POST vmaGetPhysicalDeviceProperties( - VmaAllocator allocator, - const VkPhysicalDeviceProperties **ppPhysicalDeviceProperties) -{ - VMA_ASSERT(allocator && ppPhysicalDeviceProperties); - *ppPhysicalDeviceProperties = &allocator->m_PhysicalDeviceProperties; -} +![Ring buffer](../gfx/Linear_allocator_5_ring_buffer.png) -VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryProperties( - VmaAllocator allocator, - const VkPhysicalDeviceMemoryProperties** ppPhysicalDeviceMemoryProperties) -{ - VMA_ASSERT(allocator && ppPhysicalDeviceMemoryProperties); - *ppPhysicalDeviceMemoryProperties = &allocator->m_MemProps; -} +Ring buffer is available only in pools with one memory block - +VmaPoolCreateInfo::maxBlockCount must be 1. Otherwise behavior is undefined. -VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryTypeProperties( - VmaAllocator allocator, - uint32_t memoryTypeIndex, - VkMemoryPropertyFlags* pFlags) -{ - VMA_ASSERT(allocator && pFlags); - VMA_ASSERT(memoryTypeIndex < allocator->GetMemoryTypeCount()); - *pFlags = allocator->m_MemProps.memoryTypes[memoryTypeIndex].propertyFlags; -} +\note \ref defragmentation is not supported in custom pools created with #VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT. -VMA_CALL_PRE void VMA_CALL_POST vmaSetCurrentFrameIndex( - VmaAllocator allocator, - uint32_t frameIndex) -{ - VMA_ASSERT(allocator); - VMA_ASSERT(frameIndex != VMA_FRAME_INDEX_LOST); - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\page defragmentation Defragmentation + +Interleaved allocations and deallocations of many objects of varying size can +cause fragmentation over time, which can lead to a situation where the library is unable +to find a continuous range of free memory for a new allocation despite there is +enough free space, just scattered across many small free ranges between existing +allocations. + +To mitigate this problem, you can use defragmentation feature. +It doesn't happen automatically though and needs your cooperation, +because VMA is a low level library that only allocates memory. +It cannot recreate buffers and images in a new place as it doesn't remember the contents of `VkBufferCreateInfo` / `VkImageCreateInfo` structures. +It cannot copy their contents as it doesn't record any commands to a command buffer. + +Example: + +\code +VmaDefragmentationInfo defragInfo = {}; +defragInfo.pool = myPool; +defragInfo.flags = VMA_DEFRAGMENTATION_FLAG_ALGORITHM_FAST_BIT; - allocator->SetCurrentFrameIndex(frameIndex); -} +VmaDefragmentationContext defragCtx; +VkResult res = vmaBeginDefragmentation(allocator, &defragInfo, &defragCtx); +// Check res... -VMA_CALL_PRE void VMA_CALL_POST vmaCalculateStats( - VmaAllocator allocator, - VmaStats* pStats) +for(;;) { - VMA_ASSERT(allocator && pStats); - VMA_DEBUG_GLOBAL_MUTEX_LOCK - allocator->CalculateStats(pStats); -} + VmaDefragmentationPassMoveInfo pass; + res = vmaBeginDefragmentationPass(allocator, defragCtx, &pass); + if(res == VK_SUCCESS) + break; + else if(res != VK_INCOMPLETE) + // Handle error... -VMA_CALL_PRE void VMA_CALL_POST vmaGetBudget( - VmaAllocator allocator, - VmaBudget* pBudget) -{ - VMA_ASSERT(allocator && pBudget); - VMA_DEBUG_GLOBAL_MUTEX_LOCK - allocator->GetBudget(pBudget, 0, allocator->GetMemoryHeapCount()); + for(uint32_t i = 0; i < pass.moveCount; ++i) + { + // Inspect pass.pMoves[i].srcAllocation, identify what buffer/image it represents. + VmaAllocationInfo allocInfo; + vmaGetAllocationInfo(allocator, pMoves[i].srcAllocation, &allocInfo); + MyEngineResourceData* resData = (MyEngineResourceData*)allocInfo.pUserData; + + // Recreate and bind this buffer/image at: pass.pMoves[i].dstMemory, pass.pMoves[i].dstOffset. + VkImageCreateInfo imgCreateInfo = ... + VkImage newImg; + res = vkCreateImage(device, &imgCreateInfo, nullptr, &newImg); + // Check res... + res = vmaBindImageMemory(allocator, pMoves[i].dstTmpAllocation, newImg); + // Check res... + + // Issue a vkCmdCopyBuffer/vkCmdCopyImage to copy its content to the new place. + vkCmdCopyImage(cmdBuf, resData->img, ..., newImg, ...); + } + + // Make sure the copy commands finished executing. + vkWaitForFences(...); + + // Destroy old buffers/images bound with pass.pMoves[i].srcAllocation. + for(uint32_t i = 0; i < pass.moveCount; ++i) + { + // ... + vkDestroyImage(device, resData->img, nullptr); + } + + // Update appropriate descriptors to point to the new places... + + res = vmaEndDefragmentationPass(allocator, defragCtx, &pass); + if(res == VK_SUCCESS) + break; + else if(res != VK_INCOMPLETE) + // Handle error... } -#if VMA_STATS_STRING_ENABLED +vmaEndDefragmentation(allocator, defragCtx, nullptr); +\endcode -VMA_CALL_PRE void VMA_CALL_POST vmaBuildStatsString( - VmaAllocator allocator, - char** ppStatsString, - VkBool32 detailedMap) -{ - VMA_ASSERT(allocator && ppStatsString); - VMA_DEBUG_GLOBAL_MUTEX_LOCK +Although functions like vmaCreateBuffer(), vmaCreateImage(), vmaDestroyBuffer(), vmaDestroyImage() +create/destroy an allocation and a buffer/image at once, these are just a shortcut for +creating the resource, allocating memory, and binding them together. +Defragmentation works on memory allocations only. You must handle the rest manually. +Defragmentation is an iterative process that should repreat "passes" as long as related functions +return `VK_INCOMPLETE` not `VK_SUCCESS`. +In each pass: + +1. vmaBeginDefragmentationPass() function call: + - Calculates and returns the list of allocations to be moved in this pass. + Note this can be a time-consuming process. + - Reserves destination memory for them by creating temporary destination allocations + that you can query for their `VkDeviceMemory` + offset using vmaGetAllocationInfo(). +2. Inside the pass, **you should**: + - Inspect the returned list of allocations to be moved. + - Create new buffers/images and bind them at the returned destination temporary allocations. + - Copy data from source to destination resources if necessary. + - Destroy the source buffers/images, but NOT their allocations. +3. vmaEndDefragmentationPass() function call: + - Frees the source memory reserved for the allocations that are moved. + - Modifies source #VmaAllocation objects that are moved to point to the destination reserved memory. + - Frees `VkDeviceMemory` blocks that became empty. + +Unlike in previous iterations of the defragmentation API, there is no list of "movable" allocations passed as a parameter. +Defragmentation algorithm tries to move all suitable allocations. +You can, however, refuse to move some of them inside a defragmentation pass, by setting +`pass.pMoves[i].operation` to #VMA_DEFRAGMENTATION_MOVE_OPERATION_IGNORE. +This is not recommended and may result in suboptimal packing of the allocations after defragmentation. +If you cannot ensure any allocation can be moved, it is better to keep movable allocations separate in a custom pool. + +Inside a pass, for each allocation that should be moved: + +- You should copy its data from the source to the destination place by calling e.g. `vkCmdCopyBuffer()`, `vkCmdCopyImage()`. + - You need to make sure these commands finished executing before destroying the source buffers/images and before calling vmaEndDefragmentationPass(). +- If a resource doesn't contain any meaningful data, e.g. it is a transient color attachment image to be cleared, + filled, and used temporarily in each rendering frame, you can just recreate this image + without copying its data. +- If the resource is in `HOST_VISIBLE` and `HOST_CACHED` memory, you can copy its data on the CPU + using `memcpy()`. +- If you cannot move the allocation, you can set `pass.pMoves[i].operation` to #VMA_DEFRAGMENTATION_MOVE_OPERATION_IGNORE. + This will cancel the move. + - vmaEndDefragmentationPass() will then free the destination memory + not the source memory of the allocation, leaving it unchanged. +- If you decide the allocation is unimportant and can be destroyed instead of moved (e.g. it wasn't used for long time), + you can set `pass.pMoves[i].operation` to #VMA_DEFRAGMENTATION_MOVE_OPERATION_DESTROY. + - vmaEndDefragmentationPass() will then free both source and destination memory, and will destroy the source #VmaAllocation object. + +You can defragment a specific custom pool by setting VmaDefragmentationInfo::pool +(like in the example above) or all the default pools by setting this member to null. + +Defragmentation is always performed in each pool separately. +Allocations are never moved between different Vulkan memory types. +The size of the destination memory reserved for a moved allocation is the same as the original one. +Alignment of an allocation as it was determined using `vkGetBufferMemoryRequirements()` etc. is also respected after defragmentation. +Buffers/images should be recreated with the same `VkBufferCreateInfo` / `VkImageCreateInfo` parameters as the original ones. + +You can perform the defragmentation incrementally to limit the number of allocations and bytes to be moved +in each pass, e.g. to call it in sync with render frames and not to experience too big hitches. +See members: VmaDefragmentationInfo::maxBytesPerPass, VmaDefragmentationInfo::maxAllocationsPerPass. + +It is also safe to perform the defragmentation asynchronously to render frames and other Vulkan and VMA +usage, possibly from multiple threads, with the exception that allocations +returned in VmaDefragmentationPassMoveInfo::pMoves shouldn't be destroyed until the defragmentation pass is ended. + +Mapping is preserved on allocations that are moved during defragmentation. +Whether through #VMA_ALLOCATION_CREATE_MAPPED_BIT or vmaMapMemory(), the allocations +are mapped at their new place. Of course, pointer to the mapped data changes, so it needs to be queried +using VmaAllocationInfo::pMappedData. + +\note Defragmentation is not supported in custom pools created with #VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT. - VmaStringBuilder sb(allocator); - { - VmaJsonWriter json(allocator->GetAllocationCallbacks(), sb); - json.BeginObject(); - VmaBudget budget[VK_MAX_MEMORY_HEAPS]; - allocator->GetBudget(budget, 0, allocator->GetMemoryHeapCount()); +\page statistics Statistics - VmaStats stats; - allocator->CalculateStats(&stats); +This library contains several functions that return information about its internal state, +especially the amount of memory allocated from Vulkan. - json.WriteString("Total"); - VmaPrintStatInfo(json, stats.total); +\section statistics_numeric_statistics Numeric statistics - for(uint32_t heapIndex = 0; heapIndex < allocator->GetMemoryHeapCount(); ++heapIndex) - { - json.BeginString("Heap "); - json.ContinueString(heapIndex); - json.EndString(); - json.BeginObject(); +If you need to obtain basic statistics about memory usage per heap, together with current budget, +you can call function vmaGetHeapBudgets() and inspect structure #VmaBudget. +This is useful to keep track of memory usage and stay withing budget +(see also \ref staying_within_budget). +Example: - json.WriteString("Size"); - json.WriteNumber(allocator->m_MemProps.memoryHeaps[heapIndex].size); +\code +uint32_t heapIndex = ... + +VmaBudget budgets[VK_MAX_MEMORY_HEAPS]; +vmaGetHeapBudgets(allocator, budgets); + +printf("My heap currently has %u allocations taking %llu B,\n", + budgets[heapIndex].statistics.allocationCount, + budgets[heapIndex].statistics.allocationBytes); +printf("allocated out of %u Vulkan device memory blocks taking %llu B,\n", + budgets[heapIndex].statistics.blockCount, + budgets[heapIndex].statistics.blockBytes); +printf("Vulkan reports total usage %llu B with budget %llu B.\n", + budgets[heapIndex].usage, + budgets[heapIndex].budget); +\endcode - json.WriteString("Flags"); - json.BeginArray(true); - if((allocator->m_MemProps.memoryHeaps[heapIndex].flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) != 0) - { - json.WriteString("DEVICE_LOCAL"); - } - json.EndArray(); +You can query for more detailed statistics per memory heap, type, and totals, +including minimum and maximum allocation size and unused range size, +by calling function vmaCalculateStatistics() and inspecting structure #VmaTotalStatistics. +This function is slower though, as it has to traverse all the internal data structures, +so it should be used only for debugging purposes. - json.WriteString("Budget"); - json.BeginObject(); - { - json.WriteString("BlockBytes"); - json.WriteNumber(budget[heapIndex].blockBytes); - json.WriteString("AllocationBytes"); - json.WriteNumber(budget[heapIndex].allocationBytes); - json.WriteString("Usage"); - json.WriteNumber(budget[heapIndex].usage); - json.WriteString("Budget"); - json.WriteNumber(budget[heapIndex].budget); - } - json.EndObject(); +You can query for statistics of a custom pool using function vmaGetPoolStatistics() +or vmaCalculatePoolStatistics(). - if(stats.memoryHeap[heapIndex].blockCount > 0) - { - json.WriteString("Stats"); - VmaPrintStatInfo(json, stats.memoryHeap[heapIndex]); - } +You can query for information about a specific allocation using function vmaGetAllocationInfo(). +It fill structure #VmaAllocationInfo. - for(uint32_t typeIndex = 0; typeIndex < allocator->GetMemoryTypeCount(); ++typeIndex) - { - if(allocator->MemoryTypeIndexToHeapIndex(typeIndex) == heapIndex) - { - json.BeginString("Type "); - json.ContinueString(typeIndex); - json.EndString(); +\section statistics_json_dump JSON dump - json.BeginObject(); +You can dump internal state of the allocator to a string in JSON format using function vmaBuildStatsString(). +The result is guaranteed to be correct JSON. +It uses ANSI encoding. +Any strings provided by user (see [Allocation names](@ref allocation_names)) +are copied as-is and properly escaped for JSON, so if they use UTF-8, ISO-8859-2 or any other encoding, +this JSON string can be treated as using this encoding. +It must be freed using function vmaFreeStatsString(). - json.WriteString("Flags"); - json.BeginArray(true); - VkMemoryPropertyFlags flags = allocator->m_MemProps.memoryTypes[typeIndex].propertyFlags; - if((flags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT) != 0) - { - json.WriteString("DEVICE_LOCAL"); - } - if((flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0) - { - json.WriteString("HOST_VISIBLE"); - } - if((flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT) != 0) - { - json.WriteString("HOST_COHERENT"); - } - if((flags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT) != 0) - { - json.WriteString("HOST_CACHED"); - } - if((flags & VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT) != 0) - { - json.WriteString("LAZILY_ALLOCATED"); - } -#if VMA_VULKAN_VERSION >= 1001000 - if((flags & VK_MEMORY_PROPERTY_PROTECTED_BIT) != 0) - { - json.WriteString("PROTECTED"); - } -#endif // #if VMA_VULKAN_VERSION >= 1001000 -#if VK_AMD_device_coherent_memory - if((flags & VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY) != 0) - { - json.WriteString("DEVICE_COHERENT"); - } - if((flags & VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY) != 0) - { - json.WriteString("DEVICE_UNCACHED"); - } -#endif // #if VK_AMD_device_coherent_memory - json.EndArray(); +The format of this JSON string is not part of official documentation of the library, +but it will not change in backward-incompatible way without increasing library major version number +and appropriate mention in changelog. - if(stats.memoryType[typeIndex].blockCount > 0) - { - json.WriteString("Stats"); - VmaPrintStatInfo(json, stats.memoryType[typeIndex]); - } +The JSON string contains all the data that can be obtained using vmaCalculateStatistics(). +It can also contain detailed map of allocated memory blocks and their regions - +free and occupied by allocations. +This allows e.g. to visualize the memory or assess fragmentation. - json.EndObject(); - } - } - json.EndObject(); - } - if(detailedMap == VK_TRUE) - { - allocator->PrintDetailedMap(json); - } +\page allocation_annotation Allocation names and user data - json.EndObject(); - } +\section allocation_user_data Allocation user data - const size_t len = sb.GetLength(); - char* const pChars = vma_new_array(allocator, char, len + 1); - if(len > 0) - { - memcpy(pChars, sb.GetData(), len); - } - pChars[len] = '\0'; - *ppStatsString = pChars; -} +You can annotate allocations with your own information, e.g. for debugging purposes. +To do that, fill VmaAllocationCreateInfo::pUserData field when creating +an allocation. It is an opaque `void*` pointer. You can use it e.g. as a pointer, +some handle, index, key, ordinal number or any other value that would associate +the allocation with your custom metadata. +It it useful to identify appropriate data structures in your engine given #VmaAllocation, +e.g. when doing \ref defragmentation. -VMA_CALL_PRE void VMA_CALL_POST vmaFreeStatsString( - VmaAllocator allocator, - char* pStatsString) -{ - if(pStatsString != VMA_NULL) - { - VMA_ASSERT(allocator); - size_t len = strlen(pStatsString); - vma_delete_array(allocator, pStatsString, len + 1); - } -} +\code +VkBufferCreateInfo bufCreateInfo = ... -#endif // #if VMA_STATS_STRING_ENABLED +MyBufferMetadata* pMetadata = CreateBufferMetadata(); -/* -This function is not protected by any mutex because it just reads immutable data. -*/ -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndex( - VmaAllocator allocator, - uint32_t memoryTypeBits, - const VmaAllocationCreateInfo* pAllocationCreateInfo, - uint32_t* pMemoryTypeIndex) -{ - VMA_ASSERT(allocator != VK_NULL_HANDLE); - VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); - VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; +allocCreateInfo.pUserData = pMetadata; - memoryTypeBits &= allocator->GetGlobalMemoryTypeBits(); +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buffer, &allocation, nullptr); +\endcode - if(pAllocationCreateInfo->memoryTypeBits != 0) - { - memoryTypeBits &= pAllocationCreateInfo->memoryTypeBits; - } +The pointer may be later retrieved as VmaAllocationInfo::pUserData: - uint32_t requiredFlags = pAllocationCreateInfo->requiredFlags; - uint32_t preferredFlags = pAllocationCreateInfo->preferredFlags; - uint32_t notPreferredFlags = 0; +\code +VmaAllocationInfo allocInfo; +vmaGetAllocationInfo(allocator, allocation, &allocInfo); +MyBufferMetadata* pMetadata = (MyBufferMetadata*)allocInfo.pUserData; +\endcode - // Convert usage to requiredFlags and preferredFlags. - switch(pAllocationCreateInfo->usage) - { - case VMA_MEMORY_USAGE_UNKNOWN: - break; - case VMA_MEMORY_USAGE_GPU_ONLY: - if(!allocator->IsIntegratedGpu() || (preferredFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) - { - preferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; - } - break; - case VMA_MEMORY_USAGE_CPU_ONLY: - requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - break; - case VMA_MEMORY_USAGE_CPU_TO_GPU: - requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - if(!allocator->IsIntegratedGpu() || (preferredFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) - { - preferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; - } - break; - case VMA_MEMORY_USAGE_GPU_TO_CPU: - requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - preferredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - break; - case VMA_MEMORY_USAGE_CPU_COPY: - notPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; - break; - case VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED: - requiredFlags |= VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT; - break; - default: - VMA_ASSERT(0); - break; - } +It can also be changed using function vmaSetAllocationUserData(). - // Avoid DEVICE_COHERENT unless explicitly requested. - if(((pAllocationCreateInfo->requiredFlags | pAllocationCreateInfo->preferredFlags) & - (VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY | VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY)) == 0) - { - notPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY; - } +Values of (non-zero) allocations' `pUserData` are printed in JSON report created by +vmaBuildStatsString() in hexadecimal form. - *pMemoryTypeIndex = UINT32_MAX; - uint32_t minCost = UINT32_MAX; - for(uint32_t memTypeIndex = 0, memTypeBit = 1; - memTypeIndex < allocator->GetMemoryTypeCount(); - ++memTypeIndex, memTypeBit <<= 1) - { - // This memory type is acceptable according to memoryTypeBits bitmask. - if((memTypeBit & memoryTypeBits) != 0) - { - const VkMemoryPropertyFlags currFlags = - allocator->m_MemProps.memoryTypes[memTypeIndex].propertyFlags; - // This memory type contains requiredFlags. - if((requiredFlags & ~currFlags) == 0) - { - // Calculate cost as number of bits from preferredFlags not present in this memory type. - uint32_t currCost = VmaCountBitsSet(preferredFlags & ~currFlags) + - VmaCountBitsSet(currFlags & notPreferredFlags); - // Remember memory type with lowest cost. - if(currCost < minCost) - { - *pMemoryTypeIndex = memTypeIndex; - if(currCost == 0) - { - return VK_SUCCESS; - } - minCost = currCost; - } - } - } - } - return (*pMemoryTypeIndex != UINT32_MAX) ? VK_SUCCESS : VK_ERROR_FEATURE_NOT_PRESENT; -} +\section allocation_names Allocation names -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForBufferInfo( - VmaAllocator allocator, - const VkBufferCreateInfo* pBufferCreateInfo, - const VmaAllocationCreateInfo* pAllocationCreateInfo, - uint32_t* pMemoryTypeIndex) -{ - VMA_ASSERT(allocator != VK_NULL_HANDLE); - VMA_ASSERT(pBufferCreateInfo != VMA_NULL); - VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); - VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); +An allocation can also carry a null-terminated string, giving a name to the allocation. +To set it, call vmaSetAllocationName(). +The library creates internal copy of the string, so the pointer you pass doesn't need +to be valid for whole lifetime of the allocation. You can free it after the call. - const VkDevice hDev = allocator->m_hDevice; - VkBuffer hBuffer = VK_NULL_HANDLE; - VkResult res = allocator->GetVulkanFunctions().vkCreateBuffer( - hDev, pBufferCreateInfo, allocator->GetAllocationCallbacks(), &hBuffer); - if(res == VK_SUCCESS) - { - VkMemoryRequirements memReq = {}; - allocator->GetVulkanFunctions().vkGetBufferMemoryRequirements( - hDev, hBuffer, &memReq); +\code +std::string imageName = "Texture: "; +imageName += fileName; +vmaSetAllocationName(allocator, allocation, imageName.c_str()); +\endcode - res = vmaFindMemoryTypeIndex( - allocator, - memReq.memoryTypeBits, - pAllocationCreateInfo, - pMemoryTypeIndex); +The string can be later retrieved by inspecting VmaAllocationInfo::pName. +It is also printed in JSON report created by vmaBuildStatsString(). - allocator->GetVulkanFunctions().vkDestroyBuffer( - hDev, hBuffer, allocator->GetAllocationCallbacks()); - } - return res; -} +\note Setting string name to VMA allocation doesn't automatically set it to the Vulkan buffer or image created with it. +You must do it manually using an extension like VK_EXT_debug_utils, which is independent of this library. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForImageInfo( - VmaAllocator allocator, - const VkImageCreateInfo* pImageCreateInfo, - const VmaAllocationCreateInfo* pAllocationCreateInfo, - uint32_t* pMemoryTypeIndex) -{ - VMA_ASSERT(allocator != VK_NULL_HANDLE); - VMA_ASSERT(pImageCreateInfo != VMA_NULL); - VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); - VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); - const VkDevice hDev = allocator->m_hDevice; - VkImage hImage = VK_NULL_HANDLE; - VkResult res = allocator->GetVulkanFunctions().vkCreateImage( - hDev, pImageCreateInfo, allocator->GetAllocationCallbacks(), &hImage); - if(res == VK_SUCCESS) - { - VkMemoryRequirements memReq = {}; - allocator->GetVulkanFunctions().vkGetImageMemoryRequirements( - hDev, hImage, &memReq); +\page virtual_allocator Virtual allocator - res = vmaFindMemoryTypeIndex( - allocator, - memReq.memoryTypeBits, - pAllocationCreateInfo, - pMemoryTypeIndex); +As an extra feature, the core allocation algorithm of the library is exposed through a simple and convenient API of "virtual allocator". +It doesn't allocate any real GPU memory. It just keeps track of used and free regions of a "virtual block". +You can use it to allocate your own memory or other objects, even completely unrelated to Vulkan. +A common use case is sub-allocation of pieces of one large GPU buffer. - allocator->GetVulkanFunctions().vkDestroyImage( - hDev, hImage, allocator->GetAllocationCallbacks()); - } - return res; -} +\section virtual_allocator_creating_virtual_block Creating virtual block -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreatePool( - VmaAllocator allocator, - const VmaPoolCreateInfo* pCreateInfo, - VmaPool* pPool) -{ - VMA_ASSERT(allocator && pCreateInfo && pPool); +To use this functionality, there is no main "allocator" object. +You don't need to have #VmaAllocator object created. +All you need to do is to create a separate #VmaVirtualBlock object for each block of memory you want to be managed by the allocator: + +-# Fill in #VmaVirtualBlockCreateInfo structure. +-# Call vmaCreateVirtualBlock(). Get new #VmaVirtualBlock object. + +Example: + +\code +VmaVirtualBlockCreateInfo blockCreateInfo = {}; +blockCreateInfo.size = 1048576; // 1 MB + +VmaVirtualBlock block; +VkResult res = vmaCreateVirtualBlock(&blockCreateInfo, &block); +\endcode + +\section virtual_allocator_making_virtual_allocations Making virtual allocations + +#VmaVirtualBlock object contains internal data structure that keeps track of free and occupied regions +using the same code as the main Vulkan memory allocator. +Similarly to #VmaAllocation for standard GPU allocations, there is #VmaVirtualAllocation type +that represents an opaque handle to an allocation withing the virtual block. - VMA_DEBUG_LOG("vmaCreatePool"); +In order to make such allocation: - VMA_DEBUG_GLOBAL_MUTEX_LOCK +-# Fill in #VmaVirtualAllocationCreateInfo structure. +-# Call vmaVirtualAllocate(). Get new #VmaVirtualAllocation object that represents the allocation. + You can also receive `VkDeviceSize offset` that was assigned to the allocation. - VkResult res = allocator->CreatePool(pCreateInfo, pPool); +Example: -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordCreatePool(allocator->GetCurrentFrameIndex(), *pCreateInfo, *pPool); - } -#endif +\code +VmaVirtualAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.size = 4096; // 4 KB - return res; +VmaVirtualAllocation alloc; +VkDeviceSize offset; +res = vmaVirtualAllocate(block, &allocCreateInfo, &alloc, &offset); +if(res == VK_SUCCESS) +{ + // Use the 4 KB of your memory starting at offset. } - -VMA_CALL_PRE void VMA_CALL_POST vmaDestroyPool( - VmaAllocator allocator, - VmaPool pool) +else { - VMA_ASSERT(allocator); + // Allocation failed - no space for it could be found. Handle this error! +} +\endcode - if(pool == VK_NULL_HANDLE) - { - return; - } +\section virtual_allocator_deallocation Deallocation - VMA_DEBUG_LOG("vmaDestroyPool"); +When no longer needed, an allocation can be freed by calling vmaVirtualFree(). +You can only pass to this function an allocation that was previously returned by vmaVirtualAllocate() +called for the same #VmaVirtualBlock. - VMA_DEBUG_GLOBAL_MUTEX_LOCK +When whole block is no longer needed, the block object can be released by calling vmaDestroyVirtualBlock(). +All allocations must be freed before the block is destroyed, which is checked internally by an assert. +However, if you don't want to call vmaVirtualFree() for each allocation, you can use vmaClearVirtualBlock() to free them all at once - +a feature not available in normal Vulkan memory allocator. Example: -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordDestroyPool(allocator->GetCurrentFrameIndex(), pool); - } -#endif +\code +vmaVirtualFree(block, alloc); +vmaDestroyVirtualBlock(block); +\endcode - allocator->DestroyPool(pool); -} +\section virtual_allocator_allocation_parameters Allocation parameters -VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolStats( - VmaAllocator allocator, - VmaPool pool, - VmaPoolStats* pPoolStats) +You can attach a custom pointer to each allocation by using vmaSetVirtualAllocationUserData(). +Its default value is null. +It can be used to store any data that needs to be associated with that allocation - e.g. an index, a handle, or a pointer to some +larger data structure containing more information. Example: + +\code +struct CustomAllocData { - VMA_ASSERT(allocator && pool && pPoolStats); + std::string m_AllocName; +}; +CustomAllocData* allocData = new CustomAllocData(); +allocData->m_AllocName = "My allocation 1"; +vmaSetVirtualAllocationUserData(block, alloc, allocData); +\endcode - VMA_DEBUG_GLOBAL_MUTEX_LOCK +The pointer can later be fetched, along with allocation offset and size, by passing the allocation handle to function +vmaGetVirtualAllocationInfo() and inspecting returned structure #VmaVirtualAllocationInfo. +If you allocated a new object to be used as the custom pointer, don't forget to delete that object before freeing the allocation! +Example: - allocator->GetPoolStats(pool, pPoolStats); -} +\code +VmaVirtualAllocationInfo allocInfo; +vmaGetVirtualAllocationInfo(block, alloc, &allocInfo); +delete (CustomAllocData*)allocInfo.pUserData; -VMA_CALL_PRE void VMA_CALL_POST vmaMakePoolAllocationsLost( - VmaAllocator allocator, - VmaPool pool, - size_t* pLostAllocationCount) -{ - VMA_ASSERT(allocator && pool); +vmaVirtualFree(block, alloc); +\endcode - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\section virtual_allocator_alignment_and_units Alignment and units -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordMakePoolAllocationsLost(allocator->GetCurrentFrameIndex(), pool); - } -#endif +It feels natural to express sizes and offsets in bytes. +If an offset of an allocation needs to be aligned to a multiply of some number (e.g. 4 bytes), you can fill optional member +VmaVirtualAllocationCreateInfo::alignment to request it. Example: - allocator->MakePoolAllocationsLost(pool, pLostAllocationCount); -} +\code +VmaVirtualAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.size = 4096; // 4 KB +allocCreateInfo.alignment = 4; // Returned offset must be a multiply of 4 B -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckPoolCorruption(VmaAllocator allocator, VmaPool pool) -{ - VMA_ASSERT(allocator && pool); +VmaVirtualAllocation alloc; +res = vmaVirtualAllocate(block, &allocCreateInfo, &alloc, nullptr); +\endcode - VMA_DEBUG_GLOBAL_MUTEX_LOCK +Alignments of different allocations made from one block may vary. +However, if all alignments and sizes are always multiply of some size e.g. 4 B or `sizeof(MyDataStruct)`, +you can express all sizes, alignments, and offsets in multiples of that size instead of individual bytes. +It might be more convenient, but you need to make sure to use this new unit consistently in all the places: - VMA_DEBUG_LOG("vmaCheckPoolCorruption"); +- VmaVirtualBlockCreateInfo::size +- VmaVirtualAllocationCreateInfo::size and VmaVirtualAllocationCreateInfo::alignment +- Using offset returned by vmaVirtualAllocate() or in VmaVirtualAllocationInfo::offset - return allocator->CheckPoolCorruption(pool); -} +\section virtual_allocator_statistics Statistics -VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolName( - VmaAllocator allocator, - VmaPool pool, - const char** ppName) -{ - VMA_ASSERT(allocator && pool && ppName); +You can obtain statistics of a virtual block using vmaGetVirtualBlockStatistics() +(to get brief statistics that are fast to calculate) +or vmaCalculateVirtualBlockStatistics() (to get more detailed statistics, slower to calculate). +The functions fill structures #VmaStatistics, #VmaDetailedStatistics respectively - same as used by the normal Vulkan memory allocator. +Example: - VMA_DEBUG_LOG("vmaGetPoolName"); +\code +VmaStatistics stats; +vmaGetVirtualBlockStatistics(block, &stats); +printf("My virtual block has %llu bytes used by %u virtual allocations\n", + stats.allocationBytes, stats.allocationCount); +\endcode - VMA_DEBUG_GLOBAL_MUTEX_LOCK +You can also request a full list of allocations and free regions as a string in JSON format by calling +vmaBuildVirtualBlockStatsString(). +Returned string must be later freed using vmaFreeVirtualBlockStatsString(). +The format of this string differs from the one returned by the main Vulkan allocator, but it is similar. - *ppName = pool->GetName(); -} +\section virtual_allocator_additional_considerations Additional considerations -VMA_CALL_PRE void VMA_CALL_POST vmaSetPoolName( - VmaAllocator allocator, - VmaPool pool, - const char* pName) -{ - VMA_ASSERT(allocator && pool); +The "virtual allocator" functionality is implemented on a level of individual memory blocks. +Keeping track of a whole collection of blocks, allocating new ones when out of free space, +deleting empty ones, and deciding which one to try first for a new allocation must be implemented by the user. - VMA_DEBUG_LOG("vmaSetPoolName"); +Alternative allocation algorithms are supported, just like in custom pools of the real GPU memory. +See enum #VmaVirtualBlockCreateFlagBits to learn how to specify them (e.g. #VMA_VIRTUAL_BLOCK_CREATE_LINEAR_ALGORITHM_BIT). +You can find their description in chapter \ref custom_memory_pools. +Allocation strategies are also supported. +See enum #VmaVirtualAllocationCreateFlagBits to learn how to specify them (e.g. #VMA_VIRTUAL_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT). - VMA_DEBUG_GLOBAL_MUTEX_LOCK +Following features are supported only by the allocator of the real GPU memory and not by virtual allocations: +buffer-image granularity, `VMA_DEBUG_MARGIN`, `VMA_MIN_ALIGNMENT`. - pool->SetName(pName); -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordSetPoolName(allocator->GetCurrentFrameIndex(), pool, pName); - } -#endif -} +\page debugging_memory_usage Debugging incorrect memory usage -VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemory( - VmaAllocator allocator, - const VkMemoryRequirements* pVkMemoryRequirements, - const VmaAllocationCreateInfo* pCreateInfo, - VmaAllocation* pAllocation, - VmaAllocationInfo* pAllocationInfo) -{ - VMA_ASSERT(allocator && pVkMemoryRequirements && pCreateInfo && pAllocation); +If you suspect a bug with memory usage, like usage of uninitialized memory or +memory being overwritten out of bounds of an allocation, +you can use debug features of this library to verify this. - VMA_DEBUG_LOG("vmaAllocateMemory"); +\section debugging_memory_usage_initialization Memory initialization - VMA_DEBUG_GLOBAL_MUTEX_LOCK +If you experience a bug with incorrect and nondeterministic data in your program and you suspect uninitialized memory to be used, +you can enable automatic memory initialization to verify this. +To do it, define macro `VMA_DEBUG_INITIALIZE_ALLOCATIONS` to 1. - VkResult result = allocator->AllocateMemory( - *pVkMemoryRequirements, - false, // requiresDedicatedAllocation - false, // prefersDedicatedAllocation - VK_NULL_HANDLE, // dedicatedBuffer - UINT32_MAX, // dedicatedBufferUsage - VK_NULL_HANDLE, // dedicatedImage - *pCreateInfo, - VMA_SUBALLOCATION_TYPE_UNKNOWN, - 1, // allocationCount - pAllocation); +\code +#define VMA_DEBUG_INITIALIZE_ALLOCATIONS 1 +#include "vk_mem_alloc.h" +\endcode -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordAllocateMemory( - allocator->GetCurrentFrameIndex(), - *pVkMemoryRequirements, - *pCreateInfo, - *pAllocation); - } -#endif +It makes memory of new allocations initialized to bit pattern `0xDCDCDCDC`. +Before an allocation is destroyed, its memory is filled with bit pattern `0xEFEFEFEF`. +Memory is automatically mapped and unmapped if necessary. - if(pAllocationInfo != VMA_NULL && result == VK_SUCCESS) - { - allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); - } +If you find these values while debugging your program, good chances are that you incorrectly +read Vulkan memory that is allocated but not initialized, or already freed, respectively. - return result; -} +Memory initialization works only with memory types that are `HOST_VISIBLE` and with allocations that can be mapped. +It works also with dedicated allocations. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryPages( - VmaAllocator allocator, - const VkMemoryRequirements* pVkMemoryRequirements, - const VmaAllocationCreateInfo* pCreateInfo, - size_t allocationCount, - VmaAllocation* pAllocations, - VmaAllocationInfo* pAllocationInfo) -{ - if(allocationCount == 0) - { - return VK_SUCCESS; - } +\section debugging_memory_usage_margins Margins - VMA_ASSERT(allocator && pVkMemoryRequirements && pCreateInfo && pAllocations); +By default, allocations are laid out in memory blocks next to each other if possible +(considering required alignment, `bufferImageGranularity`, and `nonCoherentAtomSize`). - VMA_DEBUG_LOG("vmaAllocateMemoryPages"); +![Allocations without margin](../gfx/Margins_1.png) - VMA_DEBUG_GLOBAL_MUTEX_LOCK +Define macro `VMA_DEBUG_MARGIN` to some non-zero value (e.g. 16) to enforce specified +number of bytes as a margin after every allocation. - VkResult result = allocator->AllocateMemory( - *pVkMemoryRequirements, - false, // requiresDedicatedAllocation - false, // prefersDedicatedAllocation - VK_NULL_HANDLE, // dedicatedBuffer - UINT32_MAX, // dedicatedBufferUsage - VK_NULL_HANDLE, // dedicatedImage - *pCreateInfo, - VMA_SUBALLOCATION_TYPE_UNKNOWN, - allocationCount, - pAllocations); +\code +#define VMA_DEBUG_MARGIN 16 +#include "vk_mem_alloc.h" +\endcode -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordAllocateMemoryPages( - allocator->GetCurrentFrameIndex(), - *pVkMemoryRequirements, - *pCreateInfo, - (uint64_t)allocationCount, - pAllocations); - } -#endif +![Allocations with margin](../gfx/Margins_2.png) - if(pAllocationInfo != VMA_NULL && result == VK_SUCCESS) - { - for (const auto i : c10::irange(allocationCount)) { - allocator->GetAllocationInfo(pAllocations[i], pAllocationInfo + i); - } - } +If your bug goes away after enabling margins, it means it may be caused by memory +being overwritten outside of allocation boundaries. It is not 100% certain though. +Change in application behavior may also be caused by different order and distribution +of allocations across memory blocks after margins are applied. - return result; -} +Margins work with all types of memory. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForBuffer( - VmaAllocator allocator, - VkBuffer buffer, - const VmaAllocationCreateInfo* pCreateInfo, - VmaAllocation* pAllocation, - VmaAllocationInfo* pAllocationInfo) -{ - VMA_ASSERT(allocator && buffer != VK_NULL_HANDLE && pCreateInfo && pAllocation); +Margin is applied only to allocations made out of memory blocks and not to dedicated +allocations, which have their own memory block of specific size. +It is thus not applied to allocations made using #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT flag +or those automatically decided to put into dedicated allocations, e.g. due to its +large size or recommended by VK_KHR_dedicated_allocation extension. - VMA_DEBUG_LOG("vmaAllocateMemoryForBuffer"); +Margins appear in [JSON dump](@ref statistics_json_dump) as part of free space. - VMA_DEBUG_GLOBAL_MUTEX_LOCK +Note that enabling margins increases memory usage and fragmentation. - VkMemoryRequirements vkMemReq = {}; - bool requiresDedicatedAllocation = false; - bool prefersDedicatedAllocation = false; - allocator->GetBufferMemoryRequirements(buffer, vkMemReq, - requiresDedicatedAllocation, - prefersDedicatedAllocation); +Margins do not apply to \ref virtual_allocator. - VkResult result = allocator->AllocateMemory( - vkMemReq, - requiresDedicatedAllocation, - prefersDedicatedAllocation, - buffer, // dedicatedBuffer - UINT32_MAX, // dedicatedBufferUsage - VK_NULL_HANDLE, // dedicatedImage - *pCreateInfo, - VMA_SUBALLOCATION_TYPE_BUFFER, - 1, // allocationCount - pAllocation); +\section debugging_memory_usage_corruption_detection Corruption detection -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordAllocateMemoryForBuffer( - allocator->GetCurrentFrameIndex(), - vkMemReq, - requiresDedicatedAllocation, - prefersDedicatedAllocation, - *pCreateInfo, - *pAllocation); - } -#endif +You can additionally define macro `VMA_DEBUG_DETECT_CORRUPTION` to 1 to enable validation +of contents of the margins. - if(pAllocationInfo && result == VK_SUCCESS) - { - allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); - } +\code +#define VMA_DEBUG_MARGIN 16 +#define VMA_DEBUG_DETECT_CORRUPTION 1 +#include "vk_mem_alloc.h" +\endcode - return result; -} +When this feature is enabled, number of bytes specified as `VMA_DEBUG_MARGIN` +(it must be multiply of 4) after every allocation is filled with a magic number. +This idea is also know as "canary". +Memory is automatically mapped and unmapped if necessary. + +This number is validated automatically when the allocation is destroyed. +If it is not equal to the expected value, `VMA_ASSERT()` is executed. +It clearly means that either CPU or GPU overwritten the memory outside of boundaries of the allocation, +which indicates a serious bug. + +You can also explicitly request checking margins of all allocations in all memory blocks +that belong to specified memory types by using function vmaCheckCorruption(), +or in memory blocks that belong to specified custom pool, by using function +vmaCheckPoolCorruption(). + +Margin validation (corruption detection) works only for memory types that are +`HOST_VISIBLE` and `HOST_COHERENT`. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForImage( - VmaAllocator allocator, - VkImage image, - const VmaAllocationCreateInfo* pCreateInfo, - VmaAllocation* pAllocation, - VmaAllocationInfo* pAllocationInfo) -{ - VMA_ASSERT(allocator && image != VK_NULL_HANDLE && pCreateInfo && pAllocation); - VMA_DEBUG_LOG("vmaAllocateMemoryForImage"); +\page opengl_interop OpenGL Interop - VMA_DEBUG_GLOBAL_MUTEX_LOCK +VMA provides some features that help with interoperability with OpenGL. - VkMemoryRequirements vkMemReq = {}; - bool requiresDedicatedAllocation = false; - bool prefersDedicatedAllocation = false; - allocator->GetImageMemoryRequirements(image, vkMemReq, - requiresDedicatedAllocation, prefersDedicatedAllocation); +\section opengl_interop_exporting_memory Exporting memory - VkResult result = allocator->AllocateMemory( - vkMemReq, - requiresDedicatedAllocation, - prefersDedicatedAllocation, - VK_NULL_HANDLE, // dedicatedBuffer - UINT32_MAX, // dedicatedBufferUsage - image, // dedicatedImage - *pCreateInfo, - VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN, - 1, // allocationCount - pAllocation); +If you want to attach `VkExportMemoryAllocateInfoKHR` structure to `pNext` chain of memory allocations made by the library: -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordAllocateMemoryForImage( - allocator->GetCurrentFrameIndex(), - vkMemReq, - requiresDedicatedAllocation, - prefersDedicatedAllocation, - *pCreateInfo, - *pAllocation); - } -#endif +It is recommended to create \ref custom_memory_pools for such allocations. +Define and fill in your `VkExportMemoryAllocateInfoKHR` structure and attach it to VmaPoolCreateInfo::pMemoryAllocateNext +while creating the custom pool. +Please note that the structure must remain alive and unchanged for the whole lifetime of the #VmaPool, +not only while creating it, as no copy of the structure is made, +but its original pointer is used for each allocation instead. - if(pAllocationInfo && result == VK_SUCCESS) - { - allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); - } +If you want to export all memory allocated by the library from certain memory types, +also dedicated allocations or other allocations made from default pools, +an alternative solution is to fill in VmaAllocatorCreateInfo::pTypeExternalMemoryHandleTypes. +It should point to an array with `VkExternalMemoryHandleTypeFlagsKHR` to be automatically passed by the library +through `VkExportMemoryAllocateInfoKHR` on each allocation made from a specific memory type. +Please note that new versions of the library also support dedicated allocations created in custom pools. - return result; -} +You should not mix these two methods in a way that allows to apply both to the same memory type. +Otherwise, `VkExportMemoryAllocateInfoKHR` structure would be attached twice to the `pNext` chain of `VkMemoryAllocateInfo`. -VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemory( - VmaAllocator allocator, - VmaAllocation allocation) -{ - VMA_ASSERT(allocator); - if(allocation == VK_NULL_HANDLE) - { - return; - } +\section opengl_interop_custom_alignment Custom alignment - VMA_DEBUG_LOG("vmaFreeMemory"); +Buffers or images exported to a different API like OpenGL may require a different alignment, +higher than the one used by the library automatically, queried from functions like `vkGetBufferMemoryRequirements`. +To impose such alignment: - VMA_DEBUG_GLOBAL_MUTEX_LOCK +It is recommended to create \ref custom_memory_pools for such allocations. +Set VmaPoolCreateInfo::minAllocationAlignment member to the minimum alignment required for each allocation +to be made out of this pool. +The alignment actually used will be the maximum of this member and the alignment returned for the specific buffer or image +from a function like `vkGetBufferMemoryRequirements`, which is called by VMA automatically. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordFreeMemory( - allocator->GetCurrentFrameIndex(), - allocation); - } -#endif +If you want to create a buffer with a specific minimum alignment out of default pools, +use special function vmaCreateBufferWithAlignment(), which takes additional parameter `minAlignment`. - allocator->FreeMemory( - 1, // allocationCount - &allocation); -} +Note the problem of alignment affects only resources placed inside bigger `VkDeviceMemory` blocks and not dedicated +allocations, as these, by definition, always have alignment = 0 because the resource is bound to the beginning of its dedicated block. +Contrary to Direct3D 12, Vulkan doesn't have a concept of alignment of the entire memory block passed on its allocation. -VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemoryPages( - VmaAllocator allocator, - size_t allocationCount, - const VmaAllocation* pAllocations) -{ - if(allocationCount == 0) - { - return; - } - VMA_ASSERT(allocator); +\page usage_patterns Recommended usage patterns - VMA_DEBUG_LOG("vmaFreeMemoryPages"); +Vulkan gives great flexibility in memory allocation. +This chapter shows the most common patterns. - VMA_DEBUG_GLOBAL_MUTEX_LOCK +See also slides from talk: +[Sawicki, Adam. Advanced Graphics Techniques Tutorial: Memory management in Vulkan and DX12. Game Developers Conference, 2018](https://www.gdcvault.com/play/1025458/Advanced-Graphics-Techniques-Tutorial-New) -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordFreeMemoryPages( - allocator->GetCurrentFrameIndex(), - (uint64_t)allocationCount, - pAllocations); - } -#endif - allocator->FreeMemory(allocationCount, pAllocations); -} +\section usage_patterns_gpu_only GPU-only resource -VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocationInfo( - VmaAllocator allocator, - VmaAllocation allocation, - VmaAllocationInfo* pAllocationInfo) -{ - VMA_ASSERT(allocator && allocation && pAllocationInfo); +When: +Any resources that you frequently write and read on GPU, +e.g. images used as color attachments (aka "render targets"), depth-stencil attachments, +images/buffers used as storage image/buffer (aka "Unordered Access View (UAV)"). - VMA_DEBUG_GLOBAL_MUTEX_LOCK +What to do: +Let the library select the optimal memory type, which will likely have `VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT`. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordGetAllocationInfo( - allocator->GetCurrentFrameIndex(), - allocation); - } -#endif +\code +VkImageCreateInfo imgCreateInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; +imgCreateInfo.imageType = VK_IMAGE_TYPE_2D; +imgCreateInfo.extent.width = 3840; +imgCreateInfo.extent.height = 2160; +imgCreateInfo.extent.depth = 1; +imgCreateInfo.mipLevels = 1; +imgCreateInfo.arrayLayers = 1; +imgCreateInfo.format = VK_FORMAT_R8G8B8A8_UNORM; +imgCreateInfo.tiling = VK_IMAGE_TILING_OPTIMAL; +imgCreateInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; +imgCreateInfo.usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT; +imgCreateInfo.samples = VK_SAMPLE_COUNT_1_BIT; - allocator->GetAllocationInfo(allocation, pAllocationInfo); -} +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT; +allocCreateInfo.priority = 1.0f; -VMA_CALL_PRE VkBool32 VMA_CALL_POST vmaTouchAllocation( - VmaAllocator allocator, - VmaAllocation allocation) -{ - VMA_ASSERT(allocator && allocation); +VkImage img; +VmaAllocation alloc; +vmaCreateImage(allocator, &imgCreateInfo, &allocCreateInfo, &img, &alloc, nullptr); +\endcode - VMA_DEBUG_GLOBAL_MUTEX_LOCK +Also consider: +Consider creating them as dedicated allocations using #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT, +especially if they are large or if you plan to destroy and recreate them with different sizes +e.g. when display resolution changes. +Prefer to create such resources first and all other GPU resources (like textures and vertex buffers) later. +When VK_EXT_memory_priority extension is enabled, it is also worth setting high priority to such allocation +to decrease chances to be evicted to system memory by the operating system. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordTouchAllocation( - allocator->GetCurrentFrameIndex(), - allocation); - } -#endif +\section usage_patterns_staging_copy_upload Staging copy for upload - return allocator->TouchAllocation(allocation); -} +When: +A "staging" buffer than you want to map and fill from CPU code, then use as a source od transfer +to some GPU resource. -VMA_CALL_PRE void VMA_CALL_POST vmaSetAllocationUserData( - VmaAllocator allocator, - VmaAllocation allocation, - void* pUserData) -{ - VMA_ASSERT(allocator && allocation); +What to do: +Use flag #VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT. +Let the library select the optimal memory type, which will always have `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT`. - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\code +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = 65536; +bufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT; - allocation->SetUserData(allocator, pUserData); +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | + VMA_ALLOCATION_CREATE_MAPPED_BIT; -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordSetAllocationUserData( - allocator->GetCurrentFrameIndex(), - allocation, - pUserData); - } -#endif -} +VkBuffer buf; +VmaAllocation alloc; +VmaAllocationInfo allocInfo; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); -VMA_CALL_PRE void VMA_CALL_POST vmaCreateLostAllocation( - VmaAllocator allocator, - VmaAllocation* pAllocation) -{ - VMA_ASSERT(allocator && pAllocation); +... - VMA_DEBUG_GLOBAL_MUTEX_LOCK; +memcpy(allocInfo.pMappedData, myData, myDataSize); +\endcode - allocator->CreateLostAllocation(pAllocation); +Also consider: +You can map the allocation using vmaMapMemory() or you can create it as persistenly mapped +using #VMA_ALLOCATION_CREATE_MAPPED_BIT, as in the example above. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordCreateLostAllocation( - allocator->GetCurrentFrameIndex(), - *pAllocation); - } -#endif -} -VMA_CALL_PRE VkResult VMA_CALL_POST vmaMapMemory( - VmaAllocator allocator, - VmaAllocation allocation, - void** ppData) -{ - VMA_ASSERT(allocator && allocation && ppData); +\section usage_patterns_readback Readback - VMA_DEBUG_GLOBAL_MUTEX_LOCK +When: +Buffers for data written by or transferred from the GPU that you want to read back on the CPU, +e.g. results of some computations. - VkResult res = allocator->Map(allocation, ppData); +What to do: +Use flag #VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT. +Let the library select the optimal memory type, which will always have `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` +and `VK_MEMORY_PROPERTY_HOST_CACHED_BIT`. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordMapMemory( - allocator->GetCurrentFrameIndex(), - allocation); - } -#endif +\code +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = 65536; +bufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_DST_BIT; - return res; -} +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT | + VMA_ALLOCATION_CREATE_MAPPED_BIT; -VMA_CALL_PRE void VMA_CALL_POST vmaUnmapMemory( - VmaAllocator allocator, - VmaAllocation allocation) -{ - VMA_ASSERT(allocator && allocation); +VkBuffer buf; +VmaAllocation alloc; +VmaAllocationInfo allocInfo; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); - VMA_DEBUG_GLOBAL_MUTEX_LOCK +... -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordUnmapMemory( - allocator->GetCurrentFrameIndex(), - allocation); - } -#endif +const float* downloadedData = (const float*)allocInfo.pMappedData; +\endcode - allocator->Unmap(allocation); -} -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocation(VmaAllocator allocator, VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size) +\section usage_patterns_advanced_data_uploading Advanced data uploading + +For resources that you frequently write on CPU via mapped pointer and +freqnently read on GPU e.g. as a uniform buffer (also called "dynamic"), multiple options are possible: + +-# Easiest solution is to have one copy of the resource in `HOST_VISIBLE` memory, + even if it means system RAM (not `DEVICE_LOCAL`) on systems with a discrete graphics card, + and make the device reach out to that resource directly. + - Reads performed by the device will then go through PCI Express bus. + The performace of this access may be limited, but it may be fine depending on the size + of this resource (whether it is small enough to quickly end up in GPU cache) and the sparsity + of access. +-# On systems with unified memory (e.g. AMD APU or Intel integrated graphics, mobile chips), + a memory type may be available that is both `HOST_VISIBLE` (available for mapping) and `DEVICE_LOCAL` + (fast to access from the GPU). Then, it is likely the best choice for such type of resource. +-# Systems with a discrete graphics card and separate video memory may or may not expose + a memory type that is both `HOST_VISIBLE` and `DEVICE_LOCAL`, also known as Base Address Register (BAR). + If they do, it represents a piece of VRAM (or entire VRAM, if ReBAR is enabled in the motherboard BIOS) + that is available to CPU for mapping. + - Writes performed by the host to that memory go through PCI Express bus. + The performance of these writes may be limited, but it may be fine, especially on PCIe 4.0, + as long as rules of using uncached and write-combined memory are followed - only sequential writes and no reads. +-# Finally, you may need or prefer to create a separate copy of the resource in `DEVICE_LOCAL` memory, + a separate "staging" copy in `HOST_VISIBLE` memory and perform an explicit transfer command between them. + +Thankfully, VMA offers an aid to create and use such resources in the the way optimal +for the current Vulkan device. To help the library make the best choice, +use flag #VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT together with +#VMA_ALLOCATION_CREATE_HOST_ACCESS_ALLOW_TRANSFER_INSTEAD_BIT. +It will then prefer a memory type that is both `DEVICE_LOCAL` and `HOST_VISIBLE` (integrated memory or BAR), +but if no such memory type is available or allocation from it fails +(PC graphics cards have only 256 MB of BAR by default, unless ReBAR is supported and enabled in BIOS), +it will fall back to `DEVICE_LOCAL` memory for fast GPU access. +It is then up to you to detect that the allocation ended up in a memory type that is not `HOST_VISIBLE`, +so you need to create another "staging" allocation and perform explicit transfers. + +\code +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = 65536; +bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | + VMA_ALLOCATION_CREATE_HOST_ACCESS_ALLOW_TRANSFER_INSTEAD_BIT | + VMA_ALLOCATION_CREATE_MAPPED_BIT; + +VkBuffer buf; +VmaAllocation alloc; +VmaAllocationInfo allocInfo; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); + +VkMemoryPropertyFlags memPropFlags; +vmaGetAllocationMemoryProperties(allocator, alloc, &memPropFlags); + +if(memPropFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) { - VMA_ASSERT(allocator && allocation); + // Allocation ended up in a mappable memory and is already mapped - write to it directly. - VMA_DEBUG_LOG("vmaFlushAllocation"); + // [Executed in runtime]: + memcpy(allocInfo.pMappedData, myData, myDataSize); +} +else +{ + // Allocation ended up in a non-mappable memory - need to transfer. + VkBufferCreateInfo stagingBufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; + stagingBufCreateInfo.size = 65536; + stagingBufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT; + + VmaAllocationCreateInfo stagingAllocCreateInfo = {}; + stagingAllocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; + stagingAllocCreateInfo.flags = VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | + VMA_ALLOCATION_CREATE_MAPPED_BIT; + + VkBuffer stagingBuf; + VmaAllocation stagingAlloc; + VmaAllocationInfo stagingAllocInfo; + vmaCreateBuffer(allocator, &stagingBufCreateInfo, &stagingAllocCreateInfo, + &stagingBuf, &stagingAlloc, stagingAllocInfo); + + // [Executed in runtime]: + memcpy(stagingAllocInfo.pMappedData, myData, myDataSize); + //vkCmdPipelineBarrier: VK_ACCESS_HOST_WRITE_BIT --> VK_ACCESS_TRANSFER_READ_BIT + VkBufferCopy bufCopy = { + 0, // srcOffset + 0, // dstOffset, + myDataSize); // size + vkCmdCopyBuffer(cmdBuf, stagingBuf, buf, 1, &bufCopy); +} +\endcode - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\section usage_patterns_other_use_cases Other use cases + +Here are some other, less obvious use cases and their recommended settings: + +- An image that is used only as transfer source and destination, but it should stay on the device, + as it is used to temporarily store a copy of some texture, e.g. from the current to the next frame, + for temporal antialiasing or other temporal effects. + - Use `VkImageCreateInfo::usage = VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT` + - Use VmaAllocationCreateInfo::usage = #VMA_MEMORY_USAGE_AUTO +- An image that is used only as transfer source and destination, but it should be placed + in the system RAM despite it doesn't need to be mapped, because it serves as a "swap" copy to evict + least recently used textures from VRAM. + - Use `VkImageCreateInfo::usage = VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT` + - Use VmaAllocationCreateInfo::usage = #VMA_MEMORY_USAGE_AUTO_PREFER_HOST, + as VMA needs a hint here to differentiate from the previous case. +- A buffer that you want to map and write from the CPU, directly read from the GPU + (e.g. as a uniform or vertex buffer), but you have a clear preference to place it in device or + host memory due to its large size. + - Use `VkBufferCreateInfo::usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT` + - Use VmaAllocationCreateInfo::usage = #VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE or #VMA_MEMORY_USAGE_AUTO_PREFER_HOST + - Use VmaAllocationCreateInfo::flags = #VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT - const VkResult res = allocator->FlushOrInvalidateAllocation(allocation, offset, size, VMA_CACHE_FLUSH); -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordFlushAllocation( - allocator->GetCurrentFrameIndex(), - allocation, offset, size); - } -#endif +\page configuration Configuration - return res; -} +Please check "CONFIGURATION SECTION" in the code to find macros that you can define +before each include of this file or change directly in this file to provide +your own implementation of basic facilities like assert, `min()` and `max()` functions, +mutex, atomic etc. +The library uses its own implementation of containers by default, but you can switch to using +STL containers instead. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocation(VmaAllocator allocator, VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size) -{ - VMA_ASSERT(allocator && allocation); +For example, define `VMA_ASSERT(expr)` before including the library to provide +custom implementation of the assertion, compatible with your project. +By default it is defined to standard C `assert(expr)` in `_DEBUG` configuration +and empty otherwise. - VMA_DEBUG_LOG("vmaInvalidateAllocation"); +\section config_Vulkan_functions Pointers to Vulkan functions - VMA_DEBUG_GLOBAL_MUTEX_LOCK +There are multiple ways to import pointers to Vulkan functions in the library. +In the simplest case you don't need to do anything. +If the compilation or linking of your program or the initialization of the #VmaAllocator +doesn't work for you, you can try to reconfigure it. - const VkResult res = allocator->FlushOrInvalidateAllocation(allocation, offset, size, VMA_CACHE_INVALIDATE); +First, the allocator tries to fetch pointers to Vulkan functions linked statically, +like this: -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordInvalidateAllocation( - allocator->GetCurrentFrameIndex(), - allocation, offset, size); - } -#endif +\code +m_VulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkAllocateMemory; +\endcode - return res; -} +If you want to disable this feature, set configuration macro: `#define VMA_STATIC_VULKAN_FUNCTIONS 0`. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocations( - VmaAllocator allocator, - uint32_t allocationCount, - const VmaAllocation* allocations, - const VkDeviceSize* offsets, - const VkDeviceSize* sizes) -{ - VMA_ASSERT(allocator); +Second, you can provide the pointers yourself by setting member VmaAllocatorCreateInfo::pVulkanFunctions. +You can fetch them e.g. using functions `vkGetInstanceProcAddr` and `vkGetDeviceProcAddr` or +by using a helper library like [volk](https://github.com/zeux/volk). - if(allocationCount == 0) - { - return VK_SUCCESS; - } +Third, VMA tries to fetch remaining pointers that are still null by calling +`vkGetInstanceProcAddr` and `vkGetDeviceProcAddr` on its own. +You need to only fill in VmaVulkanFunctions::vkGetInstanceProcAddr and VmaVulkanFunctions::vkGetDeviceProcAddr. +Other pointers will be fetched automatically. +If you want to disable this feature, set configuration macro: `#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0`. - VMA_ASSERT(allocations); +Finally, all the function pointers required by the library (considering selected +Vulkan version and enabled extensions) are checked with `VMA_ASSERT` if they are not null. - VMA_DEBUG_LOG("vmaFlushAllocations"); - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\section custom_memory_allocator Custom host memory allocator + +If you use custom allocator for CPU memory rather than default operator `new` +and `delete` from C++, you can make this library using your allocator as well +by filling optional member VmaAllocatorCreateInfo::pAllocationCallbacks. These +functions will be passed to Vulkan, as well as used by the library itself to +make any CPU-side allocations. + +\section allocation_callbacks Device memory allocation callbacks + +The library makes calls to `vkAllocateMemory()` and `vkFreeMemory()` internally. +You can setup callbacks to be informed about these calls, e.g. for the purpose +of gathering some statistics. To do it, fill optional member +VmaAllocatorCreateInfo::pDeviceMemoryCallbacks. - const VkResult res = allocator->FlushOrInvalidateAllocations(allocationCount, allocations, offsets, sizes, VMA_CACHE_FLUSH); +\section heap_memory_limit Device heap memory limit -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - //TODO - } -#endif +When device memory of certain heap runs out of free space, new allocations may +fail (returning error code) or they may succeed, silently pushing some existing_ +memory blocks from GPU VRAM to system RAM (which degrades performance). This +behavior is implementation-dependent - it depends on GPU vendor and graphics +driver. - return res; -} +On AMD cards it can be controlled while creating Vulkan device object by using +VK_AMD_memory_overallocation_behavior extension, if available. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocations( - VmaAllocator allocator, - uint32_t allocationCount, - const VmaAllocation* allocations, - const VkDeviceSize* offsets, - const VkDeviceSize* sizes) -{ - VMA_ASSERT(allocator); +Alternatively, if you want to test how your program behaves with limited amount of Vulkan device +memory available without switching your graphics card to one that really has +smaller VRAM, you can use a feature of this library intended for this purpose. +To do it, fill optional member VmaAllocatorCreateInfo::pHeapSizeLimit. - if(allocationCount == 0) - { - return VK_SUCCESS; - } - VMA_ASSERT(allocations); - VMA_DEBUG_LOG("vmaInvalidateAllocations"); +\page vk_khr_dedicated_allocation VK_KHR_dedicated_allocation - VMA_DEBUG_GLOBAL_MUTEX_LOCK +VK_KHR_dedicated_allocation is a Vulkan extension which can be used to improve +performance on some GPUs. It augments Vulkan API with possibility to query +driver whether it prefers particular buffer or image to have its own, dedicated +allocation (separate `VkDeviceMemory` block) for better efficiency - to be able +to do some internal optimizations. The extension is supported by this library. +It will be used automatically when enabled. - const VkResult res = allocator->FlushOrInvalidateAllocations(allocationCount, allocations, offsets, sizes, VMA_CACHE_INVALIDATE); +It has been promoted to core Vulkan 1.1, so if you use eligible Vulkan version +and inform VMA about it by setting VmaAllocatorCreateInfo::vulkanApiVersion, +you are all set. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - //TODO - } -#endif +Otherwise, if you want to use it as an extension: - return res; -} +1 . When creating Vulkan device, check if following 2 device extensions are +supported (call `vkEnumerateDeviceExtensionProperties()`). +If yes, enable them (fill `VkDeviceCreateInfo::ppEnabledExtensionNames`). -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckCorruption(VmaAllocator allocator, uint32_t memoryTypeBits) -{ - VMA_ASSERT(allocator); +- VK_KHR_get_memory_requirements2 +- VK_KHR_dedicated_allocation - VMA_DEBUG_LOG("vmaCheckCorruption"); +If you enabled these extensions: - VMA_DEBUG_GLOBAL_MUTEX_LOCK +2 . Use #VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT flag when creating +your #VmaAllocator to inform the library that you enabled required extensions +and you want the library to use them. - return allocator->CheckCorruption(memoryTypeBits); -} +\code +allocatorInfo.flags |= VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT; -VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragment( - VmaAllocator allocator, - const VmaAllocation* pAllocations, - size_t allocationCount, - VkBool32* pAllocationsChanged, - const VmaDefragmentationInfo *pDefragmentationInfo, - VmaDefragmentationStats* pDefragmentationStats) -{ - // Deprecated interface, reimplemented using new one. +vmaCreateAllocator(&allocatorInfo, &allocator); +\endcode - VmaDefragmentationInfo2 info2 = {}; - info2.allocationCount = (uint32_t)allocationCount; - info2.pAllocations = pAllocations; - info2.pAllocationsChanged = pAllocationsChanged; - if(pDefragmentationInfo != VMA_NULL) - { - info2.maxCpuAllocationsToMove = pDefragmentationInfo->maxAllocationsToMove; - info2.maxCpuBytesToMove = pDefragmentationInfo->maxBytesToMove; - } - else - { - info2.maxCpuAllocationsToMove = UINT32_MAX; - info2.maxCpuBytesToMove = VK_WHOLE_SIZE; - } - // info2.flags, maxGpuAllocationsToMove, maxGpuBytesToMove, commandBuffer deliberately left zero. +That is all. The extension will be automatically used whenever you create a +buffer using vmaCreateBuffer() or image using vmaCreateImage(). - VmaDefragmentationContext ctx; - VkResult res = vmaDefragmentationBegin(allocator, &info2, pDefragmentationStats, &ctx); - if(res == VK_NOT_READY) - { - res = vmaDefragmentationEnd( allocator, ctx); - } - return res; -} +When using the extension together with Vulkan Validation Layer, you will receive +warnings like this: -VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragmentationBegin( - VmaAllocator allocator, - const VmaDefragmentationInfo2* pInfo, - VmaDefragmentationStats* pStats, - VmaDefragmentationContext *pContext) -{ - VMA_ASSERT(allocator && pInfo && pContext); +_vkBindBufferMemory(): Binding memory to buffer 0x33 but vkGetBufferMemoryRequirements() has not been called on that buffer._ - // Degenerate case: Nothing to defragment. - if(pInfo->allocationCount == 0 && pInfo->poolCount == 0) - { - return VK_SUCCESS; - } +It is OK, you should just ignore it. It happens because you use function +`vkGetBufferMemoryRequirements2KHR()` instead of standard +`vkGetBufferMemoryRequirements()`, while the validation layer seems to be +unaware of it. - VMA_ASSERT(pInfo->allocationCount == 0 || pInfo->pAllocations != VMA_NULL); - VMA_ASSERT(pInfo->poolCount == 0 || pInfo->pPools != VMA_NULL); - VMA_HEAVY_ASSERT(VmaValidatePointerArray(pInfo->allocationCount, pInfo->pAllocations)); - VMA_HEAVY_ASSERT(VmaValidatePointerArray(pInfo->poolCount, pInfo->pPools)); +To learn more about this extension, see: - VMA_DEBUG_LOG("vmaDefragmentationBegin"); +- [VK_KHR_dedicated_allocation in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap50.html#VK_KHR_dedicated_allocation) +- [VK_KHR_dedicated_allocation unofficial manual](http://asawicki.info/articles/VK_KHR_dedicated_allocation.php5) - VMA_DEBUG_GLOBAL_MUTEX_LOCK - VkResult res = allocator->DefragmentationBegin(*pInfo, pStats, pContext); -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordDefragmentationBegin( - allocator->GetCurrentFrameIndex(), *pInfo, *pContext); - } -#endif +\page vk_ext_memory_priority VK_EXT_memory_priority - return res; -} +VK_EXT_memory_priority is a device extension that allows to pass additional "priority" +value to Vulkan memory allocations that the implementation may use prefer certain +buffers and images that are critical for performance to stay in device-local memory +in cases when the memory is over-subscribed, while some others may be moved to the system memory. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragmentationEnd( - VmaAllocator allocator, - VmaDefragmentationContext context) -{ - VMA_ASSERT(allocator); +VMA offers convenient usage of this extension. +If you enable it, you can pass "priority" parameter when creating allocations or custom pools +and the library automatically passes the value to Vulkan using this extension. - VMA_DEBUG_LOG("vmaDefragmentationEnd"); +If you want to use this extension in connection with VMA, follow these steps: - if(context != VK_NULL_HANDLE) - { - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\section vk_ext_memory_priority_initialization Initialization -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordDefragmentationEnd( - allocator->GetCurrentFrameIndex(), context); - } -#endif +1) Call `vkEnumerateDeviceExtensionProperties` for the physical device. +Check if the extension is supported - if returned array of `VkExtensionProperties` contains "VK_EXT_memory_priority". - return allocator->DefragmentationEnd(context); - } - else - { - return VK_SUCCESS; - } -} +2) Call `vkGetPhysicalDeviceFeatures2` for the physical device instead of old `vkGetPhysicalDeviceFeatures`. +Attach additional structure `VkPhysicalDeviceMemoryPriorityFeaturesEXT` to `VkPhysicalDeviceFeatures2::pNext` to be returned. +Check if the device feature is really supported - check if `VkPhysicalDeviceMemoryPriorityFeaturesEXT::memoryPriority` is true. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBeginDefragmentationPass( - VmaAllocator allocator, - VmaDefragmentationContext context, - VmaDefragmentationPassInfo* pInfo - ) -{ - VMA_ASSERT(allocator); - VMA_ASSERT(pInfo); +3) While creating device with `vkCreateDevice`, enable this extension - add "VK_EXT_memory_priority" +to the list passed as `VkDeviceCreateInfo::ppEnabledExtensionNames`. - VMA_DEBUG_LOG("vmaBeginDefragmentationPass"); +4) While creating the device, also don't set `VkDeviceCreateInfo::pEnabledFeatures`. +Fill in `VkPhysicalDeviceFeatures2` structure instead and pass it as `VkDeviceCreateInfo::pNext`. +Enable this device feature - attach additional structure `VkPhysicalDeviceMemoryPriorityFeaturesEXT` to +`VkPhysicalDeviceFeatures2::pNext` chain and set its member `memoryPriority` to `VK_TRUE`. - VMA_DEBUG_GLOBAL_MUTEX_LOCK +5) While creating #VmaAllocator with vmaCreateAllocator() inform VMA that you +have enabled this extension and feature - add #VMA_ALLOCATOR_CREATE_EXT_MEMORY_PRIORITY_BIT +to VmaAllocatorCreateInfo::flags. - if(context == VK_NULL_HANDLE) - { - pInfo->moveCount = 0; - return VK_SUCCESS; - } +\section vk_ext_memory_priority_usage Usage - return allocator->DefragmentationPassBegin(pInfo, context); -} -VMA_CALL_PRE VkResult VMA_CALL_POST vmaEndDefragmentationPass( - VmaAllocator allocator, - VmaDefragmentationContext context) -{ - VMA_ASSERT(allocator); +When using this extension, you should initialize following member: - VMA_DEBUG_LOG("vmaEndDefragmentationPass"); - VMA_DEBUG_GLOBAL_MUTEX_LOCK +- VmaAllocationCreateInfo::priority when creating a dedicated allocation with #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. +- VmaPoolCreateInfo::priority when creating a custom pool. - if(context == VK_NULL_HANDLE) - return VK_SUCCESS; +It should be a floating-point value between `0.0f` and `1.0f`, where recommended default is `0.5f`. +Memory allocated with higher value can be treated by the Vulkan implementation as higher priority +and so it can have lower chances of being pushed out to system memory, experiencing degraded performance. - return allocator->DefragmentationPassEnd(context); -} +It might be a good idea to create performance-critical resources like color-attachment or depth-stencil images +as dedicated and set high priority to them. For example: -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory( - VmaAllocator allocator, - VmaAllocation allocation, - VkBuffer buffer) -{ - VMA_ASSERT(allocator && allocation && buffer); +\code +VkImageCreateInfo imgCreateInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; +imgCreateInfo.imageType = VK_IMAGE_TYPE_2D; +imgCreateInfo.extent.width = 3840; +imgCreateInfo.extent.height = 2160; +imgCreateInfo.extent.depth = 1; +imgCreateInfo.mipLevels = 1; +imgCreateInfo.arrayLayers = 1; +imgCreateInfo.format = VK_FORMAT_R8G8B8A8_UNORM; +imgCreateInfo.tiling = VK_IMAGE_TILING_OPTIMAL; +imgCreateInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; +imgCreateInfo.usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT; +imgCreateInfo.samples = VK_SAMPLE_COUNT_1_BIT; - VMA_DEBUG_LOG("vmaBindBufferMemory"); +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_AUTO; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT; +allocCreateInfo.priority = 1.0f; - VMA_DEBUG_GLOBAL_MUTEX_LOCK +VkImage img; +VmaAllocation alloc; +vmaCreateImage(allocator, &imgCreateInfo, &allocCreateInfo, &img, &alloc, nullptr); +\endcode - return allocator->BindBufferMemory(allocation, 0, buffer, VMA_NULL); -} +`priority` member is ignored in the following situations: -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory2( - VmaAllocator allocator, - VmaAllocation allocation, - VkDeviceSize allocationLocalOffset, - VkBuffer buffer, - const void* pNext) -{ - VMA_ASSERT(allocator && allocation && buffer); +- Allocations created in custom pools: They inherit the priority, along with all other allocation parameters + from the parametrs passed in #VmaPoolCreateInfo when the pool was created. +- Allocations created in default pools: They inherit the priority from the parameters + VMA used when creating default pools, which means `priority == 0.5f`. - VMA_DEBUG_LOG("vmaBindBufferMemory2"); - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\page vk_amd_device_coherent_memory VK_AMD_device_coherent_memory - return allocator->BindBufferMemory(allocation, allocationLocalOffset, buffer, pNext); -} +VK_AMD_device_coherent_memory is a device extension that enables access to +additional memory types with `VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD` and +`VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD` flag. It is useful mostly for +allocation of buffers intended for writing "breadcrumb markers" in between passes +or draw calls, which in turn are useful for debugging GPU crash/hang/TDR cases. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory( - VmaAllocator allocator, - VmaAllocation allocation, - VkImage image) -{ - VMA_ASSERT(allocator && allocation && image); +When the extension is available but has not been enabled, Vulkan physical device +still exposes those memory types, but their usage is forbidden. VMA automatically +takes care of that - it returns `VK_ERROR_FEATURE_NOT_PRESENT` when an attempt +to allocate memory of such type is made. - VMA_DEBUG_LOG("vmaBindImageMemory"); +If you want to use this extension in connection with VMA, follow these steps: - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\section vk_amd_device_coherent_memory_initialization Initialization - return allocator->BindImageMemory(allocation, 0, image, VMA_NULL); -} +1) Call `vkEnumerateDeviceExtensionProperties` for the physical device. +Check if the extension is supported - if returned array of `VkExtensionProperties` contains "VK_AMD_device_coherent_memory". -VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory2( - VmaAllocator allocator, - VmaAllocation allocation, - VkDeviceSize allocationLocalOffset, - VkImage image, - const void* pNext) -{ - VMA_ASSERT(allocator && allocation && image); +2) Call `vkGetPhysicalDeviceFeatures2` for the physical device instead of old `vkGetPhysicalDeviceFeatures`. +Attach additional structure `VkPhysicalDeviceCoherentMemoryFeaturesAMD` to `VkPhysicalDeviceFeatures2::pNext` to be returned. +Check if the device feature is really supported - check if `VkPhysicalDeviceCoherentMemoryFeaturesAMD::deviceCoherentMemory` is true. - VMA_DEBUG_LOG("vmaBindImageMemory2"); +3) While creating device with `vkCreateDevice`, enable this extension - add "VK_AMD_device_coherent_memory" +to the list passed as `VkDeviceCreateInfo::ppEnabledExtensionNames`. - VMA_DEBUG_GLOBAL_MUTEX_LOCK +4) While creating the device, also don't set `VkDeviceCreateInfo::pEnabledFeatures`. +Fill in `VkPhysicalDeviceFeatures2` structure instead and pass it as `VkDeviceCreateInfo::pNext`. +Enable this device feature - attach additional structure `VkPhysicalDeviceCoherentMemoryFeaturesAMD` to +`VkPhysicalDeviceFeatures2::pNext` and set its member `deviceCoherentMemory` to `VK_TRUE`. - return allocator->BindImageMemory(allocation, allocationLocalOffset, image, pNext); -} +5) While creating #VmaAllocator with vmaCreateAllocator() inform VMA that you +have enabled this extension and feature - add #VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT +to VmaAllocatorCreateInfo::flags. -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBuffer( - VmaAllocator allocator, - const VkBufferCreateInfo* pBufferCreateInfo, - const VmaAllocationCreateInfo* pAllocationCreateInfo, - VkBuffer* pBuffer, - VmaAllocation* pAllocation, - VmaAllocationInfo* pAllocationInfo) -{ - VMA_ASSERT(allocator && pBufferCreateInfo && pAllocationCreateInfo && pBuffer && pAllocation); +\section vk_amd_device_coherent_memory_usage Usage - if(pBufferCreateInfo->size == 0) - { - return VK_ERROR_VALIDATION_FAILED_EXT; - } - if((pBufferCreateInfo->usage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_COPY) != 0 && - !allocator->m_UseKhrBufferDeviceAddress) - { - VMA_ASSERT(0 && "Creating a buffer with VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT is not valid if VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT was not used."); - return VK_ERROR_VALIDATION_FAILED_EXT; - } +After following steps described above, you can create VMA allocations and custom pools +out of the special `DEVICE_COHERENT` and `DEVICE_UNCACHED` memory types on eligible +devices. There are multiple ways to do it, for example: - VMA_DEBUG_LOG("vmaCreateBuffer"); +- You can request or prefer to allocate out of such memory types by adding + `VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD` to VmaAllocationCreateInfo::requiredFlags + or VmaAllocationCreateInfo::preferredFlags. Those flags can be freely mixed with + other ways of \ref choosing_memory_type, like setting VmaAllocationCreateInfo::usage. +- If you manually found memory type index to use for this purpose, force allocation + from this specific index by setting VmaAllocationCreateInfo::memoryTypeBits `= 1u << index`. - VMA_DEBUG_GLOBAL_MUTEX_LOCK +\section vk_amd_device_coherent_memory_more_information More information - *pBuffer = VK_NULL_HANDLE; - *pAllocation = VK_NULL_HANDLE; +To learn more about this extension, see [VK_AMD_device_coherent_memory in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VK_AMD_device_coherent_memory.html) - // 1. Create VkBuffer. - VkResult res = (*allocator->GetVulkanFunctions().vkCreateBuffer)( - allocator->m_hDevice, - pBufferCreateInfo, - allocator->GetAllocationCallbacks(), - pBuffer); - if(res >= 0) - { - // 2. vkGetBufferMemoryRequirements. - VkMemoryRequirements vkMemReq = {}; - bool requiresDedicatedAllocation = false; - bool prefersDedicatedAllocation = false; - allocator->GetBufferMemoryRequirements(*pBuffer, vkMemReq, - requiresDedicatedAllocation, prefersDedicatedAllocation); +Example use of this extension can be found in the code of the sample and test suite +accompanying this library. - // 3. Allocate memory using allocator. - res = allocator->AllocateMemory( - vkMemReq, - requiresDedicatedAllocation, - prefersDedicatedAllocation, - *pBuffer, // dedicatedBuffer - pBufferCreateInfo->usage, // dedicatedBufferUsage - VK_NULL_HANDLE, // dedicatedImage - *pAllocationCreateInfo, - VMA_SUBALLOCATION_TYPE_BUFFER, - 1, // allocationCount - pAllocation); -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordCreateBuffer( - allocator->GetCurrentFrameIndex(), - *pBufferCreateInfo, - *pAllocationCreateInfo, - *pAllocation); - } -#endif +\page enabling_buffer_device_address Enabling buffer device address - if(res >= 0) - { - // 3. Bind buffer with memory. - if((pAllocationCreateInfo->flags & VMA_ALLOCATION_CREATE_DONT_BIND_BIT) == 0) - { - res = allocator->BindBufferMemory(*pAllocation, 0, *pBuffer, VMA_NULL); - } - if(res >= 0) - { - // All steps succeeded. - #if VMA_STATS_STRING_ENABLED - (*pAllocation)->InitBufferImageUsage(pBufferCreateInfo->usage); - #endif - if(pAllocationInfo != VMA_NULL) - { - allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); - } +Device extension VK_KHR_buffer_device_address +allow to fetch raw GPU pointer to a buffer and pass it for usage in a shader code. +It has been promoted to core Vulkan 1.2. - return VK_SUCCESS; - } - allocator->FreeMemory( - 1, // allocationCount - pAllocation); - *pAllocation = VK_NULL_HANDLE; - (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); - *pBuffer = VK_NULL_HANDLE; - return res; - } - (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); - *pBuffer = VK_NULL_HANDLE; - return res; - } - return res; -} +If you want to use this feature in connection with VMA, follow these steps: -VMA_CALL_PRE void VMA_CALL_POST vmaDestroyBuffer( - VmaAllocator allocator, - VkBuffer buffer, - VmaAllocation allocation) -{ - VMA_ASSERT(allocator); +\section enabling_buffer_device_address_initialization Initialization - if(buffer == VK_NULL_HANDLE && allocation == VK_NULL_HANDLE) - { - return; - } +1) (For Vulkan version < 1.2) Call `vkEnumerateDeviceExtensionProperties` for the physical device. +Check if the extension is supported - if returned array of `VkExtensionProperties` contains +"VK_KHR_buffer_device_address". - VMA_DEBUG_LOG("vmaDestroyBuffer"); +2) Call `vkGetPhysicalDeviceFeatures2` for the physical device instead of old `vkGetPhysicalDeviceFeatures`. +Attach additional structure `VkPhysicalDeviceBufferDeviceAddressFeatures*` to `VkPhysicalDeviceFeatures2::pNext` to be returned. +Check if the device feature is really supported - check if `VkPhysicalDeviceBufferDeviceAddressFeatures::bufferDeviceAddress` is true. - VMA_DEBUG_GLOBAL_MUTEX_LOCK +3) (For Vulkan version < 1.2) While creating device with `vkCreateDevice`, enable this extension - add +"VK_KHR_buffer_device_address" to the list passed as `VkDeviceCreateInfo::ppEnabledExtensionNames`. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordDestroyBuffer( - allocator->GetCurrentFrameIndex(), - allocation); - } -#endif +4) While creating the device, also don't set `VkDeviceCreateInfo::pEnabledFeatures`. +Fill in `VkPhysicalDeviceFeatures2` structure instead and pass it as `VkDeviceCreateInfo::pNext`. +Enable this device feature - attach additional structure `VkPhysicalDeviceBufferDeviceAddressFeatures*` to +`VkPhysicalDeviceFeatures2::pNext` and set its member `bufferDeviceAddress` to `VK_TRUE`. - if(buffer != VK_NULL_HANDLE) - { - (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, buffer, allocator->GetAllocationCallbacks()); - } +5) While creating #VmaAllocator with vmaCreateAllocator() inform VMA that you +have enabled this feature - add #VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT +to VmaAllocatorCreateInfo::flags. - if(allocation != VK_NULL_HANDLE) - { - allocator->FreeMemory( - 1, // allocationCount - &allocation); - } -} +\section enabling_buffer_device_address_usage Usage -VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateImage( - VmaAllocator allocator, - const VkImageCreateInfo* pImageCreateInfo, - const VmaAllocationCreateInfo* pAllocationCreateInfo, - VkImage* pImage, - VmaAllocation* pAllocation, - VmaAllocationInfo* pAllocationInfo) -{ - VMA_ASSERT(allocator && pImageCreateInfo && pAllocationCreateInfo && pImage && pAllocation); +After following steps described above, you can create buffers with `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT*` using VMA. +The library automatically adds `VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT*` to +allocated memory blocks wherever it might be needed. - if(pImageCreateInfo->extent.width == 0 || - pImageCreateInfo->extent.height == 0 || - pImageCreateInfo->extent.depth == 0 || - pImageCreateInfo->mipLevels == 0 || - pImageCreateInfo->arrayLayers == 0) - { - return VK_ERROR_VALIDATION_FAILED_EXT; - } +Please note that the library supports only `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT*`. +The second part of this functionality related to "capture and replay" is not supported, +as it is intended for usage in debugging tools like RenderDoc, not in everyday Vulkan usage. - VMA_DEBUG_LOG("vmaCreateImage"); +\section enabling_buffer_device_address_more_information More information - VMA_DEBUG_GLOBAL_MUTEX_LOCK +To learn more about this extension, see [VK_KHR_buffer_device_address in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap46.html#VK_KHR_buffer_device_address) - *pImage = VK_NULL_HANDLE; - *pAllocation = VK_NULL_HANDLE; +Example use of this extension can be found in the code of the sample and test suite +accompanying this library. - // 1. Create VkImage. - VkResult res = (*allocator->GetVulkanFunctions().vkCreateImage)( - allocator->m_hDevice, - pImageCreateInfo, - allocator->GetAllocationCallbacks(), - pImage); - if(res >= 0) - { - VmaSuballocationType suballocType = pImageCreateInfo->tiling == VK_IMAGE_TILING_OPTIMAL ? - VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL : - VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR; +\page general_considerations General considerations - // 2. Allocate memory using allocator. - VkMemoryRequirements vkMemReq = {}; - bool requiresDedicatedAllocation = false; - bool prefersDedicatedAllocation = false; - allocator->GetImageMemoryRequirements(*pImage, vkMemReq, - requiresDedicatedAllocation, prefersDedicatedAllocation); +\section general_considerations_thread_safety Thread safety - res = allocator->AllocateMemory( - vkMemReq, - requiresDedicatedAllocation, - prefersDedicatedAllocation, - VK_NULL_HANDLE, // dedicatedBuffer - UINT32_MAX, // dedicatedBufferUsage - *pImage, // dedicatedImage - *pAllocationCreateInfo, - suballocType, - 1, // allocationCount - pAllocation); +- The library has no global state, so separate #VmaAllocator objects can be used + independently. + There should be no need to create multiple such objects though - one per `VkDevice` is enough. +- By default, all calls to functions that take #VmaAllocator as first parameter + are safe to call from multiple threads simultaneously because they are + synchronized internally when needed. + This includes allocation and deallocation from default memory pool, as well as custom #VmaPool. +- When the allocator is created with #VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT + flag, calls to functions that take such #VmaAllocator object must be + synchronized externally. +- Access to a #VmaAllocation object must be externally synchronized. For example, + you must not call vmaGetAllocationInfo() and vmaMapMemory() from different + threads at the same time if you pass the same #VmaAllocation object to these + functions. +- #VmaVirtualBlock is not safe to be used from multiple threads simultaneously. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordCreateImage( - allocator->GetCurrentFrameIndex(), - *pImageCreateInfo, - *pAllocationCreateInfo, - *pAllocation); - } -#endif +\section general_considerations_versioning_and_compatibility Versioning and compatibility - if(res >= 0) - { - // 3. Bind image with memory. - if((pAllocationCreateInfo->flags & VMA_ALLOCATION_CREATE_DONT_BIND_BIT) == 0) - { - res = allocator->BindImageMemory(*pAllocation, 0, *pImage, VMA_NULL); - } - if(res >= 0) - { - // All steps succeeded. - #if VMA_STATS_STRING_ENABLED - (*pAllocation)->InitBufferImageUsage(pImageCreateInfo->usage); - #endif - if(pAllocationInfo != VMA_NULL) - { - allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); - } +The library uses [**Semantic Versioning**](https://semver.org/), +which means version numbers follow convention: Major.Minor.Patch (e.g. 2.3.0), where: - return VK_SUCCESS; - } - allocator->FreeMemory( - 1, // allocationCount - pAllocation); - *pAllocation = VK_NULL_HANDLE; - (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, *pImage, allocator->GetAllocationCallbacks()); - *pImage = VK_NULL_HANDLE; - return res; - } - (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, *pImage, allocator->GetAllocationCallbacks()); - *pImage = VK_NULL_HANDLE; - return res; - } - return res; -} +- Incremented Patch version means a release is backward- and forward-compatible, + introducing only some internal improvements, bug fixes, optimizations etc. + or changes that are out of scope of the official API described in this documentation. +- Incremented Minor version means a release is backward-compatible, + so existing code that uses the library should continue to work, while some new + symbols could have been added: new structures, functions, new values in existing + enums and bit flags, new structure members, but not new function parameters. +- Incrementing Major version means a release could break some backward compatibility. -VMA_CALL_PRE void VMA_CALL_POST vmaDestroyImage( - VmaAllocator allocator, - VkImage image, - VmaAllocation allocation) -{ - VMA_ASSERT(allocator); +All changes between official releases are documented in file "CHANGELOG.md". - if(image == VK_NULL_HANDLE && allocation == VK_NULL_HANDLE) - { - return; - } +\warning Backward compatiblity is considered on the level of C++ source code, not binary linkage. +Adding new members to existing structures is treated as backward compatible if initializing +the new members to binary zero results in the old behavior. +You should always fully initialize all library structures to zeros and not rely on their +exact binary size. - VMA_DEBUG_LOG("vmaDestroyImage"); +\section general_considerations_validation_layer_warnings Validation layer warnings - VMA_DEBUG_GLOBAL_MUTEX_LOCK +When using this library, you can meet following types of warnings issued by +Vulkan validation layer. They don't necessarily indicate a bug, so you may need +to just ignore them. -#if VMA_RECORDING_ENABLED - if(allocator->GetRecorder() != VMA_NULL) - { - allocator->GetRecorder()->RecordDestroyImage( - allocator->GetCurrentFrameIndex(), - allocation); - } -#endif +- *vkBindBufferMemory(): Binding memory to buffer 0xeb8e4 but vkGetBufferMemoryRequirements() has not been called on that buffer.* + - It happens when VK_KHR_dedicated_allocation extension is enabled. + `vkGetBufferMemoryRequirements2KHR` function is used instead, while validation layer seems to be unaware of it. +- *Mapping an image with layout VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL can result in undefined behavior if this memory is used by the device. Only GENERAL or PREINITIALIZED should be used.* + - It happens when you map a buffer or image, because the library maps entire + `VkDeviceMemory` block, where different types of images and buffers may end + up together, especially on GPUs with unified memory like Intel. +- *Non-linear image 0xebc91 is aliased with linear buffer 0xeb8e4 which may indicate a bug.* + - It may happen when you use [defragmentation](@ref defragmentation). - if(image != VK_NULL_HANDLE) - { - (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, image, allocator->GetAllocationCallbacks()); - } - if(allocation != VK_NULL_HANDLE) - { - allocator->FreeMemory( - 1, // allocationCount - &allocation); - } -} +\section general_considerations_allocation_algorithm Allocation algorithm + +The library uses following algorithm for allocation, in order: + +-# Try to find free range of memory in existing blocks. +-# If failed, try to create a new block of `VkDeviceMemory`, with preferred block size. +-# If failed, try to create such block with size / 2, size / 4, size / 8. +-# If failed, try to allocate separate `VkDeviceMemory` for this allocation, + just like when you use #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. +-# If failed, choose other memory type that meets the requirements specified in + VmaAllocationCreateInfo and go to point 1. +-# If failed, return `VK_ERROR_OUT_OF_DEVICE_MEMORY`. + +\section general_considerations_features_not_supported Features not supported + +Features deliberately excluded from the scope of this library: -#endif // #ifdef VMA_IMPLEMENTATION +-# **Data transfer.** Uploading (streaming) and downloading data of buffers and images + between CPU and GPU memory and related synchronization is responsibility of the user. + Defining some "texture" object that would automatically stream its data from a + staging copy in CPU memory to GPU memory would rather be a feature of another, + higher-level library implemented on top of VMA. + VMA doesn't record any commands to a `VkCommandBuffer`. It just allocates memory. +-# **Recreation of buffers and images.** Although the library has functions for + buffer and image creation: vmaCreateBuffer(), vmaCreateImage(), you need to + recreate these objects yourself after defragmentation. That is because the big + structures `VkBufferCreateInfo`, `VkImageCreateInfo` are not stored in + #VmaAllocation object. +-# **Handling CPU memory allocation failures.** When dynamically creating small C++ + objects in CPU memory (not Vulkan memory), allocation failures are not checked + and handled gracefully, because that would complicate code significantly and + is usually not needed in desktop PC applications anyway. + Success of an allocation is just checked with an assert. +-# **Code free of any compiler warnings.** Maintaining the library to compile and + work correctly on so many different platforms is hard enough. Being free of + any warnings, on any version of any compiler, is simply not feasible. + There are many preprocessor macros that make some variables unused, function parameters unreferenced, + or conditional expressions constant in some configurations. + The code of this library should not be bigger or more complicated just to silence these warnings. + It is recommended to disable such warnings instead. +-# This is a C++ library with C interface. **Bindings or ports to any other programming languages** are welcome as external projects but + are not going to be included into this repository. +*/ diff --git a/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl b/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl new file mode 100644 index 0000000000000..6ec93422b0d6b --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl @@ -0,0 +1,37 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma; +layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta; +layout(set = 0, binding = 4) uniform PRECISION sampler3D uMean; +layout(set = 0, binding = 5) uniform PRECISION sampler3D uVar; +layout(set = 0, binding = 6) uniform PRECISION restrict Block { + ivec3 isize; + int channels_ext; + float eps; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.isize.xyz))) { + const ivec3 chn = ivec3(0, 0, pos.z % uBlock.channels_ext); + imageStore( + uOutput, + pos, + (texelFetch(uInput, pos, 0) + - texelFetch(uMean, chn, 0)) + / sqrt(texelFetch(uVar, chn, 0) + uBlock.eps) + * texelFetch(uGamma, chn, 0) + + texelFetch(uBeta, chn, 0)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_winograd_2_3.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_winograd_2_3.glsl deleted file mode 100644 index 9a3db58d4242c..0000000000000 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_winograd_2_3.glsl +++ /dev/null @@ -1,83 +0,0 @@ -#version 450 core -#define PRECISION $precision -#define FORMAT $format - -layout(std430) buffer; - -/* Qualifiers: layout - storage - precision - memory */ - -layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; -layout(set = 0, binding = 3) buffer PRECISION restrict readonly Bias { - vec4 data[]; -} uBias; -layout(set = 0, binding = 4) uniform PRECISION restrict Block { - ivec4 size; - vec2 clamp; -} uBlock; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec3 opos00 = ivec3(2*pos.xy, pos.z); - - if (all(lessThan(opos00, uBlock.size.xyz))) { - const ivec2 ipos00 = 4*pos.xy; - - vec4 dg[16] = { - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0), - vec4(0,0,0,0) - }; - - for (int y = 0; y < 4; ++y) { - for (int x = 0; x < 4; ++x) { - const ivec2 iposxy = ipos00.xy + ivec2(x,y); - ivec2 wpos = ivec2(4*uBlock.size.w*x, 4*pos.z+y); - for (int z4 = 0; z4 < uBlock.size.w; ++z4) { - const vec4 intex = texelFetch(uInput, ivec3(iposxy, z4), 0); - dg[4*y+x] += vec4( - dot(intex, texelFetch(uKernel, ivec3(wpos.x , wpos.y, 0), 0)), - dot(intex, texelFetch(uKernel, ivec3(wpos.x+1, wpos.y, 0), 0)), - dot(intex, texelFetch(uKernel, ivec3(wpos.x+2, wpos.y, 0), 0)), - dot(intex, texelFetch(uKernel, ivec3(wpos.x+3, wpos.y, 0), 0))); - - wpos += ivec2(4, 0); - } - } - } - - const vec4 o00 = dg[0] + dg[4] + dg[8]; - const vec4 o01 = dg[1] + dg[5] + dg[9]; - const vec4 o02 = dg[2] + dg[6] + dg[10]; - const vec4 o03 = dg[3] + dg[7] + dg[11]; - const vec4 o10 = dg[4] - dg[8] - dg[12]; - const vec4 o11 = dg[5] - dg[9] - dg[13]; - const vec4 o12 = dg[6] - dg[10] - dg[14]; - const vec4 o13 = dg[7] - dg[11] - dg[15]; - - const vec4 b = uBias.data[pos.z]; - imageStore(uOutput, ivec3(opos00.x, opos00.y, opos00.z), clamp(b + o00 + o01 + o02, uBlock.clamp.x, uBlock.clamp.y)); - if (opos00.x+1 < uBlock.size.x) - imageStore(uOutput, ivec3(opos00.x+1, opos00.y, opos00.z), clamp(b + o01 - o02 - o03, uBlock.clamp.x, uBlock.clamp.y)); - if (opos00.y+1 < uBlock.size.y) - imageStore(uOutput, ivec3(opos00.x, opos00.y+1, opos00.z), clamp(b + o10 + o11 + o12, uBlock.clamp.x, uBlock.clamp.y)); - if (opos00.x+1 < uBlock.size.x && opos00.y+1 < uBlock.size.y) - imageStore(uOutput, ivec3(opos00.x+1, opos00.y+1, opos00.z), clamp(b + o11 - o12 - o13, uBlock.clamp.x, uBlock.clamp.y)); - } -} diff --git a/aten/src/ATen/native/vulkan/glsl/dequantize.glsl b/aten/src/ATen/native/vulkan/glsl/dequantize.glsl new file mode 100644 index 0000000000000..9b90bac24c6ab --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/dequantize.glsl @@ -0,0 +1,28 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; //quantized input +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + vec2 scale; + ivec2 zero_point; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + vec4 texel = texelFetch(uInput, pos, 0); + imageStore( + uOutput, + pos, + (uBlock.scale.x * (texel - uBlock.zero_point.x))); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/glu_channel.glsl b/aten/src/ATen/native/vulkan/glsl/glu_channel.glsl new file mode 100644 index 0000000000000..43d2e075f82ca --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/glu_channel.glsl @@ -0,0 +1,52 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec3 size; // output size + int ch; // channel size of the output +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const int z0a = 2 * ((4 * pos.z) / uBlock.ch) * uBlock.ch + ((4 * pos.z) % uBlock.ch); + const int z1a = 2 * ((4 * pos.z + 1) / uBlock.ch) * uBlock.ch + ((4 * pos.z + 1) % uBlock.ch); + const int z2a = 2 * ((4 * pos.z + 2) / uBlock.ch) * uBlock.ch + ((4 * pos.z + 2) % uBlock.ch); + const int z3a = 2 * ((4 * pos.z + 3) / uBlock.ch) * uBlock.ch + ((4 * pos.z + 3) % uBlock.ch); + + const int z0b = z0a + uBlock.ch; + const int z1b = z1a + uBlock.ch; + const int z2b = z2a + uBlock.ch; + const int z3b = z3a + uBlock.ch; + + const float v0a = texelFetch(uInput, ivec3(pos.x, pos.y, z0a / 4), 0)[z0a % 4]; + const float v0b = texelFetch(uInput, ivec3(pos.x, pos.y, z0b / 4), 0)[z0b % 4]; + const float v1a = texelFetch(uInput, ivec3(pos.x, pos.y, z1a / 4), 0)[z1a % 4]; + const float v1b = texelFetch(uInput, ivec3(pos.x, pos.y, z1b / 4), 0)[z1b % 4]; + const float v2a = texelFetch(uInput, ivec3(pos.x, pos.y, z2a / 4), 0)[z2a % 4]; + const float v2b = texelFetch(uInput, ivec3(pos.x, pos.y, z2b / 4), 0)[z2b % 4]; + const float v3a = texelFetch(uInput, ivec3(pos.x, pos.y, z3a / 4), 0)[z3a % 4]; + const float v3b = texelFetch(uInput, ivec3(pos.x, pos.y, z3b / 4), 0)[z3b % 4]; + + imageStore( + uOutput, + pos, + vec4( + v0a * (1 / (1 + exp(-1 * v0b))), + v1a * (1 / (1 + exp(-1 * v1b))), + v2a * (1 / (1 + exp(-1 * v2b))), + v3a * (1 / (1 + exp(-1 * v3b))) + ) + ); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/glu_channel_mul4.glsl b/aten/src/ATen/native/vulkan/glsl/glu_channel_mul4.glsl new file mode 100644 index 0000000000000..2650117dc4487 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/glu_channel_mul4.glsl @@ -0,0 +1,31 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec3 size; // output size + int ch; // channel size of the output +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const int chext = uBlock.ch / 4; + const int z0 = 2 * (pos.z / chext) * chext + (pos.z % chext); + const int z1 = z0 + chext; + imageStore( + uOutput, + pos, + texelFetch(uInput, ivec3(pos.x, pos.y, z0), 0) + * 1 / (1 + exp(-1 * texelFetch(uInput, ivec3(pos.x, pos.y, z1), 0)))); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized.glsl new file mode 100644 index 0000000000000..07db4318df9c5 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized.glsl @@ -0,0 +1,65 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0) uniform PRECISION isampler3D uImage; +layout(set = 0, binding = 1) buffer PRECISION Buffer { + uint data[]; +} uBuffer; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec4 offset; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (pos.y == 0 && pos.z == 0) { + ivec4 texture_pos = ivec4(0,1,2,3) + 4 * pos.x; + + ivec4 last_eight; + last_eight.z = texture_pos.x / (uBlock.size.x * uBlock.size.y); + last_eight.w = texture_pos.x % (uBlock.size.x * uBlock.size.y); + last_eight.y = last_eight.w / uBlock.size.x; + last_eight.x = last_eight.w % uBlock.size.x; + + ivec4 sec_last_eight; + sec_last_eight.z = texture_pos.y / (uBlock.size.x * uBlock.size.y); + sec_last_eight.w = texture_pos.y % (uBlock.size.x * uBlock.size.y); + sec_last_eight.y = sec_last_eight.w / uBlock.size.x; + sec_last_eight.x = sec_last_eight.w % uBlock.size.x; + + ivec4 thr_last_eight; + thr_last_eight.z = texture_pos.z / (uBlock.size.x * uBlock.size.y); + thr_last_eight.w = texture_pos.z % (uBlock.size.x * uBlock.size.y); + thr_last_eight.y = thr_last_eight.w / uBlock.size.x; + thr_last_eight.x = thr_last_eight.w % uBlock.size.x; + + ivec4 four_last_eight; + four_last_eight.z = texture_pos.w / (uBlock.size.x * uBlock.size.y); + four_last_eight.w = texture_pos.w % (uBlock.size.x * uBlock.size.y); + four_last_eight.y = four_last_eight.w / uBlock.size.x; + four_last_eight.x = four_last_eight.w % uBlock.size.x; + + ivec3 last_eight_pos = ivec3(last_eight.x, last_eight.y, last_eight.z / 4); + ivec3 sec_last_eight_pos = ivec3(sec_last_eight.x, sec_last_eight.y, sec_last_eight.z / 4); + ivec3 thr_last_eight_pos = ivec3(thr_last_eight.x, thr_last_eight.y, thr_last_eight.z / 4); + ivec3 four_last_eight_pos = ivec3(four_last_eight.x, four_last_eight.y, four_last_eight.z / 4); + + int texel_1 = texelFetch(uImage, last_eight_pos, 0)[last_eight.z]; + int texel_2 = texelFetch(uImage, sec_last_eight_pos, 0)[sec_last_eight.z]; + int texel_3 = texelFetch(uImage, thr_last_eight_pos, 0)[thr_last_eight.z]; + int texel_4 = texelFetch(uImage, four_last_eight_pos, 0)[four_last_eight.z]; + + uint ui32 = (uint(texel_4 & 0xFF) << 24) + | (uint(texel_3 & 0xFF) << 16) + | (uint(texel_2 & 0xFF) << 8) + | (uint(texel_1 & 0xFF)); + + uBuffer.data[texture_pos.x / 4] = ui32; + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/layernorm.glsl b/aten/src/ATen/native/vulkan/glsl/layernorm.glsl new file mode 100644 index 0000000000000..7a347cfb73a69 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/layernorm.glsl @@ -0,0 +1,151 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma; +layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta; +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + ivec3 isize; + int volume; + int offset; + float eps; +} uBlock; + +shared float sh_mem[64]; +shared float mean; +shared float rstd; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// This is the simple two-pass algorithm to compute variance. +// This implementation is not efficient when calculating mean and +// variance since every work group will compute the mean and variance +// for the entire tensor. + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec3 tid = ivec3(gl_LocalInvocationID); + const ivec3 group_size = ivec3(gl_WorkGroupSize); + + // Start computing mean. + // Divide work among the 64 invocations in the work group + // and compute partial sums of texels that are "fully filled" + vec4 sum4d = vec4(0); + for (int z = tid.z; z < uBlock.isize.z - 1; z+=group_size.z) { + for (int y = tid.y; y < uBlock.isize.y; y+=group_size.y) { + for (int x = tid.x; x < uBlock.isize.x; x+=group_size.x) { + sum4d += texelFetch(uInput, ivec3(x, y, z), 0); + } + } + } + float sum = sum4d.x + sum4d.y + sum4d.w + sum4d.z; + + // Still computing the mean, processing the last texel across the channel-batch dimension + if ((uBlock.isize.z - 1) % group_size.z == tid.z) { + for (int y = tid.y; y < uBlock.isize.y; y+=group_size.y) { + for (int x = tid.x; x < uBlock.isize.x; x+=group_size.x) { + const vec4 last_texel = texelFetch(uInput, ivec3(x, y, uBlock.isize.z - 1), 0); + sum += ( + last_texel.x + + (uBlock.offset >= 1 ? last_texel.y : 0) + + (uBlock.offset >= 2 ? last_texel.z : 0) + + (uBlock.offset == 3 ? last_texel.w : 0) + ); + } + } + } + + // Shared memory (among threads in a work group) that holds partial sums + sh_mem[gl_LocalInvocationIndex] = sum; + + memoryBarrierShared(); + barrier(); + + // Only instance (0, 0, 0) will compute the sum of the 64 partial sums, + // and then compute the mean, dividing the total by the tensor's volume + if (tid == ivec3(0)) { + float total = 0; + for (int z = 0; z < group_size.z; ++z) { + for (int y = 0; y < group_size.y; ++y) { + for (int x = 0; x < group_size.x; ++x) { + total += sh_mem[z * group_size.y * group_size.x + y * group_size.x + x]; + } + } + } + mean = total / uBlock.volume; + } + + memoryBarrierShared(); + barrier(); + + // Start computing variance (using the previously computed mean) + // Divide work among the 64 invocations in the work group + // and compute partial sums of texels that are "fully filled" + vec4 sqsum4d = vec4(0); + for (int z = tid.z; z < uBlock.isize.z - 1; z+=group_size.z) { + for (int y = tid.y; y < uBlock.isize.y; y+=group_size.y) { + for (int x = tid.x; x < uBlock.isize.x; x+=group_size.x) { + const vec4 val = texelFetch(uInput, ivec3(x, y, z), 0); + sqsum4d += (val - mean) * (val - mean); + } + } + } + float sqsum = sqsum4d.x + sqsum4d.y + sqsum4d.w + sqsum4d.z; + + // Still computing the variance, processing the last texel across the channel-batch dimension + if ((uBlock.isize.z - 1) % group_size.z == tid.z) { + for (int y = tid.y; y < uBlock.isize.y; y+=group_size.y) { + for (int x = tid.x; x < uBlock.isize.x; x+=group_size.x) { + const vec4 last_texel = texelFetch(uInput, ivec3(x, y, uBlock.isize.z - 1), 0); + sqsum += ( + (last_texel.x - mean) * (last_texel.x - mean) + + (uBlock.offset >= 1 ? (last_texel.y - mean) * (last_texel.y - mean) : 0) + + (uBlock.offset >= 2 ? (last_texel.z - mean) * (last_texel.z - mean) : 0) + + (uBlock.offset == 3 ? (last_texel.w - mean) * (last_texel.w - mean) : 0) + ); + } + } + } + + // Reuse shared memory to hold partial squared sums + sh_mem[gl_LocalInvocationIndex] = sqsum; + + memoryBarrierShared(); + barrier(); + + // Only instance (0, 0, 0) will compute the sum of the 64 partial sums, + // and then compute the squared root of the biased variance, with eps added + // to the denominator for numerical stabilty. + if (tid == ivec3(0)) { + float total2 = 0; + for (int z = 0; z < group_size.z; ++z) { + for (int y = 0; y < group_size.y; ++y) { + for (int x = 0; x < group_size.x; ++x) { + total2 += sh_mem[z * group_size.y * group_size.x + y * group_size.x + x]; + } + } + } + rstd = sqrt(total2 / uBlock.volume + uBlock.eps); + } + + memoryBarrierShared(); + barrier(); + + // Compute layernorm using previously computed mean and rstd + if (all(lessThan(pos, uBlock.isize.xyz))) { + imageStore( + uOutput, + pos, + (texelFetch(uInput, pos, 0) + - mean) + / rstd + * texelFetch(uGamma, pos, 0) + + texelFetch(uBeta, pos, 0)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/nchw_to_image_quantized.glsl b/aten/src/ATen/native/vulkan/glsl/nchw_to_image_quantized.glsl new file mode 100644 index 0000000000000..d23796d8af4b3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/nchw_to_image_quantized.glsl @@ -0,0 +1,52 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uImage; +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + uint data[]; +} uBuffer; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec4 offset; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + const int base = pos.x + uBlock.size.x * pos.y + uBlock.size.w * pos.z; + const ivec4 index = base + uBlock.offset; + + int shift = (1 << 8) - 1; + ivec4 masks; + masks.x = shift << 8 * (index.x % 4); + masks.y = shift << 8 * (index.y % 4); + masks.z = shift << 8 * (index.z % 4); + masks.w = shift << 8 * (index.w % 4); + + uint buf_in_1 = uBuffer.data[index.x / 4]; + uint a_v = (buf_in_1 & masks.x) >> 8 * (index.x % 4); + + uint buf_in_2 = uBuffer.data[index.y / 4]; + uint b_v = (buf_in_2 & masks.y) >> 8 * (index.y % 4); + + uint buf_in_3 = uBuffer.data[index.z / 4]; + uint g_v = (buf_in_3 & masks.z) >> 8 * (index.z % 4); + + uint buf_in_4 = uBuffer.data[index.w / 4]; + uint r_v = (buf_in_4 & masks.w) >> 8 * (index.w % 4); + + uvec4 texel = uvec4(a_v, b_v, g_v, r_v); + + imageStore( + uImage, + pos, + texel); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl new file mode 100644 index 0000000000000..910603aa29f26 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl @@ -0,0 +1,29 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; //input +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + vec2 scale; + ivec2 zero_point; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + vec4 ret = texelFetch(uInput, pos, 0) / uBlock.scale.x + uBlock.zero_point.x; + uvec4 texel = uvec4(int(ret.x), int(ret.y), int(ret.z), int(ret.w)); + imageStore( + uOutput, + pos, + texel); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl new file mode 100644 index 0000000000000..8f6e51397d1c1 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl @@ -0,0 +1,46 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput0; //quantized input +layout(set = 0, binding = 2) uniform PRECISION isampler3D uInput1; //quantized input +layout(set = 0, binding = 3) uniform PRECISION restrict Block { + ivec4 size; + ivec4 isize0; + ivec4 isize1; + vec2 in_scale; + ivec2 in_zero_point; + vec2 out_scale; + ivec2 out_zero_point; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec3 input0_pos = pos % uBlock.isize0.xyz; + const ivec3 input1_pos = pos % uBlock.isize1.xyz; + + vec4 texel0 = texelFetch(uInput0, input0_pos, 0); + vec4 texel1 = texelFetch(uInput1, input1_pos, 0); + + vec4 deq_in_0 = uBlock.in_scale.x * (texel0 - uBlock.in_zero_point.x); + vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); + + vec4 res = deq_in_0 + deq_in_1; + vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + + uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + + imageStore( + uOutput, + pos, + ret); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl new file mode 100644 index 0000000000000..a361a44e85994 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl @@ -0,0 +1,88 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION isampler3D uKernel; +layout(set = 0, binding = 3) uniform PRECISION isampler3D uBias; +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + ivec4 size; + ivec4 kernel; + vec2 scale; + ivec2 zero_point; + vec2 other_inp_scale; + ivec2 other_inp_zero_point; + ivec2 ikernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; + + const ivec2 start = max(ivec2(0), ipos); + const ivec2 end = min(ipos + uBlock.kernel.xy, uBlock.kernel.zw); + ivec2 kstart = (start - ipos) / uBlock.dilate; + + kstart.x *= 4; + kstart.y += pos.z * uBlock.ikernel.y; + + vec4 q_sum = texelFetch(uBias, ivec3(pos.z, 0, 0), 0); + vec4 sum = uBlock.other_inp_scale.y * (q_sum - uBlock.other_inp_zero_point.y); + + for (int z4 = 0; z4 < uBlock.size.w/4; ++z4, kstart.x += uBlock.ikernel.x*4) { + for (int y = start.y, ky = kstart.y; y < end.y; y += uBlock.dilate.y, ++ky) { + for (int x = start.x, kx = kstart.x; x < end.x; x += uBlock.dilate.x, kx += 4) { + const vec4 In = texelFetch(uInput, ivec3(x, y, z4), 0); + vec4 deq_In = uBlock.scale.y * (In - uBlock.zero_point.y); + const ivec4 kxs = kx + ivec4(0, 1, 2, 3); + + const vec4 weight_x = texelFetch(uKernel, ivec3(kxs.x, ky, 0), 0); + if (weight_x != vec4(0.0)) { + vec4 deq_weight_x = uBlock.other_inp_scale.x * (weight_x - uBlock.other_inp_zero_point.x); + sum = fma(deq_In.xxxx, deq_weight_x, sum); + } + + const vec4 weight_y = texelFetch(uKernel, ivec3(kxs.y, ky, 0), 0); + if (weight_y != vec4(0.0)) { + vec4 deq_weight_y = uBlock.other_inp_scale.x * (weight_y - uBlock.other_inp_zero_point.x); + sum = fma(deq_In.yyyy, deq_weight_y, sum); + } + + const vec4 weight_z = texelFetch(uKernel, ivec3(kxs.z, ky, 0), 0); + if (weight_z != vec4(0.0)) { + vec4 deq_weight_z = uBlock.other_inp_scale.x * (weight_z - uBlock.other_inp_zero_point.x); + sum = fma(deq_In.zzzz, deq_weight_z, sum); + } + + const vec4 weight_w = texelFetch(uKernel, ivec3(kxs.w, ky, 0), 0); + if (weight_w != vec4(0.0)) { + vec4 deq_weight_w = uBlock.other_inp_scale.x * (weight_w - uBlock.other_inp_zero_point.x); + sum = fma(deq_In.wwww, deq_weight_w, sum); + } + } + } + } + + sum = clamp(sum, uBlock.clamp.x, uBlock.clamp.y); + vec4 q_ret = sum / uBlock.scale.x + uBlock.zero_point.x; + uvec4 res = uvec4(int(q_ret.x), int(q_ret.y), int(q_ret.z), int(q_ret.w)); + + imageStore( + uOutput, + pos, + res); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl new file mode 100644 index 0000000000000..41681da3f52bf --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl @@ -0,0 +1,67 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION isampler3D uKernel; +layout(set = 0, binding = 3) uniform PRECISION isampler3D uBias; +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + ivec4 size; + ivec4 kernel; + vec2 scale; + ivec2 zero_point; + vec2 other_inp_scale; + ivec2 other_inp_zero_point; + ivec2 ikernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; + + const ivec2 start = max(ivec2(0), ipos); + const ivec2 end = min(ipos + uBlock.kernel.xy, uBlock.kernel.zw); + const ivec2 kstart = (start - ipos) / uBlock.dilate; + + vec4 q_sum = texelFetch(uBias, ivec3(pos.z, 0, 0), 0); + vec4 sum = uBlock.other_inp_scale.y * (q_sum - uBlock.other_inp_zero_point.y); + + for (int y = start.y, ky = kstart.y; y < end.y; y += uBlock.dilate.y, ++ky) { + for (int x = start.x, kx = kstart.x + ky * uBlock.ikernel.x; x < end.x; x += uBlock.dilate.x, ++kx) { + const vec4 In = texelFetch(uInput, ivec3(x, y, pos.z), 0); + vec4 deq_In = uBlock.scale.y * (In - uBlock.zero_point.y); + + const vec4 weight = texelFetch(uKernel, ivec3(kx, pos.z, 0), 0); + if (weight != vec4(0.0)) { + vec4 deq_weight = uBlock.other_inp_scale.x * (weight - uBlock.other_inp_zero_point.x); + sum = fma( + deq_In, + deq_weight, + sum); + } + } + } + + sum = clamp(sum, uBlock.clamp.x, uBlock.clamp.y); + vec4 q_ret = sum / uBlock.scale.x + uBlock.zero_point.x; + uvec4 res = uvec4(int(q_ret.x), int(q_ret.y), int(q_ret.z), int(q_ret.w)); + + imageStore( + uOutput, + pos, + res); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl new file mode 100644 index 0000000000000..7d9805fb9fe49 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl @@ -0,0 +1,136 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION isampler3D uKernel; +layout(set = 0, binding = 3) uniform PRECISION isampler3D uBias; +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + ivec4 size; + ivec4 kernel; + vec2 scale; + ivec2 zero_point; + vec2 other_inp_scale; + ivec2 other_inp_zero_point; + ivec2 ikernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + const ivec3 pos00 = ivec3(pos.x*2 , pos.y*2 , pos.z); + const ivec3 pos10 = ivec3(pos.x*2+1, pos.y*2 , pos.z); + const ivec3 pos01 = ivec3(pos.x*2 , pos.y*2+1, pos.z); + const ivec3 pos11 = ivec3(pos.x*2+1, pos.y*2+1, pos.z); + + if (all(lessThan(pos00, uBlock.size.xyz))) { + const ivec2 ipos00 = pos00.xy * uBlock.stride - uBlock.padding; + const ivec2 ipos10 = pos10.xy * uBlock.stride - uBlock.padding; + const ivec2 ipos01 = pos01.xy * uBlock.stride - uBlock.padding; + const ivec2 ipos11 = pos11.xy * uBlock.stride - uBlock.padding; + + vec4 q_sum00 = texelFetch(uBias, ivec3(pos.z, 0, 0), 0); + vec4 sum00 = uBlock.other_inp_scale.y * (q_sum00 - uBlock.other_inp_zero_point.y); + vec4 sum10 = sum00; + vec4 sum01 = sum00; + vec4 sum11 = sum00; + + for (int z = 0, z4 = 0; z < uBlock.size.w; z += 4, ++z4) { + const ivec4 kxs = z + ivec4(0, 1, 2, 3); + const vec4 q_k1 = texelFetch(uKernel, ivec3(kxs.x, pos.z, 0), 0); + const vec4 k1 = uBlock.other_inp_scale.x * (q_k1 - uBlock.other_inp_zero_point.x); + const vec4 q_k2 = texelFetch(uKernel, ivec3(kxs.y, pos.z, 0), 0); + const vec4 k2 = uBlock.other_inp_scale.x * (q_k2 - uBlock.other_inp_zero_point.x); + const vec4 q_k3 = texelFetch(uKernel, ivec3(kxs.z, pos.z, 0), 0); + const vec4 k3 = uBlock.other_inp_scale.x * (q_k3 - uBlock.other_inp_zero_point.x); + const vec4 q_k4 = texelFetch(uKernel, ivec3(kxs.w, pos.z, 0), 0); + const vec4 k4 = uBlock.other_inp_scale.x * (q_k4 - uBlock.other_inp_zero_point.x); + + const vec4 In00 = texelFetch(uInput, ivec3(ipos00, z4), 0); + vec4 deq_In00 = uBlock.scale.y * (In00 - uBlock.zero_point.y); + const vec4 In10 = texelFetch(uInput, ivec3(ipos10, z4), 0); + vec4 deq_In10 = uBlock.scale.y * (In10 - uBlock.zero_point.y); + const vec4 In01 = texelFetch(uInput, ivec3(ipos01, z4), 0); + vec4 deq_In01 = uBlock.scale.y * (In01 - uBlock.zero_point.y); + const vec4 In11 = texelFetch(uInput, ivec3(ipos11, z4), 0); + vec4 deq_In11 = uBlock.scale.y * (In11 - uBlock.zero_point.y); + + if (q_k1 != vec4(0.0)) { + sum00 = fma(deq_In00.xxxx, k1, sum00); + sum10 = fma(deq_In10.xxxx, k1, sum10); + sum01 = fma(deq_In01.xxxx, k1, sum01); + sum11 = fma(deq_In11.xxxx, k1, sum11); + } + + if (q_k2 != vec4(0.0)) { + sum00 = fma(deq_In00.yyyy, k2, sum00); + sum10 = fma(deq_In10.yyyy, k2, sum10); + sum01 = fma(deq_In01.yyyy, k2, sum01); + sum11 = fma(deq_In11.yyyy, k2, sum11); + } + + if (q_k3 != vec4(0.0)) { + sum00 = fma(deq_In00.zzzz, k3, sum00); + sum10 = fma(deq_In10.zzzz, k3, sum10); + sum01 = fma(deq_In01.zzzz, k3, sum01); + sum11 = fma(deq_In11.zzzz, k3, sum11); + } + + if (q_k4 != vec4(0.0)) { + sum00 = fma(deq_In00.wwww, k4, sum00); + sum10 = fma(deq_In10.wwww, k4, sum10); + sum01 = fma(deq_In01.wwww, k4, sum01); + sum11 = fma(deq_In11.wwww, k4, sum11); + } + } + sum00 = clamp(sum00, uBlock.clamp.x, uBlock.clamp.y); + vec4 q_ret00 = sum00 / uBlock.scale.x + uBlock.zero_point.x; + uvec4 res00 = uvec4(int(q_ret00.x), int(q_ret00.y), int(q_ret00.z), int(q_ret00.w)); + + sum10 = clamp(sum10, uBlock.clamp.x, uBlock.clamp.y); + vec4 q_ret10 = sum10 / uBlock.scale.x + uBlock.zero_point.x; + uvec4 res10 = uvec4(int(q_ret10.x), int(q_ret10.y), int(q_ret10.z), int(q_ret10.w)); + + sum01 = clamp(sum01, uBlock.clamp.x, uBlock.clamp.y); + vec4 q_ret01 = sum01 / uBlock.scale.x + uBlock.zero_point.x; + uvec4 res01 = uvec4(int(q_ret01.x), int(q_ret01.y), int(q_ret01.z), int(q_ret01.w)); + + sum11 = clamp(sum11, uBlock.clamp.x, uBlock.clamp.y); + vec4 q_ret11 = sum11 / uBlock.scale.x + uBlock.zero_point.x; + uvec4 res11 = uvec4(int(q_ret11.x), int(q_ret11.y), int(q_ret11.z), int(q_ret11.w)); + + imageStore( + uOutput, + pos00, + res00); + if (all(lessThan(pos10, uBlock.size.xyz))) { + imageStore( + uOutput, + pos10, + res10); + } + if (all(lessThan(pos01, uBlock.size.xyz))) { + imageStore( + uOutput, + pos01, + res01); + } + if (all(lessThan(pos11, uBlock.size.xyz))) { + imageStore( + uOutput, + pos11, + res11); + } + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl new file mode 100644 index 0000000000000..aa961eb349934 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl @@ -0,0 +1,46 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput0; //quantized input +layout(set = 0, binding = 2) uniform PRECISION isampler3D uInput1; //quantized input +layout(set = 0, binding = 3) uniform PRECISION restrict Block { + ivec4 size; + ivec4 isize0; + ivec4 isize1; + vec2 in_scale; + ivec2 in_zero_point; + vec2 out_scale; + ivec2 out_zero_point; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec3 input0_pos = pos % uBlock.isize0.xyz; + const ivec3 input1_pos = pos % uBlock.isize1.xyz; + + vec4 texel0 = texelFetch(uInput0, input0_pos, 0); + vec4 texel1 = texelFetch(uInput1, input1_pos, 0); + + vec4 deq_in_0 = uBlock.in_scale.x * (texel0 - uBlock.in_zero_point.x); + vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); + + vec4 res = deq_in_0 / deq_in_1; + vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + + uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + + imageStore( + uOutput, + pos, + ret); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl new file mode 100644 index 0000000000000..459f56915d774 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl @@ -0,0 +1,46 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput0; //quantized input +layout(set = 0, binding = 2) uniform PRECISION isampler3D uInput1; //quantized input +layout(set = 0, binding = 3) uniform PRECISION restrict Block { + ivec4 size; + ivec4 isize0; + ivec4 isize1; + vec2 in_scale; + ivec2 in_zero_point; + vec2 out_scale; + ivec2 out_zero_point; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec3 input0_pos = pos % uBlock.isize0.xyz; + const ivec3 input1_pos = pos % uBlock.isize1.xyz; + + vec4 texel0 = texelFetch(uInput0, input0_pos, 0); + vec4 texel1 = texelFetch(uInput1, input1_pos, 0); + + vec4 deq_in_0 = uBlock.in_scale.x * (texel0 - uBlock.in_zero_point.x); + vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); + + vec4 res = deq_in_0 * deq_in_1; + vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + + uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + + imageStore( + uOutput, + pos, + ret); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl new file mode 100644 index 0000000000000..6bd00f33a89c0 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl @@ -0,0 +1,46 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput0; //quantized input +layout(set = 0, binding = 2) uniform PRECISION isampler3D uInput1; //quantized input +layout(set = 0, binding = 3) uniform PRECISION restrict Block { + ivec4 size; + ivec4 isize0; + ivec4 isize1; + vec2 in_scale; + ivec2 in_zero_point; + vec2 out_scale; + ivec2 out_zero_point; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec3 input0_pos = pos % uBlock.isize0.xyz; + const ivec3 input1_pos = pos % uBlock.isize1.xyz; + + vec4 texel0 = texelFetch(uInput0, input0_pos, 0); + vec4 texel1 = texelFetch(uInput1, input1_pos, 0); + + vec4 deq_in_0 = uBlock.in_scale.x * (texel0 - uBlock.in_zero_point.x); + vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); + + vec4 res = deq_in_0 - deq_in_1; + vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + + uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + + imageStore( + uOutput, + pos, + ret); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl new file mode 100644 index 0000000000000..28c167515405e --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl @@ -0,0 +1,36 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec2 isize; + vec2 scale; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = clamp( + ivec2(pos.xy * uBlock.scale), + ivec2(0), + uBlock.isize); + + vec4 texel = texelFetch(uInput, ivec3(ipos, pos.z), 0); + uvec4 ret = uvec4(int(texel.r), int(texel.g), int(texel.b), int(texel.a)); + + imageStore( + uOutput, + pos, + ret); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/select_depth.glsl b/aten/src/ATen/native/vulkan/glsl/select_depth.glsl new file mode 100644 index 0000000000000..5ce7bbcebb389 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/select_depth.glsl @@ -0,0 +1,31 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec3 size; + int index; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const int tex = uBlock.index / 4; + const int ind = uBlock.index % 4; + const float v = texelFetch(uInput, ivec3(pos.x, pos.y, tex), 0)[ind]; + + imageStore( + uOutput, + ivec3(pos.x, pos.y, 0), + vec4(v, 0, 0, 0)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/stack_feature.glsl b/aten/src/ATen/native/vulkan/glsl/stack_feature.glsl new file mode 100644 index 0000000000000..dc908ae94a4af --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/stack_feature.glsl @@ -0,0 +1,35 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uInput1; +layout(set = 0, binding = 3) uniform PRECISION sampler3D uInput2; +layout(set = 0, binding = 4) uniform PRECISION sampler3D uInput3; +layout(set = 0, binding = 5) uniform PRECISION restrict Block { + ivec3 size; + int z; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 posIn = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(posIn, uBlock.size.xyz))) { + imageStore( + uOutput, + ivec3(posIn.x, posIn.y, uBlock.z), + vec4( + texelFetch(uInput0, posIn, 0).x, + texelFetch(uInput1, posIn, 0).x, + texelFetch(uInput2, posIn, 0).x, + texelFetch(uInput3, posIn, 0).x + )); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/threshold.glsl b/aten/src/ATen/native/vulkan/glsl/threshold.glsl index 1104adf1c9119..cb5c8c27f2eef 100644 --- a/aten/src/ATen/native/vulkan/glsl/threshold.glsl +++ b/aten/src/ATen/native/vulkan/glsl/threshold.glsl @@ -20,9 +20,9 @@ void main() { if (all(lessThan(pos, uBlock.size.xyz))) { vec4 inval = texelFetch(uInput, pos, 0); - vec4 mask1 = vec4(lessThan(inval, vec4(uBlock.params.x))); - vec4 mask2 = vec4(greaterThan(inval, vec4(uBlock.params.x))); - vec4 outval = mask2 * inval + mask1 * uBlock.params.y; + vec4 mask1 = vec4(greaterThan(inval, vec4(uBlock.params.x))); + vec4 mask2 = 1.0f - mask1; + vec4 outval = mask1 * inval + mask2 * uBlock.params.y; imageStore(uOutput, pos, outval); } } diff --git a/aten/src/ATen/native/vulkan/glsl/transform_winograd_2_3_sh.glsl b/aten/src/ATen/native/vulkan/glsl/transform_winograd_2_3_sh.glsl deleted file mode 100644 index 879549494875f..0000000000000 --- a/aten/src/ATen/native/vulkan/glsl/transform_winograd_2_3_sh.glsl +++ /dev/null @@ -1,62 +0,0 @@ -#version 450 core -#define PRECISION $precision -#define FORMAT $format - -layout(std430) buffer; - -/* Qualifiers: layout - storage - precision - memory */ - -layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform PRECISION restrict Block { - ivec4 size; - ivec2 limits; - ivec2 padding; -} uBlock; - -shared vec4 i[4][4][4]; - -ivec2 off[4] = { - ivec2(0, 2), - ivec2(0, 1), - ivec2(0, -1), - ivec2(-2, 0) -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec3 tid = ivec3(gl_LocalInvocationID); - - const ivec2 ipos = (pos.xy/4) * 2 - uBlock.padding + tid.xy; - - const int shz = tid.z*16; - const int shy = tid.y*4; - const int shzy = shz + shy; - - i[tid.z][tid.y][tid.x] = texelFetch(uInput, ivec3(ipos.x, ipos.y, pos.z), 0) * - int(all(greaterThanEqual(ipos, ivec2(0,0)))) * - int(all(lessThan(ipos, uBlock.limits))); - - memoryBarrierShared(); - barrier(); - - const ivec2 ys = off[tid.y] + tid.y; - const ivec2 xs = off[tid.x] + tid.x; - - const vec4 c0 = tid.y != 1 ? i[tid.z][ys.x][xs.x] - i[tid.z][ys.y][xs.x] : - i[tid.z][ys.x][xs.x] + i[tid.z][ys.y][xs.x]; - const vec4 c1 = tid.y != 1 ? i[tid.z][ys.x][xs.y] - i[tid.z][ys.y][xs.y] : - i[tid.z][ys.x][xs.y] + i[tid.z][ys.y][xs.y]; - - vec4 outvec; - if (tid.x == 1) - outvec = c0 + c1; - else - outvec = c0 - c1; - - if (all(lessThan(pos, uBlock.size.xyz))) { - imageStore(uOutput, pos, outvec); - } -} diff --git a/aten/src/ATen/native/vulkan/ops/Arithmetic.cpp b/aten/src/ATen/native/vulkan/ops/Arithmetic.cpp index 268487e10c1c5..1551bd49b7766 100644 --- a/aten/src/ATen/native/vulkan/ops/Arithmetic.cpp +++ b/aten/src/ATen/native/vulkan/ops/Arithmetic.cpp @@ -1,6 +1,8 @@ -#include +#include #include +#include #include +#include namespace at { namespace native { @@ -8,7 +10,14 @@ namespace vulkan { namespace ops { namespace { -using namespace api::utils; +bool broadcast_input(const Tensor& input1, const Tensor& input2) { + return ((height_size(input1) > 1 && height_size(input2) == 1) || + (height_size(input2) > 1 && height_size(input1) == 1) || + (height_size(input1) == height_size(input2))) && + ((width_size(input1) > 1 && width_size(input2) == 1) || + (width_size(input2) > 1 && width_size(input1) == 1) || + (width_size(input1) == width_size(input2))); +} void check_inputs(const Tensor& input1, const Tensor& input2) { TORCH_CHECK( @@ -20,43 +29,52 @@ void check_inputs(const Tensor& input1, const Tensor& input2) { "Vulkan binary elementwise ops require channel to be a multiple of 4 to broadcast along batch dimension!") } - const uint32_t input1_h = height_size(input1); - const uint32_t input1_w = width_size(input1); - const uint32_t input2_h = height_size(input2); - const uint32_t input2_w = width_size(input2); - const std::string broadcast_error_msg = "Incompatible input dimensions for broadcasting for Vulkan binary elementwise op!"; - if (input1_h != input2_h) { - if (input1_h > input2_h) { - TORCH_CHECK(input2_h == 1, broadcast_error_msg); - TORCH_CHECK(input2_w == input1_w || input2_w == 1, broadcast_error_msg); - } else if (input2_h > input1_h) { - TORCH_CHECK(input1_h == 1, broadcast_error_msg); - TORCH_CHECK(input1_w == input2_w || input1_w == 1, broadcast_error_msg); + + TORCH_CHECK(broadcast_input(input1, input2), broadcast_error_msg); +} + +std::vector broadcast_size( + const Tensor& input1, + const Tensor& input2) { + std::vector out = {}; + int input1_size = input1.sizes().size(); + int input2_size = input2.sizes().size(); + if (input1_size > input2_size) { + for (int i = 0; i < input1_size; i++) { + out.push_back(input1.sizes()[i]); } - } else if (input1_w != input2_w) { - if (input1_w > input2_w) { - TORCH_CHECK(input2_w == 1, broadcast_error_msg); - } else if (input2_w > input1_w) { - TORCH_CHECK(input1_h == 1, broadcast_error_msg); + } else { + for (int i = 0; i < input2_size; i++) { + out.push_back(input2.sizes()[i]); + } + } + + if (width_size(input1) > 1 && width_size(input2) == 1) { + out[out.size() - 1] = width_size(input1); + } else if (width_size(input2) > 1 && width_size(input1) == 1) { + out[out.size() - 1] = width_size(input2); + } + + if (out.size() > 1) { + if (height_size(input1) > 1 && height_size(input2) == 1) { + out[out.size() - 2] = height_size(input1); + } else if (height_size(input2) > 1 && height_size(input1) == 1) { + out[out.size() - 2] = height_size(input2); } } -} -bool broadcast_first_input(const vTensor& input1, const vTensor& input2) { - return ( - (input2.extents().data[1u] > 1 && input1.extents().data[1u] == 1) || - (input2.extents().data[2u] > 1 && input1.extents().data[2u] == 1) || - input2.extents().data[0u] > input1.extents().data[0u]); + return out; } +} // namespace +using namespace api::utils; Tensor arithmetic_scalar( const Tensor& self_arg, const Scalar& other, const c10::optional& alpha_arg, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { + const api::ShaderSource& shader_descriptor) { api::Context* const context = api::context(); const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); @@ -68,116 +86,95 @@ Tensor arithmetic_scalar( v_self.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY (v_output.has_image() && v_self.has_image()) { - const float other_val = alpha_arg - ? other.to() * alpha_arg->to() - : other.to(); - const struct Block final { - uvec3 extents; - float other; - } block{ - v_self.extents(), - other_val, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_output.extents(), - adaptive_work_group_size(v_output.extents()), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, vTensor::Stage::Compute, vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image(command_buffer, vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const float other_val = alpha_arg ? other.to() * alpha_arg->to() + : other.to(); + const struct Block final { + uvec3 extents; + float other; + } block{ + v_self.extents(), + other_val, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } Tensor& arithmetic_scalar_( - Tensor& self, + Tensor& self_arg, const Scalar& other, const c10::optional& alpha_arg, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { + const api::ShaderSource& shader_descriptor) { + TORCH_CHECK( + self_arg.is_vulkan(), + "Vulkan: In-place operator is only supported on Vulkan tensors."); + api::Context* const context = api::context(); - TORCH_CHECK( - self.is_vulkan(), - "Vulkan: In-place add is only supported on Vulkan tensors."); - - vTensor& v_self = convert(self); - - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY (v_self.has_image()) { - const float other_val = alpha_arg - ? other.to() * alpha_arg->to() - : other.to(); - const struct Block final { - uvec3 extents; - float other; - } block{ - v_self.extents(), - other_val, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_self.extents(), - adaptive_work_group_size(v_self.extents()), - // Read-Write access triggers an async synchronization if necessory - // and inserts appropriate barriers if hazards are detected. - v_self.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + vTensor& v_self = convert(self_arg); + + const float other_val = alpha_arg ? other.to() * alpha_arg->to() + : other.to(); + const struct Block final { + uvec3 extents; + float other; + } block{ + v_self.extents(), + other_val, + }; - return self; + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_self.extents(), + // local work group size + adaptive_work_group_size(v_self.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_self.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + // params buffer + params.buffer()); + + return self_arg; } Tensor arithmetic_tensor( const Tensor& self_arg, const Tensor& other_arg, const c10::optional& alpha_arg, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { + const api::ShaderSource& shader_descriptor) { check_inputs(self_arg, other_arg); api::Context* const context = api::context(); @@ -189,133 +186,196 @@ Tensor arithmetic_tensor( vTensor v_output{ context, - broadcast_first_input(v_self, v_other) ? v_other.sizes() : v_self.sizes(), + broadcast_size(self_arg, other_arg), v_self.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY (v_self.has_image() && v_other.has_image()) { - const float alpha = alpha_arg ? alpha_arg->to() : 1.0; - const struct Block final { - uvec3 extents; - uint32_t fill_0; - uvec3 input1_extents; - uint32_t fill_1; - uvec3 input2_extents; - float alpha; - } block{ - v_output.extents(), - 0u, - v_self.extents(), - 0u, - v_other.extents(), - alpha, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_output.extents(), - adaptive_work_group_size(v_output.extents()), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, vTensor::Stage::Compute, vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image(command_buffer, vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_other.image(command_buffer, vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const float alpha = alpha_arg ? alpha_arg->to() : 1.0; + const struct Block final { + uvec3 extents; + uint32_t fill_0; + uvec3 input1_extents; + uint32_t fill_1; + uvec3 input2_extents; + float alpha; + } block{ + v_output.extents(), + 0u, + v_self.extents(), + 0u, + v_other.extents(), + alpha, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_other.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } +Tensor quantized_arithmetic_tensor( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point, + const api::ShaderSource& shader_descriptor) { + check_inputs(self_arg, other_arg); + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + const vTensor& v_other = convert(other); + + TORCH_CHECK(v_self.is_quantized(), "Input tensor is not quantized"); + TORCH_CHECK(v_other.is_quantized(), "Input tensor is not quantized"); + + vTensor v_output{ + context, + broadcast_size(self_arg, other_arg), + self.options().dtype(c10::kQUInt8), + scale, + zero_point}; + + const double scale1 = v_self.get_scale(); + const double scale2 = v_other.get_scale(); + const int64_t zero_point1 = v_self.get_zero_point(); + const int64_t zero_point2 = v_other.get_zero_point(); + const struct Block final { + uvec3 extents; + uint32_t fill_0; + uvec3 input1_extents; + uint32_t fill_1; + uvec3 input2_extents; + uint32_t fill_2; + float scale1; + float scale2; + int32_t zero_point1; + int32_t zero_point2; + float scale; + float _1; + int32_t zero_point; + int32_t _2; + } block{ + v_output.extents(), + 0u, + v_self.extents(), + 0u, + v_other.extents(), + 0u, + safe_downcast(scale1), + safe_downcast(scale2), + safe_downcast(zero_point1), + safe_downcast(zero_point2), + safe_downcast(scale), + 0.0f, + safe_downcast(zero_point), + 0u, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_other.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return convert_quantized(v_output); +} + Tensor& arithmetic_tensor_( - Tensor& self, + Tensor& self_arg, const Tensor& other_arg, const c10::optional& alpha_arg, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { - check_inputs(self, other_arg); - api::Context* const context = api::context(); + const api::ShaderSource& shader_descriptor) { + check_inputs(self_arg, other_arg); TORCH_CHECK( - self.is_vulkan(), - "Vulkan: In-place add is only supported on Vulkan tensors."); + self_arg.is_vulkan(), + "Vulkan: In-place operator is only supported on Vulkan tensors."); + + api::Context* const context = api::context(); - vTensor& v_self = convert(self); + vTensor& v_self = convert(self_arg); const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); const vTensor& v_other = convert(other); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY ( - v_self.has_image() && v_other.has_image() && !self.is_same(other)) { - const float alpha = alpha_arg ? alpha_arg->to() : 1.0; - const struct Block final { - uvec3 extents; - uint32_t fill_0; - uvec3 input_extents; - float alpha; - } block{ - v_self.extents(), - 0u, - v_other.extents(), - alpha, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_self.extents(), - adaptive_work_group_size(v_self.extents()), - // Read-Write access triggers an async synchronization if necessory - // and inserts appropriate barriers if hazards are detected. - v_self.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_other.image(command_buffer, vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const float alpha = alpha_arg ? alpha_arg->to() : 1.0; + const struct Block final { + uvec3 extents; + uint32_t fill_0; + uvec3 input_extents; + float alpha; + } block{ + v_self.extents(), + 0u, + v_other.extents(), + alpha, + }; - return self; + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_self.extents(), + // local work group size + adaptive_work_group_size(v_self.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_self.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + v_other.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return self_arg; } Tensor add_scalar( @@ -323,12 +383,48 @@ Tensor add_scalar( const Scalar& other, const Scalar& alpha) { return arithmetic_scalar( - self_arg, other, c10::optional(alpha), VK_KERNEL(add_scalar), "aten::add.Scalar"); + self_arg, other, c10::optional(alpha), VK_KERNEL(add_scalar)); } Tensor& add_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { return arithmetic_scalar_( - self, other, c10::optional(alpha), VK_KERNEL(add_scalar_), "aten::add_.Scalar"); + self, other, c10::optional(alpha), VK_KERNEL(add_scalar_)); +} + +Tensor quantized_add( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point) { + return quantized_arithmetic_tensor( + self_arg, other_arg, scale, zero_point, VK_KERNEL(quantized_add)); +} + +Tensor quantized_sub( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point) { + return quantized_arithmetic_tensor( + self_arg, other_arg, scale, zero_point, VK_KERNEL(quantized_sub)); +} + +Tensor quantized_mul( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point) { + return quantized_arithmetic_tensor( + self_arg, other_arg, scale, zero_point, VK_KERNEL(quantized_mul)); +} + +Tensor quantized_div( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point) { + return quantized_arithmetic_tensor( + self_arg, other_arg, scale, zero_point, VK_KERNEL(quantized_div)); } Tensor add_tensor( @@ -340,16 +436,18 @@ Tensor add_tensor( self_arg, other_arg.item(), c10::optional(alpha.to()), - VK_KERNEL(add_scalar), - "aten::add.Tensor"); + VK_KERNEL(add_scalar)); } return arithmetic_tensor( - self_arg, other_arg, c10::optional(alpha), VK_KERNEL(add), "aten::add.Tensor"); + self_arg, other_arg, c10::optional(alpha), VK_KERNEL(add)); } -Tensor& add_tensor_(Tensor& self, const Tensor& other_arg, const Scalar& alpha) { +Tensor& add_tensor_( + Tensor& self, + const Tensor& other_arg, + const Scalar& alpha) { return arithmetic_tensor_( - self, other_arg, c10::optional(alpha), VK_KERNEL(add_), "aten::add_.Tensor"); + self, other_arg, c10::optional(alpha), VK_KERNEL(add_)); } Tensor sub_scalar( @@ -360,8 +458,7 @@ Tensor sub_scalar( self_arg, other, c10::optional(-1 * alpha.to()), - VK_KERNEL(add_scalar), - "aten::sub.Scalar"); + VK_KERNEL(add_scalar)); } Tensor& sub_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { @@ -369,8 +466,7 @@ Tensor& sub_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { self, other, c10::optional(-1 * alpha.to()), - VK_KERNEL(add_scalar_), - "aten::sub_.Scalar"); + VK_KERNEL(add_scalar_)); } Tensor sub_tensor( @@ -382,26 +478,28 @@ Tensor sub_tensor( self_arg, other_arg.item(), c10::optional(-1 * alpha.to()), - VK_KERNEL(add_scalar), - "aten::sub.Tensor"); + VK_KERNEL(add_scalar)); } return arithmetic_tensor( - self_arg, other_arg, c10::optional(alpha), VK_KERNEL(sub), "aten::sub.Tensor"); + self_arg, other_arg, c10::optional(alpha), VK_KERNEL(sub)); } -Tensor& sub_tensor_(Tensor& self, const Tensor& other_arg, const Scalar& alpha) { +Tensor& sub_tensor_( + Tensor& self, + const Tensor& other_arg, + const Scalar& alpha) { return arithmetic_tensor_( - self, other_arg, c10::optional(alpha), VK_KERNEL(sub_), "aten::sub_.Tensor"); + self, other_arg, c10::optional(alpha), VK_KERNEL(sub_)); } Tensor mul_scalar(const Tensor& self_arg, const Scalar& other) { return arithmetic_scalar( - self_arg, other, c10::optional(), VK_KERNEL(mul_scalar), "aten::mul.Scalar"); + self_arg, other, c10::optional(), VK_KERNEL(mul_scalar)); } Tensor& mul_scalar_(Tensor& self, const Scalar& other) { return arithmetic_scalar_( - self, other, c10::optional(), VK_KERNEL(mul_scalar_), "aten::mul_.Scalar"); + self, other, c10::optional(), VK_KERNEL(mul_scalar_)); } Tensor mul_tensor(const Tensor& self_arg, const Tensor& other_arg) { @@ -410,16 +508,15 @@ Tensor mul_tensor(const Tensor& self_arg, const Tensor& other_arg) { self_arg, other_arg.item(), c10::optional(), - VK_KERNEL(mul_scalar), - "aten::mul.Tensor"); + VK_KERNEL(mul_scalar)); } return arithmetic_tensor( - self_arg, other_arg, c10::optional(), VK_KERNEL(mul), "aten::mul.Tensor"); + self_arg, other_arg, c10::optional(), VK_KERNEL(mul)); } Tensor& mul_tensor_(Tensor& self, const Tensor& other_arg) { return arithmetic_tensor_( - self, other_arg, c10::optional(), VK_KERNEL(mul_), "aten::mul_.Tensor"); + self, other_arg, c10::optional(), VK_KERNEL(mul_)); } Tensor div_scalar(const Tensor& self_arg, const Scalar& other) { @@ -427,8 +524,7 @@ Tensor div_scalar(const Tensor& self_arg, const Scalar& other) { self_arg, 1.0 / other.to(), c10::optional(), - VK_KERNEL(mul_scalar), - "aten::div.Scalar"); + VK_KERNEL(mul_scalar)); } Tensor& div_scalar_(Tensor& self, const Scalar& other) { @@ -436,8 +532,7 @@ Tensor& div_scalar_(Tensor& self, const Scalar& other) { self, 1.0 / other.to(), c10::optional(), - VK_KERNEL(mul_scalar_), - "aten::div_.Scalar"); + VK_KERNEL(mul_scalar_)); } Tensor div_tensor(const Tensor& self_arg, const Tensor& other_arg) { @@ -446,16 +541,15 @@ Tensor div_tensor(const Tensor& self_arg, const Tensor& other_arg) { self_arg, 1.0 / other_arg.item(), c10::optional(), - VK_KERNEL(mul_scalar), - "aten::div.Tensor"); + VK_KERNEL(mul_scalar)); } return arithmetic_tensor( - self_arg, other_arg, c10::optional(), VK_KERNEL(div), "aten::div.Tensor"); + self_arg, other_arg, c10::optional(), VK_KERNEL(div)); } Tensor& div_tensor_(Tensor& self, const Tensor& other_arg) { return arithmetic_tensor_( - self, other_arg, c10::optional(), VK_KERNEL(div_), "aten::div_.Tensor"); + self, other_arg, c10::optional(), VK_KERNEL(div_)); } #ifdef USE_VULKAN_API @@ -481,7 +575,6 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { #endif /* USE_VULKAN_API */ -} // namespace } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp new file mode 100644 index 0000000000000..30407e8cec38a --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp @@ -0,0 +1,147 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor batch_norm( + const at::Tensor& input_arg, + const c10::optional& weight_opt /* optional */, + const c10::optional& bias_opt /* optional */, + const c10::optional& running_mean_opt /* optional */, + const c10::optional& running_var_opt /* optional */, + bool training, + double /* momentum, not used in eval mode */, + double eps, + bool /* cudnn_enable, deprecated */) { + TORCH_CHECK(!training, "Vulkan batchnorm only supports evaluation mode."); + TORCH_CHECK( + weight_opt && weight_opt->defined() && bias_opt && bias_opt->defined(), + "Vulkan batchnorm expects weight and bias arguments to be defined"); + TORCH_CHECK( + running_mean_opt && running_mean_opt->defined(), + "running_mean must be defined in evaluation mode."); + TORCH_CHECK( + running_var_opt && running_var_opt->defined(), + "running_var must be defined in evaluation mode."); + TORCH_CHECK(input_arg.dim() == 4, "Vulkan batchnorm expects 4-dim input!"); + TORCH_CHECK( + channels_size(input_arg) % 4 == 0, + "Vulkan batchnorm expects channel dim to be multiple of 4!"); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + const IntArrayRef v_input_sizes = v_input.sizes(); + + auto num_features = v_input.sizes()[1]; + auto channels_ext = num_features / 4; + + const Tensor weight_opt_3d = weight_opt->reshape({num_features, 1, 1}); + const Tensor weight = + weight_opt_3d.is_vulkan() ? weight_opt_3d : weight_opt_3d.vulkan(); + const vTensor& v_weight = convert(weight); + TORCH_CHECK( + weight.numel() == num_features, + "weight tensor should contain ", + num_features, + " elements!"); + + const Tensor bias_opt_3d = bias_opt->reshape({num_features, 1, 1}); + const Tensor bias = + bias_opt_3d.is_vulkan() ? bias_opt_3d : bias_opt_3d.vulkan(); + const vTensor& v_bias = convert(bias); + TORCH_CHECK( + bias.numel() == num_features, + "bias tensor should contain ", + num_features, + " elements!"); + + const Tensor running_mean_opt_3d = + running_mean_opt->reshape({num_features, 1, 1}); + const Tensor running_mean = running_mean_opt_3d.is_vulkan() + ? running_mean_opt_3d + : running_mean_opt_3d.vulkan(); + const vTensor& v_running_mean = convert(running_mean); + TORCH_CHECK( + running_mean.numel() == num_features, + "running mean tensor should contain ", + num_features, + " elements!"); + + const Tensor running_var_opt_3d = + running_var_opt->reshape({num_features, 1, 1}); + const Tensor running_var = running_var_opt_3d.is_vulkan() + ? running_var_opt_3d + : running_var_opt_3d.vulkan(); + const vTensor& v_running_var = convert(running_var); + TORCH_CHECK( + running_var.numel() == num_features, + "running var tensor should contain ", + num_features, + " elements!"); + + api::Context* const context = api::context(); + + vTensor v_output{ + context, + v_input_sizes, + v_input.options(), + }; + + const struct Block final { + uvec3 iextents; + int32_t channels_ext; + float epsilon; + } block{ + v_output.extents(), + safe_downcast(channels_ext), + safe_downcast(eps)}; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(batchnorm), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_running_mean.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_running_var.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl(TORCH_SELECTIVE_NAME("aten::batch_norm"), TORCH_FN(batch_norm)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Clamp.cpp b/aten/src/ATen/native/vulkan/ops/Clamp.cpp index 46792d68ca996..36b09362a4f53 100644 --- a/aten/src/ATen/native/vulkan/ops/Clamp.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clamp.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -14,72 +13,55 @@ Tensor _clamp( const Tensor& self_arg, const c10::optional& min, const c10::optional& max, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { - TORCH_CHECK( - min || max, - "At least one of 'min' or 'max' must not be None"); + const api::ShaderSource& shader_descriptor) { + TORCH_CHECK(min || max, "At least one of 'min' or 'max' must not be None"); api::Context* const context = api::context(); const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); - const vTensor& v_self = convert(self); + const vTensor& v_self = convert(self_arg); vTensor v_output{ - context, - v_self.sizes(), - v_self.options(), + context, + v_self.sizes(), + v_self.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t _; - vec2 clamp; - } block { - v_output.extents(), - 0u, - { + const struct Block final { + uvec3 extents; + uint32_t _; + vec2 clamp; + } block{ + v_output.extents(), + 0u, + { min ? min->to() : -std::numeric_limits::infinity(), max ? max->to() : std::numeric_limits::infinity(), - }, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_output.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + }, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } @@ -88,407 +70,334 @@ Tensor clamp( const Tensor& self_arg, const c10::optional& min, const c10::optional& max) { - return _clamp(self_arg, min, max, VK_KERNEL(clamp), "aten::clamp"); + return _clamp(self_arg, min, max, VK_KERNEL(clamp)); } Tensor& _clamp_( - Tensor& self, + Tensor& self_arg, const c10::optional& min, const c10::optional& max, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { - api::Context* const context = api::context(); - - TORCH_CHECK( - min || max, - "At least one of 'min' or 'max' must not be None"); + const api::ShaderSource& shader_descriptor) { + TORCH_CHECK(min || max, "At least one of 'min' or 'max' must not be None"); TORCH_CHECK( - self.is_vulkan(), + self_arg.is_vulkan(), "Vulkan: In-place clamp is only supported on Vulkan tensors."); + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); vTensor& v_self = convert(self); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY(v_self.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t _; - vec2 clamp; - } block { - v_self.extents(), - 0u, - { + const struct Block final { + uvec3 extents; + uint32_t _; + vec2 clamp; + } block{ + v_self.extents(), + 0u, + { min ? min->to() : -std::numeric_limits::infinity(), max ? max->to() : std::numeric_limits::infinity(), - }, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_self.extents(), - context->gpu().adapter->local_work_group_size(), - // Read-Write access triggers an async synchronization if necessory - // and inserts appropriate barriers if hazards are detected. - v_self.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); - - return self; + }, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_self.extents(), + // local work group size + adaptive_work_group_size(v_self.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_self.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + // params buffer + params.buffer()); + + return self_arg; } Tensor threshold( const Tensor& self, const Scalar& threshold, const Scalar& value) { - return _clamp(self, threshold, value, VK_KERNEL(threshold), "aten::threshold"); + return _clamp(self, threshold, value, VK_KERNEL(threshold)); } Tensor& clamp_( Tensor& self, const c10::optional& min, const c10::optional& max) { - return _clamp_(self, min, max, VK_KERNEL(clamp_), "aten::clamp_"); + return _clamp_(self, min, max, VK_KERNEL(clamp_)); } Tensor activation( const Tensor& self_arg, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { + const api::ShaderSource& shader_descriptor) { api::Context* const context = api::context(); const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); const vTensor& v_self = convert(self); vTensor v_output{ - context, - v_self.sizes(), - v_self.options(), + context, + v_self.sizes(), + v_self.options(), + }; + + const struct Block final { + uvec3 extents; + uint32_t _; + } block{ + v_output.extents(), + 0u, }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t _; - } block { - v_output.extents(), - 0u, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_output.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } Tensor& activation_( - Tensor& self, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { - api::Context* const context = api::context(); - + Tensor& self_arg, + const api::ShaderSource& shader_descriptor) { TORCH_CHECK( - self.is_vulkan(), + self_arg.is_vulkan(), "Vulkan: In-place operator is only supported on Vulkan tensors."); - vTensor& v_self = convert(self); + api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY(v_self.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t _; - } block { - v_self.extents(), - 0u, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_self.extents(), - context->gpu().adapter->local_work_group_size(), - // Read-Write access triggers an async synchronization if necessory - // and inserts appropriate barriers if hazards are detected. - v_self.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); - - return self; + vTensor& v_self = convert(self_arg); + + const struct Block final { + uvec3 extents; + uint32_t _; + } block{ + v_self.extents(), + 0u, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_self.extents(), + // local work group size + adaptive_work_group_size(v_self.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_self.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + // params buffer + params.buffer()); + + return self_arg; } -Tensor hardtanh( - const Tensor& self, - const Scalar& min, - const Scalar& max) { - return ops::_clamp(self, min, max, VK_KERNEL(clamp), "aten::hardtanh"); +Tensor hardtanh(const Tensor& self, const Scalar& min, const Scalar& max) { + return ops::_clamp(self, min, max, VK_KERNEL(clamp)); } -Tensor& hardtanh_( - Tensor& self, - const Scalar& min, - const Scalar& max) { - return ops::_clamp_(self, min, max, VK_KERNEL(clamp_), "aten::hardtanh_"); +Tensor& hardtanh_(Tensor& self, const Scalar& min, const Scalar& max) { + return ops::_clamp_(self, min, max, VK_KERNEL(clamp_)); } Tensor relu(const Tensor& self) { - return ops::_clamp(self, 0, c10::nullopt, VK_KERNEL(clamp), "aten::relu"); + return ops::_clamp(self, 0, c10::nullopt, VK_KERNEL(clamp)); } Tensor& relu_(Tensor& self) { - return ops::_clamp_(self, 0, c10::nullopt, VK_KERNEL(clamp_), "aten::relu_"); + return ops::_clamp_(self, 0, c10::nullopt, VK_KERNEL(clamp_)); } Tensor hardswish(const Tensor& self) { - return ops::activation(self, VK_KERNEL(hardswish), "aten::hardswish"); + return ops::activation(self, VK_KERNEL(hardswish)); } Tensor& hardswish_(Tensor& self) { - return ops::activation_(self, VK_KERNEL(hardswish_), "aten::hardswish_"); + return ops::activation_(self, VK_KERNEL(hardswish_)); } Tensor hardsigmoid(const Tensor& self) { - return ops::activation(self, VK_KERNEL(hardsigmoid), "aten::hardsigmoid"); + return ops::activation(self, VK_KERNEL(hardsigmoid)); } Tensor& hardsigmoid_(Tensor& self) { - return ops::activation_(self, VK_KERNEL(hardsigmoid_), "aten::hardsigmoid_"); + return ops::activation_(self, VK_KERNEL(hardsigmoid_)); } Tensor activation_scalar( const Tensor& self_arg, const Scalar& scalar_arg, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { + const api::ShaderSource& shader_descriptor) { api::Context* const context = api::context(); const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); const vTensor& v_self = convert(self); vTensor v_output{ - context, - v_self.sizes(), - v_self.options(), + context, + v_self.sizes(), + v_self.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t _; - float scalar_value; - } block { - v_output.extents(), - 0u, - scalar_arg.to(), - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_output.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const struct Block final { + uvec3 extents; + uint32_t _; + float scalar_value; + } block{ + v_output.extents(), + 0u, + scalar_arg.to(), + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } Tensor& activation_scalar_( - Tensor& self, + Tensor& self_arg, const Scalar& scalar_arg, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { - api::Context* const context = api::context(); - + const api::ShaderSource& shader_descriptor) { TORCH_CHECK( - self.is_vulkan(), + self_arg.is_vulkan(), "Vulkan: In-place operator is only supported on Vulkan tensors."); - vTensor& v_self = convert(self); + api::Context* const context = api::context(); + + vTensor& v_self = convert(self_arg); + + const struct Block final { + uvec3 extents; + uint32_t _; + float scalar_value; + } block{ + v_self.extents(), + 0u, + scalar_arg.to(), + }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY(v_self.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t _; - float scalar_value; - } block { - v_self.extents(), - 0u, - scalar_arg.to(), - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_self.extents(), - context->gpu().adapter->local_work_group_size(), - // Read-Write access triggers an async synchronization if necessory - // and inserts appropriate barriers if hazards are detected. - v_self.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); - - return self; + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_self.extents(), + // local work group size + adaptive_work_group_size(v_self.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_self.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + // params buffer + params.buffer()); + + return self_arg; } -Tensor hardshrink( - const Tensor& self_arg, - const Scalar& lambd) { - return ops::activation_scalar(self_arg, lambd, VK_KERNEL(hardshrink), "aten::hardshrink"); +Tensor hardshrink(const Tensor& self_arg, const Scalar& lambd) { + float abs_lambd = std::abs(lambd.to()); + return ops::activation_scalar(self_arg, abs_lambd, VK_KERNEL(hardshrink)); } -Tensor& hardshrink_( - Tensor& self, - const Scalar& lambd) { - return ops::activation_scalar_(self, lambd, VK_KERNEL(hardshrink_), "aten::hardshrink_"); +Tensor& hardshrink_(Tensor& self, const Scalar& lambd) { + float abs_lambd = std::abs(lambd.to()); + return ops::activation_scalar_(self, abs_lambd, VK_KERNEL(hardshrink_)); } -Tensor leaky_relu( - const Tensor& self_arg, - const Scalar& negative_slope) { - return ops::activation_scalar(self_arg, negative_slope, VK_KERNEL(leaky_relu), "aten::leaky_relu"); +Tensor leaky_relu(const Tensor& self_arg, const Scalar& negative_slope) { + return ops::activation_scalar( + self_arg, negative_slope, VK_KERNEL(leaky_relu)); } -Tensor& leaky_relu_( - Tensor& self, - const Scalar& negative_slope) { - return ops::activation_scalar_(self, negative_slope, VK_KERNEL(leaky_relu_), "aten::leaky_relu_"); +Tensor& leaky_relu_(Tensor& self, const Scalar& negative_slope) { + return ops::activation_scalar_(self, negative_slope, VK_KERNEL(leaky_relu_)); } Tensor sigmoid(const Tensor& self) { - return ops::activation(self, VK_KERNEL(sigmoid), "aten::sigmoid"); + return ops::activation(self, VK_KERNEL(sigmoid)); } Tensor& sigmoid_(Tensor& self) { - return ops::activation_(self, VK_KERNEL(sigmoid_), "aten::sigmoid_"); + return ops::activation_(self, VK_KERNEL(sigmoid_)); } Tensor tanh(const Tensor& self) { - return ops::activation(self, VK_KERNEL(tanh), "aten::tanh"); + return ops::activation(self, VK_KERNEL(tanh)); } Tensor& tanh_(Tensor& self) { - return ops::activation_(self, VK_KERNEL(tanh_), "aten::tanh_"); + return ops::activation_(self, VK_KERNEL(tanh_)); } - #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { diff --git a/aten/src/ATen/native/vulkan/ops/Clone.cpp b/aten/src/ATen/native/vulkan/ops/Clone.cpp index 8ea7d06c2fe50..fa9f791cea0c0 100644 --- a/aten/src/ATen/native/vulkan/ops/Clone.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clone.cpp @@ -7,11 +7,12 @@ namespace vulkan { namespace ops { namespace { -Tensor clone(const Tensor& src, c10::optional optional_memory_format) { - auto memory_format = - optional_memory_format.value_or(MemoryFormat::Preserve); +Tensor clone( + const Tensor& src, + c10::optional optional_memory_format) { + auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve); TORCH_CHECK( - (c10::MemoryFormat::Preserve == memory_format) || + (c10::MemoryFormat::Preserve == memory_format) || (c10::MemoryFormat::Contiguous == memory_format), "Vulkan supports Preserve and Contiguous memory foramts"); diff --git a/aten/src/ATen/native/vulkan/ops/Common.cpp b/aten/src/ATen/native/vulkan/ops/Common.cpp index d53af6c7f45d5..9336291840967 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.cpp +++ b/aten/src/ATen/native/vulkan/ops/Common.cpp @@ -41,15 +41,15 @@ uint32_t width_size(const Tensor& tensor) { return sizes[dims - 1]; } -api::Shader::WorkGroup adaptive_work_group_size(const api::Shader::WorkGroup& global_work_group) { - api::Shader::WorkGroup local_group_size = {4, 4, 4}; +api::utils::uvec3 adaptive_work_group_size( + const api::utils::uvec3& global_work_group) { + api::utils::uvec3 local_group_size = {4, 4, 4}; if (global_work_group.data[2u] == 1) { if (global_work_group.data[1u] < 8) { local_group_size.data[0u] = 16; local_group_size.data[1u] = 4; local_group_size.data[2u] = 1; - } - else { + } else { local_group_size.data[0u] = 8; local_group_size.data[1u] = 8; local_group_size.data[2u] = 1; diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h index 39a6bbfa9f642..2cb6159038bb1 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.h +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -48,7 +48,8 @@ uint32_t channels_size(const Tensor& tensor); uint32_t height_size(const Tensor& tensor); uint32_t width_size(const Tensor& tensor); -api::Shader::WorkGroup adaptive_work_group_size(const api::Shader::WorkGroup& global_work_group); +api::utils::uvec3 adaptive_work_group_size( + const api::utils::uvec3& global_work_group); } // namespace ops } // namespace vulkan diff --git a/aten/src/ATen/native/vulkan/ops/Concat.cpp b/aten/src/ATen/native/vulkan/ops/Concat.cpp index 779c36bd2a2e2..d0c2c0cf6afe6 100644 --- a/aten/src/ATen/native/vulkan/ops/Concat.cpp +++ b/aten/src/ATen/native/vulkan/ops/Concat.cpp @@ -1,5 +1,3 @@ -#include -#include #include #include #include @@ -24,135 +22,117 @@ Tensor cat_batch(const TensorList tensors, vTensor& v_output) { Tensor cat_feature(const TensorList tensors, vTensor& v_output) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::_cat (cat_batch)"); - - int64_t ch_size_allprior = 0; - int64_t ch_interval = 0; - for (const auto& tensor : tensors) { - ch_interval += tensor.sizes()[1]; - } - auto dst_image = v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write); - - for (const auto& tensor : tensors) { - const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan(); - const vTensor& v_self = convert(self); - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - auto src_image = v_self.image( - command_buffer, - vTensor::Stage::Compute); - - const struct Block final { - uvec3 size; // output texture size - uint32_t fill_0; // dummy - uvec3 isize; // input texture size - uint32_t fill_1; // dummy - uint32_t batch_size; // input tensor's batch size - uint32_t ch_size; // input tensor's channel size - uint32_t ch_interval; // channel interval (total # of channels for all tensors) - uint32_t ch_size_allprior; // # of channels for tensor 0 to i-1 at ith tensor - } block { - v_output.extents(), - 0u, - v_self.extents(), - 0u, - safe_downcast(v_self.sizes()[0]), - safe_downcast(v_self.sizes()[1]), - safe_downcast(ch_interval), - safe_downcast(ch_size_allprior), - }; - - ch_size_allprior += v_self.sizes()[1]; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(cat_feature), - v_self.extents(), - context->gpu().adapter->local_work_group_size(), - // Read/Write access bypasses synchronization but inserts appropriate - // barriers if necessary. - dst_image, - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - src_image, - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } + int64_t ch_size_allprior = 0; + int64_t ch_interval = 0; + for (const auto& tensor : tensors) { + ch_interval += tensor.sizes()[1]; + } + + for (const auto& tensor : tensors) { + const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan(); + const vTensor& v_self = convert(self); + + const struct Block final { + uvec3 size; // output texture size + uint32_t fill_0; // dummy + uvec3 isize; // input texture size + uint32_t fill_1; // dummy + uint32_t batch_size; // input tensor's batch size + uint32_t ch_size; // input tensor's channel size + uint32_t + ch_interval; // channel interval (total # of channels for all tensors) + uint32_t + ch_size_allprior; // # of channels for tensor 0 to i-1 at ith tensor + } block{ + v_output.extents(), + 0u, + v_self.extents(), + 0u, + safe_downcast(v_self.sizes()[0]), + safe_downcast(v_self.sizes()[1]), + safe_downcast(ch_interval), + safe_downcast(ch_size_allprior), + }; + + ch_size_allprior += v_self.sizes()[1]; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(cat_feature), + // pipeline barrier + pipeline_barrier, + // global work group size + v_self.extents(), + // local work group size + adaptive_work_group_size(v_self.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); } - command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } Tensor cat_feature_mult4ch(const TensorList tensors, vTensor& v_output) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::_cat (cat_feature_mult4ch)"); - - int64_t depth_size_allprior = 0; - int64_t ch_interval = 0; - for (const auto& tensor : tensors) { - ch_interval += tensor.sizes()[1]; - } - const int64_t depth_interval = ch_interval / 4; - - auto dst_image = v_output.image( - command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write); - uvec3 src_offset{}; - uvec3 dst_offset{}; - - for (const auto& tensor : tensors) { - const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan(); - const vTensor& v_self = convert(self); - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - auto src_image = v_self.image( - command_buffer, - vTensor::Stage::Transfer); - - const uint32_t depth_slice = safe_downcast(tensor.sizes()[1] / 4); - uvec3 copy_extents {v_self.extents().data[0u], - v_self.extents().data[1u], - depth_slice}; - - for (const auto b : c10::irange(tensor.sizes()[0])) { - src_offset.data[2u] = safe_downcast(depth_slice * b); - dst_offset.data[2u] = depth_size_allprior + safe_downcast(depth_interval * b); - api::helper::copy_texture_to_texture(command_buffer, - src_image, - dst_image, - copy_extents, - src_offset, - dst_offset); - } - - depth_size_allprior += depth_slice; - } - else { - TORCH_CHECK(false, "Not implemented!"); - } + + int64_t depth_size_allprior = 0; + int64_t ch_interval = 0; + for (const auto& tensor : tensors) { + ch_interval += tensor.sizes()[1]; + } + const int64_t depth_interval = ch_interval / 4; + + uvec3 src_offset{}; + uvec3 dst_offset{}; + + for (const auto& tensor_arg : tensors) { + const Tensor tensor = + tensor_arg.is_vulkan() ? tensor_arg : tensor_arg.vulkan(); + const vTensor& v_self = convert(tensor); + + const uint32_t depth_slice = safe_downcast(tensor.sizes()[1] / 4); + + uvec3 copy_extents{ + v_self.extents().data[0u], v_self.extents().data[1u], depth_slice}; + + for (const auto b : c10::irange(tensor.sizes()[0])) { + src_offset.data[2u] = safe_downcast(depth_slice * b); + dst_offset.data[2u] = + depth_size_allprior + safe_downcast(depth_interval * b); + + api::PipelineBarrier pipeline_barrier{}; + + context->submit_texture_copy( + // pipeline barrier + pipeline_barrier, + // images + v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER), + v_output.image( + pipeline_barrier, + api::PipelineStage::TRANSFER, + api::MemoryAccessType::WRITE), + // copy details + copy_extents, + src_offset, + dst_offset, + // fence handle + VK_NULL_HANDLE); } + + depth_size_allprior += depth_slice; } - command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } @@ -163,60 +143,48 @@ Tensor cat_width(const TensorList tensors, vTensor& v_output) { Tensor cat_height(const TensorList tensors, vTensor& v_output) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::_cat (cat_width)"); - - auto dst_image = v_output.image( - command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write); - - uvec3 src_offset{}; - uvec3 dst_offset{}; - for (const auto& tensor : tensors) { - const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan(); - const vTensor& v_self = convert(self); - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - auto src_image = v_self.image( - command_buffer, - vTensor::Stage::Transfer); - - api::helper::copy_texture_to_texture(command_buffer, - src_image, - dst_image, - v_self.extents(), - src_offset, - dst_offset); - // Increment by height - dst_offset.data[1u] += v_self.extents().data[1u]; - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } + uvec3 src_offset{}; + uvec3 dst_offset{}; + + for (const auto& tensor : tensors) { + const vTensor& v_self = convert(tensor); + + api::PipelineBarrier pipeline_barrier{}; + + context->submit_texture_copy( + // pipeline barrier + pipeline_barrier, + // images + v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER), + v_output.image( + pipeline_barrier, + api::PipelineStage::TRANSFER, + api::MemoryAccessType::WRITE), + // copy details + v_self.extents(), + src_offset, + dst_offset, + // fence handle + VK_NULL_HANDLE); + + // Increment by height + dst_offset.data[1u] += v_self.extents().data[1u]; } - command_pool.submit(context->gpu().queue, command_buffer); return convert(v_output); } -Tensor cat( - const at::TensorList tensors, - const int64_t dim) { - TORCH_CHECK( - tensors.size() > 0, - "Vulkan cat expects at least one tensor"); +Tensor cat(const at::TensorList tensors, const int64_t dim) { + TORCH_CHECK(tensors.size() > 0, "Vulkan cat expects at least one tensor"); at::Tensor tensor = tensors[0]; int64_t cat_dim_size = 0; bool is_mult4ch = true; - for (const auto & t : tensors) { - TORCH_INTERNAL_ASSERT( - t.dim() == 4, "Vulkan cat expects 4 dimensional inputs"); + for (const auto& t : tensors) { + TORCH_INTERNAL_ASSERT( + t.dim() == 4, "Vulkan cat expects 4 dimensional inputs"); if (t.sizes()[1] % 4 != 0) { is_mult4ch = false; @@ -227,8 +195,8 @@ Tensor cat( continue; } TORCH_INTERNAL_ASSERT( - t.size(d) == tensor.size(d), - "Vulkan cat inputs must have matching sizes except concatenated dimension"); + t.size(d) == tensor.size(d), + "Vulkan cat inputs must have matching sizes except concatenated dimension"); } cat_dim_size += t.size(dim); } @@ -236,18 +204,14 @@ Tensor cat( auto result_size = tensor.sizes().vec(); result_size[dim] = cat_dim_size; - vTensor v_output{ - api::context(), - result_size, - tensor.options()}; + vTensor v_output{api::context(), result_size, tensor.options()}; if (dim == 3) { return cat_width(tensors, v_output); } if (dim == 2) { return cat_height(tensors, v_output); - } - else if (dim == 1) { + } else if (dim == 1) { if (is_mult4ch) { return cat_feature_mult4ch(tensors, v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index 9c800c8fd8b5d..e375a887e0e2e 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -1,9 +1,10 @@ -#include -#include #include #include + +#include #include #include +#include #include #include @@ -14,25 +15,19 @@ namespace ops { namespace { using namespace api::utils; +using namespace at::native::vulkan::ops; -struct Experimentation final { - static constexpr bool kUseWinogradConvs = false; -}; - -inline bool is_depthwise( - const IntArrayRef filter, - const int64_t groups) { +inline bool is_depthwise(const IntArrayRef filter, const int64_t groups) { return (filter[Layout::Filter::output] == groups) && - // Only K == 1 supported. - (filter[Layout::Filter::input] == 1); + // Only K == 1 supported. + (filter[Layout::Filter::input] == 1); } inline bool is_pointwise(const IntArrayRef filter) { return (1 == filter[Layout::Filter::height]) && - (1 == filter[Layout::Filter::width]); + (1 == filter[Layout::Filter::width]); } - bool all_lessthan(const IntArrayRef arr, const int t) { bool retval = true; for (const auto i : c10::irange(arr.size())) { @@ -41,16 +36,6 @@ bool all_lessthan(const IntArrayRef arr, const int t) { return retval; } -inline bool is_winograd_n_3( - const IntArrayRef filter, - const IntArrayRef stride, - const IntArrayRef dilation) { - return (3 == filter[Layout::Filter::height]) && - (3 == filter[Layout::Filter::width]) && - all_lessthan(stride, 2) && - all_lessthan(dilation, 2); -} - Conv2dMethod determine_method( const IntArrayRef filter, const IntArrayRef stride, @@ -61,15 +46,10 @@ Conv2dMethod determine_method( return Conv2dDepthwise; if (is_pointwise(filter)) return Conv2dPointwise; - if (Experimentation::kUseWinogradConvs && is_winograd_n_3(filter, stride, dilation)) - return Conv2dWinograd_2_3; return Conv2dSlidingWindow; } -vTensor pack_weights_dw( - api::Context* const context, - api::Command::Buffer& command_buffer, - const Tensor& weight) { +vTensor pack_weights_dw(api::Context* const context, const Tensor& weight) { /* Source */ const IntArrayRef src_filter = weight.sizes(); const float* const src_weight_ptr = weight.data_ptr(); @@ -77,8 +57,10 @@ vTensor pack_weights_dw( const int64_t src_kw_sz = src_filter[Layout::Filter::width]; const int64_t src_kh_sz = src_filter[Layout::Filter::height]; const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; - const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; - const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); + const int64_t src_block_sz = + src_kernel_sz * src_filter[Layout::Filter::input]; + const int64_t num_stacks = + div_up(src_filter[Layout::Filter::output], INT64_C(4)); /* Destination */ const int64_t dst_kw_sz = src_kernel_sz; @@ -95,40 +77,41 @@ vTensor pack_weights_dw( weight.options(), }; - using Future = vTensor::Future; - Future v_weight_future = v_weight.host(command_buffer); - Future::Payload v_weight_payload = v_weight_future.wait(); + api::StagingBuffer staging(context, v_weight.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + float* dst_weight_ptr = mapping.template data(); - float* const dst_weight_ptr = v_weight_payload.get(); - memset(dst_weight_ptr, 0, v_weight.nbytes()); + memset(dst_weight_ptr, 0, v_weight.nbytes()); - for (const auto src_oc : c10::irange(src_filter[Layout::Filter::output])) { - /* Source */ - const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; + for (const auto src_oc : c10::irange(src_filter[Layout::Filter::output])) { + /* Source */ + const float* const src_weight_oc_ptr = + src_weight_ptr + src_oc * src_block_sz; - /* Destination */ - const int64_t dst_oh = src_oc / 4; - const int64_t dst_c = src_oc % 4; + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; - float* const dst_weight_c_ptr = dst_weight_ptr + - dst_c * dst_kernel_sz + - dst_oh * dst_kw_sz; + float* const dst_weight_c_ptr = + dst_weight_ptr + dst_c * dst_kernel_sz + dst_oh * dst_kw_sz; - for (const auto src_ih : c10::irange(src_filter[Layout::Filter::height])) { - memcpy( - dst_weight_c_ptr + src_ih * src_kw_sz, - src_weight_oc_ptr + src_ih * src_kw_sz, - sizeof(float) * src_kw_sz); + for (const auto src_ih : + c10::irange(src_filter[Layout::Filter::height])) { + memcpy( + dst_weight_c_ptr + src_ih * src_kw_sz, + src_weight_oc_ptr + src_ih * src_kw_sz, + sizeof(float) * src_kw_sz); + } } } + ops::utils::pack_staging_to_vtensor(staging.buffer(), v_weight); return v_weight; } -vTensor pack_weights_2d( - api::Context* const context, - api::Command::Buffer& command_buffer, - const Tensor& weight) { +vTensor pack_weights_2d(api::Context* const context, const Tensor& weight) { /* Source */ const IntArrayRef src_filter = weight.sizes(); const float* const src_weight_ptr = weight.data_ptr(); @@ -136,10 +119,13 @@ vTensor pack_weights_2d( const int64_t src_kw_sz = src_filter[Layout::Filter::width]; const int64_t src_kh_sz = src_filter[Layout::Filter::height]; const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; - const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; + const int64_t src_block_sz = + src_kernel_sz * src_filter[Layout::Filter::input]; - const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); - const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); + const int64_t num_stacks = + div_up(src_filter[Layout::Filter::output], INT64_C(4)); + const int64_t stack_depth = + api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); /* Destination */ const int64_t dst_kw_sz = src_kw_sz * stack_depth; @@ -156,212 +142,107 @@ vTensor pack_weights_2d( weight.options(), }; - using Future = vTensor::Future; - Future v_weight_future = v_weight.host(command_buffer); - Future::Payload v_weight_payload = v_weight_future.wait(); + api::StagingBuffer staging(context, v_weight.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + float* dst_weight_ptr = mapping.template data(); - float* const dst_weight_ptr = v_weight_payload.get(); - memset(dst_weight_ptr, 0, v_weight.nbytes()); + memset(dst_weight_ptr, 0, v_weight.nbytes()); - for (const auto src_oc : c10::irange(src_filter[Layout::Filter::output])) { - /* Source */ - const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; + for (const auto src_oc : c10::irange(src_filter[Layout::Filter::output])) { + /* Source */ + const float* const src_weight_oc_ptr = + src_weight_ptr + src_oc * src_block_sz; - /* Destination */ - const int64_t dst_oh = src_oc / 4; - const int64_t dst_c = src_oc % 4; + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; - float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; + float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; - for (const auto src_ic : c10::irange(src_filter[Layout::Filter::input])) { - const int64_t dst_ic4 = src_ic / 4; + for (const auto src_ic : c10::irange(src_filter[Layout::Filter::input])) { + const int64_t dst_ic4 = src_ic / 4; - for (const auto src_ih : c10::irange(src_kh_sz)) { - for (const auto src_iw : c10::irange(src_kw_sz)) { - memcpy( - dst_weight_c_ptr + (dst_oh * src_kh_sz + src_ih) * dst_kw_sz + - dst_ic4 * src_kw_sz * 4 + src_iw * 4 + src_ic % 4, - src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw, - sizeof(float)); + for (const auto src_ih : c10::irange(src_kh_sz)) { + for (const auto src_iw : c10::irange(src_kw_sz)) { + memcpy( + dst_weight_c_ptr + (dst_oh * src_kh_sz + src_ih) * dst_kw_sz + + dst_ic4 * src_kw_sz * 4 + src_iw * 4 + src_ic % 4, + src_weight_oc_ptr + src_ic * src_kernel_sz + + src_ih * src_kw_sz + src_iw, + sizeof(float)); + } } } } } + ops::utils::pack_staging_to_vtensor(staging.buffer(), v_weight); return v_weight; } -vTensor pack_weights_2d_winograd_2_3( - api::Context* const context, - api::Command::Buffer& command_buffer, - const Tensor& weight) { - /* Source */ - const IntArrayRef src_filter = weight.sizes(); - - TORCH_CHECK( - src_filter[Layout::Filter::width] == 3 && src_filter[Layout::Filter::height] == 3, - "Kernel size must be 3x3 for Winograd(2x2, 3x3)!"); - const int64_t src_ic_sz = src_filter[Layout::Filter::input]; - const int64_t src_oc_sz = src_filter[Layout::Filter::output]; - - /* Destination */ - const int64_t dst_ow_sz = div_up(src_ic_sz, INT64_C(4)); - const int64_t dst_oh_sz = div_up(src_oc_sz, INT64_C(4)); - const int64_t dst_kw_sz = 16*dst_ow_sz; - const int64_t dst_kh_sz = 4*dst_oh_sz; - const int64_t dst_block_sz = dst_kw_sz * dst_kh_sz; - - vTensor v_weight{ - context, - { - 4, - 4*dst_oh_sz, - 16*dst_ow_sz, - }, - weight.options(), - }; - - using Future = vTensor::Future; - Future v_weight_future = v_weight.host(command_buffer); - Future::Payload v_weight_payload = v_weight_future.wait(); - - float* const dst_weight_ptr = v_weight_payload.get(); - memset(dst_weight_ptr, 0, v_weight.nbytes()); - - for (const auto src_oc : c10::irange(src_oc_sz)) { - const int64_t dst_oh = src_oc / 4; - const int64_t dst_iw = src_oc % 4; - - for (const auto src_ic : c10::irange(src_ic_sz)) { - const int64_t dst_ow = src_ic / 4; - const int64_t dst_c = src_ic % 4; - - //const float* const src_k_ptr = src_weight_ptr + src_oc * src_block_sz + src_ic * 9; - float* const dst_k = dst_weight_ptr + dst_c * dst_block_sz; - - const float s00 = weight[src_oc][src_ic][0][0].item(); - const float s01 = weight[src_oc][src_ic][0][1].item(); - const float s02 = weight[src_oc][src_ic][0][2].item(); - const float s10 = weight[src_oc][src_ic][1][0].item(); - const float s11 = weight[src_oc][src_ic][1][1].item(); - const float s12 = weight[src_oc][src_ic][1][2].item(); - const float s20 = weight[src_oc][src_ic][2][0].item(); - const float s21 = weight[src_oc][src_ic][2][1].item(); - const float s22 = weight[src_oc][src_ic][2][2].item(); - - const float m00 = s00; - const float m01 = s01; - const float m02 = s02; - const float m10 = (s00 + s10 + s20)/2.f; - const float m11 = (s01 + s11 + s21)/2.f; - const float m12 = (s02 + s12 + s22)/2.f; - const float m20 = (s00 - s10 + s20)/2.f; - const float m21 = (s01 - s11 + s21)/2.f; - const float m22 = (s02 - s12 + s22)/2.f; - const float m30 = s20; - const float m31 = s21; - const float m32 = s22; - - dst_k[(4*dst_oh + 0)*dst_kw_sz + 0*dst_ow_sz + 4*dst_ow + dst_iw] = m00; - dst_k[(4*dst_oh + 0)*dst_kw_sz + 4*dst_ow_sz + 4*dst_ow + dst_iw] = (m00 + m01 + m02)/2.f; - dst_k[(4*dst_oh + 0)*dst_kw_sz + 8*dst_ow_sz + 4*dst_ow + dst_iw] = (m00 - m01 + m02)/2.f; - dst_k[(4*dst_oh + 0)*dst_kw_sz + 12*dst_ow_sz + 4*dst_ow + dst_iw] = m02; - dst_k[(4*dst_oh + 1)*dst_kw_sz + 0*dst_ow_sz + 4*dst_ow + dst_iw] = m10; - dst_k[(4*dst_oh + 1)*dst_kw_sz + 4*dst_ow_sz + 4*dst_ow + dst_iw] = (m10 + m11 + m12)/2.f; - dst_k[(4*dst_oh + 1)*dst_kw_sz + 8*dst_ow_sz + 4*dst_ow + dst_iw] = (m10 - m11 + m12)/2.f; - dst_k[(4*dst_oh + 1)*dst_kw_sz + 12*dst_ow_sz + 4*dst_ow + dst_iw] = m12; - dst_k[(4*dst_oh + 2)*dst_kw_sz + 0*dst_ow_sz + 4*dst_ow + dst_iw] = m20; - dst_k[(4*dst_oh + 2)*dst_kw_sz + 4*dst_ow_sz + 4*dst_ow + dst_iw] = (m20 + m21 + m22)/2.f; - dst_k[(4*dst_oh + 2)*dst_kw_sz + 8*dst_ow_sz + 4*dst_ow + dst_iw] = (m20 - m21 + m22)/2.f; - dst_k[(4*dst_oh + 2)*dst_kw_sz + 12*dst_ow_sz + 4*dst_ow + dst_iw] = m22; - dst_k[(4*dst_oh + 3)*dst_kw_sz + 0*dst_ow_sz + 4*dst_ow + dst_iw] = m30; - dst_k[(4*dst_oh + 3)*dst_kw_sz + 4*dst_ow_sz + 4*dst_ow + dst_iw] = (m30 + m31 + m32)/2.f; - dst_k[(4*dst_oh + 3)*dst_kw_sz + 8*dst_ow_sz + 4*dst_ow + dst_iw] = (m30 - m31 + m32)/2.f; - dst_k[(4*dst_oh + 3)*dst_kw_sz + 12*dst_ow_sz + 4*dst_ow + dst_iw] = m32; - } - } - - return v_weight; -} - -vTensor pack_weights( - const Tensor& weight_arg, - const Conv2dMethod conv_method) { +vTensor pack_weights(const Tensor& weight_arg, const Conv2dMethod conv_method) { if (weight_arg.is_vulkan()) { return convert(weight_arg); } api::Context* const context = api::context(); - api::Command::Buffer& command_buffer = context->command().pool.stream(); // Don't collect the timestamp since the command buffer doesn't record anything const Tensor weight = weight_arg.contiguous(); if (conv_method == Conv2dDepthwise) { - return pack_weights_dw( - context, - command_buffer, - weight); - } - - if (conv_method == Conv2dWinograd_2_3) { - return pack_weights_2d_winograd_2_3( - context, - command_buffer, - weight); + return pack_weights_dw(context, weight); } - return pack_weights_2d( - context, - command_buffer, - weight); + return pack_weights_2d(context, weight); } -vTensor pack_biases( - const c10::optional& bias, - const Tensor& weight) { +vTensor pack_biases(const c10::optional& bias, const Tensor& weight) { if (bias && bias->is_vulkan()) { return convert(*bias); } api::Context* const context = api::context(); - api::Command::Buffer& command_buffer = context->command().pool.stream(); // Don't collect the timestamp since the command buffer doesn't record anything const int64_t src_w = weight.size(Layout::Filter::output); const int64_t packed_w = div_up(src_w, INT64_C(4)); vTensor v_bias{ - context, - { - 4, - 1, - packed_w, - }, - weight.options(), + context, + { + 4, + 1, + packed_w, + }, + weight.options(), }; - using Future = vTensor::Future; - Future v_bias_future = v_bias.host(command_buffer); - Future::Payload v_bias_payload = v_bias_future.wait(); + api::StagingBuffer staging(context, v_bias.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + float* dst_bias_ptr = mapping.template data(); - if (bias) { - const float* const src_bias_ptr = bias->contiguous().data_ptr(); - float* const dst_bias_ptr = v_bias_payload.get(); + if (bias) { + const float* const src_bias_ptr = bias->contiguous().data_ptr(); - memset(dst_bias_ptr, 0, v_bias.nbytes()); - for (const auto i : c10::irange(src_w)) { - const int64_t c = i % 4; - const int64_t x = i / 4; - dst_bias_ptr[c * packed_w + x] = src_bias_ptr[i]; + memset(dst_bias_ptr, 0, v_bias.nbytes()); + for (const auto i : c10::irange(src_w)) { + const int64_t c = i % 4; + const int64_t x = i / 4; + dst_bias_ptr[c * packed_w + x] = src_bias_ptr[i]; + } + } else { + memset( + dst_bias_ptr, + // 2's complement integers and IEEE-754 floating point numbers both + // have identical bit representations for 0, so can use memset which + // only accepts uint8_t parameter. + 0, + v_bias.nbytes()); } } - else { - memset( - v_bias_payload.get(), - // 2's complement integers and IEEE-754 floating point numbers both - // have identical bit representations for 0, so can use memset which - // only accepts uint8_t parameter. - 0, - v_bias.nbytes()); - } + ops::utils::pack_staging_to_vtensor(staging.buffer(), v_bias); return v_bias; } @@ -376,14 +257,12 @@ std::array pack_filter( }; return { - align_up(filter[Layout::Filter::output], INT64_C(4)), - align_up(filter[Layout::Filter::input], INT64_C(4)), - effective( - filter[Layout::Filter::height], - dilation[Layout::Parameter::height]), - effective( - filter[Layout::Filter::width], - dilation[Layout::Parameter::width]), + align_up(filter[Layout::Filter::output], INT64_C(4)), + align_up(filter[Layout::Filter::input], INT64_C(4)), + effective( + filter[Layout::Filter::height], dilation[Layout::Parameter::height]), + effective( + filter[Layout::Filter::width], dilation[Layout::Parameter::width]), }; } @@ -391,8 +270,8 @@ std::array pack_params(const std::vector& vector) { TORCH_INTERNAL_ASSERT(2u == vector.size(), "Invalid usage!"); return { - vector[0], - vector[1], + vector[0], + vector[1], }; } @@ -408,56 +287,54 @@ bool available( const c10::optional& output_min, const c10::optional& output_max) { return api::available() && - // Weight - (4 == weight.ndimension()) && - (weight.size(Layout::Filter::height) > 0) && - (weight.size(Layout::Filter::width) > 0) && - ((weight.device().is_cpu()) || - (c10::DeviceType::Vulkan == weight.device().type())) && - (kFloat == weight.scalar_type()) && - // Bias - ((bias && bias->defined()) ? ((1 == bias->ndimension()) && - ((bias->device().is_cpu()) || - (c10::DeviceType::Vulkan == bias->device().type())) && - (kFloat == bias->scalar_type()) && - (transposed ? false /* to be addded in the future */ - : (weight.size(Layout::Filter::output) == - bias->size(Layout::Filter::output)))) - : true) && - // Stride - (stride[Layout::Parameter::height] > 0) && - (stride[Layout::Parameter::width] > 0) && - // Padding - (padding[Layout::Parameter::height] >= 0) && - (padding[Layout::Parameter::width] >= 0) && - // Dilation - (dilation[Layout::Parameter::height] > 0) && - (dilation[Layout::Parameter::width] > 0) && - // Groups - (groups > 0) && - // Input - (weight.size(Layout::Filter::input) > 0) && - // Output - (weight.size(Layout::Filter::output) > 0) && - // Output - Groups - ((weight.size(Layout::Filter::output) % groups) == 0) && - // Output Min / Max - (!output_min || output_min->isFloatingPoint()) && - (!output_max || output_max->isFloatingPoint()) && - true; + // Weight + (4 == weight.ndimension()) && (weight.size(Layout::Filter::height) > 0) && + (weight.size(Layout::Filter::width) > 0) && + ((weight.device().is_cpu()) || + (c10::DeviceType::Vulkan == weight.device().type())) && + (kFloat == weight.scalar_type()) && + // Bias + ((bias && bias->defined()) + ? ((1 == bias->ndimension()) && + ((bias->device().is_cpu()) || + (c10::DeviceType::Vulkan == bias->device().type())) && + (kFloat == bias->scalar_type()) && + (transposed ? false /* to be addded in the future */ + : (weight.size(Layout::Filter::output) == + bias->size(Layout::Filter::output)))) + : true) && + // Stride + (stride[Layout::Parameter::height] > 0) && + (stride[Layout::Parameter::width] > 0) && + // Padding + (padding[Layout::Parameter::height] >= 0) && + (padding[Layout::Parameter::width] >= 0) && + // Dilation + (dilation[Layout::Parameter::height] > 0) && + (dilation[Layout::Parameter::width] > 0) && + // Groups + (groups > 0) && + // Input + (weight.size(Layout::Filter::input) > 0) && + // Output + (weight.size(Layout::Filter::output) > 0) && + // Output - Groups + ((weight.size(Layout::Filter::output) % groups) == 0) && + // Output Min / Max + (!output_min || output_min->isFloatingPoint()) && + (!output_max || output_max->isFloatingPoint()) && true; } bool usable(const Tensor& input) { - // Input + // Input return (4 == input.ndimension()) && - (c10::DeviceType::Vulkan == input.device().type()) && - (kFloat == input.scalar_type()) && - (input.size(Layout::Activation4D::batch) >= 0) && - (input.size(Layout::Activation4D::channels) > 0) && - (input.size(Layout::Activation4D::height) > 0) && - (input.size(Layout::Activation4D::width) > 0) && - !input.requires_grad() && - true; + (c10::DeviceType::Vulkan == input.device().type()) && + (kFloat == input.scalar_type()) && + (input.size(Layout::Activation4D::batch) >= 0) && + (input.size(Layout::Activation4D::channels) > 0) && + (input.size(Layout::Activation4D::height) > 0) && + (input.size(Layout::Activation4D::width) > 0) && !input.requires_grad() && + true; } } // namespace @@ -495,12 +372,8 @@ VulkanOpContext conv2d_context_create( "transposed, output_padding, output_min, output_max) parameters are either " "invalid individually or their combination is not supported by Vulkan impl."); - const auto method = determine_method( - weight.sizes(), - stride, - padding, - dilation, - groups); + const auto method = + determine_method(weight.sizes(), stride, padding, dilation, groups); c10::impl::GenericList packed_context{c10::AnyType::get()}; packed_context.reserve(10); @@ -512,8 +385,12 @@ VulkanOpContext conv2d_context_create( packed_context.emplace_back(output_padding); packed_context.emplace_back(pack_params(dilation)); packed_context.emplace_back(safe_downcast(groups)); - packed_context.emplace_back(output_min ? output_min->template to() : -std::numeric_limits::infinity()); - packed_context.emplace_back(output_max ? output_max->template to() : +std::numeric_limits::infinity()); + packed_context.emplace_back( + output_min ? output_min->template to() + : -std::numeric_limits::infinity()); + packed_context.emplace_back( + output_max ? output_max->template to() + : +std::numeric_limits::infinity()); packed_context.emplace_back(method); c10::impl::GenericList unpacked_context{c10::AnyType::get()}; @@ -534,7 +411,7 @@ VulkanOpContext conv2d_context_create( } void conv2d_sliding_window( - const api::Shader::Descriptor& shader, + const api::ShaderSource& shader, vTensor& v_output, const vTensor& v_input, const vTensor& packed_v_weight, @@ -546,228 +423,84 @@ void conv2d_sliding_window( const float packed_output_min, const float packed_output_max, const IntArrayRef unpacked_filter, - const Conv2dMethod method_, - const std::string& op_name) { - bool valid = C10_LIKELY(v_output.has_image() && v_input.has_image() && packed_v_weight.has_image()); - TORCH_CHECK(valid, "Not Implemented!") - + const Conv2dMethod method_) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - const struct Block final { - uvec3 extents; - int32_t ic4; - ivec4 kernel; - ivec2 ikernel; - ivec2 stride; - ivec2 padding; - ivec2 dilate; - vec2 clamp; - ivec4 src_filter; - } block { + + const struct Block final { + uvec3 extents; + int32_t ic4; + ivec4 kernel; + ivec2 ikernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + ivec4 src_filter; + } block{ v_output.extents(), safe_downcast(packed_filter[Layout::Filter::input]), { - safe_downcast(packed_filter[Layout::Filter::width]), - safe_downcast(packed_filter[Layout::Filter::height]), - safe_downcast(v_input.sizes()[Layout::Activation4D::width]), - safe_downcast(v_input.sizes()[Layout::Activation4D::height]), + safe_downcast(packed_filter[Layout::Filter::width]), + safe_downcast(packed_filter[Layout::Filter::height]), + safe_downcast(v_input.sizes()[Layout::Activation4D::width]), + safe_downcast(v_input.sizes()[Layout::Activation4D::height]), }, { - safe_downcast(unpacked_filter[Layout::Filter::width]), - safe_downcast(unpacked_filter[Layout::Filter::height]), + safe_downcast(unpacked_filter[Layout::Filter::width]), + safe_downcast(unpacked_filter[Layout::Filter::height]), }, { - safe_downcast(packed_stride[Layout::Parameter::width]), - safe_downcast(packed_stride[Layout::Parameter::height]), + safe_downcast(packed_stride[Layout::Parameter::width]), + safe_downcast(packed_stride[Layout::Parameter::height]), }, { - safe_downcast(packed_padding[Layout::Parameter::width]), - safe_downcast(packed_padding[Layout::Parameter::height]), + safe_downcast(packed_padding[Layout::Parameter::width]), + safe_downcast(packed_padding[Layout::Parameter::height]), }, { - safe_downcast(packed_dilation[Layout::Parameter::width]), - safe_downcast(packed_dilation[Layout::Parameter::height]), + safe_downcast(packed_dilation[Layout::Parameter::width]), + safe_downcast(packed_dilation[Layout::Parameter::height]), }, { - packed_output_min, - packed_output_max, + packed_output_min, + packed_output_max, }, - }; - - uvec3 global_size = v_output.extents(); - if (method_ == Conv2dPointwise) { - global_size = { - safe_downcast(div_up(v_output.sizes()[Layout::Filter::width], INT64_C(2))), - safe_downcast(div_up(v_output.sizes()[Layout::Filter::height], INT64_C(2))), - v_output.extents().data[2u] - }; - } + }; - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader, - global_size, - adaptive_work_group_size(global_size), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - packed_v_weight.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - packed_v_bias.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); + uvec3 global_size = v_output.extents(); + if (method_ == Conv2dPointwise) { + global_size = { + safe_downcast( + div_up(v_output.sizes()[Layout::Filter::width], INT64_C(2))), + safe_downcast( + div_up(v_output.sizes()[Layout::Filter::height], INT64_C(2))), + v_output.extents().data[2u]}; } - command_pool.submit(context->gpu().queue, command_buffer); -} - -void conv2d_winograd_2_3( - vTensor& v_output, - const vTensor& v_input, - const vTensor& packed_v_weight, - const vTensor& packed_v_bias, - const IntArrayRef packed_filter, - const IntArrayRef packed_padding, - const float packed_output_min, - const float packed_output_max) { - // Winograd(2x2, 3x3) calculates 2x2 tile of output for every subprogram - const int64_t out_h_units = div_up(v_output.sizes()[Layout::Activation4D::height], INT64_C(2)); - const int64_t out_w_units = div_up(v_output.sizes()[Layout::Activation4D::width], INT64_C(2)); - - bool valid = C10_LIKELY(v_output.has_image() && v_input.has_image() && packed_v_weight.has_image()); - TORCH_CHECK(valid, "Not Implemented!") - - api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "prepacked::conv2d_clamp_run (conv2d_winograd_2_3)"); - - vTensor v_input_winograd{ - context, - { - v_input.sizes()[Layout::Activation4D::batch], - v_input.sizes()[Layout::Activation4D::channels], - out_h_units*4, - out_w_units*4, - }, - v_output.options(), - }; - - { - const struct TransformBlock final { - uvec3 extents; - uint32_t fill; - ivec2 limits; - ivec2 padding; - } transform_block { - v_input_winograd.extents(), - 0u, - { - safe_downcast(v_input.sizes()[Layout::Activation4D::width]), - safe_downcast(v_input.sizes()[Layout::Activation4D::height]), - }, - { - safe_downcast(packed_padding[Layout::Parameter::width]), - safe_downcast(packed_padding[Layout::Parameter::height]), - }, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(transform_winograd_2_3_sh), - v_input_winograd.extents(), - adaptive_work_group_size(v_input_winograd.extents()), - v_input_winograd.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - v_input.image( - command_buffer, - vTensor::Stage::Compute), - context->resource().pool.uniform(transform_block).object); - } - { - const struct Block final { - uvec3 extents; - int32_t ic4; - vec2 clamp; - } block { - v_output.extents(), - safe_downcast(packed_filter[Layout::Filter::input] / 4), - { - packed_output_min, - packed_output_max, - }, - }; - - uvec3 global_size = { - safe_downcast(out_w_units), - safe_downcast(out_h_units), - v_output.extents().data[2u], - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(conv2d_winograd_2_3), - global_size, - adaptive_work_group_size(global_size), - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - v_input_winograd.image( - command_buffer, - vTensor::Stage::Compute), - packed_v_weight.image( - command_buffer, - vTensor::Stage::Compute), - packed_v_bias.buffer( - command_buffer, - vTensor::Stage::Compute), - context->resource().pool.uniform(block).object); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader, + // pipeline barrier + pipeline_barrier, + // global work group size + global_size, + // local work group size + adaptive_work_group_size(global_size), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); } Tensor conv2d_context_run( @@ -797,77 +530,64 @@ Tensor conv2d_context_run( "Reason: The provided input tensor is either invalid or unsupported by Vulkan impl."); vTensor v_output{ - context, - conv_output_size( - v_input.sizes(), - unpacked_filter, - packed_padding, - packed_stride, - packed_dilation), - input.options(), + context, + conv_output_size( + v_input.sizes(), + unpacked_filter, + packed_padding, + packed_stride, + packed_dilation), + input.options(), }; - switch(method_) { - case Conv2dWinograd_2_3: - conv2d_winograd_2_3( - v_output, - v_input, - packed_v_weight, - packed_v_bias, - packed_filter, - packed_padding, - packed_output_min, - packed_output_max); + switch (method_) { case Conv2dDepthwise: conv2d_sliding_window( - VK_KERNEL(conv2d_dw), - v_output, - v_input, - packed_v_weight, - packed_v_bias, - packed_filter, - packed_stride, - packed_padding, - packed_dilation, - packed_output_min, - packed_output_max, - unpacked_filter, - method_, - "prepacked::conv2d_clamp_run (conv2d_sliding_window::conv2d_dw)"); + VK_KERNEL(conv2d_dw), + v_output, + v_input, + packed_v_weight, + packed_v_bias, + packed_filter, + packed_stride, + packed_padding, + packed_dilation, + packed_output_min, + packed_output_max, + unpacked_filter, + method_); break; case Conv2dPointwise: conv2d_sliding_window( - VK_KERNEL(conv2d_pw_2x2), - v_output, - v_input, - packed_v_weight, - packed_v_bias, - packed_filter, - packed_stride, - packed_padding, - packed_dilation, - packed_output_min, - packed_output_max, - unpacked_filter, - method_, - "prepacked::conv2d_clamp_run (conv2d_sliding_window::conv2d_pw_2x2)"); + VK_KERNEL(conv2d_pw_2x2), + v_output, + v_input, + packed_v_weight, + packed_v_bias, + packed_filter, + packed_stride, + packed_padding, + packed_dilation, + packed_output_min, + packed_output_max, + unpacked_filter, + method_); break; default: conv2d_sliding_window( - VK_KERNEL(conv2d), - v_output, - v_input, - packed_v_weight, - packed_v_bias, - packed_filter, - packed_stride, - packed_padding, - packed_dilation, - packed_output_min, - packed_output_max, - unpacked_filter, - method_, - "prepacked::conv2d_clamp_run (conv2d_sliding_window::conv2d)"); + VK_KERNEL(conv2d), + v_output, + v_input, + packed_v_weight, + packed_v_bias, + packed_filter, + packed_stride, + packed_padding, + packed_dilation, + packed_output_min, + packed_output_max, + unpacked_filter, + method_); break; } @@ -883,33 +603,29 @@ c10::intrusive_ptr create_conv2d_clamp_context( const int64_t groups, const c10::optional& output_min, const c10::optional& output_max) { - return c10::make_intrusive( - conv2d_context_create( - weight, - bias, - stride, - padding, - dilation, - /* transposed = */ false, - /* output_padding_arg = */ {}, - groups, - output_min, - output_max)); + return c10::make_intrusive(conv2d_context_create( + weight, + bias, + stride, + padding, + dilation, + /* transposed = */ false, + /* output_padding_arg = */ {}, + groups, + output_min, + output_max)); } Tensor run_conv2d_clamp_context( const Tensor& input, const c10::intrusive_ptr& vulkan_context) { return conv2d_context_run( - input, - vulkan_context->get_packed(), - vulkan_context->get_unpacked()); + input, vulkan_context->get_packed(), vulkan_context->get_unpacked()); } /* Backwards compatibility */ Conv2dOpContext::Conv2dOpContext(VulkanOpContext vulkan_context) - : vulkan_context_{std::move(vulkan_context)} { -} + : vulkan_context_{std::move(vulkan_context)} {} Conv2dOpContext Conv2dOpContext::create( const Tensor& weight, @@ -922,51 +638,50 @@ Conv2dOpContext Conv2dOpContext::create( const int64_t groups, const c10::optional& output_min, const c10::optional& output_max) { - return Conv2dOpContext { - conv2d_context_create( - weight, - bias, - stride_arg, - padding_arg, - dilation_arg, - transposed, - output_padding_arg, - groups, - output_min, - output_max) - }; + return Conv2dOpContext{conv2d_context_create( + weight, + bias, + stride_arg, + padding_arg, + dilation_arg, + transposed, + output_padding_arg, + groups, + output_min, + output_max)}; } Tensor Conv2dOpContext::run(const Tensor& input_arg) const { return conv2d_context_run( - input_arg, - vulkan_context_.get_packed(), - vulkan_context_.get_unpacked()); + input_arg, vulkan_context_.get_packed(), vulkan_context_.get_unpacked()); } Conv2dOpContext::State Conv2dOpContext::unpack() const { - const c10::impl::GenericList unpacked_ = std::get<1>(vulkan_context_.get_state()); + const c10::impl::GenericList unpacked_ = + std::get<1>(vulkan_context_.get_state()); const Tensor unpacked_weight = unpacked_.get(0).toTensor(); - const c10::optional unpacked_bias - = unpacked_.get(1).isTensor() ? unpacked_.get(1).toTensor() : (c10::optional&) c10::nullopt; + const c10::optional unpacked_bias = unpacked_.get(1).isTensor() + ? unpacked_.get(1).toTensor() + : (c10::optional&)c10::nullopt; const std::vector unpacked_stride = unpacked_.get(2).toIntVector(); const std::vector unpacked_padding = unpacked_.get(3).toIntVector(); const std::vector unpacked_dilation = unpacked_.get(4).toIntVector(); const int64_t unpacked_groups = unpacked_.get(5).toInt(); - const c10::optional unpacked_output_min - = unpacked_.get(6).isScalar() ? unpacked_.get(6).toScalar() : (c10::optional) c10::nullopt; - const c10::optional unpacked_output_max - = unpacked_.get(6).isScalar() ? unpacked_.get(7).toScalar() : (c10::optional) c10::nullopt; + const c10::optional unpacked_output_min = unpacked_.get(6).isScalar() + ? unpacked_.get(6).toScalar() + : (c10::optional)c10::nullopt; + const c10::optional unpacked_output_max = unpacked_.get(6).isScalar() + ? unpacked_.get(7).toScalar() + : (c10::optional)c10::nullopt; return Conv2dOpContext::State{ - unpacked_weight, - unpacked_bias, - unpacked_stride, - unpacked_padding, - unpacked_dilation, - unpacked_groups, - unpacked_output_min, - unpacked_output_max - }; + unpacked_weight, + unpacked_bias, + unpacked_stride, + unpacked_padding, + unpacked_dilation, + unpacked_groups, + unpacked_output_min, + unpacked_output_max}; } c10::intrusive_ptr conv2d_clamp_prepack( @@ -978,18 +693,17 @@ c10::intrusive_ptr conv2d_clamp_prepack( const int64_t groups, const c10::optional& output_min, const c10::optional& output_max) { - return c10::make_intrusive( - Conv2dOpContext::create( - std::move(weight), - std::move(bias), - std::move(stride), - std::move(padding), - std::move(dilation), - /* transposed = */ false, - /* output_padding = */ {}, - groups, - output_min, - output_max)); + return c10::make_intrusive(Conv2dOpContext::create( + std::move(weight), + std::move(bias), + std::move(stride), + std::move(padding), + std::move(dilation), + /* transposed = */ false, + /* output_padding = */ {}, + groups, + output_min, + output_max)); } Tensor conv2d_clamp_run( diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.h b/aten/src/ATen/native/vulkan/ops/Convolution.h index 90fd5282dccca..69680a4b167b4 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.h +++ b/aten/src/ATen/native/vulkan/ops/Convolution.h @@ -14,7 +14,6 @@ enum Conv2dMethod { Conv2dDepthwise, Conv2dPointwise, Conv2dSlidingWindow, - Conv2dWinograd_2_3, }; // private: diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp index b7fbea07d9e60..fb4db712a8ad9 100644 --- a/aten/src/ATen/native/vulkan/ops/Copy.cpp +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -1,169 +1,147 @@ -#include #include +#include namespace at { namespace native { namespace vulkan { namespace ops { -Tensor& copy_(Tensor& self, const Tensor& src) { +void copy_vulkan_to_vulkan(vTensor& src, vTensor& dst) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; + api::PipelineBarrier pipeline_barrier{}; + + context->submit_texture_copy( + // pipeline barrier + pipeline_barrier, + // images + src.image(pipeline_barrier, api::PipelineStage::TRANSFER), + dst.image( + pipeline_barrier, + api::PipelineStage::TRANSFER, + api::MemoryAccessType::WRITE), + // copy details + src.extents(), + {0u, 0u, 0u}, + {0u, 0u, 0u}, + // fence handle + VK_NULL_HANDLE); +} + +void copy_cpu_to_vulkan(const Tensor& src, vTensor& dst) { + api::Context* const context = api::context(); + + api::StagingBuffer staging(context, dst.buffer_bytes()); { - // X -> Vulkan - if (at::kVulkan == self.device().type()) { - vTensor& v_self = convert(self); - - // Vulkan -> Vulkan - if (at::kVulkan == src.device().type()) { - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "copy_"); - - command_buffer.copy( - // - Read-only access is implied on const tensors. Memory barriers - // are automatically inserted if a RAW hazard is detected. - // - Recording any potential pending sync operations into the same - // command buffer prevents an expensive queue submission. - convert(src).buffer( - command_buffer, - vTensor::Stage::Transfer), - // - Write-only access never triggers a sync as the contents will be - // overwritten regardless. Having said that, appropriate barriers - // are inserted automatically if WAR or WAW hazards are detected. - // - Recording pending sync operations into the same command buffer - // prevents an expensive queue submission. - v_self.buffer( - command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write)); - } - command_pool.submit(context->gpu().queue, command_buffer); - } - // CPU -> Vulkan - else { - api::Command::Buffer& command_buffer = command_pool.stream(); // Don't collect the timestamp since the command buffer doesn't record anything - const Tensor cpu_src = src.device().is_cpu() ? src : src.cpu(); - - // Requesting write-only host access to the tensor never triggers a sync - // as the contents will be overwritten regardless. Having said that, - // appropriate barriers are inserted automatically if WAR or WAW hazards - // are detected. Examples of such scenario for instance are if any of - // these async operations are on going in the background on 'self': - // - On discrete systems: - // * buffer-to-staging transfers - // * staging-to-buffer transfers - // - On UMA buffer is an alias for staging and accessible both on host - // and device. Consequently: - // * buffer-to-image NHWC -> NC4HW packing - // * image-to-buffer NC4HW -> NHWC unpacking - - using Future = vTensor::Future; - Future v_self_future = v_self.host(command_buffer); - - // Ideally we would have been able to put as much distance between - // requesting the data - a call to host() - and accessing the data - // - a call to wait() - but a local view of the computation graph - // in eager mode makes that optimization non-trivial. - - // This wait() will be a no-op if no hazards are detected, including the - // obvious, yet important, special case of 'self' being an empty tensor. - - Future::Payload v_self_payload = v_self_future.wait(); - - memcpy( - v_self_payload.get(), - cpu_src.contiguous().data_ptr(), - std::min(src.nbytes(), self.nbytes())); - } + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + if (src.dtype() == c10::kQUInt8) { + c10::quint8* data_ptr = mapping.template data(); + memcpy( + data_ptr, + src.contiguous().data_ptr(), + std::min(src.nbytes(), src.nbytes())); + } else { + float* data_ptr = mapping.template data(); + memcpy( + data_ptr, + src.contiguous().data_ptr(), + std::min(src.nbytes(), src.nbytes())); } - // Vulkan -> X - else if (at::kVulkan == src.device().type()) { - api::Command::Buffer& command_buffer = command_pool.stream(); // Don't collect the timestamp since the command buffer doesn't record anything - const vTensor& v_src = convert(src); - - // Vulkan -> CPU - if (self.device().is_cpu()) { - // Similar notes as above applies, with the additional consideration of - // potential syncs on read accesses. Namely, - // - on discrete systems, if the (staging, buffer, image) trio, or - // - on UMA, if the (buffer, image) duo - // have gone out of sync as a result of one processor writing to one - // resource which is then either accessed as an another resource type on - // the same or another processor. Same considerations regarding hazard - // avoidance as above applies. - - using Future = vTensor::Future; - const Future v_src_future = v_src.host(command_buffer); - - // Ideally we would have been able to put as much distance between - // requesting the data - a call to host() - and accessing the data - // - a call to wait() - but a local view of the computation graph - // in eager mode makes that optimization non-trivial. - - // This wait() is a no-op if data is not out of sync. More often than - // not though, waits here are expected as the GPU catches up with - // compute submitted from CPU. - - const Future::Payload v_src_payload = v_src_future.wait(); - - memcpy( - self.data_ptr(), - v_src_payload.get(), - std::min(src.nbytes(), self.nbytes())); - } - else { - TORCH_CHECK(false, "Unsupported!"); - } - - // - // WARNING - // - - // This is not great. We almost never want to flush the GPU pipeline as - // that has far reaching consequences, especially if PyTorch is not the only - // process accessing the GPU. If we have done our job properly, above - // synchronization mechanisms should be enough to ensure correctness at a more - // modest cost, as there is no need to flush the entirety of jobs in flight - // if one is only interested on waiting on computation affecting one single - // tensor to finish. - // - // Having said that, we still do need to release all pool resources at one - // point per inference run or we will run out of memory otherwise. There is - // no perfect answer to this problem that checks all boxes, which leaves us - // with one of several design decisions: - // - // 1) Use graph mode to gain an understanding of the computation graph, - // itself allowing us to place pool purges intelligently. Best option - // for performance and memory consumption. Not without its downsides if - // flexibility is a top priority. - // 2) If on eager mode, and hence are seeing operations one at a time, expose - // this release of resources to the user as a Python / C++ function. This - // makes for suboptimal user experience but is efficient in terms of - // performance. - // 3) If on eager mode, and interested in keeping this bookkeeping transparent - // to the user, release all resources somewhere ... like here. This is - // not ideal since it requires a pipeline flush to make sure these objects - // are not already in use by a workload in flight. Cannot do much better - // within the constraints of this approach. Good for user experience, - // suboptimal for performance. - // 4) If on eager mode, and interested in keeping this bookkeeping transparent - // to the user, and performance does not matter, make CPU and GPU run in - // lockstep. Obviously this is just bad. Mentioned for the sake of - // completeness. - - context->flush(); + } + utils::pack_staging_to_vtensor(staging.buffer(), dst); +} + +void copy_vulkan_to_cpu(vTensor& src, Tensor& dst) { + api::Context* const context = api::context(); + + api::StagingBuffer staging(context, src.buffer_bytes()); + + api::VulkanFence fence = context->fences().get_fence(); + + { + // Refer to comment in submit_compute_job. When syncing with the GPU, the + // context must not allow other threads to record dispatches into it between + // between calling vkQueueSubmit and flushing the context. Therefore, + // cmd_mutex_ must be manually managed by the calling thread. + std::unique_lock context_lock(context->dispatch_lock()); + + utils::pack_vtensor_to_staging( + src, staging.buffer(), fence.get_submit_handle()); + + fence.wait(); + + context->flush(); + // cmd_mutex_ will be released when exiting this scope. + } + + // Copy data from buffer back to CPU tensor. + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::READ); + mapping.invalidate(); + + if (dst.is_quantized()) { + c10::quint8* data_ptr = mapping.template data(); + memcpy( + dst.data_ptr(), + data_ptr, + std::min(src.nbytes(), dst.nbytes())); + } else { + float* data_ptr = mapping.template data(); + memcpy( + dst.data_ptr(), + data_ptr, + std::min(src.nbytes(), dst.nbytes())); } + } + + context->fences().return_fence(fence); +} + +Tensor& copy_(Tensor& self, const Tensor& src) { + // Check that sizes are equal + TORCH_CHECK( + self.sizes() == src.sizes(), + "Vulkan copy_: Tensor sizes are mismatched!"); + + // X -> Vulkan + if (at::kVulkan == self.device().type()) { + vTensor& v_self = convert(self); + + // Vulkan -> Vulkan + if (at::kVulkan == src.device().type()) { + vTensor& v_src = convert(src); + copy_vulkan_to_vulkan(v_src, v_self); + } + // CPU -> Vulkan else { - TORCH_INTERNAL_ASSERT( - false, - "Invalid code path taken! Either the source or the destination tensor " - "was expected to be Vulkan a tensor! Incorrect dispatch?"); + TORCH_CHECK( + src.dtype() == c10::kQUInt8 || src.dtype() == at::kFloat, + "Invalid Data Type: expected QUint8 or Float but got ", + src.dtype()); + copy_cpu_to_vulkan(src, v_self); + } + } + // Vulkan -> X + else if (at::kVulkan == src.device().type()) { + vTensor& v_src = convert(src); + + // Vulkan -> CPU + if (self.device().is_cpu()) { + TORCH_CHECK( + self.dtype() == c10::kQUInt8 || self.dtype() == at::kFloat, + "Invalid Data Type: expected QUint8 or Float but got ", + self.dtype()); + copy_vulkan_to_cpu(v_src, self); + } else { + TORCH_CHECK(false, "Unsupported!"); } + } else { + TORCH_INTERNAL_ASSERT( + false, + "Invalid code path taken! Either the source or the destination tensor " + "was expected to be Vulkan a tensor! Incorrect dispatch?"); } - // No queue submission here. All queue submissions must have been handled - // above either explicitly or as a result of calling tensor.host(). return self; } diff --git a/aten/src/ATen/native/vulkan/ops/Factory.cpp b/aten/src/ATen/native/vulkan/ops/Factory.cpp index a8dbc2aaa50bb..06d44ec061935 100644 --- a/aten/src/ATen/native/vulkan/ops/Factory.cpp +++ b/aten/src/ATen/native/vulkan/ops/Factory.cpp @@ -1,11 +1,32 @@ -#include +#include #include namespace at { namespace native { namespace vulkan { namespace ops { -namespace { + +Tensor _empty_affine_quantized( + const IntArrayRef sizes, + const c10::optional dtype, + const c10::optional layout, + const c10::optional device, + const c10::optional pin_memory, + const double scale, + const int64_t zero_point, + const optional memory_format) { + return convert_quantized(vTensor{ + api::context(), + sizes, + TensorOptions() + .dtype(dtype) + .layout(layout) + .device(device) + .pinned_memory(pin_memory) + .memory_format(memory_format), + scale, + zero_point}); +} Tensor empty_memory_format( const IntArrayRef sizes, @@ -23,7 +44,7 @@ Tensor empty_memory_format( .device(device) .pinned_memory(pin_memory) .memory_format(memory_format), - }); + }); } Tensor empty_strided( @@ -34,24 +55,25 @@ Tensor empty_strided( const optional device, const optional pin_memory) { return empty_memory_format( - sizes, - dtype, - layout, - device, - pin_memory, - c10::MemoryFormat::Contiguous); + sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous); } #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { - m.impl(TORCH_SELECTIVE_NAME("aten::empty.memory_format"), at::native::vulkan::ops::empty_memory_format); - m.impl(TORCH_SELECTIVE_NAME("aten::empty_strided"), TORCH_FN(at::native::vulkan::ops::empty_strided)); + m.impl( + TORCH_SELECTIVE_NAME("aten::empty.memory_format"), + at::native::vulkan::ops::empty_memory_format); + m.impl( + TORCH_SELECTIVE_NAME("aten::_empty_affine_quantized"), + at::native::vulkan::ops::_empty_affine_quantized); + m.impl( + TORCH_SELECTIVE_NAME("aten::empty_strided"), + TORCH_FN(at::native::vulkan::ops::empty_strided)); } #endif /* USE_VULKAN_API */ -} // namespace } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Factory.h b/aten/src/ATen/native/vulkan/ops/Factory.h new file mode 100644 index 0000000000000..9dee6307bb85c --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Factory.h @@ -0,0 +1,21 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Tensor _empty_affine_quantized( + const IntArrayRef sizes, + const c10::optional dtype, + const c10::optional layout, + const c10::optional device, + const c10::optional pin_memory, + const double scale, + const int64_t zero_point, + const optional memory_format); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Glu.cpp b/aten/src/ATen/native/vulkan/ops/Glu.cpp new file mode 100644 index 0000000000000..1778813bce57b --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Glu.cpp @@ -0,0 +1,80 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor glu(const at::Tensor& input_arg, const int64_t dim = -1) { + TORCH_CHECK(input_arg.dim() == 4, "Vulkan glu only supports 4-dim input!"); + TORCH_CHECK( + dim == 1, + "Vulkan glu only supports GLU for dim = 1, but got dim = ", + dim); + TORCH_CHECK( + channels_size(input_arg) % 2 == 0, + "Vulkan glu expects channel dim to be multiple of 2!"); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + const IntArrayRef v_input_sizes = v_input.sizes(); + + auto output_ch_size = v_input.sizes()[1] / 2; + + api::Context* const context = api::context(); + + vTensor v_output{ + context, + {v_input_sizes[0], output_ch_size, v_input_sizes[2], v_input_sizes[3]}, + v_input.options(), + }; + + const struct Block final { + uvec3 extents; + int32_t chext; + } block{v_output.extents(), safe_downcast(output_ch_size)}; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + output_ch_size % 4 == 0 ? VK_KERNEL(glu_channel_mul4) + : VK_KERNEL(glu_channel), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl(TORCH_SELECTIVE_NAME("aten::glu"), TORCH_FN(glu)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Gru.cpp b/aten/src/ATen/native/vulkan/ops/Gru.cpp index 9a5a508b363ce..e29c6b59fd9fb 100644 --- a/aten/src/ATen/native/vulkan/ops/Gru.cpp +++ b/aten/src/ATen/native/vulkan/ops/Gru.cpp @@ -10,10 +10,11 @@ namespace ops { namespace { // // input_vk: input tensor of shape (L, N, H_in) when batch_first=False -// (N, L, H_in) when batch_first=True containing the features of the input sequence -// hx_vk: initial hidden state for each element in the batch. tensor of shape (D * num_layers, N, H_out) -// output: tensor of shape (N, L, D * H_out)) when batch_first=True -// h_n: tensor of shape (D * num_layers, N, H_out) +// (N, L, H_in) when batch_first=True containing +// the features of the input sequence +// hx_vk: initial hidden state for each element in the batch. tensor of shape (D +// * num_layers, N, H_out) output: tensor of shape (N, L, D * H_out)) when +// batch_first=True h_n: tensor of shape (D * num_layers, N, H_out) // // where // L = sequence length @@ -23,30 +24,40 @@ namespace { // H_out = hidden_size (# of features in the hidden state h) // std::tuple gru_input( - const Tensor & input_vk, // input sequence (vulkan) - const Tensor & hx_vk, // initial hidden state (vulkan) - TensorList params_cpu, // weights/biases (cpu) - bool has_biases, - int64_t num_layers, - double dropout, - bool train, - bool bidirectional, - bool batch_first) { - TORCH_CHECK(static_cast(params_cpu.size()) == 4 * num_layers, - "Vulkan gru expects 'params_cpu' size to be 4 * 'num_layers'."); - TORCH_INTERNAL_ASSERT(input_vk.sizes().size() == 3, "Vulkan gru expects 'input_vk' dims to be 3."); - TORCH_INTERNAL_ASSERT(hx_vk.sizes().size() == 3, "Vulkan gru expects 'hx_vk' dims to be 3."); - TORCH_INTERNAL_ASSERT(has_biases, "Vulkan gru expects 'has_biases' to be true."); + const Tensor& input_vk, // input sequence (vulkan) + const Tensor& hx_vk, // initial hidden state (vulkan) + TensorList params_cpu, // weights/biases (cpu) + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + TORCH_CHECK( + static_cast(params_cpu.size()) == 4 * num_layers, + "Vulkan gru expects 'params_cpu' size to be 4 * 'num_layers'."); + TORCH_INTERNAL_ASSERT( + input_vk.sizes().size() == 3, + "Vulkan gru expects 'input_vk' dims to be 3."); + TORCH_INTERNAL_ASSERT( + hx_vk.sizes().size() == 3, "Vulkan gru expects 'hx_vk' dims to be 3."); + TORCH_INTERNAL_ASSERT( + has_biases, "Vulkan gru expects 'has_biases' to be true."); TORCH_INTERNAL_ASSERT(!train, "Vulkan gru expects 'train' to be false."); - TORCH_INTERNAL_ASSERT(!bidirectional, "Vulkan gru expects 'bidirectional' to be false."); - TORCH_INTERNAL_ASSERT(batch_first, "Vulkan gru expects 'batch_first' to be true."); - TORCH_INTERNAL_ASSERT(dropout < std::numeric_limits::epsilon()*1000, "Vulkan gru expects 'dropout' to be 0.0."); + TORCH_INTERNAL_ASSERT( + !bidirectional, "Vulkan gru expects 'bidirectional' to be false."); + TORCH_INTERNAL_ASSERT( + batch_first, "Vulkan gru expects 'batch_first' to be true."); + TORCH_INTERNAL_ASSERT( + dropout < std::numeric_limits::epsilon() * 1000, + "Vulkan gru expects 'dropout' to be 0.0."); const auto hidden_size = hx_vk.size(2); - std::vector h_n_list; // hidden output + std::vector h_n_list; // hidden output // reshape to 2D due to Vulkan at::mm op accepts only 2D - auto x = input_vk.reshape({input_vk.size(0) * input_vk.size(1), input_vk.size(2)}); + auto x = + input_vk.reshape({input_vk.size(0) * input_vk.size(1), input_vk.size(2)}); for (int64_t i = 0; i < num_layers; ++i) { // extract each hidden state and squeeze into 2D dim @@ -58,30 +69,34 @@ std::tuple gru_input( const auto& b_ih = params_cpu[i * 4 + 2]; const auto& b_hh = params_cpu[i * 4 + 3]; - const auto& w_i_rzn = w_ih.split(hidden_size); - const auto& w_h_rzn = w_hh.split(hidden_size); - const auto& b_i_rzn = b_ih.split(hidden_size); - const auto& b_h_rzn = b_hh.split(hidden_size); - - const auto& w_ir = w_i_rzn[0]; - const auto& w_iz = w_i_rzn[1]; - const auto& w_in = w_i_rzn[2]; - const auto& w_hr = w_h_rzn[0]; - const auto& w_hz = w_h_rzn[1]; - const auto& w_hn = w_h_rzn[2]; - const auto& b_ir = b_i_rzn[0]; - const auto& b_iz = b_i_rzn[1]; - const auto& b_in = b_i_rzn[2]; - const auto& b_hr = b_h_rzn[0]; - const auto& b_hz = b_h_rzn[1]; - const auto& b_hn = b_h_rzn[2]; - - const auto& r = at::sigmoid(at::addmm(b_ir, x, w_ir.t()) + at::addmm(b_hr, h, w_hr.t())); - const auto& z = at::sigmoid(at::addmm(b_iz, x, w_iz.t()) + at::addmm(b_hz, h, w_hz.t())); - const auto& n = at::tanh(at::addmm(b_in, x, w_in.t()) + r * (at::addmm(b_hn, h, w_hn.t()))); + const auto& w_i_rzn = w_ih.split(hidden_size); + const auto& w_h_rzn = w_hh.split(hidden_size); + const auto& b_i_rzn = b_ih.split(hidden_size); + const auto& b_h_rzn = b_hh.split(hidden_size); + + const auto& w_ir = w_i_rzn[0]; + const auto& w_iz = w_i_rzn[1]; + const auto& w_in = w_i_rzn[2]; + const auto& w_hr = w_h_rzn[0]; + const auto& w_hz = w_h_rzn[1]; + const auto& w_hn = w_h_rzn[2]; + const auto& b_ir = b_i_rzn[0]; + const auto& b_iz = b_i_rzn[1]; + const auto& b_in = b_i_rzn[2]; + const auto& b_hr = b_h_rzn[0]; + const auto& b_hz = b_h_rzn[1]; + const auto& b_hn = b_h_rzn[2]; + + const auto& r = at::sigmoid( + at::addmm(b_ir, x, w_ir.t()) + at::addmm(b_hr, h, w_hr.t())); + const auto& z = at::sigmoid( + at::addmm(b_iz, x, w_iz.t()) + at::addmm(b_hz, h, w_hz.t())); + const auto& n = at::tanh( + at::addmm(b_in, x, w_in.t()) + r * (at::addmm(b_hn, h, w_hn.t()))); h = (z * (-1) + 1) * n + z * h; - x = h; // next input - h_n_list.emplace_back(h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op + x = h; // next input + h_n_list.emplace_back( + h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op } auto h_n = at::cat(h_n_list, 1); @@ -102,8 +117,9 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { std::vector> pack_linear_op_contexts( const std::vector& params_cpu, int64_t num_layers) { - TORCH_CHECK(static_cast(params_cpu.size()) == 4 * num_layers, - "Vulkan gru expects 'params_cpu' size to be 4 * 'num_layers'."); + TORCH_CHECK( + static_cast(params_cpu.size()) == 4 * num_layers, + "Vulkan gru expects 'params_cpu' size to be 4 * 'num_layers'."); std::vector> linear_op_contexts; for (int64_t i = 0; i < num_layers; ++i) { const auto& w_ih = params_cpu.at(i * 4); @@ -112,23 +128,23 @@ std::vector> pack_linear_op_contexts( const auto& b_hh = params_cpu.at(i * 4 + 3); const auto& hidden_size = w_ih.size(0) / 3; - const auto& w_i_rzn = w_ih.split(hidden_size); - const auto& w_h_rzn = w_hh.split(hidden_size); - const auto& b_i_rzn = b_ih.split(hidden_size); - const auto& b_h_rzn = b_hh.split(hidden_size); - - const auto& w_ir = w_i_rzn[0]; - const auto& w_iz = w_i_rzn[1]; - const auto& w_in = w_i_rzn[2]; - const auto& w_hr = w_h_rzn[0]; - const auto& w_hz = w_h_rzn[1]; - const auto& w_hn = w_h_rzn[2]; - const auto& b_ir = b_i_rzn[0]; - const auto& b_iz = b_i_rzn[1]; - const auto& b_in = b_i_rzn[2]; - const auto& b_hr = b_h_rzn[0]; - const auto& b_hz = b_h_rzn[1]; - const auto& b_hn = b_h_rzn[2]; + const auto& w_i_rzn = w_ih.split(hidden_size); + const auto& w_h_rzn = w_hh.split(hidden_size); + const auto& b_i_rzn = b_ih.split(hidden_size); + const auto& b_h_rzn = b_hh.split(hidden_size); + + const auto& w_ir = w_i_rzn[0]; + const auto& w_iz = w_i_rzn[1]; + const auto& w_in = w_i_rzn[2]; + const auto& w_hr = w_h_rzn[0]; + const auto& w_hz = w_h_rzn[1]; + const auto& w_hn = w_h_rzn[2]; + const auto& b_ir = b_i_rzn[0]; + const auto& b_iz = b_i_rzn[1]; + const auto& b_in = b_i_rzn[2]; + const auto& b_hr = b_h_rzn[0]; + const auto& b_hz = b_h_rzn[1]; + const auto& b_hn = b_h_rzn[2]; linear_op_contexts.emplace_back(create_linear_context(w_ir.t(), b_ir)); linear_op_contexts.emplace_back(create_linear_context(w_hr.t(), b_hr)); @@ -148,12 +164,16 @@ VulkanOpContext gru_context_create( bool train, bool bidirectional, bool batch_first) { - - TORCH_INTERNAL_ASSERT(has_biases, "Vulkan gru expects 'has_biases' to be true."); + TORCH_INTERNAL_ASSERT( + has_biases, "Vulkan gru expects 'has_biases' to be true."); TORCH_INTERNAL_ASSERT(!train, "Vulkan gru expects 'train' to be false."); - TORCH_INTERNAL_ASSERT(!bidirectional, "Vulkan gru expects 'bidirectional' to be false."); - TORCH_INTERNAL_ASSERT(batch_first, "Vulkan gru expects 'batch_first' to be true."); - TORCH_INTERNAL_ASSERT(dropout < std::numeric_limits::epsilon()*1000, "Vulkan gru expects 'dropout' to be 0.0."); + TORCH_INTERNAL_ASSERT( + !bidirectional, "Vulkan gru expects 'bidirectional' to be false."); + TORCH_INTERNAL_ASSERT( + batch_first, "Vulkan gru expects 'batch_first' to be true."); + TORCH_INTERNAL_ASSERT( + dropout < std::numeric_limits::epsilon() * 1000, + "Vulkan gru expects 'dropout' to be 0.0."); c10::impl::GenericList packed_context{c10::AnyType::get()}; packed_context.reserve(7); @@ -179,46 +199,73 @@ VulkanOpContext gru_context_create( } std::tuple gru_context_run( - const Tensor & input_vk, // input sequence (vulkan) - const Tensor & hx_vk, // initial hidden state (vulkan) + const Tensor& input_vk, // input sequence (vulkan) + const Tensor& hx_vk, // initial hidden state (vulkan) const c10::impl::GenericList& packed_context, const c10::impl::GenericList& unpacked_context) { - TORCH_INTERNAL_ASSERT(input_vk.sizes().size() == 3, "Vulkan gru expects 'input_vk' dims to be 3."); - TORCH_INTERNAL_ASSERT(hx_vk.sizes().size() == 3, "Vulkan gru expects 'hx_vk' dims to be 3."); - - const c10::List packed_linear_op_contexts = packed_context.get(0).toList(); + TORCH_INTERNAL_ASSERT( + input_vk.sizes().size() == 3, + "Vulkan gru expects 'input_vk' dims to be 3."); + TORCH_INTERNAL_ASSERT( + hx_vk.sizes().size() == 3, "Vulkan gru expects 'hx_vk' dims to be 3."); + + const c10::List packed_linear_op_contexts = + packed_context.get(0).toList(); const int64_t packed_num_layers = packed_context.get(2).toInt(); - const int64_t linear_op_contexts_per_layer = 6; // (b_ir, w_ir), (b_hr, w_hr), (b_iz, w_iz), (b_hz, w_hz), (b_in, w_in), (b_hn, w_hn) - std::vector h_n_list; // hidden output + const int64_t linear_op_contexts_per_layer = + 6; // (b_ir, w_ir), (b_hr, w_hr), (b_iz, w_iz), (b_hz, w_hz), (b_in, + // w_in), (b_hn, w_hn) + std::vector h_n_list; // hidden output // reshape to 2D due to Vulkan at::mm op accepts only 2D - auto x = input_vk.reshape({input_vk.size(0) * input_vk.size(1), input_vk.size(2)}); + auto x = + input_vk.reshape({input_vk.size(0) * input_vk.size(1), input_vk.size(2)}); for (int64_t i = 0; i < packed_num_layers; ++i) { // extract each hidden state and squeeze into 2D dim auto h = at::slice(hx_vk, 0, i, i + 1, 1); h = h.reshape({h.size(0) * h.size(1), h.size(2)}); - const auto& cxt_ir = packed_linear_op_contexts[i * linear_op_contexts_per_layer + 0].toCustomClass(); - const auto& cxt_hr = packed_linear_op_contexts[i * linear_op_contexts_per_layer + 1].toCustomClass(); - const auto& cxt_iz = packed_linear_op_contexts[i * linear_op_contexts_per_layer + 2].toCustomClass(); - const auto& cxt_hz = packed_linear_op_contexts[i * linear_op_contexts_per_layer + 3].toCustomClass(); - const auto& cxt_in = packed_linear_op_contexts[i * linear_op_contexts_per_layer + 4].toCustomClass(); - const auto& cxt_hn = packed_linear_op_contexts[i * linear_op_contexts_per_layer + 5].toCustomClass(); - - const auto& r = at::sigmoid( - linear_context_run(x, cxt_ir->get_packed(), cxt_ir->get_unpacked(), 1.0f, 1.0f, "aten::addmm") - + linear_context_run(h, cxt_hr->get_packed(), cxt_hr->get_unpacked(), 1.0f, 1.0f, "aten::addmm")); - const auto& z = at::sigmoid( - linear_context_run(x, cxt_iz->get_packed(), cxt_iz->get_unpacked(), 1.0f, 1.0f, "aten::addmm") - + linear_context_run(h, cxt_hz->get_packed(), cxt_hz->get_unpacked(), 1.0f, 1.0f, "aten::addmm")); - const auto& n = at::tanh( - linear_context_run(x, cxt_in->get_packed(), cxt_in->get_unpacked(), 1.0f, 1.0f, "aten::addmm") - + r * (linear_context_run(h, cxt_hn->get_packed(), cxt_hn->get_unpacked(), 1.0f, 1.0f, "aten::addmm"))); + const auto& cxt_ir = + packed_linear_op_contexts[i * linear_op_contexts_per_layer + 0] + .toCustomClass(); + const auto& cxt_hr = + packed_linear_op_contexts[i * linear_op_contexts_per_layer + 1] + .toCustomClass(); + const auto& cxt_iz = + packed_linear_op_contexts[i * linear_op_contexts_per_layer + 2] + .toCustomClass(); + const auto& cxt_hz = + packed_linear_op_contexts[i * linear_op_contexts_per_layer + 3] + .toCustomClass(); + const auto& cxt_in = + packed_linear_op_contexts[i * linear_op_contexts_per_layer + 4] + .toCustomClass(); + const auto& cxt_hn = + packed_linear_op_contexts[i * linear_op_contexts_per_layer + 5] + .toCustomClass(); + + const auto& r = at::sigmoid( + linear_context_run( + x, cxt_ir->get_packed(), cxt_ir->get_unpacked(), 1.0f, 1.0f) + + linear_context_run( + h, cxt_hr->get_packed(), cxt_hr->get_unpacked(), 1.0f, 1.0f)); + const auto& z = at::sigmoid( + linear_context_run( + x, cxt_iz->get_packed(), cxt_iz->get_unpacked(), 1.0f, 1.0f) + + linear_context_run( + h, cxt_hz->get_packed(), cxt_hz->get_unpacked(), 1.0f, 1.0f)); + const auto& n = at::tanh( + linear_context_run( + x, cxt_in->get_packed(), cxt_in->get_unpacked(), 1.0f, 1.0f) + + r * + (linear_context_run( + h, cxt_hn->get_packed(), cxt_hn->get_unpacked(), 1.0f, 1.0f))); h = (z * (-1) + 1) * n + z * h; - x = h; // next input - h_n_list.emplace_back(h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op + x = h; // next input + h_n_list.emplace_back( + h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op } auto h_n = at::cat(h_n_list, 1); @@ -235,7 +282,13 @@ c10::intrusive_ptr create_gru_context( bool bidirectional, bool batch_first) { return c10::make_intrusive(gru_context_create( - params_cpu, has_biases, num_layers, dropout, train, bidirectional, batch_first)); + params_cpu, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first)); } std::tuple run_gru_context( @@ -243,16 +296,15 @@ std::tuple run_gru_context( const Tensor& hx_vk, const c10::intrusive_ptr& vulkan_context) { return gru_context_run( - input_vk, - hx_vk, - vulkan_context->get_packed(), - vulkan_context->get_unpacked()); + input_vk, + hx_vk, + vulkan_context->get_packed(), + vulkan_context->get_unpacked()); } /* Backwards compatibility */ GruOpContext::GruOpContext(VulkanOpContext vulkan_context) - : vulkan_context_{std::move(vulkan_context)} { -} + : vulkan_context_{std::move(vulkan_context)} {} GruOpContext GruOpContext::create( const std::vector& params_cpu, // weights/biases (cpu) @@ -262,31 +314,31 @@ GruOpContext GruOpContext::create( bool train, bool bidirectional, bool batch_first) { - return GruOpContext { - gru_context_create( - params_cpu, - has_biases, - num_layers, - dropout, - train, - bidirectional, - batch_first) - }; + return GruOpContext{gru_context_create( + params_cpu, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first)}; } std::tuple GruOpContext::run( - const Tensor & input_vk, // input sequence (vulkan) - const Tensor & hx_vk) const { // initial hidden state (vulkan) + const Tensor& input_vk, // input sequence (vulkan) + const Tensor& hx_vk) const { // initial hidden state (vulkan) return gru_context_run( - input_vk, - hx_vk, - vulkan_context_.get_packed(), - vulkan_context_.get_unpacked()); + input_vk, + hx_vk, + vulkan_context_.get_packed(), + vulkan_context_.get_unpacked()); } GruOpContext::State GruOpContext::unpack() const { - const c10::impl::GenericList unpacked_ = std::get<1>(vulkan_context_.get_state()); - const std::vector unpacked_params_cpu = unpacked_.get(0).toTensorVector(); + const c10::impl::GenericList unpacked_ = + std::get<1>(vulkan_context_.get_state()); + const std::vector unpacked_params_cpu = + unpacked_.get(0).toTensorVector(); const bool unpacked_has_biases = unpacked_.get(1).toBool(); const int64_t unpacked_num_layers = unpacked_.get(2).toInt(); const double unpacked_dropout = unpacked_.get(3).toDouble(); @@ -294,13 +346,13 @@ GruOpContext::State GruOpContext::unpack() const { const bool unpacked_bidirectional = unpacked_.get(5).toBool(); const bool unpacked_batch_first = unpacked_.get(6).toBool(); return GruOpContext::State{ - unpacked_params_cpu, - unpacked_has_biases, - unpacked_num_layers, - unpacked_dropout, - unpacked_train, - unpacked_bidirectional, - unpacked_batch_first, + unpacked_params_cpu, + unpacked_has_biases, + unpacked_num_layers, + unpacked_dropout, + unpacked_train, + unpacked_bidirectional, + unpacked_batch_first, }; } @@ -313,7 +365,13 @@ c10::intrusive_ptr gru_prepack( bool bidirectional, bool batch_first) { return c10::make_intrusive(GruOpContext::create( - params_cpu, has_biases, num_layers, dropout, train, bidirectional, batch_first)); + params_cpu, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first)); } std::tuple gru_run( diff --git a/aten/src/ATen/native/vulkan/ops/Gru.h b/aten/src/ATen/native/vulkan/ops/Gru.h index 14993cfafb89b..304ce822a0e9a 100644 --- a/aten/src/ATen/native/vulkan/ops/Gru.h +++ b/aten/src/ATen/native/vulkan/ops/Gru.h @@ -12,9 +12,31 @@ namespace vulkan { namespace ops { // packed -// std::vector> linear_op_contexts; // {{ op context for b_ir, w_ir, op context for b_hr, w_hr, -// // op context for b_iz, w_iz, op context for b_hz, w_hz, -// // op context for b_in, w_in, op context for b_hn, w_hn,}, ...} +// std::vector> linear_op_contexts; // +// {{ op context for b_ir, w_ir, op context for b_hr, w_hr, +// // +// op +// context +// for +// b_iz, +// w_iz, +// op +// context +// for +// b_hz, +// w_hz, +// // +// op +// context +// for +// b_in, +// w_in, +// op +// context +// for +// b_hn, +// w_hn,}, +// ...} // bool has_biases{}; // int64_t num_layers{}; // double dropout{}; @@ -41,13 +63,13 @@ VulkanOpContext gru_context_create( bool batch_first); std::tuple gru_context_run( - const Tensor & input_vk, // input sequence (vulkan) - const Tensor & hx_vk, // initial hidden state (vulkan) + const Tensor& input_vk, // input sequence (vulkan) + const Tensor& hx_vk, // initial hidden state (vulkan) const c10::impl::GenericList& packed_context, const c10::impl::GenericList& unpacked_context); c10::intrusive_ptr create_gru_context( - std::vector&& params_cpu, // weights/biases (cpu) + std::vector&& params_cpu, // weights/biases (cpu) bool has_biases, int64_t num_layers, double dropout, @@ -57,7 +79,7 @@ c10::intrusive_ptr create_gru_context( std::tuple run_gru_context( const Tensor& input_vk, - const Tensor & hx_vk, + const Tensor& hx_vk, const c10::intrusive_ptr& vulkan_context); // Backwards compatibility @@ -72,11 +94,11 @@ class GruOpContext final : public torch::jit::CustomClassHolder { bool bidirectional, bool batch_first); - using State = std::tuple, bool, int64_t, double, bool, bool, bool>; + using State = + std::tuple, bool, int64_t, double, bool, bool, bool>; - std::tuple run( - const Tensor& input_vk, - const Tensor & hx_vk) const; + std::tuple run(const Tensor& input_vk, const Tensor& hx_vk) + const; State unpack() const; private: @@ -85,7 +107,7 @@ class GruOpContext final : public torch::jit::CustomClassHolder { }; c10::intrusive_ptr gru_prepack( - std::vector&& params_cpu, // weights/biases (cpu) + std::vector&& params_cpu, // weights/biases (cpu) bool has_biases, int64_t num_layers, double dropout, @@ -95,7 +117,7 @@ c10::intrusive_ptr gru_prepack( std::tuple gru_run( const Tensor& input_vk, - const Tensor & hx_vk, + const Tensor& hx_vk, const c10::intrusive_ptr& context); } // namespace ops diff --git a/aten/src/ATen/native/vulkan/ops/Layernorm.cpp b/aten/src/ATen/native/vulkan/ops/Layernorm.cpp new file mode 100644 index 0000000000000..8c34f71890f71 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Layernorm.cpp @@ -0,0 +1,153 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +void _check_layer_norm_inputs( + const at::Tensor& input, + IntArrayRef normalized_shape, + const c10::optional& weight /* optional */, + const c10::optional& bias /* optional */) { + const auto normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight->defined() || weight->sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight->sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias->defined() || bias->sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias->sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_shape = input.sizes(); + const auto input_ndim = input.sizes().size(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + AT_ERROR(ss.str()); + } +} + +Tensor layer_norm( + const at::Tensor& input_arg, + IntArrayRef normalized_shape, + const c10::optional& weight_opt /* optional */, + const c10::optional& bias_opt /* optional */, + double eps, + bool /* cudnn_enable, deprecated */) { + _check_layer_norm_inputs(input_arg, normalized_shape, weight_opt, bias_opt); + + TORCH_CHECK( + input_arg.dim() == 3 || input_arg.dim() == 4, + "Vulkan layernorm expects 3-dim or 4-dim input!"); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + const IntArrayRef v_input_sizes = v_input.sizes(); + + TORCH_CHECK( + input_arg.dim() == 3 || v_input_sizes[Layout::Activation4D::batch] == 1, + "Vulkan layernorm expects batch dim == 1 when the input is 4-dimensional!"); + + TORCH_CHECK( + normalized_shape.size() == 3, + "Vulkan layernorm expects normalized_shape to have length 3, i.e. [C, H, W]"); + + TORCH_CHECK( + weight_opt->defined() && bias_opt->defined(), + "Vulkan layernorm expects weight and bias arguments"); + + const auto volume = + c10::multiply_integers(v_input_sizes.cbegin(), v_input_sizes.end()); + + const Tensor weight = + weight_opt->is_vulkan() ? *weight_opt : weight_opt->vulkan(); + const vTensor& v_weight = convert(weight); + + const Tensor bias = bias_opt->is_vulkan() ? *bias_opt : bias_opt->vulkan(); + const vTensor& v_bias = convert(bias); + + api::Context* const context = api::context(); + + vTensor v_output{ + context, + v_input_sizes, + v_input.options(), + }; + + const struct Block final { + uvec3 iextents; + int32_t volume; + int32_t last_texel_end_offset; + float epsilon; + } block{ + v_input.extents(), + safe_downcast(volume), + safe_downcast((v_input_sizes[input_arg.dim() - 3] - 1) % 4), + safe_downcast(eps)}; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(layernorm), + // pipeline barrier + pipeline_barrier, + // global work group size + v_input.extents(), + // local work group size + adaptive_work_group_size(v_input.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl(TORCH_SELECTIVE_NAME("aten::layer_norm"), TORCH_FN(layer_norm)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Lerp.cpp b/aten/src/ATen/native/vulkan/ops/Lerp.cpp index 4a1a351919c5f..67240f64b2ccd 100644 --- a/aten/src/ATen/native/vulkan/ops/Lerp.cpp +++ b/aten/src/ATen/native/vulkan/ops/Lerp.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -47,8 +46,7 @@ void check_inputs_elementwise_op(const Tensor& input1, const Tensor& input2) { Tensor _lerp_scalar( const Tensor& start_arg, const Tensor& end_arg, - const Scalar& weight_arg, - const std::string& op_name) { + const Scalar& weight_arg) { check_inputs_elementwise_op(start_arg, end_arg); api::Context* const context = api::context(); @@ -59,142 +57,118 @@ Tensor _lerp_scalar( const vTensor& v_end = convert(end); vTensor v_output{ - context, - v_start.sizes(), - v_start.options(), + context, + v_start.sizes(), + v_start.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY (v_start.has_image() && v_end.has_image()) { - const float weight = weight_arg.to(); - const struct Block final { - uvec3 extents; - uint32_t fill_0; - uvec3 input1_extents; - uint32_t fill_1; - uvec3 input2_extents; - float weight; - } block{ - v_output.extents(), - 0u, - v_start.extents(), - 0u, - v_end.extents(), - weight, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(lerp_scalar), - v_output.extents(), - adaptive_work_group_size(v_output.extents()), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, vTensor::Stage::Compute, vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_start.image(command_buffer, vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_end.image(command_buffer, vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const float weight = weight_arg.to(); + const struct Block final { + uvec3 extents; + uint32_t fill_0; + uvec3 input1_extents; + uint32_t fill_1; + uvec3 input2_extents; + float weight; + } block{ + v_output.extents(), + 0u, + v_start.extents(), + 0u, + v_end.extents(), + weight, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(lerp_scalar), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_start.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } Tensor& _lerp_scalar_( - Tensor& self, + Tensor& self_arg, const Tensor& end_arg, - const Scalar& weight_arg, - const std::string& op_name) { - check_inputs_elementwise_op(self, end_arg); - api::Context* const context = api::context(); + const Scalar& weight_arg) { + check_inputs_elementwise_op(self_arg, end_arg); TORCH_CHECK( - self.is_vulkan(), - "Vulkan: In-place lerp is only supported on Vulkan tensors."); + self_arg.is_vulkan(), + "Vulkan: In-place operator is only supported on Vulkan tensors."); + + api::Context* const context = api::context(); - vTensor& v_self = convert(self); + vTensor& v_self = convert(self_arg); const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan(); const vTensor& v_end = convert(end); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY ( - v_self.has_image() && v_end.has_image() && !self.is_same(end)) { - const float weight = weight_arg.to(); - const struct Block final { - uvec3 extents; - uint32_t fill_0; - uvec3 input_extents; - float alpha; - } block{ - v_self.extents(), - 0u, - v_end.extents(), - weight, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(lerp_scalar_), - v_self.extents(), - adaptive_work_group_size(v_self.extents()), - // Read-Write access triggers an async synchronization if necessory - // and inserts appropriate barriers if hazards are detected. - v_self.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_end.image(command_buffer, vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const float weight = weight_arg.to(); + const struct Block final { + uvec3 extents; + uint32_t fill_0; + uvec3 input_extents; + float alpha; + } block{ + v_self.extents(), + 0u, + v_end.extents(), + weight, + }; - return self; + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(lerp_scalar_), + // pipeline barrier + pipeline_barrier, + // global work group size + v_self.extents(), + // local work group size + adaptive_work_group_size(v_self.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_self.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return self_arg; } Tensor _lerp_tensor( const Tensor& start_arg, const Tensor& end_arg, - const Tensor& weight_arg, - const std::string& op_name) { + const Tensor& weight_arg) { check_inputs_elementwise_op(start_arg, end_arg); check_inputs_elementwise_op(start_arg, weight_arg); + api::Context* const context = api::context(); const Tensor start = start_arg.is_vulkan() ? start_arg : start_arg.vulkan(); @@ -203,183 +177,155 @@ Tensor _lerp_tensor( const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan(); const vTensor& v_end = convert(end); - const Tensor weight = weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan(); - const vTensor& v_weight = convert(weight); + const Tensor weight = + weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan(); + const vTensor& v_weight = convert(weight_arg); vTensor v_output{ - context, - v_start.sizes(), - v_start.options(), + context, + v_start.sizes(), + v_start.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY (v_start.has_image() && v_end.has_image() && v_weight.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t fill_0; - uvec3 input1_extents; - uint32_t fill_1; - uvec3 input2_extents; - uint32_t fill_2; - uvec3 input3_extents; - uint32_t fill_3; - } block{ - v_output.extents(), - 0u, - v_start.extents(), - 0u, - v_end.extents(), - 0u, - v_weight.extents(), - 0u, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(lerp), - v_output.extents(), - adaptive_work_group_size(v_output.extents()), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, vTensor::Stage::Compute, vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_start.image(command_buffer, vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_end.image(command_buffer, vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_weight.image(command_buffer, vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const struct Block final { + uvec3 extents; + uint32_t fill_0; + uvec3 input1_extents; + uint32_t fill_1; + uvec3 input2_extents; + uint32_t fill_2; + uvec3 input3_extents; + uint32_t fill_3; + } block{ + v_output.extents(), + 0u, + v_start.extents(), + 0u, + v_end.extents(), + 0u, + v_weight.extents(), + 0u, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(lerp), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_start.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } Tensor& _lerp_tensor_( - Tensor& self, + Tensor& self_arg, const Tensor& end_arg, - const Tensor& weight_arg, - const std::string& op_name) { - check_inputs_elementwise_op(self, end_arg); - check_inputs_elementwise_op(self, weight_arg); - api::Context* const context = api::context(); + const Tensor& weight_arg) { + check_inputs_elementwise_op(self_arg, end_arg); + check_inputs_elementwise_op(self_arg, weight_arg); TORCH_CHECK( - self.is_vulkan(), - "Vulkan: In-place lerp is only supported on Vulkan tensors."); + self_arg.is_vulkan(), + "Vulkan: In-place operator is only supported on Vulkan tensors."); - vTensor& v_self = convert(self); + api::Context* const context = api::context(); - const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan(); - const vTensor& v_end = convert(end); + vTensor& v_self = convert(self_arg); - const Tensor weight = weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan(); - const vTensor& v_weight = convert(weight); - - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY ( - v_self.has_image() && v_end.has_image() && v_weight.has_image() && !self.is_same(end)) { - const struct Block final { - uvec3 extents; - uint32_t fill_0; - uvec3 input1_extents; - uint32_t fill_1; - uvec3 input2_extents; - uint32_t fill_2; - } block{ - v_self.extents(), - 0u, - v_end.extents(), - 0u, - v_weight.extents(), - 0u, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(lerp_), - v_self.extents(), - adaptive_work_group_size(v_self.extents()), - // Read-Write access triggers an async synchronization if necessory - // and inserts appropriate barriers if hazards are detected. - v_self.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_end.image(command_buffer, vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_weight.image(command_buffer, vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const Tensor end = end_arg.is_vulkan() ? end_arg : end_arg.vulkan(); + const vTensor& v_end = convert(end_arg); + + const Tensor weight = + weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan(); + const vTensor& v_weight = convert(weight_arg); + + const struct Block final { + uvec3 extents; + uint32_t fill_0; + uvec3 input1_extents; + uint32_t fill_1; + uvec3 input2_extents; + uint32_t fill_2; + } block{ + v_self.extents(), + 0u, + v_end.extents(), + 0u, + v_weight.extents(), + 0u, + }; - return self; + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(lerp_), + // pipeline barrier + pipeline_barrier, + // global work group size + v_self.extents(), + // local work group size + adaptive_work_group_size(v_self.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_self.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + v_end.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return self_arg; } -Tensor lerp_scalar(const Tensor& start, const Tensor& end, const Scalar& weight) { - return _lerp_scalar( - start, end, weight, "aten::lerp.Scalar"); +Tensor lerp_scalar( + const Tensor& start, + const Tensor& end, + const Scalar& weight) { + return _lerp_scalar(start, end, weight); } Tensor& lerp_scalar_(Tensor& self, const Tensor& end, const Scalar& weight) { - return _lerp_scalar_( - self, end, weight, "aten::lerp_.Scalar"); + return _lerp_scalar_(self, end, weight); } -Tensor lerp_tensor(const Tensor& start, const Tensor& end, const Tensor& weight) { +Tensor lerp_tensor( + const Tensor& start, + const Tensor& end, + const Tensor& weight) { if (weight.sizes().size() == 0) { - return _lerp_scalar( - start, end, weight.item(), "aten::lerp.Tensor"); + return _lerp_scalar(start, end, weight.item()); } - return _lerp_tensor( - start, end, weight, "aten::lerp.Tensor"); + return _lerp_tensor(start, end, weight); } Tensor& lerp_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) { if (weight.sizes().size() == 0) { - return _lerp_scalar_( - self, end, weight.item(), "aten::lerp_.Tensor"); + return _lerp_scalar_(self, end, weight.item()); } - return _lerp_tensor_( - self, end, weight, "aten::lerp_.Tensor"); + return _lerp_tensor_(self, end, weight); } #ifdef USE_VULKAN_API diff --git a/aten/src/ATen/native/vulkan/ops/Lstm.cpp b/aten/src/ATen/native/vulkan/ops/Lstm.cpp index d0de0794aae6d..c86583621c3bb 100644 --- a/aten/src/ATen/native/vulkan/ops/Lstm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Lstm.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include namespace at { @@ -7,14 +9,19 @@ namespace vulkan { namespace ops { namespace { // -// input_vk: input tensor of shape (L, N, H_in) when batch_first=False or (N, L, H_in) when batch_first=True +// input_vk: input tensor of shape (L, N, H_in) when batch_first=False or (N, L, +// H_in) when batch_first=True // containing the features of the input sequence -// hx_vk: tensor of shape (D * num_layers, N, H_out) containing the initial hidden state for each element in the input sequence. -// cx_vk: tensor of shape (D * num_layers, N, H_cell) containing the initial cell state for each element in the input sequence. -// output: tensor of shape (L, N, D * H_out) when batch_first=False or (N, L, D * H_out) when batch_first=True -// containing the output features (h_t) from the last layer of the LSTM, for each t -// h_n: tensor of shape (D * num_layers, N, H_out) containing the final hidden state for each element in the sequence. -// c_n: tensor of shape (D * num_layers, N, H_cell) containing the final cell state for each element in the sequence. +// hx_vk: tensor of shape (D * num_layers, N, H_out) containing the initial +// hidden state for each element in the input sequence. cx_vk: tensor of shape +// (D * num_layers, N, H_cell) containing the initial cell state for each +// element in the input sequence. output: tensor of shape (L, N, D * H_out) when +// batch_first=False or (N, L, D * H_out) when batch_first=True +// containing the output features (h_t) from the last layer of the LSTM, +// for each t +// h_n: tensor of shape (D * num_layers, N, H_out) containing the final hidden +// state for each element in the sequence. c_n: tensor of shape (D * num_layers, +// N, H_cell) containing the final cell state for each element in the sequence. // // where // L = sequence length @@ -25,35 +32,54 @@ namespace { // H_out = hidden_size // std::tuple lstm_input( - const Tensor & input_vk, // input sequence (vulkan) - TensorList hx, // initial hidden state (vulkan) & initial cell state (vulkan) - TensorList params_cpu, // weights/biases (cpu) - bool has_biases, - int64_t num_layers, - double dropout, - bool train, - bool bidirectional, - bool batch_first) { - TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "Vulkan LSTM with projections is not supported"); - TORCH_CHECK(static_cast(params_cpu.size()), "Vulkan LSTM expects 'params_cpu' size to be 4 * 'num_layers'."); - TORCH_INTERNAL_ASSERT(input_vk.sizes().size() == 3, "Vulkan LSTM expects input dims to be 3."); - TORCH_INTERNAL_ASSERT(hx[0].sizes().size() == 3, "Vulkan LSTM expects hidden state dims to be 3."); - TORCH_INTERNAL_ASSERT(hx[1].sizes().size() == 3, "Vulkan LSTM expects cell state dims to be 3."); - TORCH_INTERNAL_ASSERT(has_biases, "Vulkan LSTM expects 'has_biases' to be true."); + const Tensor& input_vk, // input sequence (vulkan) + TensorList + hx, // initial hidden state (vulkan) & initial cell state (vulkan) + TensorList params_cpu, // weights/biases (cpu) + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + TORCH_CHECK( + hx[0].size(2) == hx[1].size(2), + "Vulkan LSTM with projections is not supported"); + TORCH_CHECK( + static_cast(params_cpu.size()), + "Vulkan LSTM expects 'params_cpu' size to be 4 * 'num_layers'."); + TORCH_INTERNAL_ASSERT( + input_vk.sizes().size() == 3, "Vulkan LSTM expects input dims to be 3."); + TORCH_INTERNAL_ASSERT( + hx[0].sizes().size() == 3, + "Vulkan LSTM expects hidden state dims to be 3."); + TORCH_INTERNAL_ASSERT( + hx[1].sizes().size() == 3, + "Vulkan LSTM expects cell state dims to be 3."); + TORCH_INTERNAL_ASSERT( + has_biases, "Vulkan LSTM expects 'has_biases' to be true."); TORCH_INTERNAL_ASSERT(!train, "Vulkan LSTM expects 'train' to be false."); - TORCH_INTERNAL_ASSERT(!bidirectional, "Vulkan LSTM expects 'bidirectional' to be false."); - TORCH_INTERNAL_ASSERT(batch_first, "Vulkan LSTM expects 'batch_first' to be true."); - TORCH_INTERNAL_ASSERT(dropout < std::numeric_limits::epsilon()*1000, "Vulkan LSTM expects 'dropout' to be 0.0."); + TORCH_INTERNAL_ASSERT( + !bidirectional, "Vulkan LSTM expects 'bidirectional' to be false."); + TORCH_INTERNAL_ASSERT( + batch_first, "Vulkan LSTM expects 'batch_first' to be true."); + TORCH_INTERNAL_ASSERT( + dropout < std::numeric_limits::epsilon() * 1000, + "Vulkan LSTM expects 'dropout' to be 0.0."); const Tensor& hx_vk = hx[0]; const Tensor& cx_vk = hx[1]; const auto hidden_size = hx_vk.size(2); - std::vector h_n_list; // hidden state output - std::vector c_n_list; // cell state output + std::vector h_n_list; // hidden state output + std::vector c_n_list; // cell state output // reshape to 2D due to Vulkan at::mm op accepts only 2D - auto x = input_vk.reshape({input_vk.size(0) * input_vk.size(1), input_vk.size(2)}); + auto x = + input_vk.reshape({input_vk.size(0) * input_vk.size(1), input_vk.size(2)}); + + h_n_list.reserve(num_layers); + c_n_list.reserve(num_layers); for (int64_t l = 0; l < num_layers; ++l) { // extract each hidden state and squeeze into 2D dim @@ -90,15 +116,21 @@ std::tuple lstm_input( const auto& b_hg = b_h_ifgo[2]; const auto& b_ho = b_h_ifgo[3]; - const auto& i = at::sigmoid(at::addmm(b_ii, x, w_ii.t()) + at::addmm(b_hi, h, w_hi.t())); - const auto& f = at::sigmoid(at::addmm(b_if, x, w_if.t()) + at::addmm(b_hf, h, w_hf.t())); - const auto& g = at::tanh(at::addmm(b_ig, x, w_ig.t()) + at::addmm(b_hg, h, w_hg.t())); - const auto& o = at::sigmoid(at::addmm(b_io, x, w_io.t()) + at::addmm(b_ho, h, w_ho.t())); + const auto& i = at::sigmoid( + at::addmm(b_ii, x, w_ii.t()) + at::addmm(b_hi, h, w_hi.t())); + const auto& f = at::sigmoid( + at::addmm(b_if, x, w_if.t()) + at::addmm(b_hf, h, w_hf.t())); + const auto& g = + at::tanh(at::addmm(b_ig, x, w_ig.t()) + at::addmm(b_hg, h, w_hg.t())); + const auto& o = at::sigmoid( + at::addmm(b_io, x, w_io.t()) + at::addmm(b_ho, h, w_ho.t())); c = f * c + i * g; h = o * at::tanh(c); - x = h; // next input - h_n_list.emplace_back(h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op - c_n_list.emplace_back(c.reshape({1, 1, c.size(0), c.size(1)})); // 2D to 4D for cat op + x = h; // next input + h_n_list.emplace_back( + h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op + c_n_list.emplace_back( + c.reshape({1, 1, c.size(0), c.size(1)})); // 2D to 4D for cat op } auto h_n = at::cat(h_n_list, 1); @@ -117,6 +149,232 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { #endif /* USE_VULKAN_API */ } // namespace + +std::vector> pack_lstm_linear_op_contexts( + const std::vector& params_cpu, + int64_t num_layers) { + TORCH_CHECK( + static_cast(params_cpu.size()) == 4 * num_layers, + "Vulkan LSTM expects 'params_cpu' size to be 4 * 'num_layers'."); + std::vector> linear_op_contexts; + linear_op_contexts.reserve(num_layers * 8); + + for (int64_t l = 0; l < num_layers; ++l) { + const auto& w_ih = params_cpu[l * 4]; + const auto& w_hh = params_cpu[l * 4 + 1]; + const auto& b_ih = params_cpu[l * 4 + 2]; + const auto& b_hh = params_cpu[l * 4 + 3]; + const auto& hidden_size = w_ih.size(0) / 4; + + const auto& w_i_ifgo = w_ih.split(hidden_size); + const auto& w_h_ifgo = w_hh.split(hidden_size); + const auto& b_i_ifgo = b_ih.split(hidden_size); + const auto& b_h_ifgo = b_hh.split(hidden_size); + + const auto& w_ii = w_i_ifgo[0]; + const auto& w_if = w_i_ifgo[1]; + const auto& w_ig = w_i_ifgo[2]; + const auto& w_io = w_i_ifgo[3]; + const auto& w_hi = w_h_ifgo[0]; + const auto& w_hf = w_h_ifgo[1]; + const auto& w_hg = w_h_ifgo[2]; + const auto& w_ho = w_h_ifgo[3]; + const auto& b_ii = b_i_ifgo[0]; + const auto& b_if = b_i_ifgo[1]; + const auto& b_ig = b_i_ifgo[2]; + const auto& b_io = b_i_ifgo[3]; + const auto& b_hi = b_h_ifgo[0]; + const auto& b_hf = b_h_ifgo[1]; + const auto& b_hg = b_h_ifgo[2]; + const auto& b_ho = b_h_ifgo[3]; + + linear_op_contexts.emplace_back(create_linear_context(w_ii.t(), b_ii)); + linear_op_contexts.emplace_back(create_linear_context(w_hi.t(), b_hi)); + linear_op_contexts.emplace_back(create_linear_context(w_if.t(), b_if)); + linear_op_contexts.emplace_back(create_linear_context(w_hf.t(), b_hf)); + linear_op_contexts.emplace_back(create_linear_context(w_ig.t(), b_ig)); + linear_op_contexts.emplace_back(create_linear_context(w_hg.t(), b_hg)); + linear_op_contexts.emplace_back(create_linear_context(w_io.t(), b_io)); + linear_op_contexts.emplace_back(create_linear_context(w_ho.t(), b_ho)); + } + return linear_op_contexts; +} + +VulkanOpContext lstm_context_create( + const std::vector& params_cpu, // weights/biases (cpu) + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + TORCH_INTERNAL_ASSERT( + has_biases, "Vulkan LSTM expects 'has_biases' to be true."); + TORCH_INTERNAL_ASSERT(!train, "Vulkan LSTM expects 'train' to be false."); + TORCH_INTERNAL_ASSERT( + !bidirectional, "Vulkan LSTM expects 'bidirectional' to be false."); + TORCH_INTERNAL_ASSERT( + batch_first, "Vulkan LSTM expects 'batch_first' to be true."); + TORCH_INTERNAL_ASSERT( + dropout < std::numeric_limits::epsilon() * 1000, + "Vulkan LSTM expects 'dropout' to be 0.0."); + + c10::impl::GenericList packed_context{c10::AnyType::get()}; + packed_context.reserve(7); + packed_context.emplace_back( + pack_lstm_linear_op_contexts(params_cpu, num_layers)); + packed_context.emplace_back(has_biases); + packed_context.emplace_back(num_layers); + packed_context.emplace_back(dropout); + packed_context.emplace_back(train); + packed_context.emplace_back(bidirectional); + packed_context.emplace_back(batch_first); + + c10::impl::GenericList unpacked_context{c10::AnyType::get()}; + unpacked_context.reserve(7); + unpacked_context.emplace_back(params_cpu); + unpacked_context.emplace_back(has_biases); + unpacked_context.emplace_back(num_layers); + unpacked_context.emplace_back(dropout); + unpacked_context.emplace_back(train); + unpacked_context.emplace_back(bidirectional); + unpacked_context.emplace_back(batch_first); + + return VulkanOpContext::create(packed_context, unpacked_context); +} + +std::tuple lstm_context_run( + const Tensor& input_vk, // input sequence (vulkan) + const Tensor& hx_vk, // initial hidden state (vulkan) + const Tensor& cx_vk, // initial cell state (vulkan) + const c10::impl::GenericList& packed_context, + const c10::impl::GenericList& unpacked_context) { + TORCH_INTERNAL_ASSERT( + input_vk.sizes().size() == 3, "Vulkan LSTM expects input dims to be 3."); + TORCH_INTERNAL_ASSERT( + hx_vk.sizes().size() == 3, + "Vulkan LSTM expects hidden state dims to be 3."); + TORCH_INTERNAL_ASSERT( + cx_vk.sizes().size() == 3, + "Vulkan LSTM expects cell state dims to be 3."); + + const c10::List packed_linear_op_contexts = + packed_context.get(0).toList(); + const int64_t packed_num_layers = packed_context.get(2).toInt(); + + const int64_t linear_op_contexts_per_layer = + 8; // (b_ii, w_ii), (b_hi, w_hi), (b_if, w_if), (b_hf, w_hf), (b_ig, + // w_ig), (b_hg, w_hg), (b_io, w_io), (b_ho, w_ho) + std::vector h_n_list; // hidden state output + std::vector c_n_list; // cell state output + + // reshape to 2D due to Vulkan at::mm op accepts only 2D + auto x = + input_vk.reshape({input_vk.size(0) * input_vk.size(1), input_vk.size(2)}); + + h_n_list.reserve(packed_num_layers); + c_n_list.reserve(packed_num_layers); + + for (int64_t l = 0; l < packed_num_layers; ++l) { + // extract each hidden state and squeeze into 2D dim + auto h = at::slice(hx_vk, 0, l, l + 1, 1); + h = h.reshape({h.size(0) * h.size(1), h.size(2)}); + + auto c = at::slice(cx_vk, 0, l, l + 1, 1); + c = c.reshape({c.size(0) * c.size(1), c.size(2)}); + + const auto& cxt_ii = + packed_linear_op_contexts[l * linear_op_contexts_per_layer + 0] + .toCustomClass(); + const auto& cxt_hi = + packed_linear_op_contexts[l * linear_op_contexts_per_layer + 1] + .toCustomClass(); + const auto& cxt_if = + packed_linear_op_contexts[l * linear_op_contexts_per_layer + 2] + .toCustomClass(); + const auto& cxt_hf = + packed_linear_op_contexts[l * linear_op_contexts_per_layer + 3] + .toCustomClass(); + const auto& cxt_ig = + packed_linear_op_contexts[l * linear_op_contexts_per_layer + 4] + .toCustomClass(); + const auto& cxt_hg = + packed_linear_op_contexts[l * linear_op_contexts_per_layer + 5] + .toCustomClass(); + const auto& cxt_io = + packed_linear_op_contexts[l * linear_op_contexts_per_layer + 6] + .toCustomClass(); + const auto& cxt_ho = + packed_linear_op_contexts[l * linear_op_contexts_per_layer + 7] + .toCustomClass(); + + const auto& i = at::sigmoid( + linear_context_run( + x, cxt_ii->get_packed(), cxt_ii->get_unpacked(), 1.0f, 1.0f) + + linear_context_run( + h, cxt_hi->get_packed(), cxt_hi->get_unpacked(), 1.0f, 1.0f)); + const auto& f = at::sigmoid( + linear_context_run( + x, cxt_if->get_packed(), cxt_if->get_unpacked(), 1.0f, 1.0f) + + linear_context_run( + h, cxt_hf->get_packed(), cxt_hf->get_unpacked(), 1.0f, 1.0f)); + const auto& g = at::tanh( + linear_context_run( + x, cxt_ig->get_packed(), cxt_ig->get_unpacked(), 1.0f, 1.0f) + + linear_context_run( + h, cxt_hg->get_packed(), cxt_hg->get_unpacked(), 1.0f, 1.0f)); + const auto& o = at::sigmoid( + linear_context_run( + x, cxt_io->get_packed(), cxt_io->get_unpacked(), 1.0f, 1.0f) + + linear_context_run( + h, cxt_ho->get_packed(), cxt_ho->get_unpacked(), 1.0f, 1.0f)); + c = f * c + i * g; + h = o * at::tanh(c); + x = h; // next input + h_n_list.emplace_back( + h.reshape({1, 1, h.size(0), h.size(1)})); // 2D to 4D for cat op + c_n_list.emplace_back( + c.reshape({1, 1, c.size(0), c.size(1)})); // 2D to 4D for cat op + } + + auto h_n = at::cat(h_n_list, 1); + auto c_n = at::cat(c_n_list, 1); + h_n = h_n.reshape({h_n.size(0) * h_n.size(1), h_n.size(2), h_n.size(3)}); + c_n = c_n.reshape({c_n.size(0) * c_n.size(1), c_n.size(2), c_n.size(3)}); + return std::tuple(x, h_n, c_n); +} + +c10::intrusive_ptr create_lstm_context( + std::vector&& params_cpu, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + return c10::make_intrusive(lstm_context_create( + params_cpu, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first)); +} + +std::tuple run_lstm_context( + const Tensor& input_vk, // input sequence (vulkan) + const Tensor& hx_vk, // initial hidden state (vulkan) + const Tensor& cx_vk, // initial cell state (vulkan) + const c10::intrusive_ptr& vulkan_context) { + return lstm_context_run( + input_vk, + hx_vk, + cx_vk, + vulkan_context->get_packed(), + vulkan_context->get_unpacked()); +} + } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Lstm.h b/aten/src/ATen/native/vulkan/ops/Lstm.h new file mode 100644 index 0000000000000..e793ad1d00a75 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Lstm.h @@ -0,0 +1,87 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +// packed +// std::vector> linear_op_contexts; // +// {{ op context for b_ii, w_ii, op context for b_hi, w_hi, +// // +// op +// context +// for +// b_if, +// w_if, +// op +// context +// for +// b_hf, +// w_hf, +// // +// op +// context +// for +// b_ig, +// w_ig, +// op +// context +// for +// b_hg, +// w_hg, +// // +// op +// context +// for +// b_io, +// w_io, +// op +// context +// for +// b_ho, +// w_ho,}, +// ...} +// bool has_biases{}; +// int64_t num_layers{}; +// double dropout{}; +// bool train{}; +// bool bidirectional{}; +// bool batch_first{}; + +// unpacked +// std::vector params_cpu // weights/biases (cpu) +// bool has_biases +// int64_t num_layers +// double dropout +// bool train +// bool bidirectional +// bool batch_first + +c10::intrusive_ptr create_lstm_context( + std::vector&& params_cpu, // weights/biases (cpu) + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first); + +std::tuple run_lstm_context( + const Tensor& input_vk, // input sequence (vulkan) + const Tensor& hx_vk, // initial hidden state (vulkan) + const Tensor& cx_vk, // initial cell state (vulkan) + const c10::intrusive_ptr& vulkan_context); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Mean.cpp b/aten/src/ATen/native/vulkan/ops/Mean.cpp index 3e678056fc3b8..90d2d1bed0e54 100644 --- a/aten/src/ATen/native/vulkan/ops/Mean.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mean.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -13,33 +12,35 @@ using namespace api::utils; Tensor mean( const at::Tensor& input_arg, - const IntArrayRef dim, + const OptionalIntArrayRef opt_dim, const bool keepdim, const optional dtype) { - TORCH_CHECK( - input_arg.dim() == 4, - "Vulkan mean expects 4-dimensional input!"); + TORCH_CHECK(input_arg.dim() == 4, "Vulkan mean expects 4-dimensional input!"); static const std::unordered_set expected_dims_set({2, 3}); std::unordered_set dims_set; - for (const auto& d : dim) { - dims_set.insert(utils::normalize(d, 4)); + if (opt_dim.has_value()) { + auto dim = opt_dim.value(); + for (const auto& d : dim) { + dims_set.insert(utils::normalize(d, 4)); + } } TORCH_CHECK( dims_set == expected_dims_set, - "Vulkan mean currently only supports image-wide reduction!"); + "Vulkan mean: currently only supports image-wide reduction!"); api::Context* const context = api::context(); const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); const vTensor& v_input = convert(input); + const IntArrayRef v_input_sizes = v_input.sizes(); c10::SmallVector output_sizes{ - v_input_sizes[Layout::Activation4D::batch], - v_input_sizes[Layout::Activation4D::channels], + v_input_sizes[Layout::Activation4D::batch], + v_input_sizes[Layout::Activation4D::channels], }; if (keepdim) { @@ -48,59 +49,44 @@ Tensor mean( } vTensor v_output{ - context, - output_sizes, - v_input.options(), + context, + output_sizes, + v_input.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::mean.dim"); - - if C10_LIKELY(v_input.has_image()) { - const struct Block final { - uvec3 extents; - int32_t range; - uvec3 iextents; - } block { - v_output.extents(), - safe_downcast( - v_input_sizes[Layout::Activation4D::width] * - v_input_sizes[Layout::Activation4D::height]), - v_input.extents() - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - keepdim ? VK_KERNEL(mean) : VK_KERNEL(mean2d), - v_input.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const struct Block final { + uvec3 extents; + int32_t range; + uvec3 iextents; + } block{ + v_output.extents(), + safe_downcast( + v_input_sizes[Layout::Activation4D::width] * + v_input_sizes[Layout::Activation4D::height]), + v_input.extents()}; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + keepdim ? VK_KERNEL(mean) : VK_KERNEL(mean2d), + // pipeline barrier + pipeline_barrier, + // global work group size + v_input.extents(), + // local work group size + adaptive_work_group_size(v_input.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index 8bf7815d91e7e..0587a0a95a0ae 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include @@ -10,15 +10,14 @@ namespace ops { namespace { using namespace api::utils; +using namespace at::native::vulkan::ops; -vTensor pack_weights( - const Tensor& weight_arg) { +vTensor pack_weights(const Tensor& weight_arg) { if (weight_arg.is_vulkan()) { return convert(weight_arg); } api::Context* const context = api::context(); - api::Command::Buffer& command_buffer = context->command().pool.stream(); // Don't collect the timestamp since the command buffer doesn't record anything const Tensor weight = weight_arg.contiguous(); const IntArrayRef w_sizes = weight.sizes(); @@ -36,30 +35,33 @@ vTensor pack_weights( vTensor v_weight{ context, { - 4, - dst_kh_sz, - dst_kw_sz, + 4, + dst_kh_sz, + dst_kw_sz, }, weight.options(), }; - using Future = vTensor::Future; - Future v_weight_future = v_weight.host(command_buffer); - Future::Payload v_weight_payload = v_weight_future.wait(); - - float* const dst_weight_ptr = v_weight_payload.get(); - memset(dst_weight_ptr, 0, v_weight.nbytes()); - - for (const auto src_h : c10::irange(src_kh_sz)) { - for (const auto src_w : c10::irange(src_kw_sz)) { - int64_t dst_plane = 2*(src_h%2) + (src_w%2); - int64_t dst_index = (src_h/2)*dst_kw_sz + (src_w/2); - memcpy( - dst_weight_ptr + dst_plane * dst_plane_sz + dst_index, - src_weight_ptr + src_h * src_kw_sz + src_w, - sizeof(float)); + api::StagingBuffer staging(context, v_weight.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + float* dst_weight_ptr = mapping.template data(); + + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + for (const auto src_h : c10::irange(src_kh_sz)) { + for (const auto src_w : c10::irange(src_kw_sz)) { + int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2); + int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2); + memcpy( + dst_weight_ptr + dst_plane * dst_plane_sz + dst_index, + src_weight_ptr + src_h * src_kw_sz + src_w, + sizeof(float)); + } } } + utils::pack_staging_to_vtensor(staging.buffer(), v_weight); return v_weight; } @@ -72,9 +74,7 @@ vTensor pack_biases( } api::Context* const context = api::context(); - api::Command::Buffer& command_buffer = context->command().pool.stream(); // Don't collect the timestamp since the command buffer doesn't record anything - using Future = vTensor::Future; if (bias_arg) { const Tensor bias = bias_arg->contiguous(); const IntArrayRef b_sizes = bias.sizes(); @@ -85,8 +85,7 @@ vTensor pack_biases( if (bias.sizes().size() == 2) { src_kw_sz = b_sizes[Layout::Parameter::width]; src_kh_sz = b_sizes[Layout::Parameter::height]; - } - else { + } else { src_kw_sz = b_sizes[Layout::Parameter::height]; src_kh_sz = 1; } @@ -99,76 +98,84 @@ vTensor pack_biases( vTensor v_bias{ context, { - 4, - dst_kh_sz, - dst_kw_sz, + 4, + dst_kh_sz, + dst_kw_sz, }, bias_arg->options(), }; - Future v_bias_future = v_bias.host(command_buffer); - Future::Payload v_bias_payload = v_bias_future.wait(); + api::StagingBuffer staging(context, v_bias.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); - float* const dst_bias_ptr = v_bias_payload.get(); - memset(dst_bias_ptr, 0, v_bias.nbytes()); + float* dst_bias_ptr = mapping.template data(); - for (const auto src_h : c10::irange(src_kh_sz)) { - for (const auto src_w : c10::irange(src_kw_sz)) { - int64_t dst_plane = 2*(src_h%2) + (src_w%2); - int64_t dst_index = (src_h/2)*dst_kw_sz + (src_w/2); - memcpy( - dst_bias_ptr + dst_plane * dst_plane_sz + dst_index, - src_bias_ptr + src_h * src_kw_sz + src_w, - sizeof(float)); + memset(dst_bias_ptr, 0, v_bias.nbytes()); + + for (const auto src_h : c10::irange(src_kh_sz)) { + for (const auto src_w : c10::irange(src_kw_sz)) { + int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2); + int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2); + memcpy( + dst_bias_ptr + dst_plane * dst_plane_sz + dst_index, + src_bias_ptr + src_h * src_kw_sz + src_w, + sizeof(float)); + } } } + utils::pack_staging_to_vtensor(staging.buffer(), v_bias); return v_bias; - } - else { + } else { vTensor v_bias{ api::context(), {1}, weight_arg.options(), }; - Future v_bias_future = v_bias.host(command_buffer); - Future::Payload v_bias_payload = v_bias_future.wait(); - memset( - v_bias_payload.get(), - // 2's complement integers and IEEE-754 floating point numbers both - // have identical bit representations for 0, so can use memset which - // only accepts uint8_t parameter. - 0, - v_bias.nbytes()); + + api::StagingBuffer staging(context, v_bias.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + float* data_ptr = mapping.template data(); + + memset( + data_ptr, + // 2's complement integers and IEEE-754 floating point numbers both + // have identical bit representations for 0, so can use memset which + // only accepts uint8_t parameter. + 0, + v_bias.nbytes()); + } + utils::pack_staging_to_vtensor(staging.buffer(), v_bias); return v_bias; } } -bool available( - const Tensor& weight, - const c10::optional& bias) { +bool available(const Tensor& weight, const c10::optional& bias) { return api::available() && - // Weight - (2 == weight.ndimension()) && - (weight.size(Layout::Parameter::height) > 0) && - (weight.size(Layout::Parameter::width) > 0) && - ((weight.device().is_cpu()) || - (c10::DeviceType::Vulkan == weight.device().type())) && - (kFloat == weight.scalar_type()) && - !weight.requires_grad() && - // Bias - ((bias && bias->defined()) ? ((bias->ndimension() > 0) && - ((bias->device().is_cpu()) || - (c10::DeviceType::Vulkan == bias->device().type())) && - (kFloat == bias->scalar_type()) && - ((bias->ndimension() > 1) ? - (bias->size(Layout::Parameter::width) == - weight.size(Layout::Parameter::width)) - : true) && - !bias->requires_grad()) - : true) && - true; + // Weight + (2 == weight.ndimension()) && + (weight.size(Layout::Parameter::height) > 0) && + (weight.size(Layout::Parameter::width) > 0) && + ((weight.device().is_cpu()) || + (c10::DeviceType::Vulkan == weight.device().type())) && + (kFloat == weight.scalar_type()) && !weight.requires_grad() && + // Bias + ((bias && bias->defined()) + ? ((bias->ndimension() > 0) && + ((bias->device().is_cpu()) || + (c10::DeviceType::Vulkan == bias->device().type())) && + (kFloat == bias->scalar_type()) && + ((bias->ndimension() > 1) + ? (bias->size(Layout::Parameter::width) == + weight.size(Layout::Parameter::width)) + : true) && + !bias->requires_grad()) + : true) && + true; } bool usable( @@ -176,12 +183,11 @@ bool usable( const Tensor& weight, const c10::optional& /* bias */) { return (2 == input.ndimension()) && - (c10::DeviceType::Vulkan == input.device().type()) && - (kFloat == input.scalar_type()) && - (input.size(Layout::Parameter::width) == - weight.size(Layout::Parameter::height)) && - !input.requires_grad() && - true; + (c10::DeviceType::Vulkan == input.device().type()) && + (kFloat == input.scalar_type()) && + (input.size(Layout::Parameter::width) == + weight.size(Layout::Parameter::height)) && + !input.requires_grad() && true; } VulkanOpContext context_create( @@ -206,23 +212,36 @@ VulkanOpContext context_create( return VulkanOpContext::create(packed_context, unpacked_context); } +static Tensor reshape_to_2d(const Tensor& input_arg) { + TORCH_CHECK( + input_arg.dim() >= 2, + "Vulkan Linear op only supports input tensor with dim >= 2"); + const IntArrayRef input_sizes = input_arg.sizes(); + const auto d = + c10::multiply_integers(input_sizes.cbegin(), input_sizes.end() - 1); + return input_arg.reshape({d, input_arg.size(-1)}); +} + Tensor context_run( const Tensor& input_arg, const c10::impl::GenericList& packed_context, const c10::impl::GenericList& unpacked_context, const float alpha, - const float beta, - const std::string& op_name) { + const float beta) { api::Context* const context = api::context(); - const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const Tensor input_arg_2d = + input_arg.dim() == 2 ? input_arg : reshape_to_2d(input_arg); + const Tensor input = + input_arg_2d.is_vulkan() ? input_arg_2d : input_arg_2d.vulkan(); const vTensor& v_input = convert(input); const vTensor& packed_v_weight = convert(packed_context.get(0).toTensor()); const vTensor& packed_v_bias = convert(packed_context.get(1).toTensor()); const Tensor& unpacked_weight = unpacked_context.get(0).toTensor(); - const c10::optional& unpacked_bias - = unpacked_context.get(1).isTensor() ? unpacked_context.get(1).toTensor() : c10::optional(); + const c10::optional& unpacked_bias = + unpacked_context.get(1).isTensor() ? unpacked_context.get(1).toTensor() + : c10::optional(); TORCH_CHECK( usable(input, unpacked_weight, unpacked_bias), @@ -231,135 +250,112 @@ Tensor context_run( "combination with the provided weight and bias tensors are unsupported by " "Vulkan impl."); - c10::SmallVector output_sizes{ - v_input.sizes()[Layout::Parameter::height], - unpacked_weight.sizes()[Layout::Parameter::width], - }; - - vTensor v_output { + vTensor v_output{ context, { - v_input.sizes()[Layout::Parameter::height], - unpacked_weight.sizes()[Layout::Parameter::width], + v_input.sizes()[Layout::Parameter::height], + unpacked_weight.sizes()[Layout::Parameter::width], }, input.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if (v_input.has_image() && - packed_v_weight.has_image() && - packed_v_bias.has_image()) { - if (unpacked_bias && unpacked_bias->defined()) { - const struct { - uvec3 size; - int32_t K; - vec2 multiplier; - } block { - v_output.extents(), - safe_downcast(div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(2))), - { - alpha, - beta, - }, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(addmm), - { - safe_downcast(div_up(unpacked_weight.sizes()[Layout::Parameter::width], INT64_C(2))), - safe_downcast(div_up(v_input.sizes()[Layout::Parameter::height], INT64_C(2))), - 1, - }, - {8, 8, 1}, - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - packed_v_weight.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - packed_v_bias.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - const struct { - uvec3 size; - int32_t K; - } block_no_bias { - v_output.extents(), - safe_downcast(div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(2))), - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(mm), - { - safe_downcast(div_up(unpacked_weight.sizes()[Layout::Parameter::width], INT64_C(2))), - safe_downcast(div_up(v_input.sizes()[Layout::Parameter::height], INT64_C(2))), - 1, - }, - {8, 8, 1}, - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - packed_v_weight.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block_no_bias).object); - } - } - else { - TORCH_CHECK(false, "Not implemented!"); - } + if (unpacked_bias && unpacked_bias->defined()) { + const struct { + uvec3 size; + int32_t K; + vec2 multiplier; + } block{ + v_output.extents(), + safe_downcast( + div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(2))), + { + alpha, + beta, + }, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(addmm), + // pipeline barrier + pipeline_barrier, + // global work group size + { + safe_downcast(div_up( + unpacked_weight.sizes()[Layout::Parameter::width], INT64_C(2))), + safe_downcast( + div_up(v_input.sizes()[Layout::Parameter::height], INT64_C(2))), + 1, + }, + // local work group size + {8, 8, 1}, + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + } else { + const struct { + uvec3 size; + int32_t K; + } block_no_bias{ + v_output.extents(), + safe_downcast( + div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(2))), + }; + + api::UniformParamsBuffer params(context, block_no_bias); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(mm), + // pipeline barrier + pipeline_barrier, + // global work group size + { + safe_downcast(div_up( + unpacked_weight.sizes()[Layout::Parameter::width], INT64_C(2))), + safe_downcast( + div_up(v_input.sizes()[Layout::Parameter::height], INT64_C(2))), + 1, + }, + // local work group size + {8, 8, 1}, + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); } - command_pool.submit(context->gpu().queue, command_buffer); - return convert(v_output); + Tensor output = convert(v_output); + if (input_arg.dim() == 2) { + return output; + } else { + std::vector shape; + for (const auto i : c10::irange(input_arg.dim() - 1)) { + shape.emplace_back(input_arg.size(i)); + } + shape.emplace_back(output.size(-1)); + return output.reshape(shape); + } } Tensor addmm( @@ -368,33 +364,26 @@ Tensor addmm( const Tensor& weight, const Scalar& beta, const Scalar& alpha) { - VulkanOpContext vulkan_context = context_create(weight, bias); return context_run( - input, - vulkan_context.get_packed(), - vulkan_context.get_unpacked(), - alpha.to(), - beta.to(), - "aten::addmm"); + input, + vulkan_context.get_packed(), + vulkan_context.get_unpacked(), + alpha.to(), + beta.to()); } -Tensor mm( - const Tensor& mat1_arg, - const Tensor& mat2_arg) { - - VulkanOpContext vulkan_context = context_create( - mat2_arg, - c10::optional()); +Tensor mm(const Tensor& mat1_arg, const Tensor& mat2_arg) { + VulkanOpContext vulkan_context = + context_create(mat2_arg, c10::optional()); return context_run( - mat1_arg, - vulkan_context.get_packed(), - vulkan_context.get_unpacked(), - 1.0f, - 1.0f, - "aten::mm"); + mat1_arg, + vulkan_context.get_packed(), + vulkan_context.get_unpacked(), + 1.0f, + 1.0f); } #ifdef USE_VULKAN_API @@ -419,85 +408,71 @@ Tensor linear_context_run( const c10::impl::GenericList& packed_context, const c10::impl::GenericList& unpacked_context, const float alpha, - const float beta, - const std::string& op_name) { - return context_run(input_arg, packed_context, unpacked_context, alpha, beta, op_name); + const float beta) { + return context_run(input_arg, packed_context, unpacked_context, alpha, beta); } c10::intrusive_ptr create_linear_context( Tensor&& weight, c10::optional&& bias) { return c10::make_intrusive( - linear_context_create( - weight, - bias)); + linear_context_create(weight, bias)); } Tensor run_linear_context( const Tensor& input, const c10::intrusive_ptr& vulkan_context) { return linear_context_run( - input, - vulkan_context->get_packed(), - vulkan_context->get_unpacked(), - 1.0, - 1.0, - "prepacked::linear_clamp_run_vulkan"); + input, + vulkan_context->get_packed(), + vulkan_context->get_unpacked(), + 1.0f, + 1.0f); } /* Backwards compatibility */ LinearOpContext::LinearOpContext(VulkanOpContext vulkan_context) - : vulkan_context_{std::move(vulkan_context)} { -} + : vulkan_context_{std::move(vulkan_context)} {} LinearOpContext LinearOpContext::create( const Tensor& weight, const c10::optional& bias) { - return LinearOpContext { - linear_context_create( - weight, - bias) - }; + return LinearOpContext{linear_context_create(weight, bias)}; } Tensor LinearOpContext::run( const Tensor& input_arg, const float alpha, - const float beta, - const std::string& op_name) const { + const float beta) const { return linear_context_run( - input_arg, - vulkan_context_.get_packed(), - vulkan_context_.get_unpacked(), - alpha, - beta, - op_name); + input_arg, + vulkan_context_.get_packed(), + vulkan_context_.get_unpacked(), + alpha, + beta); } LinearOpContext::State LinearOpContext::unpack() const { - const c10::impl::GenericList unpacked_ = std::get<1>(vulkan_context_.get_state()); + const c10::impl::GenericList unpacked_ = + std::get<1>(vulkan_context_.get_state()); const Tensor unpacked_weight = unpacked_.get(0).toTensor(); - const c10::optional unpacked_bias - = unpacked_.get(1).isTensor() ? unpacked_.get(1).toTensor() : c10::optional(); - return LinearOpContext::State{ - unpacked_weight, - unpacked_bias - }; + const c10::optional unpacked_bias = unpacked_.get(1).isTensor() + ? unpacked_.get(1).toTensor() + : c10::optional(); + return LinearOpContext::State{unpacked_weight, unpacked_bias}; } c10::intrusive_ptr linear_prepack( Tensor&& weight, c10::optional&& bias) { return c10::make_intrusive( - LinearOpContext::create( - std::move(weight), - std::move(bias))); + LinearOpContext::create(std::move(weight), std::move(bias))); } Tensor linear_run( const Tensor& input, const c10::intrusive_ptr& context) { - return context->run(input, 1.0, 1.0, "prepacked::linear_clamp_run"); + return context->run(input, 1.0, 1.0); } } // namespace ops diff --git a/aten/src/ATen/native/vulkan/ops/Mm.h b/aten/src/ATen/native/vulkan/ops/Mm.h index db625eddac945..4d573b575bd40 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.h +++ b/aten/src/ATen/native/vulkan/ops/Mm.h @@ -28,8 +28,7 @@ Tensor linear_context_run( const c10::impl::GenericList& packed_context, const c10::impl::GenericList& unpacked_context, const float alpha, - const float beta, - const std::string& op_name); + const float beta); c10::intrusive_ptr create_linear_context( Tensor&& weight, @@ -48,7 +47,7 @@ class LinearOpContext final : public torch::jit::CustomClassHolder { using State = std::tuple>; - Tensor run(const Tensor& input, float beta, float alpha, const std::string& op_name) const; + Tensor run(const Tensor& input, float beta, float alpha) const; State unpack() const; private: diff --git a/aten/src/ATen/native/vulkan/ops/Padding.cpp b/aten/src/ATen/native/vulkan/ops/Padding.cpp index 21497172a3da2..fe6145517b607 100644 --- a/aten/src/ATen/native/vulkan/ops/Padding.cpp +++ b/aten/src/ATen/native/vulkan/ops/Padding.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -11,8 +10,10 @@ namespace { using namespace api::utils; -Tensor pad2d(const Tensor& self_arg, IntArrayRef padding, const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { +Tensor pad2d( + const Tensor& self_arg, + IntArrayRef padding, + const api::ShaderSource& shader_descriptor) { const int pad_dim = padding.size(); const IntArrayRef input_size = self_arg.sizes(); const int input_dim = input_size.size(); @@ -54,69 +55,62 @@ Tensor pad2d(const Tensor& self_arg, IntArrayRef padding, const api::Shader::Des v_self.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY (v_output.has_image() && v_self.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t _; - uvec4 padding; - } block{ - v_output.extents(), - 0u, - { - safe_downcast(pad_left), - safe_downcast(pad_right), - safe_downcast(pad_top), - safe_downcast(pad_bottom) - }, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_output.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, vTensor::Stage::Compute, vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image(command_buffer, vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const struct Block final { + uvec3 extents; + uint32_t _; + uvec4 padding; + } block{ + v_output.extents(), + 0u, + {safe_downcast(pad_left), + safe_downcast(pad_right), + safe_downcast(pad_top), + safe_downcast(pad_bottom)}, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } Tensor reflection_pad2d(const Tensor& self_arg, IntArrayRef padding) { - return pad2d(self_arg, padding, VK_KERNEL(reflection_pad2d), "aten::reflection_pad2d"); + return pad2d(self_arg, padding, VK_KERNEL(reflection_pad2d)); } Tensor replication_pad2d(const Tensor& self_arg, IntArrayRef padding) { - return pad2d(self_arg, padding, VK_KERNEL(replication_pad2d), "aten::replication_pad2d"); + return pad2d(self_arg, padding, VK_KERNEL(replication_pad2d)); } #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { - m.impl(TORCH_SELECTIVE_NAME("aten::reflection_pad2d"), TORCH_FN(reflection_pad2d)); - m.impl(TORCH_SELECTIVE_NAME("aten::replication_pad2d"), TORCH_FN(replication_pad2d)); + m.impl( + TORCH_SELECTIVE_NAME("aten::reflection_pad2d"), + TORCH_FN(reflection_pad2d)); + m.impl( + TORCH_SELECTIVE_NAME("aten::replication_pad2d"), + TORCH_FN(replication_pad2d)); } #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Permute.cpp b/aten/src/ATen/native/vulkan/ops/Permute.cpp index 557c99592af0a..0aad4d52a00a5 100644 --- a/aten/src/ATen/native/vulkan/ops/Permute.cpp +++ b/aten/src/ATen/native/vulkan/ops/Permute.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -10,78 +9,65 @@ namespace { using namespace api::utils; -Tensor permute_4d(const Tensor& input, const uvec4& in_size, const uvec4& out_size, const uvec4& out_dims, vTensor& v_output) { +Tensor permute_4d( + const Tensor& input_arg, + const uvec4& in_size, + const uvec4& out_size, + const uvec4& out_dims, + vTensor& v_output) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::permute (permute_4d)"); - - auto dst_image = v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Read | vTensor::Access::Write); - - const Tensor self = input.is_vulkan() ? input : input.vulkan(); - const vTensor& v_self = convert(self); - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - auto src_image = v_self.image( - command_buffer, - vTensor::Stage::Compute); - - const struct Block final { - uvec3 size; // output texture size - uint32_t fill_0; // dummy - uvec3 isize; // input texture size - uint32_t fill_1; // dummy - uvec4 tensor_size; // output tensor size - uvec4 itensor_size; // input tensor size - uvec4 dims; // output dims - } block { - v_output.extents(), - 0u, - v_self.extents(), - 0u, - out_size, - in_size, - out_dims, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(permute_4d), - // build up shader operations from the output texture point of view - // to avoid the nondeterministic order of GPU shader operations between texels - v_output.extents(), - context->gpu().adapter->local_work_group_size(), - // Read/Write access bypasses synchronization but inserts appropriate - // barriers if necessary. - dst_image, - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - src_image, - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_self = convert(input); + + const struct Block final { + uvec3 size; // output texture size + uint32_t fill_0; // dummy + uvec3 isize; // input texture size + uint32_t fill_1; // dummy + uvec4 tensor_size; // output tensor size + uvec4 itensor_size; // input tensor size + uvec4 dims; // output dims + } block{ + v_output.extents(), + 0u, + v_self.extents(), + 0u, + out_size, + in_size, + out_dims, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(permute_4d), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ | api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } Tensor permute(const Tensor& self, IntArrayRef dims) { auto nDims = safe_downcast(self.dim()); - TORCH_CHECK(dims.size() == (size_t)nDims, - "number of dims don't match in permute"); + TORCH_CHECK( + dims.size() == (size_t)nDims, "number of dims don't match in permute"); uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u}; uvec4 out_dims{0u, 1u, 2u, 3u}; @@ -92,8 +78,7 @@ Tensor permute(const Tensor& self, IntArrayRef dims) { std::vector seen(nDims); for (const auto i : c10::irange(nDims)) { auto dim = safe_downcast(maybe_wrap_dim(dims[i], nDims)); - TORCH_CHECK(!seen[dim], - "repeated dim in permute"); + TORCH_CHECK(!seen[dim], "repeated dim in permute"); seen[dim] = true; newSizes[i] = oldSizes[dim]; if (dim != i) { @@ -109,10 +94,7 @@ Tensor permute(const Tensor& self, IntArrayRef dims) { return self; } - vTensor v_output{ - api::context(), - newSizes, - self.options()}; + vTensor v_output{api::context(), newSizes, self.options()}; return permute_4d(self, in_size, out_size, out_dims, v_output); } diff --git a/aten/src/ATen/native/vulkan/ops/Pool.cpp b/aten/src/ATen/native/vulkan/ops/Pool.cpp index 7a2fe98ba7d4d..c455f1ba529a9 100644 --- a/aten/src/ATen/native/vulkan/ops/Pool.cpp +++ b/aten/src/ATen/native/vulkan/ops/Pool.cpp @@ -1,6 +1,5 @@ -#include -#include #include +#include #include namespace at { @@ -24,75 +23,63 @@ Tensor adaptive_avg_pool2d( const vTensor& v_self = convert(self); vTensor v_output{ - context, - { - self.size(Layout::Activation4D::batch), - self.size(Layout::Activation4D::channels), - output_size[Layout::Activation4D::batch], - output_size[Layout::Activation4D::channels], - }, - v_self.options(), + context, + { + self_arg.size(Layout::Activation4D::batch), + self_arg.size(Layout::Activation4D::channels), + output_size[Layout::Activation4D::batch], + output_size[Layout::Activation4D::channels], + }, + v_self.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::_adaptive_avg_pool2d"); - - if C10_LIKELY(v_self.has_image()) { - const uvec3 v_output_size = v_output.extents(); - const uvec3 v_self_size = v_self.extents(); - - const vec2 stride { - static_cast(v_self_size.data[0u]) / v_output_size.data[0u], - static_cast(v_self_size.data[1u]) / v_output_size.data[1u], - }; - - const struct Block final { - uvec3 extents; - uint32_t _; - vec2 kernel; - vec2 stride; - } block { - v_output.extents(), - 0u, - { - v_self_size.data[0u] - (v_output_size.data[0u] - 1u) * stride.data[0u], - v_self_size.data[1u] - (v_output_size.data[1u] - 1u) * stride.data[1u], - }, - stride, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(adaptive_avg_pool2d), - v_output.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const uvec3 v_output_size = v_output.extents(); + const uvec3 v_self_size = v_self.extents(); + + const vec2 stride{ + static_cast(v_self_size.data[0u]) / v_output_size.data[0u], + static_cast(v_self_size.data[1u]) / v_output_size.data[1u], + }; + + const struct Block final { + uvec3 extents; + uint32_t _; + vec2 kernel; + vec2 stride; + } block{ + v_output.extents(), + 0u, + { + v_self_size.data[0u] - + (v_output_size.data[0u] - 1u) * stride.data[0u], + v_self_size.data[1u] - + (v_output_size.data[1u] - 1u) * stride.data[1u], + }, + stride, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(adaptive_avg_pool2d), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } @@ -104,8 +91,7 @@ Tensor pool2d( const IntArrayRef padding_arg, const IntArrayRef dilation_arg, const bool ceil_mode, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { + const api::ShaderSource& shader_descriptor) { if (stride_arg.empty()) { stride_arg = kernel_arg; } @@ -116,8 +102,8 @@ Tensor pool2d( static const auto normalize = [](const IntArrayRef parameter) { return std::array{ - parameter[0], - (2 == parameter.size()) ? parameter[1] : parameter[0], + parameter[0], + (2 == parameter.size()) ? parameter[1] : parameter[0], }; }; @@ -166,84 +152,69 @@ Tensor pool2d( const vTensor& v_self = convert(self); vTensor v_output{ - context, - { - input_size[Layout::Activation4D::batch], - input_size[Layout::Activation4D::channels], - output_height, - output_width, - }, - v_self.options(), + context, + { + input_size[Layout::Activation4D::batch], + input_size[Layout::Activation4D::channels], + output_height, + output_width, + }, + v_self.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY(v_self.has_image()) { - const struct Block final { - uvec3 extents; - int32_t range; - ivec4 kernel; - ivec2 stride; - ivec2 padding; - ivec2 dilation; - } block { - v_output.extents(), - safe_downcast( - kernel[Layout::Parameter::width] * - kernel[Layout::Parameter::height]), - { + const struct Block final { + uvec3 extents; + int32_t range; + ivec4 kernel; + ivec2 stride; + ivec2 padding; + ivec2 dilation; + } block{ + v_output.extents(), + safe_downcast( + kernel[Layout::Parameter::width] * kernel[Layout::Parameter::height]), + { safe_downcast(kernel[Layout::Parameter::width]), safe_downcast(kernel[Layout::Parameter::height]), - safe_downcast(self.size(Layout::Activation4D::width)), - safe_downcast(self.size(Layout::Activation4D::height)), - }, - { + safe_downcast(self_arg.size(Layout::Activation4D::width)), + safe_downcast(self_arg.size(Layout::Activation4D::height)), + }, + { safe_downcast(stride[Layout::Parameter::width]), safe_downcast(stride[Layout::Parameter::height]), - }, - { + }, + { safe_downcast(padding[Layout::Parameter::width]), safe_downcast(padding[Layout::Parameter::height]), - }, - { + }, + { safe_downcast(dilation[Layout::Parameter::width]), safe_downcast(dilation[Layout::Parameter::height]), - }, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - v_output.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + }, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } @@ -257,15 +228,13 @@ Tensor avg_pool2d( const bool /* count_include_pad */, const c10::optional /* divisor_override */) { return pool2d( - self_arg, - kernel_arg, - stride_arg, - padding_arg, - {1,1}, - ceil_mode, - VK_KERNEL(avg_pool2d), - "aten::avg_pool2d" - ); + self_arg, + kernel_arg, + stride_arg, + padding_arg, + {1, 1}, + ceil_mode, + VK_KERNEL(avg_pool2d)); } Tensor max_pool2d( @@ -276,21 +245,21 @@ Tensor max_pool2d( const IntArrayRef dilation_arg, const bool ceil_mode) { return pool2d( - self_arg, - kernel_arg, - stride_arg, - padding_arg, - dilation_arg, - ceil_mode, - VK_KERNEL(max_pool2d), - "aten::max_pool2d" - ); + self_arg, + kernel_arg, + stride_arg, + padding_arg, + dilation_arg, + ceil_mode, + VK_KERNEL(max_pool2d)); } #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { - m.impl(TORCH_SELECTIVE_NAME("aten::_adaptive_avg_pool2d"), TORCH_FN(adaptive_avg_pool2d)); + m.impl( + TORCH_SELECTIVE_NAME("aten::_adaptive_avg_pool2d"), + TORCH_FN(adaptive_avg_pool2d)); m.impl(TORCH_SELECTIVE_NAME("aten::avg_pool2d"), TORCH_FN(avg_pool2d)); m.impl(TORCH_SELECTIVE_NAME("aten::max_pool2d"), TORCH_FN(max_pool2d)); } diff --git a/aten/src/ATen/native/vulkan/ops/QuantizedConvolution.cpp b/aten/src/ATen/native/vulkan/ops/QuantizedConvolution.cpp new file mode 100644 index 0000000000000..283967fb9087a --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/QuantizedConvolution.cpp @@ -0,0 +1,648 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; +using namespace at::native::vulkan::ops; + +inline bool is_depthwise(const IntArrayRef filter, const int64_t groups) { + return (filter[Layout::Filter::output] == groups) && + // Only K == 1 supported. + (filter[Layout::Filter::input] == 1); +} + +inline bool is_pointwise(const IntArrayRef filter) { + return (1 == filter[Layout::Filter::height]) && + (1 == filter[Layout::Filter::width]); +} + +bool all_lessthan(const IntArrayRef arr, const int t) { + bool retval = true; + for (const auto i : c10::irange(arr.size())) { + retval = retval && (arr[i] < t); + } + return retval; +} + +Conv2dQMethod determine_method( + const IntArrayRef filter, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const int64_t groups) { + if (is_depthwise(filter, groups)) + return Conv2dQDepthwise; + if (is_pointwise(filter)) + return Conv2dQPointwise; + return Conv2dQSlidingWindow; +} + +vTensor pack_weights_dw_q(api::Context* const context, const Tensor& weight) { + /* Source */ + const IntArrayRef src_filter = weight.sizes(); + const c10::quint8* const src_weight_ptr = weight.data_ptr(); + + const int64_t src_kw_sz = src_filter[Layout::Filter::width]; + const int64_t src_kh_sz = src_filter[Layout::Filter::height]; + const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; + const int64_t src_block_sz = + src_kernel_sz * src_filter[Layout::Filter::input]; + const int64_t num_stacks = + div_up(src_filter[Layout::Filter::output], INT64_C(4)); + + /* Destination */ + const int64_t dst_kw_sz = src_kernel_sz; + const int64_t dst_kh_sz = num_stacks; + const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + + vTensor v_weight{ + context, + { + 4, + dst_kh_sz, + dst_kw_sz, + }, + weight.options(), + weight.q_scale(), + weight.q_zero_point(), + }; + api::StagingBuffer staging(context, v_weight.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + c10::quint8* dst_weight_ptr = mapping.template data(); + + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + for (const auto src_oc : c10::irange(src_filter[Layout::Filter::output])) { + /* Source */ + const c10::quint8* const src_weight_oc_ptr = + src_weight_ptr + src_oc * src_block_sz; + + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; + + c10::quint8* const dst_weight_c_ptr = + dst_weight_ptr + dst_c * dst_kernel_sz + dst_oh * dst_kw_sz; + + for (const auto src_ih : + c10::irange(src_filter[Layout::Filter::height])) { + memcpy( + dst_weight_c_ptr + src_ih * src_kw_sz, + src_weight_oc_ptr + src_ih * src_kw_sz, + sizeof(c10::quint8) * src_kw_sz); + } + } + } + ops::utils::pack_staging_to_vtensor(staging.buffer(), v_weight); + + return v_weight; +} + +vTensor pack_weights_2d_q(api::Context* const context, const Tensor& weight) { + /* Source */ + const IntArrayRef src_filter = weight.sizes(); + const c10::quint8* const src_weight_ptr = weight.data_ptr(); + + const int64_t src_kw_sz = src_filter[Layout::Filter::width]; + const int64_t src_kh_sz = src_filter[Layout::Filter::height]; + const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; + const int64_t src_block_sz = + src_kernel_sz * src_filter[Layout::Filter::input]; + + const int64_t num_stacks = + div_up(src_filter[Layout::Filter::output], INT64_C(4)); + const int64_t stack_depth = + api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); + + /* Destination */ + const int64_t dst_kw_sz = src_kw_sz * stack_depth; + const int64_t dst_kh_sz = src_kh_sz * num_stacks; + const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + + vTensor v_weight{ + context, + { + 4, + dst_kh_sz, + dst_kw_sz, + }, + weight.options(), + weight.q_scale(), + weight.q_zero_point(), + }; + + api::StagingBuffer staging(context, v_weight.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + c10::quint8* dst_weight_ptr = mapping.template data(); + + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + for (const auto src_oc : c10::irange(src_filter[Layout::Filter::output])) { + /* Source */ + const c10::quint8* const src_weight_oc_ptr = + src_weight_ptr + src_oc * src_block_sz; + + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; + + c10::quint8* const dst_weight_c_ptr = + dst_weight_ptr + dst_c * dst_kernel_sz; + + for (const auto src_ic : c10::irange(src_filter[Layout::Filter::input])) { + const int64_t dst_ic4 = src_ic / 4; + + for (const auto src_ih : c10::irange(src_kh_sz)) { + for (const auto src_iw : c10::irange(src_kw_sz)) { + memcpy( + dst_weight_c_ptr + (dst_oh * src_kh_sz + src_ih) * dst_kw_sz + + dst_ic4 * src_kw_sz * 4 + src_iw * 4 + src_ic % 4, + src_weight_oc_ptr + src_ic * src_kernel_sz + + src_ih * src_kw_sz + src_iw, + sizeof(c10::quint8)); + } + } + } + } + } + ops::utils::pack_staging_to_vtensor(staging.buffer(), v_weight); + + return v_weight; +} + +vTensor pack_weights_q( + const Tensor& weight_arg, + const Conv2dQMethod conv_method) { + if (weight_arg.is_vulkan()) { + return convert(weight_arg); + } + + api::Context* const context = api::context(); + + const Tensor weight = weight_arg.contiguous(); + + if (conv_method == Conv2dQDepthwise) { + return pack_weights_dw_q(context, weight); + } + + return pack_weights_2d_q(context, weight); +} + +vTensor pack_biases_q(const c10::optional& bias, const Tensor& weight) { + if (bias && bias->is_vulkan()) { + return convert(*bias); + } + + api::Context* const context = api::context(); + + const int64_t src_w = weight.size(Layout::Filter::output); + const int64_t packed_w = div_up(src_w, INT64_C(4)); + vTensor v_bias{ + context, + { + 4, + 1, + packed_w, + }, + weight.options(), + weight.q_scale(), + weight.q_zero_point(), + }; + + api::StagingBuffer staging(context, v_bias.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + c10::quint8* dst_bias_ptr = mapping.template data(); + + if (bias) { + const c10::quint8* const src_bias_ptr = + bias->contiguous().data_ptr(); + + memset(dst_bias_ptr, 0, v_bias.nbytes()); + for (const auto i : c10::irange(src_w)) { + const int64_t c = i % 4; + const int64_t x = i / 4; + dst_bias_ptr[c * packed_w + x] = src_bias_ptr[i]; + } + } else { + memset( + dst_bias_ptr, + // 2's complement integers and IEEE-754 floating point numbers both + // have identical bit representations for 0, so can use memset which + // only accepts uint8_t parameter. + 0, + v_bias.nbytes()); + } + } + ops::utils::pack_staging_to_vtensor(staging.buffer(), v_bias); + + return v_bias; +} + +std::array pack_filter( + const Tensor& weight, + const IntArrayRef dilation) { + const IntArrayRef filter = weight.sizes(); + + const auto effective = [](const int64_t k, const int64_t d) { + return k + (k - 1) * (d - 1); + }; + + return { + align_up(filter[Layout::Filter::output], INT64_C(4)), + align_up(filter[Layout::Filter::input], INT64_C(4)), + effective( + filter[Layout::Filter::height], dilation[Layout::Parameter::height]), + effective( + filter[Layout::Filter::width], dilation[Layout::Parameter::width]), + }; +} + +std::array pack_params(const std::vector& vector) { + TORCH_INTERNAL_ASSERT(2u == vector.size(), "Invalid usage!"); + + return { + vector[0], + vector[1], + }; +} + +bool available( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool transposed, + const IntArrayRef /* output_padding */, + const int64_t groups, + const c10::optional& output_min, + const c10::optional& output_max) { + return api::available() && + // Weight + (4 == weight.ndimension()) && (weight.size(Layout::Filter::height) > 0) && + (weight.size(Layout::Filter::width) > 0) && + ((weight.device().is_cpu()) || + (c10::DeviceType::Vulkan == weight.device().type())) && + (kFloat == weight.scalar_type() || + c10::kQUInt8 == weight.scalar_type()) && + // Bias + ((bias && bias->defined()) + ? ((1 == bias->ndimension()) && + ((bias->device().is_cpu()) || + (c10::DeviceType::Vulkan == bias->device().type())) && + (kFloat == bias->scalar_type() || + c10::kQUInt8 == bias->scalar_type()) && + (transposed ? false /* to be addded in the future */ + : (weight.size(Layout::Filter::output) == + bias->size(Layout::Filter::output)))) + : true) && + // Stride + (stride[Layout::Parameter::height] > 0) && + (stride[Layout::Parameter::width] > 0) && + // Padding + (padding[Layout::Parameter::height] >= 0) && + (padding[Layout::Parameter::width] >= 0) && + // Dilation + (dilation[Layout::Parameter::height] > 0) && + (dilation[Layout::Parameter::width] > 0) && + // Groups + (groups > 0) && + // Input + (weight.size(Layout::Filter::input) > 0) && + // Output + (weight.size(Layout::Filter::output) > 0) && + // Output - Groups + ((weight.size(Layout::Filter::output) % groups) == 0) && + // Output Min / Max + (!output_min || output_min->isFloatingPoint()) && + (!output_max || output_max->isFloatingPoint()) && true; +} + +bool usable(const Tensor& input) { + // Input + return (4 == input.ndimension()) && + (c10::DeviceType::Vulkan == input.device().type()) && + (kFloat == input.scalar_type() || c10::kQUInt8 == input.scalar_type()) && + (input.size(Layout::Activation4D::batch) >= 0) && + (input.size(Layout::Activation4D::channels) > 0) && + (input.size(Layout::Activation4D::height) > 0) && + (input.size(Layout::Activation4D::width) > 0) && !input.requires_grad() && + true; +} + +} // namespace + +VulkanOpContext conv2d_context_create_q( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride_arg, + const IntArrayRef padding_arg, + const IntArrayRef dilation_arg, + const bool transposed, + const IntArrayRef output_padding_arg, + const int64_t groups, + const c10::optional& output_min, + const c10::optional& output_max) { + const auto stride = expand_param_if_needed(stride_arg, "stride", 2); + const auto padding = expand_param_if_needed(padding_arg, "padding", 2); + const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2); + const auto output_padding = output_padding_arg; // TODO: Deconvolutions + + TORCH_CHECK( + available( + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_min, + output_max), + "Vulkan::convolution not available! " + "Reason: The provided (weight, bias, stride, padding, dilation, groups, " + "transposed, output_padding, output_min, output_max) parameters are either " + "invalid individually or their combination is not supported by Vulkan impl."); + + TORCH_CHECK(weight.is_quantized(), "Weight Tensor is not Quantized"); + TORCH_CHECK(bias->is_quantized(), "Bias Tensor is not Quantized"); + + auto method = + determine_method(weight.sizes(), stride, padding, dilation, groups); + + c10::impl::GenericList packed_context{c10::AnyType::get()}; + packed_context.reserve(10); + packed_context.emplace_back(convert(pack_weights_q(weight, method))); + packed_context.emplace_back(convert(pack_biases_q(bias, weight))); + packed_context.emplace_back(pack_filter(weight, dilation)); + packed_context.emplace_back(pack_params(stride)); + packed_context.emplace_back(pack_params(padding)); + packed_context.emplace_back(output_padding); + packed_context.emplace_back(pack_params(dilation)); + packed_context.emplace_back(safe_downcast(groups)); + packed_context.emplace_back( + output_min ? output_min->template to() + : -std::numeric_limits::infinity()); + packed_context.emplace_back( + output_max ? output_max->template to() + : +std::numeric_limits::infinity()); + packed_context.emplace_back(method); + + c10::impl::GenericList unpacked_context{c10::AnyType::get()}; + unpacked_context.reserve(10); + unpacked_context.emplace_back(weight); + unpacked_context.emplace_back(bias); + unpacked_context.emplace_back(weight.sizes().vec()); + unpacked_context.emplace_back(stride_arg.vec()); + unpacked_context.emplace_back(padding_arg.vec()); + unpacked_context.emplace_back(output_padding_arg.vec()); + unpacked_context.emplace_back(dilation_arg.vec()); + unpacked_context.emplace_back(groups); + unpacked_context.emplace_back(output_min); + unpacked_context.emplace_back(output_max); + unpacked_context.emplace_back(method); + return VulkanOpContext::create(packed_context, unpacked_context); +} + +void conv2d_sliding_window_q( + const api::ShaderSource& shader, + vTensor& v_output, + const vTensor& v_input, + const vTensor& packed_v_weight, + const vTensor& packed_v_bias, + const IntArrayRef packed_filter, + const IntArrayRef packed_stride, + const IntArrayRef packed_padding, + const IntArrayRef packed_dilation, + const float packed_output_min, + const float packed_output_max, + const IntArrayRef unpacked_filter, + const Conv2dQMethod method_, + const double scale, + const int64_t zero_point) { + api::Context* const context = api::context(); + + const double scale_out = v_output.get_scale(); + const int64_t zero_point_out = v_output.get_zero_point(); + + const double weight_scale = packed_v_weight.get_scale(); + const int64_t weight_zero_point = packed_v_weight.get_zero_point(); + + const double bias_scale = packed_v_bias.get_scale(); + const int64_t bias_zero_point = packed_v_bias.get_zero_point(); + + const struct Block final { + uvec3 extents; + int32_t ic4; + ivec4 kernel; + float scale_out; + float scale; + int32_t zero_point_out; + int32_t zero_point; + float weight_scale; + float bias_scale; + int32_t weight_zero_point; + int32_t bias_zero_point; + ivec2 ikernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + } block{ + v_output.extents(), + safe_downcast(packed_filter[Layout::Filter::input]), + { + safe_downcast(packed_filter[Layout::Filter::width]), + safe_downcast(packed_filter[Layout::Filter::height]), + safe_downcast(v_input.sizes()[Layout::Activation4D::width]), + safe_downcast(v_input.sizes()[Layout::Activation4D::height]), + }, + safe_downcast(scale_out), + safe_downcast(scale), + safe_downcast(zero_point_out), + safe_downcast(zero_point), + safe_downcast(weight_scale), + safe_downcast(bias_scale), + safe_downcast(weight_zero_point), + safe_downcast(bias_zero_point), + { + safe_downcast(unpacked_filter[Layout::Filter::width]), + safe_downcast(unpacked_filter[Layout::Filter::height]), + }, + { + safe_downcast(packed_stride[Layout::Parameter::width]), + safe_downcast(packed_stride[Layout::Parameter::height]), + }, + { + safe_downcast(packed_padding[Layout::Parameter::width]), + safe_downcast(packed_padding[Layout::Parameter::height]), + }, + { + safe_downcast(packed_dilation[Layout::Parameter::width]), + safe_downcast(packed_dilation[Layout::Parameter::height]), + }, + { + packed_output_min, + packed_output_max, + }, + }; + + uvec3 global_size = v_output.extents(); + if (method_ == Conv2dQPointwise) { + global_size = { + safe_downcast( + div_up(v_output.sizes()[Layout::Filter::width], INT64_C(2))), + safe_downcast( + div_up(v_output.sizes()[Layout::Filter::height], INT64_C(2))), + v_output.extents().data[2u]}; + } + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader, + // pipeline barrier + pipeline_barrier, + // global work group size + global_size, + // local work group size + adaptive_work_group_size(global_size), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); +} + +Tensor conv2d_context_run_q( + const Tensor& input_arg, + const c10::impl::GenericList& packed_context, + const c10::impl::GenericList& unpacked_context, + double scale, + int64_t zero_point) { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + const vTensor& packed_v_weight = convert(packed_context.get(0).toTensor()); + const vTensor& packed_v_bias = convert(packed_context.get(1).toTensor()); + + const auto packed_filter = packed_context.get(2).toIntVector(); + const auto packed_stride = packed_context.get(3).toIntVector(); + const auto packed_padding = packed_context.get(4).toIntVector(); + const auto packed_dilation = packed_context.get(6).toIntVector(); + const float packed_output_min = + safe_downcast(packed_context.get(8).toDouble()); + const float packed_output_max = + safe_downcast(packed_context.get(9).toDouble()); + const auto unpacked_filter = unpacked_context.get(2).toIntVector(); + const Conv2dQMethod method_ = (Conv2dQMethod)unpacked_context.get(10).toInt(); + + TORCH_CHECK( + usable(input), + "Vulkan Convolution not usable! " + "Reason: The provided input tensor is either invalid or unsupported by Vulkan impl."); + + vTensor v_output{ + context, + conv_output_size( + v_input.sizes(), + unpacked_filter, + packed_padding, + packed_stride, + packed_dilation), + input.options(), + scale, + zero_point, + }; + + if (method_ == Conv2dQSlidingWindow) { + conv2d_sliding_window_q( + VK_KERNEL(quantized_conv2d), + v_output, + v_input, + packed_v_weight, + packed_v_bias, + packed_filter, + packed_stride, + packed_padding, + packed_dilation, + packed_output_min, + packed_output_max, + unpacked_filter, + method_, + v_input.get_scale(), + v_input.get_zero_point()); + } else if (method_ == Conv2dQPointwise) { + conv2d_sliding_window_q( + VK_KERNEL(quantized_conv2d_pw_2x2), + v_output, + v_input, + packed_v_weight, + packed_v_bias, + packed_filter, + packed_stride, + packed_padding, + packed_dilation, + packed_output_min, + packed_output_max, + unpacked_filter, + method_, + v_input.get_scale(), + v_input.get_zero_point()); + } else if (method_ == Conv2dQDepthwise) { + conv2d_sliding_window_q( + VK_KERNEL(quantized_conv2d_dw), + v_output, + v_input, + packed_v_weight, + packed_v_bias, + packed_filter, + packed_stride, + packed_padding, + packed_dilation, + packed_output_min, + packed_output_max, + unpacked_filter, + method_, + v_input.get_scale(), + v_input.get_zero_point()); + } else { + TORCH_CHECK(false, "Invalid Method"); + } + + return convert_quantized(v_output); +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/QuantizedConvolution.h b/aten/src/ATen/native/vulkan/ops/QuantizedConvolution.h new file mode 100644 index 0000000000000..4853623a7fa37 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/QuantizedConvolution.h @@ -0,0 +1,44 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +enum Conv2dQMethod { + Conv2dQDepthwise, + Conv2dQPointwise, + Conv2dQSlidingWindow, +}; + +VulkanOpContext conv2d_context_create_q( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride_arg, + const IntArrayRef padding_arg, + const IntArrayRef dilation_arg, + const bool transposed, + const IntArrayRef output_padding_arg, + const int64_t groups, + const c10::optional& output_min = c10::nullopt, + const c10::optional& output_max = c10::nullopt); + +Tensor conv2d_context_run_q( + const Tensor& input_arg, + const c10::impl::GenericList& packed_context, + const c10::impl::GenericList& unpacked_context, + double scale, + int64_t zero_point); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/QuantizedFunctions.h b/aten/src/ATen/native/vulkan/ops/QuantizedFunctions.h new file mode 100644 index 0000000000000..21fdbbf001c8b --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/QuantizedFunctions.h @@ -0,0 +1,66 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Tensor quantize_per_tensor( + const at::Tensor& input_arg, + const double scale, + const int64_t zero_point, + const c10::ScalarType dtype); + +Tensor dequantize_helper( + const at::Tensor& input_arg, + const double scale, + const int64_t zero_point, + const c10::ScalarType dtype); + +Tensor dequantize(const Tensor& self); + +Tensor quantized_add( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point); + +Tensor quantized_sub( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point); + +Tensor quantized_mul( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point); + +Tensor quantized_div( + const Tensor& self_arg, + const Tensor& other_arg, + const double scale, + const int64_t zero_point); + +Tensor conv2d( + const Tensor& input_, + const Tensor& weight, + const c10::optional& bias_opt, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + double out_scale, + int64_t out_zero_point); + +Tensor quantized_upsample_nearest2d( + const Tensor& input_arg, + const IntArrayRef output_sizes, + const c10::optional scales_h, + const c10::optional scales_w); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp b/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp new file mode 100644 index 0000000000000..c4ba030b5bb4e --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp @@ -0,0 +1,153 @@ +#include +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +using namespace api::utils; + +Tensor quantize_per_tensor( + const at::Tensor& input_arg, + const double scale, + const int64_t zero_point, + const c10::ScalarType dtype) { + TORCH_CHECK(dtype == c10::ScalarType::QUInt8, "Expected type c10::kQUint8"); + + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + vTensor v_output{ + context, + input.sizes(), + input.options().dtype(c10::kQUInt8), + scale, + zero_point}; + + const struct Block final { + uvec3 extents; + uint32_t _; + float scale; + float _1; + int32_t zero_point; + int32_t _2; + } block{ + v_output.extents(), + 0u, + safe_downcast(scale), + 0.0f, + safe_downcast(zero_point), + 0u, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(quantize_per_tensor), + // barrier + pipeline_barrier, + // global work group size + v_input.extents(), + // local work group size + adaptive_work_group_size(v_input.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return convert_quantized(v_output); +} + +// helper for dequantize function to use scale and zero_point +Tensor dequantize_helper( + const at::Tensor& input_arg, + const double scale, + const int64_t zero_point, + const c10::ScalarType dtype) { + TORCH_CHECK(dtype == kFloat, "Expected type Float"); + + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + vTensor v_output{ + context, + input.sizes(), + input.options().dtype(c10::kFloat), + }; + + const struct Block final { + uvec3 extents; + uint32_t _; + float scale; + float _1; + int32_t zero_point; + int32_t _2; + } block{ + v_output.extents(), + 0u, + safe_downcast(scale), + 0.0f, + safe_downcast(zero_point), + 0u, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + context->submit_compute_job( + // shader descriptor + VK_KERNEL(dequantize), + // pipeline barrier + pipeline_barrier, + // global work group size + v_input.extents(), + // local work group size + adaptive_work_group_size(v_input.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return convert(v_output); +} + +Tensor dequantize(const Tensor& self) { + double q_scale = convert(self).get_scale(); + int64_t zero_point = convert(self).get_zero_point(); + return dequantize_helper(self, q_scale, zero_point, kFloat); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl( + TORCH_SELECTIVE_NAME("aten::quantize_per_tensor"), quantize_per_tensor); + m.impl(TORCH_SELECTIVE_NAME("aten::dequantize.self"), dequantize); +} + +#endif /* USE_VULKAN_API */ + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Register.cpp b/aten/src/ATen/native/vulkan/ops/Register.cpp index cabf75ece4155..4cc1ba4e8bb6b 100644 --- a/aten/src/ATen/native/vulkan/ops/Register.cpp +++ b/aten/src/ATen/native/vulkan/ops/Register.cpp @@ -3,8 +3,11 @@ #include #include #include -#include +#include #include +#include +#include +#include #include #include #include @@ -24,10 +27,8 @@ TORCH_LIBRARY(vulkan, m) { }, // __setstate__ [](VulkanOpContext::State state) { - return c10::make_intrusive( - VulkanOpContext::create( - std::get<0>(state), - std::get<1>(state))); + return c10::make_intrusive(VulkanOpContext::create( + std::get<0>(state), std::get<1>(state))); }); // To maintain backwards compatibility. m.class_("Conv2dOpContext") @@ -171,28 +172,82 @@ TORCH_LIBRARY(vulkan_prepack, m) { "vulkan_prepack::gru_run(Tensor input_vk, " "Tensor hx_vk, " "__torch__.torch.classes.vulkan.GruOpContext G_prepack) -> (Tensor next_input, Tensor hidden_layer)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "vulkan_prepack::create_lstm_context(Tensor[] params_cpu, " + "bool has_biases, " + "int num_layers, " + "float dropout, " + "bool train, " + "bool bidirectional, " + "bool batch_first) " + "-> __torch__.torch.classes.vulkan.VulkanOpContext")); + m.def(TORCH_SELECTIVE_SCHEMA( + "vulkan_prepack::run_lstm_context(Tensor input_vk, " + "Tensor hx_vk, " + "Tensor cx_vk, " + "__torch__.torch.classes.vulkan.VulkanOpContext L_prepack) -> (Tensor next_input, Tensor hidden_state, Tensor cell_state)")); } TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::create_conv2d_clamp_context"), TORCH_FN(create_conv2d_clamp_context)); - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_clamp_prepack"), TORCH_FN(conv2d_clamp_prepack)); // Backwards compatibility - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::create_conv2d_transpose_clamp_context"), TORCH_FN(create_conv2d_transpose_clamp_context)); - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_transpose_clamp_prepack"), TORCH_FN(conv2d_transpose_clamp_prepack)); // Backwards compatibility - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::create_linear_context"), TORCH_FN(create_linear_context)); - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::linear_prepack"), TORCH_FN(linear_prepack)); // Backwards compatibility - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::create_gru_context"), TORCH_FN(create_gru_context)); - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::gru_prepack"), TORCH_FN(gru_prepack)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::create_conv2d_clamp_context"), + TORCH_FN(create_conv2d_clamp_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_clamp_prepack"), + TORCH_FN(conv2d_clamp_prepack)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME( + "vulkan_prepack::create_conv2d_transpose_clamp_context"), + TORCH_FN(create_conv2d_transpose_clamp_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_transpose_clamp_prepack"), + TORCH_FN(conv2d_transpose_clamp_prepack)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::create_linear_context"), + TORCH_FN(create_linear_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::linear_prepack"), + TORCH_FN(linear_prepack)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::create_gru_context"), + TORCH_FN(create_gru_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::gru_prepack"), + TORCH_FN(gru_prepack)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::create_lstm_context"), + TORCH_FN(create_lstm_context)); } TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::run_conv2d_clamp_context"), TORCH_FN(run_conv2d_clamp_context)); - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_clamp_run"), TORCH_FN(conv2d_clamp_run)); // Backwards compatibility - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::run_conv2d_transpose_clamp_context"), TORCH_FN(run_conv2d_transpose_clamp_context)); - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_transpose_clamp_run"), TORCH_FN(conv2d_transpose_clamp_run)); // Backwards compatibility - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::run_linear_context"), TORCH_FN(run_linear_context)); - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::linear_run"), TORCH_FN(linear_run)); // Backwards compatibility - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::run_gru_context"), TORCH_FN(run_gru_context)); - m.impl(TORCH_SELECTIVE_NAME("vulkan_prepack::gru_run"), TORCH_FN(gru_run)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::run_conv2d_clamp_context"), + TORCH_FN(run_conv2d_clamp_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_clamp_run"), + TORCH_FN(conv2d_clamp_run)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME( + "vulkan_prepack::run_conv2d_transpose_clamp_context"), + TORCH_FN(run_conv2d_transpose_clamp_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_transpose_clamp_run"), + TORCH_FN(conv2d_transpose_clamp_run)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::run_linear_context"), + TORCH_FN(run_linear_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::linear_run"), + TORCH_FN(linear_run)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::run_gru_context"), + TORCH_FN(run_gru_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::gru_run"), + TORCH_FN(gru_run)); // Backwards compatibility + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::run_lstm_context"), + TORCH_FN(run_lstm_context)); } Tensor convolution( @@ -207,17 +262,9 @@ Tensor convolution( const int64_t groups) { if (transposed) { VulkanOpContext vulkan_context = conv2d_transpose_context_create( - weight, - bias, - stride, - padding, - output_padding, - dilation, - groups); + weight, bias, stride, padding, output_padding, dilation, groups); return conv2d_transpose_context_run( - input, - vulkan_context.get_packed(), - vulkan_context.get_unpacked()); + input, vulkan_context.get_packed(), vulkan_context.get_unpacked()); } VulkanOpContext vulkan_context = conv2d_context_create( weight, @@ -229,16 +276,106 @@ Tensor convolution( output_padding, groups); return conv2d_context_run( - input, - vulkan_context.get_packed(), - vulkan_context.get_unpacked()); + input, vulkan_context.get_packed(), vulkan_context.get_unpacked()); +} + +Tensor quantized_convolution( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool transposed, + const IntArrayRef output_padding, + const int64_t groups, + const double out_scale, + const int64_t out_zero_point) { + if (transposed) { + VulkanOpContext vulkan_context = conv2d_transpose_context_create( + weight, bias, stride, padding, output_padding, dilation, groups); + return conv2d_transpose_context_run( + input, vulkan_context.get_packed(), vulkan_context.get_unpacked()); + } + VulkanOpContext vulkan_context = conv2d_context_create_q( + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + c10::nullopt, + c10::nullopt); + return conv2d_context_run_q( + input, + vulkan_context.get_packed(), + vulkan_context.get_unpacked(), + out_scale, + out_zero_point); +} +} // namespace + +static std::tuple batchify( + const Tensor& input, + const int64_t num_spatial_dims, + const std::string& func_name) { + const auto dim_count_no_batch = num_spatial_dims + 1; + const auto dim_count_batch = dim_count_no_batch + 1; + const auto is_batched = (input.dim() == dim_count_batch); + TORCH_CHECK( + input.dim() == dim_count_no_batch || is_batched, + "Expected ", + dim_count_no_batch, + "D (unbatched) or ", + dim_count_batch, + "D (batched) input to ", + func_name, + ", but got input of size: ", + input.sizes()); + return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched); +} + +Tensor conv2d( + const Tensor& input_, + const Tensor& weight, + const c10::optional& bias_opt, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + double out_scale, + int64_t out_zero_point) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + Tensor input; + bool is_batched; + std::tie(input, is_batched) = + batchify(input_, /*num_spatial_dims=*/2, "conv2d"); + Tensor output; + output = quantized_convolution( + input, + weight, + bias, + stride, + padding, + dilation, + false, + {{0, 0}}, + groups, + out_scale, + out_zero_point); + return is_batched ? output : output.squeeze(0); } TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl("convolution_overrideable", convolution); } -} // namespace } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Select.cpp b/aten/src/ATen/native/vulkan/ops/Select.cpp new file mode 100644 index 0000000000000..2500144b1b280 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Select.cpp @@ -0,0 +1,91 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor select_depth(const Tensor& input_arg, uint32_t index) { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + const IntArrayRef v_input_sizes = v_input.sizes(); + + vTensor v_output{ + context, + {v_input_sizes[1], v_input_sizes[2]}, + v_input.options(), + }; + + const struct Block final { + uvec3 size; // output texture size + uint32_t index; + } block{v_output.extents(), index}; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(select_depth), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return convert(v_output); +} + +Tensor select(const Tensor& self, int64_t dim, int64_t index) { + TORCH_CHECK(self.dim() == 3, "Vulkan select only supports 3d tensors!"); + TORCH_CHECK(dim == 0, "Vulkan select only supports dim = 0!"); + + const int64_t size = self.size(dim); + + if (index < -size || index >= size) { + TORCH_CHECK_INDEX( + false, + "select(): index ", + index, + " out of range for tensor of size ", + self.sizes(), + " at dimension ", + dim); + } + if (index < 0) { + index += size; + } + + return select_depth(self, index); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl(TORCH_SELECTIVE_NAME("aten::select.int"), TORCH_FN(select)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp index 86a466942052f..f5c47187c3bda 100644 --- a/aten/src/ATen/native/vulkan/ops/Shape.cpp +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -1,74 +1,68 @@ -#include +#include #include +#include #include namespace at { namespace native { namespace vulkan { namespace ops { -namespace { -Tensor view_internal( - const Tensor& self_arg, - const IntArrayRef shape, - const std::string& op_name) { +Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { api::Context* const context = api::context(); - const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); - const vTensor& v_self = convert(self); + Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + vTensor& v_self = convert(self); + + at::DimVector inferred_size = at::infer_size_dv(shape, self.numel()); vTensor v_output{ - context, - shape, - self.options(), + context, + inferred_size, + self.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - command_buffer.copy( - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.buffer( - command_buffer, - vTensor::Stage::Transfer), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.buffer( - command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write)); - } - command_pool.submit(context->gpu().queue, command_buffer); + api::StagingBuffer buffer(context, v_self.buffer_bytes(), true); + + utils::pack_vtensor_to_staging(v_self, buffer.buffer()); + + api::PipelineBarrier pipeline_barrier{}; + add_buffer_barrier( + pipeline_barrier, + buffer.buffer(), + // Previous access + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE, + // Next access + api::PipelineStage::COMPUTE, + api::MemoryAccessType::READ); + + utils::pack_buffer_to_vtensor(buffer.buffer(), v_output, pipeline_barrier); return convert(v_output); } -inline Tensor view( - const Tensor& self_arg, - const IntArrayRef shape) { - return view_internal(self_arg, shape, "aten::view"); +inline Tensor view(const Tensor& self_arg, const IntArrayRef shape) { + return view_internal(self_arg, shape); } Tensor _reshape_alias( const Tensor& self_arg, const IntArrayRef shape, const IntArrayRef strides) { - return view_internal(self_arg, shape, "aten::_reshape_alias"); + return view_internal(self_arg, shape); } #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::view"), TORCH_FN(view)); - m.impl(TORCH_SELECTIVE_NAME("aten::_reshape_alias"), TORCH_FN(_reshape_alias)); + m.impl( + TORCH_SELECTIVE_NAME("aten::_reshape_alias"), TORCH_FN(_reshape_alias)); } #endif /* USE_VULKAN_API */ -} // namespace } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Slice.cpp b/aten/src/ATen/native/vulkan/ops/Slice.cpp index 1d454c7ff7098..d45bff6af4066 100644 --- a/aten/src/ATen/native/vulkan/ops/Slice.cpp +++ b/aten/src/ATen/native/vulkan/ops/Slice.cpp @@ -1,6 +1,4 @@ #include -#include -#include #include #include @@ -12,193 +10,214 @@ namespace { using namespace api::utils; -Tensor slice_4d(const Tensor& input, const int64_t dim, const int64_t start, const int64_t end, - const int64_t step, const uvec4& in_tsize, const uvec4& out_tsize, vTensor& v_output) { +Tensor slice_4d( + const Tensor& input_arg, + const int64_t dim, + const int64_t start, + const int64_t end, + const int64_t step, + const uvec4& in_tsize, + const uvec4& out_tsize, + vTensor& v_output) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::slice.Tensor (slice_4d)"); - - const Tensor self = input.is_vulkan() ? input : input.vulkan(); - const vTensor& v_self = convert(self); - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - auto src_image = v_self.image( - command_buffer, - vTensor::Stage::Compute); - auto dst_image = v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write); - - const struct Block final { - uvec3 size; // output texture size - uint32_t fill_0; // dummy - uvec3 isize; // input texture size - uint32_t fill_1; // dummy - uvec4 tensor_size; // output tensor size - uvec4 itensor_size; // input tensor size - uvec4 args; // input arguments (dim, start, end, step) - } block { - v_output.extents(), - 0u, - v_self.extents(), - 0u, - out_tsize, - in_tsize, - { safe_downcast(dim), - safe_downcast(start), - safe_downcast(end), - safe_downcast(step) }, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(slice_4d), - // build up shader operations from the output texture point of view - // to avoid the nondeterministic order of GPU shader operations between texels - v_output.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - dst_image, - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - src_image, - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_self = convert(input); + + const struct Block final { + uvec3 size; // output texture size + uint32_t fill_0; // dummy + uvec3 isize; // input texture size + uint32_t fill_1; // dummy + uvec4 tensor_size; // output tensor size + uvec4 itensor_size; // input tensor size + uvec4 args; // input arguments (dim, start, end, step) + } block{ + v_output.extents(), + 0u, + v_self.extents(), + 0u, + out_tsize, + in_tsize, + {safe_downcast(dim), + safe_downcast(start), + safe_downcast(end), + safe_downcast(step)}, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(slice_4d), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + return convert(v_output); } -Tensor slice_width(const Tensor& input, const int64_t start, const int64_t end, const int64_t step, vTensor& v_output) { +Tensor slice_width( + const Tensor& input_arg, + const int64_t start, + const int64_t end, + const int64_t step, + vTensor& v_output) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::slice.Tensor (slice_width)"); - - const Tensor self = input.is_vulkan() ? input : input.vulkan(); - const vTensor& v_self = convert(self); - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - auto src_image = v_self.image( - command_buffer, - vTensor::Stage::Transfer); - auto dst_image = v_output.image( - command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write); - - uvec3 src_offset{}; - uvec3 dst_offset{}; - - if (step == 1) { - src_offset.data[0u] = start; - uvec3 copy_extents {safe_downcast(end - start), - v_self.extents().data[1u], - v_self.extents().data[2u]}; - api::helper::copy_texture_to_texture(command_buffer, - src_image, - dst_image, + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_self = convert(input); + + uvec3 src_offset{}; + uvec3 dst_offset{}; + + if (step == 1) { + src_offset.data[0u] = start; + + uvec3 copy_extents{ + safe_downcast(end - start), + v_self.extents().data[1u], + v_self.extents().data[2u]}; + + api::PipelineBarrier pipeline_barrier{}; + + context->submit_texture_copy( + // pipeline barrier + pipeline_barrier, + // images + v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER), + v_output.image( + pipeline_barrier, + api::PipelineStage::TRANSFER, + api::MemoryAccessType::WRITE), + // copy details + copy_extents, + src_offset, + dst_offset, + // fence handle + VK_NULL_HANDLE); + } else { + uvec3 copy_extents{ + 1u, v_self.extents().data[1u], v_self.extents().data[2u]}; + + const auto x_max = v_self.extents().data[0u]; + + for (int64_t x = start, x_new = 0; x < end; x += step, ++x_new) { + if (x >= x_max) { // out of range + continue; + } + + src_offset.data[0u] = x; + dst_offset.data[0u] = x_new; + + api::PipelineBarrier pipeline_barrier{}; + + context->submit_texture_copy( + // pipeline barrier + pipeline_barrier, + // images + v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER), + v_output.image( + pipeline_barrier, + api::PipelineStage::TRANSFER, + api::MemoryAccessType::WRITE), + // copy details copy_extents, src_offset, - dst_offset); - } else { - uvec3 copy_extents {1u, - v_self.extents().data[1u], - v_self.extents().data[2u]}; - const auto x_max = v_self.extents().data[0u]; - for (int64_t x = start, x_new = 0; x < end; x += step, ++x_new) { - if (x >= x_max) { // out of range - continue; - } - src_offset.data[0u] = x; - dst_offset.data[0u] = x_new; - api::helper::copy_texture_to_texture(command_buffer, - src_image, - dst_image, - copy_extents, - src_offset, - dst_offset); - } - } - } - else { - TORCH_CHECK(false, "Not implemented!"); + dst_offset, + // fence handle + VK_NULL_HANDLE); } } - command_pool.submit(context->gpu().queue, command_buffer); + return convert(v_output); } -Tensor slice_height(const Tensor& input, const int64_t start, const int64_t end, const int64_t step, vTensor& v_output) { +Tensor slice_height( + const Tensor& input_arg, + const int64_t start, + const int64_t end, + const int64_t step, + vTensor& v_output) { api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::slice.Tensor (slice_height)"); - - const Tensor self = input.is_vulkan() ? input : input.vulkan(); - const vTensor& v_self = convert(self); - if C10_LIKELY(v_output.has_image() && v_self.has_image()) { - auto src_image = v_self.image( - command_buffer, - vTensor::Stage::Transfer); - auto dst_image = v_output.image( - command_buffer, - vTensor::Stage::Transfer, - vTensor::Access::Write); - - uvec3 src_offset{}; - uvec3 dst_offset{}; - - if (step == 1) { - src_offset.data[1u] = start; - uvec3 copy_extents {v_self.extents().data[0u], - safe_downcast(end - start), - v_self.extents().data[2u]}; - api::helper::copy_texture_to_texture(command_buffer, - src_image, - dst_image, + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_self = convert(input); + + uvec3 src_offset{}; + uvec3 dst_offset{}; + + if (step == 1) { + src_offset.data[1u] = start; + + uvec3 copy_extents{ + v_self.extents().data[0u], + safe_downcast(end - start), + v_self.extents().data[2u]}; + + api::PipelineBarrier pipeline_barrier{}; + + context->submit_texture_copy( + // pipeline barrier + pipeline_barrier, + // images + v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER), + v_output.image( + pipeline_barrier, + api::PipelineStage::TRANSFER, + api::MemoryAccessType::WRITE), + // copy details + copy_extents, + src_offset, + dst_offset, + // fence handle + VK_NULL_HANDLE); + } else { + uvec3 copy_extents{ + v_self.extents().data[0u], 1u, v_self.extents().data[2u]}; + + const auto y_max = v_self.extents().data[1u]; + for (int64_t y = start, y_new = 0; y < end; y += step, ++y_new) { + if (y >= y_max) { // out of range + continue; + } + src_offset.data[1u] = y; + dst_offset.data[1u] = y_new; + + api::PipelineBarrier pipeline_barrier{}; + + context->submit_texture_copy( + // pipeline barrier + pipeline_barrier, + // images + v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER), + v_output.image( + pipeline_barrier, + api::PipelineStage::TRANSFER, + api::MemoryAccessType::WRITE), + // copy details copy_extents, src_offset, - dst_offset); - } else { - uvec3 copy_extents {v_self.extents().data[0u], - 1u, - v_self.extents().data[2u]}; - const auto y_max = v_self.extents().data[1u]; - for (int64_t y = start, y_new = 0; y < end; y += step, ++y_new) { - if (y >= y_max) { // out of range - continue; - } - src_offset.data[1u] = y; - dst_offset.data[1u] = y_new; - api::helper::copy_texture_to_texture(command_buffer, - src_image, - dst_image, - copy_extents, - src_offset, - dst_offset); - } - } - } - else { - TORCH_CHECK(false, "Not implemented!"); + dst_offset, + // fence handle + VK_NULL_HANDLE); } } - command_pool.submit(context->gpu().queue, command_buffer); + return convert(v_output); } @@ -250,19 +269,15 @@ Tensor slice( } dim += 4 - nDims; - vTensor v_output{ - api::context(), - newSizes, - self.options()}; + vTensor v_output{api::context(), newSizes, self.options()}; if (dim == 3) { slice_width(self, start_val, end_val, step, v_output); - } - else if (dim == 2) { + } else if (dim == 2) { slice_height(self, start_val, end_val, step, v_output); - } - else { - slice_4d(self, dim, start_val, end_val, step, in_tsize, out_tsize, v_output); + } else { + slice_4d( + self, dim, start_val, end_val, step, in_tsize, out_tsize, v_output); } auto result = convert(v_output); diff --git a/aten/src/ATen/native/vulkan/ops/Softmax.cpp b/aten/src/ATen/native/vulkan/ops/Softmax.cpp index f36a5fc54540a..b0e449fc8ad35 100644 --- a/aten/src/ATen/native/vulkan/ops/Softmax.cpp +++ b/aten/src/ATen/native/vulkan/ops/Softmax.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -15,15 +14,13 @@ Tensor softmax_internal( const at::Tensor& input_arg, const int64_t dim, const bool half_to_float, - const api::Shader::Descriptor& shader_descriptor, - const std::string& op_name) { + const api::ShaderSource& shader_descriptor) { TORCH_CHECK( - input_arg.dim() == 4, - "Vulkan softmax expects 4-dimensional input!"); + input_arg.dim() == 4, "Vulkan softmax expects 4-dimensional input!"); - TORCH_CHECK( - dim == 1, - "Vulkan softmax expects dim == 1 (channel)"); + TORCH_CHECK(dim == 1, "Vulkan softmax expects dim == 1 (channel)"); + + api::Context* const context = api::context(); const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); const vTensor& v_input = convert(input); @@ -33,74 +30,56 @@ Tensor softmax_internal( v_input_sizes[Layout::Activation4D::batch] == 1, "Vulkan softmax expects batch dim == 1"); - api::Context* const context = api::context(); - c10::SmallVector output_sizes{ - v_input_sizes[Layout::Activation4D::batch], - v_input_sizes[Layout::Activation4D::channels], - v_input_sizes[Layout::Activation4D::height], - v_input_sizes[Layout::Activation4D::width], + v_input_sizes[Layout::Activation4D::batch], + v_input_sizes[Layout::Activation4D::channels], + v_input_sizes[Layout::Activation4D::height], + v_input_sizes[Layout::Activation4D::width], }; vTensor v_output{ - context, - output_sizes, - v_input.options(), + context, + output_sizes, + v_input.options(), }; - const api::Shader::WorkGroup global_work_group_size = { - safe_downcast(v_input_sizes[Layout::Activation4D::width]), - safe_downcast(v_input_sizes[Layout::Activation4D::height]), - 1, + const api::utils::uvec3 global_work_group_size = { + safe_downcast(v_input_sizes[Layout::Activation4D::width]), + safe_downcast(v_input_sizes[Layout::Activation4D::height]), + 1, }; - const api::Shader::WorkGroup local_work_group_size = {8, 8, 1}; - - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), op_name); - - if C10_LIKELY(v_input.has_image()) { - const struct Block final { - uvec3 iextents; - int last_texel_end_offset; - } block { - v_input.extents(), - safe_downcast( - (v_input_sizes[Layout::Activation4D::channels] - 1) % 4 - ) - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader_descriptor, - global_work_group_size, - local_work_group_size, - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const api::utils::uvec3 local_work_group_size = {8, 8, 1}; + + const struct Block final { + uvec3 iextents; + int last_texel_end_offset; + } block{ + v_input.extents(), + safe_downcast( + (v_input_sizes[Layout::Activation4D::channels] - 1) % 4)}; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader_descriptor, + // pipeline barrier + pipeline_barrier, + // global work group size + global_work_group_size, + // local work group size + local_work_group_size, + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } @@ -109,14 +88,15 @@ Tensor softmax( const at::Tensor& input_arg, const int64_t dim, const bool half_to_float) { - return softmax_internal(input_arg, dim, half_to_float, VK_KERNEL(softmax), "_softmax"); + return softmax_internal(input_arg, dim, half_to_float, VK_KERNEL(softmax)); } Tensor log_softmax( const at::Tensor& input_arg, const int64_t dim, const bool half_to_float) { - return softmax_internal(input_arg, dim, half_to_float, VK_KERNEL(log_softmax), "_log_softmax"); + return softmax_internal( + input_arg, dim, half_to_float, VK_KERNEL(log_softmax)); } #ifdef USE_VULKAN_API diff --git a/aten/src/ATen/native/vulkan/ops/Stack.cpp b/aten/src/ATen/native/vulkan/ops/Stack.cpp new file mode 100644 index 0000000000000..1206cf8c58556 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Stack.cpp @@ -0,0 +1,108 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor stack_feature(const TensorList tensors, vTensor& v_output) { + api::Context* const context = api::context(); + + uint32_t num_tensors = tensors.size(); + for (const auto i : c10::irange(v_output.extents().data[2])) { + const vTensor& v_t0 = convert( + tensors[4 * i].is_vulkan() ? tensors[4 * i] : tensors[4 * i].vulkan()); + const vTensor& v_t1 = 4 * i + 1 < num_tensors + ? convert( + tensors[4 * i + 1].is_vulkan() ? tensors[4 * i + 1] + : tensors[4 * i + 1].vulkan()) + : v_t0; + const vTensor& v_t2 = 4 * i + 2 < num_tensors + ? convert( + tensors[4 * i + 2].is_vulkan() ? tensors[4 * i + 2] + : tensors[4 * i + 2].vulkan()) + : v_t0; + const vTensor& v_t3 = 4 * i + 3 < num_tensors + ? convert( + tensors[4 * i + 3].is_vulkan() ? tensors[4 * i + 3] + : tensors[4 * i + 3].vulkan()) + : v_t0; + + const struct Block final { + uvec3 size; // output texture size + uint32_t z; // texel along the channel-batch dimension to copy data to + } block{v_output.extents(), i}; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(stack_feature), + // pipeline barrier + pipeline_barrier, + // global work group size + v_t0.extents(), + // local work group size + adaptive_work_group_size(v_t0.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_t0.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_t1.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_t2.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_t3.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + } + + return convert(v_output); +} + +Tensor stack(const at::TensorList tensors, const int64_t dim) { + TORCH_CHECK(tensors.size() > 0, "Vulkan stack expects at least one tensor"); + TORCH_CHECK(dim == 0, "Vulkan stack expects dim = 0"); + + at::Tensor tensor = tensors[0]; + + for (const auto& t : tensors) { + TORCH_CHECK(t.dim() == 2, "Vulkan stack expects 2 dimensional inputs"); + + for (const auto d : c10::irange(t.dim())) { + TORCH_CHECK( + t.size(d) == tensor.size(d), + "Vulkan stack inputs must have matching sizes"); + } + } + + uint32_t num_tensors = tensors.size(); + std::vector output_sizes = { + num_tensors, tensor.size(0), tensor.size(1)}; + + vTensor v_output{api::context(), output_sizes, tensor.options()}; + + return stack_feature(tensors, v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl(TORCH_SELECTIVE_NAME("aten::stack"), TORCH_FN(stack)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.cpp b/aten/src/ATen/native/vulkan/ops/Tensor.cpp index 0de253447d2d8..337222c0fd5ff 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tensor.cpp @@ -1,229 +1,15 @@ -#include -#include #include +#include #include namespace at { namespace native { namespace vulkan { namespace ops { -namespace { - -using namespace api::utils; - -uvec3 image_extents(IntArrayRef); -bool requires_image(IntArrayRef); -bool requires_staging(const api::Adapter*); - -VkFormat vk_format(const caffe2::TypeMeta dtype) { - switch (c10::typeMetaToScalarType(dtype)) { - case kFloat: - #ifdef USE_VULKAN_FP16_INFERENCE - return VK_FORMAT_R16G16B16A16_SFLOAT; - #else - return VK_FORMAT_R32G32B32A32_SFLOAT; - #endif /* USE_VULKAN_FP16_INFERENCE */ - - default: - TORCH_CHECK( - false, - "Vulkan tensor format not supported!"); - } - - return VK_FORMAT_UNDEFINED; -} - -VkExtent3D vk_extent(const uvec3& extent) { - return { - extent.data[0u], - extent.data[1u], - extent.data[2u], - }; -} - -vTensor::Access::Flags access( - const VkAccessFlags vk_access) { - vTensor::Access::Flags access = 0u; - - constexpr VkAccessFlags kRead = - VK_ACCESS_HOST_READ_BIT | - VK_ACCESS_MEMORY_READ_BIT | - VK_ACCESS_SHADER_READ_BIT | - VK_ACCESS_TRANSFER_READ_BIT | - VK_ACCESS_UNIFORM_READ_BIT; - - constexpr VkAccessFlags kWrite = - VK_ACCESS_HOST_WRITE_BIT | - VK_ACCESS_MEMORY_WRITE_BIT | - VK_ACCESS_SHADER_WRITE_BIT | - VK_ACCESS_TRANSFER_WRITE_BIT; - - if (vk_access & kRead) { - access |= vTensor::Access::Read; - } - - if (vk_access & kWrite) { - access |= vTensor::Access::Write; - } - - return access; -} - -VkAccessFlags vk_access( - const vTensor::Stage::Flags stage, - const vTensor::Access::Flags access) { - VkAccessFlags vk_access = 0u; - - if (access & vTensor::Access::Read) { - if (stage & vTensor::Stage::Compute) { - vk_access |= VK_ACCESS_SHADER_READ_BIT; - } - - if (stage & vTensor::Stage::Host) { - vk_access |= VK_ACCESS_HOST_READ_BIT; - } - - if (stage & vTensor::Stage::Transfer) { - vk_access |= VK_ACCESS_TRANSFER_READ_BIT; - } - } - - if (access & vTensor::Access::Write) { - if (stage & vTensor::Stage::Compute) { - vk_access |= VK_ACCESS_SHADER_WRITE_BIT; - } - - if (stage & vTensor::Stage::Host) { - vk_access |= VK_ACCESS_HOST_WRITE_BIT; - } - - if (stage & vTensor::Stage::Transfer) { - vk_access |= VK_ACCESS_TRANSFER_WRITE_BIT; - } - } - - return vk_access; -} - -VkImageLayout vk_layout( - const vTensor::Stage::Flags stage, - const vTensor::Access::Flags access) { - switch (stage) { - case vTensor::Stage::Compute: - switch (access) { - case vTensor::Access::Read: - return VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL; - default: - return VK_IMAGE_LAYOUT_GENERAL; - } break; - - case vTensor::Stage::Transfer: - switch (access) { - case vTensor::Access::Read: - return VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL; - - case vTensor::Access::Write: - return VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL; - - default: - TORCH_INTERNAL_ASSERT(false, "Invalid!"); - } break; - - default: - TORCH_INTERNAL_ASSERT(false, "Invalid!"); - } - - return VK_IMAGE_LAYOUT_UNDEFINED; -} - -VkPipelineStageFlags vk_stage( - const vTensor::Stage::Flags stage) { - VkPipelineStageFlags vk_stage = 0u; - - if (stage & vTensor::Stage::Compute) { - vk_stage |= VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; - } - - if (stage & vTensor::Stage::Host) { - vk_stage |= VK_PIPELINE_STAGE_HOST_BIT; - } - - if (stage & vTensor::Stage::Transfer) { - vk_stage |= VK_PIPELINE_STAGE_TRANSFER_BIT; - } - - return vk_stage; -} - -VkDeviceSize buffer_bytes( - const IntArrayRef sizes, - const caffe2::TypeMeta dtype) { - VkDeviceSize size = c10::elementSize(c10::typeMetaToScalarType(dtype)); - - if (requires_image(sizes)) { - const uvec3 extents = image_extents(sizes); - size *= extents.data[0u] * extents.data[1u] * (4u * extents.data[2u]); - } - else { - size *= c10::multiply_integers(sizes); - } - - return size; -} - -vTensor::Buffer allocate_buffer( - const api::Adapter* const adapter, - api::Resource::Pool* const pool, - const IntArrayRef sizes, - const TensorOptions& options) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - adapter, - "Invalid Vulkan adapter!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - pool, - "Invalid Vulkan resource pool!"); - - TORCH_CHECK(!sizes.empty(), "Invalid Vulkan tensor size!"); - verify(options); - - const VkFlags usage = - VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | - VK_BUFFER_USAGE_TRANSFER_SRC_BIT | - VK_BUFFER_USAGE_TRANSFER_DST_BIT; - - const auto memory = [adapter]() -> api::Resource::Memory::Descriptor { - if (requires_staging(adapter)) { - return { - VMA_MEMORY_USAGE_GPU_ONLY, - 0u, - 0u, - }; - } - - return { - VMA_MEMORY_USAGE_GPU_TO_CPU, - 0u, - VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, - }; - }(); - - return pool->create_buffer({ - buffer_bytes(sizes, options.dtype()), - // Usage - { - usage, - memory, - }, - }); -} - -bool requires_image(const IntArrayRef sizes) { - return (1u <= sizes.size()) && (sizes.size() <= 4u); -} +namespace { -uvec3 image_extents(const IntArrayRef sizes) { +api::utils::uvec3 image_extents(const IntArrayRef sizes) { int64_t width = 1; int64_t height = 1; int64_t depth = 1; @@ -257,889 +43,187 @@ uvec3 image_extents(const IntArrayRef sizes) { } return { - safe_downcast(width), - safe_downcast(height), - safe_downcast(div_up(depth, INT64_C(4))), + api::utils::safe_downcast(width), + api::utils::safe_downcast(height), + api::utils::safe_downcast( + api::utils::div_up(depth, INT64_C(4))), }; } -vTensor::Image allocate_image( - api::Resource::Pool* const pool, - const VkExtent3D& extents, - const TensorOptions& options) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - pool, - "Invalid Vulkan resource pool!"); - - verify(options); - - return pool->create_image({ - VK_IMAGE_TYPE_3D, - vk_format(options.dtype()), - extents, - // Usage - { - VK_IMAGE_USAGE_SAMPLED_BIT | - VK_IMAGE_USAGE_STORAGE_BIT | - VK_IMAGE_USAGE_TRANSFER_SRC_BIT | // for vkCmdCopyImage - VK_IMAGE_USAGE_TRANSFER_DST_BIT, // for vkCmdCopyImage - { - VMA_MEMORY_USAGE_GPU_ONLY, - 0u, - 0u, - }, - }, - // View - { - VK_IMAGE_VIEW_TYPE_3D, - vk_format(options.dtype()), - }, - // Sampler - { - VK_FILTER_NEAREST, - VK_SAMPLER_MIPMAP_MODE_NEAREST, - VK_SAMPLER_ADDRESS_MODE_REPEAT, - VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK, - }, - }); -} - -bool requires_staging(const api::Adapter* const adapter) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - adapter, - "Invalid Vulkan adapter!"); - - return !adapter->has_unified_memory(); -} - -vTensor::Buffer allocate_staging( - const api::Adapter* const adapter, - api::Resource::Pool* const pool, - const IntArrayRef sizes, - const TensorOptions& options) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - adapter, - "Invalid Vulkan adapter!"); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - pool, - "Invalid Vulkan resource pool!"); - - TORCH_CHECK(!sizes.empty(), "Invalid Vulkan tensor size!"); - verify(options); - - return pool->create_buffer({ - buffer_bytes(sizes, options.dtype()), - // Usage - { - VK_BUFFER_USAGE_TRANSFER_SRC_BIT | - VK_BUFFER_USAGE_TRANSFER_DST_BIT, - { - VMA_MEMORY_USAGE_CPU_COPY, - 0u, - 0u, - }, - }, - }); -} - -vTensor::Fence allocate_fence(api::Resource::Pool* const pool) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - pool, - "Invalid Vulkan resource pool!"); - - return pool->fence(); -} - -enum class Barrier { - None, - Exectution, - Memory, -}; - -Barrier categorize( - const VkAccessFlags vk_src_access, - const VkAccessFlags vk_dst_access) { - if (0u == vk_src_access) { - return Barrier::None; - } - - const vTensor::Access::Flags src_access = access(vk_src_access); - const vTensor::Access::Flags dst_access = access(vk_dst_access); - - if ((src_access & vTensor::Access::Read) == src_access) { - if ((dst_access & vTensor::Access::Read) == dst_access) { - // RAR (Read after Read) - return Barrier::None; - } - - // WAR (Write after Read) - return Barrier::Exectution; - } - - // RAW (Read after Write), or WAW (Write after Write) - return Barrier::Memory; -}; - -Barrier categorize( - const VkAccessFlags vk_src_access, - const VkAccessFlags vk_dst_access, - const VkImageLayout vk_src_layout, - const VkImageLayout vk_dst_layout) { - if (vk_src_layout != vk_dst_layout) { - return Barrier::Memory; - } - - return categorize(vk_src_access, vk_dst_access); -} - } // namespace +// +// vTensor +// + vTensor::vTensor( api::Context* const context, const IntArrayRef sizes, const TensorOptions& options) - : vTensor( - context, - &context->resource().pool, - sizes, - options) { -} + : view_(new vTensorStorage{ + context, + sizes, + options, + }) {} vTensor::vTensor( api::Context* const context, - api::Resource::Pool* const pool, const IntArrayRef sizes, - const TensorOptions& options) - : view_(new View{ - context, - pool, - sizes, - options, - }) { -} + const TensorOptions& options, + double q_scale, + int64_t q_zero_point) + : view_( + new vTensorStorage{context, sizes, options, q_scale, q_zero_point}) {} -const vTensor* vTensor::host( - api::Command::Buffer& command_buffer) const { - view_->staging(command_buffer, Stage::Host, Access::Read); - return this; -} +api::VulkanImage& vTensor::image( + api::PipelineBarrier& pipeline_barrier, + const api::PipelineStageFlags stage) const& { + view_->transition(pipeline_barrier, stage, api::MemoryAccessType::READ); -vTensor* vTensor::host( - api::Command::Buffer& command_buffer, - const Access::Flags access) { - view_->staging(command_buffer, Stage::Host, access); - return this; + return view_->image_; } -vTensor::Buffer::Object vTensor::buffer( - api::Command::Buffer& command_buffer, - const Stage::Flags stage) const & { - return view_->buffer( - command_buffer, - stage, - Access::Read).object; -} +api::VulkanImage& vTensor::image( + api::PipelineBarrier& pipeline_barrier, + const api::PipelineStageFlags stage, + const api::MemoryAccessFlags access) & { + view_->transition(pipeline_barrier, stage, access); -vTensor::Buffer::Object vTensor::buffer( - api::Command::Buffer& command_buffer, - const Stage::Flags stage, - const Access::Flags access) & { - return view_->buffer( - command_buffer, - stage, - access).object; + return view_->image_; } -vTensor::Image::Object vTensor::image( - api::Command::Buffer& command_buffer, - const Stage::Flags stage) const & { - return view_->image( - command_buffer, - stage, - Access::Read).object; -} +// +// vTensorStorage +// -vTensor::Image::Object vTensor::image( - api::Command::Buffer& command_buffer, - const Stage::Flags stage, - const Access::Flags access) & { - return view_->image( - command_buffer, - stage, - access).object; -} +api::VulkanImage allocate_image( + api::Context* const context_ptr, + api::utils::uvec3& extents, + const caffe2::TypeMeta dtype) { + api::Adapter* adapter_ptr = context_ptr->adapter_ptr(); -vTensor::View::View() - // Resources - : buffer_{}, - image_{}, - staging_{}, - fence_{}, - // Context - context_(nullptr), - pool_(nullptr), - // State - state_{}, - // Metadata - extents_{} { + api::ImageSampler::Properties sampler_props{ + VK_FILTER_NEAREST, + VK_SAMPLER_MIPMAP_MODE_NEAREST, + VK_SAMPLER_ADDRESS_MODE_REPEAT, + VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK, + }; + + VkSampler sampler = adapter_ptr->sampler_cache().retrieve(sampler_props); + + return adapter_ptr->vma().create_image3d( + api::create_extent3d(extents), sampler_props, sampler, dtype, true); } -vTensor::View::View( +vTensorStorage::vTensorStorage( api::Context* const context, - api::Resource::Pool* const pool, const IntArrayRef sizes, const TensorOptions& options) - // Resources - : buffer_{}, - image_{}, - staging_{}, - fence_{}, - // Context - context_(context), - pool_(pool), - // State - state_(context->gpu().adapter, sizes), - // Metadata - extents_(image_extents(sizes)), - options_(options), - sizes_(sizes), - strides_(sizes.size()) { + : context_(context), + extents_(image_extents(sizes)), + options_(options), + sizes_(sizes), + strides_(sizes.size()), + image_(allocate_image(context_, extents_, options_.dtype())), + last_access_{} { ops::verify(options); } -vTensor::View::~View() { - release(); -} - -void vTensor::View::release() { - pool_->register_image_cleanup(image_); - pool_->register_buffer_cleanup(buffer_); - if (staging_) { - pool_->register_buffer_cleanup(staging_); - } +vTensorStorage::vTensorStorage( + api::Context* const context, + const IntArrayRef sizes, + const TensorOptions& options, + double q_scale_in, + int64_t q_zero_point_in) + : context_(context), + extents_(image_extents(sizes)), + options_(options), + sizes_(sizes), + strides_(sizes.size()), + is_quantized_{true}, + q_scale{q_scale_in}, + q_zero_point{q_zero_point_in}, + image_(allocate_image(context_, extents_, options_.dtype())), + last_access_{} { + ops::verify(options); } -class vTensor::View::CMD final { - public: - CMD(const View&, api::Command::Buffer&); - CMD(const CMD&) = delete; - CMD& operator=(const CMD&) = delete; - CMD(CMD&&) = delete; - CMD& operator=(CMD&&) = delete; - ~CMD() = default; - - typedef api::Resource::Buffer Buffer; - typedef api::Resource::Image Image; - typedef api::Resource::Fence Fence; - - void barrier(State::Transition transition); - - void copy_buffer_to_staging( - State& state, - const Buffer::Object& buffer, - Buffer::Object& staging); - - void copy_staging_to_buffer( - State& state, - const Buffer::Object& staging, - Buffer::Object& buffer); - - void copy_buffer_to_image( - State& state, - const Buffer::Object& buffer, - Image::Object& image); - - void copy_image_to_buffer( - State& state, - const Image::Object& image, - Buffer::Object& buffer); - - void submit(Fence fence); - - private: - const View& view_; - api::Command::Buffer& command_buffer_; -}; - -vTensor::View::CMD::CMD( - const View& view, - api::Command::Buffer& command_buffer) - : view_(view), - command_buffer_(command_buffer) { +vTensorStorage::~vTensorStorage() { + context_->register_image_cleanup(image_); } -void vTensor::View::CMD::barrier(State::Transition transition) { - // Buffer and Staging are just an alias for the same memory region on UMA. - - if (view_.state_.is_uma()) { - transition.first.buffer.stage |= transition.first.staging.stage; - transition.first.buffer.access |= transition.first.staging.access; - transition.first.staging = {}; - - transition.second.buffer.stage |= transition.second.staging.stage; - transition.second.buffer.access |= transition.second.staging.access; - transition.second.staging = {}; - } - - // Filter out host dependencies out of source, per Vulkan spec host write ordering guarantees: - // https://www.khronos.org/registry/vulkan/specs/1.2/html/vkspec.html#synchronization-submission-host-writes - - const auto filter_stage =[](VkPipelineStageFlags& stage) { - stage &= ~VK_PIPELINE_STAGE_HOST_BIT; - }; - - filter_stage(transition.first.buffer.stage); - filter_stage(transition.first.staging.stage); - - const auto filter_access =[](VkAccessFlags& access) { - access &= ~(VK_ACCESS_HOST_READ_BIT | VK_ACCESS_HOST_WRITE_BIT); - }; - - filter_access(transition.first.buffer.access); - filter_access(transition.first.staging.access); - - api::Pipeline::Barrier barrier{}; +void vTensorStorage::transition( + api::PipelineBarrier& pipeline_barrier, + const api::PipelineStageFlags cur_stage, + const api::MemoryAccessFlags cur_access) { + // Get last stage access + api::PipelineStageFlags prev_stage = last_access_.stage; + api::MemoryAccessFlags prev_access = last_access_.access; - if (transition.second.staging) { - const State::Bundle::Buffer from = transition.first.staging; - const State::Bundle::Buffer to = transition.second.staging; + const VkImageLayout cur_layout = image_.layout(); + const VkImageLayout new_layout = api::vk_layout(cur_stage, cur_access); - const Barrier category = categorize( - from.access, - to.access); + const bool layout_changed = cur_layout != new_layout; + const bool prev_written = (prev_access & api::MemoryAccessType::WRITE) != 0; - if (Barrier::None != category) { - barrier.stage.src |= from.stage; - barrier.stage.dst |= to.stage; - - if (Barrier::Memory == category) { - barrier.buffers.push_back({ - view_.staging().object, - { - from.access, - to.access, - }, - }); - } + if (prev_written || layout_changed) { + VkPipelineStageFlags src_stage = api::vk_stage(prev_stage); + if (0u == src_stage) { + src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; } - } - - if (transition.second.buffer) { - const State::Bundle::Buffer from = transition.first.buffer; - const State::Bundle::Buffer to = transition.second.buffer; - - const Barrier category = categorize( - from.access, - to.access); - - if (Barrier::None != category) { - barrier.stage.src |= from.stage; - barrier.stage.dst |= to.stage; - - if (Barrier::Memory == category) { - barrier.buffers.push_back({ - view_.buffer().object, - { - from.access, - to.access, - }, - }); - } + VkPipelineStageFlags dst_stage = api::vk_stage(cur_stage); + if (0u == dst_stage) { + dst_stage = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT; } - } - - if (transition.second.image) { - const State::Bundle::Image from = transition.first.image; - const State::Bundle::Image to = transition.second.image; - const Barrier category = categorize( - from.access, - to.access, - from.layout, - to.layout); + pipeline_barrier.stage.src |= src_stage; + pipeline_barrier.stage.dst |= dst_stage; - if (Barrier::None != category) { - barrier.stage.src |= from.stage; - barrier.stage.dst |= to.stage; + pipeline_barrier.images.push_back(api::ImageMemoryBarrier( + api::vk_access(prev_stage, prev_access), + api::vk_access(cur_stage, cur_access), + cur_layout, + new_layout, + image_)); - if (Barrier::Memory == category) { - TORCH_INTERNAL_ASSERT( - from.layout == view_.image().object.layout, - "Invalid image layout!"); - - barrier.images.push_back({ - view_.image().object, - { - from.access, - to.access, - }, - { - from.layout, - to.layout, - }, - }); - - view_.image().object.layout = to.layout; - } - } + image_.set_layout(new_layout); } - // If we are left with anything meaningful, insert a barrier. - - if (barrier) { - if (0u == barrier.stage.src) { - barrier.stage.src = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; - } - - if (0u == barrier.stage.dst) { - barrier.stage.src = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT; - } - - command_buffer_.barrier(barrier); - } + last_access_.stage = cur_stage; + last_access_.access = cur_access; } -void vTensor::View::CMD::copy_buffer_to_staging( - State& state, - const Buffer::Object& buffer, - Buffer::Object& staging) { - if (state.is_clean(Component::Staging) || state.is_uma()) { - return; - } +void add_buffer_barrier( + api::PipelineBarrier& pipeline_barrier, + const api::VulkanBuffer& buffer, + const api::PipelineStageFlags prev_stage, + const api::MemoryAccessFlags prev_access, + const api::PipelineStageFlags cur_stage, + const api::MemoryAccessFlags cur_access) { + // Check for RAW + const bool read_requested = (cur_access & api::MemoryAccessType::READ) != 0; + const bool prev_written = (prev_access & api::MemoryAccessType::WRITE) != 0; - barrier( - state.transition({ - // Staging - { - vk_stage(Stage::Transfer), - vk_access(Stage::Transfer, Access::Write), - }, - // Buffer - { - vk_stage(Stage::Transfer), - vk_access(Stage::Transfer, Access::Read), - }, - // Image - {}, - })); + const bool is_RAW = read_requested && prev_written; - command_buffer_.copy(buffer, staging); -} - -void vTensor::View::CMD::copy_staging_to_buffer( - State& state, - const Buffer::Object& staging, - Buffer::Object& buffer) { - if (state.is_clean(Component::Buffer) || state.is_uma()) { - return; - } - - barrier( - state.transition({ - // Staging - { - vk_stage(Stage::Transfer), - vk_access(Stage::Transfer, Access::Read), - }, - // Buffer - { - vk_stage(Stage::Transfer), - vk_access(Stage::Transfer, Access::Write), - }, - // Image - {}, - })); - - command_buffer_.copy(staging, buffer); -} - -void vTensor::View::CMD::copy_buffer_to_image( - State& state, - const Buffer::Object& buffer, - Image::Object& image) { - if (state.is_clean(Component::Image)) { - return; - } - - api::OpProfiler profiler(command_buffer_, view_.context_->querypool(), "copy_buffer_to_image"); - - barrier( - state.transition({ - // Staging - {}, - // Buffer - { - vk_stage(Stage::Compute), - vk_access(Stage::Compute, Access::Read), - }, - // Image - { - vk_stage(Stage::Compute), - vk_access(Stage::Compute, Access::Write), - vk_layout(Stage::Compute, Access::Write), - }, - })); - - const uvec3 extents = view_.extents(); - const uint32_t plane = extents.data[0u] * extents.data[1u]; - - const struct Block final { - uvec3 extents; - uint32_t block; - uvec4 offset; - } block { - extents, - 4u * plane, - { - 0u * plane, - 1u * plane, - 2u * plane, - 3u * plane, - }, - }; - - view_.context_->dispatch( - command_buffer_, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(nchw_to_image), - extents, - adaptive_work_group_size(extents), - image, - buffer, - view_.context_->resource().pool.uniform(block).object); -} - -void vTensor::View::CMD::copy_image_to_buffer( - State& state, - const Image::Object& image, - Buffer::Object& buffer) { - if (state.is_clean(Component::Buffer)) { - return; - } - - api::OpProfiler profiler(command_buffer_, view_.context_->querypool(), "copy_image_to_buffer"); - - barrier( - state.transition({ - // Staging - {}, - // Buffer - { - vk_stage(Stage::Compute), - vk_access(Stage::Compute, Access::Write), - }, - // Image - { - vk_stage(Stage::Compute), - vk_access(Stage::Compute, Access::Read), - vk_layout(Stage::Compute, Access::Read), - }, - })); - - const uvec3 extents = view_.extents(); - const uint32_t plane = extents.data[0u] * extents.data[1u]; - - const struct Block final { - uvec3 extents; - uint32_t block; - uvec4 offset; - } block { - extents, - 4u * plane, - { - 0u * plane, - 1u * plane, - 2u * plane, - 3u * plane, - }, - }; - - view_.context_->dispatch( - command_buffer_, - { - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(image_to_nchw), - view_.extents(), - adaptive_work_group_size(view_.extents()), - image, - buffer, - view_.context_->resource().pool.uniform(block).object); -} - -void vTensor::View::CMD::submit(const api::Resource::Fence fence) { - view_.context_->command().pool.submit( - view_.context_->gpu().queue, - command_buffer_, - fence); -} - -vTensor::Buffer& vTensor::View::buffer() const { - if (!buffer_) { - buffer_ = allocate_buffer( - context_->gpu().adapter, - pool_, - sizes(), - options()); - } - - return buffer_; -} - -vTensor::Buffer& vTensor::View::buffer( - api::Command::Buffer& command_buffer, - const Stage::Flags stage, - const Access::Flags access) const { - CMD cmd(*this, command_buffer); - return buffer(cmd, stage, access); -} - -vTensor::Buffer& vTensor::View::buffer( - CMD& cmd, - const Stage::Flags stage, - const Access::Flags access) const { - if ((access & Access::Read) && state_.is_dirty(Component::Buffer)) { - if (state_.is_clean(Component::Staging)) { - cmd.copy_staging_to_buffer( - state_, - staging(cmd, Stage::Transfer, Access::Read).object, - buffer().object); + if (is_RAW) { + VkPipelineStageFlags src_stage = api::vk_stage(prev_stage); + if (0u == src_stage) { + src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; } - else if (state_.is_clean(Component::Image)) { - cmd.copy_image_to_buffer( - state_, - image(cmd, Stage::Compute, Access::Read).object, - buffer().object); + VkPipelineStageFlags dst_stage = api::vk_stage(cur_stage); + if (0u == dst_stage) { + dst_stage = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT; } - else { - TORCH_INTERNAL_ASSERT( - false, - "Invalid state!"); - } - } - - cmd.barrier( - state_.transition({ - // Staging - {}, - // Buffer - { - vk_stage(stage), - vk_access(stage, access), - }, - // Image - {}, - })); - if (access & Access::Write) { - state_.set_dirty(Component::All); - } - - state_.set_clean(Component::Buffer); + pipeline_barrier.stage.src |= src_stage; + pipeline_barrier.stage.dst |= dst_stage; - return buffer(); -} - -vTensor::Image& vTensor::View::image() const { - if (!image_ && state_.is_available(Component::Image)) { - image_ = allocate_image( - pool_, - vk_extent(extents()), - options()); + pipeline_barrier.buffers.push_back(api::BufferMemoryBarrier( + api::vk_access(prev_stage, prev_access), + api::vk_access(cur_stage, cur_access), + buffer)); } - - return image_; -} - -vTensor::Image& vTensor::View::image( - api::Command::Buffer& command_buffer, - const Stage::Flags stage, - const Access::Flags access) const { - CMD cmd(*this, command_buffer); - return image(cmd, stage, access); -} - -vTensor::Image& vTensor::View::image( - CMD& cmd, - const Stage::Flags stage, - const Access::Flags access) const { - if ((access & Access::Read) && state_.is_dirty(Component::Image)) { - cmd.copy_buffer_to_image( - state_, - buffer(cmd, stage, Access::Read).object, - image().object); - } - - cmd.barrier( - state_.transition({ - // Staging - {}, - // Buffer - {}, - // Image - { - vk_stage(stage), - vk_access(stage, access), - vk_layout(stage, access), - }, - })); - - if (access & Access::Write) { - state_.set_dirty(Component::All); - } - - state_.set_clean(Component::Image); - - return image(); -} - -vTensor::Buffer& vTensor::View::staging() const { - if (!state_.is_available(Component::Staging)) { - return buffer(); - } - - if (!staging_) { - staging_ = allocate_staging( - context_->gpu().adapter, - pool_, - sizes(), - options()); - } - - return staging_; -} - -vTensor::Buffer& vTensor::View::staging( - api::Command::Buffer& command_buffer, - const Stage::Flags stage, - const Access::Flags access) const { - CMD cmd(*this, command_buffer); - Buffer& staging = this->staging(cmd, stage, access); - cmd.submit(fence(access)); - - return staging; -} - -vTensor::Buffer& vTensor::View::staging( - CMD& cmd, - const Stage::Flags stage, - const Access::Flags access) const { - if ((access & Access::Read) && state_.is_dirty(Component::Staging)) { - cmd.copy_buffer_to_staging( - state_, - buffer(cmd, Stage::Transfer, Access::Read).object, - staging().object); - } - - cmd.barrier( - state_.transition({ - // Staging - { - vk_stage(stage), - vk_access(stage, access), - }, - // Buffer - {}, - // Image - {}, - })); - - if (access & Access::Write) { - state_.set_dirty(Component::All); - } - - state_.set_clean(Component::Staging); - - return staging(); -} - -vTensor::Fence& vTensor::View::fence(const Access::Flags access) const { - if (access & Access::Read) { - fence_ = allocate_fence(&context_->resource().pool); - } - - return fence_; -} - -vTensor::Memory& vTensor::View::wait() const { - if (fence_) { - fence_.wait(); - } - - return staging().memory; -} - -void vTensor::View::verify() const { - TORCH_INTERNAL_ASSERT(!image_ || state_.is_available(Component::Image)); - TORCH_INTERNAL_ASSERT(!staging_ || state_.is_discrete()); -} - -vTensor::View::State::State() - : available_{}, - dirty_{}, - bundle_{} { -} - -vTensor::View::State::State( - const api::Adapter* const adapter, - const IntArrayRef sizes) - : available_{}, - dirty_{}, - bundle_{} { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - adapter, - "Invalid Vulkan adapter!"); - - available_ |= Component::Buffer; - - if (requires_image(sizes)) { - available_ |= Component::Image; - } - - if (requires_staging(adapter)) { - available_ |= Component::Staging; - } -} - -#ifdef VULKAN_TENSOR_DEBUG -std::ostream& operator<<( - std::ostream&, - const vTensor::View::State::Bundle&); -#endif /* VULKAN_TENSOR_DEBUG */ - -vTensor::View::State::Transition -vTensor::View::State::transition(const Bundle bundle) { - const Bundle from = bundle_; - Bundle& to = bundle_; - - if (bundle.staging) { - to.staging = bundle.staging; - } - - if (bundle.buffer) { - to.buffer = bundle.buffer; - } - - if (bundle.image) { - to.image = bundle.image; - } - -#ifdef VULKAN_TENSOR_DEBUG - std::cout << "From:" << std::endl << from << std::endl; - std::cout << "To:" << std::endl << to << std::endl; -#endif /* VULKAN_TENSOR_DEBUG */ - - return Transition{ - from, - to, - }; } void verify(const TensorOptions& options) { @@ -1161,175 +245,6 @@ void verify(const TensorOptions& options) { "'memory_format' tensor option is not yet supported under Vulkan!"); } -// -// Debug -// - -#ifdef VULKAN_TENSOR_DEBUG - -namespace { - -// Considering that VkAccessFlags is a weak typedef of a built-in data type, we -// need to introduce a new type to allow overload resolution distinguish between -// the two. - -struct Access final { - VkAccessFlags value; -}; - -std::ostream& operator<<( - std::ostream& stream, - const Access& access) { - stream << "Access: "; - - if (0u == access.value) { - return stream << " 0"; - } - - if (access.value & VK_ACCESS_HOST_READ_BIT) { - stream << " VK_ACCESS_HOST_READ_BIT"; - } - - if (access.value & VK_ACCESS_HOST_WRITE_BIT) { - stream << " VK_ACCESS_HOST_WRITE_BIT"; - } - - if (access.value & VK_ACCESS_MEMORY_READ_BIT) { - stream << " VK_ACCESS_MEMORY_READ_BIT"; - } - - if (access.value & VK_ACCESS_MEMORY_WRITE_BIT) { - stream << " VK_ACCESS_MEMORY_WRITE_BIT"; - } - - if (access.value & VK_ACCESS_SHADER_READ_BIT) { - stream << " VK_ACCESS_SHADER_READ_BIT"; - } - - if (access.value & VK_ACCESS_SHADER_WRITE_BIT) { - stream << " VK_ACCESS_SHADER_WRITE_BIT"; - } - - if (access.value & VK_ACCESS_TRANSFER_READ_BIT) { - stream << " VK_ACCESS_TRANSFER_READ_BIT"; - } - - if (access.value & VK_ACCESS_TRANSFER_WRITE_BIT) { - stream << " VK_ACCESS_TRANSFER_WRITE_BIT"; - } - - return stream; -} - -// Considering that VkImageLayout is a weak typedef of a built-in data type, -// we need to introduce a new type to allow overload resolution distinguish -// between the two. - -struct Image final { - struct Layout final { - VkImageLayout value; - }; -}; - -std::ostream& operator<<( - std::ostream& stream, - const Image::Layout& layout) { - stream << "Layout: "; - - switch (layout.value) { - case VK_IMAGE_LAYOUT_UNDEFINED: - stream << " VK_IMAGE_LAYOUT_UNDEFINED"; - break; - - case VK_IMAGE_LAYOUT_GENERAL: - stream << " VK_IMAGE_LAYOUT_GENERAL"; - break; - - case VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL: - stream << " VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL"; - break; - - default: - stream << " Unknown!"; - break; - }; - - return stream; -} - -// Considering that VkPipelineStageFlags is a weak typedef of a built-in data -// type, we need to introduce a new type to allow overload resolution distinguish -// between the two. - -struct Stage final { - VkPipelineStageFlags value; -}; - -std::ostream& operator<<( - std::ostream& stream, - const Stage& stage) { - stream << "Stage: "; - - if (0u == stage.value) { - return stream << " 0"; - } - - if (stage.value & VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT) { - stream << " VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT"; - } - - if (stage.value & VK_PIPELINE_STAGE_HOST_BIT) { - stream << " VK_PIPELINE_STAGE_HOST_BIT"; - } - - if (stage.value & VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT) { - stream << " VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT"; - } - - if (stage.value & VK_PIPELINE_STAGE_TRANSFER_BIT) { - stream << " VK_PIPELINE_STAGE_TRANSFER_BIT"; - } - - return stream; -} - -} // namespace - -std::ostream& operator<<( - std::ostream& stream, - const vTensor::View::State::Bundle& bundle) { - stream << "Staging\n " << - Stage{ - bundle.staging.stage, - } << "\n " << - Access{ - bundle.staging.access, - } << std::endl; - - stream << "Buffer\n " << - Stage{ - bundle.buffer.stage, - } << "\n " << - Access{ - bundle.buffer.access, - } << std::endl; - - stream << "Image\n " << - Stage{ - bundle.image.stage, - } << "\n " << - Access{ - bundle.image.access, - } << "\n " << - Image::Layout{ - bundle.image.layout, - } << std::endl; - - return stream; -} - -#endif /* VULKAN_TENSOR_DEBUG */ - } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.h b/aten/src/ATen/native/vulkan/ops/Tensor.h index ad37340e3ebec..ecf99ceb9f375 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.h +++ b/aten/src/ATen/native/vulkan/ops/Tensor.h @@ -3,8 +3,8 @@ #ifdef USE_VULKAN_API #include -#include #include +#include #include namespace at { @@ -12,373 +12,93 @@ namespace native { namespace vulkan { namespace ops { -// -// This class represents a Vulkan tensor and provides an abstraction layer -// that allows both the CPU, and the GPU, to view a Vulkan (buffer, image) -// pair as one coherent, synchronized unit of storage on both UMA and discrete -// systems. Expanding on the previous sentence, this class tries to address -// two orthogonal implementation complexities that arise as a result of the -// aforementioned goal of memory coherence: -// -// 1) First, synchronization across processors; CPUs and GPUs are separate -// processors, and even though they share the same address space in a system -// with a unified memory architecture, their address spaces only partially -// overlap on systems with a discrete GPU. Consequently on discrete systems, -// while it is still technically possible to take advantage of this shared -// address space to maintain one single copy of the data, different access -// latencies from CPU and GPU to this shared location usually necessitates -// maintaining two copies each in processor-local memory, otherwise memory -// access latency will hurt from the processor to which this data is not -// close. This shared memory is more often than not located in system memory, -// making for slow GPU read and write access over the PCI-e bus on discrete. -// Maintaining two separate copies on the other hand, requires synchronization -// to guarantee coherence. This is not an issue on UMA and this implementation -// accounts for that optimization. -// -// 2) Second, synchronization across resources (i.e. buffers and images); GPU -// drivers pack images in proprietory formats for better locality of access -// and to enable lossless compression. These conversions are both expensive -// (in general) and manual (in Vulkan.) This requires a second order of -// synchronization to guarantee coherence between the contents of the buffer -// and image otherwise they will go out of sync. -// -// It is extremely important to keep in mind that the functionality this class -// provides is generally expensive. For optimal performance, the user of this -// class should: -// -// 1) Avoid frequent CPU <=> GPU transfers which will be triggered if data is -// write accessed on one processor and read / write accessed on the other. -// -// 2) Avoid frequent buffer <=> image conversions which will be trigerred if -// data is write accessed as a buffer (image) and read accessed as an -// image (buffer). -// -// 3) When and if a synchronization is unavoidable, place as much distance -// between the synchronization is triggered and the data is accessed since -// all synchronizations this class provides are async. -// -// For optimal performance, access the data as images, and keep the data on GPU, -// and above all understand the expensive data flow that this class abstracts -// away. -// -// vTensor tries to address a specific concern and intentionally does not expose -// GPU tensor memory directly. Please keep that behavior intact as the whole -// data model fundamentally depends on limiting what the user can achieve through -// the interface to guarantee performance and coherence. -// -// A vTensor is associated with an api::Context as preparation for multi-GPU -// support. -// +struct LastAccess { + api::PipelineStageFlags stage; + api::MemoryAccessFlags access; -class vTensor final { + LastAccess() + : stage{api::PipelineStage::NO_STAGE}, + access{api::MemoryAccessType::NONE} {} + + LastAccess( + api::PipelineStageFlags stage_flags, + api::MemoryAccessFlags access_flags) + : stage{stage_flags}, access{access_flags} {} +}; + +class vTensorStorage final { public: - vTensor() = default; - vTensor( + // Do not allow empty vTensorStorage construction + vTensorStorage() = default; + + vTensorStorage( api::Context* context, IntArrayRef sizes, const TensorOptions& options); - vTensor( + vTensorStorage( api::Context* context, - api::Resource::Pool* pool, IntArrayRef sizes, - const TensorOptions& options); + const TensorOptions& options, + double q_scale, + int64_t q_zero_point); - /* - Types - */ + vTensorStorage(const vTensorStorage&) = delete; + vTensorStorage& operator=(const vTensorStorage&) = delete; - typedef api::Pipeline::Stage Stage; - typedef api::Resource::Memory::Access Access; - typedef api::Resource::Buffer Buffer; - typedef api::Resource::Fence Fence; - typedef api::Resource::Image Image; - typedef api::Resource::Memory Memory; + vTensorStorage(vTensorStorage&&) = default; + vTensorStorage operator=(vTensorStorage&&) = delete; - /* - Future - */ + ~vTensorStorage(); - template - class Future final { - template - using is_convertible = std::enable_if_t< - std::is_convertible< - Access::Pointer, - Access::Pointer>::value>; - - public: - explicit Future(const vTensor* tensor); - Future(const Future&) = delete; - Future& operator=(const Future&) = delete; - Future(Future&&); - Future& operator=(Future&&) &; - Future& operator=(Future&&) && = delete; - template> - Future(Future&&); - template> - Future& operator=(Future&&) &; - template - Future& operator=(Future&&) && = delete; - ~Future(); - - typedef Memory::Handle< - Access::Pointer< - Type, - kAccess>> Payload; - - // This is a blocking operation as the name suggests. A call to host() will - // trigger an async copy if pending writes are detected. Consequently, for - // optimal performance, put as much time and distance between the place - // where a vTensor::host() call occurs and the location where the returned - // future is explicitly waited on as a result of a call to this function. - - Payload wait() const &; - - private: - template - friend class Future; - - // Intentionally disabed to enforce a usage pattern wherein the Future's - // lifetime exceeds that of the Payload as we use the Future's destructor - // to eagerly (as opposed to lazily and upon first use) upload the - // modifications back onto the GPU in an effort to hide the upload latency. - - Payload wait() const && = delete; - - private: - const vTensor* tensor_; - }; - - /* - Host access - these functions will be expensive if they trigger a GPU -> CPU - sync due to pending writes. A call to host() will trigger an async copy in - such scenarios, which is then explicitly waited on as part of Future::wait(). - Consequently, for optimal performance, put as much time and distance between - the place where this function is called, and the location where the future is - waited on. - */ - - template - Future host(api::Command::Buffer&) const &; - - template - Future host(api::Command::Buffer&) &; - - /* - Device access - these functions will be expensive if they trigger a buffer - <-> image or CPU -> GPU sync due to pending writes. These functions are - non-blocking on the host as the copy operation is carried out by the GPU - asynchronously. Regardless, they result in extra work that could have been - avoided or at least minimized if all data access had occured through one - single processor (GPU in this case) and on one type of resource (image for - best performance.) Consequently, for optimal performance, avoid mixed reads - and writes across processor boundaries, and do your best to minimize layout - transitions as a result of working with images only (as opposed to mixed - buffer - image usage.) - This implementation intentionally restricts user access to the buffer and - image objects only, as opposed to their underlying memory, for the sake of - predictability of usage and efficiency. - */ - - Buffer::Object buffer(api::Command::Buffer&, Stage::Flags) const &; - Buffer::Object buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) &; - - bool has_image() const; - Image::Object image(api::Command::Buffer&, Stage::Flags) const &; - Image::Object image(api::Command::Buffer&, Stage::Flags, Access::Flags) &; - - /* - Metadata - */ - - const api::utils::uvec3& extents() const; - const TensorOptions& options() const; - IntArrayRef sizes() const; - IntArrayRef strides() const; - size_t nbytes() const; + friend class vTensor; private: - // Some overloads below are intentionally disabled to enforce a usage pattern - // that ensures the Tensor's lifetime exceeds that of the scope in which the - // underlying data is accessed. Allowing deleted overloads below to be - // invoked on a temporary would open the door to the possibility of accessing - // the underlying memory out of the expected scope. - - /* - Host - */ + // Context + api::Context* context_; - const vTensor* host(api::Command::Buffer&) const; - vTensor* host(api::Command::Buffer&, Access::Flags); + // Metadata + api::utils::uvec3 extents_; + TensorOptions options_; + c10::SmallVector sizes_; + c10::SmallVector strides_; + bool is_quantized_{false}; + double q_scale{1.0f}; + int64_t q_zero_point{0u}; - template - Future host(api::Command::Buffer&) const && = delete; + // Image Texture + mutable api::VulkanImage image_; - template - Future host(api::Command::Buffer&) && = delete; + // Last Access - used to insert memory barriers + LastAccess last_access_; - /* - Device - */ + private: + // Memory barrier insertion + void transition( + api::PipelineBarrier&, + const api::PipelineStageFlags, + const api::MemoryAccessFlags); + + // Validation + void verify() const; +}; - Buffer::Object buffer(api::Command::Buffer&, Stage::Flags) const && = delete; - Buffer::Object buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) && = delete; +class vTensor final { + public: + // Do not allow empty vTensor construction + vTensor() = default; - Image::Object image(api::Command::Buffer&, Stage::Flags) const && = delete; - Image::Object image(api::Command::Buffer&, Stage::Flags, Access::Flags) && = delete; + vTensor( + api::Context* context, + IntArrayRef sizes, + const TensorOptions& options); + vTensor( + api::Context* const context, + const IntArrayRef sizes, + const TensorOptions& options, + double q_scale, + int64_t q_zero_point); private: - class View final { - public: - View(); - View( - api::Context* context, - api::Resource::Pool* pool, - IntArrayRef sizes, - const TensorOptions& options); - View(const View&) = delete; - View& operator=(const View&) = delete; - View(View&&) = default; - View operator=(View&&) = delete; - ~View(); - - void release(); - - /* - Buffer - */ - - Buffer& buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) const; - - /* - Image - */ - - bool has_image() const; - Image& image(api::Command::Buffer&, Stage::Flags, Access::Flags) const; - - /* - Host - */ - - Buffer& staging(api::Command::Buffer&, Stage::Flags, Access::Flags) const; - vTensor::Memory& wait() const; - - /* - Metadata - */ - - const api::utils::uvec3& extents() const; - const TensorOptions& options() const; - IntArrayRef sizes() const; - IntArrayRef strides() const; - - private: - class CMD; - - class State final { - public: - State(); - State(const api::Adapter*, IntArrayRef); - - struct Bundle final { - struct Buffer final { - VkPipelineStageFlags stage; - VkAccessFlags access; - - operator bool() const; - } staging, buffer; - - struct Image final { - VkPipelineStageFlags stage; - VkAccessFlags access; - VkImageLayout layout; - - operator bool() const; - } image; - }; - - struct Component final { - typedef uint8_t Flags; - - enum Type : Flags { - Buffer = 1u << 0u, - Image = 1u << 1u, - Staging = 1u << 2u, - All = Buffer | Image | Staging, - }; - }; - - // Availability - bool is_available(Component::Flags) const; - bool is_discrete() const; - bool is_uma() const; - - // Clean / Dirty - bool is_clean(Component::Flags) const; - bool is_dirty(Component::Flags) const; - void set_clean(Component::Flags); - void set_dirty(Component::Flags); - - // Transition - typedef std::pair Transition; - Transition transition(Bundle to); - - private: - Component::Flags available_; - Component::Flags dirty_; - Bundle bundle_; - - private: - #ifdef VULKAN_TENSOR_DEBUG - friend class View; - #endif /* VULKAN_TENSOR_DEBUG */ - }; - - typedef State::Component Component; - - private: - // Accessors / Lazy Allocation - Buffer& buffer() const; - Buffer& buffer(CMD&, Stage::Flags, Access::Flags) const; - Image& image() const; - Image& image(CMD&, Stage::Flags, Access::Flags) const; - Buffer& staging() const; - Buffer& staging(CMD&, Stage::Flags, Access::Flags) const; - Fence& fence(Access::Flags) const; - - // Validation - void verify() const; - - private: - // Resources - mutable Buffer buffer_; - mutable Image image_; - mutable Buffer staging_; - mutable Fence fence_; - - // Context - api::Context* context_; - api::Resource::Pool* pool_; - - // State - mutable State state_; - - // Metadata - api::utils::uvec3 extents_; - TensorOptions options_; - c10::SmallVector sizes_; - c10::SmallVector strides_; - - private: - #ifdef VULKAN_TENSOR_DEBUG - friend class vTensor; - friend std::ostream& operator<<( - std::ostream&, - const View::State::Bundle&); - #endif /* VULKAN_TENSOR_DEBUG */ - }; - // Even at the cost of a heap allocation plus the resulting negative impact // on cache locality due to the subsequent pointer chasing, it is still // critcal to share the view across vTensor implementations to minimize @@ -388,211 +108,86 @@ class vTensor final { // at::TensorImpl::release_resources() to function as expected. Now that this // class is made copyable though, a new door to a whole new class of bugs is // opened, in that there now is a chance of two [shallow] copies, have their - // State objects go out of sync as a result of an operation being performed on - // one shallow copy that is not reflected in the other. Technically, if the - // programmer is very careful, it is possible to avoid this trap and not pay - // the cost of indirection, but the resulting bugs of missing memory barriers - // will be so frustrating to hunt down for those unfamiliar with the internal - // mechanics of this class, that I decided to take the performance pentalty - // of this extra layer of indirection in favor of making this class easier - // to use. - - std::shared_ptr view_; + // StorageState objects go out of sync as a result of an operation being + // performed on one shallow copy that is not reflected in the other. + // Technically, if the programmer is very careful, it is possible to avoid + // this trap and not pay the cost of indirection, but the resulting bugs of + // missing memory barriers will be so frustrating to hunt down for those + // unfamiliar with the internal mechanics of this class, that I decided to + // take the performance pentalty of this extra layer of indirection in favor + // of making this class easier to use. + std::shared_ptr view_; - private: - #ifdef VULKAN_TENSOR_DEBUG - friend std::ostream& operator<<( - std::ostream&, - const View::State::Bundle&); - #endif /* VULKAN_TENSOR_DEBUG */ -}; - -vTensor& convert(const Tensor& tensor); -Tensor convert(const vTensor& tensor); - -using vTensorImpl = VulkanOpaqueTensorImpl; -void verify(const TensorOptions& options); - -// -// Impl -// - -template -inline vTensor::Future::Future( - const vTensor* const tensor) - : tensor_(tensor) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - tensor_, - "Invalid Vulkan tensor!"); -} - -template -inline vTensor::Future::Future( - Future&& future) - : tensor_(std::move(future.tensor_)) { - future.tensor_ = nullptr; -} + public: + /* + Texture Access + */ -template -inline vTensor::Future& -vTensor::Future::operator=( - Future&& future) & { - tensor_ = std::move(future.tensor_); - future.tensor_ = nullptr; - return *this; -} + api::VulkanImage& image(api::PipelineBarrier&, const api::PipelineStageFlags) + const&; -template -template -inline vTensor::Future::Future( - Future&& future) - : tensor_(std::move(future.tensor_)) { - future.tensor_ = nullptr; -} + api::VulkanImage& image( + api::PipelineBarrier&, + const api::PipelineStageFlags, + const api::MemoryAccessFlags) &; -template -template -inline vTensor::Future& -vTensor::Future::operator=( - Future&& future) & { - tensor_ = std::move(future.tensor_); - future.tensor_ = nullptr; - return *this; -} + /* + Metadata + */ -template -inline vTensor::Future::~Future() { -#if VULKAN_SYNC_TENSORS_EAGERLY - // Sync eagerly in an effort to hide latency. - // Upside: Kick off the async transfer early on to keep the GPU busy. - // Downside: An extra CPU command submission. - if (tensor_ && (Access::Write & kAccess)) { - if (tensor_->has_image()) { - tensor_->image(); - } - else { - tensor_->buffer(); - } + inline const api::utils::uvec3& extents() const { + return view_->extents_; } -#endif -} - -template -inline typename vTensor::Future::Payload -vTensor::Future::wait() const & { - TORCH_CHECK( - tensor_, - "vTensor::Future is in an invalid state! " - "Potential reason: This future is moved from."); - - return tensor_->view_->wait().template map(); -} - -template -inline vTensor::Future -vTensor::host(api::Command::Buffer& command_buffer) const & { - return Future(host(command_buffer)); -} - -template -inline vTensor::Future -vTensor::host(api::Command::Buffer& command_buffer) & { - return Future(host(command_buffer, kAccess)); -} - -inline bool vTensor::has_image() const { - return view_->has_image(); -} - -inline const api::utils::uvec3& vTensor::extents() const { - return view_->extents(); -} -inline const TensorOptions& vTensor::options() const { - return view_->options(); -} - -inline IntArrayRef vTensor::sizes() const { - return view_->sizes(); -} - -inline size_t vTensor::nbytes() const { - return c10::elementSize(c10::typeMetaToScalarType(options().dtype())) * - c10::multiply_integers(sizes()); -} - -inline IntArrayRef vTensor::strides() const { - return view_->strides(); -} - -inline bool vTensor::View::has_image() const { - return state_.is_available(View::Component::Image); -} - -inline const api::utils::uvec3& vTensor::View::extents() const { - return extents_; -} - -inline const TensorOptions& vTensor::View::options() const { - return options_; -} - -inline IntArrayRef vTensor::View::sizes() const { - return sizes_; -} - -inline IntArrayRef vTensor::View::strides() const { - return strides_; -} + inline const TensorOptions& options() const { + return view_->options_; + } -inline vTensor::View::State::Bundle::Buffer::operator bool() const { - return (0u != stage) && - (0u != access); -} + inline IntArrayRef sizes() const { + return view_->sizes_; + } -inline vTensor::View::State::Bundle::Image::operator bool() const { - return (0u != stage) && - (0u != access) && - (VK_IMAGE_LAYOUT_UNDEFINED != layout); -} + inline IntArrayRef strides() const { + return view_->strides_; + } -inline bool vTensor::View::State::is_available( - const Component::Flags components) const { - return available_ & components; -} + inline bool is_quantized() const { + return view_->is_quantized_; + } -inline bool vTensor::View::State::is_discrete() const { - return is_available(Component::Staging); -} + inline double get_scale() const { + return view_->q_scale; + } -inline bool vTensor::View::State::is_uma() const { - return !is_discrete(); -} + inline int64_t get_zero_point() const { + return view_->q_zero_point; + } -inline bool vTensor::View::State::is_clean( - const Component::Flags components) const { - return !is_dirty(components); -} + inline size_t nbytes() const { + return c10::elementSize(c10::typeMetaToScalarType(options().dtype())) * + c10::multiply_integers(sizes()); + } -inline bool vTensor::View::State::is_dirty( - const Component::Flags components) const { - return dirty_ & components; -} + inline VkDeviceSize buffer_bytes() { + return c10::elementSize(c10::typeMetaToScalarType(options().dtype())) * + view_->extents_.data[0u] * view_->extents_.data[1u] * + (4u * view_->extents_.data[2u]); + } +}; -inline void vTensor::View::State::set_clean( - const Component::Flags components) { - dirty_ &= ~components; -} +void add_buffer_barrier( + api::PipelineBarrier&, + const api::VulkanBuffer&, + const api::PipelineStageFlags, + const api::MemoryAccessFlags, + const api::PipelineStageFlags, + const api::MemoryAccessFlags); -inline void vTensor::View::State::set_dirty( - const Component::Flags components) { - dirty_ |= components; -} +using vTensorImpl = VulkanOpaqueTensorImpl; +void verify(const TensorOptions& options); inline vTensor& convert(const Tensor& tensor) { - TORCH_INTERNAL_ASSERT( - tensor.is_vulkan(), - "Vulkan tensor expected!"); + TORCH_INTERNAL_ASSERT(tensor.is_vulkan(), "Vulkan tensor expected!"); vTensorImpl* const impl = static_cast(tensor.unsafeGetTensorImpl()); @@ -610,6 +205,16 @@ inline Tensor convert(const vTensor& tensor) { tensor.strides()); } +inline Tensor convert_quantized(const vTensor& tensor) { + TORCH_CHECK(tensor.is_quantized(), "Not a Quantized Tensor"); + return at::detail::make_tensor( + DispatchKeySet(DispatchKey::Vulkan), + tensor.options().dtype(), + at::Device(at::kVulkan), + tensor, + tensor.sizes(), + tensor.strides()); +} } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/TransposeConvolution2d.cpp b/aten/src/ATen/native/vulkan/ops/TransposeConvolution2d.cpp index 62a47d3df8429..125efa803f3c1 100644 --- a/aten/src/ATen/native/vulkan/ops/TransposeConvolution2d.cpp +++ b/aten/src/ATen/native/vulkan/ops/TransposeConvolution2d.cpp @@ -1,9 +1,10 @@ #include #include -#include + #include #include #include +#include #include #include @@ -14,10 +15,10 @@ namespace ops { namespace { using namespace api::utils; +using namespace at::native::vulkan::ops; vTensor pack_weights_2d_reverse( api::Context* const context, - api::Command::Buffer& command_buffer, const Tensor& weight, bool reversed) { /* Source */ @@ -27,10 +28,13 @@ vTensor pack_weights_2d_reverse( const int64_t src_kw_sz = src_filter[Layout::Filter::width]; const int64_t src_kh_sz = src_filter[Layout::Filter::height]; const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; - const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; + const int64_t src_block_sz = + src_kernel_sz * src_filter[Layout::Filter::input]; - const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); - const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); + const int64_t num_stacks = + div_up(src_filter[Layout::Filter::output], INT64_C(4)); + const int64_t stack_depth = + api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); /* Destination */ const int64_t dst_kw_sz = src_kw_sz * stack_depth; @@ -47,37 +51,43 @@ vTensor pack_weights_2d_reverse( weight.options(), }; - using Future = vTensor::Future; - Future v_weight_future = v_weight.host(command_buffer); - Future::Payload v_weight_payload = v_weight_future.wait(); - - float* const dst_weight_ptr = v_weight_payload.get(); - memset(dst_weight_ptr, 0, v_weight.nbytes()); - - for (const auto src_oc : c10::irange(src_filter[Layout::Filter::output])) { - /* Source */ - const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; - - /* Destination */ - const int64_t dst_oh = src_oc / 4; - const int64_t dst_c = src_oc % 4; - - float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; - - for (const auto src_ic : c10::irange(src_filter[Layout::Filter::input])) { - for (const auto src_ih : c10::irange(src_kh_sz)) { - const int64_t dst_h = reversed ? (src_kh_sz - 1 - src_ih) : src_ih; - for (const auto src_iw : c10::irange(src_kw_sz)) { - const int64_t dst_w = reversed ? (src_kw_sz - 1 - src_iw) : src_iw; - const int64_t dst_w_offset = dst_w * stack_depth; - memcpy( - dst_weight_c_ptr + (dst_oh * src_kh_sz + dst_h) * dst_kw_sz + src_ic + dst_w_offset, - src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw, - sizeof(float)); + api::StagingBuffer staging(context, v_weight.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + float* dst_weight_ptr = mapping.template data(); + + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + for (const auto src_oc : c10::irange(src_filter[Layout::Filter::output])) { + /* Source */ + const float* const src_weight_oc_ptr = + src_weight_ptr + src_oc * src_block_sz; + + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; + + float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; + + for (const auto src_ic : c10::irange(src_filter[Layout::Filter::input])) { + for (const auto src_ih : c10::irange(src_kh_sz)) { + const int64_t dst_h = reversed ? (src_kh_sz - 1 - src_ih) : src_ih; + for (const auto src_iw : c10::irange(src_kw_sz)) { + const int64_t dst_w = reversed ? (src_kw_sz - 1 - src_iw) : src_iw; + const int64_t dst_w_offset = dst_w * stack_depth; + memcpy( + dst_weight_c_ptr + (dst_oh * src_kh_sz + dst_h) * dst_kw_sz + + src_ic + dst_w_offset, + src_weight_oc_ptr + src_ic * src_kernel_sz + + src_ih * src_kw_sz + src_iw, + sizeof(float)); + } } } } } + utils::pack_staging_to_vtensor(staging.buffer(), v_weight); return v_weight; } @@ -88,63 +98,57 @@ vTensor pack_weights(const Tensor& weight_arg) { } api::Context* const context = api::context(); - api::Command::Buffer& command_buffer = context->command().pool.stream(); // Don't collect the timestamp since the command buffer doesn't record anything const Tensor weight = at::permute(weight_arg, {1, 0, 2, 3}).contiguous(); - return pack_weights_2d_reverse( - context, - command_buffer, - weight, - true); + return pack_weights_2d_reverse(context, weight, true); } -vTensor pack_biases( - const c10::optional& bias, - const Tensor& weight) { +vTensor pack_biases(const c10::optional& bias, const Tensor& weight) { if (bias && bias->is_vulkan()) { return convert(*bias); } api::Context* const context = api::context(); - api::Command::Buffer& command_buffer = context->command().pool.stream(); // Don't collect the timestamp since the command buffer doesn't record anything const int64_t src_w = weight.size(Layout::TransposedFilter::output); const int64_t packed_w = div_up(src_w, INT64_C(4)); vTensor v_bias{ - context, - { - 4, - 1, - packed_w, - }, - weight.options(), + context, + { + 4, + 1, + packed_w, + }, + weight.options(), }; - using Future = vTensor::Future; - Future v_bias_future = v_bias.host(command_buffer); - Future::Payload v_bias_payload = v_bias_future.wait(); + api::StagingBuffer staging(context, v_bias.buffer_bytes()); + { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + + float* dst_bias_ptr = mapping.template data(); - if (bias) { - const float* const src_bias_ptr = bias->contiguous().data_ptr(); - float* const dst_bias_ptr = v_bias_payload.get(); + if (bias) { + const float* const src_bias_ptr = bias->contiguous().data_ptr(); - memset(dst_bias_ptr, 0, v_bias.nbytes()); - for (const auto i : c10::irange(src_w)) { - const int64_t c = i % 4; - const int64_t x = i / 4; - dst_bias_ptr[c * packed_w + x] = src_bias_ptr[i]; + memset(dst_bias_ptr, 0, v_bias.nbytes()); + for (const auto i : c10::irange(src_w)) { + const int64_t c = i % 4; + const int64_t x = i / 4; + dst_bias_ptr[c * packed_w + x] = src_bias_ptr[i]; + } + } else { + memset( + dst_bias_ptr, + // 2's complement integers and IEEE-754 floating point numbers both + // have identical bit representations for 0, so can use memset which + // only accepts uint8_t parameter. + 0, + v_bias.nbytes()); } } - else { - memset( - v_bias_payload.get(), - // 2's complement integers and IEEE-754 floating point numbers both - // have identical bit representations for 0, so can use memset which - // only accepts uint8_t parameter. - 0, - v_bias.nbytes()); - } + utils::pack_staging_to_vtensor(staging.buffer(), v_bias); return v_bias; } @@ -159,14 +163,12 @@ std::array pack_filter( }; return { - align_up(filter[Layout::TransposedFilter::output], INT64_C(4)), - align_up(filter[Layout::TransposedFilter::input], INT64_C(4)), - effective( - filter[Layout::Filter::height], - dilation[Layout::Parameter::height]), - effective( - filter[Layout::Filter::width], - dilation[Layout::Parameter::width]), + align_up(filter[Layout::TransposedFilter::output], INT64_C(4)), + align_up(filter[Layout::TransposedFilter::input], INT64_C(4)), + effective( + filter[Layout::Filter::height], dilation[Layout::Parameter::height]), + effective( + filter[Layout::Filter::width], dilation[Layout::Parameter::width]), }; } @@ -174,8 +176,8 @@ std::array pack_params(const std::vector& vector) { TORCH_INTERNAL_ASSERT(2u == vector.size(), "Invalid usage!"); return { - vector[0], - vector[1], + vector[0], + vector[1], }; } @@ -191,71 +193,73 @@ bool available( const c10::optional& output_min, const c10::optional& output_max) { return api::available() && - // Weight - (4 == weight.ndimension()) && - (weight.size(Layout::Filter::height) > 0) && - (weight.size(Layout::Filter::width) > 0) && - ((weight.device().is_cpu()) || - (c10::DeviceType::Vulkan == weight.device().type())) && - (kFloat == weight.scalar_type()) && - // Bias - ((bias && bias->defined()) ? ((1 == bias->ndimension()) && - ((bias->device().is_cpu()) || - (c10::DeviceType::Vulkan == bias->device().type())) && - (kFloat == bias->scalar_type()) && - (transposed ? (weight.size(Layout::TransposedFilter::output) == - bias->size(Layout::Filter::output)) - : (weight.size(Layout::Filter::output) == - bias->size(Layout::Filter::output)))) - : true) && - // Stride - (stride[Layout::Parameter::height] > 0) && - (stride[Layout::Parameter::width] > 0) && - // Padding - (padding[Layout::Parameter::height] >= 0) && - (padding[Layout::Parameter::width] >= 0) && - // Dilation - (transposed ? (dilation[Layout::Parameter::height] == 1) && - (dilation[Layout::Parameter::width] == 1) - : (dilation[Layout::Parameter::height] > 0) && - (dilation[Layout::Parameter::width] > 0)) && - // Groups - (groups > 0) && - // Input - (weight.size(Layout::Filter::input) > 0) && - // Output - (weight.size(Layout::Filter::output) > 0) && - // Output - Groups - ((weight.size(Layout::Filter::output) % groups) == 0) && - // Output Min / Max - (!output_min || output_min->isFloatingPoint()) && - (!output_max || output_max->isFloatingPoint()) && - true; + // Weight + (4 == weight.ndimension()) && (weight.size(Layout::Filter::height) > 0) && + (weight.size(Layout::Filter::width) > 0) && + ((weight.device().is_cpu()) || + (c10::DeviceType::Vulkan == weight.device().type())) && + (kFloat == weight.scalar_type()) && + // Bias + ((bias && bias->defined()) + ? ((1 == bias->ndimension()) && + ((bias->device().is_cpu()) || + (c10::DeviceType::Vulkan == bias->device().type())) && + (kFloat == bias->scalar_type()) && + (transposed ? (weight.size(Layout::TransposedFilter::output) == + bias->size(Layout::Filter::output)) + : (weight.size(Layout::Filter::output) == + bias->size(Layout::Filter::output)))) + : true) && + // Stride + (stride[Layout::Parameter::height] > 0) && + (stride[Layout::Parameter::width] > 0) && + // Padding + (padding[Layout::Parameter::height] >= 0) && + (padding[Layout::Parameter::width] >= 0) && + // Dilation + (transposed ? (dilation[Layout::Parameter::height] == 1) && + (dilation[Layout::Parameter::width] == 1) + : (dilation[Layout::Parameter::height] > 0) && + (dilation[Layout::Parameter::width] > 0)) && + // Groups + (groups > 0) && + // Input + (weight.size(Layout::Filter::input) > 0) && + // Output + (weight.size(Layout::Filter::output) > 0) && + // Output - Groups + ((weight.size(Layout::Filter::output) % groups) == 0) && + // Output Min / Max + (!output_min || output_min->isFloatingPoint()) && + (!output_max || output_max->isFloatingPoint()) && true; } bool usable(const Tensor& input) { - // Input + // Input return (4 == input.ndimension()) && - (c10::DeviceType::Vulkan == input.device().type()) && - (kFloat == input.scalar_type()) && - (input.size(Layout::Activation4D::batch) >= 0) && - (input.size(Layout::Activation4D::channels) > 0) && - (input.size(Layout::Activation4D::height) > 0) && - (input.size(Layout::Activation4D::width) > 0) && - !input.requires_grad() && - true; + (c10::DeviceType::Vulkan == input.device().type()) && + (kFloat == input.scalar_type()) && + (input.size(Layout::Activation4D::batch) >= 0) && + (input.size(Layout::Activation4D::channels) > 0) && + (input.size(Layout::Activation4D::height) > 0) && + (input.size(Layout::Activation4D::width) > 0) && !input.requires_grad() && + true; } static inline std::vector get_conv_transpose_output_size( - IntArrayRef input_size, IntArrayRef weight_size, - IntArrayRef padding, IntArrayRef output_padding, - IntArrayRef stride, IntArrayRef dilation = IntArrayRef()) { + IntArrayRef input_size, + IntArrayRef weight_size, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation = IntArrayRef()) { auto dim = input_size.size(); std::vector output_size(dim); output_size[0] = input_size[input_batch_size_dim]; output_size[1] = weight_size[weight_input_channels_dim]; for (const auto d : c10::irange(2, dim)) { - output_size[d] = stride[d - 2] * (input_size[d] - 1) + weight_size[d] - 2 * padding[d - 2] + output_padding[d - 2]; + output_size[d] = stride[d - 2] * (input_size[d] - 1) + weight_size[d] - + 2 * padding[d - 2] + output_padding[d - 2]; } return output_size; } @@ -275,7 +279,8 @@ VulkanOpContext conv2d_transpose_context_create( const auto stride = expand_param_if_needed(stride_arg, "stride", 2); const auto padding = expand_param_if_needed(padding_arg, "padding", 2); const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2); - const auto output_padding = expand_param_if_needed(output_padding_arg, "output_padding", 2); + const auto output_padding = + expand_param_if_needed(output_padding_arg, "output_padding", 2); TORCH_CHECK( available( @@ -304,8 +309,12 @@ VulkanOpContext conv2d_transpose_context_create( packed_context.emplace_back(pack_params(output_padding)); packed_context.emplace_back(pack_params(dilation)); packed_context.emplace_back(safe_downcast(groups)); - packed_context.emplace_back(output_min ? output_min->template to() : -std::numeric_limits::infinity()); - packed_context.emplace_back(output_max ? output_max->template to() : +std::numeric_limits::infinity()); + packed_context.emplace_back( + output_min ? output_min->template to() + : -std::numeric_limits::infinity()); + packed_context.emplace_back( + output_max ? output_max->template to() + : +std::numeric_limits::infinity()); c10::impl::GenericList unpacked_context{c10::AnyType::get()}; unpacked_context.reserve(10); @@ -324,7 +333,7 @@ VulkanOpContext conv2d_transpose_context_create( } void conv2d_transpose_sliding_window( - const api::Shader::Descriptor& shader, + const api::ShaderSource& shader, vTensor& v_output, const vTensor& v_input, const vTensor& packed_v_weight, @@ -336,96 +345,76 @@ void conv2d_transpose_sliding_window( const float packed_output_min, const float packed_output_max, const IntArrayRef unpacked_filter) { - bool valid = C10_LIKELY(v_output.has_image() && v_input.has_image() && packed_v_weight.has_image()); - TORCH_CHECK(valid, "Not Implemented!") - api::Context* const context = api::context(); - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "prepacked::conv2d_transpose_clamp_run (conv2d_transpose_sliding_window)"); - - const struct Block final { - uvec3 extents; - int32_t ic4; - ivec4 kernel; - ivec2 ikernel; - ivec2 stride; - ivec2 padding; - ivec2 dilate; - vec2 clamp; - ivec4 src_filter; - } block { + + const struct Block final { + uvec3 extents; + int32_t ic4; + ivec4 kernel; + ivec2 ikernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + ivec4 src_filter; + } block{ v_output.extents(), - safe_downcast(packed_filter[Layout::Filter::input]), /* this is aligned up */ + safe_downcast( + packed_filter[Layout::Filter::input]), /* this is aligned up */ { - safe_downcast(packed_filter[Layout::Filter::width]), - safe_downcast(packed_filter[Layout::Filter::height]), - safe_downcast(v_input.sizes()[Layout::Activation4D::width]), - safe_downcast(v_input.sizes()[Layout::Activation4D::height]), + safe_downcast(packed_filter[Layout::Filter::width]), + safe_downcast(packed_filter[Layout::Filter::height]), + safe_downcast(v_input.sizes()[Layout::Activation4D::width]), + safe_downcast(v_input.sizes()[Layout::Activation4D::height]), }, { - safe_downcast(unpacked_filter[Layout::Filter::width]), - safe_downcast(unpacked_filter[Layout::Filter::height]), + safe_downcast(unpacked_filter[Layout::Filter::width]), + safe_downcast(unpacked_filter[Layout::Filter::height]), }, { - safe_downcast(packed_stride[Layout::Parameter::width]), - safe_downcast(packed_stride[Layout::Parameter::height]), + safe_downcast(packed_stride[Layout::Parameter::width]), + safe_downcast(packed_stride[Layout::Parameter::height]), }, { - safe_downcast(packed_padding[Layout::Parameter::width]), - safe_downcast(packed_padding[Layout::Parameter::height]), + safe_downcast(packed_padding[Layout::Parameter::width]), + safe_downcast(packed_padding[Layout::Parameter::height]), }, { - safe_downcast(packed_dilation[Layout::Parameter::width]), - safe_downcast(packed_dilation[Layout::Parameter::height]), + safe_downcast(packed_dilation[Layout::Parameter::width]), + safe_downcast(packed_dilation[Layout::Parameter::height]), }, { - packed_output_min, - packed_output_max, + packed_output_min, + packed_output_max, }, - }; - - uvec3 global_size = v_output.extents(); - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - shader, - global_size, - adaptive_work_group_size(global_size), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - packed_v_weight.image( - command_buffer, - vTensor::Stage::Compute), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - packed_v_bias.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - command_pool.submit(context->gpu().queue, command_buffer); + }; + + uvec3 global_size = v_output.extents(); + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + shader, + // pipeline barrier + pipeline_barrier, + // global work group size + global_size, + // local work group size + adaptive_work_group_size(global_size), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE), + packed_v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); } Tensor conv2d_transpose_context_run( @@ -455,30 +444,30 @@ Tensor conv2d_transpose_context_run( "Reason: The provided input tensor is either invalid or unsupported by Vulkan impl."); vTensor v_output{ - context, - get_conv_transpose_output_size( - v_input.sizes(), - unpacked_filter, - packed_padding, - packed_output_padding, - packed_stride, - packed_dilation), - input.options(), + context, + get_conv_transpose_output_size( + v_input.sizes(), + unpacked_filter, + packed_padding, + packed_output_padding, + packed_stride, + packed_dilation), + input.options(), }; conv2d_transpose_sliding_window( - VK_KERNEL(conv_transpose2d), - v_output, - v_input, - packed_v_weight, - packed_v_bias, - packed_filter, - packed_stride, - packed_padding, - packed_dilation, - packed_output_min, - packed_output_max, - unpacked_filter); + VK_KERNEL(conv_transpose2d), + v_output, + v_input, + packed_v_weight, + packed_v_bias, + packed_filter, + packed_stride, + packed_padding, + packed_dilation, + packed_output_min, + packed_output_max, + unpacked_filter); return convert(v_output); } @@ -493,32 +482,29 @@ c10::intrusive_ptr create_conv2d_transpose_clamp_context( const int64_t groups, const c10::optional& output_min, const c10::optional& output_max) { - return c10::make_intrusive( - conv2d_transpose_context_create( - weight, - bias, - stride, - padding, - output_padding, - dilation, - groups, - output_min, - output_max)); + return c10::make_intrusive(conv2d_transpose_context_create( + weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + output_min, + output_max)); } Tensor run_conv2d_transpose_clamp_context( const Tensor& input, const c10::intrusive_ptr& vulkan_context) { return conv2d_transpose_context_run( - input, - vulkan_context->get_packed(), - vulkan_context->get_unpacked()); + input, vulkan_context->get_packed(), vulkan_context->get_unpacked()); } /* Backwards compatibility */ -TransposeConv2dOpContext::TransposeConv2dOpContext(VulkanOpContext vulkan_context) - : vulkan_context_{std::move(vulkan_context)} { -} +TransposeConv2dOpContext::TransposeConv2dOpContext( + VulkanOpContext vulkan_context) + : vulkan_context_{std::move(vulkan_context)} {} TransposeConv2dOpContext TransposeConv2dOpContext::create( const Tensor& weight, @@ -530,51 +516,52 @@ TransposeConv2dOpContext TransposeConv2dOpContext::create( const int64_t groups, const c10::optional& output_min, const c10::optional& output_max) { - return TransposeConv2dOpContext { - conv2d_transpose_context_create( - weight, - bias, - stride_arg, - padding_arg, - output_padding_arg, - dilation_arg, - groups, - output_min, - output_max) - }; + return TransposeConv2dOpContext{conv2d_transpose_context_create( + weight, + bias, + stride_arg, + padding_arg, + output_padding_arg, + dilation_arg, + groups, + output_min, + output_max)}; } Tensor TransposeConv2dOpContext::run(const Tensor& input_arg) const { return conv2d_transpose_context_run( - input_arg, - vulkan_context_.get_packed(), - vulkan_context_.get_unpacked()); + input_arg, vulkan_context_.get_packed(), vulkan_context_.get_unpacked()); } TransposeConv2dOpContext::State TransposeConv2dOpContext::unpack() const { - const c10::impl::GenericList unpacked_ = std::get<1>(vulkan_context_.get_state()); + const c10::impl::GenericList unpacked_ = + std::get<1>(vulkan_context_.get_state()); const Tensor unpacked_weight = unpacked_.get(0).toTensor(); - const c10::optional unpacked_bias - = unpacked_.get(1).isTensor() ? unpacked_.get(1).toTensor() : (c10::optional&) c10::nullopt; + const c10::optional unpacked_bias = unpacked_.get(1).isTensor() + ? unpacked_.get(1).toTensor() + : (c10::optional&)c10::nullopt; const std::vector unpacked_stride = unpacked_.get(3).toIntVector(); const std::vector unpacked_padding = unpacked_.get(4).toIntVector(); - const std::vector unpacked_output_padding = unpacked_.get(5).toIntVector(); + const std::vector unpacked_output_padding = + unpacked_.get(5).toIntVector(); const std::vector unpacked_dilation = unpacked_.get(6).toIntVector(); const int64_t unpacked_groups = unpacked_.get(7).toInt(); - const c10::optional unpacked_output_min - = unpacked_.get(6).isScalar() ? unpacked_.get(8).toScalar() : (c10::optional) c10::nullopt; - const c10::optional unpacked_output_max - = unpacked_.get(6).isScalar() ? unpacked_.get(9).toScalar() : (c10::optional) c10::nullopt; + const c10::optional unpacked_output_min = unpacked_.get(6).isScalar() + ? unpacked_.get(8).toScalar() + : (c10::optional)c10::nullopt; + const c10::optional unpacked_output_max = unpacked_.get(6).isScalar() + ? unpacked_.get(9).toScalar() + : (c10::optional)c10::nullopt; return TransposeConv2dOpContext::State{ - unpacked_weight, - unpacked_bias, - unpacked_stride, - unpacked_padding, - unpacked_output_padding, - unpacked_dilation, - unpacked_groups, - unpacked_output_min, - unpacked_output_max, + unpacked_weight, + unpacked_bias, + unpacked_stride, + unpacked_padding, + unpacked_output_padding, + unpacked_dilation, + unpacked_groups, + unpacked_output_min, + unpacked_output_max, }; } diff --git a/aten/src/ATen/native/vulkan/ops/Upsample.cpp b/aten/src/ATen/native/vulkan/ops/Upsample.cpp index 20516bb387a0e..4ad959dca6ba3 100644 --- a/aten/src/ATen/native/vulkan/ops/Upsample.cpp +++ b/aten/src/ATen/native/vulkan/ops/Upsample.cpp @@ -1,14 +1,12 @@ #include -#include #include +#include #include namespace at { namespace native { namespace vulkan { namespace ops { -namespace { - using namespace api::utils; Tensor upsample_nearest2d( @@ -18,85 +16,153 @@ Tensor upsample_nearest2d( const c10::optional scales_w) { api::Context* const context = api::context(); + TORCH_CHECK( + (4 == input_arg.sizes().size()) && (2 == output_sizes.size()), + "Invalid input!"); + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); const vTensor& v_input = convert(input); const auto v_input_sizes = v_input.sizes(); + vTensor v_output{ + context, + { + v_input_sizes[Layout::Activation4D::batch], + v_input_sizes[Layout::Activation4D::channels], + output_sizes[Layout::Parameter::height], + output_sizes[Layout::Parameter::width], + }, + input_arg.options(), + }; + + const struct Block final { + uvec3 extents; + uint32_t _; + ivec2 iextents; + vec2 scale; + } block{ + v_output.extents(), + 0u, + { + safe_downcast( + input_arg.size(Layout::Activation4D::width) - 1), + safe_downcast( + input_arg.size(Layout::Activation4D::height) - 1), + }, + { + compute_scales_value( + scales_w, + v_input_sizes[Layout::Activation4D::width], + output_sizes[Layout::Parameter::width]), + compute_scales_value( + scales_h, + v_input_sizes[Layout::Activation4D::height], + output_sizes[Layout::Parameter::height]), + }, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(upsample_nearest2d), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); + + return convert(v_output); +} + +Tensor quantized_upsample_nearest2d( + const Tensor& input_arg, + const IntArrayRef output_sizes, + const c10::optional scales_h, + const c10::optional scales_w) { + api::Context* const context = api::context(); + TORCH_CHECK( - (4 == v_input_sizes.size()) && (2 == output_sizes.size()), + (4 == input_arg.sizes().size()) && (2 == output_sizes.size()), "Invalid input!"); + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + const auto v_input_sizes = v_input.sizes(); + vTensor v_output{ - context, - { - v_input_sizes[Layout::Activation4D::batch], - v_input_sizes[Layout::Activation4D::channels], - output_sizes[Layout::Parameter::height], - output_sizes[Layout::Parameter::width], - }, - input.options(), + context, + { + v_input_sizes[Layout::Activation4D::batch], + v_input_sizes[Layout::Activation4D::channels], + output_sizes[Layout::Parameter::height], + output_sizes[Layout::Parameter::width], + }, + input_arg.options(), + v_input.get_scale(), + v_input.get_zero_point(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::upsample_nearest2d"); - - if C10_LIKELY(v_input.has_image()) { - const struct Block final { - uvec3 extents; - uint32_t _; - ivec2 iextents; - vec2 scale; - } block { - v_output.extents(), - 0u, - { - safe_downcast(input.size(Layout::Activation4D::width) - 1), - safe_downcast(input.size(Layout::Activation4D::height) - 1), - }, - { - compute_scales_value( - scales_w, - v_input_sizes[Layout::Activation4D::width], - output_sizes[Layout::Parameter::width]), - compute_scales_value( - scales_h, - v_input_sizes[Layout::Activation4D::height], - output_sizes[Layout::Parameter::height]), - }, - }; - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(upsample_nearest2d), - v_output.extents(), - adaptive_work_group_size(v_output.extents()), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } - else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const struct Block final { + uvec3 extents; + uint32_t _; + ivec2 iextents; + vec2 scale; + } block{ + v_output.extents(), + 0u, + { + safe_downcast( + input_arg.size(Layout::Activation4D::width) - 1), + safe_downcast( + input_arg.size(Layout::Activation4D::height) - 1), + }, + { + compute_scales_value( + scales_w, + v_input_sizes[Layout::Activation4D::width], + output_sizes[Layout::Parameter::width]), + compute_scales_value( + scales_h, + v_input_sizes[Layout::Activation4D::height], + output_sizes[Layout::Parameter::height]), + }, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(quantized_upsample_nearest2d), + // pipeline barrier + pipeline_barrier, + // global work group size + v_output.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } @@ -104,12 +170,13 @@ Tensor upsample_nearest2d( #ifdef USE_VULKAN_API TORCH_LIBRARY_IMPL(aten, Vulkan, m) { - m.impl(TORCH_SELECTIVE_NAME("aten::upsample_nearest2d"), TORCH_FN(upsample_nearest2d)); + m.impl( + TORCH_SELECTIVE_NAME("aten::upsample_nearest2d"), + TORCH_FN(upsample_nearest2d)); } #endif /* USE_VULKAN_API */ -} // namespace } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Utils.cpp b/aten/src/ATen/native/vulkan/ops/Utils.cpp new file mode 100644 index 0000000000000..0a255f9915bd3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Utils.cpp @@ -0,0 +1,122 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace utils { + +void pack_buffer_to_vtensor( + api::VulkanBuffer& buffer, + vTensor& v_self, + api::PipelineBarrier& pipeline_barrier) { + api::Context* const context = api::context(); + + const api::utils::uvec3 extents = v_self.extents(); + const uint32_t plane = extents.data[0u] * extents.data[1u]; + + const struct Block final { + api::utils::uvec3 extents; + uint32_t block; + api::utils::uvec4 offset; + } block{ + extents, + 4u * plane, + { + 0u * plane, + 1u * plane, + 2u * plane, + 3u * plane, + }, + }; + + api::UniformParamsBuffer params(context, block); + bool is_quantized = v_self.is_quantized(); + api::ShaderSource kernel = is_quantized ? VK_KERNEL(nchw_to_image_quantized) + : VK_KERNEL(nchw_to_image); + + context->submit_compute_job( + // shader descriptor + kernel, + // pipeline barrier + pipeline_barrier, + // global work group size + extents, + // local work group size + adaptive_work_group_size(extents), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_self.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + buffer, + // params buffer + params.buffer()); +} + +void pack_staging_to_vtensor(api::VulkanBuffer& staging, vTensor& v_self) { + api::PipelineBarrier pipeline_barrier{}; + pack_buffer_to_vtensor(staging, v_self, pipeline_barrier); +} + +void pack_vtensor_to_staging( + vTensor& v_self, + api::VulkanBuffer& staging, + const VkFence fence_handle) { + api::Context* const context = api::context(); + + const api::utils::uvec3 extents = v_self.extents(); + const uint32_t plane = extents.data[0u] * extents.data[1u]; + + const struct Block final { + api::utils::uvec3 extents; + uint32_t block; + api::utils::uvec4 offset; + } block{ + extents, + 4u * plane, + { + 0u * plane, + 1u * plane, + 2u * plane, + 3u * plane, + }, + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + bool is_quantized = v_self.is_quantized(); + api::utils::uvec3 copy_extents; + copy_extents.data[0u] = 1; + copy_extents.data[1u] = 1; + copy_extents.data[2u] = + ((v_self.sizes()[1] * v_self.sizes()[2] * v_self.sizes()[3]) / 4); + api::ShaderSource kernel = is_quantized ? VK_KERNEL(image_to_nchw_quantized) + : VK_KERNEL(image_to_nchw); + api::utils::uvec3 extents_to_use = is_quantized ? copy_extents : extents; + + context->submit_compute_job( + // shader descriptor + kernel, + // pipeline barrier + pipeline_barrier, + // global work group size + extents_to_use, + // local work group size + adaptive_work_group_size(extents_to_use), + // fence handle + fence_handle, + // shader arguments + v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE), + staging, + // params buffer + params.buffer()); +} + +} // namespace utils +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Utils.h b/aten/src/ATen/native/vulkan/ops/Utils.h index de218cfc472ab..59358ee173eb0 100644 --- a/aten/src/ATen/native/vulkan/ops/Utils.h +++ b/aten/src/ATen/native/vulkan/ops/Utils.h @@ -10,12 +10,22 @@ namespace vulkan { namespace ops { namespace utils { -inline int64_t normalize( - const int64_t dimension, - const int64_t n) { +inline int64_t normalize(const int64_t dimension, const int64_t n) { return (dimension % n + n) % n; } +void pack_buffer_to_vtensor( + api::VulkanBuffer&, + vTensor&, + api::PipelineBarrier&); + +void pack_staging_to_vtensor(api::VulkanBuffer&, vTensor&); + +void pack_vtensor_to_staging( + vTensor&, + api::VulkanBuffer&, + const VkFence fence_handle = VK_NULL_HANDLE); + } // namespace utils } // namespace ops } // namespace vulkan diff --git a/aten/src/ATen/native/vulkan/ops/VulkanOpContext.cpp b/aten/src/ATen/native/vulkan/ops/VulkanOpContext.cpp index 0f400128c3c25..58f07b0d43c4f 100644 --- a/aten/src/ATen/native/vulkan/ops/VulkanOpContext.cpp +++ b/aten/src/ATen/native/vulkan/ops/VulkanOpContext.cpp @@ -8,9 +8,7 @@ namespace ops { VulkanOpContext::VulkanOpContext( c10::impl::GenericList packed_context, c10::impl::GenericList unpacked_context) - : packed_{ packed_context }, - unpacked_{ unpacked_context } { -} + : packed_(packed_context), unpacked_(unpacked_context) {} VulkanOpContext VulkanOpContext::create( c10::impl::GenericList packed_context, diff --git a/aten/src/ATen/native/vulkan/ops/cumsum.cpp b/aten/src/ATen/native/vulkan/ops/cumsum.cpp index e434949964fc5..fd84d3304f396 100644 --- a/aten/src/ATen/native/vulkan/ops/cumsum.cpp +++ b/aten/src/ATen/native/vulkan/ops/cumsum.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -16,73 +15,57 @@ Tensor cumsum( const int64_t dim, const c10::optional dtype) { TORCH_CHECK( - input_arg.dim() <= 4, - "Vulkan cumsum expects input dimension <= 4!"); + input_arg.dim() <= 4, "Vulkan cumsum expects input dimension <= 4!"); TORCH_CHECK( - batch_size(input_arg) == 1, - "Vulkan cumsum expects batch size <= 1!"); + batch_size(input_arg) == 1, "Vulkan cumsum expects batch size <= 1!"); - TORCH_CHECK( - dim < 4, - "Vulkan cumsum expects dim < 4!"); + TORCH_CHECK(dim < 4, "Vulkan cumsum expects dim < 4!"); + + if (dim <= 1) { + // TODO: dim<0, dim=0, dim=1(z axis) + TORCH_CHECK(false, "Not implemented!"); + } api::Context* const context = api::context(); const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); const vTensor& v_input = convert(input); + vTensor v_output{ - context, - input.sizes(), - input.options(), + context, + input_arg.sizes(), + input_arg.options(), }; - api::Command::Pool& command_pool = context->command().pool; - api::Command::Buffer& command_buffer = command_pool.stream(); - { - api::OpProfiler profiler(command_buffer, context->querypool(), "aten::cumsum"); - - if C10_LIKELY(v_input.has_image()) { - const struct Block final { - int32_t axis; - } block { - (3-safe_downcast(dim)), - }; - - if(dim<=1) { - // TODO: dim<0, dim=0, dim=1(z axis) - TORCH_CHECK(false, "Not implemented!"); - } - - context->dispatch( - command_buffer, - { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(cumsum), - v_input.extents(), - context->gpu().adapter->local_work_group_size(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image( - command_buffer, - vTensor::Stage::Compute, - vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_input.image( - command_buffer, - vTensor::Stage::Compute), - // Object lifetime is managed by the resource pool. - // It is OK not to keep track of the handle. - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } - } - command_pool.submit(context->gpu().queue, command_buffer); + const struct Block final { + int32_t axis; + } block{ + (3 - safe_downcast(dim)), + }; + + api::UniformParamsBuffer params(context, block); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(cumsum), + // pipeline barrier + pipeline_barrier, + // global work group size + v_input.extents(), + // local work group size + adaptive_work_group_size(v_output.extents()), + // fence handle + VK_NULL_HANDLE, + // shader arguments + v_output.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE), + // params buffer + params.buffer()); return convert(v_output); } diff --git a/aten/src/ATen/nnapi/CMakeLists.txt b/aten/src/ATen/nnapi/CMakeLists.txt index 01324049bde6f..065729e5c887e 100644 --- a/aten/src/ATen/nnapi/CMakeLists.txt +++ b/aten/src/ATen/nnapi/CMakeLists.txt @@ -3,7 +3,7 @@ if(PYTORCH_NNAPI_STANDALONE) cmake_minimum_required(VERSION 3.5 FATAL_ERROR) project(pytorch_nnapi) - set(CMAKE_CXX_STANDARD 14) + set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") find_package(Torch REQUIRED) set(NNAPI_SRCS diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index 4a70ddf8c2055..ff1778222a97a 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -111,9 +111,18 @@ inline Tensor new_qtensor( const TensorOptions& options, QuantizerPtr quantizer) { auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous); - at::Allocator* allocator = options.device().is_cuda() - ? at::detail::getCUDAHooks().getCUDADeviceAllocator() - : at::getCPUAllocator(); + auto device = options.device(); + at::Allocator* allocator = nullptr; + // TODO: why isn't this just using GetAllocator + if (device.is_cuda()) { + allocator = at::detail::getCUDAHooks().getCUDADeviceAllocator(); + } else if (device.is_cpu()) { + allocator = at::getCPUAllocator(); + } else if (device.is_meta()) { + allocator = GetAllocator(kMeta); + } else { + TORCH_INTERNAL_ASSERT(0, "unrecognized device for new_qtensor: ", device); + } #ifdef USE_PYTORCH_QNNPACK if (at::globalContext().qEngine() == at::QEngine::QNNPACK) { diff --git a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp index 0feaa8996941f..dd4c3270843ff 100644 --- a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp +++ b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp @@ -1,6 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS // ${generated_comment} +#include #include #include @@ -54,6 +55,8 @@ void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) { ${CompositeViewCopyKernel_Definitions} +${SymIntViewCopyKernel_Definitions} + ${GeneratedCompositeFunctional_Definitions} ${GeneratedCompositeOut_Definitions} diff --git a/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp b/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp index 1a5b4a452592d..7647f459a744b 100644 --- a/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp +++ b/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp @@ -2,6 +2,10 @@ ${includes} ${native_functions_include} +namespace { +${helper_fns} +} // namespace + ${namespace_prologue} ${native_function_definitions} diff --git a/aten/src/ATen/templates/NativeFunction.h b/aten/src/ATen/templates/NativeFunction.h index 35f2dbf742401..4f70db62a4c64 100644 --- a/aten/src/ATen/templates/NativeFunction.h +++ b/aten/src/ATen/templates/NativeFunction.h @@ -14,10 +14,4 @@ #include ${extra_includes} -namespace at { -namespace native { - ${native_function_declarations} - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/templates/NativeFunctions.h b/aten/src/ATen/templates/NativeFunctions.h index 19877092114e0..d6d7205b5793b 100644 --- a/aten/src/ATen/templates/NativeFunctions.h +++ b/aten/src/ATen/templates/NativeFunctions.h @@ -30,10 +30,4 @@ ${NativeFunctions_includes} -namespace at { -namespace native { - ${NativeFunctions_declarations} - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/templates/RegisterFunctionalization.cpp b/aten/src/ATen/templates/RegisterFunctionalization.cpp index eac2398f44004..af3ed13de7aef 100644 --- a/aten/src/ATen/templates/RegisterFunctionalization.cpp +++ b/aten/src/ATen/templates/RegisterFunctionalization.cpp @@ -23,6 +23,19 @@ namespace at { namespace functionalization { +// This keyset is used by functionalization when it calls into meta kernels +// to accurately propagate stride metadata. +// Exclude any modes: the purpose of calling into meta kernels is only as an implementation +// detail to perform shape inference, and we don't want any modal keys to run. +// Specifically, we want to prevent functionalization and Python modes from running. +constexpr auto exclude_keys_for_meta_dispatch = + c10::functorch_transforms_ks | + c10::DispatchKeySet({ + c10::DispatchKey::FuncTorchDynamicLayerBackMode, + c10::DispatchKey::FuncTorchDynamicLayerFrontMode, + c10::DispatchKey::Python + }); + inline Tensor to_meta(const Tensor& t) { return at::native::empty_strided_meta(t.sizes(), t.strides(), diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index fa757feda4bac..662712c641f11 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -569,9 +569,9 @@ class TORCH_API Tensor: public TensorBase { //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template - using hook_return_void_t = std::enable_if_t::type>::value, unsigned>; + using hook_return_void_t = std::enable_if_t>::value, unsigned>; template - using hook_return_var_t = std::enable_if_t::type, Tensor>::value, unsigned>; + using hook_return_var_t = std::enable_if_t, Tensor>::value, unsigned>; /// Registers a backward hook. /// diff --git a/aten/src/ATen/test/ExclusivelyOwned_test.cpp b/aten/src/ATen/test/ExclusivelyOwned_test.cpp index 0c508adbca425..5d1dcf7127d7c 100644 --- a/aten/src/ATen/test/ExclusivelyOwned_test.cpp +++ b/aten/src/ATen/test/ExclusivelyOwned_test.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -9,8 +10,6 @@ namespace { -using at::Tensor; - template class ExclusivelyOwnedTest : public ::testing::Test { public: @@ -28,15 +27,20 @@ template T getSampleValue(); template <> -Tensor getSampleValue() { +at::Tensor getSampleValue() { return at::native::zeros({2, 2}).to(at::kCPU); } +template <> +caffe2::Tensor getSampleValue() { + return caffe2::Tensor(getSampleValue()); +} + template void assertIsSampleObject(const T& eo); template <> -void assertIsSampleObject(const Tensor& t) { +void assertIsSampleObject(const at::Tensor& t) { EXPECT_EQ(t.sizes(), (c10::IntArrayRef{2, 2})); EXPECT_EQ(t.strides(), (c10::IntArrayRef{2, 1})); ASSERT_EQ(t.scalar_type(), at::ScalarType::Float); @@ -44,6 +48,11 @@ void assertIsSampleObject(const Tensor& t) { EXPECT_EQ(memcmp(zeros, t.data_ptr(), 4 * sizeof(float)), 0); } +template <> +void assertIsSampleObject(const caffe2::Tensor& t) { + assertIsSampleObject(at::Tensor(t)); +} + template void ExclusivelyOwnedTest::SetUp() { @@ -52,7 +61,8 @@ void ExclusivelyOwnedTest::SetUp() { } using ExclusivelyOwnedTypes = ::testing::Types< - Tensor + at::Tensor, + caffe2::Tensor >; TYPED_TEST_CASE(ExclusivelyOwnedTest, ExclusivelyOwnedTypes); @@ -94,5 +104,5 @@ extern "C" void inspectTensor() { } extern "C" void inspectExclusivelyOwnedTensor() { - c10::ExclusivelyOwned t(getSampleValue()); + c10::ExclusivelyOwned t(getSampleValue()); } diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index d14e7cd36ab99..75cd45d0ee78c 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -5,6 +5,7 @@ #include #include #include +#include // for TH compat test only... struct THFloatTensor; @@ -14,6 +15,8 @@ struct THFloatTensor; // NOLINTNEXTLINE(modernize-deprecated-headers) #include #include +#include +#include #define ASSERT_EQ_RESOLVED(X, Y) \ { \ @@ -471,3 +474,48 @@ TEST(BasicTest, FactoryMethodsTest) { ASSERT_FALSE(tensor0.is_pinned()); } } + +TEST(BasicTest, BasicStdTestCPU) { + c10::once_flag flag1, flag2; + + auto simple_do_once = [&]() + { + c10::call_once(flag1, [](){ std::cout << "Simple example: called once\n"; }); + }; + + auto may_throw_function = [&](bool do_throw) + { + if (do_throw) { + std::cout << "throw: call_once will retry\n"; // this may appear more than once + TORCH_CHECK(false, "throw exception"); + } + std::cout << "Didn't throw, call_once will not attempt again\n"; // guaranteed once + }; + + auto do_once = [&](bool do_throw) + { + try { + c10::call_once(flag2, may_throw_function, do_throw); + } + catch (...) { + } + }; + + std::thread st1(simple_do_once); + std::thread st2(simple_do_once); + std::thread st3(simple_do_once); + std::thread st4(simple_do_once); + st1.join(); + st2.join(); + st3.join(); + st4.join(); + + std::thread t1(do_once, true); + std::thread t2(do_once, true); + std::thread t3(do_once, false); + std::thread t4(do_once, true); + t1.join(); + t2.join(); + t3.join(); + t4.join(); +} diff --git a/aten/src/ATen/test/cuda_atomic_ops_test.cu b/aten/src/ATen/test/cuda_atomic_ops_test.cu index d5d261440064b..badef7bc9b1af 100644 --- a/aten/src/ATen/test/cuda_atomic_ops_test.cu +++ b/aten/src/ATen/test/cuda_atomic_ops_test.cu @@ -214,6 +214,11 @@ TEST(TestAtomicOps, TestAtomicAdd) { TEST(TestAtomicOps, DISABLED_ON_WINDOWS(TestAtomicMul)) { if (!at::cuda::is_available()) return; + test_atomic_mul(); + test_atomic_mul(); + test_atomic_mul(); + test_atomic_mul(); + test_atomic_mul(); test_atomic_mul(); test_atomic_mul(); test_atomic_mul(); @@ -222,6 +227,11 @@ TEST(TestAtomicOps, DISABLED_ON_WINDOWS(TestAtomicMul)) { TEST(TestAtomicOps, DISABLED_ON_WINDOWS(TestAtomicMax)) { if (!at::cuda::is_available()) return; + test_atomic_max(); + test_atomic_max(); + test_atomic_max(); + test_atomic_max(); + test_atomic_max(); test_atomic_max(); test_atomic_max(); test_atomic_max(); @@ -230,6 +240,11 @@ TEST(TestAtomicOps, DISABLED_ON_WINDOWS(TestAtomicMax)) { TEST(TestAtomicOps, DISABLED_ON_WINDOWS(TestAtomicMin)) { if (!at::cuda::is_available()) return; + test_atomic_min(); + test_atomic_min(); + test_atomic_min(); + test_atomic_min(); + test_atomic_min(); test_atomic_min(); test_atomic_min(); test_atomic_min(); diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 04d6ed9f1982b..7276261738593 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -4,13 +4,18 @@ #include #include #include -#include #include // TODO: These functions should move to a common place. namespace { +#ifdef USE_VULKAN_FP16_INFERENCE + constexpr float kTolerance = 1e-2; +#else + constexpr float kTolerance = 1e-5; +#endif + bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { float maxValue = 0.0f; @@ -18,21 +23,87 @@ bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { maxValue = fmax(tensor.abs().max().item(), maxValue); } -#ifdef USE_VULKAN_FP16_INFERENCE - constexpr float tolerance = 1e-2; -#else - constexpr float tolerance = 1e-5; -#endif - - return diff.abs().max().item() <= (tolerance * maxValue); + return diff.abs().max().item() <= (kTolerance * maxValue); } bool almostEqual(const at::Tensor& a, const at::Tensor& b) { return checkRtol(a - b, {a, b}); } -bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { - return (a - b).abs().max().item() == 0.0f; +bool checkHardShrink( + const at::Tensor& ref, const at::Tensor& out, const float clamp_thresh) { + float* ref_ptr = ref.data_ptr(); + float* out_ptr = out.data_ptr(); + float ref_max = ref.abs().max().item(); + float out_max = out.abs().max().item(); + float max_val = std::fmax(ref_max, out_max); + + float abs_clamp_thresh = std::abs(clamp_thresh); + + for (int i = 0; i < ref.numel(); ++i) { + float ref_val = ref_ptr[i]; + float out_val = out_ptr[i]; + + float abs_diff = std::abs(ref_val - out_val); + + // For values near the clamp threshold, results may be ambiguous. + float distance_from_thresh = std::abs(std::abs(ref_val) - abs_clamp_thresh); + if (distance_from_thresh < kTolerance * abs_clamp_thresh) { + if (out_val != 0.0f) { + if (abs_diff >= kTolerance * max_val) { + return false; + } + } + } + else if (std::abs(ref_val) < std::abs(abs_clamp_thresh)) { + if (out_val != 0.0f) { + return false; + } + } + else if (abs_diff >= kTolerance * max_val) { + return false; + } + } + return true; +} + +bool checkThreshold( + const at::Tensor& ref, + const at::Tensor& out, + const float clamp_thresh, + const float value) { + float* ref_ptr = ref.data_ptr(); + float* out_ptr = out.data_ptr(); + float ref_max = ref.abs().max().item(); + float out_max = out.abs().max().item(); + float max_val = std::fmax(ref_max, out_max); + + for (int i = 0; i < ref.numel(); ++i) { + float ref_val = ref_ptr[i]; + float out_val = out_ptr[i]; + + float abs_diff = std::abs(ref_val - out_val); + float val_diff = std::abs(out_val - value); + + // For values near the clamp threshold, results may be ambiguous. + float distance_from_thresh = std::abs(std::abs(ref_val) - clamp_thresh); + if (distance_from_thresh < kTolerance * clamp_thresh) { + if (val_diff >= kTolerance * value) { + if (abs_diff >= kTolerance * max_val) { + return false; + } + } + } + else if (std::abs(ref_val) < std::abs(clamp_thresh)) { + if (val_diff >= kTolerance * value) { + return false; + } + } + else if (abs_diff >= kTolerance * max_val) { + return false; + } + } + return true; } void showRtol(const at::Tensor& a, const at::Tensor& b) { @@ -41,13 +112,7 @@ void showRtol(const at::Tensor& a, const at::Tensor& b) { float maxValue = a.abs().max().item(); maxValue = fmax(b.abs().max().item(), maxValue); -#ifdef USE_VULKAN_FP16_INFERENCE - constexpr float tolerance = 1e-2; -#else - constexpr float tolerance = 1e-5; -#endif - - const float maxDiff = maxValue * tolerance; + const float maxDiff = maxValue * kTolerance; std::cout << "Max Diff allowed: " << maxDiff << std::endl; if (diff.sizes().size() == 2) { for (const auto y : c10::irange(diff.sizes()[0])) { @@ -81,11 +146,6 @@ static void gen_allpermutations(std::vector>& out, std::vec } static void slice_test(const std::vector& size, int64_t dim, c10::optional start, c10::optional end, int64_t step) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu = at::rand(size, at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan = in_cpu.vulkan(); @@ -118,11 +178,6 @@ static void slice_tests(const std::unordered_map>& } static void clone_test(const std::vector& size, c10::optional optional_memory_format) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu = at::rand(size, at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan = in_cpu.vulkan(); @@ -170,22 +225,30 @@ inline std::vector callOpByName( namespace { class VulkanAPITest : public ::testing::Test { -public: -#if defined (__ANDROID__) // to avoid `Undefined symbols for architecture arm64` error - static void SetUpTestSuite() { - at::native::vulkan::api::context()->querypool().enable(); + public: + void SetUp() { + if (!at::is_vulkan_available()) { + GTEST_SKIP() << "Vulkan is not available"; } +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + } - static void TearDownTestSuite() { - at::native::vulkan::api::context()->querypool().disable(false); + void TearDown() { +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + try { + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + } catch (const std::exception& e) { + std::cout << "Could not get querypool results!" + << " Reason: " << e.what() << std::endl; } #endif + } }; TEST_F(VulkanAPITest, adaptive_avg_pool2d) { - if (!at::is_vulkan_available()) { - return; - } c10::InferenceMode mode; const auto in_cpu = at::rand({5, 7, 47, 31}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); @@ -201,10 +264,6 @@ TEST_F(VulkanAPITest, adaptive_avg_pool2d) { } TEST_F(VulkanAPITest, add) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -223,10 +282,6 @@ TEST_F(VulkanAPITest, add) { } TEST_F(VulkanAPITest, add_broadcast0) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -245,10 +300,6 @@ TEST_F(VulkanAPITest, add_broadcast0) { } TEST_F(VulkanAPITest, add_broadcast1) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -267,9 +318,6 @@ TEST_F(VulkanAPITest, add_broadcast1) { } TEST_F(VulkanAPITest, add_broadcast2) { - if (!at::is_vulkan_available()) { - return; - } const auto a_cpu = at::rand({3, 4, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -289,10 +337,6 @@ TEST_F(VulkanAPITest, add_broadcast2) { } TEST_F(VulkanAPITest, add_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -311,10 +355,6 @@ TEST_F(VulkanAPITest, add_) { } TEST_F(VulkanAPITest, add_broadcast0_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({16, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -333,10 +373,6 @@ TEST_F(VulkanAPITest, add_broadcast0_) { } TEST_F(VulkanAPITest, add_broadcast1_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({3, 8, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -355,10 +391,6 @@ TEST_F(VulkanAPITest, add_broadcast1_) { } TEST_F(VulkanAPITest, add_scalar) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({13, 23, 59, 73}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -376,10 +408,6 @@ TEST_F(VulkanAPITest, add_scalar) { } TEST_F(VulkanAPITest, add_scalar_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({47, 2, 23, 97}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -397,10 +425,6 @@ TEST_F(VulkanAPITest, add_scalar_) { } TEST_F(VulkanAPITest, addmm) { - if (!at::is_vulkan_available()) { - return; - } - constexpr float alpha = 2.1f; constexpr float beta = 103.24; @@ -421,10 +445,6 @@ TEST_F(VulkanAPITest, addmm) { } TEST_F(VulkanAPITest, addmm_expand) { - if (!at::is_vulkan_available()) { - return; - } - constexpr float alpha = 2.1f; constexpr float beta = 103.24; @@ -445,10 +465,6 @@ TEST_F(VulkanAPITest, addmm_expand) { } TEST_F(VulkanAPITest, avg_pool2d) { - if (!at::is_vulkan_available()) { - return; - } - const auto in_cpu = at::rand({3, 19, 43, 79}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); const auto out_cpu = at::avg_pool2d(in_cpu, {5, 3}, {1, 2}, {2, 0}, true); const auto out_vulkan = at::avg_pool2d(in_cpu.vulkan(), {5, 3}, {1, 2}, {2, 0}, true); @@ -461,11 +477,211 @@ TEST_F(VulkanAPITest, avg_pool2d) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, clamp) { - if (!at::is_vulkan_available()) { - return; +TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { + c10::InferenceMode mode; + + // Act: Vulkan batchnorm only supports evaluation mode + EXPECT_THROW({ + at::batch_norm( + at::rand({3, 8, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + true, + 0.1, + 1e-05, + false); + }, ::c10::Error); + + // Act: Vulkan batchnorm expects 4-dim input + EXPECT_THROW({ + at::batch_norm( + at::rand({3, 8, 5}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + true, + 0.1, + 1e-05, + false); + }, ::c10::Error); + + // Act: Vulkan batchnorm expects 4-dim input + EXPECT_THROW({ + at::batch_norm( + at::rand({2, 8, 3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + true, + 0.1, + 1e-05, + false); + }, ::c10::Error); + + // Act: Vulkan batchnorm expects channel dim to be multiple of 4 + EXPECT_THROW({ + at::batch_norm( + at::rand({4, 7, 4, 4}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + true, + 0.1, + 1e-05, + false); + }, ::c10::Error); + + // Act: weight tensor contains incorrect number of elements + EXPECT_THROW({ + at::batch_norm( + at::rand({4, 8, 4, 4}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + true, + 0.1, + 1e-05, + false); + }, ::c10::Error); + + // Act: bias tensor contains incorrect number of elements + EXPECT_THROW({ + at::batch_norm( + at::rand({4, 8, 4, 4}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + true, + 0.1, + 1e-05, + false); + }, ::c10::Error); + + // Act: running mean tensor contains incorrect number of elements + EXPECT_THROW({ + at::batch_norm( + at::rand({4, 8, 4, 4}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + true, + 0.1, + 1e-05, + false); + }, ::c10::Error); + + // Act: running var tensor contains incorrect number of elements + EXPECT_THROW({ + at::batch_norm( + at::rand({4, 8, 4, 4}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + true, + 0.1, + 1e-05, + false); + }, ::c10::Error); +} + +TEST_F(VulkanAPITest, batch_norm_small) { + c10::InferenceMode mode; + + const auto input_cpu = at::rand({1, 4, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); + + const auto weight_cpu = at::rand({4}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); + + const auto bias_cpu = at::rand({4}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); + + const auto running_mean_cpu = at::rand({4}, at::device(at::kCPU).dtype(at::kFloat)); + const auto running_mean_vulkan = running_mean_cpu.vulkan(); + + const auto running_var_cpu = at::rand({4}, at::device(at::kCPU).dtype(at::kFloat)); + const auto running_var_vulkan = running_var_cpu.vulkan(); + + const auto output_cpu = at::batch_norm(input_cpu, weight_cpu, bias_cpu, running_mean_cpu, running_var_cpu, false, 0.1, 1e-05, false); + const auto output_vulkan = at::batch_norm(input_vulkan, weight_vulkan, bias_vulkan, running_mean_vulkan, running_var_vulkan, false, 0.1, 1e-05, false); + + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + showRtol(output_cpu, output_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, batch_norm_medium) { + c10::InferenceMode mode; + + const auto input_cpu = at::rand({3, 8, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); + + const auto weight_cpu = at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); + + const auto bias_cpu = at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); + + const auto running_mean_cpu = at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)); + const auto running_mean_vulkan = running_mean_cpu.vulkan(); + + const auto running_var_cpu = at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)); + const auto running_var_vulkan = running_var_cpu.vulkan(); + + const auto output_cpu = at::batch_norm(input_cpu, weight_cpu, bias_cpu, running_mean_cpu, running_var_cpu, false, 0.1, 1e-05, false); + const auto output_vulkan = at::batch_norm(input_vulkan, weight_vulkan, bias_vulkan, running_mean_vulkan, running_var_vulkan, false, 0.1, 1e-05, false); + + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + showRtol(output_cpu, output_vulkan.cpu()); } + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, batch_norm_large) { + c10::InferenceMode mode; + + + const auto input_cpu = at::rand({79, 52, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); + + const auto weight_cpu = at::rand({52}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); + + const auto bias_cpu = at::rand({52}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); + + const auto running_mean_cpu = at::rand({52}, at::device(at::kCPU).dtype(at::kFloat)); + const auto running_mean_vulkan = running_mean_cpu.vulkan(); + + const auto running_var_cpu = at::rand({52}, at::device(at::kCPU).dtype(at::kFloat)); + const auto running_var_vulkan = running_var_cpu.vulkan(); + + const auto output_cpu = at::batch_norm(input_cpu, weight_cpu, bias_cpu, running_mean_cpu, running_var_cpu, false, 0.1, 1e-05, false); + const auto output_vulkan = at::batch_norm(input_vulkan, weight_vulkan, bias_vulkan, running_mean_vulkan, running_var_vulkan, false, 0.1, 1e-05, false); + + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + showRtol(output_cpu, output_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, clamp) { const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan = in_cpu.vulkan(); @@ -484,10 +700,6 @@ TEST_F(VulkanAPITest, clamp) { } TEST_F(VulkanAPITest, clamp_) { - if (!at::is_vulkan_available()) { - return; - } - const auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); const auto vulkan = cpu.vulkan(); @@ -506,10 +718,6 @@ TEST_F(VulkanAPITest, clamp_) { } TEST_F(VulkanAPITest, conv2d) { - if (!at::is_vulkan_available()) { - return; - } - constexpr int64_t groups = 1; constexpr std::array stride{2, 2}; constexpr std::array padding{1, 1}; @@ -579,10 +787,6 @@ TEST_F(VulkanAPITest, conv2d) { } TEST_F(VulkanAPITest, conv2d_dw) { - if (!at::is_vulkan_available()) { - return; - } - constexpr int64_t groups = 7; constexpr std::array stride{2, 3}; constexpr std::array padding{0, 4}; @@ -651,10 +855,6 @@ TEST_F(VulkanAPITest, conv2d_dw) { } TEST_F(VulkanAPITest, conv2d_pw) { - if (!at::is_vulkan_available()) { - return; - } - constexpr int64_t groups = 1; constexpr std::array stride{1, 1}; constexpr std::array padding{0, 0}; @@ -722,87 +922,11 @@ TEST_F(VulkanAPITest, conv2d_pw) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, conv2d_winograd) { - if (!at::is_vulkan_available()) { - return; - } - - constexpr int64_t groups = 1; - constexpr std::array stride{1, 1}; - constexpr std::array padding{2, 2}; - constexpr std::array dilation{1, 1}; - - constexpr struct { - uint32_t batches; - uint32_t channels; - uint32_t width; - uint32_t height; - - std::array size() const { - return { - batches, - channels, - width, - height, - }; - } - } input {1, 10, 177, 232}; - - constexpr struct { - uint32_t output_channels; - uint32_t input_channels; - uint32_t width; - uint32_t height; - - std::array size() const { - return { - output_channels, - input_channels, - width, - height, - }; - } - } weights {13, input.channels, 3, 3}; - - const auto input_cpu = at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)); - const auto weights_cpu = at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); - const auto bias_cpu = at::rand({weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); - - const auto output_cpu = at::conv2d( - input_cpu, - weights_cpu, - bias_cpu, - stride, - padding, - dilation, - groups); - - const auto output_vulkan = at::conv2d( - input_cpu.vulkan(), - weights_cpu, - bias_cpu, - stride, - padding, - dilation, - groups).cpu(); - - const bool check = almostEqual(output_cpu, output_vulkan); - if (!check) { - showRtol(output_cpu, output_vulkan); - } - - ASSERT_TRUE(check); -} - TEST_F(VulkanAPITest, copy) { - if (!at::is_vulkan_available()) { - return; - } - const auto cpu = at::rand({13, 17, 37, 19}, at::device(at::kCPU).dtype(at::kFloat)); const auto vulkan = cpu.vulkan(); - const auto check = exactlyEqual(cpu, vulkan.cpu()); + const auto check = almostEqual(cpu, vulkan.cpu()); if (!check) { showRtol(cpu, vulkan.cpu()); } @@ -811,9 +935,6 @@ TEST_F(VulkanAPITest, copy) { } TEST_F(VulkanAPITest, cumsum) { - if (!at::is_vulkan_available()) { - return; - } c10::InferenceMode mode; const auto in_cpu = at::rand({1, 17, 37, 49}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); @@ -841,10 +962,6 @@ TEST_F(VulkanAPITest, cumsum) { } TEST_F(VulkanAPITest, div) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat))+0.01; const auto a_vulkan = a_cpu.vulkan(); @@ -863,10 +980,6 @@ TEST_F(VulkanAPITest, div) { } TEST_F(VulkanAPITest, div_broadcast0) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 1, 1}, at::device(at::kCPU).dtype(at::kFloat))+0.01; const auto a_vulkan = a_cpu.vulkan(); @@ -885,10 +998,6 @@ TEST_F(VulkanAPITest, div_broadcast0) { } TEST_F(VulkanAPITest, div_broadcast1) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 179, 221}, at::device(at::kCPU).dtype(at::kFloat))+0.01; const auto a_vulkan = a_cpu.vulkan(); @@ -907,10 +1016,6 @@ TEST_F(VulkanAPITest, div_broadcast1) { } TEST_F(VulkanAPITest, div_broadcast2) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 4, 179, 221}, at::device(at::kCPU).dtype(at::kFloat))+0.01; const auto a_vulkan = a_cpu.vulkan(); @@ -929,10 +1034,6 @@ TEST_F(VulkanAPITest, div_broadcast2) { } TEST_F(VulkanAPITest, div_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat))+0.01; auto a_vulkan = a_cpu.vulkan(); @@ -951,10 +1052,6 @@ TEST_F(VulkanAPITest, div_) { } TEST_F(VulkanAPITest, div_broadcast0_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({12, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat))+0.01; auto a_vulkan = a_cpu.vulkan(); @@ -969,167 +1066,379 @@ TEST_F(VulkanAPITest, div_broadcast0_) { showRtol(b_cpu, b_vulkan.cpu()); } - ASSERT_TRUE(check); + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, div_broadcast1_) { + auto a_cpu = at::rand({3, 8, 29, 83}, at::device(at::kCPU).dtype(at::kFloat))+0.01; + auto a_vulkan = a_cpu.vulkan(); + + const auto b_cpu = at::rand({8, 1, 1}, at::device(at::kCPU).dtype(at::kFloat))+0.01; + const auto b_vulkan = b_cpu.vulkan(); + + a_cpu.div_(b_cpu); + a_vulkan.div_(b_vulkan); + + const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + if (!check) { + showRtol(b_cpu, b_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, div_scalar) { + + const auto a_cpu = at::rand({17, 213, 213, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + const auto c_cpu = at::div(a_cpu, b_scalar); + const auto c_vulkan = at::div(a_vulkan, b_scalar); + + const auto check = almostEqual(c_cpu, c_vulkan.cpu()); + if (!check) { + showRtol(c_cpu, c_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, div_scalar_) { + auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + a_cpu.div_(b_scalar); + a_vulkan.div_(b_scalar); + + const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + if (!check) { + showRtol(a_cpu, a_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, empty) { + + ASSERT_NO_THROW(at::empty({1, 17, 41, 53}, at::device(at::kVulkan).dtype(at::kFloat))); +} + +void test_glu(const at::IntArrayRef input_shape) { + const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_cpu = at::glu(in_cpu, 1); + const auto out_vulkan = at::glu(in_vulkan, 1); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, glu_ch_200) { + test_glu({17, 200, 302, 5}); +} + +TEST_F(VulkanAPITest, glu_ch_64) { + test_glu({1, 64, 100, 8}); +} + +TEST_F(VulkanAPITest, glu_ch_32) { + test_glu({1, 32, 100, 19}); +} + +TEST_F(VulkanAPITest, glu_ch_10) { + test_glu({17, 10, 57, 41}); +} + +TEST_F(VulkanAPITest, glu_ch_2) { + test_glu({1, 2, 100, 40}); +} + +TEST_F(VulkanAPITest, hardsigmoid) { + const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat))*12 - 6; + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_cpu = at::hardsigmoid(in_cpu); + const auto out_vulkan = at::hardsigmoid(in_vulkan); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, hardsigmoid_) { + auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat))*12 - 6; + auto vulkan = cpu.vulkan(); + + at::hardsigmoid_(cpu); + at::hardsigmoid_(vulkan); + + const auto check = almostEqual(cpu, vulkan.cpu()); + if (!check) { + showRtol(cpu, vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, hardshrink) { + for (const auto lambd_value : {-4.2, -1.0, 0.42, 1.0, 4.2, 13.7}) { + // Generate values between -10 and +10 + const auto in_cpu = (at::rand({3, 63, 79, 17}, at::device(at::kCPU).dtype(at::kFloat)) - 0.5) * 20; + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_vulkan = at::hardshrink(in_vulkan, lambd_value); + + const auto check = checkHardShrink(in_cpu, out_vulkan.cpu(), lambd_value); + ASSERT_TRUE(check); + } +} + +TEST_F(VulkanAPITest, hardshrink_) { + for (const auto lambd_value : {0.42, 1.0, 4.2, 13.7}) { + // Generate values between -10 and +10 + const auto in_cpu = (at::rand({3, 63, 79, 17}, at::device(at::kCPU).dtype(at::kFloat)) - 0.5) * 20; + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_cpu = in_cpu.hardshrink(lambd_value); + const auto out_vulkan = in_vulkan.hardshrink(lambd_value).cpu(); + + const auto check = checkHardShrink(out_cpu, out_vulkan, lambd_value); + ASSERT_TRUE(check); + } +} + +TEST_F(VulkanAPITest, layer_norm_invalid_inputs) { + c10::InferenceMode mode; + + // Act: incorrect normalized shape + EXPECT_THROW({ + at::layer_norm( + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + {8, 5}, + at::rand({8, 5}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({8, 5}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + 1e-05, + false); + }, ::c10::Error); + + // Act: normalized shape must be [C, H, W] + EXPECT_THROW({ + at::layer_norm( + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + {5, 7}, + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + 1e-05, + false); + }, ::c10::Error); + + // Act: incorrect weight dimensions + EXPECT_THROW({ + at::layer_norm( + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + {3, 5, 7}, + at::rand({3, 5}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + 1e-05, + false); + }, ::c10::Error); + + // Act: incorrect bias dimensions + EXPECT_THROW({ + at::layer_norm( + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + {3, 5, 7}, + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + 1e-05, + false); + }, ::c10::Error); + + // Act: batch dim must be 1 + EXPECT_THROW({ + at::layer_norm( + at::rand({2, 3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + {3, 5, 7}, + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + 1e-05, + false); + }, ::c10::Error); + + // Act: input has too many dimensions + EXPECT_THROW({ + at::layer_norm( + at::rand({1, 2, 3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + {3, 5, 7}, + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + 1e-05, + false); + }, ::c10::Error); + + // Act: input has too few dimensions + EXPECT_THROW({ + at::layer_norm( + at::rand({3, 5}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + {3, 5}, + at::rand({3, 5}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({3, 5}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + 1e-05, + false); + }, ::c10::Error); } -TEST_F(VulkanAPITest, div_broadcast1_) { - if (!at::is_vulkan_available()) { - return; - } +TEST_F(VulkanAPITest, layer_norm_3d_small) { + c10::InferenceMode mode; - auto a_cpu = at::rand({3, 8, 29, 83}, at::device(at::kCPU).dtype(at::kFloat))+0.01; - auto a_vulkan = a_cpu.vulkan(); + const auto input_cpu = at::rand({1, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); - const auto b_cpu = at::rand({8, 1, 1}, at::device(at::kCPU).dtype(at::kFloat))+0.01; - const auto b_vulkan = b_cpu.vulkan(); + const auto weight_cpu = at::rand({1, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); - a_cpu.div_(b_cpu); - a_vulkan.div_(b_vulkan); + const auto bias_cpu = at::rand({1, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); - const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + const auto output_cpu = at::layer_norm(input_cpu, {1, 1, 1}, weight_cpu, bias_cpu, 1e-05, false); + const auto output_vulkan = at::layer_norm(input_vulkan, {1, 1, 1}, weight_vulkan, bias_vulkan, 1e-05, false); + + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); if (!check) { - showRtol(b_cpu, b_vulkan.cpu()); + showRtol(output_cpu, output_vulkan.cpu()); } ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, div_scalar) { - if (!at::is_vulkan_available()) { - return; - } +TEST_F(VulkanAPITest, layer_norm_3d_medium) { + c10::InferenceMode mode; - const auto a_cpu = at::rand({17, 213, 213, 7}, at::device(at::kCPU).dtype(at::kFloat)); - const auto a_vulkan = a_cpu.vulkan(); + const auto input_cpu = at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); - const float b_scalar = 3.1415f; + const auto weight_cpu = at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); - const auto c_cpu = at::div(a_cpu, b_scalar); - const auto c_vulkan = at::div(a_vulkan, b_scalar); + const auto bias_cpu = at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); - const auto check = almostEqual(c_cpu, c_vulkan.cpu()); + const auto output_cpu = at::layer_norm(input_cpu, {3, 5, 7}, weight_cpu, bias_cpu, 1e-05, false); + const auto output_vulkan = at::layer_norm(input_vulkan, {3, 5, 7}, weight_vulkan, bias_vulkan, 1e-05, false); + + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); if (!check) { - showRtol(c_cpu, c_vulkan.cpu()); + showRtol(output_cpu, output_vulkan.cpu()); } ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, div_scalar_) { - if (!at::is_vulkan_available()) { - return; - } +TEST_F(VulkanAPITest, layer_norm_3d_large) { + c10::InferenceMode mode; - auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); - auto a_vulkan = a_cpu.vulkan(); + const auto input_cpu = at::rand({53, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); - const float b_scalar = 3.1415f; + const auto weight_cpu = at::rand({53, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); - a_cpu.div_(b_scalar); - a_vulkan.div_(b_scalar); + const auto bias_cpu = at::rand({53, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); - const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + const auto output_cpu = at::layer_norm(input_cpu, {53, 139, 109}, weight_cpu, bias_cpu, 1e-05, false); + const auto output_vulkan = at::layer_norm(input_vulkan, {53, 139, 109}, weight_vulkan, bias_vulkan, 1e-05, false); + + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); if (!check) { - showRtol(a_cpu, a_vulkan.cpu()); + showRtol(output_cpu, output_vulkan.cpu()); } ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, empty) { - if (!at::is_vulkan_available()) { - return; - } +TEST_F(VulkanAPITest, layer_norm_4d_small) { + c10::InferenceMode mode; - ASSERT_NO_THROW(at::empty({1, 17, 41, 53}, at::device(at::kVulkan).dtype(at::kFloat))); -} + const auto input_cpu = at::rand({1, 1, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); -TEST_F(VulkanAPITest, hardsigmoid) { - if (!at::is_vulkan_available()) { - return; - } + const auto weight_cpu = at::rand({1, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); - const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat))*12 - 6; - const auto in_vulkan = in_cpu.vulkan(); + const auto bias_cpu = at::rand({1, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); - const auto out_cpu = at::hardsigmoid(in_cpu); - const auto out_vulkan = at::hardsigmoid(in_vulkan); + const auto output_cpu = at::layer_norm(input_cpu, {1, 1, 1}, weight_cpu, bias_cpu, 1e-05, false); + const auto output_vulkan = at::layer_norm(input_vulkan, {1, 1, 1}, weight_vulkan, bias_vulkan, 1e-05, false); - const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); if (!check) { - showRtol(out_cpu, out_vulkan.cpu()); + showRtol(output_cpu, output_vulkan.cpu()); } ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, hardsigmoid_) { - if (!at::is_vulkan_available()) { - return; - } +TEST_F(VulkanAPITest, layer_norm_4d_medium) { + c10::InferenceMode mode; - auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat))*12 - 6; - auto vulkan = cpu.vulkan(); + const auto input_cpu = at::rand({1, 3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); - at::hardsigmoid_(cpu); - at::hardsigmoid_(vulkan); + const auto weight_cpu = at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); - const auto check = almostEqual(cpu, vulkan.cpu()); + const auto bias_cpu = at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); + + const auto output_cpu = at::layer_norm(input_cpu, {3, 5, 7}, weight_cpu, bias_cpu, 1e-05, false); + const auto output_vulkan = at::layer_norm(input_vulkan, {3, 5, 7}, weight_vulkan, bias_vulkan, 1e-05, false); + + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); if (!check) { - showRtol(cpu, vulkan.cpu()); + showRtol(output_cpu, output_vulkan.cpu()); } ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, hardshrink) { - if (!at::is_vulkan_available()) { - return; - } - - for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { - const auto in_cpu = (at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)) - 0.5) * 20; // between -10 and +10 - const auto in_vulkan = in_cpu.vulkan(); +TEST_F(VulkanAPITest, layer_norm_4d_large) { + c10::InferenceMode mode; - const auto out_cpu = at::hardshrink(in_cpu, lambd_value); - const auto out_vulkan = at::hardshrink(in_vulkan, lambd_value); + const auto input_cpu = at::rand({1, 53, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto input_vulkan = input_cpu.vulkan(); - const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + const auto weight_cpu = at::rand({53, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight_vulkan = weight_cpu.vulkan(); - if (!check) { - showRtol(out_cpu, out_vulkan.cpu()); - } + const auto bias_cpu = at::rand({53, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_vulkan = bias_cpu.vulkan(); - ASSERT_TRUE(check); - } -} + const auto output_cpu = at::layer_norm(input_cpu, {53, 139, 109}, weight_cpu, bias_cpu, 1e-05, false); + const auto output_vulkan = at::layer_norm(input_vulkan, {53, 139, 109}, weight_vulkan, bias_vulkan, 1e-05, false); -TEST_F(VulkanAPITest, hardshrink_) { - if (!at::is_vulkan_available()) { - return; + const auto check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + showRtol(output_cpu, output_vulkan.cpu()); } - for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { - const auto cpu = (at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)) - 0.5) * 20; // between -10 and +10 - const auto vulkan = cpu.vulkan(); - - cpu.hardshrink(lambd_value); - vulkan.hardshrink(lambd_value); - - const auto check = almostEqual(cpu, vulkan.cpu()); - if (!check) { - showRtol(cpu, vulkan.cpu()); - } - - ASSERT_TRUE(check); - } + ASSERT_TRUE(check); } TEST_F(VulkanAPITest, leaky_relu) { - if (!at::is_vulkan_available()) { - return; - } - for (const auto negative_slope : {0.01, 0.001, 1.0, -0.001}) { const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan = in_cpu.vulkan(); @@ -1148,10 +1457,6 @@ TEST_F(VulkanAPITest, leaky_relu) { } TEST_F(VulkanAPITest, leaky_relu_) { - if (!at::is_vulkan_available()) { - return; - } - for (const auto negative_slope : {0.01, 0.001, 1.0, -0.001}) { auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); auto vulkan = cpu.vulkan(); @@ -1169,10 +1474,6 @@ TEST_F(VulkanAPITest, leaky_relu_) { } TEST_F(VulkanAPITest, lerp) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1194,10 +1495,6 @@ TEST_F(VulkanAPITest, lerp) { } TEST_F(VulkanAPITest, lerp_broadcast0) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1219,10 +1516,6 @@ TEST_F(VulkanAPITest, lerp_broadcast0) { } TEST_F(VulkanAPITest, lerp_broadcast1) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 4, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1244,10 +1537,6 @@ TEST_F(VulkanAPITest, lerp_broadcast1) { } TEST_F(VulkanAPITest, lerp_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -1269,10 +1558,6 @@ TEST_F(VulkanAPITest, lerp_) { } TEST_F(VulkanAPITest, lerp_broadcast0_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({3, 5, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -1294,10 +1579,6 @@ TEST_F(VulkanAPITest, lerp_broadcast0_) { } TEST_F(VulkanAPITest, lerp_broadcast1_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({3, 4, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -1319,10 +1600,6 @@ TEST_F(VulkanAPITest, lerp_broadcast1_) { } TEST_F(VulkanAPITest, lerp_scalar) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({13, 23, 59, 73}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1343,10 +1620,6 @@ TEST_F(VulkanAPITest, lerp_scalar) { } TEST_F(VulkanAPITest, lerp_scalar_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({47, 2, 23, 97}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -1367,10 +1640,6 @@ TEST_F(VulkanAPITest, lerp_scalar_) { } TEST_F(VulkanAPITest, hardswish) { - if (!at::is_vulkan_available()) { - return; - } - const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat))*12 - 6; const auto in_vulkan = in_cpu.vulkan(); @@ -1386,11 +1655,7 @@ TEST_F(VulkanAPITest, hardswish) { } TEST_F(VulkanAPITest, threshold) { - if (!at::is_vulkan_available()) { - return; - } - - const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat))*12 - 6; + const auto in_cpu = at::rand({2, 11, 57, 23}, at::device(at::kCPU).dtype(at::kFloat))*12 - 6; const auto in_vulkan = in_cpu.vulkan(); const float threshold = 2.0f; @@ -1399,19 +1664,11 @@ TEST_F(VulkanAPITest, threshold) { const auto out_cpu = at::threshold(in_cpu, threshold, value); const auto out_vulkan = at::threshold(in_vulkan, threshold, value); - const auto check = almostEqual(out_cpu, out_vulkan.cpu()); - if (!check) { - showRtol(out_cpu, out_vulkan.cpu()); - } - + const auto check = checkThreshold(out_cpu, out_vulkan.cpu(), threshold, value); ASSERT_TRUE(check); } TEST_F(VulkanAPITest, hardswish_) { - if (!at::is_vulkan_available()) { - return; - } - auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat))*12 - 6; auto vulkan = cpu.vulkan(); @@ -1427,9 +1684,6 @@ TEST_F(VulkanAPITest, hardswish_) { } TEST_F(VulkanAPITest, max_pool2d) { - if (!at::is_vulkan_available()) { - return; - } c10::InferenceMode mode; const auto in_cpu = at::rand({5, 13, 55, 68}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); @@ -1475,10 +1729,6 @@ TEST_F(VulkanAPITest, mean2d) { } TEST_F(VulkanAPITest, mm) { - if (!at::is_vulkan_available()) { - return; - } - const auto m1_cpu = at::rand({179, 67}, at::device(at::kCPU).dtype(at::kFloat)); const auto m2_cpu = at::rand({67, 163}, at::device(at::kCPU).dtype(at::kFloat)); const auto out_cpu = m1_cpu.mm(m2_cpu); @@ -1495,10 +1745,6 @@ TEST_F(VulkanAPITest, mm) { } TEST_F(VulkanAPITest, mul) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1517,10 +1763,6 @@ TEST_F(VulkanAPITest, mul) { } TEST_F(VulkanAPITest, mul_broadcast0) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1539,10 +1781,6 @@ TEST_F(VulkanAPITest, mul_broadcast0) { } TEST_F(VulkanAPITest, mul_broadcast1) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1561,10 +1799,6 @@ TEST_F(VulkanAPITest, mul_broadcast1) { } TEST_F(VulkanAPITest, mul_broadcast2) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 4, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1583,10 +1817,6 @@ TEST_F(VulkanAPITest, mul_broadcast2) { } TEST_F(VulkanAPITest, mul_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -1605,10 +1835,6 @@ TEST_F(VulkanAPITest, mul_) { } TEST_F(VulkanAPITest, mul_broadcast0_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({12, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -1627,10 +1853,6 @@ TEST_F(VulkanAPITest, mul_broadcast0_) { } TEST_F(VulkanAPITest, mul_broadcast1_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({3, 8, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -1649,10 +1871,6 @@ TEST_F(VulkanAPITest, mul_broadcast1_) { } TEST_F(VulkanAPITest, mul_scalar) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({17, 213, 213, 7}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1670,10 +1888,6 @@ TEST_F(VulkanAPITest, mul_scalar) { } TEST_F(VulkanAPITest, mul_scalar_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -1691,10 +1905,6 @@ TEST_F(VulkanAPITest, mul_scalar_) { } TEST_F(VulkanAPITest, reflection_pad2d) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({2, 3, 47, 63}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1710,10 +1920,6 @@ TEST_F(VulkanAPITest, reflection_pad2d) { } TEST_F(VulkanAPITest, replication_pad2d) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({2, 3, 47, 63}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1731,15 +1937,12 @@ TEST_F(VulkanAPITest, replication_pad2d) { } TEST_F(VulkanAPITest, reshape) { - if (!at::is_vulkan_available()) { - return; - } c10::InferenceMode mode; - const auto in_cpu = at::rand({47, 11, 83, 97}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu = at::rand({7, 11, 8, 9}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan = in_cpu.vulkan(); - const std::array shape{47 * 83, 11 * 97}; + const std::array shape{7 * 8, 11 * 9}; const auto out_cpu = at::reshape(in_cpu, shape); const auto out_vulkan = at::reshape(in_vulkan, shape); @@ -1753,15 +1956,12 @@ TEST_F(VulkanAPITest, reshape) { } TEST_F(VulkanAPITest, reshape_) { - if (!at::is_vulkan_available()) { - return; - } c10::InferenceMode mode; - const auto cpu = at::rand({59, 41, 19, 67}, at::device(at::kCPU).dtype(at::kFloat)); + const auto cpu = at::rand({9, 4, 12, 6}, at::device(at::kCPU).dtype(at::kFloat)); const auto vulkan = cpu.vulkan(); - const std::array shape{59, 41 * 67, 19}; + const std::array shape{9, 4 * 6, 12}; cpu.reshape(shape); vulkan.reshape(shape); @@ -1774,11 +1974,34 @@ TEST_F(VulkanAPITest, reshape_) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, sigmoid) { - if (!at::is_vulkan_available()) { - return; +void test_select(const at::IntArrayRef input_shape, int64_t dim, int64_t index) { + const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::select(in_cpu, dim, index); + + const auto in_vulkan = in_cpu.vulkan(); + const auto out_vulkan = at::select(in_vulkan, dim, index); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); } + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, select_3d_depth_small) { + test_select({1, 1, 1}, 0, 0); +} + +TEST_F(VulkanAPITest, select_3d_depth_medium) { + test_select({3, 2, 5}, 0, 2); +} + +TEST_F(VulkanAPITest, select_3d_depth_large) { + test_select({100, 1, 144}, 0, 50); +} + +TEST_F(VulkanAPITest, sigmoid) { const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan = in_cpu.vulkan(); @@ -1794,10 +2017,6 @@ TEST_F(VulkanAPITest, sigmoid) { } TEST_F(VulkanAPITest, sigmoid_) { - if (!at::is_vulkan_available()) { - return; - } - auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); auto vulkan = cpu.vulkan(); @@ -1813,6 +2032,8 @@ TEST_F(VulkanAPITest, sigmoid_) { } TEST_F(VulkanAPITest, softmax) { + c10::InferenceMode mode; + at::Tensor test_in[] = { at::rand({1, 196, 302, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), at::rand({1, 197, 302, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), @@ -1835,7 +2056,8 @@ TEST_F(VulkanAPITest, softmax) { } } -TEST_F(VulkanAPITest, log_softmax) { +// TODO: Currently the op is not working correctly. Add it back when it is fixed. +TEST_F(VulkanAPITest, DISABLED_log_softmax) { at::Tensor test_in[] = { at::rand({1, 196, 302, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), at::rand({1, 197, 302, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), @@ -1859,10 +2081,6 @@ TEST_F(VulkanAPITest, log_softmax) { } TEST_F(VulkanAPITest, tanh) { - if (!at::is_vulkan_available()) { - return; - } - const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)) * 30; const auto in_vulkan = in_cpu.vulkan(); @@ -1878,10 +2096,6 @@ TEST_F(VulkanAPITest, tanh) { } TEST_F(VulkanAPITest, tanh_) { - if (!at::is_vulkan_available()) { - return; - } - auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)) * 30; auto vulkan = cpu.vulkan(); @@ -1897,10 +2111,6 @@ TEST_F(VulkanAPITest, tanh_) { } TEST_F(VulkanAPITest, sub) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1919,10 +2129,6 @@ TEST_F(VulkanAPITest, sub) { } TEST_F(VulkanAPITest, sub_broadcast0) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1941,10 +2147,6 @@ TEST_F(VulkanAPITest, sub_broadcast0) { } TEST_F(VulkanAPITest, sub_broadcast1) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 5, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1963,10 +2165,6 @@ TEST_F(VulkanAPITest, sub_broadcast1) { } TEST_F(VulkanAPITest, sub_broadcast2) { - if (!at::is_vulkan_available()) { - return; - } - const auto a_cpu = at::rand({3, 4, 179, 221}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -1985,10 +2183,6 @@ TEST_F(VulkanAPITest, sub_broadcast2) { } TEST_F(VulkanAPITest, sub_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -2007,10 +2201,6 @@ TEST_F(VulkanAPITest, sub_) { } TEST_F(VulkanAPITest, sub_broadcast0_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({16, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -2029,10 +2219,6 @@ TEST_F(VulkanAPITest, sub_broadcast0_) { } TEST_F(VulkanAPITest, sub_broadcast1_) { - if (!at::is_vulkan_available()) { - return; - } - auto a_cpu = at::rand({3, 8, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); auto a_vulkan = a_cpu.vulkan(); @@ -2051,11 +2237,6 @@ TEST_F(VulkanAPITest, sub_broadcast1_) { } TEST_F(VulkanAPITest, transposed_conv2d) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange constexpr int64_t groups = 1; constexpr std::array stride{1, 2}; @@ -2131,10 +2312,6 @@ TEST_F(VulkanAPITest, transposed_conv2d) { } TEST_F(VulkanAPITest, upsample_nearest2d) { - if (!at::is_vulkan_available()) { - return; - } - const auto in_cpu = at::rand({1, 2, 2, 3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); const auto out_cpu = at::upsample_nearest2d(in_cpu, {4, 6}); @@ -2149,13 +2326,100 @@ TEST_F(VulkanAPITest, upsample_nearest2d) { ASSERT_TRUE(check); } -#if !defined(__APPLE__) -TEST_F(VulkanAPITest, cat_dim1_samefeature_success) { - // Guard - if (!at::is_vulkan_available()) { - return; +void test_unbind(const at::IntArrayRef input_shape, int64_t dim) { + const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::unbind(in_cpu, dim); + + const auto in_vulkan = in_cpu.vulkan(); + const auto out_vulkan = at::unbind(in_vulkan, dim); + + int64_t size = out_vulkan.size(); + + for (const auto i : c10::irange(size)) { + const auto check = almostEqual(out_cpu[i], out_vulkan[i].cpu()); + if (!check) { + std::cout << "The " << i << "th vectors aren't equal." << std::endl; + showRtol(out_cpu[i], out_vulkan[i].cpu()); + } + + ASSERT_TRUE(check); + } +} + +TEST_F(VulkanAPITest, unbind_3d_depth_small) { + test_unbind({1, 1, 1}, 0); +} + +TEST_F(VulkanAPITest, unbind_3d_depth_medium) { + test_unbind({3, 2, 5}, 0); +} + +TEST_F(VulkanAPITest, unbind_3d_depth_large) { + test_unbind({100, 1, 144}, 0); +} + +TEST_F(VulkanAPITest, view_explicit) { + c10::InferenceMode mode; + + const auto in_cpu = at::rand({7, 8, 9}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan = in_cpu.vulkan(); + + const std::array shape{7, 8, 9, 1}; + + const auto out_cpu = in_cpu.view(shape); + const auto out_vulkan = in_vulkan.view(shape); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, view_inferred) { + c10::InferenceMode mode; + + const auto in_cpu = at::rand({7, 11, 8, 9}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan = in_cpu.vulkan(); + + const std::array shape{7, 11, -1}; + + const auto out_cpu = in_cpu.view(shape); + const auto out_vulkan = in_vulkan.view(shape); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); } + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, view_invalid_inputs) { + c10::InferenceMode mode; + + // Act: only one dimension can be inferred + EXPECT_THROW({ + at::rand({7, 8, 9}, at::device(at::kCPU).dtype(at::kFloat)) + .vulkan().view({7, -1, -1}); + }, ::std::runtime_error); + + // Act: invalid shape dimension + EXPECT_THROW({ + at::rand({7, 8, 9}, at::device(at::kCPU).dtype(at::kFloat)) + .vulkan().view({7, 8, -2}); + }, ::c10::Error); + + // Act: incompatible shape + EXPECT_THROW({ + at::rand({7, 8, 9}, at::device(at::kCPU).dtype(at::kFloat)) + .vulkan().view({7, 70}); + }, ::std::runtime_error); +} + +#if !defined(__APPLE__) +TEST_F(VulkanAPITest, DISABLED_cat_dim1_samefeature_success) { // Arrange const auto in_cpu1 = at::rand({3, 9, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({3, 9, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2174,12 +2438,7 @@ TEST_F(VulkanAPITest, cat_dim1_samefeature_success) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, cat_dim1_difffeature_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - +TEST_F(VulkanAPITest, DISABLED_cat_dim1_difffeature_success) { // Arrange const auto in_cpu1 = at::rand({3, 3, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({3, 8, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2199,11 +2458,6 @@ TEST_F(VulkanAPITest, cat_dim1_difffeature_success) { } TEST_F(VulkanAPITest, cat_dim1_texture2d_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: 2D Texture (VK_IMAGE_VIEW_TYPE_2D) const auto in_cpu1 = at::rand({2, 3, 2, 2}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({2, 3, 2, 2}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2224,11 +2478,6 @@ TEST_F(VulkanAPITest, cat_dim1_texture2d_success) { #endif /* !defined(__APPLE__) */ TEST_F(VulkanAPITest, cat_dim1_singledepth_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: batch x channel (1x1) = single depth texture const auto in_cpu1 = at::rand({1, 1, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({1, 1, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2248,11 +2497,6 @@ TEST_F(VulkanAPITest, cat_dim1_singledepth_success) { } TEST_F(VulkanAPITest, cat_dim1_singletensor_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: single input tensor const auto in_cpu1 = at::rand({3, 7, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2269,12 +2513,7 @@ TEST_F(VulkanAPITest, cat_dim1_singletensor_success) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, cat_dim1_twotensors_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - +TEST_F(VulkanAPITest, DISABLED_cat_dim1_twotensors_success) { // Arrange: two input tensors const auto in_cpu1 = at::rand({3, 7, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({3, 7, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2293,11 +2532,6 @@ TEST_F(VulkanAPITest, cat_dim1_twotensors_success) { } TEST_F(VulkanAPITest, cat_dim1_bat1_mult4ch_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: batch=1 and channel (a multiple of 4 <-> channel %4 == 0) const auto in_cpu1 = at::rand({1, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({1, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2317,11 +2551,6 @@ TEST_F(VulkanAPITest, cat_dim1_bat1_mult4ch_success) { } TEST_F(VulkanAPITest, cat_dim1_bat2_mult4ch_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: batch=2 and channel (a multiple of 4 <-> channel %4 == 0) const auto in_cpu1 = at::rand({2, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({2, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2341,11 +2570,6 @@ TEST_F(VulkanAPITest, cat_dim1_bat2_mult4ch_success) { } TEST_F(VulkanAPITest, cat_dim1_mult4ch_mixed_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: batch=1 and channel (different multiples of 4 <-> channel %4 == 0) const auto in_cpu1 = at::rand({3, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({3, 8, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2364,12 +2588,7 @@ TEST_F(VulkanAPITest, cat_dim1_mult4ch_mixed_success) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, cat_dim1_mult4ch_nonmult4ch_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - +TEST_F(VulkanAPITest, DISABLED_cat_dim1_mult4ch_nonmult4ch_success) { // Arrange: batch=1 and channel (a mixed set of multiples and non-multiples of 4) const auto in_cpu1 = at::rand({3, 3, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({3, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2390,11 +2609,6 @@ TEST_F(VulkanAPITest, cat_dim1_mult4ch_nonmult4ch_success) { } TEST_F(VulkanAPITest, cat_dim2_sameheight_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu1 = at::rand({3, 9, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({3, 9, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2414,11 +2628,6 @@ TEST_F(VulkanAPITest, cat_dim2_sameheight_success) { } TEST_F(VulkanAPITest, cat_dim2_diffheight_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu1 = at::rand({3, 9, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({3, 9, 112, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2438,11 +2647,6 @@ TEST_F(VulkanAPITest, cat_dim2_diffheight_success) { } TEST_F(VulkanAPITest, cat_dim2_singledepth_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: batch x channel (1x1) = single depth texture const auto in_cpu1 = at::rand({1, 1, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_cpu2 = at::rand({1, 1, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2462,11 +2666,6 @@ TEST_F(VulkanAPITest, cat_dim2_singledepth_success) { } TEST_F(VulkanAPITest, cat_dim2_invalidinputs_exceptions) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: Vulkan cat inputs must have matching sizes except concatenated dimension { const auto in_cpu1 = at::rand({3, 5, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2511,11 +2710,6 @@ TEST_F(VulkanAPITest, cat_dim2_invalidinputs_exceptions) { } TEST_F(VulkanAPITest, permute_2d_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu = at::rand({2, 3}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2533,18 +2727,15 @@ TEST_F(VulkanAPITest, permute_2d_success) { } TEST_F(VulkanAPITest, permute_3d_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu = at::rand({2, 3, 2}, at::device(at::kCPU).dtype(at::kFloat)); std::vector> all_dims; std::vector in{0, 1, 2}; gen_allpermutations(all_dims, in, 0); - for (const auto& dims : all_dims) { + for (const auto i : c10::irange(1, all_dims.size())) { + const auto dims = all_dims[i]; + // Act const auto out_cpu = at::permute(in_cpu, dims); const auto out_vulkan = at::permute(in_cpu.vulkan(), dims); @@ -2560,18 +2751,15 @@ TEST_F(VulkanAPITest, permute_3d_success) { } TEST_F(VulkanAPITest, permute_4d_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu = at::rand({2, 3, 4, 5}, at::device(at::kCPU).dtype(at::kFloat)); std::vector> all_dims; std::vector in{0, 1, 2, 3}; gen_allpermutations(all_dims, in, 0); - for (const auto& dims : all_dims) { + for (const auto i : c10::irange(1, all_dims.size())) { + const auto dims = all_dims[i]; + // Act const auto out_cpu = at::permute(in_cpu, dims); const auto out_vulkan = at::permute(in_cpu.vulkan(), dims); @@ -2587,11 +2775,6 @@ TEST_F(VulkanAPITest, permute_4d_success) { } TEST_F(VulkanAPITest, permute_4dmclaren_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange: McLaren Model usage const auto in_cpu = at::rand({1, 2, 1, 161}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2609,18 +2792,14 @@ TEST_F(VulkanAPITest, permute_4dmclaren_success) { } TEST_F(VulkanAPITest, permute_4dbig_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu = at::rand({3, 9, 89, 91}, at::device(at::kCPU).dtype(at::kFloat)); std::vector> all_dims; std::vector in{0, 1, 2, 3}; gen_allpermutations(all_dims, in, 0); - for (const auto& dims : all_dims) { + for (const auto i : c10::irange(1, all_dims.size())) { + const auto dims = all_dims[i]; // Act const auto out_cpu = at::permute(in_cpu, dims); const auto out_vulkan = at::permute(in_cpu.vulkan(), dims); @@ -2636,11 +2815,6 @@ TEST_F(VulkanAPITest, permute_4dbig_success) { } TEST_F(VulkanAPITest, permute_negativedims_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu = at::rand({5, 4, 3, 2}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2657,56 +2831,7 @@ TEST_F(VulkanAPITest, permute_negativedims_success) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, permute_1d_nochange) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - - // Arrange - const auto in_cpu = at::rand({161}, at::device(at::kCPU).dtype(at::kFloat)); - - // Act - const auto out_cpu = at::permute(in_cpu, {0}); - const auto out_vulkan = at::permute(in_cpu.vulkan(), {0}); - - // Assert - const auto check = almostEqual(out_cpu, out_vulkan.cpu()); - if (!check) { - showRtol(out_cpu, out_vulkan.cpu()); - } - - ASSERT_TRUE(check); -} - -TEST_F(VulkanAPITest, permute_sameDims_nochange) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - - // Arrange - const auto in_cpu = at::rand({1, 2, 1, 161}, at::device(at::kCPU).dtype(at::kFloat)); - - // Act - const auto out_cpu = at::permute(in_cpu, {0, 1, 2, 3}); - const auto out_vulkan = at::permute(in_cpu.vulkan(), {0, 1, 2, 3}); - - // Assert - const auto check = almostEqual(out_cpu, out_vulkan.cpu()); - if (!check) { - showRtol(out_cpu, out_vulkan.cpu()); - } - - ASSERT_TRUE(check); -} - TEST_F(VulkanAPITest, permute_invalidinputs_exceptions) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const auto in_cpu = at::rand({1, 2, 1, 161}, at::device(at::kCPU).dtype(at::kFloat)); @@ -2827,6 +2952,130 @@ TEST_F(VulkanAPITest, slice_invalidinputs_exceptions) { }, ::c10::Error); } +TEST_F(VulkanAPITest, stack_invalid_inputs) { + // Act: Vulkan stack expects at least one tensor + EXPECT_THROW({ + at::stack({}, 0); + }, ::c10::Error); + + // Act: Vulkan stack expects dim = 0 + EXPECT_THROW({ + at::stack({ + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan()}, 1); + }, ::c10::Error); + + // Act: Vulkan stack expects 2 dimensional inputs + EXPECT_THROW({ + at::stack({ + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({3, 5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan()}, 0); + }, ::c10::Error); + + // Act: Vulkan stack inputs must have matching sizes + EXPECT_THROW({ + at::stack({ + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({5, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), + at::rand({6, 7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan()}, 0); + }, ::c10::Error); +} + +TEST_F(VulkanAPITest, stack_1_tensor) { + // Arrange + const auto in_cpu1 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + const auto out_cpu = at::stack({in_cpu1}, 0); + const auto out_vulkan = at::stack({in_cpu1.vulkan()}, 0); + + // Assert + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, stack_2_tensors) { + // Arrange + const auto in_cpu1 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + const auto out_cpu = at::stack({in_cpu1, in_cpu2}, 0); + const auto out_vulkan = at::stack({in_cpu1.vulkan(), in_cpu2.vulkan()}, 0); + + // Assert + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, stack_3_tensors) { + // Arrange + const auto in_cpu1 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + const auto out_cpu = at::stack({in_cpu1, in_cpu2, in_cpu3}, 0); + const auto out_vulkan = at::stack({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 0); + + // Assert + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, stack_4_tensors) { + // Arrange + const auto in_cpu1 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu4 = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + const auto out_cpu = at::stack({in_cpu1, in_cpu2, in_cpu3, in_cpu4}, 0); + const auto out_vulkan = at::stack({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan(), in_cpu4.vulkan()}, 0); + + // Assert + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, stack_from_1_to_20_tensors) { + std::vector tensors_cpu = {}; + std::vector tensors_vulkan = {}; + + for (const auto i : c10::irange(20)) { + at::Tensor in_cpu = at::rand({221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + tensors_cpu.emplace_back(in_cpu); + tensors_vulkan.emplace_back(in_cpu.vulkan()); + at::Tensor out_cpu = at::stack(tensors_cpu, 0); + at::Tensor out_vulkan = at::stack(tensors_vulkan, 0); + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Error when stacking " << i << " tensors" << std::endl; + showRtol(out_cpu, out_vulkan.cpu()); + } + ASSERT_TRUE(check); + } +} + TEST_F(VulkanAPITest, clone_success) { // Arrange std::multimap, std::vector> mem2sizes { @@ -3099,9 +3348,6 @@ class MobileNetV2 final : public OpsList { }; TEST_F(VulkanAPITest, mobilenetv2) { - if (!at::is_vulkan_available()) { - return; - } c10::InferenceMode mode; MobileNetV2 mn2; @@ -3118,11 +3364,6 @@ TEST_F(VulkanAPITest, mobilenetv2) { } TEST_F(VulkanAPITest, gru_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const int H_in = 5; // input_size const int H_out = 7; // hidden_size @@ -3190,11 +3431,6 @@ TEST_F(VulkanAPITest, gru_success) { } TEST_F(VulkanAPITest, gru_mclareninputs_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const int H_in = 384; // input_size const int H_out = 384; // hidden_size @@ -3258,11 +3494,6 @@ TEST_F(VulkanAPITest, gru_mclareninputs_success) { } TEST_F(VulkanAPITest, gru_invalidinputs_exceptions) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const int H_in = 17; // input_size const int H_out = 50; // hidden_size @@ -3356,11 +3587,6 @@ TEST_F(VulkanAPITest, gru_invalidinputs_exceptions) { } TEST_F(VulkanAPITest, gru_prepack_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const int H_in = 81; // input_size const int H_out = 10; // hidden_size @@ -3430,11 +3656,6 @@ TEST_F(VulkanAPITest, gru_prepack_success) { } TEST_F(VulkanAPITest, gru_prepack_invalidinputs_exceptions) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const int H_in = 70; // input_size const int H_out = 2; // hidden_size @@ -3559,12 +3780,51 @@ TEST_F(VulkanAPITest, gru_prepack_invalidinputs_exceptions) { }, ::c10::Error); } -TEST_F(VulkanAPITest, lstm_success) { - // Guard - if (!at::is_vulkan_available()) { - return; +void test_linear( + const at::IntArrayRef input_shape, + const at::IntArrayRef weight_shape, + const at::IntArrayRef bias_shape) { + c10::InferenceMode mode; + + const auto input_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + const auto weight = at::rand(weight_shape, at::device(at::kCPU).dtype(at::kFloat)); + const auto bias = at::rand(bias_shape, at::device(at::kCPU).dtype(at::kFloat)); + + const auto out_cpu = at::linear(input_cpu, weight, bias); + + auto prepack = callOpByName( + "vulkan_prepack::create_linear_context", + "", + weight.t(), bias); + + auto vulkan_output = callOpByName( + "vulkan_prepack::run_linear_context", + "", + input_cpu.vulkan(), prepack[0]); + + auto out_vulkan = vulkan_output[0].toTensor(); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); } + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, linear_2d) { + test_linear({1, 37}, {41, 37}, {41}); +} + +TEST_F(VulkanAPITest, linear_3d) { + test_linear({1, 1, 37}, {41, 37}, {41}); +} + +TEST_F(VulkanAPITest, linear_4d) { + test_linear({1, 1, 1, 37}, {41, 37}, {41}); +} + +TEST_F(VulkanAPITest, lstm_success) { // Arrange const int input_size = 5; const int hidden_size = 7; @@ -3645,11 +3905,6 @@ TEST_F(VulkanAPITest, lstm_success) { } TEST_F(VulkanAPITest, lstm_mclareninputs_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } - // Arrange const int input_size = 384; const int hidden_size = 384; @@ -3725,99 +3980,87 @@ TEST_F(VulkanAPITest, lstm_mclareninputs_success) { ASSERT_TRUE(check_cell); } +TEST_F(VulkanAPITest, lstm_prepack_success) { + // Arrange + const int input_size = 81; + const int hidden_size = 10; + const int num_layers = 2; + const int L = 1; + const int N = 1; + const double lstm_dropout = .0; + const bool has_biases = true; + const bool train = false; + const bool bidirectional = false; + const bool batch_first = true; + const auto in_cpu = at::rand({N, L, input_size}, at::device(at::kCPU).dtype(at::kFloat)); + const auto h0_cpu = at::rand({num_layers, N, hidden_size}, at::device(at::kCPU).dtype(at::kFloat)); + const auto c0_cpu = at::rand({num_layers, N, hidden_size}, at::device(at::kCPU).dtype(at::kFloat)); -#if defined (__ANDROID__) // to avoid `Undefined symbols for architecture arm64` error -TEST_F(VulkanAPITest, profiling_invalideinputs_exceptions) { - // Guard - if (!at::is_vulkan_available()) { - return; + c10::List weight_ih_l; // shape (4 * hidden_size, l == 0 ? input_size : hidden_size) + c10::List weight_hh_l; // shape (4 * hidden_size, hidden_size) + c10::List bias_ih_l; // shape (4 * hidden_size) + c10::List bias_hh_l; // shape (4 * hidden_size) + for (int l = 0; l < num_layers; ++l) { + if (l == 0) { + weight_ih_l.emplace_back(at::rand({4 * hidden_size, input_size}, at::device(at::kCPU).dtype(at::kFloat))); + } else { + weight_ih_l.emplace_back(at::rand({4 * hidden_size, hidden_size}, at::device(at::kCPU).dtype(at::kFloat))); + } + weight_hh_l.emplace_back(at::rand({4 * hidden_size, hidden_size}, at::device(at::kCPU).dtype(at::kFloat))); + bias_ih_l.emplace_back(at::rand({4 * hidden_size}, at::device(at::kCPU).dtype(at::kFloat))); + bias_hh_l.emplace_back(at::rand({4 * hidden_size}, at::device(at::kCPU).dtype(at::kFloat))); } - // Act: The device doesn't support for timestamps on all graphics and compute queues. - EXPECT_THROW({ - const bool is_timestamps_supported_ = false; - const float timestamp_period = 1.f; - at::native::vulkan::api::QueryPool querypool(at::native::vulkan::api::context()->gpu().device, is_timestamps_supported_, timestamp_period); - querypool.enable(); - }, ::c10::Error); + // put this guard here to run inference inststead of training + // to avoid the following error: + // C++ exception with description "0INTERNAL ASSERT FAILED at "xplat/caffe2/aten/src/ATen/core/boxing/KernelFunction.cpp":31, please report a bug to PyTorch. aten::gru.input has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther. This makes the backend kernel unreachable; the dispatcher will always prefer the CompositeImplicitAutograd lowering (see Note [Ambiguity in AutogradOther kernel]). If you want to override CompositeImplicitAutograd, please open an issue to request a dedicated Autograd dispatch key for the backend. + // If you only want to run inference instead of training, add `c10::InferenceMode mode;` before model.forward(). Note this guard is only available in C++ but not Python at present. + c10::InferenceMode mode; - // Act: The query pool already exists. - EXPECT_THROW({ - auto context = at::native::vulkan::api::context(); - at::native::vulkan::api::QueryPool querypool( - context->gpu().device, - context->gpu().adapter->timestamp_compute_and_graphics(), - context->gpu().adapter->timestamp_period()); - querypool.enable(); - querypool.enable(); // already enabled - }, ::c10::Error); + // Act + const auto out_cpu = at::lstm(in_cpu, {h0_cpu, c0_cpu}, + { weight_ih_l[0], weight_hh_l[0], bias_ih_l[0], bias_hh_l[0], + weight_ih_l[1], weight_hh_l[1], bias_ih_l[1], bias_hh_l[1] }, + has_biases, num_layers, lstm_dropout, train, bidirectional, batch_first); - // Act: The query index cannot exceed Configuration::kMaxQueryCount. - EXPECT_THROW({ - auto context = at::native::vulkan::api::context(); - at::native::vulkan::api::QueryPool querypool( - context->gpu().device, - context->gpu().adapter->timestamp_compute_and_graphics(), - context->gpu().adapter->timestamp_period()); - querypool.enable(); - for (uint32_t i = 0u; i < at::native::vulkan::api::QueryPool::Configuration::kMaxQueryCount + 1u; ++i) { - at::native::vulkan::api::Command::Buffer& command_buffer = context->command().pool.stream(); - { - at::native::vulkan::api::OpProfiler profiler(command_buffer, querypool, "test"); - } - context->command().pool.submit(context->gpu().queue, command_buffer); - } - }, ::c10::Error); -} + auto prepack = callOpByName( + "vulkan_prepack::create_lstm_context", + "", + std::vector({ weight_ih_l.get(0), weight_hh_l.get(0), bias_ih_l.get(0), bias_hh_l.get(0), + weight_ih_l.get(1), weight_hh_l.get(1), bias_ih_l.get(1), bias_hh_l.get(1) }), + has_biases, num_layers, lstm_dropout, train, bidirectional, batch_first); -// NOTE: Keep the following test at the end of file -// so that it can print out the op execution time for all prior tests -TEST_F(VulkanAPITest, profiling_result_success) { - // Guard - if (!at::is_vulkan_available()) { - return; - } + auto out_vulkan = callOpByName( + "vulkan_prepack::run_lstm_context", + "", + in_cpu.vulkan(), h0_cpu.vulkan(), c0_cpu.vulkan(), prepack[0]); - // Arrange - auto is_enabled = at::native::vulkan::api::context()->querypool().is_enabled(); - if (is_enabled) { - auto perf_info = at::native::vulkan::api::context()->querypool().disable(false); - std::cout - << "-----------------------------------------------------------------------------------------" << std::endl - << "Query Name Execution Start End" << std::endl - << "-----------------------------------------------------------------------------------------" << std::endl; - for (size_t i = 0; i < perf_info.size(); i++) { - std::cout << std::left << std::setw(35) << perf_info[i].query_name.c_str() - << std::right << std::setw(15) << perf_info[i].execution_time_us << " us" - << std::setw(15) << perf_info[i].start_time_us << " us" - << std::setw(15) << perf_info[i].end_time_us << " us" << std::left << std::endl; - } - } - at::native::vulkan::api::context()->querypool().enable(); - const auto in_cpu1 = at::rand({2, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); - const auto in_cpu2 = at::rand({2, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); - const auto in_cpu3 = at::rand({2, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); - const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 1); - out_vulkan.cpu(); // to make sure all GPU operations are done + auto cpu_output = std::get<0>(out_cpu); + auto cpu_hidden = std::get<1>(out_cpu); + auto cpu_cell = std::get<2>(out_cpu); + auto vulkan_output = out_vulkan[0].toTensor(); + auto vulkan_hidden = out_vulkan[1].toTensor(); + auto vulkan_cell = out_vulkan[2].toTensor(); - // Act - auto perf_info = at::native::vulkan::api::context()->querypool().disable(true); - for (size_t i = 0; i < perf_info.size(); i++) { - std::cout << std::left << std::setw(35) << perf_info[i].query_name.c_str() - << std::right << std::setw(15) << perf_info[i].execution_time_us << " us" - << std::setw(15) << perf_info[i].start_time_us << " us" - << std::setw(15) << perf_info[i].end_time_us << " us" << std::left << std::endl; + // Assert + const auto check_output = almostEqual(cpu_output, vulkan_output.cpu()); + if (!check_output) { + showRtol(cpu_output, vulkan_output.cpu()); } + ASSERT_TRUE(check_output); - // Assert - ASSERT_TRUE(perf_info.size() == 5u); - ASSERT_TRUE(perf_info[0].query_name == "aten::_cat (cat_feature_mult4ch)"); + const auto check_hidden = almostEqual(cpu_hidden, vulkan_hidden.cpu()); + if (!check_hidden) { + showRtol(cpu_hidden, vulkan_hidden.cpu()); + } + ASSERT_TRUE(check_hidden); - if (is_enabled) { - at::native::vulkan::api::context()->querypool().enable(); + const auto check_cell = almostEqual(cpu_cell, vulkan_cell.cpu()); + if (!check_cell) { + showRtol(cpu_cell, vulkan_cell.cpu()); } + ASSERT_TRUE(check_cell); } -#endif } // namespace diff --git a/aten/src/ATen/test/vulkan_perf_test.cpp b/aten/src/ATen/test/vulkan_perf_test.cpp index 230484a0f9151..f161a515920de 100644 --- a/aten/src/ATen/test/vulkan_perf_test.cpp +++ b/aten/src/ATen/test/vulkan_perf_test.cpp @@ -4,11 +4,45 @@ #include #include +#include +#include +#include +#include +#include namespace { -// using Vulkan Timestamp Queries for the pure GPU execution time only -static void cat_op_channel_perf_gpu_only(benchmark::State& state) { +at::Tensor vulkan_to_cpu(at::Tensor vulkan, at::Tensor in_cpu) { + auto q_options = in_cpu.options(); + if (q_options.dtype().toScalarType() == c10::ScalarType::QUInt8) { + auto output = at::native::empty_affine_quantized( + in_cpu.sizes(), + q_options.dtype().toScalarType(), + q_options.layout(), + q_options.device(), + q_options.pinned_memory(), + in_cpu.q_scale(), + in_cpu.q_zero_point()); + at::native::vulkan::ops::copy_(output, vulkan); + return output; + } else { + auto output = at::empty(in_cpu.sizes(), q_options); + at::native::vulkan::ops::copy_(output, vulkan); + return output; + } +} + +template +static inline dest_t safe_downcast(src_t v) { + TORCH_CHECK( + std::numeric_limits::min() <= v && + v <= std::numeric_limits::max(), + "integer out of range"); + + return static_cast(v); +} + +static void add_op_benchmark(benchmark::State& state) { // Guard if (!at::is_vulkan_available()) { return; @@ -19,73 +53,942 @@ static void cat_op_channel_perf_gpu_only(benchmark::State& state) { const auto channels = state.range(1); const auto height = state.range(2); const auto width = state.range(3); - const auto in_cpu1 = at::rand({batches, channels, height, width}, at::device(at::kCPU).dtype(at::kFloat)); - const auto in_cpu2 = at::rand({batches, channels, height, width}, at::device(at::kCPU).dtype(at::kFloat)); - const auto in_cpu3 = at::rand({batches, channels, height, width}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu1 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan1 = in_cpu1.vulkan(); const auto in_vulkan2 = in_cpu2.vulkan(); - const auto in_vulkan3 = in_cpu3.vulkan(); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_out = at::add(in_vulkan1, in_vulkan2).cpu(); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("add") / 1000000.0); +#endif +} + +static void add_op_q_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches = state.range(0); + const auto channels = state.range(1); + const auto height = state.range(2); + const auto width = state.range(3); + const auto in_cpu1 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan1 = in_cpu1.vulkan(); + const auto in_vulkan2 = in_cpu2.vulkan(); + const double scale = 0.1; + const int zero_point = 10; + const auto out_cpu1 = at::quantize_per_tensor( + in_cpu1, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan1 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan1, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + const double scale2 = 0.15; + const int zero_point2 = 15; + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_add = at::native::vulkan::ops::quantized_add( + out_vulkan1, out_vulkan2, scale2, zero_point2); + const auto vulkan_out = vulkan_to_cpu(vulkan_add, out_cpu1); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("quantized_add") / 1000000.0); +#endif +} + +static void conv2d_op_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches_in = safe_downcast(state.range(0)); + const auto channels_in = safe_downcast(state.range(1)); + const auto height_in = safe_downcast(state.range(2)); + const auto width_in = safe_downcast(state.range(3)); + constexpr int64_t groups = 1; + constexpr std::array stride{2, 2}; + constexpr std::array padding{1, 1}; + // TODO: Support conv2d with dilation != 1 + constexpr std::array dilation{1, 1}; + + struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{batches_in, channels_in, height_in, width_in}; + + struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{1, input.channels, 3, 3}; + + const auto input_cpu = + at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = + at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::randn( + {weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif // Act for (auto _ : state) { - at::native::vulkan::api::context()->querypool().enable(); - const auto vulkan_out = at::cat({in_vulkan1, in_vulkan2, in_vulkan3}, 1); - vulkan_out.cpu(); - auto perf_info = at::native::vulkan::api::context()->querypool().disable(true); - state.SetIterationTime(perf_info[0].execution_time_us / 1'000'000.); // us to sec + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_out = at::conv2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups) + .cpu(); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("conv2d") / 1000000.0); +#endif } -static void gru_op_perf(benchmark::State& state) { +static void conv2d_op_q_benchmark(benchmark::State& state) { // Guard if (!at::is_vulkan_available()) { return; } // Arrange - const int H_in = static_cast(state.range(0)); // input_size - const int H_out = static_cast(state.range(1)); // hidden_size - const int num_layers = static_cast(state.range(2)); - const double gru_dropout = .0; - const bool has_biases = true; - const bool train = false; - const bool bidirectional = false; - const bool batch_first = true; - const auto in_cpu = at::rand({1, 1, H_in}, at::device(at::kCPU).dtype(at::kFloat)); - const auto h0_cpu = at::rand({num_layers, 1, H_out}, at::device(at::kCPU).dtype(at::kFloat)); - - c10::List weight_ih_l; // shape (3 * hidden_size, input_size) - c10::List weight_hh_l; // shape (3 * hidden_size, hidden_size) - c10::List bias_ih_l; // shape (3 * hidden_size) - c10::List bias_hh_l; // shape (3 * hidden_size) - for (int i = 0; i < num_layers; ++i) { - weight_ih_l.emplace_back(at::rand({3 * H_out, H_in}, at::device(at::kCPU).dtype(at::kFloat))); - weight_hh_l.emplace_back(at::rand({3 * H_out, H_out}, at::device(at::kCPU).dtype(at::kFloat))); - bias_ih_l.emplace_back(at::rand({3 * H_out}, at::device(at::kCPU).dtype(at::kFloat))); - bias_hh_l.emplace_back(at::rand({3 * H_out}, at::device(at::kCPU).dtype(at::kFloat))); - } - - // put this guard here to run inference inststead of training - // to avoid the following error: - // C++ exception with description "0INTERNAL ASSERT FAILED at "xplat/caffe2/aten/src/ATen/core/boxing/KernelFunction.cpp":31, please report a bug to PyTorch. aten::gru.input has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther. This makes the backend kernel unreachable; the dispatcher will always prefer the CompositeImplicitAutograd lowering (see Note [Ambiguity in AutogradOther kernel]). If you want to override CompositeImplicitAutograd, please open an issue to request a dedicated Autograd dispatch key for the backend. - // If you only want to run inference instead of training, add `c10::InferenceMode mode;` before model.forward(). Note this guard is only available in C++ but not Python at present. - c10::InferenceMode mode; + const auto batches_in = safe_downcast(state.range(0)); + const auto channels_in = safe_downcast(state.range(1)); + const auto height_in = safe_downcast(state.range(2)); + const auto width_in = safe_downcast(state.range(3)); + constexpr int64_t groups = 1; + constexpr std::array stride{2, 2}; + constexpr std::array padding{1, 1}; + // TODO: Support conv2d with dilation != 1 + constexpr std::array dilation{1, 1}; + + struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{batches_in, channels_in, height_in, width_in}; + + struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{1, input.channels, 3, 3}; + + const auto input_cpu = + at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = + at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::randn( + {weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + const double w_scale = 0.1; + const int w_zero_point = 10; + + const double b_scale = 0.1; + const int b_zero_point = 10; + + const auto weight_q = at::quantize_per_tensor( + weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8); + const auto bias_q = at::quantize_per_tensor( + bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8); + + const auto in_vulkan1 = input_cpu.vulkan(); + const double scale = 0.1; + const int zero_point = 10; + const auto out_vulkan1 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan1, scale, zero_point, c10::ScalarType::QUInt8); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif // Act - while (state.KeepRunning()) { - // weights/biases should be always on CPU. - const auto out_vulkan = at::gru(in_cpu.vulkan(), h0_cpu.vulkan(), { weight_ih_l.get(0), weight_hh_l.get(0), bias_ih_l.get(0), bias_hh_l.get(0), - weight_ih_l.get(1), weight_hh_l.get(1), bias_ih_l.get(1), bias_hh_l.get(1) }, - has_biases, num_layers, gru_dropout, train, bidirectional, batch_first); + const double scale2 = 0.15; + const int zero_point2 = 15; + const auto shape_match = + at::rand({1, 1, 64, 199}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_conv2d = at::native::vulkan::ops::conv2d( + out_vulkan1, + weight_q, + bias_q, + stride, + padding, + dilation, + groups, + scale2, + zero_point2); + const auto vulkan_out = vulkan_to_cpu(vulkan_conv2d, shape_match); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } - auto vulkan_output = std::get<0>(out_vulkan); - auto vulkan_hidden = std::get<1>(out_vulkan); +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("quantized_conv2d") / 1000000.0); +#endif +} - // to avoid out-of-memory issues, release resources by waiting and flushing all GPU operations - at::native::vulkan::api::context()->wait(vulkan_output); - at::native::vulkan::api::context()->wait(vulkan_hidden); - at::native::vulkan::api::context()->flush(); +static void conv2dpw_op_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; } + + // Arrange + const auto batches_in = safe_downcast(state.range(0)); + const auto channels_in = safe_downcast(state.range(1)); + const auto height_in = safe_downcast(state.range(2)); + const auto width_in = safe_downcast(state.range(3)); + constexpr int64_t groups = 1; + constexpr std::array stride{1, 1}; + constexpr std::array padding{0, 0}; + constexpr std::array dilation{1, 1}; + + struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{batches_in, channels_in, height_in, width_in}; + + struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{29, input.channels, 1, 1}; + + const auto input_cpu = + at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = + at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::randn( + {weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_out = at::conv2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups) + .cpu(); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("conv2d_pw_2x2") / 1000000.0); +#endif +} + +static void conv2dpw_op_q_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches_in = safe_downcast(state.range(0)); + const auto channels_in = safe_downcast(state.range(1)); + const auto height_in = safe_downcast(state.range(2)); + const auto width_in = safe_downcast(state.range(3)); + constexpr int64_t groups = 1; + constexpr std::array stride{1, 1}; + constexpr std::array padding{0, 0}; + constexpr std::array dilation{1, 1}; + + struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{batches_in, channels_in, height_in, width_in}; + + struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{29, input.channels, 1, 1}; + + const auto input_cpu = + at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = + at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::randn( + {weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + const double w_scale = 0.1; + const int w_zero_point = 10; + + const double b_scale = 0.1; + const int b_zero_point = 10; + + const auto weight_q = at::quantize_per_tensor( + weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8); + const auto bias_q = at::quantize_per_tensor( + bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8); + + const auto in_vulkan1 = input_cpu.vulkan(); + const double scale = 0.1; + const int zero_point = 10; + const auto out_vulkan1 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan1, scale, zero_point, c10::ScalarType::QUInt8); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + const double scale2 = 0.15; + const int zero_point2 = 15; + const auto shape_match = + at::rand({1, 29, 127, 397}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_conv2d = at::native::vulkan::ops::conv2d( + out_vulkan1, + weight_q, + bias_q, + stride, + padding, + dilation, + groups, + scale2, + zero_point2); + const auto vulkan_out = vulkan_to_cpu(vulkan_conv2d, shape_match); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("quantized_conv2d_pw_2x2") / 1000000.0); +#endif +} + +static void conv2ddw_op_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches_in = safe_downcast(state.range(0)); + const auto height_in = safe_downcast(state.range(2)); + const auto width_in = safe_downcast(state.range(3)); + constexpr int64_t groups = 7; + constexpr std::array stride{2, 3}; + constexpr std::array padding{0, 4}; + constexpr std::array dilation{3, 1}; + + struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{batches_in, groups, height_in, width_in}; + + struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{groups, 1, 17, 7}; + + const auto input_cpu = + at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = + at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::randn( + {weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_out = at::conv2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups) + .cpu(); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("conv2d_dw") / 1000000.0); +#endif +} + +static void conv2ddw_op_q_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches_in = safe_downcast(state.range(0)); + const auto height_in = safe_downcast(state.range(2)); + const auto width_in = safe_downcast(state.range(3)); + constexpr int64_t groups = 7; + constexpr std::array stride{2, 3}; + constexpr std::array padding{0, 4}; + constexpr std::array dilation{3, 1}; + + struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{batches_in, groups, height_in, width_in}; + + struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{groups, 1, 17, 7}; + + const auto input_cpu = + at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = + at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::randn( + {weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + const double w_scale = 0.1; + const int w_zero_point = 10; + + const double b_scale = 0.1; + const int b_zero_point = 10; + + const auto weight_q = at::quantize_per_tensor( + weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8); + const auto bias_q = at::quantize_per_tensor( + bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8); + + const auto in_vulkan1 = input_cpu.vulkan(); + const double scale = 0.1; + const int zero_point = 10; + const auto out_vulkan1 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan1, scale, zero_point, c10::ScalarType::QUInt8); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + const double scale2 = 0.15; + const int zero_point2 = 15; + const auto shape_match = + at::rand({1, 7, 45, 67}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_conv2d = at::native::vulkan::ops::conv2d( + out_vulkan1, + weight_q, + bias_q, + stride, + padding, + dilation, + groups, + scale2, + zero_point2); + const auto vulkan_out = vulkan_to_cpu(vulkan_conv2d, shape_match); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("quantized_conv2d_dw") / 1000000.0); +#endif +} + +static void sub_op_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches = state.range(0); + const auto channels = state.range(1); + const auto height = state.range(2); + const auto width = state.range(3); + const auto in_cpu1 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan1 = in_cpu1.vulkan(); + const auto in_vulkan2 = in_cpu2.vulkan(); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_out = at::sub(in_vulkan1, in_vulkan2).cpu(); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("sub") / 1000000.0); +#endif +} + +static void sub_op_q_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches = state.range(0); + const auto channels = state.range(1); + const auto height = state.range(2); + const auto width = state.range(3); + const auto in_cpu1 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan1 = in_cpu1.vulkan(); + const auto in_vulkan2 = in_cpu2.vulkan(); + const double scale = 0.1; + const int zero_point = 10; + const auto out_cpu1 = at::quantize_per_tensor( + in_cpu1, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan1 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan1, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + const double scale2 = 0.15; + const int zero_point2 = 15; + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_sub = at::native::vulkan::ops::quantized_sub( + out_vulkan1, out_vulkan2, scale2, zero_point2); + const auto vulkan_out = vulkan_to_cpu(vulkan_sub, out_cpu1); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("quantized_sub") / 1000000.0); +#endif +} + +static void mul_op_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches = state.range(0); + const auto channels = state.range(1); + const auto height = state.range(2); + const auto width = state.range(3); + const auto in_cpu1 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan1 = in_cpu1.vulkan(); + const auto in_vulkan2 = in_cpu2.vulkan(); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_out = at::mul(in_vulkan1, in_vulkan2).cpu(); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("mul") / 1000000.0); +#endif +} + +static void mul_op_q_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches = state.range(0); + const auto channels = state.range(1); + const auto height = state.range(2); + const auto width = state.range(3); + const auto in_cpu1 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan1 = in_cpu1.vulkan(); + const auto in_vulkan2 = in_cpu2.vulkan(); + const double scale = 0.1; + const int zero_point = 10; + const auto out_cpu1 = at::quantize_per_tensor( + in_cpu1, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan1 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan1, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + const double scale2 = 0.15; + const int zero_point2 = 15; + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_mul = at::native::vulkan::ops::quantized_mul( + out_vulkan1, out_vulkan2, scale2, zero_point2); + const auto vulkan_out = vulkan_to_cpu(vulkan_mul, out_cpu1); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("quantized_mul") / 1000000.0); +#endif +} + +static void div_op_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches = state.range(0); + const auto channels = state.range(1); + const auto height = state.range(2); + const auto width = state.range(3); + const auto in_cpu1 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan1 = in_cpu1.vulkan(); + const auto in_vulkan2 = in_cpu2.vulkan(); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_out = at::div(in_vulkan1, in_vulkan2).cpu(); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("div") / 1000000.0); +#endif +} + +static void div_op_q_benchmark(benchmark::State& state) { + // Guard + if (!at::is_vulkan_available()) { + return; + } + + // Arrange + const auto batches = state.range(0); + const auto channels = state.range(1); + const auto height = state.range(2); + const auto width = state.range(3); + const auto in_cpu1 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand( + {batches, channels, height, width}, + at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan1 = in_cpu1.vulkan(); + const auto in_vulkan2 = in_cpu2.vulkan(); + const double scale = 0.1; + const int zero_point = 10; + const auto out_cpu1 = at::quantize_per_tensor( + in_cpu1, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan1 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan1, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + + // Act + const double scale2 = 0.15; + const int zero_point2 = 15; + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + const auto vulkan_div = at::native::vulkan::ops::quantized_div( + out_vulkan1, out_vulkan2, scale2, zero_point2); + const auto vulkan_out = vulkan_to_cpu(vulkan_div, out_cpu1); + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed.count()); + } + +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("quantized_div") / 1000000.0); +#endif } static void CommonBenchmarkSettings(benchmark::internal::Benchmark* b) { @@ -95,13 +998,90 @@ static void CommonBenchmarkSettings(benchmark::internal::Benchmark* b) { } // namespace -BENCHMARK(cat_op_channel_perf_gpu_only)->Apply(CommonBenchmarkSettings)->UseManualTime()->Threads(1)->Iterations(100)->Args({3, 40, 221, 193}); // big multiple of 4 channels -BENCHMARK(cat_op_channel_perf_gpu_only)->Apply(CommonBenchmarkSettings)->UseManualTime()->Threads(1)->Iterations(100)->Args({3, 20, 221, 193}); // big multiple of 4 channels -BENCHMARK(cat_op_channel_perf_gpu_only)->Apply(CommonBenchmarkSettings)->UseManualTime()->Threads(1)->Iterations(100)->Args({3, 39, 221, 193}); // big non-multiple of 4 channels -BENCHMARK(cat_op_channel_perf_gpu_only)->Apply(CommonBenchmarkSettings)->UseManualTime()->Threads(1)->Iterations(100)->Args({3, 4, 221, 193}); // small multiple of 4 channels -BENCHMARK(cat_op_channel_perf_gpu_only)->Apply(CommonBenchmarkSettings)->UseManualTime()->Threads(1)->Iterations(100)->Args({3, 3, 221, 193}); // small non-multiple of 4 channels -BENCHMARK(cat_op_channel_perf_gpu_only)->Apply(CommonBenchmarkSettings)->UseManualTime()->Threads(3)->Iterations(100)->Args({3, 40, 221, 193}); // big multiple of 4 channels (multi-thread) -BENCHMARK(gru_op_perf)->Apply(CommonBenchmarkSettings)->Threads(1)->Iterations(100)->Args({384, 384, 2}); // McLaren Model inputs +BENCHMARK(add_op_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({3, 40, 221, 193}); +BENCHMARK(add_op_q_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({3, 40, 221, 193}); +BENCHMARK(conv2d_op_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({1, 17, 127, 397}); +BENCHMARK(conv2d_op_q_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({1, 17, 127, 397}); +BENCHMARK(conv2dpw_op_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({1, 17, 127, 397}); +BENCHMARK(conv2dpw_op_q_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({1, 17, 127, 397}); +BENCHMARK(conv2ddw_op_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(10) + ->Args({1, 7, 137, 199}); +BENCHMARK(conv2ddw_op_q_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(10) + ->Args({1, 7, 137, 199}); +BENCHMARK(sub_op_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({3, 40, 221, 193}); +BENCHMARK(sub_op_q_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({3, 40, 221, 193}); +BENCHMARK(mul_op_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({3, 40, 221, 193}); +BENCHMARK(mul_op_q_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({3, 40, 221, 193}); +BENCHMARK(div_op_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({3, 40, 221, 193}); +BENCHMARK(div_op_q_benchmark) + ->Apply(CommonBenchmarkSettings) + ->UseManualTime() + ->Threads(1) + ->Iterations(100) + ->Args({3, 40, 221, 193}); BENCHMARK_MAIN(); diff --git a/aten/src/ATen/test/vulkan_quantized_api_test.cpp b/aten/src/ATen/test/vulkan_quantized_api_test.cpp new file mode 100644 index 0000000000000..9519b079d35e8 --- /dev/null +++ b/aten/src/ATen/test/vulkan_quantized_api_test.cpp @@ -0,0 +1,1052 @@ +#ifdef USE_VULKAN_API + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +namespace { + +bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { + float maxValue = 0.0f; + + for (const auto& tensor : inputs) { + maxValue = fmax(tensor.abs().max().item(), maxValue); + } + +#ifdef USE_VULKAN_FP16_INFERENCE + constexpr float tolerance = 1e-2; +#else + constexpr float tolerance = 1e-5; +#endif + + return diff.abs().max().item() <= (tolerance * maxValue); +} + +bool almostEqual(const at::Tensor& a, const at::Tensor& b) { + return checkRtol(a - b, {a, b}); +} + +bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { + return (a - b).abs().max().item() == 0.0f; +} + +void showRtol(const at::Tensor& a, const at::Tensor& b) { + const auto diff = (a - b).abs(); + + float maxValue = a.abs().max().item(); + maxValue = fmax(b.abs().max().item(), maxValue); + +#ifdef USE_VULKAN_FP16_INFERENCE + constexpr float tolerance = 1e-2; +#else + constexpr float tolerance = 1e-5; +#endif + + const float maxDiff = maxValue * tolerance; + std::cout << "Max Diff allowed: " << maxDiff << std::endl; + if (diff.sizes().size() == 2) { + for (const auto y : c10::irange(diff.sizes()[0])) { + std::cout << y << ":"; + for (const auto x : c10::irange(diff.sizes()[1])) { + float diff_xy = diff[y][x].item(); + if (diff_xy > maxDiff) { + std::cout << std::setw(5) << x; + } else { + std::cout << std::setw(5) << " "; + } + } + std::cout << std::endl; + } + } +} + +template +inline std::vector makeStack(Inputs&&... inputs) { + return {std::forward(inputs)...}; +} + +template +inline std::vector callOpByHandle( + const c10::OperatorHandle& op, + Args... args) { + auto stack = makeStack(std::forward(args)...); + c10::Dispatcher::singleton().callBoxed(op, &stack); + return stack; +} + +template +inline std::vector callOpByName( + const char* func_name, + const char* overload_name, + Args... args) { + const c10::optional op_handle = + c10::Dispatcher::singleton().findSchema({func_name, overload_name}); + assert(op_handle.has_value()); + return callOpByHandle(op_handle.value(), std::forward(args)...); +} + +} // namespace + +namespace { + +class VulkanAPITest : public ::testing::Test { + public: + void SetUp() { + if (!at::is_vulkan_available()) { + GTEST_SKIP() << "Vulkan is not available"; + } +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + at::native::vulkan::api::context()->reset_querypool(); +#endif + } + + void TearDown() { +#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) + try { + at::native::vulkan::api::context()->querypool().extract_results(); + at::native::vulkan::api::context()->querypool().print_results(); + } catch (const std::exception& e) { + std::cout << "Could not get querypool results!" + << " Reason: " << e.what() << std::endl; + } +#endif + } +}; + +at::Tensor cpu_to_vulkan(at::Tensor in_cpu) { + auto options = in_cpu.options(); + if (options.dtype().toScalarType() == c10::ScalarType::QUInt8) { + auto ret = at::native::vulkan::ops::_empty_affine_quantized( + in_cpu.sizes(), + c10::ScalarType::QUInt8, + options.layout(), + options.device(), + options.pinned_memory(), + in_cpu.q_scale(), + in_cpu.q_zero_point(), + c10::MemoryFormat::Contiguous); + at::native::vulkan::ops::copy_(ret, in_cpu); + return ret; + } else { + auto ret = at::empty(in_cpu.sizes(), options); + at::native::vulkan::ops::copy_(ret, in_cpu); + return ret; + } +} + +at::Tensor vulkan_to_cpu(at::Tensor vulkan, at::Tensor in_cpu) { + auto q_options = in_cpu.options(); + if (q_options.dtype().toScalarType() == c10::ScalarType::QUInt8) { + auto output = at::native::empty_affine_quantized( + in_cpu.sizes(), + q_options.dtype().toScalarType(), + q_options.layout(), + q_options.device(), + q_options.pinned_memory(), + in_cpu.q_scale(), + in_cpu.q_zero_point()); + at::native::vulkan::ops::copy_(output, vulkan); + return output; + } else { + auto output = at::empty(in_cpu.sizes(), q_options); + at::native::vulkan::ops::copy_(output, vulkan); + return output; + } +} + +TEST_F(VulkanAPITest, support_vulkan) { + const double scale = 0.1; + const int64_t zero_point = 10; + + auto in_cpu = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 12 - + 6; + auto in_cpu_quantized = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + + auto in_vulkan_quantized = cpu_to_vulkan(in_cpu_quantized); + at::native::vulkan::api::PipelineBarrier pipeline_barrier{}; + at::native::vulkan::ops::vTensor& v_self = + at::native::vulkan::ops::convert(in_vulkan_quantized); + if (in_cpu.dtype() == c10::kQUInt8) { + v_self.image( + pipeline_barrier, + at::native::vulkan::api::PipelineStage::COMPUTE, + at::native::vulkan::api::MemoryAccessType::READ); + v_self.image( + pipeline_barrier, + at::native::vulkan::api::PipelineStage::COMPUTE, + at::native::vulkan::api::MemoryAccessType::WRITE); + } + auto output = vulkan_to_cpu(in_vulkan_quantized, in_cpu_quantized); + const auto check = almostEqual( + at::native::int_repr_quantized_cpu(in_cpu_quantized), + at::native::int_repr_quantized_cpu(output)); + + if (!check) { + showRtol( + at::native::int_repr_quantized_cpu(in_cpu_quantized), + at::native::int_repr_quantized_cpu(output)); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantize_per_tensor) { + const auto in_cpu = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + + auto output_for_quantized_vulkan = vulkan_to_cpu(out_vulkan, out_cpu); + + int rtol = 1; + const auto check = at::allclose( + at::native::int_repr_quantized_cpu(out_cpu), + at::native::int_repr_quantized_cpu(output_for_quantized_vulkan), + rtol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantize_dequantize) { + const auto in_cpu = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + // quantize tensors + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + // dequantize tensors + const auto out_cpu_deq = at::dequantize(out_cpu); + const auto out_vulkan_deq = at::native::vulkan::ops::dequantize(out_vulkan); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu); + + float rtol = 1; + float atol = 0.5; + const auto check = + at::allclose(in_cpu, output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); + + const auto check_two = + at::allclose(out_cpu_deq, output_for_dequantized_vulkan, rtol, atol); + + if (!check_two) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check_two); +} + +TEST_F(VulkanAPITest, quantized_add) { + const auto in_cpu = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan2 = in_cpu2.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto reg_added_tensors = callOpByName( + "quantized::add", + "", + out_cpu, out_cpu2, scale3, zero_point3); + const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_added_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2); + + float rtol = 0; + float atol = 0.5; + const auto check = at::allclose( + at::dequantize(reg_added_tensors[0].toTensor()), output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantized_add_broadcast) { + const auto in_cpu = + at::rand({2, 13, 1, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = + at::rand({2, 13, 32, 1}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan2 = in_cpu2.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto reg_added_tensors = callOpByName( + "quantized::add", + "", + out_cpu, out_cpu2, scale3, zero_point3); + const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto in_cpu3 = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_added_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu3); + + float rtol = 0; + float atol = 0.5; + const auto check = at::allclose( + at::dequantize(reg_added_tensors[0].toTensor()), output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantized_add_broadcast1) { + if (!at::is_vulkan_available()) { + return; + } + + const auto in_cpu = + at::rand({2, 12, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = + at::rand({12, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan2 = in_cpu2.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto reg_added_tensors = callOpByName( + "quantized::add", + "", + out_cpu, out_cpu2, scale3, zero_point3); + const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto in_cpu3 = + at::rand({2, 12, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_added_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu3); + + float rtol = 0; + float atol = 0.5; + const auto check = at::allclose( + at::dequantize(reg_added_tensors[0].toTensor()), output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantized_add_broadcast2) { + if (!at::is_vulkan_available()) { + return; + } + + const auto in_cpu = + at::rand({32, 1}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = + at::rand({1, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan2 = in_cpu2.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto reg_added_tensors = callOpByName( + "quantized::add", + "", + out_cpu, out_cpu2, scale3, zero_point3); + const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto in_cpu3 = + at::rand({32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_added_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu3); + + float rtol = 0; + float atol = 0.5; + const auto check = at::allclose( + at::dequantize(reg_added_tensors[0].toTensor()), output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + + +TEST_F(VulkanAPITest, quantized_add_broadcast3) { + if (!at::is_vulkan_available()) { + return; + } + + const auto in_cpu = + at::rand({32, 24}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = + at::rand({1}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan2 = in_cpu2.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto reg_added_tensors = callOpByName( + "quantized::add", + "", + out_cpu, out_cpu2, scale3, zero_point3); + const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto in_cpu3 = + at::rand({32, 24}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_added_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu3); + + float rtol = 0; + float atol = 0.5; + const auto check = at::allclose( + at::dequantize(reg_added_tensors[0].toTensor()), output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantized_add_dif_params) { + const auto in_cpu = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan2 = in_cpu2.vulkan(); + const double scale = 0.1; + const int zero_point = 10; + const double scale2 = 0.2; + const int zero_point2 = 20; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale2, zero_point2, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale2, zero_point2, c10::ScalarType::QUInt8); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto reg_added_tensors = callOpByName( + "quantized::add", + "", + out_cpu, out_cpu2, scale3, zero_point3); + const auto vulk_added_tensors = at::native::vulkan::ops::quantized_add( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_added_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2); + + float rtol = 0; + float atol = 0.5; + const auto check = at::allclose( + at::dequantize(reg_added_tensors[0].toTensor()), output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, conv2d) { + constexpr int64_t groups = 1; + constexpr std::array stride{2, 2}; + constexpr std::array padding{1, 1}; + // TODO: Support conv2d with dilation != 1 + constexpr std::array dilation{1, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{1, 3, 8, 8}; + + constexpr struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{1, input.channels, 3, 3}; + + float r1 = 0.1; + float r2 = 0.7; + const auto input_cpu = (r1 - r2) * + at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)) + + r2; + const auto weights_cpu = (r1 - r2) * + at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)) + + r2; + const auto bias_cpu = (r1 - r2) * + at::rand({weights.output_channels}, + at::device(at::kCPU).dtype(at::kFloat)) + + r2; + + const double w_scale = 0.1; + const int w_zero_point = 10; + + const double b_scale = 0.1; + const int b_zero_point = 10; + + const auto weight_q = at::quantize_per_tensor( + weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8); + const auto bias_q = at::quantize_per_tensor( + bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8); + + const auto output_cpu = at::conv2d( + input_cpu, weights_cpu, bias_cpu, stride, padding, dilation, groups); + + const double scale = 0.10; + const int zero_point = 10; + const auto shape_match = + at::rand({1, 1, 4, 4}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = input_cpu.vulkan(); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale2 = 0.15; + const int zero_point2 = 15; + const auto output_vulkan = at::native::vulkan::ops::conv2d( + out_vulkan, + weight_q, + bias_q, + stride, + padding, + dilation, + groups, + scale2, + zero_point2); + + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(output_vulkan); + auto output_for_dequantized_vulkan = + vulkan_to_cpu(out_vulkan_deq, shape_match); + + float rtol = 0; + float atol = 1.5; + const auto check = + at::allclose(output_cpu, output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, conv2d_pw) { + constexpr int64_t groups = 1; + constexpr std::array stride{1, 1}; + constexpr std::array padding{0, 0}; + constexpr std::array dilation{1, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{1, 17, 127, 397}; + + constexpr struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{29, input.channels, 1, 1}; + + float r1 = 0.1; + float r2 = 0.7; + const auto input_cpu = (r1 - r2) * + at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)) + + r2; + const auto weights_cpu = (r1 - r2) * + at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)) + + r2; + const auto bias_cpu = (r1 - r2) * + at::rand({weights.output_channels}, + at::device(at::kCPU).dtype(at::kFloat)) + + r2; + + const double w_scale = 0.1; + const int w_zero_point = 10; + + const double b_scale = 0.1; + const int b_zero_point = 10; + + const auto weight_q = at::quantize_per_tensor( + weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8); + const auto bias_q = at::quantize_per_tensor( + bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8); + + const auto output_cpu = at::conv2d( + input_cpu, weights_cpu, bias_cpu, stride, padding, dilation, groups); + + const double scale = 0.10; + const int zero_point = 10; + const auto shape_match = + at::rand({1, 29, 127, 397}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = input_cpu.vulkan(); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale2 = 0.15; + const int zero_point2 = 15; + const auto output_vulkan = at::native::vulkan::ops::conv2d( + out_vulkan, + weight_q, + bias_q, + stride, + padding, + dilation, + groups, + scale2, + zero_point2); + + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(output_vulkan); + auto output_for_dequantized_vulkan = + vulkan_to_cpu(out_vulkan_deq, shape_match); + + float rtol = 0; + float atol = 1.5; + const auto check = + at::allclose(output_cpu, output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, conv2d_dw) { + constexpr int64_t groups = 7; + constexpr std::array stride{2, 3}; + constexpr std::array padding{0, 4}; + constexpr std::array dilation{3, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{1, groups, 137, 199}; + + constexpr struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{groups, 1, 17, 7}; + + float r1 = 0; + float r2 = 0.2; + const auto input_cpu = (r1 - r2) * + at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)) + + r2; + const auto weights_cpu = (r1 - r2) * + at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)) + + r2; + const auto bias_cpu = (r1 - r2) * + at::rand({weights.output_channels}, + at::device(at::kCPU).dtype(at::kFloat)) + + r2; + + const double w_scale = 0.1; + const int w_zero_point = 10; + + const double b_scale = 0.1; + const int b_zero_point = 10; + + const auto weight_q = at::quantize_per_tensor( + weights_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8); + const auto bias_q = at::quantize_per_tensor( + bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8); + + const auto output_cpu = at::conv2d( + input_cpu, weights_cpu, bias_cpu, stride, padding, dilation, groups); + + const double scale = 0.10; + const int zero_point = 10; + const auto shape_match = + at::rand({1, 7, 45, 67}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = input_cpu.vulkan(); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale2 = 0.15; + const int zero_point2 = 15; + const auto output_vulkan = at::native::vulkan::ops::conv2d( + out_vulkan, + weight_q, + bias_q, + stride, + padding, + dilation, + groups, + scale2, + zero_point2); + + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(output_vulkan); + auto output_for_dequantized_vulkan = + vulkan_to_cpu(out_vulkan_deq, shape_match); + + float rtol = 0; + float atol = 1; + const auto check = + at::allclose(output_cpu, output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantized_sub) { + float r1 = 4.0; + float r2 = 7.0; + + float r3 = 2.0; + float r4 = 5.0; + const auto in_cpu = (r1 - r2) * + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) + + r2; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = (r3 - r4) * + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) + + r4; + const auto in_vulkan2 = in_cpu2.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + + const auto reg_subtracted_tensors = at::sub(in_cpu, in_cpu2); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto vulk_subtracted_tensors = at::native::vulkan::ops::quantized_sub( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_subtracted_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2); + + float rtol = 0; + float atol = 0.5; + const auto check = at::allclose( + reg_subtracted_tensors, output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantized_mul) { + const auto in_cpu = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; + const auto in_vulkan2 = in_cpu2.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto reg_mul_tensors = callOpByName( + "quantized::mul", "", out_cpu, out_cpu2, scale3, zero_point3); + const auto vulk_mul_tensors = at::native::vulkan::ops::quantized_mul( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_mul_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2); + + float rtol = 0; + float atol = 1.5; + const auto check = at::allclose( + at::dequantize(reg_mul_tensors[0].toTensor()), + output_for_dequantized_vulkan, + rtol, + atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantized_div) { + float r1 = 2.0; + float r2 = 3.5; + + float r3 = 4.0; + float r4 = 5.5; + const auto in_cpu = (r1 - r2) * + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) + + r2; + const auto in_vulkan = in_cpu.vulkan(); + const auto in_cpu2 = (r3 - r4) * + at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) + + r4; + const auto in_vulkan2 = in_cpu2.vulkan(); + + const double scale = 0.1; + const int zero_point = 10; + + const auto out_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_cpu2 = at::quantize_per_tensor( + in_cpu2, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8); + + const auto reg_div_tensors = at::div(in_cpu, in_cpu2); + + const double scale3 = 0.15; + const int zero_point3 = 15; + const auto vulk_div_tensors = at::native::vulkan::ops::quantized_div( + out_vulkan, out_vulkan2, scale3, zero_point3); + + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(vulk_div_tensors); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2); + + float rtol = 0; + float atol = 1; + const auto check = + at::allclose(reg_div_tensors, output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, quantized_upsample_nearest2d) { + const auto in_cpu = + at::rand({2, 13, 12, 27}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::upsample_nearest2d(in_cpu, {4, 6}, 1, 1); + + const double scale = 0.1; + const int zero_point = 10; + + const auto in_vulkan = in_cpu.vulkan(); + const auto out_vulkan = at::native::vulkan::ops::quantize_per_tensor( + in_vulkan, scale, zero_point, c10::ScalarType::QUInt8); + const auto upsample_vulkan = + at::native::vulkan::ops::quantized_upsample_nearest2d( + out_vulkan, {4, 6}, 1, 1); + + const auto in_cpu2 = + at::rand({2, 13, 4, 6}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_vulkan_deq = + at::native::vulkan::ops::dequantize(upsample_vulkan); + auto output_for_dequantized_vulkan = vulkan_to_cpu(out_vulkan_deq, in_cpu2); + + float rtol = 0; + float atol = 1; + const auto check = + at::allclose(out_cpu, output_for_dequantized_vulkan, rtol, atol); + + if (!check) { + std::cout << "Max Diff allowed: " << rtol << std::endl; + } + + ASSERT_TRUE(check); +} + +} // namespace + +#endif /* USE_VULKAN_API */ diff --git a/benchmarks/operator_benchmark/benchmark_all_other_test.py b/benchmarks/operator_benchmark/benchmark_all_other_test.py index a83eb9cb16ede..988e39f07c621 100644 --- a/benchmarks/operator_benchmark/benchmark_all_other_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_other_test.py @@ -7,6 +7,9 @@ groupnorm_test, interpolate_test, instancenorm_test, remainder_test, split_test, sum_test, tensor_to_test ) +from pt import ( # noqa: F401 + ao_sparsifier_test +) if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/ao_sparsifier_test.py b/benchmarks/operator_benchmark/pt/ao_sparsifier_test.py new file mode 100644 index 0000000000000..c2ef2aadf9f1b --- /dev/null +++ b/benchmarks/operator_benchmark/pt/ao_sparsifier_test.py @@ -0,0 +1,53 @@ + +import operator_benchmark as op_bench +import torch +from torch import nn + +from torch.ao import sparsity + + +"""Microbenchmarks for sparsifier.""" + +sparse_configs_short = op_bench.config_list( + attr_names=["M", "SL", "SBS", "ZPB"], + attrs=[ + [(32, 16), 0.3, (4, 1), 2], + [(32, 16), 0.6, (1, 4), 4], + [(17, 23), 0.9, (1, 1), 1] + ], + tags=("short",) +) + +sparse_configs_long = op_bench.cross_product_configs( + M=((128, 128), (255, 324)), # Mask shape + SL=(0.0, 1.0, 0.3, 0.6, 0.9, 0.99), # Sparsity level + SBS=((1, 4), (1, 8), (4, 1), (8, 1)), # Sparse block shape + ZPB=(0, 1, 2, 3, 4, None), # Zeros per block + tags=("long",) +) + +class WeightNormSparsifierBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, SL, SBS, ZPB): + weight = torch.ones(M) + model = nn.Module() + model.register_buffer("weight", weight) + + sparse_config = [{"tensor_fqn": "weight"}] + self.sparsifier = sparsity.WeightNormSparsifier( + sparsity_level=SL, + sparse_block_shape=SBS, + zeros_per_block=ZPB, + ) + self.sparsifier.prepare(model, config=sparse_config) + self.inputs = {} # All benchmarks need inputs :) + self.set_module_name("weight_norm_sparsifier_step") + + def forward(self): + self.sparsifier.step() + +all_tests = sparse_configs_short + sparse_configs_long +op_bench.generate_pt_test(all_tests, WeightNormSparsifierBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/qarithmetic_test.py b/benchmarks/operator_benchmark/pt/qarithmetic_test.py index 01be129fe597a..b1103a8a25315 100644 --- a/benchmarks/operator_benchmark/pt/qarithmetic_test.py +++ b/benchmarks/operator_benchmark/pt/qarithmetic_test.py @@ -5,8 +5,7 @@ qarithmetic_binary_configs = op_bench.cross_product_configs( N=(2, 8, 64, 512), dtype=(torch.quint8, torch.qint8, torch.qint32), - # contig=(False, True), # TODO: Reenable this after #29435 - contig=(True,), + contig=(False, True), tags=('short',) ) diff --git a/benchmarks/static_runtime/deep_wide_pt_bench.cc b/benchmarks/static_runtime/deep_wide_pt_bench.cc index df8d2d13eeda2..e8031aedbcb50 100644 --- a/benchmarks/static_runtime/deep_wide_pt_bench.cc +++ b/benchmarks/static_runtime/deep_wide_pt_bench.cc @@ -47,7 +47,7 @@ static void BM_deep_wide_jit_graph_executor(benchmark::State& state) { std::vector inputs({ad_emb_packed, user_emb, wide}); - CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0); + TORCH_CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0); mod.forward(inputs); for (auto _ : state) { @@ -65,7 +65,7 @@ static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) { std::vector inputs({ad_emb_packed, user_emb, wide}); - CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0); + TORCH_CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0); mod.forward(inputs); for (auto _ : state) { diff --git a/benchmarks/static_runtime/test_generated_ops.cc b/benchmarks/static_runtime/test_generated_ops.cc index e0e2b58b0f425..fa2d36cd3e151 100644 --- a/benchmarks/static_runtime/test_generated_ops.cc +++ b/benchmarks/static_runtime/test_generated_ops.cc @@ -4349,7 +4349,7 @@ TEST(StaticRuntime, autogen_take_along_dim) { )IR"; auto self0 = at::rand({6, 6, 6}); - auto indices0 = at::argsort(self0, 1); + auto indices0 = at::argsort(self0, 1, true); auto dim0 = 1; std::vector args{self0, indices0, dim0}; testStaticRuntime( @@ -4361,7 +4361,7 @@ TEST(StaticRuntime, autogen_take_along_dim) { /*check_resize=*/true); auto self1 = at::rand({22, 22, 22}); - auto indices1 = at::argsort(self1, 1); + auto indices1 = at::argsort(self1, 1, true); auto dim1 = 1; std::vector args2{self1, indices1, dim1}; testStaticRuntime( @@ -7770,16 +7770,17 @@ TEST(StaticRuntime, autogen_linalg_cond) { TEST(StaticRuntime, autogen_linalg_solve) { const std::string script = R"IR( - graph(%input: Tensor, %other: Tensor): + graph(%A: Tensor, %B: Tensor, %left: bool): %bias: None = prim::Constant() - %ret = aten::linalg_solve(%input, %other) + %ret = aten::linalg_solve(%A, %B, %left) %cloned = aten::clone(%ret, %bias) return (%cloned) )IR"; - auto input0 = at::rand({6, 6, 6}); - auto other0 = at::rand({6, 6, 6}); - std::vector args{input0, other0}; + auto A0 = at::rand({6, 6, 6}); + auto B0 = at::rand({6, 6, 6}); + auto left0 = false; + std::vector args{A0, B0, left0}; testStaticRuntime( script, args, @@ -7788,9 +7789,10 @@ TEST(StaticRuntime, autogen_linalg_solve) { /*use_equalnan=*/false, /*check_resize=*/true); - auto input1 = at::rand({22, 22, 22}); - auto other1 = at::rand({22, 22, 22}); - std::vector args2{input1, other1}; + auto A1 = at::rand({22, 22, 22}); + auto B1 = at::rand({22, 22, 22}); + auto left1 = false; + std::vector args2{A1, B1, left1}; testStaticRuntime( script, args, diff --git a/benchmarks/static_runtime/test_static_module.cc b/benchmarks/static_runtime/test_static_module.cc index 41758ec9f2f37..70d1d1d306939 100644 --- a/benchmarks/static_runtime/test_static_module.cc +++ b/benchmarks/static_runtime/test_static_module.cc @@ -485,7 +485,7 @@ TEST(StaticRuntime, DeepWide) { at::Tensor output_2 = outputs[0].toTensor(); smod.runtime().check_for_memory_leak(); EXPECT_TRUE( - torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7)); + torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5)); } } } @@ -513,7 +513,7 @@ TEST(StaticRuntime, KWargsAPI_1) { at::Tensor output_2 = getTensor(output_ivalue); EXPECT_TRUE( - torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7)); + torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5)); // check for output aliasing EXPECT_EQ(output_ivalue.use_count(), 1); @@ -558,7 +558,7 @@ TEST(StaticRuntime, KWargsAPI_2) { at::Tensor output_2 = getTensor(output_ivalue); EXPECT_TRUE( - torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7)); + torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5)); // check for output aliasing EXPECT_EQ(output_ivalue.use_count(), 1); @@ -636,7 +636,7 @@ TEST(StaticRuntime, CleanUpMemory) { auto output_2 = outputs[0].toTensor(); runtime.check_for_memory_leak(); EXPECT_TRUE(torch::allclose( - output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7)); + output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5)); if (manage_output_tensors) { runtime.deallocateOutputTensors(); runtime.checkOutputTensorMemoryLeaks(); @@ -882,7 +882,7 @@ TEST(StaticRuntime, FusionPass) { EXPECT_TRUE(hit); auto output_2 = getTensor(module.forward(inputs)); EXPECT_TRUE( - torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7)); + torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5)); } } } diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 9bbd387c1bb14..72ee217401ab0 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -877,15 +877,30 @@ TEST(StaticRuntime, Div) { return torch.div(a, b, rounding_mode=c).clone() )JIT"; + const auto div_strided = R"JIT( + def forward(self, a: Tensor, b: Tensor): + a_strided = torch.transpose(a, 0, 1) + b_strided = torch.transpose(b, 0, 1) + return torch.div(a_strided, b_strided).clone() + )JIT"; + auto a = at::randn({2, 3}); auto b = at::randn({2, 3}); + auto bs = at::randn({3, 2}).transpose(0, 1); auto c = at::randn({4, 3, 2}); auto d = at::randn({4, 3, 2}); + auto ds = at::randn({3, 4, 2}).transpose(0, 1); std::vector args0{a, b}; testStaticRuntime(div_tensor, args0); testStaticRuntime(div_tensor, args0, {c, d}); + testStaticRuntime(div_strided, args0); + testStaticRuntime(div_strided, args0, {c, d}); + + testStaticRuntime(div_tensor, {a, bs}); + testStaticRuntime(div_tensor, {a, bs}, {c, ds}); + std::vector args1{a, 3}; testStaticRuntime(div_scalar, args1); testStaticRuntime(div_scalar, args1, {c, 4}); diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc index 7690e356adaa0..56ad619ba4a86 100644 --- a/binaries/benchmark_helper.cc +++ b/binaries/benchmark_helper.cc @@ -173,7 +173,7 @@ int loadInput( LOG(INFO) << "Running on GPU."; #ifdef __CUDA_ARCH__ caffe2::TensorCUDA* tensor = blob->GetMutable(); - CHECK_NOTNULL(tensor); + TORCH_CHECK_NOTNULL(tensor); tensor->Resize(input_dims); if (input_type_list[i] == "uint8_t") { tensor->mutable_data(); @@ -189,17 +189,17 @@ int loadInput( if (input_type_list[i] == "uint8_t") { caffe2::int8::Int8TensorCPU* tensor = blob->GetMutable(); - CHECK_NOTNULL(tensor); + TORCH_CHECK_NOTNULL(tensor); tensor->t.Resize(input_dims); tensor->t.mutable_data(); } else if (input_type_list[i] == "float") { caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU); - CHECK_NOTNULL(tensor); + TORCH_CHECK_NOTNULL(tensor); tensor->Resize(input_dims); tensor->mutable_data(); } else if (input_type_list[i] == "int") { caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU); - CHECK_NOTNULL(tensor); + TORCH_CHECK_NOTNULL(tensor); tensor->Resize(input_dims); tensor->mutable_data(); } else { @@ -495,7 +495,7 @@ int benchmark( net_def.set_name("benchmark"); } caffe2::NetBase* net = workspace->CreateNet(net_def); - CHECK_NOTNULL(net); + TORCH_CHECK_NOTNULL(net); runNetwork( workspace, net, diff --git a/binaries/compare_models_torch.cc b/binaries/compare_models_torch.cc index bf88c390799f1..7afac42589b62 100644 --- a/binaries/compare_models_torch.cc +++ b/binaries/compare_models_torch.cc @@ -74,6 +74,7 @@ C10_DEFINE_string( "cpu", "what backend to use for model (vulkan, cpu, metal) (default=cpu)"); C10_DEFINE_string(tolerance, "1e-5", "tolerance to use for comparison"); +C10_DEFINE_int(nthreads, 1, "Number of threads to launch. Useful for checking correct concurrent behaviour."); C10_DEFINE_bool( report_failures, true, @@ -232,6 +233,48 @@ std::vector create_inputs( return inputs; } +void run_check(float tolerance) { + torch::jit::Module module = torch::jit::load(FLAGS_model); + torch::jit::Module refmodule = torch::jit::load(FLAGS_refmodel); + + module.eval(); + refmodule.eval(); + + std::thread::id this_id = std::this_thread::get_id(); + std::cout << "Running check on thread " << this_id << "." << std::endl; + + int passed = 0; + for (int i = 0; i < FLAGS_iter; ++i) { + std::vector refinputs; + std::vector inputs; + create_inputs( + refinputs, inputs, + FLAGS_refbackend, FLAGS_backend, + FLAGS_input_min, FLAGS_input_max); + + const auto refoutput = refmodule.forward(refinputs).toTensor().cpu(); + const auto output = module.forward(inputs).toTensor().cpu(); + + bool check = checkRtol( + refoutput-output, + {refoutput, output}, + tolerance, + FLAGS_report_failures); + + if (check) { + passed += 1; + } + else if (FLAGS_report_failures) { + std::cout << " (Iteration " << i << " failed)" << std::endl; + } + + if (i > 0 && (i+1) % FLAGS_report_freq == 0) { + report_pass_rate(passed, i+1); + } + } + report_pass_rate(passed, FLAGS_iter); +} + int main(int argc, char** argv) { c10::SetUsageMessage( "Run accuracy comparison to a reference model for a pytorch model.\n" @@ -260,41 +303,24 @@ int main(int argc, char** argv) { c10::InferenceMode mode; torch::autograd::AutoGradMode guard(false); torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false); - auto module = torch::jit::load(FLAGS_model); - auto refmodule = torch::jit::load(FLAGS_refmodel); - - module.eval(); - refmodule.eval(); c10::CPUCachingAllocator caching_allocator; c10::optional caching_allocator_guard; if (FLAGS_use_caching_allocator) { caching_allocator_guard.emplace(&caching_allocator); } - std::cout << "Running modules." << std::endl; - - int passed = 0; - for (int i = 0; i < FLAGS_iter; ++i) { - std::vector refinputs; - std::vector inputs; - create_inputs(refinputs, inputs, FLAGS_refbackend, FLAGS_backend, FLAGS_input_min, FLAGS_input_max); - const auto refoutput = refmodule.forward(refinputs).toTensor().cpu(); - const auto output = module.forward(inputs).toTensor().cpu(); - - bool check = checkRtol(refoutput-output, {refoutput, output}, tolerance, FLAGS_report_failures); - if (check) { - passed += 1; - if (FLAGS_report_failures && !check) { - std::cout << " (Iteration " << i << " failed)" << std::endl; - } - } + std::vector check_threads; + check_threads.reserve(FLAGS_nthreads); + for (int i = 0; i < FLAGS_nthreads; ++i) { + check_threads.emplace_back(std::thread(run_check, tolerance)); + } - if (i > 0 && (i+1) % FLAGS_report_freq == 0) { - report_pass_rate(passed, i+1); + for (std::thread& th : check_threads) { + if (th.joinable()) { + th.join(); } } - report_pass_rate(passed, FLAGS_iter); return 0; } diff --git a/binaries/convert_and_benchmark.cc b/binaries/convert_and_benchmark.cc index 06983bb8b81b6..ceb029c304749 100644 --- a/binaries/convert_and_benchmark.cc +++ b/binaries/convert_and_benchmark.cc @@ -591,7 +591,7 @@ void runNetwork( } caffe2::NetBase* net = workspace->CreateNet(net_def); - CHECK_NOTNULL(net); + TORCH_CHECK_NOTNULL(net); LOG(INFO) << "Starting benchmark."; caffe2::ObserverConfig::initSampleRate(1, 1, 1, run_individual, warmup); diff --git a/binaries/make_image_db.cc b/binaries/make_image_db.cc index b2f8135f5da2d..3bbe15a062975 100644 --- a/binaries/make_image_db.cc +++ b/binaries/make_image_db.cc @@ -251,7 +251,7 @@ void ConvertImageDataset( // Synthesize key for this entry auto key_len = snprintf( key_cstr, sizeof(key_cstr), "%08d_%s", i, lines[i].first.c_str()); - DCHECK_LE(key_len, sizeof(key_cstr)); + TORCH_DCHECK_LE(key_len, sizeof(key_cstr)); // Put in db transaction->Put(string(key_cstr), std::move(value)); diff --git a/binaries/speed_benchmark.cc b/binaries/speed_benchmark.cc index 00f93f474362f..3885047e61f07 100644 --- a/binaries/speed_benchmark.cc +++ b/binaries/speed_benchmark.cc @@ -136,12 +136,12 @@ int main(int argc, char** argv) { if (input_type_list[i] == "uint8_t") { caffe2::int8::Int8TensorCPU* tensor = blob->GetMutable(); - CHECK_NOTNULL(tensor); + TORCH_CHECK_NOTNULL(tensor); tensor->t.Resize(input_dims); tensor->t.mutable_data(); } else if (input_type_list[i] == "float") { caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU); - CHECK_NOTNULL(tensor); + TORCH_CHECK_NOTNULL(tensor); tensor->Resize(input_dims); tensor->mutable_data(); } else { @@ -184,7 +184,7 @@ int main(int argc, char** argv) { } caffe2::NetBase* net = workspace->CreateNet(net_def); - CHECK_NOTNULL(net); + TORCH_CHECK_NOTNULL(net); CAFFE_ENFORCE(net->Run()); net->TEST_Benchmark(FLAGS_warmup, FLAGS_iter, FLAGS_run_individual); diff --git a/buckbuild.bzl b/buckbuild.bzl index 9150061912c88..40f542e3f80df 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -1,39 +1,775 @@ -load("//tools/build_defs:expect.bzl", "expect") - # NOTE: This file is shared by internal and OSS BUCK build. # These load paths point to different files in internal and OSS environment + +load("@bazel_skylib//lib:paths.bzl", "paths") load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") +load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") +load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") +load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX") +load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") +load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build") +load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build") +load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags") load( ":build_variables.bzl", + "aten_cpu_source_list", + "aten_native_source_list", + "core_sources_common", + "core_sources_full_mobile_no_backend_interface", + "core_trainer_sources", "jit_core_headers", + "jit_core_sources", + "libtorch_profiler_sources", +) +load( + ":pt_ops.bzl", + "USED_PT_BACKENDS", +) +load( + ":pt_template_srcs.bzl", + "METAL_MASKRCNN_SOURCE_LIST", + "METAL_SOURCE_LIST", + "TEMPLATE_MASKRCNN_SOURCE_LIST", + "TEMPLATE_SOURCE_LIST", + "aten_ufunc_generated_all_cpu_sources", + "get_gen_oplist_outs", + "get_generate_code_bin_outs", + "get_metal_registration_files_outs", + "get_metal_registration_files_outs_windows", + "get_metal_source_dict", + "get_template_registration_file_rules", + "get_template_registration_files_outs", + "get_template_source_dict", +) +load( + ":ufunc_defs.bzl", + "aten_ufunc_generated_cpu_kernel_sources", + "aten_ufunc_generated_cpu_sources", + "aten_ufunc_generated_cuda_sources", ) -def read_bool(section, field, default): +def read_bool(section, field, default, required = True): # @lint-ignore BUCKRESTRICTEDSYNTAX - value = read_config(section, field) - if value == None: + val = read_config(section, field) + if val != None: + if val in ["true", "True", "1"]: + return True + elif val in ["false", "False", "0"]: + return False + else: + fail( + "`{}:{}`: must be one of (0, 1, true, false, True, False), but was {}".format(section, field, val), + ) + elif default != None: return default - expect( - value == "0" or value == "1", - "{}.{} == \"{}\", wanted \"0\" or \"1\".".format(section, field, value), - ) - return bool(int(value)) + elif not required: + return None + else: + fail("`{}:{}`: no value set".format(section, field)) + +def _is_build_mode_dev(): + if is_production_build_android(): + # Android Prod builds + return False + if is_production_build_ios(): + # iOS Prod builds + return False + + return True + +def _get_enable_lightweight_dispatch(): + return read_bool("pt", "enable_lightweight_dispatch", False) + +def _get_enable_record_kernel_dtype(): + return read_bool("pt", "enable_record_kernel_dtype", False) + +def get_enable_mobile_dispatch_keys_trimming(): + return read_bool("pt", "enable_mobile_dispatch_keys_trimming", False) + +def get_disable_per_op_profiling(): + return read_bool("pt", "disable_per_op_profiling", True) -def is_oss_build(): - return read_bool("pt", "is_oss", False) +def get_strip_error_messages(): + if IS_OSS: + return True # always strip in OSS CI to expose potential issues + return read_bool("pt", "strip_error_messages", not _is_build_mode_dev()) + +def get_enable_eager_symbolication(): + return read_bool("pt", "enable_eager_symbolication", default = False, required = False) + +def get_static_dispatch_backend(): + static_dispatch_backend = native.read_config("pt", "static_dispatch_backend", None) + if static_dispatch_backend == None: + return [] + return static_dispatch_backend.split(";") + +# @lint-ignore BUCKRESTRICTEDSYNTAX +IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build + +NOT_OSS = not IS_OSS # for targets in caffe2 root path -ROOT_NAME = "//" if is_oss_build() else "//xplat/caffe2" +ROOT = "//" if IS_OSS else "//xplat/caffe2" # for targets in subfolders -ROOT_PATH = "//" if is_oss_build() else "//xplat/caffe2/" +ROOT_PATH = "//" if IS_OSS else "//xplat/caffe2/" + +C10 = "//c10:c10" if IS_OSS else "//xplat/caffe2/c10:c10" + +# a dictionary maps third party library name to fbsource and oss target +THIRD_PARTY_LIBS = { + "FP16": ["//xplat/third-party/FP16:FP16", "//third_party:FP16"], + "FXdiv": ["//xplat/third-party/FXdiv:FXdiv", "//third_party:FXdiv"], + "XNNPACK": ["//xplat/third-party/XNNPACK:XNNPACK", "//third_party:XNNPACK"], + "clog": ["//xplat/third-party/clog:clog", "//third_party:clog"], + "cpuinfo": ["//third-party/cpuinfo:cpuinfo", "//third_party:cpuinfo"], + "flatbuffers-api": ["//third-party/flatbuffers:flatbuffers-api", "//third_party:flatbuffers-api"], + "flatc": ["//third-party/flatbuffers:flatc", "//third_party:flatc"], + "fmt": ["//third-party/fmt:fmt", "//third_party:fmt"], + "glog": ["//third-party/glog:glog", "//third_party:glog"], + "gmock": ["//xplat/third-party/gmock:gtest", "//third_party:gmock"], + "gtest": ["//xplat/third-party/gmock:gmock", "//third_party:gtest"], + "kineto": ["//xplat/kineto/libkineto:libkineto", "//third_party:libkineto"], + "libkineto_headers": ["//xplat/kineto/libkineto:libkineto_headers", "//third_party:libkineto_headers"], + "omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"], + "pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"], + "psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"], + "pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"], + "pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"], + "pyyaml": ["//third-party/pyyaml:pyyaml", "//third_party:pyyaml"], + "rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"], + "ruy": ["//third-party/ruy:ruy_xplat_lib", "//third_party:ruy_lib"], + "typing-extensions": ["//third-party/typing-extensions:typing-extensions", "//third_party:typing-extensions"], +} + +def third_party(name): + if name not in THIRD_PARTY_LIBS: + fail("Cannot find third party library " + name + ", please register it in THIRD_PARTY_LIBS first!") + return THIRD_PARTY_LIBS[name][1] if IS_OSS else THIRD_PARTY_LIBS[name][0] + +def get_pt_compiler_flags(): + return select({ + "DEFAULT": _PT_COMPILER_FLAGS + [ + "-std=gnu++17", #to accomodate for eigen + ], + "ovr_config//compiler:cl": windows_convert_gcc_clang_flags(_PT_COMPILER_FLAGS), + }) + +_PT_COMPILER_FLAGS = [ + "-fexceptions", + "-frtti", + "-Os", + "-Wno-unknown-pragmas", + "-Wno-write-strings", + "-Wno-unused-variable", + "-Wno-unused-function", + "-Wno-deprecated-declarations", + "-Wno-shadow", + "-Wno-global-constructors", + "-Wno-missing-prototypes", +] + +ATEN_COMPILER_FLAGS = [ + "-fexceptions", + "-frtti", + "-fPIC", + "-Os", + "-Wno-absolute-value", + "-Wno-deprecated-declarations", + "-Wno-macro-redefined", + "-Wno-tautological-constant-out-of-range-compare", + "-Wno-unknown-pragmas", + "-Wno-unknown-warning-option", + "-Wno-unused-function", + "-Wno-unused-variable", + "-Wno-pass-failed", + "-Wno-shadow", +] + +def get_aten_compiler_flags(): + return ATEN_COMPILER_FLAGS + +_COMMON_PREPROCESSOR_FLAGS = [ + "-DC10_MOBILE", + "-DNO_EXPORT", +] + ( + ["-DC10_MOBILE_TRIM_DISPATCH_KEYS"] if get_enable_mobile_dispatch_keys_trimming() else [] +) + ( + ["-DSTRIP_ERROR_MESSAGES"] if get_strip_error_messages() else [] +) + +def get_aten_preprocessor_flags(): + # read_config is not allowed outside of function in Starlark + ATEN_PREPROCESSOR_FLAGS = _COMMON_PREPROCESSOR_FLAGS + [ + "-DCPU_CAPABILITY_DEFAULT", + "-DCPU_CAPABILITY=DEFAULT", + "-DCAFFE2_USE_LITE_PROTO", + "-DATEN_CUDNN_ENABLED_FBXPLAT=0", + "-DATEN_MKLDNN_ENABLED_FBXPLAT=0", + "-DATEN_NNPACK_ENABLED_FBXPLAT=0", + "-DATEN_MKL_ENABLED_FBXPLAT=0", + "-DATEN_MKL_SEQUENTIAL_FBXPLAT=0", + "-DUSE_PYTORCH_METAL", + "-DUSE_PYTORCH_QNNPACK", + "-DUSE_XNNPACK", + "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", + "-DAT_PARALLEL_OPENMP_FBXPLAT=0", + "-DAT_PARALLEL_NATIVE_FBXPLAT=1", + "-DAT_PARALLEL_NATIVE_TBB_FBXPLAT=0", + "-DUSE_LAPACK_FBXPLAT=0", + "-DAT_BLAS_F2C_FBXPLAT=0", + "-DAT_BLAS_USE_CBLAS_DOT_FBXPLAT=0", + "-DUSE_RUY_QMATMUL", + ] + if get_disable_per_op_profiling(): + ATEN_PREPROCESSOR_FLAGS.append("-DPYTORCH_DISABLE_PER_OP_PROFILING") + if _get_enable_record_kernel_dtype(): + ATEN_PREPROCESSOR_FLAGS.append("-DENABLE_RECORD_KERNEL_FUNCTION_DTYPE") + return ATEN_PREPROCESSOR_FLAGS + +def get_pt_preprocessor_flags(): + # read_config is not allowed outside of function in Starlark + PT_PREPROCESSOR_FLAGS = _COMMON_PREPROCESSOR_FLAGS + [ + "-D_THP_CORE", + "-DUSE_SCALARS", + "-DNO_CUDNN_DESTROY_HANDLE", + "-DBUILD_CAFFE2", + ] + + if _is_build_mode_dev(): + PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS") + return PT_PREPROCESSOR_FLAGS + +# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892 +PT_BACKEND_HEADERS = [ + "CPU", + "CUDA", + "CompositeExplicitAutograd", + "CompositeExplicitAutogradNonFunctional", + "CompositeImplicitAutograd", + "Meta", +] + +def get_aten_static_dispatch_backend_headers(existing_headers): + static_backends = get_static_dispatch_backend() + for backend in static_backends: + if backend != "CPU": + existing_headers["{}Functions.h".format(backend)] = ":gen_aten[{}Functions.h]".format(backend) + existing_headers["{}Functions_inl.h".format(backend)] = ":gen_aten[{}Functions_inl.h]".format(backend) + return existing_headers + +def get_aten_codegen_extra_params(backends): + extra_params = { + "force_schema_registration": True, + } + static_backends = get_static_dispatch_backend() + if static_backends: + extra_params["static_dispatch_backend"] = static_backends + extra_params["enabled_backends"] = static_backends + else: + extra_params["enabled_backends"] = backends + return extra_params + +def get_jit_codegen_params(): + return [] + +def get_unboxing_generated_files(): + srcs = [] + if _get_enable_lightweight_dispatch(): + srcs = [ + "UnboxingFunctions.h", + "UnboxingFunctions_0.cpp", + "UnboxingFunctions_1.cpp", + "UnboxingFunctions_2.cpp", + "UnboxingFunctions_3.cpp", + "UnboxingFunctions_4.cpp", + "RegisterCodegenUnboxedKernels_0.cpp", + "RegisterCodegenUnboxedKernels_1.cpp", + "RegisterCodegenUnboxedKernels_2.cpp", + "RegisterCodegenUnboxedKernels_3.cpp", + "RegisterCodegenUnboxedKernels_4.cpp", + "RegisterCodegenUnboxedKernels_5.cpp", + "RegisterCodegenUnboxedKernels_6.cpp", + "RegisterCodegenUnboxedKernels_7.cpp", + "RegisterCodegenUnboxedKernels_8.cpp", + "RegisterCodegenUnboxedKernels_9.cpp", + ] + res = {} + for file_name in srcs: + res[file_name] = [file_name] + return res + +def get_aten_generated_files(enabled_backends): + # NB: RegisterMeta counts as an optionally enabled backend, + # and is intentionally omitted from here + src_files = [ + "RegisterBackendSelect.cpp", + "RegisterCompositeImplicitAutograd.cpp", + "RegisterCompositeExplicitAutograd.cpp", + "RegisterCompositeExplicitAutogradNonFunctional.cpp", + "CompositeViewCopyKernels.cpp", + "RegisterSchema.cpp", + "Declarations.yaml", + "Functions.cpp", + "Functions.h", + "RedispatchFunctions.h", + "NativeFunctions.h", + "NativeMetaFunctions.h", + "MethodOperators.h", + "FunctionalInverses.h", + "Operators.h", + "Operators_0.cpp", + "Operators_1.cpp", + "Operators_2.cpp", + "Operators_3.cpp", + "Operators_4.cpp", + "CompositeImplicitAutogradFunctions.h", + "CompositeImplicitAutogradFunctions_inl.h", + "CompositeExplicitAutogradFunctions.h", + "CompositeExplicitAutogradFunctions_inl.h", + "CompositeExplicitAutogradNonFunctionalFunctions.h", + "CompositeExplicitAutogradNonFunctionalFunctions_inl.h", + "core/ATenOpList.cpp", + "core/TensorBody.h", + "core/TensorMethods.cpp", + "core/aten_interned_strings.h", + "core/enum_tag.h", + ] + get_aten_derived_type_srcs(enabled_backends) + + # This is tiresome. A better strategy would be to unconditionally + # generate these files, and then only actually COMPILE them depended + # on the generated set. C'est la vie... + if "CPU" in enabled_backends: + src_files.extend(aten_ufunc_generated_cpu_sources()) + src_files.extend(aten_ufunc_generated_cpu_kernel_sources()) + if "CUDA" in enabled_backends: + # Cannot unconditionally include this, because in the Edge selective + # build CUDA is not enabled and thus the ufunc codegen for CUDA gets + # skipped + src_files.extend(aten_ufunc_generated_cuda_sources()) + + res = {} + for file_name in src_files: + res[file_name] = [file_name] + return res + +def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends): + return [ + ":{}[{}]".format(aten_rule_name, "Register" + backend + ".cpp") + for backend in enabled_backends + ] + +def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends): + return [ + ":{}[{}]".format(aten_rule_name, f) + for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"] + ] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends) + +def get_aten_derived_type_srcs(enabled_backends): + return [ + "Register" + derived_type + ".cpp" + for derived_type in enabled_backends + ] + [ + derived_type + "Functions.h" + for derived_type in enabled_backends + if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend() + ] + [ + derived_type + "Functions_inl.h" + for derived_type in enabled_backends + if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend() + ] + +def gen_aten_files( + name, + extra_flags = {}, + visibility = [], + compatible_with = [], + apple_sdks = None): + extra_params = [] + force_schema_registration = extra_flags.get("force_schema_registration", False) + op_registration_allowlist = extra_flags.get("op_registration_allowlist", None) + op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None) + enabled_backends = extra_flags.get("enabled_backends", None) + static_dispatch_backend = extra_flags.get("static_dispatch_backend", None) + + if force_schema_registration: + extra_params.append("--force_schema_registration") + if op_registration_allowlist != None and is_string(op_registration_allowlist): + extra_params.append("--op_registration_whitelist") + extra_params.append(op_registration_allowlist) + if op_selection_yaml_path != None and is_string(op_selection_yaml_path): + extra_params.append("--op_selection_yaml_path") + extra_params.append(op_selection_yaml_path) + if enabled_backends != None and is_list(enabled_backends): + extra_params.append("--backend_whitelist") + extra_params.extend(enabled_backends) + if _get_enable_lightweight_dispatch(): + extra_params.append("--skip_dispatcher_op_registration") + if static_dispatch_backend: + extra_params.append("--static_dispatch_backend") + extra_params.extend(static_dispatch_backend) + backends = static_dispatch_backend + else: + backends = enabled_backends + fb_xplat_genrule( + name = name, + default_outs = ["."], + outs = get_aten_generated_files(backends), + cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([ + "--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT), + "--install_dir $OUT", + ] + extra_params), + visibility = visibility, + compatible_with = compatible_with, + apple_sdks = apple_sdks, + ) + +def gen_aten_unboxing_files( + genrule_name, + extra_flags = {}): + extra_params = [] + op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None) + op_registration_allowlist = extra_flags.get("op_registration_allowlist", None) + if op_selection_yaml_path != None and is_string(op_selection_yaml_path): + extra_params.append("--op_selection_yaml_path") + extra_params.append(op_selection_yaml_path) + if op_registration_allowlist != None and is_string(op_registration_allowlist): + extra_params.append("--op_registration_allowlist") + extra_params.append(op_registration_allowlist) + + fb_xplat_genrule( + name = genrule_name, + default_outs = ["."], + outs = get_unboxing_generated_files(), + cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([ + "--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT), + "--install_dir $OUT", + ] + extra_params), + visibility = ["PUBLIC"], + ) + +def copy_template_registration_files(name, apple_sdks = None): + cmd = [] + cmd_exe = [] + + template_source_dict = get_template_source_dict() + + # Ideally, we would run one copy command for a single source directory along + # with all its child directories, but it's somewhat hard to know if a directory + # is a child of another just bu looking at the metadata (directory relative + # path) that we currently have since 1 directory could look like a parent of + # another and yet come from a different filegroup() rule. + # + for (path_prefix, file_paths) in template_source_dict.items(): + cmd.append("mkdir -p $OUT/{}".format(path_prefix)) + cmd_exe.append("md $OUT/{}".format(path_prefix)) + + # Adding *.cpp is a workaround to prevent cp from thrown an error when it + # encounters a directory (since -r was not specified). If files with an + # extension other than .cpp need to be copied, then the command below + # will not work and will need to be updated. + # + cmd.append("cp -f $(location {0}:templated_selective_build_srcs)/{1}/*.cpp $OUT/{1}/".format(ROOT, path_prefix)) + cmd_exe.append("robocopy /E $(location {0}:templated_selective_build_srcs)/{1} $OUT/{1}".format(ROOT, path_prefix)) + + if NOT_OSS: + for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST: + maskrcnn_file = "$(location //xplat/caffe2/fb/custom_ops/maskrcnn:templated_selective_build_srcs)/" + file_path + cmd.append("cp -f " + maskrcnn_file + " $OUT") + cmd_exe.append("copy " + maskrcnn_file + " $OUT") + + cmd.append("mkdir -p $OUT/aten/src/ATen") + cmd_exe.append("md $OUT/aten/src/ATen") + + # NB: CUDA is skipped here because this is selective build and CUDA is not + # supported for selective build + for ufunc_file in aten_ufunc_generated_all_cpu_sources("$(location " + ROOT + ":gen_aten[{}])"): + cmd.append("cp -f " + ufunc_file + " $OUT/aten/src/ATen") + cmd_exe.append("copy " + ufunc_file + " $OUT/aten/src/ATen") + + if NOT_OSS: + pvd_batch_box_cox_file = "$(location //xplat/caffe2/fb/custom_ops/batch_box_cox:templated_selective_build_srcs)/register_batch_box_cox_ops.cpp" + cmd.append("cp -f " + pvd_batch_box_cox_file + " $OUT") + cmd_exe.append("copy " + pvd_batch_box_cox_file + " $OUT") + + fb_xplat_genrule( + name = name, + cmd = " && ".join(cmd), + cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)), + outs = get_template_registration_files_outs(IS_OSS), + default_outs = ["."], + apple_sdks = apple_sdks, + ) + +def pt_operator_query_codegen( + name, + deps = [], + train = False, + enforce_traced_op_list = False, + pt_allow_forced_schema_registration = True, + compatible_with = [], + apple_sdks = None): + oplist_dir_name = name + "_pt_oplist" + + # @lint-ignore BUCKLINT + fb_native.genrule( + name = oplist_dir_name, + cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) + + "--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " + + ("" if enforce_traced_op_list else "--allow_include_all_overloads ") + + "--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])), + outs = get_gen_oplist_outs(), + default_outs = ["."], + compatible_with = compatible_with, + ) + + # Aten files + aten_genrule = name + "_aten" + extra_flags = { + "enabled_backends": USED_PT_BACKENDS, + "op_selection_yaml_path": "$(location :{}[selected_operators.yaml])".format(oplist_dir_name), + } + + if train and pt_allow_forced_schema_registration: + extra_flags["force_schema_registration"] = True + + unboxing_genrule = name + "_unboxing" + if _get_enable_lightweight_dispatch(): + gen_aten_unboxing_files( + unboxing_genrule, + extra_flags = extra_flags, + ) + + static_dispatch_backend = get_static_dispatch_backend() + if static_dispatch_backend: + extra_flags["static_dispatch_backend"] = static_dispatch_backend + + gen_aten_files( + aten_genrule, + extra_flags = extra_flags, + compatible_with = compatible_with, + apple_sdks = apple_sdks, + ) + + # unboxing_wrappers files + extra_params = [ + "--operators_yaml_path", + "$(location :" + oplist_dir_name + "[selected_operators.yaml])", + ] + unboxing_and_autograd_genrule = name + "_unboxing_and_autograd" + gen_aten_libtorch_files( + unboxing_and_autograd_genrule, + extra_params, + compatible_with, + apple_sdks = apple_sdks, + ) + + # Template runtime files (prim ops, etc) + template_registration_genrule = name + "_template_registration" + copy_template_registration_files(template_registration_genrule, apple_sdks = apple_sdks) + + # Files needed for metal + if NOT_OSS: + metal_genrule = name + "_metal" + copy_metal(metal_genrule, apple_sdks = apple_sdks) + + srcs = get_aten_selective_cpp_rules( + aten_genrule, + static_dispatch_backend if static_dispatch_backend else USED_PT_BACKENDS, + ) + get_template_registration_file_rules(template_registration_genrule, IS_OSS) + ([ + ":{}[autograd/generated/VariableType_0.cpp]".format(unboxing_and_autograd_genrule), + ":{}[autograd/generated/VariableType_1.cpp]".format(unboxing_and_autograd_genrule), + ":{}[autograd/generated/VariableType_2.cpp]".format(unboxing_and_autograd_genrule), + ":{}[autograd/generated/VariableType_3.cpp]".format(unboxing_and_autograd_genrule), + ":{}[autograd/generated/VariableType_4.cpp]".format(unboxing_and_autograd_genrule), + ":{}[autograd/generated/ADInplaceOrViewType_0.cpp]".format(unboxing_and_autograd_genrule), + ":{}[autograd/generated/ADInplaceOrViewType_1.cpp]".format(unboxing_and_autograd_genrule), + ] if train else []) + ([ + ":{}[SupportedMobileModelsRegistration.cpp]".format(oplist_dir_name), + ] if NOT_OSS else []) + + headers = { + "selected_mobile_ops.h": ":{}[selected_mobile_ops.h]".format(oplist_dir_name), + } + + if _get_enable_lightweight_dispatch(): + srcs.extend([ + ":{}[UnboxingFunctions_0.cpp]".format(unboxing_genrule), + ":{}[UnboxingFunctions_1.cpp]".format(unboxing_genrule), + ":{}[UnboxingFunctions_2.cpp]".format(unboxing_genrule), + ":{}[UnboxingFunctions_3.cpp]".format(unboxing_genrule), + ":{}[UnboxingFunctions_4.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_0.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_1.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_2.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_3.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_4.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_5.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_6.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_7.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_8.cpp]".format(unboxing_genrule), + ":{}[RegisterCodegenUnboxedKernels_9.cpp]".format(unboxing_genrule), + ]) + headers["UnboxingFunctions.h"] = ":{}[UnboxingFunctions.h]".format(unboxing_genrule) + return {"headers": headers, "srcs": srcs} + +def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple_sdks = None): + fb_xplat_genrule( + name = name, + outs = get_generate_code_bin_outs(), + default_outs = ["."], + bash = "mkdir -p tools && " + + "$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join( + # Mobile build only needs libtorch - skip python bindings for now, except + # for ovrsource, which needs Python bindings. + (["--subset libtorch"] if not is_arvr_mode() else []) + [ + "--native-functions-path $(location {}:aten_src_path)/aten/src/ATen/native/native_functions.yaml".format(ROOT), + "--tags-path $(location {}:aten_src_path)/aten/src/ATen/native/tags.yaml".format(ROOT), + "--install_dir $OUT", + ] + extra_params, + ), + cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " + + "$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join( + # Mobile build only needs libtorch - skip python bindings for now, except + # for ovrsource, which needs Python bindings. + (["--subset libtorch"] if not is_arvr_mode() else []) + [ + "--native-functions-path $(location {}:aten_src_path)/aten/src/ATen/native/native_functions.yaml".format(ROOT), + "--tags-path $(location {}:aten_src_path)/aten/src/ATen/native/tags.yaml".format(ROOT), + "--install_dir $OUT", + ] + extra_params, + ), + compatible_with = compatible_with, + apple_sdks = apple_sdks, + ) + +def copy_metal(name, apple_sdks = None): + cmd = [] + cmd_exe = [] + metal_source_dict = get_metal_source_dict() + + # Copy all source files over to bring them into the per app build + for path_prefix in sorted(metal_source_dict.keys()): + cmd.append("mkdir -p $OUT/{}".format(path_prefix)) + cmd_exe.append("mkdir -Force $OUT/{0}".format(path_prefix)) + + # Not every directory has a mm or cpp file so '2>/dev/null || :' are tricks to suppress the error messages and codes. + cmd.append("cp -f {0}/{1}/*.mm $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix)) + cmd.append("cp -f {0}/{1}/*.cpp $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix)) + + # Robocopy has a default success code of 1 which buck treats as failure so the echo masks that problem + cmd_exe.append("(robocopy /E /NFL /NDL /NJH /NJS {0}/{1} $OUT/{1}) || ECHO robocopy failed".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix)) + + # Metal custom ops currently have to be brought into selective build because they directly reference metal ops instead of + # going through the dispatcher. There is some weird issues with the genrule and these files locations on windows though, so + # for now we simply skip building them for windows where they very likely arent needed anyway. + # Metal MaskRCNN custom op + for full_path in METAL_MASKRCNN_SOURCE_LIST: + path_prefix = paths.dirname(full_path) + cmd.append("mkdir -p $OUT/{}".format(path_prefix)) + cmd.append("cp -f {0}/{1}/*.mm $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2/fb/metal:metal_maskrcnn_sources)", path_prefix)) + + # Unet Metal Prepack Custom op + unet_metal_prepack_file = "$(location //xplat/caffe2/fb/custom_ops/unet_metal_prepack:unet_metal_prepack_sources)" + cmd.append("cp -f " + unet_metal_prepack_file + "/unet_metal_prepack.cpp" + " $OUT") + cmd.append("cp -f " + unet_metal_prepack_file + "/unet_metal_prepack.mm" + " $OUT") + + fb_xplat_genrule( + name = name, + cmd = " && ".join(cmd), + cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)), + # due to an obscure bug certain custom ops werent being copied correctly on windows. ARVR also sometimes builds android targets on windows, + # so we just exclude those targets from being copied for those platforms (They end up uncompiled anyway). + outs = select({ + "DEFAULT": get_metal_registration_files_outs(), + "ovr_config//os:android": get_metal_registration_files_outs_windows(), + "ovr_config//os:windows": get_metal_registration_files_outs_windows(), + }), + default_outs = ["."], + apple_sdks = apple_sdks, + ) + +def get_pt_operator_registry_dict( + name, + deps = [], + train = False, + labels = [], + env = [], + template_select = True, + enforce_traced_op_list = False, + pt_allow_forced_schema_registration = True, + enable_flatbuffer = False, + **kwargs): + code_gen_files = pt_operator_query_codegen( + name, + deps = deps, + train = train, + enforce_traced_op_list = enforce_traced_op_list, + pt_allow_forced_schema_registration = pt_allow_forced_schema_registration, + compatible_with = kwargs.get("compatible_with", []), + apple_sdks = kwargs.get("apple_sdks"), + ) + + return dict( + srcs = code_gen_files["srcs"], + linker_flags = [ + "-Wl,--no-as-needed", + ], + # @lint-ignore BUCKLINT link_whole + link_whole = True, + soname = "libtorch-code-gen.$(ext)", + header_namespace = "ATen", + compiler_flags = get_aten_compiler_flags(), + exported_headers = code_gen_files["headers"], + exported_preprocessor_flags = get_aten_preprocessor_flags() + (["-DTEMPLATE_SELECTIVE_BUILD"] if template_select else []), + headers = kwargs.pop("headers", []), + labels = kwargs.pop("labels", []) + [ + # This library has multiple sources with the same file name + # and does not work with Buck filegroup used in bad practices. + # Opt out of the bad practices check with the below label. + "bad_practices_ignore_override", + "pt_operator_registry", + ], + deps = [ + # need absolute path here + ROOT + ":torch_mobile_core", + ROOT + ":aten_cpu", + ROOT + ":aten_metal_prepack_header", + third_party("glog"), + C10, + ] + ([ROOT + ":torch_mobile_train"] if train else []) + + ([ROOT + ":torch_flatbuffer_all"] if enable_flatbuffer else []), + **kwargs + ) # these targets are shared by internal and OSS BUCK def define_buck_targets( - feature = None, + aten_default_args = dict(), + pt_xplat_cxx_library = fb_xplat_cxx_library, + c2_fbandroid_xplat_compiler_flags = [], labels = []): + # @lint-ignore BUCKLINT + fb_native.filegroup( + name = "metal_build_srcs", + # @lint-ignore BUCKRESTRICTEDSYNTAX + srcs = glob(METAL_SOURCE_LIST), + visibility = [ + "PUBLIC", + ], + ) + + # @lint-ignore BUCKLINT + fb_native.filegroup( + name = "templated_selective_build_srcs", + # NB: no glob here, there are generated targets in this list! + # @lint-ignore BUCKRESTRICTEDSYNTAX + srcs = glob(TEMPLATE_SOURCE_LIST) + aten_ufunc_generated_all_cpu_sources(":gen_aten[{}]"), + visibility = [ + "PUBLIC", + ], + ) + fb_xplat_cxx_library( name = "th_header", header_namespace = "", @@ -53,7 +789,6 @@ def define_buck_targets( ("aten/src", "THNN/generic/*.h"), ("aten/src", "THNN/generic/*.c"), ]), - feature = feature, labels = labels, ) @@ -93,7 +828,6 @@ def define_buck_targets( ("aten/src", "ATen/native/mkldnn/*.h"), ]), visibility = ["PUBLIC"], - feature = feature, labels = labels, ) @@ -106,7 +840,6 @@ def define_buck_targets( ("aten/src", "ATen/native/vulkan/ops/*.h"), ("aten/src", "ATen/vulkan/*.h"), ]), - feature = feature, labels = labels, visibility = ["PUBLIC"], ) @@ -115,7 +848,6 @@ def define_buck_targets( name = "jit_core_headers", header_namespace = "", exported_headers = subdir_glob([("", x) for x in jit_core_headers]), - feature = feature, labels = labels, ) @@ -143,7 +875,6 @@ def define_buck_targets( "torch/csrc/jit/serialization/mobile_bytecode_generated.h", ], ), - feature = feature, labels = labels, visibility = ["PUBLIC"], deps = [ @@ -159,6 +890,16 @@ def define_buck_targets( ]), ) + fb_xplat_cxx_library( + name = "aten_metal_prepack_header", + header_namespace = "", + exported_headers = subdir_glob([ + ("aten/src", "ATen/native/metal/MetalPrepackOpContext.h"), + ]), + labels = labels, + visibility = ["PUBLIC"], + ) + fb_xplat_cxx_library( name = "torch_mobile_headers", header_namespace = "", @@ -167,7 +908,6 @@ def define_buck_targets( ("", "torch/csrc/jit/mobile/*.h"), ], ), - feature = feature, labels = labels, visibility = ["PUBLIC"], ) @@ -178,7 +918,6 @@ def define_buck_targets( exported_headers = { "Config.h": ":generate_aten_config[Config.h]", }, - feature = feature, labels = labels, ) @@ -192,7 +931,6 @@ def define_buck_targets( # Don't build python bindings on mobile. #"python_functions.h", }, - feature = feature, labels = labels, visibility = ["PUBLIC"], ) @@ -203,7 +941,6 @@ def define_buck_targets( exported_headers = { "version.h": ":generate-version-header[version.h]", }, - feature = feature, labels = labels, ) @@ -214,7 +951,7 @@ def define_buck_targets( "torch/csrc/api/include/torch/version.h.in", "version.txt", ], - cmd = "$(exe {}tools/setup_helpers:gen-version-header) ".format(ROOT_PATH) + " ".join([ + cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([ "--template-path", "torch/csrc/api/include/torch/version.h.in", "--version-path", @@ -240,3 +977,1162 @@ def define_buck_targets( "PUBLIC", ], ) + + fb_xplat_cxx_library( + name = "common_core", + srcs = [ + "caffe2/core/common.cc", + ], + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = get_pt_compiler_flags(), + labels = labels, + # @lint-ignore BUCKLINT link_whole + link_whole = True, + visibility = ["PUBLIC"], + windows_preferred_linkage = "static" if is_arvr_mode() else None, + deps = [ + ":caffe2_headers", + C10, + ], + ) + + # @lint-ignore BUCKLINT + fb_native.genrule( + name = "generate_aten_config", + srcs = [ + "aten/src/ATen/Config.h.in", + ], + cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([ + "--install_dir", + "$OUT", + "--input-file", + "aten/src/ATen/Config.h.in", + "--output-file", + "Config.h", + "--replace", + "@AT_MKLDNN_ENABLED@", + "ATEN_MKLDNN_ENABLED_FBXPLAT", + "--replace", + "@AT_MKL_ENABLED@", + "ATEN_MKL_ENABLED_FBXPLAT", + "--replace", + "@AT_MKL_SEQUENTIAL@", + "ATEN_MKL_SEQUENTIAL_FBXPLAT", + "--replace", + "@AT_FFTW_ENABLED@", + "0", + "--replace", + "@AT_POCKETFFT_ENABLED@", + "1", + "--replace", + "@AT_NNPACK_ENABLED@", + "ATEN_NNPACK_ENABLED_FBXPLAT", + "--replace", + "@CAFFE2_STATIC_LINK_CUDA_INT@", + "CAFFE2_STATIC_LINK_CUDA_FBXPLAT", + "--replace", + "@AT_BUILD_WITH_BLAS@", + "USE_BLAS_FBXPLAT", + "--replace", + "@AT_PARALLEL_OPENMP@", + "AT_PARALLEL_OPENMP_FBXPLAT", + "--replace", + "@AT_PARALLEL_NATIVE@", + "AT_PARALLEL_NATIVE_FBXPLAT", + "--replace", + "@AT_PARALLEL_NATIVE_TBB@", + "AT_PARALLEL_NATIVE_TBB_FBXPLAT", + "--replace", + "@AT_BUILD_WITH_LAPACK@", + "USE_LAPACK_FBXPLAT", + "--replace", + "@AT_BLAS_F2C@", + "AT_BLAS_F2C_FBXPLAT", + "--replace", + "@AT_BLAS_USE_CBLAS_DOT@", + "AT_BLAS_USE_CBLAS_DOT_FBXPLAT", + ]), + outs = { + "Config.h": ["Config.h"], + }, + default_outs = ["."], + ) + + gen_aten_files( + name = "gen_aten", + extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS), + visibility = ["PUBLIC"], + ) + + gen_aten_libtorch_files(name = "gen_aten_libtorch") + + gen_aten_libtorch_files( + name = "gen_aten_libtorch_lite", + extra_params = get_jit_codegen_params(), + ) + + fb_xplat_cxx_library( + name = "generated_aten_headers_cpu", + header_namespace = "ATen", + exported_headers = get_aten_static_dispatch_backend_headers({ + "CPUFunctions.h": ":gen_aten[CPUFunctions.h]", + "CPUFunctions_inl.h": ":gen_aten[CPUFunctions_inl.h]", + "CompositeExplicitAutogradFunctions.h": ":gen_aten[CompositeExplicitAutogradFunctions.h]", + "CompositeExplicitAutogradFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradFunctions_inl.h]", + "CompositeExplicitAutogradNonFunctionalFunctions.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions.h]", + "CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]", + "CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]", + "CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]", + "FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]", + "Functions.h": ":gen_aten[Functions.h]", + "MethodOperators.h": ":gen_aten[MethodOperators.h]", + "NativeFunctions.h": ":gen_aten[NativeFunctions.h]", + "NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]", + "Operators.h": ":gen_aten[Operators.h]", + "RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]", + "core/TensorBody.h": ":gen_aten[core/TensorBody.h]", + "core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]", + "core/enum_tag.h": ":gen_aten[core/enum_tag.h]", + }), + labels = labels, + ) + + fb_xplat_cxx_library( + name = "torch_mobile_observer", + srcs = [ + "torch/csrc/jit/mobile/observer.cpp", + ] + ([] if IS_OSS else ["torch/fb/observers/MobileObserverUtil.cpp"]), + compiler_flags = ["-fexceptions"], + header_namespace = "", + exported_headers = subdir_glob( + [ + ("", "torch/csrc/jit/mobile/observer.h"), + ] + ([] if IS_OSS else [ + ("", "torch/fb/observers/ObserverUtil.h"), + ("", "torch/fb/observers/MobileObserverUtil.h"), + ]), + ), + fbobjc_compiler_flags = [ + "-Wno-missing-prototypes", + ], + labels = labels, + visibility = ["PUBLIC"], + deps = [ + C10, + ], + ) + + # Base library shared by lite-interpreter and full-jit. + pt_xplat_cxx_library( + name = "torch_common", + srcs = core_sources_common, + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + # @lint-ignore BUCKLINT link_whole + link_whole = True, + visibility = ["PUBLIC"], + deps = [ + ":aten_cpu", + ":generated-autograd-headers", + ":torch_headers", + C10, + third_party("libkineto_headers"), + ], + ) + + pt_xplat_cxx_library( + name = "torch_mobile_deserialize_common", + srcs = [ + "torch/csrc/jit/mobile/parse_bytecode.cpp", + "torch/csrc/jit/mobile/parse_operators.cpp", + "torch/csrc/jit/mobile/upgrader_mobile.cpp", + "torch/csrc/jit/serialization/import_read.cpp", + "torch/csrc/jit/serialization/unpickler.cpp", + ], + header_namespace = "", + exported_headers = [ + "torch/csrc/jit/serialization/import_read.h", + "torch/csrc/jit/serialization/unpickler.h", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + extra_flags = { + "fbandroid_compiler_flags": ["-frtti"], + }, + # torch_mobile_deserialize brings in sources neccessary to read a module + # which depends on mobile module definition + # link_whole is enable so that all symbols neccessary for mobile module are compiled + # instead of only symbols used while loading; this prevents symbol + # found definied in runtime + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = ["-Wl,--no-as-needed"], + visibility = ["PUBLIC"], + exported_deps = [ + ":aten_cpu", + ":caffe2_headers", + ":caffe2_serialize", + ":torch_common", + ":torch_headers", + ":torch_mobile_headers", + ":torch_mobile_module", + ":torch_mobile_observer", + C10, + ], + ) + + pt_xplat_cxx_library( + name = "torch_mobile_module", + srcs = [ + "torch/csrc/jit/mobile/function.cpp", + "torch/csrc/jit/mobile/interpreter.cpp", + "torch/csrc/jit/mobile/module.cpp", + ], + header_namespace = "", + exported_headers = [ + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []), + extra_flags = { + "fbandroid_compiler_flags": ["-frtti"], + }, + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + exported_deps = [ + ":aten_cpu", + ":caffe2_headers", + ":torch_common", + ":torch_headers", + ":torch_mobile_headers", + ":torch_mobile_observer", + C10, + ], + ) + + pt_xplat_cxx_library( + name = "torch_mobile_debug_symbolication", + srcs = [ + # included in aten_cpu "torch/csrc/jit/frontend/source_range.cpp", + "torch/csrc/jit/ir/scope.cpp", + "torch/csrc/jit/mobile/debug_info.cpp", + "torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp", + "torch/csrc/jit/serialization/source_range_serialization.cpp", + "torch/csrc/jit/serialization/pickle.cpp", + # pickler.cpp doesn't seem to be needed. + # "torch/csrc/jit/serialization/pickler.cpp", + # included in core_sources_common "torch/csrc/jit/serialization/unpickler.cpp", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + header_namespace = "", + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + deps = [ + ":torch_mobile_deserialize", + ], + exported_deps = [ + ":torch_common", + ], + ) + + pt_xplat_cxx_library( + name = "torch_model_tracer", + srcs = [ + "torch/csrc/jit/mobile/model_tracer/BuildFeatureTracer.cpp", + "torch/csrc/jit/mobile/model_tracer/CustomClassTracer.cpp", + "torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.cpp", + "torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.cpp", + "torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp", + ], + header_namespace = "", + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []), + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + deps = [ + ":generated-autograd-headers", + ":torch_mobile_deserialize", + ":torch_mobile_headers", + ":torch_mobile_observer", + ] + ([] if IS_OSS else ["//xplat/folly:molly"]), + exported_deps = [ + ":aten_cpu", + ":torch_common", + ] + ([] if IS_OSS else [ + "//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox", + "//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn", + ]), + ) + + pt_xplat_cxx_library( + name = "torch_mobile_deserialize", + srcs = [ + "torch/csrc/jit/mobile/import.cpp", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + header_namespace = "", + exported_headers = [ + "torch/csrc/jit/mobile/import.h", + ], + # torch_mobile_deserialize brings in sources neccessary to read a module + # which depends on mobile module definition + # link_whole is enable so that all symbols neccessary for mobile module are compiled + # instead of only symbols used while loading; this prevents symbol + # found definied in runtime + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + exported_deps = [ + ":aten_cpu", + ":caffe2_headers", + ":caffe2_serialize", + ":torch_common", + ":torch_headers", + ":torch_mobile_headers", + ":torch_mobile_module", + ":torch_mobile_observer", + ":torch_mobile_deserialize_common", + C10, + ], + ) + + pt_xplat_cxx_library( + name = "torch_mobile_core", + srcs = [], + header_namespace = "", + exported_headers = [], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []), + # torch_mobile_core brings in sources neccessary to read and run a module + # link_whole is enabled so that all symbols linked + # operators, registerations and other few symbols are need in runtime + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + deps = [ + ":generated-autograd-headers", + ":torch_mobile_headers", + ":torch_mobile_observer", + ], + exported_deps = [ + ":aten_cpu", + ":torch_common", + ":torch_mobile_deserialize", + ":torch_supported_mobile_models", + ], + ) + + pt_xplat_cxx_library( + name = "torch_mobile_core_pickle_and_flatbuffer", + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + visibility = ["PUBLIC"], + exported_deps = [ + ":torch_flatbuffer_all", + ":torch_mobile_core", + ], + ) + + pt_xplat_cxx_library( + name = "torch_core", + srcs = core_sources_full_mobile_no_backend_interface + [ + "torch/csrc/api/src/jit.cpp", + "torch/csrc/jit/serialization/export_bytecode.cpp", + "torch/csrc/jit/serialization/export_module.cpp", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + visibility = [ + "//xplat/caffe2/android/...", + "//xplat/caffe2/fb/...", + "//xplat/caffe2/fb/model_tracer/...", + ], + deps = [ + ":aten_cpu", + ":backend_interface_lib", + ":generated-autograd-headers", + ":torch_headers", + ":torch_mobile_deserialize", + third_party("glog"), + third_party("rt"), + C10, + ] + ([] if IS_OSS else [ + "//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox", + "//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn", + ]), + exported_deps = [ + ":torch_common", + ":torch_mobile_train", + ], + ) + + pt_xplat_cxx_library( + name = "torch_train", + srcs = [ + "torch/csrc/api/src/data/samplers/random.cpp", + "torch/csrc/api/src/data/samplers/sequential.cpp", + "torch/csrc/api/src/optim/optimizer.cpp", + "torch/csrc/api/src/optim/serialize.cpp", + "torch/csrc/api/src/optim/sgd.cpp", + "torch/csrc/api/src/serialize/input-archive.cpp", + "torch/csrc/api/src/serialize/output-archive.cpp", + "torch/csrc/jit/api/module_save.cpp", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + visibility = ["PUBLIC"], + deps = [ + ":aten_cpu", + ":torch_headers", + ":torch", + ":torch_core", + ":torch_mobile_deserialize", + ":torch_mobile_train", + C10, + ], + ) + + pt_xplat_cxx_library( + name = "torch_mobile_train", + srcs = core_trainer_sources + [ + "torch/csrc/autograd/VariableTypeManual.cpp", + "torch/csrc/autograd/FunctionsManual.cpp", + "torch/csrc/api/src/data/datasets/mnist.cpp", + "torch/csrc/jit/mobile/train/export_data.cpp", + "torch/csrc/jit/mobile/train/optim/sgd.cpp", + "torch/csrc/jit/mobile/train/random.cpp", + "torch/csrc/jit/mobile/train/sequential.cpp", + ":gen_aten_libtorch[autograd/generated/Functions.cpp]", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"], + # torch_mobile_train brings in sources neccessary to read and run a mobile + # and save and load mobile params along with autograd + # link_whole is enabled so that all symbols linked + # operators, registerations and autograd related symbols are need in runtime + # @lint-ignore BUCKLINT link_whole + link_whole = True, + visibility = ["PUBLIC"], + deps = [ + ":aten_cpu", + ":generated-autograd-headers", + ":torch_headers", + ":torch_mobile_deserialize", + C10, + ], + ) + + pt_xplat_cxx_library( + name = "torch", + srcs = [ + "torch/csrc/jit/runtime/register_c10_ops.cpp", + "torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + # torch brings in all sources neccessary to read and run a mobile module/jit module + # link_whole is enabled so that all symbols linked + # operators, registerations and other few symbols are need in runtime + # @lint-ignore BUCKLINT link_whole + link_whole = True, + visibility = ["PUBLIC"], + deps = [ + # This is to have autograd profiler available + # in xplat/caffe2:torch which some builds are using + # notable xplate/facegen:testsAndroid + ":torch_headers", + ":torch_kineto_profiling", + ], + exported_deps = [ + ":aten_cpu", + ":torch_core", + C10, + ], + ) + + pt_xplat_cxx_library( + name = "torch_mobile_train_import_data", + srcs = [ + "torch/csrc/jit/mobile/import_data.cpp", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"], + # torch_mobile_train_import_data brings in sources neccessary to read a mobile module + # link_whole is enabled so that all symbols linked + # operators other few symbols are need in runtime + # @lint-ignore BUCKLINT link_whole + link_whole = True, + visibility = ["PUBLIC"], + deps = [ + ":torch_headers", + ":torch_mobile_observer", + ":torch_mobile_core", + ":torch_mobile_train", + ], + ) + + fb_xplat_cxx_library( + name = "torch_mobile_compatibility", + srcs = [ + # These .cpp brought in through core_sources_common + # "torch/csrc/jit/mobile/compatibility/runtime_compatibility.cpp", + # "torch/csrc/jit/serialization/unpickler.cpp", + "torch/csrc/jit/mobile/compatibility/model_compatibility.cpp", + "torch/csrc/jit/serialization/pickle.cpp", + "torch/csrc/jit/serialization/pickler.cpp", + ], + header_namespace = "", + exported_headers = [ + "torch/csrc/jit/mobile/compatibility/backport.h", + "torch/csrc/jit/mobile/compatibility/backport_manager.h", + "torch/csrc/jit/mobile/compatibility/model_compatibility.h", + "torch/csrc/jit/mobile/compatibility/runtime_compatibility.h", + ], + compiler_flags = [ + "-fexceptions", + "-frtti", + "-Wno-deprecated-declarations", + "-Wno-global-constructors", + ], + labels = labels, + visibility = ["PUBLIC"], + deps = [ + ":torch_mobile_deserialize", + ], + ) + + pt_xplat_cxx_library( + name = "jit_module_saving", + srcs = [ + "torch/csrc/jit/api/module_save.cpp", + "torch/csrc/jit/serialization/export_bytecode.cpp", + "torch/csrc/jit/serialization/export_module.cpp", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + exported_headers = [ + "torch/csrc/jit/serialization/export.h", + "torch/csrc/jit/serialization/flatbuffer_serializer_jit.h", + ], + visibility = ["PUBLIC"], + deps = [ + ":torch", + ":torch_mobile_core", + ], + ) + + pt_xplat_cxx_library( + name = "torch_mobile_model_tracer", + srcs = [ + "torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp", + "torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp", + ], + headers = [ + "torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h", + "torch/csrc/jit/mobile/model_tracer/TensorUtils.h", + ], + header_namespace = "", + exported_headers = [ + "torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h", + ], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []), + # torch_mobile_model_tracer brings in sources neccessary to read and run a jit module + # and trace the ops + # link_whole is enabled so that all symbols linked + # operators, registerations and other few symbols are need in runtime + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + deps = [ + ":caffe2_serialize", + ":generated-autograd-headers", + ":torch_mobile_headers", + ":torch_mobile_observer", + ":torch_mobile_core", + ] + ([] if IS_OSS else ["//xplat/folly:molly"]), + exported_deps = [ + ":aten_cpu", + ":torch_common", + ] + ([] if IS_OSS else [ + "//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox", + "//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn", + "//xplat/caffe2/fb/custom_ops/sparsenn:sparsenn-all", + ]), + ) + + pt_xplat_cxx_library( + name = "torch_mobile_core_flatbuffer", + srcs = [], + header_namespace = "", + exported_headers = [], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []), + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + deps = [ + ":generated-autograd-headers", + ":torch_mobile_headers", + ":torch_mobile_observer", + ], + exported_deps = [ + ":aten_cpu", + ":torch_common", + ] + ([] if IS_OSS else [ + "//xplat/caffe2/fb/runtime:torch_mobile_deserialize_flatbuffer", + ]), + ) + + fb_xplat_cxx_library( + name = "backend_interface_lib", + srcs = [ + "torch/csrc/jit/backends/backend_debug_info.cpp", + "torch/csrc/jit/backends/backend_interface.cpp", + ], + compiler_flags = get_pt_compiler_flags(), + fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags, + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + exported_deps = [ + ":aten_cpu", + ":torch_common", + ], + ) + + pt_xplat_cxx_library( + name = "torch_kineto_profiling", + srcs = libtorch_profiler_sources, + compiler_flags = get_pt_compiler_flags() + ["-Wno-error"], + exported_preprocessor_flags = get_pt_preprocessor_flags() + [ + "-DUSE_KINETO", + "-DUSE_KINETO_UPDATED", + # Need this otherwise USE_KINETO is undefed + # for mobile + "-DEDGE_PROFILER_USE_KINETO", + ], + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + deps = [ + third_party("glog"), + third_party("kineto"), + ], + exported_deps = [ + ":aten_cpu", + ":torch_common", + ], + ) + + pt_xplat_cxx_library( + name = "torch_edge_profiling", + srcs = ["torch/csrc/jit/mobile/profiler_edge.cpp"], + compiler_flags = get_pt_compiler_flags() + ["-Wno-error"], + exported_preprocessor_flags = get_pt_preprocessor_flags() + [ + "-DUSE_KINETO", + "-DUSE_KINETO_UPDATED", + "-DEDGE_PROFILER_USE_KINETO", + ], + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + exported_deps = [ + ":torch_common", + ":torch_kineto_profiling", + ":torch_mobile_core", + ], + ) + + fb_xplat_genrule( + name = "mobile_bytecode_header", + srcs = [ + "torch/csrc/jit/serialization/mobile_bytecode.fbs", + ], + outs = { + "mobile_bytecode_generated.h": ["mobile_bytecode_generated.h"], + }, + cmd = "$(exe {})".format(third_party("flatc")) + + " --cpp --gen-mutable --scoped-enums -o ${OUT} ${SRCS}", + default_outs = ["."], + ) + + fb_xplat_cxx_library( + name = "mobile_bytecode", + header_namespace = "", + exported_headers = { + "torch/csrc/jit/serialization/mobile_bytecode_generated.h": ":mobile_bytecode_header[mobile_bytecode_generated.h]", + }, + exported_deps = [ + third_party("flatbuffers-api"), + ], + ) + + fb_xplat_cxx_library( + name = "flatbuffer_serializer", + srcs = ["torch/csrc/jit/serialization/flatbuffer_serializer.cpp"], + exported_headers = [ + "torch/csrc/jit/serialization/flatbuffer_serializer.h", + ], + compiler_flags = [ + "-g0", + "-O3", + "-fexceptions", + "-frtti", + "-Wno-deprecated-declarations", + ], + visibility = ["PUBLIC"], + deps = [ + ":torch_mobile_module", + C10, + ], + exported_deps = [ + ":flatbuffer_loader", + ":mobile_bytecode", + ":torch_mobile_train", + third_party("flatbuffers-api"), + ], + ) + + pt_xplat_cxx_library( + name = "flatbuffer_loader", + srcs = [ + "torch/csrc/jit/mobile/flatbuffer_loader.cpp", + ], + exported_headers = [ + "torch/csrc/jit/mobile/flatbuffer_loader.h", + ], + compiler_flags = get_pt_compiler_flags() + ["-Wno-error"], + exported_preprocessor_flags = get_pt_preprocessor_flags() + [ + "-DUSE_KINETO", + "-DUSE_KINETO_UPDATED", + # Need this otherwise USE_KINETO is undefed + # for mobile + "-DEDGE_PROFILER_USE_KINETO", + ], + extra_flags = { + "fbandroid_compiler_flags": ["-frtti"], + }, + # torch_mobile_deserialize brings in sources neccessary to read a module + # which depends on mobile module definition + # link_whole is enable so that all symbols neccessary for mobile module are compiled + # instead of only symbols used while loading; this prevents symbol + # found definied in runtime + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + exported_deps = [ + ":mobile_bytecode", + ":torch_mobile_deserialize", + third_party("flatbuffers-api"), + C10, + ], + ) + + fb_xplat_cxx_library( + name = "flatbuffer_serializer_jit", + srcs = ["torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp"], + exported_headers = [ + "torch/csrc/jit/serialization/flatbuffer_serializer_jit.h", + ], + compiler_flags = [ + "-g0", + "-O3", + "-fexceptions", + "-frtti", + "-Wno-deprecated-declarations", + ], + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + deps = [ + ":flatbuffer_loader", + ":flatbuffer_serializer", + ":mobile_bytecode", + ":torch_core", + ":torch_mobile_module", + third_party("flatbuffers-api"), + C10, + ], + ) + + fb_xplat_cxx_library( + name = "torch_flatbuffer_all", + visibility = ["PUBLIC"], + exported_deps = [ + ":flatbuffer_loader", + ":flatbuffer_serializer", + ":flatbuffer_serializer_jit", + ], + ) + + pt_xplat_cxx_library( + name = "torch_supported_mobile_models", + srcs = [ + "fb/supported_mobile_models/SupportedMobileModels.cpp", + ] if NOT_OSS else [], + header_namespace = "", + exported_headers = ["fb/supported_mobile_models/SupportedMobileModels.h"] if NOT_OSS else [], + compiler_flags = get_pt_compiler_flags() + ["-Wno-error"], + exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []), + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + deps = [], + exported_deps = [ + "//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox", + "//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn", + ] if NOT_OSS else [], + ) + + fb_xplat_cxx_library( + name = "static_runtime", + srcs = [ + "torch/csrc/jit/runtime/static/fusion.cpp", + "torch/csrc/jit/runtime/static/generated_ops.cpp", + "torch/csrc/jit/runtime/static/impl.cpp", + "torch/csrc/jit/runtime/static/memory_planner.cpp", + "torch/csrc/jit/runtime/static/native_ops.cpp", + "torch/csrc/jit/runtime/static/ops.cpp", + "torch/csrc/jit/runtime/static/passes.cpp", + "torch/csrc/jit/runtime/static/te_wrapper.cpp", + ], + compiler_flags = ["-fexceptions"], + labels = labels, + # @lint-ignore BUCKLINT link_whole + link_whole = True, + visibility = ["PUBLIC"], + windows_preferred_linkage = "static" if is_arvr_mode() else None, + deps = [ + ":aten_cpu", + ":caffe2_headers", + ":torch_core", + C10, + ], + ) + + # aten_cpu and aten_native_cpu + for name, srcs in [ + ("aten_cpu", jit_core_sources + aten_cpu_source_list + [ + # Generated + ":gen_aten[Functions.cpp]", + ":gen_aten[Operators_0.cpp]", + ":gen_aten[Operators_1.cpp]", + ":gen_aten[Operators_2.cpp]", + ":gen_aten[Operators_3.cpp]", + ":gen_aten[Operators_4.cpp]", + ":gen_aten[core/ATenOpList.cpp]", + ":gen_aten[core/TensorMethods.cpp]", + # Needed by ATen/native/EmbeddingBag.cpp + "caffe2/perfkernels/embedding_lookup_idx.cc", + ]), + ("aten_native_cpu", aten_native_source_list), + ]: + fb_xplat_cxx_library( + name = name, + srcs = srcs, + header_namespace = "", + # @lint-ignore BUCKLINT + link_whole = True, + visibility = ["PUBLIC"], + deps = [ + third_party("omp"), + third_party("cpuinfo"), + third_party("glog"), + third_party("XNNPACK"), + third_party("pocketfft"), + ], + compiler_flags = get_aten_compiler_flags(), + exported_preprocessor_flags = get_aten_preprocessor_flags(), + exported_deps = [ + ":aten_header", + ":caffe2_headers", + ":common_core", + ":generated_aten_config_header", + ":generated_aten_headers_cpu", + ":jit_core_headers", + ":pthreadpool", + third_party("fmt"), + third_party("ruy"), + C10, + ROOT_PATH + "aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack", + ], + labels = labels, + **aten_default_args + ) + + fb_xplat_cxx_library( + name = "lean_runtime_with_flatbuffer", + srcs = [ + "aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp", + "torch/csrc/jit/mobile/import.cpp", + "torch/csrc/jit/mobile/module.cpp", + "torch/csrc/jit/mobile/observer.cpp", + "torch/csrc/jit/serialization/import_read.cpp", + ], + header_namespace = "", + exported_headers = subdir_glob( + [ + ("", "torch/csrc/jit/ir/*.h"), + ("", "caffe2/serialize/*.h"), + ("", "caffe2/utils/*.h"), + ("", "caffe2/core/*.h"), + ("", "torch/csrc/*.h"), + ("", "torch/csrc/api/include/torch/*.h"), + ("", "torch/csrc/autograd/*.h"), + ("", "torch/csrc/autograd/*/*.h"), + ("", "torch/csrc/jit/api/*.h"), + ("", "torch/csrc/jit/backends/*.h"), + ("", "torch/csrc/jit/mobile/*.h"), + ("", "torch/csrc/jit/runtime/*.h"), + ("", "torch/csrc/jit/passes/*.h"), + ("", "torch/csrc/jit/python/*.h"), + ("", "torch/csrc/jit/frontend/*.h"), + ("", "torch/csrc/jit/serialization/*.h"), + ("", "torch/csrc/utils/*.h"), + ("", "aten/src/ATen/quantized/*.h"), + ] + ([ + ("third_party/miniz-2.1.0", "*.h"), + ] if NOT_OSS else []), + exclude = [ + "torch/csrc/jit/serialization/mobile_bytecode_generated.h", + ], + ), + compiler_flags = get_pt_compiler_flags() + select({ + "DEFAULT": [], + "ovr_config//os:xtensa-xos": [ + "-fdata-sections", + "-ffunction-sections", + ], + }), + exported_preprocessor_flags = get_pt_preprocessor_flags() + [ + "-DMIN_EDGE_RUNTIME", + ], + linker_flags = [ + "-Wl,--no-as-needed", + ] + select({ + "DEFAULT": [], + "ovr_config//os:macos": [ + "-dead_strip", + ], + "ovr_config//os:xtensa-xos": [ + "-Wl,--gc-sections", + ], + }), + visibility = ["PUBLIC"], + exported_deps = [ + ":lean_runtime_with_tensor", + ], + ) + + pt_xplat_cxx_library( + name = "lean_runtime_with_tensor", + srcs = [ + "aten/src/ATen/Context.cpp", + "aten/src/ATen/EmptyTensor.cpp", + "aten/src/ATen/Utils.cpp", + "aten/src/ATen/detail/CUDAHooksInterface.cpp", + ":gen_aten[Operators_0.cpp]", + ":gen_aten[Operators_1.cpp]", + ":gen_aten[Operators_2.cpp]", + ":gen_aten[Operators_3.cpp]", + ":gen_aten[Operators_4.cpp]", + ":gen_aten[core/TensorMethods.cpp]", + ], + header_namespace = "", + exported_headers = [ + "torch/csrc/jit/runtime/custom_operator.h", + ":gen_aten[core/TensorBody.h]", + ], + compiler_flags = get_pt_compiler_flags() + select({ + "DEFAULT": [], + "ovr_config//os:xtensa-xos": [ + "-fdata-sections", + "-ffunction-sections", + ], + }), + exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({ + "DEFAULT": [], + "ovr_config//os:xtensa-xos": [ + "-Dthread_local=", + ], + }), + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + exported_deps = [ + ":generated_aten_config_header", + ":lean_runtime_with_op", + ":aten_header", + C10, + ] + (["//xplat/caffe2/fb/embedded:experimental"] if NOT_OSS else []), + ) + + pt_xplat_cxx_library( + name = "lean_runtime_with_op", + srcs = [ + "aten/src/ATen/SequenceNumber.cpp", + "aten/src/ATen/core/boxing/KernelFunction.cpp", + "aten/src/ATen/core/custom_class.cpp", + "aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp", + "aten/src/ATen/core/dispatch/Dispatcher.cpp", + "aten/src/ATen/core/dispatch/ObservedOperators.cpp", + "aten/src/ATen/core/dispatch/OperatorEntry.cpp", + "aten/src/ATen/core/interned_strings.cpp", + "aten/src/ATen/core/library.cpp", + "aten/src/ATen/core/op_registration/infer_schema.cpp", + "aten/src/ATen/core/function_schema.cpp", + "aten/src/ATen/core/operator_name.cpp", + "aten/src/ATen/core/register_symbols.cpp", + "aten/src/ATen/core/tensor_type.cpp", + "aten/src/ATen/core/union_type.cpp", + "aten/src/ATen/record_function.cpp", + "torch/csrc/jit/frontend/edit_distance.cpp", + "torch/csrc/jit/frontend/error_report.cpp", + "torch/csrc/jit/frontend/function_schema_parser.cpp", + "torch/csrc/jit/frontend/lexer.cpp", + "torch/csrc/jit/frontend/schema_type_parser.cpp", + "torch/csrc/jit/frontend/source_range.cpp", + "torch/csrc/jit/frontend/strtod.cpp", + "torch/csrc/jit/mobile/parse_operators.cpp", + "torch/csrc/jit/mobile/prim_ops_registery.cpp", + "torch/csrc/jit/runtime/operator.cpp", + "torch/csrc/jit/runtime/slice_indices_adjust.cpp", + ], + header_namespace = "", + exported_headers = [ + "torch/csrc/jit/frontend/edit_distance.h", + "torch/csrc/jit/runtime/slice_indices_adjust.h", + ], + compiler_flags = get_pt_compiler_flags() + select({ + "DEFAULT": [], + "ovr_config//os:xtensa-xos": [ + "-fdata-sections", + "-ffunction-sections", + ], + }), + exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({ + "DEFAULT": [], + "ovr_config//os:xtensa-xos": [ + "-Dthread_local=", + ], + }), + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + exported_deps = [ + ":min_runtime_lib", + C10, + ], + ) + + pt_xplat_cxx_library( + name = "min_runtime_lib", + srcs = [ + "aten/src/ATen/ScalarOps.cpp", + "aten/src/ATen/core/Dict.cpp", + "aten/src/ATen/core/List.cpp", + "aten/src/ATen/core/class_type.cpp", + "aten/src/ATen/core/dynamic_type.cpp", + "aten/src/ATen/core/ivalue.cpp", + "aten/src/ATen/core/type.cpp", + "aten/src/ATen/core/type_factory.cpp", + "aten/src/ATen/native/prim_native_functions.cpp", + "torch/csrc/jit/mobile/function.cpp", + "torch/csrc/jit/mobile/interpreter.cpp", + "torch/csrc/jit/mobile/parse_bytecode.cpp", + "torch/csrc/jit/mobile/promoted_prim_ops.cpp", + "torch/csrc/jit/mobile/register_ops_common_utils.cpp", + "torch/csrc/jit/mobile/type_parser.cpp", + "torch/csrc/jit/runtime/instruction.cpp", + "torch/csrc/jit/runtime/jit_exception.cpp", + "torch/csrc/jit/runtime/vararg_functions.cpp", + ], + header_namespace = "", + exported_headers = [ + "caffe2/serialize/versions.h", + "torch/csrc/jit/backends/backend_exception.h", + "torch/csrc/jit/mobile/register_ops_common_utils.h", + "torch/csrc/jit/runtime/instruction.h", + "torch/csrc/jit/runtime/jit_exception.h", + "torch/csrc/jit/runtime/operator.h", + "torch/csrc/jit/runtime/operator_options.h", + "torch/csrc/jit/runtime/vararg_functions.h", + "torch/csrc/jit/serialization/import_export_constants.h", + "torch/csrc/jit/serialization/import_export_functions.h", + ], + compiler_flags = get_pt_compiler_flags() + select({ + "DEFAULT": [], + "ovr_config//os:xtensa-xos": [ + "-fexceptions", + "-fdata-sections", + "-ffunction-sections", + ], + }), + exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({ + "DEFAULT": [], + "ovr_config//os:xtensa-xos": [ + "-Dthread_local=", + ], + }), + # @lint-ignore BUCKLINT link_whole + link_whole = True, + linker_flags = [ + "-Wl,--no-as-needed", + ], + visibility = ["PUBLIC"], + exported_deps = [ + ":aten_header", + ":generated_aten_headers_cpu", + ":jit_core_headers", + ":torch_mobile_headers", + C10, + ], + ) diff --git a/build.bzl b/build.bzl index 83dab33545bd7..ac9ceaa0559de 100644 --- a/build.bzl +++ b/build.bzl @@ -14,6 +14,7 @@ def define_targets(rules): "caffe2/serialize/istream_adapter.cc", "caffe2/serialize/read_adapter_interface.cc", ], + copts = ["-fexceptions"], tags = [ "supermodule:android/default/pytorch", "supermodule:ios/default/public.pytorch", @@ -25,7 +26,7 @@ def define_targets(rules): ":caffe2_headers", "@com_github_glog//:glog", "//c10", - "//third_party/miniz-2.0.8:miniz", + "//third_party/miniz-2.1.0:miniz", ], ) @@ -143,6 +144,7 @@ GENERATED_H = [ "FunctionalInverses.h", "RedispatchFunctions.h", "RegistrationDeclarations.h", + "VmapGeneratedPlumbing.h", ] GENERATED_H_CORE = [ @@ -192,6 +194,9 @@ GENERATED_CPP = [ "RegisterCompositeImplicitAutograd.cpp", "RegisterZeroTensor.cpp", "RegisterMeta.cpp", + "RegisterQuantizedMeta.cpp", + "RegisterNestedTensorMeta.cpp", + "RegisterSparseMeta.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "CompositeViewCopyKernels.cpp", diff --git a/build_variables.bzl b/build_variables.bzl index b5b14d8f1e2a4..e4b4b82df5f60 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -133,6 +133,7 @@ libtorch_profiler_sources = [ "torch/csrc/profiler/kineto_shim.cpp", "torch/csrc/profiler/nvtx_observer.cpp", "torch/csrc/profiler/kineto_client_interface.cpp", + "torch/csrc/profiler/itt_observer.cpp", "torch/csrc/monitor/counters.cpp", "torch/csrc/monitor/events.cpp", ] @@ -366,6 +367,7 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/testing/file_check.cpp", "torch/csrc/jit/testing/hooks_for_testing.cpp", "torch/csrc/utils/cpp_stacktraces.cpp", + "torch/csrc/utils/schema_info.cpp", "torch/csrc/utils/tensor_flatten.cpp", "torch/csrc/utils/variadic.cpp", ] @@ -421,7 +423,6 @@ lazy_tensor_core_sources = [ lazy_tensor_ts_sources = [ "torch/csrc/lazy/ts_backend/dynamic_ir.cpp", "torch/csrc/lazy/ts_backend/config.cpp", - "torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp", "torch/csrc/lazy/ts_backend/ops/device_data.cpp", "torch/csrc/lazy/ts_backend/ops/random_ops.cpp", "torch/csrc/lazy/ts_backend/ops/generic.cpp", @@ -749,6 +750,9 @@ libtorch_cuda_distributed_base_sources = [ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/NCCLUtils.cpp", "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", + "torch/csrc/distributed/c10d/ProcessGroupUCC.cpp", + "torch/csrc/distributed/c10d/UCCTracing.cpp", + "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] @@ -871,6 +875,7 @@ libtorch_python_core_sources = [ "torch/csrc/autograd/python_variable_indexing.cpp", "torch/csrc/jit/backends/backend_init.cpp", "torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp", + "torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp", "torch/csrc/jit/python/init.cpp", "torch/csrc/jit/passes/onnx.cpp", "torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp", @@ -935,6 +940,7 @@ libtorch_python_core_sources = [ "torch/csrc/utils/tensor_numpy.cpp", "torch/csrc/utils/tensor_types.cpp", "torch/csrc/utils/disable_torch_function.cpp", + "torch/csrc/utils/verbose.cpp", ] + lazy_tensor_core_python_sources libtorch_python_distributed_core_sources = [ @@ -1024,6 +1030,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/core/Dict.cpp", "aten/src/ATen/core/Dimname.cpp", "aten/src/ATen/core/Formatting.cpp", + "aten/src/ATen/core/function_schema.cpp", "aten/src/ATen/core/Generator.cpp", "aten/src/ATen/core/List.cpp", "aten/src/ATen/core/NamedTensor.cpp", @@ -1060,10 +1067,6 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/native/AutogradComposite.cpp", "aten/src/ATen/native/DispatchStub.cpp", "aten/src/ATen/native/UpSample.cpp", - "aten/src/ATen/native/mkl/LinearAlgebra.cpp", - "aten/src/ATen/native/mkl/SparseBlasImpl.cpp", - "aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp", - "aten/src/ATen/native/mkl/SpectralOps.cpp", "aten/src/ATen/native/mkldnn/BinaryOps.cpp", "aten/src/ATen/native/mkldnn/Conv.cpp", "aten/src/ATen/native/mkldnn/Copy.cpp", @@ -1084,6 +1087,11 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/native/mkldnn/Utils.cpp", "aten/src/ATen/native/mkldnn/Matmul.cpp", "aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp", + # This is moved to aten_cpu because some of the custom ops use empty_with_tail_padding + # which was available only within aten_native_cpu. Ideally the right fix is to make + # empty_with_tail_padding into an op and use dispatcher with it. But exposing it as an op + # has limited use and hence does not seem to really make sense. + "aten/src/ATen/native/utils/Factory.cpp", "aten/src/ATen/record_function.cpp", "aten/src/ATen/Dispatch.cpp", "aten/src/ATen/SavedTensorHooks.cpp", @@ -1092,6 +1100,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/nnapi/nnapi_wrapper.cpp", "aten/src/ATen/nnapi/nnapi_model_loader.cpp", "aten/src/ATen/native/prim_native_functions.cpp", + "aten/src/ATen/native/verbose_wrapper.cpp", ] aten_cpu_source_codegen_list = [ @@ -1131,8 +1140,8 @@ aten_native_source_codegen_list = [ "aten/src/ATen/native/cpu/IndexKernel.cpp", "aten/src/ATen/native/cpu/LerpKernel.cpp", "aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp", - "aten/src/ATen/native/cpu/MaxPooling.cpp", "aten/src/ATen/native/cpu/MaxPoolKernel.cpp", + "aten/src/ATen/native/cpu/MaxPooling.cpp", "aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp", "aten/src/ATen/native/cpu/MultinomialKernel.cpp", "aten/src/ATen/native/cpu/PixelShuffleKernel.cpp", @@ -1153,10 +1162,15 @@ aten_native_source_codegen_list = [ "aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp", "aten/src/ATen/native/cpu/UpSampleKernel.cpp", "aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp", + "aten/src/ATen/native/cpu/WeightNormKernel.cpp", + "aten/src/ATen/native/cpu/airy_ai.cpp", "aten/src/ATen/native/cpu/batch_norm_kernel.cpp", "aten/src/ATen/native/cpu/group_norm_kernel.cpp", "aten/src/ATen/native/cpu/layer_norm_kernel.cpp", - "aten/src/ATen/native/cpu/WeightNormKernel.cpp", + "aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp", + "aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp", + "aten/src/ATen/native/cpu/spherical_bessel_j0.cpp", + "aten/src/ATen/native/cpu/SparseFactories.cpp", "aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp", ] @@ -1165,8 +1179,10 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/ao_sparse/library.cpp", "aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp", "aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear.cpp", + "aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp", "aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp", "aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp", + "aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_serialize.cpp", "aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp", "aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp", "aten/src/ATen/native/quantized/cpu/fused_obs_fake_quant.cpp", @@ -1347,8 +1363,13 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/WeightNorm.cpp", "aten/src/ATen/native/group_norm.cpp", "aten/src/ATen/native/layer_norm.cpp", + "aten/src/ATen/native/mkl/LinearAlgebra.cpp", + "aten/src/ATen/native/mkl/SparseBlasImpl.cpp", + "aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp", + "aten/src/ATen/native/mkl/SpectralOps.cpp", "aten/src/ATen/native/nested/NestedTensorMath.cpp", "aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp", + "aten/src/ATen/native/nested/NestedTensorBackward.cpp", "aten/src/ATen/native/sparse/ParamUtils.cpp", "aten/src/ATen/native/sparse/SoftMax.cpp", "aten/src/ATen/native/sparse/SparseBlas.cpp", @@ -1359,9 +1380,10 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/sparse/SparseTensorMath.cpp", "aten/src/ATen/native/sparse/SparseUnaryOps.cpp", "aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp", + "aten/src/ATen/native/sparse/SparseFactories.cpp", + "aten/src/ATen/native/sparse/ValidateCompressedIndicesKernel.cpp", "aten/src/ATen/native/transformers/attention.cpp", "aten/src/ATen/native/transformers/transformer.cpp", - "aten/src/ATen/native/utils/Factory.cpp", "aten/src/ATen/native/xnnpack/Activation.cpp", "aten/src/ATen/native/xnnpack/ChannelShuffle.cpp", "aten/src/ATen/native/xnnpack/Convolution.cpp", diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 84e4b0712ee42..5f1b9777a1205 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.10 FATAL_ERROR) project(c10 CXX) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Main build file for the C10 library. diff --git a/c10/core/Allocator.cpp b/c10/core/Allocator.cpp index 2898295c6906e..9879f05e64e4a 100644 --- a/c10/core/Allocator.cpp +++ b/c10/core/Allocator.cpp @@ -57,6 +57,25 @@ void reportMemoryUsageToProfiler( } } +void reportOutOfMemoryToProfiler( + int64_t alloc_size, + int64_t total_allocated, + int64_t total_reserved, + Device device) { + auto* reporter_ptr = static_cast( + ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE)); + if (reporter_ptr) { + reporter_ptr->reportOutOfMemory( + alloc_size, total_allocated, total_reserved, device); + } +} + MemoryReportingInfoBase::MemoryReportingInfoBase() = default; +void MemoryReportingInfoBase::reportOutOfMemory( + int64_t /*alloc_size*/, + int64_t /*total_allocated*/, + int64_t /*total_reserved*/, + Device /*device*/) {} + } // namespace c10 diff --git a/c10/core/Allocator.h b/c10/core/Allocator.h index 4f571fd915111..3ea27fcb89265 100644 --- a/c10/core/Allocator.h +++ b/c10/core/Allocator.h @@ -243,6 +243,12 @@ struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase { int64_t total_reserved, Device device) = 0; + virtual void reportOutOfMemory( + int64_t alloc_size, + int64_t total_allocated, + int64_t total_reserved, + Device device); + virtual bool memoryProfilingEnabled() const = 0; }; @@ -254,4 +260,10 @@ C10_API void reportMemoryUsageToProfiler( int64_t total_reserved, Device device); +C10_API void reportOutOfMemoryToProfiler( + int64_t alloc_size, + int64_t total_allocated, + int64_t total_reserved, + Device device); + } // namespace c10 diff --git a/c10/core/CPUAllocator.cpp b/c10/core/CPUAllocator.cpp index 88df9b72069f2..60b76edb9c7f9 100644 --- a/c10/core/CPUAllocator.cpp +++ b/c10/core/CPUAllocator.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -16,7 +17,13 @@ namespace c10 { struct C10_API DefaultCPUAllocator final : at::Allocator { DefaultCPUAllocator() = default; at::DataPtr allocate(size_t nbytes) const override { - void* data = alloc_cpu(nbytes); + void* data = nullptr; + try { + data = c10::alloc_cpu(nbytes); + } catch (c10::Error& e) { + profiledCPUMemoryReporter().OutOfMemory(nbytes); + throw e; + } profiledCPUMemoryReporter().New(data, nbytes); return {data, data, &ReportAndDelete, at::Device(at::DeviceType::CPU)}; } @@ -112,13 +119,18 @@ class DefaultMobileCPUAllocator final : public at::Allocator { } else if (profiling_allocator_ptr != nullptr) { data = profiling_allocator_ptr->allocate(alloc_size); } else { - data = c10::alloc_cpu(alloc_size); + try { + data = c10::alloc_cpu(alloc_size); + } catch (c10::Error& e) { + profiledCPUMemoryReporter().OutOfMemory(alloc_size); + throw e; + } auto allocation_planner = GetThreadLocalAllocationPlanner(); if (allocation_planner != nullptr) { allocation_planner->record_allocation(alloc_size, data); } } - // profiledCPUMemoryReporter().New(data, alloc_size); + profiledCPUMemoryReporter().New(data, alloc_size); return { reinterpret_cast(data) + PreGuardBytes, data, @@ -235,6 +247,30 @@ void ProfiledCPUMemoryReporter::Delete(void* ptr) { } } +void ProfiledCPUMemoryReporter::OutOfMemory(size_t nbytes) { + auto profile_memory = memoryProfilingEnabled(); + size_t allocated = 0; + if (FLAGS_caffe2_report_cpu_memory_usage || profile_memory) { + std::lock_guard guard(mutex_); + + allocated = allocated_; + } + if (nbytes == 0) { + return; + } + if (FLAGS_caffe2_report_cpu_memory_usage) { + LOG(INFO) << "C10 Out of Memory. Trying to allocate " << nbytes + << " bytes, total alloc " << allocated << " bytes."; + } + if (profile_memory) { + reportOutOfMemoryToProfiler( + static_cast(nbytes), + static_cast(allocated), + 0, + c10::Device(c10::DeviceType::CPU)); + } +} + C10_API at::Allocator* cpu_caching_alloc = nullptr; C10_API uint8_t cpu_caching_alloc_priority = 0; diff --git a/c10/core/CPUAllocator.h b/c10/core/CPUAllocator.h index bf94097417ae6..a899401298180 100644 --- a/c10/core/CPUAllocator.h +++ b/c10/core/CPUAllocator.h @@ -17,12 +17,13 @@ using MemoryDeleter = void (*)(void*); // A helper function that is basically doing nothing. C10_API void NoDelete(void*); -// A simple struct that is used to report C10's memory allocation and -// deallocation status to the profiler +// A simple struct that is used to report C10's memory allocation, +// deallocation status and out-of-memory events to the profiler class C10_API ProfiledCPUMemoryReporter { public: ProfiledCPUMemoryReporter() {} void New(void* ptr, size_t nbytes); + void OutOfMemory(size_t nbytes); void Delete(void* ptr); private: diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index ca995bc9d9ab5..000ad331828b0 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -12,6 +12,23 @@ namespace c10 { +// These contains all device types that also have a BackendComponent +// and therefore participate in per-backend functionality dispatch keys. +// This is most backends except PrivateUse2 and PrivateUse3 +#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \ + _(CPU, extra) \ + _(CUDA, extra) \ + _(HIP, extra) \ + _(XLA, extra) \ + _(MPS, extra) \ + _(IPU, extra) \ + _(XPU, extra) \ + _(HPU, extra) \ + _(VE, extra) \ + _(Lazy, extra) \ + _(Meta, extra) \ + _(PrivateUse1, extra) + enum class DeviceType : int8_t { CPU = 0, CUDA = 1, // CUDA. diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index c07ea8731489a..4423d578b52d0 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -42,169 +42,120 @@ const char* toString(BackendComponent t) { } } +BackendComponent toBackendComponent(DeviceType device_type) { + switch (device_type) { +#define DO_CASE(device, _) \ + case DeviceType::device: { \ + return toBackendComponent(DispatchKey::device); \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + default: + return BackendComponent::InvalidBit; + } +} + const char* toString(DispatchKey t) { switch (t) { case DispatchKey::Undefined: return "Undefined"; - case DispatchKey::CPU: - return "CPU"; - case DispatchKey::CUDA: - return "CUDA"; - case DispatchKey::HIP: - return "HIP"; - case DispatchKey::VE: - return "VE"; + + case DispatchKey::Dense: + return "Dense"; case DispatchKey::FPGA: return "FPGA"; - case DispatchKey::XPU: - return "XPU"; - case DispatchKey::IPU: - return "IPU"; case DispatchKey::ORT: return "ORT"; - case DispatchKey::XLA: - return "XLA"; + case DispatchKey::Vulkan: + return "Vulkan"; + case DispatchKey::Metal: + return "Metal"; + case DispatchKey::Lazy: return "Lazy"; case DispatchKey::MPS: return "MPS"; case DispatchKey::HPU: return "HPU"; - case DispatchKey::Vulkan: - return "Vulkan"; - case DispatchKey::Metal: - return "Metal"; - case DispatchKey::QuantizedCPU: - return "QuantizedCPU"; - case DispatchKey::QuantizedCUDA: - return "QuantizedCUDA"; - case DispatchKey::QuantizedXPU: - return "QuantizedXPU"; + case DispatchKey::Quantized: + return "Quantized"; case DispatchKey::CustomRNGKeyId: return "CustomRNGKeyId"; - case DispatchKey::MkldnnCPU: return "MkldnnCPU"; - case DispatchKey::SparseCPU: - return "SparseCPU"; - case DispatchKey::SparseCUDA: - return "SparseCUDA"; + + case DispatchKey::Sparse: + return "Sparse"; case DispatchKey::SparseCsrCPU: return "SparseCsrCPU"; case DispatchKey::SparseCsrCUDA: return "SparseCsrCUDA"; - case DispatchKey::SparseHIP: - return "SparseHIP"; - case DispatchKey::SparseVE: - return "SparseVE"; - case DispatchKey::SparseXPU: - return "SparseXPU"; case DispatchKey::NestedTensor: return "NestedTensor"; - case DispatchKey::NestedTensorCPU: - return "NestedTensorCPU"; - case DispatchKey::NestedTensorCUDA: - return "NestedTensorCUDA"; + + case DispatchKey::BackendSelect: + return "BackendSelect"; case DispatchKey::Python: return "Python"; - case DispatchKey::PythonTLSSnapshot: - return "PythonTLSSnapshot"; - case DispatchKey::PrivateUse1: - return "PrivateUse1"; - case DispatchKey::PrivateUse2: - return "PrivateUse2"; - case DispatchKey::PrivateUse3: - return "PrivateUse3"; + case DispatchKey::Fake: + return "Fake"; + case DispatchKey::FuncTorchDynamicLayerBackMode: + return "FuncTorchDynamicLayerBackMode"; + + case DispatchKey::Functionalize: + return "Functionalize"; + + case DispatchKey::Named: + return "Named"; - case DispatchKey::Negative: - return "Negative"; case DispatchKey::Conjugate: return "Conjugate"; - case DispatchKey::Meta: - return "Meta"; + case DispatchKey::Negative: + return "Negative"; + case DispatchKey::ZeroTensor: + return "ZeroTensor"; case DispatchKey::ADInplaceOrView: return "ADInplaceOrView"; - case DispatchKey::Autograd: - return "Autograd"; - case DispatchKey::AutogradCPU: - return "AutogradCPU"; - case DispatchKey::AutogradIPU: - return "AutogradIPU"; - case DispatchKey::AutogradXPU: - return "AutogradXPU"; - case DispatchKey::AutogradCUDA: - return "AutogradCUDA"; - case DispatchKey::AutogradXLA: - return "AutogradXLA"; - case DispatchKey::AutogradLazy: - return "AutogradLazy"; - case DispatchKey::AutogradMeta: - return "AutogradMeta"; - case DispatchKey::AutogradMPS: - return "AutogradMPS"; - case DispatchKey::AutogradHPU: - return "AutogradHPU"; - case DispatchKey::AutogradPrivateUse1: - return "AutogradPrivateUse1"; - case DispatchKey::AutogradPrivateUse2: - return "AutogradPrivateUse2"; - case DispatchKey::AutogradPrivateUse3: - return "AutogradPrivateUse3"; case DispatchKey::AutogradOther: return "AutogradOther"; + case DispatchKey::AutogradFunctionality: + return "AutogradFunctionality"; case DispatchKey::AutogradNestedTensor: return "AutogradNestedTensor"; - case DispatchKey::ZeroTensor: - return "ZeroTensor"; - case DispatchKey::BackendSelect: - return "BackendSelect"; - case DispatchKey::Named: - return "Named"; - - case DispatchKey::Functionalize: - return "Functionalize"; - case DispatchKey::Tracer: return "Tracer"; - // Note: AutocastCUDA and Autocast are the same, currently. - // See comments in DispatchKey.h - case DispatchKey::Autocast: - return "Autocast"; - case DispatchKey::AutocastCPU: return "AutocastCPU"; - case DispatchKey::AutocastXPU: return "AutocastXPU"; + case DispatchKey::AutocastCUDA: + return "AutocastCUDA"; + + case DispatchKey::FuncTorchBatched: + return "FuncTorchBatched"; + case DispatchKey::FuncTorchVmapMode: + return "FuncTorchVmapMode"; case DispatchKey::Batched: return "Batched"; - case DispatchKey::VmapMode: return "VmapMode"; - case DispatchKey::CompositeImplicitAutograd: - return "CompositeImplicitAutograd"; - - case DispatchKey::CompositeExplicitAutograd: - return "CompositeExplicitAutograd"; - - case DispatchKey::CompositeExplicitAutogradNonFunctional: - return "CompositeExplicitAutogradNonFunctional"; - - case DispatchKey::TESTING_ONLY_GenericWrapper: - return "TESTING_ONLY_GenericWrapper"; + case DispatchKey::FuncTorchGradWrapper: + return "FuncTorchGradWrapper"; - case DispatchKey::TESTING_ONLY_GenericMode: - return "TESTING_ONLY_GenericMode"; + case DispatchKey::DeferredInit: + return "DeferredInit"; + case DispatchKey::PythonTLSSnapshot: + return "PythonTLSSnapshot"; // Note [Out-of-tree vmap+grad prototype] // The following keys are used in the implementation of the out-of-tree @@ -212,34 +163,57 @@ const char* toString(DispatchKey t) { // https://github.com/zou3519/functorch // We plan on eventually upstreaming the prototype into core, at which // point it will have a different design that should use fewer keys. - case DispatchKey::FuncTorchDynamicLayerBackMode: - return "FuncTorchDynamicLayerBackMode"; case DispatchKey::FuncTorchDynamicLayerFrontMode: return "FuncTorchDynamicLayerFrontMode"; - case DispatchKey::FuncTorchGradWrapper: - return "FuncTorchGradWrapper"; - case DispatchKey::FuncTorchVmapMode: - return "FuncTorchVmapMode"; - case DispatchKey::FuncTorchBatched: - return "FuncTorchBatched"; - // Out-of-core torchdistX dispatch keys - case DispatchKey::Fake: - return "Fake"; - case DispatchKey::DeferredInit: - return "DeferredInit"; + case DispatchKey::TESTING_ONLY_GenericWrapper: + return "TESTING_ONLY_GenericWrapper"; - case DispatchKey::Dense: - return "Dense"; - case DispatchKey::Quantized: - return "Quantized"; - case DispatchKey::Sparse: - return "Sparse"; - case DispatchKey::AutogradFunctionality: - return "AutogradFunctionality"; + case DispatchKey::TESTING_ONLY_GenericMode: + return "TESTING_ONLY_GenericMode"; + + // Aliases + + case DispatchKey::Autograd: + return "Autograd"; + case DispatchKey::CompositeImplicitAutograd: + return "CompositeImplicitAutograd"; + case DispatchKey::CompositeExplicitAutograd: + return "CompositeExplicitAutograd"; + case DispatchKey::CompositeExplicitAutogradNonFunctional: + return "CompositeExplicitAutogradNonFunctional"; + + // Per-backend dispatch keys default: - return "UNKNOWN_TENSOR_TYPE_ID"; + auto bc = toBackendComponent(t); + auto fk = toFunctionalityKey(t); + + switch (fk) { +#define ENTRY(backend, functionality) \ + case BackendComponent::backend##Bit: \ + return #functionality #backend; + +#define FORALL_BC(dkname, prefix) \ + case DispatchKey::dkname: \ + switch (bc) { \ + C10_FORALL_BACKEND_COMPONENTS(ENTRY, prefix) \ + default: \ + return #prefix "Unknown"; \ + } + + C10_FORALL_FUNCTIONALITY_KEYS(FORALL_BC) + + default: + switch (bc) { + C10_FORALL_BACKEND_COMPONENTS(ENTRY, Unknown) + default: + return "UnknownUnknown"; + } + +#undef FORALL_BC +#undef ENTRY + } } } diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 69c57ec89f5ea..2f1f1fc5f77e0 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -22,6 +23,36 @@ namespace c10 { // bits and take the highest bit to determine which backend's implementation to // use. +// WARNING! If you add a new backend component to the end of this list, +// make sure you update PrivateUse3Bit. (But you shouldn't: private use +// keys should have higher precedence than all built-in keys) + +#define C10_FORALL_BACKEND_COMPONENTS(_, extra) \ + _(CPU, extra) \ + _(CUDA, extra) \ + _(HIP, extra) \ + _(XLA, extra) \ + _(MPS, extra) \ + _(IPU, extra) \ + _(XPU, extra) \ + _(HPU, extra) \ + _(VE, extra) \ + _(Lazy, extra) \ + _(Meta, extra) \ + _(PrivateUse1, extra) \ + _(PrivateUse2, extra) \ + _(PrivateUse3, extra) + +// WARNING! If we add a new per-backend functionality key that has higher +// priority than Autograd, then make sure you update EndOfRuntimeBackendKeys + +#define C10_FORALL_FUNCTIONALITY_KEYS(_) \ + _(Dense, ) \ + _(Quantized, Quantized) \ + _(Sparse, Sparse) \ + _(NestedTensor, NestedTensor) \ + _(AutogradFunctionality, Autograd) + enum class BackendComponent : uint8_t { // A "backend" is colloquially used to refer to handlers for dispatch @@ -45,31 +76,20 @@ enum class BackendComponent : uint8_t { // of // [backends in this enum] x [keys below that are explicitly marked as having // per-backend functionality] - - InvalidBit = 0, - CPUBit, - CUDABit, - HIPBit, - XLABit, - MPSBit, - IPUBit, - XPUBit, - HPUBit, - VEBit, - LazyBit, + // // A meta tensor is a tensor without any data associated with it. (They // have also colloquially been referred to as tensors on the "null" device). // A meta tensor can be used to dry run operators without actually doing any // computation, e.g., add on two meta tensors would give you another meta // tensor with the output shape and dtype, but wouldn't actually add anything. - MetaBit, - PrivateUse1Bit, - PrivateUse2Bit, - PrivateUse3Bit, + + InvalidBit = 0, +#define DEFINE_BACKEND_COMPONENT(n, _) n##Bit, + C10_FORALL_BACKEND_COMPONENTS(DEFINE_BACKEND_COMPONENT, unused) +#undef DEFINE_BACKEND_COMPONENT + // Define an alias to represent end of backend dispatch keys. // If you add new backend keys after PrivateUse3, please also update it here. - // (But you shouldn't: private use keys should have higher precedence than - // all built-in keys) EndOfBackendKeys = PrivateUse3Bit, }; @@ -148,9 +168,11 @@ enum class DispatchKey : uint16_t { // If any of these backends ever need to customize, e.g., Autograd, then we'll // need to add a DispatchKey::*Bit for them. + // TODO: put this in BackendComponents FPGA, // Xilinx support lives out of tree at // https://gitlab.com/pytorch-complex/vitis_kernels + // TODO: put this in BackendComponents // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and // https://github.com/microsoft/onnxruntime, and is also used to test general // backend/extension machinery in the core. cf: @@ -159,8 +181,8 @@ enum class DispatchKey : uint16_t { // - aten/src/ATen/test/extension_backend_test.cpp ORT, - Vulkan, - Metal, + Vulkan, // TODO: put this in BackendComponents + Metal, // TODO: put this in BackendComponents // See [Note: Per-Backend Functionality Dispatch Keys] Quantized, @@ -176,6 +198,8 @@ enum class DispatchKey : uint16_t { // intended for out of tree use; tested by aten/src/ATen/test/rng_test.cpp CustomRNGKeyId, + // TODO: Make Mkldnn a functionality key, so we can give it Meta + // support // Here are backends which specify more specialized operators // based on the layout of the tensor. Note that the sparse backends // are one case where ordering matters: sparse multi-dispatches with @@ -186,23 +210,10 @@ enum class DispatchKey : uint16_t { // See [Note: Per-Backend Functionality Dispatch Keys] Sparse, + // TODO: Make SparseCsr a functionality key SparseCsrCPU, SparseCsrCUDA, - // Note [Non-Customizable Backend Keys] - // Every key above here is considered a "non-customizable backend". - // These are backends that will work correctly with autograd, but - // but currently don't require separate implementations - // for autograd sparse or quantized kernels. - // Any new backends that don't need to be customized should go above here. - // If an existing backend needs to e.g. override autograd, then we can - // consider promoting it into the "BackendComponent" enum - // - // For all intents and purposes from the perspective of DispatchKeySet, - // "non-customizable backend" keys are treated the same way - // as other functionality keys - EndOfNonCustomizableBackends = SparseCsrCUDA, - NestedTensor, // In some situations, it is not immediately obvious what the correct @@ -216,8 +227,8 @@ enum class DispatchKey : uint16_t { // Out-of-core key for Fake Tensor in torchdistx. // See https://pytorch.org/torchdistx/latest/fake_tensor.html + // TODO: delete this in favor of Python-implemented fake tensor Fake, - // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key // is to insert code after the "autograd subsystem" runs, so this key should // be directly after ADInplaceOrView and all of the autograd keys. @@ -242,6 +253,7 @@ enum class DispatchKey : uint16_t { // key that triggers before composite operators, in case a composite operator // has named dimension propagation that doesn't match that of its // constituent parts. + // TODO: delete this once torchdim lands in functorch Named, // The Conjugate dispatch key is set for any tensors that need to perform @@ -333,6 +345,7 @@ enum class DispatchKey : uint16_t { Tracer, + // TODO: make Autocast a functionality key // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed // and inputs are saved for backward in the post-autocast type. AutocastCPU, @@ -391,128 +404,24 @@ enum class DispatchKey : uint16_t { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // EndOfFunctionalityKeys, // End of functionality keys. - // ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ // - // Here are backends which you think of as traditionally specifying - // how to implement operations on some device. - - // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] - StartOfDenseBackends, - CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp - CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp - HIP, // NB: I think this is not actually used, due to Note [Masquerading as - // CUDA] - XLA, // lives out of tree at https://github.com/pytorch/xla - MPS, // registered at build/aten/src/ATen/RegisterMPS.cpp - IPU, // lives out of tree at https://github.com/graphcore/poptorch - XPU, // For out of tree Intel's heterogeneous computing plug-in - HPU, // For out of tree & closed source integration of HPU / Habana - VE, // For out of tree & closed source integration of SX-Aurora / NEC - Lazy, // For lazy tensor backends - Meta, - // Here are reserved backends for user-defined backends, see Note [Private use - // DispatchKey] - // To see some example about how to use this, check out ORT - PrivateUse1, - PrivateUse2, - PrivateUse3, - EndOfDenseBackends = PrivateUse3, - - // ~~~~~~~~~~~~~~ "Quantized" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~ // - // keys starting with an _ are not currently used, - // but are needed to ensure that every backend is indexed correctly. - - // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] - StartOfQuantizedBackends, - QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp - QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp - _QuantizedHIP, - _QuantizedXLA, - _QuantizedMPS, - _QuantizedIPU, - QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in - _QuantizedHPU, - _QuantizedVE, - _QuantizedLazy, - _QuantizedMeta, - _QuantizedPrivateUse1, - _QuantizedPrivateUse2, - _QuantizedPrivateUse3, - EndOfQuantizedBackends = _QuantizedPrivateUse3, - - // ~~~~~~~~~~~~~~ "Sparse" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~ // - // keys starting with an _ are not currently used, - // but are needed to ensure that every backend is indexed correctly. - - // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] - StartOfSparseBackends, - SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp - SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp - SparseHIP, // TODO: I think this is not actually used, due to Note - // [Masquerading as CUDA] - _SparseXLA, - _SparseMPS, - _SparseIPU, - SparseXPU, // For out of tree Intel's heterogeneous computing plug-in - _SparseHPU, - SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC - _SparseLazy, - _SparseMeta, - _SparsePrivateUse1, - _SparsePrivateUse2, - _SparsePrivateUse3, - EndOfSparseBackends = _SparsePrivateUse3, - - // ~~~~~~~~~~~~~~ "NestedTensor" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~ - // // - // keys starting with an _ are not currently used, - // but are needed to ensure that every backend is indexed correctly. - - // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] - StartOfNestedTensorBackends, - // registered at build/aten/src/ATen/RegisterNestedTensorCPU.cpp - NestedTensorCPU, - // registered at build/aten/src/ATen/RegisterNestedTensorCUDA.cpp - NestedTensorCUDA, - _NestedTensorHIP, - _NestedTensorXLA, - _NestedTensorMPS, - _NestedTensorIPU, - _NestedTensorXPU, - _NestedTensorHPU, - _NestedTensorVE, - _NestedTensorLazy, - _NestedTensorMeta, - _NestedTensorPrivateUse1, - _NestedTensorPrivateUse2, - _NestedTensorPrivateUse3, - EndOfNestedTensorBackends = _NestedTensorPrivateUse3, - - // ~~~~~~~~~~~~~~ "Autograd" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~ // - // keys starting with an _ are not currently used, - // but are needed to ensure that every backend is indexed correctly. - - // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] - StartOfAutogradBackends, - AutogradCPU, - AutogradCUDA, - _AutogradHIP, - AutogradXLA, - AutogradMPS, - AutogradIPU, - AutogradXPU, - AutogradHPU, - _AutogradVE, - AutogradLazy, - AutogradMeta, - // Here are some reserved pre-autograd keys for user-defined backends, see - // Note [Private use DispatchKey] - AutogradPrivateUse1, - AutogradPrivateUse2, - AutogradPrivateUse3, - EndOfAutogradBackends = AutogradPrivateUse3, - // If we add a new per-backend functionality key that has higher priority - // than Autograd, then this key should be updated. - EndOfRuntimeBackendKeys = EndOfAutogradBackends, +// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ // +// Here are backends which you think of as traditionally specifying +// how to implement operations on some device. + +#define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n, + +#define DEFINE_PER_BACKEND_KEYS(fullname, prefix) \ + StartOf##fullname##Backends, \ + C10_FORALL_BACKEND_COMPONENTS( \ + DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \ + EndOf##fullname##Backends = prefix##PrivateUse3, + + C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS) + +#undef DEFINE_PER_BACKEND_KEYS +#undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND + + EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends, // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ // // Note [Alias Dispatch Keys] @@ -697,11 +606,12 @@ constexpr BackendComponent toBackendComponent(DispatchKey k) { static_cast(k) - static_cast(DispatchKey::StartOfNestedTensorBackends)); } else if ( - k >= DispatchKey::StartOfAutogradBackends && - k <= DispatchKey::EndOfAutogradBackends) { + k >= DispatchKey::StartOfAutogradFunctionalityBackends && + k <= DispatchKey::EndOfAutogradFunctionalityBackends) { return static_cast( static_cast(k) - - static_cast(DispatchKey::StartOfAutogradBackends)); + static_cast( + DispatchKey::StartOfAutogradFunctionalityBackends)); } else { return BackendComponent::InvalidBit; } @@ -718,13 +628,15 @@ constexpr DispatchKey toFunctionalityKey(DispatchKey k) { return DispatchKey::Sparse; } else if (k <= DispatchKey::EndOfNestedTensorBackends) { return DispatchKey::NestedTensor; - } else if (k <= DispatchKey::EndOfAutogradBackends) { + } else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) { return DispatchKey::AutogradFunctionality; } else { return DispatchKey::Undefined; } } +BackendComponent toBackendComponent(DeviceType device_type); + // Given (DispatchKey::Dense, BackendComponent::CUDABit), returns // DispatchKey::CUDA. // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] @@ -756,7 +668,8 @@ constexpr DispatchKey toRuntimePerBackendFunctionalityKey( } if (functionality_k == DispatchKey::AutogradFunctionality) { return static_cast( - static_cast(DispatchKey::StartOfAutogradBackends) + + static_cast( + DispatchKey::StartOfAutogradFunctionalityBackends) + static_cast(backend_k)); } return DispatchKey::Undefined; diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 14151fce6feb2..358703210112a 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -28,6 +28,7 @@ constexpr DispatchKeySet non_functional_backend_dispatch_keyset = backend_dispatch_keyset // XLA and LazyTensor are currently the only 2 backends in core // that use functionalization pass in eager mode. + .remove(DispatchKey::Sparse) .remove_backend(BackendComponent::XLABit) .remove_backend(BackendComponent::LazyBit); @@ -38,16 +39,21 @@ bool isBackendDispatchKey(DispatchKey t) { // Note [NestedTensor Not Included in Backend Keys] // NestedTensor has been explicitly removed from the "backend keyset" due // to incompatibility with some kernels, so we don't want it to be - // included in CompositeImplicitAutograd or CompositeExplicitAutograd - // kernels. + // included in CompositeExplicitAutograd kernels. && t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t); } // math_dispatch_keyset contains all keys in backend_dispatch_keyset and // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd // maps to [math_dispatch_keyset x full_backend_mask] -constexpr DispatchKeySet math_dispatch_keyset = - backend_dispatch_keyset | autograd_dispatch_keyset; +constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | + autograd_dispatch_keyset | + // See Note [NestedTensor Not Included in Backend Keys] + // The caveat to that note is that nested_tensor is a special case + // where we would like to support composite implict kernels but not + // explicit kernels therefore we manually add the key to the + // math_dispatch_keyset + DispatchKeySet{DispatchKey::NestedTensor}; DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); @@ -77,7 +83,7 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) { return autograd_dispatch_keyset.has(toFunctionalityKey(k)); case DispatchKey::CompositeImplicitAutograd: // See Note [NestedTensor Not Included in Backend Keys] - return k != DispatchKey::NestedTensor && math_dispatch_keyset.has(k); + return math_dispatch_keyset.has(k); case DispatchKey::CompositeExplicitAutograd: // See Note [NestedTensor Not Included in Backend Keys] return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k); @@ -118,6 +124,9 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { return DispatchKeySet(DispatchKey::PrivateUse2); case DispatchKey::AutogradPrivateUse3: return DispatchKeySet(DispatchKey::PrivateUse3); + case DispatchKey::AutogradNestedTensor: + return DispatchKeySet(DispatchKey::NestedTensor) | + DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); case DispatchKey::AutogradOther: return autogradother_backends; default: diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 656dbfc441295..d3d90693b9066 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -633,6 +633,7 @@ C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) { constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ DispatchKey::AutogradFunctionality, DispatchKey::AutogradOther, + DispatchKey::AutogradNestedTensor, }); constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ @@ -741,7 +742,8 @@ constexpr auto autograd_privateuse2_ks = constexpr auto autograd_privateuse3_ks = DispatchKeySet(DispatchKey::AutogradPrivateUse3); constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther); - +constexpr auto autograd_nested = + DispatchKeySet(DispatchKey::AutogradNestedTensor); // keyset correpsonding to functorch keys that have their own dedicated // TensorImpl subclass. constexpr auto functorch_transforms_ks = DispatchKeySet( diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 67828a3b27dd8..9e6b56d611a81 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -114,6 +114,9 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) #undef SPECIALIZE_ScalarTypeToCPPType +template +using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType::type; + } // namespace impl template diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp index e2a570b03f0fa..28e477481390a 100644 --- a/c10/core/SymInt.cpp +++ b/c10/core/SymInt.cpp @@ -1,21 +1,39 @@ - #include -#include +#include +#include namespace c10 { -std::shared_ptr SymInt::toSymbolicIntNode() const { - auto& st = getSymIntTable(); +std::array normalize_symints(SymInt a_, SymInt b_) { + SymIntNode a, b; + if (a_.is_symbolic()) + a = a_.toSymIntNodeImpl(); + if (b_.is_symbolic()) + b = b_.toSymIntNodeImpl(); + + SymIntNodeImpl* common = a ? a.get() : b.get(); + // TODO: technically we need to check that the classes match + if (!a) { + a = common->wrap(a_.as_int_unchecked()); + a_.toSymInt(a); // + } + if (!b) { + b = common->wrap(b_.as_int_unchecked()); + b_.toSymInt(b); + } + return {a, b}; +} + +SymIntNode SymInt::toSymIntNodeImpl() const { TORCH_CHECK(is_symbolic()); - return st.getNode(static_cast(data_) & ~MASK); + return SymIntNode::reclaim_copy(toSymIntNodeImplUnowned()); } -c10::SymInt SymInt::toSymInt(std::shared_ptr sin_sp) { - auto& sit = getSymIntTable(); - uint64_t idx = sit.addNode(sin_sp); - TORCH_CHECK(idx < MAX_SYM_IDX, "SymbolicIntNode index overflow: ", idx); - uint64_t data = idx | IS_SYM; - return c10::SymInt(static_cast(data)); +c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) { + auto ptr = static_cast( + reinterpret_cast(static_cast(sin_sp.release()))); + auto rep = (ptr & ~MASK) | IS_SYM; + return c10::SymInt(static_cast(rep)); } SymInt SymInt::operator+(SymInt sci) const { @@ -29,21 +47,20 @@ SymInt SymInt::operator*(SymInt sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymInt(data_ * sci.data_); } - // TODO: This is way to much boilerplate - std::shared_ptr a = - is_symbolic() ? toSymbolicIntNode() : nullptr; - std::shared_ptr b = - sci.is_symbolic() ? sci.toSymbolicIntNode() : nullptr; + auto res = normalize_symints(*this, sci); + return SymInt::toSymInt(res[0]->mul(res[1])); +} - SymbolicIntNode* common = a ? a.get() : b.get(); - // TODO: technically we need to check that the classes match - if (!a) { - a = common->wrap(data_); - } - if (!b) { - b = common->wrap(sci.data_); +bool SymInt::operator==(SymInt sci) const { + if (!is_symbolic() && !sci.is_symbolic()) { + return data_ == sci.data_; } - return SymInt::toSymInt(a->add(b)); + auto res = normalize_symints(*this, sci); + return res[0]->eq(res[1])->bool_(); +} + +bool SymInt::operator!=(SymInt sci) const { + return !(*this == sci); } bool SymInt::operator<(SymInt sci) const { @@ -66,13 +83,11 @@ bool SymInt::operator<(int64_t sci) const { } bool SymInt::operator==(int64_t sci) const { - TORCH_CHECK(!this->is_symbolic(), "Symbolic eq isn't supported yet"); - return data_ == sci; + return *this == c10::SymInt(sci); } bool SymInt::operator!=(int64_t sci) const { - TORCH_CHECK(!this->is_symbolic(), "Symbolic neq isn't supported yet"); - return data_ != sci; + return *this != c10::SymInt(sci); } SymInt SymInt::operator*(int64_t sci) const { diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index 435825900caba..331f10305dec0 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -1,11 +1,13 @@ #pragma once +#include #include #include +#include -namespace c10 { +#include -class SymbolicIntNode; +namespace c10 { // `SymInt` is a C++ wrapper class around int64_t data_ which and is used to // represent concrete dimension values. @@ -23,13 +25,57 @@ class SymbolicIntNode; // functions. // // SymInt will be extenteded to represent a union structure Union[int64_t, -// SymbolicIntNode*] which will be implemented as a single packed int64_t field +// SymIntNodeImpl*] which will be implemented as a single packed int64_t field // named data_. class C10_API SymInt { public: + // TODO: this needs to only accept integers, not pointers /*implicit*/ SymInt(int64_t d) : data_(d){}; SymInt() = default; + // TODO: these implementations are not optimal because they allocate a + // temporary and then use the move constructor/assignment + SymInt(const SymInt& s) : data_(0) { + if (s.is_symbolic()) { + *this = SymInt::toSymInt(s.toSymIntNodeImpl()); + } else { + data_ = s.data_; + } + } + SymInt(SymInt&& s) : data_(s.data_) { + s.data_ = 0; + } + + SymInt& operator=(const SymInt& s) { + if (s.is_symbolic()) { + *this = SymInt::toSymInt(s.toSymIntNodeImpl()); + } else { + data_ = s.data_; + } + return *this; + } + SymInt& operator=(SymInt&& s) { + data_ = s.data_; + if (s.is_symbolic()) + s.data_ = 0; + return *this; + } + + SymIntNodeImpl* toSymIntNodeImplUnowned() const { + uint64_t unextended_bits = static_cast(data_) & ~MASK; + uint64_t sign_bit_mask = 1ULL << (62 - 1); + // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c + uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; + return static_cast( + reinterpret_cast(static_cast(extended_bits))); + } + + ~SymInt() { + if (is_symbolic()) { + SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal + } + } + int64_t expect_int() const { TORCH_CHECK(!is_symbolic()); return data_; @@ -39,16 +85,10 @@ class C10_API SymInt { return (MASK & static_cast(this->data_)) == IS_SYM; } - bool operator==(const SymInt& p2) const { - return data_ == p2.data_; - } - - bool operator!=(const SymInt& p2) const { - return data_ != p2.data_; - } - SymInt operator+(SymInt sci) const; SymInt operator*(SymInt sci) const; + bool operator==(SymInt sci) const; + bool operator!=(SymInt p2) const; bool operator<(SymInt sci) const; void operator*=(SymInt sci); @@ -57,18 +97,13 @@ class C10_API SymInt { bool operator==(int64_t sci) const; bool operator!=(int64_t sci) const; - std::shared_ptr toSymbolicIntNode() const; - static c10::SymInt toSymInt(std::shared_ptr sin); + SymIntNode toSymIntNodeImpl() const; + static c10::SymInt toSymInt(SymIntNode sin); int64_t as_int_unchecked() const { return data_; } - // This is needed for interoperability with IValue - int64_t data() const { - return data_; - } - // Return whether the integer is representable as a SymInt. static bool check_range(int64_t i) { return i > MIN_INT; @@ -76,15 +111,17 @@ class C10_API SymInt { private: // Constraints on the internal representation: - // - Should represent positive and negative ints + // - Should represent positive and small negative ints // - No conversion necessary for operations on ints. - // - We reserve some values to act as indices into our sym int table. + // - Must represent valid 64-bit pointers // // So, the scheme is to reserve large negative numbers: // - 0b0.... means we are a positive int (following two's complement) // - 0b11... means we are a negative int (following two's complement) - // - 0b10... means we are index into the sym table. This means that + // - 0b10... means we are are a pointer. This means that // [-2^63, -2^62-1] are not representable as ints. + // We don't actually need all of this space as on x86_64 + // as the top 16bits aren't used for anything static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62; static constexpr uint64_t IS_SYM = 1ULL << 63; // Since we use the top two bits to determine whether something is symbolic, @@ -93,7 +130,7 @@ class C10_API SymInt { static constexpr uint64_t MAX_SYM_IDX = 1ULL << 62; // Since 0b10... is reserved for symbolic indices, any integers lower than // this value would collide with our representation. - static constexpr int64_t MIN_INT = -1LL & ~(1ULL << 62); + static constexpr int64_t MIN_INT = -1LL & static_cast(~(1ULL << 62)); int64_t data_; }; diff --git a/c10/core/SymIntArrayRef.cpp b/c10/core/SymIntArrayRef.cpp index 73546babf2920..44a419f4a9f45 100644 --- a/c10/core/SymIntArrayRef.cpp +++ b/c10/core/SymIntArrayRef.cpp @@ -1,26 +1,35 @@ #include +#include #include namespace c10 { at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar) { + auto r = asIntArrayRefSlowOpt(ar); + TORCH_CHECK( + r.has_value(), + "SymIntArrayRef expected to contain only concrete integers"); + return *r; +} + +c10::optional asIntArrayRefSlowOpt(c10::SymIntArrayRef ar) { for (c10::SymInt sci : ar) { - TORCH_CHECK(!sci.is_symbolic()); + if (sci.is_symbolic()) { + return c10::nullopt; + } } - return asIntArrayRefUnchecked(ar); + + return {asIntArrayRefUnchecked(ar)}; } at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) { return IntArrayRef(reinterpret_cast(ar.data()), ar.size()); } +// TODO: this print is bad std::ostream& operator<<(std::ostream& os, SymInt s) { - os << "SymInt(" << s.data() << ")"; + os << "SymInt(" << s.as_int_unchecked() << ")"; return os; } -std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list) { - return out << list.wrapped_symint_array_ref; -} - } // namespace c10 diff --git a/c10/core/SymIntArrayRef.h b/c10/core/SymIntArrayRef.h index 96f136c4afea9..bf2eb65c55366 100644 --- a/c10/core/SymIntArrayRef.h +++ b/c10/core/SymIntArrayRef.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -62,6 +63,11 @@ class SymIntArrayRef final { size_t length) : wrapped_symint_array_ref(data, length) {} + template + /* implicit */ SymIntArrayRef( + const SmallVectorTemplateCommon& Vec) + : wrapped_symint_array_ref(Vec) {} + /// Construct an SymIntArrayRef from a range. C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef( const c10::SymInt* begin, @@ -189,7 +195,13 @@ class SymIntArrayRef final { TORCH_API at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar); TORCH_API at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar); - -std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list); +TORCH_API c10::optional asIntArrayRefSlowOpt( + c10::SymIntArrayRef ar); + +inline std::ostream& operator<<( + std::ostream& out, + const c10::SymIntArrayRef& list) { + return out << list.wrapped_symint_array_ref; +} } // namespace c10 diff --git a/c10/core/SymIntNodeImpl.cpp b/c10/core/SymIntNodeImpl.cpp new file mode 100644 index 0000000000000..483110a90fa64 --- /dev/null +++ b/c10/core/SymIntNodeImpl.cpp @@ -0,0 +1,11 @@ +#include +#include + +namespace c10 { + +c10::SymInt SymIntNodeImpl::toSymInt() { + auto sit_sp = SymIntNode::reclaim_copy(this); + return SymInt::toSymInt(sit_sp); +} + +} // namespace c10 diff --git a/c10/core/SymIntNodeImpl.h b/c10/core/SymIntNodeImpl.h new file mode 100644 index 0000000000000..e5ffd2d5ef6a3 --- /dev/null +++ b/c10/core/SymIntNodeImpl.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +class SymInt; +class SymIntNodeImpl; +using SymIntNode = c10::intrusive_ptr; + +class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target { + public: + c10::SymInt toSymInt(); + virtual ~SymIntNodeImpl(){}; + + template + c10::intrusive_ptr dyn_cast() const { + return c10::intrusive_ptr::reclaim_copy(dynamic_cast(this)); + } + + // these could be pure virtual when we implement LTC versions + virtual SymIntNode add(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode sub(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode mul(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode div(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode mod(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode eq(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode ne(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode gt(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode lt(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode le(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode ge(const SymIntNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymIntNode wrap(int64_t num) { + TORCH_CHECK(false, "NYI"); + }; + virtual bool bool_() { + TORCH_CHECK(false, "NYI"); + }; + virtual int64_t int_() { + TORCH_CHECK(false, "NYI"); + } + virtual std::string str() { + TORCH_CHECK(false, "NYI"); + }; + std::ostream& operator<<(std::ostream& os) { + os << str(); + return os; + }; +}; + +} // namespace c10 diff --git a/c10/core/SymIntTable.cpp b/c10/core/SymIntTable.cpp deleted file mode 100644 index 40f578bdf2f73..0000000000000 --- a/c10/core/SymIntTable.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -namespace c10 { - -uint64_t SymIntTable::addNode(std::shared_ptr sin) { - std::lock_guard lock(mutex_); - auto index = nodes_.size(); - nodes_.push_back(sin); - return index; -} -std::shared_ptr SymIntTable::getNode(size_t index) { - std::lock_guard lock(mutex_); - TORCH_CHECK(index < nodes_.size()); - return nodes_[index]; -} - -c10::SymInt SymbolicIntNode::toSymInt() { - // We will need to figure out a way - // to dedup nodes - auto sit_sp = this->shared_from_this(); - return SymInt::toSymInt(sit_sp); -} - -SymIntTable& getSymIntTable() { - static SymIntTable sit; - return sit; -} - -} // namespace c10 diff --git a/c10/core/SymbolicIntNode.h b/c10/core/SymbolicIntNode.h deleted file mode 100644 index 5cc3cd324257b..0000000000000 --- a/c10/core/SymbolicIntNode.h +++ /dev/null @@ -1,79 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -namespace c10 { - -class C10_API SymbolicIntNode - : public std::enable_shared_from_this { - public: - c10::SymInt toSymInt(); - virtual ~SymbolicIntNode(){}; - // these could be pure virtual when we implement LTC versions - virtual std::shared_ptr add( - const std::shared_ptr& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual std::shared_ptr sub( - const std::shared_ptr& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual std::shared_ptr mul( - const std::shared_ptr& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual std::shared_ptr div( - const std::shared_ptr& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual std::shared_ptr mod( - const std::shared_ptr& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual std::shared_ptr eq( - const std::shared_ptr& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual std::shared_ptr gt( - const std::shared_ptr& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual std::shared_ptr lt( - const std::shared_ptr& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual std::shared_ptr wrap(int64_t num) { - TORCH_CHECK(false, "NYI"); - }; - virtual bool bool_() { - TORCH_CHECK(false, "NYI"); - }; - virtual int64_t int_() { - TORCH_CHECK(false, "NYI"); - } - virtual std::string str() { - TORCH_CHECK(false, "NYI"); - }; - std::ostream& operator<<(std::ostream& os) { - os << str(); - return os; - }; -}; - -class C10_API SymIntTable { - public: - uint64_t addNode(std::shared_ptr sin); - std::shared_ptr getNode(size_t index); - - private: - std::vector> nodes_; - std::mutex mutex_; -}; - -C10_API SymIntTable& getSymIntTable(); - -} // namespace c10 diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 7c85803b83bd1..e2d8e9684e6f9 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -151,6 +151,8 @@ TensorImpl::TensorImpl( C10_LOG_API_USAGE_ONCE("tensor.create"); } + // XXX: if updating keyset logic here also update + // _change_backend_component_keys bool inference_mode = c10::InferenceMode::is_enabled(); // TODO: be more explicit about the full key set at call sites so we @@ -184,6 +186,25 @@ TensorImpl::TensorImpl( // Caffe2 operators create Storages with default devices. } +void TensorImpl::_change_backend_component_keys(c10::Device device) { + BackendComponent new_backend = toBackendComponent(device.type()); + BackendComponent old_backend = key_set_.highestBackendKey(); + + // following logic TensorImpl::TensorImpl, update the BackendComponent related + // keys to correspond to device + + // TODO: Autocoast should be a per-backend functionality key, once that change + // is made this key swap will not be necessary. + auto key_set = + key_set_ - c10::getAutocastRelatedKeySetFromBackend(old_backend); + key_set = key_set | c10::getAutocastRelatedKeySetFromBackend(new_backend); + + // See note [Removing keys from DispatchKeySet Only Affects Functionality + // Keys] + key_set = key_set.remove_backend(old_backend); + key_set_ = key_set | DispatchKeySet(new_backend); +} + void TensorImpl::HandleResize() { // If needed, we will free the data. the next mutable_data() call // will create the data storage. @@ -390,12 +411,12 @@ IntArrayRef TensorImpl::sizes_custom() const { TORCH_CHECK( false, "Tensors of type ", tensorimpl_type_name(), " do not have sizes"); } + c10::SymIntArrayRef TensorImpl::sym_sizes_custom() const { - TORCH_CHECK( - false, - "Tensors of type ", - tensorimpl_type_name(), - " do not have sym sizes"); + if (C10_UNLIKELY(is_python_dispatch())) { + return load_pyobj_interpreter()->sym_sizes(this); + } + return sym_sizes_default(); } c10::Device TensorImpl::device_custom() const { @@ -416,6 +437,7 @@ IntArrayRef TensorImpl::strides_custom() const { tensorimpl_type_name(), " do not have strides"); } + int64_t TensorImpl::dim_custom() const { if (is_python_dispatch()) { return load_pyobj_interpreter()->dim(this); @@ -423,11 +445,20 @@ int64_t TensorImpl::dim_custom() const { TORCH_CHECK( false, "Tensors of type ", tensorimpl_type_name(), " do not have dim"); } + int64_t TensorImpl::numel_custom() const { TORCH_CHECK( false, "Tensors of type ", tensorimpl_type_name(), " do not have numel"); } +c10::Layout TensorImpl::layout_custom() const { + if (is_python_dispatch()) { + return load_pyobj_interpreter()->layout(this); + } + TORCH_CHECK( + false, "Tensors of type ", tensorimpl_type_name(), " do not have layout"); +} + static void deletePlacementDeleteContext(void* ptr) { delete static_cast(ptr); } @@ -516,8 +547,13 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach_core( /*dest_impl=*/impl.get(), /*version_counter=*/std::forward(version_counter), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); - impl->refresh_numel(); - impl->refresh_contiguous(); + + // We currently don't support refresh_numel() and refresh_contiguous(). It's + // plausible that we could support it, but currently done to unblock. + if (!has_symbolic_sizes_strides()) { + impl->refresh_numel(); + impl->refresh_contiguous(); + } return impl; } @@ -550,6 +586,9 @@ void TensorImpl::copy_generic_tensor_metadata( const TensorImpl* src_impl, TensorImpl* dest_impl) { dest_impl->sizes_and_strides_ = src_impl->sizes_and_strides_; + dest_impl->has_symbolic_sizes_strides_ = + src_impl->has_symbolic_sizes_strides_; + dest_impl->storage_offset_ = src_impl->storage_offset_; dest_impl->data_type_ = src_impl->data_type_; dest_impl->device_opt_ = src_impl->device_opt_; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index c6912dbb234c7..a2ffa3123b083 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -552,10 +552,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return sizes_default(); } + // TODO: make it non-virtual after a change to XLA virtual c10::SymIntArrayRef sym_sizes() const { + if (C10_UNLIKELY( + sizes_strides_policy_ >= + static_cast(SizesStridesPolicy::CustomSizes))) { + return sym_sizes_custom(); + } return sym_sizes_default(); } + virtual c10::SymIntArrayRef sym_sizes_custom() const; + /** * Return a reference to the strides of this tensor. This reference remains * valid as long as the tensor is live and not restrided. @@ -577,15 +585,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * be faster */ int64_t size(int64_t d) const { - d = maybe_wrap_dim(d, dim(), false); if (C10_UNLIKELY( sizes_strides_policy_ >= static_cast(SizesStridesPolicy::CustomSizes))) { - return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + return size_custom(d); } + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); return sizes_and_strides_.size_at_unchecked(d).as_int_unchecked(); } + c10::SymInt sym_size(int64_t d) const { + if (C10_UNLIKELY( + sizes_strides_policy_ >= + static_cast(SizesStridesPolicy::CustomSizes))) { + return sym_size_custom(d); + } + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + const auto sizes = this->sym_sizes(); + return sizes[d]; + } + /** * Return the stride of a tensor at some dimension, wrapping the dimension * if necessary. @@ -662,6 +681,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { sizes_and_strides_.size()); } + inline c10::SymIntArrayRef sym_sizes_default() const { + return c10::SymIntArrayRef( + reinterpret_cast(sizes_and_strides_.sizes_data()), + sizes_and_strides_.size()); + } + protected: /** * Customization points for the functions above. sizes_strides_policy_ @@ -674,9 +699,27 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { virtual IntArrayRef strides_custom() const; virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const; // sizes_strides_policy_ >= CustomSizes + // Currently this method only exists to be overwritten by subclasses such as + // NestedTensorImpl. + virtual int64_t size_custom(int64_t d) const { + // TODO: We could add support to Python dispatch here. + // TODO: We could call into aten::size.int instead of + // sizes_custom()[d] and enable use of the dispatcher. + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + + virtual c10::SymInt sym_size_custom(int64_t d) const { + // TODO: We could add support to Python dispatch here. + // TODO: We could call into aten::size.int instead of + // sym_sizes_custom()[d] and enable use of the dispatcher. + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sym_sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + virtual IntArrayRef sizes_custom() const; - virtual c10::SymIntArrayRef sym_sizes_custom() const; virtual Device device_custom() const; + virtual Layout layout_custom() const; virtual int64_t dim_custom() const; virtual int64_t numel_custom() const; @@ -692,11 +735,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } return is_contiguous_; } - inline c10::SymIntArrayRef sym_sizes_default() const { - return c10::SymIntArrayRef( - reinterpret_cast(sizes_and_strides_.sizes_data()), - sizes_and_strides_.size()); - } inline int64_t dim_default() const { return sizes_and_strides_.size(); } @@ -799,8 +837,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_meta(); } - constexpr auto meta_ks = DispatchKeySet(BackendComponent::MetaBit); - return key_set_.has_all(meta_ks); + return device_opt_.has_value() && device_opt_->type() == kMeta; } bool is_cpu() const { @@ -809,9 +846,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_cpu(); } - constexpr auto cpu_bits_ks = DispatchKeySet(BackendComponent::CPUBit) | - DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::MkldnnCPU}); - return key_set_.has_any(cpu_bits_ks); + // Note: we cannot rely on dispatch keys to determine the device type + // of a tensor, because "wrapper" tensors (like FunctionalTensorWrapper) + // don't include backend dispatch keys. + return device_opt_.has_value() && device_opt_->type() == kCPU; } bool is_cuda() const { @@ -820,9 +858,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_cuda(); } - constexpr auto cuda_bits_ks = DispatchKeySet(BackendComponent::CUDABit) | - DispatchKeySet(DispatchKey::SparseCsrCUDA); - return key_set_.has_any(cuda_bits_ks); + return device_opt_.has_value() && device_opt_->type() == kCUDA; } bool is_xpu() const { @@ -831,40 +867,35 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_xpu(); } - constexpr auto xpu_ks = DispatchKeySet(BackendComponent::XPUBit); - return key_set_.has_all(xpu_ks); + return device_opt_.has_value() && device_opt_->type() == kXPU; } bool is_ipu() const { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_ipu(); } - constexpr auto ipu_ks = DispatchKeySet(BackendComponent::IPUBit); - return key_set_.has_all(ipu_ks); + return device_opt_.has_value() && device_opt_->type() == kIPU; } bool is_xla() const { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_xla(); } - constexpr auto xla_ks = DispatchKeySet(BackendComponent::XLABit); - return key_set_.has_all(xla_ks); + return device_opt_.has_value() && device_opt_->type() == kXLA; } bool is_hpu() const { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_hpu(); } - constexpr auto hpu_ks = DispatchKeySet(BackendComponent::HPUBit); - return key_set_.has_all(hpu_ks); + return device_opt_.has_value() && device_opt_->type() == kHPU; } bool is_lazy() const { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_lazy(); } - constexpr auto lazy_ks = DispatchKeySet(BackendComponent::LazyBit); - return key_set_.has_all(lazy_ks); + return device_opt_.has_value() && device_opt_->type() == kLazy; } bool is_hip() const { @@ -873,8 +904,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_hip(); } - constexpr auto hip_ks = DispatchKeySet(BackendComponent::HIPBit); - return key_set_.has_all(hip_ks); + return device_opt_.has_value() && device_opt_->type() == kHIP; } bool is_ve() const { @@ -883,8 +913,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_ve(); } - constexpr auto ve_ks = DispatchKeySet(BackendComponent::VEBit); - return key_set_.has_all(ve_ks); + return device_opt_.has_value() && device_opt_->type() == kVE; } bool is_mkldnn() const { @@ -895,31 +924,28 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_vulkan(); } - constexpr auto vulkan_ks = DispatchKeySet(DispatchKey::Vulkan); - return key_set_.has_all(vulkan_ks); + return device_opt_.has_value() && device_opt_->type() == kVulkan; } bool is_metal() const { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_metal(); } - constexpr auto metal_ks = DispatchKeySet(DispatchKey::Metal); - return key_set_.has_all(metal_ks); + return device_opt_.has_value() && device_opt_->type() == kMetal; } bool is_mps() const { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_mps(); } - return key_set_.has(DispatchKey::MPS); + return device_opt_.has_value() && device_opt_->type() == kMPS; } bool is_ort() const { if (C10_UNLIKELY(custom_device_)) { return device_custom().is_ort(); } - constexpr auto ort_ks = DispatchKeySet(DispatchKey::ORT); - return key_set_.has_all(ort_ks); + return device_opt_.has_value() && device_opt_->type() == kORT; } bool is_nested() const { @@ -962,6 +988,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } Layout layout() const { + if (C10_UNLIKELY(custom_layout_)) { + return layout_custom(); + } + // NB: This method is not virtual and avoid dispatches for perf. // strided is also the most common layout type, so we check for // strided case first. @@ -1091,6 +1121,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } + /** + * XXX: do not use, private api! + * Update the backend component related keys to the backend component + * corresponding to this device. + */ + void _change_backend_component_keys(c10::Device device); + /** * Whether or not the tensor is a zerotensor */ @@ -1738,7 +1775,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // // NB: this lives in header so that we can avoid actually creating the // c10::optional - c10::optional check_pyobj(impl::PyInterpreter* self_interpreter) { + c10::optional check_pyobj( + impl::PyInterpreter* self_interpreter) const { // Note [Memory ordering on Python interpreter tag] impl::PyInterpreter* interpreter = pyobj_interpreter_.load(std::memory_order_acquire); @@ -2069,6 +2107,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_non_overlapping_and_dense_; } + bool has_symbolic_sizes_strides() const { + return has_symbolic_sizes_strides_; + } + private: void HandleResize(); @@ -2338,6 +2380,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { custom_device_ = custom_device; } + void set_custom_layout(bool custom_layout) { + custom_layout_ = custom_layout; + } + protected: Storage storage_; @@ -2457,6 +2503,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { reserved_ = false; sizes_strides_policy_ = static_cast(SizesStridesPolicy::Default); custom_device_ = false; + custom_layout_ = false; storage_access_should_throw_ = false; has_symbolic_sizes_strides_ = false; } @@ -2527,6 +2574,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // Call _custom() virtual method for device() bool custom_device_ : 1; + // Call _custom() virtual method for layout() + bool custom_layout_ : 1; + // The set of DispatchKeys which describe this tensor. NB: this // does NOT include Autograd (historically, it did, but // not anymore!) diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index d0131ce5fa130..432fe4f1e4b6c 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -631,27 +631,23 @@ inline DispatchKey computeDispatchKey( case Layout::Strided: { const auto dtype_ = dtype_or_default(dtype); switch (device_.type()) { - case DeviceType::CPU: { - if (isQIntType(dtype_)) { - return DispatchKey::QuantizedCPU; - } - return DispatchKey::CPU; - } - case DeviceType::CUDA: { - if (isQIntType(dtype_)) { - return DispatchKey::QuantizedCUDA; - } - return DispatchKey::CUDA; - } - case DeviceType::IPU: { - return DispatchKey::IPU; - } - case DeviceType::XPU: { - if (isQIntType(dtype_)) { - return DispatchKey::QuantizedXPU; - } - return DispatchKey::XPU; - } +#define DO_CASE(device, _) \ + case DeviceType::device: { \ + if (isQIntType(dtype_)) { \ + return DispatchKey::Quantized##device; \ + } \ + return DispatchKey::device; \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + case DeviceType::FPGA: + return DispatchKey::FPGA; + case DeviceType::ORT: + return DispatchKey::ORT; + case DeviceType::Vulkan: + return DispatchKey::Vulkan; + case DeviceType::Metal: + return DispatchKey::Metal; case DeviceType::MKLDNN: case DeviceType::OPENGL: case DeviceType::OPENCL: @@ -661,31 +657,6 @@ inline DispatchKey computeDispatchKey( "This is a grandfathered Caffe2 device type ", device_.type(), ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error."); - case DeviceType::HIP: - return DispatchKey::HIP; - case DeviceType::VE: - return DispatchKey::VE; - case DeviceType::FPGA: - return DispatchKey::FPGA; - case DeviceType::ORT: - return DispatchKey::ORT; - case DeviceType::XLA: - return DispatchKey::XLA; - case DeviceType::Lazy: - return DispatchKey::Lazy; - case DeviceType::MPS: - return DispatchKey::MPS; - case DeviceType::Vulkan: - return DispatchKey::Vulkan; - case DeviceType::Metal: - return DispatchKey::Metal; - case DeviceType::Meta: - return DispatchKey::Meta; - case DeviceType::HPU: - return DispatchKey::HPU; - case DeviceType::PrivateUse1: { - return DispatchKey::PrivateUse1; - } default: TORCH_CHECK_NOT_IMPLEMENTED( false, @@ -695,16 +666,12 @@ inline DispatchKey computeDispatchKey( } case Layout::Sparse: switch (device_.type()) { - case DeviceType::CPU: - return DispatchKey::SparseCPU; - case DeviceType::CUDA: - return DispatchKey::SparseCUDA; - case DeviceType::HIP: - return DispatchKey::SparseHIP; - case DeviceType::VE: - return DispatchKey::SparseVE; - case DeviceType::XPU: - return DispatchKey::SparseXPU; +#define DO_CASE(device, _) \ + case DeviceType::device: { \ + return DispatchKey::Sparse##device; \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE default: TORCH_CHECK_NOT_IMPLEMENTED( false, @@ -744,12 +711,10 @@ inline DispatchKey computeDispatchKey( inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) { switch (dispatch_key) { - case DispatchKey::SparseCPU: - case DispatchKey::SparseCUDA: - case DispatchKey::SparseHIP: - case DispatchKey::SparseVE: - case DispatchKey::SparseXPU: - return Layout::Sparse; +#define DO_CASE(bc, _) case DispatchKey::Sparse##bc: + C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused) +#undef DO_CASE + return Layout::Sparse; case DispatchKey::SparseCsrCPU: case DispatchKey::SparseCsrCUDA: TORCH_CHECK( @@ -767,53 +732,21 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) { inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { switch (dispatch_key) { // stuff that's real - case DispatchKey::CPU: - case DispatchKey::SparseCPU: +#define DO_CASE(suffix, prefix) \ + case DispatchKey::prefix##suffix: \ + return DeviceType::suffix; +#define DO_CASES(_, prefix) C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, prefix) + C10_FORALL_FUNCTIONALITY_KEYS(DO_CASES) +#undef DO_CASES +#undef DO_CASE + case DispatchKey::MkldnnCPU: - case DispatchKey::QuantizedCPU: - case DispatchKey::AutogradCPU: return DeviceType::CPU; - case DispatchKey::CUDA: - case DispatchKey::SparseCUDA: - case DispatchKey::QuantizedCUDA: - case DispatchKey::AutogradCUDA: - return DeviceType::CUDA; - case DispatchKey::HIP: - case DispatchKey::SparseHIP: - return DeviceType::HIP; - case DispatchKey::VE: - case DispatchKey::SparseVE: - return DeviceType::VE; - case DispatchKey::XLA: - case DispatchKey::AutogradXLA: - return DeviceType::XLA; - case DispatchKey::Lazy: - case DispatchKey::AutogradLazy: - return DeviceType::Lazy; case DispatchKey::Vulkan: return DeviceType::Vulkan; - case DispatchKey::Meta: - return DeviceType::Meta; - - // stuff that people are actively developing - case DispatchKey::IPU: - case DispatchKey::AutogradIPU: - return DeviceType::IPU; - case DispatchKey::XPU: - case DispatchKey::SparseXPU: - case DispatchKey::QuantizedXPU: - case DispatchKey::AutogradXPU: - return DeviceType::XPU; - case DispatchKey::MPS: - case DispatchKey::AutogradMPS: - return DeviceType::MPS; - case DispatchKey::HPU: - case DispatchKey::AutogradHPU: - return DeviceType::HPU; + case DispatchKey::ORT: return DeviceType::ORT; - case DispatchKey::PrivateUse1: - return DeviceType::PrivateUse1; default: TORCH_CHECK( false, diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp index e951fe76562ed..eec1d23e66da1 100644 --- a/c10/core/impl/PyInterpreter.cpp +++ b/c10/core/impl/PyInterpreter.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -23,8 +24,7 @@ static c10::intrusive_ptr noop_detach_fn( static void noop_dispatch_fn( const PyInterpreter*, const c10::OperatorHandle& op, - torch::jit::Stack* stack, - const std::shared_ptr& type) { + torch::jit::Stack* stack) { TORCH_INTERNAL_ASSERT( 0, "attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died"); @@ -62,6 +62,20 @@ static c10::IntArrayRef noop_sizes_fn(const PyInterpreter*, const TensorImpl*) { "attempted to call `sizes` on Tensor with nontrivial PyObject after corresponding interpreter died"); } +static c10::SymIntArrayRef noop_sym_sizes_fn( + const PyInterpreter*, + const TensorImpl*) { + TORCH_INTERNAL_ASSERT( + 0, + "attempted to call `sym_sizes` on Tensor with nontrivial PyObject after corresponding interpreter died"); +} + +static c10::Layout noop_layout_fn(const PyInterpreter*, const TensorImpl*) { + TORCH_INTERNAL_ASSERT( + 0, + "attempted to call `layout` on Tensor with nontrivial PyObject after corresponding interpreter died"); +} + void PyInterpreter::disarm() noexcept { name_fn_ = &noop_name_fn; decref_fn_ = &noop_decref_fn; @@ -72,6 +86,8 @@ void PyInterpreter::disarm() noexcept { dim_fn_ = &noop_dim_fn; strides_fn_ = &noop_strides_fn; sizes_fn_ = &noop_sizes_fn; + sym_sizes_fn_ = &noop_sym_sizes_fn; + layout_fn_ = &noop_layout_fn; } } // namespace impl diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index 78182a4ddb61c..db3d9753b9dc6 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include #include #include @@ -125,14 +127,15 @@ struct C10_API PyInterpreter { using dispatch_sig = void( const PyInterpreter*, const c10::OperatorHandle&, - torch::jit::Stack* stack, - // This is a Tensor subclass type object - const std::shared_ptr& type); + torch::jit::Stack* stack); using is_contiguous_sig = bool(const PyInterpreter*, const TensorImpl*); using device_sig = c10::Device(const PyInterpreter*, const TensorImpl*); using dim_sig = int64_t(const PyInterpreter*, const TensorImpl*); using strides_sig = c10::IntArrayRef(const PyInterpreter*, const TensorImpl*); using sizes_sig = c10::IntArrayRef(const PyInterpreter*, const TensorImpl*); + using sym_sizes_sig = + c10::SymIntArrayRef(const PyInterpreter*, const TensorImpl*); + using layout_sig = c10::Layout(const PyInterpreter*, const TensorImpl*); PyInterpreter( name_sig* name_fn, @@ -143,7 +146,9 @@ struct C10_API PyInterpreter { device_sig* device_fn, dim_sig* dim_fn, strides_sig* strides, - sizes_sig* sizes) + sizes_sig* sizes, + sym_sizes_sig* sym_sizes, + layout_sig* layout) : name_fn_(name_fn), decref_fn_(decref_fn), detach_fn_(detach), @@ -152,7 +157,9 @@ struct C10_API PyInterpreter { device_fn_(device_fn), dim_fn_(dim_fn), strides_fn_(strides), - sizes_fn_(sizes) {} + sizes_fn_(sizes), + sym_sizes_fn_(sym_sizes), + layout_fn_(layout) {} name_sig* name_fn_; decref_sig* decref_fn_; @@ -163,6 +170,8 @@ struct C10_API PyInterpreter { dim_sig* dim_fn_; strides_sig* strides_fn_; sizes_sig* sizes_fn_; + sym_sizes_sig* sym_sizes_fn_; + layout_sig* layout_fn_; // UBSAN suppression fixes: "call to function // (anonymous namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*, @@ -192,9 +201,8 @@ struct C10_API PyInterpreter { // Invoke the Python boxed fallback dispatch to go back into Python __ubsan_ignore_function__ void dispatch( const c10::OperatorHandle& op, - torch::jit::Stack* stack, - const std::shared_ptr& type) const { - return (*dispatch_fn_)(this, op, stack, type); + torch::jit::Stack* stack) const { + return (*dispatch_fn_)(this, op, stack); } __ubsan_ignore_function__ bool is_contiguous(const TensorImpl* self) const { @@ -219,6 +227,15 @@ struct C10_API PyInterpreter { return (*sizes_fn_)(this, self); } + __ubsan_ignore_function__ c10::SymIntArrayRef sym_sizes( + const TensorImpl* self) const { + return (*sym_sizes_fn_)(this, self); + } + + __ubsan_ignore_function__ c10::Layout layout(const TensorImpl* self) const { + return (*layout_fn_)(this, self); + } + // Disarm this PyInterpreter, making all of its methods noops. // Because the function pointers are raw pointers (not atomics), // a disarm() invocation that is concurrent with active destructors diff --git a/c10/core/impl/SizesAndStrides.cpp b/c10/core/impl/SizesAndStrides.cpp index a4e2f17464b6e..b46725f8bf191 100644 --- a/c10/core/impl/SizesAndStrides.cpp +++ b/c10/core/impl/SizesAndStrides.cpp @@ -11,41 +11,34 @@ void SizesAndStrides::resizeSlowPath( !isInline(), "resizeSlowPath called when fast path should have been hit!"); SymInt* tempStorage = outOfLineStorage_; - memcpy( - &inlineStorage_[0], - &tempStorage[0], - C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); - memcpy( - &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], - &tempStorage[oldSize], - C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + for (size_t i = 0; i < C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; i++) { + inlineStorage_[i] = std::move(tempStorage[i]); + inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + i] = + std::move(tempStorage[oldSize + i]); + } // CANNOT USE freeOutOfLineStorage() HERE! outOfLineStorage_ // HAS BEEN OVERWRITTEN! // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) - free(tempStorage); + delete[] tempStorage; } else { if (isInline()) { // CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD // OVERWRITE inlineStorage_! // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) - SymInt* tempStorage = static_cast(malloc(storageBytes(newSize))); + SymInt* tempStorage = new SymInt[storageElems(newSize)]; TORCH_CHECK( tempStorage, "Could not allocate memory to change Tensor SizesAndStrides!"); - const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]); - const auto bytesToZero = (newSize > oldSize) - ? (newSize - oldSize) * sizeof(tempStorage[0]) - : 0; - memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy); - if (bytesToZero) { - memset(&tempStorage[oldSize], 0, bytesToZero); + const auto elemsToCopy = oldSize; + const auto elemsToZero = (newSize > oldSize) ? (newSize - oldSize) : 0; + for (size_t i = 0; i < elemsToCopy; i++) { + tempStorage[i] = std::move(inlineStorage_[i]); + tempStorage[newSize + i] = std::move( + inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + i]); } - memcpy( - &tempStorage[newSize], - &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], - bytesToCopy); - if (bytesToZero) { - memset(&tempStorage[newSize + oldSize], 0, bytesToZero); + for (size_t i = 0; i < elemsToZero; i++) { + tempStorage[oldSize + i] = 0; + tempStorage[newSize + oldSize + i] = 0; } outOfLineStorage_ = tempStorage; } else { @@ -57,19 +50,27 @@ void SizesAndStrides::resizeSlowPath( // Shift the old strides to their new starting point. Note // that this does not occur in the inline path above because // the stride starting point is not moving. - memmove( - outOfLineStorage_ + newSize, - outOfLineStorage_ + oldSize, - std::min(oldSize, newSize) * sizeof(outOfLineStorage_[0])); + if (isGrowing) { + std::move_backward( + outOfLineStorage_ + oldSize, + outOfLineStorage_ + oldSize + oldSize, + outOfLineStorage_ + newSize + oldSize); + } else { + std::move( + outOfLineStorage_ + oldSize, + outOfLineStorage_ + oldSize + newSize, + outOfLineStorage_ + newSize); + } if (!isGrowing) { // Resize after shifting so that we don't lose data. resizeOutOfLineStorage(newSize); } else { // Zero the end of the sizes portion. - const auto bytesToZero = - (newSize - oldSize) * sizeof(outOfLineStorage_[0]); - memset(&outOfLineStorage_[oldSize], 0, bytesToZero); - memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero); + const auto elemsToZero = newSize - oldSize; + for (size_t i = 0; i < elemsToZero; i++) { + outOfLineStorage_[oldSize + i] = 0; + outOfLineStorage_[newSize + oldSize + i] = 0; + } } } } diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h index 56f5398b6bc5c..603a3296cd9f7 100644 --- a/c10/core/impl/SizesAndStrides.h +++ b/c10/core/impl/SizesAndStrides.h @@ -39,7 +39,7 @@ class C10_API SizesAndStrides { ~SizesAndStrides() { if (C10_UNLIKELY(!isInline())) { - free(outOfLineStorage_); + delete[] outOfLineStorage_; } } @@ -58,7 +58,7 @@ class C10_API SizesAndStrides { } if (C10_LIKELY(rhs.isInline())) { if (C10_UNLIKELY(!isInline())) { - free(outOfLineStorage_); + delete[] outOfLineStorage_; } copyDataInline(rhs); } else { @@ -76,7 +76,9 @@ class C10_API SizesAndStrides { // Move from rhs. rhs.size() == 0 afterwards. SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { if (C10_LIKELY(isInline())) { - memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + for (size_t i = 0; i < sizeof(inlineStorage_) / sizeof(SymInt); i++) { + inlineStorage_[i] = std::move(rhs.inlineStorage_[i]); + } } else { outOfLineStorage_ = rhs.outOfLineStorage_; rhs.outOfLineStorage_ = nullptr; @@ -92,13 +94,13 @@ class C10_API SizesAndStrides { } if (C10_LIKELY(rhs.isInline())) { if (C10_UNLIKELY(!isInline())) { - free(outOfLineStorage_); + delete[] outOfLineStorage_; } copyDataInline(rhs); } else { // They're outline. We're going to steal their vector. if (!isInline()) { - free(outOfLineStorage_); + delete[] outOfLineStorage_; } outOfLineStorage_ = rhs.outOfLineStorage_; rhs.outOfLineStorage_ = nullptr; @@ -269,13 +271,10 @@ class C10_API SizesAndStrides { if (C10_LIKELY( newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { if (oldSize < newSize) { - const auto bytesToZero = - (newSize - oldSize) * sizeof(inlineStorage_[0]); - memset(&inlineStorage_[oldSize], 0, bytesToZero); - memset( - &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], - 0, - bytesToZero); + for (size_t i = oldSize; i < newSize; i++) { + inlineStorage_[i] = 0; + inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + i] = 0; + } } size_ = newSize; } else { @@ -292,15 +291,17 @@ class C10_API SizesAndStrides { void copyDataInline(const SizesAndStrides& rhs) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); - memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + for (size_t i = 0; i < sizeof(inlineStorage_) / sizeof(SymInt); i++) { + inlineStorage_[i] = rhs.inlineStorage_[i]; + } } - static size_t storageBytes(size_t size) noexcept { - return size * 2 * sizeof(int64_t); + static size_t storageElems(size_t size) noexcept { + return size * 2; } void allocateOutOfLineStorage(size_t size) { - outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + outOfLineStorage_ = new SymInt[storageElems(size)]; TORCH_CHECK( outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); @@ -308,15 +309,21 @@ class C10_API SizesAndStrides { void resizeOutOfLineStorage(size_t newSize) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); - outOfLineStorage_ = - static_cast(realloc(outOfLineStorage_, storageBytes(newSize))); + auto* newStorage = new SymInt[storageElems(newSize)]; TORCH_CHECK( - outOfLineStorage_, - "Could not allocate memory for Tensor SizesAndStrides!"); + newStorage, "Could not allocate memory for Tensor SizesAndStrides!"); + for (size_t i = 0; i < storageElems(newSize) && i < storageElems(size_); + i++) { + newStorage[i] = std::move(outOfLineStorage_[i]); + } + delete[] outOfLineStorage_; + outOfLineStorage_ = newStorage; } void copyDataOutline(const SizesAndStrides& rhs) noexcept { - memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); + for (size_t i = 0; i < storageElems(rhs.size_); i++) { + outOfLineStorage_[i] = rhs.outOfLineStorage_[i]; + } } size_t size_; diff --git a/c10/cuda/CUDAAlgorithm.h b/c10/cuda/CUDAAlgorithm.h new file mode 100644 index 0000000000000..166c5264e5bdf --- /dev/null +++ b/c10/cuda/CUDAAlgorithm.h @@ -0,0 +1,33 @@ +#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS +#include +#include +#include +#include +#endif +namespace c10 { +namespace cuda { +#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS +template +__forceinline__ __device__ Iter +lower_bound(Iter start, Iter end, Scalar value) { + return thrust::lower_bound(thrust::device, start, end, value); +} +#else +// thrust::lower_bound is broken on device, see +// https://github.com/NVIDIA/thrust/issues/1734 Implementation inspired by +// https://github.com/pytorch/pytorch/blob/805120ab572efef66425c9f595d9c6c464383336/aten/src/ATen/native/cuda/Bucketization.cu#L28 +template +__device__ Iter lower_bound(Iter start, Iter end, Scalar value) { + while (start < end) { + auto mid = start + ((end - start) >> 1); + if (*mid < value) { + start = mid + 1; + } else { + end = mid; + } + } + return end; +} +#endif // THRUST_DEVICE_LOWER_BOUND_WORKS +} // namespace cuda +} // namespace c10 diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index d23eae8e261ea..d60f6960e9f91 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -611,6 +611,13 @@ class DeviceCachingAllocator { stats.num_ooms += 1; + c10::reportOutOfMemoryToProfiler( + size, + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current, + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current, + c10::Device(c10::DeviceType::CUDA, static_cast(device))); // "total capacity": total global memory on GPU // "allowed": memory is allowed to use, which set by fraction. // "already allocated": memory allocated by the program using the diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 36536674a12ea..b7fc04b50a8c2 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -16,7 +17,7 @@ namespace cuda { namespace { // Global stream state and constants -static std::once_flag init_flag; +static c10::once_flag init_flag; static DeviceIndex num_gpus = -1; static constexpr int kStreamsPerPoolBits = 5; static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; @@ -40,7 +41,7 @@ static constexpr int kLowPriority = 0; // already been destroyed and thus invoking cudaStreamDestroy could lead to a // crash. It's likely an issue in CUDA, but to be safe - let's just "forget" // the destruction. -static std::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS]; +static c10::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS]; static std::atomic low_priority_counters[C10_COMPILE_TIME_MAX_GPUS]; static std::atomic high_priority_counters[C10_COMPILE_TIME_MAX_GPUS]; static cudaStream_t low_priority_streams[C10_COMPILE_TIME_MAX_GPUS] @@ -173,7 +174,7 @@ static void initDeviceStreamState(DeviceIndex device_index) { // Init front-end to ensure initialization only occurs once static void initCUDAStreamsOnce() { // Inits default streams (once, globally) - std::call_once(init_flag, initGlobalStreamState); + c10::call_once(init_flag, initGlobalStreamState); if (current_streams) { return; @@ -256,7 +257,7 @@ CUDAStream getStreamFromPool( check_gpu(device_index); // Initializes the stream pools (once) - std::call_once( + c10::call_once( device_flags[device_index], initDeviceStreamState, device_index); if (isHighPriority) { diff --git a/c10/test/core/DispatchKeySet_test.cpp b/c10/test/core/DispatchKeySet_test.cpp index 5f7e0193693ea..997efde4a9ce6 100644 --- a/c10/test/core/DispatchKeySet_test.cpp +++ b/c10/test/core/DispatchKeySet_test.cpp @@ -182,7 +182,7 @@ TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) { if (tid == DispatchKey::StartOfDenseBackends || tid == DispatchKey::StartOfSparseBackends || tid == DispatchKey::StartOfQuantizedBackends || - tid == DispatchKey::StartOfAutogradBackends) { + tid == DispatchKey::StartOfAutogradFunctionalityBackends) { continue; } DispatchKeySet sing(tid); @@ -225,13 +225,13 @@ TEST(DispatchKeySet, DoubletonPerBackend) { tid1 == DispatchKey::StartOfSparseBackends || tid1 == DispatchKey::StartOfQuantizedBackends || tid1 == DispatchKey::StartOfNestedTensorBackends || - tid1 == DispatchKey::StartOfAutogradBackends) + tid1 == DispatchKey::StartOfAutogradFunctionalityBackends) continue; if (tid2 == DispatchKey::StartOfDenseBackends || tid2 == DispatchKey::StartOfSparseBackends || tid2 == DispatchKey::StartOfQuantizedBackends || tid2 == DispatchKey::StartOfNestedTensorBackends || - tid2 == DispatchKey::StartOfAutogradBackends) + tid2 == DispatchKey::StartOfAutogradFunctionalityBackends) continue; auto backend1 = toBackendComponent(tid1); @@ -386,41 +386,45 @@ TEST(DispatchKeySet, FailAtEndIterator) { c10::Error); } -TEST(DispatchKeySet, TestKeyOrderingInvariants) { - for (uint8_t i = static_cast(DispatchKey::StartOfDenseBackends); - i <= static_cast(DispatchKey::EndOfRuntimeBackendKeys); +TEST(DispatchKeySet, TestBackendComponentToString) { + std::unordered_set seen_strings; + for (int64_t i = 0; + i <= static_cast(BackendComponent::EndOfBackendKeys); i++) { + auto k = static_cast(i); + auto res = std::string(toString(k)); + ASSERT_FALSE(res == "UNKNOWN_BACKEND_BIT"); + ASSERT_FALSE(seen_strings.count(res) > 0); + seen_strings.insert(res); + } +} + +TEST(DispatchKeySet, TestEndOfRuntimeBackendKeysAccurate) { + DispatchKey k; +#define SETTER(fullname, prefix) k = DispatchKey::EndOf##fullname##Backends; + C10_FORALL_FUNCTIONALITY_KEYS(SETTER) +#undef SETTER + ASSERT_TRUE(k == DispatchKey::EndOfRuntimeBackendKeys); +} + +TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) { + std::unordered_set seen_strings; + for (int i = 0; i <= static_cast(DispatchKey::EndOfAliasKeys); i++) { auto k = static_cast(i); - // Note [The Ordering of Per-Backend Dispatch Keys Matters!] - // The DispatchKey enum includes all of the runtime keys for - // Dense/Sparse/Quantized/Autograd, (e.g. CPU, CUDA, SparseCPU, SparseCUDA, - // AutogradCPU, AutogradCUDA, etc). And we expect the ordering of those keys - // to be the same as the ordering of the backends in the `BackendComponent` - // enum. This makes several utilities in `DispatchKey.h` and - // `DispatchKeySet.h` significantly easier to implement. The purpose of the - // test is to assert (through CI) that this invariant is maintained. - // - // The only way that we can really check this invariant is by - // comparing the string names of each enum. - // We only really care about the ordering for "real" keys that are actually - // used, which we expect to be able to print properly. This saves us from - // having to enumerate the full set of possible runtime keys in - // DispatchKey::toString(). It also relies on toString() being implemented - // correctly. - auto functionality_str = std::string(toString(k)); - if (functionality_str == "UNKNOWN_TENSOR_TYPE_ID") + // These synthetic keys never actually get used and don't need + // to be printed + if (k == DispatchKey::EndOfFunctionalityKeys || + k == DispatchKey::StartOfDenseBackends || + k == DispatchKey::StartOfQuantizedBackends || + k == DispatchKey::StartOfSparseBackends || + k == DispatchKey::StartOfNestedTensorBackends || + k == DispatchKey::StartOfAutogradFunctionalityBackends) continue; - - auto computed_backend_k = toBackendComponent(k); - auto computed_backend_str = std::string(toString(computed_backend_k)); - // Skip, e.g., the "Bit" from "CPUBit" - computed_backend_str = - computed_backend_str.substr(0, computed_backend_str.size() - 3); - - ASSERT_TRUE( - functionality_str.find(computed_backend_str) != std::string::npos) - << "DispatchKey invariant broken! Found a key that is not ordered correctly" - << " with its backend bit. key = " << toString(k) << ", " << k - << ", computed backend = " << toString(computed_backend_k); + auto res = std::string(toString(k)); + ASSERT_TRUE(res.find("Unknown") == std::string::npos) + << i << " (before is " << toString(static_cast(i - 1)) + << ")"; + ASSERT_TRUE(seen_strings.count(res) == 0); + seen_strings.insert(res); } } diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index 8f43147ba273c..8892cce015daa 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include using namespace c10; @@ -9,7 +9,7 @@ void check(int64_t value) { EXPECT_TRUE(SymInt::check_range(value)); const auto i = SymInt(value); EXPECT_FALSE(i.is_symbolic()); - EXPECT_EQ(i.data(), value); + EXPECT_EQ(i.as_int_unchecked(), value); } TEST(SymIntTest, ConcreteInts) { @@ -21,7 +21,7 @@ TEST(SymIntTest, ConcreteInts) { } TEST(SymIntTest, AddNode) { - auto n = std::make_shared(); + auto n = c10::make_intrusive(); auto i = n->toSymInt(); EXPECT_TRUE(i.is_symbolic()); } diff --git a/c10/test/util/logging_test.cpp b/c10/test/util/logging_test.cpp index 8d5fd8dc00292..a521856509389 100644 --- a/c10/test/util/logging_test.cpp +++ b/c10/test/util/logging_test.cpp @@ -141,7 +141,7 @@ TEST(LoggingTest, Join) { TEST(LoggingTest, TestDanglingElse) { if (true) - DCHECK_EQ(1, 1); + TORCH_DCHECK_EQ(1, 1); else GTEST_FAIL(); } diff --git a/c10/util/C++17.h b/c10/util/C++17.h index 8ef9be69bdedf..107042bf17542 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -36,6 +36,19 @@ */ namespace c10 { + +// in c++17 std::result_of has been superceded by std::invoke_result. Since +// c++20, std::result_of is removed. +template +#if defined(__cpp_lib_is_invocable) && __cpp_lib_is_invocable >= 201703L +using invoke_result = typename std::invoke_result; +#else +using invoke_result = typename std::result_of; +#endif + +template +using invoke_result_t = typename invoke_result::type; + namespace guts { template @@ -164,7 +177,7 @@ CUDA_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) { template typename std::enable_if< std::is_member_pointer::type>::value, - typename std::result_of::type>::type + typename c10::invoke_result_t>::type invoke(Functor&& f, Args&&... args) { return std::mem_fn(std::forward(f))(std::forward(args)...); } @@ -172,7 +185,7 @@ invoke(Functor&& f, Args&&... args) { template typename std::enable_if< !std::is_member_pointer::type>::value, - typename std::result_of::type>::type + typename c10::invoke_result_t>::type invoke(Functor&& f, Args&&... args) { return std::forward(f)(std::forward(args)...); } diff --git a/c10/util/CallOnce.h b/c10/util/CallOnce.h new file mode 100644 index 0000000000000..a31600ef2e7c0 --- /dev/null +++ b/c10/util/CallOnce.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +// custom c10 call_once implementation to avoid the deadlock in std::call_once. +// The implementation here is a simplified version from folly and likely much +// much higher memory footprint. +template +inline void call_once(Flag& flag, F&& f, Args&&... args) { + if (C10_LIKELY(flag.test_once())) { + return; + } + flag.call_once_slow(std::forward(f), std::forward(args)...); +} + +class once_flag { + public: +#ifndef _WIN32 + // running into build error on MSVC. Can't seem to get a repro locally so I'm + // just avoiding constexpr + // + // C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error: + // defaulted default constructor cannot be constexpr because the + // corresponding implicitly declared default constructor would not be + // constexpr 1 error detected in the compilation of + // "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu". + constexpr +#endif + once_flag() noexcept = default; + once_flag(const once_flag&) = delete; + once_flag& operator=(const once_flag&) = delete; + + private: + template + friend void call_once(Flag& flag, F&& f, Args&&... args); + + template + void call_once_slow(F&& f, Args&&... args) { + std::lock_guard guard(mutex_); + if (init_.load(std::memory_order_relaxed)) { + return; + } + c10::guts::invoke(f, std::forward(args)...); + init_.store(true, std::memory_order_release); + } + + bool test_once() { + return init_.load(std::memory_order_acquire); + } + + void reset_once() { + init_.store(false, std::memory_order_release); + } + + private: + std::mutex mutex_; + std::atomic init_{false}; +}; + +} // namespace c10 diff --git a/c10/util/ExclusivelyOwnedTensorTraits.h b/c10/util/ExclusivelyOwnedTensorTraits.h new file mode 100644 index 0000000000000..143b4df0a4e5f --- /dev/null +++ b/c10/util/ExclusivelyOwnedTensorTraits.h @@ -0,0 +1,74 @@ +#pragma once + +#include + +#include + +namespace c10 { +// Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and +// at::TensorBase. +template +struct ExclusivelyOwnedTensorTraits { + using repr_type = TensorType; + using pointer_type = TensorType*; + using const_pointer_type = const TensorType*; + + static repr_type nullRepr() { + return TensorType(); + } + + template + static repr_type createInPlace(Args&&... args) { + return TensorType(std::forward(args)...); + } + + static repr_type moveToRepr(TensorType&& x) { + return std::move(x); + } + + static void destroyOwned(TensorType& x) { + TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy != nullptr, "Tensor somehow got null TensorImpl?"); + // May be 0 because UndefinedTensorImpl doesn't get its refcount + // incremented. + const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined), + "ExclusivelyOwned destroyed with isUndefined ", + isUndefined, + " and refcount ", + toDestroy->refcount_, + ", expected 1 or, if isUndefined, 0!"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->weakcount_ == 1 || + (toDestroy->weakcount_ == 0 && + toDestroy == UndefinedTensorImpl::singleton()), + "ExclusivelyOwned destroyed with isUndefined ", + isUndefined, + " and weakcount ", + toDestroy->weakcount_, + ", expected 1 or, if isUndefined, 0!"); + if (!isUndefined) { +#ifndef NDEBUG + // Needed to pass the debug assertions in ~intrusive_ptr_target. + toDestroy->refcount_ = 0; + toDestroy->weakcount_ = 0; +#endif + delete toDestroy; + } + } + + static TensorType take(TensorType& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; +} // namespace c10 diff --git a/c10/util/LeftRight.h b/c10/util/LeftRight.h index e45267cb8f7e3..a399c61bef8c7 100644 --- a/c10/util/LeftRight.h +++ b/c10/util/LeftRight.h @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -78,7 +79,7 @@ class LeftRight final { } template - auto read(F&& readFunc) const -> typename std::result_of::type { + auto read(F&& readFunc) const -> typename c10::invoke_result_t { detail::IncrementRAII _increment_counter( &_counters[_foregroundCounterIndex.load()]); @@ -89,7 +90,7 @@ class LeftRight final { // the old or the new state, depending on if the first or the second call to // writeFunc threw. template - auto write(F&& writeFunc) -> typename std::result_of::type { + auto write(F&& writeFunc) -> typename c10::invoke_result_t { std::unique_lock lock(_writeMutex); return _write(writeFunc); @@ -97,7 +98,7 @@ class LeftRight final { private: template - auto _write(const F& writeFunc) -> typename std::result_of::type { + auto _write(const F& writeFunc) -> typename c10::invoke_result_t { /* * Assume, A is in background and B in foreground. In simplified terms, we * want to do the following: @@ -165,7 +166,7 @@ class LeftRight final { template auto _callWriteFuncOnBackgroundInstance( const F& writeFunc, - uint8_t localDataIndex) -> typename std::result_of::type { + uint8_t localDataIndex) -> typename c10::invoke_result_t { try { return writeFunc(_data[localDataIndex ^ 1]); } catch (...) { @@ -205,13 +206,13 @@ class RWSafeLeftRightWrapper final { RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; template - auto read(F&& readFunc) const -> typename std::result_of::type { + auto read(F&& readFunc) const -> typename c10::invoke_result_t { return data_.withLock( [&readFunc](T const& data) { return readFunc(data); }); } template - auto write(F&& writeFunc) -> typename std::result_of::type { + auto write(F&& writeFunc) -> typename c10::invoke_result_t { return data_.withLock([&writeFunc](T& data) { return writeFunc(data); }); } diff --git a/c10/util/Logging.h b/c10/util/Logging.h index e2ed61de606fb..b25d7841e3f40 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -180,7 +180,7 @@ using EnforceNotMet = ::c10::Error; * With further usages like `CAFFE_ENFORCE_THAT(IsVector(Input(0).dims()))` * * Convenient wrappers for binary operations like CAFFE_ENFORCE_EQ are provided - * too. Please use them instead of CHECK_EQ and friends for failures in + * too. Please use them instead of TORCH_CHECK_EQ and friends for failures in * user-provided input. */ diff --git a/c10/util/OptionalArrayRef.h b/c10/util/OptionalArrayRef.h index 7ca375d7cb785..ff51f549a5644 100644 --- a/c10/util/OptionalArrayRef.h +++ b/c10/util/OptionalArrayRef.h @@ -74,6 +74,9 @@ class OptionalArrayRef final { Args&&... args) : wrapped_opt_array_ref(ip, il, args...) {} + constexpr OptionalArrayRef(const std::initializer_list& Vec) + : wrapped_opt_array_ref(ArrayRef(Vec)) {} + // Destructor ~OptionalArrayRef() = default; diff --git a/c10/util/Registry.h b/c10/util/Registry.h index 7338db78746f8..beca2ffaad003 100644 --- a/c10/util/Registry.h +++ b/c10/util/Registry.h @@ -64,10 +64,10 @@ class Registry { const RegistryPriority priority = REGISTRY_DEFAULT) { std::lock_guard lock(register_mutex_); // The if statement below is essentially the same as the following line: - // CHECK_EQ(registry_.count(key), 0) << "Key " << key + // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key // << " registered twice."; - // However, CHECK_EQ depends on google logging, and since registration is - // carried out at static initialization time, we do not want to have an + // However, TORCH_CHECK_EQ depends on google logging, and since registration + // is carried out at static initialization time, we do not want to have an // explicit dependency on glog's initialization function. if (registry_.count(key) != 0) { auto cur_priority = priority_[key]; diff --git a/c10/util/StringUtil.cpp b/c10/util/StringUtil.cpp index b4086d1f208b0..bf0eb5719a4e8 100644 --- a/c10/util/StringUtil.cpp +++ b/c10/util/StringUtil.cpp @@ -41,7 +41,39 @@ size_t ReplaceAll(std::string& s, c10::string_view from, c10::string_view to) { size_t numReplaced = 0; std::string::size_type last_pos = 0u; std::string::size_type cur_pos = 0u; + std::string::size_type write_pos = 0u; const c10::string_view input(s); + + if (from.size() >= to.size()) { + // If the replacement string is not larger than the original, we + // can do the replacement in-place without allocating new storage. + char* s_data = &s[0]; + + while ((cur_pos = s.find(from.data(), last_pos, from.size())) != + std::string::npos) { + ++numReplaced; + // Append input between replaced sub-strings + if (write_pos != last_pos) { + std::copy(s_data + last_pos, s_data + cur_pos, s_data + write_pos); + } + write_pos += cur_pos - last_pos; + // Append the replacement sub-string + std::copy(to.begin(), to.end(), s_data + write_pos); + write_pos += to.size(); + // Start search from next character after `from` + last_pos = cur_pos + from.size(); + } + + // Append any remaining input after replaced sub-strings + if (write_pos != last_pos) { + std::copy(s_data + last_pos, s_data + input.size(), s_data + write_pos); + write_pos += input.size() - last_pos; + s.resize(write_pos); + } + return numReplaced; + } + + // Otherwise, do an out-of-place replacement in a temporary buffer std::string buffer; while ((cur_pos = s.find(from.data(), last_pos, from.size())) != @@ -58,6 +90,7 @@ size_t ReplaceAll(std::string& s, c10::string_view from, c10::string_view to) { // If nothing was replaced, don't modify the input return 0; } + // Append any remaining input after replaced sub-strings buffer.append(input.begin() + last_pos, input.end()); s = std::move(buffer); return numReplaced; diff --git a/c10/util/Synchronized.h b/c10/util/Synchronized.h index 1679d7060fe05..4bc4e2e7d0cea 100644 --- a/c10/util/Synchronized.h +++ b/c10/util/Synchronized.h @@ -2,6 +2,8 @@ #include +#include + namespace c10 { /** @@ -42,7 +44,7 @@ class Synchronized final { * provided callback safely. */ template - typename std::result_of::type withLock(CB cb) { + typename c10::invoke_result_t withLock(CB cb) { std::lock_guard guard(this->mutex_); return cb(this->data_); } @@ -53,7 +55,7 @@ class Synchronized final { * the provided callback safely. */ template - typename std::result_of::type withLock(CB cb) const { + typename c10::invoke_result_t withLock(CB cb) const { std::lock_guard guard(this->mutex_); return cb(this->data_); } diff --git a/c10/util/TypeList.h b/c10/util/TypeList.h index ac3083dc9db65..c119dc15873ef 100644 --- a/c10/util/TypeList.h +++ b/c10/util/TypeList.h @@ -499,8 +499,9 @@ struct map_types_to_values final { template struct map_types_to_values> final { template - static std::tuple)>...> call(Func&& func) { - return std::tuple)>...>{ + static std::tuple>...> call( + Func&& func) { + return std::tuple>...>{ std::forward(func)(type_())...}; } }; diff --git a/c10/util/flat_hash_map.h b/c10/util/flat_hash_map.h index 19cca32e2793b..0e5afed2a6ef5 100644 --- a/c10/util/flat_hash_map.h +++ b/c10/util/flat_hash_map.h @@ -31,11 +31,6 @@ C10_CLANG_DIAGNOSTIC_PUSH() C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") #endif -#ifndef _MSC_VER -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wshadow" -#endif - #ifdef _MSC_VER #define SKA_NOINLINE(...) __declspec(noinline) __VA_ARGS__ #else @@ -645,8 +640,8 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { deallocate_data(new_buckets, num_buckets, old_max_lookups); } - void reserve(uint64_t num_elements) { - uint64_t required_buckets = num_buckets_for_reserve(num_elements); + void reserve(uint64_t num_elements_) { + uint64_t required_buckets = num_buckets_for_reserve(num_elements_); if (required_buckets > bucket_count()) rehash(required_buckets); } @@ -789,9 +784,9 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { return std::max(detailv3::min_lookups, desired); } - uint64_t num_buckets_for_reserve(uint64_t num_elements) const { + uint64_t num_buckets_for_reserve(uint64_t num_elements_) const { return static_cast(std::ceil( - num_elements / std::min(0.5, static_cast(_max_load_factor)))); + num_elements_ / std::min(0.5, static_cast(_max_load_factor)))); } void rehash_for_other_container(const sherwood_v3_table& other) { rehash( @@ -859,10 +854,10 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { void deallocate_data( EntryPointer begin, - uint64_t num_slots_minus_one, - int8_t max_lookups) { + uint64_t num_slots_minus_one_, + int8_t max_lookups_) { AllocatorTraits::deallocate( - *this, begin, num_slots_minus_one + max_lookups + 1); + *this, begin, num_slots_minus_one_ + max_lookups_ + 1); } void reset_to_empty_state() { @@ -1909,8 +1904,8 @@ struct fibonacci_hash_policy { size = std::max(uint64_t(2), detailv3::next_power_of_two(size)); return 64 - detailv3::log2(size); } - void commit(int8_t shift) { - this->shift = shift; + void commit(int8_t shift_) { + shift = shift_; } void reset() { shift = 63; @@ -2106,8 +2101,4 @@ struct power_of_two_std_hash : std::hash { } // end namespace ska -#ifndef _MSC_VER -#pragma GCC diagnostic pop -#endif - C10_CLANG_DIAGNOSTIC_POP() diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 5e31747a3f7b4..c87305b08be57 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -24,9 +24,6 @@ namespace intrusive_ptr { inline void incref(intrusive_ptr_target* self); } -template -struct ExclusivelyOwnedTraits; - // constructor tag used by intrusive_ptr constructors struct DontIncreaseRefcount {}; } // namespace raw @@ -92,7 +89,7 @@ class C10_API intrusive_ptr_target { intrusive_ptr_target* self); template - friend struct ExclusivelyOwnedTraits; + friend struct ExclusivelyOwnedTensorTraits; protected: // protected destructor. We never want to destruct intrusive_ptr_target* @@ -249,7 +246,7 @@ class intrusive_ptr final { TTarget* target_; template - friend struct ExclusivelyOwnedTraits; + friend struct ExclusivelyOwnedTensorTraits; template friend class intrusive_ptr; friend class weak_intrusive_ptr; diff --git a/c10/util/logging_is_google_glog.h b/c10/util/logging_is_google_glog.h index 0849508026a5b..b5860d8c0c9f4 100644 --- a/c10/util/logging_is_google_glog.h +++ b/c10/util/logging_is_google_glog.h @@ -50,6 +50,71 @@ INSTANTIATE_FOR_CONTAINER(set) #include // Additional macros on top of glog +#ifndef NDEBUG +#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2) +#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2) +#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2) +#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2) +#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2) +#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2) +#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2) +#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2) +#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2) +#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2) +#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2) +#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2) +#else // !NDEBUG +// These versions generate no code in optimized mode. +#define TORCH_CHECK_EQ(val1, val2) \ + while (false) \ + CHECK_EQ(val1, val2) +#define TORCH_CHECK_NE(val1, val2) \ + while (false) \ + CHECK_NE(val1, val2) +#define TORCH_CHECK_LE(val1, val2) \ + while (false) \ + CHECK_LE(val1, val2) +#define TORCH_CHECK_LT(val1, val2) \ + while (false) \ + CHECK_LT(val1, val2) +#define TORCH_CHECK_GE(val1, val2) \ + while (false) \ + CHECK_GE(val1, val2) +#define TORCH_CHECK_GT(val1, val2) \ + while (false) \ + CHECK_GT(val1, val2) +#define TORCH_DCHECK_EQ(val1, val2) \ + while (false) \ + DCHECK_EQ(val1, val2) +#define TORCH_DCHECK_NE(val1, val2) \ + while (false) \ + DCHECK_NE(val1, val2) +#define TORCH_DCHECK_LE(val1, val2) \ + while (false) \ + DCHECK_LE(val1, val2) +#define TORCH_DCHECK_LT(val1, val2) \ + while (false) \ + DCHECK_LT(val1, val2) +#define TORCH_DCHECK_GE(val1, val2) \ + while (false) \ + DCHECK_GE(val1, val2) +#define TORCH_DCHECK_GT(val1, val2) \ + while (false) \ + DCHECK_GT(val1, val2) +#endif // NDEBUG + +// Check that a pointer is not null. +#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val) + +#ifndef NDEBUG +// Debug only version of TORCH_CHECK_NOTNULL +#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val) +#else // !NDEBUG +// Optimized version - generates no code. +#define TORCH_DCHECK_NOTNULL(val) \ + while (false) \ + DCHECK_NOTNULL(val) +#endif // NDEBUG // Log with source location information override (to be used in generic // warning/error handlers implemented as functions, not macros) diff --git a/c10/util/logging_is_not_google_glog.h b/c10/util/logging_is_not_google_glog.h index 7a3d9ff237961..d27cc18e45300 100644 --- a/c10/util/logging_is_not_google_glog.h +++ b/c10/util/logging_is_not_google_glog.h @@ -61,8 +61,8 @@ void LogMessageFatal(const char* file, int line, const T& message) { MessageLogger(file, line, GLOG_FATAL).stream() << message; } -// Helpers for CHECK_NOTNULL(). Two are necessary to support both raw pointers -// and smart pointers. +// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw +// pointers and smart pointers. template T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) { if (t == nullptr) { @@ -136,63 +136,63 @@ static_assert( ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream() #endif // NDEBUG -#define CHECK_OP(val1, val2, op) \ +#define TORCH_CHECK_OP(val1, val2, op) \ FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \ << (val1) << " vs. " << (val2) << ") " -// Check_op macro definitions -#define CHECK_EQ(val1, val2) CHECK_OP(val1, val2, ==) -#define CHECK_NE(val1, val2) CHECK_OP(val1, val2, !=) -#define CHECK_LE(val1, val2) CHECK_OP(val1, val2, <=) -#define CHECK_LT(val1, val2) CHECK_OP(val1, val2, <) -#define CHECK_GE(val1, val2) CHECK_OP(val1, val2, >=) -#define CHECK_GT(val1, val2) CHECK_OP(val1, val2, >) +// TORCH_CHECK_OP macro definitions +#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==) +#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=) +#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=) +#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <) +#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=) +#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >) #ifndef NDEBUG -// Debug only versions of CHECK_OP macros. -#define DCHECK_EQ(val1, val2) CHECK_OP(val1, val2, ==) -#define DCHECK_NE(val1, val2) CHECK_OP(val1, val2, !=) -#define DCHECK_LE(val1, val2) CHECK_OP(val1, val2, <=) -#define DCHECK_LT(val1, val2) CHECK_OP(val1, val2, <) -#define DCHECK_GE(val1, val2) CHECK_OP(val1, val2, >=) -#define DCHECK_GT(val1, val2) CHECK_OP(val1, val2, >) +// Debug only versions of TORCH_CHECK_OP macros. +#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==) +#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=) +#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=) +#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <) +#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=) +#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >) #else // !NDEBUG // These versions generate no code in optimized mode. -#define DCHECK_EQ(val1, val2) \ - while (false) \ - CHECK_OP(val1, val2, ==) -#define DCHECK_NE(val1, val2) \ - while (false) \ - CHECK_OP(val1, val2, !=) -#define DCHECK_LE(val1, val2) \ - while (false) \ - CHECK_OP(val1, val2, <=) -#define DCHECK_LT(val1, val2) \ - while (false) \ - CHECK_OP(val1, val2, <) -#define DCHECK_GE(val1, val2) \ - while (false) \ - CHECK_OP(val1, val2, >=) -#define DCHECK_GT(val1, val2) \ - while (false) \ - CHECK_OP(val1, val2, >) +#define TORCH_DCHECK_EQ(val1, val2) \ + while (false) \ + TORCH_CHECK_OP(val1, val2, ==) +#define TORCH_DCHECK_NE(val1, val2) \ + while (false) \ + TORCH_CHECK_OP(val1, val2, !=) +#define TORCH_DCHECK_LE(val1, val2) \ + while (false) \ + TORCH_CHECK_OP(val1, val2, <=) +#define TORCH_DCHECK_LT(val1, val2) \ + while (false) \ + TORCH_CHECK_OP(val1, val2, <) +#define TORCH_DCHECK_GE(val1, val2) \ + while (false) \ + TORCH_CHECK_OP(val1, val2, >=) +#define TORCH_DCHECK_GT(val1, val2) \ + while (false) \ + TORCH_CHECK_OP(val1, val2, >) #endif // NDEBUG // Check that a pointer is not null. -#define CHECK_NOTNULL(val) \ - ::c10::CheckNotNull( \ +#define TORCH_CHECK_NOTNULL(val) \ + ::c10::CheckNotNull( \ __FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val)) #ifndef NDEBUG -// Debug only version of CHECK_NOTNULL -#define DCHECK_NOTNULL(val) \ - ::c10::CheckNotNull( \ +// Debug only version of TORCH_CHECK_NOTNULL +#define TORCH_DCHECK_NOTNULL(val) \ + ::c10::CheckNotNull( \ __FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val)) #else // !NDEBUG // Optimized version - generates no code. -#define DCHECK_NOTNULL(val) \ - while (false) \ - CHECK_NOTNULL(val) +#define TORCH_DCHECK_NOTNULL(val) \ + while (false) \ + TORCH_CHECK_NOTNULL(val) #endif // NDEBUG // ---------------------- Support for std objects -------------------------- diff --git a/c10/util/order_preserving_flat_hash_map.h b/c10/util/order_preserving_flat_hash_map.h index 74021f0b55e81..58994fc57abdf 100644 --- a/c10/util/order_preserving_flat_hash_map.h +++ b/c10/util/order_preserving_flat_hash_map.h @@ -33,11 +33,6 @@ C10_CLANG_DIAGNOSTIC_PUSH() C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") #endif -#ifndef _MSC_VER -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wshadow" -#endif - #ifdef _MSC_VER #define SKA_NOINLINE(...) __declspec(noinline) __VA_ARGS__ #else @@ -643,8 +638,8 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { deallocate_data(new_buckets, num_buckets, old_max_lookups); } - void reserve(uint64_t num_elements) { - uint64_t required_buckets = num_buckets_for_reserve(num_elements); + void reserve(uint64_t num_elements_) { + uint64_t required_buckets = num_buckets_for_reserve(num_elements_); if (required_buckets > bucket_count()) rehash(required_buckets); } @@ -827,9 +822,9 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { return std::max(detailv3::min_lookups, desired); } - uint64_t num_buckets_for_reserve(uint64_t num_elements) const { + uint64_t num_buckets_for_reserve(uint64_t num_elements_) const { return static_cast(std::ceil( - num_elements / std::min(0.5, static_cast(_max_load_factor)))); + num_elements_ / std::min(0.5, static_cast(_max_load_factor)))); } void rehash_for_other_container(const sherwood_v3_table& other) { rehash( @@ -983,10 +978,10 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { void deallocate_data( EntryPointer begin, - uint64_t num_slots_minus_one, - int8_t max_lookups) { + uint64_t num_slots_minus_one_, + int8_t max_lookups_) { AllocatorTraits::deallocate( - *this, begin, num_slots_minus_one + max_lookups + 1); + *this, begin, num_slots_minus_one_ + max_lookups_ + 1); } void reset_to_empty_state() { @@ -2033,8 +2028,8 @@ struct fibonacci_hash_policy { size = std::max(uint64_t(2), detailv3::next_power_of_two(size)); return 64 - detailv3::log2(size); } - void commit(int8_t shift) { - this->shift = shift; + void commit(int8_t shift_) { + shift = shift_; } void reset() { shift = 63; @@ -2234,8 +2229,4 @@ struct power_of_two_std_hash : std::hash { } // namespace ska_ordered -#ifndef _MSC_VER -#pragma GCC diagnostic pop -#endif - C10_CLANG_DIAGNOSTIC_POP() diff --git a/c2_defs.bzl b/c2_defs.bzl index 01ec0c6d1642d..d77fed977f39e 100644 --- a/c2_defs.bzl +++ b/c2_defs.bzl @@ -93,12 +93,6 @@ def get_c2_tvm(): return bool(int(c2_tvm)) _C2_XPLAT_NO_HPTT_PREPROCESSOR_FLAGS = [ - "-fexceptions", - "-frtti", - "-Wno-shadow", - "-Wno-unknown-pragmas", - "-Wno-unused-variable", - "-Wno-sign-compare", "-Icaffe2", "-Imodules", "-DEIGEN_NO_DEBUG", @@ -139,7 +133,13 @@ def get_c2_xplat_preprocessor_flags(): def get_c2_xplat_no_hptt_compiler_flags(): return [ "-Os", - ] + get_c2_xplat_no_hptt_preprocessor_flags() + "-fexceptions", + "-frtti", + "-Wno-shadow", + "-Wno-unknown-pragmas", + "-Wno-unused-variable", + "-Wno-sign-compare", + ] def get_c2_xplat_compiler_flags(): return get_c2_xplat_no_hptt_compiler_flags() + C2_XPLAT_HPTT_PREPROCESSOR_FLAGS diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 4c9764348f058..65cdd576d9c28 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -551,8 +551,12 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/backends/coreml/cpp/context.cpp ${TORCH_SRC_DIR}/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm ${TORCH_SRC_DIR}/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm + ${TORCH_SRC_DIR}/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm + ${TORCH_SRC_DIR}/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.mm ${TORCH_SRC_DIR}/csrc/jit/backends/coreml/observer/PTMCoreMLObserver.mm ) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm PROPERTIES COMPILE_FLAGS "-fno-objc-arc") + include_directories(${TORCH_ROOT}/third_party/nlohmann/single_include) list(APPEND TORCH_SRCS ${COREML_DELEGATE_SRCS}) endif() endif() @@ -567,7 +571,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0)) # See https://github.com/pytorch/pytorch/issues/38856 set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_jit.cpp PROPERTIES COMPILE_FLAGS "-Wno-redundant-move -Wno-noexcept-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp PROPERTIES COMPILE_FLAGS -Wno-init-list-lifetime) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp PROPERTIES COMPILE_FLAGS "-Wno-init-list-lifetime") endif() if(NOT INTERN_DISABLE_MOBILE_INTERP) @@ -605,6 +609,13 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ) endif() + if(${USE_ITT}) + list(APPEND TORCH_SRCS + ${TORCH_SRC_DIR}/csrc/itt_wrapper.cpp + ${TORCH_SRC_DIR}/csrc/profiler/itt.cpp + ) + endif() + if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER) list(APPEND TORCH_SRCS ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp @@ -907,6 +918,11 @@ if(HAVE_SOVERSION) VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) endif() +if(USE_UCC) + target_link_libraries(torch_cpu PRIVATE __caffe2_ucc) + target_compile_definitions(torch_cpu PRIVATE USE_UCC) +endif() + if(USE_ROCM) filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cu|hip)$") set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) @@ -968,6 +984,13 @@ elseif(USE_CUDA) target_link_libraries(torch_cuda PRIVATE __caffe2_nccl) target_compile_definitions(torch_cuda PRIVATE USE_NCCL) endif() + if(USE_UCC AND BUILD_SPLIT_CUDA) + target_link_libraries(torch_cuda_cpp PRIVATE __caffe2_ucc) + target_compile_definitions(torch_cuda_cpp PRIVATE USE_UCC) + elseif(USE_UCC) + target_link_libraries(torch_cuda PRIVATE __caffe2_ucc) + target_compile_definitions(torch_cuda PRIVATE USE_UCC) + endif() if(BUILD_LAZY_CUDA_LINALG) add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS}) target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG) @@ -1131,11 +1154,13 @@ endif() ${TORCH_SRC_DIR}/csrc) target_include_directories(torch_cpu PRIVATE - ${TORCH_ROOT}/third_party/miniz-2.0.8) + ${TORCH_ROOT}/third_party/miniz-2.1.0) + + target_include_directories(torch_cpu PRIVATE + ${TORCH_ROOT}/third_party/kineto/libkineto/include) if(USE_KINETO) target_include_directories(torch_cpu PRIVATE - ${TORCH_ROOT}/third_party/kineto/libkineto/include ${TORCH_ROOT}/third_party/kineto/libkineto/src) endif() @@ -1343,6 +1368,16 @@ if(USE_DISTRIBUTED) if(USE_GLOO AND USE_C10D_GLOO) target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO) endif() + if(USE_UCC AND USE_C10D_UCC) + target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC) + if(USE_CUDA) + if(BUILD_SPLIT_CUDA) + target_compile_definitions(torch_cuda_cpp PUBLIC USE_C10D_UCC) + else() + target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) + endif() + endif() + endif() if(USE_NCCL AND USE_C10D_NCCL) if(USE_ROCM) target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) @@ -2019,7 +2054,7 @@ if(BUILD_PYTHON) target_link_libraries(caffe2_pybind11_state PRIVATE torch_library python::python pybind::pybind11) - if(CAFFE2_USE_MKLDNN) + if(USE_MKLDNN) target_link_libraries(caffe2_pybind11_state PRIVATE caffe2::mkldnn) endif() if(WIN32) diff --git a/caffe2/contrib/fakelowp/fp16_fc_acc_op.h b/caffe2/contrib/fakelowp/fp16_fc_acc_op.h index 49859ed1a373d..27dccc3230b0a 100644 --- a/caffe2/contrib/fakelowp/fp16_fc_acc_op.h +++ b/caffe2/contrib/fakelowp/fp16_fc_acc_op.h @@ -78,7 +78,7 @@ class Fp16FCAccOp final : public Operator { Y_shape_cache_ = X.sizes().vec(); // This is an invariant of canonical_axis, so we can DCHECK. - DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); + TORCH_DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); Y_shape_cache_.resize(canonical_axis + 1); Y_shape_cache_[canonical_axis] = N; Y->Resize(Y_shape_cache_); diff --git a/caffe2/contrib/ideep/CMakeLists.txt b/caffe2/contrib/ideep/CMakeLists.txt index 8e1f89d0a7e48..02ffe32b9a7d9 100644 --- a/caffe2/contrib/ideep/CMakeLists.txt +++ b/caffe2/contrib/ideep/CMakeLists.txt @@ -1,4 +1,4 @@ -if(CAFFE2_USE_MKLDNN) +if(USE_MKLDNN) message(STATUS "Including IDEEP operators") # ---[ CPU files. diff --git a/caffe2/contrib/nccl/cuda_nccl_gpu.cc b/caffe2/contrib/nccl/cuda_nccl_gpu.cc index ef2b9ab37ea09..82fe523651285 100644 --- a/caffe2/contrib/nccl/cuda_nccl_gpu.cc +++ b/caffe2/contrib/nccl/cuda_nccl_gpu.cc @@ -91,7 +91,7 @@ NCCLContext* getNCCLContext(const NCCLExecution& ex) { LOG(INFO) << "Creating NCCLContext for key: " << key; contexts[key].reset(new NCCLContext(ex)); } - return CHECK_NOTNULL(contexts[key].get()); + return TORCH_CHECK_NOTNULL(contexts[key].get()); } template @@ -153,7 +153,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) { auto& comm = comms[i]; auto& stream = streams[i]; - DCHECK_EQ(ctx.device, GetGPUIDForPointer(ctx.src->raw_data())); + TORCH_DCHECK_EQ(ctx.device, GetGPUIDForPointer(ctx.src->raw_data())); CUDA_ENFORCE(cudaStreamWaitEvent(stream, context->master_event_, 0)); f(ctx, comm, stream); } diff --git a/caffe2/contrib/opencl/context.h b/caffe2/contrib/opencl/context.h index b1e61c2124adc..5ea63cb80f196 100644 --- a/caffe2/contrib/opencl/context.h +++ b/caffe2/contrib/opencl/context.h @@ -36,7 +36,7 @@ class OpenCLContext final { public: explicit OpenCLContext(); explicit OpenCLContext(const DeviceOption& option) { - DCHECK_EQ(option.device_type(), PROTO_OPENCL); + TORCH_DCHECK_EQ(option.device_type(), PROTO_OPENCL); OpenCLContext(); } ~OpenCLContext() {} diff --git a/caffe2/contrib/warpctc/ctc_op.cpp b/caffe2/contrib/warpctc/ctc_op.cpp index 047ec6e10bec3..e5ec5ff58a062 100644 --- a/caffe2/contrib/warpctc/ctc_op.cpp +++ b/caffe2/contrib/warpctc/ctc_op.cpp @@ -2,7 +2,7 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/core/operator.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include #include #endif @@ -25,7 +25,7 @@ REGISTER_CPU_OPERATOR(CTC, CTCOp); OPERATOR_SCHEMA(CTC).NumInputs(3, 4).NumOutputs(2, 3); // .EnforceInputOutputGradient({{0, 0}}); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR(CTC, IDEEPFallbackOp>); #endif diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index 1b2c58bacd467..5cd01ba7cc59c 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -785,13 +785,8 @@ void DeserializeBlob(const BlobProto& blob_proto, Blob* result) { // === Local helper functions === // Get dimensions from Tensor proto -std::vector DimsFromTensorProto(const TensorProto& proto) { - std::vector dims; - dims.reserve(proto.dims().size()); - for (const int64_t d : proto.dims()) { - dims.push_back(d); - } - return dims; +c10::IntArrayRef DimsFromTensorProto(const TensorProto& proto) { + return c10::IntArrayRef(proto.dims().data(), proto.dims().size()); } // Get number of elements from Tensor proto diff --git a/caffe2/core/blob_serialization.h b/caffe2/core/blob_serialization.h index 68ba36956938b..17992b7c61130 100644 --- a/caffe2/core/blob_serialization.h +++ b/caffe2/core/blob_serialization.h @@ -296,7 +296,7 @@ inline std::string SerializeBlobProtoAsString_EnforceCheck( int64_t NumelFromTensorProto(const TensorProto& tensor_proto); -std::vector DimsFromTensorProto(const TensorProto& proto); +c10::IntArrayRef DimsFromTensorProto(const TensorProto& proto); TypeMeta GetDataType(const TensorProto& tensor_proto); diff --git a/caffe2/core/common_cudnn.cc b/caffe2/core/common_cudnn.cc index 248897c136442..f8186544054a3 100644 --- a/caffe2/core/common_cudnn.cc +++ b/caffe2/core/common_cudnn.cc @@ -9,7 +9,7 @@ CuDNNWrapper::PerGPUCuDNNStates& CuDNNWrapper::cudnn_states() { // New it (never delete) to avoid calling the destructors on process // exit and racing against the CUDA shutdown sequence. static auto* p = new CuDNNWrapper::PerGPUCuDNNStates(); - CHECK_NOTNULL(p); + TORCH_CHECK_NOTNULL(p); return *p; } diff --git a/caffe2/core/context_gpu.cu b/caffe2/core/context_gpu.cu index 6d537400913eb..bfa563ca6b8bb 100644 --- a/caffe2/core/context_gpu.cu +++ b/caffe2/core/context_gpu.cu @@ -437,7 +437,7 @@ CUDAContext::CUDAContext(const DeviceOption& option) option.has_random_seed() ? option.random_seed() : RandomNumberSeed()) { static Caffe2CudaInitializerHelper g_cuda_initializer_; - DCHECK_EQ(option.device_type(), PROTO_CUDA); + TORCH_DCHECK_EQ(option.device_type(), PROTO_CUDA); } CUDAContext::~CUDAContext() { diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h index 3dc7c895577f1..e411d9cd735f1 100644 --- a/caffe2/core/context_gpu.h +++ b/caffe2/core/context_gpu.h @@ -230,7 +230,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext { curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); CURAND_ENFORCE( curandSetPseudoRandomGeneratorSeed(curand_generator_, random_seed_)); - CHECK_NOTNULL(curand_generator_); + TORCH_CHECK_NOTNULL(curand_generator_); } CURAND_ENFORCE(curandSetStream(curand_generator_, cuda_stream())); return curand_generator_; diff --git a/caffe2/core/context_gpu_test.cc b/caffe2/core/context_gpu_test.cc index b27dddff4cfca..aba8193e59a11 100644 --- a/caffe2/core/context_gpu_test.cc +++ b/caffe2/core/context_gpu_test.cc @@ -65,7 +65,7 @@ TEST(CUDAContextTest, MemoryPoolAllocateDealloc) { cudaStream_t getStreamForHandle(cublasHandle_t handle) { cudaStream_t stream = nullptr; CUBLAS_ENFORCE(cublasGetStream(handle, &stream)); - CHECK_NOTNULL(stream); + TORCH_CHECK_NOTNULL(stream); return stream; } diff --git a/caffe2/core/cudnn_wrappers.h b/caffe2/core/cudnn_wrappers.h index 02efba78d7e64..ce3d297bab651 100644 --- a/caffe2/core/cudnn_wrappers.h +++ b/caffe2/core/cudnn_wrappers.h @@ -172,7 +172,7 @@ class CuDNNWrapper { if (!sync_state.state.get()) { sync_state.state.reset(new CuDNNState(context_->device_id())); } - CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f); + TORCH_CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f); } protected: diff --git a/caffe2/core/hip/common_miopen.hip b/caffe2/core/hip/common_miopen.hip index 86a7c542ba44a..a617bad29a3d1 100644 --- a/caffe2/core/hip/common_miopen.hip +++ b/caffe2/core/hip/common_miopen.hip @@ -25,7 +25,7 @@ MIOPENWrapper::PerGPUMIOPENStates& MIOPENWrapper::miopen_states() // New it (never delete) to avoid calling the destructors on process // exit and racing against the CUDA shutdown sequence. static auto* p = new MIOPENWrapper::PerGPUMIOPENStates(); - CHECK_NOTNULL(p); + TORCH_CHECK_NOTNULL(p); return *p; } diff --git a/caffe2/core/hip/miopen_wrapper.h b/caffe2/core/hip/miopen_wrapper.h index 1df99491a6791..f60bed6c277d6 100644 --- a/caffe2/core/hip/miopen_wrapper.h +++ b/caffe2/core/hip/miopen_wrapper.h @@ -138,7 +138,7 @@ class MIOPENWrapper { sync_state.state.reset(new MIOPENState(context_->device_id())); } - CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f); + TORCH_CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f); } protected: diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 11fd739b20908..9c9f734575634 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -40,8 +40,9 @@ static_assert( #cmakedefine CAFFE2_USE_GOOGLE_GLOG #cmakedefine CAFFE2_USE_LITE_PROTO #cmakedefine CAFFE2_USE_MKL -#cmakedefine CAFFE2_USE_MKLDNN +#cmakedefine USE_MKLDNN #cmakedefine CAFFE2_USE_NVTX +#cmakedefine CAFFE2_USE_ITT #cmakedefine CAFFE2_USE_TRT #ifndef EIGEN_MPL2_ONLY @@ -80,7 +81,8 @@ static_assert( {"USE_EIGEN_FOR_BLAS", "${CAFFE2_USE_EIGEN_FOR_BLAS}"}, \ {"USE_LITE_PROTO", "${CAFFE2_USE_LITE_PROTO}"}, \ {"USE_MKL", "${CAFFE2_USE_MKL}"}, \ - {"USE_MKLDNN", "${CAFFE2_USE_MKLDNN}"}, \ + {"USE_MKLDNN", "${USE_MKLDNN}"}, \ {"USE_NVTX", "${CAFFE2_USE_NVTX}"}, \ + {"USE_ITT", "${CAFFE2_USE_ITT}"}, \ {"USE_TRT", "${CAFFE2_USE_TRT}"}, \ } diff --git a/caffe2/core/net_gpu_test.cc b/caffe2/core/net_gpu_test.cc index b8d4776520aaa..1eb6fa513a23d 100644 --- a/caffe2/core/net_gpu_test.cc +++ b/caffe2/core/net_gpu_test.cc @@ -79,7 +79,7 @@ void checkChainingAndRun( net_def.set_num_workers(4); std::unique_ptr net(CreateNet(net_def, &ws)); auto* dag = dynamic_cast_if_rtti(net.get()); - CHECK_NOTNULL(dag); + TORCH_CHECK_NOTNULL(dag); const auto& chains = dag->TEST_execution_chains(); EXPECT_EQ(chains, expected); testExecution(net, net_def.op().size()); diff --git a/caffe2/core/net_test.cc b/caffe2/core/net_test.cc index b49ab7bcf9415..a1c80eca6790d 100644 --- a/caffe2/core/net_test.cc +++ b/caffe2/core/net_test.cc @@ -152,7 +152,7 @@ void checkChainingAndRun( net_def.set_num_workers(4); std::unique_ptr net(CreateNet(net_def, &ws)); auto* dag = dynamic_cast_if_rtti(net.get()); - CHECK_NOTNULL(dag); + TORCH_CHECK_NOTNULL(dag); const auto& chains = dag->TEST_execution_chains(); EXPECT_TRUE(chains == expected); testExecution(net, net_def.op().size()); @@ -175,7 +175,7 @@ void checkNumChainsAndRun(const char* spec, const int expected_num_chains) { { std::unique_ptr net(CreateNet(net_def, &ws)); auto* dag = dynamic_cast_if_rtti(net.get()); - CHECK_NOTNULL(dag); + TORCH_CHECK_NOTNULL(dag); const auto& chains = dag->TEST_execution_chains(); EXPECT_EQ(expected_num_chains, chains.size()); testExecution(net, net_def.op().size()); @@ -1108,7 +1108,7 @@ void testProfDAGNetErrorCase(bool test_error) { // with failing op - prof_dag handles invalid runs and returns empty stats, // without - returns stats for each op auto* prof_dag = dynamic_cast_if_rtti(net.get()); - CHECK_NOTNULL(prof_dag); + TORCH_CHECK_NOTNULL(prof_dag); auto stats_proto = prof_dag->GetPerOperatorCost(); ASSERT_EQ( stats_proto.stats_size(), test_error ? 0 : net->GetOperators().size()); diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc index e25c92a6d6075..a16f2cb26846c 100644 --- a/caffe2/core/operator.cc +++ b/caffe2/core/operator.cc @@ -82,7 +82,7 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws) outputs_.reserve(operator_def.output_size()); for (const string& output_str : operator_def.output()) { - outputs_.push_back(CHECK_NOTNULL(ws->CreateBlob(output_str))); + outputs_.push_back(TORCH_CHECK_NOTNULL(ws->CreateBlob(output_str))); } type_ = operator_def.type(); diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 6b0a15f78f02b..4fd8619631a37 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -157,7 +157,7 @@ class TORCH_API OperatorBase : public Observable { !std::is_same::value, "You should use Input(int, DeviceType) for " "Tensor."); - DCHECK_LT((size_t)idx, inputs_.size()); + TORCH_DCHECK_LT((size_t)idx, inputs_.size()); try { return inputs_.at(idx)->template Get(); } catch (::caffe2::EnforceNotMet& enf) { @@ -178,7 +178,7 @@ class TORCH_API OperatorBase : public Observable { static_assert( std::is_same::value, "Input(int, DeviceType) is only available for Tensor"); - DCHECK_LT((size_t)idx, inputs_.size()); + TORCH_DCHECK_LT((size_t)idx, inputs_.size()); try { // TODO(jerryzh): We'll need to check device type in Get() later // Get() -> Get(type) @@ -193,7 +193,7 @@ class TORCH_API OperatorBase : public Observable { } #if defined(EXPOSE_C2_OPS) || \ !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - DCHECK_LT(0U, newstyle_inputs_.size()); + TORCH_DCHECK_LT(0U, newstyle_inputs_.size()); IValue ival; if (newstyle_inputs_[0].isTensorList()) { // if the first input is a tensor list, we get input tensors by indexing @@ -201,12 +201,12 @@ class TORCH_API OperatorBase : public Observable { // are accessible as inputs. any hypothetical input tensors that come // after the list are not accessible. auto tensorList = newstyle_inputs_[0].toTensorVector(); - DCHECK_LT((size_t)idx, tensorList.size()); + TORCH_DCHECK_LT((size_t)idx, tensorList.size()); ival = tensorList[idx]; } else { // if the first input is not a tensor list, we get input tensors by // indexing into the inputs. - DCHECK_LT((size_t)idx, newstyle_inputs_.size()); + TORCH_DCHECK_LT((size_t)idx, newstyle_inputs_.size()); ival = newstyle_inputs_[idx]; } CAFFE_ENFORCE( diff --git a/caffe2/core/parallel_net_test.cc b/caffe2/core/parallel_net_test.cc index 555155346ff42..7b17faba31509 100644 --- a/caffe2/core/parallel_net_test.cc +++ b/caffe2/core/parallel_net_test.cc @@ -24,8 +24,8 @@ class SleepOp final : public Operator { SleepOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), ms_(OperatorBase::GetSingleArgument("ms", 1000)) { - DCHECK_GT(ms_, 0); - DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?"; + TORCH_DCHECK_GT(ms_, 0); + TORCH_DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?"; } bool RunOnDevice() override { diff --git a/caffe2/core/qtensor.h b/caffe2/core/qtensor.h index 7dc9c59f82f66..0308c14c097f6 100644 --- a/caffe2/core/qtensor.h +++ b/caffe2/core/qtensor.h @@ -187,8 +187,8 @@ class C10_EXPORT QTensor { * Returns the i-th dimension of the qtensor in int. */ inline int dim32(const int i) const { - DCHECK_LT(i, static_cast(dims_.size())) << "Exceeding ndim limit " << dims_.size(); - DCHECK_GE(i, 0) << "Cannot have negative index"; + TORCH_DCHECK_LT(i, static_cast(dims_.size())) << "Exceeding ndim limit " << dims_.size(); + TORCH_DCHECK_GE(i, 0) << "Cannot have negative index"; CAFFE_ENFORCE_LT(dims_[i], std::numeric_limits::max()); return static_cast(dims_[i]); } diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index afbaeb6840413..4c5be742d0cf7 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include C10_CLANG_DIAGNOSTIC_PUSH() @@ -63,6 +65,10 @@ class TORCH_API Tensor final { return impl_.get(); } + TensorImpl* unsafeReleaseTensorImpl() { + return impl_.release(); + } + Tensor UnsafeSharedInstance() const { return Tensor(*this, IDoWantAliasing); } @@ -652,4 +658,8 @@ void TensorPrinter::Print(const Tensor& tensor) { C10_CLANG_DIAGNOSTIC_POP() +namespace c10 { +template <> +struct ExclusivelyOwnedTraits : public c10::ExclusivelyOwnedTensorTraits {}; +} // namespace c10 #endif // CAFFE2_CORE_TENSOR_H_ diff --git a/caffe2/cuda_rtc/pool_op_rtc_gpu.cc b/caffe2/cuda_rtc/pool_op_rtc_gpu.cc index 830f0dad1b34e..8ec14e1223ae8 100644 --- a/caffe2/cuda_rtc/pool_op_rtc_gpu.cc +++ b/caffe2/cuda_rtc/pool_op_rtc_gpu.cc @@ -163,8 +163,8 @@ string MaxPoolRTCFunction::GetSource( stride_w, pad_t, pad_l); - DCHECK_GE(nbytes, 0); - DCHECK_LT(nbytes, 65536); + TORCH_DCHECK_GE(nbytes, 0); + TORCH_DCHECK_LT(nbytes, 65536); return string(buffer); } @@ -202,8 +202,8 @@ string MaxPoolGradientRTCFunction::GetSource( stride_w, pad_t, pad_l); - DCHECK_GE(nbytes, 0); - DCHECK_LT(nbytes, 65536); + TORCH_DCHECK_GE(nbytes, 0); + TORCH_DCHECK_LT(nbytes, 65536); return string(buffer); } diff --git a/caffe2/distributed/file_store_handler.cc b/caffe2/distributed/file_store_handler.cc index 8535c506206cb..c5f85b57ab39e 100644 --- a/caffe2/distributed/file_store_handler.cc +++ b/caffe2/distributed/file_store_handler.cc @@ -55,7 +55,7 @@ FileStoreHandler::FileStoreHandler( auto ret = mkdir(basePath_.c_str(), 0777); #endif // defined(_MSC_VER) if (ret == -1) { - CHECK_EQ(errno, EEXIST) << "mkdir: " << strerror(errno); + TORCH_CHECK_EQ(errno, EEXIST) << "mkdir: " << strerror(errno); } } @@ -71,7 +71,7 @@ std::string FileStoreHandler::realPath(const std::string& path) { std::array buf; auto ret = realpath(path.c_str(), buf.data()); #endif - CHECK_EQ(buf.data(), ret) << "realpath: " << strerror(errno); + TORCH_CHECK_EQ(buf.data(), ret) << "realpath: " << strerror(errno); return std::string(buf.data()); } @@ -152,7 +152,7 @@ bool FileStoreHandler::check(const std::vector& names) { if (fd == -1) { // Only deal with files that don't exist. // Anything else is a problem. - CHECK_EQ(errno, ENOENT); + TORCH_CHECK_EQ(errno, ENOENT); // One of the paths doesn't exist; return early return false; diff --git a/caffe2/experiments/operators/fully_connected_op_decomposition.h b/caffe2/experiments/operators/fully_connected_op_decomposition.h index 218db49758089..0c734b17be33f 100644 --- a/caffe2/experiments/operators/fully_connected_op_decomposition.h +++ b/caffe2/experiments/operators/fully_connected_op_decomposition.h @@ -145,10 +145,10 @@ class FullyConnectedDecompGradientOp : public Operator { const auto& U = Input(1); const auto& V = Input(2); const auto& dY = Input(3); - DCHECK_GE(X.dim(), 1); - DCHECK_GE(U.dim(), 2); - DCHECK_GE(V.dim(), 2); - DCHECK_LE(dY.dim(), 2); + TORCH_DCHECK_GE(X.dim(), 1); + TORCH_DCHECK_GE(U.dim(), 2); + TORCH_DCHECK_GE(V.dim(), 2); + TORCH_DCHECK_LE(dY.dim(), 2); // batch size int M = X.dim() > 1 ? X.dim32(0) : 1; // Feature dimension @@ -156,13 +156,13 @@ class FullyConnectedDecompGradientOp : public Operator { // number of outputs. int N = U.dim32(0); int middle = U.dim32(1); - DCHECK_EQ(K, V.dim32(0)); + TORCH_DCHECK_EQ(K, V.dim32(0)); if (dY.dim() > 1) { - DCHECK_EQ(M, dY.dim32(0)); - DCHECK_EQ(N, dY.dim32(1)); + TORCH_DCHECK_EQ(M, dY.dim32(0)); + TORCH_DCHECK_EQ(N, dY.dim32(1)); } else { - DCHECK_EQ(X.dim(), 1); - DCHECK_EQ(N, dY.numel()); + TORCH_DCHECK_EQ(X.dim(), 1); + TORCH_DCHECK_EQ(N, dY.numel()); } auto* dU = Output(0, U.sizes(), at::dtype()); diff --git a/caffe2/experiments/operators/fully_connected_op_prune.h b/caffe2/experiments/operators/fully_connected_op_prune.h index cd2e6fc19f404..70834a707d134 100644 --- a/caffe2/experiments/operators/fully_connected_op_prune.h +++ b/caffe2/experiments/operators/fully_connected_op_prune.h @@ -17,6 +17,7 @@ #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_ #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_ +#include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" @@ -249,9 +250,9 @@ class FullyConnectedPruneGradientOp : public Operator { auto& thres = Input(6); // TODO(wyiming): check comp_lb is a float auto& comp_lb = Input(7); - DCHECK_GE(X.dim(), 1); - DCHECK_GE(W.dim(), 2); - DCHECK_LE(dY.dim(), 2); + TORCH_DCHECK_GE(X.dim(), 1); + TORCH_DCHECK_GE(W.dim(), 2); + TORCH_DCHECK_LE(dY.dim(), 2); // batch size int M = X.dim() > 1 ? X.dim32(0) : 1; // Feature dimension @@ -263,17 +264,17 @@ class FullyConnectedPruneGradientOp : public Operator { // TODO(wyiming): this threshold should be // based on distribution of the layer weight float thr = 0.01; - DCHECK_EQ(Mask.dim32(0), W.dim32(0)); - DCHECK_EQ(Mask.dim32(1), W.dim32(1)); - DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0)); - DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1)); - DCHECK_EQ(K, W.numel() / W.dim32(0)); + TORCH_DCHECK_EQ(Mask.dim32(0), W.dim32(0)); + TORCH_DCHECK_EQ(Mask.dim32(1), W.dim32(1)); + TORCH_DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0)); + TORCH_DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1)); + TORCH_DCHECK_EQ(K, W.numel() / W.dim32(0)); if (dY.dim() > 1) { - DCHECK_EQ(M, dY.dim32(0)); - DCHECK_EQ(N, dY.dim32(1)); + TORCH_DCHECK_EQ(M, dY.dim32(0)); + TORCH_DCHECK_EQ(N, dY.dim32(1)); } else { - DCHECK_EQ(X.dim(), 1); - DCHECK_EQ(N, dY.numel()); + TORCH_DCHECK_EQ(X.dim(), 1); + TORCH_DCHECK_EQ(N, dY.numel()); } auto* dW = Output(0, W.sizes(), at::dtype()); diff --git a/caffe2/ideep/CMakeLists.txt b/caffe2/ideep/CMakeLists.txt index 8e1f89d0a7e48..02ffe32b9a7d9 100644 --- a/caffe2/ideep/CMakeLists.txt +++ b/caffe2/ideep/CMakeLists.txt @@ -1,4 +1,4 @@ -if(CAFFE2_USE_MKLDNN) +if(USE_MKLDNN) message(STATUS "Including IDEEP operators") # ---[ CPU files. diff --git a/caffe2/ideep/operators/local_response_normalization_op.cc b/caffe2/ideep/operators/local_response_normalization_op.cc index 52aa1504b39ca..faf7c06f76b27 100644 --- a/caffe2/ideep/operators/local_response_normalization_op.cc +++ b/caffe2/ideep/operators/local_response_normalization_op.cc @@ -15,10 +15,10 @@ class IDEEPLRNOp final : public IDEEPOperator { alpha_(OperatorBase::GetSingleArgument("alpha", 0)), beta_(OperatorBase::GetSingleArgument("beta", 0)), bias_(OperatorBase::GetSingleArgument("bias", 1)) { - DCHECK_GT(size_, 0); - DCHECK_EQ(size_ % 2, 1); - DCHECK_GT(alpha_, 0); - DCHECK_GT(beta_, 0); + TORCH_DCHECK_GT(size_, 0); + TORCH_DCHECK_EQ(size_ % 2, 1); + TORCH_DCHECK_GT(alpha_, 0); + TORCH_DCHECK_GT(beta_, 0); } ~IDEEPLRNOp() override = default; @@ -52,10 +52,10 @@ class IDEEPLRNGradientOp final : public IDEEPOperator { alpha_(OperatorBase::GetSingleArgument("alpha", 0)), beta_(OperatorBase::GetSingleArgument("beta", 0)), bias_(OperatorBase::GetSingleArgument("bias", 1)) { - DCHECK_GT(size_, 0); - DCHECK_EQ(size_ % 2, 1); - DCHECK_GT(alpha_, 0); - DCHECK_GT(beta_, 0); + TORCH_DCHECK_GT(size_, 0); + TORCH_DCHECK_EQ(size_ % 2, 1); + TORCH_DCHECK_GT(alpha_, 0); + TORCH_DCHECK_GT(beta_, 0); } ~IDEEPLRNGradientOp() override = default; diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index e7925d9e5d1e1..933d0843b2544 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -60,7 +60,7 @@ class IDEEPFallbackOp final : public IDEEPOperator { parent_name += "_cpu_output_blob_" + base_def_.type(); } local_output_blobs_.push_back(ws->CreateBlob(parent_name)); - CHECK_NOTNULL(local_output_blobs_.back()); + TORCH_CHECK_NOTNULL(local_output_blobs_.back()); forwarded_output_blobs[base_def_.output(i)] = parent_name; output_inplace_.push_back(false); for (const string &input_name : base_def_.input()) { @@ -74,7 +74,7 @@ class IDEEPFallbackOp final : public IDEEPOperator { // Set up the symbols for the local workspace. for (const string& name : base_def_.input()) { local_input_blobs_.push_back(local_ws_->CreateBlob(name)); - CHECK_NOTNULL(local_input_blobs_.back()); + TORCH_CHECK_NOTNULL(local_input_blobs_.back()); } input_share_.resize(local_input_blobs_.size(), false); base_op_.reset(new CPUOp(base_def_, local_ws_.get())); diff --git a/caffe2/ideep/operators/quantization/int8_given_tensor_fill_op.cc b/caffe2/ideep/operators/quantization/int8_given_tensor_fill_op.cc index 8743ad3579295..07d44ae51a457 100644 --- a/caffe2/ideep/operators/quantization/int8_given_tensor_fill_op.cc +++ b/caffe2/ideep/operators/quantization/int8_given_tensor_fill_op.cc @@ -51,7 +51,7 @@ class IDEEPInt8GivenTensorFillOp final : public IDEEPOperator { auto data_type = zero_point_ == 0 ? idtype::u8 : idtype::s8; output->init({shape_, data_type}); - DCHECK_EQ(output->get_nelems(), values_.numel()) + TORCH_DCHECK_EQ(output->get_nelems(), values_.numel()) << "output size: " << output->get_nelems() << " given size: " << values_.numel(); @@ -121,7 +121,7 @@ class IDEEPInt8GivenIntTensorFillOp final : public IDEEPOperator { auto* output = Output(OUTPUT); output->init({shape_, idtype::s32}); output->set_scale(ConvertScales(scales_)); - DCHECK_EQ(output->get_nelems(), values_.numel()) + TORCH_DCHECK_EQ(output->get_nelems(), values_.numel()) << "output size: " << output->get_nelems() << " given size: " << values_.numel(); diff --git a/caffe2/ideep/operators/spatial_batch_norm_op.cc b/caffe2/ideep/operators/spatial_batch_norm_op.cc index 2d43871ddd2ad..edbd82e9f5c9b 100644 --- a/caffe2/ideep/operators/spatial_batch_norm_op.cc +++ b/caffe2/ideep/operators/spatial_batch_norm_op.cc @@ -30,10 +30,10 @@ class IDEEPSpatialBNOp final : public IDEEPOperator { const auto& bias = Input(BIAS); auto* Y = Output(OUTPUT); - DCHECK_EQ(scale.ndims(), 1); - DCHECK_EQ(bias.ndims(), 1); - DCHECK_EQ(scale.get_dim(0), X.get_dim(1)); - DCHECK_EQ(bias.get_dim(0), X.get_dim(1)); + TORCH_DCHECK_EQ(scale.ndims(), 1); + TORCH_DCHECK_EQ(bias.ndims(), 1); + TORCH_DCHECK_EQ(scale.get_dim(0), X.get_dim(1)); + TORCH_DCHECK_EQ(bias.get_dim(0), X.get_dim(1)); if (is_test_) { const auto& est_mean = Input(EST_MEAN); diff --git a/caffe2/image/image_input_op.cc b/caffe2/image/image_input_op.cc index be21e791ad169..ff868e1370501 100644 --- a/caffe2/image/image_input_op.cc +++ b/caffe2/image/image_input_op.cc @@ -1,6 +1,6 @@ #include "caffe2/image/image_input_op.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include #include #endif @@ -26,7 +26,7 @@ OPERATOR_SCHEMA(ImageInput) int batch_size = helper.GetSingleArgument("batch_size", 0); int crop = helper.GetSingleArgument("crop", -1); int color = helper.GetSingleArgument("color", 1); - CHECK_GT(crop, 0); + TORCH_CHECK_GT(crop, 0); out[0] = CreateTensorShape( vector{batch_size, crop, crop, color ? 3 : 1}, TensorProto::FLOAT); @@ -160,7 +160,7 @@ The dimension of the output image will always be cropxcrop NO_GRADIENT(ImageInput); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR(ImageInput, IDEEPFallbackOp>); #endif diff --git a/caffe2/image/image_input_op.h b/caffe2/image/image_input_op.h index d1c27e0845a84..51367788c0660 100644 --- a/caffe2/image/image_input_op.h +++ b/caffe2/image/image_input_op.h @@ -530,8 +530,8 @@ bool ImageInputOp::GetImageAndLabelAndInfoFromDBValue( if (protos.protos_size() == end + 1) { // We have bounding box information const TensorProto& bounding_proto = protos.protos(end); - DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32); - DCHECK_EQ(bounding_proto.int32_data_size(), 4); + TORCH_DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32); + TORCH_DCHECK_EQ(bounding_proto.int32_data_size(), 4); info.bounding_params.valid = true; info.bounding_params.ymin = bounding_proto.int32_data(0); info.bounding_params.xmin = bounding_proto.int32_data(1); @@ -541,7 +541,7 @@ bool ImageInputOp::GetImageAndLabelAndInfoFromDBValue( if (image_proto.data_type() == TensorProto::STRING) { // encoded image string. - DCHECK_EQ(image_proto.string_data_size(), 1); + TORCH_DCHECK_EQ(image_proto.string_data_size(), 1); const string& encoded_image_str = image_proto.string_data(0); int encoded_size = encoded_image_str.size(); // We use a cv::Mat to wrap the encoded str so we do not need a copy. @@ -582,7 +582,7 @@ bool ImageInputOp::GetImageAndLabelAndInfoFromDBValue( // TODO: if image decoding was unsuccessful, set label to 0 if (label_proto.data_type() == TensorProto::FLOAT) { if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) { - DCHECK_EQ(label_proto.float_data_size(), 1); + TORCH_DCHECK_EQ(label_proto.float_data_size(), 1); prefetched_label_.mutable_data()[item_id] = label_proto.float_data(0); } else if (label_type_ == MULTI_LABEL_SPARSE) { @@ -614,7 +614,7 @@ bool ImageInputOp::GetImageAndLabelAndInfoFromDBValue( } } else if (label_proto.data_type() == TensorProto::INT32) { if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) { - DCHECK_EQ(label_proto.int32_data_size(), 1); + TORCH_DCHECK_EQ(label_proto.int32_data_size(), 1); prefetched_label_.mutable_data()[item_id] = label_proto.int32_data(0); } else if (label_type_ == MULTI_LABEL_SPARSE) { diff --git a/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm b/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm index 7d868979fa33c..7593281040c2e 100644 --- a/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm +++ b/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm @@ -284,7 +284,7 @@ constexpr int computeMPSAlignOffset(int kernel, int pad) { size_t ComputeStartIndex( const TensorCPU& tensor, const std::vector& index) { - DCHECK_EQ(index.size(), tensor.dim()); + TORCH_DCHECK_EQ(index.size(), tensor.dim()); size_t ret = 0; for (int i = 0; i < index.size(); i++) { @@ -299,7 +299,7 @@ size_t ComputeStartIndex( utils::ConstTensorView GetSubTensorView( const TensorCPU& tensor, int dim0_start_index) { - DCHECK_EQ(tensor.meta().itemsize(), sizeof(T)); + TORCH_DCHECK_EQ(tensor.meta().itemsize(), sizeof(T)); if (tensor.size() == 0) { return utils::ConstTensorView(nullptr, {}); @@ -1490,7 +1490,7 @@ bool RunOnDeviceWithOrderNCHW() override { caffe2::Timer consT; std::vector refilter(kH * kW * output_channels * input_channels); refilter.assign(kH * kW * output_channels * input_channels, 0.0f); - DCHECK_EQ(refilter.size(), filter.size()); + TORCH_DCHECK_EQ(refilter.size(), filter.size()); auto* filter_ = filter.template data(); // For iOS11+ Reformat weights from WT[IC][OC][kH][kW] to // W[OC][kH][kW][IC]; For previous versions, reformat weights @@ -1512,14 +1512,14 @@ bool RunOnDeviceWithOrderNCHW() override { kw * output_channels * input_channels + oc * input_channels + ic; } - DCHECK_LT(inputIdx, filter.size()); - DCHECK_LT(outputIdx, filter.size()); + TORCH_DCHECK_LT(inputIdx, filter.size()); + TORCH_DCHECK_LT(outputIdx, filter.size()); refilter[outputIdx] = filter_[inputIdx]; } } } } - DCHECK_EQ(filter.size(), input_channels * output_channels * kH * kW); + TORCH_DCHECK_EQ(filter.size(), input_channels * output_channels * kH * kW); // initialize data structures if (runtimeAtLeastIOS11) { MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor @@ -2225,7 +2225,7 @@ void ProposalsForOneImage( auto keep = utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_); - DCHECK_LE(keep.size(), scores.size()); + TORCH_DCHECK_LE(keep.size(), scores.size()); // 4. sort all (proposal, score) pairs by score from highest to lowest // 5. take top pre_nms_topN (e.g. 6000) diff --git a/caffe2/mobile/contrib/ios/mpscnn/mpscnn_test.mm b/caffe2/mobile/contrib/ios/mpscnn/mpscnn_test.mm index 0fa08fdabd4fd..b145f3031378d 100644 --- a/caffe2/mobile/contrib/ios/mpscnn/mpscnn_test.mm +++ b/caffe2/mobile/contrib/ios/mpscnn/mpscnn_test.mm @@ -126,7 +126,7 @@ void testMPSCNN() { CAFFE_ENFORCE_EQ(t1.sizes(), t2.sizes()); for (auto i = 0; i < t1.size(); ++i) { // FP16 <-> FP32 round trip. - CHECK_NEAR(t1.data()[i], t2.data()[i], 1e-2); + TORCH_CHECK_NEAR(t1.data()[i], t2.data()[i], 1e-2); } } } @@ -197,7 +197,7 @@ void testMPSCNN() { CAFFE_ENFORCE_EQ(t1.size(), t2.size()); for (auto i = 0; i < t1.size(); ++i) { // FP16 <-> FP32 round trip. - CHECK_NEAR(t1.data()[i], t2.data()[i], 1e-2); + TORCH_CHECK_NEAR(t1.data()[i], t2.data()[i], 1e-2); } } } @@ -274,7 +274,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -467,7 +467,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -560,7 +560,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -651,7 +651,7 @@ void testMPSCNN() { const float t2_i = t2.data()[i]; // LOG(INFO) << "i: " << i << ", cpu: " << t1_i << ", mtl: " << // t2_i; - CHECK_NEAR(t1_i, t2_i, 0.7); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.7); } } } @@ -763,7 +763,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -851,7 +851,7 @@ void testMPSCNN() { const float t2_i = t2.data()[i]; // LOG(INFO) << "i: " << i << ", " << "CPU: " << t1_i << ", MTL: " << // t2_i; - CHECK_NEAR(t1_i, t2_i, 0.01); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.01); } } } @@ -932,7 +932,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } @@ -991,7 +991,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } @@ -1050,7 +1050,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } @@ -1166,7 +1166,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.2); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.2); } } } @@ -1264,7 +1264,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.3); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.3); } } } @@ -1378,7 +1378,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } @@ -1481,7 +1481,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } @@ -1589,7 +1589,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -1713,7 +1713,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -1784,7 +1784,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.02); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.02); } } @@ -1849,7 +1849,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.01); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.01); } } @@ -1914,7 +1914,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.01); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.01); } } @@ -2003,7 +2003,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.05); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.05); } } @@ -2057,7 +2057,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.02); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.02); } } } @@ -2123,7 +2123,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } @@ -2237,7 +2237,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -2349,7 +2349,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -2428,7 +2428,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -2570,7 +2570,7 @@ void testMPSCNN() { const float t3_i = t3.data()[i / 5]; if (t3_i - HALF_MIN_VAL * 2 > 0) { LOG(INFO) << i << " " << t1_i << " " << t2_i << " " << t3_i; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } @@ -2579,7 +2579,7 @@ void testMPSCNN() { const float t3_i = t3.data()[i]; const float t4_i = t4.data()[i]; LOG(INFO) << i << " " << t3_i; - CHECK_NEAR(t3_i, t4_i, 0.1); + TORCH_CHECK_NEAR(t3_i, t4_i, 0.1); } } @@ -2634,7 +2634,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -2875,7 +2875,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -2943,7 +2943,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -3030,7 +3030,7 @@ void testMPSCNN() { // FP16 <-> FP32 round trip, accumulation, etc. const float t1_i = t1.data()[i]; const float t2_i = t2.data()[i]; - CHECK_NEAR(t1_i, t2_i, 0.1); + TORCH_CHECK_NEAR(t1_i, t2_i, 0.1); } } } @@ -3072,10 +3072,10 @@ void testMPSCNN() { } return arg->i(); }; - CHECK_EQ(rc(0), 1); - CHECK_EQ(rc(1), 2); - CHECK_EQ(rc(2), 1); - CHECK_EQ(rc(3), 1); + TORCH_CHECK_EQ(rc(0), 1); + TORCH_CHECK_EQ(rc(1), 2); + TORCH_CHECK_EQ(rc(2), 1); + TORCH_CHECK_EQ(rc(3), 1); } { @@ -3117,18 +3117,18 @@ void testMPSCNN() { auto ty = [&](size_t i) { return netdef.op(i).type(); }; auto i0 = [&](size_t i) { return netdef.op(i).input(0); }; auto o0 = [&](size_t i) { return netdef.op(i).output(0); }; - CHECK_EQ(netdef.op_size(), 4); - CHECK_EQ(ty(0), "CopyToMPSCNN"); - CHECK_EQ(ty(1), std::string("MPSCNN") + computeOp + std::string("Relu")); - CHECK_EQ(ty(2), std::string("MPSCNN") + computeOp + std::string("Relu")); - CHECK_EQ(ty(3), "CopyFromMPSCNN"); - CHECK_EQ(i0(0), "X"); - CHECK_EQ(i0(1), o0(0)); - CHECK_EQ(i0(2), "X2"); - CHECK_EQ(o0(2), i0(3)); - CHECK_EQ(o0(3), "Y"); - CHECK_EQ(netdef.external_input(0), "X"); - CHECK_EQ(netdef.external_output(0), "Y"); + TORCH_CHECK_EQ(netdef.op_size(), 4); + TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN"); + TORCH_CHECK_EQ(ty(1), std::string("MPSCNN") + computeOp + std::string("Relu")); + TORCH_CHECK_EQ(ty(2), std::string("MPSCNN") + computeOp + std::string("Relu")); + TORCH_CHECK_EQ(ty(3), "CopyFromMPSCNN"); + TORCH_CHECK_EQ(i0(0), "X"); + TORCH_CHECK_EQ(i0(1), o0(0)); + TORCH_CHECK_EQ(i0(2), "X2"); + TORCH_CHECK_EQ(o0(2), i0(3)); + TORCH_CHECK_EQ(o0(3), "Y"); + TORCH_CHECK_EQ(netdef.external_input(0), "X"); + TORCH_CHECK_EQ(netdef.external_output(0), "Y"); } } @@ -3195,18 +3195,18 @@ void testMPSCNN() { op.add_output("Z"); } netdef = rewriteForMetal(netdef); - CHECK_EQ(netdef.op_size(), 4); + TORCH_CHECK_EQ(netdef.op_size(), 4); auto ty = [&](size_t i) { return netdef.op(i).type(); }; auto i0 = [&](size_t i) { return netdef.op(i).input(0); }; auto o0 = [&](size_t i) { return netdef.op(i).output(0); }; - CHECK_EQ(ty(0), "CopyToMPSCNN"); - CHECK_EQ(ty(1), "MPSCNNConvRelu"); - CHECK_EQ(ty(2), "MPSCNNRelu"); - CHECK_EQ(ty(3), "CopyFromMPSCNN"); - CHECK_EQ(i0(1), o0(0)); - CHECK_EQ(o0(1), "Z"); - CHECK_EQ(i0(2), "Z"); - CHECK_EQ(o0(2), i0(3)); + TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN"); + TORCH_CHECK_EQ(ty(1), "MPSCNNConvRelu"); + TORCH_CHECK_EQ(ty(2), "MPSCNNRelu"); + TORCH_CHECK_EQ(ty(3), "CopyFromMPSCNN"); + TORCH_CHECK_EQ(i0(1), o0(0)); + TORCH_CHECK_EQ(o0(1), "Z"); + TORCH_CHECK_EQ(i0(2), "Z"); + TORCH_CHECK_EQ(o0(2), i0(3)); } { @@ -3235,21 +3235,21 @@ void testMPSCNN() { op.add_output("Z"); } netdef = rewriteForMetal(netdef); - CHECK_EQ(netdef.op_size(), 5); + TORCH_CHECK_EQ(netdef.op_size(), 5); auto ty = [&](size_t i) { return netdef.op(i).type(); }; auto i0 = [&](size_t i) { return netdef.op(i).input(0); }; auto o0 = [&](size_t i) { return netdef.op(i).output(0); }; - CHECK_EQ(ty(0), "CopyToMPSCNN"); - CHECK_EQ(ty(1), "MPSCNNConv"); - CHECK_EQ(ty(2), "MPSCNNRelu"); - CHECK_EQ(ty(3), "MPSCNNRelu"); - CHECK_EQ(ty(4), "CopyFromMPSCNN"); - CHECK_EQ(i0(1), o0(0)); - CHECK_EQ(o0(1), "Y"); - CHECK_EQ(i0(2), o0(1)); - CHECK_EQ(o0(2), "Z"); - CHECK_EQ(i0(3), o0(1)); - CHECK_EQ(o0(3), i0(4)); + TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN"); + TORCH_CHECK_EQ(ty(1), "MPSCNNConv"); + TORCH_CHECK_EQ(ty(2), "MPSCNNRelu"); + TORCH_CHECK_EQ(ty(3), "MPSCNNRelu"); + TORCH_CHECK_EQ(ty(4), "CopyFromMPSCNN"); + TORCH_CHECK_EQ(i0(1), o0(0)); + TORCH_CHECK_EQ(o0(1), "Y"); + TORCH_CHECK_EQ(i0(2), o0(1)); + TORCH_CHECK_EQ(o0(2), "Z"); + TORCH_CHECK_EQ(i0(3), o0(1)); + TORCH_CHECK_EQ(o0(3), i0(4)); } { @@ -3277,14 +3277,14 @@ void testMPSCNN() { auto ty = [&](size_t i) { return netdef.op(i).type(); }; auto i0 = [&](size_t i) { return netdef.op(i).input(0); }; auto o0 = [&](size_t i) { return netdef.op(i).output(0); }; - CHECK_EQ(netdef.op_size(), 3); - CHECK_EQ(ty(0), "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess"); - CHECK_EQ(ty(1), "MPSCNNRelu"); - CHECK_EQ(ty(2), "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess"); - CHECK_EQ(i0(0), "X"); - CHECK_EQ(i0(1), o0(0)); - CHECK_EQ(i0(2), o0(1)); - CHECK_EQ(o0(2), "Z"); + TORCH_CHECK_EQ(netdef.op_size(), 3); + TORCH_CHECK_EQ(ty(0), "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess"); + TORCH_CHECK_EQ(ty(1), "MPSCNNRelu"); + TORCH_CHECK_EQ(ty(2), "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess"); + TORCH_CHECK_EQ(i0(0), "X"); + TORCH_CHECK_EQ(i0(1), o0(0)); + TORCH_CHECK_EQ(i0(2), o0(1)); + TORCH_CHECK_EQ(o0(2), "Z"); } LOG(INFO) << "All MPSCNN tests passed."; } @@ -3296,12 +3296,12 @@ NetDef truncateAfter(NetDef def, size_t idx) { for (auto i = 0; i < toRemove; ++i) { def.mutable_op()->RemoveLast(); } - CHECK_EQ(def.op_size(), idx + 1); + TORCH_CHECK_EQ(def.op_size(), idx + 1); return def; } NetDef addMPSCNNCopyFinalizer(NetDef def) { - CHECK_GE(def.op_size(), 1); + TORCH_CHECK_GE(def.op_size(), 1); const auto name = def.mutable_op(def.op_size() - 1)->output(0); def.mutable_op(def.op_size() - 1)->set_output(0, "METAL_COPIER"); { @@ -3315,7 +3315,7 @@ NetDef addMPSCNNCopyFinalizer(NetDef def) { void compareModels(const NetDef& initNet, NetDef predictNet) { auto* arg = predictNet.mutable_op(0)->mutable_arg(0); - CHECK_EQ(arg->name(), "noise_std"); + TORCH_CHECK_EQ(arg->name(), "noise_std"); arg->set_f(0.000001); NetDef metalPredictNet; @@ -3365,7 +3365,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) { { const auto& mt = mws.GetBlob(name)->Get(); const auto& ct = cws.GetBlob(name)->Get(); - CHECK_EQ(mt.sizes(), ct.sizes()); + TORCH_CHECK_EQ(mt.sizes(), ct.sizes()); for (auto j = 0; j < mt.size(); ++j) { if (mt.IsType()) { if (j < 10) { @@ -3373,7 +3373,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) { << ", CPU: " << ct.data()[j] << ", MTL: " << mt.data()[j]; } - CHECK_NEAR(mt.data()[j], ct.data()[j], 5); + TORCH_CHECK_NEAR(mt.data()[j], ct.data()[j], 5); } else { CHECK(mt.IsType()); if (j < 10) { @@ -3381,7 +3381,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) { << ", CPU: " << ct.data()[j] << ", MTL: " << mt.data()[j]; } - CHECK_NEAR(mt.data()[j], ct.data()[j], 5); + TORCH_CHECK_NEAR(mt.data()[j], ct.data()[j], 5); } } } @@ -3428,7 +3428,7 @@ void verifyRewrite( LOG(INFO) << "One of the operator failed."; return; } - // CHECK_EQ(mt.sizes(), ct.sizes()); + // TORCH_CHECK_EQ(mt.sizes(), ct.sizes()); for (auto j = 0; j < fmin(mt.size(), ct.size()); ++j) { if (mt.IsType()) { if (j < 10) { @@ -3437,7 +3437,7 @@ void verifyRewrite( << ", MTL: " << mt.data()[j]; } // Disabling check for now because of precision issues - // CHECK_NEAR(mt.data()[j], ct.data()[j], 5); + // TORCH_CHECK_NEAR(mt.data()[j], ct.data()[j], 5); } else { LOG(INFO) << "Type uint8_t"; CHECK(mt.IsType()); @@ -3447,7 +3447,7 @@ void verifyRewrite( << ", MTL: " << mt.data()[j]; } // Disabling check for now. - // CHECK_NEAR(mt.data()[j], ct.data()[j], 5); + // TORCH_CHECK_NEAR(mt.data()[j], ct.data()[j], 5); } } } diff --git a/caffe2/mobile/contrib/ulp2/ulp.cc b/caffe2/mobile/contrib/ulp2/ulp.cc index e0abb3e4b426c..4442deeeaa56b 100644 --- a/caffe2/mobile/contrib/ulp2/ulp.cc +++ b/caffe2/mobile/contrib/ulp2/ulp.cc @@ -170,7 +170,7 @@ void filterNormalization11(const TensorCPU& WQ, TensorCPU* WQN) { for (auto j = 0; j < WQs; ++j) { bitSum += __builtin_popcount(WQdata[f * WQs + j]); } - DCHECK_LE(bitSum, WQbits); + TORCH_DCHECK_LE(bitSum, WQbits); WQNdata[f] = 2 * bitSum - WQbits; } } diff --git a/caffe2/mobile/contrib/ulp2/ulp_neon.cc b/caffe2/mobile/contrib/ulp2/ulp_neon.cc index 254c811528ce0..48a4aba896da3 100644 --- a/caffe2/mobile/contrib/ulp2/ulp_neon.cc +++ b/caffe2/mobile/contrib/ulp2/ulp_neon.cc @@ -19,7 +19,7 @@ inline void quantize2bNeon(size_t QC, float offset, float inter_center_distance, std::array XQdata) { - DCHECK_EQ(QC % 8, 0); + TORCH_DCHECK_EQ(QC % 8, 0); const auto offset_plus_2_inter_center_distance = vdupq_n_f32(offset + 2 * inter_center_distance); const auto offset_plus_inter_center_distance = vdupq_n_f32(offset + inter_center_distance); const auto offset_ = vdupq_n_f32(offset); @@ -291,7 +291,7 @@ void qgess_packed(const uint8_t* __restrict__ Ablock, F&& f) { static_assert(kUnrollN % 8 == 0, ""); static_assert(TileDepthBytes == 16, ""); - DCHECK_EQ(QK % 16, 0); + TORCH_DCHECK_EQ(QK % 16, 0); uint16x8_t acc[kUnrollM][kUnrollN / 8]; for (size_t mm = 0; mm < kUnrollM; ++mm) { for (size_t nn = 0; nn < kUnrollN / 8; ++nn) { diff --git a/caffe2/mobile/contrib/ulp2/ulp_test.cc b/caffe2/mobile/contrib/ulp2/ulp_test.cc index d373c0ea4d496..9bd308c5a5804 100644 --- a/caffe2/mobile/contrib/ulp2/ulp_test.cc +++ b/caffe2/mobile/contrib/ulp2/ulp_test.cc @@ -19,7 +19,7 @@ void conv(const ConvArgs& args, (X.dim32(1) - KH + args.pad_t + args.pad_b) / args.stride_h + 1, (X.dim32(2) - KW + args.pad_l + args.pad_r) / args.stride_w + 1, W.dim32(0)); - CHECK_EQ(W.dim32(3), X.dim32(3)); + TORCH_CHECK_EQ(W.dim32(3), X.dim32(3)); const auto OH = Y->dim32(1); const auto OW = Y->dim32(2); const auto OC = Y->dim32(3); @@ -155,7 +155,7 @@ inline void gemmNT(int M, int N, int K, const float* A, const float* B, float* C } inline void qgemmNT(int M, int N, int K, const uint8_t* A, const uint8_t* B, float* C) { - CHECK_EQ(K % 8, 0); + TORCH_CHECK_EQ(K % 8, 0); const int QK = K / 8; for (auto m = 0; m < M; ++m) { for (auto n = 0; n < N; ++n) { diff --git a/caffe2/operators/apmeter_op.cc b/caffe2/operators/apmeter_op.cc index 62555dae54f25..b515157c266d1 100644 --- a/caffe2/operators/apmeter_op.cc +++ b/caffe2/operators/apmeter_op.cc @@ -12,7 +12,7 @@ void APMeterOp::BufferPredictions( // Initialize the buffer buffers_.resize(D, std::vector(buffer_size_)); } - DCHECK_EQ(buffers_.size(), D); + TORCH_DCHECK_EQ(buffers_.size(), D); // Fill atmose buffer_size_ data at a time, so truncate the input if needed if (N > buffer_size_) { @@ -48,12 +48,12 @@ bool APMeterOp::RunOnDevice() { auto& label = Input(LABEL); // Check dimensions - DCHECK_EQ(X.dim(), 2); + TORCH_DCHECK_EQ(X.dim(), 2); int N = X.dim32(0); int D = X.dim32(1); - DCHECK_EQ(label.dim(), 2); - DCHECK_EQ(label.dim32(0), N); - DCHECK_EQ(label.dim32(1), D); + TORCH_DCHECK_EQ(label.dim(), 2); + TORCH_DCHECK_EQ(label.dim32(0), N); + TORCH_DCHECK_EQ(label.dim32(1), D); auto* Y = Output(0, {D}, at::dtype()); const auto* Xdata = X.data(); diff --git a/caffe2/operators/atomic_ops.cc b/caffe2/operators/atomic_ops.cc index d5ce0d32cd83b..f41f6bb168815 100644 --- a/caffe2/operators/atomic_ops.cc +++ b/caffe2/operators/atomic_ops.cc @@ -3,7 +3,7 @@ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include #include #endif @@ -97,7 +97,7 @@ REGISTER_CPU_OPERATOR(CreateMutex, CreateMutexOp); REGISTER_CPU_OPERATOR(AtomicFetchAdd, AtomicFetchAddOp); REGISTER_CPU_OPERATOR(AtomicFetchAdd64, AtomicFetchAddOp); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR( CreateMutex, IDEEPFallbackOp>); diff --git a/caffe2/operators/batch_box_cox_op.cc b/caffe2/operators/batch_box_cox_op.cc index 51b041669bc0f..aa444330969b5 100644 --- a/caffe2/operators/batch_box_cox_op.cc +++ b/caffe2/operators/batch_box_cox_op.cc @@ -124,8 +124,8 @@ bool BatchBoxCoxOp::DoRunWithType() { if (K > 1) { TileArrayIntoVector(lambda1_ptr, D, K, &b.lambda1_); TileArrayIntoVector(lambda2_ptr, D, K, &b.lambda2_); - DCHECK_EQ(K * D, b.lambda1_.size()); - DCHECK_EQ(K * D, b.lambda2_.size()); + TORCH_DCHECK_EQ(K * D, b.lambda1_.size()); + TORCH_DCHECK_EQ(K * D, b.lambda2_.size()); for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) { BoxCoxNonzeroLambda( K * D, @@ -144,7 +144,7 @@ bool BatchBoxCoxOp::DoRunWithType() { int64_t i = 0; if (K > 1) { TileArrayIntoVector(lambda2_ptr, D, K, &b.lambda2_z_); - DCHECK_EQ(K * D, b.lambda2_z_.size()); + TORCH_DCHECK_EQ(K * D, b.lambda2_z_.size()); for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) { BoxCoxZeroLambda( K * D, data_ptr, b.lambda2_z_.data(), k_eps, output_ptr); @@ -176,9 +176,9 @@ bool BatchBoxCoxOp::DoRunWithType() { zeros_.resize(D - n); TileIndicesInPlace(&nonzeros_, D, K); TileIndicesInPlace(&zeros_, D, K); - DCHECK_EQ(nonzeros_.size(), b.lambda1_.size()); - DCHECK_EQ(nonzeros_.size(), b.lambda2_.size()); - DCHECK_EQ(zeros_.size(), b.lambda2_z_.size()); + TORCH_DCHECK_EQ(nonzeros_.size(), b.lambda1_.size()); + TORCH_DCHECK_EQ(nonzeros_.size(), b.lambda2_.size()); + TORCH_DCHECK_EQ(zeros_.size(), b.lambda2_z_.size()); for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) { BoxCoxMixedLambda( data_ptr, diff --git a/caffe2/operators/batch_permutation_op.cc b/caffe2/operators/batch_permutation_op.cc index 6d0b59c284eb7..97a0b6233de02 100644 --- a/caffe2/operators/batch_permutation_op.cc +++ b/caffe2/operators/batch_permutation_op.cc @@ -3,7 +3,7 @@ #include #include -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include #include #endif @@ -89,7 +89,7 @@ bool BatchPermutationGradientOp::RunOnDevice() { return true; } -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR( BatchPermutation, IDEEPFallbackOp>); diff --git a/caffe2/operators/bbox_transform_op.cc b/caffe2/operators/bbox_transform_op.cc index 5ff210bb5bf51..6ef95c9bed42c 100644 --- a/caffe2/operators/bbox_transform_op.cc +++ b/caffe2/operators/bbox_transform_op.cc @@ -100,7 +100,7 @@ bool BBoxTransformOp::RunOnDevice() { CAFFE_ENFORCE_EQ(iminfo_in.dim32(1), 3); const int batch_size = iminfo_in.dim32(0); - DCHECK_EQ(weights_.size(), 4); + TORCH_DCHECK_EQ(weights_.size(), 4); Eigen::Map boxes0( roi_in.data(), roi_in.dim32(0), roi_in.dim32(1)); diff --git a/caffe2/operators/boolean_unmask_ops_test.cc b/caffe2/operators/boolean_unmask_ops_test.cc index ae7425a151dca..096e60f19f93b 100644 --- a/caffe2/operators/boolean_unmask_ops_test.cc +++ b/caffe2/operators/boolean_unmask_ops_test.cc @@ -64,7 +64,7 @@ TEST(BooleanUnmaskTest, Test) { auto& unmasked_data = unmasked_data_blob->Get(); EXPECT_EQ(unmasked_data.numel(), 1); - CHECK_EQ(unmasked_data.data()[0], 1.0f); + TORCH_CHECK_EQ(unmasked_data.data()[0], 1.0f); } } // namespace caffe2 diff --git a/caffe2/operators/box_with_nms_limit_op.cc b/caffe2/operators/box_with_nms_limit_op.cc index fba8aefab7d46..ead3e8f83dd13 100644 --- a/caffe2/operators/box_with_nms_limit_op.cc +++ b/caffe2/operators/box_with_nms_limit_op.cc @@ -161,14 +161,14 @@ const auto& tscores = Input(0); // Pick the first `detections_per_im_` boxes with highest scores auto all_scores_sorted = get_all_scores_sorted(); - DCHECK_GT(all_scores_sorted.size(), detections_per_im_); + TORCH_DCHECK_GT(all_scores_sorted.size(), detections_per_im_); // Reconstruct keeps from `all_scores_sorted` for (auto& cur_keep : keeps) { cur_keep.clear(); } for (int i = 0; i < detections_per_im_; i++) { - DCHECK_GT(all_scores_sorted.size(), i); + TORCH_DCHECK_GT(all_scores_sorted.size(), i); auto& cur = all_scores_sorted[i]; keeps[cur.first].push_back(cur.second); } diff --git a/caffe2/operators/cross_entropy_op.cc b/caffe2/operators/cross_entropy_op.cc index fb42d667e224c..9710096fed5d7 100644 --- a/caffe2/operators/cross_entropy_op.cc +++ b/caffe2/operators/cross_entropy_op.cc @@ -272,8 +272,8 @@ bool MakeTwoClassOp::RunOnDevice() { const auto* Xdata = X.data(); auto* Ydata = Y->template mutable_data(); for (int64_t i = 0; i < N; ++i) { - DCHECK_GE(Xdata[i], 0.0); - DCHECK_LE(Xdata[i], 1.0); + TORCH_DCHECK_GE(Xdata[i], 0.0); + TORCH_DCHECK_LE(Xdata[i], 1.0); Ydata[i * 2] = 1.0 - Xdata[i]; Ydata[i * 2 + 1] = Xdata[i]; } diff --git a/caffe2/operators/deform_conv_op.cu b/caffe2/operators/deform_conv_op.cu index 02e528c5f9279..0257be46d2c9b 100644 --- a/caffe2/operators/deform_conv_op.cu +++ b/caffe2/operators/deform_conv_op.cu @@ -308,7 +308,7 @@ void DeformConvOpBase::DeformableIm2col( at::IntArrayRef im_shape, at::IntArrayRef col_shape, DType* data_col) { - CHECK_LT(2, CAFFE_CUDA_NUM_THREADS); + TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS); CAFFE_ENFORCE_EQ(pad_t(), pad_b()); CAFFE_ENFORCE_EQ(pad_l(), pad_r()); const int pad_h = pad_t(); @@ -444,7 +444,7 @@ void DeformConvOpBase::DeformableCol2im( index_t channel_per_deformable_group = im_shape[1] / deformable_group_; index_t num_kernels = size_from_dim_(0, col_shape); // num_axes should be smaller than block size - CHECK_LT(2, CAFFE_CUDA_NUM_THREADS); + TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS); // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. // NOLINT_NEXT_LINE(whitespace/operators) @@ -592,7 +592,7 @@ void DeformConvOpBase::DeformableCol2imCoord( kernel_w() * deformable_group_; index_t channel_per_deformable_group = col_shape[0] / deformable_group_; // num_axes should be smaller than block size - CHECK_LT(2, CAFFE_CUDA_NUM_THREADS); + TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS); // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. // NOLINT_NEXT_LINE(whitespace/operators) diff --git a/caffe2/operators/distance_op.cc b/caffe2/operators/distance_op.cc index 9ea8eea5a2725..d869a568a7fc9 100644 --- a/caffe2/operators/distance_op.cc +++ b/caffe2/operators/distance_op.cc @@ -1,7 +1,7 @@ #include "caffe2/operators/distance_op.h" #include "caffe2/core/types.h" #include "caffe2/utils/eigen_utils.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include #include #endif @@ -420,7 +420,7 @@ REGISTER_CPU_OPERATOR(L1Distance, L1DistanceOp); REGISTER_CPU_OPERATOR( L1DistanceGradient, L1DistanceGradientOp); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR( L1DistanceGradient, IDEEPFallbackOp>); diff --git a/caffe2/operators/filler_op.h b/caffe2/operators/filler_op.h index b95f8401f761a..b846faabc490a 100644 --- a/caffe2/operators/filler_op.h +++ b/caffe2/operators/filler_op.h @@ -444,7 +444,7 @@ class GaussianFillOp final : public FillerOp { : FillerOp(std::forward(args)...), mean_(this->template GetSingleArgument("mean", 0)), std_(this->template GetSingleArgument("std", 1)) { - DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative."; + TORCH_DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative."; } bool Fill(Tensor* output) override { diff --git a/caffe2/operators/fully_connected_op.h b/caffe2/operators/fully_connected_op.h index 563465f6220e4..22810237aadc9 100644 --- a/caffe2/operators/fully_connected_op.h +++ b/caffe2/operators/fully_connected_op.h @@ -73,7 +73,7 @@ class FullyConnectedOp final : public Operator { Y_shape_cache_ = X.sizes().vec(); // This is an invariant of canonical_axis, so we can DCHECK. - DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); + TORCH_DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); Y_shape_cache_.resize(canonical_axis + 1); Y_shape_cache_[canonical_axis] = N; auto* Y = Output(0, Y_shape_cache_, at::dtype()); diff --git a/caffe2/operators/generate_proposals_op.cc b/caffe2/operators/generate_proposals_op.cc index c479b165bc632..25d813fdda3bb 100644 --- a/caffe2/operators/generate_proposals_op.cc +++ b/caffe2/operators/generate_proposals_op.cc @@ -11,7 +11,7 @@ namespace { size_t ComputeStartIndex( const TensorCPU& tensor, const std::vector& index) { - DCHECK_EQ(index.size(), tensor.dim()); + TORCH_DCHECK_EQ(index.size(), tensor.dim()); size_t ret = 0; // NOLINTNEXTLINE(clang-diagnostic-sign-compare) @@ -27,7 +27,7 @@ template utils::ConstTensorView GetSubTensorView( const TensorCPU& tensor, int dim0_start_index) { - DCHECK_EQ(tensor.dtype().itemsize(), sizeof(T)); + TORCH_DCHECK_EQ(tensor.dtype().itemsize(), sizeof(T)); if (tensor.numel() == 0) { return utils::ConstTensorView(nullptr, {}); @@ -244,7 +244,7 @@ void GenerateProposalsOp::ProposalsForOneImage( // 3. remove predicted boxes with either height or width < min_size auto keep = utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_); - DCHECK_LE(keep.size(), scores_sorted.size()); + TORCH_DCHECK_LE(keep.size(), scores_sorted.size()); // 6. apply loose nms (e.g. threshold = 0.7) // 7. take after_nms_topN (e.g. 300) diff --git a/caffe2/operators/generate_proposals_op.h b/caffe2/operators/generate_proposals_op.h index b783b3db437b5..f70f595d40f1c 100644 --- a/caffe2/operators/generate_proposals_op.h +++ b/caffe2/operators/generate_proposals_op.h @@ -28,7 +28,7 @@ class ConstTensorView { return dims_; } int dim(int i) const { - DCHECK_LE(i, dims_.size()); + TORCH_DCHECK_LE(i, dims_.size()); return dims_[i]; } const T* data() const { diff --git a/caffe2/operators/generate_proposals_op_gpu_test.cc b/caffe2/operators/generate_proposals_op_gpu_test.cc index d328f81726c05..ef1ced04e0452 100644 --- a/caffe2/operators/generate_proposals_op_gpu_test.cc +++ b/caffe2/operators/generate_proposals_op_gpu_test.cc @@ -316,7 +316,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0GPU) { // Add angle in bbox deltas int num_boxes = scores.size(); - CHECK_EQ(bbx.size() / 4, num_boxes); + TORCH_CHECK_EQ(bbx.size() / 4, num_boxes); vector bbx_with_angle(num_boxes * box_dim); // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta // at each spatial location for each anchor. @@ -516,7 +516,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedGPU) { // Add angle in bbox deltas int num_boxes = scores.size(); - CHECK_EQ(bbx.size() / 4, num_boxes); + TORCH_CHECK_EQ(bbx.size() / 4, num_boxes); vector bbx_with_angle(num_boxes * box_dim); // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta // at each spatial location for each anchor. diff --git a/caffe2/operators/generate_proposals_op_test.cc b/caffe2/operators/generate_proposals_op_test.cc index 9692c2846e97c..f869adcc17bee 100644 --- a/caffe2/operators/generate_proposals_op_test.cc +++ b/caffe2/operators/generate_proposals_op_test.cc @@ -494,7 +494,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) { // Add angle in bbox deltas auto num_boxes = scores.size(); - CHECK_EQ(bbx.size() / 4, num_boxes); + TORCH_CHECK_EQ(bbx.size() / 4, num_boxes); vector bbx_with_angle(num_boxes * box_dim); // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta // at each spatial location for each anchor. @@ -667,7 +667,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) { // Add angle in bbox deltas auto num_boxes = scores.size(); - CHECK_EQ(bbx.size() / 4, num_boxes); + TORCH_CHECK_EQ(bbx.size() / 4, num_boxes); vector bbx_with_angle(num_boxes * box_dim); // bbx (deltas) is in shape (A * 4, H, W). Insert angle delta // at each spatial location for each anchor. diff --git a/caffe2/operators/given_tensor_byte_string_to_uint8_fill_op.h b/caffe2/operators/given_tensor_byte_string_to_uint8_fill_op.h index 5cab269dc407e..28cec6d75b93e 100644 --- a/caffe2/operators/given_tensor_byte_string_to_uint8_fill_op.h +++ b/caffe2/operators/given_tensor_byte_string_to_uint8_fill_op.h @@ -36,7 +36,7 @@ class GivenTensorByteStringToUInt8FillOp final : public FillerOp { } bool Fill(Tensor* output) override { - DCHECK_EQ(output->numel(), values_.numel()) + TORCH_DCHECK_EQ(output->numel(), values_.numel()) << "output size: " << output->numel() << " given size: " << values_.numel(); auto* data = output->template mutable_data(); @@ -51,7 +51,7 @@ class GivenTensorByteStringToUInt8FillOp final : public FillerOp { private: void Extract() { auto source_values = this->template GetRepeatedArgument("values"); - DCHECK_EQ(source_values.size(), 1) + TORCH_DCHECK_EQ(source_values.size(), 1) << "expected size: 1 " << " given size: " << source_values.size(); diff --git a/caffe2/operators/local_response_normalization_op.cc b/caffe2/operators/local_response_normalization_op.cc index 6bb07e7cb24b7..036816730b776 100644 --- a/caffe2/operators/local_response_normalization_op.cc +++ b/caffe2/operators/local_response_normalization_op.cc @@ -7,7 +7,7 @@ bool LRNOp::RunOnDeviceWithOrderNCHW() { // Note(Yangqing): this one is copied from my Caffe implementation. auto& X = Input(0); - DCHECK_EQ(X.dim(), 4); + TORCH_DCHECK_EQ(X.dim(), 4); const int N = X.dim32(0); const int C = X.dim32(1); const int H = X.dim32(2); @@ -81,7 +81,7 @@ bool LRNOp::RunOnDeviceWithOrderNHWC() { // variants have I written...? auto& X = Input(0); - DCHECK_EQ(X.dim(), 4); + TORCH_DCHECK_EQ(X.dim(), 4); const int N = X.dim32(0); const int H = X.dim32(1); const int W = X.dim32(2); @@ -135,7 +135,7 @@ bool LRNGradientOp::RunOnDeviceWithOrderNCHW() { auto& Y = Input(1); auto& dY = Input(2); - DCHECK_EQ(X.dim(), 4); + TORCH_DCHECK_EQ(X.dim(), 4); const int N = X.dim32(0); const int C = X.dim32(1); const int H = X.dim32(2); @@ -143,8 +143,8 @@ bool LRNGradientOp::RunOnDeviceWithOrderNCHW() { const int image_size = C * H * W; // Loosely checking the size, assuming that the shapes will be the same as // long as the sizes check out. - DCHECK_EQ(X.numel(), Y.numel()); - DCHECK_EQ(X.numel(), dY.numel()); + TORCH_DCHECK_EQ(X.numel(), Y.numel()); + TORCH_DCHECK_EQ(X.numel(), dY.numel()); auto* dX = Output(0, X.sizes(), at::dtype()); const float* Xdata = X.data(); @@ -248,7 +248,7 @@ bool LRNGradientOp::RunOnDeviceWithOrderNHWC() { auto& Y = Input(1); auto& dY = Input(2); - DCHECK_EQ(X.dim(), 4); + TORCH_DCHECK_EQ(X.dim(), 4); const int N = X.dim32(0); const int H = X.dim32(1); const int W = X.dim32(2); @@ -257,8 +257,8 @@ bool LRNGradientOp::RunOnDeviceWithOrderNHWC() { const float* Xdata = X.data(); // Loosely checking the size, assuming that the shapes will be the same as // long as the sizes check out. - DCHECK_EQ(X.numel(), Y.numel()); - DCHECK_EQ(X.numel(), dY.numel()); + TORCH_DCHECK_EQ(X.numel(), Y.numel()); + TORCH_DCHECK_EQ(X.numel(), dY.numel()); auto* dX = Output(0, X.sizes(), at::dtype()); if (!scale_) { scale_ = &local_scale_tensor_; diff --git a/caffe2/operators/local_response_normalization_op.cu b/caffe2/operators/local_response_normalization_op.cu index 811e5ecaa3552..03a10e00e1ee8 100644 --- a/caffe2/operators/local_response_normalization_op.cu +++ b/caffe2/operators/local_response_normalization_op.cu @@ -177,7 +177,7 @@ template<> bool LRNOp::RunOnDeviceWithOrderNCHW() { auto& X = Input(0); - DCHECK_EQ(X.dim(), 4); + TORCH_DCHECK_EQ(X.dim(), 4); const int N = X.dim32(0); const int C = X.dim32(1); const int H = X.dim32(2); @@ -214,7 +214,7 @@ template<> bool LRNOp::RunOnDeviceWithOrderNHWC() { auto& X = Input(0); - DCHECK_EQ(X.dim(), 4); + TORCH_DCHECK_EQ(X.dim(), 4); const int N = X.dim32(0); const int H = X.dim32(1); const int W = X.dim32(2); @@ -252,15 +252,15 @@ bool LRNGradientOp::RunOnDeviceWithOrderNCHW() { auto& Y = Input(1); auto& dY = Input(2); - DCHECK_EQ(X.dim(), 4); + TORCH_DCHECK_EQ(X.dim(), 4); const int N = X.dim32(0); const int C = X.dim32(1); const int H = X.dim32(2); const int W = X.dim32(3); // Loosely checking the size, assuming that the shapes will be the same as // long as the sizes check out. - DCHECK_EQ(X.numel(), Y.numel()); - DCHECK_EQ(X.numel(), dY.numel()); + TORCH_DCHECK_EQ(X.numel(), Y.numel()); + TORCH_DCHECK_EQ(X.numel(), dY.numel()); auto* dX = Output(0, X.sizes(), at::dtype()); const float* Xdata = X.data(); @@ -295,7 +295,7 @@ bool LRNGradientOp::RunOnDeviceWithOrderNHWC() { auto& Y = Input(1); auto& dY = Input(2); - DCHECK_EQ(X.dim(), 4); + TORCH_DCHECK_EQ(X.dim(), 4); const int N = X.dim32(0); const int H = X.dim32(1); const int W = X.dim32(2); @@ -303,8 +303,8 @@ bool LRNGradientOp::RunOnDeviceWithOrderNHWC() { const float* Xdata = X.data(); // Loosely checking the size, assuming that the shapes will be the same as // long as the sizes check out. - DCHECK_EQ(X.numel(), Y.numel()); - DCHECK_EQ(X.numel(), dY.numel()); + TORCH_DCHECK_EQ(X.numel(), Y.numel()); + TORCH_DCHECK_EQ(X.numel(), dY.numel()); auto* dX = Output(0, X.sizes(), at::dtype()); if (!scale_) { scale_ = &local_scale_tensor_; diff --git a/caffe2/operators/local_response_normalization_op.h b/caffe2/operators/local_response_normalization_op.h index b0b02a7f73c51..5d5bda2b4a2c7 100644 --- a/caffe2/operators/local_response_normalization_op.h +++ b/caffe2/operators/local_response_normalization_op.h @@ -22,10 +22,10 @@ class LRNOpBase : public Operator { order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))), pre_pad_((size_ - 1) / 2) { - DCHECK_GT(size_, 0); - DCHECK_EQ(size_ % 2, 1); - DCHECK_GT(alpha_, 0); - DCHECK_GT(beta_, 0); + TORCH_DCHECK_GT(size_, 0); + TORCH_DCHECK_EQ(size_ % 2, 1); + TORCH_DCHECK_GT(alpha_, 0); + TORCH_DCHECK_GT(beta_, 0); } bool RunOnDevice() override { diff --git a/caffe2/operators/mem_query_op.cu b/caffe2/operators/mem_query_op.cu index f53402ce660a9..07e15f040a531 100644 --- a/caffe2/operators/mem_query_op.cu +++ b/caffe2/operators/mem_query_op.cu @@ -11,11 +11,11 @@ class GetGPUMemoryUsageOp final : public Operator { ~GetGPUMemoryUsageOp() override {} bool RunOnDevice() override { - CHECK_EQ(InputSize(), 0); - CHECK_EQ(OutputSize(), 1); + TORCH_CHECK_EQ(InputSize(), 0); + TORCH_CHECK_EQ(OutputSize(), 1); std::vector total_by_gpu = CUDAContext::TotalMemoryByGpu(); std::vector max_by_gpu = CUDAContext::MaxMemoryByGpu(); - CHECK_EQ(total_by_gpu.size(), max_by_gpu.size()); + TORCH_CHECK_EQ(total_by_gpu.size(), max_by_gpu.size()); auto* stats = Output(0, {2, static_cast(total_by_gpu.size())}, at::dtype()); diff --git a/caffe2/operators/multi_class_accuracy_op.cc b/caffe2/operators/multi_class_accuracy_op.cc index 1f44020c4c79c..c48f0fd05ef79 100644 --- a/caffe2/operators/multi_class_accuracy_op.cc +++ b/caffe2/operators/multi_class_accuracy_op.cc @@ -7,13 +7,13 @@ bool MultiClassAccuracyOp::RunOnDevice() { auto& X = Input(PREDICTION); auto& label = Input(LABEL); - DCHECK_EQ(X.dim(), 2); + TORCH_DCHECK_EQ(X.dim(), 2); // amount, number of instances int N = X.dim32(0); // dimension, number of classes int D = X.dim32(1); - DCHECK_EQ(label.dim(), 1); - DCHECK_EQ(label.dim32(0), N); + TORCH_DCHECK_EQ(label.dim(), 1); + TORCH_DCHECK_EQ(label.dim32(0), N); auto* Y0 = Output(0, {D}, at::dtype()); auto* Y1 = Output(1, {D}, at::dtype()); @@ -34,7 +34,7 @@ bool MultiClassAccuracyOp::RunOnDevice() { } } int labelid = labeldata[i]; - DCHECK_LT(labelid, D); + TORCH_DCHECK_LT(labelid, D); if (maxid == labelid) { accuracies[labelid]++; } diff --git a/caffe2/operators/multi_class_accuracy_op.cu b/caffe2/operators/multi_class_accuracy_op.cu index fdb272c1c34d5..52588d1ffccbb 100644 --- a/caffe2/operators/multi_class_accuracy_op.cu +++ b/caffe2/operators/multi_class_accuracy_op.cu @@ -40,13 +40,13 @@ bool MultiClassAccuracyOp::RunOnDevice() { auto& label = Input(LABEL); - DCHECK_EQ(X.dim(), 2); + TORCH_DCHECK_EQ(X.dim(), 2); // amount, number of instances int N = X.dim32(0); // dimension, number of classes int D = X.dim32(1); - DCHECK_EQ(label.dim(), 1); - DCHECK_EQ(label.dim32(0), N); + TORCH_DCHECK_EQ(label.dim(), 1); + TORCH_DCHECK_EQ(label.dim32(0), N); auto* Y0 = Output(0, {D}, at::dtype()); auto* Y1 = Output(1, {D}, at::dtype()); diff --git a/caffe2/operators/operator_fallback_gpu.h b/caffe2/operators/operator_fallback_gpu.h index a728b79b4916f..a829372a11bc2 100644 --- a/caffe2/operators/operator_fallback_gpu.h +++ b/caffe2/operators/operator_fallback_gpu.h @@ -52,12 +52,12 @@ class GPUFallbackOpEx final : public Operator { // Set up the symbols for the local workspace. for (const string& name : def.input()) { local_input_blobs_.push_back(local_ws_.CreateBlob(name)); - CHECK_NOTNULL(local_input_blobs_.back()); + TORCH_CHECK_NOTNULL(local_input_blobs_.back()); } base_op_ = CreateOperator(base_def_, &local_ws_); for (const string& name : def.output()) { local_output_blobs_.push_back(local_ws_.GetBlob(name)); - CHECK_NOTNULL(local_output_blobs_.back()); + TORCH_CHECK_NOTNULL(local_output_blobs_.back()); } } diff --git a/caffe2/operators/perplexity_op.cc b/caffe2/operators/perplexity_op.cc index 882ee4e973535..46a9b2d4c6b16 100644 --- a/caffe2/operators/perplexity_op.cc +++ b/caffe2/operators/perplexity_op.cc @@ -6,7 +6,7 @@ template <> bool PerplexityOp::RunOnDevice() { auto& X = Input(0); - DCHECK_EQ(X.dim(), 1); + TORCH_DCHECK_EQ(X.dim(), 1); int N = X.dim32(0); auto* Y = Output(0, vector(), at::dtype()); diff --git a/caffe2/operators/perplexity_op.cu b/caffe2/operators/perplexity_op.cu index 0f7a066dc41ee..d319ea09fb9d2 100644 --- a/caffe2/operators/perplexity_op.cu +++ b/caffe2/operators/perplexity_op.cu @@ -21,7 +21,7 @@ template <> bool PerplexityOp::RunOnDevice() { auto& X = Input(0); - DCHECK_EQ(X.dim(), 1); + TORCH_DCHECK_EQ(X.dim(), 1); int N = X.dim32(0); auto* Y = Output(0, vector(), at::dtype()); diff --git a/caffe2/operators/piecewise_linear_transform_op.cu b/caffe2/operators/piecewise_linear_transform_op.cu index 5ea95bd6d84db..fe17af9a2f6c5 100644 --- a/caffe2/operators/piecewise_linear_transform_op.cu +++ b/caffe2/operators/piecewise_linear_transform_op.cu @@ -1,14 +1,11 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/operators/piecewise_linear_transform_op.h" - -#include -#include -#include -#include +#include namespace caffe2 { namespace { + __global__ void PieceWiseLinearTransformGeneralKernel( const int N, const int M, @@ -30,8 +27,7 @@ __global__ void PieceWiseLinearTransformGeneralKernel( Y[i] = slopes_group[num_fnc_per_grp - 1] * bounds_group[num_fnc_per_grp] + intercepts_group[num_fnc_per_grp - 1]; } else { - auto low_bound = thrust::lower_bound( - thrust::device, + auto low_bound = c10::cuda::lower_bound( bounds_group, bounds_group + num_fnc_per_grp + 1, X[i]); @@ -59,8 +55,8 @@ __global__ void PieceWiseLinearTransformBinaryKernel1( Y[i] = slopes[num_fnc_per_grp - 1] * bounds[num_fnc_per_grp] + intercepts[num_fnc_per_grp - 1]; } else { - auto low_bound = thrust::lower_bound( - thrust::device, bounds, bounds + num_fnc_per_grp + 1, X[i]); + auto low_bound = c10::cuda::lower_bound( + bounds, bounds + num_fnc_per_grp + 1, X[i]); int bounds_idx = low_bound - bounds - 1; Y[i] = slopes[bounds_idx] * X[i] + intercepts[bounds_idx]; } @@ -87,8 +83,8 @@ __global__ void PieceWiseLinearTransformBinaryKernel2( Y[index + 1] = slopes[num_fnc_per_grp - 1] * bounds[num_fnc_per_grp] + intercepts[num_fnc_per_grp - 1]; } else { - auto low_bound = thrust::lower_bound( - thrust::device, bounds, bounds + num_fnc_per_grp + 1, X[index + 1]); + auto low_bound = c10::cuda::lower_bound( + bounds, bounds + num_fnc_per_grp + 1, X[index + 1]); int bounds_idx = low_bound - bounds - 1; Y[index + 1] = slopes[bounds_idx] * X[index + 1] + intercepts[bounds_idx]; } diff --git a/caffe2/operators/prelu_op.cc b/caffe2/operators/prelu_op.cc index 7470579ba6ecf..b950b2123a2ad 100644 --- a/caffe2/operators/prelu_op.cc +++ b/caffe2/operators/prelu_op.cc @@ -176,7 +176,7 @@ bool PReluGradientOp::RunOnDevice() { CAFFE_ENFORCE(&Y != &X, "Cannot backpropagate through an in-place PReLU"); - DCHECK_EQ(dY.numel(), Y.numel()); + TORCH_DCHECK_EQ(dY.numel(), Y.numel()); auto* dX = Output(0, Y.sizes(), at::dtype()); auto* dW = Output(1, W.sizes(), at::dtype()); diff --git a/caffe2/operators/prelu_op.cu b/caffe2/operators/prelu_op.cu index 38cc592548222..6e1a5dbadaad6 100644 --- a/caffe2/operators/prelu_op.cu +++ b/caffe2/operators/prelu_op.cu @@ -212,7 +212,7 @@ bool PReluGradientOp::RunOnDevice() { CAFFE_ENFORCE(&Y != &X, "Cannot backpropagate through an in-place PReLU"); - DCHECK_EQ(dY.numel(), Y.numel()); + TORCH_DCHECK_EQ(dY.numel(), Y.numel()); auto* dX = Output(0, Y.sizes(), at::dtype()); auto* dW = Output(1, W.sizes(), at::dtype()); diff --git a/caffe2/operators/quant_decode_op.h b/caffe2/operators/quant_decode_op.h index 5253d9975c39b..e581c2cddf0d1 100644 --- a/caffe2/operators/quant_decode_op.h +++ b/caffe2/operators/quant_decode_op.h @@ -37,7 +37,7 @@ void Decode( int sz = output->numel(); for (C10_UNUSED const auto i : c10::irange(sz)) { - DCHECK_LE(*code_ptr, cb_size); + TORCH_DCHECK_LE(*code_ptr, cb_size); *out_ptr++ = cb_ptr[*code_ptr++]; } } else { @@ -49,7 +49,7 @@ void Decode( CAFFE_ENFORCE_EQ(cb_size, output->numel()); auto* out_ptr = output->template mutable_data(); while (gradient_ptr < gradient_end) { - DCHECK_LE(*code_ptr, cb_size); + TORCH_DCHECK_LE(*code_ptr, cb_size); out_ptr[*code_ptr++] += *gradient_ptr++; } } diff --git a/caffe2/operators/quantized/int8_average_pool_op.h b/caffe2/operators/quantized/int8_average_pool_op.h index 1df8f18075648..33faeb6dee3d4 100644 --- a/caffe2/operators/quantized/int8_average_pool_op.h +++ b/caffe2/operators/quantized/int8_average_pool_op.h @@ -43,7 +43,7 @@ class Int8AveragePoolOp final : public ConvPoolOpBase { Y->scale = Y_scale; Y->zero_point = Y_zero_point; - CHECK_EQ(X.t.dim(), 4); + TORCH_CHECK_EQ(X.t.dim(), 4); const int channels = X.t.dim32(3); ConvPoolOpBase::SetOutputSize(X.t, &(Y->t), channels); diff --git a/caffe2/operators/quantized/int8_channel_shuffle_op.h b/caffe2/operators/quantized/int8_channel_shuffle_op.h index 50806d81404fd..9cb309012790e 100644 --- a/caffe2/operators/quantized/int8_channel_shuffle_op.h +++ b/caffe2/operators/quantized/int8_channel_shuffle_op.h @@ -42,10 +42,10 @@ class Int8ChannelShuffleOp final : public ConvPoolOpBase { this->template GetSingleArgument("Y_zero_point", 0); const float Y_scale = this->template GetSingleArgument("Y_scale", 1.0f); - CHECK_EQ(Y_offset, X.zero_point); - CHECK_EQ(Y_scale, X.scale); - CHECK_GE(X.zero_point, std::numeric_limits::min()); - CHECK_LE(X.zero_point, std::numeric_limits::max()); + TORCH_CHECK_EQ(Y_offset, X.zero_point); + TORCH_CHECK_EQ(Y_scale, X.scale); + TORCH_CHECK_GE(X.zero_point, std::numeric_limits::min()); + TORCH_CHECK_LE(X.zero_point, std::numeric_limits::max()); const auto C = X.t.dim32(3); const auto G = this->group_; diff --git a/caffe2/operators/quantized/int8_concat_op.h b/caffe2/operators/quantized/int8_concat_op.h index d0c8d24e9840d..ef5715a6f56a7 100644 --- a/caffe2/operators/quantized/int8_concat_op.h +++ b/caffe2/operators/quantized/int8_concat_op.h @@ -20,13 +20,13 @@ class Int8ConcatOp final : public Operator { if (this->template GetSingleArgument("order", "") == "NHWC") { // Default to C axis axis_ = this->template GetSingleArgument("axis", 3); - CHECK_GE(axis_, 0); - CHECK_LT(axis_, 4); + TORCH_CHECK_GE(axis_, 0); + TORCH_CHECK_LT(axis_, 4); } else if ( this->template GetSingleArgument("order", "") == "NCHW") { axis_ = this->template GetSingleArgument("axis", 1); - CHECK_GE(axis_, 0); - CHECK_LT(axis_, 4); + TORCH_CHECK_GE(axis_, 0); + TORCH_CHECK_LT(axis_, 4); } else { axis_ = this->template GetSingleArgument("axis", 0); } @@ -39,20 +39,20 @@ class Int8ConcatOp final : public Operator { Y->zero_point = X0.zero_point; int32_t Y_offset = this->template GetSingleArgument("Y_zero_point", 0); auto Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_EQ(Y_offset, X0.zero_point); - CHECK_EQ(Y_scale, X0.scale); - CHECK_GE(X0.zero_point, std::numeric_limits::min()); - CHECK_LE(X0.zero_point, std::numeric_limits::max()); + TORCH_CHECK_EQ(Y_offset, X0.zero_point); + TORCH_CHECK_EQ(Y_scale, X0.scale); + TORCH_CHECK_GE(X0.zero_point, std::numeric_limits::min()); + TORCH_CHECK_LE(X0.zero_point, std::numeric_limits::max()); auto Y_dims = X0.t.sizes().vec(); if (this->template GetSingleArgument("order", "") == "NHWC") { - CHECK_EQ(Y_dims.size(), 4); + TORCH_CHECK_EQ(Y_dims.size(), 4); } for (const auto i : c10::irange(1, InputSize())) { const auto& Xi = Inputs()[i]->template Get(); - CHECK_EQ(Xi.t.dim(), Y_dims.size()); + TORCH_CHECK_EQ(Xi.t.dim(), Y_dims.size()); for (const auto j : c10::irange(Y_dims.size())) { if (j != axis_) { - CHECK_EQ(Xi.t.size(j), Y_dims[j]); + TORCH_CHECK_EQ(Xi.t.size(j), Y_dims[j]); } } Y_dims[axis_] += Xi.t.size(axis_); diff --git a/caffe2/operators/quantized/int8_conv_op.h b/caffe2/operators/quantized/int8_conv_op.h index 055e85141e261..d830ff351fee1 100644 --- a/caffe2/operators/quantized/int8_conv_op.h +++ b/caffe2/operators/quantized/int8_conv_op.h @@ -56,7 +56,7 @@ class Int8ConvOp final : public ConvPoolOpBase { const bool isDepthwise = this->group_ > 1 && this->group_ == M && this->group_ == C && KC == 1 && KH * KW == 9 && dilation_w() == 1; - CHECK_EQ(Y->t.dim32(3), M); + TORCH_CHECK_EQ(Y->t.dim32(3), M); runWithSharedBuffer(ws_, [&](Tensor* buffer) { initQNNPACK(); diff --git a/caffe2/operators/quantized/int8_conv_transpose_op.h b/caffe2/operators/quantized/int8_conv_transpose_op.h index 797acc2748aef..38c6fd31ba911 100644 --- a/caffe2/operators/quantized/int8_conv_transpose_op.h +++ b/caffe2/operators/quantized/int8_conv_transpose_op.h @@ -47,14 +47,14 @@ class Int8ConvTransposeOp final : public ConvTransposeUnpoolBase { const auto IC = X.t.size(3); - CHECK_EQ(IC, W.t.size(0)); + TORCH_CHECK_EQ(IC, W.t.size(0)); const auto KH = W.t.size(1); const auto KW = W.t.size(2); const auto OC = W.t.size(3); auto sizes = ConvTransposeUnpoolBase::GetOutputSize(X.t, OC); ReinitializeTensor(&(Y->t), sizes, at::dtype().device(CPU)); - CHECK_EQ(OC, Y->t.size(3)); + TORCH_CHECK_EQ(OC, Y->t.size(3)); runWithSharedBuffer(ws_, [&](Tensor* buffer) { initQNNPACK(); diff --git a/caffe2/operators/quantized/int8_fc_op.h b/caffe2/operators/quantized/int8_fc_op.h index 3f0e65adc9f52..0fa30f5b2ab8a 100644 --- a/caffe2/operators/quantized/int8_fc_op.h +++ b/caffe2/operators/quantized/int8_fc_op.h @@ -39,8 +39,8 @@ class Int8FCOp final : public Operator { // (NxHxW)xC == MxK x (NxK) -> MxN const auto K = X.t.size_from_dim(1); const auto N = W.t.size(0); - CHECK_EQ(K, W.t.size(1)); - CHECK_EQ(N, B.t.numel()); + TORCH_CHECK_EQ(K, W.t.size(1)); + TORCH_CHECK_EQ(N, B.t.numel()); const auto M = X.t.numel() / K; ReinitializeTensor(&Y->t, {M, N}, at::dtype().device(CPU)); diff --git a/caffe2/operators/quantized/int8_flatten_op.h b/caffe2/operators/quantized/int8_flatten_op.h index 5660ae49c8b4e..ece57973e25d8 100644 --- a/caffe2/operators/quantized/int8_flatten_op.h +++ b/caffe2/operators/quantized/int8_flatten_op.h @@ -22,8 +22,8 @@ class Int8FlattenOp : public Operator { auto* Y = Outputs()[0]->GetMutable(); int32_t Y_offset = this->template GetSingleArgument("Y_zero_point", 0); auto Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_EQ(Y_offset, X.zero_point); - CHECK_EQ(Y_scale, X.scale); + TORCH_CHECK_EQ(Y_offset, X.zero_point); + TORCH_CHECK_EQ(Y_scale, X.scale); Y->scale = Y_scale; Y->zero_point = Y_offset; CAFFE_ENFORCE_GE( diff --git a/caffe2/operators/quantized/int8_given_tensor_fill_op.h b/caffe2/operators/quantized/int8_given_tensor_fill_op.h index 8080ca78b344d..cb97076a23398 100644 --- a/caffe2/operators/quantized/int8_given_tensor_fill_op.h +++ b/caffe2/operators/quantized/int8_given_tensor_fill_op.h @@ -46,7 +46,7 @@ class Int8GivenTensorFillOp final : public Operator { } bool Fill(Int8TensorCPU* output) { - DCHECK_EQ(output->t.numel(), values_.numel()) + TORCH_DCHECK_EQ(output->t.numel(), values_.numel()) << "output size: " << output->t.numel() << " given size: " << values_.numel(); auto* data = output->t.template mutable_data(); @@ -98,7 +98,7 @@ class Int8GivenIntTensorFillOp final : public Operator { } bool Fill(Int8TensorCPU* output) { - DCHECK_EQ(output->t.numel(), values_.numel()) + TORCH_DCHECK_EQ(output->t.numel(), values_.numel()) << "output size: " << output->t.numel() << " given size: " << values_.numel(); auto* data = output->t.template mutable_data(); diff --git a/caffe2/operators/quantized/int8_leaky_relu_op.h b/caffe2/operators/quantized/int8_leaky_relu_op.h index 400b4d8472648..6271695874ec1 100644 --- a/caffe2/operators/quantized/int8_leaky_relu_op.h +++ b/caffe2/operators/quantized/int8_leaky_relu_op.h @@ -38,8 +38,8 @@ class Int8LeakyReluOp final : public Operator { const int32_t Y_zero_point = this->template GetSingleArgument("Y_zero_point", 0); const float Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_GE(Y_zero_point, std::numeric_limits::min()); - CHECK_LE(Y_zero_point, std::numeric_limits::max()); + TORCH_CHECK_GE(Y_zero_point, std::numeric_limits::min()); + TORCH_CHECK_LE(Y_zero_point, std::numeric_limits::max()); /* * Record quantization parameters for the input, because if the op is diff --git a/caffe2/operators/quantized/int8_max_pool_op.h b/caffe2/operators/quantized/int8_max_pool_op.h index df85ef10211e9..c677b93a7c92e 100644 --- a/caffe2/operators/quantized/int8_max_pool_op.h +++ b/caffe2/operators/quantized/int8_max_pool_op.h @@ -38,10 +38,10 @@ class Int8MaxPoolOp final : public ConvPoolOpBase { const int32_t Y_zero_point = this->template GetSingleArgument("Y_zero_point", 0); const float Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_EQ(Y_zero_point, X.zero_point); - CHECK_EQ(Y_scale, X.scale); + TORCH_CHECK_EQ(Y_zero_point, X.zero_point); + TORCH_CHECK_EQ(Y_scale, X.scale); - CHECK_EQ(X.t.dim(), 4); + TORCH_CHECK_EQ(X.t.dim(), 4); const int channels = X.t.dim32(3); ConvPoolOpBase::SetOutputSize(X.t, &(Y->t), channels); diff --git a/caffe2/operators/quantized/int8_relu_op.h b/caffe2/operators/quantized/int8_relu_op.h index e6c8a018afd70..8a22c9acdbf44 100644 --- a/caffe2/operators/quantized/int8_relu_op.h +++ b/caffe2/operators/quantized/int8_relu_op.h @@ -34,14 +34,14 @@ class Int8ReluOp final : public Operator { Y->t.ResizeLike(X.t); Y->scale = X.scale; Y->zero_point = X.zero_point; - CHECK_GE(X.zero_point, std::numeric_limits::min()); - CHECK_LE(X.zero_point, std::numeric_limits::max()); + TORCH_CHECK_GE(X.zero_point, std::numeric_limits::min()); + TORCH_CHECK_LE(X.zero_point, std::numeric_limits::max()); const int32_t Y_offset = this->template GetSingleArgument("Y_zero_point", 0); const float Y_scale = this->template GetSingleArgument("Y_scale", 1.0f); - CHECK_EQ(Y_offset, X.zero_point); - CHECK_EQ(Y_scale, X.scale); + TORCH_CHECK_EQ(Y_offset, X.zero_point); + TORCH_CHECK_EQ(Y_scale, X.scale); initQNNPACK(); diff --git a/caffe2/operators/quantized/int8_reshape_op.h b/caffe2/operators/quantized/int8_reshape_op.h index e3226a9fb5845..abc8a4a9609f5 100644 --- a/caffe2/operators/quantized/int8_reshape_op.h +++ b/caffe2/operators/quantized/int8_reshape_op.h @@ -32,8 +32,8 @@ class Int8ReshapeOp final : public ReshapeOp { auto* Y = Outputs()[0]->GetMutable(); int32_t Y_offset = this->template GetSingleArgument("Y_zero_point", 0); auto Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_EQ(Y_offset, X.zero_point); - CHECK_EQ(Y_scale, X.scale); + TORCH_CHECK_EQ(Y_offset, X.zero_point); + TORCH_CHECK_EQ(Y_scale, X.scale); Y->scale = Y_scale; Y->zero_point = Y_offset; DoRunWithTypeImpl(X.t, &Y->t); diff --git a/caffe2/operators/quantized/int8_resize_nearest_op.h b/caffe2/operators/quantized/int8_resize_nearest_op.h index 06d625bf06ba4..8e748b26719ee 100644 --- a/caffe2/operators/quantized/int8_resize_nearest_op.h +++ b/caffe2/operators/quantized/int8_resize_nearest_op.h @@ -49,8 +49,8 @@ class Int8ResizeNearestOp final : public Operator { int32_t Y_offset = this->template GetSingleArgument("Y_zero_point", 0); auto Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_EQ(Y_offset, X.zero_point); - CHECK_EQ(Y_scale, X.scale); + TORCH_CHECK_EQ(Y_offset, X.zero_point); + TORCH_CHECK_EQ(Y_scale, X.scale); const uint8_t* Xdata = X.t.data(); uint8_t* Ydata = Y->t.mutable_data(); diff --git a/caffe2/operators/quantized/int8_roi_align_op.h b/caffe2/operators/quantized/int8_roi_align_op.h index 360f4a62c0895..e0fa62f17feb0 100644 --- a/caffe2/operators/quantized/int8_roi_align_op.h +++ b/caffe2/operators/quantized/int8_roi_align_op.h @@ -281,10 +281,10 @@ class Int8RoIAlignOp final : public Operator { sampling_ratio_( this->template GetSingleArgument("sampling_ratio", -1)), aligned_(this->template GetSingleArgument("aligned", false)) { - DCHECK_GT(spatial_scale_, 0); - DCHECK_GT(pooled_height_, 0); - DCHECK_GT(pooled_width_, 0); - DCHECK_GE(sampling_ratio_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0); + TORCH_DCHECK_GT(pooled_height_, 0); + TORCH_DCHECK_GT(pooled_width_, 0); + TORCH_DCHECK_GE(sampling_ratio_, 0); // only supports NHWC CAFFE_ENFORCE(order_ == StorageOrder::NHWC); } diff --git a/caffe2/operators/quantized/int8_sigmoid_op.h b/caffe2/operators/quantized/int8_sigmoid_op.h index d67dad40a3189..aa3f1f64135ef 100644 --- a/caffe2/operators/quantized/int8_sigmoid_op.h +++ b/caffe2/operators/quantized/int8_sigmoid_op.h @@ -33,8 +33,8 @@ class Int8SigmoidOp final : public Operator { const int32_t Y_zero_point = this->template GetSingleArgument("Y_zero_point", 0); const float Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_EQ(Y_zero_point, 0); - CHECK_EQ(Y_scale, 1.0f / 256.0f); + TORCH_CHECK_EQ(Y_zero_point, 0); + TORCH_CHECK_EQ(Y_scale, 1.0f / 256.0f); /* * Record quantization parameters for the input, because if the op is diff --git a/caffe2/operators/quantized/int8_slice_op.h b/caffe2/operators/quantized/int8_slice_op.h index d573cbea4349a..02fa8525d19e2 100644 --- a/caffe2/operators/quantized/int8_slice_op.h +++ b/caffe2/operators/quantized/int8_slice_op.h @@ -76,8 +76,8 @@ class Int8SliceOp final : public SliceOp { auto* Y = Outputs()[0]->GetMutable(); int32_t Y_offset = this->template GetSingleArgument("Y_zero_point", 0); auto Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_EQ(Y_offset, X.zero_point); - CHECK_EQ(Y_scale, X.scale); + TORCH_CHECK_EQ(Y_offset, X.zero_point); + TORCH_CHECK_EQ(Y_scale, X.scale); Y->scale = Y_scale; Y->zero_point = Y_offset; diff --git a/caffe2/operators/quantized/int8_softmax_op.h b/caffe2/operators/quantized/int8_softmax_op.h index 7fc296c3263ec..27a10478e29d7 100644 --- a/caffe2/operators/quantized/int8_softmax_op.h +++ b/caffe2/operators/quantized/int8_softmax_op.h @@ -34,8 +34,8 @@ class Int8SoftmaxOp final : public Operator { const int32_t Y_zero_point = this->template GetSingleArgument("Y_zero_point", 0); const float Y_scale = this->template GetSingleArgument("Y_scale", 1); - CHECK_EQ(Y_zero_point, 0); - CHECK_EQ(Y_scale, 1.0f / 256.0f); + TORCH_CHECK_EQ(Y_zero_point, 0); + TORCH_CHECK_EQ(Y_scale, 1.0f / 256.0f); /* * Record quantization parameters for the input, because if the op is diff --git a/caffe2/operators/quantized/int8_test.cc b/caffe2/operators/quantized/int8_test.cc index 9b14d3eaec1da..cf668a88874fa 100644 --- a/caffe2/operators/quantized/int8_test.cc +++ b/caffe2/operators/quantized/int8_test.cc @@ -341,7 +341,7 @@ TEST(Int8, SumRelu) { } void setq(int8::Int8TensorCPU* dst, const std::vector& vs) { - CHECK_EQ(vs.size(), static_cast(dst->t.numel())); + TORCH_CHECK_EQ(vs.size(), static_cast(dst->t.numel())); for (auto i = 0U; i < vs.size(); ++i) { uint8_t vq = std::max( std::numeric_limits::min(), @@ -354,7 +354,7 @@ void setq(int8::Int8TensorCPU* dst, const std::vector& vs) { } void biassetq(int8::Int8TensorCPU* dst, const std::vector& vs) { - CHECK_EQ(vs.size(), static_cast(dst->t.numel())); + TORCH_CHECK_EQ(vs.size(), static_cast(dst->t.numel())); for (auto i = 0U; i < vs.size(); ++i) { int32_t vq = std::max( std::numeric_limits::min(), diff --git a/caffe2/operators/quantized/int8_utils.h b/caffe2/operators/quantized/int8_utils.h index 473798c1823a1..56195054689e7 100644 --- a/caffe2/operators/quantized/int8_utils.h +++ b/caffe2/operators/quantized/int8_utils.h @@ -91,8 +91,8 @@ inline void QuantizeMultiplierSmallerThanOne( q_fixed /= 2; --*right_shift; } - CHECK_GE(*right_shift, 0); - CHECK_LE(q_fixed, std::numeric_limits::max()); + TORCH_CHECK_GE(*right_shift, 0); + TORCH_CHECK_LE(q_fixed, std::numeric_limits::max()); *quantized_multiplier = static_cast(q_fixed); } @@ -108,8 +108,8 @@ inline void QuantizeMultiplierGreaterThanOne( q_fixed /= 2; ++*left_shift; } - CHECK_GE(*left_shift, 0); - CHECK_LE(q_fixed, std::numeric_limits::max()); + TORCH_CHECK_GE(*left_shift, 0); + TORCH_CHECK_LE(q_fixed, std::numeric_limits::max()); *quantized_multiplier = static_cast(q_fixed); } diff --git a/caffe2/operators/reducer_functors.h b/caffe2/operators/reducer_functors.h index 0159e030d2637..9827eb3c9b56f 100644 --- a/caffe2/operators/reducer_functors.h +++ b/caffe2/operators/reducer_functors.h @@ -343,7 +343,7 @@ class BaseReducer { } void observeInput(int input, const Tensor& value, int skip_dims) { - DCHECK_EQ(0, input); + TORCH_DCHECK_EQ(0, input); auto dims = value.sizes(); computeMeta(dims, skip_dims); } diff --git a/caffe2/operators/reduction_ops.cc b/caffe2/operators/reduction_ops.cc index e7099a2636b9c..3bc659d003353 100644 --- a/caffe2/operators/reduction_ops.cc +++ b/caffe2/operators/reduction_ops.cc @@ -305,7 +305,7 @@ bool SumElementsGradientOp::RunOnDevice() Tensor sum_grad(Input(1), CPU); auto* dX = Output(0, X.sizes(), at::dtype()); - DCHECK_EQ(sum_grad.numel(), 1); + TORCH_DCHECK_EQ(sum_grad.numel(), 1); math::Set( dX->numel(), static_cast( diff --git a/caffe2/operators/reduction_ops.cu b/caffe2/operators/reduction_ops.cu index 9649b85d015c5..1a81a6c11d909 100644 --- a/caffe2/operators/reduction_ops.cu +++ b/caffe2/operators/reduction_ops.cu @@ -83,7 +83,7 @@ template <> bool SumElementsGradientOp::RunOnDevice() { auto& X = Input(0); auto& dY = Input(1); - DCHECK_EQ(dY.numel(), 1); + TORCH_DCHECK_EQ(dY.numel(), 1); auto* dX = Output(0, X.sizes(), at::dtype()); SumElementsGradientKernel diff --git a/caffe2/operators/resize_3d_op.cc b/caffe2/operators/resize_3d_op.cc index 4f7d999d35635..5b9500e13b40a 100644 --- a/caffe2/operators/resize_3d_op.cc +++ b/caffe2/operators/resize_3d_op.cc @@ -2,7 +2,7 @@ #include "caffe2/utils/math.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include "caffe2/ideep/operators/operator_fallback_ideep.h" #include "caffe2/ideep/utils/ideep_operator.h" #endif @@ -165,7 +165,7 @@ REGISTER_CPU_GRADIENT_OPERATOR( ResizeNearest3DGradient, ResizeNearest3DGradientOp); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR( ResizeNearest3D, IDEEPFallbackOp>); diff --git a/caffe2/operators/resize_op.cc b/caffe2/operators/resize_op.cc index 4d574945cb8a4..840388d85fb7d 100644 --- a/caffe2/operators/resize_op.cc +++ b/caffe2/operators/resize_op.cc @@ -3,7 +3,7 @@ #include "caffe2/utils/cpu_neon.h" #include "caffe2/utils/math.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include "caffe2/ideep/operators/operator_fallback_ideep.h" #include "caffe2/ideep/utils/ideep_operator.h" #endif @@ -297,7 +297,7 @@ REGISTER_CPU_GRADIENT_OPERATOR( ResizeNearestGradient, ResizeNearestGradientOp); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR( ResizeNearest, IDEEPFallbackOp>); diff --git a/caffe2/operators/roi_align_gradient_op.h b/caffe2/operators/roi_align_gradient_op.h index 05d56ac96d934..d6e6b1fe277b0 100644 --- a/caffe2/operators/roi_align_gradient_op.h +++ b/caffe2/operators/roi_align_gradient_op.h @@ -25,10 +25,10 @@ class RoIAlignGradientOp final : public Operator { sampling_ratio_( this->template GetSingleArgument("sampling_ratio", -1)), aligned_(this->template GetSingleArgument("aligned", false)) { - DCHECK_GT(spatial_scale_, 0); - DCHECK_GT(pooled_height_, 0); - DCHECK_GT(pooled_width_, 0); - DCHECK_GE(sampling_ratio_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0); + TORCH_DCHECK_GT(pooled_height_, 0); + TORCH_DCHECK_GT(pooled_width_, 0); + TORCH_DCHECK_GE(sampling_ratio_, 0); } USE_OPERATOR_CONTEXT_FUNCTIONS; diff --git a/caffe2/operators/roi_align_op.h b/caffe2/operators/roi_align_op.h index 78ac4caa6b6f9..06505725c97a4 100644 --- a/caffe2/operators/roi_align_op.h +++ b/caffe2/operators/roi_align_op.h @@ -25,9 +25,9 @@ class RoIAlignOp final : public Operator { OP_SINGLE_ARG(int, "pooled_w", pooled_w_, 1), OP_SINGLE_ARG(int, "sampling_ratio", sampling_ratio_, -1), OP_SINGLE_ARG(bool, "aligned", aligned_, false) { - DCHECK_GT(spatial_scale_, 0.0f); - DCHECK_GT(pooled_h_, 0); - DCHECK_GT(pooled_w_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0.0f); + TORCH_DCHECK_GT(pooled_h_, 0); + TORCH_DCHECK_GT(pooled_w_, 0); DCHECK(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC); } diff --git a/caffe2/operators/roi_align_rotated_gradient_op.h b/caffe2/operators/roi_align_rotated_gradient_op.h index 92b7774fd4a55..d5b6a1b183bc6 100644 --- a/caffe2/operators/roi_align_rotated_gradient_op.h +++ b/caffe2/operators/roi_align_rotated_gradient_op.h @@ -22,10 +22,10 @@ class RoIAlignRotatedGradientOp final : public Operator { sampling_ratio_( this->template GetSingleArgument("sampling_ratio", -1)), aligned_(this->template GetSingleArgument("aligned", false)) { - DCHECK_GT(spatial_scale_, 0); - DCHECK_GT(pooled_height_, 0); - DCHECK_GT(pooled_width_, 0); - DCHECK_GE(sampling_ratio_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0); + TORCH_DCHECK_GT(pooled_height_, 0); + TORCH_DCHECK_GT(pooled_width_, 0); + TORCH_DCHECK_GE(sampling_ratio_, 0); } USE_OPERATOR_CONTEXT_FUNCTIONS; diff --git a/caffe2/operators/roi_align_rotated_op.h b/caffe2/operators/roi_align_rotated_op.h index 7c813e701b856..f63cf03ab92bc 100644 --- a/caffe2/operators/roi_align_rotated_op.h +++ b/caffe2/operators/roi_align_rotated_op.h @@ -27,10 +27,10 @@ class RoIAlignRotatedOp final : public Operator { sampling_ratio_( this->template GetSingleArgument("sampling_ratio", -1)), aligned_(this->template GetSingleArgument("aligned", false)) { - DCHECK_GT(spatial_scale_, 0); - DCHECK_GT(pooled_height_, 0); - DCHECK_GT(pooled_width_, 0); - DCHECK_GE(sampling_ratio_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0); + TORCH_DCHECK_GT(pooled_height_, 0); + TORCH_DCHECK_GT(pooled_width_, 0); + TORCH_DCHECK_GE(sampling_ratio_, 0); DCHECK(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC); } USE_OPERATOR_CONTEXT_FUNCTIONS; diff --git a/caffe2/operators/slice_op.h b/caffe2/operators/slice_op.h index 196c3ad259fa7..774160955a3e4 100644 --- a/caffe2/operators/slice_op.h +++ b/caffe2/operators/slice_op.h @@ -137,10 +137,10 @@ bool SliceImpl( src_offset_bytes + i * src_block_size_bytes; char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes; - DCHECK_LE( + TORCH_DCHECK_LE( static_cast(local_src_offset_bytes + dst_block_size_bytes), static_cast(src_bytes + src_nbytes)); - DCHECK_LE( + TORCH_DCHECK_LE( static_cast(local_dst_offset_bytes + dst_block_size_bytes), static_cast(dst_bytes + dst_nbytes)); context->CopyItemsSameDevice( @@ -183,10 +183,10 @@ bool SliceImpl( src_offset_bytes + i * src_block_size_bytes; char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes; - DCHECK_LE( + TORCH_DCHECK_LE( local_src_offset_bytes + src_block_size_bytes, src_bytes + src_nbytes); - DCHECK_LE( + TORCH_DCHECK_LE( local_dst_offset_bytes + src_block_size_bytes, dst_bytes + dst_nbytes); context->CopyItemsSameDevice( diff --git a/caffe2/operators/softmax_op_cudnn.cc b/caffe2/operators/softmax_op_cudnn.cc index 2ca57f000dd84..e51e1c623ebe1 100644 --- a/caffe2/operators/softmax_op_cudnn.cc +++ b/caffe2/operators/softmax_op_cudnn.cc @@ -98,7 +98,7 @@ class CuDNNSoftmaxGradientOp final : public Operator { const int N = Y.size_to_dim(canonical_axis); const int D = Y.size_from_dim(canonical_axis); - CHECK_EQ(Y.sizes(), dY.sizes()); + TORCH_CHECK_EQ(Y.sizes(), dY.sizes()); auto* dX = Output(0, Y.sizes(), at::dtype()); auto* dX_data = dX->template mutable_data(); if (N == 0 || D == 0) { diff --git a/caffe2/operators/softplus_op.cc b/caffe2/operators/softplus_op.cc index 2f4cb05abbcd2..3bc2f5f0ede3d 100644 --- a/caffe2/operators/softplus_op.cc +++ b/caffe2/operators/softplus_op.cc @@ -23,7 +23,7 @@ bool SoftplusGradientOp::RunOnDevice() { auto& Y = Input(0); auto& dY = Input(1); - DCHECK_EQ(dY.numel(), Y.numel()); + TORCH_DCHECK_EQ(dY.numel(), Y.numel()); auto* dX = Output(0, Y.sizes(), at::dtype()); const float* Ydata = Y.data(); diff --git a/caffe2/operators/softplus_op.cu b/caffe2/operators/softplus_op.cu index 03d7530e33f32..5eb30cabbdf99 100644 --- a/caffe2/operators/softplus_op.cu +++ b/caffe2/operators/softplus_op.cu @@ -24,7 +24,7 @@ template <> bool SoftplusOp::RunOnDevice() { auto& X = Input(0); - DCHECK_GT(X.numel(), 0); + TORCH_DCHECK_GT(X.numel(), 0); auto* Y = Output(0, X.sizes(), at::dtype()); SoftplusKernel <<::RunOnDevice() { auto& Y = Input(0); auto& dY = Input(1); - DCHECK_GT(Y.numel(), 0); - DCHECK_EQ(dY.numel(), Y.numel()); + TORCH_DCHECK_GT(Y.numel(), 0); + TORCH_DCHECK_EQ(dY.numel(), Y.numel()); auto* dX = Output(0, Y.sizes(), at::dtype()); SoftplusGradientKernel << #include #endif @@ -584,7 +584,7 @@ OPERATOR_SCHEMA(BRGNCHWCToPackedInt8BGRAStylizerDeprocess) .NumInputs(2) .NumOutputs(1); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR( BRGNCHWCToPackedInt8BGRAStylizerDeprocess, IDEEPFallbackOp>); diff --git a/caffe2/operators/summarize_op.cu b/caffe2/operators/summarize_op.cu index ca57066552dff..eac14e2cd1aeb 100644 --- a/caffe2/operators/summarize_op.cu +++ b/caffe2/operators/summarize_op.cu @@ -76,7 +76,7 @@ template<> bool SummarizeOp::RunOnDevice() { auto& X = Input(0); const int N = X.numel(); - DCHECK_GT(N, 0); + TORCH_DCHECK_GT(N, 0); // TODO(Yangqing): Any better way to avoid having to const cast? thrust::device_ptr Xdata(const_cast(X.data())); diff --git a/caffe2/operators/tensor_protos_db_input.h b/caffe2/operators/tensor_protos_db_input.h index 1ea831f4e2eb0..5fddcd13bd7f0 100644 --- a/caffe2/operators/tensor_protos_db_input.h +++ b/caffe2/operators/tensor_protos_db_input.h @@ -80,7 +80,7 @@ bool TensorProtosDBInput::Prefetch() { Tensor src = deserializer.Deserialize(protos.protos(i)); Tensor* dst = BlobGetMutableTensor( &prefetched_blobs_[i], dims, at::dtype(src.dtype()).device(CPU)); - DCHECK_EQ(src.numel() * batch_size_, dst->numel()); + TORCH_DCHECK_EQ(src.numel() * batch_size_, dst->numel()); this->context_.CopyItemsSameDevice( src.dtype(), src.numel(), diff --git a/caffe2/opt/custom/in_batch_broadcast_test.cc b/caffe2/opt/custom/in_batch_broadcast_test.cc index 7d05423f5666b..65278cd18d0c8 100644 --- a/caffe2/opt/custom/in_batch_broadcast_test.cc +++ b/caffe2/opt/custom/in_batch_broadcast_test.cc @@ -8,20 +8,20 @@ using namespace caffe2; namespace { void checkNet(NetDef& net, NetDef& expected_net) { - CHECK_EQ(net.op().size(), expected_net.op().size()); + TORCH_CHECK_EQ(net.op().size(), expected_net.op().size()); for (int i = 0; i < net.op().size(); i++) { auto& op1 = net.op(i); auto& op2 = expected_net.op(i); - CHECK_EQ(op1.type(), op2.type()); - CHECK_EQ(op1.input_size(), op2.input_size()); - CHECK_EQ(op1.output_size(), op2.output_size()); + TORCH_CHECK_EQ(op1.type(), op2.type()); + TORCH_CHECK_EQ(op1.input_size(), op2.input_size()); + TORCH_CHECK_EQ(op1.output_size(), op2.output_size()); for (int j = 0; j < op1.input_size(); j++) { - CHECK_EQ(op1.input(j), op2.input(j)); + TORCH_CHECK_EQ(op1.input(j), op2.input(j)); } for (int j = 0; j < op1.output_size(); j++) { - CHECK_EQ(op1.output(j), op2.output(j)); + TORCH_CHECK_EQ(op1.output(j), op2.output(j)); } - CHECK_EQ( + TORCH_CHECK_EQ( op1.device_option().device_type(), op2.device_option().device_type()); ArgumentHelper helper1(op1); ArgumentHelper helper2(op2); @@ -34,13 +34,13 @@ void checkNet(NetDef& net, NetDef& expected_net) { << "Argument " << name << " doesn't exist"; const auto arg1 = helper1.GetSingleArgument(name, 0); const auto arg2 = helper2.GetSingleArgument(name, 0); - CHECK_EQ(arg1, arg2); + TORCH_CHECK_EQ(arg1, arg2); } } } void checkShapeInfo(ShapeInfoMap& shape_map, ShapeInfoMap& expected_shape_map) { - CHECK_EQ(shape_map.size(), expected_shape_map.size()); + TORCH_CHECK_EQ(shape_map.size(), expected_shape_map.size()); for (auto& [name, shape] : shape_map) { auto it = expected_shape_map.find(name); CHECK(it != expected_shape_map.end()) << "Didn't find name " << name; diff --git a/caffe2/opt/fusion.cc b/caffe2/opt/fusion.cc index beca0cf862171..a1a014855c981 100644 --- a/caffe2/opt/fusion.cc +++ b/caffe2/opt/fusion.cc @@ -80,7 +80,7 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) { auto* blob = ws->CreateBlob(convBiasName); caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU); - CHECK_NOTNULL(tensor); + TORCH_CHECK_NOTNULL(tensor); // Get output channel size_t c = filterTensor->dim32(0); tensor->Resize(c); diff --git a/caffe2/opt/optimize_ideep.cc b/caffe2/opt/optimize_ideep.cc index 687126ee327d5..4aae4d7b45a2b 100644 --- a/caffe2/opt/optimize_ideep.cc +++ b/caffe2/opt/optimize_ideep.cc @@ -1,7 +1,7 @@ #include "caffe2/opt/optimize_ideep.h" #include "caffe2/opt/converter.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include #include "caffe2/ideep/ideep_utils.h" #endif @@ -11,7 +11,7 @@ namespace opt { using namespace nom; -#ifndef CAFFE2_USE_MKLDNN +#ifndef USE_MKLDNN void OptimizeForMkldnn( repr::NNModule* nn, caffe2::Workspace* ws, @@ -1012,7 +1012,7 @@ void OptimizeForMkldnn( setPoolingInferenceMode(nn); } -#endif // CAFFE2_USE_MKLDNN +#endif // USE_MKLDNN } // namespace opt } // namespace caffe2 diff --git a/caffe2/python/CMakeLists.txt b/caffe2/python/CMakeLists.txt index 373a4fff86b64..c092febee4a90 100644 --- a/caffe2/python/CMakeLists.txt +++ b/caffe2/python/CMakeLists.txt @@ -7,7 +7,7 @@ set(Caffe2_CPU_PYTHON_SRCS "/pybind_state_int8.cc" ) -if(CAFFE2_USE_MKLDNN) +if(USE_MKLDNN) set(Caffe2_CPU_PYTHON_SRCS ${Caffe2_CPU_PYTHON_SRCS} "/pybind_state_ideep.cc" diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index 42262d269695d..461b454b6a91f 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -7,13 +7,13 @@ import os - import unittest + import onnx.backend.test import caffe2.python.onnx.backend as c2 - from caffe2.python import core + core.SetEnginePref({}, {}) # This is a pytest magic variable to load extra plugins @@ -175,6 +175,17 @@ '|test_spacetodepth_.*' ')') +# Unsupported ops in opset 17 +backend_test.exclude('(test_layer_normalization_.*' + '|test_blackmanwindow_.*' + '|test_dft_.*' + '|test_hammingwindow_.*' + '|test_hannwindow_.*' + '|test_melweightmatrix_.*' + '|test_stft_.*' + '|test_sequencemap_.*' + ')') + # Skip vgg to speed up CI if 'JENKINS_URL' in os.environ: backend_test.exclude(r'(test_vgg19|test_vgg)') diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index ccaa0afb6ac91..a637f15e7a9d3 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -568,7 +568,7 @@ void addObjectMethods(py::module& m) { .def_property_readonly( "nets", [](Workspace* self) { - CHECK_NOTNULL(self); + TORCH_CHECK_NOTNULL(self); std::map nets; for (const auto& name : self->Nets()) { LOG(INFO) << "name: " << name; @@ -580,7 +580,7 @@ void addObjectMethods(py::module& m) { .def_property_readonly( "blobs", [](Workspace* self) { - CHECK_NOTNULL(self); + TORCH_CHECK_NOTNULL(self); std::map blobs; for (const auto& name : self->Blobs()) { blobs[name] = py::cast(self->GetBlob(name)); @@ -1057,11 +1057,11 @@ void addGlobalMethods(py::module& m) { m.attr("has_mkldnn") = py::bool_(false); m.attr("use_mkldnn") = py::bool_( -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN true -#else // CAFFE2_USE_MKLDNN +#else // USE_MKLDNN false -#endif // CAFFE2_USE_MKLDNN +#endif // USE_MKLDNN ); // if the binary is built with USE_ROCM, this is a ROCm build diff --git a/caffe2/python/recurrent.py b/caffe2/python/recurrent.py index d4762f08c683e..8bb0d9cfd6d65 100644 --- a/caffe2/python/recurrent.py +++ b/caffe2/python/recurrent.py @@ -282,7 +282,7 @@ def map_to_dual_list(m): cell_net.Proto().type = 'simple' # The last output is a list of step workspaces, - # which is only needed internally for gradient propogation + # which is only needed internally for gradient propagation return results[:-1] diff --git a/caffe2/quantization/server/conv_relu_op.h b/caffe2/quantization/server/conv_relu_op.h index ee7a1cc7c375c..9da14bb2126c4 100644 --- a/caffe2/quantization/server/conv_relu_op.h +++ b/caffe2/quantization/server/conv_relu_op.h @@ -12,12 +12,12 @@ class ConvReluOp final : public ConvPoolOpBase { : ConvPoolOpBase(operator_def, ws) { for (auto name : operator_def.input()) { local_input_blobs_.push_back(local_ws_.CreateBlob(name)); - CHECK_NOTNULL(local_input_blobs_.back()); + TORCH_CHECK_NOTNULL(local_input_blobs_.back()); } local_op_.reset(new ConvOp(operator_def, &local_ws_)); for (auto name : operator_def.output()) { local_output_blobs_.push_back(local_ws_.GetBlob(name)); - CHECK_NOTNULL(local_output_blobs_.back()); + TORCH_CHECK_NOTNULL(local_output_blobs_.back()); } } ~ConvReluOp() {} diff --git a/caffe2/quantization/server/fb_fc_packed_op.h b/caffe2/quantization/server/fb_fc_packed_op.h index 17b4b944f4b89..e74d01b836f38 100644 --- a/caffe2/quantization/server/fb_fc_packed_op.h +++ b/caffe2/quantization/server/fb_fc_packed_op.h @@ -129,7 +129,7 @@ class FbFCPackedOperator final : public Operator { CAFFE_ENFORCE(N == W->numCols(), dimErrorString()); Y_shape_cache_ = X.sizes().vec(); // This is an invariant of canonical_axis, so we can DCHECK. - DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); + TORCH_DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); Y_shape_cache_.resize(canonical_axis + 1); Y_shape_cache_[canonical_axis] = N; auto* Y = Output(0, Y_shape_cache_, at::dtype()); diff --git a/caffe2/quantization/server/fully_connected_fake_lowp_op.cc b/caffe2/quantization/server/fully_connected_fake_lowp_op.cc index b8e403985e42d..18686d3fc3d47 100644 --- a/caffe2/quantization/server/fully_connected_fake_lowp_op.cc +++ b/caffe2/quantization/server/fully_connected_fake_lowp_op.cc @@ -76,7 +76,7 @@ bool FullyConnectedFakeLowpFPOp:: Y_shape_cache_ = X.sizes().vec(); // This is an invariant of canonical_axis, so we can DCHECK. - DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); + TORCH_DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); Y_shape_cache_.resize(canonical_axis + 1); Y_shape_cache_[canonical_axis] = N; auto* Y = Output(0, Y_shape_cache_, at::dtype()); diff --git a/caffe2/quantization/server/op_wrapper.h b/caffe2/quantization/server/op_wrapper.h index 20ea7b94a79a2..759d6b715ee4c 100644 --- a/caffe2/quantization/server/op_wrapper.h +++ b/caffe2/quantization/server/op_wrapper.h @@ -18,13 +18,13 @@ class OpWrapper { : op_(op), qfactory_(qfactory) { for (auto name : op->debug_def().input()) { local_input_blobs_.push_back(local_ws_.CreateBlob(name)); - CHECK_NOTNULL(local_input_blobs_.back()); + TORCH_CHECK_NOTNULL(local_input_blobs_.back()); } OperatorDef def = op->debug_def(); local_op_.reset(new OpType(def, &local_ws_)); for (auto name : def.output()) { local_output_blobs_.push_back(local_ws_.GetBlob(name)); - CHECK_NOTNULL(local_output_blobs_.back()); + TORCH_CHECK_NOTNULL(local_output_blobs_.back()); } } diff --git a/caffe2/queue/blobs_queue.cc b/caffe2/queue/blobs_queue.cc index 4c890088fa2d5..87be7e073a441 100644 --- a/caffe2/queue/blobs_queue.cc +++ b/caffe2/queue/blobs_queue.cc @@ -53,7 +53,7 @@ BlobsQueue::BlobsQueue( } queue_.push_back(blobs); } - DCHECK_EQ(queue_.size(), capacity); + TORCH_DCHECK_EQ(queue_.size(), capacity); } bool BlobsQueue::blockingRead( diff --git a/caffe2/queue/blobs_queue_db.cc b/caffe2/queue/blobs_queue_db.cc index 3f10b1b7035f5..3214f90417ff0 100644 --- a/caffe2/queue/blobs_queue_db.cc +++ b/caffe2/queue/blobs_queue_db.cc @@ -10,7 +10,7 @@ #include "caffe2/core/operator.h" #include "caffe2/queue/blobs_queue.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include #include #endif @@ -42,7 +42,7 @@ class CreateBlobsQueueDBOp : public Operator { REGISTER_CPU_OPERATOR(CreateBlobsQueueDB, CreateBlobsQueueDBOp); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR( CreateBlobsQueueDB, IDEEPFallbackOp, SkipIndices<0>>); diff --git a/caffe2/serialize/CMakeLists.txt b/caffe2/serialize/CMakeLists.txt index 428052aaa20d0..1552b59d0d441 100644 --- a/caffe2/serialize/CMakeLists.txt +++ b/caffe2/serialize/CMakeLists.txt @@ -2,13 +2,13 @@ file(GLOB tmp *_test.cc) set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp}) list(APPEND Caffe2_CPU_SRCS - ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8/miniz.c + ${PROJECT_SOURCE_DIR}/third_party/miniz-2.1.0/miniz.c ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc ${CMAKE_CURRENT_SOURCE_DIR}/istream_adapter.cc ${CMAKE_CURRENT_SOURCE_DIR}/file_adapter.cc ${CMAKE_CURRENT_SOURCE_DIR}/crc.cc ${CMAKE_CURRENT_SOURCE_DIR}/read_adapter_interface.cc) -list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8) +list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.1.0) set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) diff --git a/caffe2/sgd/ftrl_op.cc b/caffe2/sgd/ftrl_op.cc index 9d012aa8c8490..5c1263e6a2ece 100644 --- a/caffe2/sgd/ftrl_op.cc +++ b/caffe2/sgd/ftrl_op.cc @@ -92,8 +92,8 @@ void SparseFtrlOp::DoRun() { int64_t N = var->size(0); int64_t block_size = M / N; int64_t K = indices.numel(); - DCHECK_EQ(M * 2, n_z->numel()); - DCHECK_EQ(grad.numel(), K * block_size); + TORCH_DCHECK_EQ(M * 2, n_z->numel()); + TORCH_DCHECK_EQ(grad.numel(), K * block_size); T* w = var->template mutable_data(); T* nz = n_z->template mutable_data(); const SIndex* idxs = indices.template data(); diff --git a/caffe2/sgd/iter_op.cc b/caffe2/sgd/iter_op.cc index b1318dd1eaff6..e285be53330b9 100644 --- a/caffe2/sgd/iter_op.cc +++ b/caffe2/sgd/iter_op.cc @@ -1,6 +1,6 @@ #include "caffe2/sgd/iter_op.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include #include #endif @@ -28,7 +28,7 @@ void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) { REGISTER_CPU_OPERATOR(Iter, IterOp); REGISTER_CPU_OPERATOR(AtomicIter, AtomicIterOp); -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR(AtomicIter, IDEEPFallbackOp>); #endif diff --git a/caffe2/sgd/learning_rate_functors.h b/caffe2/sgd/learning_rate_functors.h index d733ccc146113..9204e22adb196 100644 --- a/caffe2/sgd/learning_rate_functors.h +++ b/caffe2/sgd/learning_rate_functors.h @@ -278,10 +278,10 @@ class CompositeLearningRate : public LearningRateFunctor { public: CompositeLearningRate( const std::list>& sub_policies) { - DCHECK_GT(sub_policies.size(), 0); + TORCH_DCHECK_GT(sub_policies.size(), 0); int64_t num_iter_start = 1; for (auto it = sub_policies.begin(); it != sub_policies.end(); ++it) { - DCHECK_GT(it->num_iter_, 0); + TORCH_DCHECK_GT(it->num_iter_, 0); sub_policies_[num_iter_start].reset(it->policy_); sub_policy_lr_scales_[num_iter_start] = it->lr_scale_; num_iter_start += it->num_iter_; diff --git a/caffe2/sgd/learning_rate_op.h b/caffe2/sgd/learning_rate_op.h index 74387f47db73a..5b9e2385a0901 100644 --- a/caffe2/sgd/learning_rate_op.h +++ b/caffe2/sgd/learning_rate_op.h @@ -57,42 +57,42 @@ class LearningRateOp final : public Operator { arg_prefix + "active_period", -1); int64_t inactive_period = this->template GetSingleArgument( arg_prefix + "inactive_period", -1); - DCHECK_GE(active_period, 0); - DCHECK_GE(inactive_period, 0); + TORCH_DCHECK_GE(active_period, 0); + TORCH_DCHECK_GE(inactive_period, 0); return new AlternateLearningRate( active_period, inactive_period, active_first); } else if (policy == "hill") { int64_t num_iter = this->template GetSingleArgument(arg_prefix + "num_iter", 0); - DCHECK_GT(num_iter, 0); + TORCH_DCHECK_GT(num_iter, 0); T start_multiplier = this->template GetSingleArgument( arg_prefix + "start_multiplier", 0.); - DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1] - DCHECK_LE(start_multiplier, 1); + TORCH_DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1] + TORCH_DCHECK_LE(start_multiplier, 1); T gamma = this->template GetSingleArgument(arg_prefix + "gamma", 0); - DCHECK_GT(gamma, 0); + TORCH_DCHECK_GT(gamma, 0); T power = this->template GetSingleArgument(arg_prefix + "power", 0); - DCHECK_GT(power, 0); + TORCH_DCHECK_GT(power, 0); T end_multiplier = this->template GetSingleArgument( arg_prefix + "end_multiplier", 0); - DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1] - DCHECK_LE(end_multiplier, 1); + TORCH_DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1] + TORCH_DCHECK_LE(end_multiplier, 1); return new HillLearningRate( num_iter, start_multiplier, gamma, power, end_multiplier); } else if (policy == "slope") { int64_t num_iter_1 = this->template GetSingleArgument( arg_prefix + "num_iter_1", 0); - DCHECK_GT(num_iter_1, 0); + TORCH_DCHECK_GT(num_iter_1, 0); T multiplier_1 = this->template GetSingleArgument( arg_prefix + "multiplier_1", 0.); int64_t num_iter_2 = this->template GetSingleArgument( arg_prefix + "num_iter_2", 0); - DCHECK_GT(num_iter_1, 0); + TORCH_DCHECK_GT(num_iter_1, 0); T multiplier_2 = this->template GetSingleArgument( arg_prefix + "multiplier_2", 0.); - DCHECK_GT(num_iter_2, num_iter_1); + TORCH_DCHECK_GT(num_iter_2, num_iter_1); return new SlopeLearningRate( num_iter_1, multiplier_1, num_iter_2, multiplier_2); } else if (policy == "step") { @@ -100,13 +100,13 @@ class LearningRateOp final : public Operator { this->template GetSingleArgument(arg_prefix + "stepsize", 0); T gamma = this->template GetSingleArgument(arg_prefix + "gamma", 0); - DCHECK_GT(stepsize, 0); - DCHECK_GT(gamma, 0); + TORCH_DCHECK_GT(stepsize, 0); + TORCH_DCHECK_GT(gamma, 0); return new StepLearningRate(stepsize, gamma); } else if (policy == "exp") { T gamma = this->template GetSingleArgument(arg_prefix + "gamma", 0); - DCHECK_GT(gamma, 0); + TORCH_DCHECK_GT(gamma, 0); return new ExpLearningRate(gamma); } else if (policy == "gate") { T multiplier_1 = this->template GetSingleArgument( @@ -122,29 +122,29 @@ class LearningRateOp final : public Operator { this->template GetSingleArgument(arg_prefix + "gamma", 0); T power = this->template GetSingleArgument(arg_prefix + "power", 0); - DCHECK_GT(gamma, 0); - DCHECK_GT(power, 0); + TORCH_DCHECK_GT(gamma, 0); + TORCH_DCHECK_GT(power, 0); return new InvLearningRate(gamma, power); } else if (policy == "poly") { int max_iter = this->template GetSingleArgument(arg_prefix + "max_iter", -1); T power = this->template GetSingleArgument(arg_prefix + "power", 0); - DCHECK_GT(power, 0); + TORCH_DCHECK_GT(power, 0); return new PolyLearningRate(power, max_iter); } else if (policy == "linearWarmup") { T start_multiplier = this->template GetSingleArgument( arg_prefix + "start_multiplier", 0.); int num_iter = this->template GetSingleArgument(arg_prefix + "num_iter", 0); - DCHECK_GE(start_multiplier, 0); + TORCH_DCHECK_GE(start_multiplier, 0); return new LinearWarmupLearningRate(start_multiplier, num_iter); } else if (policy == "constantWarmup") { T multiplier = this->template GetSingleArgument( arg_prefix + "multiplier", 0.5); int num_iter = this->template GetSingleArgument(arg_prefix + "num_iter", 0); - DCHECK_GT(multiplier, 0); + TORCH_DCHECK_GT(multiplier, 0); return new ConstantWarmupLearningRate(multiplier, num_iter); } else if (policy == "pieceWarmup") { T m1 = this->template GetSingleArgument(arg_prefix + "m1", 0.5); @@ -193,8 +193,8 @@ class LearningRateOp final : public Operator { this->template GetSingleArgument(arg_prefix + "stepsize", 0); T decay = this->template GetSingleArgument(arg_prefix + "decay", 1.0); - DCHECK_GT(stepsize, 0); - DCHECK_GE(max_lr, base_lr_); + TORCH_DCHECK_GT(stepsize, 0); + TORCH_DCHECK_GE(max_lr, base_lr_); return new CyclicalLearningRate(base_lr_, max_lr, stepsize, decay); } else if (policy == "constantThenLinearWarmup") { T start_warmup_multiplier = this->template GetSingleArgument( @@ -220,7 +220,7 @@ class LearningRateOp final : public Operator { arg_prefix + "cyclical_step_size", 1000000); T cyclical_decay = this->template GetSingleArgument( arg_prefix + "cyclical_decay", 1.0); - DCHECK_GE(cyclical_max_lr, base_lr_); + TORCH_DCHECK_GE(cyclical_max_lr, base_lr_); return new CompositeCyclicalLearningRate( base_lr_, start_warmup_multiplier, @@ -240,7 +240,7 @@ class LearningRateOp final : public Operator { this->template GetSingleArgument(arg_prefix + "t_mult", 1.0); T lr_shrink = this->template GetSingleArgument( arg_prefix + "lr_shrink", 0.99); - DCHECK_GE(max_lr, min_lr); + TORCH_DCHECK_GE(max_lr, min_lr); return new CosineLearningRate( min_lr, max_lr, period, t_mult, lr_shrink); } else if (policy == "compositeCosine") { @@ -261,7 +261,7 @@ class LearningRateOp final : public Operator { T cosine_lr_shrink = this->template GetSingleArgument( arg_prefix + "cosine_lr_shrink", 0.99); - DCHECK_GE(cosine_max_lr, cosine_min_lr); + TORCH_DCHECK_GE(cosine_max_lr, cosine_min_lr); return new CompositeCosineLearningRate( start_warmup_multiplier, constant_warmup_num_iter, diff --git a/caffe2/test/assets/squeeze_predict_net.pb b/caffe2/test/assets/squeeze_predict_net.pb index a06d95947321a..ac4c476b91cc6 100644 Binary files a/caffe2/test/assets/squeeze_predict_net.pb and b/caffe2/test/assets/squeeze_predict_net.pb differ diff --git a/caffe2/utils/eigen_utils.h b/caffe2/utils/eigen_utils.h index 76f170aa9dd1a..c6c34dba9b5ae 100644 --- a/caffe2/utils/eigen_utils.h +++ b/caffe2/utils/eigen_utils.h @@ -148,7 +148,7 @@ void GetSubArray( out_array->derived().resize(indices.size()); for (const auto i : c10::irange(indices.size())) { - DCHECK_LT(indices[i], array.size()); + TORCH_DCHECK_LT(indices[i], array.size()); (*out_array)[i] = array[indices[i]]; } } @@ -181,7 +181,7 @@ void GetSubArrayRows( out_array->derived().resize(row_indices.size(), array2d.cols()); for (const auto i : c10::irange(row_indices.size())) { - DCHECK_LT(row_indices[i], array2d.size()); + TORCH_DCHECK_LT(row_indices[i], array2d.size()); out_array->row(i) = array2d.row(row_indices[i]).template cast(); } diff --git a/caffe2/utils/hip/math_blas_gpu_test.cc b/caffe2/utils/hip/math_blas_gpu_test.cc index 3405d17a53106..07d4bf11f5a4b 100644 --- a/caffe2/utils/hip/math_blas_gpu_test.cc +++ b/caffe2/utils/hip/math_blas_gpu_test.cc @@ -63,7 +63,7 @@ TEST(MathROCBLASTest, GemmNoTransNoTrans) { tensorY_host->CopyFrom(*tensorY); EXPECT_EQ(tensorY_host->size(), 30); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 10) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; } // Test Accumulate @@ -83,7 +83,7 @@ TEST(MathROCBLASTest, GemmNoTransNoTrans) { tensorY_host->CopyFrom(*tensorY); EXPECT_EQ(tensorY_host->size(), 30); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 15) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; } // Test Accumulate @@ -103,7 +103,7 @@ TEST(MathROCBLASTest, GemmNoTransNoTrans) { tensorY_host->CopyFrom(*tensorY); EXPECT_EQ(tensorY_host->size(), 30); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 20) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; } } @@ -160,7 +160,7 @@ TEST(MathROCBLASTest, GemmNoTransTrans) { tensorY_host->CopyFrom(*tensorY); EXPECT_EQ(tensorY_host->size(), 30); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 10) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; } // Test Accumulate @@ -180,7 +180,7 @@ TEST(MathROCBLASTest, GemmNoTransTrans) { tensorY_host->CopyFrom(*tensorY); EXPECT_EQ(tensorY_host->size(), 30); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 15) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; } math::Gemm( @@ -199,7 +199,7 @@ TEST(MathROCBLASTest, GemmNoTransTrans) { tensorY_host->CopyFrom(*tensorY); EXPECT_EQ(tensorY_host->size(), 30); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 20) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; } } @@ -252,7 +252,7 @@ TEST(MathROCBLASTest, GemvNoTrans) { context.FinishDeviceComputation(); tensorY_host->CopyFrom(*tensorY); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 10) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; } // Test Accumulate @@ -269,7 +269,7 @@ TEST(MathROCBLASTest, GemvNoTrans) { context.FinishDeviceComputation(); tensorY_host->CopyFrom(*tensorY); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 15) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; } // Test Accumulate @@ -286,7 +286,7 @@ TEST(MathROCBLASTest, GemvNoTrans) { context.FinishDeviceComputation(); tensorY_host->CopyFrom(*tensorY); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 20) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; } } @@ -339,7 +339,7 @@ TEST(MathROCBLASTest, GemvTrans) { context.FinishDeviceComputation(); tensorY_host->CopyFrom(*tensorY); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 6) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 6) << i; } // Test Accumulate @@ -356,7 +356,7 @@ TEST(MathROCBLASTest, GemvTrans) { context.FinishDeviceComputation(); tensorY_host->CopyFrom(*tensorY); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 9) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 9) << i; } // Test Accumulate @@ -373,7 +373,7 @@ TEST(MathROCBLASTest, GemvTrans) { context.FinishDeviceComputation(); tensorY_host->CopyFrom(*tensorY); for (int i = 0; i < tensorY_host->size(); ++i) { - CHECK_EQ(tensorY_host->data()[i], 12) << i; + TORCH_CHECK_EQ(tensorY_host->data()[i], 12) << i; } } } // namespace caffe2 diff --git a/caffe2/utils/knob_patcher.cc b/caffe2/utils/knob_patcher.cc index 98873f5f14a3d..31ff0f1388547 100644 --- a/caffe2/utils/knob_patcher.cc +++ b/caffe2/utils/knob_patcher.cc @@ -82,7 +82,7 @@ class Patcher { if (iter == patches_.end()) { LOG(FATAL) << "patch node not found when unpatching knob value"; } - CHECK_EQ(iter->second, node); + TORCH_CHECK_EQ(iter->second, node); if (node->prev) { iter->second = node->prev; } else { diff --git a/caffe2/utils/math-detail.h b/caffe2/utils/math-detail.h index 5630729d1a5ba..f2ecc711995ad 100644 --- a/caffe2/utils/math-detail.h +++ b/caffe2/utils/math-detail.h @@ -30,7 +30,7 @@ struct ScaleImpl { const T* x, T* y, CPUContext* /*context*/) { - DCHECK_EQ(N, 1); + TORCH_DCHECK_EQ(N, 1); *y = *x * alpha; } }; @@ -56,7 +56,7 @@ struct AxpyImpl { const T* x, T* y, CPUContext* /*context*/) { - DCHECK_EQ(N, 1); + TORCH_DCHECK_EQ(N, 1); *y += *x * alpha; } }; diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index 5b80701c42e89..8facae39cc51c 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -1514,7 +1514,7 @@ C10_EXPORT void Select( float* y, CPUContext* /*context*/) { for (int i = 0; i < N; ++i) { - DCHECK_LT(idx[i], D); + TORCH_DCHECK_LT(idx[i], D); y[i] = x[i * D + idx[i]]; } } diff --git a/caffe2/utils/math_test.cc b/caffe2/utils/math_test.cc index a0ed19d2bb4d0..d29860bd3ac58 100644 --- a/caffe2/utils/math_test.cc +++ b/caffe2/utils/math_test.cc @@ -29,10 +29,10 @@ TEST(MathTest, GemmNoTransNoTrans) { W.numel(), 1, W.mutable_data(), &cpu_context); EXPECT_EQ(Y.numel(), 30); for (int i = 0; i < X.numel(); ++i) { - CHECK_EQ(X.data()[i], 1); + TORCH_CHECK_EQ(X.data()[i], 1); } for (int i = 0; i < W.numel(); ++i) { - CHECK_EQ(W.data()[i], 1); + TORCH_CHECK_EQ(W.data()[i], 1); } const float kOne = 1.0; @@ -52,7 +52,7 @@ TEST(MathTest, GemmNoTransNoTrans) { &cpu_context); EXPECT_EQ(Y.numel(), 30); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 10) << i; + TORCH_CHECK_EQ(Y.data()[i], 10) << i; } // Test Accumulate math::Gemm( @@ -69,7 +69,7 @@ TEST(MathTest, GemmNoTransNoTrans) { &cpu_context); EXPECT_EQ(Y.numel(), 30); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 15) << i; + TORCH_CHECK_EQ(Y.data()[i], 15) << i; } // Test Accumulate math::Gemm( @@ -86,7 +86,7 @@ TEST(MathTest, GemmNoTransNoTrans) { &cpu_context); EXPECT_EQ(Y.numel(), 30); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 20) << i; + TORCH_CHECK_EQ(Y.data()[i], 20) << i; } } @@ -104,10 +104,10 @@ TEST(MathTest, GemmNoTransTrans) { W.numel(), 1, W.mutable_data(), &cpu_context); EXPECT_EQ(Y.numel(), 30); for (int i = 0; i < X.numel(); ++i) { - CHECK_EQ(X.data()[i], 1); + TORCH_CHECK_EQ(X.data()[i], 1); } for (int i = 0; i < W.numel(); ++i) { - CHECK_EQ(W.data()[i], 1); + TORCH_CHECK_EQ(W.data()[i], 1); } const float kOne = 1.0; @@ -127,7 +127,7 @@ TEST(MathTest, GemmNoTransTrans) { &cpu_context); EXPECT_EQ(Y.numel(), 30); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 10) << i; + TORCH_CHECK_EQ(Y.data()[i], 10) << i; } // Test Accumulate math::Gemm( @@ -144,7 +144,7 @@ TEST(MathTest, GemmNoTransTrans) { &cpu_context); EXPECT_EQ(Y.numel(), 30); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 15) << i; + TORCH_CHECK_EQ(Y.data()[i], 15) << i; } math::Gemm( CblasNoTrans, @@ -160,7 +160,7 @@ TEST(MathTest, GemmNoTransTrans) { &cpu_context); EXPECT_EQ(Y.numel(), 30); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 20) << i; + TORCH_CHECK_EQ(Y.data()[i], 20) << i; } } @@ -302,10 +302,10 @@ TEST(MathTest, GemvNoTrans) { X.numel(), 1, X.mutable_data(), &cpu_context); EXPECT_EQ(Y.numel(), 5); for (int i = 0; i < A.numel(); ++i) { - CHECK_EQ(A.data()[i], 1); + TORCH_CHECK_EQ(A.data()[i], 1); } for (int i = 0; i < X.numel(); ++i) { - CHECK_EQ(X.data()[i], 1); + TORCH_CHECK_EQ(X.data()[i], 1); } const float kOne = 1.0; @@ -322,7 +322,7 @@ TEST(MathTest, GemvNoTrans) { Y.mutable_data(), &cpu_context); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 10) << i; + TORCH_CHECK_EQ(Y.data()[i], 10) << i; } // Test Accumulate math::Gemv( @@ -336,7 +336,7 @@ TEST(MathTest, GemvNoTrans) { Y.mutable_data(), &cpu_context); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 15) << i; + TORCH_CHECK_EQ(Y.data()[i], 15) << i; } // Test Accumulate math::Gemv( @@ -350,7 +350,7 @@ TEST(MathTest, GemvNoTrans) { Y.mutable_data(), &cpu_context); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 20) << i; + TORCH_CHECK_EQ(Y.data()[i], 20) << i; } } @@ -368,10 +368,10 @@ TEST(MathTest, GemvTrans) { X.numel(), 1, X.mutable_data(), &cpu_context); EXPECT_EQ(Y.numel(), 10); for (int i = 0; i < A.numel(); ++i) { - CHECK_EQ(A.data()[i], 1); + TORCH_CHECK_EQ(A.data()[i], 1); } for (int i = 0; i < X.numel(); ++i) { - CHECK_EQ(X.data()[i], 1); + TORCH_CHECK_EQ(X.data()[i], 1); } const float kOne = 1.0; @@ -388,7 +388,7 @@ TEST(MathTest, GemvTrans) { Y.mutable_data(), &cpu_context); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 6) << i; + TORCH_CHECK_EQ(Y.data()[i], 6) << i; } // Test Accumulate math::Gemv( @@ -402,7 +402,7 @@ TEST(MathTest, GemvTrans) { Y.mutable_data(), &cpu_context); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 9) << i; + TORCH_CHECK_EQ(Y.data()[i], 9) << i; } // Test Accumulate math::Gemv( @@ -416,7 +416,7 @@ TEST(MathTest, GemvTrans) { Y.mutable_data(), &cpu_context); for (int i = 0; i < Y.numel(); ++i) { - CHECK_EQ(Y.data()[i], 12) << i; + TORCH_CHECK_EQ(Y.data()[i], 12) << i; } } @@ -429,9 +429,9 @@ TEST(MathTest, FloatToHalfConversion) { float converted_b = static_cast(at::Half(b)); float converted_c = static_cast(at::Half(c)); - CHECK_EQ(a, converted_a); - CHECK_EQ(b, converted_b); - CHECK_EQ(c, converted_c); + TORCH_CHECK_EQ(a, converted_a); + TORCH_CHECK_EQ(b, converted_b); + TORCH_CHECK_EQ(c, converted_c); } namespace { diff --git a/caffe2/utils/threadpool/WorkersPool.h b/caffe2/utils/threadpool/WorkersPool.h index 8126b82aa2089..d847ffca6817e 100644 --- a/caffe2/utils/threadpool/WorkersPool.h +++ b/caffe2/utils/threadpool/WorkersPool.h @@ -164,7 +164,7 @@ T WaitForVariableChange(std::atomic* var, new_value = var->load(std::memory_order_relaxed); return new_value != initial_value; }); - DCHECK_NE(static_cast(new_value), static_cast(initial_value)); + TORCH_DCHECK_NE(static_cast(new_value), static_cast(initial_value)); return new_value; } } @@ -178,7 +178,7 @@ class BlockingCounter { // decrementing events that the Wait() call will be waiting for. void Reset(std::size_t initial_count) { std::lock_guard g(mutex_); - DCHECK_EQ(count_, 0); + TORCH_DCHECK_EQ(count_, 0); count_ = initial_count; } @@ -188,7 +188,7 @@ class BlockingCounter { // returns false. bool DecrementCount() { const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1; - DCHECK_GE(count_value, 0); + TORCH_DCHECK_GE(count_value, 0); if (count_value == 0) { std::lock_guard g(mutex_); cond_.notify_one(); @@ -338,7 +338,7 @@ class WorkersPool { // One of the tasks will be run on the current thread. int workers_count = tasks.size() - 1; CreateWorkers(workers_count); - DCHECK_LE(workers_count, (int)workers_.size()); + TORCH_DCHECK_LE(workers_count, (int)workers_.size()); counter_to_decrement_when_ready_.Reset(workers_count); for (const auto task : c10::irange(1, tasks.size())) { workers_[task - 1]->StartWork(tasks[task].get()); diff --git a/caffe2/utils/threadpool/pthreadpool.cc b/caffe2/utils/threadpool/pthreadpool.cc index ac633d271a365..50e061f3e0fad 100644 --- a/caffe2/utils/threadpool/pthreadpool.cc +++ b/caffe2/utils/threadpool/pthreadpool.cc @@ -76,7 +76,7 @@ struct compute_2d_context { }; static void compute_2d(void* context_, size_t linear_index) { - DCHECK_LE(linear_index, std::numeric_limits::max()); + TORCH_DCHECK_LE(linear_index, std::numeric_limits::max()); const struct compute_2d_context* context = static_cast(context_); int32_t q; @@ -100,7 +100,7 @@ void legacy_pthreadpool_compute_2d( } } } else { - DCHECK_LE(range_i * range_j, (size_t)std::numeric_limits::max()); + TORCH_DCHECK_LE(range_i * range_j, (size_t)std::numeric_limits::max()); /* Execute in parallel on the thread pool using linearized index */ struct compute_2d_context context = { /*.function = */ function, @@ -155,7 +155,7 @@ void legacy_pthreadpool_compute_2d_tiled( /* Execute in parallel on the thread pool using linearized index */ const size_t tile_range_i = divide_round_up(range_i, tile_i); const size_t tile_range_j = divide_round_up(range_j, tile_j); - DCHECK_LE( + TORCH_DCHECK_LE( tile_range_i * tile_range_j, (size_t)std::numeric_limits::max()); struct compute_2d_tiled_context context = { @@ -237,7 +237,7 @@ void legacy_pthreadpool_compute_3d_tiled( const size_t tile_range_i = divide_round_up(range_i, tile_i); const size_t tile_range_j = divide_round_up(range_j, tile_j); const size_t tile_range_k = divide_round_up(range_k, tile_k); - DCHECK_LE( + TORCH_DCHECK_LE( tile_range_i * tile_range_j * tile_range_k, (size_t)std::numeric_limits::max()); struct compute_3d_tiled_context context = { @@ -349,7 +349,7 @@ void legacy_pthreadpool_compute_4d_tiled( const size_t tile_range_j = divide_round_up(range_j, tile_j); const size_t tile_range_k = divide_round_up(range_k, tile_k); const size_t tile_range_l = divide_round_up(range_l, tile_l); - DCHECK_LE( + TORCH_DCHECK_LE( tile_range_i * tile_range_j * tile_range_k * tile_range_l, (size_t)std::numeric_limits::max()); struct compute_4d_tiled_context context = { diff --git a/caffe2/video/video_input_op.cc b/caffe2/video/video_input_op.cc index 2a131f867ad01..8a6530a207b5b 100644 --- a/caffe2/video/video_input_op.cc +++ b/caffe2/video/video_input_op.cc @@ -54,7 +54,7 @@ OPERATOR_SCHEMA(VideoInput) int index = 0; vector out(output_size); - CHECK_GT(crop_size, 0); + TORCH_CHECK_GT(crop_size, 0); batch_size *= clip_per_video; if (get_rgb) { out[index++] = CreateTensorShape( diff --git a/caffe2/video/video_input_op.h b/caffe2/video/video_input_op.h index 27f7e223bfec2..36d7be54b3260 100644 --- a/caffe2/video/video_input_op.h +++ b/caffe2/video/video_input_op.h @@ -564,7 +564,7 @@ bool VideoInputOp::GetImageAndLabelsFromDBValue( cv::Mat src; if (image_proto.data_type() == TensorProto::STRING) { // encoded image string. - DCHECK_EQ(image_proto.string_data_size(), 1); + TORCH_DCHECK_EQ(image_proto.string_data_size(), 1); const string& encoded_image_str = image_proto.string_data(0); int encoded_size = encoded_image_str.size(); // We use a cv::Mat to wrap the encoded str so we do not need a copy. diff --git a/cmake/Caffe2Config.cmake.in b/cmake/Caffe2Config.cmake.in index f9979008f520b..8045c87598dfc 100644 --- a/cmake/Caffe2Config.cmake.in +++ b/cmake/Caffe2Config.cmake.in @@ -78,6 +78,10 @@ else() endif() endif() +if (@USE_ROCM@) + include("${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake") +endif() + if(@USE_CUDA@) # The file public/cuda.cmake exclusively uses CAFFE2_USE_*. # If Caffe2 was compiled with the libraries below, they must diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index a73811000b087..f652de8cce8e6 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -118,6 +118,10 @@ if(INTERN_BUILD_ATEN_OPS) --source-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen --install_dir ${CMAKE_BINARY_DIR}/aten/src/ATen ) + if(SELECTED_OP_LIST) + list(APPEND GEN_UNBOXING_COMMAND + --TEST_ONLY_op_registration_allowlist_yaml_path "${SELECTED_OP_LIST}") + endif() set("GEN_UNBOXING_COMMAND_sources" ${GEN_UNBOXING_COMMAND} --output-dependencies ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_unboxing_sources.cmake diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index deec447c99735..c67746d903dc1 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -824,7 +824,8 @@ if(USE_FBGEMM) if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0.0) # See https://github.com/pytorch/pytorch/issues/74352 target_compile_options(asmjit PRIVATE -Wno-deprecated-copy) - if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.1.0) + if(("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13.1.6) + OR("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0)) # -Wno-unused-but-set-variable doesn't exist in Apple clang version 13.0.0 (clang-1300.0.29.30) target_compile_options(asmjit PRIVATE -Wno-unused-but-set-variable) endif() @@ -961,6 +962,19 @@ if(USE_FFMPEG) endif() endif() +if(USE_ITT) + find_package(ITT) + if(ITT_FOUND) + include_directories(SYSTEM ${ITT_INCLUDE_DIR}) + list(APPEND Caffe2_DEPENDENCY_LIBS ${ITT_LIBRARIES}) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ITT_LIBRARIES}) + else() + message(WARNING "Not compiling with ITT. Suppress this warning with -DUSE_ITT=OFF") + set(USE_ITT OFF CACHE BOOL "" FORCE) + caffe2_update_option(USE_ITT OFF) + endif() +endif() + # ---[ Caffe2 depends on FP16 library for half-precision conversions if(NOT TARGET fp16 AND NOT USE_SYSTEM_FP16) set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") @@ -1104,7 +1118,7 @@ if(BUILD_PYTHON) endif() # ---[ pybind11 -if(USE_SYSTEM_BIND11) +if(USE_SYSTEM_PYBIND11) find_package(pybind11 CONFIG) if(NOT pybind11_FOUND) find_package(pybind11) @@ -1361,6 +1375,16 @@ if(USE_NCCL) endif() endif() +# ---[ UCC +if(USE_UCC) + if(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") + message(WARNING "UCC is currently only supported under Linux.") + caffe2_update_option(USE_UCC OFF) + else() + include(${CMAKE_CURRENT_LIST_DIR}/External/ucc.cmake) + endif() +endif() + # ---[ CUB if(USE_CUDA) find_package(CUB) @@ -1627,17 +1651,13 @@ if(NOT INTERN_BUILD_MOBILE) endif() if(NOT MSVC) - if(CMAKE_VERSION VERSION_LESS "3.1") - set(CMAKE_C_FLAGS "-std=c11 ${CMAKE_C_FLAGS}") - else() - set(CMAKE_C_STANDARD 11) - endif() + set(CMAKE_C_STANDARD 11 CACHE STRING "The C standard whose features are requested to build this target.") endif() string(APPEND CMAKE_CUDA_FLAGS " -Wno-deprecated-gpu-targets --expt-extended-lambda") if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(CMAKE_CXX_STANDARD 14) + set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") endif() # use cub in a safe manner, see: @@ -1751,7 +1771,6 @@ if(NOT INTERN_BUILD_MOBILE) endif() set(AT_MKLDNN_ENABLED 0) - set(CAFFE2_USE_MKLDNN OFF) if(USE_MKLDNN) if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) message(WARNING @@ -1767,11 +1786,11 @@ if(NOT INTERN_BUILD_MOBILE) set(AT_MKLDNN_ENABLED 1) include_directories(AFTER SYSTEM ${MKLDNN_INCLUDE_DIR}) if(BUILD_CAFFE2_OPS) - set(CAFFE2_USE_MKLDNN ON) list(APPEND Caffe2_DEPENDENCY_LIBS caffe2::mkldnn) endif(BUILD_CAFFE2_OPS) else() message(WARNING "MKLDNN could not be found.") + caffe2_update_option(USE_MKLDNN OFF) endif() else() message("disabling MKLDNN because USE_MKLDNN is not set") diff --git a/cmake/External/ucc.cmake b/cmake/External/ucc.cmake new file mode 100644 index 0000000000000..359ea67b1a745 --- /dev/null +++ b/cmake/External/ucc.cmake @@ -0,0 +1,20 @@ +if(NOT __UCC_INCLUDED) + set(__UCC_INCLUDED TRUE) + + if(USE_SYSTEM_UCC) + set(UCX_HOME $ENV{UCX_HOME} CACHE PATH "UCX install directory") + set(UCC_HOME $ENV{UCC_HOME} CACHE PATH "UCC install directory") + + add_library(__caffe2_ucc INTERFACE) + + target_include_directories(__caffe2_ucc INTERFACE ${UCX_HOME}/include/) + target_include_directories(__caffe2_ucc INTERFACE ${UCC_HOME}/include/) + + target_link_libraries(__caffe2_ucc INTERFACE ${UCX_HOME}/lib/libucp.so) + target_link_libraries(__caffe2_ucc INTERFACE ${UCX_HOME}/lib/libucs.so) + target_link_libraries(__caffe2_ucc INTERFACE ${UCC_HOME}/lib/libucc.so) + else() + message(FATAL_ERROR "USE_SYSTEM_UCC=OFF is not supported yet when using UCC") + endif() + +endif() diff --git a/cmake/Modules/FindAVX.cmake b/cmake/Modules/FindAVX.cmake index c04427cbad850..0b40c24843885 100644 --- a/cmake/Modules/FindAVX.cmake +++ b/cmake/Modules/FindAVX.cmake @@ -1,5 +1,5 @@ -INCLUDE(CheckCSourceCompiles) -INCLUDE(CheckCXXSourceCompiles) +INCLUDE(CheckCSourceRuns) +INCLUDE(CheckCXXSourceRuns) SET(AVX_CODE " #include @@ -51,9 +51,9 @@ MACRO(CHECK_SSE lang type flags) IF(NOT ${lang}_${type}_FOUND) SET(CMAKE_REQUIRED_FLAGS ${__FLAG}) IF(lang STREQUAL "CXX") - CHECK_CXX_SOURCE_COMPILES("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I}) + CHECK_CXX_SOURCE_RUNS("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I}) ELSE() - CHECK_C_SOURCE_COMPILES("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I}) + CHECK_C_SOURCE_RUNS("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I}) ENDIF() IF(${lang}_HAS_${type}_${__FLAG_I}) SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support") diff --git a/cmake/Modules/FindITT.cmake b/cmake/Modules/FindITT.cmake new file mode 100644 index 0000000000000..9bd800fd236a3 --- /dev/null +++ b/cmake/Modules/FindITT.cmake @@ -0,0 +1,21 @@ +# - Try to find ITT +# +# The following are set after configuration is done: +# ITT_FOUND : set to true if ITT is found. +# ITT_INCLUDE_DIR : path to ITT include dir. +# ITT_LIBRARIES : list of libraries for ITT + +IF (NOT ITT_FOUND) + SET(ITT_FOUND OFF) + + SET(ITT_INCLUDE_DIR) + SET(ITT_LIBRARIES) + + SET(ITT_ROOT "${PROJECT_SOURCE_DIR}/third_party/ittapi") + FIND_PATH(ITT_INCLUDE_DIR ittnotify.h PATHS ${ITT_ROOT} PATH_SUFFIXES include) + IF (ITT_INCLUDE_DIR) + ADD_SUBDIRECTORY(${ITT_ROOT}) + SET(ITT_LIBRARIES ittnotify) + SET(ITT_FOUND ON) + ENDIF (ITT_INCLUDE_DIR) +ENDIF(NOT ITT_FOUND) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 1a99d1e567a12..a9c6201fb6bef 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -142,10 +142,15 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_FFTW : ${USE_FFTW}") message(STATUS " USE_MKL : ${CAFFE2_USE_MKL}") message(STATUS " USE_MKLDNN : ${USE_MKLDNN}") - if(${CAFFE2_USE_MKLDNN}) + if(${USE_MKLDNN}) message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}") message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}") endif() + message(STATUS " USE_UCC : ${USE_UCC}") + if(${USE_UCC}) + message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}") + endif() + message(STATUS " USE_ITT : ${USE_ITT}") message(STATUS " USE_NCCL : ${USE_NCCL}") if(${USE_NCCL}) message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") diff --git a/cmake/VulkanCodegen.cmake b/cmake/VulkanCodegen.cmake index c39b54df3af30..075f2b36ad2a4 100644 --- a/cmake/VulkanCodegen.cmake +++ b/cmake/VulkanCodegen.cmake @@ -62,7 +62,7 @@ if(NOT USE_VULKAN_SHADERC_RUNTIME) execute_process( COMMAND "${PYTHON_EXECUTABLE}" - ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_spv.py + ${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py --glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl --output-path ${VULKAN_GEN_OUTPUT_PATH} --glslc-path=${GLSLC_PATH} diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 0202f15270b21..87bb57da1543f 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -300,10 +300,18 @@ if(HIP_FOUND) find_library(PYTORCH_HIP_HCC_LIBRARIES ${hip_library_name} HINTS ${HIP_PATH}/lib) # TODO: miopen_LIBRARIES should return fullpath to the library file, # however currently it's just the lib name - find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib) + if(TARGET ${miopen_LIBRARIES}) + set(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES}) + else() + find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib) + endif() # TODO: rccl_LIBRARIES should return fullpath to the library file, # however currently it's just the lib name - find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib) + if(TARGET ${rccl_LIBRARIES}) + set(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES}) + else() + find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib) + endif() # hiprtc is part of HIP find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib) # roctx is part of roctracer diff --git a/defs.bzl b/defs.bzl index c81f59274c1c1..a40ee4b9091da 100644 --- a/defs.bzl +++ b/defs.bzl @@ -1,19 +1,20 @@ -def get_sleef_deps(): - return [("sleef", None, "sleef")] if not (host_info().arch.is_aarch64) else [] +def get_sleef_arch_deps(): + return [ + ("x86_64", [ + "third-party//sleef:sleef", + ]), + ] -def get_blas_gomp_deps(): - if host_info().arch.is_x86_64: - return [( - "IntelComposerXE", - None, - native.read_config("fbcode", "mkl_lp64", "mkl_lp64_omp"), - )] - if host_info().arch.is_aarch64: - return [ - ("OpenBLAS", None, "OpenBLAS"), - ("openmp", None, "omp"), - ] - fail("Unsupported architecture") +def get_blas_gomp_arch_deps(): + return [ + ("x86_64", [ + "third-party//IntelComposerXE:{}".format(native.read_config("fbcode", "mkl_lp64", "mkl_lp64_omp")), + ]), + ("aarch64", [ + "third-party//OpenBLAS:OpenBLAS", + "third-party//openmp:omp", + ]), + ] default_compiler_flags = [ "-Wall", diff --git a/docker.Makefile b/docker.Makefile index 11c438d0fd224..a1772529d926d 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -18,7 +18,7 @@ CUDA_CHANNEL = nvidia # The conda channel to use to install pytorch / torchvision INSTALL_CHANNEL = pytorch -PYTHON_VERSION = 3.7 +PYTHON_VERSION = 3.8 PYTORCH_VERSION = $(shell git describe --tags --always) # Can be either official / dev BUILD_TYPE = dev diff --git a/docs/Makefile b/docs/Makefile index b9719df7ade5c..122bda6231e39 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -18,7 +18,7 @@ figures: @$(PYCMD) source/scripts/build_quantization_configs.py onnx_supported_aten_ops: - @$(PYCMD) source/scripts/build_onnx_supported_aten_op_csv_table.py + @$(PYCMD) source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py docset: html doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url https://pytorch.org/docs/ --force $(BUILDDIR)/html/ diff --git a/docs/cpp/source/conf.py b/docs/cpp/source/conf.py index 54cd6acdb8fd7..e3a43da1108f0 100644 --- a/docs/cpp/source/conf.py +++ b/docs/cpp/source/conf.py @@ -16,6 +16,9 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +# NB: C++ API doc generation using doxygen / breathe / exhale is currently only +# enabled on nightlies (and not trunk or on PRs) due to OOM errors in CI. +# See https://github.com/pytorch/pytorch/issues/79992. import os # sys.path.insert(0, os.path.abspath('.')) @@ -26,15 +29,17 @@ # If your documentation needs a minimal Sphinx version, state it here. # needs_sphinx = '3.1.2' +run_doxygen = os.environ.get('RUN_DOXYGEN', "false") == "true" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.intersphinx', +] + ([ 'breathe', 'exhale' -] +] if run_doxygen else []) intersphinx_mapping = { 'pytorch': ('https://pytorch.org/docs/master', None) diff --git a/docs/source/autograd.rst b/docs/source/autograd.rst index 0782e30345b75..4a36681947004 100644 --- a/docs/source/autograd.rst +++ b/docs/source/autograd.rst @@ -223,10 +223,12 @@ Profiler ^^^^^^^^ Autograd includes a profiler that lets you inspect the cost of different -operators inside your model - both on the CPU and GPU. There are two modes +operators inside your model - both on the CPU and GPU. There are three modes implemented at the moment - CPU-only using :class:`~torch.autograd.profiler.profile`. -and nvprof based (registers both CPU and GPU activity) using +nvprof based (registers both CPU and GPU activity) using :class:`~torch.autograd.profiler.emit_nvtx`. +and vtune profiler based using +:class:`~torch.autograd.profiler.emit_itt`. .. autoclass:: torch.autograd.profiler.profile @@ -240,6 +242,7 @@ and nvprof based (registers both CPU and GPU activity) using profiler.profile.total_average .. autoclass:: torch.autograd.profiler.emit_nvtx +.. autoclass:: torch.autograd.profiler.emit_itt .. autosummary:: diff --git a/docs/source/backends.rst b/docs/source/backends.rst index c54cf33fbe154..152e0144a416d 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -14,6 +14,7 @@ These backends include: - ``torch.backends.mkl`` - ``torch.backends.mkldnn`` - ``torch.backends.openmp`` +- ``torch.backends.xeon`` torch.backends.cuda @@ -78,6 +79,14 @@ torch.backends.cudnn A :class:`bool` that, if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest. +.. attribute:: torch.backends.cudnn.benchmark_limit + + A :class:`int` that specifies the maximum number of cuDNN convolution algorithms to try when + `torch.backends.cudnn.benchmark` is True. Set `benchmark_limit` to zero to try every + available algorithm. Note that this setting only affects convolutions dispatched via the + cuDNN v8 API. + + torch.backends.mps ^^^^^^^^^^^^^^^^^^ .. automodule:: torch.backends.mps @@ -93,6 +102,8 @@ torch.backends.mkl .. autofunction:: torch.backends.mkl.is_available +.. autoclass:: torch.backends.mkl.verbose + torch.backends.mkldnn ^^^^^^^^^^^^^^^^^^^^^ @@ -100,6 +111,8 @@ torch.backends.mkldnn .. autofunction:: torch.backends.mkldnn.is_available +.. autoclass:: torch.backends.mkldnn.verbose + torch.backends.openmp ^^^^^^^^^^^^^^^^^^^^^ @@ -112,3 +125,8 @@ torch.backends.openmp .. add anything to the rendered page for now. .. py:module:: torch.backends.quantized .. py:module:: torch.backends.xnnpack + + +torch.backends.xeon +^^^^^^^^^^^^^^^^^^^ +.. automodule:: torch.backends.xeon diff --git a/docs/source/bottleneck.rst b/docs/source/bottleneck.rst index 3fa1c99b50617..c413771887073 100644 --- a/docs/source/bottleneck.rst +++ b/docs/source/bottleneck.rst @@ -47,7 +47,9 @@ where [args] are any number of arguments to `script.py`, or run evaluating. If the profiler outputs don't help, you could try looking at the result of :func:`torch.autograd.profiler.emit_nvtx()` with ``nvprof``. However, please take into account that the NVTX overhead is very high and - often gives a heavily skewed timeline. + often gives a heavily skewed timeline. Similarly, Intel VTune Profiler helps + to analyze performance on Intel platforms further with + :func:`torch.autograd.profiler.emit_nvtx()`. .. warning:: If you are profiling CUDA code, the first profiler that ``bottleneck`` runs diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index 602d7bae74097..f6e19db5e8255 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -70,7 +70,7 @@ Distributions & RNG - Fritz Obermeyer (`fritzo `__) - Neeraj Pradhan (`neerajprad `__) - Alican Bozkurt (`alicanb `__) -- Vishwak Srinivasan (`vishwakftw `__) +- (emeritus) Vishwak Srinivasan (`vishwakftw `__) Distributed ~~~~~~~~~~~ @@ -101,9 +101,9 @@ Linear Algebra (torch.linalg) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Mike Ruberry (`mruberry `__) -- Vishwak Srinivasan (`vishwakftw `__) - Mario Lezcano (`Lezcano `__) - Ivan Yashchuk (`IvanYashchuk `__) +- (emeritus) Vishwak Srinivasan (`vishwakftw `__) Fast Fourier Transform (torch.fft) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -214,17 +214,23 @@ Windows - (emeritus) Teng Gao (`gaoteng-git `__) - (emeritus) Peter Johnson (`peterjc123 `__) -Apple M1 -~~~~~~~~ +Apple M1/MPS +~~~~~~~~~~~~ - Alban Desmaison (`alband `__) - Nikita Shulga (`malfet `__) +- Kulin Seth (`kulinseth `__) PowerPC ~~~~~~~ - Alfredo Mendoza (`avmgithub `__) +Docs / Tutorials +~~~~~~~~~~~~~~~~ + +- Svetlana Karslioglu (`svekars `__) + Library-level maintainers ------------------------- diff --git a/docs/source/complex_numbers.rst b/docs/source/complex_numbers.rst index 0823833fb3340..3a3f200fb5f48 100644 --- a/docs/source/complex_numbers.rst +++ b/docs/source/complex_numbers.rst @@ -3,6 +3,9 @@ Complex Numbers =============== +.. note:: When using complex numbers, use Pytorch with CUDA 11.6 downloaded via pip wheel as described in + `Get Started `__ and select the CUDA 11.6 pip package. + Complex numbers are numbers that can be expressed in the form :math:`a + bj`, where a and b are real numbers, and *j* is called the imaginary unit, which satisfies the equation :math:`j^2 = -1`. Complex numbers frequently occur in mathematics and engineering, especially in topics like signal processing. Traditionally many users and libraries (e.g., TorchAudio) have diff --git a/docs/source/conf.py b/docs/source/conf.py index 63b5589c178fb..e8b683cd445cd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -18,6 +18,7 @@ # import os from os import path +import re # import sys import pkgutil @@ -320,7 +321,7 @@ "Quantize", # torch.utils.backcompat "Warning", - "SymbolicIntNode" + "SymIntNode" ] # The suffix(es) of source filenames. @@ -642,12 +643,12 @@ def handle_item(fieldarg, content): # inconsistencies later when references are resolved fieldtype = types.pop(fieldarg) if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): - typename = u''.join(n.astext() for n in fieldtype) - typename = typename.replace('int', 'python:int') - typename = typename.replace('long', 'python:long') - typename = typename.replace('float', 'python:float') - typename = typename.replace('bool', 'python:bool') - typename = typename.replace('type', 'python:type') + typename = fieldtype[0].astext() + builtin_types = ['int', 'long', 'float', 'bool', 'type'] + for builtin_type in builtin_types: + pattern = fr'(?`_ for more details on usage and debugging. +There are a couple of fusion backends available to optimize TorchScript execution. The default fuser on CPUs is NNC, which can perform fusions for both CPUs and GPUs. The default fuser on GPUs is NVFuser, which supports a wider range of operators and has demonstrated generated kernels with improved throughput. See the `NVFuser documentation `_ for more details on usage and debugging. References diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 2a3d5286aa192..02950ff971a62 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -84,6 +84,7 @@ Matrix Products cross matmul + vecdot multi_dot householder_product @@ -114,6 +115,7 @@ Experimental Functions cholesky_ex inv_ex + solve_ex lu_factor_ex ldl_factor ldl_factor_ex diff --git a/docs/source/named_tensor.rst b/docs/source/named_tensor.rst index 02113b9c987cf..cd8ac28466da6 100644 --- a/docs/source/named_tensor.rst +++ b/docs/source/named_tensor.rst @@ -300,7 +300,6 @@ operators, see :ref:`name_inference_reference-doc`. .. automethod:: align_as .. automethod:: align_to - .. automethod:: unflatten .. py:method:: flatten(dims, out_dim) -> Tensor :noindex: diff --git a/docs/source/nested.rst b/docs/source/nested.rst index 08522d7750a81..9ad43322d196b 100644 --- a/docs/source/nested.rst +++ b/docs/source/nested.rst @@ -10,16 +10,11 @@ Introduction The PyTorch API of nested tensors is in prototype stage and will change in the near future. -.. warning:: - - torch.NestedTensor currently does not support autograd. It needs to be used in the context - of torch.inference_mode(). - NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure. The only constraint on the input Tensors is that their dimension must match. -This enables more efficient metadata representations and operator coverage. +This enables more efficient metadata representations and access to purpose built kernels. Construction is straightforward and involves passing a list of Tensors to the constructor. @@ -35,7 +30,7 @@ nested_tensor([ tensor([3, 4, 5, 6, 7]) ]) -Data type and device can be chosen via the usual keyword arguments +Data type and device can be chosen via the usual keyword arguments. >>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda") >>> nt @@ -44,22 +39,108 @@ nested_tensor([ tensor([3., 4., 5., 6., 7.], device='cuda:0') ]) +In order to form a valid NestedTensor the passed Tensors also all need to match in dimension, but none of the other attributes need to. -Operator coverage -+++++++++++++++++ +>>> a = torch.randn(3, 50, 70) # image 1 +>>> b = torch.randn(3, 128, 64) # image 2 +>>> nt = torch.nested_tensor([a, b], dtype=torch.float32) +>>> nt.dim() +4 -We are currently on our path to wholesale extend operator coverage guided by specific ML use cases. +If one of the dimensions don't match, the constructor throws an error. -Operator coverage thus is currently very limited and only unbind is supported. +>>> a = torch.randn(50, 128) # text 1 +>>> b = torch.randn(3, 128, 64) # image 2 +>>> nt = torch.nested_tensor([a, b], dtype=torch.float32) +Traceback (most recent call last): + File "", line 1, in +RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0. ->>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda") +Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting +NestedTensor allocates new memory to store them and does not keep a reference. + + +At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future +we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors. +Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor +has a well defined dimension. If you have a need for this feature, please feel encourage to open a feature request so that +we can track it and plan accordingly. + +size ++++++++++++++++++++++++++ + +Even though a NestedTensor does not support .size() (or .shape), it supports .size(i) if dimension i is regular. + +>>> a = torch.randn(50, 128) # text 1 +>>> b = torch.randn(32, 128) # text 2 +>>> nt = torch.nested_tensor([a, b], dtype=torch.float32) +>>> nt.size(0) +2 +>>> nt.size(1) +Traceback (most recent call last): + File "", line 1, in +RuntimeError: Given dimension 1 is irregular and does not have a size. +>>> nt.size(2) +128 + +If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular torch.Tensor. + +>>> a = torch.randn(20, 128) # text 1 +>>> nt = torch.nested_tensor([a, a], dtype=torch.float32) +>>> nt.size(0) +2 +>>> nt.size(1) +20 +>>> nt.size(2) +128 +>>> torch.stack(nt.unbind()).size() +torch.Size([2, 20, 128]) +>>> torch.stack([a, a]).size() +torch.Size([2, 20, 128]) +>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a])) +True + +In the future we might make it easier to detect this condition and convert seamlessly. + +Please open a feature request if you have a need for this (or any other related feature for that manner). + +unbind ++++++++++++++++++++++++++ + +unbind allows you to retrieve a view of the constituents. + +>>> import torch +>>> a = torch.randn(2, 3) +>>> b = torch.randn(3, 4) +>>> nt = torch.nested_tensor([a, b], dtype=torch.float32) >>> nt nested_tensor([ - tensor([0., 1., 2.], device='cuda:0'), - tensor([3., 4., 5., 6., 7.], device='cuda:0') + tensor([[ 1.2286, -1.2343, -1.4842], + [-0.7827, 0.6745, 0.0658]]), + tensor([[-1.1247, -0.4078, -1.0633, 0.8083], + [-0.2871, -0.2980, 0.5559, 1.9885], + [ 0.4074, 2.4855, 0.0733, 0.8285]]) ]) >>> nt.unbind() -[tensor([0., 1., 2.], device='cuda:0'), tensor([3., 4., 5., 6., 7.], device='cuda:0')] +(tensor([[ 1.2286, -1.2343, -1.4842], + [-0.7827, 0.6745, 0.0658]]), tensor([[-1.1247, -0.4078, -1.0633, 0.8083], + [-0.2871, -0.2980, 0.5559, 1.9885], + [ 0.4074, 2.4855, 0.0733, 0.8285]])) +>>> nt.unbind()[0] is not a +True +>>> nt.unbind()[0].mul_(3) +tensor([[ 3.6858, -3.7030, -4.4525], + [-2.3481, 2.0236, 0.1975]]) +>>> nt +nested_tensor([ + tensor([[ 3.6858, -3.7030, -4.4525], + [-2.3481, 2.0236, 0.1975]]), + tensor([[-1.1247, -0.4078, -1.0633, 0.8083], + [-0.2871, -0.2980, 0.5559, 1.9885], + [ 0.4074, 2.4855, 0.0733, 0.8285]]) +]) + +Note that nt.unbind()[0] is not a, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor. Nested tensor methods +++++++++++++++++++++++++ diff --git a/docs/source/nn.init.rst b/docs/source/nn.init.rst index a980f16f5f6db..a2a2d0bc37252 100644 --- a/docs/source/nn.init.rst +++ b/docs/source/nn.init.rst @@ -6,6 +6,11 @@ torch.nn.init ============= +.. warning:: + All the functions in this module are intended to be used to initialize neural network + parameters, so they all run in :func:`torch.no_grad` mode and will not be taken into + account by autograd. + .. currentmodule:: torch.nn.init .. autofunction:: calculate_gain .. autofunction:: uniform_ diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index e65fcc72da854..c678844edcfaa 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -9,7 +9,7 @@ created on that device. The selected device can be changed with a :any:`torch.cuda.device` context manager. However, once a tensor is allocated, you can do operations on it irrespective -of the selected device, and the results will be always placed in on the same +of the selected device, and the results will be always placed on the same device as the tensor. Cross-GPU operations are not allowed by default, with the exception of diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index 6a915163f2937..716f2532ae78f 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -89,9 +89,10 @@ properly in order to ensure that the new :class:`Function` works properly with the autograd engine. - :meth:`~torch.autograd.function.FunctionCtx.save_for_backward` must be - used when saving input or output tensors of the forward to be used later in the backward. - Anything else, i.e., non-tensors and tensors that are neither input nor output - should be stored directly on `ctx`. + used to save any tensors to be used in the backward pass. Non-tensors should + be stored directly on `ctx`. If tensors that are neither input nor output + are saved for backward your :class:`~Function` may not support double backward + (see step 3). - :meth:`~torch.autograd.function.FunctionCtx.mark_dirty` must be used to mark any input that is modified inplace by the forward function. - :meth:`~torch.autograd.function.FunctionCtx.mark_non_differentiable` must diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index 8e7804d6ef5f3..2c24c5d7ee521 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -570,27 +570,9 @@ Q: How to export models with primitive type inputs (e.g. int, float)? Q: Does ONNX support implicit scalar datatype casting? No, but the exporter will try to handle that part. Scalars are exported as constant tensors. - The exporter will try to figure out the right datatype for scalars. However when it is unable - to do so, you will need to manually specify the datatype. This often happens with - scripted models, where the datatypes are not recorded. For example:: - - class ImplicitCastType(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - # Exporter knows x is float32, will export "2" as float32 as well. - y = x + 2 - # Currently the exporter doesn't know the datatype of y, so - # "3" is exported as int64, which is wrong! - return y + 3 - # To fix, replace the line above with: - # return y + torch.tensor([3], dtype=torch.float32) - - x = torch.tensor([1.0], dtype=torch.float32) - torch.onnx.export(ImplicitCastType(), x, "implicit_cast.onnx", - example_outputs=ImplicitCastType()(x)) - - We are trying to improve the datatype propagation in the exporter such that implicit casting - is supported in more cases. + The exporter will figure out the right data type for scalars. In rare cases when it is unable + to do so, you will need to manually specify the datatype with e.g. `dtype=torch.float32`. + If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues). Q: Are lists of Tensors exportable to ONNX? diff --git a/docs/source/package.rst b/docs/source/package.rst index 9664460ac96a1..a8840c6215395 100644 --- a/docs/source/package.rst +++ b/docs/source/package.rst @@ -5,7 +5,7 @@ torch.package ============= -``torch.package`` adds support for creating hermetic packages containing arbitrary +``torch.package`` adds support for creating packages containing both artifacts and arbitrary PyTorch code. These packages can be saved, shared, used to load and execute models at a later date or on a different machine, and can even be deployed to production using ``torch::deploy``. @@ -93,7 +93,7 @@ work for exploring the contents. Some common ways to interact with ZIP files: Use the ``file_structure()`` API """""""""""""""""""""""""""""""" -:class:`PackageImporter` and :class:`PackageExporter` provide a ``file_structure()`` method, which will return a printable +:class:`PackageImporter` provides a ``file_structure()`` method, which will return a printable and queryable ``Folder`` object. The ``Folder`` object is a simple directory structure that you can use to explore the current contents of a ``torch.package``. @@ -105,10 +105,10 @@ use the glob-style ``include`` and ``exclude`` filtering arguments. with PackageExporter('my_package.pt') as pe: pe.save_pickle('models', 'model_1.pkl', mod) - # can limit printed items with include/exclude args - print(pe.file_structure(include=["**/utils.py", "**/*.pkl"], exclude="**/*.storages")) importer = PackageImporter('my_package.pt') + # can limit printed items with include/exclude args + print(importer.file_structure(include=["**/utils.py", "**/*.pkl"], exclude="**/*.storages")) print(importer.file_structure()) # will print out all files @@ -146,8 +146,8 @@ You can also query ``Folder`` objects with the ``has_file()`` method. :: - exporter_file_structure = exporter.file_structure() - found: bool = exporter_file_structure.has_file("package_a/subpackage.py") + importer_file_structure = importer.file_structure() + found: bool = importer_file_structure.has_file("package_a/subpackage.py") See why a given module was included as a dependency? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -269,9 +269,9 @@ Steps: # save as normal, no extra work necessary pe.save_pickle('foo_collection', 'foo1.pkl', foo_1) pe.save_pickle('foo_collection', 'foo2.pkl', foo_2) - print(pe.file_structure()) pi = PackageImporter('foo_package.pt') + print(pi.file_structure()) imported_foo = pi.load_pickle('foo_collection', 'foo1.pkl') print(f"foo_1 string: '{imported_foo.my_string}'") print(f"foo_1 export time: {imported_foo.time_exported}") diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index da6649a2fee3d..e142cf70a619f 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -428,6 +428,24 @@ This module implements the quantized versions of the functional layers such as upsample_bilinear upsample_nearest +torch.nn.quantizable +~~~~~~~~~~~~~~~~~~~~~~ + +This module implements the quantizable versions of some of the nn layers. +These modules can be used in conjunction with the custom module mechanism, +by providing the ``custom_module_config`` argument to both prepare and convert. + +.. currentmodule:: torch.nn.quantizable + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + LSTM + MultiheadAttention + + torch.nn.quantized.dynamic ~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: torch.nn.quantized.dynamic diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 3be4d5a390e75..b71b1fb976953 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -41,6 +41,105 @@ to lower precision with minimal accuracy loss. Quantization API Summary ----------------------------- +PyTorch provides two different modes of quantization: Eager Mode Quantization and FX Graph Mode Quantization. + +Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals. + +FX Graph Mode Quantization is a new automated quantization framework in PyTorch, and currently it's a prototype feature. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process, although people might need to refactor the model to make the model compatible with FX Graph Mode Quantization (symbolically traceable with ``torch.fx``). Note that FX Graph Mode Quantization is not expected to work on arbitrary models since the model might not be symbolically traceable, we will integrate it into domain libraries like torchvision and users will be able to quantize models similar to the ones in supported domain libraries with FX Graph Mode Quantization. For arbitrary models we'll provide general guidelines, but to actually make it work, users might need to be familiar with ``torch.fx``, especially on how to make a model symbolically traceable. + +New users of quantization are encouraged to try out FX Graph Mode Quantization first, if it does not work, user may try to follow the guideline of `using FX Graph Mode Quantization `_ or fall back to eager mode quantization. + +The following table compares the differences between Eager Mode Quantization and FX Graph Mode Quantization: + ++-----------------+-------------------+-------------------+ +| |Eager Mode |FX Graph | +| |Quantization |Mode | +| | |Quantization | ++-----------------+-------------------+-------------------+ +|Release |beta |prototype | +|Status | | | ++-----------------+-------------------+-------------------+ +|Operator |Manual |Automatic | +|Fusion | | | ++-----------------+-------------------+-------------------+ +|Quant/DeQuant |Manual |Automatic | +|Placement | | | ++-----------------+-------------------+-------------------+ +|Quantizing |Supported |Supported | +|Modules | | | ++-----------------+-------------------+-------------------+ +|Quantizing |Manual |Automatic | +|Functionals/Torch| | | +|Ops | | | ++-----------------+-------------------+-------------------+ +|Support for |Limited Support |Fully | +|Customization | |Supported | ++-----------------+-------------------+-------------------+ +|Quantization Mode|Post Training |Post Training | +|Support |Quantization: |Quantization: | +| |Static, Dynamic, |Static, Dynamic, | +| |Weight Only |Weight Only | +| | | | +| |Quantiztion Aware |Quantiztion Aware | +| |Training: |Training: | +| |Static |Static | ++-----------------+-------------------+-------------------+ +|Input/Output |``torch.nn.Module``|``torch.nn.Module``| +|Model Type | |(May need some | +| | |refactors to make | +| | |the model | +| | |compatible with FX | +| | |Graph Mode | +| | |Quantization) | ++-----------------+-------------------+-------------------+ + + +There are three types of quantization supported: + +1. dynamic quantization (weights quantized with activations read/stored in + floating point and quantized for compute) +2. static quantization (weights quantized, activations quantized, calibration + required post training) +3. static quantization aware training (weights quantized, activations quantized, + quantization numerics modeled during training) + +Please see our `Introduction to Quantization on Pytorch +`_ blog post +for a more comprehensive overview of the tradeoffs between these quantization +types. + +Operator coverage varies between dynamic and static quantization and is captured in the table below. +Note that for FX quantization, the corresponding functionals are also supported. + ++---------------------------+-------------------+--------------------+ +| |Static | Dynamic | +| |Quantization | Quantization | ++---------------------------+-------------------+--------------------+ +| | nn.Linear | | Y | | Y | +| | nn.Conv1d/2d/3d | | Y | | N | ++---------------------------+-------------------+--------------------+ +| | nn.LSTM | | Y (through | | Y | +| | | | custom modules) | | | +| | nn.GRU | | N | | Y | ++---------------------------+-------------------+--------------------+ +| | nn.RNNCell | | N | | Y | +| | nn.GRUCell | | N | | Y | +| | nn.LSTMCell | | N | | Y | ++---------------------------+-------------------+--------------------+ +|nn.EmbeddingBag | Y (activations | | +| | are in fp32) | Y | ++---------------------------+-------------------+--------------------+ +|nn.Embedding | Y | N | ++---------------------------+-------------------+--------------------+ +| nn.MultiheadAttention | Y (through | Not supported | +| | custom modules) | | ++---------------------------+-------------------+--------------------+ +| Activations | Broadly supported | Un-changed, | +| | | computations | +| | | stay in fp32 | ++---------------------------+-------------------+--------------------+ + + Eager Mode Quantization ^^^^^^^^^^^^^^^^^^^^^^^ For a general introduction to the quantization flow, including different types of quantization, please take a look at `General Quantization Flow`_. @@ -69,31 +168,31 @@ Diagram:: / linear_weight_int8 -API example:: +PTDQ API Example:: - import torch - - # define a floating point model - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.fc = torch.nn.Linear(4, 4) + import torch - def forward(self, x): - x = self.fc(x) - return x + # define a floating point model + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) - # create a model instance - model_fp32 = M() - # create a quantized model instance - model_int8 = torch.quantization.quantize_dynamic( - model_fp32, # the original model - {torch.nn.Linear}, # a set of layers to dynamically quantize - dtype=torch.qint8) # the target dtype for quantized weights + def forward(self, x): + x = self.fc(x) + return x - # run the model - input_fp32 = torch.randn(4, 4, 4, 4) - res = model_int8(input_fp32) + # create a model instance + model_fp32 = M() + # create a quantized model instance + model_int8 = torch.quantization.quantize_dynamic( + model_fp32, # the original model + {torch.nn.Linear}, # a set of layers to dynamically quantize + dtype=torch.qint8) # the target dtype for quantized weights + + # run the model + input_fp32 = torch.randn(4, 4, 4, 4) + res = model_int8(input_fp32) To learn more about dynamic quantization please see our `dynamic quantization tutorial `_. @@ -124,14 +223,14 @@ Diagram:: / linear_weight_int8 -API Example:: +PTSQ API Example:: import torch # define a floating point model where some layers could be statically quantized class M(torch.nn.Module): def __init__(self): - super(M, self).__init__() + super().__init__() # QuantStub converts tensors from floating point to quantized self.quant = torch.quantization.QuantStub() self.conv = torch.nn.Conv2d(1, 1, 1) @@ -222,14 +321,14 @@ Diagram:: / linear_weight_int8 -API Example:: +QAT API Example:: import torch # define a floating point model where some layers could benefit from QAT class M(torch.nn.Module): def __init__(self): - super(M, self).__init__() + super().__init__() # QuantStub converts tensors from floating point to quantized self.quant = torch.quantization.QuantStub() self.conv = torch.nn.Conv2d(1, 1, 1) @@ -249,8 +348,8 @@ API Example:: # create a model instance model_fp32 = M() - # model must be set to train mode for QAT logic to work - model_fp32.train() + # model must be set to eval for fusion to work + model_fp32.eval() # attach a global qconfig, which contains information about what kind # of observers to attach. Use 'fbgemm' for server inference and @@ -265,8 +364,9 @@ API Example:: [['conv', 'bn', 'relu']]) # Prepare the model for QAT. This inserts observers and fake_quants in + # the model needs to be set to train for QAT logic to work # the model that will observe weight and activation tensors during calibration. - model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused) + model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused.train()) # run the training loop (not shown) training_loop(model_fp32_prepared) @@ -324,13 +424,14 @@ to do the following in addition: There are multiple quantization types in post training quantization (weight only, dynamic and static) and the configuration is done through `qconfig_mapping` (an argument of the `prepare_fx` function). -API Example:: +FXPTQ API Example:: - from torch.quantization import QConfigMapping + import torch + from torch.ao.quantization import QConfigMapping import torch.quantization.quantize_fx as quantize_fx import copy - model_fp = UserModel(...) + model_fp = UserModel() # # post training dynamic/weight_only quantization @@ -340,9 +441,11 @@ API Example:: model_to_quantize = copy.deepcopy(model_fp) model_to_quantize.eval() qconfig_mapping = QConfigMapping().set_global(torch.quantization.default_dynamic_qconfig) + # a tuple of one or more example inputs are needed to trace the model + example_inputs = (input_fp32) # prepare - model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping) - # no calibration needed when we only have dynamici/weight_only quantization + model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) + # no calibration needed when we only have dynamic/weight_only quantization # quantize model_quantized = quantize_fx.convert_fx(model_prepared) @@ -354,7 +457,7 @@ API Example:: qconfig_mapping = QConfigMapping().set_global(torch.quantization.get_default_qconfig('qnnpack')) model_to_quantize.eval() # prepare - model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping) + model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) # calibrate (not shown) # quantize model_quantized = quantize_fx.convert_fx(model_prepared) @@ -367,7 +470,7 @@ API Example:: qconfig_mapping = QConfigMapping().set_global(torch.quantization.get_default_qat_qconfig('qnnpack')) model_to_quantize.train() # prepare - model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping) + model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs) # training loop (not shown) # quantize model_quantized = quantize_fx.convert_fx(model_prepared) @@ -432,7 +535,7 @@ Here are a few key attributes for quantized Tensor: * scale (float) * zero_point (int) - * torch.per_tensor_affine would have quantization parameters of + * torch.per_channel_affine would have quantization parameters of * per_channel_scales (list of float) * per_channel_zero_points (list of int) @@ -749,8 +852,8 @@ based on observed tensor data are provided, developers can provide their own quantization functions. Quantization can be applied selectively to different parts of the model or configured differently for different parts of the model. -We also provide support for per channel quantization for **conv2d()**, -**conv3d()** and **linear()** +We also provide support for per channel quantization for **conv1d()**, **conv2d()**, +**conv3d()** and **linear()**. Quantization workflows work by adding (e.g. adding observers as ``.observer`` submodule) or replacing (e.g. converting ``nn.Conv2d`` to @@ -790,106 +893,101 @@ on that output. The observer will be stored under the `activation_post_process` as an attribute of the custom module instance. Relaxing these restrictions may be done at a future time. -Example:: - - import torch - import torch.nn.quantized as nnq - from torch.quantization import QConfigMapping - import torch.quantization.quantize_fx - - # original fp32 module to replace - class CustomModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return self.linear(x) - - # custom observed module, provided by user - class ObservedCustomModule(torch.nn.Module): - def __init__(self, linear): - super().__init__() - self.linear = linear - - def forward(self, x): - return self.linear(x) - - @classmethod - def from_float(cls, float_module): - assert hasattr(float_module, 'qconfig') - observed = cls(float_module.linear) - observed.qconfig = float_module.qconfig - return observed - - # custom quantized module, provided by user - class StaticQuantCustomModule(torch.nn.Module): - def __init__(self, linear): - super().__init__() - self.linear = linear - - def forward(self, x): - return self.linear(x) - - @classmethod - def from_observed(cls, observed_module): - assert hasattr(observed_module, 'qconfig') - assert hasattr(observed_module, 'activation_post_process') - observed_module.linear.activation_post_process = \ - observed_module.activation_post_process - quantized = cls(nnq.Linear.from_float(observed_module.linear)) - return quantized - - # - # example API call (Eager mode quantization) - # - - m = torch.nn.Sequential(CustomModule()).eval() - - prepare_custom_config_dict = { - "float_to_observed_custom_module_class": { - CustomModule: ObservedCustomModule - } - } - convert_custom_config_dict = { - "observed_to_quantized_custom_module_class": { - ObservedCustomModule: StaticQuantCustomModule - } - } - - m.qconfig = torch.quantization.default_qconfig - mp = torch.quantization.prepare( - m, prepare_custom_config_dict=prepare_custom_config_dict) - # calibration (not shown) - mq = torch.quantization.convert( - mp, convert_custom_config_dict=convert_custom_config_dict) - - # - # example API call (FX graph mode quantization) - # - - m = torch.nn.Sequential(CustomModule()).eval() - - qconfig_mapping = QConfigMapping().set_global(torch.quantization.default_qconfig) - prepare_custom_config_dict = { - "float_to_observed_custom_module_class": { - "static": { - CustomModule: ObservedCustomModule, - } - } - } - convert_custom_config_dict = { - "observed_to_quantized_custom_module_class": { - "static": { - ObservedCustomModule: StaticQuantCustomModule, - } - } - } - mp = torch.quantization.quantize_fx.prepare_fx( - m, qconfig_mapping, prepare_custom_config_dict=prepare_custom_config_dict) - # calibration (not shown) - mq = torch.quantization.quantize_fx.convert_fx( - mp, convert_custom_config_dict=convert_custom_config_dict) +Custom API Example:: + + import torch + import torch.nn.quantized as nnq + from torch.ao.quantization import QConfigMapping + import torch.ao.quantization.quantize_fx + + # original fp32 module to replace + class CustomModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return self.linear(x) + + # custom observed module, provided by user + class ObservedCustomModule(torch.nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear(x) + + @classmethod + def from_float(cls, float_module): + assert hasattr(float_module, 'qconfig') + observed = cls(float_module.linear) + observed.qconfig = float_module.qconfig + return observed + + # custom quantized module, provided by user + class StaticQuantCustomModule(torch.nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear(x) + + @classmethod + def from_observed(cls, observed_module): + assert hasattr(observed_module, 'qconfig') + assert hasattr(observed_module, 'activation_post_process') + observed_module.linear.activation_post_process = \ + observed_module.activation_post_process + quantized = cls(nnq.Linear.from_float(observed_module.linear)) + return quantized + + # + # example API call (Eager mode quantization) + # + + m = torch.nn.Sequential(CustomModule()).eval() + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + CustomModule: ObservedCustomModule + } + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + ObservedCustomModule: StaticQuantCustomModule + } + } + m.qconfig = torch.ao.quantization.default_qconfig + mp = torch.ao.quantization.prepare( + m, prepare_custom_config_dict=prepare_custom_config_dict) + # calibration (not shown) + mq = torch.ao.quantization.convert( + mp, convert_custom_config_dict=convert_custom_config_dict) + # + # example API call (FX graph mode quantization) + # + m = torch.nn.Sequential(CustomModule()).eval() + qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig) + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule: ObservedCustomModule, + } + } + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "static": { + ObservedCustomModule: StaticQuantCustomModule, + } + } + } + mp = torch.ao.quantization.quantize_fx.prepare_fx( + m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict) + # calibration (not shown) + mq = torch.ao.quantization.quantize_fx.convert_fx( + mp, convert_custom_config=convert_custom_config_dict) Best Practices -------------- @@ -1044,8 +1142,5 @@ Please take a look at `Limitations of Symbolic Tracing `, +:ref:`CSR`, :ref:`CSC`, +:ref:`BSR`, and :ref:`BSC`. __ https://en.wikipedia.org/wiki/Sparse_matrix @@ -136,7 +139,7 @@ only: .. _sparse-hybrid-coo-docs: -Hybrid sparse COO tensors +Sparse hybrid COO tensors ------------------------- Pytorch implements an extension of sparse tensors with scalar values @@ -157,8 +160,8 @@ have: .. note:: - We use (M + K)-dimensional tensor to denote a N-dimensional hybrid - sparse tensor, where M and K are the numbers of sparse and dense + We use (M + K)-dimensional tensor to denote a N-dimensional sparse + hybrid tensor, where M and K are the numbers of sparse and dense dimensions, respectively, such that M + K == N holds. Suppose we want to create a (2 + 1)-dimensional tensor with the entry @@ -208,7 +211,7 @@ invariants: Uncoalesced sparse COO tensors ------------------------------ -PyTorch sparse COO tensor format permits *uncoalesced* sparse tensors, +PyTorch sparse COO tensor format permits sparse *uncoalesced* tensors, where there may be duplicate coordinates in the indices; in this case, the interpretation is that the value at that index is the sum of all duplicate value entries. For example, one can specify multiple values, @@ -242,7 +245,7 @@ sparse tensor with the following properties: For the most part, you shouldn't have to care whether or not a sparse tensor is coalesced or not, as most operations will work - identically given a coalesced or uncoalesced sparse tensor. + identically given a sparse coalesced or uncoalesced tensor. However, some operations can be implemented more efficiently on uncoalesced tensors, and some on coalesced tensors. @@ -340,7 +343,7 @@ When working with uncoalesced sparse COO tensors, one must take into an account the additive nature of uncoalesced data: the values of the same indices are the terms of a sum that evaluation gives the value of the corresponding tensor element. For example, the scalar -multiplication on an uncoalesced sparse tensor could be implemented by +multiplication on a sparse uncoalesced tensor could be implemented by multiplying all the uncoalesced values with the scalar because ``c * (a + b) == c * a + c * b`` holds. However, any nonlinear operation, say, a square root, cannot be implemented by applying the operation to @@ -370,49 +373,143 @@ assumption that the fill value is negative infinity. .. See https://github.com/Quansight-Labs/rfcs/tree/pearu/rfc-fill-value/RFC-0004-sparse-fill-value for a new API +.. _sparse-compressed-docs: + +Sparse Compressed Tensors ++++++++++++++++++++++++++ + +Sparse Compressed Tensors represents a class of sparse tensors that +have a common feature of compressing the indices of a certain dimension +using an encoding that enables certain optimizations on linear algebra +kernels of sparse compressed tensors. This encoding is based on the +`Compressed Sparse Row (CSR)`__ format that PyTorch sparse compressed +tensors extend with the support of sparse tensor batches, allowing +multi-dimensional tensor values, and storing sparse tensor values in +dense blocks. + +__ https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format) + +.. note:: + + We use (B + M + K)-dimensional tensor to denote a N-dimensional + sparse compressed hybrid tensor, where B, M, and K are the numbers + of batch, sparse, and dense dimensions, respectively, such that + ``B + M + K == N`` holds. The number of sparse dimensions for + sparse compressed tensors is always two, ``M == 2``. + +.. note:: + + We say that an indices tensor ``compressed_indices`` uses CSR + encoding if the following invariants are satisfied: + + - ``compressed_indices`` is a contiguous strided 32 or 64 bit + integer tensor + - ``compressed_indices`` shape is ``(*batchsize, + compressed_dim_size + 1)`` where ``compressed_dim_size`` is the + number of compressed dimensions (e.g. rows or columns) + - ``compressed_indices[..., 0] == 0`` where ``...`` denotes batch + indices + - ``compressed_indices[..., compressed_dim_size] == nse`` where + ``nse`` is the number of specified elements + - ``0 <= compressed_indices[..., i] - compressed_indices[..., i - + 1] <= plain_dim_size`` for ``i=1, ..., compressed_dim_size``, + where ``plain_dim_size`` is the number of plain dimensions + (orthogonal to compressed dimensions, e.g. columns or rows). + .. _sparse-csr-docs: Sparse CSR Tensor -+++++++++++++++++ +----------------- -The CSR (Compressed Sparse Row) sparse tensor format implements the CSR format -for storage of 2 dimensional tensors. Although there is no support for N-dimensional -tensors, the primary advantage over the COO format is better use of storage and -much faster computation operations such as sparse matrix-vector multiplication -using MKL and MAGMA backends. CUDA support does not exist as of now. +The primary advantage of the CSR format over the COO format is better +use of storage and much faster computation operations such as sparse +matrix-vector multiplication using MKL and MAGMA backends. -A CSR sparse tensor consists of three 1-D tensors: ``crow_indices``, ``col_indices`` -and ``values``: +In the simplest case, a (0 + 2 + 0)-dimensional sparse CSR tensor +consists of three 1-D tensors: ``crow_indices``, ``col_indices`` and +``values``: + + - The ``crow_indices`` tensor consists of compressed row + indices. This is a 1-D tensor of size ``nrows + 1`` (the number of + rows plus 1). The last element of ``crow_indices`` is the number + of specified elements, ``nse``. This tensor encodes the index in + ``values`` and ``col_indices`` depending on where the given row + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given row. - - The ``crow_indices`` tensor consists of compressed row indices. This is a 1-D tensor - of size ``size[0] + 1``. The last element is the number of non-zeros. This tensor - encodes the index in ``values`` and ``col_indices`` depending on where the given row - starts. Each successive number in the tensor subtracted by the number before it denotes - the number of elements in a given row. - - The ``col_indices`` tensor contains the column indices of each value. This is a 1-D - tensor of size ``nnz``. - - The ``values`` tensor contains the values of the CSR tensor. This is a 1-D tensor - of size ``nnz``. + - The ``col_indices`` tensor contains the column indices of each + element. This is a 1-D tensor of size ``nse``. + + - The ``values`` tensor contains the values of the CSR tensor + elements. This is a 1-D tensor of size ``nse``. .. note:: - The index tensors ``crow_indices`` and ``col_indices`` should have element type either - ``torch.int64`` (default) or ``torch.int32``. If you want to use MKL-enabled matrix - operations, use ``torch.int32``. This is as a result of the default linking of pytorch - being with MKL LP64, which uses 32 bit integer indexing. + The index tensors ``crow_indices`` and ``col_indices`` should have + element type either ``torch.int64`` (default) or + ``torch.int32``. If you want to use MKL-enabled matrix operations, + use ``torch.int32``. This is as a result of the default linking of + pytorch being with MKL LP64, which uses 32 bit integer indexing. + +In the general case, the (B + 2 + K)-dimensional sparse CSR tensor +consists of two (B + 1)-dimensional index tensors ``crow_indices`` and +``col_indices``, and of (1 + K)-dimensional ``values`` tensor such +that + + - ``crow_indices.shape == (*batchsize, nrows + 1)`` + + - ``col_indices.shape == (*batchsize, nse)`` + + - ``values.shape == (nse, *densesize)`` + +while the shape of the sparse CSR tensor is ``(*batchsize, nrows, +ncols, *densesize)`` where ``len(batchsize) == B`` and +``len(densesize) == K``. + +.. note:: + + The batches of sparse CSR tensors are dependent: the number of + specified elements in all batches must be the same. This somewhat + artifical constraint allows efficient storage of the indices of + different CSR batches. + +.. note:: + + The number of sparse and dense dimensions can be acquired using + :meth:`torch.Tensor.sparse_dim` and :meth:`torch.Tensor.dense_dim` + methods. The batch dimensions can be computed from the tensor + shape: ``batchsize = tensor.shape[:-tensor.sparse_dim() - + tensor.dense_dim()]``. + +.. note:: + + The memory consumption of a sparse CSR tensor is at least + ``(nrows * 8 + (8 + * + prod(densesize)) * nse) * prod(batchsize)`` bytes (plus a constant + overhead from storing other tensor data). + + With the same example data of :ref:`the note in sparse COO format + introduction`, the memory consumption of a 10 000 + x 10 000 tensor with 100 000 non-zero 32-bit floating point numbers + is at least ``(10000 * 8 + (8 + 4 * 1) * 100 000) * 1 = 1 280 000`` + bytes when using CSR tensor layout. Notice the 1.6 and 310 fold + savings from using CSR storage format compared to using the COO and + strided formats, respectively. Construction of CSR tensors ---------------------------- +''''''''''''''''''''''''''' -Sparse CSR matrices can be directly constructed by using the :func:`torch.sparse_csr_tensor` -method. The user must supply the row and column indices and values tensors separately. -The ``size`` argument is optional and will be deduced from the the ``crow_indices`` -and ``col_indices`` if it is not present. +Sparse CSR tensors can be directly constructed by using the +:func:`torch.sparse_csr_tensor` function. The user must supply the row +and column indices and values tensors separately where the row indices +must be specified using the CSR compression encoding. The ``size`` +argument is optional and will be deduced from the ``crow_indices`` and +``col_indices`` if it is not present. >>> crow_indices = torch.tensor([0, 2, 4]) >>> col_indices = torch.tensor([0, 1, 0, 1]) >>> values = torch.tensor([1, 2, 3, 4]) - >>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double) + >>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.float64) >>> csr tensor(crow_indices=tensor([0, 2, 4]), col_indices=tensor([0, 1, 0, 1]), @@ -422,20 +519,29 @@ and ``col_indices`` if it is not present. tensor([[1., 2.], [3., 4.]], dtype=torch.float64) -CSR Tensor Operations ---------------------- +.. note:: + + The values of sparse dimensions in deduced ``size`` is computed + from the size of ``crow_indices`` and the maximal index value in + ``col_indices``. If the number of columns needs to be larger than + in the deduced ``size`` then the ``size`` argument must be + specified explicitly. -The simplest way of constructing a sparse CSR tensor from a strided or sparse COO -tensor is to use :meth:`tensor.to_sparse_csr`. Any zeros in the (strided) tensor will -be interpreted as missing values in the sparse tensor: +The simplest way of constructing a 2-D sparse CSR tensor from a +strided or sparse COO tensor is to use +:meth:`torch.Tensor.to_sparse_csr` method. Any zeros in the (strided) +tensor will be interpreted as missing values in the sparse tensor: - >>> a = torch.tensor([[0, 0, 1, 0], [1, 2, 0, 0], [0, 0, 0, 0]], dtype = torch.float64) + >>> a = torch.tensor([[0, 0, 1, 0], [1, 2, 0, 0], [0, 0, 0, 0]], dtype=torch.float64) >>> sp = a.to_sparse_csr() >>> sp tensor(crow_indices=tensor([0, 1, 3, 3]), col_indices=tensor([2, 0, 1]), values=tensor([1., 1., 2.]), size=(3, 4), nnz=3, dtype=torch.float64) +CSR Tensor Operations +''''''''''''''''''''' + The sparse matrix-vector multiplication can be performed with the :meth:`tensor.matmul` method. This is currently the only math operation supported on CSR tensors. @@ -446,6 +552,272 @@ supported on CSR tensors. [1.3180], [0.0000]], dtype=torch.float64) +.. _sparse-csc-docs: + +Sparse CSC Tensor +----------------- + +The sparse CSC (Compressed Sparse Column) tensor format implements the +CSC format for storage of 2 dimensional tensors with an extension to +supporting batches of sparse CSC tensors and values being +multi-dimensional tensors. + +.. note:: + + Sparse CSC tensor is essentially a transpose of the sparse CSR + tensor when the transposition is about swapping the sparse + dimensions. + +Similarly to :ref:`sparse CSR tensors `, a sparse CSC +tensor consists of three tensors: ``ccol_indices``, ``row_indices`` +and ``values``: + + - The ``ccol_indices`` tensor consists of compressed column + indices. This is a (B + 1)-D tensor of shape ``(*batchsize, ncols + 1)``. + The last element is the number of specified + elements, ``nse``. This tensor encodes the index in ``values`` and + ``row_indices`` depending on where the given column starts. Each + successive number in the tensor subtracted by the number before it + denotes the number of elements in a given column. + + - The ``row_indices`` tensor contains the row indices of each + element. This is a (B + 1)-D tensor of shape ``(*batchsize, nse)``. + + - The ``values`` tensor contains the values of the CSC tensor + elements. This is a (1 + K)-D tensor of shape ``(nse, *densesize)``. + +Construction of CSC tensors +''''''''''''''''''''''''''' + +Sparse CSC tensors can be directly constructed by using the +:func:`torch.sparse_csc_tensor` function. The user must supply the row +and column indices and values tensors separately where the column indices +must be specified using the CSR compression encoding. The ``size`` +argument is optional and will be deduced from the ``row_indices`` and +``ccol_indices`` tensors if it is not present. + + >>> ccol_indices = torch.tensor([0, 2, 4]) + >>> row_indices = torch.tensor([0, 1, 0, 1]) + >>> values = torch.tensor([1, 2, 3, 4]) + >>> csc = torch.sparse_csc_tensor(ccol_indices, row_indices, values, dtype=torch.float64) + >>> csc + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) + >>> csc.to_dense() + tensor([[1., 3.], + [2., 4.]], dtype=torch.float64) + +.. note:: + + The sparse CSC tensor constructor function has the compressed + column indices argument before the row indices argument. + +The (0 + 2 + 0)-dimensional sparse CSC tensors can be constructed from +any two-dimensional tensor using :meth:`torch.Tensor.to_sparse_csc` +method. Any zeros in the (strided) tensor will be interpreted as +missing values in the sparse tensor: + + >>> a = torch.tensor([[0, 0, 1, 0], [1, 2, 0, 0], [0, 0, 0, 0]], dtype=torch.float64) + >>> sp = a.to_sparse_csc() + >>> sp + tensor(ccol_indices=tensor([0, 1, 2, 3, 3]), + row_indices=tensor([1, 1, 0]), + values=tensor([1., 2., 1.]), size=(3, 4), nnz=3, dtype=torch.float64, + layout=torch.sparse_csc) + +.. _sparse-bsr-docs: + +Sparse BSR Tensor +----------------- + +The sparse BSR (Block compressed Sparse Row) tensor format implements the +BSR format for storage of two-dimensional tensors with an extension to +supporting batches of sparse BSR tensors and values being blocks of +multi-dimensional tensors. + +A sparse BSR tensor consists of three tensors: ``crow_indices``, +``col_indices`` and ``values``: + + - The ``crow_indices`` tensor consists of compressed row + indices. This is a (B + 1)-D tensor of shape ``(*batchsize, + nrowblocks + 1)``. The last element is the number of specified blocks, + ``nse``. This tensor encodes the index in ``values`` and + ``col_indices`` depending on where the given column block + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of blocks in a given row. + + - The ``col_indices`` tensor contains the column block indices of each + element. This is a (B + 1)-D tensor of shape ``(*batchsize, + nse)``. + + - The ``values`` tensor contains the values of the sparse BSR tensor + elements collected into two-dimensional blocks. This is a (1 + 2 + + K)-D tensor of shape ``(nse, nrowblocks, ncolblocks, + *densesize)``. + +Construction of BSR tensors +''''''''''''''''''''''''''' + +Sparse BSR tensors can be directly constructed by using the +:func:`torch.sparse_bsr_tensor` function. The user must supply the row +and column block indices and values tensors separately where the row block indices +must be specified using the CSR compression encoding. +The ``size`` argument is optional and will be deduced from the ``crow_indices`` and +``col_indices`` tensors if it is not present. + + >>> crow_indices = torch.tensor([0, 2, 4]) + >>> col_indices = torch.tensor([0, 1, 0, 1]) + >>> values = torch.tensor([[[0, 1, 2], [6, 7, 8]], + ... [[3, 4, 5], [9, 10, 11]], + ... [[12, 13, 14], [18, 19, 20]], + ... [[15, 16, 17], [21, 22, 23]]]) + >>> bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, dtype=torch.float64) + >>> bsr + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([[[ 0., 1., 2.], + [ 6., 7., 8.]], + [[ 3., 4., 5.], + [ 9., 10., 11.]], + [[12., 13., 14.], + [18., 19., 20.]], + [[15., 16., 17.], + [21., 22., 23.]]]), + size=(4, 6), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) + >>> bsr.to_dense() + tensor([[ 0., 1., 2., 3., 4., 5.], + [ 6., 7., 8., 9., 10., 11.], + [12., 13., 14., 15., 16., 17.], + [18., 19., 20., 21., 22., 23.]], dtype=torch.float64) + +The (0 + 2 + 0)-dimensional sparse BSR tensors can be constructed from +any two-dimensional tensor using :meth:`torch.Tensor.to_sparse_bsr` +method that also requires the specification of the values block size: + + >>> dense = torch.tensor([[0, 1, 2, 3, 4, 5], + ... [6, 7, 8, 9, 10, 11], + ... [12, 13, 14, 15, 16, 17], + ... [18, 19, 20, 21, 22, 23]]) + >>> bsr = dense.to_sparse_bsr(blocksize=(2, 3)) + >>> bsr + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([[[ 0, 1, 2], + [ 6, 7, 8]], + [[ 3, 4, 5], + [ 9, 10, 11]], + [[12, 13, 14], + [18, 19, 20]], + [[15, 16, 17], + [21, 22, 23]]]), size=(4, 6), nnz=4, + layout=torch.sparse_bsr) + +.. _sparse-bsc-docs: + +Sparse BSC Tensor +----------------- + +The sparse BSC (Block compressed Sparse Column) tensor format implements the +BSC format for storage of two-dimensional tensors with an extension to +supporting batches of sparse BSC tensors and values being blocks of +multi-dimensional tensors. + +A sparse BSC tensor consists of three tensors: ``ccol_indices``, +``row_indices`` and ``values``: + + - The ``ccol_indices`` tensor consists of compressed column + indices. This is a (B + 1)-D tensor of shape ``(*batchsize, + ncolblocks + 1)``. The last element is the number of specified blocks, + ``nse``. This tensor encodes the index in ``values`` and + ``row_indices`` depending on where the given row block + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of blocks in a given column. + + - The ``row_indices`` tensor contains the row block indices of each + element. This is a (B + 1)-D tensor of shape ``(*batchsize, + nse)``. + + - The ``values`` tensor contains the values of the sparse BSC tensor + elements collected into two-dimensional blocks. This is a (1 + 2 + + K)-D tensor of shape ``(nse, nrowblocks, ncolblocks, + *densesize)``. + +Construction of BSC tensors +''''''''''''''''''''''''''' + +Sparse BSC tensors can be directly constructed by using the +:func:`torch.sparse_bsc_tensor` function. The user must supply the row +and column block indices and values tensors separately where the column block indices +must be specified using the CSR compression encoding. +The ``size`` argument is optional and will be deduced from the ``ccol_indices`` and +``row_indices`` tensors if it is not present. + + >>> ccol_indices = torch.tensor([0, 2, 4]) + >>> row_indices = torch.tensor([0, 1, 0, 1]) + >>> values = torch.tensor([[[0, 1, 2], [6, 7, 8]], + ... [[3, 4, 5], [9, 10, 11]], + ... [[12, 13, 14], [18, 19, 20]], + ... [[15, 16, 17], [21, 22, 23]]]) + >>> bsc = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, dtype=torch.float64) + >>> bsc + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([[[ 0., 1., 2.], + [ 6., 7., 8.]], + [[ 3., 4., 5.], + [ 9., 10., 11.]], + [[12., 13., 14.], + [18., 19., 20.]], + [[15., 16., 17.], + [21., 22., 23.]]]), size=(4, 6), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsc) + +Tools for working with sparse compressed tensors +------------------------------------------------ + +All sparse compressed tensors --- CSR, CSC, BSR, and BSC tensors --- +are conceptionally very similar in that their indices data is split +into two parts: so-called compressed indices that use the CSR +encoding, and so-called plain indices that are orthogonal to the +compressed indices. This allows various tools on these tensors to +share the same implementations that are parameterized by tensor +layout. + +Construction of sparse compressed tensors +''''''''''''''''''''''''''''''''''''''''' + +Sparse CSR, CSC, BSR, and CSC tensors can be constructed by using +:func:`torch.sparse_compressed_tensor` function that have the same +interface as the above discussed constructor functions +:func:`torch.sparse_csr_tensor`, :func:`torch.sparse_csc_tensor`, +:func:`torch.sparse_bsr_tensor`, and :func:`torch.sparse_bsc_tensor`, +respectively, but with an extra required ``layout`` argument. The +following example illustrates a method of constructing CSR and CSC +tensors using the same input data by specifying the corresponding +layout parameter to the :func:`torch.sparse_compressed_tensor` +function: + + >>> compressed_indices = torch.tensor([0, 2, 4]) + >>> plain_indices = torch.tensor([0, 1, 0, 1]) + >>> values = torch.tensor([1, 2, 3, 4]) + >>> csr = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=torch.sparse_csr) + >>> csr + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1, 2, 3, 4]), size=(2, 2), nnz=4, + layout=torch.sparse_csr) + >>> csc = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=torch.sparse_csc) + >>> csc + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([1, 2, 3, 4]), size=(2, 2), nnz=4, + layout=torch.sparse_csc) + >>> (csr.transpose(0, 1).to_dense() == csc.to_dense()).all() + tensor(True) + + Supported Linear Algebra operations +++++++++++++++++++++++++++++++++++ @@ -496,16 +868,21 @@ Tensor methods and sparse The following Tensor methods are related to sparse tensors: .. autosummary:: + :toctree: generated :nosignatures: Tensor.is_sparse + Tensor.is_sparse_csr Tensor.dense_dim Tensor.sparse_dim Tensor.sparse_mask Tensor.to_sparse Tensor.to_sparse_coo Tensor.to_sparse_csr - Tensor.indices + Tensor.to_sparse_csc + Tensor.to_sparse_bsr + Tensor.to_sparse_bsc + Tensor.to_dense Tensor.values The following Tensor methods are specific to sparse COO tensors: @@ -518,16 +895,26 @@ The following Tensor methods are specific to sparse COO tensors: Tensor.sparse_resize_ Tensor.sparse_resize_and_clear_ Tensor.is_coalesced - Tensor.to_dense + Tensor.indices -The following methods are specific to :ref:`sparse CSR tensors `: +The following methods are specific to :ref:`sparse CSR tensors ` and :ref:`sparse BSR tensors `: .. autosummary:: + :toctree: generated :nosignatures: Tensor.crow_indices Tensor.col_indices +The following methods are specific to :ref:`sparse CSC tensors ` and :ref:`sparse BSC tensors `: + +.. autosummary:: + :toctree: generated + :nosignatures: + + Tensor.row_indices + Tensor.ccol_indices + The following Tensor methods support sparse COO tensors: :meth:`~torch.Tensor.add` @@ -590,6 +977,10 @@ Torch functions specific to sparse Tensors sparse_coo_tensor sparse_csr_tensor + sparse_csc_tensor + sparse_bsr_tensor + sparse_bsc_tensor + sparse_compressed_tensor sparse.sum sparse.addmm sparse.sampled_addmm @@ -599,6 +990,7 @@ Torch functions specific to sparse Tensors smm sparse.softmax sparse.log_softmax + sparse.spdiags Other functions +++++++++++++++ @@ -628,3 +1020,35 @@ The following :mod:`torch` functions support sparse tensors: :func:`~torch.vstack` :func:`~torch.zeros` :func:`~torch.zeros_like` + +In addition, all zero-preserving unary functions support sparse +COO/CSR/CSC/BSR/CSR tensor inputs: + +:func:`~torch.abs` +:func:`~torch.asin` +:func:`~torch.asinh` +:func:`~torch.atan` +:func:`~torch.atanh` +:func:`~torch.ceil` +:func:`~torch.conj_physical` +:func:`~torch.floor` +:func:`~torch.log1p` +:func:`~torch.neg` +:func:`~torch.round` +:func:`~torch.sin` +:func:`~torch.sinh` +:func:`~torch.sign` +:func:`~torch.sgn` +:func:`~torch.signbit` +:func:`~torch.tan` +:func:`~torch.tanh` +:func:`~torch.trunc` +:func:`~torch.expm1` +:func:`~torch.sqrt` +:func:`~torch.angle` +:func:`~torch.isinf` +:func:`~torch.isposinf` +:func:`~torch.isneginf` +:func:`~torch.isnan` +:func:`~torch.erf` +:func:`~torch.erfinv` diff --git a/docs/source/special.rst b/docs/source/special.rst index 42acd2148a6a9..ac1ce837ef6b5 100644 --- a/docs/source/special.rst +++ b/docs/source/special.rst @@ -12,35 +12,39 @@ The torch.special module, modeled after SciPy's `special Storage` and :class:`torch.cuda.Storage` classes, +like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc., are not +actually ever instantiated. Calling their constructors creates +a :class:`torch.TypedStorage` with the appropriate :class:`torch.dtype` and +:class:`torch.device`. :class:`torch.Storage` classes have all of the +same class methods that :class:`torch.TypedStorage` has. + +A :class:`torch.TypedStorage` is a contiguous, one-dimensional array of elements of a particular :class:`torch.dtype`. It can be given any :class:`torch.dtype`, and the internal data will be interpretted appropriately. +:class:`torch.TypedStorage` contains a :class:`torch.UntypedStorage` which +holds the data as an untyped array of bytes. -Every strided :class:`torch.Tensor` contains a :class:`torch._TypedStorage`, +Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`, which stores all of the data that the :class:`torch.Tensor` views. -For backward compatibility, there are also :class:`torch.Storage` classes -(like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc). These -classes are not actually instantiated, and calling their constructors creates -a :class:`torch._TypedStorage` with the appropriate :class:`torch.dtype`. -:class:`torch.Storage` classes have all of the same class methods that -:class:`torch._TypedStorage` has. - -Also for backward compatibility, :class:`torch.Storage` is an alias for the -storage class that corresponds with the default data type -(:func:`torch.get_default_dtype()`). For instance, if the default data type is -:attr:`torch.float`, :class:`torch.Storage` resolves to -:class:`torch.FloatStorage`. - +.. autoclass:: torch.TypedStorage + :members: + :undoc-members: + :inherited-members: -.. autoclass:: torch._TypedStorage +.. autoclass:: torch.UntypedStorage :members: :undoc-members: :inherited-members: diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index e88c382df17e7..9c4264316fd14 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -669,7 +669,12 @@ Tensor class reference Tensor.arctanh_ Tensor.tolist Tensor.topk + Tensor.to_dense Tensor.to_sparse + Tensor.to_sparse_csr + Tensor.to_sparse_csc + Tensor.to_sparse_bsr + Tensor.to_sparse_bsc Tensor.trace Tensor.transpose Tensor.transpose_ @@ -685,6 +690,7 @@ Tensor class reference Tensor.type Tensor.type_as Tensor.unbind + Tensor.unflatten Tensor.unfold Tensor.uniform_ Tensor.unique diff --git a/docs/source/testing.rst b/docs/source/testing.rst index d1a63f645dfc7..122aa651b9579 100644 --- a/docs/source/testing.rst +++ b/docs/source/testing.rst @@ -2,6 +2,7 @@ torch.testing ============= .. automodule:: torch.testing +.. currentmodule:: torch.testing .. autofunction:: assert_close .. autofunction:: make_tensor diff --git a/docs/source/torch.rst b/docs/source/torch.rst index e382bb63e245b..a530c5af136f2 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -533,6 +533,7 @@ Other Operations tril_indices triu triu_indices + unflatten vander view_as_real view_as_complex diff --git a/functorch/.circleci/config.yml b/functorch/.circleci/config.yml new file mode 100644 index 0000000000000..bab6defefd24c --- /dev/null +++ b/functorch/.circleci/config.yml @@ -0,0 +1,316 @@ +version: 2.1 + +executors: + windows-cpu: + machine: + resource_class: windows.xlarge + image: windows-server-2019-vs2019:stable + shell: bash.exe + + windows-gpu: + machine: + resource_class: windows.gpu.nvidia.medium + image: windows-server-2019-nvidia:stable + shell: bash.exe + +commands: + checkout_merge: + description: "checkout merge branch" + steps: + - checkout + designate_upload_channel: + description: "inserts the correct upload channel into ${BASH_ENV}" + steps: + - run: + name: adding UPLOAD_CHANNEL to BASH_ENV + command: | + our_upload_channel=nightly + # On tags upload to test instead + if [[ -n "${CIRCLE_TAG}" ]]; then + our_upload_channel=test + fi + echo "export UPLOAD_CHANNEL=${our_upload_channel}" >> ${BASH_ENV} + +binary_common: &binary_common + parameters: + # Edit these defaults to do a release` + build_version: + description: "version number of release binary; by default, build a nightly" + type: string + default: "" + pytorch_version: + description: "PyTorch version to build against; by default, use a nightly" + type: string + default: "" + # Don't edit these + python_version: + description: "Python version to build against (e.g., 3.7)" + type: string + cu_version: + description: "CUDA version to build against, in CU format (e.g., cpu or cu100)" + type: string + unicode_abi: + description: "Python 2.7 wheel only: whether or not we are cp27mu (default: no)" + type: string + default: "" + wheel_docker_image: + description: "Wheel only: what docker image to use" + type: string + default: "pytorch/manylinux-cuda101" + environment: + PYTHON_VERSION: << parameters.python_version >> + PYTORCH_VERSION: << parameters.pytorch_version >> + UNICODE_ABI: << parameters.unicode_abi >> + CU_VERSION: << parameters.cu_version >> + +jobs: + unittest_linux_cpu: + <<: *binary_common + machine: + image: "ubuntu-2004:202104-01" + resource_class: xlarge + steps: + - checkout + - run: + name: Setup + command: | + touch ${BASH_ENV} + echo "export PARAMETERS_PYTHON_VERSION=<< parameters.python_version >>" >> ${BASH_ENV} + cat ${BASH_ENV} + # For some reason circleci isn't automatically sourcing this within the builds + source ${BASH_ENV} && .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install functorch + command: | + touch ${BASH_ENV} + echo "export PARAMETERS_PYTHON_VERSION=<< parameters.python_version >>" >> ${BASH_ENV} + cat ${BASH_ENV} + # For some reason circleci isn't automatically sourcing this within the builds + source ${BASH_ENV} && .circleci/unittest/linux/scripts/install.sh + - persist_to_workspace: + root: wheels + paths: + - "*" + - store_artifacts: + path: wheels + - run: + name: Run tests + command: .circleci/unittest/linux/scripts/run_test.sh + - run: + name: Post process + command: .circleci/unittest/linux/scripts/post_process.sh + - store_test_results: + path: test-reports + + unittest_linux_gpu: + <<: *binary_common + machine: + # https://circleci.com/docs/2.0/configuration-reference/#available-linux-gpu-images + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.medium + steps: + - checkout + - run: + name: Setup + command: | + touch ${BASH_ENV} + echo "export PARAMETERS_PYTHON_VERSION=<< parameters.python_version >>" >> ${BASH_ENV} + cat ${BASH_ENV} + # For some reason circleci isn't automatically sourcing this within the builds + source ${BASH_ENV} && .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install functorch + command: | + touch ${BASH_ENV} + echo "export PARAMETERS_PYTHON_VERSION=<< parameters.python_version >>" >> ${BASH_ENV} + cat ${BASH_ENV} + # For some reason circleci isn't automatically sourcing this within the builds + source ${BASH_ENV} && .circleci/unittest/linux/scripts/install.sh + - persist_to_workspace: + root: wheels + paths: + - "*" + - store_artifacts: + path: wheels + - run: + name: Run tests + command: .circleci/unittest/linux/scripts/run_test.sh + - run: + name: Post process + command: .circleci/unittest/linux/scripts/post_process.sh + - store_test_results: + path: test-reports + + unittest_macos_cpu: + <<: *binary_common + macos: + xcode: "12.0" + resource_class: large + steps: + - checkout + - run: + name: Install wget + command: HOMEBREW_NO_AUTO_UPDATE=1 brew install wget + # Disable brew auto update which is very slow + - run: + name: Setup + command: | + touch ${BASH_ENV} + echo "export PARAMETERS_PYTHON_VERSION=<< parameters.python_version >>" >> ${BASH_ENV} + cat ${BASH_ENV} + # For some reason circleci isn't automatically sourcing this within the builds + source ${BASH_ENV} && .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install functorch + command: .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/linux/scripts/run_test.sh + - run: + name: Post process + command: .circleci/unittest/linux/scripts/post_process.sh + - store_test_results: + path: test-results + + unittest_windows_cpu: + <<: *binary_common + executor: + name: windows-cpu + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + keys: + - env-v2-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + - run: + name: Setup + command: .circleci/unittest/windows/scripts/setup_env.sh + - save_cache: + key: env-v2-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + paths: + - conda + - env + - run: + name: Install functorch + command: .circleci/unittest/windows/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/windows/scripts/run_test.sh + - run: + name: Post process + command: .circleci/unittest/windows/scripts/post_process.sh + - store_test_results: + path: test-reports + + unittest_windows_gpu: + <<: *binary_common + executor: + name: windows-gpu + environment: + CUDA_VERSION: "11.3" + PYTHON_VERSION: << parameters.python_version >> + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + keys: + - env-v2-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + - run: + name: Setup + command: .circleci/unittest/windows/scripts/setup_env.sh + - save_cache: + key: env-v2-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + paths: + - conda + - env + - run: + name: Install CUDA + command: packaging/windows/internal/cuda_install.bat + - run: + name: Update CUDA driver + command: packaging/windows/internal/driver_update.bat + - run: + name: Install functorch + command: .circleci/unittest/windows/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/windows/scripts/run_test.sh + - run: + name: Post process + command: .circleci/unittest/windows/scripts/post_process.sh + - store_test_results: + path: test-reports + + binary_win_wheel: + <<: *binary_common + executor: windows-cpu + steps: + - checkout_merge + - designate_upload_channel + - run: + name: Build wheel packages + command: | + set -ex + source packaging/windows/internal/vc_install_helper.sh + packaging/windows/internal/cuda_install.bat + packaging/build_wheel.sh + - store_artifacts: + path: dist + - persist_to_workspace: + root: dist + paths: + - "*" + - store_test_results: + path: build_results/ + +workflows: + unittest: + jobs: + - unittest_linux_cpu: + name: unittest_linux_<< matrix.cu_version >>_py<< matrix.python_version >> + matrix: + parameters: + python_version: ["3.7", "3.8", "3.9", "3.10"] + cu_version: ["cpu"] + - unittest_linux_gpu: + name: unittest_linux_<< matrix.cu_version >>_py<< matrix.python_version >> + matrix: + parameters: + python_version: ["3.7", "3.8", "3.9", "3.10"] + cu_version: ["cu102"] + + - unittest_macos_cpu: + name: unittest_macos_<< matrix.cu_version >>_py<< matrix.python_version >> + matrix: + parameters: + python_version: ["3.10"] + cu_version: ["cpu"] + + - unittest_windows_cpu: + name: unittest_windows_<< matrix.cu_version >>_py<< matrix.python_version >> + matrix: + parameters: + python_version: ["3.9"] + cu_version: ["cpu"] + + - unittest_windows_gpu: + name: unittest_windows_<< matrix.cu_version >>_py<< matrix.python_version >> + matrix: + parameters: + python_version: ["3.10"] + cu_version: ["cu113"] + + - binary_win_wheel: + name: binary_win_wheel_<< matrix.cu_version >>_py<< matrix.python_version >> + matrix: + parameters: + python_version: ["3.7", "3.8", "3.9", "3.10"] + cu_version: ["cpu"] diff --git a/functorch/.circleci/unittest/linux/scripts/environment.yml b/functorch/.circleci/unittest/linux/scripts/environment.yml new file mode 100644 index 0000000000000..2b3e5d43683a6 --- /dev/null +++ b/functorch/.circleci/unittest/linux/scripts/environment.yml @@ -0,0 +1,17 @@ +channels: + - defaults +dependencies: + - numpy + - pytest + - pytest-cov + - codecov + - pip + - ca-certificates + - pyyaml + - pip: + - unittest-xml-reporting + - pillow>=4.1.1 + - scipy + - av + - networkx + - ninja diff --git a/functorch/.circleci/unittest/linux/scripts/install.sh b/functorch/.circleci/unittest/linux/scripts/install.sh new file mode 100755 index 0000000000000..c6f7272b1c6fd --- /dev/null +++ b/functorch/.circleci/unittest/linux/scripts/install.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +set -x +set -e + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +# if [ "${CU_VERSION:-}" == cpu ] ; then +# cudatoolkit="cpuonly" +# else +# if [[ ${#CU_VERSION} -eq 4 ]]; then +# CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" +# elif [[ ${#CU_VERSION} -eq 5 ]]; then +# CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" +# fi +# echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION" +# version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +# cudatoolkit="cudatoolkit=${version}" +# fi + +WHEELS_FOLDER=${HOME}/project/wheels +mkdir -p $WHEELS_FOLDER + +PYVSHORT=${PARAMETERS_PYTHON_VERSION:0:1}${PARAMETERS_PYTHON_VERSION:2:1} + +if [[ "$PYVSHORT" == "38" ]] ; then + PYVSHORT=cp${PYVSHORT}-cp${PYVSHORT} +else + PYVSHORT=cp${PYVSHORT}-cp${PYVSHORT}m +fi + +# if [ "${CU_VERSION:-}" == cpu ] ; then +# pip install https://download.pytorch.org/whl/nightly/cpu/torch-1.9.0.dev20210427%2Bcpu-${PYVSHORT}-linux_x86_64.whl +# pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.10.0.dev20210427%2Bcpu-${PYVSHORT}-linux_x86_64.whl +# USE_NINJA=1 python setup.py develop bdist_wheel -d $WHEELS_FOLDER +# else +# pip install https://download.pytorch.org/whl/nightly/cu102/torch-1.9.0.dev20210427%2Bcu102-${PYVSHORT}-linux_x86_64.whl +# pip install https://download.pytorch.org/whl/nightly/cu102/torchvision-0.10.0.dev20210427-${PYVSHORT}-linux_x86_64.whl +# USE_NINJA=1 python setup.py develop bdist_wheel -d $WHEELS_FOLDER +# fi + +gcc --version + +# TODO: This should really be a part of environment.yml or the docker image. +# expecttest isn't on conda so it can't be a part of environment.yml :/ +pip install expecttest + +if [ "${CU_VERSION:-}" == cpu ] ; then + conda install -y pytorch torchvision cpuonly -c pytorch-nightly + PYTORCH_VERSION="$(python -c "import torch; print(torch.__version__)")" python setup.py develop bdist_wheel -d $WHEELS_FOLDER +else + conda install -y pytorch torchvision cudatoolkit=10.2 -c pytorch-nightly + PYTORCH_VERSION="$(python -c "import torch; print(torch.__version__)")" python setup.py develop bdist_wheel -d $WHEELS_FOLDER +fi diff --git a/functorch/.circleci/unittest/linux/scripts/post_process.sh b/functorch/.circleci/unittest/linux/scripts/post_process.sh new file mode 100755 index 0000000000000..a84a0dea55e08 --- /dev/null +++ b/functorch/.circleci/unittest/linux/scripts/post_process.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +codecov diff --git a/functorch/.circleci/unittest/linux/scripts/run_test.sh b/functorch/.circleci/unittest/linux/scripts/run_test.sh new file mode 100755 index 0000000000000..9d5eaa5a136f0 --- /dev/null +++ b/functorch/.circleci/unittest/linux/scripts/run_test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -e + +export IN_CI=1 +mkdir test-reports +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +python -m torch.utils.collect_env + +# test_functorch_lagging_op_db.py: Only run this locally because it checks +# the functorch lagging op db vs PyTorch's op db. +EXIT_STATUS=0 +find test \( -name test\*.py ! -name test_functorch_lagging_op_db.py \) | xargs -I {} -n 1 python {} -v || EXIT_STATUS=$? +exit $EXIT_STATUS diff --git a/functorch/.circleci/unittest/linux/scripts/setup_env.sh b/functorch/.circleci/unittest/linux/scripts/setup_env.sh new file mode 100755 index 0000000000000..bbc1a4c24970e --- /dev/null +++ b/functorch/.circleci/unittest/linux/scripts/setup_env.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +set -x +set -e + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and functorch here, otherwise they also get cached. + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PARAMETERS_PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/functorch/.circleci/unittest/windows/scripts/environment.yml b/functorch/.circleci/unittest/windows/scripts/environment.yml new file mode 100644 index 0000000000000..590b1df530ec7 --- /dev/null +++ b/functorch/.circleci/unittest/windows/scripts/environment.yml @@ -0,0 +1,19 @@ +channels: + - pytorch + - defaults +dependencies: + - numpy + - pytest + - pytest-cov + - codecov + - pip + - pyyaml + - ca-certificates + - pip: + - unittest-xml-reporting + - pillow>=4.1.1 + - scipy + - av + - networkx + - expecttest + - ninja diff --git a/functorch/.circleci/unittest/windows/scripts/install.sh b/functorch/.circleci/unittest/windows/scripts/install.sh new file mode 100644 index 0000000000000..d425b2b7133b7 --- /dev/null +++ b/functorch/.circleci/unittest/windows/scripts/install.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -ex + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')" +conda activate ./env + +# TODO, refactor the below logic to make it easy to understand how to get correct cuda_version. +if [ "${CU_VERSION:-}" == cpu ] ; then + cudatoolkit="cpuonly" + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" + cudatoolkit="cudatoolkit=${version}" +fi + +printf "Installing PyTorch with %s\n" "${cudatoolkit}" +conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c nvidia "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}" + +torch_cuda=$(python -c "import torch; print(torch.cuda.is_available())") +echo torch.cuda.is_available is $torch_cuda + +if [ ! -z "${CUDA_VERSION:-}" ] ; then + if [ "$torch_cuda" == "False" ]; then + echo "torch with cuda installed but torch.cuda.is_available() is False" + exit 1 + fi +fi + +source "$this_dir/set_cuda_envs.sh" + +printf "* Installing functorch\n" +"$this_dir/vc_env_helper.bat" python setup.py develop diff --git a/functorch/.circleci/unittest/windows/scripts/install_conda.bat b/functorch/.circleci/unittest/windows/scripts/install_conda.bat new file mode 100644 index 0000000000000..6052ad08b106a --- /dev/null +++ b/functorch/.circleci/unittest/windows/scripts/install_conda.bat @@ -0,0 +1 @@ +start /wait "" "%miniconda_exe%" /S /InstallationType=JustMe /RegisterPython=0 /AddToPath=0 /D=%tmp_conda% diff --git a/functorch/.circleci/unittest/windows/scripts/post_process.sh b/functorch/.circleci/unittest/windows/scripts/post_process.sh new file mode 100644 index 0000000000000..5c5cbb758a9ef --- /dev/null +++ b/functorch/.circleci/unittest/windows/scripts/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')" +conda activate ./env diff --git a/functorch/.circleci/unittest/windows/scripts/run_test.sh b/functorch/.circleci/unittest/windows/scripts/run_test.sh new file mode 100644 index 0000000000000..8435aa5c955d7 --- /dev/null +++ b/functorch/.circleci/unittest/windows/scripts/run_test.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +set -e + +export IN_CI=1 +mkdir test-reports + +eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')" +conda activate ./env + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$this_dir/set_cuda_envs.sh" + +python -m torch.utils.collect_env + +EXIT_STATUS=0 +# TODO: we should be able to acquire the following from some bash commands +# Tests currently ordered in order of runtime... +python test/test_eager_transforms.py -v || EXIT_STATUS=$? +python test/test_compile_cache.py -v || EXIT_STATUS=$? +python test/test_minifier.py -v || EXIT_STATUS=$? +python test/test_memory_efficient_fusion.py -v || EXIT_STATUS=$? +python test/test_pythonkey.py -v || EXIT_STATUS=$? +python test/test_vmap.py -v || EXIT_STATUS=$? +python test/test_ops.py -v || EXIT_STATUS=$? +exit $EXIT_STATUS diff --git a/functorch/.circleci/unittest/windows/scripts/set_cuda_envs.sh b/functorch/.circleci/unittest/windows/scripts/set_cuda_envs.sh new file mode 100644 index 0000000000000..7db3137b59440 --- /dev/null +++ b/functorch/.circleci/unittest/windows/scripts/set_cuda_envs.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +set -ex + +echo CU_VERSION is "${CU_VERSION}" +echo CUDA_VERSION is "${CUDA_VERSION}" + +# Currenly, CU_VERSION and CUDA_VERSION are not consistent. +# to understand this code, see https://github.com/pytorch/vision/issues/4443 +version="cpu" +if [[ ! -z "${CUDA_VERSION}" ]] ; then + version="$CUDA_VERSION" +else + if [[ ${#CU_VERSION} -eq 5 ]]; then + version="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi +fi + +# Don't use if [[ "$version" == "cpu" ]]; then exit 0 fi. +# It would exit the shell. One result is cpu tests would not run if the shell exit. +# Unless there's an error, Don't exit. +if [[ "$version" != "cpu" ]]; then + # set cuda envs + export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${version}/bin:/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${version}/libnvvp:$PATH" + export CUDA_PATH_V${version/./_}="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v${version}" + export CUDA_PATH="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v${version}" + + if [ ! -d "$CUDA_PATH" ]; then + echo "$CUDA_PATH" does not exist + exit 1 + fi + + if [ ! -f "${CUDA_PATH}\include\nvjpeg.h" ]; then + echo "nvjpeg does not exist" + exit 1 + fi + + # check cuda driver version + for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do + if [[ -x "$path" ]]; then + "$path" || echo "true"; + break + fi + done + + which nvcc + nvcc --version + env | grep CUDA +fi diff --git a/functorch/.circleci/unittest/windows/scripts/setup_env.sh b/functorch/.circleci/unittest/windows/scripts/setup_env.sh new file mode 100644 index 0000000000000..b0b7063111204 --- /dev/null +++ b/functorch/.circleci/unittest/windows/scripts/setup_env.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + export tmp_conda="$(echo $conda_dir | tr '/' '\\')" + export miniconda_exe="$(echo $root_dir | tr '/' '\\')\\miniconda.exe" + curl --output miniconda.exe https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe -O + "$this_dir/install_conda.bat" + unset tmp_conda + unset miniconda_exe +fi + +eval "$(${conda_dir}/Scripts/conda.exe 'shell.bash' 'hook')" + +# 2. Create test environment at ./env +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/functorch/.circleci/unittest/windows/scripts/vc_env_helper.bat b/functorch/.circleci/unittest/windows/scripts/vc_env_helper.bat new file mode 100644 index 0000000000000..9410135677a4f --- /dev/null +++ b/functorch/.circleci/unittest/windows/scripts/vc_env_helper.bat @@ -0,0 +1,39 @@ +@echo on + +set VC_VERSION_LOWER=16 +set VC_VERSION_UPPER=17 + +for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( + if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( + set "VS15INSTALLDIR=%%i" + set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" + goto vswhere + ) +) + +:vswhere +if "%VSDEVCMD_ARGS%" == "" ( + call "%VS15VCVARSALL%" x64 || exit /b 1 +) else ( + call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 +) + +@echo on + +set DISTUTILS_USE_SDK=1 + +set args=%1 +shift +:start +if [%1] == [] goto done +set args=%args% %1 +shift +goto start + +:done +if "%args%" == "" ( + echo Usage: vc_env_helper.bat [command] [args] + echo e.g. vc_env_helper.bat cl /c test.cpp +) + +%args% || exit /b 1 diff --git a/functorch/.flake8 b/functorch/.flake8 new file mode 100644 index 0000000000000..a6d73773e3b55 --- /dev/null +++ b/functorch/.flake8 @@ -0,0 +1,20 @@ +[flake8] +select = B,C,E,F,P,T4,W,B9 +max-line-length = 120 +# C408 ignored because we like the dict keyword argument syntax +# E501 is not flexible enough, we're using B950 instead +ignore = + E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, + # shebang has extra meaning in fbcode lints, so I think it's not worth trying + # to line this up with executable bit + EXE001, + # these ignores are from flake8-bugbear; please fix! + B007,B008, + # these ignores are from flake8-comprehensions; please fix! + C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 +exclude = + ./.git, + ./benchmarks, + ./docs, + ./examples, + ./notebooks diff --git a/functorch/.github/workflows/docs.yml b/functorch/.github/workflows/docs.yml new file mode 100644 index 0000000000000..017d9949ff7b6 --- /dev/null +++ b/functorch/.github/workflows/docs.yml @@ -0,0 +1,82 @@ +name: Build and Deploy Docs +on: + pull_request: + types: [opened, synchronize, reopened] + push: + branches: + - main + +jobs: + + build-docs: + runs-on: ubuntu-18.04 + steps: + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: "3.9" + architecture: x64 + - name: Checkout functorch + uses: actions/checkout@v2 + - name: Install PyTorch Nightly + run: | + python3 -mpip install --pre torch>=1.12.0.dev -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - name: Install functorch + run: | + python3 setup.py install + - name: Install docs requirements + run: | + cd docs + python3 -mpip install -r requirements.txt + - name: Build docs + run: | + cd docs + make html + - name: Upload docs as GHA artifact + uses: actions/upload-artifact@v2 + with: + name: built-docs + path: docs/build/html + + deploy-docs: + needs: [build-docs] + runs-on: ubuntu-latest + if: (github.ref == 'refs/heads/main' && github.event_name == 'push') + steps: + - uses: actions/checkout@v2 + with: + ref: gh-pages + fetch-depth: 3 + + - name: Download docs artifact + uses: actions/download-artifact@v2 + with: + name: built-docs + path: /tmp/docs + + - name: Copy built docs to nightly + id: copy-docs + run: | + cp -R /tmp/docs/* nightly/ + git log -3 + # Set commit name and hash as variables: commit_name, commit_hash + echo "::set-output name=commit_name::$(git log -1 --format='%s')" + echo "::set-output name=commit_hash::$(git log -1 --format='%h')" + + - name: Git reset to commit/amend + if: ${{ steps.copy-docs.outputs.commit_name == 'auto-generated commit' }} + run: | + # if commit_name is "auto-generated commit" + # then go back in commit history to commit to the same commit + git reset --soft ${{ steps.copy-docs.outputs.commit_hash }}~1 + git log -3 + + - name: Commit and push to gh-pages + uses: github-actions-x/commit@v2.9 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + push-branch: 'gh-pages' + commit-message: 'auto-generated commit' + force-push: 'true' + name: gha + email: gha@email.org diff --git a/functorch/.github/workflows/lint.yml b/functorch/.github/workflows/lint.yml new file mode 100644 index 0000000000000..26e752e2c9b71 --- /dev/null +++ b/functorch/.github/workflows/lint.yml @@ -0,0 +1,63 @@ +name: Lint + +on: + push: + branches: + - main + pull_request: + +jobs: + lintrunner: + runs-on: ubuntu-18.04 + steps: + - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + fetch-depth: 0 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + architecture: x64 + + - name: Install lintrunner + run: pip install lintrunner==0.8.* + + - name: Initialize lint dependencies + run: lintrunner init + + - name: Run lintrunner on all files + if: github.event_name == 'push' + run: lintrunner -vv --paths-cmd='git grep -Il .' --force-color + + - name: Run lintrunner on PR files + if: github.event_name == 'pull_request' + env: + PR_BASE_SHA: ${{ github.event.pull_request.base.sha }} + run: | + set +e + if ! lintrunner -vv --force-color --merge-base-with "${PR_BASE_SHA}" ; then + echo "" + echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m" + echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m" + exit 1 + fi + + - name: Store annotations + if: always() && github.event_name == 'pull_request' + # Don't show this as an error; the above step will have already failed. + continue-on-error: true + env: + PR_BASE_SHA: ${{ github.event.pull_request.base.sha }} + run: | + # The easiest way to get annotations is to just run lintrunner again + # in JSON mode and use jq to massage the output into GitHub Actions + # workflow commands. + lintrunner --merge-base-with "${PR_BASE_SHA}" --output=json | \ + jq --raw-output '"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' + + +concurrency: + group: lint-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true diff --git a/functorch/.github/workflows/wheels.yml b/functorch/.github/workflows/wheels.yml new file mode 100644 index 0000000000000..3b1ac1b94b868 --- /dev/null +++ b/functorch/.github/workflows/wheels.yml @@ -0,0 +1,61 @@ +name: Wheels +on: + pull_request: + types: [opened, synchronize, reopened] + push: + branches: + - main + +jobs: + + build-wheel-linux: + runs-on: ubuntu-18.04 + container: pytorch/manylinux-cpu + strategy: + matrix: + python_abi: [ "cp37-cp37m", "cp38-cp38", "cp39-cp39" ] + steps: + - name: Checkout functorch + uses: actions/checkout@v2 + - name: Install PyTorch Nightly + run: | + export PATH="/opt/python/${{ matrix.python_abi }}/bin:$PATH" + python3 -mpip install --pre torch>=1.12.0.dev -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - name: Build wheel + run: | + export PATH="/opt/python/${{ matrix.python_abi }}/bin:$PATH" + python3 -mpip install wheel + python3 setup.py bdist_wheel + # NB: wheels have the linux_x86_64 prefix, need to be manually renamed + - name: Upload wheel as GHA artifact + uses: actions/upload-artifact@v2 + with: + name: functorch-linux.whl + path: dist/*.whl + + build-wheel-mac: + runs-on: macos-latest + strategy: + matrix: + python_version: [ "3.7", "3.8", "3.9" ] + steps: + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python_version }} + architecture: x64 + - name: Checkout functorch + uses: actions/checkout@v2 + - name: Install PyTorch Nightly + run: | + python3 -mpip install --pre torch>=1.12.0.dev -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - name: Build wheel + run: | + export CC=clang CXX=clang++ + python3 -mpip install wheel + python3 setup.py bdist_wheel + - name: Upload wheel as GHA artifact + uses: actions/upload-artifact@v2 + with: + name: functorch-mac.whl + path: dist/*.whl diff --git a/functorch/.gitignore b/functorch/.gitignore new file mode 100644 index 0000000000000..145ab7d608390 --- /dev/null +++ b/functorch/.gitignore @@ -0,0 +1,21 @@ +build/ +dist/ +functorch.egg-info/ +*__pycache__* +functorch/version.py +functorch/_C.so +.gdbinit +t.py +.vscode/ +ccache.sh +docs/build +docs/src +docs/source/generated +.DS_Store +op_analysis/*.txt + +# Editor temporaries +*.swn +*.swo +*.swp +*.swm diff --git a/functorch/.lintrunner.toml b/functorch/.lintrunner.toml new file mode 100644 index 0000000000000..6e0d756b53bf0 --- /dev/null +++ b/functorch/.lintrunner.toml @@ -0,0 +1,48 @@ +[[linter]] +code = 'FLAKE8' +include_patterns = ['**/*.py'] +exclude_patterns = [ + '.git/**', + 'benchmarks/**', + 'docs/**', + 'examples/**', + 'notebooks/**', +] +command = [ + 'python3', + 'tools/lint/flake8_linter.py', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python3', + 'tools/lint/pip_init.py', + '--dry-run={{DRYRUN}}', + 'flake8==3.8.2', + 'flake8-bugbear==20.1.4', + 'flake8-comprehensions==3.3.0', + 'flake8-executable==2.0.4', + 'flake8-pyi==20.5.0', + 'mccabe==0.6.1', + 'pycodestyle==2.6.0', + 'pyflakes==2.2.0', +] + +# [[linter]] +# code = 'BLACK' +# include_patterns = [ +# '**/*.py', +# ] +# command = [ +# 'python3', +# 'tools/lint/black_linter.py', +# '--', +# '@{{PATHSFILE}}' +# ] +# init_command = [ +# 'python3', +# 'tools/lint/pip_init.py', +# '--dry-run={{DRYRUN}}', +# 'black==22.3.0', +# ] +# is_formatter = true diff --git a/functorch/CODE_OF_CONDUCT.md b/functorch/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000..b91e23b17c023 --- /dev/null +++ b/functorch/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/functorch/COMPILE_README.md b/functorch/COMPILE_README.md new file mode 100644 index 0000000000000..964cda6fbec0e --- /dev/null +++ b/functorch/COMPILE_README.md @@ -0,0 +1,75 @@ +# AOT Autograd - Introduction to an experimental compilation feature in Functorch + +The primary compilation API we provide is something called AOTAutograd. AOT +Autograd is an experimental feature that allows ahead of time capture of forward +and backward graphs, and allows easy integration with compilers. This creates an +easy to hack Python-based development environment to speedup training of PyTorch +models. AOT Autograd currently lives inside functorch.compile namespace. + +AOT Autograd is experimental and the APIs are likely to change. We are looking +for feedback. If you are interested in using AOT Autograd and need help or have +suggestions, please feel free to open an issue. We will be happy to help. + +For example, here are some examples of how to use it. +```python +from functorch.compile import aot_function, aot_module, draw_graph +import torch.fx as fx +import torch + +# This simply prints out the FX graph of the forwards and the backwards +def print_graph(name): + def f(fx_g: fx.GraphModule, inps): + print(name) + print(fx_g.code) + return fx_g + return f + +def f(x): + return x.cos().cos() + +nf = aot_function(f, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward")) +nf(torch.randn(3, requires_grad=True)) + +# You can do whatever you want before and after, and you can still backprop through the function. +inp = torch.randn(3, requires_grad=True) +inp = inp.cos() +out = nf(inp) +out = out.sin().sum().backward() + +def f(x): + return x.cos().cos() + +# This draws out the forwards and the backwards graphs as svg files +def graph_drawer(name): + def f(fx_g: fx.GraphModule, inps): + draw_graph(fx_g, name) + return fx_g + return f + +aot_function(f, fw_compiler=graph_drawer("forward"), bw_compiler=graph_drawer("backward"))(torch.randn(3, requires_grad=True)) + +# We also have a convenience API for applying AOTAutograd to modules +from torchvision.models import resnet18 +aot_module(resnet18(), print_graph("forward"), print_graph("backward"))(torch.randn(1,3,200,200)) +# output elided since it's very long + +# In practice, you might want to speed it up by sending it to Torchscript. You might also lower it to Torchscript before passing it to another compiler + +def f(x): + return x.cos().cos() + +def ts_compiler(fx_g: fx.GraphModule, inps): + f = torch.jit.script(fx_g) + print(f.graph) + f = torch.jit.freeze(f.eval()) # Note: This eval() works fine *even* though we're using this for training + return f + +aot_function(f, ts_compiler, ts_compiler)(torch.randn(3, requires_grad=True)) +``` + +## Documentation +* AOT Autograd [documentation](https://pytorch.org/functorch/nightly/) +* Min-cut [recomputation](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) with AOT Autograd. + +## Tutorials +You can use this [tutorial](https://pytorch.org/functorch/nightly/notebooks/aot_autograd_optimizations.html) to play with AOT Autograd. diff --git a/functorch/CONTRIBUTING.md b/functorch/CONTRIBUTING.md new file mode 100644 index 0000000000000..effa6ed2e28cf --- /dev/null +++ b/functorch/CONTRIBUTING.md @@ -0,0 +1,12 @@ +## Contributing +Feedback on our APIs, as well as finding bugs, would be very helpful. + +Please feel free to chat us up on the PyTorch Slack, or open an issue +at https://github.com/pytorch/functorch if you're interested in +contributing. + +To contribute a change to functorch, please make sure you are submitting a +Pull Request to the functorch folder in https://github.com/pytorch/pytorch +repository. The source of truth for functorch has moved there from +https://github.com/pytorch/functorch ; the code in the pytorch/functorch +repository is read-only. diff --git a/functorch/LICENSE b/functorch/LICENSE new file mode 100644 index 0000000000000..22f4f8f28d49c --- /dev/null +++ b/functorch/LICENSE @@ -0,0 +1,26 @@ +Copyright (c) 2021 Facebook, Inc. and its affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/functorch/README.md b/functorch/README.md new file mode 100644 index 0000000000000..33fa06e6491fc --- /dev/null +++ b/functorch/README.md @@ -0,0 +1,396 @@ +# functorch + +[**Why functorch?**](#why-composable-function-transforms) +| [**Install guide**](#install) +| [**Transformations**](#what-are-the-transforms) +| [**Documentation**](#documentation) +| [**Future Plans**](#future-plans) + +**This library is currently under heavy development - if you have suggestions +on the API or use-cases you'd like to be covered, please open an github issue +or reach out. We'd love to hear about how you're using the library.** + +`functorch` is [JAX-like](https://github.com/google/jax) composable function +transforms for PyTorch. + +It aims to provide composable `vmap` and `grad` transforms that work with +PyTorch modules and PyTorch autograd with good eager-mode performance. + +In addition, there is experimental functionality to trace through these +transformations using FX in order to capture the results of these transforms +ahead of time. This would allow us to compile the results of vmap or grad +to improve performance. + +## Why composable function transforms? + +There are a number of use cases that are tricky to do in +PyTorch today: +- computing per-sample-gradients (or other per-sample quantities) +- running ensembles of models on a single machine +- efficiently batching together tasks in the inner-loop of MAML +- efficiently computing Jacobians and Hessians +- efficiently computing batched Jacobians and Hessians + +Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above +without designing a separate subsystem for each. This idea of composable function +transforms comes from the [JAX framework](https://github.com/google/jax). + +## Install + +There are two ways to install functorch: +1. functorch from source +2. functorch beta (compatible with PyTorch 1.11) + +We recommend trying out the functorch beta first. + +### Installing functorch from source + +
Click to expand +

+ +#### Using Colab + +Follow the instructions [in this Colab notebook](https://colab.research.google.com/drive/1CrLkqIrydBYP_svnF89UUO-aQEqNPE8x?usp=sharing) + +#### Locally + +First, set up an environment. We will be installing a nightly PyTorch binary +as well as functorch. If you're using conda, create a conda environment: +```bash +conda create --name functorch +conda activate functorch +``` +If you wish to use `venv` instead: +```bash +python -m venv functorch-env +source functorch-env/bin/activate +``` + +Next, install one of the following following PyTorch nightly binaries. +```bash +# For CUDA 10.2 +pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html --upgrade +# For CUDA 11.3 +pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html --upgrade +# For CPU-only build +pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade +``` +If you already have a nightly of PyTorch installed and wanted to upgrade it +(recommended!), append `--upgrade` to one of those commands. + +Install functorch: +```bash +pip install ninja # Makes the build go faster +pip install --user "git+https://github.com/pytorch/functorch.git" +``` + +Run a quick sanity check in python: +```py +import torch +from functorch import vmap +x = torch.randn(3) +y = vmap(torch.sin)(x) +assert torch.allclose(y, x.sin()) +``` + +#### functorch development setup + +`functorch` is a PyTorch C++ Extension module. To install, + +- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source). +`functorch` usually runs on the latest development version of PyTorch. +- Run `python setup.py install`. You can use `DEBUG=1` to compile in debug mode. + +Then, try to run some tests to make sure all is OK: +```bash +pytest test/test_vmap.py -v +pytest test/test_eager_transforms.py -v +``` + +To do devel install: + +```bash +pip install -e . +``` + +To install with optional dependencies, e.g. for AOTAutograd: + +```bash +pip install -e .[aot] +``` + +To run functorch tests, please install our test dependencies (`expecttest`, `pyyaml`). + + +

+
+ +### Installing functorch beta (compatible with PyTorch 1.11) + +
Click to expand +

+ +#### Using Colab + +Follow the instructions [here](https://colab.research.google.com/drive/1GNfb01W_xf8JRu78ZKoNnLqiwcrJrbYG#scrollTo=HJ1srOGeNCGA) + +#### pip + +Prerequisite: [Install PyTorch 1.11](https://pytorch.org/get-started/locally/) + + +```bash +pip install functorch +``` + +Finally, run a quick sanity check in python: +```py +import torch +from functorch import vmap +x = torch.randn(3) +y = vmap(torch.sin)(x) +assert torch.allclose(y, x.sin()) +``` + +

+
+ +## What are the transforms? + +Right now, we support the following transforms: +- `grad`, `vjp`, `jvp`, +- `jacrev`, `jacfwd`, `hessian` +- `vmap` + +Furthermore, we have some utilities for working with PyTorch modules. +- `make_functional(model)` +- `make_functional_with_buffers(model)` + +### vmap + +Note: `vmap` imposes restrictions on the code that it can be used on. +For more details, please read its docstring. + +`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor +operations in `func`. `vmap(func)` returns a new function that maps `func` over +some dimension (default: 0) of each Tensor in `inputs`. + +`vmap` is useful for hiding batch dimensions: one can write a function `func` +that runs on examples and then lift it to a function that can take batches of +examples with `vmap(func)`, leading to a simpler modeling experience: + +```py +from functorch import vmap +batch_size, feature_size = 3, 5 +weights = torch.randn(feature_size, requires_grad=True) + +def model(feature_vec): + # Very simple linear model with activation + assert feature_vec.dim() == 1 + return feature_vec.dot(weights).relu() + +examples = torch.randn(batch_size, feature_size) +result = vmap(model)(examples) +``` + +### grad + +`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute +the gradients of the output of func w.r.t. to `inputs[0]`. + +```py +from functorch import grad +x = torch.randn([]) +cos_x = grad(lambda x: torch.sin(x))(x) +assert torch.allclose(cos_x, x.cos()) + +# Second-order gradients +neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) +assert torch.allclose(neg_sin_x, -x.sin()) +``` + +When composed with `vmap`, `grad` can be used to compute per-sample-gradients: +```py +from functorch import vmap +batch_size, feature_size = 3, 5 + +def model(weights,feature_vec): + # Very simple linear model with activation + assert feature_vec.dim() == 1 + return feature_vec.dot(weights).relu() + +def compute_loss(weights, example, target): + y = model(weights, example) + return ((y - target) ** 2).mean() # MSELoss + +weights = torch.randn(feature_size, requires_grad=True) +examples = torch.randn(batch_size, feature_size) +targets = torch.randn(batch_size) +inputs = (weights,examples, targets) +grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) +``` + +### vjp + +The `vjp` transform applies `func` to `inputs` and returns a new function that +computes vjps given some `cotangents` Tensors. +```py +from functorch import vjp +outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) +``` + +### jvp + +The `jvp` transforms computes Jacobian-vector-products and is also known as +"forward-mode AD". It is not a higher-order function unlike most other transforms, +but it returns the outputs of `func(inputs)` as well as the `jvp`s. +```py +from functorch import jvp +x = torch.randn(5) +y = torch.randn(5) +f = lambda x, y: (x * y) +_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) +assert torch.allclose(output, x + y) +``` + +### jacrev, jacfwd, and hessian + +The `jacrev` transform returns a new function that takes in `x` and returns the +Jacobian of `torch.sin` with respect to `x` using reverse-mode AD. +```py +from functorch import jacrev +x = torch.randn(5) +jacobian = jacrev(torch.sin)(x) +expected = torch.diag(torch.cos(x)) +assert torch.allclose(jacobian, expected) +``` +Use `jacrev` to compute the jacobian. This can be composed with vmap to produce +batched jacobians: + +```py +x = torch.randn(64, 5) +jacobian = vmap(jacrev(torch.sin))(x) +assert jacobian.shape == (64, 5, 5) +``` + +`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using +forward-mode AD: +```py +from functorch import jacfwd +x = torch.randn(5) +jacobian = jacfwd(torch.sin)(x) +expected = torch.diag(torch.cos(x)) +assert torch.allclose(jacobian, expected) +``` + +Composing `jacrev` with itself or `jacfwd` can produce hessians: +```py +def f(x): + return x.sin().sum() + +x = torch.randn(5) +hessian0 = jacrev(jacrev(f))(x) +hessian1 = jacfwd(jacrev(f))(x) +``` + +The `hessian` is a convenience function that combines `jacfwd` and `jacrev`: +```py +from functorch import hessian + +def f(x): + return x.sin().sum() + +x = torch.randn(5) +hess = hessian(f)(x) +``` + +### Tracing through the transformations +We can also trace through these transformations in order to capture the results as new code using `make_fx`. There is also experimental integration with the NNC compiler (only works on CPU for now!). + +```py +from functorch import make_fx, grad +def f(x): + return torch.sin(x).sum() +x = torch.randn(100) +grad_f = make_fx(grad(f))(x) +print(grad_f.code) + +def forward(self, x_1): + sin = torch.ops.aten.sin(x_1) + sum_1 = torch.ops.aten.sum(sin, None); sin = None + cos = torch.ops.aten.cos(x_1); x_1 = None + _tensor_constant0 = self._tensor_constant0 + mul = torch.ops.aten.mul(_tensor_constant0, cos); _tensor_constant0 = cos = None + return mul +``` + +### Working with NN modules: make_functional and friends + +Sometimes you may want to perform a transform with respect to the parameters +and/or buffers of an nn.Module. This can happen for example in: +- model ensembling, where all of your weights and buffers have an additional +dimension +- per-sample-gradient computation where you want to compute per-sample-grads +of the loss with respect to the model parameters + +Our solution to this right now is an API that, given an nn.Module, creates a +stateless version of it that can be called like a function. + +- `make_functional(model)` returns a functional version of `model` and the +`model.parameters()` +- `make_functional_with_buffers(model)` returns a functional version of +`model` and the `model.parameters()` and `model.buffers()`. + +Here's an example where we compute per-sample-gradients using an nn.Linear +layer: + +```py +import torch +from functorch import make_functional, vmap, grad + +model = torch.nn.Linear(3, 3) +data = torch.randn(64, 3) +targets = torch.randn(64, 3) + +func_model, params = make_functional(model) + +def compute_loss(params, data, targets): + preds = func_model(params, data) + return torch.mean((preds - targets) ** 2) + +per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets) +``` + +If you're making an ensemble of models, you may find +`combine_state_for_ensemble` useful. + +## Documentation + +For more documentation, see [our docs website](https://pytorch.org/functorch). + +## Debugging +`functorch._C.dump_tensor`: Dumps dispatch keys on stack +`functorch._C._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you. + +## Future Plans + +In the end state, we'd like to upstream this into PyTorch once we iron out the +design details. To figure out the details, we need your help -- please send us +your use cases by starting a conversation in the issue tracker or trying our +project out. + +## License +Functorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. + +## Citing functorch + +If you use functorch in your publication, please cite it by using the following BibTeX entry. + +```bibtex +@Misc{functorch2021, + author = {Horace He, Richard Zou}, + title = {functorch: JAX-like composable function transforms for PyTorch}, + howpublished = {\url{https://github.com/pytorch/functorch}}, + year = {2021} +} +``` diff --git a/functorch/benchmarks/chrome_trace_parser.py b/functorch/benchmarks/chrome_trace_parser.py new file mode 100755 index 0000000000000..54d2bf1447fb1 --- /dev/null +++ b/functorch/benchmarks/chrome_trace_parser.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +import argparse + +import os +import logging +import pandas as pd + +from functorch._src.benchmark_utils import compute_utilization + +# process the chrome traces output by the pytorch profiler +# require the json input file's name to be in format {model_name}_chrome_trace_*.json +# the runtimes file should have format (model_name, runtime) + + +def get_model_name(filename): + """ + Get model name from a file in format {model_name}_chrome_trace_*.json + """ + _, tail = os.path.split(filename) + modelname = tail[:tail.find("_chrome_trace")] + return modelname + +def get_total_length(run_times_df, modelname): + return float(run_times_df[run_times_df["name"] == modelname]["runtime"]) + + +def main(): + parser = argparse.ArgumentParser() + group = parser.add_mutually_exclusive_group(required=True) + parser.add_argument( + "--runtime", "-runf", help="file name of the runtime file", required=True + ) + group.add_argument( + "--filename", "-f", action="append", help="a filename of the json file to process" + ) + group.add_argument( + "--folder", "-fd", help="a folder of the json files to process" + ) + args = parser.parse_args() + + + if args.filename: + filenames = args.filename + elif args.folder: + filenames = [] + directory = args.folder + for filename in os.listdir(directory): + f = os.path.join(directory, filename) + if os.path.isfile(f) and f.endswith(".json"): + filenames.append(f) + else: + print("Please provide a filename or a folder name") + + print("modelname, GPU Utilization, MM and Conv time") + + run_times_df = pd.read_csv(args.runtime) + for filename in filenames: + try: + modelname = get_model_name(filename) + total_length = get_total_length(run_times_df, modelname) * 1e6 + utilization, mm_conv_utilization = compute_utilization(filenames, total_length) + print(f"{modelname}, {utilization}, {mm_conv_utilization}") + except BaseException: + logging.exception(f"{filename}, ERROR") + print(f"{filename}, ERROR") + +if __name__ == "__main__": + main() diff --git a/functorch/benchmarks/cse.py b/functorch/benchmarks/cse.py new file mode 100644 index 0000000000000..028677d6ee259 --- /dev/null +++ b/functorch/benchmarks/cse.py @@ -0,0 +1,103 @@ +import torch +import torch.fx as fx +from functorch import make_fx +from torch.profiler import profile, ProfilerActivity + +from functorch._src.compile_utils import fx_graph_cse + +def profile_it(f, inp): + for _ in range(5): + f(inp) + + itr = 5 + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + for _ in range(itr): + f(inp) + + timing = prof.key_averages() + cuda_time_total = 0 + for e in timing: + cuda_time_total = cuda_time_total + e.cuda_time_total + return cuda_time_total / itr + +def profile_function(name, f, inp): + fx_g = make_fx(f)(inp) + + new_g = fx_graph_cse(fx_g.graph) + new_g = fx.GraphModule(fx_g, new_g) + # do not benchmark against the scripted version because script already does some CSE + # script_f = torch.jit.script(fx_g) + # script_g = torch.jit.script(new_g) + # avg_cuda_time_f = profile_it(script_f, inp) + # avg_cuda_time_g = profile_it(script_g, inp) + avg_cuda_time_f = profile_it(fx_g, inp) + avg_cuda_time_g = profile_it(new_g, inp) + num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes) + + print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}") + +g_gpu = torch.Generator(device='cuda') +g_gpu.manual_seed(2147483647) +inp = torch.randn(2**20, device='cuda', generator=g_gpu) + +def f1(x): + return x.cos().cos() + +profile_function("f1", f1, inp) + +def fsum(x): + a = x.sum() + b = x.sum() + c = x.sum() + d = x.sum() + return a + b + c + d + +profile_function("fsum", fsum, inp) + +def fconcat(x): + a = torch.cat((x, x)) + b = torch.cat((x, x)) + return a + b +profile_function("fconcat", fconcat, inp) + +def fsum2(x): + a = x.sum() + for _ in range(30): + a = a + x.sum() + return a + +profile_function("fsum2", fsum2, inp) + +def fsummulti(x): + a = 0 + for _ in range(3): + a = a + x.sum() + a = a * x.sum() + return a + +profile_function("fsummulti", fsummulti, inp) + +def fsummulti2(x): + a = 0 + for _ in range(30): + a = a + x.sum() + a = a * x.sum() + return a + +profile_function("fsummulti2", fsummulti2, inp) + +def fcos(x): + a = 0 + for _ in range(3): + a = a + x.cos() + return a + +profile_function("fcos", fcos, inp) + +def fcos2(x): + a = 0 + for _ in range(30): + a = a + x.cos() + return a + +profile_function("fcos2", fcos2, inp) diff --git a/functorch/benchmarks/operator_authoring.py b/functorch/benchmarks/operator_authoring.py new file mode 100644 index 0000000000000..88e558bdafc1a --- /dev/null +++ b/functorch/benchmarks/operator_authoring.py @@ -0,0 +1,260 @@ +from functools import partial +import numpy as np +import pandas as pd +import timeit +import torch +from functorch.compile import pointwise_operator + +WRITE_CSV = False +CUDA = False +SIZES = [1, 512, 8192] +NUMBER = [100, 10, 1, 1] +REPEAT = 20 + + +@pointwise_operator +def nnc_add(a, b): + return a + b + + +@pointwise_operator +def nnc_addnorm(a, b, mean, std): + return (a + b - mean) / std + + +def eager_addnorm(a, b, mean, std): + return (a + b - mean) / std + + +def inplace_addnorm(a, b, mean, std, out): + out = torch.add(a, b, out=out) + torch.sub(out, mean, out=out) + torch.div(out, std, out=out) + return out + + +ts_addnorm = torch.jit.script(eager_addnorm) +ts_ip_addnorm = torch.jit.script(inplace_addnorm) + + +def maybe_synced(fn): + if CUDA: + synchronize = torch.cuda.synchronize + synchronize() # warmup + + def _fn(): + result = fn() + synchronize() + return result + + return _fn + return fn + + +def benchmark_loop(setup): + result = np.zeros((REPEAT, len(SIZES), 2), dtype=np.float64) + for s, n in enumerate(SIZES): + nnc, aten = setup(n) + nnc = maybe_synced(nnc) + aten = maybe_synced(aten) + + for r in range(result.shape[0]): + result[r, s, 0] = timeit.timeit(nnc, number=NUMBER[s]) + result[r, s, 1] = timeit.timeit(aten, number=NUMBER[s]) + + result = np.median(result, axis=0) + assert result.shape == (len(SIZES), 2) + result = result[:, 1] / result[:, 0] + print(result) + return result + + +def test(make_args, nnc=nnc_add, aten=torch.add): + def setup(n): + args = make_args(n) + result_aten = aten(*args) + result_nnc = nnc(*args) + assert result_nnc.dtype == result_aten.dtype + assert result_nnc.size() == result_aten.size() + assert result_nnc.stride() == result_aten.stride() + torch.testing.assert_allclose(result_aten, result_nnc) + return (lambda: nnc(*args), lambda: aten(*args)) + + return benchmark_loop(setup) + + +def test_inplace(make_args, nnc=nnc_add, aten=torch.add): + def inplace_setup(n): + a, b = make_args(n) + result_aten = torch.clone(a) + result_nnc = torch.clone(a) + nnc(result_nnc, b, out=result_nnc) + aten(result_aten, b, out=result_aten) + torch.testing.assert_allclose(result_aten, result_nnc) + return (lambda: nnc(a, b, out=a), lambda: aten(a, b, out=a)) + + return benchmark_loop(inplace_setup) + + +def test_out(make_args, out, nnc=nnc_add, aten=torch.add): + def out_setup(n): + args = make_args(n) + result_aten = out(n) + result_nnc = out(n) + aten(*args, out=result_aten) + nnc(*args, out=result_nnc) + torch.testing.assert_allclose(result_aten, result_nnc) + result = out(n) + return (lambda: nnc(*args, out=result), lambda: aten(*args, out=result)) + + return benchmark_loop(out_setup) + + +def test_backwards(make_args, nnc=nnc_add, aten=torch.add): + def backwards_setup(n): + args = make_args(n) + (grad_var,) = [a for a in args if a.requires_grad] + aten(*args).sum().backward() + correct = grad_var.grad.clone() + grad_var.grad.zero_() + nnc(*args).sum().backward() + torch.testing.assert_allclose(correct, grad_var.grad) + return ( + lambda: nnc(*args).sum().backward(), + lambda: aten(*args).sum().backward(), + ) + + return benchmark_loop(backwards_setup) + + +def main(): + torch.set_num_threads(1) # TODO(jansel): add parallel support + torch._C._jit_override_can_fuse_on_cpu(True) + + device = "cuda" if CUDA else "cpu" + I = partial(torch.randint, 0, 100, device=device) + R = partial(torch.randn, device=device) + + results = [ + ("add", test(lambda n: (R(n, n), R(n, n)))), + ("broadcast1", test(lambda n: (R(n, n), R(1)))), + ("broadcast2", test(lambda n: (R(n, n), R(n, 1)))), + ("broadcast3", test(lambda n: (R(n, 1), R(1, n)))), + ("inplace", test_inplace(lambda n: (R(n, n), R(n, 1)))), + ("out=", test_out(lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n))), + ("transposed1", test(lambda n: (R(n, n), R(n, n).transpose(0, 1)))), + ( + "transposed2", + test(lambda n: (R(n, n).transpose(0, 1), R(n, n).transpose(0, 1))), + ), + ("slice1", test(lambda n: (R(n + 1, n + 1, 2)[:n, :n, 0], R(n, n)))), + ("slice2", test(lambda n: (R(n, n, 2)[:, :, 0], R(n, n, 2)[:, :, 0]))), + ( + "strided out", + test_out( + lambda n: (R(n, n), R(n, n)), + out=lambda n: R(n + 1, n + 1, 2)[:n, :n, 0], + ), + ), + ( + "out convert", + test_out( + lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n, dtype=torch.float64) + ), + ), + ("issue #57611 (n,32,32,2)", test(lambda n: (R(1, 32, 32, 2), R(n, 1, 1, 2)))), + ("float+double", test(lambda n: (R(n, n), R(n, n, dtype=torch.float64)))), + ( + "int+long", + test( + lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int64)) + ), + ), + ( + "int+short", + test( + lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int16)) + ), + ), + ( + "float+int", + test( + lambda n: (R([n, n], dtype=torch.float32), I([n, n], dtype=torch.int32)) + ), + ), + ( + "double+long", + test( + lambda n: (R([n, n], dtype=torch.float64), I([n, n], dtype=torch.int64)) + ), + ), + ( + "fused addnorm", + test( + lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), + nnc=nnc_addnorm, + aten=eager_addnorm, + ), + ), + ( + "fused addnorm (vs TS)", + test( + lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), + nnc=nnc_addnorm, + aten=ts_addnorm, + ), + ), + ( + "fused addnorm out=", + test_out( + lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), + nnc=nnc_addnorm, + aten=inplace_addnorm, + out=lambda n: R(n, n), + ), + ), + ( + "fused addnorm out= (vs TS)", + test_out( + lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), + nnc=nnc_addnorm, + aten=ts_ip_addnorm, + out=lambda n: R(n, n), + ), + ), + ( + "fused addnorm backward", + test_backwards( + lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)), + nnc=nnc_addnorm, + aten=eager_addnorm, + ), + ), + ( + "fused addnorm backward (vs TS)", + test_backwards( + lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)), + nnc=nnc_addnorm, + aten=ts_addnorm, + ), + ), + ] + + df = pd.DataFrame( + np.stack([r for n, r in results]), + columns=[f"{n}x{n}".rjust(9) for n in SIZES], + index=[n for n, r in results], + ) + + if WRITE_CSV: + df.to_csv("../operator_authoring_results.csv") + print("wrote ../operator_authoring_results.csv") + + print() + print("Speedups over aten") + pd.options.display.float_format = "{:.2f}x".format + print(df) + + +if __name__ == "__main__": + main() diff --git a/functorch/benchmarks/per_sample_grads.py b/functorch/benchmarks/per_sample_grads.py new file mode 100644 index 0000000000000..e9e3524eca53b --- /dev/null +++ b/functorch/benchmarks/per_sample_grads.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +import torchvision.models as models +from opacus.utils.module_modification import convert_batchnorm_modules +import time + +from functorch import vmap, grad +from functorch import make_functional +from opacus import PrivacyEngine + +device = 'cuda' +batch_size = 128 +torch.manual_seed(0) + +model_functorch = convert_batchnorm_modules(models.resnet18(num_classes=10)) +model_functorch = model_functorch.to(device) +criterion = nn.CrossEntropyLoss() + +images = torch.randn(batch_size, 3, 32, 32, device=device) +targets = torch.randint(0, 10, (batch_size,), device=device) +func_model, weights = make_functional(model_functorch) + +def compute_loss(weights, image, target): + images = image.unsqueeze(0) + targets = target.unsqueeze(0) + output = func_model(weights, images) + loss = criterion(output, targets) + return loss + +def functorch_per_sample_grad(): + compute_grad = grad(compute_loss) + compute_per_sample_grad = vmap(compute_grad, (None, 0, 0)) + + + start = time.time() + result = compute_per_sample_grad(weights, images, targets) + torch.cuda.synchronize() + end = time.time() + + return result, end - start # end - start in seconds + +torch.manual_seed(0) +model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10)) +model_opacus = model_opacus.to(device) +criterion = nn.CrossEntropyLoss() +for p_f, p_o in zip(model_functorch.parameters(), model_opacus.parameters()): + assert torch.allclose(p_f, p_o) # Sanity check + +privacy_engine = PrivacyEngine( + model_opacus, + sample_rate=0.01, + alphas=[10, 100], + noise_multiplier=1, + max_grad_norm=10000.0, +) + +def opacus_per_sample_grad(): + start = time.time() + output = model_opacus(images) + loss = criterion(output, targets) + loss.backward() + torch.cuda.synchronize() + end = time.time() + expected = [p.grad_sample for p in model_opacus.parameters()] + for p in model_opacus.parameters(): + delattr(p, 'grad_sample') + p.grad = None + return expected, end - start + + +for _ in range(5): + _, seconds = functorch_per_sample_grad() + print(seconds) + +result, seconds = functorch_per_sample_grad() +print(seconds) + +for _ in range(5): + _, seconds = opacus_per_sample_grad() + print(seconds) + +expected, seconds = opacus_per_sample_grad() +print(seconds) + +result = [r.detach() for r in result] +print(len(result)) + +# TODO: The following shows that the per-sample-grads computed are different. +# This concerns me a little; we should compare to a source of truth. +# for i, (r, e) in enumerate(list(zip(result, expected))[::-1]): +# if torch.allclose(r, e, rtol=1e-5): +# continue +# print(-(i+1), ((r - e)/(e + 0.000001)).abs().max()) diff --git a/functorch/benchmarks/pointwise_scorecard.py b/functorch/benchmarks/pointwise_scorecard.py new file mode 100644 index 0000000000000..ac4cf5f386dcf --- /dev/null +++ b/functorch/benchmarks/pointwise_scorecard.py @@ -0,0 +1,229 @@ +import sys +import time +import torch +import inspect +import itertools + +from functorch import pointwise_operator + +torch.set_num_threads(1) +torch._C._debug_set_fusion_group_inlining(False) + +def rand(*shape): + return torch.rand(*shape).mul(16).add(1) + + +# ------------------------------------------------------------------------------ +# Shape test cases +# ------------------------------------------------------------------------------ +def scalar(): + return (rand(1), rand(1)) + +def small(): + return (rand(32), rand(32)) + +def small_2d(): + return (rand(1, 32), rand(1, 32)) + +def small_broadcast(): + return (rand(4, 32), rand(32)) + +def medium(): + return (rand(32, 12, 64, 64), rand(32, 12, 64, 64)) + +def medium_sliced(): + return (rand(32, 12, 64, 64)[..., ::2], + rand(32, 12, 64, 64)[..., ::2]) + +def medium_transpose(): + return (rand(32, 12, 64, 64).transpose(-1, -2), + rand(32, 12, 64, 64).transpose(-1, -2)) + +def medium2(): + return (rand(32, 3, 224, 224), rand(32, 3, 224, 224)) + +def medium3d(): + return (rand(16, 32, 64), rand(16, 32, 64)) + +def medium_channels_last(): + return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last), + rand(32, 3, 224, 224).to(memory_format=torch.channels_last)) + +def medium_broadcast(): + return (rand(32, 12, 64, 64), rand(64)) + +def medium_broadcast_channels_last(): + return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), + rand(3, 1, 1)) + +def large(): + return (rand(8192, 8192), rand(8192, 8192)) + +def large_transpose(): + return (rand(8192, 8192).transpose(0, 1), + rand(8192, 8192).transpose(0, 1)) + +def large_channels_last(): + return (rand(32, 32, 256, 256).to(memory_format=torch.channels_last), + rand(32, 32, 256, 256).to(memory_format=torch.channels_last)) + +def pathological_broadcast(): + return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2)) + +# ------------------------------------------------------------------------------ +# Operator test cases +# ------------------------------------------------------------------------------ +def add(a, b): + return a + b + +def sub(a, b): + return a - b + +def mul(a, b): + return a * b + +def div(a, b): + return a / b + +def relu(a): + return a.relu() + +def sigmoid(a): + return a.sigmoid() + +def tanh(a): + return a.tanh() + +def log(a): + return a.log() + +def exp(a): + return a.exp() + +def square(a): + return a ** 2 + +def fma(a, b): + return a * b + b + +def hardswish(a): + return a * (a + 3.0).clamp(0.0, 6.0) / 6.0 + +def native_hardswish(a): + return torch._C._nn.hardswish(a) + +def softplus(a): + return (a * 1.0).exp().log1p() / 1.0 + +def mish(a): + return a * ((a * 1.0).exp().log1p() / 1.0).tanh() + +# ------------------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------------------ +def time_cpu(fn, args, iters): + s = time.perf_counter() + for _ in range(iters): + fn(*args) + e = time.perf_counter() + return e - s + +def time_cuda(fn, args, iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn(*args) + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / 1e3 + +def benchmark_with_timer(fn, args, timer): + timer(fn, args, 3) + calibration = timer(fn, args, 1) + iters = int(1.0 / calibration) + return timer(fn, args, iters) / iters + +def benchmark(fn, args): + timer = time_cpu if args[0].device.type == "cpu" else time_cuda + return benchmark_with_timer(fn, args, timer) + +def micros(s): + return f"{s * 1e6:.1f}" + +shapes = [ + scalar, + small, + small_2d, + small_broadcast, + medium, + medium2, + medium3d, + medium_sliced, + medium_transpose, + medium_channels_last, + medium_broadcast, + medium_broadcast_channels_last, + large, + large_transpose, + large_channels_last, + pathological_broadcast, +] + +operators = [ + add, + sub, + mul, + div, + relu, + sigmoid, + tanh, + log, + exp, + square, + fma, + hardswish, + native_hardswish, +] + +nope = set() +for shape, operator in itertools.product(shapes, operators): + nargs = len(inspect.signature(operator).parameters) + args = shape()[:nargs] + + try: + if shape == medium_transpose: + raise RuntimeError("pointwise_operator hangs on medium_transpose") + pw_op = pointwise_operator(operator) + torch.testing.assert_allclose(operator(*args), pw_op(*args)) + except Exception: + print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}") + nope.add((operator, shape)) + + ts_op = torch.jit.script(operator) + torch.testing.assert_allclose(operator(*args), ts_op(*args)) + + +print("fuser,device,operator,shape,time") +results = [] +for shape, operator in itertools.product(shapes, operators): + nargs = len(inspect.signature(operator).parameters) + args = shape()[:nargs] + + result = benchmark(operator, args) + print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) + try: + if shape == medium_transpose: + raise RuntimeError("pointwise_operator hangs on medium_transpose") + if (operator, shape) in nope: + raise RuntimeError("pointwise_operator fails on medium_transpose") + pw_op = pointwise_operator(operator) + result = benchmark(pw_op, args) + print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) + except Exception: + print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))])) + + ts_op = torch.jit.script(operator) + result = benchmark(ts_op, args) + print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) + sys.stdout.flush() diff --git a/functorch/benchmarks/process_scorecard.py b/functorch/benchmarks/process_scorecard.py new file mode 100644 index 0000000000000..f95d879238a12 --- /dev/null +++ b/functorch/benchmarks/process_scorecard.py @@ -0,0 +1,19 @@ +import pandas +import matplotlib.pyplot as plt + +df = pandas.read_csv("perf.csv") + +ops = pandas.unique(df["operator"]) +nops = len(ops) +pivot_op_shape = df.pivot_table(values="time", index=["operator", "shape"], columns=["fuser"]) +pivot_speedups = (pivot_op_shape.T / pivot_op_shape["eager"]).T + +plt.rcParams["figure.figsize"] = (20, 100) +fig, axs = plt.subplots(nops) +plt.subplots_adjust(hspace=0.5) +for idx, op in enumerate(ops): + op_speedups = pivot_speedups.T[op].T + op_speedups.plot(ax=axs[idx], kind="bar", ylim=(0, 5), rot=45) + axs[idx].set_title(op) + axs[idx].set_xlabel("") +plt.savefig("scorecard.svg") diff --git a/functorch/benchmarks/transformer_fusion_patterns/__init__.py b/functorch/benchmarks/transformer_fusion_patterns/__init__.py new file mode 100644 index 0000000000000..10a55772ab58b --- /dev/null +++ b/functorch/benchmarks/transformer_fusion_patterns/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/functorch/benchmarks/transformer_fusion_patterns/benchmark.py b/functorch/benchmarks/transformer_fusion_patterns/benchmark.py new file mode 100644 index 0000000000000..a6646e150c55b --- /dev/null +++ b/functorch/benchmarks/transformer_fusion_patterns/benchmark.py @@ -0,0 +1,191 @@ +import torch +from functorch.compile import memory_efficient_fusion, clear_compile_cache +import benchmark_helper + + +device = "cuda" +dtype = torch.float16 + +# LightSeq pattern 1 +class DropoutResBias: + @staticmethod + def fn(input, bias, residual): + a = torch.add(input, bias) + b = torch.nn.functional.dropout(a, p=0.7, training=True) + c = b + residual + return c + + @staticmethod + def args(): + batch_size, seq_len, hidden_size = 32, 196, 1024 + input = torch.randn( + batch_size, + seq_len, + hidden_size, + requires_grad=True, + device=device, + dtype=dtype, + ) + bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype) + residual = torch.randn( + batch_size, + seq_len, + hidden_size, + requires_grad=False, + device=device, + dtype=dtype, + ) + args = (input, bias, residual) + return args + + +class DropoutResBiasScalar: + @staticmethod + def fn(input, bias, residual, p: float): + a = torch.add(input, bias) + b = torch.nn.functional.dropout(a, p, training=True) + c = b + residual + return c + + @staticmethod + def args(): + batch_size, seq_len, hidden_size = 32, 196, 1024 + input = torch.randn( + batch_size, + seq_len, + hidden_size, + requires_grad=True, + device=device, + dtype=dtype, + ) + bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype) + residual = torch.randn( + batch_size, + seq_len, + hidden_size, + requires_grad=False, + device=device, + dtype=dtype, + ) + args = (input, bias, residual, 0.7) + return args + + + +# LightSeq pattern 2 +class BiasReluDropout: + @staticmethod + def fn(input, bias): + a = torch.add(input, bias) + b = torch.nn.functional.relu(a) + c = torch.nn.functional.dropout(b, p=0.6, training=True) + return c + + @staticmethod + def args(): + batch_size = 32 + seq_len = 196 + intermediate_size = 4096 + input = torch.randn( + batch_size, + seq_len, + intermediate_size, + requires_grad=True, + device=device, + dtype=dtype, + ) + bias = torch.randn( + intermediate_size, requires_grad=True, device=device, dtype=dtype + ) + args = (input, bias) + return args + + +class BiasDropoutResLayerNorm: + @staticmethod + def fn(input, bias, residual): + hidden_size = 1024 + a = torch.add(input, bias) + b = torch.nn.functional.dropout(a, p=0.7, training=True) + c = b + residual + d = torch.nn.functional.layer_norm(c, normalized_shape=(hidden_size,)) + return d + + @staticmethod + def args(): + batch_size = 32 + seq_len = 196 + hidden_size = 1024 + + input = torch.randn( + batch_size, + seq_len, + hidden_size, + requires_grad=True, + device=device, + dtype=dtype, + ) + bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype) + residual = torch.randn( + batch_size, + seq_len, + hidden_size, + requires_grad=False, + device=device, + dtype=dtype, + ) + args = (input, bias, residual) + return args + + +class LayerNormSigmoid: + @staticmethod + def fn(inp): + hidden_size = 512 + a = torch.nn.functional.layer_norm(inp, normalized_shape=(hidden_size,)) + b = torch.sigmoid(a) + return b + + @staticmethod + def args(): + batch_size = 8192 + hidden_size = 512 + inp = torch.randn( + batch_size, hidden_size, requires_grad=True, device=device, dtype=dtype + ) + args = (inp,) + return args + + +for cl in [DropoutResBias, BiasReluDropout, DropoutResBiasScalar, BiasDropoutResLayerNorm, LayerNormSigmoid]: + # Clear the compile cache + clear_compile_cache() + + # Get the function and inputs + obj = cl() + fn = obj.fn + args = obj.args() + + # Find the static args + static_argnums = [] + for idx, arg in enumerate(args): + if not isinstance(arg, torch.Tensor): + static_argnums.append(idx) + + # Get the optimized function + opt_fn = memory_efficient_fusion(fn, static_argnums) + + # Profile cuda kernels + benchmark_helper.profile_cuda_kernels(fn, args, "Eager") + with torch.jit.fuser("fuser2"): + benchmark_helper.profile_cuda_kernels(opt_fn, args, "AOTAutograd") + + # Time it with Torch Timer + benchmark_helper.time_with_torch_timer(fn, args, "Eager") + with torch.jit.fuser("fuser2"): + benchmark_helper.time_with_torch_timer(opt_fn, args, "AOTAutograd") + + # Time it with manual Timer + benchmark_helper.time_with_manual_timer(fn, args, "Eager") + with torch.jit.fuser("fuser2"): + benchmark_helper.time_with_manual_timer(opt_fn, args, "AOTAutograd") diff --git a/functorch/benchmarks/transformer_fusion_patterns/benchmark_helper.py b/functorch/benchmarks/transformer_fusion_patterns/benchmark_helper.py new file mode 100644 index 0000000000000..bad27572e97a0 --- /dev/null +++ b/functorch/benchmarks/transformer_fusion_patterns/benchmark_helper.py @@ -0,0 +1,148 @@ +import torch +from torch.profiler import profile, record_function, ProfilerActivity +from torch.utils.benchmark import Timer +import time + + +def profile_cuda_kernels(fn, args, string_id="Model time"): + print("################################################") + print(f"#### Profiling for {string_id} starts #########") + print("################################################") + warmup = 50 + old_args = args[:] + n_repeats = 1 + n_layers = 1 + ref = fn(*old_args) + gO = torch.rand_like(ref) + for _ in range(0, warmup // n_layers): + args = list(old_args[:]) + ref = fn(*args) + ref.backward(gO) + + torch.cuda.synchronize() + + # Forward profile + def fwd_run(): + for _ in range(0, n_repeats // n_layers): + args = list(old_args[:]) + for arg in args: + if isinstance(arg, torch.Tensor): + arg.grad = None + ref = fn(*args) + + print(f"###### Forward profile for {string_id} starts #####") + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("baseline"): + fwd_run() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) + print(f"###### Forward profile for {string_id} ends #####") + + # Backward profile + def bwd_run(): + for _ in range(0, n_repeats // n_layers): + args = list(old_args[:]) + for arg in args: + if isinstance(arg, torch.Tensor): + arg.grad = None + ref = fn(*args) + + print(f"###### Backward profile for {string_id} starts #####") + torch.cuda.synchronize() + with profile( + activities=[ProfilerActivity.CUDA], record_shapes=True + ) as prof: + with record_function("baseline"): + ref.backward(gO) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) + torch.cuda.synchronize() + print(f"###### Backward profile for {string_id} ends #####") + + bwd_run() + print("################################################") + print(f"#### Profiling for {string_id} ends #########") + print("################################################\n\n\n\n") + + +def time_with_torch_timer(fn, args, string_id, kwargs=None): + if kwargs is None: + kwargs = {} + print("################################################") + print(f"#### Torch Timer for {string_id} starts #########") + print("################################################") + ref = fn(*args, **kwargs) + gO = torch.rand_like(ref) + env = {"args": args, "gO": gO, "kwargs": kwargs, "fn": fn} + grad_none = {"for x in args: x.grad=None"} + fn_call = "fn(*args, **kwargs)" + # Measure end-to-end fwd time + timer = Timer(stmt=f"{fn_call}", globals=env) + fwd_latency = round(timer.timeit(1000).mean * 10 ** 6, 3) + timer_blocked = timer.blocked_autorange() + print(f"Forward = {fwd_latency}") + + # Measure end-to-end fwd bwd + timer = Timer( + stmt=f"{grad_none}; fwd = {fn_call}; fwd.backward(gO)", + globals=env, + ) + fwd_bwd_latency = round(timer.timeit(1000).mean * 10 ** 6, 3) + timer_blocked = timer.blocked_autorange() + # print(f"Forward + sum + Backward = {fwd_sum_bwd_latency}") + + bwd_latency = round(fwd_bwd_latency - fwd_latency, 3) + print(f"Backward = {bwd_latency}") + + print("################################################") + print(f"#### Torch Timer for {string_id} ends ###############") + print("################################################\n\n\n\n") + + +def time_with_manual_timer(fn, args, string_id): + print("################################################") + print(f"#### Manual Timer for {string_id} starts #########") + print("################################################") + warmup = 50 + repeats = 1000 + old_args = args[:] + ref = fn(*old_args) + gO = torch.rand_like(ref) + for _ in range(0, warmup): + args = list(old_args[:]) + + for arg in args: + if isinstance(arg, torch.Tensor): + arg.grad = None + ref = fn(*args) + ref.backward(gO) + + torch.cuda.synchronize() + + fwd_times = [] + bwd_times = [] + for _ in range(0, repeats): + args = list(old_args[:]) + for arg in args: + if isinstance(arg, torch.Tensor): + arg.grad = None + fwd_start = time.time() + ref = fn(*args) + torch.cuda.synchronize() + fwd_end = time.time() + + bwd_start = time.time() + ref.backward(gO) + torch.cuda.synchronize() + bwd_end = time.time() + + fwd_times.append(fwd_end - fwd_start) + bwd_times.append(bwd_end - bwd_start) + avg_fwd = round(sum(fwd_times) / repeats * 10 ** 6, 2) + avg_bwd = round(sum(bwd_times) / repeats * 10 ** 6, 2) + avg_total = round(avg_fwd + avg_bwd, 2) + + print(f"Forward = {avg_fwd}") + print(f"Backward = {avg_bwd}") + + print("################################################") + print(f"#### Manual Timer for {string_id} ends #########") + print("################################################\n\n\n") diff --git a/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py b/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py new file mode 100644 index 0000000000000..b2318068645fd --- /dev/null +++ b/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py @@ -0,0 +1,66 @@ +import torch +from functorch.compile import memory_efficient_pointwise_fusion, clear_compile_cache +import benchmark_helper + +# ALL comments regarding the patetrns + + +def bias_gelu_dropout(input, bias): + a = torch.add(input, bias) + b = torch.nn.functional.gelu(a) + c = torch.nn.functional.dropout(b, p=0.6, training=True) + return c + + +def aot_fn(input, bias): + a = torch.add(input, bias) + b = a * 0.5 * (1.0 + torch.tanh(0.79788456 * a * (1 + 0.044715 * a * a))) + c = torch.nn.functional.dropout(b, p=0.6, training=True) + return c + + +fn = bias_gelu_dropout + +clear_compile_cache() + +# Set inputs +device = "cuda" +dtype = torch.float16 +batch_size = 32 +seq_len = 196 +intermediate_size = 4096 +# batch_size = 2 +# seq_len = 4 +# intermediate_size = 3 +input = torch.randn( + batch_size, + seq_len, + intermediate_size, + requires_grad=True, + device=device, + dtype=dtype, +) +bias = torch.randn(intermediate_size, requires_grad=True, device=device, dtype=dtype) + + +# Get the optimized function +opt_fn = memory_efficient_pointwise_fusion( + aot_fn, compiler_name="torchscript_nvfuser" +) + + +# Profile cuda kernels +benchmark_helper.profile_cuda_kernels(fn, (input, bias), "Eager") +with torch.jit.fuser("fuser2"): + benchmark_helper.profile_cuda_kernels(opt_fn, (input, bias), "AOTAutograd") + + +# Time it with Torch Timer +benchmark_helper.time_with_torch_timer(fn, (input, bias), "Eager") +with torch.jit.fuser("fuser2"): + benchmark_helper.time_with_torch_timer(opt_fn, (input, bias), "AOTAutograd") + +# Time it with manual Timer +benchmark_helper.time_with_manual_timer(fn, (input, bias), "Eager") +with torch.jit.fuser("fuser2"): + benchmark_helper.time_with_manual_timer(opt_fn, (input, bias), "AOTAutograd") diff --git a/functorch/codegen/gen_functorch_lagging_op_db.py b/functorch/codegen/gen_functorch_lagging_op_db.py new file mode 100644 index 0000000000000..833e34ed4d69f --- /dev/null +++ b/functorch/codegen/gen_functorch_lagging_op_db.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.testing._internal.common_methods_invocations import op_db + + +def num_leading_spaces(line: str) -> int: + result = len(line) - len(line.lstrip()) + # Empty space handling + if result == 0: + return 999999 + return result + + +def deindent(code: str) -> str: + lines = code.split('\n') + min_leading_spaces = min(map(num_leading_spaces, lines)) + lines = [line[min_leading_spaces:] for line in lines] + return '\n'.join(lines) + + +if __name__ == '__main__': + supported = {(opinfo.name, opinfo.variant_test_name) for opinfo in op_db} + supported = sorted(supported) + print(deindent("""\ + # Copyright (c) Facebook, Inc. and its affiliates. + # All rights reserved. + # + # This source code is licensed under the BSD-style license found in the + # LICENSE file in the root directory of this source tree. + from torch.testing._internal.common_methods_invocations import op_db + + # Generated from codegen/gen_functorch_op_db.py via + # python codegen/gen_functorch_lagging_op_db.py > test/functorch_lagging_op_db.py + # + # People add new OpInfos to PyTorch all the time. + # We want them to be able to add OpInfos without breaking our CI. + # To achieve this, we keep our OpInfo library behind that of Pytorch's and + # we periodically update our OpInfo library by regenerating this file""")) + + print("_functorch_lagging_meta = {") + for name, variant in supported: + print(f' {(name, variant)},') + print("}") + + print(deindent("""\ + + + def in_functorch_lagging_op_db(opinfo): + return (opinfo.name, opinfo.variant_test_name) in _functorch_lagging_meta + + + functorch_lagging_op_db = [ + opinfo for opinfo in op_db if in_functorch_lagging_op_db(opinfo) + ]""")) diff --git a/functorch/docs/.gitignore b/functorch/docs/.gitignore new file mode 100644 index 0000000000000..7fc077fafdb15 --- /dev/null +++ b/functorch/docs/.gitignore @@ -0,0 +1,4 @@ +node_modules/ +package-lock.json +package.json +yarn.lock diff --git a/functorch/docs/Makefile b/functorch/docs/Makefile new file mode 100644 index 0000000000000..2d0ae2bdd64c5 --- /dev/null +++ b/functorch/docs/Makefile @@ -0,0 +1,39 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS ?= -WT --keep-going +SPHINXBUILD ?= sphinx-build +SPHINXPROJ ?= functorch +SOURCEDIR ?= source +BUILDDIR ?= build +PYCMD ?= python + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +docset: html + doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url https://pytorch.org/docs/ --force $(BUILDDIR)/html/ + + # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. + cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png + convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png + +html-stable: + # stable differs from `make html` in two ways: + # 1) The stable logo is used instead of the unstable logo + # 2) There will not be a link to the stable docs. + # See conf.py for more details. + RELEASE=1 make html + +.PHONY: help Makefile docset + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + @echo "Removing everything under 'build' and 'source/generated'.." + @rm -rf $(BUILDDIR)/html/ $(BUILDDIR)/doctrees $(SOURCEDIR)/generated diff --git a/functorch/docs/README.md b/functorch/docs/README.md new file mode 100644 index 0000000000000..175c47b86fe95 --- /dev/null +++ b/functorch/docs/README.md @@ -0,0 +1,30 @@ +functorch docs build +-------------------- + +## Build Locally + +Install requirements: +``` +pip install -r requirements.txt +``` + +One may also need to install [pandoc](https://pandoc.org/installing.html). On Linux we can use: `sudo apt-get install pandoc`. Or using `conda` we can use: `conda install -c conda-forge pandoc`. + +To run the docs build: +``` +make html +``` + +Check out the output files in `build/html`. + +## Deploy + +The functorch docs website does not updated automatically. We need to periodically regenerate it. + +You need write permissions to functorch to do this. We use GitHub Pages to serve docs. + +1. Build the docs +2. Save the build/html folder somewhere +3. Checkout the branch `gh-pages`. +4. Delete the contents of the branch and replace it with the build/html folder. `index.html` should be at the root. +5. Commit the changes and push the changes to the `gh-pages` branch. diff --git a/functorch/docs/requirements.txt b/functorch/docs/requirements.txt new file mode 100644 index 0000000000000..ec04b70623615 --- /dev/null +++ b/functorch/docs/requirements.txt @@ -0,0 +1,9 @@ +sphinx==3.5.4 +docutils==0.16 +-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme +sphinxcontrib.katex +sphinx_copybutton>=0.3.1 +IPython +myst-nb==0.13.2 +# Fixing upper version due to https://github.com/sphinx-doc/sphinx/issues/10306 +Jinja2<3.1.0 diff --git a/functorch/docs/source/_static/css/custom.css b/functorch/docs/source/_static/css/custom.css new file mode 100644 index 0000000000000..31ded215ee743 --- /dev/null +++ b/functorch/docs/source/_static/css/custom.css @@ -0,0 +1,21 @@ +.codeblock-height-limiter { + max-height: 500px; + overflow: scroll; +} + +div.container a.header-logo { + height: 80px; + width: 160px; + background-image: url("../images/functorch.svg"); + background-size: 160px; + background-position-y: -50%; +} + +.highlight pre { + border: 1px solid rgba(0,0,0,0.15); + border-radius: 4px; +} + +div.cell div.cell_input { + border-bottom-width: 0px; +} diff --git a/functorch/docs/source/_static/images/functorch.svg b/functorch/docs/source/_static/images/functorch.svg new file mode 100644 index 0000000000000..ec7d794122b29 --- /dev/null +++ b/functorch/docs/source/_static/images/functorch.svg @@ -0,0 +1,6 @@ + diff --git a/functorch/docs/source/_templates/autosummary/class.rst b/functorch/docs/source/_templates/autosummary/class.rst new file mode 100644 index 0000000000000..f581ac9d42ea1 --- /dev/null +++ b/functorch/docs/source/_templates/autosummary/class.rst @@ -0,0 +1,12 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :inherited-members: + :members: + +.. autogenerated from source/_templates/autosummary/class.rst diff --git a/functorch/docs/source/_templates/classtemplate.rst b/functorch/docs/source/_templates/classtemplate.rst new file mode 100644 index 0000000000000..4f74842394ec9 --- /dev/null +++ b/functorch/docs/source/_templates/classtemplate.rst @@ -0,0 +1,14 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :members: + + +.. + autogenerated from source/_templates/classtemplate.rst + note it does not have :inherited-members: diff --git a/functorch/docs/source/_templates/layout.html b/functorch/docs/source/_templates/layout.html new file mode 100644 index 0000000000000..7e10da7294f8b --- /dev/null +++ b/functorch/docs/source/_templates/layout.html @@ -0,0 +1,334 @@ +{# TEMPLATE VAR SETTINGS #} +{%- set url_root = pathto('', 1) %} +{%- if url_root == '#' %}{% set url_root = '' %}{% endif %} +{%- if not embedded and docstitle %} + {%- set titlesuffix = " — "|safe + docstitle|e %} +{%- else %} + {%- set titlesuffix = "" %} +{%- endif %} +{%- set lang_attr = 'en' if language == None else (language | replace('_', '-')) %} +{% import 'theme_variables.jinja' as theme_variables %} + + + + + + + {{ metatags }} + + {% block htmltitle %} + {{ title|striptags|e }}{{ titlesuffix }} + {% endblock %} + + {# FAVICON #} + {% if favicon %} + + {% endif %} + {# CANONICAL URL #} + {% if theme_canonical_url %} + + {% endif %} + + {# CSS #} + + {# OPENSEARCH #} + {% if not embedded %} + {% if use_opensearch %} + + {% endif %} + + {% endif %} + + + + {%- for css in css_files %} + {%- if css|attr("rel") %} + + {%- else %} + + {%- endif %} + {%- endfor %} + {%- for cssfile in extra_css_files %} + + {%- endfor %} + + {%- block linktags %} + {%- if hasdoc('about') %} + + {%- endif %} + {%- if hasdoc('genindex') %} + + {%- endif %} + {%- if hasdoc('search') %} + + {%- endif %} + {%- if hasdoc('copyright') %} + + {%- endif %} + {%- if next %} + + {%- endif %} + {%- if prev %} + + {%- endif %} + {%- endblock %} + + {%- block extrahead %} + + {% if theme_analytics_id %} + + + {% endif %} + + + {% if release == "master" %} + + + {% endif %} + + {% endblock %} + + {# Keep modernizr in head - http://modernizr.com/docs/#installing #} + + + {% include "fonts.html" %} + + + +
+ + + + {% block extrabody %} {% endblock %} + + {# SIDE NAV, TOGGLES ON MOBILE #} + + + + + +
+
+
+ {% include "breadcrumbs.html" %} +
+ +
+ Shortcuts +
+
+ +
+
+ + {% if theme_pytorch_project == 'tutorials' %} + + + + {% endif %} + + {%- block content %} + {% if theme_style_external_links|tobool %} + + +
+
+
+ {{ toc }} +
+
+
+
+
+ + {% include "versions.html" %} + + {% if not embedded %} + + {% if sphinx_version >= "1.8.0" %} + + {%- for scriptfile in script_files %} + {{ js_tag(scriptfile) }} + {%- endfor %} + {% else %} + + {%- for scriptfile in script_files %} + + {%- endfor %} + {% endif %} + + {% endif %} + + + + + + + + + {%- block footer %} + + + + + + {% endblock %} + + + + + + + + diff --git a/functorch/docs/source/aot_autograd.rst b/functorch/docs/source/aot_autograd.rst new file mode 100644 index 0000000000000..5123a35485b19 --- /dev/null +++ b/functorch/docs/source/aot_autograd.rst @@ -0,0 +1,43 @@ +functorch.compile (experimental) +================================ + +AOT Autograd is an experimental feature that allows ahead of time capture of +forward and backward graphs, and allows easy integration with compilers. This +creates an easy to hack Python-based development environment to speedup training +of PyTorch models. AOT Autograd currently lives inside ``functorch.compile`` +namespace. + +.. warning:: + AOT Autograd is experimental and the APIs are likely to change. We are looking + for feedback. If you are interested in using AOT Autograd and need help or have + suggestions, please feel free to open an issue. We will be happy to help. + +.. currentmodule:: functorch.compile + +Compilation APIs (experimental) +------------------------------- +.. autosummary:: + :toctree: generated + :nosignatures: + + aot_function + aot_module + memory_efficient_fusion + +Partitioners (experimental) +--------------------------- +.. autosummary:: + :toctree: generated + :nosignatures: + + default_partition + min_cut_rematerialization_partition + +Compilers (experimental) +------------------------ +.. autosummary:: + :toctree: generated + :nosignatures: + + nop + ts_compile diff --git a/functorch/docs/source/batch_norm.rst b/functorch/docs/source/batch_norm.rst new file mode 100644 index 0000000000000..09eb6001b5b66 --- /dev/null +++ b/functorch/docs/source/batch_norm.rst @@ -0,0 +1,48 @@ +Patching Batch Norm +=================== + +What's happening? +----------------- +Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. +Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. +``regular.add_(batched)`` is not allowed). So when vmaping over a batch of inputs to a single module, +we end up with this error + +How to fix +---------- +All of these options assume that you don't need running stats. If you're using a module this means +that it's assumed you won't use batch norm in evalution mode. If you have a use case that involves +running batch norm with vmap in evaluation mode, please file an issue + +Option 1: Change the BatchNorm +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +If you've built the module yourself, you can change the module to not use running stats. In other +words, anywhere that there's a BatchNorm module, set the ``track_running_stats`` flag to be False + +.. code-block:: python + + BatchNorm2d(64, track_running_stats=False) + + +Option 2: torchvision parameter +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are +often defaulted to be BatchNorm2d if they've been defaulted. Instead you can set it to BatchNorm +that doesn't use running stats + +.. code-block:: python + + import torchvision + from functools import partial + torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False)) + +Option 3: functorch's patching +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +functorch has added some functionality to allow for quick, in-place patching of the module. If you +have a net that you want to change, you can run ``replace_all_batch_norm_modules_`` to update the +module in-place to not use running stats + +.. code-block:: python + + from functorch.experimental import replace_all_batch_norm_modules_ + replace_all_batch_norm_modules_(net) diff --git a/functorch/docs/source/conf.py b/functorch/docs/source/conf.py new file mode 100644 index 0000000000000..c73012793908b --- /dev/null +++ b/functorch/docs/source/conf.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +# import sys + +# source code directory, relative to this file, for sphinx-autobuild +# sys.path.insert(0, os.path.abspath('../..')) + +import torch +import functorch + +RELEASE = os.environ.get('RELEASE', False) + +import pytorch_sphinx_theme +import sys + +# -- General configuration ------------------------------------------------ + +# Required version of sphinx is set from docs/requirements.txt + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.todo', + 'sphinx.ext.coverage', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + # 'sphinxcontrib.katex', + 'sphinx.ext.autosectionlabel', + 'sphinx_copybutton', + 'myst_nb', +] + +# sys.path.insert(0, os.path.abspath('./notebooks')) + +# build the templated autosummary files +# autosummary_generate = True +numpydoc_show_class_members = False + +# autosectionlabel throws warnings if section names are duplicated. +# The following tells autosectionlabel to not throw a warning for +# duplicated section names that are in different documents. +autosectionlabel_prefix_document = True + +# tell myst to not execute ipynb tutorials. +jupyter_execute_notebooks = "off" + +# katex options +# +# + +katex_prerender = True + +napoleon_use_ivar = True + +# build the templated autosummary files +autosummary_generate = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = 'functorch' +copyright = 'functorch Contributors' +author = 'functorch Contributors' +functorch_version = str(functorch.__version__) + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +# TODO: change to [:2] at v1.0 +version = 'nightly (' + functorch_version + ')' +# The full version, including alpha/beta/rc tags. +# TODO: verify this works as expected +release = 'nightly' + +# Customized html_title here. +# Default is " ".join(project, release, "documentation") if not set +# TODO: I don't know if this flag works, please check before using it +if RELEASE: + raise RuntimeError('NYI') + # remove hash (start with 'a') from version number if any + # version_end = functorch_version.find('a') + # if version_end == -1: + # html_title = " ".join((project, functorch_version, "documentation")) + # version = functorch_version + # else: + # html_title = " ".join((project, functorch_version[:version_end], "documentation")) + # version = functorch_version[:version_end] + # release = version + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['notebooks/colab**', 'notebooks/_src/**'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + +# Disable docstring inheritance +autodoc_inherit_docstrings = False + +# Disable displaying type annotations, these can be very verbose +autodoc_typehints = 'none' + +# Enable overriding of function signatures in the first line of the docstring. +autodoc_docstring_signature = True + +# -- katex javascript in header +# +# def setup(app): +# app.add_javascript("https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.js") + + +# -- Options for HTML output ---------------------------------------------- +# +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +# +# + +html_theme = 'pytorch_sphinx_theme' +html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + "collapse_navigation": False, + "display_version": True, + "logo_only": True, + "pytorch_project": "docs", + "navigation_with_keys": True, + "analytics_id": "UA-117752657-2", +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +html_css_files = [ + 'css/custom.css', +] + + +# Called automatically by Sphinx, making this `conf.py` an "extension". +def setup(app): + # NOTE: in Sphinx 1.8+ `html_css_files` is an official configuration value + # and can be moved outside of this function (and the setup(app) function + # can be deleted). + html_css_files = [ + 'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css' + ] + + # In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is + # `add_stylesheet` (deprecated in 1.8). + add_css = getattr(app, 'add_css_file', app.add_stylesheet) + for css_file in html_css_files: + add_css(css_file) + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'PyTorchdoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'pytorch.tex', 'PyTorch Documentation', + 'Torch Contributors', 'manual'), +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'functorch', 'functorch Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'functorch', 'functorch Documentation', + author, 'functorch', 'One line description of project.', + 'Miscellaneous'), +] + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + 'python': ('https://docs.python.org/3', None), + 'numpy': ('https://numpy.org/doc/stable', None), + "torch": ("https://pytorch.org/docs/stable/", None), +} + +# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- +# See http://stackoverflow.com/a/41184353/3343043 + +from docutils import nodes +from sphinx.util.docfields import TypedField +from sphinx import addnodes +import sphinx.ext.doctest + +# Without this, doctest adds any example with a `>>>` as a test +doctest_test_doctest_blocks = '' +doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS +doctest_global_setup = ''' +import torch +try: + import torchvision +except ImportError: + torchvision = None +''' + + +def patched_make_field(self, types, domain, items, **kw): + # `kw` catches `env=None` needed for newer sphinx while maintaining + # backwards compatibility when passed along further down! + + # (List, unicode, Tuple) -> nodes.field + def handle_item(fieldarg, content): + par = nodes.paragraph() + par += addnodes.literal_strong('', fieldarg) # Patch: this line added + # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, + # addnodes.literal_strong)) + if fieldarg in types: + par += nodes.Text(' (') + # NOTE: using .pop() here to prevent a single type node to be + # inserted twice into the doctree, which leads to + # inconsistencies later when references are resolved + fieldtype = types.pop(fieldarg) + if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): + typename = u''.join(n.astext() for n in fieldtype) + typename = typename.replace('int', 'python:int') + typename = typename.replace('long', 'python:long') + typename = typename.replace('float', 'python:float') + typename = typename.replace('bool', 'python:bool') + typename = typename.replace('type', 'python:type') + par.extend(self.make_xrefs(self.typerolename, domain, typename, + addnodes.literal_emphasis, **kw)) + else: + par += fieldtype + par += nodes.Text(')') + par += nodes.Text(' -- ') + par += content + return par + + fieldname = nodes.field_name('', self.label) + if len(items) == 1 and self.can_collapse: + fieldarg, content = items[0] + bodynode = handle_item(fieldarg, content) + else: + bodynode = self.list_type() + for fieldarg, content in items: + bodynode += nodes.list_item('', handle_item(fieldarg, content)) + fieldbody = nodes.field_body('', bodynode) + return nodes.field('', fieldname, fieldbody) + +TypedField.make_field = patched_make_field + +copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_is_regexp = True diff --git a/functorch/docs/source/docutils.conf b/functorch/docs/source/docutils.conf new file mode 100644 index 0000000000000..00b6db8269468 --- /dev/null +++ b/functorch/docs/source/docutils.conf @@ -0,0 +1,2 @@ +[html writers] +table_style: colwidths-auto # Necessary for the table generated by autosummary to look decent diff --git a/functorch/docs/source/experimental.rst b/functorch/docs/source/experimental.rst new file mode 100644 index 0000000000000..05f82f08f7452 --- /dev/null +++ b/functorch/docs/source/experimental.rst @@ -0,0 +1,12 @@ +functorch.experimental +====================== + +.. currentmodule:: functorch.experimental + +Experimental Function Transforms +-------------------------------- +.. autosummary:: + :toctree: generated + :nosignatures: + + functionalize diff --git a/functorch/docs/source/functorch.rst b/functorch/docs/source/functorch.rst new file mode 100644 index 0000000000000..c073bc1b604ed --- /dev/null +++ b/functorch/docs/source/functorch.rst @@ -0,0 +1,60 @@ +functorch +========= + +.. currentmodule:: functorch + +Function Transforms +------------------- +.. autosummary:: + :toctree: generated + :nosignatures: + + vmap + grad + grad_and_value + vjp + jvp + jacrev + jacfwd + hessian + +Utilities for working with torch.nn.Modules +------------------------------------------- + +In general, you can transform over a function that calls a ``torch.nn.Module``. +For example, the following is an example of computing a jacobian of a function +that takes three values and returns three values: + +.. code-block:: python + + model = torch.nn.Linear(3, 3) + + def f(x): + return model(x) + + x = torch.randn(3) + jacobian = jacrev(f)(x) + assert jacobian.shape == (3, 3) + +However, if you want to do something like compute a jacobian over the parameters +of the model, then there needs to be a way to construct a function where the +parameters are the inputs to the function. +That's what :func:`make_functional` and :func:`make_functional_with_buffers` are for: +given a ``torch.nn.Module``, these return a new function that accepts ``parameters`` +and the inputs to the Module's forward pass. + +.. autosummary:: + :toctree: generated + :nosignatures: + + make_functional + make_functional_with_buffers + combine_state_for_ensemble + +If you're looking for information on fixing Batch Norm modules, please follow the +guidance here + +.. toctree:: + :maxdepth: 1 + + batch_norm diff --git a/functorch/docs/source/index.rst b/functorch/docs/source/index.rst new file mode 100644 index 0000000000000..6c66a86e21395 --- /dev/null +++ b/functorch/docs/source/index.rst @@ -0,0 +1,78 @@ +:github_url: https://github.com/pytorch/functorch + +functorch +=================================== + +.. currentmodule:: functorch + +functorch is `JAX-like `_ composable function transforms for PyTorch. + +.. note:: + This library is currently in `beta `_. + What this means is that the features generally work (unless otherwise documented) + and we (the PyTorch team) are committed to bringing this library forward. However, the APIs + may change under user feedback and we don't have full coverage over PyTorch operations. + + If you have suggestions on the API or use-cases you'd like to be covered, please + open an github issue or reach out. We'd love to hear about how you're using the library. + +What are composable function transforms? +---------------------------------------- + +- A "function transform" is a higher-order function that accepts a numerical function + and returns a new function that computes a different quantity. + +- functorch has auto-differentiation transforms (``grad(f)`` returns a function that + computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)`` + returns a function that computes ``f`` over batches of inputs), and others. + +- These function transforms can compose with each other arbitrarily. For example, + composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that + stock PyTorch cannot efficiently compute today. + +Why composable function transforms? +----------------------------------- + +There are a number of use cases that are tricky to do in PyTorch today: + +- computing per-sample-gradients (or other per-sample quantities) +- running ensembles of models on a single machine +- efficiently batching together tasks in the inner-loop of MAML +- efficiently computing Jacobians and Hessians +- efficiently computing batched Jacobians and Hessians + +Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each. +This idea of composable function transforms comes from the `JAX framework `_. + +Read More +--------- + +Check out our `whirlwind tour `_ or some of our tutorials mentioned below. + + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + + install + notebooks/whirlwind_tour.ipynb + ux_limitations + +.. toctree:: + :maxdepth: 2 + :caption: API Reference and Notes + + functorch + experimental + aot_autograd + +.. toctree:: + :maxdepth: 1 + :caption: Tutorials + + notebooks/jacobians_hessians.ipynb + notebooks/ensembling.ipynb + notebooks/per_sample_grads.ipynb + notebooks/neural_tangent_kernels.ipynb + notebooks/aot_autograd_optimizations.ipynb + notebooks/minifier.ipynb diff --git a/functorch/docs/source/install.rst b/functorch/docs/source/install.rst new file mode 100644 index 0000000000000..9196893565d96 --- /dev/null +++ b/functorch/docs/source/install.rst @@ -0,0 +1,29 @@ +Install functorch +================= + +pip +--- + +To install functorch via pip, please first install +`PyTorch 1.11 `_ +and then run the following command: + +:: + + pip install functorch + +We currently support manylinux, x86 MacOS, and Windows, please reach out on +`GitHub `_ for other platforms. + +Colab +----- + +Please see `this colab for instructions. `_ + + +Building from source +-------------------- + +See our `README `_ +for instructions on how to build the functorch main development branch for the +latest and greatest. This requires an installation of the latest PyTorch nightly. diff --git a/functorch/docs/source/notebooks b/functorch/docs/source/notebooks new file mode 120000 index 0000000000000..d4082256dcfe3 --- /dev/null +++ b/functorch/docs/source/notebooks @@ -0,0 +1 @@ +../../notebooks/ \ No newline at end of file diff --git a/functorch/docs/source/ux_limitations.rst b/functorch/docs/source/ux_limitations.rst new file mode 100644 index 0000000000000..58ff259a61d37 --- /dev/null +++ b/functorch/docs/source/ux_limitations.rst @@ -0,0 +1,294 @@ +.. currentmodule:: functorch + +UX Limitations +============== + +functorch, like `JAX `_, has restrictions around +what can be transformed. In general, JAX’s limitations are that transforms +only work with pure functions: that is, functions where the output is completely +determined by the input and that do not involve side effects (like mutation). + +We have a similar guarantee: our transforms work well with pure functions. +However, we do support certain in-place operations. On one hand, writing code +compatible with functorch transforms may involve changing how you write PyTorch +code, on the other hand, you may find that our transforms let you express things +that were previously difficult to express in PyTorch. + +General limitations +------------------- + +All functorch transforms share a limitation in that a function should not +assign to global variables. Instead, all outputs to a function must be returned +from the function. This restriction comes from how functorch is implemented: +each transform wraps Tensor inputs in special functorch Tensor subclasses +that facilitate the transform. + +So, instead of the following: + +:: + + import torch + from functorch import grad + + # Don't do this + intermediate = None + + def f(x): + global intermediate + intermediate = x.sin() + z = intermediate.sin() + return z + + x = torch.randn([]) + grad_x = grad(f)(x) + +Please rewrite ``f`` to return ``intermediate``: + +:: + + def f(x): + intermediate = x.sin() + z = intermediate.sin() + return z, intermediate + + grad_x, intermediate = grad(f, has_aux=True)(x) + +torch.autograd APIs +------------------- + +If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad`` +or ``torch.autograd.backward`` inside of a function being transformed by +:func:`vmap` or one of functorch's AD transforms (:func:`vjp`, :func:`jvp`, +:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it. +If it is unable to do so, you'll receive an error message. + +This is a fundamental design limitation in how PyTorch's AD support is implemented +and the reason why we designed the functorch library. Please instead use the functorch +equivalents of the ``torch.autograd`` APIs: +- ``torch.autograd.grad``, ``Tensor.backward`` -> ``functorch.vjp`` or ``functorch.grad`` +- ``torch.autograd.functional.jvp`` -> ``functorch.jvp`` +- ``torch.autograd.functional.jacobian`` -> ``functorch.jacrev`` or ``functorch.jacfwd`` +- ``torch.autograd.functional.hessian`` -> ``functorch.hessian`` + +vmap limitations +---------------- + +.. note:: + :func:`vmap` is our most restrictive transform. + The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not + have these limitations. :func:`jacfwd` (and :func:`hessian`, which is + implemented with :func:`jacfwd`) is a composition of :func:`vmap` and + :func:`jvp` so it also has these limitations. + +``vmap(func)`` is a transform that returns a function that maps ``func`` over +some new dimension of each input Tensor. The mental model for vmap is that it is +like running a for-loop: for pure functions (i.e. in the absence of side +effects), ``vmap(f)(x)`` is equivalent to: + +:: + + torch.stack([f(x_i) for x_i in x.unbind(0)]) + +Mutation: Arbitrary mutation of Python data structures +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the presence of side effects, :func:`vmap` no longer acts like it is running +a for-loop. For example, the following function: + +:: + + def f(x, list): + list.pop() + print("hello!") + return x.sum(0) + + x = torch.randn(3, 1) + lst = [0, 1, 2, 3] + + result = vmap(f, in_dims=(0, None))(x, lst) + +will print "hello!" once and pop only one element from ``lst``. + + +:func:`vmap` executes `f` a single time, so all side effects only happen once. + +This is a consequence of how vmap is implemented. functorch has a special, +internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs, +turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``. +BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized) +behavior for each PyTorch operator. + + +Mutation: in-place PyTorch Operations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +:func:`vmap` will raise an error if it encounters an unsupported PyTorch +in-place operation and it will succeed otherwise. Unsupported operations +are those that would cause a Tensor with more elements to be written to a +Tensor with fewer elements. Here's an example of how this can occur: + +:: + + def f(x, y): + x.add_(y) + return x + + x = torch.randn(1) + y = torch.randn(3) + + # Raises an error because `y` has fewer elements than `x`. + vmap(f, in_dims=(None, 0))(x, y) + +``x`` is a Tensor with one element, ``y`` is a Tensor with three elements. +``x + y`` has three elements (due to broadcasting), but attempting to write +three elements back into ``x``, which only has one element, raises an error +due to attempting to write three elements into a Tensor with a single element. + +There is no problem if the Tensor being written to has the same number of +elements (or more): + +:: + + def f(x, y): + x.add_(y) + return x + + x = torch.randn(3) + y = torch.randn(3) + expected = x + y + + # Does not raise an error because x and y have the same number of elements. + vmap(f, in_dims=(0, 0))(x, y) + assert torch.allclose(x, expected) + +Mutation: out= PyTorch Operations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations. +It will error out gracefully if it encounters that in your code. + +This is not a fundamental limitation; we could theoretically support this in the +future but we have chosen not to for now. + +Data-dependent Python control flow +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +We don't yet support ``vmap`` over data-dependent control flow. Data-dependent +control flow is when the condition of an if-statement, while-loop, or +for-loop is a Tensor that is being ``vmap``'ed over. For example, the +following will raise an error message: + +:: + + def relu(x): + if x > 0: + return x + return 0 + + x = torch.randn(3) + vmap(relu)(x) + +However, any control flow that is not dependent on the values in ``vmap``'ed +tensors will work: + +:: + + def custom_dot(x): + if x.dim() == 1: + return torch.dot(x, x) + return (x * x).sum() + + x = torch.randn(3) + vmap(custom_dot)(x) + +JAX supports transforming over +`data-dependent control flow `_ +using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``). +We're investigating adding equivalents of those to functorch +(open an issue on `GitHub `_ to voice your support!). + +Data-dependent operations (.item()) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +We do not (and will not) support vmap over a user-defined function that calls +``.item()`` on a Tensor. For example, the following will raise an error message: + +:: + + def f(x): + return x.item() + + x = torch.randn(3) + vmap(f)(x) + +Please try to rewrite your code to not use ``.item()`` calls. + +You may also encounter an error message about using ``.item()`` but you might +not have used it. In those cases, it is possible that PyTorch internally is +calling ``.item()`` -- please file an issue on GitHub and we'll fix +PyTorch internals. + +Dynamic shape operations (nonzero and friends) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +``vmap(f)`` requires that ``f`` applied to every "example" in your input +returns a Tensor with the same shape. Operations such as ``torch.nonzero``, +``torch.is_nonzero`` are not supported and will error as a result. + +To see why, consider the following example: + +:: + + xs = torch.tensor([[0, 1, 2], [0, 0, 3]]) + vmap(torch.nonzero)(xs) + +``torch.nonzero(xs[0])`` returns a Tensor of shape 2; +but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1. +We are unable to construct a single Tensor as an output; +the output would need to be a ragged Tensor (and PyTorch does not yet have +the concept of a ragged Tensor). + + +Randomness +---------- +The user's intention when calling a random operation can be unclear. Specifically, some users may want +the random behavior to be the same across batches while others may want it to differ across batches. +To address this, ``vmap`` takes a randomness flag. + +The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting +to error. Under "error" mode, any call to a random function will produce an error asking the user to use +one of the other two flags based on their use case. + +Under "different" randomness, elements in a batch produce different random values. For instance, + +:: + + def add_noise(x): + y = torch.randn(()) # y will be different across the batch + return x + y + + x = torch.ones(3) + result = vmap(add_noise, randomness="different")(x) # we get 3 different values + +Under "same" randomness, elements in a batch produce same random values. For instance, + +:: + + def add_noise(x): + y = torch.randn(()) # y will be the same across the batch + return x + y + + x = torch.ones(3) + result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times + + +.. warning:: + Our system only determine the randomness behavior of PyTorch operators and cannot control the + behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions + +.. note:: + Multiple vmap calls using either type of supported randomness will not produce + the same results. Like with standard PyTorch, a user can get randomness reproducibility through + either using ``torch.manual_seed()`` outside of vmap or by using generators. + +.. note:: + Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch + doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the + most common forms of randmoness that we see. If your use case does not fit these forms of randomness, please + file an issue. diff --git a/functorch/examples/.gitignore b/functorch/examples/.gitignore new file mode 100644 index 0000000000000..ca86b701abea5 --- /dev/null +++ b/functorch/examples/.gitignore @@ -0,0 +1 @@ +cifar10/ diff --git a/functorch/examples/compilation/README.md b/functorch/examples/compilation/README.md new file mode 100644 index 0000000000000..e66d9dd5761d9 --- /dev/null +++ b/functorch/examples/compilation/README.md @@ -0,0 +1,4 @@ +## Compilation Examples + +> **WARNING**: Compilation is currently very experimental and example +here don't work out of the box with functorch diff --git a/functorch/examples/compilation/eager_fusion.py b/functorch/examples/compilation/eager_fusion.py new file mode 100644 index 0000000000000..cc43a5ce19970 --- /dev/null +++ b/functorch/examples/compilation/eager_fusion.py @@ -0,0 +1,54 @@ +from functorch.compile import aot_function, tvm_compile +import torch +import time +import torch.utils + +a = torch.randn(2000, 1, 4, requires_grad=True) +b = torch.randn(1, 2000, 4) + + +def f(a): + return (a * b).sum(dim=0) + + +fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops') +bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops') +compiled_f = aot_function(f, fw_compiler, bw_compiler) + +# fw_compiler = lambda x, _: x +# bw_compiler = lambda x, _: x +iters = 10 +out = compiled_f(a) +out.sum().backward() + + +def bench(func): + begin = time.time() + for _ in range(iters): + out = func(a).sin() + out.sum().backward() + a.grad = None + print(time.time() - begin) + + +def bench_jax(): + import jax.numpy as jnp + import jax + jax_a = jnp.array(a.detach().numpy()) + jax_b = jnp.array(b.detach().numpy()) + + def f(a): + return jnp.sin((a * jax_b).sum(axis=[0])).sum() + jit_f = jax.jit(jax.grad(f)) + jit_f(jax_a) + begin = time.time() + for _ in range(iters): + out = jit_f(jax_a) + out.block_until_ready() + print(time.time() - begin) + # for + + +bench(f) +bench(compiled_f) +# bench_jax() diff --git a/functorch/examples/compilation/fuse_module.py b/functorch/examples/compilation/fuse_module.py new file mode 100644 index 0000000000000..dafbc80711a3a --- /dev/null +++ b/functorch/examples/compilation/fuse_module.py @@ -0,0 +1,55 @@ +import timeit +from functorch.compile import compiled_module, tvm_compile +import torch.nn as nn +import torch + + +def nop(f, _): + return f + + +fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops') +bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops') +fw_compiler = nop +bw_compiler = nop + + +def run(mod, input): + out = mod(input) + out.sum().backward() + grads = [p.grad for p in mod.parameters()] + return (out, *grads) + + +class Foo(nn.Module): + def __init__(self): + super(Foo, self).__init__() + self.param = nn.Parameter(torch.randn(1)) + self.register_buffer("buf", torch.randn(1)) + + def forward(self, x): + return (self.param * x + self.buf).sum(dim=0) + + +input = torch.randn(1) +mod = Foo() +compiled_mod = compiled_module(mod, fw_compiler, bw_compiler) + +for a, b in zip(run(mod, input), run(compiled_mod, input)): + torch.testing.assert_allclose(a, b) + +out = mod(input) +out.sum().backward() +mod.param.data -= mod.param.grad +compiled_mod.orig_module.param.data -= compiled_mod.orig_module.param.grad +compiled_mod.orig_module.param.grad = None + +for a, b in zip(run(mod, input), run(compiled_mod, input)): + torch.testing.assert_allclose(a, b) + +for _ in range(5): + i = 10000 + t = timeit.Timer("mod(input)", globals=globals()).timeit(10000) + print(f"eager {t/i*1e6}") + t = timeit.Timer("compiled_mod(input)", globals=globals()).timeit(10000) + print(f"compiled {t/i*1e6}") diff --git a/functorch/examples/compilation/linear_train.py b/functorch/examples/compilation/linear_train.py new file mode 100644 index 0000000000000..2d5f9d7dd37b4 --- /dev/null +++ b/functorch/examples/compilation/linear_train.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functorch import make_functional +from functorch.compile import nnc_jit +import torch +import torch.nn as nn +import time +torch._C._jit_override_can_fuse_on_cpu(True) + + +def bench(f, iters=100, warmup=10): + for _ in range(warmup): + f() + begin = time.time() + for _ in range(iters): + f() + print((time.time() - begin)) + + +class Foo(nn.Module): + def __init__(self, num_layers=3, features=100): + super().__init__() + mods = [] + for _ in range(num_layers): + mods.append(nn.Linear(features, features, bias=False)) + self.mod = nn.Sequential(*mods) + + def forward(self, x): + return (self.mod(x)**2).sum() + + +batch_size = 16 +features = 64 +num_layers = 8 +inp = torch.randn((batch_size, features)) + +mod = Foo(num_layers, features) + +jit_mod = torch.jit.script(mod) + +func_model, weights = make_functional(mod) +lr = 1.0 + + +def functional_step(x, weights): + weights = [weight.detach().requires_grad_() for weight in weights] + out = func_model(weights, x) + out.backward() + new_weights = [weight - lr * weight.grad for weight in weights] + return out, new_weights + + +optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0) + + +def jit_step(x, weights): + optim.zero_grad() + loss = jit_mod(x) + loss.backward() + optim.step() + return loss, None + + +def train(train_step, weights): + torch.manual_seed(16) + train_step(inp, weights) + begin = time.time() + for itr in range(1000): + loss, weights = train_step(torch.randn(batch_size, features), weights) + if itr % 200 == 0: + print(f"Loss at {itr}: {loss}") + print("Time taken: ", time.time() - begin) + print() + + +grad_pt = functional_step +grad_nnc = nnc_jit(functional_step) + +print("Starting PT training") +train(grad_pt, weights) + +print("Starting NNC training") +train(grad_nnc, weights) + +print("Starting JIT training") +train(jit_step, None) diff --git a/functorch/examples/compilation/simple_function.py b/functorch/examples/compilation/simple_function.py new file mode 100644 index 0000000000000..14731c7c66661 --- /dev/null +++ b/functorch/examples/compilation/simple_function.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functorch import grad, make_fx +from functorch.compile import nnc_jit +import torch +import time + + +def f(x): + return torch.sin(x).sum() + + +inp = torch.randn(100) +grad_pt = grad(f) +grad_fx = make_fx(grad_pt)(inp) +grad_nnc = nnc_jit(grad_pt) + + +def bench(name, f, iters=10000, warmup=3): + for _ in range(warmup): + f() + begin = time.time() + for _ in range(iters): + f() + print(f"{name}: ", time.time() - begin) + + +bench("Pytorch: ", lambda: grad_pt(inp)) +bench("FX: ", lambda: grad_fx(inp)) +bench("NNC: ", lambda: grad_nnc(inp)) diff --git a/functorch/examples/dp_cifar10/README.md b/functorch/examples/dp_cifar10/README.md new file mode 100644 index 0000000000000..dbf65c45dae8a --- /dev/null +++ b/functorch/examples/dp_cifar10/README.md @@ -0,0 +1,25 @@ +## Differential Privacy with ResNet18 + +### Differential Privacy +Differential privacy is a way of training models that ensures no attacker can figure out the training +data from the gradient updates of the model. Recently, a paper was published comparing the performance of +Opacus to a JAX-based system. + +[Original differential privacy paper](https://people.csail.mit.edu/asmith/PS/sensitivity-tcc-final.pdf) +[JAX-based differential privacy paper](https://arxiv.org/pdf/2010.09063.pdf) + +### Opacus +Opacus is a differential privacy library built for PyTorch. They have added hooks to PyTorch's +autograd that compute per sample gradients and a differential privacy engine that computes +differentially private weight updates. + +### Example +This example runs ResNet18 by either having Opacus compute the differentially private updates or +getting the per sample gradients using vmap and grad and computing the differentially private update +from those. + +As a caveat, the transforms version may not be computing the exact same values as the opacus version. +No verification has been done yet for this. + +### Requirements +These examples use Opacus version 1.0.1 and torchvision 0.11.2 diff --git a/functorch/examples/dp_cifar10/cifar10_opacus.py b/functorch/examples/dp_cifar10/cifar10_opacus.py new file mode 100644 index 0000000000000..bcd0aae8b9dba --- /dev/null +++ b/functorch/examples/dp_cifar10/cifar10_opacus.py @@ -0,0 +1,468 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +""" +Runs CIFAR10 training with differential privacy. +""" + +import argparse +import logging +import shutil +import sys +from datetime import datetime, timedelta + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.utils.data +import torchvision.transforms as transforms +from torchvision import models +from opacus import PrivacyEngine +from torchvision.datasets import CIFAR10 +from tqdm import tqdm + + +logging.basicConfig( + format="%(asctime)s:%(levelname)s:%(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + stream=sys.stdout, +) +logger = logging.getLogger("ddp") +logger.setLevel(level=logging.INFO) + + +def save_checkpoint(state, is_best, filename="checkpoint.tar"): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +def accuracy(preds, labels): + return (preds == labels).mean() + + +def train(args, model, train_loader, optimizer, privacy_engine, epoch, device): + start_time = datetime.now() + + model.train() + criterion = nn.CrossEntropyLoss() + + losses = [] + top1_acc = [] + + for i, (images, target) in enumerate(tqdm(train_loader)): + + images = images.to(device) + target = target.to(device) + + # compute output + output = model(images) + loss = criterion(output, target) + preds = np.argmax(output.detach().cpu().numpy(), axis=1) + labels = target.detach().cpu().numpy() + + # measure accuracy and record loss + acc1 = accuracy(preds, labels) + + losses.append(loss.item()) + top1_acc.append(acc1) + + # compute gradient and do SGD step + loss.backward() + + # make sure we take a step after processing the last mini-batch in the + # epoch to ensure we start the next epoch with a clean state + optimizer.step() + optimizer.zero_grad() + + if i % args.print_freq == 0: + if not args.disable_dp: + epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent( + delta=args.delta, + alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)), + ) + print( + f"\tTrain Epoch: {epoch} \t" + f"Loss: {np.mean(losses):.6f} " + f"Acc@1: {np.mean(top1_acc):.6f} " + f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}" + ) + else: + print( + f"\tTrain Epoch: {epoch} \t" + f"Loss: {np.mean(losses):.6f} " + f"Acc@1: {np.mean(top1_acc):.6f} " + ) + train_duration = datetime.now() - start_time + return train_duration + + +def test(args, model, test_loader, device): + model.eval() + criterion = nn.CrossEntropyLoss() + losses = [] + top1_acc = [] + + with torch.no_grad(): + for images, target in tqdm(test_loader): + images = images.to(device) + target = target.to(device) + + output = model(images) + loss = criterion(output, target) + preds = np.argmax(output.detach().cpu().numpy(), axis=1) + labels = target.detach().cpu().numpy() + acc1 = accuracy(preds, labels) + + losses.append(loss.item()) + top1_acc.append(acc1) + + top1_avg = np.mean(top1_acc) + + print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ") + return np.mean(top1_acc) + + +# flake8: noqa: C901 +def main(): + args = parse_args() + + if args.debug >= 1: + logger.setLevel(level=logging.DEBUG) + + device = args.device + + if args.secure_rng: + try: + import torchcsprng as prng + except ImportError as e: + msg = ( + "To use secure RNG, you must install the torchcsprng package! " + "Check out the instructions here: https://github.com/pytorch/csprng#installation" + ) + raise ImportError(msg) from e + + generator = prng.create_random_device_generator("/dev/urandom") + + else: + generator = None + + augmentations = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + train_transform = transforms.Compose( + augmentations + normalize if args.disable_dp else normalize + ) + + test_transform = transforms.Compose(normalize) + + train_dataset = CIFAR10( + root=args.data_root, train=True, download=True, transform=train_transform + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=int(args.sample_rate * len(train_dataset)), + generator=generator, + num_workers=args.workers, + pin_memory=True, + ) + + test_dataset = CIFAR10( + root=args.data_root, train=False, download=True, transform=test_transform + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=args.batch_size_test, + shuffle=False, + num_workers=args.workers, + ) + + best_acc1 = 0 + + model = models.__dict__[args.architecture]( + pretrained=False, norm_layer=(lambda c: nn.GroupNorm(args.gn_groups, c)) + ) + model = model.to(device) + + if args.optim == "SGD": + optimizer = optim.SGD( + model.parameters(), + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + elif args.optim == "RMSprop": + optimizer = optim.RMSprop(model.parameters(), lr=args.lr) + elif args.optim == "Adam": + optimizer = optim.Adam(model.parameters(), lr=args.lr) + else: + raise NotImplementedError("Optimizer not recognized. Please check spelling") + + privacy_engine = None + if not args.disable_dp: + if args.clip_per_layer: + # Each layer has the same clipping threshold. The total grad norm is still bounded by `args.max_per_sample_grad_norm`. + n_layers = len( + [(n, p) for n, p in model.named_parameters() if p.requires_grad] + ) + max_grad_norm = [ + args.max_per_sample_grad_norm / np.sqrt(n_layers) + ] * n_layers + else: + max_grad_norm = args.max_per_sample_grad_norm + + privacy_engine = PrivacyEngine( + secure_mode=args.secure_rng, + ) + clipping = "per_layer" if args.clip_per_layer else "flat" + model, optimizer, train_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=train_loader, + noise_multiplier=args.sigma, + max_grad_norm=max_grad_norm, + clipping=clipping, + ) + + # Store some logs + accuracy_per_epoch = [] + time_per_epoch = [] + + for epoch in range(args.start_epoch, args.epochs + 1): + if args.lr_schedule == "cos": + lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / (args.epochs + 1))) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + train_duration = train( + args, model, train_loader, optimizer, privacy_engine, epoch, device + ) + top1_acc = test(args, model, test_loader, device) + + # remember best acc@1 and save checkpoint + is_best = top1_acc > best_acc1 + best_acc1 = max(top1_acc, best_acc1) + + time_per_epoch.append(train_duration) + accuracy_per_epoch.append(float(top1_acc)) + + save_checkpoint( + { + "epoch": epoch + 1, + "arch": "Convnet", + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + filename=args.checkpoint_file + ".tar", + ) + + time_per_epoch_seconds = [t.total_seconds() for t in time_per_epoch] + avg_time_per_epoch = sum(time_per_epoch_seconds) / len(time_per_epoch_seconds) + metrics = { + "accuracy": best_acc1, + "accuracy_per_epoch": accuracy_per_epoch, + "avg_time_per_epoch_str": str(timedelta(seconds=int(avg_time_per_epoch))), + "time_per_epoch": time_per_epoch_seconds, + } + + logger.info( + "\nNote:\n- 'total_time' includes the data loading time, training time and testing time.\n- 'time_per_epoch' measures the training time only.\n" + ) + logger.info(metrics) + +def parse_args(): + parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") + parser.add_argument( + "-j", + "--workers", + default=2, + type=int, + metavar="N", + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--epochs", + default=90, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--start-epoch", + default=1, + type=int, + metavar="N", + help="manual epoch number (useful on restarts)", + ) + parser.add_argument( + "-b", + "--batch-size-test", + default=256, + type=int, + metavar="N", + help="mini-batch size for test dataset (default: 256)" + ) + parser.add_argument( + "--sample-rate", + default=0.005, + type=float, + metavar="SR", + help="sample rate used for batch construction (default: 0.005)", + ) + parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial learning rate", + dest="lr", + ) + parser.add_argument( + "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum" + ) + parser.add_argument( + "--wd", + "--weight-decay", + default=0, + type=float, + metavar="W", + help="SGD weight decay", + dest="weight_decay", + ) + parser.add_argument( + "-p", + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency (default: 10)", + ) + parser.add_argument( + "--resume", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "-e", + "--evaluate", + dest="evaluate", + action="store_true", + help="evaluate model on validation set", + ) + parser.add_argument( + "--seed", default=None, type=int, help="seed for initializing training. " + ) + + parser.add_argument( + "--sigma", + type=float, + default=1.5, + metavar="S", + help="Noise multiplier (default 1.0)", + ) + parser.add_argument( + "-c", + "--max-per-sample-grad_norm", + type=float, + default=10.0, + metavar="C", + help="Clip per-sample gradients to this norm (default 1.0)", + ) + parser.add_argument( + "--disable-dp", + action="store_true", + default=False, + help="Disable privacy training and just train with vanilla SGD", + ) + parser.add_argument( + "--secure-rng", + action="store_true", + default=False, + help="Enable Secure RNG to have trustworthy privacy guarantees." + "Comes at a performance cost. Opacus will emit a warning if secure rng is off," + "indicating that for production use it's recommender to turn it on.", + ) + parser.add_argument( + "--delta", + type=float, + default=1e-5, + metavar="D", + help="Target delta (default: 1e-5)", + ) + + parser.add_argument( + "--checkpoint-file", + type=str, + default="checkpoint", + help="path to save check points", + ) + parser.add_argument( + "--data-root", + type=str, + default="../cifar10", + help="Where CIFAR10 is/will be stored", + ) + parser.add_argument( + "--log-dir", + type=str, + default="/tmp/stat/tensorboard", + help="Where Tensorboard log will be stored", + ) + parser.add_argument( + "--optim", + type=str, + default="SGD", + help="Optimizer to use (Adam, RMSprop, SGD)", + ) + parser.add_argument( + "--lr-schedule", type=str, choices=["constant", "cos"], default="cos" + ) + + parser.add_argument( + "--device", type=str, default="cuda", help="Device on which to run the code." + ) + + parser.add_argument( + "--architecture", + type=str, + default="resnet18", + help="model from torchvision to run", + ) + + parser.add_argument( + "--gn-groups", + type=int, + default=8, + help="Number of groups in GroupNorm", + ) + + parser.add_argument( + "--clip_per_layer", + action="store_true", + default=False, + help="Use static per-layer clipping with the same clipping threshold for each layer. Necessary for DDP. If `False` (default), uses flat clipping.", + ) + parser.add_argument( + "--debug", + type=int, + default=0, + help="debug level (default: 0)", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/functorch/examples/dp_cifar10/cifar10_transforms.py b/functorch/examples/dp_cifar10/cifar10_transforms.py new file mode 100644 index 0000000000000..9b591857055ff --- /dev/null +++ b/functorch/examples/dp_cifar10/cifar10_transforms.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +""" +Runs CIFAR10 training with differential privacy. +""" + +import argparse +import logging +import shutil +import sys +from datetime import datetime, timedelta + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.utils.data +import torchvision.transforms as transforms +from torchvision import models +from torchvision.datasets import CIFAR10 +from tqdm import tqdm + +import functorch +from functorch import vmap, grad_and_value +from functorch import make_functional + +# disable warning spam +functorch._C._set_vmap_fallback_warning_enabled(False) + +logging.basicConfig( + format="%(asctime)s:%(levelname)s:%(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + stream=sys.stdout, +) +logger = logging.getLogger("ddp") +logger.setLevel(level=logging.INFO) + + +def save_checkpoint(state, is_best, filename="checkpoint.tar"): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + + +def accuracy(preds, labels): + return (preds == labels).mean() + + +def compute_norms(sample_grads): + batch_size = sample_grads[0].shape[0] + norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads] + norms = torch.stack(norms, dim=0).norm(2, dim=0) + return norms, batch_size + + +def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0): + sample_grads = tuple(param.grad_sample for param in model.parameters()) + + # step 0: compute the norms + sample_norms, batch_size = compute_norms(sample_grads) + + # step 1: compute clipping factors + clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6) + clip_factor = clip_factor.clamp(max=1.0) + + # step 2: clip + grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad) + for sample_grad in sample_grads) + + # step 3: add gaussian noise + stddev = max_per_sample_grad_norm * noise_multiplier + noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device) + for grad_param in grads) + grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads)) + + # step 4: assign the new grads, delete the sample grads + for param, param_grad in zip(model.parameters(), grads): + param.grad = param_grad/batch_size + del param.grad_sample + + +def train(args, model, train_loader, optimizer, epoch, device): + start_time = datetime.now() + + criterion = nn.CrossEntropyLoss() + + losses = [] + top1_acc = [] + + for i, (images, target) in enumerate(tqdm(train_loader)): + + images = images.to(device) + target = target.to(device) + + # Step 1: compute per-sample-grads + + # In order to use functional vmap+grad, we need to be able to + # pass the weights to a model. + func_model, weights = make_functional(model) + + # To use vmap+grad to compute per-sample-grads, the forward pass + # must be re-formulated on a single example. + # We use the `grad` operator to compute forward+backward on a single example, + # and finally `vmap` to do forward+backward on multiple examples. + def compute_loss_and_output(weights, image, target): + images = image.unsqueeze(0) + targets = target.unsqueeze(0) + output = func_model(weights, images) + loss = criterion(output, targets) + return loss, output.squeeze(0) + + # `grad(f)` is a functional API that returns a function `f'` that + # computes gradients by running both the forward and backward pass. + # We want to extract some intermediate + # values from the computation (i.e. the loss and output). + # + # To extract the loss, we use the `grad_and_value` API, that returns the + # gradient of the weights w.r.t. the loss and the loss. + # + # To extract the output, we use the `has_aux=True` flag. + # `has_aux=True` assumes that `f` returns a tuple of two values, + # where the first is to be differentiated and the second "auxiliary value" + # is not to be differentiated. `f'` returns the gradient w.r.t. the loss, + # the loss, and the auxiliary value. + grads_loss_output = grad_and_value(compute_loss_and_output, has_aux=True) + sample_grads, (sample_loss, output) = \ + vmap(grads_loss_output, (None, 0, 0))(weights, images, target) + loss = sample_loss.mean() + + for grad_sample, weight in zip(sample_grads, model.parameters()): + weight.grad_sample = grad_sample.detach() + + # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise + clip_and_accumulate_and_add_noise( + model, args.max_per_sample_grad_norm, args.sigma) + + preds = np.argmax(output.detach().cpu().numpy(), axis=1) + labels = target.detach().cpu().numpy() + losses.append(loss.item()) + + # measure accuracy and record loss + acc1 = accuracy(preds, labels) + + top1_acc.append(acc1) + + # make sure we take a step after processing the last mini-batch in the + # epoch to ensure we start the next epoch with a clean state + optimizer.step() + optimizer.zero_grad() + + if i % args.print_freq == 0: + print( + f"\tTrain Epoch: {epoch} \t" + f"Loss: {np.mean(losses):.6f} " + f"Acc@1: {np.mean(top1_acc):.6f} " + ) + train_duration = datetime.now() - start_time + return train_duration + + +def test(args, model, test_loader, device): + model.eval() + criterion = nn.CrossEntropyLoss() + losses = [] + top1_acc = [] + + with torch.no_grad(): + for images, target in tqdm(test_loader): + images = images.to(device) + target = target.to(device) + + output = model(images) + loss = criterion(output, target) + preds = np.argmax(output.detach().cpu().numpy(), axis=1) + labels = target.detach().cpu().numpy() + acc1 = accuracy(preds, labels) + + losses.append(loss.item()) + top1_acc.append(acc1) + + top1_avg = np.mean(top1_acc) + + print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ") + return np.mean(top1_acc) + + +# flake8: noqa: C901 +def main(): + args = parse_args() + + if args.debug >= 1: + logger.setLevel(level=logging.DEBUG) + + device = args.device + + if args.secure_rng: + try: + import torchcsprng as prng + except ImportError as e: + msg = ( + "To use secure RNG, you must install the torchcsprng package! " + "Check out the instructions here: https://github.com/pytorch/csprng#installation" + ) + raise ImportError(msg) from e + + generator = prng.create_random_device_generator("/dev/urandom") + + else: + generator = None + + augmentations = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + train_transform = transforms.Compose(normalize) + + test_transform = transforms.Compose(normalize) + + train_dataset = CIFAR10( + root=args.data_root, train=True, download=True, transform=train_transform + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=int(args.sample_rate * len(train_dataset)), + generator=generator, + num_workers=args.workers, + pin_memory=True, + ) + + test_dataset = CIFAR10( + root=args.data_root, train=False, download=True, transform=test_transform + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=args.batch_size_test, + shuffle=False, + num_workers=args.workers, + ) + + best_acc1 = 0 + + model = models.__dict__[args.architecture]( + pretrained=False, norm_layer=(lambda c: nn.GroupNorm(args.gn_groups, c)) + ) + model = model.to(device) + + if args.optim == "SGD": + optimizer = optim.SGD( + model.parameters(), + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + elif args.optim == "RMSprop": + optimizer = optim.RMSprop(model.parameters(), lr=args.lr) + elif args.optim == "Adam": + optimizer = optim.Adam(model.parameters(), lr=args.lr) + else: + raise NotImplementedError("Optimizer not recognized. Please check spelling") + + # Store some logs + accuracy_per_epoch = [] + time_per_epoch = [] + + for epoch in range(args.start_epoch, args.epochs + 1): + if args.lr_schedule == "cos": + lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / (args.epochs + 1))) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + train_duration = train( + args, model, train_loader, optimizer, epoch, device + ) + top1_acc = test(args, model, test_loader, device) + + # remember best acc@1 and save checkpoint + is_best = top1_acc > best_acc1 + best_acc1 = max(top1_acc, best_acc1) + + time_per_epoch.append(train_duration) + accuracy_per_epoch.append(float(top1_acc)) + + save_checkpoint( + { + "epoch": epoch + 1, + "arch": "Convnet", + "state_dict": model.state_dict(), + "best_acc1": best_acc1, + "optimizer": optimizer.state_dict(), + }, + is_best, + filename=args.checkpoint_file + ".tar", + ) + + time_per_epoch_seconds = [t.total_seconds() for t in time_per_epoch] + avg_time_per_epoch = sum(time_per_epoch_seconds) / len(time_per_epoch_seconds) + metrics = { + "accuracy": best_acc1, + "accuracy_per_epoch": accuracy_per_epoch, + "avg_time_per_epoch_str": str(timedelta(seconds=int(avg_time_per_epoch))), + "time_per_epoch": time_per_epoch_seconds, + } + + logger.info( + "\nNote:\n- 'total_time' includes the data loading time, training time and testing time.\n- 'time_per_epoch' measures the training time only.\n" + ) + logger.info(metrics) + +def parse_args(): + parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") + parser.add_argument( + "-j", + "--workers", + default=2, + type=int, + metavar="N", + help="number of data loading workers (default: 2)", + ) + parser.add_argument( + "--epochs", + default=90, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--start-epoch", + default=1, + type=int, + metavar="N", + help="manual epoch number (useful on restarts)", + ) + parser.add_argument( + "-b", + "--batch-size-test", + default=256, + type=int, + metavar="N", + help="mini-batch size for test dataset (default: 256)" + ) + parser.add_argument( + "--sample-rate", + default=0.005, + type=float, + metavar="SR", + help="sample rate used for batch construction (default: 0.005)", + ) + parser.add_argument( + "--lr", + "--learning-rate", + default=0.1, + type=float, + metavar="LR", + help="initial learning rate", + dest="lr", + ) + parser.add_argument( + "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum" + ) + parser.add_argument( + "--wd", + "--weight-decay", + default=0, + type=float, + metavar="W", + help="SGD weight decay", + dest="weight_decay", + ) + parser.add_argument( + "-p", + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency (default: 10)", + ) + parser.add_argument( + "--resume", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "-e", + "--evaluate", + dest="evaluate", + action="store_true", + help="evaluate model on validation set", + ) + parser.add_argument( + "--seed", default=None, type=int, help="seed for initializing training. " + ) + + parser.add_argument( + "--sigma", + type=float, + default=1.5, + metavar="S", + help="Noise multiplier (default 1.0)", + ) + parser.add_argument( + "-c", + "--max-per-sample-grad_norm", + type=float, + default=10.0, + metavar="C", + help="Clip per-sample gradients to this norm (default 1.0)", + ) + parser.add_argument( + "--secure-rng", + action="store_true", + default=False, + help="Enable Secure RNG to have trustworthy privacy guarantees." + "Comes at a performance cost. Opacus will emit a warning if secure rng is off," + "indicating that for production use it's recommender to turn it on.", + ) + parser.add_argument( + "--delta", + type=float, + default=1e-5, + metavar="D", + help="Target delta (default: 1e-5)", + ) + + parser.add_argument( + "--checkpoint-file", + type=str, + default="checkpoint", + help="path to save check points", + ) + parser.add_argument( + "--data-root", + type=str, + default="../cifar10", + help="Where CIFAR10 is/will be stored", + ) + parser.add_argument( + "--log-dir", + type=str, + default="/tmp/stat/tensorboard", + help="Where Tensorboard log will be stored", + ) + parser.add_argument( + "--optim", + type=str, + default="SGD", + help="Optimizer to use (Adam, RMSprop, SGD)", + ) + parser.add_argument( + "--lr-schedule", type=str, choices=["constant", "cos"], default="cos" + ) + + parser.add_argument( + "--device", type=str, default="cuda", help="Device on which to run the code." + ) + + parser.add_argument( + "--architecture", + type=str, + default="resnet18", + help="model from torchvision to run", + ) + + parser.add_argument( + "--gn-groups", + type=int, + default=8, + help="Number of groups in GroupNorm", + ) + + parser.add_argument( + "--clip_per_layer", + action="store_true", + default=False, + help="Use static per-layer clipping with the same clipping threshold for each layer. Necessary for DDP. If `False` (default), uses flat clipping.", + ) + parser.add_argument( + "--debug", + type=int, + default=0, + help="debug level (default: 0)", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/functorch/examples/ensembling/parallel_train.py b/functorch/examples/ensembling/parallel_train.py new file mode 100644 index 0000000000000..1c0b0d6fe62e6 --- /dev/null +++ b/functorch/examples/ensembling/parallel_train.py @@ -0,0 +1,146 @@ +import argparse +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functorch import make_functional, grad_and_value, vmap, combine_state_for_ensemble + +# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a +# tutorial on Model Ensembling with JAX by Will Whitney. +# +# The original code comes with the following citation: +# @misc{Whitney2021Parallelizing, +# author = {William F. Whitney}, +# title = { {Parallelizing neural networks on one GPU with JAX} }, +# year = {2021}, +# url = {http://willwhitney.com/parallel-training-jax.html}, +# } + +# GOAL: Demonstrate that it is possible to use eager-mode vmap +# to parallelize training over models. + +parser = argparse.ArgumentParser(description="Functorch Ensembled Models") +parser.add_argument( + "--device", + type=str, + default="cpu", + help="CPU or GPU ID for this process (default: 'cpu')", +) +args = parser.parse_args() + +DEVICE = args.device + +# Step 1: Make some spirals + + +def make_spirals(n_samples, noise_std=0., rotations=1.): + ts = torch.linspace(0, 1, n_samples, device=DEVICE) + rs = ts ** 0.5 + thetas = rs * rotations * 2 * math.pi + signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 + labels = (signs > 0).to(torch.long).to(DEVICE) + + xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std + ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std + points = torch.stack([xs, ys], dim=1) + return points, labels + + +points, labels = make_spirals(100, noise_std=0.05) + + +# Step 2: Define two-layer MLP and loss function +class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + +loss_fn = nn.NLLLoss() + +# Step 3: Make the model functional(!!) and define a training function. +# NB: this mechanism doesn't exist in PyTorch today, but we want it to: +# https://github.com/pytorch/pytorch/issues/49171 +func_model, weights = make_functional(MLPClassifier().to(DEVICE)) + + +def train_step_fn(weights, batch, targets, lr=0.2): + def compute_loss(weights, batch, targets): + output = func_model(weights, batch) + loss = loss_fn(output, targets) + return loss + + grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) + + # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon) + # so we are going to re-implement SGD here. + new_weights = [] + with torch.no_grad(): + for grad_weight, weight in zip(grad_weights, weights): + new_weights.append(weight - grad_weight * lr) + + return loss, new_weights + + +# Step 4: Let's verify this actually trains. +# We should see the loss decrease. +def step4(): + global weights + for i in range(2000): + loss, weights = train_step_fn(weights, points, labels) + if i % 100 == 0: + print(loss) + + +step4() + +# Step 5: We're ready for multiple models. Let's define an init_fn +# that, given a number of models, returns to us all of the weights. + + +def init_fn(num_models): + models = [MLPClassifier().to(DEVICE) for _ in range(num_models)] + _, params, _ = combine_state_for_ensemble(models) + return params + +# Step 6: Now, can we try multiple models at the same time? +# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps +# on decreasing + + +def step6(): + parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None)) + batched_weights = init_fn(num_models=2) + for i in range(2000): + loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels) + if i % 200 == 0: + print(loss) + + +step6() + +# Step 7: Now, the flaw with step 6 is that we were training on the same exact +# data. This can lead to all of the models in the ensemble overfitting in the +# same way. The solution that http://willwhitney.com/parallel-training-jax.html +# applies is to randomly subset the data in a way that the models do not recieve +# exactly the same data in each training step! +# Because the goal of this doc is to show that we can use eager-mode vmap to +# achieve similar things as JAX, the rest of this is left as an exercise to the reader. + +# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html +# does, we used the following additional items that PyTorch does not have: +# 1. NN module functional API that turns a module into a (state, state_less_fn) pair +# 2. Functional optimizers +# 3. A "functional" grad API (that effectively wraps autograd.grad) +# 4. Composability between the functional grad API and torch.vmap. diff --git a/functorch/examples/lennard_jones/lennard_jones.py b/functorch/examples/lennard_jones/lennard_jones.py new file mode 100644 index 0000000000000..41192cb1ce38f --- /dev/null +++ b/functorch/examples/lennard_jones/lennard_jones.py @@ -0,0 +1,70 @@ +# This example was adapated from https://github.com/muhrin/milad +# It is licensed under the GLPv3 license. You can find a copy of it +# here: https://www.gnu.org/licenses/gpl-3.0.en.html . + +import torch +from torch import nn +from torch.nn.functional import mse_loss +from functorch import jacrev, vmap + +sigma = 0.5 +epsilon = 4. + + +def lennard_jones(r): + return epsilon * ((sigma / r)**12 - (sigma / r)**6) + + +def lennard_jones_force(r): + """Get magnitude of LJ force""" + return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)) + + +training_size = 1000 +r = torch.linspace(0.5, 2 * sigma, steps=training_size, requires_grad=True) + +# Create a bunch of vectors that point along positive-x +drs = torch.outer(r, torch.tensor([1.0, 0, 0])) +norms = torch.norm(drs, dim=1).reshape(-1, 1) +# Create training energies +training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1) +# Create forces with random direction vectors +training_forces = torch.stack([force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]) + +model = nn.Sequential( + nn.Linear(1, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 1) +) + + +def make_prediction(model, drs): + norms = torch.norm(drs, dim=1).reshape(-1, 1) + energies = model(norms) + + network_derivs = vmap(jacrev(model))(norms).squeeze(-1) + forces = -network_derivs * drs / norms + return energies, forces + + +def loss_fn(energies, forces, predicted_energies, predicted_forces): + return mse_loss(energies, predicted_energies) + 0.01 * mse_loss(forces, predicted_forces) / 3 + + +optimiser = torch.optim.Adam(model.parameters(), lr=1e-3) + +for epoch in range(400): + optimiser.zero_grad() + energies, forces = make_prediction(model, drs) + loss = loss_fn(training_energies, training_forces, energies, forces) + loss.backward(retain_graph=True) + optimiser.step() + + if epoch % 20 == 0: + print(loss.cpu().item()) diff --git a/functorch/examples/maml_omniglot/.gitignore b/functorch/examples/maml_omniglot/.gitignore new file mode 100644 index 0000000000000..783c4e5c7b46e --- /dev/null +++ b/functorch/examples/maml_omniglot/.gitignore @@ -0,0 +1,2 @@ +omniglot/ +maml-accs.png diff --git a/functorch/examples/maml_omniglot/README.md b/functorch/examples/maml_omniglot/README.md new file mode 100644 index 0000000000000..dfb6077814bfe --- /dev/null +++ b/functorch/examples/maml_omniglot/README.md @@ -0,0 +1,11 @@ +# Omniglot MAML examples + +In this directory we've provided some examples of traning omniglot that reproduce the experiments from [the original MAML paper](https://arxiv.org/abs/1703.03400). + +They can be run via `python {filename}`. + +`maml-omniglot-higher.py` uses the [facebookresearch/higher](https://github.com/facebookresearch/higher) metalearning package and is the reference implementation. It runs all of its tasks sequentially. + +`maml-omniglot-transforms.py` uses functorch. It runs all of its tasks in parallel. In theory this should lead to some speedups, but we haven't finished implementing all the rules for vmap that would actually make training faster. + +`maml-omniglot-ptonly.py` is an implementation of `maml-omniglot-transforms.py` that runs all of its tasks sequentially (and also doesn't use the higher package). diff --git a/functorch/examples/maml_omniglot/maml-omniglot-higher.py b/functorch/examples/maml_omniglot/maml-omniglot-higher.py new file mode 100755 index 0000000000000..8f6e017f212ad --- /dev/null +++ b/functorch/examples/maml_omniglot/maml-omniglot-higher.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to use higher to do Model Agnostic Meta Learning (MAML) +for few-shot Omniglot classification. +For more details see the original MAML paper: +https://arxiv.org/abs/1703.03400 + +This code has been modified from Jackie Loong's PyTorch MAML implementation: +https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py + +Our MAML++ fork and experiments are available at: +https://github.com/bamos/HowToTrainYourMAMLPytorch +""" + +from support.omniglot_loaders import OmniglotNShot +import higher +import torch.optim as optim +import torch.nn.functional as F +from torch import nn +import torch +import matplotlib.pyplot as plt +import argparse +import time + +import pandas as pd +import numpy as np +import matplotlib as mpl +mpl.use('Agg') +plt.style.use('bmh') + + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument( + '--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument( + '--k_qry', type=int, help='k shot for query set', default=15) + argparser.add_argument( + '--device', type=str, help='device', default='cuda') + argparser.add_argument( + '--task_num', + type=int, + help='meta batch size, namely task num', + default=32) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + + # Set up the Omniglot loader. + device = args.device + db = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + device=device, + ) + + # Create a vanilla PyTorch neural network that will be + # automatically monkey-patched by higher later. + # Before higher, models could *not* be created like this + # and the parameters needed to be manually updated and copied + # for the updates. + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + Flatten(), + nn.Linear(64, args.n_way)).to(device) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + meta_opt = optim.Adam(net.parameters(), lr=1e-3) + + log = [] + for epoch in range(100): + train(db, net, device, meta_opt, epoch, log) + test(db, net, device, epoch, log) + plot(log) + + +def train(db, net, device, meta_opt, epoch, log): + net.train() + n_train_iter = db.x_train.shape[0] // db.batchsz + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + task_num, setsz, c_, h, w = x_spt.size() + querysz = x_qry.size(1) + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + + # Initialize the inner optimizer to adapt the parameters to + # the support set. + n_inner_iter = 5 + inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) + + qry_losses = [] + qry_accs = [] + meta_opt.zero_grad() + for i in range(task_num): + with higher.innerloop_ctx( + net, inner_opt, copy_initial_weights=False + ) as (fnet, diffopt): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + # higher is able to automatically keep copies of + # your network's parameters as they are being updated. + for _ in range(n_inner_iter): + spt_logits = fnet(x_spt[i]) + spt_loss = F.cross_entropy(spt_logits, y_spt[i]) + diffopt.step(spt_loss) + + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = fnet(x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_losses.append(qry_loss.detach()) + qry_acc = (qry_logits.argmax( + dim=1) == y_qry[i]).sum().item() / querysz + qry_accs.append(qry_acc) + + # print([b.shape for b in fnet[1].buffers()]) + + # Update the model's meta-parameters to optimize the query + # losses across all of the tasks sampled in this batch. + # This unrolls through the gradient steps. + qry_loss.backward() + + meta_opt.step() + qry_losses = sum(qry_losses) / task_num + qry_accs = 100. * sum(qry_accs) / task_num + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + if batch_idx % 4 == 0: + print( + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + ) + + log.append({ + 'epoch': i, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'train', + 'time': time.time(), + }) + + +def test(db, net, device, epoch, log): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + net.train() + n_test_iter = db.x_test.shape[0] // db.batchsz + + qry_losses = [] + qry_accs = [] + + for _ in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + + task_num, setsz, c_, h, w = x_spt.size() + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = 5 + inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) + + for i in range(task_num): + with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + for _ in range(n_inner_iter): + spt_logits = fnet(x_spt[i]) + spt_loss = F.cross_entropy(spt_logits, y_spt[i]) + diffopt.step(spt_loss) + + # The query loss and acc induced by these parameters. + qry_logits = fnet(x_qry[i]).detach() + qry_loss = F.cross_entropy( + qry_logits, y_qry[i], reduction='none') + qry_losses.append(qry_loss.detach()) + qry_accs.append( + (qry_logits.argmax(dim=1) == y_qry[i]).detach()) + + qry_losses = torch.cat(qry_losses).mean().item() + qry_accs = 100. * torch.cat(qry_accs).float().mean().item() + print( + f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' + ) + log.append({ + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + }) + + +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(6, 4)) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(70, 100) + fig.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'maml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +# Won't need this after this PR is merged in: +# https://github.com/pytorch/pytorch/pull/22245 +class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +if __name__ == '__main__': + main() diff --git a/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py new file mode 100755 index 0000000000000..594237ee7d6e5 --- /dev/null +++ b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to use higher to do Model Agnostic Meta Learning (MAML) +for few-shot Omniglot classification. +For more details see the original MAML paper: +https://arxiv.org/abs/1703.03400 + +This code has been modified from Jackie Loong's PyTorch MAML implementation: +https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py + +Our MAML++ fork and experiments are available at: +https://github.com/bamos/HowToTrainYourMAMLPytorch +""" + +from support.omniglot_loaders import OmniglotNShot +from functorch import make_functional_with_buffers +import torch.optim as optim +import torch.nn.functional as F +from torch import nn +import torch +import matplotlib.pyplot as plt +import argparse +import time + +import pandas as pd +import numpy as np +import matplotlib as mpl +mpl.use('Agg') +plt.style.use('bmh') + + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument( + '--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument( + '--k_qry', type=int, help='k shot for query set', default=15) + argparser.add_argument( + '--device', type=str, help='device', default='cuda') + argparser.add_argument( + '--task_num', + type=int, + help='meta batch size, namely task num', + default=32) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + + # Set up the Omniglot loader. + device = args.device + db = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + device=device, + ) + + # Create a vanilla PyTorch neural network that will be + # automatically monkey-patched by higher later. + # Before higher, models could *not* be created like this + # and the parameters needed to be manually updated and copied + # for the updates. + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1, affine=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + Flatten(), + nn.Linear(64, args.n_way)).to(device) + + net.train() + fnet, params, buffers = make_functional_with_buffers(net) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + meta_opt = optim.Adam(params, lr=1e-3) + + log = [] + for epoch in range(100): + train(db, [params, buffers, fnet], device, meta_opt, epoch, log) + test(db, [params, buffers, fnet], device, epoch, log) + plot(log) + + +def train(db, net, device, meta_opt, epoch, log): + params, buffers, fnet = net + n_train_iter = db.x_train.shape[0] // db.batchsz + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + task_num, setsz, c_, h, w = x_spt.size() + querysz = x_qry.size(1) + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + + # Initialize the inner optimizer to adapt the parameters to + # the support set. + n_inner_iter = 5 + # inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) + + qry_losses = [] + qry_accs = [] + meta_opt.zero_grad() + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + new_params = params + for _ in range(n_inner_iter): + spt_logits = fnet(new_params, buffers, x_spt[i]) + spt_loss = F.cross_entropy(spt_logits, y_spt[i]) + grads = torch.autograd.grad(spt_loss, new_params, create_graph=True) + new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = fnet(new_params, buffers, x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + qry_losses.append(qry_loss.detach()) + qry_acc = (qry_logits.argmax( + dim=1) == y_qry[i]).sum().item() / querysz + qry_accs.append(qry_acc) + + # Update the model's meta-parameters to optimize the query + # losses across all of the tasks sampled in this batch. + # This unrolls through the gradient steps. + qry_loss.backward() + + meta_opt.step() + qry_losses = sum(qry_losses) / task_num + qry_accs = 100. * sum(qry_accs) / task_num + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + if batch_idx % 4 == 0: + print( + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + ) + + log.append({ + 'epoch': i, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'train', + 'time': time.time(), + }) + + +def test(db, net, device, epoch, log): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + [params, buffers, fnet] = net + n_test_iter = db.x_test.shape[0] // db.batchsz + + qry_losses = [] + qry_accs = [] + + for batch_idx in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + task_num, setsz, c_, h, w = x_spt.size() + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = 5 + + for i in range(task_num): + new_params = params + for _ in range(n_inner_iter): + spt_logits = fnet(new_params, buffers, x_spt[i]) + spt_loss = F.cross_entropy(spt_logits, y_spt[i]) + grads = torch.autograd.grad(spt_loss, new_params) + new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + + # The query loss and acc induced by these parameters. + qry_logits = fnet(new_params, buffers, x_qry[i]).detach() + qry_loss = F.cross_entropy( + qry_logits, y_qry[i], reduction='none') + qry_losses.append(qry_loss.detach()) + qry_accs.append( + (qry_logits.argmax(dim=1) == y_qry[i]).detach()) + + qry_losses = torch.cat(qry_losses).mean().item() + qry_accs = 100. * torch.cat(qry_accs).float().mean().item() + print( + f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' + ) + log.append({ + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + }) + + +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(6, 4)) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(70, 100) + fig.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'maml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +# Won't need this after this PR is merged in: +# https://github.com/pytorch/pytorch/pull/22245 +class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +if __name__ == '__main__': + main() diff --git a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py new file mode 100755 index 0000000000000..2ac6663faaf78 --- /dev/null +++ b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to use higher to do Model Agnostic Meta Learning (MAML) +for few-shot Omniglot classification. +For more details see the original MAML paper: +https://arxiv.org/abs/1703.03400 + +This code has been modified from Jackie Loong's PyTorch MAML implementation: +https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py + +Our MAML++ fork and experiments are available at: +https://github.com/bamos/HowToTrainYourMAMLPytorch +""" + +from support.omniglot_loaders import OmniglotNShot +from functorch import make_functional_with_buffers, vmap, grad +import functorch +import torch.optim as optim +import torch.nn.functional as F +from torch import nn +import torch +import matplotlib.pyplot as plt +import argparse +import time +import functools + +import pandas as pd +import numpy as np +import matplotlib as mpl +mpl.use('Agg') +plt.style.use('bmh') + + +# Squash the warning spam +functorch._C._set_vmap_fallback_warning_enabled(False) + + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument( + '--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument( + '--k_qry', type=int, help='k shot for query set', default=15) + argparser.add_argument( + '--device', type=str, help='device', default='cuda') + argparser.add_argument( + '--task_num', + type=int, + help='meta batch size, namely task num', + default=32) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + + # Set up the Omniglot loader. + device = args.device + db = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + device=device, + ) + + # Create a vanilla PyTorch neural network. + inplace_relu = True + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, affine=True, track_running_stats=False), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, affine=True, track_running_stats=False), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, affine=True, track_running_stats=False), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way)).to(device) + + net.train() + + # Given this module we've created, rip out the parameters and buffers + # and return a functional version of the module. `fnet` is stateless + # and can be called with `fnet(params, buffers, args, kwargs)` + fnet, params, buffers = make_functional_with_buffers(net) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + meta_opt = optim.Adam(params, lr=1e-3) + + log = [] + for epoch in range(100): + train(db, [params, buffers, fnet], device, meta_opt, epoch, log) + test(db, [params, buffers, fnet], device, epoch, log) + plot(log) + + +# Trains a model for n_inner_iter using the support and returns a loss +# using the query. +def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): + params, buffers, fnet = net + querysz = x_qry.size(0) + + def compute_loss(new_params, buffers, x, y): + logits = fnet(new_params, buffers, x) + loss = F.cross_entropy(logits, y) + return loss + + new_params = params + for _ in range(n_inner_iter): + grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) + new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = fnet(new_params, buffers, x_qry) + qry_loss = F.cross_entropy(qry_logits, y_qry) + qry_acc = (qry_logits.argmax( + dim=1) == y_qry).sum() / querysz + + return qry_loss, qry_acc + + +def train(db, net, device, meta_opt, epoch, log): + params, buffers, fnet = net + n_train_iter = db.x_train.shape[0] // db.batchsz + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + task_num, setsz, c_, h, w = x_spt.size() + + n_inner_iter = 5 + meta_opt.zero_grad() + + # In parallel, trains one model per task. There is a support (x, y) + # for each task and a query (x, y) for each task. + compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter) + qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry) + + # Compute the maml loss by summing together the returned losses. + qry_losses.sum().backward() + + meta_opt.step() + qry_losses = qry_losses.detach().sum() / task_num + qry_accs = 100. * qry_accs.sum() / task_num + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + if batch_idx % 4 == 0: + print( + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + ) + + log.append({ + 'epoch': i, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'train', + 'time': time.time(), + }) + + +def test(db, net, device, epoch, log): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + [params, buffers, fnet] = net + n_test_iter = db.x_test.shape[0] // db.batchsz + + qry_losses = [] + qry_accs = [] + + for batch_idx in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + task_num, setsz, c_, h, w = x_spt.size() + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = 5 + + for i in range(task_num): + new_params = params + for _ in range(n_inner_iter): + spt_logits = fnet(new_params, buffers, x_spt[i]) + spt_loss = F.cross_entropy(spt_logits, y_spt[i]) + grads = torch.autograd.grad(spt_loss, new_params) + new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + + # The query loss and acc induced by these parameters. + qry_logits = fnet(new_params, buffers, x_qry[i]).detach() + qry_loss = F.cross_entropy( + qry_logits, y_qry[i], reduction='none') + qry_losses.append(qry_loss.detach()) + qry_accs.append( + (qry_logits.argmax(dim=1) == y_qry[i]).detach()) + + qry_losses = torch.cat(qry_losses).mean().item() + qry_accs = 100. * torch.cat(qry_accs).float().mean().item() + print( + f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' + ) + log.append({ + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + }) + + +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(6, 4)) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(70, 100) + fig.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'maml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/functorch/examples/maml_omniglot/support/omniglot_loaders.py b/functorch/examples/maml_omniglot/support/omniglot_loaders.py new file mode 100644 index 0000000000000..b712b9b31e435 --- /dev/null +++ b/functorch/examples/maml_omniglot/support/omniglot_loaders.py @@ -0,0 +1,302 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation: +# https://github.com/dragen1860/MAML-Pytorch +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py + +import torchvision.transforms as transforms +from PIL import Image +import numpy as np + +import torch +import torch.utils.data as data +import os +import os.path +import errno + + +class Omniglot(data.Dataset): + urls = [ + 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', + 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip' + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + ''' + The items are (filename,category). The index of all the categories can be found in self.idx_classes + Args: + - root: the directory where the dataset will be stored + - transform: how to transform the input + - target_transform: how to transform the target + - download: need to download the dataset + ''' + + def __init__(self, root, transform=None, target_transform=None, + download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + + if not self._check_exists(): + if download: + self.download() + else: + raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') + + self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) + self.idx_classes = index_classes(self.all_items) + + def __getitem__(self, index): + filename = self.all_items[index][0] + img = str.join('/', [self.all_items[index][2], filename]) + + target = self.idx_classes[self.all_items[index][1]] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.all_items) + + def _check_exists(self): + return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \ + os.path.exists(os.path.join(self.root, self.processed_folder, "images_background")) + + def download(self): + from six.moves import urllib + import zipfile + + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('== Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + file_processed = os.path.join(self.root, self.processed_folder) + print("== Unzip from " + file_path + " to " + file_processed) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(file_processed) + zip_ref.close() + print("Download finished.") + + +def find_classes(root_dir): + retour = [] + for (root, dirs, files) in os.walk(root_dir): + for f in files: + if (f.endswith("png")): + r = root.split('/') + lr = len(r) + retour.append((f, r[lr - 2] + "/" + r[lr - 1], root)) + print("== Found %d items " % len(retour)) + return retour + + +def index_classes(items): + idx = {} + for i in items: + if i[1] not in idx: + idx[i[1]] = len(idx) + print("== Found %d classes" % len(idx)) + return idx + + +class OmniglotNShot: + + def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None): + """ + Different from mnistNShot, the + :param root: + :param batchsz: task num + :param n_way: + :param k_shot: + :param k_qry: + :param imgsz: + """ + + self.resize = imgsz + self.device = device + if not os.path.isfile(os.path.join(root, 'omniglot.npy')): + # if root/data.npy does not exist, just download it + self.x = Omniglot( + root, download=True, + transform=transforms.Compose( + [lambda x: Image.open(x).convert('L'), + lambda x: x.resize((imgsz, imgsz)), + lambda x: np.reshape(x, (imgsz, imgsz, 1)), + lambda x: np.transpose(x, [2, 0, 1]), + lambda x: x / 255.]), + ) + + temp = dict() # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} + for (img, label) in self.x: + if label in temp.keys(): + temp[label].append(img) + else: + temp[label] = [img] + + self.x = [] + for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs + self.x.append(np.array(imgs)) + + # as different class may have different number of imgs + self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] + # each character contains 20 imgs + print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] + temp = [] # Free memory + # save all dataset into npy file. + np.save(os.path.join(root, 'omniglot.npy'), self.x) + print('write into omniglot.npy.') + else: + # if data.npy exists, just load it. + self.x = np.load(os.path.join(root, 'omniglot.npy')) + print('load from omniglot.npy.') + + # [1623, 20, 84, 84, 1] + # TODO: can not shuffle here, we must keep training and test set distinct! + self.x_train, self.x_test = self.x[:1200], self.x[1200:] + + # self.normalization() + + self.batchsz = batchsz + self.n_cls = self.x.shape[0] # 1623 + self.n_way = n_way # n way + self.k_shot = k_shot # k shot + self.k_query = k_query # k query + assert (k_shot + k_query) <= 20 + + # save pointer of current read batch in total cache + self.indexes = {"train": 0, "test": 0} + self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached + print("DB: train", self.x_train.shape, "test", self.x_test.shape) + + self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached + "test": self.load_data_cache(self.datasets["test"])} + + def normalization(self): + """ + Normalizes our data, to have a mean of 0 and sdt of 1 + """ + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) + self.x_train = (self.x_train - self.mean) / self.std + self.x_test = (self.x_test - self.mean) / self.std + + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + + # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) + + def load_data_cache(self, data_pack): + """ + Collects several batches data for N-shot learning + :param data_pack: [cls_num, 20, 84, 84, 1] + :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks + """ + # take 5 way 1 shot as example: 5 * 1 + setsz = self.k_shot * self.n_way + querysz = self.k_query * self.n_way + data_cache = [] + + # print('preload next 50 caches of batchsz of batch.') + for sample in range(10): # num of episodes + + x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] + for i in range(self.batchsz): # one batch means one set + + x_spt, y_spt, x_qry, y_qry = [], [], [], [] + selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) + + for j, cur_class in enumerate(selected_cls): + + selected_img = np.random.choice(20, self.k_shot + self.k_query, False) + + # meta-training and meta-test + x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]]) + x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]]) + y_spt.append([j for _ in range(self.k_shot)]) + y_qry.append([j for _ in range(self.k_query)]) + + # shuffle inside a batch + perm = np.random.permutation(self.n_way * self.k_shot) + x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm] + y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] + perm = np.random.permutation(self.n_way * self.k_query) + x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm] + y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] + + # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] + x_spts.append(x_spt) + y_spts.append(y_spt) + x_qrys.append(x_qry) + y_qrys.append(y_qry) + + # [b, setsz, 1, 84, 84] + x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize) + y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz) + # [b, qrysz, 1, 84, 84] + x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize) + y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz) + + x_spts, y_spts, x_qrys, y_qrys = [ + torch.from_numpy(z).to(self.device) for z in + [x_spts, y_spts, x_qrys, y_qrys] + ] + + data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) + + return data_cache + + def next(self, mode='train'): + """ + Gets next batch from the dataset with name. + :param mode: The name of the splitting (one of "train", "val", "test") + :return: + """ + # update cache if indexes is larger cached num + if self.indexes[mode] >= len(self.datasets_cache[mode]): + self.indexes[mode] = 0 + self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) + + next_batch = self.datasets_cache[mode][self.indexes[mode]] + self.indexes[mode] += 1 + + return next_batch diff --git a/functorch/examples/maml_regression/evjang.py b/functorch/examples/maml_regression/evjang.py new file mode 100644 index 0000000000000..fcd7a3b292406 --- /dev/null +++ b/functorch/examples/maml_regression/evjang.py @@ -0,0 +1,122 @@ +# Eric Jang originally wrote an implementation of MAML in JAX +# (https://github.com/ericjang/maml-jax). +# We translated his implementation from JAX to PyTorch. + +import matplotlib.pyplot as plt +import math +import torch +import numpy as np +from torch.nn import functional as F +import matplotlib as mpl +mpl.use('Agg') + + +def net(x, params): + x = F.linear(x, params[0], params[1]) + x = F.relu(x) + + x = F.linear(x, params[2], params[3]) + x = F.relu(x) + + x = F.linear(x, params[4], params[5]) + return x + + +params = [ + torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(), + torch.Tensor(40).zero_().requires_grad_(), + + torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), + torch.Tensor(40).zero_().requires_grad_(), + + torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), + torch.Tensor(1).zero_().requires_grad_(), +] + +opt = torch.optim.Adam(params, lr=1e-3) +alpha = 0.1 + +K = 20 +losses = [] +num_tasks = 4 + + +def sample_tasks(outer_batch_size, inner_batch_size): + # Select amplitude and phase for the task + As = [] + phases = [] + for _ in range(outer_batch_size): + As.append(np.random.uniform(low=0.1, high=.5)) + phases.append(np.random.uniform(low=0., high=np.pi)) + + def get_batch(): + xs, ys = [], [] + for A, phase in zip(As, phases): + x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) + y = A * np.sin(x + phase) + xs.append(x) + ys.append(y) + return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) + x1, y1 = get_batch() + x2, y2 = get_batch() + return x1, y1, x2, y2 + + +for it in range(20000): + loss2 = 0.0 + opt.zero_grad() + + def get_loss_for_task(x1, y1, x2, y2): + f = net(x1, params) + loss = F.mse_loss(f, y1) + + # create_graph=True because computing grads here is part of the forward pass. + # We want to differentiate through the SGD update steps and get higher order + # derivatives in the backward pass. + grads = torch.autograd.grad(loss, params, create_graph=True) + new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))] + + v_f = net(x2, new_params) + return F.mse_loss(v_f, y2) + + task = sample_tasks(num_tasks, K) + inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)] + loss2 = sum(inner_losses) / len(inner_losses) + loss2.backward() + + opt.step() + + if it % 100 == 0: + print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) + losses.append(loss2.detach()) + +t_A = torch.tensor(0.0).uniform_(0.1, 0.5) +t_b = torch.tensor(0.0).uniform_(0.0, math.pi) + +t_x = torch.empty(4, 1).uniform_(-5, 5) +t_y = t_A * torch.sin(t_x + t_b) + +opt.zero_grad() + +t_params = params +for k in range(5): + t_f = net(t_x, t_params) + t_loss = F.l1_loss(t_f, t_y) + + grads = torch.autograd.grad(t_loss, t_params, create_graph=True) + t_params = [(t_params[i] - alpha * grads[i]) for i in range(len(params))] + + +test_x = torch.arange(-2 * math.pi, 2 * math.pi, step=0.01).unsqueeze(1) +test_y = t_A * torch.sin(test_x + t_b) + +test_f = net(test_x, t_params) + +plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') +plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') +plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') +plt.legend() +plt.savefig('maml-sine.png') +plt.figure() +plt.plot(np.convolve(losses, [.05] * 20)) +plt.savefig('losses.png') diff --git a/functorch/examples/maml_regression/evjang_transforms.py b/functorch/examples/maml_regression/evjang_transforms.py new file mode 100644 index 0000000000000..bc20e2c41c961 --- /dev/null +++ b/functorch/examples/maml_regression/evjang_transforms.py @@ -0,0 +1,129 @@ +# Eric Jang originally wrote an implementation of MAML in JAX +# (https://github.com/ericjang/maml-jax). +# We translated his implementation from JAX to PyTorch. + +from functorch import grad, vmap +import matplotlib.pyplot as plt +import math +import torch +import numpy as np +from torch.nn import functional as F +import matplotlib as mpl +mpl.use('Agg') + + +def net(params, x): + x = F.linear(x, params[0], params[1]) + x = F.relu(x) + + x = F.linear(x, params[2], params[3]) + x = F.relu(x) + + x = F.linear(x, params[4], params[5]) + return x + + +params = [ + torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(), + torch.Tensor(40).zero_().requires_grad_(), + + torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), + torch.Tensor(40).zero_().requires_grad_(), + + torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), + torch.Tensor(1).zero_().requires_grad_(), +] + +# TODO: use F.mse_loss + + +def mse_loss(x, y): + return torch.mean((x - y) ** 2) + + +opt = torch.optim.Adam(params, lr=1e-3) +alpha = 0.1 + +K = 20 +losses = [] +num_tasks = 4 + + +def sample_tasks(outer_batch_size, inner_batch_size): + # Select amplitude and phase for the task + As = [] + phases = [] + for _ in range(outer_batch_size): + As.append(np.random.uniform(low=0.1, high=.5)) + phases.append(np.random.uniform(low=0., high=np.pi)) + + def get_batch(): + xs, ys = [], [] + for A, phase in zip(As, phases): + x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) + y = A * np.sin(x + phase) + xs.append(x) + ys.append(y) + return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) + x1, y1 = get_batch() + x2, y2 = get_batch() + return x1, y1, x2, y2 + + +for it in range(20000): + loss2 = 0.0 + opt.zero_grad() + + def get_loss_for_task(x1, y1, x2, y2): + def inner_loss(params, x1, y1): + f = net(params, x1) + loss = mse_loss(f, y1) + return loss + + grads = grad(inner_loss)(tuple(params), x1, y1) + new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))] + + v_f = net(new_params, x2) + return mse_loss(v_f, y2) + + task = sample_tasks(num_tasks, K) + inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3]) + loss2 = sum(inner_losses) / len(inner_losses) + loss2.backward() + + opt.step() + + if it % 100 == 0: + print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) + losses.append(loss2.detach()) + +t_A = torch.tensor(0.0).uniform_(0.1, 0.5) +t_b = torch.tensor(0.0).uniform_(0.0, math.pi) + +t_x = torch.empty(4, 1).uniform_(-5, 5) +t_y = t_A * torch.sin(t_x + t_b) + +opt.zero_grad() + +t_params = params +for k in range(5): + t_f = net(t_params, t_x) + t_loss = F.l1_loss(t_f, t_y) + + grads = torch.autograd.grad(t_loss, t_params, create_graph=True) + t_params = [(t_params[i] - alpha * grads[i]) for i in range(len(params))] + + +test_x = torch.arange(-2 * math.pi, 2 * math.pi, step=0.01).unsqueeze(1) +test_y = t_A * torch.sin(test_x + t_b) + +test_f = net(t_params, test_x) + +plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') +plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') +plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') +plt.legend() +plt.savefig('maml-sine.png') +plt.figure() +plt.plot(np.convolve(losses, [.05] * 20)) +plt.savefig('losses.png') diff --git a/functorch/examples/maml_regression/evjang_transforms_module.py b/functorch/examples/maml_regression/evjang_transforms_module.py new file mode 100644 index 0000000000000..d1483550a29e6 --- /dev/null +++ b/functorch/examples/maml_regression/evjang_transforms_module.py @@ -0,0 +1,126 @@ +# Eric Jang originally wrote an implementation of MAML in JAX +# (https://github.com/ericjang/maml-jax). +# We translated his implementation from JAX to PyTorch. + +from functorch import grad, vmap, make_functional +import matplotlib.pyplot as plt +import math +import torch +import numpy as np +from torch import nn +from torch.nn import functional as F +import matplotlib as mpl +mpl.use('Agg') + + +class ThreeLayerNet(nn.Module): + def __init__(self): + super(ThreeLayerNet, self).__init__() + self.fc1 = nn.Linear(1, 40) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(40, 40) + self.relu2 = nn.ReLU() + self.fc3 = nn.Linear(40, 1) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + x = self.relu2(x) + x = self.fc3(x) + return x + +# TODO: Use F.mse_loss + + +def mse_loss(x, y): + return torch.mean((x - y) ** 2) + + +net, params = make_functional(ThreeLayerNet()) +opt = torch.optim.Adam(params, lr=1e-3) +alpha = 0.1 + +K = 20 +losses = [] +num_tasks = 4 + + +def sample_tasks(outer_batch_size, inner_batch_size): + # Select amplitude and phase for the task + As = [] + phases = [] + for _ in range(outer_batch_size): + As.append(np.random.uniform(low=0.1, high=.5)) + phases.append(np.random.uniform(low=0., high=np.pi)) + + def get_batch(): + xs, ys = [], [] + for A, phase in zip(As, phases): + x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) + y = A * np.sin(x + phase) + xs.append(x) + ys.append(y) + return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) + x1, y1 = get_batch() + x2, y2 = get_batch() + return x1, y1, x2, y2 + + +for it in range(20000): + loss2 = 0.0 + opt.zero_grad() + + def get_loss_for_task(x1, y1, x2, y2): + def inner_loss(params, x1, y1): + f = net(params, x1) + loss = mse_loss(f, y1) + return loss + + grads = grad(inner_loss)(params, x1, y1) + new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))] + + v_f = net(new_params, x2) + return mse_loss(v_f, y2) + + task = sample_tasks(num_tasks, K) + inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3]) + loss2 = sum(inner_losses) / len(inner_losses) + loss2.backward() + + opt.step() + + if it % 100 == 0: + print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) + losses.append(loss2.detach()) + +t_A = torch.tensor(0.0).uniform_(0.1, 0.5) +t_b = torch.tensor(0.0).uniform_(0.0, math.pi) + +t_x = torch.empty(4, 1).uniform_(-5, 5) +t_y = t_A * torch.sin(t_x + t_b) + +opt.zero_grad() + +t_params = params +for k in range(5): + t_f = net(t_params, t_x) + t_loss = F.l1_loss(t_f, t_y) + + grads = torch.autograd.grad(t_loss, t_params, create_graph=True) + t_params = [(t_params[i] - alpha * grads[i]) for i in range(len(params))] + + +test_x = torch.arange(-2 * math.pi, 2 * math.pi, step=0.01).unsqueeze(1) +test_y = t_A * torch.sin(test_x + t_b) + +test_f = net(t_params, test_x) + +plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') +plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') +plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') +plt.legend() +plt.savefig('maml-sine.png') +plt.figure() +plt.plot(np.convolve(losses, [.05] * 20)) +plt.savefig('losses.png') diff --git a/functorch/functorch/__init__.py b/functorch/functorch/__init__.py new file mode 100644 index 0000000000000..4f5adc5a7ad14 --- /dev/null +++ b/functorch/functorch/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from . import _C + +# Monkey patch PyTorch. This is a hack, we should try to upstream +# these pieces. +from ._src import monkey_patching as _monkey_patching + +# Top-level APIs. Please think carefully before adding something to the +# top-level namespace: +# - private helper functions should go into functorch._src +# - very experimental things should go into functorch.experimental +# - compilation related things should go into functorch.compile + +# functorch transforms +from ._src.vmap import vmap +from ._src.eager_transforms import ( + grad, grad_and_value, vjp, jacrev, jvp, jacfwd, hessian, +) +from ._src.python_key import make_fx + +# utilities. Maybe these should go in their own namespace in the future? +from ._src.make_functional import ( + make_functional_with_buffers, + make_functional, + combine_state_for_ensemble, + FunctionalModule, + FunctionalModuleWithBuffers, +) + +try: + from .version import __version__ # noqa: F401 +except ImportError: + pass diff --git a/functorch/functorch/_src/__init__.py b/functorch/functorch/_src/__init__.py new file mode 100644 index 0000000000000..10a55772ab58b --- /dev/null +++ b/functorch/functorch/_src/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/functorch/functorch/_src/aot_autograd.py b/functorch/functorch/_src/aot_autograd.py new file mode 100644 index 0000000000000..41356b20644e8 --- /dev/null +++ b/functorch/functorch/_src/aot_autograd.py @@ -0,0 +1,713 @@ +from contextlib import contextmanager, nullcontext +import torch +import torch.nn as nn +from torch import Tensor +from functorch import make_fx +from torch.fx import immutable_collections +from torch._subclasses import FakeTensorMode +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch.nn.utils import _stateless +from functorch._C import CompileCache +from functorch.experimental import functionalize +from . import config +from .decompositions import register_decomposition +from .partitioners import default_partition +from .named_members_polyfill import _named_parameters, _named_buffers +from typing import Callable, List, Dict, Any, Tuple, Optional +from functools import wraps + +try: + from torchdynamo import disable as disable_torchdynamo +except ImportError: + def disable_torchdynamo(x): + return x + +pytree._register_pytree_node( + immutable_collections.immutable_list, + lambda x: (list(x), None), + lambda x, c: immutable_collections.immutable_list(x), +) +pytree._register_pytree_node( + immutable_collections.immutable_dict, + lambda x: (list(x.values()), list(x.keys())), + lambda x, c: immutable_collections.immutable_dict( + {key: value for key, value in zip(c, x)} + ), +) + +# TODO - move this to PyTorch core. This overrides the pytree implementation for +# dict to maintain parity with Deepmind pytree. +Context = Any + + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + keys = sorted(d.keys()) + values = [d[key] for key in keys] + return values, keys + + +def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: + return {key: value for key, value in zip(context, values)} + + +pytree._register_pytree_node(dict, _dict_flatten, _dict_unflatten) + +aten = torch.ops.aten + + +@contextmanager +def preserve_rng_state(): + rng_state = torch.clone(torch.random.get_rng_state()) + if torch.cuda.is_available(): + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) + try: + yield + finally: + torch.random.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + + +def create_joint_forward_backward(fn): + def joint_forward_backward( + primals: List[Any], tangents: List[Any] + ) -> Tuple[List[Any], List[Any]]: + # Call the forward pass + outs = fn(*primals) + # Get the inputs that need gradients + grad_primals = [] + inputs_needs_grads = [] + for p in primals: + is_grad_tensor = isinstance(p, Tensor) and p.requires_grad + inputs_needs_grads.append(is_grad_tensor) + if is_grad_tensor: + grad_primals.append(p) + + # Get the outputs that need gradients + assert len(tangents) == len(outs) + needed_outs = [] + needed_tangents = [] + for out, tangent in zip(outs, tangents): + if isinstance(out, Tensor) and out.requires_grad: + needed_outs.append(out) + needed_tangents.append(tangent) + backward_out = [] + # Call the backwards pass + if grad_primals: + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) + backward_out_iter = iter(backward_out) + return outs, [ + next(backward_out_iter) if i else None for i in inputs_needs_grads + ] + + return joint_forward_backward + + +def normalize_as_list(x): + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + +aot_autograd_decompositions = {} + + +@register_decomposition(aten._reshape_alias, aot_autograd_decompositions) +def _reshape_alias(x, shape, strides): + return aten.view(x, shape) + + +@register_decomposition(aten.new_zeros, aot_autograd_decompositions) +def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None): + return torch.zeros(size, dtype=inp.dtype, device=inp.device) + + +@register_decomposition(aten.new_full, aot_autograd_decompositions) +def new_full(inp, size, value, dtype=None, layout=None, device=None, pin_memory=None): + return torch.full(size, value, dtype=inp.dtype, device=inp.device) + + +graph_being_compiled: str = None +nth_graph: int = 0 +model_name: str = "model" + + +def set_model_name(name): + global model_name + model_name = name + + +def get_graph_being_compiled() -> str: + """ + Returns the name of the graph being compiled. + """ + global model_name, graph_being_compiled, nth_graph + return f"{model_name}_{graph_being_compiled}_{nth_graph}" + + +@contextmanager +def track_graph_compiling(graph_name, increment_index=False): + global graph_being_compiled + graph_being_compiled = graph_name + yield + if increment_index: + global nth_graph + nth_graph += 1 + graph_being_compiled = None + + +def create_aot_autograd_function( + flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state +): + """ + Traces the forward and backward graphs of the attr:`flat_fn` to generate a + joint graph. The joint graph is an Fx graph with Aten ops. Please refer to + the tracing mechanism to understand the graph capturing details. + + The joint graph is then passed through attr:`partition_fn` to isolate the + forward and backward portions, which are then respectively compiled via the + provided attr:`fw_compiler` and attr:`bw_compiler`. + + The resulting compiled forward and backward graphs are then wrapped up in a + ``torch.autograd.Function`` object. + """ + if decompositions is None: + decompositions = {} + joint_forward_backward = create_joint_forward_backward(flat_fn) + + compiled_fw = None + compiled_bw = None + num_outs = None + + class CompiledFunction(torch.autograd.Function): + @staticmethod + @disable_torchdynamo + def forward(ctx, *flat_tensor_args): + nonlocal compiled_fw, compiled_bw, num_outs + # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. + # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. + old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) + if compiled_fw is None: + flat_tensor_args = pytree.tree_map( + lambda x: x.detach().requires_grad_(x.requires_grad) + if isinstance(x, Tensor) else x, flat_tensor_args + ) + fake_mode = FakeTensorMode.push() if config.use_fake_tensor else nullcontext() + with preserve_rng_state(), fake_mode as mode: + # Set input tensors that require grad to leaves + fake_flat_tensor_args = pytree.tree_map( + lambda x: mode.from_tensor(x) if mode else x + if isinstance(x, Tensor) else x, flat_tensor_args + ) + with torch.set_grad_enabled(grad_state): + out = flat_fn(*fake_flat_tensor_args) + out = pytree.tree_map( + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out + ) + + if isinstance(out, (list, tuple)): + num_outs = len(out) + else: + num_outs = 1 + + joint_inputs = (fake_flat_tensor_args, out) + aot_decompositions = {**aot_autograd_decompositions, **decompositions} + with torch.set_grad_enabled(grad_state): + fx_g = make_fx(joint_forward_backward, aot_decompositions)( + *joint_inputs + ) + + if config.use_functionalize: + # Functionalize the foward backward graph. First create a + # fake fn to make functionalize happy + def fake_fn(primals, tangents): + return fx_g(primals, tangents) + fx_g = make_fx(functionalize(fake_fn))(*joint_inputs) + + if config.debug_joint: + print(fx_g.code) + + with track_graph_compiling("joint"): + fw_module, bw_module = partition_fn(fx_g, joint_inputs) + + if config.debug_graphs: + print(fw_module.code, bw_module.code) + + with track_graph_compiling("forward"): + compiled_fw = fw_compiler(fw_module, flat_tensor_args) + fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) + if config.debug_partitioner: + activation_sizes = 0 + for out in fw_outs[num_outs:]: + if isinstance(out, torch.Tensor): + activation_sizes += out.storage().nbytes() + print(f"Real Activations Stored(GB): {activation_sizes/1e9}") + + bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] + with track_graph_compiling("backward", True): + compiled_bw = bw_compiler(bw_module, bw_args) + else: + fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) + torch._C._jit_set_autocast_mode(old_jit_autocast_flag) + ctx.save_for_backward(*fw_outs[num_outs:]) + return tuple(fw_outs[0:num_outs]) + + @staticmethod + @disable_torchdynamo + def backward(ctx, *flat_args): + # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. + # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. + old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) + contiguous_args = [t.contiguous() for t in flat_args] + # contiguous_args = [t for t in flat_args] + out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + torch._C._jit_set_autocast_mode(old_jit_autocast_flag) + return tuple(out) + + return CompiledFunction + + +class _CompileCache(CompileCache): + pass + + +# using a C++-based pytree reduces the overhead by about 50% +try: + import tree + + HAS_TREE = True +except ImportError: + HAS_TREE = False +compile_cache = None + + +# Inspired by autodidax (thanks!) +class PytreeThunk: + spec = None + # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. + is_simple = ( + None # if the output spec is a tuple/list, we won't bother unflattening it. + ) + is_really_simple = None # if the output spec is a LeafSpec + + def set(self, spec): + assert self.spec is None or self.spec == spec + self.spec = spec + if type(self.spec) in [tuple, list] and all( + isinstance(i, pytree.LeafSpec) for i in spec.children_specs + ): + self.is_simple = True + if isinstance(self.spec, pytree.LeafSpec): + self.is_really_simple = True + + def unflatten(self, x): + if self.is_really_simple: + return x[0] + if self.is_simple: + return x + return pytree.tree_unflatten(x, self.spec) + + +def filter_tensor_and_static_args(args, static_argnums): + """ + Separate out the tensor and static args. Also, for the static args, store + the hash. + """ + tensor_args = [] + static_args = [] + static_args_hashed = [] + for idx, arg in enumerate(args): + if idx not in static_argnums: + tensor_args.append(arg) + else: + static_args.append(arg) + static_args_hashed.append(arg.__hash__()) + return tensor_args, static_args, static_args_hashed + + +def rearrange(tensor_args, static_args, static_argnums): + """ + Generate the args as per the original spec. static_argnums is sorted. + """ + tensor_index = 0 + static_index = 0 + index = 0 + args = [] + assert len(static_args) == len(static_argnums) + while tensor_index < len(tensor_args) and static_index < len(static_args): + if index == static_argnums[static_index]: + args.append(static_args[static_index]) + static_index += 1 + else: + args.append(tensor_args[tensor_index]) + tensor_index += 1 + index += 1 + + while tensor_index < len(tensor_args): + args.append(tensor_args[tensor_index]) + tensor_index += 1 + + while static_index < len(static_args): + args.append(static_args[static_index]) + static_index += 1 + + return args + + +KNOWN_TYPES = [torch.Tensor, int, str, float, bool] + + +def aot_function( + fn: Callable, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[Dict] = None, + hasher_type: str = "StaticShapeHasher", + static_argnums: Optional[Tuple[int]] = None, +) -> Callable: + """ + Traces the forward and backward graph of :attr:`fn` using torch dispatch + mechanism, and then compiles the generated forward and backward graphs + through :attr:`fw_compiler` and :attr:`bw_compiler`. + + :func:`aot_function` traces the forward and backward graph ahead of time, + and generates a joint forward and backward graph. :attr:`partition_fn` is + then used to separate out forward and backward graphs. The partitioner + function can be used to perform optimizations such as recomputation. One can + set `decompositions` dictionary to decompose the operators into a sequence + of core or simpler operators supported by the backend compilers. + + :func:`aot_function` uses a compilation cache, based on input tensor + properties, to detect when there is a need of recompilation. By default, its + behavior is static, i.e., it recompiles if shape of any input tensor + changes. + + :attr:`static_argnums` allows user to mark the arguments of the original + :attr:`fn` as static. This is useful when an argument is a non-tensor, e.g., + ``int`` or ``bool``. A change in the actual value of static arg causes + recompilation. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Callable): A Python function that takes one ore more arguments. Must + return one or more Tensors. + fw_compiler (Callable): A Python function that accepts an Fx graph with + Aten ops and input args, and returns a Callable that semantically is + equivalent to the input Fx graph. + bw_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. Default: None + (when None, it defaults to the :attr:`fw_compiler`) + partition_fn (Callable): A Python function that takes a joint forward + and backward graph, and partitions it into separate forward and + backward graphs. + decompositions (Dict): A dictionary to define the decomposition of + larger Aten ops into simpler or core Aten ops. + static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark + the arguments of the function as static. + + Returns: + Returns a ``Callable`` that retains the eager behavior of the original + :attr:`fn`, but with forward and backward graph compiled via + :attr:`fw_compile` and :attr:`bw_compile`. + + A simple example usage of :func:`aot_function` is as follows. This example + will print the forward and backward graphs of the function ``fn`` + + >>> fn = lambda x : x.sin().cos() + >>> def print_compile_fn(fx_module, args): + >>> print(fx_module) + >>> return fx_module + >>> aot_fn = aot_function(fn, print_compile_fn) + >>> x = torch.randn(4, 5, requires_grad=True) + >>> aot_fn(x) + + The static argnums are used to mark the non-tensor arguments as static. An + example is as follows where the dropout probability is as argument to the + original function. + + >>> def fn(input, bias, residual, p: float): + >>> a = torch.add(input, bias) + >>> b = torch.nn.functional.dropout(a, p, training=True) + >>> c = b + residual + >>> return c + >>> aot_fn = aot_function(fn, print_compile_fn, static_argnums=(3,)) + + """ + global compile_cache + if compile_cache is None: + compile_cache = CompileCache() + if bw_compiler is None: + bw_compiler = fw_compiler + cached_res = None + + fn_id = id(fn) + fw_compiler_id = id(fw_compiler) + bw_compiler_id = id(bw_compiler) + + if isinstance(static_argnums, int): + static_argnums = [static_argnums] + elif static_argnums is not None and len(static_argnums) == 0: + static_argnums = None + elif static_argnums is not None: + static_argnums = list(static_argnums) + static_argnums.sort() + + @wraps(fn) + def returned_function(*args, **kwargs): + global compile_cache + nonlocal cached_res + + # Separate out static args if static_argnums is present + tensor_args = args + static_args = [] + # TODO - move the hashing part of static_args to C++. + static_args_hashed = [] + if static_argnums is not None: + ( + tensor_args, + static_args, + static_args_hashed, + ) = filter_tensor_and_static_args(args, static_argnums) + + # Now flatten the tensor args + if HAS_TREE: + flat_tensor_args = tree.flatten((tensor_args, kwargs)) + else: + flat_tensor_args, _ = pytree.tree_flatten((tensor_args, kwargs)) + + # Check if the fn is already compiled + num_tensor_args = len(flat_tensor_args) + flat_args_for_cache = flat_tensor_args + static_args_hashed + cached_res = compile_cache.at( + fn_id, + fw_compiler_id, + bw_compiler_id, + num_tensor_args, + hasher_type, + *flat_args_for_cache, + ) + + # Compile the function and save it in the cache + if cached_res is None: + # Save the args_spec for flat_tensor_args to unflatten while tracing + _, tensor_args_spec = pytree.tree_flatten((tensor_args, kwargs)) + out_spec = PytreeThunk() + + def flat_fn(*flat_tensor_args): + # The input are flattened tensor args. Prepare the args in the + # order that original function expects. Add static args as well. + # They will appear as tensor constants in the traced graph. + nonlocal out_spec, static_args + + tensor_args, kwargs = pytree.tree_unflatten( + flat_tensor_args, tensor_args_spec + ) + if static_argnums is None: + args = tensor_args + else: + args = rearrange(tensor_args, static_args, static_argnums) + tree_out = fn(*args, **kwargs) + flat_out, spec = pytree.tree_flatten(tree_out) + for i in flat_out: + is_known_type = False + for j in KNOWN_TYPES: + if isinstance(i, j): + is_known_type = True + break + if not is_known_type: + raise RuntimeError( + f"Found {type(i)} in output, which is not a known type. " + "If this type holds tensors, you need to register a pytree for it. " + "See https://github.com/pytorch/functorch/issues/475 for a brief " + "explanation why. If you don't need to register a pytree, please " + "leave a comment explaining your use case and we'll make this more " + "ergonomic to deal with" + ) + out_spec.set(spec) + return flat_out + + compiled_fn = create_aot_autograd_function( + flat_fn, + fw_compiler, + bw_compiler, + partition_fn, + decompositions, + grad_state=torch.is_grad_enabled(), + ).apply + cached_res = (compiled_fn, out_spec) + + # Save the compiled_fn in the cache + compile_cache.insert( + fn_id, + fw_compiler_id, + bw_compiler_id, + num_tensor_args, + hasher_type, + cached_res, + *flat_args_for_cache, + ) + + cached_fn, out_spec = cached_res + out = cached_fn(*flat_tensor_args) + return out_spec.unflatten(out) + + return returned_function + + +def num_of_recompilations(): + """ + Returns the numbers of recompilations since the last time cache was cleared. + This is equivalent to the number of entries in the compilation cache. + """ + global compile_cache + if compile_cache is None: + return 0 + return compile_cache.size() + + +def clear_compile_cache(): + """ + Clears the compilation cache. + """ + global compile_cache + if compile_cache is not None: + compile_cache.clear() + compile_cache = None + + +def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: + """ + Traces the forward and backward graph of :attr:`mod` using torch dispatch + tracing mechanism. It is wrapper function, that underneath uses + :func:`aot_function` to perform tracing and compilation. + + :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs + to a new callable which is then compiled through :func:`aot_function`. + + .. warning:: + This API is experimental and likely to change. + + Args: + mod (Callable): A ``nn.Module`` module. + args : args to be passed to :func:`aot_function` + kwargs : kwargs to be passed to :func:`aot_function` + + Returns: + Returns a ``nn.Module`` that retains the eager behavior of the original + :attr:`mod`, but with forward and backward graph compiled. + + """ + + def functional_call(named_params, named_buffers, *args, **kwargs): + params_and_buffers = {**named_params, **named_buffers} + return _stateless.functional_call(mod, params_and_buffers, args, kwargs) + + compiled_f = aot_function(functional_call, *args, **kwargs) + + class AOTModule(nn.Module): + def __init__(self): + super(AOTModule, self).__init__() + self.orig_module = mod + + def forward(self, *args, **kwargs): + return compiled_f( + dict(_named_parameters(mod, remove_duplicate=False)), + dict(_named_buffers(mod, remove_duplicate=False)), + *args, + **kwargs, + ) + + return AOTModule() + + +def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: + """ + This is the simplified or low overhead version of aot_module. For frontends + like TorchDynamo, the input functions/modules to AOT are static and have + unpacked inputs/outputs. This gives us an opportunity to remove the + (1) pytree overhead to parse inputs/outputs, + (2) AOT Autograd cache, + (3) Reading of params/buffers in every forward call + + :func:`aot_module_simplified` removes these overheads. + """ + ######################################################### + + params = { + **dict(_named_parameters(mod, remove_duplicate=False)), + **dict(_named_buffers(mod, remove_duplicate=False)), + } + params_flat, params_spec = pytree.tree_flatten(params) + params_flat = tuple(params_flat) + params_len = len(params_flat) + + def functional_call(*args, **kwargs): + with _stateless.reparametrize_module( + mod, pytree.tree_unflatten(args[:params_len], params_spec) + ): + out = mod(*args[params_len:], **kwargs) + if not isinstance(out, (tuple, list)): + raise RuntimeError( + "Graph output must be a tuple(). This is so that we can avoid " + "pytree processing of the ouputs. Please change the module to " + "have tuple outputs or use aot_module instead." + ) + return out + + def aot_function_simplified( + fn: Callable, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[Dict] = None, + hasher_type: str = "StaticShapeHasher", + static_argnums: Optional[Tuple[int]] = None, + ) -> Callable: + assert static_argnums is None + if bw_compiler is None: + bw_compiler = fw_compiler + compiled_fn = create_aot_autograd_function( + fn, + fw_compiler, + bw_compiler, + partition_fn, + decompositions, + grad_state=torch.is_grad_enabled(), + ).apply + + return compiled_fn + + compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs) + + if top_kwargs: + def forward(*args, **kwargs): + return compiled_f( + *params_flat, + *args, + **kwargs, + ) + else: + def forward(*args): + return compiled_f( + *params_flat, + *args, + ) + + forward.zero_grad = mod.zero_grad + return forward + + +compiled_function = aot_function +compiled_module = aot_module diff --git a/functorch/functorch/_src/benchmark_utils.py b/functorch/functorch/_src/benchmark_utils.py new file mode 100644 index 0000000000000..1b44101837f1d --- /dev/null +++ b/functorch/functorch/_src/benchmark_utils.py @@ -0,0 +1,200 @@ +import time +import os +import json + +import torch +from torch.profiler import profile, ProfilerActivity + + +def synchronize(): + pass + + +class NullContext: + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +def dump_chrome_trace(f, input, trace_filename, optimize_ctx, activities, num_runs=1, + devices=None, kwargs_for_f=None, kwargs_for_profiler=None): + """ + Output the chrome trace of running f(input, **kwargs_for_f) with [optimize_ctx] + [num_runs] times to [trace_filename]. + + [activities] are the activities that the profiler will record, e.g. ProfilerActivity.CUDA. + Return total runtime without the profiler + + Outputs to trace_filename + """ + + if devices is None: + devices = ["cuda"] + + global synchronize + if devices != ["cpu"] and torch.cuda.is_available(): + synchronize = torch.cuda.synchronize + + if kwargs_for_f is None: + kwargs_for_f = {} + if kwargs_for_profiler is None: + kwargs_for_profiler = {} + + with optimize_ctx: + torch.manual_seed(1337) + for _ in range(5): # warmup runs + f(input, **kwargs_for_f) + synchronize() + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(num_runs): + f(input, **kwargs_for_f) + synchronize() + t1 = time.perf_counter() + timing = t1 - t0 + + with profile(activities=activities, **kwargs_for_profiler) as prof: + with optimize_ctx: + synchronize() + torch.manual_seed(1337) + for _ in range(num_runs): + f(input, **kwargs_for_f) + synchronize() + prof.export_chrome_trace(trace_filename) + + return timing + + +def get_chrome_trace_events(filename): + f = open(filename) + data = json.load(f) + events = data["traceEvents"] + return events + + +def is_gpu_compute_event(event): + global gpu_pids + return "pid" in event and event["pid"] in gpu_pids and "ph" in event and event["ph"] == "X" + + +def get_sorted_gpu_events(events): + sorted_gpu_events = [] + for event in events: + if(not is_gpu_compute_event(event)): + continue + sorted_gpu_events.append(event) + return sorted(sorted_gpu_events, key=lambda x: x["ts"]) + + +def get_duration(sorted_gpu_events): + if len(sorted_gpu_events) == 0: + return 0 + event = sorted_gpu_events[0] + current_end_time = event["ts"] + event["dur"] + total_duration = event["dur"] + for event in sorted_gpu_events[1:]: + start_time = max(event["ts"], current_end_time) + end_time = event["ts"] + event["dur"] + total_duration = total_duration + max(end_time - start_time, 0) + current_end_time = max(current_end_time, end_time) + return total_duration + + +def get_sorted_gpu_mm_conv_events(events): + def is_mm_conv_event(event): + return "name" in event and ("gemm" in event["name"] or "conv" in event["name"] + or "cutlass" in event["name"] or "wgrad" in event["name"]) + gpu_events = get_sorted_gpu_events(events) + sorted_events = [] + for event in gpu_events: + if(not is_mm_conv_event(event)): + continue + sorted_events.append(event) + return sorted_events + + +gpu_pids = [] + + +def compute_utilization(filename: str, total_length: float): + """ + Process the chrome traces outputs by the pytorch profiler to compute GPU Utilization + and percent of times spent on matmal and convolution + + Args: + filename(str): Name of chrome traces file produced by pytorch profiler + + total_length(float): total length of the process without profiler in second + + Return: + tuple: (GPU Utilization, percent of time spent on matmal and convolution) + """ + events = get_chrome_trace_events(filename) + + # get pids of GPU events + global gpu_pids + gpu_pids = [] + for event in events: + if "name" not in event: + continue + if event["name"] == 'process_labels' and "GPU" in event["args"]["labels"]: + gpu_pids.append(event["pid"]) + + total_length = total_length * 1e6 + sorted_gpu_events = get_sorted_gpu_events(events) + utilization = get_duration(sorted_gpu_events) / total_length + + sorted_gpu_mm_conv_events = get_sorted_gpu_mm_conv_events(events) + mm_conv_utilization = get_duration(sorted_gpu_mm_conv_events) / total_length + + return utilization, mm_conv_utilization + + +def benchmark_utilization(f, input, trace_folder, optimize_ctx=None, trace_file_name="tmp_chrome_trace", num_runs=1): + """ + Benchmark the GPU Utilization and percent of time spent on matmal and convolution operations of + running f(input, **kwargs_for_f) with [optimize_ctx] [num_runs] times. + It will produce a chrome trace file in trace_folder/trace_file_name.json + + Example: + + ``` + def f(a): + return a.sum() + a = torch.rand(2**20, device="cuda") + utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace") + ``` + + Args: + f: function to benchmark + + input: input to :attr:`f` + + trace_folder: name of the folder to store the chrome trace + + optimize_ctx: the context in which f will run + + trace_file_name: name of the dumped chrome trace file, default to "tmp_chrome_trace" + + num_runs: number of times to run f, excluding the warm-up runs, default to 1. + + Return: + tuple: (GPU Utilization, percent of time spent on matmal and convolution) + + """ + isExist = os.path.exists(trace_folder) + if not isExist: + os.makedirs(trace_folder) + print("create folder " + trace_folder) + + if optimize_ctx is None: + optimize_ctx = NullContext() + + chrome_trace_file_name = os.path.join(trace_folder, trace_file_name + ".json") + total_length = dump_chrome_trace(f, input, chrome_trace_file_name, optimize_ctx, + [ProfilerActivity.CUDA], num_runs=num_runs, devices="cuda") + utilization, mm_conv_utilization = compute_utilization(chrome_trace_file_name, total_length) + + return utilization, mm_conv_utilization diff --git a/functorch/functorch/_src/compile_utils.py b/functorch/functorch/_src/compile_utils.py new file mode 100644 index 0000000000000..caebe100ecfb7 --- /dev/null +++ b/functorch/functorch/_src/compile_utils.py @@ -0,0 +1,80 @@ + +import torch +import torch.fx as fx +from torch.utils._pytree import tree_flatten + +aten = torch.ops.aten + + +def get_aten_target(node): + if hasattr(node.target, 'overloadpacket'): + return node.target.overloadpacket + return node.target + + +rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma, + aten.bernoulli, aten.multinomial, aten.native_dropout, + aten.normal, aten.poisson, aten.binomial, aten.rrelu, + aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm] + + +# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph +def fx_graph_cse(fx_g: torch.fx.graph.Graph): + new_graph = fx.Graph() + env = {} # map from node in the old graph to node in the new graph + hash_env = {} # map from hash to a node in the new graph + token_map = {} # map from hash to token + for n in fx_g.nodes: + # The placeholder, output, and get_attr nodes are copied to the new grpah without change + # do not CSE away random operations + if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops: + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs memebrs to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, torch.fx.node.Node) and v in env: + arg_list[i] = env[v] + return tuple(arg_list), spec + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = {"target": n.target, "args": args, "args_spec": args_spec, + "kwargs": kwargs, "kwargs_spec": kwargs_spec} + + # hash substituted args to a number, do not hash specs because specs are not hashable + hash_arg = hash((args, kwargs)) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + return new_graph + + +def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() diff --git a/functorch/functorch/_src/compilers.py b/functorch/functorch/_src/compilers.py new file mode 100644 index 0000000000000..e860c8a7c7669 --- /dev/null +++ b/functorch/functorch/_src/compilers.py @@ -0,0 +1,464 @@ +import torch +import torch.fx as fx +import torch.nn as nn +from functools import partial +from typing import Callable, Iterable, Optional, Tuple, Union + +from .aot_autograd import aot_function, aot_module +from .decompositions import get_decompositions +from .partitioners import draw_graph, min_cut_rematerialization_partition, default_partition +from .compile_utils import strip_overloads +import time +import os +import pickle +import random +import copy +import logging + + +# These canonicalizations are needed here (and not decompositions), as the ops +# we're trying to canonicalize to CompositeImplicitAutograd. +def _canonicalize(fx_g): + for node in fx_g.graph.nodes: + if node.target == torch.ops.aten._to_copy: + node.target = torch.ops.aten.to + fx_g.recompile() + return fx_g + + +def ts_compile(fx_g: fx.GraphModule, _) -> Callable: + """ + Compiles the :attr:`fx_g` with Torchscript compiler. + + .. warning:: + This API is experimental and likely to change. + + Args: + fx_g(fx.GraphModule): The input Fx graph module to be compiled. + + Returns: + Torch scripted model. + """ + for node in fx_g.graph.nodes: + if (node.target == torch.ops.aten._to_copy and len(node.args) == 1 + and len(node.kwargs) == 1 and 'dtype' in node.kwargs): + node.target = torch.ops.aten.to + + for node in fx_g.graph.nodes: + new_kwargs = {} + for k, v in node.kwargs.items(): + if isinstance(v, torch.device): + v = v.type + new_kwargs[k] = v + node.kwargs = new_kwargs + + strip_overloads(fx_g) + + fx_g.graph.lint() + + fx_g.recompile() + + f = torch.jit.script(fx_g) + + torch._C._jit_pass_remove_mutation(f.graph) + + f = torch.jit.freeze(f.eval()) + f = torch.jit.optimize_for_inference(f) + return f + + +def tensorexpr_compile(fx_module: fx.GraphModule, flat_args) -> Callable: + """Compiles the given fx_module using TensorExpr Kernel""" + inp_devices = {i.device for i in flat_args if isinstance(i, torch.Tensor)} + assert len(inp_devices) == 1 + inp_device = list(inp_devices)[0] + inputs = [] + output_refs = [] + for node in fx_module.graph.nodes: + if node.op == "placeholder": + inputs.append(node) + elif node.op == "output": + outputs = node.args[0] + if not isinstance(outputs, Iterable): + outputs = (outputs,) + new_outputs = [] + for idx, output in enumerate(outputs): + # Appends (bool, idx) pairs + # if True, read from kernel outputs + # if False, read from kernel inputs + if output in inputs: + output_refs.append((False, inputs.index(output))) + elif output in outputs[:idx]: + output_refs.append((True, output_refs[outputs.index(output)][1])) + else: + output_refs.append((True, len(new_outputs))) + new_outputs.append(output) + node.args = (tuple(new_outputs),) + fx_module.graph.lint() + fx_module.recompile() + + for i in range(0, 100): + attr = f"_tensor_constant{i}" + if hasattr(fx_module, attr): + setattr(fx_module, attr, getattr(fx_module, attr).to(inp_device)) + else: + break + + jit_module = torch.jit.trace(fx_module, flat_args) + jit_module = torch.jit.freeze(jit_module.eval()) + torch._C._jit_trace_module(jit_module._c, tuple(flat_args)) + torch._C._te.remove_unused_self_argument(jit_module.graph) + torch._C._te.annotate_input_shapes(jit_module.graph, tuple(flat_args)) + torch._C._jit_pass_lower_all_tuples(jit_module.graph) + te_kernel = torch._C._te.TensorExprKernel(jit_module.graph) + + def f(*args): + outs = te_kernel.run(args) + if not isinstance(outs, tuple) and not isinstance(outs, list): + outs = (outs,) + real_outs = [] + for out in output_refs: + if out[0]: + real_outs.append(outs[out[1]]) + else: + real_outs.append(args[out[1]]) + return real_outs + + return f + + +def _draw_graph_compile(fx_g, _, name, clear_meta=True): + print(fx_g.code) + draw_graph(fx_g, name, clear_meta=clear_meta) + return fx_g + + +def draw_graph_compile(name): + return partial(_draw_graph_compile, name=name) + + +def _tvm_compile( + fx_module, example_inputs, target=None, tuning_logfile=None, use_ansor_tuning=False +): + import tvm + from tvm import relay, auto_scheduler + from tvm.contrib import graph_executor + import os + + # Find the target and device for TVM. + dev = tvm.cpu(0) + if target is None: + raise ValueError("Setup the TVM target correctly.") + elif isinstance(target, str): + if "cuda" in target: + dev = tvm.cuda(0) + target = tvm.target.Target(target) + elif isinstance(target, tvm.target.target.Target): + if "cuda" in target.keys: + dev = tvm.cuda(0) + + # JIT the model and pass it to Torchscript to Relay frontend parser. TVM + # tutorials suggest tracing instead of scripting. The main reason is to + # avoid Pythonic computation to show up in JIT module. However, with Python + # key tracing, AOT Autograd leads to simpler graphs. Therefore, we use + # scripting here to retrieve the JIT module. + jit_mod = torch.jit.script(fx_module) + shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] + mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) + + # TVM Autotuning + if use_ansor_tuning: + tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) + if tuning_logfile is None: + log_file = f"{time.time()}.json" + else: + log_file = f"{tuning_logfile}.json" + if len(tasks) != 0: + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=20000, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + # early_stopping=1000, + # verbose=2, + ) + tuner.tune(tune_option) + elif tuning_logfile is not None: + log_file = f"{tuning_logfile}.json" + + if use_ansor_tuning or tuning_logfile is not None: + assert os.path.exists(log_file) + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=3, config={"relay.backend.use_auto_scheduler": True} + ): + lib = relay.build(mod, target=target, params=params) + else: + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + + # Get a graph executor graph module + m = graph_executor.GraphModule(lib["default"](dev)) + + def exec_tvm(*args): + for idx, arg in enumerate(args, 0): + if arg.dim() != 0: + m.set_input( + f"inp_{idx}", + tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(arg.contiguous())), + ) + m.run() + outs = [ + torch.utils.dlpack.from_dlpack(m.get_output(i).to_dlpack()) + for i in range(m.get_num_outputs()) + ] + return outs + + return exec_tvm + + +def tvm_compile(target, tuning_logfile=None, use_ansor_tuning=False): + return partial( + _tvm_compile, + target=target, + tuning_logfile=tuning_logfile, + use_ansor_tuning=use_ansor_tuning, + ) + + +def nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler + and can be used to check accuracy. + + .. warning:: + This API is experimental and likely to change. + + """ + return fx_g + + +def simple_ts_compile(fx_g, _): + strip_overloads(fx_g) + f = torch.jit.script(fx_g) + f = torch.jit.freeze(f.eval()) + return f + + +def nnc_jit(f, static_argnums=None): + return aot_function(f, simple_ts_compile, static_argnums=static_argnums) + + +aten = torch.ops.aten +default_decompositions = { + aten.detach, + aten.gelu_backward, + aten.leaky_relu_backward, + aten.sigmoid_backward, + aten.threshold_backward, + aten.hardtanh_backward, + aten.hardsigmoid_backward, + aten.hardswish_backward, + aten.tanh_backward, + aten.silu_backward, + aten.elu_backward, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.masked_fill.Scalar, + aten.masked_fill.Tensor, + aten.elu, + aten.leaky_relu, + aten.hardtanh, + aten.hardswish, + aten.hardsigmoid, +} + +default_decompositions = get_decompositions(default_decompositions) + + +def print_compile(fx_g, _): + print(fx_g.code) + return fx_g + + +def memory_efficient_fusion( + fn: Union[Callable, nn.Module], static_argnums: Optional[Tuple[int]] = None, **kwargs +): + """ + Wrapper function over :func:`aot_function` and :func:`aot_module` to perform + memory efficient fusion. It uses the + :func:`min_cut_rematerialization_partition` partitioner to perform efficient + recomputation. It uses NVFuser to compile the generated forward and backward + graphs. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` + that takes one ore more arguments. Must return one or more Tensors. + static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark + the arguments of the function as static. + **kwargs: Any other overrides you want to make to the settings + + Returns: + Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior + of the original :attr:`fn`, but whose forward and backward graphs have + gone through recomputation optimizations, and the graphs have been + compiled with nvfuser. + + """ + config = { + "fw_compiler": ts_compile, + "bw_compiler": ts_compile, + "partition_fn": min_cut_rematerialization_partition, + "hasher_type": "StaticShapeHasher", + "decompositions": default_decompositions, + "static_argnums": static_argnums, + } + config.update(kwargs) + if isinstance(fn, torch.nn.Module): + return aot_module(fn, **config) + else: + return aot_function(fn, **config) + + +def debug_compile(fx_g, inps): + fx_g.to_folder("foo") + print( + f""" +############################################################## +# To minimize FX graph, copy and paste the below and run it # +############################################################## + +import torch +import torch.fx as fx +from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess + +inps = {[(i.shape, i.dtype) for i in inps]} +inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] +from foo import FxModule +mod = FxModule().cuda() + +with torch.jit.fuser("fuser2"): + # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess + minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess) +""" + ) + from foo import FxModule + + FxModule().cuda()(*inps) + + return ts_compile(fx_g, inps) + + +graph_index = 0 + + +def get_inputs(input_data_path): + """ + Return a random input for the given inputs meta generated from _save_fx_default. + """ + inputs = [] + with (open(input_data_path, 'rb')) as f: + inputs_meta = pickle.load(f) + inputs = [] + for meta in inputs_meta: + if len(meta) == 1: + type = meta + input = type(random.rand()) + else: + type, shape, stride, dtype, device = meta + if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8, int, float}: + input = torch.randint(0, 1, shape, dtype=dtype, device=device) + else: + input = torch.rand(shape, dtype=dtype, device=device) + inputs.append(input) + return inputs + + +def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs): + """ + The forward, backward, and joint computation graph will be stored in + {folder_name}/{current_name}/{current_name}_forward_{graph_index}, + {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and + {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively. + The input shape of the graphs will be stored in the .input files. + These files can be loaded with pickle, + and is a list of format (type, shape, stride, dtype, device). + In the case of type = int or float, it is just (type,). + For joint graph input, it is a nested list [[],[]] + where the two inner lists have the same format. + If dump_example_input is True, example_inputs will be stored in .pt file. + Since each function might produce multiple graphs, + the graph_index is used to distinguish difference graphs + """ + from functorch.compile import aot_module_simplified + + def get_input_meta(args): + input_meta = [] + if len(args) > 0 and isinstance(args[0], tuple): # joint input + input_meta += get_input_meta(args[0]) + input_meta += get_input_meta(args[1]) + return input_meta + for arg in args: + if(type(arg) == int or type(arg) == float): + input_meta.append((type(arg),)) + else: + input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)) + return input_meta + + def graph_saver_helper(gm_to_save, args, type_name): + global graph_index + if len(gm_to_save.graph.nodes) == 0: + logging.log(logging.WARNING, f"No nodes in graph {current_name}_{type_name}_{graph_index}.") + return + + gm = copy.deepcopy(gm_to_save) + gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen + gm.recompile() + + input_meta = get_input_meta(args) + + isExist = os.path.exists(f"{folder_name}/{current_name}") + if not isExist: + os.makedirs(f"{folder_name}/{current_name}") + gm.to_folder(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}") + pickle.dump(input_meta, open(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", "wb")) # noqa: E501 + if dump_example_input: + torch.save(args, f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt") # noqa: E501 + + def graph_saver_forward(gm, fw_args): + graph_saver_helper(gm, fw_args, "forward") + return gm + + def graph_saver_backward(gm, bw_args): + graph_saver_helper(gm, bw_args, "backward") + global graph_index + graph_index += 1 + return gm + + def graph_saver_joint(gm, joint_args): + graph_saver_helper(gm, joint_args, "joint") + return default_partition(gm, joint_args) + + return aot_module_simplified(gm, fw_compiler=graph_saver_forward, + bw_compiler=graph_saver_backward, + partition_fn=graph_saver_joint, + decompositions=default_decompositions) + + +def graph_dumper_aot(current_name, folder_name, dump_example_input=False): + """ + Dump the forward, backward, and joint computation graph. + Example Usage: + save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False) + optimize_ctx = torchdynamo.optimize( + save_fx_func + ) + with torch.enable_grad(): + with optimize_ctx: + result = forward_and_backward_pass(model, example_inputs) + """ + global graph_index + graph_index = 0 + return partial(_save_fx_default, current_name, folder_name, dump_example_input) diff --git a/functorch/functorch/_src/config.py b/functorch/functorch/_src/config.py new file mode 100644 index 0000000000000..583dbcec74553 --- /dev/null +++ b/functorch/functorch/_src/config.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Global flags for aot autograd +""" +import os + +use_functionalize = False + +# TODO: flip this to true by default +# Waiting on +# https://github.com/pytorch/pytorch/pull/81617 +# https://github.com/pytorch/pytorch/pull/81609 +# https://github.com/pytorch/pytorch/pull/81604 +# fix for test_aot_autograd_exhaustive_sgn_cpu_float32 _efficientzerotensor +# fix for complex numbers +use_fake_tensor = False + +debug_partitioner = os.environ.get('AOT_PARTITIONER_DEBUG', False) +# Prints out forward + backwards FX graphs +debug_graphs = os.environ.get('AOT_FX_GRAPHS', False) +# Prints out joint graph traced, before partitioning +debug_joint = os.environ.get('AOT_FX_GRAPHS_JOINT', False) diff --git a/functorch/functorch/_src/custom_function.py b/functorch/functorch/_src/custom_function.py new file mode 100644 index 0000000000000..028a246c62a32 --- /dev/null +++ b/functorch/functorch/_src/custom_function.py @@ -0,0 +1,20 @@ +import torch +import functorch._C + +m = functorch._C._dispatch_library("FRAGMENT", "aten", "") + + +def custom_vjp(name, filter_fn, fwd_fn, bwd_fn): + m.def_(f"{name}(Tensor[] args) -> Tensor[]") + m.impl(f"{name}", "CompositeImplicitAutograd", fwd_fn) + + m.def_(f"{name}_vjp(Tensor[] args) -> Tensor[]") + m.impl(f"{name}_vjp", "CompositeImplicitAutograd", bwd_fn) + + # TODO: it looks like the autograd alias key doesn't work + m.gen_backward_binding(f"{name}", "AutogradCPU") + m.gen_backward_binding(f"{name}", "AutogradCUDA") + + def wrapped(*args): + return filter_fn(getattr(torch.ops.aten, name)(args)) + return wrapped diff --git a/functorch/functorch/_src/decompositions.py b/functorch/functorch/_src/decompositions.py new file mode 100644 index 0000000000000..3780d09db20d4 --- /dev/null +++ b/functorch/functorch/_src/decompositions.py @@ -0,0 +1,219 @@ +import torch +from torch import Tensor +import torch._decomp +from typing import Tuple, List, Optional + +aten = torch.ops.aten + +decomposition_table = torch._decomp.decomposition_table +register_decomposition = torch._decomp.register_decomposition +get_decompositions = torch._decomp.get_decompositions + +# Decompositions have been ported to torch._decomp inside of PyTorch core. +# The only decompositions here are temporary or hacks. +# Please submit your contributions to PyTorch core! + + +def maybe_register_decomposition(op): + def decorator(f): + try: + return register_decomposition(op)(f) + except Exception: + return f + return decorator + + +# Functions where we need a special decomposition for jvp but there's another version that +# should be used more generally (ex. for jvp we need to recompute the mean and variance for +# the backwards of a normalization function. Without jvp, it should used the saved value) +decomposition_table_for_jvp = {} + + +def register_decomposition_for_jvp(fn): + return register_decomposition(fn, registry=decomposition_table_for_jvp) + + +@maybe_register_decomposition(aten.trace.default) +def trace(self: Tensor) -> Tensor: + return torch.sum(torch.diag(self)) + + +@maybe_register_decomposition(aten.log_sigmoid_forward.default) +def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +def recompute_mean_var(input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool): + # for most norm decompositions, it will be the same as the core version except for here. + # We recompute the mean and variance so that they track gradients through input + + mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim) + var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim) + eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside + eps = eps.detach() + rstd = 1 / torch.sqrt(var + eps) + return mean, rstd + + +@register_decomposition_for_jvp(aten.native_layer_norm_backward) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: List[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: List[bool], +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices = list(range(axis, input_ndim)) + outer_dim_indices = list(range(0, axis)) + + N = 1 + for i in inner_dims: + N *= i + M = 1 + for i in outer_dims: + M *= i + if M <= 0 or N <= 0: + return ( + input.new_zeros(input_shape), + input.new_zeros(input_shape[axis:]), + input.new_zeros(input_shape[axis:]), + ) + + mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True) + + x_hat = (input - mean_) * rstd_ + if weight is not None: + grad_x_hat = grad_out * weight + else: + grad_x_hat = grad_out + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + inner = a - b - c3 + + if output_mask[0]: + d_input: Optional[Tensor] = (rstd_ / N) * inner + else: + d_input = torch.zeros_like(input) # should be None but doesn't work with vjp + + if output_mask[1] and weight is not None: + if len(outer_dim_indices) > 0: + d_weight: Optional[Tensor] = torch.sum( + grad_out * x_hat, outer_dim_indices, False + ) + else: + d_weight = grad_out * x_hat + elif weight is not None: + d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp + else: + d_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2] and bias is not None: + if len(outer_dim_indices) > 0: + d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False) + else: + d_bias = grad_out + elif bias is not None: + d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp + else: + d_bias = torch.zeros(()) # should be None but doesn't work with vjp + + return (d_input, d_weight, d_bias) + + +def prod(x: List[int]): + r = 1 + for i in x: + r *= i + return r + + +@register_decomposition_for_jvp(aten.native_batch_norm_backward) # @register_decomposition_for_jvp after in core +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(input_shape) / input_shape[axis] + mean = save_mean + invstd = save_invstd + if train: + assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required" + + reduciton_dims = [0] + list(range(2, input.dim())) + assert invstd is not None # for typing + mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False) + else: + assert running_mean is not None and running_var is not None + mean = running_mean + invstd = torch.rsqrt(running_var + eps) + + broadcast_mask = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = torch.reshape(mean, broadcast_mask) + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out, reduction_axes) + dot_p = torch.sum(grad_out * (input - mean), reduction_axes) + + grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) + proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) + + if weight is None: + grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 + else: + grad_scale = torch.reshape(invstd * weight, broadcast_mask) + + if train: + proj = (input - mean) * proj_scale + grad_input = ((grad_out - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + elif weight is not None: + grad_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp + else: + grad_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = torch.zeros_like(grad_output_sum) # should be None but doesn't work with vjp + + return (grad_input, grad_weight, grad_bias) diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py new file mode 100644 index 0000000000000..6cfa66d1c9b7b --- /dev/null +++ b/functorch/functorch/_src/eager_transforms.py @@ -0,0 +1,1497 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Union, Tuple, List, Any +import torch +import inspect +from functools import partial, wraps +import contextlib +from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map +from .pytree_hacks import tree_map_, treespec_pprint +import torch.autograd.forward_ad as fwAD + +from .vmap import vmap +from .decompositions import decomposition_table, decomposition_table_for_jvp + + +from functorch._C import ( + _wrap_for_grad, + _unwrap_for_grad, + _grad_increment_nesting, + _grad_decrement_nesting, + _jvp_increment_nesting, + _jvp_decrement_nesting, + set_fwd_grad_enabled, + get_fwd_grad_enabled, + _wrap_functional_tensor, + _unwrap_functional_tensor, + _func_decrement_nesting, + _func_increment_nesting, + _assert_wrapped_functional, + _propagate_functional_input_mutation, + set_inplace_requires_grad_allowed, + get_inplace_requires_grad_allowed, +) + +argnums_t = Union[int, Tuple[int, ...]] + + +@contextlib.contextmanager +def enable_inplace_requires_grad(enabled=True): + prev_state = get_inplace_requires_grad_allowed() + set_inplace_requires_grad_allowed(enabled) + try: + yield + finally: + set_inplace_requires_grad_allowed(prev_state) + + +def _create_differentiable(inps, level=None): + def create_differentiable(x): + if isinstance(x, torch.Tensor): + with enable_inplace_requires_grad(): + return x.requires_grad_() + raise ValueError(f'Thing passed to transform API must be Tensor, ' + f'got {type(x)}') + return tree_map(create_differentiable, inps) + + +def _undo_create_differentiable(inps, level=None): + def unwrap_tensors(x): + if isinstance(x, torch.Tensor): + return _unwrap_for_grad(x, level) + # TODO: Remove the following hack for namedtuples + if isinstance(x, tuple): + return tree_map(unwrap_tensors, tuple(x)) + + raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}") + + return tree_map(unwrap_tensors, inps) + + +def _is_differentiable(maybe_tensor): + if not isinstance(maybe_tensor, torch.Tensor): + return False + return maybe_tensor.requires_grad + + +def _any_differentiable(tensor_or_tuple_of_tensors): + flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors) + return any(tuple(map(_is_differentiable, flat_args))) + + +def _wrap_tensor_for_grad(maybe_tensor, level): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + return _wrap_for_grad(maybe_tensor, level) + + +def _wrap_all_tensors(tensor_pytree, level): + return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree) + + +def _as_tuple(val): + if isinstance(val, tuple): + return val + return (val,) + +# Version of autograd.grad that handles outputs that don't depend on inputs + + +def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True): + if grad_outputs is None: + diff_outputs = tuple(out for out in outputs if out.requires_grad) + else: + result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad) + if len(result) == 0: + diff_outputs, grad_outputs = (), () + else: + diff_outputs, grad_outputs = zip(*result) + if len(diff_outputs) == 0: + return tuple(torch.zeros_like(inp) for inp in inputs) + grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + allow_unused=True) + grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi + for gi, inp in zip(grad_inputs, inputs)) + return grad_inputs + +# NOTE [grad and vjp interaction with no_grad] +# +# def f(x): +# with torch.no_grad(): +# c = x ** 2 +# return x - c +# +# The thing to consider is if enable_grad is on/off before grad gets called. +# +# Case 1: enable_grad is on. +# grad(f)(x) +# In this case, `grad` should respect the inner torch.no_grad. +# +# Case 2: enable_grad is off +# with torch.no_grad(): +# grad(f)(x) +# In this case, `grad` should respect the inner torch.no_grad, but not the +# outer one. This is because `grad` is a "function transform": its result +# should not depend on the result of a context manager outside of `f`. +# +# This gives us the following desired behavior: +# - (nested) grad transforms must obey torch.no_grad inside them +# - (nested) grad transforms should not obey torch.no_grad outside them +# +# To achieve this behavior, upon entering grad/vjp: +# - we save the current ("previous") is_grad_enabled (*) +# - we unconditionally enable grad. +# +# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer +# off the stack: +# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad +# active, all subsequent grad transforms must obey it). +# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False, +# then we temporarily restore the previous `is_grad_enabled`. This is +# because we're crossing the boundary from a `grad` outside the +# no_grad to a `grad` inside the no_grad. +# +# NB: vjp has some interesting behavior because the vjp's callable can be called +# under a different grad_mode than the forward computation... +# +# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but +# it respects c10::AutoFwGradMode. We've implemented the same logic for +# our jvp transform (it will have special handling if FwGradMode is disabled). + + +# How do we increment and decrement the nesting? I don't think we can. +def vjp(func: Callable, *primals, has_aux: bool = False): + """ + Standing for the vector-Jacobian product, returns a tuple containing the + results of :attr:`func` applied to :attr:`primals` and a function that, when + given ``cotangents``, computes the reverse-mode Jacobian of :attr:`func` with + respect to :attr:`primals` times ``cotangents``. + + Args: + func (Callable): A Python function that takes one or more arguments. Must + return one or more Tensors. + primals (Tensors): Positional arguments to :attr:`func` that must all be + Tensors. The returned function will also be computing the + derivative with respect to these arguments + has_aux (bool): Flag indicating that :attr:`func` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + other auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`func` + applied to :attr:`primals` and a function that computes the vjp of + :attr:`func` with respect to all :attr:`primals` using the cotangents passed + to the returned function. If ``has_aux is True``, then instead returns a + ``(output, vjp_fn, aux)`` tuple. + The returned ``vjp_fn`` function will return a tuple of each VJP. + + When used in simple cases, :func:`vjp` behaves the same as :func:`grad` + + >>> x = torch.randn([5]) + >>> f = lambda x: x.sin().sum() + >>> (_, vjpfunc) = functorch.vjp(f, x) + >>> grad = vjpfunc(torch.tensor(1.))[0] + >>> assert torch.allclose(grad, functorch.grad(f)(x)) + + However, :func:`vjp` can support functions with multiple outputs by + passing in the cotangents for each of the outputs + + >>> x = torch.randn([5]) + >>> f = lambda x: (x.sin(), x.cos()) + >>> (_, vjpfunc) = functorch.vjp(f, x) + >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) + >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) + + :func:`vjp` can even support outputs being Python structs + + >>> x = torch.randn([5]) + >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} + >>> (_, vjpfunc) = functorch.vjp(f, x) + >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} + >>> vjps = vjpfunc(cotangents) + >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) + + The function returned by :func:`vjp` will compute the partials with + respect to each of the :attr:`primals` + + >>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) + >>> (_, vjpfunc) = functorch.vjp(torch.matmul, x, y) + >>> cotangents = torch.randn([5, 5]) + >>> vjps = vjpfunc(cotangents) + >>> assert len(vjps) == 2 + >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) + >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents)) + + :attr:`primals` are the positional arguments for :attr:`f`. All kwargs use their + default value + + >>> x = torch.randn([5]) + >>> def f(x, scale=4.): + >>> return x * 4. + >>> + >>> (_, vjpfunc) = functorch.vjp(f, x) + >>> vjps = vjpfunc(torch.ones_like(x)) + >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.)) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``vjp``. + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager: + + >>> with torch.no_grad(): + >>> vjp(f)(x) + + In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``vjp`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + """ + level = _grad_increment_nesting() + try: + # See NOTE [grad and vjp interaction with no_grad] + with torch.enable_grad(): + primals = _wrap_all_tensors(primals, level) + diff_primals = _create_differentiable(primals, level) + primals_out = func(*diff_primals) + + if has_aux: + if not (isinstance(primals_out, tuple) and len(primals_out) == 2): + raise RuntimeError( + "vjp(f, *primals): output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + primals_out, aux = primals_out + aux = _undo_create_differentiable(aux, level) + + flat_primals_out, primals_out_spec = tree_flatten(primals_out) + assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)') + flat_diff_primals, primals_spec = tree_flatten(diff_primals) + results = _undo_create_differentiable(primals_out, level) + + for primal_out in flat_primals_out: + assert isinstance(primal_out, torch.Tensor) + if primal_out.is_floating_point() or primal_out.is_complex(): + continue + raise RuntimeError("vjp(f, ...): All outputs of f must be " + "floating-point or complex Tensors, got Tensor " + f"with dtype {primal_out.dtype}") + + def wrapper(cotangents, retain_graph=True, create_graph=None): + if create_graph is None: + create_graph = torch.is_grad_enabled() + flat_cotangents, cotangents_spec = tree_flatten(cotangents) + if primals_out_spec != cotangents_spec: + raise RuntimeError( + f'Expected pytree structure of cotangents to be the same ' + f'as pytree structure of outputs to the function. ' + f'cotangents: {treespec_pprint(cotangents_spec)}, ' + f'primal output: {treespec_pprint(primals_out_spec)}') + result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents, + retain_graph=retain_graph, create_graph=create_graph) + return tree_unflatten(result, primals_spec) + + finally: + _grad_decrement_nesting() + + if has_aux: + return results, wrapper, aux + else: + return results, wrapper + + +def _safe_zero_index(x): + assert len(x) == 1 + return x[0] + + +def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False): + """ + Computes the Jacobian of :attr:`func` with respect to the arg(s) at index + :attr:`argnum` using reverse mode autodiff + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that :attr:`func` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a function that takes in the same inputs as :attr:`func` and + returns the Jacobian of :attr:`func` with respect to the arg(s) at + :attr:`argnums`. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by :attr:`func`. + + A basic usage with a pointwise, unary operation will give a diagonal array + as the Jacobian + + >>> from functorch import jacrev + >>> x = torch.randn(5) + >>> jacobian = jacrev(torch.sin)(x) + >>> expected = torch.diag(torch.cos(x)) + >>> assert torch.allclose(jacobian, expected) + + If you would like to compute the output of the function as well as the + jacobian of the function, use the ``has_aux`` flag to return the output + as an auxiliary object: + + >>> from functorch import jacrev + >>> x = torch.randn(5) + >>> + >>> def f(x): + >>> return x.sin() + >>> + >>> def g(x): + >>> result = f(x) + >>> return result, result + >>> + >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x) + >>> assert torch.allclose(f_x, f(x)) + + :func:`jacrev` can be composed with vmap to produce batched + Jacobians: + + >>> from functorch import jacrev, vmap + >>> x = torch.randn(64, 5) + >>> jacobian = vmap(jacrev(torch.sin))(x) + >>> assert jacobian.shape == (64, 5, 5) + + Additionally, :func:`jacrev` can be composed with itself to produce + Hessians + + >>> from functorch import jacrev + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hessian = jacrev(jacrev(f))(x) + >>> assert torch.allclose(hessian, torch.diag(-x.sin())) + + By default, :func:`jacrev` computes the Jacobian with respect to the first + input. However, it can compute the Jacboian with respect to a different + argument by using :attr:`argnums`: + + >>> from functorch import jacrev + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacrev(f, argnums=1)(x, y) + >>> expected = torch.diag(2 * y) + >>> assert torch.allclose(jacobian, expected) + + Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian + with respect to multiple arguments + + >>> from functorch import jacrev + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacrev(f, argnums=(0, 1))(x, y) + >>> expectedX = torch.diag(torch.ones_like(x)) + >>> expectedY = torch.diag(2 * y) + >>> assert torch.allclose(jacobian[0], expectedX) + >>> assert torch.allclose(jacobian[1], expectedY) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``jacrev``. + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager: + + >>> with torch.no_grad(): + >>> jacrev(f)(x) + + In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``jacrev`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + """ + @wraps(func) + def wrapper_fn(*args): + f_wrapper, primals = _argnums_partial(func, args, argnums) + vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux) + if has_aux: + output, vjp_fn, aux = vjp_out + else: + output, vjp_fn = vjp_out + + # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs] + flat_output, output_spec = tree_flatten(output) + + # NB: vjp already checks that all outputs are tensors + # Step 1: Construct grad_outputs by splitting the standard basis + flat_output_numels = tuple(out.numel() for out in flat_output) + flat_basis = _construct_standard_basis_for(flat_output, flat_output_numels) + basis = tree_unflatten(flat_basis, output_spec) + + results = vmap(vjp_fn)(basis) + + flat_primals, primals_spec = tree_flatten(primals) + flat_results, results_spec = tree_flatten(results) + + # Step 2: The returned jacobian is one big tensor per input. In this step, + # we split each Tensor by output. + flat_results = [result.split(flat_output_numels, dim=0) for result in flat_results] + flat_input_flat_output = [ + tuple(split.view(out.shape + primal.shape) + for split, out in zip(splits, flat_output)) + for splits, primal in zip(flat_results, flat_primals) + ] + + # Step 3: Right now, `jacobian` is a List[List[Tensor]]. + # The outer List corresponds to the number of primals, + # the inner List corresponds to the number of outputs. + # We need to: + # a. Exchange the order of the outer List and inner List + # b. tree_unflatten the inner Lists (which correspond to the primals) + # c. handle the argnums=int case + # d. tree_unflatten the outer List (which corresponds to the outputs) + flat_output_flat_input = tuple(zip(*flat_input_flat_output)) + + flat_output_input = tuple(tree_unflatten(flat_input, primals_spec) + for flat_input in flat_output_flat_input) + + if isinstance(argnums, int): + flat_output_input = tuple(_safe_zero_index(flat_input) + for flat_input in flat_output_input) + output_input = tree_unflatten(flat_output_input, output_spec) + if has_aux: + return output_input, aux + return output_input + return wrapper_fn + +# NOTE: [Computing jacobian with vmap and vjp for multiple outputs] +# +# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). +# It turns out we can compute the jacobian of this function with a single +# call to autograd.grad by using vmap over the correct grad_outputs. +# +# Firstly, one way to compute the jacobian is to stack x**2 and x.sum() +# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) +# +# To get the first row of the jacobian, we call +# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) +# To get the 2nd row of the jacobian, we call +# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) +# and so on. +# +# Using vmap, we can vectorize all 4 of these computations into one by +# passing the standard basis for R^4 as the grad_output. +# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). +# +# Now, how do we compute the jacobian *without stacking the output*? +# We can just split the standard basis across the outputs. So to +# compute the jacobian of f(x), we'd use +# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) +# The grad_outputs looks like the following: +# ( torch.tensor([[1, 0, 0], +# [0, 1, 0], +# [0, 0, 1], +# [0, 0, 0]]), +# torch.tensor([[0], +# [0], +# [0], +# [1]]) ) +# +# But we're not done yet! +# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) +# returns a Tensor of shape [4, 3]. We have to remember to split the +# jacobian of shape [4, 3] into two: +# - one of shape [3, 3] for the first output +# - one of shape [ 3] for the second output + + +def _construct_standard_basis_for(tensors, tensor_numels): + # This function: + # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. + # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. + # - Each chunk corresponds to one tensor. The chunk has the same dtype and + # device as the tensor + # + # For example, with tensor_numels = [1, 2, 1], this function returns: + # ( tensor([[1], tensor([[0, 0], tensor([[0], + # [0], [1, 0], [0], + # [0], [0, 1], [0], + # [0]]) , [0, 0]]) , [1]]) ) + # + # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) + # Precondition: tensors always has at least one element. + # + # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] + # for context behind this function. + assert len(tensors) == len(tensor_numels) + assert len(tensors) > 0 + total_numel = sum(tensor_numels) + diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind()) + chunks = tuple(tensor.new_zeros(total_numel, tensor_numel) + for tensor, tensor_numel in zip(tensors, tensor_numels)) + for chunk, diag_start_idx in zip(chunks, diag_start_indices): + chunk.diagonal(diag_start_idx).fill_(1) + chunks = tuple(chunk.view(total_numel, *tensor.shape) + for chunk, tensor in zip(chunks, tensors)) + return chunks + + +def _validate_and_wrap_argnum(argnum, num_args): + if not isinstance(argnum, int): + raise RuntimeError(f'argnum must be int, got: {type(argnum)}') + if argnum >= 0 and argnum < num_args: + return argnum + if argnum < 0 and argnum >= -num_args: + return argnum + num_args + raise RuntimeError(f'Got argnum={argnum}, but only {num_args} positional inputs') + + +def _check_unique_non_empty(argnums): + if isinstance(argnums, tuple): + if len(argnums) == 0: + raise RuntimeError("argnums must be non-empty") + if len(set(argnums)) != len(argnums): + raise RuntimeError(f"argnums elements must be unique, got {argnums}") + + +def _replace_args(old_args, new_args, argnums): + if isinstance(argnums, int): + if len(new_args) != 1: + raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}') + return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args))) + if isinstance(argnums, tuple): + if len(new_args) != len(argnums): + raise RuntimeError( + "new_args should have the same size as argnums. " + f"Argnums size {len(argnums)}, new_args size {len(new_args)}") + + def get_right_elem(i): + return new_args[argnums.index(i)] if i in argnums else old_args[i] + + return tuple(get_right_elem(i) for i in range(len(old_args))) + raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}') + + +def _validate_and_wrap_argnums(argnums, num_args): + if isinstance(argnums, int): + return _validate_and_wrap_argnum(argnums, num_args) + if isinstance(argnums, tuple): + return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums) + raise AssertionError("Should never get here") + + +def _slice_argnums(args, argnums, as_tuple=True): + if not isinstance(argnums, int) and not isinstance(argnums, tuple): + raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}') + argnums = _validate_and_wrap_argnums(argnums, len(args)) + _check_unique_non_empty(argnums) + if isinstance(argnums, int): + if as_tuple: + return (args[argnums],) + else: + return args[argnums] + return tuple(args[i] for i in argnums) + + +def _argnums_partial(f, args, argnums): + def f_wrapper(*wrapper_args): + replaced_args = _replace_args(args, wrapper_args, argnums) + return f(*replaced_args) + wrapper_args = _slice_argnums(args, argnums) + wrapper_args = wrapper_args if isinstance(wrapper_args, tuple) else (wrapper_args, ) + return (f_wrapper, wrapper_args) + + +JVP_NESTING = 0 + + +@contextlib.contextmanager +def noop(): + yield + + +@contextlib.contextmanager +def enable_fwd_grad(enabled=True): + prev_state = get_fwd_grad_enabled() + set_fwd_grad_enabled(enabled) + try: + yield + finally: + set_fwd_grad_enabled(prev_state) + + +def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None: + if not isinstance(elts, tuple): + raise RuntimeError( + f'{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}') + for elt in elts: + if isinstance(elt, torch.Tensor): + continue + raise RuntimeError( + f'{api}: Expected {argname} to be a tuple of Tensors, got ' + f'a tuple with an element of type {type(elt)}') + if len(elts) == 0: + raise RuntimeError( + f'{api}: Expected {argname} to be a non-empty tuple of Tensors.') + + +def assert_non_empty_tensor_output(output: List[Any], api: str) -> None: + if output == [None] or len(output) < 1: + raise RuntimeError( + f'{api}: Expected f to be a function that has non-empty output (got output = {output})' + ) + for o in output: + if not isinstance(o, torch.Tensor): + raise RuntimeError( + f'{api}: expected f(*primals) to return only tensors' + f', got unsupported type {type(o)}' + ) + + +def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None: + if isinstance(output, torch.Tensor): + return + if not isinstance(output, tuple): + raise RuntimeError( + f'{api}: Expected output of f to be a Tensor or Tensors, got ' + f'{type(output)}') + if len(output) == 0: + raise RuntimeError( + f'{api}: Expected output of f to be a non-empty tuple of Tensors.') + for out in output: + if isinstance(out, torch.Tensor): + continue + raise RuntimeError( + f'{api}: Expected output of f to be a Tensor or Tensors, got ' + f'{type(out)} as an output') + + +def assert_non_empty_list_of_tensors(output: List[torch.Tensor], api: str, argname: str) -> None: + if len(output) == 0: + raise RuntimeError( + f'{api}: Expected {argname} to contain at least one Tensor.') + for out in output: + if isinstance(out, torch.Tensor): + continue + raise RuntimeError( + f'{api}: Expected {argname} to only contain Tensors, got ' + f'{type(out)}') + + +jvp_str = 'jvp(f, primals, tangents)' + + +def safe_unpack_dual(dual, strict): + if not isinstance(dual, torch.Tensor): + raise RuntimeError( + f'{jvp_str}: expected f(*args) to return only tensors' + f', got unsupported type {type(dual)}' + ) + + primal, tangent = fwAD.unpack_dual(dual) + if tangent is None: + if strict: + raise RuntimeError( + 'jvp(f, primals, tangents, strict=True): ' + 'The output of f is independent of ' + 'the inputs. This is not allowed with strict=True.') + tangent = torch.zeros_like(primal) + return primal, tangent + + +def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False): + """ + Standing for the Jacobian-vector product, returns a tuple containing + the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at + ``primals``" times ``tangents``. This is also known as forward-mode autodiff. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + primals (Tensors): Positional arguments to :attr:`func` that must all be + Tensors. The returned function will also be computing the + derivative with respect to these arguments + tangents (Tensors): The "vector" for which Jacobian-vector-product is + computed. Must be the same structure and sizes as the inputs to + ``func``. + has_aux (bool): Flag indicating that :attr:`func` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + other auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a ``(output, jvp_out)`` tuple containing the output of ``func`` + evaluated at ``primals`` and the Jacobian-vector product. + If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple. + + .. warning:: + PyTorch's forward-mode AD coverage on operators is not very good at the + moment. You may see this API error out with "forward-mode AD not + implemented for operator X". If so, please file us a bug report and we + will prioritize it. + + jvp is useful when you wish to compute gradients of a function R^1 -> R^N + + >>> from functorch import jvp + >>> x = torch.randn([]) + >>> f = lambda x: x * torch.tensor([1., 2., 3]) + >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) + >>> assert torch.allclose(value, f(x)) + >>> assert torch.allclose(grad, torch.tensor([1., 2, 3])) + + :func:`jvp` can support functions with multiple inputs by passing in the + tangents for each of the inputs + + >>> from functorch import jvp + >>> x = torch.randn(5) + >>> y = torch.randn(5) + >>> f = lambda x, y: (x * y) + >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) + >>> assert torch.allclose(output, x + y) + + """ + if not isinstance(primals, tuple): + raise RuntimeError( + f'{jvp_str}: Expected primals to be a tuple. ' + f'E.g. it should be valid to call f(*primals).') + flat_primals, primals_spec = tree_flatten(primals) + flat_tangents, tangents_spec = tree_flatten(tangents) + if primals_spec != tangents_spec: + raise RuntimeError( + f'{jvp_str}: Expected primals and tangents to have the same python ' + f'structure. For example, if primals is a tuple of 3 tensors, ' + f'tangents also must be. Got primals with structure {primals_spec} ' + f'and tangents with structure {tangents_spec}') + assert_non_empty_list_of_tensors(flat_primals, jvp_str, 'primals') + assert_non_empty_list_of_tensors(flat_tangents, jvp_str, 'tangents') + + level = _jvp_increment_nesting() + try: + global JVP_NESTING + JVP_NESTING += 1 + with enable_fwd_grad(): + ctx = fwAD.dual_level if JVP_NESTING == 1 else noop + with ctx(): + flat_duals = tuple(fwAD.make_dual(p, t) + for p, t in zip(flat_primals, flat_tangents)) + duals = tree_unflatten(flat_duals, primals_spec) + result_duals = func(*duals) + if has_aux: + if not (isinstance(result_duals, tuple) and len(result_duals) == 2): + raise RuntimeError( + f"{jvp_str}: output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + result_duals, aux = result_duals + aux = _undo_create_differentiable(aux, level) + + result_duals, spec = tree_flatten(result_duals) + assert_non_empty_tensor_output(result_duals, jvp_str) + + primals_out, tangents_out = \ + zip(*[safe_unpack_dual(dual, strict) for dual in result_duals]) + primals_out = tree_map( + partial(_undo_create_differentiable, level=level), primals_out) + tangents_out = tree_map( + partial(_undo_create_differentiable, level=level), tangents_out) + + primals_out_unflatten = tree_unflatten(primals_out, spec) + tangents_out_unflatten = tree_unflatten(tangents_out, spec) + if has_aux: + return primals_out_unflatten, tangents_out_unflatten, aux + + return primals_out_unflatten, tangents_out_unflatten + finally: + _jvp_decrement_nesting() + JVP_NESTING -= 1 + + +def safe_unflatten(tensor, dim, shape): + if len(shape) == 0: + assert tensor.shape[dim] == 1 + return tensor.squeeze(dim) + return tensor.unflatten(dim, shape) + + +def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False): + """ + Computes the Jacobian of :attr:`func` with respect to the arg(s) at index + :attr:`argnum` using forward-mode autodiff + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that :attr:`func` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a function that takes in the same inputs as :attr:`func` and + returns the Jacobian of :attr:`func` with respect to the arg(s) at + :attr:`argnums`. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by :attr:`func`. + + .. warning:: + PyTorch's forward-mode AD coverage on operators is not very good at the + moment. You may see this API error out with "forward-mode AD not + implemented for operator X". If so, please file us a bug report and we + will prioritize it. + + A basic usage with a pointwise, unary operation will give a diagonal array + as the Jacobian + + >>> from functorch import jacfwd + >>> x = torch.randn(5) + >>> jacobian = jacfwd(torch.sin)(x) + >>> expected = torch.diag(torch.cos(x)) + >>> assert torch.allclose(jacobian, expected) + + :func:`jacfwd` can be composed with vmap to produce batched + Jacobians: + + >>> from functorch import jacfwd, vmap + >>> x = torch.randn(64, 5) + >>> jacobian = vmap(jacfwd(torch.sin))(x) + >>> assert jacobian.shape == (64, 5, 5) + + If you would like to compute the output of the function as well as the + jacobian of the function, use the ``has_aux`` flag to return the output + as an auxiliary object: + + >>> from functorch import jacfwd + >>> x = torch.randn(5) + >>> + >>> def f(x): + >>> return x.sin() + >>> + >>> def g(x): + >>> result = f(x) + >>> return result, result + >>> + >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x) + >>> assert torch.allclose(f_x, f(x)) + + Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev` + to produce Hessians + + >>> from functorch import jacfwd, jacrev + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hessian = jacfwd(jacrev(f))(x) + >>> assert torch.allclose(hessian, torch.diag(-x.sin())) + + By default, :func:`jacfwd` computes the Jacobian with respect to the first + input. However, it can compute the Jacboian with respect to a different + argument by using :attr:`argnums`: + + >>> from functorch import jacfwd + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacfwd(f, argnums=1)(x, y) + >>> expected = torch.diag(2 * y) + >>> assert torch.allclose(jacobian, expected) + + Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian + with respect to multiple arguments + + >>> from functorch import jacfwd + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y) + >>> expectedX = torch.diag(torch.ones_like(x)) + >>> expectedY = torch.diag(2 * y) + >>> assert torch.allclose(jacobian[0], expectedX) + >>> assert torch.allclose(jacobian[1], expectedY) + + """ + @wraps(func) + def wrapper_fn(*args): + f_wrapper, primals = _argnums_partial(func, args, argnums) + flat_primals, primals_spec = tree_flatten(primals) + flat_primals_numels = tuple(p.numel() for p in flat_primals) + flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) + basis = tree_unflatten(flat_basis, primals_spec) + + def push_jvp(basis): + output = jvp(f_wrapper, primals, basis, has_aux=has_aux) + if has_aux: + _, jvp_out, aux = output + return jvp_out, aux + _, jvp_out = output + return jvp_out + + results = vmap(push_jvp)(basis) + if has_aux: + results, aux = results + # aux is in the standard basis format, e.g. NxN matrix + # We need to fetch the first element as original `func` output + flat_aux, aux_spec = tree_flatten(aux) + flat_aux = [value[0] for value in flat_aux] + aux = tree_unflatten(flat_aux, aux_spec) + + jac_outs, spec = tree_flatten(results) + # Most probably below output check can never raise an error + # as jvp should test the output before + # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)') + + jac_outs_ins = tuple( + tuple( + safe_unflatten(jac_out_in, -1, primal.shape) + for primal, jac_out_in in + zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1)) + ) + for jac_out in jac_outs + ) + jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins) + + if isinstance(argnums, int): + jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) + if has_aux: + return tree_unflatten(jac_outs_ins, spec), aux + return tree_unflatten(jac_outs_ins, spec) + return wrapper_fn + + +def hessian(func, argnums=0): + """ + Computes the Hessian of :attr:`func` with respect to the arg(s) at index + :attr:`argnum` via a forward-over-reverse strategy. + + The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is + a good default for good performance. It is possible to compute Hessians + through other compositions of :func:`jacfwd` and :func:`jacrev` like + ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Hessian with respect to. + Default: 0. + + Returns: + Returns a function that takes in the same inputs as :attr:`func` and + returns the Hessian of :attr:`func` with respect to the arg(s) at + :attr:`argnums`. + + .. warning:: + PyTorch's forward-mode AD coverage on operators is not very good at the + moment. You may see this API error out with "forward-mode AD not + implemented for operator X". If so, please file us a bug report and we + will prioritize it. + + A basic usage with a R^N -> R^1 function gives a N x N Hessian: + + >>> from functorch import hessian + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hess = jacfwd(jacrev(f))(x) + >>> assert torch.allclose(hess, torch.diag(-x.sin())) + + """ + return jacfwd(jacrev(func, argnums), argnums) + + +def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + """ + Returns a function to compute a tuple of the gradient and primal, or + forward, computation. + + Args: + func (Callable): A Python function that takes one or more arguments. + Must return a single-element Tensor. If specified :attr:`has_aux` + equals ``True``, function can return a tuple of single-element + Tensor and other auxiliary objects: ``(output, aux)``. + argnums (int or Tuple[int]): Specifies arguments to compute gradients + with respect to. :attr:`argnums` can be single integer or tuple of + integers. Default: 0. + has_aux (bool): Flag indicating that :attr:`func` returns a tensor and + other auxiliary objects: ``(output, aux)``. Default: False. + + Returns: + Function to compute a tuple of gradients with respect to its inputs + and the forward computation. By default, the output of the function is + a tuple of the gradient tensor(s) with respect to the first argument + and the primal computation. If specified :attr:`has_aux` equals + ``True``, tuple of gradients and tuple of the forward computation with + output auxiliary objects is returned. If :attr:`argnums` is a tuple of + integers, a tuple of a tuple of the output gradients with respect to + each :attr:`argnums` value and the forward computation is returned. + + See :func:`grad` for examples + """ + @wraps(func) + def wrapper(*args, **kwargs): + level = _grad_increment_nesting() + try: + output, aux, grad_input = None, None, None + # See NOTE [grad and vjp interaction with no_grad] + with torch.enable_grad(): + args = _wrap_all_tensors(args, level) + kwargs = _wrap_all_tensors(kwargs, level) + diff_args = _slice_argnums(args, argnums, as_tuple=False) + tree_map_(partial(_create_differentiable, level=level), diff_args) + + output = func(*args, **kwargs) + if has_aux: + if not (isinstance(output, tuple) and len(output) == 2): + raise RuntimeError( + "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + output, aux = output + + if not isinstance(output, torch.Tensor): + raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) ' + f'to return a Tensor, got {type(output)}') + if output.dim() != 0: + raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) ' + 'to return a scalar Tensor, got tensor with ' + f'{output.dim()} dims. Maybe you wanted to ' + 'use the vjp or jacrev APIs instead?') + + flat_diff_args, spec = tree_flatten(diff_args) + + # NB: need create_graph so that backward pass isn't run in no_grad mode + flat_outputs = _as_tuple(output) + flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True) + grad_input = tree_unflatten(flat_grad_input, spec) + + grad_input = _undo_create_differentiable(grad_input, level) + output = _undo_create_differentiable(output, level) + if aux is not None: + aux = _undo_create_differentiable(aux, level) + + if has_aux: + return grad_input, (output, aux) + return grad_input, output + finally: + _grad_decrement_nesting() + return wrapper + + +def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + """``grad`` operator helps computing gradients of :attr:`func` with respect to the + input(s) specified by :attr:`argnums`. This operator can be nested to + compute higher-order gradients. + + Args: + func (Callable): A Python function that takes one or more arguments. + Must return a single-element Tensor. If specified :attr:`has_aux` equals ``True``, + function can return a tuple of single-element Tensor and other auxiliary objects: + ``(output, aux)``. + argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to. + :attr:`argnums` can be single integer or tuple of integers. Default: 0. + has_aux (bool): Flag indicating that :attr:`func` returns a tensor and other + auxiliary objects: ``(output, aux)``. Default: False. + + Returns: + Function to compute gradients with respect to its inputs. By default, the output of + the function is the gradient tensor(s) with respect to the first argument. + If specified :attr:`has_aux` equals ``True``, tuple of gradients and output auxiliary objects + is returned. If :attr:`argnums` is a tuple of integers, a tuple of output gradients with + respect to each :attr:`argnums` value is returned. + + Example of using ``grad``: + + >>> from functorch import grad + >>> x = torch.randn([]) + >>> cos_x = grad(lambda x: torch.sin(x))(x) + >>> assert torch.allclose(cos_x, x.cos()) + >>> + >>> # Second-order gradients + >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) + >>> assert torch.allclose(neg_sin_x, -x.sin()) + + When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients: + + >>> from functorch import grad + >>> from functorch import vmap + >>> batch_size, feature_size = 3, 5 + >>> + >>> def model(weights, feature_vec): + >>> # Very simple linear model with activation + >>> assert feature_vec.dim() == 1 + >>> return feature_vec.dot(weights).relu() + >>> + >>> def compute_loss(weights, example, target): + >>> y = model(weights, example) + >>> return ((y - target) ** 2).mean() # MSELoss + >>> + >>> weights = torch.randn(feature_size, requires_grad=True) + >>> examples = torch.randn(batch_size, feature_size) + >>> targets = torch.randn(batch_size) + >>> inputs = (weights, examples, targets) + >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) + + Example of using ``grad`` with :attr:`has_aux` and :attr:`argnums`: + + >>> from functorch import grad + >>> def my_loss_func(y, y_pred): + >>> loss_per_sample = (0.5 * y_pred - y) ** 2 + >>> loss = loss_per_sample.mean() + >>> return loss, (y_pred, loss_per_sample) + >>> + >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) + >>> y_true = torch.rand(4) + >>> y_preds = torch.rand(4, requires_grad=True) + >>> out = fn(y_true, y_preds) + >>> > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample)) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``grad``. + + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``grad`` inside ``torch.no_grad`` context manager: + + >>> with torch.no_grad(): + >>> grad(f)(x) + + In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``grad`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + + """ + @wraps(func) + def wrapper(*args, **kwargs): + results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs) + if has_aux: + grad, (_, aux) = results + return grad, aux + grad, _ = results + return grad + return wrapper + + +def _maybe_wrap_functional_tensor(maybe_tensor, level): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + wrapped = _wrap_functional_tensor(maybe_tensor, level) + _assert_wrapped_functional(maybe_tensor, wrapped) + return wrapped + + +def _wrap_all_tensors_to_functional(tensor_pytree, level): + return tree_map(partial(_maybe_wrap_functional_tensor, level=level), tensor_pytree) + + +def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + if not torch._is_functional_tensor(maybe_tensor): + # If it's not a functional tensor, just return it. + # This can happen if we functionalize a fn that returns a global, + # which was never wrapped properly. + return maybe_tensor + return _unwrap_functional_tensor(maybe_tensor, reapply_views) + + +def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool): + return tree_map(lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), tensor_pytree) + + +def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable: + """ + functionalize is a transform that can be used to remove (intermediate) + mutations and aliasing from a function, while preserving the function's + semantics. + + ``functionalize(func)`` returns a new function with the same semantics + as ``func``, but with all intermediate mutations removed. + Every inplace operation performed on an intermediate tensor: + ``intermediate.foo_()`` + gets replaced by its out-of-place equivalent: + ``intermediate_updated = intermediate.foo()``. + + functionalize is useful for shipping a pytorch program off to + backends or compilers that aren't able to easily represent + mutations or aliasing operators. + + Args: + func (Callable): A Python function that takes one or more arguments. + remove (str): An optional string argument, that takes on either + the value 'mutations' or 'mutations_and_views'. + If 'mutations' is passed in then all mutating operators + will be replaced with their non-mutating equivalents. + If 'mutations_and_views' is passed in, then additionally, all aliasing + operators will be replaced with their non-aliasing equivalents. + Default: 'mutations'. + + Returns: + Returns a new "functionalized" function. It takes the same inputs as + :attr:`func`, and has the same behavior, but any mutations + (and optionally aliasing) performed on intermeidate tensors + in the function will be removed. + + functionalize will also remove mutations (and views) that were performed on function inputs. + However to preserve semantics, functionalize will "fix up" the mutations after + the transform has finished running, by detecting if any tensor inputs "should have" + been mutated, and copying the new data back to the inputs if necessary. + + + Example:: + + >>> import torch + >>> from functorch import make_fx + >>> from functorch.experimental import functionalize + >>> + >>> A function that uses mutations and views, but only on intermediate tensors. + >>> def f(a): + ... b = a + 1 + ... c = b.view(-1) + ... c.add_(1) + ... return b + ... + >>> inpt = torch.randn(2) + >>> + >>> out1 = f(inpt) + >>> out2 = functionalize(f)(inpt) + >>> + >>> # semantics are the same (outputs are equivalent) + >>> print(torch.allclose(out1, out2)) + True + >>> + >>> f_traced = make_fx(f)(inpt) + >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> + >>> print(f_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]) + add_ = torch.ops.aten.add_(view, 1); view = None + return add + + >>> print(f_no_mutations_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]); add = None + add_1 = torch.ops.aten.add(view, 1); view = None + view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None + return view_1 + + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view_copy = torch.ops.aten.view_copy(add, [-1]); add = None + add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None + return view_copy_1 + + + >>> A function that mutates its input tensor + >>> def f(a): + ... b = a.view(-1) + ... b.add_(1) + ... return a + ... + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> + >>> All mutations and views have been removed, + >>> but there is an extra copy_ in the graph to correctly apply the mutation to the input + >>> after the function has completed. + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + view_copy = torch.ops.aten.view_copy(a_1, [-1]) + add = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None + copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None + return view_copy_1 + + + There are a few "failure modes" for functionalize that are worth calling out: + (1) Like other functorch transforms, `functionalize()` doesn't work with functions + that directly use `.backward()`. The same is true for torch.autograd.grad. + If you want to use autograd, you can compute gradients directly + with `functionalize(grad(f))`. + (2) Like other functorch transforms, `functionalize()` doesn't work with global state. + If you call `functionalize(f)` on a function that takes views / mutations of + non-local state, functionalization will simply no-op and pass the view/mutation + calls directly to the backend. + One way to work around this is is to ensure that any non-local state creation + is wrapped into a larger function, which you then call functionalize on. + (3) `resize_()` has some limitations: functionalize will only work on programs + that use resize_()` as long as the tensor being resized is not a view. + (4) `as_strided()` has some limitations: functionalize will not work on + `as_strided()` calls that result in tensors with overlapping memory. + + + Finally, a helpful mental model for understanding functionalization is that + most user pytorch programs are writting with the public torch API. + When executed, torch operators are generally decomposed into + our internal C++ "ATen" API. + The logic for functionalization happens entirely at the level of ATen. + Functionalization knows how to take every aliasing operator in ATen, + and map it to its non-aliasing equivalent + (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``), + and how to take every mutating operator in ATen, + and map it to its non-mutating equivalent + (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``), + while tracking aliases and mutations out-of-line to know when to fix things up. + Information about which ATen operators are aliasing or mutating all comes from + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml. + """ + if remove == 'mutations': + reapply_views = True + elif remove == 'mutations_and_views': + reapply_views = False + else: + raise RuntimeError( + f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}." + " Valid options are:\n" + " remove='mutations': all inplace and out= operators will be removed from the program, and replaced" + " with their out-of-place equivalents.\n" + " remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be" + " replaced with their non-aliasing counterparts, {view}_copy.\n" + ) + + @wraps(func) + def wrapped(*args, **kwargs): + try: + func_level = _func_increment_nesting(reapply_views) + func_args = _wrap_all_tensors_to_functional(args, func_level) + func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level) + + flattened_unwrapped_args, _ = tree_flatten(args) + flattened_wrapped_args, _ = tree_flatten(func_args) + flattened_unwrapped_kwargs, _ = tree_flatten(kwargs) + flattened_wrapped_kwargs, _ = tree_flatten(func_kwargs) + + func_outputs = func(*func_args, **func_kwargs) + outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views) + flat_outputs, func_out_spec = tree_flatten(outputs) + + for a in flattened_wrapped_args + flattened_wrapped_kwargs: + if isinstance(a, torch.Tensor): + # Call sync_() on the inputs, to ensure that any pending mutations have been applied. + torch._sync(a) + + # And if any mutations were applied to the inputs, we need to propagate them back to the user. + for unwrapped, wrapped in zip(flattened_unwrapped_args, flattened_wrapped_args): + if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor): + _propagate_functional_input_mutation(unwrapped, wrapped) + for unwrapped, wrapped in zip(flattened_unwrapped_kwargs, flattened_wrapped_kwargs): + if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor): + _propagate_functional_input_mutation(unwrapped, wrapped) + + return outputs + finally: + _func_decrement_nesting() + return wrapped + + +def _register_jit_decomposition(decomp, use_python=False): + if decomp in decomposition_table_for_jvp: + decomposition_table_used = decomposition_table_for_jvp + elif decomp in decomposition_table: + decomposition_table_used = decomposition_table + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + decomp_fn = decomposition_table_used[decomp] + if use_python: + decomp_fn = torch.jit.ignore(decomp_fn) + sig = inspect.signature(decomp_fn) + + # Create a string wrapping the function from the signature + # example output: + # def wrapped_decomp(x: torch.Tensor, y: int, z: int): + # return decomp_fn(x, y, z) + # Thanks copilot! + def get_function_def(sig): + param_def = [f"{param_str}" for param_str in sig.parameters.values()] + param_use = [f"{param_str}" for param_str in sig.parameters.keys()] + + return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n" + + f_str = get_function_def(sig) + graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph + else: + graph = torch.jit.script(decomp_fn).graph + torch.jit._register_decomposition(decomp, graph) + + +# use an alternate way to register an operator into the decomposition table +# _register_jit_decomposition doesn't work for some operators, e.g. addr, +# because the Tensor types generated cannot be unioned by torchscript +# decomp should be type OpOverload +vmap_decompositions_lib = torch.library.Library("aten", "IMPL", "FuncTorchBatched") + + +def _register_python_decomposition_vmap(decomp): + if decomp in decomposition_table: + vmap_decompositions_lib.impl(decomp, decomposition_table[decomp]) + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + + +_register_jit_decomposition(torch.ops.aten.trace.default, use_python=True) +_register_jit_decomposition(torch.ops.aten.nll_loss_backward.default) +_register_jit_decomposition(torch.ops.aten.nll_loss2d_backward.default) +_register_jit_decomposition(torch.ops.aten._log_softmax_backward_data.default) +_register_jit_decomposition(torch.ops.aten._softmax_backward_data.default) +_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default) +_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default) +_register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default) +_register_jit_decomposition(torch.ops.aten.cudnn_batch_norm_backward.default) +_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default) +_register_python_decomposition_vmap(torch.ops.aten.addr.default) diff --git a/functorch/functorch/_src/fx_minifier.py b/functorch/functorch/_src/fx_minifier.py new file mode 100644 index 0000000000000..a0bb799045805 --- /dev/null +++ b/functorch/functorch/_src/fx_minifier.py @@ -0,0 +1,269 @@ +import subprocess +import torch.fx as fx +import copy +import torch +import math + + +class ConcreteProp(torch.fx.Interpreter): + def run_node(self, n): + result = super().run_node(n) + + found_tensor = False + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + nonlocal found_tensor + found_tensor = True + return obj + else: + return obj + + from torch.fx.node import map_aggregate + concrete_value = map_aggregate(result, extract_tensor_meta) + if found_tensor: + n.meta['concrete_value'] = concrete_value + return result + + def propagate(self, *args): + return super().run(*args) + + +def _get_placeholders(graph): + return list(filter(lambda x: x.op == 'placeholder', graph.nodes)) + +# inplace modifies node/inps + + +def _convert_node_to_placeholder(node, inps): + if node.op == 'output': + return + node.op = 'placeholder' + node.args = () + node.target = node.name + concrete_val = node.meta['concrete_value'] + if isinstance(concrete_val, torch.Tensor): + inps.append(concrete_val) + else: + inps.append(torch.zeros(())) + for tuple_user in list(node.users): + _convert_node_to_placeholder(tuple_user, inps) + + +def minifier(fail_f: fx.GraphModule, inps, module_fails): + """ + Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails. + + Does 2 main strategies: + 1. Truncates suffix: Removes some suffix from the graph and sets a new output. + 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, + tries replacing quarter of the graph, etc. + + >>> failing_function = fx.symbolic_trace(f) + >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) + + note: module_fails returns True if it fails. + """ + failing_graph = fail_f.graph + cur_size = len(failing_graph.nodes) + + def graph_fails(graph, inps): + + mod = fx.GraphModule(fail_f, graph) + mod.graph.lint() + return module_fails(mod, inps) + + ConcreteProp(fail_f).propagate(*inps) + if not graph_fails(failing_graph, inps): + raise RuntimeError("Input graph did not fail the tester") + print(f"Started off with {cur_size} nodes") + + def remove_suffix(cur_graph, cur_inps): + print("Strategy: Remove suffix") + assert graph_fails(cur_graph, cur_inps) + gap = 2**math.floor(math.log2(len(cur_graph.nodes))) + tested = set() + while gap >= 1: + new_graph = fx.Graph() + env = {} + for idx, node in enumerate(cur_graph.nodes): + new_node = new_graph.node_copy(node, lambda x: env[x]) + if node.op not in ['placeholder', 'output']: + if idx % gap == 0 and idx not in tested: + output_node = new_graph.output((new_node,)) + if graph_fails(new_graph, cur_inps) and len(new_graph.nodes) < len(cur_graph.nodes): + print() + print(f"SUCCESS: Removed [{idx}:{len(cur_graph.nodes)})") + return (new_graph, cur_inps), True + else: + tested.add(idx) + new_graph.erase_node(output_node) + env[node] = new_node + gap //= 2 + print("FAIL: Could not remove suffix") + return (cur_graph, cur_inps), False + + def remove_unused_inputs(cur_graph, cur_inps): + assert graph_fails(cur_graph, cur_inps) + ph_nodes = _get_placeholders(cur_graph) + if len(ph_nodes) != len(cur_inps): + print(cur_graph) + print(len(cur_inps)) + assert len(ph_nodes) == len(cur_inps) + + new_inps = [] + for idx in range(len(ph_nodes)): + if len(ph_nodes[idx].users) == 0: + cur_graph.erase_node(ph_nodes[idx]) + else: + new_inps.append(cur_inps[idx]) + + if len(new_inps) < len(cur_inps) and graph_fails(cur_graph, new_inps): + print("Strategy: Remove unused inputs") + print(f"SUCCESS: Went from {len(cur_inps)} inputs to {len(new_inps)} inputs") + return (cur_graph, new_inps), True + else: + return (cur_graph, new_inps), False + + def eliminate_dead_code(cur_graph, cur_inps): + orig_size = len(cur_graph.nodes) + if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps): + print("Strategy: Eliminate dead code") + print(f"SUCCESS: Went from {orig_size} nodes to {len(cur_graph.nodes)} nodes") + return (cur_graph, cur_inps), True + else: + return (cur_graph, cur_inps), False + + def consolidate_placeholders(cur_graph): + new_graph = fx.Graph() + env = {} + for node in cur_graph.nodes: + if node.op == 'placeholder': + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + + for node in cur_graph.nodes: + if node.op != 'placeholder': + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + return new_graph + + def delta_debugging(cur_graph: fx.Graph, cur_inps): + print("Strategy: Delta Debugging") + assert graph_fails(cur_graph, cur_inps) + starting_placeholders = len(_get_placeholders(cur_graph)) + num_nodes = len(cur_graph.nodes) + gap = int(2**math.floor(math.log2(num_nodes))) + while gap >= 1: + for start_range in range(0, num_nodes, gap): + is_removing = False + new_graph = copy.deepcopy(cur_graph) + new_inps = cur_inps[:] + end_range = min(num_nodes, start_range + gap) + for idx in range(start_range, end_range): + new_node = list(new_graph.nodes)[idx] + if new_node.op not in ['placeholder', 'output']: + is_removing = True + _convert_node_to_placeholder(new_node, new_inps) + if not is_removing: + continue + new_graph = consolidate_placeholders(new_graph) + if graph_fails(new_graph, new_inps): + print( + f"SUCCESS: Removed ({start_range}:{end_range}] - Went from {starting_placeholders} " + f"placeholders to {len(_get_placeholders(new_graph))}" + ) + return (new_graph, new_inps), True + gap //= 2 + + print("FAIL: Could not remove prefix") + return (cur_graph, inps), False + + print("###################") + print(f"Current size: {len(failing_graph.nodes)}") + print("###################") + while True: + any_succeeded = False + strategies = [ + remove_suffix, eliminate_dead_code, remove_unused_inputs, + delta_debugging, eliminate_dead_code, remove_unused_inputs + ] + for strategy in strategies: + out = strategy(copy.deepcopy(failing_graph), inps[:]) + (cur_graph, cur_inps), succeeded = out + if succeeded: + print() + print("###################") + print(f"Current size: {len(cur_graph.nodes)}") + print("###################") + failing_graph = cur_graph + inps = cur_inps + any_succeeded = True + + if not any_succeeded: + break + failing_fx = fx.GraphModule(fail_f, failing_graph) + print(f""" +inps = {[(i.shape, i.dtype) for i in inps]} +inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] +{failing_fx.code} +f = torch.jit.script(forward) +with torch.jit.fuser("fuser2"): + for _ in range(5): + f(*inps)""") + return failing_fx, inps + + +def check_nvfuser_subprocess(f, inps): + f.to_folder("temp") + with open("_temp.py", 'w') as fil: + fil.write(f''' +import torch +from temp import FxModule +f = FxModule().cuda() +inps = {[(i.shape, i.dtype) for i in inps]} +inps = [torch.ones(shape, dtype=dtype, device='cuda') for shape, dtype in inps] +with torch.jit.fuser("fuser2"): + nf = torch.jit.script(f) + for _ in range(5): + nf(*inps) + ''') + p = subprocess.Popen(["PYTORCH_NVFUSER_DISABLE_FALLBACK=1 python _temp.py"], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + out, err = p.communicate() + if p.returncode != 0: + err = err.decode('utf-8') + print(err) + return True + return False + + +def check_nvfuser_correctness_subprocess(f, inps): + f.to_folder("temp") + with open("_temp.py", 'w') as fil: + fil.write(f''' +import torch +from temp import FxModule +f = FxModule().cuda() +inps = {[(i.shape, i.dtype) for i in inps]} +inps = [torch.randn(shape, dtype=dtype, device='cuda') + if dtype.is_floating_point else torch.ones(shape, dtype=dtype, device='cuda') + for shape, dtype in inps] + +ref = f(*inps) +nv_f = torch.jit.script(f) +with torch.jit.fuser("fuser2"): + for _ in range(5): + res = nv_f(*inps) +for a, b in zip(ref, res): + if not torch.allclose(a, b, atol=0.1): + exit(1) +''') + p = subprocess.Popen(["PYTORCH_NVFUSER_DISABLE_FALLBACK=1 python _temp.py"], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + out, err = p.communicate() + if p.returncode != 0: + err = err.decode('utf-8') + print(err) + return True + return False diff --git a/functorch/functorch/_src/make_functional.py b/functorch/functorch/_src/make_functional.py new file mode 100644 index 0000000000000..7b8c15196e23b --- /dev/null +++ b/functorch/functorch/_src/make_functional.py @@ -0,0 +1,543 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch import Tensor +from typing import List, Tuple +from .named_members_polyfill import _named_parameters, _named_buffers +import copy + +# Utilities to make nn.Module "functional" +# In particular the goal is to be able to provide a function that takes as input +# the parameters and evaluate the nn.Module using fixed inputs. + + +def _del_nested_attr(obj: nn.Module, names: List[str]) -> None: + """ + Deletes the attribute specified by the given list of names. + For example, to delete the attribute obj.conv.weight, + use _del_nested_attr(obj, ['conv', 'weight']) + """ + if len(names) == 1: + delattr(obj, names[0]) + else: + _del_nested_attr(getattr(obj, names[0]), names[1:]) + + +def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None: + """ + Set the attribute specified by the given list of names to value. + For example, to set the attribute obj.conv.weight, + use _del_nested_attr(obj, ['conv', 'weight'], value) + """ + if len(names) == 1: + setattr(obj, names[0], value) + else: + _set_nested_attr(getattr(obj, names[0]), names[1:], value) + + +def _get_nested_attr(obj: nn.Module, names: List[str]) -> None: + if len(names) == 1: + return getattr(obj, names[0]) + else: + _get_nested_attr(getattr(obj, names[0]), names[1:]) + + +def raise_parameter_tying_error(): + raise RuntimeError( + "make_functional(module): we don't yet support models that " + "do parameter tying (also sometimes known as weight sharing). " + "Please try to rewrite your model by replacing all instances of the " + "tied parameter with another and/or comment your support in " + "https://github.com/pytorch/functorch/issues/446") + + +def create_names_map(named_params, tied_named_params): + """ + named_params is a dictionary of tensors: {'A': A, 'B': B} + tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} + with potentially tied (or 'duplicated') tensors + + This function creates a mapping from the names in named_params to the + names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. + """ + named_params = {k: v for k, v in named_params} + tied_named_params = {k: v for k, v in tied_named_params} + + tensors_dict_keys = set(named_params.keys()) + tied_tensors_dict_keys = set(tied_named_params.keys()) + assert tensors_dict_keys.issubset(tied_tensors_dict_keys) + + tensor_to_mapping = {} + for key, tensor in named_params.items(): + tensor_to_mapping[tensor] = (key, []) + for key, tensor in tied_named_params.items(): + assert tensor in tensor_to_mapping + tensor_to_mapping[tensor][1].append(key.split('.')) + result = {key: value for key, value in tensor_to_mapping.values()} + return result + + +def _extract_members(mod: nn.Module, _named_members, named_members, subclass): + all_named_members = tuple(_named_members(mod, remove_duplicate=False)) + named_members = tuple(named_members()) + names_map = create_names_map(named_members, all_named_members) + + # Remove all the members in the model + memo = {} + for name, p in all_named_members: + if p not in memo: + memo[p] = subclass(torch.empty_like(p, device='meta')) + replacement = memo[p] + _set_nested_attr(mod, name.split("."), replacement) + + if len(named_members) == 0: + names, params = (), () + else: + names, params = zip(*named_members) + return params, names, names_map + + +def extract_weights(mod: nn.Module): + """ + This function removes all the Parameters from the model and + return them as a tuple as well as their original attribute names. + The weights must be re-loaded with `load_weights` before the model + can be used again. + Note that this function modifies the model in place and after this + call, mod.parameters() will be empty. + """ + return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter) + + +def extract_buffers(mod: nn.Module): + return _extract_members(mod, _named_buffers, mod.named_buffers, lambda x: x) + + +def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None: + """ + Reload a set of weights so that `mod` can be used again to perform a forward pass. + Note that the `params` are regular Tensors (that can have history) and so are left + as Tensors. This means that mod.parameters() will still be empty after this call. + """ + for name, p in zip(names, params): + if as_params: + p = nn.Parameter(p) + _del_nested_attr(mod, name.split(".")) + _set_nested_attr(mod, name.split("."), p) + + +def _swap_state(mod: nn.Module, names_map: List[str], elems): + result = [] + for (_, attr_names), elem in zip(names_map.items(), elems): + for i, attr_name in enumerate(attr_names): + if i == 0: + result.append(_get_nested_attr(mod, attr_name)) + _del_nested_attr(mod, attr_name) + _set_nested_attr(mod, attr_name, elem) + return result + + +def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None: + for name, p in zip(names, buffers): + _set_nested_attr(mod, name.split("."), p) + + +def load_state( + model: nn.Module, + weights: List[Tensor], weight_names: List[str], + buffers=(), buffer_names=()): + """load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model + + load_state takes `weights` and `buffers` and assigns them to the model. + This is the inverse operation of `make_functional_deprecated_v1`. + """ + assert len(weight_names) == len(weights) + load_weights(model, weight_names, weights) + if len(buffers) > 0: + assert len(buffer_names) == len(buffers) + load_buffers(model, buffer_names, buffers) + return model + + +def make_functional_deprecated_v1(model: nn.Module): + """make_functional_deprecated_v1(model) -> weights, func, weight_names + + Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights) + and returns a functional version of the model, `func`. This makes + it so that it is possible use transforms over the parameters of + `model`. + + `func` can be invoked as follows: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, func, _ = make_functional_deprecated_v1(model) + func(weights, (x,)) + ``` + + And here is an example of applying the grad transform: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, _, func = make_functional_deprecated_v1(model) + grad_weights = grad(func)(weights, (x,)) + ``` + + To put the state back into a model, use `load_state`. + """ + buffers = list(model.buffers()) + if len(buffers) > 0: + raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use ' + 'make_functional_with_buffers_deprecated_v1(model) instead.') + weights, descriptors, _ = extract_weights(model) + + def fun(weights, data): + mutable_model = copy.deepcopy(model) + load_weights(mutable_model, descriptors, weights) + return mutable_model(*data) + + return weights, fun, descriptors + + +def make_functional_with_buffers_deprecated_v1(model: nn.Module): + """make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names + + Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers) + and returns a functional version of the model, `func`. + + `func` can be invoked as follows: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) + func(weights, buffers, (x,)) + ``` + + And here is an example of applying the grad transform: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) + func(weights, buffers, (x,)) + grad_weights = grad(func)(weights, buffers, (x,)) + ``` + + To put the state back into a model, use `load_state`. + """ + weights, weight_descriptors, _ = extract_weights(model) + buffers, buf_descriptors, _ = extract_buffers(model) + + def fun(weights, buffers, data): + mutable_model = copy.deepcopy(model) + load_weights(mutable_model, weight_descriptors, weights) + load_buffers(mutable_model, buf_descriptors, buffers) + return mutable_model(*data) + + return weights, buffers, fun, weight_descriptors, buf_descriptors + + +class FunctionalModuleWithBuffers(nn.Module): + """ + This is the callable object returned by :func:`make_functional_with_buffers`. + """ + + def __init__(self, stateless_model, param_names, buffer_names, + param_names_map, buffer_names_map): + super(FunctionalModuleWithBuffers, self).__init__() + self.stateless_model = stateless_model + self.param_names = param_names + self.buffer_names = buffer_names + + self.all_names_map = dict(param_names_map) + self.all_names_map.update(buffer_names_map) + + @staticmethod + def _create_from(model, disable_autograd_tracking=False): + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + params, param_names, param_names_map = extract_weights(model_copy) + buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return ( + FunctionalModuleWithBuffers(model_copy, param_names, buffer_names, + param_names_map, buffer_names_map), + params, + buffers, + ) + + def forward(self, params, buffers, *args, **kwargs): + # Temporarily load the state back onto self.stateless_model + old_state = _swap_state( + self.stateless_model, + self.all_names_map, + list(params) + list(buffers)) + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + _swap_state(self.stateless_model, self.all_names_map, old_state) + + +class FunctionalModule(nn.Module): + """ + This is the callable object returned by :func:`make_functional`. + """ + + def __init__(self, stateless_model, param_names, names_map): + super(FunctionalModule, self).__init__() + self.stateless_model = stateless_model + self.param_names = param_names + self.names_map = names_map + + @staticmethod + def _create_from(model, disable_autograd_tracking=False): + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + params, param_names, names_map = extract_weights(model_copy) + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return FunctionalModule(model_copy, param_names, names_map), params + + def forward(self, params, *args, **kwargs): + # Temporarily load the state back onto self.stateless_model + old_state = _swap_state(self.stateless_model, self.names_map, params) + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + _swap_state(self.stateless_model, self.names_map, old_state) + + +def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): + """make_functional(model, disable_autograd_tracking=False) -> func, params + + Given a ``torch.nn.Module``, :func:`make_functional` extracts the state + (params) and returns a functional version of the model, ``func``. This + makes it so that it is possible use transforms over the parameters of + ``model``. + + ``func`` can be invoked as follows: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional + + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params = make_functional(model) + func(params, x) + + And here is an example of applying the grad transform over the parameters + of a model. + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params = make_functional(model) + + def compute_loss(params, x, t): + y = func(params, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(params, x, t) + + If the model has any buffers, please use :func:`make_functional_with_buffers` instead. + + Args: + model (torch.nn.Module): Input model. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. + + """ + buffers = list(model.buffers()) + if len(buffers) > 0: + raise RuntimeError('make_functional(model): `model` has buffers. Please use ' + 'make_functional_with_buffers(model) instead.') + return FunctionalModule._create_from(model, disable_autograd_tracking=disable_autograd_tracking) + + +def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False): + """make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers + + Given a ``torch.nn.Module``, make_functional_with_buffers extracts the + state (params and buffers) and returns a functional version of the model + ``func`` that can be invoked like a function. + + ``func`` can be invoked as follows: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional_with_buffers + + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params, buffers = make_functional_with_buffers(model) + func(params, buffers, x) + + And here is an example of applying the grad transform over the parameters + of a model: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional_with_buffers, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params, buffers = make_functional_with_buffers(model) + + def compute_loss(params, buffers, x, t): + y = func(params, buffers, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(params, buffers, x, t) + + Args: + model (torch.nn.Module): Input model. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. + + """ + return FunctionalModuleWithBuffers._create_from(model, disable_autograd_tracking=disable_autograd_tracking) + + +def transpose_stack(tuple_of_tuple_of_tensors): + tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors)) + results = tuple(torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors) + return results + + +def combine_state_for_ensemble(models): + """combine_state_for_ensemble(models) -> func, params, buffers + + Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. + + Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their + parameters and buffers together to make ``params`` and ``buffers``. + Each parameter and buffer in the result will have an additional dimension + of size ``M``. + + :func:`combine_state_for_ensemble` also returns ``func``, a functional + version of one of the models in :attr:`models`. One cannot directly run + ``func(params, buffers, *args, **kwargs)`` directly, you probably want to + use ``vmap(func, ...)(params, buffers, *args, **kwargs)`` + + Here's an example of how to ensemble over a very simple model: + + .. code-block:: python + + num_models = 5 + batch_size = 64 + in_features, out_features = 3, 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + data = torch.randn(batch_size, 3) + + fmodel, params, buffers = combine_state_for_ensemble(models) + output = vmap(fmodel, (0, 0, None))(params, buffers, data) + + assert output.shape == (num_models, batch_size, out_features) + + .. warning:: + All of the modules being stacked together must be the same (except for + the values of their parameters/buffers). For example, they should be in the + same mode (training vs eval). + + This API is subject to change -- we're investigating better ways to + create ensembles and would love your feedback how to improve this. + """ + if len(models) == 0: + raise RuntimeError('combine_state_for_ensemble: Expected at least one model, got 0.') + if not (all(m.training for m in models) or all(not m.training for m in models)): + raise RuntimeError('combine_state_for_ensemble: Expected all models to ' + 'have the same training/eval mode.') + model0_typ = type(models[0]) + if not all(type(m) == model0_typ for m in models): + raise RuntimeError('combine_state_for_ensemble: Expected all models to ' + 'be of the same class.') + funcs, params, buffers = zip(*[make_functional_with_buffers(model) + for model in models]) + params = transpose_stack(params) + buffers = transpose_stack(buffers) + return funcs[0], params, buffers + + +def functional_init(model_class, ensemble_shape=(), device='cpu'): + def wrapped(*args, **kwargs): + if len(ensemble_shape) >= 2: + raise ValueError('NYI: ensemble_shape with more than 1 element') + if len(ensemble_shape) == 0: + model = model_class(*args, **kwargs).to(device) + return make_functional_deprecated_v1(model) + num_models = ensemble_shape[0] + if num_models <= 0: + raise ValueError(f"num_models {num_models} should be > 0") + # NB: Not very efficient, more of a POC + models = tuple(model_class(*args, **kwargs).to(device) + for _ in range(num_models)) + _, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs)) + weights = tuple(make_functional_deprecated_v1(model)[0] for model in models) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + return weights, fn, names + return wrapped + + +def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'): + def wrapped(*args, **kwargs): + if len(ensemble_shape) >= 2: + raise ValueError('NYI: ensemble_shape with more than 1 element') + if len(ensemble_shape) == 0: + model = model_class(*args, **kwargs).to(device) + return make_functional_deprecated_v1(model) + num_models = ensemble_shape[0] + if num_models <= 0: + raise ValueError(f"num_models {num_models} should be > 0") + # NB: Not very efficient, more of a POC + models = tuple(model_class(*args, **kwargs).to(device) + for _ in range(num_models)) + _, _, fn, weight_names, buffer_names = \ + make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs)) + weights, buffers = zip(*tuple(make_functional_with_buffers_deprecated_v1(model)[:2] + for model in models)) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + buffers = tuple(zip(*buffers)) + buffers = tuple(torch.stack(shards).detach() for shards in buffers) + return weights, buffers, fn, weight_names, buffer_names + return wrapped diff --git a/functorch/functorch/_src/monkey_patching.py b/functorch/functorch/_src/monkey_patching.py new file mode 100644 index 0000000000000..04ba4ab9bb63e --- /dev/null +++ b/functorch/functorch/_src/monkey_patching.py @@ -0,0 +1,80 @@ +import torch +import functorch._C as _C +import functools + +# Monkeypatch tensor printing in pytorch +_old_str = torch._tensor_str._str + + +def prep_value(text, indent=4): + first_line_txt = '' + lines = text.split('\n') + lines[0] = lines[0] + lines[0] = ' ' * indent + first_line_txt + lines[0] + for i in range(1, len(lines)): + lines[i] = ' ' * (indent + len(first_line_txt)) + lines[i] + return '\n'.join(lines) + + +@functools.wraps(_old_str) +def _functorch_str(tensor, *, tensor_contents=None): + level = _C.maybe_get_level(tensor) + if level == -1: + return _old_str(tensor) + + if _C.is_functionaltensor(tensor): + # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure + # that it's up to date first + torch._sync(tensor) + + value = _C.get_unwrapped(tensor) + dl_enabled = _C.tls_set_is_included() + try: + # Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys + if (dl_enabled): + _C._set_dynamic_layer_keys_included(False) + value_repr = repr(value) + finally: + # Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys + if (dl_enabled): + _C._set_dynamic_layer_keys_included(True) + + if _C.is_batchedtensor(tensor): + bdim = _C.maybe_get_bdim(tensor) + assert bdim != -1 + return ( + f'BatchedTensor(lvl={level}, bdim={bdim}, value=\n' + f'{prep_value(value_repr)}\n' + f')' + ) + if _C.is_gradtrackingtensor(tensor): + return ( + f'GradTrackingTensor(lvl={level}, value=\n' + f'{prep_value(value_repr)}\n' + f')' + ) + if _C.is_functionaltensor(tensor): + return f'FunctionalTensor(lvl={level}, value=\\\n{value_repr})' + + raise ValueError("We don't know how to print this, please file us an issue") + + +torch._tensor_str._str = _functorch_str + + +# Monkeypatch .backward() to error out if any transforms are active. +# TODO: remove the monkeypatching and add an extension point into PyTorch core +_old_backward = torch.Tensor.backward + + +@functools.wraps(_old_backward) +def _backward(*args, **kwargs): + if _C.are_transforms_active(): + raise RuntimeError( + "backward() called inside a functorch transform. This is not " + "supported, please use functorch.grad or functorch.vjp instead " + "or call backward() outside of functorch transforms.") + return _old_backward(*args, **kwargs) + + +torch.Tensor.backward = _backward diff --git a/functorch/functorch/_src/named_members_polyfill.py b/functorch/functorch/_src/named_members_polyfill.py new file mode 100644 index 0000000000000..80704eb551adb --- /dev/null +++ b/functorch/functorch/_src/named_members_polyfill.py @@ -0,0 +1,32 @@ +# Polyfilled from pytorch core while we figure out the `remove_duplicate` issues. +def _named_members(mod, get_members_fn, prefix='', recurse=True, remove_duplicate=True): + r"""Helper method for yielding various names + members of modules.""" + memo = set() + modules = mod.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, mod)] + for module_prefix, module in modules: + members = get_members_fn(module) + for k, v in members: + if v is None or v in memo: + continue + if remove_duplicate: + memo.add(v) + name = module_prefix + ('.' if module_prefix else '') + k + yield name, v + + +def _named_parameters(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True): + gen = _named_members( + mod, + lambda module: module._parameters.items(), + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + for elem in gen: + yield elem + + +def _named_buffers(mod, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True): + gen = _named_members( + mod, + lambda module: module._buffers.items(), + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + for elem in gen: + yield elem diff --git a/functorch/functorch/_src/partitioners.py b/functorch/functorch/_src/partitioners.py new file mode 100644 index 0000000000000..88b6e447ba493 --- /dev/null +++ b/functorch/functorch/_src/partitioners.py @@ -0,0 +1,443 @@ +import torch +import torch.fx as fx +import operator +import math +import torch.utils._pytree as pytree +import copy +import os +from collections import defaultdict +from torch.fx.passes import graph_drawer +from typing import Tuple +from .compile_utils import fx_graph_cse, get_aten_target +from . import config + +AOT_PARTITIONER_DEBUG = config.debug_partitioner + +INDUCTOR = False + + +class InvalidNodeBase(object): + def __repr__(self): + return "Invalid Node" + + +InvalidNode = InvalidNodeBase() + + +def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): + """ + Given a graph, extracts out a subgraph that takes the specified nodes as + inputs and returns the specified outputs. + + This includes specifying non-placeholder nodes as inputs. + + The general strategy is to initialize all inputs with proxies as we + encounter them, and trace through the graph, only keeping values which take + in valid proxies. Then, all dead code is eliminated. + """ + new_graph = fx.Graph() + env = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in inputs: + new_node = new_graph.placeholder(node.name) + # Can't use node_copy here as we may be turning previous call_function into placeholders + new_node.meta = node.meta + env[node] = new_node + + for node in joint_graph.nodes: + if node in inputs: + continue + elif node.op == 'placeholder': + env[node] = InvalidNode + elif node.op == 'call_function': + all_args = pytree.tree_flatten((node.args, node.kwargs))[0] + all_args = [isinstance(env[x], InvalidNodeBase) for x in all_args if isinstance(x, fx.Node)] + if any(all_args): + env[node] = InvalidNode + continue + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == 'get_attr': + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == 'output': + pass + output_values = [] + for x in outputs: + if isinstance(x, fx.Node): + if x not in env: + raise RuntimeError(f"Node {x} couldn't be found in env") + output_values.append(env[x]) + else: + output_values.append(x) + new_graph.output(output_values) + + new_graph.eliminate_dead_code() + new_graph.lint() + return new_graph + + +def _is_primal(node): + return node.op == "placeholder" and "tangents" not in node.target + + +def _is_tangent(node): + return node.op == "placeholder" and "tangents" in node.target + + +def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule): + num_fwd_outputs = joint_module._out_spec.children_specs[0].num_leaves + outputs = pytree.tree_flatten([node.args for node in joint_module.graph.nodes if node.op == 'output'])[0] + fwd_outputs = outputs[:num_fwd_outputs] + bwd_outputs = outputs[num_fwd_outputs:] + return fwd_outputs, bwd_outputs + + +def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values): + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module) + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes)) + # Construct the forward module + fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs + saved_values) + bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_values + tangent_inputs, bwd_outputs) + + # This is to filter out saved values that don't actually end up being used by the backwards pass + for node in bwd_graph.nodes: + if node.op == 'placeholder' and not node.users: + for saved_value in saved_values: + if saved_value.name == node.name: + saved_values.remove(saved_value) + break + + # Now, we re-generate the fwd/bwd graphs. + # NB: This might increase compilation time, but I doubt it matters + fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs + saved_values) + bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_values + tangent_inputs, bwd_outputs) + + fwd_module = fx.GraphModule(joint_module, fwd_graph) + bwd_module = fx.GraphModule(joint_module, bwd_graph) + return fwd_module, bwd_module + + +def default_partition( + joint_module: fx.GraphModule, _joint_inputs +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the :attr:`joint_module` in a manner that closely resembles the + behavior observed in the original ``.forward()`` and ``.backward()`` of the + callable, i.e., the resulting forward graph contains those operators that + are executed in the original ``.forward()`` callable passed to + :func:`aot_function`. + + The default partitioner collects the operators that are between the forward + inputs and the forward outputs. This helps in finding the tensors which have + to be stashed for the backward pass. These stashed tensors become the output + of the generated forward graph. The remaining operators are then placed in + the backward graph. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module) + forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs) + forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != 'output'} + saved_values = [] + for node in joint_module.graph.nodes: + if node.name not in forward_node_names: + continue + # Since we can't save tuple of tensor values, we need to flatten out what we're saving + if 'tensor_meta' not in node.meta and node.op == 'call_function': + users = node.users + assert all(user.target == operator.getitem for user in users) + for user in users: + saved_values.append(user) + else: + saved_values.append(node) + saved_values = list(set(saved_values)) + + return _extract_fwd_bwd_modules(joint_module, saved_values) + + +def _prod(x): + s = 1 + for i in x: + s *= i + return s + + +def _size_of(metadata): + sizes = { + torch.float: 4, + torch.float16: 2, + torch.bfloat16: 2, + torch.float32: 4, + torch.float64: 8, + torch.int: 4, + torch.int8: 1, + torch.int16: 2, + torch.int32: 4, + torch.int64: 8, + torch.uint8: 1, + torch.bool: 1, + } + + numel = _prod(metadata.shape) + dtype = metadata.dtype + + if dtype not in sizes: + raise NotImplementedError("Don't know the size of dtype ", dtype) + + return numel * sizes[dtype] + + +# Used for some investigative purposes +def _count_ops(graph): + from collections import defaultdict + cnt = defaultdict(int) + for node in graph.nodes: + if node.op == 'call_function': + cnt[node.target.__name__] += 1 + print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) + + +def min_cut_rematerialization_partition( + joint_module: fx.GraphModule, _joint_inputs +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the joint graph such that the backward recomputes the forward. + Recomputing helps in trading off memory bandwidth with computation. + + To create the fwd and bwd graph, we copy the joint graph, manually set the + outputs to just original forward or backward outputs. And then we run the + resulting graphs through dead code elimintation. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + try: + import networkx as nx + except ImportError: + raise RuntimeError("Need networkx installed to perform smart recomputation heuristics") + + joint_module.graph.eliminate_dead_code() + joint_module.recompile() + fx_g = joint_module.graph + + # add the CSE pass + cse_graph = fx_graph_cse(fx_g) + joint_module.graph = cse_graph + full_bw_graph = joint_module.graph + + name_to_node = {} + for node in joint_module.graph.nodes: + name_to_node[node.name] = node + + def classify_nodes(joint_module): + required_bw_nodes = set() + for node in joint_module.graph.nodes: + if node.op == 'placeholder' and "tangents" in node.target: + required_bw_nodes.add(node) + if node in required_bw_nodes: + for user in node.users: + required_bw_nodes.add(user) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_outputs, _ = _extract_fwd_bwd_outputs(joint_module) + forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs) + required_fw_nodes = {name_to_node[node.name] for node in forward_only_graph.nodes + if node.op != 'output'} + unclaimed_nodes = {node for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes} + return required_fw_nodes, required_bw_nodes, unclaimed_nodes + + required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module) + for node in reversed(joint_module.graph.nodes): + if node not in required_fw_nodes: + node.dist_from_bw = 0 + else: + node.dist_from_bw = int(1e9) + for user in node.users: + node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) + + aten = torch.ops.aten + prims = torch.ops.prims + + pointwise_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward] # noqa: E501 + if INDUCTOR: + pointwise_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone] # noqa: E501 + misc_ops = [aten.to, aten.type_as, operator.getitem] + + reduction_ops = [aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax] # noqa: E501 + if INDUCTOR: + reduction_ops += [prims.var, prims.sum, aten.var] + + # not recomputed by default since these are kinda expensive/hard to fuse into + # norm_ops = [aten.instance_norm, aten._batch_norm_impl_index, aten.native_batch_norm, aten.batch_norm, aten._batch_norm_impl_index_backward, aten.native_layer_norm, aten.layer_norm, aten.native_layer_norm_backward] # noqa: E501 + + # Not used by default since NVFuser can't fuse view ops + # view_ops = [aten.expand, aten.clone, aten.transpose, aten.t, aten.view, aten._unsafe_view, aten.permute, aten.transpose, aten.t, aten._reshape_alias, aten.squeeze, aten.unsqueeze, aten.reshape, aten.cat, aten.slice, aten.split, aten.select, aten.repeat] # noqa: E501 + + # These are the view ops that NVFuser can fuse + view_ops = [aten.squeeze, aten.unsqueeze] + if INDUCTOR: + view_ops += [prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors] # noqa: E501 + random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] + compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d] # noqa: E501 + + unrecomputable_ops = random_ops + compute_intensive_ops + + recomputable_ops = set( + pointwise_ops + + misc_ops + + reduction_ops + + view_ops + ) + fusible_ops = recomputable_ops | set(random_ops) + if AOT_PARTITIONER_DEBUG: + joint_module_ops = set( + str(node.target._overloadpacket) + for node in joint_module.graph.nodes + if node.op == "call_function" and hasattr(node.target, "_overloadpacket") + ) + ops_ignored = joint_module_ops - set([str(i) for i in recomputable_ops]) + print("Ops banned from rematerialization: ", ops_ignored) + print() + + AGGRESSIVE_RECOMPUTATION = False + + def _maybe_size_of(node): + if 'tensor_meta' in node.meta: + return _size_of(node.meta['tensor_meta']) + return 0 + + def ban_recomputation(node): + if AGGRESSIVE_RECOMPUTATION: + return (node.op == 'call_function' and get_aten_target(node) in unrecomputable_ops) + else: + if node.op != 'call_function': + return False + if get_aten_target(node) not in recomputable_ops: + return True + if node.target == operator.getitem: + return False + # If the output of an op is 4x smaller (arbitrary choice), + # then we don't allow recomputation. + if 'tensor_meta' not in node.meta: + return False + input_tensors_size = sum(_maybe_size_of(i) for i in node.args if isinstance(i, fx.Node)) + output_size = _size_of(node.meta['tensor_meta']) + return (output_size * 4 < input_tensors_size) + + def is_fusible(a, b): + return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops + + def is_materialized(node): + if node.op == 'placeholder': + return True + + return not all(is_fusible(node, user) for user in node.users) + + def get_node_weight(node): + mem_sz = _size_of(node.meta['tensor_meta']) + + # Heuristic to bias towards nodes closer to the backwards pass + # Complete guess about current value + mem_sz = int(mem_sz * (1.5 ** max(min(node.dist_from_bw, 100), 1))) + # mem_sz = int(mem_sz + node.dist_from_bw) + + if is_materialized(node): + return mem_sz + else: + return mem_sz * 2 + + nx_graph = nx.DiGraph() + for node in full_bw_graph.nodes: + if node.op == 'output': + continue + + if node in required_bw_nodes: + nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) + continue + + if node.op == 'placeholder' and "primals" in node.target: + nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) + + # If a node can't be recomputed (too expensive or involves randomness), + # we prevent it from being recomputed by adding an inf edge to the source + # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. + if ban_recomputation(node) and node in required_fw_nodes: + nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) + + if 'tensor_meta' not in node.meta: + weight = math.inf + else: + weight = get_node_weight(node) + + # Creates the weights on the "node" edge + nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) + for user in node.users: + nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) + + cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") + reachable, non_reachable = partition + cutset = set() + for u, nbrs in ((n, nx_graph[n]) for n in reachable): + cutset.update((u, v) for v in nbrs if v in non_reachable) + + cut_nodes = set() + for node_in, node_out in cutset: + assert node_in[:-3] == node_out[:-4] + node_name = node_in[:-3] + cut_nodes.add(node_name) + + # To make this stuff deterministic + node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]) + fw_module, bw_module = _extract_fwd_bwd_modules(joint_module, saved_values) + if AOT_PARTITIONER_DEBUG: + print("Theoretical Activations Stored: ", sum([_size_of(i.meta['tensor_meta']) for i in saved_values]) / 1e9) + fw_module_nodes = set([node.name for node in fw_module.graph.nodes if node.op == 'call_function']) + bw_module_nodes = set([node.name for node in bw_module.graph.nodes if node.op == 'call_function']) + remat_nodes = fw_module_nodes & bw_module_nodes + + counts = defaultdict(int) + for node in fw_module.graph.nodes: + if node.name in remat_nodes and hasattr(node.target, '_overloadpacket'): + counts[str(node.target._overloadpacket)] += 1 + print("# nodes rematerialized: ", len(remat_nodes)) + print("Count of Ops Rematerialized: ", sorted(counts.items(), key=lambda x: x[1], reverse=True)) + return fw_module, bw_module + + +def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph", clear_meta=True): + if clear_meta: + new_graph = copy.deepcopy(traced.graph) + traced = fx.GraphModule(traced, new_graph) + for node in traced.graph.nodes: + node.meta = {} + base, ext = os.path.splitext(fname) + if not ext: + ext = ".svg" + print(f"Writing FX graph to file: {base}{ext}") + g = graph_drawer.FxGraphDrawer(traced, figname) + x = g.get_main_dot_graph() + getattr(x, "write_" + ext.lstrip("."))(f"{base}{ext}") + + +def draw_joint_graph(graph, joint_inputs, file_name="full_graph.png"): + draw_graph(graph, file_name) + return default_partition(graph, joint_inputs) diff --git a/functorch/functorch/_src/python_key.py b/functorch/functorch/_src/python_key.py new file mode 100644 index 0000000000000..5fe0aff691ca5 --- /dev/null +++ b/functorch/functorch/_src/python_key.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +__all__ = ["make_fx", "ProxyTensor", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"] +from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose + +pythonkey_decompose = decompose +PythonTensor = ProxyTensor diff --git a/functorch/functorch/_src/pytree_hacks.py b/functorch/functorch/_src/pytree_hacks.py new file mode 100644 index 0000000000000..3694a53d7debb --- /dev/null +++ b/functorch/functorch/_src/pytree_hacks.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.utils._pytree import tree_flatten, tree_unflatten + + +def tree_map_(fn_, pytree): + flat_args, _ = tree_flatten(pytree) + [fn_(arg) for arg in flat_args] + return pytree + + +class PlaceHolder(): + def __repr__(self): + return '*' + + +def treespec_pprint(spec): + leafs = [PlaceHolder() for _ in range(spec.num_leaves)] + result = tree_unflatten(leafs, spec) + return repr(result) diff --git a/functorch/functorch/_src/top_operators_github_usage.py b/functorch/functorch/_src/top_operators_github_usage.py new file mode 100644 index 0000000000000..9161f98d66faf --- /dev/null +++ b/functorch/functorch/_src/top_operators_github_usage.py @@ -0,0 +1,623 @@ +""" +From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0 +Try to keep this list in sync with that. +""" +top_torch = [ + ("t", 6837449), + ("tensor", 585786), + ("mode", 462182), + ("cat", 394818), + ("max", 368038), + ("zeros", 329495), + ("load", 327756), + ("no_grad", 294694), + ("save", 265130), + ("from_numpy", 243063), + ("manual_seed", 165044), + ("ones", 153696), + ("randn", 150796), + ("stack", 133358), + ("sum", 130772), + ("arange", 98087), + ("rand", 94715), + ("mean", 88546), + ("exp", 73883), + ("zeros_like", 72831), + ("min", 72248), + ("sigmoid", 66798), + ("log", 62135), + ("matmul", 47811), + ("clamp", 45304), + ("sqrt", 44911), + ("abs", 43535), + ("tanh", 42793), + ("empty", 40311), + ("argmax", 38435), + ("bmm", 33984), + ("pow", 33571), + ("norm", 31125), + ("mm", 30995), + ("is_tensor", 29546), + ("ones_like", 29512), + ("nonzero", 28681), + ("full", 28373), + ("unsqueeze", 27911), + ("where", 26585), + ("randperm", 26450), + ("eye", 24342), + ("mul", 23236), + ("topk", 22537), + ("as_tensor", 21967), + ("sort", 21412), + ("squeeze", 20863), + ("randint", 20771), + ("linspace", 20041), + ("add", 19201), + ("transpose", 18663), + ("split", 18325), + ("gather", 17904), + ("set_grad_enabled", 16013), + ("sin", 15669), + ("cos", 15562), + ("div", 15513), + ("index_select", 14866), + ("multinomial", 14331), + ("flatten", 14267), + ("isnan", 14170), + ("randn_like", 13096), + ("eq", 12680), + ("einsum", 12480), + ("round", 12367), + ("floor", 11628), + ("allclose", 11000), + ("reshape", 10605), + ("diag", 10167), + ("chunk", 9581), + ("std", 9379), + ("set_default_tensor_type", 9281), + ("triu", 8559), + ("meshgrid", 8292), + ("set_num_threads", 8126), + ("unique", 7964), + ("full_like", 7780), + ("tril", 7538), + ("dot", 7275), + ("sign", 6943), + ("equal", 6916), + ("normal", 6750), + ("cumsum", 6556), + ("dist", 6058), + ("isfinite", 6030), + ("gt", 5935), + ("set_printoptions", 5888), + ("range", 5491), + ("empty_like", 5351), + ("flip", 5342), + ("masked_select", 5341), + ("bernoulli", 5262), + ("atan", 5253), + ("var", 5247), + ("prod", 5200), + ("erf", 5088), + ("inverse", 5072), + ("addmm", 4854), + ("logsumexp", 4582), + ("fft", 4436), + ("lt", 4421), + ("log2", 4316), + ("enable_grad", 4238), + ("rand_like", 4187), + ("argsort", 3972), + ("seed", 3932), + ("mv", 3547), + ("ger", 3309), + ("ge", 3248), + ("atan2", 3210), + ("ceil", 3202), + ("ne", 3075), + ("bincount", 3063), + ("acos", 3055), + ("rsqrt", 3031), + ("svd", 3029), + ("numel", 3003), + ("log1p", 2840), + ("unbind", 2808), + ("le", 2714), + ("isinf", 2707), + ("cross", 2646), + ("set_default_dtype", 2536), + ("argmin", 2535), + ("sparse_coo_tensor", 2489), + ("log10", 2304), + ("kthvalue", 2192), + ("set_rng_state", 2158), + ("get_rng_state", 1996), + ("get_default_dtype", 1879), + ("det", 1868), + ("qr", 1864), + ("histc", 1852), + ("symeig", 1832), + ("trace", 1801), + ("median", 1795), + ("addcmul", 1751), + ("remainder", 1717), + ("baddbmm", 1693), + ("lgamma", 1665), + ("repeat_interleave", 1598), + ("fmod", 1576), + ("reciprocal", 1575), + ("tan", 1560), + ("initial_seed", 1532), + ("take", 1529), + ("stft", 1487), + ("get_num_threads", 1477), + ("real", 1459), + ("cholesky", 1406), + ("quantize_per_tensor", 1392), + ("diag_embed", 1364), + ("lerp", 1363), + ("asin", 1345), + ("eig", 1333), + ("trunc", 1290), + ("diagonal", 1287), + ("cosh", 1279), + ("rfft", 1269), + ("cumprod", 1260), + ("addr", 1211), + ("roll", 1198), + ("narrow", 1188), + ("digamma", 1172), + ("square", 1163), + ("sinh", 1131), + ("logspace", 1084), + ("broadcast_tensors", 1070), + ("irfft", 1013), + ("frac", 997), + ("hann_window", 994), + ("solve", 989), + ("logdet", 977), + ("expm1", 968), + ("cdist", 946), + ("addmv", 903), + ("randint_like", 888), + ("tensordot", 888), + ("ifft", 877), + ("true_divide", 854), + ("erfinv", 830), + ("addcdiv", 819), + ("addbmm", 813), + ("renorm", 781), + ("pinverse", 753), + ("isclose", 740), + ("erfc", 729), + ("is_storage", 725), + ("triangular_solve", 723), + ("rot90", 709), + ("logical_not", 686), + ("geqrf", 681), + ("slogdet", 677), + ("lu", 665), + ("hamming_window", 659), + ("orgqr", 651), + ("ormqr", 622), + ("is_floating_point", 602), + ("diagflat", 562), + ("cholesky_solve", 559), + ("tril_indices", 552), + ("chain_matmul", 551), + ("triu_indices", 548), + ("angle", 522), + ("poisson", 505), + ("matrix_power", 485), + ("unique_consecutive", 471), + ("quantize_per_channel", 465), + ("std_mean", 458), + ("bartlett_window", 447), + ("var_mean", 428), + ("lstsq", 421), + ("logical_and", 419), + ("mvlgamma", 411), + ("blackman_window", 400), + ("bitwise_not", 395), + ("cholesky_inverse", 388), + ("as_strided", 384), + ("floor_divide", 353), + ("cartesian_prod", 321), + ("lu_solve", 317), + ("set_flush_denormal", 310), + ("empty_strided", 283), + ("logical_xor", 282), + ("polygamma", 282), + ("logical_or", 280), + ("set_num_interop_threads", 278), + ("combinations", 274), + ("trapz", 270), + ("matrix_rank", 260), + ("lu_unpack", 255), + ("result_type", 244), + ("conj", 231), + ("cummax", 230), + ("lobpcg", 229), + ("bitwise_xor", 217), + ("promote_types", 213), + ("get_num_interop_threads", 211), + ("cummin", 205), + ("bitwise_and", 198), + ("dequantize", 192), + ("bitwise_or", 191), + ("imag", 191), + ("can_cast", 184), + ("istft", 180), + ("compiled_with_cxx11_abi", 159), + ("is_complex", 151), + ("block_diag", 136), + ("pca_lowrank", 124), + ("absolute", 122), + ("svd_lowrank", 108), + ("neg", 2), +] + +top_nn_functional = [ + ("nn.functional.softmax", 10522), + ("nn.functional.relu", 8572), + ("nn.functional.interpolate", 7277), + ("nn.functional.pad", 5207), + ("nn.functional.log_softmax", 4699), + ("nn.functional.normalize", 2338), + ("nn.functional.cross_entropy", 2083), + ("nn.functional.grid_sample", 1970), + ("nn.functional.one_hot", 1967), + ("nn.functional.mse_loss", 1920), + ("nn.functional.conv2d", 1593), + ("nn.functional.dropout", 1516), + ("nn.functional.softplus", 1385), + ("nn.functional.sigmoid", 1128), + ("nn.functional.linear", 1036), + ("nn.functional.gelu", 930), + ("nn.functional.avg_pool2d", 899), + ("nn.functional.max_pool2d", 876), + ("nn.functional.nll_loss", 863), + ("nn.functional.embedding", 737), + ("nn.functional.tanh", 664), + ("nn.functional.leaky_relu", 640), + ("nn.functional.adaptive_avg_pool2d", 633), + ("nn.functional.cosine_similarity", 627), + ("nn.functional.unfold", 609), + ("nn.functional.conv1d", 596), + ("nn.functional.binary_cross_entropy_with_logits", 591), + ("nn.functional.l1_loss", 571), + ("nn.functional.binary_cross_entropy", 492), + ("nn.functional.elu", 416), + ("nn.functional.batch_norm", 413), + ("nn.functional.upsample", 413), + ("nn.functional.fold", 305), + ("nn.functional.affine_grid", 298), + ("nn.functional.max_pool1d", 297), + ("nn.functional.torch", 294), + ("nn.functional.threshold", 263), + ("nn.functional.smooth_l1_loss", 262), + ("nn.functional.pairwise_distance", 253), + ("nn.functional.logsigmoid", 243), + ("nn.functional.adaptive_max_pool2d", 235), + ("nn.functional.relu6", 213), + ("nn.functional.pixel_shuffle", 209), + ("nn.functional.avg_pool3d", 203), + ("nn.functional.bilinear", 203), + ("nn.functional.conv_transpose2d", 201), + ("nn.functional.gumbel_softmax", 197), + ("nn.functional.max_unpool2d", 196), + ("nn.functional.kl_div", 191), + ("nn.functional.hardtanh", 189), + ("nn.functional.ctc_loss", 185), + ("nn.functional.layer_norm", 178), + ("nn.functional.conv3d", 172), + ("nn.functional.max_unpool3d", 167), + ("nn.functional.hardshrink", 165), + ("nn.functional.hardswish", 156), + ("nn.functional.selu", 156), + ("nn.functional.glu", 155), + ("nn.functional.assert_int_or_pair", 150), + ("nn.functional.hardsigmoid", 146), + ("nn.functional.upsample_bilinear", 146), + ("nn.functional.max_pool3d", 140), + ("nn.functional.adaptive_avg_pool3d", 139), + ("nn.functional.instance_norm", 124), + ("nn.functional.embedding_bag", 122), + ("nn.functional.upsample_nearest", 110), + ("nn.functional.avg_pool1d", 105), + ("nn.functional.prelu", 102), + ("nn.functional.celu", 92), + ("nn.functional.dropout2d", 86), + ("nn.functional.hinge_embedding_loss", 82), + ("nn.functional.softsign", 81), + ("nn.functional.max_unpool1d", 74), + ("nn.functional.silu", 74), + ("nn.functional.softshrink", 70), + ("nn.functional.leaky_relu_", 68), + ("nn.functional.softmin", 67), + ("nn.functional.channel_shuffle", 66), + ("nn.functional.multilabel_margin_loss", 66), + ("nn.functional.dropout3d", 65), + ("nn.functional.multi_margin_loss", 65), + ("nn.functional.lp_pool2d", 64), + ("nn.functional.conv_transpose1d", 62), + ("nn.functional.triplet_margin_loss", 62), + ("nn.functional.tanhshrink", 61), + ("nn.functional.adaptive_max_pool1d", 59), + ("nn.functional.cosine_embedding_loss", 58), + ("nn.functional.multi_head_attention_forward", 58), + ("nn.functional.max_pool1d_with_indices", 53), + ("nn.functional.poisson_nll_loss", 53), + ("nn.functional.margin_ranking_loss", 52), + ("nn.functional.soft_margin_loss", 52), + ("nn.functional.adaptive_max_pool3d", 51), + ("nn.functional.group_norm", 51), + ("nn.functional.local_response_norm", 51), + ("nn.functional.multilabel_soft_margin_loss", 51), + ("nn.functional.relu_", 50), + ("nn.functional.alpha_dropout", 49), + ("nn.functional.feature_alpha_dropout", 49), + ("nn.functional.lp_pool1d", 49), + ("nn.functional.adaptive_max_pool1d_with_indices", 48), + ("nn.functional.adaptive_max_pool2d_with_indices", 48), + ("nn.functional.adaptive_max_pool3d_with_indices", 48), + ("nn.functional.fractional_max_pool2d", 48), + ("nn.functional.fractional_max_pool2d_with_indices", 48), + ("nn.functional.fractional_max_pool3d", 48), + ("nn.functional.fractional_max_pool3d_with_indices", 48), + ("nn.functional.max_pool2d_with_indices", 48), + ("nn.functional.max_pool3d_with_indices", 48), + ("nn.functional.handle_torch_function", 47), + ("nn.functional.has_torch_function", 47), + ("nn.functional.adaptive_avg_pool1d", 43), + ("nn.functional.pdist", 43), + ("nn.functional.rrelu_", 37), + ("nn.functional.elu_", 34), + ("nn.functional.boolean_dispatch", 33), + ("nn.functional.hardtanh_", 26), + ("nn.functional.triplet_margin_with_distance_loss", 23), + ("nn.functional.selu_", 20), + ("nn.functional.pixel_unshuffle", 19), + ("nn.functional.conv_transpose3d", 18), + ("nn.functional.gaussian_nll_loss", 15), + ("nn.functional.has_torch_function_unary", 15), + ("nn.functional.has_torch_function_variadic", 15), + ("nn.functional.celu_", 13), + ("nn.functional.huber_loss", 7), + ("nn.functional.mish", 4), + ("nn.functional.threshold_", 3), + ("nn.functional.grad", 2), + ("nn.functional.conv_tbc", 1), + ("nn.functional.math", 1), +] + +top_nn_module = [ + ("nn.Module", 927129, None), + ("nn.Linear", 530688, "nn.functional.linear"), + ("nn.Sequential", 384968, None), + ("nn.Conv2d", 383320, "nn.functional.conv2d"), + ("nn.ReLU", 318877, "nn.functional.relu"), + ("nn.BatchNorm2d", 233265, "nn.functional.batch_norm"), + ("nn.Dropout", 179268, "nn.functional.dropout"), + ("nn.ModuleList", 171225, None), + ("nn.Parameter", 153291, None), + ("nn.CrossEntropyLoss", 152696, "nn.functional.cross_entropy"), + ("nn.MaxPool2d", 138619, "nn.functional.max_pool2d"), + ("nn.Embedding", 111844, "nn.functional.embedding"), + ("nn.DataParallel", 104238, None), + ("nn.MSELoss", 82954, "nn.functional.mse_loss"), + ("nn.Sigmoid", 75810, "nn.functional.sigmoid"), + ("nn.LeakyReLU", 65632, "nn.functional.leaky_relu"), + ("nn.BatchNorm1d", 65374, "nn.functional.batch_norm"), + ("nn.Softmax", 65114, "nn.functional.softmax"), + ("nn.Tanh", 59445, "nn.functional.tanh"), + ("nn.AdaptiveAvgPool2d", 59071, "nn.functional.adaptive_avg_pool2d"), + ("nn.AvgPool2d", 58377, "nn.functional.avg_pool2d"), + ("nn.ConvTranspose2d", 57524, "nn.functional.conv_transpose2d"), + ("nn.LSTM", 57411, None), + ("nn.Conv1d", 41108, "nn.functional.conv1d"), + ("nn.LayerNorm", 36089, "nn.functional.layer_norm"), + ("nn.BCELoss", 34005, "nn.functional.binary_cross_entropy"), + ("nn.Upsample", 32527, "nn.functional.interpolate"), + ("nn.BCEWithLogitsLoss", 29944, "nn.functional.binary_cross_entropy_with_logits"), + ("nn.GRU", 25421, None), + ("nn.Dropout2d", 23512, "nn.functional.dropout2d"), + ("nn.LogSoftmax", 22897, "nn.functional.log_softmax"), + ("nn.L1Loss", 22778, "nn.functional.l1_loss"), + ("nn.GroupNorm", 22183, "nn.functional.group_norm"), + ("nn.NLLLoss", 21751, "nn.functional.nll_loss"), + ("nn.Conv3d", 20874, "nn.functional.conv3d"), + ("nn.Identity", 17911, None), + ("nn.InstanceNorm2d", 16426, "nn.functional.instance_norm"), + ("nn.BatchNorm3d", 16378, "nn.functional.batch_norm"), + ("nn.PReLU", 13472, "nn.functional.prelu"), + ("nn.ReLU6", 12622, "nn.functional.relu6"), + ("nn.ELU", 12508, "nn.functional.elu"), + ("nn.LSTMCell", 10885, None), + ("nn.Flatten", 10384, "torch.flatten"), + ("nn.ModuleDict", 10255, None), + ("nn.ReflectionPad2d", 9954, "nn.functional.pad"), + ("nn.MaxPool3d", 9526, "nn.functional.max_pool3d"), + ("nn.MaxPool1d", 9154, "nn.functional.max_pool1d"), + ("nn.RNN", 9154, None), + ("nn.ZeroPad2d", 8847, "nn.functional.pad"), + ("nn.ParameterList", 7702, None), + ("nn.SyncBatchNorm", 6814, None), + ("nn.PixelShuffle", 6571, "nn.functional.pixel_shuffle"), + ("nn.SmoothL1Loss", 6517, "nn.functional.smooth_l1_loss"), + ("nn.Hardswish", 6458, "nn.functional.hardswish"), + ("nn.AdaptiveMaxPool2d", 6071, "nn.functional.adaptive_max_pool2d"), + ("nn.SELU", 6043, "nn.functional.selu"), + ("nn.ConvTranspose3d", 6039, "nn.functional.conv_transpose3d"), + ("nn.GRUCell", 5840, None), + ("nn.ReplicationPad2d", 5600, "nn.functional.pad"), + ("nn.KLDivLoss", 5541, "nn.functional.kl_div"), + ("nn.ConvTranspose1d", 5183, "nn.functional.conv_transpose1d"), + ("nn.Softplus", 5120, "nn.functional.softplus"), + ("nn.SiLU", 4895, "nn.functional.silu"), + ("nn.AvgPool3d", 4523, "nn.functional.avg_pool3d"), + ("nn.CosineSimilarity", 4058, "nn.functional.cosine_similarity"), + ("nn.GELU", 3932, "nn.functional.gelu"), + ("nn.UpsamplingBilinear2d", 3673, "nn.functional.interpolate"), + ("nn.InstanceNorm1d", 3658, "nn.functional.instance_norm"), + ("nn.Transformer", 3604, None), + ("nn.MultiheadAttention", 3435, "nn.functional.multi_head_attention_forward"), + ("nn.AvgPool1d", 3195, "nn.functional.avg_pool1d"), + ("nn.Dropout3d", 2964, "nn.functional.dropout3d"), + ("nn.AdaptiveAvgPool3d", 2915, "nn.functional.adaptive_avg_pool3d"), + ("nn.InstanceNorm3d", 2893, "nn.functional.instance_norm"), + ("nn.Hardtanh", 2613, "nn.functional.hardtanh"), + ("nn.MarginRankingLoss", 2568, "nn.functional.margin_ranking_loss"), + ("nn.GLU", 2526, "nn.functional.glu"), + ("nn.AdaptiveAvgPool1d", 2481, "nn.functional.adaptive_avg_pool1d"), + ("nn.EmbeddingBag", 2344, "nn.functional.embedding_bag"), + ("nn.TransformerEncoderLayer", 2292, None), + ("nn.TransformerEncoder", 2091, None), + ("nn.MaxUnpool2d", 2031, "nn.functional.max_unpool2d"), + ("nn.UpsamplingNearest2d", 2004, "nn.functional.interpolate"), + ("nn.ConstantPad1d", 1904, "nn.functional.pad"), + ("nn.ConstantPad2d", 1791, "nn.functional.pad"), + ("nn.CTCLoss", 1789, "nn.functional.ctc_loss"), + ("nn.AdaptiveMaxPool1d", 1713, "nn.functional.adaptive_max_pool1d"), + ("nn.AdaptiveLogSoftmaxWithLoss", 1665, None), + ("nn.Bilinear", 1664, "nn.functional.bilinear"), + ("nn.RNNCell", 1653, None), + ("nn.MultiLabelSoftMarginLoss", 1624, "nn.functional.multilabel_soft_margin_loss"), + ("nn.Unfold", 1452, "nn.functional.unfold"), + ("nn.RReLU", 1431, "nn.functional.rrelu"), + ("nn.CosineEmbeddingLoss", 1357, "nn.functional.cosine_embedding_loss"), + ("nn.LocalResponseNorm", 1331, "nn.functional.local_response_norm"), + ("nn.Softmax2d", 1300, "nn.functional.softmax"), + ("nn.PairwiseDistance", 1241, "nn.functional.pairwise_distance"), + ("nn.LogSigmoid", 1235, "nn.functional.logsigmoid"), + ("nn.TripletMarginLoss", 1230, "nn.functional.triplet_margin_loss"), + ("nn.RNNBase", 1133, None), + ("nn.Threshold", 1043, "nn.functional.threshold"), + ("nn.AdaptiveMaxPool3d", 1025, "nn.functional.adaptive_max_pool3d"), + ("nn.CELU", 1018, "nn.functional.celu"), + ("nn.NLLLoss2d", 966, "nn.functional.nll_loss"), + ("nn.Softsign", 877, "nn.functional.softsign"), + ("nn.ReplicationPad1d", 862, "nn.functional.pad"), + ("nn.SoftMarginLoss", 856, "nn.functional.soft_margin_loss"), + ("nn.ParameterDict", 742, None), + ("nn.ReflectionPad1d", 731, "nn.functional.pad"), + ("nn.Softshrink", 713, "nn.functional.softshrink"), + ("nn.AlphaDropout", 710, "nn.functional.alpha_dropout"), + ("nn.Tanhshrink", 681, "nn.functional.tanhshrink"), + ("nn.PoissonNLLLoss", 676, "nn.functional.poisson_nll_loss"), + ("nn.MaxUnpool3d", 660, "nn.functional.max_unpool3d"), + ("nn.Fold", 630, "nn.functional.fold"), + ("nn.MultiMarginLoss", 622, "nn.functional.multi_margin_loss"), + ("nn.TransformerDecoderLayer", 614, None), + ("nn.TransformerDecoder", 607, None), + ("nn.Hardshrink", 592, "nn.functional.hardshrink"), + ("nn.ConstantPad3d", 582, "nn.functional.pad"), + ("nn.MultiLabelMarginLoss", 580, "nn.functional.multilabel_margin_loss"), + ("nn.LPPool2d", 550, "nn.functional.lp_pool2d"), + ("nn.Softmin", 537, "nn.functional.softmin"), + ("nn.MaxUnpool1d", 518, "nn.functional.max_unpool1d"), + ("nn.FractionalMaxPool2d", 484, "nn.functional.fractional_max_pool2d"), + ("nn.Hardsigmoid", 477, "nn.functional.hardsigmoid"), + ("nn.ReplicationPad3d", 470, "nn.functional.pad"), + ("nn.HingeEmbeddingLoss", 442, "nn.functional.hinge_embedding_loss"), + ("nn.LPPool1d", 386, "nn.functional.lp_pool1d"), + ("nn.FractionalMaxPool3d", 252, "nn.functional.fractional_max_pool3d"), + ("nn.Container", 217, None), + ("nn.Unflatten", 206, "nn.functional.unflatten"), + ("nn.FeatureAlphaDropout", 136, "nn.functional.feature_alpha_dropout"), + ("nn.TripletMarginWithDistanceLoss", 107, "nn.functional.triplet_margin_with_distance_loss"), + ("nn.ChannelShuffle", 90, "nn.functional.channel_shuffle"), + ("nn.RNNCellBase", 88, None), + ("nn.LazyLinear", 81, "nn.functional.linear"), + ("nn.UninitializedParameter", 60, None), + ("nn.CrossMapLRN2d", 59, None), + ("nn.GaussianNLLLoss", 55, "nn.functional.gaussian_nll_loss"), + ("nn.PixelUnshuffle", 45, "nn.functional.pixel_unshuffle"), + ("nn.Mish", 31, "nn.functional.mish"), + ("nn.ReflectionPad3d", 22, "nn.functional.pad"), + ("nn.HuberLoss", 18, "nn.functional.huber_loss"), + ("nn.LazyConv2d", 15, None), + ("nn.LazyConv1d", 9, None), + ("nn.LazyConv3d", 8, None), + ("nn.LazyConvTranspose1d", 8, None), + ("nn.LazyConvTranspose2d", 8, None), + ("nn.LazyConvTranspose3d", 8, None), + ("nn.LazyBatchNorm1d", 3, None), + ("nn.LazyBatchNorm2d", 3, None), + ("nn.LazyBatchNorm3d", 3, None), + ("nn.UninitializedBuffer", 3, None), +] + +# No rankings because these are a little hard to get rankings for +method_only_ops = [ + 'bfloat16', + 'bool', + 'byte', + 'char', + 'contiguous', + 'cpu', + 'cuda', + 'detach', + 'double', + 'expand', + 'expand_as', + 'float', + 'get_device', + 'half', + 'hardshrink', + 'index_add', + 'index_copy', + 'index_fill', + 'index_put', + 'int', + 'is_contiguous', + 'is_pinned', + 'is_set_to', + 'is_shared', + 'is_signed', + 'item', + 'long', + 'masked_scatter', + 'masked_fill', + 'narrow_copy', + 'numpy', + 'pin_memory', + 'repeat', + 'reshape_as', + 'select', + 'short', + 'storage_offset', + 'sum_to_size', + 'to', + 'to_mkldnn', + 'tolist', + 'type', + 'type_as', + 'unfold', + 'view', + 'view_as', +] + + +def get_nn_functional_top_list(): + top_nn_functional_ = {k: v for k, v in top_nn_functional} + for _, count, functional_name in top_nn_module: + if functional_name is None: + continue + if functional_name == 'torch.flatten': + continue + if functional_name not in top_nn_functional_: + top_nn_functional_[functional_name] = count + else: + top_nn_functional_[functional_name] += count + + top_nn_functional_ = [(k, v) for k, v in top_nn_functional_.items()] + top_nn_functional_.sort(key=lambda x: x[1], reverse=True) + return top_nn_functional_ + + +usage_count = {} +for k, v in get_nn_functional_top_list(): + usage_count[k] = v +for k, v in top_torch: + usage_count[k] = v diff --git a/functorch/functorch/_src/vmap.py b/functorch/functorch/_src/vmap.py new file mode 100644 index 0000000000000..1504107a2ca95 --- /dev/null +++ b/functorch/functorch/_src/vmap.py @@ -0,0 +1,490 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import functools +from collections import OrderedDict +from torch import Tensor +from typing import Any, Callable, Optional, Tuple, Union, List +from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten, TreeSpec, _register_pytree_node +from .pytree_hacks import tree_map_ +from functools import partial + +from functorch._C import ( + _add_batch_dim, + _remove_batch_dim, + _vmap_decrement_nesting, + _vmap_increment_nesting, +) + +in_dims_t = Union[int, Tuple] +out_dims_t = Union[int, Tuple[int, ...]] + + +# Temporary OrderedDict registration as pytree +def _odict_flatten(d): + return list(d.values()), list(d.keys()) + + +def _odict_unflatten(values, context): + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) + + +# Checks that all args-to-be-batched have the same batch dim size + +def _validate_and_get_batch_size( + flat_in_dims: List[Optional[int]], + flat_args: List) -> int: + batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args) + if in_dim is not None] + if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes): + raise ValueError( + f'vmap: Expected all tensors to have the same size in the mapped ' + f'dimension, got sizes {batch_sizes} for the mapped dimension') + return batch_sizes[0] + + +def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int: + if isinstance(batched_outputs, tuple): + return len(batched_outputs) + return 1 + +# If value is a tuple, check it has length `num_elements`. +# If value is not a tuple, make a tuple with `value` repeated `num_elements` times + + +def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple: + if not isinstance(value, tuple): + return (value,) * num_elements + if len(value) != num_elements: + raise ValueError(error_message_lambda()) + return value + + +def _process_batched_inputs( + in_dims: in_dims_t, args: Tuple, func: Callable +) -> Tuple[int, List[Any], List[Any], TreeSpec]: + if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'expected `in_dims` to be int or a (potentially nested) tuple ' + f'matching the structure of inputs, got: {type(in_dims)}.') + if len(args) == 0: + raise ValueError( + f'vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add ' + f'inputs, or you are trying to vmap over a function with no inputs. ' + f'The latter is unsupported.') + + flat_args, args_spec = tree_flatten(args) + flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) + if flat_in_dims is None: + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'in_dims is not compatible with the structure of `inputs`. ' + f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs ' + f'has structure {args_spec}.') + + for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): + if not isinstance(in_dim, int) and in_dim is not None: + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for an input but in_dim must be either ' + f'an integer dimension or None.') + if isinstance(in_dim, int) and not isinstance(arg, Tensor): + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for an input but the input is of type ' + f'{type(arg)}. We cannot vmap over non-Tensor arguments, ' + f'please use None as the respective in_dim') + if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for some input, but that input is a Tensor ' + f'of dimensionality {arg.dim()} so expected in_dim to satisfy ' + f'-{arg.dim()} <= in_dim < {arg.dim()}.') + if in_dim is not None and in_dim < 0: + flat_in_dims[i] = in_dim % arg.dim() + + return _validate_and_get_batch_size(flat_in_dims, flat_args), flat_in_dims, flat_args, args_spec + +# Creates BatchedTensors for every Tensor in arg that should be batched. +# Returns the (potentially) batched arguments and the batch_size. + + +def _create_batched_inputs( + flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple: + # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] + batched_inputs = [arg if in_dim is None else + _add_batch_dim(arg, in_dim, vmap_level) + for in_dim, arg in zip(flat_in_dims, flat_args)] + return tree_unflatten(batched_inputs, args_spec) + +# Undos the batching (and any batch dimensions) associated with the `vmap_level`. + + +def _unwrap_batched( + batched_outputs: Union[Tensor, Tuple[Tensor, ...]], + out_dims: out_dims_t, + vmap_level: int, batch_size: int, func: Callable) -> Tuple: + flat_batched_outputs, output_spec = tree_flatten(batched_outputs) + + for out in flat_batched_outputs: + if isinstance(out, torch.Tensor): + continue + raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return ' + f'Tensors, got type {type(out)} as a return.') + + def incompatible_error(): + raise ValueError( + f'vmap({_get_name(func)}, ..., out_dims={out_dims})(): ' + f'out_dims is not compatible with the structure of `outputs`. ' + f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs ' + f'has structure {output_spec}.') + + if isinstance(batched_outputs, torch.Tensor): + # Some weird edge case requires us to spell out the following + # see test_out_dims_edge_case + if isinstance(out_dims, int): + flat_out_dims = [out_dims] + elif isinstance(out_dims, tuple) and len(out_dims) == 1: + flat_out_dims = out_dims + out_dims = out_dims[0] + else: + incompatible_error() + else: + flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) + if flat_out_dims is None: + incompatible_error() + + flat_outputs = [ + _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) + for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims) + ] + return tree_unflatten(flat_outputs, output_spec) + + +def _check_int(x, func, out_dims): + if isinstance(x, int): + return + raise ValueError( + f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be ' + f'an int or a python collection of ints representing where in the outputs the ' + f'vmapped dimension should appear.') + + +def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None: + if isinstance(out_dims, int): + return + tree_map_(partial(_check_int, func=func, out_dims=out_dims), out_dims) + + +def _get_name(func: Callable): + if hasattr(func, '__name__'): + return func.__name__ + + # Not all callables have __name__, in fact, only static functions/methods do. + # A callable created via functools.partial or an nn.Module, to name some + # examples, don't have a __name__. + return repr(func) + +# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, +# sends those into func, and then unwraps the output BatchedTensors. Operations +# on BatchedTensors perform the batched operations that the user is asking for. +# +# vmap's randomness behavior differs from JAX's, which would require a PRNG key +# to be passed everywhere. + + +def vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = 'error') -> Callable: + """ + vmap is the vectorizing map; ``vmap(func)`` returns a new function that + maps :attr:`func` over some dimension of the inputs. Semantically, vmap + pushes the map into PyTorch operations called by :attr:`func`, effectively + vectorizing those operations. + + vmap is useful for handling batch dimensions: one can write a function + :attr:`func` that runs on examples and then lift it to a function that can + take batches of examples with ``vmap(func)``. vmap can also be used to + compute batched gradients when composed with autograd. + + Args: + func (function): A Python function that takes one or more arguments. + Must return one or more Tensors. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. :attr:`in_dims` should have a + structure like the inputs. If the :attr:`in_dim` for a particular + input is None, then that indicates there is no map dimension. + Default: 0. + out_dims (int or Tuple[int]): Specifies where the mapped dimension + should appear in the outputs. If :attr:`out_dims` is a Tuple, then + it should have one element per output. Default: 0. + randomness (str): Specifies whether the randomness in this + vmap should be the same or different across batches. If 'different', + the randomness for each batch will be different. If 'same', the + randomness will be the same across batches. If 'error', any calls to + random functions will error. Default: 'error'. WARNING: this flag + only applies to random PyTorch operations and does not apply to + Python's random module or numpy randomness. + + Returns: + Returns a new "batched" function. It takes the same inputs as + :attr:`func`, except each input has an extra dimension at the index + specified by :attr:`in_dims`. It takes returns the same outputs as + :attr:`func`, except each output has an extra dimension at the index + specified by :attr:`out_dims`. + + .. warning: + :func:`vmap` works best with functional-style code. Please do not + perform any side-effects in :attr:`func`, with the exception of + in-place PyTorch operations. Examples of side-effects include mutating + Python data structures and assigning values to variables not captured + in :attr:`func`. + + One example of using :func:`vmap` is to compute batched dot products. PyTorch + doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully + rummaging through docs, use :func:`vmap` to construct a new function. + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = functorch.vmap(torch.dot) # [N, D], [N, D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) + + :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler + model authoring experience. + + >>> batch_size, feature_size = 3, 5 + >>> weights = torch.randn(feature_size, requires_grad=True) + >>> + >>> def model(feature_vec): + >>> # Very simple linear model with activation + >>> return feature_vec.dot(weights).relu() + >>> + >>> examples = torch.randn(batch_size, feature_size) + >>> result = functorch.vmap(model)(examples) + + :func:`vmap` can also help vectorize computations that were previously difficult + or impossible to batch. One example is higher-order gradient computation. + The PyTorch autograd engine computes vjps (vector-Jacobian products). + Computing a full Jacobian matrix for some function f: R^N -> R^N usually + requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`, + we can vectorize the whole computation, computing the Jacobian in a single + call to ``autograd.grad``. + + >>> # Setup + >>> N = 5 + >>> f = lambda x: x ** 2 + >>> x = torch.randn(N, requires_grad=True) + >>> y = f(x) + >>> I_N = torch.eye(N) + >>> + >>> # Sequential approach + >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] + >>> for v in I_N.unbind()] + >>> jacobian = torch.stack(jacobian_rows) + >>> + >>> # vectorized gradient computation + >>> def get_vjp(v): + >>> return torch.autograd.grad(y, x, v) + >>> jacobian = functorch.vmap(get_vjp)(I_N) + + :func:`vmap` can also be nested, producing an output with multiple batched dimensions + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = functorch.vmap(functorch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] + >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) + >>> batched_dot(x, y) # tensor of size [2, 3] + + If the inputs are not batched along the first dimension, :attr:`in_dims` specifies + the dimension that each inputs are batched along as + + >>> torch.dot # [N], [N] -> [] + >>> batched_dot = functorch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension + + If there are multiple inputs each of which is batched along different dimensions, + :attr:`in_dims` must be a tuple with the batch dimension for each input as + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = functorch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None + + If the input is a Python struct, :attr:`in_dims` must be a tuple containing a struct + matching the shape of the input: + + >>> f = lambda dict: torch.dot(dict['x'], dict['y']) + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> input = {'x': x, 'y': y} + >>> batched_dot = functorch.vmap(f, in_dims=({'x': 0, 'y': None},)) + >>> batched_dot(input) + + By default, the output is batched along the first dimension. However, it can be batched + along any dimension by using :attr:`out_dims` + + >>> f = lambda x: x ** 2 + >>> x = torch.randn(2, 5) + >>> batched_pow = functorch.vmap(f, out_dims=1) + >>> batched_pow(x) # [5, 2] + + For any function that uses kwargs, the returned function will not batch the kwargs but will + accept kwargs + + >>> x = torch.randn([2, 5]) + >>> def f(x, scale=4.): + >>> return x * scale + >>> + >>> batched_pow = functorch.vmap(f) + >>> assert torch.allclose(batched_pow(x), x * 4) + >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] + + .. note:: + vmap does not provide general autobatching or handle variable-length + sequences out of the box. + """ + _check_randomness_arg(randomness) + + @functools.wraps(func) + def wrapped(*args, **kwargs): + _check_out_dims_is_int_or_int_pytree(out_dims, func) + batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func) + return _flat_vmap( + func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs + ) + + return wrapped + + +def chunk_vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = 'error', + chunks=2) -> Callable: + """ + chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes + everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of + chunks at a time. For more details about vectorizing map, see :func:`vmap`. + + Args: + func (function): A Python function that takes one or more arguments. + Must return one or more Tensors. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. :attr:`in_dims` should have a + structure like the inputs. If the :attr:`in_dim` for a particular + input is None, then that indicates there is no map dimension. + Default: 0. + out_dims (int or Tuple[int]): Specifies where the mapped dimension + should appear in the outputs. If :attr:`out_dims` is a Tuple, then + it should have one element per output. Default: 0. + randomness (str): Specifies whether the randomness in this + vmap should be the same or different across batches. If 'different', + the randomness for each batch will be different. If 'same', the + randomness will be the same across batches. If 'error', any calls to + random functions will error. Default: 'error'. WARNING: this flag + only applies to random PyTorch operations and does not apply to + Python's random module or numpy randomness. + chunks (int): Number of chunks to use to split the input data. Default is 2. + If equals to 1 then :func:`vmap` is called. + + Returns: + Returns a new "batched" function. It takes the same inputs as + :attr:`func`, except each input has an extra dimension at the index + specified by :attr:`in_dims`. It takes returns the same outputs as + :attr:`func`, except each output has an extra dimension at the index + specified by :attr:`out_dims`. + """ + _check_randomness_arg(randomness) + + if chunks == 1: + return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness) + + def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_): + flat_args_chunks = tuple( + t.chunk(chunks_, dim=in_dim) if in_dim is not None else [t, ] * chunks_ + for t, in_dim in zip(flat_args_, flat_in_dims_) + ) + # transpose chunk dim and flatten structure + # chunks_flat_args is a list of flatten args + chunks_flat_args = zip(*flat_args_chunks) + return chunks_flat_args + + def _flatten_chunks_output(chunks_output_): + # chunks_output is a list of chunked outputs + # flatten chunked outputs: + flat_chunks_output = [] + arg_spec_list = [] + for output in chunks_output_: + flat_output, arg_specs = tree_flatten(output) + flat_chunks_output.append(flat_output) + arg_spec_list.append(arg_specs) + + arg_spec = arg_spec_list[0] # all specs should be the same + # transpose chunk dim and flatten structure + # flat_output_chunks is flat list of chunks + flat_output_chunks = list(zip(*flat_chunks_output)) + return flat_output_chunks, arg_spec + + @functools.wraps(func) + def wrapped_with_chunks(*args, **kwargs): + _check_out_dims_is_int_or_int_pytree(out_dims, func) + _, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func) + # Chunk flat arguments + chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks) + + # Apply vmap on chunks + chunks_output = [] + rs = torch.get_rng_state() if randomness == "same" else None + for flat_args in chunks_flat_args: + batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) + if rs is not None: + torch.set_rng_state(rs) + chunks_output.append( + _flat_vmap( + func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs + ) + ) + flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output) + # Removing temporary variables helps to reduce memory usage on device like CUDA + del chunks_output + + # concat chunks on out_dim + flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec) + assert len(flat_out_dims) == len(flat_output_chunks) + flat_output = [] + for out_dim in flat_out_dims: + flat_output.append(torch.cat(flat_output_chunks[0], dim=out_dim)) + # release source data + del flat_output_chunks[0] + del flat_output_chunks + + # finally unflatten the output + return tree_unflatten(flat_output, arg_spec) + + return wrapped_with_chunks + + +# Vmap refactored helper funcions: +def _check_randomness_arg(randomness): + if randomness not in ['error', 'different', 'same']: + raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}") + + +def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs): + vmap_level = _vmap_increment_nesting(batch_size, randomness) + try: + batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec) + batched_outputs = func(*batched_inputs, **kwargs) + return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) + finally: + _vmap_decrement_nesting() diff --git a/functorch/functorch/compile/__init__.py b/functorch/functorch/compile/__init__.py new file mode 100644 index 0000000000000..712b810ae0e31 --- /dev/null +++ b/functorch/functorch/compile/__init__.py @@ -0,0 +1,30 @@ +from .._src.python_key import pythonkey_decompose +from .._src.decompositions import register_decomposition, decomposition_table, get_decompositions +from .._src.fx_minifier import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess +from .._src.aot_autograd import ( + aot_function, + aot_module, + compiled_function, + compiled_module, + num_of_recompilations, + clear_compile_cache, + aot_module_simplified, +) +from .._src.compilers import ( + ts_compile, + tvm_compile, + draw_graph_compile, + nop, + nnc_jit, + memory_efficient_fusion, + debug_compile, + print_compile, + default_decompositions +) +from .._src.partitioners import ( + min_cut_rematerialization_partition, + default_partition, + draw_graph, + draw_joint_graph, +) +from .._src import config diff --git a/functorch/functorch/csrc/ADInterpreters.cpp b/functorch/functorch/csrc/ADInterpreters.cpp new file mode 100644 index 0000000000000..6a269d7e53946 --- /dev/null +++ b/functorch/functorch/csrc/ADInterpreters.cpp @@ -0,0 +1,192 @@ +#include +#include +#include + +namespace at { namespace functorch { + +static void checkForInvalidMutationOnCaptures( + const c10::OperatorHandle& op, + const torch::jit::Stack* stack, + int64_t cur_level) { + if (!isInplaceOp(op.schema())) { + return; + } + auto args = torch::jit::last(stack, op.schema().arguments().size()); + auto mutated_arg = unwrapIfDead(args[0].toTensor()); + auto* wrapper = maybeGetTensorWrapper(mutated_arg); + if (wrapper && wrapper->level().has_value() && wrapper->level().value() == cur_level) { + return; + } + TORCH_CHECK(false, + "During a grad (vjp, jvp, grad, etc) transform, the function provided ", + "attempted to call in-place operation (", op.schema().operator_name(), ") ", + "that would mutate a captured Tensor. This is not supported; please rewrite ", + "the function being transformed to explicitly accept the mutated Tensor(s) ", + "as inputs."); +} + +static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) { + if (!tensor.defined()) { + return tensor; + } + auto* wrapper = maybeGetTensorWrapper(tensor); + if (!wrapper) { + return makeTensorWrapper(tensor, current_level); + } + TORCH_INTERNAL_ASSERT(wrapper->level().value() <= current_level, "escaped?"); + if (wrapper->level().value() == current_level) { + TORCH_INTERNAL_ASSERT(tensor.defined()); + return tensor; + } + return makeTensorWrapper(tensor, current_level); +} + +static void autogradBasedTransformProcess( + const c10::OperatorHandle& op, + torch::jit::Stack* stack, + int64_t current_level, + TransformType transform_type) { + // if is a grad transform, and the operation is in-place, and the mutated + // argument is not currently wrapped in a TensorWrapper, then we need to + // error out otherwise the result is silently incorrect + checkForInvalidMutationOnCaptures(op, stack, current_level); + + // materialize live GradWrappers + auto maybeTransformGradWrappers = [&](const Tensor& tensor) { + return materializeGradWrappers(tensor, current_level); + }; + auto num_args = op.schema().arguments().size(); + foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), maybeTransformGradWrappers); + + auto exclude = keysToExcludeWhenEnteringDynamicLayer(transform_type); + setup_dispatch_key_tls(exclude, {}); + op.callBoxed(stack); +} + +static void autogradBasedTransformSendToNext( + const c10::OperatorHandle& op, + torch::jit::Stack* stack, + int64_t current_level, + TransformType transform_type, + optional prev_grad_mode, + optional prev_fwd_grad_mode) { + if (transform_type == TransformType::Grad) { + TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value()); + } + if (transform_type == TransformType::Jvp) { + TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value()); + } + auto unwrap = [&](const Tensor& tensor) { + if (!tensor.defined()) { + return tensor; + } + auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor); + if (!maybe_tensor_wrapper) { + return tensor; + } + auto tensor_wrapper_level = maybe_tensor_wrapper->level().value(); + TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= current_level); + if (tensor_wrapper_level == current_level) { + return maybe_tensor_wrapper->value(); + } + return tensor; + }; + auto wrap = [&](const Tensor& tensor) { + if (!tensor.defined()) { + return tensor; + } + // if (c10::show_dispatch_trace_enabled()) { + // std::cout << "wrap " << current_level << std::endl; + // } + return makeTensorWrapper(tensor, current_level); + }; + + // TODO: we only need to do the following (marked with !) on in-place functions + // that modify sizes or strides. There aren't many of them. + // If autograd dispatch key: + // 1. (!) Put a copy of all of the args onto the stack + // 2. Unwrap all the args in the copy set + // 3. Call the operator + // 4. Wrap the output + // 5. (!) refreshMetadata for all the args in the original set + // 6. (!) Pop those args off. + + // Step 1 & 2 + auto args_size = op.schema().arguments().size(); + // Step 1 + auto front = stack->size() - args_size; + for (const auto arg_idx : c10::irange(0, args_size)) { + stack->push_back((*stack)[front + arg_idx]); + } + // Step 2 + foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap); + + // See NOTE [grad and vjp interaction with no_grad] + optional grad_guard; + if (transform_type == TransformType::Grad && prev_grad_mode.has_value() && *prev_grad_mode == false) { + grad_guard.emplace(*prev_grad_mode); + } + optional fw_grad_guard; + if (transform_type == TransformType::Jvp && + prev_fwd_grad_mode.has_value() && prev_fwd_grad_mode.value() == false) { + fw_grad_guard.emplace(*prev_fwd_grad_mode); + } + + // Re-dispatch + if (getDynamicLayerStack().size() == 0) { + sanityCheckStack(op, stack); + } + op.callBoxed(stack); + + // Step 4, 5, 6 + auto ret_size = op.schema().returns().size(); + // Step 4 + foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap); + + // Step 5 + auto args_front = stack->size() - args_size - ret_size; + for (const auto arg_idx : c10::irange(0, args_size)) { + auto& ivalue = (*stack)[args_front + arg_idx]; + if (!ivalue.isTensor()) { + continue; + } + auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor()); + if (!maybe_tensor_wrapper) { + continue; + } + maybe_tensor_wrapper->refreshMetadata(); + } + + // Step 6 + stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size); +} + +void GradInterpreterPtr::processImpl( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + autogradBasedTransformProcess(op, stack, level(), TransformType::Grad); +} + +void GradInterpreterPtr::sendToNextInterpreterImpl( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + autogradBasedTransformSendToNext( + op, stack, level(), + TransformType::Grad, prevGradMode(), nullopt); +} + +void JvpInterpreterPtr::processImpl( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + autogradBasedTransformProcess(op, stack, level(), TransformType::Jvp); +} + +void JvpInterpreterPtr::sendToNextInterpreterImpl( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + autogradBasedTransformSendToNext( + op, stack, level(), + TransformType::Jvp, nullopt, prevFwdGradMode()); +} + +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/ADInterpreters.h b/functorch/functorch/csrc/ADInterpreters.h new file mode 100644 index 0000000000000..6f79afc6144ff --- /dev/null +++ b/functorch/functorch/csrc/ADInterpreters.h @@ -0,0 +1,32 @@ +#pragma once +#include + +namespace at { namespace functorch { + +struct GradInterpreterPtr { + explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + bool prevGradMode() const { + return c10::get(base_->meta()).prevGradMode_; + } + private: + const Interpreter* base_; +}; + +struct JvpInterpreterPtr { + explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + bool prevFwdGradMode() const { + return c10::get(base_->meta()).prevFwdGradMode_; + } + private: + const Interpreter* base_; +}; + +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/BatchRulesActivation.cpp b/functorch/functorch/csrc/BatchRulesActivation.cpp new file mode 100644 index 0000000000000..5261558e9b147 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesActivation.cpp @@ -0,0 +1,220 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +// NB: most activation functions fit pointwise unary or binary rules. +// These are only the ones that have special batch rules to help with organization +namespace at { namespace functorch { +std::tuple> +glu_batch_rule(const Tensor& self, optional self_bdim, int64_t dim) { + // repeated error message from glu because 0D -> 1D when batched + // this can't pass anyway because a 0-dimensional tensor has "size" 1, which + // can't be evenly halved, but give a nicer error message here. + TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors"); + + const auto rank = rankWithoutBatchDim(self, self_bdim); + const auto dim_ = maybe_wrap_dim(dim, rank) + 1; + + const auto self_ = moveBatchDimToFront(self, self_bdim); + + const auto res = at::glu(self_, dim_); + return std::make_tuple(res, 0); +} + +std::tuple> glu_backward_batch_rule( + const Tensor& grad_output, optional grad_output_bdim, + const Tensor& self, optional self_bdim, int64_t dim) { + if (self_bdim) { + // repeated error message from glu because 0D -> 1D when batched + // this can't pass anyway because a 0-dimensional tensor has "size" 1, which + // can't be evenly halved, but give a nicer error message here. + TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors"); + } + + const auto rank = rankWithoutBatchDim(self, self_bdim); + const auto dim_ = maybe_wrap_dim(dim, rank) + 1; + + const auto batch_size = get_bdim_size2(grad_output, grad_output_bdim, self, self_bdim); + const auto grad_output_ = ensure_has_bdim(moveBatchDimToFront(grad_output, grad_output_bdim), grad_output_bdim.has_value(), batch_size); + const auto self_ = ensure_has_bdim(moveBatchDimToFront(self, self_bdim), self_bdim.has_value(), batch_size); + + const auto res = at::glu_backward(grad_output_, self_, dim_); + return std::make_tuple(res, 0); +} + +std::tuple> prelu_batch_rule( + const Tensor& input, optional input_bdim, + const Tensor& weight, optional weight_bdim) { + if (!weight_bdim && weight.dim() == 0) { + return std::make_tuple(at::prelu(input, weight), input_bdim); + } + + const auto input_ = moveBatchDimToFront(input, input_bdim); + auto weight_flatten = moveBatchDimToFront(weight, weight_bdim); + + if (weight_flatten.dim() > 1) { + // for an input [N, C, ...] + // weight can be a non-vector but the total number of elements must be the same as C + weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1); + } + + const int64_t input_logical_rank = rankWithoutBatchDim(input, input_bdim); + VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end()); + const int64_t final_size = weight_bdim ? (input_logical_rank + 1) : input_logical_rank; + new_shape.reserve(final_size); + + if (weight_flatten.dim() == 2 || !weight_bdim) { + // if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the + // decomposition, we pad the weight to + + // copies checks from prelu if the weight (without vmap) is not a scalar + TORCH_CHECK(input_logical_rank > 0, "Not allow zero-dim input tensor."); + + int64_t channel_size = 1; // channel_size default to 1 + if (input_logical_rank > 1) { + const auto channel_dim = input_bdim ? 2 : 1; + channel_size = input_.size(channel_dim); + } + const auto weight_num = weight_flatten.size(-1); + TORCH_CHECK(channel_size == weight_num, + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + " and channel size = ", channel_size, "."); + + // pads to the left so that the flattened shape matches up with the channel + if (!weight_bdim) { + new_shape.insert(new_shape.begin(), 1); + } else { + new_shape.insert(new_shape.begin() + 1, 1); + } + } + + for (int64_t i = new_shape.size(); i < final_size; i ++) { + new_shape.push_back(1); + } + TORCH_INTERNAL_ASSERT((int64_t)new_shape.size() == final_size); + const auto weight_padded = weight_flatten.view(new_shape); + auto zero_tensor = at::zeros(1, input.options()); + + // decomposes function, + auto res = at::maximum(zero_tensor, input_) + weight_padded * at::minimum(zero_tensor, input_); + return std::make_tuple(res, 0); +} + +VmapDimVector ensure_shape_with_bdim(const Tensor& input, const bool has_bdim, const int64_t batch_size) { + // helper function that get the size of input, ensuring that there's batch dim, without expanding input + if (has_bdim) { + // sad to have to copy but got garbage if tried to return an IntArrayRef and just do input.sizes() + VmapDimVector new_shape(input.sizes().begin(), input.sizes().end()); + return new_shape; + } + VmapDimVector new_shape(1, batch_size); + new_shape.reserve(input.dim() + 1); + new_shape.insert(new_shape.end(), input.sizes().begin(), input.sizes().end()); + return new_shape; +} + +VmapDimVector shape_maybe_with_bdim(const Tensor& input, const bool need_bdim, const bool has_bdim, const int64_t batch_size) { + // if need_bdim, will return the input with a guaranteed bdim. If not, will return the input logical size (no batch dim) + if (need_bdim) { + return ensure_shape_with_bdim(input, has_bdim, batch_size); + } else if (has_bdim) { // !need_bdim && has_bdim + VmapDimVector new_shape(input.sizes().begin() + 1, input.sizes().end()); + return new_shape; + } else { // !need_bdim && !has_bdim + VmapDimVector new_shape(input.sizes().begin(), input.sizes().end()); + return new_shape; + } +} + +std::tuple prelu_backward_batched( + const Tensor& grad_out, const Tensor& self, const Tensor& weight, + const VmapDimVector& self_grad_shape, const VmapDimVector& weight_grad_padded_shape, const VmapDimVector& weight_grad_shape) { + // helper function that produces a batched gradient for prelu using a decomposition inspired by the AOTAutograd ones + const auto input_grad_collector = at::where(self > 0, grad_out, weight * grad_out); + const auto input_grad = native::sum_to_size(input_grad_collector, self_grad_shape); + const auto weight_grad_collector = at::where(self > 0, at::zeros(1, self.options()), self * grad_out); + const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, weight_grad_padded_shape); + const auto weight_grad = weight_grad_collector_2.view(weight_grad_shape); + return std::make_tuple(input_grad, weight_grad); +} + +std::tuple,Tensor,optional> prelu_backward_batch_rule( + const Tensor& grad_out, optional grad_out_bdim, + const Tensor& self, optional self_bdim, + const Tensor& weight, optional weight_bdim) { + const auto batch_size = get_bdim_size3(grad_out, grad_out_bdim, self, self_bdim, weight, weight_bdim); + const auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); + const auto self_ = moveBatchDimToFront(self, self_bdim); + const auto self_size_with_bdim = ensure_shape_with_bdim(self_, self_bdim.has_value(), batch_size); + if (!weight_bdim && weight.dim() == 0) { + VmapDimVector weight_grad_shape(1, batch_size); + VmapDimVector weight_grad_shape_padded(self_bdim.has_value() ? self.dim() : self.dim() + 1, 1); + weight_grad_shape_padded[0] = batch_size; + const auto grads = prelu_backward_batched(grad_out_, self_, weight, self_size_with_bdim, weight_grad_shape_padded, weight_grad_shape); + return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), 0); + } + const auto weight_ = moveBatchDimToFront(weight, weight_bdim); + auto weight_flatten = weight_; + if (weight_flatten.dim() > 1) { + // for an input [N, C, ...] + // weight can be a non-vector but the total number of elements must be the same as C + weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1); + } + + const int64_t self_logical_rank = rankWithoutBatchDim(self, self_bdim); + VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end()); + const int64_t final_size = weight_bdim ? (self_logical_rank + 1) : self_logical_rank; + new_shape.reserve(final_size); + + if (weight_flatten.dim() == 2 || !weight_bdim) { + // if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the + // decomposition, we pad the weight to + + // copies checks from prelu if the weight (without vmap) is not a scalar + TORCH_CHECK(self_logical_rank > 0, "Not allow zero-dim input tensor."); + + int64_t channel_size = 1; // channel_size default to 1 + if (self_logical_rank > 1) { + channel_size = self_.size(self_bdim.has_value() ? 2 : 1); + } + + const auto weight_num = weight_flatten.size(-1); + TORCH_CHECK(channel_size == weight_num, + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + " and channel size = ", channel_size, "."); + + // pads to the left so that the flattened shape matches up with the channel + if (!weight_bdim) { + new_shape.insert(new_shape.begin(), 1); + } else { + new_shape.insert(new_shape.begin() + 1, 1); + } + } + + for (int64_t i = new_shape.size(); i < final_size; i ++) { + new_shape.push_back(1); + } + // weight grad does not depend on weight values. It is batched iff grad_out or self are batched + const auto weight_grad_is_batched = grad_out_bdim.has_value() || self_bdim.has_value(); + + const auto weight_padded = weight_flatten.view(new_shape); + const auto weight_grad_shape = shape_maybe_with_bdim(weight_, weight_grad_is_batched, weight_bdim.has_value(), batch_size); + const auto weight_padded_grad_shape = shape_maybe_with_bdim(weight_padded, weight_grad_is_batched, weight_bdim.has_value(), batch_size); + + const auto grads = prelu_backward_batched(grad_out_, self_, weight_padded, self_size_with_bdim, weight_padded_grad_shape, weight_grad_shape); + return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), (weight_grad_is_batched ? optional(0) : nullopt)); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + VMAP_SUPPORT(glu_backward, glu_backward_batch_rule); + VMAP_SUPPORT(glu, glu_batch_rule); + VMAP_SUPPORT(prelu, prelu_batch_rule) + VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule) +} +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp new file mode 100644 index 0000000000000..2afd6482cc51e --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp @@ -0,0 +1,503 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +namespace at { namespace functorch { + +static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) { + auto result_type = at::native::result_type(logical_scalar_tensor[0], second); + if (logical_scalar_tensor.scalar_type() != result_type) { + logical_scalar_tensor = logical_scalar_tensor.to(result_type); + } + if (second.scalar_type() != result_type) { + second = second.to(result_type); + } +} + +std::tuple _binary_pointwise_helper( + const Tensor& tensor, optional tensor_batch_dim, + const Tensor& other, optional other_batch_dim) { + // compute max logical rank + auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim); + auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim); + auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank); + + auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim); + auto other_ = moveBatchDimToFront(other, other_batch_dim); + + // In the (0D, ND) case, type promotion semantics are different :/ + auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value()); + auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value()); + if (tensor_is_logical_scalar && !other_is_logical_scalar) { + handleScalarTypePromotion(tensor_, other_); + } + if (other_is_logical_scalar && !tensor_is_logical_scalar) { + handleScalarTypePromotion(other_, tensor_); + } + + // If the dimensions aren't aligned, we need to line them up. + // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3] + // Note that only tensors that have a batch dim need to be modified. + // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed + tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank); + other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank); + + return std::make_tuple(tensor_, other_); +} + +template +std::tuple> _binary_pointwise_batch_rule( + const Tensor& tensor, optional tensor_batch_dim, + const Tensor& other, optional other_batch_dim, + ExtraArgs... extra_args) { + + auto tensor_other = _binary_pointwise_helper( + tensor, tensor_batch_dim, other, other_batch_dim); + auto tensor_ = std::get<0>(tensor_other); + auto other_ = std::get<1>(tensor_other); + + auto result = Func(tensor_, other_, std::forward(extra_args)...); + return std::make_tuple(result, 0); +} + +template +struct BinaryPointwiseBatchRuleHelper; + +template +struct BinaryPointwiseBatchRuleHelper> { + static std::tuple> apply( + const Tensor& tensor, optional tensor_batch_dim, + const Tensor& other, optional other_batch_dim, + T... extra_args) { + return _binary_pointwise_batch_rule( + tensor, tensor_batch_dim, other, other_batch_dim, + std::forward(extra_args)...); + } +}; + +#define BINARY_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\ + BinaryPointwiseBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +template +struct BinaryRandomPointwiseBatchRuleHelper; + +template +struct BinaryRandomPointwiseBatchRuleHelper> { + static Tensor apply(const Tensor& tensor, const Tensor& other, T... extra_args) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + auto cur_level = maybe_layer->layerId(); + RandomnessType randomness = maybe_layer->randomness(); + + Tensor tensor_value; + optional tensor_bdim; + std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(tensor, cur_level); + + Tensor other_value; + optional other_bdim; + std::tie(other_value, other_bdim) = unwrapTensorAtLevel(other, cur_level); + + check_randomness(randomness, (tensor_bdim || other_bdim)); + if (randomness == RandomnessType::Different && !tensor_bdim && !other_bdim) { + auto shape = tensor_value.sizes(); + VmapDimVector shapeVec(1, maybe_layer->batchSize()); + shapeVec.reserve(shape.size() + 1); + shapeVec.insert(shapeVec.end(), shape.begin(), shape.end()); + + // not taken care of with binary batch rule, which assumes at least one input is batched + tensor_value = tensor_value.expand(shapeVec); + tensor_bdim = 0; + } else if (randomness == RandomnessType::Same && !tensor_bdim && !other_bdim) { + + // avoids unnecessary checks and batch rule assuming output is batched + return Func(tensor_value, other_value, std::forward(extra_args)...); + } + auto res = _binary_pointwise_batch_rule( + tensor_value, tensor_bdim, other_value, other_bdim, + std::forward(extra_args)...); + return makeBatched(std::get<0>(res), std::get<1>(res), cur_level); + } +}; + +#define BINARY_RANDOM_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\ + BinaryRandomPointwiseBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +template +void binary_pointwise_inplace_batch_rule( + Tensor& tensor, optional tensor_batch_dim, + const Tensor& other, optional other_batch_dim, + ExtraArgs... extra_args) { + if (!tensor_batch_dim && other_batch_dim) { + vmapIncompatibleInplaceError("inplace arithmetic"); + } + + // compute max logical rank + auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim); + auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim); + auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank); + + auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim); + auto other_ = moveBatchDimToFront(other, other_batch_dim); + + // If the dimensions aren't aligned, we need to line them up. + // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3] + // Note that only tensors that have a batch dim need to be modified. + // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed + tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank); + other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank); + + (tensor_.*Meth)(other_, std::forward(extra_args)...); +} + +template +std::tuple> comparison_pointwise_batch_rule( + const Tensor& tensor, optional tensor_batch_dim, + const Tensor& other, optional other_batch_dim) { + // compute max logical rank + auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim); + auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim); + auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank); + + auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim); + auto other_ = moveBatchDimToFront(other, other_batch_dim); + + // If the dimensions aren't aligned, we need to line them up. + // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3] + // Note that only tensors that have a batch dim need to be modified. + // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed + tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank); + other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank); + + auto result = Func(tensor_, other_); + return std::make_tuple( std::move(result), 0 ); +} + +std::tuple> where_self_batch_rule( + const Tensor& condition, optional condition_bdim, + const Tensor& self, optional self_bdim, const Tensor& other, optional other_bdim) { + auto condition_logical_rank = rankWithoutBatchDim(condition, condition_bdim); + auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); + auto max_logical_rank = std::max({tensor_logical_rank, other_logical_rank, condition_logical_rank}); + + auto condition_ = moveBatchDimToFront(condition, condition_bdim); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto other_ = moveBatchDimToFront(other, other_bdim); + + condition_ = maybePadToLogicalRank(condition_, condition_bdim, max_logical_rank); + self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank); + other_ = maybePadToLogicalRank(other_, other_bdim, max_logical_rank); + return std::make_tuple(at::where(condition_, self_, other_), 0); +} + +std::tuple> gelu_backward_batch_rule( + const Tensor& grad_out, optional grad_out_bdim, const Tensor& input, optional input_bdim, + c10::string_view approximate) { + + // repeat the preprocessing from _binary_pointwise_batch_rule + const auto tensor_other = _binary_pointwise_helper(grad_out, grad_out_bdim, input, input_bdim); + auto grad_out_ = std::get<0>(tensor_other); + auto input_ = std::get<1>(tensor_other); + + // gelu_backward doesn't broadcast well so we need to insist all inputs have a bdim + const auto batch_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim); + grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), batch_size); + input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size); + + return std::make_tuple(at::gelu_backward(grad_out_, input_, approximate), 0); +} + +std::tuple> masked_select_batch_rule( + const Tensor& self, optional self_bdim, + const Tensor& mask, optional mask_bdim) { + TORCH_CHECK(!mask_bdim.has_value(), + "vmap: Attempted to vmap over `mask` in torch.masked_select(self, mask) ", + "We cannot support this because for each batch this would return a ", + "differently shaped Tensor. " + "Please voice your support in https://github.com/pytorch/functorch/issues/256"); + auto self_ = moveBatchDimToFront(self, self_bdim); + const auto batch_size = self_.size(0); + const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + const auto max_logical_rank = std::max(self_logical_rank, mask.dim()); + self_ = maybePadToLogicalRank(self_, 0, max_logical_rank); + + // masked_select returns a 1D tensor, so we have to reshape it into 2D + const auto result = at::masked_select(self_, mask).view({ batch_size, -1 }); + return std::make_tuple(result, 0); +} + +std::tuple> masked_select_backward_batch_rule( + const Tensor& grad, optional grad_bdim, + const Tensor& self, optional self_bdim, + const Tensor& mask, optional mask_bdim) { + TORCH_CHECK(!mask_bdim.has_value(), + "vmap: Attempted to vmap over `mask` in torch.masked_select_backward(grad, self, mask) ", + "We cannot support this because for each batch this would return a ", + "differently shaped Tensor. " + "Please voice your support in https://github.com/pytorch/functorch/issues/256"); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto grad_ = moveBatchDimToFront(grad, grad_bdim); + + const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + const auto max_logical_rank = std::max(self_logical_rank, mask.dim()); + + self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank); + + const auto batch_size = get_bdim_size2(grad, grad_bdim, self, self_bdim); + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), batch_size); + + const auto result = at::masked_select_backward(grad_, self_.contiguous(), mask); + return std::make_tuple(result, 0); +} + +std::tuple> cdist_backward_batch_rule( + const Tensor& grad, optional grad_bdim, + const Tensor& x1, optional x1_bdim, + const Tensor& x2, optional x2_bdim, + const double p, + const Tensor& cdist, optional cdist_bdim) { + + auto x1_ = x1; + if (cdist_bdim && !x1_bdim) { + // We need to make sure that x1 has batch dim if cdist has one + // otherwise, we get + // RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5] + // but expected shape compatible with [4, 5] + auto bs = cdist.size(*cdist_bdim); + x1_ = ensure_has_bdim(x1, false, bs); + x1_ = x1_.contiguous(); + x1_bdim = 0; + } + + // We need to apply the same preprocessing on x1 and x2 as in the forward pass + // _binary_pointwise_batch_rule + auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim); + x1_ = std::get<0>(x12); + auto x2_ = std::get<1>(x12); + + auto grad_ = moveBatchDimToFront(grad, grad_bdim); + if ((x1_bdim || x2_bdim) && !grad_bdim) { + // We need to make sure that grad has batch dim if x1 or x2 have one + // Probably, there is an assumption on the strides. + // Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29 + auto bs = get_bdim_size2(x1_, 0, x2_, 0); + grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs); + grad_ = grad_.contiguous(); + } + + auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist); + + optional out_bdim = nullopt; + if (x1_bdim || x2_bdim) { + out_bdim = 0; + } + + return std::make_tuple(out, out_bdim); +} + +Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional gen) { + return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous +} + +TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { + #define BINARY_RANDOM_POINTWISE(op) \ + m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op))); + #define BINARY_RANDOM_POINTWISE2(op, overload) \ + m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload))); + + BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor); + m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper)); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { +#define BINARY_POINTWISE2(op, overload) \ + VMAP_SUPPORT2(op, overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload))); +#define BINARY_POINTWISE(op) \ + VMAP_SUPPORT(op, BINARY_POINTWISE_BATCH_RULE(ATEN_FN(op))); +#define UNARY_POINTWISE2(op, overload) \ + VMAP_SUPPORT2(op, overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload))); +#define UNARY_POINTWISE(op) \ + VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op))); +#define UNARY_SCALAR_POINTWISE2(op, overload) \ + VMAP_SUPPORT(op, overload, SCALAR_UNARY_BATCH_RULE(ATEN_FN2(op, overload))); + +#define BINARY_SCALAR_2(op, tensor_tensor, tensor_scalar) \ + BINARY_POINTWISE2(op, tensor_tensor);\ + UNARY_POINTWISE2(op, tensor_scalar); + +// For all 3 combinations of Tensor x Tensor, Tensor x Scalar, Scalar x Tensor +#define BINARY_SCALAR_3(op, tensor_tensor, tensor_scalar, scalar_tensor) \ + BINARY_POINTWISE2(op, tensor_tensor);\ + UNARY_POINTWISE2(op, tensor_scalar);\ + POINTWISE_BOXED(op.scalar_tensor); + +#define BINARY_SCALAR_3_Tensor(op, tensor_scalar, scalar_tensor) \ + BINARY_POINTWISE(op);\ + UNARY_POINTWISE2(op, tensor_scalar);\ + POINTWISE_BOXED(op.scalar_tensor); + + // Batching rule registrations start + POINTWISE_BOXED(__ilshift__.Tensor); + POINTWISE_BOXED(__ilshift__.Scalar); + POINTWISE_BOXED(__irshift__.Tensor) + POINTWISE_BOXED(__irshift__.Scalar) + BINARY_SCALAR_2(__lshift__, Tensor, Scalar); + BINARY_SCALAR_2(__rshift__, Tensor, Scalar); + + BINARY_SCALAR_2(add, Tensor, Scalar); + POINTWISE_BOXED(addcdiv); + POINTWISE_BOXED(addcmul); + BINARY_POINTWISE(atan2); + BINARY_SCALAR_2(bitwise_and, Tensor, Scalar); + BINARY_POINTWISE2(bitwise_or, Tensor); + BINARY_POINTWISE2(bitwise_xor, Tensor); + BINARY_SCALAR_3(bitwise_left_shift, Tensor, Tensor_Scalar, Scalar_Tensor); + BINARY_SCALAR_3(bitwise_right_shift, Tensor, Tensor_Scalar, Scalar_Tensor); + + UNARY_POINTWISE(clamp); + POINTWISE_BOXED(clamp.Tensor); + BINARY_POINTWISE2(clamp_min, Tensor); + UNARY_POINTWISE(clamp_min); + POINTWISE_BOXED(clamp_min_); + BINARY_POINTWISE2(clamp_max, Tensor); + UNARY_POINTWISE(clamp_max); + POINTWISE_BOXED(clamp_max_); + + VARIADIC_BDIMS_BOXED(_euclidean_dist); + // Implementation note: _binary_pointwise_helper performs a dtype promotion if args are scalars, + // but cdist can't work with scalars, at least 2d tensors. + BINARY_POINTWISE(_cdist_forward); + VMAP_SUPPORT(_cdist_backward, cdist_backward_batch_rule); + + // Commented out so we have a test op + // BINARY_SCALAR_2(copysign, Tensor, Scalar); + BINARY_SCALAR_2(div, Tensor, Scalar); + BINARY_SCALAR_2(div, Tensor_mode, Scalar_mode); + + BINARY_POINTWISE(floor_divide); + UNARY_POINTWISE2(floor_divide, Scalar); + + BINARY_POINTWISE(fmax); + BINARY_POINTWISE(fmin); + BINARY_SCALAR_2(fmod, Tensor, Scalar); + POINTWISE_BOXED(frexp.Tensor); + BINARY_POINTWISE(heaviside); + BINARY_POINTWISE(hypot); + BINARY_POINTWISE(gcd); + BINARY_POINTWISE(igamma); + BINARY_POINTWISE(igammac); + BINARY_POINTWISE(logaddexp); + BINARY_POINTWISE(logaddexp2); + POINTWISE_BOXED(lerp.Scalar); + POINTWISE_BOXED(lerp.Tensor); + BINARY_POINTWISE(lcm); + POINTWISE_BOXED(log_sigmoid_forward); + BINARY_POINTWISE(maximum); + BINARY_POINTWISE(minimum); + + BINARY_SCALAR_2(mul, Tensor, Scalar); + BINARY_POINTWISE(nextafter); + BINARY_SCALAR_3(pow, Tensor_Tensor, Tensor_Scalar, Scalar); + POINTWISE_BOXED2(pow_, Scalar); + BINARY_POINTWISE(polar); + POINTWISE_BOXED(polygamma); + BINARY_SCALAR_2(sub, Tensor, Scalar); + BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor); + BINARY_POINTWISE(rrelu_with_noise); + BINARY_SCALAR_2(rsub, Tensor, Scalar); + + BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar); + BINARY_SCALAR_3_Tensor(special_xlogy, other_scalar, self_scalar); + BINARY_SCALAR_3_Tensor(special_zeta, other_scalar, self_scalar); + + VMAP_SUPPORT2(where, self, where_self_batch_rule); + + BINARY_SCALAR_3(xlogy, Tensor, Scalar_Other, Scalar_Self); + + POINTWISE_BOXED(elu_backward); + BINARY_POINTWISE(hardsigmoid_backward); + BINARY_POINTWISE(hardtanh_backward); + BINARY_POINTWISE(hardshrink_backward); + BINARY_POINTWISE(hardswish_backward); + // BINARY_POINTWISE(infinitely_differentiable_gelu_backward); + BINARY_POINTWISE(leaky_relu_backward); + BINARY_POINTWISE(logit_backward); + POINTWISE_BOXED(log_sigmoid_backward); + VMAP_SUPPORT(gelu_backward, gelu_backward_batch_rule); + BINARY_POINTWISE(sigmoid_backward); + POINTWISE_BOXED(softplus_backward); + BINARY_POINTWISE(softshrink_backward); + BINARY_POINTWISE(tanh_backward); + BINARY_POINTWISE(threshold_backward); + BINARY_POINTWISE(silu_backward); + + using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const; + using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const; + using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const; + using ScalarInplaceT = Tensor& (Tensor::*)(const Scalar&) const; + using CopyT = Tensor& (Tensor::*)(const Tensor&, bool) const; + + POINTWISE_BOXED(add_.Tensor); // just testing + VMAP_SUPPORT2(add_, Scalar, SINGLE_ARG(unary_inplace_batch_rule)); + VMAP_SUPPORT2(sub_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + VMAP_SUPPORT2(sub_, Scalar, SINGLE_ARG(unary_inplace_batch_rule)); + VMAP_SUPPORT2(mul_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + VMAP_SUPPORT2(mul_, Scalar, SINGLE_ARG(unary_inplace_batch_rule)); + VMAP_SUPPORT2(div_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + VMAP_SUPPORT2(div_, Scalar, SINGLE_ARG(unary_inplace_batch_rule)); + VMAP_SUPPORT2(clamp_min_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + VMAP_SUPPORT2(clamp_max_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + VMAP_SUPPORT2(masked_fill_, Scalar, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + VMAP_SUPPORT(copy_, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + +#define COMPARISON_POINTWISE(op) \ + VMAP_SUPPORT2(op, Tensor, \ + SINGLE_ARG(comparison_pointwise_batch_rule)); \ + UNARY_POINTWISE2(op, Scalar) + + COMPARISON_POINTWISE(eq); + COMPARISON_POINTWISE(gt); + COMPARISON_POINTWISE(ge); + COMPARISON_POINTWISE(le); + COMPARISON_POINTWISE(lt); + COMPARISON_POINTWISE(ne); + +#undef COMPARISON_POINTWISE +#undef BINARY_POINTWISE2 +#undef BINARY_POINTWISE +#undef UNARY_POINTWISE2 +#undef UNARY_POINTWISE +#undef UNARY_SCALAR_POINTWISE2 +#undef BINARY_SCALAR_3 + +#define LOGICAL_COMPARISON_POINTWISE(op) \ + VMAP_SUPPORT(op, \ + SINGLE_ARG(comparison_pointwise_batch_rule)); \ + VMAP_SUPPORT(op ## _, \ + SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + + LOGICAL_COMPARISON_POINTWISE(logical_and); + LOGICAL_COMPARISON_POINTWISE(logical_or); + LOGICAL_COMPARISON_POINTWISE(logical_xor); + +#undef SINGLE_ARG +#undef LOGICAL_COMPARISON_POINTWISE + VMAP_SUPPORT(masked_select, masked_select_batch_rule); + VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule) +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesConvolution.cpp b/functorch/functorch/csrc/BatchRulesConvolution.cpp new file mode 100644 index 0000000000000..8382070283cdc --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesConvolution.cpp @@ -0,0 +1,513 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +namespace at { namespace functorch { + +// convolution_batch_rule translated from jax with modifications: +// https://github.com/google/jax/blob/master/jax/_src/lax/lax.py#L3143 + +// PyTorch's convolution is different from JAX's conv_general_dilated: +// we do not support batch_group_count (which is needed for convolution backwards). +// Instead, there's a convolution_backward op that needs a batching rule. +std::tuple> +convolution_batch_rule(const Tensor& lhs, optional lhs_bdim, const Tensor& rhs, optional rhs_bdim, const optional& bias, optional bias_bdim, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups) { + DimVector lhs_spec(stride.size() + 2); + std::iota(lhs_spec.begin(), lhs_spec.end(), 0); + DimVector rhs_spec = lhs_spec; + DimVector out_spec = lhs_spec; + if (transposed) { + rhs_spec[0] = 1; + rhs_spec[1] = 0; + } + + // If we have a batched bias or weight, we need to perform the computation separately. + optional unbatched_bias; + bool separate_bias; + if ((rhs_bdim && bias && bias->defined()) || bias_bdim) { + TORCH_INTERNAL_ASSERT(bias.has_value()); + TORCH_INTERNAL_ASSERT(bias->defined()); + unbatched_bias = nullopt; + separate_bias = true; + } else { + unbatched_bias = bias; + separate_bias = false; + } + std::tuple> result; + if (lhs_bdim && !rhs_bdim) { + auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs); + auto out = at::convolution(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + out = reshape_dim_outof(out_spec[0], lhs.sizes()[*lhs_bdim], out); + result = std::make_tuple(out, out_spec[0]); + } else if (!lhs_bdim && rhs_bdim) { + if (groups == 1) { + auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs); + auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + out = reshape_dim_outof(out_spec[1], rhs.sizes()[*rhs_bdim], out); + result = std::make_tuple(out, out_spec[1]); + } else { + auto dim_with_groups = transposed ? 1 : 0; + auto new_w = reshape_dim_outof(rhs_spec[dim_with_groups] + (*rhs_bdim <= rhs_spec[0]), groups, rhs); + new_w = reshape_dim_into(*rhs_bdim + (rhs_spec[0] < rhs_bdim), rhs_spec[0] + 1, new_w); + new_w = reshape_dim_into(rhs_spec[0], rhs_spec[0], new_w); + auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + out = reshape_dim_outof(out_spec[1], groups, out); + out = reshape_dim_outof(out_spec[1] + 1, rhs.sizes()[*rhs_bdim], out); + out = reshape_dim_into(out_spec[1], out_spec[1] + 1, out); + result = std::make_tuple(out, out_spec[1]); + } + } else if (lhs_bdim && rhs_bdim) { + auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[1], lhs); + groups *= lhs.sizes()[*lhs_bdim]; + auto dim_with_groups = transposed ? 1 : 0; + auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs); + auto out = at::convolution(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + out = reshape_dim_outof(out_spec[1], lhs.sizes()[*lhs_bdim], out); + result = std::make_tuple(out, out_spec[1]); + } else { + result = std::make_tuple(at::convolution(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), nullopt); + } + if (separate_bias) { + auto A = std::get<0>(result); + auto A_batch_dim = std::get<1>(result); + auto B = *bias; + auto B_batch_dim = bias_bdim; + A = moveBatchDimToFront(A, A_batch_dim); + B = moveBatchDimToFront(B, B_batch_dim); + for (size_t i = 0; i < out_spec.size() - 2; i++) { + B = B.unsqueeze(-1); + } + B = maybePadToLogicalRank(B, B_batch_dim, rankWithoutBatchDim(A, A_batch_dim)); + + return std::make_tuple(at::add(A, B), 0); + } else { + return result; + } +} + +Tensor _convolution_decomp( + const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_r_opt, + IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, + bool transposed_, IntArrayRef output_padding_, int64_t groups_, + bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + // Ignore everything. If the user called this in the normal way, + // then they should be fine. + (void) benchmark; + (void) deterministic; + (void) cudnn_enabled; + (void) allow_tf32; + return at::convolution( + input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_); +} + +// TODO: delete the following after confirming performance +// bool first_dim_has_size_1(const Tensor& value, int64_t bdim) { +// if (bdim == 0) { +// return value.size(1) == 1; +// } +// return value.size(0) == 1; +// } +// +// std::tuple cudnn_conv_per_sample_grad_rule( +// const Tensor& self, optional self_bdim, +// const Tensor& grad_output, optional grad_output_bdim, +// const Tensor& weight, optional weight_bdim, +// IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, +// bool deterministic, bool allow_tf32, std::array output_mask) { +// TORCH_INTERNAL_ASSERT(self_bdim && grad_output_bdim && !weight_bdim); +// // TODO: No clue if this works if the first non-batch dim isn't size 1 +// TORCH_INTERNAL_ASSERT(first_dim_has_size_1(self, *self_bdim)); +// TORCH_INTERNAL_ASSERT(self.dim() == 5); +// +// auto bdim_size = self.size(*self_bdim); +// auto self_ = reshape_dim_into(*self_bdim, 0, self); +// auto in_channels = self_.size(1); +// auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output); +// +// auto grad_self = at::cudnn_convolution_backward_input( +// self_.sizes(), grad_output_, weight, +// padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +// grad_self = reshape_dim_outof(0, bdim_size, grad_self); +// +// // Copied from https://github.com/pytorch/opacus/blob/master/opacus/grad_sample/conv.py +// auto A = at::im2col(self_, {weight.size(2), weight.size(3)}, dilation, padding, stride); +// auto B = grad_output_.reshape({bdim_size, -1, A.size(-1)}); +// auto grad_sample = at::einsum("noq,npq->nop", {B, A}); +// grad_sample = grad_sample.view({ +// bdim_size, groups, -1, groups, in_channels / groups, +// weight.size(2) * weight.size(3) }); +// grad_sample = at::einsum("ngrg...->ngr...", {grad_sample}); +// grad_sample = grad_sample.reshape( +// {bdim_size, weight.size(0), weight.size(1), weight.size(2), weight.size(3)}); +// +// return std::make_tuple(grad_self, 0, grad_sample, 0); +// } +// +// std::tuple cudnn_convolution_backward_plumbing(const Tensor & self, const Tensor & grad_output, const Tensor & weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { +// auto maybe_layer = maybeCurrentDynamicLayer(); +// TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); +// int64_t cur_level = maybe_layer->layerId(); +// +// Tensor self_value; +// optional self_bdim; +// std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); +// Tensor grad_output_value; +// optional grad_output_bdim; +// std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level); +// Tensor weight_value; +// optional weight_bdim; +// std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level); +// +// if (self_bdim.has_value() && self_value.dim() == 5 && first_dim_has_size_1(self_value, *self_bdim) && grad_output_bdim.has_value() && !weight_bdim.has_value()) { +// c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); +// auto result = cudnn_conv_per_sample_grad_rule( +// self_value, self_bdim, +// grad_output_value, grad_output_bdim, +// weight_value, weight_bdim, +// padding, stride, dilation, groups, +// benchmark, deterministic, allow_tf32, output_mask); +// return std::make_tuple( +// makeBatched(std::get<0>(result), std::get<1>(result), cur_level), +// makeBatched(std::get<2>(result), std::get<3>(result), cur_level)); +// } +// +// static auto op = c10::Dispatcher::singleton() +// .findSchemaOrThrow("aten::cudnn_convolution_backward", ""); +// return slow_fallback(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask }); + +static Tensor compute_grad_bias( + const Tensor& grad_output_, std::array output_mask) { + if (!output_mask[2]) { + return Tensor(); + } + DimVector reduce_dims; + reduce_dims.resize(grad_output_.dim() - 1); + reduce_dims[0] = 0; + std::iota(reduce_dims.begin() + 1, reduce_dims.end(), 2); + return grad_output_.sum(reduce_dims); +} + +// reshapes the batch_size into dim +Tensor make_dummy( + const Tensor& tensor, optional tensor_bdim, + int64_t dim, int64_t batch_size) { + auto tensor_ = tensor_bdim ? tensor.select(*tensor_bdim, 0) : tensor; + auto orig_size = tensor_.size(dim); + tensor_ = tensor_.slice(dim, 0, 1); + + DimVector expand_shape(tensor_.sizes().begin(), tensor_.sizes().end()); + expand_shape[dim] = batch_size * orig_size; + + return tensor_.new_empty({}).expand(expand_shape); +} + +std::tuple> +convolution_backward_input_batch_rule( + const Tensor& grad_output, optional grad_output_bdim, + const Tensor& input, optional input_bdim, + const Tensor& weight, optional weight_bdim, + IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, + IntArrayRef output_padding, int64_t groups) { + const std::array mask = {true, false, false}; + if (grad_output_bdim && weight_bdim) { + // regular: BNO, BOI -> N(BO), (BO)I -> N(BI) + // transposed: BNO, BIO -> N(BO), (BI)O -> N(BI) + const auto batch_size = weight.size(*weight_bdim); + const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output); + const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight); + auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); + const auto result = at::convolution_backward( + grad_output_, dummy_input, weight_, nullopt, stride, padding, + dilation, transposed, output_padding, groups * batch_size, mask); + const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result)); + return std::make_tuple(grad_input, 1); + } else if (grad_output_bdim && !weight_bdim) { + // BNO, OI -> (BN)O, OI -> (BN)I + // transposed is the same. + const auto batch_size = grad_output.size(*grad_output_bdim); + const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output); + auto dummy_input = make_dummy(input, input_bdim, 0, batch_size); + const auto result = at::convolution_backward( + grad_output_, dummy_input, weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result)); + return std::make_tuple(grad_input, 0); + } else if (!grad_output_bdim && weight_bdim) { + const auto batch_size = weight.size(*weight_bdim); + if (groups == 1) { + // regular: NO, BOI -> NO, O(BI) -> N(BI) + // transposed: NO, BIO -> NO, (BI)O -> N(BI) + const auto in_ch_dim = transposed ? 0 : 1; + const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight); + auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); + const auto result = at::convolution_backward( + grad_output, dummy_input, weight_, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result)); + return std::make_tuple(grad_input, 1); + } + Tensor grad_input; + if (!transposed) { + // N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI) + const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight); + auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); + const auto result = at::convolution_backward( + grad_output, dummy_input, weight_, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + grad_input = std::get<0>(result); // N(GBI) + } else { + // N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI) + auto weight_ = moveBatchDimToFront(weight, weight_bdim); // B(GI)O + weight_ = reshape_dim_outof(1, groups, weight_); // BGIO + weight_ = weight_.transpose(0, 1); // GBIO + weight_ = weight_.flatten(0, 2); // (GBI)O + const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); + const auto result = at::convolution_backward( + grad_output, dummy_input, weight_, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + grad_input = std::get<0>(result); // N(GBI) + } + // N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI) + grad_input = reshape_dim_outof(1, groups, grad_input); + grad_input = reshape_dim_outof(2, batch_size, grad_input); + grad_input = grad_input.transpose(1, 2); + grad_input = reshape_dim_into(2, 2, grad_input); + return std::make_tuple(grad_input, 1); + } else { + TORCH_INTERNAL_ASSERT(input_bdim); + const auto dummy_input = make_dummy(input, input_bdim, 0, 1); + const auto result = at::convolution_backward( + grad_output, dummy_input, weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + return std::make_tuple(std::get<0>(result), nullopt); + } +} +std::tuple> +convolution_backward_weight_batch_rule( + const Tensor& grad_output, optional grad_output_bdim, + const Tensor& input, optional input_bdim, + const Tensor& weight, optional weight_bdim, + IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, + IntArrayRef output_padding, int64_t groups) { + const std::array mask = {false, true, false}; + if (grad_output_bdim && input_bdim) { + // BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed) + const auto batch_size = input.size(*input_bdim); + const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output); + const auto input_ = reshape_dim_into(*input_bdim, 1, input); + const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); + const auto result = at::convolution_backward( + grad_output_, input_, dummy_weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups * batch_size, mask); + auto grad_weight = std::get<1>(result); + grad_weight = reshape_dim_outof(0, batch_size, grad_weight); + return std::make_tuple(grad_weight, 0); + } else if (grad_output_bdim && !input_bdim) { + const auto batch_size = grad_output.size(*grad_output_bdim); + if (groups == 1) { + // regular: BNO, NI -> N(BO), NI -> (BO)I + // transposed: BNO, NI -> N(BO), NI -> I(BO) + const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output); + const auto out_ch_dim = transposed ? 1 : 0; + const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size); + const auto result = at::convolution_backward( + grad_output_, input, dummy_weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + auto grad_weight = std::get<1>(result); + grad_weight = reshape_dim_outof(out_ch_dim, batch_size, grad_weight); + return std::make_tuple(grad_weight, out_ch_dim); + } else { + auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); // BN(GO) + grad_output_ = reshape_dim_outof(2, groups, grad_output_); // BNGO + grad_output_ = grad_output_.movedim(0, 2); // NGBO + grad_output_ = grad_output_.flatten(1, 3); // N(GBO) + if (!transposed) { + // BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I + const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); + const auto result = at::convolution_backward( + grad_output_, input, dummy_weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + auto grad_weight = std::get<1>(result); + grad_weight = grad_weight.unflatten(0, { groups, batch_size, -1 }); // GBOI + grad_weight = grad_weight.transpose(0, 1); // BGOI + grad_weight = grad_weight.flatten(1, 2); // B(GO)I + return std::make_tuple(grad_weight, 0); + } else { + // BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO) + const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size); + const auto result = at::convolution_backward( + grad_output_, input, dummy_weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + auto grad_weight = std::get<1>(result); + grad_weight = reshape_dim_outof(1, batch_size, grad_weight); + return std::make_tuple(grad_weight, 1); + } + } + } else if (!grad_output_bdim && input_bdim) { + const auto batch_size = input.size(*input_bdim); + if (groups == 1) { + // regular: NO, BNI -> NO, N(BI) -> O(BI) + // transposed: NO, BNI -> NO, N(BI) -> (BI)O + const auto input_ = reshape_dim_into(*input_bdim, 1, input); + const auto in_ch_dim = transposed ? 0 : 1; + const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size); + const auto result = at::convolution_backward( + grad_output, input_, dummy_weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + auto grad_weight = std::get<1>(result); + grad_weight = reshape_dim_outof(in_ch_dim, batch_size, grad_weight); + return std::make_tuple(grad_weight, in_ch_dim); + } else { + auto input_ = moveBatchDimToFront(input, input_bdim); // BN(GI) + input_ = reshape_dim_outof(2, groups, input_); // BNGI + input_ = input_.movedim(0, 2); // NGBI + input_ = input_.flatten(1, 3); // N(GBI) + if (!transposed) { + // regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI) + const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size); + const auto result = at::convolution_backward( + grad_output, input_, dummy_weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + auto grad_weight = std::get<1>(result); + grad_weight = reshape_dim_outof(1, batch_size, grad_weight); + return std::make_tuple(grad_weight, 1); + } else { + // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O + const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); + const auto result = at::convolution_backward( + grad_output, input_, dummy_weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + auto grad_weight = std::get<1>(result); + grad_weight = grad_weight.unflatten(0, { groups, batch_size, -1 }); // GBIO + grad_weight = grad_weight.transpose(0, 1); // BGIO + grad_weight = grad_weight.flatten(1, 2); // B(GI)O + return std::make_tuple(grad_weight, 0); + } + } + } else { + TORCH_INTERNAL_ASSERT(weight_bdim); + const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1); + const auto result = at::convolution_backward( + grad_output, input, dummy_weight, nullopt, stride, padding, + dilation, transposed, output_padding, groups, mask); + return std::make_tuple(std::get<1>(result), nullopt); + + } +} + +std::tuple convolution_backward_plumbing( + const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_, + const c10::OptionalArrayRef bias_sizes_opt, + IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, + IntArrayRef output_padding, int64_t groups, std::array output_mask) { + const auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + if (!areAnyBatchedAtLevel({grad_output_, input_, weight_}, cur_level)){ + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::convolution_backward( + grad_output_, input_, weight_, bias_sizes_opt, stride, padding, + dilation, transposed, output_padding, groups, output_mask); + } + + Tensor grad_output; + optional grad_output_bdim; + std::tie(grad_output, grad_output_bdim) = unwrapTensorAtLevel(grad_output_, cur_level); + Tensor input; + optional input_bdim; + std::tie(input, input_bdim) = unwrapTensorAtLevel(input_, cur_level); + Tensor weight; + optional weight_bdim; + std::tie(weight, weight_bdim) = unwrapTensorAtLevel(weight_, cur_level); + + const auto grad_bias = compute_grad_bias(grad_output_, output_mask); + output_mask[2] = false; + + // TODO: A little bird says that unfold + matmul is actually faster than + // group convolution in many cases. We should benchmark some of + // the common cases and replace things with unfold + matmul as necessary. + + // Notation: + // B - a batch dimension + // G - groups (sometimes omitted because it doesn't matter) + // NO - grad_output + // NI - input + // OI - weight + // "(BO)I" - we don't actually care about the values of this Tensor, + // we just need to create a tensor on the same device with the + // correct shape and pray that the implementation is smart enough + // to not do anything with it. + + // BNO, BNI, BOI + // AKA one of the model ensembling case + if (grad_output_bdim && input_bdim && weight_bdim) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + grad_output = reshape_dim_into(*grad_output_bdim, 1, grad_output); + + // BNO, BNI, BOI -> N(BO), N(BI), (BO)I + const auto batch_size = weight.size(*weight_bdim); + input = reshape_dim_into(*input_bdim, 1, input); + weight = reshape_dim_into(*weight_bdim, 0, weight); + const auto result = at::convolution_backward( + grad_output, input, weight, nullopt, stride, padding, dilation, + transposed, output_padding, batch_size * groups, output_mask); + // N(BI), (BO)I -> NBI, BOI + const auto grad_input = output_mask[0] ? + reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor(); + const auto grad_weight = output_mask[1] ? + reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor(); + return std::make_tuple( + output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input, + output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight, + grad_bias); + } + + Tensor grad_input; + if (output_mask[0]) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto result = convolution_backward_input_batch_rule( + grad_output, grad_output_bdim, + input, input_bdim, + weight, weight_bdim, + stride, padding, dilation, transposed, output_padding, groups); + grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level); + } + + Tensor grad_weight; + if (output_mask[1]) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto result = convolution_backward_weight_batch_rule( + grad_output, grad_output_bdim, + input, input_bdim, + weight, weight_bdim, + stride, padding, dilation, transposed, output_padding, groups); + grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level); + } + return std::make_tuple(grad_input, grad_weight, grad_bias); + + // Someone's definitely going to find a problem with this batching rule so + // I'm leaving the following fallback if we need it back. + // static auto op = c10::Dispatcher::singleton() + // .findSchemaOrThrow("aten::convolution_backward", ""); + // auto result = slow_fallback(op, { + // grad_output_, input_, weight_, bias_sizes_opt, + // stride, padding, dilation, transposed, output_padding, groups, output_mask + // }); + // return std::make_tuple(grad_input, std::get<1>(result), grad_bias); +} + + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + VMAP_SUPPORT(convolution, convolution_batch_rule); + m.impl("_convolution", _convolution_decomp); + m.impl("convolution_backward", convolution_backward_plumbing); +} + +}} // namespace at;:functorch diff --git a/functorch/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/functorch/csrc/BatchRulesDecompositions.cpp new file mode 100644 index 0000000000000..3256847121eed --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesDecompositions.cpp @@ -0,0 +1,261 @@ + +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace functorch { + +#define OP_DECOMPOSE(op) m.impl(#op, static_cast(native::op)); +#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast(native::op)); + +TORCH_LIBRARY_IMPL(aten, FT_VMAP_MODE_KEY, m) { + OP_DECOMPOSE(alpha_dropout_); + OP_DECOMPOSE(dropout_); + OP_DECOMPOSE(feature_alpha_dropout_); + OP_DECOMPOSE(feature_dropout_); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + OP_DECOMPOSE2(__and__, Scalar); + OP_DECOMPOSE2(__and__, Tensor); + OP_DECOMPOSE2(__iand__, Tensor); + OP_DECOMPOSE2(__iand__, Scalar); + OP_DECOMPOSE2(__ior__, Tensor); + OP_DECOMPOSE2(__ior__, Scalar); + OP_DECOMPOSE2(__ixor__, Tensor); + OP_DECOMPOSE2(__ixor__, Scalar); + OP_DECOMPOSE2(__or__, Tensor); + OP_DECOMPOSE2(__or__, Scalar); + OP_DECOMPOSE2(__xor__, Tensor); + OP_DECOMPOSE2(__xor__, Scalar); + OP_DECOMPOSE(_batch_norm_impl_index); + OP_DECOMPOSE(absolute); + OP_DECOMPOSE(arctan2); + OP_DECOMPOSE(avg_pool1d); + OP_DECOMPOSE(adaptive_max_pool1d); + OP_DECOMPOSE(adaptive_avg_pool1d); + OP_DECOMPOSE(adaptive_avg_pool2d); + OP_DECOMPOSE(adaptive_avg_pool3d); + OP_DECOMPOSE(adjoint); + OP_DECOMPOSE(arccos); + OP_DECOMPOSE(arccosh); + OP_DECOMPOSE(arcsin); + OP_DECOMPOSE(arcsinh); + OP_DECOMPOSE(arctan); + OP_DECOMPOSE(arctanh); + OP_DECOMPOSE(atleast_1d); + OP_DECOMPOSE2(atleast_1d, Sequence); + OP_DECOMPOSE(atleast_2d); + OP_DECOMPOSE2(atleast_2d, Sequence); + OP_DECOMPOSE(atleast_3d); + OP_DECOMPOSE2(atleast_3d, Sequence); + OP_DECOMPOSE(batch_norm); + OP_DECOMPOSE2(bitwise_or, Scalar); + OP_DECOMPOSE2(bitwise_xor, Scalar); + OP_DECOMPOSE(broadcast_tensors); + OP_DECOMPOSE(broadcast_to); + OP_DECOMPOSE(cartesian_prod); + OP_DECOMPOSE(cdist); + OP_DECOMPOSE(clip); + OP_DECOMPOSE2(clip, Tensor ); + OP_DECOMPOSE(concat); + OP_DECOMPOSE(conj_physical); + OP_DECOMPOSE(combinations); + OP_DECOMPOSE(corrcoef); + OP_DECOMPOSE(cosine_embedding_loss); + OP_DECOMPOSE(cosine_similarity); + OP_DECOMPOSE(cov); + OP_DECOMPOSE(cross_entropy_loss); + OP_DECOMPOSE2(cumulative_trapezoid, x); + OP_DECOMPOSE2(cumulative_trapezoid, dx); + OP_DECOMPOSE2(dsplit, int); + OP_DECOMPOSE2(dsplit, array); + OP_DECOMPOSE(det); + OP_DECOMPOSE(diag_backward); + OP_DECOMPOSE(diff); + OP_DECOMPOSE(dstack); + OP_DECOMPOSE(einsum); + OP_DECOMPOSE(embedding_backward); + OP_DECOMPOSE(expand_as); + OP_DECOMPOSE(fft_fft); + OP_DECOMPOSE(fft_fftshift); + OP_DECOMPOSE(fft_fft2); + OP_DECOMPOSE(fft_fftn); + OP_DECOMPOSE(fft_hfft); + OP_DECOMPOSE(fft_hfft2); + OP_DECOMPOSE(fft_hfftn); + OP_DECOMPOSE(fft_ifft); + OP_DECOMPOSE(fft_ifftshift); + OP_DECOMPOSE(fft_ifft2); + OP_DECOMPOSE(fft_ifftn); + OP_DECOMPOSE(fft_ihfft); + OP_DECOMPOSE(fft_irfft); + OP_DECOMPOSE(fft_irfft2); + OP_DECOMPOSE(fft_irfftn); + OP_DECOMPOSE(fft_rfft); + OP_DECOMPOSE(fft_rfft2); + OP_DECOMPOSE(fft_rfftn); + OP_DECOMPOSE(fix); + OP_DECOMPOSE(fliplr); + OP_DECOMPOSE(flipud); + OP_DECOMPOSE2(float_power, Tensor_Tensor); + OP_DECOMPOSE2(float_power, Tensor_Scalar); + OP_DECOMPOSE(ger); + OP_DECOMPOSE2(gradient, scalarint); + OP_DECOMPOSE2(gradient, scalararray); + OP_DECOMPOSE2(gradient, array); + OP_DECOMPOSE2(gradient, scalarrayint); + OP_DECOMPOSE2(gradient, scalarrayarray); + OP_DECOMPOSE2(gradient, tensorarrayint); + OP_DECOMPOSE2(gradient, tensorarray); + OP_DECOMPOSE2(greater_equal, Tensor ); + OP_DECOMPOSE2(greater, Tensor ); + OP_DECOMPOSE(grid_sampler); + OP_DECOMPOSE(group_norm); + OP_DECOMPOSE(hinge_embedding_loss); + OP_DECOMPOSE2(hsplit, int); + OP_DECOMPOSE2(hsplit, array); + OP_DECOMPOSE(hstack); + OP_DECOMPOSE(index_select_backward); + OP_DECOMPOSE(inner); + OP_DECOMPOSE(instance_norm); + OP_DECOMPOSE(kron); + OP_DECOMPOSE(l1_loss); + OP_DECOMPOSE(layer_norm); + OP_DECOMPOSE2(ldexp, Tensor); + OP_DECOMPOSE2(less_equal, Tensor ); + OP_DECOMPOSE2(less, Tensor ); + OP_DECOMPOSE(linalg_cond); + OP_DECOMPOSE(linalg_det); + OP_DECOMPOSE(linalg_matmul); + OP_DECOMPOSE(linalg_multi_dot); + OP_DECOMPOSE(linalg_svd); + OP_DECOMPOSE(linalg_svdvals); + OP_DECOMPOSE(matmul); + OP_DECOMPOSE(matrix_H); + OP_DECOMPOSE2(max, other ); + OP_DECOMPOSE(max_pool1d_with_indices); + OP_DECOMPOSE(max_pool2d); + OP_DECOMPOSE(meshgrid); + OP_DECOMPOSE2(meshgrid, indexing); + OP_DECOMPOSE(mH); + OP_DECOMPOSE2(min, other ); + OP_DECOMPOSE2(moveaxis, intlist); + OP_DECOMPOSE2(movedim, int); + OP_DECOMPOSE(msort); + OP_DECOMPOSE(mT); + OP_DECOMPOSE(narrow); + OP_DECOMPOSE(negative); + OP_DECOMPOSE(nll_loss_nd); + OP_DECOMPOSE(nll_loss); + OP_DECOMPOSE(nll_loss2d); + OP_DECOMPOSE2(not_equal, Tensor ); + OP_DECOMPOSE(outer); + OP_DECOMPOSE(pairwise_distance); + OP_DECOMPOSE(poisson_nll_loss); + OP_DECOMPOSE(qr); + OP_DECOMPOSE(ravel); + OP_DECOMPOSE2(repeat_interleave, self_int); + OP_DECOMPOSE2(repeat_interleave, self_Tensor); + OP_DECOMPOSE(reshape); + OP_DECOMPOSE(resolve_conj); + OP_DECOMPOSE(resolve_neg); + OP_DECOMPOSE(row_stack); + OP_DECOMPOSE(rrelu); + OP_DECOMPOSE2(softmax, int); + OP_DECOMPOSE(special_gammainc); + OP_DECOMPOSE(special_gammaincc); + OP_DECOMPOSE(special_logit); + OP_DECOMPOSE(special_log_softmax); + OP_DECOMPOSE(special_logsumexp); + OP_DECOMPOSE(special_multigammaln); + OP_DECOMPOSE(special_polygamma); + OP_DECOMPOSE(special_softmax); + OP_DECOMPOSE2(split, sizes); + OP_DECOMPOSE(square); + OP_DECOMPOSE(numpy_T); + OP_DECOMPOSE(reshape_as); + OP_DECOMPOSE(t); + OP_DECOMPOSE2(result_type, Tensor); + OP_DECOMPOSE2(result_type, Scalar); + OP_DECOMPOSE2(result_type, Scalar_Tensor); + OP_DECOMPOSE2(result_type, Scalar_Scalar); + OP_DECOMPOSE(is_same_size); + OP_DECOMPOSE(view_as); + OP_DECOMPOSE2(size, int); + OP_DECOMPOSE(is_complex); + OP_DECOMPOSE(std); + OP_DECOMPOSE2(std, dim); + OP_DECOMPOSE(std_mean); + OP_DECOMPOSE2(std_mean, dim); + OP_DECOMPOSE(swapaxes); + OP_DECOMPOSE2(subtract, Tensor); + OP_DECOMPOSE(sum_to_size); + OP_DECOMPOSE(svd); + OP_DECOMPOSE(swapdims); + OP_DECOMPOSE(take_along_dim); + OP_DECOMPOSE(tensordot); + OP_DECOMPOSE(tile); + OP_DECOMPOSE2(trapezoid, x); + OP_DECOMPOSE2(trapezoid, dx); + OP_DECOMPOSE2(trapz, x); + OP_DECOMPOSE2(trapz, dx); + OP_DECOMPOSE(var); + OP_DECOMPOSE2(var, dim); + OP_DECOMPOSE(var_mean); + OP_DECOMPOSE2(var_mean, dim); + OP_DECOMPOSE2(vsplit, int); + OP_DECOMPOSE2(vsplit, array); + OP_DECOMPOSE(vstack); + OP_DECOMPOSE(orgqr); + OP_DECOMPOSE2(unflatten, int); + OP_DECOMPOSE(_convolution_double_backward); + OP_DECOMPOSE(conv_transpose1d); + OP_DECOMPOSE2(conv_transpose2d, input); + OP_DECOMPOSE2(conv_transpose3d, input); + OP_DECOMPOSE(conv1d); + OP_DECOMPOSE(conv2d); + OP_DECOMPOSE(conv3d); + OP_DECOMPOSE2(conv1d, padding); + OP_DECOMPOSE2(conv2d, padding); + OP_DECOMPOSE2(conv3d, padding); + OP_DECOMPOSE(_convolution_mode); + OP_DECOMPOSE(frobenius_norm); + OP_DECOMPOSE(type_as); + OP_DECOMPOSE(linalg_diagonal); + OP_DECOMPOSE(pad); + OP_DECOMPOSE(_pad_circular); + + // divide, alias for div + OP_DECOMPOSE2(divide, Tensor); + OP_DECOMPOSE2(divide_, Tensor); + OP_DECOMPOSE2(divide, Scalar); + OP_DECOMPOSE2(divide, Tensor_mode); + OP_DECOMPOSE2(divide_, Tensor_mode); + OP_DECOMPOSE2(divide, Scalar_mode); + OP_DECOMPOSE2(divide_, Scalar_mode); + + // divide, alias for div + OP_DECOMPOSE2(true_divide, Tensor); + OP_DECOMPOSE2(true_divide_, Tensor); + OP_DECOMPOSE2(true_divide, Scalar); + OP_DECOMPOSE2(true_divide_, Scalar); + + // multiply, alias for mul + OP_DECOMPOSE2(multiply, Tensor) + OP_DECOMPOSE2(multiply_, Tensor) + OP_DECOMPOSE2(multiply, Scalar) + OP_DECOMPOSE2(multiply_, Scalar) +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesDynamic.cpp b/functorch/functorch/csrc/BatchRulesDynamic.cpp new file mode 100644 index 0000000000000..e752d96d168da --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesDynamic.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + + +namespace at { namespace functorch { + +void unsupportedDynamicOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, "vmap: We do not support batching operators that can output dynamic shape. ", + "Attempted to vmap over ", op.schema().operator_name(), ". ", + "Please voice your support in https://github.com/pytorch/functorch/issues/256"); +} +#define UNSUPPORTED_DYNAMIC(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedDynamicOp>()); + +// NB: item and is_nonzero can decompose to this... +void unsupportedLocalScalarDense(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, + "vmap: It looks like you're either (1) calling .item() on a Tensor or ", + "(2) attempting to use a Tensor in some data-dependent control flow or ", + "(3) encountering this error in PyTorch internals. ", + "For (1): we don't support vmap over calling .item() on a Tensor, please try to ", + "rewrite what you're doing with other operations. ", + "For (2): If you're doing some ", + "control flow instead, we don't support that yet, please shout over at ", + "https://github.com/pytorch/functorch/issues/257 . ", + "For (3): please file an issue."); +} + +void unsupportedItem(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, + "vmap: It looks like you're calling .item() on a Tensor. ", + "We don't support vmap over calling .item() on a Tensor, please try to ", + "rewrite what you're doing with other operations. If error is occurring ", + "somewhere inside PyTorch internals, please file a bug report."); +} + +void unsupportedIsNonzero(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, + "vmap: It looks like you're attempting to use a Tensor in some ", + "data-dependent control flow. ", + "We don't support that yet, please shout over at ", + "https://github.com/pytorch/functorch/issues/257 ."); +} + +void unsupportedAllclose(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, + "vmap over torch.allclose isn't supported yet. Please voice your ", + "support over at github.com/pytorch/functorch/issues/275"); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + UNSUPPORTED_DYNAMIC(nonzero); + UNSUPPORTED_DYNAMIC(where); + UNSUPPORTED_DYNAMIC(unique); + m.impl("_local_scalar_dense", torch::CppFunction::makeFromBoxedFunction<&unsupportedLocalScalarDense>()); + m.impl("item", torch::CppFunction::makeFromBoxedFunction<&unsupportedItem>()); + m.impl("is_nonzero", torch::CppFunction::makeFromBoxedFunction<&unsupportedIsNonzero>()); + m.impl("allclose", torch::CppFunction::makeFromBoxedFunction<&unsupportedAllclose>()); +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesFactory.cpp b/functorch/functorch/csrc/BatchRulesFactory.cpp new file mode 100644 index 0000000000000..3f63d27a0c8e1 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesFactory.cpp @@ -0,0 +1,112 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "c10/core/SymIntArrayRef.h" + +namespace at { namespace functorch { + +template +struct NewBlahBatchRuleHelper; + +template +struct NewBlahBatchRuleHelper> { + static std::tuple> apply( + const Tensor& tensor, + optional batch_dim, + IntArrayRef shape, + T... extra_args) { + const auto bdim_size = tensor.size(batch_dim.value()); + VmapDimVector new_shape; + new_shape.reserve(shape.size() + 1); + new_shape.emplace_back(bdim_size); + new_shape.insert(new_shape.end(), shape.begin(), shape.end()); + return std::make_tuple(Func(tensor, new_shape, std::forward(extra_args)...), 0); + } +}; + +// USAGE: NEW_BLAH_BATCH_RULE(at::new_zeros) +// INCORRECT USAGE: NEW_BLAH_BATCH_RULE(&at::new_zeros) +// It is important that this macro is not passed a function pointer!! +#define NEW_BLAH_BATCH_RULE(fn) SINGLE_ARG(\ + NewBlahBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +std::tuple> _new_zeros_with_same_feature_meta_batch_rule( + const Tensor& self, optional self_bdim, + const Tensor& other, optional other_bdim, + int64_t self_num_batch_dims) { + // The "self, other" naming is too confusing + // What this function really says is "create a new tangent for this base". + const auto& base = other; + const auto& base_bdim = other_bdim; + const auto& tangent = self; + const auto& tangent_bdim = self_bdim; + + // Three case: + // Case 1 Case 2 Case 3 + // base [6] [B, 6] [B, 6] + // tangent [B, 5] [5] [B, 5] + // result [B, 6] [B, 6] [B, 6] + + // Case 2 & 3 + if (base_bdim) { + auto base_ = moveBatchDimToFront(base, base_bdim); + Tensor tangent_ = tangent; + if (tangent_bdim.has_value()) { + // tangent [B, K0, K1, 5] + // base_ [B, 6] + // We want to move B to after the Ks, so that self_num_batch_dims + // (which really means tangent_num_batch_dims) isn't interfered with. + // [B, K0, K1, 6] -> [K0, K1, B, 6] + // + // [K0, K1, B, 6], [B, 5], 2 -> [K0, K1, B, 5] + tangent_ = tangent.movedim(*tangent_bdim, self_num_batch_dims); + } + const auto result = at::_new_zeros_with_same_feature_meta(tangent_, base_, self_num_batch_dims); + return std::make_tuple(result, self_num_batch_dims); + } + + // Case 1: + auto tangent_ = moveBatchDimToFront(tangent, tangent_bdim); + auto result = at::_new_zeros_with_same_feature_meta(tangent_, base, self_num_batch_dims + 1); + return std::make_tuple(result, 0); +} + +bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) { + return true; +} + +Tensor new_empty_symint_decomp( + const Tensor& self, + SymIntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt + ) { + return self.new_empty(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + m.impl("_has_same_storage_numel", _has_same_storage_numel_batch_rule); + VMAP_SUPPORT(ones_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(ones_like))); + VMAP_SUPPORT(zeros_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(zeros_like))); + VMAP_SUPPORT(empty_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(empty_like))); + VMAP_SUPPORT(randn_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(randn_like))); + VMAP_SUPPORT(rand_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(rand_like))); + VMAP_SUPPORT(full_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(full_like))); + VMAP_SUPPORT(new_empty, NEW_BLAH_BATCH_RULE(ATEN_FN(new_empty))); + m.impl("new_empty.SymInt", new_empty_symint_decomp); + VMAP_SUPPORT(new_zeros, NEW_BLAH_BATCH_RULE(ATEN_FN(new_zeros))); + VMAP_SUPPORT(new_ones, NEW_BLAH_BATCH_RULE(ATEN_FN(new_ones))); + VMAP_SUPPORT(new_full, NEW_BLAH_BATCH_RULE(ATEN_FN(new_full))); + VMAP_SUPPORT(_new_zeros_with_same_feature_meta, _new_zeros_with_same_feature_meta_batch_rule); + // Not sure how to add the ones with irregular args to the mix cleanly (i.e. randint takes an extra int parameter) +} +}} diff --git a/functorch/functorch/csrc/BatchRulesHelper.cpp b/functorch/functorch/csrc/BatchRulesHelper.cpp new file mode 100644 index 0000000000000..3118a6826d8b8 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesHelper.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +namespace at { namespace functorch { + +Tensor moveBatchDimToFront(const Tensor& tensor, optional maybe_batch_dim) { + if (!maybe_batch_dim.has_value()) { + return tensor; + } + if (maybe_batch_dim.value() == 0) { + return tensor; + } + return tensor.movedim(maybe_batch_dim.value(), 0); +} + +int64_t rankWithoutBatchDim(const Tensor& tensor, optional maybe_batch_dim) { + int64_t result = tensor.dim(); + if (maybe_batch_dim.has_value()) { + result -= 1; + } + return result; +} + +int64_t numelWithoutBatchDim(const Tensor& tensor, optional maybe_batch_dim) { + if (!maybe_batch_dim) { + return tensor.numel(); + } + return tensor.numel() / tensor.size(*maybe_batch_dim); +} + +optional valIfNonempty(optional maybe_empty, int64_t new_val) { + if (maybe_empty.has_value()) { + return new_val; + } + return nullopt; +} + +int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) { + // NB: assumes the batch dim is at the front of the tensor + optional bdim = has_batch_dim ? optional(0) : nullopt; + auto rank = rankWithoutBatchDim(tensor, bdim); + auto wrapped_dim = maybe_wrap_dim(logical_dim, rank); + if (has_batch_dim) { + return wrapped_dim + 1; + } + return wrapped_dim; +} + +VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims) { + // NB: assumes the batch dim is at the front of the tensor + optional bdim = has_batch_dim ? optional(0) : nullopt; + auto rank = rankWithoutBatchDim(tensor, bdim); + VmapDimVector result; + result.reserve(logical_dims.size()); + for (auto d : logical_dims){ + if (has_batch_dim) { + result.push_back(maybe_wrap_dim(d, rank)+1); + } else { + result.push_back(maybe_wrap_dim(d, rank)); + } + } + return result; +} + +Tensor maybePadToLogicalRank(const Tensor& tensor, optional has_bdim, int64_t logical_rank) { + if (!has_bdim) { + return tensor; + } + auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim); + if (tensor_logical_rank >= logical_rank) { + return tensor; + } + VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end()); + for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) { + new_sizes.insert(new_sizes.begin() + 1, 1); + } + return tensor.view(new_sizes); +} + +void check_randomness(RandomnessType randomness, bool any_tensor_batched) { + TORCH_CHECK( + randomness != RandomnessType::Error, + "vmap: called random operation while in randomness error mode. Please either use the " + "'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap" + ); + + TORCH_CHECK( + !(randomness == RandomnessType::Same && any_tensor_batched), + "Vmap does not currently support same randomness with a batched tensor input. ", + "Please file an issue with functorch" + ) +} + +void check_randomness(RandomnessType randomness) { + check_randomness(randomness, false); // for ops that don't take in any tensors, don't hit same error +} + +Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x) { + auto x_dim = x.dim(); + src = maybe_wrap_dim(src, x_dim); + dst = maybe_wrap_dim(dst, x_dim - 1); // Returned Tensor has one fewer dim + VmapDimVector new_shape(x.sizes().begin(), x.sizes().end()); + new_shape.erase(new_shape.begin() + src); + new_shape[dst] *= x.sizes()[src]; + return at::reshape(x.movedim(src, dst), new_shape); +} + +Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x) { + src = maybe_wrap_dim(src, x.dim()); + VmapDimVector shape(x.sizes().begin(), x.sizes().end()); + TORCH_INTERNAL_ASSERT(shape[src] % size1 == 0); + int64_t size2 = shape[src] / size1; + shape[src] = size1; + shape.insert(shape.begin() + src + 1, size2); + return at::reshape(x, shape); +} + +void vmapIncompatibleInplaceError(const char* schema_name) { + TORCH_CHECK(false, + "vmap: ", schema_name, "(self, *extra_args) is not possible because ", + "there exists a Tensor `other` in extra_args that has more elements ", + "than `self`. This happened due to `other` being vmapped over but ", + "`self` not being vmapped over in a vmap. ", + "Please try to use out-of-place operators instead of ", schema_name, ". ", + "If said operator is being called inside the PyTorch framework, ", + "please file a bug report instead."); +} + +void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + // TODO: templatize based on op and keep static trace_exec + auto * trace_exec = torch::jit::GetDecompositionExecutor(schema); + trace_exec->run((*stack)); + if (stack->back().isTuple()) { + IValue tup = stack->back(); + stack->pop_back(); + for (const auto& elem: tup.toTuple()->elements()) { + stack->push_back(elem); + } + } +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesHelper.h b/functorch/functorch/csrc/BatchRulesHelper.h new file mode 100644 index 0000000000000..834fd01e5ada8 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesHelper.h @@ -0,0 +1,472 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace functorch { +Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x); +Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x); + +Tensor moveBatchDimToFront(const Tensor& tensor, optional maybe_batch_dim); +int64_t rankWithoutBatchDim(const Tensor& tensor, optional maybe_batch_dim); +int64_t numelWithoutBatchDim(const Tensor& tensor, optional maybe_batch_dim); +optional valIfNonempty(optional maybe_empty, int64_t new_val); +int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim); +VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims); + +void vmapIncompatibleInplaceError(const char* schema_name); + +Tensor maybePadToLogicalRank(const Tensor& tensor, optional has_bdim, int64_t logical_rank); + +void check_randomness(RandomnessType randomness); +void check_randomness(RandomnessType randomness, bool any_tensor_bdim); + +inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, int64_t batch_size) { + if (has_bdim) { + return tensor; + } + const auto sizes = tensor.sizes(); + DimVector expanded_shape; + expanded_shape.reserve(sizes.size()); + expanded_shape.emplace_back(batch_size); + expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end()); + return tensor.expand(expanded_shape); +} + +#define VMAP_SUPPORT(op, batch_rule) \ + m.impl(#op, op ## _generated_plumbing); + +#define VMAP_SUPPORT2(op, overload, batch_rule) \ + m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing); + +#define OP_DECOMPOSE(op) m.impl(#op, static_cast(native::op)); +#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast(native::op)); + +// DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain +template +struct BasicUnaryBatchRuleHelper; + +template +struct BasicUnaryBatchRuleHelper> { + static std::tuple> apply( + const Tensor& tensor, + optional batch_dim, + T... extra_args) { + return std::make_tuple(Func(tensor, std::forward(extra_args)...), batch_dim); + } +}; + +// USAGE: BASIC_UNARY_BATCH_RULE(at::sin) +// INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin) +// It is important that this macro is not passed a function pointer!! +#define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\ + BasicUnaryBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +#define UNARY_POINTWISE(op) \ + VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op))); + +template +struct VariadicBdimsBatchRuleHelper; + +template +struct VariadicBdimsBatchRuleHelper> { + static std::tuple> apply( + const Tensor& tensor, + optional batch_dim, + T... extra_args) { + auto tensor_ = moveBatchDimToFront(tensor, batch_dim); + return std::make_tuple(Func(tensor_, std::forward(extra_args)...), 0); + } +}; + +// USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse) +// INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse) +// It is important that this macro is not passed a function pointer!! +#define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\ + VariadicBdimsBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +#define VARIADIC_BDIMS(op) \ + VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op))); + +#define VARIADIC_BDIMS2(op, overload) \ + VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload))); + +template +void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + auto orig_arguments = torch::jit::last(*stack, num_arguments); + if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + auto arguments = torch::jit::pop(*stack, num_arguments); + std::vector>> tensor_inputs; + std::vector tensor_pos; + for (const auto idx : c10::irange(0, num_arguments)) { + const auto& ivalue = arguments[idx]; + if (ivalue.isTensor()) { + Tensor tensor_value; + optional tensor_bdim; + std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(ivalue.toTensor(), cur_level); + tensor_inputs.emplace_back(tensor_value, tensor_bdim); + tensor_pos.push_back(idx); + } + } + Func(tensor_inputs); + + size_t tensor_idx = 0; + TORCH_INTERNAL_ASSERT(tensor_pos.size() > 0); + for (const auto arg_idx : c10::irange(0, num_arguments)) { + if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) { + torch::jit::push(stack, arguments[arg_idx]); + } else { + TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size()); + torch::jit::push(stack, tensor_inputs[tensor_idx].first); + tensor_idx++; + } + } + + op.callBoxed(stack); + const auto returns = torch::jit::pop(*stack, num_returns); + for (const auto& ret : returns) { + if (ret.isTensor()) { + torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level)); + } else { + TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values"); + } + } +} + +inline void handle_pointwise_ops(std::vector>> &tensor_inputs) { + int64_t out_logical_rank = 0; + for (auto& tensor_input : tensor_inputs) { + int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second); + out_logical_rank = std::max(out_logical_rank, cur_logical_rank); + } + for (auto& tensor_input: tensor_inputs) { + tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second); + tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank); + } +} + +#define POINTWISE_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +#define POINTWISE_BOXED2(op, overload) \ + m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction>()); + +inline void handle_variadic_bdims(std::vector>> &tensor_inputs) { + for (auto & tensor_input : tensor_inputs) { + tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second); + } +} + +#define VARIADIC_BDIMS_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +void run_jit_decomposition(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +#define RUN_JIT_DECOMPOSITION(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&run_jit_decomposition>()); + + +using UnpackedBatchedTensor = std::tuple>; + +inline void find_and_unpack_tensors( + const torch::jit::Stack* stack, + int64_t num_args, + int64_t cur_level, + SmallVector* tensors, + SmallVector* tensors_pos, + int64_t* batch_size) { + + int64_t computed_batch_size = -1; + int64_t args_begin = stack->size() - num_args; + + for (const auto idx : c10::irange(0, num_args)) { + const auto& ivalue = (*stack)[args_begin + idx]; + if (!ivalue.isTensor()) { + continue; + } + auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level); + const auto& tensor_value = std::get<0>(unpacked); + const auto tensor_bdim = std::get<1>(unpacked); + if (tensor_bdim.has_value()) { + auto candidate_batch_size = tensor_value.size(*tensor_bdim); + if (computed_batch_size == -1) { + computed_batch_size = candidate_batch_size; + } + TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size); + } + + tensors->push_back(std::move(unpacked)); + tensors_pos->push_back(idx); + } + TORCH_INTERNAL_ASSERT(computed_batch_size > -1); + *batch_size = computed_batch_size; +} + +inline void boxed_existing_bdim_all_batch_rule( + const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + const auto arguments = torch::jit::last(stack, num_arguments); + if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + int64_t args_begin = stack->size() - num_arguments; + SmallVector tensor_inputs; + SmallVector tensor_pos; + int64_t batch_size; + + find_and_unpack_tensors( + stack, num_arguments, cur_level, + &tensor_inputs, &tensor_pos, &batch_size); + + // for each tensor, ensure it has a bdim and reshape it. + for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) { + const auto& value = std::get<0>(tensor_inputs[tensor_idx]); + auto bdim = std::get<1>(tensor_inputs[tensor_idx]); + auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size); + if (!bdim.has_value()) { + bdim = 0; + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_); + } + + op.callBoxed(stack); + + for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) { + const auto& ret = (*stack)[idx]; + TORCH_INTERNAL_ASSERT(ret.isTensor(), + "This boxed batching rule does not currently support ops that return non-tensor values"); + (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level); + } +} + +// Use when all tensors arguments accept one (normal) batch dim. +// This batching rule expands the batch dim on all Tensors, reshapes it into +// dim 0, calls the op, and then reshapes the batch dim out of dim 0. +// This is not the most efficient thing; if there are alternatives, plese try +// to use them. Use this only as a last resort. +#define EXISTING_BDIM_ALL_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction()); + +template +inline void boxed_all_tensors_have_optional_bdim( + const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + const auto arguments = torch::jit::last(stack, num_arguments); + if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + int64_t args_begin = stack->size() - num_arguments; + SmallVector tensor_inputs; + SmallVector tensor_pos; + int64_t batch_size; + + find_and_unpack_tensors( + stack, num_arguments, cur_level, + &tensor_inputs, &tensor_pos, &batch_size); + + optional is_no_batch_dim_case; + + for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) { + const auto& value = std::get<0>(tensor_inputs[tensor_idx]); + auto bdim = std::get<1>(tensor_inputs[tensor_idx]); + const auto logical_rank = rankWithoutBatchDim(value, bdim); + + if (!is_no_batch_dim_case.has_value()) { + is_no_batch_dim_case = (logical_rank == feature_rank); + } + auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size); + if (!bdim.has_value()) { + bdim = 0; + } + if (*is_no_batch_dim_case) { + TORCH_INTERNAL_ASSERT(logical_rank == feature_rank); + value_ = moveBatchDimToFront(value_, bdim); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = value_; + continue; + } + TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1); + value_ = reshape_dim_into(*bdim, 0, value_); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = value_; + } + + op.callBoxed(stack); + + for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) { + const auto& ret = (*stack)[idx]; + TORCH_INTERNAL_ASSERT(ret.isTensor(), + "This boxed batching rule does not currently support ops that return non-tensor values"); + if (*is_no_batch_dim_case) { + (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level); + } else { + (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level); + } + } +} + +// Useful for many NN operators. +// The operator must satisfy the following: +// - All arguments must accept an optional batch dim. +// - All arguments must be the same rank +#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \ + m.impl(#op, \ + torch::CppFunction::makeFromBoxedFunction<\ + boxed_all_tensors_have_optional_bdim<\ + feature_rank, \ + contig_tensor_index>\ + >()); + +template +struct ExistingBdimBatchRuleHelper; + +template +struct ExistingBdimBatchRuleHelper> { + static std::tuple> apply( + const Tensor& self, + optional self_bdim, + T... extra_args) { + auto self_ = reshape_dim_into(*self_bdim, 0, self); + auto out = Func(self_, std::forward(extra_args)...); + return std::make_tuple(reshape_dim_outof(0, self.sizes()[*self_bdim], out), 0); + } +}; + +// USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse) +// INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse) +// It is important that this macro is not passed a function pointer!! +#define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\ + ExistingBdimBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + + +#define EXISTING_BDIM(op) \ + VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op))); + +#define EXISTING_BDIM2(op, overload) \ + VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload))); + +#define INVOKE(object,ptrToMember) ((object).*(ptrToMember)) + + +template +Tensor& unary_inplace_batch_rule(Tensor& self, optional, ExtraArgs... extra_args) { + INVOKE(self, Method)(std::forward(extra_args)...); + return self; +} + +inline int64_t get_bdim_size4( + const Tensor& a_value, optional a_bdim, + const Tensor& b_value, optional b_bdim, + const Tensor& c_value, optional c_bdim, + const Tensor& d_value, optional d_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + if (c_bdim) + return c_value.size(*c_bdim); + if (d_bdim) + return d_value.size(*d_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +inline int64_t get_bdim_size3( + const Tensor& a_value, optional a_bdim, + const Tensor& b_value, optional b_bdim, + const Tensor& c_value, optional c_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + if (c_bdim) + return c_value.size(*c_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +inline int64_t get_bdim_size2( + const Tensor& a_value, optional a_bdim, + const Tensor& b_value, optional b_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +// [start, start + 1, ..., stop - 1] +inline VmapDimVector range(int64_t start, int64_t stop) { + TORCH_INTERNAL_ASSERT(stop >= start); + VmapDimVector dims; + dims.reserve(stop - start); + for (int64_t i = start; i < stop; i++) { + dims.emplace_back(i); + } + return dims; +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp b/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp new file mode 100644 index 0000000000000..d7286c55f6876 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp @@ -0,0 +1,218 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +namespace at { namespace functorch { + +// Note [Batching rules for matmul-like operators] +// at::matmul doesn't "de-expand" arguments to get better performance (maybe +// it should). In the batching rules for matmul-like operators (dot, mv, mm), +// we should be careful not to expand any unnecessary dimensions. i.e., if +// only one of the two arguments is a BatchedTensor, then we should try +// not to expand batch dimensions onto the other arg. + +std::tuple> dot_batch_rule(const Tensor& A, optional A_bdim, const Tensor& B, optional B_bdim) { + TORCH_CHECK(A.dim() - A_bdim.has_value() == 1 && B.dim() - B_bdim.has_value() == 1, "Got wrong shapes for dot"); + auto A_ = moveBatchDimToFront(A, A_bdim); + auto B_ = moveBatchDimToFront(B, B_bdim); + if (A_bdim && B_bdim) { + return std::make_tuple(at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0); + } else { + return std::make_tuple(at::matmul(A_, B_.t()), 0); + } +} + +// NB: I wrote this like this because we *might* want its for a future matmul +// batch rule that isn't decomposed... +// "tv" = tensor @ vector +static std::tuple> tv_batch_rule( + const Tensor& self, optional self_bdim, + const Tensor& other, optional other_bdim) { + if (self_bdim && other_bdim) { + // See Note [Batching rules for matmul-like operators] + // B...OI, BI -> ...BOI, BI1 -> ...BO1 -> ...BO + auto self_ = at::movedim(self, *self_bdim, -3); + auto other_ = moveBatchDimToFront(other, other_bdim); + other_ = other_.unsqueeze(-1); + auto result = at::matmul(self_, other_).squeeze(-1); + auto result_bdim = result.dim() - 2; + return std::make_tuple( std::move(result), result_bdim ); + } + else if (self_bdim && !other_bdim) { + // B...OI, I -> B...O + auto self_ = moveBatchDimToFront(self, self_bdim); + return std::make_tuple( at::matmul(self_, other), 0 ); + } + else if (!self_bdim && other_bdim) { + // ...OI, BI -> ...OI, IB -> OB + auto other_ = at::movedim(other, *other_bdim, -1); + auto result = at::matmul(self, other_); + return std::make_tuple( std::move(result), 1 ); + } + TORCH_INTERNAL_ASSERT(false, "can't get here"); +} + +static std::tuple> mv_batch_rule( + const Tensor& self, optional self_bdim, + const Tensor& other, optional other_bdim) { + auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); + TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 1, + "Shape mismatch: ", + "Got incorrect dims for mv(a, b). a has dim ", self_logical_rank, + "and b has dim ", other_logical_rank, + "but expected them to have dim 2 and dim 1"); + return tv_batch_rule(self, self_bdim, other, other_bdim); +} + +static std::tuple> mm_batch_rule( + const Tensor& self, optional self_bdim, + const Tensor& other, optional other_bdim) { + auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); + TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 2, + "Shape mismatch: Got incorrect dims for mm(a, b). " + "a has dim ", self_logical_rank, + "and b has dim ", other_logical_rank, + "but expected them to have dim 2 and dim 2"); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto other_ = moveBatchDimToFront(other, other_bdim); + return std::make_tuple( at::matmul(self_, other_), 0 ); +} + +static std::tuple> bmm_batch_rule( + const Tensor& self, optional self_bdim, + const Tensor& other, optional other_bdim) { + auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); + TORCH_CHECK(self_logical_rank == 3 && other_logical_rank == 3, + "Shape mismatch: Got incorrect dims for bmm(a, b). " + "a has dim ", self_logical_rank, + "and b has dim ", other_logical_rank, + "but expected them to have dim 3 and dim 3"); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto other_ = moveBatchDimToFront(other, other_bdim); + return std::make_tuple( at::matmul(self_, other_), 0 ); +} + +// AFAICT, nothing here can be batched. So we decompose :) +Tensor addmv_decomp( + const Tensor& input, const Tensor& mat, const Tensor& vec, const Scalar& beta, const Scalar& alpha) { + Tensor out = at::mv(mat, vec); + if (!alpha.equal(1)) { + out = alpha * out; + } + if (!beta.equal(0)) { + out = beta * input + out; + } + return out; +} + +Tensor addbmm_decomp( + const Tensor& input, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { + Tensor out = at::bmm(batch1, batch2).sum(0); + if (!alpha.equal(1)) { + out = alpha * out; + } + if (!beta.equal(0)) { + out = beta * input + out; + } + return out; +} + +Tensor baddbmm_decomp( + const Tensor& input, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { + Tensor out = at::bmm(batch1, batch2); + if (!alpha.equal(1)) { + out = alpha * out; + } + if (!beta.equal(0)) { + out = beta * input + out; + } + return out; +} + +Tensor linear_decomp( + const Tensor& input, const Tensor& weight, + const c10::optional& bias_opt) { + auto result = input.matmul(weight.t()); + if (bias_opt) { + // NB: It's too much work to figure out how to actually fuse the bias so + // we're not going to. + // TODO: if the result isn't batched but bias is, then we need to do the following. + // Otherwise, it can just be in-place. We should write a more nuanced + // decomposition rule + return result.add(*bias_opt); + } + return result; +} + +Tensor addmm_decomp(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) { + // Decomposition that is probably not very fast... + return at::add(self * beta, at::mm(mat1, mat2), alpha); +} + +void _linalg_check_errors_batch_rule(const Tensor& info, optional info_bdim, c10::string_view api_name, bool is_matrix) { + auto info_ = moveBatchDimToFront(info, info_bdim); + // Not a matrix means this is a batch of matrices + at::_linalg_check_errors(info_, api_name, false); +} + +std::tuple> +householder_product_batch_rule(const Tensor &input, c10::optional input_bdim, + const Tensor &tau, c10::optional tau_bdim) +{ + auto input_ = moveBatchDimToFront(input, input_bdim); + auto tau_ = moveBatchDimToFront(tau, tau_bdim); + + auto batch_size = get_bdim_size2(input, input_bdim, tau, tau_bdim); + + input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size); + tau_ = ensure_has_bdim(tau_, tau_bdim.has_value(), batch_size); + return std::make_tuple(at::linalg_householder_product(input_, tau_), 0); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + VMAP_SUPPORT(bmm, bmm_batch_rule); + m.impl("addmv", addmv_decomp); + m.impl("addmm", addmm_decomp); + m.impl("addbmm", addbmm_decomp); + m.impl("baddbmm", baddbmm_decomp); + VMAP_SUPPORT(dot, dot_batch_rule); + VMAP_SUPPORT(mv, mv_batch_rule); + VMAP_SUPPORT(mm, mm_batch_rule); + m.impl("linear", linear_decomp); + VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule); + + VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule); + + VARIADIC_BDIMS_BOXED(cholesky_solve); + VARIADIC_BDIMS_BOXED(linalg_cholesky_ex); + VARIADIC_BDIMS_BOXED(linalg_eig); + VARIADIC_BDIMS_BOXED(linalg_eigh); + VARIADIC_BDIMS_BOXED(linalg_inv_ex); + VARIADIC_BDIMS(linalg_pinv); + VARIADIC_BDIMS_BOXED(linalg_qr); + VARIADIC_BDIMS_BOXED(linalg_slogdet); + + VARIADIC_BDIMS(cholesky); + VARIADIC_BDIMS(cholesky_inverse); + VARIADIC_BDIMS_BOXED(geqrf); + VARIADIC_BDIMS(logdet); + VARIADIC_BDIMS(matrix_exp); + VARIADIC_BDIMS(pinverse); + VARIADIC_BDIMS(inverse); + VARIADIC_BDIMS_BOXED(slogdet); + VARIADIC_BDIMS_BOXED(_linalg_svd); + VARIADIC_BDIMS_BOXED(solve); + VARIADIC_BDIMS_BOXED(symeig); + VARIADIC_BDIMS_BOXED(triangular_solve); + + VARIADIC_BDIMS_BOXED(_linalg_det); + VARIADIC_BDIMS_BOXED(_lu_with_info); +} +}} diff --git a/functorch/functorch/csrc/BatchRulesLoss.cpp b/functorch/functorch/csrc/BatchRulesLoss.cpp new file mode 100644 index 0000000000000..16ee2fb7e9c16 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesLoss.cpp @@ -0,0 +1,290 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +namespace at { namespace functorch { +// Flattens out all dims except the batch dim, and also moves batch dim +// (if it exists) to front. +at::Tensor flatten_logical(const Tensor& tensor, optional bdim) { + if (bdim.has_value()) { + auto result = moveBatchDimToFront(tensor, bdim); + if (result.dim() > 1) { + return result.flatten(1); + } else { + return result; + } + } else { + return tensor.flatten(); + } +} + +std::tuple> +mse_loss_batch_rule(const at::Tensor& self, optional self_bdim, const at::Tensor& target, + optional target_bdim, int64_t reduction) { + auto self_ = flatten_logical(self, self_bdim); + auto target_ = flatten_logical(target, target_bdim); + auto result = at::mse_loss(self_, target_, Reduction::None); + if (result.dim() == 1) { + return std::make_tuple(result, 0); + } else if (reduction == Reduction::None) { + DimVector end_shape; + const auto batched_elem = self_bdim.has_value() ? + moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim); + return std::make_tuple(result.reshape(batched_elem.sizes()), 0); + } else if (reduction == Reduction::Sum) { + return std::make_tuple(result.sum(-1), 0); + } else if (reduction == Reduction::Mean) { + return std::make_tuple(result.mean(-1), 0); + } + TORCH_INTERNAL_ASSERT(false); +}; + +static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) { + if (reduction == at::Reduction::Mean) { + return unreduced.mean(); + } else if (reduction == at::Reduction::Sum) { + return unreduced.sum(); + } + return unreduced; +} + +Tensor binary_cross_entropy_plumbing( + const Tensor& self, const Tensor& target, + const optional& weight, int64_t reduction) { + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) + && !isBatchedAtLevel(weight, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::binary_cross_entropy(self, target, weight, reduction); + } + + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + Tensor target_value; + optional target_bdim; + std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level); + + Tensor result; + if (self_bdim || target_bdim) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto bdim_size = get_bdim_size2(self_value, self_bdim, target_value, target_bdim); + auto self_ = moveBatchDimToFront(self_value, self_bdim); + auto target_ = moveBatchDimToFront(target_value, target_bdim); + self_ = ensure_has_bdim(self_, self_bdim.has_value(), bdim_size); + target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size); + result = at::binary_cross_entropy(self_, target_, nullopt, Reduction::None); + result = makeBatched(result, 0, cur_level); + } else { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + result = at::binary_cross_entropy(self_value, target_value, nullopt, Reduction::None); + } + if (weight.has_value() && weight->defined()) { + result = result * weight.value(); + } + return apply_loss_reduction(result, reduction); +} + +Tensor binary_cross_entropy_backward_plumbing( + const Tensor& grad, const Tensor& input, const Tensor& target, + const c10::optional& weight_opt, int64_t reduction) { + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + if (!areAnyBatchedAtLevel({grad, input, target, weight_opt}, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::binary_cross_entropy_backward(grad, input, target, weight_opt, reduction); + } + + Tensor grad_value; + optional grad_bdim; + std::tie(grad_value, grad_bdim) = unwrapTensorAtLevel( + reduction == Reduction::None ? grad : grad.expand_as(input), cur_level); + Tensor input_value; + optional input_bdim; + std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level); + Tensor target_value; + optional target_bdim; + std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level); + + Tensor grad_input; + if (grad_bdim || input_bdim || target_bdim) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto bdim_size = get_bdim_size3( + grad_value, grad_bdim, input_value, input_bdim, target_value, target_bdim); + + auto grad_ = moveBatchDimToFront(grad_value, grad_bdim); + auto input_ = moveBatchDimToFront(input_value, input_bdim); + auto target_ = moveBatchDimToFront(target_value, target_bdim); + + grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bdim_size); + input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size); + target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size); + + grad_input = at::binary_cross_entropy_backward( + grad_, input_, target_, nullopt, Reduction::None); + grad_input = makeBatched(grad_input, 0, cur_level); + } else { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + grad_input = at::binary_cross_entropy_backward( + grad_value, input_value, target_value, nullopt, Reduction::None); + } + if (weight_opt.has_value() && weight_opt->defined()) { + grad_input = grad_input * weight_opt.value(); + } + if (reduction == Reduction::Mean) { + grad_input.div_(input.numel()); + } + return grad_input; +} + +std::tuple nll_loss_forward_decomposition( + const Tensor & self, + const Tensor & target, + const c10::optional & weight, + int64_t reduction, int64_t ignore_index) { + + // self can be [N, C, ...] or [C] + // target can be [N, ...] or [] + + int64_t channel_dim = 1; + if (self.dim() < 2) { + channel_dim = 0; + } + auto self_ = self; + Tensor weight_; + + if (weight && weight->defined()) { + // Here is a specific case with reduction mean and non-batched tensors + // https://github.com/pytorch/pytorch/issues/61309 + // In this case weight is cancelled: w * x[t] / w -> x[t] + if (!(reduction == Reduction::Mean && self_.dim() < 2)) { + // reshape weights to [1, C, 1, ..., 1] + auto shape = weight->sizes(); + VmapDimVector new_shape(self_.dim(), 1); + new_shape[channel_dim] = shape[0]; + weight_ = weight->reshape(new_shape); + self_ = self_ * weight_; + } + } + auto target_ = target.unsqueeze(channel_dim); + // target can be [N, 1, ...] or [1] + + auto result = -at::gather(self_, channel_dim, target_).squeeze(channel_dim); + auto total_weight = at::full( + {}, result.numel(), self_.scalar_type(), + self_.layout(), self_.device(), nullopt); + + bool has_ignore_index = ignore_index >= 0; + Tensor ignore_index_mask; + if (has_ignore_index) { + ignore_index_mask = target != ignore_index; + result = result * ignore_index_mask; + total_weight = ignore_index_mask.sum().to(self_); + } + + // Apply the reduction + if (result.dim() > 0) { + if (reduction == Reduction::Sum) { + result = result.sum(); + } else if (reduction == Reduction::Mean) { + if (!weight || !weight->defined()) { + if (has_ignore_index) { + TORCH_INTERNAL_ASSERT(ignore_index_mask.defined()); + // total_weight is ignore_index_mask.sum() + result = result.sum() / total_weight; + } else { + result = result.mean(); + } + } else { + TORCH_INTERNAL_ASSERT(weight_.defined()); + weight_ = weight_.expand(self_.sizes()); + auto wsum = at::gather(weight_, channel_dim, target_).squeeze(channel_dim); + if (has_ignore_index) { + TORCH_INTERNAL_ASSERT(ignore_index_mask.defined()); + wsum = wsum * ignore_index_mask; + } + wsum = wsum.sum(); + result = result.sum() / wsum; + total_weight = wsum; + } + } + } else if (reduction == Reduction::Mean && weight && weight->defined()) { + // here weight is [C] and target is [1] + auto wsum = at::gather(*weight, channel_dim, target_).squeeze(channel_dim); + if (has_ignore_index) { + TORCH_INTERNAL_ASSERT(ignore_index_mask.defined()); + wsum = wsum * ignore_index_mask; + } + total_weight = wsum.sum(); + } + + return std::make_tuple(result, total_weight); +} + +at::Tensor nll_loss_backward_decomposition( + const at::Tensor & grad_output, const at::Tensor & self, + const at::Tensor & target, const c10::optional & weight, + int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + + int64_t channel_dim = 1; + if (self.dim() < 2) { + channel_dim = 0; + } + auto target_ = target.unsqueeze(channel_dim); + + auto grad_output_ = grad_output; + if (reduction == Reduction::Mean) { + grad_output_ = grad_output_ / total_weight; + } + + auto grad_input = at::zeros_like(self); + grad_input = at::scatter(grad_input, channel_dim, target_, -1.0); + + if (grad_output_.dim() < grad_input.dim() && grad_output_.dim() > 0) { + grad_output_ = grad_output_.unsqueeze(channel_dim); + } + + Tensor weight_; + if (weight && weight->defined()) { + auto self_ = self; + auto shape = weight->sizes(); + VmapDimVector new_shape(self_.dim(), 1); + new_shape[channel_dim] = shape[0]; + weight_ = weight->reshape(new_shape); + grad_output_ = grad_output_ * weight_; + } + + bool has_ignore_index = ignore_index >= 0; + Tensor ignore_index_mask; + if (has_ignore_index) { + ignore_index_mask = target_ != ignore_index; + grad_output_ = grad_output_ * ignore_index_mask; + } + + return grad_input * grad_output_; +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + m.impl("nll_loss_forward", nll_loss_forward_decomposition); + m.impl("nll_loss2d_forward", nll_loss_forward_decomposition); + m.impl("nll_loss_backward", nll_loss_backward_decomposition); + m.impl("nll_loss2d_backward", nll_loss_backward_decomposition); + VMAP_SUPPORT(mse_loss, mse_loss_batch_rule); + // mse_loss_backwards uses a decomposition for its batch rule + m.impl("binary_cross_entropy", binary_cross_entropy_plumbing); + m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing); +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesModules.cpp b/functorch/functorch/csrc/BatchRulesModules.cpp new file mode 100644 index 0000000000000..3d54ba5d0fe47 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesModules.cpp @@ -0,0 +1,442 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +namespace at { namespace functorch { + +static Tensor getStepTensor(const Tensor& indices, int64_t bdim_size, int64_t num_embeddings) { + // [batch_size, 1, 1, 1, ..., 1] + DimVector view_shape(indices.dim(), 1); + view_shape[0] = bdim_size; + auto range = at::arange(0, bdim_size * num_embeddings, num_embeddings, indices.options()); + return range.view(view_shape); +} + +std::tuple> embedding_batch_rule( + const Tensor& weight, optional weight_bdim, + const Tensor& indices, optional indices_bdim, + int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { + if (!weight_bdim && indices_bdim) { + // B*, ED -> B*D + const auto result = at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); + return std::make_tuple(result, indices_bdim); + } else if (weight_bdim && !indices_bdim) { + // *, BED -> *, E(BD) -> *(BD) -> *BD + const auto batch_size = weight.size(*weight_bdim); + const auto weight_ = reshape_dim_into(*weight_bdim, /*embedding_dim*/1, weight); + auto result = at::embedding(weight_, indices, padding_idx, scale_grad_by_freq, sparse); + result = reshape_dim_outof(-1, batch_size, result); + return std::make_tuple(result, result.dim() - 2); + } + TORCH_INTERNAL_ASSERT(weight_bdim && indices_bdim); + // B*, BED -> B*, (BE)D -> B*D + // We'll need to do something extra: add (0, E, 2*E, ...) to the indices. + const auto batch_size = weight.size(*weight_bdim); + const auto num_embeddings = weight.size((*weight_bdim == 0) ? 1 : 0); + const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight); + auto indices_ = moveBatchDimToFront(indices, indices_bdim); + + const auto range = getStepTensor(indices, batch_size, num_embeddings); + indices_ = indices_ + range; + const auto result = at::embedding(weight_, indices_, padding_idx, scale_grad_by_freq, sparse); + return std::make_tuple(result, 0); +} + +std::tuple> +embedding_dense_backward_batch_rule( + const Tensor& grad_, optional grad_bdim, + const Tensor& indices_, optional indices_bdim, + int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + Tensor grad = grad_; + Tensor indices = indices_; + if (!indices_bdim && grad_bdim) { + const auto bdim_size = grad.size(*grad_bdim); + grad = reshape_dim_into(*grad_bdim, -1, grad); + auto result = at::embedding_dense_backward( + grad, indices, num_weights, padding_idx, scale_grad_by_freq); + result = reshape_dim_outof(1, bdim_size, result); + return std::make_tuple(result, 1); + } + const auto bdim_size = indices.size(*indices_bdim); + indices = moveBatchDimToFront(indices, indices_bdim); + grad = moveBatchDimToFront(grad, grad_bdim); + grad = ensure_has_bdim(grad, grad_bdim.has_value(), bdim_size); + const auto range = getStepTensor(indices, bdim_size, num_weights); + auto result = at::embedding_dense_backward( + grad, indices + range, num_weights * bdim_size, -1, scale_grad_by_freq); + result = reshape_dim_outof(0, bdim_size, result); + // Fill in the padding. We can't do it in the embedding_dense_backward call + // because we need to fill in multiple rows! + if (padding_idx >= 0) { + result.select(1, padding_idx).fill_(0); + } + return std::make_tuple(result, 0); +} + +/** + * grid sample batch rule breaks down into 3 cases: + * case 1 (input is batched, grid is not): + * batch input along first dimension, unpack along first dimension + * 2d: + * input: N(BC)H_{in}W_{in}, grid: NH_{out}W_{out}2 + * output: N(BC)H_{out}W_{out} + * 3d: + * input: N(BC)D_{in}H_{in}W_{in}, grid: ND_{out}H_{out}W_{out}3 + * output: N(BC)D_{out}H_{out}W_{out} + * case 2 (input is not batched, grid is batched): + * batch grid along second dimension, unpack along second dimension + * 2d: + * input: NCH_{in}W_{in}, grid: N(BH_{out})W_{out}2 + * output: NC(BH_{out})W_{out} + * 3d: + * input: NCD_{in}H_{in}W_{in}, grid: N(BD_{out})H_{out}W_{out}3 + * output: NC(BD_{out})H_{out}W_{out} + * case 3 (input and grid are both batched): + * batch grid and input along 0th dimension, unpack along 0th dimension + * 2d: + * input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2 + * output: (BN)CH_{out}W_{out} + * 3d: + * input: (BN)CD_{in}H_{in}W_{in}, grid: (BN)D_{out}H_{out}W_{out}3 + * output: (BN)CD_{out}H_{out}W_{out} + */ +template +std::tuple> +grid_sample_batch_rule(const Tensor& input, optional input_bdim, const Tensor& grid, optional grid_bdim, ExtraArgs... extra_args) { + std::tuple> result; + if (input_bdim && !grid_bdim) { + auto new_input = reshape_dim_into(*input_bdim, 1, input); + auto out = Func(new_input, grid, std::forward(extra_args)...); + out = reshape_dim_outof(1, input.sizes()[*input_bdim], out); + result = std::make_tuple(out, 1); + } else if (!input_bdim && grid_bdim) { + // grid of N(BH)W2 -> NC(BH)W or grid of N(BD)HBW3 -> NC(BD)HW + auto new_grid = reshape_dim_into(*grid_bdim, 1, grid); + auto out = Func(input, new_grid, std::forward(extra_args)...); + out = reshape_dim_outof(2, grid.sizes()[*grid_bdim], out); + result = std::make_tuple(out, 2); + } else if (input_bdim && grid_bdim) { + auto new_input = reshape_dim_into(*input_bdim, 0, input); + auto new_grid = reshape_dim_into(*grid_bdim, 0, grid); + auto out = Func(new_input, new_grid, std::forward(extra_args)...); + out = reshape_dim_outof(0, input.sizes()[*grid_bdim], out); + result = std::make_tuple(out, 0); + } else { + result = std::make_tuple(Func(input, grid, std::forward(extra_args)...), nullopt); + } + return result; +} + +std::tuple +grid_sample_backward_helper_in( + const Tensor& grad_output, optional grad_output_bdim, + const Tensor& input, optional input_bdim, + const Tensor& grid, optional grid_bdim) { + + auto batch_size = get_bdim_size3( + grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim); + + auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); + grad_output_ = ensure_has_bdim(grad_output_, grad_output_bdim.has_value(), batch_size); + grad_output_ = reshape_dim_into(0, 0, grad_output_); + + auto input_ = moveBatchDimToFront(input, input_bdim); + input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size); + input_ = reshape_dim_into(0, 0, input_); + + auto grid_ = moveBatchDimToFront(grid, grid_bdim); + grid_ = ensure_has_bdim(grid_, grid_bdim.has_value(), batch_size); + grid_ = reshape_dim_into(0, 0, grid_); + + return std::make_tuple(grad_output_, input_, grid_, batch_size); +} + +std::tuple, Tensor, optional> +grid_sample_backward_helper_out( + const std::tuple & bw_out, + optional grad_input_out_bdim, + optional grad_grid_out_bdim, + int64_t bdim_size) { + auto grad_input = std::get<0>(bw_out); + auto grad_grid = std::get<1>(bw_out); + grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input); + grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid); + auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim); + return result; +} + + +template +std::tuple, Tensor, optional> +grid_sample_backward_batch_rule( + const Tensor& grad_output, optional grad_output_bdim, + const Tensor& input, optional input_bdim, + const Tensor& grid, optional grid_bdim, + ExtraArgs... extra_args) { + + auto new_bw_input = grid_sample_backward_helper_in( + grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim); + + auto new_grad_output = std::get<0>(new_bw_input); + auto new_input = std::get<1>(new_bw_input); + auto new_grid = std::get<2>(new_bw_input); + int64_t batch_size = std::get<3>(new_bw_input); + + auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward(extra_args)...); + + return grid_sample_backward_helper_out(bw_out, 0, 0, batch_size); +} + +template +std::tuple, Tensor, optional> +cudnn_grid_sample_backward_batch_rule( + const Tensor& input, optional input_bdim, + const Tensor& grid, optional grid_bdim, + const Tensor& grad_output, optional grad_output_bdim) { + + auto new_bw_input = grid_sample_backward_helper_in( + grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim); + + auto new_grad_output = std::get<0>(new_bw_input); + auto new_input = std::get<1>(new_bw_input); + auto new_grid = std::get<2>(new_bw_input); + int64_t bdim_size = std::get<3>(new_bw_input); + + auto bw_out = Func(new_input, new_grid, new_grad_output); + + return grid_sample_backward_helper_out(bw_out, 0, 0, bdim_size); +} + +std::tuple> cross_batch_rule( + const Tensor& self, optional self_bdim, + const Tensor& other, optional other_bdim, + const optional dim) { + auto self_ = moveBatchDimToFront(self, self_bdim); + auto other_ = moveBatchDimToFront(other, other_bdim); + + if (other_bdim.has_value() && !self_bdim.has_value()) { + self_ = self_.expand_as(other_); + } + if (self_bdim.has_value() && !other_bdim.has_value()) { + other_ = other_.expand_as(self_); + } + auto new_dim = dim; + if (dim.has_value()) { + auto t = (self_bdim.has_value()) ? self_ : other_; + bool flag = (self_bdim.has_value()) ? true : other_bdim.has_value(); + new_dim = getPhysicalDim(t, flag, *dim); + } else { + // if batch size is 3 we have to avoid that bdim is used as cross' dim argument + // according to cross API: + // > If dim is not given, it defaults to the first dimension found with the size 3 + // we have to skip batch dim and find another dim with size 3 + auto bs = (self_bdim.has_value()) ? self_.size(0) : (other_bdim.has_value()) ? other_.size(0) : -1; + if (bs == 3) { + auto t = (self_bdim.has_value()) ? self_ : other_; + int64_t idx = 1; + for (auto it = t.sizes().begin() + 1; it < t.sizes().end(); ++it, ++idx) { + if (*it == 3) { + new_dim = idx; + break; + } + } + } + } + optional out_dim = (self_bdim.has_value() || other_bdim.has_value()) ? 0 : (optional) nullopt; + return std::make_tuple(at::cross(self_, other_, new_dim), out_dim); +} + +// TODO: replace with targetable functionalization +Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) { + TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor."); + auto shape = self.sizes().vec(); + + // empty tensor could be converted to one hot representation, + // but shape inference is not possible. + if (self.numel() == 0) { + if (num_classes <= 0) { + AT_ERROR("Can not infer total number of classes from empty tensor."); + } else { + shape.push_back(num_classes); + return at::empty(shape, self.options()); + } + } + + TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please " + "provide an explicit positive num_classes argument."); + + // Disabling all of the following checks. This is OK because scatter has checks too. + // Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this. + // // non-empty tensor + // if (self.device().type() != at::kCUDA) { + // //for cuda, rely on device assert thrown by scatter + // TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative."); + // } + // if (self.device().type() != at::kCUDA) { + // //rely on device asserts from scatter to avoid sync here + // TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes."); + // } + + shape.push_back(num_classes); + Tensor ret = at::zeros(shape, self.options()); + return ret.scatter(-1, self.unsqueeze(-1), 1); +} + +template +struct UpsampleBackwardBatchRuleHelper; + +template +struct UpsampleBackwardBatchRuleHelper> { + static std::tuple> apply( + const Tensor& grad_output, optional grad_output_bdim, + OptionalArrayRef output_size, IntArrayRef input_size, + T... extra_args) { + auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output); + TORCH_INTERNAL_ASSERT(input_size.size() > 0); + + // input_size is wrong so we correct it + DimVector physical_input_size(input_size.begin(), input_size.end()); + physical_input_size[0] = grad_output_.sizes()[0]; + + auto out = Func( + grad_output_, + output_size, + physical_input_size, + std::forward(extra_args)...); + return std::make_tuple(reshape_dim_outof(0, grad_output.sizes()[*grad_output_bdim], out), 0); + } + +}; + +template +struct GridSampleBatchRuleHelper; + +template +struct GridSampleBatchRuleHelper> { + static std::tuple> apply( + const Tensor& input, optional input_batch_dim, + const Tensor& grid, optional grid_batch_dim, + T... extra_args) { + return grid_sample_batch_rule( + input, input_batch_dim, grid, grid_batch_dim, std::forward(extra_args)...); + } +}; + +template +struct GridSampleBackwardBatchRuleHelper; + +template +struct GridSampleBackwardBatchRuleHelper> { + static std::tuple, Tensor, optional> apply( + const Tensor& grad_output, optional grad_output_batch_dim, + const Tensor& input, optional input_batch_dim, + const Tensor& grid, optional grid_batch_dim, + T... extra_args) { + return grid_sample_backward_batch_rule( + grad_output, grad_output_batch_dim, + input, input_batch_dim, + grid, grid_batch_dim, + std::forward(extra_args)...); + } +}; + +template +struct CudnnGridSampleBackwardBatchRuleHelper { + static std::tuple, Tensor, optional> apply( + const Tensor& input, optional input_batch_dim, + const Tensor& grid, optional grid_batch_dim, + const Tensor& grad_output, optional grad_output_batch_dim) { + return cudnn_grid_sample_backward_batch_rule( + input, input_batch_dim, + grid, grid_batch_dim, + grad_output, grad_output_batch_dim + ); + } +}; + +#define GRID_SAMPLE_BATCH_RULE(fn) SINGLE_ARG(\ + GridSampleBatchRuleHelper<\ + decltype(&ATEN_FN(fn)),\ + &ATEN_FN(fn),\ + c10::guts::function_traits::parameter_types>::apply) + +#define GRID_SAMPLE_BW_BATCH_RULE(fn) SINGLE_ARG(\ + GridSampleBackwardBatchRuleHelper<\ + decltype(&ATEN_FN(fn)),\ + &ATEN_FN(fn),\ + c10::guts::function_traits::parameter_types>::apply) + +#define CUDNN_GRID_SAMPLE_BW_BATCH_RULE(fn)\ + CudnnGridSampleBackwardBatchRuleHelper::apply + +#define UPSAMPLE_BACKWARD(op, overload) VMAP_SUPPORT2(op, overload, SINGLE_ARG(\ + UpsampleBackwardBatchRuleHelper<\ + decltype(&ATEN_FN2(op, overload)),\ + &ATEN_FN2(op, overload),\ + c10::guts::function_traits::parameter_types>::apply)) + +#define UPSAMPLE_BATCH(op) \ + EXISTING_BDIM2(op, vec); \ + EXISTING_BDIM(op); + + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + EXISTING_BDIM(im2col); + EXISTING_BDIM(im2col_backward); + + VMAP_SUPPORT(embedding, embedding_batch_rule); + VMAP_SUPPORT(embedding_dense_backward, embedding_dense_backward_batch_rule); + + VMAP_SUPPORT(grid_sampler_2d, GRID_SAMPLE_BATCH_RULE(grid_sampler)); + VMAP_SUPPORT(grid_sampler_2d_backward, GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward)); + + VMAP_SUPPORT(grid_sampler_3d, GRID_SAMPLE_BATCH_RULE(grid_sampler)); + VMAP_SUPPORT(grid_sampler_3d_backward, GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_3d_backward)); + VMAP_SUPPORT(cudnn_grid_sampler_backward, CUDNN_GRID_SAMPLE_BW_BATCH_RULE(cudnn_grid_sampler_backward)); + + VMAP_SUPPORT(cudnn_grid_sampler, GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler)); + VMAP_SUPPORT(cross, cross_batch_rule); + + EXISTING_BDIM(pixel_shuffle); + EXISTING_BDIM(pixel_unshuffle); + + VARIADIC_BDIMS(constant_pad_nd); + EXISTING_BDIM(reflection_pad1d); + EXISTING_BDIM(reflection_pad2d); + EXISTING_BDIM(reflection_pad3d); + EXISTING_BDIM(replication_pad1d); + EXISTING_BDIM(replication_pad2d); + EXISTING_BDIM(replication_pad3d); + + EXISTING_BDIM_ALL_BOXED(replication_pad1d_backward); + EXISTING_BDIM_ALL_BOXED(replication_pad2d_backward); + EXISTING_BDIM_ALL_BOXED(replication_pad3d_backward); + + EXISTING_BDIM_ALL_BOXED(reflection_pad1d_backward); + EXISTING_BDIM_ALL_BOXED(reflection_pad2d_backward); + EXISTING_BDIM_ALL_BOXED(reflection_pad3d_backward); + + UPSAMPLE_BATCH(upsample_bicubic2d); + UPSAMPLE_BATCH(upsample_bilinear2d); + UPSAMPLE_BATCH(upsample_linear1d); + UPSAMPLE_BATCH(upsample_nearest1d); + UPSAMPLE_BATCH(upsample_nearest2d); + UPSAMPLE_BATCH(upsample_nearest3d); + UPSAMPLE_BATCH(upsample_trilinear3d); + + UPSAMPLE_BACKWARD(upsample_bicubic2d_backward, vec); + UPSAMPLE_BACKWARD(upsample_bilinear2d_backward, vec); + UPSAMPLE_BACKWARD(upsample_linear1d_backward, vec); + UPSAMPLE_BACKWARD(upsample_nearest1d_backward, vec); + UPSAMPLE_BACKWARD(upsample_nearest2d_backward, vec); + UPSAMPLE_BACKWARD(upsample_nearest3d_backward, vec); + UPSAMPLE_BACKWARD(upsample_trilinear3d_backward, vec); + m.impl("one_hot", one_hot_decomposition_hack); +} +}} diff --git a/functorch/functorch/csrc/BatchRulesNorm.cpp b/functorch/functorch/csrc/BatchRulesNorm.cpp new file mode 100644 index 0000000000000..e78538329582d --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesNorm.cpp @@ -0,0 +1,891 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +namespace at { namespace functorch { + +static bool is_empty_tensor(const Tensor& tensor) { + const auto shape = tensor.sizes(); + return shape.size() == 1 && shape[0] == 0; +} + +static optional compute_stat_bdim( + optional input_bdim, + const Tensor& stat) { + // There's a weird case where mean, rstd can both have shape (0,). + // It's possible that this is a bug on the PyTorch side. + // When that happens we don't want to return a BatchedTensor. + if (input_bdim.has_value() && !is_empty_tensor(stat)) { + return 0; + } + return nullopt; +} + +static Tensor padRight(const Tensor& tensor, optional has_bdim, int64_t logical_rank) { + // NB: Batch dim, if it exists, is assumed to be the first dim + auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim); + if (tensor_logical_rank >= logical_rank) { + return tensor; + } + VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end()); + for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) { + new_sizes.push_back(1); + } + return tensor.view(new_sizes); +} + +template +std::tuple,Tensor,optional,Tensor,optional> +batch_norm_batch_rule( + const Tensor& input, optional input_bdim, + const c10::optional& weight_opt, optional weight_bdim, + const c10::optional& bias_opt, optional bias_bdim, + const c10::optional& running_mean_opt, optional running_mean_bdim, + const c10::optional& running_var_opt, optional running_var_bdim, + bool training, double momentum, double eps) { + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + c10::MaybeOwned running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt); + auto running_mean = *running_mean_maybe_owned; + c10::MaybeOwned running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt); + auto running_var = *running_var_maybe_owned; + TORCH_CHECK(!training || (!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim))), + "Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, ", + "were not batched.\nIf you are using a module and do not need eval mode, please set `track_running_stats` to be False.", + "If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on ", + "how to patch your module to work with vmap"); + c10::optional bdim_size; + Tensor result0; + Tensor mean; + Tensor rstd; + if (!input_bdim && !running_mean_bdim && !running_var_bdim) { + const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight + const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda + const auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps); + result0 = std::get<0>(result).transpose(0, 1); // [C, B, *] + mean = std::get<1>(result); + rstd = std::get<2>(result); + } else { + bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim); + auto input_ = moveBatchDimToFront(input, input_bdim); + input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value()); + input_ = reshape_dim_into(0, /*channels dim*/1, input_); + + c10::optional running_mean_; + c10::optional running_var_; + if (running_mean.defined()) { + running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim); + running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value()); + running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous(); + } + if (running_var.defined()) { + running_var_ = moveBatchDimToFront(running_var, running_var_bdim); + running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value()); + running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous(); + } + + const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight + const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda + const auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps); + result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *] + result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *] + mean = std::get<1>(result); + mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C] + rstd = std::get<2>(result); + rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C] + } + + const auto stats_bdim = compute_stat_bdim(bdim_size, mean); + if (weight.defined()) { + const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); + auto weight_ = moveBatchDimToFront(weight, weight_bdim); + weight_ = padRight(weight_, weight_bdim, input_logical_rank); + result0 = result0 * weight_; + } + if (bias.defined()) { + const auto result_logical_rank = rankWithoutBatchDim( + result0, + bdim_size.has_value() || weight_bdim.has_value() ? optional(0) : optional(nullopt)); + auto bias_ = moveBatchDimToFront(bias, bias_bdim); + bias_ = padRight(bias_, bias_bdim, result_logical_rank); + result0 = result0 + bias_; + } + result0 = result0.transpose(1, 2); // [B0, B, C, *], because some arg must have been batched, the output must be batched + return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim); +} + +template +std::tuple> batch_norm_backward_no_weight_bias_batch_rule( + const at::Tensor & grad_out, optional grad_out_bdim, + const at::Tensor & input, optional input_bdim, + const c10::optional & running_mean_opt, optional running_mean_bdim, + const c10::optional & running_var_opt, optional running_var_bdim, + const at::Tensor & mean, optional mean_bdim, + const at::Tensor & rstd, optional rstd_bdim, + bool training, double eps) { + c10::MaybeOwned running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt); + const Tensor& running_mean = *running_mean_maybe_owned; + c10::MaybeOwned running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt); + const Tensor& running_var = *running_var_maybe_owned; + + if (!grad_out_bdim.has_value() && !input_bdim.has_value() && !running_mean_bdim.has_value() && !running_var_bdim.has_value()) { + // for either of these to have bdims, the input, running_mean, or running_var must have had a bdim + TORCH_INTERNAL_ASSERT(!mean_bdim); + TORCH_INTERNAL_ASSERT(!rstd_bdim); + const auto dummy_weight = at::ones(input.size(1), input.options()); + const auto result = Func( + grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false}); + return std::make_tuple(std::get<0>(result), nullopt); + } + + auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); + auto input_ = moveBatchDimToFront(input, input_bdim); + auto mean_ = moveBatchDimToFront(mean, mean_bdim); + auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim); + + // ensure all inputs have bdim. + const auto bdim_size = get_bdim_size4(grad_out, grad_out_bdim, input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim); + grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size); + input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size); + mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size); + rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size); + + optional running_mean_; + optional running_var_; + if (running_mean.defined()) { + running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim); + running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size); + running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous(); + } + if (running_var.defined()) { + running_var_ = moveBatchDimToFront(running_var, running_var_bdim); + running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size); + running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous(); + } + + input_ = reshape_dim_into(0, /*channels dim*/1, input_); + TORCH_INTERNAL_ASSERT(mean_.dim() == 2); + TORCH_INTERNAL_ASSERT(rstd_.dim() == 2); + mean_ = reshape_dim_into(0, 0, mean_); + rstd_ = reshape_dim_into(0, 0, rstd_); + grad_out_ = grad_out_.transpose(0, 1).flatten(1, 2); // [B0, B, C, *] -> [B, (B0, C), *] + + const auto dummy_weight = at::ones(input_.size(1), input_.options()); + auto result = at::native_batch_norm_backward( + grad_out_.contiguous(), + input_.contiguous(), + dummy_weight, + running_mean_, // contiguous called if there is a tensor given + running_var_, // contiguous called if there is a tensor given + mean_.contiguous(), + rstd_.contiguous(), + training, eps, {true, false, false}); + auto result0 = std::get<0>(result); + result0 = reshape_dim_outof(1, bdim_size, result0); // [B, B0, C, *] + result0 = result0.transpose(0, 1); // [B0, B, C, *] + return std::make_tuple(result0, 0); +} + +template +std::tuple batch_norm_backward_plumbing( + const at::Tensor & grad_out, + const at::Tensor & input, + const c10::optional & weight_opt, + const c10::optional & running_mean_opt, + const c10::optional & running_var_opt, + const c10::optional & save_mean_opt, + const c10::optional & save_rstd_opt, + bool training, + double eps, + std::array output_mask) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt); + const Tensor& running_mean = *running_mean_maybe_owned; + c10::MaybeOwned running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt); + const Tensor& running_var = *running_var_maybe_owned; + // NB: not sure why these are optional...these are required from the forward + const Tensor& save_mean = *save_mean_opt; + const Tensor& save_rstd = *save_rstd_opt; + TORCH_INTERNAL_ASSERT(save_mean.defined()); + TORCH_INTERNAL_ASSERT(save_rstd.defined()); + + // plumbing + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + Tensor grad_out_value; + optional grad_out_bdim; + std::tie(grad_out_value, grad_out_bdim) = unwrapTensorAtLevel(grad_out, cur_level); + Tensor input_value; + optional input_bdim; + std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level); + Tensor mean_value; + optional weight_value; + optional weight_bdim; + if (weight.defined()) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level); + } + optional running_mean_value; + optional running_mean_bdim; + if (running_mean.defined()) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean, cur_level); + } + optional running_var_value; + optional running_var_bdim; + if (running_var.defined()) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var, cur_level); + } + Tensor save_mean_value; + optional save_mean_bdim; + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean, cur_level); + Tensor save_rstd_value; + optional save_rstd_bdim; + std::tie(save_rstd_value, save_rstd_bdim) = unwrapTensorAtLevel(save_rstd, cur_level); + + // results + Tensor grad_bias; + Tensor grad_weight; + Tensor grad_input; + + TORCH_INTERNAL_ASSERT(grad_out_value.dim() > 1); // batch_norm can't operate on 1D tensors so the output will be at least 2D + if (output_mask[2]) { + grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim())); + } + if (output_mask[1] && weight_value.has_value()) { + // NB: output isn't saved... + auto mean = training ? save_mean : running_mean; + auto var = training ? save_rstd : (1 / at::sqrt(running_var + eps)); + const auto normalized_input = (input.transpose(0, 1) - padRight(mean, nullopt, input.dim())) * padRight(var, nullopt, input.dim()); + const auto expanded_grad_weight = normalized_input * grad_out.transpose(0, 1); + grad_weight = expanded_grad_weight.sum(range(1, grad_out.dim())); + } + if (output_mask[0]) { + const auto grad_normalized_input = weight.defined() ? + grad_out.transpose(0, 1) * padRight(weight, nullopt, grad_out.dim()) : grad_out.transpose(0, 1); // [B0, C, B, *] + Tensor grad_normalized_input_value; + optional grad_normalized_input_bdim; + std::tie(grad_normalized_input_value, grad_normalized_input_bdim) = + unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *] + + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto results = batch_norm_backward_no_weight_bias_batch_rule( + grad_normalized_input_value, grad_normalized_input_bdim, + input_value, input_bdim, + running_mean_value, running_mean_bdim, + running_var_value, running_var_bdim, + save_mean_value, save_mean_bdim, + save_rstd_value, save_rstd_bdim, + training, eps); + grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level); + } + return std::make_tuple(grad_input, grad_weight, grad_bias); +} + +std::tuple native_group_norm_plumbing( + const Tensor & input, const c10::optional & weight_opt, + const c10::optional & bias_opt, int64_t N, int64_t C, + int64_t HxW, int64_t group, double eps) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + if (!areAnyBatchedAtLevel({input, weight_opt, bias_opt}, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::native_group_norm(input, weight_opt, bias_opt, N, C, HxW, group, eps); + } + + Tensor input_value; + optional input_bdim; + std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level); + + Tensor result0; + Tensor mean; + Tensor rstd; + if (input_bdim) { + const auto input_ = reshape_dim_into(*input_bdim, 0, input_value); + const auto bdim_size = input_value.size(*input_bdim); + + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto result = at::native_group_norm(input_, nullopt, nullopt, N * bdim_size, C, HxW, group, eps); + result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level); + mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level); + rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level); + } else { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto result = at::native_group_norm(input_value, nullopt, nullopt, N, C, HxW, group, eps); + result0 = std::get<0>(result); + mean = std::get<1>(result); + rstd = std::get<2>(result); + } + + if (weight.defined()) { + const auto padded_weight = padRight(weight, nullopt, result0.dim() - 1); + result0 = result0 * padded_weight; + } + + if (bias.defined()) { + const auto padded_bias = padRight(bias, nullopt, result0.dim() - 1); + result0 = result0 + padded_bias; + } + + return std::make_tuple(result0, mean, rstd); +} + +std::tuple> group_norm_backward_no_weight_bias_batch_rule( + const at::Tensor & grad_out, optional grad_out_bdim, + const at::Tensor & input, optional input_bdim, + const at::Tensor & mean, optional mean_bdim, + const at::Tensor & rstd, optional rstd_bdim, + int64_t N, int64_t C, int64_t HxW, int64_t group) { + auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); + auto input_ = moveBatchDimToFront(input, input_bdim); + auto mean_ = moveBatchDimToFront(mean, mean_bdim); + auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim); + + const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim); + grad_out_ = ensure_has_bdim(grad_out, grad_out_bdim.has_value(), bdim_size); + input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size); + mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size); + rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size); + + grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *] + input_ = reshape_dim_into(0, 0, input_); // [B0 * N, C, *] + mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G] + rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G] + + const auto result = native_group_norm_backward( + grad_out_.contiguous(), + input_.contiguous(), + mean_.contiguous(), + rstd_.contiguous(), + nullopt, N * bdim_size, C, HxW, group, {true, false, false}); + auto result0 = std::get<0>(result); + result0 = reshape_dim_outof(0, bdim_size, result0); + return std::make_tuple(result0, 0); +} + +std::tuple native_group_norm_backward_plumbing( + const Tensor & grad_out, const Tensor & input, const Tensor & mean, + const Tensor & rstd, const c10::optional & weight_opt, + int64_t N, int64_t C, int64_t HxW, int64_t group, std::array output_mask +) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + // plumbing + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + if (!areAnyBatchedAtLevel({grad_out, input, mean, rstd, weight_opt}, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::native_group_norm_backward(grad_out, input, mean, rstd, weight_opt, N, C, HxW, group, output_mask); + } + + Tensor input_value; + optional input_bdim; + std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level); + Tensor weight_value; + optional weight_bdim; + if (weight.defined()){ + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level); + } + Tensor mean_value; + optional mean_bdim; + std::tie(mean_value, mean_bdim) = unwrapTensorAtLevel(mean, cur_level); + Tensor rstd_value; + optional rstd_bdim; + std::tie(rstd_value, rstd_bdim) = unwrapTensorAtLevel(rstd, cur_level); + + // results + Tensor grad_input; + Tensor grad_weight; + Tensor grad_bias; + + TORCH_INTERNAL_ASSERT(grad_out.dim() > 1); // group_norm can't operate on 1D tensors so the output will be at least 2D + if (output_mask[2]) { + grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim())); + } + + if (output_mask[1] && weight.defined()) { + const auto reshaped_input = reshape_dim_outof(1, group, input); + const auto normalized_input = (reshaped_input - padRight(mean, nullopt, reshaped_input.dim())) * padRight(rstd, nullopt, reshaped_input.dim()); + const auto expanded_grad_weight = reshape_dim_into(1, 1, normalized_input) * grad_out; + grad_weight = expanded_grad_weight.transpose(0, 1).sum(range(1, expanded_grad_weight.dim())); + } + + if (output_mask[0]) { + const auto grad_normalized_input = weight.defined() ? + grad_out * padRight(weight, nullopt, grad_out.dim() - 1) : grad_out; + Tensor grad_normalized_input_value; + optional grad_normalized_input_bdim; + std::tie(grad_normalized_input_value, grad_normalized_input_bdim) = + unwrapTensorAtLevel(grad_normalized_input, cur_level); + + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto res = group_norm_backward_no_weight_bias_batch_rule( + grad_normalized_input_value, grad_normalized_input_bdim, + input_value, input_bdim, + mean_value, mean_bdim, + rstd_value, rstd_bdim, + N, C, HxW, group + ); + grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level); + } + return std::make_tuple(grad_input, grad_weight, grad_bias); +} + +C10_ALWAYS_INLINE bool has_same_shape( + const Tensor& tensor, optional tensor_bdim, + IntArrayRef normalized_shape) { + if (!tensor.defined()) { + return true; + } + if (rankWithoutBatchDim(tensor, tensor_bdim) != (int64_t) normalized_shape.size()) { + return false; + } + const auto tensor_shape = tensor.sizes(); + for (const auto i : c10::irange(normalized_shape.size())) { + auto j = i; + // (0, 1, 2), 1 -> (0, 2, 3) + if (tensor_bdim.has_value() && (int64_t)i >= tensor_bdim.value()) { + j = j + 1; + } + if (normalized_shape[i] != tensor_shape[j]) { + return false; + } + } + return true; +} + +C10_ALWAYS_INLINE void check_same_shape( + const Tensor& tensor, optional tensor_bdim, + IntArrayRef normalized_shape, const std::string& name) { + TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape), + "Expected ", name, " to be of same shape as normalized_shape, but got ", + name, " of shape ", + tensor.sizes(), + " and normalized_shape = ", + normalized_shape); +} + +// Ugh, hard to deduplicate +C10_ALWAYS_INLINE void _check_layer_norm_inputs( + IntArrayRef normalized_shape, + const Tensor& weight, optional weight_bdim, + const Tensor& bias, optional bias_bdim) { + + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + check_same_shape(weight, weight_bdim, normalized_shape, "weight"); + check_same_shape(bias, bias_bdim, normalized_shape, "weight"); +} + +std::tuple,Tensor,optional,Tensor,optional> +native_layer_norm_batch_rule( + const Tensor& input, optional input_bdim, + IntArrayRef normalized_shape, + const c10::optional& weight_opt, optional weight_bdim, + const c10::optional& bias_opt, optional bias_bdim, + double eps) { + auto input_ = moveBatchDimToFront(input, input_bdim); + if (!weight_bdim && !bias_bdim) { + const auto result = at::native_layer_norm(input_, normalized_shape, weight_opt, bias_opt, eps); + const auto mean = std::get<1>(result); + const auto rstd = std::get<2>(result); + const auto stats_bdim = compute_stat_bdim(input_bdim, mean); + return std::make_tuple(std::get<0>(result), 0, mean, stats_bdim, rstd, stats_bdim); + } + + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + _check_layer_norm_inputs(normalized_shape, weight, weight_bdim, bias, bias_bdim); + + const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); + const auto result = at::native_layer_norm(input_, normalized_shape, nullopt, nullopt, eps); + auto result0 = std::get<0>(result); + const auto mean = std::get<1>(result); + const auto rstd = std::get<2>(result); + const auto stats_bdim = compute_stat_bdim(input_bdim, mean); + + if (weight.defined()) { + auto weight_ = moveBatchDimToFront(weight, weight_bdim); + weight_ = maybePadToLogicalRank(weight_, /*has_bdim*/weight_bdim, input_logical_rank); + result0 = result0 * weight_; + } + if (bias.defined()) { + const auto result_logical_rank = rankWithoutBatchDim( + result0, + input_bdim.has_value() || weight_bdim.has_value() ? optional(0) : optional(nullopt)); + auto bias_ = moveBatchDimToFront(bias, bias_bdim); + bias_ = maybePadToLogicalRank(bias_, /*has_bdim*/bias_bdim, result_logical_rank); + result0 = result0 + bias_; + } + return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim); +} + +std::tuple> native_layer_norm_backward_no_weight_bias_batch_rule( + const at::Tensor & grad_out, optional grad_out_bdim, + const at::Tensor & input, optional input_bdim, + at::IntArrayRef normalized_shape, + const at::Tensor & mean, optional mean_bdim, + const at::Tensor & rstd, optional rstd_bdim) { + + if (!grad_out_bdim.has_value() && !input_bdim.has_value() && + !mean_bdim.has_value() && !rstd_bdim.has_value()) { + const auto result = at::native_layer_norm_backward( + grad_out, input, normalized_shape, mean, rstd, nullopt, nullopt, {true, false, false}); + return std::make_tuple(std::get<0>(result), nullopt); + } + + auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); + auto input_ = moveBatchDimToFront(input, input_bdim); + auto mean_ = moveBatchDimToFront(mean, mean_bdim); + auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim); + + // ensure grad_out / input have bdim. + const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim); + grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size); + input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size); + mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size); + rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size); + + auto result = at::native_layer_norm_backward( + grad_out_.contiguous(), + input_.contiguous(), + normalized_shape, + mean_.contiguous(), + rstd_.contiguous(), + nullopt, nullopt, {true, false, false}); + + return std::make_tuple(std::get<0>(result), 0); +} + +std::tuple native_layer_norm_backward_plumbing( + const at::Tensor & grad_out, + const at::Tensor & input, + at::IntArrayRef normalized_shape, + const at::Tensor & mean, + const at::Tensor & rstd, + const c10::optional & weight_opt, + const c10::optional & bias_opt, + std::array output_mask) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + // plumbing + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + if (!areAnyBatchedAtLevel({grad_out, input, mean, rstd, weight_opt, bias_opt}, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::native_layer_norm_backward(grad_out, input, normalized_shape, mean, rstd, + weight_opt, bias_opt, output_mask); + } + Tensor grad_out_value; + optional grad_out_bdim; + std::tie(grad_out_value, grad_out_bdim) = unwrapTensorAtLevel(grad_out, cur_level); + Tensor input_value; + optional input_bdim; + std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level); + Tensor mean_value; + optional mean_bdim; + std::tie(mean_value, mean_bdim) = unwrapTensorAtLevel(mean, cur_level); + Tensor rstd_value; + optional rstd_bdim; + std::tie(rstd_value, rstd_bdim) = unwrapTensorAtLevel(rstd, cur_level); + optional weight_value; + optional weight_bdim; + if (weight.defined()) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level); + } + optional bias_value; + optional bias_bdim; + if (bias.defined()) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias, cur_level); + } + + // results + Tensor grad_bias; + Tensor grad_weight; + Tensor grad_input; + + if (output_mask[2] && bias_value.has_value()) { + const auto num_front_dims_to_reduce = grad_out.dim() - normalized_shape.size(); + if (num_front_dims_to_reduce == 0) { + grad_bias = grad_out; + } else { + grad_bias = grad_out.sum(range(0, num_front_dims_to_reduce)); + } + } + if (output_mask[1] && weight_value.has_value()) { + // NB: output isn't saved... + const auto normalized_input = (input - mean) * rstd; + const auto expanded_grad_weight = normalized_input * grad_out; + const auto num_front_dims_to_reduce = + expanded_grad_weight.dim() - normalized_shape.size(); + if (num_front_dims_to_reduce == 0) { + grad_weight = expanded_grad_weight; + } else { + grad_weight = expanded_grad_weight.sum(range(0, num_front_dims_to_reduce)); + } + } + if (output_mask[0]) { + const auto grad_normalized_input = weight.defined() ? + grad_out * weight : grad_out; + Tensor grad_normalized_input_value; + optional grad_normalized_input_bdim; + std::tie(grad_normalized_input_value, grad_normalized_input_bdim) = + unwrapTensorAtLevel(grad_normalized_input, cur_level); + + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + const auto results = native_layer_norm_backward_no_weight_bias_batch_rule( + grad_normalized_input_value, grad_normalized_input_bdim, + input_value, input_bdim, + normalized_shape, + mean_value, mean_bdim, + rstd_value, rstd_bdim); + grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level); + } + return std::make_tuple(grad_input, grad_weight, grad_bias); +} + +template +struct NativeBatchNormBatchRuleHelper { + static std::tuple,Tensor,optional,Tensor,optional> apply( + const Tensor& input, optional input_bdim, + const c10::optional& weight_opt, optional weight_bdim, + const c10::optional& bias_opt, optional bias_bdim, + const c10::optional& running_mean_opt, optional running_mean_bdim, + const c10::optional& running_var_opt, optional running_var_bdim, + bool training, double momentum, double eps) { + return batch_norm_batch_rule( + input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim, + running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps); + } +}; + +template +struct CudnnBatchNormBatchRuleHelper { + static std::tuple,Tensor,optional,Tensor,optional,Tensor,optional> apply( + const Tensor& input, optional input_bdim, + const Tensor& weight_opt, optional weight_bdim, + const c10::optional& bias_opt, optional bias_bdim, + const c10::optional& running_mean_opt, optional running_mean_bdim, + const c10::optional& running_var_opt, optional running_var_bdim, + bool training, double momentum, double eps) { + auto reserve = at::empty({0}, input.options().dtype(kByte)); // in experiments, reserve was never set to anything other than empty by cuda + auto res = batch_norm_batch_rule( + input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim, + running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps); + return std::tuple_cat(res, std::make_tuple(reserve, nullopt)); + } +}; + +template +struct MiopenBatchNormBatchRuleHelper { + static std::tuple,Tensor,optional,Tensor,optional> apply( + const Tensor& input, optional input_bdim, + const Tensor& weight_opt, optional weight_bdim, + const c10::optional& bias_opt, optional bias_bdim, + const c10::optional& running_mean_opt, optional running_mean_bdim, + const c10::optional& running_var_opt, optional running_var_bdim, + bool training, double momentum, double eps) { + return batch_norm_batch_rule( + input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim, + running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps); + } +}; + +#define NATIVE_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\ + NativeBatchNormBatchRuleHelper<\ + decltype(&ATEN_FN(fn)),\ + &ATEN_FN(fn)>::apply) + +#define CUDNN_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\ + CudnnBatchNormBatchRuleHelper<\ + decltype(&ATEN_FN(fn)),\ + &ATEN_FN(fn)>::apply) + +#define MIOPEN_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\ + MiopenBatchNormBatchRuleHelper<\ + decltype(&ATEN_FN(fn)),\ + &ATEN_FN(fn)>::apply) + +template +struct NativeBatchNormBackwardBatchRuleHelper { + static std::tuple apply( + const at::Tensor & grad_out, + const at::Tensor & input, + const c10::optional & weight_opt, + const c10::optional & running_mean_opt, + const c10::optional & running_var_opt, + const c10::optional & save_mean_opt, + const c10::optional & save_rstd_opt, + bool training, + double eps, + std::array output_mask) { + + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + if (!areAnyBatchedAtLevel({grad_out, input, weight_opt, running_mean_opt, + running_var_opt, save_mean_opt, save_rstd_opt}, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::native_batch_norm_backward(grad_out, input, weight_opt, + running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, + training, eps, output_mask); + } + + return batch_norm_backward_plumbing( + grad_out, input, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, training, eps, output_mask); + } +}; + +template +struct CudnnBatchNormBackwardBatchRuleHelper { + static std::tuple apply( + const at::Tensor & input, + const at::Tensor & grad_out, + const at::Tensor & weight, + const c10::optional & running_mean_opt, + const c10::optional & running_var_opt, + const c10::optional & save_mean_opt, + const c10::optional & save_rstd_opt, + double eps, + const at::Tensor & reserve) { + + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt, + running_var_opt, save_mean_opt, save_rstd_opt, reserve}, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::cudnn_batch_norm_backward(input, grad_out, weight, + running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps, reserve); + } + + return batch_norm_backward_plumbing( + grad_out, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, true, eps, {true, true, true}); + } +}; + +template +struct MiopenBatchNormBackwardBatchRuleHelper { + static std::tuple apply( + const at::Tensor & input, + const at::Tensor & grad_out, + const at::Tensor & weight, + const c10::optional & running_mean_opt, + const c10::optional & running_var_opt, + const c10::optional & save_mean_opt, + const c10::optional & save_rstd_opt, + double eps) { + + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt, + running_var_opt, save_mean_opt, save_rstd_opt}, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::miopen_batch_norm_backward(input, grad_out, weight, + running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps); + } + + return batch_norm_backward_plumbing( + grad_out, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, true, eps, {true, true, true}); + } +}; + +#define NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\ + NativeBatchNormBackwardBatchRuleHelper<\ + decltype(&ATEN_FN(fn)),\ + &ATEN_FN(fn)>::apply) + +#define CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\ + CudnnBatchNormBackwardBatchRuleHelper<\ + decltype(&fn),\ + &fn>::apply) + +#define MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\ + MiopenBatchNormBackwardBatchRuleHelper<\ + decltype(&fn),\ + &fn>::apply) + +std::tuple cudnn_batch_norm_backward_wrapper( + const at::Tensor & grad_out, + const at::Tensor & input, + const at::Tensor& weight_opt, + const c10::optional & running_mean_opt, + const c10::optional & running_var_opt, + const c10::optional & save_mean_opt, + const c10::optional & save_rstd_opt, + bool training, + double eps, + std::array output_mask) { + TORCH_INTERNAL_ASSERT(!training); + auto reserve = at::empty({0}, input.options().dtype(kByte)); + return at::cudnn_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps, reserve); + } + + std::tuple miopen_batch_norm_backward_wrapper( + const at::Tensor & grad_out, + const at::Tensor & input, + const at::Tensor& weight_opt, + const c10::optional & running_mean_opt, + const c10::optional & running_var_opt, + const c10::optional & save_mean_opt, + const c10::optional & save_rstd_opt, + bool training, + double eps, + std::array output_mask) { + TORCH_INTERNAL_ASSERT(!training); // this should be ensured by batch_norm_impl + return at::miopen_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps); + } + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm)); + VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm)); + VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm)); + m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward)); + m.impl("cudnn_batch_norm_backward", CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::cudnn_batch_norm_backward_wrapper)); + m.impl("miopen_batch_norm_backward", MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::miopen_batch_norm_backward_wrapper)); + m.impl("native_group_norm", native_group_norm_plumbing); + m.impl("native_group_norm_backward", native_group_norm_backward_plumbing); + VMAP_SUPPORT(native_layer_norm, native_layer_norm_batch_rule); + m.impl("native_layer_norm_backward", native_layer_norm_backward_plumbing); +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesPooling.cpp b/functorch/functorch/csrc/BatchRulesPooling.cpp new file mode 100644 index 0000000000000..a04cba329697b --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesPooling.cpp @@ -0,0 +1,56 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +namespace at { namespace functorch { + +std::tuple,Tensor,optional> +max_pool2d_with_indices_batch_rule( + const Tensor& self, optional self_bdim, + IntArrayRef kernel_size, IntArrayRef stride, + IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) { + auto logical_rank = rankWithoutBatchDim(self, self_bdim); + TORCH_INTERNAL_ASSERT(logical_rank == 3 || logical_rank == 4); + // Tensor[B, C, H, W] -> just call max_pool2d + if (logical_rank == 3) { + auto self_ = moveBatchDimToFront(self, self_bdim); + auto result = at::max_pool2d_with_indices( + self_, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple(std::move(std::get<0>(result)), 0, std::move(std::get<1>(result)), 0); + } + // Tensor[B, N, C, H, W] -> Tensor[B * N, C, H, W] + auto bdim_size = self.size(*self_bdim); + auto self_ = reshape_dim_into(*self_bdim, 0, self); + auto result = at::max_pool2d_with_indices( + self_, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple( + reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, + reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + EXISTING_BDIM(_adaptive_avg_pool2d); + EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward); + EXISTING_BDIM(_adaptive_avg_pool3d); + EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool3d_backward); + EXISTING_BDIM(avg_pool2d); + EXISTING_BDIM(avg_pool3d); + EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward); + EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward); + EXISTING_BDIM_ALL_BOXED(adaptive_max_pool2d); + EXISTING_BDIM_ALL_BOXED(adaptive_max_pool3d); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, adaptive_max_pool2d_backward, 2); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, adaptive_max_pool3d_backward, 2); + + VMAP_SUPPORT(max_pool2d_with_indices, max_pool2d_with_indices_batch_rule); + ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, max_pool2d_with_indices_backward, 2); +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesRandomness.cpp b/functorch/functorch/csrc/BatchRulesRandomness.cpp new file mode 100644 index 0000000000000..a4a9ef9abcb7c --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesRandomness.cpp @@ -0,0 +1,481 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +namespace at { +namespace functorch { + +template +Tensor random_batching_rule(IntArrayRef shape, ExtraArgs... extra_args) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + VmapDimVector shapeVec(1, maybe_layer->batchSize()); + shapeVec.reserve(shape.size() + 1); + shapeVec.insert(shapeVec.end(), shape.begin(), shape.end()); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness); + if (randomness == RandomnessType::Different) { + return makeBatched(Func(shapeVec, std::forward(extra_args)...), 0, maybe_layer->layerId()); + } else { + return Func(shape, std::forward(extra_args)...); + } +} + +template +Tensor& random_inplace_batching_rule(Tensor& self, ExtraArgs... extra_args) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + const auto cur_level = maybe_layer->layerId(); + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + self_value = moveBatchDimToFront(self_value, self_bdim); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness); + TORCH_CHECK( + !(randomness == RandomnessType::Different && !self_bdim), + "vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. ", + "If this is necessary for your usage, please file an issue with functorch."); + if (randomness == RandomnessType::Same && self_bdim) { + auto intermediate = empty(self.sizes(), self.options()); + Func(intermediate, std::forward(extra_args)...); + self.copy_(intermediate); // batching should make this just work out... + return self; + } else { + Func(self_value, std::forward(extra_args)...); + return self; + } +} + +Tensor& bernoulli_inplace_Tensor_batching_rule(Tensor& self, const Tensor& p_, c10::optional gen) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + auto cur_level = maybe_layer->layerId(); + RandomnessType randomness = maybe_layer->randomness(); + + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + + Tensor other_value; + optional other_bdim; + std::tie(other_value, other_bdim) = unwrapTensorAtLevel(p_, cur_level); + + check_randomness(randomness, other_bdim.has_value()); + + if (!self_bdim && other_bdim) { + vmapIncompatibleInplaceError("inplace bernoulli"); + } + + // compute max logical rank + auto self_logical_rank = rankWithoutBatchDim(self_value, self_bdim); + auto other_logical_rank = rankWithoutBatchDim(other_value, other_bdim); + auto max_logical_rank = std::max(self_logical_rank, other_logical_rank); + + auto self_ = moveBatchDimToFront(self_value, self_bdim); + auto other_ = moveBatchDimToFront(other_value, other_bdim); + + // If the dimensions aren't aligned, we need to line them up. + // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3] + // Note that only tensors that have a batch dim need to be modified. + // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed + self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank); + other_ = maybePadToLogicalRank(other_, other_bdim, max_logical_rank); + TORCH_CHECK( + !(randomness == RandomnessType::Different && !self_bdim), + "vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. ", + "If this is necessary for your usage, please file an issue with functorch."); + if (randomness == RandomnessType::Same && self_bdim) { + auto intermediate = empty(self.sizes(), self.options()); + intermediate.bernoulli_(other_, gen); + self.copy_(intermediate); // batching should make this just work out... + return self; + } else { + self_.bernoulli_(other_, gen); + return self; + } +} + +template +Tensor randperm_batching_rule(int64_t n, ExtraArgs... extra_args) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + auto const batch_size = maybe_layer->batchSize(); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness); + if (randomness == RandomnessType::Different) { + std::vector stackedList(batch_size); + stackedList.reserve(batch_size); + for (int64_t idx = 0; idx < batch_size; ++idx) { + // since this is done in a loop, need to pass by reference for generator to update + stackedList[idx] = Func(n, extra_args...); + } + return makeBatched(at::stack(stackedList), 0, maybe_layer->layerId()); + } else { + return Func(n, std::forward(extra_args)...); + } +} + +template +Tensor unary_pointwise_random_batch_rule(const Tensor& tensor, ExtraArgs... extra_args) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + const auto cur_level = maybe_layer->layerId(); + + Tensor tensor_value; + optional tensor_bdim; + std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(tensor, cur_level); + tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim); + + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness, tensor_bdim.has_value()); + auto shape = tensor_value.sizes(); + VmapDimVector shapeVec(1, maybe_layer->batchSize()); + shapeVec.reserve(shape.size() + 1); + shapeVec.insert(shapeVec.end(), shape.begin(), shape.end()); + + if (randomness == RandomnessType::Different && !tensor_bdim) { + tensor_value = tensor_value.expand(shapeVec); + } + auto out = Func(tensor_value, std::forward(extra_args)...); + if (randomness == RandomnessType::Same && !tensor_bdim) { + return out; + } + return makeBatched(out, 0, cur_level); +} + +template +Tensor tensor_like_random_batch_rule(const Tensor& self, ExtraArgs... extra_args) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + const auto cur_level = maybe_layer->layerId(); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness); + + Tensor tensor_value; + optional tensor_bdim; + std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(self, cur_level); + tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim); + + if (randomness == RandomnessType::Same && tensor_bdim) { + tensor_value = tensor_value[0]; + } else if (randomness == RandomnessType::Different && !tensor_bdim) { + auto shape = tensor_value.sizes(); + VmapDimVector shapeVec(1, maybe_layer->batchSize()); + shapeVec.reserve(shape.size() + 1); + shapeVec.insert(shapeVec.end(), shape.begin(), shape.end()); + tensor_value = tensor_value.expand(shapeVec); + } + + auto res = Func(tensor_value, std::forward(extra_args)...); + return (randomness == RandomnessType::Same) ? res : makeBatched(res, 0, cur_level); +} + +std::tuple native_dropout_batching_rule(const Tensor& tensor, double p, c10::optional train) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + const auto cur_level = maybe_layer->layerId(); + RandomnessType randomness = maybe_layer->randomness(); + + Tensor tensor_value; + optional tensor_bdim; + std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(tensor, cur_level); + tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim); + + if (!train.has_value() || train) { + check_randomness(randomness); // if we are in eval mode, we don't use about randomness + } + + if ((train.has_value() && !train) || randomness == RandomnessType::Different) { + auto res = at::native_dropout(tensor_value, p, train); + return std::make_tuple(makeBatched(std::get<0>(res), 0, cur_level), makeBatched(std::get<1>(res), 0, cur_level)); + } + + // repeated code from the CPU kernel since the CUDA one doesn't call bernoulli_ explicitly + double p1m = 1. - p; + // Check for probability of zero to avoid divide by zero and NaN results + double scale = p1m == 0 ? 0. : 1. / p1m; + Tensor mask = at::empty_like(tensor, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + mask.bernoulli_(p1m); + const auto output = tensor.mul(mask).mul_(scale); + return std::make_tuple(output, mask); +} + +Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, const c10::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + const auto cur_level = maybe_layer->layerId(); + + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + self_value = moveBatchDimToFront(self_value, self_bdim); + + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness, self_bdim.has_value()); + + if (randomness == RandomnessType::Different && !self_bdim) { + auto shape = self_value.sizes(); + VmapDimVector shapeVec(1, maybe_layer->batchSize()); + shapeVec.reserve(shape.size() + 1); + shapeVec.insert(shapeVec.end(), shape.begin(), shape.end()); + self_value = self_value.expand(shapeVec); + } + if (self_value.dim() == 3 && (self_bdim || randomness == RandomnessType::Different)) { + self_value = reshape_dim_into(1, 0, self_value); + } + auto out = multinomial(self_value, num_samples, replacement, generator); + if (randomness == RandomnessType::Same && !self_bdim) { + return out; + } + if(self_value.dim() == 3 && self_bdim) { + out = out.reshape(self.sizes()); + } + return makeBatched(out, 0, cur_level); +} + +template +struct RandomBatchRuleHelper; + +template +struct RandomBatchRuleHelper> { + static Tensor apply(IntArrayRef shape, T... extra_args) { + return random_batching_rule(shape, std::forward(extra_args)...); + } +}; + +template +Tensor rand_int_wrapper(IntArrayRef shape, int64_t high, T... extra_args) { + return Func(high, shape, std::forward(extra_args)...); +} + +template +struct RandomInplaceBatchRuleHelper; + +template +struct RandomInplaceBatchRuleHelper> { + static Tensor& apply(Tensor& self, T... extra_args) { + return random_inplace_batching_rule(self, std::forward(extra_args)...); + } +}; + +template +struct RandIntBatchRuleHelper; + +template +struct RandIntBatchRuleHelper> { + static Tensor apply(int64_t high, IntArrayRef shape, T... extra_args) { + return random_batching_rule), + &rand_int_wrapper, + int64_t, T...>(shape, high, std::forward(extra_args)...); + } +}; + +template +Tensor rand_int_low_wrapper(IntArrayRef shape, T0 scalar0, T1 scalar1, T... extra_args) { + return Func(scalar0, scalar1, shape, std::forward(extra_args)...); +} + +template +struct RandTwoLeadingScalarsBatchRuleHelper; + +template +struct RandTwoLeadingScalarsBatchRuleHelper> { + static Tensor apply(T0 scalar0, T1 scalar1, IntArrayRef shape, T... extra_args) { + return random_batching_rule), + &rand_int_low_wrapper, + int64_t, int64_t, T...>(shape, scalar0, scalar1, std::forward(extra_args)...); + } +}; + +template +struct RandpermBatchRuleHelper; + +template +struct RandpermBatchRuleHelper> { + static Tensor apply(int64_t n, T... extra_args) { + return randperm_batching_rule(n, std::forward(extra_args)...); + } +}; + +template +struct UnaryPointwiseRandomBatchRule; + +template +struct UnaryPointwiseRandomBatchRule> { + static Tensor apply(const Tensor& tensor, T... extra_args) { + return unary_pointwise_random_batch_rule(tensor, std::forward(extra_args)...); + } +}; + +template +struct NormalPointwiseBatchRule; + +template +struct NormalPointwiseBatchRule> { + static Tensor apply(const Tensor& tensor, T... extra_args) { + return unary_pointwise_random_batch_rule(tensor, std::forward(extra_args)...); + } +}; + +template +Tensor normal_wrapper(const Tensor& tensor, double scalar, T... extra_args) { + return Func(scalar, tensor, extra_args...); +} + +template +struct UnaryPointwiseRandomLeadingFloatBatchRule; + +template +struct UnaryPointwiseRandomLeadingFloatBatchRule> { + static Tensor apply(double scalar, const Tensor& tensor, T... extra_args) { + return unary_pointwise_random_batch_rule), + &normal_wrapper, double, + T...>(tensor, scalar, std::forward(extra_args)...); + } +}; + +TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { + #define RANDOM_INPLACE_BATCH_RULE2(op, overload) \ + m.impl(#op"."#overload, SINGLE_ARG(\ + RandomInplaceBatchRuleHelper::parameter_types>::apply)) + + RANDOM_INPLACE_BATCH_RULE2(bernoulli_, float); + + #undef RANDOM_INPLACE_BATCH_RULE2 +} + +TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { + #define RANDOM_BATCH_RULE(op) \ + m.impl(#op, SINGLE_ARG(\ + RandomBatchRuleHelper::parameter_types>::apply)) + + #define RANDOM_BATCH_RULE2(op, overload) \ + m.impl(#op"."#overload, SINGLE_ARG(\ + RandomBatchRuleHelper::parameter_types>::apply)) + + #define RANDOM_INPLACE_BATCH_RULE(op) \ + m.impl(#op, SINGLE_ARG(\ + RandomInplaceBatchRuleHelper::parameter_types>::apply)) + + #define RANDOM_INPLACE_BATCH_RULE2(op, overload) \ + m.impl(#op"."#overload, SINGLE_ARG(\ + RandomInplaceBatchRuleHelper::parameter_types>::apply)) + + #define RANDINT_BATCH_RULE(op) \ + m.impl(#op, SINGLE_ARG(\ + RandIntBatchRuleHelper::parameter_types>::apply)) + + #define RANDINT_BATCH_RULE2(op, overload) \ + m.impl(#op"."#overload, SINGLE_ARG(\ + RandIntBatchRuleHelper::parameter_types>::apply)) + + #define RAND_TWO_LEADING_SCALARS_BATCH_RULE(op, overload) \ + m.impl(#op"."#overload, SINGLE_ARG(\ + RandTwoLeadingScalarsBatchRuleHelper::parameter_types>::apply)) + #define RANDPERM_BATCH_RULE(op) \ + m.impl(#op, SINGLE_ARG(\ + RandpermBatchRuleHelper::parameter_types>::apply)) + + #define RANDPERM_BATCH_RULE2(op, overload) \ + m.impl(#op"."#overload, SINGLE_ARG(\ + RandpermBatchRuleHelper::parameter_types>::apply)) + + #define UNARY_POINTWISE_RANDOM(op) \ + m.impl(#op, SINGLE_ARG(\ + UnaryPointwiseRandomBatchRule::parameter_types>::apply)) + + #define UNARY_POINTWISE_RANDOM2(op, overload) \ + m.impl(#op"."#overload, SINGLE_ARG(\ + UnaryPointwiseRandomBatchRule::parameter_types>::apply)) + + #define UNARY_POINTWISE_RANDOM_LEADING_FLOAT(op, overload) \ + m.impl(#op"."#overload, SINGLE_ARG(\ + UnaryPointwiseRandomLeadingFloatBatchRule::parameter_types>::apply)) + + RANDOM_BATCH_RULE(randn); + RANDOM_BATCH_RULE2(randn, generator); + RANDOM_BATCH_RULE2(randn, generator_with_names); + RANDOM_BATCH_RULE2(randn, names); + + RANDOM_BATCH_RULE(rand); + RANDOM_BATCH_RULE2(rand, generator); + RANDOM_BATCH_RULE2(rand, generator_with_names); + RANDOM_BATCH_RULE2(rand, names); + + RANDOM_INPLACE_BATCH_RULE(random_); + RANDOM_INPLACE_BATCH_RULE2(random_, from); + RANDOM_INPLACE_BATCH_RULE2(random_, to); + + RANDOM_INPLACE_BATCH_RULE(cauchy_); + RANDOM_INPLACE_BATCH_RULE(exponential_); + RANDOM_INPLACE_BATCH_RULE(geometric_); + RANDOM_INPLACE_BATCH_RULE(log_normal_); + RANDOM_INPLACE_BATCH_RULE(normal_); + RANDOM_INPLACE_BATCH_RULE(uniform_); + + RANDINT_BATCH_RULE(randint); + RANDINT_BATCH_RULE2(randint, generator); + RAND_TWO_LEADING_SCALARS_BATCH_RULE(randint, low); + RAND_TWO_LEADING_SCALARS_BATCH_RULE(randint, low_generator); + + m.impl("bernoulli_.Tensor", at::functorch::bernoulli_inplace_Tensor_batching_rule); + RANDOM_INPLACE_BATCH_RULE2(bernoulli_, float); + UNARY_POINTWISE_RANDOM2(bernoulli, p); + + RANDPERM_BATCH_RULE(randperm); + RANDPERM_BATCH_RULE2(randperm, generator); + + RAND_TWO_LEADING_SCALARS_BATCH_RULE(normal, float_float); + UNARY_POINTWISE_RANDOM2(normal, Tensor_float); + UNARY_POINTWISE_RANDOM_LEADING_FLOAT(normal, float_Tensor); + + m.impl("native_dropout", native_dropout_batching_rule); // needs special casing because cuda version doesn't call bernoulli + + UNARY_POINTWISE_RANDOM(_standard_gamma); + UNARY_POINTWISE_RANDOM(_sample_dirichlet); + m.impl("multinomial", multinomial_batching_rule); + UNARY_POINTWISE_RANDOM(poisson); + UNARY_POINTWISE_RANDOM(bernoulli); + + #define TENSOR_LIKE_COMMON_ARG_TYPES optional, optional, optional, optional, optional + m.impl("randint_like", tensor_like_random_batch_rule); + m.impl("randint_like.low_dtype", tensor_like_random_batch_rule<\ + decltype(&ATEN_FN2(randint_like, low_dtype)), &ATEN_FN2(randint_like, low_dtype), int64_t, int64_t, TENSOR_LIKE_COMMON_ARG_TYPES>); + m.impl("rand_like", tensor_like_random_batch_rule); + m.impl("randn_like", tensor_like_random_batch_rule); + + #undef RANDOM_BATCH_RULE + #undef RANDOM_BATCH_RULE2 + #undef RANDOM_INPLACE_BATCH_RULE + #undef RANDOM_INPLACE_BATCH_RULE2 + #undef RANDINT_BATCH_RULE + #undef RANDINT_BATCH_RULE2 + #undef RAND_TWO_LEADING_SCALARS_BATCH_RULE + #undef RANDPERM_BATCH_RULE + #undef RANDPERM_BATCH_RULE2 + #undef UNARY_POINTWISE_RANDOM + #undef UNARY_POINTWISE_RANDOM2 + #undef UNARY_POINTWISE_RANDOM_LEADING_FLOAT + #undef TENSOR_LIKE_COMMON_ARG_TYPES +} +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/BatchRulesReduceOps.cpp b/functorch/functorch/csrc/BatchRulesReduceOps.cpp new file mode 100644 index 0000000000000..17f7a263f4ee4 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesReduceOps.cpp @@ -0,0 +1,433 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +namespace at { namespace functorch { + +bool is_allowed_dim_on_scalar_tensor(int64_t dim) { + return dim == 0 || dim == -1; +} + +Tensor sum_decomp( + const Tensor& self, optional dtype) { + return at::sum(self, range(0, self.dim()), false, dtype); +} + +Tensor sum_symint_decomp(const Tensor& input_t, c10::SymIntArrayRef dim, bool keepdim, optional opt_dtype) { + return at::sum(input_t, c10::asIntArrayRefSlow(dim), keepdim, opt_dtype); +} + +Tensor mean_decomp( + const Tensor& self, optional dtype) { + return at::mean(self, range(0, self.dim()), false, dtype); +} + +Tensor nansum_decomp( + const Tensor& self, optional dtype) { + return at::nansum(self, range(0, self.dim()), false, dtype); +} + +Tensor prod_decomp( + const Tensor& self, optional dtype) { + return at::prod(self.flatten(), 0, false, dtype); +} + +Tensor max_decomp( + const Tensor& self) { + return std::get<0>(at::max(self.flatten(), 0, false)); +} + +Tensor min_decomp( + const Tensor& self) { + return std::get<0>(at::min(self.flatten(), 0, false)); +} + +Tensor norm_scalar_decomp( + const Tensor& self, const Scalar& p) { + return at::norm(self, p, range(0, self.dim()), false); +} + +Tensor nanmedian_decomp( + const Tensor& self) { + return std::get<0>(at::nanmedian(self.flatten(), 0, false)); +} + +Tensor median_decomp( + const Tensor& self) { + return std::get<0>(at::median(self.flatten(), 0, false)); +} + +enum ReductionCase { DimArray, Dim }; + +// dim_arg_pos allows us to specify the location of the dim/dim array argument. +// Defaults to 1 +template +void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + + auto orig_arguments = torch::jit::last(*stack, num_arguments); + if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + op.callBoxed(stack); + return; + } + + auto arguments = torch::jit::pop(*stack, num_arguments); + + TORCH_INTERNAL_ASSERT(arguments[0].isTensor()); + Tensor self; + optional self_bdim; + std::tie(self, self_bdim) = unwrapTensorAtLevel(arguments[0].toTensor(), cur_level); + + self = moveBatchDimToFront(self, self_bdim); + + auto logical_dim = rankWithoutBatchDim(self, self_bdim); + std::vector dims; + ReductionCase reduction_case; + if (arguments[dim_arg_pos].isIntList()) { + reduction_case = ReductionCase::DimArray; + dims = arguments[dim_arg_pos].toIntList().vec(); + if (dims.size() == 0) { + auto all_dims = range(0, std::max((int64_t)1, logical_dim)); + dims = std::vector(all_dims.begin(), all_dims.end()); + } + } else if (arguments[dim_arg_pos].isInt()) { + reduction_case = ReductionCase::Dim; + dims = {arguments[dim_arg_pos].toInt()}; + } else if (arguments[dim_arg_pos].isNone()) { + auto param_type = schema.arguments()[dim_arg_pos].type()->expect()->getElementType(); + if (param_type->kind() == IntType::Kind) { + reduction_case = ReductionCase::Dim; + if (self.dim() > 1) { + self = self.flatten(1); + } + dims = {0}; + } else if (param_type->kind() == ListType::Kind) { + reduction_case = ReductionCase::DimArray; + if (logical_dim == 0) { + dims = {0}; + } else { + auto all_dims = range(0, self.dim() - 1); + dims = std::vector(all_dims.begin(), all_dims.end()); + } + } else { + TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims"); + } + } else{ + TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims"); + } + + VmapDimVector new_dims; + new_dims.reserve(dims.size()); + for (auto dim: dims) { + new_dims.push_back(getPhysicalDim(self, self_bdim.has_value(), dim)); + } + bool is_scalar_case = logical_dim == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0]); + if (is_scalar_case) { + self = self.unsqueeze(-1); + new_dims = {1}; + } + arguments[0] = self; + if (reduction_case == ReductionCase::DimArray) { + arguments[dim_arg_pos] = std::vector(new_dims.begin(), new_dims.end()); + } else if (reduction_case == ReductionCase::Dim) { + arguments[dim_arg_pos] = new_dims[0]; + } + for (const auto arg_idx : c10::irange(0, num_arguments)) { + torch::jit::push(stack, arguments[arg_idx]); + } + op.callBoxed(stack); + + const auto returns = torch::jit::pop(*stack, num_returns); + for (const auto& ret : returns) { + if (ret.isTensor()) { + auto res = ret.toTensor(); + if (is_scalar_case) { + res = res.squeeze(-1); + } + torch::jit::push(stack, makeBatched(res, 0, cur_level)); + } else { + TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values"); + } + } +} + +#define REDUCTION_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction()); + +#define REDUCTION_BOXED_ARGS(op, dim_pos) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +// Skipping frobenius/nuclear/all/any since they don't have opinfo tests right now :P + +Tensor dist_decomp(const Tensor& self, const Tensor& other, const Scalar& p) { + return at::norm((self - other), p); +} + +static std::tuple expand_bdims( + const Tensor& a, bool a_has_bdim, + const Tensor& b, bool b_has_bdim) { + Tensor flagpole; + if (a_has_bdim) { + flagpole = a; + } else if (b_has_bdim) { + flagpole = b; + } else { + TORCH_INTERNAL_ASSERT(false); + } + return std::make_tuple( + a_has_bdim ? a : a.expand_as(flagpole), + b_has_bdim ? b : b.expand_as(flagpole)); +} + +std::tuple> _softmax_backward_batch_rule( + const Tensor& grad_output, optional grad_output_bdim, + const Tensor& output, optional output_bdim, + int64_t dim, + ScalarType input_dtype) { + // softmax_backward's decomposition is y * gy - y * (y * gy).sum(dim, keepdim=True) + // NB: the CUDA kernel handles strides so we can just expand + // all of the tensors and call it a day. The CPU kernel is not as good but + // idk if the perf on that really matters + auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); + auto output_ = moveBatchDimToFront(output, output_bdim); + + // Expand out that extra dimension for everyone + std::tie(grad_output_, output_) = expand_bdims( + grad_output_, grad_output_bdim.has_value(), + output_, output_bdim.has_value()); + + // Scalar tensor case. softmax turns into the identity when this happens. + // I don't know why the output is zeros, though, but that's what softmax tells me... + if (output_.dim() == 1 && (dim == 0 || dim == -1)) { + return std::make_tuple(at::zeros_like(grad_output_), 0); + } + + dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim); + + // Not sure why output_ needs to be marked as .contiguous(). Someting must + // have changed in PyTorch (and output of softmax is probably always contiguous) + return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0); +} + +std::tuple> _log_softmax_backward_batch_rule( + const Tensor& grad_output, optional grad_output_bdim, + const Tensor& output, optional output_bdim, + int64_t dim, + c10::ScalarType input_dtype) { + // NB: It turns out that expanding + calling log_softmax_backward is generally + // faster than the decomposition. + // Benchmark here: https://gist.github.com/zou3519/ae3b33b5730a84aae8a80a05c89e078a + // Decomposition is (grad_output - grad_output.sum(dim, keepdim=True) * result.exp()) + // We can squeeze out a last mile of performance by writing custom kernels. + auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); + auto output_ = moveBatchDimToFront(output, output_bdim); + + // Expand out that extra dimension for everyone + std::tie(grad_output_, output_) = expand_bdims( + grad_output_, grad_output_bdim.has_value(), + output_, output_bdim.has_value()); + + // Scalar tensor case. log_softmax returns zeros when this happens + if (output_.dim() == 1 && (dim == 0 || dim == -1)) { + return std::make_tuple(at::zeros_like(grad_output_), 0); + } + + dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim); + + return std::make_tuple(at::_log_softmax_backward_data(grad_output_, output_, dim, input_dtype), 0); +} + +// aminmax has divergent behavior for 0-d tenosrs. +// reference: https://github.com/pytorch/pytorch/issues/64008 +// TODO: Once the divergent behavior for 0-d scalar is fixed, we should use REDUCTION_BOXED_ARGS +std::tuple, Tensor, optional> aminmax_batching_rule( + const Tensor &self, optional self_bdim, optional dim, bool keep_dim) +{ + auto self_ = moveBatchDimToFront(self, self_bdim); + auto logical_rank = rankWithoutBatchDim(self_, self_bdim); + if (logical_rank == 0) { + self_ = self_.unsqueeze(-1); + } + + if (dim.has_value()) { + dim = maybe_wrap_dim(dim.value(), logical_rank) + 1; + } else { + // flatten the input except for batch-dim + auto bsize = self_.size(0); + self_ = self_.view({bsize, -1}); + dim = 1; + } + + Tensor min, max; + std::tie(min, max) = at::aminmax(self_, dim, keep_dim); + + if (logical_rank == 0 && self_.device().is_cuda()) { + // behaviour diverges between cpu and cuda + min = min.squeeze(-1); + max = max.squeeze(-1); + } + return std::make_tuple(min, 0, max, 0); +} + +std::tuple> searchsorted_batch_rule( + const Tensor& sorted_sequence, + optional sorted_sequence_bdim, + const Tensor& self, + optional self_bdim, + bool out_int32, + bool right, + c10::optional side, + const c10::optional& sorter, + c10::optional sorter_bdim) { + auto buckets_logical_rank = rankWithoutBatchDim(sorted_sequence, sorted_sequence_bdim); + + // Preprocess sorter and sorted_sequence. + // If they both exist, and only one has a bdim, then we need to make sure both do. + // After this step, we can forget about sorter for a bit. + auto buckets = moveBatchDimToFront(sorted_sequence, sorted_sequence_bdim); + optional buckets_bdim; + if (sorted_sequence_bdim.has_value()) { + buckets_bdim = 0; + } + + optional sorter_; + if (sorter.has_value() && sorter->defined()) { + auto sorter__ = moveBatchDimToFront(*sorter, sorter_bdim); + if (sorted_sequence_bdim.has_value() != sorter_bdim.has_value()) { + auto bdim_size = get_bdim_size2( + sorted_sequence, sorted_sequence_bdim, + sorter.value(), sorter_bdim); + sorter__ = ensure_has_bdim(sorter__, sorter_bdim.has_value(), bdim_size); + buckets = ensure_has_bdim(buckets, sorted_sequence_bdim.has_value(), bdim_size); + buckets_bdim = 0; + } + sorter_ = sorter__; + } + + // Two cases: buckets_logical_rank is 1, or it is greater than 1. + // searchsorted is basically two operators with different semantics jammed + // into one + if (buckets_logical_rank > 1) { + // B<...>D, B<...>V -> no change + if (buckets_bdim.has_value() && self_bdim.has_value()) { + auto self_ = moveBatchDimToFront(self, self_bdim); + auto result = at::searchsorted(buckets, self_, out_int32, right, side, sorter_); + return std::make_tuple(result, 0); + } + // B<...>D, <...>V -> B<...>D, B<...>V + if (buckets_bdim.has_value() && !self_bdim.has_value()) { + auto self_ = moveBatchDimToFront(self, self_bdim); + self_ = ensure_has_bdim(self_, self_bdim.has_value(), buckets.size(0)); + auto result = at::searchsorted(buckets, self_, out_int32, right, side, sorter_); + return std::make_tuple(result, 0); + } + // <...>D, B<...>V -> <...>D, <...>(BV) + if (!buckets_bdim.has_value() && self_bdim.has_value()) { + auto bdim_size = self.size(*self_bdim); + auto self_ = reshape_dim_into(*self_bdim, -1, self); + auto result = at::searchsorted(buckets, self_, out_int32, right, side, sorter_); + result = reshape_dim_outof(-1, bdim_size, result); + return std::make_tuple(result, result.dim() - 2); + } + TORCH_INTERNAL_ASSERT(false); + } + // buckets_logical_rank == 1 case. + // BD, B* -> BD, B flat(*) + if (buckets_bdim.has_value() && self_bdim.has_value()) { + auto self_ = moveBatchDimToFront(self, self_bdim); + self_ = self_.flatten(1); + auto result = at::searchsorted(buckets, self_, out_int32, right, side, sorter_); + result = result.view(self_.sizes()); + return std::make_tuple(result, 0); + } + // BD, * -> BD, flat(*) -> BD, B flat(*) + if (buckets_bdim.has_value() && !self_bdim.has_value()) { + auto bdim_size = buckets.size(*buckets_bdim); + auto self_ = ensure_has_bdim(self, false, bdim_size); + self_ = self_.flatten(1); + auto result = at::searchsorted(buckets, self_, out_int32, right, side, sorter_); + result = result.view(self_.sizes()); + return std::make_tuple(result, 0); + } + // D, B* -> no change + if (!buckets_bdim.has_value() && self_bdim.has_value()) { + auto result = at::searchsorted(buckets, self, out_int32, right, side, sorter_); + return std::make_tuple(result, self_bdim); + } + TORCH_INTERNAL_ASSERT(false); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + VMAP_SUPPORT2(searchsorted, Tensor, searchsorted_batch_rule); + REDUCTION_BOXED(_fft_r2c); + REDUCTION_BOXED(_fft_c2r); + REDUCTION_BOXED(_fft_c2c); + REDUCTION_BOXED(amax); + // REDUCTION_BOXED(aminmax); Currently fails due to inconsistent scalar semantics. + REDUCTION_BOXED(amin); + REDUCTION_BOXED(any.dim); + REDUCTION_BOXED(argmax); + REDUCTION_BOXED(argmin); + REDUCTION_BOXED(count_nonzero.dim_IntList); + REDUCTION_BOXED(cummax); + REDUCTION_BOXED(cummin); + REDUCTION_BOXED(cumprod); + REDUCTION_BOXED(cumsum); + m.impl("dist", dist_decomp); + REDUCTION_BOXED_ARGS(kthvalue, 2); + REDUCTION_BOXED_ARGS(linalg_vector_norm, 2); + REDUCTION_BOXED(log_softmax.int); + REDUCTION_BOXED(logcumsumexp); + REDUCTION_BOXED(logsumexp); + m.impl("max", max_decomp); + REDUCTION_BOXED(max.dim); + m.impl("mean", mean_decomp); + REDUCTION_BOXED(mean.dim); + m.impl("median", median_decomp); + REDUCTION_BOXED(median.dim); + m.impl("min", min_decomp); + REDUCTION_BOXED(min.dim); + REDUCTION_BOXED(mode); + m.impl("nanmedian", nanmedian_decomp); + REDUCTION_BOXED(nanmedian.dim); + // TODO: re-enable these + // m.impl("nansum", nansum_decomp); + // REDUCTION_BOXED(nansum.dim_IntList); + m.impl("norm.Scalar", norm_scalar_decomp); + REDUCTION_BOXED_ARGS(norm.ScalarOpt_dim, 2); + m.impl("prod", prod_decomp); + REDUCTION_BOXED(prod.dim_int); + REDUCTION_BOXED(std.correction); + REDUCTION_BOXED(_softmax); + REDUCTION_BOXED(sort); + REDUCTION_BOXED_ARGS(sort.stable, 2); + REDUCTION_BOXED(argsort); + REDUCTION_BOXED(std_mean.correction); + m.impl("sum", sum_decomp); + REDUCTION_BOXED(sum.dim_IntList); + REDUCTION_BOXED_ARGS(topk, 2); + REDUCTION_BOXED(var.correction); + REDUCTION_BOXED(var_mean.correction); + REDUCTION_BOXED(_log_softmax); + REDUCTION_BOXED_ARGS(rot90, 2); + VMAP_SUPPORT(aminmax, aminmax_batching_rule); + m.impl("sum.SymInt", sum_symint_decomp); + VMAP_SUPPORT(_log_softmax_backward_data, _log_softmax_backward_batch_rule); + VMAP_SUPPORT(_softmax_backward_data, _softmax_backward_batch_rule); +} +}} diff --git a/functorch/functorch/csrc/BatchRulesScatterOps.cpp b/functorch/functorch/csrc/BatchRulesScatterOps.cpp new file mode 100644 index 0000000000000..da01d464908e9 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesScatterOps.cpp @@ -0,0 +1,1074 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { namespace functorch { + +static bool any_has_value(ArrayRef> bdims) { + for (const auto& bdim : bdims) { + if (bdim.has_value()) { + return true; + } + } + return false; +} + +static int64_t get_num_leading_nones(ArrayRef> indices) { + int64_t result = 0; + for (const auto& idx : indices) { + if (!idx.has_value() || !idx->defined()) { + result++; + } else { + return result; + } + } + return result; +} + +static int64_t get_max_index_logical_dim( + ArrayRef> indices, + ArrayRef> indices_bdims) { + int64_t max_logical_dim = -1; + TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size()); + TORCH_INTERNAL_ASSERT(indices.size() > 0); + for (const auto i : c10::irange(0, indices.size())) { + const auto& maybe_tensor = indices[i]; + if (!maybe_tensor.has_value() || !maybe_tensor->defined()) { + continue; + } + auto logical_dim = rankWithoutBatchDim(maybe_tensor.value(), indices_bdims[i]); + max_logical_dim = std::max(logical_dim, max_logical_dim); + } + return max_logical_dim; +} + +std::vector> batchIndices( + ArrayRef> indices, + ArrayRef> indices_bdims, + int64_t batch_size, + optional self_bdim, + optional values_bdim = nullopt) { + // There are 3 main cases: + // 1. self is batched, indices/values are not batched + // In this case, we just need to augment indices with a None at the front to + // basically broadcast the indexing across the batch dimension of self. + // + // 2. self is not batched, some indices are batched. + // In this case, we don't need to do anything - indices will automatically + // broadcast to work with the unbatched self. + // + // 3. self is batched, some indices are batched. + // In this case, we simply need to add an arange that indexes along the first + // dimension (i.e. the batch dimension). We also need to make sure this + // broadcasts with the rest of the indices. + // + // In all three cases, depending on if advanced indices are adjacent we will + // have to permute the output. + // See NOTE: [advanced indexing (index.Tensor) batch rule] for more details + // + // There is one more case worth mentioning - boolean tensor indices. If we + // have "batched" boolean tensor indices, that is unrepresentable, as each + // batch would result in a tensor with different values. + std::vector> indices_; + + int64_t maxLogicalRank = get_max_index_logical_dim(indices, indices_bdims); + bool indices_batched = any_has_value(indices_bdims); + + for (size_t i = 0; i < indices.size(); i++) { + auto index = indices[i]; + if (index.has_value() && index->numel() != 0) { + const auto idx_bdim = indices_bdims[i]; + indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank)); + if (index.value().dtype() == kBool && indices_bdims[i].has_value()) { + throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask."); + } + } else { + indices_.push_back(index); + } + } + + auto maxIndexDim = maxLogicalRank; + if (indices_batched || values_bdim.has_value()) { + maxIndexDim += 1; + } + + if (!indices_batched && self_bdim.has_value()) { + indices_.insert(indices_.begin(), nullopt); + } else if (indices_batched && !self_bdim.has_value()) { + // do nothing + } else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) { + auto arange_index = at::arange(0, batch_size); + while (arange_index.dim() < maxIndexDim) { + arange_index = arange_index.unsqueeze(-1); + } + // TODO: this is O(N) + indices_.insert(indices_.begin(), arange_index); + } + return indices_; +} + +// Define an "advanced index" to be a selection object that is +// a non-trivial Tensor (i.e. it does not represent :). +static bool is_advanced_index(const optional& idx) { + if (!idx.has_value()) { + return false; + } + if (!idx->defined()) { + return false; + } + return true; +} + +// See NOTE: [advanced indices adjacent] for definition +static bool are_advanced_indices_adjacent(ArrayRef> indices) { + int64_t num_advanced_indices_regions = 0; + bool in_advanced_indices_region = false; + for (const auto& idx : indices) { + if (!in_advanced_indices_region && is_advanced_index(idx)) { + num_advanced_indices_regions++; + in_advanced_indices_region = true; + continue; + } + if (in_advanced_indices_region && !is_advanced_index(idx)) { + in_advanced_indices_region = false; + continue; + } + } + return num_advanced_indices_regions <= 1; +} + +// Given a Tensor[B, , , ...] +// Swaps the regions to produce Tensor[B, , , ...] +// +// Concretely speaking, given +// - tensor: Tensor[B, 2, 3, 4, 5, 6, 7, 8] +// - first_region_size: 2 +// - second_region_size: 3 +// Produces: +// - result: Tensor[B, 4, 5, 6, 2, 3, 7, 8] +// ------- ---- +// region2 region1 +static Tensor swap_regions(const Tensor& tensor, int64_t first_region_size, int64_t second_region_size) { + VmapDimVector permutation(tensor.dim(), 0); + std::iota(permutation.begin(), permutation.end(), 0); + std::rotate( + permutation.begin() + 1, + permutation.begin() + 1 + first_region_size, + permutation.begin() + 1 + first_region_size + second_region_size); + return tensor.permute(permutation); +} + +std::tuple> index_batch_rule( + const Tensor& self, + optional self_bdim, + ArrayRef> indices, + ArrayRef> indices_bdims) { + + // NOTE: [advanced indexing (index.Tensor) batch rule] + // + // This is a three step procedure: + // 1. batch `indices`. Depends on self_bdim and indices_bdim. + // 2. call at::index + // 3. (maybe) reorder the dimensions in the result. + // Why is step 3 necessary? Let's take a detour first. + // + // NOTE: [advanced indices adjacent] + // Definition: In a list of optional indices, + // we say that "advanced indices are adjacent" if ALL advanced indices are + // not separated by a None (slice). + // + // So, for example, + // [:, :, (0, 1), (0, 1), :] -> True + // [:, (0, 1), :, (0, 1), :] -> False, the advanced indices are separated by a slice + // + // See https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + // for more details. + // + // NOTE: [Why is step 3 necessary?] + // + // In the original self[*indices] expression, + // depending on whether or not the "advanced indices inside `indices` are + // adjacent", something different happens. + // + // For example: + // - self: Tensor[4, 5, 6, 7] + // - indices: [:, (0, 1), (0, 1), :] (advanced indices are adjacent) + // - self[*indices]: Tensor[4, 2, 7] + // If advanced indices are adjacent, you get the output you would expect. + // (0, 1), (0, 1) says "please index these two dimensions at (0, 0) and (1, 1) + // to produce two elements". + // + // If advanced indices are not adjacent, it is ambiguous to where the new + // dimension of size 2 should go. The numpy spec says it should go at the very + // front of the Tensor. + // + // - self: Tensor[4, 5, 6, 7] + // - indices: [:, (0, 1), :, (0, 1)] (advanced indices not adjacent) + // - self[*indices]: Tensor[2, 4, 6] + // + // Now, this leads to some weird interactions with vmap. + // The indices might originally have adjacent advanced indices, but after + // batching them with "batchIndices", they may no longer be adjacent! + // - indices: [:, (0, 1), (0, 1)] + // - batched_indices (for example): [(0, 1), :, (0, 1), (0, 1)] + // This leads to the dimension of size 2 appearing somewhere else. + // + // There are a couple of different cases that we walk through in the code below. + // + // Background reading for why we care about if the advanced indices are adjacent: + // https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + auto self_ = moveBatchDimToFront(self, self_bdim); + TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size()); + bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(indices); + + // Step 1 + const auto batched_indices = batchIndices(indices, indices_bdims, self_.size(0), self_bdim); + auto num_leading_nones = get_num_leading_nones(indices); + auto max_index_dim = get_max_index_logical_dim(indices, indices_bdims); + + // Step 2 + auto res = at::index(self_, List>(batched_indices)); + + // Step 3: There are three cases (these match the cases outlined in batchIndices) + bool self_batched = self_bdim.has_value(); + bool indices_batched = any_has_value(indices_bdims); + + TORCH_INTERNAL_ASSERT(self_batched || indices_batched, "Requires at least one batched to get here"); + + // Case 1 + if (self_batched && !indices_batched) { + if (advanced_indices_are_adjacent) { + // self: Tensor[B, 5, 6, 7, 8] + // indices: [:, Tensor[2, 2], Tensor[2, 2], :] + // batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :] + // res: Tensor[B, 5, 2, 2, 8] + return std::make_tuple(res, 0); + } else { + // self: Tensor[B, 5, 6, 7] + // indices: [Tensor[2, 2], :, Tensor[2, 2]] + // batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]] + // res: Tensor[2, 2, B, 6] + return std::make_tuple(res, max_index_dim); + } + } + + // Case 2 + if (!self_batched && indices_batched) { + if (advanced_indices_are_adjacent) { + // self: Tensor[5, 6, 7, 8] + // indices: [:, :, Tensor[B, 2, 2], Tensor[2, 2]] + // batched_indices: indices (no change) + // res: Tensor[5, 6, B, 2, 2] + return std::make_tuple(res, num_leading_nones); + } else { + // self: Tensor[5, 6, 7, 8, 9] + // indices: [:, :, Tensor[B, 2, 2], :, Tensor[2, 2]] + // batched_indices: indices (no change) + // res: Tensor[B, 2, 2, 5, 6, 8] + return std::make_tuple(res, 0); + } + } + + // Case 3: self_batched and indices_batched + TORCH_INTERNAL_ASSERT(self_batched && indices_batched); + if (!advanced_indices_are_adjacent) { + // self: Tensor[B, 5, 6, 7, 8] + // indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]] + // batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]] + // res: Tensor[B, 2, 2, 5, 7] + return std::make_tuple(res, 0); + } + // In other words, in batched_indices, advanced indices are adjacent + if (num_leading_nones == 0) { + // self: Tensor[B, 5, 6, 7, 8] + // indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :] + // batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :] + // res: Tensor[B, 2, 2, 7, 8] + return std::make_tuple(res, 0); + } + // This is the tricky case. In indices, advanced indices are adjacent. + // In batched_indices, advanced indices are no longer adjacent + // + // self: Tensor[B, 5, 6, 7, 8, 9] + // indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :] + // batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :] + // res: Tensor[B, 2, 3, 5, 6, 9] + // expected: Tensor[B, 5, 6, 2, 3, 9] + // + // The resolution is to move dims around until we get the right shape. + // The result is set up as [B, , , ...] + // we just have to move the to before the to produce + // [B, , , ...] + return std::make_tuple(swap_regions(res, max_index_dim, num_leading_nones), 0); +} + +// plumbing done since we don't support List> in codegen +Tensor index_plumbing(const Tensor & self, const List> & indices +) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::index(self, indices); + } + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + std::vector> indices_value; + std::vector> indices_bdims; + for (const auto&& indRef : indices) { + optional ind = indRef; + optional index; + optional index_bdim; + if (ind.has_value()) { + std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level); + } + indices_value.push_back(index); + indices_bdims.push_back(index_bdim); + } + auto results = index_batch_rule(self_value, self_bdim, indices_value, indices_bdims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} + +namespace { + // Code is mostly duplicated from + // https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3 + // /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L294-L312 + VmapDimVector compute_indexed_shape(const Tensor &src, TensorList indices_list) + { + int64_t dims_before = 0, dims_after = 0, dims_indexed = 0; + IntArrayRef replacement_shape; + for (const auto dim : c10::irange(indices_list.size())) { + if (!indices_list[dim].defined()) { + if (dims_indexed == 0) { + dims_before++; + } else { + dims_after++; + } + } else { + dims_indexed++; + replacement_shape = indices_list[dim].sizes(); + } + } + + // Replace indexed dimensions in src with stride 0 and the size of the result tensor. + // The offset in these dimensions is computed by the kernel using the index tensor's + // values and the stride of src. The new shape is not meaningful. It's used to make + // the shape compatible with the result tensor. + auto shape = VmapDimVector(src.sizes()); + int64_t end = dims_before + dims_indexed; + shape.erase(shape.begin() + dims_before, shape.begin() + end); + shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end()); + return shape; + } + + // Code is mostly duplicated from + // https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3 + // /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L379-L405 + VmapDimVector get_indexed_shape(Tensor self, const torch::List> &orig) + { + at::native::checkIndexTensorTypes(orig); + // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors + auto indices = at::native::expandTensors(self, orig); + // next broadcast all index tensors together + try { + indices = at::expand_outplace(indices); + } catch (std::exception &e) { + TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together" + " with shapes "); + } + // add missing null Tensors so that it matches self.dim() + while (indices.size() < static_cast(self.dim())) { + indices.emplace_back(); + } + // if the non-null indices are not all adjacent, transpose self and indices + // together so that they're adjacent at the front + if (!at::native::hasContiguousSubspace(indices)) { + std::tie(self, indices) = at::native::transposeToFront(self, indices); + } + return compute_indexed_shape(self, indices); + } + + std::tuple>, Tensor> + index_put_batch_rule_helper(const Tensor &self, + optional self_bdim, + ArrayRef> indices, + ArrayRef> indices_bdims, + const Tensor &values, + optional values_bdim, + optional opt_batch_size = {}) { + + Tensor self_ = moveBatchDimToFront(self, self_bdim); + Tensor values_ = moveBatchDimToFront(values, values_bdim); + // for inplace variants `index_put_` and `_index_put_impl_` we find the batch_size + // here while for `index_put` does it outside of this function. + const auto batch_size = opt_batch_size ? opt_batch_size.value() : self_.size(0); + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + values_ = ensure_has_bdim(values_, values_bdim.has_value(), batch_size); + TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size()); + + // we've already made sure that self has bdim at 0. + const auto indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim); + + auto indexed_shape = get_indexed_shape(self_, List>(indices_)); + + // handle broadcasting support for values + // Eg. Given `indexed_shape.size()` is 5 and + // shape of `values` is (N, 2, 3), then following block + // will reshape `values` to (N, 1, 1, 2, 3). + if ( (int64_t) indexed_shape.size() > values_.dim()) { + auto values_sizes = values_.sizes(); + + // number of unit dims (for broadcasting value to indexed_shape) + auto n_unit_dims = indexed_shape.size() - values_sizes.size(); + VmapDimVector new_values_shape(values_sizes.size() + n_unit_dims); + + // add the batch-dim + new_values_shape[0] = batch_size; + + // insert the unit dims for broadcasting. + for (const auto idx : c10::irange(n_unit_dims)) { + // since batch-dim is already be filled. + new_values_shape[idx + 1] = 1; + } + for (const auto idx: c10::irange(1, values_sizes.size())) { + // since batch and unit dims are already be filled. + new_values_shape[idx + n_unit_dims] = values_sizes[idx]; + } + values_ = values_.view(new_values_shape); + } + + return std::make_tuple(self_, indices_, values_); + } + + auto unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor &self, + const List> &indices, + const Tensor &values, int64_t cur_level) + { + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + std::vector> indices_value; + std::vector> indices_bdims; + for (const auto &&indRef : indices) + { + optional ind = indRef; + optional index; + optional index_bdim; + if (ind.has_value()) + { + std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level); + } + indices_value.push_back(index); + indices_bdims.push_back(index_bdim); + } + Tensor values_value; + optional values_bdim; + std::tie(values_value, values_bdim) = unwrapTensorAtLevel(values, cur_level); + return std::make_tuple(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim); + } + +} // namespace + +void index_put__batch_rule( + const Tensor& self, + optional self_bdim, + ArrayRef> indices, + ArrayRef> indices_bdims, + const Tensor& values, + optional values_bdim, + bool accumulate) { + if (!self_bdim.has_value()) { + vmapIncompatibleInplaceError("index_put_"); + } + Tensor self_, values_; + std::vector> indices_; + std::tie(self_, indices_, values_) = index_put_batch_rule_helper( + self, self_bdim, indices, indices_bdims, values, values_bdim); + at::index_put_(self_, List>(indices_), values_, accumulate); +} + +// plumbing done since we don't support List> in codegen +Tensor& index_put__plumbing(Tensor & self, const List> & indices +, const Tensor & values, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return self.index_put_(indices, values, accumulate); + } + Tensor self_value, values_value; + optional self_bdim, values_bdim; + std::vector> indices_value; + std::vector> indices_bdims; + std::tie(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim) = + unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level); + index_put__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate); + return self; +} + +void _index_put_impl__batch_rule( + const Tensor& self, + optional self_bdim, + ArrayRef> indices, + ArrayRef> indices_bdims, + const Tensor& values, + optional values_bdim, + bool accumulate, + bool unsafe) { + if (!self_bdim.has_value()) { + vmapIncompatibleInplaceError("_index_put_impl_"); + } + Tensor self_, values_; + std::vector> indices_; + std::tie(self_, indices_, values_) = index_put_batch_rule_helper( + self, self_bdim, indices, indices_bdims, values, values_bdim); + at::_index_put_impl_(self_, List>(indices_), values_, accumulate, unsafe); +} + +// plumbing done since we don't support List> in codegen +Tensor &_index_put_impl__plumbing(Tensor &self, const List> &indices, + const Tensor &values, bool accumulate, bool unsafe) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_index_put_impl_(self, indices, values, accumulate, unsafe); + } + Tensor self_value, values_value; + optional self_bdim, values_bdim; + std::vector> indices_value; + std::vector> indices_bdims; + std::tie(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim) = + unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level); + _index_put_impl__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe); + return self; +} + +static Tensor maybe_permute_values( + const Tensor& values, + ArrayRef> orig_indices, + ArrayRef> orig_indices_bdims) { + bool indices_batched = any_has_value(orig_indices_bdims); + bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(orig_indices); + auto num_leading_nones = get_num_leading_nones(orig_indices); + auto max_index_dim = get_max_index_logical_dim(orig_indices, orig_indices_bdims); + TORCH_INTERNAL_ASSERT(values.dim() >= num_leading_nones + max_index_dim); + + // NB: values has its B dimension at the front + if (!indices_batched) { + if (advanced_indices_are_adjacent) { + // self: Tensor[B, 5, 6, 7, 8] + // indices: [:, Tensor[2, 2], Tensor[2, 2], :] + // batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :] + // required values: Tensor[B, 5, 2, 2, 8] + return values; + } + // self: Tensor[B, 5, 6, 7] + // indices: [Tensor[2, 2], :, Tensor[2, 2]] + // batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]] + // required values: Tensor[2, 2, B, 6] + return values.movedim(0, max_index_dim); + } + if (!advanced_indices_are_adjacent) { + // self: Tensor[B, 5, 6, 7, 8] + // indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]] + // batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]] + // required values: Tensor[B, 2, 2, 5, 7] + return values; + } + // In other words, in batched_indices, advanced indices are adjacent + if (num_leading_nones == 0) { + // self: Tensor[B, 5, 6, 7, 8] + // indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :] + // batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :] + // required values: Tensor[B, 2, 2, 7, 8] + return values; + } + // This is the tricky case. In indices, advanced indices are adjacent. + // In batched_indices, advanced indices are no longer adjacent + // + // self: Tensor[B, 5, 6, 7, 8, 9] + // indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :] + // batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :] + // required values: Tensor[B, 2, 3, 5, 6, 9] + // actual values: Tensor[B, 5, 6, 2, 3, 9] + // + // The resolution is to move dims around until we get the right shape. + // The values is set up as [B, , , ...] + // we just have to move the to before the to produce + // [B, , , ...] + return swap_regions(values, num_leading_nones, max_index_dim); +} + +std::tuple> index_put_batch_rule( + const Tensor& self, + optional self_bdim, + ArrayRef> indices, + ArrayRef> indices_bdims, + const Tensor& values, + optional values_bdim, + bool accumulate) { + TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size()); + + // find the batch_size + int64_t batch_size = 0; + if (self_bdim || values_bdim) { + batch_size = get_bdim_size2(self, self_bdim, values, values_bdim); + } else { + // one or more of the indices is batched. + for (size_t i = 0; i < indices.size(); i++) { + if (indices_bdims[i] && indices[i].has_value()) { + batch_size = indices[i].value().size(*indices_bdims[i]); + break; + } + } + } + + Tensor self_, values_; + std::vector> indices_; + std::tie(self_, indices_, values_) = index_put_batch_rule_helper( + self, self_bdim, indices, indices_bdims, values, values_bdim, batch_size); + + // Why do we need to permute values? + // See NOTE [Advanced indexing (index.Tensor) batch rule] for details, + // but the gist is that index_put effectively does the following: + // - result = self_.clone() + // - result[indices_] = values + // - return result + // Now, the problem is, result[indices_] might return a Tensor whose shape is + // the shape of values, but permuted. This is because the shape of result[indices_] + // depends on if the original indices "have adjacent advanced indices" + // and the batched `indices_` might change the "have adjacent advanced indices" property + values_ = maybe_permute_values(values_, indices, indices_bdims); + + auto result = at::index_put(self_, List>(indices_), values_, accumulate); + return std::make_tuple(result, 0); +} + +// plumbing done since we don't support List> in codegen +Tensor index_put_plumbing(const Tensor & self, const List> & indices, + const Tensor & values, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return self.index_put(indices, values, accumulate); + } + Tensor self_value, values_value; + optional self_bdim, values_bdim; + std::vector> indices_value; + std::vector> indices_bdims; + std::tie(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim) = + unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level); + auto results = index_put_batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} + +namespace { + +template +std::tuple> scatter_batch_rule( + Func f, + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Scalar& value, Args... args) { + auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto index_logical_rank = rankWithoutBatchDim(index, index_bdim); + auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim); + + auto self_ = moveBatchDimToFront(self, self_bdim); + auto index_ = moveBatchDimToFront(index, index_bdim); + + if (self_logical_rank == 0) { + self_ = self_.unsqueeze(-1); + } + if (index_logical_rank == 0) { + index_ = index_.unsqueeze(-1); + } + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size); + auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim); + + auto result = f(self_, physical_dim, index_, value, args...); + // result should have same shape as self + if (self_logical_rank == 0) { + result = result.squeeze(-1); + } + return std::make_tuple(result, 0); +} + +template +inline std::tuple> scatter_batch_rule( + Func f, + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Tensor& src, optional src_bdim, Args... args) { + auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto index_logical_rank = rankWithoutBatchDim(index, index_bdim); + auto src_logical_rank = rankWithoutBatchDim(src, src_bdim); + auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, src, src_bdim); + + auto self_ = moveBatchDimToFront(self, self_bdim); + auto index_ = moveBatchDimToFront(index, index_bdim); + auto src_ = moveBatchDimToFront(src, src_bdim); + + if (self_logical_rank == 0) { + self_ = self_.unsqueeze(-1); + } + if (index_logical_rank == 0) { + index_ = index_.unsqueeze(-1); + } + if (src_logical_rank == 0) { + src_ = src_.unsqueeze(-1); + } + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size); + src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size); + auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim); + + auto result = f(self_, physical_dim, index_, src_, args...); + // result should have same shape as self + if (self_logical_rank == 0) { + result = result.squeeze(-1); + } + return std::make_tuple(result, 0); +} + +} // namespace + +std::tuple> scatter_value_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Scalar& value) { + return scatter_batch_rule(ATEN_FN2(scatter, value), + self, self_bdim, dim, index, index_bdim, value); +} + +std::tuple> scatter_src_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Tensor& src, optional src_bdim) { + return scatter_batch_rule(ATEN_FN2(scatter, src), + self, self_bdim, dim, index, index_bdim, src, src_bdim); +} + +std::tuple> scatter_add_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Tensor& src, optional src_bdim) { + return scatter_batch_rule(ATEN_FN(scatter_add), + self, self_bdim, dim, index, index_bdim, src, src_bdim); +} + +std::tuple> scatter_reduce_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Tensor& src, optional src_bdim, + const c10::string_view reduce) { + return scatter_batch_rule(ATEN_FN2(scatter, reduce), + self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce); +} + +std::tuple> scatter_value_reduce_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Scalar& src, + const c10::string_view reduce) { + return scatter_batch_rule(ATEN_FN2(scatter, value_reduce), + self, self_bdim, dim, index, index_bdim, src, reduce); +} + +std::tuple> gather_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + bool sparse_grad) { + auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto index_logical_rank = rankWithoutBatchDim(index, index_bdim); + auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim); + + auto self_ = moveBatchDimToFront(self, self_bdim); + auto index_ = moveBatchDimToFront(index, index_bdim); + + if (self_logical_rank == 0) { + self_ = self_.unsqueeze(-1); + } + if (index_logical_rank == 0) { + index_ = index_.unsqueeze(-1); + } + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size); + auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim); + + auto result = at::gather(self_, physical_dim, index_, sparse_grad); + // result should have same rank as index + if (index_logical_rank == 0) { + result = result.squeeze(-1); + } + return std::make_tuple(result, 0); +} + +std::tuple> gather_backward_batch_rule( + const Tensor& grad, optional grad_bdim, + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + bool sparse_grad) { + auto batch_size = get_bdim_size3(grad, grad_bdim, self, self_bdim, index, index_bdim); + auto grad_ = moveBatchDimToFront(grad, grad_bdim); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto index_ = moveBatchDimToFront(index, index_bdim); + + auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto index_logical_rank = rankWithoutBatchDim(index, index_bdim); + auto grad_logical_rank = rankWithoutBatchDim(grad, grad_bdim); + + if (grad_logical_rank == 0) { + grad_ = grad_.unsqueeze(-1); + } + if (self_logical_rank == 0) { + self_ = self_.unsqueeze(-1); + } + if (index_logical_rank == 0) { + index_ = index_.unsqueeze(-1); + } + grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), batch_size); + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size); + + auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim); + auto result = at::gather_backward(grad_, self_, physical_dim, index_, sparse_grad); + // result should has same rank as self + if (self_logical_rank == 0) { + result = result.squeeze(-1); + } + return std::make_tuple(result, 0); +} + +namespace { +Tensor get_expanded_index(const Tensor& index, IntArrayRef self_size, int64_t dim) { + if (index.dim() == 0) { + return index.expand(self_size); + } + + // setup new_index_shape as [BS, 1, ..., idx_size, ..., 1] + // to reshape index_ + auto idx_size = index.size(0); // get non-batch size of index tensor + Tensor index_; + { + VmapDimVector new_index_shape(self_size.size(), 1); + new_index_shape[dim] = idx_size; + index_ = index.view(new_index_shape); + } + // Now apply expand to index_ + { + VmapDimVector new_index_shape = {self_size.begin(), self_size.end()}; + new_index_shape[dim] = idx_size; + index_ = index_.expand(new_index_shape); + } + return index_; +} +} + +Tensor index_select_decomp(const Tensor &self, int64_t dim, const Tensor &index) +{ + Tensor index_ = index; + if (self.dim() > index.dim()) { + index_ = get_expanded_index(index, self.sizes(), dim); + } + + auto result = at::gather(self, dim, index_); + + // output of gather has same dimension as `index` while + // output of index_select has same dimension as self + // Eg. t = torch.tensor(1) + // idx = torch.tensor([0]) + // torch.index_select(t, 0, idx) # 0-D + // torch.gather(t, 0, idx) # 1-D + if (self.dim() == 0 && result.dim() != 0) { + result = result.squeeze(-1); + } + + return result; +} + +Tensor index_copy_decomp( + const Tensor &self, int64_t dim, + const Tensor &index, const Tensor &source) +{ + Tensor index_ = index; + if (self.dim() > index.dim()) { + index_ = get_expanded_index(index, self.sizes(), dim); + } + + return at::scatter(self, dim, index_, source); ; +} + +Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src, + int64_t dim, c10::optional start, + c10::optional end, int64_t step) +{ + auto idx = at::arange(start.value_or(0), end.value_or(self.size(dim)), step, self.options().dtype(kLong)); + idx = get_expanded_index(idx, self.sizes(), dim); + return at::scatter(self, dim, idx, src); +} + +Tensor select_scatter_decomp( + const Tensor &self, const Tensor &source, + int64_t dim, int64_t index) +{ + // supports negative index + index = maybe_wrap_dim(index, self.size(dim)); + auto index_ = at::scalar_tensor(index, self.options().dtype(kLong)); + + return at::scatter(self, dim, index_.expand_as(self), source.unsqueeze(dim).expand_as(self)); +} + +std::tuple> diagonal_scatter_batch_rule( + const Tensor &self, c10::optional self_bdim, + const Tensor &src, c10::optional src_bdim, + int64_t offset, int64_t dim1, int64_t dim2) +{ + auto self_ = moveBatchDimToFront(self, self_bdim); + auto src_ = moveBatchDimToFront(src, src_bdim); + + auto batch_size = get_bdim_size2(self, self_bdim, src, src_bdim); + + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size); + + auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + dim1 = maybe_wrap_dim(dim1, self_logical_rank) + 1; + dim2 = maybe_wrap_dim(dim2, self_logical_rank) + 1; + + return std::make_tuple(at::diagonal_scatter(self_, src_, offset, dim1, dim2), 0); +} + +std::tuple> index_add_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Tensor& other, optional other_bdim, + const Scalar& alpha) { + if (!index_bdim) { + // Handle scalar tensors... self, other can be scalar tensors + const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + const auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); + auto self_ = moveBatchDimToFront(self, self_bdim); + if (self_logical_rank == 0) { + self_ = self_.unsqueeze(-1); + } + auto other_ = moveBatchDimToFront(other, other_bdim); + if (other_logical_rank == 0) { + other_ = other_.unsqueeze(-1); + } + dim = maybe_wrap_dim(dim, self_logical_rank); + + const auto batch_size = get_bdim_size2(self, self_bdim, other, other_bdim); + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + other_ = ensure_has_bdim(other_, other_bdim.has_value(), batch_size); + + auto result = self_.index_add(dim + 1, index, other_, alpha); + if (self_logical_rank == 0) { + result = result.squeeze(-1); + } + return std::make_tuple(result, 0); + } + + // Index is batched. For-loop and stack is the best thing I can come up with + // right now. We really want generalized index_add kernel in PyTorch + auto batch_size = get_bdim_size3(self, self_bdim, other, other_bdim, index, index_bdim); + std::vector results; + results.reserve(batch_size); + for (const auto i : c10::irange(0, batch_size)) { + const auto& self_slice = self_bdim.has_value() ? + self.select(*self_bdim, i) : self; + const auto& other_slice = other_bdim.has_value() ? + other.select(*other_bdim, i) : other; + const auto& index_slice = index_bdim.has_value() ? + index.select(*index_bdim, i) : index; + results.push_back(at::index_add(self_slice, dim, index_slice, other_slice, alpha)); + } + return std::make_tuple(at::stack(results), 0); +} + +static std::tuple binary_pointwise_align( + const Tensor & self, + optional self_bdim, + const Tensor & mask, + optional mask_bdim) { + // compute max logical rank + auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto other_logical_rank = rankWithoutBatchDim(mask, mask_bdim); + auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank); + + auto tensor_ = moveBatchDimToFront(self, self_bdim); + auto other_ = moveBatchDimToFront(mask, mask_bdim); + + // If the dimensions aren't aligned, we need to line them up. + // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3] + // Note that only tensors that have a batch dim need to be modified. + // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed + tensor_ = maybePadToLogicalRank(tensor_, self_bdim, max_logical_rank); + other_ = maybePadToLogicalRank(other_, mask_bdim, max_logical_rank); + + return std::make_tuple(tensor_, other_); +} + +std::tuple> masked_fill_scalar_batch_rule( + const Tensor & self, + optional self_bdim, + const Tensor & mask, + optional mask_bdim, + const Scalar& source) { + auto tensors = binary_pointwise_align(self, self_bdim, mask, mask_bdim); + auto result = at::masked_fill(std::get<0>(tensors), std::get<1>(tensors), source); + return std::make_tuple(result, 0); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + m.impl("index.Tensor", index_plumbing); + m.impl("index_put_", index_put__plumbing); + m.impl("index_put", index_put_plumbing); + m.impl("_index_put_impl_", _index_put_impl__plumbing); + m.impl("slice_scatter", slice_scatter_decomp); + m.impl("select_scatter", select_scatter_decomp); + m.impl("index_copy", index_copy_decomp); + m.impl("index_select", index_select_decomp); + VMAP_SUPPORT2(masked_fill, Scalar, masked_fill_scalar_batch_rule); + VMAP_SUPPORT(index_add, index_add_batch_rule); + VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule); + VMAP_SUPPORT(gather, gather_batch_rule); + VMAP_SUPPORT(gather_backward, gather_backward_batch_rule); + VMAP_SUPPORT2(scatter, value, scatter_value_batch_rule); + VMAP_SUPPORT2(scatter, src, scatter_src_batch_rule); + VMAP_SUPPORT(scatter_add, scatter_add_batch_rule); + VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule); + VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule); +} + +}} diff --git a/functorch/functorch/csrc/BatchRulesUnaryOps.cpp b/functorch/functorch/csrc/BatchRulesUnaryOps.cpp new file mode 100644 index 0000000000000..660cb1f3c7139 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesUnaryOps.cpp @@ -0,0 +1,215 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +namespace at { namespace functorch { + +std::tuple> +clone_batch_rule( + const Tensor& self, + optional self_bdim, + optional memory_format) { + // Memory format support is a little tricky because vmap is allowed to move + // around batch dimensions and some memory formats are rank-dependent. + // Another weird case is: + // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we + // allow the user to clone a Tensor with 3 logical dimensions and 1 batch + // dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims + // and N>1 batch dims? + TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve + || memory_format == MemoryFormat::Contiguous, + "NYI: Tensor.clone(memory_format) inside vmap is only supported with ", + "memory_format torch.preserve_format or torch.contiguous_format (got ", + *memory_format, ")"); + + if (memory_format == MemoryFormat::Contiguous) { + // There is an ambiguity here when the batch dims are not at the front of + // the tensor. + // >>> x = torch.randn(3, B0, 5) + // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x) + // >>> y[0].is_contiguous() + // ??? + // Should we make the whole tensor contiguous, or should we + // make the non-batch dims contiguous? We've chosen the latter because + // philosophically vmap hides the batch dims and operates on a per-sample level. + auto self_ = moveBatchDimToFront(self, self_bdim); + auto result = at::clone(self_, memory_format); + return std::make_tuple(result, 0); + } + + TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve); + auto result = at::clone(self, memory_format); + return std::make_tuple(result, self_bdim); +} + +std::tuple> +contiguous_batch_rule( + const Tensor& self, + optional self_bdim, + MemoryFormat memory_format) { + TORCH_CHECK(memory_format == MemoryFormat::Contiguous, + "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ", + "than torch.contiguous_format"); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto result = self_.contiguous(memory_format); + return std::make_tuple(result, 0); +} + +std::tuple> +view_as_complex_batch_rule(const Tensor& self, optional self_bdim) { + // guard against the user passing in a batch of scalar tensors with batch + // size equal to 2. + TORCH_CHECK(self.sizes().size() > 1, "Input tensor must have one or more dimensions"); + + auto self_ = moveBatchDimToFront(self, self_bdim); + auto result = at::view_as_complex(self_); + return std::make_tuple(result, 0); +} + +std::tuple> +to_other_batch_rule(const Tensor& self, optional self_bdim, + const Tensor& other, optional other_bdim, + bool non_blocking, + bool copy, c10::optional memory_format) { + return std::make_tuple(self.to(other, non_blocking, copy, memory_format), self_bdim); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + +#define UNARY_POINTWISE_ALL2(op, overload) \ + POINTWISE_BOXED2(op ## _, overload); \ + VMAP_SUPPORT2(op, overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload))); +#define UNARY_POINTWISE_ALL(op) \ + POINTWISE_BOXED(op ## _); \ + VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op))); + + UNARY_POINTWISE(imag); + UNARY_POINTWISE(real); + UNARY_POINTWISE(view_as_real); + VMAP_SUPPORT(view_as_complex, view_as_complex_batch_rule); + VMAP_SUPPORT(clone, clone_batch_rule); + VMAP_SUPPORT(contiguous, contiguous_batch_rule); + VMAP_SUPPORT2(to, device, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, device))); + VMAP_SUPPORT2(to, dtype, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, dtype))); + VMAP_SUPPORT2(to, dtype_layout, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, dtype_layout))); + VMAP_SUPPORT2(to, other, to_other_batch_rule); + + UNARY_POINTWISE(_to_copy); + UNARY_POINTWISE(alias); + UNARY_POINTWISE_ALL(abs); + UNARY_POINTWISE_ALL(acos); + UNARY_POINTWISE_ALL(acosh); + UNARY_POINTWISE(angle); + UNARY_POINTWISE_ALL(asin); + UNARY_POINTWISE_ALL(asinh); + UNARY_POINTWISE_ALL(atan); + UNARY_POINTWISE_ALL(atanh); + UNARY_POINTWISE_ALL(bitwise_not); + UNARY_POINTWISE_ALL(ceil); + UNARY_POINTWISE_ALL(cos); + UNARY_POINTWISE_ALL(cosh); + UNARY_POINTWISE(_conj); + UNARY_POINTWISE_ALL(deg2rad); + UNARY_POINTWISE(detach); + UNARY_POINTWISE_ALL(digamma); + UNARY_POINTWISE_ALL(erf); + UNARY_POINTWISE_ALL(exp); + UNARY_POINTWISE_ALL(expm1); + UNARY_POINTWISE_ALL(floor); + UNARY_POINTWISE_ALL(frac); + UNARY_POINTWISE(isfinite); + UNARY_POINTWISE(isnan); + UNARY_POINTWISE(isinf); + UNARY_POINTWISE(isposinf); + UNARY_POINTWISE(isneginf); + UNARY_POINTWISE(isreal); + UNARY_POINTWISE_ALL(lgamma); + UNARY_POINTWISE_ALL(log); + UNARY_POINTWISE_ALL(log10); + UNARY_POINTWISE_ALL(log1p); + UNARY_POINTWISE_ALL(log2); + UNARY_POINTWISE_ALL(logical_not); + UNARY_POINTWISE_ALL(logit); + UNARY_POINTWISE_ALL(mish); + UNARY_POINTWISE_ALL(mvlgamma); + UNARY_POINTWISE_ALL(nan_to_num); + UNARY_POINTWISE_ALL(neg); + UNARY_POINTWISE_ALL(positive); + UNARY_POINTWISE_ALL(rad2deg); + UNARY_POINTWISE_ALL(reciprocal); + UNARY_POINTWISE_ALL(round); + UNARY_POINTWISE_ALL2(round, decimals); + UNARY_POINTWISE_ALL(rsqrt); + UNARY_POINTWISE_ALL(sgn); + UNARY_POINTWISE_ALL(sign); + UNARY_POINTWISE(signbit); + UNARY_POINTWISE_ALL(sin); + UNARY_POINTWISE_ALL(sinc); + UNARY_POINTWISE_ALL(sinh); + UNARY_POINTWISE_ALL(sqrt); + UNARY_POINTWISE_ALL(tan); + UNARY_POINTWISE_ALL(threshold); + UNARY_POINTWISE_ALL(trunc); + + // special-related + UNARY_POINTWISE_ALL(i0); + UNARY_POINTWISE_ALL(erfc); + UNARY_POINTWISE_ALL(erfinv); + UNARY_POINTWISE_ALL(exp2); + + // torch.special.* functions + UNARY_POINTWISE(special_entr); + UNARY_POINTWISE(special_erf); + UNARY_POINTWISE(special_erfc); + UNARY_POINTWISE(special_erfcx); + UNARY_POINTWISE(special_erfinv); + UNARY_POINTWISE(special_expit); + UNARY_POINTWISE(special_expm1); + UNARY_POINTWISE(special_digamma); + UNARY_POINTWISE(special_psi); + UNARY_POINTWISE(special_exp2); + UNARY_POINTWISE(special_gammaln); + UNARY_POINTWISE(special_i0); + UNARY_POINTWISE(special_i0e); + UNARY_POINTWISE(special_i1); + UNARY_POINTWISE(special_i1e); + UNARY_POINTWISE(special_log1p); + UNARY_POINTWISE(special_ndtr); + UNARY_POINTWISE(special_ndtri); + UNARY_POINTWISE(special_round); + UNARY_POINTWISE(special_sinc); + + // Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) + UNARY_POINTWISE_ALL(elu); + UNARY_POINTWISE(hardshrink); + UNARY_POINTWISE_ALL(hardsigmoid); + UNARY_POINTWISE_ALL(hardtanh); + UNARY_POINTWISE_ALL(hardswish); + UNARY_POINTWISE_ALL(leaky_relu); + UNARY_POINTWISE(log_sigmoid); + UNARY_POINTWISE_ALL(relu); + UNARY_POINTWISE_ALL(relu6); + UNARY_POINTWISE_ALL(selu); + UNARY_POINTWISE_ALL(celu); + UNARY_POINTWISE(gelu); + UNARY_POINTWISE_ALL(sigmoid); + UNARY_POINTWISE_ALL(silu); + UNARY_POINTWISE(softplus); + UNARY_POINTWISE(softshrink); + UNARY_POINTWISE_ALL(tanh); + + POINTWISE_BOXED(fill_.Scalar); + POINTWISE_BOXED(zero_); + +#undef UNARY_POINTWISE +#undef UNARY_POINTWISE_ALL + +} + +#undef INVOKE +}} diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp new file mode 100644 index 0000000000000..e4160ea4c98f1 --- /dev/null +++ b/functorch/functorch/csrc/BatchRulesViews.cpp @@ -0,0 +1,556 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace functorch { + +// Note [Adding vmap support for an operator] +// Hey there! So you have an operator and you want to get it to work with vmap. +// For example, let's say you just invented the `sum.int` operator and want to make +// it so that the following works. +// >>> tensor = torch.randn(B, 3) +// >>> vmap(torch.sum, (0, None))(tensor, 0)` works +// There are three main ways to do so. +// +// Note [Writing batch rule for out-of-place operators] +// If your operator is out-of-place, you can write a batch rule for it. +// The batch rule defines how to perform the operator on inputs where each +// Tensor input may have an additional dimension that is being vmapped over. +// We refer to this dimension as the *batch dimension* or bdim for short. +// +// For example, let's consider writing a batch rule for +// `Tensor sum(const Tensor& self, int64_t dim)`. The signature of the +// batch rule has an additional optional argument after each +// Tensor argument and return. So, in this case, the batch rule has signature +// tuple> sum_batch_rule( +// const Tensor& self, optional self_bdim, int64_t dim); +// +// The vmap call above invokes the batch rule with `self = tensor`, +// `self_bdim = 0`, and `dim = 0`. Note that there are **no BatchedTensors** +// involved in this case; there exists some plumbing that automatically unwraps +// BatchedTensors before calling the batch rule. +// +// To write the logic of the batch rule: think about the semantics of the +// `sum` operation if `self` had an additional dimension (indicated by self_bdim): +// - If `self_bdim` is null, then we just do `result = self.sum(dim)` as usual +// - If `self_bdim` is not-null, then we need to modify `dim`. `dim` is equal +// to whatever the user passed in (0 in this case), but we should actually +// perform the reduction over dimension 1 and do `result = self.sum(1)` +// because dim 0 is being vmapped over. +// Finally, we return the result as well as a new bdim +// - If `self_bdim` is null, then there's no batch dim in the result. +// - If `self_bdim` is not-null, then we return where the bdim is. +// Since we invoked `result = self.sum(1)`, the bdim is still at dim 0. +// +// Now that we have written `sum_batch_rule`, we have to register it inside a +// TORCH_LIBRARY_IMPL block: +// TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { +// ... +// VMAP_SUPPORT2(sum, int, sum_batch_rule); +// ... +// } +// +// Note [Reusing batch rules to add vmap support for a complicated operator] +// Can't figure out how to write a batch rule for a big operation? If the +// operation can be expressed as a composition of other operations that do have +// batch rules, then that is another way to add vmap support. For example, +// consider the following schema +// func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) +// and assume we already have batching rules for basic arithmetic operators. +// +// To add vmap support, define a decomposition using the same signature: +// Tensor addcmul_decomp(const Tensor& self, const Tensor& tensor1, +// const Tensor& tensor2, const Scalar& value) { +// auto product = torch.mul(tensor1, tensor2); +// return torch.add(self, product, value); +// } +// And register it inside a TORCH_LIBRARY_IMPL block: +// TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { +// ... +// m.impl("addcmul", addcmul_decomp); +// ... +// } +// +// Note [Writing batch rule for in-place operators] +// TODO: This is kinda complicated. Saving this for a future date. + +std::tuple> unsqueeze_batch_rule( + const Tensor& self, + optional self_bdim, + int64_t dim) { + auto self_ = moveBatchDimToFront(self, self_bdim); + auto rank = rankWithoutBatchDim(self, self_bdim); + dim = maybe_wrap_dim(dim, rank + 1) + 1; + return std::make_tuple(self_.unsqueeze(dim), 0); +} + +// NB: repeat is not actually a view, but it is in this file +std::tuple> repeat_batch_rule( + const Tensor& self, + optional self_bdim, + IntArrayRef sizes) { + + VmapDimVector sizes_with_bdim = { sizes.begin(), sizes.end() }; + sizes_with_bdim.insert(sizes_with_bdim.begin(), 1); + auto self_ = moveBatchDimToFront(self, self_bdim); + while (self_.dim() < (int64_t)sizes_with_bdim.size()) { + self_ = self_.unsqueeze(1); + } + return std::make_tuple(self_.repeat(sizes_with_bdim), 0); +} + + +std::tuple> diag_batch_rule( + const Tensor& input, + optional input_bdim, + int64_t diagonal) { + if (!input_bdim) { + return std::make_tuple(at::diag(input, diagonal), nullopt); + } + auto input_ = moveBatchDimToFront(input, input_bdim); + auto rank = rankWithoutBatchDim(input, input_bdim); + + if (rank == 1) { + return std::make_tuple(at::diag_embed(input_, diagonal), 0); + } else if (rank == 2) { + return std::make_tuple(at::diagonal(input_.movedim(0, -1), diagonal).clone(), rank - 2); + } else { + throw std::runtime_error("Passed in an invalid shape to at::diag"); + } +} + +std::tuple> _unsafe_view_batch_rule( + const Tensor& self, + optional self_bdim, + IntArrayRef size) { + auto self_ = moveBatchDimToFront(self, self_bdim); + VmapDimVector view_size(size); + view_size.insert(view_size.begin(), self_.size(0)); + + // See if the view is valid. If it's not, then we copy. + // It's OK to copy, because _unsafe_view(x) guarantees that x isn't used + // anymore. + const at::DimVector inferred_size = at::infer_size_dv(view_size, self_.numel()); + const auto stride = at::detail::computeStride(self_.sizes(), + self_.strides(), + inferred_size); + if (!stride.has_value()) { + self_ = self_.contiguous(); + } + return std::make_tuple(at::_unsafe_view(self_, view_size), 0); +} + +std::tuple> flip_batch_rule(const Tensor& self, optional self_bdim, IntArrayRef dims) { + auto self_ = moveBatchDimToFront(self, self_bdim); + VmapDimVector new_dims; + for (auto i: dims) { + new_dims.push_back(getPhysicalDim(self_, true, i)); + } + return std::make_tuple(at::flip(self_, new_dims), 0); +} + +const Tensor& resize__plumbing( + const Tensor& self, + IntArrayRef size, + c10::optional optional_memory_format) { + TORCH_CHECK( + !optional_memory_format.has_value() || + optional_memory_format == c10::MemoryFormat::Contiguous, + "resize_: batching rule only supports None or Contiguous MemoryFormat"); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard2(kBatchedKey); + return self.resize_(size, optional_memory_format); + } + + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + TORCH_INTERNAL_ASSERT(self_bdim.has_value()); + + // TODO: The following algorithm only works for batch dim == 0. + // To get it to work for something else we need the ability to modify + // the BatchDims attribute of BatchedTensorImpl + TORCH_INTERNAL_ASSERT(self_bdim.value() == 0, "NYI: resize_ batch rule for batch dim != 0"); + + // Resize the wrapped tensor + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + self_value = moveBatchDimToFront(self_value, self_bdim); + VmapDimVector new_size(size); + new_size.insert(new_size.begin(), self_value.size(*self_bdim)); + self_value.resize_(new_size); + + // Update the sizes and strides of the wrapper + auto* batched = maybeGetBatchedImpl(self); + TORCH_INTERNAL_ASSERT(batched); + batched->refreshTensorMetadata(); + + return self; +} + +std::tuple> squeeze_batch_rule(const Tensor& self, optional bdim) { + TORCH_INTERNAL_ASSERT(bdim.has_value()); + // Special case for scalar arrays to replicate PyTorch behavior. + if (self.dim() == 1) { + return std::make_tuple(self.alias(), bdim); + } + + // Manually calculate the output shape by eliding all dimensions of + // size 1 keeping track of where the batch index started and where it + // ended up moving to. We also ensure we do not drop the batch index. + auto shape = self.sizes(); + DimVector squeezed_sizes; + bool before_batch_idx = true; + int64_t new_batch_idx = 0; + int64_t original_idx = 0; + + for (auto it : shape) { + // Keep only dimensions != 1 and the batch dimension (irrespective of size). + if (it != 1 || original_idx == bdim) { + squeezed_sizes.push_back(it); + if (original_idx == bdim) { + before_batch_idx = false; + } + // Only increment for the dimensions that will be kept in the output. + if (before_batch_idx) { + ++new_batch_idx; + } + } + ++original_idx; + } + + auto result = self.view(squeezed_sizes); + return std::make_tuple(result, c10::optional(new_batch_idx)); +} + +std::tuple> squeeze_dim_batch_rule(const Tensor& self, optional bdim, int64_t dim) { + TORCH_INTERNAL_ASSERT(bdim.has_value()); + // Special case for scalar arrays to replicate PyTorch behavior. + if (self.dim() == 1) { + TORCH_CHECK(dim == 0, "Dimension is out of range (expected to be in range of [-1, 0], but got ", dim); + return std::make_tuple(self.alias(), bdim); + } + + // Calculate the proper offset if dim is negative. + auto actual_dim = dim; + if (dim < 0) { + actual_dim = self.dim() + dim - 1; + } + if (actual_dim < bdim) { + // Since dimension to be squeezed is before the batch dimension pass as-is. + auto original_size = self.dim(); + auto result = self.squeeze(actual_dim); + auto updated_batch_idx = *bdim; + if (result.dim() != original_size) { + // A column before batch dimension has been dropped so adjust accordingly. + --updated_batch_idx; + } + return std::make_tuple(result, optional(updated_batch_idx)); + } else { + // Since dimension to be squeezed is after the batch dimension adjust by one to account + // for the original batch dimension. In this case batch dimension won't move. + return std::make_tuple(self.squeeze(actual_dim + 1), bdim); + } +} + +std::tuple, optional> chunk_batching_rule(const Tensor& self, optional self_bdim, int64_t chunks, int64_t dim) { + auto self_ = moveBatchDimToFront(self, self_bdim); + int64_t new_dim = getPhysicalDim(self, self_bdim.has_value(), dim); + return std::make_tuple(at::chunk(self_, chunks, new_dim), 0); +} + +std::tuple> select_batching_rule(const Tensor& self, optional bdim, int64_t dim, int64_t index) { + if (!bdim) { + return std::make_tuple(self.select(dim, index), nullopt); + } + + auto _self = moveBatchDimToFront(self, bdim); + auto dim_physical = getPhysicalDim(_self, true, dim); + auto result = _self.select(dim_physical, index); + return std::make_tuple(result, 0); +} + +std::tuple> _reshape_alias_batch_rule(const Tensor& self, optional bdim, const IntArrayRef shape, const IntArrayRef strides) { + (void) strides; + TORCH_INTERNAL_ASSERT(bdim.has_value()); + + auto self_ = moveBatchDimToFront(self, bdim); + c10::SmallBuffer new_shape(shape.size() + 1); + new_shape[0] = self_.size(0); + std::copy(shape.begin(), shape.end(), new_shape.begin() + 1); + return std::make_tuple(at::reshape(self_, new_shape), 0); +} + +std::tuple> roll_batch_rule(const Tensor& self, optional bdim, IntArrayRef shifts, IntArrayRef dims) { + TORCH_INTERNAL_ASSERT(bdim.has_value()); + + auto self_ = moveBatchDimToFront(self, bdim); + VmapDimVector new_dims; + if (!dims.empty()) { + for (auto i: dims) { + new_dims.push_back(getPhysicalDim(self, true, i)); + } + return std::make_tuple(at::roll(self_, shifts, new_dims), 0); + } + // We will do something like: t.reshape(a, -1).roll(1, dims=[1, ]).reshape(old_shape) + auto old_shape = self_.sizes(); + new_dims.push_back(1); + auto output = at::roll(self_.flatten(1), shifts, new_dims); + output = output.reshape(old_shape); + return std::make_tuple(output, 0); +} + +std::tuple> diagonal_batching_rule( + const Tensor &self, optional self_bdim, + int64_t offset, int64_t dim1, int64_t dim2) +{ + auto logical_rank = rankWithoutBatchDim(self, self_bdim); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto dim1_ = maybe_wrap_dim(dim1, logical_rank) + 1; + auto dim2_ = maybe_wrap_dim(dim2, logical_rank) + 1; + auto result = at::diagonal(self_, offset, dim1_, dim2_); + return std::make_tuple(std::move(result), 0); +} + +std::tuple> diagonal_backward_batch_rule( + const Tensor& grad_input, optional grad_input_bdim, + IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim); + auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim); + dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1; + dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1; + c10::SmallBuffer input_sizes_(input_sizes.size() + 1); + input_sizes_[0] = grad_input_.size(0); + std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1); + auto result = at::diagonal_backward(grad_input_, input_sizes_, offset, dim1, dim2); + return std::make_tuple(std::move(result), 0); +} + +std::tuple> slice_batch_rule( + const Tensor& self, + optional self_bdim, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { + auto self_ = moveBatchDimToFront(self, self_bdim); + dim = getPhysicalDim(self, self_bdim.has_value(), dim); + + auto result = self_.slice(dim, start, end, step); + return std::make_tuple(result, 0); +} + +static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { + return dim == 0 || dim == -1; +} + +std::tuple> +transpose_int_batch_rule( + const Tensor& self, + optional self_bdim, + int64_t dim0, + int64_t dim1) { + // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works + // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens: + // >>> x = torch.randn(B0) # the per-examples are all scalars + // >>> vmap(lambda x: x.transpose(0, -1), x) + // then we replicate this behavior. + if (/*physical*/self.dim() == 1 && is_allowed_dim_on_scalar_tensor(dim0) && + is_allowed_dim_on_scalar_tensor(dim1)) { + return std::make_tuple(self, self_bdim); + } + auto self_ = moveBatchDimToFront(self, self_bdim); + dim0 = getPhysicalDim(self, self_bdim.has_value(), dim0); + dim1 = getPhysicalDim(self, self_bdim.has_value(), dim1); + auto result = self_.transpose(dim0, dim1); + return std::make_tuple(result, 0); +} + +std::tuple> permute_batching_rule( + const Tensor &self, optional self_bdim, IntArrayRef dims) +{ + if (!self_bdim.has_value()) { + return std::make_tuple(self.permute(dims), self_bdim); + } + + auto self_ = moveBatchDimToFront(self, self_bdim); + VmapDimVector dims_; + dims_.reserve(dims.size() + 1); + dims_.emplace_back(0); + for (auto dim : dims) { + dims_.emplace_back(getPhysicalDim(self_, self_bdim.has_value(), dim)); + } + + return std::make_tuple(self_.permute(dims_), 0); +} + +std::tuple> select_backward_batch_rule( + const Tensor& grad_input, optional grad_input_bdim, + IntArrayRef input_sizes, int64_t dim, int64_t index) { + auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim); + auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim); + dim = maybe_wrap_dim(dim, logical_rank + 1) + 1; + c10::SmallBuffer input_sizes_(input_sizes.size() + 1); + input_sizes_[0] = grad_input_.size(0); + std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1); + auto result = at::select_backward(grad_input_, input_sizes_, dim, index); + return std::make_tuple(std::move(result), 0); +} + +std::tuple> slice_backward_batch_rule( + const Tensor& grad_input, optional grad_input_bdim, + IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim); + auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim); + dim = maybe_wrap_dim(dim, logical_rank) + 1; + c10::SmallBuffer input_sizes_(input_sizes.size() + 1); + input_sizes_[0] = grad_input_.size(0); + std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1); + auto result = at::slice_backward(grad_input_, input_sizes_, dim, start, end, step); + return std::make_tuple(std::move(result), 0); +} + +std::tuple> view_batching_rule( + const Tensor &self, optional self_bdim, IntArrayRef size) +{ + TORCH_INTERNAL_ASSERT(self_bdim.has_value()); + auto self_ = moveBatchDimToFront(self, self_bdim); + VmapDimVector size_(size.size() + 1); + // copy batch size + size_[0] = self_.size(0); + std::copy(size.cbegin(), size.cend(), size_.begin() + 1); + return std::make_tuple(self_.view(size_), 0); +} + +Tensor view_symint_decomposition(const Tensor& self, + c10::SymIntArrayRef size) { + return self.view( c10::asIntArrayRefSlow(size)); +} + + +template +std::tuple> expand_batch_rule( + const Tensor &self, optional self_bdim, IntArrayRef size, bool implicit) +{ + auto self_dim = self.dim(); + TORCH_CHECK(static_cast(self_dim - 1) <= size.size(), + "expand: the number of sizes provided (", size.size(), ") ", + "must be greater or equal to the number of dimensions in the tensor (", static_cast(self_dim - 1), ")"); + + auto self_ = moveBatchDimToFront(self, self_bdim); + auto self_sizes = self_.sizes(); + auto batch_size = self_sizes[0]; + + c10::SmallBuffer size_(size.size() + 1); + size_[0] = batch_size; + std::copy(size.cbegin(), size.cend(), size_.begin() + 1); + + // Here, we know we are expanding a (logical) tensor to a larger number + // of dimensions. We have to be careful because we can't call expand directly + // due to the presence of batch dimensions. + // + // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]). + // The result should be a tensor of size [B0, 2, 3]. + // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3] + // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and + // then expand. + auto extra_dims = size.size() - (self_dim - 1); + VmapDimVector view_shape(size_.size(), /*init_value*/1); + view_shape[0] = batch_size; + std::copy(self_sizes.cbegin() + 1, self_sizes.cend(), + view_shape.begin() + 1 + extra_dims); + + return std::make_tuple(Func(self_.view(view_shape), size_, implicit), 0); +} + +std::tuple> unfold_batch_rule( + const Tensor &self, optional self_bdim, int64_t dim, int64_t size, int64_t step) +{ + TORCH_INTERNAL_ASSERT(self_bdim.has_value()); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto logical_rank = rankWithoutBatchDim(self, self_bdim); + dim = maybe_wrap_dim(dim, logical_rank) + 1; + if (logical_rank==0) { + self_ = self_.unsqueeze(-1); + } + auto result = self_.unfold(dim, size, step); + if (logical_rank==0) { + result = result.squeeze(-1); + } + return std::make_tuple(result, 0); +} + +std::tuple> movedim_batch_rule(const Tensor& self, optional self_bdim, IntArrayRef source, IntArrayRef destination) { + auto self_ = moveBatchDimToFront(self, self_bdim); + auto source_ = getPhysicalDims(self_, self_bdim.has_value(), source); + auto destination_ = getPhysicalDims(self_, self_bdim.has_value(), destination); + return std::make_tuple(self_.movedim(source_, destination_), 0); +} + +std::tuple> diag_embed_batch_rule(const Tensor& self, optional self_bdim, int64_t offset, int64_t dim1, int64_t dim2) { + auto logical_rank = rankWithoutBatchDim(self, self_bdim); + auto self_ = moveBatchDimToFront(self, self_bdim); + dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1; + dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1; + return std::make_tuple(at::diag_embed(self_, offset, dim1, dim2), 0); +} + +// We need to write a real batching rule to fully support symint. +// This requires symint variants of other operations, like `view`, +// which don't exist yet. +Tensor expand_symint_decomp_hack(const Tensor& self, SymIntArrayRef packed_size, bool implicit) { + auto size = asIntArrayRefSlow(packed_size); + return self.expand(size, implicit); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + VMAP_SUPPORT(diag, diag_batch_rule); + VMAP_SUPPORT(chunk, chunk_batching_rule); + m.impl("flatten.using_ints", static_cast(native::flatten)); + VMAP_SUPPORT(flip, flip_batch_rule); + RUN_JIT_DECOMPOSITION(trace) + VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril))); + VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu))); + VMAP_SUPPORT(repeat, repeat_batch_rule); + VMAP_SUPPORT(_unsafe_view, _unsafe_view_batch_rule); + VMAP_SUPPORT(unsqueeze, unsqueeze_batch_rule); + m.impl("resize_", resize__plumbing); + VMAP_SUPPORT2(select, int, select_batching_rule); + VMAP_SUPPORT(squeeze, squeeze_batch_rule); + VMAP_SUPPORT2(squeeze, dim, squeeze_dim_batch_rule); + VMAP_SUPPORT(_reshape_alias, _reshape_alias_batch_rule); + VMAP_SUPPORT(roll, roll_batch_rule); + VMAP_SUPPORT(permute, permute_batching_rule); + VMAP_SUPPORT(diagonal, diagonal_batching_rule); + VMAP_SUPPORT(diagonal_backward, diagonal_backward_batch_rule); + VMAP_SUPPORT(select_backward, select_backward_batch_rule); + VMAP_SUPPORT(slice_backward, slice_backward_batch_rule); + VMAP_SUPPORT(view, view_batching_rule); + VMAP_SUPPORT(expand, SINGLE_ARG(expand_batch_rule)); + VMAP_SUPPORT(expand_copy, SINGLE_ARG(expand_batch_rule)); + VMAP_SUPPORT(unfold, unfold_batch_rule); + VMAP_SUPPORT2(movedim, intlist, movedim_batch_rule); + VMAP_SUPPORT2(slice, Tensor, slice_batch_rule); + VMAP_SUPPORT2(transpose, int, transpose_int_batch_rule); + VMAP_SUPPORT(diag_embed, diag_embed_batch_rule); + m.impl("expand.SymInt", expand_symint_decomp_hack); + m.impl("view.SymInt", view_symint_decomposition); +} + +}} diff --git a/functorch/functorch/csrc/BatchedFallback.cpp b/functorch/functorch/csrc/BatchedFallback.cpp new file mode 100644 index 0000000000000..6b6c58b243ee1 --- /dev/null +++ b/functorch/functorch/csrc/BatchedFallback.cpp @@ -0,0 +1,401 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at { +namespace functorch { + +bool kVmapFallbackWarningEnabled = true; + +bool isVmapFallbackWarningEnabled() { + return kVmapFallbackWarningEnabled; +} + +void setVmapFallbackWarningEnabled(bool enabled) { + kVmapFallbackWarningEnabled = enabled; +} + +bool kVmapFallbackEnabled = true; + +bool isVmapFallbackEnabled() { + return kVmapFallbackEnabled; +} + +void setVmapFallbackEnabled(bool enabled) { + kVmapFallbackEnabled = enabled; +} + +// Given a linear index, return the actual index. +// Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0] +static at::SmallVector +computeIndex(int64_t linear_idx, IntArrayRef sizes) { + at::SmallVector result; + result.reserve(sizes.size()); + for (auto it = sizes.rbegin(); it != sizes.rend(); it++) { + auto remainder = linear_idx % *it; + result.push_back(remainder); + linear_idx -= remainder; + linear_idx /= *it; + } + std::reverse(std::begin(result), std::end(result)); + return result; +} + +static bool areAllReturnsTensors(const at::FunctionSchema& schema) { + return std::all_of( + schema.returns().begin(), + schema.returns().end(), + [] (const Argument& arg) { return arg.type() == TensorType::get(); }); +} + +static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) { + return std::any_of( + schema.arguments().begin(), + schema.arguments().end(), + [] (const Argument& arg) { + return arg.type()->isSubtypeOf(ListType::ofTensors()) || + arg.type()->isSubtypeOf(ListType::ofOptionalTensors()); + }); +} + +static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) { + TORCH_CHECK(isVmapFallbackEnabled(), + schema.operator_name(), " hit the vmap fallback which is currently disabled"); + if (!isVmapFallbackWarningEnabled()) { + return; + } + TORCH_WARN("There is a performance drop because we have not yet implemented ", + "the batching rule for ", schema.operator_name(), ". Please file ", + "us an issue on GitHub so that we can prioritize its implementation."); +} + +// The general flow of the algorithm is as follows. +// - First, we figure out which arguments are BatchedTensors and save them +// to a vector. We also store a vector of which index of the arguments list +// each BatchedTensor appears in. This will be useful for bookkeeping later. +// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors. +// This returns a vector of VmapPhysicalView that hold tensors that contain +// all of the collective batch dimensions at the front of the tensors. +// - Then, we attempt to call `op` once per slice of the inputs. To do this, +// we repeatedly we slice the input arguments (if they are BatchedTensors), +// put the sliced (or a not-sliced) version of the input onto the stack, invoke +// the operator, and then pop the results off the stack. +void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + warnFallback(schema, /*in_place*/true); + + const auto num_arguments = schema.arguments().size(); + const auto arguments = torch::jit::last(stack, num_arguments); + const auto arguments_begin = stack->size() - num_arguments; + + // `self` is the Tensor being modified in-place + Tensor self = arguments[0].toTensor(); + const auto* self_impl = maybeGetBatchedImpl(self); + std::bitset self_vmap_levels; + if (self_impl) { + self_vmap_levels = createVmapLevelsBitset(self_impl->level()); + } + + // Figure out which arguments are BatchedTensor. Save them to a vector. + // For each BatchedTensor, also record what position of `arguments` they came from. + at::SmallVector batched_tensor_inputs; + VmapDimVector batched_tensor_inputs_position; + for (const auto idx : c10::irange(0, arguments.size())) { + const auto& ivalue = arguments[idx]; + if (!ivalue.isTensor()) { + continue; + } + const auto& tensor = ivalue.toTensor(); + if (!tensor.defined()) { + continue; + } + const auto* batched = maybeGetBatchedImpl(tensor); + if (!batched) { + continue; + } + + // NOTE: [vmap-incompatible in-place operations] + // In-place operations on `self` are not possible if there exists some vmap + // level `l` such that `self` is not being vmapped on that level but another + // argument is. For example, let B0 be a batch dim inside vmap and consider + // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3)) + // - self is torch.ones(3) and does not participate in this vmap + // - other is BatchedTensor(torch.ones(B0, 3)) + // There's no way to do self.add_(other) because `other` has more elements + // elements than `self` due to being vmapped over. + // + // In the vmap fallback, we should error out when we detect this. + auto other_vmap_levels = createVmapLevelsBitset(batched->level()); + if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) { + // Find one vmap level to complain about + auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels; + auto offending_level = llvm::findLastSet(additional_bdims.to_ulong()); + // The following prints out "vmap: aten::add_(tensor, ...) is not possible", + // but it would be better to print out "tensor.add_(...) is not possible". + // Afaict there's no official way to get the add_ and there is no way to + // tell if an operator has method or function variants. + TORCH_CHECK(false, + "vmap: ", schema.name(), "(self, *extra_args) is not possible because ", + "there exists a Tensor `other` in extra_args that has more elements ", + "than `self`. This happened due to `other` being vmapped over but ", + "`self` not being vmapped over at level ", offending_level, ". ", + "Please try to use out-of-place operators instead of ", schema.name(), ". ", + "If said operator is being called inside the PyTorch framework, ", + "please file a bug report instead."); + } + batched_tensor_inputs.push_back(tensor); + batched_tensor_inputs_position.push_back(idx); + } + TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0); + + // MultiBatchVmapTransform the BatchedTensor arguments. This returns + // VmapPhysicalViews that contain all of the batch dimensions. + const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical( + batched_tensor_inputs); + + // Compute the total number of batches + auto num_batch_dims = input_physical_views.front().numBatchDims(); + auto first_physical_view_sizes = input_physical_views.front().tensor().sizes(); + auto batch_sizes = ArrayRef( + first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims); + const auto num_batches = c10::multiply_integers(batch_sizes); + // Without a shape-checking API, we're unable to compute the correct shape of + // the output so we just error out. + TORCH_CHECK(num_batches > 0, + "Batching rule not implemented for ", schema.operator_name(), ". ", + "The fallback path does not support vmap over dims of size 0."); + + // Strategy: For each batch, we are going to push slices (where applicable) + // of the arguments onto `stack`, and call `op`. + for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) { + auto index = computeIndex(linear_idx, batch_sizes); + auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin(); + auto input_physical_views_iter = input_physical_views.begin(); + for (const auto arg_idx : c10::irange(0, num_arguments)) { + // We assume that torch::jit::Stack is backed by vector for + // simplicity. When that is not the case, this code should be updated. + const auto& argument = (*stack)[arguments_begin + arg_idx]; + if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() + || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) { + // argument isn't a BatchedTensor + torch::jit::push(stack, argument); + continue; + } + // argument is a BatchedTensor + TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end()); + const auto& physical_view_for_argument = *input_physical_views_iter; + auto thing = physical_view_for_argument.tensor().index(index); + torch::jit::push(stack, thing); + batched_tensor_inputs_pos_iter++; + input_physical_views_iter++; + } + + op.callBoxed(stack); + torch::jit::drop(stack, 1); + } + + // Return the tensor that was written to in-place + torch::jit::drop(stack, num_arguments); + torch::jit::push(stack, self); +} + +static Tensor safeStack(TensorList tensors) { + auto is_defined = [](const Tensor& t) { return t.defined(); }; + if (std::all_of(tensors.begin(), tensors.end(), is_defined)) { + return at::stack(tensors); + } + // NOTE [vmap through backward and undefined grad] + // While vmapping through backward functions (to compute batched grad), it + // is possible for the backward function to return an undefined grad for some + // grad_input for each example. In that case, we return an undefined grad. + // + // It is theoretically posssible for *some* of the examples to produce an + // undefined grad (a kernel could peek at the gradient values and return an + // undefined tensor if it determines the gradient is full of zeros). We + // could handle this by treating the undefined grad as a zero-filled tensor + // of the correct shape while stacking the tensors together. However I expect + // this to happen very rarely (I have not been able to find an example in our + // codebase) so we just error out in this case. + if (std::none_of(tensors.begin(), tensors.end(), is_defined)) { + return Tensor(); + } + TORCH_CHECK(false, + "vmap: slow fallback received a mix of undefined and defined tensors ", + "as the result of an operation. This is not supported, please file us ", + "an issue on github."); +} + +// TODO: Consider rewriting the following to look like: +// https://gist.github.com/zou3519/7b7c6a4a258d580f62d1d969851be6b1 + +// The general flow of the algorithm is as follows. +// - First, we figure out which arguments are BatchedTensors and save them +// to a vector. We also store a vector of which index of the arguments list +// each BatchedTensor appears in. This will be useful for bookkeeping later. +// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors. +// This returns a vector of VmapPhysicalView that hold tensors that contain +// all of the collective batch dimensions at the front of the tensors. +// - Then, we attempt to call `op` once per slice of the inputs. To do this, +// we repeatedly we slice the input arguments (if they are BatchedTensors), +// put the sliced (or a not-sliced) version of the input onto the stack, invoke +// the operator, and then pop the results off the stack. +// - Each result obtained from the previous step is a slice of the total result, +// so we stack those tensors together to form the final result. +void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + const auto arguments = torch::jit::last(stack, num_arguments); + + TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema), + "Batching rule not implemented for ", schema.operator_name(), ". ", + "We could not generate a fallback."); + + if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + op.callBoxed(stack); + return; + } + + if (isInplaceOp(schema)) { + batchedTensorInplaceForLoopFallback(op, stack); + return; + } + TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(), + "Batching rule not implemented for ", schema.operator_name(), "; ", + "the fallback path doesn't work on out= or view ops."); + TORCH_CHECK(num_returns >= 1, + "Batching rule not implemented for ", schema.operator_name(), ". ", + "The fallback path does not support operations with no returns."); + warnFallback(schema, /*in_place*/false); + + const auto arguments_begin = stack->size() - num_arguments; + + // Figure out which arguments are BatchedTensor. Save them to a vector. + // For each BatchedTensor, also record what position of `arguments` they came from. + at::SmallVector batched_tensor_inputs; + VmapDimVector batched_tensor_inputs_position; + for (const auto idx : c10::irange(0, arguments.size())) { + const auto& ivalue = arguments[idx]; + if (!ivalue.isTensor()) { + continue; + } + const auto& tensor = ivalue.toTensor(); + if (!tensor.defined()) { + continue; + } + const auto* batched = maybeGetBatchedImpl(tensor); + if (!batched) { + continue; + } + batched_tensor_inputs.push_back(tensor); + batched_tensor_inputs_position.push_back(idx); + } + TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0); + + // MultiBatchVmapTransform the BatchedTensor arguments. This returns + // VmapPhysicalViews that contain all of the batch dimensions. + const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical( + batched_tensor_inputs); + + // Compute the total number of batches + auto num_batch_dims = input_physical_views.front().numBatchDims(); + auto some_sizes = input_physical_views.front().tensor().sizes(); + auto batch_sizes = ArrayRef(some_sizes.begin(), some_sizes.begin() + num_batch_dims); + const auto num_batches = c10::multiply_integers(batch_sizes); + // Without a shape-checking API, we're unable to compute the correct shape of + // the output so we just error out. + TORCH_CHECK(num_batches > 0, + "Batching rule not implemented for ", schema.operator_name(), ". ", + "The fallback path does not support vmap over dims of size 0."); + + // Strategy: For each batch, we are going to push slices (where applicable) + // of the arguments onto `stack`, call `op`, and store the result in + // `output_shards`. + // + // NOTE: [Output shards layout] + // Assume that the operator has three outputs: a, b, c. + // The layout of output_shards is as follows: + // [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3] + // This is so that we can call at::stack([a0...a3]), at::stack([b0...b3]) + // more easily in the next step. + std::vector output_shards(num_batches * num_returns); + + for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) { + auto index = computeIndex(linear_idx, batch_sizes); + auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin(); + auto input_physical_views_iter = input_physical_views.begin(); + for (const auto arg_idx : c10::irange(0, num_arguments)) { + // We assume that torch::jit::Stack is backed by vector for + // simplicity. When that is not the case, this code should be updated. + const auto& argument = (*stack)[arguments_begin + arg_idx]; + if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() + || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) { + // argument isn't a BatchedTensor + torch::jit::push(stack, argument); + continue; + } + // argument is a BatchedTensor + TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end()); + const auto& physical_view_for_argument = *input_physical_views_iter; + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + torch::jit::push(stack, physical_view_for_argument.tensor().index(index)); + batched_tensor_inputs_pos_iter++; + input_physical_views_iter++; + } + + // std::cout << "[Fallback]: "; + // at::dump_tensor((*stack)[stack->size() - 1].toTensor()); + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + op.callBoxed(stack); + + // Store the result into `output_shards`. See NOTE: [Output shards layout] + // to learn about the details of how we store the shards. + const auto returns = torch::jit::last(stack, num_returns); + for (const auto return_idx : c10::irange(0, returns.size())) { + output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor(); + } + torch::jit::drop(stack, num_returns); + } + + // For each output Tensor, stack the shards of the tensor together to form a return + torch::jit::drop(stack, num_arguments); + auto output_shards_chunks = MatrixRef(output_shards, num_batches); + for (const auto return_idx : c10::irange(0, num_returns)) { + auto shards = output_shards_chunks[return_idx]; + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto flat_output = safeStack(shards); + // See NOTE [vmap through backward and undefined grad] + if (!flat_output.defined()) { + torch::jit::push(stack, flat_output); + continue; + } + VmapDimVector output_sizes(batch_sizes); + output_sizes.insert( + output_sizes.end(), + flat_output.sizes().begin() + 1, + flat_output.sizes().end()); + torch::jit::push( + stack, + input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes))); + } +} + +} +} // namespace at diff --git a/functorch/functorch/csrc/BatchedFallback.h b/functorch/functorch/csrc/BatchedFallback.h new file mode 100644 index 0000000000000..9130245f28b1e --- /dev/null +++ b/functorch/functorch/csrc/BatchedFallback.h @@ -0,0 +1,71 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include + +namespace at { +namespace functorch { + +// If an operator doesn't have a batching rule implemented then we fallback +// to this implementation. The fallback only works on out-of-place operators +// that return only tensors with new memory. (e.g., no in-place operators, no +// view operations). +// +// The fallback effectively takes all of the BatchedTensors in `stack`, slices +// them, and runs `op` on all of the corresponding slices to produce slices +// of the outputs. The output slices then get `torch.stack`ed to create the +// final returns. +// +// The performance of the fallback is not very good because it introduces an +// extra copy from stacking the sliced outputs. Because of this, we prefer to +// write batching rules for operators whenever possible. +void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +bool isVmapFallbackWarningEnabled(); +void setVmapFallbackWarningEnabled(bool enabled); + +bool isVmapFallbackEnabled(); +void setVmapFallbackEnabled(bool enabled); + +template A vector_to_result(const std::vector& buffer) { + return buffer[0].to(); +} +template std::tuple vector_to_result(const std::vector& buffer) { + return std::make_tuple(buffer[0].to(), buffer[1].to()); +} +template std::tuple vector_to_result(const std::vector& buffer) { + return std::make_tuple(buffer[0].to(), buffer[1].to(), buffer[2].to()); +} + +// This is a way to call the slow fallback from inside some plumbing +// TODO: Probably better way to metaprogram this +template +Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + +template +std::tuple slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + +template +std::tuple slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + + +} +} // namespace at diff --git a/functorch/functorch/csrc/BatchedTensorImpl.cpp b/functorch/functorch/csrc/BatchedTensorImpl.cpp new file mode 100644 index 0000000000000..487df29000716 --- /dev/null +++ b/functorch/functorch/csrc/BatchedTensorImpl.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#include + +#include +#include + +#include +#include + +namespace at { +namespace functorch { + +BatchedTensorImpl::BatchedTensorImpl(Tensor value, int64_t bdim, int64_t level) + : TensorImpl( + c10::DispatchKeySet(kBatchedKey), + value.dtype(), + value.device() + ) + , value_(std::move(value)) + , level_(level) + , bdim_(bdim) +{ + // TODO: I don't think this ctor gets used. + TORCH_INTERNAL_ASSERT(false); + TORCH_INTERNAL_ASSERT(value_.defined()); + set_storage_access_should_throw(); + set_sizes_strides_policy(SizesStridesPolicy::CustomStrides); + checkInvariants(); + + const auto public_dims = value_.dim() - 1; + const auto value_sizes = value_.sizes(); + const auto value_strides = value_.strides(); + sizes_and_strides_.resize(public_dims); + for (const auto dim : c10::irange(0, public_dims)) { + auto actual_dim = actualDim(dim, /*wrap_dim=*/false); + sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim); + sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim); + } + storage_offset_= value_.storage_offset(); + refresh_numel(); + refresh_contiguous(); +} + +BatchedTensorImpl::BatchedTensorImpl(DispatchKeySet key_set, Tensor value, int64_t bdim, int64_t level) + : TensorImpl( + key_set.add(kBatchedKey), + value.dtype(), + value.device() + ) + , value_(std::move(value)) + , level_(level) + , bdim_(bdim) +{ + TORCH_INTERNAL_ASSERT(value_.defined()); + set_storage_access_should_throw(); + set_sizes_strides_policy(SizesStridesPolicy::CustomStrides); + checkInvariants(); + refreshTensorMetadata(); +} + +void BatchedTensorImpl::refreshTensorMetadata() { + const auto public_dims = value_.dim() - 1; + const auto value_sizes = value_.sizes(); + const auto value_strides = value_.strides(); + sizes_and_strides_.resize(public_dims); + for (const auto dim : c10::irange(0, public_dims)) { + auto actual_dim = actualDim(dim, /*wrap_dim=*/false); + sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim); + sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim); + } + storage_offset_= value_.storage_offset(); + refresh_numel(); + refresh_contiguous(); +} + +int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const { + if (wrap_dim) { + const auto ndim = sizes_and_strides_.size(); + dim = maybe_wrap_dim(dim, ndim); + } + auto is_bdim = createBatchDimBitset(bdim_); + + // TODO(vfdev): As BatchedTensorImpl is refactored and has only one dim. + // Below code may be simplified. + + // Example: assume dim = 3, and is_bdim = 10010011000... + // The 1's are batch dims and 0's are normal dims of the underlying value_ Tensor. + // actualDim gives us the index of `dim` in the `value_` Tensor, which is equivalent + // to asking "where does the 3rd (0-indexed) zero occur in the bitset?". + // The answer to that is index 5. + // + // TODO(rzou): the PDEP instruction does exactly this + // (https://stackoverflow.com/questions/7669057/find-nth-set-bit-in-an-int) + // but it might require newer (>= ~2015) CPUs. We should clean this up + // if/when we have dropped support for older CPUs. + int64_t non_bdim_count = 0; + for (int64_t actual_dim = 0; actual_dim < kVmapMaxTensorDims; actual_dim++) { + if (is_bdim[actual_dim]) { + continue; + } + if (non_bdim_count == dim) { + return actual_dim; + } + non_bdim_count++; + } + // If we hit this assert, then that means + // `non_bdim_count` + #num_bdims > kVmapMaxTensorDims. We restrict the number + // of dims a BatchedTensorImpl can have to kVmapMaxTensorDims so this should + // never be hit. + TORCH_INTERNAL_ASSERT(false); +} + +void BatchedTensorImpl::checkInvariants() const { + TORCH_INTERNAL_ASSERT(level_ > -1); +} + +// The following are publically exposed as methods of Tensor + +IntArrayRef BatchedTensorImpl::strides_custom() const { + return strides_default(); +} + +// TODO: implement proper contiguity on batched tensor, then put +// sizes_strides_policy back to Default +bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { + TORCH_CHECK(memory_format == MemoryFormat::Contiguous, + "NYI: querying is_contiguous inside of vmap for memory_format ", + "other than torch.contiguous_format"); + return is_contiguous_; +} + +// The following are some internal inherited methods that we do not support. +// They should never get called. +void BatchedTensorImpl::set_size(int64_t dim, int64_t new_size) { + TORCH_INTERNAL_ASSERT(false, "Can't set_size for BatchedTensorImpl"); +} +void BatchedTensorImpl::set_stride(int64_t dim, int64_t new_stride) { + TORCH_INTERNAL_ASSERT(false, "Can't set_stride for BatchedTensorImpl"); +} +void BatchedTensorImpl::set_storage_offset(int64_t storage_offset) { + TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for BatchedTensorImpl"); +} +#ifdef DEBUG +bool BatchedTensorImpl::has_storage() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "BatchedTensorImpl assumes that storage_ is never set"); + return false; +} +#endif + +const char* BatchedTensorImpl::tensorimpl_type_name() const { + return "BatchedTensorImpl"; +} + +Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) { + DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor); + auto* batched = maybeGetBatchedImpl(tensor); + if (batched) { + auto batched_level = batched->level(); + TORCH_INTERNAL_ASSERT(level > batched_level, " batched_level: ", batched_level, " level: ", level); + } + return at::detail::make_tensor(key_set, tensor, bdim, level); +} + +Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level) { + return makeBatched(tensor, dim, level); +} + +} +} // namespace at diff --git a/functorch/functorch/csrc/BatchedTensorImpl.h b/functorch/functorch/csrc/BatchedTensorImpl.h new file mode 100644 index 0000000000000..0b7e7f602641a --- /dev/null +++ b/functorch/functorch/csrc/BatchedTensorImpl.h @@ -0,0 +1,148 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include +#include + +#include +#include + +namespace at { +namespace functorch { + +using Tensor = at::Tensor; + +// We assume this in a few other places in the codebase, +// but there isn't a centralized definition. +constexpr int64_t kVmapMaxTensorDims = 64; + +// The valid vmap levels range from [0, 64). This effectively means that we +// support a maximum of 64 nested vmaps. +constexpr int64_t kVmapNumLevels = 64; + +// Store this number of elements of BatchDims on the stack. Most people will +// probably use <= 5 nested vmaps, but adjust this number as necessary. +constexpr int64_t kBatchDimsStackSize = 5; + +// A BatchedTensorImpl holds an underlying Tensor and a single batch dim +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +// +// The batch dimensions are treated as being "private"; they are not user-visible. +// For example, in the following Tensor, +// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0) +// dimension 0 is batch dimension. +// +// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public) +// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor. +struct BatchedTensorImpl : public c10::TensorImpl { + explicit BatchedTensorImpl(Tensor value, int64_t dim, int64_t level); + explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level); + + // Returns batch dimension of this tensor + int64_t bdim() const { return bdim_; } + + // Returns batch dimension of this tensor + int64_t level() const { return level_; } + + // BatchedTensorImpl wraps a Tensor + const Tensor& value() const { return value_; } + + // Given a public dimension index, return the dimension index in the underlying + // value() tensor. + // For example, if we have + // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0) + // bt.actualDim(0) -> 1 + // bt.actualDim(1) -> 2 + // bt.actualDim(2) -> 3 + // bt.actualDim(3) -> Error + int64_t actualDim(int64_t dim, bool wrap_dim = true) const; + + // We have to override this because we opted into CustomStrides + IntArrayRef strides_custom() const override; + // Override a bunch of methods inherited from TensorImpl to return error messages. + bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; +#ifdef DEBUG + bool has_storage() const override; +#endif + + void refreshTensorMetadata(); + void _unsafe_set_level(int64_t level) { + level_ = level; + } + private: + // see NOTE: [BatchedTensorImpl levels invariant] + void checkInvariants() const; + const char* tensorimpl_type_name() const override; + + Tensor value_; + + int64_t level_; + int64_t bdim_; +}; + +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +inline bool isBatchedTensor(const Tensor& tensor) { + return tensor.unsafeGetTensorImpl()->key_set().has(kBatchedKey); +} + +// It is unsafe to call this on a Tensor that is not backed by a +// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible. +inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) { + return static_cast(tensor.unsafeGetTensorImpl()); +} + +inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) { + if (!isBatchedTensor(tensor)) { + return nullptr; + } + return unsafeGetBatchedImpl(tensor); +} + +// Returns a bitset. If bit i is set, then that means dim i is a batchdim. +inline std::bitset createBatchDimBitset(int64_t dim) { + std::bitset is_bdim; + is_bdim.set(dim); + return is_bdim; +} + +// Creates a bitset for the given level +inline std::bitset createVmapLevelsBitset(int64_t level) { + std::bitset result; + result.set(level); + return result; +} + +// Use this to construct a BatchedTensor from a regular Tensor +FUNCTORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level); + +// Adds a batch dim to `tensor`, returning a BatchedTensor +FUNCTORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level); + +constexpr DispatchKeySet kKeysToPropagateToWrapper({ + DispatchKey::Negative, + DispatchKey::Conjugate, + DispatchKey::XLA, + DispatchKey::CUDA, + DispatchKey::CPU, +}); + +inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { + auto key_set = tensor.unsafeGetTensorImpl()->key_set(); + return key_set & kKeysToPropagateToWrapper; +} + +} +} diff --git a/functorch/functorch/csrc/BatchingMetaprogramming.h b/functorch/functorch/csrc/BatchingMetaprogramming.h new file mode 100644 index 0000000000000..e054e58568be2 --- /dev/null +++ b/functorch/functorch/csrc/BatchingMetaprogramming.h @@ -0,0 +1,120 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +namespace at { +namespace functorch { + +// Metaprogramming things +template using typelist = c10::guts::typelist::typelist; +template using head_t = c10::guts::typelist::head_t; +template using concat_t = c10::guts::typelist::concat_t; +template class debug_t; + +// tail operation +template +struct tail final { + static_assert(c10::guts::false_t::value, + "In typelist::tail, the T argument must be typelist<...>."); +}; +template +struct tail> final { + using type = typelist; +}; +template using tail_t = typename tail::type; + +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext { + using type = Next; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, optional, Next, Tail> { + using type = Tail; +}; +template struct RemoveBatchDimAfterTensor { + using first = head_t; + using next = tail_t; + using second = head_t; + using tail = tail_t; + + using type = concat_t< + typelist, + typename RemoveBatchDimAfterTensor< + typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext::type + >::type + >; +}; +template struct RemoveBatchDimAfterTensor> { + using type = typelist; +}; +template <> struct RemoveBatchDimAfterTensor> { + using type = typelist<>; +}; +template using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor::type; + +template struct UnpackSingleItemTuple { + using type = T; +}; +template struct UnpackSingleItemTuple> { + using type = T; +}; +template using unpack_single_item_tuple_t = typename UnpackSingleItemTuple::type; + +template struct BuildFunctionHelper; +template struct BuildFunctionHelper> { + using type = Return(Args...); +}; +template +struct BuildFunction { + using type = typename BuildFunctionHelper>::type; +}; +template using build_function_t = typename BuildFunction::type; + + +template struct ToOperatorType { + using batch_rule_return_type = typename c10::guts::function_traits::return_type; + using batch_rule_parameter_types = typename c10::guts::function_traits::parameter_types; + + using operator_parameter_types = remove_batch_dim_after_tensor_t; + using operator_return_type = + unpack_single_item_tuple_t< + c10::guts::typelist::to_tuple_t< + remove_batch_dim_after_tensor_t< + c10::guts::typelist::from_tuple_t>>>; + + using type = build_function_t; +}; +template using to_operator_t = typename ToOperatorType::type; + +} +} // namespace at diff --git a/functorch/functorch/csrc/CompileCache.cpp b/functorch/functorch/csrc/CompileCache.cpp new file mode 100644 index 0000000000000..4c87800c88892 --- /dev/null +++ b/functorch/functorch/csrc/CompileCache.cpp @@ -0,0 +1,288 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +/// +/// This design stemmed of from the PointwiseOperatorCompileCache with the +/// purpose of making it more generic for AOTAutograd. This is Compile Cache +/// allowing different types of hashing functions, and is agnostic of the +/// compiler. +/// +#include +#include +#include +#include +#include + +using namespace torch::jit::tensorexpr; + +namespace { +/// Record of thread-local state that changes operator behavior. +struct LocalState { + c10::impl::LocalDispatchKeySet dispatchModifier; + bool gradModeEnabled; + + at::DispatchKeySet apply(at::DispatchKeySet ks) const { + return (ks | dispatchModifier.included_) - dispatchModifier.excluded_; + } + + LocalState() + : dispatchModifier(c10::impl::tls_local_dispatch_key_set()), + gradModeEnabled(at::GradMode::is_enabled()) {} +}; + +/// Helper to pack tensor (dtype, requires grad) into an 8-bit key. +static uint8_t packFlags(const LocalState &state, const at::Tensor &v) { + static_assert(static_cast(at::ScalarType::NumOptions) < 128, + "overflow possible"); + at::ScalarType dtype = v.dtype().toScalarType(); + bool requires_grad = state.gradModeEnabled && v.requires_grad(); + return static_cast(requires_grad) | + (static_cast(dtype) << 1); +} + +using hash_key_t = std::vector; +/// Per-tensor cache specialization key targetting dynamic shapes. Records +/// dtype, dispatch options, aliasing, and per-dim contiguity/broadcasting +/// information. + +enum DimFlags { + /// A leading dimension implicitly added by broadcasting. + SIZE_MISSING = 1 << 0, + + /// Size == 1. + SIZE_ONE = 1 << 1, + + /// Size > 1. + SIZE_OTHER = 1 << 2, + + /// Stride == 0; broadcasting. + STRIDE_ZERO = 1 << 3, + + /// Stride == 1; packed contiguously in memory. + STRIDE_ONE = 1 << 4, + + /// Stride = Stride[i + 1] * Size[i + 1]. + /// Used to collapse dimensions. + STRIDE_CONTIGUOUS = 1 << 5, + + /// Stride = Stride[i - 1] * Size[i - 1]. + /// Used to collapse dimensions in the other direction. + STRIDE_TRANSPOSED_CONTIGUOUS = 1 << 6, // stride[i-1] * sizes[i-1] + + /// Stride must be provided as an argument. + STRIDE_AS_ARG = 1 << 7, +}; + +/// Unique hasher id values to uniquely identify the type of hash. NONE_HASH is +/// used when a tensor is undefined. +enum HasherFlags { + NONE_HASH, + STATIC_HASH, + DYNAMIC_HASH, +}; + +std::vector genDimFlags(c10::IntArrayRef sizes, c10::IntArrayRef strides) { + // Pack all the properties for each dimension into a uint8. + int nDims = sizes.size(); + std::vector dimflags(nDims); + for (int64_t dim = 0; dim < nDims; ++dim) { + uint8_t flag = + (sizes[dim] == 0 ? SIZE_MISSING + : (sizes[dim] == 1 ? SIZE_ONE : SIZE_OTHER)); + if (strides[dim] == 0) { + flag |= STRIDE_ZERO; + } else if (strides[dim] == 1) { + flag |= STRIDE_ONE; + } else if (dim + 1 < (int64_t)sizes.size() && + strides[dim] == strides[dim + 1] * sizes[dim + 1]) { + flag |= STRIDE_CONTIGUOUS; + } else if (dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] && + (dimflags[dim - 1] & STRIDE_CONTIGUOUS) == 0) { + flag |= STRIDE_TRANSPOSED_CONTIGUOUS; + } else { + flag |= STRIDE_AS_ARG; + } + dimflags[dim] = flag; + } + return dimflags; +} + +hash_key_t dynamic_hasher(const LocalState &state, const at::Tensor &v) { + hash_key_t hash = {DYNAMIC_HASH, static_cast(packFlags(state, v)), + static_cast(state.apply(v.key_set()).raw_repr()), + static_cast(v.ndimension())}; + auto dimFlags = genDimFlags(v.sizes(), v.strides()); + hash.insert(hash.end(), dimFlags.begin(), dimFlags.end()); + return hash; +} + +/// Per-tensor cache specialization key targetting static shapes. Recordsdtype, +/// dispatch options, aliasing, and full shapes and strides. +hash_key_t static_hasher(const LocalState &state, const at::Tensor &v) { + hash_key_t hash = {STATIC_HASH, static_cast(packFlags(state, v)), + static_cast(state.apply(v.key_set()).raw_repr()), + static_cast(v.ndimension())}; + hash.insert(hash.end(), v.sizes().begin(), v.sizes().end()); + hash.insert(hash.end(), v.strides().begin(), v.strides().end()); + return hash; +} + +/// ArgCompileCache is a templated class allowing plugging of different types of +/// Hasher/Specialization Keys. +struct CompileCache { +public: + CompileCache() = default; + ~CompileCache() = default; + + /// Array defining groups of aliased tensors. + + /// Cache type mapping specialization keys to compiled kernels. + class vector_hasher { + public: + std::size_t operator()(hash_key_t const &vec) const { + std::size_t seed = vec.size(); + for (auto &i : vec) { + seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } + }; + using Cache = std::unordered_map; + + /// Compute the set of specialization keys based on the inputs to + /// the kernel. + hash_key_t computeCacheKey(PyObject *args, + const std::vector &tensorArgs, + int numTensorArgs, const std::string &hasherType, + int64_t id, int64_t fw_compiler_id, + int64_t bw_compiler_id) { + LocalState state; + hash_key_t cacheKey; + for (int i = 0; i < numTensorArgs; ++i) { + if (tensorArgs[i].defined()) { + // Only hash the tensor when its defined. + if (hasherType == "StaticShapeHasher") { + auto res = static_hasher(state, tensorArgs[i]); + cacheKey.insert(cacheKey.end(), res.begin(), res.end()); + } else if (hasherType == "DynamicShapeHasher") { + auto res = dynamic_hasher(state, tensorArgs[i]); + cacheKey.insert(cacheKey.end(), res.begin(), res.end()); + } + } else { + // Add a value to the cacheKey to indicate a None tensor. + cacheKey.push_back(NONE_HASH); + } + } + cacheKey.push_back(id); + cacheKey.push_back(fw_compiler_id); + cacheKey.push_back(bw_compiler_id); + cacheKey.push_back(numTensorArgs); + + // Cache the non-tensor args. Currently, all the non-tensor args are cached. + for (int i = numTensorArgs; i < PyTuple_Size(args); i++) { + PyObject *arg = PyTuple_GET_ITEM(args, i); + assert(PyLong_Check(arg)); + cacheKey.push_back(PyLong_AsLong(arg)); + } + return cacheKey; + } + + std::vector parsePythonArgs(int numTensorArgs, PyObject *args) { + // Convert to Tensor Args + std::vector tensorArgs(numTensorArgs); + for (int i = 0; i < numTensorArgs; ++i) { + PyObject *arg = PyTuple_GET_ITEM(args, i); + if (arg == Py_None) { + // If an input tensor is None, add it as an undefined tensor. + tensorArgs[i] = at::Tensor(); + } else if (!THPVariable_Check(arg)) { + // Fail if its a non-tensor arg. It should be marked static. + std::string dtype = Py_TYPE(arg)->tp_name; + std::string index = std::to_string(i); + throw std::runtime_error("Found an argument of type " + dtype + + " at index " + index + + ". Non-tensor arguments must be marked static." + " Please set the static_argnums correctly to " + "mark the argument at index " + + index + " static."); + } else { + tensorArgs[i] = THPVariable_Unpack(arg); + } + } + return tensorArgs; + } + + /// Check if the function has already been compiled. + py::object at(int64_t id, int64_t fw_compiler_id, int64_t bw_compiler_id, + int numTensorArgs, const std::string &hasherType, + PyObject *args) { + std::vector tensorArgs = parsePythonArgs(numTensorArgs, args); + hash_key_t cacheKey = + computeCacheKey(args, tensorArgs, numTensorArgs, hasherType, id, + fw_compiler_id, bw_compiler_id); + + auto item = cache_.find(cacheKey); // protected by GIL + + if (C10_LIKELY(item != cache_.end())) { + return item->second; + } + return py::none(); + } + + /// Insert a new compiled functions for new tensor properties. + void insert(int64_t id, int64_t fw_compiler_id, int64_t bw_compiler_id, + int numTensorArgs, const std::string &hasherType, + const py::object &compileFn, PyObject *args) { + std::vector tensorArgs = parsePythonArgs(numTensorArgs, args); + LocalState state; + hash_key_t cacheKey = + computeCacheKey(args, tensorArgs, numTensorArgs, hasherType, id, + fw_compiler_id, bw_compiler_id); + cache_.emplace(cacheKey, compileFn); + } + + const int64_t size() const { return cache_.size(); } + + /// Clear the cache. + void clear() { cache_.clear(); } + +private: + /// Compilation cache holding key and the compiled function. + Cache cache_; +}; + +static CompileCache *createCompileCache() { return new CompileCache(); } + +} // namespace + +namespace at { +namespace functorch { + +void initCompileCacheBindings(PyObject *module) { + py::handle te(module); + py::class_(te, "CompileCache") + .def(py::init(&createCompileCache)) + .def("at", + [](CompileCache &self, int64_t id, int64_t fw_compiler_id, + int64_t bw_compiler_id, int numTensorArgs, + const std::string &hasherType, py::args args) { + return self.at(id, fw_compiler_id, bw_compiler_id, numTensorArgs, + hasherType, args.ptr()); + }) + .def("insert", + [](CompileCache &self, int64_t id, int64_t fw_compiler_id, + int64_t bw_compiler_id, int numTensorArgs, + const std::string &hasherType, const py::object &compileFn, + py::args args, py::kwargs kwargs) { + self.insert(id, fw_compiler_id, bw_compiler_id, numTensorArgs, + hasherType, compileFn, args.ptr()); + }) + .def("clear", [](CompileCache &self) { self.clear(); }) + .def("size", [](CompileCache &self) { return self.size(); }); +} + +} // namespace functorch +} // namespace at diff --git a/functorch/functorch/csrc/CompileCache.h b/functorch/functorch/csrc/CompileCache.h new file mode 100644 index 0000000000000..e67b1db63eb38 --- /dev/null +++ b/functorch/functorch/csrc/CompileCache.h @@ -0,0 +1,17 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +#include + +namespace at { +namespace functorch { + +/// Initialize python bindings for kernel compilation cache. +void initCompileCacheBindings(PyObject *module); + +} // namespace functorch +} // namespace at diff --git a/functorch/functorch/csrc/Constants.h b/functorch/functorch/csrc/Constants.h new file mode 100644 index 0000000000000..f6e614e042465 --- /dev/null +++ b/functorch/functorch/csrc/Constants.h @@ -0,0 +1,31 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include + +namespace at { +namespace functorch { + +#define FT_BATCHED_KEY FuncTorchBatched +#define FT_VMAP_MODE_KEY FuncTorchVmapMode +#define FT_GRAD_WRAPPER_KEY FuncTorchGradWrapper +#define FT_DYNAMIC_LAYER_FRONT_MODE_KEY FuncTorchDynamicLayerFrontMode +#define FT_DYNAMIC_LAYER_BACK_MODE_KEY FuncTorchDynamicLayerBackMode +#define FT_PYTHON_KEY FuncTorchPython + +constexpr auto kBatchedKey = c10::DispatchKey::FT_BATCHED_KEY; +constexpr auto kVmapModeKey = c10::DispatchKey::FT_VMAP_MODE_KEY; +constexpr auto kGradWrapperKey = c10::DispatchKey::FT_GRAD_WRAPPER_KEY; +constexpr auto kDynamicLayerFrontModeKey = c10::DispatchKey::FT_DYNAMIC_LAYER_FRONT_MODE_KEY; +constexpr auto kDynamicLayerBackModeKey = c10::DispatchKey::FT_DYNAMIC_LAYER_BACK_MODE_KEY; +//# constexpr auto kPythonKey = c10::DispatchKey::FT_PYTHON_KEY; + +// Some helper macros +#define DECLTYPE_AUTO(...) decltype(__VA_ARGS__), __VA_ARGS__ +#define SINGLE_ARG(...) __VA_ARGS__ + +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/CustomFunction.cpp b/functorch/functorch/csrc/CustomFunction.cpp new file mode 100644 index 0000000000000..0c60726b0efb2 --- /dev/null +++ b/functorch/functorch/csrc/CustomFunction.cpp @@ -0,0 +1,290 @@ +#ifndef _WIN32 +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { namespace functorch { + +class PythonKernelHolder : public c10::OperatorKernel { + PyObject* func_; + +public: + + PythonKernelHolder(py::object func) : func_(func.release().ptr()) {} + // This is a generally useful pattern and safer than directly using pybind11's + // py::object destructor. This is because this object may outlive + // libtorch_python, so we want to disarm the deallocation if that happens. + // PyInterpreter does this correctly, pybind11 does not. + ~PythonKernelHolder() override { + getPyInterpreter()->decref(func_, /*is_tensor*/false); + } + + void operator()(const c10::OperatorHandle& op, c10::DispatchKeySet, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + + const auto num_arguments = schema.arguments().size(); + auto arguments = torch::jit::pop(*stack, num_arguments); + + // TODO: Some duplication with torch/csrc/autograd/python_variable.cpp + + py::gil_scoped_acquire g; + + // Pre-scan for arguments that match defaults + int64_t default_suffix_len = 0; + for (int64_t idx = arguments.size() - 1; idx >= 0; idx--) { + const auto& arg = schema.arguments()[idx]; + if (!arg.default_value().has_value()) { + break; + } + const auto& default_ivalue = *arg.default_value(); + const auto& ivalue = arguments[idx]; + if (default_ivalue != ivalue) { + break; + } + default_suffix_len++; + } + + auto args = py::reinterpret_steal(PyTuple_New(num_arguments - default_suffix_len)); + // TODO: actually populate kwargs sometimes? At the moment, every argument + // // just gets passed positionally + py::dict kwargs; + + for (int64_t idx = 0; idx < (int64_t)arguments.size() - default_suffix_len; idx++) { + PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(std::move(arguments[idx])).release().ptr()); + } + + auto out = py::reinterpret_steal(PyObject_Call(func_, args.ptr(), kwargs.ptr())); + if (out.ptr() == nullptr) { + throw python_error(); + } + + if (op.schema().returns().size() == 1) { + torch::jit::push(stack, torch::jit::toIValue(out.ptr(), op.schema().returns()[0].type())); + } else { + auto outs = py::cast(out); + for (unsigned idx = 0; idx < outs.size(); idx++) { + torch::jit::push(stack, torch::jit::toIValue(outs[idx].ptr(), op.schema().returns()[idx].type())); + } + } + } +}; + +torch::Library::Kind parseKind(const std::string& k) { + static std::unordered_map kind_map = { + {"DEF", torch::Library::DEF}, + {"IMPL", torch::Library::IMPL}, + {"FRAGMENT", torch::Library::FRAGMENT}, + }; + auto it = kind_map.find(k); + TORCH_CHECK(it != kind_map.end(), "could not parse ", k); + return it->second; +} +c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) { + static std::unordered_map key_map = { + {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE}, + {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA}, + {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION}, + {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default + }; + auto it = key_map.find(k); + TORCH_CHECK(it != key_map.end(), "could not parse ", k); + return it->second; +} + + +template +inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { + auto mb_key = std::string(key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(key)); + if (mb_key) { + return torch::dispatch(*mb_key, std::forward(raw_f)); + } else { + torch::CppFunction f(std::forward(raw_f)); + return f; + } +} + +std::vector unpack(at::TensorList tl, const char *name, int pos) { + std::vector ret(tl.size()); + for (const auto i : c10::irange(tl.size())) { + const auto &t = tl[i]; + if (!t.defined()) { + continue; + } + ret[i] = static_cast(t); + } + return ret; +} + +std::vector invoke_backward_fn( + PyObject* backward_function, + TensorList grads, + TensorList intermediates) { + std::vector result; + + py::gil_scoped_acquire g; + auto args = py::reinterpret_steal(PyTuple_New(grads.size() + intermediates.size())); + py::dict kwargs; + for (int64_t idx = 0; idx < (int64_t) grads.size(); idx++) { + PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(grads[idx]).release().ptr()); + } + for (int64_t idx = 0; idx < (int64_t) intermediates.size(); idx++) { + PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(intermediates[idx + grads.size()]).release().ptr()); + } + + auto out = py::reinterpret_steal(PyObject_Call(backward_function, args.ptr(), kwargs.ptr())); + if (out.ptr() == nullptr) { + throw python_error(); + } + + for (unsigned idx = 0; idx < grads.size(); idx++) { + auto ivalue = torch::jit::toIValue(PyTuple_GetItem(out.ptr(), idx), TensorType::get()); + result.emplace_back(ivalue.toTensor()); + } + return result; +} + +// TODO: figure out what this is +using torch::autograd::variable_list; +using custom_function_t = std::vector (at::TensorList); + +void copy_range(variable_list& out, torch::autograd::IndexRange range, at::ArrayRef t) { + AT_ASSERT(range.second <= out.size()); + std::cout << range.second << ", " << range.first << ", " << t.size() << std::endl; + AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output"); + std::copy(t.begin(), t.end(), out.begin() + range.first); +} + +struct GenericPythonBackward : public torch::autograd::TraceableFunction { + using TraceableFunction::TraceableFunction; + + variable_list apply(variable_list&& grads) override; + std::string name() const override { return "GenericPythonBackward"; } + void release_variables() override { + std::lock_guard lock(mutex_); + for (auto& t : saved_tensors_) { + t.reset_data(); + } + } + std::vector saved_tensors_; + int64_t num_inputs_; + optional backward_fn_; +}; + +variable_list GenericPythonBackward::apply(variable_list&& grads) { + std::lock_guard lock(mutex_); + + torch::autograd::generated::details::IndexRangeGenerator gen; + auto tensors_ix = gen.range(saved_tensors_.size()); + variable_list grad_inputs(num_inputs_); + + std::vector args; + for (auto& g : grads) { + args.emplace_back(std::move(g)); + } + for (const auto& saved : saved_tensors_) { + args.emplace_back(saved.unpack(shared_from_this())); + } + + if (should_compute_output({ tensors_ix })) { + auto handle = backward_fn_->typed(); + auto grad_result = handle.call(args); + grad_inputs = grad_result; + // copy_range(grad_inputs, tensors_ix, grad_result); + } + return grad_inputs; +} + +using custom_python_function_t = TensorList (*)(TensorList); + +using torch::autograd::compute_requires_grad; +using torch::autograd::collect_next_edges; +using torch::autograd::deleteNode; +using torch::autograd::flatten_tensor_args; + +void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + auto tensors = torch::jit::pop(stack).toTensorList().vec(); + auto tensors_ = unpack(tensors, "tensors", 0); + auto _any_requires_grad = compute_requires_grad(tensors); + (void)_any_requires_grad; + + std::string schema_name = op.schema().name(); + std::string vjp_fn_name = schema_name + "_vjp"; + + std::shared_ptr grad_fn; + if (_any_requires_grad) { + grad_fn = std::shared_ptr(new GenericPythonBackward(), deleteNode); + grad_fn->set_next_edges(collect_next_edges(tensors)); + grad_fn->backward_fn_ = c10::Dispatcher::singleton().findSchemaOrThrow(vjp_fn_name.c_str(), ""); + grad_fn->num_inputs_ = tensors_.size(); + } + + auto typed_handle = op.typed(); + std::vector _tmp = ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return typed_handle.call(tensors_); + })(); + auto result = std::move(_tmp); + if (grad_fn) { + for (auto& tensor : result) { + // TODO: is this right? + bool is_input = false; + for (const auto& input : tensors_) { + if (tensor.unsafeGetTensorImpl() == input.unsafeGetTensorImpl()) { + is_input = true; + } + } + + if (!is_input) { + set_history(tensor, grad_fn); + } + grad_fn->saved_tensors_.emplace_back(tensor, !is_input); + } + } + torch::jit::push(stack, result); +} + +void initDispatchBindings(PyObject* module) { + auto m = py::handle(module).cast(); + + py::class_(m, "_DispatchModule", py::module_local()) + .def("def_", [](py::object self, const char* schema, const char* alias) { + self.cast().def(torch::schema(schema, at::functorch::parseAliasAnalysisKind(alias))); + return self; + }, "", py::arg("schema"), py::arg("alias") = "") + .def("impl", [](py::object self, const char* name, const char* dispatch, py::object func) { + self.cast().impl( + name, + dispatch_str(dispatch, torch::CppFunction::makeFromBoxedFunctor(std::make_unique(std::move(func)))) + ); + }, "", py::arg("name"), py::arg("dispatch"), py::arg("func")) + .def("gen_backward_binding", [](py::object self, const char* name, const char* dispatch) { + self.cast().impl( + name, + dispatch_str(dispatch, + torch::CppFunction::makeFromBoxedFunction<&customFunctionBoxed>()) + ); + }, "", py::arg("name"), py::arg("dispatch")) + .def("fallback_fallthrough", [](py::object self, const char* dispatch) { + self.cast().fallback( + dispatch_str(dispatch, torch::CppFunction::makeFallthrough()) + ); + return self; + }, "", py::arg("dispatch") = "") + ; + + m.def("_dispatch_library", [](const char* kind, std::string name, const char* dispatch) { + auto mb_key = std::string(dispatch) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch) ); + return std::make_unique(parseKind(kind), std::move(name), mb_key, "/dev/null", 0); + }); +} + + +}} // at::functorch +#endif // #ifndef _WIN32 diff --git a/functorch/functorch/csrc/CustomFunction.h b/functorch/functorch/csrc/CustomFunction.h new file mode 100644 index 0000000000000..f9ef44faacb87 --- /dev/null +++ b/functorch/functorch/csrc/CustomFunction.h @@ -0,0 +1,14 @@ +#pragma once + +#ifndef _WIN32 +#include +#include +#include +#include + +namespace at { namespace functorch { + +void initDispatchBindings(PyObject* module); + +}} +#endif // #ifndef _WIN32 diff --git a/functorch/functorch/csrc/DynamicLayer.cpp b/functorch/functorch/csrc/DynamicLayer.cpp new file mode 100644 index 0000000000000..8bfd388358a0f --- /dev/null +++ b/functorch/functorch/csrc/DynamicLayer.cpp @@ -0,0 +1,511 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace functorch { + +void setDynamicLayerFrontBackKeysIncluded(bool included) { + c10::impl::tls_set_dispatch_key_included(kDynamicLayerFrontModeKey, included); + c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, included); +} + +DynamicLayer::DynamicLayer( + TransformType transform_type, + int64_t layerId, + optional batchSize, + optional randomness, + optional prev_grad_mode, + optional prev_fwd_grad_mode, + optional functionalize_add_back_views) +{ + if (transform_type == TransformType::Grad) { + TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value()); + } + if (transform_type == TransformType::Jvp) { + TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value()); + } + switch (transform_type) { + case TransformType::Vmap: + interpreter_ = Interpreter::Vmap(layerId, batchSize.value(), randomness.value()); + break; + case TransformType::Grad: + interpreter_ = Interpreter::Grad(layerId, prev_grad_mode.value()); + break; + case TransformType::Jvp: + interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value()); + break; + case TransformType::Functionalize: + interpreter_ = Interpreter::Functionalize(layerId, functionalize_add_back_views.value()); + break; + default: + TORCH_INTERNAL_ASSERT(false); + } +} + +TransformType DynamicLayer::key() const { + return interpreter_.key(); +} + +int64_t DynamicLayer::layerId() const { + return interpreter_.level(); +} + +int64_t DynamicLayer::batchSize() const { + return VmapInterpreterPtr(&interpreter_).batchSize(); +} + +RandomnessType DynamicLayer::randomness() const { + return VmapInterpreterPtr(&interpreter_).randomness(); +} + +constexpr DispatchKeySet kFrontBackKeys({kDynamicLayerBackModeKey, kDynamicLayerFrontModeKey}); + +using DynmetaData = std::unordered_map>; +DynmetaData kDynMetaDataSingleton; + +static DynmetaData& getGlobalDynmetaData() { + return kDynMetaDataSingleton; +} + +class FuncTorchTLS : public FuncTorchTLSBase { + public: + FuncTorchTLS() {} + + std::unique_ptr deepcopy() const override { + auto result = std::make_unique(); + result->dynamicLayerStack = dynamicLayerStack; + return result; + } + + int64_t checkSupportsAutogradFunction() const override { + TORCH_CHECK(dynamicLayerStack.size() == 0, + "functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ", + "Please rewrite your function to not use autograd.Function while we work on fixing this"); + return 0; + } + + void checkSupportsInplaceRequiresGrad() const override { + TORCH_CHECK(dynamicLayerStack.size() == 0 || allow_inplace_requires_grad_, + "You are attempting to call Tensor.requires_grad_() (or perhaps using ", + "torch.autograd.functional.* APIs) inside of a function being transformed ", + "by a functorch transform. ", + "This is unsupported, please attempt to use the functorch transforms ", + "(e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() " + "outside of a function being transformed instead."); + } + void checkSupportsRetainGrad() const override { + TORCH_CHECK(dynamicLayerStack.size() == 0, + "You are attempting to call Tensor.retain_grad() ", + "inside of a function being transformed ", + "by a functorch transform. ", + "This is unsupported, please attempt to use the functorch transforms ", + "(e.g. grad, vjp, jacrev, jacfwd, hessian) or call retain_grad() " + "outside of a function being transformed instead."); + } + + std::vector dynamicLayerStack; + bool allow_inplace_requires_grad_ = false; +}; + +static FuncTorchTLS* getRawFunctorchTLS() { + auto& state = functorchTLSAccessor(); + if (state == nullptr) { + state = std::make_unique(); + } + // Raw pointer usage OK, `state` keeps the pointer alive + FuncTorchTLSBase* raw_state = state.get(); + FuncTorchTLS* result = static_cast(raw_state); + return result; +} + +void setInplaceRequiresGradAllowed(bool allowed) { + auto* functorch_tls = getRawFunctorchTLS(); + functorch_tls->allow_inplace_requires_grad_ = allowed; +} + +bool getInplaceRequiresGradAllowed() { + auto* functorch_tls = getRawFunctorchTLS(); + return functorch_tls->allow_inplace_requires_grad_; +} + + +static std::vector& dynamicLayerStackAccessor() { + return getRawFunctorchTLS()->dynamicLayerStack; +} + +std::shared_ptr getLifeHandleForLevel(int64_t level) { + auto it = getGlobalDynmetaData().find(level); + TORCH_INTERNAL_ASSERT(it != kDynMetaDataSingleton.end(), "level should be alive"); + return it->second; +} + +optional maybeCurrentDynamicLayer() { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + if (dynamicLayerStack.size() == 0) { + return {}; + } + return dynamicLayerStack.back(); +} + +struct SaveLocalDispatchKeySet { + public: + SaveLocalDispatchKeySet() { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); + auto& layer = dynamicLayerStack.back(); + auto tmp = c10::impl::tls_local_dispatch_key_set(); + layer.interpreter().saveLocalDispatchKeySet(tmp); + } + ~SaveLocalDispatchKeySet() { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); + auto& layer = dynamicLayerStack.back(); + auto tmp = layer.interpreter().getSavedLocalDispatchKeySet(); + layer.interpreter().clearSavedLocalDispatchKeySet(); + c10::impl::_force_tls_local_dispatch_key_set(tmp); + } + SaveLocalDispatchKeySet(const SaveLocalDispatchKeySet&) = delete; + SaveLocalDispatchKeySet& operator=(const SaveLocalDispatchKeySet&) = delete; +}; + +const std::vector& getDynamicLayerStack() { + return dynamicLayerStackAccessor(); +} + +void setDynamicLayerStack(const std::vector& stack) { + dynamicLayerStackAccessor() = stack; +} + +bool areTransformsActive() { + const auto& data = getGlobalDynmetaData(); + return !data.empty(); +} + +static DynamicLayer popDynamicLayer() { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); + auto result = dynamicLayerStack.back(); + dynamicLayerStack.pop_back(); + + if (dynamicLayerStack.size() == 0) { +#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE + if (c10::show_dispatch_trace_enabled()) { + std::cout << "DynamicLayer off" << std::endl; + } +#endif + setDynamicLayerFrontBackKeysIncluded(false); + } + + return result; +} + +static int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + int64_t layerId = 1 + dynamicLayerStack.size(); + TORCH_INTERNAL_ASSERT(layerId == dynamic_layer.layerId()); + dynamicLayerStack.emplace_back(dynamic_layer); + + if (layerId == 1) { + setDynamicLayerFrontBackKeysIncluded(true); +#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE + if (c10::show_dispatch_trace_enabled()) { + std::cout << "DynamicLayer on" << std::endl; + } +#endif + } + + return layerId; +} + +int64_t initAndPushDynamicLayer( + TransformType transform_type, + optional batch_size, + optional randomness, + optional prev_grad_mode, + optional prev_fwd_grad_mode, + optional functionalize_add_back_views) { + const auto& dynamicLayerStack = dynamicLayerStackAccessor(); + const auto layerId = 1 + dynamicLayerStack.size(); + DynamicLayer new_layer(transform_type, layerId, batch_size, randomness, prev_grad_mode, prev_fwd_grad_mode, functionalize_add_back_views); + pushDynamicLayer(std::move(new_layer)); + + auto& data = getGlobalDynmetaData(); + + TORCH_INTERNAL_ASSERT(data.find(layerId) == data.end()); + if (transform_type == TransformType::Grad) { + TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value()); + } + if (transform_type == TransformType::Jvp) { + TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value()); + } + data[layerId] = std::make_shared(true); + return layerId; +} + +DynamicLayer popDynamicLayerAndDeleteMetadata() { + auto result = popDynamicLayer(); + auto level = result.layerId(); + + // TODO: is this lock safe? No one else should be writing to the same bucket + // if (c10::show_dispatch_trace_enabled()) { + // std::cout << "deleting metadata" << std::endl; + // } + auto& data = getGlobalDynmetaData(); + auto it = data.find(level); + if (it == data.end()) { + return result; + } + // if (c10::show_dispatch_trace_enabled()) { + // std::cout << "deleted metadata for level " << level << std::endl; + // } + // invalidate the thing + *(it->second) = false; + data.erase(level); + return result; +} + +Tensor unwrapIfDead(const Tensor& tensor) { + auto* wrapped = maybeGetTensorWrapper(tensor); + if (!wrapped) { + return tensor; + } + if (wrapped->is_alive()) { + return tensor; + } + return wrapped->value(); +} + +void foreachTensorInplace(std::vector& args, int64_t begin, int64_t end, + std::function func) { + TORCH_INTERNAL_ASSERT(begin >= 0); + TORCH_INTERNAL_ASSERT(end >= 0); + TORCH_INTERNAL_ASSERT(begin <= end); + for (int64_t idx = begin; idx < end; idx++) { + auto ivalue = args[idx]; + // Tensor?[] translates to a c10::List so we need to peek inside List + if (ivalue.isList()) { + bool modified = false; + // TODO: might be more efficient if we scan first then not copy? Depends. + auto list = ivalue.toList().copy(); + for (const auto list_idx : c10::irange(0, list.size())) { + const auto& elt = list.get(list_idx); + if (elt.isTensor()) { + list.set(list_idx, func(elt.toTensor())); + modified = true; + } + } + if (modified) { + args[idx] = list; + } + continue; + } + if (ivalue.isTensorList()) { + auto list = ivalue.toTensorList(); + for (const auto list_idx : c10::irange(0, list.size())) { + list[list_idx] = func(list[list_idx]); + } + args[idx] = list; + } + TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict"); + if (!ivalue.isTensor()) { + continue; + } + Tensor value = ivalue.toTensor(); + Tensor replacement = func(value); + args[idx] = std::move(replacement); + // sanity checks + if (ivalue.toTensor().defined()) { + TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined()); + } + } +} + +std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) { + os << layer.layerId() << ":" << layer.key(); + return os; +} +std::ostream& operator<< (std::ostream& os, const std::vector& dls) { + os << "DynamicLayerStack[ "; + for (const auto& layer : dls) { + os << layer << " "; + } + os << "]"; + return os; +} + +bool isInplaceOp(const FunctionSchema& schema) { + if (!schema.is_mutable() || schema.returns().size() != 1) { + return false; + } + // Check that the first argument is being written to + const auto& first_arg_alias_info = schema.arguments().begin()->alias_info(); + if (!first_arg_alias_info || !first_arg_alias_info->isWrite()) { + return false; + } + // Check that none of the other args are being aliased + for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) { + const auto& alias_info = it->alias_info(); + if (alias_info) { + return false; + } + } + // Check that the first tensor is being returned (i.e., output has a (a!)) + const auto& return_alias_info = schema.returns()[0].alias_info(); + return return_alias_info && return_alias_info->isWrite(); +} + + +#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE +static void dump_local_tls() { + auto tls = c10::impl::tls_local_dispatch_key_set(); + std::cout << "[Local Include] " << tls.included_ << std::endl; + std::cout << "[Local Exclude] " << tls.excluded_ << std::endl; +} +#endif + +struct WithoutTop { + WithoutTop(); + ~WithoutTop(); + DynamicLayer layer_; +}; + +WithoutTop::WithoutTop(): layer_(popDynamicLayer()) {} +WithoutTop::~WithoutTop() { + pushDynamicLayer(std::move(layer_)); +} + +// NOTE: [forward-mode AD decompositions hack] +// +// The mechanism is: in DynamicLayerFrontMode, IF we are dispatching on the +// jvp transform, AND we have a decomposition for the operation, then run +// the decomposition. +// +// Let's break that down. There are a douple of moving pieces. +// +// 0. How do we know what transform we're dispatching on? +// Easy, check the top of the DynamicLayerStack and read the transform. +// +// 1. Next, we must identify when an operation (e.g. nll_loss_backward) +// gets dispatched to. +// - register a special kernel to the DynamicLayerFrontMode key +// (see JVP_DECOMP) +// - that special kernel invokes dynamicLayerFrontFallbackOperator with +// an arg indicating we're going to use a decomp +// +// 2. Next, we need to call the decomposition. See call_decomposition_for_jvp. +// We currently use python decompositions that we torchscript. + +// Ideally c10::OperatorHandle would have a field like this +// to identify the operator. +// The stuff here should map 1:1 with the operator name. +// aten::nll_loss_backward -> nll_loss_backward +// aten::add.Tensor -> add_Tensor + +static void call_decomposition_for_jvp( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + run_jit_decomposition(op, stack); +} + +static void dynamicLayerFrontFallbackOperator( + const c10::OperatorHandle& op, + torch::jit::Stack* stack, + bool decomp_jvp) { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); +#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE + if (c10::show_dispatch_trace_enabled()) { + std::cout << dynamicLayerStack << std::endl; + dump_local_tls(); + } +#endif + + // Hack: if jvp and we have a decomposition registered, then do the decomposition + if (dynamicLayerStack.back().interpreter().key() == TransformType::Jvp && + decomp_jvp) { + return call_decomposition_for_jvp(op, stack); + } + + // Save the current LocalDispatchKeySet (to the current DynamicLayer). + // Upon exiting the current scope, that LocalDispatchKeySet gets restored. + // When the current DynamicLayer dispatches to the next (inner) DynamicLayer, + // it will also temporarily restore the saved LocalDispatchKeySet. + SaveLocalDispatchKeySet guard; + + // Unwrap escaped GradWrappers + auto num_args = op.schema().arguments().size(); + foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), unwrapIfDead); + + auto& layer = dynamicLayerStack.back(); + layer.interpreter().process(op, stack); +} + +static c10::impl::ForceDispatchKeyGuard +restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) { + return c10::impl::ForceDispatchKeyGuard(key_set); +} + +void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + return dynamicLayerFrontFallbackOperator(op, stack, false); +} + +void dynamicLayerFrontFallBackWithDecomp( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + return dynamicLayerFrontFallbackOperator(op, stack, true); +} + +void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + auto& layer = dynamicLayerStackAccessor().back(); + auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet()); + WithoutTop guard; + + layer.interpreter().sendToNextInterpreter(op, stack); +} + +TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>()); +} + +TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>()); +} + +#define JVP_DECOMP(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>()); + +#define JVP_DECOMP2(op, overload) \ + m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallBackWithDecomp>()); + +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { + JVP_DECOMP(nll_loss_backward); + JVP_DECOMP(nll_loss2d_backward); + JVP_DECOMP(_log_softmax_backward_data); + JVP_DECOMP(_softmax_backward_data); + OP_DECOMPOSE(log_sigmoid); + JVP_DECOMP(log_sigmoid_forward); + JVP_DECOMP(native_layer_norm_backward); + JVP_DECOMP(native_batch_norm_backward); + JVP_DECOMP(cudnn_batch_norm_backward); +} + + +} +} // namespace at diff --git a/functorch/functorch/csrc/DynamicLayer.h b/functorch/functorch/csrc/DynamicLayer.h new file mode 100644 index 0000000000000..7d5b5f4a9d820 --- /dev/null +++ b/functorch/functorch/csrc/DynamicLayer.h @@ -0,0 +1,93 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declared bc I am lazy +namespace c10 { struct AutogradMetaInterface; } + +namespace at { +namespace functorch { + +// TODO: we can excise DynamicLayer in favor of Interpreter, +// But I am going to leave it for now as a compatiblity shim to avoid +// needing to refactor a lot of callsites... +struct FUNCTORCH_API DynamicLayer { + explicit DynamicLayer( + TransformType transform_type, + int64_t layerId, + optional batchSize = nullopt, + optional randomness = nullopt, + optional prev_grad_mode = nullopt, + optional pre_fwd_grad_mode = nullopt, + optional functionalize_add_back_views = nullopt); + + TransformType key() const; + int64_t layerId() const; + + const Interpreter& interpreter() const { return interpreter_; } + Interpreter& interpreter() { return interpreter_; } + + // Only valid for vmap + int64_t batchSize() const; + RandomnessType randomness() const; + + private: + Interpreter interpreter_; +}; + +FUNCTORCH_API int64_t initAndPushDynamicLayer( + TransformType transform_type, + optional batch_size = nullopt, + optional randomness = nullopt, + optional prev_grad_mode = nullopt, + optional prev_fwd_grad_mode = nullopt, + optional functionalize_add_back_views = nullopt); +FUNCTORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata(); +FUNCTORCH_API c10::optional maybeCurrentDynamicLayer(); +FUNCTORCH_API const std::vector& getDynamicLayerStack(); +FUNCTORCH_API void setDynamicLayerStack(const std::vector& stack); +FUNCTORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included); + +// NB: Not lock safe, you should only call this from Python where the GIL will +// prevent race conditions. +FUNCTORCH_API bool areTransformsActive(); + +// NB: not lock safe. TODO: does it need a lock? +FUNCTORCH_API std::shared_ptr getLifeHandleForLevel(int64_t level); + +// Returns if an operator is in-place. An operator is inplace if: +// 1. The first argument is a Tensor and it is being written to +// 2. The first argument is being returned +// 3. No other arguments are aliased +// Here is an example of an in-place operator: +// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +bool isInplaceOp(const c10::FunctionSchema& schema); + +Tensor unwrapIfDead(const Tensor& tensor); + +// Pretty printers +std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); +std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack); + +void setInplaceRequiresGradAllowed(bool allowed); +bool getInplaceRequiresGradAllowed(); + + +} +} // namespace at diff --git a/functorch/functorch/csrc/FunctionalizeInterpreter.cpp b/functorch/functorch/csrc/FunctionalizeInterpreter.cpp new file mode 100644 index 0000000000000..4242305636cfc --- /dev/null +++ b/functorch/functorch/csrc/FunctionalizeInterpreter.cpp @@ -0,0 +1,68 @@ +#include +#include +#include + +namespace at { namespace functorch { + +static void sanityCheckNotFunctional(const c10::OperatorHandle& op, torch::jit::Stack* stack, size_t num_args) { + foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), + [](const Tensor& tensor) { + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor)); + return tensor; + }); +} + +void FunctionalizeInterpreterPtr::processImpl( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(TransformType::Functionalize); + + // We always want to call the functionalization kernels if functionalize() is on the layer stack. + // It's the responsibility of the functionalization kernel to no-op and redispatch + // if none of the input tensors are functional. + setup_dispatch_key_tls(exclude, DispatchKeySet(DispatchKey::Functionalize)); + auto functionalization_add_back_views = functionalizeAddBackViews(); + // We have some side-car TLS that we can set to toggle the functionaliation behavior. + // If set, then we functionalization will only remove mutations, instead of + // removing both mutations AND view operators. + at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(functionalization_add_back_views); + + op.callBoxed(stack); + + auto ret_size = op.schema().returns().size(); + foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), + [&](const Tensor& tensor) { + if (at::functionalization::impl::isFunctionalTensor(tensor)) { + auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); + // Functorch is responsible for setting the level on the wrapper, since we don't + // have that info available in core (for now). + // We could just "propagate" the level from the input tensors inside of the functionalize kernels, + // but unfortunately we can't do that for factory operators. + wrapper->set_level(level()); + } + return tensor; + } + ); +} + +void FunctionalizeInterpreterPtr::sendToNextInterpreterImpl( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + // For now, we don't support nested functionalization calls. + // This check just enforces that - after the functionalize kernel runs + // and we hit the BackModeFallback, we'll have unwrapped our FunctionalTensors + // so we can check that the unwrapped thing is not another (nested) FunctionalTensor. + auto args_size = op.schema().arguments().size(); + sanityCheckNotFunctional(op, stack, args_size); + + // Re-dispatch + if (getDynamicLayerStack().size() == 0) { + sanityCheckStack(op, stack); + } + op.callBoxed(stack); + + auto ret_size = op.schema().returns().size(); + sanityCheckNotFunctional(op, stack, ret_size); +} + +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/FunctionalizeInterpreter.h b/functorch/functorch/csrc/FunctionalizeInterpreter.h new file mode 100644 index 0000000000000..5475b38f068f0 --- /dev/null +++ b/functorch/functorch/csrc/FunctionalizeInterpreter.h @@ -0,0 +1,19 @@ +#pragma once +#include + +namespace at { namespace functorch { + +struct FunctionalizeInterpreterPtr { + explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + bool functionalizeAddBackViews() const { + return c10::get(base_->meta()).functionalizeAddBackViews_; + } + private: + const Interpreter* base_; +}; + +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/Interpreter.cpp b/functorch/functorch/csrc/Interpreter.cpp new file mode 100644 index 0000000000000..cce9fa05f70eb --- /dev/null +++ b/functorch/functorch/csrc/Interpreter.cpp @@ -0,0 +1,122 @@ +#include +#include +#include +#include +#include +#include + +namespace at { namespace functorch { + +static DispatchKeySet get_all_dynlayer_keyset() { + // NB: FULL_AFTER does not include the dispatch key + + // "all dispatch keys between DynamicLayer{Front, Back}Mode, inclusive" + auto result = + DispatchKeySet(DispatchKeySet::FULL_AFTER, kDynamicLayerFrontModeKey) - + DispatchKeySet(DispatchKeySet::FULL_AFTER, kDynamicLayerBackModeKey); + result = result | DispatchKeySet({kDynamicLayerFrontModeKey}); + + // Hack: don't handle the autocast dispatch keys. Their interaction with functorch + // is weird. + result = result - autocast_dispatch_keyset; + + // Hack: don't handle kVmapModeKey. We need a better way of modeling this. + // In e.g. grad(vmap(f)), kVmapModeKey makes it so that all random operations, + // even after we are done handling the vmap layer, error out. + result = result.remove(kVmapModeKey); + + return result; +} + +// TODO: This should be constexpr, but there are some methods +// of DispatchKeySet that haven't been marked constexpr yet. +static DispatchKeySet all_dynlayer_keyset = get_all_dynlayer_keyset(); + +static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) { + if (key == TransformType::Vmap) { + // NB: Does not include kVmapModeKey. We may modulate the key when + // constructing the DynamicLayer, but we don't control it when entering/exiting + // the DynamicLayer. + return DispatchKeySet({kBatchedKey}); + } else if (key == TransformType::Grad || key == TransformType::Jvp) { + return autograd_dispatch_keyset.add(DispatchKey::ADInplaceOrView); + } else if (key == TransformType::Functionalize) { + return DispatchKeySet(DispatchKey::Functionalize); + } else { + TORCH_INTERNAL_ASSERT(false, "Unsupported key: ", key); + } +} + +DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key) { + DispatchKeySet exclude = all_dynlayer_keyset; + exclude = exclude.remove(kDynamicLayerBackModeKey); + exclude = exclude - keysForEnteringDynamicLayer(key); + return exclude; +} + +void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include) { + auto local_keyset = c10::impl::tls_local_dispatch_key_set(); + local_keyset.excluded_ = local_keyset.excluded_ | exclude; + local_keyset.included_ = local_keyset.included_ | include; + c10::impl::_force_tls_local_dispatch_key_set(local_keyset); +} + +std::ostream& operator<<(std::ostream& os, const TransformType& t) { + switch (t) { + case TransformType::Torch: + os << "Torch"; + break; + case TransformType::Vmap: + os << "Vmap"; + break; + case TransformType::Grad: + os << "Grad"; + break; + case TransformType::Jvp: + os << "Jvp"; + break; + case TransformType::Functionalize: + os << "Functionalize"; + break; + default: + TORCH_INTERNAL_ASSERT(false); + } + return os; +} + +void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + auto num_args = op.schema().arguments().size(); + foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), + [](const Tensor& tensor) { + + auto* wrapper = maybeGetTensorWrapper(tensor); + TORCH_INTERNAL_ASSERT(wrapper == nullptr); + auto* batched = maybeGetBatchedImpl(tensor); + TORCH_INTERNAL_ASSERT(batched == nullptr); + return tensor; + }); +} + +#define INTERPRETER_DISPATCH(type, method) \ + switch (key()) { \ + case TransformType::Vmap: \ + return VmapInterpreterPtr(this). method; \ + case TransformType::Grad: \ + return GradInterpreterPtr(this). method; \ + case TransformType::Jvp: \ + return JvpInterpreterPtr(this). method; \ + case TransformType::Functionalize: \ + return FunctionalizeInterpreterPtr(this). method; \ + default: \ + TORCH_INTERNAL_ASSERT(false, "Unrecognized transform"); \ + } + +void Interpreter::process(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + INTERPRETER_DISPATCH(key_, SINGLE_ARG(processImpl(op, stack))); +} + +void Interpreter::sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + INTERPRETER_DISPATCH(key_, SINGLE_ARG(sendToNextInterpreterImpl(op, stack))); +} + +}} diff --git a/functorch/functorch/csrc/Interpreter.h b/functorch/functorch/csrc/Interpreter.h new file mode 100644 index 0000000000000..2a1a426824b17 --- /dev/null +++ b/functorch/functorch/csrc/Interpreter.h @@ -0,0 +1,187 @@ +#pragma once + +// variant.h doesn't clean up after itself... +#include +#undef DECLTYPE_AUTO + +#include +#include +#include +#include +#include + +namespace at { namespace functorch { + +// NOTE: [functorch interpreter stack] +// +// functorch's dispatching system uses a stack of interpreters. +// Historically we've referred to this as the "DynamicLayerStack". +// +// An interpreter is something that reads in the code it is passed +// and then executes it. We have a different interpreter per-transform: +// the "VmapInterpreter" is responsible for reading in operators (like aten::mv) +// and executing the batched version of it (the batching rule for aten::mv). +// +// Concretely, each interpreter is responsible for two things: +// +// 1) process(ophandle, stack) +// Given an operator handle and a stack of arguments, the interpreter is +// responsible for figuring out how to execute the operation under the semantics +// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call +// the batching rule. +// +// The batching rules are stored as kernels on the FuncTorchBatched key, so the way +// VmapInterpreter calls the batching rule is roughly: (A) exclude all +// dispatch keys aside from the Batched key, (B) redispatch so we get to the +// Batched key. +// +// 2) sendToNextInterpreter(ophandle, stack) +// The VmapInterpreter, when it sees aten::mv, will process it into a call to +// aten::mm. It then needs to send the call to aten::mm to the next interpreter +// in the interpreter stack. +// +// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack) +// and most Interpreters will implement it this way. + +enum RandomnessType { + Error, // always errors when calling a random function + Same, // randomness appears the same across batches + Different, // randomness appears different across batches + END +}; + +enum class TransformType { + Torch, // Unused + Vmap, + Grad, // reverse-mode AD, aka vjp + Jvp, // forward-mode AD + Functionalize, +}; + +std::ostream& operator<<(std::ostream& os, const TransformType& t); + +// NOTE: [Interpreter "subclassing" design] +// +// How are various Interpreters for different transforms (vmap, grad, ...) +// implemented? +// +// Accessing interpreters is in the hot-path of functorch so we have a constraint +// that this code must be as fast as possible. +// +// As a result, we stay away from virtual methods and this causes our code +// to look a little funny. +// +// `Interpreter` is the struct for Interpreters. It holds ALL of the +// relevant information (what type of interpreter it is and the metadata). +// Metadata for each interpreter is represented as a Union (c10::variant) +// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...). +// +// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this +// if you want to access the metadata fields (like batchSize and randomness). +// +// Each type of interpreter (e.g. Vmap) has a convenience struct +// (e.g. VmapInterpreterPtr) associated with it. +// +// Construct the convenience struct with VmapInterpreterPtr(Interpreter*), +// and then one can access methods on VmapInterpreterPtr like so: +// >>> VmapInterpreterPtr(&interpreter).batchSize() +// +// Finally, Interpreter::process switches on the type of the interpreter +// and calls one of {Transform}Intepreter::processImpl under the hood. +// Same for Interpreter::sendToNextInterpreter :) + +struct VmapInterpreterMeta { + explicit VmapInterpreterMeta(int64_t batchSize, RandomnessType randomness) : + batchSize_(batchSize), randomness_(randomness) {} + int64_t batchSize_; + RandomnessType randomness_; +}; + +struct GradInterpreterMeta { + explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {} + bool prevGradMode_; +}; + +struct JvpInterpreterMeta { + explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {} + bool prevFwdGradMode_; +}; + +struct FunctionalizeInterpreterMeta { + explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) : + functionalizeAddBackViews_(functionalizeAddBackViews) {} + bool functionalizeAddBackViews_; +}; + +typedef c10::variant< + int64_t, + GradInterpreterMeta, + JvpInterpreterMeta, + VmapInterpreterMeta, + FunctionalizeInterpreterMeta +> InterpreterMeta; + + +struct Interpreter { + // factory functions + static Interpreter Vmap(int64_t level, int64_t batchSize, RandomnessType randomness) { + return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(batchSize, randomness)); + } + static Interpreter Grad(int64_t level, bool prevGradMode) { + return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode)); + } + static Interpreter Jvp(int64_t level, bool prevFwdGradMode) { + return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode)); + } + static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) { + return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews)); + } + + // methods + TransformType key() const { return type_; } + int64_t level() const { return level_; } + const InterpreterMeta& meta() const { return meta_; } + + void process(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack); + + void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) { + TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value()); + savedLocalDispatchKeySet_ = std::move(keyset); + } + void clearSavedLocalDispatchKeySet() { + TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); + savedLocalDispatchKeySet_ = c10::nullopt; + } + c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const { + TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); + return *savedLocalDispatchKeySet_; + } + + // Please don't use this + explicit Interpreter() = default; + + private: + explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta): + type_(type), level_(level), meta_(meta) {} + + // fields + TransformType type_; + int64_t level_; + optional savedLocalDispatchKeySet_; + InterpreterMeta meta_; +}; + +// Applies the following for-loop: +// for i in range(begin, end): +// args[i] = func(args[i]) +void foreachTensorInplace(std::vector& args, int64_t begin, int64_t end, + std::function func); + +DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key); + +void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include); + +void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp b/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp new file mode 100644 index 0000000000000..8181174ea196a --- /dev/null +++ b/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp @@ -0,0 +1,714 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace functorch { + + +// NOTE: [What is a batching rule?] +// +// This files contains batching rules written with the legacy (now-deprecated) +// batching rule API. +// Please try to use the new-style batching rule API (see writing_batch_rules.md) +// +// A *batching rule* implements the logic of how to call an operator on inputs +// that have zero or more additional batch dimensions. When one does a vmap, the +// dimension(s) being vmap'ed over get recorded as batch dimensions. +// +// For example, vmap(torch.add)(x, y) +// 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)]; +// 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)]; +// 3. and then runs `torch.add(batched_x, batched_y)`. + +// NOTE: [When should I add a batching rule?] +// When you are adding a new operator, you'll need to add a batching rule so +// that vmap can work efficiently with said operator. If you do not, we'll attempt +// to generate a slow fallback for the batching rule. + +// NOTE: [How to write batching rules?] +// The signature of a batching rule should look like exactly like the C++ signature +// of its operator. +// +// First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology. +// +// At a high level, what a batching rule does is the following: +// 1. Converts (logical) BatchedTensors to views on physical tensors. +// 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical +// arguments that correspond to the physical tensors. +// 3. Calls at:: operations on the physical tensors and arguments to produce +// some physical results. +// 4. Converts physical results back to BatchedTensors. +// +// Steps 1, 2, and 4 differ for operators with different batching behaviors. When +// writing a new batching rule, please select a VmapTransform that matches the +// batching behavior of your operation. The VmapTransform provides helper functions +// to do steps (1), (2), and (4). +// (see NOTE: [What is an VmapTransform?] in VmapTransforms.h) + +// Note: [Future plans] +// The API for writing a batching rule isn't stable. In the future, we'd like +// to think about the problem of translating these batching rules to TorchScript. +// Ideally batching rules in eager mode vs TorchScript would look pretty similar, +// if not use the same mechanism. In order to accomplish that we might have to +// do some refactoring. + +// PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor. +static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { + return dim == 0 || dim == -1; +} + +// This check should probably go into the dispatcher... +static bool participatesInCurrentLevel(const Tensor& self) { + auto maybe_level = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_level.has_value()); + auto current_level = maybe_level->layerId(); + auto* maybe_batched_impl = maybeGetBatchedImpl(self); + if (!maybe_batched_impl) { + return false; + } + auto self_level = maybe_batched_impl->level(); + TORCH_INTERNAL_ASSERT(self_level <= current_level); + return self_level == current_level; +} + +static bool participatesInCurrentLevel(TensorList self) { + for (const Tensor& tensor : self) { + if (participatesInCurrentLevel(tensor)) { + return true; + } + } + return false; +} + +bool isPhysicalScalarTensor(const Tensor& logical_tensor) { + if (logical_tensor.dim() > 0) { + return false; + } + auto* batched = maybeGetBatchedImpl(logical_tensor); + if (batched) { + return false; + } + return true; +} + +std::vector chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return self.chunk(chunks, dim); + } + + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::chunk(self_physical.tensor(), chunks, dim_physical); + self_physical.getPhysicalToLogicalMap().applyInplace(result); + return result; +} + +std::vector tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::tensor_split(self, sections, dim); + } + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical); + self_physical.getPhysicalToLogicalMap().applyInplace(result); + return result; +} + +std::vector tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::tensor_split(self, indices, dim); + } + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical); + self_physical.getPhysicalToLogicalMap().applyInplace(result); + return result; +} + +Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return self.squeeze_(dim); + } + auto* batched = maybeGetBatchedImpl(self); + TORCH_CHECK(batched && batched->bdim() == 0); + auto logical_dim = self.dim(); + auto dim_physical = 1 + maybe_wrap_dim(dim, logical_dim); + batched->value().squeeze_(dim_physical); + + // Also need to change some metadata... + batched->refreshTensorMetadata(); + return self; +} + +Tensor& unsqueeze__batching_rule(Tensor& self, int64_t dim) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return self.unsqueeze_(dim); + } + auto* batched = maybeGetBatchedImpl(self); + TORCH_CHECK(batched && batched->bdim() == 0); + auto logical_dim = self.dim(); + auto dim_physical = 1 + maybe_wrap_dim(dim, logical_dim + 1); + batched->value().unsqueeze_(dim_physical); + + // Also need to change some metadata... + batched->refreshTensorMetadata(); + return self; +} + +Tensor& fill_inplace_scalar_batching_rule(Tensor& self, Scalar value) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return self.fill_(value); + } + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().fill_(value); + return self; +} + +Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) { + auto value_batched = isBatchedTensor(value); + + if (value_batched) { + auto physical_args = + BroadcastingVmapTransform::logicalToPhysical({self, value}); + physical_args[0].tensor().copy_(physical_args[1].tensor()); + } else { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().fill_(value); + } + return self; +} + +Tensor& zero_inplace_batching_rule(Tensor &self) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().zero_(); + return self; +} + +Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::transpose(self, dim0, dim1); + } + // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works + // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens: + // >>> x = torch.randn(B0) # the per-examples are all scalars + // >>> vmap(lambda x: x.transpose(0, -1), x) + // then we replicate this behavior. + if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) && + is_allowed_dim_on_scalar_tensor(dim1)) { + return self; + } + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim0_physical = self_physical.getPhysicalDim(dim0); + auto dim1_physical = self_physical.getPhysicalDim(dim1); + auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) { + return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims; +} + +Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { + if (!participatesInCurrentLevel(grad)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::select_backward(grad, input_sizes, dim, index); + } + auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); + auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); + auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); + grad_input.select(physical_dim, index).copy_(grad_physical.tensor()); + return grad_physical.getPhysicalToLogicalMap().apply(grad_input); +} + +Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + if (!participatesInCurrentLevel(grad)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::slice_backward(grad, input_sizes, dim, start, end, step); + } + auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); + auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); + auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); + grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor()); + return grad_physical.getPhysicalToLogicalMap().apply(grad_input); +} + +std::vector split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::split(self, split_size, dim); + } + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::split(self_physical.tensor(), split_size, dim_physical); + self_physical.getPhysicalToLogicalMap().applyInplace(result); + return result; +} + +std::vector split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return split_with_sizes(self, split_sizes, dim); + } + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = split_with_sizes(self_physical.tensor(), split_sizes, dim_physical); + self_physical.getPhysicalToLogicalMap().applyInplace(result); + return result; +} + +std::vector unbind_batching_rule(const Tensor& self, int64_t dim) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::unbind(self, dim); + } + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::unbind(self_physical.tensor(), dim_physical); + self_physical.getPhysicalToLogicalMap().applyInplace(result); + return result; +} + +// Checks that the smallest batch stride is greater than the largest example +// stride. This is something we can support but we choose not to because it's +// potentially error prone. +static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) { + auto smallest_batch_stride = std::min_element( + physical_strides.begin(), physical_strides.begin() + num_batch_dims); + auto largest_example_stride = std::max_element( + physical_strides.begin() + num_batch_dims, physical_strides.end()); + if (largest_example_stride == physical_strides.end()) { + // No example dimensions + return; + } + if (num_batch_dims == 1 && physical_strides.size() > 0 && physical_strides[0] == 0) { + // degenerate batch dim + return; + } + TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride, + "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ", + "vmapped over are at the front of the tensor (in memory layout). When they are ", + "not at the front of the tensor this operation can be error prone so we " + "actively discourage it; please file us a bug report and/or try to ", + "express the as_strided operation in terms of PyTorch view operations"); +} + +// given (sizes, strides, storage_offset) returns the maximum location that +// can be indexed (or nullopt if such a location doesn't exist, e.g., tensors +// with zero-size dims). +static optional maximum_indexable_location( + IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) { + auto result = native::storage_size_for(sizes, strides); + if (result == 0) { + return nullopt; + } + return result + storage_offset; +} + +// Let x be the "first slice" of physical_tensor. +// This checks that the range of possible memory locations accessible by +// x.as_strided(sizes, strides, maybe_storage_offset) +// are within the bounds of possible memory locations accessible by x. +static void checkBasicAsStridedValidForSlice( + const Tensor& physical_tensor, + int64_t num_batch_dims, + IntArrayRef sizes, + IntArrayRef strides, + optional maybe_storage_offset) { + auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims); + auto slice_strides = physical_tensor.strides().slice(num_batch_dims); + auto base_offset = physical_tensor.storage_offset(); + + auto storage_offset = maybe_storage_offset.value_or(base_offset); + + auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset); + auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset); + + if (!max_as_strided_loc.has_value()) { + return; + } + if (!max_slice_loc.has_value()) { + TORCH_CHECK(false, + "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")", + "can access memory outside of `tensor`. `tensor` has no storage but the ", + "passed-in (size, stride, storage_offset) imply a result with some storage. ", + "This is not supported inside of vmap, please try to rewrite the ", + "`as_strided` call as a sequence of PyTorch view operations"); + } + + TORCH_CHECK( + *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset, + "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")", + "can access memory outside of `tensor`. `result` can access some", + "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ", + "`tensor` can only access some memory in range [", base_offset, ", ", + *max_slice_loc, "]. This is not supported inside of vmap, please try to", + "rewrite the `as_strided` call as a sequence of PyTorch view operations"); +} + +// What are the semantics of as_strided inside of vmap? +// y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs) +// This returns a view on `x`, `y`, such that each y[i] has: +// - sizes: `sizes` +// - strides: `strides` +// - storage_offset: offset + i * x.stride(batch_dim) +// +// In other words, it is as if we had treated each x[i] as having storage +// offset equal to xs.offset() and called as_strided(sizes, sizes, offset). +// (that is equivalent to x[i].as_strided( +// sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i) +// +// Note that this *may* be different from actually running as_strided +// in a for-loop. This is due to how as_strided takes in `offset` to be +// an *absolute* offset. As an example, consider: +// >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1) +// >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)] +// Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))! +// However, we consider the above for-loop comprehension to be a user error: +// a user should have written the following if they wanted to use as_strided +// in a per-sample way: +// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)] +Tensor as_strided_batching_rule( + const Tensor& tensor, + IntArrayRef sizes, + IntArrayRef strides, + optional storage_offset) { + if (!participatesInCurrentLevel(tensor)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::as_strided(tensor, sizes, strides, storage_offset); + } + auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor); + auto num_batch_dims = physical_view.numBatchDims(); + auto physical_sizes = physical_view.getPhysicalShape(sizes); + const auto& physical_tensor = physical_view.tensor(); + + // We can't rely on the physical as_strided call to do this for us because + // we do some sanity checks on the size/strides before calling into as_strided. + TORCH_CHECK(sizes.size() == strides.size(), + "Tensor.as_strided(size, stride, ...): size and stride must have the ", + "same length! Got size ", sizes, " and stride ", strides); + + // Sanity checks: + // 1. All batch dims are at the front in memory layout (not necessary for + // correctness, but we are worried the user might be doing crazy things) + // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset()) + // is valid for a slice of the input tensor. + // See Note: [When will the as_strided batching rule fail?] for details. + checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims); + checkBasicAsStridedValidForSlice( + physical_tensor, num_batch_dims, sizes, strides, storage_offset); + + // physical_strides = physical tensor's batch strides + (logical) strides + auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims); + VmapDimVector physical_strides; + physical_strides.reserve(num_batch_dims + strides.size()); + physical_strides.insert( + physical_strides.end(), batch_strides.begin(), batch_strides.end()); + physical_strides.insert( + physical_strides.end(), strides.begin(), strides.end()); + + // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + // is valid for all i, then it turns out that + // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds + // and creates a tensor y such that each y[i] references the same memory + // locations as zi. See NOTE: [When will the as_strided batching rule fail?] + auto result = physical_view.tensor().as_strided( + physical_sizes, physical_strides, storage_offset); + return physical_view.getPhysicalToLogicalMap().apply(result); +} + +// NOTE: [When will the as_strided batching rule fail?] +// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) +// is valid for all i, then it turns out that +// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and +// creates a tensor y such that each y[i] refers to the same memory as zi. +// +// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()). +// Furthermore, let's say that as a part of being "valid" this as_strided call +// does not return a result that can index memory not indexable by xs[i]. +// +// WLOG, assume that there's only one batch dim and it is at the front of the +// `xs` tensor. Let B be the batch size and S be the stride of the batch dim. +// - If the batch dim isn't at the front of the tensor, then we can just move it +// to the front with movedim/permute. This is always valid because it just swaps +// some strides around. +// - This proof also works for tensors with multiple batch dims. We just have to +// do a little accounting: +// - instead of [B], we'd have [B0, B1, ..., Bk]. +// - instead of [S], we'd have [S0, S1, ..., Sk]. +// - instead of i, we'd have a list of indices [I0, I1, ..., Ik] +// - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i +// +// [Equation 1] +// xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has: +// - sizes: sizes +// - strides: strides +// - offset: offset + S * i +// +// x.as_strided itself checks that: +// - (sizes, strides, offset) are in bounds for `x`'s storage. +// - strides are positive +// - offset is positive +// +// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) +// is valid, then +// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage. +// +// If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset) +// won't error out. So all we need to check is that the memory locations are +// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important) +// +// xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to +// xs.as_strided([B] + sizes, [S] + strides, offset) +// +// xs.as_strided([B] + sizes, [S] + strides, offset) has: +// - sizes: [B] + sizes +// - strides: [S] + strides +// - offset: offset +// +// xs.as_strided([B] + sizes, [S] + strides, offset)[i] has: +// - sizes: sizes +// - strides: strides +// - offset: offset + S * i +// These memory locations are exactly the same as what we got for [Equation 1], +// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid. +// +// [Hand-wavy proof of Claim 1] +// Part of our definition of being valid is that xs[i].as_strided(...) +// must return a tensor that only uses memory indexable by xs[i]. +// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies: +// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j] +// <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) +// (the largest-index memory location of xs[i].as_strided(...) must be \leq +// the largest-index memory location of xs[i]) +// +// Fiddling that inequality gives us: +// offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j] +// <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) +// +// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] +// <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) +// +// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] +// <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j) +// +// offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] +// <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j) +// (the largest-index memory location of xs.as_strided(size, stride, offset) +// is \leq than the largest-index memory location of xs) +// Under the assumptions we've made, the lower bound (lowest indexed memory) +// is trivially within the storage. +// +// Therefore ([B] + sizes, [S] + strides, offset) are in bounds for +// `xs`'s storage. + +template +Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) { + if (!participatesInCurrentLevel(input)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return Func(input, args...); + } + // guard against the user passing in a batch of scalar tensors with batch + auto* input_batched = unsafeGetBatchedImpl(input); + auto output_physical = Func(input_batched->value(), args...); + return makeBatched(output_physical, input_batched->bdim(), input_batched->level()); +} + +template +Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) { + if (!participatesInCurrentLevel(input)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return (input.*Func)(extra_args...); + } + auto* input_batched = unsafeGetBatchedImpl(input); + auto output_physical = (input_batched->value().*Func)(extra_args...); + return makeBatched(output_physical, input_batched->bdim(), input_batched->level()); +} + +Tensor cat_batching_rule(TensorList tensors, int64_t dim) { + if (!participatesInCurrentLevel(tensors)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::cat(tensors, dim); + } + auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors); + auto physical_tensors = fmap( + physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); + TORCH_INTERNAL_ASSERT( + tensors.size() > 0, "The dispatcher should not have dispatched here otherwise."); + auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim)); + return physical_views[0].getPhysicalToLogicalMap().apply(result); +} + +Tensor block_diag_batching_rule(TensorList tensors) { + if (!participatesInCurrentLevel(tensors)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::block_diag(tensors); + } + auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors); + auto physical_tensors = fmap( + physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); + TORCH_INTERNAL_ASSERT( + tensors.size() > 0, "The dispatcher should not have dispatched here otherwise."); + // Implementing this as a dummy for loop for now, since I'm not sure how to do it any better. + // I'm probably not accounting for potentially multiple batched dimensions? + auto bdim = physical_tensors[0].size(0); + std::vector batched_outputs; + batched_outputs.reserve(bdim); + for (const auto& i : c10::irange(bdim)) { + std::vector inputs_for_batch; + inputs_for_batch.reserve(physical_tensors.size()); + for (const auto& t : physical_tensors) { + inputs_for_batch.push_back(t[i]); + } + auto out_for_batch = at::block_diag(inputs_for_batch); + batched_outputs.push_back(out_for_batch.unsqueeze(0)); + } + auto result = at::cat(batched_outputs); + return physical_views[0].getPhysicalToLogicalMap().apply(result); +} + +Tensor stack_batching_rule(TensorList tensors, int64_t dim) { + if (!participatesInCurrentLevel(tensors)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return at::stack(tensors, dim); + } + auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors); + auto physical_tensors = fmap( + physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); + TORCH_INTERNAL_ASSERT( + tensors.size() > 0, "The dispatcher should not have dispatched here otherwise."); + // NB: stack wraps the dimensionality to (logical dim + 1), so we have to + // manually handle that here. + auto dim_physical = + physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1); + auto result = at::stack(physical_tensors, dim_physical); + return physical_views[0].getPhysicalToLogicalMap().apply(result); +} + +Tensor new_empty_strided_batching_rule( + const Tensor& self, + IntArrayRef size, + IntArrayRef stride, + optional dtype, + optional layout, + optional device, + optional pin_memory) { + if (!participatesInCurrentLevel(self)) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + return self.new_empty_strided( + size, stride, dtype, layout, device, pin_memory); + } + + auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); + auto physical_size = physical_view.getPhysicalShape(size); + + // Let [B0, B1, B2] be the shape of the batch dims. We're going to create + // the batch dimensions at the front of the tensor (in memory layout), + // irrespective of whether or not they are actually at the front (in memory layout) + // in the original `self` tensor. This is because when a user calls + // `new_empty_strided` in general, the `strides` they provide are for a new + // tensor and have no relation to the strides of the original tensor. + // + // So, the physical shape of the result should be ([B0, B1, B2] + size), + // but what about the physical strides? + // + // We're actually free to pick whatever stride we want: + // e.g., for size=[5, 3], stride=[0, 1], we could decide to + // use + // - physical size: [B0, B1, B2, 5, 3] + // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1] + // + // Let's select some reasonable strides such that: + // - The batch dims are "contiguous" with respect to each other + // - if empty_strided(size, stride) would have created a contiguous Tensor, + // then this new physical Tensor (with batch dims) is also contiguous + // + // Let S be the size of the storage if one were to construct a tensor + // with `size` and `stride` via empty_strided(size, stride). + // Then the physical sizes/strides should be: + // - physical size: [B0, B1, B2, 5, 3] + // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1] + auto batch_shape = IntArrayRef( + physical_view.tensor().sizes().begin(), physical_view.numBatchDims()); + + // physical_strides = [B1 * B2 * S, B2 * S, S] + auto physical_strides = at::detail::defaultStrides(batch_shape); + TORCH_CHECK(size.size() == stride.size(), + "new_empty_strided(sizes, strides): dimensionality of sizes (", + size.size(), ") must match dimensionality of strides (", + stride.size(), ")"); + auto storage_size = native::storage_size_for(size, stride); + for (auto& physical_stride : physical_strides) { + physical_stride *= storage_size; + } + + // physical_strides = [B1 * B2 * S, B2 * S, S] + strides + physical_strides.insert(physical_strides.end(), stride.begin(), stride.end()); + + auto result = physical_view.tensor().new_empty_strided( + physical_size, physical_strides, dtype, layout, device, pin_memory); + return physical_view.getPhysicalToLogicalMap().apply(result); +} + +bool BatchedTensor_is_leaf(const Tensor& self) { + if (torch::autograd::impl::get_autograd_meta(self)) { + return torch::autograd::impl::get_autograd_meta(self)->grad_fn_ == nullptr; + } else { + return true; + } +} + +Tensor& BatchedTensor_requires_grad_(Tensor& self, bool requires_grad) { + self.set_requires_grad(requires_grad); + return self; +} + + +TORCH_LIBRARY_IMPL(_, FT_BATCHED_KEY, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>()); +} + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + // still legacy b/c teturns multiple tensors + m.impl("tensor_split.sections", tensor_split_sections_batching_rule); + m.impl("tensor_split.indices", tensor_split_indices_batching_rule); + m.impl("split.Tensor", split_batching_rule); + m.impl("split_with_sizes", split_with_sizes_batching_rule); + m.impl("unbind.int", unbind_batching_rule); + m.impl("cat", cat_batching_rule); + m.impl("block_diag", block_diag_batching_rule); + m.impl("stack", stack_batching_rule); + + // still legacy b/c needs special inplace rules + m.impl("squeeze_.dim", squeeze_dim__batching_rule); + m.impl("unsqueeze_", unsqueeze__batching_rule); + + // still legacy because these are ridiculously complicated + m.impl("as_strided", as_strided_batching_rule); + m.impl("new_empty_strided", new_empty_strided_batching_rule); + +} +} // namespace functorch +} // namespace at diff --git a/functorch/functorch/csrc/LegacyVmapTransforms.cpp b/functorch/functorch/csrc/LegacyVmapTransforms.cpp new file mode 100644 index 0000000000000..3b57bd35e52ed --- /dev/null +++ b/functorch/functorch/csrc/LegacyVmapTransforms.cpp @@ -0,0 +1,222 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +namespace at { +namespace functorch { + +// Takes a BatchedTensorImpl, permutes all of the batch dims to the front, +// and then returns a physical version of the Tensor. +static Tensor permuteBatchDimsToFront(const BatchedTensorImpl* batched) { + const Tensor& physical_tensor = batched->value(); + if (batched->bdim() == 0) { + return physical_tensor; + } + const auto sizes = physical_tensor.sizes(); + VmapDimVector permutation(sizes.size(), 0); + permutation.reserve(sizes.size()); + const auto is_bdim = createBatchDimBitset(batched->bdim()); + int64_t idx = 0; + permutation[idx++] = batched->bdim(); + for (const auto ptr : c10::irange(0, sizes.size())) { + if (is_bdim[ptr]) { + continue; + } + permutation[idx++] = ptr; + } + return physical_tensor.permute(permutation); +} + +VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) { + auto* batched = maybeGetBatchedImpl(logical_tensor); + TORCH_INTERNAL_ASSERT( + batched, + "logicalToPhysical(tensor) should only be passed a BatchedTensor"); + return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->level()) }; +} + +int64_t VmapPhysicalView::numBatchDims() const { + return levels_.count(); +} + +int64_t VmapPhysicalView::numLogicalDims() const { + return /*physical*/tensor_.dim() - numBatchDims(); +} + +VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const { + auto logical_ndim = numLogicalDims(); + // NB: fmap doesn't have a SmallVector variant, so we don't use it here. + VmapDimVector result; + result.reserve(logical_ndim); + for (auto dim : logical_dims) { + result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims()); + } + return result; +} + +int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const { + auto logical_ndim = numLogicalDims(); + return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims(); +} + +VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const { + VmapDimVector result; + result.reserve(logical_shape.size() + numBatchDims()); + auto tensor_sizes = tensor_.sizes(); + result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims()); + result.insert(result.end(), logical_shape.begin(), logical_shape.end()); + return result; +} + +static std::tuple computeFrontBatchDimsFromLevels(std::bitset levels_bitset) { + int64_t level = 0; + int64_t dim = 0; + for (; level < kVmapNumLevels; level++) { + if (!levels_bitset[level]) { + continue; + } + break; + } + return std::make_tuple(dim, level); +} + +static Tensor moveDimToFrontAndExpand(Tensor tensor, optional dim, int64_t size) { + if (dim) { + tensor = tensor.movedim(*dim, 0); + } else { + tensor = tensor.unsqueeze(0); + auto expanded_sizes = tensor.sizes().vec(); + expanded_sizes[0] = size; + tensor = tensor.expand(expanded_sizes); + } + return tensor; +} + +// The algorithm is as follows: +// 1. Figure out what all of the collective levels in `logical_tensors` is. +// 2. Move all batch dims to the front of the tensors and add extra dims +// of size 1. At this point, every tensor will have a dimension for +// each of the collective levels. +// 3. Compute the batch_sizes. +// 4. Expand each physical tensor so that they have output batch size equal +// to `batch_sizes` +VmapPhysicalViewVec +MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) { + auto cur_level = maybeCurrentDynamicLayer().value().layerId(); + auto bdim_size = -1; + + // Figure out the batch size first + for (const auto& logical_tensor : logical_tensors) { + auto* batched = maybeGetBatchedImpl(logical_tensor); + if (!batched) { + continue; + } + if (batched->level() != cur_level) { + continue; + } + bdim_size = batched->value().size(batched->bdim()); + } + TORCH_INTERNAL_ASSERT(bdim_size != -1); + + std::bitset levels; + levels[cur_level] = 1; + + VmapPhysicalViewVec result; + for (const auto& logical_tensor : logical_tensors) { + auto* batched = maybeGetBatchedImpl(logical_tensor); + if (!batched || (batched->level() != cur_level)) { + // Unsqueeze dim 0, expand it to the correct shape + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto value = moveDimToFrontAndExpand(logical_tensor, {}, bdim_size); + result.emplace_back(std::move(value), levels); + continue; + } + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto physical = batched->value(); + auto value = moveDimToFrontAndExpand(physical, batched->bdim(), bdim_size); + result.emplace_back(std::move(value), levels); + } + + return result; +} + +static Tensor moveDimToFrontAndUnsqueeze(Tensor tensor, optional dim, int64_t example_ndim) { + if (dim) { + tensor = tensor.movedim(*dim, 0); + } else { + tensor = tensor.unsqueeze(0); + } + auto ndim = tensor.dim() - 1; + for (int64_t i = 0; i < example_ndim - ndim; i++) { + tensor = tensor.unsqueeze(1); + } + return tensor; +} + +VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) { + auto cur_level = maybeCurrentDynamicLayer().value().layerId(); + auto bdim_size = -1; + + // Figure out the batch size first + for (const auto& logical_tensor : logical_tensors) { + auto* batched = maybeGetBatchedImpl(logical_tensor); + if (!batched || (batched->level() != cur_level)) { + continue; + } + bdim_size = batched->value().size(batched->bdim()); + } + TORCH_INTERNAL_ASSERT(bdim_size != -1); + + std::bitset levels; + levels[cur_level] = 1; + + // figure out the example ndim + int64_t max_example_dim = -1; + for (const auto& logical_tensor : logical_tensors) { + max_example_dim = std::max(logical_tensor.dim(), max_example_dim); + } + + VmapPhysicalViewVec result; + for (const auto& logical_tensor : logical_tensors) { + auto* batched = maybeGetBatchedImpl(logical_tensor); + if (!batched || (batched->level() != cur_level)) { + // Unsqueeze dim 0, expand it to the correct shape + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto value = moveDimToFrontAndUnsqueeze(logical_tensor, {}, max_example_dim); + result.emplace_back(std::move(value), levels); + continue; + } + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto physical = batched->value(); + auto value = moveDimToFrontAndUnsqueeze(physical, batched->bdim(), max_example_dim); + result.emplace_back(std::move(value), levels); + } + + return result; +} + +VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const { + return VmapPhysicalToLogicalMap(levels_); +} + +Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const { + auto bdim_level = computeFrontBatchDimsFromLevels(levels_); + return makeBatched(physical_tensor, std::get<0>(bdim_level), std::get<1>(bdim_level)); +} + +void VmapPhysicalToLogicalMap::applyInplace(std::vector& physical_tensors) const { + for (const auto idx : c10::irange(0, physical_tensors.size())) { + physical_tensors[idx] = apply(physical_tensors[idx]); + } +} + +} +} // namespace at diff --git a/functorch/functorch/csrc/LegacyVmapTransforms.h b/functorch/functorch/csrc/LegacyVmapTransforms.h new file mode 100644 index 0000000000000..443c4e867de26 --- /dev/null +++ b/functorch/functorch/csrc/LegacyVmapTransforms.h @@ -0,0 +1,187 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace at { +namespace functorch { + +// This files contains the legacy (now-deprecated) batching rule API. +// Please try to use the new-style batching rule API (see writing_batch_rules.md) + +// This file contains abstractions used for transforming *logical* vmap arguments +// into *physical* arguments. (Keep reading for definitions of these terms). + +// NOTE: [Logical vs physical args] +// Consider the following vmap. +// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4)) +// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4], +// with batch dims 0 and 2: +// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)]) +// +// We say the *logical* view of the tensor has size [3] -- tensors inside +// `func` appear to have size [3]. +// However, the *physical* underlying tensor (the one passed to vmap) has size +// [2, 3, 4]. +// +// This notion of logical vs physical also extends to non-tensor arguments. +// Consider the previous tensor; let's assume the user called +// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical +// dimension they are reducing over is dim 0 but the physical dim is dim 1 +// (the first non-batch dimension) + +// Forward declared; see NOTE: [What is a VmapPhysicalView?] +struct VmapPhysicalView; + +// Most PyTorch operators take 4 or fewer inputs. +constexpr int64_t kVmapTransformStaticInputSize = 4; +using VmapPhysicalViewVec = SmallVector; + +// Pytorch generally advertises good performance for <= 5 dims. +// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap +// dimensions to get 8. Adjust this number as necessary +constexpr int64_t kVmapStaticDimVecSize = 8; +using VmapDimVector = SmallVector; + +// NOTE: [What is an VmapTransform?] +// An *VmapTransform* converts logical views of tensors to physical views. +// +// Batching rules use VmapTransforms to convert logical arguments to +// physical arguments, then call one or more at:: operator that handles the +// physical arguments, and then converts the physical result back to a logical +// argument. + +// VmapTransform for operators that take tensors with multiple batch dims. +// Given one or more logical views on Tensors, `logicalToPhysical` +// permutes all of the batch dims to the front of the tensor, aligns +// and expands the batch dims to match each other (according to their `level`), +// and returns a VmapPhysicalView on the tensor(s). +struct FUNCTORCH_API MultiBatchVmapTransform { + static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); + static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); +}; + +// VmapTransform for operators that broadcast all inputs. +// Given some logical views on Tensors, `logicalToPhysical`: +// - permutes all of the batch dims to the front of the tensors +// - aligns all the batch dims to the collective levels of all of the tensors. +// If a tensor does not have a batch dim for a vmap level, then it receives +// a size-one dimension for said level. +// - aligns the non-batch dims to have the same dimensionality, adding extra +// size-1 dimensions in between the batch dimensions and the non-batch dimensions +// so that the batch dimensions are lined up from the right. +// +// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch +// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors +// of size (B, 1, 2) and (B, 3, 2). +// +// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns +// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't +// actually *need* to return a tensor of size (1, 2) for the second tensor +// because the broadcasting operation takes care of that for us, but we do +// it anyways to keep things simple. +struct FUNCTORCH_API BroadcastingVmapTransform { + static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); +}; + +// Forward declared, if you're reading this file head to toe, don't worry about +// it yet. +struct VmapPhysicalToLogicalMap; + +// NOTE: [What is a VmapPhysicalView?] +// VmapPhysicalView represents a physical view on a Tensor. +// +// One can use it to further convert logical dimension indices, logical shapes, +// and more to their physical variants, or convert a new (physical) tensor into +// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented). +// +// VmapPhysicalView stores a physical tensor with all of its batch dimensions at +// the front and some levels that correspond to said batch dimensions. +// +// The levels bitset specifies which vmap levels correspond to the batch +// dimensions at the front of the tensor. In particular, the number of set bits +// corresponds to the number of batch dimensions on `tensor` and the rightmost +// bit of `levels` specifies the maximum number of nested vmaps we are in at +// this point in time. +// For example, given: +// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) +// +// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less +// than or equal to 3. +// bitset: 010100 +// ^ +// | +// levels: 012345 +struct FUNCTORCH_API VmapPhysicalView { + VmapPhysicalView(Tensor&& tensor, std::bitset levels) + : levels_(levels), tensor_(tensor) { + // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor)); + } + + Tensor& tensor() { return tensor_; } + const Tensor& tensor() const { return tensor_; } + + // Maps logical dim indices to physical dim indices. Also does dim wrapping. + // + // For example, given: + // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3}) + // + // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}. + // This is because the size of levels tell us that the first two dimensions + // of `tensor_` are batch dimensions, so a logical dim of `n` is actually + // a physical dim of `n + 2`. + VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const; + int64_t getPhysicalDim(int64_t logical_dim) const; + + // Returns a VmapPhysicalToLogicalMap object. This can be used for + // mapping a physical tensor to a new logical tensor (BatchedTensor) + VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; + + // Maps a logical shape to a physical shape by pre-pending the batch + // sizes to the logical shape. + VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; + + int64_t numBatchDims() const; + + private: + int64_t numLogicalDims() const; + + std::bitset levels_; + Tensor tensor_; +}; + +// Convenience struct used for mapping a physical tensor (a non-BatchedTensor) +// to a logical one (BatchedTensor). It holds some levels that are used to do the +// mapping and assumes that the batch dimensions in the physical tensor all +// occur at the front of the tensor. +struct FUNCTORCH_API VmapPhysicalToLogicalMap { + VmapPhysicalToLogicalMap(std::bitset levels): levels_(levels) {} + + // Maps a physical tensor to a new logical tensor (BatchedTensor). + // Assumes that all of the "batch dimensions" are at the front + // of the physical tensor. For example, given: + // - x = rank-4 Tensor with size 2, 3, 5, 7 + // - levels = (2, 4) + // Returns: + // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) + Tensor apply(const Tensor& physical_tensor) const; + + // Given a vector of physical tensors, + // 1. maps each tensor to a new logical tensor. Assumes that all of the + // "batch dimensions" are at the front of the physical tensors. + // 2. stores the new logical tensors back into the passed-in vector. This is + // to avoid additional dynamic allocations. + void applyInplace(std::vector& physical_tensors) const; + + std::bitset levels_; +}; + + +} +} // namespace at diff --git a/functorch/functorch/csrc/Macros.h b/functorch/functorch/csrc/Macros.h new file mode 100644 index 0000000000000..9ca13023fc92f --- /dev/null +++ b/functorch/functorch/csrc/Macros.h @@ -0,0 +1,10 @@ +#pragma once + +// FUNCTORCH_BUILD_MAIN_LIB is set in setup.py. +// We don't really need to use C10_IMPORT because no C++ project relies on +// functorch. But leaving it here for future-proofing. +#ifdef FUNCTORCH_BUILD_MAIN_LIB +#define FUNCTORCH_API C10_EXPORT +#else +#define FUNCTORCH_API C10_IMPORT +#endif diff --git a/functorch/functorch/csrc/PlumbingHelper.cpp b/functorch/functorch/csrc/PlumbingHelper.cpp new file mode 100644 index 0000000000000..e75fb82a38642 --- /dev/null +++ b/functorch/functorch/csrc/PlumbingHelper.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +namespace at { namespace functorch { + +Tensor makeBatched(const Tensor& tensor, optional bdim, int64_t level) { + if (bdim.has_value()) { + TORCH_INTERNAL_ASSERT(*bdim >= 0); + TORCH_INTERNAL_ASSERT(*bdim < tensor.dim()); + return makeBatched(tensor, bdim.value(), level); + } + return tensor; +} + +std::vector makeBatchedVector(const std::vector& tensors, optional bdim, int64_t level) { + std::vector res; + for (const auto & tensor : tensors) { + res.emplace_back(makeBatched(tensor, bdim, level)); + } + return res; +} + +std::tuple> unwrapTensorAtLevel(const Tensor& tensor, int64_t level) { + auto* batched = maybeGetBatchedImpl(tensor); + if (!batched) { + return std::make_tuple(tensor, nullopt); + } + if (batched->level() == level) { + return std::make_tuple(batched->value(), batched->bdim()); + } + return std::make_tuple(tensor, nullopt); +} + +bool isBatchedAtLevel(const Tensor& tensor, int64_t level) { + auto result = unwrapTensorAtLevel(tensor, level); + return std::get<1>(result).has_value(); +} + +bool isBatchedAtLevel(const c10::optional& maybe_tensor, int64_t level) { + if (!maybe_tensor.has_value()) { + return false; + } + return isBatchedAtLevel(*maybe_tensor, level); +} + +bool isBatchedAtLevel(TensorList tensors, int64_t level) { + for (const auto& tensor : tensors) { + if (isBatchedAtLevel(tensor, level)) { + return true; + } + } + return false; +} + +bool isBatchedAtLevel(const c10::List> maybe_tensors, int64_t level) { + for (const auto idx : c10::irange(0, maybe_tensors.size())) { + const auto& maybe_tensor = maybe_tensors.get(idx); + if (isBatchedAtLevel(maybe_tensor, level)) { + return true; + } + } + return false; +} + +bool areAnyBatchedAtLevel(ArrayRef> maybe_tensors, int64_t level) { + for (const auto& maybe_tensor : maybe_tensors) { + if (isBatchedAtLevel(maybe_tensor, level)) { + return true; + } + } + return false; +} + + +}} diff --git a/functorch/functorch/csrc/PlumbingHelper.h b/functorch/functorch/csrc/PlumbingHelper.h new file mode 100644 index 0000000000000..8a8441c3bb29c --- /dev/null +++ b/functorch/functorch/csrc/PlumbingHelper.h @@ -0,0 +1,39 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include +#include + +namespace at { namespace functorch { + +Tensor makeBatched(const Tensor& tensor, optional bdim, int64_t level); +std::tuple> unwrapTensorAtLevel(const Tensor& tensor, int64_t level); + +std::vector makeBatchedVector(const std::vector& tensors, optional bdim, int64_t level); + +// Returns True if ANY tensor in tensors is batched at level +bool isBatchedAtLevel(TensorList tensors, int64_t level); +bool isBatchedAtLevel(const c10::List> maybe_tensors, int64_t level); +bool isBatchedAtLevel(const Tensor& tensor, int64_t level); +bool isBatchedAtLevel(const c10::optional& maybe_tensor, int64_t level); + +// Convenience helper. Returns true if any tensor is batched at level +bool areAnyBatchedAtLevel(ArrayRef> maybe_tensors, int64_t level); + +inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) { + if (ivalue.isTensor()) { + auto maybe_level = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_level.has_value()); + auto current_level = maybe_level->layerId(); + return isBatchedAtLevel(ivalue.toTensor(), current_level); + } + // TODO: should really check this + return false; +} + +}} diff --git a/functorch/functorch/csrc/PyTorchOperatorHacks.cpp b/functorch/functorch/csrc/PyTorchOperatorHacks.cpp new file mode 100644 index 0000000000000..e2f35bea4ef63 --- /dev/null +++ b/functorch/functorch/csrc/PyTorchOperatorHacks.cpp @@ -0,0 +1,398 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace functorch { + +// TODO: all of these should be fixed in a more blessed way. In particular, +// it is bad if any of these go out-of-sync with the implementations in +// pytorch/pytorch. +// +// This file contains hacks for composite PyTorch operators that are problematic. +// For example, the composite op might have in-place operations, +// or call data_ptr. We have some idea of how to fix these things in the long term +// (e.g. functionalization for the in-place operations). + +// TODO: can replace with better conditional functionalization +static Tensor value_selecting_reduction_backward_hack( + const Tensor& grad, + int64_t dim, + const Tensor& indices, + IntArrayRef sizes, + bool keepdim) { + if (!keepdim && sizes.size() > 0) { + auto grad_ = grad.unsqueeze(dim); + auto indices_ = indices.unsqueeze(dim); + return at::zeros(sizes, grad_.options()).scatter(dim, indices_, grad_); + } + return at::zeros(sizes, grad.options()).scatter(dim, indices, grad); +} + +// TODO: upstream into core +Tensor index_select_backward_hack(const Tensor& grad, IntArrayRef self_sizes, int64_t dim, const Tensor& index) { + return at::zeros(self_sizes, grad.options()).index_add(dim, index, grad); +} + +// TODO: https://github.com/pytorch/pytorch/issues/69991 +Tensor frobenius_norm_dim_hack(const Tensor& self, IntArrayRef dim, bool keepdim) { + if (dim.size() == 1 || dim.size() == 0) { + return at::norm(self, 2, dim, keepdim); + } else { + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); + TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead"); + if (self.is_complex()){ + return at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim)); + } else { + return at::sqrt(at::sum((self * self), dim_, keepdim)); + } + } +} + +static optional> unwrap(const Tensor& tensor) { + auto* wrapped = maybeGetTensorWrapper(tensor); + if (wrapped) { + if (wrapped->level().has_value()) { + return std::make_tuple(wrapped->value(), *wrapped->level()); + } + return unwrap(wrapped->value()); + } + auto* batched = maybeGetBatchedImpl(tensor); + if (batched) { + return std::make_tuple(batched->value(), batched->level()); + } + return nullopt; +} + +static bool can_perform_inplace(const Tensor& a, const Tensor& b) { + // TODO: generalize this to more transforms + auto a_ = unwrap(a); + auto b_ = unwrap(b); + if (!a_.has_value() && b_.has_value()) { + return false; + } + if (!a_.has_value() && !b_.has_value()) { + return true; + } + if (a_.has_value() && !b_.has_value()) { + return true; + } + TORCH_INTERNAL_ASSERT(a_.has_value() && b_.has_value()); + + // If b has any wrapper that a does not, then we cannot do a.inplace_(b) + if (std::get<1>(*a_) < std::get<1>(*b_)) { + return false; + } + if (std::get<1>(*a_) > std::get<1>(*b_)) { + return can_perform_inplace(std::get<0>(*a_), b); + } + return can_perform_inplace(std::get<0>(*a_), std::get<0>(*b_)); +} + +// TODO: linear is pretty important for performance, but I'm not sure how to work +// around the in-place. +Tensor linear_hack(const Tensor& input, const Tensor& weight, const c10::optional& bias_opt) { + // See [Note: hacky wrapper removal for optional tensor] + auto bias = bias_opt.has_value() + ? c10::MaybeOwned::borrowed(*bias_opt) + : c10::MaybeOwned::owned(c10::in_place); + + if (input.is_mkldnn()) { + return at::mkldnn_linear(input, weight, *bias); + } +#if defined(C10_MOBILE) + if (xnnpack::use_linear(input, weight, *bias)) { + return xnnpack::linear(input, weight, *bias); + } +#endif + if (input.dim() == 2 && bias->defined()) { + // Fused op is marginally faster. + return at::addmm(*bias, input, weight.t()); + } + if (input.dim() == 3 && bias->defined() && input.is_contiguous()) { + // Also hit the fused path for contiguous 3D input. + const auto input_sizes = input.sizes(); + const auto result = at::addmm(*bias, input.view({input_sizes[0] * input_sizes[1], input_sizes[2]}), weight.t()); + return result.view({input_sizes[0], input_sizes[1], result.size(1)}); + } + auto output = at::matmul(input, weight.t()); + if (bias->defined()) { + const auto& stack = getDynamicLayerStack(); + bool any_vmap_layers = std::any_of( + stack.begin(), stack.end(), + [](const DynamicLayer& dl){ return dl.key() == TransformType::Vmap; }); + if (any_vmap_layers) { + return output.add(*bias); + } + return output.add_(*bias); + } + return output; +} + +Tensor nuclear_norm_dim_hack(const Tensor& self, IntArrayRef dim, bool keepdim) { + TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2"); + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); + + auto permutation = at::native::create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); + Tensor p = self.permute(permutation); + Tensor result = at::sum(at::linalg_svdvals(p), -1, keepdim); + if (keepdim) { + result = result.unsqueeze(-1); + auto permutation_reverse = at::native::create_reverse_permutation(permutation); + result = result.permute(permutation_reverse); + } + return result; +} + +Tensor nuclear_norm_hack(const Tensor& self, bool keepdim) { + TORCH_CHECK( + self.dim() == 2, + "Expected a tensor with 2 dimensions, but got a tensor with ", + self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead."); + + return nuclear_norm_dim_hack(self, {0, 1}, keepdim); +} + +Tensor binary_cross_entropy_with_logits_backward_hack( + const Tensor& grad, const Tensor& input, const Tensor& target, + const c10::optional& weight_opt, + const c10::optional& pos_weight_opt, int64_t reduction) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();}); + + Tensor grad_input; + if (pos_weight.defined()) { + auto t = pos_weight.mul(target); + grad_input = t.add(1).sub_(target).mul(input.sigmoid()).sub_(t).mul(grad); + } else { + grad_input = (input.sigmoid() - target).mul(grad); + } + + if (weight.defined()) { + grad_input.mul(weight); + } + + if (reduction == at::Reduction::Mean) { + return grad_input / input.numel(); + } + + return grad_input; +} + +static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) { + if (reduction == at::Reduction::Mean) { + return unreduced.mean(); + } else if (reduction == at::Reduction::Sum) { + return unreduced.sum(); + } + return unreduced; +} + +Tensor binary_cross_entropy_with_logits_hack( + const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + const c10::optional& pos_weight_opt, + int64_t reduction) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();}); + + Tensor loss; + auto max_val = (-input).clamp_min(0); + if (pos_weight.defined()) { + // pos_weight need to be broadcasted, thus mul(target) is not inplace. + auto log_weight = (pos_weight - 1).mul(target).add_(1); + loss = (1 - target).mul(input).add(log_weight.mul(((-max_val).exp_().add((-input - max_val).exp_())).log_().add_(max_val))); + } else { + loss = (1 - target).mul(input).add_(max_val).add_((-max_val).exp_().add((-input -max_val).exp_()).log_()); + } + + if (weight.defined()) { + loss = loss * weight; + } + + return apply_loss_reduction(loss, reduction); +} + +Tensor trace_backward_decomp(const Tensor& grad, IntArrayRef sizes) { + if (sizes.size() != 2) { + throw std::runtime_error("expected matrix input"); + } + auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options()); + auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong)); + // Workaround using index_put instead of yet unsupported index_fill_ + grad_input = grad_input.index_put({indices}, grad); + return grad_input.view(sizes); +} + +// dropout hack +// TODO: make the following changes in pytorch/pytorch +namespace dropout_hack { + +template +using Ctype = typename std::conditional::type; + +Tensor make_feature_noise(const Tensor& input) { + auto input_sizes = input.sizes(); + TORCH_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input"); + std::vector sizes; + sizes.reserve(input.dim()); + sizes.push_back(input_sizes[0]); + sizes.push_back(input_sizes[1]); + for (const auto i : c10::irange(2, input.dim())) { + (void)i; //Suppress unused variable warning + sizes.push_back(1); + } + // NB: THIS WAS CHANGED FROM THE ORIGINAL + return at::empty(sizes, input.options()); +} + +bool is_fused_kernel_acceptable(const Tensor& input, double p) { + return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0; +} + +// NB: sure, we could have used different overloads here, but I would feel insecure +// knowing that this dispatch depends only on the constness of the references +template +Tensor& multiply(Tensor& input, const Tensor& noise) { + static_assert(inplace, "Wrong multiply overload triggered in Dropout.cpp"); + return input.mul_(noise); +} + +template +Tensor multiply(const Tensor& input, const Tensor& noise) { + static_assert(!inplace, "Wrong multiply overload triggered in Dropout.cpp"); + return input.mul(noise); +} + +template +Ctype _dropout_impl(T& input, double p, bool train) { + TORCH_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p); + if (p == 0 || !train || input.numel() == 0) { + return input; + } + + if (p == 1) { + return multiply(input, at::zeros({}, input.options())); + } + + at::Tensor b; // used for alpha_dropout only + + // NB: THIS WAS CHANGED FROM THE ORIGINAL + Tensor noise; + if (feature_dropout) { + auto empty = make_feature_noise(input); + noise = at::bernoulli(empty, 1 - p); + } else { + // NB: it is important that this is at::empty and not at::empty_like + auto empty = at::empty({}, input.options()).expand(input.sizes()); + noise = at::bernoulli(empty, 1 - p); + } + + if (alpha_dropout) { + constexpr double alpha = 1.7580993408473766; + double a = 1. / std::sqrt((alpha * alpha * p + 1) * (1 - p)); + b = noise.add(-1).mul_(alpha * a).add_(alpha * a * p); + noise.mul_(a); + } else { + noise.div_(1 - p); + } + + if (!alpha_dropout) { + return multiply(input, noise); + } else { + return multiply(input, noise).add_(b); + } +} + +#define ALIAS_SPECIALIZATION(ALIAS_NAME, IS_FEATURE, IS_ALPHA) \ +template \ +Ctype ALIAS_NAME(Args&&... args) { \ + return _dropout_impl(std::forward(args)...); \ +} + +ALIAS_SPECIALIZATION(_dropout, false, false) +ALIAS_SPECIALIZATION(_feature_dropout, true, false) +ALIAS_SPECIALIZATION(_alpha_dropout, false, true ) +ALIAS_SPECIALIZATION(_feature_alpha_dropout, true, true ) + +Tensor dropout(const Tensor& input, double p, bool train) { + auto result = [&]() { + NoNamesGuard guard; + if (train && is_fused_kernel_acceptable(input, p)) { + return std::get<0>(at::native_dropout(input, p, train)); + } + return _dropout(input, p, train); + }(); + namedinference::propagate_names(result, input); + return result; +} + +Tensor& dropout_(Tensor& input, double p, bool train) { + return _dropout(input, p, train); +} + +Tensor feature_dropout(const Tensor& input, double p, bool train) { + return _feature_dropout(input, p, train); +} + +Tensor& feature_dropout_(Tensor& input, double p, bool train) { + return _feature_dropout(input, p, train); +} + +Tensor alpha_dropout(const Tensor& input, double p, bool train) { + return _alpha_dropout(input, p, train); +} + +Tensor& alpha_dropout_(Tensor& input, double p, bool train) { + return _alpha_dropout(input, p, train); +} + +Tensor feature_alpha_dropout(const Tensor& input, double p, bool train) { + return _feature_alpha_dropout(input, p, train); +} + +Tensor& feature_alpha_dropout_(Tensor& input, double p, bool train) { + return _feature_alpha_dropout(input, p, train); +} + +} // dropout_hack + +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { + m.impl("value_selecting_reduction_backward", value_selecting_reduction_backward_hack); + m.impl("index_select_backward", index_select_backward_hack); + m.impl("frobenius_norm.dim", frobenius_norm_dim_hack); + m.impl("linear", linear_hack); + m.impl("binary_cross_entropy_with_logits_backward", binary_cross_entropy_with_logits_backward_hack); + m.impl("binary_cross_entropy_with_logits", binary_cross_entropy_with_logits_hack); + m.impl("trace_backward", trace_backward_decomp); + + m.impl("dropout", dropout_hack::dropout); + m.impl("feature_dropout", dropout_hack::feature_dropout); + m.impl("alpha_dropout", dropout_hack::alpha_dropout); + m.impl("feature_alpha_dropout", dropout_hack::feature_alpha_dropout); + + m.impl("dropout_", dropout_hack::dropout_); + m.impl("feature_dropout_", dropout_hack::feature_dropout_); + m.impl("alpha_dropout_", dropout_hack::alpha_dropout_); + m.impl("feature_alpha_dropout_", dropout_hack::feature_alpha_dropout_); + + m.impl("nuclear_norm", nuclear_norm_hack); + m.impl("nuclear_norm.dim", nuclear_norm_dim_hack); +} + +}} diff --git a/functorch/functorch/csrc/TensorWrapper.cpp b/functorch/functorch/csrc/TensorWrapper.cpp new file mode 100644 index 0000000000000..054be6495c37e --- /dev/null +++ b/functorch/functorch/csrc/TensorWrapper.cpp @@ -0,0 +1,192 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include +#include + +namespace at { +namespace functorch { + +void dumpTensor(std::ostream& ss, const Tensor& tensor) { + auto* wrapped = maybeGetTensorWrapper(tensor); + if (!wrapped) { + auto* batched = maybeGetBatchedImpl(tensor); + if (batched) { + ss << "Batched[lvl=" << batched->level() << " dim=" << batched->bdim() << ", "; + dumpTensor(ss, batched->value()); + ss << "]"; + return; + } + ss << "Tensor" << tensor.sizes(); + return; + } + ss << "Wrapper["; + if (wrapped->level().has_value()) { + ss << "lvl=" << wrapped->level().value() << ", "; + } else { + ss << "dead, "; + } + dumpTensor(ss, wrapped->value()); + ss << "]"; +} + +void TensorWrapper::refreshMetadata() { + auto dim = value_.dim(); + auto sizes = value_.sizes(); + auto strides = value_.strides(); + storage_offset_ = value_.storage_offset(); + sizes_and_strides_.resize(value_.dim()); + for (int64_t i = 0; i < dim; i++) { + sizes_and_strides_.size_at_unchecked(i) = sizes[i]; + sizes_and_strides_.stride_at_unchecked(i) = strides[i]; + } + + refresh_numel(); + refresh_contiguous(); +} + +void dumpTensorCout(const Tensor& tensor) { + dumpTensor(std::cout, tensor); + + std::cout << std::endl; +} + +c10::intrusive_ptr makeTensorWrapperPtr(const Tensor& tensor, int64_t level, bool should_be_alive) { + auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({ + DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA}); + auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate); + key_set = key_set.add(kGradWrapperKey); + if (should_be_alive) { + return c10::make_intrusive(key_set, tensor, level, getLifeHandleForLevel(level)); + } else { + return c10::make_intrusive(key_set, tensor, level, std::make_shared(false)); + } +} + +Tensor makeTensorWrapper(const Tensor& tensor, int64_t level) { + auto wrapped = maybeGetTensorWrapper(tensor); + if (wrapped) { + TORCH_INTERNAL_ASSERT(wrapped->level() < level); + } + + auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({ + DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA}); + auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate); + key_set = key_set.add(kGradWrapperKey); + auto life_handle = getLifeHandleForLevel(level); + auto result = at::detail::make_tensor(key_set, tensor, level, std::move(life_handle)); + TORCH_INTERNAL_ASSERT(result.key_set().has(kGradWrapperKey)); + return result; +} + +bool TensorWrapper::is_alive() const { + return *is_alive_; +} + +c10::intrusive_ptr TensorWrapper::shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive()); + dest_impl->set_version_counter(version_counter); + + // TODO: is this even right? + dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return dest_impl; +} + +c10::intrusive_ptr TensorWrapper::shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive()); + dest_impl->set_version_counter(version_counter); + + // TODO: is this even right? + dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return dest_impl; +} + +void TensorWrapper::shallow_copy_from(const c10::intrusive_ptr& impl) { + TORCH_INTERNAL_ASSERT(false, "NYI"); +} + +TensorWrapper::TensorWrapper( + c10::DispatchKeySet key_set, + Tensor value, + int64_t level, + std::shared_ptr is_alive, + bool use_value_sizes_strides) + : TensorImpl(key_set, value.dtype(), value.device()) + , value_(std::move(value)) + , level_(level) + , is_alive_(std::move(is_alive)) +{ + TORCH_INTERNAL_ASSERT(value_.defined()); + + // TODO: need to reset sizes/strides on mutation + TORCH_INTERNAL_ASSERT(use_value_sizes_strides); + refreshMetadata(); + + set_storage_access_should_throw(); +} + +// The following are some internal inherited methods that we do not support. +// They should never get called. +void TensorWrapper::set_size(int64_t dim, int64_t new_size) { + TORCH_INTERNAL_ASSERT(false, "Can't set_size for TensorWrapper"); +} +void TensorWrapper::set_stride(int64_t dim, int64_t new_stride) { + TORCH_INTERNAL_ASSERT(false, "Can't set_stride for TensorWrapper"); +} +void TensorWrapper::set_storage_offset(int64_t storage_offset) { + TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for TensorWrapper"); +} + +const char* TensorWrapper::tensorimpl_type_name() const { + return "TensorWrapper"; +} + + +TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor) { + if (!tensor.key_set().has(kGradWrapperKey)) { + return nullptr; + } + return (TensorWrapper*)(tensor.unsafeGetTensorImpl()); +} + +void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + auto args_size = op.schema().arguments().size(); + int64_t unwrapped_count = 0; + auto unwrapIfDeadAndIncrement = [&](const Tensor& tensor) { + auto* wrapped = maybeGetTensorWrapper(tensor); + if (!wrapped) { + return tensor; + } + if (wrapped->is_alive()) { + return tensor; + } + unwrapped_count++; + return wrapped->value(); + }; + + foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrapIfDeadAndIncrement); + TORCH_INTERNAL_ASSERT(unwrapped_count > 0, "Should have at least one dead wrapper"); + + // re-dispatch + op.callBoxed(stack); +} + +// TensorWrapper backend fallback: Unwrap and fallthrough. + +TORCH_LIBRARY_IMPL(_, FT_GRAD_WRAPPER_KEY, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>()); +} + +} +} // namespace at diff --git a/functorch/functorch/csrc/TensorWrapper.h b/functorch/functorch/csrc/TensorWrapper.h new file mode 100644 index 0000000000000..7abfe1782d385 --- /dev/null +++ b/functorch/functorch/csrc/TensorWrapper.h @@ -0,0 +1,68 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace at { +namespace functorch { + +struct FUNCTORCH_API TensorWrapper : public c10::TensorImpl { + explicit TensorWrapper( + c10::DispatchKeySet key_set, + Tensor value, + int64_t level, + std::shared_ptr is_alive, + bool use_value_sizes_strides = true); + + // Override a bunch of methods inherited from TensorImpl to return error messages + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; + + void refreshMetadata(); + + const Tensor& value() const { + return value_; + } + optional level() const { + if (is_alive()) { + return level_; + } + return {}; + } + bool is_alive() const; + + // Overrides necessary for autograd + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + + private: + const char* tensorimpl_type_name() const override; + Tensor value_; + int64_t level_; + + // When we exit the level, this wrapper may be marked as "not alive". + // Wrappers that are not alive: + // 1) May still have autograd metadata on them + // 2) Forward dispatches to the underlying value() + std::shared_ptr is_alive_; +}; + +FUNCTORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level); +FUNCTORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor); +FUNCTORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor); +FUNCTORCH_API void dumpTensorCout(const Tensor& tensor); + +} +} // namespace at diff --git a/functorch/functorch/csrc/VmapInterpreter.cpp b/functorch/functorch/csrc/VmapInterpreter.cpp new file mode 100644 index 0000000000000..a8f0283aa3b7d --- /dev/null +++ b/functorch/functorch/csrc/VmapInterpreter.cpp @@ -0,0 +1,24 @@ +#include +#include + +namespace at { namespace functorch { + +void VmapInterpreterPtr::processImpl( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(TransformType::Vmap); + setup_dispatch_key_tls(exclude, DispatchKeySet(kVmapModeKey)); + op.callBoxed(stack); +} + +void VmapInterpreterPtr::sendToNextInterpreterImpl( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + // Re-dispatch + if (getDynamicLayerStack().size() == 0) { + sanityCheckStack(op, stack); + } + op.callBoxed(stack); +} + +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/VmapInterpreter.h b/functorch/functorch/csrc/VmapInterpreter.h new file mode 100644 index 0000000000000..084cea956b28e --- /dev/null +++ b/functorch/functorch/csrc/VmapInterpreter.h @@ -0,0 +1,22 @@ +#pragma once +#include + +namespace at { namespace functorch { + +struct VmapInterpreterPtr { + explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + int64_t batchSize() const { + return c10::get(base_->meta()).batchSize_; + } + RandomnessType randomness() const { + return c10::get(base_->meta()).randomness_; + } + private: + const Interpreter* base_; +}; + +}} // namespace at::functorch diff --git a/functorch/functorch/csrc/VmapModeRegistrations.cpp b/functorch/functorch/csrc/VmapModeRegistrations.cpp new file mode 100644 index 0000000000000..922b06e93db4a --- /dev/null +++ b/functorch/functorch/csrc/VmapModeRegistrations.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace functorch { + +void unsupportedRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, "vmap: We do not support calling out variants of random operations inside of vmap. ", + "Please use non-out variants as a workaround"); +} + +TORCH_LIBRARY_IMPL(_, FuncTorchVmapMode, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + +void nyiRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, "vmap: we do not yet support ", op.schema().operator_name(), + ". Please file an issue"); +} + +#define UNSUPPORTED_RANDOM(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>()); + +#define UNSUPPORTED_RANDOM2(op, overload) \ + m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>()); + +#define NYI_RANDOM(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&nyiRandomOp>()); + +#define NYI_RANDOM2(op, overload) \ + m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&nyiRandomOp>()); + +TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { + UNSUPPORTED_RANDOM2(bernoulli, out); + UNSUPPORTED_RANDOM2(rand, generator_out); + UNSUPPORTED_RANDOM2(rand, out); + UNSUPPORTED_RANDOM2(randint, generator_out); + UNSUPPORTED_RANDOM2(randint, out); + UNSUPPORTED_RANDOM2(randn, generator_out); + UNSUPPORTED_RANDOM2(randn, out); + UNSUPPORTED_RANDOM2(randperm, generator_out); + UNSUPPORTED_RANDOM2(randperm, out); + UNSUPPORTED_RANDOM2(multinomial, out); + UNSUPPORTED_RANDOM2(normal, float_Tensor_out); + UNSUPPORTED_RANDOM2(normal, Tensor_Tensor_out); + UNSUPPORTED_RANDOM2(normal, float_float_out); + UNSUPPORTED_RANDOM2(rrelu_with_noise, out); + + NYI_RANDOM(rrelu_with_noise); + NYI_RANDOM(rrelu_with_noise_); + NYI_RANDOM(rrelu_); + NYI_RANDOM(rrelu); +} + + +} +} // namespace at diff --git a/functorch/functorch/csrc/init.cpp b/functorch/functorch/csrc/init.cpp new file mode 100644 index 0000000000000..9f6db6455aebe --- /dev/null +++ b/functorch/functorch/csrc/init.cpp @@ -0,0 +1,411 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace functorch { + +static bool has_level(const Tensor& self, int64_t level) { + const auto* batched = maybeGetBatchedImpl(self); + if (!batched) { + return false; + } + return batched->level() >= level; +} + +Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) { + return addBatchDim(self, batch_dim, level); +} + +Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) { + auto t = at::functionalization::impl::to_functional_tensor(self); + at::functionalization::impl::unsafeGetFunctionalWrapper(t)->set_level(level); + return t; +} + +void _assert_wrapped_functional(const Tensor& unwrapped, const Tensor& wrapped) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(wrapped)); + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(unwrapped)); + auto wrapped_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped); + auto& wrapped_inner = wrapped_impl->value(); + TORCH_INTERNAL_ASSERT(unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) +} + +void _propagate_functional_input_mutation(const Tensor& unwrapped, const Tensor& wrapped) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(wrapped)); + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(unwrapped)); + auto wrapped_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped); + // Ensure that the input is up to date by committing any pending updates to the alias. + wrapped_impl->sync_(); + auto& wrapped_inner = wrapped_impl->value(); + // It would probably be more reasonable to check that the two tensors are aliased, + // but we can't do that unless we give BatchedTensorImpl a notion of storage. + if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) { + } else { + TORCH_INTERNAL_ASSERT(unwrapped.nbytes() == wrapped_inner.nbytes()); + TORCH_INTERNAL_ASSERT(unwrapped.sizes() == wrapped_inner.sizes(), + "An inplace-mutation op (like transpose_() was called on an input to the functionalization pass." + " Propagating those mutations to the input is currently not supported."); + unwrapped.copy_(wrapped_inner); + } +} + + +static std::pair remove_existing_batch_dim( + const BatchedTensorImpl* batched, int64_t level) { + + TORCH_INTERNAL_ASSERT(batched->level() == level); + return std::make_pair(batched->value(), batched->bdim()); +} + +// Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src` +// while preserving the order of other existing dimensions. +// We should probably add np.moveaxis (it is more general) to PyTorch. (#36048) +// When we do, replace the following with it. +static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) { + auto logical_dim = self.dim(); + src = maybe_wrap_dim(src, logical_dim); + dst = maybe_wrap_dim(dst, logical_dim); + if (src == dst) { + return self; + } + VmapDimVector permutation; + permutation.reserve(logical_dim); + for (int64_t dim = 0; dim < logical_dim; dim++) { + if (dim == src) { + continue; + } + permutation.push_back(dim); + } + permutation.insert(permutation.begin() + dst, src); + return self.permute(permutation); +} + +// Removes the batch dim with level `level` from `self`. If this causes the +// last batch dim to be removed from a BatchedTensor, then this returns a +// regular Tensor. +// +// If the `level` of the batch dim to remove does not exist in `self`, then we +// add the batch dim in. This can happen if `self` didn't interact with a tensor +// inside the vmap level, for example, +// self = torch.randn(3) +// y = torch.randn(5) +// out = vmap(lambda x: vmap(lambda y: x)(y))(self) +// assert out.shape == (3, 5) +// Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension +// corresponding to the *outer* vmap level and it doesn't have any dimensions that +// correspond to the inner vmap level so we need to create one for the user. +// +// `out_dim` controls where we should put the batch dimension in the output tensor. +Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) { + if (!has_level(self, level)) { + auto self_sizes = self.sizes(); + VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end()); + expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size); + auto result = self.expand(expanded_sizes); + return result; + } + + // Must be batched if has_level(self, /*any_level*/) + const auto* batched = maybeGetBatchedImpl(self); + TORCH_INTERNAL_ASSERT(batched != nullptr); + + Tensor self_without_bdim; + int64_t newly_exposed_logical_dim; + std::tie(self_without_bdim, newly_exposed_logical_dim) = remove_existing_batch_dim(batched, level); + auto result = _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim); + return result; +} + +Tensor _unwrap_functional_tensor(const Tensor& self, bool add_back_views) { + // We only ever call that after popping out of a functionalize() call, in which case the current tensors + // should always be wrapped in a FunctionalTensorWrapper. + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); + auto functional = at::functionalization::impl::unsafeGetFunctionalWrapper(self); + + // when regenerating the (potentially mutated) input tensors, the functionalization pass + // regenerates them through a series of view_copy() op calls. + // Functorch wants to turn those back into view ops though. + // Ensure that the input is up to date by committing any pending updates to the alias. + at::functionalization::impl::FunctionalizationReapplyViewsGuard guard(add_back_views); + bool any_updates = functional->apply_updates(); + if (any_updates) { + functional->regenerate_from_base(); + } + return functional->value(); +} + +Tensor _wrap_for_grad(const Tensor& self, int64_t level) { + // NB: different behavior inside?? + // return self; + // TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self)); + // TORCH_INTERNAL_ASSERT(self.has_storage()); + return makeTensorWrapper(self, level); +} + +Tensor _unwrap_for_grad(const Tensor& self, int64_t level) { + auto* result = maybeGetTensorWrapper(self); + if (!result) { + return self; + } + TORCH_INTERNAL_ASSERT(result->level().has_value()); + if (result->level() == level) { + return result->value(); + } + return self; +} + +int64_t dlevel(const Tensor& tensor) { + auto* wrapped = maybeGetTensorWrapper(tensor); + if (!wrapped) { + return 0; + } + if (!wrapped->is_alive()) { + return -1; + } + return wrapped->level().value(); +} + +bool dump_tensor(const Tensor& self) { + dumpTensorCout(self); + return true; +} + +RandomnessType get_randomness_enum(const std::string& randomness) { + if (randomness == "error") { + return RandomnessType::Error; + } else if (randomness == "same") { + return RandomnessType::Same; + } else if (randomness == "different") { + return RandomnessType::Different; + } else { + TORCH_CHECK(false, "randomness argument must be error, same, or different."); + } +} + +void set_fwd_grad_enabled(bool enabled) { + AutogradState::get_tls_state().set_fw_grad_mode(enabled); +} + +bool get_fwd_grad_enabled() { + return AutogradState::get_tls_state().get_fw_grad_mode(); +} + +int64_t _grad_increment_nesting() { + // See NOTE [grad and vjp interaction with no_grad] + bool prev_grad_mode = c10::GradMode::is_enabled(); + return initAndPushDynamicLayer(TransformType::Grad, nullopt, nullopt, prev_grad_mode); +} + +int64_t _grad_decrement_nesting() { + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Grad); + return layer.layerId(); +} + +int64_t _jvp_increment_nesting() { + // See NOTE [grad and vjp interaction with no_grad] + bool prev_fwd_grad_mode = get_fwd_grad_enabled(); + return initAndPushDynamicLayer(TransformType::Jvp, nullopt, nullopt, nullopt, prev_fwd_grad_mode); +} + +int64_t _jvp_decrement_nesting() { + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Jvp); + return layer.layerId(); +} + +int64_t _vmap_increment_nesting(int64_t batch_size, const std::string& randomness) { + return initAndPushDynamicLayer(TransformType::Vmap, batch_size, get_randomness_enum(randomness)); +} + +int64_t _vmap_decrement_nesting() { + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Vmap); + return layer.layerId(); +} + +int64_t _func_increment_nesting(bool reapply_views) { + return initAndPushDynamicLayer(TransformType::Functionalize, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt, /*functionalize_add_back_views=*/reapply_views); +} + +int64_t _func_decrement_nesting() { + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Functionalize); + return layer.layerId(); +} + +static bool is_batchedtensor(const Tensor& tensor) { + auto* batched = maybeGetBatchedImpl(tensor); + return batched != nullptr; +} + +static bool is_gradtrackingtensor(const Tensor& tensor) { + auto* wrapped = maybeGetTensorWrapper(tensor); + return wrapped != nullptr; +} + +static bool is_functionaltensor(const Tensor& tensor) { + return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); +} + +static Tensor get_unwrapped(const Tensor& tensor) { + auto* batched = maybeGetBatchedImpl(tensor); + if (batched) { + return batched->value(); + } + auto* wrapped = maybeGetTensorWrapper(tensor); + if (wrapped) { + return wrapped->value(); + } + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (functional) { + return functional->value(); + } + TORCH_CHECK(false, "No wrappers present!"); +} + +static int64_t maybe_get_level(const Tensor& tensor) { + auto* batched = maybeGetBatchedImpl(tensor); + if (batched) { + return batched->level(); + } + auto* wrapped = maybeGetTensorWrapper(tensor); + if (wrapped) { + if (wrapped->level()) { + return *wrapped->level(); + } + // TODO: this is a weird special case... + return -2; + } + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (functional) { + return functional->level(); + } + return -1; +} + +static int64_t maybe_get_bdim(const Tensor& tensor) { + auto* batched = maybeGetBatchedImpl(tensor); + if (batched) { + return batched->bdim(); + } + return -1; +} + +static int64_t currentLevel() { + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t current_level = maybe_layer->layerId(); + return current_level; +} + +static std::tuple unwrapTensorAtCurrentLevel(const Tensor& tensor) { + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t current_level = maybe_layer->layerId(); + auto result = unwrapTensorAtLevel(tensor, current_level); + auto value = std::get<0>(result); + auto bdim = std::get<1>(result); + value = moveBatchDimToFront(value, bdim); + return std::make_tuple(value, bdim.has_value() ? 0 : -1); +} + +static void tls_set_vmap_excluded(bool excluded) { + c10::impl::tls_set_dispatch_key_excluded(kBatchedKey, excluded); +} + +static bool tls_set_is_included() { + return c10::impl::tls_is_dispatch_key_included(kDynamicLayerFrontModeKey); +} + +static void _set_dynamic_layer_keys_included(bool value) { + return setDynamicLayerFrontBackKeysIncluded(value); +} + +static void dump_dls() { + std::cout << getDynamicLayerStack() << std::endl; +} + +static void dump_local_tls() { + auto tls = c10::impl::tls_local_dispatch_key_set(); + std::cout << "[Local Include] " << tls.included_ << std::endl; + std::cout << "[Local Exclude] " << tls.excluded_ << std::endl; +} + +} // namespace functorch +} + + +namespace at { namespace functorch { + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim"); + m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim"); + m.def("_wrap_functional_tensor", &at::functorch::_wrap_functional_tensor, "add functional tensor"); + m.def("_assert_wrapped_functional", &at::functorch::_assert_wrapped_functional, "assert wrapped functional"); + m.def("_propagate_functional_input_mutation", &at::functorch::_propagate_functional_input_mutation, "propagate functional input mutations"); + m.def("_unwrap_functional_tensor", &at::functorch::_unwrap_functional_tensor, "remove functional tensor"); + m.def("_vmap_increment_nesting", &at::functorch::_vmap_increment_nesting, "remove batch dim"); + m.def("_vmap_decrement_nesting", &at::functorch::_vmap_decrement_nesting, "remove batch dim"); + m.def("_func_increment_nesting", &at::functorch::_func_increment_nesting, "functionalization start"); + m.def("_func_decrement_nesting", &at::functorch::_func_decrement_nesting, "functionalization end"); + m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim"); + m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim"); + m.def("_jvp_increment_nesting", &at::functorch::_jvp_increment_nesting); + m.def("_jvp_decrement_nesting", &at::functorch::_jvp_decrement_nesting); + m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "wrap as gradtrackingtensor"); + m.def("_unwrap_for_grad", &at::functorch::_unwrap_for_grad, "unwrap from gradtrackingtensor"); + m.def("_set_vmap_fallback_warning_enabled", &at::functorch::setVmapFallbackWarningEnabled, "Set vmap fallback warnings"); + m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled); + m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled); + m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed); + m.def("get_inplace_requires_grad_allowed", &at::functorch::getInplaceRequiresGradAllowed); + m.def("dlevel", &at::functorch::dlevel, "dlevel"); + m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor"); + m.def("reshape_dim_into", &at::functorch::reshape_dim_into); + m.def("reshape_dim_outof", &at::functorch::reshape_dim_outof); + m.def("are_transforms_active", &at::functorch::areTransformsActive); + // various debugging things. Maybe we should offer these as first-class APIs + // on Tensors? + m.def("is_batchedtensor", &at::functorch::is_batchedtensor); + m.def("is_gradtrackingtensor", &at::functorch::is_gradtrackingtensor); + m.def("is_functionaltensor", &at::functorch::is_functionaltensor); + m.def("get_unwrapped", &at::functorch::get_unwrapped); + m.def("maybe_get_level", &at::functorch::maybe_get_level); + m.def("maybe_get_bdim", &at::functorch::maybe_get_bdim); + m.def("current_level", &at::functorch::currentLevel); + m.def("unwrap_batchedtensor", &at::functorch::unwrapTensorAtCurrentLevel); + m.def("tls_set_vmap_excluded", &at::functorch::tls_set_vmap_excluded); + m.def("tls_set_is_included", &at::functorch::tls_set_is_included); + m.def("_set_dynamic_layer_keys_included", &at::functorch::_set_dynamic_layer_keys_included); + m.def("dump_dls", &at::functorch::dump_dls); + m.def("dump_local_tls", &at::functorch::dump_local_tls); + m.def("set_fwd_grad_enabled", &at::functorch::set_fwd_grad_enabled); + m.def("get_fwd_grad_enabled", &at::functorch::get_fwd_grad_enabled); + at::functorch::initCompileCacheBindings(m.ptr()); + + // Windows doesn't like this +#ifndef _WIN32 + initDispatchBindings(m.ptr()); +#endif +} + +}} diff --git a/functorch/functorch/experimental/__init__.py b/functorch/functorch/experimental/__init__.py new file mode 100644 index 0000000000000..2b25051373553 --- /dev/null +++ b/functorch/functorch/experimental/__init__.py @@ -0,0 +1,4 @@ +from .batch_norm_replacement import replace_all_batch_norm_modules_ +# PyTorch forward-mode is not mature yet +from .._src.eager_transforms import jvp, jacfwd, hessian, functionalize +from .._src.vmap import chunk_vmap diff --git a/functorch/functorch/experimental/batch_norm_replacement.py b/functorch/functorch/experimental/batch_norm_replacement.py new file mode 100644 index 0000000000000..55fc08a0575bb --- /dev/null +++ b/functorch/functorch/experimental/batch_norm_replacement.py @@ -0,0 +1,22 @@ +import torch.nn as nn + + +def batch_norm_without_running_stats(module: nn.Module): + if isinstance(module, nn.modules.batchnorm._BatchNorm) and module.track_running_stats: + module.running_mean = None + module.running_var = None + module.num_batches_tracked = None + module.track_running_stats = False + + +def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module: + """ + In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and + setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root` + """ + # base case + batch_norm_without_running_stats(root) + + for obj in root.modules(): + batch_norm_without_running_stats(obj) + return root diff --git a/functorch/notebooks/_src/plot_ensembling.py b/functorch/notebooks/_src/plot_ensembling.py new file mode 100644 index 0000000000000..94cd1151ad7bc --- /dev/null +++ b/functorch/notebooks/_src/plot_ensembling.py @@ -0,0 +1,110 @@ +""" +========================== +Model ensembling +========================== +This example illustrates how to vectorize model ensembling using vmap. + +What is model ensembling? +-------------------------------------------------------------------- +Model ensembling combines the predictions from multiple models together. +Traditionally this is done by running each model on some inputs separately +and then combining the predictions. However, if you're running models with +the same architecture, then it may be possible to combine them together +using ``vmap``. ``vmap`` is a function transform that maps functions across +dimensions of the input tensors. One of its use cases is eliminating +for-loops and speeding them up through vectorization. + +Let's demonstrate how to do this using an ensemble of simple CNNs. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +torch.manual_seed(0) + +# Here's a simple CNN +class SimpleCNN(nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + output = x + return output + +# Let's generate some dummy data. Pretend that we're working with an MNIST dataset +# where the images are 28 by 28. +# Furthermore, let's say we wish to combine the predictions from 10 different +# models. +device = 'cuda' +num_models = 10 +data = torch.randn(100, 64, 1, 28, 28, device=device) +targets = torch.randint(10, (6400,), device=device) +models = [SimpleCNN().to(device) for _ in range(num_models)] + +# We have a couple of options for generating predictions. Maybe we want +# to give each model a different randomized minibatch of data, or maybe we +# want to run the same minibatch of data through each model (e.g. if we were +# testing the effect of different model initializations). + +# Option 1: different minibatch for each model +minibatches = data[:num_models] +predictions1 = [model(minibatch) for model, minibatch in zip(models, minibatches)] + +# Option 2: Same minibatch +minibatch = data[0] +predictions2 = [model(minibatch) for model in models] + + +###################################################################### +# Using vmap to vectorize the ensemble +# -------------------------------------------------------------------- +# Let's use ``vmap`` to speed up the for-loop. We must first prepare the models +# for use with ``vmap``. +# +# First, let's combine the states of the model together by stacking each parameter. +# For example, model[i].fc1.weight has shape [9216, 128]; we are going to stack the +# .fc1.weight of each of the 10 models to produce a big weight of shape [10, 9216, 128]. +# +# functorch offers the following convenience function to do that. It returns a +# stateless version of the model (fmodel) and stacked parameters and buffers. +from functorch import combine_state_for_ensemble +fmodel, params, buffers = combine_state_for_ensemble(models) +[p.requires_grad_() for p in params] + +# Option 1: get predictions using a different minibatch for each model. +# By default, vmap maps a function across the first dimension of all inputs to the +# passed-in function. After `combine_state_for_ensemble`, each of of ``params``, +# ``buffers`` have an additional dimension of size ``num_models`` at the front; +# and ``minibatches`` has a dimension of size ``num_models``. +print([p.size(0) for p in params]) +assert minibatches.shape == (num_models, 64, 1, 28, 28) +from functorch import vmap +predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) +assert torch.allclose(predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6) + +# Option 2: get predictions using the same minibatch of data +# vmap has an in_dims arg that specify which dimensions to map over. +# Using ``None``, we tell vmap we want the same minibatch to apply for all of +# the 10 models. +predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) +assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6) + +# A quick note: there are limitations around what types of functions can be +# transformed by vmap. The best functions to transform are ones that are +# pure functions: a function where the outputs are only determined by the inputs +# that have no side effects (e.g. mutation). vmap is unable to handle mutation of +# arbitrary Python data structures, but it is able to handle many in-place +# PyTorch operations. diff --git a/functorch/notebooks/_src/plot_jacobians_and_hessians.py b/functorch/notebooks/_src/plot_jacobians_and_hessians.py new file mode 100644 index 0000000000000..99db81556830d --- /dev/null +++ b/functorch/notebooks/_src/plot_jacobians_and_hessians.py @@ -0,0 +1,174 @@ +""" +============================= +Jacobians, hessians, and more +============================= + +Computing jacobians or hessians are useful in a number of non-traditional +deep learning models. It is difficult (or annoying) to compute these quantities +efficiently using a standard autodiff system like PyTorch Autograd; functorch +provides ways of computing various higher-order autodiff quantities efficiently. +""" +import torch +import torch.nn.functional as F +from functools import partial +torch.manual_seed(0) + +###################################################################### +# Setup: Comparing functorch vs the naive approach +# -------------------------------------------------------------------- +# Let's start with a function that we'd like to compute the jacobian of. +# This is a simple linear function with non-linear activation. +def predict(weight, bias, x): + return F.linear(x, weight, bias).tanh() + +# Here's some dummy data: a weight, a bias, and a feature vector. +D = 16 +weight = torch.randn(D, D) +bias = torch.randn(D) +x = torch.randn(D) + +# Let's think of ``predict`` as a function that maps the input ``x`` from R^D -> R^D. +# PyTorch Autograd computes vector-Jacobian products. In order to compute the full +# Jacobian of this R^D -> R^D function, we would have to compute it row-by-row +# by using a different unit vector each time. +xp = x.clone().requires_grad_() +unit_vectors = torch.eye(D) + +def compute_jac(xp): + jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] + for vec in unit_vectors] + return torch.stack(jacobian_rows) + +jacobian = compute_jac(xp) + +# Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid +# of the for-loop and vectorize the computation. We can't directly apply vmap +# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform: +from functorch import vmap, vjp +_, vjp_fn = vjp(partial(predict, weight, bias), x) +ft_jacobian, = vmap(vjp_fn)(unit_vectors) +assert torch.allclose(ft_jacobian, jacobian) + +# In another tutorial a composition of reverse-mode AD and vmap gave us +# per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap +# gives us Jacobian computation! Various compositions of vmap and autodiff +# transforms can give us different interesting quantities. +# +# functorch provides ``jacrev`` as a convenience function that performs +# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums +# argument that says which argument we would like to compute Jacobians with +# respect to. +from functorch import jacrev +ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) +assert torch.allclose(ft_jacobian, jacobian) + +# Let's compare the performance of the two ways to compute jacobian. +# The functorch version is much faster (and becomes even faster the more outputs +# there are). In general, we expect that vectorization via ``vmap`` can help +# eliminate overhead and give better utilization of your hardware. +from torch.utils.benchmark import Timer +without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) +with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) +print(without_vmap.timeit(500)) +print(with_vmap.timeit(500)) + +# It's pretty easy to flip the problem around and say we want to compute +# Jacobians of the parameters to our model (weight, bias) instead of the input. +ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) + +###################################################################### +# reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd) +# -------------------------------------------------------------------- +# We offer two APIs to compute jacobians: jacrev and jacfwd: +# - jacrev uses reverse-mode AD. As you saw above it is a composition of our +# vjp and vmap transforms. +# - jacfwd uses forward-mode AD. It is implemented as a composition of our +# jvp and vmap transforms. +# jacfwd and jacrev can be subsituted for each other and have different +# performance characteristics. +# +# As a general rule of thumb, if you're computing the jacobian of an R^N -> R^M +# function, if there are many more outputs than inputs (i.e. M > N) then jacfwd is +# preferred, otherwise use jacrev. There are exceptions to this rule, but a +# non-rigorous argument for this follows: + +# In reverse-mode AD, we are computing the jacobian row-by-row, while in +# forward-mode AD (which computes Jacobian-vector products), we are computing +# it column-by-column. The Jacobian matrix has M rows and N columns. +from functorch import jacrev, jacfwd + +# Benchmark with more inputs than outputs +Din = 32 +Dout = 2048 +weight = torch.randn(Dout, Din) +bias = torch.randn(Dout) +x = torch.randn(Din) + +using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) +using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) +print(f'jacfwd time: {using_fwd.timeit(500)}') +print(f'jacrev time: {using_bwd.timeit(500)}') + +# Benchmark with more outputs than inputs +Din = 2048 +Dout = 32 +weight = torch.randn(Dout, Din) +bias = torch.randn(Dout) +x = torch.randn(Din) + +using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) +using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) +print(f'jacfwd time: {using_fwd.timeit(500)}') +print(f'jacrev time: {using_bwd.timeit(500)}') + +###################################################################### +# Hessian computation with functorch.hessian +# -------------------------------------------------------------------- +# We offer a convenience API to compute hessians: functorch.hessian. +# Hessians are the jacobian of the jacobian, which suggests that one can just +# compose functorch's jacobian transforms to compute one. +# Indeed, under the hood, ``hessian(f)`` is simply ``jacfwd(jacrev(f))`` +# +# Depending on your model, you may want to use ``jacfwd(jacfwd(f))`` or +# ``jacrev(jacrev(f))`` instead to compute hessians. +from functorch import hessian +# # TODO: make sure PyTorch has tanh_backward implemented for jvp!! +# hess0 = hessian(predict, argnums=2)(weight, bias, x) +# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) +hess2 = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x) + +###################################################################### +# Batch Jacobian (and Batch Hessian) +# -------------------------------------------------------------------- +# In the above examples we've been operating with a single feature vector. +# In some cases you might want to take the Jacobian of a batch of outputs +# with respect to a batch of inputs where each input produces an independent +# output. That is, given a batch of inputs of shape (B, N) and a function +# that goes from (B, N) -> (B, M), we would like a Jacobian of shape (B, M, N). +# The easiest way to do this is to sum over the batch dimension and then +# compute the Jacobian of that function: + +def predict_with_output_summed(weight, bias, x): + return predict(weight, bias, x).sum(0) + +batch_size = 64 +Din = 31 +Dout = 33 +weight = torch.randn(Dout, Din) +bias = torch.randn(Dout) +x = torch.randn(batch_size, Din) + +batch_jacobian0 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x) + +# If you instead have a function that goes from R^N -> R^M but inputs that are +# batched, you compose vmap with jacrev to compute batched jacobians: + +compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0)) +batch_jacobian1 = compute_batch_jacobian(weight, bias, x) +assert torch.allclose(batch_jacobian0, batch_jacobian1) + +# Finally, batch hessians can be computed similarly. It's easiest to think about +# them by using vmap to batch over hessian computation, but in some cases the sum +# trick also works. +compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0)) +batch_hess = compute_batch_hessian(weight, bias, x) diff --git a/functorch/notebooks/_src/plot_per_sample_gradients.py b/functorch/notebooks/_src/plot_per_sample_gradients.py new file mode 100644 index 0000000000000..0feb2b80d947a --- /dev/null +++ b/functorch/notebooks/_src/plot_per_sample_gradients.py @@ -0,0 +1,124 @@ +""" +========================== +Per-sample-gradients +========================== + +What is it? +-------------------------------------------------------------------- +Per-sample-gradient computation is computing the gradient for each and every +sample in a batch of data. It is a useful quantity in differential privacy +and optimization research. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +torch.manual_seed(0) + +# Here's a simple CNN +class SimpleCNN(nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + output = x + return output + +def loss_fn(predictions, targets): + return F.nll_loss(predictions, targets) + +# Let's generate a batch of dummy data. Pretend that we're working with an +# MNIST dataset where the images are 28 by 28 and we have a minibatch of size 64. +device = 'cuda' +num_models = 10 +batch_size = 64 +data = torch.randn(batch_size, 1, 28, 28, device=device) +targets = torch.randint(10, (64,), device=device) + +# In regular model training, one would forward the batch of examples and then +# call .backward() to compute gradients: + +model = SimpleCNN().to(device=device) +predictions = model(data) +loss = loss_fn(predictions, targets) +loss.backward() + +# Conceptually, per-sample-gradient computation is equivalent to: for each sample +# of the data, perform a forward and a backward pass to get a gradient. +def compute_grad(sample, target): + sample = sample.unsqueeze(0) + target = target.unsqueeze(0) + prediction = model(sample) + loss = loss_fn(prediction, target) + return torch.autograd.grad(loss, list(model.parameters())) + +def compute_sample_grads(data, targets): + sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)] + sample_grads = zip(*sample_grads) + sample_grads = [torch.stack(shards) for shards in sample_grads] + return sample_grads + +per_sample_grads = compute_sample_grads(data, targets) + +# sample_grads[0] is the per-sample-grad for model.conv1.weight +# model.conv1.weight.shape is [32, 1, 3, 3]; notice how there is one gradient +# per sample in the batch for a total of 64. +print(per_sample_grads[0].shape) + + +###################################################################### +# Per-sample-grads using functorch +# -------------------------------------------------------------------- +# We can compute per-sample-gradients efficiently by using function transforms. +# First, let's create a stateless functional version of ``model`` by using +# ``functorch.make_functional_with_buffers``. +from functorch import make_functional_with_buffers, vmap, grad +fmodel, params, buffers = make_functional_with_buffers(model) + +# Next, let's define a function to compute the loss of the model given a single +# input rather than a batch of inputs. It is important that this function accepts the +# parameters, the input, and the target, because we will be transforming over them. +# Because the model was originally written to handle batches, we'll use +# ``torch.unsqueeze`` to add a batch dimension. +def compute_loss(params, buffers, sample, target): + batch = sample.unsqueeze(0) + targets = target.unsqueeze(0) + predictions = fmodel(params, buffers, batch) + loss = loss_fn(predictions, targets) + return loss + +# Now, let's use ``grad`` to create a new function that computes the gradient +# with respect to the first argument of compute_loss (i.e. the params). +ft_compute_grad = grad(compute_loss) + +# ``ft_compute_grad`` computes the gradient for a single (sample, target) pair. +# We can use ``vmap`` to get it to compute the gradient over an entire batch +# of samples and targets. Note that in_dims=(None, None, 0, 0) because we wish +# to map ``ft_compute_grad`` over the 0th dimension of the data and targets +# and use the same params and buffers for each. +ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) + +# Finally, let's used our transformed function to compute per-sample-gradients: +ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets) +for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads): + assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-6, rtol=1e-6) + +# A quick note: there are limitations around what types of functions can be +# transformed by vmap. The best functions to transform are ones that are +# pure functions: a function where the outputs are only determined by the inputs +# that have no side effects (e.g. mutation). vmap is unable to handle mutation of +# arbitrary Python data structures, but it is able to handle many in-place +# PyTorch operations. diff --git a/functorch/notebooks/aot_autograd_optimizations.ipynb b/functorch/notebooks/aot_autograd_optimizations.ipynb new file mode 100644 index 0000000000000..78204eea700e2 --- /dev/null +++ b/functorch/notebooks/aot_autograd_optimizations.ipynb @@ -0,0 +1,416 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# AOT Autograd - How to use and optimize?\n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "## Background\n", + "In this tutorial, we will learn how to use AOT Autograd to speedup training of deep learning models.\n", + "\n", + "For background, AOT Autograd is a toolkit to assist developers in accelerating training on PyTorch. Broadly, it has two key features\n", + "* AOT Autograd traces the forward and backward graph ahead of time. Presence of forward and backward graph ahead of time facilitates joint graph optimizations such as recomputation or activation checkpointing.\n", + "* AOT Autograd provides simple mechanisms to compile the extracted forward and backward graphs through deep learning compilers, such as NVFuser, NNC, TVM and others.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## What will you learn?\n", + "In this tutorial, we will look at how AOT Autograd can be used, in conjunction with backend compilers, to accelerate the training of PyTorch models. More specifically, you will learn\n", + "* How to use AOT Autograd?\n", + "* How AOT Autograd uses backend compilers to perform operation fusion?\n", + "* How AOT Autograd enables training-specific optimizations such as Recomputation?\n", + "\n", + "So, lets get started.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Let's setup a simple model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def fn(a, b, c, d):\n", + " x = a + b + c + d\n", + " return x.cos().cos()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Test that it works\n", + "a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]\n", + "ref = fn(a, b, c, d)\n", + "loss = ref.sum()\n", + "loss.backward()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use AOT Autograd\n", + "\n", + "Now, lets use AOT Autograd and look at the extracted forward and backward graphs. Internally, AOT uses `__torch_dispatch__` based tracing mechanism to extract forward and backward graphs, and wraps them in `torch.Fx` GraphModule containers. Note that AOT Autograd tracing is different from the usual Fx symbolic tracing. AOT Autograd uses Fx GraphModule just to represent the traced graphs (and not for tracing).\n", + "\n", + "AOT Autograd then sends these forward and backward graphs to the user supplied compilers. So, lets write a compiler that just prints the graph." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "def forward(self, primals_1, primals_2, primals_3, primals_4):\n", + " add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None\n", + " add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None\n", + " add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None\n", + " cos = torch.ops.aten.cos(add_2)\n", + " cos_1 = torch.ops.aten.cos(cos)\n", + " return [cos_1, add_2, cos]\n", + " \n", + "\n", + "\n", + "\n", + "def forward(self, add_2, cos, tangents_1):\n", + " sin = torch.ops.aten.sin(cos); cos = None\n", + " neg = torch.ops.aten.neg(sin); sin = None\n", + " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", + " sin_1 = torch.ops.aten.sin(add_2); add_2 = None\n", + " neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None\n", + " mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None\n", + " return [mul_1, mul_1, mul_1, mul_1]\n", + " \n" + ] + } + ], + "source": [ + "from functorch.compile import aot_function\n", + "\n", + "# The compiler_fn is called after the forward and backward graphs are extracted.\n", + "# Here, we just print the code in the compiler_fn. Return of this function is a callable.\n", + "def compiler_fn(fx_module: torch.fx.GraphModule, _):\n", + " print(fx_module.code)\n", + " return fx_module\n", + "\n", + "# Pass on the compiler_fn to the aot_function API\n", + "aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)\n", + "\n", + "# Run the aot_print_fn once to trigger the compilation and print the graphs\n", + "res = aot_print_fn(a, b, c, d)\n", + "assert torch.allclose(ref, res)\n", + "\n", + "from functorch.compile import clear_compile_cache\n", + "clear_compile_cache()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above code prints the Fx graph for the forward and backward graph. You can see that in addition to the original input of the forward pass, the forward graph outputs some additional tensors. These tensors are saved for the backward pass for gradient calculation. We will come back to these later while talking about recomputation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Operator Fusion\n", + "Now that we understand how to use AOT Autograd to print forward and backward graphs, let us use AOT Autograd to use some actual deep learning compiler. In this tutorial, we use PyTorch Neural Network Compiler (NNC) to perform pointwise operator fusion for CPU devices. For CUDA devices, a suitable alternative is NvFuser. So, lets use NNC" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile\n", + "from functorch.compile import ts_compile\n", + "\n", + "# Lets compile the forward and backward through ts_compile.\n", + "aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)\n", + "\n", + "# Correctness checking. Lets clone the input so that we can check grads.\n", + "cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n", + "cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n", + "\n", + "res = aot_nnc_fn(*cloned_inputs)\n", + "loss = res.sum()\n", + "loss.backward()\n", + "assert torch.allclose(ref, res)\n", + "assert torch.allclose(a.grad, cloned_a.grad)\n", + "assert torch.allclose(b.grad, cloned_b.grad)\n", + "assert torch.allclose(c.grad, cloned_c.grad)\n", + "assert torch.allclose(d.grad, cloned_d.grad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets benchmark the original and AOT Autograd + NNC compiled function." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Lets write a function to benchmark the forward and backward pass\n", + "import time\n", + "import statistics\n", + "\n", + "def bench(fn, args, prefix):\n", + " warmup = 10\n", + " iterations = 100\n", + "\n", + " for _ in range(warmup):\n", + " ref = fn(*args)\n", + " ref.sum().backward()\n", + " \n", + " fw_latencies = []\n", + " bw_latencies = []\n", + " for _ in range(iterations):\n", + " for arg in args:\n", + " arg.grad = None\n", + "\n", + " fw_begin = time.perf_counter()\n", + " ref = fn(*args)\n", + " fw_end = time.perf_counter()\n", + "\n", + " loss = ref.sum() \n", + "\n", + " bw_begin = time.perf_counter()\n", + " loss.backward()\n", + " bw_end = time.perf_counter()\n", + "\n", + " fw_latencies.append(fw_end - fw_begin)\n", + " bw_latencies.append(bw_end - bw_begin)\n", + " \n", + " avg_fw_latency = statistics.mean(fw_latencies) * 10**6\n", + " avg_bw_latency = statistics.mean(bw_latencies) * 10**6\n", + " print(prefix, \"Fwd = \" + str(avg_fw_latency) + \" us\", \"Bwd = \" + str(avg_bw_latency) + \" us\", sep=', ')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us\n", + "AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us\n" + ] + } + ], + "source": [ + "large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]\n", + "\n", + "# Benchmark the Eager and AOT Autograd functions\n", + "bench(fn, large_inputs, \"Eager\")\n", + "bench(aot_nnc_fn, large_inputs, \"AOT\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the help of NNC, AOT Autograd speeds up both the forward and backward pass. If we look at the printed graphs earlier, all the operators are pointwise. The pointwise operators are memory bandwidth bound, and thus benefit from operator fusion. Looking closely at the numbers, the backward pass gets higher speedup. This is because forward pass has to output some intermediate tensors for gradient calculation for the backward pass, preventing it from saving some memory reads and writes. However, such restriction does not exist in the backward graph." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Recomputation (aka Activation Checkpointing)\n", + "Recomputation (often called activation checkpointing) is a technique in which, instead of saving some activations for use in backwards, we recompute them **during** the backwards pass. Recomputing saves memory, but we incur performance overhead.\n", + "\n", + "However, in the presence of fusing compiler, we can do better that that. We can recompute the fusion-friendly operators to save memory, and then rely on the fusing compiler to fuse the recomputed operators. This reduces both memory and runtime. Please refer to this [discuss post](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) for more details.\n", + "\n", + "Here, we use AOT Autograd with NNC to perform similar type of recomputation. At the end of `__torch_dispatch__` tracing, AOT Autograd has a forward graph and joint forward-backward graph. AOT Autograd then uses a partitioner to isolate the forward and backward graph. In the example above, we used a default partitioner. For this experiment, we will use another partitioner called `min_cut_rematerialization_partition` to perform smarter fusion-aware recomputation. The partitioner is configurable and one can write their own partitioner to plug it in AOT Autograd." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "def forward(self, primals_1, primals_2, primals_3, primals_4):\n", + " add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None\n", + " add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None\n", + " add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None\n", + " cos = torch.ops.aten.cos(add_2)\n", + " cos_1 = torch.ops.aten.cos(cos); cos = None\n", + " return [cos_1, add_2]\n", + " \n", + "\n", + "\n", + "\n", + "def forward(self, add_2, tangents_1):\n", + " cos = torch.ops.aten.cos(add_2)\n", + " sin = torch.ops.aten.sin(cos); cos = None\n", + " neg = torch.ops.aten.neg(sin); sin = None\n", + " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", + " sin_1 = torch.ops.aten.sin(add_2); add_2 = None\n", + " neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None\n", + " mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None\n", + " return [mul_1, mul_1, mul_1, mul_1]\n", + " \n" + ] + } + ], + "source": [ + "from functorch.compile import min_cut_rematerialization_partition\n", + "\n", + "# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.\n", + "# This will show us how the recomputation has modified the graph.\n", + "aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)\n", + "res = aot_fn(a, b, c, d)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that compared to default partitioner, forward pass now outputs fewer tensors, and recomputes some operations in the backward pass. Let us try NNC compiler now to perform operator fusions (note that we also have a wrapper function - `memory_efficient_fusion` which internally uses `min_cut_rematerialization_partition` and Torchscript compiler to achieve the same effect as following code)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Lets set up the partitioner and NNC compiler.\n", + "aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)\n", + "\n", + "# Correctness checking. Lets clone the input so that we can check grads.\n", + "cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n", + "cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n", + "\n", + "res = aot_recompute_nnc_fn(*cloned_inputs)\n", + "loss = res.sum()\n", + "loss.backward()\n", + "assert torch.allclose(ref, res)\n", + "assert torch.allclose(a.grad, cloned_a.grad)\n", + "assert torch.allclose(b.grad, cloned_b.grad)\n", + "assert torch.allclose(c.grad, cloned_c.grad)\n", + "assert torch.allclose(d.grad, cloned_d.grad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, lets benchmark the different functions" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us\n", + "AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us\n", + "AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us\n" + ] + } + ], + "source": [ + "bench(fn, large_inputs, \"Eager\")\n", + "bench(aot_nnc_fn, large_inputs, \"AOT\")\n", + "bench(aot_recompute_nnc_fn, large_inputs, \"AOT_Recomp\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We observe that both forward and backward latency improve over the default partitioner (and a lot better than eager). Fewer outputs in the forward pass and fewer inputs in the backward pass, along with fusion, allows better memory bandwidth utilization leading to further speedups." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Actual Usage\n", + "For actual usage on CUDA devices, we've wrapped AOTAutograd in a convenient wrapper - `memory_efficient_fusion`. Use this for fusion on GPU!\n", + "\n", + "```\n", + "from functorch.compile import memory_efficient_fusion\n", + "```\n" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "19eab2eb1bf96423965affa906f5d33b4f667cc21cd0152dc4f24eb30ccbeee2" + }, + "kernelspec": { + "display_name": "Python 3.8.12 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/functorch/notebooks/colab/ensembling_colab.ipynb b/functorch/notebooks/colab/ensembling_colab.ipynb new file mode 100644 index 0000000000000..ae8b7a67a6d85 --- /dev/null +++ b/functorch/notebooks/colab/ensembling_colab.ipynb @@ -0,0 +1,598 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "ensembling_colab.ipynb", + "provenance": [], + "collapsed_sections": [ + "0I5Mm2q2f5aw" + ], + "machine_shape": "hm", + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "### Welcome to the functorch tutorial on ensembling models, in colab." + ], + "metadata": { + "id": "W6b4RUiYnhSt" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Configuring your colab to run functorch \n" + ], + "metadata": { + "id": "0I5Mm2q2f5aw" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Getting setup** - running functorch currently requires Pytorch Nightly. \n", + "Thus we'll go through a pytorch nightly install and build functorch. \n", + "\n", + "After that and a restart, you'll be ready to run the tutorial here on colab." + ], + "metadata": { + "id": "jnHxd2KFgPJg" + } + }, + { + "cell_type": "markdown", + "source": [ + "Let's setup a restart function:" + ], + "metadata": { + "id": "PvwZSOklhpB2" + } + }, + { + "cell_type": "code", + "source": [ + "def colab_restart():\n", + " print(\"--> Restarting colab instance\") \n", + " get_ipython().kernel.do_shutdown(True)" + ], + "metadata": { + "id": "MklsA-KRhZKC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, let's confirm that we have a gpu. \n", + "(If not, select Runtime -> Change Runtime type above,\n", + " and select GPU under Hardward Accelerator )" + ], + "metadata": { + "id": "Njk9qPgTiiGS" + } + }, + { + "cell_type": "code", + "source": [ + "!nvcc --version" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HxidO4dpiPGi", + "outputId": "5b76c0f4-e83b-4626-c9c4-7165324528ee" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "nvcc: NVIDIA (R) Cuda compiler driver\n", + "Copyright (c) 2005-2020 NVIDIA Corporation\n", + "Built on Mon_Oct_12_20:09:46_PDT_2020\n", + "Cuda compilation tools, release 11.1, V11.1.105\n", + "Build cuda_11.1.TC455_06.29190527_0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's remove the default PyTorch install:" + ], + "metadata": { + "id": "HanoUO62jtKx" + } + }, + { + "cell_type": "code", + "source": [ + "!pip uninstall -y torch" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NIoTNykP9xI5", + "outputId": "b79462f1-50f2-42f5-e079-9148d4b238d9" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Found existing installation: torch 1.11.0+cu111\n", + "Uninstalling torch-1.11.0+cu111:\n", + " Successfully uninstalled torch-1.11.0+cu111\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "And install the relevant nightly version. (this defaults to 10.2 Cuda which works on most colabs). " + ], + "metadata": { + "id": "n-DFUwBVkHaX" + } + }, + { + "cell_type": "code", + "source": [ + "cuda_version = \"cu102\" # optionally - cu113 (for 11.3) is an option as well if you have 11.3 listed above in the nvcc output. " + ], + "metadata": { + "id": "BH5ffJBkkRR8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install torch -f https://download.pytorch.org/whl/{cuda_version}/torch_stable.html --upgrade" + ], + "metadata": { + "id": "Bi2oymijkav5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next we'll install functorch:" + ], + "metadata": { + "id": "s3rrVgGkmNpi" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UtBgzUPDfIQg" + }, + "outputs": [], + "source": [ + "!pip install functorch" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Finally - restart colab and after that - just skip directly down to the '-- Tutorial Start --' section to get underway." + ], + "metadata": { + "id": "T8dhR1XEmcJ6" + } + }, + { + "cell_type": "code", + "source": [ + "colab_restart() " + ], + "metadata": { + "id": "xo2UY9b8ma8t", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d2ca8716-b99a-4335-c60a-b9ad28e8d8c7" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--> Restarting colab instance\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## -- Tutorial Start -- \n", + "\n" + ], + "metadata": { + "id": "nj6_fW76wM0d" + } + }, + { + "cell_type": "code", + "source": [ + "# Confirm we are ready to start. \n", + "# If this errs, please make sure you have completed the 'configuring your colab' steps above first and then return here.\n", + "\n", + "import functorch " + ], + "metadata": { + "id": "SvUfIxRyeAaL" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Model Ensembling\n", + "\n", + "This example illustrates how to vectorize model ensembling, using vmap.\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "nLdOLDH6m9oy" + } + }, + { + "cell_type": "markdown", + "source": [ + "**What is model ensembling?**\n", + "\n", + "Model ensembling combines the predictions from multiple models together. Traditionally this is done by running each model on some inputs separately and then combining the predictions. However, if you’re running models with the same architecture, then it may be possible to combine them together using vmap. vmap is a function transform that maps functions across dimensions of the input tensors. One of its use cases is eliminating for-loops and speeding them up through vectorization.\n", + "\n", + "Let’s demonstrate how to do this using an ensemble of simple CNNs.\n", + "\n" + ], + "metadata": { + "id": "CJJBTOl-tawq" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from functools import partial\n", + "torch.manual_seed(0);" + ], + "metadata": { + "id": "Gb-yt4VKUUuc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Here's a simple MLP\n", + "class SimpleMLP(nn.Module):\n", + " def __init__(self):\n", + " super(SimpleMLP, self).__init__()\n", + " self.fc1 = nn.Linear(784, 128)\n", + " self.fc2 = nn.Linear(128, 128)\n", + " self.fc3 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, x):\n", + " x = x.flatten(1)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = self.fc2(x)\n", + " x = F.relu(x)\n", + " x = self.fc3(x)\n", + " return x\n" + ], + "metadata": { + "id": "tf-HKHjUUbyY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a minibatch of size 64. Furthermore, lets say we want to combine the predictions from 10 different models. \n" + ], + "metadata": { + "id": "VEDPe-EoU5Fa" + } + }, + { + "cell_type": "code", + "source": [ + "device = 'cuda'\n", + "num_models = 10\n", + "\n", + "data = torch.randn(100, 64, 1, 28, 28, device=device)\n", + "targets = torch.randint(10, (6400,), device=device)\n", + "\n", + "models = [SimpleMLP().to(device) for _ in range(num_models)]" + ], + "metadata": { + "id": "WB2Qe3AHUvPN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We have a couple of options for generating predictions. Maybe we want to give each model a different randomized minibatch of data. Alternatively, maybe we want to run the same minibatch of data through each model (e.g. if we were testing the effect of different model initializations).\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "GOGJ-OUxVcT5" + } + }, + { + "cell_type": "markdown", + "source": [ + "Option 1: different minibatch for each model" + ], + "metadata": { + "id": "CwJBb09MxCN3" + } + }, + { + "cell_type": "code", + "source": [ + "minibatches = data[:num_models]\n", + "predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]" + ], + "metadata": { + "id": "WYjMx8QTUvRu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Option 2: Same minibatch" + ], + "metadata": { + "id": "HNw4_IVzU5Pz" + } + }, + { + "cell_type": "code", + "source": [ + "minibatch = data[0]\n", + "predictions2 = [model(minibatch) for model in models]" + ], + "metadata": { + "id": "vUsb3VfexJrY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Using vmap to vectorize the ensemble\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "aNkX6lFIxzcm" + } + }, + { + "cell_type": "markdown", + "source": [ + "Let’s use vmap to speed up the for-loop. We must first prepare the models for use with vmap.\n", + "\n", + "First, let’s combine the states of the model together by stacking each parameter. For example, `model[i].fc1.weight` has shape `[784, 128]`; we are going to stack the .fc1.weight of each of the 10 models to produce a big weight of shape `[10, 784, 128]`.\n", + "\n", + "functorch offers the 'combine_state_for_ensemble' convenience function to do that. It returns a stateless version of the model (fmodel) and stacked parameters and buffers.\n", + "\n" + ], + "metadata": { + "id": "-sFMojhryviM" + } + }, + { + "cell_type": "code", + "source": [ + "from functorch import combine_state_for_ensemble\n", + "\n", + "fmodel, params, buffers = combine_state_for_ensemble(models)\n", + "[p.requires_grad_() for p in params];\n" + ], + "metadata": { + "id": "C3a9_clvyPho" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Option 1: get predictions using a different minibatch for each model. \n", + "\n", + "By default, vmap maps a function across the first dimension of all inputs to the passed-in function. After using the combine_state_for_ensemble, each of the params and buffers have an additional dimension of size 'num_models' at the front, and minibatches has a dimension of size 'num_models'.\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "mFJDWMM9yaYZ" + } + }, + { + "cell_type": "code", + "source": [ + "print([p.size(0) for p in params]) # show the leading 'num_models' dimension\n", + "\n", + "assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ezuFQx1G1zLG", + "outputId": "a8c7626e-5191-4ebe-9cba-55dd1af56e40" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[10, 10, 10, 10, 10, 10]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from functorch import vmap\n", + "\n", + "predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)\n", + "\n", + "# verify the vmap predictions match the \n", + "assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)" + ], + "metadata": { + "id": "VroLnfD82DDf" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Option 2: get predictions using the same minibatch of data.\n", + "\n", + "vmap has an in_dims arg that specifies which dimensions to map over. By using `None`, we tell vmap we want the same minibatch to apply for all of the 10 models.\n", + "\n", + "\n" + ], + "metadata": { + "id": "tlkmyQyfY6XU" + } + }, + { + "cell_type": "code", + "source": [ + "predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)\n", + "\n", + "assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)" + ], + "metadata": { + "id": "WiSMupvCyecd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations." + ], + "metadata": { + "id": "KrXQsUCIGLWm" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Performance\n", + "\n", + "Curious about performance numbers? Here's how the numbers look on Google Colab." + ], + "metadata": { + "id": "MCjBhMrVF5hH" + } + }, + { + "cell_type": "code", + "source": [ + "from torch.utils.benchmark import Timer\n", + "without_vmap = Timer(\n", + " stmt=\"[model(minibatch) for model, minibatch in zip(models, minibatches)]\",\n", + " globals=globals())\n", + "with_vmap = Timer(\n", + " stmt=\"vmap(fmodel)(params, buffers, minibatches)\",\n", + " globals=globals())\n", + "print(f'Predictions without vmap {without_vmap.timeit(100)}')\n", + "print(f'Predictions with vmap {with_vmap.timeit(100)}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gJPrGdS0GBjz", + "outputId": "460d9808-2c70-4936-8c03-6a008bc289d5" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Predictions without vmap \n", + "[model(minibatch) for model, minibatch in zip(models, minibatches)]\n", + " 3.20 ms\n", + " 1 measurement, 100 runs , 1 thread\n", + "Predictions with vmap \n", + "vmap(fmodel)(params, buffers, minibatches)\n", + " 879.02 us\n", + " 1 measurement, 100 runs , 1 thread\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "There's a large speedup using vmap! \n", + "\n", + "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n", + "\n" + ], + "metadata": { + "id": "UI74G9JarQU8" + } + } + ] +} \ No newline at end of file diff --git a/functorch/notebooks/colab/jacobians_hessians_colab.ipynb b/functorch/notebooks/colab/jacobians_hessians_colab.ipynb new file mode 100644 index 0000000000000..b2c2e39b0cfda --- /dev/null +++ b/functorch/notebooks/colab/jacobians_hessians_colab.ipynb @@ -0,0 +1,1120 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "jacobians_hessians_colab.ipynb", + "provenance": [], + "collapsed_sections": [ + "0I5Mm2q2f5aw" + ], + "machine_shape": "hm", + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "### Welcome to the functorch tutorial on Jacobians, Hessians and more - on colab! " + ], + "metadata": { + "id": "W6b4RUiYnhSt" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Configuring your colab to run functorch \n" + ], + "metadata": { + "id": "0I5Mm2q2f5aw" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Getting setup** - running functorch currently requires at least PyTorch 1.11. \n", + "Thus we'll go through a pytorch 1.11 install and build functorch. \n", + "\n", + "After that and a restart, you'll be ready to run the tutorial here on colab." + ], + "metadata": { + "id": "jnHxd2KFgPJg" + } + }, + { + "cell_type": "markdown", + "source": [ + "Let's setup a restart function:" + ], + "metadata": { + "id": "PvwZSOklhpB2" + } + }, + { + "cell_type": "code", + "source": [ + "def colab_restart():\n", + " print(\"--> Restarting colab instance\") \n", + " get_ipython().kernel.do_shutdown(True)" + ], + "metadata": { + "id": "MklsA-KRhZKC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, let's confirm that we have a gpu. \n", + "(If not, select Runtime -> Change Runtime type above,\n", + " and select GPU under Hardward Accelerator )" + ], + "metadata": { + "id": "Njk9qPgTiiGS" + } + }, + { + "cell_type": "code", + "source": [ + "!nvcc --version" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HxidO4dpiPGi", + "outputId": "d6d31c17-02cf-427b-cae8-6994c57c2320" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "nvcc: NVIDIA (R) Cuda compiler driver\n", + "Copyright (c) 2005-2020 NVIDIA Corporation\n", + "Built on Mon_Oct_12_20:09:46_PDT_2020\n", + "Cuda compilation tools, release 11.1, V11.1.105\n", + "Build cuda_11.1.TC455_06.29190527_0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's remove the default PyTorch install:" + ], + "metadata": { + "id": "HanoUO62jtKx" + } + }, + { + "cell_type": "code", + "source": [ + "!pip uninstall -y torch" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NIoTNykP9xI5", + "outputId": "5cc5a77d-9696-4cde-a7e5-3d835058afee" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Found existing installation: torch 1.10.0+cu111\n", + "Uninstalling torch-1.10.0+cu111:\n", + " Successfully uninstalled torch-1.10.0+cu111\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "And install the relevant version. (this defaults to 10.2 Cuda which works on most colabs). " + ], + "metadata": { + "id": "n-DFUwBVkHaX" + } + }, + { + "cell_type": "code", + "source": [ + "cuda_version = \"cu102\" # optionally - cu113 (for 11.3) is an option as well if you have 11.3 listed above in the nvcc output. " + ], + "metadata": { + "id": "BH5ffJBkkRR8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install torch -f https://download.pytorch.org/whl/{cuda_version}/torch_stable.html --upgrade" + ], + "metadata": { + "id": "Bi2oymijkav5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next we'll install functorch:" + ], + "metadata": { + "id": "s3rrVgGkmNpi" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UtBgzUPDfIQg" + }, + "outputs": [], + "source": [ + "!pip install functorch" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Finally - restart colab and after that - just skip directly down to the '-- Tutorial Start --' section to get underway." + ], + "metadata": { + "id": "T8dhR1XEmcJ6" + } + }, + { + "cell_type": "code", + "source": [ + "colab_restart() " + ], + "metadata": { + "id": "xo2UY9b8ma8t", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "a59e08c1-7206-4439-e08e-c4b8ff004f49" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--> Restarting colab instance\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## -- Tutorial Start -- \n", + "\n" + ], + "metadata": { + "id": "nj6_fW76wM0d" + } + }, + { + "cell_type": "code", + "source": [ + "# Confirm we are ready to start. \n", + "# If this errs, please make sure you have completed the install steps above first and then return here.\n", + "\n", + "import functorch " + ], + "metadata": { + "id": "SvUfIxRyeAaL" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms\n", + "\n", + "Computing quantities related to Jacobians or Hessians is useful in a number of non-traditional deep learning models. \n", + "\n", + "It is difficult (or annoying) to compute these quantities efficiently using a standard autodiff system like PyTorch Autograd; functorch provides ways of computing various higher-order autodiff quantities efficiently." + ], + "metadata": { + "id": "OeTtrGkGfsE9" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Computing the Jacobian" + ], + "metadata": { + "id": "viWZDMQtflUG" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from functools import partial\n", + "_ = torch.manual_seed(0)\n" + ], + "metadata": { + "id": "w_IinyjzflUH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let’s start with a function that we’d like to compute the jacobian of. This is a simple linear function with non-linear activation.\n", + "\n" + ], + "metadata": { + "id": "cibF_PEYflUH" + } + }, + { + "cell_type": "code", + "source": [ + "def predict(weight, bias, x):\n", + " return F.linear(x, weight, bias).tanh()" + ], + "metadata": { + "id": "qhcD9hWYflUH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's add some dummy data: a weight, a bias, and a feature vector x.\n", + "\n" + ], + "metadata": { + "id": "G8tqQrO_flUH" + } + }, + { + "cell_type": "code", + "source": [ + "D = 16\n", + "weight = torch.randn(D, D)\n", + "bias = torch.randn(D)\n", + "x = torch.randn(D) # feature vector" + ], + "metadata": { + "id": "FZ4uJfZGflUH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.\n", + "PyTorch Autograd computes vector-Jacobian products. In order to compute the full\n", + "Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row\n", + "by using a different unit vector each time." + ], + "metadata": { + "id": "uMAW-ArQflUH" + } + }, + { + "cell_type": "code", + "source": [ + "def compute_jac(xp):\n", + " jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n", + " for vec in unit_vectors]\n", + " return torch.stack(jacobian_rows)" + ], + "metadata": { + "id": "z-BJPtbpflUI" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "xp = x.clone().requires_grad_()\n", + "unit_vectors = torch.eye(D)\n", + "\n", + "jacobian = compute_jac(xp)\n", + "\n", + "print(jacobian.shape)\n", + "print(jacobian[0]) # show first row" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f1f1ec12-56ef-40f7-8c3c-cbad7bf86644", + "id": "zuWGSXspflUI" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([16, 16])\n", + "tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,\n", + " 0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Instead of computing the jacobian row-by-row, we can use vmap to get rid of the for-loop and vectorize the computation. \n", + "We can’t directly apply vmap to PyTorch Autograd; instead, functorch provides a vjp transform:\n", + "\n" + ], + "metadata": { + "id": "mxlEOUieflUI" + } + }, + { + "cell_type": "code", + "source": [ + "from functorch import vmap, vjp\n", + "\n", + "_, vjp_fn = vjp(partial(predict, weight, bias), x)\n", + "\n", + "ft_jacobian, = vmap(vjp_fn)(unit_vectors)\n", + "\n", + "# lets confirm both methods compute the same result\n", + "assert torch.allclose(ft_jacobian, jacobian)" + ], + "metadata": { + "id": "DeF6uy4WflUI" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "In future tutorial a composition of reverse-mode AD and vmap will give us per-sample-gradients. \n", + "In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! \n", + "Various compositions of vmap and autodiff transforms can give us different interesting quantities.\n", + "\n", + "functorch provides **jacrev** as a convenience function that performs the vmap-vjp composition to compute jacobians. **jacrev** accepts an argnums argument that says which argument we would like to compute Jacobians with respect to.\n", + "\n" + ], + "metadata": { + "id": "Hy4REmwDflUI" + } + }, + { + "cell_type": "code", + "source": [ + "from functorch import jacrev\n", + "\n", + "ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n", + "\n", + "# confirm \n", + "assert torch.allclose(ft_jacobian, jacobian)" + ], + "metadata": { + "id": "Rt7i6_YlflUI" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let’s compare the performance of the two ways to compute the jacobian. The functorch version is much faster (and becomes even faster the more outputs there are). \n", + "\n", + "In general, we expect that vectorization via vmap can help eliminate overhead and give better utilization of your hardware.\n", + "\n", + "Vmap does this magic by pushing the outer loop down into the functions primitive operations in order to obtain better performance.\n", + "\n", + "\n" + ], + "metadata": { + "id": "JYe2H1UcflUJ" + } + }, + { + "cell_type": "markdown", + "source": [ + "Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:" + ], + "metadata": { + "id": "i_143LZwflUJ" + } + }, + { + "cell_type": "code", + "source": [ + "def get_perf(first, first_descriptor, second, second_descriptor):\n", + " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n", + " faster = second.times[0]\n", + " slower = first.times[0]\n", + " gain = (slower-faster)/slower\n", + " if gain < 0: gain *=-1 \n", + " final_gain = gain*100\n", + " print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")\n", + " " + ], + "metadata": { + "id": "II7r6jBtflUJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "And then run the performance comparison:" + ], + "metadata": { + "id": "r4clPnPKflUJ" + } + }, + { + "cell_type": "code", + "source": [ + "from torch.utils.benchmark import Timer\n", + "\n", + "without_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\n", + "with_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "\n", + "no_vmap_timer = without_vmap.timeit(500)\n", + "with_vmap_timer = with_vmap.timeit(500)\n", + "\n", + "print(no_vmap_timer)\n", + "print(with_vmap_timer)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "cbf77a19-aac9-428d-eba1-74d337c53e49", + "id": "ZPtoxF6eflUJ" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "compute_jac(xp)\n", + " 2.25 ms\n", + " 1 measurement, 500 runs , 1 thread\n", + "\n", + "jacrev(predict, argnums=2)(weight, bias, x)\n", + " 884.34 us\n", + " 1 measurement, 500 runs , 1 thread\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Lets do a relative performance comparison of the above with our get_perf function:" + ], + "metadata": { + "id": "nGBBi4dZflUJ" + } + }, + { + "cell_type": "code", + "source": [ + "get_perf(no_vmap_timer, \"without vmap\", with_vmap_timer, \"vmap\");" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "85d0bc5f-34aa-4826-f953-6c637404490c", + "id": "zqV2RzEXflUJ" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Performance delta: 60.7170 percent improvement with vmap \n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input." + ], + "metadata": { + "id": "EQAB99EQflUJ" + } + }, + { + "cell_type": "code", + "source": [ + "# note the change in input via argnums params of 0,1 to map to weight and bias\n", + "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)" + ], + "metadata": { + "id": "8UZpC8DnflUK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n" + ], + "metadata": { + "id": "F3USYENIflUK" + } + }, + { + "cell_type": "markdown", + "source": [ + "We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: \n", + "- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. \n", + "- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. \n", + "\n", + "jacfwd and jacrev can be substituted for each other but they have different performance characteristics.\n", + "\n", + "As a general rule of thumb, if you’re computing the jacobian of an $𝑅^N \\to R^M$ function, and there are many more outputs than inputs (i.e. $M > N$) then jacfwd is preferred, otherwise use jacrev. There are exceptions to this rule, but a non-rigorous argument for this follows:\n", + "\n", + "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n", + "\n" + ], + "metadata": { + "id": "V7B3vE8dflUK" + } + }, + { + "cell_type": "code", + "source": [ + "from functorch import jacrev, jacfwd" + ], + "metadata": { + "id": "k7Tok7m3flUK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "First, let's benchmark with more inputs than outputs:\n", + "\n" + ], + "metadata": { + "id": "YrV-gZAaflUL" + } + }, + { + "cell_type": "code", + "source": [ + "Din = 32\n", + "Dout = 2048\n", + "weight = torch.randn(Dout, Din)\n", + "\n", + "bias = torch.randn(Dout)\n", + "x = torch.randn(Din)\n", + "\n", + "# remember the general rule about taller vs wider...here we have a taller matrix:\n", + "print(weight.shape)\n", + "\n", + "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "\n", + "jacfwd_timing = using_fwd.timeit(500)\n", + "jacrev_timing = using_bwd.timeit(500)\n", + "\n", + "print(f'jacfwd time: {jacfwd_timing}')\n", + "print(f'jacrev time: {jacrev_timing}')\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "dd882726-9723-47c0-a72f-3c7835a85aa1", + "id": "m5j-4hSxflUL" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([2048, 32])\n", + "jacfwd time: \n", + "jacfwd(predict, argnums=2)(weight, bias, x)\n", + " 1.32 ms\n", + " 1 measurement, 500 runs , 1 thread\n", + "jacrev time: \n", + "jacrev(predict, argnums=2)(weight, bias, x)\n", + " 12.46 ms\n", + " 1 measurement, 500 runs , 1 thread\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "and then do a relative benchmark:" + ], + "metadata": { + "id": "k_Sg-4tVflUL" + } + }, + { + "cell_type": "code", + "source": [ + "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3a6586a1-269d-46d8-d119-e24f6d46277f", + "id": "_4T96zGjflUL" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Performance delta: 842.8274 percent improvement with jacrev \n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "and now the reverse - more outputs (M) than inputs (N):" + ], + "metadata": { + "id": "RCDPot1yflUL" + } + }, + { + "cell_type": "code", + "source": [ + "Din = 2048\n", + "Dout = 32\n", + "weight = torch.randn(Dout, Din)\n", + "bias = torch.randn(Dout)\n", + "x = torch.randn(Din)\n", + "\n", + "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "\n", + "jacfwd_timing = using_fwd.timeit(500)\n", + "jacrev_timing = using_bwd.timeit(500)\n", + "\n", + "print(f'jacfwd time: {jacfwd_timing}')\n", + "print(f'jacrev time: {jacrev_timing}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "913e9ccd-3d4f-472a-a749-19cee36d0a16", + "id": "_DRFqzqZflUM" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "jacfwd time: \n", + "jacfwd(predict, argnums=2)(weight, bias, x)\n", + " 7.99 ms\n", + " 1 measurement, 500 runs , 1 thread\n", + "jacrev time: \n", + "jacrev(predict, argnums=2)(weight, bias, x)\n", + " 1.09 ms\n", + " 1 measurement, 500 runs , 1 thread\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "and a relative perf comparison:" + ], + "metadata": { + "id": "5SRbMCNsflUM" + } + }, + { + "cell_type": "code", + "source": [ + "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "c282ce25-4f6e-44cd-aed7-60f6f5010e5b", + "id": "uF_9GaoiflUM" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Performance delta: 635.2095 percent improvement with jacfwd \n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Hessian computation with functorch.hessian\n" + ], + "metadata": { + "id": "J29FQaBQflUM" + } + }, + { + "cell_type": "markdown", + "source": [ + "We offer a convenience API to compute hessians: `functorch.hessian`. \n", + "Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).\n", + "\n", + "This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. \n", + "Indeed, under the hood, `hessian(f)` is simply `jacfwd(jacrev(f))`.\n", + "\n" + ], + "metadata": { + "id": "My4DPH97flUM" + } + }, + { + "cell_type": "markdown", + "source": [ + "Note: to boost performance: depending on your model, you may also want to use `jacfwd(jacfwd(f))` or `jacrev(jacrev(f))` instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n", + "\n" + ], + "metadata": { + "id": "FJt038l5flUM" + } + }, + { + "cell_type": "code", + "source": [ + "from functorch import hessian\n", + "\n", + "# lets reduce the size in order not to blow out colab. Hessians require significant memory:\n", + "Din = 512\n", + "Dout = 32\n", + "weight = torch.randn(Dout, Din)\n", + "bias = torch.randn(Dout)\n", + "x = torch.randn(Din)\n", + "\n", + "hess_api = hessian(predict, argnums=2)(weight, bias, x)\n", + "hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\n", + "#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)\n" + ], + "metadata": { + "id": "jEqr2ywZflUM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())" + ], + "metadata": { + "id": "n9BHcICQflUN" + } + }, + { + "cell_type": "code", + "source": [ + "torch.allclose(hess_api, hess_fwdfwd)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e457e3bc-f085-4f90-966d-f98893b98ea8", + "id": "eHiWRkjJflUN" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Batch Jacobian and Batch Hessian\n" + ], + "metadata": { + "id": "Gjt1RO8HflUN" + } + }, + { + "cell_type": "markdown", + "source": [ + "In the above examples we’ve been operating with a single feature vector. In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. That is, given a batch of inputs of shape `(B, N)` and a function that goes from $R^N \\to R^M$, we would like a Jacobian of shape `(B, M, N)`. \n", + "\n", + "The easiest way to do this is to use vmap:" + ], + "metadata": { + "id": "RjIzdoQNflUN" + } + }, + { + "cell_type": "code", + "source": [ + "batch_size = 64\n", + "Din = 31\n", + "Dout = 33\n", + "\n", + "weight = torch.randn(Dout, Din)\n", + "print(f\"weight shape = {weight.shape}\")\n", + "\n", + "bias = torch.randn(Dout)\n", + "\n", + "x = torch.randn(batch_size, Din)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "561eb618-e00f-40d5-bd99-fa51ab82051f", + "id": "B1eoEO4UflUN" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "weight shape = torch.Size([33, 31])\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\n", + "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)" + ], + "metadata": { + "id": "nZ_V02NhflUN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it’s also sometimes possible to do this without using vmap by summing the outputs and then computing the Jacobian of that function:\n", + "\n" + ], + "metadata": { + "id": "_OLDiY3MflUN" + } + }, + { + "cell_type": "code", + "source": [ + "def predict_with_output_summed(weight, bias, x):\n", + " return predict(weight, bias, x).sum(0)\n", + "\n", + "batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\n", + "assert torch.allclose(batch_jacobian0, batch_jacobian1)" + ], + "metadata": { + "id": "_QH4hD8PflUO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "If you instead have a function that goes from $𝑅^𝑁 \\to 𝑅^𝑀$ but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n", + "\n", + "Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.\n", + "\n" + ], + "metadata": { + "id": "eUjw65cCflUO" + } + }, + { + "cell_type": "code", + "source": [ + "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n", + "\n", + "batch_hess = compute_batch_hessian(weight, bias, x)\n", + "batch_hess.shape" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f3135cfa-e9e5-4f18-8cb7-0655e8a37cb5", + "id": "3vAyQjMsflUO" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([64, 33, 31, 31])" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Computing Hessian-vector products\n", + "\n", + "The naive way to compute a Hessian-vector product (hvp) is to materialize the full Hessian and perform a dot-product with a vector. We can do better: it turns out we don't need to materialize the full Hessian to do this. We'll go through two (of many) different strategies to compute Hessian-vector products:\n", + "- composing reverse-mode AD with reverse-mode AD\n", + "- composing reverse-mode AD with forward-mode AD\n", + "\n", + "Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn't need to construct an Autograd graph and save intermediates for backward:" + ], + "metadata": { + "id": "Wa8E48sQgpkb" + } + }, + { + "cell_type": "code", + "source": [ + "from functorch import jvp, grad, vjp\n", + "\n", + "def hvp(f, primals, tangents):\n", + " return jvp(grad(f), primals, tangents)[1]" + ], + "metadata": { + "id": "trw6WbAth6BM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Here's some sample usage." + ], + "metadata": { + "id": "DQMpRo6nitfr" + } + }, + { + "cell_type": "code", + "source": [ + "def f(x):\n", + " return x.sin().sum()\n", + "\n", + "x = torch.randn(2048)\n", + "tangent = torch.randn(2048)\n", + "\n", + "result = hvp(f, (x,), (tangent,))" + ], + "metadata": { + "id": "sPwg8SOdiVAK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "If PyTorch forward-AD does not have coverage for your operations, then we can instead compose reverse-mode AD with reverse-mode AD:" + ], + "metadata": { + "id": "zGvUIcB0j1Ez" + } + }, + { + "cell_type": "code", + "source": [ + "def hvp_revrev(f, primals, tangents):\n", + " _, vjp_fn = vjp(grad(f), *primals)\n", + " return vjp_fn(*tangents)" + ], + "metadata": { + "id": "mdDFZdlekAOK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\n", + "assert torch.allclose(result, result_hvp_revrev[0])" + ], + "metadata": { + "id": "_CuCk9X0lW7C" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/functorch/notebooks/colab/per_sample_grads_colab.ipynb b/functorch/notebooks/colab/per_sample_grads_colab.ipynb new file mode 100644 index 0000000000000..5912649f11bbc --- /dev/null +++ b/functorch/notebooks/colab/per_sample_grads_colab.ipynb @@ -0,0 +1,795 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "per_sample_grads_colab.ipynb", + "provenance": [], + "collapsed_sections": [], + "machine_shape": "hm", + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "### Welcome to the functorch tutorial on Per-Sample-Gradients, in colab." + ], + "metadata": { + "id": "W6b4RUiYnhSt" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Configuring your colab to run functorch \n" + ], + "metadata": { + "id": "0I5Mm2q2f5aw" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Getting setup** - running functorch currently requires Pytorch Nightly. \n", + "Thus we'll go through a pytorch nightly install and build functorch. \n", + "\n", + "After that and a restart, you'll be ready to run the tutorial here on colab." + ], + "metadata": { + "id": "jnHxd2KFgPJg" + } + }, + { + "cell_type": "markdown", + "source": [ + "Let's setup a restart function:" + ], + "metadata": { + "id": "PvwZSOklhpB2" + } + }, + { + "cell_type": "code", + "source": [ + "def colab_restart():\n", + " print(\"--> Restarting colab instance\") \n", + " get_ipython().kernel.do_shutdown(True)" + ], + "metadata": { + "id": "MklsA-KRhZKC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, let's confirm that we have a gpu. \n", + "(If not, select Runtime -> Change Runtime type above,\n", + " and select GPU under Hardward Accelerator )" + ], + "metadata": { + "id": "Njk9qPgTiiGS" + } + }, + { + "cell_type": "code", + "source": [ + "!nvcc --version" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HxidO4dpiPGi", + "outputId": "d6d31c17-02cf-427b-cae8-6994c57c2320" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "nvcc: NVIDIA (R) Cuda compiler driver\n", + "Copyright (c) 2005-2020 NVIDIA Corporation\n", + "Built on Mon_Oct_12_20:09:46_PDT_2020\n", + "Cuda compilation tools, release 11.1, V11.1.105\n", + "Build cuda_11.1.TC455_06.29190527_0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's remove the default PyTorch install:" + ], + "metadata": { + "id": "HanoUO62jtKx" + } + }, + { + "cell_type": "code", + "source": [ + "!pip uninstall -y torch" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NIoTNykP9xI5", + "outputId": "5cc5a77d-9696-4cde-a7e5-3d835058afee" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Found existing installation: torch 1.10.0+cu111\n", + "Uninstalling torch-1.10.0+cu111:\n", + " Successfully uninstalled torch-1.10.0+cu111\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "And install the relevant version. (this defaults to 10.2 Cuda which works on most colabs). " + ], + "metadata": { + "id": "n-DFUwBVkHaX" + } + }, + { + "cell_type": "code", + "source": [ + "cuda_version = \"cu102\" # optionally - cu113 (for 11.3) is an option as well if you have 11.3 listed above in the nvcc output. " + ], + "metadata": { + "id": "BH5ffJBkkRR8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install torch -f https://download.pytorch.org/whl/{cuda_version}/torch_stable.html --upgrade" + ], + "metadata": { + "id": "Bi2oymijkav5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next we'll install functorch:" + ], + "metadata": { + "id": "s3rrVgGkmNpi" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UtBgzUPDfIQg" + }, + "outputs": [], + "source": [ + "!pip install functorch" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Finally - restart colab and after that - just skip directly down to the '-- Tutorial Start --' section to get underway." + ], + "metadata": { + "id": "T8dhR1XEmcJ6" + } + }, + { + "cell_type": "code", + "source": [ + "colab_restart() " + ], + "metadata": { + "id": "xo2UY9b8ma8t", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "a59e08c1-7206-4439-e08e-c4b8ff004f49" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--> Restarting colab instance\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## -- Tutorial Start -- \n", + "\n" + ], + "metadata": { + "id": "nj6_fW76wM0d" + } + }, + { + "cell_type": "code", + "source": [ + "# Confirm we are ready to start. \n", + "# If this errs, please make sure you have completed the 'configuring your colab' steps above first and then return here.\n", + "\n", + "import functorch " + ], + "metadata": { + "id": "SvUfIxRyeAaL" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Per-sample-gradients\n", + "Per-sample-gradient computation is computing the gradient for each and every sample in a batch of data. \n", + "It is a useful quantity for differential privacy, meta-learning, and optimization research.\n", + "\n", + "Let's walk through an example of per-sample-gradients in action below with a simple CNN model. \n", + "\n", + "\n" + ], + "metadata": { + "id": "nLdOLDH6m9oy" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from functools import partial\n", + "\n", + "torch.manual_seed(0);" + ], + "metadata": { + "id": "Gb-yt4VKUUuc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Here's a simple CNN and loss function:\n", + "\n", + "class SimpleCNN(nn.Module):\n", + " def __init__(self):\n", + " super(SimpleCNN, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", + " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", + " self.fc1 = nn.Linear(9216, 128)\n", + " self.fc2 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = F.relu(x)\n", + " x = self.conv2(x)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, 2)\n", + " x = torch.flatten(x, 1)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = self.fc2(x)\n", + " output = F.log_softmax(x, dim=1)\n", + " output = x\n", + " return output\n", + "\n", + "def loss_fn(predictions, targets):\n", + " return F.nll_loss(predictions, targets)" + ], + "metadata": { + "id": "tf-HKHjUUbyY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. \n", + "\n", + "The dummy images are 28 by 28 and we use a minibatch of size 64.\n", + "\n" + ], + "metadata": { + "id": "VEDPe-EoU5Fa" + } + }, + { + "cell_type": "code", + "source": [ + "device = 'cuda'\n", + "\n", + "num_models = 10\n", + "batch_size = 64\n", + "data = torch.randn(batch_size, 1, 28, 28, device=device)\n", + "\n", + "targets = torch.randint(10, (64,), device=device)" + ], + "metadata": { + "id": "WB2Qe3AHUvPN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "In regular model training, one would forward the minibatch through the model, and then call .backward() to compute gradients. This would generate an 'average' gradient of the entire mini-batch:\n", + "\n" + ], + "metadata": { + "id": "GOGJ-OUxVcT5" + } + }, + { + "cell_type": "code", + "source": [ + "model = SimpleCNN().to(device=device)\n", + "predictions = model(data) # move the entire mini-batch through the model\n", + "\n", + "loss = loss_fn(predictions, targets)\n", + "loss.backward() # back propogate the 'average' gradient of this mini-batch" + ], + "metadata": { + "id": "WYjMx8QTUvRu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "In contrast to the above approach, per-sample-gradient computation is equivalent to: \n", + "- for each individual sample of the data, perform a forward and a backward pass to get an individual (per-sample) gradient.\n", + "\n" + ], + "metadata": { + "id": "HNw4_IVzU5Pz" + } + }, + { + "cell_type": "code", + "source": [ + "def compute_grad(sample, target):\n", + " \n", + " sample = sample.unsqueeze(0) # prepend batch dimension for processing\n", + " target = target.unsqueeze(0)\n", + "\n", + " prediction = model(sample)\n", + " loss = loss_fn(prediction, target)\n", + "\n", + " return torch.autograd.grad(loss, list(model.parameters()))\n", + "\n", + "\n", + "def compute_sample_grads(data, targets):\n", + " \"\"\" manually process each sample with per sample gradient \"\"\"\n", + " sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]\n", + " sample_grads = zip(*sample_grads)\n", + " sample_grads = [torch.stack(shards) for shards in sample_grads]\n", + " return sample_grads\n", + "\n", + "per_sample_grads = compute_sample_grads(data, targets)" + ], + "metadata": { + "id": "vUsb3VfexJrY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "`sample_grads[0]` is the per-sample-grad for model.conv1.weight. `model.conv1.weight.shape` is `[32, 1, 3, 3]`; notice how there is one gradient, per sample, in the batch for a total of 64.\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "aNkX6lFIxzcm" + } + }, + { + "cell_type": "code", + "source": [ + "print(per_sample_grads[0].shape)" + ], + "metadata": { + "id": "C3a9_clvyPho", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "407abc1a-846f-4e50-83bc-c90719a26073" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([64, 32, 1, 3, 3])\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Per-sample-grads, *the efficient way*, using functorch\n", + "\n", + "\n" + ], + "metadata": { + "id": "mFJDWMM9yaYZ" + } + }, + { + "cell_type": "markdown", + "source": [ + "We can compute per-sample-gradients efficiently by using function transforms. \n", + "\n", + "First, let’s create a stateless functional version of `model` by using `functorch.make_functional_with_buffers`. \n", + "\n", + "This will seperate state (the parameters) from the model and turn the model into a pure function:\n", + "\n" + ], + "metadata": { + "id": "tlkmyQyfY6XU" + } + }, + { + "cell_type": "code", + "source": [ + "from functorch import make_functional_with_buffers, vmap, grad\n", + "\n", + "fmodel, params, buffers = make_functional_with_buffers(model)" + ], + "metadata": { + "id": "WiSMupvCyecd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's review the changes - first, the model has become the stateless FunctionalModuleWithBuffers:" + ], + "metadata": { + "id": "wMsbppPNZklo" + } + }, + { + "cell_type": "code", + "source": [ + "fmodel" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Xj0cZOJMZbbB", + "outputId": "2e87dfde-3af2-4e1f-cd91-5c232446fb53" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "FunctionalModuleWithBuffers(\n", + " (stateless_model): SimpleCNN(\n", + " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", + " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", + " (fc1): Linear(in_features=9216, out_features=128, bias=True)\n", + " (fc2): Linear(in_features=128, out_features=10, bias=True)\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 15 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "And the model parameters now exist independently of the model, stored as a tuple:" + ], + "metadata": { + "id": "zv4_YYPxZvvg" + } + }, + { + "cell_type": "code", + "source": [ + "for x in params:\n", + " print(f\"{x.shape}\")\n", + "\n", + "print(f\"\\n{type(params)}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tH0TAZhBZ3bS", + "outputId": "97c4401f-cccb-43f6-b071-c85a18fc439b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([32, 1, 3, 3])\n", + "torch.Size([32])\n", + "torch.Size([64, 32, 3, 3])\n", + "torch.Size([64])\n", + "torch.Size([128, 9216])\n", + "torch.Size([128])\n", + "torch.Size([10, 128])\n", + "torch.Size([10])\n", + "\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Next, let’s define a function to compute the loss of the model given a single input rather than a batch of inputs. It is important that this function accepts the parameters, the input, and the target, because we will be transforming over them. \n", + "\n", + "Note - because the model was originally written to handle batches, we’ll use `torch.unsqueeze` to add a batch dimension.\n", + "\n" + ], + "metadata": { + "id": "cTgIIZ9Wyih8" + } + }, + { + "cell_type": "code", + "source": [ + "def compute_loss_stateless_model (params, buffers, sample, target):\n", + " batch = sample.unsqueeze(0)\n", + " targets = target.unsqueeze(0)\n", + "\n", + " predictions = fmodel(params, buffers, batch) \n", + " loss = loss_fn(predictions, targets)\n", + " return loss" + ], + "metadata": { + "id": "ItURFU3M-p98" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now, let’s use functorch's `grad` to create a new function that computes the gradient with respect to the first argument of `compute_loss` (i.e. the params)." + ], + "metadata": { + "id": "Qo3sbDK2i_bH" + } + }, + { + "cell_type": "code", + "source": [ + "ft_compute_grad = grad(compute_loss_stateless_model)" + ], + "metadata": { + "id": "sqRp_Sxni-Xm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The `ft_compute_grad` function computes the gradient for a single (sample, target) pair. We can use vmap to get it to compute the gradient over an entire batch of samples and targets. Note that `in_dims=(None, None, 0, 0)` because we wish to map `ft_compute_grad` over the 0th dimension of the data and targets, and use the same params and buffers for each.\n", + "\n" + ], + "metadata": { + "id": "2pG3Ofqjjc8O" + } + }, + { + "cell_type": "code", + "source": [ + "ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))" + ], + "metadata": { + "id": "62ecNMO6inqX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Finally, let’s used our transformed function to compute per-sample-gradients:\n", + "\n" + ], + "metadata": { + "id": "_alXdQ3QkETu" + } + }, + { + "cell_type": "code", + "source": [ + "ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)\n", + "\n", + "# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:\n", + "for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):\n", + " assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)" + ], + "metadata": { + "id": "1gehVA1c-BHd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs, and that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations.\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "BEZaNt1d_bc1" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Performance comparison" + ], + "metadata": { + "id": "BASP151Iml7B" + } + }, + { + "cell_type": "markdown", + "source": [ + "Curious about how the performance of vmap compares?\n", + "\n", + "Currently the best results are obtained on newer GPU's such as the A100 (Ampere) where we've seen up to 25x speedups on this example, but here are some results done in Colab:" + ], + "metadata": { + "id": "jr1xNpV4nJ7u" + } + }, + { + "cell_type": "code", + "source": [ + "def get_perf(first, first_descriptor, second, second_descriptor):\n", + " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n", + " second_res = second.times[0]\n", + " first_res = first.times[0]\n", + "\n", + " gain = (first_res-second_res)/first_res\n", + " if gain < 0: gain *=-1 \n", + " final_gain = gain*100\n", + "\n", + " print(f\" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} \")" + ], + "metadata": { + "id": "GnAnMkYmoc-j" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from torch.utils.benchmark import Timer\n", + "\n", + "without_vmap = Timer( stmt=\"compute_sample_grads(data, targets)\", globals=globals())\n", + "with_vmap = Timer(stmt=\"ft_compute_sample_grad(params, buffers, data, targets)\",globals=globals())\n", + "no_vmap_timing = without_vmap.timeit(100)\n", + "with_vmap_timing = with_vmap.timeit(100)\n", + "\n", + "print(f'Per-sample-grads without vmap {no_vmap_timing}')\n", + "print(f'Per-sample-grads with vmap {with_vmap_timing}')" + ], + "metadata": { + "id": "Zfnn2C2g-6Fb", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "922f3901-773f-446b-b562-88e78f49036c" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Per-sample-grads without vmap \n", + "compute_sample_grads(data, targets)\n", + " 79.86 ms\n", + " 1 measurement, 100 runs , 1 thread\n", + "Per-sample-grads with vmap \n", + "ft_compute_sample_grad(params, buffers, data, targets)\n", + " 12.93 ms\n", + " 1 measurement, 100 runs , 1 thread\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "get_perf(with_vmap_timing, \"vmap\", no_vmap_timing,\"no vmap\" )" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NV9R3LZQoavl", + "outputId": "e11e8be9-287d-4e60-e517-e08f8d6909bd" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Performance delta: 517.5791 percent improvement with vmap \n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "There are other optimized solutions (like in https://github.com/pytorch/opacus) to computing per-sample-gradients in PyTorch that also perform better than the naive method. But it’s cool that composing `vmap` and `grad` give us a nice speedup.\n", + "\n", + "\n", + "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n", + "\n" + ], + "metadata": { + "id": "UI74G9JarQU8" + } + } + ] +} \ No newline at end of file diff --git a/functorch/notebooks/colab/readme.md b/functorch/notebooks/colab/readme.md new file mode 100644 index 0000000000000..fbdf129da00f7 --- /dev/null +++ b/functorch/notebooks/colab/readme.md @@ -0,0 +1,5 @@ +### Holds the colab ready versions of the notebook tutorials. + +These are similar to the jupyter notebooks, but have additional colab specific changes including the building of functorch in colab to prep for running. + +The colabs and notebooks are not auto-synced atm, thus currently updates to one need to be synched to the other. diff --git a/functorch/notebooks/ensembling.ipynb b/functorch/notebooks/ensembling.ipynb new file mode 100644 index 0000000000000..72554b9f9e22a --- /dev/null +++ b/functorch/notebooks/ensembling.ipynb @@ -0,0 +1,391 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "de1548fb-a313-4e9c-ae5d-8ec4c12ddd94", + "metadata": { + "id": "de1548fb-a313-4e9c-ae5d-8ec4c12ddd94" + }, + "source": [ + "# Model ensembling\n", + "\n", + "This example illustrates how to vectorize model ensembling using vmap.\n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "## What is model ensembling?\n", + "Model ensembling combines the predictions from multiple models together.\n", + "Traditionally this is done by running each model on some inputs separately\n", + "and then combining the predictions. However, if you're running models with\n", + "the same architecture, then it may be possible to combine them together\n", + "using `vmap`. `vmap` is a function transform that maps functions across\n", + "dimensions of the input tensors. One of its use cases is eliminating\n", + "for-loops and speeding them up through vectorization.\n", + "\n", + "Let's demonstrate how to do this using an ensemble of simple CNNs." + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from functools import partial\n", + "torch.manual_seed(0);" + ], + "metadata": { + "id": "Gb-yt4VKUUuc" + }, + "execution_count": null, + "outputs": [], + "id": "Gb-yt4VKUUuc" + }, + { + "cell_type": "code", + "source": [ + "# Here's a simple MLP\n", + "class SimpleMLP(nn.Module):\n", + " def __init__(self):\n", + " super(SimpleMLP, self).__init__()\n", + " self.fc1 = nn.Linear(784, 128)\n", + " self.fc2 = nn.Linear(128, 128)\n", + " self.fc3 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, x):\n", + " x = x.flatten(1)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = self.fc2(x)\n", + " x = F.relu(x)\n", + " x = self.fc3(x)\n", + " return x\n" + ], + "metadata": { + "id": "tf-HKHjUUbyY" + }, + "execution_count": null, + "outputs": [], + "id": "tf-HKHjUUbyY" + }, + { + "cell_type": "markdown", + "source": [ + "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a minibatch of size 64. Furthermore, lets say we want to combine the predictions from 10 different models. \n" + ], + "metadata": { + "id": "VEDPe-EoU5Fa" + }, + "id": "VEDPe-EoU5Fa" + }, + { + "cell_type": "code", + "source": [ + "device = 'cuda'\n", + "num_models = 10\n", + "\n", + "data = torch.randn(100, 64, 1, 28, 28, device=device)\n", + "targets = torch.randint(10, (6400,), device=device)\n", + "\n", + "models = [SimpleMLP().to(device) for _ in range(num_models)]" + ], + "metadata": { + "id": "WB2Qe3AHUvPN" + }, + "execution_count": null, + "outputs": [], + "id": "WB2Qe3AHUvPN" + }, + { + "cell_type": "markdown", + "source": [ + "We have a couple of options for generating predictions. Maybe we want to give each model a different randomized minibatch of data. Alternatively, maybe we want to run the same minibatch of data through each model (e.g. if we were testing the effect of different model initializations).\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "GOGJ-OUxVcT5" + }, + "id": "GOGJ-OUxVcT5" + }, + { + "cell_type": "markdown", + "source": [ + "Option 1: different minibatch for each model" + ], + "metadata": { + "id": "CwJBb09MxCN3" + }, + "id": "CwJBb09MxCN3" + }, + { + "cell_type": "code", + "source": [ + "minibatches = data[:num_models]\n", + "predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]" + ], + "metadata": { + "id": "WYjMx8QTUvRu" + }, + "execution_count": null, + "outputs": [], + "id": "WYjMx8QTUvRu" + }, + { + "cell_type": "markdown", + "source": [ + "Option 2: Same minibatch" + ], + "metadata": { + "id": "HNw4_IVzU5Pz" + }, + "id": "HNw4_IVzU5Pz" + }, + { + "cell_type": "code", + "source": [ + "minibatch = data[0]\n", + "predictions2 = [model(minibatch) for model in models]" + ], + "metadata": { + "id": "vUsb3VfexJrY" + }, + "execution_count": null, + "outputs": [], + "id": "vUsb3VfexJrY" + }, + { + "cell_type": "markdown", + "source": [ + "## Using vmap to vectorize the ensemble\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "aNkX6lFIxzcm" + }, + "id": "aNkX6lFIxzcm" + }, + { + "cell_type": "markdown", + "source": [ + "Let’s use vmap to speed up the for-loop. We must first prepare the models for use with vmap.\n", + "\n", + "First, let’s combine the states of the model together by stacking each parameter. For example, `model[i].fc1.weight` has shape `[784, 128]`; we are going to stack the .fc1.weight of each of the 10 models to produce a big weight of shape `[10, 784, 128]`.\n", + "\n", + "functorch offers the 'combine_state_for_ensemble' convenience function to do that. It returns a stateless version of the model (fmodel) and stacked parameters and buffers.\n", + "\n" + ], + "metadata": { + "id": "-sFMojhryviM" + }, + "id": "-sFMojhryviM" + }, + { + "cell_type": "code", + "source": [ + "from functorch import combine_state_for_ensemble\n", + "\n", + "fmodel, params, buffers = combine_state_for_ensemble(models)\n", + "[p.requires_grad_() for p in params];\n" + ], + "metadata": { + "id": "C3a9_clvyPho" + }, + "execution_count": null, + "outputs": [], + "id": "C3a9_clvyPho" + }, + { + "cell_type": "markdown", + "source": [ + "Option 1: get predictions using a different minibatch for each model. \n", + "\n", + "By default, vmap maps a function across the first dimension of all inputs to the passed-in function. After using the combine_state_for_ensemble, each of the params and buffers have an additional dimension of size 'num_models' at the front, and minibatches has a dimension of size 'num_models'.\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "mFJDWMM9yaYZ" + }, + "id": "mFJDWMM9yaYZ" + }, + { + "cell_type": "code", + "source": [ + "print([p.size(0) for p in params]) # show the leading 'num_models' dimension\n", + "\n", + "assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ezuFQx1G1zLG", + "outputId": "ab260da3-77f2-4ff9-d843-e0d0f1e0a884" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[10, 10, 10, 10, 10, 10]\n" + ] + } + ], + "id": "ezuFQx1G1zLG" + }, + { + "cell_type": "code", + "source": [ + "from functorch import vmap\n", + "\n", + "predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)\n", + "\n", + "# verify the vmap predictions match the \n", + "assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)" + ], + "metadata": { + "id": "VroLnfD82DDf" + }, + "execution_count": null, + "outputs": [], + "id": "VroLnfD82DDf" + }, + { + "cell_type": "markdown", + "source": [ + "Option 2: get predictions using the same minibatch of data.\n", + "\n", + "vmap has an in_dims arg that specifies which dimensions to map over. By using `None`, we tell vmap we want the same minibatch to apply for all of the 10 models.\n", + "\n", + "\n" + ], + "metadata": { + "id": "tlkmyQyfY6XU" + }, + "id": "tlkmyQyfY6XU" + }, + { + "cell_type": "code", + "source": [ + "predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)\n", + "\n", + "assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)" + ], + "metadata": { + "id": "WiSMupvCyecd" + }, + "execution_count": null, + "outputs": [], + "id": "WiSMupvCyecd" + }, + { + "cell_type": "markdown", + "source": [ + "A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations." + ], + "metadata": { + "id": "KrXQsUCIGLWm" + }, + "id": "KrXQsUCIGLWm" + }, + { + "cell_type": "markdown", + "source": [ + "## Performance\n", + "\n", + "Curious about performance numbers? Here's how the numbers look on Google Colab." + ], + "metadata": { + "id": "MCjBhMrVF5hH" + }, + "id": "MCjBhMrVF5hH" + }, + { + "cell_type": "code", + "source": [ + "from torch.utils.benchmark import Timer\n", + "without_vmap = Timer(\n", + " stmt=\"[model(minibatch) for model, minibatch in zip(models, minibatches)]\",\n", + " globals=globals())\n", + "with_vmap = Timer(\n", + " stmt=\"vmap(fmodel)(params, buffers, minibatches)\",\n", + " globals=globals())\n", + "print(f'Predictions without vmap {without_vmap.timeit(100)}')\n", + "print(f'Predictions with vmap {with_vmap.timeit(100)}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gJPrGdS0GBjz", + "outputId": "04e75950-b964-419c-fa9c-f1590e0081bb" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Predictions without vmap \n", + "[model(minibatch) for model, minibatch in zip(models, minibatches)]\n", + " 3.25 ms\n", + " 1 measurement, 100 runs , 1 thread\n", + "Predictions with vmap \n", + "vmap(fmodel)(params, buffers, minibatches)\n", + " 879.28 us\n", + " 1 measurement, 100 runs , 1 thread\n" + ] + } + ], + "id": "gJPrGdS0GBjz" + }, + { + "cell_type": "markdown", + "source": [ + "There's a large speedup using vmap! \n", + "\n", + "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n", + "\n" + ], + "metadata": { + "id": "UI74G9JarQU8" + }, + "id": "UI74G9JarQU8" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "colab": { + "name": "ensembling.ipynb", + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/functorch/notebooks/jacobians_hessians.ipynb b/functorch/notebooks/jacobians_hessians.ipynb new file mode 100644 index 0000000000000..172479ae76261 --- /dev/null +++ b/functorch/notebooks/jacobians_hessians.ipynb @@ -0,0 +1,952 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms\n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "Computing jacobians or hessians are useful in a number of non-traditional\n", + "deep learning models. It is difficult (or annoying) to compute these quantities\n", + "efficiently using a standard autodiff system like PyTorch Autograd; functorch\n", + "provides ways of computing various higher-order autodiff quantities efficiently." + ], + "metadata": { + "id": "zPbR6-eP51fe" + }, + "id": "zPbR6-eP51fe" + }, + { + "cell_type": "markdown", + "source": [ + "## Computing the Jacobian" + ], + "metadata": { + "id": "3kDj8fhn52j3" + }, + "id": "3kDj8fhn52j3" + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from functools import partial\n", + "_ = torch.manual_seed(0)" + ], + "metadata": { + "id": "w_IinyjzflUH" + }, + "execution_count": null, + "outputs": [], + "id": "w_IinyjzflUH" + }, + { + "cell_type": "markdown", + "source": [ + "Let’s start with a function that we’d like to compute the jacobian of. This is a simple linear function with non-linear activation.\n", + "\n" + ], + "metadata": { + "id": "cibF_PEYflUH" + }, + "id": "cibF_PEYflUH" + }, + { + "cell_type": "code", + "source": [ + "def predict(weight, bias, x):\n", + " return F.linear(x, weight, bias).tanh()" + ], + "metadata": { + "id": "qhcD9hWYflUH" + }, + "execution_count": null, + "outputs": [], + "id": "qhcD9hWYflUH" + }, + { + "cell_type": "markdown", + "source": [ + "Let's add some dummy data: a weight, a bias, and a feature vector x.\n", + "\n" + ], + "metadata": { + "id": "G8tqQrO_flUH" + }, + "id": "G8tqQrO_flUH" + }, + { + "cell_type": "code", + "source": [ + "D = 16\n", + "weight = torch.randn(D, D)\n", + "bias = torch.randn(D)\n", + "x = torch.randn(D) # feature vector" + ], + "metadata": { + "id": "FZ4uJfZGflUH" + }, + "execution_count": null, + "outputs": [], + "id": "FZ4uJfZGflUH" + }, + { + "cell_type": "markdown", + "source": [ + "Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.\n", + "PyTorch Autograd computes vector-Jacobian products. In order to compute the full\n", + "Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row\n", + "by using a different unit vector each time." + ], + "metadata": { + "id": "uMAW-ArQflUH" + }, + "id": "uMAW-ArQflUH" + }, + { + "cell_type": "code", + "source": [ + "def compute_jac(xp):\n", + " jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n", + " for vec in unit_vectors]\n", + " return torch.stack(jacobian_rows)" + ], + "metadata": { + "id": "z-BJPtbpflUI" + }, + "execution_count": null, + "outputs": [], + "id": "z-BJPtbpflUI" + }, + { + "cell_type": "code", + "source": [ + "xp = x.clone().requires_grad_()\n", + "unit_vectors = torch.eye(D)\n", + "\n", + "jacobian = compute_jac(xp)\n", + "\n", + "print(jacobian.shape)\n", + "print(jacobian[0]) # show first row" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f1f1ec12-56ef-40f7-8c3c-cbad7bf86644", + "id": "zuWGSXspflUI" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([16, 16])\n", + "tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,\n", + " 0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])\n" + ] + } + ], + "id": "zuWGSXspflUI" + }, + { + "cell_type": "markdown", + "source": [ + "Instead of computing the jacobian row-by-row, we can use vmap to get rid of the for-loop and vectorize the computation. \n", + "We can’t directly apply vmap to PyTorch Autograd; instead, functorch provides a vjp transform:\n", + "\n" + ], + "metadata": { + "id": "mxlEOUieflUI" + }, + "id": "mxlEOUieflUI" + }, + { + "cell_type": "code", + "source": [ + "from functorch import vmap, vjp\n", + "\n", + "_, vjp_fn = vjp(partial(predict, weight, bias), x)\n", + "\n", + "ft_jacobian, = vmap(vjp_fn)(unit_vectors)\n", + "\n", + "# lets confirm both methods compute the same result\n", + "assert torch.allclose(ft_jacobian, jacobian)" + ], + "metadata": { + "id": "DeF6uy4WflUI" + }, + "execution_count": null, + "outputs": [], + "id": "DeF6uy4WflUI" + }, + { + "cell_type": "markdown", + "source": [ + "In future tutorial a composition of reverse-mode AD and vmap will give us per-sample-gradients. \n", + "In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! \n", + "Various compositions of vmap and autodiff transforms can give us different interesting quantities.\n", + "\n", + "functorch provides **jacrev** as a convenience function that performs the vmap-vjp composition to compute jacobians. **jacrev** accepts an argnums argument that says which argument we would like to compute Jacobians with respect to.\n", + "\n" + ], + "metadata": { + "id": "Hy4REmwDflUI" + }, + "id": "Hy4REmwDflUI" + }, + { + "cell_type": "code", + "source": [ + "from functorch import jacrev\n", + "\n", + "ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n", + "\n", + "# confirm \n", + "assert torch.allclose(ft_jacobian, jacobian)" + ], + "metadata": { + "id": "Rt7i6_YlflUI" + }, + "execution_count": null, + "outputs": [], + "id": "Rt7i6_YlflUI" + }, + { + "cell_type": "markdown", + "source": [ + "Let’s compare the performance of the two ways to compute the jacobian. The functorch version is much faster (and becomes even faster the more outputs there are). \n", + "\n", + "In general, we expect that vectorization via vmap can help eliminate overhead and give better utilization of your hardware.\n", + "\n", + "Vmap does this magic by pushing the outer loop down into the functions primitive operations in order to obtain better performance.\n", + "\n", + "\n" + ], + "metadata": { + "id": "JYe2H1UcflUJ" + }, + "id": "JYe2H1UcflUJ" + }, + { + "cell_type": "markdown", + "source": [ + "Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:" + ], + "metadata": { + "id": "i_143LZwflUJ" + }, + "id": "i_143LZwflUJ" + }, + { + "cell_type": "code", + "source": [ + "def get_perf(first, first_descriptor, second, second_descriptor):\n", + " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n", + " faster = second.times[0]\n", + " slower = first.times[0]\n", + " gain = (slower-faster)/slower\n", + " if gain < 0: gain *=-1 \n", + " final_gain = gain*100\n", + " print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")" + ], + "metadata": { + "id": "II7r6jBtflUJ" + }, + "execution_count": null, + "outputs": [], + "id": "II7r6jBtflUJ" + }, + { + "cell_type": "markdown", + "source": [ + "And then run the performance comparison:" + ], + "metadata": { + "id": "r4clPnPKflUJ" + }, + "id": "r4clPnPKflUJ" + }, + { + "cell_type": "code", + "source": [ + "from torch.utils.benchmark import Timer\n", + "\n", + "without_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\n", + "with_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "\n", + "no_vmap_timer = without_vmap.timeit(500)\n", + "with_vmap_timer = with_vmap.timeit(500)\n", + "\n", + "print(no_vmap_timer)\n", + "print(with_vmap_timer)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "cbf77a19-aac9-428d-eba1-74d337c53e49", + "id": "ZPtoxF6eflUJ" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "compute_jac(xp)\n", + " 2.25 ms\n", + " 1 measurement, 500 runs , 1 thread\n", + "\n", + "jacrev(predict, argnums=2)(weight, bias, x)\n", + " 884.34 us\n", + " 1 measurement, 500 runs , 1 thread\n" + ] + } + ], + "id": "ZPtoxF6eflUJ" + }, + { + "cell_type": "markdown", + "source": [ + "Lets do a relative performance comparison of the above with our get_perf function:" + ], + "metadata": { + "id": "nGBBi4dZflUJ" + }, + "id": "nGBBi4dZflUJ" + }, + { + "cell_type": "code", + "source": [ + "get_perf(no_vmap_timer, \"without vmap\", with_vmap_timer, \"vmap\");" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "85d0bc5f-34aa-4826-f953-6c637404490c", + "id": "zqV2RzEXflUJ" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Performance delta: 60.7170 percent improvement with vmap \n" + ] + } + ], + "id": "zqV2RzEXflUJ" + }, + { + "cell_type": "markdown", + "source": [ + "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input." + ], + "metadata": { + "id": "EQAB99EQflUJ" + }, + "id": "EQAB99EQflUJ" + }, + { + "cell_type": "code", + "source": [ + "# note the change in input via argnums params of 0,1 to map to weight and bias\n", + "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)" + ], + "metadata": { + "id": "8UZpC8DnflUK" + }, + "execution_count": null, + "outputs": [], + "id": "8UZpC8DnflUK" + }, + { + "cell_type": "markdown", + "source": [ + "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n" + ], + "metadata": { + "id": "F3USYENIflUK" + }, + "id": "F3USYENIflUK" + }, + { + "cell_type": "markdown", + "source": [ + "We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: \n", + "- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. \n", + "- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. \n", + "\n", + "jacfwd and jacrev can be substituted for each other but they have different performance characteristics.\n", + "\n", + "As a general rule of thumb, if you’re computing the jacobian of an $𝑅^N \\to R^M$ function, and there are many more outputs than inputs (i.e. $M > N$) then jacfwd is preferred, otherwise use jacrev. There are exceptions to this rule, but a non-rigorous argument for this follows:\n", + "\n", + "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n", + "\n" + ], + "metadata": { + "id": "V7B3vE8dflUK" + }, + "id": "V7B3vE8dflUK" + }, + { + "cell_type": "code", + "source": [ + "from functorch import jacrev, jacfwd" + ], + "metadata": { + "id": "k7Tok7m3flUK" + }, + "execution_count": null, + "outputs": [], + "id": "k7Tok7m3flUK" + }, + { + "cell_type": "markdown", + "source": [ + "First, let's benchmark with more inputs than outputs:\n", + "\n" + ], + "metadata": { + "id": "YrV-gZAaflUL" + }, + "id": "YrV-gZAaflUL" + }, + { + "cell_type": "code", + "source": [ + "Din = 32\n", + "Dout = 2048\n", + "weight = torch.randn(Dout, Din)\n", + "\n", + "bias = torch.randn(Dout)\n", + "x = torch.randn(Din)\n", + "\n", + "# remember the general rule about taller vs wider...here we have a taller matrix:\n", + "print(weight.shape)\n", + "\n", + "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "\n", + "jacfwd_timing = using_fwd.timeit(500)\n", + "jacrev_timing = using_bwd.timeit(500)\n", + "\n", + "print(f'jacfwd time: {jacfwd_timing}')\n", + "print(f'jacrev time: {jacrev_timing}')\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "dd882726-9723-47c0-a72f-3c7835a85aa1", + "id": "m5j-4hSxflUL" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([2048, 32])\n", + "jacfwd time: \n", + "jacfwd(predict, argnums=2)(weight, bias, x)\n", + " 1.32 ms\n", + " 1 measurement, 500 runs , 1 thread\n", + "jacrev time: \n", + "jacrev(predict, argnums=2)(weight, bias, x)\n", + " 12.46 ms\n", + " 1 measurement, 500 runs , 1 thread\n" + ] + } + ], + "id": "m5j-4hSxflUL" + }, + { + "cell_type": "markdown", + "source": [ + "and then do a relative benchmark:" + ], + "metadata": { + "id": "k_Sg-4tVflUL" + }, + "id": "k_Sg-4tVflUL" + }, + { + "cell_type": "code", + "source": [ + "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3a6586a1-269d-46d8-d119-e24f6d46277f", + "id": "_4T96zGjflUL" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Performance delta: 842.8274 percent improvement with jacrev \n" + ] + } + ], + "id": "_4T96zGjflUL" + }, + { + "cell_type": "markdown", + "source": [ + "and now the reverse - more outputs (M) than inputs (N):" + ], + "metadata": { + "id": "RCDPot1yflUL" + }, + "id": "RCDPot1yflUL" + }, + { + "cell_type": "code", + "source": [ + "Din = 2048\n", + "Dout = 32\n", + "weight = torch.randn(Dout, Din)\n", + "bias = torch.randn(Dout)\n", + "x = torch.randn(Din)\n", + "\n", + "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", + "\n", + "jacfwd_timing = using_fwd.timeit(500)\n", + "jacrev_timing = using_bwd.timeit(500)\n", + "\n", + "print(f'jacfwd time: {jacfwd_timing}')\n", + "print(f'jacrev time: {jacrev_timing}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "913e9ccd-3d4f-472a-a749-19cee36d0a16", + "id": "_DRFqzqZflUM" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "jacfwd time: \n", + "jacfwd(predict, argnums=2)(weight, bias, x)\n", + " 7.99 ms\n", + " 1 measurement, 500 runs , 1 thread\n", + "jacrev time: \n", + "jacrev(predict, argnums=2)(weight, bias, x)\n", + " 1.09 ms\n", + " 1 measurement, 500 runs , 1 thread\n" + ] + } + ], + "id": "_DRFqzqZflUM" + }, + { + "cell_type": "markdown", + "source": [ + "and a relative perf comparison:" + ], + "metadata": { + "id": "5SRbMCNsflUM" + }, + "id": "5SRbMCNsflUM" + }, + { + "cell_type": "code", + "source": [ + "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "c282ce25-4f6e-44cd-aed7-60f6f5010e5b", + "id": "uF_9GaoiflUM" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Performance delta: 635.2095 percent improvement with jacfwd \n" + ] + } + ], + "id": "uF_9GaoiflUM" + }, + { + "cell_type": "markdown", + "source": [ + "## Hessian computation with functorch.hessian\n" + ], + "metadata": { + "id": "J29FQaBQflUM" + }, + "id": "J29FQaBQflUM" + }, + { + "cell_type": "markdown", + "source": [ + "We offer a convenience API to compute hessians: `functorch.hessian`. \n", + "Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).\n", + "\n", + "This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. \n", + "Indeed, under the hood, `hessian(f)` is simply `jacfwd(jacrev(f))`.\n", + "\n" + ], + "metadata": { + "id": "My4DPH97flUM" + }, + "id": "My4DPH97flUM" + }, + { + "cell_type": "markdown", + "source": [ + "Note: to boost performance: depending on your model, you may also want to use `jacfwd(jacfwd(f))` or `jacrev(jacrev(f))` instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n", + "\n" + ], + "metadata": { + "id": "FJt038l5flUM" + }, + "id": "FJt038l5flUM" + }, + { + "cell_type": "code", + "source": [ + "from functorch import hessian\n", + "\n", + "# lets reduce the size in order not to blow out colab. Hessians require significant memory:\n", + "Din = 512\n", + "Dout = 32\n", + "weight = torch.randn(Dout, Din)\n", + "bias = torch.randn(Dout)\n", + "x = torch.randn(Din)\n", + "\n", + "hess_api = hessian(predict, argnums=2)(weight, bias, x)\n", + "hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\n", + "#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)\n" + ], + "metadata": { + "id": "jEqr2ywZflUM" + }, + "execution_count": null, + "outputs": [], + "id": "jEqr2ywZflUM" + }, + { + "cell_type": "markdown", + "source": [ + "Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())" + ], + "metadata": { + "id": "n9BHcICQflUN" + }, + "id": "n9BHcICQflUN" + }, + { + "cell_type": "code", + "source": [ + "torch.allclose(hess_api, hess_fwdfwd)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e457e3bc-f085-4f90-966d-f98893b98ea8", + "id": "eHiWRkjJflUN" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ], + "id": "eHiWRkjJflUN" + }, + { + "cell_type": "markdown", + "source": [ + "## Batch Jacobian and Batch Hessian\n" + ], + "metadata": { + "id": "Gjt1RO8HflUN" + }, + "id": "Gjt1RO8HflUN" + }, + { + "cell_type": "markdown", + "source": [ + "In the above examples we’ve been operating with a single feature vector. In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. That is, given a batch of inputs of shape `(B, N)` and a function that goes from $R^N \\to R^M$, we would like a Jacobian of shape `(B, M, N)`. \n", + "\n", + "The easiest way to do this is to use vmap:" + ], + "metadata": { + "id": "RjIzdoQNflUN" + }, + "id": "RjIzdoQNflUN" + }, + { + "cell_type": "code", + "source": [ + "batch_size = 64\n", + "Din = 31\n", + "Dout = 33\n", + "\n", + "weight = torch.randn(Dout, Din)\n", + "print(f\"weight shape = {weight.shape}\")\n", + "\n", + "bias = torch.randn(Dout)\n", + "\n", + "x = torch.randn(batch_size, Din)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "561eb618-e00f-40d5-bd99-fa51ab82051f", + "id": "B1eoEO4UflUN" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "weight shape = torch.Size([33, 31])\n" + ] + } + ], + "id": "B1eoEO4UflUN" + }, + { + "cell_type": "code", + "source": [ + "compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\n", + "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)" + ], + "metadata": { + "id": "nZ_V02NhflUN" + }, + "execution_count": null, + "outputs": [], + "id": "nZ_V02NhflUN" + }, + { + "cell_type": "markdown", + "source": [ + "If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it’s also sometimes possible to do this without using vmap by summing the outputs and then computing the Jacobian of that function:\n", + "\n" + ], + "metadata": { + "id": "_OLDiY3MflUN" + }, + "id": "_OLDiY3MflUN" + }, + { + "cell_type": "code", + "source": [ + "def predict_with_output_summed(weight, bias, x):\n", + " return predict(weight, bias, x).sum(0)\n", + "\n", + "batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\n", + "assert torch.allclose(batch_jacobian0, batch_jacobian1)" + ], + "metadata": { + "id": "_QH4hD8PflUO" + }, + "execution_count": null, + "outputs": [], + "id": "_QH4hD8PflUO" + }, + { + "cell_type": "markdown", + "source": [ + "If you instead have a function that goes from $𝑅^𝑁 \\to 𝑅^𝑀$ but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n", + "\n", + "Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.\n", + "\n" + ], + "metadata": { + "id": "eUjw65cCflUO" + }, + "id": "eUjw65cCflUO" + }, + { + "cell_type": "code", + "source": [ + "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n", + "\n", + "batch_hess = compute_batch_hessian(weight, bias, x)\n", + "batch_hess.shape" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f3135cfa-e9e5-4f18-8cb7-0655e8a37cb5", + "id": "3vAyQjMsflUO" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([64, 33, 31, 31])" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ], + "id": "3vAyQjMsflUO" + }, + { + "cell_type": "markdown", + "source": [ + "## Computing Hessian-vector products\n", + "\n", + "The naive way to compute a Hessian-vector product (hvp) is to materialize the full Hessian and perform a dot-product with a vector. We can do better: it turns out we don't need to materialize the full Hessian to do this. We'll go through two (of many) different strategies to compute Hessian-vector products:\n", + "- composing reverse-mode AD with reverse-mode AD\n", + "- composing reverse-mode AD with forward-mode AD\n", + "\n", + "Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn't need to construct an Autograd graph and save intermediates for backward:" + ], + "metadata": { + "id": "Wa8E48sQgpkb" + }, + "id": "Wa8E48sQgpkb" + }, + { + "cell_type": "code", + "source": [ + "from functorch import jvp, grad, vjp\n", + "\n", + "def hvp(f, primals, tangents):\n", + " return jvp(grad(f), primals, tangents)[1]" + ], + "metadata": { + "id": "trw6WbAth6BM" + }, + "execution_count": null, + "outputs": [], + "id": "trw6WbAth6BM" + }, + { + "cell_type": "markdown", + "source": [ + "Here's some sample usage." + ], + "metadata": { + "id": "DQMpRo6nitfr" + }, + "id": "DQMpRo6nitfr" + }, + { + "cell_type": "code", + "source": [ + "def f(x):\n", + " return x.sin().sum()\n", + "\n", + "x = torch.randn(2048)\n", + "tangent = torch.randn(2048)\n", + "\n", + "result = hvp(f, (x,), (tangent,))" + ], + "metadata": { + "id": "sPwg8SOdiVAK" + }, + "execution_count": null, + "outputs": [], + "id": "sPwg8SOdiVAK" + }, + { + "cell_type": "markdown", + "source": [ + "If PyTorch forward-AD does not have coverage for your operations, then we can instead compose reverse-mode AD with reverse-mode AD:" + ], + "metadata": { + "id": "zGvUIcB0j1Ez" + }, + "id": "zGvUIcB0j1Ez" + }, + { + "cell_type": "code", + "source": [ + "def hvp_revrev(f, primals, tangents):\n", + " _, vjp_fn = vjp(grad(f), *primals)\n", + " return vjp_fn(*tangents)" + ], + "metadata": { + "id": "mdDFZdlekAOK" + }, + "execution_count": null, + "outputs": [], + "id": "mdDFZdlekAOK" + }, + { + "cell_type": "code", + "source": [ + "result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\n", + "assert torch.allclose(result, result_hvp_revrev[0])" + ], + "metadata": { + "id": "_CuCk9X0lW7C" + }, + "execution_count": null, + "outputs": [], + "id": "_CuCk9X0lW7C" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "colab": { + "name": "jacobians_hessians.ipynb", + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/functorch/notebooks/minifier.ipynb b/functorch/notebooks/minifier.ipynb new file mode 100644 index 0000000000000..41d87a1fa44a8 --- /dev/null +++ b/functorch/notebooks/minifier.ipynb @@ -0,0 +1,433 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using the Minifier\n", + "We have a pretty convenient test case minifier with this interface\n", + "```\n", + "def minifier(fail_f: fx.GraphModule, inps, module_fails):\n", + " \"\"\"\n", + " Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.\n", + "\n", + " Does 2 main strategies:\n", + " 1. Truncates suffix: Removes some suffix from the graph and sets a new output.\n", + " 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,\n", + " tries replacing quarter of the graph, etc.\n", + "\n", + " >>> failing_function = fx.symbolic_trace(f)\n", + " >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))\n", + "\n", + " note: module_fails returns True if it fails.\n", + " ...\n", + "```\n", + "\n", + "Specifically, it takes your FX graph, and tries to minify it with the following 4 strategies (while checking that the resulting graph still returns True for `module_fails`), until it can't minify it anymore.\n", + "\n", + "1. Truncates Suffix: Given a FX graph, it tries to remove some suffix from the graph. For example, given this:\n", + "\n", + "```\n", + "def f(a):\n", + " b = x * 2\n", + " c = b + 3\n", + " d = c / 4\n", + " return d\n", + "```\n", + "It might try truncating the suffix, and get\n", + "```\n", + "def f(a):\n", + " b = x * 2\n", + " c = b + 3\n", + " return c\n", + "```\n", + "It tries this in a binary search manner, trying to remove the last 1/2, then 3/4, 1/4 then 7/8, 5/8, 3/8...\n", + "\n", + "2. [Delta Debugging](https://en.wikipedia.org/wiki/Delta_debugging): Of course, removing the suffix isn't always sufficient to minify a graph. What if the error is caused by the first instruction? So, we take an approach inspired by delta debugging - we try removing intermediate nodes of the graph. Unlike with suffixes, there are still dependencies on the removed nodes. So, instead of removing them entirely, we promote them to inputs. For example, given the above example:\n", + "\n", + "```\n", + "def f(a):\n", + " b = x * 2\n", + " c = b + 3\n", + " d = c / 4\n", + " return d\n", + "```\n", + "We might remove a middle node (say, c, in this case).\n", + "```\n", + "def f(a, c):\n", + " b = x * 2\n", + " d = c / 4\n", + " return d\n", + "```\n", + "\n", + "Finally, there are 2 auxiliary strategies - eliminating dead code and removing unused inputs. These are somewhat self-explanatory." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So, let's take a look at a toy example. Let's pretend that our graph fails if it has a \"multiply\" in it. Let's create a failing graph." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key\n", + " operator: aten::multiply.Tensor(Tensor self, Tensor other) -> (Tensor)\n", + " registered at aten/src/ATen/RegisterSchema.cpp:6\n", + " dispatch key: FuncTorchBatched\n", + " previous kernel: registered at aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:10338\n", + " new kernel: registered at /fsx/users/chilli/work/functorch/functorch/csrc/BatchRulesDecompositions.cpp:108 (function registerKernel)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Started off with 7 nodes\n", + "###################\n", + "Current size: 7\n", + "###################\n", + "Strategy: Remove suffix\n", + "\n", + "SUCCESS: Removed [4:7)\n", + "\n", + "###################\n", + "Current size: 6\n", + "###################\n", + "Strategy: Delta Debugging\n", + "SUCCESS: Removed (0:4] - Went from 2 placeholders to 4\n", + "\n", + "###################\n", + "Current size: 6\n", + "###################\n", + "Strategy: Remove unused inputs\n", + "SUCCESS: Went from 4 inputs to 2 inputs\n", + "\n", + "###################\n", + "Current size: 4\n", + "###################\n", + "Strategy: Remove suffix\n", + "FAIL: Could not remove suffix\n", + "Strategy: Delta Debugging\n", + "FAIL: Could not remove prefix\n", + "\n", + "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n", + "inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n", + "\n", + "\n", + "\n", + "def forward(self, div, add):\n", + " mul = torch.ops.aten.mul(add, div); add = div = None\n", + " return (mul,)\n", + " \n", + "f = torch.jit.script(forward)\n", + "with torch.jit.fuser(\"fuser2\"):\n", + " for _ in range(5):\n", + " f(*inps)\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.fx as fx\n", + "from functorch.compile import minifier\n", + "\n", + "def failing_f(x, y):\n", + " y = torch.ops.aten.div(x, y)\n", + " x = torch.ops.aten.add(x, 3)\n", + " x = torch.ops.aten.mul(x, y)\n", + " return torch.ops.aten.sub(x, y)\n", + "\n", + "inps = [torch.randn(3), torch.randn(3)]\n", + "\n", + "def pass_checker(fx_g, inps):\n", + " return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))\n", + "\n", + "min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Tada! Our graph is now a minimal example that still fails.\n", + "\n", + "Since the primary use case of this minifier (for now) is for NVFuser repros, we print out a string for convenience that creates a self-contained repro to run the minified graph with NVFuser.\n", + "\n", + "Note that in practice, we provide 2 main \"graph checkers\" - `check_nvfuser_subprocess` and `check_nvfuser_correctness_subprocess`. These are used to check for errors and correctness (i.e. do the results match eager) respectively. These can be used like\n", + "\n", + "```\n", + "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n", + "minifier(failing_graph, inps, check_nvfuser_subprocess)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, assuming you're using AOTAutograd, there's another problem - how do you obtain the FX graph in the first place to pass to the minifier? One possible way is simply to use `print_compile`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "def forward(self, primals_1):\n", + " cos = torch.ops.aten.cos(primals_1)\n", + " cos_1 = torch.ops.aten.cos(cos)\n", + " return [cos_1, primals_1, cos]\n", + " \n", + "\n", + "\n", + "\n", + "def forward(self, primals_1, cos, tangents_1):\n", + " sin = torch.ops.aten.sin(cos); cos = None\n", + " neg = torch.ops.aten.neg(sin); sin = None\n", + " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", + " sin_1 = torch.ops.aten.sin(primals_1); primals_1 = None\n", + " neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None\n", + " mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None\n", + " return [mul_1]\n", + " \n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([0.6062, 0.9982, 0.6474], grad_fn=)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from functorch.compile import aot_function\n", + "\n", + "from functorch.compile import print_compile\n", + "# Or...\n", + "def print_compile(fx_g, _):\n", + " print(fx_g.code)\n", + " return fx_g\n", + "\n", + "def foo(x):\n", + " return x.cos().cos()\n", + "inp = torch.randn(3, requires_grad=True)\n", + "aot_function(foo, print_compile)(inp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, this doesn't provide the inputs, nor does it handle any tensor constants that might be saved in the graph. To resolve this, we have another \"compiler\" called `debug_compile`. It simply prints out a string that can be copy pasted and run from another file. It leverages FX's `to_folder` feature to serialize the graph to disk, along with any constants.\n", + "\n", + "You can apply it to either the `fw_compiler` to dump the forwards graph or `bw_compiler` to dump the backwards graph." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "##############################################################\n", + "# To minimize FX graph, copy and paste the below and run it #\n", + "##############################################################\n", + "\n", + "import torch\n", + "import torch.fx as fx\n", + "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n", + "\n", + "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n", + "inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n", + "from foo import FxModule\n", + "mod = FxModule().cuda()\n", + "\n", + "with torch.jit.fuser(\"fuser2\"):\n", + " # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess\n", + " minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([0.6062, 0.9982, 0.6474], grad_fn=)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from functorch.compile import memory_efficient_fusion, debug_compile\n", + "\n", + "memory_efficient_fusion(foo, bw_compiler=debug_compile)(inp)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So, let's copy paste it and see how it works - note that I made a couple minor modifications to run on CPU and use the previous \"graph fails if there's a multiply in it\" checker." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Started off with 10 nodes\n", + "###################\n", + "Current size: 10\n", + "###################\n", + "Strategy: Remove suffix\n", + "\n", + "SUCCESS: Removed [6:10)\n", + "\n", + "###################\n", + "Current size: 8\n", + "###################\n", + "Strategy: Delta Debugging\n", + "SUCCESS: Removed (0:4] - Went from 2 placeholders to 4\n", + "\n", + "###################\n", + "Current size: 8\n", + "###################\n", + "Strategy: Remove unused inputs\n", + "SUCCESS: Went from 4 inputs to 3 inputs\n", + "\n", + "###################\n", + "Current size: 7\n", + "###################\n", + "Strategy: Remove suffix\n", + "\n", + "SUCCESS: Removed [4:7)\n", + "\n", + "###################\n", + "Current size: 6\n", + "###################\n", + "Strategy: Remove unused inputs\n", + "SUCCESS: Went from 3 inputs to 2 inputs\n", + "\n", + "###################\n", + "Current size: 5\n", + "###################\n", + "Strategy: Delta Debugging\n", + "SUCCESS: Removed (2:3] - Went from 2 placeholders to 3\n", + "\n", + "###################\n", + "Current size: 5\n", + "###################\n", + "Strategy: Remove unused inputs\n", + "SUCCESS: Went from 3 inputs to 2 inputs\n", + "\n", + "###################\n", + "Current size: 4\n", + "###################\n", + "Strategy: Remove suffix\n", + "FAIL: Could not remove suffix\n", + "Strategy: Delta Debugging\n", + "FAIL: Could not remove prefix\n", + "\n", + "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n", + "inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n", + "\n", + "\n", + "\n", + "def forward(self, tangents_1, neg):\n", + " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", + " return (mul,)\n", + " \n", + "f = torch.jit.script(forward)\n", + "with torch.jit.fuser(\"fuser2\"):\n", + " for _ in range(5):\n", + " f(*inps)\n" + ] + }, + { + "data": { + "text/plain": [ + "(GraphModule(), [tensor([1., 1., 1.]), tensor([-0.5144, -0.5144, -0.5144])])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.fx as fx\n", + "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n", + "\n", + "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n", + "inps = [torch.ones(shape, dtype=dtype) for (shape, dtype) in inps]\n", + "from foo import FxModule\n", + "mod = FxModule()\n", + "\n", + "minifier(fx.symbolic_trace(mod), inps, pass_checker)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Hopefully that was useful :)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "a1cf69278e4496ab232105d2fffcc75678d2dcbec1c795483197519eb80161c7" + }, + "kernelspec": { + "display_name": "Python 3.8.12 ('py38')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/functorch/notebooks/neural_tangent_kernels.ipynb b/functorch/notebooks/neural_tangent_kernels.ipynb new file mode 100644 index 0000000000000..b56a5d6e1b189 --- /dev/null +++ b/functorch/notebooks/neural_tangent_kernels.ipynb @@ -0,0 +1,349 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b687b169-ec83-493d-a7c5-f8c6cd402ea3", + "metadata": {}, + "source": [ + "# Neural Tangent Kernels\n", + "\n", + "The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch." + ] + }, + { + "cell_type": "markdown", + "id": "77f41c65-f070-4b60-b3d0-1c8f56ed4f64", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, some setup. Let's define a simple CNN that we wish to compute the NTK of." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "855fa70b-5b63-4973-94df-41be57ab6ecf", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from functorch import make_functional, vmap, vjp, jvp, jacrev\n", + "device = 'cuda'\n", + "\n", + "class CNN(nn.Module):\n", + " def __init__(self):\n", + " super(CNN, self).__init__()\n", + " self.conv1 = nn.Conv2d(3, 32, (3, 3))\n", + " self.conv2 = nn.Conv2d(32, 32, (3, 3))\n", + " self.conv3 = nn.Conv2d(32, 32, (3, 3))\n", + " self.fc = nn.Linear(21632, 10)\n", + " \n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = x.relu()\n", + " x = self.conv2(x)\n", + " x = x.relu()\n", + " x = self.conv3(x)\n", + " x = x.flatten(1)\n", + " x = self.fc(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "52c600e9-207a-41ec-93b4-5d940827bda0", + "metadata": {}, + "source": [ + "And let's generate some random data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0001a907-f5c9-4532-9ee9-2e94b8487d08", + "metadata": {}, + "outputs": [], + "source": [ + "x_train = torch.randn(20, 3, 32, 32, device=device)\n", + "x_test = torch.randn(5, 3, 32, 32, device=device)" + ] + }, + { + "cell_type": "markdown", + "id": "8af210fe-9613-48ee-a96c-d0836458b0f1", + "metadata": {}, + "source": [ + "## Create a function version of the model\n", + "\n", + "functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n", + "\n", + "We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e6b4bb59-bdde-46cd-8a28-7fd00a37a387", + "metadata": {}, + "outputs": [], + "source": [ + "net = CNN().to(device)\n", + "fnet, params = make_functional(net)" + ] + }, + { + "cell_type": "markdown", + "id": "319276a4-da45-499a-af47-0677107559b6", + "metadata": {}, + "source": [ + "Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0b8b4021-eb10-4a50-9d99-3817cb0ce4cc", + "metadata": {}, + "outputs": [], + "source": [ + "def fnet_single(params, x):\n", + " return fnet(params, x.unsqueeze(0)).squeeze(0)" + ] + }, + { + "cell_type": "markdown", + "id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248", + "metadata": {}, + "source": [ + "## Compute the NTK: method 1 (Jacobian contraction)\n", + "\n", + "We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n", + "\n", + "$$J_{net}(x_1) J_{net}^T(x_2)$$\n", + "\n", + "In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n", + "\n", + "The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "99a38a4b-64d3-4e13-bd63-2d71e8dd6840", + "metadata": {}, + "outputs": [], + "source": [ + "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n", + " # Compute J(x1)\n", + " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n", + " jac1 = [j.flatten(2) for j in jac1]\n", + " \n", + " # Compute J(x2)\n", + " jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n", + " jac2 = [j.flatten(2) for j in jac2]\n", + " \n", + " # Compute J(x1) @ J(x2).T\n", + " result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])\n", + " result = result.sum(0)\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "cbf54d2b-c4bc-46bd-9e55-e1471d639a4e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([20, 5, 10, 10])\n" + ] + } + ], + "source": [ + "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n", + "print(result.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "ea844f45-98fb-4cba-8056-644292b968ab", + "metadata": {}, + "source": [ + "In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It's easy to adjust the above function to do that:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "aae760c9-e906-4fda-b490-1126a86b7e96", + "metadata": {}, + "outputs": [], + "source": [ + "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n", + " # Compute J(x1)\n", + " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n", + " jac1 = [j.flatten(2) for j in jac1]\n", + " \n", + " # Compute J(x2)\n", + " jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n", + " jac2 = [j.flatten(2) for j in jac2]\n", + " \n", + " # Compute J(x1) @ J(x2).T\n", + " einsum_expr = None\n", + " if compute == 'full':\n", + " einsum_expr = 'Naf,Mbf->NMab'\n", + " elif compute == 'trace':\n", + " einsum_expr = 'Naf,Maf->NM'\n", + " elif compute == 'diagonal':\n", + " einsum_expr = 'Naf,Maf->NMa'\n", + " else:\n", + " assert False\n", + " \n", + " result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])\n", + " result = result.sum(0)\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "42d974f3-1f9d-4953-8677-5ee22cfc67eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([20, 5])\n" + ] + } + ], + "source": [ + "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n", + "print(result.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa", + "metadata": {}, + "source": [ + "The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details." + ] + }, + { + "cell_type": "markdown", + "id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa", + "metadata": {}, + "source": [ + "## Compute the NTK: method 2 (NTK-vector products)\n", + "\n", + "The next method we will discuss is a way to compute the NTK using NTK-vector products.\n", + "\n", + "This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n", + "\n", + "$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n", + "where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n", + "\n", + "- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n", + "- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n", + "- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n", + "\n", + "This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n", + "\n", + "Let's code that up:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dc4b49d7-3096-45d5-a7a1-7032309a2613", + "metadata": {}, + "outputs": [], + "source": [ + "def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n", + " def get_ntk(x1, x2):\n", + " def func_x1(params):\n", + " return func(params, x1)\n", + "\n", + " def func_x2(params):\n", + " return func(params, x2)\n", + "\n", + " output, vjp_fn = vjp(func_x1, params)\n", + "\n", + " def get_ntk_slice(vec):\n", + " # This computes vec @ J(x2).T\n", + " # `vec` is some unit vector (a single slice of the Identity matrix)\n", + " vjps = vjp_fn(vec)\n", + " # This computes J(X1) @ vjps\n", + " _, jvps = jvp(func_x2, (params,), vjps)\n", + " return jvps\n", + "\n", + " # Here's our identity matrix\n", + " basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)\n", + " return vmap(get_ntk_slice)(basis)\n", + " \n", + " # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n", + " # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n", + " # we actually wish to compute the NTK between every pair of data points\n", + " # between {x1} and {x2}. That's what the vmaps here do.\n", + " result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n", + " \n", + " if compute == 'full':\n", + " return result\n", + " if compute == 'trace':\n", + " return torch.einsum('NMKK->NM', result)\n", + " if compute == 'diagonal':\n", + " return torch.einsum('NMKK->NMK', result)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f750544f-9e48-47fe-9f9b-e1b8ae49b245", + "metadata": {}, + "outputs": [], + "source": [ + "result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n", + "result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n", + "assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)" + ] + }, + { + "cell_type": "markdown", + "id": "84253466-971d-4475-999c-fe3de6bd25b5", + "metadata": {}, + "source": [ + "Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n", + "\n", + "The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/functorch/notebooks/per_sample_grads.ipynb b/functorch/notebooks/per_sample_grads.ipynb new file mode 100644 index 0000000000000..5ea18bd05424a --- /dev/null +++ b/functorch/notebooks/per_sample_grads.ipynb @@ -0,0 +1,607 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a474c143-05c4-43b6-b12c-17b592d07a6a", + "metadata": { + "id": "a474c143-05c4-43b6-b12c-17b592d07a6a" + }, + "source": [ + "# Per-sample-gradients\n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "## What is it?\n", + "\n", + "Per-sample-gradient computation is computing the gradient for each and every\n", + "sample in a batch of data. It is a useful quantity in differential privacy, meta-learning,\n", + "and optimization research.\n" + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from functools import partial\n", + "\n", + "torch.manual_seed(0);" + ], + "metadata": { + "id": "Gb-yt4VKUUuc" + }, + "execution_count": null, + "outputs": [], + "id": "Gb-yt4VKUUuc" + }, + { + "cell_type": "code", + "source": [ + "# Here's a simple CNN and loss function:\n", + "\n", + "class SimpleCNN(nn.Module):\n", + " def __init__(self):\n", + " super(SimpleCNN, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", + " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", + " self.fc1 = nn.Linear(9216, 128)\n", + " self.fc2 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = F.relu(x)\n", + " x = self.conv2(x)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, 2)\n", + " x = torch.flatten(x, 1)\n", + " x = self.fc1(x)\n", + " x = F.relu(x)\n", + " x = self.fc2(x)\n", + " output = F.log_softmax(x, dim=1)\n", + " output = x\n", + " return output\n", + "\n", + "def loss_fn(predictions, targets):\n", + " return F.nll_loss(predictions, targets)" + ], + "metadata": { + "id": "tf-HKHjUUbyY" + }, + "execution_count": null, + "outputs": [], + "id": "tf-HKHjUUbyY" + }, + { + "cell_type": "markdown", + "source": [ + "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. \n", + "\n", + "The dummy images are 28 by 28 and we use a minibatch of size 64.\n", + "\n" + ], + "metadata": { + "id": "VEDPe-EoU5Fa" + }, + "id": "VEDPe-EoU5Fa" + }, + { + "cell_type": "code", + "source": [ + "device = 'cuda'\n", + "\n", + "num_models = 10\n", + "batch_size = 64\n", + "data = torch.randn(batch_size, 1, 28, 28, device=device)\n", + "\n", + "targets = torch.randint(10, (64,), device=device)" + ], + "metadata": { + "id": "WB2Qe3AHUvPN" + }, + "execution_count": null, + "outputs": [], + "id": "WB2Qe3AHUvPN" + }, + { + "cell_type": "markdown", + "source": [ + "In regular model training, one would forward the minibatch through the model, and then call .backward() to compute gradients. This would generate an 'average' gradient of the entire mini-batch:\n", + "\n" + ], + "metadata": { + "id": "GOGJ-OUxVcT5" + }, + "id": "GOGJ-OUxVcT5" + }, + { + "cell_type": "code", + "source": [ + "model = SimpleCNN().to(device=device)\n", + "predictions = model(data) # move the entire mini-batch through the model\n", + "\n", + "loss = loss_fn(predictions, targets)\n", + "loss.backward() # back propogate the 'average' gradient of this mini-batch" + ], + "metadata": { + "id": "WYjMx8QTUvRu" + }, + "execution_count": null, + "outputs": [], + "id": "WYjMx8QTUvRu" + }, + { + "cell_type": "markdown", + "source": [ + "In contrast to the above approach, per-sample-gradient computation is equivalent to: \n", + "- for each individual sample of the data, perform a forward and a backward pass to get an individual (per-sample) gradient.\n", + "\n" + ], + "metadata": { + "id": "HNw4_IVzU5Pz" + }, + "id": "HNw4_IVzU5Pz" + }, + { + "cell_type": "code", + "source": [ + "def compute_grad(sample, target):\n", + " \n", + " sample = sample.unsqueeze(0) # prepend batch dimension for processing\n", + " target = target.unsqueeze(0)\n", + "\n", + " prediction = model(sample)\n", + " loss = loss_fn(prediction, target)\n", + "\n", + " return torch.autograd.grad(loss, list(model.parameters()))\n", + "\n", + "\n", + "def compute_sample_grads(data, targets):\n", + " \"\"\" manually process each sample with per sample gradient \"\"\"\n", + " sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]\n", + " sample_grads = zip(*sample_grads)\n", + " sample_grads = [torch.stack(shards) for shards in sample_grads]\n", + " return sample_grads\n", + "\n", + "per_sample_grads = compute_sample_grads(data, targets)" + ], + "metadata": { + "id": "vUsb3VfexJrY" + }, + "execution_count": null, + "outputs": [], + "id": "vUsb3VfexJrY" + }, + { + "cell_type": "markdown", + "source": [ + "`sample_grads[0]` is the per-sample-grad for model.conv1.weight. `model.conv1.weight.shape` is `[32, 1, 3, 3]`; notice how there is one gradient, per sample, in the batch for a total of 64.\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "aNkX6lFIxzcm" + }, + "id": "aNkX6lFIxzcm" + }, + { + "cell_type": "code", + "source": [ + "print(per_sample_grads[0].shape)" + ], + "metadata": { + "id": "C3a9_clvyPho", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "407abc1a-846f-4e50-83bc-c90719a26073" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([64, 32, 1, 3, 3])\n" + ] + } + ], + "id": "C3a9_clvyPho" + }, + { + "cell_type": "markdown", + "source": [ + "## Per-sample-grads, *the efficient way*, using functorch\n", + "\n", + "\n" + ], + "metadata": { + "id": "mFJDWMM9yaYZ" + }, + "id": "mFJDWMM9yaYZ" + }, + { + "cell_type": "markdown", + "source": [ + "We can compute per-sample-gradients efficiently by using function transforms. \n", + "\n", + "First, let’s create a stateless functional version of `model` by using `functorch.make_functional_with_buffers`. \n", + "\n", + "This will seperate state (the parameters) from the model and turn the model into a pure function:\n", + "\n" + ], + "metadata": { + "id": "tlkmyQyfY6XU" + }, + "id": "tlkmyQyfY6XU" + }, + { + "cell_type": "code", + "source": [ + "from functorch import make_functional_with_buffers, vmap, grad\n", + "\n", + "fmodel, params, buffers = make_functional_with_buffers(model)" + ], + "metadata": { + "id": "WiSMupvCyecd" + }, + "execution_count": null, + "outputs": [], + "id": "WiSMupvCyecd" + }, + { + "cell_type": "markdown", + "source": [ + "Let's review the changes - first, the model has become the stateless FunctionalModuleWithBuffers:" + ], + "metadata": { + "id": "wMsbppPNZklo" + }, + "id": "wMsbppPNZklo" + }, + { + "cell_type": "code", + "source": [ + "fmodel" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Xj0cZOJMZbbB", + "outputId": "2e87dfde-3af2-4e1f-cd91-5c232446fb53" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "FunctionalModuleWithBuffers(\n", + " (stateless_model): SimpleCNN(\n", + " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", + " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", + " (fc1): Linear(in_features=9216, out_features=128, bias=True)\n", + " (fc2): Linear(in_features=128, out_features=10, bias=True)\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 15 + } + ], + "id": "Xj0cZOJMZbbB" + }, + { + "cell_type": "markdown", + "source": [ + "And the model parameters now exist independently of the model, stored as a tuple:" + ], + "metadata": { + "id": "zv4_YYPxZvvg" + }, + "id": "zv4_YYPxZvvg" + }, + { + "cell_type": "code", + "source": [ + "for x in params:\n", + " print(f\"{x.shape}\")\n", + "\n", + "print(f\"\\n{type(params)}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tH0TAZhBZ3bS", + "outputId": "97c4401f-cccb-43f6-b071-c85a18fc439b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([32, 1, 3, 3])\n", + "torch.Size([32])\n", + "torch.Size([64, 32, 3, 3])\n", + "torch.Size([64])\n", + "torch.Size([128, 9216])\n", + "torch.Size([128])\n", + "torch.Size([10, 128])\n", + "torch.Size([10])\n", + "\n", + "\n" + ] + } + ], + "id": "tH0TAZhBZ3bS" + }, + { + "cell_type": "markdown", + "source": [ + "Next, let’s define a function to compute the loss of the model given a single input rather than a batch of inputs. It is important that this function accepts the parameters, the input, and the target, because we will be transforming over them. \n", + "\n", + "Note - because the model was originally written to handle batches, we’ll use `torch.unsqueeze` to add a batch dimension.\n", + "\n" + ], + "metadata": { + "id": "cTgIIZ9Wyih8" + }, + "id": "cTgIIZ9Wyih8" + }, + { + "cell_type": "code", + "source": [ + "def compute_loss_stateless_model (params, buffers, sample, target):\n", + " batch = sample.unsqueeze(0)\n", + " targets = target.unsqueeze(0)\n", + "\n", + " predictions = fmodel(params, buffers, batch) \n", + " loss = loss_fn(predictions, targets)\n", + " return loss" + ], + "metadata": { + "id": "ItURFU3M-p98" + }, + "execution_count": null, + "outputs": [], + "id": "ItURFU3M-p98" + }, + { + "cell_type": "markdown", + "source": [ + "Now, let’s use functorch's `grad` to create a new function that computes the gradient with respect to the first argument of `compute_loss` (i.e. the params)." + ], + "metadata": { + "id": "Qo3sbDK2i_bH" + }, + "id": "Qo3sbDK2i_bH" + }, + { + "cell_type": "code", + "source": [ + "ft_compute_grad = grad(compute_loss_stateless_model)" + ], + "metadata": { + "id": "sqRp_Sxni-Xm" + }, + "execution_count": null, + "outputs": [], + "id": "sqRp_Sxni-Xm" + }, + { + "cell_type": "markdown", + "source": [ + "The `ft_compute_grad` function computes the gradient for a single (sample, target) pair. We can use vmap to get it to compute the gradient over an entire batch of samples and targets. Note that `in_dims=(None, None, 0, 0)` because we wish to map `ft_compute_grad` over the 0th dimension of the data and targets, and use the same params and buffers for each.\n", + "\n" + ], + "metadata": { + "id": "2pG3Ofqjjc8O" + }, + "id": "2pG3Ofqjjc8O" + }, + { + "cell_type": "code", + "source": [ + "ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))" + ], + "metadata": { + "id": "62ecNMO6inqX" + }, + "execution_count": null, + "outputs": [], + "id": "62ecNMO6inqX" + }, + { + "cell_type": "markdown", + "source": [ + "Finally, let’s used our transformed function to compute per-sample-gradients:\n", + "\n" + ], + "metadata": { + "id": "_alXdQ3QkETu" + }, + "id": "_alXdQ3QkETu" + }, + { + "cell_type": "code", + "source": [ + "ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)\n", + "\n", + "# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:\n", + "for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):\n", + " assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)" + ], + "metadata": { + "id": "1gehVA1c-BHd" + }, + "execution_count": null, + "outputs": [], + "id": "1gehVA1c-BHd" + }, + { + "cell_type": "markdown", + "source": [ + "A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs, and that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations.\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "BEZaNt1d_bc1" + }, + "id": "BEZaNt1d_bc1" + }, + { + "cell_type": "markdown", + "source": [ + "## Performance comparison" + ], + "metadata": { + "id": "BASP151Iml7B" + }, + "id": "BASP151Iml7B" + }, + { + "cell_type": "markdown", + "source": [ + "Curious about how the performance of vmap compares?\n", + "\n", + "Currently the best results are obtained on newer GPU's such as the A100 (Ampere) where we've seen up to 25x speedups on this example, but here are some results done in Colab:" + ], + "metadata": { + "id": "jr1xNpV4nJ7u" + }, + "id": "jr1xNpV4nJ7u" + }, + { + "cell_type": "code", + "source": [ + "def get_perf(first, first_descriptor, second, second_descriptor):\n", + " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n", + " second_res = second.times[0]\n", + " first_res = first.times[0]\n", + "\n", + " gain = (first_res-second_res)/first_res\n", + " if gain < 0: gain *=-1 \n", + " final_gain = gain*100\n", + "\n", + " print(f\" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} \")" + ], + "metadata": { + "id": "GnAnMkYmoc-j" + }, + "execution_count": null, + "outputs": [], + "id": "GnAnMkYmoc-j" + }, + { + "cell_type": "code", + "source": [ + "from torch.utils.benchmark import Timer\n", + "\n", + "without_vmap = Timer( stmt=\"compute_sample_grads(data, targets)\", globals=globals())\n", + "with_vmap = Timer(stmt=\"ft_compute_sample_grad(params, buffers, data, targets)\",globals=globals())\n", + "no_vmap_timing = without_vmap.timeit(100)\n", + "with_vmap_timing = with_vmap.timeit(100)\n", + "\n", + "print(f'Per-sample-grads without vmap {no_vmap_timing}')\n", + "print(f'Per-sample-grads with vmap {with_vmap_timing}')" + ], + "metadata": { + "id": "Zfnn2C2g-6Fb", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "922f3901-773f-446b-b562-88e78f49036c" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Per-sample-grads without vmap \n", + "compute_sample_grads(data, targets)\n", + " 79.86 ms\n", + " 1 measurement, 100 runs , 1 thread\n", + "Per-sample-grads with vmap \n", + "ft_compute_sample_grad(params, buffers, data, targets)\n", + " 12.93 ms\n", + " 1 measurement, 100 runs , 1 thread\n" + ] + } + ], + "id": "Zfnn2C2g-6Fb" + }, + { + "cell_type": "code", + "source": [ + "get_perf(with_vmap_timing, \"vmap\", no_vmap_timing,\"no vmap\" )" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NV9R3LZQoavl", + "outputId": "e11e8be9-287d-4e60-e517-e08f8d6909bd" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Performance delta: 517.5791 percent improvement with vmap \n" + ] + } + ], + "id": "NV9R3LZQoavl" + }, + { + "cell_type": "markdown", + "source": [ + "There are other optimized solutions (like in https://github.com/pytorch/opacus) to computing per-sample-gradients in PyTorch that also perform better than the naive method. But it’s cool that composing `vmap` and `grad` give us a nice speedup.\n", + "\n", + "\n", + "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n", + "\n" + ], + "metadata": { + "id": "UI74G9JarQU8" + }, + "id": "UI74G9JarQU8" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "colab": { + "name": "per_sample_grads.ipynb", + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/functorch/notebooks/whirlwind_tour.ipynb b/functorch/notebooks/whirlwind_tour.ipynb new file mode 100644 index 0000000000000..3c49985d79f4c --- /dev/null +++ b/functorch/notebooks/whirlwind_tour.ipynb @@ -0,0 +1,321 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "903e2f76", + "metadata": {}, + "source": [ + "# Whirlwind Tour\n", + "\n", + "\n", + "## What is functorch?\n", + "\n", + "functorch is a library for [JAX](https://github.com/google/jax)-like composable function transforms in PyTorch.\n", + "- A \"function transform\" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.\n", + "- functorch has auto-differentiation transforms (`grad(f)` returns a function that computes the gradient of `f`), a vectorization/batching transform (`vmap(f)` returns a function that computes `f` over batches of inputs), and others.\n", + "- These function transforms can compose with each other arbitrarily. For example, composing `vmap(grad(f))` computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today.\n", + "\n", + "Furthermore, we also provide an experimental compilation transform in the `functorch.compile` namespace. Our compilation transform, named AOT (ahead-of-time) Autograd, returns to you an [FX graph](https://pytorch.org/docs/stable/fx.html) (that optionally contains a backward pass), of which compilation via various backends is one path you can take.\n", + "\n", + "\n", + "## Why composable function transforms?\n", + "There are a number of use cases that are tricky to do in PyTorch today:\n", + "- computing per-sample-gradients (or other per-sample quantities)\n", + "- running ensembles of models on a single machine\n", + "- efficiently batching together tasks in the inner-loop of MAML\n", + "- efficiently computing Jacobians and Hessians\n", + "- efficiently computing batched Jacobians and Hessians\n", + "\n", + "Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each.\n", + "\n", + "## What are the transforms?\n", + "\n", + "### grad (gradient computation)\n", + "\n", + "`grad(func)` is our gradient computation transform. It returns a new function that computes the gradients of `func`. It assumes `func` returns a single-element Tensor and by default it computes the gradients of the output of `func` w.r.t. to the first input." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f920b923", + "metadata": {}, + "outputs": [], + "source": [ + "from functorch import grad\n", + "x = torch.randn([])\n", + "cos_x = grad(lambda x: torch.sin(x))(x)\n", + "assert torch.allclose(cos_x, x.cos())\n", + "\n", + "# Second-order gradients\n", + "neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n", + "assert torch.allclose(neg_sin_x, -x.sin())" + ] + }, + { + "cell_type": "markdown", + "id": "ef3b2d85", + "metadata": {}, + "source": [ + "### vmap (auto-vectorization)\n", + "\n", + "Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n", + "\n", + "`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor operations in `func`. `vmap(func)` returns a new function that maps `func` over some dimension (default: 0) of each Tensor in inputs.\n", + "\n", + "vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with `vmap(func)`, leading to a simpler modeling experience:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ebac649", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from functorch import vmap\n", + "batch_size, feature_size = 3, 5\n", + "weights = torch.randn(feature_size, requires_grad=True)\n", + "\n", + "def model(feature_vec):\n", + " # Very simple linear model with activation\n", + " assert feature_vec.dim() == 1\n", + " return feature_vec.dot(weights).relu()\n", + "\n", + "examples = torch.randn(batch_size, feature_size)\n", + "result = vmap(model)(examples)" + ] + }, + { + "cell_type": "markdown", + "id": "5161e6d2", + "metadata": {}, + "source": [ + "When composed with `grad`, `vmap` can be used to compute per-sample-gradients:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffb2fcb1", + "metadata": {}, + "outputs": [], + "source": [ + "from functorch import vmap\n", + "batch_size, feature_size = 3, 5\n", + "\n", + "def model(weights,feature_vec):\n", + " # Very simple linear model with activation\n", + " assert feature_vec.dim() == 1\n", + " return feature_vec.dot(weights).relu()\n", + "\n", + "def compute_loss(weights, example, target):\n", + " y = model(weights, example)\n", + " return ((y - target) ** 2).mean() # MSELoss\n", + "\n", + "weights = torch.randn(feature_size, requires_grad=True)\n", + "examples = torch.randn(batch_size, feature_size)\n", + "targets = torch.randn(batch_size)\n", + "inputs = (weights,examples, targets)\n", + "grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)" + ] + }, + { + "cell_type": "markdown", + "id": "11d711af", + "metadata": {}, + "source": [ + "### vjp (vector-Jacobian product)\n", + "\n", + "The `vjp` transform applies `func` to `inputs` and returns a new function that computes the vector-Jacobian product (vjp) given some `cotangents` Tensors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad48f9d4", + "metadata": {}, + "outputs": [], + "source": [ + "from functorch import vjp\n", + "\n", + "inputs = torch.randn(3)\n", + "func = torch.sin\n", + "cotangents = (torch.randn(3),)\n", + "\n", + "outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)" + ] + }, + { + "cell_type": "markdown", + "id": "e0221270", + "metadata": {}, + "source": [ + "### jvp (Jacobian-vector product)\n", + "\n", + "The `jvp` transforms computes Jacobian-vector-products and is also known as \"forward-mode AD\". It is not a higher-order function unlike most other transforms, but it returns the outputs of `func(inputs)` as well as the jvps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3772f43", + "metadata": {}, + "outputs": [], + "source": [ + "from functorch import jvp\n", + "x = torch.randn(5)\n", + "y = torch.randn(5)\n", + "f = lambda x, y: (x * y)\n", + "_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))\n", + "assert torch.allclose(output, x + y)" + ] + }, + { + "cell_type": "markdown", + "id": "7b00953b", + "metadata": {}, + "source": [ + "### jacrev, jacfwd, and hessian\n", + "\n", + "The `jacrev` transform returns a new function that takes in `x` and returns the Jacobian of the function\n", + "with respect to `x` using reverse-mode AD." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20f53be2", + "metadata": {}, + "outputs": [], + "source": [ + "from functorch import jacrev\n", + "x = torch.randn(5)\n", + "jacobian = jacrev(torch.sin)(x)\n", + "expected = torch.diag(torch.cos(x))\n", + "assert torch.allclose(jacobian, expected)" + ] + }, + { + "cell_type": "markdown", + "id": "b9007c88", + "metadata": {}, + "source": [ + "Use `jacrev` to compute the jacobian. This can be composed with `vmap` to produce batched jacobians:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97d6c382", + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(64, 5)\n", + "jacobian = vmap(jacrev(torch.sin))(x)\n", + "assert jacobian.shape == (64, 5, 5)" + ] + }, + { + "cell_type": "markdown", + "id": "cda642ec", + "metadata": {}, + "source": [ + "`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using forward-mode AD:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8c1dedb", + "metadata": {}, + "outputs": [], + "source": [ + "from functorch import jacfwd\n", + "x = torch.randn(5)\n", + "jacobian = jacfwd(torch.sin)(x)\n", + "expected = torch.diag(torch.cos(x))\n", + "assert torch.allclose(jacobian, expected)" + ] + }, + { + "cell_type": "markdown", + "id": "39f85b50", + "metadata": {}, + "source": [ + "Composing `jacrev` with itself or `jacfwd` can produce hessians:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e511139", + "metadata": {}, + "outputs": [], + "source": [ + "def f(x):\n", + " return x.sin().sum()\n", + "\n", + "x = torch.randn(5)\n", + "hessian0 = jacrev(jacrev(f))(x)\n", + "hessian1 = jacfwd(jacrev(f))(x)" + ] + }, + { + "cell_type": "markdown", + "id": "18efdc65", + "metadata": {}, + "source": [ + "The `hessian` is a convenience function that combines `jacfwd` and `jacrev`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd1765df", + "metadata": {}, + "outputs": [], + "source": [ + "from functorch import hessian\n", + "\n", + "def f(x):\n", + " return x.sin().sum()\n", + "\n", + "x = torch.randn(5)\n", + "hess = hessian(f)(x)" + ] + }, + { + "cell_type": "markdown", + "id": "b597d7ad", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "Check out our other tutorials (in the left bar) for more detailed explanations of how to apply functorch transforms for various use cases. `functorch` is very much a work in progress and we'd love to hear how you're using it -- we encourage you to start a conversation at our [issues tracker](https://github.com/pytorch/functorch) to discuss your use case." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/functorch/op_analysis/annotated_ops b/functorch/op_analysis/annotated_ops new file mode 100644 index 0000000000000..d4f8d5926ef68 --- /dev/null +++ b/functorch/op_analysis/annotated_ops @@ -0,0 +1,570 @@ +data, misc +rename, named +align_to, named +align_as, named +align_tensors, named +refine_names, named +dropout, composite pointwise +feature_dropout, composite pointwise +alpha_dropout, reduction +feature_alpha_dropout, reduction +abs, primitive pointwise +absolute, alias +angle, composite pointwise +view_as_real, complex +view_as_complex, complex +sgn, primitive pointwise +real, complex +imag, complex +conj, complex +conj_physical, complex +resolve_conj, complex +acos, primitive pointwise +arccos, alias +add, primitive pointwise +addmv, composite matmul +addr, composite pointwise +affine_grid_generator, factory +all, reduction +any, reduction +arange, factory +argmax, reduction +argmin, reduction +acosh, primitive pointwise +arccosh, alias +asinh, primitive pointwise +arcsinh, alias +atanh, primitive pointwise +arctanh, alias +as_strided, view/reshape +asin, primitive pointwise +arcsin, alias +atan, primitive pointwise +arctan, alias +atleast_1d, view/reshape +atleast_2d, view/reshape +atleast_3d, view/reshape +baddbmm, composite matmul +bartlett_window, factory +bernoulli, factory +bilinear, composite matmul +binary_cross_entropy, reduction +binary_cross_entropy_with_logits, reduction +bincount, misc +bitwise_not, primitive pointwise +copysign, composite pointwise +logical_not, composite pointwise +logical_xor, composite pointwise +logical_and, composite pointwise +logical_or, composite pointwise +blackman_window, factory +bmm, composite matmul +broadcast_tensors, view/reshape +broadcast_to, view/reshape +cat, view/reshape +block_diag, view/reshape +ceil, primitive pointwise +chain_matmul, alias +unsafe_chunk, view/reshape +chunk, view/reshape +tensor_split, view/reshape +clamp, composite pointwise +clamp_max, composite pointwise +clamp_min, composite pointwise +clip, alias +complex, complex +polar, composite pointwise +constant_pad_nd, view/reshape +contiguous, view/reshape +cos, primitive pointwise +cosh, primitive pointwise +cosine_embedding_loss, reduction +count_nonzero, reduction +cummax, reduction +cummin, reduction +cumprod, reduction +cumsum, reduction +ctc_loss, misc +diag_embed, view/reshape +diagflat, view/reshape +diagonal, view/reshape +diff, reduction +gradient, misc +div, primitive pointwise +divide, alias +true_divide, primitive pointwise +dot, reduction +vdot, reduction +einsum, composite matmul +embedding, misc +row_stack, alias +embedding_bag, reduction +empty, factory +new_empty, factory +new_empty_strided, factory +new_full, factory +new_zeros, factory +new_ones, factory +empty_quantized, factory +empty_like, factory +empty_strided, factory +erf, primitive pointwise +erfc, composite pointwise +exp, primitive pointwise +exp2, composite pointwise +expm1, composite pointwise +expand, view/reshape +expand_as, view/reshape +eye, factory +flatten, view/reshape +unflatten, view/reshape +floor, primitive pointwise +floor_divide, composite pointwise +frac, primitive pointwise +full, factory +full_like, factory +from_file, misc +gcd, primitive pointwise +lcm, composite pointwise +grid_sampler, misc +grid_sampler_2d, misc +grid_sampler_3d, misc +hann_window, factory +hamming_window, factory +kaiser_window, factory +hinge_embedding_loss, reduction +group_norm, reduction +index, scatter/gather +index_copy, scatter/gather +index_put, scatter/gather +instance_norm, reduction +inverse, linalg +isclose, composite pointwise +isnan, composite pointwise +isreal, composite pointwise +kl_div, reduction +kron, composite pointwise +kthvalue, reduction +layer_norm, reduction +nan_to_num, composite pointwise +linear, composite matmul +fbgemm_linear_int8_weight_fp32_activation, fbgemm +fbgemm_linear_int8_weight, fbgemm +fbgemm_linear_quantize_weight, fbgemm +fbgemm_pack_gemm_matrix_fp16, fbgemm +fbgemm_linear_fp16_weight_fp32_activation, fbgemm +fbgemm_linear_fp16_weight, fbgemm +fbgemm_pack_quantized_matrix, fbgemm +ldexp, composite pointwise +linspace, factory +log, primitive pointwise +log10, composite pointwise +log1p, composite pointwise +log2, composite pointwise +logaddexp, reduction +logaddexp2, reduction +xlogy, composite pointwise +logdet, linalg +logspace, factory +log_softmax, reduction +logcumsumexp, reduction +logsumexp, reduction +margin_ranking_loss, reduction +matmul, composite matmul +matrix_rank, linalg +matrix_power, alias +matrix_exp, linalg +max, reduction +amax, reduction +mean, reduction +median, reduction +nanmedian, reduction +min, reduction +amin, reduction +mm, composite matmul +mode, misc +mul, reduction +multiply, alias +mv, composite matmul +mvlgamma, primitive pointwise +narrow_copy, view/reshape +narrow, view/reshape +ones, factory +ones_like, factory +pairwise_distance, reduction +cdist, reduction +pdist, reduction +cosine_similarity, reduction +permute, scatter/gather +movedim, view/reshape +moveaxis, alias +numpy_T, view/reshape +pixel_shuffle, view/reshape +pixel_unshuffle, view/reshape +channel_shuffle, view/reshape +pin_memory, misc +pinverse, linalg +poisson_nll_loss, reduction +rad2deg, composite pointwise +deg2rad, composite pointwise +scalar_tensor, factory +rand, factory +rand_like, factory +randint, factory +randint_like, factory +randn, factory +randn_like, factory +randperm, factory +range, factory +ravel, view/reshape +reciprocal, composite pointwise +neg, composite pointwise +negative, alias +repeat, view/reshape +repeat_interleave, view/reshape +reshape, view/reshape +reshape_as, view/reshape +round, primitive pointwise +rrelu, composite pointwise +relu, composite pointwise +relu6, composite pointwise +prelu, composite pointwise +gelu, composite pointwise +hardshrink, composite pointwise +rsqrt, composite pointwise +select, view/reshape +selu, composite pointwise +celu, composite pointwise +silu, composite pointwise +mish, composite pointwise +sigmoid, composite pointwise +logit, composite pointwise +sin, primitive pointwise +sinc, composite pointwise +sinh, primitive pointwise +detach, misc +slice, view/reshape +slogdet, linalg +smm, sparse +softmax, reduction +unsafe_split, view/reshape +split, view/reshape +unsafe_split_with_sizes, view/reshape +split_with_sizes, view/reshape +hsplit, view/reshape +vsplit, view/reshape +dsplit, view/reshape +squeeze, view/reshape +sspaddmm, sparse +stack, view/reshape +hstack, view/reshape +vstack, view/reshape +dstack, view/reshape +stft, fft +istft, fft +sum, reduction +nansum, reduction +sum_to_size, reduction +sqrt, primitive pointwise +square, composite pointwise +std, reduction +std_mean, reduction +prod, reduction +t, view/reshape +tan, primitive pointwise +tanh, primitive pointwise +tensordot, composite matmul +threshold, composite pointwise +tile, view/reshape +transpose, view/reshape +one_hot, scatter/gather +flip, view/reshape +fliplr, view/reshape +flipud, view/reshape +roll, view/reshape +rot90, view/reshape +trapz, composite pointwise +triplet_margin_loss, reduction +trunc, composite pointwise +fix, alias +type_as, misc +unique_dim, misc +unique_consecutive, misc +unique_dim_consecutive, misc +unsqueeze, view/reshape +vander, factory +var, reduction +var_mean, reduction +view_as, view/reshape +where, misc +norm_except_dim, reduction +zeros, factory +zeros_like, factory +poisson, factory +binomial, factory +norm, reduction +frexp, composite pointwise +frobenius_norm, reduction +nuclear_norm, reduction +clone, view/reshape +positive, composite pointwise +sub, primitive pointwise +subtract, alias +rsub, primitive pointwise +heaviside, composite pointwise +addmm, composite matmul +sparse_csr_tensor, sparse +sparse_coo_tensor, sparse +sparse_mask, sparse +to_dense, sparse +coalesce, sparse +indices, sparse +values, sparse +crow_indices, sparse +col_indices, sparse +hspmm, sparse +unbind, view/reshape +to_sparse, sparse +quantize_per_tensor, quantize +quantize_per_channel, quantize +dequantize, quantize +q_per_channel_scales, quantize +q_per_channel_zero_points, quantize +int_repr, quantize +fake_quantize_per_tensor_affine, quantize +fake_quantize_per_tensor_affine_cachemask, quantize +fake_quantize_per_channel_affine, quantize +fake_quantize_per_channel_affine_cachemask, quantize +choose_qparams_optimized, quantize +to, misc +meshgrid, view/reshape +cartesian_prod, misc +combinations, misc +lstm, rnn +gru, rnn +rnn_tanh, rnn +rnn_relu, rnn +lstm_cell, rnn +gru_cell, rnn +rnn_tanh_cell, rnn +rnn_relu_cell, rnn +quantized_lstm_cell, rnn +quantized_gru_cell, rnn +quantized_rnn_relu_cell, rnn +quantized_rnn_tanh_cell, rnn +masked_fill, misc +masked_scatter, misc +view, view/reshape +put, scatter/gather +index_add, scatter/gather +index_fill, scatter/gather +scatter, scatter/gather +scatter_add, scatter/gather +bitwise_and, primitive pointwise +bitwise_or, primitive pointwise +bitwise_xor, primitive pointwise +addbmm, composite matmul +diag, view/reshape +cross, misc +triu, view/reshape +tril, view/reshape +tril_indices, factory +triu_indices, factory +trace, reduction +ne, composite pointwise +not_equal, alias +eq, primitive pointwise +ge, primitive pointwise +greater_equal, alias +le, primitive pointwise +less_equal, alias +gt, primitive pointwise +greater, alias +lt, primitive pointwise +less, alias +take, scatter/gather +take_along_dim, scatter/gather +index_select, scatter/gather +masked_select, scatter/gather +nonzero, misc +nonzero_numpy, misc +gather, scatter/gather +addcmul, composite pointwise +addcdiv, composite pointwise +cross_entropy_loss, reduction +lstsq, alias +triangular_solve, linalg +symeig, linalg +eig, linalg +svd, linalg +swapaxes, alias +swapdims, alias +cholesky, linalg +cholesky_solve, linalg +solve, linalg +cholesky_inverse, linalg +qr, linalg +geqrf, linalg +orgqr, alias +ormqr, linalg +lu_solve, linalg +lu_unpack, linalg +multinomial, misc +lgamma, primitive pointwise +digamma, primitive pointwise +polygamma, primitive pointwise +erfinv, primitive pointwise +i0, primitive pointwise +sign, composite pointwise +signbit, composite pointwise +dist, reduction +atan2, primitive pointwise +lerp, composite pointwise +histc, misc +fmod, primitive pointwise +hypot, composite pointwise +igamma, primitive pointwise +igammac, primitive pointwise +nextafter, composite pointwise +remainder, composite pointwise +fmin, reduction +fmax, reduction +maximum, composite pointwise +minimum, composite pointwise +quantile, misc +nanquantile, misc +sort, misc +msort, misc +argsort, misc +topk, misc +renorm, reduction +unfold, misc +pow, primitive pointwise +float_power, composite pointwise +normal, factory +alias, misc +bucketize, misc +searchsorted, misc +mse_loss, reduction +l1_loss, reduction +multi_margin_loss, reduction +multilabel_margin_loss, reduction +multilabel_margin_loss_forward, reduction +nll_loss, reduction +nll_loss_nd, reduction +nll_loss_forward, reduction +nll_loss2d, reduction +nll_loss2d_forward, reduction +smooth_l1_loss, reduction +huber_loss, reduction +soft_margin_loss, reduction +elu, composite pointwise +glu, composite pointwise +hardsigmoid, composite pointwise +hardtanh, composite pointwise +hardswish, composite pointwise +leaky_relu, composite pointwise +log_sigmoid, composite pointwise +log_sigmoid_forward, composite pointwise +rrelu_with_noise, composite pointwise +softplus, composite pointwise +softshrink, composite pointwise +reflection_pad1d, misc +reflection_pad2d, misc +replication_pad1d, misc +replication_pad2d, misc +replication_pad3d, misc +upsample_linear1d, misc +upsample_bilinear2d, misc +upsample_trilinear3d, misc +upsample_bicubic2d, misc +upsample_nearest1d, misc +upsample_nearest2d, misc +upsample_nearest3d, misc +col2im, misc +column_stack, view/reshape +im2col, view/reshape +isfinite, composite pointwise +isinf, composite pointwise +isposinf, composite pointwise +isneginf, composite pointwise +special_entr, primitive pointwise +special_expm1, alias +special_exp2, alias +special_gammaln, alias +special_erf, alias +special_erfc, alias +special_erfinv, alias +special_ndtr, primitive pointwise +special_xlog1py, composite pointwise +special_i0, alias +special_i0e, composite pointwise +special_i1, primitive pointwise +special_i1e, composite pointwise +special_logit, composite pointwise +special_expit, composite pointwise +fft_fft, fft +fft_ifft, fft +fft_rfft, fft +fft_irfft, fft +fft_hfft, fft +fft_ihfft, fft +fft_fft2, fft +fft_ifft2, fft +fft_rfft2, fft +fft_irfft2, fft +fft_fftn, fft +fft_ifftn, fft +fft_rfftn, fft +fft_irfftn, fft +fft_fftfreq, fft +fft_rfftfreq, fft +fft_fftshift, fft +fft_ifftshift, fft +linalg_cholesky_ex, alias +linalg_cholesky, alias +linalg_det, linalg +det, alias +linalg_lstsq, linalg +linalg_slogdet, alias +linalg_eig, alias +linalg_eigvals, alias +linalg_eigh, linalg +linalg_eigvalsh, linalg +linalg_householder_product, linalg +linalg_inv_ex, alias +linalg_inv, alias +inner, reduction +outer, composite pointwise +ger, alias +linalg_norm, reduction +linalg_vector_norm, reduction +linalg_matrix_norm, reduction +linalg_svd, alias +linalg_svdvals, linalg +linalg_cond, linalg +linalg_pinv, alias +linalg_solve, alias +linalg_tensorinv, linalg +linalg_tensorsolve, linalg +linalg_qr, alias +linalg_matrix_power, linalg +linalg_matrix_rank, alias +linalg_multi_dot, linalg +segment_reduce, misc +pad_sequence, misc +flatten_dense_tensors, misc +unflatten_dense_tensors, misc +bitwise_left_shift, primitive pointwise +bitwise_right_shift, primitive pointwise +trapezoid, reduction +special_ndtri, primitive pointwise +special_psi, primitive pointwise +special_digamma, primitive pointwise +special_erfcx, primitive pointwise +special_xlogy, primitive pointwise +special_zeta, primitive pointwise +special_sinc, primitive pointwise +special_round, primitive pointwise +special_log1p, primitive pointwise +isin, reduction diff --git a/functorch/op_analysis/gen_data.py b/functorch/op_analysis/gen_data.py new file mode 100644 index 0000000000000..a65ea0cb96a6f --- /dev/null +++ b/functorch/op_analysis/gen_data.py @@ -0,0 +1,157 @@ +import yaml +import csv +import torch +from collections import defaultdict + + +def get_ops_for_key(key): + # Needs modified PyTorch C++ code to work + if key is None: + ops = torch._C._dispatch_get_registrations_for_dispatch_key() + else: + ops = torch._C._dispatch_get_registrations_for_dispatch_key(key) + cleaned_ops = [] + for i in ops: + if 'aten::' not in i: + continue + cleaned_ops.append(i[6:].strip()) + return set(cleaned_ops) + + +def gen_data(special_op_lists, analysis_name): + all_ops = get_ops_for_key(None) + composite_ops = get_ops_for_key('CompositeImplicitAutograd') + noncomposite_ops = all_ops - composite_ops + + ops = yaml.load(open('../../pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader) + + annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))} + from collections import defaultdict + + uniq_ops = [] + uniq_names = set() + overload_types = defaultdict(list) + cnt = 0 + for op in ops: + func_str = op['func'] + name = func_str[:func_str.index('(')] + if '.' in name: + uniq_name = name[:name.index('.')] + overload_types[name[name.index('.') + 1:]].append(name) + else: + uniq_name = name + op['name'] = uniq_name + full_name = func_str[:func_str.index('(')] + op['full_name'] = full_name + ret_type = func_str[func_str.index('->') + 3:] + op['ret_type'] = ret_type + cnt += 1 + if uniq_name in uniq_names: + continue + uniq_names.add(uniq_name) + uniq_ops.append(op) + + def annotate_ops(ops, is_unique): + categorization = defaultdict(int) + for op in ops: + if op['name'][-1] == '_': + categorization['inplace'] += 1 + op['meta'] = 'inplace' + continue + if not is_unique and 'a!' in op['func'].lower(): + categorization['out'] += 1 + op['meta'] = 'out' + continue + if 'conv' in op['name']: + categorization['conv'] += 1 + op['meta'] = 'conv' + continue + if 'pool' in op['name']: + categorization['pool'] += 1 + op['meta'] = 'pool' + continue + if 'backward' in op['name']: + categorization['backward'] += 1 + op['meta'] = 'backward' + continue + if op['name'][0] == '_' and op['name'][1] != '_': + categorization['private'] += 1 + op['meta'] = 'private' + continue + if 'batch_norm' in op['name']: + categorization['batch_norm'] += 1 + op['meta'] = 'batch_norm' + continue + if 'Tensor' not in op['func'] or 'Tensor' not in op['ret_type']: + categorization['non_tensor'] += 1 + op['meta'] = 'non_tensor' + continue + if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or \ + 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']: + categorization['backend'] += 1 + op['meta'] = 'backend' + continue + if op['name'] in annotated_ops: + categorization['core'] += 1 + op['meta'] = 'core ' + annotated_ops[op['name']] + continue + categorization['core'] += 1 + op['meta'] = 'core unknown' + return categorization + + annotate_ops(ops, is_unique=False) + with open(f"{analysis_name}", 'w') as f: + for op in ops: + info = [ + op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops) + ] + [check(op) for check in special_op_lists] + f.write(','.join([str(i) for i in info]) + '\n') + + +def name_check(lst): + return lambda x: x['name'] in lst + + +def full_name_check(lst): + return lambda x: x['full_name'] in lst + + +# Generates batching rule data +gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt') + + +def remove_suffix(input_string, suffix): + if suffix and input_string.endswith(suffix): + return input_string[:-len(suffix)] + return input_string + +def remove_prefix(input_string, prefix): + if prefix and input_string.startswith(prefix): + return input_string[len(prefix):] + return input_string + + +if True: + with open('run_ops.txt', 'r') as f: + opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] + with open('count_ops.txt', 'r') as f: + opinfo_counts = [i.strip() for i in f.readlines()] + opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)}) + + def count_fn(x): + return opinfo_counts[x['full_name']] + + with open('run_decompositions.txt', 'r') as f: + decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] + + with open('public_api', 'r') as f: + ref_api = [i.strip() for i in f.readlines()] + + def has_ref_impl(x): + name = x['name'] + for prefix in ["linalg_", "special_"]: + name = remove_prefix(name, prefix) + prefixes = ['nn.functional', 'fft', 'special', 'linalg'] + return any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api + + gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn, has_ref_impl], 'decompositions.txt') diff --git a/functorch/op_analysis/public_api b/functorch/op_analysis/public_api new file mode 100644 index 0000000000000..c3886b6c616f4 --- /dev/null +++ b/functorch/op_analysis/public_api @@ -0,0 +1,623 @@ +__imul__ +__matmul__ +__radd__ +__rand__ +__rdiv__ +__rmatmul__ +__rmod__ +__rmul__ +__ror__ +__rpow__ +__rrshift__ +__rsub__ +__rxor__ +add +atan2 +bitwise_and +bitwise_left_shift +bitwise_or +bitwise_right_shift +bitwise_xor +complex +copysign +div +eq +float_power +floor_divide +fmax +fmin +fmod +gcd +ge +gt +heaviside +hypot +igamma +igammac +isclose +lcm +ldexp +le +logical_and +logical_or +logical_xor +lt +maximum +minimum +mul +ne +nextafter +polar +pow +remainder +rsub +special.xlog1py +special.zeta +sub +true_divide +xlogy +abs +acos +acosh +angle +asin +asinh +atan +atanh +bitwise_not +ceil +conj +conj_physical +cos +cosh +deg2rad +digamma +erf +erfc +erfinv +exp +exp2 +expm1 +floor +frac +frexp +i0 +imag +isfinite +isinf +isnan +isneginf +isposinf +isreal +lgamma +log +log10 +log1p +log2 +logical_not +logit +multigammaln +nan_to_num +neg +nn.functional.celu +nn.functional.elu +nn.functional.hardsigmoid +nn.functional.logsigmoid +nn.functional.mish +nn.functional.rrelu +nn.functional.selu +nn.functional.silu +nn.functional.softsign +nn.functional.tanhshrink +polygamma +positive +rad2deg +real +reciprocal +round +rsqrt +sgn +sigmoid +sign +signbit +sin +sinc +sinh +special.entr +special.erfcx +special.i0e +special.i1 +special.i1e +special.log_ndtr +special.ndtr +special.ndtri +special.polygamma +sqrt +square +tan +tanh +trunc +fft_fftfreq +fft_rfftfreq +fft.fft +fft.fft2 +fft.fftn +fft.fftshift +fft.hfft +fft.hfft2 +fft.hfftn +fft.ifft +fft.ifft2 +fft.ifftn +fft.ifftshift +fft.ihfft +fft.ihfft2 +fft.ihfftn +fft.irfft +fft.irfft2 +fft.irfftn +fft.rfft +fft.rfft2 +fft.rfftn +istft +stft +as_strided +atleast_1d +atleast_2d +atleast_3d +broadcast_shapes +broadcast_tensors +broadcast_to +cat +chunk +column_stack +contiguous +dsplit +dstack +expand +expand_as +flatten +flip +fliplr +flipud +H +hsplit +hstack +mH / adjoint +movedim +mT +narrow +narrow_copy +permute +ravel +repeat +repeat_interleave +reshape +reshape_as +resize_ +resize_as_ +roll +rot90 +split +split_with_sizes +squeeze +stack +swapaxes +T +t +tensor_split +tile +transpose +unbind +unflatten +unfold +unsqueeze +view +view_as +vsplit +vstack +addbmm +addmm +addmv +addr +baddbmm +block_diag +bmm +cartesian_prod +chain_matmul +cholesky +cholesky_inverse +cholesky_solve +cross +diag +diag_embed +diagflat +diagonal +dot +eig +einsum +frobenius_norm +geqrf +ger +inner +inverse +kron +linalg.cholesky +linalg.cholesky_ex +linalg.cond +linalg.cross +linalg.det +linalg.eig +linalg.eigh +linalg.eigvals +linalg.eigvalsh +linalg.householder_product +linalg.inv +linalg.inv_ex +linalg.lstsq +linalg.lu_factor +linalg.lu_factor_ex +linalg.matrix_norm +linalg.matrix_power +linalg.matrix_rank +linalg.multi_dot +linalg.norm +linalg.pinv +linalg.qr +linalg.slogdet +linalg.solve +linalg.solve_triangular +linalg.svd +linalg.svdvals +linalg.tensorinv +linalg.tensorsolve +linalg.vector_norm +lobpcg +logaddexp +logaddexp2 +logdet +lu +lu_solve +lu_unpack +matmul +matrix_exp +mm +mv +norm +nuclear_norm +orgqr +ormqr +outer +pca_lowrank +pinverse +qr +renorm +solve +svd +svd_lowrank +symeig +tensordot +trace +triangular_solve +tril +tril_indices +triu +triu_indices +vdot +all +amax +amin +nanquantile +quantile +sum +nansum +aminmax +any +argmax +std +var +var_mean +std_mean +argmin +cummax +cummin +cumprod +cumsum +logcumsumexp +logsumexp +max +mean +median +min +mode +nanmean +nanmedian +nanquantile +nansum +quantile +std +std_mean +sum +var +var_mean +cumulative_trapezoid +diff +gradient +grid_sampler +grid_sampler_2d +grid_sampler_3d +lerp +trapezoid +trapz +arange +as_tensor +asarray +bartlett_window +blackman_window +empty +empty_like +empty_strided +eye +from_numpy +full +full_like +hamming_window +hann_window +kaiser_window +linspace +logspace +meshgrid +new +new_empty +new_empty +new_empty_strided +new_full +new_full +new_ones +new_ones +new_tensor +new_zeros +new_zeros +normal +ones +ones_like +rand +rand_like +randint +randint_like +randn +randn_like +randn_like +random_ +randperm +range +tensor +vander +zeros +zeros_like +adaptive_max_pool1d_with_indices +adaptive_max_pool2d_with_indices +adaptive_max_pool3d_with_indices +affine_grid +alpha_dropout +batch_norm_gather_stats_with_counts +batch_norm_stats +binary_cross_entropy +binary_cross_entropy_with_logits +cdist +channel_shuffle +constant_pad_nd +conv3d +convolution +dist +dropout3d +feature_alpha_dropout +feature_dropout +fractional_max_pool2d_with_indices +fractional_max_pool3d_with_indices +gru_cell +gumbel_softmax +kl_div +l1_loss +log_softmax +lp_pool1d +lp_pool2d +lstm_cell +margin_ranking_loss +max_pool1d_with_indices +max_pool2d_with_indices +max_pool3d_with_indices +max_unpool1d +max_unpool2d +max_unpool3d +multi_head_attention_forward +multi_margin_loss +multilabel_margin_loss +multilabel_soft_margin_loss +native_batch_norm +native_dropout +native_layer_norm +native_norm +nn.functional.adaptive_avg_pool1d +nn.functional.adaptive_avg_pool2d +nn.functional.adaptive_avg_pool3d +nn.functional.adaptive_max_pool1d +nn.functional.adaptive_max_pool2d +nn.functional.adaptive_max_pool3d +nn.functional.avg_pool1d +nn.functional.avg_pool2d +nn.functional.avg_pool3d +nn.functional.batch_norm +nn.functional.bilinear +nn.functional.conv_transpose1d +nn.functional.conv_transpose2d +nn.functional.conv_transpose3d +nn.functional.conv1d +nn.functional.conv2d +nn.functional.cosine_embedding_loss +nn.functional.cosine_similarity +nn.functional.cross_entropy +nn.functional.ctc_loss +nn.functional.dropout +nn.functional.dropout2d +nn.functional.embedding +nn.functional.embedding_bag +nn.functional.feature_alpha_dropout +nn.functional.fractional_max_pool2d +nn.functional.fractional_max_pool3d +nn.functional.gaussian_nll_loss +nn.functional.gelu +nn.functional.glu +nn.functional.grid_sample +nn.functional.group_norm +nn.functional.hardshrink +nn.functional.hardsigmoid +nn.functional.hardswish +nn.functional.hardtanh +nn.functional.hinge_embedding_loss +nn.functional.huber_loss +nn.functional.instance_norm +nn.functional.interpolate +nn.functional.kl_div +nn.functional.layer_norm +nn.functional.leaky_relu +nn.functional.linear +nn.functional.local_response_norm +nn.functional.logsigmoid +nn.functional.max_pool1d +nn.functional.max_pool2d +nn.functional.max_pool3d +nn.functional.mse_loss +nn.functional.nll_loss +nn.functional.normalize +nn.functional.one_hot +nn.functional.pad +nn.functional.pairwise_distance +nn.functional.pixel_shuffle +nn.functional.pixel_unshuffle +nn.functional.poisson_nll_loss +nn.functional.prelu +nn.functional.relu +nn.functional.relu6 +nn.functional.softmin +nn.functional.softplus +nn.functional.softshrink +nn.functional.softsign +nn.functional.tanhshrink +nn.functional.threshold +nn.functional.unfold +nn.functional.upsample_bilinear +nn.functional.upsample_nearest +pdist +rnn_relu +rnn_relu_cell +rnn_tanh +rnn_tanh_cell +smooth_l1_loss +soft_margin_loss +softmax +triplet_margin_loss +triplet_margin_with_distance_loss +argsort +argwhere +bincount +bucketize +corrcoef +count_nonzero +cov +histc +histogram +histogramdd +isin +kthvalue +msort +nonzero +searchsorted +sort +topk +unique +unique_consecutive +clamp +clamp_max +clamp_min +where +diagonal_scatter +gather +scatter +scatter_add +scatter_reduce +select_scatter +slice_scatter +index_add +index_copy +index_fill +index_put +index_select +item +put +select +take +take_along_dim +bfloat16 +bool +byte +char +clone +cpu +cuda +double +fill_ +float +half +int +long +short +type_as +zero_ +coalesce +dense_dim +sparse_coo_tensor +sparse_csr_tensor +to_dense +to_sparse +to_sparse_coo +complex +conj +conj_physical +real +resolve_conj +resolve_neg +view_as_complex +view_as_real +bernoulli +binomial +cauchy +exponential_ +geometric_ +multinomial +poisson +__getitem__ +combinations +is_complex +is_floating_point +is_signed +sum_to_size +_masked.amax +_masked.amin +_masked.log_softmax +_masked.mean +_masked.norm +_masked.normalize +_masked.prod +_masked.softmax +_masked.softmin +_masked.std +_masked.sum +_masked.var +masked_fill +masked_scatter +masked_select +addcdiv +addcmul +allclose +equal diff --git a/functorch/packaging/build_wheel.sh b/functorch/packaging/build_wheel.sh new file mode 100644 index 0000000000000..074e7dde77141 --- /dev/null +++ b/functorch/packaging/build_wheel.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -ex + +script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +. "$script_dir/pkg_helpers.bash" + +export BUILD_TYPE=wheel +setup_env 0.2.0 +setup_wheel_python +pip_install numpy pyyaml future ninja +pip_install --upgrade setuptools +setup_pip_pytorch_version +python setup.py clean + +if [[ "$OSTYPE" == "msys" ]]; then + "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel +else + python setup.py bdist_wheel +fi diff --git a/functorch/packaging/pkg_helpers.bash b/functorch/packaging/pkg_helpers.bash new file mode 100644 index 0000000000000..329891a07216c --- /dev/null +++ b/functorch/packaging/pkg_helpers.bash @@ -0,0 +1,414 @@ +# A set of useful bash functions for common functionality we need to do in +# many build scripts + + +# Setup CUDA environment variables, based on CU_VERSION +# +# Inputs: +# CU_VERSION (cpu, cu92, cu100) +# NO_CUDA_PACKAGE (bool) +# BUILD_TYPE (conda, wheel) +# +# Outputs: +# VERSION_SUFFIX (e.g., "") +# PYTORCH_VERSION_SUFFIX (e.g., +cpu) +# WHEEL_DIR (e.g., cu100/) +# CUDA_HOME (e.g., /usr/local/cuda-9.2, respected by torch.utils.cpp_extension) +# FORCE_CUDA (respected by torchvision setup.py) +# NVCC_FLAGS (respected by torchvision setup.py) +# +# Precondition: CUDA versions are installed in their conventional locations in +# /usr/local/cuda-* +# +# NOTE: Why VERSION_SUFFIX versus PYTORCH_VERSION_SUFFIX? If you're building +# a package with CUDA on a platform we support CUDA on, VERSION_SUFFIX == +# PYTORCH_VERSION_SUFFIX and everyone is happy. However, if you are building a +# package with only CPU bits (e.g., torchaudio), then VERSION_SUFFIX is always +# empty, but PYTORCH_VERSION_SUFFIX is +cpu (because that's how you get a CPU +# version of a Python package. But that doesn't apply if you're on OS X, +# since the default CU_VERSION on OS X is cpu. +setup_cuda() { + + # First, compute version suffixes. By default, assume no version suffixes + export VERSION_SUFFIX="" + export PYTORCH_VERSION_SUFFIX="" + export WHEEL_DIR="" + # Wheel builds need suffixes (but not if they're on OS X, which never has suffix) + if [[ "$BUILD_TYPE" == "wheel" ]] && [[ "$(uname)" != Darwin ]]; then + export PYTORCH_VERSION_SUFFIX="+$CU_VERSION" + # Match the suffix scheme of pytorch, unless this package does not have + # CUDA builds (in which case, use default) + if [[ -z "$NO_CUDA_PACKAGE" ]]; then + export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX" + export WHEEL_DIR="$CU_VERSION/" + fi + fi + + # Now work out the CUDA settings + case "$CU_VERSION" in + cu115) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.5" + else + export CUDA_HOME=/usr/local/cuda-11.5/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" + ;; + cu113) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.3" + else + export CUDA_HOME=/usr/local/cuda-11.3/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" + ;; + cu112) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.2" + else + export CUDA_HOME=/usr/local/cuda-11.2/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" + ;; + cu111) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.1" + else + export CUDA_HOME=/usr/local/cuda-11.1/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" + ;; + cu110) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.0" + else + export CUDA_HOME=/usr/local/cuda-11.0/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0" + ;; + cu102) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2" + else + export CUDA_HOME=/usr/local/cuda-10.2/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" + ;; + cu101) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.1" + else + export CUDA_HOME=/usr/local/cuda-10.1/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" + ;; + cu100) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0" + else + export CUDA_HOME=/usr/local/cuda-10.0/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" + ;; + cu92) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v9.2" + else + export CUDA_HOME=/usr/local/cuda-9.2/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0" + ;; + cpu) + ;; + rocm*) + export FORCE_CUDA=1 + ;; + *) + echo "Unrecognized CU_VERSION=$CU_VERSION" + exit 1 + ;; + esac + if [[ -n "$CUDA_HOME" ]]; then + # Adds nvcc binary to the search path so that CMake's `find_package(CUDA)` will pick the right one + export PATH="$CUDA_HOME/bin:$PATH" + export FORCE_CUDA=1 + fi +} + +# Populate build version if necessary, and add version suffix +# +# Inputs: +# BUILD_VERSION (e.g., 0.2.0 or empty) +# VERSION_SUFFIX (e.g., +cpu) +# +# Outputs: +# BUILD_VERSION (e.g., 0.2.0.dev20190807+cpu) +# +# Fill BUILD_VERSION if it doesn't exist already with a nightly string +# Usage: setup_build_version 0.2.0 +setup_build_version() { + if [[ -z "$BUILD_VERSION" ]]; then + export BUILD_VERSION="$1.dev$(date "+%Y%m%d")$VERSION_SUFFIX" + else + export BUILD_VERSION="$BUILD_VERSION$VERSION_SUFFIX" + fi + + # Set build version based on tag if on tag + if [[ -n "${CIRCLE_TAG}" ]]; then + # Strip tag + export BUILD_VERSION="$(echo "${CIRCLE_TAG}" | sed -e 's/^v//' -e 's/-.*$//')${VERSION_SUFFIX}" + fi +} + +# Set some useful variables for OS X, if applicable +setup_macos() { + if [[ "$(uname)" == Darwin ]]; then + export MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ + fi +} + + +# Top-level entry point for things every package will need to do +# +# Usage: setup_env 0.2.0 +setup_env() { + setup_cuda + setup_build_version "$1" + setup_macos +} + +# Function to retry functions that sometimes timeout or have flaky failures +retry () { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} + +# Inputs: +# PYTHON_VERSION (3.7, 3.8, 3.9) +# UNICODE_ABI (bool) +# +# Outputs: +# PATH modified to put correct Python version in PATH +# +# Precondition: If Linux, you are in a soumith/manylinux-cuda* Docker image +setup_wheel_python() { + if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then + eval "$(conda shell.bash hook)" + conda env remove -n "env$PYTHON_VERSION" || true + conda create ${CONDA_CHANNEL_FLAGS} -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION" + conda activate "env$PYTHON_VERSION" + # Install libpng from Anaconda (defaults) + conda install ${CONDA_CHANNEL_FLAGS} libpng "jpeg<=9b" -y + else + # Install native CentOS libJPEG, freetype and GnuTLS + yum install -y libjpeg-turbo-devel freetype gnutls + case "$PYTHON_VERSION" in + 3.7) python_abi=cp37-cp37m ;; + 3.8) python_abi=cp38-cp38 ;; + 3.9) python_abi=cp39-cp39 ;; + 3.10) python_abi=cp310-cp310 ;; + *) + echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION" + exit 1 + ;; + esac + # Download all the dependencies required to compile image and video_reader + # extensions + + mkdir -p ext_libraries + pushd ext_libraries + popd + export PATH="/opt/python/$python_abi/bin:$(pwd)/ext_libraries/bin:$PATH" + fi +} + +# Install with pip a bit more robustly than the default +pip_install() { + retry pip install --progress-bar off "$@" +} + +# Install torch with pip, respecting PYTORCH_VERSION, and record the installed +# version into PYTORCH_VERSION, if applicable +setup_pip_pytorch_version() { + if [[ -z "$PYTORCH_VERSION" ]]; then + # Install latest prerelease version of torch, per our nightlies, consistent + # with the requested cuda version + pip_install --pre torch -f "https://download.pytorch.org/whl/nightly/${WHEEL_DIR}torch_nightly.html" + if [[ "$CUDA_VERSION" == "cpu" ]]; then + # CUDA and CPU are ABI compatible on the CPU-only parts, so strip + # in this case + export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')" + else + export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//')" + fi + else + pip_install "torch==$PYTORCH_VERSION$PYTORCH_VERSION_SUFFIX" \ + -f "https://download.pytorch.org/whl/${CU_VERSION}/torch_stable.html" \ + -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${CU_VERSION}/torch_${UPLOAD_CHANNEL}.html" + fi +} + +# Fill PYTORCH_VERSION with the latest conda nightly version, and +# CONDA_CHANNEL_FLAGS with appropriate flags to retrieve these versions +# +# You MUST have populated PYTORCH_VERSION_SUFFIX before hand. +setup_conda_pytorch_constraint() { + if [[ -z "$PYTORCH_VERSION" ]]; then + export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly -c pytorch" + export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | \ + python -c "import os, sys, json, re; cuver = os.environ.get('CU_VERSION'); \ + cuver_1 = cuver.replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ + cuver_2 = (cuver[:-1] + '.' + cuver[-1]).replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ + print(re.sub(r'\\+.*$', '', \ + [x['version'] for x in json.load(sys.stdin)['pytorch'] \ + if (x['platform'] == 'darwin' or cuver_1 in x['fn'] or cuver_2 in x['fn']) \ + and 'py' + os.environ['PYTHON_VERSION'] in x['fn']][-1]))")" + if [[ -z "$PYTORCH_VERSION" ]]; then + echo "PyTorch version auto detection failed" + echo "No package found for CU_VERSION=$CU_VERSION and PYTHON_VERSION=$PYTHON_VERSION" + exit 1 + fi + else + export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch -c pytorch-${UPLOAD_CHANNEL}" + fi + if [[ "$CU_VERSION" == cpu ]]; then + export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==$PYTORCH_VERSION${PYTORCH_VERSION_SUFFIX}" + export CONDA_PYTORCH_CONSTRAINT="- pytorch==$PYTORCH_VERSION" + else + export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" + export CONDA_PYTORCH_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" + fi + if [[ "$OSTYPE" == msys && "$CU_VERSION" == cu92 ]]; then + export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c defaults -c numba/label/dev" + fi +} + +# Translate CUDA_VERSION into CUDA_CUDATOOLKIT_CONSTRAINT +setup_conda_cudatoolkit_constraint() { + export CONDA_BUILD_VARIANT="cuda" + if [[ "$(uname)" == Darwin ]]; then + export CONDA_BUILD_VARIANT="cpu" + else + case "$CU_VERSION" in + cu115) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.5,<11.6 # [not osx]" + ;; + cu113) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.3,<11.4 # [not osx]" + ;; + cu112) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.2,<11.3 # [not osx]" + ;; + cu111) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.1,<11.2 # [not osx]" + ;; + cu110) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.0,<11.1 # [not osx]" + ;; + cu102) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.2,<10.3 # [not osx]" + ;; + cu101) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.1,<10.2 # [not osx]" + ;; + cu100) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.0,<10.1 # [not osx]" + ;; + cu92) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=9.2,<9.3 # [not osx]" + ;; + cpu) + export CONDA_CUDATOOLKIT_CONSTRAINT="" + export CONDA_BUILD_VARIANT="cpu" + ;; + *) + echo "Unrecognized CU_VERSION=$CU_VERSION" + exit 1 + ;; + esac + fi +} + +setup_conda_cudatoolkit_plain_constraint() { + export CONDA_BUILD_VARIANT="cuda" + export CMAKE_USE_CUDA=1 + if [[ "$(uname)" == Darwin ]]; then + export CONDA_BUILD_VARIANT="cpu" + export CMAKE_USE_CUDA=0 + else + case "$CU_VERSION" in + cu115) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.5" + ;; + cu113) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.3" + ;; + cu112) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.2" + ;; + cu111) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.1" + ;; + cu102) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.2" + ;; + cu101) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.1" + ;; + cu100) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.0" + ;; + cu92) + export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=9.2" + ;; + cpu) + export CONDA_CUDATOOLKIT_CONSTRAINT="" + export CONDA_BUILD_VARIANT="cpu" + export CMAKE_USE_CUDA=0 + ;; + *) + echo "Unrecognized CU_VERSION=$CU_VERSION" + exit 1 + ;; + esac + fi +} + +# Build the proper compiler package before building the final package +setup_visual_studio_constraint() { + if [[ "$OSTYPE" == "msys" ]]; then + export VSTOOLCHAIN_PACKAGE=vs$VC_YEAR + conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload packaging/$VSTOOLCHAIN_PACKAGE + cp packaging/$VSTOOLCHAIN_PACKAGE/conda_build_config.yaml packaging/torchvision/conda_build_config.yaml + fi +} + +setup_junit_results_folder() { + if [[ "$CI" == "true" ]]; then + export CONDA_PYTORCH_BUILD_RESULTS_DIRECTORY="${SOURCE_ROOT_DIR}/build_results/results.xml" + fi +} + + +download_copy_ffmpeg() { + if [[ "$OSTYPE" == "msys" ]]; then + # conda install -yq ffmpeg=4.2 -c pytorch + # curl -L -q https://anaconda.org/pytorch/ffmpeg/4.3/download/win-64/ffmpeg-4.3-ha925a31_0.tar.bz2 --output ffmpeg-4.3-ha925a31_0.tar.bz2 + # bzip2 --decompress --stdout ffmpeg-4.3-ha925a31_0.tar.bz2 | tar -x --file=- + # cp Library/bin/*.dll ../torchvision + echo "FFmpeg is disabled currently on Windows" + else + if [[ "$(uname)" == Darwin ]]; then + conda install -yq ffmpeg=4.2 -c pytorch + conda install -yq wget + else + # pushd ext_libraries + # wget -q https://anaconda.org/pytorch/ffmpeg/4.2/download/linux-64/ffmpeg-4.2-hf484d3e_0.tar.bz2 + # tar -xjvf ffmpeg-4.2-hf484d3e_0.tar.bz2 + # rm -rf ffmpeg-4.2-hf484d3e_0.tar.bz2 + # ldconfig + # which ffmpeg + # popd + echo "FFmpeg is disabled currently on Linux" + fi + fi +} diff --git a/functorch/packaging/windows/internal/cuda_install.bat b/functorch/packaging/windows/internal/cuda_install.bat new file mode 100644 index 0000000000000..41960224ebaed --- /dev/null +++ b/functorch/packaging/windows/internal/cuda_install.bat @@ -0,0 +1,264 @@ +@echo on + +if "%CU_VERSION%" == "cpu" ( + echo Skipping for CPU builds + exit /b 0 +) + +set SRC_DIR=%~dp0\.. + +if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" + +rem in unit test workflow, we get CUDA_VERSION, for example 11.1 +if defined CUDA_VERSION ( + set CUDA_VER=%CUDA_VERSION:.=% +) else ( + set CUDA_VER=%CU_VERSION:cu=% +) + +set /a CUDA_VER=%CU_VERSION:cu=% +set CUDA_VER_MAJOR=%CUDA_VER:~0,-1% +set CUDA_VER_MINOR=%CUDA_VER:~-1,1% +set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR% + + +if %CUDA_VER% EQU 92 goto cuda92 +if %CUDA_VER% EQU 100 goto cuda100 +if %CUDA_VER% EQU 101 goto cuda101 +if %CUDA_VER% EQU 102 goto cuda102 +if %CUDA_VER% EQU 110 goto cuda110 +if %CUDA_VER% EQU 111 goto cuda111 +if %CUDA_VER% EQU 112 goto cuda112 +if %CUDA_VER% EQU 113 goto cuda113 +if %CUDA_VER% EQU 115 goto cuda115 + + +echo CUDA %CUDA_VERSION_STR% is not supported +exit /b 1 + +:cuda92 +if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" + set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" +) + +goto cuda_common + +:cuda100 + +if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" + set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" +) + +goto cuda_common + +:cuda101 + +if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" + set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvjpeg_10.1 nvjpeg_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" +) + +goto cuda_common + +:cuda102 + +if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" + set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvjpeg_10.2 nvjpeg_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" +) + +rem The below only for cu102, if it's used in other version, e.g. cu111, torch.cuda.is_availabe() would be False. +if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" ( + curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" + if errorlevel 1 exit /b 1 +) + +echo Installing GPU driver DLLs +7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32" + +goto cuda_common + +:cuda110 + +if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" + set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvjpeg_11.0 nvjpeg_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" +) + +goto cuda_common + +:cuda111 + +if not exist "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.1_456.81_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" + set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvjpeg_11.1 nvjpeg_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" +) + +goto cuda_common + +:cuda112 + +if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" + set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvjpeg_11.2 nvjpeg_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ( + curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" +) + +goto cuda_common + +:cuda113 + +set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe +if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( + curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" + set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3" + +) + +set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip +if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( + curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" +) + +goto cuda_common + +:cuda115 + +set CUDA_INSTALL_EXE=cuda_11.5.0_496.13_win10.exe +if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( + curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" + set "ARGS=thrust_11.5 nvcc_11.5 cuobjdump_11.5 nvprune_11.5 nvprof_11.5 cupti_11.5 cublas_11.5 cublas_dev_11.5 cudart_11.5 cufft_11.5 cufft_dev_11.5 curand_11.5 curand_dev_11.5 cusolver_11.5 cusolver_dev_11.5 cusparse_11.5 cusparse_dev_11.5 npp_11.5 npp_dev_11.5 nvrtc_11.5 nvrtc_dev_11.5 nvml_dev_11.5" +) + +set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip +if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( + curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" +) + +goto cuda_common + +:cuda_common + +if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( + curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z" + if errorlevel 1 exit /b 1 +) + +echo Installing CUDA toolkit... +7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda" +pushd "%SRC_DIR%\temp_build\cuda" +sc config wuauserv start= disabled +sc stop wuauserv +sc query wuauserv + +start /wait setup.exe -s %ARGS% -loglevel:6 -log:"%cd%/cuda_install_logs" +echo %errorlevel% + +popd + +echo Installing VS integration... +rem It's for VS 2019 +if "%CUDA_VER_MAJOR%" == "10" ( + xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" +) +if "%CUDA_VER_MAJOR%" == "11" ( + xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" +) + +echo Installing NvToolsExt... +7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" +mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" +mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" +mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" +xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" +xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" +xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" + +echo Setting up environment... +set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" +set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" +set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" +set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" + +if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( + echo CUDA %CUDA_VERSION_STR% installed failed. + echo --------- RunDll32.exe.log + type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.RunDll32.exe.log" + echo --------- setup.exe.log ------- + type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.setup.exe.log" + exit /b 1 +) + +echo Installing cuDNN... +7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" +xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" +xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64" +xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include" + +echo Cleaning temp files +rd /s /q "%SRC_DIR%\temp_build" || ver > nul diff --git a/functorch/packaging/windows/internal/driver_update.bat b/functorch/packaging/windows/internal/driver_update.bat new file mode 100644 index 0000000000000..00b43affc01cc --- /dev/null +++ b/functorch/packaging/windows/internal/driver_update.bat @@ -0,0 +1,25 @@ +set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" +curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe +if errorlevel 1 exit /b 1 + +start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot +if errorlevel 1 exit /b 1 + +del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL + +setlocal EnableDelayedExpansion +set NVIDIA_GPU_EXISTS=0 +for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( + set GPUS=%%i + if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( + SET NVIDIA_GPU_EXISTS=1 + goto gpu_check_end + ) +) +:gpu_check_end +endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% + +if "%NVIDIA_GPU_EXISTS%" == "0" ( + echo "CUDA Driver installation Failed" + exit /b 1 +) diff --git a/functorch/packaging/windows/internal/vc_env_helper.bat b/functorch/packaging/windows/internal/vc_env_helper.bat new file mode 100644 index 0000000000000..e85a372f93d58 --- /dev/null +++ b/functorch/packaging/windows/internal/vc_env_helper.bat @@ -0,0 +1,43 @@ +@echo on + +set VC_VERSION_LOWER=16 +set VC_VERSION_UPPER=17 +if "%VC_YEAR%" == "2017" ( + set VC_VERSION_LOWER=15 + set VC_VERSION_UPPER=16 +) + +for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( + if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( + set "VS15INSTALLDIR=%%i" + set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" + goto vswhere + ) +) + +:vswhere +if "%VSDEVCMD_ARGS%" == "" ( + call "%VS15VCVARSALL%" x64 || exit /b 1 +) else ( + call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 +) + +@echo on + +set DISTUTILS_USE_SDK=1 + +set args=%1 +shift +:start +if [%1] == [] goto done +set args=%args% %1 +shift +goto start + +:done +if "%args%" == "" ( + echo Usage: vc_env_helper.bat [command] [args] + echo e.g. vc_env_helper.bat cl /c test.cpp +) + +%args% || exit /b 1 diff --git a/functorch/packaging/windows/internal/vc_install_helper.sh b/functorch/packaging/windows/internal/vc_install_helper.sh new file mode 100644 index 0000000000000..cdae18065b9f6 --- /dev/null +++ b/functorch/packaging/windows/internal/vc_install_helper.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -ex + +if [[ "$CU_VERSION" == "cu92" ]]; then + export VC_YEAR=2017 + export VSDEVCMD_ARGS="-vcvars_ver=14.13" + powershell packaging/windows/internal/vs2017_install.ps1 +elif [[ "$CU_VERSION" == "cu100" ]]; then + export VC_YEAR=2017 + export VSDEVCMD_ARGS="" + powershell packaging/windows/internal/vs2017_install.ps1 +else + export VC_YEAR=2019 + export VSDEVCMD_ARGS="" +fi diff --git a/functorch/pull_request_template.md b/functorch/pull_request_template.md new file mode 100644 index 0000000000000..abb0f9bfe5184 --- /dev/null +++ b/functorch/pull_request_template.md @@ -0,0 +1,5 @@ +To contribute a change to functorch, please make sure you are submitting a +Pull Request to the functorch folder in https://github.com/pytorch/pytorch +repository. The source of truth for functorch has moved there from +https://github.com/pytorch/functorch ; the pytorch/functorch repository +is now read-only. diff --git a/functorch/setup.cfg b/functorch/setup.cfg new file mode 100644 index 0000000000000..c2f3b448a6c20 --- /dev/null +++ b/functorch/setup.cfg @@ -0,0 +1,18 @@ +[bdist_wheel] +universal=1 + +[metadata] +license_file = LICENSE + +[pep8] +max-line-length = 120 + +[flake8] +max-line-length = 120 +exclude = docs, benchmarks, notebooks, tools +per-file-ignores = + __init__.py: F401 + functorch/_src/decompositions.py: E501 + +[pydocstyle] +select = D417 # Missing argument descriptions in the docstring diff --git a/functorch/setup.py b/functorch/setup.py new file mode 100644 index 0000000000000..aadef78b0f5a0 --- /dev/null +++ b/functorch/setup.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import distutils.command.clean +import shutil +import glob +import os +import subprocess +from setuptools import setup, find_packages +from torch.utils.cpp_extension import ( + CppExtension, + BuildExtension, +) + +cwd = os.path.dirname(os.path.abspath(__file__)) +version_txt = os.path.join(cwd, 'version.txt') +with open(version_txt, 'r') as f: + version = f.readline().strip() + +try: + sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() +except Exception: + sha = 'Unknown' +package_name = 'functorch' + +if os.getenv('BUILD_VERSION'): + version = os.getenv('BUILD_VERSION') +elif sha != 'Unknown': + version += '+' + sha[:7] + + +def write_version_file(): + version_path = os.path.join(cwd, 'functorch', 'version.py') + with open(version_path, 'w') as f: + f.write("__version__ = '{}'\n".format(version)) + f.write("git_version = {}\n".format(repr(sha))) + + +# pytorch_dep = 'torch' +# if os.getenv('PYTORCH_VERSION'): +# pytorch_dep += "==" + os.getenv('PYTORCH_VERSION') +requirements = [ + # This represents a nightly version of PyTorch. + # It can be installed as a binary or from source. + "torch>=1.13.0.dev", +] + +extras = {} +extras["aot"] = ["networkx", ] + + +class clean(distutils.command.clean.clean): + def run(self): + + with open(".gitignore", "r") as f: + ignores = f.read() + for wildcard in filter(None, ignores.split("\n")): + for filename in glob.glob(wildcard): + try: + os.remove(filename) + except OSError: + shutil.rmtree(filename, ignore_errors=True) + + # It's an old-style class in Python 2.7... + distutils.command.clean.clean.run(self) + + +def get_extensions(): + extension = CppExtension + + # See functorch/csrc/Macros.h + define_macros = [('FUNCTORCH_BUILD_MAIN_LIB', None)] + + extra_link_args = [] + extra_compile_args = {"cxx": [ + "-O3", + "-std=c++14", + "-fdiagnostics-color=always", + ]} + debug_mode = os.getenv('DEBUG', '0') == '1' + if debug_mode: + print("Compiling in debug mode") + extra_compile_args = { + "cxx": [ + "-O0", + "-fno-inline", + "-g", + "-std=c++14", + "-fdiagnostics-color=always", + ]} + extra_link_args = ["-O0", "-g"] + + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "functorch", "csrc") + + extension_sources = set( + os.path.join(extensions_dir, p) + for p in glob.glob(os.path.join(extensions_dir, "*.cpp")) + ) + sources = list(extension_sources) + + ext_modules = [ + extension( + "functorch._C", + sources, + include_dirs=[this_dir], + define_macros=define_macros, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] + + return ext_modules + + +class BuildExtension_(BuildExtension): + def build_extensions(self, *args, **kwargs): + # It turns out for windows this isn't populated? + if hasattr(self.compiler, 'compiler_so'): + if '-Wstrict-prototypes' in self.compiler.compiler_so: + self.compiler.compiler_so.remove('-Wstrict-prototypes') + super().build_extensions(*args, **kwargs) + + +if __name__ == '__main__': + print("Building wheel {}-{}".format(package_name, version)) + write_version_file() + setup( + # Metadata + name=package_name, + version=version, + author='PyTorch Core Team', + url="https://github.com/pytorch/functorch", + description='JAX-like composable function transforms for PyTorch', + license='BSD', + + # Package info + packages=find_packages(), + install_requires=requirements, + extras_require=extras, + ext_modules=get_extensions(), + cmdclass={ + "build_ext": BuildExtension_.with_options(no_python_abi_suffix=True), + 'clean': clean, + }) diff --git a/functorch/test/common_utils.py b/functorch/test/common_utils.py new file mode 100644 index 0000000000000..332cae09045fd --- /dev/null +++ b/functorch/test/common_utils.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import torch +import functorch +from functorch import vmap +import torch.utils._pytree as pytree +from functorch_lagging_op_db import functorch_lagging_op_db +from functorch_additional_op_db import additional_op_db +from torch.testing._internal.common_methods_invocations import DecorateInfo +import os +import unittest +from torch.testing._internal.common_device_type import toleranceOverride + +IS_FBCODE = os.getenv('FUNCTORCH_TEST_FBCODE') == '1' + + +def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values): + outs = [] + for idx in range(batch_size): + flat_args, args_spec = pytree.tree_flatten(batched_args) + flat_dims, dims_spec = pytree.tree_flatten(in_dims) + assert(args_spec == dims_spec) + new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)] + out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values) + outs.append(out) + + loop_out = [] + if isinstance(outs[0], torch.Tensor): + loop_out = torch.stack(outs) + else: + for idx in range(len(outs[0])): + loop_out.append(torch.stack([i[idx] for i in outs], out_dim)) + return loop_out + + +# This is kind of dangerous, please think carefully before using it. +# Known risks: +# - the return better not be mutated so it's best to return immutable types +# (e.g. prefer tuples to list) +# - Don't hash tensors in a global context, that'll keep them around forever +def memoize(fn): + memo = {} + + def wrapped(*args): + if args not in memo: + memo[args] = fn(*args) + return memo[args] + return wrapped + + +# NB: This is O(2 ** num_tensors). +# num_tensors ranges from 1 to 10, with 2-4 being most common. +# Try not to extravagate it if you're modifying it. +@memoize +def get_bdim_choices(num_tensors): + choices = [] + + # full of zeros + choices.append((0,) * num_tensors) + + # All permutations of (-1, None) + options = (-1, None) + for choice in itertools.product(options, repeat=num_tensors): + choices.append(choice) + + assert choices[-1] == (None,) * num_tensors + return tuple(choices[:-1]) + +# NB: This is O(2 ** num_tensors). +# num_tensors ranges from 1 to 10, with 2-4 being most common. +# Try not to extravagate it if you're modifying it. +def get_bdim_choices_batch_norm(num_tensors, _, running_mean=None, running_var=None, *args): + choices = [] + options = (-1, None) + + # instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified + if running_mean is None or running_var is None: + choices.append((None,) + (0,) * (num_tensors - 1)) + for choice in itertools.product(options, repeat=num_tensors - 1): + choices.append((None,) + choice) + + else: + # running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but + # running_mean/var are unbatched, so this tests all other cases + choices.append((0,) * num_tensors) + for choice in itertools.product(options, repeat=num_tensors): + input_bdim = choice[0] + running_mean_bdim = choice[1] + running_var_bdim = choice[2] + if input_bdim and (not running_mean_bdim or not running_var_bdim): + continue + choices.append(choice) + + assert choices[-1] == (None,) * num_tensors + return tuple(choices[:-1]) + + +def add_batch_dim(arg, bdim, batch_size=3): + assert bdim == 0 or bdim == -1 + assert isinstance(arg, torch.Tensor) + if bdim == 0: + shape = [1] * len(arg.shape) + shape.insert(bdim, batch_size) + return (arg.repeat(shape), bdim) + if bdim == -1: + arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous() + return (arg, bdim) + + +def construct_in_dims(bdim_choice_for_tensors, is_tensors): + result = [] + bdim = iter(bdim_choice_for_tensors) + for is_tensor in is_tensors: + if not is_tensor: + result.append(None) + continue + result.append(next(bdim)) + return tuple(result) + +def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=2): + flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values)) + is_tensors = [isinstance(a, torch.Tensor) for a in flat_args] + bdim_choices = get_bdim_choices(sum(is_tensors)) + + @memoize + def get_batched_arg(arg, bdim): + assert isinstance(arg, torch.Tensor) + assert bdim is not None + result, _ = add_batch_dim(arg, bdim, batch_size) + return result + + for bdim_choice in bdim_choices: + flat_in_dims = construct_in_dims(bdim_choice, is_tensors) + + flat_batched_args = tuple(arg if in_dim is None else get_batched_arg(arg, in_dim) + for arg, in_dim in zip(flat_args, flat_in_dims)) + batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec) + in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec) + yield batched_args, in_dims, kwarg_values + + +def is_batch_norm_training(op_name, kwarg_values): + batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm + if op_name not in batch_norm_fns: + return False + + # batch norm and instance norm require the value to be a plain bool + default_training = op_name == "nn.functional.instance_norm" # instance norm defaults to training, batch norm doesn't + is_training = tuple(arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool)) + if len(is_training) == 0: + return default_training + else: + assert len(is_training) == 1 + return is_training[0] + + +def get_exhaustive_batched_inputs_batch_norm_is_training(arg_values, kwarg_values, batch_size=2): + flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values)) + is_tensors = [isinstance(a, torch.Tensor) for a in flat_args] + num_tensors = sum(is_tensors) + if num_tensors == 1: # if there's only an input, can't batch it since running_mean/var will be seen as unbatched tensors + return + bdim_choices = get_bdim_choices_batch_norm(num_tensors, *arg_values) + + @memoize + def get_batched_arg(arg, bdim): + assert isinstance(arg, torch.Tensor) + assert bdim is not None + result, _ = add_batch_dim(arg, bdim, batch_size) + return result + + for bdim_choice in bdim_choices: + flat_in_dims = construct_in_dims(bdim_choice, is_tensors) + + flat_batched_args = tuple(arg if in_dim is None else get_batched_arg(arg, in_dim) + for arg, in_dim in zip(flat_args, flat_in_dims)) + batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec) + in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec) + yield batched_args, in_dims, kwarg_values + + +def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True): + out_dim = 0 + batch_size = 2 + + if is_batch_norm_and_training: + generator = get_exhaustive_batched_inputs_batch_norm_is_training(arg_values, kwarg_values, batch_size) + else: + generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size) + + for batched_args, in_dims, kwarg_values in generator: + if compute_loop_out: + loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) + else: + loop_out = None + # Used for debugging the resulting operations + # from functorch import make_fx + # def f(a): + # return op(a) + # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values) + # print(in_dims, [arg.shape for arg in batched_args], kwarg_values) + batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values) + yield (loop_out, batched_out) + + # Tests case where we dispatch to a batching rule with no bdims + # This should be handled by autogenerated plumbing. For vmap support + # added via a manual plumbing you may need to handle this specially. + def add_bdim_if_tensor(x): + if isinstance(x, torch.Tensor): + return x.unsqueeze(1) + return x + + def f(dummy, *args, **kwargs): + return op(*args, **kwargs) + + dummy = torch.ones(batch_size, 1) + expected = pytree.tree_map(add_bdim_if_tensor, batched_out) + + inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims) + outer_in_dims = (0,) + in_dims + output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values) + yield (expected, output) + + +def opinfo_in_dict(opinfo, d): + return (opinfo.name in d) or (f'{opinfo.name}.{opinfo.variant_test_name}' in d) + + +def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, True) + +# TODO: this doesn't work in python < 3.8 + + +def skip(op_name, variant_name='', *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, False) + + +def skipOps(test_case_name, base_test_name, to_skip): + all_opinfos = functorch_lagging_op_db + additional_op_db + for xfail in to_skip: + op_name, variant_name, device_type, dtypes, expected_failure = xfail + matching_opinfos = [o for o in all_opinfos + if o.name == op_name and o.variant_test_name == variant_name] + assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" + for opinfo in matching_opinfos: + decorators = list(opinfo.decorators) + if expected_failure: + decorator = DecorateInfo(unittest.expectedFailure, + test_case_name, base_test_name, + device_type=device_type, dtypes=dtypes) + decorators.append(decorator) + else: + decorator = DecorateInfo(unittest.skip("Skipped!"), + test_case_name, base_test_name, + device_type=device_type, dtypes=dtypes) + decorators.append(decorator) + opinfo.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + return wrapped + + +def tol2(op_name, variant_name, override_dct, *, device_type=None): + return (op_name, variant_name, override_dct, device_type) + + +def tol1(op_name, override_dct, *, device_type=None): + return tol2(op_name, '', override_dct, device_type=device_type) + + +def opsToleranceOverride(test_case_name, base_test_name, overrides): + all_opinfos = functorch_lagging_op_db + additional_op_db + for override in overrides: + op_name, variant_name, override, device_type = override + matching_opinfos = [o for o in all_opinfos + if o.name == op_name and o.variant_test_name == variant_name] + assert len(matching_opinfos) == 1, f"Couldn't find OpInfo for {override}" + opinfo = matching_opinfos[0] + decorators = list(opinfo.decorators) + decorators.append(DecorateInfo( + toleranceOverride(override), + test_case_name, base_test_name, device_type=device_type)) + opinfo.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + return wrapped + + +class DisableVmapFallback: + def __enter__(self): + self.prev_state = functorch._C._is_vmap_fallback_enabled() + functorch._C._set_vmap_fallback_enabled(False) + + def __exit__(self, *ignored): + functorch._C._set_vmap_fallback_enabled(self.prev_state) + + +def check_vmap_fallback(test_case, thunk, opinfo, dry_run=False): + try: + with DisableVmapFallback(): + thunk() + except Exception: + if not dry_run: + raise + if opinfo.variant_test_name: + print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),") + else: + print(f"xfail('{opinfo.name}'),") diff --git a/functorch/test/discover_coverage.py b/functorch/test/discover_coverage.py new file mode 100644 index 0000000000000..d31a25a2ec49a --- /dev/null +++ b/functorch/test/discover_coverage.py @@ -0,0 +1,900 @@ +import torch +import copy +from torch.testing._internal.common_methods_invocations import op_db +from functorch_additional_op_db import additional_op_db +from enum import Enum +import functorch._src.top_operators_github_usage as top_ops +import pprint +import unittest +import enum +from functorch_lagging_op_db import in_functorch_lagging_op_db +from torch.testing._internal.common_device_type import toleranceOverride + +# Importing these files make modifications to the op_db that we need +import test_ops # noqa: F401 +import test_vmap # noqa: F401 + +all_overridable = list(torch.overrides.get_testing_overrides().keys()) + +public_docs = [ + (torch.nn.functional, 'torch.nn.functional', 'docs/source/nn.functional.rst'), + (torch.fft, 'torch.fft', 'docs/source/fft.rst'), + (torch.special, 'torch.special', 'docs/source/special.rst'), + (torch.linalg, 'torch.linalg', 'docs/source/linalg.rst'), + (torch, 'torch', 'docs/source/torch.rst'), + (torch.Tensor, 'torch.Tensor', 'docs/source/tensors.rst'), +] + +# torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different + + +def get_public_overridable_apis(pytorch_root='/raid/rzou/pt/debug-cpu'): + results = {} + all_overridable_apis = set(torch.overrides.get_testing_overrides().keys()) + for module, module_name, src in public_docs: + with open(f'{pytorch_root}/{src}') as f: + lines = f.readlines() + # APIs eitehr begin with 4 spaces or ".. autofunction::" + api_lines1 = [line.strip() for line in lines if line.startswith(' ' * 4)] + api_lines2 = [line.strip()[len('.. autofunction:: '):] + for line in lines if line.startswith('.. autofunction::')] + lines = api_lines1 + api_lines2 + lines = [line[7:] if line.startswith('Tensor.') else line for line in lines] + lines = [line for line in lines if hasattr(module, line)] + for line in lines: + api = getattr(module, line) + if api in all_overridable_apis: + results[f'{module_name}.{line}'] = api + return results + + +denylist = { + 'torch.Tensor.data_ptr', + 'torch.Tensor.dim', + 'torch.Tensor.element_size', + 'torch.Tensor.backward', + 'torch.Tensor.as_strided', + 'torch.Tensor.register_hook', + 'torch.Tensor.record_stream', + 'torch.Tensor.qscheme', + 'torch.Tensor.ndimension', + 'torch.Tensor.smm', + 'torch.Tensor.sspaddmm', + 'torch.Tensor.retain_grad', + 'torch.Tensor.sparse_mask', + 'torch.Tensor.sparse_dim', + 'torch.Tensor.dense_dim', + 'torch.Tensor.values', + 'torch.Tensor.indices', + 'torch.Tensor.numel', + 'torch.Tensor.size', + 'torch.Tensor.nelement', + 'torch.Tensor.q_scale', + 'torch.Tensor.q_zero_point', + 'torch.Tensor.q_per_channel_scales', + 'torch.Tensor.q_per_channel_zero_points', + 'torch.Tensor.q_per_channel_axis', + 'torch.Tensor.int_repr', + 'torch.Tensor.to_sparse', + 'torch.Tensor.is_inference', + 'torch.Tensor.storage', + 'torch.Tensor.storage_type', +} + + +def get_method_only_ops_we_care_about(): + apis = get_public_overridable_apis() + result = [] + for key, _ in apis.items(): + if not key.startswith('torch.Tensor'): + continue + if key in denylist: + continue + api = key.split('.')[2] + # filter out in-place + if api.endswith('_'): + continue + if f'torch.{api}' not in apis.keys(): + result.append(api) + return result + +# Deduplicates torch.abs and Tensor.abs + + +def get_public_overridable_ops(): + results = get_public_overridable_apis() + cpy = copy.deepcopy(results) + for key, _ in cpy.items(): + if not key.startswith('torch.Tensor'): + continue + api = key.split('.')[2] + if f'torch.{api}' in results.keys(): + del results[key] + return results + + +def get_public_overridable_outplace_ops(): + results = get_public_overridable_ops() + cpy = copy.deepcopy(results) + for key, _ in cpy.items(): + # NB: there are no dunder methods bcs we don't document those + if key.endswith('_'): + del results[key] + return results + + +def get_public_overridable_outplace_we_care_about(): + results = get_public_overridable_outplace_ops() + cpy = copy.deepcopy(results) + for key, _ in cpy.items(): + # quantization + if 'quant' in key or '.q_' in key: + del results[key] + + # is_cpu, etc. It doesn't make sense to have OpInfos for these + if '.is_' in key: + del results[key] + + if key in denylist and key in results: + del results[key] + return results + +# e.g. nn.functional.softmax + + +def get_op(dotted_name): + names = dotted_name.split('.') + mod = torch + for name in names: + if not hasattr(mod, name): + return None + mod = getattr(mod, name) + return mod + +# Maps function -> [OpInfo] + + +def get_ops_covered_by_opinfos(): + ops = {} + + def safe_append(dct, key, val): + if key in dct: + dct[key].append(val) + else: + dct[key] = [val] + + for opinfo in op_db: + func_op = get_op(opinfo.name) + if func_op: + safe_append(ops, func_op, opinfo) + if opinfo.method_variant: + safe_append(ops, opinfo.method_variant, opinfo) + if opinfo.inplace_variant: + safe_append(ops, opinfo.inplace_variant, opinfo) + for alias in opinfo.aliases: + safe_append(ops, alias.op, opinfo) + return ops + + +factory_fns = { + 'tensor', 'zeros', 'ones', 'randn', 'arange', 'rand', 'empty', 'randperm', + 'linspace', 'logspace', 'hann_window', 'full', 'eye', 'blackman_window', + 'barlett_window', 'randint', 'range', 'arange', +} + + +def get_top_ops(torch_threshold, nn_fn_threshold, with_counts=False): + denylist = set({ + # These are either not real "operators", factory functions + # that trivially work, or not-documented ops. + 'load', 'no_grad', 'save', 'from_numpy', + 'manual_seed', 'set_grad_enabled', + 'set_default_tensor_type', 'set_num_threads', + 'set_printoptions', 'numel', + 'set_default_dtype', 'sparse_coo_tensor', 'set_rng_state', + 'get_rng_state', 'get_default_dtype', 'initial_seed', + 'get_num_threads', 'quantize_per_tensor', + 'hann_window', 'is_tensor', 'as_tensor', + 'equal', 'enable_grad', 'seed', 'is_storage', + 'is_floating_point', 'nn.functional.torch', + 'set_flush_denormal', 'set_num_interop_threads', 'dequantize', + 'get_num_interop_threads', 'nn.functional.math', + 'nn.functional.threshold_', + 'nn.functional.selu_', + 'nn.functional.elu_', + 'nn.functional.rrelu_', + 'nn.functional.leaky_relu_', + 'nn.functional.hardtanh_', + 'nn.functional.has_torch_function', + 'nn.functional.has_torch_function_unary', + 'nn.functional.has_torch_function_variadic', + 'nn.functional.handle_torch_function', + 'nn.functional.adaptive_max_pool1d_with_indices', + 'nn.functional.adaptive_max_pool2d_with_indices', + 'nn.functional.adaptive_max_pool3d_with_indices', + 'nn.functional.fractional_max_pool2d_with_indices', + 'nn.functional.fractional_max_pool3d_with_indices', + 'is_complex', + 'grad', + 'quantize_per_channel', + 'nn.functional.max_pool2d_with_indices', + 'nn.functional.max_pool3d_with_indices', + 'nn.functional.max_pool1d_with_indices', + 'nn.functional.celu_', + 'nn.functional.grad', + 'nn.functional.relu_', + 'nn.functional.boolean_dispatch', + 'nn.functional.assert_int_or_pair', + 'fft', # is namespace + }) + + torch_ops = top_ops.top_torch + nn_fn_ops = top_ops.get_nn_functional_top_list() + torch_ops = [op for op in torch_ops if op[0] not in denylist] + nn_fn_ops = [op for op in nn_fn_ops if op[0] not in denylist] + + ops = torch_ops[:torch_threshold] + nn_fn_ops[:nn_fn_threshold] + + # Now, sort by priority + ops.sort(reverse=True, key=lambda op: op[1]) + if not with_counts: + ops = [op[0] for op in ops] + return ops + + +def get_ops_percentage(torch_threshold, nn_fn_threshold): + data = top_ops.top_torch + top_ops.get_nn_functional_top_list() + + def get_num_usages(opname): + # Ignore this, this is heavily inflated + if opname == 't': + return 0 + result = [op[1] for op in data if op[0] == opname] + assert len(result) == 1 + return result[0] + + # get all operators that are not in the denylist + all_ops = get_top_ops(999999, 999999) + total_op_usages = sum([get_num_usages(op) for op in all_ops]) + + # get subset of all operators + subset_ops = get_top_ops(torch_threshold, nn_fn_threshold) + subset_op_usages = sum([get_num_usages(op) for op in subset_ops]) + return subset_op_usages / total_op_usages + + +def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0): + ops = get_top_ops(torch_threshold, nn_fn_threshold) + + ops_with_opinfo = [] + for op in op_db: + ops_with_opinfo.append(op.name) + ops_with_opinfo.extend([op.name for op in op.aliases]) + ops_with_opinfo = set(ops_with_opinfo) + + result = [op for op in ops if op not in ops_with_opinfo] + result = [op for op in result if op not in denylist] + result = [op for op in result if op not in factory_fns] + return result + + +def get_covered_ops(ops_list, invert=False): + ops_covered_by_opinfo = get_ops_covered_by_opinfos() + overridable_outplace_ops = ops_list + results = {} + for key, op in overridable_outplace_ops.items(): + cond = op in ops_covered_by_opinfo + if invert: + cond = not cond + if cond: + results[key] = op + return results + + +class Status(Enum): + Correct = 0 + Fast = 1 + + +tests = { + 'test_vmap_exhaustive', + 'test_op_has_batch_rule', + 'test_vjp', + 'test_vmapvjp', + 'test_vmapvjp_has_batch_rule', + 'test_jvp', + 'test_vmapjvp', +} + + +def is_decorateinfo_skip_or_xfail(decorateinfo): + assert len(decorateinfo.decorators) == 1 + actual_decorator = decorateinfo.decorators[0] + if isinstance(actual_decorator, toleranceOverride): + return False + if actual_decorator == unittest.expectedFailure: + return True + # Assume the rest are skips + return True + + +def get_all_tested_ops(): + overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() + op_to_opinfo = get_ops_covered_by_opinfos() + result = set({}) + for name, op in get_covered_ops(overridable_outplace_we_care_about).items(): + opinfos = op_to_opinfo[op] + for opinfo in opinfos: + result.add(opinfo.name) + return result + + +def get_skipped_or_xfailed_ops_for(test_name): + overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() + op_to_opinfo = get_ops_covered_by_opinfos() + result = set({}) + for name, op in get_covered_ops(overridable_outplace_we_care_about).items(): + opinfos = op_to_opinfo[op] + for opinfo in opinfos: + for decorator in opinfo.decorators: + if not hasattr(decorator, 'test_name'): + continue + if decorator.test_name != test_name: + continue + if is_decorateinfo_skip_or_xfail(decorator): + result.add(opinfo.name) + return result + + +def get_statuses(for_subset=None, invert=False): + overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() + if for_subset is not None: + overridable_outplace_we_care_about = { + k: v + for k, v in overridable_outplace_we_care_about.items() + # Removes "torch." + if k[6:] in for_subset + } + op_to_opinfo = get_ops_covered_by_opinfos() + result = {} + _ = get_covered_ops(overridable_outplace_we_care_about) + + def get_covered_tests(op): + opinfos = op_to_opinfo[op] + result = copy.deepcopy(tests) + for opinfo in opinfos: + for decorator in opinfo.decorators: + if not hasattr(decorator, 'test_name'): + continue + if decorator.test_name in tests and decorator.test_name in result: + result.remove(decorator.test_name) + return result + + def get_all_aliases(op): + opinfos = op_to_opinfo[op] + result = [] + for opinfo in opinfos: + result.append(opinfo.name) + result.extend(opinfo.aliases) + return set(result) + + for name, op in get_covered_ops(overridable_outplace_we_care_about).items(): + successful_tests = get_covered_tests(op) + failed_tests = tests - successful_tests + result[name] = failed_tests if invert else successful_tests + return result + + +def transpose_statuses(for_subset=None, invert=False): + statuses = get_statuses(for_subset, invert=invert) + result = {} + for test in tests: + result[test] = set({}) + for op, supported in statuses.items(): + for test in supported: + result[test].add(op) + return result + + +overridable_apis = get_public_overridable_apis() + +overridable_ops = get_public_overridable_ops() + +overridable_outplace_ops = get_public_overridable_outplace_ops() + +overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() + +tested_overridable_outplace_ops = get_covered_ops(overridable_outplace_we_care_about) +untested_overridable_outplace_ops = get_covered_ops(overridable_outplace_we_care_about, invert=True) + +# print("List of OpInfos we need:") +# for key in untested_overridable_outplace_ops.keys(): +# print(key) +# print("-" * 80) +# print("") + +print(f'Overridable public APIs: {len(overridable_apis)}') +print(f'Overridable public ops: {len(overridable_ops)}') +print(f'Overridable public outplace ops: {len(overridable_outplace_ops)}') +print(f'Overridable public outplace ops we care about: {len(overridable_outplace_we_care_about)}') +print(f'OpInfo-tested overridable public outplace ops: {len(tested_overridable_outplace_ops)}') + + +def remove_torch(name): + assert name[:6] == 'torch.' + return name[6:] + + +def get_list_of_all_tests(): + all_tests = list(tested_overridable_outplace_ops.keys()) + return set([remove_torch(test) for test in all_tests]) + + +mytest = { + 'test_vmap_exhaustive', + 'test_op_has_batch_rule', + 'test_vjp', + 'test_vmapvjp', + 'test_vmapvjp_has_batch_rule', +} + +print('*' * 80) +all_tests = get_list_of_all_tests() +for test in mytest: + result = get_skipped_or_xfailed_ops_for(test) + diff = len(all_tests - result) + print(f'{test}: {diff}') + + +def get_jvp_coverage(subset=None): + # - number that support autograd + # - number that support forward_ad (in pytorch core) + # - number that support functorch.jvp + op_to_opinfo = get_ops_covered_by_opinfos() + ops_dct = tested_overridable_outplace_ops + if subset is not None: + ops_dct = {name: op for name, op in ops_dct.items() + if remove_torch(name) in subset} + supports_autograd_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items() + if op_to_opinfo[fn][0].supports_autograd} + supports_forwardad_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items() + if op_to_opinfo[fn][0].supports_forward_ad} + + ops = set([remove_torch(test) for test in list(ops_dct.keys())]) + supports_autograd = set([remove_torch(test) + for test in list(supports_autograd_ops_dct.keys())]) + supports_forward_ad = set([remove_torch(test) + for test in list(supports_forwardad_ops_dct.keys())]) + assert supports_forward_ad.issubset(supports_autograd) + assert supports_autograd.issubset(ops) + + failed_ops = get_skipped_or_xfailed_ops_for('test_jvp') + + coverage = len(supports_forward_ad - failed_ops) + no_forward_ad = len(supports_autograd) - len(supports_forward_ad) + print(f'test_jvp, {coverage}, {no_forward_ad}, {len(ops)}') + + +get_jvp_coverage() +get_jvp_coverage(get_top_ops(100, 25)) +for op in get_top_ops(100, 25): + print(op) +print('*' * 80) + +# result = get_skipped_or_xfailed_ops_for('test_vmap_exhaustive') +# result = get_skipped_or_xfailed_ops_for('test_op_has_batch_rule') +# result = get_skipped_or_xfailed_ops_for('test_vjp') +# result = get_skipped_or_xfailed_ops_for('test_vmapvjp') +# result = get_skipped_or_xfailed_ops_for('test_vmapvjp_has_batch_rule') +# import pdb; pdb.set_trace() + +statuses = transpose_statuses() +for test in tests: + print(f'{test} coverage {len(statuses[test])}') + +method_only_ops = get_method_only_ops_we_care_about() +# for op in method_only_ops: +# print(f' {op},') + +top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(100, 25) +print('=' * 80) +for op in top_ops_not_covered_by_opinfo: + print(f'{op}, {top_ops.usage_count[op]}') + +# print("top ops not covered by opinfo: ") +# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 50) +# for op in top_ops_not_covered_by_opinfo: +# print(f'{op}, {top_ops.usage_count[op]}') + +# print("top ops not covered by opinfo: ") +# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(220, 92) +# for op in top_ops_not_covered_by_opinfo: +# print(f'{op}, {top_ops.usage_count[op]}') + +# print("top ops not covered by opinfo: ") +# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(999, 999) +# for op in top_ops_not_covered_by_opinfo: +# print(f'{op}, {top_ops.usage_count[op]}') + + +def remove_from_set(parent, to_remove): + for to_remove_elt in to_remove: + if to_remove_elt in parent: + parent.remove(to_remove_elt) + + +def print_coverage_info(th=100, nn=25): + print('=' * 80) + print(f"top {th}, {nn} coverage") + statuses = transpose_statuses(get_top_ops(th, nn), invert=True) + top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(th, nn) + + # testing problems + exemptions = { + 'torch.nn.functional.dropout', # randomness + } + + # Allowed exemptions + vmap_exemptions = { + 'torch.randn_like', # randomness + 'torch.rand_like', # randomness + 'torch.allclose', # number output + 'torch.unique', # dynamic + 'torch.nonzero', # dynamic + 'torch.masked_select', # dynamic + 'torch.prod', # dynamic (backward) + 'torch.norm', # norm with nuc is not commonly used; we support the other cases. + 'torch.svd', # There isn't a bug, it is just nondeterministic so we can't test it. + 'torch.nn.functional.embedding', # We support everything except the sparse option. + } + remove_from_set(statuses['test_vmap_exhaustive'], vmap_exemptions) + remove_from_set(statuses['test_vmapvjp'], vmap_exemptions) + remove_from_set(statuses['test_vmapvjp_has_batch_rule'], vmap_exemptions) + remove_from_set(statuses['test_op_has_batch_rule'], vmap_exemptions) + remove_from_set(statuses['test_vmapjvp'], vmap_exemptions) + for test in tests: + remove_from_set(statuses[test], exemptions) + + print(f"total ops in set: {th + nn}") + print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}") + for test in tests: + if test in {'test_jvp', 'test_vmapjvp'}: + continue + print(f'{test} failing coverage {len(statuses[test])}') + + # We don't care about these yet + del statuses['test_jvp'] + del statuses['test_vmapjvp'] + + pprint.pprint(statuses) + + +def get_name_to_opinfo_map(): + dct = {} + for op in (op_db + additional_op_db): + def add(name, op): + if name not in dct: + dct[name] = [] + dct[name].append(op) + add(op.name, op) + for alias in op.aliases: + add(alias.name, op) + return dct + + +NAME_TO_OPINFO = get_name_to_opinfo_map() + + +class Support(enum.Enum): + NO = 0 + YES = 1 + UNKNOWN = 2 + + +FACTORY_FNS = { + 'tensor', 'zeros', 'ones', 'randn', 'arange', 'rand', 'empty', 'range', + 'full', 'randperm', 'eye', 'randint', 'linspace', 'logspace', +} + +VJP_EXEMPTIONS = { + 'nn.functional.dropout', # not actually problem, randomness testing artifact + 'nn.functional.dropout2d', # not actually problem, randomness testing artifact + 'nn.functional.rrelu', # not actually problem, randomness testing artifact + 'bernoulli', # not actually problem, randomness testing artifact + 'normal', # not actually problem, randomness testing artifact +} + +VMAP_EXEMPTIONS = { + 'randn_like', # randomness + 'rand_like', # randomness + 'allclose', # number output + 'unique', # dynamic + 'nonzero', # dynamic + 'masked_select', # dynamic + 'prod', # dynamic (backward) + 'norm', # norm with nuc is not commonly used; we support the other cases. + 'svd', # There isn't a bug, it is just nondeterministic so we can't test it. + 'nn.functional.embedding', # We support everything except the sparse option. + 'nn.functional.dropout', # randomness + 'nn.functional.dropout2d', # randomness + 'bernoulli', # randomness + 'multinomial', # randomness + 'normal', # randomness +} + +JVP_EXEMPTIONS = { + 'nn.functional.dropout', # not actually problem, randomness testing artifact + 'nn.functional.dropout2d', # not actually problem, randomness testing artifact + 'nn.functional.rrelu', # not actually problem, randomness testing artifact + 'normal', # not actually problem, randomness testing artifact + 'bernoulli', # not actually problem, randomness testing artifact +} + + +class Operator: + def __init__(self, name): + self.name = name + self.opinfos = NAME_TO_OPINFO.get(name, None) + assert self.opinfos is None or len(self.opinfos) > 0 + + def has_opinfo(self): + return self.opinfos is not None + + def __repr__(self): + return f'Operator("{self.name}")' + + def __hash__(self): + return hash(self.name) + + def no_opinfos_skip_test(self, test_name): + """Returns NO if any opinfos have a skip or xfail for the test""" + if not self.has_opinfo(): + return Support.UNKNOWN + if not any([o in additional_op_db for o in self.opinfos]): + if not any([in_functorch_lagging_op_db(o) for o in self.opinfos]): + return Support.UNKNOWN + for opinfo in self.opinfos: + for decorator in opinfo.decorators: + if not hasattr(decorator, 'test_name'): + continue + if decorator.test_name != test_name: + continue + if is_decorateinfo_skip_or_xfail(decorator): + return Support.NO + return Support.YES + + def any_opinfo_attr(self, attr): + if not self.has_opinfo(): + raise RuntimeError() + return any([getattr(opinfo, attr) for opinfo in self.opinfos]) + + def all_opinfo_attr(self, attr): + if not self.has_opinfo(): + raise RuntimeError() + return all([getattr(opinfo, attr) for opinfo in self.opinfos]) + + def supports_vjp(self): + if self.name in FACTORY_FNS: + return Support.YES + if self.name in VJP_EXEMPTIONS: + return Support.YES + return self.no_opinfos_skip_test('test_vjp') + + def supports_vmap(self): + if self.name in FACTORY_FNS: + return Support.YES + if self.name in VMAP_EXEMPTIONS: + return Support.YES + return self.no_opinfos_skip_test('test_vmap_exhaustive') + + def supports_fast_vmap(self): + if self.name in FACTORY_FNS: + return Support.YES + if self.name in VMAP_EXEMPTIONS: + return Support.YES + return self.no_opinfos_skip_test('test_op_has_batch_rule') + + def supports_vmapvjp(self): + if self.name in FACTORY_FNS: + return Support.YES + if self.name in VMAP_EXEMPTIONS: + return Support.YES + return self.no_opinfos_skip_test('test_vmapvjp') + + def supports_fast_vmapvjp(self): + if self.name in FACTORY_FNS: + return Support.YES + if self.name in VMAP_EXEMPTIONS: + return Support.YES + return self.no_opinfos_skip_test('test_vmapvjp_has_batch_rule') + + def supports_jvp(self): + if self.name in FACTORY_FNS: + return Support.YES + if self.name in JVP_EXEMPTIONS: + return Support.YES + if not self.has_opinfo(): + return Support.UNKNOWN + if self.any_opinfo_attr('supports_autograd') and \ + not self.all_opinfo_attr('supports_forward_ad'): + return Support.NO + return self.no_opinfos_skip_test('test_jvp') + + def supports_jvpvjp(self): + if self.name in FACTORY_FNS: + return Support.YES + exemptions = { + # we have support (see OpInfo), testing artifact + 'nn.functional.dropout2d', + 'nn.functional.dropout', + # exception: we dont even support double backward for this + 'nn.functional.hardswish', + 'bernoulli', # this isn't differentiable + 'normal', # not differentiable + } + if self.name in exemptions: + return Support.YES + return self.no_opinfos_skip_test('test_jvpvjp') + + def _supports_vmapjvp_base(self, test): + if self.name in FACTORY_FNS: + return Support.YES + VMAPJVP_EXEMPTIONS = { + 'prod', # dynamic (backward) + 'nn.functional.batch_norm', # testing problem + 'normal', # not actually problem, randomness testing artifact + 'bernoulli', # not actually problem, randomness testing artifact + 'nn.functional.dropout2d', # not actually problem, randomness testing artifact + 'nn.functional.dropout', # not actually problem, randomness testing artifact + # Not a problem. + # It's just that the max_norm testing mutates inputs... + # (we have our own functorch variant of the OpInfo without max_norm) + 'nn.functional.embedding', + } + if self.name in VMAPJVP_EXEMPTIONS: + return Support.YES + if not self.has_opinfo(): + return Support.UNKNOWN + if self.any_opinfo_attr('supports_autograd') and \ + not self.all_opinfo_attr('supports_forward_ad'): + return Support.NO + return self.no_opinfos_skip_test(test) + + def supports_vmapjvp(self): + return self._supports_vmapjvp_base('test_vmapjvpall') + + def supports_fast_vmapjvp(self): + return self._supports_vmapjvp_base('test_vmapjvpall_has_batch_rule') + + +class OperatorSet: + def __init__(self, operators): + self.data = set(operators) + + @classmethod + def from_names(cls, names): + return OperatorSet([Operator(name) for name in names]) + + @classmethod + def from_top_ops_threshold(cls, torch_threshold, nn_fn_threshold): + names = get_top_ops(torch_threshold, nn_fn_threshold) + return cls.from_names(names) + + @classmethod + def from_top125(cls): + return cls.from_top_ops_threshold(100, 25) + + @classmethod + def from_top160(cls): + return cls.from_top_ops_threshold(107, 53) + + @classmethod + def all(cls): + dct = get_public_overridable_outplace_we_care_about() + names = dct.keys() + names_sanitized = [] + for n in names: + torch_tensor = 'torch.Tensor.' + torch_dot = 'torch.' + if n.startswith(torch_tensor): + names_sanitized.append(n[len(torch_tensor):]) + elif n.startswith(torch_dot): + names_sanitized.append(n[len(torch_dot):]) + else: + raise AssertionError() + return cls.from_names(names_sanitized) + + def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)): + result = {} + for key in filter: + result[key] = set([]) + for op in self.data: + support_status = operator_method(op) + if support_status in filter: + result[support_status].add(op) + return result + + def summary(self): + checks = [ + 'supports_vjp', + 'supports_vmap', + 'supports_fast_vmap', + 'supports_vmapvjp', + 'supports_fast_vmapvjp', + 'supports_jvp', + 'supports_vmapjvp', + 'supports_fast_vmapjvp', + 'supports_jvpvjp', + ] + result = ['test, yes, no, unknown'] + for check in checks: + accessor = getattr(Operator, check) + all_results = self.query(accessor) + yes_amt = len(all_results[Support.YES]) + no_amt = len(all_results[Support.NO]) + unknown_amt = len(all_results[Support.UNKNOWN]) + result.append(f'{check}, {yes_amt}, {no_amt}, {unknown_amt}') + return '\n'.join(result) + + +opset = OperatorSet.all() +has_no_opinfo = opset.query(Operator.has_opinfo, (False,)) + +print("=" * 30 + " Summary " + "=" * 30) +print(f'% of usages on github: {get_ops_percentage(99999, 99999)}') +print(opset.summary()) + +# sanity checks +result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) + +print("=" * 30 + " Top 60 Summary " + "=" * 30) +print(f'% of usages on github: {get_ops_percentage(35, 25)}') +opset = OperatorSet.from_top_ops_threshold(35, 25) +# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) +# result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) +# kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) +# kpprint.pprint(result) +# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) +# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) +# pprint.pprint(result) +print(opset.summary()) + +print("=" * 30 + " Top 125 Summary " + "=" * 30) +print(f'% of usages on github: {get_ops_percentage(100, 25)}') +opset = OperatorSet.from_top125() +# result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) +# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) +print("supports_vjp") +result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN)) +pprint.pprint(result) +print("supports_jvp") +result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN)) +pprint.pprint(result) +print("supports_vmapjvp") +result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) +pprint.pprint(result) +print("supports_jvpvjp") +result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) +pprint.pprint(result) +# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) +# pprint.pprint(result) +print(opset.summary()) + +# print("=" * 30 + " Top 160 Summary " + "=" * 30) +# opset = OperatorSet.from_top160() +# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) +# pprint.pprint(result) +# print(opset.summary()) + +# Print list of everything in order +# all_ops = get_top_ops(999999, 999999, with_counts=True) +# for op, count in all_ops: +# print(f'{op}, {count}') diff --git a/functorch/test/functorch_additional_op_db.py b/functorch/test/functorch_additional_op_db.py new file mode 100644 index 0000000000000..b090121d21807 --- /dev/null +++ b/functorch/test/functorch_additional_op_db.py @@ -0,0 +1,569 @@ +from functools import partial +import itertools +import unittest + +import torch + +from torch.testing import \ + (floating_types, floating_types_and, all_types_and_complex_and) +from torch.testing._internal.common_utils import make_tensor +from torch.testing._internal.common_methods_invocations import OpInfo, SampleInput, DecorateInfo + +# List of OpInfos that aren't in PyTorch Core yet. +# They are here because we wanted a fast way of writing OpInfos and may not be +# 100% correct (w.r.t. to dtypes and other options). +# TODO: Figure out how to upstream these, delete them when they're upstreamed + +additional_op_db = [] + +# https://github.com/pytorch/pytorch/pull/61068 + + +def sample_inputs_conv2d(has_bias, self, device, dtype, requires_grad, extra_args=(), groups=1): + in_ch, out_ch = 6, 4 + inp = make_tensor((2, in_ch * groups, 7, 5), device=device, dtype=dtype, + requires_grad=requires_grad, low=-1, high=1) + weight = make_tensor((out_ch * groups, in_ch, 3, 2), device=device, dtype=dtype, + requires_grad=requires_grad, low=-1, high=1) + bias = None + if has_bias: + bias = make_tensor((out_ch * groups,), device=device, dtype=dtype, + requires_grad=requires_grad, low=-1, high=1) + return [SampleInput(inp, args=((weight, bias) + extra_args))] + + +additional_op_db.extend([ + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='no_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, False), + dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='with_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, True), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='stride_with_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 2))), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='stride_no_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, False, extra_args=((2, 2))), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='stride_padding_with_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1))), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='stride_padding_no_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, False, extra_args=((2, 2), (1, 1))), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='strided_padding_dilation_with_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2))), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='strided_padding_dilation_no_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2))), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='stride_groups_with_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 2), groups=2), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), + OpInfo('nn.functional.conv2d', + aten_name="conv2d", + variant_test_name='stride_depthwise_with_bias', + supports_autograd=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 6), groups=6), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types(), + supports_out=False), +]) + + +# TODO: PyTorch core has a check for if requires_grad=True or not. +# We actually want to test more things for backward here which is why we have our own +def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape): + return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_long_input(shape, *, low, high): + return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high) + + M = 20 + S = 5 + + def generator(): + # 0-D index tensor + idx = make_long_input((), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + # 1-D index tensor + idx = make_long_input((S,), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + # 2-D index tensor + idx = make_long_input((S, S), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 2 + idx[1, 1] = 2 + yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},) + + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 4 + idx[1, 1] = 4 + yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},) + + # Scale the gradient based on the inverse frequency of a particular index. + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 1 + idx[0, 1] = 1 + weights = make_input((S, S)) + yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},) + + return list(generator()) + + +additional_op_db.append( + OpInfo( + "nn.functional.embedding", + variant_test_name="functorch", + # We use lambda to reshuffle the positional arguments. + # This is because currently only the `input` field of SampleInput + # is tested in gradient tests. + op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_embedding, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + )) + + +def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape, requires_grad=requires_grad): + return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad) + S = 5 + + shapes = ((S, S), (S, S, S), (S, S, S, S)) + reductions = ("none", "mean", "sum") + + for shape, reduction in itertools.product(shapes, reductions): + yield SampleInput(make_input(shape), + args=(make_input(shape, requires_grad=rhs_requires_grad),), + kwargs={"reduction": reduction}) + + +additional_op_db.append( + OpInfo( + "nn.functional.mse_loss", + variant_test_name="functorch", + sample_inputs_func=sample_inputs_mse_loss, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.float16), + backward_dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + )) + + +# TODO: upstream sample inputs to pytorch/pytorch. +# We are more comprehensive. +def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): + # Short for "advanced index" + adv_idx = torch.LongTensor([[0, 1], [2, 3]]) + S = 5 + # self_dim, indices + test_args = [ + (3, ([1, 2],)), + (3, (slice(0, 3),)), + (3, ([slice(0, 3), 1],)), + (3, ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],)), + (3, ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],)), + (3, ([slice(None), slice(None), [0, 3]],)), + (3, ([slice(None), [0, 3], slice(None)],)), + (3, ([[0, 3], slice(None), slice(None)],)), + (3, ([[0, 3], [1, 2], slice(None)],)), + (3, ([[0, 3], ],)), + (3, ([[0, 3], slice(None)],)), + (3, ([[0, 3], Ellipsis],)), + (3, ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],)), + (4, ([slice(None), adv_idx, adv_idx, slice(None)],)), + (4, ([slice(None), adv_idx, slice(None), adv_idx],)), + (4, ([adv_idx, slice(None), slice(None), adv_idx],)), + (4, ([slice(None), slice(None), adv_idx, adv_idx],)), + (4, ([Ellipsis, adv_idx, adv_idx],)), + (5, ([slice(None), slice(None), adv_idx, slice(None), adv_idx],)), + (5, ([slice(None), slice(None), adv_idx, adv_idx, slice(None)],)), + (5, ([slice(None), slice(None), adv_idx, None, adv_idx, slice(None)],)), + (6, ([slice(None), slice(None), slice(None), adv_idx, adv_idx],)), + (6, ([slice(None), slice(None), adv_idx, adv_idx, adv_idx],)), + (6, ([slice(None), slice(None), None, adv_idx, adv_idx, adv_idx],)), + ] + + def get_shape(dim): + return tuple(S + i for i in range(dim)) + + return tuple(SampleInput( + make_tensor(get_shape(self_dim), device=device, dtype=dtype, low=None, high=None, requires_grad=requires_grad), + args=args) + for self_dim, args in test_args) + + +# TODO: split PyTorch's __getitem__. The problem is we don't support indexing +# with masks with vmap. +additional_op_db.append( + OpInfo('__getitem__', + variant_test_name='functorch', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + op=torch.Tensor.__getitem__, + assert_jit_shape_analysis=False, # TODO: support index.Tensor() + supports_forward_ad=True, + sample_inputs_func=sample_inputs_getitem,)) + + +# Turns out at::index_put is different from torch.index_put... +# TODO: figure out how to upstream this +def sample_inputs_aten_index_put(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + inputs = [] + adv_idx = torch.LongTensor([[0, 1], [2, 3]]) + # self_shape, indices + additional = [ + ((5, 6, 7, 8), [None, adv_idx, adv_idx, None]), + ((5, 6, 7, 8), [None, adv_idx, None, adv_idx]), + ((5, 6, 7, 8), [adv_idx, None, None, adv_idx]), + ((5, 6, 7, 8), [None, None, adv_idx, adv_idx]), + ((5, 6, 7, 8, 9), [None, None, adv_idx, None, adv_idx]), + ((5, 6, 7, 8, 9), [None, None, adv_idx, adv_idx, None]), + ((5, 6, 7, 8, 9, 10), [None, None, None, adv_idx, adv_idx]), + ((5, 6, 7, 8, 9, 10), [None, None, adv_idx, adv_idx, adv_idx]), + ] + for self_shape, indices in additional: + for broadcast_value in [False, True]: + inp = make_arg(self_shape) + + tmp_indices = [slice(None) if idx is None else idx for idx in indices] + values_shape = inp[tmp_indices].shape + if broadcast_value: + values_shape = values_shape[3:] + values = make_arg(values_shape) + inputs.append(SampleInput(inp, args=(tuple(indices), values))) + return inputs + + +def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, dtype=torch.long, device=device, requires_grad=False) + S = 5 + inputs = [] + for accumulate in [False, True]: + # putting vectors at indexed locations + inputs.append(SampleInput( + make_arg((S, S)), + args=((make_idx((2,), low=0, high=4),), make_arg((2, S))), + kwargs=dict(accumulate=accumulate))) + + # putting multi-dim tensors at indexed locations + inputs.append(SampleInput( + make_arg((S, S, 2)), + args=((make_idx((3,), low=0, high=4),), make_arg((3, S, 2))), + kwargs=dict(accumulate=accumulate))) + + # value with size `0` dim + inputs.append(SampleInput( + make_arg((S, 0)), + args=((make_idx((3,), low=0, high=4),), make_arg((3, 0))), + kwargs=dict(accumulate=accumulate))) + + # scalar value + inputs.append(SampleInput( + make_arg((S,)), + args=((make_idx((), low=0, high=S),), make_arg(())), + kwargs=dict(accumulate=accumulate))) + + # cuda and accumulate don't work well + # Reference: https://github.com/pytorch/pytorch/issues/72053 + if not accumulate and device == 'cuda': + # Broadcast `values` + inputs.append(SampleInput( + make_arg((S, S)), + args=((make_idx((2,), low=0, high=S),), make_arg((S,))), + kwargs=dict(accumulate=accumulate))) + + return inputs + + +additional_op_db.append( + OpInfo( + "index_put", + variant_test_name='functorch', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_index_put, + supports_forward_ad=True, + )) +additional_op_db.append( + OpInfo( + "ops.aten.index_put", + variant_test_name='functorch', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_aten_index_put, + supports_forward_ad=True, + )) + +def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs): + S = 3 + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10)) + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10)) + yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10)) + yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10)) + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, 10), + broadcasts_input=True) + +additional_op_db.append( + OpInfo('masked_fill', + variant_test_name='functorch_Scalar_only', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_masked_fill, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + supports_out=False) +) + + +def sample_inputs_new_zeros_with_same_feature_meta(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + matrix = [ + # tangent, base, num_tangent_bdims + ([5], [2, 3], 0), + ([2, 3], [2, 3], 0), + ([5], [2], 0), + ([1, 0, 2], [1, 2], 0), + ([], [1, 2], 0), + ([8, 7, 5], [2, 3, 11], 1), + ([6, 7, 5], [2, 3, 4], 2), + ([6, 4], [3], 2), + ] + results = [] + for tangent_shape, base_shape, num_tangent_bdims in matrix: + tangent = make_arg(tangent_shape) + base = make_arg(base_shape) + results.append(SampleInput( + tangent, + args=(base,), + kwargs=dict(self_num_batch_dims=num_tangent_bdims))) + return results + + +additional_op_db.append( + OpInfo( + "ops.aten._new_zeros_with_same_feature_meta", + variant_test_name='functorchonly', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + supports_forward_ad=False, + sample_inputs_func=sample_inputs_new_zeros_with_same_feature_meta, + )) + + +def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + shapes = ((), + (2, 3)) + memory_format_options = [None, torch.contiguous_format] + for shape, memory_format in itertools.product(shapes, memory_format_options): + yield SampleInput(make_arg(shape), + kwargs={'memory_format': memory_format} if memory_format else {}) + + +additional_op_db.extend([ + OpInfo('bfloat16', + op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'), + )), + OpInfo('bool', + op=lambda x, *args, **kwargs: x.bool(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('byte', + op=lambda x, *args, **kwargs: x.byte(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + # The autograd test runner cannot handle functions that change dtype + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('char', + op=lambda x, *args, **kwargs: x.char(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + # The autograd test runner cannot handle functions that change dtype + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('double', + op=lambda x, *args, **kwargs: x.double(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('float', + op=lambda x, *args, **kwargs: x.float(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('half', + op=lambda x, *args, **kwargs: x.half(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('int', + op=lambda x, *args, **kwargs: x.int(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('long', + op=lambda x, *args, **kwargs: x.long(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('short', + op=lambda x, *args, **kwargs: x.short(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + variant_test_name='functorch_no_channels_last', + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), +]) diff --git a/functorch/test/functorch_lagging_op_db.py b/functorch/test/functorch_lagging_op_db.py new file mode 100644 index 0000000000000..39ed29f7b9da7 --- /dev/null +++ b/functorch/test/functorch_lagging_op_db.py @@ -0,0 +1,574 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torch.testing._internal.common_methods_invocations import op_db + +# Generated from codegen/gen_functorch_op_db.py via +# python codegen/gen_functorch_lagging_op_db.py > test/functorch_lagging_op_db.py +# +# People add new OpInfos to PyTorch all the time. +# We want them to be able to add OpInfos without breaking our CI. +# To achieve this, we keep our OpInfo library behind that of Pytorch's and +# we periodically update our OpInfo library by regenerating this file +_functorch_lagging_meta = { + ('H', ''), + ('T', ''), + ('__getitem__', ''), + ('__radd__', ''), + ('__rand__', ''), + ('__rdiv__', ''), + ('__rmatmul__', ''), + ('__rmod__', ''), + ('__rmul__', ''), + ('__ror__', ''), + ('__rpow__', ''), + ('__rsub__', ''), + ('__rxor__', ''), + ('_masked.amax', ''), + ('_masked.amin', ''), + ('_masked.log_softmax', ''), + ('_masked.mean', ''), + ('_masked.norm', ''), + ('_masked.normalize', ''), + ('_masked.prod', ''), + ('_masked.softmax', ''), + ('_masked.softmin', ''), + ('_masked.std', ''), + ('_masked.sum', ''), + ('_masked.var', ''), + ('abs', ''), + ('acos', ''), + ('acosh', ''), + ('add', ''), + ('addbmm', ''), + ('addcdiv', ''), + ('addcmul', ''), + ('addmm', ''), + ('addmm', 'decomposed'), + ('addmv', ''), + ('addr', ''), + ('all', ''), + ('allclose', ''), + ('amax', ''), + ('amin', ''), + ('aminmax', ''), + ('angle', ''), + ('any', ''), + ('argmax', ''), + ('argmin', ''), + ('argsort', ''), + ('argwhere', ''), + ('as_strided', ''), + ('asin', ''), + ('asinh', ''), + ('atan', ''), + ('atan2', ''), + ('atanh', ''), + ('atleast_1d', ''), + ('atleast_2d', ''), + ('atleast_3d', ''), + ('baddbmm', ''), + ('bernoulli', ''), + ('bfloat16', ''), + ('bincount', ''), + ('bitwise_and', ''), + ('bitwise_left_shift', ''), + ('bitwise_not', ''), + ('bitwise_or', ''), + ('bitwise_right_shift', ''), + ('bitwise_xor', ''), + ('block_diag', ''), + ('bmm', ''), + ('bool', ''), + ('broadcast_tensors', ''), + ('broadcast_to', ''), + ('bucketize', ''), + ('byte', ''), + ('cartesian_prod', ''), + ('cat', ''), + ('cdist', ''), + ('ceil', ''), + ('char', ''), + ('cholesky', ''), + ('cholesky_inverse', ''), + ('cholesky_solve', ''), + ('chunk', ''), + ('clamp', ''), + ('clamp', 'scalar'), + ('clone', ''), + ('column_stack', ''), + ('combinations', ''), + ('complex', ''), + ('conj', ''), + ('conj_physical', ''), + ('contiguous', ''), + ('copysign', ''), + ('corrcoef', ''), + ('cos', ''), + ('cosh', ''), + ('count_nonzero', ''), + ('cov', ''), + ('cross', ''), + ('cummax', ''), + ('cummin', ''), + ('cumprod', ''), + ('cumsum', ''), + ('cumulative_trapezoid', ''), + ('deg2rad', ''), + ('diag', ''), + ('diag_embed', ''), + ('diagflat', ''), + ('diagonal', ''), + ('diagonal_scatter', ''), + ('diff', ''), + ('digamma', ''), + ('dist', ''), + ('div', 'floor_rounding'), + ('div', 'no_rounding_mode'), + ('div', 'trunc_rounding'), + ('dot', ''), + ('double', ''), + ('dsplit', ''), + ('dstack', ''), + ('eig', ''), + ('einsum', ''), + ('empty_like', ''), + ('eq', ''), + ('erf', ''), + ('erfc', ''), + ('erfinv', ''), + ('exp', ''), + ('exp2', ''), + ('expand', ''), + ('expand_as', ''), + ('expm1', ''), + ('fft.fft', ''), + ('fft.fft2', ''), + ('fft.fftn', ''), + ('fft.fftshift', ''), + ('fft.hfft', ''), + ('fft.hfft2', ''), + ('fft.hfftn', ''), + ('fft.ifft', ''), + ('fft.ifft2', ''), + ('fft.ifftn', ''), + ('fft.ifftshift', ''), + ('fft.ihfft', ''), + ('fft.ihfft2', ''), + ('fft.ihfftn', ''), + ('fft.irfft', ''), + ('fft.irfft2', ''), + ('fft.irfftn', ''), + ('fft.rfft', ''), + ('fft.rfft2', ''), + ('fft.rfftn', ''), + ('fill_', ''), + ('flatten', ''), + ('flip', ''), + ('fliplr', ''), + ('flipud', ''), + ('float', ''), + ('float_power', ''), + ('floor', ''), + ('floor_divide', ''), + ('fmax', ''), + ('fmin', ''), + ('fmod', ''), + ('frac', ''), + ('frexp', ''), + ('full_like', ''), + ('gather', ''), + ('gcd', ''), + ('ge', ''), + ('geqrf', ''), + ('gradient', ''), + ('gt', ''), + ('half', ''), + ('heaviside', ''), + ('histc', ''), + ('histogram', ''), + ('histogramdd', ''), + ('hsplit', ''), + ('hstack', ''), + ('hypot', ''), + ('i0', ''), + ('igamma', ''), + ('igammac', ''), + ('imag', ''), + ('index_add', ''), + ('index_copy', ''), + ('index_fill', ''), + ('index_put', ''), + ('index_select', ''), + ('inner', ''), + ('int', ''), + ('inverse', ''), + ('isclose', ''), + ('isfinite', ''), + ('isin', ''), + ('isinf', ''), + ('isnan', ''), + ('isneginf', ''), + ('isposinf', ''), + ('isreal', ''), + ('istft', ''), + ('kron', ''), + ('kthvalue', ''), + ('lcm', ''), + ('ldexp', ''), + ('le', ''), + ('lerp', ''), + ('lgamma', ''), + ('linalg.cholesky', ''), + ('linalg.cholesky_ex', ''), + ('linalg.cond', ''), + ('linalg.cross', ''), + ('linalg.det', ''), + ('linalg.det', 'singular'), + ('linalg.eig', ''), + ('linalg.eigh', ''), + ('linalg.eigvals', ''), + ('linalg.eigvalsh', ''), + ('linalg.householder_product', ''), + ('linalg.inv', ''), + ('linalg.inv_ex', ''), + ('linalg.lstsq', ''), + ('linalg.lstsq', 'grad_oriented'), + ('linalg.lu_factor', ''), + ('linalg.lu_factor_ex', ''), + ('linalg.matrix_norm', ''), + ('linalg.matrix_power', ''), + ('linalg.matrix_rank', ''), + ('linalg.matrix_rank', 'hermitian'), + ('linalg.multi_dot', ''), + ('linalg.norm', ''), + ('linalg.norm', 'subgradients_at_zero'), + ('linalg.pinv', ''), + ('linalg.pinv', 'hermitian'), + ('linalg.pinv', 'singular'), + ('linalg.qr', ''), + ('linalg.slogdet', ''), + ('linalg.solve', ''), + ('linalg.solve_triangular', ''), + ('linalg.svd', ''), + ('linalg.svdvals', ''), + ('linalg.tensorinv', ''), + ('linalg.tensorsolve', ''), + ('linalg.vector_norm', ''), + ('log', ''), + ('log10', ''), + ('log1p', ''), + ('log2', ''), + ('log_softmax', ''), + ('log_softmax', 'dtype'), + ('logaddexp', ''), + ('logaddexp2', ''), + ('logcumsumexp', ''), + ('logdet', ''), + ('logical_and', ''), + ('logical_not', ''), + ('logical_or', ''), + ('logical_xor', ''), + ('logit', ''), + ('logsumexp', ''), + ('long', ''), + ('lt', ''), + ('lu', ''), + ('lu_solve', ''), + ('lu_unpack', ''), + ('mH', ''), + ('mT', ''), + ('masked_fill', ''), + ('masked_scatter', ''), + ('masked_select', ''), + ('matmul', ''), + ('matrix_exp', ''), + ('max', 'binary'), + ('max', 'reduction_no_dim'), + ('max', 'reduction_with_dim'), + ('maximum', ''), + ('mean', ''), + ('median', ''), + ('meshgrid', 'list_of_tensors'), + ('meshgrid', 'variadic_tensors'), + ('min', 'binary'), + ('min', 'reduction_no_dim'), + ('min', 'reduction_with_dim'), + ('minimum', ''), + ('mm', ''), + ('mode', ''), + ('movedim', ''), + ('msort', ''), + ('mul', ''), + ('multinomial', ''), + ('mv', ''), + ('mvlgamma', 'mvlgamma_p_1'), + ('mvlgamma', 'mvlgamma_p_3'), + ('mvlgamma', 'mvlgamma_p_5'), + ('nan_to_num', ''), + ('nanmean', ''), + ('nanmedian', ''), + ('nanquantile', ''), + ('nansum', ''), + ('narrow', ''), + ('ne', ''), + ('neg', ''), + ('new_empty', ''), + ('new_full', ''), + ('new_ones', ''), + ('new_zeros', ''), + ('nextafter', ''), + ('nn.functional.adaptive_avg_pool1d', ''), + ('nn.functional.adaptive_avg_pool2d', ''), + ('nn.functional.adaptive_avg_pool3d', ''), + ('nn.functional.adaptive_max_pool1d', ''), + ('nn.functional.adaptive_max_pool2d', ''), + ('nn.functional.adaptive_max_pool3d', ''), + ('nn.functional.avg_pool1d', ''), + ('nn.functional.avg_pool2d', ''), + ('nn.functional.avg_pool3d', ''), + ('nn.functional.batch_norm', ''), + ('nn.functional.batch_norm', 'without_cudnn'), + ('nn.functional.bilinear', ''), + ('nn.functional.binary_cross_entropy', ''), + ('nn.functional.binary_cross_entropy_with_logits', ''), + ('nn.functional.celu', ''), + ('nn.functional.conv1d', ''), + ('nn.functional.conv2d', ''), + ('nn.functional.conv_transpose1d', ''), + ('nn.functional.conv_transpose2d', ''), + ('nn.functional.conv_transpose3d', ''), + ('nn.functional.cosine_embedding_loss', ''), + ('nn.functional.cosine_similarity', ''), + ('nn.functional.cross_entropy', ''), + ('nn.functional.ctc_loss', ''), + ('nn.functional.dropout', ''), + ('nn.functional.dropout2d', ''), + ('nn.functional.elu', ''), + ('nn.functional.embedding', ''), + ('nn.functional.embedding_bag', ''), + ('nn.functional.feature_alpha_dropout', 'with_train'), + ('nn.functional.feature_alpha_dropout', 'without_train'), + ('nn.functional.fractional_max_pool2d', ''), + ('nn.functional.fractional_max_pool3d', ''), + ('nn.functional.gaussian_nll_loss', ''), + ('nn.functional.gelu', ''), + ('nn.functional.glu', ''), + ('nn.functional.grid_sample', ''), + ('nn.functional.group_norm', ''), + ('nn.functional.hardshrink', ''), + ('nn.functional.hardsigmoid', ''), + ('nn.functional.hardswish', ''), + ('nn.functional.hardtanh', ''), + ('nn.functional.hinge_embedding_loss', ''), + ('nn.functional.huber_loss', ''), + ('nn.functional.instance_norm', ''), + ('nn.functional.interpolate', 'area'), + ('nn.functional.interpolate', 'bicubic'), + ('nn.functional.interpolate', 'bilinear'), + ('nn.functional.interpolate', 'linear'), + ('nn.functional.interpolate', 'nearest'), + ('nn.functional.interpolate', 'trilinear'), + ('nn.functional.kl_div', ''), + ('nn.functional.l1_loss', ''), + ('nn.functional.layer_norm', ''), + ('nn.functional.leaky_relu', ''), + ('nn.functional.linear', ''), + ('nn.functional.local_response_norm', ''), + ('nn.functional.logsigmoid', ''), + ('nn.functional.margin_ranking_loss', ''), + ('nn.functional.max_pool1d', ''), + ('nn.functional.max_pool2d', ''), + ('nn.functional.max_pool3d', ''), + ('nn.functional.max_unpool1d', ''), + ('nn.functional.max_unpool1d', 'grad'), + ('nn.functional.max_unpool2d', ''), + ('nn.functional.max_unpool2d', 'grad'), + ('nn.functional.max_unpool3d', ''), + ('nn.functional.max_unpool3d', 'grad'), + ('nn.functional.mish', ''), + ('nn.functional.mse_loss', ''), + ('nn.functional.multi_margin_loss', ''), + ('nn.functional.multilabel_margin_loss', ''), + ('nn.functional.multilabel_soft_margin_loss', ''), + ('nn.functional.nll_loss', ''), + ('nn.functional.normalize', ''), + ('nn.functional.one_hot', ''), + ('nn.functional.pad', 'circular'), + ('nn.functional.pad', 'constant'), + ('nn.functional.pad', 'reflect'), + ('nn.functional.pad', 'replicate'), + ('nn.functional.pairwise_distance', ''), + ('nn.functional.pdist', ''), + ('nn.functional.pixel_shuffle', ''), + ('nn.functional.pixel_unshuffle', ''), + ('nn.functional.poisson_nll_loss', ''), + ('nn.functional.prelu', ''), + ('nn.functional.relu', ''), + ('nn.functional.relu6', ''), + ('nn.functional.rrelu', ''), + ('nn.functional.selu', ''), + ('nn.functional.silu', ''), + ('nn.functional.silu', 'complex'), + ('nn.functional.smooth_l1_loss', ''), + ('nn.functional.soft_margin_loss', ''), + ('nn.functional.softmin', ''), + ('nn.functional.softmin', 'with_dtype'), + ('nn.functional.softplus', ''), + ('nn.functional.softshrink', ''), + ('nn.functional.softsign', ''), + ('nn.functional.tanhshrink', ''), + ('nn.functional.threshold', ''), + ('nn.functional.triplet_margin_loss', ''), + ('nn.functional.triplet_margin_with_distance_loss', ''), + ('nn.functional.unfold', ''), + ('nn.functional.upsample_bilinear', ''), + ('nn.functional.upsample_nearest', ''), + ('nonzero', ''), + ('norm', ''), + ('norm', 'fro'), + ('norm', 'inf'), + ('norm', 'nuc'), + ('normal', ''), + ('normal', 'number_mean'), + ('ones_like', ''), + ('ormqr', ''), + ('outer', ''), + ('pca_lowrank', ''), + ('permute', ''), + ('pinverse', ''), + ('polar', ''), + ('polygamma', 'polygamma_n_0'), + ('polygamma', 'polygamma_n_1'), + ('polygamma', 'polygamma_n_2'), + ('polygamma', 'polygamma_n_3'), + ('polygamma', 'polygamma_n_4'), + ('positive', ''), + ('pow', ''), + ('prod', ''), + ('put', ''), + ('qr', ''), + ('quantile', ''), + ('rad2deg', ''), + ('rand_like', ''), + ('randint_like', ''), + ('randn_like', ''), + ('ravel', ''), + ('real', ''), + ('reciprocal', ''), + ('remainder', ''), + ('renorm', ''), + ('repeat', ''), + ('repeat_interleave', ''), + ('reshape', ''), + ('reshape_as', ''), + ('resize_', ''), + ('resize_as_', ''), + ('resolve_conj', ''), + ('resolve_neg', ''), + ('roll', ''), + ('rot90', ''), + ('round', ''), + ('round', 'decimals_0'), + ('round', 'decimals_3'), + ('round', 'decimals_neg_3'), + ('rsqrt', ''), + ('rsub', ''), + ('scatter', ''), + ('scatter_add', ''), + ('scatter_reduce', 'amax'), + ('scatter_reduce', 'amin'), + ('scatter_reduce', 'mean'), + ('scatter_reduce', 'prod'), + ('scatter_reduce', 'sum'), + ('searchsorted', ''), + ('select', ''), + ('select_scatter', ''), + ('sgn', ''), + ('short', ''), + ('sigmoid', ''), + ('sign', ''), + ('signbit', ''), + ('sin', ''), + ('sinc', ''), + ('sinh', ''), + ('slice_scatter', ''), + ('softmax', ''), + ('softmax', 'with_dtype'), + ('solve', ''), + ('sort', ''), + ('special.entr', ''), + ('special.erfcx', ''), + ('special.i0e', ''), + ('special.i1', ''), + ('special.i1e', ''), + ('special.log_ndtr', ''), + ('special.ndtr', ''), + ('special.ndtri', ''), + ('special.polygamma', 'special_polygamma_n_0'), + ('special.xlog1py', ''), + ('special.zeta', ''), + ('split', ''), + ('split', 'list_args'), + ('split_with_sizes', ''), + ('sqrt', ''), + ('square', ''), + ('squeeze', ''), + ('stack', ''), + ('std', ''), + ('std_mean', ''), + ('stft', ''), + ('sub', ''), + ('sum', ''), + ('sum_to_size', ''), + ('svd', ''), + ('svd_lowrank', ''), + ('symeig', ''), + ('t', ''), + ('take', ''), + ('take_along_dim', ''), + ('tan', ''), + ('tanh', ''), + ('tensor_split', ''), + ('tensordot', ''), + ('tile', ''), + ('to_sparse', ''), + ('topk', ''), + ('trace', ''), + ('transpose', ''), + ('trapezoid', ''), + ('trapz', ''), + ('triangular_solve', ''), + ('tril', ''), + ('triu', ''), + ('true_divide', ''), + ('trunc', ''), + ('unfold', ''), + ('unique', ''), + ('unique_consecutive', ''), + ('unsqueeze', ''), + ('var', ''), + ('var_mean', ''), + ('vdot', ''), + ('view', ''), + ('view_as', ''), + ('view_as_complex', ''), + ('view_as_real', ''), + ('vsplit', ''), + ('vstack', ''), + ('where', ''), + ('xlogy', ''), + ('zero_', ''), + ('zeros_like', ''), +} + + +def in_functorch_lagging_op_db(opinfo): + return (opinfo.name, opinfo.variant_test_name) in _functorch_lagging_meta + + +functorch_lagging_op_db = [ + opinfo for opinfo in op_db if in_functorch_lagging_op_db(opinfo) +] diff --git a/functorch/test/pytest.ini b/functorch/test/pytest.ini new file mode 100644 index 0000000000000..ff3ba09162ecc --- /dev/null +++ b/functorch/test/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts=-s -v diff --git a/functorch/test/test_compile_cache.py b/functorch/test/test_compile_cache.py new file mode 100644 index 0000000000000..2115e58845f3a --- /dev/null +++ b/functorch/test/test_compile_cache.py @@ -0,0 +1,686 @@ +# Owner(s): ["module: functorch"] + +import torch + +import functorch +from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS +import unittest + +from functorch.compile import aot_function, nop + + +class TestCompileCache(TestCase): + def check(self, a, b, aot_fn, fn): + a_clone = a.clone().detach().requires_grad_(True) + b_clone = b.clone().detach().requires_grad_(True) + ref = fn(a, b) + ref.sum().backward() + + res = aot_fn(a_clone, b_clone) + res.sum().backward() + assert torch.allclose(res, ref) + assert torch.allclose(a.grad, a_clone.grad) + assert torch.allclose(b.grad, b_clone.grad) + + def test_recompilation_on_broadcast(self): + def fn(x, bias): + return x + bias + + for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: + functorch.compile.clear_compile_cache() + start_num_recomps = functorch.compile.num_of_recompilations() + aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type) + + a = torch.randn(10, 20, requires_grad=True) + b = torch.randn(20, requires_grad=True) + self.check(a, b, aot_autograd_fn, fn) + + a = torch.randn(10, 20, requires_grad=True) + b = torch.randn(10, 20, requires_grad=True) + self.check(a, b, aot_autograd_fn, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_compilation_for_dynamic_shape(self): + def fn(x, bias): + return x + bias + + for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: + functorch.compile.clear_compile_cache() + start_num_recomps = functorch.compile.num_of_recompilations() + aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type) + + for s in range(10, 20): + a = torch.randn(s, requires_grad=True) + b = torch.randn(s, requires_grad=True) + self.check(a, b, aot_autograd_fn, fn) + + for s in range(10, 20): + a = torch.randn(s, requires_grad=True) + b = torch.randn(s, requires_grad=True) + self.check(a, b, aot_autograd_fn, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + if hasher_type == "DynamicShapeHasher": + assert total_recomps == 1 + elif hasher_type == "StaticShapeHasher": + assert total_recomps == 10 + + for s in range(10, 20): + a = torch.randn(s, s, requires_grad=True) + b = torch.randn(s, s, requires_grad=True) + self.check(a, b, aot_autograd_fn, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + if hasher_type == "DynamicShapeHasher": + assert total_recomps == 2 + elif hasher_type == "StaticShapeHasher": + assert total_recomps == 20 + + def test_global_cache_no_recompilations(self): + def f(x, bias): + return x + bias + + def g(x, bias): + return aot_function(f, nop, nop, hasher_type="DynamicShapeHasher")(x, bias) + + start_num_recomps = functorch.compile.num_of_recompilations() + for _ in range(10): + a = torch.randn(10, 20, requires_grad=True) + b = torch.randn(10, 20, requires_grad=True) + self.check(a, b, g, f) + + end_num_recomps = functorch.compile.num_of_recompilations() + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 1 + + def test_multiple_functions(self): + def f(x, bias): + return x + bias + + def g(x, y): + return x * y + + for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: + functorch.compile.clear_compile_cache() + aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type) + aot_autograd_g = aot_function(g, nop, nop, hasher_type=hasher_type) + + start_num_recomps = functorch.compile.num_of_recompilations() + a = torch.randn(10, requires_grad=True) + b = torch.randn(10, requires_grad=True) + self.check(a, b, aot_autograd_f, f) + + a = torch.randn(10, requires_grad=True) + b = torch.randn(10, requires_grad=True) + self.check(a, b, aot_autograd_g, g) + + end_num_recomps = functorch.compile.num_of_recompilations() + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + # Force recompilation for function f and check num of recompilations again + a = torch.randn(10, 20, requires_grad=True) + b = torch.randn(10, 20, requires_grad=True) + self.check(a, b, aot_autograd_f, f) + + end_num_recomps = functorch.compile.num_of_recompilations() + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 3 + + def test_high_number_of_args(self): + def f(*args): + res = args[0] + for arg in args: + res = res * arg + return res + + def check(args, aot_autograd_fn, fn): + args_clone = [arg.clone().detach().requires_grad_(True) for arg in args] + ref = fn(*args) + ref.sum().backward() + + res = aot_autograd_fn(*args_clone) + res.sum().backward() + assert torch.allclose(res, ref) + for (arg, arg_clone) in zip(args, args_clone): + assert torch.allclose(arg.grad, arg_clone.grad) + + for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: + functorch.compile.clear_compile_cache() + + aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type) + + args = [torch.randn(10, requires_grad=True) for _ in range(100)] + check(args, aot_autograd_f, f) + + def test_multiple_compiler(self): + def fn(x, bias): + return x + bias + + def nop_duplicate(fx_g, _): + return fx_g + + for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]: + functorch.compile.clear_compile_cache() + start_num_recomps = functorch.compile.num_of_recompilations() + nop_fn = aot_function(fn, nop, nop, hasher_type=hasher_type) + nop_duplicate_fn = aot_function( + fn, nop_duplicate, nop_duplicate, hasher_type=hasher_type + ) + + a = torch.randn(10, 20, requires_grad=True) + b = torch.randn(20, requires_grad=True) + nop_fn(a, b) + nop_duplicate_fn(a, b) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + +@unittest.skipIf(IS_WINDOWS, 'test broken on windows') +class TestCompileCacheStaticArgs(TestCase): + def check(self, a, b, aot_autograd_fn, fn): + a_clone = a.clone().detach().requires_grad_(True) + ref = fn(a, b) + ref.sum().backward() + + res = aot_autograd_fn(a_clone, b) + res.sum().backward() + assert torch.allclose(res, ref) + assert torch.allclose(a.grad, a_clone.grad) + + def test_failure(self): + # Test that not setting up static_argnums should raise exception + def fn(x, p): + return x * p + + aot_autograd_f = aot_function(fn, nop, nop) + + a = torch.randn(2, 2, requires_grad=True) + b = 2 + try: + # Since b is not marked as static, it should raise exception + aot_autograd_f(a, b) + raise AssertionError() + except RuntimeError: + pass + + def test_simple(self): + def fn(x, static_arg): + return x * static_arg + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1) + + a = torch.randn(2, 2, requires_grad=True) + b = 2 + self.check(a, b, aot_autograd_f, fn) + + # Same type of args, so no recompilation + a = torch.randn(2, 2, requires_grad=True) + b = 2 + self.check(a, b, aot_autograd_f, fn) + + # Trigger recompilation + a = torch.randn(2, 2, requires_grad=True) + b = 3 + self.check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_static_arg_before_tensor_arg(self): + def fn(static_arg, x): + return static_arg - x + + def check(a, b, aot_autograd_fn, fn): + b_clone = b.clone().detach().requires_grad_(True) + ref = fn(a, b) + ref.sum().backward() + + res = aot_autograd_fn(a, b_clone) + res.sum().backward() + assert torch.allclose(res, ref) + assert torch.allclose(b.grad, b_clone.grad) + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=0) + + a = 2 + b = torch.randn(2, 2, requires_grad=True) + check(a, b, aot_autograd_f, fn) + + a = 3 + b = torch.randn(2, 2, requires_grad=True) + check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_interleaved_static_args(self): + def fn(static_arg1, x, static_arg2): + return static_arg1 - x - static_arg2 + + def check(a, b, c, aot_autograd_fn, fn): + b_clone = b.clone().detach().requires_grad_(True) + ref = fn(a, b, c) + ref.sum().backward() + + res = aot_autograd_fn(a, b_clone, c) + res.sum().backward() + assert torch.allclose(res, ref) + assert torch.allclose(b.grad, b_clone.grad) + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0, 2)) + + a = 2 + b = torch.randn(2, 2, requires_grad=True) + c = 0.1 + check(a, b, c, aot_autograd_f, fn) + + a = 3 + b = torch.randn(2, 2, requires_grad=True) + c = 0.1 + check(a, b, c, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_dropout(self): + def fn(x, prob): + return torch.nn.functional.dropout(x, p=prob) + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1]) + + a = torch.randn(2, 2, requires_grad=True) + b = 0.3 + aot_autograd_f(a, b) + + # Setting the prob to 0. This should cause recompilation. + a = torch.randn(2, 2, requires_grad=True) + b = 0 + self.check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_if_condition(self): + def fn(x, state: bool): + if state: + return torch.sin(x) + else: + return torch.cos(x) + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1]) + + a = torch.randn(2, 2, requires_grad=True) + b = True + self.check(a, b, aot_autograd_f, fn) + + a = torch.randn(2, 2, requires_grad=True) + b = True + self.check(a, b, aot_autograd_f, fn) + + a = torch.randn(2, 2, requires_grad=True) + b = False + self.check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_custom(self): + class Record: + def __init__(self, name, multiplier): + self.name = name + self.multiplier = multiplier + + def __eq__(self, other): + return self.name == other.name and self.multiplier == other.multiplier + + def __hash__(self): + return hash((self.name, self.multiplier)) + + def fn(x, record): + return x * record.multiplier + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1]) + + a = torch.randn(2, 2, requires_grad=True) + b = Record("Foo", 0.5) + self.check(a, b, aot_autograd_f, fn) + + a = torch.randn(2, 2, requires_grad=True) + b = Record("Bar", 10.2) + self.check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_tuple(self): + def fn(a_tuple, static_arg): + return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg + + def check(a_tuple, b, aot_autograd_fn, fn): + a0 = a_tuple[0] + a1 = a_tuple[1] + + a0_clone = a0.clone().detach().requires_grad_(True) + a1_clone = a1.clone().detach().requires_grad_(True) + ref = fn(a, b) + ref.sum().backward() + + res = aot_autograd_fn((a0_clone, a1_clone), b) + res.sum().backward() + assert torch.allclose(res, ref) + assert torch.allclose(a0.grad, a0_clone.grad) + assert torch.allclose(a1.grad, a1_clone.grad) + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,)) + + a = ( + torch.randn(2, 2, requires_grad=True), + torch.randn(2, 2, requires_grad=True), + ) + b = 0.1 + check(a, b, aot_autograd_f, fn) + + a = ( + torch.randn(2, 2, requires_grad=True), + torch.randn(2, 2, requires_grad=True), + ) + b = 1 + check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_tuple_with_first_arg_as_static(self): + def fn(static_arg, a_tuple): + return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg + + def check(a, b_tuple, aot_autograd_fn, fn): + b0 = b_tuple[0] + b1 = b_tuple[1] + + b0_clone = b0.clone().detach().requires_grad_(True) + b1_clone = b1.clone().detach().requires_grad_(True) + ref = fn(a, b_tuple) + ref.sum().backward() + + res = aot_autograd_fn(a, (b0_clone, b1_clone)) + res.sum().backward() + assert torch.allclose(res, ref) + assert torch.allclose(b0.grad, b0_clone.grad) + assert torch.allclose(b1.grad, b1_clone.grad) + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,)) + + a = 0.1 + b = ( + torch.randn(2, 2, requires_grad=True), + torch.randn(2, 2, requires_grad=True), + ) + check(a, b, aot_autograd_f, fn) + + a = 1 + b = ( + torch.randn(2, 2, requires_grad=True), + torch.randn(2, 2, requires_grad=True), + ) + check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_dict(self): + def fn(a_dict, static_arg): + return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg + + def check(a_dict, b, aot_autograd_fn, fn): + + a0 = a_dict["foo"] + a1 = a_dict["bar"] + + a0_clone = a0.clone().detach().requires_grad_(True) + a1_clone = a1.clone().detach().requires_grad_(True) + ref = fn(a_dict, b) + ref.sum().backward() + + a_clone = {} + a_clone["foo"] = a0_clone + a_clone["bar"] = a1_clone + res = aot_autograd_fn(a_clone, b) + res.sum().backward() + assert torch.allclose(res, ref) + assert torch.allclose(a0.grad, a0_clone.grad) + assert torch.allclose(a1.grad, a1_clone.grad) + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,)) + + a = {} + a["foo"] = torch.zeros(2, 2, requires_grad=True) + a["bar"] = torch.ones(2, 2, requires_grad=True) + b = 0 + check(a, b, aot_autograd_f, fn) + + a = {} + a["foo"] = torch.randn(2, 2, requires_grad=True) + a["bar"] = torch.randn(2, 2, requires_grad=True) + b = 0.2 + check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_dict_with_static_arg_before_dict(self): + def fn(static_arg, a_dict): + return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg + + def check(a, b_dict, aot_autograd_fn, fn): + + ref = fn(a, b_dict) + res = aot_autograd_fn(a, b_dict) + assert torch.allclose(res, ref) + + b0 = b_dict["foo"] + b1 = b_dict["bar"] + + b0_clone = b0.clone().detach().requires_grad_(True) + b1_clone = b1.clone().detach().requires_grad_(True) + ref.sum().backward() + + b_clone = {} + b_clone["foo"] = b0_clone + b_clone["bar"] = b1_clone + res = aot_autograd_fn(a, b_clone) + res.sum().backward() + assert torch.allclose(res, ref) + assert torch.allclose(b0.grad, b0_clone.grad) + assert torch.allclose(b1.grad, b1_clone.grad) + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,)) + + a = 0.1 + b = {} + b["foo"] = torch.randn(2, 2, requires_grad=True) + b["bar"] = torch.randn(2, 2, requires_grad=True) + check(a, b, aot_autograd_f, fn) + + a = 0.2 + b = {} + b["foo"] = torch.randn(2, 2, requires_grad=True) + b["bar"] = torch.randn(2, 2, requires_grad=True) + check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_tuple_static_args(self): + def fn(x, tuple_static_arg): + return x * tuple_static_arg[0] * tuple_static_arg[1] + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1) + + a = torch.randn(2, 2, requires_grad=True) + b = (2, 3) + self.check(a, b, aot_autograd_f, fn) + + # Same type of args, so no recompilation + a = torch.randn(2, 2, requires_grad=True) + b = (2, 3) + self.check(a, b, aot_autograd_f, fn) + + # Trigger recompilation + a = torch.randn(2, 2, requires_grad=True) + b = (3, 4) + self.check(a, b, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 2 + + def test_arg_none(self): + def check(a, b, c, aot_autograd_fn, fn): + def cloner(x): + if x is not None: + return x.clone().detach().requires_grad_(True) + return None + + def check_grad(x, x_clone): + if x is not None: + return torch.allclose(x.grad, x_clone.grad) + return True + + ref = fn(a, b, c) + res = aot_autograd_fn(a, b, c) + assert torch.allclose(res, ref) + + a_clone = cloner(a) + b_clone = cloner(b) + c_clone = cloner(c) + ref.sum().backward() + res = aot_autograd_fn(a_clone, b_clone, c_clone) + res.sum().backward() + + check_grad(a, a_clone) + check_grad(b, b_clone) + check_grad(c, c_clone) + + def fn(a, b, c): + if a is None and b is None: + return c + elif a is None and c is None: + return b + elif b is None and c is None: + return a + elif a is None: + return b + c + elif b is None: + return a + c + elif c is None: + return a + b + return a + b + c + + functorch.compile.clear_compile_cache() + + start_num_recomps = functorch.compile.num_of_recompilations() + + aot_autograd_f = aot_function(fn, nop, nop) + + t1 = torch.randn(2, 2, requires_grad=True) + check(t1, None, None, aot_autograd_f, fn) + check(None, t1, None, aot_autograd_f, fn) + check(None, None, t1, aot_autograd_f, fn) + + t2 = torch.randn(2, 2, requires_grad=True) + check(t1, t2, None, aot_autograd_f, fn) + check(t1, None, t2, aot_autograd_f, fn) + check(None, t1, t2, aot_autograd_f, fn) + + t3 = torch.randn(2, 2, requires_grad=True) + check(t1, t2, t3, aot_autograd_f, fn) + + # Same type of args, so no recompilation + check(t1, t2, None, aot_autograd_f, fn) + + end_num_recomps = functorch.compile.num_of_recompilations() + + total_recomps = end_num_recomps - start_num_recomps + assert total_recomps == 7 + + +if __name__ == "__main__": + run_tests() diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py new file mode 100644 index 0000000000000..1b372cedefc05 --- /dev/null +++ b/functorch/test/test_eager_transforms.py @@ -0,0 +1,3267 @@ +# Owner(s): ["module: functorch"] + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from torch.testing._internal.common_utils import ( + TestCase, run_tests, parametrize, subtest, instantiate_parametrized_tests +) +import torch +import torch.nn as nn +import torch.nn.functional as F +import unittest +import warnings +import math +from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU +from torch.testing._internal.common_dtype import get_all_fp_dtypes +from torch.testing._internal.common_utils import IS_WINDOWS +from functools import partial +from functorch.experimental import replace_all_batch_norm_modules_ + +import functorch +from functorch import ( + grad, vjp, vmap, jacrev, jacfwd, grad_and_value, hessian, + jvp, make_functional, make_functional_with_buffers, + combine_state_for_ensemble, make_fx +) +from functorch._src.make_functional import ( + functional_init, functional_init_with_buffers, +) +from functorch._src.eager_transforms import _argnums_partial, enable_fwd_grad +from functorch.experimental import functionalize + +if not IS_WINDOWS: + from functorch._src.custom_function import custom_vjp + +# NB: numpy is a testing dependency! +import numpy as np + +USE_TORCHVISION = False +try: + import torchvision # noqa: F401 + USE_TORCHVISION = True +except ImportError: + warnings.warn("Couldn't import torchvision. Some of our tests use it, try " + "to install it with commands from pytorch.org, post-fixed with " + "`--no-deps` to avoid overwriting the pytorch installation", + UserWarning) + +# TestCase for _argnums_partial, an important helper funciton + + +class TestArgnumsPartial(TestCase): + def test_invalid_argnum_type(self): + x = torch.randn(3) + args = (x,) + with self.assertRaisesRegex(RuntimeError, "int or Tuple"): + _argnums_partial(torch.sin, args, 0.0) + with self.assertRaisesRegex(RuntimeError, "int or Tuple"): + _argnums_partial(torch.sin, args, [0]) + with self.assertRaisesRegex(RuntimeError, "must be int"): + _argnums_partial(torch.sin, args, (0.0,)) + + args = (0.1, 1.1, 2.1, 3.1, 4.1) + + def f(a, b, c, d, e): + return a + with self.assertRaisesRegex(RuntimeError, "must be int"): + _argnums_partial(torch.sin, args, ((0, 1), 2)) + + def test_out_of_bounds_argnum_values(self): + x = torch.randn(3) + args = (x,) + with self.assertRaisesRegex(RuntimeError, "positional inputs"): + _argnums_partial(torch.sin, args, 1) + with self.assertRaisesRegex(RuntimeError, "positional inputs"): + _argnums_partial(torch.sin, args, -2) + with self.assertRaisesRegex(RuntimeError, "positional inputs"): + _argnums_partial(torch.sin, args, (-2,)) + + def test_not_enough_argnums(self): + x = torch.randn(3) + args = (x,) + with self.assertRaisesRegex(RuntimeError, "must be non-empty"): + _argnums_partial(torch.sin, args, ()) + + def test_duplicate_argnums(self): + x = torch.randn(3) + args = (x, x) + with self.assertRaisesRegex(RuntimeError, "must be unique"): + _argnums_partial(torch.add, args, (0, 0)) + with self.assertRaisesRegex(RuntimeError, "must be unique"): + _argnums_partial(torch.add, args, (0, -2)) + + def test_flat_args_with_positive_int_argnum(self): + args = (0.1, 1.1, 2.1, 3.1, 4.1) + + def f(a, b, c, d, e): + return a + + f_new, res = _argnums_partial(f, args, 0) + self.assertEqual(res, (0.1,)) + self.assertEqual(f_new(*res), 0.1) + + f_new, res = _argnums_partial(f, args, 4) + self.assertEqual(res, (4.1,)) + self.assertEqual(f_new(*res), 0.1) + + def test_flat_args_with_negative_int_argnum(self): + args = (0.1, 1.1, 2.1, 3.1, 4.1) + + def f(a, b, c, d, e): + return a + + expected = f(*args) + f_new, res = _argnums_partial(f, args, -1) + self.assertEqual(res, (4.1,)) + self.assertEqual(f_new(*res), expected) + + f_new, res = _argnums_partial(f, args, -5) + self.assertEqual(res, (0.1,)) + self.assertEqual(f_new(*res), expected) + + def test_flat_args_with_tuple_argnum(self): + args = (0.1, 1.1, 2.1, 3.1, 4.1) + + def f(a, b, c, d, e): + return a + + f_new, res = _argnums_partial(f, args, (0, 1, 2, 3, 4)) + self.assertEqual(f_new(*res), 0.1) + self.assertEqual(res, args) + + f_new, res = _argnums_partial(f, args, (0, -3)) + self.assertEqual(f_new(*res), 0.1) + self.assertEqual(res, (0.1, 2.1)) + + def test_pytree_args(self): + args = ((0.1, 1.1), 2.0, [3.1]) + + def f(a, b, c): + return a[0] + a[1] + b + c[0] + + expected = f(*args) + + f_new, res = _argnums_partial(f, args, 0) + self.assertEqual(res, args[0:1]) + self.assertEqual(f_new(*res), expected) + + f_new, res = _argnums_partial(f, args, (0,)) + self.assertEqual(res, args[0:1]) + self.assertEqual(f_new(*res), expected) + + f_new, res = _argnums_partial(f, args, -1) + self.assertEqual(res, args[-1:]) + self.assertEqual(f_new(*res), expected) + + f_new, res = _argnums_partial(f, args, (0, -2)) + self.assertEqual(res, args[0:2]) + self.assertEqual(f_new(*res), expected) + + def test_argnums_reorders(self): + args = ((0.1, 1.1, 2.1), 3.1, 4.1) + + def f(a, b, c): + return a[0] + a[1] + a[2] + b + c + + expected = f(*args) + f_new, res = _argnums_partial(f, args, (1, 0)) + self.assertEqual(res, (args[1], args[0])) + self.assertEqual(f_new(*res), expected) + + def test_function_with_default_args(self): + args = ((0.1, 1.1, 2.1), 3.1) + + def f(a, b, c=4.1): + return a[0] + a[1] + a[2] + b + c + + expected = f(*args) + f_new, res = _argnums_partial(f, args, -2) + self.assertEqual(res, args[0:1]) + self.assertEqual(f_new(*res), expected) + + args = ((0.1, 1.1, 2.1), 3.1, 5.1) + expected = f(*args) + f_new, res = _argnums_partial(f, args, -1) + self.assertEqual(res, args[-1:]) + self.assertEqual(f_new(*res), expected) + + +class TestGradTransform(TestCase): + def test_primitive(self, device): + x = torch.randn([], device=device) + result = grad(torch.sin)(x) + self.assertEqual(result, torch.cos(x)) + + def test_composite_simple(self, device): + x = torch.randn(2, 3, 4, device=device) + result = grad(lambda x: torch.flatten(x).sum())(x) + self.assertEqual(result, torch.ones_like(x)) + + def test_fn_with_kwargs(self, device): + def foo(x, y): + return (x * y).sum() + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + expected = grad(foo)(x, y) + result = grad(foo)(x, y=y) + self.assertEqual(result, expected) + + def test_composite_complicated(self, device): + x = torch.randn(3, device=device) + y = torch.randn(3, 5, device=device) + + def foo(x, y): + result = x @ y + return result.sum() + + result = grad(foo)(x, y) + + x.requires_grad_() + out = foo(x, y) + expected, = torch.autograd.grad(out, x) + + self.assertEqual(result, expected) + + def test_composite_two_ops(self, device): + N, C = 2, 5 + y = torch.randn(N, C, device=device) + targets = torch.randint(0, C, (N,), device=device) + + def foo(y, targets): + return F.cross_entropy(y, targets) + + result = grad(foo)(y, targets) + + y.requires_grad_() + expected, = torch.autograd.grad(foo(y, targets), y) + + self.assertEqual(result, expected) + + def _test_attributes(self, get_attr_lambda, device): + x = torch.randn(2, 3, 5, dtype=torch.double, device=device) + expected = get_attr_lambda(x) + + def foo(x): + self.assertEqual(get_attr_lambda(x), expected) + return x.sum() + + grad(foo)(x) + + def test_shape(self, device): + self._test_attributes(lambda x: x.shape, device) + + def test_dtype(self, device): + self._test_attributes(lambda x: x.dtype, device) + + def test_is_cuda(self, device): + self._test_attributes(lambda x: x.is_cuda, device) + + def test_numel(self, device): + self._test_attributes(lambda x: x.numel(), device) + + def test_inplace(self, device): + x = torch.randn([], device=device) + + def foo(x): + return x.clone().sin_() + + result = grad(foo)(x) + self.assertEqual(result, x.cos()) + + def test_inplace_on_view(self, device): + x = torch.randn(3, device=device) + + def foo(x): + y = x.clone() + y0 = y[0] + y0.sin_() + return y.sum() + + result = grad(foo)(x) + + x.requires_grad_() + out = foo(x) + expected, = torch.autograd.grad(out, x) + + self.assertEqual(result, expected) + + def test_inplace_on_view_base(self, device): + x = torch.randn(3, device=device) + + def foo(x): + y = x.clone() + y0 = y[0] + y.sin_() + return y0 + + result = grad(foo)(x) + + x.requires_grad_() + out = foo(x) + expected, = torch.autograd.grad(out, x) + + self.assertEqual(result, expected) + + def test_inplace_on_captures(self, device): + x = torch.tensor([1., 2., 3.], device=device) + captured = torch.randn(3, device=device) + + def foo(x): + captured.copy_(x) + return (x * captured).sum() + + with self.assertRaisesRegex(RuntimeError, 'mutate a captured Tensor'): + grad(foo)(x) + + def test_nesting_simple(self, device): + x = torch.randn([], device=device) + result = grad(grad(torch.sin))(x) + self.assertEqual(result, -torch.sin(x)) + + def test_escaped_wrappers_are_marked_as_dead(self, device): + x = torch.randn([], device=device) + escaped = [] + + def foo(x): + y = x.sin() + escaped.append(y) + return y + + grad(foo)(x) + self.assertEqual(functorch._C.dlevel(escaped[0]), -1) + + def test_escaped_wrappers_are_ignored(self, device): + x = torch.randn([], device=device) + escaped = [] + + def foo(x): + y = x.sin() + escaped.append(y) + return y + + grad(foo)(x) + + something = escaped[0].sum() + self.assertEqual(functorch._C.dlevel(something), 0) + self.assertEqual(something, x.sin().sum()) + + def test_vjp(self, device): + x = torch.randn([], device=device) + out, vjp_fn = vjp(torch.sin, x) + self.assertEqual(out, x.sin()) + + v = torch.randn([], device=device) + result, = vjp_fn(v) + self.assertEqual(result, v * x.cos()) + + def test_vjp_two_outputs(self, device): + def f(x): + return x, x + result, vjp_fn = vjp(f, torch.tensor(1.)) + vjp_fn(result) + + def test_conj_bit(self): + x = torch.tensor(1 + 1j) + + def foo(x): + assert not x.is_conj() + y = x.conj() + assert y.is_conj() + return y + res = grad(foo)(x) + self.assertEqual(res, torch.ones_like(res)) + + def test_composed_with_autograd(self, device): + x = torch.randn([], requires_grad=True, device=device) + + y = grad(torch.sin)(x) + result, = torch.autograd.grad(y, x) + self.assertEqual(result, -x.sin()) + + def test_grad_of_vjp_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) + + def foo(x, y): + out, vjp_fn = vjp(torch.sin, x) + return grad(lambda y: vjp_fn(y)[0])(y) + + result = foo(x, y) + expected = x.cos() + self.assertEqual(result, expected) + + def test_vjp_of_grad_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) + + def foo(x, y): + out, vjp_fn = vjp(grad(torch.sin), x) + return vjp_fn(y)[0] + + result = foo(x, y) + expected = -y * x.sin() + self.assertEqual(result, expected) + + def test_grad_of_vjp_of_grad_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) + + def foo(x, y): + df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x) + return grad(lambda y: vjp_fn(y)[0])(y) + + result = foo(x, y) + expected = x.cos() + self.assertEqual(result, expected) + + def test_views(self, device): + x = torch.randn([], requires_grad=True, device=device) + y = torch.randn([], requires_grad=True, device=device) + + def silly_sin(x): + x = x.view([]) + x = x.sin() + return x + + def foo(x, y): + z1 = grad(silly_sin)(x) + z2 = torch.cos(y) + return z1 + z2 + + result = foo(x, y) + grads = torch.autograd.grad(result, [x, y]) + self.assertEqual(grads[0], -x.sin()) + self.assertEqual(grads[1], -y.sin()) + + def test_view_inplace_simple(self, device): + def foo(x): + x = x.clone() + x.view([]).sin_() + return x + + x = torch.randn([], requires_grad=True, device=device) + result = grad(foo)(x) + self.assertEqual(result, x.cos()) + + def test_invalid_argnums(self, device): + x = torch.randn([]) + y = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, 'but only'): + grad(torch.mul, argnums=-3)(x, y) + with self.assertRaisesRegex(RuntimeError, 'but only'): + grad(torch.mul, argnums=2)(x, y) + with self.assertRaisesRegex(RuntimeError, 'int or Tuple'): + grad(torch.mul, argnums=[0])(x, y) + with self.assertRaisesRegex(RuntimeError, 'must be int'): + grad(torch.mul, argnums=('0',))(x, y) + with self.assertRaisesRegex(RuntimeError, 'must be unique'): + grad(torch.mul, argnums=(0, 0))(x, y) + with self.assertRaisesRegex(RuntimeError, 'must be unique'): + grad(torch.mul, argnums=(0, -2))(x, y) + + def test_argnums(self, device): + x = torch.randn([]) + y = torch.randn([]) + gx = grad(torch.mul, argnums=0)(x, y) + self.assertEqual(gx, y) + + gy = grad(torch.mul, argnums=1)(x, y) + self.assertEqual(gy, x) + + gx, = grad(torch.mul, argnums=(0,))(x, y) + self.assertEqual(gx, y) + + gx, gy = grad(torch.mul, argnums=(0, 1))(x, y) + self.assertEqual(gx, y) + self.assertEqual(gy, x) + + def test_out_of_order_argnums(self, device): + x = torch.randn([]) + y = torch.randn([]) + gy, gx = grad(torch.mul, argnums=(1, 0))(x, y) + self.assertEqual(gx, y) + self.assertEqual(gy, x) + + def test_negative_argnums(self, device): + x = torch.randn([]) + y = torch.randn([]) + gx = grad(torch.mul, argnums=-2)(x, y) + self.assertEqual(gx, y) + + gy = grad(torch.mul, argnums=-1)(x, y) + self.assertEqual(gy, x) + + gx, = grad(torch.mul, argnums=(-2,))(x, y) + self.assertEqual(gx, y) + + gx, gy = grad(torch.mul, argnums=(-2, -1))(x, y) + self.assertEqual(gx, y) + self.assertEqual(gy, x) + + def test_grad_pytree_inputs(self, device): + x = torch.randn([], device=device) + + def f(a, b): + x, y = a + return 1 * x + 2 * y + 3 * b['foo'] + + args = ((x, x), {'foo': x}) + + gx, gy = grad(f)(*args) + self.assertEqual(gx, torch.tensor(1., device=device)) + self.assertEqual(gy, torch.tensor(2., device=device)) + + (gx, gy), = grad(f, argnums=(0,))(*args) + self.assertEqual(gx, torch.tensor(1., device=device)) + self.assertEqual(gy, torch.tensor(2., device=device)) + + (gx, gy), gz = grad(f, argnums=(0, 1))(*args) + self.assertEqual(gx, torch.tensor(1., device=device)) + self.assertEqual(gy, torch.tensor(2., device=device)) + self.assertEqual(gz['foo'], torch.tensor(3., device=device)) + + def test_grad_aux_tensor(self, device): + + x = torch.randn(3, device=device) + + with self.assertRaisesRegex( + RuntimeError, + r'grad_and_value\(f\)\(\*args\): output of function f should be a tuple' + ): + grad(lambda t: [t, t], has_aux=True)(x) + + with self.assertRaisesRegex( + RuntimeError, + r'grad_and_value\(f\)\(\*args\): output of function f should be a tuple' + ): + grad(lambda t: (t, t + 2, t + 3), has_aux=True)(x) + + def f(t): + y = t.sin() + return y.sum(), t.cos() + + out, aux = grad(f, has_aux=True)(x) + self.assertEqual(aux, x.cos()) + self.assertEqual(out, x.cos()) + + def test_grad_aux_pytree(self, device): + def f(x): + y = x.sin() + return y.sum(), {'a': x.cos(), 'b': [x.tan()]} + + x = torch.randn(3, device=device) + + out, aux = grad(f, has_aux=True)(x) + _, expected_aux = f(x) + self.assertEqual(aux, expected_aux) + self.assertEqual(out, x.cos()) + + for aux in [1, 1.0, "abc"]: + with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): + _ = grad(lambda x: (x.sum(), aux), has_aux=True)(x) + with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): + _ = grad(lambda x: (x.sum(), [x, aux]), has_aux=True)(x) + + def test_zero_grad(self, device): + def f(x): + return (x['a']**2.0).sum() + inps = ({'a': torch.randn(10, device=device) + 3, 'b': torch.randn(10, device=device)}) + grads = grad(f)(inps) + self.assertNotEqual(grads['a'].sum(), 0.0) + self.assertEqual(grads['b'].sum(), 0.0) + + def test_unrelated_grad(self, device): + x = torch.tensor(1., device=device) + y = torch.tensor(2., device=device) + + def unrelated(x): + return y + + result = grad(unrelated)(x) + self.assertEqual(result, torch.zeros_like(x)) + + def test_unrelated_vjp(self, device): + x = torch.tensor(1., device=device) + y = torch.tensor(2., device=device) + v = torch.tensor(1., device=device) + + def unrelated(x): + return y + + out, vjp_fn = vjp(unrelated, x) + result = vjp_fn(v) + expected = (torch.zeros_like(x),) + self.assertEqual(result, expected) + + def test_unrelated_vjp_multiple_inputs_outputs(self, device): + w = torch.tensor(3., device=device) + x = torch.tensor(4., device=device) + y = torch.tensor(2., device=device) + v = torch.tensor(1., device=device) + + def unrelated(w, x): + return y, y, x + + out, vjp_fn = vjp(unrelated, w, x) + result = vjp_fn((v, v, v)) + expected = (torch.zeros_like(x), torch.ones_like(x)) + self.assertEqual(result, expected) + + # TODO: https://github.com/zou3519/functorch/issues/12 + @onlyCPU + def test_unrelated_hessian(self, device): + N = 5 + M = 3 + W = torch.randn(N, M, device=device) + + def f(x): + return W @ x + + x = torch.randn(M) + result = jacrev(jacrev(f))(x) + expected = torch.zeros(N, M, M, device=device) + self.assertEqual(result, expected) + + def test_vjp_pytree_input(self, device): + def f(x): + return x[0] * x[1][0] + + x = torch.randn([], device=device) + v = torch.randn([], device=device) + out, vjp_fn = vjp(f, (x, (x, x))) + self.assertEqual(out, x * x) + result = vjp_fn(v) + self.assertEqual(result, ((x * v, (x * v, 0.)),)) + + def test_vjp_pytree_output(self, device): + def f(x): + return x, (x, x) + + x = torch.randn([], device=device) + v1 = torch.randn([], device=device) + v2 = torch.randn([], device=device) + v3 = torch.randn([], device=device) + _, vjp_fn = vjp(f, x) + result, = vjp_fn((v1, (v2, v3))) + self.assertEqual(result, v1 + v2 + v3) + + def test_vjp_outputs_can_any_pytree(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + for output in [None, ()]: + with self.assertRaisesRegex( + RuntimeError, r"vjp\(f, \*primals\): Expected f to be a function that has non-empty output" + ): + _, vjp_fn = vjp(lambda _: output, x) + vjp_fn(t) + + for output in [1, True, 12.2, "abc"]: + with self.assertRaisesRegex( + RuntimeError, r"vjp\(f, \*primals\): expected f\(\*primals\) to return only tensors" + ): + _, vjp_fn = vjp(lambda _: output, x) + vjp_fn(t) + + # Check list output + output, vjp_fn = vjp(lambda x: [x, x.sum()], x) + vjp_out, = vjp_fn([t, t.sum()]) + assert isinstance(output, list) and len(output) == 2 + assert isinstance(vjp_out, torch.Tensor) + + # Check dict output + output, vjp_fn = vjp(lambda x: {"x": x, "xsum": x.sum()}, x) + vjp_out, = vjp_fn({"x": t, "xsum": t.sum()}) + assert isinstance(output, dict) and len(output) == 2 and "xsum" in output + assert isinstance(vjp_out, torch.Tensor) + + def composite_output(x): + out = x.sum() + return [ + (out, {"a": x, "out": [x, out]}), + ] + + output, vjp_fn = vjp(composite_output, x) + vjp_out, = vjp_fn([(t.sum(), {"a": t, "out": [t, t.sum()]}), ]) + assert isinstance(output, list) + assert isinstance(output[0], tuple) and isinstance(output[0][1], dict) + assert isinstance(vjp_out, torch.Tensor) + + def test_vjp_pytree_error(self, device): + def f(x): + return x, (x, x) + + x = torch.randn([], device=device) + v1 = torch.randn([], device=device) + v2 = torch.randn([], device=device) + v3 = torch.randn([], device=device) + _, vjp_fn = vjp(f, x) + with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'): + result, = vjp_fn(((v1, (v2, v3)),)) + + def test_vjp_aux_tensor(self, device): + + x = torch.randn(3, device=device) + + with self.assertRaisesRegex(RuntimeError, r'vjp\(f, \*primals\): output of function f should be a tuple'): + vjp(lambda t: [t, t], x, has_aux=True) + + with self.assertRaisesRegex(RuntimeError, r'vjp\(f, \*primals\): output of function f should be a tuple'): + vjp(lambda t: (t, t + 2, t + 3), x, has_aux=True) + + def f(t): + y = t.sin() + return y, t.cos() + + out, vjp_fn, aux = vjp(f, x, has_aux=True) + self.assertEqual(aux, x.cos()) + self.assertEqual(out, x.sin()) + + v = torch.randn(3, device=device) + grad_x, = vjp_fn(v) + self.assertEqual(grad_x, v * x.cos()) + + def test_vjp_aux_pytree(self, device): + def f(x): + y = x.sin() + return y, {'a': x.cos(), 'b': [x.tan()]} + + x = torch.randn(3, device=device) + + out, vjp_fn, aux = vjp(f, x, has_aux=True) + expected_out, expected_aux = f(x) + self.assertEqual(out, expected_out) + self.assertEqual(aux, expected_aux) + + v = torch.randn(3, device=device) + grad_x, = vjp_fn(v) + self.assertEqual(grad_x, v * x.cos()) + + for aux in [1, 1.0, "abc"]: + with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): + _ = vjp(lambda x: (x, aux), x, has_aux=True) + with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): + _ = vjp(lambda x: (x, [x, aux]), x, has_aux=True) + + def test_functional_init(self, device): + class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + B = 10 + weights, fn, _ = functional_init(MLPClassifier, (B,), device=device)(32, 2) + inputs = torch.randn(B, 7, 2, device=device) + vmap(fn)(weights, (inputs,)) + + def test_functional_init_with_buffers(self, device): + class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.bn(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + B = 10 + weights, buffers, fn, _, _ = \ + functional_init_with_buffers(MLPClassifier, [B], device=device)(32, 2) + inputs = torch.randn(B, 7, 2, device=device) + vmap(fn)(weights, buffers, (inputs,)) + + def test_advanced_indexing(self, device): + def f(value): + log_prob = torch.ones((), device=device) + val = (torch.zeros(()) > 0) + log_prob[val] = 0 + return value + + result = grad(f)(torch.randn((), device=device)) + self.assertEqual(result, torch.ones_like(result)) + + def f2(value): + value = value.clone() + value[value > 0] = 0 + return value.sum() + + x = torch.randn(100, device=device) + result = grad(f2)(x) + self.assertEqual(result, (x <= 0).type_as(x)) + + def test_tensor_ctor_inside_grad(self, device): + def foo(x): + return x * torch.tensor(2., device=device) + + x = torch.tensor(3.14, device=device) + functorch.grad(foo)(x) + + @parametrize("op_list_data", [ + subtest(([vmap, ], [(4, 2), (64, 3, 32, 32)]), name='vmap'), + subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name='vmap_vmap'), + subtest(([grad, ], [(0, ), [], (4, 2), (64, 3, 32, 32)]), name='grad'), + subtest(([grad, grad], [[], ]), name='grad_grad'), + subtest(([vmap, grad], [(4, 2)]), name='vmap_grad'), + ]) + def test_tensor_print(self, device, op_list_data): + + op_list, shapes = op_list_data + + for dt in get_all_fp_dtypes(): + data = [torch.randn(s, dtype=dt, device=device) for s in shapes] + + for x in data: + buf = None + + def foo(t): + nonlocal buf + buf = repr(t) + return t.mean() + + fn = foo + bdim = 0 + for op in reversed(op_list): + if op == vmap: + fn = op(fn, in_dims=bdim) + bdim += 1 + else: + fn = op(fn) + + expected = f"{repr(x)}" + level = 0 + for op in op_list: + level += 1 + if op == grad: + expected = f"GradTrackingTensor(lvl={level}, value={expected})" + elif op == vmap: + bdim -= 1 + expected = f"BatchedTensor(lvl={level}, bdim={bdim}, value={expected})" + + fn(x) + buf = buf.replace("\n", "").replace(" ", "") + expected = expected.replace("\n", "").replace(" ", "") + self.assertEqual(expected, buf) + + def test_no_grad_outside(self, device): + x = torch.randn([], device=device, requires_grad=True) + with torch.no_grad(): + y = grad(torch.sin)(x) + self.assertEqual(y, x.cos()) + self.assertFalse(y.requires_grad) + + def test_no_grad_inside(self, device): + def f(x): + with torch.no_grad(): + shift = x ** 2 + return x ** 2 - shift + + x = torch.randn([], device=device) + y = grad(f)(x) + self.assertEqual(y, 2 * x) + y = grad(grad(f))(x) + self.assertEqual(y, 2) + + x = torch.randn([], device=device, requires_grad=True) + y = grad(f)(x) + z, = torch.autograd.grad(y, x) + self.assertEqual(z, 2) + + def test_no_grad_mixed(self, device): + def f(x): + with torch.no_grad(): + shift = x ** 2 + return x ** 2 - shift + + x = torch.randn([], device=device, requires_grad=True) + with torch.no_grad(): + y = grad(f)(x) + + self.assertEqual(y, 2 * x) + self.assertFalse(y.requires_grad) + + def test_no_grad_nested_simple(self, device): + def h(x): + with torch.no_grad(): + shift = grad(lambda x: 0.25 * x ** 4)(x) + return x ** 3 - shift + + x = torch.tensor(1.5, device=device, requires_grad=True) + y = grad(h)(x) + self.assertEqual(y, 3 * x ** 2) + + z, = torch.autograd.grad(y, x) + self.assertEqual(z, 6 * x) + + def test_no_grad_nested_complicated(self, device): + def f(x): + with torch.no_grad(): + shift = x ** 3 + return x ** 3 - shift + + def g(x): + r1 = grad(f)(x) + with torch.no_grad(): + shift = grad(f)(x) + return r1 - shift + + x = torch.randn([], requires_grad=True, device=device) + y = grad(g)(x) + # The only differential part of g is x ** 3 + self.assertEqual(y, 6 * x) + + z, = torch.autograd.grad(y, x) + self.assertEqual(z, 6) + + def test_no_grad_value(self, device): + def h(x): + with torch.no_grad(): + gvalue, value = grad_and_value(lambda x: x ** 3)(x) + return x ** 3 - value + + x = torch.tensor(1.6, device=device, requires_grad=True) + y = grad(h)(x) + self.assertEqual(y, 3 * x ** 2) + + z, = torch.autograd.grad(y, x) + self.assertEqual(z, 6 * x) + + def test_no_grad_outside_vjp(self, device): + def h(x): + return x ** 2 + + x = torch.tensor(2., requires_grad=True, device=device) + with torch.no_grad(): + out, vjp_fn = vjp(h, x) + y, = vjp_fn(torch.tensor(1., device=device)) + + self.assertEqual(y, 2 * x) + self.assertFalse(y.requires_grad) + self.assertFalse(out.requires_grad) + + def test_no_grad_outside_vjp_fn(self, device): + def h(x): + return x ** 2 + + x = torch.tensor(3.14, requires_grad=True, device=device) + out, vjp_fn = vjp(h, x) + with torch.no_grad(): + y, = vjp_fn(torch.tensor(1., device=device)) + + self.assertEqual(y, 2 * x) + self.assertFalse(y.requires_grad) + self.assertTrue(out.requires_grad) + + z, = torch.autograd.grad(out, x) + self.assertEqual(z, 2 * x) + + def test_no_grad_outside_vjp_only(self, device): + def h(x): + return x ** 2 + + x = torch.tensor(3.14, requires_grad=True, device=device) + with torch.no_grad(): + out, vjp_fn = vjp(h, x) + y, = vjp_fn(torch.tensor(1., device=device)) + + self.assertEqual(y, 2 * x) + self.assertFalse(out.requires_grad) + + # This one is a little weird... + self.assertTrue(y.requires_grad) + + z, = torch.autograd.grad(y, x) + self.assertEqual(z, 2) + + +class TestVmapOfGrad(TestCase): + def test_per_sample_grads_inplace_view(self, device): + def compute_loss(weight, x, t): + x = x.mm(weight) + y = x.squeeze_(0) + return (y - t).sum() + + weight = torch.randn(16, 2, device=device) + x = torch.randn(64, 1, 16, device=device) + t = torch.randn(64, 2, device=device) + result = vmap(partial(grad(compute_loss), weight))(x, t) + expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] + expected = torch.stack(expected) + # TODO: Check if the rtol is a problem + self.assertEqual(result, expected, atol=0, rtol=5e-4) + + def test_new_zeros_materializes_tensor(self, device): + N = 3 + C = 5 + + def foo(y, x): + result = x.new_zeros((C,)) + result.copy_(y) + return result.sum() + + x = torch.randn(N, device=device) + y = torch.randn(N, C, device=device) + result = vmap(grad(foo))(y, x) + self.assertEqual(result, torch.ones_like(y)) + + def test_new_empty_materializes_tensor(self, device): + N = 3 + C = 5 + + def foo(y, x): + result = x.new_empty((C,)) + result.copy_(y) + return result.sum() + + x = torch.randn(N, device=device) + y = torch.randn(N, C, device=device) + result = vmap(grad(foo))(y, x) + self.assertEqual(result, torch.ones_like(y)) + + def test_per_sample_grads_simple(self, device): + def compute_loss(weight, x, t): + y = x @ weight + return ((y - t) ** 2).sum() + + weight = torch.randn(16, 2, device=device) + x = torch.randn(64, 16, device=device) + t = torch.randn(64, 2, device=device) + result = vmap(partial(grad(compute_loss), weight))(x, t) + expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] + expected = torch.stack(expected) + # TODO: Check if the rtol is a problem + self.assertEqual(result, expected, atol=0, rtol=5e-4) + + def test_per_sample_grads_embeddingnet(self, device): + class SampleNet(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.emb = nn.Embedding(vocab_size, 16) + self.fc1 = nn.Linear(16, 16) + self.fc2 = nn.Linear(16, 2) + + def forward(self, x): + x = self.emb(x) + x = torch.transpose(x, -1, -2) + x = torch.mean(x, -1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x + + def name(self): + return "SampleNet" + + # Create our inputs... + vocab_size = 1000 + batch_shape = [64] + words_per_sentence = 5 + data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device) + targets = torch.randint(0, 1, (*batch_shape,), device=device) + + # Construct our module + net = SampleNet(vocab_size).to(device=device) + criterion = nn.CrossEntropyLoss() + + net_func, weights = make_functional(net) + + def compute_loss(weights, data, target): + output = net_func(weights, data) + result = criterion(output, target) + return result + + expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)] + expected = zip(*expected) + expected = tuple(torch.stack(shards) for shards in expected) + + result = vmap(partial(grad(compute_loss), weights))(data, targets) + for r, e in zip(result, expected): + # TODO: Check if the rtol is a problem + self.assertEqual(r, e, atol=0, rtol=1e-3) + + def test_log_softmax(self, device): + x = torch.randn(3, 5, device=device) + v = torch.randn(5, device=device) + + def foo(x, v): + _, vjp_fn = vjp(partial(torch.log_softmax, dim=-1), x) + return vjp_fn(v)[0] + + result = vmap(foo, (0, None))(x, v) + + v = v.expand_as(x) + x.requires_grad_() + output = torch.log_softmax(x, dim=-1) + output.backward(v) + self.assertEqual(result, x.grad) + + +jacrev_and_jacfwd = parametrize("jacapi", [subtest(jacrev, name='jacrev'), subtest(jacfwd, name='jacfwd')]) + +FIXME_jacrev_only = parametrize("jacapi", [subtest(jacrev, name='jacrev')]) + + +class TestJac(TestCase): + @jacrev_and_jacfwd + def test_simple(self, device, jacapi): + x = torch.randn(3, device=device) + y = jacapi(torch.sin)(x) + expected = torch.diagflat(x.cos()) + assert torch.allclose(y, expected) + + @jacrev_and_jacfwd + def test_simple_not_flat(self, device, jacapi): + x = torch.randn(2, 3, device=device) + y = jacapi(torch.sin)(x) + expected = torch.diagflat(x.view(-1).cos()) + expected = expected.view(2, 3, 2, 3) + assert torch.allclose(y, expected) + + @FIXME_jacrev_only + def test_diff_numel(self, device, jacapi): + x = torch.randn(2, 4, device=device) + + # Tensor[2, 4] -> Tensor[3, 1] + def f(x): + return x[0, 1:].unsqueeze(-1) + + y = jacapi(f)(x) + self.assertEqual(y.shape, (3, 1, 2, 4)) + + expected = x.new_zeros(3, 1, 2, 4) + expected[0, 0, 0, 1] = 1 + expected[1, 0, 0, 2] = 1 + expected[2, 0, 0, 3] = 1 + self.assertEqual(y, expected) + + @FIXME_jacrev_only + def test_vmap_on_jac_simple(self, device, jacapi): + x = torch.randn(2, 3, device=device) + y = vmap(jacapi(torch.sin))(x) + expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)]) + assert torch.allclose(y, expected) + + @FIXME_jacrev_only + def test_nested_jac_simple(self, device, jacapi): + def foo(x): + return x.sin().sum() + + x = torch.randn(3, device=device) + y = jacapi(jacapi(foo))(x) + expected = torch.diagflat(-x.sin()) + assert torch.allclose(y, expected) + + @jacrev_and_jacfwd + def test_multiple_args(self, device, jacapi): + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(torch.multiply, argnums=1)(x, y) + expected = torch.diagflat(x) + assert torch.allclose(z, expected) + + @jacrev_and_jacfwd + def test_multiple_outputs_multiple_argnums(self, device, jacapi): + def f(x, y): + return 2 * x + 3 * y, 4 * x + 5 * y + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(f, argnums=(0, 1))(x, y) + expected_out0_x = torch.diagflat(torch.full_like(x, 2)) + expected_out0_y = torch.diagflat(torch.full_like(y, 3)) + expected_out1_x = torch.diagflat(torch.full_like(x, 4)) + expected_out1_y = torch.diagflat(torch.full_like(y, 5)) + + self.assertEqual(len(z), 2) + self.assertTrue(isinstance(z, tuple)) + self.assertEqual(len(z[0]), 2) + self.assertTrue(isinstance(z[0], tuple)) + self.assertEqual(z[0][0], expected_out0_x) + self.assertEqual(z[0][1], expected_out0_y) + self.assertEqual(z[1][0], expected_out1_x) + self.assertEqual(z[1][1], expected_out1_y) + + @jacrev_and_jacfwd + def test_multiple_outputs_single_argnums(self, device, jacapi): + def f(x, y): + return 2 * x + 3 * y, 4 * x + 5 * y + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + expected_out0_x = torch.diagflat(torch.full_like(x, 2)) + expected_out1_x = torch.diagflat(torch.full_like(x, 4)) + + z = jacapi(f, argnums=0)(x, y) + self.assertEqual(len(z), 2) + self.assertTrue(isinstance(z, tuple)) + self.assertEqual(z, (expected_out0_x, expected_out1_x)) + + z = jacapi(f, argnums=(0,))(x, y) + self.assertEqual(len(z), 2) + self.assertTrue(isinstance(z, tuple)) + self.assertTrue(isinstance(z[0], tuple)) + self.assertEqual(z, ((expected_out0_x,), (expected_out1_x,))) + + @FIXME_jacrev_only + def test_multiple_outputs_pytree(self, device, jacapi): + def f(x, y): + return {'left': 2 * x + 3 * y, 'right': 4 * x + 5 * y} + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(f, argnums=(0, 1))(x, y) + expected_left_x = torch.diagflat(torch.full_like(x, 2)) + expected_left_y = torch.diagflat(torch.full_like(y, 3)) + expected_right_x = torch.diagflat(torch.full_like(x, 4)) + expected_right_y = torch.diagflat(torch.full_like(y, 5)) + expected = { + 'left': (expected_left_x, expected_left_y), + 'right': (expected_right_x, expected_right_y), + } + self.assertTrue(isinstance(z, dict)) + self.assertTrue(isinstance(z['left'], tuple)) + self.assertTrue(isinstance(z['right'], tuple)) + self.assertEqual(z, expected) + + @jacrev_and_jacfwd + def test_multiple_inputs_pytree(self, device, jacapi): + def f(a, b, c): + a0, a1 = a + return a0 + a1 * 2 + b * 3 + c * 4 + + x = torch.randn([], device=device) + args = ((x, x), x, x) + + result = jacapi(f, argnums=(0, 1, 2))(*args) + expected = ( + (torch.tensor(1., device=device), torch.tensor(2., device=device)), + torch.tensor(3., device=device), + torch.tensor(4., device=device), + ) + self.assertEqual(result, expected) + + result = jacapi(f, argnums=(0,))(*args) + expected = ((torch.tensor(1., device=device), torch.tensor(2., device=device)),) + self.assertEqual(result, expected) + + result = jacapi(f)(*args) + expected = (torch.tensor(1., device=device), torch.tensor(2., device=device)) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_dimensionality(self, device, jacapi): + def f(x): + return x + + x = torch.randn([], device=device) + result = jacapi(f)(x) + self.assertEqual(result.dim(), 0) + self.assertEqual(result, torch.ones_like(x)) + + x = torch.randn([1], device=device) + result = jacapi(f)(x) + self.assertEqual(result.dim(), 2) + self.assertEqual(result, x.new_ones(1, 1)) + + @FIXME_jacrev_only + def test_aux_tensor(self, device, jacapi): + def f(x): + y = x.clone() + return y, y.cos() + + x = torch.randn(3, device=device) + result, aux = jacapi(f, has_aux=True)(x) + + self.assertEqual(result, torch.eye(3, 3, device=device)) + self.assertEqual(aux, x.cos()) + + @jacrev_and_jacfwd + def test_aux_pytree(self, device, jacapi): + def f(x): + y = x.clone() + return y, {'a': y.cos(), 'b': [y.tan()]} + + x = torch.randn(3, device=device) + + result, aux = jacapi(f, has_aux=True)(x) + self.assertEqual(result, torch.eye(3, 3, device=device)) + _, expected_aux = f(x) + self.assertEqual(aux, expected_aux) + + for aux in [1, 1.0, "abc"]: + with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): + _ = jacapi(lambda x: (x, aux), has_aux=True)(x) + with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): + _ = jacapi(lambda x: (x, [x, aux]), has_aux=True)(x) + + @jacrev_and_jacfwd + def test_outputs_can_any_pytree(self, device, jacapi): + x = torch.randn(2, 3, device=device) + + for output in [None, ()]: + with self.assertRaisesRegex( + RuntimeError, r"(vjp|jvp).+: Expected f to be a function that has non-empty output" + ): + jacapi(lambda _: output)(x) + + for output in [1, True, 12.2, "abc"]: + with self.assertRaisesRegex( + RuntimeError, r"(vjp|jvp).+: expected f\(\*primals\) to return only tensors" + ): + jacapi(lambda _: output)(x) + + # Check list output + out = jacapi(lambda x: [x, x.sum()])(x) + assert isinstance(out, list) and len(out) == 2 + + # Check dict output + out = jacapi(lambda x: {"x": x, "xsum": x.sum()})(x) + assert isinstance(out, dict) and len(out) == 2 and "xsum" in out + + def composite_output(x): + out = x.sum() + return [ + (out, {"a": x, "out": [x, out]}), + ] + + out = jacapi(composite_output)(x) + assert isinstance(out, list) + assert isinstance(out[0], tuple) and isinstance(out[0][1], dict) + + @jacrev_and_jacfwd + def test_multiple_inputs_outputs_pytree(self, device, jacapi): + def f(a, b, c): + a0, a1 = a + return a0 + a1 * 2, {'foo': b * 3 + c * 4} + + x = torch.randn([], device=device) + zero = torch.zeros([], device=device) + args = ((x, x), x, x) + + result = jacapi(f)(*args) + expected = ( + (torch.tensor(1., device=device), torch.tensor(2., device=device)), + {'foo': (zero, zero)}, + ) + self.assertEqual(result, expected) + + result = jacapi(f, argnums=(0,))(*args) + expected = ( + ((torch.tensor(1., device=device), torch.tensor(2., device=device)),), + {'foo': ((zero, zero),)}, + ) + self.assertEqual(result, expected) + + result = jacapi(f, argnums=(0, 1))(*args) + expected = ( + ((torch.tensor(1., device=device), torch.tensor(2., device=device)), zero), + {'foo': ((zero, zero), torch.tensor(3., device=device))}, + ) + self.assertEqual(result, expected) + + @FIXME_jacrev_only + def test_multiple_inputs_outputs_pytree_multidim(self, device, jacapi): + def f(dct): + a = dct['a'] + b = dct['b'] + return {'c': a.sin(), 'd': b.cos()} + + x = torch.randn(3, device=device) + args = ({'a': x, 'b': x},) + + result = jacapi(f)(*args) + expected = { + 'c': {'a': x.cos().diagflat(), 'b': x.new_zeros(3, 3)}, + 'd': {'a': x.new_zeros(3, 3), 'b': -x.sin().diagflat()}, + } + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_unrelated_input(self, device, jacapi): + def f(x, y): + return x + + x = torch.randn(2, 3, device=device) + y = torch.randn(2, 3, device=device) + + result = jacapi(f, argnums=(0, 1))(x, y) + expected0 = torch.eye(6, 6, device=device).view(2, 3, 2, 3) + expected1 = y.new_zeros(2, 3, 2, 3) + expected = (expected0, expected1) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_unrelated_output(self, device, jacapi): + y = torch.randn(2, 3, device=device) + + def f(x): + return y + + x = torch.randn(2, 3, device=device) + + result = jacapi(f)(x) + expected = x.new_zeros(2, 3, 2, 3) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_empty_output(self, device, jacapi): + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + + def f(x, y): + return () + + with self.assertRaisesRegex(RuntimeError, 'xpected'): + jacapi(f)(x, y) + + @jacrev_and_jacfwd + def test_argnums_tuple(self, device, jacapi): + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(torch.multiply, argnums=(0, 1))(x, y) + expected0 = torch.diagflat(y) + expected1 = torch.diagflat(x) + assert len(z) == 2 + assert torch.allclose(z[0], expected0) + assert torch.allclose(z[1], expected1) + + @jacrev_and_jacfwd + def test_argnums_effect_on_return(self, device, jacapi): + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(torch.multiply, argnums=(0,))(x, y) + expected0 = torch.diagflat(y) + assert isinstance(z, tuple) + assert len(z) == 1 + assert torch.allclose(z[0], expected0) + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(torch.multiply, argnums=0)(x, y) + expected0 = torch.diagflat(y) + assert isinstance(z, torch.Tensor) + assert torch.allclose(z, expected0) + + @jacrev_and_jacfwd + def test_argnums_defaults_to_zero(self, device, jacapi): + def f(x, y): + return x * 2 + y * 3 + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(f)(x, y) + expected = torch.diagflat(torch.full_like(x, 2)) + self.assertEqual(z, expected) + + @jacrev_and_jacfwd + def test_empty_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "must be non-empty"): + jacapi(torch.sin, argnums=())(x) + + @jacrev_and_jacfwd + def test_out_of_bounds_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"): + jacapi(torch.sin, argnums=2)(x) + + @jacrev_and_jacfwd + def test_negative_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"): + jacapi(torch.sin, argnums=-2)(x) + + @jacrev_and_jacfwd + def test_repeated_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "must be unique"): + jacapi(torch.sin, argnums=(0, 0))(x) + + @jacrev_and_jacfwd + def test_float_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "must be int or Tuple"): + jacapi(torch.sin, argnums=0.0)(x) + with self.assertRaisesRegex(RuntimeError, "must be int"): + jacapi(torch.multiply, argnums=(1, 0.0))(x, x) + + def test_hessian_simple(self, device): + def f(x): + return x.sin() + + x = torch.randn(3, device=device) + hessian(f)(x) + + def _test_against_reference(self, f, inputs, jacapi): + def foo(inputs): + return f(*inputs) + + expected = torch.autograd.functional.jacobian(f, inputs) + result = jacapi(foo)(inputs) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_against_reference_simple(self, device, jacapi): + def f(x): + return 3 * x ** 2 + + x = torch.randn(2, 3, 5, device=device) + self._test_against_reference(f, (x,), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_multi_input(self, device, jacapi): + def f(x, y): + return (x.cos() * x) @ y.sin() + + x = torch.randn(2, 3, device=device) + y = torch.randn(3, 5, device=device) + self._test_against_reference(f, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_multi_input_multi_output(self, device, jacapi): + def f(x, y): + return (x * x) @ y, x @ (x.sum(1) * y), y.sum() + + x = torch.randn(5, 3, device=device) + y = torch.randn(3, 5, device=device) + self._test_against_reference(f, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_unrelated_outputs(self, device, jacapi): + def f(x, y): + return x, y, x, y + + x = torch.randn(2, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_zero_dim(self, device, jacapi): + # zero-dim output + def f(x, y): + return x.sum(), y.sum(), x * y + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y), jacapi) + + # zero-dim input + def g(x): + return torch.stack([x, x, x]) + + x = torch.randn([], device=device) + self._test_against_reference(g, (x,), jacapi) + + # Mixed zero-dim input / zero-dim output + def h(x, y): + return y.sum(), x * y + + x = torch.randn([], device=device) + y = torch.randn(1, device=device) + self._test_against_reference(h, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_correctness_different_devices(self, device, jacapi): + def f(x, y): + return x * y, (x * y).to(device=device) + + x = torch.randn(3) + y = torch.randn(3) + self._test_against_reference(f, (x, y), jacapi) + + +class TestHessian(TestCase): + def _test_against_reference(self, f, inputs): + def foo(inputs): + return f(*inputs) + + expected = torch.autograd.functional.hessian(f, inputs) + result = hessian(foo)(inputs) + self.assertEqual(result, expected) + + def test_hessian_vectorize_correctness_simple(self, device): + def f(x): + return (3 * x ** 2).sum() + + x = torch.randn(2, 3, 5, device=device) + self._test_against_reference(f, (x,)) + + def test_hessian_vectorize_correctness_multi_input(self, device): + def f(x, y, z): + return ((x.relu() * x) @ y.sin() @ z).sum() + + x = torch.randn(2, 3, device=device) + y = torch.randn(3, 5, device=device) + z = torch.randn(5, 5, device=device) + self._test_against_reference(f, (x, y, z)) + + def test_hessian_vectorize_correctness_unrelated_outputs(self, device): + # output unrelated to one input + def f(x, y): + return (x ** 2).sum() + + x = torch.randn(2, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y)) + + # output unrelated to all inputs + def f(x, y): + return torch.ones([]) + + x = torch.randn(2, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y)) + + def test_jacfwd_different_levels(self, device): + # Test case from: + # https://github.com/pytorch/functorch/issues/597 + b = 8 + n = 100 + d = 2 + x1 = torch.randn(b, n, d, device=device) + x2 = x1 + A = 0.1 * torch.randn(b, d, d, device=device) + + def loss(A, x1, x2): + x2_hat = (A @ (x1.T)).T + res = x2 - x2_hat + res_sqr = res**2 + return res_sqr.sum() + + hess1 = vmap(jacrev(jacrev(loss)))(A, x1, x2) + hess2 = vmap(hessian(loss))(A, x1, x2) + self.assertEqual(hess2, hess1) + + +class TestJvp(TestCase): + def test_inplace_on_captures(self, device): + x = torch.tensor([1., 2., 3.], device=device) + captured = torch.randn(3, device=device) + + def foo(x): + captured.copy_(x) + return (x * captured).sum() + + with self.assertRaisesRegex(RuntimeError, 'mutate a captured Tensor'): + grad(foo)(x) + + def test_simple(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + result = jvp(torch.sin, (x,), (t,)) + expected = (x.sin(), x.cos() * t) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_multiple_inputs(self, device): + x = torch.randn(2, 3, device=device) + y = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + ty = torch.randn(2, 3, device=device) + + def f(x, y): + return x * y + + result = jvp(f, (x, y), (tx, ty)) + expected = (x * y, y * tx + x * ty) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_pytree_inputs(self, device): + def f(x, y, z): + a, b = x + return a + 2 * b + 3 * y + 4 * z + + one = torch.tensor(1., device=device) + primal_outs, tangent_outs = jvp(f, ((one, one), one, one), ((one, one), one, one)) + self.assertEqual(primal_outs, one * 10) + self.assertEqual(tangent_outs, one * 10) + + def test_pytree_inputs_error_cases(self, device): + def f(x): + return x + + one = torch.tensor(1., device=device) + + with self.assertRaisesRegex(RuntimeError, 'Expected primals to be a tuple'): + jvp(f, one, one) + with self.assertRaisesRegex(RuntimeError, 'same python structure'): + jvp(f, ((one, one), one), (one, one)) + with self.assertRaisesRegex(RuntimeError, 'only contain Tensors'): + jvp(f, ((one, one), 1), ((one, one), one)) + with self.assertRaisesRegex(RuntimeError, 'only contain Tensors'): + jvp(f, ((one, one), 1), ((1, one), one)) + with self.assertRaisesRegex(RuntimeError, 'at least one Tensor'): + jvp(f, ((),), ((),)) + + def test_unrelated_input(self, device): + def f(x, y): + return x + + x = torch.randn(2, 3, device=device) + y = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + ty = torch.randn(2, 3, device=device) + + result = jvp(f, (x, y), (tx, ty)) + expected = (x, tx) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_unrelated_output(self, device): + y = torch.randn(2, 3, device=device) + + def f(x): + return y + + x = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + + result = jvp(f, (x,), (tx,)) + expected = (y, torch.zeros_like(y)) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_strict_mode(self, device): + y = torch.randn(2, 3, device=device) + + def f(x): + return x, y + + x = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + + with self.assertRaisesRegex(RuntimeError, "strict"): + jvp(f, (x,), (tx,), strict=True) + + def test_multiple_outputs(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + def f(x): + return torch.sin(x), torch.cos(x) + + result = jvp(f, (x,), (t,)) + expected = (f(x), (x.cos() * t, -x.sin() * t)) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_multiple_inputs_outputs(self, device): + x = torch.randn(2, 3, device=device) + y = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + ty = torch.randn(2, 3, device=device) + + def f(x, y): + return 2 * x + 3 * y, 4 * x + 5 * y + + result = jvp(f, (x, y), (tx, ty)) + expected = (f(x, y), f(tx, ty)) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_primals_tangents_length_mismatch(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + msg = "same python structure" + with self.assertRaisesRegex(RuntimeError, msg): + jvp(torch.sin, (x,), (t, t)) + with self.assertRaisesRegex(RuntimeError, msg): + jvp(torch.sin, (x, x), (t, t, t)) + + def test_nonempty_primals_and_tangents(self, device): + with self.assertRaisesRegex(RuntimeError, "at least one Tensor"): + jvp(torch.sin, (), ()) + + def test_inputs_are_tuples_of_tensors(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + with self.assertRaisesRegex(RuntimeError, 'be a tuple'): + jvp(torch.sin, x, (t,)) + with self.assertRaisesRegex(RuntimeError, 'same python structure'): + jvp(torch.sin, (x,), t) + with self.assertRaisesRegex(RuntimeError, 'same python structure'): + jvp(torch.sin, (x,), [t]) + with self.assertRaisesRegex(RuntimeError, 'only contain Tensors'): + jvp(torch.sin, (1.,), (t,)) + with self.assertRaisesRegex(RuntimeError, 'only contain Tensors'): + jvp(torch.sin, (x,), (1.,)) + + def test_outputs_can_any_pytree(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + for output in [None, ()]: + with self.assertRaisesRegex( + RuntimeError, r"jvp\(f, primals, tangents\): Expected f to be a function that has non-empty output" + ): + jvp(lambda _: output, (x,), (t,)) + + for output in [1, True, 12.2, "abc"]: + with self.assertRaisesRegex( + RuntimeError, r"jvp\(f, primals, tangents\): expected f\(\*primals\) to return only tensors" + ): + jvp(lambda _: output, (x,), (t,)) + + # Check list output + out = jvp(lambda x: [x, x.sum()], (x,), (t,)) + for i in range(2): + assert isinstance(out[i], list) and len(out[i]) == 2 + + # Check dict output + out = jvp(lambda x: {"x": x, "xsum": x.sum()}, (x,), (t,)) + for i in range(2): + assert isinstance(out[i], dict) and len(out[i]) == 2 and "xsum" in out[i] + + def composite_output(x): + out = x.sum() + return [ + (out, {"a": x, "out": [x, out]}), + ] + + out = jvp(composite_output, (x,), (t,)) + for i in range(2): + assert isinstance(out[i], list) + assert isinstance(out[i][0], tuple) and \ + isinstance(out[i][0][1], dict) + + def test_aux_tensor(self, device): + + x = torch.randn(3, device=device) + t = torch.randn(3, device=device) + + with self.assertRaisesRegex( + RuntimeError, r'jvp\(f, primals, tangents\): output of function f should be a tuple' + ): + jvp(lambda t: [t, t], (x, ), (t, ), has_aux=True) + + with self.assertRaisesRegex( + RuntimeError, r'jvp\(f, primals, tangents\): output of function f should be a tuple' + ): + jvp(lambda t: (t, t + 2, t + 3), (x, ), (t, ), has_aux=True) + + def f(z): + y = z.sin() + return y, z.cos() + + out, jvp_out, aux = jvp(f, (x, ), (t, ), has_aux=True) + self.assertEqual(aux, x.cos()) + self.assertEqual(out, x.sin()) + self.assertEqual(jvp_out, t * x.cos()) + + def test_aux_pytree(self, device): + def f(x): + y = x.sin() + return y, {'a': x.cos(), 'b': [x.tan()]} + + x = torch.randn(3, device=device) + t = torch.randn(3, device=device) + + out, jvp_out, aux = jvp(f, (x, ), (t, ), has_aux=True) + expected_out, expected_aux = f(x) + self.assertEqual(out, expected_out) + self.assertEqual(aux, expected_aux) + self.assertEqual(jvp_out, t * x.cos()) + + for aux in [1, 1.0, "abc"]: + with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): + _ = jvp(lambda x: (x, aux), (x, ), (t, ), has_aux=True) + with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): + _ = jvp(lambda x: (x, [x, aux]), (x, ), (t, ), has_aux=True) + + def test_fwd_grad_enabled(self, device): + # Tests some private helper functions to enable/disable fwd grad mode + enabled = functorch._C.get_fwd_grad_enabled() + self.assertTrue(enabled) + + try: + functorch._C.set_fwd_grad_enabled(False) + enabled = functorch._C.get_fwd_grad_enabled() + self.assertFalse(enabled) + finally: + functorch._C.set_fwd_grad_enabled(True) + + enabled = functorch._C.get_fwd_grad_enabled() + self.assertTrue(enabled) + + def test_autograd_function_disables_fwd_grad(self, device): + # Sanity check. We don't really assume this anywhere so + # it's fine if this breaks one day. + class MySquare(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + enabled = functorch._C.get_fwd_grad_enabled() + self.assertFalse(enabled) + return x * x + + @staticmethod + def backward(ctx, gx): + return gx + + x = torch.randn(3, requires_grad=True) + MySquare.apply(x) + + def test_enable_fwd_grad(self, device): + # Tests a private helper function + try: + functorch._C.set_fwd_grad_enabled(False) + enabled = functorch._C.get_fwd_grad_enabled() + self.assertFalse(enabled) + + with enable_fwd_grad(): + enabled = functorch._C.get_fwd_grad_enabled() + self.assertTrue(enabled) + + enabled = functorch._C.get_fwd_grad_enabled() + self.assertFalse(enabled) + finally: + functorch._C.set_fwd_grad_enabled(True) + + def test_disable_fwd_grad_outside(self, device): + x = torch.randn([], device=device) + t = torch.ones_like(x) + with enable_fwd_grad(False): + _, y = jvp(torch.sin, (x,), (t,)) + self.assertEqual(y, x.cos()) + + def test_disable_fwd_grad_inside(self, device): + def f(x): + with enable_fwd_grad(False): + shift = x ** 2 + return x ** 2 - shift + + x = torch.randn([], device=device) + t = torch.ones_like(x) + _, y = jvp(f, (x,), (t,)) + self.assertEqual(y, 2 * x) + _, y = jvp(lambda x: jvp(f, (x,), (t,))[1], (x,), (t,)) + self.assertEqual(y, 2) + + def test_disable_fwd_grad_mixed(self, device): + def f(x): + with enable_fwd_grad(False): + shift = x ** 2 + return x ** 2 - shift + + x = torch.randn([], device=device) + t = torch.ones_like(x) + with enable_fwd_grad(): + _, y = jvp(f, (x,), (t,)) + + self.assertEqual(y, 2 * x) + + def test_jvp_inside_autograd_function(self, device): + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + t = torch.ones_like(x) + _, neg_sin_x = jvp(torch.cos, (x,), (t,)) + ctx.save_for_backward(x) + return -neg_sin_x + + @staticmethod + def backward(ctx, gx): + x, = ctx.saved_tensors + t = torch.ones_like(x) + _, cos_x = jvp(torch.sin, (x,), (t,)) + return gx * cos_x + + x = torch.randn([], device=device, requires_grad=True) + y = MySin.apply(x) + self.assertEqual(y, x.sin()) + + gx, = torch.autograd.grad(y, x) + self.assertEqual(gx, x.cos()) + + def test_zerotensor_vmapjvp_interaction(self, device): + dummy = torch.ones(4, 1) + x = torch.randn(4, 2) + x_tangent = torch.randn(2) + + def push_jvp(dummy, x): + result = jvp(torch.cov, (x,), (x_tangent,)) + return result + + # Should not error + vmap(vmap(push_jvp, (0, None)))(dummy, x) + + +class TestCustomFunction(TestCase): + @unittest.skipIf(IS_WINDOWS, "Prototype of custom_vjp doesn't link on windows") + @onlyCPU + def test_basic(self, device): + called_impl = False + called_vjp = False + + def my_sin_impl(args): + x, = args + nonlocal called_impl + called_impl = True + return x.sin(), x + + def my_sin_vjp(args): + grad_y, result, x = args + nonlocal called_vjp + called_vjp = True + return (grad_y * 3 * x.cos(),) + + def filter_fn(args): + return args[0] + + my_sin = custom_vjp('my_sin', filter_fn, my_sin_impl, my_sin_vjp) + + x = torch.tensor([1., 2.], requires_grad=True, device=device) + + y = my_sin(x) + self.assertTrue(called_impl) + + y.sum().backward() + self.assertTrue(called_vjp) + + assert torch.allclose(x.grad, 3 * x.cos()) + + +class TestComposability(TestCase): + def test_grad_grad(self, device): + x = torch.randn([], device=device) + y = grad(grad(torch.sin))(x) + self.assertEqual(y, -x.sin()) + + def test_grad_vmap(self, device): + def foo(x): + y = vmap(torch.sin)(x) + return y.sum() + + x = torch.randn(3, device=device) + y = grad(foo)(x) + self.assertEqual(y, x.cos()) + + def test_grad_vjp(self, device): + x = torch.randn(3, device=device) + + def foo(x): + _, vjp_fn = vjp(torch.sin, x) + return vjp_fn(x)[0].sum() + + y = grad(foo)(x) + expected = grad(lambda x: (x * x.cos()).sum())(x) + self.assertEqual(y, expected) + + def test_vmap_grad(self, device): + x = torch.randn(3, device=device) + y = vmap(grad(torch.sin))(x) + self.assertEqual(y, x.cos()) + + def test_vmap_vmap(self, device): + x = torch.randn(2, 3, device=device) + y = vmap(vmap(torch.sin))(x) + self.assertEqual(y, x.sin()) + + def test_vmap_vjp(self, device): + x = torch.randn(3, device=device) + _, vjp_fn = vjp(torch.sin, x) + + def foo(x): + _, vjp_fn = vjp(torch.sin, x) + return vjp_fn(x) + + y = vmap(foo)(x) + self.assertEqual(y, vjp_fn(x)) + + # TODO: there's a very interesting error message when the following + # is on CPU + xs = torch.randn(5, 3, device=device) + expected = torch.stack([vjp_fn(x)[0] for x in xs]) + result = vmap(lambda x: vjp_fn(x)[0])(xs) + self.assertEqual(result, expected) + + def test_vjp_grad(self, device): + x = torch.randn([], device=device) + y, vjp_fn = vjp(grad(torch.sin), x) + self.assertEqual(y, x.cos()) + + v = torch.randn([]) + self.assertEqual(vjp_fn(v)[0], -x.sin() * v) + + def test_vjp_vmap(self, device): + x = torch.randn(3, device=device) + y, vjp_fn = vjp(vmap(torch.sin), x) + self.assertEqual(y, x.sin()) + + v = torch.randn(3, device=device) + self.assertEqual(vjp_fn(v)[0], x.cos() * v) + + def test_vjp_vjp(self, device): + x = torch.randn(3, device=device) + y, vjp_fn = vjp(torch.sin, x) + self.assertEqual(y, x.sin()) + + y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x) + self.assertEqual(y, x * x.cos()) + + y = vjp_fn(x)[0] + # Honestly IDK what the result here is... but at least it runs + + def test_make_fx_vmap(self, device): + def f(x): + return torch.sin(x) + inp = torch.randn(5, 3) + f = vmap(f) + fx_f = make_fx(f)(inp) + new_inp = torch.randn(5, 3) + self.assertEqual(fx_f(new_inp), f(new_inp)) + + def test_make_fx_jacrev(self, device): + def f(x): + return x.sin().sum() + inp = torch.randn(3) + f = jacrev(jacrev(f)) + fx_f = make_fx(f)(inp) + new_inp = torch.randn(3) + self.assertEqual(fx_f(new_inp), f(new_inp)) + + def test_make_fx_vjp(self, device): + def f(x): + return torch.sin(x).sum() + + primals = torch.randn(3) + _, vjp_fn = vjp(f, primals) + cotangent = torch.randn(()) + fx_f = make_fx(vjp_fn)(cotangent, True, True) + new_cotangent = torch.randn(()) + self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) + + def test_requires_grad_inside_transform(self, device): + def f(x): + x.requires_grad_() + return x.sin().sum() + + x = torch.randn(3) + + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + vmap(f)(x) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + grad(f)(x) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + vmap(grad(f))(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + grad(grad(f))(x) + + def test_retain_grad_inside_transform(self, device): + def f(x): + y = x.sin() + y.retain_grad() + return y.sum() + + x = torch.randn(3) + + with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"): + grad(f)(x) + + def test_autograd_functional_jacrev_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_vjp_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_jvp_inside_transform(self, device): + def f(x): + t = torch.ones_like(x) + y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,)) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_jacfwd_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.jacobian( + lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaises(RuntimeError): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaises(RuntimeError): + grad(f)(x) + + +class TestMakeFunctional(TestCase): + @parametrize('disable_autograd_tracking', [True, False]) + def test_disable_autograd_tracking(self, disable_autograd_tracking): + class Foo(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 3) + + def forward(self, x): + x = self.linear(x) + return x + + mod = Foo() + _, params = make_functional(mod, disable_autograd_tracking=disable_autograd_tracking) + self.assertEqual(len(params), 2) + for param in params: + self.assertEqual(param.requires_grad, not disable_autograd_tracking) + + def test_parameter_tying(self): + class Foo(nn.Module): + def __init__(self): + super().__init__() + self.bias = nn.Parameter(torch.randn(3)) + self.linear = nn.Linear(3, 3) + self.linear.bias = self.bias + self.linear_tied = self.linear + + def forward(self, x): + x = self.linear(x) + x = self.linear_tied(x) + x = x + self.bias + return x + + torch.manual_seed(1) + mod = Foo() + func, _ = make_functional(mod) + + torch.manual_seed(0) + mod = Foo() + _, params = make_functional(mod) + self.assertEqual(len(params), 2) + + x = torch.randn(2, 3) + result = func(params, x) + expected = mod(x) + self.assertEqual(result, expected) + + def test_buffer_tying(self): + class Foo(nn.Module): + def __init__(self): + super().__init__() + self.bias = nn.Parameter(torch.randn(3)) + self.linear = nn.Linear(3, 3) + self.register_buffer('buffer', torch.randn(3)) + self.register_buffer('buffer_tied', self.buffer) + + def forward(self, x): + x = self.linear(x) + x = x + self.bias + x = x + self.buffer + x = x + self.buffer_tied + return x + + torch.manual_seed(1) + mod = Foo() + func, _, _ = make_functional_with_buffers(mod) + + torch.manual_seed(0) + mod = Foo() + _, params, buffers = make_functional_with_buffers(mod) + self.assertEqual(len(params), 3) + self.assertEqual(len(buffers), 1) + + x = torch.randn(2, 3) + result = func(params, buffers, x) + expected = mod(x) + self.assertEqual(result, expected) + + @parametrize('disable_autograd_tracking', [True, False]) + def test_with_buffers_disable_autograd_tracking(self, disable_autograd_tracking): + class Foo(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 3) + self.register_buffer('buffer', torch.randn(3)) + + def forward(self, x): + x = self.linear(x) + x = x + self.buffer + return x + + mod = Foo() + _, params, buffers = make_functional_with_buffers(mod, disable_autograd_tracking=disable_autograd_tracking) + self.assertEqual(len(params), 2) + self.assertEqual(len(buffers), 1) + for param in params: + self.assertEqual(param.requires_grad, not disable_autograd_tracking) + + def test_parameter_tying_grad(self): + class Foo(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 3) + self.weight = self.linear.weight + self.bias = self.linear.bias + + def forward(self, x): + x = self.linear(x) + x = F.linear(x, self.weight, self.bias) + return x + + x = torch.randn(2, 3) + torch.manual_seed(0) + mod = Foo() + loss = mod(x).sum() + expected = torch.autograd.grad(loss, mod.parameters()) + + mod = Foo() + fmod, _, _ = make_functional_with_buffers(mod) + torch.manual_seed(0) + mod = Foo() + _, params, buffers = make_functional_with_buffers(mod) + + def compute_loss(params, buffers, x): + return fmod(params, buffers, x).sum() + + result = grad(compute_loss)(params, buffers, x) + + self.assertEqual(result, expected) + + def test_parameter_tying_ensemble(self): + class Foo(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 3) + self.weight = self.linear.weight + self.bias = self.linear.bias + self.register_buffer('buffer', torch.randn(3)) + self.register_buffer('buffer_tied', self.buffer) + + def forward(self, x): + x = self.linear(x) + x = F.linear(x, self.weight, self.bias) + x = x + self.buffer + x = x + self.buffer_tied + return x + + num_models = 2 + xs = torch.randn(num_models, 64, 3) + models = [Foo() for _ in range(num_models)] + fmodel, _, _ = combine_state_for_ensemble(models) + + torch.manual_seed(0) + models = [Foo() for _ in range(num_models)] + _, params, buffers = combine_state_for_ensemble(models) + result = vmap(fmodel)(params, buffers, xs) + + torch.manual_seed(0) + models = [Foo() for _ in range(num_models)] + expected = torch.stack([model(x) for model, x in zip(models, xs)]) + + self.assertEqual(result, expected) + + def test_correctness_mnist(self): + class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + + x = torch.randn(64, 1, 32, 32) + torch.manual_seed(301) + fnet, _ = make_functional(Net()) + + torch.manual_seed(0) + _, params = make_functional(Net()) + result = fnet(params, x) + + torch.manual_seed(0) + net = Net() + expected = net(x) + + self.assertEqual(result, expected) + + def test_combine_state_for_ensemble_error(self): + in_features = 2 + out_features = 2 + + models = [] + with self.assertRaisesRegex(RuntimeError, "Expected at least one model"): + _ = combine_state_for_ensemble(models) + + num_models = 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + models[1].eval() + with self.assertRaisesRegex(RuntimeError, "same training/eval mode"): + _ = combine_state_for_ensemble(models) + + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + models[1] = torch.nn.Conv2d(3, 3, (3, 3)) + with self.assertRaisesRegex(RuntimeError, "models to be of the same class"): + _ = combine_state_for_ensemble(models) + + def test_combine_state_for_ensemble_smoke(self): + in_features = 2 + out_features = 2 + num_models = 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + _ = combine_state_for_ensemble(models) + + +class TestExamplesCorrectness(TestCase): + def test_maml_regression(self, device): + class ThreeLayerNet(nn.Module): + def __init__(self): + super(ThreeLayerNet, self).__init__() + self.fc1 = nn.Linear(1, 40) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(40, 40) + self.relu2 = nn.ReLU() + self.fc3 = nn.Linear(40, 1) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + x = self.relu2(x) + x = self.fc3(x) + return x + + # TODO: should replace with F.mse_loss + def mse_loss(x, y): + return torch.mean((x - y) ** 2) + + net, params = make_functional(ThreeLayerNet().to(device)) + K = 20 + num_tasks = 4 + alpha = 0.1 + + def sample_tasks(outer_batch_size, inner_batch_size): + # Select amplitude and phase for the task + As = [] + phases = [] + for _ in range(outer_batch_size): + As.append(np.random.uniform(low=0.1, high=.5)) + phases.append(np.random.uniform(low=0., high=np.pi)) + + def get_batch(): + xs, ys = [], [] + for A, phase in zip(As, phases): + x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) + y = A * np.sin(x + phase) + xs.append(x) + ys.append(y) + return torch.tensor(xs, dtype=torch.float, device=device), \ + torch.tensor(ys, dtype=torch.float, device=device) + x1, y1 = get_batch() + x2, y2 = get_batch() + return x1, y1, x2, y2 + + def get_loss_for_task(use_transform, x1, y1, x2, y2): + def inner_loss(params, x1, y1): + f = net(params, x1) + loss = mse_loss(f, y1) + return loss + + if use_transform: + grads = grad(inner_loss)(params, x1, y1) + else: + loss = inner_loss(params, x1, y1) + grads = torch.autograd.grad(loss, params, create_graph=True) + new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))] + + v_f = net(new_params, x2) + return mse_loss(v_f, y2) + + task = sample_tasks(num_tasks, K) + + # Compute with vmap+grad + inner_losses = vmap(partial(get_loss_for_task, True))(task[0], task[1], task[2], task[3]) + loss2 = sum(inner_losses) / len(inner_losses) + result_grads = torch.autograd.grad(loss2, params) + + # Compute without vmap+grad + inner_losses = [ + get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i]) + for i in range(num_tasks) + ] + loss2 = sum(inner_losses) / len(inner_losses) + expected_grads = torch.autograd.grad(loss2, params) + + self.assertEqual(result_grads, expected_grads) + + def test_maml_omniglot(self, device): + # TODO: there appears to be precision issues for float32 + dtype = torch.double + + # TODO: We don't support inplace relu? + inplace_relu = False + n_way = 5 + n_inner_iter = 2 + num_tasks = 2 + + # real example uses batch norm but it's numerically unstable in the first + # iteration, when near 0, and won't produce same gradients. Uses group norm instead + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.GroupNorm(64, 64, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.GroupNorm(64, 64, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.GroupNorm(64, 64, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, n_way)).to(device).to(dtype) + + fnet, params, buffers = make_functional_with_buffers(net) + net = (params, buffers, fnet) + + def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry): + params, buffers, fnet = net + querysz = x_qry.size(0) + + def compute_loss(new_params, buffers, x, y): + logits = fnet(new_params, buffers, x) + loss = F.cross_entropy(logits, y) + return loss + + new_params = params + for _ in range(n_inner_iter): + if use_transform: + grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) + else: + res = compute_loss(new_params, buffers, x_spt, y_spt) + grads = torch.autograd.grad(res, new_params, create_graph=True) + new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + + qry_logits = fnet(new_params, buffers, x_qry) + qry_loss = F.cross_entropy(qry_logits, y_qry) + qry_acc = (qry_logits.argmax( + dim=1) == y_qry).sum() / querysz + + return qry_loss, qry_acc + + # Get some sample inputs... + x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device) + y_spt = torch.randint(0, 5, (num_tasks, 25), device=device) + x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype, device=device) + y_qry = torch.randint(0, 5, (num_tasks, 75), device=device) + + # compute with vmap + grad + compute_loss = partial(loss_for_task, net, n_inner_iter, True) + qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry) + result_grads = torch.autograd.grad(qry_losses.sum(), params) + + # compute without vmap + grad + compute_loss = partial(loss_for_task, net, n_inner_iter, False) + losses = [compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0] + for i in range(num_tasks)] + expected_grads = torch.autograd.grad(sum(losses), params) + + self.assertEqual(result_grads, expected_grads) + + @parametrize('originally_track_running_stats', [True, False]) + def test_update_batch_norm(self, device, originally_track_running_stats): + dtype = torch.double + inplace_relu = False + classes = 5 + num_batches = 2 + net = nn.Sequential( + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, affine=True, track_running_stats=originally_track_running_stats), + nn.ReLU(inplace=inplace_relu), + nn.Flatten(), + nn.Linear(43264, classes)).to(device).to(dtype) + + replace_all_batch_norm_modules_(net) + transformed_net = net + fnet, params, buffers = make_functional_with_buffers(transformed_net) + net = (params, buffers, fnet) + criterion = nn.CrossEntropyLoss() + + def compute_loss(x, y, params, buffers): + return criterion(fnet(params, buffers, x), y) + + # Get some sample inputs... + x = torch.randn(num_batches, 1, 64, 28, 28, device=device, dtype=dtype) + y = torch.randint(0, classes, (num_batches, 1), device=device) + + # compute some per sample grads with vmap + grad + result_grads = vmap(grad(compute_loss, argnums=2), in_dims=(0, 0, None, None))(x, y, params, buffers) + + # compute some per sample grads without vmap + grad + fnet, params, buffers = make_functional_with_buffers(transformed_net) + expected_grads = [ + torch.autograd.grad(compute_loss(x[i], y[i], params, buffers), params) + for i in range(num_batches) + ] + expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] + + self.assertEqual(result_grads, expected_grads) + + @parametrize('jac', ['jacfwd', 'jacrev']) + def test_lennard_jones_batched_jac(self, device, jac): + sigma = 0.5 + epsilon = 4. + + jac = getattr(functorch, jac) + + def lennard_jones(r): + return epsilon * ((sigma / r)**12 - (sigma / r)**6) + + def lennard_jones_force(r): + """Get magnitude of LJ force""" + return \ + -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)) + + r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device) + drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device)) + norms = torch.norm(drs, dim=1).reshape(-1, 1) + training_energies = \ + torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1) + training_forces = torch.stack( + [force * dr + for force, dr in zip(map(lennard_jones_force, norms), drs)]) + + model = nn.Sequential( + nn.Linear(1, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 1) + ).to(device) + + def make_prediction(model, drs, use_functorch): + norms = torch.norm(drs, dim=1).reshape(-1, 1) + energies = model(norms) + + if use_functorch: + network_derivs = vmap(jac(model))(norms).squeeze(-1) + forces = -network_derivs * drs / norms + else: + forces = [] + for r, dr in zip(norms, drs): + network_deriv = torch.autograd.functional.jacobian( + model, r, create_graph=True) + force = -network_deriv * dr / r + forces.append(force) + forces = torch.cat(forces) + return energies, forces + + def loss_fn(energies, forces, predicted_energies, predicted_forces): + return F.mse_loss(energies, predicted_energies) + \ + 0.01 * F.mse_loss(forces, predicted_forces) / 3 + + energies, forces = make_prediction(model, drs, use_functorch=True) + loss = loss_fn(training_energies, training_forces, energies, forces) + result = torch.autograd.grad(loss, model.parameters()) + + energies, forces = make_prediction(model, drs, use_functorch=False) + loss = loss_fn(training_energies, training_forces, energies, forces) + expected = torch.autograd.grad(loss, model.parameters()) + + self.assertEqual(result, expected) + + def test_ensemble_regression(self, device): + def make_spirals(n_samples, noise_std=0., rotations=1.): + ts = torch.linspace(0, 1, n_samples) + rs = ts ** 0.5 + thetas = rs * rotations * 2 * math.pi + signs = torch.randint(0, 2, (n_samples,)) * 2 - 1 + labels = (signs > 0).to(torch.long) + + xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std + ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std + points = torch.stack([xs, ys], dim=1) + return points.to(device), labels.to(device) + + points, labels = make_spirals(100, noise_std=0.05) + + class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + loss_fn = nn.NLLLoss() + + func_model, weights = make_functional(MLPClassifier().to(device)) + + def train_step_fn(use_transform, weights, batch, targets, lr=0.2): + def compute_loss(weights, batch, targets): + output = func_model(weights, batch) + loss = loss_fn(output, targets) + return loss + + if use_transform: + grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) + else: + loss = compute_loss(weights, batch, targets) + grad_weights = torch.autograd.grad(loss, weights) + + new_weights = [] + with torch.no_grad(): + for grad_weight, weight in zip(grad_weights, weights): + new_weights.append(weight - grad_weight * lr) + # NB: return looks weird because torch.vmap must return Tensors + return (loss, *new_weights) + + def unpack(train_result): + return train_result[0], train_result[1:] + + def init_fn(num_models): + models = tuple(MLPClassifier().to(device) for _ in range(num_models)) + weights = tuple(make_functional(model)[1] for model in models) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + return weights + + def slice_weights(batched_weights, index): + return tuple(weight[index].detach().requires_grad_() for weight in batched_weights) + + batched_weights = init_fn(num_models=2) + parallel_train_step_fn = vmap(partial(train_step_fn, True), in_dims=(0, None, None)) + + result_loss, result_weights = unpack(parallel_train_step_fn(batched_weights, points, labels)) + + loss0, weights0 = unpack(train_step_fn(False, slice_weights(batched_weights, 0), points, labels)) + loss1, weights1 = unpack(train_step_fn(False, slice_weights(batched_weights, 1), points, labels)) + expected_loss = torch.stack([loss0, loss1]) + expected_weights = tuple(torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1)) + + self.assertEqual(result_loss, expected_loss) + self.assertEqual(result_weights, expected_weights) + + @parametrize("dropout_layer", [nn.Dropout, nn.AlphaDropout, nn.FeatureAlphaDropout]) + def test_find_learning_rate_ensembling(self, device, dropout_layer): + # This example mimics what a user might do when trying to find the optimal learning rate. They would + # want to run a bunch of models with the same behavior (including the same dropout!) and have them + # each run with different learning rates. Specifically, this is an example of using same randomness with vmap + points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint(0, 2, (100,), device=device) + + class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.dropout = dropout_layer() + self.fc1 = nn.Linear(16, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.dropout(x) + x = torch.flatten(x, start_dim=1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + loss_fn = nn.NLLLoss() + + func_model, weights = make_functional(MLPClassifier().to(device)) + + def train_step_fn(weights, batch, targets, lr): + def compute_loss(weights, batch, targets): + output = func_model(weights, batch) + loss = loss_fn(output, targets) + return loss + + grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) + new_weights = [] + with torch.no_grad(): + for grad_weight, weight in zip(grad_weights, weights): + new_weights.append(weight - grad_weight * lr) + # NB: return looks weird because torch.vmap must return Tensors + return (loss, *new_weights) + + def unpack(train_result): + return train_result[0], train_result[1:] + + def init_fn(num_models): + og_model = MLPClassifier().to(device) + models = tuple(copy.deepcopy(og_model) for _ in range(num_models)) # have same initialization + weights = tuple(make_functional(model)[1] for model in models) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + return weights + + batched_weights = init_fn(num_models=2) + parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None, 0), randomness="same") + + lrs = torch.tensor([0.2, 0.4], device=device) + result_loss, result_weights = unpack(parallel_train_step_fn(batched_weights, points, labels, lrs)) + + self.assertEqual(result_loss[0], result_loss[1]) + self.assertNotEqual(tuple(weight[0] for weight in result_weights), + tuple(weight[1] for weight in result_weights)) + + @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") + def test_resnet18_per_sample_grads(self, device): + import torchvision.models as models + model = models.__dict__['resnet18']( + pretrained=False, norm_layer=(lambda c: nn.GroupNorm(min(32, c), c)) + ).to(device) + criterion = nn.CrossEntropyLoss(reduction='sum') # avoid cross batch reductions for for loop comparison + + func_model, weights = make_functional(model) + + def compute_loss(weights, image, target): + output = func_model(weights, images) + loss = criterion(output, targets) + return loss + + batch_size = 3 + images = torch.randn(batch_size, 3, 32, 32, device=device) + targets = torch.randint(0, 10, (batch_size,), device=device) + + result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(weights, images, targets) + + expected_grads = [ + torch.autograd.grad(compute_loss(weights, images[i].unsqueeze(0), targets[i].unsqueeze(0)), weights) + for i in range(batch_size) + ] + expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] + + self.assertEqual(result_grads, expected_grads, atol=1e-3, rtol=1.) + + +class TestFunctionalize(TestCase): + def _check_functionalize_correctness(self, f, inpt): + inpt1 = inpt.clone() + inpt2 = inpt.clone() + inpt3 = inpt.clone() + + expected_outputs = f(inpt1) + actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze() + # Right now the flavor of functionalize that also removes view ops + # isn't being used with vmap + # That's because {view}_copy ops don't have batching rules yet + # (although we should probably fix that) + actual_outputs_view_copy = functionalize(f, remove='mutations_and_views')(inpt3) + # Check that outputs are the same + self.assertEqual(actual_outputs, expected_outputs) + self.assertEqual(actual_outputs_view_copy, expected_outputs) + + # Inputs might have been mutated by f: check that they were mutated properly + self.assertEqual(inpt1, inpt2) + self.assertEqual(inpt1, inpt3) + + def test_simple_view(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y = x.view(4, 2) + y.add_(tmp) + return x + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + + def test_multioutput_view(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y1, y2 = x.split(2) + y1_view = y1.diagonal() + y1_view.add_(tmp) + return x + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + + def test_inplace_view(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(4, device=device) + y = x + x + y2 = y.transpose(1, 0) + z = y2[0] + z.add_(tmp) + return y + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + + # See https://github.com/pytorch/functorch/issues/780 + def test_linear(self, device): + + def f(x, y, z) -> torch.Tensor: + return torch._C._nn.linear(x, y, z) + + x = torch.randn(14, 1, 384, device=device) + y = torch.randn(96, 384, device=device) + z = torch.randn(96, device=device) + + out_expected = f(x, y, z) + out_actual = functionalize(f)(x, y, z) + self.assertEqual(out_expected, out_actual) + + def test_multioutput_inplace_slice_view(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, 2, device=device) + y = x.view(8) + z0 = y.reshape(2, 4) + z1 = z0.transpose(1, 0) + z1.unsqueeze_(0) + z1.squeeze_() + z2, z3 = z1.split(2) + z2.add_(tmp) + return x + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + + # Ensure functionalize works with List[Optional[Tensor]] arguments. + # See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085 + def test_functionalize_opt_tensor_list(self, device): + + def f(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return x[indices] + + inpta = torch.ones(4, device=device) + inptb = torch.arange(2, device=device) + out1 = f(inpta, inptb) + out2 = functionalize(f)(inpta, inptb) + self.assertEqual(out1, out2) + out = make_fx(functionalize(f))(inpta, inptb) + self.assertExpectedInline((out.code), """\ + + + +def forward(self, x_1, indices_1) -> torch.Tensor: + index_tensor = torch.ops.aten.index.Tensor(x_1, [indices_1]); x_1 = indices_1 = None + return index_tensor + """) + + # Ensure grad(functionalize(f)) works + def test_functionalize_grad(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y = x + x + z = y.view(4, 2) + y.add_(tmp) + return z.sum() + + inpt1 = torch.ones(4, 2, device=device) + inpt2 = torch.ones(4, 2, device=device) + out1 = grad(f)(inpt1) + out2 = grad(functionalize(f))(inpt2) + self.assertEqual(out1, out2) + self.assertEqual(inpt1, inpt2) + + def test_vmap_functionalize_jvp(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + y = x + x + z = y.view(-1) + y.add_(1) + return z + + def jvp_wrapper(x, t): + return jvp(f, (x,), (t,),) + + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + out1 = vmap(jvp_wrapper)(x, t) + out2 = vmap(functionalize(jvp_wrapper))(x, t) + self.assertEqual(out1, out2) + + def test_functionalize_fx_simple(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y = x.view(4, 2) + y.add_(tmp) + return x + # There's a copy_ in the graph, because the input (x) was mutated. + # To preserve semantics, functionalize() needs to propagate the mutation. + fn = make_fx(functionalize(f, remove='mutations_and_views'), trace_factory_functions=False) + out = fn(torch.zeros(4, 2, device=device)) + self.assertExpectedInline((out.code), """\ + + + +def forward(self, x_1) -> torch.Tensor: + view_copy_default = torch.ops.aten.view_copy.default(x_1, [4, 2]) + _tensor_constant0 = self._tensor_constant0 + add_tensor = torch.ops.aten.add.Tensor(view_copy_default, _tensor_constant0); view_copy_default = _tensor_constant0 = None + view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None + copy__default = torch.ops.aten.copy_.default(x_1, view_copy_default_1); x_1 = None + return view_copy_default_1 + """) + + def test_functionalize_fx_transpose_simple(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + return x.transpose(1, 0) + fn = make_fx(functionalize(f, remove='mutations_and_views'), trace_factory_functions=False) + out = fn(torch.zeros(4, 2, device=device)) + self.assertExpectedInline(out.code, """\ + + + +def forward(self, x_1) -> torch.Tensor: + transpose_copy_int = torch.ops.aten.transpose_copy.int(x_1, 1, 0); x_1 = None + return transpose_copy_int + """) + + def test_functionalize_fx_out_op(self, device): + + def f(inpt: torch.Tensor) -> torch.Tensor: + out = torch.empty((), dtype=torch.float32) + torch.add(inpt, inpt, out=out) + out_view = out.view(4) + out_view.add_(1) + return out + + fn = make_fx(functionalize(f, remove='mutations_and_views'), trace_factory_functions=False) + out = fn(torch.arange(4, device=device, dtype=torch.float32)) + self.assertExpectedInline(out.code, """\ + + + +def forward(self, inpt_1) -> torch.Tensor: + add_tensor = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None + view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4]) + view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4]); add_tensor = None + add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None + view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor_1, [4]); add_tensor_1 = None + return view_copy_default_2 + """) + + def test_functionalize_fx_multi_out_op(self, device): + + def f(inpt: torch.Tensor) -> torch.Tensor: + mins = torch.empty(4, dtype=torch.float32) + maxs = torch.empty(2, 2, dtype=torch.float32) + maxs_view = maxs.view(4) + inpt_view = inpt.view(2, 4) + torch.aminmax(inpt_view, dim=0, out=(mins, maxs_view)) + return (maxs, mins) + + fn = make_fx(functionalize(f, remove='mutations_and_views'), trace_factory_functions=False) + out = fn(torch.arange(8, device=device, dtype=torch.float32)) + self.assertExpectedInline(out.code, """\ + + + +def forward(self, inpt_1) -> torch.Tensor: + view_copy_default = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None + aminmax_default = torch.ops.aten.aminmax.default(view_copy_default, dim = 0); view_copy_default = None + getitem = aminmax_default[0] + getitem_1 = aminmax_default[1]; aminmax_default = None + view_copy_default_1 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None + return (view_copy_default_1, getitem) + """) + + def test_functionalize_fx_reapply_views_simple(self, device): + + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y = x.view(4, 2) + y.add_(tmp) + return x + + out = make_fx(functionalize(f), trace_factory_functions=False)(torch.zeros(4, 2, device=device)) + self.assertExpectedInline(out.code, """\ + + + +def forward(self, x_1) -> torch.Tensor: + view_default = torch.ops.aten.view.default(x_1, [4, 2]) + _tensor_constant0 = self._tensor_constant0 + add_tensor = torch.ops.aten.add.Tensor(view_default, _tensor_constant0); view_default = _tensor_constant0 = None + view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]); add_tensor = None + copy__default = torch.ops.aten.copy_.default(x_1, view_default_1); x_1 = None + return view_default_1 + """) + + def test_functionalize_nonfunctional_output(self, device): + + global_out = torch.ones(2, device=device) + + def f() -> torch.Tensor: + return global_out + + out = make_fx(functionalize(f))() + self.assertExpectedInline(out.code, """\ + + + +def forward(self) -> torch.Tensor: + _tensor_constant0 = self._tensor_constant0 + return _tensor_constant0 + """) + + def test_functionalize_optional_tensorlist1(self, device): + + def f(a, b) -> torch.Tensor: + # at::index has OptionalTensorList arguments, + # test that here + return a[b] + + a = torch.arange(4).reshape(2, 2) + b = torch.ones(2, dtype=torch.long) + out = make_fx(functionalize(f))(a, b) + self.assertExpectedInline(out.code, """\ + + + +def forward(self, a_1, b_1) -> torch.Tensor: + index_tensor = torch.ops.aten.index.Tensor(a_1, [b_1]); a_1 = b_1 = None + return index_tensor + """) + + def test_functionalize_optional_tensorlist2(self, device): + + def f(a, b) -> torch.Tensor: + # See https://github.com/pytorch/pytorch/pull/77846 + return torch.ops.aten.index(a, b) + + a = torch.arange(4).reshape(2, 2) + b = torch.ones(2, dtype=torch.long) + out = make_fx(functionalize(f))(a, b) + self.assertExpectedInline(out.code, """\ + + + +def forward(self, a_1, b_1) -> torch.Tensor: + unbind_int = torch.ops.aten.unbind.int(b_1); b_1 = None + getitem = unbind_int[0] + getitem_1 = unbind_int[1]; unbind_int = None + index_tensor = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]); a_1 = getitem = getitem_1 = None + return index_tensor + """) + + + +only_for = ("cpu", "cuda") +instantiate_device_type_tests( + TestGradTransform, + globals(), + only_for=only_for, +) +instantiate_device_type_tests( + TestVmapOfGrad, + globals(), + only_for=only_for, +) +instantiate_device_type_tests( + TestJac, + globals(), + only_for=only_for, +) +instantiate_device_type_tests( + TestJvp, + globals(), + only_for=only_for, +) +instantiate_device_type_tests( + TestHessian, + globals(), + only_for=only_for, +) +instantiate_device_type_tests( + TestComposability, + globals(), + only_for=only_for, +) +instantiate_device_type_tests( + TestExamplesCorrectness, + globals(), + only_for=only_for, +) +instantiate_device_type_tests( + TestCustomFunction, + globals(), + only_for=only_for, +) +instantiate_device_type_tests( + TestFunctionalize, + globals(), + only_for=only_for, +) +instantiate_parametrized_tests( + TestMakeFunctional, +) + +if __name__ == '__main__': + run_tests() diff --git a/functorch/test/test_functionalize.py b/functorch/test/test_functionalize.py new file mode 100644 index 0000000000000..399273bfaf0d2 --- /dev/null +++ b/functorch/test/test_functionalize.py @@ -0,0 +1,49 @@ +# Owner(s): ["module: functorch"] + +import functorch +from unittest.mock import patch +import functools +from torch.testing._internal.common_utils import run_tests +import test_compile_cache +import test_pythonkey + + +def make_functionalize_fn(fn): + @functools.wraps(fn) + def _fn(*args, **kwargs): + with patch.object(functorch.compile.config, "use_functionalize", True): + return fn(*args, **kwargs) + + return _fn + + +def make_functionalize_test(cls): + class FunctionalizeTest(cls): + pass + + FunctionalizeTest.__name__ = f"Functionalize{cls.__name__}" + + for name in dir(cls): + if name.startswith("test_"): + fn = getattr(cls, name) + if not callable(fn): + continue + + new_name = f"{name}_functionalize" + fn = make_functionalize_fn(fn) + fn.__name__ = new_name + setattr(FunctionalizeTest, name, None) + setattr(FunctionalizeTest, new_name, fn) + + return FunctionalizeTest + + +FunctionalizeTestCompileCache = make_functionalize_test(test_compile_cache.TestCompileCache) +FunctionalizeTestCompileCacheStaticArgs = make_functionalize_test(test_compile_cache.TestCompileCacheStaticArgs) +FunctionalizeTestPythonKeyAOT = make_functionalize_test(test_pythonkey.TestAOTAutograd) +FunctionalizeTestPythonKeyContiguous = make_functionalize_test(test_pythonkey.TestContiguous) +FunctionalizeTestPythonKeyRandom = make_functionalize_test(test_pythonkey.TestRandom) +FunctionalizeTestPythonKeyPartitioning = make_functionalize_test(test_pythonkey.TestPartitioning) + +if __name__ == "__main__": + run_tests() diff --git a/functorch/test/test_memory_efficient_fusion.py b/functorch/test/test_memory_efficient_fusion.py new file mode 100644 index 0000000000000..b0f18f06b8295 --- /dev/null +++ b/functorch/test/test_memory_efficient_fusion.py @@ -0,0 +1,371 @@ +# Owner(s): ["module: functorch"] + +import torch +import torch.nn as nn +import torch.fx as fx +from functorch import make_fx +from torch.nn import functional as F +from functorch.compile import memory_efficient_fusion +from functorch._src.compile_utils import fx_graph_cse +from torch.testing._internal.common_utils import TestCase, run_tests +import inspect +import random +from typing import Callable +import unittest + +HAS_CUDA = torch.cuda.is_available() + + +def _num_args(fn: Callable): + return len(inspect.signature(fn).parameters) + + +def gelu_bias(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + +def swish(x): + return x * torch.sigmoid(x) + + +def mish(x): + return x.mul(torch.tanh(F.softplus(x))) + + +def hard_sigmoid(x): + return (x + 3.0).clamp(min=0.0, max=6.0).div(6.0) + + +def hard_swish(x): + return x * (x + 3.0).clamp(min=0.0, max=6.0).div(6.0) + + +def hard_mish(x): + return 0.5 * x * (x + 2.0).clamp(min=0.0, max=2.0) + + +# todo: convert these into tests +# def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): +# B, C, H, W = x.shape +# x_dtype = x.dtype +# if flatten: +# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues +# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) +# else: +# x = x.reshape(B, groups, C // groups, H, W) +# std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) +# return std.expand(x.shape).reshape(B, C, H, W) + +# class EvoNorm2dS0(nn.Module): +# def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): +# super().__init__() +# self.apply_act = apply_act # apply activation (non-linearity) +# if group_size: +# assert num_features % group_size == 0 +# self.groups = num_features // group_size +# else: +# self.groups = groups +# self.eps = eps +# self.weight = nn.Parameter(torch.ones(num_features)) +# self.bias = nn.Parameter(torch.zeros(num_features)) +# self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None +# self.reset_parameters() + +# def reset_parameters(self): +# nn.init.ones_(self.weight) +# nn.init.zeros_(self.bias) +# if self.v is not None: +# nn.init.ones_(self.v) + +# def forward(self, x): +# x_dtype = x.dtype +# v_shape = (1, -1, 1, 1) +# if self.v is not None: +# v = self.v.view(v_shape).to(dtype=x_dtype) +# x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) +# return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + + +# device = "cuda" +# dtype = torch.float + +# evo_norm = EvoNorm2dS0(2048) +# evo_norm_inp = [(128, 2048, 8, 8)] + + +def run_and_compare_activation(self, fn, inps): + with torch.jit.fuser("fuser1"): + device = "cuda" + dtype = torch.float + if isinstance(fn, nn.Module): + fn = fn.to(device=device, dtype=dtype) + + ref_args = [torch.randn(shape, device=device, dtype=dtype, requires_grad=True) for shape in inps] + res_args = [i.clone().detach().requires_grad_(True) for i in ref_args] + + ref = fn(*ref_args) + ref.sum().backward() + + mem_optimized_fn = memory_efficient_fusion(fn) + for _ in range(5): + for i in res_args: + i.grad = None + res = mem_optimized_fn(*res_args) + res.sum().backward() + + self.assertEqual(ref, res) + for ref_arg, res_arg in zip(ref_args, res_args): + self.assertEqual(ref_arg.grad, res_arg.grad) + + +@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") +class TestMemoryEfficientOpAuthoring(TestCase): + def test_gelu_bias(self): + run_and_compare_activation(self, gelu_bias, [(1024,), (1024,)]) + + def test_mish(self): + run_and_compare_activation(self, mish, [(1024,)]) + + def test_swish(self): + run_and_compare_activation(self, swish, [(1024,)]) + + def test_hard_sigmoid(self): + run_and_compare_activation(self, hard_sigmoid, [(1024,)]) + + def test_hard_swish(self): + run_and_compare_activation(self, hard_swish, [(1024,)]) + + def test_layer_norm(self): + def layer_norm(x, weight, bias): + dim = -1 + eps = 1e-5 + mean = torch.mean(x, dim, keepdim=True) + centered = x - mean + var = torch.sum(centered * centered, dim, keepdim=True) / x.size(-1) + rvar = 1. / torch.sqrt(var + eps) + normed = (x - mean) * rvar + return normed * weight + bias + + bs = 10 + ln_size = 16 + layer_norm_inps = [(bs, ln_size), (ln_size,), (ln_size,)] + run_and_compare_activation(self, layer_norm, layer_norm_inps) + + def test_rmsnorm(self): + class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + bs = 256 + seq = 256 + hidden = 1024 + t5_norm = T5LayerNorm(hidden) + t5_norm_inputs = [(bs, seq, hidden)] + run_and_compare_activation(self, t5_norm, t5_norm_inputs) + + # TODO - Assertion failure + # def test_hard_mish(self): + # for compiler in compilers: + # run_and_compare_activation(hard_mish, 1024) + + +# check if the CSE modified graph of f has delta less nodes, and do not reduce the number of nodes further on a second pass. +# delta is an integer >= -1. If delta = -1, only check if the new graph +# has less or equal number of nodes +def check(f, t, delta, check_val=True, graph_input=False): + if graph_input: + fx_g = f + else: + fx_g = make_fx(f)(t) + new_graph = fx_graph_cse(fx_g.graph) + new_g = fx.GraphModule(fx_g, new_graph) + + # the number of nodes decrease/ or stay the same + old_num_nodes = len(fx_g.graph.nodes) + new_num_nodes = len(new_graph.nodes) + if delta == -1: + assert old_num_nodes >= new_num_nodes, ( + f"number of nodes increased {old_num_nodes}, {new_num_nodes}") + else: + assert old_num_nodes == new_num_nodes + delta, ( + f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}") + + # a second pass should not reduce more nodes + pass_2_graph = fx_graph_cse(new_graph) + pass_2_num_nodes = len(pass_2_graph.nodes) + assert pass_2_num_nodes == new_num_nodes, ( + f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}") + + # check correctness + if check_val: + true_result = fx_g(t) + our_result = new_g(t) + if true_result is None: # both return None + assert our_result is None, f"true result is None, CSE result is {our_result}" + else: # results returned are the same + assert torch.all(true_result == our_result), ( + f"results are different {true_result}, {our_result}") # check results are the same + + +class NoChangeTestCase(TestCase): + + def test_nochange(self): + def f(x): + a = x + 1 + b = x + a + a = x + d = x + a + return b + d + t = torch.randn(2, 2) + check(f, t, 0) + + def test_empty(self): + def f(x): + pass + t = torch.randn(2, 2) + check(f, t, 0) + + def test_rand_like(self): + def f(x): + a = torch.rand_like(x) + b = torch.rand_like(x) + return a + b + t = torch.randn(2, 2) + check(f, t, 0, check_val=False) + + def test_rand_n(self): + def f(x): + a = torch.randn(4) + b = torch.randn(4) + return a + b + t = torch.randn(2, 2) + check(f, t, 0, check_val=False) + + +class ReduceTestCase(TestCase): + + def test_immutable_list_type(self): + def f(x): + a = x.sum(dim=1) + b = x.sum(dim=1) + c = x.sum() + d = x.sum() + return a + b + c + d + t = torch.randn(2, 2) + check(f, t, 2) + + def test_immutable_list_multiple_entries(self): + def f(x): + a = x.sum(dim=[0, 1]) + b = x.sum(dim=[0, 1]) + c = x.sum(dim=1) + d = x.sum(dim=1) + return a + b + c + d + t = torch.randn(2, 2) + check(f, t, 2) + + def test_simple(self): + def f(x): + a = x.cos() + b = x.cos() + c = a + a + d = b + b + return c + d + t = torch.randn(2, 2) + check(f, t, 2) + + def test_simple_2(self): + def f(x): + a = x.cos().sin() + b = x.cos().sin() + c = a + a + d = b + b + return c + d + t = torch.randn(1) + check(f, t, 3) + + def test_two_args_default(self): + def f(x): + a = x.sum(dim=1) + b = x.sum(dim=1, keepdim=False) + c = x.sum(dim=1, keepdim=False) + d = x.sum(dim=1) + return a + b + c + d + t = torch.randn(2, 2) + check(f, t, 3) + + def test_two_args(self): + def f(x): + a = x.sum(dim=1) + b = x.sum(dim=1, keepdim=True) + c = x.sum(dim=1, keepdim=True) + d = x.sum(dim=1) + return a + b + c + d + t = torch.randn(2, 2) + check(f, t, 2) + + def test_simple_multiple_same_ops(self): + def f(x): + a = x.sum() + b = x.sum() + c = x.sum() + d = x.sum() + return a + b + c + d + t = torch.randn(2, 2) + check(f, t, 3) + + def test_nested_immutable_list_type(self): + def f(x): + a = torch.cat((x, x)) + b = torch.cat((x, x)) + return a + b + t = torch.randn(2, 2) + check(f, t, 1) + + def test_kwarg(self): + def f(x): + a = torch.ones_like(x) + b = torch.ones_like(x) + return a + b + t = torch.randn(2, 2) + check(f, t, 1) + + +class RandomOpTestCase(TestCase): + def test_random(self): + def f(x): + vals = [x] + ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu] + for _ in range(100): + new_val = random.choice(ops)(random.choice(vals)) + vals.append(new_val) + return vals[-1] + + fx_g = fx.symbolic_trace(f) + fx_g.graph.eliminate_dead_code() + fx_g.recompile() + t = torch.randn(2, 2) + + for _ in range(30): + check(fx_g, t, -1, graph_input=True) + + + +if __name__ == "__main__": + run_tests() diff --git a/functorch/test/test_minifier.py b/functorch/test/test_minifier.py new file mode 100644 index 0000000000000..4f026c185c50a --- /dev/null +++ b/functorch/test/test_minifier.py @@ -0,0 +1,53 @@ +# Owner(s): ["module: functorch"] + +import torch +from functorch.compile import minifier +from functorch import make_fx +from torch.testing._internal.common_utils import TestCase, run_tests + + +class TestMinifier(TestCase): + # https://github.com/pytorch/functorch/issues/913 + def test_has_mul_minifier(self): + def failing_f(x, y): + y = y / 3 + x = x + 3 + x = x * y + return x + y + inps = [torch.randn(3), torch.randn(3)] + failing_f = make_fx(failing_f)(*inps) + + def pass_checker(fx_g, inps): + return (torch.ops.aten.mul.Tensor in set([i.target for i in fx_g.graph.nodes])) + + min_f, inps = minifier(failing_f, inps, pass_checker) + assert len(min_f.graph.nodes) == 4 + assert len(inps) == 2 + + def test_has_add_mul(self): + def failing_f(x): + x = x * 3 + x = x + 5 + x = x.cos() + zero = x - x + result = zero / zero + result = result + 3 + return (result * 2,) + + inps = [torch.randn(3)] + failing_f = make_fx(failing_f)(*inps) + + def pass_checker(fx_g, inps): + # Basically, make sure none of the inputs are nans + for i in inps: + if torch.isnan(i).any(): + return False + return torch.isnan(fx_g(*inps)[0]).any() + + min_f, inps = minifier(failing_f, inps, pass_checker) + assert len(min_f.graph.nodes) == 3 + assert len(inps) == 1 + + +if __name__ == "__main__": + run_tests() diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py new file mode 100644 index 0000000000000..6fcc0e19e7b8b --- /dev/null +++ b/functorch/test/test_ops.py @@ -0,0 +1,1355 @@ +# Owner(s): ["module: functorch"] + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors +import torch +from torch import Tensor +import functools +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_device_type import ops +from torch.testing._internal.common_device_type import \ + toleranceOverride, tol +from functorch_lagging_op_db import functorch_lagging_op_db +from functorch_additional_op_db import additional_op_db +from common_utils import ( + get_fallback_and_vmap_exhaustive, + get_exhaustive_batched_inputs, + get_exhaustive_batched_inputs_batch_norm_is_training, + xfail, + skip, + skipOps, + tol1, + # tol2, + opsToleranceOverride, + check_vmap_fallback, + is_batch_norm_training, +) +from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map +from functorch import grad, vjp, vmap, jacrev, jacfwd +import torch.autograd.forward_ad as fwAD +from functorch._src.eager_transforms import _as_tuple, jvp +aten = torch.ops.aten + + +# Version of autograd.grad with some differences: +# - pytree inputs is allowed (but leaves of the pytree have to all +# be tensors) +# - if an input is not used as part of derivatives, we will return a +# zero-filled tensor for the result +def _autograd_grad( + outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True +): + inputs, inputs_spec = tree_flatten(inputs) + diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) + if grad_outputs is None: + diff_outputs = tuple(out for out in outputs if out.requires_grad) + else: + diff_grad_outputs = [ + (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad + ] + if len(diff_grad_outputs) == 0: + diff_outputs, grad_outputs = (), () + else: + diff_outputs, grad_outputs = zip(*diff_grad_outputs) + grad_inputs = torch.autograd.grad( + diff_outputs, + diff_inputs, + grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + allow_unused=True, + ) + result = [] + grad_inputs_iter = iter(grad_inputs) + for inp in inputs: + if inp.requires_grad: + grad_input = next(grad_inputs_iter) + if grad_input is None: + result.append(torch.zeros_like(inp)) + else: + result.append(grad_input) + else: + result.append(torch.zeros_like(inp)) + return tree_unflatten(result, inputs_spec) + + +def diff_arg(arg, requires_grad=True): + def is_differentiable_arg(arg): + if requires_grad: + return arg.requires_grad + else: + return arg.is_floating_point() or arg.is_complex() + if is_iterable_of_tensors(arg): + if all([is_differentiable_arg(a) for a in arg]): + return True + if all([not is_differentiable_arg(a) for a in arg]): + return False + raise RuntimeError("NYI: The test runner can't handle this") + return isinstance(arg, Tensor) and is_differentiable_arg(arg) + + +# Given f, returns an f' such that: +# - f' takes only positional arguments +# - All arguments to f' are floating-point Tensors +# - All outputs of f' are floating-point Tensors +def normalize_op_input_output2(f, args, kwargs, output_process_fn_grad=None, requires_grad=True): + flat_args, args_spec = tree_flatten(args) + diff_argnums = tuple(i for i, arg in enumerate(flat_args) if diff_arg(arg, requires_grad=requires_grad)) + assert len(diff_argnums) > 0 + primals = tuple(flat_args[i] for i in diff_argnums) + + @functools.wraps(f) + def wrapped(*primals): + _args = list(flat_args) + for num, arg in zip(diff_argnums, primals): + _args[num] = arg + _args = tree_unflatten(_args, args_spec) + result = f(*_args, **kwargs) + if output_process_fn_grad is not None: + result = output_process_fn_grad(result) + if isinstance(result, tuple): + # TODO: Remove the following hack for namedtuples + result = tuple(result) + result = tuple(r for r in result if torch.is_floating_point(r)) + assert len(result) > 0 + return result + return wrapped, primals + + +# TODO: consolidate with normalize_op_input_output2 +def normalize_op_input_output3(f, args, kwargs, sample_args, output_process_fn_grad=None): + flat_args, args_spec = tree_flatten(args) + flat_sample_args, _ = tree_flatten(sample_args) + diff_argnums = tuple(i for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args)) + if diff_arg(sample, requires_grad=True)) + assert len(diff_argnums) > 0 + primals = tuple(flat_args[i] for i in diff_argnums) + + @functools.wraps(f) + def wrapped(*primals): + _args = list(flat_args) + for num, arg in zip(diff_argnums, primals): + _args[num] = arg + _args = tree_unflatten(_args, args_spec) + result = f(*_args, **kwargs) + if output_process_fn_grad is not None: + result = output_process_fn_grad(result) + if isinstance(result, tuple): + # TODO: Remove the following hack for namedtuples + result = tuple(result) + result = tuple(r for r in result if torch.is_floating_point(r)) + assert len(result) > 0 + return result + return wrapped, primals + + +def normalize_op_input_output(f, sample, requires_grad=True): + args = tuple([sample.input] + list(sample.args)) + return normalize_op_input_output2( + f, args, sample.kwargs, sample.output_process_fn_grad, requires_grad=requires_grad + ) + + +def ref_vjp(f, *primals): + result = f(*primals) + + def wrapped(cotangents): + return _autograd_grad(_as_tuple(result), primals, _as_tuple(cotangents)) + + return result, wrapped + + +def simulate_jvp(f, primals, tangents): + primals_out, tangents_out = torch.autograd.functional.jvp(f, primals, tangents) + return primals_out, tangents_out + + +def ref_jvp(f, primals, tangents): + with fwAD.dual_level(): + duals = tuple(fwAD.make_dual(p, t) for p, t in zip(primals, tangents)) + result_duals = f(*duals) + result_duals, spec = tree_flatten(result_duals) + primals_out, tangents_out = zip(*(fwAD.unpack_dual(d) for d in result_duals)) + return tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec) + + +def get_sample_cotangents(f, sample): + fn, primals = normalize_op_input_output(f, sample) + output = fn(*primals) + return tree_map(torch.randn_like, output) + + +# returns a new function g(*args, *cotangents) +# that computes vjps and (*args, cotangents) +def get_vjp_fn_and_args_with_cotangents(f, sample, cotangents): + args = tuple([sample.input] + list(sample.args)) + kwargs = sample.kwargs + flat_args, args_spec = tree_flatten(args) + flat_cotangents, cotangents_spec = tree_flatten(cotangents) + + @functools.wraps(f) + def wrapped(*args): + assert len(args) == len(flat_args) + len(flat_cotangents) + actual_args = args[:len(flat_args)] + cotangents = args[len(flat_args):] + actual_args = tree_unflatten(actual_args, args_spec) + cotangents = tree_unflatten(cotangents, cotangents_spec) + + fn, primals = normalize_op_input_output3(f, actual_args, kwargs, + flat_args, + sample.output_process_fn_grad) + _, vjp_fn = vjp(fn, *primals) + return vjp_fn(cotangents) + + return wrapped, tuple(flat_args + flat_cotangents) + + +# Returns a new function g(*args, *cotangents) that computes vjps and +# sample (*args, *cotangents) +def get_vjpfull_variant(f, sample): + fn, primals = normalize_op_input_output(f, sample) + result = fn(*primals) + cotangents = _as_tuple( + tree_map(lambda x: torch.randn_like(x, requires_grad=True), result)) + num_primals = len(primals) + args = (*primals, *cotangents) + + @functools.wraps(f) + def wrapped(*args): + primals = args[:num_primals] + cotangents = args[num_primals:] + result, vjp_fn = vjp(fn, *primals) + if isinstance(result, torch.Tensor): + assert len(cotangents) == 1 + cotangents = cotangents[0] + return vjp_fn(cotangents) + + return wrapped, args + + +def get_jvp_variant(f, sample): + # We want this higher-order variant of jvp, so that it can + # be used to wrap vmap + fn, primals = normalize_op_input_output(f, sample, requires_grad=False) + tangents = _as_tuple( + tree_map(lambda x: torch.randn_like(x), primals)) + + @functools.wraps(f) + def wrapped(*args): + tangents = args + primals_out, tangents_out = jvp(fn, primals, tangents) + + if isinstance(primals_out, torch.Tensor): + return (primals_out, tangents_out) + else: + flat_primals_out, _ = tree_flatten(primals_out) + flat_tangents_out, _ = tree_flatten(tangents_out) + return tuple(flat_primals_out + flat_tangents_out) + + return wrapped, tangents + + +def get_jvp_variant_primals_tangents(f, sample): + # We want this higher-order variant of jvp, so that it can + # be used to wrap vmap + fn, primals = normalize_op_input_output(f, sample, requires_grad=False) + tangents = _as_tuple( + tree_map(lambda x: torch.randn_like(x), primals)) + + @functools.wraps(f) + def wrapped(*args): + primals_in = args[:len(primals)] + tangents_in = args[len(primals):] + primals_out, tangents_out = jvp(fn, primals_in, tangents_in) + + if isinstance(primals_out, torch.Tensor): + return (primals_out, tangents_out) + else: + flat_primals_out, _ = tree_flatten(primals_out) + flat_tangents_out, _ = tree_flatten(tangents_out) + return tuple(flat_primals_out + flat_tangents_out) + + return wrapped, primals + tangents + + +def is_inplace(op, variant): + if hasattr(variant, "__wrapped__"): + return variant.__wrapped__ is op.get_inplace() + return variant is op.get_inplace() + + +vjp_fail = { + xfail('tensor_split'), + xfail('to_sparse'), + xfail('nn.functional.ctc_loss'), + skip('pca_lowrank', ''), # fails on cuda, runs okay on cpu + skip('svd_lowrank', ''), # fails on cuda, runs okay on cpu +} + + +class TestOperators(TestCase): + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestOperators', 'test_grad', vjp_fail.union({ + xfail('linalg.eig'), # diagonal_scatter does not support complex + })) + @opsToleranceOverride('TestOperators', 'test_grad', ( + tol1('nn.functional.binary_cross_entropy_with_logits', + {torch.float32: tol(atol=1e-04, rtol=1e-04)}), + )) + def test_grad(self, device, dtype, op): + if op.name in vjp_fail: + self.skipTest("Skipped; Expected failures") + return + + if not op.supports_autograd: + self.skipTest("Skipped! Autograd not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + for sample in samples: + args = [sample.input] + list(sample.args) + kwargs = sample.kwargs + + diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg)) + assert len(diff_argnums) > 0 + diff_args = tuple(args[i] for i in diff_argnums) + + def wrapped_fn(*args, **kwargs): + result = op(*args, **kwargs) + if sample.output_process_fn_grad is not None: + result = sample.output_process_fn_grad(result) + + # Reduce into single value for grad + if isinstance(result, torch.Tensor): + return result.sum() + result = sum([res.sum() for res in result]) + return result + + result = grad(wrapped_fn, diff_argnums)(*args, **kwargs) + expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args) + + self.assertEqual(result, expected) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestOperators', 'test_jvp', set({ + skip('nn.functional.max_pool1d'), # fails on cpu, runs okay on cuda + skip('pca_lowrank', ''), # fails on cuda, runs okay on cpu + skip('svd_lowrank', ''), # fails on cuda, runs okay on cpu + + # ============================================= + # NB: The above failures also fail using PyTorch core's + # forward-mode AD and vmap. + # The failures below are functorch-specific issues + # ============================================= + + # Composite ops that do bad things. Need to be fixed in PyTorch core. + # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage + xfail('tensor_split'), + + # BUG: runs and produces numerical differences + skip('nn.functional.max_unpool1d'), # fails everywhere except on mac + skip('nn.functional.max_unpool2d'), # fails everywhere except on windows + skip('nn.functional.max_unpool3d'), # fails everywhere except on mac + })) + @opsToleranceOverride('TestOperators', 'test_jvp', ( + tol1('nn.functional.conv_transpose3d', + {torch.float32: tol(atol=1e-04, rtol=1.3e-06)}, device_type='cuda'), + tol1('nn.functional.binary_cross_entropy_with_logits', + {torch.float32: tol(atol=4e-04, rtol=4e-04)}), + )) + def test_jvp(self, device, dtype, op): + # TODO: when we change supports_autograd to supports_backward_ad, also change in this file + VJP_DECOMP = { + 'nn.functional.logsigmoid', + } + if op.name in VJP_DECOMP: + ref_jvp_local = simulate_jvp + else: + ref_jvp_local = ref_jvp + + if not op.supports_forward_ad and op.name not in VJP_DECOMP: + self.skipTest("Skipped! Forward AD not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + for sample in samples: + # NB: we used requires_grad=True to determine where the primals are, + # but don't need that information otherwise + fn, primals = normalize_op_input_output(op, sample, requires_grad=True) + primals = tree_map(lambda x: x.detach(), primals) + tangents = tree_map(lambda x: torch.randn_like(x), primals) + primal_outs, tangent_outs = jvp(fn, primals, tangents) + expected_primal_outs, expected_tangent_outs = ref_jvp_local(fn, primals, tangents) + self.assertEqual(primal_outs, expected_primal_outs) + self.assertEqual(tangent_outs, expected_tangent_outs) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestOperators', 'test_vjp', vjp_fail.union({ + xfail('pca_lowrank', ''), + xfail('svd_lowrank', ''), + })) + @opsToleranceOverride('TestOperators', 'test_vjp', ( + tol1('nn.functional.conv_transpose3d', + {torch.float32: tol(atol=5e-05, rtol=9e-05)}, device_type='cuda'), + tol1('nn.functional.binary_cross_entropy_with_logits', + {torch.float32: tol(atol=1e-04, rtol=1e-04)}), + )) + def test_vjp(self, device, dtype, op): + if not op.supports_autograd: + self.skipTest("Skipped! Autograd not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + def _test(_op): + for sample in samples: + fn, primals = normalize_op_input_output(_op, sample) + result = fn(*primals) + cotangents = tree_map(lambda x: torch.randn_like(x), result) + + out, vjp_fn = vjp(fn, *primals) + self.assertEqual(out, result) + result_vjps = vjp_fn(cotangents) + + _, vjp_fn = ref_vjp(fn, *primals) + expected_vjps = vjp_fn(cotangents) + + self.assertEqual(result_vjps, expected_vjps) + + _test(op) + for a_op in op.aliases: + _test(a_op) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestOperators', 'test_vjpvjp', vjp_fail.union({ + skip('nn.functional.max_unpool1d'), # Flaky + skip('nn.functional.max_unpool2d'), # Flaky + })) + @opsToleranceOverride('TestOperators', 'test_vjpvjp', ( + tol1('nn.functional.conv_transpose3d', + {torch.float32: tol(atol=5e-05, rtol=9e-05)}, device_type='cuda'), + )) + def test_vjpvjp(self, device, dtype, op): + if not op.supports_autograd: + self.skipTest("Skipped! Autograd not supported.") + return + if not op.supports_gradgrad: + self.skipTest("Skipped! Operation does not support gradgrad") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + for sample in samples: + fn, args = get_vjpfull_variant(op, sample) + result = fn(*args) + cotangents = tree_map(lambda x: torch.randn_like(x), result) + + # Compute vjp of vjp + _, vjp_fn = vjp(fn, *args) + result_vjps = vjp_fn(cotangents) + + # Compute ref_vjp of vjp. We could have done ref_vjp of ref_vjp, + # but since we're confident that vjp works by itself, this is + # an equivalent way to test that. + _, vjp_fn = ref_vjp(fn, *args) + expected_vjps = vjp_fn(cotangents) + + self.assertEqual(result_vjps, expected_vjps) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) + def test_vmapvjpvjp(self, device, dtype, op): + self.skipTest("Skipped; these tests take too long") + op_skip = set({ + }) + op_skip = op_skip.union(vjp_fail) + if op.name in op_skip: + self.skipTest("Skipped; Expected failures") + return + + if not op.supports_autograd: + self.skipTest("Skipped! Autograd not supported.") + return + if not op.supports_gradgrad: + self.skipTest("Skipped! Operation does not support gradgrad") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + for sample in samples: + fn, args = get_vjpfull_variant(op, sample) + result = fn(*args) + cotangents = tree_map(lambda x: torch.randn_like(x), result) + cotangents, _ = tree_flatten(cotangents) + num_args = len(args) + + args_and_cotangents = tuple(args) + tuple(cotangents) + + def vjp_of_vjp(*args_and_cotangents): + args = args_and_cotangents[:num_args] + cotangents = args_and_cotangents[num_args:] + result, vjp_fn = vjp(fn, *args) + result_vjps = vjp_fn(cotangents) + result, _ = tree_flatten(result) + result_vjps, _ = tree_flatten(result_vjps) + return (*result, *result_vjps) + + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) + generator = get_fallback_and_vmap_exhaustive( + vjp_of_vjp, args_and_cotangents, {}, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: + self.assertEqual(loop_out, batched_out) + + vmapvjp_fail = vjp_fail.union({ + # The following are not bugs and are expected behavior + xfail('masked_select'), # Not possible due to dynamic shapes + skip('bernoulli'), # randomness + skip('normal', ''), # randomness + skip('normal', 'number_mean'), # randomness + skip('nn.functional.rrelu'), # randomness + skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness + skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness + skip('nn.functional.dropout'), # randomness + skip('nn.functional.dropout2d'), # randomness + xfail('as_strided'), # as_strided is too wild for us to support, wontfix + xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset + xfail('masked_scatter'), # dynamic + xfail('nn.functional.fractional_max_pool2d'), # random + xfail('nn.functional.fractional_max_pool3d'), # random + xfail('take'), # dynamic + + # All of the following are bugs and need to be fixed + skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule + xfail('__getitem__', ''), # dynamic error + xfail('_masked.prod'), # calls aten::item + xfail('eig'), # calls aten::item + xfail('linalg.eig'), # Uses aten::allclose + xfail('linalg.householder_product'), # needs select_scatter + xfail('nanquantile'), # checks q via a .item() call + xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0 + xfail('prod'), # calls nonzero + xfail('quantile'), # checks q via a .item() call + xfail('stft'), + xfail('view_as_complex'), + + # required rank 4 tensor to use channels_last format + xfail('bfloat16'), + xfail('double'), + xfail('float'), + xfail('half'), + + xfail('scatter_reduce', 'prod'), # item call + + # NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format + xfail('nn.functional.max_unpool2d'), + xfail('nn.functional.max_unpool2d', 'grad'), + }) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) + @opsToleranceOverride('TestOperators', 'test_vmapvjp', ( + tol1('linalg.svd', + {torch.float32: tol(atol=1.5e-04, rtol=1e-04)}, device_type="cuda"), + tol1('svd', + {torch.float32: tol(atol=1.5e-04, rtol=1e-04)}, device_type="cuda"), + )) + @skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail) + def test_vmapvjp(self, device, dtype, op): + if not op.supports_autograd: + self.skipTest("Skipped! Autograd not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + for sample in samples: + cotangents = get_sample_cotangents(op, sample) + fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents) + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) + generator = get_fallback_and_vmap_exhaustive( + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: + self.assertEqual(loop_out, batched_out) + + vmapjvpall_fail = { + # The following are expected (not a bug) + skip('bernoulli', ''), # randomness + skip('nn.functional.dropout'), # randomness + skip('nn.functional.rrelu'), # randomness + skip('nn.functional.dropout2d', ''), + skip('nn.functional.feature_alpha_dropout', 'without_train'), + skip('nn.functional.feature_alpha_dropout', 'with_train'), + xfail('nn.functional.fractional_max_pool2d'), # Cannot access data pointer of Tensor that doesn't have storage + xfail('nn.functional.fractional_max_pool3d'), # Cannot access data pointer of Tensor that doesn't have storage + + # The following are bugs that we should fix + skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda + xfail('_masked.mean'), + xfail('_masked.prod'), + + # Not actually a problem: embedding with max_norm mutates the weight + # and causes different runs to produce different results. + # skip because this is flaky depending on what the max_norm is! + skip('nn.functional.embedding', ''), + xfail('nn.functional.soft_margin_loss', ''), + xfail('linalg.householder_product'), + xfail('tensor_split'), + xfail('quantile'), + xfail('as_strided'), + xfail('nn.functional.gaussian_nll_loss'), + xfail('scatter'), + xfail('nanquantile'), + xfail('view_as_complex'), + xfail('prod'), + + skip('pca_lowrank', ''), + skip('svd_lowrank', ''), + + xfail('stft'), # transpose_ fallback + + xfail('double'), # required rank 4 tensor to use channels_last format + + skip('nn.functional.max_unpool1d'), # Flaky, seems to sometimes his max_unpool2d + skip('nn.functional.max_unpool2d'), # fails everywhere except on mac + skip('nn.functional.max_unpool3d'), # fails everywhere except on mac + + xfail('nn.functional.prelu'), # Call Tensor.as_strided + + # erroring because running_mean and running_var aren't differentiable + xfail('nn.functional.batch_norm'), + xfail('nn.functional.batch_norm', 'without_cudnn'), + } + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @opsToleranceOverride('TestOperators', 'test_vmapjvpall', ( + tol1('nn.functional.conv_transpose3d', + {torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type='cuda'), + )) + @skipOps('TestOperators', 'test_vmapjvpall', vmapjvpall_fail) + @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) + # This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp + # or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact + # because that coresponds to "batched forward-mode AD" testing in PyTorch core + def test_vmapjvpall(self, device, dtype, op): + if is_inplace(op, op.get_op()): + # TODO: test in-place + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=False) + + if not op.supports_forward_ad: + self.skipTest("Skipped! Forward AD not supported.") + return + + for sample in samples: + arg_values = [sample.input] + list(sample.args) + kwarg_values = sample.kwargs + args = tuple(arg_values) + tuple(kwarg_values) + fn, args = get_jvp_variant_primals_tangents(op, sample) + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) + generator = get_fallback_and_vmap_exhaustive( + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: + self.assertEqual(loop_out, batched_out) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({ + xfail('linalg.solve_triangular'), + xfail('nn.functional.huber_loss'), + xfail('lu'), + skip('linalg.det', 'singular'), # https://github.com/pytorch/functorch/issues/961 + xfail('cumprod'), + xfail('lu_solve'), + xfail('linalg.det'), + xfail('linalg.lstsq', 'grad_oriented'), + xfail('linalg.cholesky'), + xfail('linalg.qr'), + xfail('cross'), + xfail('qr'), + xfail('linalg.pinv'), + xfail('masked_fill'), + xfail('copysign'), + xfail('linalg.solve'), + xfail('linalg.eig'), + xfail('complex'), + xfail('linalg.pinv', 'hermitian'), + xfail('matrix_exp'), + xfail('pinverse'), + skip('_masked.mean'), # ??? + xfail('linalg.cholesky_ex'), + xfail('masked_scatter'), + xfail('index_fill'), + xfail('put'), + xfail('take'), + xfail('linalg.eigvals'), + xfail('linalg.qr'), + xfail('linalg.tensorsolve'), + xfail('nn.functional.max_pool3d'), + xfail('vdot'), + xfail('linalg.cross'), + xfail('nanmean'), + xfail('nansum'), + xfail('nn.functional.feature_alpha_dropout', 'without_train'), + xfail('linalg.lu_factor', ''), + xfail('nn.functional.dropout2d', ''), + xfail('pca_lowrank', ''), + xfail('svd_lowrank', ''), + xfail('linalg.lu_factor_ex', ''), + xfail('nn.functional.feature_alpha_dropout', 'with_train'), + xfail('special.log_ndtr', ''), + xfail('fft.ihfft2'), # conj_physical fallback + xfail('fft.ihfftn'), # conj_physical fallback + xfail('istft'), # col2im fallback + xfail('polar'), # complex fallback + xfail('nn.functional.max_unpool3d', 'grad'), + xfail('nn.functional.smooth_l1_loss', ''), + xfail('nn.functional.max_unpool2d', 'grad'), + xfail('nn.functional.soft_margin_loss', ''), + xfail('nn.functional.max_unpool1d', 'grad'), + xfail('nn.functional.embedding', ''), + xfail('lu_unpack'), + xfail('nn.functional.glu'), + xfail('nn.functional.bilinear'), # trilinear doesn't have batching rule + xfail('linalg.eigh'), # _linalg_eigh doesn't have batching rule + xfail('linalg.eigvalsh'), # _linalg_eigh doesn't have batching rule + xfail('logdet'), # _linalg_slogdet doesn't have batching rule + xfail('linalg.slogdet'), # _linalg_slogdet doesn't have batching rule + })) + @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) + def test_vmapjvpall_has_batch_rule(self, device, dtype, op): + if is_inplace(op, op.get_op()): + # TODO: test in-place + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=False) + + if not op.supports_forward_ad: + self.skipTest("Skipped! Forward AD not supported.") + return + + def test(): + for sample in samples: + arg_values = [sample.input] + list(sample.args) + kwarg_values = sample.kwargs + args = tuple(arg_values) + tuple(kwarg_values) + fn, args = get_jvp_variant_primals_tangents(op, sample) + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) + for loop_out, batched_out in get_fallback_and_vmap_exhaustive( + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): + pass + check_vmap_fallback(self, test, op, dry_run=False) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) + @skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({ + xfail('view_as_complex'), + xfail('cholesky'), + xfail('complex'), + xfail('copysign'), + xfail('cummax'), + xfail('cummin'), + xfail('cumprod'), + xfail('eig'), + xfail('nansum'), + xfail('nanmean'), + xfail('special.log_ndtr'), + xfail('index_copy'), + xfail('index_fill'), + xfail('linalg.cholesky'), + xfail('linalg.cholesky_ex'), + xfail('linalg.det'), + xfail('linalg.eig'), + xfail('linalg.eigh'), + xfail('linalg.eigvals'), + xfail('linalg.householder_product'), + xfail('linalg.lstsq', ''), + xfail('linalg.lstsq', 'grad_oriented'), + xfail('linalg.pinv'), + xfail('linalg.qr'), + xfail('linalg.pinv', 'hermitian'), + xfail('linalg.slogdet'), + xfail('linalg.solve'), + xfail('logdet'), + xfail('lu'), + xfail('lu_solve'), + xfail('lu_unpack'), + xfail('masked_fill'), + xfail('masked_scatter'), + xfail('masked_select'), + xfail('matrix_exp'), + xfail('nanquantile'), + xfail('pinverse'), + xfail('prod'), + xfail('put'), + skip('linalg.det'), # https://github.com/pytorch/functorch/issues/961 + xfail('quantile'), + xfail('renorm'), + xfail('take'), + xfail('tensor_split'), + xfail('to_sparse'), + xfail('unfold'), + xfail('vdot'), + xfail('nn.functional.dropout'), + xfail('_masked.prod'), + xfail('fft.ihfft2'), + xfail('fft.ihfftn'), + xfail('cross'), + xfail('linalg.cross'), + xfail('nn.functional.gaussian_nll_loss'), + xfail('nn.functional.huber_loss'), + xfail('nn.functional.bilinear'), + xfail('nn.functional.fractional_max_pool3d'), + xfail('as_strided'), + xfail('linalg.solve_triangular'), + xfail('stft'), + xfail('nn.functional.rrelu'), + xfail('nn.functional.embedding_bag'), + xfail('nn.functional.max_pool3d'), + xfail('istft'), + xfail('nn.functional.fractional_max_pool2d'), + xfail('linalg.tensorsolve'), + xfail('linalg.lu_factor', ''), + xfail('nn.functional.feature_alpha_dropout', 'with_train'), + xfail('pca_lowrank', ''), + xfail('nn.functional.dropout2d', ''), + xfail('nn.functional.feature_alpha_dropout', 'without_train'), + xfail('svd_lowrank', ''), + xfail('linalg.lu_factor_ex', ''), + + xfail('nn.functional.max_unpool2d', ''), + xfail('nn.functional.multi_margin_loss', ''), + xfail('nn.functional.multilabel_margin_loss', ''), + xfail('nn.functional.pdist', ''), + xfail('nn.functional.smooth_l1_loss', ''), + xfail('scatter_reduce', 'prod'), + xfail('scatter_reduce', 'amax'), + xfail('nn.functional.max_unpool1d', ''), + xfail('nn.functional.max_unpool3d', ''), + xfail('scatter_reduce', 'sum'), + xfail('scatter_reduce', 'mean'), + xfail('nn.functional.max_unpool3d', 'grad'), + xfail('nn.functional.soft_margin_loss', ''), + xfail('scatter_reduce', 'amin'), + xfail('nn.functional.max_unpool1d', 'grad'), + xfail('nn.functional.max_unpool2d', 'grad'), + xfail('qr'), + xfail('linalg.eigvalsh'), # _linalg_eigh doesn't have batching rule + })) + def test_vmapvjp_has_batch_rule(self, device, dtype, op): + if not op.supports_autograd: + self.skipTest("Skipped! Autograd not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + def test(): + for sample in samples: + cotangents = get_sample_cotangents(op, sample) + fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents) + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) + for loop_out, batched_out in get_fallback_and_vmap_exhaustive( + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): + pass + for a_op in op.aliases: + fn, args = get_vjp_fn_and_args_with_cotangents(a_op, sample, cotangents) + for loop_out, batched_out in get_fallback_and_vmap_exhaustive( + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): + pass + + check_vmap_fallback(self, test, op, dry_run=False) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestOperators', 'test_vjpvmap', vjp_fail.union({ + skip('bernoulli', ''), # vjpvmap testing can't handle randomness + skip('normal', ''), # vjpvmap testing can't handle randomness + skip('normal', 'number_mean'), # vjpvmap testing can't handle randomness + skip('nn.functional.rrelu'), # randomness + skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness + skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness + + # fallback path doesn't work + # All of the following are bugs and need to be fixed + xfail('__getitem__', ''), + xfail('index_put', ''), + xfail('matrix_exp'), + xfail('view_as_complex'), + xfail('nn.functional.gaussian_nll_loss'), + xfail('masked_select'), + skip('nn.functional.fractional_max_pool3d'), # generator works on cpu, fails on cuda + xfail('__rpow__'), # https://github.com/pytorch/functorch/issues/617 + xfail('as_strided'), + skip('nn.functional.fractional_max_pool2d'), # generator works on cpu, fails on cuda + xfail('column_stack', ''), + xfail('nn.functional.dropout2d', ''), + xfail('svd_lowrank', ''), + xfail('pca_lowrank', ''), + xfail('clamp'), + # something weird happening with channels_last + xfail('bfloat16'), + xfail('double'), + xfail('float'), + xfail('half'), + })) + def test_vjpvmap(self, device, dtype, op): + # NB: there is no vjpvmap_has_batch_rule test because that is almost + # certainly redundant with the vmap_has_batch_rule test in test_vmap.py + + # one-off skip + if op.name == 'nn.functional.dropout': + self.skipTest("Skipped!") + + if not op.supports_autograd: + # If the op doesn't support autograd, vmap(op) won't either + self.skipTest("Skipped! Autograd not supported.") + return + + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm + is_batch_norm = op.name in batch_norm_fns + + for sample in samples: + args = [sample.input] + list(sample.args) + kwargs = sample.kwargs + if is_batch_norm and is_batch_norm_training(op.name, kwargs): + generator = get_exhaustive_batched_inputs_batch_norm_is_training(args, kwargs) + else: + generator = get_exhaustive_batched_inputs(args, kwargs) + + for batched_args, in_dims, kwargs in generator: + vmapped_op = vmap(op, in_dims) + fn, primals = normalize_op_input_output2(vmapped_op, batched_args, kwargs, + sample.output_process_fn_grad) + result = fn(*primals) + cotangents = tree_map(lambda x: torch.randn_like(x), result) + + _, vjp_fn = vjp(fn, *primals) + result_vjps = vjp_fn(cotangents) + + _, vjp_fn = ref_vjp(fn, *primals) + expected_vjps = vjp_fn(cotangents) + + self.assertEqual(result_vjps, expected_vjps) + + def _compare_jacobians_of_vjp(self, fn, cotangents_and_primals, argnums=None, atol_rtol=None): + if argnums is None: + argnums = tuple(range(len(cotangents_and_primals))) + + def get_vjp(cotangents, *primals): + _, vjp_fn = vjp(fn, *primals) + return vjp_fn(cotangents) + + jacobian_jvp = jacfwd(get_vjp, argnums)(*cotangents_and_primals) + jacobian_vjp = jacrev(get_vjp, argnums)(*cotangents_and_primals) + + # For dtype changing operations, the jacobians have different dtype. + jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp) + jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp) + + if atol_rtol is not None: + (atol, rtol) = atol_rtol + self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol) + else: + self.assertEqual(jacobian_jvp, jacobian_vjp) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestOperators', 'test_jvpvjp', vjp_fail.union({ + # RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor, + # this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3]. + xfail('normal', ''), + xfail('_masked.log_softmax', ''), + xfail('_masked.softmax', ''), + xfail('_masked.softmin', ''), + xfail('cdist', ''), + xfail('cholesky', ''), + xfail('eig', ''), + xfail('logcumsumexp', ''), + xfail('nn.functional.embedding_bag', ''), + xfail('nn.functional.grid_sample', ''), + xfail('nn.functional.hardsigmoid', ''), + xfail('nn.functional.huber_loss', ''), + xfail('nn.functional.instance_norm', ''), + xfail('nn.functional.logsigmoid', ''), + xfail('nn.functional.softmin', ''), + xfail('nn.functional.softmin', 'with_dtype'), + xfail('renorm', ''), + xfail('symeig', ''), + xfail('pca_lowrank', ''), + xfail('svd_lowrank', ''), + xfail('nn.functional.multilabel_margin_loss', ''), + xfail('nn.functional.multilabel_soft_margin_loss', ''), + xfail('scatter_reduce', 'amax'), + xfail('scatter_reduce', 'amin'), + xfail('nn.functional.soft_margin_loss', ''), + xfail('nn.functional.pdist', ''), + xfail('scatter_reduce', 'sum'), + xfail('nn.functional.multi_margin_loss', ''), + xfail('scatter_reduce', 'mean'), + xfail('scatter_reduce', 'prod'), + skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why + })) + def test_jvpvjp(self, device, dtype, op): + if not op.supports_autograd: + self.skipTest("Skipped! Autograd not supported.") + return + + samples = op.sample_inputs(device, dtype, requires_grad=True) + + # TODO: test in-place + if is_inplace(op, op.get_op()): + self.skipTest("Skipped! NYI: inplace-testing not supported.") + return + + for sample in samples: + fn, primals = normalize_op_input_output(op, sample) + result = fn(*primals) + cotangents = tree_map(lambda x: torch.randn_like(x), result) + + primals_tangents = tree_map(lambda x: torch.randn_like(x), primals) + cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents) + + if isinstance(primals[0], torch.Tensor) and primals[0].numel() == 0: + # typically the first primal arg is the input. If the input has no elements, we will typically run + # into an issue of "Expected Tensor but got None" + continue + + def push_vjp(primals, cotangents): + _, vjp_fn = vjp(fn, *primals) + return vjp_fn(cotangents) + + result = jvp(push_vjp, (primals, cotangents), (primals_tangents, cotangents_tangents)) + self.assertEqual(len(result), 2) + + def tree_map2(fn, first, second): + flat_first, spec_first = tree_flatten(first) + flat_second, spec_second = tree_flatten(second) + assert spec_first == spec_second + flat_result = [fn(f, s) for f, s in zip(flat_first, flat_second)] + return tree_unflatten(flat_result, spec_first) + + def reference(primals, cotangents, primals_tangents, cotangents_tangents): + with fwAD.dual_level(): + primal_duals = tree_map2(fwAD.make_dual, primals, primals_tangents) + _, vjp_fn = ref_vjp(fn, *primal_duals) + + cotangent_duals = tree_map2(fwAD.make_dual, cotangents, cotangents_tangents) + result = vjp_fn(cotangent_duals) + + flat_result, spec = tree_flatten(result) + primals_out, tangents_out = zip(*[fwAD.unpack_dual(r) for r in flat_result]) + tangents_out = [t if t is not None else torch.zeros_like(p) + for p, t in zip(primals_out, tangents_out)] + expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec)) + return expected + + # HACK: obviously pytorch should also have the same coverage + # For things that do have the same coverage, we test that jvp x vjp + # are the same between PyTorch and functorch. For things that don't, + # we check that jacfwd(vjp) and jacrev(vjp) are the same. This results + # in slower tests. + FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = { + 'nn.functional.nll_loss', + 'softmax', + 'log_softmax', + 'nn.functional.cross_entropy', + 'nn.functional.layer_norm', + 'nn.functional.batch_norm', + } + if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH: + self.assertFalse(op.supports_fwgrad_bwgrad, + f"{op.name} now supports forward over reverse without a decomposition. " + + "Please remove the decomposition version") + + def is_differentiable(t): + return isinstance(t, torch.Tensor) and t.dtype == torch.float32 + args = (cotangents, *primals) + if op.name == 'nn.functional.binary_cross_entropy': + argnums = (0, 1) # targets is float32 but isn't differentiable + atol_rtol = 1.5e-4, 1.3e-06 + else: + argnums = tuple(i for i in range(len(args)) if is_differentiable(args[i])) + atol_rtol = None + self._compare_jacobians_of_vjp(fn, args, argnums, atol_rtol) + else: + expected = reference(primals, cotangents, primals_tangents, cotangents_tangents) + self.assertEqual(result, expected) + + def _make_extremal_inputs(self, shape, device): + if shape is None: + return (None,) + return ( + torch.full(shape, -1000., device=device), + torch.zeros(shape, device=device), + torch.full(shape, 1000., device=device), + ) + + def _arg_and_kwarg_options(self, args_options, kwargs_options): + return itertools.product(*args_options, kwargs_options) + + def test_extremal_numerics_nll_loss(self, device): + N, C = 3, 4 + d1, d2, d3 = 5, 6, 7 + shapes = ( + ((N, C), (N,), (C,)), + ((N, C), (N,), None), + ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)), + ((N, C, d1, d2, d3), (N, d1, d2, d3), None), + ) + kwargs_options = ({'ignore_index': 0, 'reduction': 'mean'}, {'reduction': 'sum'}, {'reduction': 'none'}, {}) + for input_shape, target_shape, weight_shape in shapes: + input_options = self._make_extremal_inputs(input_shape, device) + for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options): + if weight_shape is None: + weight = None + else: + weight = torch.randn(weight_shape, device=device) + target = torch.randint(0, C, target_shape, device=device) + target[0] = 1 # since we're ignoring index 0, at least one element must be non-zero + + fn = functools.partial(torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs) + result = fn(input) + cotangents = torch.randn_like(result, device=device) + self._compare_jacobians_of_vjp(fn, (cotangents, input)) + + def test_extremal_numerics_l1_loss(self, device): + N, C, H, W = 3, 4, 5, 6 + shapes = ((N, C), (N, C, H), (N, C, H, W)) + kwargs_options = ({'reduction': 'sum'}, {'reduction': 'none'}, {}) + for shape in shapes: + input_options = self._make_extremal_inputs(shape, device) + target_options = self._make_extremal_inputs(shape, device) + for input, target, kwargs in self._arg_and_kwarg_options((input_options, target_options), kwargs_options): + result = torch.nn.functional.l1_loss(input, target) + cotangents = torch.randn_like(result, device=device) + self._compare_jacobians_of_vjp(torch.nn.functional.l1_loss, (cotangents, input, target)) + + def test_extremal_numerics_mse_loss(self, device): + N, C, H, W = 3, 4, 5, 6 + shapes = ((N, C), (N, C, H), (N, C, H, W)) + kwargs_options = ({'reduction': 'sum'}, {'reduction': 'none'}, {}) + for shape in shapes: + input_options = self._make_extremal_inputs(shape, device) + target_options = self._make_extremal_inputs(shape, device) + for input, target, kwargs in self._arg_and_kwarg_options((input_options, target_options), kwargs_options): + result = torch.nn.functional.mse_loss(input, target) + cotangents = torch.randn_like(result, device=device) + self._compare_jacobians_of_vjp(torch.nn.functional.mse_loss, (cotangents, input, target)) + + def test_extremal_numerics_softmax(self, device): + N, C, H, W = 3, 4, 5, 6 + shapes = ((N, C), (N, C, H), (N, C, H, W)) + kwargs_options = ({'dim': 1}, {}) + for shape in shapes: + input_options = self._make_extremal_inputs(shape, device) + for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options): + result = torch.nn.functional.softmax(input) + cotangents = torch.randn_like(result, device=device) + self._compare_jacobians_of_vjp(torch.nn.functional.softmax, (cotangents, input)) + + + def test_extremal_numerics_log_softmax(self, device): + N, C, H, W = 3, 4, 5, 6 + shapes = ((N, C), (N, C, H), (N, C, H, W)) + kwargs_options = ({'dim': 1}, {}) + for shape in shapes: + input_options = self._make_extremal_inputs(shape, device) + for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options): + result = torch.nn.functional.log_softmax(input) + cotangents = torch.randn_like(result, device=device) + self._compare_jacobians_of_vjp(torch.nn.functional.log_softmax, (cotangents, input)) + + def test_extremal_numerics_cross_entropy(self, device): + N, C = 3, 4 + d1, d2, d3 = 5, 6, 7 + shapes = ( + ((N, C), (N,), (C,)), + ((N, C), (N,), None), + ((N, C), (N, C), (C,)), + ((N, C), (N, C), None), + ((C,), (), (C,)), + ((C,), (), None), + ((C,), (C,), (C,)), + ((C,), (C,), None), + ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)), + ((N, C, d1, d2, d3), (N, d1, d2, d3), None), + ((N, C, d1, d2, d3), (N, C, d1, d2, d3), (C,)), + ((N, C, d1, d2, d3), (N, C, d1, d2, d3), None), + ) + for input_shape, target_shape, weight_shape in shapes: + input_options = self._make_extremal_inputs(input_shape, device) + kwargs_options = [{'reduction': 'sum'}, {'reduction': 'none'}, {}] + if input_shape != target_shape: + kwargs_options.append({'ignore_index': 0, 'reduction': 'mean'}) + + for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options): + if weight_shape is None: + weight = None + else: + weight = torch.randn(weight_shape, device=device) + + if input_shape == target_shape: + target = torch.rand(target_shape, device=device) + elif len(target_shape) == 0: + target = torch.tensor(1, device=device) # must be non-zero since ignore_index may be 0 + else: + target = torch.randint(0, C, target_shape, device=device) + + fn = functools.partial(torch.nn.functional.cross_entropy, target=target, weight=weight, **kwargs) + result = fn(input) + cotangents = torch.randn_like(result, device=device) + self._compare_jacobians_of_vjp(fn, (cotangents, input), atol_rtol=(1e-4, 1e-5)) + + def test_extremal_numerics_binary_cross_entropy(self, device): + N, C, H, W = 3, 4, 5, 6 + shapes = ((N, C), (N, C, H), (N, C, H, W)) + for shape in shapes: + weight_options = self._make_extremal_inputs(shape, device) + kwargs_options = [{'reduction': 'sum'}, {'reduction': 'none'}, {}] + + for weight, kwargs in self._arg_and_kwarg_options((weight_options,), kwargs_options): + input = torch.rand(shape, device=device) + target = torch.rand(shape, device=device) + fn = functools.partial(torch.nn.functional.binary_cross_entropy, target=target, weight=weight, **kwargs) + result = fn(input) + cotangents = torch.randn_like(result, device=device) + self._compare_jacobians_of_vjp(fn, (cotangents, input), atol_rtol=(1e-4, 2e-5)) + + def test_extremal_numerics_layer_norm(self, device): + N, C, H, W = 3, 4, 5, 6 + shapes = ((N, C), (N, C, H), (N, C, H, W)) + for shape in shapes: + input_options = self._make_extremal_inputs(shape, device) + normalized_shape = shape[1:] + weight_options = self._make_extremal_inputs(normalized_shape, device) + bias_options = self._make_extremal_inputs(normalized_shape, device) + + for input, bias, weight in self._arg_and_kwarg_options((input_options, bias_options, weight_options), ()): + def fn(input, weight, bias): + return torch.nn.functional.layer_norm(input, normalized_shape, weight=weight, bias=bias) + result = fn(input, weight, bias) + cotangents = torch.randn_like(result, device=device) + self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias)) + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float32, torch.double)) + @skipOps('TestOperators', 'test_vmap_autograd_grad', { + # call inplace functions + xfail('linalg.householder_product'), # inplace + + xfail('linalg.eig'), # all close? + # The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0 + xfail('masked_select'), + xfail('nn.functional.max_unpool2d', 'grad'), # contiguous call + xfail('nn.functional.max_unpool2d'), # contiguous call + xfail('to_sparse'), # dispatch key issue + + # numerical inconsistencies, look like bugs + skip('ldexp', dtypes=(torch.float32,), device_type='cpu'), # fails on all but mac + skip('__rmatmul__'), # flaky needs investigation + skip('matmul'), # flaky needs investigation + skip('nn.functional.conv_transpose3d'), # flaky needs investigation + skip('nn.functional.conv_transpose2d'), # flaky needs investigation + skip('nn.functional.conv_transpose1d'), # flaky needs investigation + skip('nn.functional.layer_norm', dtypes=(torch.float32,), device_type='cpu'), # fails on windows + skip('linalg.lu_factor', dtypes=(torch.float32,), device_type='cuda'), # fails on all but windows + skip('linalg.lu_factor_ex', dtypes=(torch.float32,), device_type='cuda'), # fails on all but windows + }) + def test_vmap_autograd_grad(self, device, dtype, op): + def is_differentiable(inp): + return isinstance(inp, Tensor) and (inp.grad_fn is not None or inp.requires_grad) + + def get_flat_differentiable(pytree): + flattened = tree_flatten(pytree)[0] + return tuple(i for i in flattened if is_differentiable(i)) + + def get_differentiable_linked(list1, list2): + paired_list = zip(list1, list2) + paired_list = tuple((first, second) for (first, second) in paired_list if is_differentiable(first)) + return zip(*paired_list) + + def filter_none(out): + flattened = tree_flatten(out)[0] + return tuple(o for o in flattened if o is not None) + + if not op.supports_autograd: + self.skipTest("Skipped! Autograd not supported.") + return + + sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) + + for sample_input in sample_inputs: + fn, primals = normalize_op_input_output(op, sample_input) + out = fn(*primals) + cotangents = tree_map(torch.randn_like, out) + + def compute_grad(cotangents): + out_flattened = out + cotangents_flattened = cotangents + if not isinstance(out_flattened, torch.Tensor): + out_flattened = tree_flatten(out)[0] + cotangents_flattened = tree_flatten(cotangents)[0] + out_flattened, cotangents_flattened = get_differentiable_linked(out_flattened, cotangents_flattened) + + return filter_none( + torch.autograd.grad(out_flattened, get_flat_differentiable(primals), cotangents_flattened, + retain_graph=True, allow_unused=True)) + + is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs) + generator = get_fallback_and_vmap_exhaustive( + compute_grad, (cotangents,), {}, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: + self.assertEqual(loop_out, batched_out) + + +only_for = ("cpu", "cuda") +instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) + +if __name__ == '__main__': + run_tests() diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py new file mode 100644 index 0000000000000..388a87c468755 --- /dev/null +++ b/functorch/test/test_pythonkey.py @@ -0,0 +1,595 @@ +# Owner(s): ["module: functorch"] + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.testing._internal.common_utils import TestCase, run_tests +import torch +import torch.nn as nn +import torch.utils._pytree as pytree +import unittest +import warnings +import itertools +from functools import partial +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from functorch import ( + grad, vjp, vmap, jacrev, + make_fx +) +from functorch._src.aot_autograd import aot_module_simplified +from functorch.compile import ( + nnc_jit, compiled_function, compiled_module, + min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop, + num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion, +) + +from torch.testing._internal.common_device_type import ops +from functorch_lagging_op_db import functorch_lagging_op_db +from functorch_additional_op_db import additional_op_db +from common_utils import ( + xfail, + skip, + skipOps, +) + +USE_TORCHVISION = False +try: + import torchvision + USE_TORCHVISION = True +except ImportError: + warnings.warn("Couldn't import torchvision. Some of our tests use it, try " + "to install it with commands from pytorch.org, post-fixed with " + "`--no-deps` to avoid overwriting the pytorch installation", + UserWarning) + +USE_NETWORKX = False +try: + import networkx # noqa: F401 + USE_NETWORKX = True +except ImportError: + warnings.warn("Some tests use networkx but it was not installed", + UserWarning) + +# NB: numpy is a testing dependency! + + +class TestPythonKey(TestCase): + def test_make_fx(self, device): + def f(x): + return torch.sin(x) + inp = torch.randn(3) + fx_f = make_fx(f)(inp) + + new_inp = torch.randn(3) + self.assertEqual(fx_f(new_inp), f(new_inp)) + + def test_make_fx_grad(self, device): + def f(x): + return torch.sin(x).sum() + inp = torch.randn(3) + f = grad(f) + fx_f = make_fx(f)(inp) + + new_inp = torch.randn(3) + self.assertEqual(fx_f(new_inp), f(new_inp)) + + def test_scalar_device(self, device): + def f(a, b): + return a + b + inps = [torch.randn(3, device=device), torch.tensor(5)] + fx_f = make_fx(f)(*inps) + self.assertEqual(fx_f(*inps), f(*inps)) + + def test_make_fx_vmap(self, device): + def f(x): + return torch.sin(x) + inp = torch.randn(5, 3) + f = vmap(f) + fx_f = make_fx(f)(inp) + new_inp = torch.randn(5, 3) + self.assertEqual(fx_f(new_inp), f(new_inp)) + + def test_make_fx_jacrev(self, device): + def f(x): + return x.sin().sum() + inp = torch.randn(3) + f = jacrev(jacrev(f)) + fx_f = make_fx(f)(inp) + new_inp = torch.randn(3) + self.assertEqual(fx_f(new_inp), f(new_inp)) + + def test_make_fx_vjp(self, device): + def f(x): + return torch.sin(x).sum() + + primals = torch.randn(3) + _, vjp_fn = vjp(f, primals) + cotangent = torch.randn(()) + fx_f = make_fx(vjp_fn)(cotangent, True, True) + new_cotangent = torch.randn(()) + self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) + + def test_make_fx_no_decompose(self, device): + # FIXME + return self.skipTest("error: maximum recursion reached") + + def f(x): + return torch.tanh(x).sum() + + fx_f = make_fx(grad(f))(torch.randn(5)) + ops = set([i.target for i in fx_f.graph.nodes]) + + self.assertEqual(torch.ops.aten.tanh_backward in ops, True) + + fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5)) + ops = set([i.target for i in fx_f.graph.nodes]) + self.assertEqual(torch.ops.aten.tanh_backward in ops, False) + + def test_nnc_jit(self, device): + def f(x): + return torch.sin(x) + + jit_f = nnc_jit(f) + + inp = torch.randn(3) + self.assertEqual(jit_f(inp), f(inp)) + + def test_nnc_scalar(self, device): + def f(x): + return torch.sin(x) + + jit_f = nnc_jit(f) + + inp = torch.randn(()) + self.assertEqual(jit_f(inp), f(inp)) + + def test_nnc_pytrees(self, device): + def f(x): + return [torch.sin(x[0])] + + jit_f = nnc_jit(f) + + inp = [torch.randn(3)] + self.assertEqual(jit_f(inp), f(inp)) + + def test_external_calls(self, device): + def f(a, b): + return torch.mv(a, b) + jit_f = nnc_jit(f) + inp = [torch.randn(3, 3), torch.randn(3)] + self.assertEqual(jit_f(*inp), f(*inp)) + + def test_nnc_passthrough(self, device): + def f(x, y): + return x + y, y + inp = (torch.randn(3), torch.randn(3)) + jit_f = nnc_jit(f) + self.assertEqual(jit_f(*inp), f(*inp)) + + def f(x): + x['a'] = x['a'] * 2 + return x + inp = ({'a': torch.randn(3), 'b': torch.randn(3)},) + jit_f = nnc_jit(f) + self.assertEqual(jit_f(*inp), f(*inp)) + + @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") + def test_resnet18_backward_trace(self, device): + mod = torchvision.models.resnet18() + + def f(x): + out = mod(x) + out.sum().backward() + return [a.grad for a in mod.parameters()] + + inp = torch.randn(3, 3, 250, 250, requires_grad=True) + grads = f(inp) + + mod.zero_grad() + mod(inp).sum().backward() + grads2 = [a.grad for a in mod.parameters()] + self.assertEqual(grads, grads2) + + +def _outs_and_grads(fn, inps): + outs = fn(*inps) + for out in pytree.tree_flatten(outs)[0]: + if isinstance(out, torch.Tensor) and out.requires_grad: + out.sum().backward(retain_graph=True) + grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]] + for inp in pytree.tree_flatten(inps)[0]: + inp.grad = None + return outs, grads + + +class TestAOTAutograd(TestCase): + def verify_aot_autograd(self, f, inp): + if isinstance(f, nn.Module): + compiled_f = aot_module(f, nop) + else: + compiled_f = aot_function(f, nop) + ref_out, ref_grad = _outs_and_grads(f, inp) + test_out, test_grad = _outs_and_grads(compiled_f, inp) + self.assertEqual(ref_out, test_out) + self.assertEqual(ref_grad, test_grad) + + def test_single_output(self): + def f(a, b): + return a + b + inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] + self.verify_aot_autograd(f, inp) + + def test_multi_output(self): + def f(a, b): + return a + b, a - b + inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] + self.verify_aot_autograd(f, inp) + + def test_multi_output_list(self): + def f(a, b): + return [a + b, a - b] + inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] + self.verify_aot_autograd(f, inp) + + def test_no_grad_input_output(self): + def f(a, b): + return a.cos(), b.cos(), a * b + + inp_thunks = [lambda: torch.randn(5, requires_grad=True), lambda: torch.randn(5, requires_grad=False)] + for inps in itertools.product(inp_thunks, repeat=2): + inps = [i() for i in inps] + self.verify_aot_autograd(f, inps) + + def test_inner_grad(self): + def foo(x): + y = torch.exp(x) + z = torch.autograd.grad(y, x) + return z + inps = [torch.randn((), requires_grad=True)] + self.verify_aot_autograd(foo, inps) + + def test_grad_context(self): + def foo(x): + return x * 2 + inps = [torch.randn((), requires_grad=True)] + graph_size = None + + def assert_graph_empty(fx_g, _): + nonlocal graph_size + graph_size = len(fx_g.graph.nodes) + return fx_g + + start_recompilations = num_of_recompilations() + f = aot_function(foo, nop, assert_graph_empty) + with torch.set_grad_enabled(False): + f(*inps) + self.assertEqual(graph_size, 2) + with torch.set_grad_enabled(True): + f(*inps) + self.assertTrue(graph_size > 2) + self.assertEqual(num_of_recompilations() - start_recompilations, 2) + + def test_output_dict(self): + def f(x): + return {'a': x, 'b': x} + inp = [torch.randn(3, 3, requires_grad=True)] + self.verify_aot_autograd(f, inp) + + def f(x, y): + return {'a': x, 'b': y + x} + inp = [torch.randn(3, requires_grad=True), torch.randn(3)] + self.verify_aot_autograd(f, inp) + + def f(x): + new_d = {} + for k in x: + new_d[k] = x[k] * 2 + return new_d + inp = [{'a': torch.randn(3, requires_grad=True), 'b': torch.randn(3, requires_grad=True)}] + self.verify_aot_autograd(f, inp) + + def test_module(self): + mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU()) + compiled_mod = compiled_module(mod, nop, nop) + inp = torch.randn(32, 32) + ref_out = mod(inp) + ref_out.sum().backward() + ref_grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) + out = compiled_mod(inp) + out.sum().backward() + grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) + self.assertEqual((out, grads), (ref_out, ref_grads)) + + def test_batchnorm(self): + mod = compiled_module(nn.BatchNorm2d(4), nop, nop) + x = torch.ones(1, 4, 2, 2) + mod(x).sum().backward() + + +class TestEagerFusionOpInfo(TestCase): + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + # entries in here need don't work and need to be fixed. + # Each one of these is a bug (or needs to be investigated) + @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', { + xfail('linalg.cholesky'), + skip('msort'), + xfail('nn.functional.dropout'), + xfail('to_sparse'), + xfail('addcdiv'), + xfail('cholesky'), + xfail('cumulative_trapezoid'), + xfail('diag_embed'), + xfail('linalg.householder_product'), + xfail('logit'), + xfail('trapezoid'), + xfail('trapz'), + xfail('corrcoef'), + xfail('cov'), + skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes? + skip('nn.functional.margin_ranking_loss'), # seems flaky + }) + def test_aot_autograd_exhaustive(self, device, dtype, op): + def f(args, kwargs): + return op.op(*args, **kwargs) + if not op.supports_autograd: + return + sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) + for sample_input in sample_inputs_itr: + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in args]): + self.skipTest("not all inputs are float tensors") + if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in kwargs.values()]): + self.skipTest("not all inputs are float tensors") + continue + t = f(args, kwargs) + if isinstance(t, tuple): + self.skipTest("output is a tuple") + continue + + def reset_grads(): + def f(x): + x.grad = None + pytree.tree_map(f, args) + + def get_grads(args): + return pytree.tree_map(lambda x: x.grad, args) + + compiled_f = compiled_function(f, nop, nop) + + reset_grads() + compiled_f(args, kwargs).sum().backward() + compiled_grad = get_grads(args) + + reset_grads() + f(args, kwargs).sum().backward() + orig_grad = get_grads(args) + self.assertEqual(orig_grad, compiled_grad) + + def create_new_arg(x): + return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad) + + args = pytree.tree_map(create_new_arg, args) + + reset_grads() + compiled_f(args, kwargs).sum().backward() + compiled_grad = get_grads(args) + + reset_grads() + f(args, kwargs).sum().backward() + orig_grad = get_grads(args) + self.assertEqual(orig_grad, compiled_grad) + + +def extract_graph(fx_g, _, graph_cell): + graph_cell[0] = fx_g + return fx_g + + +def get_ins_outs(fx_g): + ins = [] + outs = [] + for n in fx_g.graph.nodes: + if n.op == 'placeholder': + ins.append(n) + elif n.op == 'output': + outs = tuple(n.args[0]) + return ins, outs + + +def get_num_ins_outs(fx_g): + return tuple(len(i) for i in get_ins_outs(fx_g)) + + +def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition): + fw_graph_cell = [None] + bw_graph_cell = [None] + aot_function(f, + fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), + bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), + partition_fn=partitioner, + decompositions=default_decompositions)(*inps) + return (fw_graph_cell[0], bw_graph_cell[0]) + + +class TestPartitioning(TestCase): + @unittest.skipIf(not USE_NETWORKX, "networkx not available") + def test_recompute_partitioning(self): + def fn(a, b): + return torch.sin(torch.sin(a)) + b + + # Reference calculation + ref_a = torch.rand(10, 10, requires_grad=True) + ref_b = torch.rand(10, 10, requires_grad=True) + ref = fn(ref_a, ref_b) + ref.sum().backward() + + # Compiled function calculation + res_a = ref_a.clone().detach().requires_grad_(True) + res_b = ref_b.clone().detach().requires_grad_(True) + + def compile_fn(x, _): + return x + + compiled_fn = compiled_function(fn, compile_fn, compile_fn, min_cut_rematerialization_partition) + res = compiled_fn(res_a, res_b) + res.sum().backward() + assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) + assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3) + assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3) + + def test_meta_tensor_inplace_op(self): + # Following module results in inplace ops while tracing. The test checks + # that the meta tensor information is stored for inplace ops. + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(3072, 768, requires_grad=True)) + self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True)) + + def forward(self, add_4): + linear_4 = torch.nn.functional.linear(add_4, self.weight, bias=self.bias) + gelu = torch.nn.functional.gelu(linear_4) + return gelu + + def check_meta_tensor(fx_g, _): + for node in fx_g.graph.nodes: + if node.op != 'output': + assert 'tensor_meta' in node.meta + return fx_g + + inp0 = torch.randn(16, 128, 768, requires_grad=True) + inputs = [inp0, ] + mod = MockModule().to(device="cpu") + aot_mod = aot_module(mod, fw_compiler=check_meta_tensor) + aot_mod(*inputs) + + def test_default_partitioner_getitem(self): + mod = nn.LayerNorm([10]) + + def f(x, mod_weight, mod_bias): + return torch.nn.functional.layer_norm(x, [10], mod_weight, mod_bias, eps=1e-6) + + fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias], + partitioner=default_partition) + self.assertEqual(get_num_ins_outs(fw_graph), (3, 6)) + self.assertEqual(get_num_ins_outs(bw_graph), (6, 3)) + + @unittest.skipIf(not USE_NETWORKX, "networkx not available") + def test_min_cut_partitioner(self): + def f(x): + return x.cos().cos().cos() + + fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)]) + self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) + self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) + + def f(a, b, c, d): + x = a + b + c + d + return x.cos().cos() + + fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True) for _ in range(4)]) + self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) + self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) + + def f(x): + return torch.mm(x, torch.ones(x.shape)).tanh().tanh() + fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(5, 5, requires_grad=True)]) + self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) + + ins, outs = get_ins_outs(fw_graph) + self.assertEqual(outs[1].target, torch.ops.aten.mm.default) + + +class TestContiguous(TestCase): + def test_contiguous(self): + # The test simulates the condition where transpose followed by view + # happens in the backward pass. + # https://discuss.pytorch.org/t/error-on-transpose-and-view/434 + def f(x): + return x.view(2, 3).t() + + inp = torch.randn(6, requires_grad=True) + out = aot_function(f, nop)(inp) + torch.autograd.grad(out, inp, torch.randn(3, 2)) + + +class TestAOTModuleSimplified(TestCase): + def test_aot_module_simplified(self): + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(20, 30) + + def forward(self, x, y): + return (self.linear(x) + y, ) + + mod = MockModule() + mod.zero_grad() + + x = torch.randn(128, 20, requires_grad=True) + y = torch.randn(128, 30, requires_grad=True) + inputs = [x, y] + cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] + + ref = mod(*inputs) + ref[0].sum().backward() + + aot_mod = aot_module_simplified(mod, nop) + aot_mod.zero_grad() + res = aot_mod(*cloned_inputs) + res[0].sum().backward() + + assert torch.allclose(ref[0], res[0]) + assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) + assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) + + +class TestRandom(TestCase): + def test_preserve_random(self): + def fn(x): + return torch.nn.functional.dropout(x, 0.5) + x + + + x = torch.randn(4) + + torch.manual_seed(0) + ref = fn(x) + + torch.manual_seed(0) + aot_fn = aot_function(fn, nop) + res = aot_fn(x) + + assert torch.allclose(ref, res) + + +class TestAutocast(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") + @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") + def test_autocast(self): + mod = torchvision.models.resnet18().cuda() + mod.train() + + x = torch.randn(16, 3, 32, 32, device="cuda") + aot_mod = memory_efficient_fusion(mod) + + # Ensure that AOT Autograd works with AMP + with torch.cuda.amp.autocast(True): + res = aot_mod(x) + res.sum().backward() + + +only_for = ("cpu") +instantiate_device_type_tests( + TestPythonKey, + globals(), + only_for=only_for, +) +instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for) + + +if __name__ == '__main__': + run_tests() diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py new file mode 100644 index 0000000000000..eb13da91b64da --- /dev/null +++ b/functorch/test/test_vmap.py @@ -0,0 +1,4248 @@ +# Owner(s): ["module: functorch"] + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import OrderedDict +from unittest.case import skipIf +from torch.testing._internal.common_utils import TestCase, run_tests +import torch +import torch.nn.functional as F +from torch import Tensor +import functools +import itertools +import warnings +import unittest +from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ + skipCUDAIfNoMagma +from torch.testing._internal.common_device_type import ops +from torch.testing._internal.common_utils import ( + parametrize, + instantiate_parametrized_tests, + subtest +) +from torch.testing._internal.common_device_type import \ + toleranceOverride, tol +from functorch_lagging_op_db import functorch_lagging_op_db +from functorch_additional_op_db import additional_op_db +from common_utils import ( + get_fallback_and_vmap_exhaustive, + xfail, + skip, + skipOps, + check_vmap_fallback, + tol1, + opsToleranceOverride, + is_batch_norm_training, +) +import types +from collections import namedtuple + +import functorch +from functorch import vmap, grad, grad_and_value, jvp, vjp +from functorch.experimental import chunk_vmap +from functorch._C import reshape_dim_into, reshape_dim_outof +from functorch._src.make_functional import functional_init_with_buffers + +FALLBACK_REGEX = 'There is a performance drop' + + +class EnableVmapFallbackWarnings: + def __enter__(self): + self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled() + torch._C._debug_only_display_vmap_fallback_warnings(True) + + def __exit__(self, *ignored): + torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state) + + +class TestVmapAPI(TestCase): + def test_non_tensor_output_raises(self): + with self.assertRaisesRegex(ValueError, "got type as a return"): + vmap(lambda x: 3.14)(torch.ones(3)) + + def multiple_outputs(x): + return x, 3 + + with self.assertRaisesRegex(ValueError, "got type as a return"): + vmap(multiple_outputs)(torch.ones(3)) + + def test_different_map_dim_size_raises(self): + x = torch.randn(2) + y = torch.randn(3) + expected_msg = 'Expected all tensors to have the same size in the mapped dimension' + with self.assertRaisesRegex(ValueError, expected_msg): + vmap(torch.mul)(x, y) + with self.assertRaisesRegex(ValueError, expected_msg): + vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) + with self.assertRaisesRegex(ValueError, expected_msg): + vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y}) + + def test_func_with_no_inputs(self): + expected_msg = 'got no inputs' + + def foo(): + return torch.randn(3) + + def bar(x): + return torch.randn(3) + + with self.assertRaisesRegex(ValueError, expected_msg): + vmap(foo)() + + with self.assertRaisesRegex(ValueError, expected_msg): + vmap(bar)() + + def test_constant_function(self): + output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3)) + self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14])) + + def test_single_input(self): + x = torch.randn(2, 3) + + def square(x): + return x * x + + output = vmap(square)(x) + self.assertEqual(output, x * x) + + def test_multiple_inputs(self): + x = torch.randn(2, 3) + y = torch.randn(2, 3) + output = vmap(torch.mul)(x, y) + self.assertEqual(output, x * y) + + def test_multiple_outputs(self): + def foo(x): + return x * x, x * x * x + + x = torch.randn(3) + outputs = vmap(foo)(x) + self.assertEqual(outputs[0], x * x) + self.assertEqual(outputs[1], x * x * x) + + def test_multiple_outputs2(self): + # This is the same thing as + # def returns_tuple_of_tensors(x): + # return x, x + def returns_tuple_of_tensors(x): + return (x, x) + + def returns_list_of_two_tensors(x): + return [x, x] + + def returns_list_of_one_tensor(x): + return [x] + + x = torch.randn(3) + + # should not throw + vmap(returns_tuple_of_tensors)(x) + vmap(returns_list_of_two_tensors)(x) + vmap(returns_list_of_one_tensor)(x) + + def test_nested_with_same_map_dim(self): + x = torch.randn(2, 3, 5) + y = torch.randn(2, 3, 5) + output = vmap(vmap(torch.mul))(x, y) + self.assertEqual(output, x * y) + + output = vmap(vmap(vmap(torch.mul)))(x, y) + self.assertEqual(output, x * y) + + def test_nested_with_diag_embed(self): + # diag_embed requires special testing because it is registered with conditional functionalization. + x = torch.randn(3, 3, 5) + output = vmap(vmap(torch.diag_embed))(x) + self.assertEqual(output, torch.diag_embed(x)) + + def test_nested_with_different_map_dim(self): + x = torch.randn(2, 3) + y = torch.randn(5, 3) + output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) + self.assertEqual(output.shape, (2, 5, 3)) + self.assertEqual(output, x.view(2, 1, 3) * y) + + z = torch.randn(7, 3) + output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) + self.assertEqual(output.shape, (2, 5, 7, 3)) + self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z) + + def test_noop_in_inner_vmap(self): + x = torch.randn(3) + y = torch.randn(5) + output = vmap(lambda x: vmap(lambda y: x)(y))(x) + self.assertEqual(output, x.view(3, 1).expand(3, 5)) + + def test_unsupported_op_err_msg(self): + # Unsupported view op + tensor = torch.randn(2, 3) + msg = ( + r"Batching rule not implemented for aten::.+; the " + r"fallback path doesn't work on out= or view ops" + ) + # TODO: find a view op + # with self.assertRaisesRegex(RuntimeError, msg): + # vmap(torch.ravel)(tensor) + + def out_op(x, y): + return torch.abs(x, out=y) + + with self.assertRaisesRegex(RuntimeError, msg): + vmap(out_op)(tensor, tensor) + + # Don't support non-tensor returns. This is a limitation of vmap; + # functions that don't return tensors must be special cased + with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'): + vmap(torch.equal)(tensor, tensor) + + def test_nonzero_out_dims(self): + # Basic test + tensor = torch.randn(2, 3) + result = vmap(lambda x: x, out_dims=1)(tensor) + self.assertEqual(result, tensor.permute(1, 0)) + self.assertEqual(result.data_ptr(), tensor.data_ptr()) + + # Test that the batch dimension gets permuted to dim 2 + tensor = torch.randn(2, 3, 5, 7) + result = vmap(lambda x: x, out_dims=2)(tensor) + self.assertEqual(result, tensor.permute(1, 2, 0, 3)) + self.assertEqual(result.data_ptr(), tensor.data_ptr()) + + # negative out_dim + tensor = torch.randn(2, 3, 5, 7) + result = vmap(lambda x: x, out_dims=-1)(tensor) + self.assertEqual(result, tensor.permute(1, 2, 3, 0)) + self.assertEqual(result.data_ptr(), tensor.data_ptr()) + + # check that out_dims works on ALL outputs + tensor = torch.randn(2, 3, 5, 7) + other = torch.randn(2, 3, 5, 7) + result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other) + self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3))) + + # use out_dims with the maximum vmap-able tensor dims (64 dims) + ndims = 64 + shape = [2] + [1] * (ndims - 1) + expected_shape = [1, 1, 2] + [1] * (ndims - 3) + tensor = torch.randn(shape) + result = vmap(lambda x: x, out_dims=2)(tensor) + self.assertEqual(result.shape, expected_shape) + + # test something that is not the identity function + def foo(x, y): + return x, x * y, x * y * y + x = torch.randn(2, 3, 5) + y = torch.randn(2, 3, 5) + result = vmap(foo, out_dims=1)(x, y) + self.assertEqual( + result, + (x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2))) + + def test_multiple_out_dims(self): + def foo(x): + return x, x + + def bar(x, y): + return x, x, x, x * y + + x = torch.randn(2, 3, 5) + y = torch.randn(2, 3, 5) + result = vmap(foo, out_dims=(0, 1))(x) + self.assertEqual(result, (x, x.permute(1, 0, 2))) + + result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y) + expected = ( + x.permute(1, 2, 0), + x, + x.permute(1, 0, 2), + (x * y).permute(1, 2, 0), + ) + self.assertEqual(result, expected) + + def test_nested_out_dims(self): + y = torch.randn(2, 3, 5, 7) + + # Inner vmap has non-zero out_dim + result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y) + self.assertEqual(result.shape, (2, 5, 3, 7)) + self.assertEqual(result, y.permute(0, 2, 1, 3)) + + # all vmaps have non-zero out_dim + result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y) + self.assertEqual(result.shape, (5, 2, 3, 7)) + self.assertEqual(result, y.permute(2, 0, 1, 3)) + + # throwing in some negative out_dims + result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y) + self.assertEqual(result.shape, (5, 7, 3, 2)) + self.assertEqual(result, y.permute(2, 3, 1, 0)) + + # testing fn that isn't the identity + x = torch.randn(2, 3) + y = torch.randn(5, 3) + result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y) + self.assertEqual(result.shape, (3, 2, 5)) + self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0)) + + def test_out_dims_edge_case(self): + def foo(x): + return x + + # Test that we accept out_dims=(1,) for a function with one output. + tensor = torch.randn(2, 3) + expected = vmap(foo, out_dims=1)(tensor) + result = vmap(foo, out_dims=(1,))(tensor) + self.assertEqual(result, expected) + + def test_pytree_returns(self): + x = torch.randn(2, 3) + + def f(x): + y = x.sin() + return y, (y, y), [y, (y, y)] + + y0, (y1, y2), (y3, (y4, y5)) = vmap(f)(x) + self.assertEqual(y0, x.sin()) + self.assertEqual(y0, y1) + self.assertEqual(y2, y1) + self.assertEqual(y2, y3) + self.assertEqual(y4, y3) + self.assertEqual(y5, y4) + + def test_pytree_odict_returns(self): + x = torch.randn(2, 3) + + def f(t): + y = t.sin() + return OrderedDict([("sin", y), ("cos", t.cos())]) + + out = vmap(f)(x) + assert isinstance(out, OrderedDict) + expected = f(x) + self.assertEqual(out["sin"], expected["sin"]) + self.assertEqual(out["cos"], expected["cos"]) + + # temporary test for _odict_flatten and _odict_unflatten + def test_pytest_odict_flatten_unflatten(self): + + from functorch._src.vmap import _odict_flatten, _odict_unflatten + + x = torch.randn(2, 3) + inpt = OrderedDict([("sin", x.sin()), ("cos", x.cos())]) + + out = _odict_flatten(inpt) + self.assertEqual(out[0], list(inpt.values())) + self.assertEqual(out[1], list(inpt.keys())) + + recon_inpt = _odict_unflatten(*out) + self.assertEqual(recon_inpt, inpt) + + def test_pytree_returns_outdims(self): + x = torch.randn(2, 3) + + def f(x): + y = x.sin() + return y, (y, y) + + y0, (y1, y2) = vmap(f, out_dims=(0, (0, 1)))(x) + self.assertEqual(y0, x.sin()) + self.assertEqual(y1, x.sin()) + self.assertEqual(y2, x.sin().t()) + + def test_pytree_returns_broadcast_simple(self): + x = torch.randn(2, 3) + + def f(x): + y = x.sin() + return y, (y, y) + + y0, (y1, y2) = vmap(f, out_dims=1)(x) + self.assertEqual(y0, x.sin().t()) + self.assertEqual(y1, y0) + self.assertEqual(y2, y0) + + def test_pytree_returns_broadcast_nested(self): + x = torch.randn(2, 3) + + def f(x): + y = x.sin() + return y, (y, y) + + y0, (y1, y2) = vmap(f, out_dims=(0, 1))(x) + self.assertEqual(y0, x.sin()) + self.assertEqual(y1, y0.t()) + self.assertEqual(y2, y0.t()) + + def test_out_dims_must_be_int_or_collection_of_int_err_msg(self): + msg = 'must be an int or a python collection of ints' + tensor = torch.randn(2, 3) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda x: x, out_dims='lol')(tensor) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda x: x, out_dims=('lol',))(tensor) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda x: x, out_dims=None)(tensor) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda x: x, out_dims=(None,))(tensor) + + def test_out_dims_and_num_outputs_mismatch_err_msg(self): + msg = 'not compatible' + x = torch.randn(2, 3, 5) + + # Too many out_dims + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda x: x, out_dims=(0, 0))(x) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x) + + # Too few out_dims + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda x: (x, x), out_dims=(0,))(x) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda x: (x, x, x), out_dims=(0, 0))(x) + + def test_out_dim_out_of_bounds_err_msg(self): + # TODO(rzou): This error message isn't that great. It comes straight + # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to + # the error message in the future in C++ + msg = 'Dimension out of range' + x = torch.randn(2, 3, 5) + with self.assertRaisesRegex(IndexError, msg): + vmap(lambda x: x, out_dims=3)(x) + with self.assertRaisesRegex(IndexError, msg): + vmap(lambda x: x, out_dims=-4)(x) + + def test_non_zero_in_dims(self): + tensor = torch.randn(2, 3, 5) + + # Implicit out_dims = 0; vmap will move the batch dim to the front. + output = vmap(lambda x: x, (1,))(tensor) + self.assertEqual(output, tensor.permute(1, 0, 2)) + self.assertEqual(output.data_ptr(), tensor.data_ptr()) + + x = torch.randn(2, 3) + y = torch.randn(3, 2) + output = vmap(torch.mul, (0, 1))(x, y) + self.assertEqual(output, x * y.t()) + output = vmap(torch.mul, (1, 0))(x, y) + self.assertEqual(output, x.t() * y) + + def test_none_in_dims(self): + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + # None in_dim for a Tensor means we don't map over it + output = vmap(torch.mul, (0, None))(x, y) + self.assertEqual(output.shape, (2, 2, 3)) + self.assertEqual(output, x.view(2, 1, 3) * y) + + # None in_dim for non-tensor arguments + output = vmap(torch.mul, (0, None))(x, 2) + self.assertEqual(output, x * 2) + + def test_nested_non_default_in_dims(self): + x = torch.rand(5, 2, 3) + y = torch.rand(3, 5, 2) + result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y) + self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1)) + + def test_nested_negative_in_dims(self): + x = torch.randn(2, 3) + y = torch.randn(2, 3) + output = vmap(torch.mul, (-1, -1))(x, y) + self.assertEqual(output.shape, (3, 2)) + self.assertEqual(output, (x * y).permute(1, 0)) + + def test_non_default_in_dims_out_dims(self): + x = torch.randn(2, 3, 5) + + # Same in_dim as out_dim, vmap over identity + result = vmap(lambda x: x, in_dims=1, out_dims=1)(x) + self.assertEqual(result, x) + self.assertEqual(result.data_ptr(), x.data_ptr()) + + # Different in_dim from out_dim, vmap over identity + result = vmap(lambda x: x, in_dims=2, out_dims=1)(x) + self.assertEqual(result.shape, (2, 5, 3)) + self.assertEqual(result, x.transpose(1, 2)) + self.assertEqual(result.data_ptr(), x.data_ptr()) + + def foo(x): + return x * 2 + + # Same in_dim as out_dim, vmap over operation + result = vmap(foo, in_dims=1, out_dims=1)(x) + self.assertEqual(result, x * 2) + + # Different in_dim as out_dim, vmap over operation + result = vmap(foo, in_dims=2, out_dims=1)(x) + self.assertEqual(result.shape, (2, 5, 3)) + self.assertEqual(result, (x * 2).transpose(1, 2)) + + # Basic nested test. + result = vmap(vmap(foo, 1, 1), 1, 1)(x) + self.assertEqual(result, x * 2) + + def test_item_throws(self): + def f(x): + return x.item() + + with self.assertRaisesRegex(RuntimeError, r'item\(\) on a Tensor'): + vmap(f)(torch.randn(3)) + + def test_data_dependent_control_flow_throws(self): + def f(x): + if x: + return x + return 0 + + with self.assertRaisesRegex(RuntimeError, r'data-dependent control flow'): + vmap(f)(torch.randn(3)) + + def test_accepts_nested_inputs(self): + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + # Single layer of nesting + out = vmap(lambda z: z[0] + z[1])((x, y)) + self.assertEqual(out, x + y) + out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y)) + self.assertEqual(out, x + y) + out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) + self.assertEqual(out, x + y) + + out = vmap(lambda z: z[0] + z[1])([x, y]) + self.assertEqual(out, x + y) + out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y]) + self.assertEqual(out, x + y) + out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y]) + self.assertEqual(out, x + y) + + out = vmap(lambda z: z['x'] + z['y'])({'x': x, 'y': y}) + self.assertEqual(out, x + y) + out = vmap(lambda z: z['x'] + z['y'], in_dims=(0,))({'x': x, 'y': y}) + self.assertEqual(out, x + y) + out = vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y}) + self.assertEqual(out, x + y) + + # Multiple layers of nesting + out_fn = vmap(lambda z: z['x'][0] + z['x'][1][0] + z['y'][0] + z['y'][1]) + out = out_fn({'x': [x, (x,)], 'y': [y, y]}) + self.assertEqual(out, x + x + y + y) + + def test_in_dims_wrong_type_err_msg(self): + x = torch.randn(3) + y = torch.randn(3) + msg = r'expected `in_dims` to be int or a \(potentially nested\) tuple' + with self.assertRaisesRegex(ValueError, msg): + vmap(torch.mul, [0, 0])(x, y) + with self.assertRaisesRegex(ValueError, msg): + vmap(torch.mul, set({0, 0}))(x, y) + with self.assertRaisesRegex(ValueError, msg): + vmap(torch.mul, 'lol')(x, y) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y]) + # The following should not throw + vmap(torch.mul, (0, 0))(x, y) + + def test_not_enough_in_dims_err_msg(self): + x = torch.randn(3) + y = torch.randn(3) + msg = r'in_dims is not compatible with the structure of `inputs`' + + with self.assertRaisesRegex(ValueError, msg): + vmap(torch.mul, (0,))(x, y) + with self.assertRaisesRegex(ValueError, msg): + vmap(torch.mul, (0, 0, 0))(x, y) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y]) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y]) + # The following should not throw + vmap(torch.mul, (0, 0))(x, y) + + def test_integer_in_dim_but_not_tensor_input_err_msg(self): + def foo(xy): + return xy[0] * xy[1] + + def bar(x, yz): + return x * yz[0] * yz[1] + + x = torch.randn(2, 3) + + # the following are errors in jax (and will always be errors) + msg = 'Got in_dim=0 for an input but the input is of type' + with self.assertRaisesRegex(ValueError, msg): + vmap(torch.sum)(x, 0) + with self.assertRaisesRegex(ValueError, msg): + vmap(torch.sum, (0, 0))(x, 0) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1]) + # The following should not throw + vmap(torch.sum, (0, None))(x, 0) + + def test_in_dim_not_in_tensor_err_msg(self): + def foo(x): + return x * x + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + msg = r'Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w' + with self.assertRaisesRegex(ValueError, msg): + vmap(foo)(torch.randn([])) + with self.assertRaisesRegex(ValueError, msg): + vmap(foo, in_dims=(0,))(torch.randn([])) + with self.assertRaisesRegex(ValueError, msg): + vmap(foo, in_dims=(-3,))(x) + with self.assertRaisesRegex(ValueError, msg): + vmap(foo, in_dims=(2,))(y) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y]) + # the following should not throw + vmap(foo, in_dims=(0,))(torch.randn(2, 3)) + vmap(foo, in_dims=(1,))(torch.randn(2, 3)) + + def test_fallback_does_not_warn_by_default(self): + # NB: One day we will implement a batching rule for torch.atan2. + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = torch.copysign + x = torch.randn(11) + y = torch.randn(11) + with warnings.catch_warnings(record=True) as wa: + vmap(op)(x, y) + # The single warning here is the "vmap is experimental" + # warning, not a warning from the vmap fallback path. + self.assertEqual(len(wa), 1) + + @unittest.expectedFailure + def test_fallback_warns_when_warnings_are_enabled(self): + # NB: One day we will implement a batching rule for torch.atan2. + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = torch.copysign + x = torch.randn(11) + y = torch.randn(11) + with warnings.catch_warnings(record=True) as wa: + with EnableVmapFallbackWarnings(): + vmap(op)(x, y) + self.assertEqual(len(wa), 2) + self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) + + def _assert_uses_vmap_fallback(self, vmap_args, inputs): + return + # with warnings.catch_warnings(record=True) as wa: + # with EnableVmapFallbackWarnings(): + # result = vmap(*vmap_args)(*inputs) + # self.assertEqual(len(wa), 2) + # self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) + + def test_fallback_zero_dim(self): + # NB: One day we will implement a batching rule for torch.atan2. + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = torch.copysign + x = torch.randn(11) + y = torch.randn(11) + self._assert_uses_vmap_fallback((op,), (x, y)) + + B0, B1 = 0, 3 + x = torch.randn(B0, 11) + y = torch.randn(11) + + msg = 'The fallback path does not support vmap over dims of size 0' + + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (0, None))(x, y) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (None, 0))(y, x) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(x, x) + + x = torch.randn(B0, B1, 11) + y = torch.randn(B1, 11) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (0, None))(x, y) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (None, 0))(y, x) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(x, x) + + def test_fallback_atan2(self): + # NB: One day we will implement a batching rule for torch.atan2. + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = torch.copysign + + x = torch.randn(5, 7, 11) + y = torch.randn(5, 7, 11) + + self._assert_uses_vmap_fallback((op,), (x, y)) + + # fallback on torch.atan2 + x = torch.randn(7, 11, 5) + y = torch.randn(5, 7, 11) + result = vmap(op, (2, 0))(x, y) + self.assertEqual(result, op(x.permute(2, 0, 1), y)) + + # fallback on torch.atan2, nested vmap + x = torch.randn(7, 11, 5) + y = torch.randn(5, 7, 11) + result = vmap(vmap(op), (2, 0))(x, y) + self.assertEqual(result, op(x.permute(2, 0, 1), y)) + + # big batch size (total 10000) + x = torch.randn(100, 10, 10, 5) + y = torch.randn(100, 10, 10) + result = vmap(vmap(vmap(op)))(x, y) + self.assertEqual(result, op(x, y.view(100, 10, 10, 1))) + + # TODO: No clue what is wrong here. + @unittest.skip + def test_fallback_masked_fill(self): + # NB: One day we will implement a batching rule for masked_fill + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + def run_test(batch_size): + B0 = batch_size + x = torch.randn(B0, 7, 11, 13) + dim = 0 + index = torch.tensor([0, 4, 2]) + values = torch.randn(B0, 3, 13) + + self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values)) + + result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values) + expected = torch.index_add( + x, dim + 1, index, values.view(B0, 3, 1, 13)) + self.assertEqual(result, expected) + + run_test(batch_size=5) + run_test(batch_size=1237) + + def test_fallback_multiple_returns(self): + # NB: One day we will implement a batching rule for torch.var_mean + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + B0, B1, B2 = 2, 3, 1237 + tensor = torch.randn(B0, 10) + + self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,)) + + # fallback correctness on torch.var_mean + result = vmap(torch.var_mean)(tensor) + expected = torch.var_mean(tensor, dim=1) + self.assertEqual(result, expected) + + # nested vmap + tensor = torch.randn(B0, B1, 10) + result = vmap(vmap(torch.var_mean))(tensor) + expected = torch.var_mean(tensor, dim=2) + self.assertEqual(result, expected) + + # big batch size, nested vmap + tensor = torch.randn(B0, B1, B2, 10) + result = vmap(vmap(vmap(torch.var_mean)))(tensor) + expected = torch.var_mean(tensor, dim=3) + self.assertEqual(result, expected) + + def test_inplace_fallback_unary(self): + # Test the in-place fallback on an in-place method that takes no + # additional Tensor arguments. This is the simplest case of the fallback. + # NB: One day we will implement a batching rule for acos_. + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = Tensor.acos_ + B0, B1, B2 = 2, 3, 10000 + + x = torch.randn(B0, 5) + self._assert_uses_vmap_fallback((op,), (x,)) + + # Single vmap + x_orig = torch.rand(B0, 5) + x = x_orig.clone() + result = vmap(op)(x) + self.assertTrue(result is x) + self.assertEqual(result, x_orig.acos()) + + # Single vmap + different out_dim produces a view(!) + x_orig = torch.rand(B0, 5) + x = x_orig.clone() + result = vmap(op, out_dims=(1,))(x) + self.assertTrue(result._base is x) + self.assertEqual(result, x_orig.t().acos()) + + # Nested vmap + x_orig = torch.randn(B0, B1, 5) + x = x_orig.clone() + result = vmap(vmap(op))(x) + self.assertTrue(result is x) + self.assertEqual(result, x_orig.acos()) + + # Nested vmap, large batch size + x_orig = torch.randn(B0, B1, B2, 5) + x = x_orig.clone() + result = vmap(vmap(vmap(op)))(x) + self.assertTrue(result is x) + self.assertEqual(result, x_orig.acos()) + + def test_inplace_fallback_nary_same_levels(self): + # NB: One day we will implement a batching rule for atan2_ + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = Tensor.atan2_ + outplace_op = torch.atan2 + + x = torch.randn(5, 7, 11) + y = torch.randn(5, 7, 11) + self._assert_uses_vmap_fallback((op,), (x, y)) + + # Single vmap + B0 = 5 + x_orig = torch.randn(7, 11, B0) + x = x_orig.clone() + y = torch.randn(B0, 7, 11) + vmap(op, (2, 0))(x, y) + self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2))) + + # Nested vmap + B0, B1 = 5, 7 + x_orig = torch.randn(B1, 11, B0) + x = x_orig.clone() + y = torch.randn(B0, B1, 11) + vmap(vmap(op), (2, 0))(x, y) + self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0]))) + + # big batch size (total 10000) + B0, B1, B2 = 100, 10, 10 + x_orig = torch.randn(B0, B1, B2, 5) + x = x_orig.clone() + y = torch.randn(B0, B1, B2) + vmap(vmap(vmap(op)))(x, y) + self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1))) + + # ("Fallback isInplaceVmapCompatible check is broken") + @unittest.expectedFailure + def test_inplace_fallback_nary_different_levels(self): + # NB: One day we will implement a batching rule for atan2_ + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = Tensor.atan2_ + outplace_op = torch.atan2 + B0, B1 = 2, 3 + + x = torch.rand(B0, 7) + y = torch.rand(7) + self._assert_uses_vmap_fallback((op, (0, None)), (x, y)) + + # op(left, right): All of the levels in right are found in left + x_orig = torch.rand(B0, 7) + x = x_orig.clone() + y = torch.rand(7) + vmap(op, in_dims=(0, None))(x, y) + self.assertEqual(x, outplace_op(x_orig, y)) + + x_orig = torch.rand(B0, B1, 7) + x = x_orig.clone() + y = torch.rand(B0, 7) + vmap(vmap(op, in_dims=(0, None)))(x, y) + self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7))) + + # op(left, right): Some of the levels in right are not found in left + msg = r'vmap: aten::atan2_\(self, \*extra_args\) is not possible' + x = torch.rand(7) + y = torch.rand(B0, 7) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(None, 0))(x, y) + + x = torch.rand(B1, 7) + y = torch.rand(B0, 7) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y) + + x = torch.rand(B1, 7) + y = torch.rand(7, B0) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y) + + x = torch.rand(B0, 7) + y = torch.rand(B0, B1, 7) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op, in_dims=(None, 0)))(x, y) + + def test_backward_unsupported_interaction(self): + x = torch.randn(3, requires_grad=True) + y = torch.randn(5) + grad = torch.randn_like(x) + err_msg = r'backward\(\) called inside a functorch transform' + + def backward_on_vmapped_tensor(x): + x.sum().backward() + + # FIXME + return self.skipTest("error: element 0 of tensors does not require grad and does not have a grad_fn") + with self.assertRaisesRegex(RuntimeError, err_msg): + vmap(backward_on_vmapped_tensor)(x) + + def backward_with_vmapped_grad(x, grad): + x.backward(grad) + + with self.assertRaisesRegex(RuntimeError, err_msg): + vmap(backward_with_vmapped_grad)(x, grad) + + def completely_unrelated_backward(y): + x.sum().backward() + return y + + with self.assertRaisesRegex(RuntimeError, err_msg): + vmap(completely_unrelated_backward)(y) + + @unittest.expectedFailure + def test_grad_unsupported_interaction(self): + input_tensor = torch.randn(3, requires_grad=True) + err_msg = 'autograd.grad.* called inside torch.vmap' + + captured = torch.randn(3, requires_grad=True) + + def output_to_grad_is_vmapped(input_tensor): + output = (captured * input_tensor).sum() + return torch.autograd.grad([output], [captured])[0] + + with self.assertRaisesRegex(RuntimeError, err_msg): + vmap(output_to_grad_is_vmapped)(input_tensor) + + output = (input_tensor ** 2).sum() + + def input_to_grad_is_vmapped(input_tensor): + return torch.autograd.grad([output], [input_tensor])[0] + + with self.assertRaisesRegex(RuntimeError, err_msg): + vmap(input_to_grad_is_vmapped)(input_tensor) + + def test_batched_gradient_basic(self): + N = 3 + x = torch.randn(N, requires_grad=True) + y = torch.randn(N) + + def vjp_mul(v): + return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0] + + batched_v = torch.eye(N) + jacobian = vmap(vjp_mul)(batched_v) + self.assertEqual(jacobian, torch.diagflat(y)) + + def test_functools_partial(self): + x = torch.randn(3) + y = torch.randn(2, 3) + result = vmap(functools.partial(torch.mul, x))(y) + self.assertEqual(result, x * y) + + def test_nn_module(self): + tensor = torch.randn(2, 3) + model = torch.nn.Linear(3, 3, bias=False) + result = vmap(model)(tensor) + self.assertEqual(result, model(tensor)) + + def test_fallback_with_undefined_grad(self): + B0 = 7 + x = torch.randn(2, 3, 4, 5, requires_grad=True) + weight = torch.randn(3, 3, 1, 1) + v = torch.randn(B0, 2, 3, 4, 5) + + def get_vjp(v): + result = torch.nn.functional.conv2d(x, weight) + grad_x, = torch.autograd.grad(result, x, v) + return grad_x + + # Runs vmap(get_vjp)(v), which should not error out. + # The backward formula for convolution returns an undefined + # Tensor for grad_bias because the original bias does not exist. + # + # In the future we'll probably add a batching rule for convolution + # backward. When this happens, we should modify this test to use a + # different op (and/or create and use a dummy operator) to avoid bitrot. + self._assert_uses_vmap_fallback([get_vjp], [v]) + + def test_reshape_dim_into(self): + x = torch.randn(2, 3, 5, 7) + + y = reshape_dim_into(0, 0, x) + self.assertEqual(y, x.reshape(6, 5, 7)) + + y = reshape_dim_into(0, 1, x) + self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7)) + + y = reshape_dim_into(0, 2, x) + self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7)) + + y = reshape_dim_into(1, 2, x) + self.assertEqual(y, x.movedim(1, 2).reshape(2, 5, 3 * 7)) + + y = reshape_dim_into(0, -2, x) + self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7)) + + y = reshape_dim_into(0, -1, x) + self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7)) + + y = reshape_dim_into(-4, -1, x) + self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7)) + + def test_reshape_dim_outof(self): + x = torch.randn(12, 12, 12).permute(2, 1, 0) + + y = reshape_dim_outof(0, 2, x) + self.assertEqual(y, x.reshape(2, 6, 12, 12)) + + y = reshape_dim_outof(1, 4, x) + self.assertEqual(y, x.reshape(12, 4, 3, 12)) + + y = reshape_dim_outof(2, 6, x) + self.assertEqual(y, x.reshape(12, 12, 6, 2)) + + y = reshape_dim_outof(-1, 6, x) + self.assertEqual(y, x.reshape(12, 12, 6, 2)) + + def test_batch_rule_does_not_need_to_handle_no_batched_input(self): + def f(x, y): + res = torch.dot(y, torch.ones(2)) + return x + res + + x = torch.randn(7, 5) + y = torch.randn(3, 2) + out = vmap(vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y) + expected = torch.mv(y, torch.ones(2)).view(3, 1, 1) + x + self.assertEqual(out, expected) + + def _test_vmap_autocast(self, device): + + if torch.device(device).type == "cpu": + amp_dtype = torch.bfloat16 + else: + amp_dtype = torch.float16 + + a_float32 = torch.rand(4, 2, 3, device=device) + b_float32 = torch.rand(4, 3, 2, device=device) + c_float32 = torch.rand(4, 2, 2, device=device) + d_float32 = torch.rand(4, 3, 2, device=device) + + # Case 1, autocast inside vmapped function + def func1(x, y, z, w): + with torch.autocast(dtype=amp_dtype, device_type=device): + e_float16 = torch.matmul(x, y) + assert e_float16.dtype == amp_dtype, e_float16.dtype + f_float16 = torch.matmul(z, e_float16) + assert f_float16.dtype == amp_dtype, f_float16.dtype + return torch.matmul(w, f_float16.float()) + + expected = func1(a_float32, b_float32, c_float32, d_float32) + out = vmap(func1)(a_float32, b_float32, c_float32, d_float32) + assert expected.allclose(out) + + # Case 2, autocast decorator inside vmapped function + @torch.autocast(dtype=amp_dtype, device_type=device) + def func2(x, y, z, w): + e_float16 = torch.matmul(x, y) + assert e_float16.dtype == amp_dtype, e_float16.dtype + f_float16 = torch.matmul(z, e_float16) + assert f_float16.dtype == amp_dtype, f_float16.dtype + return torch.matmul(w, f_float16) + + expected = func2(a_float32, b_float32, c_float32, d_float32) + out = vmap(func2)(a_float32, b_float32, c_float32, d_float32) + assert expected.allclose(out) + + # Case 3, autocast is outside vmapped function + def func3(x, y, z, w): + e_float16 = torch.matmul(x, y) + assert e_float16.dtype == amp_dtype, e_float16.dtype + f_float16 = torch.matmul(z, e_float16) + assert f_float16.dtype == amp_dtype, f_float16.dtype + return torch.matmul(w, f_float16) + + with torch.autocast(dtype=amp_dtype, device_type=device): + expected = func3(a_float32, b_float32, c_float32, d_float32) + out = vmap(func3)(a_float32, b_float32, c_float32, d_float32) + + assert expected.allclose(out) + + @unittest.skip("Somehow, vmap and autocast do not work on CPU") + def test_vmap_autocast_cpu(self): + self._test_vmap_autocast("cpu") + + @skipIf(not torch.cuda.is_available(), "CUDA is unavailable") + def test_vmap_autocast_cuda(self): + self._test_vmap_autocast("cuda") + + +def slice_inputs(inputs, bdims, i): + result = [] + for inp, bdim in zip(inputs, bdims): + if bdim is None: + result.append(inp) + else: + result.append(inp.select(bdim, i)) + return tuple(result) + + +def reference_vmap(op, inputs, in_dims=0, out_dims=0): + if isinstance(in_dims, int): + in_dims = (in_dims,) * len(inputs) + bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None] + assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes) + bdim_size = bdim_sizes[0] + results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size)) + + assert len(results) > 0 + op_has_single_return = not isinstance(results[0], tuple) + if op_has_single_return: + assert all(isinstance(result, torch.Tensor) for result in results) + if isinstance(out_dims, int): + out_dims = (out_dims,) * 1 + return torch.stack(results, dim=out_dims[0]) + + assert all(isinstance(result, tuple) for result in results) + num_returns = len(results[0]) + assert all(len(result) == num_returns for result in results) + if isinstance(out_dims, int): + out_dims = (out_dims,) * num_returns + return tuple(torch.stack(result_shards, out_dim) + for result_shards, out_dim in zip(zip(*results), out_dims)) + + +class TensorFactory: + @staticmethod + def rand(size, device='cpu', dtype=torch.float): + return torch.rand(size, device=device, dtype=dtype) + + @staticmethod + def randn(size, device='cpu', dtype=torch.float): + return torch.randn(size, device=device, dtype=dtype) + + @staticmethod + def randp1(size, device='cpu', dtype=torch.float): + return torch.rand(size, device=device, dtype=dtype) + 1 + +# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a +# (slow) sequential map+stack fallback. +# +# check_view: Test if the first returned output is a view of the first input +# check_propagates_grad: Test if the operation propagates gradients. + + +def _vmap_test(self, op, inputs, in_dims=0, out_dims=0, + check_view=False, check_propagates_grad=True): + result = vmap(op, in_dims, out_dims)(*inputs) + reference_result = reference_vmap(op, inputs, in_dims, out_dims) + self.assertEqual(result, reference_result) + op_has_single_return = not isinstance(result, tuple) + + if check_view: + result_as_tuple = (result,) if op_has_single_return else result + for output in result_as_tuple: + input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base + self.assertTrue(output._base is input0_base, + msg="result was not a view of the first input!") + + if not check_propagates_grad: + return + # Assuming input[0] is a floating-point tensor. Check if the vmap + # operation propagates the requires_grad flag to the zeroth output. + # Some vmap operators are implemented in a way that assumes that + # they are composite with respect to autograd. If the operator ever is + # changed to not be composite with respect to autograd, then the + # following check should fail. + inputs_clone = list(inputs) + inputs_clone[0] = inputs[0].clone().requires_grad_() + result = vmap(op, in_dims, out_dims)(*inputs_clone) + result_as_tuple = (result,) if op_has_single_return else result + self.assertTrue(result[0].requires_grad) + + +def should_allow_vmap_fallback_usage(fn): + return getattr(fn, '_allow_vmap_fallback_usage', False) + + +def allowVmapFallbackUsage(fn): + fn._allow_vmap_fallback_usage = True + return fn + +# All tests of TestVmapBase check that the slow vmap fallback is never invoked. +# This is so that we can incrementally add batching rules for operators to +# replace the slow vmap fallback path for said operators. To skip this check, +# please use the allowVmapFallbackUsage decorator. +# +# NB: Don't add tests to TestVmapBase directly, unless you want them to run +# on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators. +# +# NB: TestVmapBase is a nested class. This prevents test runners from picking +# it up and running it. + + +class Namespace: + class TestVmapBase(TestCase): + def __init__(self, method_name='runTest'): + super().__init__(method_name) + + test_method = getattr(self, method_name, None) + if test_method is None: + return + + if not should_allow_vmap_fallback_usage(test_method): + setattr(self, method_name, + self._wrap_method_with_vmap_fallback_check(test_method)) + + def _wrap_method_with_vmap_fallback_check(self, method): + # msg = ( + # 'Expected the test to not invoke the vmap fallback path, i.e., ' + # 'all of the operators being tested in this test should have batching ' + # 'rules implemented. If you are intentionally testing something to ' + # 'do with the fallback path, use allowVmapFallbackUsage. Otherwise, ' + # 'please make sure that batching rules are implemented for the ' + # 'operator(s) being tested.' + # ) + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + with warnings.catch_warnings(record=True): + warnings.simplefilter('always') + with EnableVmapFallbackWarnings(): + method(*args, **kwargs) + # for captured_warning in wa: + # self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg) + return types.MethodType(wrapper, self) + + @allowVmapFallbackUsage + def test_vmap_fallback_check_ok(self): + # One day we'll implement a batching rule for torch.var_mean. + # When that happens, please change the example to use an + # operator that doesn't have a batching rule implemented. + op_using_fallback = torch.var_mean + vmap(op_using_fallback)(torch.rand(3)) + + @unittest.expectedFailure + def test_vmap_fallback_check(self): + @self._wrap_method_with_vmap_fallback_check + def no_fallback(self): + pass + + # One day we'll implement a batching rule for torch.var_mean. + # When that happens, please change the example to use an + # operator that doesn't have a batching rule implemented. + op_using_fallback = torch.var_mean + + @self._wrap_method_with_vmap_fallback_check + def uses_fallback(self): + vmap(op_using_fallback)(torch.rand(3)) + + no_fallback(self) + + with self.assertRaises(AssertionError): + uses_fallback(self) + + +def _make_case(op, input_getter=TensorFactory.randn): + return (op, input_getter) + + +class TestVmapOperators(Namespace.TestVmapBase): + def _vmap_test(self, *args, **kwargs): + return _vmap_test(self, *args, **kwargs) + + def _vmap_view_test(self, *args, **kwargs): + self._vmap_test(*args, **kwargs, check_view=True) + + def _test_unary(self, op, getter, device, *args, **kwargs): + test = functools.partial(self._vmap_test, *args, **kwargs) + B0, B1 = 7, 11 + + # Single vmap, various in_dims / out_dims + test(op, [getter([B0, 3], device)]) + test(op, [getter([2, 5, B0, 3], device)], in_dims=2) + test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [getter([B0, B1], device)]) + test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2) + test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)], + in_dims=2, out_dims=2) + + @parametrize("case", [ + (torch.abs, TensorFactory.randn), + (torch.acos, TensorFactory.rand), + (torch.asin, TensorFactory.rand), + (torch.atan, TensorFactory.rand), + (torch.ceil, TensorFactory.randn), + (torch.cos, TensorFactory.rand), + (torch.cosh, TensorFactory.rand), + (torch.digamma, TensorFactory.rand), + (torch.exp, TensorFactory.randn), + (torch.expm1, TensorFactory.randn), + (torch.floor, TensorFactory.randn), + (torch.frac, TensorFactory.randn), + (torch.lgamma, TensorFactory.rand), + (torch.log, TensorFactory.randp1), + (torch.log10, TensorFactory.randp1), + (torch.log1p, TensorFactory.randp1), + (torch.log2, TensorFactory.randp1), + (torch.neg, TensorFactory.randn), + (torch.reciprocal, TensorFactory.randp1), + (torch.relu, TensorFactory.randn), + (torch.round, TensorFactory.randn), + (torch.rsqrt, TensorFactory.randp1), + (torch.sigmoid, TensorFactory.randn), + (torch.sign, TensorFactory.randn), + (torch.sin, TensorFactory.rand), + (torch.sinh, TensorFactory.rand), + (torch.sqrt, TensorFactory.rand), + (torch.tan, TensorFactory.rand), + (torch.tanh, TensorFactory.rand), + (torch.trunc, TensorFactory.randn), + ], name_fn=lambda x: x[0].__name__) + def test_unary_pointwise(self, case): + op, getter = case + self._test_unary(op, getter, 'cpu') + + # test in-place + method = getattr(Tensor, f'{op.__name__ + "_"}') + self._test_unary(method, getter, 'cpu', check_propagates_grad=False) + + def test_clone(self): + # Some basic tests + self._test_unary(lambda x: x.clone(), TensorFactory.randn, 'cpu') + self._test_unary(lambda x: x.clone(memory_format=torch.preserve_format), + TensorFactory.randn, 'cpu') + self._test_unary(lambda x: x.clone(memory_format=torch.contiguous_format), + TensorFactory.randn, 'cpu') + + # Test that the per-examples are contiguous when using torch.contiguous_format + def clone_contiguous(x): + return x.clone(memory_format=torch.contiguous_format) + + B0, B1 = 3, 5 + x = torch.randn(2, B0, 7) + y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x) + self.assertTrue(y.movedim(1, 0).is_contiguous()) + self.assertTrue(y[:, 0, :].is_contiguous()) + + x = torch.randn(2, B0, 7, B1) + y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x) + self.assertTrue(y.is_contiguous()) + self.assertTrue(y[0][0].is_contiguous()) + + msg = r'only supported with memory_format torch.preserve_format or torch.contiguous_format' + with self.assertRaisesRegex(RuntimeError, msg): + vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0)) + + def test_weird_matmul_case(self): + # Check that this doesn't crash. + # https://github.com/pytorch/functorch/issues/417 + x = torch.randn(5, 2, 2, 2) + y = torch.randn(5, 7, 2) + + vmap(vmap(torch.matmul, in_dims=(None, 0)))(x, y) + + @parametrize("case", + ( + (torch.clamp_min_, TensorFactory.randn), + (torch.clamp_max_, TensorFactory.randn), + ), name_fn=lambda x: x[0].__name__) + def test_clamp_inplace_variant(self, case): + test = self._vmap_test + + def get_number(getter): + return getter([]).item() + + op, getter = case + device = 'cpu' + B0, B1 = 7, 11 + + # Single vmap: op(Tensor, Tensor) + test(op, (getter([B0, 3], device), getter([B0, 3], device)), check_propagates_grad=False) + test(op, (getter([B0], device), getter([B0], device)), check_propagates_grad=False) + test(op, (getter([2, B0, 3], device), getter([2, B0, 3], device)), in_dims=(1, 1), check_propagates_grad=False) + test(op, (getter([B0, 2, 3], device), getter([2, B0, 3], device)), + in_dims=(0, 1), out_dims=1, check_propagates_grad=False) + test(op, (getter([B0, 2, 3], device), getter([1, 1], device)), in_dims=(0, None), check_propagates_grad=False) + test(op, (getter([B0, 3], device), getter([B0, 3], device)), in_dims=(0, 0), check_propagates_grad=False) + + # Nested vmap: op(Tensor, Tensor) + test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 1, 3], device)), check_propagates_grad=False) + + # Python number overload: op(Tensor, Number) + number = get_number(getter) + self._test_unary(lambda t: op(t, number), getter, device, check_propagates_grad=False) + + @parametrize('case', [ + subtest(_make_case(torch.clamp_min), name='clamp_min'), + subtest(_make_case(torch.clamp_max), name='clamp_max'), + ]) + def test_clamp_variant(self, case): + test = self._vmap_test + + def get_number(getter): + return getter([]).item() + + op, getter = case + device = 'cpu' + B0, B1 = 7, 11 + + # Single vmap: op(Tensor, Tensor) + test(op, (getter([B0, 3], device), getter([B0, 3], device))) + test(op, (getter([B0], device), getter([B0, 2, 3], device))) + test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1)) + test(op, (getter([B0], device), getter([2, B0, 3], device)), + in_dims=(0, 1), out_dims=1) + test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None)) + test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(None, 0)) + + # Nested vmap: op(Tensor, Tensor) + test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device))) + test(vmap(op, in_dims=(None, 0)), + (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None)) + + # Python number overload: op(Tensor, Number) + number = get_number(getter) + self._test_unary(lambda t: op(t, number), getter, device) + + def test_copy_(self): + x = torch.randn(3) + y = torch.randn(3) + vmap(Tensor.copy_)(x, y) + self.assertEqual(x, y) + + x = torch.randn(3) + y = torch.randn(3, 2) + vmap(Tensor.copy_, in_dims=(1, None))(y, x) + self.assertEqual(y, x.expand(2, 3).t()) + + x = torch.randn(3) + y = torch.randn(2, 3) + with self.assertRaisesRegex(RuntimeError, 'inplace'): + vmap(Tensor.copy_, in_dims=(None, 0))(x, y) + + def test_silu_backward(self): + test = self._vmap_test + device = 'cpu' + getter = TensorFactory.randp1 + B0 = 7 + op = torch.ops.aten.silu_backward + + # Single vmap: op(Tensor, Tensor) + test(op, (getter([B0, 3], device), getter([B0, 3], device))) + test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0)) + test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None)) + + @parametrize('case', [ + subtest(_make_case(torch.add), name='add'), + subtest(_make_case(lambda x, y: x + y), name='add_dunder'), + subtest(_make_case(torch.sub), name='sub'), + subtest(_make_case(lambda x, y: x - y), name='sub_dunder'), + subtest(_make_case(torch.mul), name='mul'), + subtest(_make_case(lambda x, y: x * y), name='mul_dunder'), + subtest(_make_case(torch.div, input_getter=TensorFactory.randp1), name='div'), + subtest(_make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), name='div_dunder'), + subtest(_make_case(torch.pow, input_getter=TensorFactory.randp1), name='pow'), + subtest(_make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1), name='pow_dunder'), + ]) + def test_arithmetic(self, case): + test = self._vmap_test + + def get_number(getter): + return getter([]).item() + + op, getter = case + device = 'cpu' + B0, B1 = 7, 11 + + # Single vmap: op(Tensor, Tensor) + test(op, (getter([B0, 3], device), getter([B0, 3], device))) + test(op, (getter([B0], device), getter([B0, 2, 3], device))) + test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1)) + test(op, (getter([B0], device), getter([2, B0, 3], device)), + in_dims=(0, 1), out_dims=1) + test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None)) + test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None)) + + # Nested vmap: op(Tensor, Tensor) + test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device))) + test(vmap(op, in_dims=(None, 0)), + (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None)) + + # Python number overload: op(Tensor, Number) (and vice-versa) + number = get_number(getter) + self._test_unary(lambda t: op(t, number), getter, device) + number = get_number(getter) + self._test_unary(lambda t: op(number, t), getter, device) + + # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor) + test(op, (getter([B0], device), getter([B0], device, dtype=torch.double))) + test(op, (getter([B0], device, dtype=torch.double), getter([B0], device))) + test(op, (getter([B0], device), getter([B0], device))) + + # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa) + test(op, (getter([B0, 2], device), getter([B0], device, torch.double))) + test(op, (getter([B0], device, torch.double), getter([B0, 2], device))) + + if not torch.cuda.is_available(): + return + + # TODO(rzou): fix the following + # # Test cross-device scalars + # number = get_number(getter) + # self._test_unary(lambda t: op(t, number), getter, device='cuda') + # self._test_unary(lambda t: op(number, t), getter, device='cuda') + # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda') + + # TODO: as_strided BR + @unittest.expectedFailure + def test_as_strided(self): + def _test(sizes, strides, offset, tensor, lambd): + result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor) + expected = vmap(lambd)(tensor) + self.assertTrue(result._base is expected._base) + self.assertEqual(result, expected) + + # single vmap test + B0 = 5 + tensors = [ + # contiguous + torch.randn(B0, 2, 3), + # non-contiguous + torch.randn(B0, 3, 2).transpose(1, 2), + # non-zero storage offset + torch.randn(2, B0, 2, 3)[1], + # non-contiguous strides, zero storage offset + torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0], + # non-contiguous strides, non-zero storage offset + torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1], + ] + + for x in tensors: + S0, S1 = x.stride()[1:] + offset = x.storage_offset() + + # Broadcast + _test([5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3)) + # transpose + _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1)) + # select + _test([2], [S0], offset + S1, x, lambda x: x[:, 1]) + + # Nested vmap test + B1 = 7 + x = torch.randn(B1, B0, 2, 3) + S0, S1 = x.stride()[2:] + result = vmap(vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1)(x) + expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x) + self.assertTrue(result._base is expected._base) + self.assertEqual(result, expected) + + # Check that mal-formatted size/strides doesn't crash + with self.assertRaisesRegex(RuntimeError, 'size and stride must have the same length'): + x = torch.randn(B0, 2, 3).transpose(0, 1) + vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x) + + # Sanity check #1: we require the batch dims to be at the front of the + # tensor (in memory layout). + msg = 'batch dims being vmapped over are at the front of the tensor' + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(2, B0, 3).transpose(0, 1) + vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x) + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, 2, 3, B1).movedim(3, 1) + vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x) + + # All the Sanity check #2{a,b,c} cases check that + # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + # doesn't index memory that is out of bounds of xs[i]. This condition + # is important to the correctness of the as_strided batching rule + # (see NOTE: [When will the as_strided_batching_rule fail?]) + + # Sanity check #2a: The maximum indexable location of + # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + # is less than or equal to the maximum indexable location of xs[i]. + msg = 'This is not supported inside of vmap' + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, 3) + vmap(lambda x: x.as_strided([3], [1], 1))(x) + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, 3, 5) + vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x) + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, B1, 3, 5) + vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x) + + # Sanity check #2b: The min indexable location of + # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + # is greater than or equal to the min indexable location of xs[i]. + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(2, B0, 3)[1] + vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x) + + # Sanity check #2c: + # xs[i] is a zero-dim tensor, but + # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + # is not + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, 0, 3) + vmap(lambda x: x.as_strided([3], [1]))(x) + + def test_nll_loss(self): + test = self._vmap_test + op = F.nll_loss + B = 3 + + y = torch.randn(B, 2, 5) + t = torch.randint(0, 5, (B, 2)) + test(op, (y, t)) + test(functools.partial(op, reduction='sum'), (y, t)) + test(functools.partial(op, reduction='none'), (y, t)) + + y = torch.randn(B, 2, 5) + t = torch.randint(0, 5, (2,)) + test(op, (y, t), in_dims=(0, None)) + test(functools.partial(op, reduction='sum'), (y, t), in_dims=(0, None)) + test(functools.partial(op, reduction='none'), (y, t), in_dims=(0, None)) + + def test_adaptive_avg_pool2d(self): + test = self._vmap_test + op = functools.partial(F.adaptive_avg_pool2d, output_size=(3, 3)) + + x = torch.randn(3, 5, 7, 9, 11) + test(op, (x,)) + test(op, (x,), in_dims=(1,)) + test(op, (x,), in_dims=(4,)) + + def test_bmm(self): + op = torch.bmm + test = self._vmap_test + B0, B1 = 7, 11 + + # shape mismatch + msg = "" + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) + + # left arg is vmapped + test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None)) + test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)), + in_dims=(1, None)) + + # right arg is vmapped + test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) + test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)), + in_dims=(None, 1)) + + # both args are vmapped + test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3))) + test(vmap(op), (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), in_dims=(1, 0)) + test(vmap(op, in_dims=(0, None)), + (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), in_dims=(None, 0)) + + def test_cat(self): + test = self._vmap_test + B0, B1 = 5, 7 + + # Quick hack b/c vmap can't accept a list of tensors as an argument + def get_op(dim): + def op(*tensors): + return torch.cat(tensors, dim=dim) + return op + + test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3))) + test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0)) + test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2)) + test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2)) + test(vmap(get_op(0), in_dims=(0, None)), + (torch.rand(B1, 2), torch.rand(B0, 3)), in_dims=(None, 0)) + test(vmap(get_op(0), in_dims=(0, 0)), + (torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0)) + + def test_unsafe_view(self): + # Unsafe view isn't exposed, so we get at it via + # vmap(grad(matmul)) + test = functools.partial(self._vmap_test, check_propagates_grad=False) + B = 2 + x = torch.randn(B, 2, 3, 3) + y = torch.randn(B, 3, 3) + + def baz(x, y): + return (x @ y).sum() + + test(functorch.grad(baz), (x, y)) + + def test_conj(self): + op = torch.conj + + def run_test(dtype): + def get(shape): + return torch.randn(shape, dtype=dtype) + B0, B1 = 7, 11 + test = self._vmap_test + + # Single vmap, various in_dims / out_dims + test(op, [get([B0, 3])]) + test(op, [get([2, 5, B0, 3])], in_dims=2) + test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [get([B0, B1])]) + test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2) + test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], + in_dims=2, out_dims=2) + + # correctness tests + run_test(torch.float) + run_test(torch.cfloat) + + # check that torch.conj on a non-complex tensor returns the same tensor + real_tensor = torch.randn(3) + result = vmap(op)(real_tensor) + self.assertEqual(result.data_ptr(), real_tensor.data_ptr()) + + def test_contiguous(self): + op = Tensor.contiguous + + self._test_unary(op, TensorFactory.randn, 'cpu') + + # check that contiguous returns the original tensor if the per-examples + # are already contiguous + B0 = 3 + x = torch.randn(B0, 2, 5, 7) + x = x.movedim(0, 2) + result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x) + self.assertTrue(result is x) + + msg = 'NYI: querying is_contiguous inside of vmap for memory_format' + tensor = torch.randn(B0, 3) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(op, memory_format=torch.channels_last))(tensor) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor) + + def test_stride(self): + B0 = 3 + + x = torch.randn(B0, 2, 5, 7) + + def foo(x): + assert x.stride() == (7 * 5, 7, 1) + return x + + vmap(foo)(x) + + x = torch.randn(2, B0, 5, 7).movedim(1, 0) + + def bar(x): + assert x.stride() == (7 * 5 * B0, 7, 1) + return x + + vmap(bar)(x) + + def test_chunk(self): + test = self._vmap_view_test + op = torch.chunk + B0, B1, B2 = 7, 11, 13 + + # tests for torch.split(self, split_size: int, dim) + test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + + def test_clamp(self): + clamp_cases = ( + (lambda t: t.clamp(min=-0.5), TensorFactory.randn), + (lambda t: t.clamp(max=0.5), TensorFactory.randn), + (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn), + (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn), + (lambda t: t.clamp_max(max=0.5), TensorFactory.randn), + ) + for op, getter in clamp_cases: + self._test_unary(op, getter, 'cpu') + + def test_comparison_ops(self): + test = functools.partial(self._vmap_test, check_propagates_grad=False) + + getter = TensorFactory.randn + B0, B1 = 7, 11 + + ops = ( + torch.eq, lambda x, y: x == y, + torch.gt, lambda x, y: x > y, + torch.ge, lambda x, y: x >= y, + torch.le, lambda x, y: x <= y, + torch.lt, lambda x, y: x < y, + torch.ne, lambda x, y: x != y, + ) + + for op in ops: + # Single vmap: op(Tensor, Tensor) + test(op, (getter([B0, 3]), getter([B0, 3]))) + test(op, (getter([B0]), getter([B0, 2, 3]))) + test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1)) + test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1) + test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None)) + test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None)) + + # Nested vmap: op(Tensor, Tensor) + test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3]))) + test(vmap(op, in_dims=(None, 0)), + (getter([B0, 2, 3]), getter([B1, 3])), in_dims=(0, None)) + + # test number as inputs + number = getter([]).item() + self._test_unary(lambda t: op(t, number), getter, 'cpu', check_propagates_grad=False) + + def test_cross_batch_size_three(self): + # Let's test corner case when batch_size is 3 and cross' dim argument is not specified + # According to the cross API, dim will be assigned to the first dim with value 3 + # In this test we ensure that found dim is not batch dim. + op = torch.cross + test = self._vmap_test + B0 = B1 = 3 + test(op, (torch.rand(B0, 2, 3), torch.rand(B0, 2, 3))) + test(vmap(op, in_dims=(0, None)), (torch.rand(B0, B1, 2, 3), torch.rand(B0, B1, 2, 3)), + in_dims=(None, 1)) + + def test_diagonal(self): + tensor = torch.randn(3, 5, 7, 11, 13) + test = self._vmap_view_test + op = torch.diagonal + test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None)) + test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None)) + test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None)) + test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1) + test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1) + test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3), + (tensor,), in_dims=1, out_dims=1) + + def test_dot(self): + op = torch.dot + test = self._vmap_test + B0, B1 = 7, 11 + + # shape mismatch + msg = "" + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2)) + + # left arg is vmapped + test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None)) + test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 5), torch.rand(5)), + in_dims=(1, None)) + + # right arg is vmapped + test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0)) + test(vmap(op, in_dims=(None, 0)), (torch.rand(5), torch.rand(B1, B0, 5)), + in_dims=(None, 1)) + + # both args are vmapped + test(op, (torch.rand(B0, 5), torch.rand(B0, 5))) + test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)) + test(vmap(op, in_dims=(0, None)), + (torch.rand(B1, 5), torch.rand(B0, 5)), in_dims=(None, 0)) + + def test_expand_as(self): + op = torch.Tensor.expand_as + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5))) + test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None)) + test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) + test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5))) + test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1)) + test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None)) + test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5))) + + def test_fill_and_zero_inplace(self): + test = functools.partial(self._vmap_test, check_propagates_grad=False) + B0, B1 = 7, 11 + ops = ( + lambda t: t.fill_(0.1), + lambda t: t.fill_(torch.tensor(0.2)), + lambda t: t.zero_(), + ) + + for op in ops: + # Single vmap, various in_dims / out_dims + test(op, [TensorFactory.randn([B0, 3])]) + test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2) + test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [TensorFactory.randn([B0, B1])]) + test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2) + test(vmap(op, in_dims=2), [TensorFactory.randn([2, 5, B0, B1, 3])], + in_dims=2, out_dims=2) + + # test when value is a batched tensor for fill_ operator + B0, B1 = 3, 5 + test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)]) + + with self.assertRaisesRegex(RuntimeError, + ""): + # Runtime Error is thrown when the tensor being written to isn't being vmapped over + vmap(Tensor.fill_, (None, 0))(TensorFactory.randn([B0, B1]), + TensorFactory.randn([B0])) + + def _test_complex_views(self, op, dtypes): + test = self._vmap_view_test + + def run_test(op, dtype): + def get(shape): + return torch.randn(shape, dtype=dtype) + + B0, B1 = 7, 11 + + # Single vmap, various in_dims / out_dims + test(op, [get([B0, 3])]) + test(op, [get([3, B0])], in_dims=1) + test(op, [get([2, 5, B0, 3])], in_dims=2) + test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [get([B0, B1])]) + test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4) + test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], + in_dims=2, out_dims=2) + + for dtype in dtypes: + run_test(op, dtype) + + def test_real(self): + self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble]) + + def test_imag(self): + self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble]) + + def test_view_as_real(self): + self._test_complex_views(torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble]) + + def test_view_as_complex(self): + def run_test(dtype): + def get(shape): + return torch.randn(shape, dtype=dtype) + + op = torch.view_as_complex + test = self._vmap_view_test + B0, B1 = 7, 11 + + # Single vmap, various in_dims / out_dims + test(op, [get([B0, 3, 2])]) + test(op, [get([2, 5, B0, 3, 2])], in_dims=2) + test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [get([B0, B1, 2])]) + test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2) + test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], + in_dims=2, out_dims=2) + + # Interesting case #1: Batch dim directly before dim of size 2 + test(op, [get([3, B0, 2])], in_dims=1) + test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2) + + # Interesting case #2: Batch dim at end of tensor, success cases + # view_as_complex requires that the dim with size 2 have stride 1 + # in order for the view to function propertly + test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1) + test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)]) + test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)]) + + # Interesting case #3: Batch dim at end of tensor, failure cases + msg = "Tensor must have a last dimension with stride 1" + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=1)(get([2, B0])) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1])) + + # Invalid input: no dimension of size 2 + msg = 'Input tensor must have one or more dimensions' + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(get([B0])) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op))(get([B0, B1])) + + # Invalid input: Batch dim has size 2, but the logical last dim does + # not have size 2 + msg = 'Tensor must have a last dimension of size 2' + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=1)(get([3, 2])) + + for dtype in [torch.float, torch.double]: + run_test(dtype) + + def test_is_complex(self): + ctensor = torch.randn(3, dtype=torch.cfloat) + tensor = torch.randn(3) + + def foo(x): + if x.is_complex(): + return torch.tensor(1) + else: + return torch.tensor(0) + + self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1])) + self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0])) + + def test_is_floating_point(self): + float_tensor = torch.tensor([1., 2., 3.]) + long_tensor = torch.tensor([1, 2, 3]) + + def foo(x): + if x.is_floating_point(): + return torch.tensor(1) + else: + return torch.tensor(0) + + self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1])) + self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0])) + + def test_is_contiguous(self): + def foo(x): + if x.is_contiguous(): + return torch.tensor(1.) + else: + return torch.tensor(0.) + + B0, B1 = 3, 5 + + # Single batch dim + contig = torch.randn(B0, 2, 7) + self.assertEqual(vmap(foo)(contig), torch.ones(B0)) + + noncontig = torch.randn(2, B0, 7) + self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0)) + + noncontig = torch.randn(2, B0, 7).movedim(1, 0) + self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0)) + + noncontig = torch.randn(2, 7, B0) + self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0)) + + # Multiple batch dims + contig = torch.randn(B0, B1, 3) + self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) + + contig = torch.randn(B1, B0, 3) + self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1)) + + contig = torch.randn(B1, B0, 3).movedim(0, 1) + self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) + + noncontig = torch.randn(B0, 3, B1) + self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1)) + + # is_contiguous on empty tensor is True + def bar(x): + assert x.is_contiguous() + return x + + vmap(bar)(torch.randn(B0, 0, 3)) + vmap(bar, in_dims=1)(torch.randn(0, B0, 3)) + vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2)) + + # is_contiguous with other memory formats + def baz(x, memory_format): + x.is_contiguous(memory_format=memory_format) + return x + + msg = 'NYI: querying is_contiguous inside of vmap for memory_format' + tensor = torch.randn(B0, 2, 7, 3) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) + + def test_unsqueeze(self): + op = torch.unsqueeze + test = self._vmap_view_test + B0, B1 = 7, 11 + + # unsqueeze dim 0 + test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None)) + test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None)) + + # unsqueeze last dim (positive) + test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None)) + test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None)) + + # unsqueeze last dim (negative) + test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None)) + test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None)) + + # nested vmaps + def unsqueeze_0(x): + return torch.unsqueeze(x, 0) + + def unsqueeze_last(x): + return torch.unsqueeze(x, -1) + + # bdims in canonical order + test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2), )) + test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),)) + + # wild bdims + test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2) + test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2) + test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2) + test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2) + + def test_movedim(self): + op = torch.movedim + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + + # movedim(tensor, int, int) variant + test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 2, B0, 5), 0, 1), in_dims=(2, None, None)) + test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), + (torch.rand(B1, 2, B0, 5, B2), 0, 1), in_dims=(2, None, None)) + + # movedim(tensor, intlist, intlist) variant + test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None)) + test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), + (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), in_dims=(2, None, None)) + test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), + (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), in_dims=(2, None, None)) + + def test_mm(self): + op = torch.mm + test = self._vmap_test + B0, B1 = 7, 11 + + # shape mismatch + msg = "Shape mismatch" + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) + + # left arg is vmapped + test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None)) + test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)), + in_dims=(1, None)) + + # right arg is vmapped + test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0)) + test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)), + in_dims=(None, 1)) + + # both args are vmapped + test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2))) + test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), in_dims=(1, 0)) + test(vmap(op, in_dims=(0, None)), + (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0)) + + def test_mv(self): + op = torch.mv + test = self._vmap_test + B0, B1 = 7, 11 + + # shape mismatch + msg = "" + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2)) + + # left arg is vmapped + test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None)) + test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5)), + in_dims=(1, None)) + + # right arg is vmapped + test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0)) + test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5)), + in_dims=(None, 1)) + + # both args are vmapped + test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5))) + test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)) + test(vmap(op, in_dims=(0, None)), + (torch.rand(B1, 2, 5), torch.rand(B0, 5)), in_dims=(None, 0)) + + def test_narrow(self): + op = torch.narrow + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + + test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None)) + test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None)) + test(vmap(op, in_dims=(0, None, None, None)), + (torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None)) + test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)), + (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None)) + + def test_new_empty(self): + # Empty is non-deterministic so we just check that the shape of the + # output tensor is what we expect and that the vmap fallback isn't used. + op = Tensor.new_empty + + B0, B1 = 7, 11 + + result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0)) + self.assertEqual(result.shape, [B0, 2, 3]) + + result = vmap(lambda x: op(x, []))(torch.randn(B0)) + self.assertEqual(result.shape, [B0]) + + result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1)) + self.assertEqual(result.shape, [B0, B1, 2, 3]) + + def test_new_empty_strided(self): + # Empty is non-deterministic so we just check that the size and shape + # of the output are what we expect and that the vmap fallback isn't used + B0, B1 = 7, 11 + + def _test_single_vmap(size, stride, B0): + x = torch.randn(B0) + result = vmap(lambda x: x.new_empty_strided(size, stride))(x) + S = torch.empty_strided(size, stride).storage().size() + self.assertEqual(result.shape, [B0] + size) + self.assertEqual(result.stride(), [S] + stride) + + def _test_double_vmap(size, stride, B0, B1): + x = torch.randn(B0, B1) + result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x) + S = torch.empty_strided(size, stride).storage().size() + self.assertEqual(result.shape, [B0, B1] + size) + self.assertEqual(result.stride(), [B1 * S, S] + stride) + + x = torch.randn(B1, B0) + result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(x) + S = x.new_empty_strided(size, stride).storage().size() + self.assertEqual(result.shape, [B0, B1] + size) + self.assertEqual(result.stride(), [B1 * S, S] + stride) + + # contiguous case + _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0) + _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1) + + # expanded + _test_single_vmap([2, 3, 5], [0, 5, 1], B0) + _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1) + + # some of these cases are pretty strange, just verifying that if + # empty_strided allows them then BatchedTensor.new_empty_strided + # can as well + for shape in [[2, 3, 4], [0, 2, 0]]: + for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]: + _test_single_vmap(shape, strides, B0) + _test_double_vmap(shape, strides, B0, B1) + + def test_new_zeros(self): + op = Tensor.new_zeros + test = functools.partial(self._vmap_test, check_propagates_grad=False) + B0, B1 = 7, 11 + + test(lambda x: op(x, 2, 3), (torch.rand(B0),)) + test(lambda x: op(x, []), (torch.rand(B0),)) + test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),)) + + def test_select(self): + op = torch.select + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None)) + test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2) + + def test_roll_no_dims(self): + op = torch.roll + test = self._vmap_test + B0, B1, B2 = 7, 11, 13 + test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None)) + test(op, (torch.rand(2, B0, 5), 3), in_dims=(1, None)) + test(vmap(lambda t: op(t, 3)), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(vmap(lambda t: op(t, 3), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2) + + def test_stack(self): + test = self._vmap_test + B0, B1 = 5, 7 + + # Quick hack b/c vmap can't accept a list of tensors as an argument + def get_op(dim): + def op(*tensors): + return torch.stack(tensors, dim=dim) + return op + + test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3))) + test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0)) + test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) + test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) + test(vmap(get_op(0), in_dims=(0, None)), + (torch.rand(B1, 2), torch.rand(B0, 2)), in_dims=(None, 0)) + test(vmap(get_op(0), in_dims=(0, 0)), + (torch.rand(B1, 2), torch.rand(B0, B1, 2)), in_dims=(None, 0)) + + def test_slice(self): + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + test(lambda t: t[0:1], (torch.rand(B0, 3, 5),)) + test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2) + test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2) + test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2), + (torch.rand(3, 5, B0, B1, B2),), in_dims=2) + + def test_squeeze(self): + def verify_behavior(op, min_ndim=1): + test = self._vmap_view_test + B0, B1 = 1, 11 + # These tests cannot be used with an operator that requires more + # than 1 dimension after batching. + if min_ndim <= 1: + test(op, (torch.rand(B0),)) + test(op, (torch.rand(B1),)) + test(vmap(op), (torch.rand(B0, B1, 1),)) + test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2) + test(op, (torch.rand(B0, 3, 5),)) + test(op, (torch.rand(1, B0, 5),), in_dims=1) + test(op, (torch.rand(B0, 0, 1, 5, 1),)) + test(op, (torch.rand(B0, 1, 1, 1, 1),)) + test(vmap(op), (torch.rand(B0, B1, 1, 3, 4),)) + test(vmap(op), (torch.rand(B1, 1, B0, 4, 5),), in_dims=2) + + verify_behavior(torch.squeeze) + verify_behavior(lambda x: torch.squeeze(x, dim=0), min_ndim=1) + verify_behavior(lambda x: torch.squeeze(x, dim=1), min_ndim=2) + verify_behavior(lambda x: torch.squeeze(x, dim=-1), min_ndim=2) + verify_behavior(lambda x: torch.squeeze(x, dim=-2), min_ndim=3) + + msg = "" + try: + torch.squeeze(torch.rand(10), dim=1) + except IndexError as err: + msg = str(err) + with self.assertRaises(RuntimeError, msg=msg): + vmap(lambda x: torch.squeeze(x, dim=1))(torch.rand(10)) + + def _test_mean_sum_dim(self, op): + test = self._vmap_test + B0, B1 = 5, 7 + + # Single vmap, various in_dims / out_dims + test(lambda x: op(x, 0), [torch.randn([B0])]) + test(lambda x: op(x, -1), [torch.randn([B0])]) + test(lambda x: op(x, 0), [torch.randn([B0, 3])]) + test(lambda x: op(x, -1), [torch.randn([2, 5, B0, 3])], in_dims=2) + test(lambda x: op(x, 2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(lambda x: op(x, 0)), [torch.randn([B0, B1])]) + test(vmap(lambda x: op(x, -1)), [torch.randn([B0, B1])]) + test(vmap(lambda x: op(x, -2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2) + test(vmap(lambda x: op(x, 2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])], + in_dims=2, out_dims=2) + + def test_sum_dim(self): + self._test_mean_sum_dim(torch.sum) + + def test_mean_dim(self): + self._test_mean_sum_dim(torch.mean) + + def test_argmax_dim(self): + def test(f, args): + for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}): + self.assertEqual(loop_out, batched_out) + B0 = 5 + test(lambda x: torch.argmax(x), [torch.randn(B0)]) + test(lambda x: torch.argmax(x), [torch.randn(B0, 2, 3)]) + test(lambda x: torch.argmax(x, 0), [torch.randn(B0, 2, 3)]) + test(lambda x: torch.argmax(x, -1), [torch.randn(B0, 2, 3)]) + test(lambda x: torch.argmax(x, 2), [torch.randn(B0, 2, 3)]) + + def _test_sum_mean(self, op): + test = self._vmap_test + B0, B1 = 5, 7 + + # Single vmap, various in_dims / out_dims + test(op, [torch.randn([B0])]) + test(op, [torch.randn([B0, 3])]) + test(op, [torch.randn([2, 5, B0, 3])], in_dims=2) + test(op, [torch.randn([2, 5, B0, 3])], in_dims=2) + + # Doubly nested vmap + test(vmap(op), [torch.randn([B0, B1])]) + test(vmap(op), [torch.randn([B1, 2, 5, B0, 3])]) + test(vmap(op), [torch.randn([2, 5, B0, B1, 3])], in_dims=2) + + def test_sum(self): + self._test_sum_mean(torch.sum) + + def test_mean(self): + self._test_sum_mean(torch.mean) + + def test_repeat(self): + test = self._vmap_test + B0 = 7 + op = Tensor.repeat + test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),)) + test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1) + + def test_slogdet(self): + test = functools.partial(self._vmap_test, check_propagates_grad=False) + B0 = 7 + op = torch.linalg.slogdet + test(op, (torch.rand(B0, 1, 1),)) + test(op, (torch.rand(B0, 2, 2),)) + test(op, (torch.rand(B0, 3, 2, 2),)) + test(op, (torch.rand(3, 2, 2, B0),), in_dims=3) + + def test_reshape(self): + test = self._vmap_test + B0, B1, B2 = 7, 11, 13 + op = torch.reshape + test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True) + test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False) + test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True) + test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1), + (torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False) + + def test_reshape_as(self): + test = self._vmap_test + B0, B1, B2 = 7, 11, 13 + op = torch.Tensor.reshape_as + test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True) + test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True) + test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True) + + test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False) + + test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True) + test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)), + (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)), + in_dims=(5, 0), check_view=False) + + def test_result_type(self): + def scalar_tensor_with_dtype(op): + def wrapped(*args, **kwargs): + dtype = op(*args, **kwargs) + return torch.ones([], dtype=dtype) + return wrapped + + test = self._vmap_test + op = scalar_tensor_with_dtype(torch.result_type) + + B0 = 2 + + test(op, (torch.randn(B0), torch.randn(B0, dtype=torch.float64)), + check_propagates_grad=False) + test(op, (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)), + check_propagates_grad=False) + + test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False) + test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False) + + test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0),), + check_propagates_grad=False) + test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), + (torch.randn(B0),), check_propagates_grad=False) + + test(op, (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)), + check_propagates_grad=False) + test(op, (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)), + check_propagates_grad=False) + + test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False) + test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False) + + test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0, 2),), + check_propagates_grad=False) + test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), + (torch.randn(B0, 2),), check_propagates_grad=False) + + test(op, (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)), + check_propagates_grad=False) + test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)), + check_propagates_grad=False) + + def test_tensor_split(self): + test = self._vmap_view_test + op = torch.tensor_split + B0, B1, B2 = 7, 11, 13 + + # tests for torch.tensor_split(self, indices_or_sections: int, dim) + test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + + # tests for torch.tensor_split(self, indices_or_sections: List[int], dim) + test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + + def test_split(self): + test = self._vmap_view_test + op = torch.split + B0, B1, B2 = 7, 11, 13 + + # tests for torch.split(self, split_size: int, dim) + test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + + # tests for torch.split(self, split_size: List[int], dim) + test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + + def test_trace(self): + op = torch.trace + test = self._vmap_test + B0, B1, B2 = 7, 11, 13 + test(op, (torch.rand(B0, 2, 5),)) + test(op, (torch.rand(2, B0, 5),), in_dims=1) + test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) + + def test_transpose(self): + op = torch.transpose + test = self._vmap_view_test + + B0, B1, B2 = 7, 11, 13 + test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),)) + test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),)) + test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),)) + test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1) + test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) + + # Special case: scalar tensor + for dim1, dim2 in itertools.product([0, -1], [0, -1]): + x = torch.rand(B0) + result = vmap(lambda x: op(x, dim1, dim2))(x) + self.assertTrue(result is x) + + def test_t(self): + op = torch.t + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + test(op, (torch.rand(B0, 2, 5),)) + test(op, (torch.rand(2, B0, 5),), in_dims=1) + test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) + + def test_T_numpy(self): + def op(t): + return t.T + + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + test(op, (torch.rand(B0, 2, 3, 5),)) + test(op, (torch.rand(B0),)) + test(op, (torch.rand(2, B0, 3, 5),), in_dims=1) + test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2) + test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2) + + def test_to(self): + test = self._vmap_test + B0, B1 = 7, 11 + + test(lambda t: t.to('cpu'), (torch.rand(B0),)) + test(lambda t: t.to(torch.double), (torch.rand(B0),)) + test(lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64))) + test(lambda t, o: t.to(o), + (torch.rand(B0), torch.randn(B0, dtype=torch.float64)), + in_dims=(0, None)) + test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),)) + + # also test some casting methods + test(lambda t: t.double(), (torch.rand(B0),)) + test(lambda t: t.float(), (torch.rand(B0),)) + test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False) + test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False) + + def test_unfold(self): + op = torch.Tensor.unfold + test = self._vmap_view_test + B0, B1, B2 = 3, 2, 5 + + test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None)) + test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None)) + test(vmap(op, in_dims=(0, None, None, None)), + (torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None)) + test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)), + (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None)) + + def test_unbind(self): + test = self._vmap_view_test + op = torch.unbind + B0, B1, B2 = 7, 11, 13 + + test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None)) + test(op, (torch.rand(B0, 2, 0),)) + test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None)) + test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1), + in_dims=(2, None)) + test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)), + (torch.rand(B1, 2, B0, 32, B2),), in_dims=2) + + def test_view(self): + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + op = torch.Tensor.view + + # We should error out if the view would produce an incorrect result + with self.assertRaises(RuntimeError): + vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10]) + + test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None)) + test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None)) + test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),)) + test(vmap(vmap(lambda t: t.reshape([-1])), in_dims=1), + (torch.rand(B2, B0, B1, 3, 2, 5),), in_dims=1) + + def test_view_as(self): + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + op = torch.Tensor.view_as + + # We should error out if the view would produce an incorrect result + with self.assertRaises(RuntimeError): + vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10)) + + test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5))) + test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0)) + test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None)) + + test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None)) + + test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10))) + test(vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)), + (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)), + in_dims=(2, 0)) + + def test_conv2d(self): + conv_setups = [ + (torch.nn.Conv1d, torch.conv1d, [2, 4, 15]), + (torch.nn.Conv2d, torch.conv2d, [2, 4, 15, 20]), + (torch.nn.Conv3d, torch.conv3d, [2, 4, 15, 20, 25]), + # (torch.nn.ConvTranspose2d, torch.conv_transpose2d, [2, 4, 15, 20]) + ] + for conv_mod, conv_fn, inp_shape in conv_setups: + mod = conv_mod(4, 8, kernel_size=3) + arg_values = [torch.randn(inp_shape), mod.weight, mod.bias] + kwarg_values = {} + for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values): + self.assertEqual(loop_out, batched_out) + + arg_values = [torch.randn(inp_shape), mod.weight, None] + for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values): + self.assertEqual(loop_out, batched_out) + + mod2 = conv_mod(4, 8, kernel_size=3, groups=2, stride=3, padding=1, dilation=2) + arg_values = [torch.randn(inp_shape), mod2.weight, mod2.bias] + kwarg_values = dict(groups=2, stride=3, padding=1, dilation=2) + for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values): + self.assertEqual(loop_out, batched_out) + + arg_values = [torch.randn(inp_shape), mod2.weight, None] + for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values): + self.assertEqual(loop_out, batched_out) + + def test_one_hot(self): + sample_inputs = [ + (torch.randint(0, 3, []), 3), + (torch.randint(0, 3, [2, 3, 4]), 4), + ] + for args in sample_inputs: + for loop_out, batched_out in get_fallback_and_vmap_exhaustive(F.one_hot, args, {}): + self.assertEqual(loop_out, batched_out) + + def test_conj_bit(self): + x = torch.tensor([1 + 1j, 2 + 1j]) + + def foo(x): + assert not x.is_conj() + y = x.conj() + assert y.is_conj() + return y + res = vmap(foo)(x) + self.assertEqual(res, x.conj()) + + def test_mode_key(self): + def vmap_f(x): + return x + torch.randn(()) + + def naive_f(x, shape): + return x + torch.randn(shape) + + torch.manual_seed(0) + out1 = vmap(vmap(vmap_f, randomness='different'), randomness='different')(torch.ones(2, 3)) + + torch.manual_seed(0) + out2 = naive_f(torch.ones(2, 3), (2, 3)) + self.assertEqual(out1, out2) + + torch.manual_seed(0) + out1 = vmap(vmap(vmap_f, randomness='different'), randomness='different')(torch.ones(2, 3, 4)) + + torch.manual_seed(0) + out2 = naive_f(torch.ones(2, 3, 4), (2, 3, 1)) + self.assertEqual(out1, out2) + + self.assertTrue(torch.randn(()).dim() == 0) + + @parametrize('in_dim', [0, 1, 2]) + @parametrize('out_dim', [0, 1, 2]) + @parametrize('randomness', ['error', 'same']) + def test_chunk_vmap(self, in_dim, out_dim, randomness): + + x = torch.randn(4, 5, 6) + + def f(x): + y = x.sin() + if randomness != "error": + y = y + torch.rand_like(x) + return y + + rs = torch.get_rng_state() + expected = vmap(f, in_dims=in_dim, out_dims=out_dim, randomness=randomness)(x) + + for chunks in [1, 2, 3, 4, 7, 10, 16]: + torch.set_rng_state(rs) + output = chunk_vmap( + f, in_dims=in_dim, out_dims=out_dim, randomness=randomness, chunks=chunks + )(x) + self.assertEqual(output, expected) + + +instantiate_parametrized_tests(TestVmapOperators) + + +def construct_v(output, batch_size): + return torch.randn(batch_size, *output.shape, + dtype=output.dtype, device=output.device) + + +def as_tuple(x): + if isinstance(x, tuple): + return x + elif isinstance(x, list): + return tuple(x) + else: + return x, + + +def differentiable(args): + return tuple(arg for arg in as_tuple(args) + if isinstance(arg, torch.Tensor) and arg.requires_grad) + + +def _get_rand_no_zeros(*args, **kwargs): + requires_grad = kwargs.get('requires_grad', False) + kwargs_without_requires_grad = kwargs.copy() + kwargs_without_requires_grad['requires_grad'] = False + result = torch.rand(*args, **kwargs_without_requires_grad) + return result.clamp_min_(0.1).requires_grad_(requires_grad) + + +class TestVmapBatchedGradient(Namespace.TestVmapBase): + def _vmap_test(self, *args, **kwargs): + return _vmap_test(self, *args, **kwargs) + + # Tests batched gradient computation of outputs = op(*args, **kwargs) + # by comparing it to a sequential map+stack fallback. + # + # output_process_fn: a function that maps the outputs to the part + # that should be differentiated. + # batch_size: the batch dim size for the batched grad + def _batched_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3): + if kwargs is None: + kwargs = {} + outputs = op(*args, **kwargs) + outputs = differentiable(output_process_fn(outputs)) + batched_vectors = tuple(construct_v(out, batch_size) for out in outputs) + + def vector_jacobian_product(*vectors): + return torch.autograd.grad(outputs, differentiable(args), vectors, + retain_graph=True) + self._vmap_test(vector_jacobian_product, batched_vectors, + check_propagates_grad=False) + + # Tests batched second grad computation of outputs = op(*args, **kwargs). + # by comparing it to a sequential map+stack fallback. + # + # output_process_fn: a function that maps the outputs to the part + # that should be differentiated. + # batch_size: the batch dim size for the batched grad + # + # NB: we only test computing batched gradients in the second gradient + # computation. One specific use case that does this is computing the hessian + # matrix of a scalar-valued function; this is useful in Bayesian Logistic + # Regression. + # It might be useful to have a test that computes batched first gradients and + # then uses those to compute batched second gradients in the future. + def _batched_grad_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3): + if kwargs is None: + kwargs = {} + outputs = op(*args, **kwargs) + outputs = differentiable(output_process_fn(outputs)) + ones = tuple(torch.ones_like(out) for out in outputs) + # Same thing as summing together all of the outputs and calling .backward() + first_grads = torch.autograd.grad(outputs, differentiable(args), ones, + create_graph=True) + first_grads = differentiable(first_grads) + self.assertNotEqual( + len(first_grads), 0, "None of the first grads depend on the input!") + + batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads) + + def vector_hessian_product(*vectors): + outputs = torch.autograd.grad(first_grads, differentiable(args), vectors, + retain_graph=True, allow_unused=True) + outputs = tuple(out for out in outputs if out is not None) + assert len(outputs) > 0 + return outputs + + self._vmap_test(vector_hessian_product, batched_vectors, + check_propagates_grad=False) + + def _test_arithmetic(self, op, device, test_grad_grad=True): + x = torch.randn(2, 3, requires_grad=True, device=device) + y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) + scalar = 3.14 + self._batched_grad_test(op, (x, y)) + self._batched_grad_test(op, (scalar, y)) + self._batched_grad_test(op, (x, scalar)) + + if test_grad_grad: + self._batched_grad_grad_test(op, (x, y)) + + def test_add(self, device): + self._test_arithmetic(torch.add, device, test_grad_grad=False) + self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False) + + def test_sub(self, device): + self._test_arithmetic(torch.sub, device, test_grad_grad=False) + self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False) + + def test_mul(self, device): + self._test_arithmetic(torch.mul, device) + self._test_arithmetic(lambda x, y: x * y, device) + + def test_div(self, device): + self._test_arithmetic(torch.div, device) + self._test_arithmetic(lambda x, y: x / y, device) + + def test_binary_cross_entropy(self, device): + x = F.sigmoid(torch.randn(3, 2, device=device, requires_grad=True)) + target = torch.rand(3, 2, device=device) + + op = functools.partial(F.binary_cross_entropy, target=target) + + self._batched_grad_test(op, (x,), {}) + self._batched_grad_grad_test(op, (x,), {}) + + def test_log_softmax(self, device): + op = functools.partial(torch.log_softmax, dim=-1) + x = torch.randn(3, 2, device=device, requires_grad=True) + + self._batched_grad_test(op, (x,), {}) + self._batched_grad_grad_test(op, (x,), {}) + + def test_expand(self, device): + x = torch.randn(2, 3, device=device, requires_grad=True) + + def op(x): + return x.expand(5, 5, 2, 3) + self._batched_grad_test(op, (x,)) + + @allowVmapFallbackUsage + def test_index(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + index = torch.tensor([[0, 0], [1, 1]], device=device) + + def op(x): + y = x * x + return y[index] + + self._batched_grad_test(op, (x,)) + self._batched_grad_grad_test(op, (x,)) + + def test_lgamma(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + self._batched_grad_test(Tensor.lgamma, (x,)) + self._batched_grad_grad_test(Tensor.lgamma, (x,)) + + def test_log(self, device): + x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) + self._batched_grad_test(torch.log, (x,)) + self._batched_grad_grad_test(torch.log, (x,)) + + def test_logsumexp(self, device): + x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) + + def op(x): + return torch.logsumexp(x, -1) + + self._batched_grad_test(op, (x,)) + self._batched_grad_grad_test(op, (x,)) + + def test_log1p(self, device): + x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) + self._batched_grad_test(torch.log1p, (x,)) + self._batched_grad_grad_test(torch.log1p, (x,)) + + @allowVmapFallbackUsage + def test_max(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + self._batched_grad_test(torch.max, (x,)) + + @allowVmapFallbackUsage + def test_median(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + self._batched_grad_test(torch.median, (x,)) + + @allowVmapFallbackUsage + def test_min(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + self._batched_grad_test(torch.min, (x,)) + + def test_permute(self, device): + x = torch.randn(2, 3, 5, requires_grad=True, device=device) + + def op(x): + return x.permute(2, 0, 1) + + self._batched_grad_test(op, (x,)) + + def test_reshape(self, device): + x = torch.randn(2, 3, 5, requires_grad=True, device=device) + + def op(x): + return x.reshape([2 * 3, 5]) + + self._batched_grad_test(op, (x,)) + + def test_sigmoid(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + self._batched_grad_test(Tensor.sigmoid, (x,)) + self._batched_grad_grad_test(Tensor.sigmoid, (x,)) + + def test_stack(self, device): + x = torch.randn(2, 3, device=device, requires_grad=True) + y = torch.randn(2, 3, device=device, requires_grad=True) + + def op(x, y): + return torch.stack([x, y]) + self._batched_grad_test(op, (x, y)) + + def test_select(self, device): + x = torch.randn(2, 3, device=device, requires_grad=True) + self._batched_grad_test(lambda x: x[1], (x,)) + self._batched_grad_test(lambda x: x.select(1, 2), (x,)) + self._batched_grad_test(lambda x: x.select(-1, 0), (x,)) + + def test_slice(self, device): + x = torch.randn(2, 3, 5, device=device, requires_grad=True) + self._batched_grad_test(lambda x: x[0:1], (x,)) + self._batched_grad_test(lambda x: x[:, 1:3], (x,)) + self._batched_grad_test(lambda x: x[..., 1:3], (x,)) + + def test_trace(self, device): + x = torch.randn(2, 3, device=device, requires_grad=True) + self._batched_grad_test(Tensor.trace, (x,)) + + x = torch.randn(3, 2, 2, device=device) + + def sum_grad_trace(x): + return grad(torch.trace)(x).sum() + + output = vmap(grad(sum_grad_trace))(x) + self.assertEqual(output, torch.zeros_like(output)) + + def test_where(self, device): + x = torch.randn(3, 2, device=device) + y = torch.ones(3, 2, device=device) + + def f(x, y): + return torch.where(x > 0, x, y) + + # Check that there is no runtime error, exactness tests are done with opinfo + vmap(f)(x, y) + + x = torch.randint(0, 2, size=(4, 3), dtype=torch.float) + + def f(t): + return torch.where(t) + + with self.assertRaisesRegex(RuntimeError, r"Attempted to vmap over aten::where"): + vmap(f)(x) + + @skipCUDAIfNoMagma + @allowVmapFallbackUsage + def test_symeig(self, device): + def op(x): + return torch.symeig(x, eigenvectors=True)[0] + + x = torch.randn(3, 3, device=device, requires_grad=True) + self._batched_grad_test(op, (x,), {}) + self._batched_grad_grad_test(op, (x,), {}) + + def test_threshold(self, device): + x = torch.randn(2, 3, device=device, requires_grad=True) + self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,)) + + @allowVmapFallbackUsage + def test_inplace_view(self, device): + leaf = torch.randn(4, 5, requires_grad=True) + + def func(leaf): + # Make sure the function is non-trivially twice differentiable + base = leaf * leaf + view = base[0] + view.cos_() + return view + + self._batched_grad_test(func, (leaf,), {}) + self._batched_grad_grad_test(func, (leaf,), {}) + + @allowVmapFallbackUsage + def test_inplace_manyview(self, device): + leaf = torch.randn(4, 4, 5, requires_grad=True) + + def func(leaf): + # Make sure the function is non-trivially twice differentiable + base = leaf * leaf + view = base.transpose(0, 2) + view = view[1] + view = view.diagonal() + view = view[::2] + view.cos_() + return view + + self._batched_grad_test(func, (leaf,), {}) + self._batched_grad_grad_test(func, (leaf,), {}) + + def test_diagonal(self, device): + x = torch.randn(4, 5, device=device, requires_grad=True) + self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,)) + + x = torch.randn(3, 4, 5, device=device, requires_grad=True) + self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,)) + + @allowVmapFallbackUsage + def test_unrelated_output(self, device): + B0 = 3 + x = torch.randn([], requires_grad=True) + y = torch.randn([], requires_grad=True) + gy = torch.randn(B0, requires_grad=True) + + def vjp(v): + res, = torch.autograd.grad(y, x, v, allow_unused=True) + return torch.zeros_like(x) if res is None else res + + result = vmap(vjp)(gy) + self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) + + @allowVmapFallbackUsage + def test_unrelated_output_multiple_grad(self, device): + B0 = 3 + x = torch.randn([], requires_grad=True) + y = torch.randn([], requires_grad=True) + gy = torch.randn(B0, requires_grad=True) + + def vjp(v): + res, = torch.autograd.grad(y, x, v, allow_unused=True) + return torch.zeros_like(x) if res is None else res + + _ = vjp(gy[0]) + result = vmap(vjp)(gy) + self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) + + +class TestVmapOperatorsOpInfo(TestCase): + vmap_fail = { + # These are things that we either cannot fix or are not actually problems + xfail('resize_'), + xfail('resize_as_'), + xfail('to_sparse'), + xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency + xfail('masked_select'), # dynamic op + xfail('nonzero'), # dynamic op + xfail('allclose'), # returns a boolean + xfail('rand_like'), # randomness is tested separately + xfail('randint_like'), # randomness is tested separately + xfail('randn_like'), # randomness is tested separately + xfail('bernoulli', ''), # randomness is tested separately + xfail('normal', ''), # randomness is tested separately + xfail('normal', 'number_mean'), # randomness is tested separately + xfail('multinomial', ''), # randomness + xfail('nn.functional.embedding', ''), # we only support some cases + xfail('nn.functional.rrelu'), # randomness + xfail('nn.functional.dropout2d', ''), # randomness + xfail('nn.functional.feature_alpha_dropout', 'with_train'), # randomness + xfail('as_strided'), # as_strided is too crazy + xfail('nn.functional.fractional_max_pool3d'), # randomness + xfail('nn.functional.fractional_max_pool2d'), # randomness + + # entries in here don't work and need to be fixed. + # Each one of these is a bug + xfail('view_as_complex'), + xfail('tensor_split'), + xfail('svd', device_type='cuda'), + xfail('linalg.svd', device_type='cuda'), + xfail('matrix_exp'), + xfail('histogramdd'), + xfail('nn.functional.gaussian_nll_loss'), + xfail('nn.functional.embedding_bag'), + xfail('__rpow__'), # https://github.com/pytorch/functorch/issues/617 + xfail('column_stack', ''), + xfail('pca_lowrank', ''), + xfail('svd_lowrank', ''), + skip('linalg.eigh', ''), # Flaky but is likely a real problem + + # required rank 4 tensor to use channels_last format + xfail('bfloat16'), + xfail('bool'), + xfail('byte'), + xfail('char'), + xfail('double'), + xfail('float'), + xfail('half'), + xfail('int'), + xfail('long'), + xfail('short'), + } + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', ( + tol1('linalg.det', + {torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'), + # The following is often flaky, but just on windows. + # We should investigate if it's actually a problem or not. + tol1('nn.functional.conv_transpose3d', + {torch.float32: tol(atol=1e-04, rtol=1e-02)}, device_type='cuda'), + )) + @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) + @skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail) + def test_vmap_exhaustive(self, device, dtype, op): + sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) + for sample_input in sample_inputs_itr: + arg_values = [sample_input.input] + list(sample_input.args) + kwarg_values = sample_input.kwargs + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) + try: + generator = get_fallback_and_vmap_exhaustive( + op.op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: + # empty_like and new_empty produce garbage values so we just check the shapes. + if op.name == 'empty_like' or op.name == 'new_empty': + self.assertEqual(loop_out.shape, batched_out.shape) + continue + self.assertEqual(loop_out, batched_out) + for a_op in op.aliases: + a_generator = get_fallback_and_vmap_exhaustive( + a_op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in a_generator: + self.assertEqual(loop_out, batched_out) + # todo(chilli): Garbage hack I added to deal with indexing not working + except Exception as e: + # Checking if we're throwing an error because of dynamic shapes. + if "dynamic" in e.args[0]: + continue + raise e + + @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', ( + tol1('linalg.det', + {torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'), + )) + @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) + @skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({ + xfail('complex'), + xfail('copysign'), + xfail('eig'), + xfail('histogram'), + xfail('index_fill'), + xfail('nansum'), + xfail('nanmean'), + # `index_put` OpInfo in pytorch/pytorch has + # masked index as input which is not supported + xfail('index_put', ''), + xfail('isin'), + xfail('linalg.cholesky'), + xfail('linalg.eigvals'), + xfail('linalg.eigvalsh'), + xfail('linalg.inv'), + xfail('linalg.lstsq'), + xfail('linalg.lstsq', 'grad_oriented'), + xfail('linalg.matrix_norm'), + xfail('linalg.matrix_power'), + xfail('linalg.matrix_rank'), + xfail('linalg.matrix_rank', 'hermitian'), + xfail('linalg.pinv'), + xfail('linalg.pinv', 'hermitian'), + xfail('linalg.norm'), + xfail('linalg.solve'), + xfail('linalg.tensorinv'), + xfail('lu_solve'), + xfail('lu_unpack'), + xfail('masked_fill'), + xfail('masked_scatter'), + xfail('masked_select'), + xfail('nanquantile'), + xfail('ormqr'), + xfail('put'), + xfail('quantile'), + xfail('renorm'), + xfail('resize_as_'), + xfail('take'), + xfail('tensor_split'), + xfail('to_sparse'), + xfail('vdot'), + xfail('__getitem__', ''), + xfail('all'), + xfail('any'), + xfail('count_nonzero'), + xfail('nanmean'), + xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency + xfail('resize_'), + xfail('view_as_complex'), + xfail('matrix_exp'), + xfail('bucketize'), + xfail('fft.ihfft2'), + xfail('fft.ihfftn'), + xfail('allclose'), + xfail('argwhere'), + xfail('linalg.cross'), + xfail('unique_consecutive'), + xfail('unique'), + xfail('nn.functional.ctc_loss'), + xfail('nn.functional.gaussian_nll_loss'), + xfail('nn.functional.huber_loss'), + # We can get this to work on CUDA through decomposition, + # but fails on CPU due to max_pool1d_cpu not having a derivative + xfail('nn.functional.max_pool1d'), + xfail('nn.functional.max_pool3d'), + xfail('histc'), + xfail('as_strided'), + xfail('istft'), + xfail('nonzero'), + xfail('nn.functional.fractional_max_pool2d'), + xfail('stft'), + xfail('linalg.solve_triangular'), + xfail('isclose'), + xfail('nn.functional.fractional_max_pool3d'), + xfail('nn.functional.bilinear'), + xfail('nn.functional.embedding_bag'), + xfail('linalg.tensorsolve'), + xfail('bernoulli', ''), + xfail('linalg.lu_factor', ''), + xfail('nn.functional.feature_alpha_dropout', 'with_train'), + xfail('nn.functional.kl_div', ''), + xfail('multinomial', ''), + xfail('column_stack', ''), + xfail('pca_lowrank', ''), + xfail('normal', ''), + xfail('nn.functional.dropout2d', ''), + xfail('normal', 'number_mean'), + xfail('svd_lowrank', ''), + xfail('linalg.lu_factor_ex', ''), + xfail('diagflat', ''), + xfail('special.log_ndtr'), + xfail('nn.functional.triplet_margin_loss', ''), + xfail('nn.functional.pdist', ''), + xfail('scatter_reduce', 'sum'), + xfail('nn.functional.smooth_l1_loss', ''), + xfail('scatter_reduce', 'amax'), + xfail('nn.functional.max_unpool1d', 'grad'), + xfail('nn.functional.multi_margin_loss', ''), + xfail('linalg.norm', 'subgradients_at_zero'), + xfail('scatter_reduce', 'prod'), + xfail('nn.functional.multilabel_margin_loss', ''), + xfail('scatter_reduce', 'amin'), + xfail('nn.functional.max_unpool3d', 'grad'), + xfail('nn.functional.max_unpool2d', ''), + xfail('nn.functional.max_unpool2d', 'grad'), + xfail('nn.functional.margin_ranking_loss', ''), + xfail('nn.functional.max_unpool1d', ''), + xfail('nn.functional.soft_margin_loss', ''), + xfail('scatter_reduce', 'mean'), + xfail('nn.functional.max_unpool3d', ''), + })) + def test_op_has_batch_rule(self, device, dtype, op): + def test(): + sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) + for sample_input in sample_inputs_itr: + arg_values = [sample_input.input] + list(sample_input.args) + kwarg_values = sample_input.kwargs + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) + generator = get_fallback_and_vmap_exhaustive( + op.op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in generator: + # empty_like and new_empty produce garbage values so we just check the shapes. + if op.name == 'empty_like' or op.name == 'new_empty': + self.assertEqual(loop_out.shape, batched_out.shape) + continue + self.assertEqual(loop_out, batched_out) + for a_op in op.aliases: + a_generator = get_fallback_and_vmap_exhaustive( + a_op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training) + for loop_out, batched_out in a_generator: + self.assertEqual(loop_out, batched_out) + check_vmap_fallback(self, test, op) + + def test_conv_double_backward(self, device): + images = torch.randn(2, 1, 5, 5, device=device) + weight = torch.randn(2, 1, 2, 2, device=device) + bias = torch.randn(2, device=device) + ggI = torch.randn_like(images) + ggW = torch.randn_like(weight) + ggb = torch.randn_like(bias) + stride = (1, 1) + padding = (0, 0) + dilation = (1, 1) + transposed = False + output_padding = (0, 0) + groups = 1 + output_mask = (True, True, True) + gO = torch.randn_like(F.conv2d(images, weight, bias, stride, padding, dilation, groups)) + + args = ( + ggI, ggW, ggb, gO, weight, images, stride, padding, dilation, + transposed, output_padding, groups, output_mask, + ) + op = torch.ops.aten._convolution_double_backward + + generator = get_fallback_and_vmap_exhaustive(op, args, {}) + + def test(): + for loop_out, batched_out in generator: + self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4) + + check_vmap_fallback(self, test, op) + + def test_isnan(self, device): + test = functools.partial(_vmap_test, check_propagates_grad=False) + + B, N, C, H, W = 2, 3, 24, 5, 7 + op = torch.isnan + + x = torch.randn(B, N, C, H, W) + x[x > 0] = float('nan') + test(self, op, (x,), in_dims=(0)) + + def test_isinf(self, device): + test = functools.partial(_vmap_test, check_propagates_grad=False) + + B, N, C, H, W = 2, 3, 24, 5, 7 + op = torch.isinf + + x = torch.randn(B, N, C, H, W) + x[x > 0] = float('inf') + test(self, op, (x,), in_dims=(0)) + + def test_foo_like(self, device): + # vfdev-5: Probably, we can remove this line. Flake8 reported as unused + # test = functools.partial(_vmap_test, check_propagates_grad=False) + + B, N, C, H, W = 2, 3, 24, 5, 7 + for op in [torch.ones_like, torch.zeros_like]: + x = torch.randn(B, N, C, H, W) + # todo(chilli): test these better + # Not testing correctness, just that they run + vmap(op, in_dims=(0,))(x,) + + def test_flatten(self, device): + test = functools.partial(_vmap_test, check_propagates_grad=False) + + op = torch.flatten + + x = torch.randn(2, 3, 4, 5) + test(self, op, (x, 1, 2), in_dims=(0, None, None)) + + def test_group_norm(self, device): + test = functools.partial(_vmap_test, check_propagates_grad=False) + + B, N, C, H, W = 2, 3, 24, 5, 7 + op = F.group_norm + + x = torch.randn(B, N, C, H, W) + weight = torch.randn(C) + bias = torch.randn(C) + test(self, op, (x, 3, weight, bias), in_dims=(0, None, None, None)) + + x = torch.randn(B, N, C, H, W) + weight = torch.randn(B, C) + bias = torch.randn(B, C) + test(self, op, (x, 4, weight, bias), in_dims=(0, None, 0, 0)) + + def test_index_put(self, device): + def test(f, t, idx, values): + base = f(t[0], idx[0], values[0]) + self.assertEqual(vmap(f, in_dims=(0, 0, 0))(t, idx, values)[0], base) + self.assertEqual(vmap(f, in_dims=(0, None, None))(t, idx[0], values[0])[0], base) + self.assertEqual(vmap(f, in_dims=(0, None, 0))(t, idx[0], values)[0], base) + self.assertEqual(vmap(f, in_dims=(0, 0, None))(t, idx, values[0])[0], base) + + def f(x, y, z): + x[y] = z + return x + + x = torch.randn(3, 4, 5, device=device) + y = torch.zeros((3, 2), device=device).long() + z = torch.randn(3, 2, 5, device=device) + test(f, x, y, z) + + # indexing innermost dim + def f(t, idx, values): + t[:, idx] = values + return t + + t = torch.zeros((3, 2, 3)) + values = torch.ones((3, 1, 2)) + idx = torch.tensor([[1, 2]]).expand((3, 2)) + test(f, t, idx, values) + + # indexing middle dim + def f(t, idx, values): + t[:, idx, :] = values + return t + + t = torch.zeros((3, 2, 3, 3)) + values = torch.ones((3, 1, 2, 3)) + idx = torch.tensor([[0, 2]]).expand((3, 2)) + test(f, t, idx, values) + + # indexing with slices + def f(t, values): + t[:, :2, :] = values + return t + + base = f(t[0], values[0]) + self.assertEqual(vmap(f, in_dims=(0, 0))(t, values)[0], base) + self.assertEqual(vmap(f, in_dims=(0, None))(t, values[0])[0], base) + + # index_put_ + tensor = torch.zeros(3, 3, 4) + value = torch.ones(3, 2) + idxs = (torch.tensor([[0], [1], [2]]), torch.tensor([[0]]), torch.tensor([1, 2])) + expected = torch.index_put_(tensor.clone(), idxs, value) + + def f(t, idx, v): + torch.index_put_(t, idx, v) + return t + + self.assertEqual(vmap(f, in_dims=(0, (None, None), 0))(tensor, idxs[1:], value), expected) + self.assertEqual(vmap(f, in_dims=(0, (None, None), None))(tensor, idxs[1:], value[0]), expected) + + @parametrize('training', [True, False]) + @parametrize('track_running_stats', [True, False]) + @parametrize('affine', [True, False]) + def test_batch_norm(self, device, affine, track_running_stats, training): + if not track_running_stats and not training: + return + + test = functools.partial(_vmap_test, check_propagates_grad=False) + BN = torch.nn.BatchNorm2d + ensemble_size = 10 + hidden_dim = 3 + + weights, buffers, _, _, _ = \ + functional_init_with_buffers(BN, [ensemble_size])( + hidden_dim, affine=affine, track_running_stats=track_running_stats) + + inputs = [torch.randn(ensemble_size, 32, hidden_dim, 16, 16, device=device)] + in_dims = [0] + + def append(inp, in_dim): + inputs.append(inp) + in_dims.append(in_dim) + + if track_running_stats: + running_mean, running_var, _ = buffers + append(running_mean.to(device), 0) + append(running_var.to(device), 0) + else: + append(None, None) + append(None, None) + + if affine: + weight, bias = weights + append(weight.to(device), 0) + append(bias.to(device), 0) + else: + append(None, None) + append(None, None) + + append(training, None) + + def op(inp, running_mean, running_var, weight, bias, training): + res = F.batch_norm(inp, running_mean, running_var, weight, bias, training) + if track_running_stats: + return res, running_mean, running_var + return res + + test(self, op, tuple(inputs), in_dims=tuple(in_dims)) + + def test_torch_return_types_returns(self, device): + t = torch.randn(3, 2, 2, device=device) + self.assertTrue(isinstance(vmap(torch.min, (0, None))(t, 0), torch.return_types.min)) + self.assertTrue(isinstance(vmap(torch.max, (0, None))(t, 0), torch.return_types.max)) + self.assertTrue(isinstance(vmap(torch.topk, (0, None, None))(t, 1, 0), torch.return_types.topk)) + self.assertTrue(isinstance(vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig)) + + def test_namedtuple_returns(self, device): + Point = namedtuple('Point', ['x', 'y']) + + def f(x, y): + return Point(x=x, y=y) + + x = torch.randn(2, 5, device=device) + y = torch.randn(2, 3, device=device) + self.assertTrue(isinstance(vmap(f)(x, y), Point)) + + def test_advanced_indexing(self, device): + def test(f, args): + for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}): + self.assertEqual(loop_out, batched_out) + + def f(x, idx): + return x[:, idx] + + def f2(x, idx): + return x[idx, :] + + def f3(x, idx): + return x[:, :, idx] + + inps = (torch.randn(5, 5, 5, device=device), + torch.randn(5, 5, 5, 5, device=device), + torch.randn(5, 5, 5, 5, 5, device=device)) + idxes = (torch.tensor([0, 1, 2], device=device), + torch.tensor([0, 1, 2], device=device).reshape(3, 1), + torch.tensor([0, 1, 2], device=device).reshape(3, 1, 1)) + for (inp, idx) in itertools.product(inps, idxes): + test(f, (inp, idx)) + test(f2, (inp, idx)) + test(f3, (inp, idx)) + + def test_nested_advanced_indexing(self, device): + e = torch.rand(7, 4, device=device) + idx = torch.tensor([0, 1], device=device).view(2, 1) + + # simple reference implementation for comparison + def _fake_vmap(f, in_dims=0, out_dims=0): + def w(input): + r = [f(input.select(in_dims, i)) for i in range(input.size(in_dims))] + return torch.stack(r, out_dims) + + return w + + def with_vmap(_vmap): + def g(idx_): + def f(e_): + return e_[idx_] + + return _vmap(f, in_dims=1)(e) + + r = _vmap(g)(idx) + return r + + a = with_vmap(vmap) + b = with_vmap(_fake_vmap) + self.assertEqual(a, b) + + +class TestRandomness(TestCase): + def _reset_random(self, generator, orig_state, use_generator, seed): + return generator.set_state(orig_state) if use_generator else torch.manual_seed(seed) + + def _get_image(self, batched_input, batch_size, device): + if batched_input == "first": + return torch.ones([batch_size, 3, 3, 14, 14], device=device) + if batched_input == "last": + return torch.ones([3, 3, 14, 14, batch_size], device=device) + assert batched_input == "none" + return torch.ones([3, 3, 14, 14], device=device) + + def _assert_all_slices_equal(self, tensor): + expected = tensor[0] + self.assertTrue((tensor == expected).all()) + + def _assert_all_slices_unique(self, tensor): + B0 = tensor.shape[0] + slices_equal = vmap(vmap(lambda x, y: (x == y).all(), (0, None)), (None, 0))(tensor, tensor) + assert slices_equal.shape == (B0, B0) + slices_equal.diagonal().zero_() + self.assertEqual(slices_equal, torch.zeros_like(slices_equal)) + + def _assert_throws_in_error_mode(self, fn, args, in_dims): + with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"): + vmap(fn, in_dims=in_dims, randomness="error")(*args) + + def _assert_throws_in_different_mode_inplace(self, fn, args, in_dims): + with self.assertRaisesRegex(RuntimeError, r"different inplace randomness on an unbatched tensor"): + vmap(fn, in_dims=in_dims, randomness="different")(*args) + + def _assert_throws_in_same_mode_batched(self, fn, args, in_dims): + with self.assertRaisesRegex(RuntimeError, + r"Vmap does not currently support same randomness with a batched tensor input"): + vmap(fn, in_dims=in_dims, randomness="same")(*args) + + def _in_dims(self, *batched_strings): + + def get_in_dim(batched_string): + if batched_string == "first": + return 0 + if batched_string == "last": + return -1 + assert batched_string == "none" + return None + + batched_strings = batched_strings + ("first",) # for the always batched as first dim dummy argument + return tuple(get_in_dim(batched_string) for batched_string in batched_strings) + + @parametrize('randomness', ['same', 'different', 'error']) + @parametrize('use_generator', [True, False]) + def test_factory_ops(self, device, randomness, use_generator): + generator = torch.Generator(device=device) + orig_state = generator.get_state() + kwargs = {'device': device, 'generator': generator} if use_generator else {'device': device} + ops = [ + lambda _, shape: torch.randn(shape, **kwargs), + lambda _, shape: torch.rand(shape, **kwargs), + lambda _, shape: torch.randint(100, shape, **kwargs), + lambda _, shape: torch.randint(5, 100, shape, **kwargs), + lambda _, shape: torch.normal(0., 1., shape, **kwargs), + ] + B0 = 4 + shape = (3, 3) + seed = 1234567 + + for op in ops: + passed = torch.randn(B0, device=device) + if randomness == 'error': + self._assert_throws_in_error_mode(op, (passed, shape), in_dims=(0, None)) + return + + generator = self._reset_random(generator, orig_state, use_generator, seed) + vmap_result = vmap(op, in_dims=(0, None), randomness=randomness)(passed, shape) + + generator = self._reset_random(generator, orig_state, use_generator, seed) + if randomness == "different": + expected = op(passed, [B0, *shape]) + self._assert_all_slices_unique(vmap_result) + self.assertEqual(vmap_result, expected) + else: + expected = op(passed, shape) + self._assert_all_slices_equal(vmap_result) + for i in range(B0): + self.assertEqual(vmap_result[i], expected) + + @parametrize('randomness', ['same', 'different', 'error']) + @parametrize('use_generator', [True, False]) + def test_randperm(self, device, randomness, use_generator): + # needs a special case because randperm doesn't take a batch size + B0 = 4 + seed = 1234567 + passed = torch.randn(B0, device=device) + + torch.manual_seed(seed) + generator = torch.Generator(device=device) + orig_state = generator.get_state() + + kwargs = {'device': device, 'generator': generator} if use_generator else {'device': device} + + if randomness == 'error': + with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"): + vmap(lambda _: torch.randperm(10, **kwargs), randomness=randomness)(passed) + return + + vmap_result = vmap(lambda _: torch.randperm(10, **kwargs), randomness=randomness)(passed) + generator = generator.set_state(orig_state) + torch.manual_seed(seed) + if randomness == 'different': + for i in range(B0): + expected = torch.randperm(10, **kwargs) + self.assertEqual(vmap_result[i], expected) + else: + expected = torch.randperm(10, **kwargs) + for i in range(B0): + self.assertEqual(vmap_result[i], expected) + + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + def test_dropout(self, device, randomness, batched_input): + def op(t, ignored): + return torch.nn.functional.dropout(torch.ones_like(t), training=True) + + B0 = 4 + always_batched = torch.randn((B0,)) + passed = self._get_image(batched_input, B0, device) + in_dims = self._in_dims(batched_input) + + if randomness == 'error': + with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"): + vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + return + + vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + + # Check that the randomness is within bounds... + # ideally this is close to 0.5 + p_estimate = vmap_result.mean() / 2 + self.assertTrue(p_estimate < 0.75) + self.assertTrue(p_estimate > 0.25) + + if randomness == 'different': + self._assert_all_slices_unique(vmap_result) + return + + assert randomness == 'same' + self._assert_all_slices_equal(vmap_result) + + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + def test_alpha_dropout(self, device, randomness, batched_input): + def op(t, ignored): + return torch.nn.functional.alpha_dropout(torch.ones_like(t), training=True) + + B0 = 4 + always_batched = torch.randn((B0,)) + passed = self._get_image(batched_input, B0, device) + in_dims = self._in_dims(batched_input) + + if randomness == 'error': + with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"): + vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + return + + # I have no clue how to actually test corectness of alpha dropout because the docs + # seem wrong: https://github.com/pytorch/pytorch/issues/74004 + vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + if randomness == 'different': + self._assert_all_slices_unique(vmap_result) + return + + assert randomness == 'same' + self._assert_all_slices_equal(vmap_result) + + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + @parametrize('dim', [2, 3]) + def test_feature_dropout(self, device, randomness, batched_input, dim): + def op(t, ignored): + f = torch.nn.functional.dropout2d if dim == 2 else torch.nn.functional.dropout3d + return f(torch.ones_like(t), training=True) + + B0 = 4 + always_batched = torch.randn((B0,)) + passed = self._get_image(batched_input, B0, device) + if dim == 3: + unsqueeze_dim = -2 if batched_input == "last" else -1 + passed = passed.unsqueeze(unsqueeze_dim) + in_dims = self._in_dims(batched_input) + + if randomness == 'error': + with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"): + vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + return + + vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + + # Check that the randomness is within bounds... + # ideally this is close to 0.5 + p_estimate = vmap_result.mean() / 2 + self.assertTrue(p_estimate < 0.75) + self.assertTrue(p_estimate > 0.25) + + # Check the "feature" pattern + dims = [-1, -2] if dim == 2 else [-1, -2, -3] + planes_numel = 2 * vmap_result.numel() / (vmap_result.shape[0] * vmap_result.shape[1] * vmap_result.shape[2]) + planes = vmap_result.sum(dims) + result = (planes == 0) ^ (planes == planes_numel) + self.assertEqual(result, torch.ones_like(result, dtype=torch.bool)) + + if randomness == 'different': + self._assert_all_slices_unique(vmap_result) + return + + assert randomness == 'same' + self._assert_all_slices_equal(vmap_result) + + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + def test_feature_alpha_dropout(self, device, randomness, batched_input): + def op(t, ignored): + return torch.nn.functional.feature_alpha_dropout(torch.ones_like(t), training=True) + + B0 = 4 + always_batched = torch.randn((B0,)) + passed = self._get_image(batched_input, B0, device) + unsqueeze_dim = -2 if batched_input == "last" else -1 + passed = passed.unsqueeze(unsqueeze_dim) + in_dims = self._in_dims(batched_input) + + if randomness == 'error': + with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"): + vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + return + + vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + + # I have no clue how to actually test corectness of alpha dropout because the docs + # seem wrong: https://github.com/pytorch/pytorch/issues/74004 + + # Check the "feature" pattern + dims = [-1, -2, -3] + planes = vmap_result.sum(dims) + max_elt = planes.max() + min_elt = planes.min() + result = (planes == min_elt) ^ (planes == max_elt) + self.assertEqual(result, torch.ones_like(result, dtype=torch.bool)) + + if randomness == 'different': + self._assert_all_slices_unique(vmap_result) + return + + assert randomness == 'same' + self._assert_all_slices_equal(vmap_result) + + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + def test_like_functions(self, device, randomness, batched_input): + seed = 1234567 + supported_ops = [ + lambda t, _: torch.randint_like(t, 20), + lambda t, _: torch.randint_like(t, 0, 20), + lambda t, _: torch.rand_like(t), + lambda t, _: torch.randn_like(t), + ] + B0 = 4 + + for op in supported_ops: + always_batched = torch.randn(B0) + passed = self._get_image(batched_input, B0, device) + in_dims = self._in_dims(batched_input) + + if randomness == 'error': + with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"): + vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched) + return + + torch.manual_seed(seed) + vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched) + + torch.manual_seed(seed) + + if batched_input == "last": + passed = passed.movedim(-1, 0) + if randomness == 'different': + if batched_input == "none": + passed = passed.expand(B0, *passed.shape) + expected = op(passed, 0) + + self._assert_all_slices_unique(vmap_result) + self.assertEqual(expected, vmap_result) + return + + assert randomness == 'same' + if batched_input != "none": + passed = passed[0] + expected = op(passed, 0) + self._assert_all_slices_equal(vmap_result) + for i in range(B0): + self.assertEqual(expected, vmap_result[i]) + + @parametrize('use_generator', [True, False]) + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + def test_random_unary_inplace(self, device, use_generator, randomness, batched_input): + generator = torch.Generator(device=device) + orig_state = generator.get_state() + kwargs = {'generator': generator} if use_generator else {} + ops = [ + lambda t, _: t.random_(**kwargs), + lambda t, _: t.random_(100, **kwargs), + lambda t, _: t.random_(-5, 100, **kwargs), + lambda t, _: t.normal_(**kwargs), + lambda t, _: t.bernoulli_(**kwargs), + lambda t, _: t.cauchy_(**kwargs), + lambda t, _: t.exponential_(**kwargs), + lambda t, _: t.geometric_(0.5, **kwargs), + lambda t, _: t.log_normal_(**kwargs), + lambda t, _: t.uniform_(**kwargs), + ] + B0 = 4 + seed = 1234567 + in_dims = self._in_dims(batched_input) + + for op in ops: + # because of in place updates, clone inputs + always_batched = torch.randn(B0, device=device) + passed = self._get_image(batched_input, B0, device) + passed_expected = passed.clone() + + if randomness == 'error': + self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims) + return + if randomness == 'different' and batched_input == "none": + self._assert_throws_in_different_mode_inplace(op, (passed, always_batched), in_dims=in_dims) + return + + generator = self._reset_random(generator, orig_state, use_generator, seed) + vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched) + + if batched_input == "last": + passed_expected = passed_expected.movedim(-1, 0) + generator = self._reset_random(generator, orig_state, use_generator, seed) + if randomness == "different": + expected = op(passed_expected, always_batched) + self._assert_all_slices_unique(vmap_result) + self.assertEqual(vmap_result, expected) + else: + if batched_input != "none": + passed_expected = passed_expected[0].clone() # bug in pytorch, normal_ on views doesn't work + expected = op(passed_expected, always_batched) + self._assert_all_slices_equal(vmap_result) + for i in range(B0): + self.assertEqual(vmap_result[i], expected) + + @parametrize('use_generator', [True, False]) + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + @parametrize('batched_probability', ["first", "last", "none"]) + def test_bernoulli_in_place(self, device, use_generator, randomness, batched_input, batched_probability): + B0 = 4 + seed = 1234567 + generator = torch.Generator(device=device) + orig_state = generator.get_state() + kwargs = {'generator': generator} if use_generator else {} + in_dims = self._in_dims(batched_input, batched_probability) + + def op(t, p, ignored): + return t.bernoulli_(p, **kwargs) + + # because of in place updates, clone inputs + always_batched = torch.randn(B0, device=device) + input = self._get_image(batched_input, B0, device) + input_expected = input.clone() + probability = self._get_image(batched_probability, B0, device) - 0.5 + + if randomness == 'error': + self._assert_throws_in_error_mode(op, (input, probability, always_batched), in_dims=in_dims) + return + if randomness == 'same' and batched_probability != "none": + self._assert_throws_in_same_mode_batched(op, (input, probability, always_batched), in_dims=in_dims) + return + if batched_input == "none" and batched_probability != "none": + regex = r"there exists a Tensor `other` in extra_args that has more elements than `self`" + with self.assertRaisesRegex(RuntimeError, regex): + vmap(op, in_dims=in_dims, randomness=randomness)(input, probability, always_batched) + return + if randomness == 'different' and batched_input == "none": + self._assert_throws_in_different_mode_inplace(op, (input, probability, always_batched), in_dims=in_dims) + return + + self._reset_random(generator, orig_state, use_generator, seed) + vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(input, probability, always_batched) + + self._reset_random(generator, orig_state, use_generator, seed) + if batched_input == "last": + input_expected = input_expected.movedim(-1, 0) + if batched_probability == "last": + probability = probability.movedim(-1, 0) + if randomness == "different": + expected = op(input_expected, probability, always_batched) + self._assert_all_slices_unique(vmap_result) + self.assertEqual(vmap_result, expected) + else: + if batched_input != "none": + input_expected = input_expected[0] + expected = op(input_expected, probability, always_batched) + self._assert_all_slices_equal(vmap_result) + for i in range(B0): + self.assertEqual(vmap_result[i], expected) + + @parametrize('use_generator', [True, False]) + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + @parametrize('batched_other', ["first", "last", "none"]) + def test_random_binary_out_of_place(self, device, use_generator, randomness, batched_input, batched_other): + generator = torch.Generator(device=device) + orig_state = generator.get_state() + kwargs = {'generator': generator} if use_generator else {} + ops = [ + lambda t, o, _: torch.normal(t, o, **kwargs), + lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs), + ] + + B0 = 4 + seed = 1234567 + in_dims = self._in_dims(batched_input, batched_other) + + for op in ops: + always_batched = torch.randn(B0, device=device) + input = self._get_image(batched_input, B0, device) + other = self._get_image(batched_other, B0, device) + + if randomness == 'error': + self._assert_throws_in_error_mode(op, (input, other, always_batched), in_dims=in_dims) + return + if randomness == 'same' and (batched_input != "none" or batched_other != "none"): + self._assert_throws_in_same_mode_batched(op, (input, other, always_batched), in_dims=in_dims) + return + + generator = self._reset_random(generator, orig_state, use_generator, seed) + vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(input, other, always_batched) + + if batched_input == "last": + input = input.movedim(-1, 0) + if batched_other == "last": + other = other.movedim(-1, 0) + + generator = self._reset_random(generator, orig_state, use_generator, seed) + if randomness == "different": + if batched_input == "none": + input = input.expand(B0, *input.shape) + expected = op(input, other, always_batched) + self._assert_all_slices_unique(vmap_result) + self.assertEqual(vmap_result, expected) + else: + assert batched_input == "none" and batched_other == "none" + expected = op(input, other, always_batched) + self._assert_all_slices_equal(vmap_result) + for i in range(B0): + self.assertEqual(vmap_result[i], expected) + + @parametrize('use_generator', [True, False]) + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_input', ["first", "last", "none"]) + def test_random_unary_out_of_place(self, device, use_generator, randomness, batched_input): + generator = torch.Generator(device=device) + orig_state = generator.get_state() + kwargs = {'generator': generator} if use_generator else {} + ops = [ + lambda t, _: torch.normal(0., torch.abs(t), **kwargs), + lambda t, _: torch.normal(t, 1., **kwargs), + lambda t, _: torch.bernoulli(t - 0.5, **kwargs), + lambda t, _: torch.bernoulli(t, 0.5, **kwargs), + lambda t, _: torch._standard_gamma(t, **kwargs), + lambda t, _: torch._sample_dirichlet(t, **kwargs), + lambda t, _: torch.poisson(t, **kwargs), + ] + + B0 = 4 + seed = 1234567 + in_dims = self._in_dims(batched_input) + + for op in ops: + always_batched = torch.randn(B0, device=device) + passed = self._get_image(batched_input, B0, device) + if randomness == 'error': + self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims) + return + if randomness == 'same' and batched_input != "none": + self._assert_throws_in_same_mode_batched(op, (passed, always_batched), in_dims=in_dims) + return + + generator = self._reset_random(generator, orig_state, use_generator, seed) + vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched) + + generator = self._reset_random(generator, orig_state, use_generator, seed) + if randomness == "different": + if batched_input == "none": + passed = passed.expand(B0, *passed.shape) + if batched_input == "last": + passed = passed.movedim(-1, 0) + expected = op(passed, always_batched) + self._assert_all_slices_unique(vmap_result) + self.assertEqual(vmap_result, expected) + else: + expected = op(passed, always_batched) + self._assert_all_slices_equal(vmap_result) + for i in range(B0): + self.assertEqual(vmap_result[i], expected) + + @parametrize('use_generator', [True, False]) + @parametrize('randomness', ['error', 'same', 'different']) + @parametrize('batched_call', [True, False]) + @parametrize('batched_input', ["first", "last", "none"]) + def test_multinomial(self, device, use_generator, randomness, batched_call, batched_input): + def flatten_input(input, batch_call, batch_location): + if batch_call and batch_location != "none": + final_size = 3 # [B0, B, N] + elif not batch_call and batch_location == "none": + final_size = 1 # [N] + else: + final_size = 2 # [B0, N] or [B, N] + + start_idx = final_size - 1 + end_idx = -1 + if batch_location == "last": + start_idx -= 1 + end_idx -= 1 # gets to correct final size because using negative indices + + ret = input.flatten(start_idx, end_idx) + assert ret.dim() == final_size + return ret + + def op(input, _): + return torch.multinomial(input, 10, **kwargs) + + generator = torch.Generator(device=device) + orig_state = generator.get_state() + kwargs = {'generator': generator} if use_generator else {} + + B0 = 4 + seed = 1234567 + in_dims = self._in_dims(batched_input) + + always_batched = torch.randn(B0, device=device) + passed = self._get_image(batched_input, B0, device) + passed = flatten_input(passed, batched_call, batched_input) + if randomness == 'error': + self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims) + return + if randomness == 'same' and batched_input != "none": + self._assert_throws_in_same_mode_batched(op, (passed, always_batched), in_dims=in_dims) + return + + generator = self._reset_random(generator, orig_state, use_generator, seed) + vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched) + + generator = self._reset_random(generator, orig_state, use_generator, seed) + + if randomness == "different": + if batched_input == "none": + passed = passed.expand(B0, *passed.shape) + if batched_input == "last": + passed = passed.movedim(-1, 0) + orig_passed_size = passed.shape[:2] if batched_call else passed.shape[:1] + passed = passed.flatten(0, 1) if batched_call else passed + expected = op(passed, always_batched) + expected.reshape(*orig_passed_size, 10) + self._assert_all_slices_unique(vmap_result) + self.assertEqual(vmap_result, expected) + else: + expected = op(passed, always_batched) + self._assert_all_slices_equal(vmap_result) + for i in range(B0): + self.assertEqual(vmap_result[i], expected) + + def test_unsupported_random(self, device): + x = torch.randn(3, device=device) + y = x.abs() + z = x.abs() + with self.assertRaisesRegex(RuntimeError, "calling out variants"): + def f(x): + return torch.randn(3, device=device, out=y) + vmap(f, randomness='same')(x) + with self.assertRaisesRegex(RuntimeError, "calling out variants"): + def f(x0, x1): + return torch.normal(x, y, out=x) + vmap(f, randomness='same')(z, z) + with self.assertRaisesRegex(RuntimeError, "do not yet support"): + def f(z): + return torch.rrelu(x) + vmap(f, randomness='same')(z) + + @parametrize('in_dim', [0, 1, 2]) + @parametrize('out_dim', [0, 1, 2]) + def test_chunk_vmap(self, in_dim, out_dim): + + randomness = "different" + + x = torch.randn(4, 5, 6) + + def f(x): + y = x.sin() + torch.rand_like(x) + return y + + for chunks in [1, 2, 3, 4, 7, 10, 16]: + output = chunk_vmap( + f, in_dims=in_dim, out_dims=out_dim, randomness=randomness, chunks=chunks + )(x) + self._assert_all_slices_unique(output) + + +class TestTransformFailure(TestCase): + @parametrize('transform', ['vmap', 'grad', 'grad_and_value', 'vjp', 'jvp', 'jacrev', 'jacfwd']) + def test_fails_with_autograd_function(self, device, transform): + class Test(torch.autograd.Function): + @staticmethod + def forward(_, input): + return input + + @staticmethod + def backward(_, grad_input): + return grad_input + + transform = getattr(functorch, transform) + + def f(x): + return Test.apply(x) + + if transform == grad or transform == grad_and_value: + input = torch.tensor(4.) + else: + input = torch.randn(5) + + if transform == vjp: + transform = functools.partial(transform, f) + elif transform == jvp: + input = (input,) + transform = functools.partial(transform, f, input) + else: + transform = transform(f) + + with self.assertRaisesRegex(RuntimeError, "autograd.Function"): + transform(input) + + +only_for = ("cpu", "cuda") +instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for) + +instantiate_device_type_tests( + TestVmapBatchedGradient, + globals(), + only_for=only_for, +) +instantiate_device_type_tests(TestTransformFailure, globals(), only_for=only_for) +instantiate_device_type_tests(TestRandomness, globals(), only_for=only_for) + +if __name__ == '__main__': + run_tests() diff --git a/functorch/test/xfail_suggester.py b/functorch/test/xfail_suggester.py new file mode 100644 index 0000000000000..d9ddc95029585 --- /dev/null +++ b/functorch/test/xfail_suggester.py @@ -0,0 +1,142 @@ +import re +import torch + +""" +Instructions: + +1. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_pythonkey.py > result.txt +2. python test/xfail_suggester.py +""" + +with open('result.txt') as f: + lines = f.readlines() + +failed = [line for line in lines if line.startswith('FAILED')] +p = re.compile('FAILED test/test_\w+.py::\w+::(\S+)') # noqa: W605 + + +def get_failed_test(line): + m = p.match(line) + if m is None: + return None + return m.group(1) + + +base_names = { + 'test_grad_', + 'test_vjp_', + 'test_vmapvjp_', + 'test_vmapvjp_has_batch_rule_', + 'test_vjpvmap_', + 'test_jvp_', + 'test_vmapjvp_', + 'test_vmapjvpall_has_batch_rule_', + 'test_vmapjvpall_', + 'test_jvpvjp_', + 'test_vjpvjp_', + 'test_decomposition_', + 'test_make_fx_exhaustive_', + 'test_vmap_exhaustive_', + 'test_op_has_batch_rule_', +} + +failed_tests = [get_failed_test(line) for line in lines] +failed_tests = [match for match in failed_tests if match is not None] +failed_tests = sorted(failed_tests) + +suggested_xfails = {} + + +def remove_device_dtype(test): + return '_'.join(test.split('_')[:-2]) + + +def belongs_to_base(test, base): + if not test.startswith(base): + return False + candidates = [try_base for try_base in base_names if len(try_base) > len(base)] + for candidate in candidates: + if test.startswith(candidate): + return False + return True + + +def parse_namespace(base): + mappings = { + 'nn_functional_': 'nn.functional', + 'fft_': 'fft', + 'linalg_': 'linalg', + '_masked_': '_masked', + } + for heading in mappings.keys(): + if base.startswith(heading): + return mappings[heading], base[len(heading):] + return None, base + + +def get_torch_module(namespace): + if namespace is None: + return torch + if namespace == 'nn.functional': + return torch.nn.functional + return getattr(torch, namespace) + + +def parse_base(base): + namespace, rest = parse_namespace(base) + + apis = dir(get_torch_module(namespace)) + apis = sorted(apis, key=lambda x: -len(x)) + + api = rest + variant = '' + for candidate in apis: + if rest.startswith(candidate): + api = candidate + variant = rest[len(candidate) + 1:] + break + print(base, namespace, api, variant) + return namespace, api, variant + + +def any_starts_with(strs, thing): + for s in strs: + if s.startswith(thing): + return True + return False + + +def get_suggested_xfails(base, tests): + result = [] + tests = [test[len(base):] for test in tests if + belongs_to_base(test, base)] + + base_tests = set([remove_device_dtype(test) for test in tests]) + tests = set(tests) + for base in base_tests: + cpu_variant = base + '_cpu_float32' + cuda_variant = base + '_cuda_float32' + namespace, api, variant = parse_base(base) + if namespace is None: + api = api + else: + api = f'{namespace}.{api}' + if cpu_variant in tests and cuda_variant in tests: + result.append(f"xfail('{api}', '{variant}'),") + continue + if cpu_variant in tests: + result.append(f"xfail('{api}', '{variant}', device_type='cpu'),") + continue + if cuda_variant in tests: + result.append(f"xfail('{api}', '{variant}', device_type='cuda'),") + continue + result.append(f"skip('{api}', '{variant}',") + return result + + +result = {base: get_suggested_xfails(base, failed_tests) for base in base_names} +for k, v in result.items(): + print('=' * 50) + print(k) + print('=' * 50) + print('\n'.join(v)) diff --git a/functorch/tools/lint/black_linter.py b/functorch/tools/lint/black_linter.py new file mode 100644 index 0000000000000..9d259fe096b84 --- /dev/null +++ b/functorch/tools/lint/black_linter.py @@ -0,0 +1,228 @@ +import argparse +import concurrent.futures +import json +import logging +import os +import subprocess +import sys +import time +from enum import Enum +from typing import Any, List, NamedTuple, Optional, BinaryIO + + +IS_WINDOWS: bool = os.name == "nt" + + +def eprint(*args: Any, **kwargs: Any) -> None: + print(*args, file=sys.stderr, flush=True, **kwargs) + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: Optional[str] + line: Optional[int] + char: Optional[int] + code: str + severity: LintSeverity + name: str + original: Optional[str] + replacement: Optional[str] + description: Optional[str] + + +def as_posix(name: str) -> str: + return name.replace("\\", "/") if IS_WINDOWS else name + + +def _run_command( + args: List[str], + *, + stdin: BinaryIO, + timeout: int, +) -> "subprocess.CompletedProcess[bytes]": + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run( + args, + stdin=stdin, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=IS_WINDOWS, # So batch scripts are found. + timeout=timeout, + check=True, + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def run_command( + args: List[str], + *, + stdin: BinaryIO, + retries: int, + timeout: int, +) -> "subprocess.CompletedProcess[bytes]": + remaining_retries = retries + while True: + try: + return _run_command(args, stdin=stdin, timeout=timeout) + except subprocess.TimeoutExpired as err: + if remaining_retries == 0: + raise err + remaining_retries -= 1 + logging.warning( + "(%s/%s) Retrying because command failed with: %r", + retries - remaining_retries, + retries, + err, + ) + time.sleep(1) + + +def check_file( + filename: str, + retries: int, + timeout: int, +) -> List[LintMessage]: + try: + with open(filename, "rb") as f: + original = f.read() + with open(filename, "rb") as f: + proc = run_command( + [sys.executable, "-mblack", "--stdin-filename", filename, "-"], + stdin=f, + retries=retries, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="BLACK", + severity=LintSeverity.ERROR, + name="timeout", + original=None, + replacement=None, + description=( + "black timed out while trying to process a file. " + "Please report an issue in pytorch/pytorch with the " + "label 'module: lint'" + ), + ) + ] + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="BLACK", + severity=LintSeverity.ADVICE, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + "COMMAND (exit code {returncode})\n" + "{command}\n\n" + "STDERR\n{stderr}\n\n" + "STDOUT\n{stdout}" + ).format( + returncode=err.returncode, + command=" ".join(as_posix(x) for x in err.cmd), + stderr=err.stderr.decode("utf-8").strip() or "(empty)", + stdout=err.stdout.decode("utf-8").strip() or "(empty)", + ) + ), + ) + ] + + replacement = proc.stdout + if original == replacement: + return [] + + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="BLACK", + severity=LintSeverity.WARNING, + name="format", + original=original.decode("utf-8"), + replacement=replacement.decode("utf-8"), + description="Run `lintrunner -a` to apply this patch.", + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Format files with black.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--retries", + default=3, + type=int, + help="times to retry timed out black", + ) + parser.add_argument( + "--timeout", + default=90, + type=int, + help="seconds to wait for black", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = { + executor.submit(check_file, x, args.retries, args.timeout): x + for x in args.filenames + } + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + print(json.dumps(lint_message._asdict()), flush=True) + except Exception: + logging.critical('Failed at "%s".', futures[future]) + raise + + +if __name__ == "__main__": + main() diff --git a/functorch/tools/lint/flake8_linter.py b/functorch/tools/lint/flake8_linter.py new file mode 100644 index 0000000000000..20274432566c9 --- /dev/null +++ b/functorch/tools/lint/flake8_linter.py @@ -0,0 +1,373 @@ +import argparse +import json +import logging +import os +import re +import subprocess +import sys +import time +from enum import Enum +from typing import Any, Dict, List, NamedTuple, Optional, Set, Pattern + + +IS_WINDOWS: bool = os.name == "nt" + + +def eprint(*args: Any, **kwargs: Any) -> None: + print(*args, file=sys.stderr, flush=True, **kwargs) + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: Optional[str] + line: Optional[int] + char: Optional[int] + code: str + severity: LintSeverity + name: str + original: Optional[str] + replacement: Optional[str] + description: Optional[str] + + +def as_posix(name: str) -> str: + return name.replace("\\", "/") if IS_WINDOWS else name + + +# fmt: off +# https://www.flake8rules.com/ +DOCUMENTED_IN_FLAKE8RULES: Set[str] = { + "E101", "E111", "E112", "E113", "E114", "E115", "E116", "E117", + "E121", "E122", "E123", "E124", "E125", "E126", "E127", "E128", "E129", + "E131", "E133", + "E201", "E202", "E203", + "E211", + "E221", "E222", "E223", "E224", "E225", "E226", "E227", "E228", + "E231", + "E241", "E242", + "E251", + "E261", "E262", "E265", "E266", + "E271", "E272", "E273", "E274", "E275", + "E301", "E302", "E303", "E304", "E305", "E306", + "E401", "E402", + "E501", "E502", + "E701", "E702", "E703", "E704", + "E711", "E712", "E713", "E714", + "E721", "E722", + "E731", + "E741", "E742", "E743", + "E901", "E902", "E999", + "W191", + "W291", "W292", "W293", + "W391", + "W503", "W504", + "W601", "W602", "W603", "W604", "W605", + "F401", "F402", "F403", "F404", "F405", + "F811", "F812", + "F821", "F822", "F823", + "F831", + "F841", + "F901", + "C901", +} + +# https://pypi.org/project/flake8-comprehensions/#rules +DOCUMENTED_IN_FLAKE8COMPREHENSIONS: Set[str] = { + "C400", "C401", "C402", "C403", "C404", "C405", "C406", "C407", "C408", "C409", + "C410", + "C411", "C412", "C413", "C413", "C414", "C415", "C416", +} + +# https://github.com/PyCQA/flake8-bugbear#list-of-warnings +DOCUMENTED_IN_BUGBEAR: Set[str] = { + "B001", "B002", "B003", "B004", "B005", "B006", "B007", "B008", "B009", "B010", + "B011", "B012", "B013", "B014", "B015", + "B301", "B302", "B303", "B304", "B305", "B306", + "B901", "B902", "B903", "B950", +} +# fmt: on + + +# stdin:2: W802 undefined name 'foo' +# stdin:3:6: T484 Name 'foo' is not defined +# stdin:3:-100: W605 invalid escape sequence '\/' +# stdin:3:1: E302 expected 2 blank lines, found 1 +RESULTS_RE: Pattern[str] = re.compile( + r"""(?mx) + ^ + (?P.*?): + (?P\d+): + (?:(?P-?\d+):)? + \s(?P\S+?):? + \s(?P.*) + $ + """ +) + + +def _test_results_re() -> None: + """ + >>> def t(s): return RESULTS_RE.search(s).groupdict() + + >>> t(r"file.py:80:1: E302 expected 2 blank lines, found 1") + ... # doctest: +NORMALIZE_WHITESPACE + {'file': 'file.py', 'line': '80', 'column': '1', 'code': 'E302', + 'message': 'expected 2 blank lines, found 1'} + + >>> t(r"file.py:7:1: P201: Resource `stdout` is acquired but not always released.") + ... # doctest: +NORMALIZE_WHITESPACE + {'file': 'file.py', 'line': '7', 'column': '1', 'code': 'P201', + 'message': 'Resource `stdout` is acquired but not always released.'} + + >>> t(r"file.py:8:-10: W605 invalid escape sequence '/'") + ... # doctest: +NORMALIZE_WHITESPACE + {'file': 'file.py', 'line': '8', 'column': '-10', 'code': 'W605', + 'message': "invalid escape sequence '/'"} + """ + pass + + +def _run_command( + args: List[str], + *, + extra_env: Optional[Dict[str, str]], +) -> "subprocess.CompletedProcess[str]": + logging.debug( + "$ %s", + " ".join( + ([f"{k}={v}" for (k, v) in extra_env.items()] if extra_env else []) + args + ), + ) + start_time = time.monotonic() + try: + return subprocess.run( + args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8", + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def run_command( + args: List[str], + *, + extra_env: Optional[Dict[str, str]], + retries: int, +) -> "subprocess.CompletedProcess[str]": + remaining_retries = retries + while True: + try: + return _run_command(args, extra_env=extra_env) + except subprocess.CalledProcessError as err: + if remaining_retries == 0 or not re.match( + r"^ERROR:1:1: X000 linting with .+ timed out after \d+ seconds", + err.stdout, + ): + raise err + remaining_retries -= 1 + logging.warning( + "(%s/%s) Retrying because command failed with: %r", + retries - remaining_retries, + retries, + err, + ) + time.sleep(1) + + +def get_issue_severity(code: str) -> LintSeverity: + # "B901": `return x` inside a generator + # "B902": Invalid first argument to a method + # "B903": __slots__ efficiency + # "B950": Line too long + # "C4": Flake8 Comprehensions + # "C9": Cyclomatic complexity + # "E2": PEP8 horizontal whitespace "errors" + # "E3": PEP8 blank line "errors" + # "E5": PEP8 line length "errors" + # "F401": Name imported but unused + # "F403": Star imports used + # "F405": Name possibly from star imports + # "T400": type checking Notes + # "T49": internal type checker errors or unmatched messages + if any( + code.startswith(x) + for x in [ + "B9", + "C4", + "C9", + "E2", + "E3", + "E5", + "F401", + "F403", + "F405", + "T400", + "T49", + ] + ): + return LintSeverity.ADVICE + + # "F821": Undefined name + # "E999": syntax error + if any(code.startswith(x) for x in ["F821", "E999"]): + return LintSeverity.ERROR + + # "F": PyFlakes Error + # "B": flake8-bugbear Error + # "E": PEP8 "Error" + # "W": PEP8 Warning + # possibly other plugins... + return LintSeverity.WARNING + + +def get_issue_documentation_url(code: str) -> str: + if code in DOCUMENTED_IN_FLAKE8RULES: + return f"https://www.flake8rules.com/rules/{code}.html" + + if code in DOCUMENTED_IN_FLAKE8COMPREHENSIONS: + return "https://pypi.org/project/flake8-comprehensions/#rules" + + if code in DOCUMENTED_IN_BUGBEAR: + return "https://github.com/PyCQA/flake8-bugbear#list-of-warnings" + + return "" + + +def check_files( + filenames: List[str], + flake8_plugins_path: Optional[str], + severities: Dict[str, LintSeverity], + retries: int, +) -> List[LintMessage]: + try: + proc = run_command( + [sys.executable, "-mflake8", "--exit-zero"] + filenames, + extra_env={"FLAKE8_PLUGINS_PATH": flake8_plugins_path} + if flake8_plugins_path + else None, + retries=retries, + ) + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code="FLAKE8", + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + "COMMAND (exit code {returncode})\n" + "{command}\n\n" + "STDERR\n{stderr}\n\n" + "STDOUT\n{stdout}" + ).format( + returncode=err.returncode, + command=" ".join(as_posix(x) for x in err.cmd), + stderr=err.stderr.strip() or "(empty)", + stdout=err.stdout.strip() or "(empty)", + ) + ), + ) + ] + + return [ + LintMessage( + path=match["file"], + name=match["code"], + description="{}\nSee {}".format( + match["message"], + get_issue_documentation_url(match["code"]), + ), + line=int(match["line"]), + char=int(match["column"]) + if match["column"] is not None and not match["column"].startswith("-") + else None, + code="FLAKE8", + severity=severities.get(match["code"]) or get_issue_severity(match["code"]), + original=None, + replacement=None, + ) + for match in RESULTS_RE.finditer(proc.stdout) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Flake8 wrapper linter.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--flake8-plugins-path", + help="FLAKE8_PLUGINS_PATH env value", + ) + parser.add_argument( + "--severity", + action="append", + help="map code to severity (e.g. `B950:advice`)", + ) + parser.add_argument( + "--retries", + default=3, + type=int, + help="times to retry timed out flake8", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + flake8_plugins_path = ( + None + if args.flake8_plugins_path is None + else os.path.realpath(args.flake8_plugins_path) + ) + + severities: Dict[str, LintSeverity] = {} + if args.severity: + for severity in args.severity: + parts = severity.split(":", 1) + assert len(parts) == 2, f"invalid severity `{severity}`" + severities[parts[0]] = LintSeverity(parts[1]) + + lint_messages = check_files( + args.filenames, flake8_plugins_path, severities, args.retries + ) + for lint_message in lint_messages: + print(json.dumps(lint_message._asdict()), flush=True) + + +if __name__ == "__main__": + main() diff --git a/functorch/tools/lint/pip_init.py b/functorch/tools/lint/pip_init.py new file mode 100644 index 0000000000000..db1f69d26b227 --- /dev/null +++ b/functorch/tools/lint/pip_init.py @@ -0,0 +1,75 @@ +""" +Initializer script that installs stuff to pip. +""" +import os +import argparse +import logging +import subprocess +import sys +import time + +from typing import List + + +def run_command(args: List[str]) -> "subprocess.CompletedProcess[bytes]": + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run(args, check=True) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="pip initializer") + parser.add_argument( + "packages", + nargs="+", + help="pip packages to install", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "--dry-run", help="do not install anything, just print what would be done." + ) + + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET if args.verbose else logging.DEBUG, + stream=sys.stderr, + ) + + for package in args.packages: + package_name, _, version = package.partition("=") + if version == "": + raise RuntimeError( + "Package {package_name} did not have a version specified. " + "Please specify a version to product a consistent linting experience." + ) + pip_args = ["pip3", "install"] + + # If we are in a global install, use `--user` to install so that you do not + # need root access in order to initialize linters. + # + # However, `pip install --user` interacts poorly with virtualenvs (see: + # https://bit.ly/3vD4kvl) and conda (see: https://bit.ly/3KG7ZfU). So in + # these cases perform a regular installation. + in_conda = os.environ.get("CONDA_PREFIX") is not None + in_virtualenv = os.environ.get("VIRTUAL_ENV") is not None + if not in_conda and not in_virtualenv: + pip_args.append("--user") + + pip_args.extend(args.packages) + + dry_run = args.dry_run == "1" + if dry_run: + print(f"Would have run: {pip_args}") + sys.exit(0) + + run_command(pip_args) diff --git a/functorch/version.txt b/functorch/version.txt new file mode 100644 index 0000000000000..c181bf5996673 --- /dev/null +++ b/functorch/version.txt @@ -0,0 +1 @@ +0.3.0a0 diff --git a/functorch/writing_batching_rules.md b/functorch/writing_batching_rules.md new file mode 100644 index 0000000000000..5f571c4170858 --- /dev/null +++ b/functorch/writing_batching_rules.md @@ -0,0 +1,98 @@ +So, you want to write some batching rules? This is the guide to get started :) + +First off, what are batching rules and why do we need so many of them? Well, to understand that, we need to understand how vmap works. + +### How does vmap work? +Vmap is a function transform (pioneered by Jax) that allows one to batch functions. That is, given a function `f(x: [N]) -> [N]`, `vmap(f)` now transforms the signature to be `f(x: [B, N]) -> [B, N]`. That is - it adds a batch dimension to both the input and the output of the function. + +This guide will gloss over all the cool things you can do this (there are many!), so let's focus on how we actually implement this. + +One misconception is that this is some magic compiler voodoo, or that it is inherent some function transform. It is not - and there's another framing of it that might make it more clear. + +Instead of providing `vmap`, imagine that we provide a `BatchedTensor` instead. This `BatchedTensor` wraps a `Tensor[B, N, M]`. *But*, to all the users of this tensor, it looks like a `Tensor[N, M]` (that is, without the `B` dimension). Then, when operations are done on this tensor, it transforms that operation to broadcast over the additional `B` dimension as well. + +For example, let's say that we wanted to sum a `BatchedTensor` with shape `[5]` - that is, `torch.sum(x)`. This would give us back a `BatchedTensor` with shape `[]` (i.e. a scalar tensor). **But**, in reality, we this is actually a `Tensor` with shape `[B]`. Instead of running `torch.sum(x: [5])`, we ran `torch.sum(x: [B, 5], dim=1)`. In other words, we transformed the sum operation so that instead of summing the whole tensor, it summed all the dimensions *except* the batch dimension. + +That is how `vmap` works. For every single operator, we define how to transform that operator to broadcast over an additional batch dimension. + +### Basic Batching Rule (unsqueeze) +Let's take a look at our batching rule API. For some reference, the function signature for unsqueeze is `unsqueeze(Tensor(a) self, int dim) -> Tensor(a)`. This can be found [here](functorch/csrc/BatchRulesViews.cpp). +``` +std::tuple> unsqueeze_batch_rule( + const Tensor& self, + optional self_bdim, + int64_t dim) { + auto self_ = moveBatchDimToFront(self, self_bdim); + auto rank = rankWithoutBatchDim(self, self_bdim); + dim = maybe_wrap_dim(dim, rank + 1) + 1; + return std::make_tuple(self_.unsqueeze(dim), 0); +} +``` +Now, let's look at each part individually. +``` +std::tuple> unsqueeze_batch_rule( + const Tensor& self, + optional self_bdim, + int64_t dim) { +``` +For the most part, the function signature for a batching rule is identical to the function signature for the operator. The only difference is that for each `Tensor` (both in the input and the output), we have an additional `optional`. This is the batch dimension. In the previous explanation, we implicitly assumed that the batch dimension was always at 0, but we allow for batch dimensions to be on arbitrary dimensions. The `optional` part reflects that not all tensors are batched - if a function takes multiple tensors then it's possible for only one of them to be a `BatchedTensor`. Note, however, that we guarantee that at least one tensor will always have a batch dimension. + +``` + auto self_ = moveBatchDimToFront(self, self_bdim); + auto rank = rankWithoutBatchDim(self, self_bdim); + dim = maybe_wrap_dim(dim, rank + 1) + 1; +``` +For `unsqueeze(x, dim)`, the strategy for the batching rule is pretty simple. We first move the batching dimension to the front. Then, instead of doing `unsqueeze(x, dim)`, we do `unsqueeze(x, dim + 1)` (since there's now an extra bdim). + +``` +return std::make_tuple(self_.unsqueeze(dim), 0); +``` +Now, we return a tuple of the tensor along with its batch dimension (which is now 0 since we moved it to the front). + +``` +VMAP_SUPPORT(unsqueeze, unsqueeze_batch_rule); +``` +Finally, we add support for it by using the `VMAP_SUPPORT` macro. + +You may need to use the `VMAP_SUPPORT2` macro if the operator has an overload name. + +### Implementing multiple batching rules with boxed fallbacks or templates +Often, we find that large classes of operators have similar patterns of batching rules. For example, every single pointwise op has a similar pattern. In that case, it's a bit ridiculous to separately write a batching rule for those situations. + +In those cases, we have 2 primary tools - templates and boxed fallbacks. For example, we've written a boxed fallback that covers many reductions (see the [reduction batching rules](functorch/csrc/BatchRulesReduceOps.cpp)). + +There are 3 primary boxed fallbacks that we've used (I'll refer to the macros here). If you feel that there's any pattern that we could/should abstract away, feel free to post an issue. + +1. `POINTWISE_BOXED`: Handles pointwise ops. Takes all tensors in the arguments, moves batch dimensions to the front, and unsqueezes all tensors so that they broadcast. +1. `REDUCTION_BOXED`: Handles reduction ops. Moves batch dimension to the front, and then modifies the dim argument so that it works with the extra batch dimension. For example, if the dim is an integer, then we add one. If it's a dimarray, then we add one to all entries (unless it's empty!, in which case we fill in all the entries except 0). +1. `VARIADIC_BDIMS_BOXED`: Handles ops that already natively support arbitrary batch dimensions. For example, if it supports `[B1,B2,..., N]`. In this case, we can simply move the batch dimension to the front and we're done! + +### Sidestepping batching rules by decomposing operators +Sometimes, it's difficult to implement a batching rule by transforming it into another operator. For example, `trace`. In that case, instead of transforming the operator, we can simply decompose it. + +``` +Tensor trace_decomp(const Tensor& self) { + return at::sum(at::diagonal(self)); +} +... +m.impl("trace", trace_decomp); +``` +In general, this reduces the performance, since instead of launching one kernel we're launching multiple. So, we generally try to avoid this option :) + +PS: There's a special class of operators in PyTorch called CompositeImplicitAutograd operators that can be decomposed without losing performance. We actually decompose those by default! (although we stop the decompositions that aren't tested in [BatchRulesStopDecomposition](functorch/csrc/BatchRulesStopDecomposition.cpp) for debuggability reasons). + +### Testing your batching rule +We generally use OpInfos to test our batching rules. OpInfos are great since they let us test the same operator in many different ways. + +In general, if the operator you've added a batching rule for has an OpInfo test, that's good enough! + +Generally, you can try running `pytest -k op_name` to use `pytest` to find all tests that test your operator. Sometimes, if your operator doesn't match the public API, you need to figure out the public API that corresponds to the operator you've implemented a batching rule for. For example, `torch.where` actually often executes `aten::_s_where` underneath. + +Todo: Add more relevant details @zou + +## Cool, I'm convinced! And I want to write batching rules! Where do I find some? +There's a couple different resources for finding batching rules to write. + +1. [BatchingRegistrations.cpp](functorch/csrc/BatchingRegistrations.cpp): This is probably the easiest place to start. These were batching rules that were written with an old API, and thus have a lot of cruft in them that are no longer necessary. Porting these batching rules to using one of the above options is an easy way to get started and help us reduce tech debt :) Once you've gotten your footing with writing batching rules, you can start helping with writing new batching rules. +2. Popular operators. See [1](https://github.com/facebookresearch/functorch/issues/112), [2](https://github.com/facebookresearch/functorch/issues/101), [3](https://github.com/facebookresearch/functorch/issues/102), and [4](https://github.com/facebookresearch/functorch/issues/102). These contain lists of (user-facing) PyTorch operators sorted by usages, along with whether they have a batching rule implemented or not. +3. [Master List](https://docs.google.com/spreadsheets/d/1Sp4HUjxwMifS5oDQg0yvjqk7hKOpCfKO4jWH4MTGP-k/edit#gid=0). This is the master list of vmap operator support :). It's generated by [this script](op_analysis/gen_data.py). Theoretically, we want to support most of the operators in that list (that aren't composite or out variants). diff --git a/ios/LibTorch-Lite.podspec b/ios/LibTorch-Lite.podspec index d2d9264e0a622..9814eaa367586 100644 --- a/ios/LibTorch-Lite.podspec +++ b/ios/LibTorch-Lite.podspec @@ -1,6 +1,6 @@ Pod::Spec.new do |s| s.name = 'LibTorch-Lite' - s.version = '1.11.0' + s.version = '1.12.0' s.authors = 'PyTorch Team' s.license = { :type => 'BSD' } s.homepage = 'https://github.com/pytorch/pytorch' diff --git a/ios/LibTorch.podspec b/ios/LibTorch.podspec index 77bc0537e89ed..3c197f0f103b9 100644 --- a/ios/LibTorch.podspec +++ b/ios/LibTorch.podspec @@ -1,6 +1,6 @@ Pod::Spec.new do |s| s.name = 'LibTorch' - s.version = '1.11.0' + s.version = '1.12.0' s.authors = 'PyTorch Team' s.license = { :type => 'BSD' } s.homepage = 'https://github.com/pytorch/pytorch' diff --git a/ios/TestApp/Gemfile.lock b/ios/TestApp/Gemfile.lock index d4407d469b174..3586770934fbc 100644 --- a/ios/TestApp/Gemfile.lock +++ b/ios/TestApp/Gemfile.lock @@ -1,61 +1,102 @@ GEM remote: https://rubygems.org/ specs: - CFPropertyList (3.0.2) + CFPropertyList (3.0.5) + rexml addressable (2.8.0) public_suffix (>= 2.0.2, < 5.0) + artifactory (3.0.15) atomos (0.1.3) - babosa (1.0.3) - claide (1.0.3) + aws-eventstream (1.2.0) + aws-partitions (1.601.0) + aws-sdk-core (3.131.2) + aws-eventstream (~> 1, >= 1.0.2) + aws-partitions (~> 1, >= 1.525.0) + aws-sigv4 (~> 1.1) + jmespath (~> 1, >= 1.6.1) + aws-sdk-kms (1.57.0) + aws-sdk-core (~> 3, >= 3.127.0) + aws-sigv4 (~> 1.1) + aws-sdk-s3 (1.114.0) + aws-sdk-core (~> 3, >= 3.127.0) + aws-sdk-kms (~> 1) + aws-sigv4 (~> 1.4) + aws-sigv4 (1.5.0) + aws-eventstream (~> 1, >= 1.0.2) + babosa (1.0.4) + claide (1.1.0) colored (1.2) colored2 (3.1.2) - commander-fastlane (4.4.6) - highline (~> 1.7.2) - declarative (0.0.10) - declarative-option (0.1.0) - digest-crc (0.4.1) + commander (4.6.0) + highline (~> 2.0.0) + declarative (0.0.20) + digest-crc (0.6.4) + rake (>= 12.0.0, < 14.0.0) domain_name (0.5.20190701) unf (>= 0.0.5, < 1.0.0) - dotenv (2.7.5) - emoji_regex (1.0.1) - excon (0.71.1) - faraday (0.17.3) - multipart-post (>= 1.2, < 3) - faraday-cookie_jar (0.0.6) - faraday (>= 0.7.4) + dotenv (2.7.6) + emoji_regex (3.2.3) + excon (0.92.3) + faraday (1.10.0) + faraday-em_http (~> 1.0) + faraday-em_synchrony (~> 1.0) + faraday-excon (~> 1.1) + faraday-httpclient (~> 1.0) + faraday-multipart (~> 1.0) + faraday-net_http (~> 1.0) + faraday-net_http_persistent (~> 1.0) + faraday-patron (~> 1.0) + faraday-rack (~> 1.0) + faraday-retry (~> 1.0) + ruby2_keywords (>= 0.0.4) + faraday-cookie_jar (0.0.7) + faraday (>= 0.8.0) http-cookie (~> 1.0.0) - faraday_middleware (0.13.1) - faraday (>= 0.7.4, < 1.0) - fastimage (2.1.7) - fastlane (2.140.0) + faraday-em_http (1.0.0) + faraday-em_synchrony (1.0.0) + faraday-excon (1.1.0) + faraday-httpclient (1.0.1) + faraday-multipart (1.0.4) + multipart-post (~> 2) + faraday-net_http (1.0.1) + faraday-net_http_persistent (1.2.0) + faraday-patron (1.0.0) + faraday-rack (1.0.0) + faraday-retry (1.0.3) + faraday_middleware (1.2.0) + faraday (~> 1.0) + fastimage (2.2.6) + fastlane (2.206.2) CFPropertyList (>= 2.3, < 4.0.0) - addressable (>= 2.3, < 3.0.0) - babosa (>= 1.0.2, < 2.0.0) + addressable (>= 2.8, < 3.0.0) + artifactory (~> 3.0) + aws-sdk-s3 (~> 1.0) + babosa (>= 1.0.3, < 2.0.0) bundler (>= 1.12.0, < 3.0.0) colored - commander-fastlane (>= 4.4.6, < 5.0.0) + commander (~> 4.6) dotenv (>= 2.1.1, < 3.0.0) - emoji_regex (>= 0.1, < 2.0) + emoji_regex (>= 0.1, < 4.0) excon (>= 0.71.0, < 1.0.0) - faraday (~> 0.17) + faraday (~> 1.0) faraday-cookie_jar (~> 0.0.6) - faraday_middleware (~> 0.13.1) + faraday_middleware (~> 1.0) fastimage (>= 2.1.0, < 3.0.0) gh_inspector (>= 1.1.2, < 2.0.0) - google-api-client (>= 0.29.2, < 0.37.0) - google-cloud-storage (>= 1.15.0, < 2.0.0) - highline (>= 1.7.2, < 2.0.0) + google-apis-androidpublisher_v3 (~> 0.3) + google-apis-playcustomapp_v1 (~> 0.1) + google-cloud-storage (~> 1.31) + highline (~> 2.0) json (< 3.0.0) - jwt (~> 2.1.0) + jwt (>= 2.1.0, < 3) mini_magick (>= 4.9.4, < 5.0.0) - multi_xml (~> 0.5) multipart-post (~> 2.0.0) + naturally (~> 2.2) + optparse (~> 0.1.1) plist (>= 3.1.0, < 4.0.0) - public_suffix (~> 2.0.0) - rubyzip (>= 1.3.0, < 2.0.0) + rubyzip (>= 2.0.0, < 3.0.0) security (= 0.1.3) simctl (~> 1.6.3) - slack-notifier (>= 2.0.0, < 3.0.0) terminal-notifier (>= 2.0.0, < 3.0.0) terminal-table (>= 1.4.5, < 2.0.0) tty-screen (>= 0.6.3, < 1.0.0) @@ -65,90 +106,106 @@ GEM xcpretty (~> 0.3.0) xcpretty-travis-formatter (>= 0.0.3) gh_inspector (1.1.3) - google-api-client (0.36.4) + google-apis-androidpublisher_v3 (0.23.0) + google-apis-core (>= 0.6, < 2.a) + google-apis-core (0.6.0) addressable (~> 2.5, >= 2.5.1) - googleauth (~> 0.9) - httpclient (>= 2.8.1, < 3.0) + googleauth (>= 0.16.2, < 2.a) + httpclient (>= 2.8.1, < 3.a) mini_mime (~> 1.0) representable (~> 3.0) - retriable (>= 2.0, < 4.0) - signet (~> 0.12) - google-cloud-core (1.5.0) + retriable (>= 2.0, < 4.a) + rexml + webrick + google-apis-iamcredentials_v1 (0.12.0) + google-apis-core (>= 0.6, < 2.a) + google-apis-playcustomapp_v1 (0.9.0) + google-apis-core (>= 0.6, < 2.a) + google-apis-storage_v1 (0.16.0) + google-apis-core (>= 0.6, < 2.a) + google-cloud-core (1.6.0) google-cloud-env (~> 1.0) google-cloud-errors (~> 1.0) - google-cloud-env (1.3.0) - faraday (~> 0.11) - google-cloud-errors (1.0.0) - google-cloud-storage (1.25.1) - addressable (~> 2.5) + google-cloud-env (1.6.0) + faraday (>= 0.17.3, < 3.0) + google-cloud-errors (1.2.0) + google-cloud-storage (1.36.2) + addressable (~> 2.8) digest-crc (~> 0.4) - google-api-client (~> 0.33) - google-cloud-core (~> 1.2) - googleauth (~> 0.9) + google-apis-iamcredentials_v1 (~> 0.1) + google-apis-storage_v1 (~> 0.1) + google-cloud-core (~> 1.6) + googleauth (>= 0.16.2, < 2.a) mini_mime (~> 1.0) - googleauth (0.10.0) - faraday (~> 0.12) + googleauth (1.2.0) + faraday (>= 0.17.3, < 3.a) jwt (>= 1.4, < 3.0) memoist (~> 0.16) multi_json (~> 1.11) os (>= 0.9, < 2.0) - signet (~> 0.12) - highline (1.7.10) - http-cookie (1.0.3) + signet (>= 0.16, < 2.a) + highline (2.0.3) + http-cookie (1.0.5) domain_name (~> 0.5) httpclient (2.8.3) - json (2.3.0) - jwt (2.1.0) + jmespath (1.6.1) + json (2.6.2) + jwt (2.4.1) memoist (0.16.2) - mini_magick (4.10.1) - mini_mime (1.0.2) - multi_json (1.14.1) - multi_xml (0.6.0) + mini_magick (4.11.0) + mini_mime (1.1.2) + multi_json (1.15.0) multipart-post (2.0.0) - nanaimo (0.2.6) - naturally (2.2.0) - os (1.0.1) - plist (3.5.0) - public_suffix (2.0.5) - representable (3.0.4) + nanaimo (0.3.0) + naturally (2.2.1) + optparse (0.1.1) + os (1.1.4) + plist (3.6.0) + public_suffix (4.0.7) + rake (13.0.6) + representable (3.2.0) declarative (< 0.1.0) - declarative-option (< 0.2.0) + trailblazer-option (>= 0.1.1, < 0.2.0) uber (< 0.2.0) retriable (3.1.2) + rexml (3.2.5) rouge (2.0.7) - rubyzip (1.3.0) + ruby2_keywords (0.0.5) + rubyzip (2.3.2) security (0.1.3) - signet (0.12.0) - addressable (~> 2.3) - faraday (~> 0.9) + signet (0.17.0) + addressable (~> 2.8) + faraday (>= 0.17.5, < 3.a) jwt (>= 1.5, < 3.0) multi_json (~> 1.10) - simctl (1.6.7) + simctl (1.6.8) CFPropertyList naturally - slack-notifier (2.3.2) terminal-notifier (2.0.0) terminal-table (1.8.0) unicode-display_width (~> 1.1, >= 1.1.1) - tty-cursor (0.7.0) - tty-screen (0.7.0) - tty-spinner (0.9.2) + trailblazer-option (0.1.2) + tty-cursor (0.7.1) + tty-screen (0.8.1) + tty-spinner (0.9.3) tty-cursor (~> 0.7) uber (0.1.0) unf (0.1.4) unf_ext - unf_ext (0.0.7.6) - unicode-display_width (1.6.0) + unf_ext (0.0.8.2) + unicode-display_width (1.8.0) + webrick (1.7.0) word_wrap (1.0.0) - xcodeproj (1.14.0) + xcodeproj (1.22.0) CFPropertyList (>= 2.3.3, < 4.0) atomos (~> 0.1.3) claide (>= 1.0.2, < 2.0) colored2 (~> 3.1) - nanaimo (~> 0.2.6) + nanaimo (~> 0.3.0) + rexml (~> 3.2.4) xcpretty (0.3.0) rouge (~> 2.0.7) - xcpretty-travis-formatter (1.0.0) + xcpretty-travis-formatter (1.0.1) xcpretty (~> 0.2, >= 0.0.7) PLATFORMS @@ -158,4 +215,4 @@ DEPENDENCIES fastlane BUNDLED WITH - 2.0.2 + 2.3.16 diff --git a/ios/TestApp/fastlane/Scanfile b/ios/TestApp/fastlane/Scanfile index cb1f9e2e18890..8d351bf65b493 100644 --- a/ios/TestApp/fastlane/Scanfile +++ b/ios/TestApp/fastlane/Scanfile @@ -2,6 +2,8 @@ scheme("TestAppTests") open_report(false) clean(true) suppress_xcode_output(true) -force_quit_simulator(true) +ensure_devices_found(true) include_simulator_logs(false) deployment_target_version('14.0') +number_of_retries(2) +prelaunch_simulator(true) diff --git a/modules/detectron/CMakeLists.txt b/modules/detectron/CMakeLists.txt index dc5aea2df1797..46276114c5e04 100644 --- a/modules/detectron/CMakeLists.txt +++ b/modules/detectron/CMakeLists.txt @@ -17,7 +17,7 @@ if(BUILD_CAFFE2_OPS) torch_set_target_props(caffe2_detectron_ops_gpu) target_link_libraries(caffe2_detectron_ops_gpu PRIVATE torch ${OpenMP_link}) - if(CAFFE2_USE_MKLDNN) + if(USE_MKLDNN) target_link_libraries(caffe2_detectron_ops_gpu PRIVATE caffe2::mkldnn) endif() install(TARGETS caffe2_detectron_ops_gpu DESTINATION lib) @@ -33,7 +33,7 @@ if(BUILD_CAFFE2_OPS) ${Detectron_HIP_SRCS}) torch_set_target_props(caffe2_detectron_ops_hip) target_compile_options(caffe2_detectron_ops_hip PRIVATE ${HIP_CXX_FLAGS}) - if(CAFFE2_USE_MKLDNN) + if(USE_MKLDNN) target_link_libraries(caffe2_detectron_ops_hip PRIVATE caffe2::mkldnn) endif() target_link_libraries(caffe2_detectron_ops_hip PRIVATE torch) @@ -46,7 +46,7 @@ if(BUILD_CAFFE2_OPS) endif() torch_set_target_props(caffe2_detectron_ops) target_link_libraries(caffe2_detectron_ops PRIVATE torch ${OpenMP_link}) - if(CAFFE2_USE_MKLDNN) + if(USE_MKLDNN) target_link_libraries(caffe2_detectron_ops PRIVATE caffe2::mkldnn) endif() install(TARGETS caffe2_detectron_ops DESTINATION lib) diff --git a/modules/detectron/group_spatial_softmax_op.cu b/modules/detectron/group_spatial_softmax_op.cu index a37a3fba55a73..741da27f59d2b 100644 --- a/modules/detectron/group_spatial_softmax_op.cu +++ b/modules/detectron/group_spatial_softmax_op.cu @@ -103,7 +103,7 @@ bool GroupSpatialSoftmaxOp::RunOnDevice() { int A = D / num_classes_; auto* P = Output(0, X.sizes(), at::dtype()); // Probabilities from softmax - DCHECK_EQ(X.ndim(), 4); + TORCH_DCHECK_EQ(X.ndim(), 4); const float* Xdata = X.data(); float* Pdata = P->mutable_data(); @@ -123,7 +123,7 @@ bool GroupSpatialSoftmaxGradientOp::RunOnDevice() { auto& dY = Input(1); - DCHECK_EQ(Y.ndim(), 4); + TORCH_DCHECK_EQ(Y.ndim(), 4); int N = Y.dim32(0); int D = Y.dim32(1); diff --git a/modules/detectron/ps_roi_pool_op.h b/modules/detectron/ps_roi_pool_op.h index 8f30722e053e5..ecee1dd7041c4 100644 --- a/modules/detectron/ps_roi_pool_op.h +++ b/modules/detectron/ps_roi_pool_op.h @@ -33,8 +33,8 @@ class PSRoIPoolOp final : public Operator { "spatial_scale", 1.)), group_size_(this->template GetSingleArgument("group_size", 1)), output_dim_(this->template GetSingleArgument("output_dim", 1)) { - DCHECK_GT(spatial_scale_, 0); - DCHECK_GT(group_size_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0); + TORCH_DCHECK_GT(group_size_, 0); pooled_height_ = group_size_; pooled_width_ = group_size_; } @@ -65,8 +65,8 @@ class PSRoIPoolGradientOp final : public Operator { "spatial_scale", 1.)), group_size_(this->template GetSingleArgument("group_size", 1)), output_dim_(this->template GetSingleArgument("output_dim", 1)) { - DCHECK_GT(spatial_scale_, 0); - DCHECK_GT(group_size_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0); + TORCH_DCHECK_GT(group_size_, 0); pooled_height_ = group_size_; pooled_width_ = group_size_; } diff --git a/modules/detectron/roi_pool_f_op.h b/modules/detectron/roi_pool_f_op.h index 357e6bf45c4b5..604c5606a203e 100644 --- a/modules/detectron/roi_pool_f_op.h +++ b/modules/detectron/roi_pool_f_op.h @@ -33,9 +33,9 @@ class RoIPoolFOp final : public Operator { "spatial_scale", 1.)), pooled_height_(this->template GetSingleArgument("pooled_h", 1)), pooled_width_(this->template GetSingleArgument("pooled_w", 1)) { - DCHECK_GT(spatial_scale_, 0); - DCHECK_GT(pooled_height_, 0); - DCHECK_GT(pooled_width_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0); + TORCH_DCHECK_GT(pooled_height_, 0); + TORCH_DCHECK_GT(pooled_width_, 0); } USE_OPERATOR_CONTEXT_FUNCTIONS; @@ -59,9 +59,9 @@ class RoIPoolFGradientOp final : public Operator { "spatial_scale", 1.)), pooled_height_(this->template GetSingleArgument("pooled_h", 1)), pooled_width_(this->template GetSingleArgument("pooled_w", 1)) { - DCHECK_GT(spatial_scale_, 0); - DCHECK_GT(pooled_height_, 0); - DCHECK_GT(pooled_width_, 0); + TORCH_DCHECK_GT(spatial_scale_, 0); + TORCH_DCHECK_GT(pooled_height_, 0); + TORCH_DCHECK_GT(pooled_width_, 0); } USE_OPERATOR_CONTEXT_FUNCTIONS; diff --git a/modules/detectron/softmax_focal_loss_op.cu b/modules/detectron/softmax_focal_loss_op.cu index b7f8d2423ebc0..0612ef7edcc8c 100644 --- a/modules/detectron/softmax_focal_loss_op.cu +++ b/modules/detectron/softmax_focal_loss_op.cu @@ -165,7 +165,7 @@ bool SoftmaxFocalLossOp::RunOnDevice() { P->size(), 0.f, P->mutable_data(), &context_); math::Set( losses_.size(), 0.f, losses_.mutable_data(), &context_); - DCHECK_EQ(X.ndim(), 4); + TORCH_DCHECK_EQ(X.ndim(), 4); const float* Xdata = X.data(); const float* Wdata = wp.data(); diff --git a/modules/detectron/upsample_nearest_op.cc b/modules/detectron/upsample_nearest_op.cc index e5b187d2a8334..631e17b231f91 100644 --- a/modules/detectron/upsample_nearest_op.cc +++ b/modules/detectron/upsample_nearest_op.cc @@ -15,13 +15,13 @@ */ #include "upsample_nearest_op.h" -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN #include "caffe2/ideep/operators/operator_fallback_ideep.h" #include "caffe2/ideep/utils/ideep_operator.h" #endif namespace caffe2 { -#ifdef CAFFE2_USE_MKLDNN +#ifdef USE_MKLDNN REGISTER_IDEEP_OPERATOR( UpsampleNearest, IDEEPFallbackOp>); diff --git a/modules/detectron/upsample_nearest_op.h b/modules/detectron/upsample_nearest_op.h index 636341cb2041a..f850f0381a1e8 100644 --- a/modules/detectron/upsample_nearest_op.h +++ b/modules/detectron/upsample_nearest_op.h @@ -30,7 +30,7 @@ class UpsampleNearestOp final : public Operator { UpsampleNearestOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), scale_(this->template GetSingleArgument("scale", 2)) { - DCHECK_GE(scale_, 1); + TORCH_DCHECK_GE(scale_, 1); } USE_OPERATOR_CONTEXT_FUNCTIONS; @@ -88,7 +88,7 @@ class UpsampleNearestGradientOp final : public Operator { UpsampleNearestGradientOp(const OperatorDef& def, Workspace* ws) : Operator(def, ws), scale_(this->template GetSingleArgument("scale", 2)) { - DCHECK_GE(scale_, 1); + TORCH_DCHECK_GE(scale_, 1); } USE_OPERATOR_CONTEXT_FUNCTIONS; diff --git a/mypy-strict.ini b/mypy-strict.ini index d233600f37ee5..460599699c46f 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -64,3 +64,6 @@ ignore_missing_imports = True [mypy-mypy.*] ignore_missing_imports = True + +[mypy-usort.*] +ignore_missing_imports = True diff --git a/mypy.ini b/mypy.ini index f57a08022f106..eb6b50218a966 100644 --- a/mypy.ini +++ b/mypy.ini @@ -61,6 +61,21 @@ ignore_missing_imports = True [mypy-torch.ao.quantization.experimental.apot_utils] ignore_missing_imports = True +[mypy-torch.ao.quantization.experimental.quantizer] +ignore_missing_imports = True + +[mypy-torch.ao.quantization.experimental.observer] +ignore_missing_imports = True + +[mypy-torch.ao.quantization.experimental.APoT_tensor] +ignore_missing_imports = True + +[mypy-torch.ao.quantization.experimental.fake_quantize_function] +ignore_missing_imports = True + +[mypy-torch.ao.quantization.experimental.fake_quantize] +ignore_missing_imports = True + # # Files with various errors. Mostly real errors, possibly some false # positives as well. @@ -275,3 +290,6 @@ ignore_missing_imports = True [mypy-dill.*] ignore_missing_imports = True + +[mypy-usort.*] +ignore_missing_imports = True diff --git a/pt_defs.oss.bzl b/pt_defs.oss.bzl deleted file mode 100644 index 879acb31f8b83..0000000000000 --- a/pt_defs.oss.bzl +++ /dev/null @@ -1,810 +0,0 @@ -load("@bazel_skylib//lib:paths.bzl", "paths") -load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") -load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") -load(":build_variables.bzl", "aten_native_source_list") -load( - ":ufunc_defs.bzl", - "aten_ufunc_generated_cpu_kernel_sources", - "aten_ufunc_generated_cpu_sources", - "aten_ufunc_generated_cuda_sources", -) - -USED_PT_BACKENDS = [ - "CPU", - "QuantizedCPU", - "SparseCPU", # brings ~20 kb size regression -] - -# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892 -PT_BACKEND_HEADERS = [ - "CPU", - "CUDA", - "CompositeExplicitAutograd", - "CompositeExplicitAutogradNonFunctional", - "CompositeImplicitAutograd", - "Meta", -] - -PT_BASE_OPS = [ - "aten::_coalesced_", - "aten::_copy_from", - "aten::_empty_affine_quantized", - "aten::_empty_per_channel_affine_quantized", - "aten::_indices", - "aten::_nnz", - "aten::_values", - "aten::add", - "aten::add_", - "aten::arange", - "aten::as_strided", - "aten::as_strided_", - "aten::cat", - "aten::clone", - "aten::coalesce", - "aten::contiguous", - "aten::copy_", - "aten::copy_sparse_to_sparse_", - "aten::dense_dim", - "aten::dequantize", - "aten::div", - "aten::div_", - "aten::empty", - "aten::empty_like", - "aten::empty_strided", - "aten::empty.memory_format", - "aten::eq", - "aten::equal", - "aten::expand", - "aten::fill_", - "aten::is_coalesced", - "aten::is_complex", - "aten::is_floating_point", - "aten::is_leaf", - "aten::is_nonzero", - "aten::item", - "aten::max", - "aten::min", - "aten::mul", - "aten::mul_", - "aten::narrow", - "aten::ne", - "aten::permute", - "aten::q_per_channel_axis", - "aten::q_per_channel_scales", - "aten::q_per_channel_zero_points", - "aten::q_scale", - "aten::q_zero_point", - "aten::qscheme", - "aten::quantize_per_tensor", - "aten::reshape", - "aten::_reshape_alias", - "aten::resize_", - "aten::resize_as_", - "aten::scalar_tensor", - "aten::select", - "aten::set_", - "aten::size", - "aten::slice", - "aten::sparse_dim", - "aten::sparse_resize_and_clear_", - "aten::squeeze", - "aten::squeeze_", - "aten::stride", - "aten::sub", - "aten::sub_", - "aten::sum", - "aten::t", - "aten::to", - "aten::_to_copy", - "aten::unsqueeze", - "aten::view", - "aten::zero_", - "aten::zeros", - "aten::zeros_like", -] - -def get_aten_compiler_flags(): - return ATEN_COMPILER_FLAGS - -def get_generate_code_bin_outs(): - return { - "autograd/generated/ADInplaceOrViewTypeEverything.cpp": ["autograd/generated/ADInplaceOrViewTypeEverything.cpp"], - "autograd/generated/ADInplaceOrViewType_0.cpp": ["autograd/generated/ADInplaceOrViewType_0.cpp"], - "autograd/generated/ADInplaceOrViewType_1.cpp": ["autograd/generated/ADInplaceOrViewType_1.cpp"], - "autograd/generated/Functions.cpp": ["autograd/generated/Functions.cpp"], - "autograd/generated/Functions.h": ["autograd/generated/Functions.h"], - "autograd/generated/TraceTypeEverything.cpp": ["autograd/generated/TraceTypeEverything.cpp"], - "autograd/generated/TraceType_0.cpp": ["autograd/generated/TraceType_0.cpp"], - "autograd/generated/TraceType_1.cpp": ["autograd/generated/TraceType_1.cpp"], - "autograd/generated/TraceType_2.cpp": ["autograd/generated/TraceType_2.cpp"], - "autograd/generated/TraceType_3.cpp": ["autograd/generated/TraceType_3.cpp"], - "autograd/generated/TraceType_4.cpp": ["autograd/generated/TraceType_4.cpp"], - "autograd/generated/VariableType.h": ["autograd/generated/VariableType.h"], - "autograd/generated/VariableTypeEverything.cpp": ["autograd/generated/VariableTypeEverything.cpp"], - "autograd/generated/VariableType_0.cpp": ["autograd/generated/VariableType_0.cpp"], - "autograd/generated/VariableType_1.cpp": ["autograd/generated/VariableType_1.cpp"], - "autograd/generated/VariableType_2.cpp": ["autograd/generated/VariableType_2.cpp"], - "autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"], - "autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"], - "autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"], - } - -ATEN_COMPILER_FLAGS = [ - "-fexceptions", - "-frtti", - "-fPIC", - "-Os", - "-Wno-absolute-value", - "-Wno-deprecated-declarations", - "-Wno-macro-redefined", - "-Wno-tautological-constant-out-of-range-compare", - "-Wno-unknown-pragmas", - "-Wno-unknown-warning-option", - "-Wno-unused-function", - "-Wno-unused-variable", - "-Wno-pass-failed", - "-Wno-shadow", -] - -PT_COMPILER_FLAGS = [ - "-frtti", - "-Os", - "-Wno-unknown-pragmas", - "-Wno-write-strings", - "-Wno-unused-variable", - "-Wno-unused-function", - "-Wno-deprecated-declarations", - "-Wno-shadow", - "-Wno-global-constructors", - "-Wno-missing-prototypes", - "-std=gnu++17", # to accommodate Eigen -] - -def get_template_source_dict(): - ret = {} - for file_path in TEMPLATE_SOURCE_LIST: - path_prefix = paths.dirname(file_path) - if path_prefix not in ret: - ret[path_prefix] = [] - ret[path_prefix].append(file_path) - return ret - -def get_gen_oplist_outs(): - return { - #"SupportedMobileModelsRegistration.cpp": [ - # "SupportedMobileModelsRegistration.cpp", - #], - "selected_mobile_ops.h": [ - "selected_mobile_ops.h", - ], - "selected_operators.yaml": [ - "selected_operators.yaml", - ], - } - -def get_pt_compiler_flags(): - return PT_COMPILER_FLAGS - -def get_aten_preprocessor_flags(): - # read_config is not allowed outside of function in Starlark - ATEN_PREPROCESSOR_FLAGS = [ - "-DC10_MOBILE", - "-DCPU_CAPABILITY_DEFAULT", - "-DCPU_CAPABILITY=DEFAULT", - "-DCAFFE2_USE_LITE_PROTO", - "-DATEN_CUDNN_ENABLED_FBXPLAT=0", - "-DATEN_MKLDNN_ENABLED_FBXPLAT=0", - "-DATEN_NNPACK_ENABLED_FBXPLAT=0", - "-DATEN_MKL_ENABLED_FBXPLAT=0", - "-DATEN_MKL_SEQUENTIAL_FBXPLAT=0", - "-DUSE_PYTORCH_METAL", - "-DUSE_PYTORCH_QNNPACK", - "-DUSE_XNNPACK", - "-DNO_EXPORT", - "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", - "-DAT_PARALLEL_OPENMP_FBXPLAT=0", - "-DAT_PARALLEL_NATIVE_FBXPLAT=1", - "-DAT_PARALLEL_NATIVE_TBB_FBXPLAT=0", - "-DUSE_LAPACK_FBXPLAT=0", - "-DAT_BLAS_F2C_FBXPLAT=0", - "-DAT_BLAS_USE_CBLAS_DOT_FBXPLAT=0", - "-DUSE_RUY_QMATMUL", # need third_party:ruy - ] - - # if get_disable_per_op_profiling(): - ATEN_PREPROCESSOR_FLAGS.append("-DPYTORCH_DISABLE_PER_OP_PROFILING") - return ATEN_PREPROCESSOR_FLAGS - -TEMPLATE_SOURCE_LIST = [ - "torch/csrc/jit/runtime/register_prim_ops.cpp", - "torch/csrc/jit/runtime/register_special_ops.cpp", -] + aten_native_source_list - -# For selective build, we can lump the CPU and CPU kernel sources altogether -# because there is only ever one vectorization variant that is compiled -def aten_ufunc_generated_all_cpu_sources(gencode_pattern = "{}"): - return ( - aten_ufunc_generated_cpu_sources(gencode_pattern) + - aten_ufunc_generated_cpu_kernel_sources(gencode_pattern) - ) - -def get_template_registration_files_outs(): - outs = {} - - for file_path in TEMPLATE_SOURCE_LIST: - outs[file_path] = [file_path] - - for base_name in aten_ufunc_generated_all_cpu_sources(): - file_path = "aten/src/ATen/{}".format(base_name) - outs[file_path] = [file_path] - - return outs - -def get_pt_preprocessor_flags(): - # read_config is not allowed outside of function in Starlark - PT_PREPROCESSOR_FLAGS = [ - "-D_THP_CORE", - "-DC10_MOBILE", - "-DUSE_SCALARS", - "-DNO_CUDNN_DESTROY_HANDLE", - "-DNO_EXPORT", - "-DBUILD_CAFFE2", - ] - return PT_PREPROCESSOR_FLAGS - -def is_arvr_mode(): - return False - -def get_build_from_deps_query(): - build_from_query = native.read_config("pt", "build_from_deps_query", "1") - return bool(int(build_from_query)) - -def get_enable_lightweight_dispatch(): - enable_lightweight_dispatch = native.read_config("pt", "enable_lightweight_dispatch", "0") - return bool(int(enable_lightweight_dispatch)) - -def get_static_dispatch_backend(): - static_dispatch_backend = native.read_config("pt", "static_dispatch_backend", None) - if static_dispatch_backend == None: - return [] - return static_dispatch_backend.split(";") - -def get_aten_codegen_extra_params(backends): - if get_build_from_deps_query(): - extra_params = { - "force_schema_registration": True, - } - static_backends = get_static_dispatch_backend() - if static_backends: - extra_params["static_dispatch_backend"] = static_backends - extra_params["enabled_backends"] = static_backends - else: - extra_params["enabled_backends"] = backends - return extra_params - else: - return {} - -def gen_aten_files( - name, - extra_flags = {}, - visibility = [], - compatible_with = []): - extra_params = [] - force_schema_registration = extra_flags.get("force_schema_registration", False) - op_registration_allowlist = extra_flags.get("op_registration_allowlist", None) - op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None) - enabled_backends = extra_flags.get("enabled_backends", None) - static_dispatch_backend = extra_flags.get("static_dispatch_backend", None) - - if force_schema_registration: - extra_params.append("--force_schema_registration") - if op_registration_allowlist != None and is_string(op_registration_allowlist): - extra_params.append("--op_registration_whitelist") - extra_params.append(op_registration_allowlist) - if op_selection_yaml_path != None and is_string(op_selection_yaml_path): - extra_params.append("--op_selection_yaml_path") - extra_params.append(op_selection_yaml_path) - if enabled_backends != None and is_list(enabled_backends): - extra_params.append("--backend_whitelist") - extra_params.extend(enabled_backends) - if get_enable_lightweight_dispatch(): - extra_params.append("--skip_dispatcher_op_registration") - if static_dispatch_backend: - extra_params.append("--static_dispatch_backend") - extra_params.extend(static_dispatch_backend) - backends = static_dispatch_backend - else: - backends = enabled_backends - fb_xplat_genrule( - name = name, - default_outs = ["."], - outs = get_aten_generated_files(backends), - cmd = "$(exe //torchgen:gen) " + " ".join([ - "--source-path $(location //:aten_src_path)/aten/src/ATen", - "--install_dir $OUT", - ] + extra_params), - visibility = visibility, - compatible_with = compatible_with, - ) - -def get_aten_generated_files(enabled_backends): - # NB: RegisterMeta counts as an optionally enabled backend, - # and is intentionally omitted from here - src_files = [ - "RegisterBackendSelect.cpp", - "RegisterCompositeImplicitAutograd.cpp", - "RegisterCompositeExplicitAutograd.cpp", - "RegisterCompositeExplicitAutogradNonFunctional.cpp", - "CompositeViewCopyKernels.cpp", - "RegisterSchema.cpp", - "Declarations.yaml", - "Functions.cpp", - "Functions.h", - "RedispatchFunctions.h", - "NativeFunctions.h", - "NativeMetaFunctions.h", - "MethodOperators.h", - "FunctionalInverses.h", - "Operators.h", - "Operators_0.cpp", - "Operators_1.cpp", - "Operators_2.cpp", - "Operators_3.cpp", - "Operators_4.cpp", - "CompositeImplicitAutogradFunctions.h", - "CompositeImplicitAutogradFunctions_inl.h", - "CompositeExplicitAutogradFunctions.h", - "CompositeExplicitAutogradFunctions_inl.h", - "CompositeExplicitAutogradNonFunctionalFunctions.h", - "CompositeExplicitAutogradNonFunctionalFunctions_inl.h", - "core/ATenOpList.cpp", - "core/TensorBody.h", - "core/TensorMethods.cpp", - "core/aten_interned_strings.h", - "core/enum_tag.h", - ] + get_aten_derived_type_srcs(enabled_backends) - - # This is tiresome. A better strategy would be to unconditionally - # generate these files, and then only actually COMPILE them depended - # on the generated set. C'est la vie... - if "CPU" in enabled_backends: - src_files.extend(aten_ufunc_generated_cpu_sources()) - src_files.extend(aten_ufunc_generated_cpu_kernel_sources()) - if "CUDA" in enabled_backends: - # Cannot unconditionally include this, because in the Edge selective - # build CUDA is not enabled and thus the ufunc codegen for CUDA gets - # skipped - src_files.extend(aten_ufunc_generated_cuda_sources()) - - res = {} - for file_name in src_files: - res[file_name] = [file_name] - return res - -def get_template_registration_file_rules(rule_name): - rules = [] - for file_path in TEMPLATE_SOURCE_LIST: - rules.append(":{}[{}]".format(rule_name, file_path)) - for file_path in aten_ufunc_generated_all_cpu_sources(): - rules.append(":{}[aten/src/ATen/{}]".format(rule_name, file_path)) - - return rules - -# Originally, there were two sets of codes in caffe2:aten_cpu, native codes and non-native. -# Now we have only non-naitve sources in aten_cpu. However, there are some aten related -# tests that may require both native and non-native codes. This rule is used to generate -# both aten_cpu and aten_native_cpu. They are using the same compilation setups. -def build_aten_cpu(name, srcs, deps = []): - cxx_library( - name = name, - srcs = srcs, - header_namespace = "", - compiler_flags = get_pt_compiler_flags(), - exported_preprocessor_flags = get_aten_preprocessor_flags(), - link_whole = True, - linker_flags = ["-Wl,--no-as-needed", "-ldl"], - visibility = ["PUBLIC"], - deps = [ - "//third_party:cpuinfo", - "//third_party:glog", - "//third_party:XNNPACK", - #"//third_party/linker_lib:omp", - ], - exported_deps = [ - "//third_party:fmt", - "//aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack", - "//c10:c10", - ":aten_header", - ":caffe2_headers", - ":common_core", - ":generated_aten_config_header", - ":generated_aten_headers_cpu", - ":jit_core_headers", - ":pthreadpool", - "//third_party:ruy_lib", - ], - ) - -######### selective build ######### - -def get_pt_ops_deps(name, deps, train = False, enforce_traced_op_list = False, enable_flatbuffer = False, **kwargs): - if not get_build_from_deps_query(): - return deps - pt_operator_registry( - name, - deps, - train = train, - enforce_traced_op_list = enforce_traced_op_list, - enable_flatbuffer = enable_flatbuffer, - **kwargs - ) - return deps + [":" + name] - -# pt_operator_registry is the method that defines the fb_xplat_cxx_library that contains -# code for all selected PyTorch Operators and kernel functions. This also includes -# operator registration into the dispatcher. -# -# template_select: bool: Indicates if template based selective build is enabled. -# -# enforce_traced_op_list: bool: Enforces that only new-style operator -# lists based on the all_mobile_model_configs.yaml file and tracing based selective -# build are used in this library. -# -# train: bool: Build this library for training (True) or inference only (False). -# If built for training, codegen for VariableType is also included. -# -# pt_allow_forced_schema_registration: Manually disables forced schema registration when set to false, Default is true. -# Only does anything when train=True and the app requires full jit then force_schema_registration needs to occur. -# As Federated Learning migrates to lite interpreter -# we can slowly turn off forced schema registration as it is useless space and floods the compatibility api -# -def pt_operator_registry( - name, - deps = [], - train = False, - labels = [], - env = [], - template_select = True, - enforce_traced_op_list = False, - pt_allow_forced_schema_registration = True, - enable_flatbuffer = False, - **kwargs): - compatible_with = kwargs.get("compatible_with", []) - code_gen_files = pt_operator_query_codegen(name, deps = deps, train = train, enforce_traced_op_list = enforce_traced_op_list, pt_allow_forced_schema_registration = pt_allow_forced_schema_registration, compatible_with = compatible_with) - code_gen_srcs = code_gen_files["srcs"] - - lib_deps = [ - ":aten_cpu", - ":torch_mobile_core", - "//c10:c10", - "//third_party:glog", - ] - - #if train: - # lib_deps = lib_deps + ["fbsource//xplat/caffe2:torch_mobile_train"] - - exported_preprocessor_flags = get_aten_preprocessor_flags() - exported_preprocessor_flags += kwargs.pop("exported_preprocessor_flags", []) - if template_select: - # In addition to the - # original code-gen select, this option further filter more operators based on - # compile-time calculation. Examples include prim ops and any other ops that were - # not filtered out before. The purpose of this option is to reduce the production - # size further. However, it may have less flexibility, especially for tests from - # python, where the used operator list is not explicitly generated. If the tests - # are for functionality but not for size, and it's difficult to maintain an explicit - # operator list, it's suggested to turn this option off. - exported_preprocessor_flags.append("-DTEMPLATE_SELECTIVE_BUILD") - kwargs.pop("exported_headers", []) - cxx_library( - name = name, - srcs = code_gen_srcs, - linker_flags = [ - "-Wl,--no-as-needed", - "-ldl", - ], - link_whole = True, - soname = "libtorch-code-gen.$(ext)", - compiler_flags = get_aten_compiler_flags(), - platform_compiler_flags = get_cpukernel_avx2_flags(), - platform_deps = get_cpukernel_avx2_deps(), - header_namespace = "ATen", - exported_headers = code_gen_files["headers"], - exported_preprocessor_flags = exported_preprocessor_flags, - headers = kwargs.pop("headers", []), - deps = lib_deps + [ - "//third_party:XNNPACK", - ], - **kwargs - ) - -def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends): - return [ - ":{}[{}]".format(aten_rule_name, "Register" + backend + ".cpp") - for backend in enabled_backends - ] - -def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends): - return [ - ":{}[{}]".format(aten_rule_name, f) - for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"] - ] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends) - -def get_aten_derived_type_srcs(enabled_backends): - return [ - "Register" + derived_type + ".cpp" - for derived_type in enabled_backends - ] + [ - derived_type + "Functions.h" - for derived_type in enabled_backends - if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend() - ] + [ - derived_type + "Functions_inl.h" - for derived_type in enabled_backends - if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend() - ] - -def pt_operator_query_codegen(name, deps = [], train = False, enforce_traced_op_list = False, pt_allow_forced_schema_registration = True, compatible_with = []): - oplist_dir_name = name + "_pt_oplist" - - # @lint-ignore BUCKLINT - fb_xplat_genrule( - name = oplist_dir_name, - cmd = ("$(exe //:gen_oplist) " + - "--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " + - ("" if enforce_traced_op_list else "--allow_include_all_overloads ") + - "--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])), - outs = get_gen_oplist_outs(), - default_outs = ["."], - compatible_with = compatible_with, - ) - - # Aten files - aten_genrule = name + "_aten" - extra_flags = { - "enabled_backends": USED_PT_BACKENDS, - "op_selection_yaml_path": "$(location :{}[selected_operators.yaml])".format(oplist_dir_name), - } - - if train and pt_allow_forced_schema_registration: - extra_flags["force_schema_registration"] = True - - # if get_enable_lightweight_dispatch(): - # unboxing_genrule = name + "_unboxing" - # gen_aten_unboxing_files( - # unboxing_genrule, - # extra_flags = extra_flags, - # ) - - static_dispatch_backend = get_static_dispatch_backend() - if static_dispatch_backend: - extra_flags["static_dispatch_backend"] = static_dispatch_backend - - gen_aten_files( - aten_genrule, - extra_flags = extra_flags, - compatible_with = compatible_with, - ) - - # unboxing_wrappers files - extra_params = [ - "--operators_yaml_path", - "$(location :" + oplist_dir_name + "[selected_operators.yaml])", - ] - unboxing_and_autograd_genrule = name + "_unboxing_and_autograd" - gen_aten_libtorch_files(unboxing_and_autograd_genrule, extra_params, compatible_with) - - # Template runtime files (prim ops, etc) - template_registration_genrule = name + "_template_registration" - copy_template_registration_files(template_registration_genrule) - - srcs = get_aten_selective_cpp_rules( - aten_genrule, - static_dispatch_backend if static_dispatch_backend else USED_PT_BACKENDS, - ) + get_template_registration_file_rules( - template_registration_genrule, - ) + ([ - ":{}[autograd/generated/VariableType_0.cpp]".format(unboxing_and_autograd_genrule), - ":{}[autograd/generated/VariableType_1.cpp]".format(unboxing_and_autograd_genrule), - ":{}[autograd/generated/VariableType_2.cpp]".format(unboxing_and_autograd_genrule), - ":{}[autograd/generated/VariableType_3.cpp]".format(unboxing_and_autograd_genrule), - ":{}[autograd/generated/VariableType_4.cpp]".format(unboxing_and_autograd_genrule), - ":{}[autograd/generated/ADInplaceOrViewType_0.cpp]".format(unboxing_and_autograd_genrule), - ":{}[autograd/generated/ADInplaceOrViewType_1.cpp]".format(unboxing_and_autograd_genrule), - ] if train else []) + ([ - #":{}[SupportedMobileModelsRegistration.cpp]".format(oplist_dir_name), - ]) - - headers = { - "selected_mobile_ops.h": ":{}[selected_mobile_ops.h]".format(oplist_dir_name), - } - - # if get_enable_lightweight_dispatch(): - # srcs.extend([ - # ":{}[UnboxingFunctions_0.cpp]".format(unboxing_genrule), - # ":{}[UnboxingFunctions_1.cpp]".format(unboxing_genrule), - # ":{}[UnboxingFunctions_2.cpp]".format(unboxing_genrule), - # ":{}[UnboxingFunctions_3.cpp]".format(unboxing_genrule), - # ":{}[UnboxingFunctions_4.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_0.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_1.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_2.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_3.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_4.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_5.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_6.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_7.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_8.cpp]".format(unboxing_genrule), - # ":{}[RegisterCodegenUnboxedKernels_9.cpp]".format(unboxing_genrule), - # ]) - # headers["UnboxingFunctions.h"] = ":{}[UnboxingFunctions.h]".format(unboxing_genrule) - return {"headers": headers, "srcs": srcs} - -def gen_aten_libtorch_files(name, extra_params = [], compatible_with = []): - fb_xplat_genrule( - name = name, - outs = get_generate_code_bin_outs(), - default_outs = ["."], - cmd = "mkdir -p tools && " + - "$(exe //tools/setup_helpers:generate_code_bin) " + " ".join( - # Mobile build only needs libtorch - skip python bindings for now, except - # for ovrsource, which needs Python bindings. - (["--subset libtorch"] if not is_arvr_mode() else []) + [ - "--native-functions-path $(location :aten_src_path)/aten/src/ATen/native/native_functions.yaml", - "--tags-path $(location :aten_src_path)/aten/src/ATen/native/tags.yaml", # todo D35992309 - "--install_dir $OUT", - ] + extra_params, - ), - cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " + - "$(exe //tools/setup_helpers:generate_code_bin) " + " ".join( - # Mobile build only needs libtorch - skip python bindings for now, except - # for ovrsource, which needs Python bindings. - (["--subset libtorch"] if not is_arvr_mode() else []) + [ - "--native-functions-path $(location :aten_src_path)/aten/src/ATen/native/native_functions.yaml", - "--tags-path $(location :aten_src_path)/aten/src/ATen/native/tags.yaml", - "--install_dir $OUT", - ] + extra_params, - ), - compatible_with = compatible_with, - ) - -def copy_template_registration_files(name): - cmd = [] - cmd_exe = [] - - template_source_dict = get_template_source_dict() - - # Ideally, we would run one copy command for a single source directory along - # with all its child directories, but it's somewhat hard to know if a directory - # is a child of another just bu looking at the metadata (directory relative - # path) that we currently have since 1 directory could look like a parent of - # another and yet come from a different filegroup() rule. - # - for (path_prefix, file_paths) in template_source_dict.items(): - cmd.append("mkdir -p $OUT/{}".format(path_prefix)) - cmd_exe.append("md $OUT/{}".format(path_prefix)) - - # Adding *.cpp is a workaround to prevent cp from thrown an error when it - # encounters a directory (since -r was not specified). If files with an - # extension other than .cpp need to be copied, then the command below - # will not work and will need to be updated. - # - cmd.append("cp -f {0}/{1}/*.cpp $OUT/{1}/".format("$(location :templated_selective_build_srcs)", path_prefix)) - cmd_exe.append("robocopy /E {0}/{1} $OUT/{1}".format("$(location :templated_selective_build_srcs)", path_prefix)) - - cmd.append("mkdir -p $OUT/aten/src/ATen") - cmd_exe.append("md $OUT/aten/src/ATen") - - # NB: CUDA is skipped here because this is selective build and CUDA is not - # supported for selective build - for ufunc_file in aten_ufunc_generated_all_cpu_sources("$(location :gen_aten[{}])"): - cmd.append("cp -f " + ufunc_file + " $OUT/aten/src/ATen") - cmd_exe.append("copy " + ufunc_file + " $OUT/aten/src/ATen") - - fb_xplat_genrule( - name = name, - cmd = " && ".join(cmd), - cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)), - outs = get_template_registration_files_outs(), - default_outs = ["."], - ) - -def pt_operator_library( - name, - ops = [], - exported_deps = [], - check_decl = True, - train = False, - model = None, - include_all_operators = False, - **kwargs): - model_name = name - - if get_build_from_deps_query(): - ops = [op.strip() for op in ops] - - # If ops are specified, then we are in static selective build mode, so we append - # base ops to this list to avoid additional special case logic in subsequent code. - if len(ops) > 0: - ops.extend(PT_BASE_OPS) - - visibility = kwargs.pop("visibility", ["PUBLIC"]) - - fb_xplat_genrule( - name = name, - out = "model_operators.yaml", - cmd = ( - "$(exe :gen_operators_yaml) " + - "{optionally_root_ops} " + - "{optionally_training_root_ops} " + - "--rule_name {rule_name} " + - "--output_path \"${{OUT}}\" " + - "--model_name {model_name} " + - "--dep_graph_yaml_path pytorch_op_deps.yaml " + - "--models_yaml_path all_mobile_model_configs.yaml " + - #"{optionally_model_versions} " + - #"{optionally_model_assets} " + - #"{optionally_model_traced_backends} " + - "{optionally_include_all_operators}" - ).format( - rule_name = name, - model_name = model_name, - optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "", - optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "", - #optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "", - #optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "", - #optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "", - optionally_include_all_operators = "--include_all_operators " if include_all_operators else "", - ), - labels = ["pt_operator_library"], # for pt_operator_query_codegen query - visibility = visibility, - **kwargs - ) - else: - if check_decl: - pass - # ensure_ops_are_declared(ops) - - cxx_library( - name = name, - compiler_flags = get_pt_compiler_flags(), - cxx_platform_compiler_flags = get_cpukernel_avx2_flags(), - exported_deps = exported_deps, - **kwargs - ) - -def compose_platform_setting_list(settings): - """Settings object: - os/cpu pair: should be valid key, or at most one part can be wildcard. - flags: the values added to the compiler flags - """ - result = [] - for setting in settings: - result = result.append([ - "^{}-{}$".format(setting["os"], setting["cpu"]), - setting["flags"], - ]) - return result - -def get_cpukernel_avx2_flags(): - # flags = compose_platform_setting_list([ - # { - # "cpu": "x86_64", - # "flags": ["-DHAVE_AVX2_CPU_DEFINITION"], - # "os": "macosx", - # }, - # ]) if build_cpukernel_avx2() else [] - return [] - -def build_cpukernel_avx2(): - return not is_arvr_mode() - -def get_cpukernel_avx2_deps(): - # flags = compose_platform_setting_list([ - # { - # "cpu": "x86_64", - # "flags": ["fbsource//xplat/caffe2:cpukernel_avx2"], - # "os": "macosx", - # }, - # ]) if build_cpukernel_avx2() else [] - return [] diff --git a/pt_ops.bzl b/pt_ops.bzl new file mode 100644 index 0000000000000..73f0f8f40908e --- /dev/null +++ b/pt_ops.bzl @@ -0,0 +1,617 @@ +load("//tools/build_defs:expect.bzl", "expect") +load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") +load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") + +# @lint-ignore BUCKRESTRICTEDSYNTAX +IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build + +USED_PT_BACKENDS = [ + "CPU", + "QuantizedCPU", + "SparseCPU", # brings ~20 kb size regression +] + +def pt_operator_library( + name, + ops = [], + exported_deps = [], + check_decl = True, + train = False, + model = None, + include_all_operators = False, + **kwargs): + (model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information( + name, + model, + ) + + ops = [op.strip() for op in ops] + + # If ops are specified, then we are in static selective build mode, so we append + # base ops to this list to avoid additional special case logic in subsequent code. + if len(ops) > 0: + ops.extend(PT_BASE_OPS) + + labels = kwargs.pop("labels", []) + visibility = kwargs.pop("visibility", ["PUBLIC"]) + + fb_xplat_genrule( + name = name, + out = "model_operators.yaml", + cmd = ( + "$(exe {exe}) " + + "{optionally_root_ops} " + + "{optionally_training_root_ops} " + + "--rule_name {rule_name} " + + "--output_path \"${{OUT}}\" " + + "--model_name {model_name} " + + "--dep_graph_yaml_path {dep_graph_yaml} " + + "--models_yaml_path {models_yaml} " + + "{optionally_model_versions} " + + "{optionally_model_assets} " + + "{optionally_model_traced_backends} " + + "{optionally_include_all_operators}" + ).format( + exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml", + rule_name = name, + model_name = model_name, + dep_graph_yaml = "none" if IS_OSS else "$(location fbsource//xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ", + models_yaml = "none" if IS_OSS else "$(location fbsource//xplat/pytorch_models:all_mobile_model_configs)/build/all_mobile_model_configs.yaml ", + optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "", + optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "", + optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "", + optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "", + optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "", + optionally_include_all_operators = "--include_all_operators " if include_all_operators else "", + ), + labels = labels + [ + "pt_operator_library", + "supermodule:android/default/pytorch", + "supermodule:ios/default/public.pytorch", + ] + (["pt_train_operator_library"] if train else []), + visibility = visibility, + **kwargs + ) + +def validate_and_extract_model_information(name, model): + model_name = name + model_versions = None + model_assets = None + model_traced_backends = None + + if model != None: + model_name = model.get("name") + expect(model_name != None, "Expected Model Name to be present") + model_versions = model.get("versions") + expect(is_list(model_versions), "Expected model versions to be a list of string") + for ver in model_versions or []: + expect(is_string(ver), "Expected version '{}' to be string".format(str(ver))) + model_assets = model.get("assets") + expect( + model_assets == None or is_list(model_assets), + "Expected model assets to be a list of string if specified", + ) + for asset_name in model_assets or []: + expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name))) + model_traced_backends = model.get("traced_backends") + expect( + model_traced_backends == None or is_list(model_traced_backends), + "Expected model traced backends to be a list of string if specified", + ) + + if model_traced_backends != None: + for backend in model_traced_backends: + expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend))) + expect( + backend in USED_PT_BACKENDS, + "Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)), + ) + + return (model_name, model_versions, model_assets, model_traced_backends) + +# This file keeps a list of PyTorch operators used by any targets in +# @fbsource//xplat/... +# The purpose of the list is to avoid generating large number of unused +# operator registration code / BUCK rules at build time. +# See more detail at: https://fb.quip.com/ZVh1AgOKW8Vv + +PT_OPS_PRIM = [ + "aten::str", + "aten::list", + "aten::__range_length", + "aten::__derive_index", + "prim::TupleUnpack", + "prim::unchecked_cast", + "aten::IntImplicit", + "aten::FloatImplicit", + "aten::ScalarImplicit", + "aten::Bool.Tensor", + "aten::Bool.int", + "aten::Bool.float", + "aten::Int.Tensor", + "aten::Int.Scalar", + "aten::Int.int", + "aten::Int.bool", + "aten::Int.str", + "aten::Float.Tensor", + "aten::Float.Scalar", + "aten::Float.int", + "aten::Float.bool", + "aten::Float.str", + "aten::format", + "prim::NumToTensor.Scalar", + "prim::RaiseException", + "aten::Size", + "aten::size", + "prim::EnumName", + "prim::EnumValue.int", + "prim::EnumValue.float", + "prim::EnumValue.str", + "prim::TupleIndex", + "aten::ne.int_list", + "prim::unchecked_unwrap_optional", + "prim::device", + "prim::dtype", + "aten::__not__", + "aten::__is__", + "aten::__isnot__", + "aten::element_size", + "aten::numel", + "aten::dim", + "aten::get_device", + "aten::storage_offset", + "aten::is_contiguous", + "aten::select.t", + "aten::__getitem__.t", + "aten::append.t", + "aten::reverse.t", + "aten::extend.t", + "aten::copy.t", + "aten::_set_item.t", + "aten::clear.t", + "aten::Delete.t", + "aten::insert.t", + "aten::pop.t", + "aten::add.t", + "aten::add_.t", + "aten::slice.t", + "aten::list.t", + "aten::mul.left_t", + "aten::mul.right_", + "aten::mul_.t", + "aten::len.t", + "aten::eq.int_list", + "prim::Uninitialized", + "prim::Print", + "aten::eq.enum", + "aten::ne.enum", + "aten::dequantize.tensor", + "aten::dequantize.any", + "aten::add.str", + "aten::eq.int", + "aten::eq.float", + "aten::eq.int_float", + "aten::eq.float_int", + "aten::eq", + "aten::eq.str", + "aten::ne.int", + "aten::ne.float", + "aten::ne.int_float", + "aten::ne.float_int", + "aten::ne", + "aten::ne.str", + "aten::lt.int", + "aten::lt.float", + "aten::lt.int_float", + "aten::lt.float_int", + "aten::lt", + "aten::lt.str", + "aten::gt.int", + "aten::gt.float", + "aten::gt.int_float", + "aten::gt.float_int", + "aten::gt", + "aten::gt.str", + "aten::le.int", + "aten::le.float", + "aten::le.int_float", + "aten::le.float_int", + "aten::le", + "aten::le.str", + "aten::ge.int", + "aten::ge.float", + "aten::ge.int_float", + "aten::ge.float_int", + "aten::ge", + "aten::ge.str", + "aten::add.int", + "aten::add.float", + "aten::add.int_float", + "aten::add.float_int", + "aten::add", + "aten::sub.int", + "aten::sub.float", + "aten::sub.int_float", + "aten::sub.float_int", + "aten::sub", + "aten::mul.int", + "aten::mul.float", + "aten::mul.int_float", + "aten::mul.float_int", + "aten::mul", + "aten::__and__.bool", + "aten::__or__.bool", + "aten::__xor__.bool", + "aten::floor.int", + "aten::floor.float", + "aten::floor.Scalar", + "aten::ceil.int", + "aten::ceil.float", + "aten::ceil.Scalar", + "aten::neg.int", + "aten::neg.float", + "aten::neg.Scalar", + "aten::exp.int", + "aten::exp.float", + "aten::exp.Scalar", + "aten::remainder.int", + "aten::remainder.float", + "aten::remainder.int_float", + "aten::remainder.float_int", + "aten::remainder", + "aten::div.int", + "aten::div.float", + "aten::div", + "aten::floordiv.int", + "aten::floordiv.float", + "aten::floordiv.int_float", + "aten::floordiv.float_int", + "aten::floordiv", + "aten::pow.int", + "aten::pow.float", + "aten::pow.int_float", + "aten::pow.float_int", + "aten::pow.Scalar_Scalar", + "aten::pow.int_to_int", + "prim::min.int", + "prim::min.float", + "prim::min.int_float", + "prim::min.float_int", + "prim::min", + "prim::max.int", + "prim::max.float", + "prim::max.int_float", + "prim::max.float_int", + "prim::max", + "prim::type", + "aten::len.Tensor", + "aten::ord", + "aten::lower", + "aten::__contains__.str_list", + "aten::len.str", + "aten::__getitem__.str", + "aten::copy_.Tensor", + "aten::copy_.int", + "aten::copy_.float", + "aten::backward", + "aten::index.Tensor_hacked_twin", + "aten::_index_put_impl_.hacked_twin", + "aten::index_put_.hacked_twin", + "aten::index_put.hacked_twin", + "aten::to.prim_Device", + "aten::to.prim_dtype", + "prim::is_cuda", + "prim::data", + "prim::min.int_list", + "prim::max.int_list", + "prim::min.self_int", + "prim::max.self_int", + "prim::min.float_list", + "prim::max.float_list", + "prim::min.self_float", + "prim::max.self_float", + "prim::min.bool_list", + "prim::max.bool_list", + "prim::min.self_bool", + "prim::max.self_bool", + "aten::len.Dict_str", + "aten::keys.str", + "aten::values.str", + "aten::__getitem__.Dict_str", + "aten::get.str", + "aten::get.default_str", + "aten::setdefault.str", + "aten::Delete.Dict_str", + "aten::pop.Dict_str", + "aten::pop.Dict_default_str", + "aten::popitem.str", + "aten::clear.str", + "aten::update.str", + "aten::items.str", + "aten::copy.Dict_str", + "aten::__contains__.str", + "aten::_set_item.str", + "aten::dict.str", + "aten::len.Dict_int", + "aten::keys.int", + "aten::values.int", + "aten::__getitem__.Dict_int", + "aten::get.int", + "aten::get.default_int", + "aten::setdefault.int", + "aten::Delete.Dict_int", + "aten::pop.Dict_int", + "aten::pop.Dict_default_int", + "aten::popitem.int", + "aten::clear.int", + "aten::update.int", + "aten::items.int", + "aten::copy.Dict_int", + "aten::__contains__.int", + "aten::_set_item.int", + "aten::dict.int", + "aten::len.Dict_bool", + "aten::keys.bool", + "aten::values.bool", + "aten::__getitem__.Dict_bool", + "aten::get.bool", + "aten::get.default_bool", + "aten::setdefault.bool", + "aten::Delete.Dict_bool", + "aten::pop.Dict_bool", + "aten::pop.Dict_default_bool", + "aten::popitem.bool", + "aten::clear.bool", + "aten::update.bool", + "aten::items.bool", + "aten::copy.Dict_bool", + "aten::__contains__.bool", + "aten::_set_item.bool", + "aten::dict.bool", + "aten::len.Dict_float", + "aten::keys.float", + "aten::values.float", + "aten::__getitem__.Dict_float", + "aten::get.float", + "aten::get.default_float", + "aten::setdefault.float", + "aten::Delete.Dict_float", + "aten::pop.Dict_float", + "aten::pop.Dict_default_float", + "aten::popitem.float", + "aten::clear.float", + "aten::update.float", + "aten::items.float", + "aten::copy.Dict_float", + "aten::__contains__.float", + "aten::_set_item.float", + "aten::dict.float", + "aten::len.Dict_Tensor", + "aten::keys.Tensor", + "aten::values.Tensor", + "aten::__getitem__.Dict_Tensor", + "aten::get.Tensor", + "aten::get.default_Tensor", + "aten::setdefault.Tensor", + "aten::Delete.Dict_Tensor", + "aten::pop.Dict_Tensor", + "aten::pop.Dict_default_Tensor", + "aten::popitem.Tensor", + "aten::clear.Tensor", + "aten::update.Tensor", + "aten::items.Tensor", + "aten::copy.Dict_Tensor", + "aten::__contains__.Tensor", + "aten::_set_item.Tensor", + "aten::dict.Tensor", + "aten::__round_to_zero_floordiv.int", + "aten::mathremainder.int", + "aten::mathremainder.float", + "aten::mathremainder.int_float", + "aten::mathremainder.float_int", + "aten::mathremainder", + "aten::__and__.int", + "aten::__or__.int", + "aten::__xor__.int", + "aten::__lshift__.int", + "aten::__rshift__.int", + "aten::round.int", + "aten::round.float", + "aten::round.Scalar", + "aten::log.int", + "aten::log.float", + "aten::log.Scalar", + "aten::log.int_int", + "aten::log.float_float", + "aten::log.int_float", + "aten::log.float_int", + "aten::log.Scalar_Scalar", + "aten::log1p.int", + "aten::log1p.float", + "aten::log1p.Scalar", + "aten::log10.int", + "aten::log10.float", + "aten::log10.Scalar", + "aten::sqrt.int", + "aten::sqrt.float", + "aten::sqrt.Scalar", + "aten::acos.int", + "aten::acos.float", + "aten::acos.Scalar", + "aten::asin.int", + "aten::asin.float", + "aten::asin.Scalar", + "aten::atan.int", + "aten::atan.float", + "aten::atan.Scalar", + "aten::atan2.int", + "aten::atan2.float", + "aten::atan2.int_float", + "aten::atan2.float_int", + "aten::atan2.Scalar_Scalar", + "aten::cos.int", + "aten::cos.float", + "aten::cos.Scalar", + "aten::sin.int", + "aten::sin.float", + "aten::sin.Scalar", + "aten::tan.int", + "aten::tan.float", + "aten::tan.Scalar", + "aten::asinh.int", + "aten::asinh.float", + "aten::asinh.Scalar", + "aten::atanh.int", + "aten::atanh.float", + "aten::atanh.Scalar", + "aten::acosh.int", + "aten::acosh.float", + "aten::acosh.Scalar", + "aten::sinh.int", + "aten::sinh.float", + "aten::sinh.Scalar", + "aten::cosh.int", + "aten::cosh.float", + "aten::cosh.Scalar", + "aten::tanh.int", + "aten::tanh.float", + "aten::tanh.Scalar", + "aten::degrees.int", + "aten::degrees.float", + "aten::degrees.Scalar", + "aten::radians.int", + "aten::radians.float", + "aten::radians.Scalar", + "aten::fmod.int", + "aten::fmod.float", + "aten::fmod.int_float", + "aten::fmod.float_int", + "aten::fmod", + "aten::factorial.int", + "aten::isnan.float", + "aten::isfinite.float", + "aten::isinf.float", + "aten::gamma.int", + "aten::gamma.float", + "aten::gamma.Scalar", + "aten::erf.int", + "aten::erf.float", + "aten::erf.Scalar", + "aten::erfc.int", + "aten::erfc.float", + "aten::erfc.Scalar", + "aten::expm1.int", + "aten::expm1.float", + "aten::expm1.Scalar", + "aten::fabs.int", + "aten::fabs.float", + "aten::fabs.Scalar", + "aten::lgamma.int", + "aten::lgamma.float", + "aten::lgamma.Scalar", + "prim::abs.int", + "prim::abs.float", + "prim::abs.Scalar", + "aten::gcd.int", + "aten::copysign.int", + "aten::copysign.float", + "aten::copysign.int_float", + "aten::copysign.float_int", + "aten::copysign", + "aten::split", + "aten::tensor.float", + "aten::as_tensor.float", + "aten::tensor.int", + "aten::as_tensor.int", + "aten::tensor.bool", + "aten::as_tensor.bool", + "aten::_infer_size", + "aten::_no_grad_embedding_renorm_", + "aten::tensor", + "aten::as_tensor", + "aten::as_tensor.list", + "aten::_pack_sequence", + "aten::_get_tracing_state", + "aten::is_scripting", + "aten::_no_grad_uniform_", + "aten::_no_grad_normal_", + "aten::_no_grad_fill_", + "aten::_no_grad_zero_", +] + +PT_BASE_OPS = [ + "aten::_coalesced_", + "aten::_copy_from", + "aten::_empty_affine_quantized", + "aten::_empty_per_channel_affine_quantized", + "aten::_indices", + "aten::_nnz", + "aten::_values", + "aten::add", + "aten::add_", + "aten::arange", + "aten::as_strided", + "aten::as_strided_", + "aten::cat", + "aten::clone", + "aten::coalesce", + "aten::contiguous", + "aten::copy_", + "aten::copy_sparse_to_sparse_", + "aten::dense_dim", + "aten::dequantize", + "aten::div", + "aten::div_", + "aten::empty", + "aten::empty_like", + "aten::empty_strided", + "aten::eq", + "aten::equal", + "aten::expand", + "aten::fill_", + "aten::is_coalesced", + "aten::is_complex", + "aten::is_floating_point", + "aten::is_leaf", + "aten::is_nonzero", + "aten::item", + "aten::max", + "aten::min", + "aten::mul", + "aten::mul_", + "aten::narrow", + "aten::ne", + "aten::permute", + "aten::q_per_channel_axis", + "aten::q_per_channel_scales", + "aten::q_per_channel_zero_points", + "aten::q_scale", + "aten::q_zero_point", + "aten::qscheme", + "aten::quantize_per_tensor", + "aten::reshape", + "aten::_reshape_alias", + "aten::resize_", + "aten::resize_as_", + "aten::scalar_tensor", + "aten::select", + "aten::set_", + "aten::size", + "aten::slice", + "aten::sparse_dim", + "aten::sparse_resize_and_clear_", + "aten::squeeze", + "aten::squeeze_", + "aten::stride", + "aten::sub", + "aten::sub_", + "aten::sum", + "aten::t", + "aten::to", + "aten::_to_copy", + "aten::unsqueeze", + "aten::view", + "aten::zero_", + "aten::zeros", + "aten::zeros_like", +] diff --git a/pt_template_srcs.bzl b/pt_template_srcs.bzl index 8f1499268aaf1..7d8dfd53d376b 100644 --- a/pt_template_srcs.bzl +++ b/pt_template_srcs.bzl @@ -3,7 +3,7 @@ # being built load("@bazel_skylib//lib:paths.bzl", "paths") -load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") +load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load(":build_variables.bzl", "aten_native_source_list") load( ":ufunc_defs.bzl", @@ -59,6 +59,7 @@ METAL_SOURCE_LIST = [ "aten/src/ATen/native/metal/ops/MetalConvolution.mm", "aten/src/ATen/native/metal/ops/MetalCopy.mm", "aten/src/ATen/native/metal/ops/MetalHardswish.mm", + "aten/src/ATen/native/metal/ops/MetalHardshrink.mm", "aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm", "aten/src/ATen/native/metal/ops/MetalNeurons.mm", "aten/src/ATen/native/metal/ops/MetalPadding.mm", @@ -134,6 +135,7 @@ def get_generate_code_bin_outs(): if is_arvr_mode(): outs.update({ + "autograd/generated/python_enum_tag.cpp": ["autograd/generated/python_enum_tag.cpp"], "autograd/generated/python_fft_functions.cpp": ["autograd/generated/python_fft_functions.cpp"], "autograd/generated/python_functions.h": ["autograd/generated/python_functions.h"], "autograd/generated/python_functions_0.cpp": ["autograd/generated/python_functions_0.cpp"], @@ -150,17 +152,17 @@ def get_generate_code_bin_outs(): "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"], "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"], "autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"], - "autograd/generated/python_enum_tag.cpp": ["autograd/generated/python_enum_tag.cpp"], }) return outs -def get_template_registration_files_outs(): +def get_template_registration_files_outs(is_oss = False): outs = {} - for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST: - outs[file_path] = [file_path] + if not is_oss: + for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST: + outs[file_path] = [file_path] - for file_path in TEMPLATE_BATCH_BOX_COX_SOURCE_LIST: - outs[file_path] = [file_path] + for file_path in TEMPLATE_BATCH_BOX_COX_SOURCE_LIST: + outs[file_path] = [file_path] for file_path in TEMPLATE_SOURCE_LIST: outs[file_path] = [file_path] @@ -171,9 +173,9 @@ def get_template_registration_files_outs(): return outs -def get_template_registration_file_rules(rule_name): +def get_template_registration_file_rules(rule_name, is_oss = False): rules = [] - for file_path in TEMPLATE_SOURCE_LIST + TEMPLATE_MASKRCNN_SOURCE_LIST + TEMPLATE_BATCH_BOX_COX_SOURCE_LIST: + for file_path in TEMPLATE_SOURCE_LIST if is_oss else (TEMPLATE_SOURCE_LIST + TEMPLATE_MASKRCNN_SOURCE_LIST + TEMPLATE_BATCH_BOX_COX_SOURCE_LIST): rules.append(":{}[{}]".format(rule_name, file_path)) for file_path in aten_ufunc_generated_all_cpu_sources(): rules.append(":{}[aten/src/ATen/{}]".format(rule_name, file_path)) diff --git a/pytest.ini b/pytest.ini index c72e4ec8e8d33..53b5ad643ebf2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,12 +1,12 @@ [pytest] addopts = # show summary of all tests that did not pass - -ra + -rEfX # Make tracebacks shorter --tb=native # capture only Python print and C++ py::print, but not C output (low-level Python errors) --capture=sys - # enable all warnings - -Wd + --disable-warnings testpaths = test +junit_logging_reruns = all diff --git a/scripts/buck_setup.sh b/scripts/buck_setup.sh index 0d094fd98e955..8e60d92a5fd15 100644 --- a/scripts/buck_setup.sh +++ b/scripts/buck_setup.sh @@ -1,29 +1,37 @@ #!/bin/bash -printf "\n[Creating .buckconfig]\n" +printf "\nCreating .buckconfig\n" cp .buckconfig.oss .buckconfig +PROXY="" +if [ "$1" == "devserver" ]; then + echo -e '\n[download]\n proxy_host=fwdproxy\n proxy_port=8080\n proxy_type=HTTP\n' >> .buckconfig + PROXY="$(fwdproxy-config curl)" + printf "using proxy $PROXY\n\n" +fi + +cat .buckconfig + cd third_party || return -printf "\n[Generating wrappers for cpuionfo]\n" +printf "\nGenerating cpuinfo wrappers\n" python3 generate-cpuinfo-wrappers.py -printf "\n[Generating wrappers for xnnpack]\n" +printf "\nGenerating xnnpack wrappers\n" python3 generate-xnnpack-wrappers.py # bazel-skylib -printf "\n[Downloading bazel-skylib-1.0.2]\n" -curl -L -o /tmp/bazel-skylib-1.0.2.tar.gz https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz -mkdir bazel-skylib -tar -xf /tmp/bazel-skylib-1.0.2.tar.gz -C bazel-skylib/ +printf "\nDownloading bazel-skylib\n" +rm -rf bazel-skylib; mkdir bazel-skylib +curl -L $PROXY https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz|tar zx -C bazel-skylib # glog -printf "\n[Downloading glog-0.4.0]\n" -curl -L -o /tmp/glog-0.4.0.tar.gz https://github.com/google/glog/archive/v0.4.0.tar.gz -tar -xf /tmp/glog-0.4.0.tar.gz -C /tmp/ -mv /tmp/glog-0.4.0/ glog/ +printf "\nDownloading glog\n" +rm -rf glog; mkdir glog +curl -L $PROXY https://github.com/google/glog/archive/v0.4.0.tar.gz | tar zx -C glog --strip-components 1 # ruy -printf "\n[Downloading ruy]\n" -curl -L -o /tmp/ruy.zip https://github.com/google/ruy/archive/a09683b8da7164b9c5704f88aef2dc65aa583e5d.zip +printf "\nDownloading ruy\n" +curl -L $PROXY -o /tmp/ruy.zip https://github.com/google/ruy/archive/a09683b8da7164b9c5704f88aef2dc65aa583e5d.zip unzip -q /tmp/ruy.zip -d /tmp/ +rm -rf ruy/ mv /tmp/ruy-a09683b8da7164b9c5704f88aef2dc65aa583e5d ruy/ diff --git a/scripts/build_android.sh b/scripts/build_android.sh index 5913f5e8b768c..225caa68abfcd 100755 --- a/scripts/build_android.sh +++ b/scripts/build_android.sh @@ -135,6 +135,7 @@ else fi # Disable unused dependencies CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") CMAKE_ARGS+=("-DUSE_OPENCV=OFF") CMAKE_ARGS+=("-DUSE_LMDB=OFF") diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh index 2bb8763ef17dc..a0402db65a79b 100755 --- a/scripts/build_ios.sh +++ b/scripts/build_ios.sh @@ -104,6 +104,7 @@ CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") # Disable unused dependencies CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") CMAKE_ARGS+=("-DUSE_OPENCV=OFF") CMAKE_ARGS+=("-DUSE_LMDB=OFF") diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh index bc87474ff1f93..0cc49301baf1f 100755 --- a/scripts/build_mobile.sh +++ b/scripts/build_mobile.sh @@ -38,6 +38,7 @@ fi # Disable unused dependencies CMAKE_ARGS+=("-DUSE_ROCM=OFF") CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") CMAKE_ARGS+=("-DUSE_OPENCV=OFF") CMAKE_ARGS+=("-DUSE_LMDB=OFF") diff --git a/scripts/build_tizen.sh b/scripts/build_tizen.sh index 38e90fa39901e..c9d26ced319a6 100755 --- a/scripts/build_tizen.sh +++ b/scripts/build_tizen.sh @@ -112,6 +112,7 @@ cd $BUILD_ROOT cmake "$CAFFE2_ROOT" \ -DCMAKE_VERBOSE_MAKEFILE=1 \ -DUSE_CUDA=OFF \ + -DUSE_ITT=OFF \ -DUSE_OPENCV=OFF \ -DUSE_LMDB=OFF \ -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ diff --git a/scripts/fbcode-dev-setup/onnx_c2_sanity_check.sh b/scripts/fbcode-dev-setup/onnx_c2_sanity_check.sh deleted file mode 100755 index bb4d1efa4e35c..0000000000000 --- a/scripts/fbcode-dev-setup/onnx_c2_sanity_check.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -set -e - -python -c 'from caffe2.python import build; from pprint import pprint; pprint(build.build_options)' -python -c 'from caffe2.python import core, workspace; print("GPUs found: " + str(workspace.NumCudaDevices()))' -python -c "import onnx" -python -c "import torch" - -echo "Caffe2, PyTorch and ONNX installed successfully!!" diff --git a/scripts/fbcode-dev-setup/onnx_c2_setup.sh b/scripts/fbcode-dev-setup/onnx_c2_setup.sh deleted file mode 100755 index 2fc629657e5fb..0000000000000 --- a/scripts/fbcode-dev-setup/onnx_c2_setup.sh +++ /dev/null @@ -1,154 +0,0 @@ -#!/bin/bash - -# This script helps developers set up the ONNX Caffe2 and PyTorch develop environment on devgpu. -# It creates an virtualenv instance, and installs all the dependencies in this environment. -# The script will creates a folder called onnx-dev folder under the $HOME directory. -# onnx, pytorch and caffe2 are installed separately. -# Please source $HOME/onnx-dev/.onnx_env_init to initialize the development before starting developing. - - -# TODO: support python 3. - -# Set script configuration -set -e -shopt -s expand_aliases - -# Proxy setup -alias with_proxy="HTTPS_PROXY=http://fwdproxy:8080 HTTP_PROXY=http://fwdproxy:8080 FTP_PROXY=http://fwdproxy:8080 https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 ftp_proxy=http://fwdproxy:8080 http_no_proxy='*.facebook.com|*.tfbnw.net|*.fb.com'" - -# Set the variables -RED='\033[0;31m' -CYAN='\033[0;36m' -NC='\033[0m' -onnx_root="$HOME/local/onnx-dev" # I think hardcoding the onnx root dir is fine, just like fbsource -onnx_root_link="$HOME/onnx-dev" -venv="$onnx_root/onnxvenv" -onnx_init_file="$onnx_root_link/.onnx_env_init" -ccache_root="$onnx_root/ccache" -ccache_script="$(pwd)/ccache_install.sh" -sanity_script="$onnx_root/sanity.sh" - -# Check whether default CUDA exists -# TODO check the required header and lib files -default_cuda="/usr/local/cuda" -if [[ ! -e "$default_cuda" ]]; then - echo "Default CUDA is not found at $default_cuda" -fi - -# Checking to see if CuDNN is present, and install it if not exists -if [ -f /usr/local/cuda/include/cudnn.h ]; then - echo "CuDNN header already exists!!" -else - sudo cp -R /home/engshare/third-party2/cudnn/6.0.21/src/cuda/include/* /usr/local/cuda/include/ - sudo cp -R /home/engshare/third-party2/cudnn/6.0.21/src/cuda/lib64/* /usr/local/cuda/lib64/ -fi - -# TODO set the specific version for each package -# Install the dependencies for Caffe2 -sudo yum install python-virtualenv freetype-devel libpng-devel glog gflags protobuf protobuf-devel protobuf-compiler -y -rpm -q protobuf # check the version and if necessary update the value below -protoc --version # check protoc -protoc_path=$(which protoc) -if [[ "$protoc_path" != "/bin/protoc" ]]; then - echo "Warning: Non-default protoc is detected, the script may not work with non-default protobuf!!!" - echo "Please try to remove the protoc at $protoc_path and rerun this script." - exit 1 -fi - -# Upgrade Cmake to the right version (>3.0) -sudo yum remove cmake3 -y -sudo yum install cmake -y - -# Install the dependencies for CCache -sudo yum install autoconf asciidoc -y - -# Create the root folder -if [ -e "$onnx_root" ]; then - timestamp=$(date "+%Y.%m.%d-%H.%M.%S") - mv --backup=t -T "$onnx_root" "${onnx_root}.old.$timestamp" -fi -mkdir -p "$onnx_root" -if [ -e "$onnx_root_link"]; then - timestamp=$(date "+%Y.%m.%d-%H.%M.%S") - mv --backup=t -T "$onnx_root_link" "${onnx_root_link}.old.$timestamp" -fi -ln -s "$onnx_root" "$onnx_root_link" - -# Set the name of virtualenv instance -with_proxy virtualenv "$venv" - -# Creating a script that can be sourced in the future for the environmental variable -touch "$onnx_init_file" -{ - # shellcheck disable=SC2016 - echo 'if [ -z "$LD_LIBRARY_PATH" ]; then'; - echo ' export LD_LIBRARY_PATH=/usr/local/cuda/lib64'; - echo 'else' - # shellcheck disable=SC2016 - echo ' export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH'; - echo "fi" - # shellcheck disable=SC2016 - echo 'export PATH='"$ccache_root"'/lib:/usr/local/cuda/bin:$PATH'; - echo "source $venv/bin/activate"; - echo 'alias with_proxy="HTTPS_PROXY=http://fwdproxy:8080 HTTP_PROXY=http://fwdproxy:8080 FTP_PROXY=http://fwdproxy:8080 https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 ftp_proxy=http://fwdproxy:8080 http_no_proxy='"'"'*.facebook.com|*.tfbnw.net|*.fb.com'"'"'"' -} >> "$onnx_init_file" -chmod u+x "$onnx_init_file" - -# Installing CCache -cd "$onnx_root" -if [ ! -f "$ccache_script" ]; then - ccache_script="$onnx_root/ccache_install.sh" - with_proxy wget https://raw.githubusercontent.com/pytorch/pytorch/master/scripts/fbcode-dev-setup/ccache_setup.sh -O "$ccache_script" -fi -chmod u+x "$ccache_script" -"$ccache_script" --path "$ccache_root" - -# Test nvcc with CCache -own_ccache=true -if [ -f "$CUDA_NVCC_EXECUTABLE" ] && [[ "$ccache_root/cuda/nvcc" != "$CUDA_NVCC_EXECUTABLE" ]] && \ - [[ "$CUDA_NVCC_EXECUTABLE" == *"ccache"* ]]; then # Heuristic rule - if $CUDA_NVCC_EXECUTABLE --version; then - own_ccache=false - fi -fi -if [ "$own_ccache" = true ]; then - echo "export CUDA_NVCC_EXECUTABLE=$ccache_root/cuda/nvcc" >> "$onnx_init_file" -fi - -# Loading env vars -# shellcheck disable=SC1090 -source "$onnx_init_file" - -"$CUDA_NVCC_EXECUTABLE" --version - -# Create a virtualenv, activate it, upgrade pip -if [ -f "$HOME/.pip/pip.conf" ]; then - echo "${RED}Warning: $HOME/.pip/pip.conf is detected, pip install may fail!${NC}" -fi -with_proxy python -m pip install -U pip setuptools -with_proxy python -m pip install future numpy "protobuf>3.2" pytest-runner pyyaml typing ipython - -# Cloning repos -cd "$onnx_root" -with_proxy git clone https://github.com/onnx/onnx --recursive -with_proxy git clone https://github.com/pytorch/pytorch --recursive - -# Build ONNX -cd "$onnx_root/onnx" -with_proxy python setup.py develop - -# Build PyTorch and Caffe2 -cd "$onnx_root/pytorch" -with_proxy pip install -r "requirements.txt" -with_proxy python setup.py develop - -# Sanity checks and useful info -cd "$onnx_root" -with_proxy wget https://raw.githubusercontent.com/pytorch/pytorch/master/scripts/fbcode-dev-setup/onnx_c2_sanity_check.sh -O "$sanity_script" -chmod u+x "$sanity_script" -$sanity_script - -echo "Congrats, you are ready to rock!!" -echo "################ Please run the following command before development ################" -echo -e "${CYAN}source $onnx_init_file${NC}" -echo "#####################################################################################" diff --git a/scripts/onnx/test.sh b/scripts/onnx/test.sh index 162f2f54cedbc..2cc7e77895535 100755 --- a/scripts/onnx/test.sh +++ b/scripts/onnx/test.sh @@ -66,11 +66,11 @@ if [[ "${SHARD_NUMBER}" == "1" ]]; then --ignore "$top_dir/test/onnx/test_custom_ops.py" \ --ignore "$top_dir/test/onnx/test_utility_funs.py" \ --ignore "$top_dir/test/onnx/test_models.py" \ + --ignore "$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \ "${test_paths[@]}" # Heavy memory usage tests that cannot run in parallel. pytest "${args[@]}" \ - "$top_dir/test/onnx/test_models_onnxruntime.py" \ "$top_dir/test/onnx/test_custom_ops.py" \ "$top_dir/test/onnx/test_utility_funs.py" \ "$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "not TestModelsONNXRuntime" @@ -81,6 +81,7 @@ if [[ "${SHARD_NUMBER}" == "2" ]]; then # TODO(#79802): Parameterize test_models.py pytest "${args[@]}" \ "$top_dir/test/onnx/test_models.py" \ + "$top_dir/test/onnx/test_models_quantized_onnxruntime.py" \ "$top_dir/test/onnx/test_models_onnxruntime.py" "-k" "TestModelsONNXRuntime" pytest "${args[@]}" "${args_parallel[@]}" \ diff --git a/scripts/release_notes/.gitignore b/scripts/release_notes/.gitignore new file mode 100644 index 0000000000000..484ab7e5c61d7 --- /dev/null +++ b/scripts/release_notes/.gitignore @@ -0,0 +1 @@ +results/* diff --git a/scripts/release_notes/apply_categories.py b/scripts/release_notes/apply_categories.py new file mode 100644 index 0000000000000..b656f83c1e10f --- /dev/null +++ b/scripts/release_notes/apply_categories.py @@ -0,0 +1,28 @@ +# Quick scipt to apply categorized items to the +# base commitlist . Useful if you are refactoring any code +# but want to keep the previous data on categories + +import commitlist +import csv + +category_csv = "results/category_data.csv" +commitlist_csv = "results/commitlist.csv" + +with open(category_csv, "r") as category_data: + reader = csv.DictReader(category_data, commitlist.commit_fields) + rows = list(reader) + category_map = {row["commit_hash"]: row["category"] for row in rows} + +with open(commitlist_csv, "r") as commitlist_data: + reader = csv.DictReader(commitlist_data, commitlist.commit_fields) + commitlist_rows = list(reader) + +for row in commitlist_rows: + hash = row["commit_hash"] + if hash in category_map and category_map[hash] != "Uncategorized": + row["category"] = category_map[hash] + +with open(commitlist_csv, "w") as commitlist_write: + writer = csv.DictWriter(commitlist_write, commitlist.commit_fields) + writer.writeheader() + writer.writerows(commitlist_rows) diff --git a/scripts/release_notes/categorize.py b/scripts/release_notes/categorize.py index 2bdab4f410414..a79c737d18e50 100644 --- a/scripts/release_notes/categorize.py +++ b/scripts/release_notes/categorize.py @@ -1,12 +1,12 @@ import argparse import os import textwrap -from common import categories, topics, CommitDataCache +from common import categories, topics, get_commit_data_cache from commitlist import CommitList class Categorizer: def __init__(self, path, category='Uncategorized'): - self.cache = CommitDataCache() + self.cache = get_commit_data_cache() self.commits = CommitList.from_existing(path) # Special categories: 'Uncategorized' diff --git a/scripts/release_notes/commitlist.py b/scripts/release_notes/commitlist.py index 4abaffa6fb881..2013b6d512a6f 100644 --- a/scripts/release_notes/commitlist.py +++ b/scripts/release_notes/commitlist.py @@ -2,10 +2,13 @@ from common import run, topics, get_features from collections import defaultdict import os +from pathlib import Path import csv import pprint -from common import CommitDataCache +from common import get_commit_data_cache, features_to_dict import re +import dataclasses +from typing import List """ @@ -21,28 +24,30 @@ python commitlist.py --update_to bfcb687b9c """ - +@dataclasses.dataclass(frozen=True) class Commit: - def __init__(self, commit_hash, category, topic, title): - self.commit_hash = commit_hash - self.category = category - self.topic = topic - self.title = title - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - return self.commit_hash == other.commit_hash and \ - self.category == other.category and \ - self.topic == other.topic and \ - self.title == other.title + commit_hash: str + category: str + topic: str + title: str + pr_link: str + author: str + + # This is not a list so that it is easier to put in a spreadsheet + accepter_1: str + accepter_2: str + accepter_3: str + + merge_into: str = None def __repr__(self): return f'Commit({self.commit_hash}, {self.category}, {self.topic}, {self.title})' +commit_fields = tuple(f.name for f in dataclasses.fields(Commit)) + class CommitList: # NB: Private ctor. Use `from_existing` or `create_new`. - def __init__(self, path, commits): + def __init__(self, path: str, commits: List[Commit]): self.path = path self.commits = commits @@ -59,20 +64,28 @@ def create_new(path, base_version, new_version): return CommitList(path, commits) @staticmethod - def read_from_disk(path): + def read_from_disk(path) -> List[Commit]: with open(path) as csvfile: - reader = csv.reader(csvfile) - rows = list(row for row in reader) - assert all(len(row) >= 4 for row in rows) - return [Commit(*row[:4]) for row in rows] - - def write_to_disk(self): - path = self.path - rows = self.commits + reader = csv.DictReader(csvfile) + rows = [] + for row in reader: + if row.get("new_title", "") != "": + row["title"] = row["new_title"] + filtered_rows = {k: row.get(k, "") for k in commit_fields} + rows.append(Commit(**filtered_rows)) + return rows + + def write_result(self): + self.write_to_disk_static(self.path, self.commits) + + @staticmethod + def write_to_disk_static(path, commit_list): + os.makedirs(Path(path).parent, exist_ok=True) with open(path, 'w') as csvfile: writer = csv.writer(csvfile) - for commit in rows: - writer.writerow([commit.commit_hash, commit.category, commit.topic, commit.title]) + writer.writerow(commit_fields) + for commit in commit_list: + writer.writerow(dataclasses.astuple(commit)) def keywordInFile(file, keywords): for key in keywords: @@ -81,13 +94,26 @@ def keywordInFile(file, keywords): return False @staticmethod - def categorize(commit_hash, title): - features = get_features(commit_hash, return_dict=True) + def gen_commit(commit_hash): + feature_item = get_commit_data_cache().get(commit_hash) + features = features_to_dict(feature_item) + category, topic = CommitList.categorize(features) + a1, a2, a3 = (features["accepters"] + ("", "", ""))[:3] + if features["pr_number"] is not None: + pr_link = f"https://github.com/pytorch/pytorch/pull/{features['pr_number']}" + else: + pr_link = None + + return Commit(commit_hash, category, topic, features["title"], pr_link, features["author"], a1, a2, a3) + + @staticmethod + def categorize(features): title = features['title'] labels = features['labels'] category = 'Uncategorized' topic = 'Untopiced' + # We ask contributors to label their PR's appropriately # when they're first landed. # Check if the labels are there first. @@ -100,15 +126,15 @@ def categorize(commit_hash, title): topic = label.split('topic: ', 1)[1] already_topiced = True if already_categorized and already_topiced: - return Commit(commit_hash, category, topic, title) + return category, topic # update this to check if each file starts with caffe2 if 'caffe2' in title: - return Commit(commit_hash, 'caffe2', topic, title) + return 'caffe2', topic if '[codemod]' in title.lower(): - return Commit(commit_hash, 'skip', topic, title) + return 'skip', topic if 'Reverted' in labels: - return Commit(commit_hash, 'skip', topic, title) + return 'skip', topic if 'bc_breaking' in labels: topic = 'bc-breaking' if 'module: deprecation' in labels: @@ -185,7 +211,7 @@ def categorize(commit_hash, title): category = 'python_frontend' - return Commit(commit_hash, category, topic, title) + return category, topic @staticmethod def get_commits_between(base_version, new_version): @@ -201,7 +227,7 @@ def get_commits_between(base_version, new_version): log_lines = commits.split('\n') hashes, titles = zip(*[log_line.split(' ', 1) for log_line in log_lines]) - return [CommitList.categorize(commit_hash, title) for commit_hash, title in zip(hashes, titles)] + return [CommitList.gen_commit(commit_hash) for commit_hash in hashes] def filter(self, *, category=None, topic=None): commits = self.commits @@ -225,40 +251,60 @@ def stat(self): def create_new(path, base_version, new_version): commits = CommitList.create_new(path, base_version, new_version) - commits.write_to_disk() + commits.write_result() def update_existing(path, new_version): commits = CommitList.from_existing(path) commits.update_to(new_version) - commits.write_to_disk() + commits.write_result() def rerun_with_new_filters(path): current_commits = CommitList.from_existing(path) for i in range(len(current_commits.commits)): c = current_commits.commits[i] if 'Uncategorized' in str(c): - current_commits.commits[i] = CommitList.categorize(c.commit_hash, c.title) - current_commits.write_to_disk() - -def to_markdown(commit_list, category): + feature_item = get_commit_data_cache().get(c.commit_hash) + features = features_to_dict(feature_item) + category, topic = CommitList.categorize(features) + current_commits[i] = dataclasses.replace(c, category=category, topic=topic) + current_commits.write_result() + +def get_hash_or_pr_url(commit: Commit): + # cdc = get_commit_data_cache() + pr_link = commit.pr_link + if pr_link is None: + return commit.commit_hash + else: + regex = r'https://github.com/pytorch/pytorch/pull/([0-9]+)' + matches = re.findall(regex, pr_link) + if len(matches) == 0: + return commit.commit_hash + + return f'[#{matches[0]}]({pr_link})' + +def to_markdown(commit_list: CommitList, category): def cleanup_title(commit): match = re.match(r'(.*) \(#\d+\)', commit.title) if match is None: return commit.title return match.group(1) - cdc = CommitDataCache() + merge_mapping = defaultdict(list) + for commit in commit_list.commits: + if commit.merge_into: + merge_mapping[commit.merge_into].append(commit) + + cdc = get_commit_data_cache() lines = [f'\n## {category}\n'] for topic in topics: lines.append(f'### {topic}\n') commits = commit_list.filter(category=category, topic=topic) for commit in commits: - result = cleanup_title(commit) - maybe_pr_number = cdc.get(commit.commit_hash).pr_number - if maybe_pr_number is None: - result = f'- {result} ({commit.commit_hash})\n' - else: - result = f'- {result} ([#{maybe_pr_number}](https://github.com/pytorch/pytorch/pull/{maybe_pr_number}))\n' + if commit.merge_into: + continue + all_related_commits = merge_mapping[commit.commit_hash] + [commit] + commit_list_md = ", ".join(get_hash_or_pr_url(c) for c in all_related_commits) + result = f'- {cleanup_title(commit)} ({commit_list_md})\n' lines.append(result) return lines @@ -301,7 +347,7 @@ def main(): group.add_argument('--rerun_with_new_filters', action='store_true') group.add_argument('--stat', action='store_true') group.add_argument('--export_markdown', action='store_true') - + group.add_argument('--export_csv_categories', action='store_true') parser.add_argument('--path', default='results/commitlist.csv') args = parser.parse_args() @@ -319,6 +365,16 @@ def main(): stats = commits.stat() pprint.pprint(stats) return + + if args.export_csv_categories: + commits = CommitList.from_existing(args.path) + categories = list(commits.stat().keys()) + for category in categories: + print(f"Exporting {category}...") + filename = f'results/export/result_{category}.csv' + CommitList.write_to_disk_static(filename, commits.filter(category=category)) + return + if args.export_markdown: commits = CommitList.from_existing(args.path) categories = list(commits.stat().keys()) @@ -331,7 +387,7 @@ def main(): with open(filename, 'w') as f: f.writelines(lines) return - assert False + raise AssertionError() if __name__ == '__main__': main() diff --git a/scripts/release_notes/common.py b/scripts/release_notes/common.py index 355dee12adaf2..4509a186cc279 100644 --- a/scripts/release_notes/common.py +++ b/scripts/release_notes/common.py @@ -1,5 +1,5 @@ from collections import namedtuple -from os.path import expanduser +from pathlib import Path import locale import subprocess import re @@ -58,6 +58,8 @@ 'docs', 'devs', 'Untopiced', + "not user facing", + "security", ] @@ -67,6 +69,8 @@ 'pr_number', 'files_changed', 'labels', + 'author', + 'accepters' ]) @@ -76,7 +80,9 @@ def dict_to_features(dct): body=dct['body'], pr_number=dct['pr_number'], files_changed=dct['files_changed'], - labels=dct['labels']) + labels=dct['labels'], + author=dct['author'], + accepters=tuple(dct['accepters'])) def features_to_dict(features): @@ -128,7 +134,7 @@ def parse_pr_number(body, commit_hash, title): def get_ghstack_token(): pattern = 'github_oauth = (.*)' - with open(expanduser('~/.ghstackrc'), 'r+') as f: + with open(Path('~/.ghstackrc').expanduser(), 'r+') as f: config = f.read() matches = re.findall(pattern, config) if len(matches) == 0: @@ -146,47 +152,77 @@ def run_query(query): raise Exception("Query failed to run by returning code of {}. {}".format(request.status_code, query)) -def gh_labels(pr_number): - query = f""" - {{ - repository(owner: "pytorch", name: "pytorch") {{ - pullRequest(number: {pr_number}) {{ - labels(first: 10) {{ - edges {{ - node {{ +def github_data(pr_number): + query = """ + { + repository(owner: "pytorch", name: "pytorch") { + pullRequest(number: %s ) { + author { + login + } + reviews(last: 5, states: APPROVED) { + nodes { + author { + login + } + } + } + labels(first: 10) { + edges { + node { name - }} - }} - }} - }} - }} - }} - """ + } + } + } + } + } + } + """ % pr_number query = run_query(query) + edges = query['data']['repository']['pullRequest']['labels']['edges'] - return [edge['node']['name'] for edge in edges] + labels = [edge['node']['name'] for edge in edges] + author = query['data']['repository']['pullRequest']['author']['login'] + nodes = query['data']['repository']['pullRequest']['reviews']['nodes'] + + # using set to dedup multiple accepts from same accepter + accepters = {node["author"]["login"] for node in nodes} + accepters = tuple(sorted(accepters)) + + return labels, author, accepters -def get_features(commit_hash, return_dict=False): +def get_features(commit_hash): title, body, files_changed = ( commit_title(commit_hash), commit_body(commit_hash), commit_files_changed(commit_hash)) pr_number = parse_pr_number(body, commit_hash, title) labels = [] + author = "" + accepters = tuple() if pr_number is not None: - labels = gh_labels(pr_number) - result = Features(title, body, pr_number, files_changed, labels) - if return_dict: - return features_to_dict(result) + labels, author, accepters = github_data(pr_number) + result = Features(title, body, pr_number, files_changed, labels, author, accepters) return result -class CommitDataCache: - def __init__(self, path='results/data.json'): + +_commit_data_cache = None + +def get_commit_data_cache(path='results/data.json'): + global _commit_data_cache + if _commit_data_cache is None: + _commit_data_cache = _CommitDataCache(path) + return _commit_data_cache + +class _CommitDataCache: + def __init__(self, path): self.path = path self.data = {} if os.path.exists(path): self.data = self.read_from_disk() + else: + os.makedirs(Path(path).parent, exist_ok=True) def get(self, commit): if commit not in self.data.keys(): diff --git a/scripts/release_notes/test_release_notes.py b/scripts/release_notes/test_release_notes.py index 898db48c29295..8bd32eee13f47 100644 --- a/scripts/release_notes/test_release_notes.py +++ b/scripts/release_notes/test_release_notes.py @@ -6,12 +6,12 @@ class TestCommitList(unittest.TestCase): def test_create_new(self): with tempfile.TemporaryDirectory() as tempdir: commit_list_path = f'{tempdir}/commitlist.csv' - commit_list = CommitList.create_new(commit_list_path, 'v1.5.0', '7543e7e558') - self.assertEqual(len(commit_list.commits), 2143) + commit_list = CommitList.create_new(commit_list_path, 'v1.5.0', '6000dca5df') + self.assertEqual(len(commit_list.commits), 33) self.assertEqual(commit_list.commits[0].commit_hash, '7335f079ab') self.assertTrue(commit_list.commits[0].title.startswith('[pt][quant] qmul and qadd')) - self.assertEqual(commit_list.commits[-1].commit_hash, '7543e7e558') - self.assertTrue(commit_list.commits[-1].title.startswith('Migrate minall, max, maxall')) + self.assertEqual(commit_list.commits[-1].commit_hash, '6000dca5df') + self.assertTrue(commit_list.commits[-1].title.startswith('[nomnigraph] Copy device option when customize ')) def test_read_write(self): with tempfile.TemporaryDirectory() as tempdir: diff --git a/setup.py b/setup.py index 77eb7b71e58bc..4facfe5deec18 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,8 @@ # # USE_STATIC_MKL # Prefer to link with MKL statically - Unix only +# USE_ITT=0 +# disable use of Intel(R) VTune Profiler's ITT functionality # # USE_NNPACK=0 # disables NNPACK build @@ -294,7 +296,6 @@ def report(*args): sysconfig.get_config_var("prefix"), sysconfig.get_config_var("VERSION")) # Fix virtualenv builds - # TODO: Fix for python < 3.3 if not os.path.exists(cmake_python_library): cmake_python_library = "{}/libs/python{}.lib".format( sys.base_prefix, @@ -541,6 +542,11 @@ def run(self): if cmake_cache_vars['USE_LIGHTWEIGHT_DISPATCH']: report('-- Using lightweight dispatch') + if cmake_cache_vars['USE_ITT']: + report('-- Using ITT') + else: + report('-- Not using ITT') + # Do not use clang to compile extensions if `-fstack-clash-protection` is defined # in system CFLAGS c_flags = str(os.getenv('CFLAGS', '')) @@ -657,20 +663,29 @@ class concat_license_files(): is a single license file in the sdist and wheels with all of the necessary licensing info. """ - def __init__(self): + def __init__(self, include_files=False): self.f1 = 'LICENSE' self.f2 = 'third_party/LICENSES_BUNDLED.txt' + self.include_files = include_files def __enter__(self): """Concatenate files""" + + old_path = sys.path + sys.path.append(third_party_path) + try: + from build_bundled import create_bundled + finally: + sys.path = old_path + with open(self.f1, 'r') as f1: self.bsd_text = f1.read() with open(self.f1, 'a') as f1: - with open(self.f2, 'r') as f2: - self.bundled_text = f2.read() - f1.write('\n\n') - f1.write(self.bundled_text) + f1.write('\n\n') + create_bundled(os.path.relpath(third_party_path), f1, + include_files=self.include_files) + def __exit__(self, exception_type, exception_value, traceback): """Restore content of f1""" @@ -690,7 +705,7 @@ def __exit__(self, exception_type, exception_value, traceback): class wheel_concatenate(bdist_wheel): """ check submodules on sdist to prevent incomplete tarballs """ def run(self): - with concat_license_files(): + with concat_license_files(include_files=True): super().run() @@ -930,9 +945,11 @@ def print_box(msg): print('-' * (size + 2)) if __name__ == '__main__': - # Parse the command line and check the arguments - # before we proceed with building deps and setup + # Parse the command line and check the arguments before we proceed with + # building deps and setup. We need to set values so `--help` works. dist = Distribution() + dist.script_name = os.path.basename(sys.argv[0]) + dist.script_args = sys.argv[1:] try: dist.parse_command_line() except setuptools.distutils.errors.DistutilsArgError as e: @@ -1061,6 +1078,7 @@ def print_box(msg): 'include/torch/csrc/deploy/interpreter/*.h', 'include/torch/csrc/deploy/interpreter/*.hpp', 'include/torch/csrc/distributed/c10d/exception.h', + 'include/torch/csrc/distributed/rpc/*.h', 'include/torch/csrc/jit/*.h', 'include/torch/csrc/jit/backends/*.h', 'include/torch/csrc/jit/generated/*.h', @@ -1086,7 +1104,9 @@ def print_box(msg): 'include/torch/csrc/tensor/*.h', 'include/torch/csrc/lazy/backend/*.h', 'include/torch/csrc/lazy/core/*.h', + 'include/torch/csrc/lazy/core/internal_ops/*.h', 'include/torch/csrc/lazy/core/ops/*.h', + 'include/torch/csrc/lazy/ts_backend/*.h', 'include/pybind11/*.h', 'include/pybind11/detail/*.h', 'include/TH/*.h*', @@ -1148,7 +1168,7 @@ def print_box(msg): 'Programming Language :: Python :: 3', ] + ['Programming Language :: Python :: 3.{}'.format(i) for i in range(python_min_version[1], version_range_max)], license='BSD-3', - keywords='pytorch machine learning', + keywords='pytorch, machine learning', ) if EMIT_BUILD_WARNING: print_box(build_update_message) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 5a7e745bca825..1476b2b4abe47 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1,17 +1,5 @@ { - "torch.amp.autocast_mode": [ - "Any", - "Optional" - ], - "torch.ao.nn.sparse.quantized.dynamic.linear": [ - "LinearBlockSparsePattern", - "Optional", - "hide_packed_params_repr" - ], - "torch.ao.nn.sparse.quantized.linear": [ - "Optional", - "hide_packed_params_repr" - ], + "being_migrated": {}, "torch.ao.quantization": [ "ABC", "ABCMeta", @@ -84,6 +72,15 @@ "Union", "get_combined_dict" ], + "torch.ao.quantization.backend_config.utils": [ + "Any", + "Dict", + "Callable", + "List", + "Union", + "Tuple", + "Pattern" + ], "torch.ao.quantization.backend_config.native": [ "Any", "Dict", @@ -103,13 +100,6 @@ "reverse3", "reverse_sequential_wrapper2" ], - "torch.ao.quantization.backend_config.observation_type": [ - "Enum" - ], - "torch.ao.quantization.backend_config.tensorrt": [ - "ObservationType", - "reverse_sequential_wrapper2" - ], "torch.ao.quantization.quantization_types": [ "Any", "Node", @@ -127,12 +117,6 @@ "Set", "Union" ], - "torch.ao.quantization.fx.lower_to_fbgemm": [ - "Dict", - "QConfigAny", - "QuantizedGraphModule", - "Tuple" - ], "torch.ao.quantization.fx.pattern_utils": [ "Any", "Dict", @@ -186,21 +170,6 @@ "sorted_patterns_dict", "get_quantize_handler_cls" ], - "torch.ao.quantization.observer": [ - "ABC", - "ABCMeta", - "Any", - "Dict", - "List", - "Optional", - "OrderedDict", - "Tuple", - "Union", - "abstractmethod", - "calculate_qmin_qmax", - "check_min_max_valid", - "partial" - ], "torch.ao.quantization.qconfig": [ "Any", "FakeQuantize", @@ -273,42 +242,6 @@ "QuantType", "wrap_cpp_module" ], - "torch.ao.sparsity.experimental.pruner.base_pruner": [ - "ActivationReconstruction", - "BaseSparsifier", - "BiasHook", - "ModuleDict", - "ModuleList", - "PruningParametrization", - "ZeroesParametrization", - "fqn_to_module", - "module_to_fqn" - ], - "torch.ao.sparsity.experimental.pruner.parametrization": [ - "Any", - "List" - ], - "torch.ao.sparsity.scheduler.base_scheduler": [ - "BaseSparsifier", - "wraps" - ], - "torch.ao.sparsity.scheduler.lambda_scheduler": [ - "BaseScheduler" - ], - "torch.ao.sparsity.sparsifier.base_sparsifier": [ - "Dict", - "FakeSparsity", - "Optional", - "Tuple", - "defaultdict", - "fqn_to_module", - "module_to_fqn" - ], - "torch.ao.sparsity.sparsifier.weight_norm_sparsifier": [ - "BaseSparsifier", - "Tuple", - "reduce" - ], "torch.autograd": [ "NestedIOFunction", "detect_anomaly", @@ -392,6 +325,14 @@ "Union", "classproperty" ], + "torch.cuda.comm": [ + "broadcast", + "broadcast_coalesced", + "reduce_add", + "reduce_add_coalesced", + "scatter", + "gather" + ], "torch.cuda.amp.autocast_mode": [ "Any" ], @@ -454,11 +395,6 @@ "ProcessGroupMPI", "ProcessGroupNCCL" ], - "torch.distributed.algorithms.ddp_comm_hooks": [ - "DistributedDataParallel", - "Enum", - "partial" - ], "torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks": [ "Any", "GradBucket" @@ -471,23 +407,6 @@ "Any", "Callable" ], - "torch.distributed.algorithms.join": [ - "ABC", - "Any", - "List", - "NamedTuple", - "Optional", - "TracebackType", - "Type", - "abstractmethod" - ], - "torch.distributed.algorithms.model_averaging.averagers": [ - "ABC", - "Dict", - "Iterable", - "Union", - "abstractmethod" - ], "torch.distributed.algorithms.model_averaging.utils": [ "Dict", "Iterable", @@ -529,57 +448,18 @@ "ProcessGroupMPI", "ProcessGroupNCCL" ], - "torch.distributed.elastic.agent.server.api": [ - "Any", - "Callable", - "Dict", - "Enum", - "Event", - "EventSource", - "List", - "Optional", - "ProcessFailure", - "SignalException", - "Std", - "Store", - "Tuple", - "Union", - "closing", - "dataclass", - "field", - "get_logger", - "prof", - "put_metric", - "record" - ], "torch.distributed.elastic.events": [ "Dict", "Enum", "EventMetadataValue", "Optional" ], - "torch.distributed.elastic.events.api": [ - "Dict", - "Enum", - "EventMetadataValue", - "Optional", - "Union", - "asdict", - "dataclass", - "field" - ], "torch.distributed.elastic.events.handlers": [ "Dict" ], "torch.distributed.elastic.metrics": [ "Optional" ], - "torch.distributed.elastic.metrics.api": [ - "Dict", - "Optional", - "namedtuple", - "wraps" - ], "torch.distributed.elastic.multiprocessing": [ "Callable", "Dict", @@ -623,12 +503,6 @@ "get_logger", "wraps" ], - "torch.distributed.elastic.multiprocessing.errors.error_handler": [ - "Optional" - ], - "torch.distributed.elastic.multiprocessing.errors.handlers": [ - "ErrorHandler" - ], "torch.distributed.elastic.multiprocessing.redirects": [ "contextmanager", "partial", @@ -658,69 +532,7 @@ "abstractmethod" ], "torch.distributed.elastic.rendezvous.dynamic_rendezvous": [ - "ABC", - "Any", - "Callable", - "Dict", - "Enum", - "List", - "NodeState", - "Optional", - "PrefixStore", - "RendezvousClosedError", - "RendezvousError", - "RendezvousHandler", - "RendezvousParameters", - "RendezvousStateError", - "RendezvousTimeoutError", - "Set", - "Store", - "Token", - "Tuple", - "abstractmethod", - "cast", - "construct_and_record_rdzv_event", - "dataclass", - "datetime", - "timedelta" - ], - "torch.distributed.elastic.rendezvous.registry": [ - "RendezvousHandler", - "RendezvousParameters", - "create_handler" - ], - "torch.distributed.elastic.rendezvous.utils": [ - "Any", - "Callable", - "Dict", - "Event", - "Optional", - "Thread", - "Tuple", - "Union", - "timedelta" - ], - "torch.distributed.elastic.timer.api": [ - "Any", - "Dict", - "List", - "Optional", - "Set", - "contextmanager", - "getframeinfo", - "stack" - ], - "torch.distributed.elastic.timer.local_timer": [ - "Any", - "Dict", - "Empty", - "List", - "RequestQueue", - "Set", - "TimerClient", - "TimerRequest", - "TimerServer", - "Tuple" + "get_method_name" ], "torch.distributed.elastic.utils.api": [ "Any", @@ -754,34 +566,6 @@ "Union", "accumulate" ], - "torch.distributed.fsdp.fully_sharded_data_parallel": [ - "Any", - "Callable", - "Dict", - "Enum", - "FlatParameter", - "FlattenParamsWrapper", - "Generator", - "Iterable", - "Iterator", - "List", - "Mapping", - "NamedTuple", - "Optional", - "Parameter", - "ProcessGroup", - "Set", - "Shard", - "ShardedTensor", - "Tuple", - "Union", - "Variable", - "auto", - "cast", - "contextmanager", - "dataclass", - "init_from_local_shards" - ], "torch.distributed.fsdp.utils": [ "Any", "Callable", @@ -803,25 +587,6 @@ "Type", "cast" ], - "torch.distributed.launcher.api": [ - "Any", - "Callable", - "ChildFailedError", - "Dict", - "List", - "LocalElasticAgent", - "Optional", - "RendezvousParameters", - "SignalException", - "Std", - "Tuple", - "Union", - "WorkerSpec", - "dataclass", - "field", - "get_logger", - "parse_rendezvous_endpoint" - ], "torch.distributed.nn": [ "Function", "ReduceOp", @@ -856,66 +621,6 @@ "Optional", "get_remote_module_template" ], - "torch.distributed.optim.functional_adadelta": [ - "Dict", - "List", - "Optional", - "Tensor" - ], - "torch.distributed.optim.functional_adagrad": [ - "Dict", - "List", - "Optional", - "Tensor" - ], - "torch.distributed.optim.functional_adam": [ - "Dict", - "List", - "Optional", - "Tensor", - "Tuple" - ], - "torch.distributed.optim.functional_adamax": [ - "Dict", - "List", - "Optional", - "Tensor", - "Tuple" - ], - "torch.distributed.optim.functional_adamw": [ - "Dict", - "List", - "Optional", - "Tensor", - "Tuple" - ], - "torch.distributed.optim.functional_rmsprop": [ - "Dict", - "List", - "Optional", - "Tensor" - ], - "torch.distributed.optim.functional_rprop": [ - "Dict", - "List", - "Optional", - "Tensor", - "Tuple" - ], - "torch.distributed.optim.functional_sgd": [ - "Dict", - "List", - "Optional", - "Tensor" - ], - "torch.distributed.optim.optimizer": [ - "List", - "Lock", - "Optional", - "RRef", - "Tensor", - "defaultdict" - ], "torch.distributed.optim.utils": [ "Type" ], @@ -1079,129 +784,18 @@ "Optional", "Union" ], - "torch.distributed.rpc.server_process_global_profiler": [ - "profile" - ], - "torch.distributions.bernoulli": [ - "ExponentialFamily", - "Number", - "binary_cross_entropy_with_logits", - "broadcast_all", - "lazy_property", - "logits_to_probs", - "probs_to_logits" - ], - "torch.distributions.beta": [ + "torch.distributions.kl": [ + "Bernoulli", + "Beta", + "Binomial", + "Callable", + "Categorical", + "Cauchy", + "ContinuousBernoulli", + "Dict", "Dirichlet", - "ExponentialFamily", - "Number", - "Real", - "broadcast_all" - ], - "torch.distributions.binomial": [ - "Distribution", - "broadcast_all", - "lazy_property", - "logits_to_probs", - "probs_to_logits" - ], - "torch.distributions.categorical": [ - "Distribution", - "lazy_property", - "logits_to_probs", - "probs_to_logits" - ], - "torch.distributions.cauchy": [ "Distribution", - "Number", - "broadcast_all" - ], - "torch.distributions.chi2": [ - "Gamma" - ], - "torch.distributions.continuous_bernoulli": [ - "ExponentialFamily", - "Number", - "binary_cross_entropy_with_logits", - "broadcast_all", - "clamp_probs", - "lazy_property", - "logits_to_probs", - "probs_to_logits" - ], - "torch.distributions.dirichlet": [ - "ExponentialFamily", - "Function", - "once_differentiable" - ], - "torch.distributions.distribution": [ - "Any", - "Dict", - "Optional", - "lazy_property" - ], - "torch.distributions.exp_family": [ - "Distribution" - ], - "torch.distributions.exponential": [ - "ExponentialFamily", - "Number", - "broadcast_all" - ], - "torch.distributions.fishersnedecor": [ - "Distribution", - "Gamma", - "Number", - "broadcast_all" - ], - "torch.distributions.gamma": [ - "ExponentialFamily", - "Number", - "broadcast_all" - ], - "torch.distributions.geometric": [ - "Distribution", - "Number", - "binary_cross_entropy_with_logits", - "broadcast_all", - "lazy_property", - "logits_to_probs", - "probs_to_logits" - ], - "torch.distributions.gumbel": [ - "AffineTransform", - "ExpTransform", - "Number", - "TransformedDistribution", - "Uniform", - "broadcast_all" - ], - "torch.distributions.half_cauchy": [ - "AbsTransform", - "Cauchy", - "TransformedDistribution" - ], - "torch.distributions.half_normal": [ - "AbsTransform", - "Normal", - "TransformedDistribution" - ], - "torch.distributions.independent": [ - "Dict", - "Distribution" - ], - "torch.distributions.kl": [ - "Bernoulli", - "Beta", - "Binomial", - "Callable", - "Categorical", - "Cauchy", - "ContinuousBernoulli", - "Dict", - "Dirichlet", - "Distribution", - "Exponential", + "Exponential", "ExponentialFamily", "Gamma", "Geometric", @@ -1221,117 +815,6 @@ "Uniform", "total_ordering" ], - "torch.distributions.kumaraswamy": [ - "AffineTransform", - "PowerTransform", - "TransformedDistribution", - "Uniform", - "broadcast_all" - ], - "torch.distributions.laplace": [ - "Distribution", - "Number", - "broadcast_all" - ], - "torch.distributions.lkj_cholesky": [ - "Beta", - "Distribution", - "broadcast_all" - ], - "torch.distributions.log_normal": [ - "ExpTransform", - "Normal", - "TransformedDistribution" - ], - "torch.distributions.logistic_normal": [ - "Normal", - "StickBreakingTransform", - "TransformedDistribution" - ], - "torch.distributions.lowrank_multivariate_normal": [ - "Distribution", - "lazy_property" - ], - "torch.distributions.mixture_same_family": [ - "Categorical", - "Dict", - "Distribution" - ], - "torch.distributions.multinomial": [ - "Binomial", - "Categorical", - "Distribution", - "broadcast_all" - ], - "torch.distributions.multivariate_normal": [ - "Distribution", - "lazy_property" - ], - "torch.distributions.negative_binomial": [ - "Distribution", - "broadcast_all", - "lazy_property", - "logits_to_probs", - "probs_to_logits" - ], - "torch.distributions.normal": [ - "ExponentialFamily", - "Number", - "Real", - "broadcast_all" - ], - "torch.distributions.one_hot_categorical": [ - "Categorical", - "Distribution" - ], - "torch.distributions.pareto": [ - "AffineTransform", - "ExpTransform", - "Exponential", - "TransformedDistribution", - "broadcast_all" - ], - "torch.distributions.poisson": [ - "ExponentialFamily", - "Number", - "broadcast_all" - ], - "torch.distributions.relaxed_bernoulli": [ - "Distribution", - "Number", - "SigmoidTransform", - "TransformedDistribution", - "broadcast_all", - "clamp_probs", - "lazy_property", - "logits_to_probs", - "probs_to_logits" - ], - "torch.distributions.relaxed_categorical": [ - "Categorical", - "Distribution", - "ExpTransform", - "TransformedDistribution", - "broadcast_all", - "clamp_probs" - ], - "torch.distributions.studentT": [ - "Chi2", - "Distribution", - "broadcast_all" - ], - "torch.distributions.transformed_distribution": [ - "ComposeTransform", - "Dict", - "Distribution", - "Independent", - "Transform" - ], - "torch.distributions.uniform": [ - "Distribution", - "Number", - "broadcast_all" - ], "torch.distributions.utils": [ "Any", "Dict", @@ -1339,24 +822,6 @@ "is_tensor_like", "update_wrapper" ], - "torch.distributions.von_mises": [ - "Distribution", - "broadcast_all", - "lazy_property" - ], - "torch.distributions.weibull": [ - "AffineTransform", - "Exponential", - "PowerTransform", - "TransformedDistribution", - "broadcast_all" - ], - "torch.distributions.wishart": [ - "ExponentialFamily", - "Number", - "Union", - "lazy_property" - ], "torch.fft": [ "Tensor", "fft", @@ -1384,15 +849,7 @@ "svd_lowrank" ], "torch.futures": [ - "Callable", - "Future", - "Generic", - "List", - "Optional", - "Type", - "TypeVar", - "Union", - "cast" + "Future" ], "torch.fx": [ "ProxyableClassMeta", @@ -1497,44 +954,6 @@ "Tuple", "compatibility" ], - "torch.fx.interpreter": [ - "Any", - "Argument", - "Dict", - "Graph", - "GraphModule", - "Iterator", - "List", - "Node", - "Optional", - "Proxy", - "Target", - "Tracer", - "Tuple", - "Union", - "compatibility", - "map_aggregate", - "map_arg" - ], - "torch.fx.node": [ - "Any", - "ArgsKwargsPair", - "Argument", - "BaseArgumentTypes", - "Callable", - "Dict", - "List", - "Optional", - "Set", - "Target", - "Tuple", - "Union", - "compatibility", - "immutable_dict", - "immutable_list", - "normalize_function", - "normalize_module" - ], "torch.fx.operator_schemas": [ "Any", "Callable", @@ -1555,160 +974,8 @@ "chain", "compatibility" ], - "torch.fx.passes.graph_manipulation": [ - "Any", - "Argument", - "Dict", - "Graph", - "GraphModule", - "List", - "NamedTuple", - "Node", - "Optional", - "ShapeProp", - "Target", - "Tuple", - "compatibility", - "lift_lowering_attrs_to_nodes", - "map_aggregate", - "map_arg" - ], - "torch.fx.passes.net_min_base": [ - "Any", - "Callable", - "Dict", - "FxNetAccFusionsFinder", - "Names", - "NodeList", - "NodeSet", - "Optional", - "ShapeProp", - "TensorOrTensors", - "Tensors", - "Tuple", - "compatibility", - "dataclass", - "map_arg", - "split_by_tags" - ], - "torch.fx.passes.operator_support": [ - "IsNodeSupported", - "SupportDict", - "SupportedArgumentDTypes", - "TargetTypeName", - "TensorMetadata", - "compatibility", - "get_node_target" - ], - "torch.fx.passes.param_fetch": [ - "Any", - "Callable", - "Dict", - "GraphModule", - "List", - "Tuple", - "Type", - "compatibility" - ], - "torch.fx.passes.shape_prop": [ - "Any", - "Dict", - "NamedTuple", - "Node", - "Optional", - "Tuple", - "compatibility", - "map_aggregate" - ], - "torch.fx.passes.split_module": [ - "Any", - "Callable", - "Dict", - "GraphModule", - "List", - "Optional", - "compatibility" - ], - "torch.fx.passes.split_utils": [ - "Dict", - "List", - "NodeList", - "NodeSet", - "Optional", - "compatibility", - "dataclass", - "field", - "map_arg" - ], - "torch.fx.passes.splitter_base": [ - "Any", - "Dict", - "FxGraphDrawer", - "FxNetAccFusionsFinder", - "Iterable", - "List", - "NamedTuple", - "NodeList", - "NodeSet", - "OperatorSupportBase", - "Optional", - "Sequence", - "ShapeProp", - "Tensors", - "Tuple", - "compatibility", - "dataclass", - "defaultdict", - "get_node_target", - "get_size_of_node", - "is_node_output_tensor", - "map_arg", - "split_by_tags" - ], - "torch.fx.passes.tools_common": [ - "Any", - "Dict", - "List", - "Mapping", - "Names", - "NodeList", - "NodeSet", - "Set", - "TensorOrTensors", - "Tensors", - "Tuple", - "Union", - "compatibility", - "dataclass" - ], "torch.fx.proxy": [ - "Any", - "Argument", - "Callable", - "Dict", - "Graph", - "Iterable", - "Iterator", - "Node", - "Optional", - "Target", - "Tuple", - "check_for_mutable_operation", - "compatibility", - "map_aggregate" - ], - "torch.fx.subgraph_rewriter": [ - "Callable", - "Dict", - "Graph", - "GraphModule", - "List", - "NamedTuple", - "Node", - "Optional", - "Set", - "compatibility", - "symbolic_trace" + "assert_fn" ], "torch.hub": [ "HTTPError", @@ -1875,12 +1142,14 @@ "qr", "slogdet", "solve", + "solve_ex", "solve_triangular", "svd", "svdvals", "tensorinv", "tensorsolve", "vander", + "vecdot", "vector_norm" ], "torch.multiprocessing": [ @@ -1976,243 +1245,41 @@ "has_torch_function", "has_torch_function_unary", "has_torch_function_variadic", - "leaky_relu_", - "linear", - "logsigmoid", - "native_channel_shuffle", - "one_hot", - "pairwise_distance", - "pdist", - "pixel_shuffle", - "pixel_unshuffle", - "prelu", - "relu_", - "rrelu_", - "selu_", - "softplus", - "softshrink", - "threshold_" - ], - "torch.nn.init": [ - "Tensor" - ], - "torch.nn.intrinsic.modules": [ - "_FusedModule" - ], - "torch.nn.intrinsic.modules.fused": [ - "BatchNorm1d", - "BatchNorm2d", - "BatchNorm3d", - "Conv1d", - "Conv2d", - "Conv3d", - "Linear", - "ReLU", - "type_before_parametrizations" - ], - "torch.nn.intrinsic.qat.modules.conv_fused": [ - "Parameter", - "TypeVar", - "fuse_conv_bn_weights" - ], - "torch.nn.intrinsic.qat.modules.linear_fused": [ - "Parameter", - "fuse_linear_bn_weights" - ], - "torch.nn.intrinsic.quantized.modules.conv_relu": [ - "fuse_conv_bn_weights" - ], - "torch.nn.modules.activation": [ - "Module", - "NonDynamicallyQuantizableLinear", - "Optional", - "Parameter", - "Tensor", - "Tuple", - "constant_", - "xavier_normal_", - "xavier_uniform_" - ], - "torch.nn.modules.adaptive": [ - "Linear", - "List", - "Module", - "ModuleList", - "Sequence", - "Sequential", - "Tensor", - "log_softmax", - "namedtuple" - ], - "torch.nn.modules.batchnorm": [ - "Any", - "LazyModuleMixin", - "Module", - "Optional", - "Parameter", - "Tensor", - "UninitializedBuffer", - "UninitializedParameter", - "sync_batch_norm" - ], - "torch.nn.modules.channelshuffle": [ - "Module", - "Tensor" - ], - "torch.nn.modules.container": [ - "Any", - "Dict", - "Iterable", - "Iterator", - "Mapping", - "Module", - "Optional", - "OrderedDict", - "Parameter", - "Tuple", - "TypeVar", - "Union", - "chain", - "islice", - "overload" - ], - "torch.nn.modules.conv": [ - "LazyModuleMixin", - "List", - "Module", - "Optional", - "Parameter", - "Tensor", - "Tuple", - "UninitializedParameter", - "Union" - ], - "torch.nn.modules.distance": [ - "Module", - "Tensor" - ], - "torch.nn.modules.dropout": [ - "Module", - "Tensor" - ], - "torch.nn.modules.flatten": [ - "Module", - "Tensor", - "Tuple", - "Union" - ], - "torch.nn.modules.fold": [ - "Module", - "Tensor" - ], - "torch.nn.modules.instancenorm": [ - "Tensor" - ], - "torch.nn.modules.lazy": [ - "Protocol", - "is_lazy" - ], - "torch.nn.modules.linear": [ - "LazyModuleMixin", - "Module", - "NonDynamicallyQuantizableLinear", - "Parameter", - "Tensor", - "UninitializedParameter" - ], - "torch.nn.modules.loss": [ - "Callable", - "Module", - "Optional", - "PairwiseDistance", - "Tensor" - ], - "torch.nn.modules.module": [ - "Any", - "Callable", - "Dict", - "Iterator", - "List", - "Mapping", - "Optional", - "OrderedDict", - "Parameter", - "RemovableHandle", - "Set", - "Tensor", - "Tuple", - "TypeVar", - "Union", - "device", - "dtype", - "namedtuple", - "overload" - ], - "torch.nn.modules.normalization": [ - "List", - "Module", - "Parameter", - "Size", - "Tensor", - "Tuple", - "Union" - ], - "torch.nn.modules.padding": [ - "Module", - "Sequence", - "Tensor", - "Tuple" - ], - "torch.nn.modules.pixelshuffle": [ - "Module", - "Tensor" + "leaky_relu_", + "linear", + "logsigmoid", + "native_channel_shuffle", + "one_hot", + "pairwise_distance", + "pdist", + "pixel_shuffle", + "pixel_unshuffle", + "prelu", + "relu_", + "rrelu_", + "selu_", + "softplus", + "softshrink", + "threshold_" ], - "torch.nn.modules.pooling": [ - "List", - "Module", - "Optional", + "torch.nn.init": [ "Tensor" ], - "torch.nn.modules.rnn": [ - "List", - "Module", - "Optional", - "PackedSequence", - "Parameter", - "Tensor", - "Tuple", - "overload" + "torch.nn.intrinsic.modules": [ + "_FusedModule" ], - "torch.nn.modules.sparse": [ - "Module", - "Optional", + "torch.nn.intrinsic.qat.modules.linear_fused": [ "Parameter", - "Tensor" + "fuse_linear_bn_weights" ], - "torch.nn.modules.transformer": [ - "Any", - "Callable", - "Dropout", - "LayerNorm", - "Linear", - "Module", - "ModuleList", - "MultiheadAttention", - "Optional", - "Tensor", - "Union", - "xavier_uniform_" + "torch.nn.intrinsic.quantized.modules.conv_relu": [ + "fuse_conv_bn_weights" ], - "torch.nn.modules.upsampling": [ - "Module", - "Optional", - "Tensor" + "torch.nn.modules.linear": [ + "NonDynamicallyQuantizableLinear" ], - "torch.nn.modules.utils": [ - "Any", - "Dict", - "List", - "repeat" + "torch.nn.modules.rnn": [ + "apply_permutation" ], "torch.nn.parallel": [ "DistributedDataParallelCPU" @@ -2220,36 +1287,6 @@ "torch.nn.parallel.comm": [ "List" ], - "torch.nn.parallel.data_parallel": [ - "Module", - "chain", - "gather", - "parallel_apply", - "replicate", - "scatter_kwargs" - ], - "torch.nn.parallel.distributed": [ - "Any", - "Callable", - "Enum", - "Function", - "Join", - "JoinHook", - "Joinable", - "Module", - "RRef", - "ReduceOp", - "Type", - "Variable", - "auto", - "contextmanager", - "dataclass", - "gather", - "is_namedtuple", - "scatter_kwargs", - "tree_flatten", - "tree_unflatten" - ], "torch.nn.parallel.parallel_apply": [ "ExceptionWrapper", "autocast" @@ -2258,8 +1295,7 @@ "OrderedDict" ], "torch.nn.parallel.scatter_gather": [ - "Gather", - "Scatter" + "is_namedtuple" ], "torch.nn.parameter": [ "OrderedDict" @@ -2294,17 +1330,8 @@ "torch.nn.quantized": [ "MaxPool2d" ], - "torch.nn.quantized.dynamic.modules.conv": [ - "Tensor" - ], "torch.nn.quantized.dynamic.modules.rnn": [ - "Dict", - "List", - "Optional", - "PackedSequence", - "Tensor", - "Tuple", - "Union" + "apply_permutation" ], "torch.nn.quantized.functional": [ "List", @@ -2318,86 +1345,17 @@ "torch.nn.quantized.modules.batchnorm": [ "Tensor" ], - "torch.nn.quantized.modules.conv": [ - "List", - "Optional", - "TypeVar", - "WeightedQuantizedModule", - "fuse_conv_bn_weights" - ], - "torch.nn.quantized.modules.embedding_ops": [ - "List", - "Optional", - "Tensor", - "hide_packed_params_repr" - ], - "torch.nn.quantized.modules.functional_modules": [ - "List", - "Tensor" - ], - "torch.nn.quantized.modules.linear": [ - "Iterable", - "Optional", - "WeightedQuantizedModule", - "fuse_linear_bn_weights", - "hide_packed_params_repr", - "type_before_parametrizations" - ], "torch.nn.quantized.modules.utils": [ "repeat" ], - "torch.nn.utils.clip_grad": [ - "Iterable", - "Union" + "torch.nn.utils.rnn": [ + "bind", + "PackedSequence_" ], "torch.nn.utils.convert_parameters": [ "Iterable", "Optional" ], - "torch.nn.utils.parametrizations": [ - "Enum", - "Module", - "Optional", - "Tensor", - "auto" - ], - "torch.nn.utils.parametrize": [ - "Dict", - "Module", - "ModuleDict", - "ModuleList", - "Optional", - "Parameter", - "Sequence", - "Tensor", - "Tuple", - "Union", - "contextmanager" - ], - "torch.nn.utils.rnn": [ - "Iterable", - "List", - "Optional", - "Tensor", - "Tuple", - "Union", - "namedtuple" - ], - "torch.nn.utils.spectral_norm": [ - "Any", - "Module", - "Optional", - "TypeVar", - "normalize" - ], - "torch.nn.utils.weight_norm": [ - "Any", - "Module", - "Parameter", - "TypeVar", - "UninitializedParameter", - "norm_except_dim" - ], "torch.onnx": [ "Dict", "OperatorExportTypes", @@ -2405,126 +1363,19 @@ "TensorProtoDataType", "TrainingMode" ], - "torch.optim.adadelta": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.adagrad": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.adam": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.adamax": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.adamw": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.asgd": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.lbfgs": [ - "Optimizer", - "reduce" - ], - "torch.optim.lr_scheduler": [ - "Counter", - "Optimizer", - "bisect_right", - "wraps" - ], - "torch.optim.nadam": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.optimizer": [ - "chain", - "deepcopy", - "defaultdict" - ], - "torch.optim.radam": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.rmsprop": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.rprop": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.sgd": [ - "List", - "Optimizer", - "Optional", - "Tensor" - ], - "torch.optim.sparse_adam": [ - "Optimizer" - ], - "torch.optim.swa_utils": [ - "Module", - "deepcopy" - ], "torch.overrides": [ "BaseTorchFunctionMode", "TorchFunctionMode", "TorchFunctionModeMeta", "enable_torch_function_mode", "get_default_nowrap_functions", - "has_torch_function", - "push_torch_function_mode" - ], - "torch.package.analyze.find_first_use_of_broken_modules": [ - "Dict", - "List", - "PackagingError" + "has_torch_function" ], "torch.package.analyze.is_from_package": [ "Any", "ModuleType", "is_mangled" ], - "torch.package.analyze.trace_dependencies": [ - "Any", - "Callable", - "Iterable", - "List", - "Tuple" - ], - "torch.package.file_structure_representation": [ - "Dict", - "GlobGroup", - "GlobPattern", - "List" - ], "torch.package.find_file_dependencies": [ "List", "Optional", @@ -2535,92 +1386,12 @@ "Iterable", "Union" ], - "torch.package.importer": [ - "ABC", - "Any", - "Dict", - "List", - "ModuleType", - "Optional", - "Tuple", - "abstractmethod", - "demangle", - "get_mangle_prefix", - "is_mangled" - ], - "torch.package.package_exporter": [ - "ActionHook", - "Any", - "BinaryIO", - "Callable", - "DefaultDict", - "DiGraph", - "Dict", - "Enum", - "GlobGroup", - "GlobPattern", - "Importer", - "List", - "Optional", - "OrderedDict", - "OrderedImporter", - "Path", - "RemovableHandle", - "Sequence", - "Set", - "Storage", - "Union", - "cast", - "create_pickler", - "dataclass", - "defaultdict", - "demangle", - "find_files_source_depends_on", - "is_mangled", - "is_stdlib_module", - "location_tag", - "normalize_storage_type" - ], - "torch.package.package_importer": [ - "Any", - "BinaryIO", - "Callable", - "Dict", - "Directory", - "DirectoryReader", - "GlobPattern", - "Importer", - "List", - "Optional", - "PackageMangler", - "PackageUnpickler", - "Path", - "Union", - "WeakValueDictionary", - "cast", - "contextmanager", - "demangle" - ], "torch.profiler": [ "DeviceType", "ProfilerActivity", "kineto_available", "record_function" ], - "torch.profiler.profiler": [ - "Any", - "Callable", - "Dict", - "Enum", - "Iterable", - "List", - "Optional", - "ProfilerActivity", - "Tuple", - "kineto_available", - "partial", - "warn" - ], "torch.quantization": [ "ABC", "DeQuantStub", @@ -2868,6 +1639,7 @@ "softmax" ], "torch.special": [ + "airy_ai", "bessel_j0", "bessel_j1", "bessel_y0", @@ -2911,12 +1683,15 @@ "polygamma", "psi", "round", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", "shifted_chebyshev_polynomial_t", "shifted_chebyshev_polynomial_u", "shifted_chebyshev_polynomial_v", "shifted_chebyshev_polynomial_w", "sinc", "softmax", + "spherical_bessel_j0", "xlog1py", "xlogy", "zeta" @@ -2960,10 +1735,6 @@ "rand", "randn" ], - "torch.torch_version": [ - "Any", - "Iterable" - ], "torch.types": [ "Any", "Device", @@ -2978,14 +1749,6 @@ "enable_minidumps", "enable_minidumps_on_exceptions" ], - "torch.utils.benchmark.utils.common": [ - "_make_temp_dir", - "ordered_unique", - "select_unit", - "set_torch_threads", - "trim_sigfig", - "unit_to_english" - ], "torch.utils.benchmark.utils.compare": [ "Colorize", "Table", @@ -3074,47 +1837,6 @@ "IO", "Union" ], - "torch.utils.tensorboard.summary": [ - "HistogramProto", - "Optional", - "PrCurvePluginData", - "Summary", - "SummaryMetadata", - "TensorProto", - "TensorShapeProto", - "TextPluginData", - "convert_to_HWC", - "make_np", - "range" - ], - "torch.utils.tensorboard.writer": [ - "Event", - "EventFileWriter", - "ProjectorConfig", - "SessionLog", - "audio", - "custom_scalars", - "figure_to_image", - "get_embedding_info", - "graph", - "histogram", - "histogram_raw", - "hparams", - "image", - "image_boxes", - "load_onnx_graph", - "make_mat", - "make_np", - "make_sprite", - "make_tsv", - "mesh", - "pr_curve", - "pr_curve_raw", - "scalar", - "text", - "video", - "write_pbtxt" - ], "torch": [ "BFloat16Storage", "BFloat16Tensor", @@ -3130,7 +1852,7 @@ "QUInt4x2Storage", "QUInt8Storage", "Storage", - "_TypedStorage", + "TypedStorage", "_adaptive_avg_pool2d", "_adaptive_avg_pool3d", "_add_batch_dim", @@ -3174,7 +1896,6 @@ "_cummax_helper", "_cummin_helper", "_debug_has_internal_overlap", - "_det_lu_based_helper", "_det_lu_based_helper_backward_helper", "_dim_arange", "_dirichlet_grad", @@ -3281,6 +2002,7 @@ "_linalg_inv_out_helper_", "_linalg_qr_helper", "_linalg_svd", + "_linalg_solve_ex", "_log_softmax", "_log_softmax_backward_data", "_logcumsumexp", diff --git a/test/ao/sparsity/test_activation_sparsifier.py b/test/ao/sparsity/test_activation_sparsifier.py new file mode 100644 index 0000000000000..c50123a6cf27a --- /dev/null +++ b/test/ao/sparsity/test_activation_sparsifier.py @@ -0,0 +1,381 @@ +# -*- coding: utf-8 -*- +# Owner(s): ["module: unknown"] + +import copy +from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo +import logging +import torch +from torch.ao.sparsity._experimental.activation_sparsifier.activation_sparsifier import ActivationSparsifier +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.sparsity.sparsifier.utils import module_to_fqn +from typing import List + +logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=3) + self.conv2 = nn.Conv2d(32, 32, kernel_size=3) + self.identity1 = nn.Identity() + self.max_pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.linear1 = nn.Linear(4608, 128) + self.identity2 = nn.Identity() + self.linear2 = nn.Linear(128, 10) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + out = self.identity1(out) + out = self.max_pool1(out) + + batch_size = x.shape[0] + out = out.reshape(batch_size, -1) + + out = F.relu(self.identity2(self.linear1(out))) + out = self.linear2(out) + return out + + +class TestActivationSparsifier(TestCase): + def _check_constructor(self, activation_sparsifier, model, defaults, sparse_config): + """Helper function to check if the model, defaults and sparse_config are loaded correctly + in the activation sparsifier + """ + sparsifier_defaults = activation_sparsifier.defaults + combined_defaults = {**defaults, 'sparse_config': sparse_config} + + # more keys are populated in activation sparsifier (eventhough they may be None) + assert len(combined_defaults) <= len(activation_sparsifier.defaults) + + for key, config in sparsifier_defaults.items(): + # all the keys in combined_defaults should be present in sparsifier defaults + assert config == combined_defaults.get(key, None) + + def _check_register_layer(self, activation_sparsifier, defaults, sparse_config, layer_args_list): + """Checks if layers in the model are correctly mapped to it's arguments. + + Args: + activation_sparsifier (sparsifier object) + activation sparsifier object that is being tested. + + defaults (Dict) + all default config (except sparse_config) + + sparse_config (Dict) + default sparse config passed to the sparsifier + + layer_args_list (list of tuples) + Each entry in the list corresponds to the layer arguments. + First entry in the tuple corresponds to all the arguments other than sparse_config + Second entry in the tuple corresponds to sparse_config + """ + # check args + data_groups = activation_sparsifier.data_groups + assert len(data_groups) == len(layer_args_list) + for layer_args in layer_args_list: + layer_arg, sparse_config_layer = layer_args + + # check sparse config + sparse_config_actual = copy.deepcopy(sparse_config) + sparse_config_actual.update(sparse_config_layer) + + name = module_to_fqn(activation_sparsifier.model, layer_arg['layer']) + + assert data_groups[name]['sparse_config'] == sparse_config_actual + + # assert the rest + other_config_actual = copy.deepcopy(defaults) + other_config_actual.update(layer_arg) + other_config_actual.pop('layer') + + for key, value in other_config_actual.items(): + assert key in data_groups[name] + assert value == data_groups[name][key] + + # get_mask should raise error + with self.assertRaises(ValueError): + activation_sparsifier.get_mask(name=name) + + def _check_pre_forward_hook(self, activation_sparsifier, data_list): + """Registering a layer attaches a pre-forward hook to that layer. This function + checks if the pre-forward hook works as expected. Specifically, checks if the + input is aggregated correctly. + + Basically, asserts that the aggregate of input activations is the same as what was + computed in the sparsifier. + + Args: + activation_sparsifier (sparsifier object) + activation sparsifier object that is being tested. + + data_list (list of torch tensors) + data input to the model attached to the sparsifier + + """ + # can only check for the first layer + data_agg_actual = data_list[0] + model = activation_sparsifier.model + layer_name = module_to_fqn(model, model.conv1) + agg_fn = activation_sparsifier.data_groups[layer_name]['aggregate_fn'] + + for i in range(1, len(data_list)): + data_agg_actual = agg_fn(data_agg_actual, data_list[i]) + + assert 'data' in activation_sparsifier.data_groups[layer_name] + assert torch.all(activation_sparsifier.data_groups[layer_name]['data'] == data_agg_actual) + + return data_agg_actual + + def _check_step(self, activation_sparsifier, data_agg_actual): + """Checks if .step() works as expected. Specifically, checks if the mask is computed correctly. + + Args: + activation_sparsifier (sparsifier object) + activation sparsifier object that is being tested. + + data_agg_actual (torch tensor) + aggregated torch tensor + + """ + model = activation_sparsifier.model + layer_name = module_to_fqn(model, model.conv1) + assert layer_name is not None + + reduce_fn = activation_sparsifier.data_groups[layer_name]['reduce_fn'] + + data_reduce_actual = reduce_fn(data_agg_actual) + mask_fn = activation_sparsifier.data_groups[layer_name]['mask_fn'] + sparse_config = activation_sparsifier.data_groups[layer_name]['sparse_config'] + mask_actual = mask_fn(data_reduce_actual, **sparse_config) + + mask_model = activation_sparsifier.get_mask(layer_name) + + assert torch.all(mask_model == mask_actual) + + for _, config in activation_sparsifier.data_groups.items(): + assert 'data' not in config + + + def _check_squash_mask(self, activation_sparsifier, data): + """Makes sure that squash_mask() works as usual. Specifically, checks + if the sparsifier hook is attached correctly. + This is achieved by only looking at the identity layers and making sure that + the output == layer(input * mask). + + Args: + activation_sparsifier (sparsifier object) + activation sparsifier object that is being tested. + + data (torch tensor) + dummy batched data + """ + # create a forward hook for checking ouput == layer(input * mask) + def check_output(name): + mask = activation_sparsifier.get_mask(name) + features = activation_sparsifier.data_groups[name].get('features') + feature_dim = activation_sparsifier.data_groups[name].get('feature_dim') + + def hook(module, input, output): + input_data = input[0] + if features is None: + assert torch.all(mask * input_data == output) + else: + for feature_idx in range(0, len(features)): + feature = torch.Tensor([features[feature_idx]], device=input_data.device).long() + inp_data_feature = torch.index_select(input_data, feature_dim, feature) + out_data_feature = torch.index_select(output, feature_dim, feature) + + assert torch.all(mask[feature_idx] * inp_data_feature == out_data_feature) + return hook + + for name, config in activation_sparsifier.data_groups.items(): + if 'identity' in name: + config['layer'].register_forward_hook(check_output(name)) + + activation_sparsifier.model(data) + + + def _check_state_dict(self, sparsifier1): + """Checks if loading and restoring of state_dict() works as expected. + Basically, dumps the state of the sparsifier and loads it in the other sparsifier + and checks if all the configuration are in line. + + This function is called at various times in the workflow to makes sure that the sparsifier + can be dumped and restored at any point in time. + """ + state_dict = sparsifier1.state_dict() + + new_model = Model() + + # create an empty new sparsifier + sparsifier2 = ActivationSparsifier(new_model) + + assert sparsifier2.defaults != sparsifier1.defaults + assert len(sparsifier2.data_groups) != len(sparsifier1.data_groups) + + sparsifier2.load_state_dict(state_dict) + + assert sparsifier2.defaults == sparsifier1.defaults + + # import pdb; pdb.set_trace() + for name, state in sparsifier2.state.items(): + assert name in sparsifier1.state + mask1 = sparsifier1.state[name]['mask'] + mask2 = state['mask'] + + if mask1 is None: + assert mask2 is None + else: + assert type(mask1) == type(mask2) + if isinstance(mask1, List): + assert len(mask1) == len(mask2) + for idx in range(len(mask1)): + assert torch.all(mask1[idx] == mask2[idx]) + else: + # import pdb; pdb.set_trace() + assert torch.all(mask1 == mask2) + + # make sure that the state dict is stored as torch sparse + for _, state in state_dict['state'].items(): + mask = state['mask'] + if mask is not None: + if isinstance(mask, List): + for idx in range(len(mask)): + assert mask[idx].is_sparse + else: + assert mask.is_sparse + + dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups + + for layer_name, config in dg1.items(): + assert layer_name in dg2 + + # exclude hook and layer + config1 = {key: value for key, value in config.items() if key not in ['hook', 'layer']} + config2 = {key: value for key, value in dg2[layer_name].items() if key not in ['hook', 'layer']} + + assert config1 == config2 + + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") + def test_activation_sparsifier(self): + """Simulates the workflow of the activation sparsifier, starting from object creation + till squash_mask(). + The idea is to check that everything works as expected while in the workflow. + """ + # defining aggregate, reduce and mask functions + def agg_fn(x, y): + return x + y + + def reduce_fn(x): + return torch.mean(x, dim=0) + + def _vanilla_norm_sparsifier(data, sparsity_level): + r"""Similar to data norm spasifier but block_shape = (1,1). + Simply, flatten the data, sort it and mask out the values less than threshold + """ + data_norm = torch.abs(data).flatten() + _, sorted_idx = torch.sort(data_norm) + threshold_idx = round(sparsity_level * len(sorted_idx)) + sorted_idx = sorted_idx[:threshold_idx] + + mask = torch.ones_like(data_norm) + mask.scatter_(dim=0, index=sorted_idx, value=0) + mask = mask.reshape(data.shape) + + return mask + + # Creating default function and sparse configs + # default sparse_config + sparse_config = { + 'sparsity_level': 0.5 + } + + defaults = { + 'aggregate_fn': agg_fn, + 'reduce_fn': reduce_fn + } + + # simulate the workflow + # STEP 1: make data and activation sparsifier object + model = Model() # create model + activation_sparsifier = ActivationSparsifier(model, **defaults, **sparse_config) + + # Test Constructor + self._check_constructor(activation_sparsifier, model, defaults, sparse_config) + + # STEP 2: Register some layers + register_layer1_args = { + 'layer': model.conv1, + 'mask_fn': _vanilla_norm_sparsifier + } + sparse_config_layer1 = {'sparsity_level': 0.3} + + register_layer2_args = { + 'layer': model.linear1, + 'features': [0, 10, 234], + 'feature_dim': 1, + 'mask_fn': _vanilla_norm_sparsifier + } + sparse_config_layer2 = {'sparsity_level': 0.1} + + register_layer3_args = { + 'layer': model.identity1, + 'mask_fn': _vanilla_norm_sparsifier + } + sparse_config_layer3 = {'sparsity_level': 0.3} + + register_layer4_args = { + 'layer': model.identity2, + 'features': [0, 10, 20], + 'feature_dim': 1, + 'mask_fn': _vanilla_norm_sparsifier + } + sparse_config_layer4 = {'sparsity_level': 0.1} + + layer_args_list = [(register_layer1_args, sparse_config_layer1), (register_layer2_args, sparse_config_layer2)] + layer_args_list += [(register_layer3_args, sparse_config_layer3), (register_layer4_args, sparse_config_layer4)] + + # Registering.. + for layer_args in layer_args_list: + layer_arg, sparse_config_layer = layer_args + activation_sparsifier.register_layer(**layer_arg, **sparse_config_layer) + + # check if things are registered correctly + self._check_register_layer(activation_sparsifier, defaults, sparse_config, layer_args_list) + + # check state_dict after registering and before model forward + self._check_state_dict(activation_sparsifier) + + # check if forward pre hooks actually work + # some dummy data + data_list = [] + num_data_points = 5 + for _ in range(0, num_data_points): + rand_data = torch.randn(16, 1, 28, 28) + activation_sparsifier.model(rand_data) + data_list.append(rand_data) + + data_agg_actual = self._check_pre_forward_hook(activation_sparsifier, data_list) + # check state_dict() before step() + self._check_state_dict(activation_sparsifier) + + # STEP 3: sparsifier step + activation_sparsifier.step() + + # check state_dict() after step() and before squash_mask() + self._check_state_dict(activation_sparsifier) + + # self.check_step() + self._check_step(activation_sparsifier, data_agg_actual) + + # STEP 4: squash mask + activation_sparsifier.squash_mask() + + self._check_squash_mask(activation_sparsifier, data_list[0]) + + # check state_dict() after squash_mask() + self._check_state_dict(activation_sparsifier) diff --git a/test/ao/sparsity/test_composability.py b/test/ao/sparsity/test_composability.py index 577c20a1bddc6..698652ef5bd56 100644 --- a/test/ao/sparsity/test_composability.py +++ b/test/ao/sparsity/test_composability.py @@ -9,6 +9,8 @@ from torch import nn from torch.ao import sparsity from torch.testing._internal.common_utils import TestCase +from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, convert_to_reference_fx, prepare_qat_fx +from torch.ao.sparsity import fqn_to_module logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -20,50 +22,47 @@ "zeros_per_block": 4, } +def _get_model_and_sparsifier_and_sparse_config(qconfig=None): + model = nn.Sequential( + nn.Linear(4, 4), # 0 + nn.ReLU(), + nn.Linear(4, 4), # 2 + nn.ReLU(), + tq.QuantStub(), + nn.Linear(4, 4), # 5 + nn.ReLU(), + tq.DeQuantStub(), + ) + if qconfig: + model[4].qconfig = qconfig + model[5].qconfig = qconfig + + sparsifier = sparsity.WeightNormSparsifier(**sparse_defaults) + + sparse_config = [ + { + "tensor_fqn": '5.weight', + "sparsity_level": 0.7, + "sparse_block_shape": (1, 4), + "zeros_per_block": 4, + }, + {"tensor_fqn": "0.weight"}, + ] + return model, sparsifier, sparse_config + +def _squash_mask_calibrate_and_convert(model, sparsifier, input): + sparsifier.step() + sparsifier.squash_mask() + model(input) + tq.convert(model, inplace=True) + +def _calculate_sparsity(tensor): + return ((tensor == 0).sum() / tensor.numel()).item() + # This series of tests are to check the composability goals for sparsity and quantization. Namely # that performing quantization and sparsity model manipulations in various orderings # does not cause problems class TestComposability(TestCase): - def _get_model_and_sparsifier_and_sparse_config(self, qconfig=None): - model = nn.Sequential( - nn.Linear(4, 4), # 0 - nn.ReLU(), - nn.Linear(4, 4), # 2 - nn.ReLU(), - tq.QuantStub(), - nn.Linear(4, 4), # 5 - nn.ReLU(), - tq.DeQuantStub(), - ) - if qconfig is None: - model[4].qconfig = tq.get_default_qconfig("fbgemm") - model[5].qconfig = tq.get_default_qconfig("fbgemm") - else: - model[4].qconfig = qconfig - model[5].qconfig = qconfig - - sparsifier = sparsity.WeightNormSparsifier(**sparse_defaults) - - sparse_config = [ - { - "tensor_fqn": '5.weight', - "sparsity_level": 0.7, - "sparse_block_shape": (1, 4), - "zeros_per_block": 4, - }, - {"tensor_fqn": "0.weight"}, - ] - return model, sparsifier, sparse_config - - def _squash_mask_calibrate_and_convert(self, model, sparsifier, input): - sparsifier.step() - sparsifier.squash_mask() - model(input) - tq.convert(model, inplace=True) - - def _calculate_sparsity(self, tensor): - return ((tensor == 0).sum() / tensor.numel()).item() - # This test checks whether performing quantization prepare before sparse prepare # causes any issues and verifies that the correct observers are inserted and that # the quantized model works as expected @@ -72,7 +71,7 @@ def test_q_prep_before_s_prep(self): mod, sparsifier, sparse_config, - ) = self._get_model_and_sparsifier_and_sparse_config() + ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm")) tq.prepare(mod, inplace=True) sparsifier.prepare(mod, config=sparse_config) @@ -83,7 +82,7 @@ def test_q_prep_before_s_prep(self): # check that correct observers were inserted self.assertTrue(hasattr(mod[5], "activation_post_process")) - self._squash_mask_calibrate_and_convert( + _squash_mask_calibrate_and_convert( mod, sparsifier, torch.randn(1, 4, 4, 4) ) @@ -101,7 +100,7 @@ def test_s_prep_before_q_prep(self): mod, sparsifier, sparse_config, - ) = self._get_model_and_sparsifier_and_sparse_config() + ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm")) sparsifier.prepare(mod, config=sparse_config) tq.prepare(mod, inplace=True) @@ -115,7 +114,7 @@ def test_s_prep_before_q_prep(self): # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) - self._squash_mask_calibrate_and_convert( + _squash_mask_calibrate_and_convert( mod, sparsifier, torch.randn(1, 4, 4, 4) ) @@ -132,7 +131,7 @@ def test_convert_without_squash_mask(self): mod, sparsifier, sparse_config, - ) = self._get_model_and_sparsifier_and_sparse_config() + ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm")) sparsifier.prepare(mod, config=sparse_config) tq.prepare(mod, inplace=True) @@ -146,7 +145,7 @@ def test_convert_without_squash_mask(self): # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) sparsifier.step() - sparsity_level = self._calculate_sparsity(mod[5].weight) + sparsity_level = _calculate_sparsity(mod[5].weight) mod(torch.randn(1, 4, 4, 4)) tq.convert(mod, inplace=True) @@ -155,7 +154,7 @@ def test_convert_without_squash_mask(self): self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified - cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0]) + cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) self.assertGreaterAlmostEqual( sparsity_level, sparse_config[0]["sparsity_level"] @@ -170,7 +169,7 @@ def test_s_prep_before_fusion(self): mod, sparsifier, sparse_config, - ) = self._get_model_and_sparsifier_and_sparse_config() + ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm")) sparsifier.prepare(mod, config=sparse_config) tq.fuse_modules(mod, [["5", "6"]], inplace=True) mod[5].qconfig = tq.get_default_qconfig("fbgemm") @@ -184,7 +183,7 @@ def test_s_prep_before_fusion(self): # check that correct observers were inserted and that matching # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) - self._squash_mask_calibrate_and_convert( + _squash_mask_calibrate_and_convert( mod, sparsifier, torch.randn(1, 4, 4, 4) ) @@ -199,7 +198,7 @@ def test_fusion_before_s_prep(self): mod, sparsifier, _, - ) = self._get_model_and_sparsifier_and_sparse_config() + ) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm")) tq.fuse_modules(mod, [["5", "6"]], inplace=True) # its absolutely broken by fusion but will still work if you put the correct fqn in @@ -210,7 +209,7 @@ def test_fusion_before_s_prep(self): "sparse_block_shape": (1, 4), "zeros_per_block": 4, }, - {"tensor_fqn": ".0.weight"}, + {"tensor_fqn": "0.weight"}, ] sparsifier.prepare(mod, config=sparse_config) @@ -226,7 +225,7 @@ def test_fusion_before_s_prep(self): # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) sparsifier.step() - sparsity_level = self._calculate_sparsity(mod[5][0].weight) + sparsity_level = _calculate_sparsity(mod[5][0].weight) mod(torch.randn(1, 4, 4, 4)) tq.convert(mod, inplace=True) @@ -235,7 +234,7 @@ def test_fusion_before_s_prep(self): self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified - cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0]) + cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) self.assertGreaterAlmostEqual( sparsity_level, sparse_config[0]["sparsity_level"] @@ -251,7 +250,7 @@ def test_s_prep_before_qat_prep(self): mod, sparsifier, sparse_config, - ) = self._get_model_and_sparsifier_and_sparse_config( + ) = _get_model_and_sparsifier_and_sparse_config( tq.get_default_qat_qconfig("fbgemm") ) sparsifier.prepare(mod, config=sparse_config) @@ -263,7 +262,7 @@ def test_s_prep_before_qat_prep(self): # occured successfully self.assertTrue(hasattr(mod[5], "activation_post_process")) self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear)) - self._squash_mask_calibrate_and_convert( + _squash_mask_calibrate_and_convert( mod, sparsifier, torch.randn(1, 4, 4, 4) ) # check that final module is the expected quantized module and that the model runs @@ -271,12 +270,12 @@ def test_s_prep_before_qat_prep(self): self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified - cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0]) + cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) # This tests whether performing qat prepare before sparse prepare causes issues. def test_qat_prep_before_s_prep(self): - mod, sparsifier, _ = self._get_model_and_sparsifier_and_sparse_config( + mod, sparsifier, _ = _get_model_and_sparsifier_and_sparse_config( tq.get_default_qat_qconfig("fbgemm") ) tq.prepare_qat(mod, inplace=True) @@ -289,7 +288,7 @@ def test_qat_prep_before_s_prep(self): "sparse_block_shape": (1, 4), "zeros_per_block": 4, }, - {"tensor_fqn": ".0.weight"}, + {"tensor_fqn": "0.weight"}, ] sparsifier.prepare(mod, config=sparse_config) @@ -303,7 +302,7 @@ def test_qat_prep_before_s_prep(self): self.assertTrue(hasattr(mod[5], "activation_post_process")) self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear)) - self._squash_mask_calibrate_and_convert( + _squash_mask_calibrate_and_convert( mod, sparsifier, torch.randn(1, 4, 4, 4) ) @@ -312,5 +311,273 @@ def test_qat_prep_before_s_prep(self): self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) # check that module was actually sparsified - cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0]) + cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) + self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) + +def _module_has_activation_post_process(model, fqn_of_module): + for node in model.graph.nodes: + # look for an observer whose arg is the target module + if "activation_post_process" in node.name: + if node.args[0].target == fqn_of_module: + return True + return False + +class TestFxComposability(TestCase): + r"""This series of tests checks that various steps of the quantization and sparsity flow + compose cleanly despite variation in sequencing. + """ + def test_q_prep_fx_before_s_prep(self): + r""" + This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx + compose cleanly without issue and that the final result is sparsified without + having to call squash mask between sparse prepare and convert_fx. This also tests the + automatic fusion that occurs during prepare_fx. + """ + ( + mod, + sparsifier, + _, + ) = _get_model_and_sparsifier_and_sparse_config() + + example = torch.randn(1, 4, 4, 4) + qconfig = tq.get_default_qconfig("fbgemm") + qconfig_mapping = tq.QConfigMapping() \ + .set_module_name("4", qconfig) \ + .set_module_name("5", qconfig) + + + mod = prepare_fx(mod, qconfig_mapping, (example,)) + + # its absolutely broken by auto fusion in fx + # but will still work if you put the correct fqn in + sparse_config = [ + { + "tensor_fqn": "5.0.weight", + "sparsity_level": 0.7, + "sparse_block_shape": (1, 4), + "zeros_per_block": 4, + }, + {"tensor_fqn": "0.0.weight"}, + ] + sparsifier.prepare(mod, config=sparse_config) + + # check that correct modules had parametrizations added and + # that none were lost during prepare + self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) + self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) + + # check that correct observers were inserted and that matching + # occured successfully + self.assertTrue(_module_has_activation_post_process(mod, "5")) + sparsifier.step() + sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) + mod(example) + mod = convert_fx(mod) + + # check that final module is the expected quantized module and that the model runs + self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU)) + self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) + + # check that module was actually sparsified + cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0]) + self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) + self.assertGreaterAlmostEqual( + sparsity_level, sparse_config[0]["sparsity_level"] + ) + self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) + + def test_q_prep_fx_s_prep_ref_conv(self): + r""" + This checks that the ordering: prepare_fx -> sparse prepare -> convert_to_reference_fx + compose cleanly without issue and that the final result is sparsified without + having to call squash mask before convert_to_reference_fx. + """ + ( + mod, + sparsifier, + _, + ) = _get_model_and_sparsifier_and_sparse_config() + + example = torch.randn(1, 4, 4, 4) + qconfig = tq.get_default_qconfig("fbgemm") + qconfig_mapping = tq.QConfigMapping() \ + .set_module_name("4", qconfig) \ + .set_module_name("5", qconfig) + + mod = prepare_fx(mod, qconfig_mapping, (example,)) + + # its absolutely broken by auto fusion in fx + # but will still work if you put the correct fqn in + sparse_config = [ + { + "tensor_fqn": "5.0.weight", + "sparsity_level": 0.7, + "sparse_block_shape": (1, 4), + "zeros_per_block": 4, + }, + {"tensor_fqn": "0.0.weight"}, + ] + sparsifier.prepare(mod, config=sparse_config) + + # check that correct modules had parametrizations added and + # that none were lost during prepare + self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) + self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) + + # check that correct observers were inserted and that matching + # occured successfully + self.assertTrue(_module_has_activation_post_process(mod, "5")) + sparsifier.step() + sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) + mod(example) + mod = convert_to_reference_fx(mod) + + # check that final module is the expected quantized module and that the model runs + self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.LinearReLU)) + self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) + self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.nn.quantized._reference.Linear)) + + # check that module was actually sparsified + cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) + self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) + self.assertGreaterAlmostEqual( + sparsity_level, sparse_config[0]["sparsity_level"] + ) + self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) + + def test_s_prep_before_q_prep_fx(self): + r""" + This test checks that the ordering of sparse prepare -> prepare_fx -> convert_fx + compose cleanly without issue and that the final result is sparsified without + having to call squash mask before convert_fx. + """ + ( + mod, + sparsifier, + sparse_config, + ) = _get_model_and_sparsifier_and_sparse_config() + sparsifier.prepare(mod, config=sparse_config) + + example = torch.randn(1, 4, 4, 4) + qconfig = tq.get_default_qconfig("fbgemm") + qconfig_mapping = tq.QConfigMapping() \ + .set_module_name("4", qconfig) \ + .set_module_name("5", qconfig) + mod = prepare_fx(mod, qconfig_mapping, (example,)) + + # check that correct modules had parametrizations added and + # that none were lost during prepare + self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) + self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) + + # check that correct observers were inserted and that matching + # occured successfully + self.assertTrue(_module_has_activation_post_process(mod, "5")) + sparsifier.step() + sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) + mod(example) + mod = convert_fx(mod) + + # check that final module is the expected quantized module and that the model runs + self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU)) + self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) + + # check that module was actually sparsified + cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0]) + self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) + self.assertGreaterAlmostEqual( + sparsity_level, sparse_config[0]["sparsity_level"] + ) + self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) + + def test_s_prep_before_qat_prep_fx(self): + r""" + This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx + compose cleanly without issue and that the final result is sparsified without + having to call squash mask before convert_fx. + """ + ( + mod, + sparsifier, + sparse_config, + ) = _get_model_and_sparsifier_and_sparse_config() + sparsifier.prepare(mod, config=sparse_config) + + example = torch.randn(1, 4, 4, 4) + qconfig = tq.get_default_qat_qconfig("fbgemm") + qconfig_mapping = tq.QConfigMapping() \ + .set_module_name("4", qconfig) \ + .set_module_name("5", qconfig) + mod = prepare_qat_fx(mod, qconfig_mapping, (example,)) + + # check that correct modules had parametrizations added and + # that none were lost during prepare + self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) + self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations")) + self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.qat.LinearReLU)) + + # check that correct observers were inserted and that matching + # occured successfully + self.assertTrue(_module_has_activation_post_process(mod, "5")) + sparsifier.step() + sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight")) + mod(example) + mod = convert_fx(mod) + + # check that final module is the expected quantized module and that the model runs + self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU)) + self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) + + # check that module was actually sparsified + cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0]) + self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) + self.assertGreaterAlmostEqual( + sparsity_level, sparse_config[0]["sparsity_level"] + ) + self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) + + def test_s_prep_q_prep_fx_ref(self): + r""" + This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx + compose cleanly without issue and that the final result is sparsified without + having to call squash mask before convert_to_reference_fx. + """ + ( + mod, + sparsifier, + sparse_config, + ) = _get_model_and_sparsifier_and_sparse_config() + sparsifier.prepare(mod, config=sparse_config) + + example = torch.randn(1, 4, 4, 4) + qconfig = tq.get_default_qconfig("fbgemm") + qconfig_mapping = tq.QConfigMapping() \ + .set_module_name("4", qconfig) \ + .set_module_name("5", qconfig) + mod = prepare_fx(mod, qconfig_mapping, (example,)) + + # check that correct modules had parametrizations added and + # that none were lost during prepare + self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) + self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) + + # check that correct observers were inserted and that matching + # occured successfully + self.assertTrue(_module_has_activation_post_process(mod, "5")) + sparsifier.step() + sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) + mod(example) + mod = convert_to_reference_fx(mod) + + # check that final module is the expected quantized module and that the model runs + self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.LinearReLU)) + self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) + self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.nn.quantized._reference.Linear)) + + # check that module was actually sparsified + cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) + self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) + self.assertGreaterAlmostEqual( + sparsity_level, sparse_config[0]["sparsity_level"] + ) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) diff --git a/test/ao/sparsity/test_data_scheduler.py b/test/ao/sparsity/test_data_scheduler.py new file mode 100644 index 0000000000000..543c9afd019fa --- /dev/null +++ b/test/ao/sparsity/test_data_scheduler.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# Owner(s): ["module: unknown"] + +import logging +import warnings +from torch.testing._internal.common_utils import TestCase +from torch import nn +import torch +from typing import Tuple +import copy + +from torch.ao.sparsity._experimental.data_sparsifier import DataNormSparsifier +from torch.ao.sparsity._experimental.data_scheduler import BaseDataScheduler + +logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) + + +class ImplementedDataScheduler(BaseDataScheduler): + def __init__(self, sparsifier, sparsifier_hyperparam, last_epoch=-1, verbose=False): + super().__init__(sparsifier, sparsifier_hyperparam, last_epoch, verbose) + + def get_schedule_param(self): + if self.last_epoch > 0: + return {name: config['sparsity_level'] * 0.5 + for name, config in self.data_sparsifier.data_groups.items()} + else: + return self.base_param + + +class TestBaseDataScheduler(TestCase): + def _get_data(self): + tensor1, param1, emb1 = torch.randn(5, 5), nn.Parameter(torch.randn(10, 10)), nn.Embedding(50, 5) + data_list = [ + ('tensor1', tensor1), ('param1', param1), ('emb1', emb1) + ] + defaults = { + 'sparsity_level': 0.7, + 'sparse_block_shape': (1, 4), + 'zeros_per_block': 2 + } + data_with_config = [ + { + 'name': 'tensor2', 'data': torch.randn(4, 4), + 'config': {'sparsity_level': 0.3} + } + ] + return data_list, data_with_config, defaults + + def _get_sparsifier(self, data_list, data_with_config, defaults): + sparsifier = DataNormSparsifier(data_list, **defaults) + for data_config_dict in data_with_config: + name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config'] + sparsifier.add_data(name=name, data=data, **config) + return sparsifier + + def _get_scheduler(self, sparsifier, schedule_param): + scheduler = ImplementedDataScheduler(sparsifier, schedule_param) + return scheduler + + def _get_schedule_param(self): + return 'sparsity_level' + + def _get_name_data_config(self, some_data, defaults): + config = copy.deepcopy(defaults) + if isinstance(some_data, Tuple): + # dealing with data_list + name, data = some_data + else: + # dealing with data_with_config + name, data, new_config = some_data['name'], some_data['data'], some_data['config'] + config.update(new_config) + return name, data, config + + def test_constructor(self): + """Checks if the warning is thrown if the scheduler step is called + before the sparsifier step""" + data_list, data_with_config, defaults = self._get_data() + sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) + schedule_param = self._get_schedule_param() + scheduler = self._get_scheduler(sparsifier, schedule_param) + + assert scheduler.data_sparsifier == sparsifier + assert scheduler._step_count == 1 + + for name, config in sparsifier.data_groups.items(): + assert scheduler.base_param[name] == config.get(schedule_param, None) + + def test_order_of_steps(self): + data_list, data_with_config, defaults = self._get_data() + sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) + schedule_param = self._get_schedule_param() + scheduler = self._get_scheduler(sparsifier, schedule_param) + + # Sparsifier step is not called + with self.assertWarns(UserWarning): + scheduler.step() + + # Correct order has no warnings + # Note: This will trigger if other warnings are present. + with warnings.catch_warnings(record=True) as w: + sparsifier.step() + scheduler.step() + # Make sure there is no warning related to the base_data_scheduler + for warning in w: + fname = warning.filename + fname = '/'.join(fname.split('/')[-5:]) + assert fname != 'torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py' + + def test_step(self): + data_list, data_with_config, defaults = self._get_data() + sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) + schedule_param = self._get_schedule_param() + scheduler = self._get_scheduler(sparsifier, schedule_param) + + all_data = data_list + data_with_config + + for some_data in all_data: + name, _, config = self._get_name_data_config(some_data, defaults) + assert sparsifier.data_groups[name][schedule_param] == config[schedule_param] + + sparsifier.step() + scheduler.step() + + for some_data in all_data: + name, _, config = self._get_name_data_config(some_data, defaults) + assert sparsifier.data_groups[name][schedule_param] == config[schedule_param] * 0.5 + + # checking step count + step_cnt = 5 + for _ in range(0, step_cnt): + sparsifier.step() + scheduler.step() + + assert scheduler._step_count == step_cnt + 2 # step_cnt + step above + 1 step in constructor + + def test_state_dict(self): + data_list, data_with_config, defaults = self._get_data() + sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) + schedule_param = self._get_schedule_param() + scheduler1 = self._get_scheduler(sparsifier, schedule_param) + + sparsifier.step() + scheduler1.step() + + scheduler2 = self._get_scheduler(sparsifier, schedule_param) + all_data = data_list + data_with_config + for some_data in all_data: + name, _, _ = self._get_name_data_config(some_data, defaults) + assert scheduler1.base_param[name] != scheduler2.base_param[name] + assert scheduler1._last_param[name] == scheduler2.base_param[name] + + scheduler1_state = scheduler1.state_dict() + scheduler2.load_state_dict(scheduler1_state) + + for some_data in all_data: + name, _, _ = self._get_name_data_config(some_data, defaults) + assert scheduler1.base_param[name] == scheduler2.base_param[name] + assert scheduler1._last_param[name] == scheduler2._last_param[name] diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index 77becbf7861b7..23c067faf93cc 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -2,15 +2,17 @@ # Owner(s): ["module: unknown"] import logging -import random import torch from torch.nn.utils.parametrize import is_parametrized from torch.testing._internal.common_utils import TestCase -from torch.ao.sparsity import BaseDataSparsifier, DataNormSparsifier + from typing import Tuple from torch import nn import itertools import math +import copy + +from torch.ao.sparsity._experimental.data_sparsifier import BaseDataSparsifier, DataNormSparsifier logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) @@ -26,7 +28,7 @@ def update_mask(self, name, data, **kwargs): linear_state['step_count'] = linear_state.get('step_count', 0) + 1 -class _BaseDataSparsiferTestRunner: +class _BaseDataSparsiferTestCase(TestCase): r"""This helper test class takes in any supported type of and runs some tests. The user is required to pass in the data that needs to sparsified and the runner will run some tests that needs to be passed in order for the data @@ -34,47 +36,59 @@ class _BaseDataSparsiferTestRunner: TODO: Change the structure by creating a separate test case class for each member function """ - def __init__(self, data_list, defaults, data_with_config): - self.data_list = data_list - self.defaults = defaults - self.data_with_config = data_with_config - - # Temporary hack to quickly fix failing tests. - # This will be rewritten as soon as possible - self._test_case = TestCase() - - def _get_name_data_config(self, some_data): + def run_all_checks(self, data_list, data_with_config, defaults): + self.check_constructor(data_list, data_with_config, defaults) + self.check_squash_mask(data_list, data_with_config, defaults) + self.check_add_data(data_list, data_with_config, defaults) + self.check_step(data_list, data_with_config, defaults) + self.check_state_dict(data_list, data_with_config, defaults) + self.check_memory_reference(data_list, data_with_config, defaults) + + @staticmethod + def _get_name_data_config(some_data, defaults=None): if isinstance(some_data, Tuple): # dealing with data_list name, data = some_data - config = self.defaults + config = defaults else: # dealing with data_with_config name, data, config = some_data['name'], some_data['data'], some_data['config'] return name, data, config - def _get_sparsifier(self): - sparsifier = ImplementedSparsifier(data_list=self.data_list, **self.defaults) - assert len(sparsifier.data_groups) == len(self.data_list) - for data_config_dict in self.data_with_config: + @staticmethod + def _make_sparsifier(data_list, data_with_config, defaults, + sparsifier_type=None, sparsifier_kwargs=None): + if sparsifier_type is None: + sparsifier = ImplementedSparsifier(data_list=data_list, **defaults) + else: + kwargs = copy.deepcopy(defaults) + kwargs.update(sparsifier_kwargs) + kwargs['data_list'] = data_list + sparsifier = sparsifier_type(**kwargs) + assert len(sparsifier.data_groups) == len(data_list) + for data_config_dict in data_with_config: name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config'] sparsifier.add_data(name=name, data=data, **config) return sparsifier - def _run_constructor_test(self): - sparsifier = self._get_sparsifier() - assert len(sparsifier.data_groups) == len(self.data_list) + len(self.data_with_config) + def check_constructor(self, data_list, data_with_config, defaults, **kwargs): + sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs) + self.assertEqual(len(sparsifier.data_groups), + len(data_list) + len(data_with_config), + msg="Sparsifier data groups don't match the input " + f"({len(sparsifier.data_groups)} vs. " + f"{len(data_list) + len(data_with_config)}).") - all_data = self.data_list + self.data_with_config + all_data = data_list + data_with_config for some_data in all_data: - name, _, config = self._get_name_data_config(some_data) - assert name in sparsifier.data_groups - assert sparsifier.data_groups[name] == config + name, _, config = self._get_name_data_config(some_data, defaults=defaults) + self.assertIn(name, sparsifier.data_groups) + self.assertEqual(sparsifier.data_groups[name], config) - def _run_step_test(self): - sparsifier = self._get_sparsifier() - all_data = self.data_list + self.data_with_config + def check_step(self, data_list, data_with_config, defaults, **kwargs): + sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs) + all_data = data_list + data_with_config # Check data and mask before doing the step for some_data in all_data: @@ -83,9 +97,9 @@ def _run_step_test(self): sparsified_data = sparsifier.get_data(name=name, return_original=False) original_data = sparsifier.get_data(name=name, return_original=True) mask = sparsifier.get_mask(name=name) - assert torch.all(sparsified_data == data) - assert torch.all(original_data == data) - assert torch.all(mask[0] == 1) + self.assertEqual(sparsified_data, data) + self.assertEqual(original_data, data) + self.assertEqualBroadcasting(mask[0], 1) step_count = 3 @@ -97,15 +111,15 @@ def _run_step_test(self): sparsified_data = sparsifier.get_data(name=name, return_original=False) original_data = sparsifier.get_data(name=name, return_original=True) mask = sparsifier.get_mask(name=name) - assert torch.all(sparsified_data[0] == 0) - assert torch.all(original_data == data) - assert torch.all(mask[0] == 0) + self.assertEqualBroadcasting(sparsified_data[0], 0) + self.assertEqual(original_data, data) + self.assertEqualBroadcasting(mask[0], 0) assert 'step_count' in sparsifier.state[name] assert sparsifier.state[name]['step_count'] == 3 - def _run_squash_mask_test(self): - sparsifier = self._get_sparsifier() - all_data = self.data_list + self.data_with_config + def check_squash_mask(self, data_list, data_with_config, defaults, **kwargs): + sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs) + all_data = data_list + data_with_config for some_data in all_data: name, _, _ = self._get_name_data_config(some_data) assert hasattr(sparsifier._container, name) @@ -116,52 +130,54 @@ def _run_squash_mask_test(self): for some_data in all_data: name, _, _ = self._get_name_data_config(some_data) assert not is_parametrized(sparsifier._container, name) # not parametrized anymore - with self._test_case.assertRaises(ValueError): + with self.assertRaises(ValueError): sparsifier.get_data(name, return_original=True) - def _run_add_data_test(self): - sparsifier = self._get_sparsifier() - all_data = self.data_list + self.data_with_config + def check_add_data(self, data_list, data_with_config, defaults, **kwargs): + sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs) + all_data = data_list + data_with_config for some_data in all_data: - name1, data1, _ = self._get_name_data_config(some_data) + name1, data1, config = self._get_name_data_config(some_data, defaults=defaults) data1 = sparsifier._extract_weight(data1) + data1_old = copy.deepcopy(data1) assert torch.all(data1 == sparsifier.get_data(name=name1)) - # get some other data at random and with the same name - rand_idx = random.randint(0, len(all_data) - 1) - _, data2, _ = self._get_name_data_config(all_data[rand_idx]) - data2 = sparsifier._extract_weight(data2) + + sparsifier.step() + mask = sparsifier.get_mask(name1) + + data2 = torch.randn(data1.shape) # add another data with the same shape as original data sparsifier.add_data(name=name1, data=data2) assert torch.all(data2 == sparsifier.get_data(name=name1)) - def run_tests(self): - self._run_constructor_test() - self._run_squash_mask_test() - self._run_add_data_test() - self._run_step_test() - self._run_state_dict_test() + assert torch.all(sparsifier.get_mask(name1) == mask) # mask should not change + assert torch.all(data1_old == data1) + + assert sparsifier.data_groups[name1] == config # if replaced old_config should match new config - def _run_state_dict_test(self): - sparsifier1 = self._get_sparsifier() - sparsifier2 = ImplementedSparsifier(data_list=[self.data_list[0]]) + def check_state_dict(self, data_list, data_with_config, defaults, **kwargs): + sparsifier1 = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs) + sparsifier2 = self._make_sparsifier(data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs) sparsifier1.step() state_dict1 = sparsifier1.state_dict() assert sparsifier1.state != sparsifier2.state - name, _, _ = self._get_name_data_config(self.data_list[0]) - self._test_case.assertNotEqual(sparsifier1.get_mask(name), sparsifier2.get_mask(name)) + name, _, _ = self._get_name_data_config(data_list[0]) + self.assertNotEqual(sparsifier1.get_mask(name), sparsifier2.get_mask(name)) sparsifier2.load_state_dict(state_dict1) assert len(sparsifier1.state) == len(sparsifier2.state) assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups) - for name in sparsifier1.state.keys(): + state1 = state_dict1['state'] + for name in state1.keys(): # compare mask assert name in sparsifier2.state assert 'mask' in sparsifier2.state[name] assert 'mask' in sparsifier1.state[name] - mask1, mask2 = sparsifier1.state[name]['mask'], sparsifier2.state[name]['mask'] - assert torch.all(mask1 == mask2) + mask1, mask2 = state1[name]['mask'], sparsifier2.state[name]['mask'] + assert mask1.is_sparse and not mask2.is_sparse + assert torch.all(mask1.to_dense() == mask2) # mask1 is stored as sparse coo now # compare data_groups dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups @@ -177,87 +193,52 @@ def _run_state_dict_test(self): param2 = getattr(container2.parametrizations, name)[0] assert hasattr(param1, 'mask') assert hasattr(param2, 'mask') - self._test_case.assertEqual(param1.__dict__, param2.__dict__) - - -class TestBaseDataSparsifier(TestCase): - """To add unit tests to support new data types for the BaseDataSparsifier, create the following - data_list: List of tuples of name, data to be added to the constructor - defaults: default config for the above data in data_list - data_with_config: list of dictionaries defining name, data and config (look test_tensors()) - - Once the above is done, create an instance of TestBaseDataSparsifierType and call all the run_tests() - """ - def test_tensors(self): - tensor1, tensor2, tensor3 = torch.randn(3, 3), torch.randn(4, 4), torch.randn(5, 5) - tensor4, tensor5 = torch.randn(1, 1), torch.randn(4, 4) - data_list = [('tensor1', tensor1), ('tensor2', tensor2), ('tensor3', tensor3)] - defaults = {'test': 3} - - data_with_config = [ - { - 'name': 'tensor4', 'data': tensor4, 'config': {'test': 7} - }, - { - 'name': 'tensor5', 'data': tensor5, 'config': {'test': 8} - }, - ] - tensor_test = _BaseDataSparsiferTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config) - tensor_test.run_tests() - - def test_nn_parameters(self): - param1, param2, param3 = nn.Parameter(torch.randn(3, 3)), nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5)) - param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(torch.randn(4, 4)) - data_list = [('param1', param1), ('param2', param2), ('param3', param3)] - defaults = {'test': 3} + self.assertEqual(param1.__dict__, param2.__dict__) - data_with_config = [ - { - 'name': 'param4', 'data': param4, 'config': {'test': 7} - }, - { - 'name': 'param5', 'data': param5, 'config': {'test': 8} - }, - ] - param_test = _BaseDataSparsiferTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config) - param_test.run_tests() - - def test_nn_embeddings(self): - emb1, emb2, = nn.Embedding(10, 3), nn.Embedding(20, 3) - emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) + def check_memory_reference(self, data_list, data_with_config, defaults, **kwargs): + """Checks if the data is truly "attached" to the sparsifier. Meaning, when the + data is changed outside of the sparsifier, the changes must be reflected on the data + inside the data sparsifier as well. + This makes sure that the sparsifier is holding the memory reference of the data and + not copies. - emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) - data_list = [('emb1', emb1), ('emb1_bag', emb1_bag), ('emb2', emb2), ('emb2_bag', emb2_bag)] - defaults = {'test': 3} - - data_with_config = [ - { - 'name': 'emb3', 'data': emb3, 'config': {'test': 7} - }, - { - 'name': 'emb3_bag', 'data': emb3_bag, 'config': {'test': 8} - }, - ] - emb_test = _BaseDataSparsiferTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config) - emb_test.run_tests() + This test modifies the data and asserts that data in the sparsifier is changed as well + """ + sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs) + all_data = data_list + data_with_config + for some_data in all_data: + name, data, _ = self._get_name_data_config(some_data) + weight = sparsifier._extract_weight(data) + weight.data = weight + torch.randn(*weight.shape) + contained_data = sparsifier.get_data(name=name) + assert id(weight.data) == id(contained_data.data) + assert torch.all(contained_data == weight) -class _NormDataSparsifierTestRunner(_BaseDataSparsiferTestRunner): +class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase): r"""This helper test class takes in any supported type of and runs some tests. This inherits the TestBaseDataSparsifierRuner wherein some functions are over-ridden to take accomodate the specific sparsifier. TODO: Change the structure by creating a separate test case class for each member function """ - def __init__(self, data_list, defaults, data_with_config, norm_type='L1'): - super().__init__(data_list=data_list, defaults=defaults, data_with_config=data_with_config) + def run_all_checks(self, data_list, defaults, data_with_config, norm_type='L1'): assert norm_type in ['L1', 'L2'] - self.norm_type = norm_type - - def _get_bounds_on_actual_sparsity(self, config, tensor_shape): + kwargs = { + 'sparsifier_type': DataNormSparsifier, + 'sparsifier_kwargs': {'norm': norm_type} + } + self.check_constructor(data_list, data_with_config, defaults, **kwargs) + self.check_squash_mask(data_list, data_with_config, defaults, **kwargs) + self.check_add_data(data_list, data_with_config, defaults, **kwargs) + self.check_state_dict(data_list, data_with_config, defaults, **kwargs) + self.check_step(data_list, data_with_config, defaults, norm_type=norm_type) + self.check_step_2_of_4(norm_type=norm_type) + self.check_sparsity_level(data_list, data_with_config, defaults, norm_type=norm_type) + self.check_memory_reference(data_list, data_with_config, defaults, **kwargs) + + @staticmethod + def _get_bounds_on_actual_sparsity(config, tensor_shape): r"""This function gets the bounds on actual sparsity. Note:: Although we specify the sparsity_level parameter, this does not mean that @@ -277,7 +258,7 @@ def _get_bounds_on_actual_sparsity(self, config, tensor_shape): return (1.0, 1.0) else: # min value assumes zeros_per_block is 1 - min_values_sparsified = number_blocks * sparsity_level + min_values_sparsified = round(number_blocks * sparsity_level) # max value assumes actual zeros_per_block max_values_sparsified = min_values_sparsified * min(values_per_block, zeros_per_block) lower_bound = min_values_sparsified / (height * width) @@ -286,17 +267,11 @@ def _get_bounds_on_actual_sparsity(self, config, tensor_shape): lower_bound, upper_bound = round(lower_bound, 3), round(upper_bound, 3) return lower_bound, upper_bound - def _get_sparsifier(self): - sparsifier = DataNormSparsifier(data_list=self.data_list, norm=self.norm_type, **self.defaults) - assert len(sparsifier.data_groups) == len(self.data_list) - for data_config_dict in self.data_with_config: - name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config'] - sparsifier.add_data(name=name, data=data, **config) - return sparsifier - - def _run_step_test(self): - sparsifier = self._get_sparsifier() - all_data = self.data_list + self.data_with_config + def check_step(self, data_list, data_with_config, defaults, norm_type='L1'): + sparsifier = self._make_sparsifier(data_list, data_with_config, defaults, + sparsifier_type=DataNormSparsifier, + sparsifier_kwargs={'norm': norm_type}) + all_data = data_list + data_with_config # mask before step() should not be sparsified for some_data in all_data: @@ -318,7 +293,10 @@ def _run_step_test(self): iters_before_collapse = 100 - test_sparsifier = DataNormSparsifier(sparsity_level=0.5, sparse_block_shape=(1, 4), zeros_per_block=4, norm=self.norm_type) + test_sparsifier = DataNormSparsifier(sparsity_level=0.5, + sparse_block_shape=(1, 4), + zeros_per_block=4, + norm=norm_type) for _ in range(iters_before_collapse): new_data = torch.randn(20, 20) @@ -328,19 +306,19 @@ def _run_step_test(self): mask = mask.to(torch.float) assert (1.0 - mask.mean().item()) > 0 # some sparsity achieved - def _run_step_2_of_4_test(self): + def check_step_2_of_4(self, norm_type): # overriding default config for test purposes default_config = {'sparsity_level': 1.0, 'zeros_per_block': 2, 'sparse_block_shape': (1, 4)} data_list = [('test_data', torch.randn(4, 4))] - sparsifier = DataNormSparsifier(data_list=data_list, norm=self.norm_type, **default_config) + sparsifier = DataNormSparsifier(data_list=data_list, norm=norm_type, **default_config) sparsifier.step() for some_data in data_list: name, _ = some_data mask = sparsifier.get_mask(name=name) mask = mask.to(torch.float) - self._test_case.assertAlmostEqual(1.0 - mask.mean().item(), 0.5, places=2) + self.assertAlmostEqual(1.0 - mask.mean().item(), 0.5, places=2) for row in mask: for idx in range(0, len(row), 4): block = row[idx:idx + 4] @@ -348,19 +326,23 @@ def _run_step_2_of_4_test(self): assert (block[:2] == 0).all() assert (block[2:] != 0).all() - def _run_sparsity_level_test(self): + def check_sparsity_level(self, data_list, data_with_config, defaults, norm_type='L1'): sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0] sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)] zeros_per_blocks = [0, 1, 2, 3, 4] - sparsifier = self._get_sparsifier() + sparsifier = DataNormSparsifier(data_list=data_list, norm=norm_type) + testcases = itertools.tee(itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks)) + assert len(data_with_config) > 0 and 'name' in data_with_config[0] and 'data' in data_with_config[0] # get some data - name, data, _ = self.data_with_config['name'], self.data_with_config['data'] - for sl, sbs, zpb in testcases[0]: - new_name = f'{name}_{sl}_{sbs}_{zpb}' + name, data = data_with_config[0]['name'], data_with_config[0]['data'] + for idx, (sl, sbs, zpb) in enumerate(testcases[0]): + new_name = f'{name}_{idx}' + if zpb > sbs[0] * sbs[1]: + continue current_config = {'sparsity_level': sl, 'sparse_block_shape': sbs, 'zeros_per_block': zpb} sparsifier.add_data(name=new_name, data=data, **current_config) if zpb > sbs[0] * sbs[1]: @@ -368,8 +350,8 @@ def _run_sparsity_level_test(self): sparsifier.step() sparsifier.squash_mask() - for sl, sbs, zpb in testcases[0]: - new_name = f'{name}_{sl}_{sbs}_{zpb}' + for idx, (sl, sbs, zpb) in enumerate(testcases[0]): + new_name = f'{name}_{idx}' sparsified_data = sparsifier.get_data(name=new_name, original=False) # sparse mask sparse_mask = (sparsified_data == 0).float() @@ -381,16 +363,67 @@ def _run_sparsity_level_test(self): true_sl = true_sl * zpb / sbs[0] / sbs[1] assert sparse_mask.mean() == true_sl - def run_tests(self): - self._run_constructor_test() - self._run_squash_mask_test() - self._run_add_data_test() - self._run_state_dict_test() - self._run_step_test() - self._run_step_2_of_4_test() +class TestBaseDataSparsifier(_BaseDataSparsiferTestCase): + """To add unit tests to support new data types for the BaseDataSparsifier, create the following + data_list: List of tuples of name, data to be added to the constructor + defaults: default config for the above data in data_list + data_with_config: list of dictionaries defining name, data and config (look test_tensors()) + + Once the above is done, create an instance of TestBaseDataSparsifierType and call all the run_tests() + """ + def test_tensors(self): + tensor1, tensor2, tensor3 = torch.randn(3, 3), torch.randn(4, 4), torch.randn(5, 5) + tensor4, tensor5 = torch.randn(1, 1), torch.randn(4, 4) + data_list = [('tensor1', tensor1), ('tensor2', tensor2), ('tensor3', tensor3)] + defaults = {'test': 3} + + data_with_config = [ + { + 'name': 'tensor4', 'data': tensor4, 'config': {'test': 7} + }, + { + 'name': 'tensor5', 'data': tensor5, 'config': {'test': 8} + }, + ] + self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config) + + def test_nn_parameters(self): + param1, param2, param3 = nn.Parameter(torch.randn(3, 3)), nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5)) + param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(torch.randn(4, 4)) + data_list = [('param1', param1), ('param2', param2), ('param3', param3)] + defaults = {'test': 3} + + data_with_config = [ + { + 'name': 'param4', 'data': param4, 'config': {'test': 7} + }, + { + 'name': 'param5', 'data': param5, 'config': {'test': 8} + }, + ] + self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config) + + def test_nn_embeddings(self): + emb1, emb2, = nn.Embedding(10, 3), nn.Embedding(20, 3) + emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) -class TestNormDataSparsifiers(TestCase): + emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) + data_list = [('emb1', emb1), ('emb1_bag', emb1_bag), ('emb2', emb2), ('emb2_bag', emb2_bag)] + defaults = {'test': 3} + + data_with_config = [ + { + 'name': 'emb3', 'data': emb3, 'config': {'test': 7} + }, + { + 'name': 'emb3_bag', 'data': emb3_bag, 'config': {'test': 8} + }, + ] + self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config) + + +class TestNormDataSparsifiers(_NormDataSparsifierTestCase): """To add unit tests to support new data types for the NormDataSparsifier, create the following data_list: List of tuples of name, data to be added to the constructor defaults: default config for the above data in data_list @@ -399,8 +432,8 @@ class TestNormDataSparsifiers(TestCase): Once the above is done, create an instance of _NormDataSparsifierTestRunner and call run_tests() """ def test_tensors(self): - tensor1, tensor2, tensor3 = torch.randn(3, 3), torch.randn(4, 4), torch.randn(5, 5) - tensor4, tensor5 = torch.randn(10, 10), torch.randn(4, 4) + tensor1, tensor2, tensor3 = torch.randn(1, 10), torch.randn(4, 4), torch.randn(1, 5) + tensor4, tensor5 = torch.randn(1, 2), torch.randn(4, 4) data_list = [('tensor1', tensor1), ('tensor2', tensor2), ('tensor3', tensor3)] defaults = {'sparsity_level': 0.5, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4} @@ -414,16 +447,13 @@ def test_tensors(self): 'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6} }, ] - tensor_test_l1 = _NormDataSparsifierTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config, norm_type='L1') - tensor_test_l1.run_tests() - - tensor_test_l2 = _NormDataSparsifierTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config, norm_type='L2') - tensor_test_l2.run_tests() + self.run_all_checks(data_list=data_list, defaults=defaults, + data_with_config=data_with_config, norm_type='L1') + self.run_all_checks(data_list=data_list, defaults=defaults, + data_with_config=data_with_config, norm_type='L2') def test_nn_parameters(self): - param1, param2, param3 = nn.Parameter(torch.randn(3, 3)), nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5)) + param1, param2, param3 = nn.Parameter(torch.randn(1, 8)), nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5)) param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(torch.randn(4, 4)) data_list = [('param1', param1), ('param2', param2), ('param3', param3)] defaults = {'sparsity_level': 0.5, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4} @@ -438,13 +468,10 @@ def test_nn_parameters(self): 'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6} }, ] - param_test_l1 = _NormDataSparsifierTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config, norm_type='L1') - param_test_l1.run_tests() - - param_test_l2 = _NormDataSparsifierTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config, norm_type='L2') - param_test_l2.run_tests() + self.run_all_checks(data_list=data_list, defaults=defaults, + data_with_config=data_with_config, norm_type='L1') + self.run_all_checks(data_list=data_list, defaults=defaults, + data_with_config=data_with_config, norm_type='L2') def test_nn_embeddings(self): emb1, emb2, = nn.Embedding(10, 3), nn.Embedding(20, 3) @@ -464,11 +491,8 @@ def test_nn_embeddings(self): 'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6} }, ] - emb_test_l1 = _NormDataSparsifierTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config, norm_type='L1') - emb_test_l1.run_tests() - - emb_test_l2 = _NormDataSparsifierTestRunner(data_list=data_list, defaults=defaults, - data_with_config=data_with_config, norm_type='L2') + self.run_all_checks(data_list=data_list, defaults=defaults, + data_with_config=data_with_config, norm_type='L1') - emb_test_l2.run_tests() + self.run_all_checks(data_list=data_list, defaults=defaults, + data_with_config=data_with_config, norm_type='L2') diff --git a/test/ao/sparsity/test_kernels.py b/test/ao/sparsity/test_kernels.py index 04a9343459997..3bcbed09c5946 100644 --- a/test/ao/sparsity/test_kernels.py +++ b/test/ao/sparsity/test_kernels.py @@ -13,8 +13,7 @@ import torch.ao.quantization as tq from torch import nn -from torch.ao.nn.sparse import quantized as ao_nn_sq -from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern +from torch.ao.sparsity.sparsifier.utils import fqn_to_module from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_quantized import ( @@ -110,232 +109,204 @@ def test_sparse_qlinear(self): ) -class TestQuantizedSparseLayers(TestCase): - class SparseQuantizedModel(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - self.linear = nn.Linear(in_channels, out_channels) - - def forward(self, x): - return self.linear(x) +def _sparse_layer_test_helper( + model_class, + sparse_mapping, + ref_mapping, + qconfig_dict, + fqn_to_check, + test_class, + test_scripting, +): + # SET UP TEST PARAMETERS, INPUTS AND WEIGHTS + # ------------------------------------------ + batch_size = 12 + input_channels = 4 + output_channels = 7 + model = model_class(input_channels, output_channels) + + # For sparse kernels both the activation and weight ZP = 0 + X_scale = 0.2 + X_zp = 2 + W_scale = 1e-2 + W_zp = 0 + + X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32) + float_bias = torch.randn(output_channels, dtype=torch.float32) + + # generate a weight which we'll insert into the model + W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32) + mask = torch.randint(0, 2, W_fp32.shape) + W_fp32 *= mask + with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): + X_q = torch.quantize_per_tensor( + X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8 + ) + X_fp32 = X_q.dequantize() + + W_q = torch.quantize_per_tensor(W_fp32, W_scale, W_zp, torch.qint8) + + # PREPARE MODELS FOR QUANTIZATION + # ------------------------------- + model.linear.weight = nn.Parameter(W_q.dequantize()) + model.eval() + + # Add `sparse_params` to the model. The test for correct + # sparse_param addition is in the sparsifier tests + model.linear.sparse_params = {"sparse_block_shape": (1, 4)} + + # generate model versions + qmodel = copy.deepcopy(model) + sqmodel = copy.deepcopy(model) + + # generate model versions and apply qconfigs + tq.propagate_qconfig_(qmodel, qconfig_dict) + tq.propagate_qconfig_(sqmodel, qconfig_dict) + + tq.prepare(qmodel, inplace=True) + tq.prepare(sqmodel, inplace=True) + + # calibrate + with torch.no_grad(): + qmodel(X_fp32) + sqmodel(X_fp32) + + # ACTUAL TESTING BEGINS HERE + # -------------------------- + + # Make sure the quantization parameters are computed the same way + qparams = qmodel.linear.qconfig.weight().calculate_qparams() + sqparams = sqmodel.linear.qconfig.weight().calculate_qparams() + test_class.assertEqual(qparams, sqparams) + + sqmodule_to_check = fqn_to_module(sqmodel, fqn_to_check) + sqmodule_start_class = sqmodule_to_check.__class__ + sqmodule_expected_converted_class = sparse_mapping[sqmodule_start_class] + + qmodule_to_check = fqn_to_module(qmodel, fqn_to_check) + qmodule_start_class = qmodule_to_check.__class__ + qmodule_expected_converted_class = ref_mapping[qmodule_start_class] + + # need to determine whether dynamic quantization is being performed since + # input dtype will be different at the end + is_dynamic = isinstance( + qmodule_to_check.activation_post_process, tq.PlaceholderObserver + ) + + tq.convert(sqmodel, inplace=True, mapping=sparse_mapping) + tq.convert(qmodel, inplace=True, mapping=ref_mapping) + + # this code is a duplicate of above since the references do not + # update to the post-convert modules + sqmodule_to_check = fqn_to_module(sqmodel, fqn_to_check) + qmodule_to_check = fqn_to_module(qmodel, fqn_to_check) + + # check that the modules were converted as expected + assert isinstance( + sqmodule_to_check, sqmodule_expected_converted_class + ), "Convert failed" + assert isinstance( + qmodule_to_check, qmodule_expected_converted_class + ), "Mapping failed" + + row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[ + 2: + ] + assert row_block_size == 1 and col_block_size == 4 + + # only run during serialization/deserialization tests + # makes sure script/save/load doesn't malform the sqmodel + if test_scripting: + scripted_sqmodel = torch.jit.script(sqmodel) + scripted_sqmodel.eval() + buffer = io.BytesIO() + torch.jit.save(scripted_sqmodel, buffer) + buffer.seek(0) + sqmodel = torch.jit.load(buffer) + + # use correct input dtype + if is_dynamic: + Y_ref = qmodel(X_fp32) + Y_hat = sqmodel(X_fp32) + test_class.assertEqual(Y_ref, Y_hat) + else: + Y_ref = qmodel(X_q) + Y_hat = sqmodel(X_q) + test_class.assertEqual(Y_ref.dequantize(), Y_hat.dequantize()) + +class SparseQuantizedModel(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.linear = nn.Linear(in_channels, out_channels) + + def forward(self, x): + return self.linear(x) +class TestQuantizedSparseLayers(TestCase): @override_qengines def test_sparse_qlinear(self): - batch_size = 12 - input_channels = 4 - output_channels = 7 - model = self.SparseQuantizedModel(input_channels, output_channels) - - # For sparse kernels both the activation and weight ZP = 0 - X_scale = 0.2 - X_zp = 2 - W_scale = 1e-2 - W_zp = 0 - - X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32) - float_bias = torch.randn(output_channels, dtype=torch.float32) - - W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32) - mask = torch.randint(0, 2, W_fp32.shape) - W_fp32 *= mask + # Note: At the moment, for sparse kernels + # fbgemm supports only static quantized sparse linear + # qnnpack supports only dynamically quantized sparse linear + # Hence we have two different tests. + # fbgemm tests static flow, qnnpack tests dynamic. + # Should be unified later on and tests should be fixed + # appropriately. + model_class = SparseQuantizedModel + fqn_to_check = "linear" + if qengine_is_fbgemm(): + sparse_mapping = tq.get_default_static_sparse_quant_module_mappings() + ref_mapping = tq.get_default_static_quant_module_mappings() + qconfig_dict = {nn.Linear: tq.get_default_qconfig("fbgemm")} + elif qengine_is_qnnpack(): + sparse_mapping = tq.get_default_dynamic_sparse_quant_module_mappings() + ref_mapping = tq.get_default_dynamic_quant_module_mappings() + qconfig_dict = {nn.Linear: tq.qconfig.default_dynamic_qconfig} + else: + return - with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): - X_q = torch.quantize_per_tensor( - X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8 - ) - X_fp32 = X_q.dequantize() - - W_q = torch.quantize_per_tensor(W_fp32, W_scale, W_zp, torch.qint8) - - model.weight = nn.Parameter(W_q.dequantize()) - model.eval() - - # Add `sparse_params` to the model. The test for correct - # sparse_param addition is in the sparsifier tests - model.linear.sparse_params = {'sparse_block_shape': (1, 4)} - - # Note: At the moment, for sparse kernels - # fbgemm supports only static quantized sparse linear - # qnnpack supports only dynamically quantized sparse linear - # Hence we have two different tests. - # fbgemm tests static flow, qnnpack tests dynamic. - # Should be unified later on and tests should be fixed - # appropriately. - if qengine_is_fbgemm(): - model.qconfig = tq.get_default_qconfig('fbgemm') - qmodel = copy.deepcopy(model) - sqmodel = copy.deepcopy(model) - - tq.prepare(qmodel, inplace=True) - tq.prepare(sqmodel, inplace=True) - - with torch.no_grad(): - qmodel(X_fp32) - sqmodel(X_fp32) - - # Make sure the quantization parameters are computed the same way - qparams = qmodel.linear.qconfig.weight().calculate_qparams() - sqparams = sqmodel.linear.qconfig.weight().calculate_qparams() - self.assertEqual(qparams, sqparams) - - # Make sure mapping of sparse kernels does not affect the non-sparse - sparse_mapping = tq.get_default_static_quant_module_mappings() - sparse_mapping[nn.Linear] = ao_nn_sq.Linear - tq.convert(sqmodel, inplace=True, mapping=sparse_mapping) - tq.convert(qmodel, inplace=True) - - assert isinstance(sqmodel.linear, ao_nn_sq.Linear), "Convert failed" - assert isinstance(qmodel.linear, nn.quantized.Linear), "Mapping failed" - - # Make sure numerics are right - Y_ref = qmodel(X_q) - Y_hat = sqmodel(X_q) - self.assertEqual(Y_ref.dequantize(), Y_hat.dequantize()) - - elif qengine_is_qnnpack(): - qconfig = {nn.Linear : tq.qconfig.default_dynamic_qconfig} - qmodel = copy.deepcopy(model) - sqmodel = copy.deepcopy(model) - - tq.propagate_qconfig_(qmodel, qconfig) - tq.propagate_qconfig_(sqmodel, qconfig) - - # Make sure the quantization parameters are computed the same way - qparams = qmodel.linear.qconfig.weight().calculate_qparams() - sqparams = sqmodel.linear.qconfig.weight().calculate_qparams() - self.assertEqual(qparams, sqparams) - - # Make sure mapping of sparse kernels does not affect the non-sparse - sparse_mapping = copy.deepcopy(tq.get_default_dynamic_quant_module_mappings()) - sparse_mapping[nn.Linear] = ao_nn_sq.dynamic.Linear - tq.convert(sqmodel, inplace=True, mapping=sparse_mapping) - tq.convert(qmodel, mapping=tq.get_default_dynamic_quant_module_mappings(), inplace=True) - - assert isinstance(sqmodel.linear, ao_nn_sq.dynamic.Linear), "Convert failed" - assert isinstance(qmodel.linear, nn.quantized.dynamic.Linear), "Mapping failed" - - # Make sure numerics are right - Y_ref = qmodel(X_fp32) - Y_hat = sqmodel(X_fp32) - self.assertEqual(Y_ref, Y_hat) - - # ONEDNN does not support this yet - elif qengine_is_onednn(): - return - - row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[2:] - assert row_block_size == 1 and col_block_size == 4 + _sparse_layer_test_helper( + model_class=model_class, + sparse_mapping=sparse_mapping, + ref_mapping=ref_mapping, + qconfig_dict=qconfig_dict, + fqn_to_check=fqn_to_check, + test_class=self, + test_scripting=False, + ) @override_qengines def test_sparse_qlinear_serdes(self): - batch_size = 12 - input_channels = 4 - output_channels = 7 - model = self.SparseQuantizedModel(input_channels, output_channels) + # Note: At the moment, for sparse kernels + # fbgemm supports only static quantized sparse linear + # qnnpack supports only dynamically quantized sparse linear + # Hence we have two different tests. + # fbgemm tests static flow, qnnpack tests dynamic. + # Should be unified later on and tests should be fixed + # appropriately. + model_class = SparseQuantizedModel + fqn_to_check = "linear" + if qengine_is_fbgemm(): + sparse_mapping = tq.get_default_static_sparse_quant_module_mappings() + ref_mapping = tq.get_default_static_quant_module_mappings() + qconfig_dict = {nn.Linear: tq.get_default_qconfig("fbgemm")} + elif qengine_is_qnnpack(): + sparse_mapping = tq.get_default_dynamic_sparse_quant_module_mappings() + ref_mapping = tq.get_default_dynamic_quant_module_mappings() + qconfig_dict = {nn.Linear: tq.qconfig.default_dynamic_qconfig} + else: + return - # For sparse kernels both the activation and weight ZP = 0 - X_scale = 0.2 - X_zp = 0 - W_scale = 1e-2 - W_zp = 0 + _sparse_layer_test_helper( + model_class=model_class, + sparse_mapping=sparse_mapping, + ref_mapping=ref_mapping, + qconfig_dict=qconfig_dict, + fqn_to_check=fqn_to_check, + test_class=self, + test_scripting=True, + ) - with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): - X_fp32 = torch.randn(batch_size, input_channels, dtype=torch.float32) - float_bias = torch.randn(output_channels, dtype=torch.float32) - X_q = torch.quantize_per_tensor( - X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8 - ) - X_fp32 = X_q.dequantize() - - W_fp32 = torch.randn(output_channels, input_channels, dtype=torch.float32) - mask = torch.randint(0, 2, W_fp32.shape) - W_fp32 *= mask - W_q = torch.quantize_per_tensor(W_fp32, W_scale, W_zp, torch.qint8) - - model.linear.weight = nn.Parameter(W_q.dequantize()) - model.linear.sparse_params = {'sparse_block_shape': (1, 4)} - model.eval() - - # Note: At the moment, for sparse kernels - # fbgemm supports only static quantized sparse linear - # qnnpack supports only dynamically quantized sparse linear - # Hence we have two different tests. - # fbgemm tests static flow, qnnpack tests dynamic. - # Should be unified later on and tests should be fixed - # appropriately. - if qengine_is_fbgemm(): - model.qconfig = tq.get_default_qconfig('fbgemm') - qmodel = copy.deepcopy(model) - sqmodel = copy.deepcopy(model) - - tq.prepare(qmodel, inplace=True) - tq.prepare(sqmodel, inplace=True) - - with torch.no_grad(): - qmodel(X_fp32) - sqmodel(X_fp32) - - # Make sure the quantization parameters are computed the same way - qparams = qmodel.linear.qconfig.weight().calculate_qparams() - sqparams = sqmodel.linear.qconfig.weight().calculate_qparams() - self.assertEqual(qparams, sqparams) - - # Make sure mapping of sparse kernels does not affect the non-sparse - sparse_mapping = tq.get_default_static_quant_module_mappings() - sparse_mapping[nn.Linear] = ao_nn_sq.Linear - tq.convert(sqmodel, inplace=True, mapping=sparse_mapping) - tq.convert(qmodel, inplace=True) - - assert isinstance(sqmodel.linear, ao_nn_sq.Linear), "Convert failed" - assert isinstance(qmodel.linear, nn.quantized.Linear), "Mapping failed" - - scripted_sqmodel = torch.jit.script(sqmodel) - scripted_sqmodel.eval() - buffer = io.BytesIO() - torch.jit.save(scripted_sqmodel, buffer) - buffer.seek(0) - sqmodel = torch.jit.load(buffer) - - # Make sure numerics are right - Y_ref = qmodel(X_q) - Y_hat = sqmodel(X_q) - self.assertEqual(Y_ref.dequantize(), Y_hat.dequantize()) - - elif qengine_is_qnnpack(): - qconfig = {nn.Linear : tq.qconfig.default_dynamic_qconfig} - dqmodel = copy.deepcopy(model) - sdqmodel = copy.deepcopy(model) - - tq.propagate_qconfig_(dqmodel, qconfig) - tq.propagate_qconfig_(sdqmodel, qconfig) - - # Make sure the quantization parameters are computed the same way - qparams = dqmodel.linear.qconfig.weight().calculate_qparams() - sqparams = sdqmodel.linear.qconfig.weight().calculate_qparams() - self.assertEqual(qparams, sqparams) - - # Make sure mapping of sparse kernels does not affect the non-sparse - sparse_mapping = copy.deepcopy(tq.get_default_dynamic_quant_module_mappings()) - sparse_mapping[nn.Linear] = ao_nn_sq.dynamic.Linear - with LinearBlockSparsePattern(1, 4): - tq.convert(sdqmodel, inplace=True, mapping=sparse_mapping) - tq.convert(dqmodel, mapping=tq.get_default_dynamic_quant_module_mappings(), inplace=True) - - assert isinstance(sdqmodel.linear, ao_nn_sq.dynamic.Linear), "Convert failed" - assert isinstance(dqmodel.linear, nn.quantized.dynamic.Linear), "Mapping failed" - - scripted_sdqmodel = torch.jit.script(sdqmodel) - scripted_sdqmodel.eval() - buffer = io.BytesIO() - torch.jit.save(scripted_sdqmodel, buffer) - buffer.seek(0) - sdqmodel = torch.jit.load(buffer) - - # Make sure numerics are right - Y_ref = dqmodel(X_fp32) - Y_hat = sdqmodel(X_fp32) - self.assertEqual(Y_ref, Y_hat) - -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/ao/sparsity/test_pruner.py b/test/ao/sparsity/test_pruner.py index 85f70a017bf0c..6cf7175b9afee 100644 --- a/test/ao/sparsity/test_pruner.py +++ b/test/ao/sparsity/test_pruner.py @@ -7,10 +7,10 @@ import torch from torch import nn -from torch.ao.sparsity import BasePruner, PruningParametrization, ZeroesParametrization +from torch.ao.sparsity._experimental.pruner import BasePruner, PruningParametrization, ZeroesParametrization from torch.nn.utils import parametrize -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) @@ -381,6 +381,7 @@ def _test_step_conv2d_on_device(self, model, config, device): self._check_pruner_valid_after_step(model, pruner, {1}, device) assert model(x).shape == (1, 64, 24, 24) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_step_conv2d(self): bn_model = Conv2dBN() bn_config = [(bn_model.seq[0], bn_model.seq[1]), diff --git a/test/ao/sparsity/test_qlinear_packed_params.py b/test/ao/sparsity/test_qlinear_packed_params.py new file mode 100644 index 0000000000000..b2287207e2d6a --- /dev/null +++ b/test/ao/sparsity/test_qlinear_packed_params.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# Owner(s): ["oncall: mobile"] + +import tempfile +import torch +from torch.ao.nn.sparse.quantized.dynamic.linear import Linear +from torch.testing._internal.common_quantized import ( + qengine_is_qnnpack, + override_quantized_engine, + override_cpu_allocator_for_qnnpack +) +from torch.testing._internal.common_utils import TestCase + +class TestQlinearPackedParams(TestCase): + def test_qlinear_packed_params(self, allow_non_zero_zero_points=False): + # copied from https://pytorch.org/docs/stable/sparse.html#csr-tensor-operations, + # so row/col block indices match that example, but with blocks and + # scaled rows + weight_fp32 = torch.Tensor([ + [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0], + [6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]) + + row_block_size = 1 + col_block_size = 4 + out_features = weight_fp32.shape[0] + in_features = weight_fp32.shape[1] + + scales = [2.0, 6.0, 12.0] + zero_points = [ + ((i + 1) if allow_non_zero_zero_points else 0) for i in range(out_features) + ] + dtype = torch.qint8 + + wide_weight_fp32 = torch.zeros((3, 4008)) # 4000 is tile width for Fbgemm + wide_weight_fp32[0][0] = 4 + wide_weight_fp32[0][4004] = 6 + wide_weight_fp32[1][0] = 8 + + per_tensor_small = ( + torch.quantize_per_tensor( + weight_fp32, + scales[0], + zero_points[0], + dtype + ), + True, + [0, 1, 3, 3], + [2, 0, 1], + [x + (1 if allow_non_zero_zero_points else 0) for x in [ + 1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6 + ]], + ) + + per_channel_small = ( + torch.quantize_per_channel( + weight_fp32, + torch.Tensor(scales), + torch.Tensor(zero_points).to(torch.int), + 0, # axis = 0 + dtype, + ), + False, + [0, 1, 3, 3], + [2, 0, 1], + [x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0) for (i, x) in enumerate([ + 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2 + ])], + ) + + per_tensor_large = ( + torch.quantize_per_tensor( + wide_weight_fp32, + scales[0], + zero_points[0], + dtype, + ), + True, + [0, 2, 3, 3], + [0, 1001, 0], + [x + (1 if allow_non_zero_zero_points else 0) for x in [ + 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0 + ]], + ) + + for (weight, is_per_tensor_quantized, expected_row_block_indices, expected_col_block_indices, expected_weights) in [ + per_tensor_small, per_channel_small, per_tensor_large + ]: + lin = Linear( + out_features=weight.shape[0], + in_features=weight.shape[1], + row_block_size=row_block_size, + col_block_size=col_block_size, + bias=True, + dtype=dtype, + ) + + bias = torch.ones(size=(weight.shape[0],)) + + lin.set_weight_bias(weight, bias, row_block_size, col_block_size) + + serialized = lin._packed_params._packed_params.__getstate__() + + ( + _, # version + bias_, + out_features_block_size_, + in_features_block_size_, + weight_scales_, + weight_zero_points_, + quantization_scheme_, + row_block_indices_, + col_block_indices_, + weights_, + output_channels_, + input_channels_ + ) = serialized[0] + + # Test Serialization + self.assertEqual(bias_, bias) + self.assertEqual(out_features_block_size_, row_block_size) + self.assertEqual(in_features_block_size_, col_block_size) + self.assertEqual(weight_scales_, [scales[0]] if is_per_tensor_quantized else scales) + self.assertEqual(weight_zero_points_, [zero_points[0]] if is_per_tensor_quantized else zero_points) + self.assertEqual(quantization_scheme_, is_per_tensor_quantized) + self.assertEqual(row_block_indices_, expected_row_block_indices) + self.assertEqual(col_block_indices_, expected_col_block_indices) + self.assertEqual(weights_.tolist(), [v + 128 for v in expected_weights]) # weights are serialized as +128 + self.assertEqual(output_channels_, weight.shape[0]) + self.assertEqual(input_channels_, weight.shape[1]) + + # Test Unpacking + (weights_, bias_, out_features_block_size_, in_features_block_size_) = lin._weight_bias() + self.assertEqual(torch.dequantize(weights_), torch.dequantize(weight)) + self.assertEqual(bias_, bias) + self.assertEqual(out_features_block_size_, row_block_size) + self.assertEqual(in_features_block_size_, col_block_size) + + # Test Deserialization + with tempfile.TemporaryFile() as file_buff: + torch.save(lin, file_buff) + file_buff.seek(0) + lin2 = torch.load(file_buff) + self.assertEqual(lin._weight_bias(), lin2._weight_bias()) + # Serialize -> Deserialize -> Serialize should match Serialize + self.assertEqual(serialized, lin2._packed_params._packed_params.__getstate__()) + + # Test that op output is preserved by serialize -> deserialize + if qengine_is_qnnpack(): + x = torch.rand(size=(1, weight.shape[1])) + y1 = lin(x) + y2 = lin2(x) + self.assertEqual(y1, y2) + + + def test_qlinear_packed_params_qnnpack(self): + torch.manual_seed(0) + with override_quantized_engine('qnnpack'): + with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()): + self.test_qlinear_packed_params(allow_non_zero_zero_points=True) diff --git a/test/ao/sparsity/test_sparsifier.py b/test/ao/sparsity/test_sparsifier.py index 59bbe64840361..4dd2d49296ecc 100644 --- a/test/ao/sparsity/test_sparsifier.py +++ b/test/ao/sparsity/test_sparsifier.py @@ -258,7 +258,8 @@ def test_step_2_of_4(self): sparsifier.prepare(model, config=[{'tensor_fqn': 'linear.weight'}]) sparsifier.step() # make sure the sparsity level is approximately 50% - self.assertAlmostEqual(model.linear.parametrizations['weight'][0].mask.mean().item(), 0.5, places=2) + mask = model.linear.parametrizations['weight'][0].mask.to(torch.float) # mean works on float only + self.assertAlmostEqual(mask.mean().item(), 0.5, places=2) # Make sure each block has exactly 50% zeros module = sparsifier.groups[0]['module'] mask = module.parametrizations['weight'][0].mask @@ -359,7 +360,7 @@ def test_constructor(self): def test_step(self): model = Model() sparsifier = NearlyDiagonalSparsifier(nearliness=1) - sparsifier.prepare(model, config=[model.linear]) + sparsifier.prepare(model, config=[{'tensor_fqn': 'linear.weight'}]) for g in sparsifier.groups: # Before step @@ -428,7 +429,7 @@ def test_sparsity_levels(self): width, height = layer.weight.shape model.add_module(layer_name, layer) config = { - 'module_fqn': layer_name, + 'tensor_fqn': layer_name + ".weight", 'nearliness': nearliness } diff --git a/test/ao/sparsity/test_sparsity_utils.py b/test/ao/sparsity/test_sparsity_utils.py new file mode 100644 index 0000000000000..add621ebc4bab --- /dev/null +++ b/test/ao/sparsity/test_sparsity_utils.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# Owner(s): ["module: unknown"] + + +import logging + +import torch +from torch.ao.sparsity.sparsifier.utils import ( + fqn_to_module, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) + +from torch.testing._internal.common_quantization import ( + ConvBnReLUModel, + ConvModel, + FunctionalLinear, + LinearAddModel, + ManualEmbeddingBagLinear, + SingleLayerLinearModel, + TwoLayerLinearModel, +) +from torch.testing._internal.common_utils import TestCase + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +model_list = [ + ConvModel, + SingleLayerLinearModel, + TwoLayerLinearModel, + LinearAddModel, + ConvBnReLUModel, + ManualEmbeddingBagLinear, + FunctionalLinear, +] + + +class TestSparsityUtilFunctions(TestCase): + def test_module_to_fqn(self): + """ + Tests that module_to_fqn works as expected when compared to known good + module.get_submodule(fqn) function + """ + for model_class in model_list: + model = model_class() + list_of_modules = [m for _, m in model.named_modules()] + [model] + for module in list_of_modules: + fqn = module_to_fqn(model, module) + check_module = model.get_submodule(fqn) + self.assertEqual(module, check_module) + + def test_module_to_fqn_fail(self): + """ + Tests that module_to_fqn returns None when an fqn that doesn't + correspond to a path to a node/tensor is given + """ + for model_class in model_list: + model = model_class() + fqn = module_to_fqn(model, torch.nn.Linear(3, 3)) + self.assertEqual(fqn, None) + + def test_module_to_fqn_root(self): + """ + Tests that module_to_fqn returns '' when model and target module are the same + """ + for model_class in model_list: + model = model_class() + fqn = module_to_fqn(model, model) + self.assertEqual(fqn, "") + + def test_fqn_to_module(self): + """ + Tests that fqn_to_module operates as inverse + of module_to_fqn + """ + for model_class in model_list: + model = model_class() + list_of_modules = [m for _, m in model.named_modules()] + [model] + for module in list_of_modules: + fqn = module_to_fqn(model, module) + check_module = fqn_to_module(model, fqn) + self.assertEqual(module, check_module) + + def test_fqn_to_module_fail(self): + """ + Tests that fqn_to_module returns None when it tries to + find an fqn of a module outside the model + """ + for model_class in model_list: + model = model_class() + fqn = "foo.bar.baz" + check_module = fqn_to_module(model, fqn) + self.assertEqual(check_module, None) + + def test_fqn_to_module_for_tensors(self): + """ + Tests that fqn_to_module works for tensors, actually all parameters + of the model. This is tested by identifying a module with a tensor, + and generating the tensor_fqn using module_to_fqn on the module + + the name of the tensor. + """ + for model_class in model_list: + model = model_class() + list_of_modules = [m for _, m in model.named_modules()] + [model] + for module in list_of_modules: + module_fqn = module_to_fqn(model, module) + for tensor_name, tensor in module.named_parameters(recurse=False): + tensor_fqn = ( # string manip to handle tensors on root + module_fqn + ("." if module_fqn != "" else "") + tensor_name + ) + check_tensor = fqn_to_module(model, tensor_fqn) + self.assertEqual(tensor, check_tensor) + + def test_get_arg_info_from_tensor_fqn(self): + """ + Tests that get_arg_info_from_tensor_fqn works for all parameters of the model. + Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and + then compares with known (parent) module and tensor_name as well as module_fqn + from module_to_fqn. + """ + for model_class in model_list: + model = model_class() + list_of_modules = [m for _, m in model.named_modules()] + [model] + for module in list_of_modules: + module_fqn = module_to_fqn(model, module) + for tensor_name, tensor in module.named_parameters(recurse=False): + tensor_fqn = ( + module_fqn + ("." if module_fqn != "" else "") + tensor_name + ) + arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) + self.assertEqual(arg_info["module"], module) + self.assertEqual(arg_info["module_fqn"], module_fqn) + self.assertEqual(arg_info["tensor_name"], tensor_name) + self.assertEqual(arg_info["tensor_fqn"], tensor_fqn) + + def test_get_arg_info_from_tensor_fqn_fail(self): + """ + Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn + inputs. The string outputs still work but the output module is expected to be None. + """ + for model_class in model_list: + model = model_class() + tensor_fqn = "foo.bar.baz" + arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) + self.assertEqual(arg_info["module"], None) + self.assertEqual(arg_info["module_fqn"], "foo.bar") + self.assertEqual(arg_info["tensor_name"], "baz") + self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz") diff --git a/test/backends/xeon/test_launch.py b/test/backends/xeon/test_launch.py new file mode 100644 index 0000000000000..056a53ee110df --- /dev/null +++ b/test/backends/xeon/test_launch.py @@ -0,0 +1,65 @@ +# Owner(s): ["module: intel"] + +from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX +import shutil +import subprocess +import tempfile +import unittest + +@unittest.skipIf(not IS_LINUX, "Only works on linux") +class TestTorchrun(TestCase): + def setUp(self): + self._test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__) + + def tearDown(self): + shutil.rmtree(self._test_dir) + + def test_cpu_info(self): + lscpu_info = """# The following is the parsable format, which can be fed to other +# programs. Each different item in every column has an unique ID +# starting from zero. +# CPU,Core,Socket,Node +0,0,0,0 +1,1,0,0 +2,2,0,0 +3,3,0,0 +4,4,1,1 +5,5,1,1 +6,6,1,1 +7,7,1,1 +8,0,0,0 +9,1,0,0 +10,2,0,0 +11,3,0,0 +12,4,1,1 +13,5,1,1 +14,6,1,1 +15,7,1,1 +""" + from torch.backends.xeon.run_cpu import _CPUinfo + cpuinfo = _CPUinfo(lscpu_info) + assert cpuinfo._physical_core_nums() == 8 + assert cpuinfo._logical_core_nums() == 16 + assert cpuinfo.get_node_physical_cores(0) == [0, 1, 2, 3] + assert cpuinfo.get_node_physical_cores(1) == [4, 5, 6, 7] + assert cpuinfo.get_node_logical_cores(0) == [0, 1, 2, 3, 8, 9, 10, 11] + assert cpuinfo.get_node_logical_cores(1) == [4, 5, 6, 7, 12, 13, 14, 15] + assert cpuinfo.get_all_physical_cores() == [0, 1, 2, 3, 4, 5, 6, 7] + assert cpuinfo.get_all_logical_cores() == [0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15] + assert cpuinfo.numa_aware_check([0, 1, 2, 3]) == [0] + assert cpuinfo.numa_aware_check([4, 5, 6, 7]) == [1] + assert cpuinfo.numa_aware_check([2, 3, 4, 5]) == [0, 1] + + def test_multi_threads(self): + num = 0 + with subprocess.Popen(f"python -m torch.backends.xeon.run_cpu --ninstances 4 --use_default_allocator \ + --disable_iomp --disable_numactl --log_path {self._test_dir} --no_python pwd", + shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p: + for line in p.stdout.readlines(): + segs = str(line, "utf-8").strip().split("-") + if segs[-1].strip() == "pwd": + num += 1 + assert num == 4, "Failed to launch multiple instances for inference" + +if __name__ == "__main__": + run_tests() diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000000000..d2e929a9a58db --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,146 @@ +from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape +from _pytest.terminal import _get_raw_skip_reason +from _pytest.stash import StashKey +from _pytest.reports import TestReport +from _pytest.config.argparsing import Parser +from _pytest.config import filename_arg +from _pytest.config import Config +from _pytest._code.code import ReprFileLocation +from typing import Union +from typing import Optional +import xml.etree.ElementTree as ET +import functools + +# a lot of this file is copied from _pytest.junitxml and modified to get rerun info + +xml_key = StashKey["LogXMLReruns"]() + + +def pytest_addoption(parser: Parser) -> None: + group = parser.getgroup("terminal reporting") + group.addoption( + "--junit-xml-reruns", + action="store", + dest="xmlpath_reruns", + metavar="path", + type=functools.partial(filename_arg, optname="--junit-xml-reruns"), + default=None, + help="create junit-xml style report file at given path.", + ) + group.addoption( + "--junit-prefix-reruns", + action="store", + metavar="str", + default=None, + help="prepend prefix to classnames in junit-xml output", + ) + parser.addini( + "junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest" + ) + parser.addini( + "junit_logging_reruns", + "Write captured log messages to JUnit report: " + "one of no|log|system-out|system-err|out-err|all", + default="no", + ) + parser.addini( + "junit_log_passing_tests_reruns", + "Capture log information for passing tests to JUnit report: ", + type="bool", + default=True, + ) + parser.addini( + "junit_duration_report_reruns", + "Duration time to report: one of total|call", + default="total", + ) + parser.addini( + "junit_family_reruns", + "Emit XML for schema: one of legacy|xunit1|xunit2", + default="xunit2", + ) + + +def pytest_configure(config: Config) -> None: + xmlpath = config.option.xmlpath_reruns + # Prevent opening xmllog on worker nodes (xdist). + if xmlpath and not hasattr(config, "workerinput"): + junit_family = config.getini("junit_family_reruns") + config.stash[xml_key] = LogXMLReruns( + xmlpath, + config.option.junitprefix, + config.getini("junit_suite_name_reruns"), + config.getini("junit_logging_reruns"), + config.getini("junit_duration_report_reruns"), + junit_family, + config.getini("junit_log_passing_tests_reruns"), + ) + config.pluginmanager.register(config.stash[xml_key]) + + +def pytest_unconfigure(config: Config) -> None: + xml = config.stash.get(xml_key, None) + if xml: + del config.stash[xml_key] + config.pluginmanager.unregister(xml) + + +class _NodeReporterReruns(_NodeReporter): + def _prepare_content(self, content: str, header: str) -> str: + return content + + def _write_content(self, report: TestReport, content: str, jheader: str) -> None: + if content == "": + return + tag = ET.Element(jheader) + tag.text = bin_xml_escape(content) + self.append(tag) + + +class LogXMLReruns(LogXML): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None: + if hasattr(report, "wasxfail"): + reporter._add_simple("skipped", "xfail-marked test passes unexpectedly") + else: + assert report.longrepr is not None + reprcrash: Optional[ReprFileLocation] = getattr( + report.longrepr, "reprcrash", None + ) + if reprcrash is not None: + message = reprcrash.message + else: + message = str(report.longrepr) + message = bin_xml_escape(message) + reporter._add_simple("rerun", message, str(report.longrepr)) + + def pytest_runtest_logreport(self, report: TestReport) -> None: + super().pytest_runtest_logreport(report) + if report.outcome == "rerun": + reporter = self._opentestcase(report) + self.append_rerun(reporter, report) + if report.outcome == "skipped": + if isinstance(report.longrepr, tuple): + fspath, lineno, reason = report.longrepr + reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}" + report.longrepr = (fspath, lineno, reason) + + def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns: + nodeid: Union[str, TestReport] = getattr(report, "nodeid", report) + # Local hack to handle xdist report order. + workernode = getattr(report, "node", None) + + key = nodeid, workernode + + if key in self.node_reporters: + # TODO: breaks for --dist=each + return self.node_reporters[key] + + reporter = _NodeReporterReruns(nodeid, self) + + self.node_reporters[key] = reporter + self.node_reporters_ordered.append(reporter) + + return reporter diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index f3a72946e6885..d6cc58b7c3bc6 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -831,6 +831,111 @@ TEST(CustomAutogradTest, Hooks) { ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index"); } +TEST(CustomAutogradTest, HooksInplace) { + auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); + + int hook1_count = 0; + auto hook1 = ([&hook1_count](Variable grad) { + hook1_count++; + ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); + }); + + int hook2_count = 0; + auto hook2 = ([&hook2_count](Variable grad) { + hook2_count++; + ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); + }); + + a.register_hook(hook1); + a.mul_(2); + a.register_hook(hook2); + + auto out = (a + 1).sum(); + out.backward(); + + ASSERT_EQ(hook1_count, 1); + ASSERT_EQ(hook2_count, 1); +} + +TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) { + auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); + + int hook1_count = 0; + auto hook1 = ([&hook1_count](Variable grad) { + hook1_count++; + ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); + }); + + int hook2_count = 0; + auto hook2 = ([&hook2_count](Variable grad) { + hook2_count++; + ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); + }); + + int hook3_count = 0; + auto hook3 = ([&hook3_count](Variable grad) { + hook3_count++; + ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); + }); + + a.register_hook(hook1); + a.retain_grad(); + a.register_hook(hook2); + + a.mul_(2); + a.register_hook(hook3); + + auto out = (a + 1).sum(); + out.backward(); + + ASSERT_EQ(hook1_count, 1); + ASSERT_EQ(hook2_count, 1); + ASSERT_EQ(hook3_count, 1); + + ASSERT_TRUE(a.retains_grad()); + ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5})); +} + +TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) { + auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); + + int hook1_count = 0; + auto hook1 = ([&hook1_count](Variable grad) { + hook1_count++; + ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4); + }); + + int hook2_count = 0; + auto hook2 = ([&hook2_count](Variable grad) { + hook2_count++; + ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4); + }); + + int hook3_count = 0; + auto hook3 = ([&hook3_count](Variable grad) { + hook3_count++; + ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); + }); + + a.register_hook(hook1); + a.retain_grad(); + a.register_hook(hook2); + + a.mul_(2); + a.mul_(2); + a.register_hook(hook3); + + auto out = (a + 1).sum(); + out.backward(); + + ASSERT_EQ(hook1_count, 1); + ASSERT_EQ(hook2_count, 1); + ASSERT_EQ(hook3_count, 1); + + ASSERT_TRUE(a.retains_grad()); + ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5})); +} + TEST(CustomAutogradTest, HookNone) { struct NoneGradientFunction : public Function { static variable_list forward(AutogradContext* ctx, Variable x, Variable y) { diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 26d999c5c1efb..f5f52390d7e62 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -3526,7 +3526,7 @@ void _multihead_attn_test_helper( const torch::Tensor V = K; const torch::Tensor Q = decoder_state.clone().resize_({batch_sz, 1, d_model}); - auto attn_mask = torch::randint(0, 2, {1, seq_len}); + auto attn_mask = torch::randint(0, 2, {1, seq_len}, torch::kFloat); const torch::Tensor attn_mask_tensor = attn_mask.clone(); attn_mask_tensor.masked_fill_( attn_mask_tensor == 0, -std::numeric_limits::infinity()); diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index bf91460c4ba24..d50c6c4f8ef41 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -72,7 +72,7 @@ if(USE_MPI AND USE_C10D_MPI) endif() endif() -if(LINUX) +if(LINUX AND USE_GLOO AND USE_C10D_GLOO) add_executable(example_allreduce example/allreduce.cpp) target_include_directories(example_allreduce PRIVATE $) target_link_libraries(example_allreduce pthread torch_cpu) diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 9bd349b619521..dadfd4d74b3f5 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -80,6 +80,7 @@ set(JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_peephole_optimize.cpp ${JIT_TEST_ROOT}/test_qualified_name.cpp ${JIT_TEST_ROOT}/test_save_load.cpp + ${JIT_TEST_ROOT}/test_schema_info.cpp ${JIT_TEST_ROOT}/test_schema_matching.cpp ${JIT_TEST_ROOT}/test_stack_opt.cpp ${JIT_TEST_ROOT}/test_subgraph_matcher.cpp @@ -118,7 +119,7 @@ if(USE_SYSTEM_ONNX) target_link_libraries(test_jit PRIVATE onnx_proto onnx) endif() -if(CAFFE2_USE_MKLDNN) +if(USE_MKLDNN) target_link_libraries(test_jit PRIVATE caffe2::mkldnn) endif() diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index f3a4c0a55a866..13955108bfbd7 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -1607,5 +1607,90 @@ TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) { [&graph] { AliasDb aliasDb(graph); }, "Tried to register operator foo::rand12(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); } + +TEST(IRNonDeterminismTest, Basic) { + auto graph = std::make_shared(); + auto graph_string = R"IR( + graph(): + %x : Tensor = prim::MakeTestTensor() + %0 : int = prim::Constant[value=0]() + %1 : NoneType = prim::Constant() + %2 : Tensor = aten::bernoulli(%x, %1) + %3 : Tensor = aten::add(%x, %2, %0) + return (%3))IR"; + parseIR(graph_string, graph.get()); + + for (Node* n : graph->nodes()) { + if (n->kind() == aten::bernoulli) { + ASSERT_TRUE(n->isNondeterministic()); + } else { + ASSERT_FALSE(n->isNondeterministic()); + } + } +} + +TEST(IRNonDeterminismTest, DropoutSpecialCase) { + auto graph = std::make_shared(); + auto graph_string = R"IR( + graph(): + %x : Tensor = prim::MakeTestTensor() + %0 : bool = prim::Constant[value=0]() + %1 : bool = prim::Constant[value=1]() + %3 : int = prim::Constant[value=1]() + %3 : float = prim::Constant[value=1.0]() + %4 : Tensor = aten::dropout(%x, %3, %0) + %5 : Tensor = aten::dropout(%x, %3, %1) + %6 : Tensor = aten::add(%4, %5, %3) + return (%6))IR"; + parseIR(graph_string, graph.get()); + + bool train = false; + for (Node* n : graph->nodes()) { + if (n->kind() == aten::dropout) { + if (!train) { + ASSERT_FALSE(n->isNondeterministic()); + train = true; + } else { + ASSERT_TRUE(n->isNondeterministic()); + } + } else { + ASSERT_FALSE(n->isNondeterministic()); + } + } +} + +TEST(NonDeterminismBackwardsCompatibility, BackwardsCompatibility) { + static const std::vector nondeterministic_ops = { + "aten::dropout(Tensor input, float p, bool train) -> Tensor", + "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)", + "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor", + "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor", + "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor", + "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor", + "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)", + "aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator) -> Tensor", + "aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator) -> Tensor", + "aten::normal.Tensor_float(Tensor mean, float std, *, Generator? generator) -> Tensor", + "aten::poisson(Tensor self, Generator? generator) -> Tensor", + "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor", + "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", + "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", + "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; + for (const std::string& op : nondeterministic_ops) { + const c10::FunctionSchema& schema = torch::jit::parseSchema(op); + const auto& op_handle = c10::Dispatcher::singleton().findOp( + c10::OperatorName(schema.name(), schema.overload_name())); + ASSERT_TRUE(op_handle->hasTag(at::Tag::nondeterministic_seeded)); + } +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index 078c405195f28..baa54b0024e45 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -145,6 +145,15 @@ class BackendWithCompiler : public PyTorchBackendInterface { auto x_ptr = float_data_ptr(x); auto h_ptr = float_data_ptr(h); auto y_ptr = float_data_ptr(y); +#ifndef NO_PROFILING + RECORD_BACKEND_MEMORY_EVENT_TO_EDGE_PROFILER( + x_ptr, + x.numel() * sizeof(float), + x.numel() * sizeof(float), + x.numel() * sizeof(float) + y.numel() * sizeof(float) + + h.numel() * sizeof(float), + c10::Device(c10::kCPU)); +#endif if (instruction == "aten::add") { y_ptr[0] = x_ptr[0] + h_ptr[0]; } else { diff --git a/test/cpp/jit/test_exception.cpp b/test/cpp/jit/test_exception.cpp index b6b3cbcd67930..7f57bc5ca75a3 100644 --- a/test/cpp/jit/test_exception.cpp +++ b/test/cpp/jit/test_exception.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index c45ca96383e9f..930b26076bbb1 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -679,6 +680,7 @@ void backportAllVersionCheck( #if !defined FB_XPLAT_BUILD TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) { + torch::jit::register_flatbuffer_all(); torch::jit::Module module("m"); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) module.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); diff --git a/test/cpp/jit/test_schema_info.cpp b/test/cpp/jit/test_schema_info.cpp new file mode 100644 index 0000000000000..939f9fc4480d7 --- /dev/null +++ b/test/cpp/jit/test_schema_info.cpp @@ -0,0 +1,394 @@ +#include +#include +#include + +namespace torch { +namespace utils { +using c10::SchemaArgType; + +TEST(FunctionSchemaIsAliasingTest, Basic) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::test.Tensor(Tensor(a) self, Tensor(b!) other, Tensor more_other) -> (Tensor(a), Tensor(b!))"); + ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 0})); + ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 1})); + ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 0})); + ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 1})); + ASSERT_FALSE(schema.is_aliasing({SchemaArgType::input, 2})); +} + +TEST(FunctionSchemaIsAliasingTest, InvalidArgument) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_THROW(schema.is_aliasing({SchemaArgType::input, 4}), c10::Error); + ASSERT_THROW(schema.is_aliasing({SchemaArgType::output, 4}), c10::Error); +} + +TEST(FunctionSchemaIsMutableTest, Basic) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0})); + ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0})); + ASSERT_TRUE(schema.is_mutable("self")); + ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1})); + ASSERT_FALSE(schema.is_mutable("other")); + ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2})); + ASSERT_FALSE(schema.is_mutable("alpha")); +} + +TEST(FunctionSchemaIsMutableTest, InvalidArgument) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error); + ASSERT_THROW(schema.is_mutable({SchemaArgType::output, 4}), c10::Error); + ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error); +} + +TEST(SchemaInfoIsMutableTest, Basic) { + SchemaInfo schema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0})); + ASSERT_TRUE(schema.is_mutable("self")); + ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1})); + ASSERT_FALSE(schema.is_mutable("other")); + ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2})); + ASSERT_FALSE(schema.is_mutable("alpha")); +} + +TEST(SchemaInfoIsMutableTest, InvalidArgument) { + SchemaInfo schema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error); + ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error); +} + +TEST(SchemaInfoIsMutableTest, AliasingInputs) { + SchemaInfo schema( + "aten::test.Tensor(Tensor(a!) self, Tensor(b) other, *, Scalar alpha=1) -> (Tensor(a!), Tensor(b))"); + ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0})); + ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0})); + ASSERT_TRUE(schema.is_mutable("self")); + ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1})); + ASSERT_FALSE(schema.is_mutable({SchemaArgType::output, 1})); + ASSERT_FALSE(schema.is_mutable("other")); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("self", input); + schema.addArgumentValue("other", input); + ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 1})); + ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 1})); + ASSERT_TRUE(schema.is_mutable("other")); +} + +TEST(SchemaInfoIsMutableTest, InstanceNorm) { + SchemaInfo schema_info( + "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor"); + ASSERT_TRUE(schema_info.is_mutable("running_mean")); + ASSERT_TRUE(schema_info.is_mutable("running_var")); + schema_info.addArgumentValue("use_input_stats", false); + ASSERT_FALSE(schema_info.is_mutable("running_mean")); + ASSERT_FALSE(schema_info.is_mutable("running_var")); +} + +TEST(SchemaInfoIsMutableTest, BatchNorm) { + SchemaInfo schema_info( + "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"); + ASSERT_TRUE(schema_info.is_mutable("running_mean")); + ASSERT_TRUE(schema_info.is_mutable("running_var")); + schema_info.addArgumentValue("training", false); + ASSERT_FALSE(schema_info.is_mutable("running_mean")); + ASSERT_FALSE(schema_info.is_mutable("running_var")); +} + +TEST(SchemaInfoIsNonDeterministicTest, Basic) { + SchemaInfo deterministic_schema_info( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + SchemaInfo nondeterministic_schema_info( + "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor"); + ASSERT_FALSE(deterministic_schema_info.is_nondeterministic()); + ASSERT_TRUE(nondeterministic_schema_info.is_nondeterministic()); +} + +TEST(SchemaInfoIsNonDeterministicTest, Dropout) { + SchemaInfo droupout_schema_info( + "aten::dropout(Tensor input, float p, bool train) -> Tensor"); + ASSERT_TRUE(droupout_schema_info.is_nondeterministic()); + droupout_schema_info.addArgumentValue("train", false); + ASSERT_FALSE(droupout_schema_info.is_nondeterministic()); +} + +TEST(FunctionSchemaMayAliasTest, Basic) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::input, 0})); +} + +TEST(FunctionSchemaMayAliasTest, InvalidArgument) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_THROW( + schema.may_alias({SchemaArgType::input, 15}, {SchemaArgType::output, 0}), + c10::Error); + ASSERT_THROW( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 15}), + c10::Error); +} + +TEST(FunctionSchemaMayAliasTest, Wildcard) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)"); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0})); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0})); +} + +TEST(SchemaInfoMayAliasTest, AliasingInputs) { + SchemaInfo schema( + "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("self", input); + schema.addArgumentValue("other", input); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); +} + +TEST(SchemaInfoMayAliasTest, AliasingOutputs) { + SchemaInfo schema( + "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)"); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("min", input); + schema.addArgumentValue("max", input); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1})); +} + +TEST(SchemaInfoMayAliasTest, AliasingInputOutput) { + SchemaInfo schema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("self", input); + schema.addArgumentValue("other", input); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); +} + +TEST(SchemaInfoMayAliasTest, MultipleWildcardInputs) { + SchemaInfo schema( + "aten::test.Tensor(Tensor(a) a, Tensor(*) b, Tensor(*) c) -> (Tensor(a), Tensor(*))"); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1})); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2})); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1})); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("a", input); + schema.addArgumentValue("b", input); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); +} + +TEST(SchemaInfoMayAliasTest, MultipleNonWildcardInputs) { + SchemaInfo schema( + "aten::test.Tensor(Tensor(a) a, Tensor(a) b, Tensor(*) c, Tensor(b) d) -> (Tensor(a), Tensor(*))"); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::input, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 0})); +} + +TEST(SchemaInfoMayAliasTest, MultipleNonWildcardOutputs) { + SchemaInfo schema( + "aten::test.Tensor(Tensor(a) a, Tensor(*) b) -> (Tensor(a), Tensor(a))"); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1})); + ASSERT_TRUE( + schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 1})); +} + +TEST(SchemaInfoMayAliasTest, MismatchingTypes) { + SchemaInfo schema("aten::test.Tensor(Tensor(a) a) -> int(a)"); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); +} + +TEST(FunctionSchemaMayContainAliasTest, Basic) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))"); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::output, 0})); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 1}, {SchemaArgType::output, 0})); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 1}, {SchemaArgType::input, 0})); +} + +TEST(FunctionSchemaMayContainAliasTest, Wildcard) { + c10::FunctionSchema schema = torch::jit::parseSchema( + "aten::test.Tensor(Tensor(*) self) -> (Tensor[], Tensor)"); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 0})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false)); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false)); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0})); +} + +TEST(FunctionSchemaMayContainAliasTest, InputAndOutputContainers) { + c10::FunctionSchema schema = + torch::jit::parseSchema("aten::test.Tensor(Tensor[] self) -> Tensor[]"); + ASSERT_FALSE( + schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 0})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false)); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false)); +} + +TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsEqual) { + SchemaInfo schema( + "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("self", input); + schema.addArgumentValue("other", input); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false)); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false)); +} + +TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsContained) { + SchemaInfo schema( + "aten::test.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor"); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("self", c10::List({input})); + schema.addArgumentValue("other", input); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false)); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false)); +} + +TEST(SchemaInfoMayContainAliasTest, ContainAliasOutputs) { + SchemaInfo schema( + "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)"); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::output, 1})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("min", input); + schema.addArgumentValue("max", input); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::output, 1})); +} + +TEST(SchemaInfoMayContainAliasTest, ContainAliasInputOutput) { + SchemaInfo schema( + "aten::test.tensor(Tensor(a) self, Tensor[] other) -> Tensor(a)"); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 1})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("other", c10::List({input})); + schema.addArgumentValue("self", input); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 1})); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 1}, false)); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 1}, {SchemaArgType::output, 0}, false)); +} + +TEST(SchemaInfoMayContainAliasTest, InputAndOutputContainers) { + SchemaInfo schema( + "aten::test.tensor(Tensor self, Tensor[] other) -> Tensor[]"); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 1})); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 0})); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("other", c10::List({input})); + schema.addArgumentValue("self", input); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 1})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::output, 0}, {SchemaArgType::input, 0})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); +} + +TEST(SchemaInfoMayContainAliasTest, Wildcard) { + SchemaInfo schema( + "aten::test.tensor(Tensor a, Tensor[] b, Tensor(*) c) -> Tensor[]"); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 2})); + ASSERT_FALSE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 2}, {SchemaArgType::input, 1})); + at::Tensor input = at::randn({3, 3}); + schema.addArgumentValue("b", c10::List({input})); + schema.addArgumentValue("a", input); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 2})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); + ASSERT_TRUE(schema.may_contain_alias( + {SchemaArgType::input, 2}, {SchemaArgType::input, 1})); +} +} // namespace utils +} // namespace torch diff --git a/test/cpp/lazy/test_backend_device.cpp b/test/cpp/lazy/test_backend_device.cpp index b9b337fb702cc..9c4d1c1467fbe 100644 --- a/test/cpp/lazy/test_backend_device.cpp +++ b/test/cpp/lazy/test_backend_device.cpp @@ -4,6 +4,7 @@ #include #include +#include #include namespace torch { @@ -19,9 +20,14 @@ TEST(BackendDeviceTest, BackendDeviceType) { TEST(BackendDeviceTest, Basic1) { auto device = BackendDevice(); - EXPECT_EQ(device.type(), 0); EXPECT_EQ(device.ordinal(), 0); - EXPECT_STREQ(device.toString().c_str(), "Unknown0"); + if (std::getenv("LTC_TS_CUDA") != nullptr) { + EXPECT_EQ(device.type(), 1); + EXPECT_STREQ(device.toString().c_str(), "CUDA0"); + } else { + EXPECT_EQ(device.type(), 0); + EXPECT_STREQ(device.toString().c_str(), "CPU0"); + } } TEST(BackendDeviceTest, Basic2) { @@ -48,6 +54,27 @@ TEST(BackendDeviceTest, Basic3) { EXPECT_STREQ(device.toString().c_str(), "Test1"); } +TEST(BackendDeviceTest, Basic4) { + // Seems weird to have setters in BackendImplInterface given getBackend() + // returns a const pointer. + auto default_type = getBackend()->GetDefaultDeviceType(); + auto default_ordinal = getBackend()->GetDefaultDeviceOrdinal(); + const_cast(getBackend()) + ->SetDefaultDeviceType(static_cast(c10::kCUDA)); + const_cast(getBackend())->SetDefaultDeviceOrdinal(1); + + auto device = BackendDevice(); + + EXPECT_EQ(device.type(), 1); + EXPECT_EQ(device.ordinal(), 1); + EXPECT_STREQ(device.toString().c_str(), "CUDA1"); + + const_cast(getBackend()) + ->SetDefaultDeviceType(default_type->type); + const_cast(getBackend()) + ->SetDefaultDeviceOrdinal(default_ordinal); +} + TEST(BackendDeviceTest, Compare) { auto type = std::make_shared(); type->type = 1; diff --git a/test/cpp/lazy/test_lazy_ops.cpp b/test/cpp/lazy/test_lazy_ops.cpp index e0ad18068bd1e..de80ff9ea5763 100644 --- a/test/cpp/lazy/test_lazy_ops.cpp +++ b/test/cpp/lazy/test_lazy_ops.cpp @@ -22,7 +22,9 @@ namespace lazy { #ifndef FBCODE_CAFFE2 namespace { -// This registers the torchscript backend, without which lazy device won't work +// This registers the torchscript backend, without which lazy device won't work. +// FIXME: This registers the backend for the whole test binary. We should +// probably do it and undo it in the test fixture below. static bool inline init_backend() { torch::lazy::InitTorchScriptBackend(); return true; @@ -89,12 +91,13 @@ TEST(LazyDynamicOpsTest, NarrowCopy) { auto y = torch::rand({Y_DIM}).to(kLazy); auto ly = torch::lazy::TryGetLtcTensor(y); auto dim_node = MakeNode(ly->GetIrValue(), 0); - auto lmn = std::make_shared(dim_node); + auto lmn = c10::make_intrusive(dim_node); auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, lmn->toSymInt()); AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM)); } TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) { + FLAGS_ltc_enable_symbolic_shapes = true; auto xc = torch::rand({10}); auto x = xc.to(kLazy); const size_t Y_DIM = 3; @@ -105,6 +108,7 @@ TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) { ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc // shape inference assumes narrow_copy can copy the whole tensor AllClose(z.cpu(), zc); + FLAGS_ltc_enable_symbolic_shapes = false; } TEST_F(LazyOpsTest, TestScalarTensor) { diff --git a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp index 95ba2b7b853ed..867b775c1adb4 100644 --- a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp +++ b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp @@ -25,7 +25,9 @@ bool checkMetaData( if (line.find(op_name) != std::string::npos) { while (std::getline(trace_file, line)) { if (line.find(metadata_name) != std::string::npos) { - return (line.find(metadata_val) != std::string::npos); + if (line.find(metadata_val) != std::string::npos) { + return true; + } } } } @@ -122,6 +124,39 @@ TEST(MobileProfiler, Backend) { checkMetaData("aten::add", metadata_name, "test_backend", trace_file)); } +TEST(MobileProfiler, BackendMemoryEvents) { + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + testModelFile.append("test_backend_for_profiling.ptl"); + + std::vector inputs; + inputs.emplace_back(at::rand({64, 64})); + inputs.emplace_back(at::rand({64, 64})); + std::string trace_file_name("/tmp/test_trace_backend_memory.trace"); + + mobile::Module bc = _load_for_mobile(testModelFile); + { + mobile::KinetoEdgeCPUProfiler profiler( + bc, + trace_file_name, + false, // record input_shapes + true, // profile memory + true, // record callstack + false, // record flops + true); // record module hierarchy + bc.forward(inputs); + } + std::ifstream trace_file(trace_file_name); + std::string line; + ASSERT_TRUE(trace_file.is_open()); + trace_file.seekg(0, std::ios_base::beg); + std::string metadata_name("Bytes"); + ASSERT_TRUE(checkMetaData("[memory]", metadata_name, "16384", trace_file)); + trace_file.seekg(0, std::ios_base::beg); + metadata_name = "Total Reserved"; + ASSERT_TRUE(checkMetaData("[memory]", metadata_name, "49152", trace_file)); +} + } // namespace mobile } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/padded_buffer.cpp b/test/cpp/tensorexpr/padded_buffer.cpp index 2b6202557fdb8..424d82c77453c 100644 --- a/test/cpp/tensorexpr/padded_buffer.cpp +++ b/test/cpp/tensorexpr/padded_buffer.cpp @@ -9,7 +9,7 @@ namespace jit { namespace tensorexpr { int PaddedBufferBase::Index(const std::vector& indices) const { - DCHECK_EQ(dims_.size(), indices.size()); + TORCH_DCHECK_EQ(dims_.size(), indices.size()); int total_index = 0; for (const auto i : c10::irange(dims_.size())) { total_index += indices[i] * strides_[i]; diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 98e7fca452b0f..5763be459dbc7 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -154,7 +154,7 @@ TEST_F(Kernel, _1) { k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 5 * 3; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } @@ -192,7 +192,7 @@ TEST_F(Kernel, _2) { k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 5 * 3; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } @@ -230,7 +230,7 @@ TEST_F(Kernel, _3) { k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 5 * 3; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } @@ -278,7 +278,7 @@ TEST_F(Kernel, ParallelStrided) { k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 5 * 3; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } @@ -321,7 +321,7 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 5 * 3; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } { @@ -356,10 +356,10 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { std::vector stack = fmap(inputs); k.run(stack); o = stack[0].toTensor(); - CHECK_EQ(o.sizes()[0], 8); - CHECK_EQ(o.sizes()[1], 4); + TORCH_CHECK_EQ(o.sizes()[0], 8); + TORCH_CHECK_EQ(o.sizes()[1], 4); for (size_t i = 0; i < 8 * 4; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } { @@ -412,16 +412,16 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { o = stack[0].toTensor(); // Check sizes - CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); size_t num_el = 1; for (const auto idx : c10::irange(ref.sizes().size())) { - CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); num_el *= ref.sizes()[idx]; } // Check the contents for (const auto i : c10::irange(num_el)) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } { @@ -465,16 +465,16 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { o = stack[0].toTensor(); // Check sizes - CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); size_t num_el = 1; for (const auto idx : c10::irange(ref.sizes().size())) { - CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); num_el *= ref.sizes()[idx]; } // Check the contents for (const auto i : c10::irange(num_el)) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } { @@ -564,17 +564,17 @@ TEST_F(Kernel, CatInputTypesPromotion) { auto o = stack[0].toTensor(); // Check sizes - CHECK_EQ(o.sizes().size(), ref.sizes().size()); - CHECK_EQ(o.dtype(), ref.dtype()); + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); size_t num_el = 1; for (const auto idx : c10::irange(ref.sizes().size())) { - CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); num_el *= ref.sizes()[idx]; } // Check the contents for (const auto i : c10::irange(num_el)) { - CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); } } } @@ -687,17 +687,17 @@ TEST_F(Kernel, CatWoConditionals) { auto o = stack[0].toTensor(); // Check sizes - CHECK_EQ(o.sizes().size(), ref.sizes().size()); - CHECK_EQ(o.dtype(), ref.dtype()); + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); size_t num_el = 1; for (const auto idx : c10::irange(ref.sizes().size())) { - CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); num_el *= ref.sizes()[idx]; } // Check the contents for (const auto i : c10::irange(num_el)) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } getCatWoConditionals() = old_cat_wo_conditionals; } @@ -752,17 +752,17 @@ TEST_F(Kernel, OptimizeConditionals) { auto o = stack[0].toTensor(); // Check sizes - CHECK_EQ(o.sizes().size(), ref.sizes().size()); - CHECK_EQ(o.dtype(), ref.dtype()); + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); size_t num_el = 1; for (const auto idx : c10::irange(ref.sizes().size())) { - CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); num_el *= ref.sizes()[idx]; } // Check the contents for (const auto i : c10::irange(num_el)) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } getOptConditionals() = old_opt_conditionals; getCatWoConditionals() = old_cat_wo_conditionals; @@ -1517,7 +1517,7 @@ TEST_F(Kernel, RunFast) { k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()}); for (size_t i = 0; i < 5 * 3; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } #endif } @@ -1544,7 +1544,7 @@ TEST_F(Kernel, RunWithAllocatedOutputs) { std::vector stack = fmap(args); k.runWithAllocatedOutputs(stack); for (size_t i = 0; i < 5 * 3; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } #endif } @@ -1663,7 +1663,7 @@ TEST_F(Kernel, Vectorize) { k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 100 * 16; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } #endif } @@ -1699,7 +1699,7 @@ TEST_F(Kernel, DISABLED_FlattenVectorize) { k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 100 * 3; i++) { - CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } #endif } @@ -1729,7 +1729,8 @@ TEST_F(Kernel, Strided1dWithinBounds) { auto output = stack[0].toTensor(); for (size_t i = 0; i < 3; ++i) { - CHECK_EQ(((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]); + TORCH_CHECK_EQ( + ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]); } } @@ -1784,14 +1785,14 @@ graph(%x : int, %y : int): std::vector inputs = {&x, &y}; std::vector outputs = {&r, &z}; k.runFast(inputs, outputs); - CHECK_EQ(z, x * y); - CHECK_EQ(r, z * x); + TORCH_CHECK_EQ(z, x * y); + TORCH_CHECK_EQ(r, z * x); // Verify that TEK::run works correctly with scalar outputs std::vector stack = {x, y}; k.run(stack); - CHECK_EQ(stack[0], x * y * x); - CHECK_EQ(stack[1], x * y); + TORCH_CHECK_EQ(stack[0], x * y * x); + TORCH_CHECK_EQ(stack[1], x * y); } TEST_F(Kernel, ScalarTensorOut) { @@ -1820,8 +1821,8 @@ graph(%x : int, std::vector inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; std::vector outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; k.runFast(inputs, outputs); - CHECK_EQ(z, x * y); - CHECK_EQ(r, z * x); + TORCH_CHECK_EQ(z, x * y); + TORCH_CHECK_EQ(r, z * x); ASSERT_TRUE(at::equal(zt, xt * yt)); ASSERT_TRUE(at::equal(rt, zt * xt)); @@ -1829,9 +1830,9 @@ graph(%x : int, // inputs/utputs std::vector stack = {x, xt, y, yt}; k.run(stack); - CHECK_EQ(stack[0], x * y * x); + TORCH_CHECK_EQ(stack[0], x * y * x); ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt)); - CHECK_EQ(stack[2], x * y); + TORCH_CHECK_EQ(stack[2], x * y); ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt)); } diff --git a/test/cpp/tensorexpr/test_quantization.cpp b/test/cpp/tensorexpr/test_quantization.cpp index a6049c5ac0b65..96a21e0e07bf3 100644 --- a/test/cpp/tensorexpr/test_quantization.cpp +++ b/test/cpp/tensorexpr/test_quantization.cpp @@ -57,7 +57,7 @@ TEST_F(Quantization, QuantDequantInt8) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } TEST_F(Quantization, QuantDequantUInt8) { @@ -87,7 +87,7 @@ TEST_F(Quantization, QuantDequantUInt8) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } TEST_F(Quantization, QuantDequantUInt8_NLC) { @@ -119,7 +119,7 @@ TEST_F(Quantization, QuantDequantUInt8_NLC) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } at::Tensor quantized_add( @@ -174,7 +174,7 @@ TEST_F(Quantization, QuantAddDequantInt8) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } TEST_F(Quantization, QuantAddDequantUInt8) { @@ -218,7 +218,7 @@ TEST_F(Quantization, QuantAddDequantUInt8) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } TEST_F(Quantization, QuantSigmoidDequantUInt8) { @@ -254,7 +254,7 @@ TEST_F(Quantization, QuantSigmoidDequantUInt8) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } at::Tensor quantized_mul( @@ -310,7 +310,7 @@ TEST_F(Quantization, QuantMulDequantUInt8) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) { @@ -348,7 +348,7 @@ TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } TEST_F(Quantization, UpsampleNearst2d) { @@ -377,7 +377,7 @@ TEST_F(Quantization, UpsampleNearst2d) { std::cout << "y_expected:\n" << y_expected << std::endl; std::cout << "y:\n" << y << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } at::Tensor quantized_cat( @@ -445,7 +445,7 @@ TEST_F(Quantization, QuantCatDequantUInt8) { std::cout << "expected:\n" << expected << std::endl; std::cout << "result:\n" << result << std::endl; } - CHECK_EQ(check, 1); + TORCH_CHECK_EQ(check, 1); } } // namespace jit diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp new file mode 100644 index 0000000000000..c6d05042ef76f --- /dev/null +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -0,0 +1,119 @@ +#include +#include + +#include +#include + +#include +#include +#include + + +static uint64_t add_counter = 0; +static uint64_t last_saved_value = 0; + +// basic dummy add function +at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + add_counter += 1; + // Since this custom device is just for testing, not bothering to implement kernels. + return at::empty(self.sizes(), self.options()); +} + +// A dummy allocator for our custom device, that secretly uses the CPU +struct DummyCustomAllocator final : at::Allocator { + DummyCustomAllocator() = default; + at::DataPtr allocate(size_t nbytes) const override { + void* data = c10::alloc_cpu(nbytes); + return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; + } + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + c10::free_cpu(ptr); + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } +}; + +// Register our dummy allocator +static DummyCustomAllocator global_custom_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc); + +// basic dummy empty function, so we can directly construct tensors on the custom device +// This dummy test device will just use the CPU allocator, and ignores pinned memory. +at::Tensor custom_empty_memory_format(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) { + constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); +} +at::Tensor custom_empty_symint(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) { + constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_generic(c10::asIntArrayRefSlow(size), &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); +} + +at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) { + // Not bothering to implement. + return self; +} + +// basic dummy copy_() function, so we can copy from the custom device to/from CPU +at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { + TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); + TORCH_CHECK(dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); + + // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. + TORCH_CHECK(self.sizes() == dst.sizes()); + TORCH_CHECK(self.scalar_type() == dst.scalar_type()); + TORCH_CHECK(self.is_contiguous() && dst.is_contiguous()); + + std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), self.storage().nbytes()); + return dst; +} + + +// This macro does the heavy lifting. +// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend. +// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key. +// Later in this file, we map a custom device to the PrivateUse1 device type, +// which allows user code that puts a tensor on your custom_device to eventually get plumbed +// into the kernels registered here. +// +// This macro registers your kernels to the PyTorch Dispatcher. +// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("add.Tensor", &custom_add_Tensor); + m.impl("empty.memory_format", &custom_empty_memory_format); + m.impl("empty.SymInt", &custom_empty_symint); + m.impl("fill_.Scalar", &custom_fill__scalar); + m.impl("_copy_from", &custom__copy_from); +} + +// This basic implementation doesn't bother dealing with different device indices +// (e.g. custom_device:0 vs. custom_device:1). +// We could do that by letting the user pass in a device index in our exposed device function. +// Note that if you do that, you'll also need to register a device guard to core. +// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`. +c10::Device get_custom_device() { + return c10::Device(c10::DeviceType::PrivateUse1, 0); +} + +bool custom_add_called() { + bool called = false; + if (add_counter > last_saved_value) { + called = true; + last_saved_value = add_counter; + } + return called; +} + +// Here, we're exposing a custom device object that corresponds to our custom backend. +// We do this using pybind: exposing an "extension_name.custom_device()" function in python, +// that's implemented in C++. +// The implementation in this file maps directly to the `PrivateUse1` device type. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("custom_device", &get_custom_device, "get custom device object"); + m.def("custom_add_called", &custom_add_called, "check if our custom add function was called"); +} diff --git a/test/cpp_extensions/ort_extension.cpp b/test/cpp_extensions/ort_extension.cpp index b646f3b14939d..24617abeb06d5 100644 --- a/test/cpp_extensions/ort_extension.cpp +++ b/test/cpp_extensions/ort_extension.cpp @@ -26,6 +26,11 @@ Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::op return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size); } +Tensor empty_symint_override(c10::SymIntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + return empty_override(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); +} + Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) { test_int = 1; return out; @@ -53,6 +58,7 @@ std::tuple fake_convolution_backward( } TORCH_LIBRARY_IMPL(aten, ORT, m) { + m.impl("empty.SymInt", empty_symint_override); m.impl("empty.memory_format", empty_override); m.impl("add.out", add_out_override); m.impl("convolution_overrideable", fake_convolution); diff --git a/test/distributed/_shard/checkpoint/test_file_system_checkpoint.py b/test/distributed/_shard/checkpoint/test_file_system_checkpoint.py index 89f690631a740..df7d2412fd015 100644 --- a/test/distributed/_shard/checkpoint/test_file_system_checkpoint.py +++ b/test/distributed/_shard/checkpoint/test_file_system_checkpoint.py @@ -4,11 +4,10 @@ import os import shutil import tempfile -from typing import Dict, cast +from typing import Dict import torch import torch.distributed as dist -from torch import Tensor from torch.distributed._shard import sharded_tensor from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook from torch.distributed._shard.sharding_spec import ( @@ -48,87 +47,6 @@ ) sys.exit(0) -def _sharded_tensor_gather( - self, - dst=0, - out=None, -): - """ - This is a reimplementation of ST:gather using gather instead of gather_object. - The later hangs on CI inside NCCL. - """ - - def shard_size(shard_md): - res = 1 - for s in shard_md.shard_sizes: - res *= s - return res - rank = dist.get_rank(self._process_group) - full_size = self.metadata().size - - world_size = dist.get_world_size(self._process_group) - rank_sizes = [0 for _ in range(world_size)] - max_rank_size = 0 - shard_placement = dict() - local_shards_placement = [] - # collect sizes - for shard_idx, shard_md in enumerate(self.metadata().shards_metadata): - shard_rank = shard_md.placement.rank() - shard_placement[shard_idx] = (shard_rank, rank_sizes[shard_rank]) - if shard_rank == rank: - local_shards_placement.append((shard_md, rank_sizes[shard_rank],)) - - rank_sizes[shard_rank] += shard_size(shard_md) - max_rank_size = max(max_rank_size, rank_sizes[shard_rank]) - - - if rank == dst: - gather_list = [torch.empty((max_rank_size,), device=out.device) for _ in range(world_size)] - else: - gather_list = None - - # FIXME is a rank allowed to not have any data? - with torch.no_grad(): - # XXX we can fastpath this to torch.cat if max_rank_size == rank_sizes[rank] - data = torch.empty(max_rank_size, device=self.local_shards()[0].tensor.device) - for shard in self.local_shards(): - for placement in local_shards_placement: - if placement[0] == shard.metadata: - src = shard.tensor.flatten() - data[placement[1]: placement[1] + src.numel()].copy_(src) - break - - dist.gather( - tensor=data, - gather_list=gather_list, - dst=dst, - group=self._process_group, - ) - if rank != dst: - return - if out is None: - raise ValueError("`out` Tensor must be provided on dst rank!") - - full_size = self.metadata().size - dims = len(full_size) - - - for shard_idx, shard_md in enumerate(self.metadata().shards_metadata): - placement = shard_placement[shard_idx] - tensor = gather_list[placement[0]] - tensor = tensor[placement[1] : placement[1] + shard_size(shard_md)] - tensor = tensor.view(shard_md.shard_sizes) - - out_narrow_view = out - for dim in range(dims): - out_narrow_view = out_narrow_view.narrow( - dim, - shard_md.shard_offsets[dim], - shard_md.shard_sizes[dim], - ) - - out_narrow_view.copy_(tensor) - def assert_state_dict_equal( self: TestCase, @@ -146,11 +64,7 @@ def assert_state_dict_equal( for key, value_1 in state_dict_1.items(): value_2 = state_dict_2[key] - if isinstance(value_1, torch.Tensor): - self.assertTrue( - torch.equal(value_1, value_2), f"Key {key}'s tensor does not match" - ) - elif isinstance(value_1, ShardedTensor): + if isinstance(value_1, ShardedTensor): for local_shard_1, local_shard_2 in zip( value_1.local_shards(), value_2.local_shards() ): @@ -158,6 +72,10 @@ def assert_state_dict_equal( torch.equal(local_shard_1.tensor, local_shard_1.tensor), f"Key {key}'s shard does not match", ) + elif isinstance(value_1, torch.Tensor): + self.assertTrue( + torch.equal(value_1, value_2), f"Key {key}'s tensor does not match" + ) return True @@ -268,8 +186,8 @@ def get_file_path(self) -> str: def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor: res = torch.zeros(tensor.shape, device="cuda:0") if dist.get_rank() == 0 else None - _sharded_tensor_gather(tensor, out=res) - return cast(Tensor, res) + tensor.gather(out=res) + return res @with_comms(init_rpc=False) @skip_if_lt_x_gpu(2) diff --git a/test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py b/test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py new file mode 100644 index 0000000000000..321dc2f546883 --- /dev/null +++ b/test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py @@ -0,0 +1,460 @@ +# Owner(s): ["oncall: distributed"] + +import sys +import os +import shutil +import tempfile +from typing import Dict + +import torch +import torch.distributed as dist +from torch.distributed._shard import sharded_tensor +from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook +from torch.distributed._shard.sharding_spec import ( + ChunkShardingSpec, + EnumerableShardingSpec, + ShardingSpec, + ShardMetadata, +) +from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.distributed._shard.sharded_tensor import ( + ShardedTensorTestBase, + with_comms, +) +from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import ( + MyShardedModel1 +) + + +from torch.testing._internal.common_utils import ( + TEST_WITH_DEV_DBG_ASAN, + run_tests, +) + +from torch.distributed._shard.checkpoint import ( + FileSystemReader, + FileSystemWriter, + load_state_dict, + save_state_dict, +) + + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) + + +def assert_state_dict_equal( + self: TestCase, + state_dict_1: Dict[str, torch.Tensor], + state_dict_2: Dict[str, torch.Tensor], +) -> bool: + self.assertEqual( + len(state_dict_1), len(state_dict_2), "state_dict must be the same size" + ) + self.assertEqual( + set(state_dict_1.keys()), + set(state_dict_2.keys()), + "state_dict keys do not match", + ) + + for key, value_1 in state_dict_1.items(): + value_2 = state_dict_2[key] + if isinstance(value_1, ShardedTensor): + for local_shard_1, local_shard_2 in zip( + value_1.local_shards(), value_2.local_shards() + ): + self.assertTrue( + torch.equal(local_shard_1.tensor, local_shard_1.tensor), + f"Key {key}'s shard does not match", + ) + elif isinstance(value_1, torch.Tensor): + self.assertTrue( + torch.equal(value_1, value_2), f"Key {key}'s tensor does not match" + ) + + return True + + +class MyTestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_1 = torch.nn.Linear(5, 5) + self.linear_2 = torch.nn.Linear(5, 1) + self.emb = torch.nn.EmbeddingBag(5, 10) + + +# The ShardedModels are borrowed from test/distributed/_sharded_tensor/test_sharded_tensor.py +class MyShardedModel3(torch.nn.Module): + def __init__( + self, + spec: ShardingSpec, + ) -> None: + super(MyShardedModel3, self).__init__() + self.sharded_tensor: ShardedTensor = sharded_tensor.rand( + spec, 10, 20, init_rrefs=False + ) + + +class TestDistributedStateDictSaveLoad(TestCase): + def test_read_write_only_tensor(self) -> None: + with tempfile.TemporaryDirectory() as path: + state_dict_to_save = MyTestModule().state_dict() + + fs_writer = FileSystemWriter(path=path) + save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer, no_dist=True) + + state_dict_to_load_to = MyTestModule().state_dict() + + with self.assertRaises(AssertionError): + assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + + # Load from file without any resharding + fs_reader = FileSystemReader(path=path) + load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader, no_dist=True) + + assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + + +class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): + @property + def world_size(self) -> int: + return 2 + + @with_comms(init_rpc=False, backend="gloo") + def test_read_write_shard_tensor(self) -> None: + paths = [tempfile.mkdtemp()] + dist.broadcast_object_list(paths) + + path = paths[0] + + # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0", + "rank:1", + ], + ) + + model_to_save = MyShardedModel1(spec, init_rrefs=False) + + # Test save + model_to_save._register_state_dict_hook(state_dict_hook) + state_dict_to_save = model_to_save.state_dict() + + fs_writer = FileSystemWriter(path=path) + save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) + + dist.barrier() + + # Create a new model + model_to_load = MyShardedModel1(spec, init_rrefs=False) + # This is not the correct hook for loading the state dict + # model_to_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) + model_to_load._register_state_dict_hook(state_dict_hook) + state_dict_to_load_to = model_to_load.state_dict() + + dist.barrier() + + with self.assertRaises(AssertionError): + assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + + # Test load. + fs_reader = FileSystemReader(path=path) + load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) + + assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + dist.barrier() + + +class TestDistributedReshardOnLoad(ShardedTensorTestBase): + @property + def world_size(self) -> int: + return 2 + + def get_file_path(self) -> str: + paths = [tempfile.mkdtemp()] if dist.get_rank() == 0 else [None] + dist.broadcast_object_list(paths) + return paths[0] + + def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor: + res = torch.zeros(tensor.shape, device="cpu") if dist.get_rank() == 0 else None + tensor.gather(out=res) + return res + + @with_comms(init_rpc=False, backend="gloo") + def test_load_with_different_shard_plan(self) -> None: + path = self.get_file_path() + + # We hardcode the assumption of how many shards are around + self.assertEqual(self.world_size, dist.get_world_size()) + + specs = [ + # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. + ChunkShardingSpec( + dim=0, + placements=[ + "rank:0", + "rank:1", + ], + ), + # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. + ChunkShardingSpec( + dim=0, + placements=[ + "rank:0", + "rank:1", + "rank:1", + "rank:0", + ], + ), + # This requires the tensors to be [10, 20] + EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[2, 20], + placement="rank:0", + ), + ShardMetadata( + shard_offsets=[2, 0], + shard_sizes=[1, 20], + placement="rank:1", + ), + ShardMetadata( + shard_offsets=[3, 0], + shard_sizes=[3, 20], + placement="rank:0", + ), + ShardMetadata( + shard_offsets=[6, 0], + shard_sizes=[3, 20], + placement="rank:1", + ), + ShardMetadata( + shard_offsets=[9, 0], + shard_sizes=[1, 20], + placement="rank:0", + ), + ] + ), + # This requires the tensors to be [10, 20] + EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[8, 20], + placement="rank:1", + ), + ShardMetadata( + shard_offsets=[8, 0], + shard_sizes=[2, 20], + placement="rank:0", + ), + ] + ), + ] + + for s0 in specs: + for s1 in specs: + if s0 == s1: + continue + + if dist.get_rank() == 0: + shutil.rmtree(path, ignore_errors=True) + os.makedirs(path) + dist.barrier() + + model_to_save = MyShardedModel3(s0) + model_to_save._register_state_dict_hook(state_dict_hook) + state_dict_to_save = model_to_save.state_dict() + + fs_writer = FileSystemWriter(path=path) + save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) + + dist.barrier() + + model_to_load = MyShardedModel3(s1) + model_to_load._register_state_dict_hook(state_dict_hook) + state_dict_to_load_to = model_to_load.state_dict() + dist.barrier() + + fs_reader = FileSystemReader(path=path) + load_state_dict( + state_dict=state_dict_to_load_to, storage_reader=fs_reader + ) + + dist.barrier() + store_tensor = self.load_tensor(model_to_save.sharded_tensor) + dist.barrier() + load_tensor = self.load_tensor(model_to_load.sharded_tensor) + + if dist.get_rank() == 0: + self.assertTrue( + torch.allclose(store_tensor, load_tensor), msg=f"{s0} vs {s1}" + ) + + @with_comms(init_rpc=False, backend="gloo") + def test_load_rowwise_to_colwise(self) -> None: + path = self.get_file_path() + self.assertEqual(self.world_size, dist.get_world_size()) + + # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. + src_spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0", + "rank:1", + ], + ) + + # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. + dst_spec = ChunkShardingSpec( + dim=1, + placements=[ + "rank:0", + "rank:1", + ], + ) + + if dist.get_rank() == 0: + shutil.rmtree(path, ignore_errors=True) + os.makedirs(path) + + model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank()) + model_to_save._register_state_dict_hook(state_dict_hook) + state_dict_to_save = model_to_save.state_dict() + + fs_writer = FileSystemWriter(path=path) + save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) + + model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank()) + model_to_load._register_state_dict_hook(state_dict_hook) + state_dict_to_load_to = model_to_load.state_dict() + + fs_reader = FileSystemReader(path=path) + + load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) + + # We can't use torch.allclose since each ST has a different sharding spec + store_tensor = self.load_tensor(model_to_save.sharded_tensor) + load_tensor = self.load_tensor(model_to_load.sharded_tensor) + + if dist.get_rank() == 0: + self.assertTrue(torch.allclose(store_tensor, load_tensor)) + + + @with_comms(init_rpc=False, backend="gloo") + def test_save_load_bytes(self) -> None: + path = self.get_file_path() + + state_dict_to_save = { + 'bytes0': [1], + 'bytes1': 'string' + } + + fs_writer = FileSystemWriter(path=path) + save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) + + state_dict_to_load = { + 'bytes0': [2], + 'bytes1': 'other' + } + + fs_reader = FileSystemReader(path=path) + load_state_dict(state_dict=state_dict_to_load, storage_reader=fs_reader) + + self.assertEqual([1], state_dict_to_load['bytes0']) + self.assertEqual('string', state_dict_to_load['bytes1']) + + + @with_comms(init_rpc=False, backend="gloo") + def test_switch_between_sharded_tensor_to_tensor(self) -> None: + path = self.get_file_path() + tensor_size = 32 + + specs = [ + ChunkShardingSpec( + dim=0, + placements=[ + "rank:0", + "rank:1", + ], + ), + ChunkShardingSpec( + dim=0, + placements=[ + "rank:0", + "rank:1", + "rank:1", + "rank:0", + ], + ), + EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0], + shard_sizes=[8], + placement="rank:1", + ), + ShardMetadata( + shard_offsets=[8], + shard_sizes=[tensor_size - 8], + placement="rank:0", + ), + ] + ), + EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0], + shard_sizes=[10], + placement="rank:0", + ), + ShardMetadata( + shard_offsets=[10], + shard_sizes=[tensor_size - 10], + placement="rank:1", + ), + ] + ), + ] + + for save_spec in specs: + for load_spec in specs: + save_dict = { + 'sharded': sharded_tensor.rand(save_spec, tensor_size), + 'replicated': torch.rand(tensor_size, device=f"cpu:{self.rank}") + } + + fs_writer = FileSystemWriter(path=path) + save_state_dict(state_dict=save_dict, storage_writer=fs_writer) + + # Freaky Friday the tensors + load_dict = { + 'sharded': torch.zeros(tensor_size, device=f"cpu:{self.rank}"), + 'replicated': sharded_tensor.zeros(load_spec, tensor_size) + } + + fs_reader = FileSystemReader(path=path) + load_state_dict(state_dict=load_dict, storage_reader=fs_reader) + + save_dict_sharded = self.load_tensor(save_dict['sharded']) + load_dict_replicated = self.load_tensor(load_dict['replicated']) + + if dist.get_rank() == 0: + self.assertTrue( + torch.allclose(save_dict_sharded, load_dict['sharded']), + f"save-spec {save_spec} load-spec {load_spec}" + ) + self.assertTrue( + torch.allclose(save_dict['replicated'], load_dict_replicated), + f"save-spec {save_spec} load-spec {load_spec}" + ) + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_shard/checkpoint/test_utils.py b/test/distributed/_shard/checkpoint/test_utils.py new file mode 100644 index 0000000000000..9e1fd2a23888e --- /dev/null +++ b/test/distributed/_shard/checkpoint/test_utils.py @@ -0,0 +1,127 @@ +# Owner(s): ["oncall: distributed"] + +import sys + +import torch + +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardMetadata, + ShardedTensor, + ShardedTensorMetadata, +) +from torch.distributed._shard.sharded_tensor.metadata import TensorProperties + +from torch.testing._internal.common_utils import ( + TestCase, + TEST_WITH_DEV_DBG_ASAN, + run_tests, +) +from torch.distributed._shard.checkpoint.utils import find_state_dict_object +from torch.distributed._shard.checkpoint.metadata import MetadataIndex +from torch.testing._internal.distributed.distributed_utils import ( + with_fake_comms +) + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) + +def create_sharded_tensor(rank, world_size, shards_per_rank): + shards_metadata = [] + local_shards = [] + for idx in range(0, world_size * shards_per_rank): + shard_rank = idx // shards_per_rank + shard_md = ShardMetadata(shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu") + shards_metadata.append(shard_md) + if shard_rank == rank: + shard = Shard.from_tensor_and_offsets( + torch.rand(*shard_md.shard_sizes), + shard_offsets=shard_md.shard_offsets, + rank=rank + ) + local_shards.append(shard) + + sharded_tensor_md = ShardedTensorMetadata( + shards_metadata=shards_metadata, + size=torch.Size([8 * len(shards_metadata)]), + tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)) + ) + + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=sharded_tensor_md + ) + + +class TestMedatadaIndex(TestCase): + def test_init_convert_offset(self): + a = MetadataIndex("foo", [1, 2]) + b = MetadataIndex("foo", torch.Size([1, 2])) + self.assertEqual(a, b) + + def test_index_hint_ignored_on_equals(self): + a = MetadataIndex("foo") + b = MetadataIndex("foo", index=99) + self.assertEqual(a, b) + + def test_index_hint_ignored_on_hash(self): + a = MetadataIndex("foo") + b = MetadataIndex("foo", index=99) + self.assertEqual(hash(a), hash(b)) + + def test_flat_data(self): + state_dict = { + "a": torch.rand(10), + "b": [1, 2, 3], + } + + a = find_state_dict_object(state_dict, MetadataIndex("a")) + self.assertEqual(a, state_dict["a"]) + a = find_state_dict_object(state_dict, MetadataIndex("a", index=99)) + self.assertEqual(a, state_dict["a"]) + + b = find_state_dict_object(state_dict, MetadataIndex("b")) + self.assertEqual(b, state_dict["b"]) + b = find_state_dict_object(state_dict, MetadataIndex("b", index=1)) + self.assertEqual(b, state_dict["b"]) + + with self.assertRaisesRegex(ValueError, "FQN"): + find_state_dict_object(state_dict, MetadataIndex("c")) + with self.assertRaisesRegex(ValueError, "ShardedTensor"): + find_state_dict_object(state_dict, MetadataIndex("a", [0])) + with self.assertRaisesRegex(ValueError, "ShardedTensor"): + find_state_dict_object(state_dict, MetadataIndex("b", [1])) + + @with_fake_comms(rank=0, world_size=2) + def test_sharded_tensor_lookup(self): + st = create_sharded_tensor(rank=0, world_size=2, shards_per_rank=3) + state_dict = {"st": st} + + obj = find_state_dict_object(state_dict, MetadataIndex("st", [8])) + self.assertEqual(obj, st.local_shards()[1].tensor) + + # good hint + obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=1)) + self.assertEqual(obj, st.local_shards()[1].tensor) + + # bad hint + obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=2)) + self.assertEqual(obj, st.local_shards()[1].tensor) + + # broken hint + obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=99)) + self.assertEqual(obj, st.local_shards()[1].tensor) + + with self.assertRaisesRegex(ValueError, "no offset was provided"): + find_state_dict_object(state_dict, MetadataIndex("st")) + + with self.assertRaisesRegex(ValueError, "Could not find shard"): + find_state_dict_object(state_dict, MetadataIndex("st", [1])) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_shard/sharded_optim/test_sharded_optim.py b/test/distributed/_shard/sharded_optim/test_sharded_optim.py index d3f1468aea3c9..a884d64d399f8 100644 --- a/test/distributed/_shard/sharded_optim/test_sharded_optim.py +++ b/test/distributed/_shard/sharded_optim/test_sharded_optim.py @@ -13,7 +13,6 @@ ) from torch.distributed._shard.sharded_optim import ( ShardedOptimizer, - named_params_with_sharded_tensor ) from torch.testing._internal.common_distributed import ( requires_nccl, @@ -35,7 +34,7 @@ def __init__(self, spec=None, group=None): torch.manual_seed(0) self.param = torch.nn.Parameter(torch.rand(5, 10)) if spec is not None: - self.sharded_param = sharded_tensor.rand(spec, 20, 10, requires_grad=True, process_group=group) + self.sharded_param = torch.nn.Parameter(sharded_tensor.rand(spec, 20, 10, requires_grad=True, process_group=group)) else: self.sharded_param = torch.nn.Parameter(torch.rand(5, 10)) @@ -102,15 +101,15 @@ def test_sharded_optim(self): "rank:3/cuda:3", ], ) - local_model = MyShardedModel().cuda(self.rank) - sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank) + local_model = MyShardedModel().cuda() + sharded_model = MyShardedModel(spec=rowwise_spec).cuda() # copy the parameteres from local model sharded_model.sharded_param.local_shards()[0].tensor = \ local_model.sharded_param.detach().clone().requires_grad_() local_optim = optim.SGD(local_model.parameters(), lr=0.1) - sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model)) + sharded_model_params = dict(sharded_model.named_parameters()) sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1) local_optim.zero_grad() @@ -162,16 +161,16 @@ def test_named_params_with_sharded_tensor(self): "rank:3/cuda:3", ], ) - sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank) - sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model)) + sharded_model = MyShardedModel(spec=rowwise_spec).cuda() + sharded_model_params = dict(sharded_model.named_parameters()) param_keys = list(sharded_model_params.keys()) self.assertEqual(len(param_keys), 2) self.assertTrue("param" in param_keys) self.assertTrue("sharded_param" in param_keys) - sharded_linear = MyShardedLinear(rank=self.rank).cuda(self.rank) + sharded_linear = MyShardedLinear(rank=self.rank).cuda() sharded_linear.shard_parameter() - sharded_linear_params = dict(named_params_with_sharded_tensor(sharded_linear)) + sharded_linear_params = dict(sharded_linear.named_parameters()) param_keys = list(sharded_linear_params.keys()) self.assertEqual(len(param_keys), 4) self.assertTrue("linear1.bias" in param_keys) @@ -180,9 +179,5 @@ def test_named_params_with_sharded_tensor(self): self.assertTrue("linear2.weight" in param_keys) self.assertFalse("bias" in param_keys) - - - - if __name__ == '__main__': run_tests() diff --git a/test/distributed/_shard/sharded_tensor/ops/test_linear.py b/test/distributed/_shard/sharded_tensor/ops/test_linear.py index 9e28cca19eb54..77d3b1035b47c 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_linear.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_linear.py @@ -12,7 +12,6 @@ ) from torch.distributed._shard.sharded_optim import ( ShardedOptimizer, - named_params_with_sharded_tensor, ) from torch.distributed._shard.sharded_tensor import ( empty, @@ -127,7 +126,7 @@ def _run_sharded_linear( previous_sharded_weight = sharded_weight.clone() previous_sharded_bias = sharded_linear.bias.clone() sharded_optim = ShardedOptimizer( - dict(named_params_with_sharded_tensor(sharded_linear)), + dict(sharded_linear.named_parameters()), torch.optim.SGD, lr=0.1, ) @@ -192,6 +191,7 @@ def test_sharded_linear_rowwise(self): def test_sharded_linear_errors(self): for spec in generate_chunk_sharding_specs_for_test(0): fc1 = torch.nn.Linear(10, 10).cuda(self.rank) + shard_parameter(fc1, "weight", spec) shard_parameter(fc1, "bias", spec) with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'): fc1(torch.rand(10, 10).cuda(self.rank)) diff --git a/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py b/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py index e080a63875158..b5863e7ce8ded 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py @@ -8,7 +8,7 @@ from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, EnumerableShardingSpec, - ShardMetadata + ShardMetadata, ) from torch.testing._internal.common_distributed import ( requires_nccl, @@ -26,6 +26,7 @@ generate_chunk_sharding_specs_for_test, ) + class TestMathOps(ShardedTensorTestBase): @with_comms(init_rpc=False) @skip_if_lt_x_gpu(TEST_GPU_NUM) @@ -46,8 +47,12 @@ def test_basic_math_ops(self): sharded_lhs = sharded_tensor.rand(spec, (12, 3)) sharded_rhs = sharded_tensor.rand(spec, (12, 3)) current_rank = dist.get_rank() - global_lhs = torch.empty((12, 3), device=current_rank) if current_rank == 0 else None - global_rhs = torch.empty((12, 3), device=current_rank) if current_rank == 0 else None + global_lhs = ( + torch.empty((12, 3), device=current_rank) if current_rank == 0 else None + ) + global_rhs = ( + torch.empty((12, 3), device=current_rank) if current_rank == 0 else None + ) sharded_lhs.gather(dst=0, out=global_lhs) sharded_rhs.gather(dst=0, out=global_rhs) @@ -56,7 +61,9 @@ def test_basic_math_ops(self): binary_op_ = gen_binary_op_func(op, inplace=True) # test basic math ops between ShardedTensors sharded_output = binary_op(sharded_lhs, sharded_rhs) - output = torch.empty((12, 3), device=current_rank) if current_rank == 0 else None + output = ( + torch.empty((12, 3), device=current_rank) if current_rank == 0 else None + ) sharded_output.gather(dst=0, out=output) if current_rank == 0: @@ -68,25 +75,28 @@ def test_basic_math_ops(self): scalars = [3, 1.8] for scalar in scalars: sharded_output_lhs = binary_op(sharded_lhs, scalar) - - sharded_output_lhs_ = binary_op_(sharded_lhs, scalar) - self.assertTrue(torch.allclose(sharded_output_lhs, sharded_output_lhs_)) - output_lhs = torch.empty((12, 3), device=current_rank) if current_rank == 0 else None + output_lhs = ( + torch.empty((12, 3), device=current_rank) + if current_rank == 0 + else None + ) sharded_output_lhs.gather(dst=0, out=output_lhs) - sharded_output_rhs = binary_op(scalar, sharded_lhs) - output_rhs = torch.empty((12, 3), device=current_rank) if current_rank == 0 else None + sharded_output_rhs = binary_op(scalar, sharded_rhs) + output_rhs = ( + torch.empty((12, 3), device=current_rank) + if current_rank == 0 + else None + ) sharded_output_rhs.gather(dst=0, out=output_rhs) if current_rank == 0: global_output_lhs = binary_op(global_lhs, scalar) - global_output_rhs = binary_op(scalar, global_lhs) + global_output_rhs = binary_op(scalar, global_rhs) self.assertEqual(output_lhs, global_output_lhs) self.assertEqual(output_rhs, global_output_rhs) - - @with_comms(init_rpc=False) @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() @@ -103,38 +113,41 @@ def test_math_ops_errors(self): sharded_lhs = sharded_tensor.rand(spec, (20, 3)) sharded_rhs = sharded_tensor.rand(spec, (12, 3)) - with self.assertRaisesRegex(RuntimeError, 'Implicit broadcasting not supported'): + with self.assertRaisesRegex( + RuntimeError, "Implicit broadcasting not supported" + ): torch.add(sharded_lhs, sharded_rhs) - spec = EnumerableShardingSpec([ - ShardMetadata( - shard_offsets=[0, 0], - shard_sizes=[5, 5], - placement="rank:0/cuda:0", - ), - ShardMetadata( - shard_offsets=[0, 5], - shard_sizes=[5, 5], - placement="rank:1/cuda:1", - ), - ShardMetadata( - shard_offsets=[5, 0], - shard_sizes=[5, 5], - placement="rank:2/cuda:2", - ), - ShardMetadata( - shard_offsets=[5, 5], - shard_sizes=[5, 5], - placement="rank:3/cuda:3", - ) - ]) + spec = EnumerableShardingSpec( + [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[5, 5], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 5], + shard_sizes=[5, 5], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_sizes=[5, 5], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_sizes=[5, 5], + placement="rank:3/cuda:3", + ), + ] + ) st = sharded_tensor.rand(spec, 10, 10) - with self.assertRaisesRegex(RuntimeError, 'not supported'): + with self.assertRaisesRegex(RuntimeError, "not supported"): torch.add(st, sharded_rhs) - @with_comms(init_rpc=False) @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() @@ -149,7 +162,6 @@ def test_sharded_bmm(self): self.assertTrue(torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected)) self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected)) - @with_comms(init_rpc=False) @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() @@ -159,7 +171,7 @@ def test_sharded_bmm_errors(self): st_rhs = sharded_tensor.rand(specs[1], (15, 5, 6)) with self.assertRaisesRegex( NotImplementedError, - 'Both st and st2 need to have same placements for bmm', + "Both st and st2 need to have same placements for bmm", ): torch.bmm(st_lhs, st_rhs) for spec in specs: @@ -167,13 +179,13 @@ def test_sharded_bmm_errors(self): st_rhs = sharded_tensor.rand(spec, (20, 3)) with self.assertRaisesRegex( TypeError, - 'both st and st2 need to be a 3D ShardedTensor', + "both st and st2 need to be a 3D ShardedTensor", ): torch.bmm(st_lhs, st_rhs) rhs = torch.rand(15, 5, 6).cuda(self.rank) with self.assertRaisesRegex( TypeError, - 'st2 needs to be a ShardedTensor for torch.bmm', + "st2 needs to be a ShardedTensor for torch.bmm", ): torch.bmm(st_lhs, rhs) spec.dim = 1 @@ -181,6 +193,6 @@ def test_sharded_bmm_errors(self): st_rhs = sharded_tensor.rand(spec, (15, 5, 6)) with self.assertRaisesRegex( NotImplementedError, - 'Only support performing bmm on tensors sharded on dim 0 now', + "Only support performing bmm on tensors sharded on dim 0 now", ): torch.bmm(st_lhs, st_rhs) diff --git a/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py b/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py index 3f9bec1f38f5b..58aa774cd05e7 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py @@ -2,6 +2,7 @@ import copy +import torch import torch.distributed._shard.sharded_tensor as sharded_tensor from torch.distributed._shard.sharding_spec import ( @@ -21,6 +22,7 @@ run_tests, ) + class TestTensorOps(ShardedTensorTestBase): @with_comms(init_rpc=False) @skip_if_lt_x_gpu(TEST_GPU_NUM) @@ -41,6 +43,25 @@ def test_deep_copy(self): self.assertEqual(copied_st.local_tensor(), st.local_tensor()) self.assertFalse(copied_st is st) + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_inplace_copy(self): + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + st = sharded_tensor.rand(spec, (12, 5)) + ones_st = sharded_tensor.ones(spec, (12, 5)) + self.assertFalse(torch.equal(ones_st, st)) + st.copy_(ones_st) + self.assertTrue(torch.equal(st, ones_st)) + @with_comms(init_rpc=False) @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() diff --git a/test/distributed/_shard/sharded_tensor/test_megatron_prototype.py b/test/distributed/_shard/sharded_tensor/test_megatron_prototype.py index 2f3770717921c..db946d53735f6 100644 --- a/test/distributed/_shard/sharded_tensor/test_megatron_prototype.py +++ b/test/distributed/_shard/sharded_tensor/test_megatron_prototype.py @@ -7,7 +7,6 @@ import torch.distributed as dist from torch.distributed._shard.sharded_optim import ( ShardedOptimizer, - named_params_with_sharded_tensor, ) from torch.distributed._shard.api import ( shard_parameter, @@ -171,7 +170,7 @@ def _shard_parameter(module, spec): optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1) optim.step() sharded_optim = ShardedOptimizer( - dict(named_params_with_sharded_tensor(sharded_megatron_lm)), + dict(sharded_megatron_lm.named_parameters()), torch.optim.SGD, lr=0.1, ) diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py index 07d2a9e64dd82..5c548db8324dc 100644 --- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py @@ -404,6 +404,7 @@ def test_sharded_tensor_metadata(self): st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) st_metadata = st.metadata() self.assertEqual(torch.Size([10, 20]), st_metadata.size) + self.assertEqual(torch.Size([10, 20]), st.size()) self.assertEqual(torch.float, st.dtype) self.assertEqual(torch.strided, st.layout) self.assertEqual(False, st.requires_grad) @@ -432,7 +433,7 @@ def test_sharded_tensor_metadata(self): # test read only properties, they're read only as we can't simply change # the global metadata without changing the underlying shard's properties - with self.assertRaisesRegex(AttributeError, "can't set attribute"): + with self.assertRaisesRegex(RuntimeError, "torch function '__set__'"): st.requires_grad = True @with_comms @@ -952,7 +953,7 @@ def test_invalid_sharding(self): spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'): - sharded_tensor.empty(spec, 10, 20, layout=torch.sparse) + sharded_tensor.empty(spec, 10, 20, layout=torch.sparse_coo) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'): @@ -1069,11 +1070,18 @@ def test_sharded_tensor_sizes(self): st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) self.assertEqual(st.size(1), 20) + # Test with negative indexed size + st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) + self.assertEqual(st.size(-1), 20) + + # Test with dim/ndim + self.assertEqual(st.dim(), 2) + self.assertEqual(st.ndim, 2) # Test with invalid input st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) - with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[-2, 2\\)'): + with self.assertRaisesRegex(IndexError, 'Dimension out of range'): st.size(-3) - with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[-2, 2\\)'): + with self.assertRaisesRegex(IndexError, 'Dimension out of range'): st.size(2) with self.assertRaises(TypeError): @@ -1545,7 +1553,7 @@ def test_sharded_tensor_to_cpu(self): # CPU sharded tensor should return the same instance (no copy) st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg) new_st_cpu = st_cpu.cpu() - self.assertEqual(st_cpu, new_st_cpu) + self.assertTrue(st_cpu is new_st_cpu) # GPU sharded tensor to cpu st = sharded_tensor.zeros(spec, h, w) @@ -1553,7 +1561,7 @@ def test_sharded_tensor_to_cpu(self): spec_before_move = st.sharding_spec() new_st = st.cpu(process_group=gloo_pg) # return a copy of orginal st - self.assertNotEqual(st, new_st) + self.assertFalse(st is new_st) # check the spec is still ChunkShardingSpec spec_after_move = new_st.sharding_spec() self.assertIsInstance(spec_after_move, ChunkShardingSpec) @@ -1586,7 +1594,7 @@ def test_sharded_tensor_to_cpu(self): st = sharded_tensor.zeros(mixed_spec, h, w, process_group=gloo_pg) new_st = st.cpu() # return a copy of orginal st - self.assertNotEqual(st, new_st) + self.assertFalse(st is new_st) # check the spec is still ChunkShardingSpec spec_after_move = new_st.sharding_spec() self.assertIsInstance(spec_after_move, ChunkShardingSpec) @@ -1603,6 +1611,158 @@ def test_sharded_tensor_to_cpu(self): for meta in metas: self.assertEqual(str(meta.placement.device()), "cpu") + @with_comms + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_sharded_tensor_to_cuda(self): + cpu_spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cpu", + "rank:1/cpu", + "rank:2/cpu", + "rank:3/cpu", + ], + ) + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + h, w = 10, 20 + # CUDA sharded tensor should return a new ShardedTensor, but same + # local shards(no movements) + st_cuda = sharded_tensor.zeros(spec, h, w) + new_st_cuda = st_cuda.cuda() + self.assertTrue(st_cuda is not new_st_cuda) + self.assertTrue(st_cuda.local_tensor() is new_st_cuda.local_tensor()) + + gloo_pg = dist.new_group(backend="gloo") + + # CPU sharded tensor to GPU + st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg) + # test ability to move st to GPU + spec_before_move = st_cpu.sharding_spec() + new_st_gpu = st_cpu.cuda() + # check the spec is still ChunkShardingSpec + spec_after_move = new_st_gpu.sharding_spec() + self.assertIsInstance(spec_after_move, ChunkShardingSpec) + # test specs before and after the move almost the same except placement device + self.assertEqual(spec_before_move.dim, spec_after_move.dim) + self.assertEqual(len(spec_before_move.placements), len(spec_after_move.placements)) + for i, remote_device_after in enumerate(spec_after_move.placements): + remote_device_before = spec_before_move.placements[i] + self.assertEqual(remote_device_before.rank(), remote_device_after.rank()) + self.assertEqual(str(remote_device_before.device().type), "cpu") + self.assertEqual(str(remote_device_after.device().type), "cuda") + + # ensure metdata also get changed to GPU + metas = new_st_gpu.metadata().shards_metadata + for meta in metas: + self.assertEqual(str(meta.placement.device().type), "cuda") + + @with_comms + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_sharded_tensor_to_test(self): + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + h, w = 10, 20 + # CUDA sharded tensor should return a new ShardedTensor, but same + # local shards(no movements) + st = sharded_tensor.zeros(spec, h, w) + # test same dtype, device return itself + st_self = st.to(dtype=st.dtype, device="cuda") + self.assertTrue(st_self is st) + + # test dtype to + st_16 = st.to(torch.float16) + self.assertFalse(st_16 is st) + self.assertEqual(st_16.dtype, torch.float16) + # test device to + st_cpu = st.to(device=torch.device("cpu")) + self.assertFalse(st_cpu is st) + self.assertEqual(st_cpu.local_tensor().device.type, "cpu") + st_cuda = st_cpu.to(device=torch.device("cuda")) + self.assertEqual(st_cuda.local_tensor().device.type, "cuda") + # non-kwarg device to + st_cuda = st_cpu.to(torch.device("cuda")) + self.assertEqual(st_cuda.local_tensor().device.type, "cuda") + st_cpu = st_cuda.to(torch.device("cpu")) + self.assertEqual(st_cpu.local_tensor().device.type, "cpu") + # with string like device conversion + st_cpu = st_cuda.to("cpu") + self.assertEqual(st_cpu.local_tensor().device.type, "cpu") + st_cuda = st_cpu.to("cuda") + self.assertEqual(st_cuda.local_tensor().device.type, "cuda") + # with int like device conversion + st_cpu = st_cuda.to("cpu") + self.assertEqual(st_cpu.local_tensor().device.type, "cpu") + st_cuda = st_cpu.to(self.rank) + self.assertEqual(st_cuda.local_tensor().device.type, "cuda") + + # test tensor to + cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda") + st_cuda = st.to(cuda_tensor) + self.assertFalse(st_cuda is st) + self.assertEqual(st_cuda.dtype, torch.float16) + + cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda:2") + st_cuda = st.to(cuda_tensor) + self.assertEqual(st_cuda.dtype, torch.float16) + + # test dtype and device together + st_cpu_16 = st.to("cpu", torch.float16) + self.assertEqual(st_cpu_16.dtype, torch.float16) + self.assertEqual(st_cpu_16.local_tensor().device.type, "cpu") + + st_cuda_32 = st_cpu_16.to("cuda", torch.float32) + self.assertEqual(st_cuda_32.dtype, torch.float32) + self.assertEqual(st_cuda_32.local_tensor().device.type, "cuda") + + # test pass additional process group + gloo_pg = dist.new_group(backend="gloo") + st_gloo = st.to(device="cpu", process_group=gloo_pg) + self.assertFalse(st_gloo is st) + self.assertEqual(st_gloo.local_tensor().device.type, "cpu") + self.assertEqual(st_gloo._process_group, gloo_pg) + + @with_comms + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_sharded_tensor_device(self): + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + h, w = 10, 20 + # CUDA sharded tensor should return a new ShardedTensor, but same + # local shards(no movements) + st = sharded_tensor.zeros(spec, h, w) + current_device = torch.device(torch.cuda.current_device()) + self.assertEqual(current_device, st.device) + + # test after to cpu, device get changed + cpu_device = torch.device("cpu") + st_cpu = st.to(device=cpu_device) + self.assertEqual(st_cpu.device, cpu_device) + @skip_if_lt_x_gpu(4) @requires_nccl() def test_uneven_shards(self): @@ -2109,6 +2269,65 @@ def test_init_from_local_shards(self): shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size()) + @skip_if_lt_x_gpu(4) + def test_st_base_init_from_local_shards_and_global_metadata(self): + world_size = 4 + shards_metadata = [] + shards = [] + for rank in range(world_size): + local_shard_metadata = ShardMetadata( + shard_offsets=[(rank // 2) * 5, (rank % 2) * 5], + shard_sizes=[5, 5], + placement=f"rank:{rank}/cuda:{rank}", + ) + shards_metadata.append(local_shard_metadata) + shards.append( + sharded_tensor.Shard( + torch.randn(5, 5, device=f"cuda:{rank}"), local_shard_metadata + ) + ) + + tensor_properties = TensorProperties( + dtype=torch.get_default_dtype(), + layout=torch.strided, + requires_grad=False, + memory_format=torch.contiguous_format, + pin_memory=False, + ) + + sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata( + shards_metadata=shards_metadata, + size=torch.Size([10, 10]), + tensor_properties=tensor_properties, + ) + + st_base = sharded_tensor.ShardedTensorBase._init_from_local_shards_and_global_metadata( + shards, sharded_tensor_metadata=sharded_tensor_metadata + ) + self.assertEqual(4, len(st_base.local_shards())) + + # Verify local shard of st_base + local_shard = st_base.local_shards()[0] + self.assertEqual(torch.device("cuda:0"), local_shard.tensor.device) + self.assertEqual((5, 5), local_shard.tensor.size()) + + # Verify local shard metadata. + self.assertEqual( + (0, 0), + local_shard.metadata.shard_offsets, + ) + self.assertEqual((5, 5), local_shard.metadata.shard_sizes) + self.assertEqual("rank:0/cuda:0", str(local_shard.metadata.placement)) + + # Verify global metadata. + shards_metadata = st_base.metadata().shards_metadata + self.assertEqual(4, len(shards_metadata)) + for rank, shard_metadata in enumerate(shards_metadata): + self.assertEqual( + (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets + ) + self.assertEqual((5, 5), shard_metadata.shard_sizes) + self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) @with_comms @skip_if_lt_x_gpu(4) diff --git a/test/distributed/_shard/sharding_plan/test_sharding_plan.py b/test/distributed/_shard/sharding_plan/test_sharding_plan.py index 5687c9a190559..f09e1e9497696 100644 --- a/test/distributed/_shard/sharding_plan/test_sharding_plan.py +++ b/test/distributed/_shard/sharding_plan/test_sharding_plan.py @@ -8,7 +8,6 @@ import torch.distributed as dist from torch.distributed._shard.sharded_optim import ( ShardedOptimizer, - named_params_with_sharded_tensor, ) from torch.testing._internal.common_distributed import ( requires_nccl, @@ -19,7 +18,10 @@ from torch.distributed._shard.sharding_spec import ChunkShardingSpec from torch.distributed._shard.sharded_tensor import ShardedTensor -from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN +from torch.testing._internal.common_utils import ( + TEST_WITH_DEV_DBG_ASAN, + run_tests, +) from torch.testing._internal.distributed._shard.sharded_tensor import ( TEST_GPU_NUM, ShardedTensorTestBase, @@ -169,7 +171,7 @@ def test_sharding_plan_simple_megatron(self): optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1) optim.step() sharded_optim = ShardedOptimizer( - dict(named_params_with_sharded_tensor(megatron_lm)), + dict(megatron_lm.named_parameters()), torch.optim.SGD, lr=0.1, ) @@ -360,3 +362,6 @@ def test_shard_module_sub_process_group(self): if self.rank >= 2: shard_module(megatron_lm, sharding_plan, process_group=pg) + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/algorithms/quantization/test_quantization.py b/test/distributed/algorithms/quantization/test_quantization.py index 9de9d93b928e2..aebf3ccd62667 100644 --- a/test/distributed/algorithms/quantization/test_quantization.py +++ b/test/distributed/algorithms/quantization/test_quantization.py @@ -5,8 +5,8 @@ import torch.cuda import sys import torch.distributed as dist -import torch.distributed.algorithms.quantization.quantization as quant -from torch.distributed.algorithms.quantization.quantization import DQuantType +import torch.distributed.algorithms._quantization.quantization as quant +from torch.distributed.algorithms._quantization.quantization import DQuantType from torch.testing._internal.common_distributed import ( MultiProcessTestCase, init_multigpu_helper, diff --git a/test/distributed/fsdp/test_checkpoint_wrapper.py b/test/distributed/fsdp/test_checkpoint_wrapper.py index 3c2b5957e44bf..205981c541b03 100644 --- a/test/distributed/fsdp/test_checkpoint_wrapper.py +++ b/test/distributed/fsdp/test_checkpoint_wrapper.py @@ -106,7 +106,7 @@ def test(use_checkpointing, use_wrapper, use_reentrant): functional_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=True) wrapper_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=True) - self.assertEqual(functional_no_reentrant, wrapper_no_reentrant) + self.assertEqual(functional_reentrant, wrapper_reentrant) def test_forward_missing_attributes(self): lin = nn.Linear(1, 1) @@ -176,6 +176,12 @@ def check_fn(l): self.assertTrue(param.requires_grad) self.assertFalse(param.grad is None) + def test_fqn(self): + lin = nn.Linear(10, 10, bias=False) + lin = checkpoint_wrapper(lin) + state_dict = lin.state_dict() + for fqn, _ in lin.named_parameters(): + self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.") if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_flatten_params_wrapper.py b/test/distributed/fsdp/test_flatten_params_wrapper.py index 69c78ee6dde7a..56f7bc4564851 100644 --- a/test/distributed/fsdp/test_flatten_params_wrapper.py +++ b/test/distributed/fsdp/test_flatten_params_wrapper.py @@ -5,12 +5,9 @@ import torch from torch import distributed as dist -from torch.distributed.fsdp.flatten_params_wrapper import ( - FlattenParamsWrapper, - ShardMetadata, -) -from torch.testing._internal.common_utils import run_tests, TestCase - +from torch.distributed.fsdp.flat_param import FlatParamShardMetadata +from torch.distributed.fsdp.flatten_params_wrapper import FlattenParamsWrapper +from torch.testing._internal.common_utils import TestCase, run_tests if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) @@ -104,7 +101,7 @@ def test_partial_flattening(self): ) num_params_to_flatten = sum(p.numel() for p in params_to_flatten) - module = FlattenParamsWrapper(module, param_list=params_to_flatten) + module = FlattenParamsWrapper(module, params_to_flatten) self.assertEqual(module.flat_param.numel(), num_params_to_flatten) self.assertEqual(sum(p.numel() for p in module.parameters()), num_params) @@ -131,14 +128,14 @@ def test_partial_flattening(self): def test_flatten_nothing(self): module = self._get_transformer() - module = FlattenParamsWrapper(module, param_list=[]) + module = FlattenParamsWrapper(module, []) self.assertIsNone(module.flat_param) def test_empty_module(self): module = self._get_empty_module() in_data = torch.rand(1) ref_out = module(in_data) - module = FlattenParamsWrapper(module, param_list=[]) + module = FlattenParamsWrapper(module, []) self.assertEqual(len(list(module.parameters())), 0) self.assertIsNone(module.flat_param) fpw_out = module(in_data) @@ -184,126 +181,115 @@ def test_sharded_flat_param(self): ) params_to_flatten = list(module.parameters()) flat_module = FlattenParamsWrapper(module, params_to_flatten) - flat_p = flat_module.flat_param - - def _test(kwargs, expected, exception=None, regex=None): - flat_p._is_sharded = True - if exception is not None: - with self.assertRaisesRegex(exception, regex): - flat_p.shard_by_offsets(**kwargs) - else: - flat_p.shard_by_offsets(**kwargs) - self.assertEqual( - flat_p.shard_metadata(), - expected, - msg=f"{flat_p.shard_metadata()}, {expected}", - ) - self.assertEqual(flat_p.num_padded, kwargs["num_padded"]) - - _test( - kwargs={"start": -1, "end": -1, "num_padded": 0}, - expected=None, - exception=ValueError, - regex="Shard the flatten parameter with an invalid offset pair", - ) - _test( - kwargs={"start": 1, "end": 0, "num_padded": 0}, - expected=None, - exception=ValueError, - regex="Shard the flatten parameter with an invalid offset pair", - ) - _test( - kwargs={"start": 0, "end": 1, "num_padded": 3}, - expected=None, - exception=ValueError, - regex="The number of padding is larger than the shard size.", - ) + flat_param_handle = flat_module.handle + + def _test(kwargs, expected): + """ + Tests the subroutine ``_get_shard_metadata()`` that computes shard + metadata based on start and end indices in the unsharded flattened + parameter. + + We manually set the relevant attributes on the flattened parameter + to be able to check the effect of ``_get_shard_metadata()`` via + ``shard_metadata()`` since normally the attributes are set in + ``init_shard_info()`` with the start and end indices fixed based on + rank and world size. + """ + flat_param = flat_module.flat_param + flat_param._is_sharded = True + flat_param._shard_param_offsets, flat_param._shard_indices = \ + flat_param_handle._get_shard_metadata(kwargs["start"], kwargs["end"]) + self.assertEqual( + flat_param_handle.shard_metadata(), + expected, + msg=f"{flat_param_handle.shard_metadata()}, {expected}", + ) _test( - kwargs={"start": 0, "end": 0, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.0.weight"], + kwargs={"start": 0, "end": 0}, + expected=FlatParamShardMetadata( + param_names=["0.weight"], param_shapes=[(10, 10)], param_numels=[100], param_offsets=[(0, 0)], ), ) _test( - kwargs={"start": 0, "end": 50, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.0.weight"], + kwargs={"start": 0, "end": 50}, + expected=FlatParamShardMetadata( + param_names=["0.weight"], param_shapes=[(10, 10)], param_numels=[100], param_offsets=[(0, 50)], ), ) _test( - kwargs={"start": 0, "end": 99, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.0.weight"], + kwargs={"start": 0, "end": 99}, + expected=FlatParamShardMetadata( + param_names=["0.weight"], param_shapes=[(10, 10)], param_numels=[100], param_offsets=[(0, 99)], ), ) _test( - kwargs={"start": 50, "end": 149, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.0.weight", "_fpw_module.2.weight"], + kwargs={"start": 50, "end": 149}, + expected=FlatParamShardMetadata( + param_names=["0.weight", "2.weight"], param_shapes=[(10, 10), (10, 10)], param_numels=[100, 100], param_offsets=[(50, 99), (0, 49)], ), ) _test( - kwargs={"start": 50, "end": 199, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.0.weight", "_fpw_module.2.weight"], + kwargs={"start": 50, "end": 199}, + expected=FlatParamShardMetadata( + param_names=["0.weight", "2.weight"], param_shapes=[(10, 10), (10, 10)], param_numels=[100, 100], param_offsets=[(50, 99), (0, 99)], ), ) _test( - kwargs={"start": 99, "end": 199, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.0.weight", "_fpw_module.2.weight"], + kwargs={"start": 99, "end": 199}, + expected=FlatParamShardMetadata( + param_names=["0.weight", "2.weight"], param_shapes=[(10, 10), (10, 10)], param_numels=[100, 100], param_offsets=[(99, 99), (0, 99)], ), ) _test( - kwargs={"start": 100, "end": 199, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.2.weight"], + kwargs={"start": 100, "end": 199}, + expected=FlatParamShardMetadata( + param_names=["2.weight"], param_shapes=[(10, 10)], param_numels=[100], param_offsets=[(0, 99)], ), ) _test( - kwargs={"start": 100, "end": 299, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.2.weight", "_fpw_module.4.weight"], + kwargs={"start": 100, "end": 299}, + expected=FlatParamShardMetadata( + param_names=["2.weight", "4.weight"], param_shapes=[(10, 10), (10, 10)], param_numels=[100, 100], param_offsets=[(0, 99), (0, 99)], ), ) _test( - kwargs={"start": 100, "end": 1000, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.2.weight", "_fpw_module.4.weight"], + kwargs={"start": 100, "end": 1000}, + expected=FlatParamShardMetadata( + param_names=["2.weight", "4.weight"], param_shapes=[(10, 10), (10, 10)], param_numels=[100, 100], param_offsets=[(0, 99), (0, 99)], ), ) _test( - kwargs={"start": 299, "end": 299, "num_padded": 0}, - expected=ShardMetadata( - param_names=["_fpw_module.4.weight"], + kwargs={"start": 299, "end": 299}, + expected=FlatParamShardMetadata( + param_names=["4.weight"], param_shapes=[(10, 10)], param_numels=[100], param_offsets=[(99, 99)], diff --git a/test/distributed/fsdp/test_fsdp_apply.py b/test/distributed/fsdp/test_fsdp_apply.py index 7870804d78fc7..d72d57d133b0d 100644 --- a/test/distributed/fsdp/test_fsdp_apply.py +++ b/test/distributed/fsdp/test_fsdp_apply.py @@ -6,12 +6,13 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.testing._internal.common_distributed import ( - skip_if_lt_x_gpu, -) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, FSDPTest, NestedWrappedModule, + TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, @@ -41,12 +42,8 @@ def _init_linear_weights(self, m): m.weight.fill_(1.0) m.bias.fill_(1.0) - @property - def process_group(self): - return dist.distributed_c10d._get_default_group() - def check_weights(self, fsdp, expected_tensor_fn, check): - with fsdp.summon_full_params(fsdp, recurse=True): + with FSDP.summon_full_params(fsdp, recurse=True): linear_modules = [ module for module in fsdp.modules() if type(module) == nn.Linear ] @@ -70,32 +67,36 @@ def _check_apply(self, fsdp): @skip_if_lt_x_gpu(2) def test_nested_module_apply(self): - """ - Checks apply() modifies weights appropriately on a nested FSDP instance. - """ - nested_module = NestedWrappedModule( - self.process_group, wrap_fsdp=True, wrap_everything=True + """Tests that ``apply()`` modifies parameter values in-place on a + non-FSDP-root nested FSDP-wrapped model.""" + nested_wrapped_module = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, ) - fsdp_module = FSDP(nested_module, self.process_group).cuda(self.rank) - self._check_apply(fsdp_module) + self._check_apply(nested_wrapped_module) @skip_if_lt_x_gpu(2) def test_transformer_module_apply(self): - """ - Checks apply() modifies weights appropriately on a wrapped Transformer - module. - """ - transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank) + """Tests that ``apply()`` modifies parameter values in-place on an + FSDP-wrapped transformer model with shared parameters.""" + transformer = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, + ) self._check_apply(transformer) @skip_if_lt_x_gpu(2) def test_apply_in_summon_raises_error(self): - """ - Ensures that if user calls apply() on FSDP instance within full param - summon context, appropriate error is raised. - """ - transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank) - with transformer.summon_full_params(transformer, recurse=True): + """Tests that calling ``apply()`` on an FSDP instance inside the + ``summon_full_params()`` context raises an error.""" + transformer = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, + ) + with transformer.summon_full_params(transformer): with self.assertRaisesRegex(ValueError, "expected to be in states"): transformer.apply(self._init_linear_weights) diff --git a/test/distributed/fsdp/test_fsdp_comm.py b/test/distributed/fsdp/test_fsdp_comm.py index c527ca7aebc80..432e56ac03599 100644 --- a/test/distributed/fsdp/test_fsdp_comm.py +++ b/test/distributed/fsdp/test_fsdp_comm.py @@ -11,7 +11,13 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest, NestedWrappedModule +from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, + FSDPTest, + NestedWrappedModule, + TransformerWithSharedParams, +) from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, @@ -45,19 +51,25 @@ def _init_model( sharding_strategy: ShardingStrategy, device: torch.device, ): - group = dist.distributed_c10d._get_default_group() + fsdp_kwargs = {"sharding_strategy": sharding_strategy} if nested_model: - model = NestedWrappedModule( - group, wrap_fsdp=True, sharding_strategy=sharding_strategy, + model = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, + fsdp_kwargs, ) fsdp_model: FSDP = FSDP( - model, group, sharding_strategy=sharding_strategy, + model, + self.process_group, + **fsdp_kwargs, ).to(device) else: - fsdp_model: FSDP = self._get_wrapped_model( - group, - cuda_first=False, - config={"sharding_strategy": sharding_strategy}, + fsdp_model: FSDP = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + fsdp_kwargs, ) return fsdp_model diff --git a/test/distributed/fsdp/test_fsdp_comm_hooks.py b/test/distributed/fsdp/test_fsdp_comm_hooks.py new file mode 100644 index 0000000000000..a16258855b2ad --- /dev/null +++ b/test/distributed/fsdp/test_fsdp_comm_hooks.py @@ -0,0 +1,390 @@ +# Owner(s): ["oncall: distributed"] + +import sys +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import distributed as dist +from torch.distributed.algorithms._comm_hooks import default_hooks +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy +from torch.testing._internal.common_distributed import ( + requires_nccl, + requires_nccl_version, + sandcastle_skip_if, + skip_if_lt_x_gpu, + skip_if_rocm, +) +from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + +# bfloat16 is only supported by CUDA 11+ +BFLOAT16_AVAILABLE = ( + torch.cuda.is_available() + and torch.version.cuda is not None + and int(torch.version.cuda.split('.')[0]) >= 11) + +class Net(nn.Module): + + def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): + # to ensure determinism + torch.manual_seed(0) + torch.cuda.manual_seed(0) + super().__init__() + + if has_wrapping: + self.net = FSDP(nn.Sequential( + nn.Linear(8, 16), + nn.ReLU(), + FSDP( + nn.Linear(16, 8), + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + ) + ), + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + ) + else: + self.net = nn.Sequential( + nn.Linear(8, 16), + nn.ReLU(), + nn.Linear(16, 8) + ) + + self.out = nn.Linear(8, 4) + + def forward(self, x): + return self.out(F.relu(self.net(x))) + +class DummyState(object): + + __slots__ = [ + "process_group" + ] + + def __init__(self, process_group): + self.process_group = process_group + +class DummyHook(object): + + def dummy_hook(self, state: DummyState, grad: torch.Tensor): + pass + +class TestCommunicationHooks(FSDPTest): + + @skip_if_lt_x_gpu(2) + @parametrize( + "sharding_strategy", + [ + ShardingStrategy.NO_SHARD + ]) + def test_default_communication_hook_behavior( + self, + sharding_strategy: Optional[ShardingStrategy] + ): + """ + Tests FSDP's default communication hook's behavior and correctness. + Arguments: + sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. + """ + m = torch.nn.Linear(1, 5, bias=False) + inpt = torch.tensor([self.rank]).float().cuda(self.rank) + + net_default_hook = FSDP( + m, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy + ).to(self.rank) + + # Check that default hook is set to `all_reduce` + for entry in FSDP.fsdp_modules(net_default_hook): + self.assertEqual(entry.communication_hook, default_hooks.allreduce_hook) + + for _ in range(4): + + # Clear gradients + net_default_hook.zero_grad() + loss = net_default_hook(inpt).sum() + loss.backward() + + # For each worker, the gradient on the weight should be worker_rank. + grad = net_default_hook.params[0].grad + expected_grad = ( + sum(i for i in range(dist.get_world_size())) / dist.get_world_size() + ) + # Verify default hook produces expected gradients + self.assertEqual( + grad[0].item(), + expected_grad, + msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}") + + def _get_submodules(self, fsdp_net): + return [ + submodule for submodule in FSDP.fsdp_modules(fsdp_net) + if not submodule.check_is_root() + ] + + def _init_model(self, core, sharding_strategy, mixed_precision=None): + + device = torch.device("cuda") + return FSDP( + core, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + ).to(device) + + @skip_if_lt_x_gpu(2) + @parametrize("has_wrapping", [True, False]) + @parametrize( + "sharding_strategy", + [ + ShardingStrategy.NO_SHARD, + ShardingStrategy.FULL_SHARD, + ShardingStrategy.SHARD_GRAD_OP + ]) + def test_default_communication_hook_initialization( + self, + has_wrapping: bool, + sharding_strategy: Optional[ShardingStrategy] + ): + """ + Tests FSDP's communication hook interface behavior. + Arguments: + has_wrapping (bool): Configures wrapping of a module. + sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. + """ + + # Initialize a model + fsdp_model_with_hook = self._init_model( + Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), + sharding_strategy=sharding_strategy + ) + dummy_state = DummyState(process_group=None) + + # FSDP currently supports communication hooks for a NO_SHARD strategy + # Check that a `NotImplementedError` is raised for other strategies + if sharding_strategy != ShardingStrategy.NO_SHARD: + # Check that default hook is set to None + for entry in FSDP.fsdp_modules(fsdp_model_with_hook): + self.assertIsNone(entry.communication_hook) + self.assertIsNone(entry.communication_hook_state) + + with self.assertRaisesRegex( + NotImplementedError, + '^Communication hooks are currently only available for a NO_SHARD strategy.$' + ): + fsdp_model_with_hook.register_comm_hook(dummy_state, DummyHook.dummy_hook) + + else: + + # Check that default hook is set to `all_reduce` + for entry in FSDP.fsdp_modules(fsdp_model_with_hook): + self.assertEqual(entry.communication_hook, default_hooks.allreduce_hook) + + dummy_state = DummyState(process_group=None) + + fsdp_model_with_hook.register_comm_hook( + dummy_state, + DummyHook.dummy_hook + ) + + # Check that we can't register comm hook twice + with self.assertRaisesRegex(AssertionError, '^communication hook can be only registered once$'): + fsdp_model_with_hook.register_comm_hook( + dummy_state, + DummyHook.dummy_hook + ) + + # Check dummy hook was registered for the root and all submodules if any + for entry in FSDP.fsdp_modules(fsdp_model_with_hook): + self.assertEqual( + entry.communication_hook, + DummyHook.dummy_hook + ) + self.assertEqual( + entry.communication_hook_state, + dummy_state + ) + + @skip_if_lt_x_gpu(2) + @parametrize( + "sharding_strategy", + [ + ShardingStrategy.NO_SHARD + ]) + def test_registering_hook_non_root( + self, + sharding_strategy: Optional[ShardingStrategy] + ): + """ + Tests FSDP's communication hook registering for submodules. + Make sure it can't be registered for non-root submodules. + Currently tests only ``NO_SHARD`` strategy. + Arguments: + sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. + + """ + + fsdp_model_with_hook = self._init_model( + Net(has_wrapping=True, sharding_strategy=sharding_strategy), + sharding_strategy=sharding_strategy + ) + dummy_state = DummyState(process_group=None) + # Creating a list of non-root submodules to test + submodules = self._get_submodules(fsdp_model_with_hook) + # Check that assertion is raised for registering a comm hook on a non-root + with self.assertRaisesRegex(AssertionError, '^register_comm_hook can only be called on a root instance.$'): + submodules[1].register_comm_hook(dummy_state, DummyHook.dummy_hook) + + @skip_if_lt_x_gpu(2) + @parametrize( + "sharding_strategy", + [ + ShardingStrategy.NO_SHARD + ]) + def test_registering_hook_submodules( + self, + sharding_strategy: Optional[ShardingStrategy] + ): + """ + Tests FSDP's communication hook registering for submodules. + Checks behavior if a hook was registered for a non-root submodule + Currently tests only ``NO_SHARD`` strategy. + Arguments: + sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm. + + """ + + fsdp_model_with_hook = self._init_model( + Net(has_wrapping=True, sharding_strategy=sharding_strategy), + sharding_strategy=sharding_strategy + ) + dummy_state = DummyState(process_group=None) + submodules = self._get_submodules(fsdp_model_with_hook) + + # Simulate a registration of a hook on a submodule + submodules[1]._hook_registered = True + # Check that an error is raised when some of submodules have a non-default hook assigned + with self.assertRaisesRegex(AssertionError, '^communication hook can be only registered once$'): + fsdp_model_with_hook.register_comm_hook(dummy_state, DummyHook.dummy_hook) + + # Reinitialize the model + fsdp_model_with_hook = self._init_model( + Net(has_wrapping=True, sharding_strategy=sharding_strategy), + sharding_strategy=sharding_strategy + ) + submodules = self._get_submodules(fsdp_model_with_hook) + submodules[1].communication_hook = DummyHook.dummy_hook + + # Check that an error is raised when some of submodules have a non-default hook assigned + with self.assertRaisesRegex( + AssertionError, + f'^communication hook should be default, but it is {submodules[1].communication_hook.__name__} instead$' + ): + fsdp_model_with_hook.register_comm_hook( + dummy_state, + DummyHook.dummy_hook + ) + + def _check_low_precision_hook(self, state, hook, sharding_strategy, dtype, has_wrapping): + # keep everything deterministic for input data + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + fsdp_with_hook = self._init_model( + Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), + sharding_strategy=sharding_strategy + ) + fsdp_with_hook.register_comm_hook(state, hook) + + mp_only_grad = MixedPrecision(reduce_dtype=dtype) + fsdp_with_mp = self._init_model( + Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy, mixed_precision=mp_only_grad), + sharding_strategy=sharding_strategy, + mixed_precision=mp_only_grad + ) + + optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1) + optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1) + + in_data = torch.rand(16, 8).cuda() + fsdp_with_hook.train() + fsdp_with_mp.train() + loss_hook = fsdp_with_hook(in_data).sum() + loss_mp = fsdp_with_mp(in_data).sum() + loss_hook.backward() + # Make sure grads were cast to the parameter's precision + self.assertEqual(fsdp_with_hook.params[0].dtype, state.parameter_type) + loss_mp.backward() + optim_hook.step() + optim_mp.step() + + dist.barrier() + + for hook_param, mp_param in zip(fsdp_with_hook.parameters(), fsdp_with_mp.parameters()): + self.assertEqual(hook_param.grad, mp_param.grad) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @parametrize("has_wrapping", [True, False]) + @parametrize( + "sharding_strategy", + [ + ShardingStrategy.NO_SHARD + ]) + def test_fp16_hook( + self, + has_wrapping: bool, + sharding_strategy: Optional[ShardingStrategy] + ): + + state = default_hooks.LowPrecisionState(process_group=None) + hook = default_hooks.fp16_compress_hook + + self._check_low_precision_hook(state, hook, sharding_strategy, torch.float16, has_wrapping) + + @requires_nccl() + @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS") + @sandcastle_skip_if( + not BFLOAT16_AVAILABLE, + "BFloat16 is only supported by CUDA 11+", + ) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + @parametrize("has_wrapping", [True, False]) + @parametrize( + "sharding_strategy", + [ + ShardingStrategy.NO_SHARD + ]) + def test_bf16_hook( + self, + has_wrapping: bool, + sharding_strategy: Optional[ShardingStrategy] + ): + + state = default_hooks.LowPrecisionState(process_group=None) + hook = default_hooks.bf16_compress_hook + + self._check_low_precision_hook(state, hook, sharding_strategy, torch.bfloat16, has_wrapping) + + +instantiate_parametrized_tests(TestCommunicationHooks) + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py index 35698092d3540..36dc19eeda808 100644 --- a/test/distributed/fsdp/test_fsdp_core.py +++ b/test/distributed/fsdp/test_fsdp_core.py @@ -3,15 +3,21 @@ import functools import itertools import sys +from typing import Any, Dict, List, Optional from unittest import mock import torch import torch.distributed as dist import torch.nn as nn -from torch.testing._internal.common_distributed import ( - skip_if_lt_x_gpu, +from torch.distributed.fsdp import CPUOffload, MixedPrecision +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + ShardingStrategy, ) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + AlwaysWrapNestedWrappedModule, + CUDAInitMode, DummyDDP, FSDPInitMode, FSDPTest, @@ -19,7 +25,7 @@ NestedWrappedModule, NestedWrappedModuleWithDelay, TransformerWithSharedParams, - subtest_name + subtest_name, ) from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, @@ -28,10 +34,6 @@ run_tests, ) -from torch.distributed.fsdp import CPUOffload, MixedPrecision -from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy - - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -43,22 +45,13 @@ ) sys.exit(0) -params = "cpu_offload,backward_prefetch,forward_prefetch,sharding_strategy" +params = "cpu_offload,sharding_strategy" cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] -backward_prefetch_config = [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None] -forward_prefetch_config = ["forward_prefetch", "no_forward_prefetch"] -sharding_strategy_config = [ShardingStrategy.SHARD_GRAD_OP, None, ShardingStrategy.NO_SHARD] -configs = list(itertools.product(cpu_offload_config, - backward_prefetch_config, - forward_prefetch_config, - sharding_strategy_config)) +sharding_strategy_config = [None, ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD] +configs = list(itertools.product(cpu_offload_config, sharding_strategy_config)) test_name_mapping = { str(CPUOffload(offload_params=True)): "offload_true", str(CPUOffload(offload_params=False)): "offload_false", - str(BackwardPrefetch.BACKWARD_PRE): "backward_prefetch_pre", - str(BackwardPrefetch.BACKWARD_POST): "backward_prefetch_post", - "forward_prefetch": "forward_prefetch", - "no_forward_prefetch": "no_forward_prefetch", str(ShardingStrategy.SHARD_GRAD_OP): "shard_grad_op", str(ShardingStrategy.NO_SHARD): "no_shard", } @@ -72,216 +65,236 @@ class TestParityWithDDP(FSDPTest): PyTorch DDP vs. FullyShardedDataParallel. """ - def _get_init_modes_for_test(self, cpu_offload): + def _get_cuda_init_modes(self, cpu_offload: CPUOffload) -> List[CUDAInitMode]: modes = [ - FSDPInitMode.CUDA_AFTER, - FSDPInitMode.CUDA_BEFORE + CUDAInitMode.CUDA_AFTER, + CUDAInitMode.CUDA_BEFORE ] - # Note that FSDPInitMode.CUDA_NEVER works currently only with CPU + # Note that CUDAInitMode.CUDA_NEVER works currently only with CPU # offload as we explicitly bring the param back to CUDA device. In # general, it will not work since we try to all_gather p.data which is # on CPU but NCCL only supports GPU. if cpu_offload.offload_params: - modes.append(FSDPInitMode.CUDA_NEVER) + modes.append(CUDAInitMode.CUDA_NEVER) return modes + def _get_subtest_config(self, cpu_offload: CPUOffload) -> Dict[str, List[Any]]: + """Returns a subtest configuration that subtests CUDA initialization + modes and prefetching settings together.""" + return { + "cuda_init_mode": self._get_cuda_init_modes(cpu_offload), + "forward_prefetch": [False, True], + "backward_prefetch": [ + None, + BackwardPrefetch.BACKWARD_PRE, + BackwardPrefetch.BACKWARD_POST, + ] + } + @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - def test_nested_wrapped_model(self, cpu_offload, backward_prefetch, forward_prefetch, sharding_strategy): - forward_prefetch = (forward_prefetch == "forward_prefetch") - init_modes = self._get_init_modes_for_test(cpu_offload) - for fsdp_init_mode in init_modes: - with self.subTest(fsdp_init_mode=fsdp_init_mode): - self._test_identical_outputs( - NestedWrappedModule, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - sharding_strategy=sharding_strategy, - ) + def test_nested_wrapped_model( + self, + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], + ): + self.run_subtests( + self._get_subtest_config(cpu_offload), + self._test_fsdp_parity, + NestedWrappedModule, + FSDPInitMode.RECURSIVE, + cpu_offload=cpu_offload, + sharding_strategy=sharding_strategy, + ) @skip_if_lt_x_gpu(2) - @parametrize("cpu_offload", cpu_offload_config) - @parametrize("sharding_strategy", sharding_strategy_config) - @parametrize("mixed_precision", [True, False]) + @parametrize(params, configs, subtest_name) def test_nested_wrapped_model_single_iteration_mixed_precision( self, - cpu_offload, - sharding_strategy, - mixed_precision + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], ): - init_modes = self._get_init_modes_for_test(cpu_offload) mixed_precision = MixedPrecision( param_dtype=torch.float16, buffer_dtype=torch.float16, reduce_dtype=torch.float16, - ) if mixed_precision else None - for fsdp_init_mode in init_modes: - with self.subTest(fsdp_init_mode=fsdp_init_mode): - self._test_identical_outputs( - NestedWrappedModule, - # Only run one step for comparison, as usually grad scaler - # is needed to avoid NaN after first step. - num_steps=1, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - ) + ) + self.run_subtests( + self._get_subtest_config(cpu_offload), + self._test_fsdp_parity, + NestedWrappedModule, + FSDPInitMode.RECURSIVE, + cpu_offload=cpu_offload, + sharding_strategy=sharding_strategy, + num_iters=1, + mixed_precision=mixed_precision, + ) @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - @parametrize("clip_norm_type", [2.0, None]) - def test_nested_all_wrapped_model( - self, cpu_offload, backward_prefetch, forward_prefetch, sharding_strategy, clip_norm_type): - forward_prefetch = (forward_prefetch == "forward_prefetch") - init_modes = self._get_init_modes_for_test(cpu_offload) - for fsdp_init_mode in init_modes: - with self.subTest(fsdp_init_mode=fsdp_init_mode): - model_fn = functools.partial(NestedWrappedModule, wrap_everything=True) - self._test_identical_outputs( - model_fn, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - norm_type=clip_norm_type, - sharding_strategy=sharding_strategy, - ) + # TODO (awgu): 2.0 fails tests + # @parametrize("norm_type", [2.0, None]) + @parametrize("norm_type", [None]) + def test_nested_always_wrap_model( + self, + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], + norm_type: Optional[float], + ): + self.run_subtests( + self._get_subtest_config(cpu_offload), + self._test_fsdp_parity, + AlwaysWrapNestedWrappedModule, + FSDPInitMode.RECURSIVE, + cpu_offload=cpu_offload, + sharding_strategy=sharding_strategy, + norm_type=norm_type, + ) @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - @parametrize("clip_norm_type", [2.0, None]) - def test_transformer_parameterized( - self, cpu_offload, backward_prefetch, forward_prefetch, sharding_strategy, clip_norm_type): - forward_prefetch = (forward_prefetch == "forward_prefetch") - init_modes = self._get_init_modes_for_test(cpu_offload) - for fsdp_init_mode in init_modes: - with self.subTest(fsdp_init_mode=fsdp_init_mode): - self._test_identical_outputs( - TransformerWithSharedParams, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - norm_type=clip_norm_type, - sharding_strategy=sharding_strategy, - ) + # TODO (awgu): 2.0 fails tests + # @parametrize("norm_type", [2.0, None]) + @parametrize("norm_type", [None]) + def test_transformer( + self, + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], + norm_type: Optional[float], + ): + self.run_subtests( + self._get_subtest_config(cpu_offload), + self._test_fsdp_parity, + TransformerWithSharedParams, + FSDPInitMode.RECURSIVE, + cpu_offload=cpu_offload, + norm_type=norm_type, + sharding_strategy=sharding_strategy, + ) @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - def test_delayed_optim_step(self, cpu_offload, backward_prefetch, forward_prefetch, sharding_strategy): - forward_prefetch = (forward_prefetch == "forward_prefetch") - # We use a model with a long CUDA delay right before the optimizer step. - # This tests our streams logic, and that we don't start the allgather - # until after the optimization step completes. - init_modes = self._get_init_modes_for_test(cpu_offload) - for fsdp_init_mode in init_modes: - with self.subTest(fsdp_init_mode=fsdp_init_mode): - model_fn = functools.partial( - NestedWrappedModuleWithDelay, delay_after_loss_ms=250 - ) - self._test_identical_outputs( - model_fn, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - sharding_strategy=sharding_strategy, - ) + def test_delayed_optim_step( + self, + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], + ): + """Tests the FSDP forward, backward, and optimizer step runtime by + using a model with a long CUDA delay after the loss computation/before + the optimizer step to exercise the internal CUDA stream usage in that + the forward pass all-gathers do not start until after the optimizer + step completes.""" + self.run_subtests( + self._get_subtest_config(cpu_offload), + self._test_fsdp_parity, + NestedWrappedModuleWithDelay, + FSDPInitMode.RECURSIVE, + cpu_offload=cpu_offload, + sharding_strategy=sharding_strategy, + init_kwargs={"delay_after_loss_ms": 250}, + ) @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch, forward_prefetch, sharding_strategy): - forward_prefetch = (forward_prefetch == "forward_prefetch") - # We insert a delay in the torch.distributed._reduce_scatter_base op, so that - # the post_backward_stream takes much longer than the backward pass. - # This tests that we properly block at the end of the backward pass for - # the reductions to finish. - init_modes = self._get_init_modes_for_test(cpu_offload) - for fsdp_init_mode in init_modes: - with self.subTest(fsdp_init_mode=fsdp_init_mode): - model_fn = functools.partial( - NestedWrappedModuleWithDelay, delay_before_reduction_ms=250 - ) - self._test_identical_outputs( - model_fn, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - sharding_strategy=sharding_strategy, - ) + def test_delayed_reduce_scatter( + self, + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], + ): + """Tests the FSDP forward, backward, and optimizer step runtime by + using a model with a long CUDA delay before the gradient reduce-scatter + to exercise the internal CUDA stream usage in that the backward pass + waits for those reductions to finish.""" + self.run_subtests( + self._get_subtest_config(cpu_offload), + self._test_fsdp_parity, + NestedWrappedModuleWithDelay, + FSDPInitMode.RECURSIVE, + cpu_offload=cpu_offload, + sharding_strategy=sharding_strategy, + init_kwargs={"delay_before_reduction_ms": 250}, + ) def _dummy_ddp_fn(self, model): + # `MixtureOfExperts`` implements custom gradient reduction logic, so + # the reference behavior should follow that logic instead of DDP return DummyDDP(model) @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - @parametrize("clip_norm_type", [2.0, None]) + # TODO (awgu): 2.0 fails tests + # @parametrize("norm_type", [2.0, None]) + @parametrize("norm_type", [None]) def test_mixture_of_experts( - self, cpu_offload, backward_prefetch, forward_prefetch, sharding_strategy, clip_norm_type): - forward_prefetch = (forward_prefetch == "forward_prefetch") - init_modes = self._get_init_modes_for_test(cpu_offload) - for fsdp_init_mode in init_modes: - with self.subTest(fsdp_init_mode=fsdp_init_mode): - self._test_identical_outputs( - MixtureOfExperts, - # MixtureOfExperts implements custom reduce logic, so the reference - # behavior should use that logic instead of PyTorch DDP. - ref_ddp_fn=self._dummy_ddp_fn, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - norm_type=clip_norm_type, - sharding_strategy=sharding_strategy, - ) + self, + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], + norm_type: Optional[float], + ): + self.run_subtests( + self._get_subtest_config(cpu_offload), + self._test_fsdp_parity, + MixtureOfExperts, + FSDPInitMode.RECURSIVE, + ref_init_fn=self._dummy_ddp_fn, + cpu_offload=cpu_offload, + sharding_strategy=sharding_strategy, + norm_type=norm_type, + ) @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) def test_mixture_of_experts_with_delay_before_free( - self, cpu_offload, backward_prefetch, forward_prefetch, sharding_strategy): - forward_prefetch = (forward_prefetch == "forward_prefetch") - init_modes = self._get_init_modes_for_test(cpu_offload) - for fsdp_init_mode in init_modes: - with self.subTest(fsdp_init_mode=fsdp_init_mode): - model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250) - self._test_identical_outputs( - model_fn, - ref_ddp_fn=self._dummy_ddp_fn, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - sharding_strategy=sharding_strategy, - ) + self, + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], + ): + self.run_subtests( + self._get_subtest_config(cpu_offload), + self._test_fsdp_parity, + MixtureOfExperts, + FSDPInitMode.RECURSIVE, + ref_init_fn=self._dummy_ddp_fn, + cpu_offload=cpu_offload, + sharding_strategy=sharding_strategy, + init_kwargs={"delay_before_free_ms": 250} + ) class TestParamInit(FSDPTest): @skip_if_lt_x_gpu(2) @parametrize("mixed_precision", [True, False]) def test_param_change_after_init(self, mixed_precision): - group = dist.distributed_c10d._get_default_group() - # Establish reference behavior. - mixed_precision = MixedPrecision() if mixed_precision else None - config = {"mixed_precision": mixed_precision} - model = self._get_wrapped_model( - group, mixed_precision=mixed_precision, cuda_first=False + """ + Tests that changing FSDP model parameter values in-place after FSDP + initialization persist. + """ + # Establish reference behavior + fsdp_kwargs = {} + if mixed_precision: + fsdp_kwargs["mixed_precision"] = MixedPrecision() + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, + fsdp_kwargs, + deterministic=True, ) - model.eval() # no dropout for this test - input = model.module.get_input(torch.device("cuda")) - ref_output = model(*input) - - # Change the weights in place. - model = self._get_wrapped_model(group, cuda_first=False) - model.eval() # no dropout for this test - first_param = next(model.parameters()) + input = fsdp_model.module.get_input(torch.device("cuda")) + ref_output = fsdp_model(*input) + # Initialize the same model but change its first parameter value + # in-place after FSDP initialization + new_fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, + fsdp_kwargs, + deterministic=True, + ) + first_param = next(new_fsdp_model.parameters()) nn.init.normal_(first_param.data) - new_output = model(*input) - + new_output = new_fsdp_model(*input) self.assertNotEqual( ref_output, new_output, @@ -290,24 +303,33 @@ def test_param_change_after_init(self, mixed_precision): class TestHooks(FSDPTest): - # They aspire to make sure that backward hooks are registered and used @skip_if_lt_x_gpu(2) @parametrize("cuda_first", [False, True]) - def test_output_backward_hooks(self, cuda_first): - group = dist.distributed_c10d._get_default_group() - model = self._get_wrapped_model(group, cuda_first=cuda_first) - self._test_output_backward_hooks(model=model) + def test_pre_backward_hook_registration(self, cuda_first: bool): + """Tests that FSDP pre-backward hooks are registered on forward pass + outputs.""" + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE if cuda_first else CUDAInitMode.CUDA_AFTER, + ) + self._test_pre_backward_hook_registration(fsdp_model) @skip_if_lt_x_gpu(2) - def test_backward_hooks_after_save(self): - group = dist.distributed_c10d._get_default_group() - model = self._get_wrapped_model(group, cuda_first=False) - self._train_for_several_steps(model, num_steps=2, autocast=False) - state_1 = model.state_dict() - model.load_state_dict(state_1) - self._test_output_backward_hooks(model=model) - - def _test_output_backward_hooks(self, model): + def test_pre_backward_hook_registration_after_state_dict(self): + """Tests that FSDP pre-backward hooks are registered on forward pass + outputs after saving and loading the model from a checkpoint.""" + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, + ) + self._train_for_several_steps(fsdp_model, num_steps=2, autocast=False) + state_dict = fsdp_model.state_dict() + fsdp_model.load_state_dict(state_dict) + self._test_pre_backward_hook_registration(fsdp_model) + + def _test_pre_backward_hook_registration(self, model): optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optim.zero_grad() # Inputs always cuda, as computation happes on CUDA device only @@ -325,54 +347,64 @@ def _test_output_backward_hooks(self, model): @skip_if_lt_x_gpu(2) @parametrize("cuda_first", [False, True]) @parametrize("mixed_precision", [True, False]) - def test_register_functions_called(self, cuda_first, mixed_precision): - """Tests that _register_{pre|post}_backward_hooks called during forward.""" - group = dist.distributed_c10d._get_default_group() - mixed_precision = MixedPrecision() if mixed_precision else None - config = {"mixed_precision": mixed_precision} - model = self._get_wrapped_model( - group, mixed_precision=mixed_precision, cuda_first=cuda_first + def test_register_functions_called(self, cuda_first: bool, mixed_precision: bool): + """Tests that ``_register_{pre|post}_backward_hooks()`` are called + during the FSDP forward.""" + fsdp_kwargs = {} + if mixed_precision: + fsdp_kwargs["mixed_precision"] = MixedPrecision() + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE if cuda_first else CUDAInitMode.CUDA_AFTER, + fsdp_kwargs, ) - input = model.module.get_input(torch.device("cuda")) - model._register_post_backward_hooks = mock.MagicMock(return_value=None) - model._register_pre_backward_hooks = mock.MagicMock(return_value=None) - self.assertFalse(model._register_post_backward_hooks.called) - self.assertFalse(model._register_pre_backward_hooks.called) - model(*input) - self.assertTrue(model._register_post_backward_hooks.called) - self.assertTrue(model._register_pre_backward_hooks.called) + input = fsdp_model.module.get_input(torch.device("cuda")) + fsdp_model._register_pre_backward_hooks = mock.MagicMock(return_value=None) + fsdp_model._register_post_backward_hooks = mock.MagicMock(return_value=None) + self.assertFalse(fsdp_model._register_post_backward_hooks.called) + self.assertFalse(fsdp_model._register_pre_backward_hooks.called) + fsdp_model(*input) + self.assertTrue(fsdp_model._register_post_backward_hooks.called) + self.assertTrue(fsdp_model._register_pre_backward_hooks.called) class TestNoGrad(FSDPTest): @skip_if_lt_x_gpu(2) @parametrize("mixed_precision", [True, False]) def test_transformer_no_grad(self, mixed_precision): - group = dist.distributed_c10d._get_default_group() - mixed_precision = MixedPrecision( - param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16, - ) if mixed_precision else None - config = {"mixed_precision": mixed_precision} - model = self._get_wrapped_model(group, config=config, cuda_first=False) - # Train model for a step + """Tests that for an FSDP-wrapped transformer model with shared + parameters, after training for one iteration, running a forward pass in + ``eval()`` mode gives the same output as running a forward pass in + ``torch.no_grad()``.""" + fsdp_kwargs = {} + if mixed_precision: + fsdp_kwargs["mixed_precision"] = MixedPrecision( + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, + ) + else: + fsdp_kwargs["mixed_precision"] = None + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, + fsdp_kwargs, + ) self._train_for_several_steps( - model, + fsdp_model, num_steps=1, autocast=False, - mixed_precision=config["mixed_precision"] + mixed_precision=fsdp_kwargs["mixed_precision"] ) - - model.eval() # no dropout for this test - - # Eval in standard mode (i.e., without no_grad) - input = model.module.get_input(torch.device("cuda")) - ref_output = model(*input) - - # Eval with no_grad and compare + input = fsdp_model.module.get_input(torch.device("cuda")) + # Run a forward in eval mode + fsdp_model.eval() + ref_output = fsdp_model(*input) + # Run a forward in `no_grad()` and compare with torch.no_grad(): - no_grad_output = model(*input) - + no_grad_output = fsdp_model(*input) self.assertEqual(ref_output, no_grad_output) diff --git a/test/distributed/fsdp/test_fsdp_fx.py b/test/distributed/fsdp/test_fsdp_fx.py new file mode 100644 index 0000000000000..7b0e0a3ddf2f2 --- /dev/null +++ b/test/distributed/fsdp/test_fsdp_fx.py @@ -0,0 +1,122 @@ +# Owner(s): ["oncall: distributed"] + +from typing import Any + +import torch +from torch.distributed.fsdp._symbolic_trace import _init_execution_info, _patch_tracer +from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, +) + + +class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight1 = torch.nn.Parameter(torch.randn(6, 6)) + self.weight2 = torch.nn.Parameter(torch.randn(6, 6)) + self.weight_unused = torch.nn.Parameter(torch.randn(2, 2)) + self.layer0 = torch.nn.Linear(6, 6) + self.layer1 = torch.nn.Linear(6, 6, bias=False) + self.layer2 = torch.nn.Sequential( + torch.nn.Linear(6, 3, bias=False), + torch.nn.ReLU(), + torch.nn.Linear(3, 6, bias=False), + ) + self.relu = torch.nn.ReLU() + + def forward(self, x: Any, run_all_layers: bool): + z = self.relu(self.layer0(x)) + z = self.relu(self.layer2(z)) + z = z @ self.weight1 + if run_all_layers: + z = self.relu(self.layer1(z)) + z = z @ self.weight2 + # used to test the case where a module is called more than once + z = self.relu(self.layer0(x)) + return z + + +class TestSymbolicTracing(FSDPTest): + def test_symbolic_tracing_outputs(self): + """ + test ``execution_info.module_forward_order`` and ``execution_info.module_to_execution_infos`` + after running ``tracer.trace()`` inside ``_patch_tracer``. + """ + model = Model() + tracer = torch.fx.Tracer() + execution_info = _init_execution_info(model) + original_call_module = tracer.call_module + original_create_proxy = tracer.create_proxy + with _patch_tracer( + tracer=tracer, root_module=model, execution_info=execution_info + ): + concrete_args = {"run_all_layers": True} + tracer.trace(model, concrete_args) + # the member functions of tracer should not be changed + self.assertEqual(original_call_module, tracer.call_module) + self.assertEqual(original_create_proxy, tracer.create_proxy) + # test tracer.module_forward_order + correct_module_forward_order = [ + model, + model.layer0, + model.relu, + model.layer2, + model.layer2[0], + model.layer2[1], + model.layer2[2], + model.relu, + model.layer1, + model.relu, + model.layer0, + model.relu, + ] + self.assertEqual( + execution_info.module_forward_order, correct_module_forward_order + ) + # test execution_info.module_to_execution_infos + self.assertEqual( + execution_info.module_to_execution_infos[model], + [ + (model.layer0, list(model.layer0.named_parameters())), + (model.layer2, list(model.layer2.named_parameters())), + (model, [("weight1", model.weight1)]), + (model.layer1, list(model.layer1.named_parameters())), + (model, [("weight2", model.weight2)]), + (model.layer0, list(model.layer0.named_parameters())), + ], + ) + self.assertEqual( + execution_info.module_to_execution_infos[model.layer0], + [(model.layer0, list(model.layer0.named_parameters()))], + ) + self.assertEqual( + execution_info.module_to_execution_infos[model.layer1], + [(model.layer1, list(model.layer1.named_parameters()))], + ) + self.assertEqual( + execution_info.module_to_execution_infos[model.layer2], + [ + (model.layer2[0], list(model.layer2[0].named_parameters())), + (model.layer2[2], list(model.layer2[2].named_parameters())), + ], + ) + self.assertEqual(execution_info.module_to_execution_infos[model.relu], []) + # test tracer.param_exec_order + correct_param_order = [ + model.layer0.weight, + model.layer0.bias, + model.layer2[0].weight, + model.layer2[2].weight, + model.weight1, + model.layer1.weight, + model.weight2, + ] + self.assertEqual(execution_info.param_exec_order, correct_param_order) + + +instantiate_parametrized_tests(TestSymbolicTracing) + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py index f2569266c3471..ae01b22ca66c9 100644 --- a/test/distributed/fsdp/test_fsdp_grad_acc.py +++ b/test/distributed/fsdp/test_fsdp_grad_acc.py @@ -12,7 +12,12 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, + FSDPTest, + TransformerWithSharedParams, +) from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, @@ -118,15 +123,18 @@ def _test_grad_acc( torch.backends.cuda.matmul.allow_tf32 = False # Initialize the FSDP model and optimizer - group = dist.distributed_c10d._get_default_group() - fsdp_model: FSDP = self._get_wrapped_model( - group, cuda_first=False, add_bn=False, - config={ - "cpu_offload": cpu_offload, - "backward_prefetch": backward_prefetch, - }, - ) # disable BN since the test uses varying batch sizes - fsdp_model.eval() # disable dropout + fsdp_kwargs = { + "cpu_offload": cpu_offload, + "backward_prefetch": backward_prefetch, + } + fsdp_model: FSDP = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_AFTER, + fsdp_kwargs, + deterministic=True, + add_bn=False, # disable BN since the test uses varying batch sizes + ) device = torch.device("cuda") optim = torch.optim.SGD( fsdp_model.parameters(), lr=0.01, momentum=0.9, diff --git a/test/distributed/fsdp/test_fsdp_ignored_modules.py b/test/distributed/fsdp/test_fsdp_ignored_modules.py index ddc8f742b060a..826710d1979cb 100644 --- a/test/distributed/fsdp/test_fsdp_ignored_modules.py +++ b/test/distributed/fsdp/test_fsdp_ignored_modules.py @@ -3,10 +3,16 @@ import sys import torch +import torch.nn as nn from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, + FSDPTest, + TransformerWithSharedParams, +) from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, @@ -92,13 +98,25 @@ def test_ignored_modules_transformer(self): transformer model with shared parameters.""" # Initialize an FSDP-wrapped transformer model that has FSDP ignore # the `nn.Transformer` module's parameters - group = dist.distributed_c10d._get_default_group() - wrapped_model = self._get_wrapped_model( - group, cuda_first=True, ignore_modules=True, + model: nn.Module = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + ) + wrapped_model = FSDP( + model, + self.process_group, + ignored_modules=[model.transformer], ) # Check that the wrapped model's flattened parameter does not include # the ignored transformer module's parameters - nonwrapped_model = self._get_nonwrapped_model(group) + nonwrapped_model: nn.Module = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + ) total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum( p.numel() for p in nonwrapped_model.transformer.parameters() diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index 5d9d8fbdc78c8..ce8e090765e4c 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -1,33 +1,36 @@ # Owner(s): ["oncall: distributed"] +import functools import sys from contextlib import suppress -import functools import torch import torch.distributed as dist import torch.nn as nn -from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer +from torch.distributed.fsdp import FlatParameter from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.testing._internal.common_distributed import ( - skip_if_lt_x_gpu, +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp.wrap import ( + always_wrap_policy, + transformer_auto_wrap_policy, ) +from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, FSDPTest, NestedWrappedModule, - FSDPInitMode, TransformerWithSharedParams, - _validate, + _assert_module_states, ) from torch.testing._internal.common_utils import ( + TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, - TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -49,88 +52,143 @@ def world_size(self): def process_group(self): return dist.distributed_c10d._get_default_group() + @skip_if_lt_x_gpu(2) + @parametrize("use_second_layer", [True, False]) + @parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None]) + def test_fsdp_module_no_compute_grad(self, use_second_layer, sharding_strategy): + # When use_second_layer=True, b is involved in forward computation but does + # not receive grad in backward. Otherwise, b is not involved in forward + # computation. + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Linear(10, 10) + self.b = nn.Linear(10, 10) + + def forward(self, x, y): + out1 = self.a(x) + if use_second_layer: + out2 = self.b(y) + return out1, out2 + else: + return out1 + + fsdp = FSDP( + MyModel().cuda(), + sharding_strategy=sharding_strategy, + auto_wrap_policy=always_wrap_policy + ) + x = torch.randn(10, 10, device='cuda') + y = torch.randn(10, 10, device='cuda') + for i in range(4): + if use_second_layer: + a, b = fsdp(x, y) + else: + a = fsdp(x, y) + loss = a.sum() + loss.backward() + + # self.a receives grad, self.b does not + a_grad = fsdp.module.a._fsdp_wrapped_module.flat_param.grad + b_grad = fsdp.module.b._fsdp_wrapped_module.flat_param.grad + self.assertIsNotNone(a_grad) + self.assertIsNone(b_grad) + @skip_if_lt_x_gpu(2) def test_device_id_auto_wrap(self): - """ - Test auto wrapping propagates the device id. - """ - model = TransformerWithSharedParams(group=self.process_group) - my_auto_wrap_policy = functools.partial( + """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all + nested FSDP instances.""" + auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer} + transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) - wrapped = FSDP( - model, - auto_wrap_policy=my_auto_wrap_policy, - device_id=torch.cuda.current_device() + fsdp_kwargs = { + "auto_wrap_policy": auto_wrap_policy, + "device_id": torch.cuda.current_device(), + } + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + fsdp_kwargs, ) - # All FSDP instances should have device_id set - for m in FSDP.fsdp_modules(wrapped): - self.assertEqual(m.device_id, torch.device("cuda", torch.cuda.current_device())) + for fsdp_module in FSDP.fsdp_modules(fsdp_model): + self.assertEqual( + fsdp_module.device_id, + torch.device("cuda", torch.cuda.current_device()), + ) @skip_if_lt_x_gpu(2) @parametrize("use_index", [True, False]) def test_fsdp_device_id(self, use_index): """ - If CPU module is passed into FSDP with device_id - argument, it is moved to the GPU with that device_id. + Tests the FSDP ``device_id`` argument: + - Wrapping a CPU module should move the module to the GPU matching + ``device_id`` + - Wrapping a GPU module already on the GPU matching ``device_id`` + should not raise an error + - Wrapping a GPU module already on GPU and passing a GPU device + without specifying a device ID (i.e. ``torch.device("cuda")``) warns """ dev_id = ( torch.cuda.current_device() if use_index else torch.device("cuda", torch.cuda.current_device()) ) - def _check_device_matches(fsdp, dev_id): - devices = {p.device for p in fsdp.parameters()} + def _check_device_matches(module, device_id): + """Checks that the ``FlatParameter``s in ``module`` have device + matching ``device_id``.""" + devices = { + p.device for p in module.parameters() + if isinstance(p, FlatParameter) + } + assert len(devices) > 0 self.assertEqual(1, len(devices)) - found_dev = devices.pop() - if use_index and not isinstance(dev_id, torch.device): - dev_id = torch.device("cuda", dev_id) - self.assertEqual(found_dev, dev_id) + found_device = devices.pop() + if use_index and not isinstance(device_id, torch.device): + device = torch.device("cuda", device_id) + else: + device = device_id + self.assertEqual(found_device, device) - mod = NestedWrappedModule( - group=self.process_group, - wrap_fsdp=True, - wrap_everything=True, - fsdp_init_mode=FSDPInitMode.CUDA_NEVER, - device_id=dev_id + # Check that FSDP parameters are moved to `device_id` for a CPU module + nested_wrapped_module = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_NEVER, + fsdp_kwargs={"device_id": dev_id}, ) - fsdp = FSDP(mod, device_id=dev_id) - # Check FSDP parameters are moved. - _check_device_matches(fsdp, dev_id) - # device_id matching module device before FSDP construction - # should not throw errors. - mod = NestedWrappedModule( - group=self.process_group, - wrap_fsdp=True, - wrap_everything=True, - fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, - device_id=dev_id + _check_device_matches(nested_wrapped_module, dev_id) + # Check that specifying `device_id` for a GPU module already on that + # device does not raise an error + nested_wrapped_module = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + fsdp_kwargs={"device_id": dev_id}, ) - fsdp = FSDP(mod, device_id=dev_id) - _check_device_matches(fsdp, dev_id) - # Passing in torch.device("cuda") should work. + _check_device_matches(nested_wrapped_module, dev_id) + # Check that passing in `torch.device("cuda")` for a GPU module warns regex = "does not have explicit index" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: - mod = NestedWrappedModule( - group=self.process_group, - wrap_fsdp=True, - wrap_everything=True, - fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, - device_id=torch.device("cuda") + nested_wrapped_module = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + fsdp_kwargs={"device_id": torch.device("cuda")} ) - fsdp = FSDP(mod, device_id=torch.device("cuda")) - _check_device_matches(fsdp, torch.device("cuda", torch.cuda.current_device())) + _check_device_matches( + nested_wrapped_module, + torch.device("cuda", torch.cuda.current_device()) + ) @skip_if_lt_x_gpu(2) def test_module_device_mismatches_device_id(self): - """ - FSDP raises errors when module is on a GPU that does - not match device_id. - """ + """Tests that specifying a ``device_id`` argument to FSDP for a GPU + module that does not match the GPU device ID raises an error.""" context = ( self.assertRaisesRegex( RuntimeError, @@ -138,24 +196,21 @@ def test_module_device_mismatches_device_id(self): ) if self.rank != 0 else suppress() ) with context: - mod = NestedWrappedModule( - group=self.process_group, - wrap_fsdp=True, - wrap_everything=True, - # Would move module to current cuda device before - # wrapping with FSDP - fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, - # Rank 1 is given device id 0, but model is on cuda:1, - # should throw errors. - device_id=0 + NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + # Move wrapped modules to CUDA before wrapping with FSDP + cuda_init_mode=CUDAInitMode.CUDA_BEFORE, + # Should raise error since rank 1 is given `device_id=0` when + # the model is on cuda:1 + fsdp_kwargs={"device_id": 0}, ) @skip_if_lt_x_gpu(2) def test_multi_device_not_supported(self): - """ - FSDP throws appropriate error when we wrap multi-device module. - """ - class MyModule(nn.Module): + """Tests that wrapping a multi-device module (i.e. with submodules on + both GPU and CPU) with FSDP raises an error.""" + class MultiDeviceModule(nn.Module): def __init__(self): super().__init__() self.a = nn.Linear(1, 1).cuda() @@ -164,7 +219,7 @@ def __init__(self): with self.assertRaisesRegex( RuntimeError, "FSDP only supports single device modules" ): - FSDP(MyModule()) + FSDP(MultiDeviceModule()) @skip_if_lt_x_gpu(2) def test_no_params(self): @@ -197,53 +252,54 @@ def test_no_params(self): @skip_if_lt_x_gpu(2) def test_fsdp_cpu_init_stays_on_cpu(self): - """ - Ensure that CPU model input stays on CPU - after FSDP init and we log a warning. - """ + """Tests that passing a CPU module to FSDP preserves that the wrapped + module is on CPU after FSDP initialization, albeit after loging a + warning, and that FSDP moves CPU input to GPU before the forward.""" torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: - mod = NestedWrappedModule( - group=self.process_group, - wrap_fsdp=True, - wrap_everything=True, - fsdp_init_mode=FSDPInitMode.CUDA_NEVER, + nested_wrapped_module = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_NEVER, ) - fsdp = FSDP(mod) - devices = {p.device for p in fsdp.parameters()} + fsdp_model = FSDP(nested_wrapped_module, self.process_group) + devices = {p.device for p in fsdp_model.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) - fsdp = fsdp.cuda() + fsdp_model = fsdp_model.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. - inp = mod.get_input(device=torch.device("cpu")) - fsdp(inp[0]).sum().backward() + inp = fsdp_model.module.get_input(device=torch.device("cpu")) + fsdp_model(*inp).sum().backward() @skip_if_lt_x_gpu(2) - def test_cpu_init_with_sync_module_raises(self): - """ - CPU module with sync_module_states=True throws appropriate - error because it requires GPU comm. - """ - mod = NestedWrappedModule( - group=self.process_group, - wrap_fsdp=False, - wrap_everything=True, - fsdp_init_mode=FSDPInitMode.CUDA_NEVER, + def test_cpu_init_with_sync_module_states(self): + """Tests that passing ``sync_module_states=True`` raises an error for + a CPU module since the synchronization requires GPU communication, + while additionally passing ``device_id`` does not raise an error.""" + nested_wrapped_module = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_NEVER, ) with self.assertRaisesRegex( ValueError, "Module has CPU parameters, but sync_module_states=True is specified." ): - FSDP(mod, sync_module_states=True) + FSDP(nested_wrapped_module, self.process_group, sync_module_states=True) # Specifying device_id with sync_module_states=True works. - FSDP(mod, device_id=torch.cuda.current_device(), sync_module_states=True) + FSDP( + nested_wrapped_module, + self.process_group, + device_id=torch.cuda.current_device(), + sync_module_states=True, + ) @skip_if_lt_x_gpu(2) def test_fsdp_same_model_across_ranks(self): @@ -261,19 +317,19 @@ def __init__(self, rank): self.register_buffer("buffer", torch.ones(1) * rank) m = MyModel(self.rank).cuda() - _validate(m, process_group=self.process_group, assert_fn=self.assertNotEqual) + _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, sync_module_states=True) with fsdp.summon_full_params(fsdp): - _validate(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) + _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) # sync_module_states also works with CPU module with device_id passed in m = MyModel(self.rank) - _validate(m, process_group=self.process_group, assert_fn=self.assertNotEqual) + _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) with fsdp.summon_full_params(fsdp): - _validate(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) + _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) instantiate_parametrized_tests(TestFSDPMisc) diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index d841ed874b9ee..238a72e334c52 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -1,38 +1,38 @@ # Owner(s): ["oncall: distributed"] -import sys import contextlib +import sys from functools import partial from itertools import product +from typing import Any, Dict, List import torch import torch.cuda.nccl as nccl import torch.nn as nn import torch.nn.functional as F from torch import distributed as dist -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - CPUOffload, - MixedPrecision, - BackwardPrefetch, - ShardingStrategy, -) +from torch.distributed.fsdp import BackwardPrefetch, CPUOffload +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy from torch.nn.modules.batchnorm import _BatchNorm -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.testing._internal.common_cuda import CUDA11OrLater from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, FSDPTest, + TransformerWithSharedParams, subtest_name, ) from torch.testing._internal.common_utils import ( + TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, - TEST_WITH_DEV_DBG_ASAN, sandcastle_skip_if, ) -from torch.testing._internal.common_cuda import CUDA11OrLater try: import torchvision @@ -87,40 +87,30 @@ # Buffer original dtype, which can differ from model params. _BUFFER_ORIG_DTYPE = torch.float64 -params = "mp_config,cpu_offload,backward_prefetch,forward_prefetch,full_precision_param_dtype,sharded_grad_scaler" +params = "mp_config,cpu_offload,full_precision_param_dtype,enable_sharded_grad_scaler" cpu_offload_config = [ CPUOffload(offload_params=True), CPUOffload(offload_params=False) ] -backward_prefetch_config = [ - BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST -] -forward_prefetch_config = ["forward_prefetch", "no_forward_prefetch"] full_precision_param_dtype_config = [torch.float32, torch.float64] -sharded_grad_scaler = ["enable_sharded_grad_scaler", None] +enable_sharded_grad_scaler = ["enable_sharded_grad_scaler", None] configs = list(product( mp_configs, cpu_offload_config, - backward_prefetch_config, - forward_prefetch_config, full_precision_param_dtype_config, - sharded_grad_scaler, + enable_sharded_grad_scaler, )) test_name_mapping = { str(CPUOffload(offload_params=True)): "offload_true", str(CPUOffload(offload_params=False)): "offload_false", - str(BackwardPrefetch.BACKWARD_PRE): "backward_prefetch_pre", - str(BackwardPrefetch.BACKWARD_POST): "backward_prefetch_post", - "forward_prefetch": "forward_prefetch", - "no_forward_prefetch": "no_forward_prefetch", str(default_mp): "mp_fp16", str(mp_only_reduce): "mp_only_reduce", str(mp_only_param_and_buf): "mp_only_param_and_buf", str(mp_no_mixed_precision): "mp_no_mp", str(torch.float32): "fp32", str(torch.float64): "fp64", - "enable_sharded_grad_scaler": "sharded_grad_scaler" + "enable_sharded_grad_scaler": "enable_sharded_grad_scaler" } if nccl_supports_bf16: @@ -291,8 +281,6 @@ def _reduce_scatter_base_validate_mp( ) ) - # for t in tensors: - # print(f"tensor type {t.dtype} expected {expected_dtype}") for t in tensors: self.assertEqual(expected_dtype, t.dtype) @@ -306,7 +294,7 @@ def _run_test_mixed_precision_e2e( forward_prefetch, full_precision_param_dtype, sharding_strategy, - sharded_grad_scaler, + enable_sharded_grad_scaler, ): torch.cuda.set_device(self.rank) fsdp_models = [ @@ -337,7 +325,7 @@ def _run_test_mixed_precision_e2e( self._reduce_scatter_base_validate_mp, orig_reduce_scatter, mp_config, ) with patch_reduce_scatter(test_reduce_scatter, full_precision_param_dtype): - scaler = ShardedGradScaler(enabled=sharded_grad_scaler) + scaler = ShardedGradScaler(enabled=enable_sharded_grad_scaler) optim = torch.optim.Adam(model.parameters()) for _ in range(3): @@ -443,11 +431,22 @@ def _run_test_mixed_precision_e2e( class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision): - @property def world_size(self): return 2 + def _get_subtest_config(self) -> Dict[str, List[Any]]: + """Returns a subtest configuration that subtests prefetching settings + together.""" + return { + "forward_prefetch": [False, True], + "backward_prefetch": [ + None, + BackwardPrefetch.BACKWARD_PRE, + BackwardPrefetch.BACKWARD_POST, + ] + } + @skip_if_lt_x_gpu(2) def test_mixed_precision_no_reshard_after_forward(self): # Note that we don't exercise all possible different configs so as to @@ -460,7 +459,7 @@ def test_mixed_precision_no_reshard_after_forward(self): forward_prefetch=False, full_precision_param_dtype=torch.float64, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, - sharded_grad_scaler=False, + enable_sharded_grad_scaler=False, ) @skip_if_lt_x_gpu(2) @@ -469,20 +468,17 @@ def test_mixed_precision_e2e_full_shard( self, mp_config, cpu_offload, - backward_prefetch, - forward_prefetch, full_precision_param_dtype, - sharded_grad_scaler, + enable_sharded_grad_scaler, ): - forward_prefetch = (forward_prefetch == "forward_prefetch") - self._run_test_mixed_precision_e2e( - mp_config, - cpu_offload, - backward_prefetch, - forward_prefetch, - full_precision_param_dtype, - ShardingStrategy.FULL_SHARD, - sharded_grad_scaler, + self.run_subtests( + self._get_subtest_config(), + self._run_test_mixed_precision_e2e, + mp_config=mp_config, + cpu_offload=cpu_offload, + full_precision_param_dtype=full_precision_param_dtype, + sharding_strategy=ShardingStrategy.FULL_SHARD, + enable_sharded_grad_scaler=enable_sharded_grad_scaler, ) def _test_mixed_precision_embedding_table(self, mp_config): @@ -494,19 +490,24 @@ def _test_mixed_precision_embedding_table(self, mp_config): self._reduce_scatter_base_validate_mp, orig_reduce_scatter, mp_config, ) with patch_reduce_scatter(test_reduce_scatter, param_dtype): - model = self._get_wrapped_model( - group=torch.distributed.distributed_c10d._get_default_group(), - config={"mixed_precision": mp_config} + # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the + # entire `TransformerWithSharedParams` with a single top-level FSDP + model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + {"mixed_precision": mp_config}, ) - optim = torch.optim.SGD(model.parameters(), lr=0.1) + fsdp_model = FSDP(model, mixed_precision=mp_config) + optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1) for _ in range(6): - inp = model.module.get_input(torch.device("cuda")) + inp = fsdp_model.module.get_input(torch.device("cuda")) # This would fail if we casted integer module inputs such as for # embedding tables. - output = model(*inp) - loss = model.module.get_loss(inp, output).cuda() + output = fsdp_model(*inp) + loss = fsdp_model.module.get_loss(inp, output).cuda() self.assertEqual(loss.dtype, param_dtype) - model.module.run_backward(loss) + fsdp_model.module.run_backward(loss) optim.step() @skip_if_lt_x_gpu(2) @@ -656,7 +657,7 @@ def test_mixed_precision_no_reshard_after_forward(self): forward_prefetch=False, full_precision_param_dtype=torch.float64, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, - sharded_grad_scaler=False, + enable_sharded_grad_scaler=False, ) @skip_if_lt_x_gpu(1) @@ -669,7 +670,7 @@ def test_mixed_precision_e2e_full_shard(self): forward_prefetch=False, full_precision_param_dtype=torch.float64, sharding_strategy=ShardingStrategy.FULL_SHARD, - sharded_grad_scaler=False, + enable_sharded_grad_scaler=False, ) instantiate_parametrized_tests(TestFSDPMixedPrecisionSharded) diff --git a/test/distributed/fsdp/test_fsdp_optim_state.py b/test/distributed/fsdp/test_fsdp_optim_state.py index 84e59277697e6..1f2d5ad8ea8db 100644 --- a/test/distributed/fsdp/test_fsdp_optim_state.py +++ b/test/distributed/fsdp/test_fsdp_optim_state.py @@ -7,13 +7,21 @@ import torch from torch import distributed as dist +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, apply_activation_checkpointing_wrapper +) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( OptimStateKeyType, StateDictType, ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, + FSDPTest, + TransformerWithSharedParams, +) from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, @@ -275,9 +283,12 @@ def _init_transformer_model( raise NotImplementedError() if group is None: group = dist.distributed_c10d._get_default_group() - model = self._get_wrapped_model(group=group, cuda_first=True) if wrap \ - else self._get_nonwrapped_model(group=group, cuda_first=True) - model.eval() # disable dropout for determinism + model = TransformerWithSharedParams.init( + group, + FSDPInitMode.RECURSIVE if wrap else FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + ) optim = optim_class(model.parameters(), lr=0.01) return model, optim, None @@ -294,6 +305,7 @@ def _step_model( losses = [] module = model.module if hasattr(model, "module") else model for _ in range(num_iters): + optim.zero_grad() inp = module.get_input(device) output = model(*inp) loss = module.get_loss(inp, output).to(device) @@ -477,12 +489,21 @@ def test_full_optim_state_dict_keys(self): device = torch.device("cuda") model = NestedModel().to(device) wrapped_model = NestedModel.wrap(model, ignore_modules=True) + # Add checkpointing to ensure optim_state_dict and state_dict strip out + # checkpointing prefixes. + apply_activation_checkpointing_wrapper( + model, + check_fn=lambda module: isinstance(module, torch.nn.Sequential) + ) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._step_model(model, optim, device) optim_state_dict = FSDP.full_optim_state_dict(wrapped_model, optim, rank0_only=False) with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT): state_dict = wrapped_model.state_dict() self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys()) + # Check that checkpointing prefix was indeed stripped. + for key in optim_state_dict["state"]: + self.assertNotIn(_CHECKPOINT_PREFIX, key) @skip_if_lt_x_gpu(2) def test_full_optim_state_dict_nested_invalid(self): diff --git a/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py b/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py index ff3e7479d936c..a1c73d1cafb53 100644 --- a/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py +++ b/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py @@ -1,11 +1,14 @@ # Owner(s): ["oncall: distributed"] +from typing import Any, Callable + import torch -from torch.testing._internal.common_fsdp import FSDPTest -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.distributed.fsdp.wrap import ParamExecOrderWrapPolicy, always_wrap_policy -from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp._symbolic_trace import TracingConfig +from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy +from torch.distributed.fsdp.wrap import always_wrap_policy, ParamExecOrderWrapPolicy +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -25,25 +28,31 @@ def __init__(self) -> None: ) self.relu = torch.nn.ReLU() - def forward(self, x): + def forward(self, x: Any, use_all_params: bool = True): # `layer0` -> `layer2` -> `layer1` # the forward execution order is NOT consistent with the model definition order. z = self.relu(self.layer0(x)) z = self.relu(self.layer2(z)) - z = self.relu(self.layer1(z)) + if use_all_params: + z = self.relu(self.layer1(z)) return z def get_input(self, device: torch.device): - return (torch.randn((8, 6)).to(device), ) + return (torch.randn((8, 6)).to(device),) def get_loss(self, input, output): return (output - input[0]).sum() @staticmethod - def wrap(sharding_strategy: ShardingStrategy, device: torch.device, init_policy=always_wrap_policy): + def wrap( + sharding_strategy: ShardingStrategy, + device: torch.device, + wrap_policy: Callable, + ) -> torch.nn.Module: model = Model() - wrap_policy = ParamExecOrderWrapPolicy(init_policy=init_policy) - fsdp_model = FSDP(model, auto_wrap_policy=wrap_policy, sharding_strategy=sharding_strategy) + fsdp_model = FSDP( + model, auto_wrap_policy=wrap_policy, sharding_strategy=sharding_strategy + ) return fsdp_model.to(device) @@ -57,27 +66,63 @@ def device(self): "sharding_strategy", [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP], ) - @parametrize("iters", [1, 3]) - def test_fsdp_flatten_params_exec_order(self, sharding_strategy: ShardingStrategy, iters: int): - """Tests the basic APIs of FSDP with ParamExecOrderWrapPolicy""" - fsdp_model = Model.wrap(sharding_strategy, self.device) + def test_fsdp_flatten_params_exec_order( + self, + sharding_strategy: ShardingStrategy, + ): + """ + Test ``_fsdp_params_exec_order`` with ``ParamExecOrderWrapPolicy``, + after running one iteration of forward and backward pass. + Here ``torch.fx`` is not enabled inside ``ParamExecOrderWrapPolicy``. + """ + wrap_policy = ParamExecOrderWrapPolicy(init_policy=always_wrap_policy) + fsdp_model = Model.wrap(sharding_strategy, self.device, wrap_policy=wrap_policy) self.assertTrue(fsdp_model._is_param_exec_order_prep_stage()) - for _ in range(iters): - input = fsdp_model.module.get_input(self.device) - output = fsdp_model(*input) - loss = fsdp_model.module.get_loss(input, output).to(self.device) - loss.backward() + # run one iteration to record the execution ordering + input = fsdp_model.module.get_input(self.device) + output = fsdp_model(*input) + loss = fsdp_model.module.get_loss(input, output).to(self.device) + loss.backward() + params_list = list(fsdp_model.parameters()) + # Since the forward execution order is NOT consistent with + # the model definition order, the ordering in flatten_named_params_exec_order + # should be different from named_parameters. + self.assertEqual( + fsdp_model._fsdp_params_exec_order, + [params_list[0], params_list[2], params_list[3], params_list[1]], + ) + self.assertTrue(fsdp_model._use_param_exec_order_policy()) + self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage()) + + @skip_if_lt_x_gpu(2) + @parametrize( + "sharding_strategy", + [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP], + ) + def test_fsdp_flatten_params_exec_order_symbolic_trace( + self, + sharding_strategy: ShardingStrategy, + ): + """ + Tests ``ParamExecOrderWrapPolicy`` with symbolic tracing. + With symbolic tracing enabled, ``_is_param_exec_order_prep_stage`` + should always set as False. + """ + wrap_policy = ParamExecOrderWrapPolicy( + init_policy=always_wrap_policy, + tracing_config=TracingConfig(concrete_args={"use_all_params": False}), + ) + fsdp_model = Model.wrap( + sharding_strategy, + self.device, + wrap_policy=wrap_policy, + ) params_list = list(fsdp_model.parameters()) # Since the forward execution order is NOT consistent with the model definition order, # the ordering in flatten_named_params_exec_order should be different from named_parameters self.assertEqual( fsdp_model._fsdp_params_exec_order, - [ - params_list[0], - params_list[2], - params_list[3], - params_list[1] - ] + [params_list[0], params_list[2], params_list[3]], ) self.assertTrue(fsdp_model._use_param_exec_order_policy()) self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage()) diff --git a/test/distributed/fsdp/test_fsdp_pure_fp16.py b/test/distributed/fsdp/test_fsdp_pure_fp16.py index 82648ea457a8b..eea03bea3d8a0 100644 --- a/test/distributed/fsdp/test_fsdp_pure_fp16.py +++ b/test/distributed/fsdp/test_fsdp_pure_fp16.py @@ -2,25 +2,22 @@ import sys -import torch from torch import distributed as dist -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload -from torch.nn.parallel import DistributedDataParallel -from torch.optim import SGD +from torch.distributed.fsdp import CPUOffload from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, FSDPTest, - get_full_params, - DeterministicModel, + NestedWrappedModule, ) from torch.testing._internal.common_utils import ( + TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, - TEST_WITH_DEV_DBG_ASAN, run_tests, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -33,51 +30,25 @@ sys.exit(0) -# Test pure fp16 training, also testing the case when the parameter's data type is -# changed after FSDP wrapping and before training loop starts. -# Only run one step for comparision, as usually grad scaler is needed to avoid NaN value -# after first step. class TestPureFP16(FSDPTest): - def _dist_train(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): - # keep everything deterministic for input data - torch.manual_seed(0) - - model = DeterministicModel(wrap_fsdp, cpu_offload) - if wrap_fsdp: - model = FSDP(model, cpu_offload=cpu_offload) - else: - model = DistributedDataParallel(model, device_ids=[self.rank]) - model.half() - optim = SGD(model.parameters(), lr=0.1) - - in_data = torch.rand(16, 2).cuda().half() - in_data.requires_grad = True - for _ in range(1): - out = model(in_data) - out.sum().backward() - optim.step() - optim.zero_grad() - - if wrap_fsdp: - full_params = get_full_params(model) - torch.cuda.synchronize() - return full_params - - return list(model.parameters()) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) - def test_pure_fp16(self, cpu_offload): - # DDP - ddp_state = self._dist_train(wrap_fsdp=False) - - # FSDP - fsdp_state = self._dist_train(wrap_fsdp=True, cpu_offload=cpu_offload) - - self.assertEqual(ddp_state, fsdp_state) + def test_pure_fp16(self, cpu_offload: CPUOffload): + """Tests pure FP16 training, including when the parameter's dtype is + changed after FSDP initialization and before training.""" + self._test_fsdp_parity( + NestedWrappedModule, + FSDPInitMode.RECURSIVE, + cuda_init_mode=CUDAInitMode.CUDA_AFTER, + # Run one iteration to avoid NaN without a gradient scaler + num_iters=1, + cpu_offload=cpu_offload, + use_pure_fp16=True, + ) instantiate_parametrized_tests(TestPureFP16) diff --git a/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py b/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py index 44b8815a9a4b6..1c230cb7400c4 100644 --- a/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py +++ b/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py @@ -3,24 +3,32 @@ import functools import itertools import sys -import torch import unittest +from typing import Optional +import torch from torch import distributed as dist from torch.cuda.amp.common import amp_definitely_not_available +from torch.distributed.fsdp import CPUOffload, MixedPrecision from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy -from torch.distributed.fsdp import MixedPrecision, CPUOffload from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torch.testing._internal.common_fsdp import DummyProcessGroup, subtest_name, FSDPInitMode, NestedWrappedModule, FSDPTest from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + DummyProcessGroup, + FSDPInitMode, + FSDPTest, + NestedWrappedModule, + subtest_name, +) from torch.testing._internal.common_utils import ( - TestCase, run_tests, + TEST_WITH_DEV_DBG_ASAN, + TestCase, instantiate_parametrized_tests, parametrize, - TEST_WITH_DEV_DBG_ASAN, + run_tests, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -120,31 +128,37 @@ def test_inf_gradients_skip_optim_step(self): class TestShardedGradScalerParityWithDDP(FSDPTest): def _get_init_modes_for_test(self, cpu_offload): modes = [ - FSDPInitMode.CUDA_AFTER, - FSDPInitMode.CUDA_BEFORE + CUDAInitMode.CUDA_AFTER, + CUDAInitMode.CUDA_BEFORE ] - # Note that FSDPInitMode.CUDA_NEVER works currently only with CPU + # Note that CUDAInitMode.CUDA_NEVER works currently only with CPU # offload as we explicitly bring the param back to CUDA device. In # general, it will not work since we try to all_gather p.data which is # on CPU but NCCL only supports GPU. if cpu_offload.offload_params: - modes.append(FSDPInitMode.CUDA_NEVER) + modes.append(CUDAInitMode.CUDA_NEVER) return modes @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - def test_scaler_enabled(self, cpu_offload, sharding_strategy, mixed_precision): + def test_fsdp_ddp_parity_with_grad_scaler( + self, + cpu_offload: CPUOffload, + sharding_strategy: Optional[ShardingStrategy], + mixed_precision: Optional[str], + ): init_modes = self._get_init_modes_for_test(cpu_offload) mp = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, - ) if mixed_precision else None - for fsdp_init_mode in init_modes: - self._test_identical_outputs( + ) if mixed_precision is not None else None + for cuda_init_mode in init_modes: + self._test_fsdp_parity( NestedWrappedModule, - fsdp_init_mode=fsdp_init_mode, + FSDPInitMode.RECURSIVE, + cuda_init_mode=cuda_init_mode, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, mixed_precision=mp, diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 367f5cb0efef0..97c82fbe8ffb2 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import itertools import sys from contextlib import suppress from copy import deepcopy @@ -8,42 +9,53 @@ import torch import torch.nn as nn -from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer from torch import distributed as dist -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, +) +from torch.distributed.fsdp import CPUOffload, FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, LocalStateDictConfig, - CPUOffload, MixedPrecision, + StateDictType, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel, ) -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.distributed.fsdp.shard_utils import _gather_state_dict -from torch.distributed.fsdp.wrap import enable_wrap, wrap, transformer_auto_wrap_policy -from torch.nn import Linear, Module +from torch.distributed.fsdp.wrap import ( + enable_wrap, + transformer_auto_wrap_policy, + wrap, +) +from torch.nn import ( + Linear, + Module, + TransformerDecoderLayer, + TransformerEncoderLayer, +) from torch.nn.parallel import DistributedDataParallel from torch.optim import SGD from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, FSDPTest, - get_full_params, - _get_full_detached_param, - _get_state_dict, SkipModel, - _zero_model, TransformerWithSharedParams, - _validate, + _assert_module_states, + _get_state_dict, + _zero_model, + get_full_params, ) from torch.testing._internal.common_utils import ( + TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, - TEST_WITH_DEV_DBG_ASAN, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -156,7 +168,12 @@ def forward(self, x): self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs), ) - def _get_state_dict_mgr(self, model, state_dict_type, state_dict_rank0_and_offload): + def _get_state_dict_mgr( + self, + model: nn.Module, + state_dict_type: str, + state_dict_rank0_and_offload: bool, + ): _state_dict_type = STATE_DICT_MAPPING[state_dict_type] if state_dict_type == "state_dict": config = FullStateDictConfig( @@ -190,6 +207,9 @@ def _validate_state_dict_contents( @skip_if_lt_x_gpu(2) @parametrize("checkpoint_wrap", ["first", "second", "both"]) def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap): + """Tests saving the state dict, zeroing a target model's parameters, and + loading the state dict, where the source and target models may have a + checkpoint wrapper.""" for model_call in [ partial(self._get_simple_model), partial(self._get_simple_nested_model) @@ -207,51 +227,62 @@ def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap): @skip_if_lt_x_gpu(2) def test_state_dict_rank0_offload_save_load_flow(self): - # Test taking checkpoint on rank 0 only, and reload - # without redundant CPU memories. - model = TransformerWithSharedParams(group=dist.distributed_c10d._get_default_group()) - my_auto_wrap_policy = partial( + """Tests saving a model checkpoint only on rank 0 and loading it only + on rank 0 with ``sync_module_states=True`` to emulate the workflow to + avoid redundant CPU memory usage.""" + auto_wrap_policy = partial( transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer} + transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) - model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) - ctx = self._get_state_dict_mgr( - model, "state_dict", True + fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + fsdp_kwargs, ) - with ctx: - state_dict = deepcopy(_get_state_dict(model)) - - # All ranks initialize non-FSDP model - grp = dist.distributed_c10d._get_default_group() - model_new = TransformerWithSharedParams(group=grp) - for p in model_new.parameters(): - with torch.no_grad(): - p.zero_() - # Only rank 0 loads the checkpoint + # Force model parameters and buffers to be nonzero + with FSDP.summon_full_params(fsdp_model): + for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): + if torch.count_nonzero(tensor) == 0: + with torch.no_grad(): + tensor.add_(torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) + with self._get_state_dict_mgr(fsdp_model, "state_dict", True): + state_dict = deepcopy(_get_state_dict(fsdp_model)) + # Initialize a non-wrapped model on all ranks + new_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + ) + _zero_model(new_model, zero_buffers=True) + # Only load the checkpoint on rank 0 if self.rank == 0: - model_new.load_state_dict(state_dict) - - # TransformerWithSharedParams has a buffer of zeros, so can't pass in - # self.assertNotEqual since the buffers would be equal. So just checking that - # there is some difference in the model across ranks before state_dict is - # broadcasted. - with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close"): - _validate(model_new, process_group=grp, assert_fn=self.assertEqual) - # FSDP with sync_module_states=True broadcasts the checkpointed states. - model_new = FSDP( - model_new, + new_model.load_state_dict(state_dict) + _assert_module_states( + new_model, + process_group=self.process_group, + assert_fn=self.assertNotEqual, + ) + # Broadcast the module states from rank 0 with `sync_module_states=True` + new_fsdp_model = FSDP( + new_model, device_id=torch.cuda.current_device(), - auto_wrap_policy=my_auto_wrap_policy, - sync_module_states=True + auto_wrap_policy=auto_wrap_policy, + sync_module_states=True, ) - # After wrapping with FSDP models are equal across ranks, and have loaded the checkpoint - with FSDP.summon_full_params(model_new): - _validate(model_new, process_group=grp, assert_fn=self.assertEqual) - - with FullyShardedDataParallel.summon_full_params(model): - with FullyShardedDataParallel.summon_full_params(model_new): - params = list(model.parameters()) - params_new = list(model_new.parameters()) + # Check FSDP models are equal across ranks + with FSDP.summon_full_params(new_fsdp_model): + _assert_module_states( + new_fsdp_model, + process_group=self.process_group, + assert_fn=self.assertEqual, + ) + # Check FSDP models correctly loaded the checkpoint + with FullyShardedDataParallel.summon_full_params(fsdp_model): + with FullyShardedDataParallel.summon_full_params(new_fsdp_model): + params = list(fsdp_model.parameters()) + params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new) @skip_if_lt_x_gpu(2) @@ -350,7 +381,7 @@ def test_save_and_load_after_forward_state_dict( ) model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) - initial_params = _get_full_detached_param(model) + initial_params = get_full_params(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) @@ -360,7 +391,7 @@ def test_save_and_load_after_forward_state_dict( loss.backward() optim.step() - trained_params = _get_full_detached_param(model) + trained_params = get_full_params(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict @@ -392,7 +423,7 @@ def test_save_and_load_after_forward_state_dict( with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]): model.load_state_dict(state_dict) - loaded_params = _get_full_detached_param(model) + loaded_params = get_full_params(model) self.assertEqual(loaded_params, trained_params) def _initialize_model( diff --git a/test/distributed/fsdp/test_fsdp_summon_full_params.py b/test/distributed/fsdp/test_fsdp_summon_full_params.py index 8808a5d712b76..d6fe3013bff47 100644 --- a/test/distributed/fsdp/test_fsdp_summon_full_params.py +++ b/test/distributed/fsdp/test_fsdp_summon_full_params.py @@ -7,26 +7,26 @@ import torch import torch.nn as nn from torch import distributed as dist -from torch.distributed.fsdp import CPUOffload, MixedPrecision -from torch.distributed.fsdp import FlatParameter +from torch.distributed.fsdp import CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel -from torch.distributed.fsdp.wrap import wrap, enable_wrap +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.flat_param import FlatParamHandle +from torch.distributed.fsdp.wrap import enable_wrap, wrap from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + DeterministicModel, FSDPInitMode, FSDPTest, NestedWrappedModule, - DeterministicModel, ) from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, - run_tests, instantiate_parametrized_tests, parametrize, + run_tests, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -138,7 +138,7 @@ def test_summon_full_param_shard_value(self, mixed_precision): model = FSDP(raw_model.cuda(self.rank), mixed_precision=mixed_precision) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) - # we're assuming a single flatenned param + # we're assuming a single flattened param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) @@ -146,7 +146,7 @@ def test_summon_full_param_shard_value(self, mixed_precision): with model.summon_full_params(model): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) - all_shards = FlatParameter(parameters, requires_grad=False) + all_shards = FlatParamHandle.flatten_params(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not @@ -351,7 +351,7 @@ def __init__(self, fsdp_1, fsdp_2, fsdp_3): ) params_to_compare = list(model_no_fsdp.parameters()) - with FullyShardedDataParallel.summon_full_params(model_fsdp): + with FSDP.summon_full_params(model_fsdp): fsdp_params = [p.clone() for p in model_fsdp.parameters()] self.assertEqual(params_to_compare, fsdp_params) @@ -472,35 +472,36 @@ def _get_flat_param(): @parametrize("rank0_only", [True, False]) @parametrize("offload_to_cpu", [True, False]) @parametrize("mixed_precision", [True, False]) - def test_params_count_and_value(self, rank0_only, offload_to_cpu, mixed_precision): + def test_params_count_and_value( + self, + rank0_only: bool, + offload_to_cpu: bool, + mixed_precision: bool, + ): mixed_precision = MixedPrecision() if mixed_precision else None - fsdp_model = FSDP( - NestedWrappedModule( - group=dist.distributed_c10d._get_default_group(), - wrap_fsdp=True, - fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, - mixed_precision=mixed_precision, - ), - mixed_precision=mixed_precision, + model = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, ) - model = NestedWrappedModule( - group=dist.distributed_c10d._get_default_group(), - wrap_fsdp=False, - fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, + fsdp_model = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, ) - dev = ( torch.device("cpu") if offload_to_cpu else torch.device("cuda", torch.cuda.current_device()) ) - params_to_compare = ( [p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list(p.clone() for p in fsdp_model.parameters()) ) - with fsdp_model.summon_full_params( + with FSDP.summon_full_params( fsdp_model, rank0_only=rank0_only, writeback=not rank0_only ): for p1, p2 in itertools.zip_longest( @@ -516,17 +517,16 @@ def test_params_count_and_value(self, rank0_only, offload_to_cpu, mixed_precisio @skip_if_lt_x_gpu(2) def test_raises_rank0_with_writeback(self): - fsdp_model = FSDP( - NestedWrappedModule( - group=dist.distributed_c10d._get_default_group(), - wrap_fsdp=True, - fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, - ) + """Tests that ``summon_full_params()`` with both ``rank0_only=True`` + and ``writeback=True`` raises an error.""" + nested_wrapped_module = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, ) - with self.assertRaisesRegex(ValueError, "is not supported"): - with fsdp_model.summon_full_params( - fsdp_model, rank0_only=True, writeback=True + with FSDP.summon_full_params( + nested_wrapped_module, rank0_only=True, writeback=True ): pass @@ -534,21 +534,29 @@ def test_raises_rank0_with_writeback(self): @parametrize("prefix", ["", "test_prefix"]) @parametrize("recurse", [False, True]) def test_named_parameters_buffers(self, prefix: str, recurse: bool): + """Tests that ``named_parameters()`` and ``named_buffers()`` for a + top-level FSDP-wrapped model matches their behavior for the equivalent + non-wrapped model.""" + model = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + ) + model.register_buffer("buffer", torch.ones(1)) + # `named_parameters()` and `named_buffers` will contain FSDP prefixes + # if called on a non-FSDP root module fsdp_model = FSDP( - NestedWrappedModule( - group=dist.distributed_c10d._get_default_group(), - wrap_fsdp=True, - fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, - ) + NestedWrappedModule.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + ), + self.process_group, ) fsdp_model.register_buffer("buffer", torch.ones(1)) - model = NestedWrappedModule( - group=dist.distributed_c10d._get_default_group(), - wrap_fsdp=False, - fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, - ) - model.register_buffer("buffer", torch.ones(1)) - with fsdp_model.summon_full_params(fsdp_model): + with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), diff --git a/test/distributed/fsdp/test_fsdp_traversal.py b/test/distributed/fsdp/test_fsdp_traversal.py index 69ceca082441b..e1b0a77cfe791 100644 --- a/test/distributed/fsdp/test_fsdp_traversal.py +++ b/test/distributed/fsdp/test_fsdp_traversal.py @@ -6,6 +6,8 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, FSDPTest, NestedWrappedModule, ) @@ -14,7 +16,6 @@ run_tests, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -34,21 +35,24 @@ def world_size(self): @skip_if_lt_x_gpu(2) def test_fsdp_modules(self): - group = dist.distributed_c10d._get_default_group() - model = NestedWrappedModule(group, wrap_fsdp=True) - modules = FSDP.fsdp_modules(model) + nested_wrapped_module = NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + ) + modules = FSDP.fsdp_modules(nested_wrapped_module) self.assertEquals( modules, [ - model.module.get_submodule("1"), - model.module.get_submodule("1").get_submodule("0"), - model.module.get_submodule("2"), + nested_wrapped_module.module.get_submodule("1"), + nested_wrapped_module.module.get_submodule("1").get_submodule("0"), + nested_wrapped_module.module.get_submodule("2"), ] ) - modules = FSDP.fsdp_modules(model, root_only=True) + modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True) self.assertEqual( modules, [ - model.module.get_submodule("1"), - model.module.get_submodule("2"), + nested_wrapped_module.module.get_submodule("1"), + nested_wrapped_module.module.get_submodule("2"), ] ) diff --git a/test/distributed/fsdp/test_fsdp_uneven.py b/test/distributed/fsdp/test_fsdp_uneven.py index 93b89f547e1f9..295afbce508bc 100644 --- a/test/distributed/fsdp/test_fsdp_uneven.py +++ b/test/distributed/fsdp/test_fsdp_uneven.py @@ -62,8 +62,6 @@ def test_one_iteration(self): optim.zero_grad() with model.summon_full_params(model): - torch.cuda.synchronize() # TODO: This is here because it was - # originally part of get_full_params(), debug why it is needed here. weight_out = model.module.weight.T.clone() self.assertEqual(ref_forward_output_my_rank, out) self.assertEqual(ref_weight_out, weight_out) diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index 2326c7137c308..b1c8549dd1bf3 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -1,27 +1,24 @@ # Owner(s): ["oncall: distributed"] -from collections import OrderedDict import random import sys import unittest +from collections import OrderedDict import torch import torch.nn as nn from torch import distributed as dist -from torch.distributed.fsdp._utils import ( - _apply_to_tensors, -) +from torch.distributed.fsdp._utils import _apply_to_tensors from torch.distributed.utils import _replace_by_prefix from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, + TestCase, instantiate_parametrized_tests, parametrize, run_tests, subtest, - TestCase, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py index a39f3858f73c1..0008f8d23a94a 100644 --- a/test/distributed/fsdp/test_wrap.py +++ b/test/distributed/fsdp/test_wrap.py @@ -1,46 +1,49 @@ # Owner(s): ["oncall: distributed"] -from enum import Enum, auto import functools import os import tempfile import unittest +from enum import Enum, auto + import torch import torch.nn as nn import torch.nn.functional as F from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel as FSDP, - CPUOffload, BackwardPrefetch, + CPUOffload, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, ) from torch.distributed.fsdp.wrap import ( - always_wrap_policy, - size_based_auto_wrap_policy, - enable_wrap, _or_policy, - wrap, _wrap_batchnorm_individually, + always_wrap_policy, + enable_wrap, + size_based_auto_wrap_policy, transformer_auto_wrap_policy, + wrap, ) -from torch.testing._internal.common_distributed import ( - skip_if_lt_x_gpu, -) +from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + CUDAInitMode, DummyProcessGroup, - FSDPTest, FSDPInitMode, - _maybe_cuda, + FSDPTest, TransformerWithSharedParams, + _maybe_cuda, ) from torch.testing._internal.common_utils import ( FILE_SCHEMA, - run_tests, - find_free_port, TestCase, - parametrize, + find_free_port, instantiate_parametrized_tests, + parametrize, + run_tests, ) -from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer + class BatchNormNet(nn.Module): def __init__(self): @@ -104,7 +107,7 @@ def _get_linear(self, fin, fout): return nn.Linear(fin, fout, bias=False) def _get_already_wrapped_fsdp( - self, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, nested=False + self, cuda_init_mode=CUDAInitMode.CUDA_BEFORE, nested=False ) -> FSDP: fn_self = self @@ -112,7 +115,7 @@ class MyModel(nn.Module): def __init__(self, nested): super().__init__() # TODO: test the various init modes. - move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE + move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE # if nested=True, the FSDP module will be nested one layer deep # and we should pick that up. if nested: @@ -135,14 +138,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: @skip_if_lt_x_gpu(2) @parametrize("nested", [True, False]) - @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE]) - def test_error_already_wrapped(self, nested, fsdp_init_mode): + @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE]) + def test_error_already_wrapped(self, nested, cuda_init_mode): """ Test that an error is raised if we attempt to wrap when submodules are already FSDP. """ - wrapped_fsdp = self._get_already_wrapped_fsdp(nested=nested, fsdp_init_mode=fsdp_init_mode) - if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: + wrapped_fsdp = self._get_already_wrapped_fsdp(nested=nested, cuda_init_mode=cuda_init_mode) + if cuda_init_mode == CUDAInitMode.CUDA_AFTER: wrapped_fsdp = wrapped_fsdp.cuda() with self.assertRaisesRegex(ValueError, "to NOT be FullyShardedDataParallel"): @@ -228,16 +231,16 @@ def wrap_bn_container(module, recurse, *args, **kwargs): ) @parametrize("forward_prefetch", [True, False]) @parametrize( - "fsdp_init_mode", - [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE] + "cuda_init_mode", + [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE] ) - def test_main_wrap_api(self, cpu_offload, backward_prefetch, forward_prefetch, fsdp_init_mode): + def test_main_wrap_api(self, cpu_offload, backward_prefetch, forward_prefetch, cuda_init_mode): - if fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: + if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params: # they don't work together, expected return - move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE + move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE class Nested(nn.Module): def __init__(self): @@ -269,7 +272,7 @@ def forward(self, input): backward_prefetch=backward_prefetch, forward_prefetch=forward_prefetch, ) - if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: + if cuda_init_mode == CUDAInitMode.CUDA_AFTER: wrapped_model = wrapped_model.cuda() modules_in_fsdp_graph_order = [ @@ -367,21 +370,26 @@ def test_always_wrap(self): @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_transformer_auto_wrap_policy(self): - model = TransformerWithSharedParams(group=self.process_group) - my_auto_wrap_policy = functools.partial( + """Tests the ``transformer_auto_wrap_policy``.""" + auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer} + transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) - fsdp_model = FSDP( - model, - process_group=self.process_group, - auto_wrap_policy=my_auto_wrap_policy + fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + fsdp_kwargs, ) - self.assertTrue(isinstance(fsdp_model, FSDP)) - for layer in fsdp_model.module.transformer.encoder.layers: - self.assertTrue(isinstance(layer, FSDP)) - for layer in fsdp_model.module.transformer.decoder.layers: - self.assertTrue(isinstance(layer, FSDP)) + modules = list(fsdp_model.modules()) + encoder_layers = set(fsdp_model.module.transformer.encoder.layers) + decoder_layers = set(fsdp_model.module.transformer.decoder.layers) + for module in modules: + if module is fsdp_model or module in encoder_layers or module in decoder_layers: + self.assertTrue(isinstance(module, FSDP)) + else: + self.assertFalse(isinstance(module, FSDP)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_api(self): @@ -475,16 +483,16 @@ def test_auto_wrap_preset_force_leaf_custom(self): self.assertTrue(isinstance(model.module[1], nn.ModuleList)) @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") - @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_BEFORE, FSDPInitMode.CUDA_AFTER]) + @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER]) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)] ) @parametrize("use_device_id", [True, False]) - def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload, use_device_id): + def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id): # CPU offload and CUDA after don't work together as expected. if ( - cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER + cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER ): return @@ -508,7 +516,7 @@ def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload, use_device_id): # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. - cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER + cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index b68765013e76d..766c5ce811f7a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1825,7 +1825,6 @@ def first_bucket_size(ddp_bucket_mb): @requires_nccl() @skip_if_lt_x_gpu(2) - @skip_if_rocm def test_grad_layout_1devicemodule_1replicaperprocess(self): dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0])) # Tells DDP to use just one device. @@ -2183,7 +2182,7 @@ def div(fut): process_group, allreduce_with_then_hook ) - # check whether the grads are equal to what allreduce returns multuplied by 5. + # check whether the grads are equal to what allreduce returns multiplied by 5. # without the comm_hook, result would be still 0.25 * torch.ones(2, 2). self._run_and_verify_hook(gpu_model, 8, 1.25 * torch.ones(2, 2)) diff --git a/test/distributed/test_c10d_pypg.py b/test/distributed/test_c10d_pypg.py new file mode 100644 index 0000000000000..9c9e0c4422d99 --- /dev/null +++ b/test/distributed/test_c10d_pypg.py @@ -0,0 +1,154 @@ +# Owner(s): ["oncall: distributed"] + +import os + +import torch +import torch.distributed as dist +from torch.testing._internal.common_utils import ( + run_tests, +) +from torch.futures import Future +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +import test_c10d_common +import weakref +from torch._C._distributed_c10d import _create_work_from_future +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, +) + +def create_work(result): + future = Future() + future.set_result(result) + return _create_work_from_future(future) + +class MyWork(dist._Work): + def __init__(self, result, pg): + super().__init__() + self.result_ = result + self.future_ = torch.futures.Future() + self.future_.set_result(result) + self.pg_ = weakref.ref(pg) + + def wait(self, timeout): + self.pg_().wait_count += 1 + return True + + def get_future(self): + self.pg_().get_future_count += 1 + return self.future_ + +class LonelyRankProcessGroup(dist.ProcessGroup): + """ + This PG only supports world_size of 1 + """ + def __init__(self, rank, world, use_wrapper): + super(LonelyRankProcessGroup, self).__init__(rank, world) + assert rank == 0 + assert world == 1 + + self._rank = rank + self._world = world + self.wait_count = 0 + self.get_future_count = 0 + self.use_wrapper = use_wrapper + self._work = [] + + def broadcast(self, tensor_list, opts): + if self.use_wrapper: + return create_work(tensor_list) + res = MyWork(tensor_list, self) + self._work.append(res) + return res + + def allgather(self, output_tensors, input_tensor, opts): + for o, i in zip(output_tensors[0], input_tensor): + o.copy_(i) + if self.use_wrapper: + return create_work(output_tensors) + + res = MyWork(output_tensors, self) + self._work.append(res) + + return res + + def allreduce(self, tensors, opts): + if self.use_wrapper: + return create_work(tensors) + res = MyWork(tensors, self) + self._work.append(res) + return res + + def size(self): + return self._world + + def getBackendName(self): + return "lonely-pg" + + def __repr__(self): + return f"PLG w:{self._world} r:{self._rank}" + +# We cannot use parametrize as some tests are defined on the base class and use _get_process_group +class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest): + def setUp(self): + super(AbstractDDPSingleRank, self).setUp() + self._spawn_processes() + + @property + def world_size(self): + return 1 + + def tearDown(self): + super(AbstractDDPSingleRank, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def _get_process_group(self): + return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper) + + def test_ddp_invoke_work_object(self): + pg = self._get_process_group() + + torch.manual_seed(123) + model = nn.Sequential( + nn.Linear(2, 2), + nn.ReLU() + ) + wrapped_model = model + input_tensor = torch.rand(2) + model = DDP(model, process_group=pg) + model(input_tensor).sum().backward() + + ddp_grad = wrapped_model[0].bias.grad.clone() + + wrapped_model.zero_grad() + wrapped_model(input_tensor).sum().backward() + self.assertEqual(wrapped_model[0].bias.grad, ddp_grad) + if not self.use_wrapper: + self.assertTrue(pg.wait_count > 0) + self.assertTrue(pg.get_future_count > 0) + + def test_ddp_with_pypg(self): + pg = self._get_process_group() + + self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None) + + def test_ddp_with_pypg_with_grad_views(self): + pg = self._get_process_group() + + self._test_ddp_with_process_group(pg, [torch.device("cpu")], device_ids=None, gradient_as_bucket_view=True) + +class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase): + @property + def use_wrapper(self): + return False + +class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase): + @property + def use_wrapper(self): + return True + +if __name__ == '__main__': + run_tests() diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py index 26f773cb90009..c9bafe0dd8622 100644 --- a/test/distributed/test_pg_wrapper.py +++ b/test/distributed/test_pg_wrapper.py @@ -265,6 +265,7 @@ def test_collective_hang(self): def test_collectives_op_mismatch_debug_mode(self): pg = self._create_wrapper_pg(with_new_group=True) self._test_collectives_op_mismatch(pg, use_cuda=True) + self._test_nccl_only_op_mismatch(pg) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -272,6 +273,7 @@ def test_collectives_op_mismatch_debug_mode(self): def test_collectives_op_mismatch(self): pg = self._create_wrapper_pg(with_new_group=False) self._test_collectives_op_mismatch(pg, use_cuda=True) + self._test_nccl_only_op_mismatch(pg) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -279,6 +281,7 @@ def test_collectives_op_mismatch(self): def test_collective_shape_mismatch_debug_mode(self): pg = self._create_wrapper_pg(with_new_group=True) self._test_collective_shape_mismatch(pg, use_cuda=True) + self._test_nccl_only_shape_mismatch(pg) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -286,6 +289,48 @@ def test_collective_shape_mismatch_debug_mode(self): def test_collective_shape_mismatch(self): pg = self._create_wrapper_pg(with_new_group=False) self._test_collective_shape_mismatch(pg, use_cuda=True) + self._test_nccl_only_shape_mismatch(pg) + + def _test_nccl_only_op_mismatch(self, wrapper_pg): + device = f"cuda:{self.rank}" + with self.assertRaisesRegex(RuntimeError, ".*") as cm: + output = torch.zeros(4 + self.rank, device=device) + input = torch.ones(4 * self.world_size, device=device) + if self.rank == 0: + wrapper_pg._allgather_base(output, input).wait() + else: + wrapper_pg._reduce_scatter_base(output, input).wait() + self._validate_error( + exception=cm.exception, + op_type="ALLGATHER_BASE" if self.rank == 0 else "REDUCE_SCATTER_BASE", + rank=self.rank, + tensor=input, + ) + + def _test_nccl_only_shape_mismatch(self, wrapper_pg): + device = f"cuda:{self.rank}" + with self.assertRaisesRegex(RuntimeError, ".*") as cm: + output = torch.zeros(4 + self.rank, device=device) + input = torch.ones(4 * self.world_size, device=device) + + wrapper_pg._reduce_scatter_base(output, input).wait() + self._validate_error( + exception=cm.exception, + op_type="REDUCE_SCATTER_BASE", + rank=self.rank, + tensor=input, + ) + with self.assertRaisesRegex(RuntimeError, ".*") as cm: + output = torch.zeros(4, device=device) + input = torch.ones((4 + self.rank) * self.world_size, device=device) + + wrapper_pg._reduce_scatter_base(output, input).wait() + self._validate_error( + exception=cm.exception, + op_type="REDUCE_SCATTER_BASE", + rank=self.rank, + tensor=input, + ) @requires_gloo() diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index bcff510bfe0c2..943d12aa85ae3 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -331,6 +331,10 @@ def test_unknown_handler(self): with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"): dist.rendezvous("invalid://") + def test_url_with_node_params(self): + with self.assertRaisesRegex(AssertionError, "has node-specific arguments"): + dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16) + class RendezvousEnvTest(TestCase): @retry_on_connect_failures diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index c8b5551c89371..4db246ab93b33 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -793,7 +793,14 @@ def is_all_nan(tensor): ] -class TestDistributions(TestCase): +class DistributionsTestCase(TestCase): + def setUp(self): + """The tests assume that the validation flag is set.""" + torch.distributions.Distribution.set_default_validate_args(True) + super(DistributionsTestCase, self).setUp() + + +class TestDistributions(DistributionsTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True @@ -3240,7 +3247,7 @@ def test_mode(self): # These tests are only needed for a few distributions that implement custom # reparameterized gradients. Most .rsample() implementations simply rely on # the reparameterization trick and do not need to be tested for accuracy. -class TestRsample(TestCase): +class TestRsample(DistributionsTestCase): @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_gamma(self): num_samples = 100 @@ -3448,7 +3455,7 @@ def compute_v(x, alpha): ])) -class TestDistributionShapes(TestCase): +class TestDistributionShapes(DistributionsTestCase): def setUp(self): super(TestDistributionShapes, self).setUp() self.scalar_sample = 1 @@ -3910,7 +3917,7 @@ def test_continuous_bernoulli_shape_tensor_params(self): self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))) -class TestKL(TestCase): +class TestKL(DistributionsTestCase): def setUp(self): super(TestKL, self).setUp() @@ -4304,7 +4311,7 @@ def test_entropy_exponential_family(self): ])) -class TestConstraints(TestCase): +class TestConstraints(DistributionsTestCase): def test_params_constraints(self): normalize_probs_dists = ( Categorical, @@ -4356,7 +4363,7 @@ def test_support_constraints(self): self.assertTrue(ok.all(), msg=message) -class TestNumericalStability(TestCase): +class TestNumericalStability(DistributionsTestCase): def _test_pdf_score(self, dist_class, x, @@ -4573,7 +4580,7 @@ def test_continuous_bernoulli_with_logits_overflow(self): # TODO: make this a pytest parameterized test -class TestLazyLogitsInitialization(TestCase): +class TestLazyLogitsInitialization(DistributionsTestCase): def setUp(self): super(TestLazyLogitsInitialization, self).setUp() # ContinuousBernoulli is not tested because log_prob is not computed simply @@ -4620,7 +4627,7 @@ def test_lazy_probs_initialization(self): @unittest.skipIf(not TEST_NUMPY, "NumPy not found") -class TestAgainstScipy(TestCase): +class TestAgainstScipy(DistributionsTestCase): def setUp(self): super(TestAgainstScipy, self).setUp() positive_var = torch.randn(20).exp() @@ -4794,7 +4801,7 @@ def test_icdf(self): self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist) -class TestFunctors(TestCase): +class TestFunctors(DistributionsTestCase): def test_cat_transform(self): x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100 @@ -4912,9 +4919,9 @@ def test_stack_transform(self): self.assertEqual(actual_jac, expected_jac) -class TestValidation(TestCase): +class TestValidation(DistributionsTestCase): def setUp(self): - super(TestCase, self).setUp() + super(TestValidation, self).setUp() def test_valid(self): for Dist, params in EXAMPLES: @@ -5007,7 +5014,7 @@ def tearDown(self): super(TestValidation, self).tearDown() -class TestJit(TestCase): +class TestJit(DistributionsTestCase): def _examples(self): for Dist, params in EXAMPLES: for param in params: diff --git a/test/distributions/test_transforms.py b/test/distributions/test_transforms.py index da645e0e50366..ea99562b1f0c6 100644 --- a/test/distributions/test_transforms.py +++ b/test/distributions/test_transforms.py @@ -1,5 +1,6 @@ # Owner(s): ["module: distributions"] +import io from numbers import Number import pytest @@ -472,5 +473,18 @@ def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim, assert log_prob.shape == d.batch_shape +def test_save_load_transform(): + # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check + # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after. + dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)]) + x = torch.linspace(0, 1, 10) + log_prob = dist.log_prob(x) + stream = io.BytesIO() + torch.save(dist, stream) + stream.seek(0) + other = torch.load(stream) + assert torch.allclose(log_prob, other.log_prob(x)) + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect index f01221172b700..eb6bd8aca3dcf 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect @@ -2,7 +2,7 @@ torch.fx._symbolic_trace.ProxyableClassMeta [] torch.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'is_leaf_module', 'path_of_module', 'trace'] torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'on_generate_code', 'output', 'owning_module', 'placeholder', 'print_tabular', 'process_inputs', 'process_outputs', 'python_code', 'set_codegen'] torch.fx.graph.PythonCode [] -torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'recompile', 'to_folder'] +torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'nested_str', 'recompile', 'to_folder'] torch.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update'] torch.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove'] torch.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node'] diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index bd8c0e63a52cb..a50db90a50567 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -59,6 +59,7 @@ torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> No torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument +torch.fx.passes.reinplace.reinplace(gm, *sample_args) torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None) torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None) diff --git a/test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect b/test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect index 0ac8d53e73b60..696fcbb08cf12 100644 --- a/test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect +++ b/test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect @@ -1,907 +1,6979 @@ -########## torch.float32/torch.int32/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=()+(3, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), + row_indices=tensor([0, 1, 0, 2]), values=tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]), size=(2, 4), nnz=4, + [[ 4., 14.]]]), size=(3, 4), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([0, 2, 4], dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 1], dtype=torch.int32) +tensor([0, 1, 0, 2], dtype=torch.int32) # _values tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]) + [[ 4., 14.]]]) -########## torch.float32/torch.int32/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float32/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), size=(0, 0), nnz=0, + values=tensor([], size=(0, 1, 2)), size=(0, 0), nnz=0, layout=torch.sparse_bsc) # _ccol_indices tensor([0], dtype=torch.int32) # _row_indices tensor([], dtype=torch.int32) # _values -tensor([], size=(1, 0, 0)) +tensor([], size=(0, 1, 2)) -########## torch.float32/torch.int32/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=(2,)+(6, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), - values=tensor([[[[ 1., 11.]], + [0, 1, 2, 0]]), + values=tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]), size=(2, 2, 4), nnz=4, + [[8.], + [9.]]]]), size=(2, 6, 2), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], dtype=torch.int32) + [0, 3, 4]], dtype=torch.int32) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], dtype=torch.int32) + [0, 1, 2, 0]], dtype=torch.int32) # _values -tensor([[[[ 1., 11.]], +tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]) + [[8.], + [9.]]]]) -########## torch.float32/torch.int32/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=(2, 3)+(9, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]), size=(2, 3, 2, 4), nnz=4, + [[24., 34.], + [25., 35.], + [26., 36.]]]]]), size=(2, 3, 9, 4), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], dtype=torch.int32) + [0, 3, 4]]], dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]) + [[24., 34.], + [25., 35.], + [26., 36.]]]]]) -########## torch.float64/torch.int32/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=()+(3, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), + row_indices=tensor([0, 1, 0, 2]), values=tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]), size=(2, 4), nnz=4, dtype=torch.float64, + [[ 4., 14.]]]), size=(3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([0, 2, 4], dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 1], dtype=torch.int32) +tensor([0, 1, 0, 2], dtype=torch.int32) # _values tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], dtype=torch.float64) + [[ 4., 14.]]], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float64/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), size=(0, 0), nnz=0, + values=tensor([], size=(0, 1, 2)), size=(0, 0), nnz=0, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([0], dtype=torch.int32) # _row_indices tensor([], dtype=torch.int32) # _values -tensor([], size=(1, 0, 0), dtype=torch.float64) +tensor([], size=(0, 1, 2), dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=(2,)+(6, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), - values=tensor([[[[ 1., 11.]], + [0, 1, 2, 0]]), + values=tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]), size=(2, 2, 4), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) + [[8.], + [9.]]]]), size=(2, 6, 2), nnz=4, dtype=torch.float64, + layout=torch.sparse_bsc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], dtype=torch.int32) + [0, 3, 4]], dtype=torch.int32) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], dtype=torch.int32) + [0, 1, 2, 0]], dtype=torch.int32) # _values -tensor([[[[ 1., 11.]], +tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]], dtype=torch.float64) + [[8.], + [9.]]]], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=(2, 3)+(9, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]), size=(2, 3, 2, 4), nnz=4, + [[24., 34.], + [25., 35.], + [26., 36.]]]]]), size=(2, 3, 9, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], dtype=torch.int32) + [0, 3, 4]]], dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]], dtype=torch.float64) + [[24., 34.], + [25., 35.], + [26., 36.]]]]], dtype=torch.float64) -########## torch.float32/torch.int64/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=()+(3, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), + row_indices=tensor([0, 1, 0, 2]), values=tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]), size=(2, 4), nnz=4, + [[ 4., 14.]]]), size=(3, 4), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([0, 2, 4]) # _row_indices -tensor([0, 1, 0, 1]) +tensor([0, 1, 0, 2]) # _values tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]) + [[ 4., 14.]]]) -########## torch.float32/torch.int64/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float32/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), size=(0, 0), nnz=0, + values=tensor([], size=(0, 1, 2)), size=(0, 0), nnz=0, layout=torch.sparse_bsc) # _ccol_indices tensor([0]) # _row_indices tensor([], dtype=torch.int64) # _values -tensor([], size=(1, 0, 0)) +tensor([], size=(0, 1, 2)) -########## torch.float32/torch.int64/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=(2,)+(6, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), - values=tensor([[[[ 1., 11.]], + [0, 1, 2, 0]]), + values=tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]), size=(2, 2, 4), nnz=4, + [[8.], + [9.]]]]), size=(2, 6, 2), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]]) + [0, 3, 4]]) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]) + [0, 1, 2, 0]]) # _values -tensor([[[[ 1., 11.]], +tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]) + [[8.], + [9.]]]]) -########## torch.float32/torch.int64/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=(2, 3)+(9, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]), size=(2, 3, 2, 4), nnz=4, + [[24., 34.], + [25., 35.], + [26., 36.]]]]]), size=(2, 3, 9, 4), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]) + [0, 3, 4]]]) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]) + [[24., 34.], + [25., 35.], + [26., 36.]]]]]) -########## torch.float64/torch.int64/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=()+(3, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), + row_indices=tensor([0, 1, 0, 2]), values=tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]), size=(2, 4), nnz=4, dtype=torch.float64, + [[ 4., 14.]]]), size=(3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([0, 2, 4]) # _row_indices -tensor([0, 1, 0, 1]) +tensor([0, 1, 0, 2]) # _values tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], dtype=torch.float64) + [[ 4., 14.]]], dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float64/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), size=(0, 0), nnz=0, + values=tensor([], size=(0, 1, 2)), size=(0, 0), nnz=0, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([0]) # _row_indices tensor([], dtype=torch.int64) # _values -tensor([], size=(1, 0, 0), dtype=torch.float64) +tensor([], size=(0, 1, 2), dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=(2,)+(6, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), - values=tensor([[[[ 1., 11.]], + [0, 1, 2, 0]]), + values=tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]), size=(2, 2, 4), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) + [[8.], + [9.]]]]), size=(2, 6, 2), nnz=4, dtype=torch.float64, + layout=torch.sparse_bsc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]]) + [0, 3, 4]]) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]) + [0, 1, 2, 0]]) # _values -tensor([[[[ 1., 11.]], +tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]], dtype=torch.float64) + [[8.], + [9.]]]], dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=(2, 3)+(9, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]), size=(2, 3, 2, 4), nnz=4, + [[24., 34.], + [25., 35.], + [26., 36.]]]]]), size=(2, 3, 9, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]) + [0, 3, 4]]]) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], + + + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], + + + [[[ 9., 19.], + [10., 20.], + [11., 21.]], + + [[10., 20.], + [11., 21.], + [12., 22.]], + + [[11., 21.], + [12., 22.], + [13., 23.]], + + [[12., 22.], + [13., 23.], + [14., 24.]]]], + + + + [[[[13., 23.], + [14., 24.], + [15., 25.]], + + [[14., 24.], + [15., 25.], + [16., 26.]], + + [[15., 25.], + [16., 26.], + [17., 27.]], + + [[16., 26.], + [17., 27.], + [18., 28.]]], + + + [[[17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.]], + + [[20., 30.], + [21., 31.], + [22., 32.]]], + + + [[[21., 31.], + [22., 32.], + [23., 33.]], + + [[22., 32.], + [23., 33.], + [24., 34.]], + + [[23., 33.], + [24., 34.], + [25., 35.]], + + [[24., 34.], + [25., 35.], + [26., 36.]]]]], dtype=torch.float64) + + +########## torch.float32/torch.int32/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]), size=(6, 6, 2), nnz=4, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]) + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]) + +########## torch.float32/torch.int32/size=()+(9, 4)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]), size=(9, 4, 4, 2), + nnz=4, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]) + +########## torch.float32/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], dtype=torch.int32) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]) + + +########## torch.float64/torch.int32/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]), size=(6, 6, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(9, 4)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]), size=(9, 4, 4, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], dtype=torch.int32) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]], dtype=torch.float64) + + +########## torch.float32/torch.int64/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]), size=(6, 6, 2), nnz=4, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4]) +# _row_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]) + +########## torch.float32/torch.int64/size=()+(9, 4)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]), size=(9, 4, 4, 2), + nnz=4, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4]) +# _row_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]) + +########## torch.float32/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]) + + +########## torch.float64/torch.int64/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]), size=(6, 6, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4]) +# _row_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(9, 4)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]), size=(9, 4, 4, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4]) +# _row_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], - [[ 2., 22.]], + [[ 32.], + [132.]], - [[ 3., 33.]], + [[ 42.], + [142.]]]], - [[ 4., 44.]]], - [[[ 1., 11.]], + [[[[ 22.], + [122.]], - [[ 2., 22.]], + [[ 32.], + [132.]], - [[ 3., 33.]], + [[ 42.], + [142.]]], - [[ 4., 44.]]], + [[[ 23.], + [123.]], - [[[ 1., 11.]], + [[ 33.], + [133.]], - [[ 2., 22.]], + [[ 43.], + [143.]]]], - [[ 3., 33.]], - [[ 4., 44.]]]], + [[[[ 23.], + [123.]], + [[ 33.], + [133.]], - [[[[ 1., 11.]], + [[ 43.], + [143.]]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[ 24.], + [124.]], - [[ 4., 44.]]], + [[ 34.], + [134.]], + [[ 44.], + [144.]]]], - [[[ 1., 11.]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[[ 24.], + [124.]], - [[ 4., 44.]]], + [[ 34.], + [134.]], + [[ 44.], + [144.]]], - [[[ 1., 11.]], - [[ 2., 22.]], + [[[ 25.], + [125.]], - [[ 3., 33.]], + [[ 35.], + [135.]], - [[ 4., 44.]]]]], dtype=torch.float64) + [[ 45.], + [145.]]]]]]], dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect b/test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect index 2de9b362e31e1..267056b76e678 100644 --- a/test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect +++ b/test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect @@ -1,907 +1,6945 @@ -########## torch.float32/torch.int32/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=()+(4, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([[[ 1., 11.]], + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]), size=(2, 4), nnz=4, - layout=torch.sparse_bsr) + [[4.], + [5.]]]), size=(4, 3), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([0, 2, 4], dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 1], dtype=torch.int32) +tensor([0, 1, 0, 2], dtype=torch.int32) # _values -tensor([[[ 1., 11.]], +tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]) + [[4.], + [5.]]]) -########## torch.float32/torch.int32/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float32/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), size=(0, 0), nnz=0, + values=tensor([], size=(0, 2, 1)), size=(0, 0), nnz=0, layout=torch.sparse_bsr) # _crow_indices tensor([0], dtype=torch.int32) # _col_indices tensor([], dtype=torch.int32) # _values -tensor([], size=(1, 0, 0)) +tensor([], size=(0, 2, 1)) -########## torch.float32/torch.int32/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=(2,)+(2, 6)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]), size=(2, 2, 4), nnz=4, + [[ 8., 18.]]]]), size=(2, 2, 6), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], dtype=torch.int32) + [0, 3, 4]], dtype=torch.int32) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], dtype=torch.int32) + [0, 1, 2, 0]], dtype=torch.int32) # _values tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]) + [[ 8., 18.]]]]) -########## torch.float32/torch.int32/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=(2, 3)+(4, 9)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]), size=(2, 3, 2, 4), nnz=4, + [[24., 34., 44.], + [25., 35., 45.]]]]]), size=(2, 3, 4, 9), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], dtype=torch.int32) + [0, 3, 4]]], dtype=torch.int32) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]) + [[24., 34., 44.], + [25., 35., 45.]]]]]) -########## torch.float64/torch.int32/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=()+(4, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([[[ 1., 11.]], + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]), size=(2, 4), nnz=4, dtype=torch.float64, + [[4.], + [5.]]]), size=(4, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([0, 2, 4], dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 1], dtype=torch.int32) +tensor([0, 1, 0, 2], dtype=torch.int32) # _values -tensor([[[ 1., 11.]], +tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], dtype=torch.float64) + [[4.], + [5.]]], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float64/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), size=(0, 0), nnz=0, + values=tensor([], size=(0, 2, 1)), size=(0, 0), nnz=0, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([0], dtype=torch.int32) # _col_indices tensor([], dtype=torch.int32) # _values -tensor([], size=(1, 0, 0), dtype=torch.float64) +tensor([], size=(0, 2, 1), dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=(2,)+(2, 6)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]), size=(2, 2, 4), nnz=4, + [[ 8., 18.]]]]), size=(2, 2, 6), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], dtype=torch.int32) + [0, 3, 4]], dtype=torch.int32) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], dtype=torch.int32) + [0, 1, 2, 0]], dtype=torch.int32) # _values tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]], dtype=torch.float64) + [[ 8., 18.]]]], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=(2, 3)+(4, 9)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]), size=(2, 3, 2, 4), nnz=4, + [[24., 34., 44.], + [25., 35., 45.]]]]]), size=(2, 3, 4, 9), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], dtype=torch.int32) + [0, 3, 4]]], dtype=torch.int32) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]], dtype=torch.float64) + [[24., 34., 44.], + [25., 35., 45.]]]]], dtype=torch.float64) -########## torch.float32/torch.int64/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=()+(4, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([[[ 1., 11.]], + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]), size=(2, 4), nnz=4, - layout=torch.sparse_bsr) + [[4.], + [5.]]]), size=(4, 3), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([0, 2, 4]) # _col_indices -tensor([0, 1, 0, 1]) +tensor([0, 1, 0, 2]) # _values -tensor([[[ 1., 11.]], +tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]) + [[4.], + [5.]]]) -########## torch.float32/torch.int64/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float32/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), size=(0, 0), nnz=0, + values=tensor([], size=(0, 2, 1)), size=(0, 0), nnz=0, layout=torch.sparse_bsr) # _crow_indices tensor([0]) # _col_indices tensor([], dtype=torch.int64) # _values -tensor([], size=(1, 0, 0)) +tensor([], size=(0, 2, 1)) -########## torch.float32/torch.int64/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=(2,)+(2, 6)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]), size=(2, 2, 4), nnz=4, + [[ 8., 18.]]]]), size=(2, 2, 6), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]]) + [0, 3, 4]]) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]) + [0, 1, 2, 0]]) # _values tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]) + [[ 8., 18.]]]]) -########## torch.float32/torch.int64/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=(2, 3)+(4, 9)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]), size=(2, 3, 2, 4), nnz=4, + [[24., 34., 44.], + [25., 35., 45.]]]]]), size=(2, 3, 4, 9), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]) + [0, 3, 4]]]) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]) + [[24., 34., 44.], + [25., 35., 45.]]]]]) -########## torch.float64/torch.int64/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=()+(4, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([[[ 1., 11.]], + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]), size=(2, 4), nnz=4, dtype=torch.float64, + [[4.], + [5.]]]), size=(4, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([0, 2, 4]) # _col_indices -tensor([0, 1, 0, 1]) +tensor([0, 1, 0, 2]) # _values -tensor([[[ 1., 11.]], +tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], dtype=torch.float64) + [[4.], + [5.]]], dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float64/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), size=(0, 0), nnz=0, + values=tensor([], size=(0, 2, 1)), size=(0, 0), nnz=0, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([0]) # _col_indices tensor([], dtype=torch.int64) # _values -tensor([], size=(1, 0, 0), dtype=torch.float64) +tensor([], size=(0, 2, 1), dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=(2,)+(2, 6)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]), size=(2, 2, 4), nnz=4, + [[ 8., 18.]]]]), size=(2, 2, 6), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]]) + [0, 3, 4]]) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]) + [0, 1, 2, 0]]) # _values tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]], dtype=torch.float64) + [[ 8., 18.]]]], dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=(2, 3)+(4, 9)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]), size=(2, 3, 2, 4), nnz=4, + [[24., 34., 44.], + [25., 35., 45.]]]]]), size=(2, 3, 4, 9), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]) + [0, 3, 4]]]) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], + + [[ 2., 12., 22.], + [ 3., 13., 23.]], + + [[ 3., 13., 23.], + [ 4., 14., 24.]], + + [[ 4., 14., 24.], + [ 5., 15., 25.]]], + + + [[[ 5., 15., 25.], + [ 6., 16., 26.]], + + [[ 6., 16., 26.], + [ 7., 17., 27.]], + + [[ 7., 17., 27.], + [ 8., 18., 28.]], + + [[ 8., 18., 28.], + [ 9., 19., 29.]]], + + + [[[ 9., 19., 29.], + [10., 20., 30.]], + + [[10., 20., 30.], + [11., 21., 31.]], + + [[11., 21., 31.], + [12., 22., 32.]], + + [[12., 22., 32.], + [13., 23., 33.]]]], + + + + [[[[13., 23., 33.], + [14., 24., 34.]], + + [[14., 24., 34.], + [15., 25., 35.]], + + [[15., 25., 35.], + [16., 26., 36.]], + + [[16., 26., 36.], + [17., 27., 37.]]], + + + [[[17., 27., 37.], + [18., 28., 38.]], + + [[18., 28., 38.], + [19., 29., 39.]], + + [[19., 29., 39.], + [20., 30., 40.]], + + [[20., 30., 40.], + [21., 31., 41.]]], + + + [[[21., 31., 41.], + [22., 32., 42.]], + + [[22., 32., 42.], + [23., 33., 43.]], + + [[23., 33., 43.], + [24., 34., 44.]], + + [[24., 34., 44.], + [25., 35., 45.]]]]], dtype=torch.float64) + + +########## torch.float32/torch.int32/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]), size=(6, 6, 2), nnz=4, + layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]) + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]) + +########## torch.float32/torch.int32/size=()+(4, 9)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]), size=(4, 9, 4, 2), + nnz=4, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]) + +########## torch.float32/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]) + + +########## torch.float64/torch.int32/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]), size=(6, 6, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(4, 9)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]), size=(4, 9, 4, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]], dtype=torch.float64) + + +########## torch.float32/torch.int64/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]), size=(6, 6, 2), nnz=4, + layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4]) +# _col_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]) + +########## torch.float32/torch.int64/size=()+(4, 9)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]), size=(4, 9, 4, 2), + nnz=4, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4]) +# _col_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]) + +########## torch.float32/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]) + + +########## torch.float64/torch.int64/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]), size=(6, 6, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4]) +# _col_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(4, 9)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]), size=(4, 9, 4, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4]) +# _col_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + - [[ 2., 22.]], - [[ 3., 33.]], + [[[[ 22.], + [122.]], - [[ 4., 44.]]], + [[ 32.], + [132.]]], - [[[ 1., 11.]], + [[[ 23.], + [123.]], - [[ 2., 22.]], + [[ 33.], + [133.]]], - [[ 3., 33.]], - [[ 4., 44.]]], + [[[ 24.], + [124.]], + [[ 34.], + [134.]]]], - [[[ 1., 11.]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[[ 23.], + [123.]], - [[ 4., 44.]]]], + [[ 33.], + [133.]]], + [[[ 24.], + [124.]], - [[[[ 1., 11.]], + [[ 34.], + [134.]]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[ 25.], + [125.]], - [[ 4., 44.]]], + [[ 35.], + [135.]]]], - [[[ 1., 11.]], - [[ 2., 22.]], + [[[[ 24.], + [124.]], - [[ 3., 33.]], + [[ 34.], + [134.]]], - [[ 4., 44.]]], + [[[ 25.], + [125.]], - [[[ 1., 11.]], + [[ 35.], + [135.]]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[ 26.], + [126.]], - [[ 4., 44.]]]]], dtype=torch.float64) + [[ 36.], + [136.]]]]]]], dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect b/test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect index a449883a3fe20..15e9bb56a85c7 100644 --- a/test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect +++ b/test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect @@ -1,17 +1,17 @@ -########## torch.float32/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int32/size=()+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + row_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), size=(3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([0, 2, 4], dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 1], dtype=torch.int32) +tensor([0, 1, 0, 2], dtype=torch.int32) # _values tensor([1., 2., 3., 4.]) -########## torch.float32/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), @@ -24,89 +24,89 @@ tensor([], dtype=torch.int32) # _values tensor([]) -########## torch.float32/torch.int32/batch_shape=(2,)/block_shape=() ########## +########## torch.float32/torch.int32/size=(2,)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), size=(2, 2, 2), nnz=4, + [5., 6., 7., 8.]]), size=(2, 3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], dtype=torch.int32) + [0, 3, 4]], dtype=torch.int32) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], dtype=torch.int32) + [0, 1, 2, 0]], dtype=torch.int32) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]) + [5., 6., 7., 8.]]) -########## torch.float32/torch.int32/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float32/torch.int32/size=(2, 3)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), size=(2, 3, 2, 2), nnz=4, + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), size=(2, 3, 3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], dtype=torch.int32) + [0, 3, 4]]], dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]) + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]) -########## torch.float64/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int32/size=()+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + row_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), size=(3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([0, 2, 4], dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 1], dtype=torch.int32) +tensor([0, 1, 0, 2], dtype=torch.int32) # _values tensor([1., 2., 3., 4.], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), @@ -119,89 +119,89 @@ tensor([], dtype=torch.int32) # _values tensor([], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2,)/block_shape=() ########## +########## torch.float64/torch.int32/size=(2,)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), size=(2, 2, 2), nnz=4, + [5., 6., 7., 8.]]), size=(2, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], dtype=torch.int32) + [0, 3, 4]], dtype=torch.int32) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], dtype=torch.int32) + [0, 1, 2, 0]], dtype=torch.int32) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], dtype=torch.float64) + [5., 6., 7., 8.]], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float64/torch.int32/size=(2, 3)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), size=(2, 3, 2, 2), nnz=4, + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), size=(2, 3, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], dtype=torch.int32) + [0, 3, 4]]], dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], dtype=torch.float64) + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], dtype=torch.float64) -########## torch.float32/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int64/size=()+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + row_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), size=(3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([0, 2, 4]) # _row_indices -tensor([0, 1, 0, 1]) +tensor([0, 1, 0, 2]) # _values tensor([1., 2., 3., 4.]) -########## torch.float32/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), @@ -214,89 +214,89 @@ tensor([], dtype=torch.int64) # _values tensor([]) -########## torch.float32/torch.int64/batch_shape=(2,)/block_shape=() ########## +########## torch.float32/torch.int64/size=(2,)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), size=(2, 2, 2), nnz=4, + [5., 6., 7., 8.]]), size=(2, 3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]]) + [0, 3, 4]]) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]) + [0, 1, 2, 0]]) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]) + [5., 6., 7., 8.]]) -########## torch.float32/torch.int64/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float32/torch.int64/size=(2, 3)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), size=(2, 3, 2, 2), nnz=4, + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), size=(2, 3, 3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]) + [0, 3, 4]]]) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]) + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]) -########## torch.float64/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int64/size=()+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + row_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), size=(3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([0, 2, 4]) # _row_indices -tensor([0, 1, 0, 1]) +tensor([0, 1, 0, 2]) # _values tensor([1., 2., 3., 4.], dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), @@ -309,71 +309,1103 @@ tensor([], dtype=torch.int64) # _values tensor([], dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2,)/block_shape=() ########## +########## torch.float64/torch.int64/size=(2,)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), size=(2, 2, 2), nnz=4, + [5., 6., 7., 8.]]), size=(2, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]]) + [0, 3, 4]]) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]) + [0, 1, 2, 0]]) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], dtype=torch.float64) + [5., 6., 7., 8.]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(3, 2)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), size=(2, 3, 3, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], dtype=torch.float64) + + +########## torch.float32/torch.int32/size=()+(3, 2)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), size=(3, 2, 2), nnz=4, + layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]) + +########## torch.float32/torch.int32/size=()+(3, 2)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), size=(3, 2, 4, 2), nnz=4, + layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], -########## torch.float64/torch.int64/batch_shape=(2, 3)/block_shape=() ########## + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]) + +########## torch.float32/torch.int32/size=(2, 3)+(3, 2)+(2, 1) ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), size=(2, 3, 3, 2, 2, 1), nnz=4, + layout=torch.sparse_csc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], dtype=torch.int32) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], - [[0, 2, 4], + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]) + + +########## torch.float64/torch.int32/size=()+(3, 2)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), size=(3, 2, 2), nnz=4, dtype=torch.float64, + layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(3, 2)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), size=(3, 2, 4, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=(2, 3)+(3, 2)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), size=(2, 3, 2, 2), nnz=4, + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), size=(2, 3, 3, 2, 2, 1), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], dtype=torch.int32) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]], dtype=torch.float64) + + +########## torch.float32/torch.int64/size=()+(3, 2)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), size=(3, 2, 2), nnz=4, + layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4]) +# _row_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]) + +########## torch.float32/torch.int64/size=()+(3, 2)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), size=(3, 2, 4, 2), nnz=4, + layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4]) +# _row_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]) + +########## torch.float32/torch.int64/size=(2, 3)+(3, 2)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), size=(2, 3, 3, 2, 2, 1), nnz=4, + layout=torch.sparse_csc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]], + [0, 3, 4]]]) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], - [[0, 2, 4], + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]) + + +########## torch.float64/torch.int64/size=()+(3, 2)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), size=(3, 2, 2), nnz=4, dtype=torch.float64, + layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4]) +# _row_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(3, 2)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), size=(3, 2, 4, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4]) +# _row_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(3, 2)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), size=(2, 3, 3, 2, 2, 1), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]) + [0, 3, 4]]]) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], dtype=torch.float64) + [[24.], + [25.]]]]], dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect b/test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect index 02476652e4b7f..3ab2e1135aa55 100644 --- a/test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect +++ b/test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect @@ -1,17 +1,17 @@ -########## torch.float32/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int32/size=()+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + col_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), size=(2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([0, 2, 4], dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 1], dtype=torch.int32) +tensor([0, 1, 0, 2], dtype=torch.int32) # _values tensor([1., 2., 3., 4.]) -########## torch.float32/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), @@ -24,89 +24,89 @@ tensor([], dtype=torch.int32) # _values tensor([]) -########## torch.float32/torch.int32/batch_shape=(2,)/block_shape=() ########## +########## torch.float32/torch.int32/size=(2,)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), size=(2, 2, 2), nnz=4, + [5., 6., 7., 8.]]), size=(2, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], dtype=torch.int32) + [0, 3, 4]], dtype=torch.int32) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], dtype=torch.int32) + [0, 1, 2, 0]], dtype=torch.int32) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]) + [5., 6., 7., 8.]]) -########## torch.float32/torch.int32/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), size=(2, 3, 2, 2), nnz=4, + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], dtype=torch.int32) + [0, 3, 4]]], dtype=torch.int32) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]) + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]) -########## torch.float64/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int32/size=()+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + col_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), size=(2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([0, 2, 4], dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 1], dtype=torch.int32) +tensor([0, 1, 0, 2], dtype=torch.int32) # _values tensor([1., 2., 3., 4.], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), @@ -119,89 +119,89 @@ tensor([], dtype=torch.int32) # _values tensor([], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2,)/block_shape=() ########## +########## torch.float64/torch.int32/size=(2,)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), size=(2, 2, 2), nnz=4, + [5., 6., 7., 8.]]), size=(2, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], dtype=torch.int32) + [0, 3, 4]], dtype=torch.int32) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], dtype=torch.int32) + [0, 1, 2, 0]], dtype=torch.int32) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], dtype=torch.float64) + [5., 6., 7., 8.]], dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), size=(2, 3, 2, 2), nnz=4, + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), size=(2, 3, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], dtype=torch.int32) + [0, 3, 4]]], dtype=torch.int32) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], dtype=torch.float64) + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], dtype=torch.float64) -########## torch.float32/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int64/size=()+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + col_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), size=(2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([0, 2, 4]) # _col_indices -tensor([0, 1, 0, 1]) +tensor([0, 1, 0, 2]) # _values tensor([1., 2., 3., 4.]) -########## torch.float32/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), @@ -214,89 +214,89 @@ tensor([], dtype=torch.int64) # _values tensor([]) -########## torch.float32/torch.int64/batch_shape=(2,)/block_shape=() ########## +########## torch.float32/torch.int64/size=(2,)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), size=(2, 2, 2), nnz=4, + [5., 6., 7., 8.]]), size=(2, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]]) + [0, 3, 4]]) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]) + [0, 1, 2, 0]]) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]) + [5., 6., 7., 8.]]) -########## torch.float32/torch.int64/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), size=(2, 3, 2, 2), nnz=4, + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]) + [0, 3, 4]]]) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]) + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]) -########## torch.float64/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int64/size=()+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + col_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), size=(2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([0, 2, 4]) # _col_indices -tensor([0, 1, 0, 1]) +tensor([0, 1, 0, 2]) # _values tensor([1., 2., 3., 4.], dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), @@ -309,71 +309,1103 @@ tensor([], dtype=torch.int64) # _values tensor([], dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2,)/block_shape=() ########## +########## torch.float64/torch.int64/size=(2,)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), size=(2, 2, 2), nnz=4, + [5., 6., 7., 8.]]), size=(2, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]]) + [0, 3, 4]]) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]) + [0, 1, 2, 0]]) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], dtype=torch.float64) + [5., 6., 7., 8.]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), size=(2, 3, 2, 3), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], dtype=torch.float64) + + +########## torch.float32/torch.int32/size=()+(2, 3)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), size=(2, 3, 2), nnz=4, + layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]) + +########## torch.float32/torch.int32/size=()+(2, 3)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), size=(2, 3, 4, 2), nnz=4, + layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], -########## torch.float64/torch.int64/batch_shape=(2, 3)/block_shape=() ########## + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]) + +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(2, 1) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), size=(2, 3, 2, 3, 2, 1), nnz=4, + layout=torch.sparse_csr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], - [[0, 2, 4], + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]) + + +########## torch.float64/torch.int32/size=()+(2, 3)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), size=(2, 3, 2), nnz=4, dtype=torch.float64, + layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(2, 3)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), size=(2, 3, 4, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], dtype=torch.int32) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), size=(2, 3, 2, 2), nnz=4, + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), size=(2, 3, 2, 3, 2, 1), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], dtype=torch.int32) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]], dtype=torch.float64) + + +########## torch.float32/torch.int64/size=()+(2, 3)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), size=(2, 3, 2), nnz=4, + layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4]) +# _col_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]) + +########## torch.float32/torch.int64/size=()+(2, 3)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), size=(2, 3, 4, 2), nnz=4, + layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4]) +# _col_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]) + +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), size=(2, 3, 2, 3, 2, 1), nnz=4, + layout=torch.sparse_csr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]], + [0, 3, 4]]]) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], - [[0, 2, 4], + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]) + + +########## torch.float64/torch.int64/size=()+(2, 3)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), size=(2, 3, 2), nnz=4, dtype=torch.float64, + layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4]) +# _col_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(2, 3)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), size=(2, 3, 4, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4]) +# _col_indices +tensor([0, 1, 0, 2]) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), size=(2, 3, 2, 3, 2, 1), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]) + [0, 3, 4]]]) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], dtype=torch.float64) + [[24.], + [25.]]]]], dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect b/test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect index d3a51cb1c939d..46bdb44b2a983 100644 --- a/test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect +++ b/test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect @@ -1,907 +1,6981 @@ -########## torch.float32/torch.int32/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=()+(3, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), + row_indices=tensor([0, 1, 0, 2]), values=tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]), device='cuda:0', size=(2, 4), nnz=4, + [[ 4., 14.]]]), device='cuda:0', size=(3, 4), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], device='cuda:0') + [[ 4., 14.]]], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float32/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), device='cuda:0', size=(0, 0), nnz=0, + values=tensor([], size=(0, 1, 2)), device='cuda:0', size=(0, 0), nnz=0, layout=torch.sparse_bsc) # _ccol_indices tensor([0], device='cuda:0', dtype=torch.int32) # _row_indices tensor([], device='cuda:0', dtype=torch.int32) # _values -tensor([], device='cuda:0', size=(1, 0, 0)) +tensor([], device='cuda:0', size=(0, 1, 2)) -########## torch.float32/torch.int32/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=(2,)+(6, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), - values=tensor([[[[ 1., 11.]], + [0, 1, 2, 0]]), + values=tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]), device='cuda:0', size=(2, 2, 4), nnz=4, + [[8.], + [9.]]]]), device='cuda:0', size=(2, 6, 2), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0', dtype=torch.int32) + [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[ 1., 11.]], +tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]], device='cuda:0') + [[8.], + [9.]]]], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=(2, 3)+(9, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]), device='cuda:0', size=(2, 3, 2, 4), + [[24., 34.], + [25., 35.], + [26., 36.]]]]]), device='cuda:0', size=(2, 3, 9, 4), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]], device='cuda:0') + [[24., 34.], + [25., 35.], + [26., 36.]]]]], device='cuda:0') -########## torch.float64/torch.int32/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=()+(3, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), + row_indices=tensor([0, 1, 0, 2]), values=tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]), device='cuda:0', size=(2, 4), nnz=4, + [[ 4., 14.]]]), device='cuda:0', size=(3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], device='cuda:0', dtype=torch.float64) + [[ 4., 14.]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float64/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), device='cuda:0', size=(0, 0), nnz=0, + values=tensor([], size=(0, 1, 2)), device='cuda:0', size=(0, 0), nnz=0, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([0], device='cuda:0', dtype=torch.int32) # _row_indices tensor([], device='cuda:0', dtype=torch.int32) # _values -tensor([], device='cuda:0', size=(1, 0, 0), dtype=torch.float64) +tensor([], device='cuda:0', size=(0, 1, 2), dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=(2,)+(6, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), - values=tensor([[[[ 1., 11.]], + [0, 1, 2, 0]]), + values=tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]), device='cuda:0', size=(2, 2, 4), nnz=4, + [[8.], + [9.]]]]), device='cuda:0', size=(2, 6, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0', dtype=torch.int32) + [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[ 1., 11.]], +tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]], device='cuda:0', dtype=torch.float64) + [[8.], + [9.]]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=(2, 3)+(9, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]), device='cuda:0', size=(2, 3, 2, 4), + [[24., 34.], + [25., 35.], + [26., 36.]]]]]), device='cuda:0', size=(2, 3, 9, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]], device='cuda:0', dtype=torch.float64) + [[24., 34.], + [25., 35.], + [26., 36.]]]]], device='cuda:0', dtype=torch.float64) -########## torch.float32/torch.int64/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=()+(3, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), + row_indices=tensor([0, 1, 0, 2]), values=tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]), device='cuda:0', size=(2, 4), nnz=4, + [[ 4., 14.]]]), device='cuda:0', size=(3, 4), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([0, 2, 4], device='cuda:0') # _row_indices -tensor([0, 1, 0, 1], device='cuda:0') +tensor([0, 1, 0, 2], device='cuda:0') # _values tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], device='cuda:0') + [[ 4., 14.]]], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float32/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), device='cuda:0', size=(0, 0), nnz=0, + values=tensor([], size=(0, 1, 2)), device='cuda:0', size=(0, 0), nnz=0, layout=torch.sparse_bsc) # _ccol_indices tensor([0], device='cuda:0') # _row_indices tensor([], device='cuda:0', dtype=torch.int64) # _values -tensor([], device='cuda:0', size=(1, 0, 0)) +tensor([], device='cuda:0', size=(0, 1, 2)) -########## torch.float32/torch.int64/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=(2,)+(6, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), - values=tensor([[[[ 1., 11.]], + [0, 1, 2, 0]]), + values=tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]), device='cuda:0', size=(2, 2, 4), nnz=4, + [[8.], + [9.]]]]), device='cuda:0', size=(2, 6, 2), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0') + [0, 3, 4]], device='cuda:0') # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0') + [0, 1, 2, 0]], device='cuda:0') # _values -tensor([[[[ 1., 11.]], +tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]], device='cuda:0') + [[8.], + [9.]]]], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=(2, 3)+(9, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]), device='cuda:0', size=(2, 3, 2, 4), + [[24., 34.], + [25., 35.], + [26., 36.]]]]]), device='cuda:0', size=(2, 3, 9, 4), nnz=4, layout=torch.sparse_bsc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0') + [0, 3, 4]]], device='cuda:0') # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0') + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]], device='cuda:0') + [[24., 34.], + [25., 35.], + [26., 36.]]]]], device='cuda:0') -########## torch.float64/torch.int64/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=()+(3, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), + row_indices=tensor([0, 1, 0, 2]), values=tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]]), device='cuda:0', size=(2, 4), nnz=4, + [[ 4., 14.]]]), device='cuda:0', size=(3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([0, 2, 4], device='cuda:0') # _row_indices -tensor([0, 1, 0, 1], device='cuda:0') +tensor([0, 1, 0, 2], device='cuda:0') # _values tensor([[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], device='cuda:0', dtype=torch.float64) + [[ 4., 14.]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float64/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), device='cuda:0', size=(0, 0), nnz=0, + values=tensor([], size=(0, 1, 2)), device='cuda:0', size=(0, 0), nnz=0, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([0], device='cuda:0') # _row_indices tensor([], device='cuda:0', dtype=torch.int64) # _values -tensor([], device='cuda:0', size=(1, 0, 0), dtype=torch.float64) +tensor([], device='cuda:0', size=(0, 1, 2), dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=(2,)+(6, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), - values=tensor([[[[ 1., 11.]], + [0, 1, 2, 0]]), + values=tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]]), device='cuda:0', size=(2, 2, 4), nnz=4, + [[8.], + [9.]]]]), device='cuda:0', size=(2, 6, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0') + [0, 3, 4]], device='cuda:0') # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0') + [0, 1, 2, 0]], device='cuda:0') # _values -tensor([[[[ 1., 11.]], +tensor([[[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], + [[4.], + [5.]]], - [[[ 1., 11.]], + [[[5.], + [6.]], - [[ 2., 22.]], + [[6.], + [7.]], - [[ 3., 33.]], + [[7.], + [8.]], - [[ 4., 44.]]]], device='cuda:0', dtype=torch.float64) + [[8.], + [9.]]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=(2, 3)+(9, 4)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], - [[ 2., 22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 3., 33.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 4., 44.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 1., 11.]], + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 2., 22.]], + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 3., 33.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 4., 44.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], - [[[ 1., 11.]], + [[[ 9., 19.], + [10., 20.], + [11., 21.]], - [[ 2., 22.]], + [[10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 33.]], + [[11., 21.], + [12., 22.], + [13., 23.]], - [[ 4., 44.]]]], + [[12., 22.], + [13., 23.], + [14., 24.]]]], - [[[[ 1., 11.]], + [[[[13., 23.], + [14., 24.], + [15., 25.]], - [[ 2., 22.]], + [[14., 24.], + [15., 25.], + [16., 26.]], - [[ 3., 33.]], + [[15., 25.], + [16., 26.], + [17., 27.]], - [[ 4., 44.]]], + [[16., 26.], + [17., 27.], + [18., 28.]]], - [[[ 1., 11.]], + [[[17., 27.], + [18., 28.], + [19., 29.]], - [[ 2., 22.]], + [[18., 28.], + [19., 29.], + [20., 30.]], - [[ 3., 33.]], + [[19., 29.], + [20., 30.], + [21., 31.]], - [[ 4., 44.]]], + [[20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 1., 11.]], + [[[21., 31.], + [22., 32.], + [23., 33.]], - [[ 2., 22.]], + [[22., 32.], + [23., 33.], + [24., 34.]], - [[ 3., 33.]], + [[23., 33.], + [24., 34.], + [25., 35.]], - [[ 4., 44.]]]]]), device='cuda:0', size=(2, 3, 2, 4), + [[24., 34.], + [25., 35.], + [26., 36.]]]]]), device='cuda:0', size=(2, 3, 9, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0') + [0, 3, 4]]], device='cuda:0') # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[[[ 1., 11.], + [ 2., 12.], + [ 3., 13.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], + + + [[[ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.]]], + + + [[[ 9., 19.], + [10., 20.], + [11., 21.]], + + [[10., 20.], + [11., 21.], + [12., 22.]], + + [[11., 21.], + [12., 22.], + [13., 23.]], + + [[12., 22.], + [13., 23.], + [14., 24.]]]], + + + + [[[[13., 23.], + [14., 24.], + [15., 25.]], + + [[14., 24.], + [15., 25.], + [16., 26.]], + + [[15., 25.], + [16., 26.], + [17., 27.]], + + [[16., 26.], + [17., 27.], + [18., 28.]]], + + + [[[17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.]], + + [[20., 30.], + [21., 31.], + [22., 32.]]], + + + [[[21., 31.], + [22., 32.], + [23., 33.]], + + [[22., 32.], + [23., 33.], + [24., 34.]], + + [[23., 33.], + [24., 34.], + [25., 35.]], + + [[24., 34.], + [25., 35.], + [26., 36.]]]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int32/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]), device='cuda:0', size=(6, 6, 2), + nnz=4, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0') + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]], device='cuda:0') + +########## torch.float32/torch.int32/size=()+(9, 4)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]), device='cuda:0', + size=(9, 4, 4, 2), nnz=4, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]], device='cuda:0') + +########## torch.float32/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]), device='cuda:0', + size=(2, 3, 6, 6, 2, 1), nnz=4, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]], device='cuda:0') + + +########## torch.float64/torch.int32/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]), device='cuda:0', size=(6, 6, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(9, 4)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]), device='cuda:0', + size=(9, 4, 4, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]), device='cuda:0', + size=(2, 3, 6, 6, 2, 1), nnz=4, dtype=torch.float64, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int64/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]), device='cuda:0', size=(6, 6, 2), + nnz=4, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0') +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]], device='cuda:0') + +########## torch.float32/torch.int64/size=()+(9, 4)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]), device='cuda:0', + size=(9, 4, 4, 2), nnz=4, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0') +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]], device='cuda:0') + +########## torch.float32/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]), device='cuda:0', + size=(2, 3, 6, 6, 2, 1), nnz=4, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0') +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]], device='cuda:0') + + +########## torch.float64/torch.int64/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]]), device='cuda:0', size=(6, 6, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0') +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[[ 1., 101.], + [ 11., 111.], + [ 21., 121.]], + + [[ 2., 102.], + [ 12., 112.], + [ 22., 122.]]], + + + [[[ 2., 102.], + [ 12., 112.], + [ 22., 122.]], + + [[ 3., 103.], + [ 13., 113.], + [ 23., 123.]]], + + + [[[ 3., 103.], + [ 13., 113.], + [ 23., 123.]], + + [[ 4., 104.], + [ 14., 114.], + [ 24., 124.]]], + + + [[[ 4., 104.], + [ 14., 114.], + [ 24., 124.]], + + [[ 5., 105.], + [ 15., 115.], + [ 25., 125.]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(9, 4)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]]), device='cuda:0', + size=(9, 4, 4, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0') +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]]], + + + [[[6.0000e+00, 1.0060e+03], + [1.0600e+02, 1.1060e+03], + [2.0600e+02, 1.2060e+03], + [3.0600e+02, 1.3060e+03]], + + [[1.6000e+01, 1.0160e+03], + [1.1600e+02, 1.1160e+03], + [2.1600e+02, 1.2160e+03], + [3.1600e+02, 1.3160e+03]]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]], + + [[ 42.], + [142.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]], + + [[ 43.], + [143.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]], + + [[ 44.], + [144.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]], + + [[ 45.], + [145.]]]]]]]), device='cuda:0', + size=(2, 3, 6, 6, 2, 1), nnz=4, dtype=torch.float64, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0') +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]], + + [[ 36.], + [136.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]], + + [[ 37.], + [137.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]], + + [[ 38.], + [138.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]], + + [[ 39.], + [139.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]], + + [[ 40.], + [140.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]], + + [[ 41.], + [141.]]], + + + [[[ 22.], + [122.]], - [[ 2., 22.]], + [[ 32.], + [132.]], - [[ 3., 33.]], + [[ 42.], + [142.]]]], - [[ 4., 44.]]], - [[[ 1., 11.]], + [[[[ 22.], + [122.]], - [[ 2., 22.]], + [[ 32.], + [132.]], - [[ 3., 33.]], + [[ 42.], + [142.]]], - [[ 4., 44.]]], + [[[ 23.], + [123.]], - [[[ 1., 11.]], + [[ 33.], + [133.]], - [[ 2., 22.]], + [[ 43.], + [143.]]]], - [[ 3., 33.]], - [[ 4., 44.]]]], + [[[[ 23.], + [123.]], + [[ 33.], + [133.]], - [[[[ 1., 11.]], + [[ 43.], + [143.]]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[ 24.], + [124.]], - [[ 4., 44.]]], + [[ 34.], + [134.]], + [[ 44.], + [144.]]]], - [[[ 1., 11.]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[[ 24.], + [124.]], - [[ 4., 44.]]], + [[ 34.], + [134.]], + [[ 44.], + [144.]]], - [[[ 1., 11.]], - [[ 2., 22.]], + [[[ 25.], + [125.]], - [[ 3., 33.]], + [[ 35.], + [135.]], - [[ 4., 44.]]]]], device='cuda:0', dtype=torch.float64) + [[ 45.], + [145.]]]]]]], device='cuda:0', dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect b/test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect index 90c158c8860d4..0dd1aff7d4dc2 100644 --- a/test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect +++ b/test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect @@ -1,907 +1,6949 @@ -########## torch.float32/torch.int32/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=()+(4, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([[[ 1., 11.]], + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]), device='cuda:0', size=(2, 4), nnz=4, + [[4.], + [5.]]]), device='cuda:0', size=(4, 3), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values -tensor([[[ 1., 11.]], +tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], device='cuda:0') + [[4.], + [5.]]], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float32/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), device='cuda:0', size=(0, 0), nnz=0, + values=tensor([], size=(0, 2, 1)), device='cuda:0', size=(0, 0), nnz=0, layout=torch.sparse_bsr) # _crow_indices tensor([0], device='cuda:0', dtype=torch.int32) # _col_indices tensor([], device='cuda:0', dtype=torch.int32) # _values -tensor([], device='cuda:0', size=(1, 0, 0)) +tensor([], device='cuda:0', size=(0, 2, 1)) -########## torch.float32/torch.int32/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=(2,)+(2, 6)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]), device='cuda:0', size=(2, 2, 4), nnz=4, + [[ 8., 18.]]]]), device='cuda:0', size=(2, 2, 6), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]], device='cuda:0', dtype=torch.int32) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0', dtype=torch.int32) + [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) # _values tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]], device='cuda:0') + [[ 8., 18.]]]], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float32/torch.int32/size=(2, 3)+(4, 9)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]), device='cuda:0', size=(2, 3, 2, 4), - nnz=4, layout=torch.sparse_bsr) + [[24., 34., 44.], + [25., 35., 45.]]]]]), device='cuda:0', + size=(2, 3, 4, 9), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]], device='cuda:0') + [[24., 34., 44.], + [25., 35., 45.]]]]], device='cuda:0') -########## torch.float64/torch.int32/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=()+(4, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([[[ 1., 11.]], + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]), device='cuda:0', size=(2, 4), nnz=4, + [[4.], + [5.]]]), device='cuda:0', size=(4, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values -tensor([[[ 1., 11.]], +tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], device='cuda:0', dtype=torch.float64) + [[4.], + [5.]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float64/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), device='cuda:0', size=(0, 0), nnz=0, + values=tensor([], size=(0, 2, 1)), device='cuda:0', size=(0, 0), nnz=0, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([0], device='cuda:0', dtype=torch.int32) # _col_indices tensor([], device='cuda:0', dtype=torch.int32) # _values -tensor([], device='cuda:0', size=(1, 0, 0), dtype=torch.float64) +tensor([], device='cuda:0', size=(0, 2, 1), dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=(2,)+(2, 6)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]), device='cuda:0', size=(2, 2, 4), nnz=4, + [[ 8., 18.]]]]), device='cuda:0', size=(2, 2, 6), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]], device='cuda:0', dtype=torch.int32) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0', dtype=torch.int32) + [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) # _values tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]], device='cuda:0', dtype=torch.float64) + [[ 8., 18.]]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float64/torch.int32/size=(2, 3)+(4, 9)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]), device='cuda:0', size=(2, 3, 2, 4), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) + [[24., 34., 44.], + [25., 35., 45.]]]]]), device='cuda:0', + size=(2, 3, 4, 9), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]], device='cuda:0', dtype=torch.float64) + [[24., 34., 44.], + [25., 35., 45.]]]]], device='cuda:0', dtype=torch.float64) -########## torch.float32/torch.int64/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=()+(4, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([[[ 1., 11.]], + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]), device='cuda:0', size=(2, 4), nnz=4, + [[4.], + [5.]]]), device='cuda:0', size=(4, 3), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([0, 2, 4], device='cuda:0') # _col_indices -tensor([0, 1, 0, 1], device='cuda:0') +tensor([0, 1, 0, 2], device='cuda:0') # _values -tensor([[[ 1., 11.]], +tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], device='cuda:0') + [[4.], + [5.]]], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float32/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), device='cuda:0', size=(0, 0), nnz=0, + values=tensor([], size=(0, 2, 1)), device='cuda:0', size=(0, 0), nnz=0, layout=torch.sparse_bsr) # _crow_indices tensor([0], device='cuda:0') # _col_indices tensor([], device='cuda:0', dtype=torch.int64) # _values -tensor([], device='cuda:0', size=(1, 0, 0)) +tensor([], device='cuda:0', size=(0, 2, 1)) -########## torch.float32/torch.int64/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=(2,)+(2, 6)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]), device='cuda:0', size=(2, 2, 4), nnz=4, + [[ 8., 18.]]]]), device='cuda:0', size=(2, 2, 6), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0') + [0, 3, 4]], device='cuda:0') # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0') + [0, 1, 2, 0]], device='cuda:0') # _values tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]], device='cuda:0') + [[ 8., 18.]]]], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float32/torch.int64/size=(2, 3)+(4, 9)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]), device='cuda:0', size=(2, 3, 2, 4), - nnz=4, layout=torch.sparse_bsr) + [[24., 34., 44.], + [25., 35., 45.]]]]]), device='cuda:0', + size=(2, 3, 4, 9), nnz=4, layout=torch.sparse_bsr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0') + [0, 3, 4]]], device='cuda:0') # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0') + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]], device='cuda:0') + [[24., 34., 44.], + [25., 35., 45.]]]]], device='cuda:0') -########## torch.float64/torch.int64/batch_shape=()/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=()+(4, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([[[ 1., 11.]], + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]]), device='cuda:0', size=(2, 4), nnz=4, + [[4.], + [5.]]]), device='cuda:0', size=(4, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([0, 2, 4], device='cuda:0') # _col_indices -tensor([0, 1, 0, 1], device='cuda:0') +tensor([0, 1, 0, 2], device='cuda:0') # _values -tensor([[[ 1., 11.]], +tensor([[[1.], + [2.]], - [[ 2., 22.]], + [[2.], + [3.]], - [[ 3., 33.]], + [[3.], + [4.]], - [[ 4., 44.]]], device='cuda:0', dtype=torch.float64) + [[4.], + [5.]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=()/block_shape=(0, 0) ########## +########## torch.float64/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), - values=tensor([], size=(1, 0, 0)), device='cuda:0', size=(0, 0), nnz=0, + values=tensor([], size=(0, 2, 1)), device='cuda:0', size=(0, 0), nnz=0, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([0], device='cuda:0') # _col_indices tensor([], device='cuda:0', dtype=torch.int64) # _values -tensor([], device='cuda:0', size=(1, 0, 0), dtype=torch.float64) +tensor([], device='cuda:0', size=(0, 2, 1), dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2,)/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=(2,)+(2, 6)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]]), device='cuda:0', size=(2, 2, 4), nnz=4, + [[ 8., 18.]]]]), device='cuda:0', size=(2, 2, 6), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0') + [0, 3, 4]], device='cuda:0') # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0') + [0, 1, 2, 0]], device='cuda:0') # _values tensor([[[[ 1., 11.]], - [[ 2., 22.]], + [[ 2., 12.]], - [[ 3., 33.]], + [[ 3., 13.]], - [[ 4., 44.]]], + [[ 4., 14.]]], - [[[ 1., 11.]], + [[[ 5., 15.]], - [[ 2., 22.]], + [[ 6., 16.]], - [[ 3., 33.]], + [[ 7., 17.]], - [[ 4., 44.]]]], device='cuda:0', dtype=torch.float64) + [[ 8., 18.]]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2, 3)/block_shape=(1, 2) ########## +########## torch.float64/torch.int64/size=(2, 3)+(4, 9)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[[[ 1., 11.]], + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], - [[ 2., 22.]], + [[ 2., 12., 22.], + [ 3., 13., 23.]], - [[ 3., 33.]], + [[ 3., 13., 23.], + [ 4., 14., 24.]], - [[ 4., 44.]]], + [[ 4., 14., 24.], + [ 5., 15., 25.]]], - [[[ 1., 11.]], + [[[ 5., 15., 25.], + [ 6., 16., 26.]], - [[ 2., 22.]], + [[ 6., 16., 26.], + [ 7., 17., 27.]], - [[ 3., 33.]], + [[ 7., 17., 27.], + [ 8., 18., 28.]], - [[ 4., 44.]]], + [[ 8., 18., 28.], + [ 9., 19., 29.]]], - [[[ 1., 11.]], + [[[ 9., 19., 29.], + [10., 20., 30.]], - [[ 2., 22.]], + [[10., 20., 30.], + [11., 21., 31.]], - [[ 3., 33.]], + [[11., 21., 31.], + [12., 22., 32.]], - [[ 4., 44.]]]], + [[12., 22., 32.], + [13., 23., 33.]]]], - [[[[ 1., 11.]], + [[[[13., 23., 33.], + [14., 24., 34.]], - [[ 2., 22.]], + [[14., 24., 34.], + [15., 25., 35.]], - [[ 3., 33.]], + [[15., 25., 35.], + [16., 26., 36.]], - [[ 4., 44.]]], + [[16., 26., 36.], + [17., 27., 37.]]], - [[[ 1., 11.]], + [[[17., 27., 37.], + [18., 28., 38.]], - [[ 2., 22.]], + [[18., 28., 38.], + [19., 29., 39.]], - [[ 3., 33.]], + [[19., 29., 39.], + [20., 30., 40.]], - [[ 4., 44.]]], + [[20., 30., 40.], + [21., 31., 41.]]], - [[[ 1., 11.]], + [[[21., 31., 41.], + [22., 32., 42.]], - [[ 2., 22.]], + [[22., 32., 42.], + [23., 33., 43.]], - [[ 3., 33.]], + [[23., 33., 43.], + [24., 34., 44.]], - [[ 4., 44.]]]]]), device='cuda:0', size=(2, 3, 2, 4), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) + [[24., 34., 44.], + [25., 35., 45.]]]]]), device='cuda:0', + size=(2, 3, 4, 9), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0') + [0, 3, 4]]], device='cuda:0') # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[[[ 1., 11., 21.], + [ 2., 12., 22.]], + + [[ 2., 12., 22.], + [ 3., 13., 23.]], + + [[ 3., 13., 23.], + [ 4., 14., 24.]], + + [[ 4., 14., 24.], + [ 5., 15., 25.]]], + + + [[[ 5., 15., 25.], + [ 6., 16., 26.]], + + [[ 6., 16., 26.], + [ 7., 17., 27.]], + + [[ 7., 17., 27.], + [ 8., 18., 28.]], + + [[ 8., 18., 28.], + [ 9., 19., 29.]]], + + + [[[ 9., 19., 29.], + [10., 20., 30.]], + + [[10., 20., 30.], + [11., 21., 31.]], + + [[11., 21., 31.], + [12., 22., 32.]], + + [[12., 22., 32.], + [13., 23., 33.]]]], + + + + [[[[13., 23., 33.], + [14., 24., 34.]], + + [[14., 24., 34.], + [15., 25., 35.]], + + [[15., 25., 35.], + [16., 26., 36.]], + + [[16., 26., 36.], + [17., 27., 37.]]], + + + [[[17., 27., 37.], + [18., 28., 38.]], + + [[18., 28., 38.], + [19., 29., 39.]], + + [[19., 29., 39.], + [20., 30., 40.]], + + [[20., 30., 40.], + [21., 31., 41.]]], + + + [[[21., 31., 41.], + [22., 32., 42.]], + + [[22., 32., 42.], + [23., 33., 43.]], + + [[23., 33., 43.], + [24., 34., 44.]], + + [[24., 34., 44.], + [25., 35., 45.]]]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int32/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]), device='cuda:0', size=(6, 6, 2), + nnz=4, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]], device='cuda:0') + +########## torch.float32/torch.int32/size=()+(4, 9)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0') + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]), device='cuda:0', + size=(4, 9, 4, 2), nnz=4, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1., 11.]], +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]], device='cuda:0') + +########## torch.float32/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]), device='cuda:0', + size=(2, 3, 6, 6, 2, 1), nnz=4, layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]], device='cuda:0') + + +########## torch.float64/torch.int32/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]), device='cuda:0', size=(6, 6, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(4, 9)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]), device='cuda:0', + size=(4, 9, 4, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]), device='cuda:0', + size=(2, 3, 6, 6, 2, 1), nnz=4, dtype=torch.float64, + layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int64/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]), device='cuda:0', size=(6, 6, 2), + nnz=4, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]], device='cuda:0') + +########## torch.float32/torch.int64/size=()+(4, 9)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]), device='cuda:0', + size=(4, 9, 4, 2), nnz=4, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]], device='cuda:0') + +########## torch.float32/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]), device='cuda:0', + size=(2, 3, 6, 6, 2, 1), nnz=4, layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0') +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]], device='cuda:0') + + +########## torch.float64/torch.int64/size=()+(6, 6)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]]), device='cuda:0', size=(6, 6, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[[ 1., 101.], + [ 11., 111.]], + + [[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]]], + + + [[[ 2., 102.], + [ 12., 112.]], + + [[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]]], + + + [[[ 3., 103.], + [ 13., 113.]], + + [[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]]], + + + [[[ 4., 104.], + [ 14., 114.]], + + [[ 5., 105.], + [ 15., 115.]], + + [[ 6., 106.], + [ 16., 116.]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(4, 9)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]]), device='cuda:0', + size=(4, 9, 4, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[[[1.0000e+00, 1.0010e+03], + [1.0100e+02, 1.1010e+03], + [2.0100e+02, 1.2010e+03], + [3.0100e+02, 1.3010e+03]], + + [[1.1000e+01, 1.0110e+03], + [1.1100e+02, 1.1110e+03], + [2.1100e+02, 1.2110e+03], + [3.1100e+02, 1.3110e+03]], + + [[2.1000e+01, 1.0210e+03], + [1.2100e+02, 1.1210e+03], + [2.2100e+02, 1.2210e+03], + [3.2100e+02, 1.3210e+03]]], + + + [[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]]], + + + + [[[[2.0000e+00, 1.0020e+03], + [1.0200e+02, 1.1020e+03], + [2.0200e+02, 1.2020e+03], + [3.0200e+02, 1.3020e+03]], + + [[1.2000e+01, 1.0120e+03], + [1.1200e+02, 1.1120e+03], + [2.1200e+02, 1.2120e+03], + [3.1200e+02, 1.3120e+03]], + + [[2.2000e+01, 1.0220e+03], + [1.2200e+02, 1.1220e+03], + [2.2200e+02, 1.2220e+03], + [3.2200e+02, 1.3220e+03]]], + + + [[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]]], + + + + [[[[3.0000e+00, 1.0030e+03], + [1.0300e+02, 1.1030e+03], + [2.0300e+02, 1.2030e+03], + [3.0300e+02, 1.3030e+03]], + + [[1.3000e+01, 1.0130e+03], + [1.1300e+02, 1.1130e+03], + [2.1300e+02, 1.2130e+03], + [3.1300e+02, 1.3130e+03]], + + [[2.3000e+01, 1.0230e+03], + [1.2300e+02, 1.1230e+03], + [2.2300e+02, 1.2230e+03], + [3.2300e+02, 1.3230e+03]]], + + + [[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]]], + + + + [[[[4.0000e+00, 1.0040e+03], + [1.0400e+02, 1.1040e+03], + [2.0400e+02, 1.2040e+03], + [3.0400e+02, 1.3040e+03]], + + [[1.4000e+01, 1.0140e+03], + [1.1400e+02, 1.1140e+03], + [2.1400e+02, 1.2140e+03], + [3.1400e+02, 1.3140e+03]], + + [[2.4000e+01, 1.0240e+03], + [1.2400e+02, 1.1240e+03], + [2.2400e+02, 1.2240e+03], + [3.2400e+02, 1.3240e+03]]], + + + [[[5.0000e+00, 1.0050e+03], + [1.0500e+02, 1.1050e+03], + [2.0500e+02, 1.2050e+03], + [3.0500e+02, 1.3050e+03]], + + [[1.5000e+01, 1.0150e+03], + [1.1500e+02, 1.1150e+03], + [2.1500e+02, 1.2150e+03], + [3.1500e+02, 1.3150e+03]], + + [[2.5000e+01, 1.0250e+03], + [1.2500e+02, 1.1250e+03], + [2.2500e+02, 1.2250e+03], + [3.2500e+02, 1.3250e+03]]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + + + + [[[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]]], + + + + [[[[ 23.], + [123.]], + + [[ 33.], + [133.]]], + + + [[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]]], + + + + [[[[ 24.], + [124.]], + + [[ 34.], + [134.]]], + + + [[[ 25.], + [125.]], + + [[ 35.], + [135.]]], + + + [[[ 26.], + [126.]], + + [[ 36.], + [136.]]]]]]]), device='cuda:0', + size=(2, 3, 6, 6, 2, 1), nnz=4, dtype=torch.float64, + layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0') +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[[[[[ 1.], + [101.]], + + [[ 11.], + [111.]]], + + + [[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]]], + + + + [[[[ 2.], + [102.]], + + [[ 12.], + [112.]]], + + + [[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]]], + + + + [[[[ 3.], + [103.]], + + [[ 13.], + [113.]]], + + + [[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]]], + + + + [[[[ 4.], + [104.]], + + [[ 14.], + [114.]]], + + + [[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]]]], + + + + + [[[[[ 5.], + [105.]], + + [[ 15.], + [115.]]], + + + [[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]]], + + + + [[[[ 6.], + [106.]], + + [[ 16.], + [116.]]], + + + [[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]]], + + + + [[[[ 7.], + [107.]], + + [[ 17.], + [117.]]], + + + [[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]]], + + + + [[[[ 8.], + [108.]], + + [[ 18.], + [118.]]], + + + [[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]]]], + + + + + [[[[[ 9.], + [109.]], + + [[ 19.], + [119.]]], + + + [[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]]], + + + + [[[[ 10.], + [110.]], + + [[ 20.], + [120.]]], + + + [[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]]], + + + + [[[[ 11.], + [111.]], + + [[ 21.], + [121.]]], + + + [[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]]], + + + + [[[[ 12.], + [112.]], + + [[ 22.], + [122.]]], + + + [[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]]]]], + + + + + + [[[[[[ 13.], + [113.]], + + [[ 23.], + [123.]]], + + + [[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]]], + + + + [[[[ 14.], + [114.]], + + [[ 24.], + [124.]]], + + + [[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]]], + + + + [[[[ 15.], + [115.]], + + [[ 25.], + [125.]]], + + + [[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]]], + + + + [[[[ 16.], + [116.]], + + [[ 26.], + [126.]]], + + + [[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]]]], + + + + + [[[[[ 17.], + [117.]], + + [[ 27.], + [127.]]], + + + [[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]]], + + + + [[[[ 18.], + [118.]], + + [[ 28.], + [128.]]], + + + [[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]]], + + + + [[[[ 19.], + [119.]], + + [[ 29.], + [129.]]], + + + [[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]]], + + + + [[[[ 20.], + [120.]], + + [[ 30.], + [130.]]], + + + [[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]]]], + + + + + [[[[[ 21.], + [121.]], + + [[ 31.], + [131.]]], + + + [[[ 22.], + [122.]], + + [[ 32.], + [132.]]], + + + [[[ 23.], + [123.]], + + [[ 33.], + [133.]]]], + - [[ 2., 22.]], - [[ 3., 33.]], + [[[[ 22.], + [122.]], - [[ 4., 44.]]], + [[ 32.], + [132.]]], - [[[ 1., 11.]], + [[[ 23.], + [123.]], - [[ 2., 22.]], + [[ 33.], + [133.]]], - [[ 3., 33.]], - [[ 4., 44.]]], + [[[ 24.], + [124.]], + [[ 34.], + [134.]]]], - [[[ 1., 11.]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[[ 23.], + [123.]], - [[ 4., 44.]]]], + [[ 33.], + [133.]]], + [[[ 24.], + [124.]], - [[[[ 1., 11.]], + [[ 34.], + [134.]]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[ 25.], + [125.]], - [[ 4., 44.]]], + [[ 35.], + [135.]]]], - [[[ 1., 11.]], - [[ 2., 22.]], + [[[[ 24.], + [124.]], - [[ 3., 33.]], + [[ 34.], + [134.]]], - [[ 4., 44.]]], + [[[ 25.], + [125.]], - [[[ 1., 11.]], + [[ 35.], + [135.]]], - [[ 2., 22.]], - [[ 3., 33.]], + [[[ 26.], + [126.]], - [[ 4., 44.]]]]], device='cuda:0', dtype=torch.float64) + [[ 36.], + [136.]]]]]]], device='cuda:0', dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect b/test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect index 4292bfcd21994..64435343b7cb6 100644 --- a/test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect +++ b/test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect @@ -1,17 +1,17 @@ -########## torch.float32/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int32/size=()+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4, + row_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values tensor([1., 2., 3., 4.], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), @@ -24,89 +24,89 @@ tensor([], device='cuda:0', dtype=torch.int32) # _values tensor([], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=(2,)/block_shape=() ########## +########## torch.float32/torch.int32/size=(2,)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), device='cuda:0', size=(2, 2, 2), + [5., 6., 7., 8.]]), device='cuda:0', size=(2, 3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0', dtype=torch.int32) + [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], device='cuda:0') + [5., 6., 7., 8.]], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float32/torch.int32/size=(2, 3)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 2), - nnz=4, layout=torch.sparse_csc) + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), device='cuda:0', + size=(2, 3, 3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], device='cuda:0') + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], device='cuda:0') -########## torch.float64/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int32/size=()+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4, + row_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), @@ -119,89 +119,89 @@ tensor([], device='cuda:0', dtype=torch.int32) # _values tensor([], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2,)/block_shape=() ########## +########## torch.float64/torch.int32/size=(2,)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), device='cuda:0', size=(2, 2, 2), + [5., 6., 7., 8.]]), device='cuda:0', size=(2, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0', dtype=torch.int32) + [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], device='cuda:0', dtype=torch.float64) + [5., 6., 7., 8.]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float64/torch.int32/size=(2, 3)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_csc) + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), device='cuda:0', + size=(2, 3, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], device='cuda:0', dtype=torch.float64) + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], device='cuda:0', dtype=torch.float64) -########## torch.float32/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int64/size=()+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4, + row_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([0, 2, 4], device='cuda:0') # _row_indices -tensor([0, 1, 0, 1], device='cuda:0') +tensor([0, 1, 0, 2], device='cuda:0') # _values tensor([1., 2., 3., 4.], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), @@ -214,89 +214,89 @@ tensor([], device='cuda:0', dtype=torch.int64) # _values tensor([], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=(2,)/block_shape=() ########## +########## torch.float32/torch.int64/size=(2,)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), device='cuda:0', size=(2, 2, 2), + [5., 6., 7., 8.]]), device='cuda:0', size=(2, 3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0') + [0, 3, 4]], device='cuda:0') # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0') + [0, 1, 2, 0]], device='cuda:0') # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], device='cuda:0') + [5., 6., 7., 8.]], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float32/torch.int64/size=(2, 3)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 2), - nnz=4, layout=torch.sparse_csc) + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), device='cuda:0', + size=(2, 3, 3, 2), nnz=4, layout=torch.sparse_csc) # _ccol_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0') + [0, 3, 4]]], device='cuda:0') # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0') + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], device='cuda:0') + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], device='cuda:0') -########## torch.float64/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int64/size=()+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4, + row_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([0, 2, 4], device='cuda:0') # _row_indices -tensor([0, 1, 0, 1], device='cuda:0') +tensor([0, 1, 0, 2], device='cuda:0') # _values tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(ccol_indices=tensor([0]), row_indices=tensor([], size=(0,)), @@ -309,71 +309,1103 @@ tensor([], device='cuda:0', dtype=torch.int64) # _values tensor([], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2,)/block_shape=() ########## +########## torch.float64/torch.int64/size=(2,)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), row_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), device='cuda:0', size=(2, 2, 2), + [5., 6., 7., 8.]]), device='cuda:0', size=(2, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0') + [0, 3, 4]], device='cuda:0') # _row_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0') + [0, 1, 2, 0]], device='cuda:0') # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], device='cuda:0', dtype=torch.float64) + [5., 6., 7., 8.]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float64/torch.int64/size=(2, 3)+(3, 2)+() ########## # sparse tensor tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), device='cuda:0', + size=(2, 3, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0') +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int32/size=()+(3, 2)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), device='cuda:0', size=(3, 2, 2), nnz=4, + layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], device='cuda:0') + +########## torch.float32/torch.int32/size=()+(3, 2)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), device='cuda:0', size=(3, 2, 4, 2), + nnz=4, layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], device='cuda:0') - [[0, 2, 4], +########## torch.float32/torch.int32/size=(2, 3)+(3, 2)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 2), + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), device='cuda:0', size=(2, 3, 3, 2, 2, 1), + nnz=4, layout=torch.sparse_csc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]], device='cuda:0') + + +########## torch.float64/torch.int32/size=()+(3, 2)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), device='cuda:0', size=(3, 2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(3, 2)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), device='cuda:0', size=(3, 2, 4, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=(2, 3)+(3, 2)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), device='cuda:0', size=(2, 3, 3, 2, 2, 1), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int64/size=()+(3, 2)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), device='cuda:0', size=(3, 2, 2), nnz=4, + layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0') +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], device='cuda:0') + +########## torch.float32/torch.int64/size=()+(3, 2)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), device='cuda:0', size=(3, 2, 4, 2), + nnz=4, layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0') +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], device='cuda:0') + +########## torch.float32/torch.int64/size=(2, 3)+(3, 2)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), device='cuda:0', size=(2, 3, 3, 2, 2, 1), + nnz=4, layout=torch.sparse_csc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]], + [0, 3, 4]]], device='cuda:0') +# _row_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], - [[0, 2, 4], + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]], device='cuda:0') + + +########## torch.float64/torch.int64/size=()+(3, 2)+(2,) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), device='cuda:0', size=(3, 2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0') +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(3, 2)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), device='cuda:0', size=(3, 2, 4, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([0, 2, 4], device='cuda:0') +# _row_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(3, 2)+(2, 1) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), device='cuda:0', size=(2, 3, 3, 2, 2, 1), + nnz=4, dtype=torch.float64, layout=torch.sparse_csc) +# _ccol_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0') + [0, 3, 4]]], device='cuda:0') # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0') + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], device='cuda:0', dtype=torch.float64) + [[24.], + [25.]]]]], device='cuda:0', dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect b/test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect index 918f2570807f7..ddb5272c79cab 100644 --- a/test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect +++ b/test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect @@ -1,17 +1,17 @@ -########## torch.float32/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int32/size=()+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4, + col_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values tensor([1., 2., 3., 4.], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), @@ -24,89 +24,89 @@ tensor([], device='cuda:0', dtype=torch.int32) # _values tensor([], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=(2,)/block_shape=() ########## +########## torch.float32/torch.int32/size=(2,)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), device='cuda:0', size=(2, 2, 2), + [5., 6., 7., 8.]]), device='cuda:0', size=(2, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]], device='cuda:0', dtype=torch.int32) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0', dtype=torch.int32) + [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], device='cuda:0') + [5., 6., 7., 8.]], device='cuda:0') -########## torch.float32/torch.int32/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 2), - nnz=4, layout=torch.sparse_csr) + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), device='cuda:0', + size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], device='cuda:0') + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], device='cuda:0') -########## torch.float64/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int32/size=()+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4, + col_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) # _values tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int32/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), @@ -119,89 +119,89 @@ tensor([], device='cuda:0', dtype=torch.int32) # _values tensor([], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2,)/block_shape=() ########## +########## torch.float64/torch.int32/size=(2,)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), device='cuda:0', size=(2, 2, 2), + [5., 6., 7., 8.]]), device='cuda:0', size=(2, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]], device='cuda:0', dtype=torch.int32) # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0', dtype=torch.int32) + [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], device='cuda:0', dtype=torch.float64) + [5., 6., 7., 8.]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_csr) + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), device='cuda:0', + size=(2, 3, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0', dtype=torch.int32) + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], device='cuda:0', dtype=torch.float64) + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], device='cuda:0', dtype=torch.float64) -########## torch.float32/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int64/size=()+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4, + col_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([0, 2, 4], device='cuda:0') # _col_indices -tensor([0, 1, 0, 1], device='cuda:0') +tensor([0, 1, 0, 2], device='cuda:0') # _values tensor([1., 2., 3., 4.], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float32/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), @@ -214,89 +214,89 @@ tensor([], device='cuda:0', dtype=torch.int64) # _values tensor([], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=(2,)/block_shape=() ########## +########## torch.float32/torch.int64/size=(2,)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), device='cuda:0', size=(2, 2, 2), + [5., 6., 7., 8.]]), device='cuda:0', size=(2, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0') + [0, 3, 4]], device='cuda:0') # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0') + [0, 1, 2, 0]], device='cuda:0') # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], device='cuda:0') + [5., 6., 7., 8.]], device='cuda:0') -########## torch.float32/torch.int64/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 2), - nnz=4, layout=torch.sparse_csr) + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), device='cuda:0', + size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], - [0, 2, 4], - [0, 2, 4]], + [0, 3, 4], + [0, 1, 4]], - [[0, 2, 4], + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0') + [0, 3, 4]]], device='cuda:0') # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0') + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], device='cuda:0') + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], device='cuda:0') -########## torch.float64/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int64/size=()+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4, + col_indices=tensor([0, 1, 0, 2]), + values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([0, 2, 4], device='cuda:0') # _col_indices -tensor([0, 1, 0, 1], device='cuda:0') +tensor([0, 1, 0, 2], device='cuda:0') # _values tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=()/block_shape=() ########## +########## torch.float64/torch.int64/size=()+(0, 0)+() ########## # sparse tensor tensor(crow_indices=tensor([0]), col_indices=tensor([], size=(0,)), @@ -309,71 +309,1103 @@ tensor([], device='cuda:0', dtype=torch.int64) # _values tensor([], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2,)/block_shape=() ########## +########## torch.float64/torch.int64/size=(2,)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[0, 2, 4], - [0, 2, 4]]), + [0, 3, 4]]), col_indices=tensor([[0, 1, 0, 1], - [0, 1, 0, 1]]), + [0, 1, 2, 0]]), values=tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]]), device='cuda:0', size=(2, 2, 2), + [5., 6., 7., 8.]]), device='cuda:0', size=(2, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[0, 2, 4], - [0, 2, 4]], device='cuda:0') + [0, 3, 4]], device='cuda:0') # _col_indices tensor([[0, 1, 0, 1], - [0, 1, 0, 1]], device='cuda:0') + [0, 1, 2, 0]], device='cuda:0') # _values tensor([[1., 2., 3., 4.], - [1., 2., 3., 4.]], device='cuda:0', dtype=torch.float64) + [5., 6., 7., 8.]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/batch_shape=(2, 3)/block_shape=() ########## +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]]), device='cuda:0', + size=(2, 3, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0') +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]], + + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int32/size=()+(2, 3)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), device='cuda:0', size=(2, 3, 2), nnz=4, + layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], device='cuda:0') + +########## torch.float32/torch.int32/size=()+(2, 3)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), device='cuda:0', size=(2, 3, 4, 2), + nnz=4, layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], device='cuda:0') - [[0, 2, 4], +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]]), + [0, 3, 4]]]), col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], - - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]]), - values=tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], - - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 2), + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), device='cuda:0', size=(2, 3, 2, 3, 2, 1), + nnz=4, layout=torch.sparse_csr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]], device='cuda:0') + + +########## torch.float64/torch.int32/size=()+(2, 3)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), device='cuda:0', size=(2, 3, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(2, 3)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), device='cuda:0', size=(2, 3, 4, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), device='cuda:0', size=(2, 3, 2, 3, 2, 1), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int64/size=()+(2, 3)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), device='cuda:0', size=(2, 3, 2), nnz=4, + layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], device='cuda:0') + +########## torch.float32/torch.int64/size=()+(2, 3)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), device='cuda:0', size=(2, 3, 4, 2), + nnz=4, layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], device='cuda:0') + +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), device='cuda:0', size=(2, 3, 2, 3, 2, 1), + nnz=4, layout=torch.sparse_csr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]], + [0, 3, 4]]], device='cuda:0') +# _col_indices +tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') +# _values +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], - [[0, 2, 4], + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]], device='cuda:0') + + +########## torch.float64/torch.int64/size=()+(2, 3)+(2,) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]]), device='cuda:0', size=(2, 3, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[1., 2.], + [2., 3.], + [3., 4.], + [4., 5.]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(2, 3)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 2]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]]), device='cuda:0', size=(2, 3, 4, 2), + nnz=4, dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([0, 2, 4], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 2], device='cuda:0') +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], + + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(2, 1) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], + [0, 2, 4], + [0, 3, 4]]]), + col_indices=tensor([[[0, 1, 0, 1], + [0, 1, 2, 0], + [0, 0, 1, 2]], + + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]]), + values=tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], + + [[24.], + [25.]]]]]), device='cuda:0', size=(2, 3, 2, 3, 2, 1), + nnz=4, dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([[[0, 2, 4], + [0, 3, 4], + [0, 1, 4]], + + [[0, 1, 4], [0, 2, 4], - [0, 2, 4]]], device='cuda:0') + [0, 3, 4]]], device='cuda:0') # _col_indices tensor([[[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]], + [0, 1, 2, 0], + [0, 0, 1, 2]], - [[0, 1, 0, 1], - [0, 1, 0, 1], - [0, 1, 0, 1]]], device='cuda:0') + [[1, 0, 1, 2], + [0, 2, 0, 1], + [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]], +tensor([[[[[ 1.], + [ 2.]], + + [[ 2.], + [ 3.]], + + [[ 3.], + [ 4.]], + + [[ 4.], + [ 5.]]], + + + [[[ 5.], + [ 6.]], + + [[ 6.], + [ 7.]], + + [[ 7.], + [ 8.]], + + [[ 8.], + [ 9.]]], + + + [[[ 9.], + [10.]], + + [[10.], + [11.]], + + [[11.], + [12.]], + + [[12.], + [13.]]]], + + + + [[[[13.], + [14.]], + + [[14.], + [15.]], + + [[15.], + [16.]], + + [[16.], + [17.]]], + + + [[[17.], + [18.]], + + [[18.], + [19.]], + + [[19.], + [20.]], + + [[20.], + [21.]]], + + + [[[21.], + [22.]], + + [[22.], + [23.]], + + [[23.], + [24.]], - [[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]], device='cuda:0', dtype=torch.float64) + [[24.], + [25.]]]]], device='cuda:0', dtype=torch.float64) diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 86f91976c3c86..7f91b95cfcd1d 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -9,6 +9,19 @@ from torch._C import parse_schema +# How to run this test locally: +# 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly) +# one with your local changes (venv_yours). +# In venv_nightly: +# 2. First ensure that Pytorch is uninstalled, but all prereqs are installed +# 3. Install torch nightly build with +# `pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html` +# 4. Generate original schemas with +# `python test/forward_backward_compatibility/dump_all_function_schemas.py --filename nightly_schemas.txt` +# Now in venv_yours: +# 5. Run this test with +# `python test/forward_backward_compatibility/check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt` + # The date specifies how long the allowlist exclusion should apply to. # # - If we NEVER give BC guarantee for an operator, you can put the @@ -38,6 +51,7 @@ ("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)), ("profiler::_record_function_enter", datetime.date(9999, 1, 1)), ("aten::_sparse_addmm", datetime.date(2022, 6, 30)), + ("aten::kl_div_backward", datetime.date(2022, 9, 1)), ("aten::_cholesky_helper", datetime.date(9999, 1, 1)), ("aten::_lstsq_helper", datetime.date(9999, 1, 1)), ("aten::_syevd_helper", datetime.date(9999, 1, 1)), @@ -51,12 +65,19 @@ ("aten::randperm", datetime.date(9999, 1, 1)), ("aten::linalg_solve", datetime.date(2022, 8, 31)), ("aten::linalg_solve.out", datetime.date(2022, 8, 31)), - ("aten::l1_loss_backward.grad_input", datetime.date(2022, 7, 1)), - ("aten::l1_loss_backward", datetime.date(2022, 7, 1)), - ("aten::l1_loss.out", datetime.date(2022, 7, 1)), + ("aten::binary_cross_entropy_with_logits_backward", datetime.date(2022, 9, 21)), ("aten::_linalg_qr_helper", datetime.date(2022, 8, 1)), ("aten::linalg_lu_solve", datetime.date(2022, 8, 1)), ("aten::linalg_lu_solve.out", datetime.date(2022, 8, 1)), + ("aten::linalg_det", datetime.date(2022, 8, 1)), + ("aten::linalg_det.out", datetime.date(2022, 8, 1)), + ("aten::_det_lu_based_helper", datetime.date(2022, 8, 1)), + ("aten::slogdet", datetime.date(2022, 8, 1)), + ("aten::slogdet.out", datetime.date(2022, 8, 1)), + ("aten::linalg_slogdet", datetime.date(2022, 8, 1)), + ("aten::linalg_slogdet.out", datetime.date(2022, 8, 1)), + ("aten::_linalg_solve", datetime.date(2022, 10, 1)), + ("aten::_linalg_solve.solution", datetime.date(2022, 10, 1)), ("aten::solve", datetime.date(9999, 1, 1)), ("aten::solve.solution", datetime.date(9999, 1, 1)), ("aten::_solve_helper", datetime.date(9999, 1, 1)), @@ -90,6 +111,10 @@ ("aten::segment_reduce", datetime.date(2022, 6, 30)), ("aten::_segment_reduce_backward", datetime.date(2022, 6, 30)), ("aten::empty.SymInt", datetime.date(9999, 1, 1)), + ("c10d::broadcast", datetime.date(2022, 6, 25)), + ("aten::.*functional", datetime.date(2022, 8, 1)), + ("aten::_foreach.*", datetime.date(2022, 8, 1)), + ("aten::unflatten", datetime.date(2022, 8, 10)), # TODO: FIXME: prims shouldn't be checked ("prims::.*", datetime.date(9999, 1, 1)), ] @@ -118,6 +143,7 @@ def allow_listed(schema): ("_TorchScriptTesting.*", datetime.date(2099, 9, 17)), ("test_backend", datetime.date(2099, 9, 17)), ("dist_c10d", datetime.date(2099, 9, 17)), + ("__backends__.nnc", datetime.date(2099, 9, 17)), ] def has_valid_upgraders(schema, version_map): diff --git a/test/fx/test_common_passes.py b/test/fx/test_common_passes.py new file mode 100644 index 0000000000000..9c59abce4da61 --- /dev/null +++ b/test/fx/test_common_passes.py @@ -0,0 +1,115 @@ +# Owner(s): ["oncall: fx"] + +import torch + +from torch.testing._internal.common_utils import ( + TestCase, parametrize, instantiate_parametrized_tests, run_tests) +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.dialect.common.cse_pass import CSEPass +from torch.fx.graph_module import GraphModule + +import itertools + +def FactoryFunctionCall(x, device): + y = torch.full(x.shape, 3, device=device) + z = torch.add(y, x) + return z + + +def TorchTensorCall(x): + y = torch.tensor(3) + return x + y + + +def TakeList(x): + z = torch.cat([x, x]) + return z + + +def ReturnList(x): + a = torch.arange(10).reshape(5, 2) + z = torch.split(a, [1, 4]) + return z + + +def Mutation(x): + y = x + 2 + y.add_(1) + return x + y + + +def MutationInput(x): + x.add_(1) + y = x + 2 + return x + y + + +def MutationFactory(x, device): + y = torch.full(x.shape, 3, device=device) + y.add_(1) + return x + y + + +def MutationTorchTensorCall(x): + y = torch.tensor(3) + y.add_(1) + return x + y + + +def MutationMetadata(x): + x.resize_(2) + return x + + +Passes = [CSEPass] +Test_Cases = [TakeList, + ReturnList, + Mutation, + MutationInput, + MutationMetadata, + MutationTorchTensorCall] +Factory_Test_Cases = [FactoryFunctionCall, MutationFactory] +Devices = ["cpu"] +if torch.cuda.is_available(): + Devices.append("cuda") + +@instantiate_parametrized_tests +class TestCommonPass(TestCase): + + @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices)) + def test_correctness(self, common_pass, f, device): + inp = torch.randn(10, device=device) + + traced_m = make_fx(f)(inp) + P = common_pass() + + res = P(traced_m) + modified_m = res.graph_module + assert isinstance(modified_m, GraphModule) + + inp_copy = inp.clone() + expected = f(inp) + result = modified_m(inp_copy) + + self.assertEqual(result, expected) + + + @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices)) + def test_correctness_factory(self, common_pass, f, device): + inp = torch.randn(10, device=device) + traced_m = make_fx(f)(inp, device) + P = common_pass() + + res = P(traced_m) + modified_m = res.graph_module + assert isinstance(modified_m, GraphModule) + + inp_copy = inp.clone() + expected = f(inp, device) + result = modified_m(inp_copy, device) + + self.assertEqual(result, expected) + + +if __name__ == '__main__': + run_tests() diff --git a/test/fx/test_cse_pass.py b/test/fx/test_cse_pass.py new file mode 100644 index 0000000000000..13ed344dc43e1 --- /dev/null +++ b/test/fx/test_cse_pass.py @@ -0,0 +1,233 @@ +# Owner(s): ["oncall: fx"] + +import torch + +from torch.testing._internal.common_utils import ( + TestCase, run_tests) +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops +from torch.fx import symbolic_trace + +import random + + +banned_ops = get_CSE_banned_ops() +P_default = CSEPass(banned_ops=banned_ops) + +def check(self, f, t, delta, check_val=True, graph_input=False, P=None): + """ + check if the CSE modified graph of ``f`` + 1) has delta less nodes, and + 2) do not reduce the number of nodes further on a second pass, and + 3) modified returned is true only if the number of nodes decreases. + + Args: + f: function to be checked + t: tensor to be passed to f + delta: an integer >= -1. + If delta = -1, it only checks if the new graph has less or equal number of nodes + check_val: if True, check if the output of f is correct + graph_input: True is f is type GraphModule + P: the pass to use. If None, use P_default + """ + if graph_input: + fx_g = f + else: + fx_g = make_fx(f)(t) + + if P is None: + P = P_default + + res = P(fx_g) + new_g = res.graph_module + new_graph = new_g.graph + modified = res.modified + + # the number of nodes decrease/ or stay the same + old_num_nodes = len(fx_g.graph.nodes) + new_num_nodes = len(new_graph.nodes) + + assert (new_num_nodes < old_num_nodes) == modified, "modified should be True if the number of nodes decrease" + + if delta == -1: + self.assertTrue(old_num_nodes >= new_num_nodes, ( + f"number of nodes increased {old_num_nodes}, {new_num_nodes}")) + else: + self.assertTrue(old_num_nodes == new_num_nodes + delta, ( + f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}")) + + # a second pass should not reduce more nodes + res = P(new_g) + pass_2_graph = res.graph_module.graph + pass_2_num_nodes = len(pass_2_graph.nodes) + self.assertTrue(pass_2_num_nodes == new_num_nodes, ( + f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}")) + + # check correctness + if check_val: + true_result = fx_g(t) + our_result = new_g(t) + if true_result is None: # both return None + self.assertTrue(our_result is None, f"true result is None, CSE result is {our_result}") + else: # results returned are the same + self.assertTrue(torch.all(true_result == our_result), ( + f"results are different {true_result}, {our_result}")) # check results are the same + +class TestCSEPass(TestCase): + + def test_nochange(self): + def f(x): + a = x + 1 + b = x + a + a = x + d = x + a + return b + d + t = torch.randn(2, 2) + check(self, f, t, 0) + + def test_empty(self): + def f(x): + pass + t = torch.randn(2, 2) + check(self, f, t, 0) + + + def test_immutable_list_type(self): + def f(x): + a = x.sum(dim=1) + b = x.sum(dim=1) + c = x.sum() + d = x.sum() + return a + b + c + d + t = torch.randn(2, 2) + check(self, f, t, 2) + + def test_immutable_list_multiple_entries(self): + def f(x): + a = x.sum(dim=[0, 1]) + b = x.sum(dim=[0, 1]) + c = x.sum(dim=1) + d = x.sum(dim=1) + return a + b + c + d + t = torch.randn(2, 2) + check(self, f, t, 2) + + def test_simple(self): + def f(x): + a = x.cos() + b = x.cos() + c = a + a + d = b + b + return c + d + t = torch.randn(2, 2) + check(self, f, t, 2) + + def test_simple_2(self): + def f(x): + a = x.cos().sin() + b = x.cos().sin() + c = a + a + d = b + b + return c + d + t = torch.randn(1) + check(self, f, t, 3) + + def test_two_args_default(self): + def f(x): + a = x.sum(dim=1) + b = x.sum(dim=1, keepdim=False) + c = x.sum(dim=1, keepdim=False) + d = x.sum(dim=1) + return a + b + c + d + t = torch.randn(2, 2) + check(self, f, t, 3) + + def test_two_args(self): + def f(x): + a = x.sum(dim=1) + b = x.sum(dim=1, keepdim=True) + c = x.sum(dim=1, keepdim=True) + d = x.sum(dim=1) + return a + b + c + d + t = torch.randn(2, 2) + check(self, f, t, 2) + + def test_simple_multiple_same_ops(self): + def f(x): + a = x.sum() + b = x.sum() + c = x.sum() + d = x.sum() + return a + b + c + d + t = torch.randn(2, 2) + check(self, f, t, 3) + + def test_nested_immutable_list_type(self): + def f(x): + a = torch.cat((x, x)) + b = torch.cat((x, x)) + return a + b + t = torch.randn(2, 2) + check(self, f, t, 1) + + def test_kwarg(self): + def f(x): + a = torch.ones_like(x) + b = torch.ones_like(x) + return a + b + t = torch.randn(2, 2) + check(self, f, t, 1) + + """ + Generate function with random ops and check if the result is the same + """ + def test_random(self): + def f(x): + vals = [x] + ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu] + for _ in range(100): + new_val = random.choice(ops)(random.choice(vals)) + vals.append(new_val) + return vals[-1] + + fx_g = symbolic_trace(f) + fx_g.graph.eliminate_dead_code() + fx_g.recompile() + t = torch.randn(2, 2) + + for _ in range(30): + check(self, fx_g, t, -1, graph_input=True) + + """ + Test that banned list ban ops as expected. + """ + def test_banned_list(self): + def f(x): + a = x + 1 + b = x + 1 + return a + b + + t = torch.randn(2, 2) + P_ban_add = P = CSEPass(banned_ops=[torch.ops.aten.add]) + check(self, f, t, 0, P=P_ban_add) # check that add is banned + check(self, f, t, 1) # check that add is not banned by default + + def test_rand_like(self): + def f(x): + a = torch.rand_like(x) + b = torch.rand_like(x) + return a + b + t = torch.randn(2, 2) + check(self, f, t, 0, check_val=False) + + def test_rand_n(self): + def f(x): + a = torch.randn(4) + b = torch.randn(4) + return a + b + t = torch.randn(2, 2) + check(self, f, t, 0, check_val=False) + + +if __name__ == '__main__': + run_tests() diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py index 930f345e82204..4f46b9982ba94 100644 --- a/test/fx/test_dce_pass.py +++ b/test/fx/test_dce_pass.py @@ -1,4 +1,4 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: fx"] from typing import Set, Type import torch diff --git a/test/fx/test_future.py b/test/fx/test_future.py index 35d0649ffdd57..4f093de54b4f8 100644 --- a/test/fx/test_future.py +++ b/test/fx/test_future.py @@ -1,4 +1,4 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: fx"] from __future__ import annotations # type: ignore[attr-defined] import torch diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py index fb53f3494f2ed..d7f3b16f2466c 100644 --- a/test/fx/test_fx_const_fold.py +++ b/test/fx/test_fx_const_fold.py @@ -1,4 +1,4 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: fx"] import operator diff --git a/test/fx/test_fx_param_shape_control_flow.py b/test/fx/test_fx_param_shape_control_flow.py index cbedca71c3220..e9af35d604577 100644 --- a/test/fx/test_fx_param_shape_control_flow.py +++ b/test/fx/test_fx_param_shape_control_flow.py @@ -1,4 +1,4 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: fx"] import unittest import torch diff --git a/test/fx/test_gradual_type.py b/test/fx/test_gradual_type.py index d2b9c067447b1..0371dc20214bb 100644 --- a/test/fx/test_gradual_type.py +++ b/test/fx/test_gradual_type.py @@ -1,4 +1,4 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: fx"] import unittest import torch diff --git a/test/fx/test_pass_infra.py b/test/fx/test_pass_infra.py new file mode 100644 index 0000000000000..e087b1dc6c2fd --- /dev/null +++ b/test/fx/test_pass_infra.py @@ -0,0 +1,158 @@ +# Owner(s): ["module: fx"] + +import torch +import torch.fx as fx + +from torch.testing._internal.common_utils import TestCase +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import ( + PassManager, + this_before_that_pass_constraint, + _topological_sort_passes, +) + +def replace_add_with_mul_pass(gm): + modified = False + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.add: + node.target = torch.mul + modified = True + return PassResult(gm, modified) + +def replace_mul_with_div_pass(gm): + modified = False + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.mul: + node.target = torch.div + modified = True + return PassResult(gm, modified) + +class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.add(x, x) + z = torch.add(y, x) + return z + + +class TestPassManager(TestCase): + def test_pass_manager(self): + """ + Tests that the pass manager runs the passes correctly. + """ + + m = AddModule() + traced_m = torch.fx.symbolic_trace(m) + pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass], steps=5) + + pm.validate_constraints() + self.assertEqual(len(pm.passes), 2) + + res = pm(traced_m) + modified_m = res.graph_module + assert isinstance(modified_m, fx.GraphModule) + + # Check that all call_function nodes are divs + for node in modified_m.graph.nodes: + if node.op == "call_function": + self.assertEqual(node.target, torch.div) + + def test_this_before_that_pass_constraint(self): + """ + Tests the construction of constraints + """ + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + + # add unfulfillable constraint + pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) + + with self.assertRaises(RuntimeError): + pm.validate_constraints() + + + def test_pass_manager_checks(self): + """ + Tests that users can add in check functions correctly + """ + m = AddModule() + traced_m = fx.symbolic_trace(m) + pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass]) + + def check_div_target(graph_module): + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target != torch.div: + raise ValueError("Target should be div!") + pm.add_checks(check_div_target) + + with self.assertRaises(ValueError): + pm(traced_m) + + def test_pass_manager_bad_checks(self): + """ + Checks that we error if we pass in a check function with the wrong parameters + """ + def check_bad_args(graph_module, i): + pass + + pm = PassManager() + self.assertRaises(TypeError, pm.add_checks, check_bad_args) + + def test_topological_sort(self): + """ + Tests that passes are correctly ordered based on contraints. + """ + + def pass0(x): + return x + + def pass1(x): + return x + 1 + + def pass2(x): + return x + 2 + + def pass3(x): + return x + 3 + + def pass4(x): + return x + 4 + + def pass5(x): + return x + 5 + + # Not passing any constraints should keep the original order + passes = [pass0, pass1, pass2, pass3, pass4, pass5] + sorted = _topological_sort_passes(passes, []) + self.assertEqual(sorted, passes) + + # Graph that we are constructing: + # 5 ----> 0 <---- 4 + # | | + # +-> 2 -> 3 -> 1 <-+ + # Which has a possible topological order of: [4, 5, 0, 2, 3, 1] + passes = [pass0, pass1, pass2, pass3, pass4, pass5] + constraints = [ + this_before_that_pass_constraint(pass5, pass0), + this_before_that_pass_constraint(pass5, pass2), + this_before_that_pass_constraint(pass4, pass0), + this_before_that_pass_constraint(pass4, pass1), + this_before_that_pass_constraint(pass2, pass3), + this_before_that_pass_constraint(pass3, pass1), + ] + sorted = _topological_sort_passes(passes, constraints) + self.assertEqual(sorted, [pass4, pass5, pass0, pass2, pass3, pass1]) + + # Circular dependency should result in the circular_dep flag being set + passes = [pass0, pass1, pass2] + constraints = [ + this_before_that_pass_constraint(passes[0], passes[1]), + this_before_that_pass_constraint(passes[1], passes[2]), + this_before_that_pass_constraint(passes[2], passes[0]), + ] + with self.assertRaises(RuntimeError) as e: + _topological_sort_passes(passes, constraints) + expected_error_msg = f"Circular dependency detected within the following passes: {passes}" + self.assertEqual(e.exception.args[0], expected_error_msg) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index b4e19ded9aa38..afaf9e84f6824 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -1,4 +1,4 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: fx"] import os import sys diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py new file mode 100644 index 0000000000000..c864523fe0664 --- /dev/null +++ b/test/fx/test_z3_gradual_types.py @@ -0,0 +1,2294 @@ +# Owner(s): ["module: fx"] +import operator +import unittest +from torch.fx import GraphModule, symbolic_trace +from torch.fx.experimental.meta_tracer import symbolic_trace as meta_symbolic_trace +from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintT, DVar, TVar, T +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator +from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint +from torch.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency +from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints,\ + evaluate_conditional_with_constraints +from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn +from torch.fx.experimental.rewriter import RewritingTracer +from torch.fx.tensor_type import Dyn, TensorType +import torch + + +try: + import z3 # type: ignore[import] + HAS_Z3 = True +except ImportError: + HAS_Z3 = False + + +try: + from torchvision import models + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False +skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + + +class HFOperations(unittest.TestCase): + + + def test_conditional_ne_1(self): + """ + This test case is for the HFmodels interface. + A function takes a node and a graph and considers + the conditional the node represents and its negation + and solves each formula with the remaining sets of constraints + Returns: + + """ + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([32, 4, 4]), y: TensorType([32, 4, 4])): + size_5 = x.size() + getitem_7 = size_5[0] + getitem_8 = size_5[1] + getitem_9 = size_5[2] + ne_1 = y != (getitem_7, getitem_8, getitem_9) + return ne_1 + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + + # The node we are considering is the gt node + for n in graph.nodes: + if n.target == operator.ne: + node = n + + # since x and y are equal, the requirement that x != y cannot be true, so we should get unsat + # for the positive condition and sat for the negative condition + positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) + self.assertEqual(positive, z3.unsat) + self.assertEqual(negative, z3.sat) + + def test_bmm(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, 2, 3]), y: TensorType([1, 3, 2])): + bmm = torch.bmm(x, y) + return bmm + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock()) + b = BasicBlock().forward(torch.rand(1, 2, 3), torch.rand(1, 3, 2)) + transformed = transform_all_constraints(symbolic_traced, counter=0) + + s = z3.Solver() + s.add(transformed) + + output = z3.Const(3, tensor_type) + self.assertEqual(s.check(), z3.sat) + self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) + self.assertEqual(s.model()[output].arg(1).arg(1), b.shape[1]) + self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) + + + def test_bmm2(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn, y: TensorType([1, 3, 2])): + bmm = torch.bmm(x, y) + return bmm + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock()) + b = BasicBlock().forward(torch.rand(1, 2, 3), torch.rand(1, 3, 2)) + transformed = transform_all_constraints(symbolic_traced, counter=0) + + s = z3.Solver() + s.add(transformed) + + output = z3.Const(3, tensor_type) + self.assertEqual(s.check(), z3.sat) + self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) + self.assertEqual(s.model()[output].arg(1).arg(0), 0) + self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) + + def test_bmm3(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 3, 3]), y: TensorType([1, 3, 2])): + bmm = torch.bmm(x, y) + return bmm + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock()) + transformed = transform_all_constraints(symbolic_traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.unsat) + + + def test_transpose(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([1, 2, 3, 4])): + transpose = x.transpose(0, 1) + return transpose + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock()) + b = BasicBlock().forward(torch.rand(1, 2, 3, 4)) + + transformed = transform_all_constraints(symbolic_traced, counter=0) + + s = z3.Solver() + s.add(transformed) + output = z3.Const(2, tensor_type) + self.assertEqual(s.check(), z3.sat) + self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) + self.assertEqual(s.model()[output].arg(1).arg(1), b.shape[1]) + self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) + self.assertEqual(s.model()[output].arg(3).arg(1), b.shape[3]) + + # change the annotation to Dyn + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = Dyn + + transformed = transform_all_constraints(symbolic_traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + + def test_index_select(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2050, 1024]), y: Dyn): + index_select = x.index_select(0, y) + return index_select + symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock()) + # print(symbolic_traced) + b = BasicBlock().forward(torch.rand(2050, 1024), torch.ones(8).int()) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + index_select = z3.Const(3, tensor_type) + + # the second dimension of the result should not be affected since + # the index is 0 + self.assertEqual(s.model()[index_select].arg(1).arg(1), b.shape[1]) + + replacement_vector = z3.Const(2, tensor_type) + + # we set the vector to Dyn + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + index_select = z3.Const(3, tensor_type) + s.add(replacement_vector == z3_dyn) + self.assertEqual(s.check(), z3.sat) + + # this implies that the index at 0 should be dyn + self.assertEqual(s.model()[index_select].arg(0).arg(0), 0) + + def test_get_attr(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([1, 2, 3])): + getattr = x.device + to = x.to(getattr) + return to + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock()) + b = BasicBlock().forward(torch.rand(1, 2, 3)) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + attr_res = z3.Const(3, tensor_type) + assert s.model()[attr_res].arg(0).arg(1) == b.shape[0] + assert s.model()[attr_res].arg(1).arg(1) == b.shape[1] + assert s.model()[attr_res].arg(2).arg(1) == b.shape[2] + + + def test_expand(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([1, 4])): + size = x.size() + getitem = size[-1] + expand = x.expand(getitem, 4) + return expand + + b = BasicBlock().forward(torch.rand(1, 4)) + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock()) + transformed = transform_all_constraints(symbolic_traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + expand_res = z3.Const(4, tensor_type) + assert s.model()[expand_res].arg(0).arg(1) == b.shape[0] + assert s.model()[expand_res].arg(1).arg(1) == b.shape[1] + + # change the annotation on the input to Dyn. + # the last dimension should still be 4 + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = Dyn + + transformed = transform_all_constraints(symbolic_traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + assert s.model()[expand_res].arg(1).arg(1) == b.shape[1] + + def test_getitem_tensor(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([4, 4])): + getitem = x[(None, None, slice(None, None, None), slice(None, None, None))] + return getitem + + B = BasicBlock() + b = B.forward(torch.rand(4, 4)) + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(B) + transformed = transform_all_constraints(symbolic_traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + get_item_res = z3.Const(2, tensor_type) + assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] + assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] + assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] + assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] + + # change the annotation on the input to make sure it propagates + # to the output + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([Dyn, 4]) + + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + # dyn check + assert s.model()[get_item_res].arg(2).arg(0) == 0 + + + def test_getitem_tensor2(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([4, 4])): + getitem = x[(None, None)] + return getitem + + B = BasicBlock() + b = B.forward(torch.rand(4, 4)) + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(B) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + get_item_res = z3.Const(2, tensor_type) + assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] + assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] + assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] + assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] + + + def test_getitem_tensor_3(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([4, 4])): + getitem = x[(None, slice(None, None, None), None, slice(None, None, None))] + return getitem + + B = BasicBlock() + b = B.forward(torch.rand(4, 4)) + symbolic_traced: torch.fx.GraphModule = symbolic_trace(B) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + get_item_res = z3.Const(2, tensor_type) + assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] + assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] + assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] + assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] + + + + def test_layer_norm(self): + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + self.l = torch.nn.LayerNorm((1024,)) + + def forward(self, x: Dyn): + return self.l(x) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + # make the output a size 1 tensor which should result + # in the migration of the input + + b = BasicBlock().forward(torch.rand(1024)) + input = z3.Const(1, tensor_type) + output = z3.Const(2, tensor_type) + s.add(output == tensor_type.tensor1(D(1, 1024))) + s.check() + self.assertEqual(s.model()[input], s.model()[output]) + # input shape = output shape + self.assertEqual(b.shape[0], s.model()[input].arg(0).arg(1)) + + # change annotation to the wrong shape + for n in graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([10, 10]) + + traced = GraphModule(ast_rewriter.root, graph, "gm") + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.unsat) + + # fix the annotation + for n in graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([10, 1024]) + + traced = GraphModule(ast_rewriter.root, graph, "gm") + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + s.check() + b = BasicBlock().forward(torch.rand(10, 1024)).shape + self.assertEqual(s.model()[output].arg(0).arg(1), b[0]) + self.assertEqual(s.model()[output].arg(1).arg(1), b[1]) + + def test_ne_int_long_type_as(self): + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn, Dyn])): + ne_int = torch.ne(x, y).int() + type_as = ne_int.type_as(y) + long = type_as.long() + return long + + symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock()) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + # migrate one of the parameters to a fully static shape so we can compare + + input = z3.Const(1, tensor_type) + input_2 = z3.Const(2, tensor_type) + s1, s2 = z3.Ints('s1 s2') + + output_long = z3.Const(8, tensor_type) + s.add(input == tensor_type.tensor2(D(1, 2), D(1, 4))) + s.add(input_2 == tensor_type.tensor2(D(1, s1), D(1, s2))) + + self.assertEquals(s.check(), z3.sat) + actual_shape = BasicBlock().forward(torch.rand(2, 4), torch.rand(2, 4)).shape + self.assertEqual(s.model()[output_long].arg(0).arg(1), actual_shape[0]) + self.assertEqual(s.model()[output_long].arg(1).arg(1), actual_shape[1]) + + + def test_ne(self): + s1, s2 = z3.Ints('s1 s2') + s11, s22 = z3.Ints('s11 s22') + d1, d2 = D(s11, s1), D(0, s2) + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn, y: Dyn): + return torch.ne(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + # change the annotations + for n in graph.nodes: + if n.name == 'x': + n.type = TensorType([1, 2]) + if n.name == 'y': + n.type = TensorType([2, Dyn]) + + # resulting type should be TensorType([2, 2]) + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + # force the second dimension to be Dyn + # output should still be TensorType([2, 2]) + input = z3.Const(2, tensor_type) + s.add(input == tensor_type.tensor2(d1, d2)) + self.assertEqual(s.check(), z3.sat) + B = BasicBlock().forward(torch.rand(1, 2), torch.rand(2, 1)) + output = z3.Const(3, tensor_type) + self.assertEqual(s.model()[output].arg(0).arg(1), B.shape[0]) + self.assertEqual(s.model()[output].arg(1).arg(1), B.shape[0]) + + + def test_cumsum(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, 4, 3])): + t = torch.cumsum(x, 3) + return t + + symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + + # should be unsat since the index is not valid for this annotation + self.assertEqual(s.check(), z3.unsat) + + # modify the annotation to Dyn which should give sat + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = Dyn + + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + # # modify the annotation to the right tensor size + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([1, 2, 3, 4]) + + # verify that the input is equal to the output + B = BasicBlock().forward(torch.rand(1, 2, 3, 4)) + res_shape = B.shape + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + # confirm the output matches the expected tensor + result = z3.Const(2, tensor_type) + self.assertEqual(s.model()[result].arg(0).arg(1), res_shape[0]) + self.assertEqual(s.model()[result].arg(1).arg(1), res_shape[1]) + self.assertEqual(s.model()[result].arg(2).arg(1), res_shape[2]) + self.assertEqual(s.model()[result].arg(3).arg(1), res_shape[3]) + + # confirm the output is not dyn + self.assertNotEqual(s.model()[result].arg(0).arg(0).as_long(), 0) + self.assertNotEqual(s.model()[result].arg(1).arg(0).as_long(), 0) + self.assertNotEqual(s.model()[result].arg(2).arg(0).as_long(), 0) + self.assertNotEqual(s.model()[result].arg(3).arg(0).as_long(), 0) + + + def test_cumsum_kwargs(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, 4, 3])): + t = torch.cumsum(x, dim=3) + return t + + symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + + # should be unsat since the index is not valid for this annotation + self.assertEqual(s.check(), z3.unsat) + + # modify the annotation to Dyn which should give sat + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = Dyn + + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + + def test_arange(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 4])): + size = x.size() + getitem = size[-1] + arange = torch.arange(getitem) + return arange + + B = BasicBlock().forward(torch.rand(2, 4)) + + symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + arange_result = z3.Const(5, tensor_type) + self.assertNotEqual(s.model()[arange_result].arg(0).arg(0).as_long(), 0) + self.assertEqual(s.model()[arange_result].arg(0).arg(1).as_long(), B.size()[0]) + + # change the annotation to Dyn. This will migrate to an arbitirary type + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = Dyn + + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([Dyn, Dyn, Dyn, Dyn]) + + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + def test_scalar_add(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 4])): + size = x.size() + getitem = size[-1] + arange = torch.arange(getitem) + add = arange + 1 + return add + + symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + arange_result = z3.Const(5, tensor_type) + add_result = z3.Const(6, tensor_type) + self.assertEqual(s.model()[arange_result], s.model()[add_result]) + + + def test_regular_add_2(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 4])): + to = x.to() + size = to.size() + getitem = size[-1] + add = getitem + 1 + return add + + b = BasicBlock().forward(torch.rand(2, 4)) + + symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + res = z3.Int(5) + self.assertEqual(s.model()[res], b) + + + def test_regular_add_3(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 4])): + to = x.to() + size = to.size() + getitem = size[-1] + add = 1 + getitem + return add + + b = BasicBlock().forward(torch.rand(2, 4)) + + symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + res = z3.Int(5) + self.assertEqual(s.model()[res], b) + + def test_embedding(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + self.embedding = torch.nn.Embedding(256008, 1024, padding_idx=1) + + def forward(self, x: TensorType([2, 4])): + return self.embedding(x) + + B = BasicBlock().forward(torch.ones([2, 4], dtype=torch.long)).size() + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + embedding_result = z3.Const(2, tensor_type) + + assert s.model()[embedding_result].arg(0).arg(1) == B[0] + assert s.model()[embedding_result].arg(1).arg(1) == B[1] + assert s.model()[embedding_result].arg(2).arg(1) == B[2] + + # change the type. This should still be satisfiable + for n in traced.graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([Dyn, Dyn]) + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + assert s.model()[embedding_result].arg(0).arg(0) == 0 + assert s.model()[embedding_result].arg(1).arg(0) == 0 + assert s.model()[embedding_result].arg(2).arg(1) == B[2] + + # change the type to Dyn. Here, we will get an arbitirary migration + for n in traced.graph.nodes: + if n.op == 'placeholder': + n.type = Dyn + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + + self.assertEquals(s.check(), z3.sat) + + + def test_size_two_args(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, 2, Dyn])): + size = x.size(-1) + return size + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + d1, d2 = z3.Int(39), z3.Int(2) + d4, d5 = z3.Int('input_d1'), z3.Int('input_d2') + + # migrate the third dimension + s.add(d1 != 0) + + self.assertEqual(s.check(), z3.sat) + input = z3.Const(1, tensor_type) + s.add(input == tensor_type.tensor3(D(3, 39), D(1, 2), D(d4, d5))) + + # check if the item we got is the right one + self.assertEqual(s.check(), z3.sat) + self.assertEqual(s.model()[d5], s.model()[d2]) + self.assertEqual(s.model()[d1], s.model()[d4]) + + def test_size_getitem(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn): + size = x.size() + getitem = size[-1] + return getitem + + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + + self.assertEquals(s.check(), z3.sat) + + # force the input to be of size 4 + + s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') + s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + + input = z3.Const(1, tensor_type) + s.add(input == tensor_type.tensor4(d1, d2, d3, d4)) + + # check if the model is still SAT + self.assertEquals(s.check(), z3.sat) + + s1, s2 = z3.Int(23), z3.Int(3) + + # check that the item is correct + self.assertEquals(s.model()[s1], s.model()[s2]) + + # invalid index but should still be SAT because input will be Dyn + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn): + size = x.size() + getitem = size[-10] + return getitem + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + + self.assertEquals(s.check(), z3.sat) + s.add(input != z3_dyn) + self.assertEqual(s.check(), z3.unsat) + + def test_view_mul(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) + + def forward(self, x: TensorType([2, 4])): + size = x.size() + getitem = size[-1] + view = x.view(-1, getitem) + embed_tokens = self.embed_tokens(view) + mul = embed_tokens * 32.0 + return mul + + + # print(B) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + # print(traced) + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + # print(s.model()) + + embedding_result = z3.Const(6, tensor_type) + + # note that the view output will be: tensor3(dim(0, 0), dim(1, 4), dim(1, 1024)) + # this is due to the reshape constraints. This can be lifted + # but would require revising the type rules accordingly so we leave it for now + assert (s.model()[embedding_result].arg(1).arg(1)) == 4 + assert (s.model()[embedding_result].arg(2).arg(1)) == 1024 + + mul_result = z3.Const(13, tensor_type) + assert s.model()[mul_result] == s.model()[embedding_result] + + def test_gt(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, 4])): + size = x.size() + getitem_1 = size[-1] + gt = getitem_1 > 1 + return gt + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + res = z3.Bool(4) + self.assertEqual(s.model()[res], True) + + def test_view(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 4])): + view = x.view(-1, 8) + return view + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + def test_lt_tensor(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 4]), y: Dyn): + lt = x > y + return lt + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + + def test_conditional(self): + """ + This test case is for the HFmodels interface. + A function takes a node and a graph and considers + the conditional the node represents and its negation + and solves each formula with the remaining sets of constraints + Returns: + + """ + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) + + def forward(self, x: TensorType([Dyn, 4])): + size = x.size() + getitem = size[-1] + view = x.view(-1, getitem) + embed_tokens = self.embed_tokens(view) + mul = embed_tokens * 32.0 + getitem_1 = size[-1] + gt = getitem_1 > 1 + return gt + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + + # The node we are considering is the gt node + for n in graph.nodes: + if n.target == operator.gt: + node = n + + positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) + self.assertEqual(positive, z3.sat) + self.assertEqual(negative, z3.unsat) + + # change the annotation to Dyn + for n in graph.nodes: + if n.op == 'placeholder': + n.type = Dyn + + # here, both should be SAT since the input is Dyn + positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) + + self.assertEqual(positive, z3.sat) + self.assertEqual(negative, z3.sat) + + + # change the annotation to TensorType[Dyn, Dyn] + for n in graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([Dyn, Dyn]) + + # here, both should be SAT as well + positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) + + self.assertEqual(positive, z3.sat) + self.assertEqual(negative, z3.sat) + + + def test_conditional_2(self): + """ + This test case is for the HFmodels interface. + A function takes a node and a graph and considers + the conditional the node represents and its negation + and solves each formula with the remaining sets of constraints + Returns the opposite result of the above testcase + + """ + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) + + def forward(self, x: TensorType([Dyn, 4])): + size = x.size() + getitem = size[-1] + view = x.view(-1, getitem) + embed_tokens = self.embed_tokens(view) + mul = embed_tokens * 32.0 + getitem_1 = size[-1] + lt = getitem_1 < 1 + return lt + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + + # The node we are considering is the gt node + for n in graph.nodes: + if n.target == operator.lt: + node = n + + positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) + self.assertEqual(positive, z3.unsat) + self.assertEqual(negative, z3.sat) + + +class ComposeOperationsGradualTypes(unittest.TestCase): + + def test_masked_fill(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 4])): + size = x.size() + getitem = size[-1] + arange = torch.arange(getitem) + view = x.view(-1, getitem) + lt = arange > view + masked_fill = x.masked_fill_(lt, 0) + return masked_fill + + B = BasicBlock().forward(torch.rand(2, 4)) + # print(B.shape) + + symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) + # print(symbolic_traced) + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + masked_fill_res = z3.Const(10, tensor_type) + self.assertEqual(s.model()[masked_fill_res].arg(0).arg(1).as_long(), B.size()[0]) + self.assertEqual(s.model()[masked_fill_res].arg(1).arg(1).as_long(), B.size()[1]) + + # change the annotation to Dyn. This will migrate to an arbitirary type + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = Dyn + + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([Dyn, Dyn, Dyn, Dyn]) + + transformed = transform_all_constraints(symbolic_traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEqual(s.check(), z3.sat) + + def test_add_reshape_1(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn, y: Dyn): + return torch.add(torch.reshape(x, (1, 2)), torch.reshape(y, (2, 2))) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + def test_add_reshape_2(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn, y: Dyn): + return torch.add(torch.reshape(x, (-1, 2)), torch.reshape(y, (2, 2, 2))) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + def test_conv_reshape_add_0(self): + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: Dyn, y: Dyn): + return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + new_transformed_c = transform_all_constraints(traced) + solver = z3.Solver() + solver.add(new_transformed_c) + self.assertEquals(solver.check(), z3.sat) + + + def test_conv_reshape_add_0_2(self): + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: Dyn, y: TensorType([4, 1])): + return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + + # 4,1 + # 1, 2, 4, 8 + res = B.forward(torch.rand(20, 20), torch.rand(1, 2, 4, 8)).size() + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + new_transformed_c = transform_all_constraints(traced) + solver = z3.Solver() + solver.add(new_transformed_c) + self.assertEquals(solver.check(), z3.sat) + + + conv_result = z3.Const(4, tensor_type) + add_result = z3.Const(9, tensor_type) + input_2 = z3.Const(2, tensor_type) + + s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') + s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + + + solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) + solver.check() + assert solver.model()[s1].as_long() == res[0] + assert solver.model()[s2].as_long() == res[1] + assert solver.model()[s3].as_long() == res[2] + assert solver.model()[s4].as_long() == res[3] + + solver.add(input_2 == tensor_type.tensor2(D(1, 4), D(1, 1))) + self.assertEquals(solver.check(), z3.sat) + solver.add(add_result == tensor_type.tensor4(d1, d2, d3, d4)) + self.assertEquals(solver.check(), z3.sat) + + # first dimension could be anything because we have broadcasting + assert solver.model()[s1] == res[0] + assert solver.model()[s2] == res[1] + assert solver.model()[s3] == res[2] + assert solver.model()[s4] == res[3] + + def test_conv_reshape_add_0_3(self): + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: Dyn, y: TensorType([11, 1])): + return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + new_transformed_c = transform_all_constraints(traced) + solver = z3.Solver() + solver.add(new_transformed_c) + self.assertEquals(solver.check(), z3.unsat) + + + def test_conv_reshape_add_1(self): + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: Dyn, y: TensorType([1, 2, 10, 20])): + return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + new_transformed_c = transform_all_constraints(traced) + solver = z3.Solver() + solver.add(new_transformed_c) + self.assertEquals(solver.check(), z3.unsat) + + +class GradualTypes(unittest.TestCase): + def test_conv_reshape_unsat(self): + + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: Dyn): + return self.conv1(torch.reshape(x, (1, 2, 10))) + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + new_transformed_c = transform_all_constraints(traced) + solver = z3.Solver() + solver.add(new_transformed_c) + self.assertEquals(solver.check(), z3.unsat) + + def test_conv_reshape0(self): + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: Dyn): + return self.conv1(torch.reshape(x, (1, 2, 10, 20))) + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + res = B.forward(torch.rand(20, 20)).size() + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + new_transformed_c = transform_all_constraints(traced) + + solver = z3.Solver() + solver.add(new_transformed_c) + self.assertEquals(solver.check(), z3.sat) + conv_result = z3.Const(3, tensor_type) + + s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') + s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + + solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) + solver.check() + # print(solver.model()) + # print(type(solver.model()[s1])) + assert solver.model()[s1].as_long() == res[0] + assert solver.model()[s2].as_long() == res[1] + assert solver.model()[s3].as_long() == res[2] + assert solver.model()[s4].as_long() == res[3] + + s1, s2, s3, s4 = z3.Ints('y1 y2 y3 y4') + s11, s22, s33, s44 = z3.Ints('y11 y22 y33 y44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + + input = z3.Const(1, tensor_type) + solver.add(input == tensor_type.tensor4(d1, d2, d3, d4)) + + # assert solver.check() == sat + # solver.add(s11 == 1) + # solver.add(s22 == 1) + # solver.add(s33 == 1) + # solver.add(s44 == 1) + # + # print(solver.check()) + # print(solver.model()) + + + def test_conv_reshape1(self): + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: TensorType([20, 20])): + return self.conv1(torch.reshape(x, (1, -1, 10, 20))) + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + res = B.forward(torch.rand(20, 20)).size() + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + new_transformed_c = transform_all_constraints(traced) + + solver = z3.Solver() + solver.add(new_transformed_c) + self.assertEquals(solver.check(), z3.sat) + conv_result = z3.Const(3, tensor_type) + + s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') + s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + + solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) + solver.check() + # print(solver.model()) + assert solver.model()[s1].as_long() == res[0] + assert solver.model()[s2].as_long() == res[1] + assert solver.model()[s3].as_long() == res[2] + assert solver.model()[s4].as_long() == res[3] + + +class TestSingleOperation(unittest.TestCase): + def test_conv_dyn(self): + + s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') + e1, e2, e3, e4 = z3.Ints('e1 e2 e3 e4') + s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') + e11, e22, e33, e44 = z3.Ints('e11 e22 e33 e44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + b1, b2, b3, b4 = D(e11, e1), D(e22, e2), D(e33, e3), D(e44, e4) + + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: Dyn): + return self.conv1(x) + + + BasicBlock(2, 2, 2, 2, 2, 2, 2).forward(torch.rand(4, 2, 3, 4)) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock(2, 2, 2, 2, 2, 2, 2)) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced) + + solver3 = z3.Solver() + solver3.add(transformed) + assert solver3.check() == z3.sat + + x = z3.Const(1, tensor_type) + y = z3.Const(2, tensor_type) + + solver3.add(x == tensor_type.tensor4(d1, d2, d3, d4), + y == tensor_type.tensor4(b1, b2, b3, b4)) + + assert solver3.check() == z3.sat + assert solver3.model()[s1].as_long() == solver3.model()[e1].as_long() + assert solver3.model()[s11].as_long() == solver3.model()[e11].as_long() + + solver3.add(s2 != 2) + assert solver3.check() == z3.sat + assert solver3.model()[s22].as_long() == 0 + + solver3.add(s22 != 0) + self.assertEquals(solver3.check(), z3.unsat) + + solver2 = z3.Solver() + solver2.add(transformed) + assert solver2.check() == z3.sat + solver2.add(x == tensor_type.tensor3(d1, d2, d3)) + self.assertEquals(solver2.check(), z3.unsat) + + + def test_add(self): + s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') + s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn, y: Dyn): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + # make the tensor be of size 1 + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor1(D(1, s11))) + self.assertEquals(s.check(), z3.sat) + + y = z3.Const(2, tensor_type) + s.add(y == tensor_type.tensor1(D(1, s22))) + self.assertEquals(s.check(), z3.sat) + + s.add(s11 == 1) # tensor[1] + s.add(s22 == 2) # tensor[2] + self.assertEquals(s.check(), z3.sat) + + class BasicBlock2(torch.nn.Module): + def __init__(self): + super(BasicBlock2, self).__init__() + + def forward(self, x: TensorType((Dyn,)), y: Dyn): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock2()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + # make the tensor be of size 1 + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor1(D(1, s11))) + self.assertEquals(s.check(), z3.sat) + y = z3.Const(2, tensor_type) + s.add(y == tensor_type.tensor1(D(1, s22))) + self.assertEquals(s.check(), z3.sat) + s.add(s11 == 4) # tensor[4] + s.add(s22 == 5) # tensor[5] + self.assertEquals(s.check(), z3.unsat) + + class BasicBlock3(torch.nn.Module): + def __init__(self): + super(BasicBlock3, self).__init__() + + def forward(self, x: TensorType((Dyn,)), y: Dyn): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock3()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced) + s = z3.Solver() + s.add(transformed) + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor2(d1, d2)) + self.assertEquals(s.check(), z3.unsat) + + def test_add_padding(self): + s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType((Dyn,)), y: TensorType((Dyn, Dyn))): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor1(D(1, s1))) + + self.assertEquals(s.check(), z3.sat) + + # print(s.model()) + + def test_add_padding_2(self): + s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + # print(s.model()) + + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor2(D(1, s1), D(1, s2))) + self.assertEquals(s.check(), z3.sat) + + y = z3.Const(2, tensor_type) + s.add(y == tensor_type.tensor1(D(0, s3))) + self.assertEquals(s.check(), z3.sat) + + add_result = z3.Const(3, tensor_type) + broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const(5, tensor_type) + + # print(s.model()) + + assert s.model()[broadcast_res1].decl() == tensor_type.tensor2 + assert s.model()[broadcast_res2].decl() == tensor_type.tensor2 + assert s.model()[add_result].decl() == tensor_type.tensor2 + assert s.model()[y].decl() == tensor_type.tensor1 + + # print(s.model()) + + # prevent broadcasting for that dimension + s.add(s2 > 1) + + assert s.check() + + # the second dimension of the result is a number, not Dyn. + # however if the first input dimension had been 1, we would + # have had dyn in the result, as seen in the next test case + assert s.model()[add_result].arg(1).arg(0).as_long() != 0 + + def test_add_padding_3(self): + s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, 1]), y: TensorType([Dyn])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + # print(transformed) + self.assertEquals(s.check(), z3.sat) + + x = z3.Const(1, tensor_type) + y = z3.Const(2, tensor_type) + + s.add(s2 != 0) + s.add(x == tensor_type.tensor2(D(0, s1), D(s2, 1))) + s.add(y == tensor_type.tensor1(D(0, s3))) + + self.assertEquals(s.check(), z3.sat) + + # print(s.model()) + + add_result = z3.Const(3, tensor_type) + assert s.model()[add_result].arg(0).arg(0).as_long() == 0 + assert s.model()[add_result].arg(1).arg(0).as_long() == 0 + + + def test_add_padding_4(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 1]), y: TensorType([3])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + + self.assertEquals(s.check(), z3.sat) + + add_result = z3.Const(3, tensor_type) + assert s.model()[add_result] == tensor_type.tensor2(D(1, 2), D(1, 3)) + + def test_add_padding_5(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([2, 2]), y: TensorType([3])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.unsat) + + def test_add_size_3(self): + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn, Dyn, Dyn]), y: TensorType([Dyn, Dyn, Dyn])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + x = z3.Const(1, tensor_type) + y = z3.Const(2, tensor_type) + + s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') + + s.add(x == tensor_type.tensor3(D(1, s1), D(1, 1), D(1, s2))) + s.add(y == tensor_type.tensor3(D(1, s3), D(1, s4), D(1, s5))) + + self.assertEquals(s.check(), z3.sat) + s.add(s2 == 5) + self.assertEquals(s.check(), z3.sat) + s.add(s5 == 6) + self.assertEquals(s.check(), z3.unsat) + + def test_add_padding_6(self): + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + + x = z3.Const(1, tensor_type) + y = z3.Const(2, tensor_type) + + s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') + + s.add(x == tensor_type.tensor1(D(1, s1))) + s.add(y == tensor_type.tensor3(D(1, s2), D(1, s3), D(1, s4))) + + self.assertEquals(s.check(), z3.sat) + + s.add(s1 == 4) + s.add(s4 == 5) + + self.assertEquals(s.check(), z3.unsat) + + def test_add_padding_7(self): + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn, Dyn])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + x = z3.Const(1, tensor_type) + s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') + s.add(x == tensor_type.tensor2(D(s1, s2), D(s2, s3))) + self.assertEquals(s.check(), z3.unsat) + + + def test_add_padding_8(self): + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn, Dyn])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + x = z3.Const(1, tensor_type) + y = z3.Const(2, tensor_type) + + s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') + s.add(x == tensor_type.tensor1(D(s1, 1))) + s.add(s1 >= 0) + + self.assertEquals(s.check(), z3.sat) + + s.add(y == tensor_type.tensor4(D(0, s2), D(0, s3), D(0, s4), D(0, s5))) + self.assertEquals(s.check(), z3.sat) + + def test_add_padding_9(self): + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn, y: TensorType([Dyn, Dyn, Dyn, Dyn])): + return torch.add(x, y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced, counter=0) + s = z3.Solver() + s.add(transformed) + + self.assertEquals(s.check(), z3.sat) + x = z3.Const(1, tensor_type) + y = z3.Const(2, tensor_type) + + s1, s2, s3, s4, s5, s6, s7 = z3.Ints('s1 s2 s3 s4 s5 s6 s7') + s.add(x == tensor_type.tensor1(D(s1, s7))) + s.add(s1 == 1) + self.assertEquals(s.check(), z3.sat) + + s.add(y == tensor_type.tensor4(D(0, s2), D(0, s3), D(0, s4), D(s6, s5))) + self.assertEquals(s.check(), z3.sat) + + s.add(s6 == 1) + + self.assertEquals(s.check(), z3.sat) + s.add(s5 != 1, s7 != 1) + assert s.check() + + assert s.model()[s5].as_long() == s.model()[s7].as_long() + + def test_conv_static(self): + s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') + e1, e2, e3, e4 = z3.Ints('e1 e2 e3 e4') + s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') + e11, e22, e33, e44 = z3.Ints('e11 e22 e33 e44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + b1, b2, b3, b4 = D(e11, e1), D(e22, e2), D(e33, e3), D(e44, e4) + + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation) + + def forward(self, x: TensorType((1, 2, 10, 20))): + return self.conv1(x) + + ast_rewriter = RewritingTracer() + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + res = B.forward(torch.rand(1, 2, 10, 20)).size() + + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + new_transformed_c = transform_all_constraints(traced) + solver = z3.Solver() + solver.add(new_transformed_c) + self.assertEquals(solver.check(), z3.sat) + + x = z3.Const(1, tensor_type) + y = z3.Const(2, tensor_type) + + solver.add(x == tensor_type.tensor4(d1, d2, d3, d4)) + solver.add(y == tensor_type.tensor4(b1, b2, b3, b4)) + self.assertEquals(solver.check(), z3.sat) + # print(solver.model()) + assert solver.model()[e3].as_long() == res[2] + assert solver.model()[e4].as_long() == res[3] + + B2 = BasicBlock(2, 4, 5, 2, 9, 2, 2) + res2 = B2.forward(torch.rand(1, 2, 10, 20)).size() + + graph2 = ast_rewriter.trace(B2) + traced2 = GraphModule(ast_rewriter.root, graph2, "gm") + new_transformed_c = transform_all_constraints(traced2) + solver = z3.Solver() + solver.add(new_transformed_c) + + solver.add(x == tensor_type.tensor4(d1, d2, d3, d4)) + solver.add(y == tensor_type.tensor4(b1, b2, b3, b4)) + + self.assertEquals(solver.check(), z3.sat) + assert solver.model()[e3].as_long() == res2[2] + assert solver.model()[e4].as_long() == res2[3] + + def test_reshape_dyn(self): + s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn): + return torch.reshape(x, (2, -1)) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + transformed = transform_all_constraints(traced) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor1(D(1, s11))) + self.assertEquals(s.check(), z3.sat) + s.add(z3.Or([s11 == 2, s11 == 4, s11 == 9])) + self.assertEquals(s.check(), z3.sat) + s.add(s11 == 9) + self.assertEquals(s.check(), z3.unsat) + + + def test_reshape_annotated(self): + s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') + s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') + d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn])): + return torch.reshape(x, (2, -1)) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + transformed = transform_all_constraints(traced) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor2(d1, d2)) + self.assertEquals(s.check(), z3.unsat) + + def test_reshape_static_target(self): + s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: TensorType([Dyn])): + return torch.reshape(x, (2, 3)) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + transformed = transform_all_constraints(traced) + # print(transformed) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor1(D(1, s11))) + s.check() + assert s.model()[s11].as_long() == 6 + s.add(s11 != 6) + self.assertEquals(s.check(), z3.unsat) + + def test_reshape_static_target2(self): + s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn): + return torch.reshape(x, (2, 3, 1, 1)) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + transformed = transform_all_constraints(traced) + s = z3.Solver() + s.add(transformed) + self.assertEquals(s.check(), z3.sat) + x = z3.Const(1, tensor_type) + s.add(x == tensor_type.tensor1(D(1, s11))) + s.check() + assert s.model()[s11].as_long() == 6 + s.add(s11 != 6) + self.assertEquals(s.check(), z3.unsat) + + + def test_conv2D_maxpool2d_flatten(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + self.conv1 = torch.nn.Conv2d(3, 6, 5) + self.pool = torch.nn.MaxPool2d(2, 2) + self.conv2 = torch.nn.Conv2d(6, 16, 5) + self.fc1 = torch.nn.Linear(5, 120) + self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) + + def forward(self, x : TensorType((4, 3, 32, 32))): + out = self.conv1(x) + out = self.pool(out) + out = self.conv2(out) + out = self.pool(out) + out = self.fc1(out) + out = self.pool2(out) + out = torch.flatten(out, 1) + return out + + B = BasicBlock() + res = B.forward(torch.rand(4, 3, 32, 32)).shape + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + constraints = transform_all_constraints(traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + solver.check() + input = z3.Const(1, tensor_type) + solver.add(input == tensor_type.tensor4(D(1, 4), D(1, 3), D(1, 32), D(1, 32))) + solver.check() + output = z3.Const(48, tensor_type) + assert solver.model()[output].arg(0).arg(1) == res[0] + assert solver.model()[output].arg(1).arg(1) == res[1] + + def test_conv2D_maxpool2d_flatten_unsat(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + self.conv1 = torch.nn.Conv2d(3, 6, 5) + self.pool = torch.nn.MaxPool2d(2, 2) + self.conv2 = torch.nn.Conv2d(6, 16, 5) + self.fc1 = torch.nn.Linear(5, 120) + self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) + + def forward(self, x : TensorType((4, 3, 32, 32))): + out = self.conv1(x) + out = self.pool(out) + out = self.conv2(out) + out = self.pool(out) + out = self.fc1(out) + out = self.pool2(out) + out = torch.flatten(out, 1) + return out + + B = BasicBlock() + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + constraints = transform_all_constraints(traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + solver.check() + input = z3.Const(1, tensor_type) + solver.add(input == tensor_type.tensor4(D(1, 4), D(1, 3), D(1, 32), D(1, 45))) + self.assertEquals(solver.check(), z3.unsat) + + def test_conv2D_maxpool2d_flatten_dyn(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + self.conv1 = torch.nn.Conv2d(3, 6, 5) + self.pool = torch.nn.MaxPool2d(2, 2) + self.conv2 = torch.nn.Conv2d(6, 16, 5) + self.fc1 = torch.nn.Linear(5, 120) + self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) + + def forward(self, x : TensorType((Dyn, 3, 32, 32))): + out = self.conv1(x) + out = self.pool(out) + out = self.conv2(out) + out = self.pool(out) + out = self.fc1(out) + out = self.pool2(out) + out = torch.flatten(out, 1) + return out + + B = BasicBlock() + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + constraints = transform_all_constraints(traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.sat) + + def test_type_check_flatten(self): + s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') + + class M(torch.nn.Module): + def forward(self, x: TensorType([2, 3, 4, 5])): + return torch.flatten(x, start_dim=1, end_dim=3) + + module = M() + symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) + constraints = transform_all_constraints(symbolic_traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.sat) + flatten = z3.Const(2, tensor_type) + + res = M().forward(torch.rand(2, 3, 4, 5)).size() + assert solver.model()[flatten].arg(0).arg(1) == res[0] + assert solver.model()[flatten].arg(1).arg(1) == res[1] + + class M(torch.nn.Module): + def forward(self, x: TensorType([2, 3, Dyn, 5])): + return torch.flatten(x, start_dim=1, end_dim=3) + + module = M() + symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) + constraints = transform_all_constraints(symbolic_traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.sat) + x = z3.Const(1, tensor_type) + y = z3.Const(2, tensor_type) + + solver.add(x == tensor_type.tensor4(D(1, 2), D(1, 3), D(0, s1), D(1, 5))) + self.assertEquals(solver.check(), z3.sat) + assert solver.model()[y].arg(1).arg(0) == 0 + + + class M(torch.nn.Module): + def forward(self, x: TensorType([2, 3, Dyn])): + return torch.flatten(x, 10, 0) + + module = M() + # print(module.forward(torch.rand(2,3,5)).shape) + symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) + constraints = transform_all_constraints(symbolic_traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.unsat) + +class ConstraintGeneration(unittest.TestCase): + + def test_add_reshape(self): + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + def forward(self, x: Dyn, y: Dyn): + return torch.add(torch.reshape(x, (1, 2)), torch.reshape(y, (2, 2))) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(BasicBlock()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + generator = ConstraintGenerator(traced) + new_constraints, counter = generator.generate_constraints(0) + assert len(new_constraints.conjucts) == 11 + + + def test_conv_reshape_add(self): + class BasicBlock(torch.nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): + super(BasicBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, + padding=padding, groups=groups, bias=False, dilation=dilation) + + def forward(self, x: Dyn, y: Dyn): + return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) + + B = BasicBlock(2, 2, 2, 3, 2, 2, 2) + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + generator = ConstraintGenerator(traced) + new_constraints, counter = generator.generate_constraints(0) + assert len(new_constraints.conjucts) == 16 + + +class TestInternalConstraints(unittest.TestCase): + def test_precision(self): + + c1 = BinConstraintT(Dyn, TVar('x'), op_precision) + transformed, _ = transform_constraint(c1, 0) + assert transformed == T() + + c2 = BinConstraintT(TensorType([1, Dyn, 3]), TVar('x'), op_precision) + transformed, counter = transform_constraint(c2, 0) + assert len(transformed.conjucts) == 7 + + def test_matching(self): + c1 = BinConstraintT(TVar('x'), + TensorType([DVar('a'), DVar('b'), DVar('c'), DVar('d')]), op_matching) + transformed, _ = transform_constraint(c1, 0) + assert len(transformed.disjuncts) == 2 + + def test_consistency(self): + c1 = BinConstraintT(TVar('x'), + TensorType([DVar('a'), DVar('b')]), op_consistency) + transformed, count = transform_constraint(c1, 0) + + assert len(transformed.disjuncts) == 5 + transformed, count = transform_constraint(transformed, count) + assert len(transformed.disjuncts) == 5 + + # def test_apply_broadcasting(self): + # c1 = ApplyBroadcasting(TVar(1), TVar(2), TVar(3), TVar(4)) + # transformed, count = transform_apply_broadcasting(c1, 5) + # assert len(transformed.conjucts) == 41 + +@skipIfNoTorchVision +class TestResNet(unittest.TestCase): + + def test_resnet50_unsat(self): + traced = symbolic_trace(models.resnet50()) + for n in traced.graph.nodes: + n.type = Dyn + + constraints = transform_all_constraints(traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + input = z3.Const(1, tensor_type) + # input with 3 dimensions + solver.add(input == tensor_type.tensor3(D(1, 1), D(1, 3), D(1, 224))) + self.assertEquals(solver.check(), z3.unsat) + + + + def test_resnet50(self): + traced = symbolic_trace(models.resnet50()) + for n in traced.graph.nodes: + n.type = Dyn + + sample_input = torch.randn(1, 3, 224, 224) + res = models.resnet50().forward(sample_input).size() + constraints = transform_all_constraints(traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.sat) + linear = z3.Const(650, tensor_type) + + input = z3.Const(1, tensor_type) + solver.add(input == tensor_type.tensor4(D(1, 1), D(1, 3), D(1, 224), D(1, 224))) + self.assertEquals(solver.check(), z3.sat) + assert solver.model()[linear] == tensor_type.tensor2(D(1, res[0]), D(1, res[1])) + + def test_resnet502(self): + traced = symbolic_trace(models.resnet50()) + for n in traced.graph.nodes: + n.type = Dyn + + constraints = transform_all_constraints(traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + linear = z3.Const(650, tensor_type) + input = z3.Const(1, tensor_type) + batch = z3.Int('b') + solver.add(input == tensor_type.tensor4(D(1, batch), D(1, 3), D(1, 224), D(1, 224))) + solver.add(batch > 4) + solver.check() + assert solver.model()[batch] == solver.model()[linear].arg(0).arg(1) + + def test_resnet503(self): + traced = symbolic_trace(models.resnet50()) + for n in traced.graph.nodes: + n.type = Dyn + + constraints = transform_all_constraints(traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + linear = z3.Const(650, tensor_type) + input = z3.Const(1, tensor_type) + batch, d1, d2 = z3.Ints('b d1 d2') + solver.add(input == tensor_type.tensor4(D(1, batch), D(1, 3), D(1, 224), D(1, 224))) + solver.add(linear == tensor_type.tensor2(D(1, d1), D(1, d2))) + self.assertEquals(solver.check(), z3.sat) + solver.add(batch != d1) + self.assertEquals(solver.check(), z3.unsat) + +@skipIfNoTorchVision +class TestAlexNet(unittest.TestCase): + def test_alexnet1(self): + + alexnet = models.alexnet() + symbolic_traced : torch.fx.GraphModule = symbolic_trace(alexnet) + + for n in symbolic_traced.graph.nodes: + n.type = Dyn + + # print(symbolic_traced) + + res = alexnet.forward(torch.rand(10, 3, 227, 227)).size() + constraints = transform_all_constraints(symbolic_traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.sat) + input = z3.Const(1, tensor_type) + conv = z3.Const(2, tensor_type) + solver.add(input == tensor_type.tensor4(D(1, 10), D(1, 3), D(1, 227), D(1, 227))) + self.assertEquals(solver.check(), z3.sat) + assert solver.model()[conv] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 56), D(1, 56)) + + relu = z3.Const(7, tensor_type) + assert solver.model()[relu] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 56), D(1, 56)) + + maxpool = z3.Const(8, tensor_type) + assert solver.model()[maxpool] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 27), D(1, 27)) + + maxpool2 = z3.Const(42, tensor_type) + assert solver.model()[maxpool2] == tensor_type.tensor4(D(1, 10), D(1, 256), D(1, 6), D(1, 6)) + + flatten = z3.Const(52, tensor_type) + assert solver.model()[flatten] == tensor_type.tensor2(D(1, 10), D(1, 9216)) + + linear = z3.Const(64, tensor_type) + assert solver.model()[linear] == tensor_type.tensor2(D(1, 10), D(1, 4096)) + + linear2 = z3.Const(109, tensor_type) + assert solver.model()[linear2] == tensor_type.tensor2(D(1, res[0]), D(1, res[1])) + + + def test_alexnet2(self): + alexnet = models.alexnet() + symbolic_traced : torch.fx.GraphModule = symbolic_trace(alexnet) + + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([Dyn, 4, 227, 227]) + + constraints = transform_all_constraints(symbolic_traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.unsat) + + def test_alexnet3(self): + alexnet = models.alexnet() + symbolic_traced : torch.fx.GraphModule = symbolic_trace(alexnet) + + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([Dyn, Dyn, 227, 227]) + + constraints = transform_all_constraints(symbolic_traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.sat) + + def test_alexnet4(self): + alexnet = models.alexnet() + symbolic_traced : torch.fx.GraphModule = symbolic_trace(alexnet) + + for n in symbolic_traced.graph.nodes: + if n.op == 'placeholder': + n.type = TensorType([Dyn, Dyn, 227]) + + constraints = transform_all_constraints(symbolic_traced, counter=0) + solver = z3.Solver() + solver.add(constraints) + self.assertEquals(solver.check(), z3.unsat) + + + +if __name__ == '__main__': + unittest.main() diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index 09a58b3cd735f..d01063a65a3b7 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -15,7 +15,7 @@ sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase, make_global import torch.testing._internal.jit_utils -from torch.testing._internal.common_utils import IS_SANDCASTLE +from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo from typing import List, Tuple, Iterable, Optional, Dict if __name__ == '__main__': @@ -505,6 +505,7 @@ def fun(x: Any): with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", ""): sc = torch.jit.script(fun) + @skipIfTorchDynamo("Test does not work with TorchDynamo") @unittest.skipIf(IS_SANDCASTLE, "Importing like this doesn't work in fbcode") def test_imported_classes(self): import jit._imported_class_test.foo @@ -1230,6 +1231,7 @@ def test_function(a: int) -> 'ClassWithClassMethod': self.checkScript(test_function, (1,)) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_properties(self): """ Test that a scripted class can make use of the @property decorator. diff --git a/test/jit/test_cuda.py b/test/jit/test_cuda.py index 120f21d920a81..a151756d598f5 100644 --- a/test/jit/test_cuda.py +++ b/test/jit/test_cuda.py @@ -89,7 +89,6 @@ def test_multi_device_synchronize(): FileCheck().check("cuda::synchronize(") \ .run(test_multi_device_synchronize.graph) - @skipIfRocm def test_stream_args(self): # Test stream creation with default arguments @torch.jit.script @@ -119,7 +118,6 @@ def stream_args_all() -> bool: self.assertTrue(stream_default_args_for_priority) self.assertTrue(stream_args_all) - @skipIfRocm def test_event_args(self): # Test Event creation with default arguments @torch.jit.script diff --git a/test/jit/test_custom_operators.py b/test/jit/test_custom_operators.py index feb3b8eb8fb67..6d1fd07fe6c8f 100644 --- a/test/jit/test_custom_operators.py +++ b/test/jit/test_custom_operators.py @@ -123,3 +123,8 @@ def func(x): def test_generic_list(self): self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello') + + # https://github.com/pytorch/pytorch/issues/80508 + def test_where_no_scalar(self): + x = torch.rand(1, 3, 224, 224) + torch.ops.aten.where(x > 0.5, -1.5, 1.5) # does not raise diff --git a/test/jit/test_dtype_analysis.py b/test/jit/test_dtype_analysis.py index 4ccefd8152b7d..af1a7f3b24f28 100644 --- a/test/jit/test_dtype_analysis.py +++ b/test/jit/test_dtype_analysis.py @@ -353,6 +353,7 @@ def custom_rules_test_base(self, device, dtype, op, allow_eager_fail=False): # Run the Dtype Analysis graph = traced_fn.graph # Note this is a cached graph input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)] + input_tensors += [v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor)] self.prop_dtype_on_graph(graph, input_tensors) self.assert_output_dtype_equal(expected_res, graph) diff --git a/test/jit/test_hash.py b/test/jit/test_hash.py index cb1c1544b10a4..2ca1e9cda0a0e 100644 --- a/test/jit/test_hash.py +++ b/test/jit/test_hash.py @@ -75,7 +75,10 @@ def fn(f1: float, f2: float): self.checkScript(fn, (1.2345, float("inf"))) self.checkScript(fn, (float("inf"), float("inf"))) self.checkScript(fn, (1.2345, float('nan'))) - self.checkScript(fn, (float("nan"), float("nan"))) + if sys.version_info < (3, 10): + # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html : + # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity. + self.checkScript(fn, (float("nan"), float("nan"))) self.checkScript(fn, (float("nan"), float("inf"))) def test_hash_int(self): diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index d4c014e595dfc..aca95b8f62a42 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -18,6 +18,7 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.common_utils import skipIfTorchDynamo if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" @@ -1397,6 +1398,7 @@ def dict2(self): def dict_bool(self): return {True: 1} + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_dict_bool_conversion(self): def if_predicate(d: Dict[int, int]): if d: @@ -1425,6 +1427,7 @@ def ternary_predicate(d: Dict[int, int]): self.checkScript(ternary_predicate, ({1: 2, 3: 5},)) self.checkScript(ternary_predicate, ({},)) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_del(self): def inputs(): return {'hi': 2, 'bye': 3} @@ -1443,6 +1446,7 @@ def fn(x: Dict[str, int]) -> Dict[str, int]: with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", "x['hi']"): self.checkScript(fn, [{}]) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_dict_variance(self): """ `Dict[T1, _]` is not a subtype of `Dict[T2, _]`, even if `T1` is @@ -1497,6 +1501,7 @@ def test_dicts_with_different_value_types_are_invariant_recursive(self): r"Dict\[str, int\]\]"): torch.jit.script(test_dicts_with_different_value_types_are_invariant_recursive) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_keys(self): @torch.jit.script def keys(x: Dict[str, Tensor]) -> List[str]: @@ -1512,6 +1517,7 @@ def specialized_list(): self.assertTrue(set(specialized_list()) == set([1, 2, 3])) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_values(self): @torch.jit.script def values(x: Dict[str, Tensor]) -> List[Tensor]: @@ -1520,18 +1526,21 @@ def values(x: Dict[str, Tensor]) -> List[Tensor]: the_dict = self.dict() self.assertEqual(set(values(the_dict)), set(the_dict.values())) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_len(self): def length(x: Dict[str, Tensor]) -> int: return len(x) self.checkScript(length, (self.dict(),)) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_copy(self): def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]: return x.copy() self.checkScript(func, (self.dict(),)) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_items(self): def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]: return x.items() @@ -1547,6 +1556,7 @@ def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]: for item in eager_out: self.assertTrue(item in script_out) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_pop(self): def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]: return x.pop(key), x @@ -1570,6 +1580,7 @@ def default_pop(x: Dict[str, Tensor], key: str, default: Tensor) -> Tuple[Tensor tester(default_pop, 'a', torch.randn(2, 2)) tester(default_pop, 'x', torch.randn(2, 2)) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_setdefault(self): def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Tensor]: x.setdefault(key, default) @@ -1578,6 +1589,7 @@ def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Ten self.checkScript(setdefault, (self.dict(), 'a', torch.randn(2, 2))) self.checkScript(setdefault, (self.dict(), 'nonexistant', torch.randn(2, 2))) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_update(self): def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: a.update(b) @@ -1586,6 +1598,7 @@ def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor self.checkScript(update, (self.dict(), self.dict())) self.checkScript(update, (self.dict(), self.dict2())) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_update_existing_key(self): def foo() -> Dict[str, int]: a: Dict[str, int] = {} @@ -1595,6 +1608,7 @@ def foo() -> Dict[str, int]: self.checkScript(foo, ()) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_aug_assign(self): def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]: a['a'] += 1 @@ -1615,6 +1629,7 @@ def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]: self.checkScript(aug_assign_dict_tensor, (self.dict(),)) self.checkScript(aug_assign_dict_prim, ({'a': 3.0, 'b': 2.0, 'c': 4.0},)) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_popitem(self): @torch.jit.script def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]: @@ -1634,6 +1649,7 @@ def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor] self.assertTrue(isinstance(script_out[0][0], str)) self.assertTrue(isinstance(script_out[0][1], torch.Tensor)) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_clear(self): def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]: x.clear() @@ -1641,6 +1657,7 @@ def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]: self.checkScript(clear, (self.dict(),)) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_get(self): def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: return x.get(key) @@ -1654,6 +1671,7 @@ def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_get_boolkey(self): def get(x: Dict[bool, int], key: bool) -> Optional[int]: return x.get(key) @@ -1667,6 +1685,7 @@ def get_default(x: Dict[bool, int], key: bool) -> int: self.checkScript(get_default, (self.dict_bool(), True)) self.checkScript(get_default, (self.dict_bool(), False)) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_basic(self): def simple(x: Dict[str, int]) -> Dict[str, int]: return x @@ -1711,6 +1730,7 @@ def list_of_dicts() -> List[Dict[str, Tensor]]: self.checkScript(list_of_dicts, ()) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_mutability(self): @torch.jit.script def fn() -> Dict[str, int]: @@ -1720,12 +1740,14 @@ def fn() -> Dict[str, int]: self.assertEqual(fn(), {'ok': 10}) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_key_type(self): with self.assertRaisesRegexWithHighlight(RuntimeError, "but instead found type", "a[None]"): @torch.jit.script def fn(a: Dict[str, int]) -> int: return a[None] + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_loop(self): @torch.jit.script def fn(x: int) -> Dict[str, int]: @@ -1736,6 +1758,7 @@ def fn(x: int) -> Dict[str, int]: self.assertEqual(fn(10), {'ok': 9}) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_view(self): def fn(x, y): l = {"a": x} @@ -1746,6 +1769,7 @@ def fn(x, y): return a == b self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_membership(self): def fn(x: Dict[int, int], y: int) -> int: return x.get(y, 3) @@ -1766,6 +1790,7 @@ def optional(x: Dict[int, int], y: int) -> bool: def bad_types(x: Dict[int, int], y: int) -> int: return x.get(y) # noqa: T484 + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_dict_to_python(self): @torch.jit.ignore def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]: @@ -1777,6 +1802,7 @@ def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]: a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} self.checkScript(fn, (a_dict, ('a', 'c'))) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_ordered_dict(self): def test_func(fn, inputs): self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs)) @@ -1817,6 +1843,7 @@ def test_dict_error(): with self.assertRaisesRegexWithHighlight(Exception, "Arguments for call are not", "a[1] = 2"): torch.jit.script(test_dict_error) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_type_annotation_missing_contained_type(self): """ Test that the use of a Dict type annotation without contained @@ -1844,6 +1871,7 @@ def annotated_fn(input: Dict) -> Any: with self.assertRaisesRegex(RuntimeError, r"Attempted to use Dict without contained types"): m = torch.jit.script(annotated_fn) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_dict_preserves_order(self): def dict_ordering(): a : Dict[int, int] = {} @@ -1858,6 +1886,7 @@ def dict_ordering(): key, value = res[i] self.assertTrue(key == i and value == i + 1) + @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_optional_dict_construct(self): class M(torch.nn.Module): def use(self, buffer: Dict[str, Optional[torch.Tensor]]): diff --git a/test/jit/test_parametrization.py b/test/jit/test_parametrization.py index 6ef1aa6c01462..8882a431f2337 100644 --- a/test/jit/test_parametrization.py +++ b/test/jit/test_parametrization.py @@ -46,14 +46,14 @@ def test_traceable(self): def test_scriptable(self): # TODO: Need to fix the scripting in parametrizations - # Currently, all the tests below will throw UnsupportedNodeError + # Currently, all the tests below will throw torch.jit.Error model = nn.Linear(5, 5) parametrize.register_parametrization(model, "weight", self.Symmetric()) x = torch.randn(3, 5) y = model(x) - with self.assertRaises(torch.jit.frontend.UnsupportedNodeError): + with self.assertRaises(torch.jit.Error): # Check scripting works scripted_model = torch.jit.script(model) y_hat = scripted_model(x) diff --git a/test/jit/test_schema_check.py b/test/jit/test_schema_check.py index b9d108ac6bc12..306724ef29386 100644 --- a/test/jit/test_schema_check.py +++ b/test/jit/test_schema_check.py @@ -5,9 +5,9 @@ import torch from torch.utils._pytree import tree_map - +from torch.fx.operator_schemas import normalize_function from torch.testing._internal.schema_check_mode import SchemaCheckMode -from torch.utils._python_dispatch import enable_torch_dispatch_mode +from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode from torch.testing._internal.jit_utils import JitTestCase pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -22,7 +22,9 @@ # which is then used to test that SchemaCheckMode behaves as expected class IncorrectAliasTensor(torch.Tensor): - INCORRECT_OPS = {"aten::add", "aten::sub"} + ALIAS_ARG_OUT = {"aten::add"} + ALIAS_OUT_OUT = {"aten::aminmax"} + MUTATE_ARGS_OUT = {"aten::sub"} elem: torch.Tensor @@ -58,8 +60,14 @@ def wrap(e): return cls(e) if isinstance(e, torch.Tensor) else e unwrapped_args = tree_map(unwrap, args) out = func(*unwrapped_args, **tree_map(unwrap, kwargs)) - if func._schema.name in IncorrectAliasTensor.INCORRECT_OPS: + if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT: args[0].elem = out + if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT: + args[0].elem = torch.rand(args[0].elem.shape) + if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT: + incorrect_out = list(out) + incorrect_out[0] = incorrect_out[1] + return tree_map(wrap, tuple(incorrect_out)) return tree_map(wrap, out) @@ -71,18 +79,133 @@ def test_schema_check_mode_operator_order(self): with enable_torch_dispatch_mode(schema_check): x = torch.rand((3, 3), requires_grad=True) x.relu().sin() - self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) + self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) # Tests that SchemaCheckMode records operator order without grad - def test_schema_check_tensor_operator_order_without_grad(self): + def test_schema_check_mode_operator_order_without_grad(self): schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): x = torch.rand((3, 3), requires_grad=False) x.relu().sin() - self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) + self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) + + # Tests that SchemaCheckMode records mutations and aliases with none expected + def test_schema_check_mode_mutated_aliasing_none(self): + x = torch.rand((3, 3), requires_grad=True) + schema_check = SchemaCheckMode() + with enable_torch_dispatch_mode(schema_check): + actual = x.relu().sin() + self.assertEqual([], schema_check.mutated) + self.assertEqual([], schema_check.aliasing) + + # Tests that SchemaCheckMode records mutations and aliases with mutation expected + def test_schema_check_mode_mutated_aliasing_mutation(self): + actual = torch.rand((3, 3), requires_grad=False) + schema_check = SchemaCheckMode() + with enable_torch_dispatch_mode(schema_check): + actual.sinh_() + self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated) + self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing) + + # Tests that SchemaCheckMode records mutations and aliases with resize_ + def test_schema_check_mode_mutated_aliasing_resize_(self): + actual = torch.rand((3, 3), requires_grad=False) + schema_check = SchemaCheckMode() + with enable_torch_dispatch_mode(schema_check): + actual.resize_(9) + self.assertEqual([('aten::resize_', 'input')], schema_check.mutated) + self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing) + + # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs + def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): + actual = torch.rand((3, 3)) + y = actual + schema_check = SchemaCheckMode() + with enable_torch_dispatch_mode(schema_check): + actual.add_(y) + self.assertEqual( + [ + ('aten::add_', 'input'), + ('aten::add_', 'other') + ], + schema_check.mutated + ) + self.assertEqual( + [ + ('aten::add_', 'input', 'output_0'), + ('aten::add_', 'other', 'output_0') + ], + schema_check.aliasing + ) + + # Tests that SchemaCheckMode records mutations and alias with as_strided + def test_schema_check_mode_mutated_aliasing_as_strided(self): + x = torch.rand((3, 6, 4)) + schema_check = SchemaCheckMode() + with enable_torch_dispatch_mode(schema_check): + x.as_strided_([3, 6, 4], [9, 1, 1]) + self.assertEqual( + [ + ('aten::as_strided_', 'input') + ], + schema_check.mutated + ) + self.assertEqual( + [ + ('aten::as_strided_', 'input', 'output_0') + ], + schema_check.aliasing + ) + + # Tests that SchemaCheckMode records mutations and aliases with multiple outputs + def test_schema_check_mode_mutated_aliasing_multiple_outputs(self): + x = torch.arange(9.) + m_actual = torch.arange(9.) + e_actual = torch.zeros([9], dtype=torch.int32) + schema_check = SchemaCheckMode() + with enable_torch_dispatch_mode(schema_check): + torch.frexp(x, out=(m_actual, e_actual)) + self.assertEqual( + [ + ('aten::frexp', 'mantissa'), + ('aten::frexp', 'exponent') + ], + schema_check.mutated + ) + self.assertEqual( + [ + ('aten::frexp', 'mantissa', 'output_0'), + ('aten::frexp', 'exponent', 'output_1') + ], + schema_check.aliasing + ) + + # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs + def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): + x = torch.rand((3, 3)) + actual = torch.zeros(3) + schema_check = SchemaCheckMode() + with enable_torch_dispatch_mode(schema_check): + torch.aminmax(x, dim=0, out=[actual, actual]) + self.assertEqual( + [ + ('aten::aminmax', 'min'), + ('aten::aminmax', 'max') + ], + schema_check.mutated + ) + self.assertEqual( + [ + ('aten::aminmax', 'min', 'output_0'), + ('aten::aminmax', 'min', 'output_1'), + ('aten::aminmax', 'max', 'output_0'), + ('aten::aminmax', 'max', 'output_1') + ], + schema_check.aliasing + ) # Tests that SchemaCheckMode wraps torch.Tensor - def test_schema_check_tensor_functionality(self): + def test_schema_check_mode_functionality(self): x = torch.rand((3, 3), requires_grad=True) expected = x.relu().sin() with enable_torch_dispatch_mode(SchemaCheckMode()): @@ -90,7 +213,7 @@ def test_schema_check_tensor_functionality(self): self.assertEqual(expected, actual) # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overriden - def test_schema_check_tensor_functionality_default_replaced(self): + def test_schema_check_mode_functionality_default_replaced(self): x = torch.rand((3, 3), requires_grad=True) expected = x.add(x, alpha=2) with enable_torch_dispatch_mode(SchemaCheckMode()): @@ -98,7 +221,7 @@ def test_schema_check_tensor_functionality_default_replaced(self): self.assertEqual(expected, actual) # Tests that SchemaCheckMode wraps torch.Tensor when there is a Tensor[] argument - def test_schema_check_tensor_functionality_list_input(self): + def test_schema_check_mode_functionality_list_input(self): a = torch.rand((3, 3)) b = torch.rand((3, 3)) c = torch.rand((3, 3)) @@ -107,8 +230,16 @@ def test_schema_check_tensor_functionality_list_input(self): actual = torch.linalg.multi_dot([a, b, c]) self.assertEqual(expected, actual) + # Tests that SchemaCheckMode wraps torch.Tensor with an op that has the (a -> *) notation + def test_schema_check_mode_functionality_wildcard_after(self): + x = torch.rand((3, 3)) + expected = x.chunk(6) + with enable_torch_dispatch_mode(SchemaCheckMode()): + actual = x.chunk(6) + self.assertEqual(expected, actual) + # Tests that SchemaCheckMode wraps torch.Tensor when there is a kwarg tensor input - def test_schema_check_tensor_functionality_kwarg_tensor(self): + def test_schema_check_mode_functionality_kwarg_tensor(self): x = torch.rand((3, 5)) w = torch.rand((4)) expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True) @@ -117,7 +248,7 @@ def test_schema_check_tensor_functionality_kwarg_tensor(self): self.assertEqual(expected, actual) # Tests that SchemaCheckMode wraps torch.Tensor with a mutable op - def test_schema_check_tensor_functionality_mutable_inputs(self): + def test_schema_check_mode_functionality_mutable_inputs(self): expected = torch.rand((3, 3), requires_grad=False) actual = torch.clone(expected) expected.sinh_() @@ -125,30 +256,97 @@ def test_schema_check_tensor_functionality_mutable_inputs(self): actual.sinh_() self.assertEqual(expected, actual) + # Tests that SchemaCheckMode wraps Torch.tensor when inputs alias + def test_schema_check_mode_functionality_aliasing_inputs(self): + expected = torch.rand((3, 3)) + x = expected + actual = torch.clone(expected) + y = actual + expected.add_(x) + with enable_torch_dispatch_mode(SchemaCheckMode()): + actual.add_(y) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs + def test_schema_check_mode_functionality_with_multiple_outputs(self): + x = torch.arange(9.) + m_expected, e_expected = torch.frexp(x) + m_actual = torch.arange(9.) + e_actual = torch.zeros([9], dtype=torch.int32) + with enable_torch_dispatch_mode(SchemaCheckMode()): + torch.frexp(x, out=(m_actual, e_actual)) + self.assertEqual(m_expected, m_actual) + self.assertEqual(e_expected, e_actual) + + # Tests that SchemaCheckMode wraps Torch.tensor with aliasing ouputs due to aliasing inputs + def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self): + x = torch.rand((3, 3)) + actual = torch.zeros(3) + with enable_torch_dispatch_mode(SchemaCheckMode()): + torch.aminmax(x, dim=0, out=[actual, actual]) + self.assertEqual(torch.amax(x, dim=0), actual) + + # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input + def test_schema_check_mode_functionality_device_input(self): + with enable_torch_dispatch_mode(SchemaCheckMode()): + x = torch.rand((3, 3), device="cpu", dtype=torch.double) + y = x + x + self.assertEqual(x + x, y) + + # Tests that SchemaCheckMode wraps Torch.tensor in special training op edge case + def test_schema_check_mode_functionality_training_op(self): + x = torch.rand((3, 3), requires_grad=True) + batch = torch.nn.BatchNorm1d(3, track_running_stats=True) + expected = batch(x) + with enable_torch_dispatch_mode(SchemaCheckMode()): + actual = batch(x) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps Torch.tensor with nested training op edge case + def test_schema_check_mode_functionality_nested_training_op(self): + actual = torch.rand((3, 3)) + batch = torch.nn.BatchNorm1d(3, track_running_stats=True) + expected = torch.clone(actual) + expected.sinh_() + expected.tanh_() + expected.relu_() + expected = batch(expected) + + with enable_torch_dispatch_mode(SchemaCheckMode()): + actual.sinh_() + actual.tanh_() + actual.relu_() + actual = batch(actual) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps Torch.tensor with empty list input + def test_schema_check_mode_empty_list_input(self): + expected = torch.atleast_1d([]) + with enable_torch_dispatch_mode(SchemaCheckMode()): + actual = torch.atleast_1d([]) + self.assertEqual(expected, actual) + # Tests that an exception is raised for a mismatching mutation def test_mutation_check_fail(self): - with self.assertRaisesRegex(RuntimeError, "Argument running_mean is not defined as mutable but was mutated"): - x = torch.rand((3, 3), requires_grad=True) - batch = torch.nn.BatchNorm1d(3, track_running_stats=True) + with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): + x = torch.rand((3, 3)) + y = torch.rand((3, 3)) with enable_torch_dispatch_mode(SchemaCheckMode()): - batch(x) + IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y)) - # Tests that an exception is raised for a mismatching mutation over multiple ops + # # Tests that an exception is raised for a mismatching mutation over multiple ops def test_mutation_check_fail_multiple_operators(self): - with self.assertRaisesRegex(RuntimeError, "Argument running_mean is not defined as mutable but was mutated"): - x = torch.rand((3, 3), requires_grad=True) - batch = torch.nn.BatchNorm1d(3, track_running_stats=True) + with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): + x = torch.rand((3, 3)) + y = torch.rand((3, 3)) with enable_torch_dispatch_mode(SchemaCheckMode()): - x = x.sinh() - x = x.tanh() - x = x.relu() - batch(x) + IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y)) # Tests that an exception is raised for a mismatching alias - def test_alias_check_fail(self): + def test_alias_check_fail_simple(self): with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): x = torch.rand((3, 3), requires_grad=True) - y = torch.zeros((3, 3)) + y = torch.rand((3, 3)) with enable_torch_dispatch_mode(SchemaCheckMode()): IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2) @@ -168,10 +366,80 @@ def test_alias_check_fail_multiple_operators_centered(self): with enable_torch_dispatch_mode(SchemaCheckMode()): IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu() - # Tests that isAliasOf returns as expected - def test_is_alias_of(self): + # Tests that an exception is raised for a centered mismatching alias over multiple ops + def test_alias_check_fail_outputs_unexpectedly_aliasing(self): + with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"): + x = torch.rand((3, 3)) + s = SchemaCheckMode() + with enable_torch_dispatch_mode(s): + IncorrectAliasTensor(x).aminmax(dim=0) + + # Tests that is_alias_of returns as expected + def test_is_alias_of_basic(self): x = torch.rand((3, 3), requires_grad=True) y = torch.rand((3, 3), requires_grad=True) y = x.add(x, alpha=2) self.assertTrue(torch._C._is_alias_of(x, x)) self.assertFalse(torch._C._is_alias_of(x, y)) + + # Tests that is_alias_of returns as expected with empty containers + def test_is_alias_of_empty_container(self): + x = [] + y = torch.rand((3, 3), requires_grad=True) + self.assertFalse(torch._C._is_alias_of(x, x)) + self.assertFalse(torch._C._is_alias_of(x, y)) + + # Tests that overlaps returns as expected + def test_overlaps_basic(self): + x = torch.rand((3, 3), requires_grad=True) + y = torch.rand((3, 3), requires_grad=True) + z = [x, y] + self.assertTrue(torch._C._overlaps(x, x)) + self.assertFalse(torch._C._overlaps(x, y)) + self.assertTrue(torch._C._overlaps(z, x)) + self.assertTrue(torch._C._overlaps(z, y)) + + # Tests that overlaps returns correctly with empty containers + def test_overlaps_empty_container(self): + x = [] + y = [torch.rand((3, 3), requires_grad=True)] + # Anything overlaps nothing + self.assertTrue(torch._C._overlaps(y, x)) + self.assertTrue(torch._C._overlaps(y, y)) + + # Tests that SchemaInfo Bindings work as expected + def test_schema_info_bind_basic(self): + class SchemaInfoBindTestMode(TorchDispatchMode): + def __init__(self, test_self): + self.test_self = test_self + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + named_arg_list = normalize_function( + func, + args, + kwargs, + normalize_to_only_use_kwargs=True + ).kwargs + schema_info_value_test = torch._C._SchemaInfo(func._schema) + schema_info_values_test = torch._C._SchemaInfo(func._schema) + self.test_self.assertFalse(schema_info_value_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + self.test_self.assertFalse(schema_info_values_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + for i in named_arg_list: + schema_info_value_test.add_argument_value(i, named_arg_list[i]) + schema_info_values_test.add_argument_values(named_arg_list) + self.test_self.assertTrue(schema_info_value_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + self.test_self.assertTrue(schema_info_values_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + + return func(*args, **kwargs) + x = torch.rand((3, 3)) + schemaInfoCheck = SchemaInfoBindTestMode(self) + with enable_torch_dispatch_mode(schemaInfoCheck): + x.add(x) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 99d078dd4ad18..50fdec94b9fc0 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -17,7 +17,7 @@ sys.path.append(pytorch_test_dir) from torch.testing._internal.common_utils import suppress_warnings, \ skipIfCompiledWithoutNumpy, enable_profiling_mode_for_profiling_tests, \ - IS_SANDCASTLE, TemporaryFileName, skipIfCrossRef + IS_SANDCASTLE, TemporaryFileName, skipIfCrossRef, skipIfTorchDynamo from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, \ _tmp_donotuse_dont_inline_everything, _trace, RUN_CUDA, \ RUN_CUDA_MULTI_GPU, make_global @@ -1795,6 +1795,7 @@ def foo(bar, baz): assert 'baz' in graph_str assert 'quick_brown_fox' in graph_str + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_tracing_hooks(self): class Net(nn.Module): def __init__(self): diff --git a/test/lazy/test_reuse_ir.py b/test/lazy/test_reuse_ir.py index 5621b7364a69f..2d19fe1a5b539 100644 --- a/test/lazy/test_reuse_ir.py +++ b/test/lazy/test_reuse_ir.py @@ -121,7 +121,7 @@ def testBatchNorm(self): torch._lazy.mark_step() torch.testing.assert_close(z.cpu(), z_lazy.cpu()) - assert metrics.counter_value("IrNodeReused_torch::lazy::TSNativeBatchNormForward") >= 7 + assert metrics.counter_value("IrNodeReused_torch::lazy::NativeBatchNorm") >= 7 metrics.reset() torch._lazy.ir_cache.reset() diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index 400479896afdb..a60183ba50ecc 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -54,10 +54,25 @@ def init_lists(): 'pow', # incorrect results 'addcdiv', # incorrect results (on CI not locally?) ]) + # The following ops all show up directly in ts_native_functions.yaml, + # but run functionalized versions of the composite kernels in core. + # This means that we don't expect the ops to show directly in the LTC metrics. + FUNCTIONAL_DECOMPOSE_LIST = set([ + 'block_diag', + 'new_empty_strided', + 'narrow_copy', + 'pixel_shuffle', + 'pixel_unshuffle', + 'select_backward', + '_trilinear', + 'linalg_inv_ex', + 'linalg_pinv.atol_rtol_tensor', + 'logsumexp', + ]) - return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST) + return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST) -(LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST) = init_lists() +(LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST) = init_lists() torch.manual_seed(42) @@ -96,9 +111,50 @@ def testConvolutionBackward(self): torch.testing.assert_close(weight_copy_grad.cpu(), weight_grad.cpu()) torch.testing.assert_close(inp_copy_grad.cpu(), inp_grad.cpu()) + def test_view_mark_step_preserved(self): + test_device = get_test_device() + inp = torch.rand(4, device=test_device) + inp_lazy = clone_move(inp) + + def foo(x, *, mark_step): + y = x.view(2, 2) + y.add_(1) + z = x + x + + if mark_step: + torch._lazy.mark_step() + + # y and x should contiue to be aliased after the mark_step call. + y.add_(1) + return x + + + out_ref = foo(inp, mark_step=False) + out = foo(inp_lazy, mark_step=True) + # out will have some pending mutations, which will be synced by the .cpu() call. + torch.testing.assert_close(out_ref.cpu(), out.cpu()) + + def test_tensor_ctr(self): + test_device = get_test_device() + inp = torch.tensor([[1, 2, 3, 4, 5]], device=test_device) + inp_lazy = torch.tensor([[1, 2, 3, 4, 5]], device='lazy') + + def foo(x): + # Calling a view op to ensure that functionalization wrapping occurs. + return x.view(-1) + + out_ref = foo(inp) + out = foo(inp_lazy) + torch.testing.assert_close(out_ref.cpu(), out.cpu()) + + class TestLazyOpInfo(TestCase): - @ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST], allowed_dtypes=(torch.float,)) + @ops([op for op in op_db + if op.name in LAZY_OPS_LIST + and op.name not in SKIP_RUNTIME_ERROR_LIST + and op.name not in FUNCTIONAL_DECOMPOSE_LIST + ], allowed_dtypes=(torch.float,)) def test_dispatched_to_lazy(self, device, dtype, op): def get_name(op): l = [op.name] diff --git a/test/mkl_verbose.py b/test/mkl_verbose.py new file mode 100644 index 0000000000000..879168f866b7b --- /dev/null +++ b/test/mkl_verbose.py @@ -0,0 +1,17 @@ +import argparse +import torch + +def run_model(level): + m = torch.nn.Linear(20, 30) + input = torch.randn(128, 20) + with torch.backends.mkl.verbose(level): + m(input) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--verbose-level", default=0, type=int) + args = parser.parse_args() + try: + run_model(args.verbose_level) + except Exception as e: + print(e) diff --git a/test/mkldnn_verbose.py b/test/mkldnn_verbose.py new file mode 100644 index 0000000000000..804eb9a24567a --- /dev/null +++ b/test/mkldnn_verbose.py @@ -0,0 +1,26 @@ +import argparse +import torch + +class Module(torch.nn.Module): + def __init__(self): + super(Module, self).__init__() + self.conv = torch.nn.Conv2d(1, 10, 5, 1) + + def forward(self, x): + y = self.conv(x) + return y + +def run_model(level): + m = Module().eval() + d = torch.rand(1, 1, 112, 112) + with torch.backends.mkldnn.verbose(level): + m(d) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--verbose-level", default=0, type=int) + args = parser.parse_args() + try: + run_model(args.verbose_level) + except Exception as e: + print(e) diff --git a/test/mobile/custom_build/CMakeLists.txt b/test/mobile/custom_build/CMakeLists.txt index 339ee953f7ec7..521569176c307 100644 --- a/test/mobile/custom_build/CMakeLists.txt +++ b/test/mobile/custom_build/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.1) project(custom_build_project) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") # Find torch library find_package(Torch REQUIRED) diff --git a/test/mobile/lightweight_dispatch/build.sh b/test/mobile/lightweight_dispatch/build.sh index 0ad0d05ce2575..b478f048ff8ed 100755 --- a/test/mobile/lightweight_dispatch/build.sh +++ b/test/mobile/lightweight_dispatch/build.sh @@ -19,13 +19,16 @@ TEST_SRC_ROOT="$PWD/test/mobile/lightweight_dispatch" pushd "$CUSTOM_TEST_ARTIFACT_BUILD_DIR" # prepare test -python "$TEST_SRC_ROOT/tests_setup.py" setup +OP_LIST="lightweight_dispatch_ops.yaml" +export SELECTED_OP_LIST=$TEST_SRC_ROOT/$OP_LIST +python "$TEST_SRC_ROOT/tests_setup.py" setup "$SELECTED_OP_LIST" export USE_DISTRIBUTED=0 export USE_LIGHTWEIGHT_DISPATCH=1 export STATIC_DISPATCH_BACKEND="CPU" export BUILD_LITE_INTERPRETER=1 +export USE_FBGEMM=0 python "${BUILD_LIBTORCH_PY}" ret=$? @@ -42,13 +45,7 @@ if ! build/bin/test_codegen_unboxing; then fi # shutdown test -python "$TEST_SRC_ROOT/tests_setup.py" shutdown - -# run lite interpreter tests -if ! build/bin/test_lite_interpreter_runtime; then - echo "test_lite_interpreter_runtime has failure!" - exit 1 -fi +python "$TEST_SRC_ROOT/tests_setup.py" shutdown "$SELECTED_OP_LIST" popd diff --git a/test/mobile/lightweight_dispatch/lightweight_dispatch_ops.yaml b/test/mobile/lightweight_dispatch/lightweight_dispatch_ops.yaml new file mode 100644 index 0000000000000..525cd5a75c7c5 --- /dev/null +++ b/test/mobile/lightweight_dispatch/lightweight_dispatch_ops.yaml @@ -0,0 +1,6 @@ +# base ops for preparing inputs +- aten::copy_ +- aten::detach +- aten::fill_.Tensor +- aten::to.device +# model introduced ops begin from here diff --git a/test/mobile/lightweight_dispatch/test_codegen_unboxing.cpp b/test/mobile/lightweight_dispatch/test_codegen_unboxing.cpp index 07a845d6008ba..80f26e68d2606 100644 --- a/test/mobile/lightweight_dispatch/test_codegen_unboxing.cpp +++ b/test/mobile/lightweight_dispatch/test_codegen_unboxing.cpp @@ -60,10 +60,10 @@ namespace jit { namespace mobile { // covers int[], ScalarType?, Layout?, Device?, bool? TEST(LiteInterpreterTest, Ones) { - // Load check in model: ones.ptl - auto testModelFile = "ones.ptl"; + // Load check in model: ModelWithDTypeDeviceLayoutPinMemory.ptl + auto testModelFile = "ModelWithDTypeDeviceLayoutPinMemory.ptl"; - // class Model(torch.nn.Module): + // class ModelWithDTypeDeviceLayoutPinMemory(torch.nn.Module): // def forward(self, x: int): // a = torch.ones([3, x], dtype=torch.int64, layout=torch.strided, device="cpu") // return a @@ -75,10 +75,10 @@ TEST(LiteInterpreterTest, Ones) { } TEST(LiteInterpreterTest, Index) { - // Load check in model: index.ptl - auto testModelFile = "index.ptl"; + // Load check in model: ModelWithTensorOptional.ptl + auto testModelFile = "ModelWithTensorOptional.ptl"; - // class Model(torch.nn.Module): + // class ModelWithTensorOptional(torch.nn.Module): // def forward(self, index): // a = torch.zeros(2, 2) // a[0][1] = 1 @@ -98,10 +98,10 @@ TEST(LiteInterpreterTest, Index) { } TEST(LiteInterpreterTest, Gradient) { - // Load check in model: gradient.ptl - auto testModelFile = "gradient.ptl"; + // Load check in model: ModelWithScalarList.ptl + auto testModelFile = "ModelWithScalarList.ptl"; - // class Model(torch.nn.Module): + // class ModelWithScalarList(torch.nn.Module): // def forward(self, a: int): // values = torch.tensor([4., 1., 1., 16.], ) // if a == 0: @@ -120,8 +120,8 @@ TEST(LiteInterpreterTest, Gradient) { } TEST(LiteInterpreterTest, Upsample) { - // Load check in model: upsample.ptl - auto testModelFile = "upsample.ptl"; + // Load check in model: ModelWithFloatList.ptl + auto testModelFile = "ModelWithFloatList.ptl"; // model = torch.nn.Upsample(scale_factor=(2.0,), mode="linear") Module bc = _load_for_mobile(testModelFile); @@ -132,10 +132,10 @@ TEST(LiteInterpreterTest, Upsample) { } TEST(LiteInterpreterTest, IndexTensor) { - // Load check in model: Index_Tensor.ptl - auto testModelFile = "index_Tensor.ptl"; + // Load check in model: ModelWithListOfOptionalTensors.ptl + auto testModelFile = "ModelWithListOfOptionalTensors.ptl"; - // class Model(torch.nn.Module): + // class ModelWithListOfOptionalTensors(torch.nn.Module): // def forward(self, index): // values = torch.tensor([4., 1., 1., 16.], ) // return values[[index, torch.tensor(0)]] @@ -147,8 +147,8 @@ TEST(LiteInterpreterTest, IndexTensor) { } TEST(LiteInterpreterTest, Conv2d) { - // Load check in model: conv2d.ptl - auto testModelFile = "conv2d.ptl"; + // Load check in model: ModelWithArrayOfInt.ptl + auto testModelFile = "ModelWithArrayOfInt.ptl"; // model = torch.nn.Conv2d(1, 2, (2, 2), stride=(1, 1), padding=(1, 1)) Module bc = _load_for_mobile(testModelFile); @@ -158,10 +158,10 @@ TEST(LiteInterpreterTest, Conv2d) { } TEST(LiteInterpreterTest, AddTensor) { - // Load check in model: add_Tensor.ptl - auto testModelFile = "add_Tensor.ptl"; + // Load check in model: ModelWithTensors.ptl + auto testModelFile = "ModelWithTensors.ptl"; - // class Model(torch.nn.Module): + // class ModelWithTensors(torch.nn.Module): // def forward(self, a): // values = torch.ones(size=[2, 3], names=['N', 'C']) // values[0][0] = a[0] @@ -174,10 +174,10 @@ TEST(LiteInterpreterTest, AddTensor) { } TEST(LiteInterpreterTest, DivideTensor) { - // Load check in model: add_Tensor.ptl - auto testModelFile = "divide_Tensor.ptl"; + // Load check in model: ModelWithStringOptional.ptl + auto testModelFile = "ModelWithStringOptional.ptl"; - // class Model(torch.nn.Module): + // class ModelWithStringOptional(torch.nn.Module): // def forward(self, b): // a = torch.tensor(3, dtype=torch.int64) // out = torch.empty(size=[1], dtype=torch.float) @@ -193,10 +193,10 @@ TEST(LiteInterpreterTest, DivideTensor) { } TEST(LiteInterpreterTest, MultipleOps) { - // Load check in model: multiple_ops.ptl - auto testModelFile = "multiple_ops.ptl"; + // Load check in model: ModelWithMultipleOps.ptl + auto testModelFile = "ModelWithMultipleOps.ptl"; - // class Model(torch.nn.Module): + // class ModelWithMultipleOps(torch.nn.Module): // def __init__(self): // super(Model, self).__init__() // self.ops = torch.nn.Sequential( diff --git a/test/mobile/lightweight_dispatch/tests_setup.py b/test/mobile/lightweight_dispatch/tests_setup.py index 91af29796b9d9..6059961132a23 100644 --- a/test/mobile/lightweight_dispatch/tests_setup.py +++ b/test/mobile/lightweight_dispatch/tests_setup.py @@ -1,203 +1,143 @@ +import functools import os +from io import BytesIO +import shutil + import sys import torch +from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list +_OPERATORS = set() +_FILENAMES = [] +_MODELS = [] -class Setup(object): - def setup(self): - raise NotImplementedError() - - def shutdown(self): - raise NotImplementedError() - - -class FileSetup(object): - path = None - - def shutdown(self): - if os.path.exists(self.path): - os.remove(self.path) - pass +def save_model(cls): + """Save a model and dump all the ops""" -class ModelWithDTypeDeviceLayoutPinMemory(FileSetup): - path = 'ones.ptl' + @functools.wraps(cls) + def wrapper_save(): + _MODELS.append(cls) + model = cls() + scripted = torch.jit.script(model) + buffer = BytesIO(scripted._save_to_buffer_for_lite_interpreter()) + buffer.seek(0) + mobile_module = _load_for_lite_interpreter(buffer) + ops = _export_operator_list(mobile_module) + _OPERATORS.update(ops) + path = f"./{cls.__name__}.ptl" + _FILENAMES.append(path) + scripted._save_for_lite_interpreter(path) - def setup(self): - class Model(torch.nn.Module): - def forward(self, x: int): - a = torch.ones(size=[3, x], dtype=torch.int64, layout=torch.strided, device="cpu", pin_memory=False) - return a + return wrapper_save - model = Model() - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) +@save_model +class ModelWithDTypeDeviceLayoutPinMemory(torch.nn.Module): + def forward(self, x: int): + a = torch.ones(size=[3, x], dtype=torch.int64, layout=torch.strided, device="cpu", pin_memory=False) + return a -class ModelWithTensorOptional(FileSetup): - path = 'index.ptl' - - def setup(self): - class Model(torch.nn.Module): - def forward(self, index): - a = torch.zeros(2, 2) - a[0][1] = 1 - a[1][0] = 2 - a[1][1] = 3 - return a[index] - - model = Model() - - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) +@save_model +class ModelWithTensorOptional(torch.nn.Module): + def forward(self, index): + a = torch.zeros(2, 2) + a[0][1] = 1 + a[1][0] = 2 + a[1][1] = 3 + return a[index] # gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[] -class ModelWithScalarList(FileSetup): - path = 'gradient.ptl' - - def setup(self): - - class Model(torch.nn.Module): - def forward(self, a: int): - values = torch.tensor([4., 1., 1., 16.], ) - if a == 0: - return torch.gradient(values, spacing=torch.scalar_tensor(2., dtype=torch.float64)) - elif a == 1: - return torch.gradient(values, spacing=[torch.tensor(1.).item()]) - - model = Model() - - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) +@save_model +class ModelWithScalarList(torch.nn.Module): + def forward(self, a: int): + values = torch.tensor([4., 1., 1., 16.], ) + if a == 0: + return torch.gradient(values, spacing=torch.scalar_tensor(2., dtype=torch.float64)) + elif a == 1: + return torch.gradient(values, spacing=[torch.tensor(1.).item()]) # upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor -class ModelWithFloatList(FileSetup): - path = 'upsample.ptl' - - def setup(self): - model = torch.nn.Upsample(scale_factor=(2.0,), mode="linear", align_corners=False, recompute_scale_factor=True) - - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) +@save_model +class ModelWithFloatList(torch.nn.Upsample): + def __init__(self): + super().__init__(scale_factor=(2.0,), mode="linear", align_corners=False, recompute_scale_factor=True) # index.Tensor(Tensor self, Tensor?[] indices) -> Tensor -class ModelWithListOfOptionalTensors(FileSetup): - path = 'index_Tensor.ptl' - - def setup(self): - class Model(torch.nn.Module): - def forward(self, index): - values = torch.tensor([[4., 1., 1., 16.]]) - return values[torch.tensor(0), index] - - model = Model() - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) +@save_model +class ModelWithListOfOptionalTensors(torch.nn.Module): + def forward(self, index): + values = torch.tensor([[4., 1., 1., 16.]]) + return values[torch.tensor(0), index] # conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, # int groups=1) -> Tensor -class ModelWithArrayOfInt(FileSetup): - path = 'conv2d.ptl' - - def setup(self): - model = torch.nn.Conv2d(1, 2, (2, 2), stride=(1, 1), padding=(1, 1)) - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) +@save_model +class ModelWithArrayOfInt(torch.nn.Conv2d): + def __init__(self): + super().__init__(1, 2, (2, 2), stride=(1, 1), padding=(1, 1)) # add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor # ones_like(Tensor self, *, ScalarType?, dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, # MemoryFormat? memory_format=None) -> Tensor -class ModelWithTensors(FileSetup): - path = 'add_Tensor.ptl' - - def setup(self): - class Model(torch.nn.Module): - def forward(self, a): - b = torch.ones_like(a) - return a + b - model = Model() - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) - - -class ModelWithStringOptional(FileSetup): - path = 'divide_Tensor.ptl' - - def setup(self): - class Model(torch.nn.Module): - def forward(self, b): - a = torch.tensor(3, dtype=torch.int64) - out = torch.empty(size=[1], dtype=torch.float) - torch.div(b, a, out=out) - return [torch.div(b, a, rounding_mode='trunc'), out] - model = Model() - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) - - -class ModelWithMultipleOps(FileSetup): - path = 'multiple_ops.ptl' - - def setup(self): - class Model(torch.nn.Module): - def __init__(self): - super(Model, self).__init__() - self.ops = torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Flatten(), - ) - - def forward(self, x): - x[1] = -2 - return self.ops(x) - - model = Model() - # Script the model and save - script_model = torch.jit.script(model) - script_model._save_for_lite_interpreter(self.path) - - -tests = [ - ModelWithDTypeDeviceLayoutPinMemory(), - ModelWithTensorOptional(), - ModelWithScalarList(), - ModelWithFloatList(), - ModelWithListOfOptionalTensors(), - ModelWithArrayOfInt(), - ModelWithTensors(), - ModelWithStringOptional(), - ModelWithMultipleOps(), -] - - -def setup(): - for test in tests: - test.setup() - - -def shutdown(): - for test in tests: - test.shutdown() +@save_model +class ModelWithTensors(torch.nn.Module): + def forward(self, a): + b = torch.ones_like(a) + return a + b + +@save_model +class ModelWithStringOptional(torch.nn.Module): + def forward(self, b): + a = torch.tensor(3, dtype=torch.int64) + out = torch.empty(size=[1], dtype=torch.float) + torch.div(b, a, out=out) + return [torch.div(b, a, rounding_mode='trunc'), out] + + +@save_model +class ModelWithMultipleOps(torch.nn.Module): + def __init__(self): + super().__init__() + self.ops = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Flatten(), + ) + + def forward(self, x): + x[1] = -2 + return self.ops(x) if __name__ == "__main__": command = sys.argv[1] + ops_yaml = sys.argv[2] + backup = ops_yaml + ".bak" if command == "setup": - setup() + tests = [ + ModelWithDTypeDeviceLayoutPinMemory(), + ModelWithTensorOptional(), + ModelWithScalarList(), + ModelWithFloatList(), + ModelWithListOfOptionalTensors(), + ModelWithArrayOfInt(), + ModelWithTensors(), + ModelWithStringOptional(), + ModelWithMultipleOps(), + ] + shutil.copyfile(ops_yaml, backup) + with open(ops_yaml, 'a') as f: + for op in _OPERATORS: + f.write(f"- {op}\n") elif command == "shutdown": - shutdown() + for file in _MODELS: + if os.path.isfile(file): + os.remove(file) + shutil.move(backup, ops_yaml) diff --git a/test/mobile/model_test/gen_test_model.py b/test/mobile/model_test/gen_test_model.py index e9e3908630be4..370e8d08541f7 100644 --- a/test/mobile/model_test/gen_test_model.py +++ b/test/mobile/model_test/gen_test_model.py @@ -34,7 +34,7 @@ ) from quantization_ops import ( GeneralQuantModule, - DynamicQuantModule, + # DynamicQuantModule, StaticQuantModule, FusedQuantModule, ) @@ -89,7 +89,8 @@ "nn_utils_ops": NNUtilsModule(), # quantization ops "general_quant_ops": GeneralQuantModule(), - "dynamic_quant_ops": DynamicQuantModule(), + # TODO(sdym@fb.com): fix and re-enable dynamic_quant_ops + # "dynamic_quant_ops": DynamicQuantModule(), "static_quant_ops": StaticQuantModule(), "fused_quant_ops": FusedQuantModule(), # TorchScript buildin ops diff --git a/test/mobile/model_test/model_ops.yaml b/test/mobile/model_test/model_ops.yaml index 06a3640e4cbe7..43e4876451e38 100644 --- a/test/mobile/model_test/model_ops.yaml +++ b/test/mobile/model_test/model_ops.yaml @@ -1,752 +1,442 @@ root_operators: - aten::Bool.Tensor: 19 - aten::Bool.int: 7 - aten::Float.Scalar: 18 - aten::Float.Tensor: 11 - aten::Float.str: 6 - aten::FloatImplicit: 2 - aten::Int.Scalar: 19 - aten::Int.Tensor: 35 - aten::Int.float: 6 - aten::Int.str: 12 - aten::IntImplicit: 11 - aten::ScalarImplicit: 3 - aten::__and__.Tensor: 13 - aten::__and__.bool: 11 + aten::Bool.Tensor: 32 + aten::Bool.int: 34 + aten::Float.Scalar: 30 + aten::Float.Tensor: 20 + aten::Float.str: 32 + aten::FloatImplicit: 5 + aten::Int.Scalar: 60 + aten::Int.Tensor: 66 + aten::Int.float: 11 + aten::Int.str: 20 + aten::IntImplicit: 37 + aten::ScalarImplicit: 9 + aten::__and__.Tensor: 21 + aten::__and__.bool: 43 aten::__and__.int: 2 - aten::__contains__.int: 5 - aten::__contains__.int_list: 17 - aten::__contains__.str: 22 - aten::__contains__.str_list: 5 - aten::__derive_index: 24 - aten::__getitem__.Dict_int: 4 - aten::__getitem__.Dict_str: 39 - aten::__getitem__.str: 20 - aten::__getitem__.t: 178 - aten::__is__: 83 - aten::__isnot__: 81 + aten::__contains__.int: 6 + aten::__contains__.int_list: 37 + aten::__contains__.str: 60 + aten::__contains__.str_list: 11 + aten::__derive_index: 106 + aten::__getitem__.Dict_int: 2 + aten::__getitem__.Dict_str: 30 + aten::__getitem__.str: 39 + aten::__getitem__.t: 328 + aten::__is__: 80 + aten::__isnot__: 79 aten::__lshift__.int: 2 aten::__not__: 32 - aten::__range_length: 23 + aten::__range_length: 106 aten::__rshift__.int: 2 - aten::__xor__.bool: 10 - aten::_aminmax: 4 - aten::_convolution: 12 + aten::__xor__.bool: 16 + aten::_aminmax: 18 + aten::_convolution: 27 aten::_convolution.deprecated: 3 - aten::_infer_size: 7 - aten::_make_per_tensor_quantized_tensor: 2 - aten::_pack_padded_sequence: 10 - aten::_pad_packed_sequence: 10 - aten::_reshape_from_tensor: 10 + aten::_infer_size: 9 + aten::_make_per_tensor_quantized_tensor: 1 + aten::_pack_padded_sequence: 23 + aten::_pad_packed_sequence: 23 + aten::_reshape_from_tensor: 16 aten::_set_item.int: 7 - aten::_set_item.str: 163 - aten::_set_item.t: 8 - aten::_shape_as_tensor: 10 + aten::_set_item.str: 315 + aten::_set_item.t: 34 + aten::_shape_as_tensor: 16 + aten::abs: 1 + aten::acos: 1 aten::adaptive_avg_pool1d: 1 - aten::adaptive_avg_pool2d: 33 + aten::adaptive_avg_pool2d: 52 aten::adaptive_avg_pool3d: 1 - aten::add.Scalar: 33 - aten::add.Tensor: 63 - aten::add.float: 5 - aten::add.int: 49 + aten::add: 15 + aten::add.Scalar: 81 + aten::add.Tensor: 138 + aten::add.float: 9 + aten::add.int: 128 + aten::add.int_float: 1 aten::add.out: 2 - aten::add.str: 29 - aten::add.t: 11 - aten::add_.Scalar: 15 - aten::add_.Tensor: 29 + aten::add.str: 93 + aten::add.t: 37 + aten::add_.Scalar: 24 + aten::add_.Tensor: 57 + aten::add_.t: 4 aten::addcmul: 2 aten::addmm: 7 - aten::all: 6 - aten::allclose: 1 - aten::any: 14 - aten::append.t: 59 - aten::arange: 16 - aten::arange.start: 6 - aten::arange.start_step: 16 - aten::argmax: 2 - aten::as_strided: 10 - aten::as_tensor.list: 4 - aten::atan: 4 - aten::avg_pool1d: 6 - aten::avg_pool2d: 7 - aten::backward: 23 - aten::batch_norm: 15 - aten::binary_cross_entropy: 15 - aten::binary_cross_entropy_with_logits: 3 - aten::bitwise_not: 13 - aten::bmm: 16 - aten::broadcast_tensors: 1 - aten::cat: 90 - aten::ceil: 3 - aten::ceil.float: 7 - aten::chunk: 19 - aten::clamp: 36 - aten::clamp_: 12 - aten::clamp_min: 3 + aten::all: 14 + aten::all.dim: 1 + aten::allclose: 2 + aten::any: 22 + aten::append.t: 155 + aten::arange: 51 + aten::arange.start: 32 + aten::arange.start_step: 56 + aten::argmax: 5 + aten::as_strided: 16 + aten::as_tensor.list: 6 + aten::atan: 8 + aten::atan2: 1 + aten::avg_pool1d: 32 + aten::avg_pool2d: 4 + aten::backward: 30 + aten::batch_norm: 24 + aten::binary_cross_entropy: 16 + aten::binary_cross_entropy_with_logits: 6 + aten::bitwise_not: 20 + aten::bmm: 48 + aten::broadcast_tensors: 4 + aten::cat: 180 + aten::ceil: 7 + aten::ceil.float: 35 + aten::chunk: 72 + aten::clamp: 65 + aten::clamp_: 19 + aten::clamp_min: 7 aten::clear.str: 2 - aten::clone: 26 + aten::clone: 39 aten::coalesce: 2 - aten::conj: 1 - aten::constant_pad_nd: 17 - aten::contiguous: 113 - aten::conv1d: 12 - aten::conv2d: 10 - aten::conv_transpose2d.input: 5 - aten::copy_: 15 - aten::copy_.Tensor: 27 + aten::complex: 1 + aten::conj: 2 + aten::constant_pad_nd: 16 + aten::contiguous: 237 + aten::conv1d: 28 + aten::conv2d: 15 + aten::conv_transpose2d.input: 8 + aten::copy_: 25 + aten::copy_.Tensor: 69 aten::copy_.int: 1 - aten::cos: 4 - aten::count_nonzero: 4 + aten::cos: 9 + aten::count_nonzero: 3 + aten::cross: 1 aten::ctc_loss.Tensor: 1 - aten::cumsum: 13 + aten::cumsum: 19 aten::dequantize.list: 1 - aten::dequantize.self: 30 - aten::dequantize.tensor: 36 - aten::detach: 34 - aten::dim: 36 - aten::div: 9 - aten::div.Scalar: 8 - aten::div.Tensor: 71 - aten::div.Tensor_mode: 7 - aten::div.float: 3 - aten::div.int: 7 - aten::div_.Tensor: 7 - aten::dropout: 41 - aten::embedding: 16 + aten::dequantize.self: 51 + aten::dequantize.tensor: 61 + aten::detach: 73 + aten::diagonal: 1 + aten::dim: 30 + aten::div: 39 + aten::div.Scalar: 33 + aten::div.Tensor: 88 + aten::div.Tensor_mode: 12 + aten::div.float: 8 + aten::div.int: 35 + aten::div_.Tensor: 13 + aten::dot: 1 + aten::dropout: 78 + aten::embedding: 40 aten::embedding_bag.padding_idx: 2 - aten::empty.memory_format: 11 - aten::empty_like: 11 - aten::empty_strided: 3 - aten::eq.Scalar: 24 - aten::eq.Tensor: 6 - aten::eq.int: 57 - aten::eq.int_list: 20 - aten::eq.str: 43 - aten::exp: 18 - aten::exp.float: 4 - aten::expand: 26 - aten::expand_as: 3 - aten::extend.t: 38 + aten::empty.memory_format: 36 + aten::empty_like: 22 + aten::empty_strided: 5 + aten::endswith: 1 + aten::eq: 3 + aten::eq.Scalar: 34 + aten::eq.Tensor: 5 + aten::eq.int: 133 + aten::eq.int_list: 56 + aten::eq.str: 80 + aten::exp: 30 + aten::exp.float: 5 + aten::expand: 73 + aten::expand_as: 7 + aten::extend.t: 125 aten::feature_dropout: 1 - aten::fill_.Scalar: 17 - aten::find: 3 - aten::flatten.using_ints: 45 - aten::flip: 1 - aten::floor: 5 + aten::fft_irfftn: 1 + aten::fft_rfftn: 1 + aten::fill_.Scalar: 25 + aten::fill_.Tensor: 2 + aten::find: 1 + aten::flatten.using_ints: 74 + aten::flip: 2 + aten::floor: 9 aten::floor.float: 2 - aten::floor_divide: 4 - aten::floor_divide.Scalar: 7 - aten::floordiv.int: 21 - aten::format: 58 - aten::full: 10 - aten::full_like: 10 - aten::gather: 10 - aten::ge.Scalar: 4 - aten::ge.Tensor: 6 - aten::ge.int: 29 - aten::gelu: 12 - aten::get.default_str: 14 - aten::glu: 18 - aten::grid_sampler: 3 - aten::gt.Scalar: 16 - aten::gt.float: 16 - aten::gt.float_int: 3 - aten::gt.int: 52 - aten::hardsigmoid: 3 + aten::floor_divide: 8 + aten::floor_divide.Scalar: 37 + aten::floordiv.int: 50 + aten::format: 49 + aten::frobenius_norm.dim: 1 + aten::full: 18 + aten::full_like: 18 + aten::gather: 25 + aten::ge.Scalar: 6 + aten::ge.Tensor: 33 + aten::ge.int: 107 + aten::gelu: 53 + aten::get.default_str: 21 + aten::glu: 31 + aten::grid_sampler: 7 + aten::gt.Scalar: 21 + aten::gt.float: 48 + aten::gt.float_int: 4 + aten::gt.int: 152 + aten::hardsigmoid: 5 aten::hardsigmoid_: 2 - aten::hardswish_: 4 - aten::hardtanh: 3 - aten::hardtanh_: 3 + aten::hardswish_: 6 + aten::hardtanh: 8 + aten::hardtanh_: 5 aten::hstack: 2 - aten::index.Tensor: 23 - aten::index_fill.int_Scalar: 15 - aten::index_put_: 16 - aten::index_select: 31 + aten::imag: 1 + aten::index.Tensor: 53 + aten::index_fill.int_Scalar: 67 + aten::index_put_: 23 + aten::index_select: 100 + aten::insert.t: 1 aten::is_coalesced: 2 - aten::is_floating_point: 9 + aten::is_floating_point: 10 + aten::isfinite: 1 aten::isnan: 1 - aten::item: 40 - aten::items.str: 3 - aten::keys.str: 15 - aten::layer_norm: 26 - aten::le.Scalar: 1 - aten::le.Tensor: 10 + aten::isnumeric: 2 + aten::item: 86 + aten::items.str: 8 + aten::join: 1 + aten::keys.str: 25 + aten::layer_norm: 97 + aten::le.Scalar: 3 + aten::le.Tensor: 17 aten::le.float: 2 - aten::le.int: 17 + aten::le.int: 18 aten::leaky_relu: 1 - aten::leaky_relu_: 5 + aten::leaky_relu_: 4 aten::len.Dict_int: 5 - aten::len.Tensor: 19 - aten::len.str: 23 - aten::len.t: 177 - aten::linear: 46 - aten::linspace: 3 - aten::list.t: 24 - aten::log: 18 - aten::log10: 4 - aten::log1p: 5 - aten::log_softmax.int: 31 - aten::logical_and: 1 - aten::logical_not: 10 - aten::logit: 7 - aten::lower: 10 - aten::lstm.data: 8 - aten::lstm.input: 4 - aten::lt.Scalar: 8 - aten::lt.Tensor: 1 - aten::lt.float: 16 - aten::lt.int: 46 - aten::masked_fill.Scalar: 16 - aten::matmul: 12 - aten::max: 18 - aten::max.dim: 30 - aten::max.other: 7 - aten::max_pool2d: 10 - aten::maximum: 4 - aten::mean: 10 - aten::mean.dim: 16 + aten::len.Tensor: 34 + aten::len.str: 41 + aten::len.t: 327 + aten::linear: 84 + aten::linspace: 7 + aten::list.t: 57 + aten::log: 27 + aten::log.float: 1 + aten::log10: 3 + aten::log1p: 6 + aten::log_softmax.int: 104 + aten::logical_and: 6 + aten::logical_not: 16 + aten::logical_or: 1 + aten::logit: 9 + aten::lower: 22 + aten::lstm.data: 10 + aten::lstm.input: 3 + aten::lt: 3 + aten::lt.Scalar: 19 + aten::lt.Tensor: 5 + aten::lt.float: 48 + aten::lt.int: 140 + aten::masked_fill.Scalar: 48 + aten::masked_fill_.Scalar: 1 + aten::matmul: 32 + aten::max: 54 + aten::max.dim: 60 + aten::max.other: 12 + aten::max_pool2d: 11 + aten::maximum: 8 + aten::mean: 13 + aten::mean.dim: 49 + aten::meshgrid: 2 aten::meshgrid.indexing: 2 - aten::min: 2 - aten::min.dim: 4 - aten::min.other: 17 - aten::minimum: 4 - aten::mse_loss: 1 - aten::mul.Scalar: 26 - aten::mul.Tensor: 90 - aten::mul.float: 5 - aten::mul.float_int: 3 - aten::mul.int: 26 - aten::mul.int_float: 4 - aten::mul.left_t: 15 + aten::min: 5 + aten::min.dim: 8 + aten::min.other: 28 + aten::minimum: 8 + aten::mm: 1 + aten::mse_loss: 4 + aten::mul.Scalar: 72 + aten::mul.Tensor: 188 + aten::mul.float: 8 + aten::mul.float_int: 4 + aten::mul.int: 62 + aten::mul.int_float: 10 + aten::mul.left_t: 31 aten::mul.out: 1 - aten::mul_.Scalar: 11 - aten::mul_.Tensor: 5 + aten::mul_.Scalar: 18 + aten::mul_.Tensor: 9 aten::nan_to_num: 3 - aten::nan_to_num_: 10 - aten::narrow: 10 - aten::ne.Scalar: 14 - aten::ne.Tensor: 5 - aten::ne.int: 44 + aten::nan_to_num_: 13 + aten::narrow: 14 + aten::ne.Scalar: 21 + aten::ne.Tensor: 4 + aten::ne.float: 3 + aten::ne.float_int: 4 + aten::ne.int: 123 aten::ne.int_float: 2 - aten::ne.int_list: 20 + aten::ne.int_list: 21 aten::ne.str: 3 - aten::neg: 29 - aten::neg.int: 19 - aten::new_zeros: 6 + aten::neg: 73 + aten::neg.int: 54 + aten::new_zeros: 32 aten::nll_loss: 1 aten::nll_loss2d: 1 - aten::nll_loss_nd: 3 - aten::nonzero: 4 - aten::norm.Scalar: 1 - aten::norm.ScalarOpt_dim: 4 - aten::numel: 8 + aten::nll_loss_nd: 4 + aten::nonzero: 12 + aten::norm.Scalar: 2 + aten::norm.ScalarOpt_dim: 8 + aten::numel: 38 + aten::numpy_T: 1 aten::one_hot: 2 - aten::ones: 38 - aten::ones_like: 16 - aten::ord: 20 - aten::permute: 43 - aten::pop.t: 7 - aten::pow.Tensor_Scalar: 3 + aten::ones: 50 + aten::ones_like: 27 + aten::ord: 39 + aten::pad: 10 + aten::permute: 135 + aten::pop.t: 8 + aten::pow.Tensor_Scalar: 8 + aten::pow.Tensor_Tensor: 1 + aten::pow.int: 1 aten::pow.int_float: 2 aten::quantile.scalar: 1 - aten::quantize_per_tensor: 66 - aten::quantize_per_tensor.tensor_qparams: 1 - aten::quantized_lstm.data: 2 - aten::rand: 25 - aten::randint.low: 2 - aten::randn_like: 17 + aten::quantize_per_tensor: 105 + aten::quantize_per_tensor.tensor_qparams: 7 + aten::quantize_per_tensor_dynamic: 2 + aten::quantized_lstm.data: 13 + aten::rad2deg: 1 + aten::rand: 33 + aten::randint.low: 4 + aten::randn_like: 100 + aten::real: 1 aten::reciprocal: 1 - aten::reflection_pad2d: 1 - aten::relu: 82 - aten::relu_: 9 - aten::remainder.Scalar: 2 - aten::remainder.int: 22 - aten::repeat: 16 + aten::reflection_pad2d: 2 + aten::relu: 176 + aten::relu_: 15 + aten::remainder.Scalar: 3 + aten::remainder.int: 50 + aten::repeat: 49 + aten::repeat_interleave.self_Tensor: 1 aten::replace: 1 aten::replication_pad1d: 1 - aten::replication_pad2d: 2 + aten::replication_pad2d: 1 aten::replication_pad3d: 1 aten::requires_grad_: 4 - aten::reshape: 36 + aten::reshape: 88 aten::resize_as_: 1 - aten::resolve_conj: 1 - aten::resolve_neg: 1 + aten::resolve_conj: 12 + aten::resolve_neg: 12 aten::reverse.t: 2 - aten::round.Scalar: 4 - aten::rstrip: 1 - aten::rsub.Scalar: 5 - aten::scatter_.src: 6 - aten::scatter_add_: 10 - aten::select.int: 57 + aten::round: 1 + aten::round.Scalar: 18 + aten::round.float: 1 + aten::rsub.Scalar: 7 + aten::scatter_.src: 17 + aten::scatter_add_: 16 + aten::select.int: 111 aten::selu: 2 - aten::sigmoid: 93 - aten::sin: 4 - aten::size: 66 - aten::size.int: 66 - aten::slice.Tensor: 75 - aten::slice.str: 12 - aten::slice.t: 43 - aten::softmax.int: 63 + aten::set_.source_Tensor: 1 + aten::sigmoid: 176 + aten::sin: 9 + aten::size: 155 + aten::size.int: 172 + aten::slice.Tensor: 177 + aten::slice.str: 18 + aten::slice.t: 141 + aten::softmax.int: 139 aten::softplus: 2 - aten::sort: 18 + aten::sort: 36 aten::sparse_coo_tensor.indices: 1 aten::sparse_resize_and_clear_: 1 - aten::split.str: 10 + aten::split.Tensor: 3 + aten::split.sizes: 1 + aten::split.str: 24 + aten::split_with_sizes: 1 aten::sqrt: 1 - aten::squeeze.dim: 26 - aten::stack: 30 - aten::startswith: 10 - aten::str: 16 - aten::strip: 3 - aten::sub: 8 - aten::sub.Scalar: 26 - aten::sub.Tensor: 94 - aten::sub.int: 52 - aten::sub_.Tensor: 4 - aten::sum: 17 - aten::sum.dim_IntList: 19 - aten::sum.int: 1 + aten::squeeze: 2 + aten::squeeze.dim: 51 + aten::stack: 90 + aten::startswith: 16 + aten::str: 25 + aten::strip: 1 + aten::sub: 36 + aten::sub.Scalar: 70 + aten::sub.Tensor: 145 + aten::sub.float: 3 + aten::sub.int: 146 + aten::sub.int_float: 3 + aten::sub_.Scalar: 1 + aten::sub_.Tensor: 6 + aten::sum: 21 + aten::sum.dim_IntList: 25 aten::t: 3 - aten::tanh: 26 - aten::tensor: 51 - aten::tensor.float: 28 - aten::tensor.int: 34 + aten::tanh: 78 + aten::tensor: 156 + aten::tensor.float: 69 + aten::tensor.int: 82 aten::tensor_split.indices: 4 - aten::to.device: 11 - aten::to.dtype: 23 - aten::to.dtype_layout: 27 - aten::to.prim_Device: 23 - aten::to.prim_dtype: 38 - aten::topk: 10 - aten::transpose.int: 33 - aten::triu: 10 + aten::to.device: 20 + aten::to.dtype: 51 + aten::to.dtype_layout: 79 + aten::to.other: 2 + aten::to.prim_Device: 67 + aten::to.prim_dtype: 89 + aten::topk: 21 + aten::transpose.int: 82 + aten::triu: 16 aten::true_divide.Tensor: 2 - aten::trunc_: 3 - aten::type_as: 6 - aten::unbind.int: 24 + aten::trunc_: 7 + aten::type_as: 33 + aten::unbind.int: 80 aten::unique_consecutive: 2 - aten::unsqueeze: 34 - aten::unsqueeze_: 6 - aten::update.str: 4 - aten::upsample_bicubic2d.vec: 1 - aten::upsample_bilinear2d.vec: 8 + aten::unsqueeze: 82 + aten::unsqueeze_: 12 + aten::update.str: 8 + aten::upsample_bicubic2d.vec: 2 + aten::upsample_bilinear2d.vec: 10 aten::upsample_linear1d.vec: 1 aten::upsample_nearest1d.vec: 2 - aten::upsample_nearest2d: 7 - aten::upsample_nearest2d.vec: 30 + aten::upsample_nearest2d: 13 + aten::upsample_nearest2d.vec: 52 aten::upsample_nearest3d.vec: 2 aten::upsample_trilinear3d.vec: 1 aten::values.int: 3 - aten::view: 61 - aten::vstack: 1 - aten::where.ScalarOther: 4 - aten::where.self: 10 - aten::zeros: 75 + aten::view: 117 + aten::vstack: 2 + aten::where.ScalarOther: 3 + aten::where.self: 16 + aten::zeros: 159 aten::zeros.out: 1 - aten::zeros_like: 7 + aten::zeros_like: 11 prepacked::conv2d_clamp_prepack: 2 - prepacked::conv2d_clamp_run: 32 + prepacked::conv2d_clamp_run: 41 prepacked::conv2d_transpose_clamp_prepack: 1 - prepacked::conv2d_transpose_clamp_run: 1 - prepacked::linear_clamp_run: 26 - prim::ModuleContainerIndex.list: 2 - prim::NumToTensor.Scalar: 15 + prepacked::conv2d_transpose_clamp_run: 2 + prepacked::linear_clamp_run: 36 + prim::ModuleContainerIndex.list: 12 + prim::NumToTensor.Scalar: 18 prim::Print: 1 - prim::RaiseException: 103 - prim::TupleIndex: 157 - prim::TupleUnpack: 120 - prim::Uninitialized: 80 - prim::device: 46 - prim::dtype: 45 - prim::is_cuda: 1 - prim::max.float: 7 - prim::max.int: 14 - prim::max.self_int: 17 - prim::min: 4 - prim::min.int: 35 - prim::min.self_int: 25 - prim::unchecked_cast: 100 - quantized::add: 58 - quantized::add_relu: 1 - quantized::batch_norm2d: 1 - quantized::cat: 4 - quantized::conv1d: 1 - quantized::conv2d: 4 - quantized::conv2d.new: 55 - quantized::conv2d_prepack: 14 - quantized::conv2d_relu.new: 50 - quantized::conv_prepack: 5 - quantized::conv_transpose2d: 2 - quantized::embedding_4bit: 1 - quantized::embedding_byte: 14 - quantized::hardswish: 1 - quantized::instance_norm: 1 - quantized::leaky_relu: 2 - quantized::linear: 27 - quantized::linear_dynamic: 21 - quantized::linear_dynamic_fp16: 18 - quantized::linear_prepack: 29 - quantized::linear_prepack_fp16: 25 - quantized::linear_relu: 2 - quantized::linear_unpack: 4 - quantized::linear_unpack_fp16: 4 - quantized::mul: 4 - quantized::mul.Scalar: 1 -traced_operators: - aten::__and__.Tensor: 13 - aten::__iand__.Tensor: 1 - aten::__ior__.Tensor: 1 - aten::_adaptive_avg_pool2d: 23 - aten::_aminmax: 4 - aten::_batch_norm_impl_index: 15 - aten::_cat: 95 - aten::_coalesce: 2 - aten::_coalesced_: 3 - aten::_convolution: 34 - aten::_convolution.deprecated: 3 - aten::_ctc_loss: 1 - aten::_embedding_bag: 2 - aten::_embedding_bag_backward: 1 - aten::_embedding_bag_sparse_backward: 1 - aten::_empty_affine_quantized: 87 - aten::_empty_per_channel_affine_quantized: 28 - aten::_index_put_impl_: 16 - aten::_indices: 4 - aten::_local_scalar_dense: 188 - aten::_log_softmax: 28 - aten::_log_softmax_backward_data: 4 - aten::_make_per_tensor_quantized_tensor: 2 - aten::_nnz: 3 - aten::_pack_padded_sequence: 10 - aten::_pack_padded_sequence_backward: 3 - aten::_pad_packed_sequence: 10 - aten::_reshape_alias: 93 - aten::_reshape_from_tensor: 10 - aten::_s_where: 15 - aten::_shape_as_tensor: 10 - aten::_slow_conv2d_backward.output_mask: 3 - aten::_slow_conv2d_forward: 33 - aten::_softmax: 63 - aten::_sparse_coo_tensor_unsafe: 4 - aten::_sparse_coo_tensor_with_dims_and_tensors: 5 - aten::_to_copy: 188 - aten::_unsafe_view: 28 - aten::_values: 4 - aten::abs: 1 - aten::abs.out: 1 - aten::adaptive_avg_pool2d: 29 - aten::add.Scalar: 30 - aten::add.Tensor: 72 - aten::add.out: 2 - aten::add_.Scalar: 11 - aten::add_.Tensor: 48 - aten::addmm: 41 - aten::alias: 14 - aten::all: 8 - aten::allclose: 1 - aten::aminmax: 4 - aten::any: 14 - aten::any.dim: 1 - aten::arange: 10 - aten::arange.start: 26 - aten::arange.start_out: 28 - aten::arange.start_step: 8 - aten::argmax: 2 - aten::as_strided: 188 - aten::as_strided_: 39 - aten::atan: 4 - aten::atleast_1d.Sequence: 2 - aten::atleast_2d.Sequence: 1 - aten::avg_pool2d: 7 - aten::batch_norm: 15 - aten::bernoulli_.float: 2 - aten::binary_cross_entropy: 13 - aten::binary_cross_entropy_backward: 12 - aten::binary_cross_entropy_with_logits: 3 - aten::binary_cross_entropy_with_logits_backward: 2 - aten::bitwise_and.Tensor: 13 - aten::bitwise_and_.Tensor: 1 - aten::bitwise_not: 13 - aten::bitwise_or_.Tensor: 1 - aten::bmm: 18 - aten::broadcast_tensors: 1 - aten::cat: 95 - aten::ceil: 4 - aten::ceil_: 1 - aten::chunk: 20 - aten::clamp: 38 - aten::clamp_: 12 - aten::clamp_min: 73 - aten::clamp_min.out: 74 - aten::clamp_min_: 4 - aten::clone: 134 - aten::coalesce: 2 - aten::conj: 1 - aten::constant_pad_nd: 14 - aten::contiguous: 139 - aten::conv1d: 12 - aten::conv2d: 7 - aten::conv_transpose2d.input: 5 - aten::convolution: 19 - aten::convolution_backward: 3 - aten::copy_: 188 - aten::copy_sparse_to_sparse_: 3 - aten::cos: 4 - aten::count_nonzero: 4 - aten::count_nonzero.dim_IntList: 4 - aten::ctc_loss.Tensor: 1 - aten::cudnn_is_acceptable: 12 - aten::cumsum: 14 - aten::dense_dim: 3 - aten::dequantize.self: 63 - aten::dequantize.tensors: 1 - aten::detach: 49 - aten::div.Scalar: 188 - aten::div.Tensor: 188 - aten::div.Tensor_mode: 8 - aten::div_.Scalar: 27 - aten::div_.Tensor: 34 - aten::dropout: 41 - aten::elu: 2 - aten::embedding: 16 - aten::embedding_backward: 4 - aten::embedding_bag.padding_idx: 2 - aten::embedding_dense_backward: 4 - aten::embedding_sparse_backward: 1 - aten::empty.memory_format: 188 - aten::empty_like: 162 - aten::empty_strided: 188 - aten::eq.Scalar: 25 - aten::eq.Tensor: 188 - aten::exp: 15 - aten::exp_: 3 - aten::expand: 63 - aten::expand_as: 17 - aten::feature_dropout: 1 - aten::fill_.Scalar: 188 - aten::flatten.using_ints: 42 - aten::flip: 1 - aten::floor: 6 - aten::floor_divide: 7 - aten::floor_divide.Scalar: 7 - aten::full: 21 - aten::full_like: 10 - aten::gather: 11 - aten::ge.Scalar: 2 - aten::gelu: 12 - aten::glu: 18 - aten::grid_sampler: 3 - aten::grid_sampler_2d: 3 - aten::gt.Scalar: 16 - aten::hardsigmoid: 3 - aten::hardsigmoid_: 2 - aten::hardswish_: 4 - aten::hardtanh: 3 - aten::hstack: 2 - aten::index.Tensor: 20 - aten::index_add_: 4 - aten::index_fill.int_Scalar: 1 - aten::index_fill_.int_Scalar: 1 - aten::index_put_: 16 - aten::index_select: 28 - aten::index_select_backward: 3 - aten::is_coalesced: 3 - aten::is_floating_point: 8 - aten::isclose: 1 - aten::isfinite: 1 - aten::isnan: 1 - aten::item: 188 - aten::layer_norm: 26 - aten::le.Scalar: 2 - aten::le.Tensor: 1 - aten::leaky_relu: 1 - aten::leaky_relu_: 5 - aten::lerp_.Tensor: 1 - aten::linear: 51 - aten::linspace: 3 - aten::linspace.out: 3 - aten::log: 15 - aten::log10: 4 - aten::log1p: 5 - aten::log_: 3 - aten::log_softmax.int: 28 - aten::logical_and: 1 - aten::logical_and.out: 2 - aten::logical_and_: 1 - aten::logit: 7 - aten::lstm.data: 8 - aten::lstm.input: 4 - aten::lt.Scalar: 8 - aten::lt.Tensor: 1 - aten::masked_fill.Scalar: 3 - aten::masked_fill_.Scalar: 18 - aten::matmul: 31 - aten::max: 27 - aten::max.dim: 31 - aten::max.other: 4 - aten::max_pool2d: 7 - aten::maximum: 4 - aten::mean: 16 - aten::mean.dim: 26 - aten::meshgrid.indexing: 2 - aten::min: 25 - aten::min.dim: 5 - aten::min.other: 4 - aten::minimum: 5 - aten::mm: 40 - aten::mul.Scalar: 31 - aten::mul.Tensor: 103 - aten::mul.out: 12 - aten::mul_.Scalar: 11 - aten::mul_.Tensor: 7 - aten::nan_to_num: 3 - aten::nan_to_num.out: 13 - aten::nan_to_num_: 10 - aten::narrow: 188 - aten::native_batch_norm: 15 - aten::native_layer_norm: 26 - aten::native_layer_norm_backward: 1 - aten::ne.Scalar: 15 - aten::ne.Tensor: 6 - aten::neg: 29 - aten::new_empty_strided: 188 - aten::nll_loss: 4 - aten::nll_loss_backward: 4 - aten::nll_loss_forward: 4 - aten::nll_loss_nd: 3 - aten::nonzero: 16 - aten::norm.Scalar: 1 - aten::norm.ScalarOpt_dim: 5 - aten::normal_: 17 - aten::one_hot: 2 - aten::ones: 188 - aten::ones_like: 25 - aten::permute: 44 - aten::pow.Tensor_Scalar: 3 - aten::q_per_channel_scales: 28 - aten::q_per_channel_zero_points: 28 - aten::q_scale: 65 - aten::q_zero_point: 85 - aten::qscheme: 85 - aten::quantile.scalar: 1 - aten::quantize_per_tensor: 84 - aten::quantize_per_tensor.tensor_qparams: 1 - aten::quantized_lstm.data: 2 - aten::quantized_max_pool2d: 3 - aten::rand: 25 - aten::randint.low: 2 - aten::randn_like: 17 - aten::random_.from: 2 - aten::reciprocal: 1 - aten::reflection_pad2d: 1 - aten::relu: 79 - aten::relu_: 4 - aten::remainder.Scalar: 2 - aten::remainder.Tensor: 2 - aten::repeat: 14 - aten::replication_pad2d: 2 - aten::requires_grad_: 2 - aten::reshape: 69 - aten::resize_: 188 - aten::resize_as_: 18 - aten::resolve_conj: 70 - aten::resolve_neg: 1 - aten::result_type.Scalar: 3 - aten::rsub.Scalar: 5 - aten::scalar_tensor: 1 - aten::scatter_.src: 6 - aten::scatter_.value: 2 - aten::scatter_add_: 10 - aten::select.int: 77 - aten::select_backward: 1 - aten::selu: 2 - aten::set_.source_Storage: 186 - aten::set_.source_Storage_storage_offset: 186 - aten::sigmoid: 90 - aten::sigmoid_: 14 - aten::sigmoid_backward: 17 - aten::sin: 4 - aten::slice.Tensor: 188 - aten::slice_backward: 4 - aten::slow_conv_transpose2d: 6 - aten::softmax.int: 63 - aten::softplus: 2 - aten::sort: 20 - aten::sparse_coo_tensor.indices: 1 - aten::sparse_dim: 3 - aten::sparse_resize_and_clear_: 1 - aten::split.Tensor: 20 - aten::sqrt: 1 - aten::squeeze: 13 - aten::squeeze.dim: 38 - aten::squeeze_.dim: 36 - aten::stack: 39 - aten::sub.Scalar: 23 - aten::sub.Tensor: 105 - aten::sub_.Scalar: 1 - aten::sub_.Tensor: 7 - aten::sum: 18 - aten::sum.IntList_out: 29 - aten::sum.dim_IntList: 41 - aten::t: 49 - aten::tanh: 40 - aten::tanh_: 14 - aten::tanh_backward: 5 - aten::tensor_split.indices: 4 - aten::thnn_conv2d: 33 - aten::threshold_backward: 17 - aten::to.device: 35 - aten::to.dtype: 188 - aten::to.dtype_layout: 184 - aten::topk: 10 - aten::transpose.int: 73 - aten::triu: 10 - aten::true_divide.Tensor: 2 - aten::trunc_: 4 - aten::type_as: 6 - aten::unbind.int: 38 - aten::unfold: 14 - aten::uniform_: 25 - aten::unique_consecutive: 2 - aten::unsafe_chunk: 14 - aten::unsafe_split.Tensor: 14 - aten::unsqueeze: 56 - aten::unsqueeze_: 31 - aten::upsample_bilinear2d: 7 - aten::upsample_bilinear2d.vec: 7 - aten::upsample_nearest2d: 31 - aten::upsample_nearest2d.vec: 27 - aten::value_selecting_reduction_backward: 3 - aten::view: 95 - aten::vstack: 1 - aten::where.ScalarOther: 4 - aten::where.self: 15 - aten::zero_: 188 - aten::zeros: 188 - aten::zeros.out: 1 - aten::zeros_like: 6 - prepacked::conv2d_clamp_prepack: 1 - prepacked::conv2d_clamp_run: 32 - prepacked::conv2d_transpose_clamp_run: 1 - prepacked::linear_clamp_run: 26 - quantized::add: 58 - quantized::add_relu: 1 + prim::RaiseException: 100 + prim::TupleIndex: 147 + prim::TupleUnpack: 235 + prim::Uninitialized: 76 + prim::abs: 1 + prim::device: 43 + prim::dtype: 43 + prim::max.float: 25 + prim::max.int: 25 + prim::max.self_int: 34 + prim::min: 20 + prim::min.float: 3 + prim::min.int: 81 + prim::min.self_int: 37 + prim::unchecked_cast: 204 + quantized::add: 90 + quantized::add_relu: 2 quantized::batch_norm2d: 1 quantized::cat: 4 - quantized::conv1d: 1 - quantized::conv2d: 4 - quantized::conv2d.new: 55 + quantized::conv1d: 5 + quantized::conv2d: 2 + quantized::conv2d.new: 86 quantized::conv2d_prepack: 14 - quantized::conv2d_relu.new: 50 - quantized::conv_prepack: 5 - quantized::conv_transpose2d: 2 - quantized::embedding_byte: 14 + quantized::conv2d_relu.new: 83 + quantized::conv_prepack: 3 + quantized::conv_transpose2d: 1 + quantized::embedding_4bit: 18 + quantized::embedding_byte: 34 quantized::hardswish: 1 - quantized::instance_norm: 1 + quantized::instance_norm: 2 quantized::leaky_relu: 2 - quantized::linear: 27 - quantized::linear_dynamic: 21 - quantized::linear_prepack: 29 - quantized::linear_relu: 2 - quantized::mul: 4 - quantized::mul.Scalar: 1 + quantized::linear: 39 + quantized::linear_dynamic: 91 + quantized::linear_dynamic_fp16: 78 + quantized::linear_prepack: 69 + quantized::linear_prepack_fp16: 67 + quantized::linear_relu: 7 + quantized::linear_unpack: 46 + quantized::linear_unpack_fp16: 46 + quantized::mul: 5 + quantized::mul.Scalar: 4 diff --git a/test/mobile/nnc/test_aot_compile.sh b/test/mobile/nnc/test_aot_compile.sh index f4387a83c4415..141a01270f891 100755 --- a/test/mobile/nnc/test_aot_compile.sh +++ b/test/mobile/nnc/test_aot_compile.sh @@ -23,4 +23,5 @@ test_aot_model_compiler() { popd } -test_aot_model_compiler +# Temporarily disable the test since NNC backend is no longer available. +# test_aot_model_compiler diff --git a/test/mobile/nnc/test_nnc_backend.cpp b/test/mobile/nnc/test_nnc_backend.cpp index 35bf60f2cca79..704a1c4eb4caf 100644 --- a/test/mobile/nnc/test_nnc_backend.cpp +++ b/test/mobile/nnc/test_nnc_backend.cpp @@ -65,7 +65,7 @@ REGISTER_NNC_KERNEL( "_add_kernel_nnc_fake_model:v1:forward:VERTOKEN", add_kernel) -TEST(NNCBackendTest, AOTCompileThenExecute) { +TEST(DISABLED_NNCBackendTest, AOTCompileThenExecute) { torch::jit::Module m("m"); auto param = torch::ones({1}); m.register_parameter("param", param, false); diff --git a/test/onnx/debug_embed_params.py b/test/onnx/debug_embed_params.py index 7fe40a5906dcd..3bee953dd4e1d 100644 --- a/test/onnx/debug_embed_params.py +++ b/test/onnx/debug_embed_params.py @@ -1,9 +1,9 @@ import sys -import onnx -from test_pytorch_common import flatten - import caffe2.python.onnx.backend as c2 + +import onnx +import pytorch_test_common import torch import torch.jit from torch.autograd import Variable @@ -41,7 +41,9 @@ def run_embed_params(proto, model, input, state_dict=None, use_gpu=True): parameters = list(model.state_dict().values()) W = {} - for k, v in zip(model_def.graph.input, flatten((input, parameters))): + for k, v in zip( + model_def.graph.input, pytorch_test_common.flatten((input, parameters)) + ): if isinstance(v, Variable): W[k.name] = v.data.cpu().numpy() else: diff --git a/test/onnx/expect/TestOperators.test_baddbmm.expect b/test/onnx/expect/TestOperators.test_baddbmm.expect index fc7eb0f8295e6..058770e803269 100644 --- a/test/onnx/expect/TestOperators.test_baddbmm.expect +++ b/test/onnx/expect/TestOperators.test_baddbmm.expect @@ -9,38 +9,54 @@ graph { name: "MatMul_0" op_type: "MatMul" } + node { + output: "onnx::Mul_11" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\000\000\200?" + } + type: TENSOR + } + } node { input: "onnx::Mul_5" input: "onnx::Mul_11" output: "onnx::Add_7" - name: "Mul_1" + name: "Mul_2" op_type: "Mul" } + node { + output: "onnx::Mul_12" + name: "Constant_3" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\000\000\200?" + } + type: TENSOR + } + } node { input: "onnx::Mul_0" input: "onnx::Mul_12" output: "onnx::Add_9" - name: "Mul_2" + name: "Mul_4" op_type: "Mul" } node { input: "onnx::Add_7" input: "onnx::Add_9" output: "10" - name: "Add_3" + name: "Add_5" op_type: "Add" } name: "torch_jit" - initializer { - data_type: 1 - name: "onnx::Mul_11" - raw_data: "\000\000\200?" - } - initializer { - data_type: 1 - name: "onnx::Mul_12" - raw_data: "\000\000\200?" - } input { name: "onnx::Mul_0" type { diff --git a/test/onnx/expect/TestOperators.test_bitshift.expect b/test/onnx/expect/TestOperators.test_bitshift.expect index 10199d03efcd1..3bc677e2abd0b 100644 --- a/test/onnx/expect/TestOperators.test_bitshift.expect +++ b/test/onnx/expect/TestOperators.test_bitshift.expect @@ -2,11 +2,24 @@ ir_version: 6 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::BitShift_7" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 2 + raw_data: "\001" + } + type: TENSOR + } + } node { input: "onnx::BitShift_0" input: "onnx::BitShift_7" output: "3" - name: "BitShift_0" + name: "BitShift_1" op_type: "BitShift" attribute { name: "direction" @@ -14,11 +27,24 @@ graph { type: STRING } } + node { + output: "onnx::BitShift_8" + name: "Constant_2" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 2 + raw_data: "\002" + } + type: TENSOR + } + } node { input: "onnx::BitShift_0" input: "onnx::BitShift_8" output: "6" - name: "BitShift_1" + name: "BitShift_3" op_type: "BitShift" attribute { name: "direction" @@ -27,16 +53,6 @@ graph { } } name: "torch_jit" - initializer { - data_type: 2 - name: "onnx::BitShift_7" - raw_data: "\001" - } - initializer { - data_type: 2 - name: "onnx::BitShift_8" - raw_data: "\002" - } input { name: "onnx::BitShift_0" type { diff --git a/test/onnx/expect/TestOperators.test_clip.expect b/test/onnx/expect/TestOperators.test_clip.expect index 81606851e7851..67dad133acec7 100644 --- a/test/onnx/expect/TestOperators.test_clip.expect +++ b/test/onnx/expect/TestOperators.test_clip.expect @@ -2,25 +2,41 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::Clip_6" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\000\000\000\277" + } + type: TENSOR + } + } + node { + output: "onnx::Clip_7" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\000\000\000?" + } + type: TENSOR + } + } node { input: "onnx::Clip_0" input: "onnx::Clip_6" input: "onnx::Clip_7" output: "5" - name: "Clip_0" + name: "Clip_2" op_type: "Clip" } name: "torch_jit" - initializer { - data_type: 1 - name: "onnx::Clip_6" - raw_data: "\000\000\000\277" - } - initializer { - data_type: 1 - name: "onnx::Clip_7" - raw_data: "\000\000\000?" - } input { name: "onnx::Clip_0" type { diff --git a/test/onnx/expect/TestOperators.test_clip_max.expect b/test/onnx/expect/TestOperators.test_clip_max.expect index 7fdb350daa041..23b001cd4e7e6 100644 --- a/test/onnx/expect/TestOperators.test_clip_max.expect +++ b/test/onnx/expect/TestOperators.test_clip_max.expect @@ -2,20 +2,28 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::Clip_7" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\315\314\314=" + } + type: TENSOR + } + } node { input: "onnx::Clip_0" input: "" input: "onnx::Clip_7" output: "5" - name: "Clip_0" + name: "Clip_1" op_type: "Clip" } name: "torch_jit" - initializer { - data_type: 1 - name: "onnx::Clip_7" - raw_data: "\315\314\314=" - } input { name: "onnx::Clip_0" type { diff --git a/test/onnx/expect/TestOperators.test_clip_min.expect b/test/onnx/expect/TestOperators.test_clip_min.expect index 8b260da8f90ba..3bd4c47ef8583 100644 --- a/test/onnx/expect/TestOperators.test_clip_min.expect +++ b/test/onnx/expect/TestOperators.test_clip_min.expect @@ -2,20 +2,28 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::Clip_7" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\315\314\314\275" + } + type: TENSOR + } + } node { input: "onnx::Clip_0" input: "onnx::Clip_7" input: "" output: "5" - name: "Clip_0" + name: "Clip_1" op_type: "Clip" } name: "torch_jit" - initializer { - data_type: 1 - name: "onnx::Clip_7" - raw_data: "\315\314\314\275" - } input { name: "onnx::Clip_0" type { diff --git a/test/onnx/expect/TestOperators.test_embedding_bags.expect b/test/onnx/expect/TestOperators.test_embedding_bags.expect index eb4a94b75590b..b2b11d5bd4f0f 100644 --- a/test/onnx/expect/TestOperators.test_embedding_bags.expect +++ b/test/onnx/expect/TestOperators.test_embedding_bags.expect @@ -3,9 +3,22 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - output: "5" + output: "onnx::Loop_33" name: "Constant_0" op_type: "Constant" + attribute { + name: "value" + t { + data_type: 9 + raw_data: "\001" + } + type: TENSOR + } + } + node { + output: "5" + name: "Constant_1" + op_type: "Constant" attribute { name: "value" t { @@ -19,12 +32,12 @@ graph { node { input: "input" output: "onnx::Gather_6" - name: "Shape_1" + name: "Shape_2" op_type: "Shape" } node { output: "onnx::Gather_7" - name: "Constant_2" + name: "Constant_3" op_type: "Constant" attribute { name: "value" @@ -39,7 +52,7 @@ graph { input: "onnx::Gather_6" input: "onnx::Gather_7" output: "onnx::Unsqueeze_8" - name: "Gather_3" + name: "Gather_4" op_type: "Gather" attribute { name: "axis" @@ -49,7 +62,7 @@ graph { } node { output: "onnx::Unsqueeze_9" - name: "Constant_4" + name: "Constant_5" op_type: "Constant" attribute { name: "value" @@ -65,14 +78,14 @@ graph { input: "onnx::Unsqueeze_8" input: "onnx::Unsqueeze_9" output: "onnx::Concat_10" - name: "Unsqueeze_5" + name: "Unsqueeze_6" op_type: "Unsqueeze" } node { input: "offsets" input: "onnx::Concat_10" output: "onnx::Slice_11" - name: "Concat_6" + name: "Concat_7" op_type: "Concat" attribute { name: "axis" @@ -82,7 +95,7 @@ graph { } node { output: "onnx::Slice_12" - name: "Constant_7" + name: "Constant_8" op_type: "Constant" attribute { name: "value" @@ -96,7 +109,7 @@ graph { } node { output: "onnx::Slice_13" - name: "Constant_8" + name: "Constant_9" op_type: "Constant" attribute { name: "value" @@ -110,7 +123,7 @@ graph { } node { output: "onnx::Slice_14" - name: "Constant_9" + name: "Constant_10" op_type: "Constant" attribute { name: "value" @@ -124,7 +137,7 @@ graph { } node { output: "onnx::Slice_15" - name: "Constant_10" + name: "Constant_11" op_type: "Constant" attribute { name: "value" @@ -143,18 +156,18 @@ graph { input: "onnx::Slice_12" input: "onnx::Slice_15" output: "onnx::Shape_16" - name: "Slice_11" + name: "Slice_12" op_type: "Slice" } node { input: "onnx::Shape_16" output: "onnx::Gather_17" - name: "Shape_12" + name: "Shape_13" op_type: "Shape" } node { output: "onnx::Gather_18" - name: "Constant_13" + name: "Constant_14" op_type: "Constant" attribute { name: "value" @@ -169,7 +182,7 @@ graph { input: "onnx::Gather_17" input: "onnx::Gather_18" output: "onnx::Loop_19" - name: "Gather_14" + name: "Gather_15" op_type: "Gather" attribute { name: "axis" @@ -181,7 +194,7 @@ graph { input: "onnx::Loop_19" input: "onnx::Loop_33" output: "20" - name: "Loop_15" + name: "Loop_16" op_type: "Loop" attribute { name: "body" @@ -190,7 +203,7 @@ graph { input: "onnx::Slice_11" input: "21" output: "23" - name: "Gather_16" + name: "Gather_17" op_type: "Gather" attribute { name: "axis" @@ -202,7 +215,7 @@ graph { input: "onnx::Shape_16" input: "21" output: "24" - name: "Gather_17" + name: "Gather_18" op_type: "Gather" attribute { name: "axis" @@ -212,7 +225,7 @@ graph { } node { output: "25" - name: "Constant_18" + name: "Constant_19" op_type: "Constant" attribute { name: "value" @@ -228,12 +241,12 @@ graph { input: "23" input: "25" output: "26" - name: "Unsqueeze_19" + name: "Unsqueeze_20" op_type: "Unsqueeze" } node { output: "27" - name: "Constant_20" + name: "Constant_21" op_type: "Constant" attribute { name: "value" @@ -249,7 +262,7 @@ graph { input: "24" input: "27" output: "28" - name: "Unsqueeze_21" + name: "Unsqueeze_22" op_type: "Unsqueeze" } node { @@ -258,14 +271,14 @@ graph { input: "28" input: "5" output: "29" - name: "Slice_22" + name: "Slice_23" op_type: "Slice" } node { input: "weight" input: "29" output: "30" - name: "Gather_23" + name: "Gather_24" op_type: "Gather" attribute { name: "axis" @@ -276,7 +289,7 @@ graph { node { input: "30" output: "31" - name: "ReduceMean_24" + name: "ReduceMean_25" op_type: "ReduceMean" attribute { name: "axes" @@ -292,7 +305,7 @@ graph { node { input: "onnx::Loop_33" output: "32" - name: "Cast_25" + name: "Cast_26" op_type: "Cast" attribute { name: "to" @@ -356,11 +369,6 @@ graph { name: "weight" raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?\027X\277\277\253\206a?\354\335\226\277L\032o\277\251J\021\277\311\360\215\276\312\274\013\300\252\320\273>\220\"p?\267\020\000\222\233\314?\334\360?\275|t\303\277\214\351\000\300\3065\302\2775\206\306>X\251\227\277x\2160?U^\251?d\221\350?\237F.?\rp9?9X\004=/c\324\277SL\360\277\'\274\332\356\226\275\211\035\241>*\271\204\277>\025W>\036K\035?\036\233\200=\035\313\250\276\017\003\346\277\374p_?\313WD?!\006\351\275\232\\q\277\230\007A?" } - initializer { - data_type: 9 - name: "onnx::Loop_33" - raw_data: "\001" - } input { name: "input" type { @@ -403,16 +411,6 @@ graph { } } } - input { - name: "onnx::Loop_33" - type { - tensor_type { - elem_type: 9 - shape { - } - } - } - } output { name: "20" type { diff --git a/test/onnx/expect/TestOperators.test_empty_like_opset7.expect b/test/onnx/expect/TestOperators.test_empty_like_opset7.expect deleted file mode 100644 index 504162493a003..0000000000000 --- a/test/onnx/expect/TestOperators.test_empty_like_opset7.expect +++ /dev/null @@ -1,68 +0,0 @@ -ir_version: 3 -producer_name: "pytorch" -producer_version: "CURRENT_VERSION" -graph { - node { - input: "onnx::Shape_0" - output: "onnx::ConstantFill_1" - name: "Shape_0" - op_type: "Shape" - } - node { - input: "onnx::ConstantFill_1" - output: "2" - name: "ConstantFill_1" - op_type: "ConstantFill" - attribute { - name: "dtype" - i: 1 - type: INT - } - attribute { - name: "input_as_shape" - i: 1 - type: INT - } - attribute { - name: "value" - f: 0 - type: FLOAT - } - } - name: "torch_jit" - input { - name: "onnx::Shape_0" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 5 - } - dim { - dim_value: 8 - } - } - } - } - } - output { - name: "2" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 5 - } - dim { - dim_value: 8 - } - } - } - } - } -} -opset_import { - version: 7 -} diff --git a/test/onnx/expect/TestOperators.test_expand.expect b/test/onnx/expect/TestOperators.test_expand.expect index 6634173a0a63a..36804d0062e53 100644 --- a/test/onnx/expect/TestOperators.test_expand.expect +++ b/test/onnx/expect/TestOperators.test_expand.expect @@ -16,10 +16,24 @@ graph { type: TENSOR } } + node { + output: "onnx::ConstantOfShape_10" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\003\000\000\000\000\000\000\000" + } + type: TENSOR + } + } node { input: "onnx::ConstantOfShape_10" output: "onnx::Mul_3" - name: "ConstantOfShape_1" + name: "ConstantOfShape_2" op_type: "ConstantOfShape" attribute { name: "value" @@ -33,7 +47,7 @@ graph { } node { output: "onnx::Mul_4" - name: "Constant_2" + name: "Constant_3" op_type: "Constant" attribute { name: "value" @@ -48,12 +62,12 @@ graph { input: "onnx::Mul_3" input: "onnx::Mul_4" output: "onnx::Equal_5" - name: "Mul_3" + name: "Mul_4" op_type: "Mul" } node { output: "onnx::Equal_6" - name: "Constant_4" + name: "Constant_5" op_type: "Constant" attribute { name: "value" @@ -69,7 +83,7 @@ graph { input: "onnx::Equal_6" input: "onnx::Equal_5" output: "onnx::Where_7" - name: "Equal_5" + name: "Equal_6" op_type: "Equal" } node { @@ -77,23 +91,17 @@ graph { input: "onnx::Mul_3" input: "onnx::Where_1" output: "onnx::Expand_8" - name: "Where_6" + name: "Where_7" op_type: "Where" } node { input: "onnx::Expand_0" input: "onnx::Expand_8" output: "9" - name: "Expand_7" + name: "Expand_8" op_type: "Expand" } name: "torch_jit" - initializer { - dims: 1 - data_type: 7 - name: "onnx::ConstantOfShape_10" - raw_data: "\003\000\000\000\000\000\000\000" - } input { name: "onnx::Expand_0" type { diff --git a/test/onnx/expect/TestOperators.test_mul_bool.expect b/test/onnx/expect/TestOperators.test_mul_bool.expect new file mode 100644 index 0000000000000..455967e543cbf --- /dev/null +++ b/test/onnx/expect/TestOperators.test_mul_bool.expect @@ -0,0 +1,55 @@ +ir_version: 7 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + input: "onnx::And_0" + input: "onnx::And_1" + output: "2" + name: "And_0" + op_type: "And" + } + name: "torch_jit" + input { + name: "onnx::And_0" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "onnx::And_1" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "2" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/test/onnx/expect/TestOperators.test_mul_fp_bool.expect b/test/onnx/expect/TestOperators.test_mul_fp_bool.expect new file mode 100644 index 0000000000000..dee222fbb1fac --- /dev/null +++ b/test/onnx/expect/TestOperators.test_mul_fp_bool.expect @@ -0,0 +1,66 @@ +ir_version: 7 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + input: "onnx::Cast_1" + output: "onnx::Mul_2" + name: "Cast_0" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + } + node { + input: "onnx::Mul_0" + input: "onnx::Mul_2" + output: "3" + name: "Mul_1" + op_type: "Mul" + } + name: "torch_jit" + input { + name: "onnx::Mul_0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "onnx::Cast_1" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 3 + } + } + } + } + } + output { + name: "3" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/test/onnx/expect/TestOperators.test_narrow.expect b/test/onnx/expect/TestOperators.test_narrow.expect index a7b13c89a646c..b25611d9374a1 100644 --- a/test/onnx/expect/TestOperators.test_narrow.expect +++ b/test/onnx/expect/TestOperators.test_narrow.expect @@ -2,34 +2,58 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::Slice_14" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\000\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + output: "onnx::Slice_15" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\002\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + output: "onnx::Slice_16" + name: "Constant_2" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\000\000\000\000\000\000\000\000" + } + type: TENSOR + } + } node { input: "onnx::Slice_0" input: "onnx::Slice_14" input: "onnx::Slice_15" input: "onnx::Slice_16" output: "12" - name: "Slice_0" + name: "Slice_3" op_type: "Slice" } name: "torch_jit" - initializer { - dims: 1 - data_type: 7 - name: "onnx::Slice_14" - raw_data: "\000\000\000\000\000\000\000\000" - } - initializer { - dims: 1 - data_type: 7 - name: "onnx::Slice_15" - raw_data: "\002\000\000\000\000\000\000\000" - } - initializer { - dims: 1 - data_type: 7 - name: "onnx::Slice_16" - raw_data: "\000\000\000\000\000\000\000\000" - } input { name: "onnx::Slice_0" type { diff --git a/test/onnx/expect/TestOperators.test_pad.expect b/test/onnx/expect/TestOperators.test_pad.expect index 293877ab834aa..e4554ae3181ae 100644 --- a/test/onnx/expect/TestOperators.test_pad.expect +++ b/test/onnx/expect/TestOperators.test_pad.expect @@ -2,10 +2,38 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::ConstantOfShape_27" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\004\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + output: "onnx::Concat_28" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 4 + data_type: 7 + raw_data: "\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } node { input: "onnx::ConstantOfShape_27" output: "onnx::Concat_10" - name: "ConstantOfShape_0" + name: "ConstantOfShape_2" op_type: "ConstantOfShape" attribute { name: "value" @@ -21,7 +49,7 @@ graph { input: "onnx::Concat_28" input: "onnx::Concat_10" output: "onnx::Reshape_11" - name: "Concat_1" + name: "Concat_3" op_type: "Concat" attribute { name: "axis" @@ -31,7 +59,7 @@ graph { } node { output: "onnx::Reshape_12" - name: "Constant_2" + name: "Constant_4" op_type: "Constant" attribute { name: "value" @@ -47,12 +75,12 @@ graph { input: "onnx::Reshape_11" input: "onnx::Reshape_12" output: "onnx::Slice_13" - name: "Reshape_3" + name: "Reshape_5" op_type: "Reshape" } node { output: "onnx::Slice_14" - name: "Constant_4" + name: "Constant_6" op_type: "Constant" attribute { name: "value" @@ -66,7 +94,7 @@ graph { } node { output: "onnx::Slice_15" - name: "Constant_5" + name: "Constant_7" op_type: "Constant" attribute { name: "value" @@ -80,7 +108,7 @@ graph { } node { output: "onnx::Slice_16" - name: "Constant_6" + name: "Constant_8" op_type: "Constant" attribute { name: "value" @@ -94,7 +122,7 @@ graph { } node { output: "onnx::Slice_17" - name: "Constant_7" + name: "Constant_9" op_type: "Constant" attribute { name: "value" @@ -113,13 +141,13 @@ graph { input: "onnx::Slice_14" input: "onnx::Slice_17" output: "onnx::Transpose_18" - name: "Slice_8" + name: "Slice_10" op_type: "Slice" } node { input: "onnx::Transpose_18" output: "onnx::Reshape_19" - name: "Transpose_9" + name: "Transpose_11" op_type: "Transpose" attribute { name: "perm" @@ -130,7 +158,7 @@ graph { } node { output: "onnx::Reshape_20" - name: "Constant_10" + name: "Constant_12" op_type: "Constant" attribute { name: "value" @@ -146,13 +174,13 @@ graph { input: "onnx::Reshape_19" input: "onnx::Reshape_20" output: "onnx::Cast_21" - name: "Reshape_11" + name: "Reshape_13" op_type: "Reshape" } node { input: "onnx::Cast_21" output: "onnx::Pad_22" - name: "Cast_12" + name: "Cast_14" op_type: "Cast" attribute { name: "to" @@ -164,7 +192,7 @@ graph { input: "onnx::Pad_0" input: "onnx::Pad_22" output: "23" - name: "Pad_13" + name: "Pad_15" op_type: "Pad" attribute { name: "mode" @@ -173,18 +201,6 @@ graph { } } name: "torch_jit" - initializer { - dims: 1 - data_type: 7 - name: "onnx::ConstantOfShape_27" - raw_data: "\004\000\000\000\000\000\000\000" - } - initializer { - dims: 4 - data_type: 7 - name: "onnx::Concat_28" - raw_data: "\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" - } input { name: "onnx::Pad_0" type { diff --git a/test/onnx/expect/TestOperators.test_repeat.expect b/test/onnx/expect/TestOperators.test_repeat.expect index 5206bce0d88ff..76203f189e3ec 100644 --- a/test/onnx/expect/TestOperators.test_repeat.expect +++ b/test/onnx/expect/TestOperators.test_repeat.expect @@ -16,10 +16,24 @@ graph { type: TENSOR } } + node { + output: "onnx::ConstantOfShape_6" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\004\000\000\000\000\000\000\000" + } + type: TENSOR + } + } node { input: "onnx::ConstantOfShape_6" output: "onnx::Expand_3" - name: "ConstantOfShape_1" + name: "ConstantOfShape_2" op_type: "ConstantOfShape" attribute { name: "value" @@ -35,23 +49,17 @@ graph { input: "onnx::Expand_0" input: "onnx::Expand_3" output: "onnx::Tile_4" - name: "Expand_2" + name: "Expand_3" op_type: "Expand" } node { input: "onnx::Tile_4" input: "onnx::Tile_1" output: "5" - name: "Tile_3" + name: "Tile_4" op_type: "Tile" } name: "torch_jit" - initializer { - dims: 1 - data_type: 7 - name: "onnx::ConstantOfShape_6" - raw_data: "\004\000\000\000\000\000\000\000" - } input { name: "onnx::Expand_0" type { diff --git a/test/onnx/expect/TestOperators.test_repeat_dim_overflow.expect b/test/onnx/expect/TestOperators.test_repeat_dim_overflow.expect index 2dbb3a436d42b..cdbadc5f43eb7 100644 --- a/test/onnx/expect/TestOperators.test_repeat_dim_overflow.expect +++ b/test/onnx/expect/TestOperators.test_repeat_dim_overflow.expect @@ -16,10 +16,24 @@ graph { type: TENSOR } } + node { + output: "onnx::ConstantOfShape_6" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\004\000\000\000\000\000\000\000" + } + type: TENSOR + } + } node { input: "onnx::ConstantOfShape_6" output: "onnx::Expand_3" - name: "ConstantOfShape_1" + name: "ConstantOfShape_2" op_type: "ConstantOfShape" attribute { name: "value" @@ -35,23 +49,17 @@ graph { input: "onnx::Expand_0" input: "onnx::Expand_3" output: "onnx::Tile_4" - name: "Expand_2" + name: "Expand_3" op_type: "Expand" } node { input: "onnx::Tile_4" input: "onnx::Tile_1" output: "5" - name: "Tile_3" + name: "Tile_4" op_type: "Tile" } name: "torch_jit" - initializer { - dims: 1 - data_type: 7 - name: "onnx::ConstantOfShape_6" - raw_data: "\004\000\000\000\000\000\000\000" - } input { name: "onnx::Expand_0" type { diff --git a/test/onnx/expect/TestOperators.test_shape_value_map.expect b/test/onnx/expect/TestOperators.test_shape_value_map.expect index 174551f9a7c5b..92e5be56549dc 100644 --- a/test/onnx/expect/TestOperators.test_shape_value_map.expect +++ b/test/onnx/expect/TestOperators.test_shape_value_map.expect @@ -54,13 +54,55 @@ graph { name: "Unsqueeze_4" op_type: "Unsqueeze" } + node { + output: "onnx::Concat_26" + name: "Constant_5" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + output: "onnx::Concat_27" + name: "Constant_6" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\002\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + output: "onnx::Concat_28" + name: "Constant_7" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\377\377\377\377\377\377\377\377" + } + type: TENSOR + } + } node { input: "onnx::Concat_8" input: "onnx::Concat_26" input: "onnx::Concat_27" input: "onnx::Concat_28" output: "onnx::Reshape_15" - name: "Concat_5" + name: "Concat_8" op_type: "Concat" attribute { name: "axis" @@ -72,13 +114,13 @@ graph { input: "x" input: "onnx::Reshape_15" output: "onnx::Transpose_16" - name: "Reshape_6" + name: "Reshape_9" op_type: "Reshape" } node { input: "onnx::Transpose_16" output: "x.1" - name: "Transpose_7" + name: "Transpose_10" op_type: "Transpose" attribute { name: "perm" @@ -92,7 +134,7 @@ graph { node { input: "x.1" output: "onnx::Reshape_18" - name: "Softmax_8" + name: "Softmax_11" op_type: "Softmax" attribute { name: "axis" @@ -102,7 +144,7 @@ graph { } node { output: "onnx::Unsqueeze_20" - name: "Constant_9" + name: "Constant_12" op_type: "Constant" attribute { name: "value" @@ -118,14 +160,28 @@ graph { input: "onnx::Unsqueeze_3" input: "onnx::Unsqueeze_20" output: "onnx::Concat_21" - name: "Unsqueeze_10" + name: "Unsqueeze_13" op_type: "Unsqueeze" } + node { + output: "onnx::Concat_29" + name: "Constant_14" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\377\377\377\377\377\377\377\377" + } + type: TENSOR + } + } node { input: "onnx::Concat_21" input: "onnx::Concat_29" output: "onnx::Reshape_24" - name: "Concat_11" + name: "Concat_15" op_type: "Concat" attribute { name: "axis" @@ -137,34 +193,10 @@ graph { input: "onnx::Reshape_18" input: "onnx::Reshape_24" output: "25" - name: "Reshape_12" + name: "Reshape_16" op_type: "Reshape" } name: "torch_jit" - initializer { - dims: 1 - data_type: 7 - name: "onnx::Concat_26" - raw_data: "\001\000\000\000\000\000\000\000" - } - initializer { - dims: 1 - data_type: 7 - name: "onnx::Concat_27" - raw_data: "\002\000\000\000\000\000\000\000" - } - initializer { - dims: 1 - data_type: 7 - name: "onnx::Concat_28" - raw_data: "\377\377\377\377\377\377\377\377" - } - initializer { - dims: 1 - data_type: 7 - name: "onnx::Concat_29" - raw_data: "\377\377\377\377\377\377\377\377" - } input { name: "x" type { diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect b/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect index 2489a21e59ea9..89bda18c735ca 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect @@ -2,12 +2,26 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::Resize_6" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 4 + data_type: 1 + raw_data: "\000\000\200?\000\000\200?\000\000\000@\000\000\000@" + } + type: TENSOR + } + } node { input: "x" input: "" input: "onnx::Resize_6" output: "5" - name: "Resize_0" + name: "Resize_1" op_type: "Resize" attribute { name: "coordinate_transformation_mode" @@ -31,12 +45,6 @@ graph { } } name: "torch_jit" - initializer { - dims: 4 - data_type: 1 - name: "onnx::Resize_6" - raw_data: "\000\000\200?\000\000\200?\000\000\000@\000\000\000@" - } input { name: "x" type { diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect b/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect index 2489a21e59ea9..89bda18c735ca 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect @@ -2,12 +2,26 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::Resize_6" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 4 + data_type: 1 + raw_data: "\000\000\200?\000\000\200?\000\000\000@\000\000\000@" + } + type: TENSOR + } + } node { input: "x" input: "" input: "onnx::Resize_6" output: "5" - name: "Resize_0" + name: "Resize_1" op_type: "Resize" attribute { name: "coordinate_transformation_mode" @@ -31,12 +45,6 @@ graph { } } name: "torch_jit" - initializer { - dims: 4 - data_type: 1 - name: "onnx::Resize_6" - raw_data: "\000\000\200?\000\000\200?\000\000\000@\000\000\000@" - } input { name: "x" type { diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest_size.expect b/test/onnx/expect/TestOperators.test_upsample_nearest_size.expect index 6848d3d312ae0..53219a4045086 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest_size.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest_size.expect @@ -59,11 +59,25 @@ graph { name: "Slice_4" op_type: "Slice" } + node { + output: "onnx::Concat_12" + name: "Constant_5" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 2 + data_type: 7 + raw_data: "\020\000\000\000\000\000\000\000\020\000\000\000\000\000\000\000" + } + type: TENSOR + } + } node { input: "onnx::Concat_6" input: "onnx::Concat_12" output: "onnx::Resize_8" - name: "Concat_5" + name: "Concat_6" op_type: "Concat" attribute { name: "axis" @@ -77,7 +91,7 @@ graph { input: "" input: "onnx::Resize_8" output: "11" - name: "Resize_6" + name: "Resize_7" op_type: "Resize" attribute { name: "coordinate_transformation_mode" @@ -101,12 +115,6 @@ graph { } } name: "torch_jit" - initializer { - dims: 2 - data_type: 7 - name: "onnx::Concat_12" - raw_data: "\020\000\000\000\000\000\000\000\020\000\000\000\000\000\000\000" - } input { name: "x" type { diff --git a/test/onnx/expect/TestOperators.test_view_flatten.expect b/test/onnx/expect/TestOperators.test_view_flatten.expect index ac814160d5bd1..444906b80e476 100644 --- a/test/onnx/expect/TestOperators.test_view_flatten.expect +++ b/test/onnx/expect/TestOperators.test_view_flatten.expect @@ -2,20 +2,28 @@ ir_version: 7 producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { + node { + output: "onnx::Reshape_11" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 2 + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000\030\000\000\000\000\000\000\000" + } + type: TENSOR + } + } node { input: "onnx::Reshape_0" input: "onnx::Reshape_11" output: "8" - name: "Reshape_0" + name: "Reshape_1" op_type: "Reshape" } name: "torch_jit" - initializer { - dims: 2 - data_type: 7 - name: "onnx::Reshape_11" - raw_data: "\001\000\000\000\000\000\000\000\030\000\000\000\000\000\000\000" - } input { name: "onnx::Reshape_0" type { diff --git a/test/onnx/export_onnx_tests_filter.py b/test/onnx/export_onnx_tests_filter.py index cf8afafd9b733..868f72fddc342 100644 --- a/test/onnx/export_onnx_tests_filter.py +++ b/test/onnx/export_onnx_tests_filter.py @@ -6,7 +6,7 @@ import google.protobuf.text_format import onnx.backend.test -import test_onnx_common +import onnx_test_common from test_caffe2_common import run_generated_test from torch.testing._internal.common_device_type import get_all_device_types @@ -20,7 +20,7 @@ def collect_generated_testcases( - root_dir=test_onnx_common.pytorch_converted_dir, + root_dir=onnx_test_common.pytorch_converted_dir, verbose=False, fail_dir=None, expect=True, @@ -95,7 +95,7 @@ def collect_generated_testcases( collect_generated_testcases(verbose=verbose, fail_dir=fail_dir, expect=expect) # We already generate the expect files for test_operators.py. collect_generated_testcases( - root_dir=test_onnx_common.pytorch_operator_dir, + root_dir=onnx_test_common.pytorch_operator_dir, verbose=verbose, fail_dir=fail_dir, expect=False, diff --git a/test/onnx/export_onnx_tests_generator.py b/test/onnx/export_onnx_tests_generator.py index ef728cead0d92..39b18d1e9fa6c 100644 --- a/test/onnx/export_onnx_tests_generator.py +++ b/test/onnx/export_onnx_tests_generator.py @@ -4,11 +4,11 @@ import traceback import onnx -import test_onnx_common -from onnx import numpy_helper -from test_nn import new_module_tests +import onnx_test_common import torch +from onnx import numpy_helper +from test_nn import new_module_tests from torch.autograd import Variable from torch.testing._internal.common_nn import module_tests @@ -110,7 +110,7 @@ def convert_tests(testcases, sets=1): onnx_model = onnx.load_from_string(f.getvalue()) onnx.checker.check_model(onnx_model) onnx.helper.strip_doc_string(onnx_model) - output_dir = os.path.join(test_onnx_common.pytorch_converted_dir, test_name) + output_dir = os.path.join(onnx_test_common.pytorch_converted_dir, test_name) if os.path.exists(output_dir): shutil.rmtree(output_dir) @@ -151,7 +151,7 @@ def convert_tests(testcases, sets=1): ) print( "PyTorch converted cases are stored in {}.".format( - test_onnx_common.pytorch_converted_dir + onnx_test_common.pytorch_converted_dir ) ) print_stats(FunctionalModule_nums, nn_module) diff --git a/test/onnx/test_onnx_common.py b/test/onnx/onnx_test_common.py similarity index 81% rename from test/onnx/test_onnx_common.py rename to test/onnx/onnx_test_common.py index 86400f61d861c..9c617c3434bd0 100644 --- a/test/onnx/test_onnx_common.py +++ b/test/onnx/onnx_test_common.py @@ -3,13 +3,15 @@ from __future__ import annotations import os -import unittest +import random +from typing import Any, Mapping, Type import numpy as np import onnxruntime import torch from torch.onnx import _constants, verification +from torch.testing._internal import common_utils onnx_model_dir = os.path.join( os.path.dirname(os.path.realpath(__file__)), @@ -31,24 +33,39 @@ _ORT_PROVIDERS = ("CPUExecutionProvider",) -def _run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs): +def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs): kwargs["ort_providers"] = _ORT_PROVIDERS kwargs["opset_version"] = test_suite.opset_version kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs return verification.verify(*args, **kwargs) -class _TestONNXRuntime(unittest.TestCase): +def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): + """Combine class name with the parameterized arguments. + + This function is passed to `parameterized.parameterized_class` as the + `class_name_func` argument. + """ + suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items()) + return f"{cls.__name__}_{suffix}" + + +def set_rng_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + +class _TestONNXRuntime(common_utils.TestCase): opset_version = _constants.onnx_default_opset keep_initializers_as_inputs = True # For IR version 3 type export. is_script = False def setUp(self): - torch.manual_seed(0) + set_rng_seed(0) onnxruntime.set_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - np.random.seed(seed=0) os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0" self.is_script_test_enabled = True @@ -69,12 +86,12 @@ def run_test( input_names=None, output_names=None, fixed_batch_size=False, - training=None, + training=torch.onnx.TrainingMode.EVAL, remained_onnx_input_idx=None, verbose=False, ): def _run_test(m, remained_onnx_input_idx, flatten=True): - return _run_model_test( + return run_model_test( self, m, input_args=input_args, diff --git a/test/onnx/test_pytorch_common.py b/test/onnx/pytorch_test_common.py similarity index 90% rename from test/onnx/test_pytorch_common.py rename to test/onnx/pytorch_test_common.py index 04b12bf189309..77bdd28ad4e45 100644 --- a/test/onnx/test_pytorch_common.py +++ b/test/onnx/pytorch_test_common.py @@ -6,13 +6,11 @@ import unittest import torch -import torch.autograd.function as function +from torch.autograd import function pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.insert(-1, pytorch_test_dir) -from torch.testing._internal.common_utils import * # noqa: F401,F403 - torch.set_default_tensor_type("torch.FloatTensor") BATCH_SIZE = 2 @@ -95,21 +93,21 @@ def wrapper(self, *args, **kwargs): return skip_dec -# skips tests for scripting. -def skipScriptTest(min_opset_version=float("inf")): - def script_dec(func): +def skipTraceTest(min_opset_version=float("inf")): + def skip_dec(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - self.is_script_test_enabled = self.opset_version >= min_opset_version + self.is_trace_test_enabled = self.opset_version >= min_opset_version + if not self.is_trace_test_enabled and not self.is_script: + raise unittest.SkipTest("Skip verify test for torch trace") return func(self, *args, **kwargs) return wrapper - return script_dec + return skip_dec -# TODO(#75630): replace `skipScriptTest` with this to parametrize test class. -def skipScriptTest_New(min_opset_version=float("inf")): +def skipScriptTest(min_opset_version=float("inf")): def skip_dec(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): diff --git a/test/onnx/test_caffe2_common.py b/test/onnx/test_caffe2_common.py index e85f4b8aef365..45f5a44ff6b09 100644 --- a/test/onnx/test_caffe2_common.py +++ b/test/onnx/test_caffe2_common.py @@ -3,12 +3,12 @@ import glob import os +import caffe2.python.onnx.backend as c2 + import numpy as np import onnx.backend.test from onnx import numpy_helper -import caffe2.python.onnx.backend as c2 - def load_tensor_as_numpy_array(f): tensor = onnx.TensorProto() diff --git a/test/onnx/test_custom_ops.py b/test/onnx/test_custom_ops.py index d1ddd95446922..db5ddfd001140 100644 --- a/test/onnx/test_custom_ops.py +++ b/test/onnx/test_custom_ops.py @@ -1,18 +1,17 @@ # Owner(s): ["module: onnx"] +import caffe2.python.onnx.backend as c2 import numpy as np import onnx -from test_pytorch_onnx_caffe2 import do_export -from test_pytorch_onnx_onnxruntime import run_model_test - -import caffe2.python.onnx.backend as c2 +import onnx_test_common import torch import torch.utils.cpp_extension -from torch.onnx.symbolic_helper import _unimplemented -from test_pytorch_common import TestCase, run_tests +from test_pytorch_onnx_caffe2 import do_export +from torch.onnx import symbolic_helper +from torch.testing._internal import common_utils -class TestCustomOps(TestCase): +class TestCustomOps(common_utils.TestCase): def test_custom_add(self): op_source = """ #include @@ -56,7 +55,7 @@ def symbolic_custom_add(g, self, other): np.testing.assert_array_equal(caffe2_out[0], model(x, y).cpu().numpy()) -class TestCustomAutogradFunction(TestCase): +class TestCustomAutogradFunction(common_utils.TestCase): opset_version = 9 keep_initializers_as_inputs = False onnx_shape_inference = True @@ -83,7 +82,7 @@ def forward(self, x): x = torch.randn(2, 3, 4, requires_grad=True) model = MyModule() - run_model_test(self, model, input_args=(x,)) + onnx_test_common.run_model_test(self, model, input_args=(x,)) def test_register_custom_op(self): class MyClip(torch.autograd.Function): @@ -117,7 +116,9 @@ def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs): elif name == "MyRelu": return g.op("Relu", args[0], outputs=n.outputsSize()) else: - return _unimplemented("prim::PythonOp", "unknown node kind: " + name) + return symbolic_helper._unimplemented( + "prim::PythonOp", "unknown node kind: " + name + ) from torch.onnx import register_custom_op_symbolic @@ -125,10 +126,10 @@ def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs): x = torch.randn(2, 3, 4, requires_grad=True) model = MyModule() - run_model_test(self, model, input_args=(x,)) + onnx_test_common.run_model_test(self, model, input_args=(x,)) -class TestExportAsContribOps(TestCase): +class TestExportAsContribOps(common_utils.TestCase): opset_version = 14 keep_initializers_as_inputs = False onnx_shape_inference = True @@ -159,8 +160,8 @@ def symbolic_custom_gelu(g, input, approximate): x = torch.randn(3, 3, 4, requires_grad=True) model = torch.jit.script(M()) - run_model_test(self, model, input_args=(x,)) + onnx_test_common.run_model_test(self, model, input_args=(x,)) if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_models.py b/test/onnx/test_models.py index 609cf9c8f5556..3e9044ffb06c7 100644 --- a/test/onnx/test_models.py +++ b/test/onnx/test_models.py @@ -2,26 +2,22 @@ import unittest +import caffe2.python.onnx.backend as backend +import torch + from model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2 from model_defs.mnist import MNIST -from model_defs.op_test import ( - ConcatNet, - DummyNet, - FakeQuantNet, - PermuteNet, - PReluNet, -) +from model_defs.op_test import ConcatNet, DummyNet, FakeQuantNet, PermuteNet, PReluNet from model_defs.squeezenet import SqueezeNet from model_defs.srresnet import SRResNet from model_defs.super_resolution import SuperResolutionNet -from test_pytorch_common import ( - TestCase, - run_tests, - skipIfNoLapack, - skipIfUnsupportedMinOpsetVersion, - skipScriptTest, -) +from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest +from torch import quantization +from torch.autograd import Variable +from torch.onnx import OperatorExportTypes +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import skipIfNoLapack from torchvision.models import shufflenet_v2_x1_0 from torchvision.models.alexnet import alexnet from torchvision.models.densenet import densenet121 @@ -35,14 +31,6 @@ from torchvision.models.video import mc3_18, r2plus1d_18, r3d_18 from verify import verify -import caffe2.python.onnx.backend as backend -import torch -import torch.onnx -import torch.onnx.utils -from torch import quantization -from torch.autograd import Variable -from torch.onnx import OperatorExportTypes - if torch.cuda.is_available(): def toC(x): @@ -57,7 +45,7 @@ def toC(x): BATCH_SIZE = 2 -class TestModels(TestCase): +class TestModels(common_utils.TestCase): opset_version = 9 # Caffe2 doesn't support the default. keep_initializers_as_inputs = False @@ -296,4 +284,4 @@ def test_r2plus1d_18_video(self): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_models_onnxruntime.py b/test/onnx/test_models_onnxruntime.py index f7813df4a9232..20d81c5b59685 100644 --- a/test/onnx/test_models_onnxruntime.py +++ b/test/onnx/test_models_onnxruntime.py @@ -3,16 +3,18 @@ import os import unittest from collections import OrderedDict -from typing import Any, List, Mapping, Tuple, Type +from typing import List, Mapping, Tuple +import onnx_test_common import parameterized import PIL -import test_onnx_common + +import torch import torchvision +from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest from test_models import TestModels -from test_pytorch_common import TestCase, run_tests, skipIfUnsupportedMinOpsetVersion -from test_pytorch_common import skipScriptTest_New as skipScriptTest -from test_pytorch_onnx_onnxruntime import run_model_test +from torch import nn +from torch.testing._internal import common_utils from torchvision import ops from torchvision.models.detection import ( faster_rcnn, @@ -24,9 +26,6 @@ transform, ) -import torch -from torch import nn - def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None): opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12, 13, 14] @@ -34,38 +33,43 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None): for opset_version in opset_versions: self.opset_version = opset_version self.onnx_shape_inference = True - run_model_test(self, model, input_args=inputs, rtol=rtol, atol=atol) + onnx_test_common.run_model_test( + self, model, input_args=inputs, rtol=rtol, atol=atol + ) if self.is_script_test_enabled and opset_version > 11: script_model = torch.jit.script(model) - run_model_test(self, script_model, input_args=inputs, rtol=rtol, atol=atol) + onnx_test_common.run_model_test( + self, script_model, input_args=inputs, rtol=rtol, atol=atol + ) TestModels = type( "TestModels", - (TestCase,), - dict(TestModels.__dict__, is_script_test_enabled=False, exportTest=exportTest), + (common_utils.TestCase,), + dict( + TestModels.__dict__, + is_script_test_enabled=False, + is_script=False, + exportTest=exportTest, + ), ) # model tests for scripting with new JIT APIs and shape inference TestModels_new_jit_API = type( "TestModels_new_jit_API", - (TestCase,), + (common_utils.TestCase,), dict( TestModels.__dict__, exportTest=exportTest, is_script_test_enabled=True, + is_script=True, onnx_shape_inference=True, ), ) -def class_name_func(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): - suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items()) - return f"{cls.__name__}_{suffix}" - - def _get_image(rel_path: str, size: Tuple[int, int]) -> torch.Tensor: data_dir = os.path.join(os.path.dirname(__file__), "assets") path = os.path.join(data_dir, *rel_path.split("/")) @@ -176,10 +180,10 @@ def _init_test_roi_heads_faster_rcnn(): @parameterized.parameterized_class( ("is_script",), - ([True, False],), - class_name_func=class_name_func, + [(True,), (False,)], + class_name_func=onnx_test_common.parameterize_class_name, ) -class TestModelsONNXRuntime(test_onnx_common._TestONNXRuntime): +class TestModelsONNXRuntime(onnx_test_common._TestONNXRuntime): @skipIfUnsupportedMinOpsetVersion(11) @skipScriptTest() # Faster RCNN model is not scriptable def test_faster_rcnn(self): @@ -414,4 +418,4 @@ def test_shufflenet_v2_dynamic_axes(self): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_models_quantized_onnxruntime.py b/test/onnx/test_models_quantized_onnxruntime.py new file mode 100644 index 0000000000000..2c582e7d447e4 --- /dev/null +++ b/test/onnx/test_models_quantized_onnxruntime.py @@ -0,0 +1,97 @@ +# Owner(s): ["module: onnx"] + +import os +import unittest + +import onnx_test_common +import parameterized +import PIL + +import torch +import torchvision +from torch import nn + + +def _get_test_image_tensor(): + data_dir = os.path.join(os.path.dirname(__file__), "assets") + img_path = os.path.join(data_dir, "grace_hopper_517x606.jpg") + input_image = PIL.Image.open(img_path) + # Based on example from https://pytorch.org/hub/pytorch_vision_resnet/ + preprocess = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(256), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + return preprocess(input_image).unsqueeze(0) + + +# Due to precision error from quantization, check only that the top prediction matches. +class _TopPredictor(nn.Module): + def __init__(self, base_model): + super().__init__() + self.base_model = base_model + + def forward(self, x): + x = self.base_model(x) + _, topk_id = torch.topk(x[0], 1) + return topk_id + + +# TODO: All torchvision quantized model test can be written as single parameterized test case, +# after per-parameter test decoration is supported via #79979, or after they are all enabled, +# whichever is first. +@parameterized.parameterized_class( + ("is_script",), + [(True,), (False,)], + class_name_func=onnx_test_common.parameterize_class_name, +) +class TestQuantizedModelsONNXRuntime(onnx_test_common._TestONNXRuntime): + def run_test(self, model, inputs, *args, **kwargs): + model = _TopPredictor(model) + return super().run_test(model, inputs, *args, **kwargs) + + def test_mobilenet_v3(self): + model = torchvision.models.quantization.mobilenet_v3_large( + pretrained=True, quantize=True + ) + self.run_test(model, _get_test_image_tensor()) + + @unittest.skip("quantized::cat not supported") + def test_inception_v3(self): + model = torchvision.models.quantization.inception_v3( + pretrained=True, quantize=True + ) + self.run_test(model, _get_test_image_tensor()) + + @unittest.skip("quantized::cat not supported") + def test_googlenet(self): + model = torchvision.models.quantization.googlenet( + pretrained=True, quantize=True + ) + self.run_test(model, _get_test_image_tensor()) + + @unittest.skip("quantized::cat not supported") + def test_shufflenet_v2_x0_5(self): + model = torchvision.models.quantization.shufflenet_v2_x0_5( + pretrained=True, quantize=True + ) + self.run_test(model, _get_test_image_tensor()) + + def test_resnet18(self): + model = torchvision.models.quantization.resnet18(pretrained=True, quantize=True) + self.run_test(model, _get_test_image_tensor()) + + def test_resnet50(self): + model = torchvision.models.quantization.resnet50(pretrained=True, quantize=True) + self.run_test(model, _get_test_image_tensor()) + + def test_resnext101_32x8d(self): + model = torchvision.models.quantization.resnext101_32x8d( + pretrained=True, quantize=True + ) + self.run_test(model, _get_test_image_tensor()) diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index f836e4783ebec..6bce330e23557 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -4,13 +4,13 @@ import itertools import onnx -from test_pytorch_common import TestCase, run_tests import torch import torch.onnx from torch.nn import Module from torch.onnx import producer_name, producer_version from torch.onnx._globals import GLOBALS +from torch.testing._internal import common_utils def check_onnx_opset_operator( @@ -70,7 +70,7 @@ def check_onnx_opsets_operator( check_onnx_opset_operator(model, ops[opset_version], opset_version) -class TestONNXOpset(TestCase): +class TestONNXOpset(common_utils.TestCase): def test_opset_fallback(self): class MyModule(Module): def forward(self, x): @@ -96,7 +96,8 @@ def forward(self, x): } ] ops_10 = [ - {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]} + {"op_name": "Constant"}, + {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]}, ] ops = {9: ops_9, 10: ops_10} x = torch.arange(1.0, 6.0, requires_grad=True) @@ -253,11 +254,13 @@ def forward(self, x): {"op_name": "Shape"}, {"op_name": "Constant"}, {"op_name": "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]}, + {"op_name": "Constant"}, { "op_name": "Unsqueeze", "attributes": [{"name": "axes", "i": 0, "type": 7}], }, {"op_name": "Constant"}, + {"op_name": "Constant"}, {"op_name": "Slice", "attributes": []}, ] ops = {10: ops_10} @@ -426,6 +429,7 @@ def forward(self, x): ) ops_9 = [ + {"op_name": "Constant"}, {"op_name": "Shape"}, {"op_name": "Slice"}, {"op_name": "Cast"}, @@ -438,6 +442,7 @@ def forward(self, x): }, ] ops_10 = [ + {"op_name": "Constant"}, {"op_name": "Shape"}, {"op_name": "Constant"}, {"op_name": "Constant"}, @@ -519,4 +524,4 @@ def forward(self, x, grid, mode, padding_mode, align_corers): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index ff890b3f926bd..bd66c38ff5ecf 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -8,34 +8,27 @@ import shutil import tempfile -from test_pytorch_common import ( +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.onnx + +from pytorch_test_common import ( BATCH_SIZE, + flatten, RNN_HIDDEN_SIZE, RNN_INPUT_SIZE, RNN_SEQUENCE_LENGTH, - TestCase, - flatten, - run_tests, - skipIfCaffe2, - skipIfNoLapack, ) - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.onnx -import torch.testing._internal.common_utils as common from torch.autograd import Function, Variable -from torch.nn import Module, functional -from torch.onnx import ( - register_custom_op_symbolic, - unregister_custom_op_symbolic, -) +from torch.nn import functional, Module from torch.onnx.symbolic_helper import ( _get_tensor_dim_size, _get_tensor_sizes, parse_args, ) +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import skipIfCaffe2, skipIfNoLapack """Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data] --no-onnx: no onnx python dependence @@ -77,7 +70,7 @@ def forward(self, *args): return self.f(*itertools.chain(args, self.params)) -class TestOperators(TestCase): +class TestOperators(common_utils.TestCase): def assertONNX(self, f, args, params=None, **kwargs): if params is None: params = () @@ -94,7 +87,7 @@ def assertONNX(self, f, args, params=None, **kwargs): import onnx import onnx.checker import onnx.numpy_helper - import test_onnx_common + import onnx_test_common model_def = onnx.ModelProto.FromString(onnx_model_pb) onnx.checker.check_model(model_def) @@ -102,7 +95,7 @@ def assertONNX(self, f, args, params=None, **kwargs): test_function = inspect.stack()[1][0].f_code.co_name test_name = test_function[0:4] + "_operator" + test_function[4:] output_dir = os.path.join( - test_onnx_common.pytorch_operator_dir, test_name + onnx_test_common.pytorch_operator_dir, test_name ) # Assume: # 1) the old test should be delete before the test. @@ -202,6 +195,16 @@ def test_rsub(self): x = torch.randn(2, 3, requires_grad=True).double() self.assertONNX(lambda x: 1 - x, (x,)) + def test_mul_bool(self): + x = torch.tensor([True, False, True, False]) + y = torch.tensor([True, True, False, False]) + self.assertONNX(lambda x, y: torch.mul(x, y), (x, y)) + + def test_mul_fp_bool(self): + x = torch.tensor([9.4, 1.7, 3.6]) + y = torch.tensor([True, True, False]) + self.assertONNX(lambda x, y: torch.mul(x, y), (x, y)) + def test_transpose(self): x = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True) self.assertONNX(lambda x: x.transpose(0, 1).transpose(1, 0), x) @@ -737,10 +740,6 @@ def test_empty_like(self): x = torch.randn(5, 8, requires_grad=True) self.assertONNX(lambda x: torch.empty_like(x), x) - def test_empty_like_opset7(self): - x = torch.randn(5, 8, requires_grad=True) - self.assertONNX(lambda x: torch.empty_like(x), x, opset_version=7) - def test_zeros_like(self): x = torch.randn(5, 8, requires_grad=True) self.assertONNX(lambda x: torch.zeros_like(x), x) @@ -1159,7 +1158,9 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): ) return output - register_custom_op_symbolic("::embedding", embedding, _onnx_opset_version) + torch.onnx.register_custom_op_symbolic( + "::embedding", embedding, _onnx_opset_version + ) class Model(torch.nn.Module): def __init__(self): @@ -1176,7 +1177,7 @@ def forward(self, x, y): y = torch.randn(1, 8) self.assertONNX(model, (x, y), opset_version=_onnx_opset_version) - unregister_custom_op_symbolic("::embedding", _onnx_opset_version) + torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version) # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding. @skipIfCaffe2 @@ -1208,7 +1209,9 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): output.setType(output_type) return output - register_custom_op_symbolic("::embedding", embedding, _onnx_opset_version) + torch.onnx.register_custom_op_symbolic( + "::embedding", embedding, _onnx_opset_version + ) class Model(torch.nn.Module): def __init__(self): @@ -1233,7 +1236,7 @@ def forward(self, x, y): operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, ) - unregister_custom_op_symbolic("::embedding", _onnx_opset_version) + torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version) # Without shapeValueMap, the onnx graph looks like: # graph(%0 : Float(*, 1, 128, 1, strides=[128, 128, 1, 1], requires_grad=0, device=cpu)): @@ -1277,19 +1280,19 @@ def forward(self, x): if __name__ == "__main__": no_onnx_dep_flag = "--no-onnx" - _onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS - if no_onnx_dep_flag in common.UNITTEST_ARGS: - common.UNITTEST_ARGS.remove(no_onnx_dep_flag) + _onnx_dep = no_onnx_dep_flag not in common_utils.UNITTEST_ARGS + if no_onnx_dep_flag in common_utils.UNITTEST_ARGS: + common_utils.UNITTEST_ARGS.remove(no_onnx_dep_flag) onnx_test_flag = "--produce-onnx-test-data" - _onnx_test = onnx_test_flag in common.UNITTEST_ARGS - if onnx_test_flag in common.UNITTEST_ARGS: - common.UNITTEST_ARGS.remove(onnx_test_flag) + _onnx_test = onnx_test_flag in common_utils.UNITTEST_ARGS + if onnx_test_flag in common_utils.UNITTEST_ARGS: + common_utils.UNITTEST_ARGS.remove(onnx_test_flag) if _onnx_test: _onnx_dep = True - import test_onnx_common + import onnx_test_common for d in glob.glob( - os.path.join(test_onnx_common.pytorch_operator_dir, "test_operator_*") + os.path.join(onnx_test_common.pytorch_operator_dir, "test_operator_*") ): shutil.rmtree(d) - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_helper.py b/test/onnx/test_pytorch_helper.py index eeb5f88f17f8b..362841d8bf90f 100644 --- a/test/onnx/test_pytorch_helper.py +++ b/test/onnx/test_pytorch_helper.py @@ -4,17 +4,18 @@ import unittest import numpy as np -from pytorch_helper import PyTorchModule -from test_pytorch_common import skipIfNoLapack, run_tests, TestCase import torch.nn.init as init import torch.onnx from caffe2.python.core import workspace from caffe2.python.model_helper import ModelHelper +from pytorch_helper import PyTorchModule from torch import nn +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import skipIfNoLapack -class TestCaffe2Backend(TestCase): +class TestCaffe2Backend(common_utils.TestCase): @skipIfNoLapack @unittest.skip("test broken because Lapack was always missing.") def test_helper(self): @@ -67,4 +68,4 @@ def _initialize_weights(self): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_jit_onnx.py b/test/onnx/test_pytorch_jit_onnx.py index f0cabae76b549..a850779e6df9d 100644 --- a/test/onnx/test_pytorch_jit_onnx.py +++ b/test/onnx/test_pytorch_jit_onnx.py @@ -2,9 +2,9 @@ import onnxruntime import torch -from torch._C import parse_ir +from pytorch_test_common import skipIfNoCuda from torch.onnx import verification -from test_pytorch_common import TestCase, run_tests +from torch.testing._internal import common_utils def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version): @@ -16,14 +16,14 @@ def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version): It also does not interact with actual PyTorch modules nor PyTorch tensor inputs. """ - from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version - from torch.onnx.utils import _optimize_graph # Shape inference is required because some ops' symbolic functions # generate sub-graphs based on inputs' types. - _set_onnx_shape_inference(True) - _set_opset_version(opset_version) - graph = _optimize_graph(graph, operator_export_type, params_dict={}) + torch.onnx.symbolic_helper._set_onnx_shape_inference(True) + torch.onnx.symbolic_helper._set_opset_version(opset_version) + graph = torch.onnx.utils._optimize_graph( + graph, operator_export_type, params_dict={} + ) proto, _, _, _ = graph._export_onnx( {}, opset_version, @@ -52,7 +52,7 @@ class _TestJITIRToONNX: ort_providers = ["CPUExecutionProvider"] def run_test(self, graph_ir, example_inputs): - graph = parse_ir(graph_ir) + graph = torch._C.parse_ir(graph_ir) jit_outs = torch._C._jit_interpret_graph(graph, example_inputs) onnx_proto = _jit_graph_to_onnx_model( @@ -79,12 +79,92 @@ def test_example_ir(self): b = torch.randn(2, 3) self.run_test(graph_ir, (a, b)) + def test_add_sub_with_graph_inputs(self): + for op in ["add", "sub", "rsub"]: + graph_ir = f""" + graph(%1 : Float(2, 3), + %2 : Float(2, 3), + %3 : int): + %4 : Float(2, 3) = aten::{op}(%1, %2, %3) + return (%4) + """ + a = torch.randn(2, 3) + b = torch.randn(2, 3) + self.run_test(graph_ir, (a, b, 2)) + + def test_native_layer_norm(self): + graph_ir = """ + graph(%x : Float(2, 3, 2), + %w : Float(3, 2), + %b : Float(3, 2)): + %5 : int = prim::Constant[value=3]() + %6 : int = prim::Constant[value=2]() + %7 : int[] = prim::ListConstruct(%5, %6) + %10 : float = prim::Constant[value=1.0000000000000001e-05]() + %11 : Float(2, 3, 2), %12 : Float(2, 1, 1), %13 : Float(2, 1, 1) = aten::native_layer_norm(%x, %7, %w, %b, %10) + return (%11, %12, %13) + """ + x = torch.randn(2, 3, 2) + w = torch.randn(3, 2) + b = torch.randn(3, 2) + self.run_test(graph_ir, (x, w, b)) + + def test_convolution(self): + graph_ir = """ + graph(%1 : Tensor, + %2 : Tensor): + %3 : NoneType = prim::Constant() + %4 : int[] = prim::Constant[value=[1, 1]]() + %5 : int[] = prim::Constant[value=[0, 0]]() + %6 : bool = prim::Constant[value=0]() + %7 : int = prim::Constant[value=1]() + %8 : Tensor = aten::convolution(%1, %2, %3, %4, %5, %4, %6, %5, %7) + return (%8) + """ + x = torch.randn(8, 1, 5, 5) + w = torch.randn(4, 1, 3, 3) + self.run_test(graph_ir, (x, w)) + + def test_log_softmax(self): + graph_ir = """ + graph(%x: Tensor): + %half_to_float: bool = prim::Constant[value=0]() + %dim: int = prim::Constant[value=1]() + %y = aten::_log_softmax(%x, %dim, %half_to_float) + return (%y) + """ + x = torch.randn(5, 2) + self.run_test(graph_ir, (x,)) + + @skipIfNoCuda + def test_log_softmax_half_to_float(self): + graph_ir = """ + graph(%x: Tensor): + %half_to_float: bool = prim::Constant[value=1]() + %dim: int = prim::Constant[value=1]() + %y = aten::_log_softmax(%x, %dim, %half_to_float) + return (%y) + """ + x = torch.randn(5, 2).half().to("cuda") + self.run_test(graph_ir, (x,)) + + def test_native_dropout(self): + graph_ir = """ + graph(%1 : Float(2, 3)): + %2 : float = prim::Constant[value=0.0]() + %training : bool = prim::Constant[value=1]() + %3 : Tensor, %4 : Tensor = aten::native_dropout(%1, %2, %training) + return (%3, %4) + """ + a = torch.randn(2, 3) + self.run_test(graph_ir, (a,)) + def MakeTestCase(opset_version: int) -> type: name = f"TestJITIRToONNX_opset{opset_version}" return type( str(name), - (TestCase,), + (common_utils.TestCase,), dict(_TestJITIRToONNX.__dict__, opset_version=opset_version), ) @@ -92,4 +172,4 @@ def MakeTestCase(opset_version: int) -> type: TestJITIRToONNX_opset14 = MakeTestCase(14) if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 10d419fd4455c..141d3683171f6 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -6,11 +6,20 @@ import unittest from typing import Tuple +import caffe2.python.onnx.backend as c2 + import model_defs.dcgan as dcgan import model_defs.word_language_model as word_language_model import numpy as np import onnx +import torch.onnx +import torch.onnx.operators +import torch.utils.model_zoo as model_zoo import verify +from caffe2.python.operator_test.torch_integration_test import ( + create_bbox_transform_inputs, + generate_rois_rotated, +) from debug_embed_params import run_embed_params from model_defs.lstm_flattening_result import LstmFlatteningResult from model_defs.mnist import MNIST @@ -18,20 +27,23 @@ from model_defs.squeezenet import SqueezeNet from model_defs.srresnet import SRResNet from model_defs.super_resolution import SuperResolutionNet -from test_pytorch_common import ( +from pytorch_test_common import ( BATCH_SIZE, RNN_BATCH_SIZE, RNN_HIDDEN_SIZE, RNN_INPUT_SIZE, RNN_SEQUENCE_LENGTH, skipIfNoCuda, - skipIfNoLapack, skipIfTravis, skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion, - TestCase, - run_tests, ) +from torch import nn +from torch.autograd import function, Variable +from torch.nn.utils import rnn as rnn_utils +from torch.onnx import ExportTypes +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import skipIfNoLapack # Import various models for testing from torchvision.models.alexnet import alexnet @@ -40,19 +52,6 @@ from torchvision.models.resnet import resnet50 from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn -import caffe2.python.onnx.backend as c2 -import torch.onnx -import torch.onnx.operators -import torch.utils.model_zoo as model_zoo -from caffe2.python.operator_test.torch_integration_test import ( - create_bbox_transform_inputs, - generate_rois_rotated, -) -from torch import nn -from torch.autograd import Variable, function -from torch.nn.utils import rnn as rnn_utils -from torch.onnx import ExportTypes - skip = unittest.skip @@ -130,13 +129,13 @@ def do_export(model, inputs, *args, **kwargs): } -class TestCaffe2Backend_opset9(TestCase): +class TestCaffe2Backend_opset9(common_utils.TestCase): opset_version = 9 embed_params = False def setUp(self): # the following should ideally be super().setUp(), https://github.com/pytorch/pytorch/issues/79630 - TestCase.setUp(self) + common_utils.TestCase.setUp(self) torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) @@ -3199,44 +3198,44 @@ def setup_rnn_tests(): # to embed_params=True TestCaffe2BackendEmbed_opset9 = type( "TestCaffe2BackendEmbed_opset9", - (TestCase,), + (common_utils.TestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True), ) # opset 7 tests TestCaffe2Backend_opset7 = type( "TestCaffe2Backend_opset7", - (TestCase,), + (common_utils.TestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=7), ) TestCaffe2BackendEmbed_opset7 = type( "TestCaffe2BackendEmbed_opset7", - (TestCase,), + (common_utils.TestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=7), ) # opset 8 tests TestCaffe2Backend_opset8 = type( "TestCaffe2Backend_opset8", - (TestCase,), + (common_utils.TestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=8), ) TestCaffe2BackendEmbed_opset8 = type( "TestCaffe2BackendEmbed_opset8", - (TestCase,), + (common_utils.TestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=8), ) # opset 10 tests TestCaffe2Backend_opset10 = type( "TestCaffe2Backend_opset10", - (TestCase,), + (common_utils.TestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=10), ) TestCaffe2BackendEmbed_opset10 = type( "TestCaffe2BackendEmbed_opset10", - (TestCase,), + (common_utils.TestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=10), ) @@ -3244,9 +3243,9 @@ def setup_rnn_tests(): # to embed_params=True TestCaffe2BackendEmbed_opset9_new_jit_API = type( "TestCaffe2BackendEmbed_opset9_new_jit_API", - (TestCase,), + (common_utils.TestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True), ) if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_caffe2_quantized.py b/test/onnx/test_pytorch_onnx_caffe2_quantized.py index 5aacba78df4a4..6b6e813640087 100644 --- a/test/onnx/test_pytorch_onnx_caffe2_quantized.py +++ b/test/onnx/test_pytorch_onnx_caffe2_quantized.py @@ -1,17 +1,18 @@ # Owner(s): ["module: unknown"] import io -import numpy as np -import onnx import caffe2.python.onnx.backend as c2 + +import numpy as np +import onnx import torch.nn as nn import torch.nn.quantized as nnq import torch.onnx -from test_pytorch_common import TestCase, run_tests +from torch.testing._internal import common_utils -class TestQuantizedOps(TestCase): +class TestQuantizedOps(common_utils.TestCase): def generic_test( self, model, sample_inputs, input_names=None, decimal=3, relaxed_check=False ): @@ -377,4 +378,4 @@ def forward(self, x): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 59ffa5fe6f17e..ccde15a745c25 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -6,13 +6,15 @@ import io import itertools import unittest -from typing import Dict, Optional, Type, Callable, Iterable, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union import onnx +import onnx.numpy_helper import torch +import torch.nn.functional as F from torch import Tensor -from torch.onnx import symbolic_helper, utils, symbolic_registry +from torch.onnx import symbolic_helper, symbolic_registry, utils from torch.onnx._globals import GLOBALS from torch.testing._internal import common_utils @@ -134,27 +136,6 @@ def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int) if attr.name in ("then_branch", "else_branch"): self.assertEqual(expected_output_type, attr.g.output[0].type) - def test_uninitialized_optional(self): - class Module(torch.nn.Module): - def forward(self, y: Optional[Tensor]) -> Optional[Tensor]: - if y is not None: - if y.shape[1] < 5: - if y.size(0) == 1: - y = y + 4 - else: - return y - return y - - y = torch.ones((3, 4), dtype=torch.int) - torch.onnx.export( - torch.jit.script(Module()), - y, - io.BytesIO(), - opset_version=15, - dynamic_axes={"y": {0: "y0", 1: "y1"}}, - input_names=["y"], - ) - class TestONNXExport(common_utils.TestCase): def test_fuse_addmm(self): @@ -575,6 +556,280 @@ def cast_device_cpu_string(src: torch.Tensor) -> torch.Tensor: self._helper_test_to_(cast_device_cpu_string) + def test_script_custom_class_error(self): + class BoxCoder: + def __init__(self, bbox_xform_clip: float) -> None: + self.bbox_xform_clip = bbox_xform_clip + + def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: + boxes = torch.cat(boxes, dim=0) + pred_ctr_x = ( + torch.clamp(rel_codes[:, 0::4], max=self.bbox_xform_clip) + * boxes[:, 2] + ) + return pred_ctr_x + + class MyModule(torch.nn.Module): + __annotations__ = { + "box_coder": BoxCoder, + } + + def __init__(self): + super().__init__() + self.box_coder = BoxCoder(1.4) + + def forward(self, box_regression: Tensor, proposals: List[Tensor]): + return self.box_coder.decode(box_regression, proposals) + + model = torch.jit.script(MyModule()) + box_regression = torch.randn([4, 4]) + proposal = [torch.randn(2, 4), torch.randn(2, 4)] + + with self.assertRaises(RuntimeError) as cm: + onnx_model = io.BytesIO() + torch.onnx.export( + model, + (box_regression, proposal), + onnx_model, + ) + + def test_initializer_sequence(self): + class MyModule(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super().__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + test_model = MyModule(3, 4, 10) + state_dict_list = [k for (k, v) in test_model.state_dict().items()] + named_params_list = [k for (k, v) in test_model.named_parameters()] + + x = torch.randn(32, 3) + f = io.BytesIO() + torch.onnx._export(test_model, (x,), f, do_constant_folding=False) + loaded_model = onnx.load_from_string(f.getvalue()) + + actual_list = [p.name for p in loaded_model.graph.initializer] + assert actual_list == state_dict_list, ( + "Initializers' sequence is not as same as state_dict(). Expected: (" + + ", ".join(state_dict_list) + + "). Actual:(" + + ", ".join(actual_list) + + ")." + ) + assert actual_list == named_params_list, ( + "Initializers' sequence is not as same as named_parameters(). Expected: (" + + ", ".join(named_params_list) + + "). Actual:(" + + ", ".join(actual_list) + + ")." + ) + + def test_initializer_sequence_script_model(self): + def list_is_expected(short_list, long_list) -> bool: + if len(short_list) > len(long_list): + return False + + for i in range(len(short_list)): + if short_list[i] not in long_list[i]: + return False + + return True + + def loop(x, y): + for i in range(int(y)): + x = x + i + return x + + class MyModule(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super().__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, x, y): + x = loop(x, y) + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + test_model = torch.jit.script(MyModule(3, 4, 10)) + state_dict_list = [k for (k, v) in test_model.state_dict().items()] + named_params_list = [k for (k, v) in test_model.named_parameters()] + + x = torch.ones(2, 3, dtype=torch.float) + y = torch.tensor(5, dtype=torch.long) + f = io.BytesIO() + + torch.onnx.export(test_model, (x, y), f, do_constant_folding=False) + loaded_model = onnx.load_from_string(f.getvalue()) + + actual_list = [p.name for p in loaded_model.graph.initializer] + assert list_is_expected(state_dict_list, actual_list), ( + "ScriptModel - Initializers' sequence is not as same as state_dict(). Expected: (" + + ", ".join(state_dict_list) + + "). Actual:(" + + ", ".join(actual_list) + + ")." + ) + assert list_is_expected(named_params_list, actual_list), ( + "ScriptModel - Initializers' sequence is not as same as named_parameters(). Expected: (" + + ", ".join(named_params_list) + + "). Actual:(" + + ", ".join(actual_list) + + ")." + ) + + def test_onnx_checker_invalid_graph(self): + class CustomAddModule(torch.nn.Module): + def forward(self, x, y): + return torch.add(x, y) + + def symbolic_custom_invalid_add(g, input, other, alpha=None): + return g.op("Add", input, other, invalid_attr_i=1) + + torch.onnx.register_custom_op_symbolic("::add", symbolic_custom_invalid_add, 1) + + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + + test_model = CustomAddModule() + f = io.BytesIO() + + try: + with self.assertRaises(torch.onnx.errors.CheckerError): + torch.onnx.export(test_model, (x, y), f) + finally: + torch.onnx.unregister_custom_op_symbolic("::add", 1) + + self.assertTrue(f.getvalue(), "ONNX graph was not exported.") + loaded_model = onnx.load_from_string(f.getvalue()) + + def test_shape_value_map(self): + class RSoftMax(torch.nn.Module): + def __init__(self, radix, cardinality): + super().__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + return x + + radix = 2 + cardinality = 1 + x = torch.randn(10, 1, 128, 1) + f = io.BytesIO() + torch.onnx.export( + RSoftMax(radix, cardinality), + (x,), + f, + input_names=["x"], + dynamic_axes={"x": [0]}, + ) + loaded_model = onnx.load_from_string(f.getvalue()) + self.assertEqual( + loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128 + ) + + def test_onnx_proto_checker(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 2 * x + + x = torch.randn(1, 2, 3, requires_grad=True) + f = io.BytesIO() + torch.onnx.export(Model(), x, f) + model = onnx.load(f) + model.ir_version = 0 + + def check_proto(): + torch._C._check_onnx_proto(model.SerializeToString()) + + self.assertRaises(RuntimeError, check_proto) + + def test_maintain_dynamic_shapes_of_unreliable_nodes(self): + def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs): + return g.op("com.microsoft::PythonOp") + + torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1) + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1) + + # necessay parameters for transformer embeddings + hidden_size = 48 + max_position_embeddings = 32 + batch_size = 2 + + # issue found that autograd.function making downstream + # node unreliable but with static shape. The issue was first + # discovered with using Apex FusedLayerNorm in Transformers + class CustomLayerNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, embedding): + layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) + return layer_norm(embedding) + + class EmbeddingModule(torch.nn.Module): + def forward( + self, + embeddings=None, + ): + embedding_output = CustomLayerNorm.apply(embeddings) + query = embedding_output.transpose(0, 1) + target_len, batch_size, embedding_dim = query.size() + # Reshape is used for consuming batch_size, and if it is static, + # this will be a Constant node in the graph + query = query.reshape(target_len, batch_size, embedding_dim) + return query + + embeddings = torch.randn(batch_size, max_position_embeddings, hidden_size) + + f = io.BytesIO() + torch.onnx.export( + EmbeddingModule().eval(), + (embeddings,), + f, + input_names=["embeddings"], + dynamic_axes={ + "embeddings": { + 0: "batch_size", + 1: "max_position_embeddings", + 2: "hidden_size", + } + }, + custom_opsets={"com.microsoft": 1}, + ) + model = onnx.load(io.BytesIO(f.getvalue())) + + # If there is a constant node with dim=3 and max_position_embeddings, + # batch_size, hidden_size as shape, it means the shape becomes static. + # Normally, with dynamic batch size, this constant node should not exist. + const_node = [n for n in model.graph.node if n.op_type == "Constant"] + self.assertNotEqual(len(const_node), 0) + for node in const_node: + for a in node.attribute: + if a.name == "value": + shape = onnx.numpy_helper.to_array(a.t) + self.assertNotEqual( + shape.tolist(), + [max_position_embeddings, batch_size, hidden_size], + ) + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index ccd4fdc1dfb17..f3763df682bcf 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -5,75 +5,39 @@ import io import itertools import os -import random import unittest from collections import OrderedDict from typing import Dict, List, Optional, Tuple, Union -import model_defs.word_language_model as word_language_model import numpy as np -import onnx -import onnxruntime +import onnx_test_common +import parameterized + +import torch import torchvision -from model_defs.lstm_flattening_result import ( - LstmFlatteningResultWithoutSeqLength, - LstmFlatteningResultWithSeqLength, -) -from model_defs.rnn_model_with_packed_sequence import ( - RnnModelWithPackedSequence, - RnnModelWithPackedSequenceWithoutState, - RnnModelWithPackedSequenceWithState, +from model_defs import ( + lstm_flattening_result, + rnn_model_with_packed_sequence, + word_language_model, ) -from test_pytorch_common import ( +from pytorch_test_common import ( BATCH_SIZE, RNN_BATCH_SIZE, RNN_HIDDEN_SIZE, RNN_INPUT_SIZE, RNN_SEQUENCE_LENGTH, - run_tests, - skipIfNoLapack, + skipForAllOpsetVersions, skipIfUnsupportedMaxOpsetVersion, skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion, skipScriptTest, - TestCase, + skipTraceTest, ) -from torchvision import ops -from torchvision.models.detection.image_list import ImageList -from torchvision.models.detection.rpn import ( - AnchorGenerator, - RegionProposalNetwork, - RPNHead, -) -from torchvision.models.detection.transform import GeneralizedRCNNTransform - -import torch -import torch.nn.functional as F -import torch.onnx.verification as verification from torch import Tensor from torch.nn.utils import rnn as rnn_utils -from torch.nn.utils.rnn import PackedSequence -from torch.onnx import ( - register_custom_op_symbolic, - unregister_custom_op_symbolic, -) -from torch.onnx.symbolic_helper import _unimplemented - -_ORT_PROVIDERS = ["CPUExecutionProvider"] - - -def run_model_test(test_suite: Union[_TestONNXRuntime, TestCase], *args, **kwargs): - kwargs["ort_providers"] = _ORT_PROVIDERS - kwargs["opset_version"] = test_suite.opset_version - kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs - return verification.verify(*args, **kwargs) - - -def run_model_test_with_external_data( - test_suite: Union[_TestONNXRuntime, TestCase], *args, **kwargs -): - kwargs["use_external_data"] = True - return run_model_test(test_suite, *args, **kwargs) +from torch.onnx import verification +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import skipIfNoLapack def _init_test_generalized_rcnn_transform(): @@ -81,16 +45,22 @@ def _init_test_generalized_rcnn_transform(): max_size = 200 image_mean = [0.485, 0.456, 0.406] image_std = [0.229, 0.224, 0.225] - transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + transform = torchvision.models.detection.transform.GeneralizedRCNNTransform( + min_size, max_size, image_mean, image_std + ) return transform def _init_test_rpn(): anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) + rpn_anchor_generator = torchvision.models.detection.rpn.AnchorGenerator( + anchor_sizes, aspect_ratios + ) out_channels = 256 - rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) + rpn_head = torchvision.models.detection.rpn.RPNHead( + out_channels, rpn_anchor_generator.num_anchors_per_location()[0] + ) rpn_fg_iou_thresh = 0.7 rpn_bg_iou_thresh = 0.3 rpn_batch_size_per_image = 256 @@ -100,7 +70,7 @@ def _init_test_rpn(): rpn_nms_thresh = 0.7 rpn_score_thresh = 0.0 - rpn = RegionProposalNetwork( + rpn = torchvision.models.detection.rpn.RegionProposalNetwork( rpn_anchor_generator, rpn_head, rpn_fg_iou_thresh, @@ -119,7 +89,7 @@ def _construct_tensor_for_quantization_test( shape: Tuple[int, ...], offset: Optional[Union[int, float]] = None, max_val: Optional[Union[int, float]] = None, -) -> torch.Tensor: +) -> Tensor: """Helper function to generate weights and test inputs in a deterministic way. Due to difference in implementation details between PyTorch and ONNXRuntime, randomly generated @@ -141,90 +111,42 @@ def _construct_tensor_for_quantization_test( return tensor -def set_rng_seed(seed): - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - - -class _TestONNXRuntime: - """Abstract base class for test cases. - - Intentionally not a sub-class of unittest.TestCase so that unittest / pytest - don't run it directly. unitest.TestCase is mixed in as another base class when - creating concrete sub-types. See MakeTestCase(). - """ - - opset_version = -1 # Sub-classes must override - keep_initializers_as_inputs = True # For IR version 3 type export. - - def setUp(self): - set_rng_seed(0) - onnxruntime.set_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0" - self.is_script_test_enabled = True - - # The exported ONNX model may have less inputs than the pytorch model because of const folding. - # This mostly happens in unit test, where we widely use torch.size or torch.shape. - # So the output is only dependent on the input shape, not value. - # remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model. - def run_test( - self, - model, - input_args, - input_kwargs=None, - rtol=1e-3, - atol=1e-7, - do_constant_folding=True, - dynamic_axes=None, - additional_test_inputs=None, - input_names=None, - output_names=None, - fixed_batch_size=False, - training=torch.onnx.TrainingMode.EVAL, - remained_onnx_input_idx=None, - verbose=False, - ): - def _run_test(m, remained_onnx_input_idx, flatten=True): - return run_model_test( - self, - m, - input_args=input_args, - input_kwargs=input_kwargs, - rtol=rtol, - atol=atol, - do_constant_folding=do_constant_folding, - dynamic_axes=dynamic_axes, - additional_test_inputs=additional_test_inputs, - input_names=input_names, - output_names=output_names, - fixed_batch_size=fixed_batch_size, - training=training, - remained_onnx_input_idx=remained_onnx_input_idx, - flatten=flatten, - verbose=verbose, - ) - - if isinstance(remained_onnx_input_idx, dict): - scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"] - tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"] - else: - scripting_remained_onnx_input_idx = remained_onnx_input_idx - tracing_remained_onnx_input_idx = remained_onnx_input_idx - - is_script = isinstance( - model, (torch.jit.ScriptModule, torch.jit.ScriptFunction) - ) - - if self.is_script_test_enabled: - script_model = model if is_script else torch.jit.script(model) - _run_test(script_model, scripting_remained_onnx_input_idx, flatten=False) - - if not is_script: - _run_test(model, tracing_remained_onnx_input_idx) - +def _parameterized_class_attrs_and_values(): + attrs = ("opset_version", "is_script", "keep_initializers_as_inputs") + input_values = [] + input_values.extend(itertools.product((7, 8), (True, False), (True,))) + # Valid opset versions are defined in torch/onnx/_constants.py. + # Versions are intentionally set statically, to not be affected by elsewhere changes. + input_values.extend(itertools.product(range(9, 17), (True, False), (True, False))) + return {"attrs": attrs, "input_values": input_values} + + +def _parametrize_rnn_args(arg_name): + options = { + "layers": {1: "unilayer", 3: "trilayer"}, + "bidirectional": {True: "bidirectional", False: "forward"}, + "initial_state": {True: "with_initial_state", False: "no_initial_state"}, + "packed_sequence": { + 0: "without_sequence_lengths", + 1: "with_variable_length_sequences", + 2: "with_batch_first_sequence_lengths", + }, + "dropout": {0.2: "with_dropout", 0.0: "without_dropout"}, + } + + return { + "arg_str": arg_name, + "arg_values": options[arg_name].keys(), + "name_fn": lambda val: options[arg_name][val], + } + + +@parameterized.parameterized_class( + **_parameterized_class_attrs_and_values(), + class_name_func=onnx_test_common.parameterize_class_name, +) +@common_utils.instantiate_parametrized_tests +class TestONNXRuntime(onnx_test_common._TestONNXRuntime): def test_fuse_conv_bn1d(self): class Fuse(torch.nn.Module): def __init__(self): @@ -387,8 +309,6 @@ def run_word_language_model(self, model_name): self.run_test(model, (x, model.hidden)) def get_image(self, rel_path: str, size: Tuple[int, int]) -> Tensor: - import os - from PIL import Image from torchvision import transforms @@ -459,52 +379,6 @@ def test_heatmaps_to_keypoints(self): assert torch.all(out2[0].eq(out_trace2[0])) assert torch.all(out2[1].eq(out_trace2[1])) - @unittest.skip( - "Unstable loading pretrained quantized mobilenet v3: https://github.com/pytorch/vision/issues/5303" - ) - @skipIfUnsupportedMinOpsetVersion(10) - @skipScriptTest() - def test_mobilenet_v3_quant(self): - model = torchvision.models.quantization.mobilenet_v3_large( - pretrained=True, quantize=True - ) - from PIL import Image - from torchvision import transforms - - data_dir = os.path.join(os.path.dirname(__file__), "assets") - path = os.path.join(data_dir, "grace_hopper_517x606.jpg") - input_image = Image.open(path) - # Based on example from https://pytorch.org/hub/pytorch_vision_resnet/ - preprocess = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), - ] - ) - input_tensor = preprocess(input_image).unsqueeze(0) - - # Due to precision error from quantization, check only that the top prediction matches. - class TopPredictor(torch.nn.Module): - def __init__(self, mobilenet): - super().__init__() - self.mobilenet = mobilenet - - def forward(self, x): - x = self.mobilenet(x) - _, topk_catid = torch.topk(x[0], 1) - return topk_catid - - # Currently, we need convert the model to ScriptModule before export. - # The reason is that PackedParams contains int (not tensor). - # Then it fails when the exporter calls _trace_and_get_graph_from_model(). - # TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/1547858 - model = torch.jit.trace(TopPredictor(model), input_tensor) - self.run_test(model, (input_tensor,)) - def test_word_language_model_RNN_TANH(self): self.run_word_language_model("RNN_TANH") @@ -684,7 +558,10 @@ def forward(self, a, b: Tuple[Tensor, Tuple[Tensor, Tensor]]): def test_mixed_optional_default_none(self): class Model(torch.nn.Module): def forward( - self, x, y: Optional[Tensor] = None, z: Optional[Tensor] = None + self, + x, + y: Optional[Tensor] = None, + z: Optional[Tensor] = None, ): if y is not None: return x + y @@ -729,6 +606,7 @@ def forward( self.run_test(model, (x, y, None)) self.run_test(model, (x, None, z)) + @skipTraceTest() # tracing is verified with different set of inputs. See above. @skipIfUnsupportedMinOpsetVersion(15) def test_mixed_optional_default_tensor_script(self): class Model(torch.nn.Module): @@ -823,6 +701,7 @@ def forward( with self.assertRaisesRegex(ValueError, "got too many positional inputs"): self.run_test(model, (x, y)) + @skipTraceTest() # tracing is verified with different set of inputs. See above. @skipIfUnsupportedMinOpsetVersion(15) def test_all_optional_default_tensor_script(self): class Model(torch.nn.Module): @@ -905,6 +784,7 @@ def forward( y1 = torch.randn(2, 3) self.run_test(Model(), (x, (None, y1))) + @skipTraceTest() # tracing is verified with different set of inputs. See above. @skipIfUnsupportedMinOpsetVersion(15) def test_tuple_of_optional_default_tensor_script(self): class Model(torch.nn.Module): @@ -1122,6 +1002,11 @@ def test_hardshrink(self): x = torch.tensor(-0.5).to(dtype=torch.float32) self.run_test(model, x) + @skipIfUnsupportedMinOpsetVersion(9) + def test_hardshrink_dtype(self): + x = torch.rand(3, 3).to(dtype=torch.float64) + self.run_test(torch.nn.Hardshrink(), x) + @skipIfUnsupportedMinOpsetVersion(9) def test_softshrink(self): model = torch.nn.Softshrink() @@ -1135,6 +1020,11 @@ def test_softshrink(self): x = torch.tensor(-0.5).to(dtype=torch.float32) self.run_test(model, x) + @skipIfUnsupportedMinOpsetVersion(9) + def test_softshrink_dtype(self): + x = torch.rand(3, 3).to(dtype=torch.float64) + self.run_test(torch.nn.Softshrink(), x) + def test_clamp(self): class ClampModel(torch.nn.Module): def forward(self, x): @@ -1294,6 +1184,14 @@ def forward(self, input1, input2, input3): self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5) + def test_numpy_T(self): + class NumpyTranspose(torch.nn.Module): + def forward(self, x): + return x.T + + self.run_test(NumpyTranspose(), torch.randn(4, 7)) + self.run_test(NumpyTranspose(), torch.tensor(-42.0)) + # Conversion of Transpose depends on input shape to be known. # The following test only works when onnx shape inference is enabled. def test_transpose_infer_shape(self): @@ -2882,6 +2780,43 @@ def forward(self, x): def test_interpolate_downsample(self): self._interpolate_tests(False) + @skipIfUnsupportedMinOpsetVersion(11) + def test_interpolate_half_pixel(self): + # testing whether it uses "half_pixel" or "pytorch_half_pixel" + # see https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize + + class MyModel(torch.nn.Module): + def __init__(self, mode, size): + super().__init__() + self.mode = mode + self.size = size + + def forward(self, x): + return torch.nn.functional.interpolate( + x, mode=self.mode, size=self.size + ) + + modes = ["linear", "bicubic"] + x = [ + torch.randn(1, 2, 6, requires_grad=True), + torch.randn(1, 2, 4, 6, requires_grad=True), + torch.randn(1, 2, 4, 4, 6, requires_grad=True), + ] + for mode in modes: + for xi in x: + mode_i = mode + if mode == "bicubic" and xi.dim() != 4: + continue + elif mode == "linear": + if xi.dim() == 4: + mode_i = "bilinear" + elif xi.dim() == 5: + mode_i = "trilinear" + for i in range(xi.dim() - 2): + size = list(xi.shape[2:]) + size[i] = 1 + self.run_test(MyModel(mode_i, size), xi) + @skipIfUnsupportedMinOpsetVersion(11) def test_interpolate_no_shape(self): class MyModel(torch.jit.ScriptModule): @@ -3117,6 +3052,18 @@ def forward(self, x, y): self.run_test(DivModule(), (x, y)) self.run_test(PowModule(), (x, z)) + def test_mul_bool(self): + class MyModel(torch.nn.Module): + def forward(self, x, y): + return torch.mul(x, y) + + x_t = torch.tensor([True, False, True, False]) + y_t = torch.tensor([True, True, False, False]) + z_t = torch.tensor([1.0, 2.0, 3.0, 0.0]) + self.run_test(MyModel(), (x_t, y_t)) + self.run_test(MyModel(), (x_t, z_t)) + self.run_test(MyModel(), (z_t, y_t)) + # fmod was added in version 10 @skipIfUnsupportedMinOpsetVersion(10) @skipIfUnsupportedMaxOpsetVersion(13) @@ -3639,10 +3586,13 @@ def symbolic_python_op( elif name == "MyRelu": return g.op("Relu", args[0], outputs=n.outputsSize()) else: - return _unimplemented("prim::PythonOp", "unknown node kind: " + name) + # TODO(justinchuby): Remove reference to internal names in symbolic_helper + return torch.onnx.symbolic_helper._unimplemented( + "prim::PythonOp", "unknown node kind: " + name + ) - register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1) - self.addCleanup(unregister_custom_op_symbolic, "prim::PythonOp", 1) + torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1) + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1) class MyClipModule(torch.nn.Module): def forward(self, x, min): @@ -4481,9 +4431,13 @@ def make_model(layers, packed_sequence): ) if packed_sequence == 1: - model = RnnModelWithPackedSequence(model, False) + model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence( + model, False + ) if packed_sequence == 2: - model = RnnModelWithPackedSequence(model, True) + model = rnn_model_with_packed_sequence.RnnModelWithPackedSequence( + model, True + ) return model def make_input(batch_size, layers, packed_sequence): @@ -6474,6 +6428,15 @@ def forward(self, x): x = torch.randn(3, 4) self.run_test(SortModel(), x) + @skipIfUnsupportedMinOpsetVersion(11) + def test_argsort(self): + class ArgSortModel(torch.nn.Module): + def forward(self, x): + return torch.argsort(x, dim=1, descending=False) + + x = torch.randn(3, 4) + self.run_test(ArgSortModel(), x) + @skipIfUnsupportedMinOpsetVersion(9) def test_masked_fill(self): class MaskedFillModel(torch.nn.Module): @@ -8169,16 +8132,16 @@ def forward(self, input, target): @skipIfUnsupportedMinOpsetVersion(9) def test_kldiv_loss(self): - x = torch.randn(5) - y = torch.randn(5) + x = torch.rand(5).log() + y = torch.rand(5) self._kldiv_loss(x, y) - x = torch.randn(2, 3, 5) - y = torch.randn(2, 3, 5) + x = torch.rand(2, 3, 5).log() + y = torch.rand(2, 3, 5) self._kldiv_loss(x, y) - x = torch.randn(2, 3, 5, 7) - y = torch.randn(2, 3, 5, 7) + x = torch.rand(2, 3, 5, 7).log() + y = torch.rand(2, 3, 5, 7) self._kldiv_loss(x, y) def _kldiv_loss(self, x, y): @@ -8188,7 +8151,7 @@ def __init__(self): self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True) def forward(self, input, target): - return self.loss(input, target) + return self.loss(input, target.log()) self.run_test(KLDivLossNone(), input_args=(x, y)) @@ -8208,7 +8171,7 @@ def __init__(self): self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True) def forward(self, input, target): - return self.loss(input, target) + return self.loss(input, target.log()) self.run_test(KLDivLossSum(), input_args=(x, y)) @@ -8230,7 +8193,7 @@ def __init__(self): ) def forward(self, input, target): - return self.loss(input, target) + return self.loss(input, target.log()) self.run_test(KLDivLossMiniBatchMean(), input_args=(x, y)) @@ -9069,25 +9032,6 @@ def forward(self, x, y, cond): dynamic_axes={"output_1": [1]}, ) - def test_onnx_proto_checker(self): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return 2 * x - - x = torch.randn(1, 2, 3, requires_grad=True) - f = io.BytesIO() - torch.onnx._export(Model(), x, f) - model = onnx.load(f) - model.ir_version = 0 - - def check_proto(): - torch._C._check_onnx_proto(model.SerializeToString()) - - self.assertRaises(RuntimeError, check_proto) - @skipScriptTest(min_opset_version=11) # dynamic split support addded in 11 def test_split_tensor_scalar(self): class SplitModel(torch.nn.Module): @@ -9199,6 +9143,7 @@ def _elman_rnn_test( initial_state, packed_sequence, dropout, + **extra_kwargs, ): class ElmanWithStateModel(torch.nn.Module): def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first): @@ -9215,7 +9160,7 @@ def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first): batch_first=batch_first, ) - def forward(self, input: PackedSequence, hx=None): + def forward(self, input: rnn_utils.PackedSequence, hx=None): return self.inner_model(input, hx) class ElmanWithoutStateModel(torch.nn.Module): @@ -9232,7 +9177,7 @@ def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first): batch_first=batch_first, ) - def forward(self, input: PackedSequence): + def forward(self, input: rnn_utils.PackedSequence): return self.inner_model(input) batch_first = packed_sequence == 2 @@ -9246,7 +9191,11 @@ def forward(self, input: PackedSequence): batch_first=batch_first, ) if packed_sequence: - model = RnnModelWithPackedSequenceWithState(model, batch_first) + model = ( + rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState( + model, batch_first + ) + ) else: model = ElmanWithStateModel( layers=layers, @@ -9256,7 +9205,9 @@ def forward(self, input: PackedSequence): batch_first=batch_first, ) if packed_sequence: - model = RnnModelWithPackedSequenceWithoutState(model, batch_first) + model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState( + model, batch_first + ) def make_input(batch_size): seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) @@ -9286,12 +9237,18 @@ def make_input(batch_size): self.run_test(model, other_input) def _lstm_test( - self, layers, bidirectional, initial_state, packed_sequence, dropout + self, + layers, + bidirectional, + initial_state, + packed_sequence, + dropout, + **extra_kwargs, ): batch_first = packed_sequence == 2 if packed_sequence: - model = LstmFlatteningResultWithSeqLength( + model = lstm_flattening_result.LstmFlatteningResultWithSeqLength( RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, @@ -9300,11 +9257,17 @@ def _lstm_test( batch_first, ) if initial_state: - model = RnnModelWithPackedSequenceWithState(model, batch_first) + model = ( + rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState( + model, batch_first + ) + ) else: - model = RnnModelWithPackedSequenceWithoutState(model, batch_first) + model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState( + model, batch_first + ) else: - model = LstmFlatteningResultWithoutSeqLength( + model = lstm_flattening_result.LstmFlatteningResultWithoutSeqLength( RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, @@ -9341,7 +9304,15 @@ def make_input(batch_size): other_input = make_input(RNN_BATCH_SIZE + 1) self.run_test(model, other_input) - def _gru_test(self, layers, bidirectional, initial_state, packed_sequence, dropout): + def _gru_test( + self, + layers, + bidirectional, + initial_state, + packed_sequence, + dropout, + **extra_kwargs, + ): class GRUWithStateModel(torch.nn.Module): def __init__(self, layers, bidirect, dropout, batch_first): super().__init__() @@ -9356,7 +9327,7 @@ def __init__(self, layers, bidirect, dropout, batch_first): batch_first=batch_first, ) - def forward(self, input: PackedSequence, hx): + def forward(self, input: rnn_utils.PackedSequence, hx): return self.inner_model(input, hx) class GRUWithoutStateModel(torch.nn.Module): @@ -9372,7 +9343,7 @@ def __init__(self, layers, bidirect, dropout, batch_first): batch_first=batch_first, ) - def forward(self, input: PackedSequence): + def forward(self, input: rnn_utils.PackedSequence): return self.inner_model(input) class GRUNoSeqLengthWithoutStateModel(torch.nn.Module): @@ -9417,7 +9388,11 @@ def forward(self, input, hx): dropout=dropout, batch_first=batch_first, ) - model = RnnModelWithPackedSequenceWithState(model, batch_first) + model = ( + rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithState( + model, batch_first + ) + ) else: model = GRUWithoutStateModel( layers=layers, @@ -9425,7 +9400,9 @@ def forward(self, input, hx): dropout=dropout, batch_first=batch_first, ) - model = RnnModelWithPackedSequenceWithoutState(model, batch_first) + model = rnn_model_with_packed_sequence.RnnModelWithPackedSequenceWithoutState( + model, batch_first + ) else: if initial_state: model = GRUNoSeqLengthWithStateModel( @@ -9518,7 +9495,9 @@ def forward(self, input): self.run_test(FakeQuantizePerChannelModel(), (x)) @skipIfUnsupportedMinOpsetVersion(13) - @skipScriptTest() # RuntimeError: Can't redefine method: forward on class: __torch__.torch.nn.modules.linear.Linear + # RuntimeError: Can't redefine method: + # forward on class: __torch__.torch.nn.modules.linear.Linear + @skipScriptTest() def test_fake_quantize_activation(self): from torch import quantization @@ -9939,140 +9918,6 @@ def forward(self, x): ) self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL) - def test_script_custom_class_error(self): - class BoxCoder: - def __init__(self, bbox_xform_clip: float) -> None: - self.bbox_xform_clip = bbox_xform_clip - - def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: - boxes = torch.cat(boxes, dim=0) - pred_ctr_x = ( - torch.clamp(rel_codes[:, 0::4], max=self.bbox_xform_clip) - * boxes[:, 2] - ) - return pred_ctr_x - - class MyModule(torch.nn.Module): - __annotations__ = { - "box_coder": BoxCoder, - } - - def __init__(self): - super().__init__() - self.box_coder = BoxCoder(1.4) - - def forward(self, box_regression: Tensor, proposals: List[Tensor]): - return self.box_coder.decode(box_regression, proposals) - - model = torch.jit.script(MyModule()) - box_regression = torch.randn([4, 4]) - proposal = [torch.randn(2, 4), torch.randn(2, 4)] - - with self.assertRaises(RuntimeError) as cm: - onnx_model = io.BytesIO() - torch.onnx.export( - model, - (box_regression, proposal), - onnx_model, - opset_version=self.opset_version, - ) - - def test_initializer_sequence(self): - class MyModule(torch.nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = torch.nn.Linear(input_size, hidden_size) - self.relu = torch.nn.ReLU() - self.fc2 = torch.nn.Linear(hidden_size, num_classes) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - - test_model = MyModule(3, 4, 10) - state_dict_list = [k for (k, v) in test_model.state_dict().items()] - named_params_list = [k for (k, v) in test_model.named_parameters()] - - x = torch.randn(32, 3) - f = io.BytesIO() - torch.onnx._export(test_model, (x,), f, do_constant_folding=False) - loaded_model = onnx.load_from_string(f.getvalue()) - - actual_list = [p.name for p in loaded_model.graph.initializer] - assert actual_list == state_dict_list, ( - "Initializers' sequence is not as same as state_dict(). Expected: (" - + ", ".join(state_dict_list) - + "). Actual:(" - + ", ".join(actual_list) - + ")." - ) - assert actual_list == named_params_list, ( - "Initializers' sequence is not as same as named_parameters(). Expected: (" - + ", ".join(named_params_list) - + "). Actual:(" - + ", ".join(actual_list) - + ")." - ) - - def test_initializer_sequence_script_model(self): - def list_is_expected(short_list, long_list) -> bool: - if len(short_list) > len(long_list): - return False - - for i in range(len(short_list)): - if short_list[i] not in long_list[i]: - return False - - return True - - def loop(x, y): - for i in range(int(y)): - x = x + i - return x - - class MyModule(torch.nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = torch.nn.Linear(input_size, hidden_size) - self.relu = torch.nn.ReLU() - self.fc2 = torch.nn.Linear(hidden_size, num_classes) - - def forward(self, x, y): - x = loop(x, y) - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - - test_model = torch.jit.script(MyModule(3, 4, 10)) - state_dict_list = [k for (k, v) in test_model.state_dict().items()] - named_params_list = [k for (k, v) in test_model.named_parameters()] - - x = torch.ones(2, 3, dtype=torch.float) - y = torch.tensor(5, dtype=torch.long) - f = io.BytesIO() - - torch.onnx.export(test_model, (x, y), f, do_constant_folding=False) - loaded_model = onnx.load_from_string(f.getvalue()) - - actual_list = [p.name for p in loaded_model.graph.initializer] - assert list_is_expected(state_dict_list, actual_list), ( - "ScriptModel - Initializers' sequence is not as same as state_dict(). Expected: (" - + ", ".join(state_dict_list) - + "). Actual:(" - + ", ".join(actual_list) - + ")." - ) - assert list_is_expected(named_params_list, actual_list), ( - "ScriptModel - Initializers' sequence is not as same as named_parameters(). Expected: (" - + ", ".join(named_params_list) - + "). Actual:(" - + ", ".join(actual_list) - + ")." - ) - @skipIfUnsupportedMinOpsetVersion(11) def test_nms(self): num_boxes = 100 @@ -10082,10 +9927,13 @@ def test_nms(self): class Module(torch.nn.Module): def forward(self, boxes, scores): - return ops.nms(boxes, scores, 0.5) + return torchvision.ops.nms(boxes, scores, 0.5) self.run_test(Module(), (boxes, scores)) + @unittest.skip( + "Broken in recent TorchVision, see https://github.com/pytorch/pytorch/issues/81121" + ) @skipIfUnsupportedMinOpsetVersion(11) # TODO: Fails with vision 0.13. See #77671 def test_batched_nms(self): @@ -10097,7 +9945,7 @@ def test_batched_nms(self): class Module(torch.nn.Module): def forward(self, boxes, scores, idxs): - return ops.batched_nms(boxes, scores, idxs, 0.5) + return torchvision.ops.batched_nms(boxes, scores, idxs, 0.5) self.run_test(Module(), (boxes, scores, idxs)) @@ -10113,7 +9961,7 @@ def test_clip_boxes_to_image(self): class Module(torch.nn.Module): def forward(self, boxes, size): shape = (size.shape[0], size.shape[1]) - return ops.boxes.clip_boxes_to_image(boxes, shape) + return torchvision.ops.boxes.clip_boxes_to_image(boxes, shape) self.run_test( Module(), @@ -10123,44 +9971,53 @@ def forward(self, boxes, size): additional_test_inputs=[(boxes, size), (boxes, size_2)], ) + @unittest.skip( + "Broken in recent TorchVision, see https://github.com/pytorch/pytorch/issues/81121" + ) @skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch @skipIfUnsupportedMinOpsetVersion(11) def test_roi_align(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) - model = ops.RoIAlign((5, 5), 1.0, 2) + model = torchvision.ops.RoIAlign((5, 5), 1.0, 2) self.run_test(model, (x, single_roi)) + @unittest.skip( + "Broken in recent TorchVision, see https://github.com/pytorch/pytorch/issues/81121" + ) @skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch @skipIfUnsupportedMinOpsetVersion(11) def test_roi_align_aligned(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32) - model1 = ops.RoIAlign((5, 5), 1.0, 2, aligned=True) + model1 = torchvision.ops.RoIAlign((5, 5), 1.0, 2, aligned=True) self.run_test(model1, (x, single_roi)) x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) - model2 = ops.RoIAlign((5, 5), 0.5, 3, aligned=True) + model2 = torchvision.ops.RoIAlign((5, 5), 0.5, 3, aligned=True) self.run_test(model2, (x, single_roi)) x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) - model3 = ops.RoIAlign((5, 5), 1.8, 2, aligned=True) + model3 = torchvision.ops.RoIAlign((5, 5), 1.8, 2, aligned=True) self.run_test(model3, (x, single_roi)) x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) - model4 = ops.RoIAlign((2, 2), 2.5, 0, aligned=True) + model4 = torchvision.ops.RoIAlign((2, 2), 2.5, 0, aligned=True) self.run_test(model4, (x, single_roi)) + @unittest.skip( + "Broken in recent TorchVision, see https://github.com/pytorch/pytorch/issues/81121" + ) @skipIfUnsupportedMinOpsetVersion(11) def test_roi_pool(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) pool_h = 5 pool_w = 5 - model = ops.RoIPool((pool_h, pool_w), 2.0) + model = torchvision.ops.RoIPool((pool_h, pool_w), 2.0) self.run_test(model, (x, rois)) @skipIfUnsupportedMinOpsetVersion(11) @@ -10217,15 +10074,13 @@ def get_features(self, images): @skipIfUnsupportedMinOpsetVersion(11) @skipScriptTest() def test_rpn(self): - set_rng_seed(0) - class RPNModule(torch.nn.Module): def __init__(self): super().__init__() self.rpn = _init_test_rpn() def forward(self, images, features: Dict[str, Tensor]): - images_m = ImageList( + images_m = torchvision.models.detection.image_list.ImageList( images, [(i.shape[-1], i.shape[-2]) for i in images] ) return self.rpn(images_m, features) @@ -10261,7 +10116,9 @@ def test_multi_scale_roi_align(self): class TransformModule(torch.nn.Module): def __init__(self): super().__init__() - self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2) + self.model = torchvision.ops.MultiScaleRoIAlign( + ["feat1", "feat2"], 3, 2 + ) self.image_sizes = [(512, 512)] def forward(self, input: Dict[str, Tensor], boxes: List[Tensor]) -> Tensor: @@ -10831,7 +10688,7 @@ def test_index_put_if_5(self): @torch.jit.script def check_init( input_data: Tensor, hidden_size: int, prev_state: Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[Tensor, Tensor]: batch_size = input_data.size(0) spatial_size_0 = input_data.size(2) spatial_size_1 = input_data.size(3) @@ -11816,31 +11673,6 @@ def forward(self, x): x = torch.randn(10, 5) self.run_test(M(), (x,)) - def test_onnx_checker_invalid_graph(self): - class CustomAddModule(torch.nn.Module): - def forward(self, x, y): - return torch.add(x, y) - - def symbolic_custom_invalid_add(g, input, other, alpha=None): - return g.op("Add", input, other, invalid_attr_i=1) - - register_custom_op_symbolic("::add", symbolic_custom_invalid_add, 1) - - x = torch.randn(2, 3, 4) - y = torch.randn(2, 3, 4) - - test_model = CustomAddModule() - f = io.BytesIO() - - try: - with self.assertRaises(torch.onnx.errors.CheckerError): - torch.onnx.export(test_model, (x, y), f) - finally: - unregister_custom_op_symbolic("::add", 1) - - self.assertTrue(f.getvalue(), "ONNX graph was not exported.") - loaded_model = onnx.load_from_string(f.getvalue()) - def test_tuple_output_from_if_with_raised_exception(self): class M(torch.nn.Module): def __init__(self): @@ -11855,36 +11687,6 @@ def forward(self, t: Tensor) -> Tuple[Tensor, Tensor]: x = torch.zeros(1) self.run_test(torch.jit.script(M()), (x,)) - def test_shape_value_map(self): - class RSoftMax(torch.nn.Module): - def __init__(self, radix, cardinality): - super().__init__() - self.radix = radix - self.cardinality = cardinality - - def forward(self, x): - batch = x.size(0) - x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) - x = F.softmax(x, dim=1) - x = x.reshape(batch, -1) - return x - - radix = 2 - cardinality = 1 - x = torch.randn(10, 1, 128, 1) - f = io.BytesIO() - torch.onnx.export( - RSoftMax(radix, cardinality), - (x,), - f, - input_names=["x"], - dynamic_axes={"x": [0]}, - ) - loaded_model = onnx.load_from_string(f.getvalue()) - self.assertEqual( - loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128 - ) - # NOTE: For quantization tests, choose scale and zero point carefully # such that inputs and outputs do not always overflow/underflow. # Otherwise test results could be inaccurate. @@ -11969,8 +11771,78 @@ def forward(self, input): x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8) self.run_test(FlattenModel(), x) + @unittest.skip( + "ONNX Runtime 1.11 does not support quantized cat. Enable after ORT 1.12 is enabled in CI." + ) + @skipIfUnsupportedMinOpsetVersion(10) + @skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function: + def test_quantized_cat_when_concatinating_the_same_tensor(self): + class QuantizedSelfConcatenationModel(torch.nn.Module): + def forward(self, x): + return torch.nn.quantized.QFunctional().cat((x, x), dim=1) + + q_input = torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 128, torch.quint8) + self.run_test(QuantizedSelfConcatenationModel(), q_input) + + @common_utils.parametrize( + "x, y", + [ + common_utils.subtest( + [ + torch.quantize_per_tensor( + torch.ones(2, 3), 0.26, 128, torch.quint8 + ), + torch.quantize_per_tensor( + torch.zeros(1, 3), 0.26, 128, torch.quint8 + ), + ], + name="different_shape", + ), + common_utils.subtest( + [ + torch.quantize_per_tensor( + torch.ones(2, 3), 0.26, 128, torch.quint8 + ), + torch.quantize_per_tensor(torch.ones(2, 3), 42, 1, torch.quint8), + ], + name="different_scale", + ), + common_utils.subtest( + [ + torch.quantize_per_tensor( + torch.ones(2, 3), 0.26, 128, torch.quint8 + ), + torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 63, torch.quint8), + ], + name="different_zero_point", + ), + common_utils.subtest( + [ + torch.quantize_per_tensor( + torch.ones(2, 3), 0.26, 128, torch.quint8 + ), + torch.quantize_per_tensor(torch.ones(2, 3), 0.1, 63, torch.quint8), + ], + name="different_zero_point_and_scale", + ), + ], + ) + @unittest.skip( + "ONNX Runtime 1.11 does not support quantized cat. Enable after ORT 1.12 is enabled in CI." + ) @skipIfUnsupportedMinOpsetVersion(10) @skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function: + def test_quantized_cat(self, x: torch.Tensor, y: torch.Tensor): + class QuantizedConcatenationModel(torch.nn.Module): + def forward(self, x, y): + return torch.nn.quantized.QFunctional().cat((x, y), dim=0) + + self.run_test(QuantizedConcatenationModel(), (x, y)) + + @skipIfUnsupportedMinOpsetVersion(10) + # torch.jit.frontend.FrontendError: + # Cannot instantiate class 'QFunctional' in a script function + @skipScriptTest() def test_quantized_arithmetic_qfunctional(self): x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) @@ -12264,7 +12136,21 @@ def forward(self, x): self.run_test(Module(True), x, rtol=1e-3, atol=1e-6) @skipIfUnsupportedMinOpsetVersion(16) - def test_grid_sample(self): + @common_utils.parametrize( + "mode", + ("bilinear", "nearest", "bicubic"), + ) + @common_utils.parametrize( + "padding_mode", + ("zeros", "border", "reflection"), + ) + @common_utils.parametrize( + "align_corners", + (True, False), + name_fn=lambda align_corners: str(align_corners), + ) + def test_grid_sample(self, mode, padding_mode, align_corners): + n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4 class GridSampleModule(torch.nn.Module): @@ -12281,23 +12167,41 @@ def forward(self, input, grid): input, grid, self.mode, self.padding_mode, self.align_corners ) - for mode, padding_mode, align_corners in itertools.product( - ("bilinear", "nearest", "bicubic"), - ("zeros", "border", "reflection"), - (True, False), - ): - atol_rtol = {} - if (mode, padding_mode) == ("bicubic", "border"): - if align_corners: - atol_rtol.update({"atol": 0.3, "rtol": 0.4}) - else: - atol_rtol.update({"atol": 0.02, "rtol": 0.02}) - input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2) - self.run_test( - GridSampleModule(mode, padding_mode, align_corners), - (input, grid), - **atol_rtol, - ) + atol_rtol = {} + if (mode, padding_mode) == ("bicubic", "border"): + if align_corners: + atol_rtol.update({"atol": 0.3, "rtol": 0.4}) + else: + atol_rtol.update({"atol": 0.02, "rtol": 0.02}) + input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2) + self.run_test( + GridSampleModule(mode, padding_mode, align_corners), + (input, grid), + **atol_rtol, + ) + + # TODO: The fix of OptionalHasElement is still in master branch, not in release + # Enable the test after it's been released. + @skipForAllOpsetVersions() + @skipTraceTest() + @skipIfUnsupportedMinOpsetVersion(16) + def test_uninitialized_optional(self): + class Module(torch.nn.Module): + def forward(self, y: Optional[Tensor]) -> Optional[Tensor]: + if y is not None: + if y.shape[1] < 5: + if y.size(0) == 1: + y = y + 4 + else: + return y + return y + + self.run_test( + Module(), + torch.ones((3, 4), dtype=torch.int), + dynamic_axes={"y": {0: "y0", 1: "y1"}}, + input_names=["y"], + ) @skipIfUnsupportedMinOpsetVersion(9) def test_device_eq(self): @@ -12341,32 +12245,6 @@ def forward(self, x): self.run_test(LerpModel(), torch.rand(5, 4, 3)) - -def make_test( - name, - base, - layer, - bidirectional, - initial_state, - variable_length, - dropout, - script_test_min_opset_version, - **extra_kwargs, -): - test_name = str( - "_".join( - [ - "test", - name, - layer[1], - bidirectional[1], - initial_state[1], - variable_length[1], - dropout[1], - ] - ) - ) - # Cannot export with older opsets because of "ConstantFill" op # ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime # There are still some issues prevent us from enabling script test for these scenarios: @@ -12375,140 +12253,34 @@ def make_test( # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382 # Operator aten::_pack_padded_sequence is not supported by exporter yet. # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384 + # test_elman_*: + # Compiling in script mode fails with errors like: + # torch.jit.frontend.UnsupportedNodeError: annotated assignments + # without assigned value aren't supported + # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 + # test_lstm_*: + # Compiling in script mode fails with errors like: + # RuntimeError: Arguments for call are not valid. + # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 @skipScriptTest() @skipIfUnsupportedMinOpsetVersion(9) - def f(self): - self.is_script_test_enabled = ( - self.opset_version >= script_test_min_opset_version - ) - self._dispatch_rnn_test( - base, - layers=layer[0], - bidirectional=bidirectional[0], - initial_state=initial_state[0], - packed_sequence=variable_length[0], - dropout=dropout[0], - **extra_kwargs, - ) - - f.__name__ = test_name - setattr(_TestONNXRuntime, f.__name__, f) - - -def setup_rnn_tests(): - layers_opts = [(1, "unilayer"), (3, "trilayer")] - bidirectional_opts = [(False, "forward"), (True, "bidirectional")] - initial_state_opts = [(True, "with_initial_state"), (False, "no_initial_state")] - variable_length_opts = [ - (0, "without_sequence_lengths"), - (1, "with_variable_length_sequences"), - (2, "with_batch_first_sequence_lengths"), - ] - dropout_opts = [(0.2, "with_dropout"), (0.0, "without_dropout")] - test_count = 0 - for ( - layer, - bidirectional, - initial_state, - variable_length, - dropout, - ) in itertools.product( - layers_opts, - bidirectional_opts, - initial_state_opts, - variable_length_opts, - dropout_opts, - ): - - for base, name, extra_kwargs in ( - ("elman", "elman_relu", {"nonlinearity": "relu"}), - ("elman", "elman_tanh", {"nonlinearity": "tanh"}), - ("lstm", "lstm", {}), - ("gru", "gru", {}), - ): - # Need Add between list of tensors - script_test_min_opset_version = 11 - - if ( # compiling in script mode fails with errors like: - # torch.jit.frontend.UnsupportedNodeError: annotated assignments - # without assigned value aren't supported - # https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 - base == "elman" - or - # compiling in script mode fails with errors like: - # RuntimeError: Arguments for call are not valid. - # https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 - base == "lstm" - ): - script_test_min_opset_version = float("inf") - make_test( - name, - base, - layer, - bidirectional, - initial_state, - variable_length, - dropout, - script_test_min_opset_version, - **extra_kwargs, - ) - test_count += 1 - - # sanity check that a representative example does exist - _TestONNXRuntime.test_gru_trilayer_forward_with_initial_state_without_sequence_lengths_with_dropout - - # make sure no one accidentally disables all the tests without - # noticing - if test_count != 192: - raise ValueError(f"Expected 192 tests but found {test_count}") - - -setup_rnn_tests() - - -def MakeTestCase(opset_version: int, keep_initializers_as_inputs: bool = True) -> type: - name = f"TestONNXRuntime_opset{opset_version}" - if not keep_initializers_as_inputs: - name += "_IRv4" - return type( - str(name), - (TestCase,), - dict( - _TestONNXRuntime.__dict__, - opset_version=opset_version, - keep_initializers_as_inputs=keep_initializers_as_inputs, - ), + @common_utils.parametrize( + "name, nonlinearity", + [ + ("elman", "relu"), + ("elman", "tanh"), + ("lstm", None), + ("gru", None), + ], ) - - -TestONNXRuntime_opset7 = MakeTestCase(7) - -TestONNXRuntime_opset8 = MakeTestCase(8) - -TestONNXRuntime_opset9 = MakeTestCase(9) - -TestONNXRuntime_opset9_IRv4 = MakeTestCase(9, keep_initializers_as_inputs=False) - -TestONNXRuntime_opset10 = MakeTestCase(10) - -TestONNXRuntime_opset10_IRv4 = MakeTestCase(10, keep_initializers_as_inputs=False) - -TestONNXRuntime_opset11 = MakeTestCase(11) - -TestONNXRuntime_opset11_IRv4 = MakeTestCase(11, keep_initializers_as_inputs=False) - -TestONNXRuntime_opset12 = MakeTestCase(12) - -TestONNXRuntime_opset12_IRv4 = MakeTestCase(12, keep_initializers_as_inputs=False) - -TestONNXRuntime_opset13 = MakeTestCase(13, keep_initializers_as_inputs=False) - -TestONNXRuntime_opset14 = MakeTestCase(14, keep_initializers_as_inputs=False) - -TestONNXRuntime_opset15 = MakeTestCase(15, keep_initializers_as_inputs=False) - -TestONNXRuntime_opset16 = MakeTestCase(16, keep_initializers_as_inputs=False) + @common_utils.parametrize(**_parametrize_rnn_args("layers")) + @common_utils.parametrize(**_parametrize_rnn_args("bidirectional")) + @common_utils.parametrize(**_parametrize_rnn_args("initial_state")) + @common_utils.parametrize(**_parametrize_rnn_args("packed_sequence")) + @common_utils.parametrize(**_parametrize_rnn_args("dropout")) + def test_rnn(self, *args, **kwargs): + self._dispatch_rnn_test(*args, **kwargs) if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py index b5d41bbf71fc3..3832a110fc918 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py @@ -3,23 +3,21 @@ import unittest import onnxruntime # noqa: F401 -from test_pytorch_common import ( + +import torch +from pytorch_test_common import ( skipIfNoBFloat16Cuda, skipIfNoCuda, skipIfUnsupportedMinOpsetVersion, skipScriptTest, - TestCase, ) - -# TODO(justinchuby): Remove reference to other unit tests. from test_pytorch_onnx_onnxruntime import TestONNXRuntime - -import torch from torch.cuda.amp import autocast from torch.onnx._globals import GLOBALS +from torch.testing._internal import common_utils -class TestONNXRuntime_cuda(TestCase): +class TestONNXRuntime_cuda(common_utils.TestCase): opset_version = GLOBALS.export_onnx_opset_version keep_initializers_as_inputs = True @@ -151,5 +149,4 @@ def forward(self, x, y): TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test if __name__ == "__main__": - # TODO: convert this to use common_utils.run_tests() - unittest.main(TestONNXRuntime_cuda()) + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py index 6b293d65697ad..516ac2cb6cd72 100644 --- a/test/onnx/test_pytorch_onnx_shape_inference.py +++ b/test/onnx/test_pytorch_onnx_shape_inference.py @@ -1,11 +1,11 @@ # Owner(s): ["module: onnx"] import numpy as np -from test_pytorch_common import run_tests, skipIfUnsupportedMinOpsetVersion, TestCase import torch -from torch.onnx import _constants -from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version +from pytorch_test_common import skipIfUnsupportedMinOpsetVersion +from torch.onnx import _constants, symbolic_helper +from torch.testing._internal import common_utils def expect_tensor(scalar_type, shape=None): @@ -19,12 +19,11 @@ def verify(actual_type): return verify -class TestONNXShapeInference(TestCase): - def __init__(self, *args, **kwargs): - TestCase.__init__(self, *args, **kwargs) +class TestONNXShapeInference(common_utils.TestCase): + def setUp(self): self.opset_version = _constants.onnx_main_opset - _set_onnx_shape_inference(True) - _set_opset_version(self.opset_version) + symbolic_helper._set_onnx_shape_inference(True) + symbolic_helper._set_opset_version(self.opset_version) def run_test(self, g, n, type_assertion_funcs): if not isinstance(type_assertion_funcs, list): @@ -271,4 +270,4 @@ def test_resize_after_concat(self): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 7b8ac8d3f88e7..38eab4f09f03b 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -4,24 +4,21 @@ import io import onnx + +import torch +import torch.onnx +import torch.utils.cpp_extension import torchvision from autograd_helper import CustomFunction as CustomFunction2 -from test_pytorch_common import ( - TestCase, - run_tests, +from pytorch_test_common import ( skipIfNoCuda, skipIfUnsupportedMaxOpsetVersion, skipIfUnsupportedMinOpsetVersion, ) -from verify import verify - -import torch -import torch.onnx -import torch.utils.cpp_extension from torch.onnx import ( OperatorExportTypes, - TrainingMode, register_custom_op_symbolic, + TrainingMode, unregister_custom_op_symbolic, utils, ) @@ -32,9 +29,12 @@ _unpack_list, parse_args, ) +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack +from verify import verify -class _BaseTestCase(TestCase): +class _BaseTestCase(common_utils.TestCase): def setUp(self): super().setUp() torch.manual_seed(0) @@ -170,8 +170,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Transpose") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") - self.assertEqual(len(list(graph.nodes())), 1) + self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_reduceL2(self): class ReduceModule(torch.nn.Module): @@ -189,7 +188,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::ReduceL2") - self.assertEqual(len(list(graph.nodes())), 1) + self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_reduceL1(self): class NormModule(torch.nn.Module): @@ -207,7 +206,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::ReduceL1") - self.assertEqual(len(list(graph.nodes())), 1) + self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_slice(self): class NarrowModule(torch.nn.Module): @@ -226,8 +225,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Slice") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") - self.assertEqual(len(list(graph.nodes())), 1) + self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_slice_index_exceeds_dim(self): class SliceIndexExceedsDimModule(torch.nn.Module): @@ -249,8 +247,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Slice") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") - self.assertEqual(len(list(graph.nodes())), 1) + self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_slice_negative_index(self): class SliceNegativeIndexModule(torch.nn.Module): @@ -274,7 +271,6 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Slice") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") def test_constant_fold_gather(self): class GatherModule(torch.nn.Module): @@ -313,8 +309,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Unsqueeze") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") - self.assertEqual(len(list(graph.nodes())), 1) + self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_unsqueeze_multi_axies(self): class PReluModel(torch.nn.Module): @@ -336,8 +331,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Unsqueeze") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") - self.assertEqual(len(list(graph.nodes())), 4) + self.assertEqual(len(list(graph.nodes())), 5) def test_constant_fold_squeeze_without_axes(self): class SqueezeModule(torch.nn.Module): @@ -354,8 +348,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Squeeze") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") - self.assertEqual(len(list(graph.nodes())), 2) + self.assertEqual(len(list(graph.nodes())), 4) def test_constant_fold_squeeze_with_axes(self): class SqueezeAxesModule(torch.nn.Module): @@ -373,8 +366,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Squeeze") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") - self.assertEqual(len(list(graph.nodes())), 1) + self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_concat(self): class ConcatModule(torch.nn.Module): @@ -410,8 +402,7 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Concat") self.assertNotEqual(node.kind(), "onnx::Cast") - self.assertNotEqual(node.kind(), "onnx::Constant") - self.assertEqual(len(list(graph.nodes())), 1) + self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_lstm(self): class GruNet(torch.nn.Module): @@ -640,6 +631,16 @@ def forward(self, x): self.assertNotEqual(node.kind(), "onnx::Shape") self.assertEqual(len(list(graph.nodes())), 1) + def test_constant_fold_upsample_scale_fold_as_constant(self): + # upsample scale is a constant, not a model parameter, + # therefore should not be added as initializer after constant folding. + model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + x = torch.randn(1, 32, 224, 224) + f = io.BytesIO() + torch.onnx.export(model, x, f) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertEqual(len(onnx_model.graph.initializer), 0) + def test_verbose(self): class MyModule(torch.nn.Module): def forward(self, input): @@ -1084,6 +1085,7 @@ def gelu(g, self, approximate): self.assertEqual(graph.graph.node[0].op_type, "Gelu") self.assertEqual(graph.opset_import[1].domain, "com.microsoft") + @skipIfNoLapack def test_custom_opsets_inverse(self): class CustomInverse(torch.nn.Module): def forward(self, x): @@ -1301,6 +1303,7 @@ def forward(self, x): "Graph parameter names does not match model parameters.", ) + @skipIfNoCaffe2 def test_modifying_params(self): class MyModel(torch.nn.Module): def __init__(self): @@ -1592,6 +1595,32 @@ def forward(self, input0, input1): self.assertEqual(graph.graph.node[3].op_type, "Gemm") self.assertEqual(graph.graph.node[4].op_type, "Identity") + def test_deduplicate_ignore_upsample_scale(self): + # upsample scale is a constant, not a model parameter, + # therefore should be ignored by shared weight deduplication. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.upsample_1 = torch.nn.Upsample(scale_factor=2) + self.upsample_2 = torch.nn.Upsample(scale_factor=2) + + def forward(self, x): + return self.upsample_1(x), self.upsample_2(x) + + f = io.BytesIO() + x = torch.randn(1, 32, 224, 224) + torch.onnx.export(Model(), x, f) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + # aten::upsample converts to onnx::resize + resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"] + self.assertEqual(len(resize_nodes), 2) + for resize_node in resize_nodes: + scale_node = [ + n for n in onnx_model.graph.node if n.output[0] == resize_node.input[2] + ] + self.assertEqual(len(scale_node), 1) + self.assertEqual(scale_node[0].op_type, "Constant") + def test_bad_symbolic_registration(self): _onnx_opset_version = 9 @@ -1649,4 +1678,4 @@ class TestUtilityFuns_opset15(TestUtilityFuns_opset9): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/onnx/test_verify.py b/test/onnx/test_verify.py index a9a88ac614339..af8c29bbbe1f1 100644 --- a/test/onnx/test_verify.py +++ b/test/onnx/test_verify.py @@ -1,15 +1,14 @@ # Owner(s): ["module: onnx"] -from test_pytorch_common import TestCase, run_tests -from verify import verify - import caffe2.python.onnx.backend as backend import torch from torch.autograd import Function from torch.nn import Module, Parameter +from torch.testing._internal import common_utils +from verify import verify -class TestVerify(TestCase): +class TestVerify(common_utils.TestCase): maxDiff = None def assertVerifyExpectFail(self, *args, **kwargs): @@ -106,4 +105,4 @@ def forward(self, x): if __name__ == "__main__": - run_tests() + common_utils.run_tests() diff --git a/test/package/test_dependency_api.py b/test/package/test_dependency_api.py index 9f1a9c9899e8b..b8350ddf88242 100644 --- a/test/package/test_dependency_api.py +++ b/test/package/test_dependency_api.py @@ -6,6 +6,8 @@ from textwrap import dedent from unittest import skipIf +import torch.nn + from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter from torch.package.package_exporter import PackagingError from torch.testing._internal.common_utils import IS_WINDOWS, run_tests @@ -347,6 +349,29 @@ def test_repackage_mocked_module(self): with self.assertRaises(NotImplementedError): foo2.package_a.get_something() + def test_externing_c_extension(self): + """Externing c extensions modules should allow us to still access them especially those found in torch._C.""" + + buffer = BytesIO() + # The C extension module in question is F.gelu which comes from torch._C._nn + model = torch.nn.TransformerEncoderLayer( + d_model=64, + nhead=2, + dim_feedforward=64, + dropout=1.0, + batch_first=True, + activation="gelu", + norm_first=True, + ) + with PackageExporter(buffer) as e: + e.extern("torch.**") + e.intern("**") + + e.save_pickle("model", "model.pkl", model) + buffer.seek(0) + imp = PackageImporter(buffer) + imp.load_pickle("model", "model.pkl") + if __name__ == "__main__": run_tests() diff --git a/test/package/test_directory_reader.py b/test/package/test_directory_reader.py index 16d1b73d28884..d4bf4ae99057e 100644 --- a/test/package/test_directory_reader.py +++ b/test/package/test_directory_reader.py @@ -45,6 +45,7 @@ class DirectoryReaderTest(PackageTestCase): """Tests use of DirectoryReader as accessor for opened packages.""" @skipIfNoTorchVision + @skipIf(True, "Does not work with latest TorchVision, see https://github.com/pytorch/pytorch/issues/81115") def test_loading_pickle(self): """ Test basic saving and loading of modules and pickles from a DirectoryReader. diff --git a/test/package/test_misc.py b/test/package/test_misc.py index 20e9b0ebbb114..c29602d8e360b 100644 --- a/test/package/test_misc.py +++ b/test/package/test_misc.py @@ -10,7 +10,7 @@ from torch.package import is_from_package, PackageExporter, PackageImporter from torch.package.package_exporter import PackagingError -from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests +from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests, skipIfTorchDynamo try: from .common import PackageTestCase @@ -263,6 +263,7 @@ def test_dunder_package_works_from_package(self): self.assertTrue(imported_mod.is_from_package()) self.assertFalse(mod.is_from_package()) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_std_lib_sys_hackery_checks(self): """ The standard library performs sys.module assignment hackery which diff --git a/test/package/test_model.py b/test/package/test_model.py index cc0c7d81bdcc8..05da0954114ad 100644 --- a/test/package/test_model.py +++ b/test/package/test_model.py @@ -23,6 +23,7 @@ from common import PackageTestCase +@skipIf(True, "Does not work with recent torchvision, see https://github.com/pytorch/pytorch/issues/81115") @skipIfNoTorchVision class ModelTest(PackageTestCase): """End-to-end tests packaging an entire model.""" diff --git a/test/profiler_utils_mock_events.json b/test/profiler_utils_mock_events.json new file mode 100644 index 0000000000000..00fcfccdfe300 --- /dev/null +++ b/test/profiler_utils_mock_events.json @@ -0,0 +1 @@ +[[{"_name": "aten::matmul", "_start_us": 1656454173440014, "_duration_us": 2254, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173440019, "_duration_us": 2246, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173442289, "_duration_us": 33, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173442291, "_duration_us": 30, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173442325, "_duration_us": 32, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173442326, "_duration_us": 30, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173442360, "_duration_us": 21, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173442361, "_duration_us": 19, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173442384, "_duration_us": 21, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173442385, "_duration_us": 20, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173444252, "_duration_us": 38, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173444282, "_duration_us": 4, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173444291, "_duration_us": 9, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173444296, "_duration_us": 1, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::copy_", "_start_us": 1656454173444305, "_duration_us": 45427, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173489760, "_duration_us": 5, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173489764, "_duration_us": 0, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173489766, "_duration_us": 3, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173489767, "_duration_us": 1, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::copy_", "_start_us": 1656454173489771, "_duration_us": 35, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173489811, "_duration_us": 2, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173489812, "_duration_us": 1, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173489814, "_duration_us": 2, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173489815, "_duration_us": 0, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::copy_", "_start_us": 1656454173489817, "_duration_us": 21, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173489842, "_duration_us": 3, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173489844, "_duration_us": 0, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173489846, "_duration_us": 1, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173489847, "_duration_us": 0, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::copy_", "_start_us": 1656454173489848, "_duration_us": 21, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173489873, "_duration_us": 2, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173489874, "_duration_us": 0, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::select", "_start_us": 1656454173489875, "_duration_us": 2, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::as_strided", "_start_us": 1656454173489876, "_duration_us": 1, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::copy_", "_start_us": 1656454173489878, "_duration_us": 20, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173489912, "_duration_us": 104, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173489916, "_duration_us": 99, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173490026, "_duration_us": 25, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173490027, "_duration_us": 23, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173490054, "_duration_us": 34, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173490055, "_duration_us": 32, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173490091, "_duration_us": 21, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173490092, "_duration_us": 20, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::matmul", "_start_us": 1656454173490115, "_duration_us": 22, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "aten::mm", "_start_us": 1656454173490116, "_duration_us": 20, "_linked_correlation_id": 0, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173441289, "_duration_us": 2, "_linked_correlation_id": 3074, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173443225, "_duration_us": 9296, "_linked_correlation_id": 3074, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173441296, "_duration_us": 963, "_linked_correlation_id": 3074, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173442309, "_duration_us": 1, "_linked_correlation_id": 3076, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173452523, "_duration_us": 9296, "_linked_correlation_id": 3076, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173442312, "_duration_us": 7, "_linked_correlation_id": 3076, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173442346, "_duration_us": 0, "_linked_correlation_id": 3078, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173461821, "_duration_us": 9293, "_linked_correlation_id": 3078, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173442348, "_duration_us": 6, "_linked_correlation_id": 3078, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173442371, "_duration_us": 0, "_linked_correlation_id": 3080, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173471117, "_duration_us": 9295, "_linked_correlation_id": 3080, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173442373, "_duration_us": 5, "_linked_correlation_id": 3080, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173442395, "_duration_us": 0, "_linked_correlation_id": 3082, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173480414, "_duration_us": 9297, "_linked_correlation_id": 3082, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173442397, "_duration_us": 6, "_linked_correlation_id": 3082, "_device_type": 0}, {"_name": "Memcpy HtoD (Pageable -> Device)", "_start_us": 1656454173489715, "_duration_us": 2, "_linked_correlation_id": 3087, "_device_type": 1}, {"_name": "cudaMemcpyAsync", "_start_us": 1656454173444325, "_duration_us": 24, "_linked_correlation_id": 3087, "_device_type": 0}, {"_name": "cudaStreamSynchronize", "_start_us": 1656454173444350, "_duration_us": 45377, "_linked_correlation_id": 3087, "_device_type": 0}, {"_name": "Memcpy HtoD (Pageable -> Device)", "_start_us": 1656454173489796, "_duration_us": 2, "_linked_correlation_id": 3092, "_device_type": 1}, {"_name": "cudaMemcpyAsync", "_start_us": 1656454173489777, "_duration_us": 14, "_linked_correlation_id": 3092, "_device_type": 0}, {"_name": "cudaStreamSynchronize", "_start_us": 1656454173489791, "_duration_us": 13, "_linked_correlation_id": 3092, "_device_type": 0}, {"_name": "Memcpy HtoD (Pageable -> Device)", "_start_us": 1656454173489828, "_duration_us": 2, "_linked_correlation_id": 3097, "_device_type": 1}, {"_name": "cudaMemcpyAsync", "_start_us": 1656454173489820, "_duration_us": 3, "_linked_correlation_id": 3097, "_device_type": 0}, {"_name": "cudaStreamSynchronize", "_start_us": 1656454173489824, "_duration_us": 13, "_linked_correlation_id": 3097, "_device_type": 0}, {"_name": "Memcpy HtoD (Pageable -> Device)", "_start_us": 1656454173489859, "_duration_us": 2, "_linked_correlation_id": 3102, "_device_type": 1}, {"_name": "cudaMemcpyAsync", "_start_us": 1656454173489851, "_duration_us": 3, "_linked_correlation_id": 3102, "_device_type": 0}, {"_name": "cudaStreamSynchronize", "_start_us": 1656454173489854, "_duration_us": 13, "_linked_correlation_id": 3102, "_device_type": 0}, {"_name": "Memcpy HtoD (Pageable -> Device)", "_start_us": 1656454173489889, "_duration_us": 2, "_linked_correlation_id": 3107, "_device_type": 1}, {"_name": "cudaMemcpyAsync", "_start_us": 1656454173489880, "_duration_us": 3, "_linked_correlation_id": 3107, "_device_type": 0}, {"_name": "cudaStreamSynchronize", "_start_us": 1656454173489884, "_duration_us": 12, "_linked_correlation_id": 3107, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173489972, "_duration_us": 3, "_linked_correlation_id": 3109, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173490013, "_duration_us": 9302, "_linked_correlation_id": 3109, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173489980, "_duration_us": 32, "_linked_correlation_id": 3109, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173490040, "_duration_us": 0, "_linked_correlation_id": 3111, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173499317, "_duration_us": 9306, "_linked_correlation_id": 3111, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173490042, "_duration_us": 7, "_linked_correlation_id": 3111, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173490076, "_duration_us": 0, "_linked_correlation_id": 3113, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173508625, "_duration_us": 9299, "_linked_correlation_id": 3113, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173490078, "_duration_us": 7, "_linked_correlation_id": 3113, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173490102, "_duration_us": 0, "_linked_correlation_id": 3115, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173517925, "_duration_us": 9300, "_linked_correlation_id": 3115, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173490104, "_duration_us": 5, "_linked_correlation_id": 3115, "_device_type": 0}, {"_name": "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", "_start_us": 1656454173490126, "_duration_us": 0, "_linked_correlation_id": 3117, "_device_type": 0}, {"_name": "ampere_sgemm_128x64_nn", "_start_us": 1656454173527228, "_duration_us": 9301, "_linked_correlation_id": 3117, "_device_type": 1}, {"_name": "cudaLaunchKernel", "_start_us": 1656454173490128, "_duration_us": 6, "_linked_correlation_id": 3117, "_device_type": 0}, {"_name": "cudaDeviceSynchronize", "_start_us": 1656454173490313, "_duration_us": 46225, "_linked_correlation_id": 0, "_device_type": 0}], [{"_name": "test_profiler.py(1435): ", "id": 94242239505696, "start_time_ns": 1656454173436288169, "duration_time_ns": 7566917863418487639, "correlation_id": 0, "children": [94242238082288], "parent": null}, {"_name": "torch/testing/_internal/common_utils.py(697): run_tests", "id": 94242238082288, "start_time_ns": 1656454173438182431, "duration_time_ns": 7566917863416593377, "correlation_id": 0, "children": [94242238082800], "parent": 94242239505696}, {"_name": "unittest/main.py(101): __init__", "id": 94242238082800, "start_time_ns": 1656454173438184159, "duration_time_ns": 7566917863416591649, "correlation_id": 0, "children": [94242238083184], "parent": 94242238082288}, {"_name": "unittest/main.py(271): runTests", "id": 94242238083184, "start_time_ns": 1656454173438186629, "duration_time_ns": 7566917863416589179, "correlation_id": 0, "children": [94242238083568], "parent": 94242238082800}, {"_name": "unittest/runner.py(184): run", "id": 94242238083568, "start_time_ns": 1656454173438187601, "duration_time_ns": 7566917863416588207, "correlation_id": 0, "children": [94242238084128], "parent": 94242238083184}, {"_name": "unittest/suite.py(84): __call__", "id": 94242238084128, "start_time_ns": 1656454173438189531, "duration_time_ns": 7566917863416586277, "correlation_id": 0, "children": [94242238084544], "parent": 94242238083568}, {"_name": "unittest/suite.py(122): run", "id": 94242238084544, "start_time_ns": 1656454173438190205, "duration_time_ns": 7566917863416585603, "correlation_id": 0, "children": [94242238084960], "parent": 94242238084128}, {"_name": "unittest/suite.py(84): __call__", "id": 94242238084960, "start_time_ns": 1656454173438191228, "duration_time_ns": 7566917863416584580, "correlation_id": 0, "children": [94242238085376], "parent": 94242238084544}, {"_name": "unittest/suite.py(122): run", "id": 94242238085376, "start_time_ns": 1656454173438191346, "duration_time_ns": 7566917863416584462, "correlation_id": 0, "children": [94242238085792], "parent": 94242238084960}, {"_name": "unittest/case.py(651): __call__", "id": 94242238085792, "start_time_ns": 1656454173438191484, "duration_time_ns": 7566917863416584324, "correlation_id": 0, "children": [94242239133216], "parent": 94242238085376}, {"_name": "torch/testing/_internal/common_utils.py(1886): run", "id": 94242239133216, "start_time_ns": 1656454173438195759, "duration_time_ns": 7566917863416580049, "correlation_id": 0, "children": [94242239133632], "parent": 94242238085792}, {"_name": "torch/testing/_internal/common_utils.py(1829): _run_with_retry", "id": 94242239133632, "start_time_ns": 1656454173438197353, "duration_time_ns": 7566917863416578455, "correlation_id": 0, "children": [94242239134048], "parent": 94242239133216}, {"_name": "unittest/case.py(592): run", "id": 94242239134048, "start_time_ns": 1656454173438198172, "duration_time_ns": 7566917863416577636, "correlation_id": 0, "children": [94242239134464], "parent": 94242239133632}, {"_name": "unittest/case.py(550): _callTestMethod", "id": 94242239134464, "start_time_ns": 1656454173438211703, "duration_time_ns": 7566917863416564105, "correlation_id": 0, "children": [94242239134880], "parent": 94242239134048}, {"_name": "test_profiler.py(1420): test_utils_get_optimizable_events", "id": 94242239134880, "start_time_ns": 1656454173438759703, "duration_time_ns": 7566917863416016105, "correlation_id": 0, "children": [94242239135296], "parent": 94242239134464}, {"_name": "test_profiler.py(1251): load_mock_profile", "id": 94242239135296, "start_time_ns": 1656454173438760534, "duration_time_ns": 7566917863416015274, "correlation_id": 0, "children": [94242239135712, 94240979270032, 94240979295904, 94240979389888, 94240979327296, 94242239499936, 94242239139040, 94242239299696, 94242239301040, 94242239302384, 94242239303728, 94242239305072, 94242239139456], "parent": 94242239134880}, {"_name": "torch/profiler/profiler.py(475): __exit__", "id": 94242239139456, "start_time_ns": 1656454173490143177, "duration_time_ns": 7566917863364632631, "correlation_id": 0, "children": [94242239139872], "parent": 94242239135296}, {"_name": "torch/profiler/profiler.py(484): stop", "id": 94242239139872, "start_time_ns": 1656454173490151443, "duration_time_ns": 7566917863364624365, "correlation_id": 0, "children": [94242239140288], "parent": 94242239139456}, {"_name": "torch/profiler/profiler.py(511): _transit_action", "id": 94242239140288, "start_time_ns": 1656454173490160200, "duration_time_ns": 7566917863364615608, "correlation_id": 0, "children": [94242238898288, 94242238886608], "parent": 94242239139872}, {"_name": "torch/profiler/profiler.py(117): stop_trace", "id": 94242238886608, "start_time_ns": 1656454173490212930, "duration_time_ns": 7566917863364562878, "correlation_id": 0, "children": [94242238887024], "parent": 94242239140288}, {"_name": "torch/autograd/profiler.py(207): __exit__", "id": 94242238887024, "start_time_ns": 1656454173490216323, "duration_time_ns": 7566917863364559485, "correlation_id": 0, "children": [94242238887440], "parent": 94242238886608}, {"_name": "torch/cuda/__init__.py(486): synchronize", "id": 94242238887440, "start_time_ns": 1656454173490222710, "duration_time_ns": 7566917863364553098, "correlation_id": 0, "children": [94242238887856, 94242238888688, 94242238894096, 94242239121280, 94242238897008], "parent": 94242238887024}, {"_name": "torch/cuda/__init__.py(281): __exit__", "id": 94242238897008, "start_time_ns": 1656454173536540711, "duration_time_ns": 7566917863318235097, "correlation_id": 0, "children": [94242239121696], "parent": 94242238887440}, {"_name": "", "id": 94242239121696, "start_time_ns": 1656454173536553153, "duration_time_ns": 7566917863318222655, "correlation_id": 0, "children": [], "parent": 94242238897008}, {"_name": "", "id": 94242239121280, "start_time_ns": 1656454173490312079, "duration_time_ns": 46227101, "correlation_id": 0, "children": [], "parent": 94242238887440}, {"_name": "torch/cuda/__init__.py(272): __enter__", "id": 94242238894096, "start_time_ns": 1656454173490303577, "duration_time_ns": 5394, "correlation_id": 0, "children": [94242238894512, 94242238895760, 94242238896176], "parent": 94242238887440}, {"_name": "torch/cuda/__init__.py(191): _lazy_init", "id": 94242238896176, "start_time_ns": 1656454173490308123, "duration_time_ns": 792, "correlation_id": 0, "children": [94242238896592], "parent": 94242238894096}, {"_name": "torch/cuda/__init__.py(149): is_initialized", "id": 94242238896592, "start_time_ns": 1656454173490308633, "duration_time_ns": 242, "correlation_id": 0, "children": [94242239120864], "parent": 94242238896176}, {"_name": "", "id": 94242239120864, "start_time_ns": 1656454173490308734, "duration_time_ns": 121, "correlation_id": 0, "children": [], "parent": 94242238896592}, {"_name": "torch/_jit_internal.py(982): is_scripting", "id": 94242238895760, "start_time_ns": 1656454173490307337, "duration_time_ns": 660, "correlation_id": 0, "children": [], "parent": 94242238894096}, {"_name": "torch/cuda/__init__.py(480): current_device", "id": 94242238894512, "start_time_ns": 1656454173490304532, "duration_time_ns": 2250, "correlation_id": 0, "children": [94242238894928, 94242239120448], "parent": 94242238894096}, {"_name": "", "id": 94242239120448, "start_time_ns": 1656454173490305817, "duration_time_ns": 934, "correlation_id": 0, "children": [], "parent": 94242238894512}, {"_name": "torch/cuda/__init__.py(191): _lazy_init", "id": 94242238894928, "start_time_ns": 1656454173490305205, "duration_time_ns": 400, "correlation_id": 0, "children": [94242238895344], "parent": 94242238894512}, {"_name": "torch/cuda/__init__.py(149): is_initialized", "id": 94242238895344, "start_time_ns": 1656454173490305315, "duration_time_ns": 249, "correlation_id": 0, "children": [94242239120032], "parent": 94242238894928}, {"_name": "", "id": 94242239120032, "start_time_ns": 1656454173490305469, "duration_time_ns": 64, "correlation_id": 0, "children": [], "parent": 94242238895344}, {"_name": "torch/cuda/__init__.py(268): __init__", "id": 94242238888688, "start_time_ns": 1656454173490238187, "duration_time_ns": 63856, "correlation_id": 0, "children": [94242238889104], "parent": 94242238887440}, {"_name": "torch/cuda/_utils.py(7): _get_device_index", "id": 94242238889104, "start_time_ns": 1656454173490241229, "duration_time_ns": 59393, "correlation_id": 0, "children": [94242239113392, 94242239113808, 94242238889520, 94242239114224, 94242238889936], "parent": 94242238888688}, {"_name": "torch/_utils.py(521): _get_device_index", "id": 94242238889936, "start_time_ns": 1656454173490254695, "duration_time_ns": 45728, "correlation_id": 0, "children": [94242239114640, 94242239115056, 94242239117536, 94242238890352, 94242238890768], "parent": 94242238889104}, {"_name": "torch/_utils.py(497): _get_current_device_index", "id": 94242238890768, "start_time_ns": 1656454173490269804, "duration_time_ns": 30489, "correlation_id": 0, "children": [94242238891184], "parent": 94242238889936}, {"_name": "torch/_utils.py(487): _get_device_attr", "id": 94242238891184, "start_time_ns": 1656454173490277921, "duration_time_ns": 22112, "correlation_id": 0, "children": [94242238891600, 94242239118784, 94242238892432], "parent": 94242238890768}, {"_name": "torch/_utils.py(499): ", "id": 94242238892432, "start_time_ns": 1656454173490290622, "duration_time_ns": 9269, "correlation_id": 0, "children": [94242238892848], "parent": 94242238891184}, {"_name": "torch/cuda/__init__.py(480): current_device", "id": 94242238892848, "start_time_ns": 1656454173490292572, "duration_time_ns": 7253, "correlation_id": 0, "children": [94242238893264, 94242239119616], "parent": 94242238892432}, {"_name": "", "id": 94242239119616, "start_time_ns": 1656454173490296196, "duration_time_ns": 3565, "correlation_id": 0, "children": [], "parent": 94242238892848}, {"_name": "torch/cuda/__init__.py(191): _lazy_init", "id": 94242238893264, "start_time_ns": 1656454173490293743, "duration_time_ns": 1072, "correlation_id": 0, "children": [94242238893680], "parent": 94242238892848}, {"_name": "torch/cuda/__init__.py(149): is_initialized", "id": 94242238893680, "start_time_ns": 1656454173490294339, "duration_time_ns": 402, "correlation_id": 0, "children": [94242239119200], "parent": 94242238893264}, {"_name": "", "id": 94242239119200, "start_time_ns": 1656454173490294551, "duration_time_ns": 124, "correlation_id": 0, "children": [], "parent": 94242238893680}, {"_name": "", "id": 94242239118784, "start_time_ns": 1656454173490289374, "duration_time_ns": 241, "correlation_id": 0, "children": [], "parent": 94242238891184}, {"_name": "torch/_utils.py(478): _get_available_device_type", "id": 94242238891600, "start_time_ns": 1656454173490280148, "duration_time_ns": 8003, "correlation_id": 0, "children": [94242238892016], "parent": 94242238891184}, {"_name": "torch/cuda/__init__.py(77): is_available", "id": 94242238892016, "start_time_ns": 1656454173490282141, "duration_time_ns": 5804, "correlation_id": 0, "children": [94242239117952, 94242239118368], "parent": 94242238891600}, {"_name": "", "id": 94242239118368, "start_time_ns": 1656454173490286599, "duration_time_ns": 1061, "correlation_id": 0, "children": [], "parent": 94242238892016}, {"_name": "", "id": 94242239117952, "start_time_ns": 1656454173490284307, "duration_time_ns": 988, "correlation_id": 0, "children": [], "parent": 94242238892016}, {"_name": "torch/_jit_internal.py(982): is_scripting", "id": 94242238890352, "start_time_ns": 1656454173490268636, "duration_time_ns": 383, "correlation_id": 0, "children": [], "parent": 94242238889936}, {"_name": "", "id": 94242239117536, "start_time_ns": 1656454173490268135, "duration_time_ns": 45, "correlation_id": 0, "children": [], "parent": 94242238889936}, {"_name": "", "id": 94242239115056, "start_time_ns": 1656454173490266016, "duration_time_ns": 43, "correlation_id": 0, "children": [], "parent": 94242238889936}, {"_name": "", "id": 94242239114640, "start_time_ns": 1656454173490264843, "duration_time_ns": 71, "correlation_id": 0, "children": [], "parent": 94242238889936}, {"_name": "", "id": 94242239114224, "start_time_ns": 1656454173490253455, "duration_time_ns": 56, "correlation_id": 0, "children": [], "parent": 94242238889104}, {"_name": "torch/_jit_internal.py(982): is_scripting", "id": 94242238889520, "start_time_ns": 1656454173490250344, "duration_time_ns": 2192, "correlation_id": 0, "children": [], "parent": 94242238889104}, {"_name": "", "id": 94242239113808, "start_time_ns": 1656454173490247257, "duration_time_ns": 104, "correlation_id": 0, "children": [], "parent": 94242238889104}, {"_name": "", "id": 94242239113392, "start_time_ns": 1656454173490245162, "duration_time_ns": 807, "correlation_id": 0, "children": [], "parent": 94242238889104}, {"_name": "torch/cuda/__init__.py(191): _lazy_init", "id": 94242238887856, "start_time_ns": 1656454173490224967, "duration_time_ns": 10586, "correlation_id": 0, "children": [94242238888272], "parent": 94242238887440}, {"_name": "torch/cuda/__init__.py(149): is_initialized", "id": 94242238888272, "start_time_ns": 1656454173490227128, "duration_time_ns": 8241, "correlation_id": 0, "children": [94242239113008], "parent": 94242238887856}, {"_name": "", "id": 94242239113008, "start_time_ns": 1656454173490234177, "duration_time_ns": 892, "correlation_id": 0, "children": [], "parent": 94242238888272}, {"_name": "", "id": 94242238898288, "start_time_ns": 1656454173490187641, "duration_time_ns": 9517, "correlation_id": 0, "children": [94242239140704], "parent": 94242239140288}, {"_name": "enum.py(774): __hash__", "id": 94242239140704, "start_time_ns": 1656454173490190439, "duration_time_ns": 5319, "correlation_id": 0, "children": [94242239112592], "parent": 94242238898288}, {"_name": "", "id": 94242239112592, "start_time_ns": 1656454173490194870, "duration_time_ns": 721, "correlation_id": 0, "children": [], "parent": 94242239140704}, {"_name": "aten::matmul", "id": 94242239305072, "start_time_ns": 1656454173490115971, "duration_time_ns": 21513, "correlation_id": 3116, "children": [94242239305744], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94242239305744, "start_time_ns": 1656454173490116650, "duration_time_ns": 20114, "correlation_id": 3117, "children": [], "parent": 94242239305072}, {"_name": "aten::matmul", "id": 94242239303728, "start_time_ns": 1656454173490091388, "duration_time_ns": 21342, "correlation_id": 3114, "children": [94242239304400], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94242239304400, "start_time_ns": 1656454173490092214, "duration_time_ns": 19792, "correlation_id": 3115, "children": [], "parent": 94242239303728}, {"_name": "aten::matmul", "id": 94242239302384, "start_time_ns": 1656454173490054842, "duration_time_ns": 33225, "correlation_id": 3112, "children": [94242239303056], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94242239303056, "start_time_ns": 1656454173490055485, "duration_time_ns": 31894, "correlation_id": 3113, "children": [], "parent": 94242239302384}, {"_name": "aten::matmul", "id": 94242239301040, "start_time_ns": 1656454173490026585, "duration_time_ns": 24997, "correlation_id": 3110, "children": [94242239301712], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94242239301712, "start_time_ns": 1656454173490027380, "duration_time_ns": 23566, "correlation_id": 3111, "children": [], "parent": 94242239301040}, {"_name": "aten::matmul", "id": 94242239299696, "start_time_ns": 1656454173489912600, "duration_time_ns": 104156, "correlation_id": 3108, "children": [94242239300368], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94242239300368, "start_time_ns": 1656454173489916106, "duration_time_ns": 99633, "correlation_id": 3109, "children": [], "parent": 94242239299696}, {"_name": "test_profiler.py(1245): garbage_code", "id": 94242239139040, "start_time_ns": 1656454173442410244, "duration_time_ns": 47490938, "correlation_id": 0, "children": [94242239501088, 94242239482560, 94242239483728, 94242239484368, 94242239485536, 94242239486912, 94242239487552, 94242237992320, 94242237993696, 94242237994336, 94242237995712, 94242237997120, 94242237997856, 94242237999328, 94242238000800], "parent": 94242239135296}, {"_name": "aten::copy_", "id": 94242238000800, "start_time_ns": 1656454173489878232, "duration_time_ns": 20288, "correlation_id": 3107, "children": [], "parent": 94242239139040}, {"_name": "aten::select", "id": 94242237999328, "start_time_ns": 1656454173489875969, "duration_time_ns": 1490, "correlation_id": 3105, "children": [94242238000240], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242238000240, "start_time_ns": 1656454173489876749, "duration_time_ns": 269, "correlation_id": 3106, "children": [], "parent": 94242237999328}, {"_name": "aten::select", "id": 94242237997856, "start_time_ns": 1656454173489873022, "duration_time_ns": 2173, "correlation_id": 3103, "children": [94242237998768], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242237998768, "start_time_ns": 1656454173489874129, "duration_time_ns": 436, "correlation_id": 3104, "children": [], "parent": 94242237997856}, {"_name": "aten::copy_", "id": 94242237997120, "start_time_ns": 1656454173489848771, "duration_time_ns": 20290, "correlation_id": 3102, "children": [], "parent": 94242239139040}, {"_name": "aten::select", "id": 94242237995712, "start_time_ns": 1656454173489846145, "duration_time_ns": 1571, "correlation_id": 3100, "children": [94242237996560], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242237996560, "start_time_ns": 1656454173489847021, "duration_time_ns": 204, "correlation_id": 3101, "children": [], "parent": 94242237995712}, {"_name": "aten::select", "id": 94242237994336, "start_time_ns": 1656454173489842325, "duration_time_ns": 3114, "correlation_id": 3098, "children": [94242237995184], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242237995184, "start_time_ns": 1656454173489844409, "duration_time_ns": 440, "correlation_id": 3099, "children": [], "parent": 94242237994336}, {"_name": "aten::copy_", "id": 94242237993696, "start_time_ns": 1656454173489817557, "duration_time_ns": 20628, "correlation_id": 3097, "children": [], "parent": 94242239139040}, {"_name": "aten::select", "id": 94242237992320, "start_time_ns": 1656454173489814695, "duration_time_ns": 1630, "correlation_id": 3095, "children": [94242237993168], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242237993168, "start_time_ns": 1656454173489815568, "duration_time_ns": 267, "correlation_id": 3096, "children": [], "parent": 94242237992320}, {"_name": "aten::select", "id": 94242239487552, "start_time_ns": 1656454173489811667, "duration_time_ns": 2305, "correlation_id": 3093, "children": [94242237991792], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242237991792, "start_time_ns": 1656454173489812906, "duration_time_ns": 491, "correlation_id": 3094, "children": [], "parent": 94242239487552}, {"_name": "aten::copy_", "id": 94242239486912, "start_time_ns": 1656454173489771721, "duration_time_ns": 34924, "correlation_id": 3092, "children": [], "parent": 94242239139040}, {"_name": "aten::select", "id": 94242239485536, "start_time_ns": 1656454173489766717, "duration_time_ns": 2462, "correlation_id": 3090, "children": [94242239486384], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242239486384, "start_time_ns": 1656454173489767943, "duration_time_ns": 366, "correlation_id": 3091, "children": [], "parent": 94242239485536}, {"_name": "aten::select", "id": 94242239484368, "start_time_ns": 1656454173489760388, "duration_time_ns": 5433, "correlation_id": 3088, "children": [94242239485008], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242239485008, "start_time_ns": 1656454173489764139, "duration_time_ns": 858, "correlation_id": 3089, "children": [], "parent": 94242239484368}, {"_name": "aten::copy_", "id": 94242239483728, "start_time_ns": 1656454173444305057, "duration_time_ns": 45427507, "correlation_id": 3087, "children": [], "parent": 94242239139040}, {"_name": "aten::select", "id": 94242239482560, "start_time_ns": 1656454173444291864, "duration_time_ns": 8740, "correlation_id": 3085, "children": [94242239483200], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242239483200, "start_time_ns": 1656454173444296798, "duration_time_ns": 531, "correlation_id": 3086, "children": [], "parent": 94242239482560}, {"_name": "aten::select", "id": 94242239501088, "start_time_ns": 1656454173444252555, "duration_time_ns": 38328, "correlation_id": 3083, "children": [94242239501584], "parent": 94242239139040}, {"_name": "aten::as_strided", "id": 94242239501584, "start_time_ns": 1656454173444282394, "duration_time_ns": 3993, "correlation_id": 3084, "children": [], "parent": 94242239501088}, {"_name": "aten::matmul", "id": 94242239499936, "start_time_ns": 1656454173442384887, "duration_time_ns": 20958, "correlation_id": 3081, "children": [94242239500512], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94242239500512, "start_time_ns": 1656454173442385493, "duration_time_ns": 19655, "correlation_id": 3082, "children": [], "parent": 94242239499936}, {"_name": "aten::matmul", "id": 94240979327296, "start_time_ns": 1656454173442360631, "duration_time_ns": 21026, "correlation_id": 3079, "children": [94242238916288], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94242238916288, "start_time_ns": 1656454173442361296, "duration_time_ns": 19633, "correlation_id": 3080, "children": [], "parent": 94240979327296}, {"_name": "aten::matmul", "id": 94240979389888, "start_time_ns": 1656454173442325764, "duration_time_ns": 31593, "correlation_id": 3077, "children": [94240979374096], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94240979374096, "start_time_ns": 1656454173442326275, "duration_time_ns": 30364, "correlation_id": 3078, "children": [], "parent": 94240979389888}, {"_name": "aten::matmul", "id": 94240979295904, "start_time_ns": 1656454173442289759, "duration_time_ns": 32569, "correlation_id": 3075, "children": [94240169025248], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94240169025248, "start_time_ns": 1656454173442291934, "duration_time_ns": 29693, "correlation_id": 3076, "children": [], "parent": 94240979295904}, {"_name": "aten::matmul", "id": 94240979270032, "start_time_ns": 1656454173440014537, "duration_time_ns": 2254371, "correlation_id": 3073, "children": [94240979296288], "parent": 94242239135296}, {"_name": "aten::mm", "id": 94240979296288, "start_time_ns": 1656454173440019291, "duration_time_ns": 2245915, "correlation_id": 3074, "children": [], "parent": 94240979270032}, {"_name": "torch/profiler/profiler.py(472): __enter__", "id": 94242239135712, "start_time_ns": 1656454173438761183, "duration_time_ns": 1208076, "correlation_id": 0, "children": [94242239136128], "parent": 94242239135296}, {"_name": "torch/profiler/profiler.py(479): start", "id": 94242239136128, "start_time_ns": 1656454173438762066, "duration_time_ns": 1206947, "correlation_id": 0, "children": [94242239136544], "parent": 94242239135712}, {"_name": "torch/profiler/profiler.py(515): _transit_action", "id": 94242239136544, "start_time_ns": 1656454173438764183, "duration_time_ns": 1203897, "correlation_id": 0, "children": [94242239136960], "parent": 94242239136128}, {"_name": "torch/profiler/profiler.py(110): start_trace", "id": 94242239136960, "start_time_ns": 1656454173438766170, "duration_time_ns": 1200818, "correlation_id": 0, "children": [94242239137376, 94242238897424, 94242239137792], "parent": 94242239136544}, {"_name": "torch/profiler/profiler.py(189): _get_distributed_info", "id": 94242239137792, "start_time_ns": 1656454173439946391, "duration_time_ns": 20326, "correlation_id": 0, "children": [94242239138208, 94242239138624], "parent": 94242239136960}, {"_name": "torch/distributed/distributed_c10d.py(415): is_initialized", "id": 94242239138624, "start_time_ns": 1656454173439964257, "duration_time_ns": 2376, "correlation_id": 0, "children": [], "parent": 94242239137792}, {"_name": "torch/distributed/__init__.py(8): is_available", "id": 94242239138208, "start_time_ns": 1656454173439956583, "duration_time_ns": 5736, "correlation_id": 0, "children": [94242238897872], "parent": 94242239137792}, {"_name": "", "id": 94242238897872, "start_time_ns": 1656454173439960911, "duration_time_ns": 1344, "correlation_id": 0, "children": [], "parent": 94242239138208}, {"_name": "", "id": 94242238897424, "start_time_ns": 1656454173439940525, "duration_time_ns": 1813, "correlation_id": 0, "children": [], "parent": 94242239136960}, {"_name": "torch/autograd/profiler.py(205): _start_trace", "id": 94242239137376, "start_time_ns": 1656454173438766630, "duration_time_ns": 63314, "correlation_id": 0, "children": [], "parent": 94242239136960}]] diff --git a/test/quantization/core/experimental/fx_graph_mode_apot.py b/test/quantization/core/experimental/fx_graph_mode_apot.py new file mode 100644 index 0000000000000..74a5ef081da56 --- /dev/null +++ b/test/quantization/core/experimental/fx_graph_mode_apot.py @@ -0,0 +1,254 @@ +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms.transforms as transforms +import os +import torch.quantization +from torchvision.models.quantization.resnet import resnet18 + +# Setup warnings +import warnings +warnings.filterwarnings( + action='ignore', + category=DeprecationWarning, + module=r'.*' +) +warnings.filterwarnings( + action='default', + module=r'torch.quantization' +) + +""" +Define helper functions +""" + +# Specify random seed for repeatable results +_ = torch.manual_seed(191009) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def evaluate(model, criterion, data_loader): + model.eval() + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + cnt = 0 + with torch.no_grad(): + for image, target in data_loader: + output = model(image) + loss = criterion(output, target) + cnt += 1 + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0], image.size(0)) + top5.update(acc5[0], image.size(0)) + print('') + + return top1, top5 + +def load_model(model_file): + model = resnet18(pretrained=False) + state_dict = torch.load(model_file) + model.load_state_dict(state_dict) + model.to("cpu") + return model + +def print_size_of_model(model): + if isinstance(model, torch.jit.RecursiveScriptModule): + torch.jit.save(model, "temp.p") + else: + torch.jit.save(torch.jit.script(model), "temp.p") + print("Size (MB):", os.path.getsize("temp.p") / 1e6) + os.remove("temp.p") + +def prepare_data_loaders(data_path): + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + dataset = torchvision.datasets.ImageNet(data_path, + split="train", + transform=transforms.Compose([transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize])) + dataset_test = torchvision.datasets.ImageNet(data_path, + split="val", + transform=transforms.Compose([transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize])) + + train_sampler = torch.utils.data.RandomSampler(dataset) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=train_batch_size, + sampler=train_sampler) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, batch_size=eval_batch_size, + sampler=test_sampler) + + return data_loader, data_loader_test + +data_path = '~/my_imagenet/' + +train_batch_size = 30 +eval_batch_size = 50 + +data_loader, data_loader_test = prepare_data_loaders(data_path) +criterion = nn.CrossEntropyLoss() +float_model = resnet18(pretrained=True) +float_model.eval() + +# deepcopy the model since we need to keep the original model around +import copy +model_to_quantize = copy.deepcopy(float_model) + +model_to_quantize.eval() + +""" +Prepare models +""" + +# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee +from torch.quantization.quantize_fx import prepare_fx, convert_fx + +def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + +from torch.ao.quantization.experimental.qconfig import ( + uniform_qconfig_8bit, + apot_weights_qconfig_8bit, + apot_qconfig_8bit, + uniform_qconfig_4bit, + apot_weights_qconfig_4bit, + apot_qconfig_4bit +) + +""" +Prepare full precision model +""" +full_precision_model = float_model + +top1, top5 = evaluate(full_precision_model, criterion, data_loader_test) +print("Model #0 Evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg)) + +""" +Prepare model PTQ for specified qconfig for torch.nn.Linear +""" +def prepare_ptq_linear(qconfig): + qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]} + prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers + calibrate(prepared_model, data_loader_test) # run calibration on sample data + return prepared_model + +""" +Prepare model with uniform activation, uniform weight +b=8, k=2 +""" + +prepared_model = prepare_ptq_linear(uniform_qconfig_8bit) +quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model + +top1, top5 = evaluate(quantized_model, criterion, data_loader_test) +print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg)) + +""" +Prepare model with uniform activation, uniform weight +b=4, k=2 +""" + +prepared_model = prepare_ptq_linear(uniform_qconfig_4bit) +quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model + +top1, top5 = evaluate(quantized_model1, criterion, data_loader_test) +print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg)) + +""" +Prepare model with uniform activation, APoT weight +(b=8, k=2) +""" + +prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit) + +top1, top5 = evaluate(prepared_model, criterion, data_loader_test) +print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg)) + +""" +Prepare model with uniform activation, APoT weight +(b=4, k=2) +""" + +prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit) + +top1, top5 = evaluate(prepared_model, criterion, data_loader_test) +print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg)) + + +""" +Prepare model with APoT activation and weight +(b=8, k=2) +""" + +prepared_model = prepare_ptq_linear(apot_qconfig_8bit) + +top1, top5 = evaluate(prepared_model, criterion, data_loader_test) +print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg)) + +""" +Prepare model with APoT activation and weight +(b=4, k=2) +""" + +prepared_model = prepare_ptq_linear(apot_qconfig_4bit) + +top1, top5 = evaluate(prepared_model, criterion, data_loader_test) +print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg)) + +""" +Prepare eager mode quantized model +""" +eager_quantized_model = resnet18(pretrained=True, quantize=True).eval() +top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test) +print("Eager mode quantized model evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg)) diff --git a/test/quantization/core/experimental/test_fake_quantize.py b/test/quantization/core/experimental/test_fake_quantize.py new file mode 100644 index 0000000000000..4e9464aca800a --- /dev/null +++ b/test/quantization/core/experimental/test_fake_quantize.py @@ -0,0 +1,92 @@ +# Owner(s): ["oncall: quantization"] + +import torch +import unittest +from torch.ao.quantization.experimental.observer import APoTObserver +from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT +from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize +from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function +forward_helper = fake_quantize_function.forward +backward = fake_quantize_function.backward +from torch.autograd import gradcheck + +class TestFakeQuantize(unittest.TestCase): + r""" Tests fake quantize calculate_qparams() method + by comparing with result from observer calculate_qparams. + Uses hard-coded values: alpha=1.0, b=4, k=2. + """ + def test_fake_calc_qparams(self): + apot_fake = APoTFakeQuantize(b=4, k=2) + apot_fake.activation_post_process.min_val = torch.tensor([0.0]) + apot_fake.activation_post_process.max_val = torch.tensor([1.0]) + + alpha, gamma, quantization_levels, level_indices = apot_fake.calculate_qparams(signed=False) + + observer = APoTObserver(b=4, k=2) + observer.min_val = torch.tensor([0.0]) + observer.max_val = torch.tensor([1.0]) + + qparams_expected = observer.calculate_qparams(signed=False) + + self.assertEqual(alpha, qparams_expected[0]) + self.assertTrue(torch.equal(gamma, qparams_expected[1])) + self.assertTrue(torch.equal(quantization_levels, qparams_expected[2])) + self.assertTrue(torch.equal(level_indices, qparams_expected[3])) + + r""" Tests fake quantize forward() method + by comparing result with expected + quant_dequant_APoT mapping of input tensor. + Uses input tensor with random values from 0 -> 1000 + and APoT observer with hard-coded values b=4, k=2 + """ + def test_forward(self): + # generate a tensor of size 20 with random values + # between 0 -> 1000 to quantize -> dequantize + X = 1000 * torch.rand(20) + + observer = APoTObserver(b=4, k=2) + observer.forward(X) + alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False) + + apot_fake = APoTFakeQuantize(b=4, k=2) + apot_fake.enable_observer() + apot_fake.enable_fake_quant() + + X_reduced_precision_fp = apot_fake.forward(torch.clone(X), False) + + # get X_expected by converting fp -> apot -> fp to simulate quantize -> dequantize + X_to_apot = quantize_APoT(X, alpha, gamma, quantization_levels, level_indices) + X_expected = dequantize_APoT(X_to_apot) + + self.assertTrue(torch.equal(X_reduced_precision_fp, X_expected)) + + r""" Tests fake quantize forward() method + throws error when qparams are None + """ + def test_forward_exception(self): + # generate a tensor of size 20 with random values + # between 0 -> 1000 to quantize -> dequantize + X = 1000 * torch.rand(20) + + apot_fake = APoTFakeQuantize(b=4, k=2) + # disable observer so qparams not set, qparams are all None + apot_fake.disable_observer() + apot_fake.enable_fake_quant() + + with self.assertRaises(Exception): + apot_fake.forward(torch.clone(X), False) + + r""" Tests fake quantize helper backward() method + using torch.autograd.gradcheck function. + """ + def test_backward(self): + input = torch.randn(20, dtype=torch.double, requires_grad=True) + + observer = APoTObserver(b=4, k=2) + observer(input) + alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False) + + test = gradcheck(fake_quantize_function.apply, (input, alpha, gamma, quantization_levels, level_indices), atol=1e-4) + +if __name__ == '__main__': + unittest.main() diff --git a/test/quantization/core/experimental/test_linear.py b/test/quantization/core/experimental/test_linear.py new file mode 100644 index 0000000000000..6a46b4fc3ccbf --- /dev/null +++ b/test/quantization/core/experimental/test_linear.py @@ -0,0 +1,65 @@ +# Owner(s): ["oncall: quantization"] + +import torch +from torch.ao.quantization.experimental.linear import LinearAPoT +from torch.nn.modules.linear import Linear +import unittest + +class TestNonUniformObserver(unittest.TestCase): + """ + Test linear_APoT_fn by comparing to uniform linear + for 2d tensors with size (4,4) and k=1 + """ + def test_linear_APoT_k1(self): + # weight: fp tensor + weight = 1000 * torch.rand(4, 4) + + # activtion: fp32 tensor with ~ integer values + activation = torch.randint(low=0, high=255, size=(4, 4), dtype=torch.float) + + # calculate result from calling linear forward method + apot_linear = LinearAPoT(weight, 8, 1) + apot_linear_result = apot_linear(activation) + + # calculate expected results + fp_linear = Linear(4, 4, bias=False) + + # set weight for fp linear + apot_quantized_weight_float = apot_linear.weight.type(torch.FloatTensor) + fp_linear_weight = torch.nn.parameter.Parameter(data=apot_quantized_weight_float) + fp_linear.weight = fp_linear_weight + + fp_linear_result = fp_linear(activation).data + + self.assertTrue(torch.equal(apot_linear_result, fp_linear_result)) + + """ + Test linear_APoT_fn by comparing to uniform linear + for 2d tensors with size (5,3), (3, 5) and k=2 + """ + def test_linear_APoT_k2(self): + # weight: fp tensor + weight = 1000 * torch.rand(5, 3) + + # activtion: fp32 tensor with ~ integer values + # note: transpose of activation matrix will have dimension (3, 5) + activation = torch.randint(low=0, high=255, size=(5, 3), dtype=torch.float) + + # calculate result from calling linear forward method + apot_linear = LinearAPoT(weight, 8, 2) + apot_linear_result = apot_linear(activation) + + # calculate expected results + fp_linear = Linear(4, 4, bias=False) + + # set weight for fp linear + apot_quantized_weight_float = apot_linear.weight.type(torch.FloatTensor) + fp_linear_weight = torch.nn.parameter.Parameter(data=apot_quantized_weight_float) + fp_linear.weight = fp_linear_weight + + fp_linear_result = fp_linear(activation).data + + self.assertTrue(torch.equal(apot_linear_result, fp_linear_result)) + +if __name__ == '__main__': + unittest.main() diff --git a/test/quantization/core/experimental/test_nonuniform_observer.py b/test/quantization/core/experimental/test_nonuniform_observer.py index 4e2f197044e6d..3f3935f191daa 100644 --- a/test/quantization/core/experimental/test_nonuniform_observer.py +++ b/test/quantization/core/experimental/test_nonuniform_observer.py @@ -2,20 +2,23 @@ from torch.ao.quantization.experimental.observer import APoTObserver import unittest +import torch class TestNonUniformObserver(unittest.TestCase): """ - Test case 1 + Test case 1: calculate_qparams Test that error is thrown when k == 0 """ def test_calculate_qparams_invalid(self): - obs = APoTObserver(max_val=0.0, b=0, k=0) + obs = APoTObserver(b=0, k=0) + obs.min_val = torch.tensor([0.0]) + obs.max_val = torch.tensor([0.0]) with self.assertRaises(AssertionError): - obs_result = obs.calculate_qparams(signed=False) + alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False) """ - Test case 2 + Test case 2: calculate_qparams APoT paper example: https://arxiv.org/pdf/1909.13144.pdf Assume hardcoded parameters: * b = 4 (total number of bits across all terms) @@ -24,8 +27,16 @@ def test_calculate_qparams_invalid(self): * note: b = k * n """ def test_calculate_qparams_2terms(self): - obs = APoTObserver(max_val=1.0, b=4, k=2) - obs_result = obs.calculate_qparams(signed=False) + obs = APoTObserver(b=4, k=2) + + obs.min_val = torch.tensor([0.0]) + obs.max_val = torch.tensor([1.0]) + alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False) + + alpha_test = torch.max(-obs.min_val, obs.max_val) + + # check alpha value + self.assertEqual(alpha, alpha_test) # calculate expected gamma value gamma_test = 0 @@ -35,32 +46,39 @@ def test_calculate_qparams_2terms(self): gamma_test = 1 / gamma_test # check gamma value - self.assertEqual(obs_result[0], gamma_test) + self.assertEqual(gamma, gamma_test) # check quantization levels size - quantlevels_size_test = int(len(obs_result[1])) + quantlevels_size_test = int(len(quantization_levels)) quantlevels_size = 2**4 self.assertEqual(quantlevels_size_test, quantlevels_size) # check level indices size - levelindices_size_test = int(len(obs_result[2])) + levelindices_size_test = int(len(level_indices)) self.assertEqual(levelindices_size_test, 16) # check level indices unique values - level_indices_test_list = obs_result[2].tolist() + level_indices_test_list = level_indices.tolist() self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list))) """ - Test case 3 + Test case 3: calculate_qparams Assume hardcoded parameters: * b = 6 (total number of bits across all terms) * k = 2 (base bitwidth, i.e. bitwidth of every term) * n = 3 (number of additive terms) """ def test_calculate_qparams_3terms(self): - obs = APoTObserver(max_val=1.0, b=6, k=2) + obs = APoTObserver(b=6, k=2) - obs_result = obs.calculate_qparams(signed=False) + obs.min_val = torch.tensor([0.0]) + obs.max_val = torch.tensor([1.0]) + alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False) + + alpha_test = torch.max(-obs.min_val, obs.max_val) + + # check alpha value + self.assertEqual(alpha, alpha_test) # calculate expected gamma value gamma_test = 0 @@ -70,23 +88,23 @@ def test_calculate_qparams_3terms(self): gamma_test = 1 / gamma_test # check gamma value - self.assertEqual(obs_result[0], gamma_test) + self.assertEqual(gamma, gamma_test) # check quantization levels size - quantlevels_size_test = int(len(obs_result[1])) + quantlevels_size_test = int(len(quantization_levels)) quantlevels_size = 2**6 self.assertEqual(quantlevels_size_test, quantlevels_size) # check level indices size - levelindices_size_test = int(len(obs_result[2])) + levelindices_size_test = int(len(level_indices)) self.assertEqual(levelindices_size_test, 64) # check level indices unique values - level_indices_test_list = obs_result[2].tolist() + level_indices_test_list = level_indices.tolist() self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list))) """ - Test case 4 + Test case 4: calculate_qparams Same as test case 2 but with signed = True Assume hardcoded parameters: * b = 4 (total number of bits across all terms) @@ -95,8 +113,15 @@ def test_calculate_qparams_3terms(self): * signed = True """ def test_calculate_qparams_signed(self): - obs = APoTObserver(max_val=1.0, b=4, k=2) - obs_result = obs.calculate_qparams(signed=True) + obs = APoTObserver(b=4, k=2) + + obs.min_val = torch.tensor([0.0]) + obs.max_val = torch.tensor([1.0]) + alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=True) + alpha_test = torch.max(-obs.min_val, obs.max_val) + + # check alpha value + self.assertEqual(alpha, alpha_test) # calculate expected gamma value gamma_test = 0 @@ -106,15 +131,15 @@ def test_calculate_qparams_signed(self): gamma_test = 1 / gamma_test # check gamma value - self.assertEqual(obs_result[0], gamma_test) + self.assertEqual(gamma, gamma_test) # check quantization levels size - quantlevels_size_test = int(len(obs_result[1])) + quantlevels_size_test = int(len(quantization_levels)) self.assertEqual(quantlevels_size_test, 49) # check negatives of each element contained # in quantization levels - quantlevels_test_list = obs_result[1].tolist() + quantlevels_test_list = quantization_levels.tolist() negatives_contained = True for ele in quantlevels_test_list: if not (-ele) in quantlevels_test_list: @@ -122,12 +147,72 @@ def test_calculate_qparams_signed(self): self.assertTrue(negatives_contained) # check level indices size - levelindices_size_test = int(len(obs_result[2])) + levelindices_size_test = int(len(level_indices)) self.assertEqual(levelindices_size_test, 49) # check level indices unique elements - level_indices_test_list = obs_result[2].tolist() + level_indices_test_list = level_indices.tolist() self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list))) + """ + Test case 5: calculate_qparams + Assume hardcoded parameters: + * b = 6 (total number of bits across all terms) + * k = 1 (base bitwidth, i.e. bitwidth of every term) + * n = 6 (number of additive terms) + """ + def test_calculate_qparams_k1(self): + obs = APoTObserver(b=6, k=1) + + obs.min_val = torch.tensor([0.0]) + obs.max_val = torch.tensor([1.0]) + + alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False) + + # calculate expected gamma value + gamma_test = 0 + for i in range(6): + gamma_test += 2**(-i) + + gamma_test = 1 / gamma_test + + # check gamma value + self.assertEqual(gamma, gamma_test) + + # check quantization levels size + quantlevels_size_test = int(len(quantization_levels)) + quantlevels_size = 2**6 + self.assertEqual(quantlevels_size_test, quantlevels_size) + + # check level indices size + levelindices_size_test = int(len(level_indices)) + level_indices_size = 2**6 + self.assertEqual(levelindices_size_test, level_indices_size) + + # check level indices unique values + level_indices_test_list = level_indices.tolist() + self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list))) + + """ + Test forward method on hard-coded tensor with arbitrary values. + Checks that alpha is max of abs value of max and min values in tensor. + """ + def test_forward(self): + obs = APoTObserver(b=4, k=2) + + X = torch.tensor([0.0, -100.23, -37.18, 3.42, 8.93, 9.21, 87.92]) + + X = obs.forward(X) + + alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=True) + + min_val = torch.min(X) + max_val = torch.max(X) + + expected_alpha = torch.max(-min_val, max_val) + + self.assertEqual(alpha, expected_alpha) + + if __name__ == '__main__': unittest.main() diff --git a/test/quantization/core/experimental/test_quantized_tensor.py b/test/quantization/core/experimental/test_quantized_tensor.py index 5ccc362aac0f6..02286b94f8db3 100644 --- a/test/quantization/core/experimental/test_quantized_tensor.py +++ b/test/quantization/core/experimental/test_quantized_tensor.py @@ -1,22 +1,41 @@ # Owner(s): ["oncall: quantization"] import torch -from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT import unittest +from torch.ao.quantization.experimental.observer import APoTObserver +from torch.ao.quantization.experimental.quantizer import quantize_APoT class TestQuantizedTensor(unittest.TestCase): - def test_quantize_APoT(self): - t = torch.Tensor() - with self.assertRaises(NotImplementedError): - TensorAPoT.quantize_APoT(t) - - def test_dequantize(self): - with self.assertRaises(NotImplementedError): - TensorAPoT.dequantize(self) - - def test_q_apot_alpha(self): - with self.assertRaises(NotImplementedError): - TensorAPoT.q_apot_alpha(self) + r""" Tests int_repr on APoTQuantizer with random tensor2quantize + and hard-coded values + """ + def test_int_repr(self): + # generate tensor with random fp values + tensor2quantize = tensor2quantize = torch.tensor([0, 0.0215, 0.1692, 0.385, 1, 0.0391]) + + observer = APoTObserver(b=4, k=2) + + observer.forward(tensor2quantize) + + qparams = observer.calculate_qparams(signed=False) + + # get apot quantized tensor result + qtensor = quantize_APoT(tensor2quantize=tensor2quantize, + alpha=qparams[0], + gamma=qparams[1], + quantization_levels=qparams[2], + level_indices=qparams[3]) + + qtensor_data = qtensor.int_repr().int() + + # expected qtensor values calculated based on + # corresponding level_indices to nearest quantization level + # for each fp value in tensor2quantize + # e.g. + # 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices + expected_qtensor_data = torch.tensor([0, 3, 8, 13, 5, 12], dtype=torch.int32) + + self.assertTrue(torch.equal(qtensor_data, expected_qtensor_data)) if __name__ == '__main__': unittest.main() diff --git a/test/quantization/core/experimental/test_quantizer.py b/test/quantization/core/experimental/test_quantizer.py new file mode 100644 index 0000000000000..d689ee8e99e15 --- /dev/null +++ b/test/quantization/core/experimental/test_quantizer.py @@ -0,0 +1,229 @@ +# Owner(s): ["oncall: quantization"] + +import torch +from torch import quantize_per_tensor +from torch.ao.quantization.observer import MinMaxObserver +from torch.ao.quantization.experimental.observer import APoTObserver +from torch.ao.quantization.experimental.quantizer import APoTQuantizer, quantize_APoT, dequantize_APoT +import unittest +import random + +class TestQuantizer(unittest.TestCase): + r""" Tests quantize_APoT result on random 1-dim tensor + and hardcoded values for b, k by comparing to uniform quantization + (non-uniform quantization reduces to uniform for k = 1) + quantized tensor (https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html) + * tensor2quantize: Tensor + * b: 8 + * k: 1 + """ + def test_quantize_APoT_rand_k1(self): + # generate random size of tensor2quantize between 1 -> 20 + size = random.randint(1, 20) + + # generate tensor with random fp values between 0 -> 1000 + tensor2quantize = 1000 * torch.rand(size, dtype=torch.float) + + apot_observer = APoTObserver(b=8, k=1) + apot_observer(tensor2quantize) + alpha, gamma, quantization_levels, level_indices = apot_observer.calculate_qparams(signed=False) + + # get apot quantized tensor result + qtensor = quantize_APoT(tensor2quantize=tensor2quantize, + alpha=alpha, + gamma=gamma, + quantization_levels=quantization_levels, + level_indices=level_indices) + + # get uniform quantization quantized tensor result + uniform_observer = MinMaxObserver() + uniform_observer(tensor2quantize) + scale, zero_point = uniform_observer.calculate_qparams() + + uniform_quantized = quantize_per_tensor(input=tensor2quantize, + scale=scale, + zero_point=zero_point, + dtype=torch.quint8).int_repr() + + qtensor_data = qtensor.data.int() + uniform_quantized_tensor = uniform_quantized.data.int() + + self.assertTrue(torch.equal(qtensor_data, uniform_quantized_tensor)) + + r""" Tests quantize_APoT for k != 1. + Tests quantize_APoT result on random 1-dim tensor and hardcoded values for + b=4, k=2 by comparing results to hand-calculated results from APoT paper + https://arxiv.org/pdf/1909.13144.pdf + * tensor2quantize: Tensor + * b: 4 + * k: 2 + """ + def test_quantize_APoT_k2(self): + r""" + given b = 4, k = 2, alpha = 1.0, we know: + (from APoT paper example: https://arxiv.org/pdf/1909.13144.pdf) + + quantization_levels = tensor([0.0000, 0.0208, 0.0417, 0.0625, 0.0833, 0.1250, 0.1667, + 0.1875, 0.2500, 0.3333, 0.3750, 0.5000, 0.6667, 0.6875, 0.7500, 1.0000]) + + level_indices = tensor([ 0, 3, 12, 15, 2, 14, 8, 11, 10, 1, 13, 9, 4, 7, 6, 5])) + """ + + # generate tensor with random fp values + tensor2quantize = torch.tensor([0, 0.0215, 0.1692, 0.385, 1, 0.0391]) + + observer = APoTObserver(b=4, k=2) + observer.forward(tensor2quantize) + alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False) + + # get apot quantized tensor result + qtensor = quantize_APoT(tensor2quantize=tensor2quantize, + alpha=alpha, + gamma=gamma, + quantization_levels=quantization_levels, + level_indices=level_indices) + + qtensor_data = qtensor.data.int() + + # expected qtensor values calculated based on + # corresponding level_indices to nearest quantization level + # for each fp value in tensor2quantize + # e.g. + # 0.0215 in tensor2quantize nearest 0.0208 in quantization_levels -> 3 in level_indices + expected_qtensor = torch.tensor([0, 3, 8, 13, 5, 12], dtype=torch.int32) + + self.assertTrue(torch.equal(qtensor_data, expected_qtensor)) + + r""" Tests dequantize_apot result on random 1-dim tensor + and hardcoded values for b, k. + Dequant -> quant an input tensor and verify that + result is equivalent to input + * tensor2quantize: Tensor + * b: 4 + * k: 2 + """ + def test_dequantize_quantize_rand_b4(self): + # make observer + observer = APoTObserver(4, 2) + + # generate random size of tensor2quantize between 1 -> 20 + size = random.randint(1, 20) + + # make tensor2quantize: random fp values between 0 -> 1000 + tensor2quantize = 1000 * torch.rand(size, dtype=torch.float) + + observer.forward(tensor2quantize) + + alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False) + + # make mock apot_tensor + original_apot = quantize_APoT(tensor2quantize=tensor2quantize, + alpha=alpha, + gamma=gamma, + quantization_levels=quantization_levels, + level_indices=level_indices) + + original_input = torch.clone(original_apot.data).int() + + # dequantize apot_tensor + dequantize_result = dequantize_APoT(apot_tensor=original_apot) + + # quantize apot_tensor + final_apot = quantize_APoT(tensor2quantize=dequantize_result, + alpha=alpha, + gamma=gamma, + quantization_levels=quantization_levels, + level_indices=level_indices) + + result = final_apot.data.int() + + self.assertTrue(torch.equal(original_input, result)) + + r""" Tests dequantize_apot result on random 1-dim tensor + and hardcoded values for b, k. + Dequant -> quant an input tensor and verify that + result is equivalent to input + * tensor2quantize: Tensor + * b: 12 + * k: 4 + """ + def test_dequantize_quantize_rand_b6(self): + # make observer + observer = APoTObserver(12, 4) + + # generate random size of tensor2quantize between 1 -> 20 + size = random.randint(1, 20) + + # make tensor2quantize: random fp values between 0 -> 1000 + tensor2quantize = 1000 * torch.rand(size, dtype=torch.float) + + observer.forward(tensor2quantize) + + alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False) + + # make mock apot_tensor + original_apot = quantize_APoT(tensor2quantize=tensor2quantize, + alpha=alpha, + gamma=gamma, + quantization_levels=quantization_levels, + level_indices=level_indices) + + original_input = torch.clone(original_apot.data).int() + + # dequantize apot_tensor + dequantize_result = dequantize_APoT(apot_tensor=original_apot) + + # quantize apot_tensor + final_apot = quantize_APoT(tensor2quantize=dequantize_result, + alpha=alpha, + gamma=gamma, + quantization_levels=quantization_levels, + level_indices=level_indices) + + result = final_apot.data.int() + + self.assertTrue(torch.equal(original_input, result)) + + r""" Tests for correct dimensions in dequantize_apot result + on random 3-dim tensor with random dimension sizes + and hardcoded values for b, k. + Dequant an input tensor and verify that + dimensions are same as input. + * tensor2quantize: Tensor + * b: 4 + * k: 2 + """ + def test_dequantize_dim(self): + # make observer + observer = APoTObserver(4, 2) + + # generate random size of tensor2quantize between 1 -> 20 + size1 = random.randint(1, 20) + size2 = random.randint(1, 20) + size3 = random.randint(1, 20) + + # make tensor2quantize: random fp values between 0 -> 1000 + tensor2quantize = 1000 * torch.rand(size1, size2, size3, dtype=torch.float) + + observer.forward(tensor2quantize) + + alpha, gamma, quantization_levels, level_indices = observer.calculate_qparams(signed=False) + + # make mock apot_tensor + original_apot = quantize_APoT(tensor2quantize=tensor2quantize, + alpha=alpha, + gamma=gamma, + quantization_levels=quantization_levels, + level_indices=level_indices) + + # dequantize apot_tensor + dequantize_result = dequantize_APoT(apot_tensor=original_apot) + + self.assertEqual(original_apot.data.size(), dequantize_result.size()) + + def test_q_apot_alpha(self): + with self.assertRaises(NotImplementedError): + APoTQuantizer.q_apot_alpha(self) + +if __name__ == '__main__': + unittest.main() diff --git a/test/quantization/core/test_backend_config.py b/test/quantization/core/test_backend_config.py new file mode 100644 index 0000000000000..731d0f5afe6b2 --- /dev/null +++ b/test/quantization/core/test_backend_config.py @@ -0,0 +1,319 @@ +# Owner(s): ["oncall: quantization"] + +import torch +import torch.nn.intrinsic as nni +import torch.nn.qat as nnqat +import torch.nn.quantized._reference as nnqr +from torch.testing._internal.common_quantization import QuantizationTestCase + +from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) +from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2 +from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter +from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer + + +class TestBackendConfig(QuantizationTestCase): + + # ============= + # DTypeConfig + # ============= + + dtype_config1 = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float + ) + + dtype_config2 = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + is_dynamic=True + ) + + dtype_config_dict1 = { + "input_dtype": torch.quint8, + "output_dtype": torch.quint8, + "weight_dtype": torch.qint8, + "bias_dtype": torch.float, + } + + dtype_config_dict2 = { + "input_dtype": torch.float16, + "output_dtype": torch.float, + "is_dynamic": True, + } + + def test_dtype_config_from_dict(self): + self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1), self.dtype_config1) + self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2), self.dtype_config2) + + def test_dtype_config_to_dict(self): + self.assertEqual(self.dtype_config1.to_dict(), self.dtype_config_dict1) + self.assertEqual(self.dtype_config2.to_dict(), self.dtype_config_dict2) + + # ====================== + # BackendPatternConfig + # ====================== + + _fuser_method = reverse_sequential_wrapper2(nni.LinearReLU) + + _num_tensor_args_to_observation_type = { + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + _input_type_to_index = { + "bias": 0, + "input": 1, + "weight": 2, + } + _fake_quantize = FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer) + + def _extra_inputs_getter(self, p): + return (torch.rand(3, 3),) + + def _get_backend_op_config1(self): + return BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(self.dtype_config1) \ + .add_dtype_config(self.dtype_config2) \ + .set_root_module(torch.nn.Linear) \ + .set_qat_module(nnqat.Linear) \ + .set_reference_quantized_module(nnqr.Linear) \ + .set_fused_module(nni.LinearReLU) \ + .set_fuser_method(self._fuser_method) + + def _get_backend_op_config2(self): + return BackendPatternConfig(torch.add) \ + .add_dtype_config(self.dtype_config2) \ + ._set_root_node_getter(_default_root_node_getter) \ + ._set_extra_inputs_getter(self._extra_inputs_getter) \ + ._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \ + ._set_input_type_to_index(self._input_type_to_index) \ + ._set_input_output_observed(False) \ + ._set_overwrite_output_fake_quantize(self._fake_quantize) \ + ._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer) + + def _get_backend_pattern_config_dict1(self): + return { + "pattern": (torch.nn.ReLU, torch.nn.Linear), + "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + "dtype_configs": [self.dtype_config_dict1, self.dtype_config_dict2], + "root_module": torch.nn.Linear, + "qat_module": nnqat.Linear, + "reference_quantized_module_for_root": nnqr.Linear, + "fused_module": nni.LinearReLU, + "fuser_method": self._fuser_method, + } + + def _get_backend_pattern_config_dict2(self): + return { + "pattern": torch.add, + "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + "dtype_configs": [self.dtype_config_dict2], + "root_node_getter": _default_root_node_getter, + "extra_inputs_getter": self._extra_inputs_getter, + "num_tensor_args_to_observation_type": self._num_tensor_args_to_observation_type, + "input_type_to_index": self._input_type_to_index, + "input_output_observed": False, + "overwrite_output_fake_quantize": self._fake_quantize, + "overwrite_output_observer": default_fixed_qparams_range_0to1_observer + } + + def test_backend_op_config_set_observation_type(self): + conf = BackendPatternConfig(torch.nn.Linear) + self.assertEqual(conf.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) + conf.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + self.assertEqual(conf.observation_type, ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + + def test_backend_op_config_add_dtype_config(self): + conf = BackendPatternConfig(torch.nn.Linear) + self.assertEqual(len(conf.dtype_configs), 0) + conf.add_dtype_config(self.dtype_config1) + conf.add_dtype_config(self.dtype_config2) + self.assertEqual(len(conf.dtype_configs), 2) + self.assertEqual(conf.dtype_configs[0], self.dtype_config1) + self.assertEqual(conf.dtype_configs[1], self.dtype_config2) + + def test_backend_op_config_set_root_module(self): + conf = BackendPatternConfig(nni.LinearReLU) + self.assertTrue(conf.root_module is None) + conf.set_root_module(torch.nn.Linear) + self.assertEqual(conf.root_module, torch.nn.Linear) + + def test_backend_op_config_set_qat_module(self): + conf = BackendPatternConfig(torch.nn.Linear) + self.assertTrue(conf.qat_module is None) + conf.set_qat_module(nnqat.Linear) + self.assertEqual(conf.qat_module, nnqat.Linear) + + def test_backend_op_config_set_reference_quantized_module(self): + conf = BackendPatternConfig(torch.nn.Linear) + self.assertTrue(conf.reference_quantized_module is None) + conf.set_reference_quantized_module(nnqr.Linear) + self.assertEqual(conf.reference_quantized_module, nnqr.Linear) + + def test_backend_op_config_set_fused_module(self): + conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) + self.assertTrue(conf.fused_module is None) + conf.set_fused_module(nni.LinearReLU) + self.assertEqual(conf.fused_module, nni.LinearReLU) + + def test_backend_op_config_set_fuser_method(self): + conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) + self.assertTrue(conf.fuser_method is None) + conf.set_fuser_method(self._fuser_method) + self.assertEqual(conf.fuser_method, self._fuser_method) + + def test_backend_op_config_set_root_node_getter(self): + conf = BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) + self.assertTrue(conf._root_node_getter is None) + conf._set_root_node_getter(_default_root_node_getter) + self.assertEqual(conf._root_node_getter, _default_root_node_getter) + + def test_backend_op_config_set_extra_inputs_getter(self): + conf = BackendPatternConfig(torch.nn.Linear) + self.assertTrue(conf._extra_inputs_getter is None) + conf._set_extra_inputs_getter(self._extra_inputs_getter) + self.assertEqual(conf._extra_inputs_getter, self._extra_inputs_getter) + + def test_backend_op_config_set_num_tensor_args_to_observation_type(self): + conf = BackendPatternConfig(torch.add) + self.assertEqual(len(conf._num_tensor_args_to_observation_type), 0) + conf._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) + self.assertEqual(conf._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type) + + def test_backend_op_config_set_input_type_to_index(self): + conf = BackendPatternConfig(torch.addmm) + self.assertEqual(len(conf._input_type_to_index), 0) + conf._set_input_type_to_index(self._input_type_to_index) + self.assertEqual(conf._input_type_to_index, self._input_type_to_index) + + def test_backend_op_config_set_input_output_observed(self): + conf = BackendPatternConfig(torch.nn.Embedding) + self.assertTrue(conf._input_output_observed is None) + conf._set_input_output_observed(False) + self.assertEqual(conf._input_output_observed, False) + + def test_backend_op_config_set_overwrite_output_fake_quantize(self): + conf = BackendPatternConfig(torch.sigmoid) + self.assertTrue(conf._overwrite_output_fake_quantize is None) + conf._set_overwrite_output_fake_quantize(self._fake_quantize) + self.assertEqual(conf._overwrite_output_fake_quantize, self._fake_quantize) + + def test_backend_op_config_set_overwrite_output_observer(self): + conf = BackendPatternConfig(torch.sigmoid) + self.assertTrue(conf._overwrite_output_observer is None) + conf._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer) + self.assertEqual(conf._overwrite_output_observer, default_fixed_qparams_range_0to1_observer) + + def test_backend_op_config_from_dict(self): + conf_dict1 = self._get_backend_pattern_config_dict1() + conf1 = BackendPatternConfig.from_dict(conf_dict1) + self.assertEqual(conf1.pattern, (torch.nn.ReLU, torch.nn.Linear)) + self.assertEqual(conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) + self.assertEqual(conf1.root_module, torch.nn.Linear) + self.assertEqual(conf1.qat_module, nnqat.Linear) + self.assertEqual(conf1.reference_quantized_module, nnqr.Linear) + self.assertEqual(conf1.fused_module, nni.LinearReLU) + self.assertEqual(conf1.fuser_method, self._fuser_method) + self.assertTrue(conf1._root_node_getter is None) + self.assertTrue(conf1._extra_inputs_getter is None) + self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0) + self.assertEqual(len(conf1._input_type_to_index), 0) + self.assertTrue(conf1._input_output_observed is None) + self.assertTrue(conf1._overwrite_output_fake_quantize is None) + self.assertTrue(conf1._overwrite_output_observer is None) + # Test temporary/internal keys + conf_dict2 = self._get_backend_pattern_config_dict2() + conf2 = BackendPatternConfig.from_dict(conf_dict2) + self.assertEqual(conf2.pattern, torch.add) + self.assertEqual(conf2.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) + self.assertTrue(conf2.root_module is None) + self.assertTrue(conf2.qat_module is None) + self.assertTrue(conf2.reference_quantized_module is None) + self.assertTrue(conf2.fused_module is None) + self.assertTrue(conf2.fuser_method is None) + self.assertEqual(conf2._root_node_getter, _default_root_node_getter) + self.assertEqual(conf2._extra_inputs_getter, self._extra_inputs_getter) + self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type) + self.assertEqual(conf2._input_type_to_index, self._input_type_to_index) + self.assertEqual(conf2._input_output_observed, False) + self.assertEqual(conf2._overwrite_output_fake_quantize, self._fake_quantize) + self.assertEqual(conf2._overwrite_output_observer, default_fixed_qparams_range_0to1_observer) + + def test_backend_op_config_to_dict(self): + conf1 = self._get_backend_op_config1() + conf2 = self._get_backend_op_config2() + conf_dict1 = self._get_backend_pattern_config_dict1() + conf_dict2 = self._get_backend_pattern_config_dict2() + self.assertEqual(conf1.to_dict(), conf_dict1) + self.assertEqual(conf2.to_dict(), conf_dict2) + + # =============== + # BackendConfig + # =============== + + def test_backend_config_set_name(self): + conf = BackendConfig("name1") + self.assertEqual(conf.name, "name1") + conf.set_name("name2") + self.assertEqual(conf.name, "name2") + + def test_backend_config_set_backend_pattern_config(self): + conf = BackendConfig("name1") + self.assertEqual(len(conf.configs), 0) + backend_op_config1 = self._get_backend_op_config1() + backend_op_config2 = self._get_backend_op_config2() + conf.set_backend_pattern_config(backend_op_config1) + self.assertEqual(conf.configs, { + (torch.nn.ReLU, torch.nn.Linear): backend_op_config1, + }) + conf.set_backend_pattern_config(backend_op_config2) + self.assertEqual(conf.configs, { + (torch.nn.ReLU, torch.nn.Linear): backend_op_config1, + torch.add: backend_op_config2 + }) + + def test_backend_config_from_dict(self): + op1 = self._get_backend_op_config1() + op2 = self._get_backend_op_config2() + op_dict1 = self._get_backend_pattern_config_dict1() + op_dict2 = self._get_backend_pattern_config_dict2() + conf_dict = { + "name": "name1", + "configs": [op_dict1, op_dict2], + } + conf = BackendConfig.from_dict(conf_dict) + self.assertEqual(conf.name, "name1") + self.assertEqual(len(conf.configs), 2) + key1 = (torch.nn.ReLU, torch.nn.Linear) + key2 = torch.add + self.assertTrue(key1 in conf.configs) + self.assertTrue(key2 in conf.configs) + self.assertEqual(conf.configs[key1].to_dict(), op_dict1) + self.assertEqual(conf.configs[key2].to_dict(), op_dict2) + + def test_backend_config_to_dict(self): + op1 = self._get_backend_op_config1() + op2 = self._get_backend_op_config2() + op_dict1 = self._get_backend_pattern_config_dict1() + op_dict2 = self._get_backend_pattern_config_dict2() + conf = BackendConfig("name1").set_backend_pattern_config(op1).set_backend_pattern_config(op2) + conf_dict = { + "name": "name1", + "configs": [op_dict1, op_dict2], + } + self.assertEqual(conf.to_dict(), conf_dict) + +if __name__ == '__main__': + raise RuntimeError("This _test file is not meant to be run directly, use:\n\n" + "\tpython _test/_test_quantization.py TESTNAME\n\n" + "instead.") diff --git a/test/quantization/core/test_docs.py b/test/quantization/core/test_docs.py new file mode 100644 index 0000000000000..f719a340a1614 --- /dev/null +++ b/test/quantization/core/test_docs.py @@ -0,0 +1,151 @@ +# Owner(s): ["oncall: quantization"] + +import re +from pathlib import Path + +import torch + +# import torch.nn.quantized as nnq +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + SingleLayerLinearModel, +) + + +class TestQuantizationDocs(QuantizationTestCase): + r""" + The tests in this section import code from the quantization docs and check that + they actually run without errors. In cases where objects are undefined in the code snippet, + they must be provided in the test. The imports seem to behave a bit inconsistently, + they can be imported either in the test file or passed as a global input + """ + + def _get_code( + self, path_from_pytorch, unique_identifier, offset=2, short_snippet=False + ): + r""" + This function reads in the code from the docs given a unique identifier. + Most code snippets have a 2 space indentation, for other indentation levels, + change the offset `arg`. the `short_snippet` arg can be set to allow for testing + of smaller snippets, the check that this arg controls is used to make sure that + we are not accidentally only importing a blank line or something. + """ + + def get_correct_path(path_from_pytorch): + r""" + Current working directory when CI is running test seems to vary, this function + looks for docs and if it finds it looks for the path to the + file and if the file exists returns that path, otherwise keeps looking. Will + only work if cwd contains pytorch or docs or a parent contains docs. + """ + # get cwd + cur_dir_path = Path(".").resolve() + + # check if cwd contains pytorch, use that if it does + if (cur_dir_path / "pytorch").is_dir(): + cur_dir_path = (cur_dir_path / "pytorch").resolve() + + # need to find the file, so we check current directory + # and all parent directories to see if the path leads to it + check_dir = cur_dir_path + while not check_dir == check_dir.parent: + file_path = (check_dir / path_from_pytorch).resolve() + if file_path.is_file(): + return file_path + check_dir = check_dir.parent.resolve() + + # no longer passing when file not found + raise FileNotFoundError("could not find {}".format(path_from_pytorch)) + + path_to_file = get_correct_path(path_from_pytorch) + if path_to_file: + file = open(path_to_file) + content = file.readlines() + + # it will register as having a newline at the end in python + if "\n" not in unique_identifier: + unique_identifier += "\n" + + assert unique_identifier in content, "could not find {} in {}".format( + unique_identifier, path_to_file + ) + + # get index of first line of code + line_num_start = content.index(unique_identifier) + 1 + + # next find where the code chunk ends. + # this regex will match lines that don't start + # with a \n or " " with number of spaces=offset + r = r = re.compile("^[^\n," + " " * offset + "]") + # this will return the line of first line that matches regex + line_after_code = next(filter(r.match, content[line_num_start:])) + last_line_num = content.index(line_after_code) + + # remove the first `offset` chars of each line and gather it all together + code = "".join( + [x[offset:] for x in content[line_num_start + 1 : last_line_num]] + ) + + # want to make sure we are actually getting some code, + assert last_line_num - line_num_start > 3 or short_snippet, ( + "The code in {} identified by {} seems suspiciously short:" + "\n\n###code-start####\n{}###code-end####".format( + path_to_file, unique_identifier, code + ) + ) + return code + + return None + + def _test_code(self, code, global_inputs=None): + r""" + This function runs `code` using any vars in `global_inputs` + """ + # if couldn't find the + if code is not None: + expr = compile(code, "test", "exec") + exec(expr, global_inputs) + + def test_quantization_doc_ptdq(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "PTDQ API Example::" + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code) + + def test_quantization_doc_ptsq(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "PTSQ API Example::" + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code) + + def test_quantization_doc_qat(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "QAT API Example::" + + def _dummy_func(*args, **kwargs): + return None + + input_fp32 = torch.randn(1, 1, 1, 1) + global_inputs = {"training_loop": _dummy_func, "input_fp32": input_fp32} + + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code, global_inputs) + + def test_quantization_doc_fx(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "FXPTQ API Example::" + + input_fp32 = SingleLayerLinearModel().get_example_inputs() + global_inputs = {"UserModel": SingleLayerLinearModel, "input_fp32": input_fp32} + + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code, global_inputs) + + def test_quantization_doc_custom(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "Custom API Example::" + + global_inputs = {"nnq": torch.nn.quantized} + + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code, global_inputs) diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index 7cbab3be475e1..067b7b481426f 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -1011,6 +1011,30 @@ def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True, dtype=qdtype) + def test_prelu(self): + x = torch.randn((4, 4, 4, 4), dtype=torch.float) + qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8) + + # num_parameters = 1 + prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=1) + w = torch.randn(1, dtype=torch.float) + qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8) + prelu_module.set_weight(qw) + qy = prelu_module(qx) + qy_ref = torch.prelu(qx, qw) + + self.assertEqual(qy_ref, qy, + msg="PReLU module API failed") + + # num_parameters = num_channels + prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=4) + w = torch.randn(4, dtype=torch.float) + qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8) + prelu_module.set_weight(qw) + qy = prelu_module(qx) + qy_ref = torch.prelu(qx, qw) + self.assertEqual(qy_ref, qy, + msg="PReLU module API failed") class TestDynamicQuantizedModule(QuantizationTestCase): def _test_qconv_impl(self, q_mod, dq_mod, dim, dtype, bias): diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index e58f8e2fc3052..41b735afc75c2 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -21,7 +21,7 @@ import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, skipIfSlowGradcheckEnv from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2 from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ @@ -130,6 +130,7 @@ def _get_random_tensor_and_q_params(shapes, rand_scale, torch_type): X_scale = 1e-10 return X, X_scale, X_zero_point +@skipIfSlowGradcheckEnv class TestQuantizedOps(TestCase): """Helper function to test quantized activation functions.""" @@ -469,6 +470,40 @@ def test_qgelu(self): self.assertEqual(qY.dequantize(), qY_hat.dequantize(), msg="F.gelu failed ({} vs {})".format(qY, qY_hat)) + """Tests the correctness of the quantized::prelu op.""" + def test_qprelu(self): + shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) + num_params = (0, 1) # 0: num_parameter = num_channels + dtypes = (torch.quint8, torch.qint8) + memory_formats = (torch.channels_last, torch.contiguous_format) + test_cases = itertools.product(shapes, num_params, dtypes, memory_formats) + for shape, num_param, dtype, memory_format in test_cases: + if memory_format == torch.channels_last and len(shape) != 4: + continue + X, scale, zero_point, torch_type = \ + torch.randn(*shape), 0.1, 0, dtype + X = X.to(memory_format=memory_format) + num_parameter = 1 if num_param == 1 or len(shape) == 1 else shape[1] + W = torch.randn(num_parameter) + W, w_scale, w_zero_point = \ + torch.randn(num_parameter), 0.2, 0 + + qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, + dtype=torch_type) + dqX = qX.dequantize() + qW = torch.quantize_per_tensor(W, scale=w_scale, zero_point=w_zero_point, + dtype=torch_type) + dqW = qW.dequantize() + + op = torch.nn.functional.prelu + qop = torch.ops.quantized.prelu + dqY = op(dqX, dqW) + qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point, + dtype=torch_type) + qY_hat = qop(qX, qW, scale, zero_point) + self.assertEqual(qY.dequantize(), qY_hat.dequantize(), + msg="F.prelu failed ({} vs {})".format(qY, qY_hat)) + """Tests the correctness of the quantized::qlayer_norm op.""" @skipIfNoFBGEMM def test_qlayer_norm(self): @@ -2144,21 +2179,24 @@ def test_cat_nhwc(self, X, relu): torch.testing.assert_close(out.dequantize(), ref.dequantize()) self.assertNotEqual(out.stride(), sorted(out.stride())) - @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=1, max_dims=5, - min_side=1, max_side=4), - qparams=hu.qparams()), - dim=st.integers(-1, 5)) @override_qengines - def test_mean(self, X, dim): - X, (scale, zero_point, torch_type) = X - assume(dim < X.ndim) - qX = torch.quantize_per_tensor(torch.tensor(X).float(), scale, zero_point, torch_type) - - Y = torch.mean(qX.dequantize(), dim) - Y = torch.quantize_per_tensor(Y, scale, zero_point, torch_type).dequantize() - qY = torch.mean(qX, dim) - - self.assertEqual(Y, qY.dequantize()) + def test_mean(self): + scale_list = (1, 0.25) + zero_point_list = (0, 2) + shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4, 4)) + dtypes = (torch.quint8, torch.qint8) + dims = ((), (-1,), (0,), (1,), (2,), (3,), (0, 1), (1, 2), (3, 4)) + test_cases = itertools.product(scale_list, zero_point_list, shapes, dtypes, dims) + op = torch.mean + for scale, zp, shape, dtype, dim in test_cases: + if not all([d < len(shape) for d in dim]): + continue + X = torch.randn(*shape) * 10 + qX = torch.quantize_per_tensor(X, scale, zp, dtype) + Y = op(qX.dequantize(), dim) + Y = torch.quantize_per_tensor(Y, scale, zp, dtype).dequantize() + qY = op(qX, dim) + self.assertEqual(Y, qY.dequantize()) @skipIfNoQNNPACK @given(keep=st.booleans()) @@ -2177,6 +2215,28 @@ def test_quantized_mean_qnnpack(self, keep): MQ = XQ.mean((2, 3), keepdim=keep) self.assertTrue(torch.equal(MQ, YQ)) + @override_qengines + def test_std(self): + scale_list = (1, 0.25) + zero_point_list = (0, 2) + shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4, 4)) + dtypes = (torch.quint8, torch.qint8) + dims = ((), (-1,), (0,), (1,), (2,), (3,), (0, 1), (1, 2), (3, 4)) + unbiased_list = (True, False) + keep_dim_list = (True, False) + test_cases = itertools.product(scale_list, zero_point_list, shapes, + dtypes, dims, unbiased_list, keep_dim_list) + op = torch.std + for scale, zp, shape, dtype, dim, unbiased, keep_dim in test_cases: + if not all([d < len(shape) for d in dim]): + continue + X = torch.randn(*shape) * 10 + qX = torch.quantize_per_tensor(X, scale, zp, dtype) + Y = op(qX.dequantize(), dim, unbiased, keep_dim) + Y = torch.quantize_per_tensor(Y, scale, zp, dtype).dequantize() + qY = op(qX, dim, unbiased, keep_dim) + self.assertEqual(Y, qY.dequantize()) + """Tests the correctness of the quantized equal op.""" @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), qparams=hu.qparams()), @@ -2252,9 +2312,9 @@ def equal_ref(qX, qX2): def test_group_norm(self): # hypothesis is flaky for this test, create test cases manually batches_list = (1, 7) - num_groups_list = (1, 2) - channels_per_groups = (1, 2) - elements_per_channels = (8, 17) + num_groups_list = (1, 4) + channels_per_groups = (1, 36, 72) + elements_per_channels = (8, 128, 1024) torch_types = (torch.qint8, torch.quint8) y_scales = (0.1, 4.23) y_zero_points = (0, 1) @@ -2319,7 +2379,7 @@ def test_group_norm(self): ch_end = ch_start + channels_per_group group_vals = dqX[batch_idx][ch_start:ch_end] assume( - float(torch.unique(group_vals).shape[0]) / group_vals.numel() > 0.01 + float(torch.unique(group_vals).shape[0]) / group_vals.numel() > 0.001 or group_vals.numel() < 5) qY = torch.ops.quantized.group_norm(qX, num_groups, weight, bias, eps, Y_scale, Y_zero_point) @@ -2706,15 +2766,6 @@ def test_custom_module_lstm(self): dtype = np.uint8 qtype = torch.quint8 - custom_module_config = { - 'float_to_observed_custom_module_class': { - torch.nn.LSTM: torch.nn.quantizable.LSTM - }, - 'observed_to_quantized_custom_module_class': { - torch.nn.quantizable.LSTM: torch.nn.quantizable.LSTM - } - } - x = np.random.randn(seq_len, batch_size, input_size) scale, zero_point = _calculate_dynamic_qparams(x, dtype=dtype) x = torch.from_numpy(x).to(torch.float) @@ -2748,18 +2799,18 @@ def test_custom_module_lstm(self): # Prepare lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine) - lstm_prepared = torch.ao.quantization.prepare( - lstm, prepare_custom_config_dict=custom_module_config) + lstm_prepared = torch.ao.quantization.prepare(lstm) self.assertTrue(hasattr(lstm_prepared[0], 'layers')) self.assertEqual(num_layers, len(lstm_prepared[0].layers)) + assert type(lstm_prepared[0]) == torch.nn.quantizable.LSTM # Calibrate y = lstm_prepared(x) self.assertEqual(y_ref, y) # Quantize - lstm_quantized = torch.ao.quantization.convert( - lstm_prepared, convert_custom_config_dict=custom_module_config) + lstm_quantized = torch.ao.quantization.convert(lstm_prepared) + assert type(lstm_quantized[0]) == torch.nn.quantized.LSTM qy = lstm_quantized(qx) snr = _snr(y, qy) @@ -2818,15 +2869,6 @@ def forward( dtype = np.uint8 qtype = torch.quint8 - custom_module_config = { - 'float_to_observed_custom_module_class': { - torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention - }, - 'observed_to_quantized_custom_module_class': { - torch.nn.quantizable.MultiheadAttention: torch.nn.quantizable.MultiheadAttention - } - } - for kdim, vdim in ((kembed_dim, vembed_dim), (None, None)): fp_data = [ torch.randn(target_seq_length, batch_size, qembed_dim), # Q @@ -2866,7 +2908,7 @@ def forward( else: mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine) mha_prepared = torch.ao.quantization.prepare( - mha, prepare_custom_config_dict=custom_module_config) + mha) # Calibrate y = mha_prepared(*fp_data) @@ -2876,9 +2918,7 @@ def forward( self.assertEqual(y_ref[1], y[1]) # Weight # Quantize - mha_quantized = torch.ao.quantization.convert( - mha_prepared, - convert_custom_config_dict=custom_module_config) + mha_quantized = torch.ao.quantization.convert(mha_prepared) qy = mha_quantized(*q_data) # Reference result diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py index 5467943d17a1d..1a297b9ecf43c 100644 --- a/test/quantization/fx/test_equalize_fx.py +++ b/test/quantization/fx/test_equalize_fx.py @@ -279,7 +279,7 @@ def test_input_weight_equalization_prepare(self): m, specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict) + _equalization_config=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) def test_input_weight_equalization_branching(self): @@ -313,7 +313,7 @@ def forward(self, x): example_inputs = (torch.rand(1, 5),) prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict) + _equalization_config=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence) # Tests that we will add an equalization observer because there is only @@ -337,7 +337,7 @@ def forward(self, x): example_inputs = (torch.randn(1, 5),) prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict) + _equalization_config=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence) @skipIfNoFBGEMM @@ -369,7 +369,7 @@ def test_input_weight_equalization_convert(self): copy.deepcopy(m), specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict + _equalization_config=default_equalization_qconfig_dict ) output = prepared(x) @@ -379,7 +379,7 @@ def test_input_weight_equalization_convert(self): prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict) + _equalization_config=default_equalization_qconfig_dict) prepared(x) convert_fx(prepared) # Check if compile self.assertEqual(output, convert_ref_output) @@ -431,7 +431,7 @@ def test_input_weight_equalization_equalization_scales(self): prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict) + _equalization_config=default_equalization_qconfig_dict) prepared(*example_inputs) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -484,7 +484,7 @@ def test_input_weight_equalization_weights_bias(self): prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict) + _equalization_config=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -544,7 +544,7 @@ def test_input_weight_equalization_activation_values(self): prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict) + _equalization_config=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -783,7 +783,7 @@ def test_input_weight_equalization_graphs(self): prepared = prepare_fx( m, specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict) + _equalization_config=default_equalization_qconfig_dict) equalized_quantized_model = convert_fx(prepared) # Check the order of nodes in the graph @@ -808,7 +808,7 @@ def test_input_weight_equalization_results(self): copy.deepcopy(m), specific_qconfig_dict, example_inputs=example_inputs, - equalization_config={}) + _equalization_config={}) prepared(x) quantized = convert_fx(prepared) # Check if compile quantized_output = quantized(x) @@ -818,7 +818,7 @@ def test_input_weight_equalization_results(self): copy.deepcopy(m), specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=default_equalization_qconfig_dict + _equalization_config=default_equalization_qconfig_dict ) prepared(x) equalized_and_quantized = convert_fx(prepared) # Check if compile @@ -876,7 +876,7 @@ def forward(self, x): copy.deepcopy(float_model), specific_qconfig_dict, example_inputs=example_inputs, - equalization_config=selective_equalization_qconfig_dict, + _equalization_config=selective_equalization_qconfig_dict, ) prepared_model(x) equalized_model = convert_fx(prepared_model) diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index 2df4622e693c7..d123a8752ca72 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -2,11 +2,19 @@ # Owner(s): ["oncall: quantization"] import torch +import torch.nn as nn import torch.ao.quantization.quantize_fx as quantize_fx import torch.nn.functional as F from torch.ao.quantization import QConfig, QConfigMapping -from torch.ao.quantization.fx._model_report.detector import DynamicStaticDetector, PerChannelDetector +from torch.ao.quantization.fx._model_report.detector import ( + DynamicStaticDetector, + InputWeightEqualizationDetector, + PerChannelDetector, + OutlierDetector, +) from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver +from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer +from torch.ao.quantization.fx._model_report.model_report import ModelReport from torch.ao.quantization.observer import HistogramObserver, default_per_channel_weight_observer from torch.nn.intrinsic.modules.fused import ConvReLU2d, LinearReLU from torch.testing._internal.common_quantization import ( @@ -16,6 +24,7 @@ TwoLayerLinearModel, skipIfNoFBGEMM, skipIfNoQNNPACK, + override_quantized_engine, ) @@ -68,6 +77,39 @@ torch.nn.Conv2d(3, 3, 2, 1), ) +# Test class +# example model to use for tests +class ThreeOps(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 3) + self.bn = nn.BatchNorm2d(3) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.linear(x) + x = self.bn(x) + x = self.relu(x) + return x + + def get_example_inputs(self): + return (torch.randn(1, 3, 3, 3),) + +class TwoThreeOps(nn.Module): + def __init__(self): + super().__init__() + self.block1 = ThreeOps() + self.block2 = ThreeOps() + + def forward(self, x): + x = self.block1(x) + y = self.block2(x) + z = x + y + z = F.relu(z) + return z + + def get_example_inputs(self): + return (torch.randn(1, 3, 3, 3),) class TestFxModelReportDetector(QuantizationTestCase): @@ -87,32 +129,37 @@ def _prepare_model_and_run_input(self, model, q_config_mapping, input): Output has no changes / suggestions """ + @skipIfNoFBGEMM def test_simple_conv(self): - torch.backends.quantized.engine = "onednn" - q_config_mapping = QConfigMapping() - q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + with override_quantized_engine('fbgemm'): + torch.backends.quantized.engine = "fbgemm" - input = torch.randn(1, 3, 10, 10) - prepared_model = self._prepare_model_and_run_input(ConvModel(), q_config_mapping, input) + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) - # run the detector - per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) - optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) + input = torch.randn(1, 3, 10, 10) + prepared_model = self._prepare_model_and_run_input(ConvModel(), q_config_mapping, input) - # no optims possible and there should be nothing in per_channel_status - self.assertEqual( - optims_str, - DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), - ) - self.assertEqual(per_channel_info["backend"], torch.backends.quantized.engine) - self.assertEqual(len(per_channel_info["per_channel_status"]), 1) - self.assertEqual(list(per_channel_info["per_channel_status"])[0], "conv") - self.assertEqual( - per_channel_info["per_channel_status"]["conv"]["per_channel_supported"], - True, - ) - self.assertEqual(per_channel_info["per_channel_status"]["conv"]["per_channel_used"], True) + # run the detector + per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) + optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) + + # no optims possible and there should be nothing in per_channel_status + self.assertEqual( + optims_str, + DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), + ) + + # there shoud only be one conv there in this model + self.assertEqual(per_channel_info["conv"]["backend"], torch.backends.quantized.engine) + self.assertEqual(len(per_channel_info), 1) + self.assertEqual(list(per_channel_info)[0], "conv") + self.assertEqual( + per_channel_info["conv"]["per_channel_quantization_supported"], + True, + ) + self.assertEqual(per_channel_info["conv"]["per_channel_quantization_used"], True) """Case includes: Multiple conv or linear @@ -125,35 +172,39 @@ def test_simple_conv(self): @skipIfNoQNNPACK def test_multi_linear_model_without_per_channel(self): - torch.backends.quantized.engine = "qnnpack" - q_config_mapping = QConfigMapping() - q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + with override_quantized_engine('qnnpack'): + torch.backends.quantized.engine = "qnnpack" - prepared_model = self._prepare_model_and_run_input( - TwoLayerLinearModel(), - q_config_mapping, - TwoLayerLinearModel().get_example_inputs()[0], - ) + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) - # run the detector - per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) - optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) + prepared_model = self._prepare_model_and_run_input( + TwoLayerLinearModel(), + q_config_mapping, + TwoLayerLinearModel().get_example_inputs()[0], + ) - # there should be optims possible - self.assertNotEqual( - optims_str, - DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), - ) - self.assertEqual(per_channel_info["backend"], torch.backends.quantized.engine) - self.assertEqual(len(per_channel_info["per_channel_status"]), 2) + # run the detector + per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) + optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) - # for each linear layer, should be supported but not used - for linear_key in per_channel_info["per_channel_status"].keys(): - module_entry = per_channel_info["per_channel_status"][linear_key] + # there should be optims possible + self.assertNotEqual( + optims_str, + DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), + ) + # pick a random key to look at + rand_key: str = list(per_channel_info.keys())[0] + self.assertEqual(per_channel_info[rand_key]["backend"], torch.backends.quantized.engine) + self.assertEqual(len(per_channel_info), 2) + + # for each linear layer, should be supported but not used + for linear_key in per_channel_info.keys(): + module_entry = per_channel_info[linear_key] - self.assertEqual(module_entry["per_channel_supported"], True) - self.assertEqual(module_entry["per_channel_used"], False) + self.assertEqual(module_entry["per_channel_quantization_supported"], True) + self.assertEqual(module_entry["per_channel_quantization_used"], False) """Case includes: Multiple conv or linear @@ -166,70 +217,72 @@ def test_multi_linear_model_without_per_channel(self): @skipIfNoQNNPACK def test_multiple_q_config_options(self): - torch.backends.quantized.engine = "qnnpack" - - # qconfig with support for per_channel quantization - per_channel_qconfig = QConfig( - activation=HistogramObserver.with_args(reduce_range=True), - weight=default_per_channel_weight_observer, - ) - # we need to design the model - class ConvLinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(3, 3, 2, 1) - self.fc1 = torch.nn.Linear(9, 27) - self.relu = torch.nn.ReLU() - self.fc2 = torch.nn.Linear(27, 27) - self.conv2 = torch.nn.Conv2d(3, 3, 2, 1) - - def forward(self, x): - x = self.conv1(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.conv2(x) - return x + with override_quantized_engine('qnnpack'): + torch.backends.quantized.engine = "qnnpack" - q_config_mapping = QConfigMapping() - q_config_mapping.set_global( - torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine) - ).set_object_type(torch.nn.Conv2d, per_channel_qconfig) + # qconfig with support for per_channel quantization + per_channel_qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) - prepared_model = self._prepare_model_and_run_input( - ConvLinearModel(), - q_config_mapping, - torch.randn(1, 3, 10, 10), - ) + # we need to design the model + class ConvLinearModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 2, 1) + self.fc1 = torch.nn.Linear(9, 27) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(27, 27) + self.conv2 = torch.nn.Conv2d(3, 3, 2, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.conv2(x) + return x + + q_config_mapping = QConfigMapping() + q_config_mapping.set_global( + torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine) + ).set_object_type(torch.nn.Conv2d, per_channel_qconfig) + + prepared_model = self._prepare_model_and_run_input( + ConvLinearModel(), + q_config_mapping, + torch.randn(1, 3, 10, 10), + ) - # run the detector - per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) - optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) + # run the detector + per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) + optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) - # the only suggestions should be to linear layers + # the only suggestions should be to linear layers - # there should be optims possible - self.assertNotEqual( - optims_str, - DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), - ) + # there should be optims possible + self.assertNotEqual( + optims_str, + DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), + ) - # to ensure it got into the nested layer - self.assertEqual(len(per_channel_info["per_channel_status"]), 4) + # to ensure it got into the nested layer + self.assertEqual(len(per_channel_info), 4) - # for each layer, should be supported but not used - for key in per_channel_info["per_channel_status"].keys(): - module_entry = per_channel_info["per_channel_status"][key] - self.assertEqual(module_entry["per_channel_supported"], True) + # for each layer, should be supported but not used + for key in per_channel_info.keys(): + module_entry = per_channel_info[key] + self.assertEqual(module_entry["per_channel_quantization_supported"], True) - # if linear False, if conv2d true cuz it uses different config - if "fc" in key: - self.assertEqual(module_entry["per_channel_used"], False) - elif "conv" in key: - self.assertEqual(module_entry["per_channel_used"], True) - else: - raise ValueError("Should only contain conv and linear layers as key values") + # if linear False, if conv2d true cuz it uses different config + if "fc" in key: + self.assertEqual(module_entry["per_channel_quantization_used"], False) + elif "conv" in key: + self.assertEqual(module_entry["per_channel_quantization_used"], True) + else: + raise ValueError("Should only contain conv and linear layers as key values") """Case includes: Multiple conv or linear @@ -242,36 +295,38 @@ def forward(self, x): @skipIfNoQNNPACK def test_sequential_model_format(self): - torch.backends.quantized.engine = "qnnpack" - q_config_mapping = QConfigMapping() - q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + with override_quantized_engine('qnnpack'): + torch.backends.quantized.engine = "qnnpack" - prepared_model = self._prepare_model_and_run_input( - NESTED_CONV_LINEAR_EXAMPLE, - q_config_mapping, - torch.randn(1, 3, 10, 10), - ) + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + + prepared_model = self._prepare_model_and_run_input( + NESTED_CONV_LINEAR_EXAMPLE, + q_config_mapping, + torch.randn(1, 3, 10, 10), + ) - # run the detector - per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) - optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) + # run the detector + per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) + optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) - # there should be optims possible - self.assertNotEqual( - optims_str, - DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), - ) + # there should be optims possible + self.assertNotEqual( + optims_str, + DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), + ) - # to ensure it got into the nested layer - self.assertEqual(len(per_channel_info["per_channel_status"]), 4) + # to ensure it got into the nested layer + self.assertEqual(len(per_channel_info), 4) - # for each layer, should be supported but not used - for key in per_channel_info["per_channel_status"].keys(): - module_entry = per_channel_info["per_channel_status"][key] + # for each layer, should be supported but not used + for key in per_channel_info.keys(): + module_entry = per_channel_info[key] - self.assertEqual(module_entry["per_channel_supported"], True) - self.assertEqual(module_entry["per_channel_used"], False) + self.assertEqual(module_entry["per_channel_quantization_supported"], True) + self.assertEqual(module_entry["per_channel_quantization_used"], False) """Case includes: Multiple conv or linear @@ -284,36 +339,38 @@ def test_sequential_model_format(self): @skipIfNoQNNPACK def test_conv_sub_class_considered(self): - torch.backends.quantized.engine = "qnnpack" - q_config_mapping = QConfigMapping() - q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + with override_quantized_engine('qnnpack'): + torch.backends.quantized.engine = "qnnpack" - prepared_model = self._prepare_model_and_run_input( - LAZY_CONV_LINEAR_EXAMPLE, - q_config_mapping, - torch.randn(1, 3, 10, 10), - ) + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) - # run the detector - per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) - optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) + prepared_model = self._prepare_model_and_run_input( + LAZY_CONV_LINEAR_EXAMPLE, + q_config_mapping, + torch.randn(1, 3, 10, 10), + ) - # there should be optims possible - self.assertNotEqual( - optims_str, - DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), - ) + # run the detector + per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) + optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) + + # there should be optims possible + self.assertNotEqual( + optims_str, + DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), + ) - # to ensure it got into the nested layer and it considered the lazyConv2d - self.assertEqual(len(per_channel_info["per_channel_status"]), 4) + # to ensure it got into the nested layer and it considered the lazyConv2d + self.assertEqual(len(per_channel_info), 4) - # for each layer, should be supported but not used - for key in per_channel_info["per_channel_status"].keys(): - module_entry = per_channel_info["per_channel_status"][key] + # for each layer, should be supported but not used + for key in per_channel_info.keys(): + module_entry = per_channel_info[key] - self.assertEqual(module_entry["per_channel_supported"], True) - self.assertEqual(module_entry["per_channel_used"], False) + self.assertEqual(module_entry["per_channel_quantization_supported"], True) + self.assertEqual(module_entry["per_channel_quantization_used"], False) """Case includes: Multiple conv or linear @@ -326,35 +383,37 @@ def test_conv_sub_class_considered(self): @skipIfNoFBGEMM def test_fusion_layer_in_sequential(self): - torch.backends.quantized.engine = "fbgemm" - q_config_mapping = QConfigMapping() - q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + with override_quantized_engine('fbgemm'): + torch.backends.quantized.engine = "fbgemm" - prepared_model = self._prepare_model_and_run_input( - FUSION_CONV_LINEAR_EXAMPLE, - q_config_mapping, - torch.randn(1, 3, 10, 10), - ) + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + + prepared_model = self._prepare_model_and_run_input( + FUSION_CONV_LINEAR_EXAMPLE, + q_config_mapping, + torch.randn(1, 3, 10, 10), + ) - # run the detector - per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) - optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) + # run the detector + per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) + optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) - # no optims possible and there should be nothing in per_channel_status - self.assertEqual( - optims_str, - DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), - ) + # no optims possible and there should be nothing in per_channel_status + self.assertEqual( + optims_str, + DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), + ) - # to ensure it got into the nested layer and it considered all the nested fusion components - self.assertEqual(len(per_channel_info["per_channel_status"]), 4) + # to ensure it got into the nested layer and it considered all the nested fusion components + self.assertEqual(len(per_channel_info), 4) - # for each layer, should be supported but not used - for key in per_channel_info["per_channel_status"].keys(): - module_entry = per_channel_info["per_channel_status"][key] - self.assertEqual(module_entry["per_channel_supported"], True) - self.assertEqual(module_entry["per_channel_used"], True) + # for each layer, should be supported but not used + for key in per_channel_info.keys(): + module_entry = per_channel_info[key] + self.assertEqual(module_entry["per_channel_quantization_supported"], True) + self.assertEqual(module_entry["per_channel_quantization_used"], True) """Case includes: Multiple conv or linear @@ -388,39 +447,40 @@ def forward(self, x): x = self.dequant(x) return x - # create a model instance - model_fp32 = QATConvLinearReluModel() + with override_quantized_engine('qnnpack'): + # create a model instance + model_fp32 = QATConvLinearReluModel() - model_fp32.qconfig = torch.quantization.get_default_qat_qconfig("qnnpack") + model_fp32.qconfig = torch.quantization.get_default_qat_qconfig("qnnpack") - # model must be in eval mode for fusion - model_fp32.eval() - model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [["conv", "bn", "relu"]]) + # model must be in eval mode for fusion + model_fp32.eval() + model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [["conv", "bn", "relu"]]) - # model must be set to train mode for QAT logic to work - model_fp32_fused.train() + # model must be set to train mode for QAT logic to work + model_fp32_fused.train() - # prepare the model for QAT, different than for post training quantization - model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused) + # prepare the model for QAT, different than for post training quantization + model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused) - # run the detector - per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) - optims_str, per_channel_info = per_channel_detector.generate_detector_report(model_fp32_prepared) + # run the detector + per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) + optims_str, per_channel_info = per_channel_detector.generate_detector_report(model_fp32_prepared) - # there should be optims possible - self.assertNotEqual( - optims_str, - DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), - ) + # there should be optims possible + self.assertNotEqual( + optims_str, + DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), + ) - # make sure it was able to find the single conv in the fused model - self.assertEqual(len(per_channel_info["per_channel_status"]), 1) + # make sure it was able to find the single conv in the fused model + self.assertEqual(len(per_channel_info), 1) - # for the one conv, it should still give advice to use different qconfig - for key in per_channel_info["per_channel_status"].keys(): - module_entry = per_channel_info["per_channel_status"][key] - self.assertEqual(module_entry["per_channel_supported"], True) - self.assertEqual(module_entry["per_channel_used"], False) + # for the one conv, it should still give advice to use different qconfig + for key in per_channel_info.keys(): + module_entry = per_channel_info[key] + self.assertEqual(module_entry["per_channel_quantization_supported"], True) + self.assertEqual(module_entry["per_channel_quantization_used"], False) """ @@ -743,65 +803,1044 @@ def forward(self, x): z = F.relu(z) return z - # create model, example input, and qconfig mapping - torch.backends.quantized.engine = "fbgemm" - model = TwoBlockNet() - example_input = torch.randint(-10, 0, (1, 3, 3, 3)) + + with override_quantized_engine('fbgemm'): + # create model, example input, and qconfig mapping + torch.backends.quantized.engine = "fbgemm" + model = TwoBlockNet() + example_input = torch.randint(-10, 0, (1, 3, 3, 3)) + example_input = example_input.to(torch.float) + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm")) + + # prep model and select observer + model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input) + obs_ctr = ModelReportObserver + + # find layer to attach to and store + linear_fqn = "block2.linear" # fqn of target linear + + target_linear = None + for node in model_prep.graph.nodes: + if node.target == linear_fqn: + target_linear = node + break + + # insert into both module and graph pre and post + + # set up to insert before target_linear (pre_observer) + with model_prep.graph.inserting_before(target_linear): + obs_to_insert = obs_ctr() + pre_obs_fqn = linear_fqn + ".model_report_pre_observer" + model_prep.add_submodule(pre_obs_fqn, obs_to_insert) + model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args) + + # set up and insert after the target_linear (post_observer) + with model_prep.graph.inserting_after(target_linear): + obs_to_insert = obs_ctr() + post_obs_fqn = linear_fqn + ".model_report_post_observer" + model_prep.add_submodule(post_obs_fqn, obs_to_insert) + model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,)) + + # need to recompile module after submodule added and pass input through + model_prep.recompile() + + num_iterations = 10 + for i in range(num_iterations): + if i % 2 == 0: + example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float) + else: + example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float) + model_prep(example_input) + + # run it through the dynamic vs static detector + dynamic_vs_static_detector = DynamicStaticDetector() + dynam_vs_stat_str, dynam_vs_stat_dict = dynamic_vs_static_detector.generate_detector_report(model_prep) + + # one of the stats should be stationary, and the other non-stationary + # as a result, dynamic should be recommended + data_dist_info = [ + dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.PRE_OBS_DATA_DIST_KEY], + dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.POST_OBS_DATA_DIST_KEY], + ] + + self.assertTrue("stationary" in data_dist_info) + self.assertTrue("non-stationary" in data_dist_info) + self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"]) + +class TestFxModelReportClass(QuantizationTestCase): + + @skipIfNoFBGEMM + def test_constructor(self): + """ + Tests the constructor of the ModelReport class. + Specifically looks at: + - The desired reports + - Ensures that the observers of interest are properly initialized + """ + + with override_quantized_engine('fbgemm'): + # set the backend for this test + torch.backends.quantized.engine = "fbgemm" + backend = torch.backends.quantized.engine + + # create a model + model = ThreeOps() + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + model_prep = quantize_fx.prepare_fx(model, q_config_mapping, model.get_example_inputs()[0]) + + # make an example set of detectors + test_detector_set = set([DynamicStaticDetector(), PerChannelDetector(backend)]) + # initialize with an empty detector + model_report = ModelReport(model_prep, test_detector_set) + + # make sure internal valid reports matches + detector_name_set = set([detector.get_detector_name() for detector in test_detector_set]) + self.assertEqual(model_report.get_desired_reports_names(), detector_name_set) + + # now attempt with no valid reports, should raise error + with self.assertRaises(ValueError): + model_report = ModelReport(model, set([])) + + # number of expected obs of interest entries + num_expected_entries = len(test_detector_set) + self.assertEqual(len(model_report.get_observers_of_interest()), num_expected_entries) + + for value in model_report.get_observers_of_interest().values(): + self.assertEqual(len(value), 0) + + @skipIfNoFBGEMM + def test_prepare_model_callibration(self): + """ + Tests model_report.prepare_detailed_calibration that prepares the model for callibration + Specifically looks at: + - Whether observers are properly inserted into regular nn.Module + - Whether the target and the arguments of the observers are proper + - Whether the internal representation of observers of interest is updated + """ + + with override_quantized_engine('fbgemm'): + # create model report object + + # create model + model = TwoThreeOps() + # make an example set of detectors + torch.backends.quantized.engine = "fbgemm" + backend = torch.backends.quantized.engine + test_detector_set = set([DynamicStaticDetector(), PerChannelDetector(backend)]) + # initialize with an empty detector + + # prepare the model + example_input = model.get_example_inputs()[0] + current_backend = torch.backends.quantized.engine + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + + model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input) + + model_report = ModelReport(model_prep, test_detector_set) + + # prepare the model for callibration + prepared_for_callibrate_model = model_report.prepare_detailed_calibration() + + # see whether observers properly in regular nn.Module + # there should be 4 observers present in this case + modules_observer_cnt = 0 + for fqn, module in prepared_for_callibrate_model.named_modules(): + if isinstance(module, ModelReportObserver): + modules_observer_cnt += 1 + + self.assertEqual(modules_observer_cnt, 4) + + model_report_str_check = "model_report" + # also make sure arguments for observers in the graph are proper + for node in prepared_for_callibrate_model.graph.nodes: + # not all node targets are strings, so check + if isinstance(node.target, str) and model_report_str_check in node.target: + # if pre-observer has same args as the linear (next node) + if "pre_observer" in node.target: + self.assertEqual(node.args, node.next.args) + # if post-observer, args are the target linear (previous node) + if "post_observer" in node.target: + self.assertEqual(node.args, (node.prev,)) + + # ensure model_report observers of interest updated + # there should be two entries + self.assertEqual(len(model_report.get_observers_of_interest()), 2) + for detector in test_detector_set: + self.assertTrue(detector.get_detector_name() in model_report.get_observers_of_interest().keys()) + + # get number of entries for this detector + detector_obs_of_interest_fqns = model_report.get_observers_of_interest()[detector.get_detector_name()] + + # assert that the per channel detector has 0 and the dynamic static has 4 + if isinstance(detector, PerChannelDetector): + self.assertEqual(len(detector_obs_of_interest_fqns), 0) + elif isinstance(detector, DynamicStaticDetector): + self.assertEqual(len(detector_obs_of_interest_fqns), 4) + + # ensure that we can prepare for callibration only once + with self.assertRaises(ValueError): + prepared_for_callibrate_model = model_report.prepare_detailed_calibration() + + + def get_module_and_graph_cnts(self, callibrated_fx_module): + r""" + Calculates number of ModelReportObserver modules in the model as well as the graph structure. + Returns a tuple of two elements: + int: The number of ModelReportObservers found in the model + int: The number of model_report nodes found in the graph + """ + # get the number of observers stored as modules + modules_observer_cnt = 0 + for fqn, module in callibrated_fx_module.named_modules(): + if isinstance(module, ModelReportObserver): + modules_observer_cnt += 1 + + # get number of observers in the graph + model_report_str_check = "model_report" + graph_observer_cnt = 0 + # also make sure arguments for observers in the graph are proper + for node in callibrated_fx_module.graph.nodes: + # not all node targets are strings, so check + if isinstance(node.target, str) and model_report_str_check in node.target: + # increment if we found a graph observer + graph_observer_cnt += 1 + + return (modules_observer_cnt, graph_observer_cnt) + + @skipIfNoFBGEMM + def test_generate_report(self): + """ + Tests model_report.generate_model_report to ensure report generation + Specifically looks at: + - Whether correct number of reports are being generated + - Whether observers are being properly removed if specified + - Whether correct blocking from generating report twice if obs removed + """ + + with override_quantized_engine('fbgemm'): + # set the backend for this test + torch.backends.quantized.engine = "fbgemm" + + # check whether the correct number of reports are being generated + filled_detector_set = set([DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)]) + single_detector_set = set([DynamicStaticDetector()]) + + # create our models + model_full = TwoThreeOps() + model_single = TwoThreeOps() + + # prepare and callibrate two different instances of same model + # prepare the model + example_input = model_full.get_example_inputs()[0] + current_backend = torch.backends.quantized.engine + q_config_mapping = QConfigMapping() + q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) + + model_prep_full = quantize_fx.prepare_fx(model_full, q_config_mapping, example_input) + model_prep_single = quantize_fx.prepare_fx(model_single, q_config_mapping, example_input) + + # initialize one with filled detector + model_report_full = ModelReport(model_prep_full, filled_detector_set) + # initialize another with a single detector set + model_report_single = ModelReport(model_prep_single, single_detector_set) + + # prepare the models for callibration + prepared_for_callibrate_model_full = model_report_full.prepare_detailed_calibration() + prepared_for_callibrate_model_single = model_report_single.prepare_detailed_calibration() + + # now callibrate the two models + num_iterations = 10 + for i in range(num_iterations): + example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float) + prepared_for_callibrate_model_full(example_input) + prepared_for_callibrate_model_single(example_input) + + # now generate the reports + model_full_report = model_report_full.generate_model_report(True) + model_single_report = model_report_single.generate_model_report(False) + + # check that sizes are appropriate + self.assertEqual(len(model_full_report), len(filled_detector_set)) + self.assertEqual(len(model_single_report), len(single_detector_set)) + + # make sure observers are being properly removed for full report since we put flag in + modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_full) + self.assertEqual(modules_observer_cnt, 0) # assert no more observer modules + self.assertEqual(graph_observer_cnt, 0) # assert no more observer nodes in graph + + # make sure observers aren't being removed for single report since not specified + modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_single) + self.assertNotEqual(modules_observer_cnt, 0) + self.assertNotEqual(graph_observer_cnt, 0) + + # make sure error when try to rerun report generation for full report but not single report + with self.assertRaises(Exception): + model_full_report = model_report_full.generate_model_report( + prepared_for_callibrate_model_full, False + ) + + # make sure we don't run into error for single report + model_single_report = model_report_single.generate_model_report(False) + + @skipIfNoFBGEMM + def test_generate_visualizer(self): + """ + Tests that the ModelReport class can properly create the ModelReportVisualizer instance + Checks that: + - Correct number of modules are represented + - Modules are sorted + - Correct number of features for each module + """ + with override_quantized_engine('fbgemm'): + # set the backend for this test + torch.backends.quantized.engine = "fbgemm" + # test with multiple detectors + detector_set = set() + detector_set.add(OutlierDetector(reference_percentile=0.95)) + detector_set.add(InputWeightEqualizationDetector(0.5)) + + model = TwoThreeOps() + + # get tst model and callibrate + prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper( + model, detector_set, model.get_example_inputs()[0] + ) + + # now we actually callibrate the model + example_input = model.get_example_inputs()[0] + example_input = example_input.to(torch.float) + + prepared_for_callibrate_model(example_input) + + # try to visualize without generating report, should throw error + with self.assertRaises(Exception): + mod_rep_visualizaiton = mod_report.generate_visualizer() + + # now get the report by running it through ModelReport instance + generated_report = mod_report.generate_model_report(remove_inserted_observers=False) + + # now we get the visualizer should not error + mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer() + + # since we tested with outlier detector, which looks at every base level module + # should be six entries in the ordered dict + mod_fqns_to_features = mod_rep_visualizer.generated_reports + + self.assertEqual(len(mod_fqns_to_features), 6) + + # outlier detector has 9 feature per module + # input-weight has 12 features per module + # there are 1 common data point, so should be 12 + 9 - 1 = 20 unique features per common modules + # all linears will be common + for module_fqn in mod_fqns_to_features: + if ".linear" in module_fqn: + linear_info = mod_fqns_to_features[module_fqn] + self.assertEqual(len(linear_info), 20) + +class TestFxDetectInputWeightEqualization(QuantizationTestCase): + + class SimpleConv(torch.nn.Module): + def __init__(self, con_dims): + super().__init__() + self.relu = torch.nn.ReLU() + self.conv = torch.nn.Conv2d(con_dims[0], con_dims[1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + return x + + class TwoBlockComplexNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.block1 = TestFxDetectInputWeightEqualization.SimpleConv((3, 32)) + self.block2 = TestFxDetectInputWeightEqualization.SimpleConv((3, 3)) + self.conv = torch.nn.Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False) + self.linear = torch.nn.Linear(768, 10) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.block1(x) + x = self.conv(x) + y = self.block2(x) + y = y.repeat(1, 1, 2, 2) + z = x + y + z = z.flatten(start_dim=1) + z = self.linear(z) + z = self.relu(z) + return z + + def get_fusion_modules(self): + return [['conv', 'relu']] + + def get_example_inputs(self): + return (torch.randn((1, 3, 28, 28)),) + + class ReluOnly(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(x) + return x + + def get_example_inputs(self): + return (torch.arange(27).reshape((1, 3, 3, 3)),) + + def _get_prepped_for_calibration_model(self, model, detector_set, fused=False): + r"""Returns a model that has been prepared for callibration and corresponding model_report""" + + # pass in necessary inputs to helper + example_input = model.get_example_inputs()[0] + return _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused) + + @skipIfNoFBGEMM + def test_input_weight_equalization_determine_points(self): + # use fbgemm and create our model instance + # then create model report instance with detector + with override_quantized_engine('fbgemm'): + + detector_set = set([InputWeightEqualizationDetector(0.5)]) + + # get tst model and callibrate + non_fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set) + fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set, fused=True) + + # reporter should still give same counts even for fused model + for prepared_for_callibrate_model, mod_report in [non_fused, fused]: + + # supported modules to check + mods_to_check = set([nn.Linear, nn.Conv2d]) + + # get the set of all nodes in the graph their fqns + node_fqns = set([node.target for node in prepared_for_callibrate_model.graph.nodes]) + + # there should be 4 node fqns that have the observer inserted + correct_number_of_obs_inserted = 4 + number_of_obs_found = 0 + obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME + + for node in prepared_for_callibrate_model.graph.nodes: + # if the obs name is inside the target, we found an observer + if obs_name_to_find in str(node.target): + number_of_obs_found += 1 + + self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted) + + # assert that each of the desired modules have the observers inserted + for fqn, module in prepared_for_callibrate_model.named_modules(): + # check if module is a supported module + is_in_include_list = sum(list(map(lambda x: isinstance(module, x), mods_to_check))) > 0 + + if is_in_include_list: + # make sure it has the observer attribute + self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME)) + else: + # if it's not a supported type, it shouldn't have observer attached + self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME)) + + @skipIfNoFBGEMM + def test_input_weight_equalization_report_gen(self): + # use fbgemm and create our model instance + # then create model report instance with detector + with override_quantized_engine('fbgemm'): + + test_input_weight_detector = InputWeightEqualizationDetector(0.4) + detector_set = set([test_input_weight_detector]) + model = self.TwoBlockComplexNet() + # prepare the model for callibration + prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model( + model, detector_set + ) + + # now we actually callibrate the model + example_input = model.get_example_inputs()[0] + example_input = example_input.to(torch.float) + + prepared_for_callibrate_model(example_input) + + # now get the report by running it through ModelReport instance + generated_report = model_report.generate_model_report(True) + + # check that sizes are appropriate only 1 detector + self.assertEqual(len(generated_report), 1) + + # get the specific report for input weight equalization + input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()] + + # we should have 5 layers looked at since 4 conv / linear layers + self.assertEqual(len(input_weight_dict), 4) + + # we can validate that the max and min values of the detector were recorded properly for the first one + # this is because no data has been processed yet, so it should be values from original input + + example_input = example_input.reshape((3, 28, 28)) # reshape input + for module_fqn in input_weight_dict: + # look for the first linear + if "block1.linear" in module_fqn: + block_1_lin_recs = input_weight_dict[module_fqn] + # get input range info and the channel axis + ch_axis = block_1_lin_recs[InputWeightEqualizationDetector.CHANNEL_KEY] + + # ensure that the min and max values extracted match properly + example_min, example_max = torch.aminmax(example_input, dim=ch_axis) + dimension_min = torch.amin(example_min, dim=ch_axis) + dimension_max = torch.amax(example_max, dim=ch_axis) + + # make sure per channel min and max are as expected + min_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX + min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY + + max_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX + max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY + + per_channel_min = block_1_lin_recs[min_per_key] + per_channel_max = block_1_lin_recs[max_per_key] + self.assertEqual(per_channel_min, dimension_min) + self.assertEqual(per_channel_max, dimension_max) + + # make sure per channel min and max are as expected + min_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX + min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY + + max_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX + max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY + + # make sure the global min and max were correctly recorded and presented + global_min = block_1_lin_recs[min_key] + global_max = block_1_lin_recs[max_key] + self.assertEqual(global_min, min(dimension_min)) + self.assertEqual(global_max, max(dimension_max)) + + input_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min)) + # ensure comparision stat passed back is sqrt of range ratios + # need to get the weight ratios first + + # make sure per channel min and max are as expected + min_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX + min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY + + max_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX + max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY + + # get weight per channel and global info + per_channel_min = block_1_lin_recs[min_per_key] + per_channel_max = block_1_lin_recs[max_per_key] + + # make sure per channel min and max are as expected + min_key = InputWeightEqualizationDetector.WEIGHT_PREFIX + min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY + + max_key = InputWeightEqualizationDetector.WEIGHT_PREFIX + max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY + + global_min = block_1_lin_recs[min_key] + global_max = block_1_lin_recs[max_key] + + weight_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min)) + + # also get comp stat for this specific layer + comp_stat = block_1_lin_recs[InputWeightEqualizationDetector.COMP_METRIC_KEY] + + weight_to_input_ratio = weight_ratio / input_ratio + + self.assertEqual(comp_stat, weight_to_input_ratio) + # only looking at the first example so can break + break + + @skipIfNoFBGEMM + def test_input_weight_equalization_report_gen_empty(self): + # tests report gen on a model that doesn't have any layers + # use fbgemm and create our model instance + # then create model report instance with detector + with override_quantized_engine('fbgemm'): + test_input_weight_detector = InputWeightEqualizationDetector(0.4) + detector_set = set([test_input_weight_detector]) + model = self.ReluOnly() + # prepare the model for callibration + prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(model, detector_set) + + # now we actually callibrate the model + example_input = model.get_example_inputs()[0] + example_input = example_input.to(torch.float) + + prepared_for_callibrate_model(example_input) + + # now get the report by running it through ModelReport instance + generated_report = model_report.generate_model_report(True) + + # check that sizes are appropriate only 1 detector + self.assertEqual(len(generated_report), 1) + + # get the specific report for input weight equalization + input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()] + + # we should have 0 layers since there is only a Relu + self.assertEqual(len(input_weight_dict), 0) + + # make sure that the string only has two lines, as should be if no suggestions + self.assertEqual(input_weight_str.count("\n"), 2) + + +class TestFxDetectOutliers(QuantizationTestCase): + + class LargeBatchModel(torch.nn.Module): + def __init__(self, param_size): + super().__init__() + self.param_size = param_size + self.linear = torch.nn.Linear(param_size, param_size) + self.relu_1 = torch.nn.ReLU() + self.conv = torch.nn.Conv2d(param_size, param_size, 1) + self.relu_2 = torch.nn.ReLU() + + def forward(self, x): + x = self.linear(x) + x = self.relu_1(x) + x = self.conv(x) + x = self.relu_2(x) + return x + + def get_example_inputs(self): + param_size = self.param_size + return (torch.randn((1, param_size, param_size, param_size)),) + + def get_outlier_inputs(self): + param_size = self.param_size + random_vals = torch.randn((1, param_size, param_size, param_size)) + # change one in some of them to be a massive value + random_vals[:, 0:param_size:2, 0, 3] = torch.tensor([3.28e8]) + return (random_vals,) + + + def _get_prepped_for_calibration_model(self, model, detector_set, use_outlier_data=False): + r"""Returns a model that has been prepared for callibration and corresponding model_report""" + # call the general helper function to callibrate + example_input = model.get_example_inputs()[0] + + # if we specifically want to test data with outliers replace input + if use_outlier_data: + example_input = model.get_outlier_inputs()[0] + + return _get_prepped_for_calibration_model_helper(model, detector_set, example_input) + + @skipIfNoFBGEMM + def test_outlier_detection_determine_points(self): + # use fbgemm and create our model instance + # then create model report instance with detector + # similar to test for InputWeightEqualization but key differences that made refactoring not viable + # not explicitly testing fusion because fx workflow automatically + with override_quantized_engine('fbgemm'): + + detector_set = set([OutlierDetector(reference_percentile=0.95)]) + + # get tst model and callibrate + prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( + self.LargeBatchModel(param_size=128), detector_set + ) + + # supported modules to check + mods_to_check = set([nn.Linear, nn.Conv2d, nn.ReLU]) + + # there should be 4 node fqns that have the observer inserted + correct_number_of_obs_inserted = 4 + number_of_obs_found = 0 + obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME + + number_of_obs_found = sum( + [1 if obs_name_to_find in str(node.target) else 0 for node in prepared_for_callibrate_model.graph.nodes] + ) + self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted) + + # assert that each of the desired modules have the observers inserted + for fqn, module in prepared_for_callibrate_model.named_modules(): + # check if module is a supported module + is_in_include_list = isinstance(module, tuple(mods_to_check)) + + if is_in_include_list: + # make sure it has the observer attribute + self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME)) + else: + # if it's not a supported type, it shouldn't have observer attached + self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME)) + + @skipIfNoFBGEMM + def test_no_outlier_report_gen(self): + # use fbgemm and create our model instance + # then create model report instance with detector + with override_quantized_engine('fbgemm'): + + # test with multiple detectors + outlier_detector = OutlierDetector(reference_percentile=0.95) + dynamic_static_detector = DynamicStaticDetector(tolerance=0.5) + + param_size: int = 4 + detector_set = set([outlier_detector, dynamic_static_detector]) + model = self.LargeBatchModel(param_size=param_size) + + # get tst model and callibrate + prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( + model, detector_set + ) + + # now we actually callibrate the model + example_input = model.get_example_inputs()[0] + example_input = example_input.to(torch.float) + + prepared_for_callibrate_model(example_input) + + # now get the report by running it through ModelReport instance + generated_report = mod_report.generate_model_report(True) + + # check that sizes are appropriate only 2 detectors + self.assertEqual(len(generated_report), 2) + + # get the specific report for input weight equalization + outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()] + + # we should have 5 layers looked at since 4 conv + linear + relu + self.assertEqual(len(outlier_dict), 4) + + # assert the following are true for all the modules + for module_fqn in outlier_dict: + # get the info for the specific module + module_dict = outlier_dict[module_fqn] + + # there really should not be any outliers since we used a normal distribution to perform this calculation + outlier_info = module_dict[OutlierDetector.OUTLIER_KEY] + self.assertEqual(sum(outlier_info), 0) + + # ensure that the number of ratios and batches counted is the same as the number of params + self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size) + self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size) + + + @skipIfNoFBGEMM + def test_all_outlier_report_gen(self): + # make the percentile 0 and the ratio 1, and then see that everything is outlier according to it + # use fbgemm and create our model instance + # then create model report instance with detector + with override_quantized_engine('fbgemm'): + # create detector of interest + outlier_detector = OutlierDetector(ratio_threshold=1, reference_percentile=0) + + param_size: int = 16 + detector_set = set([outlier_detector]) + model = self.LargeBatchModel(param_size=param_size) + + # get tst model and callibrate + prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( + model, detector_set + ) + + # now we actually callibrate the model + example_input = model.get_example_inputs()[0] + example_input = example_input.to(torch.float) + + prepared_for_callibrate_model(example_input) + + # now get the report by running it through ModelReport instance + generated_report = mod_report.generate_model_report(True) + + # check that sizes are appropriate only 1 detector + self.assertEqual(len(generated_report), 1) + + # get the specific report for input weight equalization + outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()] + + # we should have 5 layers looked at since 4 conv + linear + relu + self.assertEqual(len(outlier_dict), 4) + + # assert the following are true for all the modules + for module_fqn in outlier_dict: + # get the info for the specific module + module_dict = outlier_dict[module_fqn] + + # everything should be an outlier because we said that the max should be equal to the min for all of them + # however we will just test and say most should be in case we have several 0 channel values + outlier_info = module_dict[OutlierDetector.OUTLIER_KEY] + assert sum(outlier_info) >= len(outlier_info) / 2 + + # ensure that the number of ratios and batches counted is the same as the number of params + self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size) + self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size) + + @skipIfNoFBGEMM + def test_multiple_run_consistent_spike_outlier_report_gen(self): + # specifically make a row really high consistently in the number of batches that you are testing and try that + # generate report after just 1 run, and after many runs (30) and make sure above minimum threshold is there + with override_quantized_engine('fbgemm'): + + # detector of interest + outlier_detector = OutlierDetector(reference_percentile=0.95) + + param_size: int = 8 + detector_set = set([outlier_detector]) + model = self.LargeBatchModel(param_size=param_size) + + # get tst model and callibrate + prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( + model, detector_set, use_outlier_data=True + ) + + # now we actually callibrate the model + example_input = model.get_outlier_inputs()[0] + example_input = example_input.to(torch.float) + + # now callibrate minimum 30 times to make it above minimum threshold + for i in range(30): + example_input = model.get_outlier_inputs()[0] + example_input = example_input.to(torch.float) + + # make 2 of the batches to have zero channel + if i % 14 == 0: + # make one channel constant + example_input[0][1] = torch.zeros_like(example_input[0][1]) + + prepared_for_callibrate_model(example_input) + + # now get the report by running it through ModelReport instance + generated_report = mod_report.generate_model_report(True) + + # check that sizes are appropriate only 1 detector + self.assertEqual(len(generated_report), 1) + + # get the specific report for input weight equalization + outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()] + + # we should have 5 layers looked at since 4 conv + linear + relu + self.assertEqual(len(outlier_dict), 4) + + # assert the following are true for all the modules + for module_fqn in outlier_dict: + # get the info for the specific module + module_dict = outlier_dict[module_fqn] + + # because we ran 30 times, we should have at least a couple be significant + # could be less because some channels could possibly be all 0 + sufficient_batches_info = module_dict[OutlierDetector.IS_SUFFICIENT_BATCHES_KEY] + assert sum(sufficient_batches_info) >= len(sufficient_batches_info) / 2 + + # half of them should be outliers, because we set a really high value every 2 channels + outlier_info = module_dict[OutlierDetector.OUTLIER_KEY] + self.assertEqual(sum(outlier_info), len(outlier_info) / 2) + + # ensure that the number of ratios and batches counted is the same as the number of params + self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size) + self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size) + + # for the first one ensure the per channel max values are what we set + if module_fqn == "linear.0": + + # check that the non-zero channel count, at least 2 should be there + # for the first module + counts_info = module_dict[OutlierDetector.CONSTANT_COUNTS_KEY] + assert sum(counts_info) >= 2 + + # half of the recorded max values should be what we set + matched_max = sum([val == 3.28e8 for val in module_dict[OutlierDetector.MAX_VALS_KEY]]) + self.assertEqual(matched_max, param_size / 2) + + +class TestFxModelReportVisualizer(QuantizationTestCase): + + def _callibrate_and_generate_visualizer(self, model, prepared_for_callibrate_model, mod_report): + r""" + Callibrates the passed in model, generates report, and returns the visualizer + """ + # now we actually callibrate the model + example_input = model.get_example_inputs()[0] example_input = example_input.to(torch.float) - q_config_mapping = QConfigMapping() - q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm")) - - # prep model and select observer - model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input) - obs_ctr = ModelReportObserver - - # find layer to attach to and store - linear_fqn = "block2.linear" # fqn of target linear - - target_linear = None - for node in model_prep.graph.nodes: - if node.target == linear_fqn: - target_linear = node - break - - # insert into both module and graph pre and post - - # set up to insert before target_linear (pre_observer) - with model_prep.graph.inserting_before(target_linear): - obs_to_insert = obs_ctr() - pre_obs_fqn = linear_fqn + ".model_report_pre_observer" - model_prep.add_submodule(pre_obs_fqn, obs_to_insert) - model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args) - - # set up and insert after the target_linear (post_observer) - with model_prep.graph.inserting_after(target_linear): - obs_to_insert = obs_ctr() - post_obs_fqn = linear_fqn + ".model_report_post_observer" - model_prep.add_submodule(post_obs_fqn, obs_to_insert) - model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,)) - - # need to recompile module after submodule added and pass input through - model_prep.recompile() - - num_iterations = 10 - for i in range(num_iterations): - if i % 2 == 0: - example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float) - else: - example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float) - model_prep(example_input) - - # run it through the dynamic vs static detector - dynamic_vs_static_detector = DynamicStaticDetector() - dynam_vs_stat_str, dynam_vs_stat_dict = dynamic_vs_static_detector.generate_detector_report(model_prep) - - # one of the stats should be stationary, and the other non-stationary - # as a result, dynamic should be recommended - data_dist_info = [ - dynam_vs_stat_dict[linear_fqn]["pre_observer_data_dist"], - dynam_vs_stat_dict[linear_fqn]["post_observer_data_dist"], - ] - self.assertTrue("stationary" in data_dist_info) - self.assertTrue("non-stationary" in data_dist_info) - self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"]) + prepared_for_callibrate_model(example_input) + + # now get the report by running it through ModelReport instance + generated_report = mod_report.generate_model_report(remove_inserted_observers=False) + + # now we get the visualizer should not error + mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer() + + return mod_rep_visualizer + + @skipIfNoFBGEMM + def test_get_modules_and_features(self): + """ + Tests the get_all_unique_module_fqns and get_all_unique_feature_names methods of + ModelReportVisualizer + + Checks whether returned sets are of proper size and filtered properly + """ + with override_quantized_engine('fbgemm'): + # set the backend for this test + torch.backends.quantized.engine = "fbgemm" + # test with multiple detectors + detector_set = set() + detector_set.add(OutlierDetector(reference_percentile=0.95)) + detector_set.add(InputWeightEqualizationDetector(0.5)) + + model = TwoThreeOps() + + # get tst model and callibrate + prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper( + model, detector_set, model.get_example_inputs()[0] + ) + + mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer( + model, prepared_for_callibrate_model, mod_report + ) + + # ensure the module fqns match the ones given by the get_all_unique_feature_names method + actual_model_fqns = set(mod_rep_visualizer.generated_reports.keys()) + returned_model_fqns = mod_rep_visualizer.get_all_unique_module_fqns() + self.assertEqual(returned_model_fqns, actual_model_fqns) + + # now ensure that features are all properly returned + # all the linears have all the features for two detectors + # can use those as check that method is working reliably + b_1_linear_features = mod_rep_visualizer.generated_reports["block1.linear"] + + # first test all features + returned_all_feats = mod_rep_visualizer.get_all_unique_feature_names(False) + self.assertEqual(returned_all_feats, set(b_1_linear_features.keys())) + + # now test plottable features + plottable_set = set() + + for feature_name in b_1_linear_features: + if type(b_1_linear_features[feature_name]) == torch.Tensor: + plottable_set.add(feature_name) + + returned_plottable_feats = mod_rep_visualizer.get_all_unique_feature_names() + self.assertEqual(returned_plottable_feats, plottable_set) + + def _prep_visualizer_helper(self): + r""" + Returns a mod rep visualizer that we test in various ways + """ + # set backend for test + torch.backends.quantized.engine = "fbgemm" + + # test with multiple detectors + detector_set = set() + detector_set.add(OutlierDetector(reference_percentile=0.95)) + detector_set.add(InputWeightEqualizationDetector(0.5)) + + model = TwoThreeOps() + + # get tst model and callibrate + prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper( + model, detector_set, model.get_example_inputs()[0] + ) + + mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer( + model, prepared_for_callibrate_model, mod_report + ) + + return mod_rep_visualizer + + @skipIfNoFBGEMM + def test_generate_tables_match_with_report(self): + """ + Tests the generate_table_view() + ModelReportVisualizer + + Checks whether the generated dict has proper information + Visual check that the tables look correct performed during testing + """ + with override_quantized_engine('fbgemm'): + + # get the visualizer + mod_rep_visualizer = self._prep_visualizer_helper() + + table_dict = mod_rep_visualizer.generate_filtered_tables() + + # test primarily the dict since it has same info as str + tensor_headers, tensor_table = table_dict[ModelReportVisualizer.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY] + + # these two together should be the same as the generated report info in terms of keys + tensor_info_modules = set(row[1] for row in tensor_table) + channel_info_modules = set(row[1] for row in channel_table) + combined_modules: Set = tensor_info_modules.union(channel_info_modules) + + generated_report_keys: Set = set(mod_rep_visualizer.generated_reports.keys()) + self.assertEqual(combined_modules, generated_report_keys) + + @skipIfNoFBGEMM + def test_generate_tables_no_match(self): + """ + Tests the generate_table_view() + ModelReportVisualizer + + Checks whether the generated dict has proper information + Visual check that the tables look correct performed during testing + """ + with override_quantized_engine('fbgemm'): + # get the visualizer + mod_rep_visualizer = self._prep_visualizer_helper() + + # try a random filter and make sure that there are no rows for either table + empty_tables_dict = mod_rep_visualizer.generate_filtered_tables(module_fqn_filter="random not there module") + + # test primarily the dict since it has same info as str + tensor_headers, tensor_table = empty_tables_dict[ModelReportVisualizer.TABLE_TENSOR_KEY] + channel_headers, channel_table = empty_tables_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY] + + tensor_info_modules = set(row[1] for row in tensor_table) + channel_info_modules = set(row[1] for row in channel_table) + combined_modules: Set = tensor_info_modules.union(channel_info_modules) + self.assertEqual(len(combined_modules), 0) # should be no matching modules + + @skipIfNoFBGEMM + def test_generate_tables_single_feat_match(self): + """ + Tests the generate_table_view() + ModelReportVisualizer + + Checks whether the generated dict has proper information + Visual check that the tables look correct performed during testing + """ + with override_quantized_engine('fbgemm'): + # get the visualizer + mod_rep_visualizer = self._prep_visualizer_helper() + + # try a matching filter for feature and make sure only those features show up + # if we filter to a very specific feature name, should only have 1 additional column in each table row + single_feat_dict = mod_rep_visualizer.generate_filtered_tables(feature_filter=OutlierDetector.MAX_VALS_KEY) + + # test primarily the dict since it has same info as str + tensor_headers, tensor_table = single_feat_dict[ModelReportVisualizer.TABLE_TENSOR_KEY] + channel_headers, channel_table = single_feat_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY] + + # get the number of features in each of these + tensor_info_features = len(tensor_headers) + channel_info_features = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + + # make sure that there are no tensor features, and that there is one channel level feature + self.assertEqual(tensor_info_features, 0) + self.assertEqual(channel_info_features, 1) + +def _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused: bool = False): + r"""Returns a model that has been prepared for callibration and corresponding model_report""" + # set the backend for this test + torch.backends.quantized.engine = "fbgemm" + + # create model instance and prepare it + example_input = example_input.to(torch.float) + q_config_mapping = torch.ao.quantization.get_default_qconfig_mapping() + + # if they passed in fusion paramter, make sure to test that + if fused: + model = torch.quantization.fuse_modules(model, model.get_fusion_modules()) + + model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input) + + model_report = ModelReport(model_prep, detector_set) + + # prepare the model for callibration + prepared_for_callibrate_model = model_report.prepare_detailed_calibration() + + return (prepared_for_callibrate_model, model_report) diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index c6bfe3c5887cf..f371bdfddbb8f 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -13,6 +13,7 @@ toq = torch.ops.quantized from torch.ao.quantization.quantize_fx import ( convert_fx, + convert_to_reference_fx, prepare_fx, prepare_qat_fx, ) @@ -501,12 +502,9 @@ def forward(self, x): m = M().eval() # prevent conv2 from getting quantized, so we can test # modules with equal types - qconfig_dict = { - '': torch.ao.quantization.default_qconfig, - 'module_name': [('conv2', None)], - } + qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping().set_module_name("conv2", None) example_inputs = (torch.randn(1, 1, 1, 1),) - mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -526,10 +524,10 @@ def forward(self, x): # all of these should be matched expected_types = { conv_name_1: - ((nn.Conv2d, torch.ao.quantization.MinMaxObserver), (nnq.Conv2d, nnq.Conv2d)), + ((nn.Conv2d, torch.ao.quantization.HistogramObserver), (nnq.Conv2d, nnq.Conv2d)), conv_name_0: - ((nn.Conv2d, torch.ao.quantization.MinMaxObserver), (nn.Conv2d, nn.Conv2d)), - mul_name_0: ((torch.mul, torch.ao.quantization.MinMaxObserver), (toq.mul, toq.mul)), + ((nn.Conv2d, torch.ao.quantization.HistogramObserver), (nn.Conv2d, nn.Conv2d)), + mul_name_0: ((torch.mul, torch.ao.quantization.HistogramObserver), (toq.mul, toq.mul)), relu_name_0: ((F.relu, torch.ao.quantization.FixedQParamsObserver), (F.relu, F.relu)), sigmoid_name_0: ((torch.sigmoid, torch.ao.quantization.FixedQParamsObserver), (torch.sigmoid, torch.sigmoid)), @@ -547,10 +545,10 @@ def forward(self, x): m1 = M().eval() m2 = M().eval() - qconfig_dict = {'': torch.ao.quantization.default_qconfig} + qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() example_inputs = (torch.randn(1),) - m1p = prepare_fx(m1, qconfig_dict, example_inputs=example_inputs) - m2p = prepare_fx(m2, qconfig_dict, example_inputs=example_inputs) + m1p = prepare_fx(m1, qconfig_mapping, example_inputs=example_inputs) + m2p = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs) results = get_matching_subgraph_pairs(m1p, m2p) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() sigmoid_name_0 = 'base_op_' + get_base_name_for_op( @@ -740,10 +738,10 @@ def forward(self, x): x = _wrapped_hardswish(x) return x - qconfig_dict = {'': torch.ao.quantization.default_qconfig} + qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() example_inputs = (torch.randn(1, 1, 1, 1),) - m1 = prepare_fx(M1().eval(), qconfig_dict, example_inputs=example_inputs) - m2 = prepare_fx(M2().eval(), qconfig_dict, example_inputs=example_inputs) + m1 = prepare_fx(M1().eval(), qconfig_mapping, example_inputs=example_inputs) + m2 = prepare_fx(M2().eval(), qconfig_mapping, example_inputs=example_inputs) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() add_op_to_sets_of_related_ops( @@ -758,7 +756,7 @@ def forward(self, x): expected_types = { hardswish_name_0: - ((F.hardswish, torch.ao.quantization.MinMaxObserver), (_wrapped_hardswish, _wrapped_hardswish)), + ((F.hardswish, torch.ao.quantization.HistogramObserver), (_wrapped_hardswish, _wrapped_hardswish)), } self.assert_types_for_matched_subgraph_pairs( results, expected_types, m1, m2) @@ -856,7 +854,7 @@ def _test_match_activations( prepare_fn=prepare_fx, ): if qconfig_dict is None: - qconfig_dict = {'': torch.ao.quantization.default_qconfig} + qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() if prepare_fn == prepare_fx: m.eval() else: @@ -919,7 +917,7 @@ def _test_match_shadow_activations( prepare_fn=prepare_fx, compare_fp32_vs_fp32_prepared=True, ): if qconfig_dict is None: - qconfig_dict = {'': torch.ao.quantization.default_qconfig} + qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() if prepare_fn == prepare_fx: m.eval() else: @@ -1208,9 +1206,9 @@ def test_shadow_activations_fqn(self): nn.Sequential(nn.Conv2d(1, 1, 1)), nn.Conv2d(1, 1, 1), ).eval() - qconfig_dict = {'': torch.ao.quantization.default_qconfig} + qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() example_inputs = (torch.randn(1, 1, 1, 1),) - mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers('a', mp, 'b', mq, OutputLogger) datum = torch.randn(1, 1, 1, 1) @@ -1696,10 +1694,10 @@ def forward(self, x): x = _wrapped_linear(x, self.w1, self.b1) return x - qconfig_dict = {'': torch.ao.quantization.default_qconfig} + qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() example_inputs = (torch.randn(1, 1),) - m1 = prepare_fx(M1().eval(), qconfig_dict, example_inputs=example_inputs) - m2 = prepare_fx(M2().eval(), qconfig_dict, example_inputs=example_inputs) + m1 = prepare_fx(M1().eval(), qconfig_mapping, example_inputs=example_inputs) + m2 = prepare_fx(M2().eval(), qconfig_mapping, example_inputs=example_inputs) data = torch.randn(1, 1) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() @@ -1767,9 +1765,9 @@ def test_layer_names(self): nn.Conv2d(1, 1, 1), nn.Sigmoid(), ).eval() - qconfig_dict = {'': torch.ao.quantization.default_qconfig} + qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping("fbgemm") example_inputs = (torch.randn(1, 1, 1, 1),) - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) # extract weights @@ -1986,7 +1984,7 @@ def test_fp16_shadows_fp32(self): example_inputs = (torch.randn(1, 4),) qconfig_dict = {"": torch.ao.quantization.float16_static_qconfig} mp = prepare_fx(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) - mq = convert_fx(mp, is_reference=True) + mq = convert_to_reference_fx(mp) mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger) def test_mul_add_cat_stack_skips_shadowing(self): diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 11063d90b4d18..c008c700815c3 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -18,6 +18,7 @@ from torch.ao.quantization.quantize_fx import ( prepare_fx, convert_fx, + convert_to_reference_fx, prepare_qat_fx, fuse_fx, ) @@ -1045,7 +1046,7 @@ def forward(self, x): qconfig_dict = {'': qconfig} example_inputs = (torch.rand(1, 1),) prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) - quantized = convert_fx(prepared, is_reference=True) + quantized = convert_to_reference_fx(prepared) qparams = (quantized._scale_0, quantized._zero_point_0) weight_obs = qconfig.weight() weight_obs(quantized.weight) @@ -1225,7 +1226,8 @@ def forward(self, x): m = ModuleClass(*module_constructor_inputs).eval() qconfig_dict = {"": float16_dynamic_qconfig} m = prepare_fx(m, qconfig_dict, example_inputs=inputs) - m = convert_fx(m, is_reference=is_reference) + convert_fn = convert_to_reference_fx if is_reference else convert_fx + m = convert_fn(m) self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) @@ -1856,6 +1858,7 @@ def forward(self, x): # should not crash as in https://github.com/pytorch/pytorch/issues/75825 prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) + # TODO: move QConfigMapping tests to test/quantization/core def test_qconfig_mapping_set_global(self): qconfig = get_default_qconfig() qconfig_mapping = QConfigMapping() @@ -3680,9 +3683,9 @@ def forward(self, x): return x m = M().eval() - qconfig_dict = {"": float16_static_qconfig} + qconfig_mapping = get_default_qconfig_mapping().set_global(float16_static_qconfig) # make sure quantization runs - m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),)) + m = prepare_fx(m, qconfig_mapping, example_inputs=(torch.randn(1),)) m = convert_fx(m) def test_qparams_fqn(self): @@ -4015,7 +4018,7 @@ def forward(self, x): m_copy = copy.deepcopy(m) m = convert_fx(m) - m_reference = convert_fx(m_copy, is_reference=True) + m_reference = convert_to_reference_fx(m_copy) # checks for non-reference quantized model node_occurrence = { @@ -4069,7 +4072,7 @@ def forward(self, x): qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 ), weight=torch.ao.quantization.default_per_channel_weight_observer)}, example_inputs=example_inputs) - m = convert_fx(m, is_reference=True) + m = convert_to_reference_fx(m) m(*example_inputs) def test_preserve_tuple(self): @@ -4103,7 +4106,7 @@ def forward(self, x): m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),)) m_copy = copy.deepcopy(m) m = convert_fx(m) - m_ref = convert_fx(m_copy, is_reference=True) + m_ref = convert_to_reference_fx(m_copy) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 @@ -4283,8 +4286,8 @@ def forward(self, x): example_inputs = (torch.randn(5, 10),) m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m_copy = copy.deepcopy(m) - m = convert_fx(m, is_reference=False) - m_ref = convert_fx(m_copy, is_reference=True) + m = convert_fx(m) + m_ref = convert_to_reference_fx(m_copy) result = m(*example_inputs) result_ref = m_ref(*example_inputs) self.assertTrue(torch.equal(result, result_ref)) @@ -4321,8 +4324,8 @@ def forward(self, x): data = self.img_data_dict[dim][0][0] m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) m_copy = copy.deepcopy(m) - m = convert_fx(m, is_reference=False) - m_ref = convert_fx(m_copy, is_reference=True) + m = convert_fx(m) + m_ref = convert_to_reference_fx(m_copy) result = m(data) result_ref = m_ref(data) self.assertTrue(torch.equal(result, result_ref)) @@ -4414,7 +4417,7 @@ def forward(self, x): data = self.img_data_dict[dim][0][0] m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) m_ref = copy.deepcopy(m) - m_ref = convert_fx(m_ref, is_reference=True) + m_ref = convert_to_reference_fx(m_ref) m = convert_fx(m) out_ref = m_ref(data) out = m(data) @@ -4634,8 +4637,8 @@ def forward(self, x): break self.assertTrue(found_stack_trace) - # test is_reference == True - mq = convert_fx(copy.deepcopy(mp), is_reference=True) + # test reference model + mq = convert_to_reference_fx(copy.deepcopy(mp)) found_stack_trace = False for n in mq.graph.nodes: if n.op == 'call_module' and n.target == 'linear': @@ -4643,8 +4646,8 @@ def forward(self, x): break self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: True") - # test is_reference == False - mq = convert_fx(mp, is_reference=False) + # test quantized model + mq = convert_fx(mp) found_stack_trace = False for n in mq.graph.nodes: if n.op == 'call_module' and n.target == 'linear': @@ -5780,6 +5783,28 @@ def test_elu(self): def test_leaky_relu(self): self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu) + def test_prelu(self): + class M(torch.nn.Module): + def __init__(self, num_param: int): + super(M, self).__init__() + self.op = torch.nn.PReLU(num_parameters=num_param) + + def forward(self, input): + return self.op(input) + + X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]] + options = itertools.product([1, 4], self.static_quant_types, [True, False]) + quantized_nodes = { + # is_reference + True: ns.call_module(torch.nn.PReLU), + False: ns.call_module(torch.nn.quantized.PReLU), + } + + for num_parameter, quant_type, is_reference in options: + self.checkGraphModeFxOp( + M(num_parameter), X, quant_type, quantized_nodes[is_reference], + is_reference=is_reference) + def _test_norm_impl( self, float_module, float_op, op_args, data, quantized_module, quantized_op, skip_op_arg_for_functional=False): @@ -5942,9 +5967,10 @@ def forward(self, x): qconfig_dict = {"": qconfig} m = M(module, functional).eval() - m_prep = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict) + m_prep = prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict) m_prep(data) - m_quant = torch.ao.quantization.quantize_fx.convert_fx(m_prep, is_reference=is_reference) + convert_fn = convert_to_reference_fx if is_reference else convert_fx + m_quant = convert_fn(m_prep, is_reference=is_reference) m_quant(data) self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) @@ -6111,9 +6137,10 @@ def forward(self, x, y): ] m = M().eval() - m_prep = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m_prep = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m_prep(*example_inputs) - m_quant = torch.ao.quantization.quantize_fx.convert_fx(m_prep, is_reference=is_reference) + convert_fn = convert_to_reference_fx if is_reference else convert_fx + m_quant = convert_fn(m_prep) m_quant(*example_inputs) self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) @@ -6171,15 +6198,14 @@ def forward(self, x): data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) quant_type = QuantType.STATIC - qconfig_dict = { - "": float16_static_qconfig - } + # TODO: use get_default_qconfig_mapping once it handles fp16 + qconfig_mapping = QConfigMapping().set_global(float16_static_qconfig) backend_config_dict = get_test_only_legacy_native_backend_config_dict() node_occurrence = { ns.call_method("to"): 7 } self.checkGraphModeFxOp( - M(), data, quant_type, custom_qconfig_dict=qconfig_dict, + M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, expected_node_occurrence=node_occurrence, backend_config_dict=backend_config_dict) @@ -6204,13 +6230,13 @@ def forward(self, x): qconfig = torch.ao.quantization.QConfig( activation=HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8), weight=default_weight_observer) - qconfig_dict = {"": qconfig} + qconfig_mapping = get_default_qconfig_mapping().set_global(qconfig) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 7, ns.call_method("dequantize"): 7 } self.checkGraphModeFxOp( - M(), data, quant_type, custom_qconfig_dict=qconfig_dict, + M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, expected_node_occurrence=node_occurrence, is_reference=True) @skipIfNoFBGEMM @@ -6337,7 +6363,7 @@ def forward(self, x): qconfig_dict = {'': default_qconfig} prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable - quantized = convert_fx(prepared, is_reference=True) + quantized = convert_to_reference_fx(prepared) @skipIfNoFBGEMM @@ -6479,7 +6505,8 @@ def forward(self, xs): m = M().eval() example_inputs = (torch.rand(1, 2),) - m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) + qconfig_mapping = get_default_qconfig_mapping() + m = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) @@ -6494,9 +6521,10 @@ def forward(self, xs): m2 = M2().eval() example_inputs = ([torch.rand(1, 2)],) - m2 = prepare_fx(m2, {"": default_qconfig}, example_inputs=example_inputs) + qconfig_mapping = get_default_qconfig_mapping() + m2 = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs) self.checkGraphModuleNodes(m2, expected_node_occurrence={ - ns.call_module(torch.ao.quantization.MinMaxObserver): 1 + ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2 }) m2 = convert_fx(m2) self.checkGraphModuleNodes(m2, expected_node_list=[ @@ -6515,9 +6543,10 @@ def forward(self, x): m3 = M3().eval() example_inputs = (torch.rand(1, 2, 3, 4),) - m3 = prepare_fx(m3, {"": default_qconfig}, example_inputs=example_inputs) + qconfig_mapping = get_default_qconfig_mapping() + m3 = prepare_fx(m3, qconfig_mapping, example_inputs=example_inputs) self.checkGraphModuleNodes(m3, expected_node_occurrence={ - ns.call_module(torch.ao.quantization.MinMaxObserver): 1 + ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2 }) m3 = convert_fx(m3) self.checkGraphModuleNodes(m3, expected_node_list=[ @@ -6561,20 +6590,18 @@ def forward(self, x): m = M() if eval_mode: m.eval() - qconfig = default_qconfig + qconfig_mapping = get_default_qconfig_mapping() prepare = prepare_fx fq_count = 10 else: m.train() - qconfig = default_qat_qconfig + qconfig_mapping = get_default_qat_qconfig_mapping() prepare = prepare_qat_fx fq_count = 10 - # nothing to fuse so skipping the fuse step m_copy = copy.deepcopy(m) - qconfig_dict = {'': qconfig} example_inputs = (torch.rand(3, 3, 3, 3),) - prepared = prepare(m, qconfig_dict, example_inputs=example_inputs) + prepared = prepare(m, qconfig_mapping, example_inputs=example_inputs) prepared_copy = copy.deepcopy(prepared) # check that prepare does not change model result if eval_mode: @@ -6589,7 +6616,7 @@ def forward(self, x): expected_node_occurrence=count_check) # not runnable quantized = convert_fx(prepared) - quantized_reference = convert_fx(prepared_copy, is_reference=True) + quantized_reference = convert_to_reference_fx(prepared_copy) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra @@ -6944,16 +6971,14 @@ def forward(self, x): w = torch.randn(4, 4) b = torch.randn(4) m = M(w, b).eval() - qconfig_dict = { - "": float16_static_qconfig, - "object_type": [ - (torch.nn.functional.linear, default_qconfig) - ] - } + # TODO: use get_default_qconfig_mapping once it handles fp16 + qconfig_mapping = QConfigMapping() \ + .set_global(float16_static_qconfig) \ + .set_object_type(torch.nn.functional.linear, default_qconfig) example_inputs = (torch.randn(1, 4),) backend_config_dict = get_test_only_legacy_native_backend_config_dict() m = prepare_fx( - m, qconfig_dict, example_inputs=example_inputs, + m, qconfig_mapping, example_inputs=example_inputs, backend_config_dict=backend_config_dict) expected_occurrence = { # input and weight of linear, output of linear @@ -7094,7 +7119,7 @@ def forward(self, x): qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) model_prepared(*example_inputs) - model_quantized = convert_fx(model_prepared, is_reference=True) + model_quantized = convert_to_reference_fx(model_prepared) out = model_quantized(*example_inputs) self.assertEqual(out.device.type, 'cuda') @@ -7122,7 +7147,7 @@ def forward(self, x): model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,)) model_prepared(input) model_prepared.to(device_after) - model_quantized = convert_fx(model_prepared, is_reference=True) + model_quantized = convert_to_reference_fx(model_prepared) out = model_quantized(input.to(device_after)) self.assertEqual(out.device.type, device_after) @@ -7152,7 +7177,7 @@ def forward(self, x): del model_prepared_first model_prepared_second.load_state_dict(state_dict) model_prepared_second.to(device_after) - model_quantized = convert_fx(model_prepared_second, is_reference=True) + model_quantized = convert_to_reference_fx(model_prepared_second) out = model_quantized(input.to(device_after)) self.assertEqual(out.device.type, device_after) @@ -7160,9 +7185,9 @@ def forward(self, x): def test_model_dropout(self): from torchvision import models m = models.mobilenet_v3_small() - qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')} + qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('fbgemm') example_inputs = (torch.randn(1, 3, 224, 224),) - mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) + mp = prepare_qat_fx(m, qconfig_mapping, example_inputs=example_inputs) mp(*example_inputs) mq = convert_fx(mp) res = mq(*example_inputs) @@ -7608,6 +7633,10 @@ def forward(self, x): torch.testing.assert_allclose(grad[0], grad_ref[0]) if 'fbgemm' in torch.backends.quantized.supported_engines: + # During the lowering step in convert, fold_weight calls quantized::linear_prepack + # which doesn't support QuantizedCuda backend + prepared.cpu() + prepared_ref.cpu() converted = convert_fx(prepared) converted_ref = convert_fx(prepared_ref) inp = torch.rand(5, 5) diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index 6648bcaa9afc6..84ab3a723b70f 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -73,6 +73,7 @@ from torch.testing._internal.jit_utils import attrs_with_prefix from torch.testing._internal.jit_utils import get_forward from torch.testing._internal.jit_utils import get_forward_graph +from torch.testing._internal.common_utils import skipIfSlowGradcheckEnv from torch.jit._recursive import wrap_cpp_module @@ -1625,6 +1626,7 @@ def forward(self, x): torch.jit.save(model, b) +@skipIfSlowGradcheckEnv class TestQuantizeJitOps(QuantizationTestCase): """Test graph mode post training static quantization works for individual ops end to end. diff --git a/test/run_test.py b/test/run_test.py index 815c7de79e9bb..0ed38673728d7 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -13,6 +13,8 @@ import subprocess import sys import tempfile +import json +from typing import Dict, Optional, List, cast, Any import torch from torch.utils import cpp_extension @@ -25,18 +27,17 @@ parser as common_parser, ) import torch.distributed as dist -from typing import Optional, List REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent try: # using tools/ to optimize test run. sys.path.append(str(REPO_ROOT)) + from tools.stats.export_test_times import TEST_TIMES_FILE from tools.testing.test_selections import ( - export_S3_test_times, - get_shard_based_on_S3, get_reordered_tests, get_test_case_configs, + calculate_shards, ) HAVE_TEST_SELECTION_TOOLS = True except ImportError: @@ -72,6 +73,7 @@ def skip_test_p(name: str) -> bool: rc += extra_tests return sorted(rc) + TESTS = discover_tests( blocklisted_patterns=[ 'ao', @@ -100,8 +102,6 @@ def skip_test_p(name: str) -> bool: 'test_static_runtime', 'test_throughput_benchmark', 'test_typing', - "distributed/algorithms/ddp_comm_hooks/test_ddp_hooks", - "distributed/algorithms/quantization/test_quantization", "distributed/bin/test_script", "distributed/elastic/multiprocessing/bin/test_script", "distributed/launcher/bin/test_script", @@ -247,6 +247,7 @@ def skip_test_p(name: str) -> bool: RUN_PARALLEL_BLOCKLIST = [ "test_cpp_extensions_jit", + "test_cpp_extensions_open_device_registration", "test_jit_disabled", "test_mobile_optimizer", "test_multiprocessing", @@ -269,9 +270,6 @@ def skip_test_p(name: str) -> bool: "test_torch" ] -# the JSON file to store the S3 test stats -TEST_TIMES_FILE = ".pytorch-test-times.json" - # if a test file takes longer than 5 min, we add it to TARGET_DET_LIST SLOW_TEST_THRESHOLD = 300 @@ -318,6 +316,19 @@ def skip_test_p(name: str) -> bool: DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith("distributed")] + +def discover_functorch_tests(): + pytorch_root = pathlib.Path(__file__).resolve().parent.parent + functorch_test_dir = os.path.join(pytorch_root, 'functorch', 'test') + result = discover_tests(pathlib.Path(functorch_test_dir)) + result = [os.path.join(functorch_test_dir, r) for r in result] + + # Sanity check + assert len(result) >= 8 + return result + +FUNCTORCH_TESTS = discover_functorch_tests() + TESTS_REQUIRING_LAPACK = [ "distributions/test_constraints", "distributions/test_distributions", @@ -383,6 +394,7 @@ def test_cuda_primary_ctx(test_module, test_directory, options): test_module, test_directory, options, extra_unittest_args=["--subprocess"] ) + run_test_with_subprocess = functools.partial(run_test, extra_unittest_args=["--subprocess"]) @@ -390,7 +402,6 @@ def get_run_test_with_subprocess_fn(): return lambda test_module, test_directory, options: run_test_with_subprocess(test_module, test_directory, options) - def _test_cpp_extensions_aot(test_directory, options, use_ninja): if use_ninja: try: @@ -544,6 +555,7 @@ def test_distributed(test_module, test_directory, options): "test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja, "test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja, "distributed/test_distributed_spawn": test_distributed, + "distributed/algorithms/quantization/test_quantization": test_distributed, "distributed/test_c10d_nccl": get_run_test_with_subprocess_fn(), "distributed/test_c10d_gloo": get_run_test_with_subprocess_fn(), "distributed/test_c10d_common": get_run_test_with_subprocess_fn(), @@ -557,6 +569,7 @@ def test_distributed(test_module, test_directory, options): "distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(), } + def parse_test_module(test): return test.split(".")[0] @@ -590,6 +603,16 @@ def parse_args(): action="store_true", help="run all distributed tests", ) + parser.add_argument( + "--functorch", + "--functorch", + action="store_true", + help=( + "If this flag is present, we will only run functorch tests. " + "If this flag is not present, we will not run any functorch tests. " + "This requires functorch to already be installed." + ) + ) parser.add_argument( "-core", "--core", @@ -677,13 +700,6 @@ def parse_args(): help="additional arguments passed through to unittest, e.g., " "python run_test.py -i sparse -- TestSparse.test_factory_size_check", ) - parser.add_argument( - "--export-past-test-times", - nargs="?", - type=str, - const=TEST_TIMES_FILE, - help="dumps test times from previous S3 stats into a file, format JSON", - ) parser.add_argument( "--shard", nargs=2, @@ -778,6 +794,9 @@ def get_selected_tests(options): filter(lambda test_name: test_name in CORE_TEST_LIST, selected_tests) ) + if options.functorch: + selected_tests = FUNCTORCH_TESTS + # process reordering if options.bring_to_front: to_front = set(options.bring_to_front) @@ -838,11 +857,29 @@ def get_selected_tests(options): assert num_shards <= len( selected_tests ), f"Number of shards must be less than {len(selected_tests)}" - # TODO: fix this to use test_times_filename, but currently this is not working - # because setting the export arg immeidately halts the test execution. - selected_tests = get_shard_based_on_S3( - which_shard, num_shards, selected_tests, TEST_TIMES_FILE - ) + + if num_shards == 1: + return selected_tests + + # Download previous test times to make sharding decisions + path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE) + if os.path.exists(path): + with open(path, "r") as f: + test_file_times = cast(Dict[str, Any], json.load(f)) + else: + test_file_times = {} + test_config = os.environ.get("TEST_CONFIG") + if test_config not in test_file_times: + print( + "::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan." + ) + selected_tests = selected_tests[which_shard - 1 :: num_shards] + else: + print("Found test time stats from artifacts") + test_file_times_config = test_file_times[test_config] + shards = calculate_shards(num_shards, selected_tests, test_file_times_config) + _, tests_from_shard = shards[which_shard - 1] + selected_tests = tests_from_shard # skip all distributed tests if distributed package is not available. if not dist.is_available(): @@ -882,15 +919,6 @@ def run_test_module(test: str, test_directory: str, options) -> Optional[str]: def main(): options = parse_args() - # TODO: move this export & download function in tools/ folder - test_times_filename = options.export_past_test_times - if test_times_filename: - print( - f"Exporting past test times from S3 to {test_times_filename}, no tests will be run." - ) - export_S3_test_times(test_times_filename) - return - test_directory = str(REPO_ROOT / "test") selected_tests = get_selected_tests(options) diff --git a/test/test_ao_sparsity.py b/test/test_ao_sparsity.py index 13d515c586e5b..89c941cddab1b 100644 --- a/test/test_ao_sparsity.py +++ b/test/test_ao_sparsity.py @@ -23,10 +23,20 @@ # Composability from ao.sparsity.test_composability import TestComposability # noqa: F401 +from ao.sparsity.test_composability import TestFxComposability # noqa: F401 + +# Utilities +from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F401 # Data Sparsifier from ao.sparsity.test_data_sparsifier import TestBaseDataSparsifier # noqa: F401 from ao.sparsity.test_data_sparsifier import TestNormDataSparsifiers # noqa: F401 +# Data Scheduler +from ao.sparsity.test_data_scheduler import TestBaseDataScheduler # noqa: F401 + +# Activation Sparsifier +from ao.sparsity.test_activation_sparsifier import TestActivationSparsifier # noqa: F401 + if __name__ == '__main__': run_tests() diff --git a/test/test_autograd.py b/test/test_autograd.py index 866f25e5e0852..793e175ba6268 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -836,6 +836,41 @@ def test_retain_grad(self): out.backward() self.assertEqual(input * 18, input.grad) + # NB: See test/cpp/api/autograd.cpp for more tests on the interaction between + # retains_grad and hooks in cpp. There's no point testing in python because + # Python hooks use a completely different mechanism. + def test_retain_grad_inplace(self): + a = torch.tensor([1.], requires_grad=True).clone() + a.retain_grad() + a.mul_(2) + a.sum().backward() + self.assertEqual(a.grad, torch.tensor([1.])) + + a = torch.tensor([1.], requires_grad=True).clone() + a.retain_grad() + # Inplace multiple times is OK, the real test here would be in cpp though + # because the index here is always zero, having cpp hooks in addition, + # will force us to properly update the index + a.mul_(2) + a.mul_(2) + a.sum().backward() + self.assertEqual(a.grad, torch.tensor([1.])) + + def test_retain_grad_inplace_over_view(self): + base = torch.tensor([1.], requires_grad=True).clone() + view = base[:] + view2 = base[:] + view.retain_grad() + view2.retain_grad() + view.mul_(2) + (view + view2).sum().backward() + + # The old grad_fn, slice, wouldn't be part of the graph during backward + # so if the retains grad were not properly updated to the new grad_fn, + # the grad would still be None + self.assertEqual(view.grad, view2.grad) + self.assertEqual(view.grad, torch.tensor([1.])) + def test_retain_grad_cycle(self): x = torch.ones(5, 5, requires_grad=True) @@ -4159,8 +4194,10 @@ def jvp(ctx, x_t, y_t): jvp_count[0] += 1 return x_t, y_t - x = torch.rand(2, dtype=torch.double, requires_grad=True) - y = torch.rand(2, dtype=torch.double, requires_grad=True) + # NB: In slow gradcheck we need to loop through numel times so use numel = 1 to ensure + # that fast and slow have the same counts + x = torch.rand(1, dtype=torch.double, requires_grad=True) + y = torch.rand(1, dtype=torch.double, requires_grad=True) gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False, check_batched_grad=False, check_batched_forward_grad=False) self.assertEqual(jvp_count[0], 2) # (2) once per input @@ -4179,8 +4216,8 @@ def jvp(ctx, x_t, y_t): # Repeat the previous test except we mark one input with requires_grad=False # NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)! # Otherwise, other counts are halved. - x = torch.rand(2, dtype=torch.double, requires_grad=True) - y = torch.rand(2, dtype=torch.double, requires_grad=False) + x = torch.rand(1, dtype=torch.double, requires_grad=True) + y = torch.rand(1, dtype=torch.double, requires_grad=False) gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False, check_batched_grad=False, check_batched_forward_grad=True) self.assertEqual(jvp_count[0], 5) # 1 + 1 + 3 @@ -4455,6 +4492,54 @@ def test_checkpointing(self): mean_combined = torch.stack(feat_combined).mean() mean_combined.backward() + def _test_checkpointing_non_reentrant_autocast(self, device_type): + for enabled in [True, False]: + def foo(x, y, z): + # torch.mm is on autocast's list of ops that should run in + # the autocast precision + x = torch.mm(x, y) + y = torch.mm(x, z) + z = torch.mm(z, z) + expected_dtype = ( + torch.float32 if not enabled else torch.bfloat16 + ) + self.assertEqual(expected_dtype, z.dtype) + return z + + x = torch.randn(3, 3, requires_grad=True) + y = torch.randn(3, 3, requires_grad=True) + z = torch.randn(3, 3, requires_grad=True) + if device_type == 'cuda': + x = x.cuda() + y = y.cuda() + z = z.cuda() + + with torch.autocast(enabled=enabled, device_type=device_type, dtype=torch.bfloat16): + loss = checkpoint(foo, x, y, z, use_reentrant=False) + loss = loss.sum() + + # Without saving + recasting the autocast type, would raise error in autograd + # about mismatched dtypes. + loss.backward() # triggers recomputation to check it runs in bfloat + + def test_checkpointing_non_reentrant_autocast_cpu(self): + """ + Test that autocast args such as the dtype are preserved during non-reentrant + checkpoint recomputation on CPU. + """ + self._test_checkpointing_non_reentrant_autocast(device_type='cpu') + + @unittest.skipIf( + not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(), + "Test requires CUDA bf16 support" + ) + def test_checkpointing_non_reentrant_autocast_gpu(self): + """ + Test that autocast args/kwargs such as the dtype are preserved during + non-reentrant checkpoint recomputation on GPU. + """ + self._test_checkpointing_non_reentrant_autocast(device_type='cuda') + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") @slowTest def test_checkpointing_without_reentrant_memory_savings(self): @@ -4599,6 +4684,18 @@ def test_checkpointing_without_reentrant(self, input_requires_grad): nn.Linear(nz_bottleneck, nz_inp) ) + # Module holder for testing activation checkpointing with no_reentrant + # supports kwargs. + class MyModule(nn.Module): + def __init__(self, mod): + super().__init__() + self.module = mod + + def forward(self, data): + return self.module(data) + + module = MyModule(mod=module) + # Run model with and without checkpointing and verify gradients are # equivalent, regardless of if inputs require grads or not. module_copy = deepcopy(module) @@ -4610,7 +4707,7 @@ def test_checkpointing_without_reentrant(self, input_requires_grad): data_r.uniform_() data_r.requires_grad = input_requires_grad data_r_copy = data_r.clone() - feat_r = checkpoint(module, data_r, use_reentrant=False) + feat_r = checkpoint(module, data=data_r, use_reentrant=False) feat_combined.append(feat_r) feat_r_no_checkpoint = module_copy(data_r) feat_combined_no_checkpoint.append(feat_r_no_checkpoint) @@ -6789,6 +6886,23 @@ def test_metadata_check_check_conj(self): torch.real(dual) torch.imag(dual) + def test_metadata_check_ignore_storage_offset_for_zero_numel_tensor(self): + # See https://github.com/pytorch/pytorch/issues/80507 + a = torch.tensor([1.]).as_strided((0,), (1,), 1) + b = torch.tensor([1.]).as_strided((0,), (1,), 2) + + with fwAD.dual_level(): + dual_input = fwAD.make_dual(a, b) + # Check that no copy is made + self.assertIs(fwAD.unpack_dual(dual_input).tangent, b) + + a = torch.tensor([1.]).as_strided((1,), (2,), 0) + b = torch.tensor([1.]).as_strided((1,), (1,), 0) + + with fwAD.dual_level(): + dual_input = fwAD.make_dual(a, b) + dual_input[1:] + # The following test functions want to ensure all the following behaviors: # - Ensure that default level system in the python binding works # - Ensure that only level 0 exists and nesting is properly disabled @@ -6800,6 +6914,7 @@ def test_metadata_check_check_conj(self): # - Ensure that view + inplace for both modes work fine # - Ensure we do proper cleanup on exit of a level + def test_default_level(self): foo = torch.rand(2) bar = torch.rand(2) @@ -6835,6 +6950,24 @@ def test_set_fw_grad_having_own_fw_grad_at_same_level(self): with self.assertRaisesRegex(RuntimeError, "has a forward gradient at the same level"): fwAD.make_dual(baz, dual) + def test_codegen_ignores_undefined_outputs(self): + # This test checks that codegen silently ignores undefined outputs + # Below, grad_input is specified as False in grad_output_mask, so + # convolution backward will return a undefined tensor in that position. + # Note that for this test to work we need to make sure either grad_output + # or weight to be a dual tensor, so grad_input requires forward grad + weight = torch.randn(6, 1, 30, 30) + inp = torch.rand((1, 1, 32, 32)) + out = torch.nn.functional.conv2d(inp, weight) + grad_out = torch.ones_like(out) + + with fwAD.dual_level(): + dual_weight = fwAD.make_dual(weight, torch.ones_like(weight)) + grad_input, _, _ = torch.ops.aten.convolution_backward( + grad_out, inp, dual_weight, (0,), + (1, 1), (0, 0), (1, 1), False, (0, 0), 1, (False, True, False)) + self.assertIsNone(grad_input) + def test_make_dual_inference_tensor_in_inference_mode(self): with torch.inference_mode(): foo = torch.rand(2) diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 9875f4ee3567e..e4b1e9e550873 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -15,7 +15,7 @@ import torch.backends.cudnn import torch.utils.cpp_extension from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME -from torch.testing._internal.common_utils import gradcheck +from torch.testing._internal.common_utils import gradcheck, skipIfSlowGradcheckEnv TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None @@ -37,7 +37,8 @@ def remove_build_path(): if os.path.exists(default_build_root): shutil.rmtree(default_build_root) - +# There's only one test that runs gracheck, run slow mode manually +@skipIfSlowGradcheckEnv class TestCppExtensionJIT(common.TestCase): """Tests just-in-time cpp extensions. Don't confuse this with the PyTorch JIT (aka TorchScript). @@ -864,7 +865,8 @@ def test_custom_compound_op_autograd(self): a = torch.randn(5, 5, requires_grad=True) b = torch.randn(5, 5, requires_grad=True) - gradcheck(torch.ops.my.add, [a, b], eps=1e-2) + for fast_mode in (True, False): + gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode) if __name__ == "__main__": diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py new file mode 100644 index 0000000000000..ca69bf8398c72 --- /dev/null +++ b/test/test_cpp_extensions_open_device_registration.py @@ -0,0 +1,102 @@ +# Owner(s): ["module: cpp-extensions"] + +import os +import shutil +import sys + +import torch.testing._internal.common_utils as common +import torch +import torch.utils.cpp_extension +from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME + + +TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None +TEST_CUDNN = False +TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None and ROCM_HOME is not None +if TEST_CUDA and torch.version.cuda is not None: # the skip CUDNN test for ROCm + CUDNN_HEADER_EXISTS = os.path.isfile(os.path.join(CUDA_HOME, "include/cudnn.h")) + TEST_CUDNN = ( + TEST_CUDA and CUDNN_HEADER_EXISTS and torch.backends.cudnn.is_available() + ) +IS_WINDOWS = sys.platform == "win32" + + +def remove_build_path(): + if sys.platform == "win32": + # Not wiping extensions build folder because Windows + return + default_build_root = torch.utils.cpp_extension.get_default_build_root() + if os.path.exists(default_build_root): + shutil.rmtree(default_build_root, ignore_errors=True) + + +class TestCppExtensionOpenRgistration(common.TestCase): + """Tests Open Device Registration with C++ extensions. + """ + + def setUp(self): + super().setUp() + # cpp extensions use relative paths. Those paths are relative to + # this file, so we'll change the working directory temporarily + self.old_working_dir = os.getcwd() + os.chdir(os.path.dirname(os.path.abspath(__file__))) + + def tearDown(self): + super().tearDown() + # return the working directory (see setUp) + os.chdir(self.old_working_dir) + + @classmethod + def setUpClass(cls): + remove_build_path() + + @classmethod + def tearDownClass(cls): + remove_build_path() + + def test_open_device_registration(self): + module = torch.utils.cpp_extension.load( + name="custom_device_extension", + sources=[ + "cpp_extensions/open_registration_extension.cpp", + ], + extra_include_paths=["cpp_extensions"], + extra_cflags=["-g"], + verbose=True, + ) + + self.assertFalse(module.custom_add_called()) + + # create a tensor using our custom device object. + device = module.custom_device() + + x = torch.empty(4, 4, device=device) + y = torch.empty(4, 4, device=device) + + # Check that our device is correct. + self.assertTrue(x.device == device) + self.assertFalse(x.is_cpu) + + self.assertFalse(module.custom_add_called()) + + # calls out custom add kernel, registered to the dispatcher + z = x + y + + # check that it was called + self.assertTrue(module.custom_add_called()) + + z_cpu = z.to(device='cpu') + + # Check that our cross-device copy correctly copied the data to cpu + self.assertTrue(z_cpu.is_cpu) + self.assertFalse(z.is_cpu) + self.assertTrue(z.device == device) + self.assertEqual(z, z_cpu) + + z2 = z_cpu + z_cpu + + # None of our CPU operations should call the custom add function. + self.assertFalse(module.custom_add_called()) + +if __name__ == "__main__": + common.run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index 1cdf3fa3137f4..cea33fb087898 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -26,7 +26,7 @@ from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \ NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \ slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY, \ - get_cycles_per_ms + get_cycles_per_ms, parametrize, instantiate_parametrized_tests from torch.testing._internal.autocast_test_lists import AutocastTestLists # load_tests from common_utils is used to automatically filter tests for @@ -569,8 +569,8 @@ def test_serialization_array_with_storage(self): self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor)) self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor)) self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor)) - self.assertTrue(isinstance(q_copy[3], torch.storage._TypedStorage)) - self.assertTrue(isinstance(q_copy[3]._storage, torch._UntypedStorage)) + self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) + self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage)) q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) @@ -1351,7 +1351,6 @@ def _get_external_stream(self, device): out = cudart.cudaStreamDestroy(stream.value) self.assertEqual(out, 0) - @skipIfRocm def test_external_streams(self): device = torch.cuda.device(0) with self._get_external_stream(device) as stream_v: @@ -1359,7 +1358,6 @@ def test_external_streams(self): self.assertEqual(stream_v, ext_stream.cuda_stream) self.assertEqual(ext_stream.device.index, device.idx) - @skipIfRocm @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_external_streams_multi_device(self): device = torch.cuda.device(1) @@ -2806,7 +2804,11 @@ def test_autocast_torch_bf16(self): op, args = op_with_args[0], op_with_args[1] if len(op_with_args) == 3: skip_test = op_with_args[2] # TEST_WITH_ROCM - should_error_from_not_implemented = 'cudnn' in op or 'prelu' in op or 'thnn' in op \ + should_error_from_cudnn = 'cudnn' in op and not\ + ('TORCH_CUDNN_V8_API_ENABLED' in os.environ and + int(os.environ['TORCH_CUDNN_V8_API_ENABLED']) and + torch.cuda.get_device_capability() >= (8, 0)) + should_error_from_not_implemented = should_error_from_cudnn or 'prelu' in op or 'thnn' in op \ or 'fused' in op or 'gru' in op or op == '_thnn_fused_lstm_cell' or op == 'lstm_cell' if not skip_test: if should_error_from_not_implemented: @@ -3749,7 +3751,10 @@ def test_graph_grad_scaling(self): @unittest.skipIf((not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") - def test_graph_make_graphed_callables(self): + @parametrize('with_amp,cache_enabled', [(False, False), (True, False), (True, True)], + name_fn=lambda x, y: '{}{}'.format({True: "with_amp", False: "without_amp"}[x], + {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else '')) + def test_graph_make_graphed_callables(self, with_amp, cache_enabled): torch.manual_seed(5) torch.cuda.manual_seed(5) @@ -3780,9 +3785,10 @@ def test_graph_make_graphed_callables(self): relu_control = torch.nn.functional.relu # This is a good stress test. It graphs four callables: two Modules and two python functions. - model_graphed[0], model_graphed[1], relu_graphed, loss_fn_graphed = \ - torch.cuda.make_graphed_callables((model_graphed[0], model_graphed[1], relu_control, loss_fn_control), - ((x,), (h,), (y_pred,), (y_pred, y))) + with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): + model_graphed[0], model_graphed[1], relu_graphed, loss_fn_graphed = \ + torch.cuda.make_graphed_callables((model_graphed[0], model_graphed[1], relu_control, loss_fn_control), + ((x,), (h,), (y_pred,), (y_pred, y))) real_inputs = [torch.rand_like(x) for _ in range(10)] real_targets = [torch.rand_like(y) for _ in range(10)] @@ -3797,10 +3803,11 @@ def test_graph_make_graphed_callables(self): torch.cuda.manual_seed(5) for data, target in zip(real_inputs, real_targets): opt.zero_grad(set_to_none=True) - y_pred = m(data) - y_pred = relu(y_pred) - loss = loss_fn(y_pred, target) - loss.backward() + with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): + y_pred = m(data) + y_pred = relu(y_pred) + loss = loss_fn(y_pred, target) + loss.backward() opt.step() for p, pc in zip(model_graphed.parameters(), model_control.parameters()): @@ -3951,6 +3958,15 @@ def forward(self, x): loss.backward() optimizer.step() + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUDA_VISIBLE_DEVICES") + @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient") + def test_lazy_init(self): + """ Validate that no CUDA calls are made during `import torch` call""" + from subprocess import check_output + test_script = "import os; import torch;os.environ['CUDA_VISIBLE_DEVICES']='32';print(torch.cuda.device_count())" + rc = check_output([sys.executable, '-c', test_script]).decode("ascii").strip() + self.assertEqual(rc, "0") + class TestCudaComm(TestCase): def _test_broadcast(self, input): @@ -4391,5 +4407,7 @@ class TestNamedTupleInput_1(NamedTuple): cat = torch.cat((outputs[0][i].to('cpu'), outputs[1][i].to('cpu'))) self.assertTrue(torch.equal(x, cat)) +instantiate_parametrized_tests(TestCuda) + if __name__ == '__main__': run_tests() diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 8e817b366f60d..b35352e3fe6b0 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -3,7 +3,6 @@ import math import sys import errno -import multiprocessing import os import ctypes import faulthandler @@ -1124,6 +1123,8 @@ def test_multiple_dataloaders(self): next(loader2_it) next(loader1_it) next(loader2_it) + del loader1_it + del loader2_it def test_segfault(self): p = ErrorTrackingProcess(target=_test_segfault) @@ -1892,7 +1893,10 @@ def test_proper_exit(self): # - `None` means that no error happens. # In all cases, all processes should end properly. if use_workers: - exit_methods = [None, 'loader_error', 'loader_kill', 'worker_error', 'worker_kill'] + # TODO: Fix test for 'loader_kill' that would cause running out of shared memory. + # Killing loader process would prevent DataLoader iterator clean up all queues + # and worker processes + exit_methods = [None, 'loader_error', 'worker_error', 'worker_kill'] persistent_workers = self.persistent_workers else: exit_methods = [None, 'loader_error', 'loader_kill'] @@ -2723,22 +2727,28 @@ def __iter__(self): after = os.sched_getaffinity(0) return iter(after) - -def worker_set_affinity(_): - os.sched_setaffinity(0, [multiprocessing.cpu_count() - 1]) - - @unittest.skipIf( not hasattr(os, 'sched_setaffinity'), "os.sched_setaffinity is not available") class TestSetAffinity(TestCase): def test_set_affinity_in_worker_init(self): + # Query the current affinity mask to avoid setting a disallowed one + old_affinity = os.sched_getaffinity(0) + if not old_affinity: + self.skipTest("No affinity information") + # Choose any + expected_affinity = list(old_affinity)[-1] + + def worker_set_affinity(_): + os.sched_setaffinity(0, [expected_affinity]) + + dataset = SetAffinityDataset() dataloader = torch.utils.data.DataLoader( dataset, num_workers=2, worker_init_fn=worker_set_affinity) for sample in dataloader: - self.assertEqual(sample, [multiprocessing.cpu_count() - 1]) + self.assertEqual(sample, [expected_affinity]) class ConvDataset(Dataset): def __init__(self): diff --git a/test/test_datapipe.py b/test/test_datapipe.py index e5d0e65d90204..520e266322e65 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -30,7 +30,6 @@ import numpy as np import torch -import torch.utils.data.backward_compatibility import torch.utils.data.datapipes as dp import torch.utils.data.graph import torch.utils.data.graph_settings @@ -50,6 +49,9 @@ from torch.utils.data.datapipes.utils.decoder import ( basichandlers as decoder_basichandlers, ) +from torch.utils.data.datapipes.utils.snapshot import ( + _simple_graph_snapshot_restoration +) from torch.utils.data.datapipes.dataframe import CaptureDataFrame from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper @@ -277,7 +279,7 @@ def tearDown(self): def test_listdirfiles_iterable_datapipe(self): temp_dir = self.temp_dir.name - datapipe = dp.iter.FileLister(temp_dir, '') + datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, '') count = 0 for pathname in datapipe: @@ -496,7 +498,6 @@ def operations(df): self.compare_capture_and_eager(operations) -@skipIf(True, "Fix DataFramePipes Tests") class TestDataFramesPipes(TestCase): """ Most of test will fail if pandas instaled, but no dill available. @@ -520,7 +521,9 @@ def test_capture(self): dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0])) df_numbers = self._get_dataframes_pipe() df_numbers['k'] = df_numbers['j'] + df_numbers.i * 3 - self.assertEqual(list(dp_numbers), list(df_numbers)) + expected = list(dp_numbers) + actual = list(df_numbers) + self.assertEqual(expected, actual) @skipIfNoDataFrames @skipIfNoDill @@ -554,7 +557,40 @@ def test_unbatch(self): @skipIfNoDill def test_filter(self): df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5) - self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], list(df_numbers)) + actual = list(df_numbers) + self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], actual) + + @skipIfNoDataFrames + @skipIfNoDill + def test_collate(self): + def collate_i(column): + return column.sum() + + def collate_j(column): + return column.prod() + df_numbers = self._get_dataframes_pipe(range=30).batch(3) + df_numbers = df_numbers.collate({'j': collate_j, 'i': collate_i}) + + expected_i = [3, + 12, + 21, + 30, + 39, + 48, + 57, + 66, + 75, + 84, ] + + actual_i = [] + for i, j in df_numbers: + actual_i.append(i) + self.assertEqual(expected_i, actual_i) + + actual_i = [] + for item in df_numbers: + actual_i.append(item.i) + self.assertEqual(expected_i, actual_i) class IDP_NoLen(IterDataPipe): @@ -603,6 +639,11 @@ def _worker_init_fn(worker_id): torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id) +lambda_fn1 = lambda x: x # noqa: E731 +lambda_fn2 = lambda x: x % 2 # noqa: E731 +lambda_fn3 = lambda x: x >= 5 # noqa: E731 + + class TestFunctionalIterDataPipe(TestCase): def _serialization_test_helper(self, datapipe, use_dill): @@ -702,16 +743,41 @@ def test_serializable(self): def test_serializable_with_dill(self): """Only for DataPipes that take in a function as argument""" input_dp = dp.iter.IterableWrapper(range(10)) - unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [ - (dp.iter.Collator, (lambda x: x,), {}), - (dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}), - (dp.iter.Filter, (lambda x: x >= 5,), {}), - (dp.iter.Grouper, (lambda x: x >= 5,), {}), - (dp.iter.Mapper, (lambda x: x,), {}), + + datapipes_with_lambda_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [ + (dp.iter.Collator, (lambda_fn1,), {}), + (dp.iter.Demultiplexer, (2, lambda_fn2,), {}), + (dp.iter.Filter, (lambda_fn3,), {}), + (dp.iter.Grouper, (lambda_fn3,), {}), + (dp.iter.Mapper, (lambda_fn1,), {}), + ] + + def _local_fns(): + def _fn1(x): + return x + + def _fn2(x): + return x % 2 + + def _fn3(x): + return x >= 5 + + return _fn1, _fn2, _fn3 + + fn1, fn2, fn3 = _local_fns() + + datapipes_with_local_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [ + (dp.iter.Collator, (fn1,), {}), + (dp.iter.Demultiplexer, (2, fn2,), {}), + (dp.iter.Filter, (fn3,), {}), + (dp.iter.Grouper, (fn3,), {}), + (dp.iter.Mapper, (fn1,), {}), ] + dp_compare_children = {dp.iter.Demultiplexer} + if HAS_DILL: - for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: + for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn: if dpipe in dp_compare_children: dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] self._serialization_test_for_dp_with_children(dp1, dp2, use_dill=True) @@ -719,13 +785,16 @@ def test_serializable_with_dill(self): datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] self._serialization_test_for_single_dp(datapipe, use_dill=True) else: - for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: - with warnings.catch_warnings(record=True) as wa: - datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle") - with self.assertRaises(AttributeError): - p = pickle.dumps(datapipe) + msgs = ( + r"^Lambda function is not supported by pickle", + r"^Local function is not supported by pickle" + ) + for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs): + for dpipe, dp_args, dp_kwargs in dps: + with self.assertWarnsRegex(UserWarning, msg): + datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] + with self.assertRaises((pickle.PicklingError, AttributeError)): + pickle.dumps(datapipe) def test_iterable_wrapper_datapipe(self): @@ -795,7 +864,7 @@ def test_fork_iterdatapipe(self): with self.assertRaises(ValueError): input_dp.fork(num_instances=0) - dp0 = input_dp.fork(num_instances=1) + dp0 = input_dp.fork(num_instances=1, buffer_size=0) self.assertEqual(dp0, input_dp) # Functional Test: making sure all child DataPipe shares the same reference @@ -816,15 +885,19 @@ def test_fork_iterdatapipe(self): self.assertEqual([(i, i) for i in range(10)], output) # Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small - dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5) + dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=4) it1 = iter(dp1) - for _ in range(5): + for _ in range(4): next(it1) with self.assertRaises(BufferError): next(it1) with self.assertRaises(BufferError): list(dp2) + dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5) + with self.assertRaises(BufferError): + list(dp2) + # Functional Test: one child DataPipe yields all value first with unlimited buffer with warnings.catch_warnings(record=True) as wa: dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1) @@ -1145,42 +1218,43 @@ def fn_n1(d0, d1): def fn_nn(d0, d1): return -d0, -d1, d0 + d1 - def _helper(ref_fn, fn, input_col=None, output_col=None): + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): for constr in (list, tuple): datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) - res_dp = datapipe.map(fn, input_col, output_col) - ref_dp = datapipe.map(ref_fn) - self.assertEqual(list(res_dp), list(ref_dp)) - # Reset - self.assertEqual(list(res_dp), list(ref_dp)) + if ref_fn is None: + with self.assertRaises(error): + res_dp = datapipe.map(fn, input_col, output_col) + list(res_dp) + else: + res_dp = datapipe.map(fn, input_col, output_col) + ref_dp = datapipe.map(ref_fn) + self.assertEqual(list(res_dp), list(ref_dp)) + # Reset + self.assertEqual(list(res_dp), list(ref_dp)) # Replacing with one input column and default output column _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1) # The index of input column is out of range - with self.assertRaises(IndexError): - _helper(None, fn_1n, 3) + _helper(None, fn_1n, 3, error=IndexError) # Unmatched input columns with fn arguments - with self.assertRaises(TypeError): - _helper(None, fn_n1, 1) + _helper(None, fn_n1, 1, error=TypeError) + # Replacing with multiple input columns and default output column (the left-most input column) _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) _helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1]) # output_col can only be specified when input_col is not None - with self.assertRaises(ValueError): - _helper(None, fn_n1, None, 1) + _helper(None, fn_n1, None, 1, error=ValueError) # output_col can only be single-element list or tuple - with self.assertRaises(ValueError): - _helper(None, fn_n1, None, [0, 1]) + _helper(None, fn_n1, None, [0, 1], error=ValueError) # Single-element list as output_col _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) # Replacing with one input column and single specified output column _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) # The index of output column is out of range - with self.assertRaises(IndexError): - _helper(None, fn_1n, 1, 3) + _helper(None, fn_1n, 1, 3, error=IndexError) _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0) @@ -1213,38 +1287,39 @@ def _dict_update(data, newdata, remove_idx=None): del _data[idx] return _data - def _helper(ref_fn, fn, input_col=None, output_col=None): + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): datapipe = dp.iter.IterableWrapper( [{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}] ) - res_dp = datapipe.map(fn, input_col, output_col) - ref_dp = datapipe.map(ref_fn) - self.assertEqual(list(res_dp), list(ref_dp)) - # Reset - self.assertEqual(list(res_dp), list(ref_dp)) + if ref_fn is None: + with self.assertRaises(error): + res_dp = datapipe.map(fn, input_col, output_col) + list(res_dp) + else: + res_dp = datapipe.map(fn, input_col, output_col) + ref_dp = datapipe.map(ref_fn) + self.assertEqual(list(res_dp), list(ref_dp)) + # Reset + self.assertEqual(list(res_dp), list(ref_dp)) # Replacing with one input column and default output column _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y") # The key of input column is not in dict - with self.assertRaises(KeyError): - _helper(None, fn_1n, "a") + _helper(None, fn_1n, "a", error=KeyError) # Unmatched input columns with fn arguments - with self.assertRaises(TypeError): - _helper(None, fn_n1, "y") + _helper(None, fn_n1, "y", error=TypeError) # Replacing with multiple input columns and default output column (the left-most input column) _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"]) _helper(lambda data: _dict_update( data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), fn_nn, ["z", "y"]) # output_col can only be specified when input_col is not None - with self.assertRaises(ValueError): - _helper(None, fn_n1, None, "x") + _helper(None, fn_n1, None, "x", error=ValueError) # output_col can only be single-element list or tuple - with self.assertRaises(ValueError): - _helper(None, fn_n1, None, ["x", "y"]) + _helper(None, fn_n1, None, ["x", "y"], error=ValueError) # Single-element list as output_col _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) # Replacing with one input column and single specified output column @@ -1567,8 +1642,9 @@ def test_zip_iterdatapipe(self): # Reset Test: n_elements_before_reset = 3 res_before_reset, res_after_reset = reset_after_n_next_calls(zipped_dp, n_elements_before_reset) - self.assertEqual(list((i, i) for i in range(5))[:n_elements_before_reset], res_before_reset) - self.assertEqual(list((i, i) for i in range(5)), res_after_reset) + expected_res = [(i, i) for i in range(5)] + self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) + self.assertEqual(expected_res, res_after_reset) class TestFunctionalMapDataPipe(TestCase): @@ -1617,24 +1693,41 @@ def test_serializable(self): def test_serializable_with_dill(self): """Only for DataPipes that take in a function as argument""" input_dp = dp.map.SequenceWrapper(range(10)) - unpicklable_datapipes: List[ + + datapipes_with_lambda_fn: List[ Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] ] = [ - (dp.map.Mapper, (lambda x: x,), {}), + (dp.map.Mapper, (lambda_fn1,), {}), ] + + def _local_fns(): + def _fn1(x): + return x + + return _fn1 + + fn1 = _local_fns() + + datapipes_with_local_fn: List[ + Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] + ] = [ + (dp.map.Mapper, (fn1,), {}), + ] + if HAS_DILL: - for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: + for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn: _ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] else: - for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: - with warnings.catch_warnings(record=True) as wa: - datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self.assertEqual(len(wa), 1) - self.assertRegex( - str(wa[0].message), r"^Lambda function is not supported for pickle" - ) - with self.assertRaises(AttributeError): - p = pickle.dumps(datapipe) + msgs = ( + r"^Lambda function is not supported by pickle", + r"^Local function is not supported by pickle" + ) + for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs): + for dpipe, dp_args, dp_kwargs in dps: + with self.assertWarnsRegex(UserWarning, msg): + datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] + with self.assertRaises((pickle.PicklingError, AttributeError)): + pickle.dumps(datapipe) def test_sequence_wrapper_datapipe(self): seq = list(range(10)) @@ -2122,39 +2215,116 @@ def __iter__(self): class TestGraph(TestCase): - @skipIfNoDill + class CustomIterDataPipe(IterDataPipe): + def add_v(self, x): + return x + self.v + + def __init__(self, source_dp, v=1): + self._dp = source_dp.map(self.add_v) + self.v = 1 + + def __iter__(self): + yield from self._dp + + def __hash__(self): + raise NotImplementedError + + def test_simple_traverse(self): numbers_dp = NumbersDataset(size=50) mapped_dp = numbers_dp.map(lambda x: x * 10) - graph = torch.utils.data.graph.traverse(mapped_dp) - expected: Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}} + graph = torch.utils.data.graph.traverse(mapped_dp, only_datapipe=True) + expected: Dict[Any, Any] = {id(mapped_dp): (mapped_dp, {id(numbers_dp): (numbers_dp, {})})} self.assertEqual(expected, graph) - @skipIfNoDill + dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) + self.assertEqual(len(dps), 2) + self.assertTrue(numbers_dp in dps) + self.assertTrue(mapped_dp in dps) + def test_traverse_forked(self): numbers_dp = NumbersDataset(size=50) dp0, dp1, dp2 = numbers_dp.fork(num_instances=3) dp0_upd = dp0.map(lambda x: x * 10) dp1_upd = dp1.filter(lambda x: x % 3 == 1) combined_dp = dp0_upd.mux(dp1_upd, dp2) - graph = torch.utils.data.graph.traverse(combined_dp) - expected = {combined_dp: {dp0_upd: {dp0: {dp0.main_datapipe: {dp0.main_datapipe.main_datapipe: {}}}}, - dp1_upd: {dp1: {dp1.main_datapipe: {dp1.main_datapipe.main_datapipe: {}}}}, - dp2: {dp2.main_datapipe: {dp2.main_datapipe.main_datapipe: {}}}}} + graph = torch.utils.data.graph.traverse(combined_dp, only_datapipe=True) + expected = { + id(combined_dp): (combined_dp, { + id(dp0_upd): (dp0_upd, { + id(dp0): (dp0, { + id(dp0.main_datapipe): (dp0.main_datapipe, { + id(dp0.main_datapipe.main_datapipe): (dp0.main_datapipe.main_datapipe, {}) + }) + }) + }), + id(dp1_upd): (dp1_upd, { + id(dp1): (dp1, { + id(dp1.main_datapipe): (dp1.main_datapipe, { + id(dp1.main_datapipe.main_datapipe): (dp1.main_datapipe.main_datapipe, {}) + }) + }) + }), + id(dp2): (dp2, { + id(dp2.main_datapipe): (dp2.main_datapipe, { + id(dp2.main_datapipe.main_datapipe): (dp2.main_datapipe.main_datapipe, {}) + }) + }) + }) + } self.assertEqual(expected, graph) + dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) + self.assertEqual(len(dps), 8) + for _dp in [numbers_dp, dp0.main_datapipe, dp0, dp1, dp2, dp0_upd, dp1_upd, combined_dp]: + self.assertTrue(_dp in dps) + def test_traverse_mapdatapipe(self): source_dp = dp.map.SequenceWrapper(range(10)) map_dp = source_dp.map(partial(_fake_add, 1)) graph = torch.utils.data.graph.traverse(map_dp) - expected: Dict[Any, Any] = {map_dp: {source_dp: {}}} + expected: Dict[Any, Any] = {id(map_dp): (map_dp, {id(source_dp): (source_dp, {})})} self.assertEqual(expected, graph) def test_traverse_mixdatapipe(self): source_map_dp = dp.map.SequenceWrapper(range(10)) iter_dp = dp.iter.IterableWrapper(source_map_dp) graph = torch.utils.data.graph.traverse(iter_dp) - expected: Dict[Any, Any] = {iter_dp: {source_map_dp: {}}} + expected: Dict[Any, Any] = {id(iter_dp): (iter_dp, {id(source_map_dp): (source_map_dp, {})})} + self.assertEqual(expected, graph) + + def test_traverse_circular_datapipe(self): + source_iter_dp = dp.iter.IterableWrapper(list(range(10))) + circular_dp = TestGraph.CustomIterDataPipe(source_iter_dp) + graph = torch.utils.data.graph.traverse(circular_dp, only_datapipe=True) + # See issue: https://github.com/pytorch/data/issues/535 + expected: Dict[Any, Any] = { + id(circular_dp): (circular_dp, { + id(circular_dp._dp): (circular_dp._dp, { + id(source_iter_dp): (source_iter_dp, {}) + }) + }) + } + self.assertEqual(expected, graph) + + dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph) + self.assertEqual(len(dps), 3) + for _dp in [circular_dp, circular_dp._dp, source_iter_dp]: + self.assertTrue(_dp in dps) + + def test_traverse_unhashable_datapipe(self): + source_iter_dp = dp.iter.IterableWrapper(list(range(10))) + unhashable_dp = TestGraph.CustomIterDataPipe(source_iter_dp) + graph = torch.utils.data.graph.traverse(unhashable_dp, only_datapipe=True) + with self.assertRaises(NotImplementedError): + hash(unhashable_dp) + expected: Dict[Any, Any] = { + id(unhashable_dp): (unhashable_dp, { + id(unhashable_dp._dp): (unhashable_dp._dp, { + id(source_iter_dp): (source_iter_dp, {}) + }) + }) + } self.assertEqual(expected, graph) @@ -2165,7 +2335,7 @@ def unbatch(x): class TestSerialization(TestCase): @skipIfNoDill def test_spawn_lambdas_iter(self): - idp = dp.iter.IterableWrapper(range(3)).map(lambda x: x + 1) + idp = dp.iter.IterableWrapper(range(3)).map(lambda x: x + 1).shuffle() dl = DataLoader(idp, num_workers=2, shuffle=True, multiprocessing_context='spawn', collate_fn=unbatch, batch_size=1) result = list(dl) @@ -2173,7 +2343,7 @@ def test_spawn_lambdas_iter(self): @skipIfNoDill def test_spawn_lambdas_map(self): - mdp = dp.map.SequenceWrapper(range(6)).map(lambda x: x + 1) + mdp = dp.map.SequenceWrapper(range(6)).map(lambda x: x + 1).shuffle() dl = DataLoader(mdp, num_workers=2, shuffle=True, multiprocessing_context='spawn', collate_fn=unbatch, batch_size=1) result = list(dl) @@ -2204,45 +2374,104 @@ def __iter__(self): yield from self._dp def test_circular_serialization_with_pickle(self): - from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, _DemultiplexerIterDataPipe - - def _get_name(datapipe): - return datapipe.__name__ - # Test for circular reference issue with pickle - source_dp = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn) - self.assertTrue(list(source_dp) == - list(pickle.loads(pickle.dumps(TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn))))) - res1 = traverse(source_dp, only_datapipe=True) - res2 = traverse(source_dp, only_datapipe=False) - expected_str1 = str({source_dp: - {_get_name(dp.iter.IterableWrapper): {}, - _get_name(_ChildDataPipe): - {_get_name(_DemultiplexerIterDataPipe): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.IterableWrapper): {}}}}}}} - ).replace("'", "") - expected_str2 = str({source_dp: - {_get_name(dp.iter.IterableWrapper): {}, - _get_name(_ChildDataPipe): - {_get_name(_DemultiplexerIterDataPipe): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.IterableWrapper): {}}, - _get_name(dp.iter.IterableWrapper): {}}}}}} - ).replace("'", "") - # For simplicity, compare the resulting string instead of trying to recreate the object - self.assertEqual(expected_str1, str(res1)) - self.assertEqual(expected_str2, str(res2)) - dp1 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn) + self.assertTrue(list(dp1) == list(pickle.loads(pickle.dumps(dp1)))) + + child_1 = dp1._dp + dm_1 = child_1.main_datapipe + m2_1 = dm_1.main_datapipe + m1_1 = m2_1.datapipe + src_1 = m1_1.datapipe + + res1 = traverse(dp1, only_datapipe=True) + res2 = traverse(dp1, only_datapipe=False) + + exp_res_1 = {id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}) + })}) + })} + exp_res_2 = {id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, { + id(m1_1): (m1_1, {id(src_1): (src_1, {})}), + id(src_1): (src_1, {}) + }) + })}) + })} + + self.assertEqual(res1, exp_res_1) + self.assertEqual(res2, exp_res_2) + dp2 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn, source_dp=dp1) self.assertTrue(list(dp2) == list(pickle.loads(pickle.dumps(dp2)))) + + child_2 = dp2._dp + dm_2 = child_2.main_datapipe + m2_2 = dm_2.main_datapipe + m1_2 = m2_2.datapipe + res3 = traverse(dp2, only_datapipe=True) res4 = traverse(dp2, only_datapipe=False) - self.assertTrue(str(dp2) in str(res3)) # Quick check to ensure the result isn't blank - self.assertTrue(str(dp2) in str(res4)) + exp_res_3 = {id(dp2): (dp2, { + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}) + })}) + }), + id(child_2): (child_2, {id(dm_2): (dm_2, { + id(m2_2): (m2_2, {id(m1_2): (m1_2, { + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}) + })}) + }), + })}) + })}) + })} + exp_res_4 = {id(dp2): (dp2, { + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, { + id(m1_1): (m1_1, {id(src_1): (src_1, {})}), + id(src_1): (src_1, {}) + }) + })}) + }), + id(child_2): (child_2, {id(dm_2): (dm_2, { + id(m2_2): (m2_2, { + id(m1_2): (m1_2, { + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, { + id(m1_1): (m1_1, {id(src_1): (src_1, {})}), + id(src_1): (src_1, {}) + }) + })}) + }) + }), + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, { + id(m1_1): (m1_1, {id(src_1): (src_1, {})}), + id(src_1): (src_1, {}) + }) + })}) + }) + }) + })}) + })} + + self.assertEqual(res3, exp_res_3) + self.assertEqual(res4, exp_res_4) class LambdaIterDataPipe(CustomIterDataPipe): @@ -2253,48 +2482,106 @@ def __init__(self, fn, source_dp=None): self._dp = self.source_dp.map(self.add_one).map(self.lambda_fn).map(self.add_v).demux(2, self.classify)[0] @skipIfNoDill + @skipIf(True, "Dill Tests") def test_circular_serialization_with_dill(self): - from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, _DemultiplexerIterDataPipe - - def _get_name(datapipe): - return datapipe.__name__ - # Test for circular reference issue with dill - self.assertTrue(list(TestCircularSerialization.LambdaIterDataPipe(lambda x: x + 1)) == - list(dill.loads(dill.dumps(TestCircularSerialization.LambdaIterDataPipe(lambda x: x + 1))))) - source_dp = TestCircularSerialization.LambdaIterDataPipe(fn=_fake_fn) - res1 = traverse(source_dp, only_datapipe=True) - res2 = traverse(source_dp, only_datapipe=False) - expected_str1 = str({source_dp: - {_get_name(dp.iter.IterableWrapper): {}, - _get_name(_ChildDataPipe): - {_get_name(_DemultiplexerIterDataPipe): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.IterableWrapper): {}}}}}}}} - ).replace("'", "") - expected_str2 = str({source_dp: - {_get_name(dp.iter.IterableWrapper): {}, - _get_name(_ChildDataPipe): - {_get_name(_DemultiplexerIterDataPipe): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.Mapper): - {_get_name(dp.iter.IterableWrapper): {}}}, - _get_name(dp.iter.IterableWrapper): {}}}}}} - ).replace("'", "") - # For simplicity, compare the resulting string instead of trying to recreate the object - self.assertEqual(expected_str1, str(res1)) - self.assertEqual(expected_str2, str(res2)) - - dp1 = TestCircularSerialization.LambdaIterDataPipe(fn=_fake_fn) + dp1 = TestCircularSerialization.LambdaIterDataPipe(lambda x: x + 1) + self.assertTrue(list(dp1) == list(dill.loads(dill.dumps(dp1)))) + + child_1 = dp1._dp + dm_1 = child_1.main_datapipe + m2_1 = dm_1.main_datapipe + m1_1 = m2_1.datapipe + src_1 = m1_1.datapipe + + res1 = traverse(dp1, only_datapipe=True) + res2 = traverse(dp1, only_datapipe=False) + + exp_res_1 = {id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}) + })}) + })} + exp_res_2 = {id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, { + id(m1_1): (m1_1, {id(src_1): (src_1, {})}), + id(src_1): (src_1, {}) + }) + })}) + })} + + self.assertEqual(res1, exp_res_1) + self.assertEqual(res2, exp_res_2) + dp2 = TestCircularSerialization.LambdaIterDataPipe(fn=_fake_fn, source_dp=dp1) self.assertTrue(list(dp2) == list(dill.loads(dill.dumps(dp2)))) + + child_2 = dp2._dp + dm_2 = child_2.main_datapipe + m2_2 = dm_2.main_datapipe + m1_2 = m2_2.datapipe + res3 = traverse(dp2, only_datapipe=True) res4 = traverse(dp2, only_datapipe=False) - self.assertTrue(str(dp2) in str(res3)) # Quick check to ensure the result isn't blank - self.assertTrue(str(dp2) in str(res4)) + exp_res_3 = {id(dp2): (dp2, { + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}) + })}) + }), + id(child_2): (child_2, {id(dm_2): (dm_2, { + id(m2_2): (m2_2, {id(m1_2): (m1_2, { + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, {id(m1_1): (m1_1, {id(src_1): (src_1, {})})}) + })}) + }), + })}) + })}) + })} + exp_res_4 = {id(dp2): (dp2, { + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, { + id(m1_1): (m1_1, {id(src_1): (src_1, {})}), + id(src_1): (src_1, {}) + }) + })}) + }), + id(child_2): (child_2, {id(dm_2): (dm_2, { + id(m2_2): (m2_2, { + id(m1_2): (m1_2, { + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, { + id(m1_1): (m1_1, {id(src_1): (src_1, {})}), + id(src_1): (src_1, {}) + }) + })}) + }) + }), + id(dp1): (dp1, { + id(src_1): (src_1, {}), + id(child_1): (child_1, {id(dm_1): (dm_1, { + id(m2_1): (m2_1, { + id(m1_1): (m1_1, {id(src_1): (src_1, {})}), + id(src_1): (src_1, {}) + }) + })}) + }) + }) + })}) + })} + + self.assertEqual(res3, exp_res_3) + self.assertEqual(res4, exp_res_4) class TestSharding(TestCase): @@ -2355,8 +2642,7 @@ def test_old_dataloader(self): expected = list(dp0) dp0 = self._get_pipeline().sharding_filter() - dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2, - worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn) + dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2) items = [] for i in dl: items.append(i) @@ -2588,5 +2874,337 @@ def test_iterdatapipe_singleton_constraint_multiple_outputs(self): next(it1) self.assertEqual(1, next(it3)) +class TestIterDataPipeCountSampleYielded(TestCase): + + def _yield_count_test_helper(self, datapipe, n_expected_samples): + + # Functional Test: Check if number of samples yielded is as expected + res = list(datapipe) + self.assertEqual(len(res), datapipe._number_of_samples_yielded) + + # Functional Test: Check if the count is correct when DataPipe is partially read + it = iter(datapipe) + res = [] + for i, value in enumerate(it): + res.append(value) + if i == n_expected_samples - 1: + break + self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded) + + # Functional Test: Check for reset behavior and if iterator also works + it = iter(datapipe) # reset the DataPipe + res = list(it) + self.assertEqual(len(res), datapipe._number_of_samples_yielded) + + def test_iterdatapipe_sample_yielded_generator_function(self): + # Functional Test: `__iter__` is a generator function + datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10)) + self._yield_count_test_helper(datapipe, n_expected_samples=5) + + def test_iterdatapipe_sample_yielded_generator_function_exception(self): + # Functional Test: `__iter__` is a custom generator function with exception + class _CustomGeneratorFnDataPipe(IterDataPipe): + # This class's `__iter__` has a Runtime Error + def __iter__(self): + yield 0 + yield 1 + yield 2 + raise RuntimeError("Custom test error after yielding 3 elements") + yield 3 + + # Functional Test: Ensure the count is correct even when exception is raised + datapipe: IterDataPipe = _CustomGeneratorFnDataPipe() + with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"): + list(datapipe) + self.assertEqual(3, datapipe._number_of_samples_yielded) + + # Functional Test: Check for reset behavior and if iterator also works + it = iter(datapipe) # reset the DataPipe + with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"): + list(it) + self.assertEqual(3, datapipe._number_of_samples_yielded) + + def test_iterdatapipe_sample_yielded_return_self(self): + class _CustomGeneratorDataPipe(IterDataPipe): + # This class's `__iter__` is not a generator function + def __init__(self): + self.source = iter(range(10)) + + def __iter__(self): + return self.source + + def reset(self): + self.source = iter(range(10)) + + datapipe: IterDataPipe = _CustomGeneratorDataPipe() + self._yield_count_test_helper(datapipe, n_expected_samples=5) + + def test_iterdatapipe_sample_yielded_next(self): + class _CustomNextDataPipe(IterDataPipe): + # This class's `__iter__` returns `self` and has a `__next__` + def __init__(self): + self.source = iter(range(10)) + + def __iter__(self): + return self + + def __next__(self): + return next(self.source) + + def reset(self): + self.source = iter(range(10)) + + datapipe: IterDataPipe = _CustomNextDataPipe() + self._yield_count_test_helper(datapipe, n_expected_samples=5) + + def test_iterdatapipe_sample_yielded_next_exception(self): + class _CustomNextDataPipe(IterDataPipe): + # This class's `__iter__` returns `self` and has a `__next__` + def __init__(self): + self.source = iter(range(10)) + self.count = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.count == 3: + raise RuntimeError("Custom test error after yielding 3 elements") + self.count += 1 + return next(self.source) + + def reset(self): + self.count = 0 + self.source = iter(range(10)) + + # Functional Test: Ensure the count is correct even when exception is raised + datapipe: IterDataPipe = _CustomNextDataPipe() + with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"): + list(datapipe) + self.assertEqual(3, datapipe._number_of_samples_yielded) + + # Functional Test: Check for reset behavior and if iterator also works + it = iter(datapipe) # reset the DataPipe + with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"): + list(it) + self.assertEqual(3, datapipe._number_of_samples_yielded) + + +class _CustomNonGeneratorTestDataPipe(IterDataPipe): + def __init__(self): + self.n = 10 + self.source = list(range(self.n)) + + # This class's `__iter__` is not a generator function + def __iter__(self): + return iter(self.source) + + def __len__(self): + return self.n + + +class _CustomSelfNextTestDataPipe(IterDataPipe): + def __init__(self): + self.n = 10 + self.iter = iter(range(self.n)) + + def __iter__(self): + return self + + def __next__(self): + return next(self.iter) + + def reset(self): + self.iter = iter(range(self.n)) + + def __len__(self): + return self.n + + +class TestIterDataPipeGraphFastForward(TestCase): + + def _fast_forward_graph_test_helper(self, datapipe, fast_forward_fn, expected_res, n_iterations=3, rng=None): + if rng is None: + rng = torch.Generator() + rng = rng.manual_seed(0) + torch.utils.data.graph_settings.apply_shuffle_seed(datapipe, rng) + + # Test Case: fast forward works with list + rng.manual_seed(0) + fast_forward_fn(datapipe, n_iterations, rng) + actual_res = list(datapipe) + self.assertEqual(len(datapipe) - n_iterations, len(actual_res)) + self.assertEqual(expected_res[n_iterations:], actual_res) + + # Test Case: fast forward works with iterator + rng.manual_seed(0) + fast_forward_fn(datapipe, n_iterations, rng) + it = iter(datapipe) + actual_res = list(it) + self.assertEqual(len(datapipe) - n_iterations, len(actual_res)) + self.assertEqual(expected_res[n_iterations:], actual_res) + with self.assertRaises(StopIteration): + next(it) + + def test_simple_snapshot_graph(self): + graph1 = dp.iter.IterableWrapper(range(10)) + res1 = list(range(10)) + self._fast_forward_graph_test_helper(graph1, _simple_graph_snapshot_restoration, + expected_res=res1) + + graph2 = graph1.map(_mul_10) + res2 = [10 * x for x in res1] + self._fast_forward_graph_test_helper(graph2, _simple_graph_snapshot_restoration, + expected_res=res2) + + rng = torch.Generator() + graph3 = graph2.shuffle() + rng.manual_seed(0) + torch.utils.data.graph_settings.apply_shuffle_seed(graph3, rng) + res3 = list(graph3) + self._fast_forward_graph_test_helper(graph3, _simple_graph_snapshot_restoration, + expected_res=res3) + + graph4 = graph3.map(_mul_10) + res4 = [10 * x for x in res3] + self._fast_forward_graph_test_helper(graph4, _simple_graph_snapshot_restoration, + expected_res=res4) + + batch_size = 2 + graph5 = graph4.batch(batch_size) + res5 = [res4[i:i + batch_size] for i in range(0, len(res4), batch_size)] # .batch(2) + self._fast_forward_graph_test_helper(graph5, _simple_graph_snapshot_restoration, + expected_res=res5) + + # With `fork` and `zip` + cdp1, cdp2 = graph5.fork(2) + graph6 = cdp1.zip(cdp2) + rng = rng.manual_seed(100) + torch.utils.data.graph_settings.apply_shuffle_seed(graph6, rng) + res6 = [(x, x) for x in res5] + self._fast_forward_graph_test_helper(graph6, _simple_graph_snapshot_restoration, + expected_res=res6) + + # With `fork` and `concat` + graph7 = cdp1.concat(cdp2) + res7 = res5 * 2 + self._fast_forward_graph_test_helper(graph7, _simple_graph_snapshot_restoration, + expected_res=res7) + + # Raises an exception if the graph has already been restored + with self.assertRaisesRegex(RuntimeError, "Snapshot restoration cannot be applied."): + _simple_graph_snapshot_restoration(graph7, 1) + _simple_graph_snapshot_restoration(graph7, 1) + + def test_simple_snapshot_custom_non_generator(self): + graph = _CustomNonGeneratorTestDataPipe() + self._fast_forward_graph_test_helper(graph, _simple_graph_snapshot_restoration, expected_res=range(10)) + + def test_simple_snapshot_custom_self_next(self): + graph = _CustomSelfNextTestDataPipe() + self._fast_forward_graph_test_helper(graph, _simple_graph_snapshot_restoration, expected_res=range(10)) + + def _snapshot_test_helper(self, datapipe, expected_res, n_iter=3, rng=None): + """ + Extend the previous test with serialization and deserialization test. + """ + if rng is None: + rng = torch.Generator() + rng.manual_seed(0) + torch.utils.data.graph_settings.apply_shuffle_seed(datapipe, rng) + it = iter(datapipe) + for _ in range(n_iter): + next(it) + serialized_graph = pickle.dumps(datapipe) + deserialized_graph = pickle.loads(serialized_graph) + self.assertEqual(n_iter, datapipe._number_of_samples_yielded) + self.assertEqual(n_iter, deserialized_graph._number_of_samples_yielded) + + rng_for_deserialized = torch.Generator() + rng_for_deserialized.manual_seed(0) + _simple_graph_snapshot_restoration(deserialized_graph, n_iter, rng=rng_for_deserialized) + self.assertEqual(expected_res[n_iter:], list(it)) + self.assertEqual(expected_res[n_iter:], list(deserialized_graph)) + + def test_simple_snapshot_graph_with_serialization(self): + graph1 = dp.iter.IterableWrapper(range(10)) + res1 = list(range(10)) + self._snapshot_test_helper(graph1, expected_res=res1) + + graph2 = graph1.map(_mul_10) + res2 = [10 * x for x in res1] + self._snapshot_test_helper(graph2, expected_res=res2) + + rng = torch.Generator() + graph3 = graph2.shuffle() + rng.manual_seed(0) + torch.utils.data.graph_settings.apply_shuffle_seed(graph3, rng) + res3 = list(graph3) + self._snapshot_test_helper(graph3, expected_res=res3) + + graph4 = graph3.map(_mul_10) + res4 = [10 * x for x in res3] + self._snapshot_test_helper(graph4, expected_res=res4) + + batch_size = 2 + graph5 = graph4.batch(batch_size) + res5 = [res4[i:i + batch_size] for i in range(0, len(res4), batch_size)] # .batch(2) + self._snapshot_test_helper(graph5, expected_res=res5) + + # With `fork` and `zip` + cdp1, cdp2 = graph5.fork(2) + graph6 = cdp1.zip(cdp2) + res6 = [(x, x) for x in res5] + self._snapshot_test_helper(graph6, expected_res=res6) + + # With `fork` and `concat` + graph7 = cdp1.concat(cdp2) + res7 = res5 * 2 + self._snapshot_test_helper(graph7, expected_res=res7) + + def test_simple_snapshot_graph_repeated(self): + cdp1, cdp2 = dp.iter.IterableWrapper(range(10)).map(_mul_10).shuffle().map(_mul_10).map(_mul_10).fork(2) + graph = cdp1.zip(cdp2) + + rng = torch.Generator() + rng.manual_seed(0) + torch.utils.data.graph_settings.apply_shuffle_seed(graph, rng) + + # Get expected result + expected_res = list(graph) + + rng.manual_seed(0) + torch.utils.data.graph_settings.apply_shuffle_seed(graph, rng) + it = iter(graph) + n_iter = 3 + for _ in range(n_iter): + next(it) + + # First serialization/deserialization + serialized_graph = pickle.dumps(graph) + deserialized_graph = pickle.loads(serialized_graph) + + rng_for_deserialized = torch.Generator() + rng_for_deserialized.manual_seed(0) + _simple_graph_snapshot_restoration(deserialized_graph, deserialized_graph._number_of_samples_yielded, + rng=rng_for_deserialized) + + it = iter(deserialized_graph) + # Get the next element and ensure it is as expected + self.assertEqual(expected_res[3], next(it)) + + # Serializalize/Deserialize and fast-forward again after to ensure it works + serialized_graph2 = pickle.dumps(deserialized_graph) + deserialized_graph2 = pickle.loads(serialized_graph2) + + rng_for_deserialized = torch.Generator() + rng_for_deserialized.manual_seed(0) + _simple_graph_snapshot_restoration(deserialized_graph2, deserialized_graph._number_of_samples_yielded, + rng=rng_for_deserialized) + + # Get the next element and ensure it is as expected + self.assertEqual(expected_res[4:], list(deserialized_graph2)) + + if __name__ == '__main__': run_tests() diff --git a/test/test_decomp.py b/test/test_decomp.py index 78c59f3654d64..16f64a0229da8 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -15,6 +15,7 @@ suppress_warnings, TEST_WITH_ASAN, run_tests, + skipIfSlowGradcheckEnv, ) from torch.testing._internal.common_device_type import ( onlyNativeDeviceTypes, @@ -155,6 +156,8 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs) (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5, (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, + (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-6, + (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-6, } if ref.is_floating_point(): orig_diff = (orig - ref).abs().max() @@ -238,10 +241,8 @@ def wrapped(*primals): def upcast_tensor(x, dtype=torch.float32): if isinstance(x, Tensor) and x.dtype.is_floating_point: return x.to(dtype=dtype) - elif ( - isinstance(x, torch.dtype) - and x in [torch.float16, torch.bfloat16] - ): + elif (isinstance(x, torch.dtype) + and x in [torch.float16, torch.bfloat16, torch.float]): return dtype else: return x @@ -270,10 +271,15 @@ def normalize_op_input_output(f, sample, requires_grad=True): ("cuda", torch.bfloat16, "nn.functional.dropout"), ("cuda", torch.float64, "nn.functional.dropout"), ("cuda", torch.float32, "nn.functional.dropout"), + (None, None, "new_empty"), # decomp has problem even with opmath # doesn't work ("cuda", torch.bfloat16, "nn.functional.embedding"), + # CompositeAutogradImplicit + # See https://github.com/pytorch/pytorch/issues/81669 + (None, None, "nn.functional.relu6"), + } all_decomposed = set() @@ -325,6 +331,7 @@ def test_unsupported(t): return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs)) +@skipIfSlowGradcheckEnv class TestDecomp(TestCase): longMessage = True @@ -352,7 +359,7 @@ def do_cross_ref(self, device, dtype, op, *, run_all): None, dtype, op.name, - ) in CROSS_REF_EXCLUDE_SET: + ) in CROSS_REF_EXCLUDE_SET or (None, None, op.name) in CROSS_REF_EXCLUDE_SET: self.skipTest(f"{op.name} in {dtype} not supported") test_dtype = dtype @@ -383,7 +390,10 @@ def _torch_dispatch(cls, func, types, args=(), kwargs=None): # Stuff we shouldn't bother testing # (TODO: remove detach from the decomp table?) if func not in decomposition_table or func in [ - torch.ops.aten.detach.default + torch.ops.aten.detach.default, + # non-deterministic ops + torch.ops.aten.new_empty.default, + torch.ops.aten.new_empty.SymInt ] or any_unsupported(args, kwargs): return func(*args, **kwargs) @@ -441,8 +451,10 @@ def _torch_dispatch(cls, func, types, args=(), kwargs=None): def check_decomposed(aten_name): self.assertTrue( any(overload_to_aten_name(c) == aten_name for c in decomposed), - msg=f"aten.{aten_name} was not decomposed, saw calls for: " - + ", ".join(map(str, list(called))), + msg=(f"aten.{aten_name} was not decomposed, saw calls for: " + f"{', '.join(map(str, list(called)))}. If your op is " + f"CompositeImplicitAutograd you should skip this test " + "by updating CROSS_REF_EXCLUDE_SET.") ) aten_name = op.decomp_aten_name or op.aten_name diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 73084486f42ab..3d5662177e5ae 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -8,6 +8,8 @@ import unittest import torch from torch.utils._pytree import tree_map +from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt + aten = torch.ops.aten try: @@ -32,7 +34,7 @@ def add_func(op): @register_meta([aten.add.Tensor, aten.sub.Tensor]) def binary_meta(a, b): - return a.new_empty(a.sym_size()) + return a.new_empty(a.shape) @register_meta(aten.cat.default) @@ -53,7 +55,7 @@ def cat_meta(tensors, dim=0): @register_meta([aten.narrow_copy.SymInt]) def narrow_copy_symint_meta(a, dim, start, length, **kwargs): shape = [] - for i, x in enumerate(a.sym_size()): + for i, x in enumerate(a.shape): if i == dim: shape.append(length) else: @@ -66,68 +68,6 @@ def expand_symint_meta(a, size, implicit=False): return a.new_empty(size) -class PySymInt(object): - def __init__(self, expr, shape_env): - self.expr = expr - self.shape_env = shape_env - - def wrap(self, num): - return PySymInt(sympy.Integer(num), self.shape_env) - - def __str__(self): - return f"PySymInt({self.expr})" - - def __int__(self): - return self.shape_env.evaluate_expr(self.expr) - - def __bool__(self): - return bool(self.shape_env.evaluate_expr(self.expr)) - - -magic_methods = { - 'add': lambda a, b: a + b, - 'radd': lambda a, b: a + b, - 'sub': lambda a, b: a - b, - 'mul': lambda a, b: a * b, - 'div': lambda a, b: a / b, - 'mod': lambda a, b: a % b, - 'eq': lambda a, b: sympy.Eq(a, b), - 'gt': lambda a, b: sympy.Gt(a, b), - 'lt': lambda a, b: sympy.Lt(a, b), -} - -for method, func in magic_methods.items(): - method_name = f'{method}' - - def create_magic_impl(func): - def magic_impl(self, other): - if isinstance(other, PySymInt): - other = other.expr - return PySymInt(func(self.expr, other), self.shape_env) - return magic_impl - - # this should be wrapped transparently into torch._C.SymbolicIntNode - setattr(PySymInt, method_name, create_magic_impl(func)) - - -class ShapeEnv(object): - def __init__(self): - self.guards = [] - self.shape_env = {} - - def create_symint(self, name, val): - sympy_expr = sympy.Symbol(name) - py_sym_int = PySymInt(sympy_expr, self) - cpp_sym_int = torch._C.SymbolicIntNode.new_symint(py_sym_int) - self.shape_env[sympy_expr] = val - return cpp_sym_int - - def evaluate_expr(self, expr): - concrete_val = expr.subs(self.shape_env) - self.guards.append((expr, concrete_val)) - return concrete_val - - def create_contiguous(shape): strides = [1] for dim in reversed(shape[:-1]): @@ -147,6 +87,8 @@ def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device): dtype=dtype, layout=layout, requires_grad=requires_grad, device=device, ) + + r.sym_shape = sym_shape return r __torch_function__ = _disabled_torch_function_impl @@ -159,6 +101,18 @@ def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): if func_overload in meta_funcs: return meta_funcs[func_overload](*args, **kwargs) + if func_overload == torch.ops.aten.sym_size.default: + self = args[0] + return self.sym_shape + + # some calls can be redirected to `sym_size` rather than + # `sym_sizes`. `sym_size` uses `dim` to canonicalize an index + # so we need to implement both `sym_size` and `dim` for python + # tensors + if func_overload == torch.ops.aten.dim.default: + self = args[0] + return len(self.sym_shape) + if func_overload == torch.ops.aten.new_empty.default: self = args[0] shape = args[1] @@ -168,12 +122,12 @@ def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): def create_symbolic_tensor(name, arg, shape_env): - sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.sym_size())]) + sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.size())]) sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())]) return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device) -CPP_SYMINT_CLASS = type(torch._C.SymbolicIntNode.new_symint(1)) +CPP_SYMINT_CLASS = type(torch._C.SymIntNode.new_symint(1)) class TestPySymInt(TestCase): @@ -182,22 +136,22 @@ class TestPySymInt(TestCase): def test_roundtrip(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - self.assertTrue(not isinstance(x.sym_size(0), PySymInt)) - self.assertTrue(isinstance(x.sym_size(0), CPP_SYMINT_CLASS)) + self.assertTrue(not isinstance(x.shape[0], PySymInt)) + self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS)) - self.assertEqual(int(x.sym_size(0)), 5) - self.assertEqual(int(x.sym_size(1)), 4) - self.assertEqual(int(x.sym_size(2)), 3) + self.assertTrue(x.shape[0] == 5) + self.assertTrue(x.shape[1] == 4) + self.assertTrue(x.shape[2], 3) - self.assertEqual(int(x.sym_size()[0]), 5) - self.assertEqual(int(x.sym_size()[1]), 4) - self.assertTrue(isinstance(x.sym_size()[1], CPP_SYMINT_CLASS)) - self.assertEqual(int(x.sym_size()[2]), 3) + self.assertTrue(x.size()[0], 5) + self.assertTrue(x.size()[1], 4) + self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS)) + self.assertTrue(x.size()[2] == 3) - self.assertEqual(int(x.sym_size(0)), 5) - self.assertEqual(int(x.sym_size(1)), 4) - self.assertEqual(int(x.sym_size(2)), 3) - self.assertTrue(isinstance(x.sym_size(2), CPP_SYMINT_CLASS)) + self.assertTrue(x.size(0) == 5) + self.assertTrue(x.size(1) == 4) + self.assertTrue(x.size(2) == 3) + self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS)) @skipIfNoSympy def test_binary(self): @@ -206,16 +160,16 @@ def test_binary(self): y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) z = x + y - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + self.assertTrue(z.shape[0] == 5) + self.assertTrue(z.shape[1] == 4) + self.assertTrue(z.shape[2] == 3) # broadcasting y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) z = x + y - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + self.assertTrue(z.shape[0] == 5) + self.assertTrue(z.shape[1] == 4) + self.assertTrue(z.shape[2] == 3) @skipIfNoSympy def test_symint_args(self): @@ -223,16 +177,16 @@ def test_symint_args(self): x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) LAST_DIM = 2 - z = x.narrow_copy(LAST_DIM, 0, y.sym_size(LAST_DIM)) - self.assertEqual(int(z.sym_size(2)), int(y.sym_size(2))) + z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) + self.assertTrue(z.shape[2] == int(y.shape[2])) # arithmetic expr with two symints - z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - y.sym_size(LAST_DIM)) - self.assertEqual(int(z.sym_size(2)), 2) + z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) + self.assertTrue(z.shape[2] == 2) # arithmetic expr with a symint and python int - z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - 1) - self.assertEqual(int(z.sym_size(2)), 2) + z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) + self.assertTrue(z.shape[2] == 2) @skipIfNoSympy def test_symint_vargs(self): @@ -241,56 +195,70 @@ def test_symint_vargs(self): y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) # varargs - z = y.expand(x.sym_size(0), y.sym_size(1), x.sym_size(2)) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand(x.shape[0], y.shape[1], x.shape[2]) + self.assertTrue(z.shape[0] == 5) + self.assertTrue(z.shape[1] == 4) + self.assertTrue(z.shape[2] == 3) # shape list - z = y.expand((x.sym_size(0), y.sym_size(1), x.sym_size(2))) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand((x.shape[0], y.shape[1], x.shape[2])) + self.assertTrue(z.shape[0] == 5) + self.assertTrue(z.shape[1] == 4) + self.assertTrue(z.shape[2] == 3) # mixed python symints and ints - z = y.expand(x.sym_size(0), y.sym_size(1), 3) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand(x.shape[0], y.shape[1], 3) + self.assertTrue(z.shape[0] == 5) + self.assertTrue(z.shape[1] == 4) + self.assertTrue(z.shape[2] == 3) # mixed python symints and ints in a list - z = y.expand((x.sym_size(0), y.sym_size(1), 3)) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand((x.shape[0], y.shape[1], 3)) + self.assertTrue(z.shape[0] == 5) + self.assertTrue(z.shape[1] == 4) + self.assertTrue(z.shape[2] == 3) # mixed python symints and ints - z = y.expand(5, y.sym_size(1), x.sym_size(2)) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand(5, y.shape[1], x.shape[2]) + self.assertTrue(z.shape[0] == 5) + self.assertTrue(z.shape[1] == 4) + self.assertTrue(z.shape[2] == 3) # mixed python ints and symints in a list - z = y.expand((5, y.sym_size(1), x.sym_size(2))) - self.assertEqual(int(z.sym_size(0)), 5) - self.assertEqual(int(z.sym_size(1)), 4) - self.assertEqual(int(z.sym_size(2)), 3) + z = y.expand((5, y.shape[1], x.shape[2])) + self.assertTrue(z.shape[0] == 5) + self.assertTrue(z.shape[1] == 4) + self.assertTrue(z.shape[2] == 3) + + z = y.expand((y.shape[1],)) + z = y.expand(y.shape[1]) @skipIfNoSympy def test_size_expressions(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) - expand_x = x.expand(x.sym_size(0), x.sym_size(0)) - if expand_x.sym_size(0) > 3: + expand_x = x.expand(x.shape[0], x.shape[0]) + if expand_x.shape[0] > 3: result = expand_x + expand_x else: result = expand_x + expand_x gt_op = shape_env.guards[0][0] self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) - self.assertTrue(str(x.sym_size(0)), str(gt_op.args[0])) - self.assertTrue(str(expand_x.sym_size(1)), str(x.sym_size(0))) - self.assertTrue(str(expand_x.sym_size(1)), str(result.sym_size(0))) + self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) + self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) + self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) + + @skipIfNoSympy + def test_aten_ops(self): + + shape_env = ShapeEnv() + x = create_symbolic_tensor("x", torch.randn(5), shape_env) + torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0]) + + shape_env = ShapeEnv() + x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) + torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]]) def test_fx_trace_intlist(self): class CustomModule(torch.nn.Module): @@ -305,5 +273,13 @@ def forward(self, x): torch.fx.symbolic_trace(m) + @skipIfNoSympy + def test_meta_symint(self): + shape_env = ShapeEnv() + a0 = shape_env.create_symint("a0", 2) + r = torch.empty(a0, device='meta') + self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS) + + if __name__ == '__main__': run_tests() diff --git a/test/test_dynamo_cudagraphs.py b/test/test_dynamo_cudagraphs.py new file mode 100644 index 0000000000000..f64380a32ebfb --- /dev/null +++ b/test/test_dynamo_cudagraphs.py @@ -0,0 +1,192 @@ +# Owner(s): ["module: cuda graphs"] + +import functools +import sys + +from unittest.mock import patch + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase + +try: + import functorch # noqa: F401 + import torchdynamo + from torch.cuda._dynamo_graphs import aot_autograd_cudagraphs + + TEST_DYNAMO = True +except ImportError: + TEST_DYNAMO = False + +TEST_CUDA = torch.cuda.is_available() + +if not TEST_CUDA or not TEST_DYNAMO: + print("CUDA or dynamo not available, skipping tests", file=sys.stderr) + TestCase = object # noqa: F811 + + +def composed(*decs): + def deco(f): + for dec in reversed(decs): + f = dec(f) + return f + + return deco + + +def assert_aot_autograd_counter(ok=True): + def deco(f): + @functools.wraps(f) + def wrap(self, *args, **kwargs): + torchdynamo.utils.counters.clear() + r = f(self, *args, **kwargs) + c_ok = torchdynamo.utils.counters["aot_autograd"]["ok"] + c_not_ok = torchdynamo.utils.counters["aot_autograd"]["not_ok"] + if ok: + self.assertGreater(c_ok, 0) + self.assertEqual(c_not_ok, 0) + else: + self.assertEqual(c_ok, 0) + self.assertGreater(c_not_ok, 0) + return r + + return wrap + + return deco + + +def patch_all(ok=True): + return composed( + patch("torchdynamo.config.verify_correctness", True), + assert_aot_autograd_counter(ok), + ) + + +N_ITERS = 5 + + +class TestDynamoCudaGraphs(TestCase): + @patch_all() + def test_basic(self): + def model(x, y): + return (x + y) * y + + with torchdynamo.optimize(aot_autograd_cudagraphs): + for i in range(N_ITERS): + x = torch.randn(3, device="cuda", requires_grad=True) + y = torch.randn(3, device="cuda") + loss = model(x, y).sum() + loss.backward() + + @patch_all() + def test_dtoh(self): + def model(x, y): + a = x + y + b = a.cpu() * 3 + return b + + with torchdynamo.optimize(aot_autograd_cudagraphs): + for i in range(N_ITERS): + x = torch.randn(3, device="cuda", requires_grad=True) + y = torch.randn(3, device="cuda") + loss = model(x, y).sum() + loss.backward() + + @patch_all() + def test_htod(self): + def model(x, y): + a = x + y + return a * 3 + + with torchdynamo.optimize(aot_autograd_cudagraphs): + for i in range(N_ITERS): + x = torch.randn(3, device="cuda", requires_grad=True) + y = torch.randn((), device="cpu") + loss = model(x, y).sum() + loss.backward() + + @patch("functorch._src.config.use_functionalize", True) + @patch_all(ok=False) # input mutation not supported yet + def test_mutate_input(self): + def model(x, y): + y.add_(3) + return x * y + + with torchdynamo.optimize(aot_autograd_cudagraphs): + for i in range(N_ITERS): + with self.subTest(i): + x = torch.randn(3, device="cuda", requires_grad=True) + y = torch.randn(3, device="cuda") + y_orig = y.clone() + loss = model(x, y).sum() + self.assertEqual(y, y_orig + 3) + loss.backward() + + @patch_all() + def test_mutate_constant(self): + def model(x, y): + c = torch.tensor(1) + c.add_(2) + return x * y * 0 + c + + with torchdynamo.optimize(aot_autograd_cudagraphs): + for i in range(N_ITERS): + with self.subTest(i): + x = torch.randn(1, device="cuda", requires_grad=True) + y = torch.randn(1, device="cuda") + loss = model(x, y).sum() + self.assertEqual(loss, torch.tensor(3.0, device="cuda")) + loss.backward() + + @patch_all() + def test_factory(self): + def model(y): + x = torch.zeros(3, device="cuda:0") + x.add_(3) + return x * y + + with torchdynamo.optimize(aot_autograd_cudagraphs): + for i in range(N_ITERS): + with self.subTest(i): + y = torch.randn(3, device="cuda:0", requires_grad=True) + loss = model(y).sum() + loss.backward() + + @patch("functorch._src.config.use_functionalize", True) + @patch_all() + def test_mutated_metadata(self): + # more tortured example at + # https://github.com/pytorch/pytorch/issues/81385 + def model(x): + x = x.clone() + x.resize_(20) + x.fill_(2) + return x + + with torchdynamo.optimize(aot_autograd_cudagraphs): + for i in range(N_ITERS): + with self.subTest(i): + x = torch.empty(0, device="cuda:0") + rx = model(x) + self.assertEqual(rx, torch.full((20,), 2.0, device="cuda:0")) + + @patch("functorch._src.config.use_functionalize", True) + @patch_all() + def test_dead_fill(self): + def model(x): + x = x.clone() + y = x[0:0] + x.fill_(2) + y.fill_(3) + return x, y + + with torchdynamo.optimize(aot_autograd_cudagraphs): + for i in range(N_ITERS): + with self.subTest(i): + x = torch.empty(20, device="cuda:0") + rx, ry = model(x) + self.assertEqual(rx, torch.full((20,), 2.0, device="cuda:0")) + self.assertEqual(ry, torch.empty(0, device="cuda:0")) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_expanded_weights.py b/test/test_expanded_weights.py index a1eb96019cfd0..1b1f24d3eee6e 100644 --- a/test/test_expanded_weights.py +++ b/test/test_expanded_weights.py @@ -1,5 +1,5 @@ # Owner(s): ["module: nn"] - +from dataclasses import dataclass from functools import partial from itertools import product, chain import unittest @@ -12,7 +12,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests -from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests +from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests, parametrize from torch.testing._internal.common_methods_invocations import SampleInput, op_db from torch.nn.utils._expanded_weights import ExpandedWeight from torch.nn.utils._expanded_weights.expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \ @@ -27,8 +27,12 @@ def test_forward_helper(self, device): weight = torch.randn(5, 4, device=device) bias = torch.randn(5, device=device) for (weight_batched, bias_batched) in product([True, False], [True, False]): - maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3) if weight_batched else weight - maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3) if bias_batched else bias + maybe_batched_weight = weight + maybe_batched_bias = bias + if weight_batched: + maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3, loss_reduction="sum") + if bias_batched: + maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3, loss_reduction="sum") args = (input, maybe_batched_weight, maybe_batched_bias) expanded_args, expanded_kwargs = standard_kwargs(('bias',), args) res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) @@ -45,7 +49,7 @@ def test_forward_helper_failure_args(self, device): weight = torch.randn(5, 4, device=device) bias = torch.randn(5, device=device) with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."): - input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3) + input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum") expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"): @@ -61,18 +65,22 @@ def test_forward_helper_failure_args(self, device): for (weight_batched, bias_batched) in product([True, False], [True, False]): if not weight_batched and not bias_batched: continue - maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight - maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias + maybe_batched_weight = weight + maybe_batched_bias = bias + if weight_batched: + maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4, loss_reduction="sum") + if bias_batched: + maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4, loss_reduction="sum") with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"): expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias)) forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) def test_set_grad_sample_if_exists(self, device): - def test_fn(_): + def test_fn(a): return True orig_weight = torch.randn(4, device=device, requires_grad=True) - expanded_weight = ExpandedWeight(orig_weight, 3) + expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum") set_grad_sample_if_exists(expanded_weight, test_fn) self.assertTrue(hasattr(orig_weight, 'grad_sample')) self.assertTrue(orig_weight.grad_sample) @@ -86,7 +94,7 @@ def test_fn(_): self.assertFalse(hasattr(non_tensor, 'grad_sample')) def test_set_grad_sample_if_exists_failure(self, device): - def test_fn(_): + def test_fn(a): return True grad_tensor = torch.randn(4, requires_grad=True, device=device) @@ -95,7 +103,7 @@ def test_fn(_): def test_unpack_expanded_weight_or_tensor(self, device): input = torch.randn(3, requires_grad=True, device=device) - self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3))) + self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum"))) input.requires_grad_(False) self.assertEqual(input, unpack_expanded_weight_or_tensor(input)) @@ -103,7 +111,7 @@ def test_unpack_expanded_weight_or_tensor(self, device): def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device): input = torch.randn(3, requires_grad=True, device=device) - self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input)) + self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input)) input.requires_grad_(False) self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input)) @@ -131,38 +139,68 @@ def test_sum_over_all_but_batch_and_last_n(self, device): self.assertEqual(res, input) class TestExpandedWeightFunctional(TestCase): + def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction): + input = sample_input.input + args = sample_input.args + kwargs = sample_input.kwargs + batch_size = input.shape[0] if len(input.shape) > 1 else 1 + + # get per sample grads with ExpandedWeights objects + loss_reduction = "sum" if reduction == torch.sum else "mean" + (ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction) + diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) + diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] + diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list] + if not diff_input_list: + return + result = run_op(op, ew_input, *ew_args, **ew_kwargs) + reduction(result).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ + expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list) + + # get per sample grads with for loop + func = partial(run_op, op) + + per_sample_grad = for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs) + + # check equality + self.assertEqual(len(per_sample_grad), len(expanded_weight_grad)) + if loss_reduction == "mean": + # don't check equality of `input.grad`s since these vanilla tensors won't be scaled + expanded_weight_grad = expanded_weight_grad[1:] + per_sample_grad = per_sample_grad[1:] + for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad): + self.assertEqual(result_grad, expected_grad) + @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,)) - def test_expanded_weight_per_sample_grad(self, device, dtype, op): + def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op): sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) for sample_input in supported_inputs(op, sample_inputs): if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs) - input = sample_input.input - args = sample_input.args - kwargs = sample_input.kwargs - batch_size = input.shape[0] if len(input.shape) > 1 else 1 - - # get per sample grads with ExpandedWeights objects - (ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size) - diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) - diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] - diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list] - if not diff_input_list: - continue - result = run_op(op, ew_input, *ew_args, **ew_kwargs) - result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ - expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list) - # get per sample grads with for loop - func = partial(run_op, op) - per_sample_grad = for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs) + def reduction(x): + return x.sum() + + self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum) + + @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,)) + def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op): + sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) + for sample_input in supported_inputs(op, sample_inputs): + if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests + sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs) + + self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) + + @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,)) + def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op): + sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) + for sample_input in supported_inputs(op, sample_inputs): + if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests + sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs) + sample_input.input.requires_grad_(False) - # check equality - self.assertEqual(len(per_sample_grad), len(expanded_weight_grad)) - for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad): - if result_grad is None: - result_grad = torch.zeros_like(expected_grad) - self.assertEqual(result_grad, expected_grad) + self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,)) def test_unsupported_expand_weights(self, device, dtype, op): @@ -195,20 +233,54 @@ def test_expanded_weight_forward(self, device, dtype, op): if "cuda" in device and "max_norm" in sample_input.kwargs and "padding_idx" in sample_input.kwargs: self.skipTest("embedding is non-determinstic in this case, see issue #74679") batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1 - (ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size) - expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs) - normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs) - self.assertEqual(expanded_weight_result, normal_result) + for loss_reduction in ["sum", "mean"]: + (ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction) + expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs) + normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs) + self.assertEqual(expanded_weight_result, normal_result) def test_expanded_weight_error(self, device): batch_size = 3 sample_input = make_tensor((batch_size, 4), dtype=torch.float32, device=device, requires_grad=True) sample_weight = make_tensor((4), dtype=torch.float32, device=device, requires_grad=True) with self.assertRaisesRegex(RuntimeError, r"Expanded Weights encountered but cannot handle function"): - torch.add(sample_input, ExpandedWeight(sample_weight, batch_size)) + torch.add(sample_input, ExpandedWeight(sample_weight, batch_size, loss_reduction="sum")) + + def _test_embedding_model(self, model, num_embedding, device): + batch_size = 32 + input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device) + return self._test_model(partial(model, num_embedding=num_embedding), batch_size, input, device) + + def _test_conv_model(self, model, input_size, num_dim, device, loss_reduction="sum"): + batch_size = 32 + input_ending = [input_size] * num_dim + input = torch.randn([batch_size, 3] + input_ending, device=device) + return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device, loss_reduction) + + def _test_model(self, model, batch_size, input, device, loss_reduction="sum"): + model = model(10).to(device) + targets = torch.randint(0, 10, (batch_size,), device=device) + criterion = CrossEntropyLoss(reduction=loss_reduction) + result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input) + loss = criterion(result, targets) + loss.backward() + result = [] + for weight in model.parameters(): + result.append(weight.grad_sample) + del weight.grad_sample + + expected = [] + for i in range(batch_size): + loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) + expected.append(torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))) + + expected = [torch.stack(grad) for grad in zip(*expected)] + for (res, exp) in zip(result, expected): + self.assertEqual(res, exp, atol=1e-4, rtol=5e-5) + - def test_small_model(self, device): - def convnet(num_classes): + def test_cnn_model_sum(self, device): + def convnet(num_classes, num_dim): return nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), @@ -226,27 +298,75 @@ def convnet(num_classes): nn.Linear(128, num_classes, bias=True), ) - batch_size = 32 - model = convnet(10).to(device) - input = torch.randn([batch_size, 3, 28, 28], device=device) - targets = torch.randint(0, 10, (batch_size,), device=device) - criterion = CrossEntropyLoss(reduction='sum') # use a loss that doesn't average across the batch to test in a for loop - result = call_for_per_sample_grads(model, batch_size, input) - loss = criterion(result, targets) - loss.backward() - result = [] - for weight in model.parameters(): - result.append(weight.grad_sample) - del weight.grad_sample + return self._test_conv_model(convnet, 28, 2, device) - expected = [] - for i in range(batch_size): - loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) - expected.append(torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))) + def test_cnn_model_mean(self, device): + def convnet(num_classes, num_dim): + return nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(128, num_classes, bias=True), + ) - expected = [torch.stack(grad) for grad in zip(*expected)] - for (res, exp) in zip(result, expected): - self.assertEqual(res, exp, atol=1e-4, rtol=5e-5) + return self._test_conv_model(convnet, 28, 2, device, loss_reduction="mean") + + @parametrize('num_dim', [1, 2, 3]) + def test_instance_norm_model(self, num_dim, device): + def instance_norm_model(num_classes, num_dim): + conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d + norm_layer = nn.InstanceNorm1d if num_dim == 1 else nn.InstanceNorm2d if num_dim == 2 else nn.InstanceNorm3d + return nn.Sequential( + conv_layer(3, 32, kernel_size=3, stride=1, padding=1), + norm_layer(32, affine=True), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(32 * (7 ** num_dim), num_classes, bias=True), + ) + return self._test_conv_model(instance_norm_model, 7, num_dim, device) + + @parametrize('num_dim', [1, 2, 3]) + def test_group_norm_model(self, num_dim, device): + def group_norm_model(num_classes, num_dim): + conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d + return nn.Sequential( + conv_layer(3, 32, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, 32, affine=True), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(32 * (7 ** num_dim), num_classes, bias=True), + ) + return self._test_conv_model(group_norm_model, 7, num_dim, device) + + @parametrize('num_dim', [1, 2, 3]) + def test_layer_norm_model(self, num_dim, device): + def layer_norm_model(num_classes, num_dim): + conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d + normalized_shape = [7] * num_dim + return nn.Sequential( + conv_layer(3, 32, kernel_size=3, stride=1, padding=1), + nn.LayerNorm(normalized_shape, elementwise_affine=True), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(32 * (7 ** num_dim), num_classes, bias=True), + ) + return self._test_conv_model(layer_norm_model, 7, num_dim, device) + + def test_embedding_model(self, device): + def embedding_model(num_classes, num_embedding): + return nn.Sequential( + nn.Embedding(num_embedding, 15), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(375, num_classes, bias=True) + ) + return self._test_embedding_model(embedding_model, 16, device) def test_group_norm_error(self, device): # group norm has to call native_group_norm. This checks that it hits the same errors @@ -266,7 +386,7 @@ def _do_test(self, module, input): input.requires_grad_() with freeze_rng_state(): # get per sample grads with ExpandedWeights context manager - actual_res = call_for_per_sample_grads(module, batch_size, input).sum() + actual_res = call_for_per_sample_grads(module, loss_reduction="sum")(input).sum() actual_res.backward() actual_grads = [] for param in module.parameters(): @@ -308,7 +428,7 @@ def forward(self, input): with freeze_rng_state(): # get per sample grads with ExpandedWeights context manager, calling .backward() twice test_module = TestModule(module) - actual_res = call_for_per_sample_grads(test_module, batch_size, input).sum() + actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")(input).sum() actual_res.backward() actual_grads = [] for param in module.parameters(): @@ -337,15 +457,72 @@ def test_per_sample_api_failing(self): module = nn.Linear(10, 10) input = torch.randn(64, 10) with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"): - call_for_per_sample_grads("fail", 64, input) - with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be an integer"): - call_for_per_sample_grads(module, 6.4, input) + call_for_per_sample_grads("fail")(input) + with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be None or an integer"): + call_for_per_sample_grads(module, batch_size=6.4)(input) with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"): - call_for_per_sample_grads(module, -64, input) + call_for_per_sample_grads(module, batch_size=-64)(input) with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"): - loss = call_for_per_sample_grads(module, 64, input).sum() + loss = call_for_per_sample_grads(module)(input).sum() loss.backward() # populate grad_sample fields - call_for_per_sample_grads(module, 64, input) + call_for_per_sample_grads(module)(input) + + module = nn.Linear(10, 10) # reset to not have grad_sample fields + with self.assertRaisesRegex(RuntimeError, r"Expected loss_reduction argument to be sum or mean"): + call_for_per_sample_grads(module, loss_reduction="")(input) + + def test_per_sample_api_compute_batch_size(self): + class CustomModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 5) + + def forward(self, input1, input2): + return self.linear(input1) + self.linear(input2) + + module = CustomModule() + input1 = torch.randn(4, 5) + input2 = torch.randn(5, 5) + + with self.assertRaisesRegex(RuntimeError, "found at least one input with batch size 4 and one with batch size 5"): + call_for_per_sample_grads(module)(input1, input2) + + input2 = torch.randn(4, 5) + call_for_per_sample_grads(module)(input1, input2) + + module = CustomModule() + call_for_per_sample_grads(module)(input1, input2=input2) + + module = CustomModule() + call_for_per_sample_grads(module)(input1=input1, input2=input2) + + def test_per_sample_api_compute_batch_size_not_pytreeable(self): + @dataclass + class NonPytreeableTuple: + elem1: torch.Tensor + elem2: torch.Tensor + + class CustomModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 5) + + def forward(self, input1, input2): + return self.linear(input1.elem1) + self.linear(input1.elem2) + + input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5)) + model = CustomModule() + with self.assertRaisesRegex(RuntimeError, "ExpandedWeights cannot compute the batch size from the inputs"): + call_for_per_sample_grads(model)(input, "") + + # would prefer for it to error because input is not pytree-able but that's hard to detect + with self.assertRaisesRegex(RuntimeError, "Expected ExpandedWeights to have batch size matching input"): + call_for_per_sample_grads(model)(input, torch.randn(5)) + + model = CustomModule() # TODO: functional call bug, sam will fix + call_for_per_sample_grads(model)(input, torch.randn(4, 5)) + model = CustomModule() + call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5)) class ContextManagerTests(TestBase): def __init__(self, *args, **kwargs): @@ -378,10 +555,17 @@ def test_context_manager_multiple_inputs(self, test_case, device): raise unittest.SkipTest("Can't get per sample gradients for input of rank 1") test_case._do_test_multi_input(module, input) +def filter_supported_tests(t): + supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm', 'GroupNorm', 'InstanceNorm'] + if 'module_name' in t and t['module_name'] in supported_modules: + return True + if 'fullname' in t and any([module + "_" in t['fullname'] for module in supported_modules]): + return not('Conv' in t['fullname'] and 'pad' in t['fullname']) + # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests # These currently use the legacy nn tests supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm', 'GroupNorm', 'InstanceNorm'] -supported_tests = [t for t in module_tests + new_module_tests if 'module_name' in t and t['module_name'] in supported_modules] +supported_tests = [t for t in module_tests + new_module_tests if filter_supported_tests(t)] for test_param in supported_tests: if 'constructor' not in test_param: name = test_param.pop('module_name') @@ -418,9 +602,11 @@ def run_op(op, input, *args, **kwargs): else: return op(input, *args, **kwargs) -def make_expanded_weight(sample_input, batch_size): +def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"): def expanded_weight_or_clone(arg): - return ExpandedWeight(torch.clone(arg), batch_size) if is_diff_tensor(arg) else clone_if_tensor(arg) + if is_diff_tensor(arg): + return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction) + return clone_if_tensor(arg) ew_input = clone_if_tensor(sample_input.input) ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args) @@ -434,14 +620,16 @@ def supported_inputs(op, sample_inputs, supported_inputs=True): """ def filter_fn(input): convolutions = ["nn.functional.conv1d", "nn.functional.conv2d", "nn.functional.conv3d"] + batched_input_size = dict(zip(convolutions, [3, 4, 5])) if op.name == "nn.functional.linear": - is_supported_input = len(input.input.shape) > 1 # input of rank 1 means no batch dim + is_supported_input = input.input.dim() > 1 # input of rank 1 means no batch dim elif op.name == "nn.functional.layer_norm": normalized_shape = input.args[0] is_supported_input = input.input.shape != normalized_shape # would cause inter-batch operations elif op.name in convolutions: # currently can't deal with padding computation on Python level is_supported_input = 'padding' not in input.kwargs or not isinstance(input.kwargs['padding'], str) + is_supported_input = is_supported_input and input.input.dim() == batched_input_size[op.name] elif op.name == "nn.functional.embedding": idx = input.args[0] is_supported_input = len(idx.shape) > 1 # there's no batch size @@ -451,12 +639,12 @@ def filter_fn(input): return is_supported_input if supported_inputs else not is_supported_input return [input for input in sample_inputs if filter_fn(input)] -def for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs): +def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs): # get per sample grads by getting derivative for each input in a for loop per_sample_grad = [] for i in range(batch_size): per_sample_input = input[i] - result = func(per_sample_input.unsqueeze(0), *args, **kwargs) + result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs)) diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values()) diff_input_list = [i for i in diff_input_list if isinstance(i, torch.Tensor) and i.requires_grad] per_sample_grad.append(torch.autograd.grad(result, diff_input_list, torch.ones_like(result), allow_unused=True)) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 8cfa4507b4ac9..6ea09369ff2ac 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1,6 +1,6 @@ # Owner(s): ["module: meta tensors"] -from torch.testing._internal.common_utils import TestCase, run_tests, skipIfCrossRef +from torch.testing._internal.common_utils import TestCase, run_tests, skipIfCrossRef, skipIfRocm import torch import itertools from torch.testing._internal.jit_utils import RUN_CUDA @@ -10,9 +10,12 @@ FakeTensorConverter, DynamicOutputShapeException, ) +from torch.testing import FileCheck from torch.utils._python_dispatch import enable_torch_dispatch_mode +from torch import nn import unittest import torch._prims as prims +import contextlib import copy class FakeTensorTest(TestCase): @@ -39,6 +42,13 @@ def test_parameter_instantiation(self): y = torch.nn.parameter.Parameter(x) self.assertTrue(isinstance(y, torch.nn.Parameter)) + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_index_cuda_with_cpu(self): + with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): + x = torch.rand([2048], device='cuda') + out = x[torch.zeros([36], dtype=torch.int64)] + self.checkType(out, "cuda", [36]) + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_shape_take_not_device(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): @@ -79,6 +89,25 @@ def test_type_as(self): self.assertEqual(out.device.type, "cuda") self.assertTrue(isinstance(out, FakeTensor)) + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_setitem(self): + for device in ["cpu", "cuda"]: + with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): + x = torch.rand([16, 1], device=device) + x[..., 0] = 0 + + def test_fake_dispatch_keys(self): + with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): + x = torch.rand([4]) + f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU") + f.run(torch._C._dispatch_key_set(x)) + + with torch.inference_mode(): + x = torch.rand([4]) + y = x + x + FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y)) + FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y)) + def test_constructor(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([4, 4], device="cpu") @@ -109,6 +138,16 @@ def test_fake_mode_error(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): y = x[0] + def test_fake_grad_copy(self): + x = torch.rand([4, 4], requires_grad=True) + x.grad = torch.rand([4, 4]) + mode = FakeTensorMode() + fake_x = mode.from_tensor(x) + prims.utils.compare_tensor_meta(fake_x, x) + prims.utils.compare_tensor_meta(fake_x.grad, x.grad) + + self.assertTrue(isinstance(fake_x.grad, FakeTensor)) + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_like_constructor(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): @@ -130,20 +169,21 @@ def test_binary_op_type_promotion(self): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_cpu_fallback(self): - with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=False)): + with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=False)): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() - with self.assertRaises(NotImplementedError): - torch.nn.functional.conv2d(inputs, filters, padding=1) + out = torch.nn.functional.conv2d(inputs, filters, padding=1) + self.assertEqual(out.device.type, "cuda") + self.assertEqual(list(out.size()), [1, 8, 5, 5]) - with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=True)): + with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=True)): # intentionally bad inputs filters = torch.randn(8, 20, 3, 3).cuda() inputs = torch.randn(1, 7, 10, 5).cuda() with self.assertRaises(RuntimeError): torch.nn.functional.conv2d(inputs, filters, padding=1) - with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=True)): + with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=True)): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() @@ -151,9 +191,127 @@ def test_cpu_fallback(self): self.assertEqual(out.device.type, "cuda") self.assertEqual(list(out.size()), [1, 8, 5, 5]) + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_normalize_device(self): + with FakeTensorMode(): + x = torch.empty(1, device="cuda") + y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}") + out = x + y + self.checkType(out, "cuda", [1]) + + @skipIfRocm + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_cudnn_rnn(self): + def fn( + a0, + b0, + b1, + b2, + b3, + b4, + b5, + b6, + b7, + b8, + b9, + b10, + b11, + b12, + b13, + b14, + b15, + a3, + a4, + a5, + ): + a1 = [ + b0, + b1, + b2, + b3, + b4, + b5, + b6, + b7, + b8, + b9, + b10, + b11, + b12, + b13, + b14, + b15, + ] + return torch.ops.aten._cudnn_rnn( + a0, + a1, + 4, + a3, + a4, + a5, + 2, + 2048, + 0, + 2, + False, + 0.0, + False, + True, + [], + None, + ) + + mode = FakeTensorMode(inner=None) + for i, context in enumerate([contextlib.nullcontext, lambda: enable_torch_dispatch_mode(mode)]): + with context(): + inps = ( + torch.randn([92, 8, 2048]).cuda(), + torch.randn([8192, 2048]).cuda(), + torch.randn([8192, 2048]).cuda(), + torch.randn([8192]).cuda(), + torch.randn([8192]).cuda(), + torch.randn([8192, 2048]).cuda(), + torch.randn([8192, 2048]).cuda(), + torch.randn([8192]).cuda(), + torch.randn([8192]).cuda(), + torch.randn([8192, 4096]).cuda(), + torch.randn([8192, 2048]).cuda(), + torch.randn([8192]).cuda(), + torch.randn([8192]).cuda(), + torch.randn([8192, 4096]).cuda(), + torch.randn([8192, 2048]).cuda(), + torch.randn([8192]).cuda(), + torch.randn([8192]).cuda(), + torch.randn([167837696]).cuda(), + torch.randn([4, 8, 2048]).cuda(), + torch.randn([4, 8, 2048]).cuda(), + ) + out = fn(*inps) + self.assertIs(out[4], inps[-3]) + for ten in out: + if i == 1: + self.assertTrue(isinstance(ten, FakeTensor)) + self.assertTrue(ten.device.type == 'cuda') + + @skipIfRocm + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_fallback_memory_prop(self): + m = nn.Conv2d(16, 33, 3, stride=2, device="cuda", dtype=torch.half) + m = m.to(memory_format=torch.channels_last) + mode = FakeTensorMode(inner=None) + # TODO: module.to() doesn't work because it assigns .data, which is ignored + with torch._subclasses.fake_tensor.FakeCopyMode(mode): + mod_copied = copy.deepcopy(m) + + with enable_torch_dispatch_mode(mode): + input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last) + out = mod_copied(input) + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.checkType(out, "cuda", [20, 33, 24, 49]) + def test_data_dependent_operator(self): with enable_torch_dispatch_mode( - FakeTensorMode(inner=None, allow_cpu_fallback=False) + FakeTensorMode(inner=None, allow_fallback_kernels=False) ): x = torch.rand([10, 10]) @@ -199,7 +357,9 @@ def test_new(self): a = torch.rand([16, 1]) self.checkType(a.new(10, 10), "cpu", [10, 10]) self.checkType(a.new([1, 2, 3, 4]), "cpu", [4]) - self.checkType(a.new(device='cuda'), "cuda", [0]) + b = torch.rand([4, 4], device='cuda') + self.checkType(b.new(device='cuda'), "cuda", [0]) + def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type): return maybe_contained_type.isSubtypeOf(type) or any( @@ -219,7 +379,7 @@ def test_memoized_conversion_from_meta(self): converter = mode.fake_tensor_converter self.assertTrue(converter(mode, x, "cpu") is converter(mode, x, "cpu")) - def test_separate_tensor_storages(self): + def test_separate_tensor_storages_view(self): x = torch.rand(2, 2, 2) y = x[0] mode = FakeTensorMode(inner=None) @@ -228,6 +388,26 @@ def test_separate_tensor_storages(self): y_conv = converter(mode, y) self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv)) + def test_separate_tensor_storages_non_view(self): + x = torch.rand(2, 2, 2) + y = torch.rand(4, 2) + y.set_(x.storage()) + mode = FakeTensorMode(inner=None) + converter = mode.fake_tensor_converter + x_conv = converter(mode, x) + y_conv = converter(mode, y) + stor_id = torch._C._storage_id(x_conv) + self.assertEqual(stor_id, torch._C._storage_id(y_conv)) + del x + self.assertEqual(len(converter.tensor_memo), 1) + converter.meta_converter.check_for_expired_weak_storages() + self.assertEqual(len(converter.meta_converter.storage_memo), 1) + del y + self.assertEqual(len(converter.tensor_memo), 0) + converter.meta_converter.check_for_expired_weak_storages() + self.assertEqual(len(converter.meta_converter.storage_memo), 0) + + def test_dead_weak_ref(self): x = torch.rand(2, 2, 2) y = x[0] @@ -240,6 +420,17 @@ def test_dead_weak_ref(self): y_conv = converter(mode, y) self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv)) + def test_dead_key(self): + x = torch.rand(2, 2, 2) + mode = FakeTensorMode(inner=None) + converter = FakeTensorConverter() + x_conv = converter(mode, x) + self.assertEqual(len(converter.tensor_memo), 1) + self.assertEqual(len(converter.meta_converter.tensor_memo), 1) + del x + self.assertEqual(len(converter.tensor_memo), 0) + self.assertEqual(len(converter.meta_converter.tensor_memo), 0) + def test_no_active_mode(self): mode = FakeTensorMode(inner=None) with enable_torch_dispatch_mode(mode): @@ -260,13 +451,13 @@ def test_separate_mode_error(self): def test_no_ref_cycle(self): x = torch.rand([4]) - mode = torch._prims.utils.get_prim_fake_mode() + mode = torch._prims.get_prim_fake_mode() y = mode.from_tensor(x) - assert mode is torch._prims.utils.get_prim_fake_mode() + assert mode is torch._prims.get_prim_fake_mode() self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1) del mode del y - new_mode = torch._prims.utils.get_prim_fake_mode() + new_mode = torch._prims.get_prim_fake_mode() self.assertEqual(len(new_mode.fake_tensor_converter.tensor_memo), 0) diff --git a/test/test_functional_optim.py b/test/test_functional_optim.py index 862030f324028..7dd0276fb6c97 100644 --- a/test/test_functional_optim.py +++ b/test/test_functional_optim.py @@ -1,11 +1,14 @@ # Owner(s): ["oncall: distributed"] +from typing import List, Optional, Tuple + import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from torch.optim import SGD, Adam, AdamW from torch.testing._internal.common_utils import TestCase, run_tests -from torch.distributed.optim.utils import functional_optim_map +from torch.distributed.optim.utils import functional_optim_map, register_functional_optim class MyModule(torch.nn.Module): def __init__(self): @@ -17,6 +20,49 @@ def __init__(self): def forward(self, t1): return self.lin2(F.relu(self.lin1(t1))) +# dummy class to showcase custom optimizer registration with functional wrapper +class MyDummyFnOptimizer(object): + def __init__( + self, + params: List[Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + _allow_empty_param_list: bool = False, + ): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 < weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + def step_param(self, param: Tensor, grad: Optional[Tensor]): + # call the custom optimizer step_param implementation + with torch.no_grad(): + raise RuntimeError("MyDummyFnOptimizer does not support step_param() as of now") + + def step(self, gradients: List[Optional[Tensor]]): + # call the custom optimizer step implementation + with torch.no_grad(): + raise RuntimeError("MyDummyFnOptimizer does not support step() as of now") class TestFunctionalOptimParity(TestCase): def _validate_parameters(self, params_1, params_2): @@ -80,6 +126,17 @@ def _test_functional_optim_parity(self, optim_cls, *args, **kwargs): self.assertNotEqual(old_module_optim_params[i], optim_param) self.assertNotEqual(old_module_functional_params[i], functional_param) + def _test_functional_optim_registration(self): + fn_map_key = "MyDummyFnOptimizer" + fn_optim = MyDummyFnOptimizer + register_functional_optim(fn_map_key, fn_optim) + functional_optim_cls = functional_optim_map.get(fn_map_key, None) + if not functional_optim_cls: + raise ValueError(f"Functional optimizer not registered for {fn_map_key}") + + def test_functional_optim_registration(self): + self._test_functional_optim_registration() + def test_functional_optim_parity_sgd(self): self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01) diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 59837372ccb5f..54645503470a3 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -1,11 +1,13 @@ # Owner(s): ["module: codegen"] import torch -from torch.testing._internal.common_utils import TestCase, run_tests -from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, capture_logs, log_input +from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO +from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs from torch.utils._pytree import tree_map +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.reinplace import reinplace -import logging +import unittest def are_aliased(x, y): if x._base is None and y._base is None: @@ -16,91 +18,76 @@ def are_aliased(x, y): return y._base is x return x._base is y._base -# Just for testing: a logging tensor that also transforms out-of-place ops into inplace ops. -# That way even if the outer wrapper is functionalized, the inner wrapper will also need functionalization. -class InplaceLoggingTensor(LoggingTensorReentrant): - @staticmethod - def __new__(cls, e): - r = torch.Tensor._make_wrapper_subclass(cls, e.shape, dtype=e.dtype, requires_grad=False) - r.elem = e - return r - - __torch_function__ = torch._C._disabled_torch_function_impl - - def __str__(self): - return f'InplaceLoggingTensor({self.elem})' - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(e): - if isinstance(e, InplaceLoggingTensor): - return e.elem - else: - return e - - def wrap(e): - if isinstance(e, torch.Tensor): - return InplaceLoggingTensor(e) - else: - return e - f = func - # this subclass converts all `add()` ops into `add_()` ops - if f is torch.ops.aten.add.Tensor: - f = torch.ops.aten.add_.Tensor - - with cls.context(): - rs = tree_map(wrap, f(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) - # after running the (potentially transformed) op, - # log the original op that we saw. - logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) - return rs - - - +# We can unify testing and use functionalize() here instead +# if/when functorch moves into core. +# This is basically a crappy version of `functionalize()` for single-tensor-arg inputs. +def _functionalize(f, *, reapply_views: bool): + def wrapped(a): + input_functional = torch._to_functional_tensor(a) + torch._enable_functionalization(reapply_views=reapply_views) + try: + out = f(input_functional) + finally: + torch._disable_functionalization() + torch._sync(input_functional) + inpt_new = torch._from_functional_tensor(input_functional) + if inpt_new is not a: + # Existing deficiency in functionalize(): + # we don't correctly mutate input metadata (yet?) + if inpt_new.shape == a.shape: + a.copy_(inpt_new) + tree_map(torch._sync, out) + out_unwrapped = tree_map(torch._from_functional_tensor, out) + return out_unwrapped + + return wrapped + +@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457") class TestFunctionalization(TestCase): - def get_logs(self, func, inpt, *, reapply_views=False): - input_clone_logging = LoggingTensor(inpt.clone()) - input_functional_logging = torch._to_functional_tensor(input_clone_logging) + def get_logs(self, func, inpt, *, reapply_views=False, run_reinplace=False): + inpt_clone = inpt.clone() + traced_f = make_fx(_functionalize(func, reapply_views=reapply_views))(inpt) + if run_reinplace: + traced_f = reinplace(traced_f, inpt_clone) + return traced_f.code - with capture_logs() as logs: - log_input("input", input_clone_logging) - torch._enable_functionalization(reapply_views=reapply_views) - try: - func(input_functional_logging) - finally: - torch._disable_functionalization() - return logs - - def assert_functionalization(self, func, inpt, *, reapply_views=False): + def assert_functionalization(self, func, inpt, *, reapply_views=False, mutated_input_metadata=False): input_clone = inpt.clone() input_clone2 = inpt.clone() - input_functional = torch._to_functional_tensor(input_clone2) + input_clone3 = inpt.clone() # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) - - torch._enable_functionalization(reapply_views=reapply_views) - try: - out_functional = func(input_functional) - finally: - torch._disable_functionalization() - - # We need to sync the input tensors first, in case there are any queued mutations left. - torch._sync(input_functional) - self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur + out_functional = _functionalize(func, reapply_views=reapply_views)(input_clone) + # The reinplacing pass is only valid to run with reapply_views=True. + functional_func = make_fx(_functionalize(func, reapply_views=True))(input_clone2) + reinplace_func = reinplace(make_fx(_functionalize(func, reapply_views=True))(input_clone2), input_clone2) + + # NOTE: for now, need to pass in fresh inputs here, because make_fx + # will directly mutate the inputs that you trace with. + # Once this is fixed we can clean this up. + out_reinplace = reinplace_func(input_clone3) + + # functionalize() deficiency: input metadata mutations aren't propagated properly, + # so we just need to skip checks here for the tests that exercise that. + if not mutated_input_metadata: + self.assertEqual(inpt, input_clone) # input mutations should still occur + self.assertEqual(inpt, input_clone3) # Handle tests with multi-tensor outputs - if isinstance(out_ref, tuple) and isinstance(out_functional, tuple): - out_refs, out_functionals = list(out_ref), list(out_functional) + if isinstance(out_ref, tuple): + out_refs, out_functionals, out_reinplaces = list(out_ref), list(out_functional), list(out_reinplace) else: - out_refs, out_functionals = [out_ref], [out_functional] + out_refs, out_functionals, out_reinplaces = [out_ref], [out_functional], [out_reinplace] + + for out_ref_, out_functional_, out_reinplace_ in zip(out_refs, out_functionals, out_reinplaces): + self.assertEqual(out_ref_, out_functional_) + self.assertEqual(out_ref_, out_reinplace_) - for out_ref_, out_functional_ in zip(out_refs, out_functionals): - self.assertEqual(out_ref_.size(), out_functional_.size()) - torch._sync(out_functional_) - out_functional_unwrapped = torch._from_functional_tensor(out_functional_) - self.assertEqual(out_ref_, out_functional_unwrapped) + def test_save_for_backwards_segfault(self): + inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True) + inp.exp() def test_multiple_views_of_same_base(self): def f(x): @@ -124,15 +111,34 @@ def f(x): return y self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.view_copy.default($0, [4, 2]) -$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.], - [1., 1.], - [1., 1.], - [1., 1.]])) -$3 = torch._ops.aten.view_copy.default($2, [4, 2]) -$4 = torch._ops.aten.mul.Tensor($3, $3)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]) + add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None + view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]) + mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1) + copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None + return add_tensor + """) + + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + view_default = torch.ops.aten.view.default(a_1, [4, 2]) + add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None + view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]) + mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1) + copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None + return add_tensor + """) def test_simple_out(self): def f(x): @@ -145,14 +151,32 @@ def f(x): return w self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.view_copy.default($0, [4, 2]) -$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.], - [1., 1.], - [1., 1.], - [1., 1.]])) -$3 = torch._ops.aten.mul.Tensor($2, $2)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None + empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None + mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None + return mul_tensor + """) + + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None + empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None + mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None + return mul_tensor + """) def test_multi_out(self): def f(x): @@ -164,9 +188,32 @@ def f(x): return out_max self.assert_functionalization(f, torch.arange(8, dtype=torch.float32)) logs = self.get_logs(f, torch.arange(8, dtype=torch.float32)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1, $2 = torch._ops.aten.aminmax.default($0, dim=0)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty_1 = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None + getitem = aminmax_default[0] + getitem_1 = aminmax_default[1]; aminmax_default = None + return getitem + """) + + reinplaced_logs = self.get_logs(f, torch.arange(8, dtype=torch.float32), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty_1 = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None + getitem = aminmax_default[0] + getitem_1 = aminmax_default[1]; aminmax_default = None + return getitem + """) def test_tensor_ctr(self): def f(x): @@ -174,7 +221,50 @@ def f(x): z = y.view(-1) z.add_(1) return y - self.assert_functionalization(f, torch.arange(3, dtype=torch.float32)) + + inpt = torch.arange(3, dtype=torch.float32) + self.assert_functionalization(f, inpt) + + logs = self.get_logs(f, inpt) + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + _tensor_constant0 = self._tensor_constant0 + lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None + view_copy_default = torch.ops.aten.view_copy.default(lift_fresh, [-1]); lift_fresh = None + add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None + view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [3]); add_tensor = None + return view_copy_default_1 + """) + + reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + _tensor_constant0 = self._tensor_constant0 + lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None + view_default = torch.ops.aten.view.default(lift_fresh, [-1]); lift_fresh = None + add_tensor = torch.ops.aten.add_.Tensor(view_default, 1) + view_default_1 = torch.ops.aten.view.default(view_default, [3]); view_default = None + return view_default_1 + """) + + + def test_tensor_list_mixed_functional_nonfunctional(self): + nonfunctional_tensor = torch.ones(2, dtype=torch.long) + + def f(x): + # simple test: 1 view op, 1 inplace op + functional_tensor = torch.ones(2, dtype=torch.long) + out = x[functional_tensor, nonfunctional_tensor] + return out + out = f(torch.ones(2, 2)) + out_functional = _functionalize(f, reapply_views=True)(torch.ones(2, 2)) + self.assertEqual(out, out_functional) def test_inplace_on_non_view(self): def f(x): @@ -186,13 +276,32 @@ def f(x): return y self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.view_copy.default($0, [4, 2]) -$2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.], - [1., 1.], - [1., 1.], - [1., 1.]]))""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]) + add_tensor = torch.ops.aten.add.Tensor(a_1, ones); ones = None + copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None + view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None + return view_copy_default_1 + """) + + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + view_default = torch.ops.aten.view.default(a_1, [4, 2]) + add_tensor = torch.ops.aten.add.Tensor(a_1, ones); ones = None + copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None + view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]); add_tensor = None + return view_default_1 + """) # Some ops that are mutable are neither inplace nor out= ops. # They also need special handling. @@ -201,9 +310,21 @@ def f(x): return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0) logs = self.get_logs(f, torch.ones(1)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + _fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0) + getitem = _fused_moving_avg_obs_fq_helper_functional_default[0] + getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1] + getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2] + getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3] + getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4] + getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None + copy__default = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None + return (getitem, getitem_1) + """) # noqa: B950 def test_as_strided(self): def f(x): @@ -212,10 +333,17 @@ def f(x): return x self.assert_functionalization(f, torch.ones(9)) logs = self.get_logs(f, torch.ones(9)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.as_strided_copy.default($0, [2], [2], 1) -$2 = torch._ops.aten.add.Tensor($1, 1)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1) + add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None + as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); add_tensor = None + copy__default = torch.ops.aten.copy_.default(a_1, as_strided_scatter_default); a_1 = None + return as_strided_scatter_default + """) def test_tensor_list_composite(self): def f(x): @@ -224,11 +352,14 @@ def f(x): return y self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.block_diag.default([LoggingTensor(tensor([[1., 1.], - [1., 1.]])), LoggingTensor(tensor([[1., 1.], - [1., 1.]]))])""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + block_diag_default = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None + return block_diag_default + """) def test_cat(self): def f(x): @@ -237,27 +368,64 @@ def f(x): return out self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.cat.default([LoggingTensor(tensor([[1., 1.], - [1., 1.]]))])""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None + return cat_default + """) + + reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None + return cat_default + """) + def test_diagonal(self): def f(x): # test: view ops that take a subset of the original tensor (select/diagonal) tmp = torch.ones(2) - y = x.diagonal() + y = x.clone().diagonal() y.add_(tmp) z = x * x return z self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.diagonal_copy.default($0) -$2 = torch._ops.aten.add.Tensor($1, tensor([1., 1.])) -$3 = torch._ops.aten.diagonal_scatter.default($0, $2) -$4 = torch._ops.aten.mul.Tensor($3, $3)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + clone_default = torch.ops.aten.clone.default(a_1) + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(clone_default); clone_default = None + add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None + mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None + return mul_tensor + """) + + reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + clone_default = torch.ops.aten.clone.default(a_1) + diagonal_default = torch.ops.aten.diagonal.default(clone_default); clone_default = None + add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, ones); diagonal_default = ones = None + mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None + return mul_tensor + """) def test_diagonal_mutated_input(self): def f(x): @@ -268,6 +436,19 @@ def f(x): return x x = torch.ones(2, 2) self.assert_functionalization(f, x) + logs = self.get_logs(f, torch.ones(2, 2)) + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1) + add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None + diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); add_tensor = None + copy__default = torch.ops.aten.copy_.default(a_1, diagonal_scatter_default); a_1 = None + return diagonal_scatter_default + """) def test_split(self): def f(x): @@ -280,15 +461,26 @@ def f(x): return y3 self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1, $2 = torch._ops.aten.split_copy.Tensor($0, 2) -$3 = torch._ops.aten.diagonal_copy.default($2) -$4 = torch._ops.aten.add.Tensor($3, tensor([1., 1.])) -$5, $6 = torch._ops.aten.split_copy.Tensor($0, 2) -$7 = torch._ops.aten.diagonal_scatter.default($6, $4) -$8 = torch._ops.aten.slice_scatter.default($0, $7, 0, 2, 4) -$9 = torch._ops.aten.mul.Tensor($8, $8)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + split_copy_tensor = torch.ops.aten.split_copy.Tensor(a_1, 2) + getitem = split_copy_tensor[0] + getitem_1 = split_copy_tensor[1]; split_copy_tensor = None + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None + add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None + split_copy_tensor_1 = torch.ops.aten.split_copy.Tensor(a_1, 2) + getitem_2 = split_copy_tensor_1[0] + getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None + diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None + slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); diagonal_scatter_default = None + mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default) + copy__default = torch.ops.aten.copy_.default(a_1, slice_scatter_default); a_1 = slice_scatter_default = None + return add_tensor + """) # noqa: B950 def test_view_inplace(self): def f(x): @@ -298,13 +490,23 @@ def f(x): y = x[0] y.add_(tmp) return x - self.assert_functionalization(f, torch.ones(4, 2)) + self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.transpose_copy.int($0, 1, 0) -$2 = torch._ops.aten.select_copy.int($1, 0, 0) -$3 = torch._ops.aten.add.Tensor($2, tensor([1., 1., 1., 1.]))""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + transpose_copy_int = torch.ops.aten.transpose_copy.int(a_1, 1, 0) + select_copy_int = torch.ops.aten.select_copy.int(transpose_copy_int, 0, 0); transpose_copy_int = None + add_tensor = torch.ops.aten.add.Tensor(select_copy_int, ones); select_copy_int = ones = None + transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None + select_scatter_default = torch.ops.aten.select_scatter.default(transpose_copy_int_1, add_tensor, 0, 0); transpose_copy_int_1 = add_tensor = None + transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(select_scatter_default, 1, 0); select_scatter_default = None + transpose_copy_int_3 = torch.ops.aten.transpose_copy.int(transpose_copy_int_2, 1, 0); transpose_copy_int_2 = None + return transpose_copy_int_3 + """) # noqa: B950 def test_optional_tensor_list(self): def f(x): @@ -317,10 +519,19 @@ def f(x): return y self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.view_copy.default($0, [8]) -$2 = torch._ops.aten.index_put.default($1, [tensor([0, 1, 2, 3])], tensor([0., 1., 2., 3.]))""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + view_copy_default = torch.ops.aten.view_copy.default(a_1, [8]) + arange = torch.ops.aten.arange.default(4, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None + view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2]) + copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None + return index_put_default + """) # noqa: B950 def test_scalars(self): def f(x): @@ -333,24 +544,71 @@ def f(x): return z self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.view_copy.default($0, [4, 2]) -$2 = torch._ops.aten.add.Tensor($1, 1) -$3 = torch._ops.aten.mul.Tensor($2, 2) -$4 = torch._ops.aten.div.Tensor($3, 1)""") + self.assertExpectedInline(logs, """\ + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]) + add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None + mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2) + div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None + view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None + copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None + return div_tensor + """) + + @skipIfTorchDynamo("Test does not work with TorchDynamo") def test_metadata_change(self): def f(x): # ops like ge_() are allowed to change the dtype of the input. # functionalization should pick up on that. - return x.ge_(0) + y = x.clone() + out = y.ge_(0) + return out self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.ge.Scalar($0, 0) -$2 = torch._ops.aten._to_copy.default($1, dtype=torch.float32, layout=torch.strided)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + clone_default = torch.ops.aten.clone.default(a_1); a_1 = None + ge_scalar = torch.ops.aten.ge.Scalar(clone_default, 0); clone_default = None + _to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None + return _to_copy_default + """) + + reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + clone_default = torch.ops.aten.clone.default(a_1); a_1 = None + ge_scalar = torch.ops.aten.ge_.Scalar(clone_default, 0) + _to_copy_default = torch.ops.aten._to_copy.default(clone_default, dtype = torch.float32, layout = torch.strided); clone_default = None + return _to_copy_default + """) # noqa: B950 + + @skipIfTorchDynamo("Test does not work with TorchDynamo") + def test_metadata_change_out_op(self): + def f(t, y): + out_1 = torch.ones(1) + return torch.add(t, y, out=out_1) + + inpt1, inpt2 = torch.tensor([1]), torch.tensor([1]) + inpt1_func, inpt2_func = torch._to_functional_tensor(inpt1), torch._to_functional_tensor(inpt2) + + out_ref = f(inpt1, inpt2) + torch._enable_functionalization(reapply_views=True) + try: + out_functional = f(inpt1_func, inpt2_func) + finally: + torch._disable_functionalization() + self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) + def test_only_one_view(self): def f(x): @@ -359,9 +617,14 @@ def f(x): # so there should be a total of 1 op in the output trace. return x.view(4, 2) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.view_copy.default($0, [4, 2])""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None + return view_copy_default + """) def test_everything(self): def f(x): @@ -379,35 +642,79 @@ def f(x): return z2 self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.add.Tensor($0, $0) -$2 = torch._ops.aten.view_copy.default($1, [8]) -$3 = torch._ops.aten._reshape_alias_copy.default($2, [2, 4], [4, 1]) -$4 = torch._ops.aten.transpose_copy.int($3, 1, 0) -$5 = torch._ops.aten.unsqueeze_copy.default($4, 0) -$6 = torch._ops.aten.squeeze_copy.default($5) -$7, $8 = torch._ops.aten.split_copy.Tensor($6, 2) -$9 = torch._ops.aten.add.Tensor($7, tensor([[1., 1.], - [1., 1.]])) -$10 = torch._ops.aten.select_copy.int($3, 0, 0) -$11 = torch._ops.aten.clone.default($9, memory_format=torch.contiguous_format) -$12 = torch._ops.aten._unsafe_view.default($11, [4]) -$13 = torch._ops.aten.view_copy.default($1, [8]) -$14 = torch._ops.aten._reshape_alias_copy.default($13, [2, 4], [4, 1]) -$15 = torch._ops.aten.transpose_copy.int($14, 1, 0) -$16 = torch._ops.aten.unsqueeze_copy.default($15, 0) -$17 = torch._ops.aten.squeeze_copy.default($16) -$18 = torch._ops.aten.slice_scatter.default($17, $9, 0, 0, 2) -$19 = torch._ops.aten.unsqueeze_copy.default($18, 0) -$20 = torch._ops.aten.squeeze_copy.dim($19, 0) -$21 = torch._ops.aten.transpose_copy.int($20, 1, 0) -$22 = torch._ops.aten._reshape_alias_copy.default($21, [8], [1]) -$23 = torch._ops.aten.view_copy.default($22, [4, 2]) -$24 = torch._ops.aten.view_copy.default($23, [8]) -$25 = torch._ops.aten._reshape_alias_copy.default($24, [2, 4], [4, 1]) -$26 = torch._ops.aten.select_copy.int($25, 0, 0) -$27 = torch._ops.aten.add.Tensor($26, $12)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [8]) + _reshape_alias_copy_default = torch.ops.aten._reshape_alias_copy.default(view_copy_default, [2, 4], [4, 1]); view_copy_default = None + transpose_copy_int = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default, 1, 0) + unsqueeze_copy_default = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int, 0); transpose_copy_int = None + squeeze_copy_default = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default); unsqueeze_copy_default = None + split_copy_tensor = torch.ops.aten.split_copy.Tensor(squeeze_copy_default, 2); squeeze_copy_default = None + getitem = split_copy_tensor[0] + getitem_1 = split_copy_tensor[1]; split_copy_tensor = None + add_tensor_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None + select_copy_int = torch.ops.aten.select_copy.int(_reshape_alias_copy_default, 0, 0); _reshape_alias_copy_default = None + clone_default = torch.ops.aten.clone.default(add_tensor_1, memory_format = torch.contiguous_format) + _unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None + view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [8]); add_tensor = None + _reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_1, [2, 4], [4, 1]); view_copy_default_1 = None + transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default_1, 1, 0); _reshape_alias_copy_default_1 = None + unsqueeze_copy_default_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int_1, 0); transpose_copy_int_1 = None + squeeze_copy_default_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default_1); unsqueeze_copy_default_1 = None + slice_scatter_default = torch.ops.aten.slice_scatter.default(squeeze_copy_default_1, add_tensor_1, 0, 0, 2); squeeze_copy_default_1 = None + unsqueeze_copy_default_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter_default, 0); slice_scatter_default = None + squeeze_copy_dim = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_default_2, 0); unsqueeze_copy_default_2 = None + transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_dim, 1, 0); squeeze_copy_dim = None + _reshape_alias_copy_default_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_int_2, [8], [1]); transpose_copy_int_2 = None + view_copy_default_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_default_2, [4, 2]); _reshape_alias_copy_default_2 = None + view_copy_default_3 = torch.ops.aten.view_copy.default(view_copy_default_2, [8]); view_copy_default_2 = None + _reshape_alias_copy_default_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_3, [2, 4], [4, 1]); view_copy_default_3 = None + select_copy_int_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_default_3, 0, 0); _reshape_alias_copy_default_3 = None + add_tensor_2 = torch.ops.aten.add.Tensor(select_copy_int_1, _unsafe_view_default); select_copy_int_1 = _unsafe_view_default = None + return add_tensor_1 + """) # noqa: B950 + + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + view_default = torch.ops.aten.view.default(add_tensor, [8]) + _reshape_alias_default = torch.ops.aten._reshape_alias.default(view_default, [2, 4], [4, 1]); view_default = None + transpose_int = torch.ops.aten.transpose.int(_reshape_alias_default, 1, 0) + unsqueeze_default = torch.ops.aten.unsqueeze.default(transpose_int, 0); transpose_int = None + squeeze_default = torch.ops.aten.squeeze.default(unsqueeze_default); unsqueeze_default = None + split_tensor = torch.ops.aten.split.Tensor(squeeze_default, 2); squeeze_default = None + getitem = split_tensor[0] + getitem_1 = split_tensor[1]; split_tensor = None + add_tensor_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None + select_int = torch.ops.aten.select.int(_reshape_alias_default, 0, 0); _reshape_alias_default = None + clone_default = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format) + _unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None + view_default_1 = torch.ops.aten.view.default(add_tensor, [8]); add_tensor = None + _reshape_alias_default_1 = torch.ops.aten._reshape_alias.default(view_default_1, [2, 4], [4, 1]); view_default_1 = None + transpose_int_1 = torch.ops.aten.transpose.int(_reshape_alias_default_1, 1, 0); _reshape_alias_default_1 = None + unsqueeze_default_1 = torch.ops.aten.unsqueeze.default(transpose_int_1, 0); transpose_int_1 = None + squeeze_default_1 = torch.ops.aten.squeeze.default(unsqueeze_default_1); unsqueeze_default_1 = None + unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(squeeze_default_1, 0); squeeze_default_1 = None + squeeze_dim = torch.ops.aten.squeeze.dim(unsqueeze_default_2, 0); unsqueeze_default_2 = None + transpose_int_2 = torch.ops.aten.transpose.int(squeeze_dim, 1, 0); squeeze_dim = None + _reshape_alias_default_2 = torch.ops.aten._reshape_alias.default(transpose_int_2, [8], [1]); transpose_int_2 = None + view_default_2 = torch.ops.aten.view.default(_reshape_alias_default_2, [4, 2]); _reshape_alias_default_2 = None + view_default_3 = torch.ops.aten.view.default(view_default_2, [8]); view_default_2 = None + _reshape_alias_default_3 = torch.ops.aten._reshape_alias.default(view_default_3, [2, 4], [4, 1]); view_default_3 = None + select_int_1 = torch.ops.aten.select.int(_reshape_alias_default_3, 0, 0); _reshape_alias_default_3 = None + add_tensor_2 = torch.ops.aten.add.Tensor(select_int_1, _unsafe_view_default); select_int_1 = _unsafe_view_default = None + return getitem + """) def test_reapply_views_simple(self): def f(x): @@ -418,15 +725,19 @@ def f(x): return y self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True) logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.view.default($0, [4, 2]) -$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.], - [1., 1.], - [1., 1.], - [1., 1.]])) -$3 = torch._ops.aten.view.default($2, [4, 2]) -$4 = torch._ops.aten.mul.Tensor($3, $3)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + view_default = torch.ops.aten.view.default(a_1, [4, 2]) + add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None + view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]) + mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1) + copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None + return add_tensor + """) def test_aliases_maintained_after_pass_when_reapplying_views(self): def f(x): @@ -455,8 +766,6 @@ def f(x): def test_copy_(self): def f(x): tmp = torch.zeros(2, 2) - # NOTE: LoggingTensor isn't a mode, which means that the diagonal call - # will not be logged. This is fine for testing. tmp_slice = tmp.diagonal() y = tmp_slice.copy_(x) z = y.add_(x) @@ -466,34 +775,131 @@ def f(x): # to() is a composite op that noops when the dtype/shape match, so nothing gets logged. # self.assert_functionalization(f, torch.ones(2)) logs = self.get_logs(f, torch.ones(2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0) -$2 = torch._ops.aten.add.Tensor($1, $0)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None + copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None + add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None + return add_tensor + """) + + reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None + copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1) + add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None + return diagonal_default + """) # Test 2: copy_() with same dtype, different shape self.assert_functionalization(f, torch.ones(1)) logs = self.get_logs(f, torch.ones(1)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0) -$2 = torch._ops.aten.add.Tensor($1, $0)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None + copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None + add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None + return add_tensor + """) + + reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None + copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1) + add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None + return diagonal_default + """) # Test 3: copy_() with different dtype, same shape self.assert_functionalization(f, torch.ones(2, dtype=torch.long)) logs = self.get_logs(f, torch.ones(2, dtype=torch.long)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0) -$2 = torch._ops.aten.add.Tensor($1, $0)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None + copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None + add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None + return add_tensor + """) + + reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None + copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1) + add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None + return diagonal_default + """) # noqa: B950 # Test 4: copy_() with different dtype, different shape self.assert_functionalization(f, torch.ones(1, dtype=torch.long)) logs = self.get_logs(f, torch.ones(1, dtype=torch.long)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0) -$2 = torch._ops.aten.add.Tensor($1, $0)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None + copy_default = torch.ops.aten.copy.default(diagonal_copy_default, a_1); diagonal_copy_default = None + add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None + return add_tensor + """) + + reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None + copy_default = torch.ops.aten.copy_.default(diagonal_default, a_1) + add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, a_1); a_1 = None + return diagonal_default + """) # noqa: B950 + + def test_expand_symint(self): + # Once some existing SymInt bugs are ironed out, we should update + # this test to plumb FakeSymbolicTensors through it + def f(x): + return x.expand(x.size(0), x.size(1)) + + self.assert_functionalization(f, torch.ones(2, 2)) + logs = self.get_logs(f, torch.ones(2, 2)) + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2, 2]); a_1 = None + return expand_copy_default + """) def test_fill_(self): def f(x): @@ -504,11 +910,29 @@ def f(x): self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.add.Tensor($0, $0) -$2 = torch._ops.aten.diagonal_copy.default($1) -$3 = torch._ops.aten.fill.Scalar($2, 0)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(add_tensor) + fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None + diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None + return diagonal_scatter_default + """) + + reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + diagonal_default = torch.ops.aten.diagonal.default(add_tensor) + fill_scalar = torch.ops.aten.fill_.Scalar(diagonal_default, 0); diagonal_default = None + return add_tensor + """) def test_resize_smaller(self): def f(w): @@ -523,22 +947,49 @@ def f(w): self.assert_functionalization(f, torch.ones(8, 2)) logs = self.get_logs(f, torch.ones(8, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.add.Tensor($0, 1) -$2 = torch._ops.aten.view_copy.default($1, [4, 4]) -$3 = torch._ops.aten.resize.functional($2, [3, 3]) -$4 = torch._ops.aten.as_strided_copy.default($2, [3, 3], [3, 1]) -$5 = torch._ops.aten.view_copy.default($4, [-1]) -$6 = torch._ops.aten.add.Tensor($5, 1) -$7 = torch._ops.aten.view_copy.default($1, [4, 4]) -$8 = torch._ops.aten.as_strided_copy.default($7, [3, 3], [3, 1]) -$9 = torch._ops.aten.view_copy.default($6, [3, 3]) -$10 = torch._ops.aten.as_strided_scatter.default($7, $9, [3, 3], [3, 1]) -$11 = torch._ops.aten.view_copy.default($10, [8, 2]) -$12 = torch._ops.aten.view_copy.default($11, [4, 4]) -$13 = torch._ops.aten.as_strided_copy.default($12, [3, 3], [3, 1]) -$14 = torch._ops.aten.add.Tensor($13, 1)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None + view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4, 4]) + resize_default = torch.ops.aten.resize.default(view_copy_default, [3, 3]) + as_strided_copy_default = torch.ops.aten.as_strided_copy.default(view_copy_default, [3, 3], [3, 1]); view_copy_default = None + view_copy_default_1 = torch.ops.aten.view_copy.default(as_strided_copy_default, [-1]); as_strided_copy_default = None + add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None + view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor, [4, 4]); add_tensor = None + as_strided_copy_default_1 = torch.ops.aten.as_strided_copy.default(view_copy_default_2, [3, 3], [3, 1]) + view_copy_default_3 = torch.ops.aten.view_copy.default(add_tensor_1, [3, 3]); add_tensor_1 = None + as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(view_copy_default_2, view_copy_default_3, [3, 3], [3, 1]); view_copy_default_2 = view_copy_default_3 = None + view_copy_default_4 = torch.ops.aten.view_copy.default(as_strided_scatter_default, [8, 2]); as_strided_scatter_default = None + view_copy_default_5 = torch.ops.aten.view_copy.default(view_copy_default_4, [4, 4]); view_copy_default_4 = None + as_strided_copy_default_2 = torch.ops.aten.as_strided_copy.default(view_copy_default_5, [3, 3], [3, 1]); view_copy_default_5 = None + add_tensor_2 = torch.ops.aten.add.Tensor(as_strided_copy_default_2, 1); as_strided_copy_default_2 = None + return add_tensor_2 + """) # noqa: B950 + + reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None + view_default = torch.ops.aten.view.default(add_tensor, [4, 4]) + resize_default = torch.ops.aten.resize.default(view_default, [3, 3]) + as_strided_default = torch.ops.aten.as_strided.default(view_default, [3, 3], [3, 1]); view_default = None + view_default_1 = torch.ops.aten.view.default(as_strided_default, [-1]); as_strided_default = None + add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_1, 1) + view_default_2 = torch.ops.aten.view.default(add_tensor, [4, 4]); add_tensor = None + as_strided_default_1 = torch.ops.aten.as_strided.default(view_default_2, [3, 3], [3, 1]) + view_default_3 = torch.ops.aten.view.default(view_default_1, [3, 3]); view_default_1 = None + view_default_4 = torch.ops.aten.view.default(view_default_2, [8, 2]); view_default_2 = None + view_default_5 = torch.ops.aten.view.default(view_default_4, [4, 4]); view_default_4 = None + as_strided_default_2 = torch.ops.aten.as_strided.default(view_default_5, [3, 3], [3, 1]); view_default_5 = None + add_tensor_2 = torch.ops.aten.add_.Tensor(as_strided_default_2, 1) + return as_strided_default_2 + """) def test_resize_larger_valid(self): def f(x): @@ -559,14 +1010,34 @@ def f(x): self.assert_functionalization(f, torch.ones(8, 2)) logs = self.get_logs(f, torch.ones(8, 2)) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.aten.add.Tensor($0, 1) -$2 = torch._ops.aten.resize.functional($1, [5, 5]) -$3 = torch._ops.aten.view_copy.default($2, [25]) -$4 = torch._ops.aten.fill.Scalar($3, 1) -$5 = torch._ops.aten.view_copy.default($4, [5, 5]) -$6 = torch._ops.aten.add.Tensor($5, 1)""") + self.assertExpectedInline(logs, """\ + + + +def forward(self, a_1): + add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None + resize_default = torch.ops.aten.resize.default(add_tensor, [5, 5]); add_tensor = None + view_copy_default = torch.ops.aten.view_copy.default(resize_default, [25]); resize_default = None + fill_scalar = torch.ops.aten.fill.Scalar(view_copy_default, 1); view_copy_default = None + view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None + add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1) + return (view_copy_default_1, add_tensor_1) + """) + + reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, a_1): + add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None + resize_default = torch.ops.aten.resize_.default(add_tensor, [5, 5]) + view_default = torch.ops.aten.view.default(add_tensor, [25]); add_tensor = None + fill_scalar = torch.ops.aten.fill_.Scalar(view_default, 1) + view_default_1 = torch.ops.aten.view.default(view_default, [5, 5]); view_default = None + add_tensor_1 = torch.ops.aten.add.Tensor(view_default_1, 1) + return (view_default_1, add_tensor_1) + """) def test_resize_larger_invalid(self): def f(x): @@ -631,46 +1102,36 @@ def test_mixed_wrappers_invalid(self): with self.assertRaises(RuntimeError): x1_not_functional.add_(x2_functional) - # This tests the behavior of functionalization with multiple layers of wrapped tensor subclasses. - def test_multiple_levels_of_wrapping(self): + def test_index_mutation_on_non_input(self): def f(x): - # call an inplace op and have it get logged twice (by the outer + inner wrapper) - x.add_(1) - - # Test 1: both the inner and outer wrapper are "functionalized" - x_inner_and_outer_functional = torch._to_functional_tensor( - InplaceLoggingTensor(torch._to_functional_tensor(LoggingTensor(torch.ones(4))))) - - with capture_logs() as logs: - f(x_inner_and_outer_functional) + tmp = torch.zeros(10) + tmp[5].fill_(1) + return tmp + self.assert_functionalization(f, torch.ones(2)) + logs = self.get_logs(f, torch.ones(2)) + self.assertExpectedInline(logs, """\ - # Since both wrappers were unctionalized, they both log "add" - self.assertExpectedInline('\n'.join(logs), """\ -$1 = torch._ops.aten.add.Tensor($0, 1) -$3 = torch._ops.aten.add.Tensor($2, 1)""") - # Test 2: only the inner wrapper is "functionalized" - x_only_inner_functional = InplaceLoggingTensor(torch._to_functional_tensor(LoggingTensor(torch.ones(4)))) - with capture_logs() as logs: - f(x_only_inner_functional) +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([10], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + select_copy_int = torch.ops.aten.select_copy.int(zeros, 0, 5) + fill_scalar = torch.ops.aten.fill.Scalar(select_copy_int, 1); select_copy_int = None + select_scatter_default = torch.ops.aten.select_scatter.default(zeros, fill_scalar, 0, 5); zeros = fill_scalar = None + return select_scatter_default + """) # noqa: B950 - # Since only the inner wrapper is functionalized, then the inner (first) log is functionalized - self.assertExpectedInline('\n'.join(logs), """\ -$1 = torch._ops.aten.add.Tensor($0, 1) -$3 = torch._ops.aten.add_.Tensor($2, 1)""") + reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True) + self.assertExpectedInline(reinplaced_logs, """\ - # Test 3: only the inner wrapper is "functionalized" - x_only_outer_functional = torch._to_functional_tensor(InplaceLoggingTensor(LoggingTensor(torch.ones(4)))) - with capture_logs() as logs: - f(x_only_outer_functional) - # Only the outer add_ is functionalized - # Since only the outer wrapper is functionalized, then the outer (second) log is functionalized - self.assertExpectedInline('\n'.join(logs), """\ -$1 = torch._ops.aten.add_.Tensor($0, 1) -$3 = torch._ops.aten.add.Tensor($2, 1)""") +def forward(self, a_1): + zeros = torch.ops.aten.zeros.default([10], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + select_int = torch.ops.aten.select.int(zeros, 0, 5) + fill_scalar = torch.ops.aten.fill_.Scalar(select_int, 1); select_int = None + return zeros + """) if __name__ == '__main__': run_tests() diff --git a/test/test_fx.py b/test/test_fx.py index 625614d29a8e3..f69d5046cc9e3 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1,4 +1,4 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: fx"] import builtins import contextlib @@ -42,6 +42,9 @@ from fx.test_dce_pass import TestDCE # noqa: F401 from fx.test_fx_const_fold import TestConstFold # noqa: F401 from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow # noqa: F401 +from fx.test_pass_infra import TestPassManager # noqa: F401 +from fx.test_common_passes import TestCommonPass # noqa: F401 +from fx.test_cse_pass import TestCSEPass # noqa: F401 if sys.version_info >= (3, 7): from fx.test_gradual_type import AnnotationsTest # noqa: F401 @@ -52,9 +55,9 @@ IS_FBCODE, IS_MACOS, IS_WINDOWS, - TEST_WITH_ROCM, find_library_location, run_tests, + skipIfSlowGradcheckEnv, ) from torch.testing._internal.jit_utils import JitTestCase @@ -149,7 +152,7 @@ def setUp(self): self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations torch.fx.proxy.TracerBase.check_mutable_operations = True - if not (TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS): + if not (IS_FBCODE or IS_WINDOWS or IS_MACOS): lib_file_path = find_library_location('libtorchbind_test.so') torch.ops.load_library(str(lib_file_path)) @@ -577,7 +580,7 @@ def forward(self, a, b): self.checkGraphModule(m, (a, b)) def test_native_callable(self): - if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS: + if IS_FBCODE or IS_WINDOWS or IS_MACOS: raise unittest.SkipTest("non-portable load_library call used in test") # This test exercises the case where we use FX to translate from Python # code to some native callable object @@ -1551,7 +1554,7 @@ def forward(self, x): self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', 'call_module', 'output'])) - # Test shape propogation and make sure results match actual + # Test shape propagation and make sure results match actual self.assertEqual(output_shape, ref_out.shape) self.assertEqual(output_stride, ref_out.stride()) @@ -2362,7 +2365,7 @@ def test_update_args_kwargs_yells_at_you(self): node.__update_args_kwargs((), {}) def test_torchbind_class_attribute_in_fx(self): - if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS: + if IS_FBCODE or IS_WINDOWS or IS_MACOS: self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping") class FooBar1234(torch.nn.Module): @@ -2377,7 +2380,7 @@ def forward(self): self.checkGraphModule(m, ()) def test_torchbind_class_attribute_in_fx_tensor_arg(self): - if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS: + if IS_FBCODE or IS_WINDOWS or IS_MACOS: self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping") class FooBar2341(torch.nn.Module): @@ -3280,6 +3283,7 @@ def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]): .run(scripted.code) @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108") + @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10") def test_assert(self): def f(x): assert x > 1 @@ -4017,7 +4021,7 @@ def generate_test_func(cls, func_name, fn): def functional_test(self): if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \ - sys.version_info >= (3, 8) and sys.version_info < (3, 10): + sys.version_info >= (3, 8) and sys.version_info < (3, 11): exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name] with self.assertRaisesRegex(exc, err): symbolic_trace(fn) @@ -4058,9 +4062,10 @@ def tearDownClass(cls): instantiate_device_type_tests(TestOperatorSignatures, globals()) @skipIfNoTorchVision +@skipIfSlowGradcheckEnv class TestVisionTracing(JitTestCase): def setUp(self): - # Checking for mutable operations whil tracing is feature flagged + # Checking for mutable operations while tracing is feature flagged # Enable it in testing but not by default self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations torch.fx.proxy.TracerBase.check_mutable_operations = True @@ -4076,11 +4081,17 @@ def tearDown(self): UNTRACEABLE_MODELS = { "fasterrcnn_resnet50_fpn": PROXY_ITERATED, + "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED, "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED, "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED, "maskrcnn_resnet50_fpn": PROXY_ITERATED, + "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED, "keypointrcnn_resnet50_fpn": PROXY_ITERATED, "retinanet_resnet50_fpn": PROXY_ITERATED, + "retinanet_resnet50_fpn_v2": PROXY_ITERATED, + "ssd300_vgg16": PROXY_ITERATED, + "fcos_resnet50_fpn": PROXY_ITERATED, + "ssdlite320_mobilenet_v3_large": PROXY_ITERATED, } UNSCRIPTABLE_MODELS = { "googlenet": INCONSISTENT_TYPE, @@ -4132,7 +4143,7 @@ def run_test(self): @classmethod def generate_classification_tests(cls): for k, v in torchvision_models.__dict__.items(): - if callable(v) and k[0].lower() == k[0] and k[0] != "_": + if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != 'get_weight': test_name = 'test_torchvision_models_' + k x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224) kwargs = dict(num_classes=50) @@ -4164,7 +4175,7 @@ def generate_video_tests(cls): for k, v in torchvision_models.video.__dict__.items(): if callable(v) and k[0].lower() == k[0] and k[0] != "_": test_name = 'test_torchvision_models_video_' + k - x = torch.rand(1, 3, 4, 112, 112) + x = torch.rand(1, 3, 4, 112, 112) if k != 'mvit_v1_b' else torch.rand(1, 3, 16, 224, 224) kwargs = dict(num_classes=50) model_test = cls.generate_test_fn(k, v, x, kwargs) setattr(cls, test_name, model_test) diff --git a/test/test_fx_backends.py b/test/test_fx_backends.py new file mode 100644 index 0000000000000..645abdd9e8018 --- /dev/null +++ b/test/test_fx_backends.py @@ -0,0 +1,258 @@ +# Owner(s): ["module: fx"] + +import copy +import sys +import logging +from typing import List, Tuple + +import torch +from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.backends.nvfuser import NvFuserBackend + +from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TestCase +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + skipCUDAIfRocm, + dtypes, +) + +if not TEST_CUDA: + print('CUDA not available, skipping tests', file=sys.stderr) + TestCase = object # noqa: F811 + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +class HF_T5_Partial(torch.nn.Module): + + def inputs_meta(self): + return [ + (torch.Size([512, 512]), torch.float32), + (torch.Size([512, 512]), torch.float32), + (torch.Size([512, 512]), torch.float32), + (torch.Size([512, 512]), torch.float32), + (torch.Size([512]), torch.float32), + (torch.Size([2048, 512]), torch.float32), + (torch.Size([512, 2048]), torch.float32), + (torch.Size([512]), torch.float32), + (torch.Size([8, 1024, 512]), torch.float32), + (torch.Size([8, 8, 1024, 1024]), torch.float32), + ] + + def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, + primals_6, primals_7, primals_8, primals_9, primals_10): + pow_1 = torch.ops.aten.pow(primals_9, 2) + mean = torch.ops.aten.mean(pow_1, [-1], True) + add = torch.ops.aten.add(mean, 1e-06) + rsqrt = torch.ops.aten.rsqrt(add) + mul = torch.ops.aten.mul(primals_9, rsqrt) + mul_1 = torch.ops.aten.mul(primals_5, mul) + t = torch.ops.aten.t(primals_3) + view = torch.ops.aten.view(mul_1, [8192, 512]) + mm = torch.ops.aten.mm(view, t) + _unsafe_view = torch.ops.aten._unsafe_view(mm, [8, 1024, 512]) + view_1 = torch.ops.aten.view(_unsafe_view, [8, -1, 8, 64]) + transpose = torch.ops.aten.transpose(view_1, 1, 2) + t_1 = torch.ops.aten.t(primals_1) + view_2 = torch.ops.aten.view(mul_1, [8192, 512]) + mm_1 = torch.ops.aten.mm(view_2, t_1) + _unsafe_view_1 = torch.ops.aten._unsafe_view(mm_1, [8, 1024, 512]) + view_3 = torch.ops.aten.view(_unsafe_view_1, [8, -1, 8, 64]) + transpose_1 = torch.ops.aten.transpose(view_3, 1, 2) + t_2 = torch.ops.aten.t(primals_4) + view_4 = torch.ops.aten.view(mul_1, [8192, 512]) + mm_2 = torch.ops.aten.mm(view_4, t_2) + _unsafe_view_2 = torch.ops.aten._unsafe_view(mm_2, [8, 1024, 512]) + view_5 = torch.ops.aten.view(_unsafe_view_2, [8, -1, 8, 64]) + transpose_2 = torch.ops.aten.transpose(view_5, 1, 2) + transpose_3 = torch.ops.aten.transpose(transpose_1, 3, 2) + expand = torch.ops.aten.expand(transpose, [8, 8, 1024, 64]) + clone = torch.ops.aten.clone(expand, memory_format=torch.contiguous_format) + _unsafe_view_3 = torch.ops.aten._unsafe_view(clone, [64, 1024, 64]) + expand_1 = torch.ops.aten.expand(transpose_3, [8, 8, 64, 1024]) + clone_1 = torch.ops.aten.clone(expand_1, memory_format=torch.contiguous_format) + _unsafe_view_4 = torch.ops.aten._unsafe_view(clone_1, [64, 64, 1024]) + bmm = torch.ops.aten.bmm(_unsafe_view_3, _unsafe_view_4) + _unsafe_view_5 = torch.ops.aten._unsafe_view(bmm, [8, 8, 1024, 1024]) + add_ = torch.ops.aten.add_(_unsafe_view_5, primals_10) + _softmax = torch.ops.aten._softmax(add_, -1, False) + expand_2 = torch.ops.aten.expand(_softmax, [8, 8, 1024, 1024]) + view_6 = torch.ops.aten.view(expand_2, [64, 1024, 1024]) + expand_3 = torch.ops.aten.expand(transpose_2, [8, 8, 1024, 64]) + clone_2 = torch.ops.aten.clone(expand_3, memory_format=torch.contiguous_format) + _unsafe_view_6 = torch.ops.aten._unsafe_view(clone_2, [64, 1024, 64]) + bmm_1 = torch.ops.aten.bmm(view_6, _unsafe_view_6) + _unsafe_view_7 = torch.ops.aten._unsafe_view(bmm_1, [8, 8, 1024, 64]) + transpose_4 = torch.ops.aten.transpose(_unsafe_view_7, 1, 2) + clone_3 = torch.ops.aten.clone(transpose_4, memory_format=torch.contiguous_format) + view_7 = torch.ops.aten.view(clone_3, [8, -1, 512]) + t_3 = torch.ops.aten.t(primals_2) + view_8 = torch.ops.aten.view(view_7, [8192, 512]) + mm_3 = torch.ops.aten.mm(view_8, t_3) + _unsafe_view_8 = torch.ops.aten._unsafe_view(mm_3, [8, 1024, 512]) + add_1 = torch.ops.aten.add(primals_9, _unsafe_view_8) + pow_2 = torch.ops.aten.pow(add_1, 2) + mean_1 = torch.ops.aten.mean(pow_2, [-1], True) + add_2 = torch.ops.aten.add(mean_1, 1e-06) + rsqrt_1 = torch.ops.aten.rsqrt(add_2) + mul_2 = torch.ops.aten.mul(add_1, rsqrt_1) + mul_3 = torch.ops.aten.mul(primals_8, mul_2) + t_4 = torch.ops.aten.t(primals_6) + view_9 = torch.ops.aten.view(mul_3, [8192, 512]) + mm_4 = torch.ops.aten.mm(view_9, t_4) + _unsafe_view_9 = torch.ops.aten._unsafe_view(mm_4, [8, 1024, 2048]) + relu = torch.ops.aten.relu(_unsafe_view_9) + t_5 = torch.ops.aten.t(primals_7) + view_10 = torch.ops.aten.view(relu, [8192, 2048]) + mm_5 = torch.ops.aten.mm(view_10, t_5) + _unsafe_view_10 = torch.ops.aten._unsafe_view(mm_5, [8, 1024, 512]) + add_3 = torch.ops.aten.add(add_1, _unsafe_view_10) + return [add_3, rsqrt, _unsafe_view_3, t_3, _softmax, view_6, mul_2, t, view_9, t_1, primals_5, add_1, + _unsafe_view_4, view_2, view_10, t_5, t_2, primals_8, view_4, view_8, rsqrt_1, primals_9, t_4, + mul, _unsafe_view_6, relu, view] + + +class TestFxNvFuserBackend(TestCase): + + def _generate_random_inputs(self, device, inputs_meta: List[Tuple[torch.Size, torch.dtype]]): + inputs = [] + for meta in inputs_meta: + shape, dtype = meta + + if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8}: + input = torch.randint(0, 1, shape, dtype=dtype, device=device) + else: + input = torch.rand(shape, dtype=dtype, device=device) + + inputs.append(input) + + return inputs + + + @skipCUDAIfRocm + @dtypes(torch.float32) + def test_nvfuser_call_module_backend(self, device, dtype): + + class Model(torch.nn.Module): + + def __init__(self): + super(Model, self).__init__() + self.bn = torch.nn.BatchNorm2d(3) + self.relu = torch.nn.ReLU() + + def forward(self, inp): + o = self.bn(inp) + o = self.relu(o) + return o + + inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device) + m = Model().to(dtype=dtype, device=device) + + # note that the traced module here contains only `call_module` node, + # which isn't fused by nvfuser backend. But `nvfuser.compile` should run without error + traced = symbolic_trace(m) + + nvfuser = NvFuserBackend() + compiled_module = nvfuser.compile(traced) + + eager_result = m(inp) + nvfuser_result = compiled_module(inp) + + torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) + + + @skipCUDAIfRocm + @dtypes(torch.float32) + def test_nvfuser_backend(self, device, dtype): + m = HF_T5_Partial() + m.to(device) + + traced = symbolic_trace(m) + + nvfuser = NvFuserBackend() + compiled_module = nvfuser.compile(traced) + + inputs = self._generate_random_inputs(device, m.inputs_meta()) + + eager_result = m(*inputs) + nvfuser_result = compiled_module(*inputs) + + torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) + + @skipCUDAIfRocm + @dtypes(torch.float32) + def test_aten_square(self, device, dtype): + + def fn(x): + square = torch.square(x) + a = square + 1 + b = a + 1 + return b + + inputs = torch.randn(4, device=device) + traced = make_fx(fn)(inputs) + + nvfuser = NvFuserBackend() + compiled_module = nvfuser.compile(copy.deepcopy(traced)) + + for node in compiled_module.graph.nodes: + if node.op == "call_function": + assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" + + eager_result = traced(inputs) + nvfuser_result = compiled_module(inputs) + torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) + + @skipCUDAIfRocm + @dtypes(torch.float32) + def test_aten_leakyrelu(self, device, dtype): + + def fn(x): + square = torch.ops.aten.leaky_relu(x, 0.1) + a = square + 1 + b = a + 1 + return b + + inputs = torch.randn(4, device=device) + traced = make_fx(fn)(inputs) + + nvfuser = NvFuserBackend() + compiled_module = nvfuser.compile(copy.deepcopy(traced)) + + for node in compiled_module.graph.nodes: + if node.op == "call_function": + assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" + + eager_result = traced(inputs) + nvfuser_result = compiled_module(inputs) + torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) + + @skipCUDAIfRocm + @dtypes(torch.float32) + def test_aten_where(self, device, dtype): + + def fn(x): + where = torch.ops.aten.where(x < 0, -x, x) + a = where + 1 + b = a + 1 + return b + + inputs = torch.randn(4, device=device) + traced = make_fx(fn)(inputs) + + nvfuser = NvFuserBackend() + compiled_module = nvfuser.compile(copy.deepcopy(traced)) + + for node in compiled_module.graph.nodes: + if node.op == "call_function": + assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" + + eager_result = traced(inputs) + nvfuser_result = compiled_module(inputs) + torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) + +instantiate_device_type_tests(TestFxNvFuserBackend, globals(), only_for="cuda") + +if __name__ == "__main__": + run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 047e0f02a3b20..5ace6bb830d5b 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1,4 +1,4 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: fx"] import math import numbers @@ -38,7 +38,7 @@ type_matches, create_type_hint, ) -from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp +from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.split_module import split_module from torch.testing._internal.common_device_type import ( ops, @@ -72,93 +72,6 @@ def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> Graph class TestFXExperimental(JitTestCase): - def test_serialize_graph(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - self.e = torch.rand(4) - self.conv = torch.nn.Conv2d(3, 3, 2, bias=False) - - def forward(self, a, b, c): - add_1 = a + b - conv1 = self.conv(c) - linear = self.linear(add_1 + conv1) - add_2 = linear + self.e - return add_2 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - b = torch.rand(4) - c = torch.rand(3, 3, 2, 2) - graph_manipulation.get_size_of_all_nodes(traced, [a, b, c]) - - partitioner = Partitioner() - devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)] - partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - # Fix for now to add type/shape to output - for node in traced.graph.nodes: - if node.op == "output": - node.meta["tensor_meta"] = _extract_tensor_metadata(a) - for mod in module_with_submodules.modules(): - if isinstance(mod, GraphModule): - for node in mod.graph.nodes: - node.meta["tensor_meta"] = _extract_tensor_metadata(a) - for node in module_with_submodules.graph.nodes: - node.meta["tensor_meta"] = _extract_tensor_metadata(a) - - weights1 = {} - weights2 = {} - serialized_graph1 = graph_manipulation.serialize_module(traced, weights1) - serialized_graph2 = graph_manipulation.serialize_module( - module_with_submodules, weights2 - ) - assert len(weights1) == 4 - assert len(weights2) == 4 - assert len(serialized_graph1["nodes"]) == 10 - assert len(serialized_graph1["weights"]) == 4 - assert len(serialized_graph1["modules"]) == 0 - assert len(serialized_graph2["nodes"]) == 6 - assert len(serialized_graph2["weights"]) == 1 - assert len(serialized_graph2["modules"]) == 1 - assert serialized_graph1["weights"]["linear.weight"]["shape"] == "[4, 4]" - assert serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32" - assert serialized_graph1["weights"]["linear.weight"]["is_quantized"] is False - assert serialized_graph1["nodes"][0]["shape"] == "[4]" - assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32" - assert serialized_graph1["nodes"][0]["target"] == "a" - assert serialized_graph1["nodes"][0]["op_code"] == "placeholder" - assert serialized_graph1["nodes"][0]["name"] == "a" - assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_1" - assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True - - # Test the users of the nodes. No users of the last/output node. - assert serialized_graph2["nodes"][0]["users"][0]["name"] == "submod_0" - assert serialized_graph2["nodes"][1]["users"][0]["name"] == "submod_0" - assert serialized_graph2["nodes"][4]["users"][0]["name"] == "output" - assert serialized_graph2["nodes"][5]["users"] == [] - - # Test quantization info serialization. - x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) - q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32) - q_tensor_channel = torch.quantize_per_channel( - x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8 - ) - result, _ = graph_manipulation.serialize_tensor_quantization( - q_tensor, weights={}, pcq_prefix="foo" - ) - result2, per_channel_dict = graph_manipulation.serialize_tensor_quantization( - q_tensor_channel, weights={}, pcq_prefix="bar" - ) - assert result["qscheme"] == "torch.per_tensor_affine" - assert result["q_scale"] == 1.0 - assert result2["qscheme"] == "torch.per_channel_affine" - assert result2["q_per_channel_scales"] == "bar_per_channel_scales" - assert per_channel_dict["bar_per_channel_zero_points"]["shape"] == "[2]" - def test_find_single_partition(self): class TestModule(torch.nn.Module): def forward(self, a, b): diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py new file mode 100644 index 0000000000000..e8085fbb92ab2 --- /dev/null +++ b/test/test_fx_passes.py @@ -0,0 +1,276 @@ +# Owner(s): ["module: fx.passes"] + +import operator +import logging + +import torch +from torch.fx._symbolic_trace import symbolic_trace + +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions + +from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests +from torch.testing._internal.jit_utils import JitTestCase + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.param = torch.nn.Parameter(torch.rand(4, 4)) + + def forward(self, a, b, c): + add = a + b + + linear_1 = self.linear(add) + + add_1 = add + c + add_2 = add_1 + self.param + add_3 = add_1 + linear_1 + add_4 = add_2 + add_3 + + linear_2 = self.linear2(add_4) + + add_5 = linear_2 + add_4 + add_6 = add_5 + a + relu = add_6.relu() + + return add_4, add_6, relu + +class TestPartitionFunctions: + @staticmethod + def forward1(a, b, c): + add = a + b + add_1 = add + b + add_2 = add_1 + c + relu_1 = add_2.relu() + add_3 = add_1 + add_2 + add_4 = add_1 + relu_1 + add_3 + relu_2 = add_4.relu() + add_5 = relu_2 + add_4 + add_6 = add_5 + add_4 + return add_4, add_6 + + @staticmethod + def forward2(a, b, _): + add = a + b + add_1 = add + b + relu_1 = add_1.relu() # blocked by this + add_3 = add_1 + relu_1 + add_4 = add_1 + add_3 + return add_4, add_1 + + @staticmethod + def forward3(a, b, c): + add = a + b + add_1 = a + c + add_2 = b + c + return add, add_1, add_2 + + @staticmethod + def forward4(a, b, c): + add = a + b + add_1 = a + c + add_2 = b + c + return torch.where(add > 0, add_1, add_2) + + @staticmethod + def forward5(a, b, c): + # add should be fused right branch, as left branch is not supported + add = a + 1 + # left branch + relu = add.relu() + # right branch + add_1 = add + 2 + return relu, add_1 + + @staticmethod + def forward6(a, b, c): + # add should have its own partition, as neither branchs are supported + add = a + 1 + # left branch + relu = add.relu() + # right branch + relu_1 = add.relu() + return relu, relu_1 + + @staticmethod + def forward7(a, b, c): + # both branches are supported, all adds should be fused together + add = a + 1 + # left branch + add_1 = add + 2 + # right branch is larger + add_2 = add + 1 + add_3 = add_2 + 1 + return add_3, add_1 + + @staticmethod + def forward8(a, b, c): + # both branches are in the same partition, add should join the same partition + add = a + 1 + # left branch + add_1 = add + 2 + # right branch + add_2 = add + 1 + # left and right branch merges + add_3 = add_2 + add_1 + + return add_3 + + @staticmethod + def forward9(a, b, c): + add = a + 1 + # branch 1 + add_1 = add + 1 + # branch 2 + add_2 = add + 1 + # branch_3 + add_3 = add + 1 + out = torch.stack([add_1, add_2, add_3]) + return out + + @staticmethod + def forward10(a, b, c): + add = a + 1 + # branch 1 + add_1 = add + 1 + # branch 2 + add_2 = add + 1 + # branch 3: depends on branch 2 + add_3 = add + add_2 + out = torch.stack([add_1, add_2, add_3]) + return out + + @staticmethod + def forward11(a, b, c): + add = a + 1 + # branch 1 + add_1 = add.relu() + # branch 2 depends on branch 1 + add_2 = add + add_1 + # branch 3 + add_3 = add.relu() + out = torch.stack([add_1, add_2, add_3]) + return out + +# A mock OperatorSupport class, where only operator.add is supported +class MockOperatorSupport(OperatorSupport): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in {operator.add} + +class TestFXGraphPasses(JitTestCase): + + @parametrize("fn, expected_partition", [ + (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]]), + (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]]), + + # 2 branches cases + (TestPartitionFunctions.forward5, [["add_1", "add"]]), + (TestPartitionFunctions.forward6, [["add"]]), + (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]]), + (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]]), + + # 3 branch cases + (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']]), + (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']]), + (TestPartitionFunctions.forward11, [['add_1'], ['add']]), + ]) + def test_partitioner(self, fn, expected_partition): + traced = symbolic_trace(fn) + + supported_ops = MockOperatorSupport() + partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True) + partitions = partitioner.propose_partitions() + + partitions_name = [[node.name for node in partition.nodes] for partition in partitions] + assert len(partitions_name) == len(expected_partition) + for i in range(len(partitions_name)): + assert set(partitions_name[i]) == set(expected_partition[i]) + + fused_graph = partitioner.fuse_partitions(partitions) + + a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) + + expected = fn(a, b, c) + result = fused_graph(a, b, c) + torch.testing.assert_close(expected, result) + + + @parametrize("fn, expected_partition", [ + # horizontal fusion without a common downstream node, not supported yet + (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]]), + # horizontal fusion with a common downstream node, not supported yet + (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]]), + ]) + def test_partitioner_xfail(self, fn, expected_partition): + traced = symbolic_trace(fn) + + supported_ops = MockOperatorSupport() + partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True) + partitions = partitioner.propose_partitions() + + partitions_name = [[node.name for node in partition.nodes] for partition in partitions] + with self.assertRaises(Exception): + assert len(partitions_name) == len(expected_partition) + + @parametrize("partition", [ + [['add', 'add_1'], ['add_5', 'add_6']], + [['add', 'add_1', 'add_2']], # vertical fusion + [['add_2', 'add_3']], # horizontal fusion + [['add_3', 'add_4']], + [['add_6', 'add_5']], # arbitray node order + [['add_4', 'add_1', 'add_3', 'add_2']], # arbitray node order + [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitray partition order + [['add_5', 'linear2']], # includes call_function + call_module node + [['add_6', 'relu']], # includes call_function + call_module node + [['param', 'add_2']], # includes get_attr + call_module nodes + [['param', 'add_1', 'linear']], # includes get_attr + call_function + call_module nodes + [["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]], # full graph + ]) + def test_fuser_util(self, partition): + m = TestModule() + gm = symbolic_trace(m) + + nodes_by_name = {node.name : node for node in gm.graph.nodes} + + partitions = [] + for node_names in partition: + partitions.append([nodes_by_name[name] for name in node_names]) + + fused_graph = fuse_by_partitions(gm, partitions) + + a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) + + expected = m(a, b, c) + result = fused_graph(a, b, c) + + torch.testing.assert_close(expected, result) + + @parametrize("partition", [ + [['add', 'add_1'], ['add_1', 'add_5', 'add_6']], # add_1 exists in multiple partitions + [['add', 'add_1', 'add_3']], # invalid partition: circular dependency + [['add_4', 'add_5']], # invalid partition: circular dependency + [['relu', 'add_5']], # invalid partition: circular dependency + ]) + def test_fuser_util_xfail(self, partition): + m = TestModule() + gm = symbolic_trace(m) + + nodes_by_name = {node.name : node for node in gm.graph.nodes} + + partitions = [] + for node_names in partition: + partitions.append([nodes_by_name[name] for name in node_names]) + + with self.assertRaises(Exception): + fuse_by_partitions(gm, partitions) + +instantiate_parametrized_tests(TestFXGraphPasses) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_fx_reinplace_pass.py b/test/test_fx_reinplace_pass.py new file mode 100644 index 0000000000000..037633670bf14 --- /dev/null +++ b/test/test_fx_reinplace_pass.py @@ -0,0 +1,251 @@ +# Owner(s): ["module: functionalization"] +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.fx.passes.reinplace import reinplace +from torch.fx.experimental.proxy_tensor import make_fx + +try: + from functorch.experimental import functionalize + HAS_FUNCTIONALIZATION = True +except Exception as e: + HAS_FUNCTIONALIZATION = False + +class TestReinplacePass(TestCase): + + def test_reinplace_basic(self): + # Basic test: the out-of-place add() call should be converted + # into add_() + def f(x): + a = x.clone() + b = a.add(1) + return b + + inpt = torch.ones(2) + f2 = reinplace(make_fx(f)(inpt), inpt) + expected_out = f(inpt) + actual_out = f2(inpt) + self.assertEqual(actual_out, expected_out) + self.assertExpectedInline(f2.code, """\ + + + +def forward(self, x_1): + clone_default = torch.ops.aten.clone.default(x_1); x_1 = None + add_tensor = torch.ops.aten.add_.Tensor(clone_default, 1) + return clone_default + """) + + + def test_reinplace_with_view(self): + def f(x): + a = x.clone() + a_view = a.view(-1) + # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program + b = a.add(1) + # Second add() is fine to re-inplace + c = a_view.add(1) + return c + + inpt = torch.ones(2) + f2 = reinplace(make_fx(f)(inpt), inpt) + expected_out = f(inpt) + actual_out = f2(inpt) + self.assertEqual(actual_out, expected_out) + self.assertExpectedInline(f2.code, """\ + + + +def forward(self, x_1): + clone_default = torch.ops.aten.clone.default(x_1); x_1 = None + view_default = torch.ops.aten.view.default(clone_default, [-1]) + add_tensor = torch.ops.aten.add.Tensor(clone_default, 1); clone_default = None + add_tensor_1 = torch.ops.aten.add_.Tensor(view_default, 1) + return view_default + """) + + # This test won't actually run in CI, because it requires functionalize() from functorch. + # I'm planning on testing more comprehensively with torchbench models, + # but we can make this testing better once functorch moves into pytorch/pytorch. + def test_reinplace_scatter_op(self): + def f(a_): + # for now, don't test mutations to inputs + a = a_.clone() + e = a.view(-1) + b = a.view(-1) + c = b[0] + d = c.view(-1) + d.add_(1) + return a + e + + if not HAS_FUNCTIONALIZATION: + return + inpt = torch.ones(4) + f2 = reinplace(make_fx(functionalize(f))(inpt), inpt) + expected_out = f(inpt) + actual_out = f2(inpt) + self.assertEqual(actual_out, expected_out) + # NOTE: one slight pessimization here is the fact that + # there are a bunch of redundant views in the graph. + # Technically, half of these views are duplicates that we could de-dup. + # This shouldn't really hurt performance though, since creating an extra view + # is effectively just moving some metadata around (and allocating a new TensorImpl). + # We can/should update the pass in the future to clean this up. + self.assertExpectedInline(f2.code, """\ + + + +def forward(self, a__1): + clone_default = torch.ops.aten.clone.default(a__1); a__1 = None + view_default = torch.ops.aten.view.default(clone_default, [-1]) + view_default_1 = torch.ops.aten.view.default(clone_default, [-1]) + select_int = torch.ops.aten.select.int(view_default_1, 0, 0); view_default_1 = None + view_default_2 = torch.ops.aten.view.default(select_int, [-1]); select_int = None + add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1) + view_default_3 = torch.ops.aten.view.default(clone_default, [-1]); clone_default = None + select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0) + view_default_4 = torch.ops.aten.view.default(view_default_2, []); view_default_2 = None + view_default_5 = torch.ops.aten.view.default(view_default_3, [4]); view_default_3 = None + view_default_6 = torch.ops.aten.view.default(view_default_5, [-1]) + add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6); view_default_6 = None + return view_default_5 + """) + + def test_reinplace_scatter_twice(self): + def f(a_): + # for now, don't test mutations to inputs + a = a_.clone() + b = a[:, 1] + c = b[1] + c.add_(1) + return a + + if not HAS_FUNCTIONALIZATION: + return + + inpt = torch.ones(4, 4) + f2 = reinplace(make_fx(functionalize(f))(inpt), inpt) + expected_out = f(inpt) + actual_out = f2(inpt) + self.assertEqual(actual_out, expected_out) + self.assertExpectedInline(f2.code, """\ + + + +def forward(self, a__1): + clone_default = torch.ops.aten.clone.default(a__1); a__1 = None + slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) + select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None + select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None + add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None + slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) + select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1); slice_tensor_1 = None + return clone_default + """) + + def test_reinplace_scatter_twice_with_different_view_op_valid(self): + def f(a_): + a = a_.clone() + b = a[:, 1] + c = b[1] + c_updated = c.add(1) + good_mirror_of_b = a.as_strided((4,), (4,), 1) + # good_mirror_of_b points to the same region of memory as b. + # and this scatter op below tries to scatter c_updated into the same region + # that c currently takes up. + # reinplacing logic checks this by confirming that: + # c_updated + # good_mirror_of_b.select(0, 1) + # have the same size/stride/storage_offset. + b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1) + return b_updated + + inpt = torch.ones(4, 4) + f2 = reinplace(make_fx(f)(inpt), inpt) + expected_out = f(inpt) + actual_out = f2(inpt) + self.assertEqual(actual_out, expected_out) + self.assertExpectedInline(f2.code, """\ + + + +def forward(self, a__1): + clone_default = torch.ops.aten.clone.default(a__1); a__1 = None + slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) + select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None + select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None + add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None + as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None + return as_strided_default + """) + + # Test example where we have a scatter op, where the base tensor + # has the same size/stride/storage offset (even though it is a different view), + # making it valid to re-inplace + def test_reinplace_scatter_twice_with_different_view_op_invalid(self): + def f(a_): + a = a_.clone() + b = a[:, 1] + c = b[1] + c_updated = c.add(1) + good_mirror_of_b = a.as_strided((4,), (4,), 1) + # The first arg to select_scatter is an equivalent view to b. + # However, the select_scatter call below tries to put c_updated + # into a different slice of "b" than what "c" currently occupies. + # + b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0) + return b_updated + + inpt = torch.ones(4, 4) + f2 = reinplace(make_fx(f)(inpt), inpt) + expected_out = f(inpt) + actual_out = f2(inpt) + self.assertEqual(actual_out, expected_out) + self.assertExpectedInline(f2.code, """\ + + + +def forward(self, a__1): + clone_default = torch.ops.aten.clone.default(a__1); a__1 = None + slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) + select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None + select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None + add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None + as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None + select_scatter_default = torch.ops.aten.select_scatter.default(as_strided_default, add_tensor, 0, 0); as_strided_default = add_tensor = None + return select_scatter_default + """) # noqa: B950 + + def test_reinplace_scatter_twice_with_different_view_op_invalid2(self): + def f(a_): + a = a_.clone() + b = a[:, 1] + c = b[1] + c_updated = c.add(1) + bad_mirror_of_b = a.as_strided((4,), (4,), 0) + # The first arg to select_scatter points to a different than c's base. + # This makes it invalid to re-inplace. + b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1) + return b_updated + + inpt = torch.ones(4, 4) + f2 = reinplace(make_fx(f)(inpt), inpt) + expected_out = f(inpt) + actual_out = f2(inpt) + # self.assertEqual(actual_out, expected_out) + self.assertExpectedInline(f2.code, """\ + + + +def forward(self, a__1): + clone_default = torch.ops.aten.clone.default(a__1); a__1 = None + slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807) + select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None + select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None + add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None + as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0); clone_default = None + select_scatter_default = torch.ops.aten.select_scatter.default(as_strided_default, add_tensor, 0, 1); as_strided_default = add_tensor = None + return select_scatter_default + """) # noqa: B950 + +if __name__ == '__main__': + run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index 55306f0c910db..b07b83cc40c28 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -101,7 +101,7 @@ suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ freeze_rng_state, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ - skipIfCrossRef, IS_MACOS + skipIfCrossRef, IS_MACOS, skipIfTorchDynamo from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ _trace, do_input_map, get_execution_plan, make_global, \ execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ @@ -2003,6 +2003,7 @@ def sublist_format(x, y): check(fn, torch.jit.script(fn), x, y) check(fn, torch.jit.trace(fn, (x, y)), x, y) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_python_ivalue(self): # Test if pure python object can be hold as IValue and conversion # between IValue and PyObject are correct @@ -2773,6 +2774,7 @@ def forward(self, input=[]): # noqa: B006 with self.assertRaisesRegex(Exception, "Mutable default parameters"): torch.jit.script(Test()) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_warnings(self): import warnings @@ -3974,6 +3976,7 @@ def select_expr_or_var(): o2 = cu.f() self.assertEqual(o1, o2) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_cpp_module_iterator(self): a = nn.Module() a.name = 'a' @@ -5365,7 +5368,7 @@ def func(x): def func2(x): return x.sum(dim=4) - # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument + # test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument self.run_pass('constant_propagation', func.graph) self.run_pass('constant_propagation', func2.graph) g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False) @@ -6446,6 +6449,7 @@ def func(a, b): inputs = self._make_scalar_vars([42, 1337], torch.int64) self.checkScript(func, inputs, optimize=True) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_while_nest_if(self): def func(a, b): # type: (int, int) -> int @@ -6602,6 +6606,7 @@ def func(a, b): checkMathWrap("remainder", 2) checkMathWrap("factorial", 1, is_float=False, ret_type="int", vals=[(i, 0) for i in range(-2, 10)]) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_if_nest_while(self): def func(a, b): # type: (int, int) -> int @@ -7437,6 +7442,9 @@ def func(): # tensor from empty list is type float in python and annotated type in torchscript if "annotate" in li and "dtype" not in option: continue + # Skip unsigned tensor initializaton for signed values on 3.10 + if sys.version_info[:2] >= (3, 10) and "torch.uint8" in option and "-" in li: + continue code = tensor_template.format(list_create=li, tensor_op=op, options=option) scope = {} exec(code, globals(), scope) @@ -10193,6 +10201,7 @@ def forward(self): cm_load = torch.jit.load(buffer) FileCheck().check_not("Double(1, 3)").run(cm_load.forward.graph) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_type_annotations_repeated_list(self): @torch.jit.script def float_fn(x, y): @@ -11131,12 +11140,12 @@ def test_rand(): def randint(): return torch.randint(0, 5, [1, 2]) out = randint() - self.assertEqual(out.dtype, torch.double) - # although the type should be int here, testing that the runtime dtype - # and shape analysis dtype is the same. + self.assertEqual(out.dtype, torch.int64) if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: - FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \ - .check_not("Float(*, *, requires_grad=0, device=cpu)").run(randint.graph_for()) + FileCheck().check("Long(*, *, requires_grad=0, device=cpu)") \ + .check_not("Float(*, *, requires_grad=0, device=cpu)") \ + .check_not("Double(*, *, requires_grad=0, device=cpu)") \ + .run(randint.graph_for()) @unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") @@ -11235,14 +11244,12 @@ def test_rand(): def randint(): return torch.randint(0, 5, [1, 2]) - # although the type should be int here, testing that the runtime dtype - # and shape analysis dtype is the same. with enable_profiling_mode_for_profiling_tests(): with num_profiled_runs(1): out = randint() graph_str = torch.jit.last_executed_optimized_graph() - self.assertEqual(out.dtype, torch.double) - FileCheck().check("profiled_type=Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str) + self.assertEqual(out.dtype, torch.int64) + FileCheck().check("profiled_type=Long(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str) def test_erase_number_types(self): @@ -11655,6 +11662,7 @@ def test_none_type_str(self): python_type = eval(none_type.annotation_str, g) assert python_type is type(None) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_zip_enumerate_modulelist(self): class Sub(torch.nn.Module): def __init__(self): @@ -11765,7 +11773,17 @@ def fn(x): self.checkScript(fn, ([1, 2, 3, 4, 5],)) - def fn_enumerate_start_index(x): + def fn_enumerate_start_arg(x): + # type: (List[int]) -> int + sum = 0 + for (i, v) in enumerate(x, 1): + sum += i * v + + return sum + + self.checkScript(fn_enumerate_start_arg, ([1, 2, 3, 4, 5],)) + + def fn_enumerate_start_kwarg(x): # type: (List[int]) -> int sum = 0 for (i, v) in enumerate(x, start=1): @@ -11773,7 +11791,7 @@ def fn_enumerate_start_index(x): return sum - self.checkScript(fn, ([1, 2, 3, 4, 5],)) + self.checkScript(fn_enumerate_start_kwarg, ([1, 2, 3, 4, 5],)) def fn_nested_enumerate(x): # type: (List[int]) -> int @@ -11783,7 +11801,7 @@ def fn_nested_enumerate(x): return sum - self.checkScript(fn, ([1, 2, 3, 4, 5],)) + self.checkScript(fn_nested_enumerate, ([1, 2, 3, 4, 5],)) with self.assertRaisesRegex(RuntimeError, r'enumerate expected at least 1 arguments, got 0'): @torch.jit.script @@ -16118,6 +16136,22 @@ def forward(self, x): self.checkModule(MyModule(), (torch.ones(2, 3),)) + def test_context_manager(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, x, y): + p = x + y + q = p + 2.0 + return q + + x = torch.randn(3, 2, dtype=torch.float) + y = torch.randn(3, 2, dtype=torch.float) + for fuser_name in ['fuser0', 'fuser1', 'none']: + with torch.jit.fuser(fuser_name): + self.checkModule(MyModule(), (x, y)) + # known to be failing in tracer EXCLUDE_TRACED = { # The following fail due to #12024. diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 09f7dde3dcb1e..031643d021f5f 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -41,6 +41,7 @@ if RUN_NVFUSER and torch.version.cuda is not None: CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.')[:2]) +os.environ['PYTORCH_NVFUSER_ENABLE'] = 'linear_decomposition,conv_decomposition' os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,unroll_with_rng' os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' # TODO: enable complex when we fixes the extremal cases in OpInfo diff --git a/test/test_linalg.py b/test/test_linalg.py index 1c5a2d7c19443..8739da0c99138 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -17,7 +17,8 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices, - make_fullrank_matrices_with_distinct_singular_values) + make_fullrank_matrices_with_distinct_singular_values, + freeze_rng_state) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, has_cusolver, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, @@ -533,14 +534,14 @@ def test_cholesky_errors_and_warnings(self, device, dtype): # dtypes should be safely castable out = torch.empty(*A.shape, dtype=torch.int, device=device) - with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"): + with self.assertRaisesRegex(RuntimeError, "but got int instead"): torch.linalg.cholesky(A, out=out) # device should match if torch.cuda.is_available(): wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' out = torch.empty(0, device=wrong_device, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, "Expected result and input tensors to be on the same device"): + with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): torch.linalg.cholesky(A, out=out) # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py @@ -697,19 +698,6 @@ def test_cholesky_ex_non_pd(self, device, dtype): with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'): torch.linalg.cholesky_ex(A, check_errors=True) - @skipCUDAIfNoMagmaAndNoCusolver - @skipCPUIfNoLapack - @dtypes(*floating_and_complex_types()) - def test_cholesky_ex_out_info_error(self, device, dtype): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix - - # dtype for info must be torch.int32 - A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) - L = torch.empty(A.shape, dtype=dtype, device=device) - info = torch.empty(A.shape[:-2], dtype=torch.int64, device=device) - with self.assertRaisesRegex(RuntimeError, "but got info with dtype Long"): - torch.linalg.cholesky_ex(A, out=(L, info)) - def _test_addr_vs_numpy(self, device, dtype, beta=1, alpha=1): def check(m, a, b, beta, alpha): if dtype == torch.bfloat16: @@ -1509,12 +1497,13 @@ def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): S = 10 error_test_cases = [ # input size, p settings, dim, error type, error regex - ((S, ), ['fro', 'nuc'], None, RuntimeError, r'input tensor must be a matrix or a batch of matrices'), + ((S, ), ['fro', 'nuc'], None, RuntimeError, r'A must have at least 2 dimensions'), ((S, S), [3.5], None, RuntimeError, r'matrix_norm: Order 3.5 not supported'), ((S, S), [0], None, RuntimeError, r'matrix_norm: Order 0 not supported'), ((S, S), ['fail'], None, RuntimeError, r'matrix_norm: Order fail not supported'), - ((S, S), ['fro', 'nuc'], 0, RuntimeError, r'matrix_norm: dim must be a 2-tuple of ints'), - ((S, S), ['fro', 'nuc', 2], (0, 0), RuntimeError, r'Expected dims to be different'), + ((S, S), ['fro', 'nuc'], 0, RuntimeError, r'matrix_norm: dim must be a 2-tuple'), + ((S, S), ['fro', 'nuc', 2], (0, 0), RuntimeError, r'dims must be different'), + ((S, S), ['fro', 'nuc', 2], (-1, 1), RuntimeError, r'dims must be different'), ((S, S), ['fro', 'nuc', 2], (0, 4), IndexError, r'Dimension out of range'), ((S, ), [0], (4, ), IndexError, r'Dimension out of range'), ((S, ), [None], (0, 0), RuntimeError, r'dim 0 appears multiple times'), @@ -1596,7 +1585,7 @@ def test_matrix_norm(self, device, dtype): # Test only inputs for which torch.linalg.matrix_norm diverges from torch.linalg.norm A = make_tensor((2, 2, 2), dtype=dtype, device=device) - with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must be a matrix.*'): + with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must have at least 2 dimensions.*'): torch.linalg.matrix_norm(make_tensor((2,), dtype=dtype, device=device)) with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must be a 2-tuple.*'): torch.linalg.matrix_norm(A, dim=(0,)) @@ -6202,6 +6191,19 @@ def run_test(coeff_shape, data_shape): run_test([3, 4], [3, 3, 3]) run_test([3, 4], [3, 3, 3, 3]) + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.complex64) + def test_linalg_matrix_exp_no_warnings(self, device, dtype): + # this tests https://github.com/pytorch/pytorch/issues/80948 + with freeze_rng_state(): + torch.manual_seed(42) + tens = 0.5 * torch.randn(10, 3, 3, dtype=dtype, device=device) + tens = (0.5 * (tens.transpose(-1, -2) + tens)) + with warnings.catch_warnings(record=True) as w: + tens.imag = torch.matrix_exp(tens.imag) + self.assertFalse(len(w)) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) @@ -6532,9 +6534,8 @@ def test_slogdet_errors_and_warnings(self, device, dtype): with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): torch.linalg.slogdet(a) - # slogdet requires the input to be of float, double, cfloat or cdouble types a = torch.randn(2, 2, device=device, dtype=torch.bfloat16) - with self.assertRaisesRegex(RuntimeError, r'of float, double, cfloat or cdouble types'): + with self.assertRaisesRegex(RuntimeError, r'Low precision dtypes not supported'): torch.linalg.slogdet(a) # if non-empty out tensor with wrong shape is passed a warning is given @@ -6549,16 +6550,6 @@ def test_slogdet_errors_and_warnings(self, device, dtype): self.assertEqual(len(w), 1) self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) - # dtypes should be safely castable - sign_out = torch.empty_like(a).to(torch.int) - logabsdet_out = torch.empty_like(a).to(torch.int) - with self.assertRaisesRegex(RuntimeError, "but got sign with dtype Int"): - torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) - - sign_out = torch.empty(0, device=device, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, "but got logabsdet with dtype Int"): - torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) - # device should match if torch.cuda.is_available(): wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' diff --git a/test/test_meta.py b/test/test_meta.py index 7897b343029d0..a6b56f2664142 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -13,6 +13,7 @@ suppress_warnings, TEST_WITH_ASAN, run_tests, + skipIfSlowGradcheckEnv, ) from torch.testing._internal.common_device_type import ( ops, @@ -30,6 +31,8 @@ from collections import defaultdict import unittest import warnings +import weakref + bf16 = torch.bfloat16 f64 = torch.float64 @@ -62,6 +65,7 @@ } +@skipIfSlowGradcheckEnv class TestMetaConverter(TestCase): def assertSameVersionCounter(self, m1, m2): # Cannot easily test m1 and m2 have same storage due to @@ -166,6 +170,35 @@ def test_imag(self): self.assertEqual(m.stride(), y.stride()) self.assertEqual(m.storage_offset(), y.storage_offset()) + def test_weakref(self): + x = torch.randn(4, 4, 4) + m = MetaConverter() + y = m(x) + z = m(x) + self.assertIs(y, z) + self.assertEqual(len(m.tensor_memo), 1) + self.assertEqual(len(m.storage_memo), 1) + del x + self.assertEqual(len(m.tensor_memo), 0) + m.check_for_expired_weak_storages() + self.assertEqual(len(m.storage_memo), 0) + li = [] + for i in range(4): + li.append(torch.rand([i])) + m(li[-1]) + self.assertEqual(len(m.tensor_memo), 4) + del li + self.assertEqual(len(m.tensor_memo), 0) + m.check_for_expired_weak_storages() + self.assertEqual(len(m.storage_memo), 0) + + def test_tensor_outlives_converter(self): + m = MetaConverter() + ref = weakref.ref(m) + x = torch.randn([4, 4]) + y = m(x) + del m + self.assertIs(ref(), None) def assert_ref_meta_equal(test_case, meta_rs, rs, msg_callable): flat_meta_rs, _ = tree_flatten(meta_rs) @@ -315,6 +348,10 @@ def run_meta_crossref( else: indices.append(meta_index) meta_args = (meta_args[0], indices) + + if kwargs.get("device", None) is not None: + meta_kwargs["device"] = "meta" + try: # Suppress warnings, this doesn't matter for test_meta.py # but it does matter if you want to use this decorator @@ -370,78 +407,63 @@ def run_meta_crossref( RE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ") meta_function_expected_failures = { - torch.Tensor.item: {b8, bf16, c128, c64, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense - torch.Tensor.to_sparse: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::to_sparse, aten::to_sparse.sparse_dim - torch.allclose: {bf16, f16, f32, f64}, # aten::_local_scalar_dense - torch.argwhere: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nonzero - torch.bincount: {i16, i32, i64, i8, u8}, # aten::bincount - torch.bucketize: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::bucketize.Tensor, aten::bucketize.Tensor_out - torch.combinations: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::masked_select - torch.complex: {f16, f32, f64}, # aten::complex.out - torch.corrcoef: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense - torch.count_nonzero: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::count_nonzero.dim_IntList - torch.cov: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense - torch.fft.hfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c - torch.fft.hfft: {b8, f32, f64, i16, i32, i64, i8, u8}, - torch.fft.hfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c - torch.floor_divide: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::floor_divide, aten::floor_divide.out - torch.frexp: {bf16, f16, f32, f64}, # aten::frexp.Tensor_out - torch.functional.istft: {f32, f64}, # aten::view_as_complex - torch.functional.unique: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_unique2, aten::unique_dim - torch.functional.unique_consecutive: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::unique_consecutive - torch.histc: {bf16, f32, f64}, # aten::histc, aten::histc.out - torch.histogram: {f32, f64}, # aten::histogram.bin_ct, aten::histogram.bins_tensor - torch.histogramdd: {f32, f64}, # aten::_histogramdd_bin_edges, aten::_histogramdd_from_bin_tensors - torch.kthvalue: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::kthvalue.values - torch.logcumsumexp: {bf16, f32, f64}, # aten::_logcumsumexp, aten::_logcumsumexp.out - torch.masked_select: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::masked_select, aten::masked_select.out - torch.matrix_exp: {bf16, f32, f64}, # aten::linalg_matrix_exp - torch.median: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::median, aten::median.dim_values - torch.mode: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::mode - torch.multinomial: {bf16, f32, f64}, # aten::multinomial, aten::multinomial.out - torch.mvlgamma: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense, aten::mvlgamma.out - torch.nn.functional.conv1d: {bf16, f32, f64, i64}, - torch.nn.functional.conv2d: {bf16, f32, f64, i64}, - torch.nn.functional.conv_transpose1d: {f32, f64, i64}, - torch.nn.functional.conv_transpose2d: {f32, f64, i64}, - torch.nn.functional.conv_transpose3d: {f32, f64, i64}, - torch.nn.functional.ctc_loss: {f32, f64}, - torch.nn.functional.gaussian_nll_loss: {bf16, f32, f64}, # aten::_local_scalar_dense - torch.nn.functional.grid_sample: {f32, f64}, # aten::grid_sampler_2d, aten::grid_sampler_3d - torch.nn.functional.max_pool3d: {f32, f64}, # aten::max_pool3d_with_indices - torch.nn.functional.max_pool3d_with_indices: {f32, f64}, # aten::max_pool3d_with_indices - torch.nn.functional.max_unpool1d: {f32, f64}, # aten::max_unpool2d - torch.nn.functional.max_unpool2d: {f32, f64}, # aten::max_unpool2d - torch.nn.functional.max_unpool3d: {f32, f64}, # aten::max_unpool3d - torch.nn.functional.multi_margin_loss: {f32, f64}, # aten::multi_margin_loss - torch.nn.functional.multilabel_margin_loss: {f32, f64}, # aten::multilabel_margin_loss_forward - torch.nn.functional.one_hot: {i64}, # aten::_local_scalar_dense - torch.nn.functional.pdist: {f32, f64}, # aten::_pdist_forward - torch.nn.functional.prelu: {bf16, f32, f64}, # aten::prelu - torch.nn.functional.rrelu: {bf16, f32, f64}, # aten::rrelu_with_noise - torch.nn.functional.unfold: {bf16, f16, f32, f64}, # aten::im2col - torch.nonzero: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nonzero, aten::nonzero.out - torch.polar: {f32, f64}, # aten::polar.out - torch.repeat_interleave: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::repeat_interleave.Tensor - torch.segment_reduce: {bf16, f16, f32, f64}, # aten::segment_reduce - torch.searchsorted: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::searchsorted.Tensor, aten::searchsorted.Tensor_out - torch.symeig: {f32, f64}, - torch.take: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::take, aten::take.out - torch.vdot: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::vdot - torch.ormqr: {f32, f64}, - torch.cholesky: {f32, f64}, # aten::cholesky, aten::cholesky.out - torch.cholesky_inverse: {f32, f64}, # aten::cholesky_inverse, aten::cholesky_inverse.out - torch.cholesky_solve: {f32, f64}, # aten::_cholesky_solve_helper - torch.eig: {f32, f64}, # aten::_local_scalar_dense - torch.geqrf: {f32, f64}, # aten::geqrf - torch.linalg.det: {f32, f64}, # aten::_det_lu_based_helper - torch.linalg.eig: {f32, f64}, # aten::linalg_eig - torch.linalg.eigvals: {f32, f64}, - torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product - torch.linalg.lstsq: {f32, f64}, # aten::linalg_lstsq.out - torch.linalg.slogdet: {f32, f64}, # aten::linalg_slogdet - torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular - torch.logdet: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero + torch.Tensor.to_sparse : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, + torch.allclose : {f64, f16, c128, c64, bf16, f32}, + torch.argwhere : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, + torch.combinations : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, + torch.corrcoef : {f64, i32, c128, i64, i16, u8, c64, bf16, i8, f32}, + torch.count_nonzero : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, + torch.cov : {f64, i32, c128, i64, i16, u8, c64, bf16, i8, f32}, + torch.functional.istft : {f64, c64, c128, f32}, + torch.geqrf : {f64, c64, c128, f32}, + torch.linalg.householder_product : {f64, c64, c128, f32}, + torch.linalg.solve_triangular : {f64, c64, c128, f32}, + torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, + torch.matrix_exp : {f64, c128, c64, bf16, f32}, + torch.nn.functional.unfold : {f64, f16, c128, c64, bf16, f32}, + torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32}, + torch.ormqr : {f64, c64, c128, f32}, + torch.repeat_interleave : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32}, + torch.take : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, + torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, + torch.bincount : {i32, i64, u8, i16, i8}, + torch.bucketize : {f64, i32, i64, f16, u8, i16, bf16, i8, f32}, + torch.frexp : {f64, f16, bf16, f32}, + torch.functional.unique : {f64, i32, i64, u8, i16, bf16, b8, i8, f32}, + torch.functional.unique_consecutive : {f64, i32, i64, u8, i16, bf16, b8, i8, f32}, + torch.histc : {f64, bf16, f32}, + torch.histogram : {f64, f32}, + torch.histogramdd : {f64, f32}, + torch.kthvalue : {f64, i32, i64, u8, i16, bf16, i8, f32}, + torch.logcumsumexp : {f64, bf16, f32}, + torch.median : {f64, i32, i64, u8, i16, bf16, i8, f32}, + torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32}, + torch.multinomial : {f64, bf16, f32}, + torch.mvlgamma : {f64, i32, i64, u8, i16, bf16, i8, f32}, + torch.nn.functional.ctc_loss : {f64, f32}, + torch.nn.functional.gaussian_nll_loss : {f64, bf16, f32}, + torch.nn.functional.grid_sample : {f64, f32}, + torch.nn.functional.max_pool3d : {f64, f32}, + torch.nn.functional.max_pool3d_with_indices : {f64, f32}, + torch.nn.functional.max_unpool1d : {f64, f32}, + torch.nn.functional.max_unpool2d : {f64, f32}, + torch.nn.functional.max_unpool3d : {f64, f32}, + torch.nn.functional.multi_margin_loss : {f64, f32}, + torch.nn.functional.multilabel_margin_loss : {f64, f32}, + torch.nn.functional.one_hot : {i64}, + torch.nn.functional.pdist : {f64, f32}, + torch.nn.functional.rrelu : {f64, bf16, f32}, + torch.polar : {f64, f32}, + torch.segment_reduce : {f64, f16, bf16, f32}, + torch.searchsorted : {f64, i32, i64, f16, u8, i16, bf16, i8, f32}, + torch.symeig : {f64, f32, c128, c64}, + torch.cholesky : {f64, f32, c128, c64}, + torch.cholesky_inverse : {f64, f32, c128, c64}, + torch.cholesky_solve : {f64, f32, c128, c64}, + torch.eig : {f64, f32, c128, c64}, + torch.linalg.eig : {f64, f32, c128, c64}, + torch.linalg.eigvals : {f64, f32, c128, c64}, + torch.linalg.lstsq : {f64, f32, c128, c64}, } """ @@ -456,23 +478,65 @@ def run_meta_crossref( """ meta_function_skips = { - torch.aminmax: {b8, f32, f64, i16, i32, i64, i8, u8}, - torch.cummax: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, - torch.cummin: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, - torch.diff: {b8}, - torch.equal: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, - torch.functional.cdist: {f32, f64}, - torch.nanmean: {bf16, f16, f32, f64}, - torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8}, - torch.inner: {bf16, f32, f64, i16, i32, i64, i8, u8}, - torch.nn.functional.cross_entropy: {bf16, f32, f64}, - torch.nn.functional.interpolate: {bf16, f32, f64, u8}, - torch.nanmean: {bf16, f16, f32, f64}, # TODO(chilli): Doesn't seem to work for some reason? - torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO - torch.linalg.pinv: {f32, f64}, - torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, + torch.Tensor.__rmatmul__ : {bf16, c128, f64, f32, f16, c64}, + torch.Tensor.matmul : {f64, f32, c128, c64}, + torch.fft.fft2 : {i8, i64, u8, c128, b8, f64, i16, f32, i32, c64, c32, f16}, + torch.fft.fft : {i8, i64, u8, c128, b8, f64, i16, f32, i32, c64, c32, f16}, + torch.fft.fftn : {i8, i64, u8, c128, b8, f64, i16, f32, i32, c64, c32, f16}, + torch.fft.ifft2 : {i8, i64, u8, c128, b8, f64, i16, f32, i32, c64, c32, f16, c32}, + torch.fft.ifft : {c128, c64, c32, f16}, + torch.fft.ifftn : {i8, i64, u8, c128, b8, f64, i16, f32, i32, c64, c32, f16}, + torch.fft.hfft: {f16}, + torch.fft.hfftn: {f16}, + torch.fft.hfft2: {f16}, + torch.fft.ihfft: {f16}, + torch.fft.ihfft2 : {i8, i64, u8, f64, b8, f32, i32, i16, f16, c32, f16}, + torch.fft.ihfftn : {i8, i64, u8, f64, b8, f32, i32, i16, c32, f16}, + torch.fft.irfft2 : {f16}, + torch.fft.irfft : {f16}, + torch.fft.irfftn : {f16}, + torch.fft.rfft2 : {i8, i64, u8, f64, b8, f32, i32, i16, c32, f16}, + torch.fft.rfft : {i8, i64, u8, f64, b8, f32, i32, i16, c32, f16}, + torch.fft.rfftn : {i8, i64, u8, f64, b8, f32, i32, i16, c32, f16}, + torch.functional.atleast_2d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, + torch.functional.atleast_3d : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, + torch.functional.cartesian_prod : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, + torch.functional.einsum : {bf16, c128, f64, f32, f16, c64}, + torch.functional.stft : {c128, f32, c64, f64}, + torch.functional.tensordot : {bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64}, + torch.inner : {bf16, i8, i64, u8, c128, f64, i16, f32, i32, c64}, + torch.linalg.lu_solve : {c128, c64}, + torch.linalg.matrix_norm : {c128, f32, c64, f64}, + torch.linalg.matrix_power : {c128, c64}, + torch.linalg.matrix_rank : {c128, c64}, + torch.linalg.svd : {c128, c64}, + torch.matmul : {bf16, c128, f64, f32, f16, c64}, + torch.nanquantile : {f64, f32}, + torch.nn.functional.batch_norm : {f64, f32}, + torch.nn.functional.binary_cross_entropy : {bf16, f64, f32, f16}, + torch.nn.functional.dropout3d : {bf16, f64, f32, f16}, + torch.nn.functional.local_response_norm : {bf16, f64, f32, f16}, + torch.svd : {c128, c64}, + torch.take_along_dim : {bf16, i8, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, + torch.vstack : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, + torch.aminmax : {i8, i64, u8, f64, b8, f32, i32, i16}, + torch.cummax : {bf16, i8, i64, u8, f64, b8, f32, i32, i16}, + torch.cummin : {bf16, i8, i64, u8, f64, b8, f32, i32, i16}, + torch.diff : {b8}, + torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, + torch.functional.cdist : {f64, f32}, + torch.nanmean : {bf16, f64, f32, f16}, + torch.nn.functional.cross_entropy : {bf16, f64, f32}, + torch.nn.functional.interpolate : {bf16, f64, f32, u8}, + torch.nn.functional.nll_loss : {bf16, f64, f32}, + torch.linalg.pinv : {f64, f32}, + torch.linalg.cond : {c128, c64, f32, f64}, + torch.linalg.vander: {c128, c64, f32, f64, i16, i32, i64, i8, u8}, + torch.linalg.vecdot : {bf16, f64, f32, f16}, + torch.empty : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64}, } + meta_function_device_expected_failures = defaultdict(dict) meta_function_device_skips = defaultdict(dict) @@ -482,24 +546,6 @@ def run_meta_crossref( meta_function_device_expected_failures['cuda'] = { torch.corrcoef: {bf16, f16}, # aten::_local_scalar_dense torch.cov: {f16}, # aten::_local_scalar_dense - torch.fft.fft2: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out - torch.fft.fft: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out - torch.fft.fftn: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out - torch.fft.hfft2: {c32, f16}, # aten::_fft_c2c - torch.fft.hfft: {c32, f16}, - torch.fft.hfftn: {c32, f16}, # aten::_fft_c2c - torch.fft.ifft2: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out - torch.fft.ifft: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out - torch.fft.ifftn: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out - torch.fft.ihfft2: {f16}, - torch.fft.ihfft: {f16}, - torch.fft.ihfftn: {f16}, - torch.fft.irfft2: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out - torch.fft.irfft: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out - torch.fft.irfftn: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out - torch.fft.rfft2: {f16}, - torch.fft.rfft: {f16}, - torch.fft.rfftn: {f16}, torch.functional.unique: {f16}, # aten::_unique2, aten::unique_dim torch.functional.unique_consecutive: {f16}, # aten::unique_consecutive torch.geqrf: {f32, f64}, # aten::geqrf @@ -512,11 +558,6 @@ def run_meta_crossref( torch.median: {f16}, # aten::median, aten::median.dim_values torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out torch.mvlgamma: {f16}, # aten::_local_scalar_dense, aten::mvlgamma.out - torch.nn.functional.conv1d: {f16, c32}, - torch.nn.functional.conv2d: {f16, c32}, - torch.nn.functional.conv_transpose1d: {bf16, f16}, - torch.nn.functional.conv_transpose2d: {bf16, f16}, - torch.nn.functional.conv_transpose3d: {bf16, f16}, torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense torch.nn.functional.grid_sample: {f16}, # aten::grid_sampler_2d, aten::grid_sampler_3d torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices @@ -526,10 +567,8 @@ def run_meta_crossref( torch.nn.functional.max_unpool3d: {f16}, # aten::max_unpool3d torch.nn.functional.multi_margin_loss: {bf16, f16}, # aten::multi_margin_loss torch.nn.functional.multilabel_margin_loss: {bf16, f16}, # aten::multilabel_margin_loss_forward - torch.nn.functional.prelu: {f16}, # aten::prelu torch.nn.functional.rrelu: {f16}, # aten::rrelu_with_noise torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out - torch.vdot: {f16}, # aten::vdot } meta_function_device_skips['cuda'] = { @@ -595,114 +634,108 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # these always fail meta_dispatch_expected_failures = { - aten._convolution.default: {c64, i64, f64, c128, bf16, f32}, - aten._ctc_loss.default: {f64, f32}, - aten._histogramdd_bin_edges.default: {f64, f32}, - aten._histogramdd_from_bin_cts.default: {f64, f32}, - aten._histogramdd_from_bin_tensors.default: {f64, f32}, - aten._local_scalar_dense.default: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten._pdist_forward.default: {f64, f32}, - aten._unique2.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, - aten.bincount.default: {i8, i64, i16, u8, i32}, - aten.bucketize.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, - aten.bucketize.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, - aten.col2im.default: {c64, f32, f64, c128}, - aten.complex.default: {c64, f64, c128, f16, f32}, - aten.complex.out: {f16}, - aten.convolution.default: {c64, i64, f64, c128, bf16, f32}, - aten.count_nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.count_nonzero.dim_IntList: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.equal.default: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.floor_divide.default: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, - aten.floor_divide.out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, - aten.frexp.Tensor: {bf16, f16, f64, f32}, - aten.grid_sampler_2d.default: {f64, f32}, - aten.grid_sampler_3d.default: {f64, f32}, - aten.histc.default: {bf16, f64, f32}, - aten.histc.out: {bf16, f64, f32}, - aten.histogram.bin_ct: {f64, f32}, - aten.histogram.bins_tensor: {f64, f32}, - aten.im2col.default: {bf16, f16, f64, f32}, - aten.kthvalue.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten.linalg_matrix_exp.default: {bf16, f64, f32}, - aten.log_sigmoid_forward.output: {bf16, f64, f32}, - aten.logcumsumexp.default: {bf16, f64, f32}, - aten.logcumsumexp.out: {bf16, f64, f32}, - aten.masked_select.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.masked_select.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.max_pool3d_with_indices.default: {f64, f32}, - aten.max_unpool2d.default: {f64, f32}, - aten.max_unpool3d.default: {f64, f32}, - aten.median.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten.median.dim: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten.mode.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.multi_margin_loss.default: {f64, f32}, - aten.multilabel_margin_loss_forward.default: {f64, f32}, - aten.multinomial.default: {bf16, f64, f32}, - aten.multinomial.out: {bf16, f64, f32}, - aten.mvlgamma.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten.mvlgamma.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten.nll_loss2d_forward.default: {bf16, f64, f32}, - aten.nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.nonzero.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.polar.default: {f64, f32}, - aten.prelu.default: {bf16, f64, f32}, - aten.rrelu_with_noise.default: {bf16, f64, f32}, - aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, - aten.searchsorted.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, - aten.segment_reduce.default: {bf16, f16, f32, f64}, - aten.take.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.take.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.tensordot.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten.to_sparse.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.to_sparse.sparse_dim: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.unique_consecutive.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, - aten.unique_dim.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, - aten.upsample_nearest3d.vec: {bf16, u8, f64, f32}, - aten.vdot.default: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten.vdot.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, - aten._det_lu_based_helper.default: {f32, f64}, # aten::_det_lu_based_helper - aten.cholesky.default: {f32, f64}, # aten::cholesky - aten.cholesky.out: {f32, f64}, # aten::cholesky.out - aten.cholesky_inverse.default: {f32, f64}, # aten::cholesky_inverse - aten.cholesky_inverse.out: {f32, f64}, # aten::cholesky_inverse.out - aten.cholesky_solve.default: {f32, f64}, # aten::_cholesky_solve_helper - aten.cholesky_solve.out: {f32, f64}, # aten::_cholesky_solve_helper - aten.eig.default: {f32, f64}, # aten::_local_scalar_dense - aten.geqrf.default: {f32, f64}, # aten::geqrf - aten.linalg_eig.default: {f32, f64}, # aten::linalg_eig - aten.linalg_householder_product.default: {f32, f64}, # aten::linalg_householder_product - aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out - aten.linalg_lstsq.default: {f32, f64}, # aten::linalg_lstsq.out - aten.linalg_slogdet.default: {f32, f64}, # aten::linalg_slogdet - aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular - aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out - aten.logdet.default: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero - aten.ormqr.default: {f32, f64}, # aten::ormqr - aten.ormqr.out: {f32, f64}, # aten::ormqr.out - aten.symeig.default: {f32, f64}, # aten::_symeig_helper + aten.allclose.default: {f16, bf16, f32, f64, c64, c128}, # NotImplementedError: 'aten::_local_scalar_dense' + aten._fft_c2c.out : {f16, c64, i8, f64, c128, i32, i64, f32, c32, b8, i16, u8}, + aten._fft_r2c.out : {f16, i8, f64, i32, i64, f32, b8, i16, u8}, + aten.cholesky.default : {c64, c128, f64, f32}, + aten.cholesky.out : {c64, c128, f64, f32}, + aten.cholesky_inverse.default : {c64, c128, f64, f32}, + aten.cholesky_inverse.out : {c64, c128, f64, f32}, + aten.cholesky_solve.default : {c64, c128, f64, f32}, + aten.cholesky_solve.out : {c64, c128, f64, f32}, + aten.count_nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten.count_nonzero.dim_IntList : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten.eig.default : {c64, c128, f64, f32}, + aten.geqrf.default : {c64, c128, f64, f32}, + aten.im2col.default : {c64, bf16, f32, f16, f64, c128}, + aten.linalg_eig.default : {c64, c128, f64, f32}, + aten.linalg_householder_product.default : {c64, c128, f64, f32}, + aten.linalg_householder_product.out : {c64, c128, f64, f32}, + aten.linalg_lstsq.default : {c64, c128, f64, f32}, + aten.linalg_matrix_exp.default : {c64, bf16, f32, f64, c128}, + aten.linalg_solve_triangular.default : {c64, c128, f64, f32}, + aten.linalg_solve_triangular.out : {c64, c128, f64, f32}, + aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten.native_group_norm.default : {bf16}, + aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8}, + aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8}, + aten.ormqr.default : {c64, c128, f64, f32}, + aten.ormqr.out : {c64, c128, f64, f32}, + aten.polar.out : {f32, f64}, + aten.symeig.default : {c64, c128, f64, f32}, + aten.take.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten.take.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten.tensordot.out : {c64, i8, f64, c128, i64, bf16, f32, i32, i16, u8}, + aten.to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten.to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten._ctc_loss.default : {f32, f64}, + aten._histogramdd_bin_edges.default : {f32, f64}, + aten._histogramdd_from_bin_cts.default : {f32, f64}, + aten._histogramdd_from_bin_tensors.default : {f32, f64}, + aten._local_scalar_dense.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten._pdist_forward.default : {f32, f64}, + aten._unique2.default : {i8, f64, i64, bf16, f32, i32, b8, i16, u8}, + aten.bincount.default : {i64, i8, i32, i16, u8}, + aten.bucketize.Tensor : {f16, i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.bucketize.Tensor_out : {f16, i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.col2im.default : {c64, f32, f64, c128}, + aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, + aten.frexp.Tensor : {bf16, f32, f16, f64}, + aten.grid_sampler_2d.default : {f32, f64}, + aten.grid_sampler_3d.default : {f32, f64}, + aten.histc.default : {bf16, f32, f64}, + aten.histc.out : {bf16, f32, f64}, + aten.histogram.bin_ct : {f32, f64}, + aten.histogram.bins_tensor : {f32, f64}, + aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.log_sigmoid_forward.output : {bf16, f32, f64}, + aten.logcumsumexp.default : {bf16, f32, f64}, + aten.logcumsumexp.out : {bf16, f32, f64}, + aten.max_pool3d_with_indices.default : {f32, f64}, + aten.max_unpool2d.default : {f32, f64}, + aten.max_unpool3d.default : {f32, f64}, + aten.median.default : {i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.median.dim : {i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.mode.default : {f16, i8, f64, i64, bf16, f32, i32, b8, i16, u8}, + aten.multi_margin_loss.default : {f32, f64}, + aten.multilabel_margin_loss_forward.default : {f32, f64}, + aten.multinomial.default : {bf16, f32, f64}, + aten.multinomial.out : {bf16, f32, f64}, + aten.mvlgamma.default : {i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.mvlgamma.out : {i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.nll_loss2d_forward.default : {bf16, f32, f64}, + aten.polar.default : {f32, f64}, + aten.rrelu_with_noise.default : {bf16, f32, f64}, + aten.searchsorted.Tensor : {f16, i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.searchsorted.Tensor_out : {f16, i8, f64, i64, bf16, f32, i32, i16, u8}, + aten.segment_reduce.default : {bf16, f32, f16, f64}, + aten.unique_consecutive.default : {i8, f64, i64, bf16, f32, i32, b8, i16, u8}, + aten.unique_dim.default : {i8, f64, i64, bf16, f32, i32, b8, i16, u8}, + aten.upsample_nearest3d.vec : {bf16, f32, f64, u8}, } # these sometimes pass and sometimes fail meta_dispatch_skips = { - aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32}, # at::nonzero doesn't have a Meta function - aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, + aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128}, # at::nonzero doesn't have a Meta function + aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32, c64, c128}, aten.aminmax.default: {i64, u8, b8, f32, i8, f64, i16, i32}, aten.cummax.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, aten.cummin.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32}, + aten.linalg_lu_solve.default: {c32, c64, c128}, + aten.linalg_lu_solve.out: {c32, c64, c128}, aten.linalg_pinv.atol_rtol_tensor: {f32, f64}, aten.linalg_pinv.atol_rtol_tensor_out: {f32, f64}, aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, + aten.empty.SymInt: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, } meta_dispatch_device_expected_failures = defaultdict(dict) meta_dispatch_device_skips = defaultdict(dict) meta_dispatch_device_expected_failures['cuda'] = { - aten._convolution.default: {f16, c32}, aten._unique2.default: {f16}, # aten::_unique2 aten._use_cudnn_ctc_loss.default: {f32, f64}, # aten::_use_cudnn_ctc_loss - aten.convolution.default: {f16, c32}, aten.cudnn_grid_sampler.default: {f16, f32, f64}, # aten::cudnn_grid_sampler aten.geqrf.default: {f32, f64}, # aten::geqrf aten.grid_sampler_2d.default: {f16}, # aten::grid_sampler_2d @@ -735,19 +768,21 @@ def __torch_function__(self, func, types, args=(), kwargs=None): aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward aten.ormqr.default: {f32, f64}, # aten::ormqr aten.ormqr.out: {f32, f64}, # aten::ormqr.out - aten.prelu.default: {f16}, # aten::prelu aten.rrelu_with_noise.default: {f16}, # aten::rrelu_with_noise aten.tensordot.out: {f16}, # aten::tensordot.out aten.unique_consecutive.default: {f16}, # aten::unique_consecutive aten.unique_dim.default: {f16}, # aten::unique_dim aten.upsample_nearest3d.vec: {f16}, # aten::upsample_nearest3d.vec - aten.vdot.default: {f16}, # aten::vdot - aten.vdot.out: {f16}, # aten::vdot } meta_dispatch_device_skips['cuda'] = { - aten._conj.default: {c32, f16}, + aten._conj.default: {c32, f16}, # file issue + aten._linalg_svd.default: {c64, c128}, # aten::linalg_eigvalsh.out aten.cudnn_batch_norm.default: {f32, f64}, + aten.log_softmax.int : {c32, c64}, + aten.softmax.int : {c32, c64}, + aten.softmax.int : {c32, c64}, + aten.cummax.default: {f16}, aten.cummin.default: {f16}, # ROCm stuff; technically this should be expected failure but it's @@ -795,11 +830,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): device_type=self.device_type, ) - # NB: we're running these tests only on CUDA because there are some # inconsistencies between CUDA and CPU, and running on CUDA makes it easier # to ignore the CPU case when inconsistencies arise. Ideally we deal # with the inconsistencies but this takes time. +@skipIfSlowGradcheckEnv class TestMeta(TestCase): @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyCUDA @@ -815,7 +850,7 @@ def test_meta(self, device, dtype, op): for sample_input in samples: args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs - with MetaCrossRefFunctionMode.push(self, dtype=dtype, device=device): + with MetaCrossRefFunctionMode(self, dtype=dtype, device=device): expected = func(*args, **kwargs) if isinstance(expected, torch.Tensor) and op.supports_out: func(*args, **kwargs, out=expected) @@ -831,11 +866,16 @@ def test_dispatch_meta(self, device, dtype, op): for sample_input in samples: args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs + with MetaCrossRefDispatchMode.push(self, dtype=dtype, device=device): expected = func(*args, **kwargs) if isinstance(expected, torch.Tensor) and op.supports_out: func(*args, **kwargs, out=expected) + def test_empty_quantized(self): + r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8) + self.assertEqual(r.device.type, 'meta') + instantiate_device_type_tests(TestMeta, globals()) def print_op_str_if_not_supported(op_str): diff --git a/test/test_mkl_verbose.py b/test/test_mkl_verbose.py new file mode 100644 index 0000000000000..5e6cbda12a2f2 --- /dev/null +++ b/test/test_mkl_verbose.py @@ -0,0 +1,34 @@ +# Owner(s): ["module: unknown"] + +from torch.testing._internal.common_utils import TestCase, run_tests +import os +import subprocess +import sys + +class TestMKLVerbose(TestCase): + def test_verbose_on(self): + num = 0 + loc = os.path.dirname(os.path.abspath(__file__)) + with subprocess.Popen(f'{sys.executable} -u {loc}/mkl_verbose.py --verbose-level=1', shell=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p: + for line in p.stdout.readlines(): + line = str(line, 'utf-8').strip() + if line.startswith("MKL_VERBOSE"): + num = num + 1 + elif line == 'Failed to set MKL into verbose mode. Please consider to disable this verbose scope.': + return + self.assertTrue(num > 0, 'oneMKL verbose messages not found.') + + def test_verbose_off(self): + num = 0 + loc = os.path.dirname(os.path.abspath(__file__)) + with subprocess.Popen(f'{sys.executable} -u {loc}/mkl_verbose.py --verbose-level=0', shell=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p: + for line in p.stdout.readlines(): + line = str(line, 'utf-8').strip() + if line.startswith("MKL_VERBOSE"): + num = num + 1 + self.assertEqual(num, 0, 'unexpected oneMKL verbose messages found.') + +if __name__ == '__main__': + run_tests() diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index cb9eb4828cacc..4f9518d8fd03e 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -136,6 +136,22 @@ def test_unsupported(self): with self.assertRaises(RuntimeError) as context: creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn) + def test_mkldnn_conv_shapecheck(self): + input = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32) + w1 = torch.full((1, 1, 1, 24,), 1, dtype=torch.float32) + b1 = torch.full((1,), 1, dtype=torch.float32) + w2 = torch.full((1, 1, 2, 24,), 1, dtype=torch.float32) + b2 = torch.full((2,), 1, dtype=torch.float32) + options = zip([-1, 0, 0, 0, 0, 0, 0], # padding + [1, 0, 1, 1, 1, 1, 1], # stride + [1, 1, 0, 1, 1, 1, 1], # dilation + [1, 1, 1, 0, 2, 1, 1], # groups + [w1, w1, w1, w1, w1, w1, w2], # weight + [b1, b1, b1, b1, b1, b2, b1]) # bias + for pad, st, dil, gr, w, b in options: + with self.assertRaises(RuntimeError) as _: + torch.mkldnn_convolution(input, w, b, [pad] * 2, [st] * 2, [dil] * 2, gr) + def test_autograd_to_mkldnn(self): # MKLDNN only supports float32 root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) diff --git a/test/test_mkldnn_verbose.py b/test/test_mkldnn_verbose.py new file mode 100644 index 0000000000000..b7d8607ee50e1 --- /dev/null +++ b/test/test_mkldnn_verbose.py @@ -0,0 +1,34 @@ +# Owner(s): ["module: unknown"] + +from torch.testing._internal.common_utils import TestCase, run_tests +import os +import subprocess +import sys + +class TestMKLDNNVerbose(TestCase): + def test_verbose_on(self): + num = 0 + loc = os.path.dirname(os.path.abspath(__file__)) + with subprocess.Popen(f'{sys.executable} -u {loc}/mkldnn_verbose.py --verbose-level=1', shell=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p: + for line in p.stdout.readlines(): + line = str(line, 'utf-8').strip() + if line.startswith("onednn_verbose"): + num = num + 1 + elif line == 'Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope.': + return + self.assertTrue(num > 0, 'oneDNN verbose messages not found.') + + def test_verbose_off(self): + num = 0 + loc = os.path.dirname(os.path.abspath(__file__)) + with subprocess.Popen(f'{sys.executable} -u {loc}/mkldnn_verbose.py --verbose-level=0', shell=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p: + for line in p.stdout.readlines(): + line = str(line, 'utf-8').strip() + if line.startswith("onednn_verbose"): + num = num + 1 + self.assertEqual(num, 0, 'unexpected oneDNN verbose messages found.') + +if __name__ == '__main__': + run_tests() diff --git a/test/test_module_init.py b/test/test_module_init.py index a988612971f04..61a9a2feac77d 100644 --- a/test/test_module_init.py +++ b/test/test_module_init.py @@ -226,6 +226,7 @@ def build_constructor_arg_db(): 'factory_kwargs': {}, }), torch.nn.quantized.MaxPool2d: ((3,), {}), + torch.nn.quantized.PReLU: ((0.01, 0), {}), torch.nn.quantized.Quantize: ((0.1, 0), { 'dtype': torch.int16, 'factory_kwargs': {}, @@ -370,10 +371,12 @@ def generate_tests(test_cls, constructor_arg_db): # See https://github.com/pytorch/pytorch/issues/55396 torch.nn.quantized.Embedding, torch.nn.quantized.EmbeddingBag, + torch.nn.quantized.LSTM, + torch.nn.quantized.MultiheadAttention, } # no need to support kwargs for these modules even though # they have parameters / buffers because they are passed in - # already instantiated + # already instantiated s MODULES_WITHOUT_KWARGS_SUPPORT = { torch.nn.BCELoss, torch.nn.BCEWithLogitsLoss, diff --git a/test/test_modules.py b/test/test_modules.py index 3ed5f3be76f37..a62bfff8de698 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -533,6 +533,12 @@ def _to_device(obj): gpu_module = module_cls(*args, **kwargs).to(dtype).to(device) gpu_module.train(training) + # === Lazy modules need to see an input to initialize params === + if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin): + with torch.no_grad(): + cpu_module(*cpu_forward_args, **cpu_forward_kwargs) + gpu_module(*gpu_forward_args, **gpu_forward_kwargs) + for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): gpu_p.data.copy_(cpu_p) diff --git a/test/test_mps.py b/test/test_mps.py index d5475dab8a87b..c2cf7565995de 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9,17 +9,26 @@ import subprocess import tempfile import os +import pprint import torch import torch.nn as nn import torch.nn.functional as F import itertools +from collections import defaultdict from torch._six import inf from torch.nn import Parameter -from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN +from torch.testing._internal.common_utils import \ + (gradcheck, gradgradcheck, run_tests, TestCase, download_file, + TEST_WITH_UBSAN) +from torch.testing import make_tensor from torch.testing._comparison import TensorLikePair +from torch.testing._internal.common_dtype import get_all_dtypes import torch.backends.mps from torch.distributions import Uniform, Exponential +from functools import partial +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase import numpy as np import torch @@ -366,6 +375,12 @@ def _linear_helper(self, in_features, out_features, shape, bias=True, backward_p self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size()) self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05) + def test_linear1D(self): + self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False) + + def test_linear1D_backward(self): + self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True) + def test_linear2D(self): self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False) @@ -782,7 +797,6 @@ def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dt helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine) helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine) - def test_instance_norm(self): def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False): @@ -1444,6 +1458,14 @@ def test_to(self): torch.tensor(4, dtype=torch.int32)) self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int), torch.tensor(-8.34, device='cpu').to('mps').to(torch.int)) + # Cast int8 and uint8 to float and compare results + # See https://github.com/pytorch/pytorch/issues/80009 for more details + cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8) + cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8) + for x_cpu in [cpu_byte, cpu_char]: + x_mps = x_cpu.to('mps') + self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32)) + def test_setitem_scalar(self) -> None: device = 'mps' @@ -1468,6 +1490,164 @@ def test_stride_of_strides(self) -> None: z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu") self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z) + def test_type_casting(self): + # https://github.com/pytorch/pytorch/issues/81567 + def helper(data, to_dtype): + a_cpu = torch.tensor(data) + a_mps = a_cpu.to(torch.device('mps')) + + res_cpu = a_cpu.type(to_dtype) + res_mps = a_mps.type(to_dtype) + self.assertEqual(res_cpu, res_mps) + + helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor) + helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor) + + def test_to_casting(self): + # https://github.com/pytorch/pytorch/issues/81567 + def helper(data, to_dtype): + a_cpu = torch.tensor(data) + a_mps = a_cpu.to(torch.device('mps')) + + res_cpu = a_cpu.to(to_dtype) + res_mps = a_mps.to(to_dtype) + self.assertEqual(res_cpu, res_mps) + + helper([9.0, 3.0, 5.0, 4.0], torch.int64) + helper([9.0, 3.0, 5.0, 4.0], torch.float) + helper([9.0, 3.0, 5.0, 4.0], torch.int32) + helper([9.0, 3.0, 5.0, 4.0], torch.short) + helper([9.0, 3.0, 5.0, 4.0], torch.half) + helper([9.0, 3.0, 5.0, 4.0], torch.int8) + helper([9.0, 3.0, 5.0, 4.0], torch.uint8) + + def test_storage_offset_greater_than_src_nbytes(self): + # https://github.com/pytorch/pytorch/issues/80844 + n_tensors = 100 + n_tensor_elems = 784 + elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32) + + tensor_list = [] + for i in range(0, n_tensors - 1): + # create a list of contiguous view tensors (view tensor created by the slice op) + t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)] + tensor_list.append(t) + + for i in range(0, n_tensors - 1): + t = tensor_list[i].view(1, 784) + t_mps = t.to("mps") + self.assertEqual(t, t_mps.cpu()) + + # See https://github.com/pytorch/pytorch/issues/82427 + # Test should not crash + def test_bool_full(self): + x = torch.full((3, 3), True, device='mps') + + # Empty unary op should return tensor of the same size + def test_empty_neg(self): + x = torch.tensor([[]], device='mps') + y = -x + self.assertEqual(x, y) + + +class TestLogical(TestCase): + def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): + return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad) + + def test_logical_not(self): + def helper(x): + cpu_x = x + x = cpu_x.detach().clone().to('mps') + + result = torch.logical_not(x) + result_cpu = torch.logical_not(cpu_x) + + self.assertEqual(result, result_cpu) + + helper(self._wrap_tensor([1, 1, 0, 0])) + helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True)) + helper(self._wrap_tensor([True, True, False, False])) + helper(self._wrap_tensor(1)) + helper(self._wrap_tensor(0)) + helper(self._wrap_tensor(True)) + helper(self._wrap_tensor(False)) + + def test_logical_and(self): + def helper(x, other): + cpu_x = x + x = cpu_x.detach().clone().to('mps') + + cpu_other = other + other = cpu_other.detach().clone().to('mps') + + result = torch.logical_and(x, other) + result_cpu = torch.logical_and(cpu_x, cpu_other) + self.assertEqual(result, result_cpu) + + helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor(([1, 0, 0, 1]))) + helper( + self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), + self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) + ) + helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) + + def test_logical_or(self): + def helper(x, other): + cpu_x = x + x = cpu_x.detach().clone().to('mps') + + cpu_other = other + other = cpu_other.detach().clone().to('mps') + + result = torch.logical_or(x, other) + result_cpu = torch.logical_or(cpu_x, cpu_other) + + self.assertEqual(result, result_cpu) + + helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor(([1, 0, 0, 1]))) + helper( + self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), + self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) + ) + helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) + + def test_logical_xor(self): + def helper(x, other): + cpu_x = x + x = cpu_x.detach().clone().to('mps') + + cpu_other = other + other = cpu_other.detach().clone().to('mps') + + result = torch.logical_xor(x, other) + result_cpu = torch.logical_xor(cpu_x, cpu_other) + + self.assertEqual(result, result_cpu) + + helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor(([1, 0, 0, 1]))) + helper( + self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), + self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) + ) + helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) + helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) + class TestSmoothL1Loss(TestCase): @@ -1585,19 +1765,51 @@ def _nll_loss_helper(self, input_size, reduction, expected): output_mps.sum().backward() self.assertEqual(input.grad, input_mps.grad.to('cpu')) + def _nll_loss_1d_helper(self, input_size, reduction): + + # CPU + input = torch.rand(input_size, requires_grad=True, device='cpu') + num_channels = input_size[0] + target = torch.randint(num_channels, [], device='cpu') + + # MPS + input_mps = input.detach().clone().to('mps').requires_grad_() + target_mps = target.detach().clone().to('mps') + + output_cpu = F.nll_loss(input, target, reduction=reduction) + output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu')) + + output_cpu.sum().backward() + output_mps.sum().backward() + self.assertEqual(input.grad, input_mps.grad.to('cpu')) + def test_as_strided(self): - def helper(n, c): - values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] - values_1 = [[1.0, 1.0], [1.0, 1.0]] - cpu_x = torch.tensor(values, device='cpu') - ones1 = torch.tensor(values_1, device='mps') - x = cpu_x.detach().clone().to('mps').requires_grad_() - strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2)) - strided_mps = torch.as_strided(x, (2, 2), (1, 2)) + values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + values_1 = [[1.0, 1.0], [1.0, 1.0]] + cpu_x = torch.tensor(values, device='cpu') + ones1 = torch.tensor(values_1, device='mps') + x = cpu_x.detach().clone().to('mps').requires_grad_() + strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2)) + strided_mps = torch.as_strided(x, (2, 2), (1, 2)) + self.assertEqual(strided_mps, strided_cpu) + strided_cpu_out = strided_cpu + ones1.to('cpu') + strided_mps_out = strided_mps + ones1 + self.assertEqual(strided_cpu_out, strided_mps_out) + + # test with storage offsets + cpu_x = torch.rand(3, 3, device='cpu') + mps_x = cpu_x.to('mps') + strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0) + strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0) + strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1) + strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1) + strided_cpu_out = strided_cpu1 - strided_cpu2 + strided_mps_out = strided_mps1 - strided_mps2 + self.assertEqual(strided_cpu_out, strided_mps_out) - self.assertEqual(strided_mps, strided_cpu) - helper(3, 3) def test_sum_backward(self): def helper(n, c): @@ -1615,6 +1827,11 @@ def helper(n, c): helper(3, 3) + def test_nll_loss_1d(self, device='cpu'): + self._nll_loss_1d_helper([10], "none") + self._nll_loss_1d_helper([10], "mean") + self._nll_loss_1d_helper([10], "sum") + def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'): self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device)) self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device)) @@ -1668,6 +1885,35 @@ def compute_result_and_gradient(reduction, target_dtype): self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu']) self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu']) + # L1 loss + def test_l1_loss(self): + def helper(shape, reduction): + # create the criterion + loss = torch.nn.L1Loss(reduction=reduction) + + inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) + targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() + targetMPS = targetCPU.detach().clone().to('mps') + + # forward pass + outputCPU = loss(inputCPU, targetCPU) + outputMPS = loss(inputMPS, targetMPS) + self.assertEqual(outputCPU, outputMPS) + + # backward pass + if reduction != 'none': + # chose 2 just to make the grad_output > 1 in backward pass + outputCPU.backward(gradient=torch.full_like(outputCPU, 2)) + outputMPS.backward(gradient=torch.full_like(outputMPS, 2)) + self.assertEqual(inputCPU.grad, inputMPS.grad) + + helper([8, 5, 4], 'none') + helper([7, 5, 2, 4], 'sum') + # verify if changes in shape would cause cached graph lookup problems + helper([7, 5, 2, 4, 6], 'sum') + helper([8, 4, 5, 7, 6], 'mean') + # Mean Squared Error def test_mse_loss(self): def helper(shape, reduction): @@ -2030,9 +2276,14 @@ def helper(shape): helper((2, 3, 4, 5)) - # Test forward argmax - def test_argmax(self): - def helper(n, c, h, w, dtype=torch.float32): + # Test forward argmin argmax + def test_argmin_argmax(self): + def helper(n, c, h, w, reduction_type, dtype=torch.float32): + if reduction_type == "max": + arg_reduction_fn = torch.argmax + else: + arg_reduction_fn = torch.argmin + cpu_x = None x = None if(dtype not in [torch.float32, torch.bool]): @@ -2045,46 +2296,50 @@ def helper(n, c, h, w, dtype=torch.float32): cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True) x = cpu_x.detach().clone().to('mps').requires_grad_() - y = torch.argmax(x) - ref_y = torch.argmax(cpu_x) + y = arg_reduction_fn(x) + ref_y = arg_reduction_fn(cpu_x) self.assertEqual(y, ref_y) - y_0 = torch.argmax(x, dim=0) - refy_0 = torch.argmax(cpu_x, dim=0) + y_0 = arg_reduction_fn(x, dim=0) + refy_0 = arg_reduction_fn(cpu_x, dim=0) self.assertEqual(y_0, refy_0) - y_0dim = torch.argmax(x, dim=0, keepdim=True) - refy_0dim = torch.argmax(cpu_x, dim=0, keepdim=True) + y_0dim = arg_reduction_fn(x, dim=0, keepdim=True) + refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True) self.assertEqual(y_0dim, refy_0dim) - y_1 = torch.argmax(x, dim=1) - refy_1 = torch.argmax(cpu_x, dim=1) + y_1 = arg_reduction_fn(x, dim=1) + refy_1 = arg_reduction_fn(cpu_x, dim=1) self.assertEqual(y_1, refy_1) - y_1dim = torch.argmax(x, dim=1, keepdim=True) - refy_1dim = torch.argmax(cpu_x, dim=1, keepdim=True) + y_1dim = arg_reduction_fn(x, dim=1, keepdim=True) + refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True) self.assertEqual(y_1dim, refy_1dim) - y_2 = torch.argmax(x, dim=2) - refy_2 = torch.argmax(cpu_x, dim=2) + y_2 = arg_reduction_fn(x, dim=2) + refy_2 = arg_reduction_fn(cpu_x, dim=2) self.assertEqual(y_2, refy_2) - y_2dim = torch.argmax(x, dim=2, keepdim=True) - refy_2dim = torch.argmax(cpu_x, dim=2, keepdim=True) + y_2dim = arg_reduction_fn(x, dim=2, keepdim=True) + refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True) self.assertEqual(y_2dim, refy_2dim) - y_3 = torch.argmax(x, dim=3) - refy_3 = torch.argmax(cpu_x, dim=3) + y_3 = arg_reduction_fn(x, dim=3) + refy_3 = arg_reduction_fn(cpu_x, dim=3) self.assertEqual(y_3, refy_3) - y_3dim = torch.argmax(x, dim=3, keepdim=True) - refy_3dim = torch.argmax(cpu_x, dim=3, keepdim=True) + y_3dim = arg_reduction_fn(x, dim=3, keepdim=True) + refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True) self.assertEqual(y_3dim, refy_3dim) - helper(2, 8, 4, 4, torch.float32) - helper(2, 8, 4, 4, torch.int32) - helper(2, 8, 4, 4, torch.float16) - helper(2, 8, 4, 4, torch.int64) + helper(2, 8, 4, 4, "max", torch.float32) + helper(2, 8, 4, 4, "max", torch.int32) + helper(2, 8, 4, 4, "max", torch.float16) + helper(2, 8, 4, 4, "max", torch.int64) + helper(2, 8, 4, 4, "min", torch.float32) + helper(2, 8, 4, 4, "min", torch.int32) + helper(2, 8, 4, 4, "min", torch.float16) + helper(2, 8, 4, 4, "min", torch.int64) # Test forward max # Note - don't test grad now @@ -2659,6 +2914,8 @@ def helper(shape): self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu) helper((4, 5, 6, 7)) + # verify if a change in shape of input would cause problems with graph caching + helper((9, 5, 6, 7)) # Test var def test_var(self): @@ -2757,6 +3014,8 @@ def helper(shape): self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu) helper((4, 5, 6, 7)) + # verify if a change in shape of input would cause problems with graph caching + helper((9, 5, 6, 7)) # Test forward amax def test_amax(self): @@ -3105,6 +3364,20 @@ def helper(N, C, H, W): helper(1, 1, 4, 4) helper(7, 5, 3, 2) + def test_upsample_nearest1d(self): + def helper(N, C, H, W): + inputCPU = torch.arange(C * H * W, device='cpu', dtype=torch.float, + requires_grad=True).reshape(C, H, W) + inputMPS = inputCPU.detach().clone().to('mps') + + outputCPU = torch.nn.functional.interpolate(inputCPU, scale_factor=2.0, mode='nearest') + outputMPS = torch.nn.functional.interpolate(inputMPS, scale_factor=2.0, mode='nearest') + + self.assertEqual(outputCPU, outputMPS) + + helper(1, 1, 4, 4) + helper(7, 5, 3, 2) + # Test concat forward def test_cat1(self): def helper(shape_x, shape_y, shape_z): @@ -3123,16 +3396,49 @@ def helper(shape_x, shape_y, shape_z): self.assertEqual(cat, cat_cpu) helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5]) - # Empty test - Currently failing! Empty tensor not handled! - # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5]) + helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5]) + helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5]) + helper([2, 2, 6, 5], [0], [2, 5, 6, 5]) + helper([0], [2, 3, 6, 5], [2, 5, 6, 5]) + helper([2, 3, 4, 5], [2, 5, 4, 5], [0]) + helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5]) + helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5]) + helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5]) + + def test_constant_pad(self): + m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5) + input_cpu = torch.randn(1, 16, 16, 16) + input_mps = input_cpu.detach().clone().to("mps") + r_cpu = m(input_cpu) + r_mps = m(input_mps) + self.assertEqual(r_cpu, r_mps.to("cpu")) + + def test_circular_pad(self): + # https://github.com/pytorch/pytorch/issues/80856 + k_cpu = torch.ones(3, 3, 9, 9) + k_mps = k_cpu.detach().clone().to("mps") + + x_cpu = torch.rand(1, 3, 32, 32) + x_mps = x_cpu.detach().clone().to("mps") + + x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular') + x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular') + + y_cpu = F.conv2d(x_pad_cpu, k_cpu) + y_mps = F.conv2d(x_pad_mps, k_mps) + + self.assertEqual(y_cpu, y_mps.cpu()) def test_pad(self): - def helper(shape, padding, op): + def helper(shape, padding, op, value=0): inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) inputCPU.retain_grad() inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() - padCriteria = op(padding) + if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]): + padCriteria = op(padding, value) + else: + padCriteria = op(padding) outputCPU = padCriteria(inputCPU) outputMPS = padCriteria(inputMPS) self.assertEqual(outputCPU, outputMPS) @@ -3148,6 +3454,8 @@ def helper(shape, padding, op): helper((2, 4, 4), (1, 3), nn.ReflectionPad1d) # Replication 1D helper((2, 1, 6), 3, nn.ReplicationPad1d) + # Constant Pad 1D + helper((2, 3, 4), 2, nn.ConstantPad1d) # 2D Padding helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d) @@ -3157,24 +3465,46 @@ def helper(shape, padding, op): helper((2, 1, 6, 8), 2, nn.ReplicationPad2d) # verify if a change in shape of padding would cause problems with graph caching helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d) + # Constant Pad 2D + helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d) # 3D Padding helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d) # verify if a change in shape of padding would cause problems with graph caching helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d) + # Constant Pad 3D + helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) # Test stack forward def test_stack(self): # All shapes must be same - def helper(shape): - cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - x = cpu_x.detach().clone().to('mps') + def helper(shape, dtype=torch.float32): - cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - y = cpu_y.detach().clone().to('mps') + x, cpu_x = None, None + y, cpu_y = None, None + z, cpu_z = None, None - cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - z = cpu_z.detach().clone().to('mps') + if(dtype not in [torch.float32, torch.bool]): + cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) + x = cpu_x.detach().clone().to('mps') + cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) + y = cpu_y.detach().clone().to('mps') + cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) + z = cpu_z.detach().clone().to('mps') + elif (dtype == torch.bool): + cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) + x = cpu_x.detach().clone().to('mps') + cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) + y = cpu_y.detach().clone().to('mps') + cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) + z = cpu_z.detach().clone().to('mps') + else: + cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) + x = cpu_x.detach().clone().to('mps').requires_grad_() + cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) + y = cpu_y.detach().clone().to('mps').requires_grad_() + cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) + z = cpu_z.detach().clone().to('mps').requires_grad_() stack = torch.stack([x, y, z], dim=1) stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1) @@ -3182,6 +3512,10 @@ def helper(shape): self.assertEqual(stack, stack_cpu) helper([2, 8, 4, 5]) + helper([2, 8, 4, 5], dtype=torch.float16) + helper([2, 8, 4, 5], dtype=torch.int32) + helper([2, 8, 4, 5], dtype=torch.int64) + helper([2, 8, 4, 5], dtype=torch.bool) # Empty test - Currently failing! Empty tensor not handled! # helper([0, 2, 4, 5]) @@ -3397,8 +3731,30 @@ def helper(shape, alpha=1.0): for alpha in [0.000001, 1.0, 2.3, 0.34, 23]: helper(shape, alpha) - # Test softplus + # Test glu + def test_glu(self): + def helper(shape, dim=0): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) + x = cpu_x.detach().clone().to('mps').requires_grad_() + + for activation_func in [torch.nn.GLU(dim=dim)]: + glu_result = activation_func(x) + glu_result_cpu = activation_func(cpu_x) + + cpu_grad = torch.randn(glu_result_cpu.shape) + grad = cpu_grad.to('mps') + glu_result.backward(gradient=grad) + glu_result_cpu.backward(gradient=cpu_grad) + + self.assertEqual(glu_result, glu_result_cpu) + self.assertEqual(x.grad, cpu_x.grad) + + for shape in [[4], (2, 4), (2, 8, 4, 6)]: + for dim in range(len(shape)): + helper(shape, dim) + + # Test softplus def test_softplus(self): def helper(shape): cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) @@ -3407,7 +3763,14 @@ def helper(shape): softplus_result = torch.nn.Softplus(beta=0.5, threshold=0.5)(x) softplus_result_cpu = torch.nn.Softplus(beta=0.5, threshold=0.5)(cpu_x) + cpu_grad = torch.randn(softplus_result.shape) + grad = cpu_grad.to('mps') + + softplus_result.backward(gradient=grad) + softplus_result_cpu.backward(gradient=cpu_grad) + self.assertEqual(softplus_result, softplus_result_cpu) + self.assertEqual(x.grad, cpu_x.grad) # Test empty shape too for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]: @@ -3689,6 +4052,54 @@ def helper(shape): helper((2, 8, 4, 5)) + # Test index add + def test_index_add(self): + def helper(shape, dim, index, source_shape, alpha, idx_dtype=torch.int32): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + x = cpu_x.detach().clone().to('mps') + + cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype) + idx = cpu_idx.detach().clone().to('mps') + + cpu_source = torch.randn(source_shape, device='cpu', dtype=torch.float, requires_grad=False) + source = cpu_source.detach().clone().to('mps') + + idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha) + idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha) + self.assertEqual(idx_result, idx_result_cpu) + + helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5) + helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0) + helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5) + helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0) + helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4) + helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0) + # test result dim=1 + helper((2,), 0, [1], (1,), 6.0) + helper(2, 0, 1, 1, 6) + + # Test flip + def test_flip(self): + def helper(shape, dims): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + x = cpu_x.detach().clone().to('mps') + + flip_result = torch.flip(x, dims=dims) + flip_result_cpu = torch.flip(cpu_x, dims=dims) + + self.assertEqual(flip_result, flip_result_cpu) + + helper((2, 8, 4, 5), [0]) + helper((8, 8, 4, 5), [0, 1]) + helper((2, 8, 4, 5), (0, 1, 2, 3)) + helper((2, 3, 3), (-1,)) + # empty dims + helper((2, 8, 4, 5), []) + # input.numel() == 1 + helper((1,), (0,)) + # input.numel() == 0 + helper((0,), (0,)) + # Test index select def test_index_select(self): def helper(shape, dim, index, idx_dtype=torch.int32): @@ -4108,9 +4519,6 @@ def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float): # Test normal def test_normal(self): def helper(shape, mean=0.0, std=1.0): - cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - x = cpu_x.detach().clone().to('mps') - mps_out = torch.normal(mean, std, shape, device='mps') mean_array = np.ones(shape) @@ -4123,6 +4531,7 @@ def helper(shape, mean=0.0, std=1.0): cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False) std_tensor = cpu_std_tensor.detach().clone().to('mps') + # test out mps_out = torch.zeros(shape, device='mps') torch.normal(mean_tensor, std, out=mps_out) @@ -4132,14 +4541,22 @@ def helper(shape, mean=0.0, std=1.0): mps_out = torch.zeros(shape, device='mps') torch.normal(mean_tensor, std_tensor, out=mps_out) + # test without out + mps_out = torch.normal(mean_tensor, std) + self.assertEqual(mps_out.size(), mean_tensor.size()) + + mps_out = torch.normal(mean, std_tensor) + self.assertEqual(mps_out.size(), std_tensor.size()) + + inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size()) + mps_out = torch.normal(mean_tensor, std_tensor) + self.assertEqual(mps_out.size(), inferred_shape) + helper((2, 3, 4, 5, 6)) helper((100, 100), 2.5, 1.2) def test_bernoulli(self): def helper(shape, prob=0.5): - cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - x = cpu_x.detach().clone().to('mps') - prob_array = np.ones(shape) prob_array *= prob cpu_prob_tensor = torch.tensor(prob_array, device='cpu', dtype=torch.float, requires_grad=False) @@ -4255,6 +4672,12 @@ def helper(alpha): helper(0.1) helper(0.2) + # Test int32 tensor + int64 scalar add + # see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534 + x = torch.ones(4, dtype=torch.int32, device='mps') + self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps')) + self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps'))) + def test_types_binary_op(self): # Float * Bool cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu") @@ -4591,6 +5014,845 @@ def maybe_transpose(cond, m): m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4) +class TestGatherScatter(TestCase): + def test_slicing_with_step(self): + # Slicing with step + # https://github.com/pytorch/pytorch/issues/78886 + x_mps = torch.zeros(10, dtype=torch.float32, device="mps") + x_mps[::2] = 1.0 + + x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu") + x_cpu[::2] = 1.0 + + self.assertEqual(x_cpu, x_mps) + + def test_slicing_replace_column(self): + # https://github.com/pytorch/pytorch/issues/78074 + def _helper(tensor_data): + x_cpu = torch.tensor(tensor_data) + x_mps = x_cpu.to('mps') + + x_cpu[:, 0] = 7 + x_mps[:, 0] = 7 + + self.assertEqual(x_cpu, x_mps) + + _helper([[1, 2, 3], [4, 5, 6]]) + _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) + + def test_inplace_scatter(self): + # https://github.com/pytorch/pytorch/issues/79672 + a_mps = torch.ones((2, 2),).to(torch.device("mps")) + b_mps = torch.ones((2, 2),).to(torch.device("mps")) + + a_cpu = torch.ones((2, 2),).to(torch.device("cpu")) + b_cpu = torch.ones((2, 2),).to(torch.device("cpu")) + + a_mps[:, 0] += b_mps[:, 0] + a_cpu[:, 0] += b_cpu[:, 0] + self.assertEqual(a_cpu, a_mps) + + a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0] + a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0] + self.assertEqual(a_cpu, a_mps) + +# These tests were taken from test/test_view_ops.py +# They are subset of those tests as currently only this subset is working. +# This whole `class` will be removed when we add generic device testing. There +# are no additional tests added apart from what is part of test_view_ops.py +class TestViewOpsMPS(TestCase): + exact_dtype = True + + def is_view_of(self, base, other): + if (not other._is_view() or + other is base or + other._base is not base or + base.device != other.device): + return False + # Note: only validates storage on native device types + # because some accelerators, like XLA, do not expose storage + if base.device.type == 'mps': + if base.storage().data_ptr() != other.storage().data_ptr(): + return False + + return True + + # Returns true if v1 and v2 are views of the same base + def is_view_of_same_base(self, v1, v2): + if (not v1._is_view() or v1 is v2): + return False + return self.is_view_of(v1._base, v2) + + # Performs transpose if contiguous=True, else returns the input tensor as is + def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): + if contiguous: + return x + else: + return x.transpose(dim0, dim1) + + def test_diagonal_view(self, device="mps"): + t = torch.ones((5, 5), device=device) + v = torch.diagonal(t) + self.assertTrue(self.is_view_of(t, v)) + + v[0] = 0 + self.assertEqual(t[0, 0], v[0]) + + t = torch.ones((3, 3, 3), device="mps") + v = torch.diagonal(t, offset=1, dim1=1, dim2=2) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 0, 1], v[0, 0]) + + def test_select_view(self, device="mps") -> None: + t = torch.ones((5, 5), device=device) + v = t.select(0, 2) + self.assertTrue(self.is_view_of(t, v)) + + v[0] = 0 + self.assertEqual(t[2, 0], v[0]) + + def test_unbind_view(self, device="mps") -> None: + t = torch.zeros((5, 5), device=device) + tup = torch.unbind(t) + + for idx, v in enumerate(tup): + self.assertTrue(self.is_view_of(t, v)) + + v[0] = idx + 1 + self.assertEqual(t[idx, 0], v[0]) + + def test_expand_view(self, device="mps") -> None: + t = torch.ones((5, 1), device=device) + v = t.expand(5, 5) + self.assertTrue(self.is_view_of(t, v)) + + v[2, 2] = 0 + self.assertEqual(t[2, 0], v[2, 2]) + + def test_expand_as_view(self, device="mps"): + t = torch.ones((5, 1), device=device) + e = torch.empty((5, 5), device=device) + v = t.expand_as(e) + self.assertTrue(self.is_view_of(t, v)) + + v[2, 2] = 0 + self.assertEqual(t[2, 0], v[2, 2]) + + def test_narrow_view(self, device="mps"): + t = torch.ones((5, 5), device=device) + v = torch.narrow(t, 1, 2, 2) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 2], v[0, 0]) + + def test_permute_view(self, device="mps") -> None: + t = torch.ones((5, 5), device=device) + v = t.permute(1, 0) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_transpose_view(self, device="mps"): + for fn in (torch.swapdims, torch.swapaxes, torch.transpose): + t = torch.ones((5, 5), device=device) + v = fn(t, 0, 1) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_transpose_inplace_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.swapdims_(0, 1) + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.swapaxes_(0, 1) + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.transpose_(0, 1) + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_t_view(self, device="mps"): + t = torch.ones((5, 5), device=device) + v = t.t() + self.assertTrue(self.is_view_of(t, v)) + + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_t_inplace_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.t_() + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_T_view(self, device="mps"): + for op in ("T", "H", "mT", "mH"): + t = torch.ones((5, 5), device=device) + v = getattr(t, op) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + # requires aten::unfold + # def test_unfold_view(self, device="mps"): + # t = torch.ones(10, device=device) + # v = t.unfold(0, 3, 2) + # self.assertTrue(self.is_view_of(t, v)) + + # v[1, 0] = 0 + # self.assertEqual(t[2], v[1, 0]) + + def test_squeeze_view(self, device="mps"): + t = torch.ones(5, 1, 5, device=device) + v = torch.squeeze(t) + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertTrue(t is v._base) + + def test_squeeze_inplace_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.squeeze_() + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertTrue(t is v._base) + + def test_unsqueeze_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = torch.unsqueeze(t, 1) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0, 1] = 0 + self.assertEqual(t[0, 1], v[0, 0, 1]) + + def test_unsqueeze_inplace_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.unsqueeze_(1) + self.assertTrue(self.is_view_of(t, v)) + v[0, 0, 1] = 0 + self.assertEqual(t[0, 1], v[0, 0, 1]) + + def test_as_strided_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = torch.as_strided(t, (25,), (1,)) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_as_strided_inplace_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view_as(t) + v = v.as_strided_((25,), (1,)) + self.assertTrue(self.is_view_of(t, v)) + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_view_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t.view(25) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_view_as_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + e = torch.empty((25,)) + v = t.view_as(e) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_contiguous_self(self, device="mps"): + t = torch.ones(5, 5, device=device) + s = t.contiguous() + self.assertTrue(s is t) + + def test_contiguous_nonview(self, device="mps"): + t = torch.ones(5, 5, device=device) + nv = t.t().contiguous() + self.assertTrue(not self.is_view_of(t, nv)) + + nv[0, 0] = 0 + self.assertNotEqual(t[0, 0], nv[0, 0]) + + def test_reshape_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = torch.reshape(t, (25,)) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_reshape_as_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + e = torch.empty((25,), device=device) + v = t.reshape_as(e) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_reshape_nonview(self, device="mps"): + t = torch.ones(5, 5, device=device) + nv = torch.reshape(t.t(), (25,)) + self.assertTrue(not self.is_view_of(t, nv)) + + nv[6] = 0 + self.assertNotEqual(t[1, 1], nv[6]) + + def test_flatten_view(self, device="mps"): + def test_writes_propagate(t, v): + idx_t = (0,) * t.ndim + idx_v = (0,) * v.ndim + v[idx_v] = 0 + self.assertEqual(t[idx_t], v[idx_v]) + + t = torch.ones(1, 2, 3, 4, device=device) + v = t.flatten() + self.assertTrue(self.is_view_of(t, v)) + test_writes_propagate(t, v) + + # zero-dimensional tensor + t = torch.tensor(1, device=device) + v = t.flatten() + test_writes_propagate(t, v) + self.assertTrue(self.is_view_of(t, v)) + + t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3) + v = t.flatten(0, 1) + test_writes_propagate(t, v) + self.assertTrue(self.is_view_of_same_base(t, v)) + + # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups: + t = torch.ones(720, device=device) \ + .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0)) + # [--1--|---2---|-3-] [--1--|----2---|-3-] + v1 = t.flatten(0, 1) + v2 = v1.flatten(1, 3) + v3 = v2.flatten(2, 2) + test_writes_propagate(t, v1) + self.assertTrue(self.is_view_of_same_base(t, v1)) + test_writes_propagate(t, v2) + self.assertTrue(self.is_view_of_same_base(t, v2)) + test_writes_propagate(t, v3) + self.assertTrue(self.is_view_of_same_base(t, v3)) + + def test_flatten_nonview(self, device="mps"): + def assert_is_nonview(t, nv): + idx_t = (0,) * t.ndim + idx_nv = (0,) * nv.ndim + self.assertTrue(not nv._is_view()) + nv[idx_nv] = 0 + self.assertNotEqual(t[idx_t], nv[idx_nv]) + t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) + nv = t.flatten(1, 3) + assert_is_nonview(t, nv) + + t = torch.ones(2, 2, device=device).T + nv = t.flatten() + assert_is_nonview(t, nv) + + # flatten returns the original object if start_dim=end_dim + t = t = torch.ones(2, 2, device=device) + nv = t.flatten(1, 1) + self.assertTrue(t is nv) + + def test_basic_indexing_slice_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t[:2, :3] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 0], v[0, 0]) + + def test_basic_indexing_ellipses_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t[..., :2] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 0], v[0, 0]) + + def test_basic_indexing_newaxis_view(self, device="mps"): + t = torch.ones(5, 5, device=device) + v = t[None, :2, 3] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 3], v[0, 0]) + + def test_chunk_view(self, device="mps"): + t = torch.zeros(3, 3, device=device) + l = torch.chunk(t, 3) + + for idx, v in enumerate(l): + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = idx + 1 + self.assertEqual(t[idx, 0], v[0, 0]) + + def test_split_view(self, device="mps"): + t = torch.zeros(3, 3, device=device) + l = torch.split(t, [1, 1, 1]) + + for idx, v in enumerate(l): + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = idx + 1 + self.assertEqual(t[idx, 0], v[0, 0]) + + def test_movedim_view(self, device="mps"): + def run_test(device, op): + t = torch.zeros(3, 3, device=device) + out = op(t) + + self.assertTrue(self.is_view_of(t, out)) + + # Randomly change values in output + # and verify that original is changed + # as well. + for _ in range(3): + idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) + out[idx_1, idx_2] = random.random() + self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) + + for fn in [torch.movedim, torch.moveaxis]: + op = partial(fn, source=(0, 1), destination=(1, 0)) + run_test(device, op) + + op = partial(fn, source=0, destination=1) + run_test(device, op) + + # Testing that the generated view_copy kernel and its derivative are implemented correctly + def test_view_copy(self, device="mps"): + a = torch.randn(4, device=device, requires_grad=True) + a_ref = a.clone().detach().requires_grad_() + a_view = a_ref.view(2, 2) + a_view_copy = torch.view_copy(a, (2, 2)) + + # view_copy ops don't preserve view relationship + self.assertTrue(self.is_view_of(a_ref, a_view)) + self.assertFalse(self.is_view_of(a, a_view_copy)) + + a_view_copy.sum().backward() + a_view.sum().backward() + + # forward and backward give the same shape + result + self.assertEqual(a_view_copy, a_view) + self.assertEqual(a.grad, a_ref.grad) + + def test_view_copy_out(self, device="mps"): + a = torch.randn(2, 2, device=device) + out = torch.empty(2, device=device) + + torch.diagonal_copy(a, out=out) + expected = torch.diagonal_copy(a) + + self.assertEqual(expected, out) + + a = torch.randn(4, device=device) + out1 = torch.empty(2, device=device) + out2 = torch.empty(2, device=device) + + torch.split_copy(a, 2, out=(out1, out2)) + expected1, expected2 = torch.split_copy(a, 2) + + self.assertEqual(expected1, out1) + self.assertEqual(expected2, out2) + + def test_empty_reshape(self, device="mps"): + x = torch.randn(0, 6, device=device) + self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) + # should be viewable -- i.e. data_ptr is the same. + self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) + + # match NumPy semantics -- don't infer the size of dimension with a degree of freedom + self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) + + def test_expand(self, device="mps"): + tensor = torch.rand(1, 8, 1, device=device) + tensor2 = torch.rand(5, device=device) + template = torch.rand(4, 8, 5, device=device) + target = template.size() + self.assertEqual(tensor.expand_as(template).size(), target) + self.assertEqual(tensor.expand(4, 8, 5).size(), target) + self.assertEqual(tensor.expand(target).size(), target) + self.assertEqual(tensor2.expand_as(template).size(), target) + self.assertEqual(tensor2.expand(4, 8, 5).size(), target) + self.assertEqual(tensor2.expand(target).size(), target) + + # test double expand + self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) + + # test non-contiguous + noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0] + self.assertFalse(noncontig.is_contiguous()) + self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) + + # make sure it's compatible with unsqueeze + expanded = tensor2.expand(1, 1, 5) + unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) + self.assertEqual(expanded, unsqueezed) + self.assertEqual(expanded.stride(), unsqueezed.stride()) + + # test -1 as target size + self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) + self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) + + # test expanding empty to empty + self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device)) + + def test_view_empty(self, device="mps"): + x = torch.randn(0, 6, device=device) + self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) + + def test_reshape(self, device="mps"): + x = torch.randn(3, 3, device=device) + self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) + self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) + self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) + + y = torch.randn(4, 4, 4, device=device)[:, 0, :] + # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape + if device != "meta": + self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) + self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) + self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) + + s = torch.randn((), device=device) + self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) + self.assertEqual(s.reshape(-1).shape, (1,)) + self.assertRaises(RuntimeError, lambda: s.reshape(2)) + + empty = torch.tensor([], device=device) + self.assertEqual(empty, empty.reshape(-1)) + self.assertEqual(empty, empty.reshape([0])) + # TODO: fix these once we have multi-dimensional empty tensors + self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) + self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) + self.assertRaises(RuntimeError, lambda: empty.reshape(1)) + + x = torch.randn(3, 3, device=device) + self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) + self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device))) + + def test_narrow(self, device="mps"): + x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]])) + self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]])) + self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]])) + self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]])) + self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]])) + self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) + self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]])) + self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]])) + + def test_narrow_tensor(self, device="mps"): + x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]])) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor(0.), 1) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor([0]), 1) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor([0, 1]), 1) + + def test_t(self, device="mps"): + # Test 0D tensors + x = torch.randn(()) + self.assertEqual(x, x.t()) + x = x.to_sparse() + self.assertEqual(x, x.t()) + + # Test 1D tensors + x = torch.arange(4) + self.assertEqual(x, x.t()) + x = x.to_sparse() + self.assertEqual(x, x.t()) + + # Test 2D tensors + x = torch.rand((2, 2)) + self.assertEqual(x.t(), x.transpose(0, 1)) + x = x.to_sparse() + self.assertEqual(x.t(), x.transpose(0, 1)) + + # Test 3D tensor + x = torch.rand((2, 2, 2)) + with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): + x.t() + x = x.to_sparse() + with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): + x.t() + + def test_split(self, device="mps"): + tensor = torch.rand(7, 4) + split_size = 3 + dim = 0 + target_sizes = ([3, 4], [3, 4], [1, 4]) + splits = tensor.split(split_size, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + # Variable sections split + tensor = torch.randn(20, 10) + dim = 0 + split_sizes = [5, 5, 10] + target_sizes = ([[5, 10], [5, 10], [10, 10]]) + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + split_sizes = [2, 2, 6] + target_sizes = ([20, 2], [20, 2], [20, 6]) + dim = 1 + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + def test_chunk(self, device="mps"): + tensor = torch.rand(4, 7) + num_chunks = 3 + dim = 1 + target_sizes = ([4, 3], [4, 3], [4, 1]) + splits = tensor.chunk(num_chunks, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, + atol=0, rtol=0) + start = start + target_size[dim] + + # Invalid chunk sizes + error_regex = 'chunk expects.*greater than 0' + with self.assertRaisesRegex(RuntimeError, error_regex): + tensor.chunk(0) + with self.assertRaisesRegex(RuntimeError, error_regex): + tensor.chunk(-2) + + def test_unsqueeze(self, device="mps") -> None: + x = torch.randn(2, 3, 4) + y = x.unsqueeze(1) + self.assertEqual(y, x.view(2, 1, 3, 4)) + y = x.clone().unsqueeze_(2) + self.assertEqual(y, x.view(2, 3, 1, 4)) + + x = x[:, 1] + self.assertFalse(x.is_contiguous()) + y = x.unsqueeze(1) + self.assertEqual(y, x.contiguous().view(2, 1, 4)) + y = x.clone().unsqueeze_(2) + self.assertEqual(y, x.contiguous().view(2, 4, 1)) + + # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) + def test_big_transpose(self, device="mps"): + t = torch.rand(456, 789, device=device) + t1 = t.t().contiguous() + t2 = torch.from_numpy(t.cpu().numpy().transpose()) + self.assertEqual(t1, t2) + + def test_T(self, device="mps"): + a = torch.randn(2, 3, 4, device=device) + t1 = a.T + t2 = a.permute(2, 1, 0) + self.assertEqual(t2, t1) + b = torch.randn(10, device=device) + self.assertEqual(b, b.T) + scalar = torch.tensor(5, device=device) + self.assertEqual(scalar, scalar.T) + + def test_transposes(self, device="mps", dtype=torch.float32): + for op in ("T", "H", "mT", "mH", "adjoint"): + shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),) + for shape in shapes: + a = make_tensor(shape, device=device, dtype=dtype) + t1 = getattr(a, op) + if op == "adjoint": + t1 = t1() + t2 = a + if a.ndim != 0: + t2 = t2.transpose(-2, -1) + if op[-1] == "H" or op == "adjoint": + t2 = t2.conj() + self.assertEqual(t2, t1) + + def test_transposes_errors(self, device="mps", dtype=torch.float32): + for op in ("H", "mT", "mH", "adjoint"): + shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),) + for shape in shapes: + a = make_tensor(shape, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "only supported on matrices"): + t1 = getattr(a, op) + if op == "adjoint": + t1 = t1() + + def test_python_types(self, device="mps"): + a1 = torch.randn((1, 2), device=device, dtype=torch.float32) + a2 = torch.randn((1, 2), device=device, dtype=torch.float32) + self.assertEqual(a1.dtype, a2.dtype) + + b1 = torch.arange(10, 20, dtype=torch.int64, device=device) + b2 = torch.arange(10, 20, dtype=int, device=device) + self.assertEqual(b1.dtype, b2.dtype) + + c1 = torch.tensor([True, False], dtype=torch.bool, device=device) + c2 = torch.tensor([True, False], dtype=bool, device=device) + self.assertEqual(c1.dtype, c2.dtype) + + # TODO: is resize best put in test_view_ops? + def test_resize_as_preserves_strides(self, device="mps"): + x = torch.empty(2, 3).t() + old_strides = x.stride() + x.resize_as_(x) + self.assertEqual(x.stride(), old_strides) + + def test_memory_format_resize_as(self, device="mps"): + def test_helper(shape, memory_format, device="mps"): + xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format) + flat = torch.randn(xc.numel(), device=device) + flat.resize_as_(xc, memory_format=torch.preserve_format) + self.assertTrue(flat.is_contiguous(memory_format=memory_format)) + + test_helper((10, 3, 32, 32), torch.channels_last, device="mps") + test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps") + + def test_memory_format_resize_(self, device="mps"): + def test_helper(shape, numel, memory_format, device="mps"): + flat = torch.randn(numel, device=device) + flat.resize_(shape, memory_format=memory_format) + self.assertTrue(flat.is_contiguous(memory_format=memory_format)) + + test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps") + test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps") + + # TODO: OpInfo this + def _test_atleast(self, device, torch_fn): + # 0-dim + s = torch.tensor(0.5, dtype=torch.double, requires_grad=True) + + gradcheck(lambda x: torch_fn(x), s) + gradgradcheck(lambda x: torch_fn(x), s) + + # 1-dim + a = torch.rand(4, dtype=torch.double, requires_grad=True) + + gradcheck(lambda x: torch_fn(x), a) + gradgradcheck(lambda x: torch_fn(x), a) + + # 2,3,4-dim + b = torch.rand(4, 3, dtype=torch.double, requires_grad=True) + c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True) + d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True) + + input_tuple = (s, a, b, c, d) + gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) + gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) + + def test_atleast_gradient(self, device="mps"): + self._test_atleast(device, torch.atleast_1d) + self._test_atleast(device, torch.atleast_2d) + self._test_atleast(device, torch.atleast_3d) + + def test_view(self, device="mps"): + tensor = torch.rand(15, device=device) + template = torch.rand(3, 5, device=device) + empty = torch.empty(0, device=device) + target = template.size() + self.assertEqual(tensor.view_as(template).size(), target) + self.assertEqual(tensor.view(3, 5).size(), target) + self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) + self.assertEqual(tensor.view(-1, 5).size(), target) + self.assertEqual(tensor.view(3, -1).size(), target) + tensor_view = tensor.view(5, 3) + tensor_view.fill_(random.uniform(0, 1)) + self.assertEqual(empty.view_as(empty), empty) + self.assertEqual(empty.view(0), empty) + self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) + self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) + + # test size inference with empty tensors + self.assertEqual(empty.view(-1).size(), torch.Size([0])) + self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) + + with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): + empty.view(-1, 0) + + with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): + empty.view(3, 0, -1, 0) + + self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) + self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) + self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) + + # RuntimeError: Invalid device for storage: mps + def test_contiguous(self, device="mps"): + x = torch.randn(1, 16, 5, 5, device=device) + self.assertTrue(x.is_contiguous()) + stride = list(x.stride()) + stride[0] = 20 + # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 + x.set_(x.storage(), 0, x.size(), stride) + self.assertTrue(x.is_contiguous()) + + def test_resize_all_dtypes_and_devices(self, device="mps"): + shape = (2, 2) + for dt in (torch.half, torch.bfloat16, torch.bool): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + x.resize_(shape) + self.assertEqual(shape, x.shape) + + def test_resize_as_all_dtypes_and_devices(self, device="mps"): + for dt in (torch.half, torch.bfloat16, torch.bool): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) + x.resize_as_(y) + self.assertEqual(y.shape, x.shape) + + def test_resize_overflow(self, device="mps"): + x = torch.empty((), dtype=torch.float64) + with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'): + x.resize_([2, 4, 2**29, 2**29]) + with self.assertRaisesRegex(RuntimeError, 'overflow'): + x.resize_([8, 8, 2**29, 2**29]) + + def test_view_all_dtypes_and_devices(self, device="mps"): + for dt in (torch.float, torch.bool): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + self.assertEqual(x.view(6).shape, [6]) class TestRNNMPS(TestCase): def test_lstm_1(self, device="mps", dtype=torch.float32): @@ -4600,15 +5862,34 @@ def test_lstm_1(self, device="mps", dtype=torch.float32): hx = torch.zeros(2, 3, 4, device="cpu") cx = torch.zeros(2, 3, 4, device="cpu") - cpu_output, _ = rnn(input, (hx, cx)) + cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx)) - device = torch.device("mps") rnn = rnn.to(device) input = input.to(device) hx = hx.to(device) cx = cx.to(device) - output, _ = rnn(input, (hx, cx)) + output, (hn, cn) = rnn(input, (hx, cx)) + self.assertEqual(cpu_output, output) + self.assertEqual(cpu_hn, hn) + self.assertEqual(cpu_cn, cn) + + # test batch_first + rnn = nn.LSTM(1, 4, 2, device="cpu", batch_first=True) + input = torch.randn(3, 2, 1, device="cpu") + hx = torch.zeros(2, 3, 4, device="cpu") + cx = torch.zeros(2, 3, 4, device="cpu") + cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx)) + + rnn = rnn.to(device) + input = input.to(device) + hx = hx.to(device) + cx = cx.to(device) + output, (hn, cn) = rnn(input, (hx, cx)) + + self.assertEqual(cpu_output, output) + self.assertEqual(cpu_hn, hn) + self.assertEqual(cpu_cn, cn) @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): @@ -4710,8 +5991,11 @@ def test_assert_close(self): with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): torch.testing.assert_close(a, inf) - with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): - torch.testing.assert_close(a, nan) + # TODO: The NaN test is failing when all the tests in test_mps are run + # together but passes when run separately. There seems to be memory + # corruption which needs to be fixed for this test to be enabled. + # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): + # torch.testing.assert_close(a, nan) @unittest.expectedFailure def test_mps_compat(self): @@ -4775,7 +6059,612 @@ def test_serialization_map_location(self): self.assertEqual(x2.device.type, "cpu") +MPS_DTYPES = get_all_dtypes() +for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]: + del MPS_DTYPES[MPS_DTYPES.index(t)] + +class TestConsistency(TestCase): + # TODO: This is only used while some ops are being added. + # This list should contain all ops and dtypes eventually + # This can be generated automatically in the `new_mps_allowlist.txt` file + # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` + # You most likely do NOT want to modify this manually + ALLOWLIST_OP = { + '__radd__': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rand__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rmul__': ['torch.bool', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__ror__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rxor__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '_masked.normalize': ['torch.float32'], + 'abs': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.uint8'], + 'add': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'addcdiv': ['torch.float32'], + 'addcmul': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'addmv': ['torch.float32'], + 'addr': ['torch.float32'], + 'all': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'any': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'argmax': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'asin': ['torch.float32'], + 'asinh': ['torch.float32'], + 'atan': ['torch.float32'], + 'atan2': ['torch.float32'], + 'atanh': ['torch.float32'], + 'atleast_1d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'atleast_2d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'atleast_3d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'baddbmm': ['torch.float32'], + 'bitwise_and': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_left_shift': ['torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_not': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_or': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_right_shift': ['torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_xor': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bmm': ['torch.float32'], + 'ceil': ['torch.float32'], + 'chunk': ['torch.float16', 'torch.float32', 'torch.int64'], + 'clone': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'column_stack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'conj': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'conj_physical': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'contiguous': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'corrcoef': ['torch.float32'], + 'deg2rad': ['torch.float32'], + 'diag': ['torch.float32', 'torch.int32'], + 'diagflat': ['torch.int32'], + 'diff': ['torch.float32'], + 'dist': ['torch.float32'], + 'dot': ['torch.float32', 'torch.int32'], + 'einsum': ['torch.float32'], + 'erf': ['torch.float32'], + 'fill': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'flatten': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'floor': ['torch.float32'], + 'hstack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'index_select': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'isinf': ['torch.float16', 'torch.float32'], + 'isnan': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'kron': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'linalg.norm': ['torch.float16', + 'torch.float32', + 'torch.float16', + 'torch.float32'], + 'linalg.svd': ['torch.float32'], + 'linalg.vector_norm': ['torch.float16'], + 'log1p': ['torch.float32'], + 'log_softmax': ['torch.float32'], + 'logaddexp': ['torch.float32'], + 'logaddexp2': ['torch.float32'], + 'masked_select': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'mm': ['torch.float32'], + 'mv': ['torch.float32'], + 'neg': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32'], + 'nn.functional.adaptive_max_pool1d': ['torch.float32'], + 'nn.functional.adaptive_max_pool2d': ['torch.float32'], + 'nn.functional.binary_cross_entropy': ['torch.float32'], + 'nn.functional.celu': ['torch.float32'], + 'nn.functional.elu': ['torch.float32'], + 'nn.functional.embedding': ['torch.float16', 'torch.float32'], + 'nn.functional.feature_alpha_dropout': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.hardtanh': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'nn.functional.hinge_embedding_loss': ['torch.float32'], + 'nn.functional.kl_div': ['torch.float32'], + 'nn.functional.l1_loss': ['torch.float32'], + 'nn.functional.huber_loss': ['torch.float32'], + 'nn.functional.leaky_relu': ['torch.float32'], + 'nn.functional.mse_loss': ['torch.float16', 'torch.float32'], + 'nn.functional.relu': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.relu6': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.prelu': ['torch.float32'], + 'nn.functional.selu': ['torch.float32'], + 'nn.functional.silu': ['torch.float32'], + 'nn.functional.smooth_l1_loss': ['torch.float32', + 'torch.float16'], + 'nn.functional.softmin': ['torch.float32'], + 'nn.functional.threshold': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.upsample_bilinear': ['torch.float32'], + 'norm': ['torch.float32', 'torch.float16', 'torch.float32'], + 'positive': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'rad2deg': ['torch.float32'], + 'ravel': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'real': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'repeat': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'repeat_interleave': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resize_': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resize_as_': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resolve_conj': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resolve_neg': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'round': ['torch.float32'], + 'sgn': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sign': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.uint8'], + 'sin': ['torch.float32'], + 'sinh': ['torch.float32'], + 'softmax': ['torch.float32'], + 'split': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sqrt': ['torch.float32'], + 'square': ['torch.float32'], + 'squeeze': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'stack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sub': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'sum_to_size': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'svd': ['torch.float32'], + 't': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'tanh': ['torch.float32'], + 'tensordot': ['torch.float32'], + 'topk': ['torch.float32'], + 'tril': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'triu': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'true_divide': ['torch.float32'], + 'trunc': ['torch.float32'], + 'unsqueeze': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'view': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'view_as': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'vsplit': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'vstack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'zero_': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8']} + + # These ops that are problematic. So never run them even when + # generating the new allowlist. + # If the dtype list is None, all dtypes are excluded. + # All the entries in this list should be removed + BLOCKLIST = { + # Functions that hang + 'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool], + # Functions that hard crash + 'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64], + 'nn.functional.nll_loss': [torch.float32], + 'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32], + 'std': [torch.float16], + 'stft': [torch.float32], 'var': [torch.float16], + + # These were moved from ALLOWLIST to BLOCK as they are not working + # locally + 'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + '__radd__': ['torch.bool', 'torch.uint8'], + '__rmul__': ['torch.uint8'], + 'add': ['torch.bool', 'torch.uint8'], + 'square': ['torch.int32', 'torch.int64', 'torch.uint8'], + 'addr': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + 'diag': ['torch.int64'], + 'diagflat': ['torch.int64'], + + # Functions that are flaky + # These are detected as "ok" by the expect case but actually fail to run sometimes + 'H': None, + 'T': None, + 'as_strided': None, + 'broadcast_tensors': None, + 'broadcast': None, + 'broadcast_to': None, + 'diagonal': None, + 'divfloor_rounding': None, + 'divno_rounding_mode': None, + 'divtrunc_rounding': None, + 'dsplit': None, + 'hsplit': None, + 'empty': None, + 'expand_as': None, + 'expand': None, + 'ge': None, + 'ne': None, + 'le': None, + 'lt': None, + 'gt': None, + 'transpose': None, + 'splitlist_args': None, + 'select': None, + 'reshape': None, + 'reshape_as': None, + 'permute': None, + 'norm': None, + 'nn.functional.pixel_unshuffle': None, + 'nn.functional.pixel_shuffle': None, + 'nn.functional.cross_entropy': None, + 'nn.functional.one_hot': None, + 'narrow': None, + 'movedim': None, + 'minreduction_with_dim': None, + 'minreduction_no_dim': None, + 'minbinary': None, + 'meshgridvariadic_tensors': None, + 'meshgridlist_of_tensors': None, + 'maxreduction_with_dim': None, + 'maxreduction_no_dim': None, + 'maxbinary': None, + 'maximum': None, + 'minimum': None, + 'mT': None, + 'mH': None, + 'outer': None, + 'softmaxwith_dtype': None, + 'rounddecimals_neg_3': None, + 'rounddecimals_3': None, + 'rounddecimals_0': None, + 'normnuc': None, + 'nn.functional.softminwith_dtype': None, + 'nn.functional.feature_alpha_dropoutwith_train': None, + 'log_softmaxdtype': None, + 'split_with_sizes': None, + 'trapezoid': None, + 'eq': None, + 'mul': None, + 'cartesian_prod': None, + 'nonzero': None, + 'bool': None, + 'inner': None, + 'dstack': None, + 'take_along_dim': None, + } + + # Used for accept mode only + NEW_ALLOW_LIST = defaultdict(list) + + @ops(op_db, allowed_dtypes=MPS_DTYPES) + def test_output_match(self, device, dtype, op): + self.assertEqual(device, "cpu") + if not torch.backends.mps.is_available(): + self.skipTest("MPS is not available") + + key = op.name + op.variant_test_name + if key in self.BLOCKLIST: + if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]: + self.skipTest(f"Running test with {op.name} hangs so skipping") + + # Make this an expecttest manually + # When this env variable is set, generate a new ALLOWLIST_OP + # that reflects the current state of what passes or not + if os.environ.get("EXPECTTEST_ACCEPT", None) == "1": + generate_new_truth = True + else: + generate_new_truth = False + + if not generate_new_truth: + if op.name not in self.ALLOWLIST_OP: + self.skipTest(f"{op.name} is not in the allow list for test on MPS") + else: + if str(dtype) not in self.ALLOWLIST_OP[op.name]: + self.skipTest(f"{op.name} is in the allow list for MPS but {dtype} is excluded") + try: + cpu_samples = op.sample_inputs(device, dtype) + + for cpu_sample in cpu_samples: + mps_sample = cpu_sample.transform(lambda x: x.to("mps") if isinstance(x, torch.Tensor) else x) + + # TODO: This checks only the function variant. We should also check the method and inplace version + # when they exist + cpu_args = [cpu_sample.input] + list(cpu_sample.args) + cpu_kwargs = cpu_sample.kwargs + mps_args = [mps_sample.input] + list(mps_sample.args) + mps_kwargs = mps_sample.kwargs + + cpu_out = op(*cpu_args, **cpu_kwargs) + mps_out = op(*mps_args, **mps_kwargs) + self.assertEqual(cpu_out, mps_out) + except Exception as e: + if not generate_new_truth: + raise e + else: + if generate_new_truth: + self.NEW_ALLOW_LIST[op.name].append(str(dtype)) + + # We could write it only once. But I don't know how to detect that the current test is the last one + # So each test append to the dict and write it. + with open("new_mps_allowlist.txt", "w") as f: + pprint.pprint(self.NEW_ALLOW_LIST, stream=f) + +# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. +# This requires mps to be properly registered in the device generic test framework which is not the +# case right now. +instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu") if __name__ == "__main__": run_tests() diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 2b9b00191cc58..839dc01c90dbb 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1083,22 +1083,13 @@ def test_flatten_nodims(self): def test_unflatten(self): # test args: tensor, int, namedshape self.assertTrue(torch.equal( - torch.ones(4).unflatten(0, (('A', 2), ('B', 2))), + torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))), torch.ones(2, 2, names=('A', 'B')))) self.assertTrue(torch.equal( - torch.ones(4).unflatten(0, [('A', 2), ('B', 2)]), + torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]), torch.ones(2, 2, names=('A', 'B')))) self.assertTrue(torch.equal( - torch.ones(4).unflatten(0, (['A', 2], ['B', 2])), - torch.ones(2, 2, names=('A', 'B')))) - self.assertTrue(torch.equal( - torch.ones(4).unflatten(-1, (['A', 2], ['B', 2])), - torch.ones(2, 2, names=('A', 'B')))) - self.assertTrue(torch.equal( - torch.ones(4).unflatten(-1, (['A', -1], ['B', 2])), - torch.ones(2, 2, names=('A', 'B')))) - self.assertTrue(torch.equal( - torch.ones(4).unflatten(-1, (['A', 2], ['B', -1])), + torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])), torch.ones(2, 2, names=('A', 'B')))) self.assertTrue(torch.equal( torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)), @@ -1112,18 +1103,13 @@ def test_unflatten(self): .unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])), torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3')))) - # test args: namedtensor, int, namedshape - self.assertTrue(torch.equal( - torch.ones(2, 4, names=('A', 'B')).unflatten(1, (('B1', 2), ('B2', 2))), - torch.ones(2, 2, 2, names=('A', 'B1', 'B2')))) - # test args: namedtensor, str, namedshape self.assertTrue(torch.equal( torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))), torch.ones(2, 2, 2, names=('A', 'B1', 'B2')))) # test invalid args: namedtensor, str, sizes - with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"): + with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"): torch.tensor([1], names=('A',)).unflatten('A', (1, 1)) # test invalid args: namedtensor, int, sizes @@ -1195,8 +1181,11 @@ def test_simple_reduce(op, device): check_output(op(t, 1), ['N', 'L']) check_output(op(t, -1), ['N', 'C']) check_output(op(t, 'C'), ['N', 'L']) - with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): - op(t, None) + if op.__name__ in ['sum', 'mean']: + check_output(op(t, None), []) + else: + with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): + op(t, None) with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'): op(t, 'H') diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index cd957aded6589..01d0d58fb708e 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -16,12 +16,12 @@ 'max', 'min', 'aminmax', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'slogdet', 'sort', 'topk', 'lstsq', 'linalg_inv_ex', 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "_linalg_eigh", "_unpack_dual", 'linalg_qr', - 'linalg_svd', '_linalg_svd', 'linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask', + 'linalg_svd', '_linalg_svd', 'linalg_slogdet', '_linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask', 'fake_quantize_per_channel_affine_cachemask', 'linalg_lstsq', 'linalg_eig', 'linalg_cholesky_ex', 'frexp', 'lu_unpack', 'histogram', 'histogramdd', '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams', '_fused_moving_avg_obs_fq_helper', 'linalg_lu_factor', 'linalg_lu_factor_ex', 'linalg_lu', - '_det_lu_based_helper', '_lu_with_info', 'linalg_ldl_factor_ex', 'linalg_ldl_factor', '_linalg_solve' + '_linalg_det', '_lu_with_info', 'linalg_ldl_factor_ex', 'linalg_ldl_factor', 'linalg_solve_ex', '_linalg_solve_ex' } @@ -72,9 +72,9 @@ def test_namedtuple_return(self): op(operators=['kthvalue'], input=(1, 0), names=('values', 'indices'), hasout=True), op(operators=['svd'], input=(), names=('U', 'S', 'V'), hasout=True), - op(operators=['linalg_svd'], input=(), names=('U', 'S', 'Vh'), hasout=True), - op(operators=['_linalg_svd'], input=(), names=('U', 'S', 'Vh'), hasout=True), - op(operators=['slogdet'], input=(), names=('sign', 'logabsdet'), hasout=False), + op(operators=['linalg_svd', '_linalg_svd'], input=(), names=('U', 'S', 'Vh'), hasout=True), + op(operators=['slogdet', 'linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True), + op(operators=['_linalg_slogdet'], input=(), names=('sign', 'logabsdet', 'LU', 'pivots'), hasout=True), op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True), op(operators=['geqrf'], input=(), names=('a', 'tau'), hasout=True), op(operators=['symeig', 'eig'], input=(True,), names=('eigenvalues', 'eigenvectors'), hasout=True), @@ -83,10 +83,10 @@ def test_namedtuple_return(self): op(operators=['linalg_eig'], input=(), names=('eigenvalues', 'eigenvectors'), hasout=True), op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), op(operators=['_linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), - op(operators=['linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True), op(operators=['linalg_cholesky_ex'], input=(), names=('L', 'info'), hasout=True), op(operators=['linalg_inv_ex'], input=(), names=('inverse', 'info'), hasout=True), - op(operators=['_linalg_solve'], input=(a,), names=('result', 'LU', 'pivots'), hasout=True), + op(operators=['linalg_solve_ex'], input=(a,), names=('result', 'info'), hasout=True), + op(operators=['_linalg_solve_ex'], input=(a,), names=('result', 'LU', 'pivots', 'info'), hasout=True), op(operators=['linalg_lu_factor'], input=(), names=('LU', 'pivots'), hasout=True), op(operators=['linalg_lu_factor_ex'], input=(), names=('LU', 'pivots', 'info'), hasout=True), op(operators=['linalg_ldl_factor'], input=(), names=('LD', 'pivots'), hasout=True), @@ -111,8 +111,8 @@ def test_namedtuple_return(self): op(operators=['_fused_moving_avg_obs_fq_helper'], input=(torch.tensor([1]), torch.tensor([1]), torch.tensor([0.1]), torch.tensor([0.1]), torch.tensor([0.1]), torch.tensor([1]), 0.01, 0, 255, 0), names=('output', 'mask',), hasout=False), - op(operators=['_det_lu_based_helper'], - input=(), names=('det', 'lu', 'pivs'), hasout=False), + op(operators=['_linalg_det'], + input=(), names=('result', 'LU', 'pivots'), hasout=True), op(operators=['aminmax'], input=(), names=('min', 'max'), hasout=True), op(operators=['_lu_with_info'], input=(), names=('LU', 'pivots', 'info'), hasout=False), diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index b27e9572ed83e..34a0ebd05ee97 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -8,8 +8,9 @@ dtypesIfCUDA, instantiate_device_type_tests, skipMeta, + onlyCPU ) -from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state +from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state, parametrize from torch import nested_tensor # Tests are ported from pytorch/nestedtensor. @@ -18,6 +19,29 @@ def _iter_constructors(): # yield as_nested_tensor yield nested_tensor +# Helper functions to pad a noncontiguous nested tensor +# can be replaced once to_padded_tensor supports noncontiguous memory +def noncontiguous_to_padded_tensor(input, shape=None): + tensors = input.unbind() + ntensors = len(tensors) + assert ntensors > 0 + if shape is None: + shape = [] + for size in tensors[0].shape: + shape.append(size) + for i in range(1, ntensors): + new_shape = tensors[i].shape + for j in range(len(shape)): + shape[j] = max(shape[j], new_shape[j]) + shape = [ntensors] + shape + result = tensors[0].new_zeros(shape) + for itensor in range(ntensors): + tensor = tensors[itensor] + view = result[itensor] + for idim in range(tensor.dim()): + view = view.narrow(idim, 0, tensor.size(idim)) + view.copy_(tensor) + return result class TestNestedTensor(TestCase): @torch.inference_mode() @@ -136,9 +160,21 @@ def test_dim(self): def test_numel(self): for constructor in _iter_constructors(): a1 = constructor([]) - self.assertRaisesRegex( - RuntimeError, "numel is disabled", lambda: a1.numel(), - ) + self.assertEqual(a1.numel(), 0) + a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)]) + self.assertEqual(a1.numel(), 2) + a1 = constructor([torch.randn(2, 2, 2)]) + self.assertEqual(a1.numel(), 8) + a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)]) + self.assertEqual(a1.numel(), 12) + a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)]) + self.assertEqual(a1.numel(), 27) + a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)]) + self.assertEqual(a1.numel(), 341) + + # Interesting edge case + a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)]) + self.assertEqual(a1.numel(), 6) @torch.inference_mode() def test_size(self): @@ -206,6 +242,72 @@ def test_to_padded_tensor_on_empty_tensor(self): self.assertEqual(empty, torch.tensor([])) class TestNestedTensorDeviceType(TestCase): + # Helper function to assert 2 nested tensors are equal + def nt_equal(self, nt1, nt2): + self.assertEqual(nt1.dtype, nt2.dtype) + self.assertEqual(nt1.device, nt2.device) + ub1 = nt1.unbind() + ub2 = nt2.unbind() + self.assertEqual(len(ub1), len(ub2)) + n = len(ub1) + for i in range(n): + self.assertEqual(ub1[i], ub2[i]) + + # Helper function to generate a random nested tensor + def random_nt(self, device, dtype, num_tensors, max_dims, min_dims=None): + if min_dims is None: + min_dims = tuple([0] * len(max_dims)) + ts1 = [] + for _ in range(num_tensors): + tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item() + for (min_dim, max_dim) in zip(min_dims, max_dims)]) + t1 = torch.randn(tensor_dims, device=device, dtype=dtype) + ts1.append(t1) + return torch.nested_tensor(ts1, device=device, dtype=dtype) + + # Helper function to generate a pair of random nested tensors + # the 2 nested tensors have same shapes + def random_nt_pair(self, device, dtype, num_tensors, max_dims): + ts1 = [] + ts2 = [] + for _ in range(num_tensors): + tensor_dims = tuple([torch.randint(low=0, high=max_dim, size=(1,)).item() for max_dim in max_dims]) + t1 = torch.randn(tensor_dims, device=device, dtype=dtype) + t2 = torch.randn(tensor_dims, device=device, dtype=dtype) + ts1.append(t1) + ts2.append(t2) + return (torch.nested_tensor(ts1, device=device, dtype=dtype), + torch.nested_tensor(ts2, device=device, dtype=dtype)) + + # Helper function to generate a pair of random nested tensors + # one is contiguous, the other is not, but they appear to have same entries + # an output nested tensor consists of + # * `len(ragged_sizes)` matrices + # * matrices[i].shape == (20, ragged_sizes[i]) + def random_nt_noncontiguous_pair(self, ragged_sizes, device, dtype): + xs = [] + for size in ragged_sizes: + xs.append(torch.randn((size, 20), device=device, dtype=dtype)) + # contiguous nested tensor + ys = [] + for x in xs: + ys.append(x.transpose(-1, -2)) + nt_contiguous = torch.nested_tensor(ys) + # noncontiguous nested tensor + n = len(ragged_sizes) + nt_noncontiguous = torch.nested_tensor(xs).transpose(-1, -2) + return nt_contiguous, nt_noncontiguous + + @dtypes(torch.float, torch.float16, torch.double) + def test_unbind_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = self.random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + ub_contiguous = nt_contiguous.unbind() + ub_noncontiguous = nt_noncontiguous.unbind() + self.assertEqual(len(ub_contiguous), len(ub_noncontiguous)) + n = len(ub_contiguous) + for i in range(n): + self.assertEqual(ub_contiguous[i], ub_noncontiguous[i]) + @dtypes(torch.float) @skipMeta def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype): @@ -349,6 +451,25 @@ def test_to_padded_tensor_dim4(self, device, dtype): padded = nt.to_padded_tensor(pad) self.assertEqual(padded, correct_output) + # TODO: test noncontiguous to_padded_tensor + # For now this tests the functionality of noncontiguous_to_padded_tensor + # and the error message of to_padded_tensor + # since to_padded_tensor does not support noncontiguous buffer yet + @dtypes(torch.float, torch.float16, torch.double) + @torch.inference_mode() + def test_to_padded_tensor_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = self.random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + # test noncontiguous_to_padded_tensor functionality + self.assertEqual( + nt_contiguous.to_padded_tensor(0.0), + noncontiguous_to_padded_tensor(nt_noncontiguous)) + # test to_padded_tensor error message + self.assertRaisesRegex( + RuntimeError, + r"for now to_padded_tensor only supports contiguous nested tensor", + lambda: nt_noncontiguous.to_padded_tensor(0.0) + ) + @skipMeta def test_device_checks(self, device): nt = torch.nested_tensor([], device=device) @@ -359,11 +480,7 @@ def test_device_checks(self, device): def test_nested_tensor_indexing(self, device, dtype): # edge case: empty nested tensor nt0 = torch.nested_tensor([]) - self.assertRaisesRegex( - RuntimeError, - "cannot index an empty nested tensor", - lambda: nt0[0] - ) + self.assertRaises(IndexError, lambda: nt0[0]) # normal case x0 = torch.randn((2, 5), device=device, dtype=dtype) x1 = torch.randn((3, 4), device=device, dtype=dtype) @@ -391,40 +508,14 @@ def test_nested_tensor_indexing(self, device, dtype): answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4) self.assertEqual(nt[1, 1, :], answer) - # Helper functions for testing elementwise ops - def random_nt(self, device, dtype, num_tensors, max_dims, min_dims=None): - if min_dims is None: - min_dims = tuple([0] * len(max_dims)) - ts1 = [] - for _ in range(num_tensors): - tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item() - for (min_dim, max_dim) in zip(min_dims, max_dims)]) - t1 = torch.randn(tensor_dims, device=device, dtype=dtype) - ts1.append(t1) - return torch.nested_tensor(ts1, device=device, dtype=dtype) - - # Helper functions for testing elementwise ops - def random_nt_pair(self, device, dtype, num_tensors, max_dims): - ts1 = [] - ts2 = [] - for _ in range(num_tensors): - tensor_dims = tuple([torch.randint(low=0, high=max_dim, size=(1,)).item() for max_dim in max_dims]) - t1 = torch.randn(tensor_dims, device=device, dtype=dtype) - t2 = torch.randn(tensor_dims, device=device, dtype=dtype) - ts1.append(t1) - ts2.append(t2) - return (torch.nested_tensor(ts1, device=device, dtype=dtype), - torch.nested_tensor(ts2, device=device, dtype=dtype)) - - def nt_equal(self, nt1, nt2): - self.assertEqual(nt1.dtype, nt2.dtype) - self.assertEqual(nt1.device, nt2.device) - ub1 = nt1.unbind() - ub2 = nt2.unbind() - self.assertEqual(len(ub1), len(ub2)) - n = len(ub1) + @dtypes(torch.float, torch.float16, torch.double) + @torch.inference_mode() + def test_nested_tensor_indexing_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = self.random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0)) + n = nt_contiguous.size(0) for i in range(n): - self.assertEqual(ub1[i], ub2[i]) + self.assertEqual(nt_contiguous[i], nt_noncontiguous[i]) @dtypes(torch.float, torch.float16) @skipMeta @@ -439,10 +530,35 @@ def test_nested_tensor_add(self, device, dtype): @skipMeta @torch.inference_mode() def test_nested_tensor_mul(self, device, dtype): + # nested tensor * nested tensor (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) ref = torch.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) out = nt1 * nt2 self.nt_equal(ref, out) + # nested tensor * scalar + number = 10.0 + scalar = torch.tensor(number).to(dtype).to(device) + ref = torch.nested_tensor([t * number for t in nt1.unbind()]) + out_number0 = nt1 * number + out_number1 = number * nt1 + out_scalar0 = nt1 * scalar + out_scalar1 = scalar * nt1 + self.nt_equal(out_number0, ref) + self.nt_equal(out_number1, ref) + self.nt_equal(out_scalar0, ref) + self.nt_equal(out_scalar1, ref) + # error case: numel == 1 but dim > 0 + vector = torch.tensor([number]).to(dtype).to(device) + self.assertRaisesRegex( + RuntimeError, + "Expected both self and other to be nested, but got a nested self and non-nested other", + lambda: nt1.mul(vector) + ) + self.assertRaisesRegex( + RuntimeError, + "Expected both self and other to be nested, but got a non-nested self and nested other", + lambda: vector.mul(nt1) + ) @dtypes(torch.float, torch.float16) @skipMeta @@ -457,10 +573,67 @@ def test_nested_tensor_add_in_place(self, device, dtype): @skipMeta @torch.inference_mode() def test_nested_tensor_mul_in_place(self, device, dtype): + # nested tensor * nested tensor (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) ref = torch.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) nt1 *= nt2 self.nt_equal(ref, nt1) + # nested tensor * scalar + number = 10.0 + scalar = torch.tensor(number).to(dtype).to(device) + ref = torch.nested_tensor([t * number for t in nt1.unbind()]) + out_number = nt1.clone() + out_number *= number + out_scalar = nt1.clone() + out_scalar *= scalar + self.nt_equal(out_number, ref) + self.nt_equal(out_scalar, ref) + self.assertRaisesRegex( + RuntimeError, + r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]", + lambda: scalar.mul_(nt1) + ) + # error case: numel == 1 but dim > 0 + vector = torch.tensor([number]).to(dtype).to(device) + self.assertRaisesRegex( + RuntimeError, + "Expected both self and other to be nested, but got a nested self and non-nested other", + lambda: nt1.mul_(vector) + ) + self.assertRaisesRegex( + RuntimeError, + "Expected both self and other to be nested, but got a non-nested self and nested other", + lambda: vector.mul_(nt1) + ) + + @onlyCPU + @skipMeta + @dtypes(torch.float) + def test_nested_tensor_sum_dim(self, device, dtype): + params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7))) + + def test_sum(nt, dim, keepdim=True): + nt2 = nt.clone() + nt = nt.sum(dim=dim, keepdim=keepdim) + ub2 = nt2.unbind() + ub2 = [t.sum(-1, keepdim=keepdim) for t in ub2] + nt2 = torch.nested_tensor(ub2) + self.nt_equal(nt, nt2) + return + + for ntensors, max_sizes in params: + test_sum(self.random_nt(device, dtype, ntensors, max_sizes), len(max_sizes)) + + # Test error inputs + with self.assertRaisesRegex(RuntimeError, "NestedTensor can only be reduced across the last"): + torch.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(0, keepdim=True) + + with self.assertRaisesRegex(RuntimeError, "NestedTensor only allows reduction of a single"): + torch.nested_tensor([torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]).sum([0, 1], keepdim=True) + + with self.assertRaisesRegex(RuntimeError, "NestedTensor always requires keepdim=True for now."): + torch.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(-1) + @dtypes(torch.float, torch.float16) @skipMeta @@ -550,6 +723,489 @@ def test_dropout(self, device, dtype): expect_tensor[j] /= 1.0 - p self.nt_equal(nt, expect) + # dropout works directly on the underlying buffer memory + # so contiguous / noncontiguous does not make any difference + + # cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half' + @dtypes(torch.float, torch.double) + @torch.inference_mode() + def test_softmax(self, device, dtype): + # normal nested tensor + ntensors = 4 + nt = self.random_nt(device, dtype, ntensors, (4, 4)) + # error case: softmax across nested dimension + self.assertRaisesRegex( + RuntimeError, + "Cannot apply softmax across nested dimension 0", + lambda: torch.nn.functional.softmax(nt, 0) + ) + self.assertRaisesRegex( + RuntimeError, + "Cannot apply softmax across nested dimension 0", + lambda: torch.nn.functional.softmax(nt, -3) + ) + # error case: dimension out of range + self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3)) + self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4)) + # normal case: should equal to padding -inf + softmaxer = torch.nn.Softmax(1) + y0 = softmaxer(nt) + y1 = torch.nn.functional.softmax(nt, 1) + self.nt_equal(y0, y1) + pt = nt.to_padded_tensor(float("-inf")) + # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan + # however, physically speaking that should be 0.0 + expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0) + self.assertEqual(y0.to_padded_tensor(0.0), expect) + # edge case: empty nested tensor + nt0 = torch.nested_tensor([]) + y = torch.nn.functional.softmax(nt0, 1) + self.nt_equal(nt0, y) + # edge case: nesting scalars + nt1 = torch.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)]) + self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0)) + self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1)) + + @dtypes(torch.float, torch.double) + @torch.inference_mode() + def test_softmax_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = self.random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + self.nt_equal( + torch.nn.functional.softmax(nt_contiguous, -1), + torch.nn.functional.softmax(nt_noncontiguous, -1)) + + # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_bmm(self, device, dtype): + # error case: one is nested but the other is not + nt = torch.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) + t = torch.randn(4, device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + "Expected both to be nested, but got a nested self and non-nested other", + lambda: nt.bmm(t) + ) + self.assertRaisesRegex( + RuntimeError, + "Expected both to be nested, but got a non-nested self and nested other", + lambda: t.bmm(nt) + ) + # error case: not 3D tensors + nt0 = torch.nested_tensor([], device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) + nt2 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + "batch1 must be a 3D tensor", + lambda: nt0.bmm(nt0) + ) + self.assertRaisesRegex( + RuntimeError, + "batch1 must be a 3D tensor", + lambda: nt0.bmm(nt1) + ) + self.assertRaisesRegex( + RuntimeError, + "batch1 must be a 3D tensor", + lambda: nt0.bmm(nt2) + ) + self.assertRaisesRegex( + RuntimeError, + "batch1 must be a 3D tensor", + lambda: nt1.bmm(nt0) + ) + self.assertRaisesRegex( + RuntimeError, + "batch1 must be a 3D tensor", + lambda: nt1.bmm(nt1) + ) + self.assertRaisesRegex( + RuntimeError, + "batch1 must be a 3D tensor", + lambda: nt1.bmm(nt2) + ) + self.assertRaisesRegex( + RuntimeError, + "batch2 must be a 3D tensor", + lambda: nt2.bmm(nt0) + ) + self.assertRaisesRegex( + RuntimeError, + "batch2 must be a 3D tensor", + lambda: nt2.bmm(nt1) + ) + # error case: incompatible batch size + nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn((4, 6)), + torch.randn((4, 5)), + torch.randn((4, 7))], + device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.", + lambda: nt0.bmm(nt1) + ) + self.assertRaisesRegex( + RuntimeError, + "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.", + lambda: nt1.bmm(nt0) + ) + # error case: underlying matrices cannot be multiplied + nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", + lambda: nt0.bmm(nt0) + ) + # normal nested tensor + nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) + actual = nt0.bmm(nt1).to_padded_tensor(0.0) + expect = nt0.to_padded_tensor(0.0).bmm(nt1.to_padded_tensor(0.0)) + self.assertEqual(actual, expect) + + # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_bmm_noncontiguous(self, device, dtype): + nt0_contiguous, nt0_noncontiguous = self.random_nt_noncontiguous_pair((2, 3), device, dtype) + nt1_contiguous, nt1_noncontiguous = self.random_nt_noncontiguous_pair((6, 7), device, dtype) + self.nt_equal( + nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), + nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous)) + + # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_matmul(self, device, dtype): + # error case: one is nested but the other is not + nt = torch.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) + t = torch.randn(4, device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + "Expected both to be nested, but got a nested self and non-nested other", + lambda: torch.matmul(nt, t) + ) + self.assertRaisesRegex( + RuntimeError, + "Expected both to be nested, but got a non-nested self and nested other", + lambda: torch.matmul(t, nt) + ) + # error case: not 3+D tensors + nt0 = torch.nested_tensor([], device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) + nt2 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt0, nt0) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt0, nt1) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt0, nt2) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt1, nt0) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt1, nt1) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt1, nt2) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", + lambda: torch.matmul(nt2, nt0) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", + lambda: torch.matmul(nt2, nt1) + ) + # error case: incompatible batch size + nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn((4, 6)), + torch.randn((4, 5)), + torch.randn((4, 7))], + device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", + lambda: torch.matmul(nt0, nt1) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", + lambda: torch.matmul(nt1, nt0) + ) + # error case: incompatible generalized batch size + nt0 = torch.nested_tensor([torch.randn((2, 2, 4)), + torch.randn((2, 3, 4))], + device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn((3, 4, 6)), + torch.randn((3, 4, 5))], + device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, no broadcasting is currently performed: " + r"[0-9]+-th nested matrices in batch at dimension [0-9]+ " + r"have mismatching sizes [0-9]+ and [0-9]+", + lambda: torch.matmul(nt0, nt1) + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, no broadcasting is currently performed: " + r"[0-9]+-th nested matrices in batch at dimension [0-9]+ " + r"have mismatching sizes [0-9]+ and [0-9]+", + lambda: torch.matmul(nt1, nt0) + ) + # error case: underlying matrices cannot be multiplied + nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", + lambda: torch.matmul(nt0, nt0) + ) + # normal nested tensor: 3D + nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) + actual = torch.matmul(nt0, nt1).to_padded_tensor(0.0) + expect = torch.matmul(nt0.to_padded_tensor(0.0), nt1.to_padded_tensor(0.0)) + self.assertEqual(actual, expect) + # normal nested tensor: 4D + nt0 = torch.nested_tensor([torch.randn((8, 2, 4)), + torch.randn((8, 3, 7))], + device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn((8, 4, 6)), + torch.randn((8, 7, 5))], + device=device, dtype=dtype) + actual = torch.matmul(nt0, nt1).to_padded_tensor(0.0) + expect = torch.matmul(nt0.to_padded_tensor(0.0), nt1.to_padded_tensor(0.0)) + self.assertEqual(actual, expect) + # normal nested tensor: 5D + nt0 = torch.nested_tensor([torch.randn((8, 9, 2, 4)), + torch.randn((8, 9, 3, 7))], + device=device, dtype=dtype) + nt1 = torch.nested_tensor([torch.randn((8, 9, 4, 6)), + torch.randn((8, 9, 7, 5))], + device=device, dtype=dtype) + actual = torch.matmul(nt0, nt1).to_padded_tensor(0.0) + expect = torch.matmul(nt0.to_padded_tensor(0.0), nt1.to_padded_tensor(0.0)) + self.assertEqual(actual, expect) + + # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_matmul_noncontiguous(self, device, dtype): + nt0_contiguous, nt0_noncontiguous = self.random_nt_noncontiguous_pair((2, 3), device, dtype) + nt1_contiguous, nt1_noncontiguous = self.random_nt_noncontiguous_pair((6, 7), device, dtype) + self.nt_equal( + torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous), + torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous)) + + @dtypes(torch.float, torch.double) + def test_linear(self, device, dtype): + a = torch.randn(1, 2, device=device, dtype=dtype) + b = torch.randn(2, 2, device=device, dtype=dtype) + c = torch.randn(3, 2, device=device, dtype=dtype) + nt = torch.nested_tensor([a, b, c]) + + weight = torch.randn(2, 2, device=device, dtype=dtype) + bias = torch.randn(2, device=device, dtype=dtype) + # success case + torch.functional.F.linear(nt, weight, bias) + + # invalid nested tensor dimension + msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2' + nt1 = torch.nested_tensor([torch.randn(1, device=device, dtype=dtype), + torch.randn(2, device=device, dtype=dtype)]) + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt1, weight, bias) + + # invalid weight shape + msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3' + weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt, weight1, bias) + + # inconsistent last dim of nested tensor + msg = r"all tensors in NestedTensor must have the same trailing dim" + nt2 = torch.nested_tensor([torch.randn(1, 2, device=device, dtype=dtype), + torch.randn(2, 3, device=device, dtype=dtype)]) + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt2, weight, bias) + + # Mismatch of nested tensor last dim and weight dimension + weight2 = torch.randn(2, 4, device=device, dtype=dtype) + msg = r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" \ + r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4" + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt, weight2, bias) + + # Nested tensor input and nested weight + nt_weight = nt.clone() + msg = r"Linear does not support nested weight when input is a nested tensor." + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt, nt_weight, bias) + + # TODO: test noncontiguous linear + # For now this tests the error message of linear + # since linear does not support noncontiguous buffer yet + @dtypes(torch.float, torch.double) + def test_linear_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = self.random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + weight = torch.randn((8, 5), device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + r"for now linear only supports contiguous nested tensor", + lambda: torch.nn.functional.linear(nt_noncontiguous, weight) + ) + + @dtypes(torch.float, torch.float16, torch.double) + @torch.inference_mode() + def test_transpose(self, device, dtype): + nt = self.random_nt(device, dtype, 4, (4, 4)) + # error case: transpose nested dimension + self.assertRaisesRegex( + RuntimeError, + "Nested tensor dimension 0 cannot be transposed", + lambda: nt.transpose(0, 1) + ) + self.assertRaisesRegex( + RuntimeError, + "Nested tensor dimension 0 cannot be transposed", + lambda: nt.transpose(1, -3) + ) + # error case: dimension out of range + self.assertRaises(IndexError, lambda: nt.transpose(1, 3)) + self.assertRaises(IndexError, lambda: nt.transpose(-4, -1)) + # normal case + ntT = nt.transpose(-1, -2) + ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) + pt = nt.to_padded_tensor(0.0) + ptT = pt.transpose(-1, -2) + self.assertEqual(ptT, ptT_from_ntT) + + @dtypes(torch.float, torch.float16, torch.double) + @torch.inference_mode() + def test_reshape(self, device, dtype): + nt = self.random_nt(device, dtype, 4, (4, 4)) + # error case: empty shape + self.assertRaisesRegex( + RuntimeError, + r"shape '\[\]' is invalid for a nested tensor", + lambda: nt.reshape(()) + ) + # error case: empty nested tensor + nt_empty = torch.nested_tensor([]) + self.assertRaisesRegex( + RuntimeError, + "empty nested tensor cannot be reshaped", + lambda: nt_empty.reshape(-1) + ) + # error case: invalid proposed shape for underlying tensors + self.assertRaisesRegex( + RuntimeError, + r"invalid shape dimension -2", + lambda: nt.reshape(-2, 2, 3) + ) + self.assertRaisesRegex( + RuntimeError, + r"shape '\[.*\]' is invalid for input of size [0-9]+", + lambda: nt.reshape(4, 2, 3) + ) + # normal case + x0 = torch.randn((2, 20), device=device, dtype=dtype) + x1 = torch.randn((3, 20), device=device, dtype=dtype) + nt = torch.nested_tensor([x0, x1]) + pt = nt.to_padded_tensor(0.0) + self.assertRaisesRegex( + RuntimeError, + r"for now reshape cannot change the implicit batch dimension", + lambda: nt.transpose(-1, -2).reshape(40, -1) + ) + # inherit only the ragged dimension + # (2, 20) -> (2, 5, 4) + # (3, 20) -> (3, 5, 4) + nt1 = nt.reshape(2, -1, 5, 4) + # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) + pt1 = pt.reshape(2, -1, 5, 4) + self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) + # also inherit regular dimension + nt2 = nt1.reshape(2, -1, -1, 2, 2) + pt2 = pt1.reshape(2, -1, 5, 2, 2) + self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2) + + @parametrize("input_dim", [3, 4]) + def test_scaled_dot_product_attention(self, device, input_dim): + + def rand_tensor(*shape): + return torch.randn(shape, device=device) + + E = 10 + if input_dim == 3: + # Shape: (N, L, E); ragged L + query = torch.nested_tensor([rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]) + + # Shape: (N, S, E); ragged S + key = torch.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]) + value = torch.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]) + elif input_dim == 4: + # Shape: (N, N', L, E); ragged N' and L + query = torch.nested_tensor([rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]) + # Shape: (N, N', S, E); ragged N' and S + key = torch.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]) + value = torch.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]) + else: + self.fail(f"Invalid input_dim {input_dim} encountered in SDP test") + + def rand_mask(size): + return torch.randint(0, 2, size=size, dtype=torch.bool, device=device) + + # Shape: (N, L, S); ragged L and S matching above + attn_mask = torch.nested_tensor([rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]) + + dropout_p = 0.0 # no dropout for reproducibility + need_attn_weights: bool = True + + # Success case: no attn_mask set and is_causal=False. + actual = torch.ops.aten._scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=dropout_p, need_attn_weights=need_attn_weights) + + expected_outputs = [] + expected_attn_weights = [] + for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()): + (output, attn_weights) = torch.ops.aten._scaled_dot_product_attention( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_mask=None, dropout_p=dropout_p, + need_attn_weights=need_attn_weights) + expected_outputs.append(output.squeeze(0)) + expected_attn_weights.append(attn_weights.squeeze(0)) + expected_output_nested = torch.nested_tensor(expected_outputs) + expected_attn_weight_nested = torch.nested_tensor(expected_attn_weights) + self.nt_equal(actual[0], expected_output_nested) + self.nt_equal(actual[1], expected_attn_weight_nested) + + # Error case: explicit attn_mask set. + with self.assertRaisesRegex(RuntimeError, "not supported when an explicit attn_mask is set"): + torch.ops.aten._scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, need_attn_weights=need_attn_weights) + + # Error case: is_causal=True. + with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"): + torch.ops.aten._scaled_dot_product_attention( + query, key, value, dropout_p=dropout_p, need_attn_weights=need_attn_weights, is_causal=True) + + class TestNestedTensorAutograd(TestCase): def nt_equal(self, nt1, nt2): self.assertEqual(nt1.dtype, nt2.dtype) @@ -666,6 +1322,180 @@ def grad_test_func(a, b, c): data = (a, b, c) assert torch.autograd.gradcheck(grad_test_func, inputs=data) + def test_size_dim(self): + a = torch.nested_tensor([]) + self.assertEqual(a.size(0), 0) + + a = torch.nested_tensor([torch.tensor(1)]) + self.assertEqual(a.size(0), 1) + + a = torch.nested_tensor([torch.tensor(1), torch.tensor(2)]) + self.assertEqual(a.size(0), 2) + + a = torch.nested_tensor([torch.rand(1, 2), + torch.rand(1, 8)]) + self.assertEqual(a.size(0), 2) + self.assertEqual(a.size(1), 1) + self.assertRaisesRegex( + RuntimeError, "Given dimension 2 is irregular and does not have a size", lambda: a.size(2)) + + a = torch.nested_tensor([torch.rand(3, 4), + torch.rand(5, 4)]) + self.assertEqual(a.size(0), 2) + self.assertRaisesRegex( + RuntimeError, "Given dimension 1 is irregular and does not have a size", lambda: a.size(1)) + self.assertEqual(a.size(2), 4) + + def test_nested_tensor_bmm_gradcheck(self): + a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64) + b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64) + c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64) + d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64) + + def grad_test_func(a, b, c, d): + nt0 = torch.nested_tensor([a, b]) + nt1 = torch.nested_tensor([c, d]) + result = nt0.bmm(nt1) + return result.to_padded_tensor(0.0) + + data = (a, b, c, d) + assert torch.autograd.gradcheck(grad_test_func, inputs=data) + + def test_nested_tensor_bmm_backward(self): + nt0 = torch.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))]).requires_grad_(True) + nt1 = torch.nested_tensor([torch.randn((6, 4)), torch.randn((6, 5))]).requires_grad_(True) + with torch.no_grad(): + pt0 = nt0.to_padded_tensor(0.0).requires_grad_(True) + pt1 = nt1.to_padded_tensor(0.0).requires_grad_(True) + + ynt = nt0.bmm(nt1) + ypt = pt0.bmm(pt1) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(nt0.grad.to_padded_tensor(0.0), pt0.grad) + self.assertEqual(nt1.grad.to_padded_tensor(0.0), pt1.grad) + + def test_nested_tensor_matmul_gradcheck(self): + a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64) + b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64) + c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64) + d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64) + + def grad_test_func(a, b, c, d): + nt0 = torch.nested_tensor([a, b]) + nt1 = torch.nested_tensor([c, d]) + result = torch.matmul(nt0, nt1) + return result.to_padded_tensor(0.0) + + data = (a, b, c, d) + assert torch.autograd.gradcheck(grad_test_func, inputs=data) + + def test_nested_tensor_matmul_backward(self): + nt0 = torch.nested_tensor([torch.randn((7, 2, 6)), torch.randn((7, 3, 6))]).requires_grad_(True) + nt1 = torch.nested_tensor([torch.randn((7, 6, 4)), torch.randn((7, 6, 5))]).requires_grad_(True) + with torch.no_grad(): + pt0 = nt0.to_padded_tensor(0.0).requires_grad_(True) + pt1 = nt1.to_padded_tensor(0.0).requires_grad_(True) + + ynt = torch.matmul(nt0, nt1) + ypt = torch.matmul(pt0, pt1) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.nt_equal(nt0.grad.to_padded_tensor(0.0), pt0.grad) + self.nt_equal(nt1.grad.to_padded_tensor(0.0), pt1.grad) + + def test_nested_tensor_transpose_gradcheck(self): + a = torch.randn(2, 5, requires_grad=True) + b = torch.randn(3, 4, requires_grad=True) + + def grad_test_func(a, b): + nt = torch.nested_tensor([a, b]) + result = nt.transpose(-2, -1).transpose(-2, -1) + return result.to_padded_tensor(0.0) + + data = (a, b) + assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) + + def test_nested_tensor_transpose_backward(self): + nt = torch.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))]).requires_grad_(True) + with torch.no_grad(): + pt = nt.to_padded_tensor(0.0).requires_grad_(True) + + ynt = nt.transpose(-2, -1) + ypt = pt.transpose(-2, -1) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(nt.grad.to_padded_tensor(0.0), pt.grad) + + def test_nested_tensor_reshape_gradcheck(self): + a = torch.randn(2, 6, requires_grad=True) + b = torch.randn(3, 6, requires_grad=True) + + def grad_test_func(a, b): + nt = torch.nested_tensor([a, b]) + result = nt.reshape(2, -1, 2, 3) + return result.to_padded_tensor(0.0) + + data = (a, b) + assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) + + def test_nested_tensor_reshape_backward(self): + nt = torch.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))]).requires_grad_(True) + with torch.no_grad(): + pt = nt.to_padded_tensor(0.0).requires_grad_(True) + + ynt = nt.reshape(2, -1, 2, 3) + ypt = pt.reshape(2, -1, 2, 3) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(nt.grad.to_padded_tensor(0.0), pt.grad) + + def test_nested_tensor_linear(self): + + a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64) + b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64) + c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64) + + weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64) + bias = torch.randn(2, requires_grad=True, dtype=torch.float64) + + def grad_test_func(a, b, c, weight, bias=None): + nt = torch.nested_tensor([a, b, c]) + # This implicitly tests to_padded_tensor grads + d = torch.functional.F.linear(nt, weight, bias) + return d.to_padded_tensor(0) + data = (a, b, c, weight, bias) + assert torch.autograd.gradcheck(grad_test_func, inputs=data) + + # Test linear with no bias added + data = (a, b, c, weight) + assert torch.autograd.gradcheck(grad_test_func, inputs=data) + + def test_nested_tensor_linear_backward(self): + a = torch.randn(1, 2, requires_grad=False) + b = torch.randn(2, 2, requires_grad=False) + c = torch.randn(3, 2, requires_grad=False) + + weight = torch.randn(2, 2, requires_grad=True) + bias = torch.randn(2, requires_grad=True) + nt = torch.nested_tensor([a, b, c]) + + out = torch.functional.F.linear(nt, weight, bias) + + out.backward(out.clone()) + + assert weight.grad is not None + assert bias.grad is not None + + assert a.grad is None + assert b.grad is None + assert c.grad is None + + instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/test/test_nn.py b/test/test_nn.py index 12b5f61540647..09d6e67e9fb2e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -19,6 +19,8 @@ import sys import os import subprocess +import weakref +import gc import torch @@ -47,7 +49,7 @@ download_file, get_function_arglist, load_tests, skipIfMps,\ suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC, \ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, set_default_dtype, IS_WINDOWS, \ - slowTest + slowTest, skipIfTorchDynamo from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, \ @@ -451,6 +453,7 @@ def test_module_backcompat(self): def test_conv_backcompat(self): from torch.serialization import SourceChangeWarning + # This file was generated by running on PyTorch 1.0.1 on Python 2: # # import torch @@ -562,6 +565,7 @@ def bw_hook(inc, h_module, grad_input, grad_output): test_fwd.remove() test_bwd.remove() + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_hooks(self): self._test_hooks("register_backward_hook") self._test_hooks("register_full_backward_hook") @@ -614,6 +618,7 @@ def forward(self, arg1, arg2, arg3): mod.register_full_backward_hook(lambda mod, gI, gO: None) mod(inp, inp.detach(), inp) + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_hook_no_requires_grad(self): mod = nn.Linear(2, 3) @@ -802,6 +807,7 @@ def bw_hook(module, grad_input, grad_output): with module.register_full_backward_hook(bw_hook): module(inp1, inp2).sum().backward() + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_hook_backward_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) @@ -819,6 +825,7 @@ def bw_hook(module, grad_input, grad_output): expected_grad = sig_x * (1 - sig_x) * 2 self.assertEqual(input.grad, expected_grad) + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_hook_forward_preforward_writable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) @@ -1573,6 +1580,59 @@ def test_Sequential_delitem(self): del n[1::2] self.assertEqual(n, nn.Sequential(l1, l3)) + def test_Sequential_add(self): + l1 = nn.Linear(1, 2) + l2 = nn.Linear(2, 3) + l3 = nn.Linear(3, 4) + l4 = nn.Linear(4, 5) + n = nn.Sequential(l1, l2) + other = nn.Sequential(l3, l4) + self.assertEqual(n + other, nn.Sequential(l1, l2, l3, l4)) + + def test_Sequential_iadd(self): + l1 = nn.Linear(10, 20) + l2 = nn.Linear(20, 30) + l3 = nn.Linear(30, 40) + l4 = nn.Linear(40, 50) + n = nn.Sequential(l1, l2, l3) + n2 = nn.Sequential(l4) + n += n2 + n2 += n + self.assertEqual(n, nn.Sequential(l1, l2, l3, l4)) + self.assertEqual(n2, nn.Sequential(l4, l1, l2, l3, l4)) + + def test_Sequential_mul(self): + l1 = nn.Linear(10, 20) + l2 = nn.Linear(20, 30) + l3 = nn.Linear(30, 40) + l4 = nn.Linear(40, 50) + n = nn.Sequential(l1, l2, l3, l4) + n2 = n * 2 + self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4)) + + def test_Sequential_rmul(self): + l1 = nn.Linear(10, 20) + l2 = nn.Linear(20, 30) + l3 = nn.Linear(30, 40) + l4 = nn.Linear(40, 50) + n = nn.Sequential(l1, l2, l3, l4) + n2 = 2 * n + self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4)) + + def test_Sequential_imul(self): + l1 = nn.Linear(10, 20) + l2 = nn.Linear(20, 30) + l3 = nn.Linear(30, 40) + l4 = nn.Linear(40, 50) + n = nn.Sequential(l1, l2, l3, l4) + n *= 2 + self.assertEqual(n, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4)) + n *= 2 + self.assertEqual( + n, + nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4) + ) + def test_Sequential_append(self): l1 = nn.Linear(10, 20) l2 = nn.Linear(20, 30) @@ -1584,6 +1644,63 @@ def test_Sequential_append(self): self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4)) self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4)) + def test_Sequential_pop(self): + l1 = nn.Linear(1, 2) + l2 = nn.Linear(2, 3) + l3 = nn.Linear(3, 4) + l4 = nn.Linear(4, 5) + n1 = nn.Sequential(l1, l2, l3, l4) + self.assertEqual(l4, n1.pop(3)) + n2 = nn.Sequential(l1, l2, l3) + self.assertEqual(n1, n2) + # check order of the index + for k, mod in zip(range(len(n1)), n1): + self.assertIs(n1[k], mod) + + def test_Sequential_insert(self): + l1 = nn.Linear(1, 2) + l2 = nn.Linear(2, 3) + l3 = nn.Linear(3, 4) + + n1 = nn.Sequential(l1, l2, l3) + module_1 = nn.Linear(4, 5) + n2 = nn.Sequential(l1, module_1, l2, l3) + self.assertEqual(n1.insert(1, module_1), n2) + + # test for negative support + n3 = nn.Sequential(l1, l2, l3) + module_2 = nn.Linear(5, 6) + n4 = nn.Sequential(l1, module_2, l2, l3) + self.assertEqual(n3.insert(-2, module_2), n4) + + def test_Sequential_insert_fail_case(self): + l1 = nn.Linear(1, 2) + l2 = nn.Linear(2, 3) + l3 = nn.Linear(3, 4) + + module = nn.Linear(5, 6) + + # test for error case + n1 = nn.Sequential(l1, l2, l3) + with self.assertRaises(IndexError): + n1.insert(-5, module) + + with self.assertRaises(AssertionError): + n1.insert(1, [nn.Linear(6, 7)]) + + def test_Sequential_extend(self): + l1 = nn.Linear(10, 20) + l2 = nn.Linear(20, 30) + l3 = nn.Linear(30, 40) + l4 = nn.Linear(40, 50) + n1 = nn.Sequential(l1, l2) + n2 = nn.Sequential(l3, l4) + n3 = nn.Sequential(l1, l2) + for l in n2: + n1.append(l) + n3.extend(n2) + self.assertEqual(n3, n1) + def test_ModuleList(self): modules = [nn.ReLU(), nn.Linear(5, 5)] module_list = nn.ModuleList(modules) @@ -1659,6 +1776,14 @@ def check(): module_list.extend(s.modules()) check() + modules = [nn.ReLU(), nn.Linear(5, 5), nn.Conv2d(3, 4, 3)] + module_list = nn.ModuleList(modules) + self.assertEqual(modules.pop(1), module_list.pop(1)) + self.assertEqual(modules, module_list) + # check order of the index + for k, mod in zip(range(len(module_list)), module_list): + self.assertIs(module_list[k], mod) + # verify the right exception is thrown when trying to "forward" through a ModuleList self.assertRaises(NotImplementedError, module_list) self.assertRaises(NotImplementedError, module_list, torch.rand(1, 3)) @@ -1760,6 +1885,7 @@ def check(): self.assertRaises(NotImplementedError, module_dict) self.assertRaises(NotImplementedError, module_dict, torch.rand(1, 3)) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_ParameterList(self): def make_param(): return Parameter(torch.randn(2, 2)) @@ -1868,6 +1994,7 @@ def make_param(): self.assertIsNotNone(p2.grad_fn) self.assertIs(p2._base, p) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_ParameterDict(self): parameters = OrderedDict([ ('p1', Parameter(torch.randn(10, 10))), @@ -3279,6 +3406,54 @@ def forward(self, X): parametrize.type_before_parametrizations(model) == original_type ) + def test_deepcopy_after_parametrization(self): + r"""Test that we are able to create a deepcopy of the module when it's parametrized.""" + + class AddOne(nn.Module): + def forward(self, x): + return x + 1.0 + + class ModelWithoutDeepcopy(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.tensor([1., 1., 1., 1.]), requires_grad=True) + self.bias = nn.Parameter(torch.tensor([0., 0., 0., 0.]), requires_grad=True) + self.attr = [1.0, 2.0, 3.0, 4.0] + + class ActualModel(ModelWithoutDeepcopy): + # Emulate custom implementation of the deepcopying. + def __deepcopy__(self, memo): + result = self.__new__(self.__class__) + memo[id(self)] = result + result.__dict__ = deepcopy(self.__dict__, memo) + return result + + def check_deepcopy(m1: nn.Module, m2: nn.Module): + w1 = m1.parametrizations.weight.original + w2 = m2.parametrizations.weight.original + b1 = m1.parametrizations.bias.original if parametrize.is_parametrized(m1, "bias") else m1.bias + b2 = m2.parametrizations.bias.original if parametrize.is_parametrized(m2, "bias") else m2.bias + # Weights, biases and attributes should be equal but they must be different objects. + self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys()) + self.assertIsNot(m1, m2) + self.assertEqual(w1, w2) + self.assertIsNot(w1, w2) + self.assertEqual(b1, b2) + self.assertIsNot(b1, b2) + self.assertEqual(m1.attr, m2.attr) + self.assertIsNot(m1.attr, m2.attr) + + for model in (ModelWithoutDeepcopy(), ActualModel()): + # General check that we are able to create deepcopy. + parametrize.register_parametrization(model, "weight", AddOne()) + check_deepcopy(model, deepcopy(model)) + # Check that this works on models with several parametrized tensors. + parametrize.register_parametrization(model, "bias", AddOne()) + check_deepcopy(model, deepcopy(model)) + # Check that this works on models where tensors have more than one parametrization. + parametrize.register_parametrization(model, "weight", AddOne()) + check_deepcopy(model, deepcopy(model)) + def test_transfer_parametrizations_and_params(self): r"""Test that all parametrizations and their associated parameters are transferred.""" @@ -4369,6 +4544,7 @@ def test_prune_importance_scores_mimic_default(self): self.assertEqual(pruned_tensor_without_importance_scores, pruned_tensor_with_importance_scores) self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_rnn_pruning(self): l = torch.nn.LSTM(32, 32) # This Module has 4 parameters called: @@ -4401,6 +4577,7 @@ def test_rnn_pruning(self): assert 'weight_ih_l0_orig' not in dict(l.named_parameters()) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_rnn_weight_norm(self): def check_weight_norm(l, name, num_params): # This Module has 4 or 5 parameters called: @@ -4440,14 +4617,14 @@ def check_weight_norm(l, name, num_params): def test_weight_norm(self): for dtype in [torch.float, torch.bfloat16]: - input = torch.randn(3, 40, dtype=dtype) - m = nn.Linear(40, 50).to(dtype=dtype) + input = torch.randn(3, 4, dtype=dtype) + m = nn.Linear(4, 5).to(dtype=dtype) expected_output = m(input) # add weight normalization m = torch.nn.utils.weight_norm(m) self.assertEqual(m.weight_v.size(), m.weight.size()) - self.assertEqual(m.weight_g.size(), (50, 1)) + self.assertEqual(m.weight_g.size(), (5, 1)) self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) # remove weight norm @@ -4459,11 +4636,11 @@ def test_weight_norm(self): # test with dim=1 m = torch.nn.utils.weight_norm(m, dim=1) self.assertEqual(m.weight_v.size(), m.weight.size()) - self.assertEqual(m.weight_g.size(), (1, 40)) + self.assertEqual(m.weight_g.size(), (1, 4)) self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) # test with dim=None - m = nn.Linear(40, 50).to(dtype=dtype) + m = nn.Linear(4, 5).to(dtype=dtype) expected_output = m(input) m = torch.nn.utils.weight_norm(m, dim=None) self.assertEqual(m(input), expected_output) @@ -4472,6 +4649,12 @@ def test_weight_norm(self): m = torch.nn.utils.weight_norm(m) m = torch.nn.utils.weight_norm(m) + # For float16, the forward of the Module doesn't work but we must still be able + # to register the weight norm as this is often done before sending the Module to + # CUDA. + m = nn.Linear(4, 5, dtype=torch.float16) + m = torch.nn.utils.weight_norm(m) + def test_parameterlistdict_setting_attributes(self): with warnings.catch_warnings(record=True) as w: mod = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) @@ -4521,6 +4704,7 @@ def test_weight_norm_pickle(self): m = pickle.loads(pickle.dumps(m)) self.assertIsInstance(m, nn.Linear) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_spectral_norm(self): input = torch.randn(3, 5) m = nn.Linear(5, 7) @@ -5840,32 +6024,6 @@ def test_multihead_attn_3d_attn_mask(self): # output_2d in shape of [T, 1, D] self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_self_attn_TxT_attn_mask(self): - embed_dim = 16 - num_heads = 4 - batch_size = 10 - tgt_len = 16 - - query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D] - attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] - attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) - - attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) - - mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() - mta_model.eval() - - # Generate 3D results - with torch.inference_mode(): - output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] - output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] - - output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] - output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] - - self.assertEqual(output_mask_4d, output_mask_TxT) - def test_multihead_attn_no_bias(self): embed_dim = 8 num_heads = 4 @@ -6390,6 +6548,7 @@ def test_load_state_dict_BC(self): self.assertEqual(bn.num_batches_tracked.dtype, torch.long) self.assertEqual(bn.num_batches_tracked.item(), 0) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_load_state_dict_ref_cycle(self): # load_state_dict shouldn't cause a reference cycle involving Tensors import gc @@ -6521,6 +6680,7 @@ def set_extra_state(self): with self.assertRaisesRegex(RuntimeError, 'Missing key'): m.load_state_dict(m.state_dict()) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_parameter_assignment(self): l = nn.Linear(5, 5) @@ -7908,187 +8068,6 @@ def perm_fn(x): [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) - def test_transformerencoder(self): - def get_a_test_layer(use_cuda, activation, batch_first=False): - d_model = 4 - nhead = 2 - dim_feedforward = 16 - dropout = 0.0 - device = torch.device("cuda" if use_cuda else "cpu") - - layer = nn.TransformerEncoderLayer( - d_model, - nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - batch_first=batch_first).to(device) - - with torch.no_grad(): - # set constant weights of the model - for idx, p in enumerate(layer.parameters()): - x = p.data - sz = x.view(-1).size(0) - shape = x.shape - x = torch.cos(torch.arange(0, sz).float().view(shape)) - p.data.copy_(x) - - return layer - - # this is a deterministic test for TransformerEncoder - activation = F.relu - use_cuda = torch.cuda.is_available() - device = torch.device("cuda" if use_cuda else "cpu") - - def _test(batch_first, training): - def perm_fn(x): - return x.transpose(1, 0) if batch_first else x - - encoder_layer = get_a_test_layer(use_cuda=use_cuda, activation=activation, - batch_first=batch_first) - - model = nn.TransformerEncoder(encoder_layer, 1).to(device) - if not training: - model = model.eval() - - # deterministic input - encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], - [0.5387, 0.1655, 0.3565, 0.0471]], - [[0.8335, 0.2799, 0.5031, 0.2947], - [0.1402, 0.0318, 0.7636, 0.1346]], - [[0.6333, 0.9344, 0.1376, 0.9938], - [0.8924, 0.2872, 0.6692, 0.2944]], - [[0.9897, 0.6915, 0.3154, 0.1733], - [0.8645, 0.3513, 0.3064, 0.0767]], - [[0.8117, 0.2366, 0.4838, 0.7881], - [0.3718, 0.4945, 0.9511, 0.0864]]] - )).to(device) - result = model(encoder_input) - ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249], - [2.427987, 0.021213, -0.602496, -0.084103]], - [[2.424689, 0.019155, -0.604793, -0.085672], - [2.413863, 0.022211, -0.612486, -0.072490]], - [[2.433774, 0.021598, -0.598343, -0.087548], - [2.425104, 0.019748, -0.604515, -0.084839]], - [[2.436185, 0.022682, -0.596625, -0.087261], - [2.433556, 0.021891, -0.598509, -0.086832]], - [[2.416246, 0.017512, -0.610712, -0.082961], - [2.422901, 0.024187, -0.606178, -0.074929]]] - )).to(device) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) - - # all 0 - mask = torch.zeros([2, 5]).to(device) == 1 - result = model(encoder_input, src_key_padding_mask=mask) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) - mask[0, 1] = 1 - mask[1, 3] = 1 - mask[1, 4] = 1 - # If mask is not left aligned - # We disable nested tensor - model.enable_nested_tensor = False - result = model(encoder_input, src_key_padding_mask=mask) - ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642], - [2.428811, 0.021445, -0.601912, -0.084252]], - [[2.425009, 0.019155, -0.604566, -0.085899], - [2.415408, 0.02249, -0.611415, -0.073]], - [[2.434199, 0.021682, -0.598039, -0.087699], - [2.42598, 0.019941, -0.603896, -0.085091]], - [[2.436457, 0.022736, -0.59643, -0.08736], - [2.434021, 0.022093, -0.598179, -0.08679]], - [[2.416531, 0.017498, -0.610513, -0.083181], - [2.4242, 0.024653, -0.605266, -0.074959]]] - )).to(device) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) - - # test case 2, multiple layers no norm - model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=False).to(device) - if not training: - model = model.eval() - result = model(encoder_input, src_key_padding_mask=mask) - ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003], - [2.419102, 0.017452, -0.608703, -0.085026]], - [[2.419043, 0.017445, -0.608744, -0.084999], - [2.419052, 0.017446, -0.608738, -0.085004]], - [[2.419067, 0.017448, -0.608727, -0.085010], - [2.419098, 0.017452, -0.608706, -0.085024]], - [[2.419072, 0.017449, -0.608724, -0.085012], - [2.419119, 0.017455, -0.608691, -0.085034]], - [[2.419019, 0.017442, -0.608761, -0.084989], - [2.419075, 0.017449, -0.608722, -0.085014]]] - )).to(device) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) - - model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=False).to(device) - if not training: - model = model.eval() - result = model(encoder_input, src_key_padding_mask=mask) - ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025], - [2.419101, 0.017453, -0.608704, -0.085025]], - [[2.419101, 0.017453, -0.608703, -0.085025], - [2.419101, 0.017453, -0.608704, -0.085025]], - [[2.419101, 0.017453, -0.608703, -0.085025], - [2.419101, 0.017453, -0.608704, -0.085025]], - [[2.419101, 0.017453, -0.608703, -0.085025], - [2.419101, 0.017453, -0.608704, -0.085025]], - [[2.419101, 0.017453, -0.608703, -0.085025], - [2.419101, 0.017453, -0.608704, -0.085025]]] - )).to(device) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) - - # test case 3, multiple layers with norm - # d_model = 4 - norm = nn.LayerNorm(4) - model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, enable_nested_tensor=False).to(device) - if not training: - model = model.eval() - result = model(encoder_input, src_key_padding_mask=mask) - ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238], - [1.695955, -0.357639, -0.893050, -0.445266]], - [[1.695948, -0.357634, -0.893082, -0.445233], - [1.695950, -0.357635, -0.893077, -0.445238]], - [[1.695951, -0.357636, -0.893069, -0.445246], - [1.695955, -0.357639, -0.893052, -0.445264]], - [[1.695952, -0.357636, -0.893066, -0.445249], - [1.695957, -0.357641, -0.893041, -0.445276]], - [[1.695946, -0.357632, -0.893095, -0.445220], - [1.695952, -0.357637, -0.893065, -0.445251]]] - )).to(device) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) - - model = nn.TransformerEncoder(encoder_layer, 6, norm=norm, enable_nested_tensor=False).to(device) - if not training: - model = model.eval() - result = model(encoder_input, src_key_padding_mask=mask) - ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265], - [1.695955, -0.357639, -0.893051, -0.445265]], - [[1.695955, -0.357639, -0.893051, -0.445265], - [1.695955, -0.357639, -0.893051, -0.445265]], - [[1.695955, -0.357639, -0.893051, -0.445265], - [1.695955, -0.357639, -0.893051, -0.445265]], - [[1.695955, -0.357639, -0.893051, -0.445265], - [1.695955, -0.357639, -0.893051, -0.445265]], - [[1.695955, -0.357639, -0.893051, -0.445265], - [1.695955, -0.357639, -0.893051, -0.445265]]] - )).to(device) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) - for batch_first in (True, False): - for training in (True, False): - # Fast path requires inference mode. - if training: - cm = contextlib.nullcontext() - else: - cm = torch.no_grad() - with cm: - _test(batch_first, training) - def test_transformerdecoder(self): def get_a_test_layer(use_cuda, activation, batch_first=False): d_model = 4 @@ -9370,6 +9349,7 @@ def test_inplace_thnn(self): self.assertEqual(grad_output, grad_output_clone) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_pixel_shuffle_unshuffle(self): def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True, upscale_factor=None): @@ -10002,6 +9982,7 @@ def test_pdist_empty_col(self): inp = torch.randn(4, 0, dtype=torch.double, device=device, requires_grad=True) self.assertTrue(gradcheck(F.pdist, (inp,))) + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") @unittest.expectedFailure def test_pdist_cpu_gradgrad_unimplemented(self): inp = torch.randn(4, 5, requires_grad=True) @@ -11090,6 +11071,14 @@ def test_channel_shuffle(self): y = y.contiguous(memory_format=torch.contiguous_format) self.assertEqual(y, y_ref) + + def test_channel_shuffle_return_self(self): + # gh-76616: nn.ChannelShuffle will return self with an empty input tensor + groups = 3 + input_tensor = torch.rand([0, 9, 4, 4]) + output = torch.nn.ChannelShuffle(groups)(input_tensor) + torch.testing.assert_close(output, input_tensor) + def test_upsamplingLinear1d(self): for align_corners in [True, False]: for recompute_scale_factor in [True, False]: @@ -11593,31 +11582,33 @@ def test_softmin(self): self.assertEqual(F.softmin(x, 0), F.softmax(-x, 0)) def test_log_softmax_cpu(self, dtype=torch.bfloat16): - inputf = torch.rand(32, 100, device="cpu", dtype=torch.float, requires_grad=True) - input = inputf.to(dtype).detach().requires_grad_(True) - outf = F.log_softmax(inputf, dim=-1) - out = F.log_softmax(input, dim=-1) - self.assertEqual(out.dtype, dtype) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(out, outf, atol=0.1, rtol=0) + for dim in [0, 1]: + inputf = torch.rand(200, 200, device="cpu", dtype=torch.float, requires_grad=True) + input = inputf.to(dtype).detach().requires_grad_(True) + outf = F.log_softmax(inputf, dim=dim) + out = F.log_softmax(input, dim=dim) + self.assertEqual(out.dtype, dtype) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(out, outf, atol=0.1, rtol=0) - out.sum().backward() - outf.sum().backward() - self.assertEqual(input.grad.dtype, dtype) - self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0.1, rtol=0) + out.sum().backward() + outf.sum().backward() + self.assertEqual(input.grad.dtype, dtype) + self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0.1, rtol=0) def test_softmax_cpu(self, dtype=torch.bfloat16): - inputf = torch.rand(32, 100, device="cpu", dtype=torch.float, requires_grad=True) - input = inputf.to(dtype).detach().requires_grad_(True) - outf = F.softmax(inputf, dim=-1) - out = F.softmax(input, dim=-1) - self.assertEqual(out.dtype, dtype) - self.assertEqualIgnoreType(out, outf, atol=1e-3, rtol=0) + for dim in [0, 1]: + inputf = torch.rand(200, 200, device="cpu", dtype=torch.float, requires_grad=True) + input = inputf.to(dtype).detach().requires_grad_(True) + outf = F.softmax(inputf, dim=dim) + out = F.softmax(input, dim=dim) + self.assertEqual(out.dtype, dtype) + self.assertEqualIgnoreType(out, outf, atol=1e-3, rtol=0) - out.sum().backward() - outf.sum().backward() - self.assertEqual(input.grad.dtype, dtype) - self.assertEqual(input.grad, inputf.grad.to(dtype), atol=1e-3, rtol=0) + out.sum().backward() + outf.sum().backward() + self.assertEqual(input.grad.dtype, dtype) + self.assertEqual(input.grad, inputf.grad.to(dtype), atol=1e-3, rtol=0) def test_adaptive_log_softmax(self): # args validation @@ -11847,34 +11838,77 @@ def test_functional_grad_conv(self): output = F.conv1d(input, weight, dilation=2) grad_output = torch.randn(output.shape) - grad_input_autograd = torch.autograd.grad(output, input, grad_output)[0] + grad_input_autograd, grad_weight_autograd = torch.autograd.grad(output, (input, weight), grad_output) + grad_input_functional = torch.nn.grad.conv1d_input(input.shape, weight, grad_output, dilation=2) self.assertEqual(grad_input_functional, grad_input_autograd) + grad_weight_functional = torch.nn.grad.conv1d_weight(input, weight.shape, grad_output, dilation=2) + self.assertEqual(grad_weight_functional, grad_weight_autograd) + # Conv 2D input = torch.randn(1, 1, 5, 5, requires_grad=True) weight = torch.randn(1, 1, 3, 3, requires_grad=True) output = F.conv2d(input, weight, dilation=2) grad_output = torch.randn(output.shape) - grad_input_autograd = torch.autograd.grad(output, input, grad_output)[0] + (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) + grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, dilation=2) self.assertEqual(grad_input_functional, grad_input_autograd) + grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, dilation=2) + self.assertEqual(grad_weight_functional, grad_weight_autograd) + # Conv 3D input = torch.randn(1, 1, 5, 5, 5, requires_grad=True) weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True) output = F.conv3d(input, weight, dilation=2) grad_output = torch.randn(output.shape) - grad_input_autograd = torch.autograd.grad(output, input, grad_output)[0] + (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) + grad_input_functional = torch.nn.grad.conv3d_input(input.shape, weight, grad_output, dilation=2) self.assertEqual(grad_input_functional, grad_input_autograd) - # Warning for _grad_input_padding - with warnings.catch_warnings(record=True) as w: - torch.nn.grad._grad_input_padding(torch.rand(1, 2, 3), [1, 2, 5], (1,), (0,), (3,)) - self.assertEqual(len(w), 1) + grad_weight_functional = torch.nn.grad.conv3d_weight(input, weight.shape, grad_output, dilation=2) + self.assertEqual(grad_weight_functional, grad_weight_autograd) + + def test_functional_grad_conv2d(self): + BATCH_SIZE = 4 + IN_CH = 8 + OUT_CH = 16 + SPATIAL = 32 + + def _test_conv2d(stride, kernel_size, groups, dilation): + padding = kernel_size // 2 + + input = torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL).uniform_(-8.0, 8.0).requires_grad_(True) + + weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size).uniform_(-4.0, 4.0).requires_grad_(True) + + output = F.conv2d(input, weight, + stride=stride, padding=padding, dilation=dilation, groups=groups) + + grad_output = torch.randn(output.shape) + + (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) + + grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, + stride=stride, padding=padding, dilation=dilation, groups=groups) + self.assertEqual(grad_input_functional, grad_input_autograd) + + grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, + stride=stride, padding=padding, dilation=dilation, groups=groups) + self.assertEqual(grad_weight_functional, grad_weight_autograd) + + strides = [1, 2] + kernel_sizes = [1, 3, 5] + groups = [1, 2, 4] + dilates = [1, 2] + + for s, k, g, d in product(strides, kernel_sizes, groups, dilates): + _test_conv2d(s, k, g, d) def test_flatten(self): tensor_input = torch.randn(2, 1, 2, 3) @@ -14156,12 +14190,14 @@ def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d), decorators=[onlyCUDA, disablecuDNN], name='slow3d_cuda'), - subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'), + # FIXME: RuntimeError: CUDA out of memory. + # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), + # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'), subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated'), - subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'), + # FIXME: RuntimeError: CUDA out of memory. + # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), + # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'), subtest(((0, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch1d'), subtest(((2, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), @@ -14199,8 +14235,9 @@ def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d_transposed'), subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d_transposed'), - subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), - decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'), + # FIXME: RuntimeError: CUDA out of memory. + # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), + # decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'), # === miopen === subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d'), @@ -14356,6 +14393,7 @@ def _make_noncontiguous(inp): bias.requires_grad_(False) self.assertTrue(gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)) + def test_Dropout(self, device): input = torch.empty(1000) self._test_dropout(nn.Dropout, device, input) @@ -15895,6 +15933,72 @@ def helper(n, c, h, w, kernel_size, stride=None): helper(10, 512, 31, 31, 3, stride=2) helper(1, 129, 8, 8, 3, stride=2) + @onlyNativeDeviceTypes + @dtypes(torch.half, torch.float, torch.double) + @onlyCUDA + def test_max_pool3d_ndhwc(self, device, dtype): + def helper(n, c, h, w, d, kernel_size, stride=None): + batch = n + if not batch: + batch = 1 + input = torch.randn(batch, c, d, h, w, dtype=dtype, device=device) + input = input.contiguous(memory_format=torch.channels_last_3d).requires_grad_() + if not n: + input = input.squeeze(0).detach().clone().requires_grad_() + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * 3 + if stride is None: + stride = kernel_size + elif isinstance(stride, int): + stride = [stride] * 3 + grad = torch.randn(batch, c, + (d - kernel_size[0]) // stride[0] + 1, + (h - kernel_size[1]) // stride[1] + 1, + (w - kernel_size[2]) // stride[2] + 1, + dtype=dtype, device=device) + grad = grad.contiguous(memory_format=torch.channels_last_3d) + if not n: + grad = grad.squeeze(0) + pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(device) + + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(device) + out, ind = pool(input) + out.backward(grad) + ref_out, ref_ind = ref_pool(ref_input) + ref_out.backward(ref_grad) + + if len(out.shape) == 4: + self.assertTrue(out.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d)) + else: + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d)) + self.assertTrue(ref_out.is_contiguous()) + if len(ind.shape) == 4: + self.assertTrue(ind.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d)) + else: + self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last_3d)) + self.assertTrue(ref_ind.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(ind, ref_ind) + if dtype == torch.half: + self.assertEqual(input.grad, ref_input.grad, atol=0.05, rtol=0.01) + else: + self.assertEqual(input.grad, ref_input.grad) + + helper(4, 8, 8, 8, 8, 7) + helper(4, 8, 8, 8, 8, (5, 6, 7)) + helper(1, 8, 8, 8, 8, (5, 6, 7)) + helper(0, 6, 12, 13, 14, (5, 6, 7)) + helper(4, 8, 7, 7, 7, 3, stride=1) + helper(10, 128, 19, 19, 19, 3, stride=2) + helper(10, 128, 19, 19, 19, (1, 2, 3), stride=2) + helper(1, 128, 19, 19, 19, (1, 2, 3), stride=2) + helper(0, 128, 19, 19, 19, (1, 2, 3), stride=2) + helper(1, 79, 4, 4, 4, 3, stride=2) + helper(0, 79, 4, 4, 4, 3, stride=2) + + @onlyCPU def test_max_pool2d_bfloat16(self, device): def helper(n, c, h, w, kernel_size, stride, memory_format): @@ -16673,6 +16777,7 @@ def test_embedding_padding_idx(self, device, dtype): # with an offset array. Compare against an equivalent 2D input that uses # padding indices to fill in the gaps indicated by the offset array + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") @onlyNativeDeviceTypes @dtypes(torch.float32, torch.float64) @dtypesIfCUDA(torch.half, torch.bfloat16) @@ -17868,17 +17973,25 @@ def test_pool3d_size_one_feature_dim(self, device): self.assertEqual(out_y, out_x.to(device), msg=test) @onlyCUDA - @largeTensorTest('6GB') + @largeTensorTest('18GB') + @largeTensorTest('180GB', 'cpu') def test_pool3d_large_size_int64(self, device): # See https://github.com/pytorch/pytorch/issues/52822 - x = torch.randn(70, 32, 100, 100, 100, dtype=torch.half, device=device) + x = torch.randn(70, 32, 100, 100, 100, dtype=torch.half, device=device, requires_grad=True) y = torch.nn.functional.max_pool3d(x, 5) + g = torch.randn_like(y, dtype=torch.half) + torch.cuda.synchronize() + y.backward(g) torch.cuda.synchronize() - ref_x = x.cpu().float() # max_pool3d_cpu is not implemented for half + ref_x = x.detach().cpu().float() # max_pool3d_cpu is not implemented for half + ref_x.requires_grad = True + ref_g = g.cpu().float() ref_y = torch.nn.functional.max_pool3d(ref_x, 5) + ref_y.backward(ref_g) self.assertEqual(y, ref_y, exact_dtype=False) + self.assertEqual(x.grad, ref_x.grad, exact_dtype=False) @onlyCUDA def test_AvgPool3d_backward_after_cat_dim1_device(self, device): @@ -19394,12 +19507,12 @@ def test_softmax_bfloat16(self, device): @onlyCPU @dtypes(torch.float, torch.double) def test_conv_thnn_nhwc(self, device, dtype): - def helper(n, c, h, w, out_channels, kernel_size, dilation, groups, weight_memory_format): + def helper(n, c, h, w, out_channels, kernel_size, dilation, groups, input_format, weight_format): input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\ - .to(memory_format=torch.channels_last) + .to(memory_format=input_format) input.requires_grad_() conv = nn.Conv2d(c, out_channels, kernel_size, dilation=dilation, groups=groups)\ - .to(device='cpu', dtype=dtype, memory_format=weight_memory_format) + .to(device='cpu', dtype=dtype, memory_format=weight_format) for p in conv.parameters(): p.data = torch.randint_like(p, -3, 3) @@ -19426,16 +19539,29 @@ def helper(n, c, h, w, out_channels, kernel_size, dilation, groups, weight_memor self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) with torch.backends.mkldnn.flags(enabled=False): - for mf in [torch.contiguous_format, torch.channels_last]: + formats = [[torch.channels_last, torch.channels_last], + [torch.channels_last, torch.contiguous_format], + [torch.contiguous_format, torch.channels_last]] + for input_format, weight_format in formats: # non-dilated conv: thnn_conv2d normal path (with im2col) - helper(2, 8, 4, 4, out_channels=4, kernel_size=3, dilation=1, groups=1, weight_memory_format=mf) - helper(2, 8, 4, 4, out_channels=8, kernel_size=3, dilation=1, groups=8, weight_memory_format=mf) + helper(2, 8, 4, 4, out_channels=4, kernel_size=3, dilation=1, groups=1, + input_format=input_format, weight_format=weight_format) + helper(2, 8, 4, 4, out_channels=8, kernel_size=3, dilation=1, groups=8, + input_format=input_format, weight_format=weight_format) + # test when input chanels is 1 and not converted to channels last + helper(2, 1, 10, 10, out_channels=8, kernel_size=3, dilation=1, groups=1, + input_format=torch.contiguous_format, weight_format=torch.channels_last) # non-dilated conv: thnn_conv2d fast path (skip im2col) - helper(1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=1, weight_memory_format=mf) - helper(1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=16, weight_memory_format=mf) + helper(1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=1, + input_format=input_format, weight_format=weight_format) + # ic == oc == 1 here, so need to stick input to CL to activate channels last + helper(1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=16, + input_format=torch.channels_last, weight_format=weight_format) # dilated conv: slow_conv_dilated2d - helper(2, 8, 11, 13, out_channels=16, kernel_size=3, dilation=2, groups=1, weight_memory_format=mf) - helper(2, 16, 11, 13, out_channels=32, kernel_size=3, dilation=2, groups=16, weight_memory_format=mf) + helper(2, 8, 11, 13, out_channels=16, kernel_size=3, dilation=2, groups=1, + input_format=input_format, weight_format=weight_format) + helper(2, 16, 11, 13, out_channels=32, kernel_size=3, dilation=2, groups=16, + input_format=input_format, weight_format=weight_format) @onlyCUDA @skipCUDAIfRocmVersionLessThan((4, 3)) @@ -20132,6 +20258,7 @@ def test_maxpool3d_non_square_backward(self, device): # Check that clip_grad_norm_ raises an error if the total norm of the # parameters' gradients is non-finite + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_clip_grad_norm_error_if_nonfinite(self, device): norms_pos = [0.1, 1, 2, 3.5, inf] norms_neg = [-0.1, -1, -2, -3.5] @@ -20712,6 +20839,13 @@ def test_multihead_attn_fast_path_query_and_bias_have_different_dtypes(self, dev query = torch.randn(3, 3, 3, dtype=dtype, device=device) mha(query, query, query) + @dtypes(torch.double) + @torch.no_grad() + def test_multihead_attn_fast_path_small_test(self, device, dtype): + mha = torch.nn.MultiheadAttention(3, 3, batch_first=True, dtype=dtype, device=device).eval() + query = torch.randn(3, 3, 3, dtype=dtype, device=device) + mha(query, query, query) + @dtypes(torch.double) @torch.no_grad() def test_multihead_attn_in_proj_bias_none(self, device, dtype): @@ -20818,6 +20952,7 @@ def tearDown(self): nn.modules.module._global_forward_hooks = OrderedDict() nn.modules.module._global_forward_pre_hooks = OrderedDict() + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_module_global_hooks(self): module = nn.Sigmoid @@ -20938,6 +21073,7 @@ def bw_hook(module, grad_input, grad_output): expected_grad = sig_x * (1 - sig_x) * 2 self.assertEqual(input.grad, expected_grad) + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_module_global_forward_preforward_hook_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) @@ -20959,6 +21095,7 @@ def forward_hook(m, input, output): expected_grad = -sig_x * (1 - sig_x) * 2 * mask self.assertEqual(input.grad, expected_grad) + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_module_forward_preforward_hook_removable(self): """ This test is to test when multiple pre-forward hook functions can be @@ -20994,6 +21131,7 @@ def removable_hook_2(m, input): self.assertEqual(len(handle.hooks_dict_ref()), 0) self.assertEqual(len(handle_2.hooks_dict_ref()), 0) + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_module_forward_forward_hook_removable(self): """ This test is to test when multiple forward hook functions can be registered @@ -21029,6 +21167,7 @@ def removable_hook_2(m, input, output): self.assertEqual(len(handle.hooks_dict_ref()), 0) self.assertEqual(len(handle_2.hooks_dict_ref()), 0) + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_global_and_local_hooks_order(self): module = nn.Sigmoid() @@ -21710,6 +21849,9 @@ def test_pickle_softsign(self): # Make sure it does not throw an exception s = pickle.dumps(F.softsign) +def _hook_to_pickle(*args, **kwargs): + pass + class TestStateDictHooks(TestCase): def test_load_state_dict_pre_hook(self): @@ -21742,6 +21884,24 @@ def hook_with_module(module, state_dict, prefix, local_metadata, strict, missing m_load.load_state_dict(m_state_dict) self.assertEqual(2, hook_called) + def test_no_extra_ref_to_module(self): + try: + gc.disable() + m = nn.Linear(10, 10) + + m._register_load_state_dict_pre_hook(_hook_to_pickle, True) + weak_m = weakref.ref(m) + del m + + self.assertEqual(weak_m(), None) + finally: + gc.enable() + + def test_pickled_hook(self): + m = nn.Linear(10, 10) + m._register_load_state_dict_pre_hook(_hook_to_pickle, True) + pickle.loads(pickle.dumps(m)) + def test_load_state_dict_module_pre_hook(self): hook_called = 0 diff --git a/test/test_ops.py b/test/test_ops.py index 73774eeaf95cc..050ae49649e6a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,6 +7,7 @@ import itertools import torch import contextlib +from collections import defaultdict from importlib import import_module from torch.utils._pytree import tree_map @@ -31,12 +32,14 @@ IS_FBCODE, first_sample, parametrize, + skipIfSlowGradcheckEnv, ) from torch.testing._internal.common_methods_invocations import ( op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, + ReductionPythonRefInfo, SpectralFuncInfo, ops_and_refs, python_ref_db, @@ -65,7 +68,7 @@ from torch.testing._internal import composite_compliance from torch.utils._pytree import tree_flatten -from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode +from torch.utils._python_dispatch import TorchDispatchMode # TODO: fixme https://github.com/pytorch/pytorch/issues/68972 torch.set_default_dtype(torch.float32) @@ -95,6 +98,7 @@ # Tests that apply to all operators and aren't related to any particular # system +@skipIfSlowGradcheckEnv class TestCommon(TestCase): exact_dtype = True @@ -116,202 +120,6 @@ def tearDownClass(cls): assert len(filtered_ops) == 0, err_msg - # Validates that each OpInfo specifies its forward and backward dtypes - # correctly for CPU and CUDA devices - @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") - @skipMeta - @onlyNativeDeviceTypes - @ops(ops_and_refs, dtypes=OpDTypes.none) - def test_dtypes(self, device, op): - # Check complex32 support only if the op claims. - # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. - device_type = torch.device(device).type - include_complex32 = ( - (torch.complex32,) - if op.supports_dtype(torch.complex32, device_type) - else () - ) - - # dtypes to try to backward in - allowed_backward_dtypes = floating_and_complex_types_and( - *((torch.half, torch.bfloat16) + include_complex32) - ) - - # lists for (un)supported dtypes - supported_dtypes = set() - unsupported_dtypes = set() - supported_backward_dtypes = set() - unsupported_backward_dtypes = set() - - def unsupported(dtype): - unsupported_dtypes.add(dtype) - if dtype in allowed_backward_dtypes: - unsupported_backward_dtypes.add(dtype) - - for dtype in all_types_and_complex_and( - *((torch.half, torch.bfloat16, torch.bool) + include_complex32) - ): - # tries to acquire samples - failure indicates lack of support - requires_grad = dtype in allowed_backward_dtypes - try: - samples = tuple( - op.sample_inputs(device, dtype, requires_grad=requires_grad) - ) - except Exception as e: - unsupported(dtype) - continue - - for sample in samples: - # tries to call operator with the sample - failure indicates - # lack of support - try: - result = op(sample.input, *sample.args, **sample.kwargs) - supported_dtypes.add(dtype) - except Exception as e: - # NOTE: some ops will fail in forward if their inputs - # require grad but they don't support computing the gradient - # in that type! This is a bug in the op! - unsupported(dtype) - continue - - # Checks for backward support in the same dtype, if the input has - # one or more tensors requiring grad - def _tensor_requires_grad(x): - if isinstance(x, dict): - for k, v in x.items(): - if _tensor_requires_grad(v): - return True - if isinstance(x, (list, tuple)): - for a in x: - if _tensor_requires_grad(a): - return True - if isinstance(x, torch.Tensor) and x.requires_grad: - return True - - return False - - requires_grad = _tensor_requires_grad(sample.input) \ - or _tensor_requires_grad(sample.args) or _tensor_requires_grad(sample.kwargs) - if not requires_grad: - continue - - try: - result = sample.output_process_fn_grad(result) - if isinstance(result, torch.Tensor): - backward_tensor = result - elif isinstance(result, Sequence) and isinstance( - result[0], torch.Tensor - ): - backward_tensor = result[0] - else: - continue - - # Note: this grad may not have the same dtype as dtype - # For functions like complex (float -> complex) or abs - # (complex -> float) the grad tensor will have a - # different dtype than the input. - # For simplicity, this is still modeled as these ops - # supporting grad in the input dtype. - grad = torch.randn_like(backward_tensor) - backward_tensor.backward(grad) - supported_backward_dtypes.add(dtype) - except Exception as e: - unsupported_backward_dtypes.add(dtype) - - # Checks that dtypes are listed correctly and generates an informative - # error message - - supported_forward = supported_dtypes - unsupported_dtypes - partially_supported_forward = supported_dtypes & unsupported_dtypes - unsupported_forward = unsupported_dtypes - supported_dtypes - supported_backward = supported_backward_dtypes - unsupported_backward_dtypes - partially_supported_backward = ( - supported_backward_dtypes & unsupported_backward_dtypes - ) - unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes - - device_type = torch.device(device).type - - claimed_forward = set(op.supported_dtypes(device_type)) - supported_but_unclaimed_forward = supported_forward - claimed_forward - claimed_but_unsupported_forward = claimed_forward & unsupported_forward - - claimed_backward = set(op.supported_backward_dtypes(device_type)) - supported_but_unclaimed_backward = supported_backward - claimed_backward - claimed_but_unsupported_backward = claimed_backward & unsupported_backward - - # Partially supporting a dtype is not an error, but we print a warning - if (len(partially_supported_forward) + len(partially_supported_backward)) > 0: - msg = "Some dtypes for {0} on device type {1} are only partially supported!\n".format( - op.name, device_type - ) - if len(partially_supported_forward) > 0: - msg = ( - msg - + "The following dtypes only worked on some samples during forward: {0}.\n".format( - partially_supported_forward - ) - ) - if len(partially_supported_backward) > 0: - msg = ( - msg - + "The following dtypes only worked on some samples during backward: {0}.\n".format( - partially_supported_backward - ) - ) - print(msg) - - if ( - len(supported_but_unclaimed_forward) - + len(claimed_but_unsupported_forward) - + len(supported_but_unclaimed_backward) - + len(claimed_but_unsupported_backward) - ) == 0: - return - - # Reference operators often support additional dtypes, and that's OK - if op in python_ref_db: - if ( - len(claimed_but_unsupported_forward) - + len(claimed_but_unsupported_backward) - ) == 0: - return - - # Generates error msg - msg = "The supported dtypes for {0} on device type {1} are incorrect!\n".format( - op.name, device_type - ) - if len(supported_but_unclaimed_forward) > 0: - msg = ( - msg - + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format( - supported_but_unclaimed_forward - ) - ) - if len(supported_but_unclaimed_backward) > 0: - msg = ( - msg - + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format( - supported_but_unclaimed_backward - ) - ) - if len(claimed_but_unsupported_forward) > 0: - msg = ( - msg - + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format( - claimed_but_unsupported_forward - ) - ) - if len(claimed_but_unsupported_backward) > 0: - msg = ( - msg - + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format( - claimed_but_unsupported_backward - ) - ) - - self.fail(msg) - # Validates that each OpInfo works correctly on different CUDA devices @onlyCUDA @deviceCountAtLeast(2) @@ -361,7 +169,7 @@ def test_numpy_ref(self, device, dtype, op): @onlyNativeDeviceTypes @ops(python_ref_db) def test_python_ref_meta(self, device, dtype, op): - mode = torch._prims.utils.get_prim_fake_mode() + mode = torch._prims.get_prim_fake_mode() def _to_tensormeta(x): if isinstance(x, torch.Tensor): @@ -381,18 +189,22 @@ def _to_tensormeta(x): continue if isinstance(result, torch.Tensor): + self.assertTrue(isinstance(meta_result, FakeTensor)) prims.utils.compare_tensor_meta(result, meta_result) elif isinstance(result, Sequence): for a, b in zip(result, meta_result): if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): + self.assertTrue(isinstance(b, FakeTensor)) prims.utils.compare_tensor_meta(a, b) - def _ref_test_helper(self, ctx, device, dtype, op, skip_zero_numel=False): + def _ref_test_helper(self, ctx, device, dtype, op, skip_zero_numel=False, skip_zero_dim=False): # NOTE: this test works by comparing the reference ex = None for sample in op.reference_inputs(device, dtype, requires_grad=False): if isinstance(sample.input, torch.Tensor) and sample.input.numel() == 0 and skip_zero_numel: continue + if isinstance(sample.input, torch.Tensor) and sample.input.ndim == 0 and skip_zero_dim: + continue with ctx(): ref_result = op(sample.input, *sample.args, **sample.kwargs) torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs) @@ -440,6 +252,8 @@ def _ref_test_helper(self, ctx, device, dtype, op, skip_zero_numel=False): # If the results are not close, checks that the # reference is more accurate than the torch op def _make_precise(x): + if isinstance(x, torch.dtype): + return precise_dtype if isinstance(x, torch.Tensor) and x.dtype is dtype: return x.to(precise_dtype) return x @@ -488,7 +302,7 @@ def test_python_ref(self, device, dtype, op): # In this test, primTorch refs call into the refs namespace # For example, a ref with torch.foo in it will calls refs.foo instead # Direct calls to refs and prims are not affected - self._ref_test_helper(lambda: TorchRefsMode.push(strict=True), device, dtype, op) + self._ref_test_helper(lambda: TorchRefsMode(strict=True), device, dtype, op) # Tests that experimental Python References perform the same computation # as the operators they reference, when operator calls in the torch @@ -509,7 +323,7 @@ def test_python_ref_torch_fallback(self, device, dtype, op): @parametrize('executor', ['aten', 'nvfuser']) def test_python_ref_executor(self, device, dtype, op, executor): # TODO: Not all dtypes are supported with nvfuser - from torch._prims.utils import _torch_dtype_to_nvfuser_dtype_map + from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map if executor == "nvfuser" and dtype not in _torch_dtype_to_nvfuser_dtype_map: raise unittest.SkipTest(f"nvfuser doesn't support dtype {dtype}") @@ -520,16 +334,28 @@ def test_python_ref_executor(self, device, dtype, op, executor): if executor == "nvfuser" and not op.supports_nvfuser: raise unittest.SkipTest(f"{op.name} doesn't support nvfuser") + # nvFuser doesn't support reduction operations on 0-dim tensors yet + skip_zero_dim = False + if executor == "nvfuser" and isinstance(op, ReductionPythonRefInfo): + skip_zero_dim = True + + # skip zero-dim tensors for some composites of reduction operations + normalization_ops = ["_refs.softmax", "_refs.logsumexp", "_refs.log_softmax"] + if executor == "nvfuser" and op.name in normalization_ops: + skip_zero_dim = True + from torch._prims.executor import make_traced from copy import copy op = copy(op) + executor = "strictly_nvfuser" if executor == "nvfuser" else executor op.op = partial(make_traced(op.op), executor=executor) self._ref_test_helper( contextlib.nullcontext, device, dtype, op, - skip_zero_numel=(executor == "nvfuser"), # nvfuser doesn't support zero-sized tensors + skip_zero_numel=("nvfuser" in executor), # nvfuser doesn't support zero-sized tensors + skip_zero_dim=skip_zero_dim, ) @skipMeta @@ -546,7 +372,7 @@ def test_errors(self, device, op): @onlyNativeDeviceTypes @ops([op for op in python_ref_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) def test_python_ref_errors(self, device, op): - mode = torch._prims.utils.get_prim_fake_mode() + mode = torch._prims.get_prim_fake_mode() def _to_tensormeta(x): if isinstance(x, torch.Tensor): @@ -785,8 +611,12 @@ def _any_nonempty(out): # - Case 1: out has the correct shape, dtype, and device but is noncontiguous # - Case 2: out has the correct dtype and device, but is zero elements # - Case 3: out has the correct shape and dtype, but is on a different device type - # - Case 4: out has the with correct shape and device, but a dtype that cannot + # - Case 4: out has the correct shape and device, but a dtype that cannot # "safely" cast to + # + # Case 3 and 4 are slightly different when the op is a factory function: + # - if device, dtype are NOT passed, any combination of dtype/device should be OK for out + # - if device, dtype are passed, device and dtype should match @ops(_ops_and_refs, dtypes=OpDTypes.any_one) def test_out(self, device, dtype, op): # Prefers running in float32 but has a fallback for the first listed supported dtype @@ -911,15 +741,26 @@ def _case_two_transform(t): elif torch.cuda.is_available(): wrong_device = "cuda" + + factory_fn_msg = ( + "\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its " + "OpInfo with `is_factory_function=True`." + ) if wrong_device is not None: def _case_three_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) out = _apply_out_transform(_case_three_transform, expected) - msg_fail = f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}" - with self.assertRaises(RuntimeError, msg=msg_fail): + + if op.is_factory_function and sample.kwargs.get("device", None) is None: op_out(out=out) + else: + msg_fail = ( + f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}." + ) + factory_fn_msg + with self.assertRaises(RuntimeError, msg=msg_fail): + op_out(out=out) # Case 4: out= with correct shape and device, but a dtype # that output cannot be "safely" cast to (long). @@ -951,9 +792,13 @@ def _case_four_transform(t): "Expected RuntimeError when doing an unsafe cast from a result of dtype " f"{expected.dtype} into an out= with dtype torch.long" ) - ) - with self.assertRaises(RuntimeError, msg=msg_fail): + ) + factory_fn_msg + + if op.is_factory_function and sample.kwargs.get("dtype", None) is None: op_out(out=out) + else: + with self.assertRaises(RuntimeError, msg=msg_fail): + op_out(out=out) # Tests that the forward and backward passes of operations produce the # same values for the cross-product of op variants (method, inplace) @@ -1188,6 +1033,202 @@ def convert_boolean_tensors(x): self.assertEqual(expect, actual) + # Validates that each OpInfo specifies its forward and backward dtypes + # correctly for CPU and CUDA devices + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @skipMeta + @onlyNativeDeviceTypes + @ops(ops_and_refs, dtypes=OpDTypes.none) + def test_dtypes(self, device, op): + # Check complex32 support only if the op claims. + # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. + device_type = torch.device(device).type + include_complex32 = ( + (torch.complex32,) + if op.supports_dtype(torch.complex32, device_type) + else () + ) + + # dtypes to try to backward in + allowed_backward_dtypes = floating_and_complex_types_and( + *((torch.half, torch.bfloat16) + include_complex32) + ) + + # lists for (un)supported dtypes + supported_dtypes = set() + unsupported_dtypes = set() + supported_backward_dtypes = set() + unsupported_backward_dtypes = set() + + def unsupported(dtype): + unsupported_dtypes.add(dtype) + if dtype in allowed_backward_dtypes: + unsupported_backward_dtypes.add(dtype) + + for dtype in all_types_and_complex_and( + *((torch.half, torch.bfloat16, torch.bool) + include_complex32) + ): + # tries to acquire samples - failure indicates lack of support + requires_grad = dtype in allowed_backward_dtypes + try: + samples = tuple( + op.sample_inputs(device, dtype, requires_grad=requires_grad) + ) + except Exception as e: + unsupported(dtype) + continue + + for sample in samples: + # tries to call operator with the sample - failure indicates + # lack of support + try: + result = op(sample.input, *sample.args, **sample.kwargs) + supported_dtypes.add(dtype) + except Exception as e: + # NOTE: some ops will fail in forward if their inputs + # require grad but they don't support computing the gradient + # in that type! This is a bug in the op! + unsupported(dtype) + continue + + # Checks for backward support in the same dtype, if the input has + # one or more tensors requiring grad + def _tensor_requires_grad(x): + if isinstance(x, dict): + for k, v in x.items(): + if _tensor_requires_grad(v): + return True + if isinstance(x, (list, tuple)): + for a in x: + if _tensor_requires_grad(a): + return True + if isinstance(x, torch.Tensor) and x.requires_grad: + return True + + return False + + requires_grad = _tensor_requires_grad(sample.input) \ + or _tensor_requires_grad(sample.args) or _tensor_requires_grad(sample.kwargs) + if not requires_grad: + continue + + try: + result = sample.output_process_fn_grad(result) + if isinstance(result, torch.Tensor): + backward_tensor = result + elif isinstance(result, Sequence) and isinstance( + result[0], torch.Tensor + ): + backward_tensor = result[0] + else: + continue + + # Note: this grad may not have the same dtype as dtype + # For functions like complex (float -> complex) or abs + # (complex -> float) the grad tensor will have a + # different dtype than the input. + # For simplicity, this is still modeled as these ops + # supporting grad in the input dtype. + grad = torch.randn_like(backward_tensor) + backward_tensor.backward(grad) + supported_backward_dtypes.add(dtype) + except Exception as e: + unsupported_backward_dtypes.add(dtype) + + # Checks that dtypes are listed correctly and generates an informative + # error message + + supported_forward = supported_dtypes - unsupported_dtypes + partially_supported_forward = supported_dtypes & unsupported_dtypes + unsupported_forward = unsupported_dtypes - supported_dtypes + supported_backward = supported_backward_dtypes - unsupported_backward_dtypes + partially_supported_backward = ( + supported_backward_dtypes & unsupported_backward_dtypes + ) + unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes + + device_type = torch.device(device).type + + claimed_forward = set(op.supported_dtypes(device_type)) + supported_but_unclaimed_forward = supported_forward - claimed_forward + claimed_but_unsupported_forward = claimed_forward & unsupported_forward + + claimed_backward = set(op.supported_backward_dtypes(device_type)) + supported_but_unclaimed_backward = supported_backward - claimed_backward + claimed_but_unsupported_backward = claimed_backward & unsupported_backward + + # Partially supporting a dtype is not an error, but we print a warning + if (len(partially_supported_forward) + len(partially_supported_backward)) > 0: + msg = "Some dtypes for {0} on device type {1} are only partially supported!\n".format( + op.name, device_type + ) + if len(partially_supported_forward) > 0: + msg = ( + msg + + "The following dtypes only worked on some samples during forward: {0}.\n".format( + partially_supported_forward + ) + ) + if len(partially_supported_backward) > 0: + msg = ( + msg + + "The following dtypes only worked on some samples during backward: {0}.\n".format( + partially_supported_backward + ) + ) + print(msg) + + if ( + len(supported_but_unclaimed_forward) + + len(claimed_but_unsupported_forward) + + len(supported_but_unclaimed_backward) + + len(claimed_but_unsupported_backward) + ) == 0: + return + + # Reference operators often support additional dtypes, and that's OK + if op in python_ref_db: + if ( + len(claimed_but_unsupported_forward) + + len(claimed_but_unsupported_backward) + ) == 0: + return + + # Generates error msg + msg = "The supported dtypes for {0} on device type {1} are incorrect!\n".format( + op.name, device_type + ) + if len(supported_but_unclaimed_forward) > 0: + msg = ( + msg + + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format( + supported_but_unclaimed_forward + ) + ) + if len(supported_but_unclaimed_backward) > 0: + msg = ( + msg + + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format( + supported_but_unclaimed_backward + ) + ) + if len(claimed_but_unsupported_forward) > 0: + msg = ( + msg + + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format( + claimed_but_unsupported_forward + ) + ) + if len(claimed_but_unsupported_backward) > 0: + msg = ( + msg + + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format( + claimed_but_unsupported_backward + ) + ) + + self.fail(msg) + class TestCompositeCompliance(TestCase): # Checks if the operator (if it is composite) is written to support most @@ -1203,8 +1244,8 @@ def test_operator(self, device, dtype, op): for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs - composite_compliance.check_with_mode(op, args, kwargs) - composite_compliance.check_all_permutations(op, args, kwargs) + composite_compliance.check_with_mode(op, args, kwargs, self.assertEqual) + composite_compliance.check_all_permutations(op, args, kwargs, self.assertEqual) @unittest.skipIf( IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" @@ -1216,7 +1257,12 @@ def test_backward(self, device, dtype, op): for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs - composite_compliance.check_backward_formula(op, args, kwargs, sample.output_process_fn_grad) + # We pass assertEqual so that decorators like `toleranceOverride` + # actually work (otherwise they silently do nothing!) + composite_compliance.check_backward_formula( + op.get_op(), args, kwargs, + sample.output_process_fn_grad, + op.gradcheck_wrapper, self.assertEqual) @unittest.skipIf( IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" @@ -1234,9 +1280,13 @@ def test_forward_ad(self, device, dtype, op): for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs - composite_compliance.check_forward_ad_formula(op, args, kwargs) + # We pass assertEqual so that decorators like `toleranceOverride` + # actually work (otherwise they silently do nothing!) + composite_compliance.check_forward_ad_formula( + op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual) +@skipIfSlowGradcheckEnv class TestMathBits(TestCase): # Tests that # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors @@ -1425,7 +1475,7 @@ def is_bit_set(x): ) # input strides and size may have been altered due to the result of an inplace op -def test_inplace_view(func, input, rs, input_size, input_strides): +def check_inplace_view(func, input, rs, input_size, input_strides): if func is None: return # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out @@ -1446,18 +1496,20 @@ def test_inplace_view(func, input, rs, input_size, input_strides): # A mode that when enabled runs correctness checks to ensure # that operators have expected tags based on their input and # ouput tensor properties +@skipIfSlowGradcheckEnv class TestTagsMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): if isinstance(args[0], torch.Tensor): old_size = args[0].size() old_stride = args[0].stride() rs = func(*args, **kwargs) - test_inplace_view(func, args[0], rs, old_size, old_stride) + check_inplace_view(func, args[0], rs, old_size, old_stride) else: rs = func(*args, **kwargs) return rs # Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags` +@skipIfSlowGradcheckEnv class TestTags(TestCase): @onlyCPU @ops(ops_and_refs, dtypes=OpDTypes.any_one) @@ -1469,20 +1521,21 @@ def test_tags(self, device, dtype, op): if isinstance(input, torch.Tensor): old_size = input.size() old_stride = input.stride() - with push_torch_dispatch_mode(TestTagsMode): + with TestTagsMode(): rs = op(input, *sample.args, **sample.kwargs) # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761 aten_name = op.aten_name if op.aten_name is not None else op.name opoverloadpacket = getattr(torch.ops.aten, aten_name, None) - test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride) + check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride) +@skipIfSlowGradcheckEnv class TestRefsOpsInfo(TestCase): - import_paths = ["_refs", "_refs.special", "_refs.nn.functional"] + import_paths = ["_refs", "_refs.special", "_refs.nn.functional", "_refs.fft"] module_alls = [(path, import_module(f"torch.{path}").__all__) for path in import_paths] - ref_ops_names = itertools.chain.from_iterable( - [f"{path}.{op}" for op in module_all] for path, module_all in module_alls) + ref_ops_names = tuple(itertools.chain.from_iterable( + [f"{path}.{op}" for op in module_all] for path, module_all in module_alls)) ref_db_names = set(ref_op.name for ref_op in python_ref_db) # TODO: References that do not have an entry in python_ref_db @@ -1499,18 +1552,101 @@ class TestRefsOpsInfo(TestCase): '_refs.std_var', '_refs.swap_axes', '_refs.uniform', + '_refs.scalar_tensor', + '_refs.trunc_divide', '_refs.zeros', '_refs.zeros_like' } + not_in_decomp_table = { + # duplicated in _decomp and _refs + '_refs.nn.functional.elu', + '_refs.nn.functional.mse_loss', + '_refs.masked_fill', + '_refs.transpose', + '_refs.var', + '_refs.rsub', + # these are not aten ops? + '_refs.broadcast_shapes', + '_refs.broadcast_tensors', + '_refs.nn.functional.tanhshrink', + '_refs.swap_axes', + # CompositeImplicitAutograd + '_refs.allclose', + '_refs.atleast_1d', + '_refs.atleast_2d', + '_refs.atleast_3d', + '_refs.broadcast_to', + '_refs.chunk', + '_refs.column_stack', + '_refs.contiguous', + '_refs.dsplit', + '_refs.dstack', + '_refs.fill', + '_refs.flatten', + '_refs.fliplr', + '_refs.flipud', + '_refs.float_power', + '_refs.hsplit', + '_refs.hstack', + '_refs.isclose', + '_refs.isfinite', + '_refs.narrow', + '_refs.positive', + '_refs.ravel', + '_refs.reshape', + '_refs.square', + '_refs.tensor_split', + '_refs.true_divide', + '_refs.trunc_divide', + '_refs.vsplit', + '_refs.vstack', + '_refs.linalg.matrix_norm', + '_refs.linalg.norm', + '_refs.linalg.svd', + '_refs.linalg.svdvals', + # ref implementation missing kwargs + '_refs.empty', # missing "pin_memory" + '_refs.empty_like', # missing "layout" + '_refs.full', # missing "layout" + '_refs.full_like', # missing "layout" + '_refs.ones', # missing "layout" + '_refs.ones_like', # missing "layout" + '_refs.round', # missing "decimals" + '_refs.scalar_tensor', # missing "layout" + '_refs.zeros', # missing "layout" + '_refs.zeros_like', # missing "layout" + # other + '_refs.as_strided', # _prims._as_strided_meta: "reduce() of empty sequence with no initial value" + '_refs.copy_to', # torch._C._jit_get_operation: No such operator aten::copy_to + '_refs.clone', # test_meta.py: view size is not compatible with input tensor's size and stride + '_refs.equal', # 'bool' object has no attribute 'dtype' + '_refs.conj', # Calls _prims.conj + } + @parametrize("op", ref_ops_names) def test_refs_are_in_python_ref_db(self, op): if op in self.skip_ref_ops: raise unittest.SkipTest(f"{op} does not have an entry in python_ref_db") self.assertIn(op, self.ref_db_names) + @parametrize("op", ref_ops_names) + def test_refs_are_in_decomp_table(self, op): + path = op.split('.') + module_path = '.'.join(path[:-1]) + op_name = path[-1] + op_impl = getattr(import_module(f"torch.{module_path}"), op_name) + + if op in self.not_in_decomp_table: + self.assertFalse(op_impl in torch._decomp.decomposition_table.values(), + f"Unexpectedly found {op} in torch._decomp.decomposition_table.values()") + else: + self.assertTrue(op_impl in torch._decomp.decomposition_table.values(), + f"Did not find {op} in torch._decomp.decomposition_table.values()") + fake_skips = ( + "aminmax", # failing input "cholesky", # Could not run 'aten::cholesky' with arguments from the 'Meta' backend "cholesky_inverse", # Could not run 'aten::cholesky' with arguments from the 'Meta' backend "cov", # aweights cannot be negtaive @@ -1544,6 +1680,14 @@ def test_refs_are_in_python_ref_db(self, op): "nn.functional.one_hot", ) +fake_autocast_device_skips = defaultdict(dict) + +# TODO: investigate/fix +fake_autocast_device_skips["cpu"] = set( + ("linalg.pinv",) +) + + dynamic_output_op_tests = ( "argwhere", "bincount", @@ -1562,15 +1706,46 @@ def test_refs_are_in_python_ref_db(self, op): "index_select", ) +aliasing_failures = ( + "histogramdd", + "nn.functional.pixel_shuffle", + "nn.functional.pixel_unshuffle", +) + +fake_striding_skips = ( + "fft.fft2", + "fft.fft", + "fft.fftn", + "fft.hfft2", + "fft.hfft", + "fft.hfftn", + "fft.ifft2", + "fft.ifft", + "fft.ifftn", + "fft.ihfft2", + "fft.ihfft", + "fft.ihfftn", + "fft.irfft2", + "fft.irfft", + "fft.irfftn", + "fft.rfft2", + "fft.rfft", + "fft.rfftn", + "svd", + "linalg.svd", + "nn.functional.conv_transpose2d", +) + + +@skipIfSlowGradcheckEnv class TestFakeTensorNonErroring(TestCase): - @onlyCPU - @ops(op_db, dtypes=OpDTypes.any_one) - def test_fake(self, device, dtype, op): + def _test_fake_helper(self, device, dtype, op, context): name = op.name if op.variant_test_name: name += "." + op.variant_test_name - if name in fake_skips or "sparse" in name: + if name in fake_skips or "sparse" in name or "jiterator" in name: self.skipTest("Skip failing test") + samples = op.sample_inputs(device, dtype, requires_grad=False) for sample in samples: try: @@ -1586,10 +1761,25 @@ def map_to_fake(e): args = tree_map(map_to_fake, sample.args) kwargs = tree_map(map_to_fake, sample.kwargs) - with enable_torch_dispatch_mode(mode): - res_fake = op(input, *args, **kwargs) + try: + with context(): + res = op(sample.input, *sample.args, **sample.kwargs) + except Exception as e: + continue - res = op(sample.input, *sample.args, **sample.kwargs) + with context(): + with enable_torch_dispatch_mode(mode): + res_fake = op(input, *args, **kwargs) + + def outputs_alias_inputs(outputs, inputs): + input_storages = set() + for out in tree_flatten(outputs)[0]: + if isinstance(out, torch.Tensor): + input_storages.add(out.storage()._cdata) + for inp in tree_flatten(inputs)[0]: + if isinstance(inp, torch.Tensor) and inp.storage()._cdata in input_storages: + return True + return False for fake_out, real_out in zip( tree_flatten(res_fake)[0], tree_flatten(res)[0] @@ -1601,7 +1791,20 @@ def map_to_fake(e): self.assertTrue(isinstance(fake_out, FakeTensor)) # if you see a shape exception here, you may need to add # a `dynamic_output_shape` tag to an operator - prims.utils.compare_tensor_meta(fake_out, real_out) + + check_strides = name not in fake_striding_skips + + # if there is a striding failure here as a result of adding a primtorch ref, + # feel free to add the op to `fake_striding_skips` but please tag + # @eellison on the pr. + # see: https://github.com/pytorch/pytorch/issues/78050 + prims.utils.compare_tensor_meta(fake_out, real_out, check_strides) + + if name not in aliasing_failures: + fake_aliasing = outputs_alias_inputs((input, args, kwargs), res_fake) + real_aliasing = outputs_alias_inputs((sample.input, sample, args, sample.kwargs), res) + self.assertEqual(fake_aliasing, real_aliasing) + self.assertTrue(name not in dynamic_output_op_tests) except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: @@ -1609,6 +1812,17 @@ def map_to_fake(e): except torch._subclasses.fake_tensor.DynamicOutputShapeException: self.assertTrue(name in dynamic_output_op_tests or name in sometimes_dynamic_output_op_test) + @ops(op_db, dtypes=OpDTypes.any_one) + def test_fake(self, device, dtype, op): + self._test_fake_helper(device, dtype, op, contextlib.nullcontext) + + @ops(op_db, dtypes=OpDTypes.any_one) + def test_fake_autocast(self, device, dtype, op): + if op.name in fake_autocast_device_skips[device]: + self.skipTest("Skip failing test") + context = torch.cuda.amp.autocast if device == "cuda" else torch.cpu.amp.autocast + self._test_fake_helper(device, dtype, op, context) + instantiate_device_type_tests(TestCommon, globals()) instantiate_device_type_tests(TestCompositeCompliance, globals()) diff --git a/test/test_ops_gradients.py b/test/test_ops_gradients.py index 64cd71fdee6d6..bed924c2aec0a 100644 --- a/test/test_ops_gradients.py +++ b/test/test_ops_gradients.py @@ -5,7 +5,7 @@ import torch from torch.testing._internal.common_utils import \ - (TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck) + (TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, is_slow_gradcheck_env) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, OpDTypes) @@ -44,7 +44,9 @@ def is_inplace(variant): return variant is op.get_inplace() include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex - samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs) + + samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs, + small_inputs_only=is_slow_gradcheck_env()) for sample in samples: if sample.broadcasts_input and is_inplace(variant): diff --git a/test/test_ops_jit.py b/test/test_ops_jit.py index 0d44f59f37eb6..21fccf6622953 100644 --- a/test/test_ops_jit.py +++ b/test/test_ops_jit.py @@ -6,7 +6,7 @@ from torch.testing import FileCheck from torch.testing._internal.common_utils import \ - (run_tests, IS_SANDCASTLE, clone_input_helper, first_sample) + (run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, skipIfSlowGradcheckEnv) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference @@ -29,6 +29,7 @@ # autodifferentiation behavior. # Inherits from JitCommonTestCase instead of TestCase directly to share # functionality with original test_jit.py method operator tests +@skipIfSlowGradcheckEnv class TestJit(JitCommonTestCase): exact_dtype = True diff --git a/test/test_optim.py b/test/test_optim.py index 58795b9299752..7ea094d399367 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -21,7 +21,7 @@ EPOCH_DEPRECATION_WARNING from torch.optim.swa_utils import AveragedModel, SWALR, update_bn from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \ - parametrize, instantiate_parametrized_tests + parametrize, instantiate_parametrized_tests, gradcheck # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -41,7 +41,7 @@ class TestOptim(TestCase): exact_dtype = True def _test_rosenbrock_sparse(self, constructor, scheduler_constructors=None, - sparse_only=False): + sparse_only=False, maximize=False): if scheduler_constructors is None: scheduler_constructors = [] params_t = torch.tensor([1.5, 1.5]) @@ -96,7 +96,10 @@ def eval(params, sparse_grad, w): optimizer_c.step(functools.partial(eval, params_c, False, w)) self.assertEqual(params.data, params_c.data) - self.assertLessEqual(params.data.dist(solution), initial_dist) + if not maximize: + self.assertLessEqual(params.data.dist(solution), initial_dist) + else: + self.assertGreaterEqual(rosenbrock(params.data), rosenbrock(params_t)) def _test_basic_cases_template(self, weight, bias, input, constructor, scheduler_constructors, constructor_accepts_maximize=True): @@ -317,6 +320,32 @@ def _test_complex_optimizer(self, optimizer_constructor): self.assertEqual(torch.view_as_real(complex_param), real_param) + def _test_complex_2d(self, optimizer_constructor, f=None): + if f is None: + f = rosenbrock + a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True) + a1_real = a1.real.clone().detach() + a1_imag = a1.imag.clone().detach() + a1_real.requires_grad_() + a1_imag.requires_grad_() + optim1 = optimizer_constructor([a1]) + optim2 = optimizer_constructor([a1_real, a1_imag]) + + for i in range(10): + optim1.zero_grad() + optim2.zero_grad() + a2 = torch.complex(a1_real, a1_imag) + f(a1).backward() + f(a2).backward() + + self.assertEqual(a1.grad.real, a1_real.grad) + self.assertEqual(a1.grad.imag, a1_imag.grad) + + optim1.step() + optim2.step() + self.assertEqual(a1.real, a1_real) + self.assertEqual(a1.imag, a1_imag) + def _build_params_dict(self, weight, bias, **kwargs): return [{'params': [weight]}, dict(params=[bias], **kwargs)] @@ -563,27 +592,14 @@ def test_adam(self): lambda opt: ReduceLROnPlateau(opt)], constructor_accepts_maximize=True ) + self._test_complex_2d(optimizer) + with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): optimizer(None, lr=1e-2, betas=(1.0, 0.0)) with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): optimizer(None, lr=1e-2, weight_decay=-1) - # Test whether variance parameter is always real - def test_complex_adam_variance(self): - complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True) - target = torch.randn(5, 5, dtype=torch.complex64) - optimizer = optim.Adam([complex_param], lr=0.001) - - for i in range(20): - optimizer.zero_grad() - loss = (complex_param - target).pow(2).sum() - loss.backward() - optimizer.step() - for idx in optimizer.state_dict()['state'].keys(): - variance = optimizer.state_dict()['state'][idx]['exp_avg_sq'] - self.assertEqual(variance.imag, torch.zeros(variance.imag.shape)) - def test_adamw(self): for optimizer in [optim.AdamW, optim_mt.AdamW]: self._test_basic_cases( @@ -612,6 +628,12 @@ def test_sparse_adam(self): [], True ) + self._test_rosenbrock_sparse( + lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True), + [], + True, + True + ) with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0)) with self.assertRaisesRegex(ValueError, "SparseAdam requires dense parameter tensors"): @@ -792,32 +814,38 @@ def test_radam(self): def test_rmsprop(self): for optimizer in [optim.RMSprop, optim_mt.RMSprop]: self._test_basic_cases( - lambda weight, bias: optimizer([weight, bias], lr=1e-2) + lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-2, maximize=maximize), + constructor_accepts_maximize=True ) self._test_basic_cases( - lambda weight, bias: optimizer( + lambda weight, bias, maximize: optimizer( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2) + lr=1e-2, maximize=maximize), + constructor_accepts_maximize=True ) self._test_basic_cases( - lambda weight, bias: optimizer( + lambda weight, bias, maximize: optimizer( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, centered=True) + lr=1e-2, centered=True, maximize=maximize), + constructor_accepts_maximize=True ) self._test_basic_cases( - lambda weight, bias: optimizer( + lambda weight, bias, maximize: optimizer( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, centered=True, momentum=0.1) + lr=1e-2, centered=True, momentum=0.1, maximize=maximize), + constructor_accepts_maximize=True ) self._test_basic_cases( - lambda weight, bias: optimizer( + lambda weight, bias, maximize: optimizer( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, momentum=0.1) + lr=1e-2, momentum=0.1, maximize=maximize), + constructor_accepts_maximize=True ) self._test_basic_cases( - lambda weight, bias: optimizer( + lambda weight, bias, maximize: optimizer( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, momentum=0.1, weight_decay=1) + lr=1e-2, momentum=0.1, weight_decay=1, maximize=maximize), + constructor_accepts_maximize=True ) with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"): optimizer(None, lr=1e-2, momentum=-1.0) @@ -825,17 +853,20 @@ def test_rmsprop(self): def test_asgd(self): for optimizer in [optim.ASGD, optim_mt.ASGD]: self._test_basic_cases( - lambda weight, bias: optimizer([weight, bias], lr=1e-3, t0=100) + lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, t0=100, maximize=maximize), + constructor_accepts_maximize=True ) self._test_basic_cases( - lambda weight, bias: optimizer( + lambda weight, bias, maximize: optimizer( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, t0=100) + lr=1e-3, t0=100, maximize=maximize), + constructor_accepts_maximize=True ) self._test_basic_cases( - lambda weight, bias: optimizer( - self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, weight_decay=1) + lambda weight, bias, maximize: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3, weight_decay=1, maximize=maximize), + constructor_accepts_maximize=True ) with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): optimizer(None, lr=1e-2, weight_decay=-0.5) @@ -2665,5 +2696,30 @@ def test_bn_update_eval_momentum(self): instantiate_parametrized_tests(TestLRScheduler) + +def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): + # Ignored is the list of values in `opt_differentiable_state`, we do this + # for `gradcheck` to correctly track the state tensors as function inputs + # because otherwise it can't unpack the values in the `opt_differentiable_state` + # dict + p = p.clone() + p.grad = grad + opt_differentiable_state = {k: v.clone() for k, v in opt_differentiable_state.items()} + opt = opt_class([p], **kwargs) + opt.state.update(opt_differentiable_state) + opt.step() + return (p,) + tuple(opt_differentiable_state.values()) + + +class TestDifferentiableOptimizer(TestCase): + + def test_sgd(self): + p = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand(10, requires_grad=True, dtype=torch.float64) + mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64) + state = {'momentum_buffer': mbuff} + gradcheck(_diff_fn, (p, grad, state, torch.optim.SGD, {'lr': 0.9, 'differentiable': True}, *state.values())) + + if __name__ == '__main__': run_tests() diff --git a/test/test_overrides.py b/test/test_overrides.py index 6e992b4ab38f4..dae399732a5e8 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -19,7 +19,6 @@ TorchFunctionMode ) from torch.utils._mode_utils import find_outermost_mode, all_same_mode, all_same_mode_scope -from functools import partial Tensor = torch.Tensor @@ -338,7 +337,7 @@ def generate_tensor_like_torch_implementations(): msg = ( "The following functions are not tested for __torch_function__ " "support, please ensure there is an entry in the dict returned by " - "torch._overrides.get_testing_overrides for this function or if a " + "torch.overrides.get_testing_overrides for this function or if a " "__torch_function__ override does not make sense, add an entry to " "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}" ) @@ -649,7 +648,11 @@ def instance_gen(): func_args.append(3.5) elif t == 'bool': func_args.append(False) - elif t.startswith('int') or t in {'Dimname', 'DimnameList'}: + elif t == 'Dimname': + func_args.append("") + elif t == 'DimnameList': + func_args.append([""]) + elif t.startswith('int'): func_args.append(0) elif t in {'Stream'}: func_args.append(torch.Stream()) @@ -1127,7 +1130,7 @@ def __torch_function__(self, *args, **kwargs): return -1 # NB: factory functions get overridden too! x = torch.randn(1) - with torch.overrides.push_torch_function_mode(A): + with A(): self.assertEqual(torch.randn(3), -1) self.assertEqual(torch.add(x, x), -1) self.assertEqual(torch.split(None, [2]), -1) # python side @@ -1138,7 +1141,7 @@ class A(TorchFunctionMode): def __torch_function__(self, *args, **kwargs): return -1 - with torch.overrides.push_torch_function_mode(A): + with A(): self.assertEqual(torch.tensor([1]), -1) self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1) self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1) @@ -1157,7 +1160,7 @@ def __torch_function__(self, *args, **kwargs): return -40 x = SubTensor() - with torch.overrides.push_torch_function_mode(A): + with A(): self.assertEqual(torch.neg(x), -40) self.assertEqual(torch.mean(x), -40) self.assertEqual(torch.mm(x, x), -40) @@ -1169,7 +1172,7 @@ def __torch_function__(self, *args, **kwargs): return NotImplemented x = SubTensor() - with torch.overrides.push_torch_function_mode(MyMode): + with MyMode(): self.assertEqual(torch.mean(x), 0) self.assertEqual(torch.mm(x, x), -1) self.assertEqual(bar(x), 1) @@ -1177,38 +1180,6 @@ def __torch_function__(self, *args, **kwargs): TypeError, r'SubTensor.+MyMode', lambda: self.assertEqual(torch.max(x, x))) - def test_mode_stack(self): - logs = [] - - class Logger(TorchFunctionMode): - def __init__(self, name): - self.name = name - - def __torch_function__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - logs.append(self.name) - return func(*args, **kwargs) - - x = torch.randn(1) - with torch.overrides.push_torch_function_mode(partial(Logger, "A")): - with torch.overrides.push_torch_function_mode(partial(Logger, "B")): - torch.mean(x) - - self.assertEqual(logs, ["B", "A"]) - - def test_push_mode_instance_errors(self): - class A(TorchFunctionMode): - pass - with self.assertRaisesRegex(ValueError, 'instance of TorchFunctionMode'): - with torch.overrides.push_torch_function_mode(A()): - pass - - def test_push_mode_returns_unrelated(self): - with self.assertRaisesRegex(ValueError, 'return a TorchFunctionMode'): - with torch.overrides.push_torch_function_mode(lambda *, inner: None): - pass - def test_enable_torch_function_mode_trivial(self): class A(TorchFunctionMode): def __torch_function__(self, *args, **kwargs): @@ -1323,8 +1294,7 @@ def test_error_with_ancestor(self): class A(TorchFunctionMode): pass - x = A() - with x: + with A() as x: pass with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"): @@ -1454,7 +1424,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): x = torch.randn(1) y = torch.randn(1) - with torch.overrides.push_torch_function_mode(A): + with A(): torch.sub(x, y) # add hits the torch function again! self.assertEqual(log, [torch.sub, torch.add]) @@ -1473,7 +1443,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): called = True return func(*args, **kwargs) - with torch.overrides.push_torch_function_mode(A): + with A(): torch._C._nn._parse_to('cpu') self.assertTrue(called) @@ -1494,7 +1464,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): called = True return func(*args, **kwargs) - with torch.overrides.push_torch_function_mode(A): + with A(): torch.distributions.Bernoulli(0.3) self.assertTrue(called) @@ -1527,7 +1497,7 @@ class B(torch.Tensor): b = B() - with torch.overrides.push_torch_function_mode(A): + with A(): r = torch.neg(b) self.assertIs(type(r), B) @@ -1535,7 +1505,7 @@ class B(torch.Tensor): called = 0 - with torch.overrides.push_torch_function_mode(A): + with A(): r = bar(b) self.assertIs(type(r), B) @@ -1556,7 +1526,7 @@ class B(torch.Tensor): pass x = B(torch.randn(5)) - with torch.overrides.push_torch_function_mode(A): + with A(): with torch._C.DisableTorchFunction(): self.assertNotIsInstance(torch.sum(x), B) diff --git a/test/test_prims.py b/test/test_prims.py index 8afeee157553d..faae19b7da6a9 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -1,10 +1,13 @@ # Owner(s): ["module: primTorch"] from functools import partial +from itertools import product +from warnings import catch_warnings +import unittest import torch from torch.testing import make_tensor -from torch.testing._internal.common_utils import parametrize, run_tests, TestCase +from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, @@ -14,6 +17,11 @@ from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input import torch._prims as prims from torch._prims.executor import make_traced +import torch._refs as refs + + +if TEST_SCIPY: + import scipy.special class TestPrims(TestCase): @@ -106,6 +114,30 @@ def _wrapper(a): self.assertTrue(result.is_contiguous) self.assertEqual(_wrapper(a), result) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + @dtypes(torch.float64, torch.long) + def test_cbrt_prim(self, device, dtype): + make_arg = partial(make_tensor, device=device, dtype=dtype) + batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)] + shapes = [(), (0,), (1,), (5,)] + + try: + # Sets the default dtype to NumPy's default dtype of double + cur_default = torch.get_default_dtype() + torch.set_default_dtype(torch.double) + + # Tested here, as this OP is not currently exposed or tested in ATen + for b, s in product(batches, shapes): + x = make_arg(b + s) + y = prims.cbrt(x) + + x_np = x.cpu().numpy() + y_np = scipy.special.cbrt(x_np) + + self.assertEqual(y, y_np, exact_device=False) + finally: + torch.set_default_dtype(cur_default) + @onlyCUDA @skipCUDAIfRocm def test_nvfuser_impl_is_used(self, device): @@ -114,7 +146,7 @@ def test_nvfuser_impl_is_used(self, device): # This test is not intended to test the correctness of the nvfuser implementation from torch._C._nvfuser import FusionDefinition as fd - prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.Ops)) + prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.ops)) ops_without_nvfuser_impl = { name for name in prim_nvfuser_ops @@ -125,6 +157,136 @@ def test_nvfuser_impl_is_used(self, device): ), (f"The following prims do not have 'impl_nvfuser' defined: {ops_without_nvfuser_impl} ", "while there exists nvfuser implementations for them.") + @onlyCUDA + @skipCUDAIfRocm + def test_nvfuser_executor_cached_noncontiguous(self, device): + # This test is to ensure that nvfuser computes correct results for noncontiguous tensors + from torch.fx.experimental.proxy_tensor import make_fx + from torch._prims.context import TorchRefsMode + from torch._prims.executor import execute + + a = torch.randn(3, 3, device=device) + + def func(a): + return torch.sigmoid(a) + + with TorchRefsMode(): + gm = make_fx(func)(a) + + # First run to create the cache + execute(gm, a, executor="nvfuser") + + # a.mT is noncontiguous, but it shouldn't affect correctness + expected = execute(gm, a.mT, executor="aten") + actual = execute(gm, a.mT, executor="nvfuser") + self.assertEqual(expected, actual) + + def test_nvfuser_capability_context(self, device): + # This test is to ensure that the torch calls are replaced with refs + # based on the nvfuser+prims capability + from torch.fx.experimental.proxy_tensor import make_fx + from torch._prims.context import TorchRefsNvfuserCapabilityMode + + # It's assumed that digamma is not supported by nvfuser + # If it's ever supported, this test will need to be updated + self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) + + a = torch.randn(3, 3, device=device) + + def func(a): + return torch.digamma(a) + + with TorchRefsNvfuserCapabilityMode(): + gm = make_fx(func)(a) + + # Check that the torch.digamma is not replaced with torch.ops.prims.digamma + call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes) + includes_aten_digamma = any( + torch.ops.aten.digamma.default == node.target + for node in call_function_nodes + ) + includes_prims_digamma = any( + torch.ops.prims.digamma.default == node.target + for node in call_function_nodes + ) + self.assertTrue(includes_aten_digamma) + self.assertFalse(includes_prims_digamma) + + # Check mixed case, sigmoid is replaced with refs, but digamma is not + def func(a): + return torch.sigmoid(torch.digamma(a)) + + with TorchRefsNvfuserCapabilityMode(): + gm = make_fx(func)(a) + + call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes) + includes_aten_sigmoid = any( + torch.ops.aten.sigmoid.default == node.target + for node in call_function_nodes + ) + includes_prims_digamma = any( + torch.ops.prims.digamma.default == node.target + for node in call_function_nodes + ) + self.assertFalse(includes_aten_sigmoid) + self.assertFalse(includes_prims_digamma) + + @onlyCUDA + @skipCUDAIfRocm + def test_nvfuser_executor_partitioned(self, device): + # This test is to ensure that nvfuser partitioned executor works correctly + # It's assumed that digamma is not supported by nvfuser + # If it's ever supported, this test will need to be updated + self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) + + from torch.fx.experimental.proxy_tensor import make_fx + from torch._prims.context import TorchRefsMode + from torch._prims.executor import execute + + a = torch.randn(3, 4, device=device) + b = torch.rand(3, 1, device=device) + c = torch.rand(3, 4, device=device) + + def func(a, b, c): + aa = torch.digamma(a) # not supported by nvfuser + d = torch.add(b, c) + dd = torch.sqrt(d) + return torch.mul(aa, dd.digamma()) + + with TorchRefsMode(): + gm = make_fx(func)(a, b, c) + + expected = execute(gm, a, b, c, executor="aten") + actual = execute(gm, a, b, c, executor="nvfuser") + self.assertEqual(expected, actual) + + @onlyCUDA + @skipCUDAIfRocm + def test_nvfuser_executor_partitioned_no_partitions_error(self, device): + # This test is to ensure that nvfuser partitioned executor works correctly + # It's assumed that digamma is not supported by nvfuser + # If it's ever supported, this test will need to be updated + self.assertTrue(torch.ops.prims.digamma.default.impl_nvfuser is None) + + from torch.fx.experimental.proxy_tensor import make_fx + from torch._prims.context import TorchRefsMode + from torch._prims.executor import execute + + a = torch.randn(3, 4, device=device) + + def func(a): + return torch.digamma(a) # not supported by nvfuser + + with TorchRefsMode(): + gm = make_fx(func)(a) + + with catch_warnings(record=True) as w: + # Trigger warning + execute(gm, a, executor="nvfuser") + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("is not supported by nvFuser" in str(w[-1].message)) + @onlyCUDA @skipCUDAIfRocm @dtypes(torch.float32) @@ -149,9 +311,10 @@ def _wrapper(a): @onlyCUDA @skipCUDAIfRocm @dtypes(torch.float32) - def test_pytree_output(self, device, dtype): + def test_pytree_input_output(self, device, dtype): @make_traced - def fn(a, b): + def fn(a, b_dict): + b = b_dict["b"] d = {} d["c"] = torch.add(a, b) return (d, torch.add(a, d["c"])) @@ -159,11 +322,75 @@ def fn(a, b): make_arg = partial(make_tensor, device=device, dtype=dtype) a = make_arg((5, 5)) b = make_arg((1, 5)) + b_dict = {"b": b} - result_aten = fn(a, b, executor="aten") - result_nvfuser = fn(a, b, executor="nvfuser") + result_aten = fn(a, b_dict, executor="aten") + result_nvfuser = fn(a, b_dict, executor="nvfuser") self.assertEqual(result_aten, result_nvfuser) + @dtypes(torch.float32) + def test_memory_format_strides(self, device, dtype): + shapes = ( + (), + (0,), + (1,), + (5), + (1, 0), + (1, 1), + (3, 7), + (3, 0, 2), + (1, 1, 2), + (4, 1, 1), + (7, 8, 9), + ) + + channels_last_shapes = ( + (0, 0, 0, 0), + (1, 0, 3, 0), + (0, 2, 3, 5), + (2, 2, 2, 0), + (5, 4, 3, 2), + (8, 8, 7, 2), + (9, 1, 3, 1), + (4, 5, 8, 7) + ) + + channels_last_3d_shapes = ( + (0, 8, 7, 9, 2), + (5, 0, 7, 9, 2), + (5, 0, 7, 9, 0), + (5, 8, 7, 9, 2), + (5, 1, 7, 9, 2), + (5, 1, 7, 9, 1), + ) + + pairs = ( + (shapes, torch.contiguous_format), + (channels_last_shapes, torch.contiguous_format), + (channels_last_3d_shapes, torch.contiguous_format), + (channels_last_shapes, torch.channels_last), + (channels_last_3d_shapes, torch.channels_last_3d), + ) + + for shapes, memory_format in pairs: + for shape in shapes: + # tests empty + expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format) + actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format) + self.assertEqual(expected.stride(), actual.stride()) + + # tests clone + a = torch.testing.make_tensor(shape, device=device, dtype=dtype) + expected = torch.clone(a, memory_format=memory_format) + actual = torch.clone(a, memory_format=memory_format) + self.assertEqual(expected.stride(), actual.stride()) + + # tests contiguous + a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True) + expected = a.contiguous(memory_format=memory_format) + actual = refs.contiguous(a, memory_format=memory_format) + self.assertEqual(expected.stride(), actual.stride()) + class TestPrimsBasic(TestCase): def test_torch_ops(self): @@ -184,5 +411,45 @@ def test_mul_complex(self): instantiate_device_type_tests(TestPrims, globals()) + +class TestRefs(TestCase): + @dtypes(torch.float32) + def test_constant_pad_nd_memory_format(self, device, dtype): + # Test memory format is preserved in unambiguous cases + for mf, ndim in ( + (torch.channels_last, 4), + (torch.contiguous_format, 4), + (torch.channels_last_3d, 5), + (torch.contiguous_format, 5), + ): + a = torch.zeros([2] * ndim).to(memory_format=mf) + res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim)) + self.assertTrue(res.is_contiguous(memory_format=mf)) + + # Ambiguous cases + + # is_channels_last_ and is_contiguous_, results in channels_last output + a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1)) + self.assertTrue(a.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(a.is_contiguous()) + actual = refs.constant_pad_nd(a, pad=[1] * 8) + expect = torch.constant_pad_nd(a, pad=[1] * 8) + self.assertEqual(actual.stride(), expect.stride()) + self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last)) + + # is_channels_last_contiguous_ but not is_channels_last_, results in + # contiguous output + a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1)) + self.assertTrue(a.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(a.is_contiguous()) + actual = refs.constant_pad_nd(a, pad=[1] * 8) + expect = torch.constant_pad_nd(a, pad=[1] * 8) + self.assertEqual(actual.stride(), expect.stride()) + self.assertTrue(actual.is_contiguous()) + + +instantiate_device_type_tests(TestRefs, globals()) + + if __name__ == "__main__": run_tests() diff --git a/test/test_profiler.py b/test/test_profiler.py index e093a50178aaf..a5f60837e898b 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -1,13 +1,15 @@ # Owner(s): ["oncall: profiler"] import collections +import expecttest import gc import io import json import os import re import tempfile +from typing import List, Optional import unittest -import time +from dataclasses import dataclass, field import torch import torch.nn as nn @@ -26,6 +28,14 @@ DeviceType, ProfilerAction, ProfilerActivity, ExecutionGraphObserver, _utils ) +from torch.profiler._pattern_matcher import (Pattern, NamePattern, + ExtraCUDACopyPattern, + ForLoopIndexingPattern, + FP32MatMulPattern, + OptimizerSingleTensorPattern, + SynchronizedDataLoaderPattern, + GradNotSetToNonePattern, + Conv2dBiasFollowedByBatchNorm2dPattern) from torch.testing._internal.common_device_type import skipCUDAVersionIn try: @@ -645,6 +655,46 @@ def create_mkldnn_tensor(): ] ) + def test_oom_tracing(self): + def run_profiler(tensor_creation_fn): + with _profile(profile_memory=True, record_shapes=True) as prof: + with self.assertRaisesRegex(RuntimeError, ".*[tT]ried to allocate.*"): + x = tensor_creation_fn() + return prof + + def create_cuda_tensor_oom(): + device = torch.device("cuda:0") + return torch.empty(1024, 1024, 1024, 20, dtype=torch.float32, device=device) + + def check_trace(fname): + prof.export_chrome_trace(fname) + with io.open(fname, 'r') as f: + trace = json.load(f) + self.assertTrue("traceEvents" in trace) + events = trace["traceEvents"] + found_out_of_memory_events = False + for evt in events: + self.assertTrue("name" in evt) + if evt["name"] == "[OutOfMemory]": + found_out_of_memory_events = True + self.assertTrue("args" in evt) + self.assertTrue("Device Type" in evt["args"]) + self.assertTrue("Device Id" in evt["args"]) + self.assertTrue("Bytes" in evt["args"]) + + # Memory should be an instantaneous event. + self.assertTrue("dur" not in evt["args"]) + self.assertTrue("cat" not in evt["args"]) + self.assertTrue(found_out_of_memory_events) + + if torch.cuda.is_available(): + with TemporaryFileName(mode="w+") as fname: + prof = run_profiler(create_cuda_tensor_oom) + check_trace(fname) + + + + @unittest.skipIf(not kineto_available(), "Kineto is required") def test_module_hierarchy(self): class A(nn.Module): @@ -1110,6 +1160,241 @@ def test_profiler_correlation_id(self): id_uniqueness_set.add(corr_id) self.assertTrue(corr_id < uint32_max) +def find_node_with_name(nodes, name): + for node in nodes: + if node.name() == name: + return node + result = find_node_with_name(node.children, name) + if result is not None: + return result + +class TestTorchTidyProfiler(TestCase): + def test_extra_fields(self): + with profile(with_stack=True, profile_memory=True) as p: + _ = torch.ones((1,)) + + nodes = p.profiler.kineto_results.experimental_event_tree() + node = find_node_with_name(nodes, "aten::ones") + self.assertIsNotNone(node) + + self.assertIsInstance( + node.extra_fields, + torch._C._autograd._ExtraFields_TorchOp) + + self.assertIsInstance( + node.parent.extra_fields, + torch._C._autograd._ExtraFields_PyCCall) + + self.assertEqual(node.children[0].name(), "aten::empty") + self.assertEqual(node.children[0].children[0].name(), "[memory]") + self.assertIsInstance( + node.children[0].children[0].extra_fields, + torch._C._autograd._ExtraFields_Allocation) + + def test_tensor_properties(self): + x = torch.ones(10, 10).as_strided([4, 4], [12, 3]) + y = torch.ones(4, 1) + + with profile(with_stack=True, profile_memory=True, record_shapes=True) as p: + _ = x + y + + nodes = p.profiler.kineto_results.experimental_event_tree() + node = find_node_with_name(nodes, "aten::add") + self.assertIsNotNone(node) + + self.assertIsInstance( + node.extra_fields, + torch._C._autograd._ExtraFields_TorchOp) + + self.assertEqual(node.extra_fields.inputs.shapes, [[4, 4], [4, 1], []]) + + input_info = node.extra_fields.inputs + self.assertEqual(input_info.dtypes, ['float', 'float', 'Scalar']) + + layout_info = [x.layout if x else None for x in input_info.tensor_metadata] + self.assertEqual(layout_info, [torch.strided, torch.strided, None]) + + +@dataclass(frozen=True) +class MockKinetoEvent(): + _name: str + _start_us: int + _duration_us: int + _linked_correlation_id: int + _device_type: int + + def name(self) -> str: + return self._name + + def start_us(self) -> int: + return self._start_us + + def duration_us(self) -> int: + return self._duration_us + + def linked_correlation_id(self) -> int: + return self._linked_correlation_id + + def device_type(self) -> DeviceType: + return DeviceType.CUDA if self._device_type == 1 else DeviceType.CPU + + +@dataclass(frozen=True) +class MockProfilerEvent(): + _name: str + id: int + start_time_ns: int + duration_time_ns: int + correlation_id: int = 0 + children: List["MockProfilerEvent"] = field(default_factory=list) + parent: Optional["MockProfilerEvent"] = None + + @property + def end_time_ns(self): + return self.start_time_ns + self.duration_time_ns + + def name(self) -> str: + return self._name + + def __post__init__(self, parent, children): + object.__setattr__(self, "parent", parent) + object.__setattr__(self, "children", children) + + +class TestExperimentalUtils(TestCase): + + @staticmethod + def generate_mock_profile(): + cuda_events = [ + MockKinetoEvent("cudaLaunchKernel", 400, 100, 1, 0), + MockKinetoEvent("cudaLaunchKernel", 500, 100, 2, 0), + MockKinetoEvent("cudaLaunchKernel", 600, 100, 3, 0), + MockKinetoEvent("cudaLaunchKernel", 700, 100, 4, 0), + MockKinetoEvent("cudaLaunchKernel", 800, 100, 5, 0), + MockKinetoEvent("cudaLaunchKernel", 1500, 100, 6, 0), + MockKinetoEvent("GPU", 900, 100, 1, 1), + MockKinetoEvent("GPU", 1000, 100, 2, 1), + MockKinetoEvent("GPU", 1100, 100, 3, 1), + MockKinetoEvent("GPU", 1200, 100, 4, 1), + MockKinetoEvent("GPU", 1300, 100, 5, 1), + MockKinetoEvent("GPU", 1700, 100, 6, 1) + ] + cpu_events = [ + MockProfilerEvent("CPU (Before cudaLaunchKernel)", 1, 0, 100000), + MockProfilerEvent("CPU (Before cudaLaunchKernel)", 2, 100000, + 100000), + MockProfilerEvent("CPU (Before cudaLaunchKernel)", 3, 200000, + 100000), + MockProfilerEvent("CPU (Before cudaLaunchKernel)", 4, 300000, + 100000), + MockProfilerEvent("CPU (After cudaLaunchKernel)", 5, 400000, + 100000), + MockProfilerEvent("CPU (After cudaLaunchKernel)", 6, 500000, + 100000), + MockProfilerEvent("CPU (After cudaLaunchKernel)", 7, 600000, + 100000), + MockProfilerEvent("CPU (After cudaLaunchKernel)", 8, 700000, + 100000), + MockProfilerEvent("CPU (After GPU)", 9, 800000, 100000), + MockProfilerEvent("CPU (After GPU)", 10, 900000, 100000), + MockProfilerEvent("CPU (After GPU)", 11, 1100000, 100000), + MockProfilerEvent("CPU (After GPU)", 12, 1200000, 500000), + ] + + profiler = unittest.mock.Mock() + profiler.kineto_results = unittest.mock.Mock() + profiler.kineto_results.events = unittest.mock.Mock( + return_value=cuda_events) + profiler.kineto_results.experimental_event_tree = unittest.mock.Mock( + return_value=cpu_events) + return profiler + + @staticmethod + def load_mock_profile(): + accept = expecttest.ACCEPT + json_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "profiler_utils_mock_events.json") + if accept and torch.cuda.is_available(): + + def garbage_code(x): + for i in range(5): + x[0, i] = i + + x = torch.ones((4096, 4096), device="cuda") + x = x @ x + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True) as prof: + for _ in range(5): + x = x @ x + garbage_code(x) + for _ in range(5): + x = x @ x + + kineto_events = [{ + '_name': + e.name(), + '_start_us': + e.start_us(), + '_duration_us': + e.duration_us(), + '_linked_correlation_id': + e.linked_correlation_id(), + '_device_type': + 1 if e.device_type() == DeviceType.CUDA else 0 + } for e in prof.profiler.kineto_results.events()] + + def EventTreeDFS(event_tree): + from collections import deque + stack = deque(event_tree) + while stack: + curr_event = stack.pop() + yield curr_event + for child_event in curr_event.children: + stack.append(child_event) + + profiler_events = [{ + '_name': e.name(), + 'id': e.id, + 'start_time_ns': e.start_time_ns, + 'duration_time_ns': e.duration_time_ns, + 'correlation_id': e.correlation_id, + 'children': [child.id for child in e.children], + 'parent': e.parent.id if e.parent else None + } for e in EventTreeDFS( + prof.profiler.kineto_results.experimental_event_tree())] + + with open(json_file_path, "w") as f: + json.dump([kineto_events, profiler_events], f) + + assert (os.path.exists(json_file_path)) + with open(json_file_path, "r") as f: + kineto_events, profiler_events = json.load(f) + + cuda_events = [ + MockKinetoEvent(*event.values()) for event in kineto_events + ] + cpu_events = [] + id_map = {} + for e in profiler_events: + event = MockProfilerEvent(**e) + id_map[event.id] = event + cpu_events.append(event) + for event in cpu_events: + parent = None if event.parent is None else id_map[event.parent] + children = [id_map[child] for child in event.children] + event.__post__init__(parent, children) + cpu_events = [event for event in cpu_events if event.parent is None] + profiler = unittest.mock.Mock() + profiler.kineto_results = unittest.mock.Mock() + profiler.kineto_results.events = unittest.mock.Mock( + return_value=cuda_events) + profiler.kineto_results.experimental_event_tree = unittest.mock.Mock( + return_value=cpu_events) + return profiler + def test_utils_compute_self_time(self): with profile() as prof: t1, t2 = torch.ones(1, requires_grad=True), torch.ones( @@ -1129,52 +1414,301 @@ def test_utils_compute_self_time(self): for child in event_key.event.children ])) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_utils_intervals_overlap(self): + event = _utils.EventKey(MockProfilerEvent("Event 1", 1, 5, 5)) + intervals = [ + _utils.Interval(0, 9), + _utils.Interval(1, 2), + _utils.Interval(2, 3), + _utils.Interval(3, 4), + _utils.Interval(4, 5), + _utils.Interval(8, 12), + ] + print(event.intervals_overlap(intervals)) + self.assertEqual(event.intervals_overlap(intervals), 5) + def test_utils_compute_queue_depth(self): - x = torch.ones((4096, 4096), device="cuda") + + def format_queue_depth(queue_depth_list, events): + res = "" + for data, event in zip(queue_depth_list, events): + res += f"{data.queue_depth} [{event.name()}]\n" + return res + + # We have to use Mock because time series data is too flaky to test + profiler = self.generate_mock_profile() + basic_evaluation = _utils.BasicEvaluation(profiler) + self.assertExpectedInline( + format_queue_depth(basic_evaluation.queue_depth_list, + basic_evaluation.cuda_events), """\ +1 [cudaLaunchKernel] +2 [cudaLaunchKernel] +3 [cudaLaunchKernel] +4 [cudaLaunchKernel] +5 [cudaLaunchKernel] +4 [GPU] +3 [GPU] +2 [GPU] +1 [GPU] +0 [GPU] +1 [cudaLaunchKernel] +0 [GPU] +""") + self.assertExpectedInline( + format_queue_depth([ + basic_evaluation.metrics[k] + for k in basic_evaluation.event_keys + ], basic_evaluation.events), """\ +0 [CPU (Before cudaLaunchKernel)] +0 [CPU (Before cudaLaunchKernel)] +0 [CPU (Before cudaLaunchKernel)] +0 [CPU (Before cudaLaunchKernel)] +1 [CPU (After cudaLaunchKernel)] +2 [CPU (After cudaLaunchKernel)] +3 [CPU (After cudaLaunchKernel)] +4 [CPU (After cudaLaunchKernel)] +5 [CPU (After GPU)] +4 [CPU (After GPU)] +2 [CPU (After GPU)] +1 [CPU (After GPU)] +""") + + def test_utils_compute_queue_depth_when_no_cuda_events(self): + # For traces with only cpu events, we expect empty queue depth list + x = torch.ones((1024, 1024)) with profile() as prof: - # First half we want it to be compute bound for _ in range(5): - y = torch.mm(x, x) - # Second half we want it to be overhead bound - # So we are synchronize and sleeping - torch.cuda.synchronize() - for _ in range(3): - y[0] += 1 - time.sleep(0.1) + x = x @ x basic_evaluation = _utils.BasicEvaluation(prof.profiler) - for entry in basic_evaluation.compute_queue_depth(): - self.assertTrue(entry.queue_depth >= 0) - - - def test_extra_fields(self): - with profile(with_stack=True, profile_memory=True) as p: - _ = torch.ones((1,)) + self.assertFalse(basic_evaluation.compute_queue_depth()) + + def test_utils_compute_idle_time(self): + profiler = self.generate_mock_profile() + basic_evaluation = _utils.BasicEvaluation(profiler) + expected_output = "\n".join([ + f"{basic_evaluation.metrics[event_key].idle_time_ns} [{event_key.event.name()}]" + for event_key in basic_evaluation.event_keys + ]) + self.assertExpectedInline( + expected_output, """\ +100000 [CPU (Before cudaLaunchKernel)] +100000 [CPU (Before cudaLaunchKernel)] +100000 [CPU (Before cudaLaunchKernel)] +100000 [CPU (Before cudaLaunchKernel)] +0 [CPU (After cudaLaunchKernel)] +0 [CPU (After cudaLaunchKernel)] +0 [CPU (After cudaLaunchKernel)] +0 [CPU (After cudaLaunchKernel)] +0 [CPU (After GPU)] +0 [CPU (After GPU)] +0 [CPU (After GPU)] +100000 [CPU (After GPU)]""") + + def test_utils_get_optimizable_events(self): + basic_evaluation = _utils.BasicEvaluation(self.load_mock_profile()) + optimizable_events = basic_evaluation.get_optimizable_events( + 2, print_enable=False) + expected_output = "\n".join( + [f"{event_key.event.name()}" for event_key in optimizable_events]) + self.assertExpectedInline( + expected_output, """\ + +aten::copy_""") + + def test_profiler_name_pattern(self): + x = torch.ones((4096, 4096)) + with profile() as prof: + for _ in range(5): + x = x @ x + x = x + x + matched_events = NamePattern(prof, "aten::mm").matched_events() + output = "\n".join([f"{event.name()}" for event in matched_events]) + self.assertExpectedInline(output, """\ +aten::mm +aten::mm +aten::mm +aten::mm +aten::mm""") + + def test_profiler_pattern_match_helper(self): + x = torch.ones((100, 100)) + with profile() as prof: + for _ in range(5): + x = x @ x + x = x + x + event_tree = prof.profiler.kineto_results.experimental_event_tree() + pattern = Pattern(prof) + self.assertEqual([], pattern.siblings_of(event_tree[0])[0]) + self.assertEqual(event_tree[1:], pattern.siblings_of(event_tree[0])[1]) + child_nodes = event_tree[0].children + self.assertEqual([], pattern.siblings_of(child_nodes[0])[0]) + self.assertEqual(child_nodes[1:], pattern.siblings_of(child_nodes[0])[1]) + self.assertEqual(event_tree[0], + pattern.root_of(event_tree[0].children[0].children[0])) + self.assertEqual(None, pattern.next_of(event_tree[-1])) + self.assertEqual(event_tree[1], pattern.next_of(event_tree[0])) + self.assertEqual(event_tree[0], pattern.prev_of(event_tree[1])) - def find_ones(nodes): - for n in nodes: - if n.name() == "aten::ones": - return n - result = find_ones(n.children) - if result: - return result + @unittest.skipIf(TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite.") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_profiler_extra_cuda_copy_pattern(self): + cases = ( + (0, lambda: torch.ones((100, 100), device="cuda")), + (1, lambda: torch.ones((100, 100)).to("cuda")), + (1, lambda: torch.zeros((100, 100)).to("cuda")), + (1, lambda: torch.empty((100, 100)).fill_(5).to("cuda")), + (1, lambda: torch.ones((100, 100)).cuda()), + (1, lambda: torch.zeros((100, 100)).cuda()), + (1, lambda: torch.empty((100, 100)).fill_(5).cuda()), + (1, lambda: torch.rand((100, 100)).cuda()), + (1, lambda: torch.randn((100, 100)).cuda()), + (1, lambda: torch.full((100, 100), 10).cuda()), + ) + num_matched = [] + for _, fn in cases: + with profile(with_stack=True, record_shapes=True) as prof: + fn() + pattern = ExtraCUDACopyPattern(prof) + num_matched.append(len(pattern.matched_events())) + self.assertEqual(num_matched, [i for i, _ in cases]) + + @unittest.skipIf(TEST_WITH_CROSSREF, + "crossref intercepts calls and changes the callsite.") + def test_profiler_for_loop_indexing_pattern(self): + x = torch.ones((100, 100)) + + def case1(): + for i in range(100): + x[i] = i + + def case2(): + y = 0 + for i in range(100): + y += x[i] + + def case3(): + y = 1 + for i in range(100): + y *= x[i] + + def case4(): + y = x + for _ in range(100): + y = y @ x + + def case5(): + for i in range(100): + x[i, :] = torch.arange(100) + i + + cases = ((1, case1), (1, case2), (1, case3), (0, case4), (1, case5)) + num_matched = [] + for _, fn in cases: + with profile(with_stack=True) as prof: + fn() + pattern = ForLoopIndexingPattern(prof) + num_matched.append(len(pattern.matched_events())) + self.assertEqual(num_matched, [i for i, _ in cases]) - node = find_ones(p.profiler.kineto_results.experimental_event_tree()) - self.assertIsNotNone(node) - self.assertIsInstance( - node.extra_fields, - torch._C._autograd._ExtraFields_TorchOp) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_profiler_fp32_matmul_pattern(self): + x = torch.ones((100, 100), device="cuda") + with profile(with_stack=True) as prof: + x = x @ x + pattern = FP32MatMulPattern(prof) + has_tf32 = 0 if pattern.skip else 1 + num_matched = len(pattern.matched_events()) + self.assertEqual(num_matched, has_tf32) - self.assertIsInstance( - node.parent.extra_fields, - torch._C._autograd._ExtraFields_PyCCall) - self.assertEqual(node.children[0].name(), "aten::empty") - self.assertEqual(node.children[0].children[0].name(), "[memory]") - self.assertIsInstance( - node.children[0].children[0].extra_fields, - torch._C._autograd._ExtraFields_Allocation) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_profiler_extra_cuda_copy_pattern_benchmark(self): + with profile(with_stack=True, record_shapes=True) as prof: + x = torch.ones((100, 100)).to("cuda") + x = torch.ones((50, 50)).to("cuda") + pattern = ExtraCUDACopyPattern(prof) + shapes_factor_map = pattern.benchmark(pattern.matched_events()) + self.assertEqual(len(shapes_factor_map), 2) + + def test_profiler_optimizer_single_tensor_pattern(self): + x = torch.ones((100, 100)) + cases = ( + (1, lambda: torch.optim.Adam(model.parameters())), + (1, lambda: torch.optim.SGD(model.parameters(), lr=0.01)), + (1, lambda: torch.optim.AdamW(model.parameters())), + (0, lambda: torch.optim.Adam(model.parameters(), foreach=True)), + (0, lambda: torch.optim.SGD(model.parameters(), lr=0.01, foreach=True)), + (0, lambda: torch.optim.AdamW(model.parameters(), foreach=True)), + ) + num_matched = [] + for _, fn in cases: + with profile(with_stack=True) as prof: + model = nn.Sequential( + nn.Linear(100, 100), + nn.ReLU(), + nn.Linear(100, 10), + ) + optimizer = fn() + optimizer.zero_grad() + y_hat = model(x) + loss = torch.nn.functional.cross_entropy(y_hat, torch.randint(0, 10, (100,))) + loss.backward() + optimizer.step() + pattern = OptimizerSingleTensorPattern(prof) + num_matched.append(len(pattern.matched_events())) + self.assertEqual(num_matched, [i for i, _ in cases]) + + def test_profiler_synchronized_dataloader_pattern(self): + dataset = torch.rand((100, 100)) + sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10) + async_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=4) + with profile(with_stack=True) as prof: + next(iter(sync_dataloader)) + next(iter(async_dataloader)) + pattern = SynchronizedDataLoaderPattern(prof) + num_matched = len(pattern.matched_events()) + self.assertEqual(num_matched, 1) + + def test_profiler_grad_not_set_to_none_pattern(self): + x = torch.ones((100, 100)) + model = nn.Sequential( + nn.Linear(100, 100), + nn.ReLU(), + nn.Linear(100, 10), + ) + optimizer = torch.optim.Adam(model.parameters()) + cases = ( + (1, lambda: optimizer.zero_grad()), + (1, lambda: model.zero_grad()), + (0, lambda: optimizer.zero_grad(set_to_none=True)), + (0, lambda: model.zero_grad(set_to_none=True)) + ) + num_matched = [] + for _, fn in cases: + with profile(with_stack=True) as prof: + y_hat = model(x) + loss = torch.nn.functional.cross_entropy(y_hat, torch.randint(0, 10, (100,))) + loss.backward() + optimizer.step() + fn() + pattern = GradNotSetToNonePattern(prof) + num_matched.append(len(pattern.matched_events())) + self.assertEqual(num_matched, [i for i, _ in cases]) + + def test_profiler_conv2d_bias_followed_by_batchnorm2d_pattern(self): + x = torch.randn((1, 3, 32, 32)) + cases = ( + (1, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1), nn.BatchNorm2d(3))), + (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1, bias=False), nn.BatchNorm2d(3))), + (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1))) + ) + num_matched = [] + for _, model in cases: + with profile(with_stack=True, record_shapes=True) as prof: + model(x) + pattern = Conv2dBiasFollowedByBatchNorm2dPattern(prof) + num_matched.append(len(pattern.matched_events())) + self.assertEqual(num_matched, [i for i, _ in cases]) if __name__ == '__main__': diff --git a/test/test_profiler_tree.py b/test/test_profiler_tree.py index 7b988401829dd..41d7f770461b1 100644 --- a/test/test_profiler_tree.py +++ b/test/test_profiler_tree.py @@ -4,12 +4,35 @@ import os import re import textwrap +import traceback import unittest +import expecttest + import torch +from torch._C._autograd import _ExtraFields_PyCall, _ExtraFields_PyCCall from torch.testing._internal.common_utils import ( TestCase, run_tests, IS_WINDOWS, TEST_WITH_CROSSREF) +# These functions can vary from based on platform and build (e.g. with CUDA) +# and generally distract from rather than adding to the test. +PRUNE_FUNCTIONS = { + "torch/profiler/profiler.py(...): start": True, + "torch/profiler/profiler.py(...): stop_trace": True, + "cudaStreamIsCapturing": False, + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": False, +} + +# ROCTracer is currently not producing events that profiler can extract. We +# should bring it up to parity with CUPTI Kineto / profiler integration, but in +# the mean time there is still utility in running tests but not checking that +# the values match expected value. +# 1) We will still catch runtime errors and assert failures +# 2) We can diff the output to see how far we are from parity +# +# TODO: We also fail to capture events for Windows on some platforms. +ALLOW_CUDA_FAILURE = (torch.version.hip is not None) or IS_WINDOWS + class ProfilerTree: @@ -25,11 +48,14 @@ def test(f): """ @functools.wraps(f) - def begin_unit_test_marker(self, replicates=5): + def begin_unit_test_marker(self, replicates=3): try: for i in range(replicates): self.tree_replicate = i - return f(self) + out = f(self) + if self.tree_replicate is None: + break + return out finally: delattr(self, "tree_replicate") return begin_unit_test_marker @@ -42,12 +68,23 @@ def flatten(nodes, depth=0, out=None): out = [] for node in nodes: - out.append((depth, cls.fmt_name(node.name()))) - flatten(node.children, depth + 1, out) + cls.validate_node(node) + name = cls.fmt_name(node.name()) + add_ellipses = PRUNE_FUNCTIONS.get(name.strip(), None) + if add_ellipses is None: + out.append((depth, name)) + flatten(node.children, depth + 1, out) + elif add_ellipses: + out.append((depth, "...")) return out flat_nodes = flatten(profiler.kineto_results.experimental_event_tree()) + + # Profiler inserts a `cudaDeviceSynchronize` at the end of profiling. + if flat_nodes and flat_nodes[-1][1] == "cudaDeviceSynchronize": + flat_nodes = flat_nodes[:-1] + min_depth = min([d + 1 for d, name in flat_nodes if "begin_unit_test_marker" in name] or [0]) return textwrap.indent( "\n".join([f"{' ' * (d - min_depth)}{name.rstrip()}" for d, name in flat_nodes if d >= min_depth]), @@ -62,7 +99,7 @@ def fmt_name(name: str) -> str: match = re.match(r"(.*)\.py\(([0-9]+)\): (.*)$", name) if match: - filename, lineno, fn = match.groups() + filename, _, fn = match.groups() # This test can appear as `test/test_profiler_tree.py` depending on # where it is run from. @@ -73,16 +110,52 @@ def fmt_name(name: str) -> str: filename = filename.replace(os.sep, "/") # We don't want to have to update this test every time PyTorch changes. - lineno = lineno if os.path.split(filename.strip())[1] == "test_profiler_tree" else "..." + # At some point we should test some line numbers, but for now it's + # too brittle. + lineno = "..." + return f"{filename}.py({lineno}): {fn}" + for kernel_pattern in ( + "void at::native::elementwise_kernel", + "void at::native::reduce_kernel", + "void at::native::vectorized_elementwise_kernel", + "void at::native::unrolled_elementwise_kernel", + + r"void [a-zA-Z0-9]+_kernel", # Nvidia kernels. + ): + name = re.sub( + rf"{kernel_pattern}<.+>\(.+\)$", + f"{kernel_pattern.replace('[a-zA-Z0-9]+', '...')}<...>(...)", + name) + return re.sub( "object at 0x[0-9a-fA-F]+>", "object at 0xXXXXXXXXXXXX>", name) + @classmethod + def validate_node(cls, node): + extra_fields = node.extra_fields + if isinstance(extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall)): + # Check that the lineage established by the profiler matches the + # caller recorded by the Python tracer. + parent = node.parent + while parent is not None: + if isinstance(parent.extra_fields, _ExtraFields_PyCall): + break + parent = parent.parent + + def to_string(frame_state): + return f"{frame_state.file_name}(...): {frame_state.function_name}" + + if parent: + parent_name = to_string(parent.extra_fields.callsite) + caller_name = to_string(extra_fields.caller) + assert parent_name == caller_name, f"{parent_name} vs. {caller_name}" + class TestProfilerTree(TestCase): - def assertTreesMatch(self, actual: str, expected: str): + def assertTreesMatch(self, actual: str, expected: str, allow_failure: bool = False): # Warning: Here be dragons # Different platforms will have subtly different behavior for Python # tracing. Observed differences include: @@ -97,6 +170,11 @@ def assertTreesMatch(self, actual: str, expected: str): # change in the codebase which changes the trace produced, simply use # EXPECTTEST_ACCEPT=1 to update the tests to reflect the new structure. + # expecttest will not show the diff view if `len(actual) < len(expected)` + if not expecttest.ACCEPT: + actual = actual.ljust(len(expected)) + self.maxDiff = None + replicate = getattr(self, "tree_replicate", None) self.assertIsNotNone(replicate, "Please annotate test with `@ProfilerTree.test`") @@ -107,7 +185,15 @@ def assertTreesMatch(self, actual: str, expected: str): if replicate: self.assertEqual(actual, expected) else: - self.assertExpectedInline(actual, expected, skip=1) + try: + self.assertExpectedInline(actual, expected, skip=1) + except AssertionError as e: + if allow_failure: + self.tree_replicate = None + msg = traceback.format_exception_only(type(e), e)[0] + print(msg.split("AssertionError:")[-1]) + else: + raise @ProfilerTree.test def test_profiler_experimental_tree(self): @@ -296,71 +382,7 @@ def test_profiler_experimental_tree_with_memory(self): [memory]""" ) - self.assertTreesMatch( - ProfilerTree.format(p.profiler, 12), - """\ - aten::add - [memory] - aten::ones - aten::empty - [memory] - aten::fill_ - aten::sub - [memory] - aten::pow - aten::result_type - aten::to - [memory] - aten::ones_like - aten::empty_like - aten::empty_strided - [memory] - aten::fill_ - autograd::engine::evaluate_function: PowBackward0 - PowBackward0 - aten::pow - aten::result_type - aten::to - [memory] - aten::copy_ - aten::mul - [memory] - aten::mul - aten::to - aten::_to_copy - aten::empty_strided - [memory] - aten::copy_ - [memory] - [memory] - [memory] - aten::mul - [memory] - [memory] - [memory] - [memory] - autograd::engine::evaluate_function: SubBackward0 - SubBackward0 - aten::neg - [memory] - [memory] - autograd::engine::evaluate_function: AddBackward0 - AddBackward0 - autograd::engine::evaluate_function: torch::autograd::AccumulateGrad - torch::autograd::AccumulateGrad - aten::new_empty_strided - aten::empty_strided - [memory] - aten::copy_ - autograd::engine::evaluate_function: torch::autograd::AccumulateGrad - torch::autograd::AccumulateGrad - aten::detach - detach - [memory]""" - ) - @unittest.skipIf(TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite.") - @unittest.skipIf(torch.has_cuda, "CUDA invokes extra Python functions.") @ProfilerTree.test def test_profiler_experimental_tree_with_memory_and_stack(self): t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True) @@ -373,17 +395,9 @@ def test_profiler_experimental_tree_with_memory_and_stack(self): self.assertTreesMatch( ProfilerTree.format(p.profiler, 12), """\ - test_profiler_tree.py(367): test_profiler_experimental_tree_with_memory_and_stack + test_profiler_tree.py(...): test_profiler_experimental_tree_with_memory_and_stack torch/profiler/profiler.py(...): __enter__ - torch/profiler/profiler.py(...): start - torch/profiler/profiler.py(...): _transit_action - torch/profiler/profiler.py(...): start_trace - torch/autograd/profiler.py(...): _start_trace - - torch/profiler/profiler.py(...): _get_distributed_info - torch/distributed/__init__.py(...): is_available - - torch/distributed/distributed_c10d.py(...): is_initialized + ... aten::add [memory] @@ -415,8 +429,8 @@ def test_profiler_experimental_tree_with_memory_and_stack(self): aten::empty_strided [memory] aten::fill_ - - + + autograd::engine::evaluate_function: PowBackward0 PowBackward0 aten::pow @@ -461,16 +475,13 @@ def test_profiler_experimental_tree_with_memory_and_stack(self): torch/profiler/profiler.py(...): __exit__ torch/profiler/profiler.py(...): stop torch/profiler/profiler.py(...): _transit_action - + enum.py(...): __hash__ - torch/profiler/profiler.py(...): stop_trace - torch/autograd/profiler.py(...): __exit__ - """ + ...""" ) @unittest.skipIf(TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite.") - @unittest.skipIf(torch.has_cuda, "CUDA invokes extra Python functions.") @ProfilerTree.test def test_profiler_experimental_tree_with_stack_and_modules(self): class MyModule(torch.nn.Module): @@ -495,24 +506,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertTreesMatch( ProfilerTree.format(p.profiler, 12), """\ - test_profiler_tree.py(491): test_profiler_experimental_tree_with_stack_and_modules + test_profiler_tree.py(...): test_profiler_experimental_tree_with_stack_and_modules torch/profiler/profiler.py(...): __enter__ - torch/profiler/profiler.py(...): start - torch/profiler/profiler.py(...): _transit_action - torch/profiler/profiler.py(...): start_trace - torch/autograd/profiler.py(...): _start_trace - - torch/profiler/profiler.py(...): _get_distributed_info - torch/distributed/__init__.py(...): is_available - - torch/distributed/distributed_c10d.py(...): is_initialized + ... aten::ones aten::empty aten::fill_ nn.Module: MyModule_0 - test_profiler_tree.py(485): forward + test_profiler_tree.py(...): forward nn.Module: ReLU_0 torch/nn/modules/activation.py(...): forward @@ -532,12 +535,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: aten::transpose aten::as_strided aten::matmul - aten::t - aten::transpose - aten::as_strided - aten::mv - aten::empty - aten::addmv_ + aten::unsqueeze + aten::as_strided + aten::mm + aten::resolve_conj + aten::resolve_conj + aten::resolve_conj + aten::squeeze_ + aten::as_strided_ aten::add_ nn.Module: ReLU_1 @@ -553,7 +558,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: aten::fill_ nn.Module: MyModule_0 - test_profiler_tree.py(485): forward + test_profiler_tree.py(...): forward nn.Module: ReLU_0 torch/nn/modules/activation.py(...): forward @@ -573,12 +578,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: aten::transpose aten::as_strided aten::matmul - aten::t - aten::transpose - aten::as_strided - aten::mv - aten::empty - aten::addmv_ + aten::unsqueeze + aten::as_strided + aten::mm + aten::resolve_conj + aten::resolve_conj + aten::resolve_conj + aten::squeeze_ + aten::as_strided_ aten::add_ nn.Module: ReLU_1 @@ -594,10 +601,359 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: enum.py(...): __hash__ - torch/profiler/profiler.py(...): stop_trace - torch/autograd/profiler.py(...): __exit__ - """ + ...""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + @ProfilerTree.test + def test_profiler_experimental_tree_cuda(self): + with torch.profiler.profile(profile_memory=True) as p: + weight = torch.ones(1, device="cuda", requires_grad=True) + x = torch.ones(1, device="cuda") + y = torch.add(weight, x) + loss = torch.pow(y, 2) + loss.backward() + torch.optim.SGD([weight], lr=0.01, momentum=0.9).step() + + self.assertTreesMatch( + ProfilerTree.format(p.profiler, 12), + """\ + aten::ones + aten::empty + [memory] + aten::fill_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + aten::ones + aten::empty + [memory] + aten::fill_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + aten::add + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + aten::pow + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + aten::result_type + aten::to + [memory] + aten::ones_like + aten::empty_like + aten::empty_strided + [memory] + aten::fill_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + autograd::engine::evaluate_function: PowBackward0 + PowBackward0 + aten::pow + aten::result_type + aten::to + [memory] + aten::copy_ + cudaMemcpyAsync + Memcpy DtoD (Device -> Device) + aten::mul + [memory] + aten::mul + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + [memory] + aten::mul + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + [memory] + [memory] + autograd::engine::evaluate_function: AddBackward0 + AddBackward0 + autograd::engine::evaluate_function: torch::autograd::AccumulateGrad + torch::autograd::AccumulateGrad + aten::detach + detach + [memory] + aten::zeros + aten::empty + [memory] + aten::zero_ + Optimizer.step#SGD.step + aten::empty + [memory] + [memory] + [memory] + aten::clone + aten::empty_strided + [memory] + aten::copy_ + cudaMemcpyAsync + Memcpy DtoD (Device -> Device) + aten::detach + detach + aten::add_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory]""", # noqa: B950 + allow_failure=ALLOW_CUDA_FAILURE, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + @ProfilerTree.test + def test_profiler_experimental_tree_cuda_with_stream(self): + streams = [torch.cuda.Stream() for _ in range(3)] + results = [] + with torch.profiler.profile(profile_memory=True) as p: + x = torch.ones((4, 4), device="cuda") + for stream in streams: + with torch.cuda.stream(stream): + results.append(torch.tanh(x) - x) + del results + for s in streams: + torch.cuda.current_stream().wait_stream(s) + + self.assertTreesMatch( + ProfilerTree.format(p.profiler, 12), + """\ + aten::ones + aten::empty + [memory] + aten::fill_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + aten::tanh + cudaMalloc + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + aten::sub + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + [memory] + aten::tanh + cudaMalloc + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + aten::sub + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + [memory] + aten::tanh + cudaMalloc + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + aten::sub + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + [memory]""", + allow_failure=ALLOW_CUDA_FAILURE, + ) + + @unittest.skipIf(TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite.") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + @ProfilerTree.test + def test_profiler_experimental_tree_cuda_detailed(self): + model = torch.nn.modules.Linear(1, 1, device="cuda") + model.train() + opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + def step(): + x = torch.ones((1, 1), device="cuda") + loss = model(x) + loss.backward() + opt.step() + + # Warmup + for _ in range(3): + step() + + with torch.profiler.profile(profile_memory=True, with_stack=True) as p: + step() + + self.assertTreesMatch( + ProfilerTree.format(p.profiler, 12), + """\ + test_profiler_tree.py(...): test_profiler_experimental_tree_cuda_detailed + torch/profiler/profiler.py(...): __enter__ + ... + test_profiler_tree.py(...): step + + aten::ones + aten::empty + [memory] + aten::fill_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + nn.Module: Linear_0 + + torch/nn/modules/linear.py(...): forward + torch/nn/modules/module.py(...): __getattr__ + torch/nn/modules/module.py(...): __getattr__ + + aten::linear + aten::t + aten::transpose + aten::as_strided + aten::addmm + cudaMemcpyAsync + Memcpy DtoD (Device -> Device) + cudaLaunchKernel + void ..._kernel<...>(...) + [memory] + aten::expand + aten::as_strided + torch/_tensor.py(...): backward + + torch/autograd/__init__.py(...): backward + + + + torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple + torch/autograd/__init__.py(...): _make_grads + + + + aten::ones_like + aten::empty_like + aten::empty_strided + [memory] + aten::fill_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + + + autograd::engine::evaluate_function: AddmmBackward0 + AddmmBackward0 + aten::t + aten::transpose + aten::as_strided + aten::mm + cudaLaunchKernel + void ..._kernel<...>(...) + [memory] + aten::t + aten::transpose + aten::as_strided + aten::sum + aten::sum + cudaLaunchKernel + void at::native::reduce_kernel<...>(...) + [memory] + aten::view + aten::view + autograd::engine::evaluate_function: torch::autograd::AccumulateGrad + torch::autograd::AccumulateGrad + aten::add_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + autograd::engine::evaluate_function: TBackward0 + TBackward0 + aten::t + aten::transpose + aten::as_strided + autograd::engine::evaluate_function: torch::autograd::AccumulateGrad + torch::autograd::AccumulateGrad + aten::add_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + [memory] + torch/optim/optimizer.py(...): wrapper + + torch/autograd/profiler.py(...): __init__ + + aten::zeros + aten::empty + [memory] + aten::zero_ + torch/autograd/profiler.py(...): __enter__ + torch/_ops.py(...): __call__ + + Optimizer.step#SGD.step + aten::empty + [memory] + [memory] + [memory] + torch/optim/optimizer.py(...): _use_grad + + torch/autograd/grad_mode.py(...): __init__ + + + torch/optim/sgd.py(...): step + + + torch/_tensor.py(...): __hash__ + + + + + + torch/_tensor.py(...): __hash__ + + + + torch/optim/sgd.py(...): sgd + torch/optim/sgd.py(...): _single_tensor_sgd + + [memory] + aten::mul_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + + aten::add_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + + aten::add_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + + [memory] + aten::mul_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + [memory] + + aten::add_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + + aten::add_ + cudaLaunchKernel + void at::native::vectorized_elementwise_kernel<...>(...) + torch/_tensor.py(...): __hash__ + + + torch/_tensor.py(...): __hash__ + + + torch/autograd/grad_mode.py(...): __init__ + + + torch/autograd/profiler.py(...): __exit__ + torch/_ops.py(...): __call__ + + [memory] + [memory] + torch/profiler/profiler.py(...): __exit__ + torch/profiler/profiler.py(...): stop + torch/profiler/profiler.py(...): _transit_action + + enum.py(...): __hash__ + + ...""", # noqa: B950 + allow_failure=ALLOW_CUDA_FAILURE, ) + if __name__ == '__main__': run_tests() diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 53851b484a5c4..2d7caa807acdd 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1,15 +1,66 @@ -# Owner(s): ["oncall: fx"] +# Owner(s): ["module: ProxyTensor"] from torch.testing._internal.common_utils import TestCase, run_tests import torch import unittest import warnings +import torch.nn.utils._stateless as stateless +from collections.abc import Iterable from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed +from torch._subclasses.fake_tensor import DynamicOutputShapeException +from torch._decomp import decomposition_table from torch.testing._internal.common_device_type import ops -from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule +from torch.utils._pytree import tree_map +from torch import nn +import re + +aten = torch.ops.aten + +try: + import sympy # noqa: F401 + HAS_SYMPY = True +except ImportError: + HAS_SYMPY = False +skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") + + +def process_failures(): + """ + Takes file containing failures like + + FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950 + + and processes them into a list of opinfo xfails + """ + f = open('pytest_failures') + failures = f.readlines() + failures = [i.strip() for i in failures] + + def process_failure_string(s, matcher): + out = re.search(matcher, s) + return out.groups() + + SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)' + failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures] + + def create_normalized_name(op): + if op.variant_test_name == '': + s = op.name + else: + s = f"{op.name}.{op.variant_test_name}" + return s.replace('.', '_') + + remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db} + + print("symbolic_tensor_failures = {") + for failure, reason in failures: + print(f" xfail{remap_opinfo[failure]}, # {reason}") + print("}") + # Copied from functorch def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): @@ -58,40 +109,148 @@ def wrapped(fn): UserWarning) +def _create_new_input(x): + if not isinstance(x, torch.Tensor): + return x + if x.dtype != torch.float: + return x + 1 + if x.is_leaf: + return torch.rand_like(x, requires_grad=True) + else: + return torch.rand_like(x) + class TestProxyTensor(TestCase): - def test_make_fx(self, device): + def _test(self, f, inps): + fx_f = make_fx(f)(*inps) + new_inps = tree_map(_create_new_input, inps) + self.assertEqual(fx_f(*new_inps), f(*new_inps)) + + def test_make_fx_simple(self, device): def f(x): return torch.sin(x) - inp = torch.randn(3) - fx_f = make_fx(f)(inp) - - new_inp = torch.randn(3) - self.assertEqual(fx_f(new_inp), f(new_inp)) + self._test(f, (torch.randn(3),)) def test_scalar_device(self, device): def f(a, b): return a + b - inps = [torch.randn(3, device=device), torch.tensor(5)] - fx_f = make_fx(f)(*inps) - self.assertEqual(fx_f(*inps), f(*inps)) + self._test(f, [torch.randn(3, device=device), torch.tensor(5)]) + + def test_isolated_graphmodule(self): + def is_any_sum(gm): + return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes) + + def is_any_digamma(gm): + return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes) + def is_any_sigmoid(gm): + return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes) + + def inner(x): + return torch.sum(x) + + def f(x): + gm = get_isolated_graphmodule(inner, (x,), {}) + self.assertTrue(is_any_sum(gm)) + return x + torch.randn(x.shape) + + # get_isolated_graphmodule uses make_fx internally that shouldn't be traced + # by the outer make_fx call + traced = make_fx(f)(torch.randn(3)) + self.assertFalse(is_any_sum(traced)) + + # When factory functions are used, they should not be traced + # by the outer make_fx call + def inner_with_factory(): + val = torch.tensor(float(1)) + val.add_(2) + return torch.full((10, 10), val).sum() + + def f1(x): + gm = get_isolated_graphmodule(inner_with_factory, (), {}) + self.assertTrue(is_any_sum(gm)) + return torch.sigmoid(x) + + def f2(x): + gm = get_isolated_graphmodule(f1, (x,), {}) + self.assertFalse(is_any_sum(gm)) + self.assertTrue(is_any_sigmoid(gm)) + return torch.digamma(x) + + traced = make_fx(f2)(torch.randn(3)) + self.assertFalse(is_any_sum(traced)) + self.assertFalse(is_any_sigmoid(traced)) + self.assertTrue(is_any_digamma(traced)) + + # Verify nested make_fx calls don't make factory functions to be leaked + # into the outer graph + def f2(x): + gm = make_fx(f1)(x) + self.assertFalse(is_any_sum(gm)) + self.assertTrue(is_any_sigmoid(gm)) + return torch.digamma(x) + + traced = make_fx(f2)(torch.randn(3)) + self.assertFalse(is_any_sum(traced)) + self.assertTrue(is_any_sigmoid(traced)) + self.assertTrue(is_any_digamma(traced)) + + # Verify interaction with non-ProxyTensor modes + from torch.testing._internal.logging_tensor import LoggingTensorMode + + def f1_logging(x): + with LoggingTensorMode(): + gm = get_isolated_graphmodule(inner_with_factory, (), {}) + self.assertTrue(is_any_sum(gm)) + return torch.sigmoid(x) + + def f2_logging(x): + with LoggingTensorMode(), LoggingTensorMode(): + gm = get_isolated_graphmodule(f1_logging, (x,), {}) + self.assertFalse(is_any_sum(gm)) + self.assertTrue(is_any_sigmoid(gm)) + return torch.digamma(x) + + traced = make_fx(f2_logging)(torch.randn(3)) + self.assertFalse(is_any_sum(traced)) + self.assertFalse(is_any_sigmoid(traced)) + self.assertTrue(is_any_digamma(traced)) + + # Verify interaction with another tensor subclass + # This case currently doesn't work and should raise an error + # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068 + from torch.testing._internal.logging_tensor import LoggingTensor + + def f1_logging_tensor(x): + gm = get_isolated_graphmodule(inner_with_factory, (), {}) + self.assertTrue(is_any_sum(gm)) + return torch.sigmoid(x) + + def f2_logging_tensor(x): + x = LoggingTensor(x) + gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {}) + self.assertFalse(is_any_sum(gm)) + self.assertTrue(is_any_sigmoid(gm)) + return torch.digamma(x) + + with self.assertRaisesRegex(AssertionError, "ProxyTensor is wrapped with another Tensor subclass"): + traced = make_fx(f2_logging_tensor)(torch.randn(3)) + self.assertFalse(is_any_sum(traced)) + self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor + self.assertTrue(is_any_digamma(traced)) @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") def test_resnet18_backward_trace(self, device): mod = torchvision.models.resnet18() def f(x): + for a in mod.parameters(): + a.grad = None out = mod(x) out.sum().backward() return [a.grad for a in mod.parameters()] inp = torch.randn(3, 3, 250, 250, requires_grad=True) - grads = f(inp) - - mod.zero_grad() - mod(inp).sum().backward() - grads2 = [a.grad for a in mod.parameters()] - self.assertEqual(grads, grads2) + self._test(f, [inp]) def test_proxy_tensor(self): def f_grad(x): @@ -104,11 +263,16 @@ def f_backward(x): return x.grad for f in [f_grad, f_backward]: - traced_graph = make_fx(f)(torch.randn(3, requires_grad=True)) - inp = torch.randn(3, requires_grad=True) - traced_graph_out = traced_graph(inp) - assert inp.grad is None - torch.testing.assert_close(traced_graph_out, f(inp)) + self._test(f, [torch.randn(3, requires_grad=True)]) + + def test_inplace_metadata(self): + def f(x): + x = x.clone() + x.unsqueeze_(-1) + assert x.shape[-1] == 1 + return x + + self._test(f, [torch.randn(5)]) def test_mode_tracing_factory_function(self): def f(x): @@ -118,7 +282,7 @@ def f(x): traced = make_fx(f)(torch.randn(3)) self.assertTrue( any( - isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn' + node.target == aten.randn.default for node in traced.graph.nodes ) ) @@ -126,16 +290,243 @@ def f(x): def test_mode_tracing_factory_function_no_factory_function(self): def f(x): return x + torch.randn(x.shape) - - traced = make_fx(f, trace_factory_functions=False)(torch.randn(3)) # default behavior should not trace factory functions + # setting the flag to false should not trace factory functions + traced = make_fx(f, trace_factory_functions=False)(torch.randn(3)) self.assertFalse( any( - isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn' + node.target == aten.randn.default for node in traced.graph.nodes ) ) + def test_make_fx_overloads(self): + def f(x): + return x.cos() + torch.randn(x.shape) + + traced = make_fx(f)(torch.randn(3)) + + self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload) + for node in traced.graph.nodes if node.op == 'call_function'])) + + def test_tensor_constants(self): + def f(): + val = torch.tensor(float('inf')) + return torch.full((100, 100), val) + + self._test(f, []) + + def test_constant_proxy_tensor(self): + from torch.fx.experimental.proxy_tensor import make_fx + + def f(): + val = torch.tensor(float('inf')) + return torch.full((100, 100), val) + + g = make_fx(f)() + self.assertEqual(g(), f()) + + def test_constant_proxy_tensor_mut(self): + from torch.fx.experimental.proxy_tensor import make_fx + + def f(): + val = torch.tensor(float(1)) + val.add_(2) + return torch.full((100, 100), val) + + g = make_fx(f)() + self.assertEqual(g(), f()) + # In case we mutated shared state in the g graph! + self.assertEqual(g(), f()) + + g = make_fx(f, tracing_mode="fake")() + self.assertEqual(g(), f()) + # In case we mutated shared state in the g graph! + self.assertEqual(g(), f()) + + def test_constant_unbind(self): + def f(): + val = torch.tensor([2]) + r, = torch.unbind(val, 0) + return r.item() + + g = make_fx(f)() + self.assertEqual(g(), f()) + + def test_issue82547(self): + x = nn.Parameter(torch.randn(3, 3)) + + def f(): + return torch.ops.aten.t.default(x) + self.assertRaisesRegex(Exception, "non-Fake Tensor", lambda: make_fx(f, tracing_mode="fake")()) + + class A(torch.Tensor): + pass + + x = A(torch.randn(3, 3)) + self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")()) + + def test_use_fake_and_tensor(self): + def f(x, y): + z = torch.tensor([2.0, 3.0]) + return x + y + z + + g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2)) + x, y = torch.randn(2), torch.randn(2) + self.assertEqual(g(x, y), f(x, y)) + + def test_decomposition_interpreter(self): + def fn(x): + return torch.nn.functional.silu(x) + + x = torch.rand((4, 4)) + fx_module = make_fx(fn, decomposition_table=None)(x) + + found_silu = False + for n in fx_module.graph.nodes: + if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default: + found_silu = True + + self.assertTrue(found_silu) + + new_graph = torch.fx.Graph() + silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]} + DecompositionInterpreter( + fx_module, + new_graph=new_graph, + decomposition_table=silu_decomp_table, + ).run(x) + + decomposed_module = torch.fx.GraphModule(fx_module, new_graph) + + for n in decomposed_module.graph.nodes: + self.assertTrue(n.target != torch.ops.aten.silu) + self.assertTrue(n.target != torch.ops.aten.silu.default) + + self.assertEqual(fx_module(x), decomposed_module(x)) + + def test_make_fx_model_fwd_bwd(self, device): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x).relu() + + model = Foo() + + def f(x, params): + out = stateless.functional_call(model, params, x).sum() + out.backward() + return list(params.values()) + input = torch.randn(3, 5, requires_grad=True) + params = dict(model.named_parameters()) + fx_f = make_fx(f)(input, params) + # fx may change the order of parameters in list, so using set() to compare + self.assertTrue( + torch.allclose(fx_f(input, params)[0], f(input, params)[0]) + or + torch.allclose(fx_f(input, params)[0], f(input, params)[1]) + ) + self.assertTrue( + torch.allclose(fx_f(input, params)[1], f(input, params)[0]) + or + torch.allclose(fx_f(input, params)[1], f(input, params)[1]) + ) + + def test_make_fx_model_fwd_bwd_wgtupdate(self, device): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x).relu() + + model = Foo() + + def f(args, params, buffers): + if not isinstance(args, Iterable): + args = [args] + params_and_buffers = {**params, **buffers} + out = stateless.functional_call(model, params_and_buffers, args) + out.sum().backward() + return [p - 1e-4 * p.grad for p in params.values()] + + input = torch.randn(3, 5, requires_grad=True) + params = dict(model.named_parameters()) + buffers = dict(model.named_buffers()) + fx_f = make_fx(f)(input, params, buffers) + # fx may change the order of parameters in list, so using set() to compare + # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03 + self.assertTrue( + torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03) + or + torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03) + ) + self.assertTrue( + torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03) + or + torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03) + ) + +# TODO: Need to test the guards themselves specifically as well +@skipIfNoSympy +class TestSymbolicTracing(TestCase): + def _test_dynamic(self, fn, trace_inputs, test_inputs): + """ + Tests fn traced with trace_inputs against test_inputs + Also returns shape env + """ + trace_inputs = [torch.randn(shape) for shape in trace_inputs] + traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs) + for input in test_inputs: + input = [torch.randn(shape) for shape in input] + self.assertEqual(traced_f(*input), fn(*input)) + return traced_f.shape_env + + + def test_unary(self): + def f(x): + assert x.shape[0] < 20 + return x.cos() + test_inputs = [] + test_inputs.append([(2, 5)]) + test_inputs.append([(6, 8)]) + shape_env = self._test_dynamic(f, [(3, 4)], test_inputs) + self.assertTrue(shape_env.evaluate_guards_for_args(torch.randn(4, 5))) + self.assertFalse(shape_env.evaluate_guards_for_args(torch.randn(25, 5))) + assert len(shape_env.guards) == 1 + + def test_binary_broadcast(self): + def f(a, b): + c = a * b + return c + + test_inputs = [] + test_inputs.append([(1, 5), (3, 1)]) + test_inputs.append([(1, 4), (4, 1)]) + shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs) + assert len(shape_env.guards) == 0 + + def test_cat(self): + def f(a, b): + val = torch.mul(a, b) + out = torch.cat([val, val]) + if out.shape[0] * out.shape[1] > 20: + out = out.cos() + return out + + test_inputs = [] + test_inputs.append([(1, 5), (6, 1)]) + test_inputs.append([(1, 4), (3, 1)]) + shape_env = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs) + self.assertTrue(shape_env.evaluate_guards_for_args(torch.randn(1, 10), torch.randn(6, 1))) + self.assertFalse(shape_env.evaluate_guards_for_args(torch.randn(1, 2), torch.randn(4, 1))) + assert len(shape_env.guards) == 1 + make_fx_failures = { + # unknown xfail('allclose'), xfail('equal'), xfail('linalg.eigvals'), @@ -150,8 +541,7 @@ def f(x): skip('nn.functional.max_unpool2d', '', device_type='cpu'), skip('nn.functional.max_unpool3d', '', device_type='cpu'), skip('linalg.lstsq'), # flaky, probably just a precision issue - xfail('histogram'), - xfail('scatter'), + # data-dependent control flow xfail('cov'), xfail('istft'), @@ -160,60 +550,454 @@ def f(x): xfail('quantile'), xfail('tensor_split'), xfail('corrcoef'), - # Masked failures (creating a scalar tensor just to call `.item` on it) - xfail('_masked.amax'), - xfail('_masked.amax'), - xfail('_masked.amin'), - xfail('_masked.argmax'), - xfail('_masked.argmin'), - xfail('_masked.cumprod'), - xfail('_masked.cumsum'), - xfail('_masked.log_softmax'), - xfail('_masked.logaddexp'), - xfail('_masked.logsumexp'), - xfail('_masked.mean'), - xfail('_masked.median'), - xfail('_masked.norm'), - xfail('_masked.prod'), - xfail('_masked.softmax'), - xfail('_masked.softmin'), - xfail('_masked.std'), - xfail('_masked.sum'), - xfail('_masked.var'), # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse xfail('sparse.sampled_addmm'), - # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse + # ??? xfail('nn.functional.ctc_loss'), + # Sparse tensors are not supported with faketensors for now + xfail('to_sparse'), + # segfaults + skip('block_diag'), } +fake_tensor_failures = { + # FakeTensor fallback doesn't work + xfail('segment_reduce', 'lengths'), + xfail('multinomial'), + xfail('mvlgamma', 'mvlgamma_p_1'), + xfail('mvlgamma', 'mvlgamma_p_3'), + xfail('mvlgamma', 'mvlgamma_p_5'), + xfail('cholesky'), + xfail('cholesky_inverse'), + # ASAN failures due to divide by 0 + skip('nn.functional.nll_loss'), +} + +symbolic_tensor_failures = { + # Needs complex-value support + xfail('polar'), + xfail('complex'), + xfail('linalg.eig'), + xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('__rmatmul__', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition + xfail('__rpow__', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition + xfail('_masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition + xfail('_masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition + xfail('_masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel + xfail('_masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ... + xfail('_masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition + xfail('_masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition + xfail('_masked.normalize', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition + xfail('_masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... + xfail('_masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('_masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... + xfail('addbmm', ''), # aten.addbmm.default - couldn't find symbolic meta function/decomposition + xfail('addmm', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition + xfail('addmm', 'decomposed'), # aten.mm.default - couldn't find symbolic meta function/decomposition + xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition + xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('all', ''), # Unexpected type when computing elementwise type promotion! + xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition + xfail('argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition + xfail('argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition + xfail('argsort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition + xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition + xfail('as_strided', ''), # aten.as_strided.default - couldn't find symbolic meta function/decomposition + xfail('as_strided_scatter', ''), # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition + xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition + xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition + xfail('bfloat16', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('bmm', ''), # aten.bmm.default - couldn't find symbolic meta function/decomposition + xfail('bool', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('broadcast_tensors', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition + xfail('byte', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel + xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('chalf', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('char', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... + xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('clamp_max', ''), # Received type that is neither a tensor or a number! + xfail('clone', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition + xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel + xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition + xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba... + xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition + xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition + xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition + xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition + xfail('cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition + xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition + xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition + xfail('diag_embed', ''), # aten.diag_embed.default - couldn't find symbolic meta function/decomposition + xfail('diagflat', ''), # Tensors of type TensorImpl do not have numel + xfail('diagonal', ''), # aten.diagonal.default - couldn't find symbolic meta function/decomposition + xfail('diagonal_scatter', ''), # aten.diagonal_scatter.default - couldn't find symbolic meta function/decomposition + xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition + xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition + xfail('double', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition + xfail('eig', ''), # aten.eig.default - couldn't find symbolic meta function/decomposition + xfail('einsum', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('expand_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.fft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.fftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.fftshift', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.hfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.hfft', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('fft.hfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.ifft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.ifft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.ifftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.ifftshift', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.ihfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.ihfft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.ihfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.irfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.irfft', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('fft.irfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.rfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.rfft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fft.rfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('fill', ''), # The underlying op of 'aten.stride' has no overload name '_schema' + xfail('flatten', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('float', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('float_power', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition + xfail('full_like', ''), # aten.full_like.default - couldn't find symbolic meta function/decomposition + xfail('gather', ''), # aten.gather.default - couldn't find symbolic meta function/decomposition + xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition + xfail('gradient', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('half', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because... + xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c... + xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition + xfail('hsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('hstack', ''), # Tensors of type TensorImpl do not have numel + xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition + xfail('index_add', ''), # Float + xfail('index_copy', ''), # Expected a long tensor for index, but got Float + xfail('index_fill', ''), # aten.index_fill.int_Scalar - couldn't find symbolic meta function/decomposition + xfail('index_put', ''), # aten.index_put.default - couldn't find symbolic meta function/decomposition + xfail('index_reduce', ''), # Float + xfail('index_select', ''), # Tensors of type TensorImpl do not have numel + xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('int', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('inverse', ''), # Tensors of type TensorImpl do not have numel + xfail('isclose', ''), # The underlying op of 'aten.stride' has no overload name '_schema' + xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition + xfail('isreal', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition + xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition + xfail('lerp', ''), # aten.lerp.Scalar - couldn't find symbolic meta function/decomposition + xfail('linalg.cholesky', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.cholesky_ex', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.cond', ''), # Tensors of type TensorImpl do not have numel + xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition + xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition + xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition + xfail('linalg.eigvalsh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition + xfail('linalg.householder_product', ''), # aten.linalg_householder_product.default - couldn't find symbolic meta funct... + xfail('linalg.inv', ''), # aten.linalg_inv_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.inv_ex', ''), # aten.linalg_inv_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.ldl_factor', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.ldl_factor_ex', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decompos... + xfail('linalg.ldl_solve', ''), # aten.linalg_ldl_solve.default - couldn't find symbolic meta function/decomposition + xfail('linalg.lu', ''), # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition + xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition + xfail('linalg.matrix_power'), # RuntimeError: Trying to call aten.size on a tensor with symbolic shape + xfail('linalg.matrix_norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition + xfail('linalg.matrix_rank', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('linalg.matrix_rank', 'hermitian'), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('linalg.multi_dot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('norm', 'fro'), # TensorImpl do not have numel + xfail('norm', 'inf'), # TensorImpl do not have numel + xfail('linalg.norm', ''), # TensorImpl do not have numel + xfail('linalg.norm', 'subgradients_at_zero'), # TensorImpl do not have numel + xfail('linalg.pinv', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition + xfail('linalg.pinv', 'singular'), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.pinv', 'hermitian'), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decompo... + xfail('linalg.qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition + xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition + xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition + xfail('linalg.solve_triangular', ''), # aten.linalg_solve_triangular.default - couldn't find symbolic meta function/de... + xfail('linalg.svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition + xfail('linalg.svdvals', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition + xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('linalg.vecdot', ''), # Could not run 'aten::vdot' with arguments from the 'Meta' backend. This could be ... + xfail('linalg.vector_norm', ''), # TensorImpl do not have numel + xfail('log_softmax', 'dtype'), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition + xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition + xfail('logcumsumexp', ''), # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition + xfail('logdet', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('logsumexp', ''), # Tensors of type TensorImpl do not have numel + xfail('long', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition + xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition + xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition + xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 + xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition + xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition + xfail('matmul', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition + xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition + xfail('max', 'reduction_with_dim'), # aten.max.dim - couldn't find symbolic meta function/decomposition + xfail('mean', ''), # Unexpected type when computing elementwise type promotion! + xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... + xfail('meshgrid', 'list_of_tensors'), # Tensors of type TensorImpl do not have numel + xfail('meshgrid', 'variadic_tensors'), # Tensors of type TensorImpl do not have numel + xfail('min', 'reduction_with_dim'), # aten.min.dim - couldn't find symbolic meta function/decomposition + xfail('mm', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition + xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition + xfail('msort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition + xfail('mv', ''), # aten.mv.default - couldn't find symbolic meta function/decomposition + xfail('nanmean', ''), # The underlying op of 'aten.stride' has no overload name '_schema' + xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('native_layer_norm', ''), # Unexpected type when computing elementwise type promot... + xfail('nn.functional.adaptive_avg_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.adaptive_avg_pool2d', ''), # argument 'size' must be tuple of ints, but found element o... + xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d.default - couldn't find symbolic meta func... + xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct... + xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl... + xfail('nn.functional.avg_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.avg_pool2d', ''), # aten.avg_pool2d.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.batch_norm', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom... + xfail('nn.functional.binary_cross_entropy_with_logits', ''), # aten.binary_cross_entropy_with_logits.default - couldn'... + xfail('nn.functional.conv1d', ''), # aten.convolution.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.conv2d', ''), # aten.convolution.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.conv_transpose1d', ''), # aten.convolution.default - couldn't find symbolic meta function/decompo... + xfail('nn.functional.conv_transpose2d', ''), # aten.convolution.default - couldn't find symbolic meta function/decompo... + xfail('nn.functional.conv_transpose3d', ''), # aten.convolution.default - couldn't find symbolic meta function/decompo... + xfail('nn.functional.cosine_embedding_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema' + xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.dropout2d', ''), # Tensors of type TensorImpl do not have numel + xfail('nn.functional.dropout3d', ''), # Tensors of type TensorImpl do not have numel + xfail('nn.functional.dropout', ''), # Tensors of type TensorImpl do not have numel + xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun... + xfail('nn.functional.embedding', ''), # argument 'size' must be tuple of ints, but found element of type tor... + xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Tensors of type TensorImpl do not have numel + xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t... + xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t... + xfail('nn.functional.glu', ''), # aten.glu.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... + xfail('nn.functional.group_norm', ''), # 'torch._C.SymbolicIntNode' and 'int' + xfail('nn.functional.hardsigmoid', ''), # Received type that is neither a tensor or a number! + xfail('nn.functional.hardswish', ''), # Received type that is neither a tensor or a number! + xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco... + xfail('nn.functional.huber_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.instance_norm', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.interpolate', 'area'), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.interpolate', 'bicubic'), # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d... + xfail('nn.functional.interpolate', 'bilinear'), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function... + xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... + xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d... + xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... + xfail('nn.functional.kl_div', ''), # Unexpected type when computing elementwise type pro... + xfail('nn.functional.l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.layer_norm', ''), # Unexpected type when computing elementwise type... + xfail('nn.functional.linear', ''), # aten.mv.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel + xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema' + xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices.default - couldn't find symbolic meta function/d... + xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d... + xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... + xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... + xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom... + xfail('nn.functional.mse_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the... + xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ... + xfail('nn.functional.multilabel_soft_margin_loss', ''), # aten.new_empty.default - couldn't find symbolic meta functio... + xfail('nn.functional.normalize', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.pad', 'circular'), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.pad', 'constant'), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition + xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo... + xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta function/deco... + xfail('nn.functional.pairwise_distance', ''), # TensorImpl does not have numel + xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... + xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta function/decompos... + xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... + xfail('nn.functional.poisson_nll_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema' + xfail('nn.functional.prelu', ''), # Tensors of type TensorImpl do not have numel + xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.soft_margin_loss', ''), # aten.soft_margin_loss.default - couldn't find symbolic meta function/de... + xfail('nn.functional.softmin', 'with_dtype'), # aten._to_copy.default - couldn't find symbolic meta function/decompos... + xfail('nn.functional.triplet_margin_loss', ''), # Unexpected type when computing element... + xfail('nn.functional.triplet_margin_with_distance_loss', ''), # Unexpected type when com... + xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de... + xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... + xfail('norm', ''), # TensorImpl does not have numel + xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition + xfail('normal', ''), # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition + xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition + xfail('ones_like', ''), # aten.ones_like.default - couldn't find symbolic meta function/decomposition + xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition + xfail('outer', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('pca_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition + xfail('pinverse', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition + xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition + xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition + xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition + xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition + xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition + xfail('put', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition + xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition + xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition + xfail('rand_like', ''), # aten.randn_like.default - couldn't find symbolic meta function/decomposition + xfail('randint_like', ''), # aten.randint_like.default - couldn't find symbolic meta function/decomposition + xfail('randn_like', ''), # aten.randn_like.default - couldn't find symbolic meta function/decomposition + xfail('ravel', ''), # Tensors of type TensorImpl do not have numel + xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition + xfail('repeat', ''), # aten.repeat.default - couldn't find symbolic meta function/decomposition + xfail('reshape_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('reshape', ''), # Tensors of type TensorImpl do not have numel + xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition + xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition + xfail('roll', ''), # Tensors of type TensorImpl do not have numel + xfail('rot90', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition + xfail('round', ''), # aten.round.default - couldn't find symbolic meta function/decomposition + xfail('round', 'decimals_0'), # aten.round.decimals - couldn't find symbolic meta function/decomposition + xfail('round', 'decimals_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition + xfail('round', 'decimals_neg_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition + xfail('scatter_add', ''), # aten.scatter_add.default - couldn't find symbolic meta function/decomposition + xfail('scatter', ''), # aten.scatter.src - couldn't find symbolic meta function/decomposition + xfail('scatter_reduce', 'amax'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition + xfail('scatter_reduce', 'amin'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition + xfail('scatter_reduce', 'mean'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition + xfail('scatter_reduce', 'prod'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition + xfail('scatter_reduce', 'sum'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition + xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ... + xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition + xfail('select', ''), # aten.select.int - couldn't find symbolic meta function/decomposition + xfail('select_scatter', ''), # aten.select_scatter.default - couldn't find symbolic meta function/decomposition + xfail('sgn', ''), # aten.sgn.default - couldn't find symbolic meta function/decomposition + xfail('short', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('sinc', ''), # aten.sinc.default - couldn't find symbolic meta function/decomposition + xfail('slice_scatter', ''), # aten.slice_scatter.default - couldn't find symbolic meta function/decomposition + xfail('softmax', 'with_dtype'), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('sort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition + xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition + xfail('special.bessel_j0', ''), # aten.special_bessel_j0.default - couldn't find symbolic meta function/decomposition + xfail('special.bessel_j1', ''), # aten.special_bessel_j1.default - couldn't find symbolic meta function/decomposition + xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition + xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition + xfail('special.chebyshev_polynomial_t', ''), # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me... + xfail('special.chebyshev_polynomial_u', ''), # aten.special_chebyshev_polynomial_u.default - couldn't find symbolic me... + xfail('special.entr', ''), # aten.special_entr.default - couldn't find symbolic meta function/decomposition + xfail('special.erfcx', ''), # aten.special_erfcx.default - couldn't find symbolic meta function/decomposition + xfail('special.hermite_polynomial_h', ''), # aten.special_hermite_polynomial_h.default - couldn't find symbolic meta f... + xfail('special.hermite_polynomial_he', ''), # aten.special_hermite_polynomial_he.default - couldn't find symbolic meta... + xfail('special.laguerre_polynomial_l', ''), # aten.special_laguerre_polynomial_l.default - couldn't find symbolic meta... + xfail('special.log_ndtr', ''), # aten.special_log_ndtr.default - couldn't find symbolic meta function/decomposition + xfail('special.modified_bessel_i0', ''), # aten.special_modified_bessel_i0.default - couldn't find symbolic meta funct... + xfail('special.modified_bessel_i1', ''), # aten.special_modified_bessel_i1.default - couldn't find symbolic meta funct... + xfail('special.modified_bessel_k0', ''), # aten.special_modified_bessel_k0.default - couldn't find symbolic meta funct... + xfail('special.modified_bessel_k1', ''), # aten.special_modified_bessel_k1.default - couldn't find symbolic meta funct... + xfail('special.ndtri', ''), # aten.special_ndtri.default - couldn't find symbolic meta function/decomposition + xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/... + xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... + xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... + xfail('special.spherical_bessel_j0', ''), # aten.special_spherical_bessel_j0.default - couldn't find symbolic meta fun... + xfail('special.xlog1py', ''), # aten.special_xlog1py.default - couldn't find symbolic meta function/decomposition + xfail('split', ''), # 'torch._C.SymbolicIntNode' and 'int' + xfail('split', 'list_args'), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('split_with_sizes', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('stack', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymbolicIntNode a... + xfail('std', ''), # Unexpected type when computing elementwise type promotion! + xfail('std_mean', ''), # Unexpected type when computing elementwise type promotion! + xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymbolicIntNode at... + xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition + xfail('svd_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition + xfail('symeig', ''), # aten.symeig.default - couldn't find symbolic meta function/decomposition + xfail('take_along_dim', ''), # dtype of indices should be Long but got Float + xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition + xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('tile', ''), # aten.repeat.default - couldn't find symbolic meta function/decomposition + xfail('topk', ''), # aten.topk.default - couldn't find symbolic meta function/decomposition + xfail('trapezoid', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('trapz', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition + xfail('tril', ''), # aten.tril.default - couldn't find symbolic meta function/decomposition + xfail('triu', ''), # aten.triu.default - couldn't find symbolic meta function/decomposition + xfail('unfold', ''), # aten.unfold.default - couldn't find symbolic meta function/decomposition + xfail('var_mean', ''), # Unexpected type when computing elementwise type promotion! + xfail('var', ''), # Unexpected type when computing elementwise type promotion! + xfail('vdot', ''), # aten.vdot.default - couldn't find symbolic meta function/decomposition + xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition + xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('view', ''), # Tensors of type TensorImpl do not have numel + xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('where', ''), # expected predicate to be bool, got torch.float32 + xfail('zero_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition + xfail('zeros_like', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition +} + +def _test_make_fx_helper(self, device, dtype, op, tracing_mode): + def f(args, kwargs): + return op.op(*args, **kwargs) + sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) + new_f = None + for sample_input in sample_inputs_itr: + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + + try: + new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs) + except DynamicOutputShapeException as e: + self.skipTest("Dynamic output shape operation in trace") + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: + arg.uniform_(0, 1) + try: + old_out = f(args, kwargs) + except Exception: + continue + new_out = wrapper_set_seed(new_f, args, kwargs) + self.assertEqual(new_out, old_out) class TestProxyTensorOpInfo(TestCase): @ops(op_db, allowed_dtypes=(torch.float,)) - @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures - ) + @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures) def test_make_fx_exhaustive(self, device, dtype, op): + _test_make_fx_helper(self, device, dtype, op, "real") - def f(args, kwargs): - return op.op(*args, **kwargs) - sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) - new_f = None - for sample_input in sample_inputs_itr: - args = [sample_input.input] + list(sample_input.args) - kwargs = sample_input.kwargs - - new_f = make_fx(f, trace_factory_functions=True)(args, kwargs) - for arg in args: - if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: - arg.uniform_(0, 1) - try: - old_out = f(args, kwargs) - except Exception: - continue - new_out = wrapper_set_seed(new_f, args, kwargs) - self.assertEqual(new_out, old_out) + @ops(op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures)) + def test_make_fx_fake_exhaustive(self, device, dtype, op): + _test_make_fx_helper(self, device, dtype, op, "fake") + @skipIfNoSympy + @ops(op_db, allowed_dtypes=(torch.float,)) + @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', + make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) + def test_make_fx_symbolic_exhaustive(self, device, dtype, op): + _test_make_fx_helper(self, device, dtype, op, "symbolic") only_for = ("cpu") diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index b830dc64ef7be..c41c6b0324926 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -188,7 +188,7 @@ def test_no_new_bindings(self): "StreamObjType", "StringType", "SUM", - "SymbolicIntNode", + "SymIntNode", "TensorType", "ThroughputBenchmark", "TracingState", @@ -247,6 +247,8 @@ def test_no_new_bindings(self): "view_copy", "generated", "dynamic_output_shape", + "nondeterministic_bitwise", + "nondeterministic_seeded", } torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")} @@ -276,6 +278,12 @@ def test_correct_module_names(self): # no new entries should be added to this allow_dict. # New APIs must follow the public API guidelines. allow_dict = json.load(json_file) + # Because we want minimal modifications to the `allowlist_for_publicAPI.json`, + # we are adding the entries for the migrated modules here from the original + # locations. + for modname in allow_dict["being_migrated"]: + if modname in allow_dict: + allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[modname] def test_module(modname): split_strs = modname.split('.') @@ -294,8 +302,13 @@ def check_one_element(elem, modname, mod, *, is_public, is_all): why_not_looks_public = "" if elem_module is None: why_not_looks_public = "because it does not have a `__module__` attribute" + # If a module is being migrated from foo.a to bar.a (that is entry {"foo": "bar"}), + # the module's starting package would be referred to as the new location even + # if there is a "from foo import a" inside the "bar.py". + modname = allow_dict["being_migrated"].get(modname, modname) elem_modname_starts_with_mod = elem_module is not None and \ - elem_module.startswith(modname) and '._' not in elem_module + elem_module.startswith(modname) and \ + '._' not in elem_module if not why_not_looks_public and not elem_modname_starts_with_mod: why_not_looks_public = f"because its `__module__` attribute (`{elem_module}`) is not within the " \ f"torch library or does not start with the submodule where it is defined (`{modname}`)" diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index e926b33191f6f..de720e5e56e90 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -11,10 +11,9 @@ from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \ log_input, capture_logs, capture_logs_with_logging_tensor_mode from torch.utils._pytree import tree_map -from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_dispatch_mode, TorchDispatchMode +from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode import logging -from functools import partial class TestPythonRegistration(TestCase): @@ -278,6 +277,12 @@ def _test(): test_helper("CONSERVATIVE") + def test_error_for_unsupported_ns_or_kind(self) -> None: + with self.assertRaisesRegex(ValueError, "Unsupported kind"): + my_lib1 = Library("myns", "BLA") + + with self.assertRaisesRegex(ValueError, "reserved namespace"): + my_lib1 = Library("prim", "DEF") class TestPythonDispatch(TestCase): def test_basic(self) -> None: @@ -352,23 +357,20 @@ def test_kwarg_only(self) -> None: def test_kwarg_only_and_positional_default(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.ones(1)) - y = LoggingTensor(torch.ones(1)) log_input("x", x) - log_input("y", y) - torch.ops.aten.kl_div(x, y) - torch.ops.aten.kl_div(x, y, 2) - torch.ops.aten.kl_div(x, y, log_target=True) - torch.ops.aten.kl_div(x, y, 2, log_target=True) + torch.ops.aten._foobar(x) + torch.ops.aten._foobar(x, False) + torch.ops.aten._foobar(x, arg3=False) + torch.ops.aten._foobar(x, False, arg3=False) - # What we are testing here is that we omit reduction + # What we are testing here is that we omit arg2 # if it is defaulted, even if a kwarg is set self.assertExpectedInline('\n'.join(logs), '''\ $0 = input('x') -$1 = input('y') -$2 = torch._ops.aten.kl_div.default($0, $1) -$3 = torch._ops.aten.kl_div.default($0, $1, 2) -$4 = torch._ops.aten.kl_div.default($0, $1, log_target=True) -$5 = torch._ops.aten.kl_div.default($0, $1, 2, log_target=True)''') +$1 = torch._ops.aten._foobar.default($0) +$2 = torch._ops.aten._foobar.default($0, False) +$3 = torch._ops.aten._foobar.default($0, arg3=False) +$4 = torch._ops.aten._foobar.default($0, False, arg3=False)''') def test_produce_real_type(self) -> None: with capture_logs() as logs: @@ -742,26 +744,33 @@ def test_enable_torch_dispatch_mode_error(self) -> None: def test_enable_torch_dispatch_mode_basic(self) -> None: with capture_logs(is_mode=True) as logs: - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)): + with enable_torch_dispatch_mode(LoggingTensorMode()): torch.empty([]) - self.assertExpectedInline('\n'.join(logs), """\ -$0 = torch._ops.aten.empty.memory_format([], dtype=torch.float32, device=device(type='cpu'), pin_memory=False)""") + + self.assertExpectedInline('\n'.join(logs), ("$0 = torch._ops.aten.empty.memory_format([], dtype=torch.float32," + + " device=device(type='cpu'), pin_memory=False)")) def test_enable_torch_dispatch_mode_unrelated_tensors(self) -> None: x = torch.randn([]) y = torch.randn([]) with capture_logs(is_mode=True) as logs: - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)): + with enable_torch_dispatch_mode(LoggingTensorMode()): x + y self.assertExpectedInline('\n'.join(logs), """\ $2 = torch._ops.aten.add.Tensor($0, $1)""") + def test_nested_push_regular(self): + with LoggingTensorMode.push() as mode: + # This previously errored + with LoggingTensorMode(): + pass + def test_nested_push_logging_tensor_mode(self): x = torch.randn([]) y = torch.randn([]) with capture_logs(is_mode=True) as logs: - with push_torch_dispatch_mode(LoggingTensorMode): - with push_torch_dispatch_mode(LoggingTensorMode): + with LoggingTensorMode(): + with LoggingTensorMode(): torch.empty([]) x + y @@ -843,12 +852,12 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None: with capture_logs(is_mode=True) as logs1: - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)): + with enable_torch_dispatch_mode(LoggingTensorMode()): torch.ones([2, 3]) with no_dispatch(): torch.ones([2, 3]) with capture_logs(is_mode=True) as logs2: - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)): + with enable_torch_dispatch_mode(LoggingTensorMode()): torch.ones([2, 3]) self.assertEqual(logs1, logs2) @@ -869,21 +878,21 @@ class A(LoggingTensorMode): pass with self.assertRaisesRegex(ValueError, "there is already an active mode"): - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)): - with enable_torch_dispatch_mode(A(inner=None)): + with enable_torch_dispatch_mode(LoggingTensorMode()): + with enable_torch_dispatch_mode(A()): pass # For nesting to be a noop, they need to be the same instance with self.assertRaisesRegex(ValueError, "there is already an active mode"): - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)): - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)): + with enable_torch_dispatch_mode(LoggingTensorMode()): + with enable_torch_dispatch_mode(LoggingTensorMode()): pass def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None: # "nested" enable_torch_dispatch_modes are allowed if they're the same mode (same instance). # It's the equivalent of a noop, so it will only write once to the log x = torch.tensor([3.]) - mode = LoggingTensorMode(inner=None) + mode = LoggingTensorMode() with capture_logs(is_mode=True) as logs: log_input("x", x) with enable_torch_dispatch_mode(mode): @@ -900,8 +909,8 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): x = torch.tensor([3.]) with capture_logs(is_mode=True) as logs: - with enable_torch_dispatch_mode(A(inner=None)): - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None), ignore_preexisting=True): + with enable_torch_dispatch_mode(A()): + with enable_torch_dispatch_mode(LoggingTensorMode(), ignore_preexisting=True): x + x self.assertExpectedInline('\n'.join(logs), """\ $1 = torch._ops.aten.add.Tensor($0, $0)""") @@ -912,10 +921,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise AssertionError x = torch.tensor([3.]) - outer_mode = A(inner=None) + outer_mode = A() with capture_logs(is_mode=True) as logs: with enable_torch_dispatch_mode(outer_mode): - with enable_torch_dispatch_mode(LoggingTensorMode(inner=None), replace=outer_mode): + with enable_torch_dispatch_mode(LoggingTensorMode(), replace=outer_mode): x + x self.assertExpectedInline('\n'.join(logs), """\ $1 = torch._ops.aten.add.Tensor($0, $0)""") @@ -939,58 +948,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): pass self.assertTrue(isinstance(torch.zeros(()), A)) - def test_push_torch_dispatch_mode(self) -> None: - class ErrorA(RuntimeError): - def __init__(self, msg=None): - return super().__init__(msg) - - class A(TorchDispatchMode): - def __init__(self, msg=None): - self.msg = msg - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - raise ErrorA(self.msg) - - x = torch.randn(3) - with self.assertRaises(ErrorA): - with push_torch_dispatch_mode(A): - torch.add(x, x) - - with self.assertRaisesRegex(ErrorA, r"partial constructor"): - with push_torch_dispatch_mode(partial(A, "partial constructor")): - x + x - - def test_torch_dispatch_mode_stack(self) -> None: - logs = [] - - class Logger(TorchDispatchMode): - def __init__(self, name): - self.name = name - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - logs.append(self.name) - return func(*args, **kwargs) - - x = torch.randn(1) - with Logger.push("A"): - with Logger.push("B"): - x + x - self.assertEqual(logs, ["B", "A"]) - - def test_push_mode_instance_errors(self): - class A(TorchDispatchMode): - pass - with self.assertRaisesRegex(ValueError, 'instance of TorchDispatchMode'): - with push_torch_dispatch_mode(A()): - pass - - def test_push_mode_returns_unrelated(self): - with self.assertRaisesRegex(ValueError, 'return a TorchDispatchMode'): - with push_torch_dispatch_mode(lambda *, inner: None): - pass - def test_ctor_no_inner(self): class A(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -1094,6 +1051,65 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): assert self.assertRaisesRegex(RuntimeError, "subclass Mode but.* associated to a python object of type Mode") + def test_notimplemented_mode(self): + sub_count = 0 + + class PoliteMode(TorchDispatchMode): + def __init__(self): + self.pre_count = 0 + self.post_count = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + self.pre_count += 1 + if any(t is not torch.Tensor for t in types): + return NotImplemented + self.post_count += 1 + return func(*args, **kwargs) + + class SubTensor(torch.Tensor): + def __new__(cls, elem): + r = torch.Tensor._make_wrapper_subclass(cls, elem.shape) + r.elem = elem + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + nonlocal sub_count + sub_count += 1 + + def unwrap(t): + if isinstance(t, SubTensor): + return t.elem + else: + return t + + return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + + __torch_function__ = torch._C._disabled_torch_function_impl + + a = SubTensor(torch.randn(2)) + with PoliteMode() as mode: + a.abs() + + self.assertEqual(mode.pre_count, 2) + self.assertEqual(mode.post_count, 1) + self.assertEqual(sub_count, 1) + + # make sure this doesn't error + with PoliteMode(): + with PoliteMode(): + a.abs() + + def test_disable_mode(self): + class FailEverythingMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + raise RuntimeError("arf") + + with FailEverythingMode() as m: + self.assertRaises(RuntimeError, lambda: torch.ones([2, 3])) + with enable_torch_dispatch_mode(None, replace=m): + torch.ones([2, 3]) + def test_make_wrapper_subclass_with_modes(self): class ModeTensor(torch.Tensor): def __new__(cls, elem, mode): @@ -1682,7 +1698,7 @@ def __new__(cls, data, wrapper): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): - if func == torch.ops.aten.stride: + if func == torch.ops.aten.stride.default: return (4, 2) return NotImplemented @@ -1693,7 +1709,7 @@ def __new__(cls, data, wrapper): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): - if func == torch.ops.aten.stride: + if func == torch.ops.aten.stride.default: return None return NotImplemented @@ -1732,7 +1748,7 @@ def __new__(cls, data, wrapper): def __torch_dispatch__(cls, func, types, args, kwargs): if func.overloadpacket == torch.ops.aten.dim: return data.dim() - if func.overloadpacket == torch.ops.aten.size: + if func.overloadpacket == torch.ops.aten.sym_size: return (5, 3) return NotImplemented @@ -1745,13 +1761,13 @@ def __new__(cls, data, wrapper): def __torch_dispatch__(cls, func, types, args, kwargs): if func.overloadpacket == torch.ops.aten.dim: return data.dim() - if func.overloadpacket == torch.ops.aten.size: + if func.overloadpacket == torch.ops.aten.sym_size: return None return NotImplemented - err_msg = "no implementation found for 'torch.ops.aten.size'" + err_msg = "no implementation found for 'torch.ops.aten.sym_size'" e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass) - with self.assertRaisesRegex(TypeError, err_msg): + with self.assertRaisesRegex(RuntimeError, err_msg): e.size() e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass) @@ -1760,5 +1776,51 @@ def __torch_dispatch__(cls, func, types, args, kwargs): e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass) self.assertEqual(e.size(), (4, 2)) + def test_layout_slow_path(self): + for use_wrapper_subclass in [True, False]: + data = torch.randn(6, 2) + + class LayoutNotImplemented(torch.Tensor): + @staticmethod + def __new__(cls, data, wrapper): + return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_layout=True) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + return NotImplemented + + class LayoutCustomReturn(torch.Tensor): + @staticmethod + def __new__(cls, data, wrapper): + return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_layout=True) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func.overloadpacket == torch.ops.prim.layout: + return torch.sparse_csr + return NotImplemented + + class LayoutDefaultReturn(torch.Tensor): + @staticmethod + def __new__(cls, data, wrapper): + return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_layout=True) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func.overloadpacket == torch.ops.prim.layout: + return data.layout + return NotImplemented + + err_msg = "no implementation found for 'torch.ops.prim.layout'" + e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass) + with self.assertRaisesRegex(TypeError, err_msg): + e.layout + + e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass) + self.assertEqual(e.layout, torch.sparse_csr) + + e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass) + self.assertEqual(e.layout, torch.strided) + if __name__ == '__main__': run_tests() diff --git a/test/test_quantization.py b/test/test_quantization.py index 47d25249c951c..16f1d2cd318a4 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -36,8 +36,9 @@ from quantization.core.test_workflow_module import TestHistogramObserver # noqa: F401 from quantization.core.test_workflow_module import TestDistributed # noqa: F401 from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401 +from quantization.core.test_backend_config import TestBackendConfig # noqa: F401 from quantization.core.test_utils import TestUtils # noqa: F401 - +from quantization.core.test_docs import TestQuantizationDocs # noqa: F401 # Eager Mode Workflow. Tests for the functionality of APIs and different features implemented # using eager mode. @@ -87,6 +88,10 @@ from quantization.fx.test_model_report_fx import TestFxModelReportDetector # noqa: F401 from quantization.fx.test_model_report_fx import TestFxModelReportObserver # noqa: F401 from quantization.fx.test_model_report_fx import TestFxModelReportDetectDynamicStatic # noqa: F401 + from quantization.fx.test_model_report_fx import TestFxModelReportClass # noqa: F401 + from quantization.fx.test_model_report_fx import TestFxDetectInputWeightEqualization # noqa: F401 + from quantization.fx.test_model_report_fx import TestFxDetectOutliers # noqa: F401 + from quantization.fx.test_model_report_fx import TestFxModelReportVisualizer # noqa: F401 except ImportError: pass diff --git a/test/test_reductions.py b/test/test_reductions.py index 6f64dc316a9eb..f29bb56087bef 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -14,7 +14,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and, - integral_types_and, floating_and_complex_types_and, all_types_and, + integral_types_and, floating_and_complex_types_and, all_types_and, all_types, ) from torch.testing._internal.common_utils import ( TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict, @@ -825,15 +825,16 @@ def test_cumsum_integer_upcast(self, device): def test_cumprod_integer_upcast(self, device): self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs)) - def test_mode(self, device): + @dtypes(*all_types()) + def test_mode(self, device, dtype): SIZE = 10 - x = torch.arange(1., SIZE * SIZE + 1, device=device).clone().resize_(SIZE, SIZE) + x = torch.arange(1., SIZE * SIZE + 1, device=device, dtype=dtype).clone().resize_(SIZE, SIZE) x[:2] = 1 x[:, :2] = 1 x0 = x.clone() # Pre-calculated results. - res1val = torch.ones(SIZE, device=device) + res1val = torch.ones(SIZE, device=device, dtype=dtype) # The indices are the position of the last appearance of the mode element. res1ind = torch.ones(SIZE, device=device, dtype=torch.long) res1ind[0] = SIZE - 1 @@ -844,7 +845,7 @@ def test_mode(self, device): self.assertEqual(res1ind, res2ind, atol=0, rtol=0) # Test use of result tensor - res2val = torch.tensor((), device=device) + res2val = torch.tensor((), device=device, dtype=dtype) res2ind = torch.tensor((), device=device, dtype=torch.long) torch.mode(x, keepdim=False, out=(res2val, res2ind)) self.assertEqual(res1val, res2val, atol=0, rtol=0) @@ -858,10 +859,10 @@ def test_mode(self, device): # input unchanged self.assertEqual(x, x0, atol=0, rtol=0) - def _test_mode_intervals(self, shape, intervals, device, v=1): - x = torch.arange(0, shape[0] * shape[1], device=device) - x[v] = x.numel() - x = x.resize_(shape) + def _test_mode_intervals(self, shape, intervals, device, dtype, v=1): + x = torch.arange(0, shape[1], device=device, dtype=dtype).expand(shape) + x = x.contiguous() + x[:, v] = intervals[0][0] # Set the value of each interval to the mode "v" for (beg, end) in intervals: @@ -875,14 +876,15 @@ def _test_mode_intervals(self, shape, intervals, device, v=1): self.assertTrue((values == v).all().item()) @onlyCUDA - def test_mode_large(self, device): + @dtypes(*all_types_and(torch.half, torch.bfloat16)) + def test_mode_large(self, device, dtype): # i should be less than (d - 2) / 2 def testset_for_shape(shape, i): d = shape[-1] # Mode only in the middle. - self._test_mode_intervals(shape, [(i, d - i)], device) + self._test_mode_intervals(shape, [(i, d - i)], device, dtype) # Mode in discontiguous parts of the input. - self._test_mode_intervals(shape, [(0, i), (i + 1, d - i - 1), (d - i, d)], device) + self._test_mode_intervals(shape, [(0, i), (i + 1, d - i - 1), (d - i, d)], device, dtype) # More than one line of (65535) thread blocks testset_for_shape((65536, 10), 3) @@ -893,6 +895,32 @@ def testset_for_shape(shape, i): # Naive kernel for big slice sizes (> 2048) testset_for_shape((10, 4096), 10) + def test_mode_boolean(self, device): + shapes = [ + (10, 10), + (4, 2048), + (1, 4096), + ] + + for shape in shapes: + a = torch.zeros(shape, device=device, dtype=torch.bool) + + a[:, (shape[1] - 1) // 2:] = True + values, indices = a.mode(-1) + self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool)) + print(indices) + indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1) + self.assertEqual(values, indexed) + + a.fill_(False) + a[:, shape[1] // 2 + 1:] = True + values, indices = a.mode(-1) + print(indices) + self.assertEqual(values, torch.zeros(shape[0], dtype=torch.bool)) + indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1) + self.assertEqual(values, indexed) + + @expectedFailureMeta # mode only supports CPU and CUDA device type @onlyNativeDeviceTypes def test_mode_wrong_dtype(self, device): diff --git a/test/test_scatter_gather_ops.py b/test/test_scatter_gather_ops.py index fa0c14d4dcb70..e8d874b6a52b3 100644 --- a/test/test_scatter_gather_ops.py +++ b/test/test_scatter_gather_ops.py @@ -12,7 +12,7 @@ (instantiate_device_type_tests, dtypes, dtypesIfCUDA, toleranceOverride, tol,) from torch.testing._internal.common_dtype import \ - (get_all_dtypes, get_all_fp_dtypes,) + (get_all_dtypes,) # Protects against includes accidentally setting the default dtype assert torch.get_default_dtype() is torch.float32 @@ -208,7 +208,7 @@ def test_scatter_reduce_sum(self, device, dtype): include_self=include_self) @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True)) - @dtypesIfCUDA(*get_all_fp_dtypes(include_half=True, include_bfloat16=True)) + @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) def test_scatter_reduce_prod(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, @@ -216,7 +216,7 @@ def test_scatter_reduce_prod(self, device, dtype): include_self=include_self) @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False)) - @dtypesIfCUDA(*get_all_fp_dtypes(include_half=True, include_bfloat16=True)) + @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) def test_scatter_reduce_mean(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, @@ -224,7 +224,7 @@ def test_scatter_reduce_mean(self, device, dtype): include_self=include_self) @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) - @dtypesIfCUDA(*get_all_fp_dtypes(include_half=True, include_bfloat16=True)) + @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) def test_scatter_reduce_amax(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, @@ -243,7 +243,7 @@ def test_scatter_reduce_amax(self, device, dtype): @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) - @dtypesIfCUDA(*get_all_fp_dtypes(include_half=True, include_bfloat16=True)) + @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) def test_scatter_reduce_amin(self, device, dtype): for include_self in (True, False): self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, diff --git a/test/test_serialization.py b/test/test_serialization.py index 2643b4bcad5cd..2d95dbee6f068 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -97,7 +97,7 @@ def _test_serialization_assert(self, b, c): self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) - self.assertTrue(isinstance(c[4], torch.storage._TypedStorage)) + self.assertTrue(isinstance(c[4], torch.storage.TypedStorage)) self.assertEqual(c[4].dtype, torch.float) c[0].fill_(10) self.assertEqual(c[0], c[2], atol=0, rtol=0) @@ -370,7 +370,7 @@ def test_serialization_backwards_compat(self): self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) - self.assertTrue(isinstance(c[4], torch.storage._TypedStorage)) + self.assertTrue(isinstance(c[4], torch.storage.TypedStorage)) self.assertEqual(c[4].dtype, torch.float32) c[0].fill_(10) self.assertEqual(c[0], c[2], atol=0, rtol=0) @@ -621,8 +621,8 @@ def save_load_check(a, b): a = torch.tensor([], dtype=dtype, device=device) for other_dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): - s = torch._TypedStorage( - wrap_storage=a.storage()._untyped(), + s = torch.TypedStorage( + wrap_storage=a.storage().untyped(), dtype=other_dtype) save_load_check(a, s) save_load_check(a.storage(), s) @@ -653,8 +653,8 @@ def test_save_different_dtype_error(self): torch.save([a.storage(), a.imag.storage()], f) a = torch.randn(10, device=device) - s_bytes = torch._TypedStorage( - wrap_storage=a.storage()._untyped(), + s_bytes = torch.TypedStorage( + wrap_storage=a.storage().untyped(), dtype=torch.uint8) with self.assertRaisesRegex(RuntimeError, error_msg): @@ -898,6 +898,10 @@ def __setstate__(self, state): self.reloaded = True +class TestEmptySubclass(torch.Tensor): + ... + + class TestSubclassSerialization(TestCase): def test_tensor_subclass_wrapper_serialization(self): wrapped_tensor = torch.rand(2) @@ -956,6 +960,25 @@ def test_cloned_deepcopy(self, requires_grad): self.assertEqual(new_tensor.requires_grad, my_tensor.requires_grad) + def test_empty_class_serialization(self): + tensor = TestEmptySubclass([1.]) + # Ensures it runs fine + tensor2 = copy.copy(tensor) + + with BytesIOContext() as f: + torch.save(tensor, f) + f.seek(0) + tensor2 = torch.load(f) + + tensor = TestEmptySubclass() + # Ensures it runs fine + # Note that tensor.data_ptr() == 0 here + tensor2 = copy.copy(tensor) + + with BytesIOContext() as f: + torch.save(tensor, f) + f.seek(0) + tensor2 = torch.load(f) instantiate_device_type_tests(TestBothSerialization, globals()) diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index b6557eed0d257..c09d58ea4c8be 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -11,7 +11,7 @@ from torch._six import nan from torch.testing import make_tensor from torch.testing._internal.common_utils import ( - TestCase, run_tests, torch_to_numpy_dtype_dict) + TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes, dtypesIfCUDA, largeTensorTest) @@ -63,6 +63,7 @@ def test_unbind(self, device): self.assertEqual(x.select(dim, i), res2[i]) # TODO: update to work on CUDA, too? + @skipIfTorchDynamo("TorchDynamo fails with an unknown error") @onlyCPU def test_tolist(self, device): list0D = [] @@ -475,6 +476,7 @@ def test_flip_numpy(self, device, dtype): @onlyCUDA # CPU is too slow @largeTensorTest('17GB') # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB + @largeTensorTest("81GB", "cpu") # even for CUDA test, sufficient system memory is required def test_flip_large_tensor(self, device): t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_() torch_fn = partial(torch.flip, dims=(0,)) @@ -546,6 +548,7 @@ def test_rot90(self, device): self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) self.assertRaises(RuntimeError, lambda: data.rot90(1, [0])) + @skipIfTorchDynamo("TorchDynamo fails with an unknown error") @dtypes(torch.cfloat, torch.cdouble) def test_complex_rot90(self, device, dtype): shape = self._rand_shape(random.randint(2, 4), 5, 10) diff --git a/test/test_sparse.py b/test/test_sparse.py index abaed1953d3fb..e0b50e1b3ed98 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -8,7 +8,7 @@ import unittest from torch.testing import make_tensor from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ - do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS, gradcheck, coalescedonoff, \ + do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ DeterministicGuard, first_sample from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number @@ -26,6 +26,9 @@ floating_and_complex_types_and, integral_types, floating_types_and, ) +if TEST_SCIPY: + import scipy.sparse + # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -786,6 +789,53 @@ def test_shape(sparse_dims, nnz, with_size): test_shape(4, 3, [7, 7, 7, 3, 3, 3, 0]) test_shape(4, 0, [0, 0, 7, 3, 3, 3, 0]) + @coalescedonoff + @dtypes(torch.double, torch.cdouble) + def test_permute(self, device, dtype, coalesced): + # trivial checks + s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse() + with self.assertRaisesRegex(RuntimeError, "does not match the length"): + s.permute(dims=(1, 0)) + with self.assertRaisesRegex(RuntimeError, "duplicate dims"): + s.permute(dims=(1, 1, 1)) + + def test_shape(sparse_dims, nnz, with_size): + ndim = len(with_size) + valid_sparse_dims = torch.arange(-ndim, -ndim + sparse_dims) + valid_dense_dims = torch.arange(-ndim + sparse_dims, 0) + + for dims in itertools.permutations(range(-ndim, 0)): + s = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] + d = self.safeToDense(s) + + dims_sparse, _ = torch.tensor(dims[:sparse_dims]).sort() + dims_dense, _ = torch.tensor(dims[sparse_dims:]).sort() + + if (valid_sparse_dims == dims_sparse).all() and (valid_dense_dims == dims_dense).all(): + # if valid permutation, test for correctness + s_permuted = s.permute(dims) + self.assertEqual(s_permuted, d.permute(dims)) + + # if s is coalesced, and perm does not touch 0-dim, + # the result has to be coalesced as well + if dims[0] == 0: + self.assertEqual(s_permuted.is_coalesced(), s.is_coalesced()) + else: + self.assertFalse(s_permuted.is_coalesced()) + + gradcheck(lambda t: t.permute(dims).to_dense(), s.requires_grad_(True), check_sparse_nnz=True) + else: + # otherwise check if exception is thrown + fail_message = "transpositions between sparse and dense dimensions are not allowed" + with self.assertRaisesRegex(RuntimeError, fail_message): + s.permute(dims) + + test_shape(2, 3, [2, 3, 4, 5]) + test_shape(2, 3, [2, 2, 0]) + # if nnz=0, it is not true that t == t.to_dense().to_sparse() + # unless t.sparse_dim == t.dim (i.e. t is not hybrid) + test_shape(3, 0, [0, 0, 2]) + @coalescedonoff @onlyCPU @dtypes(torch.double) @@ -973,6 +1023,21 @@ def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=N for i in range(sizes[d]): test_shape(1, 10, sizes, d, i) + @dtypes(*integral_types()) + def test_select_no_type_promotion(self, device, dtype): + # see https://github.com/pytorch/pytorch/issues/82150 + idx = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]]) + val = torch.ones(6, dtype=dtype) + s = torch.sparse_coo_tensor(idx, val, size=(3, 3)) + + for t in (s, s * torch.tensor(0, dtype=dtype)): + # empty checks + self.assertEqual(t.dtype, t[2].dtype) + self.assertEqual(t.dtype, t[0, 1].dtype) + # sum should not promote + self.assertEqual(t.dtype, t[0, 0].dtype) + self.assertEqual(t.dtype, t[1, 1].dtype) + @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_index_select(self, device, dtype, coalesced): @@ -1397,6 +1462,27 @@ def fn(S, D): test_shape(7, 8, 9, 20, False) test_shape(7, 8, 9, 20, True) + @coalescedonoff + @dtypes(torch.double) + def test_sparse_mul(self, device, dtype, coalesced): + # https://github.com/pytorch/pytorch/issues/79914 + a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True) + b = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True) + gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], check_sparse_nnz=True) + + def test_shape(sparse_dims, nnz, with_shape): + a = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True) + b = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True) + + self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense()) + gradcheck(lambda x, y: (x * y).to_dense(), [a, b], check_sparse_nnz=True) + # Issues with 0-dim indices/values + gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], check_sparse_nnz=True) + + # TODO: Re-enable these + # test_shape(2, 3, [2, 3, 4, 5]) + # test_shape(2, 3, [2, 2, 0]) + @coalescedonoff @dtypes(torch.double) def test_dsmm(self, device, dtype, coalesced): @@ -3312,6 +3398,13 @@ def sparse_log(x): test_op(3, 100, [3, 4, 2, 3, 5, 2], coalesced) test_op(4, 100, [3, 4, 2, 3, 5, 2], coalesced) + + @dtypes(torch.double) + def test_softmax_zero_nnz(self, device, dtype): + t = torch.sparse_coo_tensor([[]], [], (3,), device=device, dtype=dtype) + out = torch.sparse.softmax(t, 0) + self.assertEqual(out.to_dense(), torch.zeros_like(t)) + # TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA @skipIfRocm @coalescedonoff @@ -3425,26 +3518,6 @@ def assign_to(): self.assertRaises(TypeError, assign_to) - def test_cpu_sparse_dense_mul(self, device): - # general multiplication is not supported, but 0dim multiplication is supported - s = torch.sparse_coo_tensor([[0], [1]], [5.0], (2, 3), device=device) - t23 = s.to_dense() - t0 = torch.tensor(2.0, device=device) - r = s * 2.0 - self.assertEqual(r, 2.0 * s) - self.assertEqual(r, t0 * s) - self.assertEqual(r, s * t0) - if device == 'cpu': - with self.assertRaisesRegex(RuntimeError, r"mul\(sparse, dense\) is not supported"): - s * t23 - with self.assertRaisesRegex(RuntimeError, r"mul\(dense, sparse\) is not supported"): - t23 * s - elif device == 'cuda': - with self.assertRaisesRegex(NotImplementedError, "CUDA"): - s * t23 - with self.assertRaisesRegex(NotImplementedError, "CUDA"): - t23 * s - @dtypes(torch.double, torch.cdouble) def test_full_broadcast_to(self, device, dtype): def can_broadcast(s0, s1): @@ -3490,6 +3563,168 @@ def test(sparse_dims, nnz, with_size, new_size): test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3]) test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3]) + @coalescedonoff + @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) + @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) + def test_sparse_dense_mul(self, device, dtype, coalesced): + skipTestIfUncoalesced = False + # This case always coalesce inputs and that could lead to loss of precision, + # hence it is inhibited for float16/bfloat16 by providing already coalesced tensors. + if not coalesced and dtype in {torch.float16, torch.bfloat16}: + skipTestIfUncoalesced = True + # to_dense is problematic for boolean non-coalesced CUDA tensors + # see https://github.com/pytorch/pytorch/issues/81648 + if not coalesced and dtype == torch.bool and torch.device(device).type == "cuda": + skipTestIfUncoalesced = True + + if skipTestIfUncoalesced: + self.skipTest(f"Test with dtype={dtype}, device={device} runs only with coalesced inputs") + + shape = (2, 3, 4, 10) + nnz = 10 + + def check(self, s, d): + res = d * s + + # check commutativity + self.assertEqual(res, s * d) + + # check correctness + self.assertEqual(res.to_dense(), s.to_dense() * d) + + # check in-placeness for dense + if d.dim() >= s.dim(): + dc = d.clone() + self.assertEqual(d.mul_(s), dc.mul_(s.to_dense())) + + # check in-placeness for sparse + if s.dim() >= d.dim(): + # for sparse + sc = s.clone() + self.assertEqual(s.mul_(d).to_dense(), sc.to_dense().mul_(d)) + + for dim in range(len(shape) + 1): + sub_shape = shape[dim:] + sparse_dim = len(sub_shape) // 2 + + def check_empty(sparse_shape, nnz, dense_shape, coalesce): + from itertools import product + for nnz_val, shape_suffix in product((nnz, 0), ((), (0,))): + empty_sparse_shape = sparse_shape + shape_suffix + empty_dense_shape = dense_shape + shape_suffix + s = self._gen_sparse(sparse_dim, nnz_val, empty_sparse_shape, dtype, device, coalesce)[0] + d = make_tensor(empty_dense_shape, dtype=dtype, device=device) + check(self, s, d) + + # check scalar multiplication + s = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0] + for scalar in (True, 1, 1.0): + res_sparse = s * scalar + res_dense = s.to_dense() * scalar + # check correctness and dtype + self.assertEqual(s.to(res_sparse.dtype), res_sparse) + self.assertEqual(res_sparse.dtype, res_dense.dtype) + + # Case 1: sparse broadcasts over dense + s = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0] + d = make_tensor(shape, dtype=dtype, device=device) + check(self, s, d) + check_empty(sub_shape, nnz, shape, coalesced) + + # Case 2: dense broadcasts over sparse + s = self._gen_sparse(3, nnz, shape, dtype, device, coalesced)[0] + d = make_tensor(sub_shape, dtype=dtype, device=device) + check(self, s, d) + check_empty(shape, nnz, sub_shape, coalesced) + + @unittest.skipIf(not TEST_NUMPY, "NumPy is not availible") + @onlyCPU + @dtypes(*all_types_and_complex_and(torch.bool)) + def test_sparse_spdiags(self, device, dtype): + + make_diags = functools.partial(make_tensor, dtype=dtype, device=device) + make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device) + + if TEST_SCIPY: + def reference(diags, offsets, shape): + return scipy.sparse.spdiags(diags, offsets, *shape).toarray() + + else: + def reference(diags, offsets, shape): + result = torch.zeros(shape, dtype=dtype, device=device) + for i, off in enumerate(offsets): + res_view = result.diagonal(off) + data = diags[i] + if off > 0: + data = data[off:] + + m = min(res_view.shape[0], data.shape[0]) + res_view[:m] = data[:m] + return result + + def check_valid(diags, offsets, shape, layout=None): + ref_out = reference(diags, offsets, shape) + out = torch.sparse.spdiags(diags, offsets, shape, layout=layout) + if layout is None: + ex_layout = torch.sparse_coo + else: + ex_layout = layout + out_dense = out.to_dense() + self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}") + self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}") + + def check_invalid(args, error): + with self.assertRaisesRegex(RuntimeError, error): + torch.sparse.spdiags(*args) + + def valid_cases(): + # some normal cases + yield (make_diags((1, 5)), make_offsets([0]), (5, 5)) + yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4)) + # noncontigous diags + yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5)) + # noncontigous offsets + yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5)) + # noncontigous diags + offsets + yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5)) + # correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero + yield (make_diags((0, 3)), make_offsets([]), (3, 3)) + # forward rotation of upper diagonals + yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4)) + # rotation exausts input space to read from + yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3)) + # Simple cases repeated with special output format + yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc) + yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr) + # vector diags + yield (make_diags((3, )), make_offsets([1]), (4, 4)) + # Scalar offset + yield (make_diags((1, 3)), make_offsets(2), (4, 4)) + # offsets out of range + yield (make_diags((1, 3)), make_offsets([3]), (3, 3)) + yield (make_diags((1, 3)), make_offsets([-3]), (3, 3)) + + for case in valid_cases(): + check_valid(*case) + + def invalid_cases(): + yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d" + yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector" + yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix" + yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\ + r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)" + yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\ + r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)" + yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\ + r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+" + yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values" + yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+" + + + for case, error_regex in invalid_cases(): + check_invalid(case, error_regex) + + class TestSparseOneOff(TestCase): @unittest.skipIf(not TEST_CUDA, 'CUDA not available') @@ -3685,6 +3920,37 @@ def test_future_empty_dim(self, device, dtype, op): self.assertEqual(actual, expected) +class TestSparseMeta(TestCase): + exact_dtype = True + + def test_basic(self): + r = torch.empty(4, 4, layout=torch.sparse_coo, device='meta') + self.assertTrue(r.is_meta) + self.assertEqual(r.device.type, "meta") + r2 = torch.empty_like(r) + self.assertTrue(r2.is_meta) + self.assertEqual(r, r2) + r3 = torch.sparse_coo_tensor(size=(4, 4), device='meta') + self.assertTrue(r3.is_meta) + self.assertEqual(r, r3) + r.sparse_resize_((4, 4), 1, 1) + r.sparse_resize_and_clear_((4, 4, 4), 2, 1) + self.assertEqual(r.sparse_dim(), 2) + self.assertEqual(r.dense_dim(), 1) + self.assertEqual(r._dimV(), 1) + self.assertEqual(r._nnz(), 0) + # TODO: nnz zero sparse tensors should always be coalesced... + self.assertEqual(r.is_coalesced(), False) + r._coalesced_(True) + self.assertEqual(r.is_coalesced(), True) + # TODO: this sort of aliasing will need to be handled by + # functionalization + self.assertEqual(r._indices(), torch.empty(2, 0, device='meta', dtype=torch.int64)) + self.assertEqual(r._values(), torch.empty(0, 4, device='meta')) + self.assertEqual(r.indices(), torch.empty(2, 0, device='meta', dtype=torch.int64)) + self.assertEqual(r.values(), torch.empty(0, 4, device='meta')) + + # e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta') diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index ef15c5be8e3e5..b9423763795d1 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -1,20 +1,22 @@ # Owner(s): ["module: sparse"] +import copy import torch import random import itertools import unittest +import functools from torch.testing import make_tensor from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC from torch.testing._internal.common_utils import \ - (TEST_WITH_ROCM, TEST_SCIPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, parametrize, + (TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, parametrize, subtest) from torch.testing._internal.common_device_type import \ (ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoCusparseGeneric, precisionOverride, skipMeta, skipCUDAIf, skipCUDAIfRocm, skipCPUIfNoMklSparse) from torch.testing._internal.common_methods_invocations import \ (op_db, sparse_csr_unary_ufuncs, ReductionOpInfo) -from torch.testing._internal.common_cuda import _get_torch_cuda_version, CUDA11OrLater +from torch.testing._internal.common_cuda import _get_torch_cuda_version, CUDA11OrLater, TEST_CUDA from torch.testing._internal.common_dtype import ( floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and, all_types_and_complex, floating_and_complex_types_and @@ -24,6 +26,8 @@ if TEST_SCIPY: import scipy.sparse as sp +if TEST_NUMPY: + import numpy as np # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -171,46 +175,146 @@ def genTensor(self, size, nnz, *, layout, device=None, dtype=torch.float, index_ device = self.device_type return self.genSparseCompressedTensor(size, nnz, device=device, dtype=dtype, index_dtype=index_dtype, layout=layout) - def _generate_small_inputs(self, layout, device, dtype, index_dtype): + def _generate_small_inputs_utils(self, layout, device=None, dtype=None): + + def shape(shape, basedim=0, blocksize=(1, 1), dense_shape=()): + # Below, we define compressed and plain indices that + # correspond to row compressed tensors. In order to reuse + # the indices tensors for column compressed tensors, we + # swap the row and columns in shape dims (basedim and + # basedim + 1, respectively) to obtain the correct shape + # for column compressed tensors. Batch and dense + # dimensions remain as they are. + # + # Similarly, we reuse indices of non-block tensors for + # block tensors, that means, we'll need to multiply the + # base shape of the non-block tensor with blocksize to get + # the base shape of a block tensor. + if layout is torch.sparse_csc: + shape = shape[:basedim] + (shape[basedim + 1], shape[basedim]) + shape[basedim + 2:] + elif layout is torch.sparse_bsc: + shape = shape[:basedim] + (shape[basedim + 1] * blocksize[1], shape[basedim] * blocksize[0]) + shape[basedim + 2:] + elif layout is torch.sparse_bsr: + shape = shape[:basedim] + (shape[basedim] * blocksize[0], shape[basedim + 1] * blocksize[1]) + shape[basedim + 2:] + return shape + + def values(lst, basedim=0, blocksize=(1, 1), densesize=(), device=device, dtype=dtype): + # Below, we define values for non-blocked and non-hybrid + # tensors. To reuse these for blocked tensors, we replace + # all values in lst with a double-list that "shape" + # corresponds to blocksize. + # To support hybrid tensors, the values in lst are further + # replaced with a N-list where N==len(densesize) and the + # shape corresponds to densesize. + + max_val = torch.iinfo(dtype).max if dtype in [torch.int16, torch.int8, torch.uint8] else None + + def list_add(lst, value): + # recursively add a value to lst items + if isinstance(lst, list): + return [list_add(item, value) for item in lst] + rc = lst + value + return rc if max_val is None else (rc % max_val) + + def stretch_values(value, bdim, values_item_shape): + # replace a value with a new value that extends the + # dimensionality of the value by + # len(values_item_shape) from right. The left + # dimensions up to bdim are considered as batch + # dimensions. + if not values_item_shape: + return value + if isinstance(value, list) and bdim >= 0: + return [stretch_values(item, bdim - 1, values_item_shape) for item in value] + new_value = functools.reduce(lambda x, dims: [copy.deepcopy(x) for _ in range(dims)], + reversed(values_item_shape), None) + for p in itertools.product(*map(list, map(range, values_item_shape))): + row = functools.reduce(lambda x, i: x.__getitem__(i), p[:-1], new_value) + row[p[-1]] = list_add(value, sum([i * 10 ** d for d, i in enumerate(p)])) + return new_value + + if layout is torch.sparse_bsr: + values_item_shape = blocksize + densesize + elif layout is torch.sparse_bsc: + values_item_shape = tuple(reversed(blocksize)) + densesize + else: + values_item_shape = densesize + + if not lst: + return torch.tensor(lst, device=device, dtype=dtype).reshape(0, *values_item_shape) + + lst = stretch_values(lst, basedim, values_item_shape) + + return torch.tensor(lst, device=device, dtype=dtype) + + return shape, values + + def _generate_small_inputs(self, layout, device=None, dtype=None, index_dtype=None, + enable_batched=True, enable_hybrid=True): """Generator of inputs to sparse compressed tensor factory functions. The input is defined as a 4-tuple: compressed_indices, plain_indices, values, expected_size_from_shape_inference """ - from operator import mul - from functools import reduce - if layout in {torch.sparse_csr, torch.sparse_csc}: + if index_dtype is None: + index_dtype = torch.int64 + + shape, values = self._generate_small_inputs_utils(layout, device, dtype) + + # a regular tensor + yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype), + torch.tensor([0, 1, 0, 2], device=device, dtype=index_dtype), + values([1, 2, 3, 4], 0, (2, 1)), + shape((2, 3), 0, (2, 1))) + + # a tensor with zero dimensions + yield (torch.tensor([0, ], device=device, dtype=index_dtype), + torch.tensor([], device=device, dtype=index_dtype), + values([], 0, (2, 1)), + shape((0, 0), 0, (2, 1))) + + if enable_batched: + # a batched tensor with one batch dimension + yield (torch.tensor([[0, 2, 4], [0, 3, 4]], device=device, dtype=index_dtype), + torch.tensor([[0, 1, 0, 1], [0, 1, 2, 0]], device=device, dtype=index_dtype), + values([[1, 2, 3, 4], [5, 6, 7, 8]], 1, (1, 2)), + shape((2, 2, 3), 1, (1, 2))) + + # a batched tensor with two batch dimensions + yield (torch.tensor([[[0, 2, 4], [0, 3, 4], [0, 1, 4]], + [[0, 1, 4], [0, 2, 4], [0, 3, 4]]], + device=device, dtype=index_dtype), + torch.tensor([[[0, 1, 0, 1], [0, 1, 2, 0], [0, 0, 1, 2]], + [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]], + device=device, dtype=index_dtype), + values([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], 2, (2, 3)), + shape((2, 3, 2, 3), 2, (2, 3))) + + if enable_hybrid: + # a tensor with one dense dimension yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype), - torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype), - torch.tensor([1, 2, 3, 4], device=device, dtype=dtype), - (2, 2)) - yield (torch.tensor([0, ], device=device, dtype=index_dtype), - torch.tensor([], device=device, dtype=index_dtype), - torch.tensor([], device=device, dtype=dtype), - (0, 0)) - for batch_shape in [(2,), (2, 3)]: - prod = reduce(mul, batch_shape, 1) - yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype).repeat(prod, 1).reshape(*batch_shape, -1), - torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype).repeat(prod, 1).reshape(*batch_shape, -1), - torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(prod, 1).reshape(*batch_shape, -1), - (*batch_shape, 2, 2)) - else: - assert layout in {torch.sparse_bsr, torch.sparse_bsc} + torch.tensor([0, 1, 0, 2], device=device, dtype=index_dtype), + values([1, 2, 3, 4], 0, (3, 2), (2,)), + shape((2, 3, 2), 0, (3, 2))) + + # a tensor with two dense dimensions yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype), - torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype), - torch.tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 44]]], device=device, dtype=dtype), - (2, 4)) - yield (torch.tensor([0, ], device=device, dtype=index_dtype), - torch.tensor([], device=device, dtype=index_dtype), - torch.tensor([], device=device, dtype=dtype).reshape(1, 0, 0), - (0, 0)) - for batch_shape in [(2,), (2, 3)]: - prod = reduce(mul, batch_shape, 1) - yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype).repeat(prod, 1).reshape(*batch_shape, -1), - torch.tensor([0, 1, 0, 1], device=device, dtype=index_dtype).repeat(prod, 1).reshape(*batch_shape, -1), - torch.tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 44]]], - device=device, dtype=dtype).repeat(prod, 1, 1).reshape(*batch_shape, 4, 1, 2), - (*batch_shape, 2, 4)) + torch.tensor([0, 1, 0, 2], device=device, dtype=index_dtype), + values([1, 2, 3, 4], 0, (2, 3), (4, 2)), + shape((2, 3, 4, 2), 0, (2, 3))) + + if enable_batched and enable_hybrid: + # a batched tensor with two batch dimensions and two dense dimensions + yield (torch.tensor([[[0, 2, 4], [0, 3, 4], [0, 1, 4]], + [[0, 1, 4], [0, 2, 4], [0, 3, 4]]], + device=device, dtype=index_dtype), + torch.tensor([[[0, 1, 0, 1], [0, 1, 2, 0], [0, 0, 1, 2]], + [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]], + device=device, dtype=index_dtype), + values([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], 2, (3, 2), (2, 1)), + shape((2, 3, 2, 3, 2, 1), 2, (3, 2))) @all_sparse_compressed_layouts() @onlyCPU @@ -324,41 +428,37 @@ def test_clone(self, layout, device, dtype): def test_print(self, layout, device): compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] printed = [] - for index_dtype in [torch.int32, torch.int64]: - for dtype in [torch.float32, torch.float64]: - for compressed_indices, plain_indices, values, size in self._generate_small_inputs( - layout, device, dtype, index_dtype): - batch_shape = tuple(size[:-2]) - block_shape = tuple(values.shape[-2:]) if layout in {torch.sparse_bsr, torch.sparse_bsc} else () - blocksize0, blocksize1 = block_shape if layout in {torch.sparse_bsr, torch.sparse_bsc} else (1, 1) - if size not in [(2 * blocksize0, 2 * blocksize1), (0, 0), - (2, 3, 2 * blocksize0, 2 * blocksize1), (2, 2 * blocksize0, 2 * blocksize1)]: - # Skip inputs that are not in the list of - # expected sizes to ensure the stability of - # test_print in the case - # _generate_small_inputs is extended with new - # inputs - continue - if block_shape not in [(), (0, 0), (1, 2)]: - # Skip inputs that are not in the list of - # expected block sizes to ensure test_print - # stability. - continue - printed.append("########## {}/{}/batch_shape={}/block_shape={} ##########".format( - dtype, index_dtype, batch_shape, block_shape)) - x = torch.sparse_compressed_tensor(compressed_indices, - plain_indices, - values, dtype=dtype, layout=layout, device=device) - printed.append("# sparse tensor") - printed.append(str(x)) - printed.append(f"# _{compressed_indices_mth.__name__}") - printed.append(str(compressed_indices_mth(x))) - printed.append(f"# _{plain_indices_mth.__name__}") - printed.append(str(plain_indices_mth(x))) - printed.append("# _values") - printed.append(str(x.values())) + for enable_hybrid in [False, True]: + for index_dtype in [torch.int32, torch.int64]: + for dtype in [torch.float32, torch.float64]: + for compressed_indices, plain_indices, values, size in self._generate_small_inputs( + layout, device, dtype, index_dtype, enable_hybrid=enable_hybrid): + block_ndim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0 + base_ndim = 2 + batch_ndim = compressed_indices.dim() - 1 + dense_ndim = values.dim() - batch_ndim - block_ndim - 1 + if enable_hybrid and dense_ndim == 0: + # non-hybrid cases are covered by the enable_hybrid==False loop + continue + batchsize = size[:batch_ndim] + basesize = size[batch_ndim:batch_ndim + base_ndim] + densesize = size[batch_ndim + base_ndim:] + assert len(densesize) == dense_ndim + printed.append("########## {}/{}/size={}+{}+{} ##########".format( + dtype, index_dtype, batchsize, basesize, densesize)) + x = torch.sparse_compressed_tensor(compressed_indices, + plain_indices, + values, size, dtype=dtype, layout=layout, device=device) + printed.append("# sparse tensor") + printed.append(str(x)) + printed.append(f"# _{compressed_indices_mth.__name__}") + printed.append(str(compressed_indices_mth(x))) + printed.append(f"# _{plain_indices_mth.__name__}") + printed.append(str(plain_indices_mth(x))) + printed.append("# _values") + printed.append(str(x.values())) + printed.append('') printed.append('') - printed.append('') orig_maxDiff = self.maxDiff self.maxDiff = None try: @@ -446,6 +546,11 @@ def test_consistency(self, layout, device, dtype, op): or layout == torch.sparse_bsc and op.supports_sparse_bsc): self.skipTest(f"{op.name} does not support input with {layout} layout") + # FIXME: remove in followup once integer support is landed for segment_reduce + if (layout == torch.sparse_csr and not dtype.is_floating_point + and op.name in ('_masked.mean', '_masked.amax', '_masked.amin')): + self.skipTest(f"{op.name} does not support input with {layout} layout") + require_mask = isinstance(op, ReductionOpInfo) and '_masked.' in op.name if require_mask and layout in {torch.sparse_bsr, torch.sparse_bsc}: self.skipTest(f"{op.name} does not support input with {layout} layout") @@ -493,7 +598,8 @@ def test_consistency(self, layout, device, dtype, op): assert torch.is_tensor(output) strided_output = output.to_dense() if require_mask: - expected *= torch._masked._output_mask(op.op, sample.input, **sample.kwargs) + output_mask = torch._masked._output_mask(op.op, sample.input, **sample.kwargs) + expected.masked_fill_(~output_mask, 0) self.assertEqual(strided_output, expected) count += 1 @@ -501,6 +607,284 @@ def test_consistency(self, layout, device, dtype, op): if not count: raise ValueError("Expected at least one sample with keepdim and/or explicit mask for reductions.") + @skipMeta + @all_sparse_compressed_layouts() + @all_sparse_compressed_layouts('layout2') + @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) + def test_empty_like(self, layout, layout2, device, dtype): + for compressed_indices, plain_indices, values, size in self._generate_small_inputs(layout): + sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, + dtype=dtype, layout=layout, device=device) + if layout == layout2: + result = torch.empty_like(sparse, layout=layout2) + compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[result.layout] + torch._validate_sparse_compressed_tensor_args(compressed_indices_mth(result), + plain_indices_mth(result), + result.values(), + result.shape, + result.layout) + self.assertEqual(sparse.shape, result.shape) + else: + self.assertRaisesRegex( + RuntimeError, + "empty_like with different sparse layout is not supported", + lambda: torch.empty_like(sparse, layout=layout2) + ) + + @skipMeta + @all_sparse_compressed_layouts() + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + def test_validate(self, layout, device, dtype): + for index_dtype in [torch.int32, torch.int64]: + for compressed_indices, plain_indices, values, size in self._generate_small_inputs( + layout, device, dtype, index_dtype, enable_batched=True, enable_hybrid=True): + torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout) + + def _generate_invalid_input(self, layout, device): + from functools import partial + + shape, values = self._generate_small_inputs_utils(layout, device=device) + + tensor = partial(torch.tensor, device=device) + values = partial(values, device=device) + + yield ('incontiguous compressed_indices', + tensor([0, -1, 2, -1, 4, -1])[::2], + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + 'expected compressed_indices to be a strided and contiguous tensor') + + yield ('incontiguous plain_indices', + tensor([0, 2, 4]), + tensor([0, -1, 1, -1, 0, -1, 2, -1])[::2], + values([1, 2, 3, 4]), + shape((2, 3)), + 'expected plain_indices to be a strided and contiguous tensor') + + yield ('incontiguous values', + tensor([0, 2, 4]), + tensor([0, 1, 0, 2]), + values([1, 1, 2, 2, 3, 3, 4, 4])[::2], + shape((2, 3)), + 'expected values to be a strided and contiguous tensor') + + yield ('0-D compressed_indices', + tensor(0), + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + 'compressed_indices must have dimensionality >= 1 but got 0') + + yield ('compressed/plain_indices mismatch of dimensionalites', + tensor([[0, 2, 4]]), + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + 'compressed_indices and plain_indices dimensionalities must be equal but got 2 and 1, respectively') + + if layout in {torch.sparse_csr, torch.sparse_csc}: + yield ('indices and values mismatch of dimensionalites', + tensor([[0, 2, 4]]), + tensor([[0, 1, 0, 2]]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 0\) but got 1') + else: + yield ('indices and values mismatch of dimensionalites', + tensor([[0, 2, 4]]), + tensor([[0, 1, 0, 2]]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 2\) but got 3') + + yield ('invalid size', + tensor([0, 2, 4]), + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + (2,), + r'tensor dimensionality must be sum of batch, base, and dense dimensionalites \(=0 \+ 2 \+ 0\) but got 1') + + yield ('invalid batchsize', + tensor([[0, 2, 4]]), + tensor([[0, 1, 0, 2]]), + values([[1, 2, 3, 4]]), + shape((2, 2, 3), 1), + r'all batch dimensions of compressed_indices \(=\[1\]\), plain_indices \(=\[1\]\), ' + r'and values \(=\[1\]\) must be equal to tensor batch dimensions \(=\[2\]\)') + + if layout is torch.sparse_bsr: + yield ('invalid blocksize', + tensor([0, 2, 4]), + tensor([0, 1, 0, 2]), + tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]), + shape((2, 3)), + r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape') + + if layout is torch.sparse_bsc: + yield ('invalid blocksize', + tensor([0, 2, 4]), + tensor([0, 1, 0, 2]), + tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]), + shape((3, 2)), + r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape') + + yield ('invalid compressed_indices shape', + tensor([0, 2, 3, 4]), + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'compressed_indices.shape\[-1\] must be equal to the number of compressed_indices_names \+ 1 \(=3\), but got 4') + + yield ('invalid compressed_indices shape', + tensor([0, 2, 4]), + tensor([0, 1, 0, 1, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'plain_indices.shape\[-1\] must be equal to nnz \(=4\) as defined by values.shape\[0\], but got 5') + + yield ('compressed/plain_indices mismatch of dtype', + tensor([0, 2, 4], dtype=torch.int32), + tensor([0, 1, 0, 2], dtype=torch.int64), + values([1, 2, 3, 4]), + shape((2, 3)), + r'compressed_indices and plain_indices must have the same dtype, bot got Int and Long, respectively') + + yield ('invalid compressed/plain_indices dtype', + tensor([0, 2, 4], dtype=torch.int16), + tensor([0, 1, 0, 2], dtype=torch.int16), + values([1, 2, 3, 4]), + shape((2, 3)), + r'compressed_indices and plain_indices dtype must be Int or Long, but got Short') + + # CUDA kernel asserts are not recoverable, so we skip these for now + if torch.device(device).type == 'cpu': + yield ('invalid compressed_indices[0]', + tensor([1, 2, 4]), + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'`compressed_indices\[..., 0\] == 0` is not satisfied.') + + yield ('invalid compressed_indices[-1]', + tensor([0, 2, 5]), + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'`compressed_indices\[..., -1\] == nnz` is not satisfied.') + + yield ('invalid compressed_indices.diff(dim=-1)', + tensor([0, 0, 4]), + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.') + + yield ('invalid compressed_indices.diff(dim=-1)', + tensor([0, 5, 4]), + tensor([0, 1, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.') + + yield ('invalid min(plain_indices)', + tensor([0, 2, 4]), + tensor([0, -1, 0, 3]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'`0 <= plain_indices < plain_dim` is not satisfied.') + + yield ('invalid max(plain_indices)', + tensor([0, 2, 4]), + tensor([0, 1, 0, 3]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'`0 <= plain_indices < plain_dim` is not satisfied.') + + yield ('non-coalesced', + tensor([0, 2, 4]), + tensor([1, 0, 0, 2]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'`plain_indices\[..., compressed_indices\[..., i - 1\]:compressed_indices\[..., i\]\] ' + 'for all i = 1, ..., compressed_dim ' + 'are sorted and distinct along the last dimension values` is not satisfied.') + + if TEST_CUDA and torch.device(device).type == 'cpu': + yield ('indices and values mismatch of device', + torch.tensor([0, 2, 4]), + torch.tensor([0, 1, 0, 1]), + values([1, 2, 3, 4], device='cuda'), + shape((2, 3)), + r'device of compressed_indices \(=cpu\) must match device of values \(=cuda:0\)') + yield ('compressed_indices and values mismatch of device', + torch.tensor([0, 2, 4], device='cuda'), + torch.tensor([0, 1, 0, 1]), + values([1, 2, 3, 4]), + shape((2, 3)), + r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!') + yield ('compressed/plain_indices mismatch of device', + torch.tensor([0, 2, 4], device='cuda'), + torch.tensor([0, 1, 0, 1]), + values([1, 2, 3, 4], device='cuda'), + shape((2, 3)), + r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!') + + @skipMeta + @all_sparse_compressed_layouts() + @parametrize('target', [subtest('validate_sparse_compressed_tensor_args'), + subtest('sparse_compressed_tensor'), + subtest('sparse_compressed_tensor_no_size')]) + def test_invalid_input(self, layout, device, target): + for label, compressed_indices, plain_indices, values, size, errmsg in self._generate_invalid_input(layout, device): + if layout is torch.sparse_bsr: + errmsg = errmsg.replace('compressed_indices_name', 'row block').replace('plain_indices_name', 'column block') + elif layout is torch.sparse_bsc: + errmsg = errmsg.replace('compressed_indices_name', 'column block').replace('plain_indices_name', 'row block') + elif layout is torch.sparse_csr: + errmsg = errmsg.replace('compressed_indices_name', 'row').replace('plain_indices_name', 'column') + elif layout is torch.sparse_csc: + errmsg = errmsg.replace('compressed_indices_name', 'column').replace('plain_indices_name', 'row') + if layout in {torch.sparse_csr, torch.sparse_bsr}: + errmsg = errmsg.replace('compressed_indices', 'crow_indices') \ + .replace('plain_indices', 'col_indices') \ + .replace('plain_dim', 'ncols') \ + .replace('compressed_dim', 'nrows') + else: + errmsg = errmsg.replace('compressed_indices', 'ccol_indices') \ + .replace('plain_indices', 'row_indices') \ + .replace('plain_dim', 'nrows') \ + .replace('compressed_dim', 'ncols') + + if target == 'sparse_compressed_tensor_no_size' and label in { + 'invalid size', 'invalid batchsize', 'invalid compressed_indices shape', 'invalid max(plain_indices)', + 'invalid blocksize'}: + # Skip invalid size input as a valid size is estimated for other inputs + continue + + with self.assertRaisesRegex(RuntimeError, errmsg): + if target == 'validate_sparse_compressed_tensor_args': + torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout) + elif target == 'sparse_compressed_tensor': + torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout) + elif target == 'sparse_compressed_tensor_no_size': + torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=layout) + else: + raise NotImplementedError(target) + + @skipMeta + @onlyCPU + @all_sparse_compressed_layouts() + def test_dim(self, layout): + for compressed_indices, plain_indices, values, size in self._generate_small_inputs(layout): + batch_dim = compressed_indices.dim() - 1 + sparse_dim = 2 + block_dim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0 + dense_dim = values.dim() - batch_dim - block_dim - 1 + sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout) + self.assertEqual(sparse.sparse_dim(), sparse_dim) + self.assertEqual(sparse.dense_dim(), dense_dim) + class TestSparseCSR(TestCase): @@ -573,6 +957,35 @@ def test_sparse_csr_select(self, device, dtype): with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"): sparse[0, 0, 0, 0] = 99.0 + @parametrize("index_dtype", [torch.int32, torch.int64]) + @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + def test_sparse_bsr_select(self, device, dtype, index_dtype): + shape = (3, 6, 10) + nnz = 6 + sparse = self.genSparseBSRTensor(shape, (2, 2), nnz, dtype=dtype, device=device, index_dtype=index_dtype) + + # select from batch dimensions + sparse_selected02 = sparse.select(0, 2) + expected_sparse_selected02 = torch.sparse_bsr_tensor(sparse.crow_indices().select(0, 2).contiguous(), + sparse.col_indices().select(0, 2).contiguous(), + sparse.values().select(0, 2).contiguous(), + size=(6, 10), + dtype=dtype, + device=device) + self.assertEqual(expected_sparse_selected02, sparse_selected02) + + msg = "selecting non-batch dimensions is currently only supported for CSR tensors" + # selecting from rows or columns for batched CSR is not yet implemented + with self.assertRaisesRegex(RuntimeError, msg): + sparse.select(-2, 0) + + with self.assertRaisesRegex(RuntimeError, msg): + sparse.select(-1, 0) + + # assigning to sparse via indexing is disabled + with self.assertRaisesRegex(RuntimeError, msg): + sparse[0, 0, 0] = 99.0 + @skipMeta @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) def test_resize(self, device, dtype): @@ -621,161 +1034,6 @@ def test_resize_errors(self, device, dtype): new_shape = (2, 2) a.resize_(new_shape) - def test_factory_type_invariants_check(self, device): - with self.assertRaisesRegex(RuntimeError, "both crow_indices and col_indices should have the same type."): - torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int64), - torch.tensor([0, 1, 0, 1], dtype=torch.int32), - torch.tensor([1, 2, 3, 4]), - device=device) - - with self.assertRaisesRegex(RuntimeError, "crow_indices and col_indices must be an int32 or int64 type"): - torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int16), - torch.tensor([0, 1, 0, 1], dtype=torch.int16), - torch.tensor([1, 2, 3, 4]), - device=device) - - def test_factory_layout_invariants_check(self, device): - with self.assertRaisesRegex(RuntimeError, "expected values to be a strided and contiguous tensor"): - values = torch.tensor([1.], device=device).expand(4,) - torch.sparse_csr_tensor(torch.tensor([0, 2, 4], device=device), - torch.tensor([0, 1, 0, 1], device=device), - values) - - with self.assertRaisesRegex(RuntimeError, "expected col_indices to be a strided and contiguous tensor"): - col_indices = torch.tensor([0], device=device).expand(4,) - torch.sparse_csr_tensor(torch.tensor([0, 2, 4]), - col_indices, - torch.tensor([1, 2, 3, 4])) - - with self.assertRaisesRegex(RuntimeError, "expected crow_indices to be a strided and contiguous tensor"): - crow_indices = torch.arange(6, device=device) - torch.sparse_csr_tensor(crow_indices[::2], - torch.tensor([0, 1, 0, 1], device=device), - torch.tensor([1, 2, 3, 4])) - - def test_factory_shape_invariants_check(self, device): - crow_indices = torch.tensor([0, 2, 4], device=device) - col_indices = torch.tensor([0, 1, 0, 1], device=device) - values = torch.tensor([1, 2, 3, 4], device=device) - size = (2, 10) - torch.sparse_csr_tensor(crow_indices, col_indices, values, size, device=device) - - with self.assertRaisesRegex(RuntimeError, r"size of a batched CSR tensor must have length >= 2, but got: 1"): - torch.sparse_csr_tensor(crow_indices, col_indices, values, - size=(2,), - device=device) - - with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim >= 1 but got crow_indices\.dim\(\)\ = 0"): - torch.sparse_csr_tensor(torch.zeros((), device=device, dtype=torch.int64), - col_indices, - values, - size, - device=device) - - with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim >= 1 but got col_indices\.dim\(\)\ = 0"): - torch.sparse_csr_tensor(crow_indices, - torch.zeros((), device=device, dtype=torch.int64), - values, - size, - device=device) - - with self.assertRaisesRegex(RuntimeError, r"values must have dim >= 1 but got values\.dim\(\)\ = 0"): - torch.sparse_csr_tensor(crow_indices, - col_indices, - torch.zeros((), device=device, dtype=torch.int64), - size, - device=device) - - with self.assertRaisesRegex(RuntimeError, - r"crow_indices\.size\(-1\) must be equal to size\[-2\] \+ 1 \(that is 2\), but got: 3"): - torch.sparse_csr_tensor(crow_indices, col_indices, values, (1, 1), - device=device) - - - with self.assertRaisesRegex(RuntimeError, - r"number of dimensions of crow_indices and col_indices must be the same"): - torch.sparse_csr_tensor(crow_indices, col_indices.repeat(2, 1), values, size, - device=device) - - with self.assertRaisesRegex(RuntimeError, - r"non-zero dense dimensions \(=1\) is not supported"): - torch.sparse_csr_tensor(crow_indices, col_indices, values.repeat(2, 1), size, - device=device) - - with self.assertRaisesRegex(RuntimeError, - r"number of dimensions of indices must be one less"): - torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(2, 1), values.repeat(2, 1), size, - device=device) - - with self.assertRaisesRegex(RuntimeError, - r"all batch dimensions of the provided size \(\[2\]\), indices \(\[2\], \[3\]\)," - r" and values \(\[4\]\) must be the same"): - torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(3, 1), values.repeat(4, 1), (2, 2, 10), - device=device) - - def test_factory_indices_invariants_check(self, device): - crow_indices = [0, 2, 4] - col_indices = [0, 1, 0, 1] - values = [1, 2, 3, 4] - size = (2, 10) - with self.assertRaisesRegex(RuntimeError, "0th value of crow_indices must be 0."): - torch.sparse_csr_tensor(torch.tensor([-1, 0, 4]), torch.tensor(col_indices), torch.tensor(values), size, - device=device) - - with self.assertRaisesRegex(RuntimeError, - "last value of crow_indices should be equal to the length of col_indices."): - torch.sparse_csr_tensor(torch.tensor([0, 2, 5]), torch.tensor(col_indices), torch.tensor(values), size, - device=device) - - with self.assertRaisesRegex(RuntimeError, - r"at position i \= 2," + - r" the condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"): - torch.sparse_csr_tensor(torch.tensor([0, 5, 4]), torch.tensor(col_indices), torch.tensor(values), size, - device=device) - - with self.assertRaisesRegex(RuntimeError, r"col_indices\.min\(\) should be greater or equal to zero"): - torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, -1, 0, 1]), torch.tensor(values), size, - device=device) - - with self.assertRaisesRegex(RuntimeError, r"size\[-1\] should be greater than col_indices\.max\(\)"): - torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 11, 0, 1]), torch.tensor(values), size, - device=device) - - @onlyCUDA - @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) - def test_factory_device_type_inference(self, device, dtype): - cpu_cuda = ('cpu', 'cuda') - cpu_cuda_none = cpu_cuda + (None,) - for crow_indices_device, col_indices_device, values_device, device in itertools.product(cpu_cuda, - cpu_cuda, - cpu_cuda, - cpu_cuda_none): - for index_dtype in [torch.int32, torch.int64]: - crow_indices = torch.tensor([0, 2, 4], dtype=index_dtype, device=crow_indices_device) - col_indices = torch.tensor([0, 1, 0, 1], dtype=index_dtype, device=col_indices_device) - values = torch.tensor([1, 2, 3, 4], dtype=dtype, device=values_device) - if device is None and (crow_indices_device != col_indices_device or - crow_indices_device != values_device): - with self.assertRaises(RuntimeError): - torch.sparse_csr_tensor(crow_indices, - col_indices, - values, - size=(2, 10), - device=device) - else: - t = torch.sparse_csr_tensor(crow_indices, - col_indices, - values, - size=(2, 10), - device=device) - should_be_cuda = (device == 'cuda' or (device is None and values_device == 'cuda')) - self.assertEqual(should_be_cuda, t.is_cuda) - t.crow_indices().dtype == index_dtype - t.col_indices().dtype == index_dtype - t.values().dtype == dtype - t.crow_indices().device == t.values().device - t.col_indices().device == t.values().device - @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) def test_sparse_csr_from_dense(self, device, dtype): dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], dtype=dtype, device=device) @@ -2255,6 +2513,7 @@ def _to_from_layout(layout_a, layout_b): @skipMeta @all_sparse_compressed_layouts() + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") def test_dense_to_from_sparse_compressed(self, device, layout): """ This test tests conversion from dense to/from CSR and CSC @@ -2269,14 +2528,14 @@ def test_dense_to_from_sparse_compressed(self, device, layout): shapes = [(6, 10), (0, 10), (6, 0), (0, 0)] blocksizes = [(2, 2)] + batch_sizes = [(3, )] + if layout is torch.sparse_bsr: blocksizes += [(3, 5), (6, 10)] + batch_sizes += [(2, 3), (1, 1, 1, 2)] - for shape, blocksize in itertools.product(shapes, blocksizes): - dense = make_tensor(shape, dtype=torch.float, device=device) - dense = dense.relu() # Introduce some sparsity + def _test_matrix(pt_matrix, dense, layout, blocksize): sp_matrix = self._construct_sp_matrix(dense, layout, blocksize=blocksize) - pt_matrix = self._convert_to_layout(dense, layout, blocksize=blocksize) compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] @@ -2286,12 +2545,96 @@ def test_dense_to_from_sparse_compressed(self, device, layout): self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix)) self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values()) + + for shape, blocksize in itertools.product(shapes, blocksizes): + dense = make_tensor(shape, dtype=torch.float, device=device) + dense = dense.relu() # Introduce some sparsity + pt_matrix = self._convert_to_layout(dense, layout, blocksize=blocksize) + _test_matrix(pt_matrix, dense, layout, blocksize) self.assertEqual(dense, pt_matrix.to_dense()) + if layout is not torch.sparse_bsr: + # TODO: Remove this once support has been enabled + return + + # Test batch shapes (ND inputs) + + # Case 1: Same sparsity pattern across matrices + for shape, blocksize, batch_shape in itertools.product(shapes, blocksizes, batch_sizes): + full_shape = batch_shape + shape + batch_len = functools.reduce(lambda x, y: x * y, batch_shape, 1) + dense = make_tensor(full_shape, dtype=torch.float, device=device) + # select the first batch to create the mask + mask = dense[tuple(np.unravel_index(0, batch_shape))].relu().bool() + dense = dense * mask + pt_tensor = self._convert_to_layout(dense, layout, blocksize=blocksize) + for i in range(batch_len): + batch_idx = tuple(np.unravel_index(i, batch_shape)) + _test_matrix(pt_tensor[batch_idx], dense[batch_idx], layout, blocksize) + # todo: check whole conversion once to_dense impl for n-d batched-bsr + # take 3d slices of dense/sparse to convert/compare for now + if dense.dim() > 3: + part_dim = dense.dim() - 3 + part_shape = batch_shape[:part_dim] + len_partition = functools.reduce(lambda x, y: x * y, part_shape, 1) + for i in range(len_partition): + part_idx = tuple(np.unravel_index(i, part_shape)) + self.assertEqual(dense[part_idx], pt_tensor[part_idx].to_dense()) + else: + self.assertEqual(dense, pt_tensor.to_dense()) + + # Verify exception when given 0 sized batch + for shape, blocksize in itertools.product(shapes, blocksizes): + dense = make_tensor((0,) + shape, dtype=torch.float, device=device) + # TODO: Support zero sized batch dimensions + with self.assertRaisesRegex(RuntimeError, "to_sparse_bsr: Expected product of batch dimensions to be non-zero."): + self._convert_to_layout(dense, layout, blocksize=blocksize) + + # TODO: Case 2: Different sparsity pattern across matrices, but same number of zeros + # NOTE: For blocksparse formats this applies at a per-block level, + dense = make_tensor((2, 4, 4), dtype=torch.float, device=device) + blocksize = (2, 2) + mask = torch.tensor([ + [[True, True], [False, True]], + [[True, False], [True, True]]], + device=device).view((2, 2, 2, 1, 1)) + mask = mask.expand((2, 2, 2, 2, 2)) + mask = mask.transpose(2, 3) + mask = mask.reshape_as(dense) + dense = dense * mask + if layout == torch.sparse_bsr: + # this is not an error as long as the nse is equal for bsr + pt_tensor = self._convert_to_layout(dense, layout, blocksize=blocksize) + for i in range(2): + _test_matrix(pt_tensor[i], dense[i], layout, blocksize) + self.assertEqual(dense, pt_tensor.to_dense()) + else: + with self.assertRaisesRegex(RuntimeError, "Expect the same sparsity pattern across matrices for ND input."): + self._convert_to_layout(dense, layout, blocksize=blocksize) + + # TODO: Case 3: Different sparsity pattern across matrices, but different number of zeros + dense = make_tensor((2, 4, 4), dtype=torch.float, device=device) + blocksize = (2, 2) + mask = torch.tensor( + [[[True, True], [False, False]], + [[True, False], [True, True]]], + device=device).view((2, 2, 2, 1, 1)) + mask = mask.expand((2, 2, 2, 2, 2)) + mask = mask.transpose(2, 3) + mask = mask.reshape_as(dense) + dense = dense * mask + if layout == torch.sparse_bsr: + msg = "Expect the same number of specified elements per batch." + else: + msg = "Expect the same sparsity pattern across matrices for ND input." + with self.assertRaisesRegex(RuntimeError, msg): + self._convert_to_layout(dense, layout, blocksize=blocksize) + @skipMeta @all_sparse_compressed_layouts() @coalescedonoff @dtypes(torch.double) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout): """ This test tests conversion from COO to CSR and CSC and CSC to CSR and CSC diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index b4f37cc1558eb..a0cc0c20c1645 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -948,6 +948,52 @@ def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError) _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError) + @skipCPUIfNoFFT + @onlyNativeDeviceTypes + @dtypes(torch.double) + def test_istft_against_librosa(self, device, dtype): + if not TEST_LIBROSA: + raise unittest.SkipTest('librosa not found') + + def librosa_istft(x, n_fft, hop_length, win_length, window, length, center): + if window is None: + window = np.ones(n_fft if win_length is None else win_length) + else: + window = window.cpu().numpy() + + return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, + win_length=win_length, length=length, window=window, center=center) + + def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None, + length=None, center=True): + x = torch.randn(size, dtype=dtype, device=device) + if win_sizes is not None: + window = torch.randn(*win_sizes, dtype=dtype, device=device) + else: + window = None + + x_stft = x.stft(n_fft, hop_length, win_length, window, center=center, + onesided=True, return_complex=True) + + ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length, + window, length, center) + result = x_stft.istft(n_fft, hop_length, win_length, window, + length=length, center=center) + self.assertEqual(result, ref_result) + + for center in [True, False]: + _test(10, 7, center=center) + _test(4000, 1024, center=center) + _test(4000, 1024, center=center, length=4000) + + _test(10, 7, 2, center=center) + _test(4000, 1024, 512, center=center) + _test(4000, 1024, 512, center=center, length=4000) + + _test(10, 7, 2, win_sizes=(7,), center=center) + _test(4000, 1024, 512, win_sizes=(1024,), center=center) + _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000) + @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.double, torch.cdouble) diff --git a/test/test_stateless.py b/test/test_stateless.py index e3e3f03277d82..d7b7be547be83 100644 --- a/test/test_stateless.py +++ b/test/test_stateless.py @@ -156,6 +156,26 @@ def test_reparametrized_module_change_parametrization_original(self): self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) self.assertEqual(orig_sn_weight, module.l1.weight) + def test_reparamertize_module_fail_reset_to_original(self): + module = MockModule() + torch.nn.utils.parametrizations.spectral_norm(module.l1) + self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) + orig_sn_weight = module.l1.weight.clone() + # We substitute the parameter inside the parametrization + # the parametrization itself is not overwritten so it will be applied with a different + # value for the original tensor + parameters = {'l1.parametrizations.weight.original': torch.nn.Parameter(torch.tensor([[1.0]])), + 'l1.bias': torch.tensor([0.0]), + 'buffer': torch.tensor([0.0])} + with self.assertRaisesRegex(RuntimeError, "shapes cannot be multiplied"): + x = torch.rand((4, 5)) # to work, it should be of size (1, 1) + stateless.functional_call(module, parameters, x) # this call will fail because x is the wrong size + + # verify that the spectral normalization is still applied + self.assertTrue('l1.parametrizations.weight.original' in dict(module.named_parameters())) + self.assertEqual(orig_sn_weight, module.l1.weight) + + def test_setattr(self): class Foo(torch.nn.Module): def __init__(self): diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index a501d4b1a6fe8..b3087eee18e06 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -23,6 +23,9 @@ def __call__(self, *args, **kwargs): def benchmark(self, args, kwargs, warmup_runs, main_runs): self.static_module.benchmark(args, kwargs, warmup_runs, main_runs) + def runAsync(self, args, kwargs): + return self.static_module.runAsync(args, kwargs) + def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs): return self.static_module.benchmark_individual_ops( args, kwargs, warmup_runs, main_runs @@ -222,6 +225,20 @@ def test_fork_wait_1(self): output_test = static_runtime_module(inp1, inp2) torch.testing.assert_close(output_test, output_ref) + """ + Test Case: To test simple fork/wait operation with + StaticRuntime runAsync API returning future + """ + def test_fork_wait_1_async(self): + inp1 = torch.ones(5, 5) + inp2 = torch.randn(5, 5) + torch_graph = torch.jit.script(fork_wait_graph1) + output_ref = torch_graph(inp1, inp2) + static_runtime_module = StaticModule(torch_graph) + output_test = static_runtime_module.runAsync((inp1, inp2), {}) + output_test.wait() + torch.testing.assert_close(output_test.value(), output_ref) + """ Test Case: To test fork/wait operation in a graph on a loop subgraph performing mix of operations @@ -235,6 +252,20 @@ def test_fork_wait_2(self): output_test = static_runtime_module(inp1, inp2) torch.testing.assert_close(output_test, output_ref) + """ + Test Case: To test fork/wait operation on a loop + subgraph with StaticRuntime runAsync API returning future + """ + def test_fork_wait_2_async(self): + inp1 = torch.randn(5, 5) + inp2 = torch.randn(5, 5) + torch_graph = torch.jit.script(fork_wait_graph2) + output_ref = torch_graph(inp1, inp2) + static_runtime_module = StaticModule(torch_graph) + output_test = static_runtime_module.runAsync((inp1, inp2), {}) + output_test.wait() + torch.testing.assert_close(output_test.value(), output_ref) + """ Test Case: To test fork/wait operation in a graph on having multiple fork/wait operations @@ -247,6 +278,21 @@ def test_fork_wait_3(self): static_runtime_module = StaticModule(torch_graph) output_test = static_runtime_module(input, num_forks) torch.testing.assert_close(output_test, output_ref) + + """ + Test Case: To test fork/wait operation in a graph with + multiple fork/wait operations on runAsync API returning future + """ + def test_fork_wait_3_async(self): + input = torch.ones(3, 3) + num_forks = 10 + torch_graph = torch.jit.script(fork_wait_graph3) + output_ref = torch_graph(input, num_forks) + static_runtime_module = StaticModule(torch_graph) + output_test = static_runtime_module.runAsync((input, num_forks), {}) + output_test.wait() + torch.testing.assert_close(output_test.value(), output_ref) + """ Test Case: To test fork/wait operation in a graph on multiple nested fork/wait operations @@ -261,6 +307,22 @@ def test_fork_wait_4(self): output_test = static_runtime_module(input, num_forks, num_child_forks) torch.testing.assert_close(output_test, output_ref) + """ + Test Case: To test fork/wait operation in a graph with multiple + nested fork/wait operations on runAsync API returning future + """ + def test_fork_wait_4_async(self): + input = torch.ones(3, 3) + num_forks = 10 + num_child_forks = 10 + torch_graph = torch.jit.script(fork_wait_graph4) + static_runtime_module = StaticModule(torch_graph) + output_ref = torch_graph(input, num_forks, num_child_forks) + output_test = static_runtime_module.runAsync( + (input, num_forks, num_child_forks), {}) + output_test.wait() + torch.testing.assert_close(output_test.value(), output_ref) + """ Test Case: To test exception handling in fork/wait operation. Add.Tensor op is called for tensors with @@ -290,6 +352,36 @@ def test_fork_wait_exception(self): f"not contain expected substring: \"{expected_error_msg}\"" ) from error + """ + Test Case: To test exception handling in fork/wait + operation with runAsync API. Add.Tensor op is called for + tensors with non-matching dims on the forked subgraph + and the exception raised by subgraph is set on future returned + by prim::fork to parent graph. Returned exception is + checked for substring expected_error_msg as declared below + """ + def test_fork_wait_exception_async(self): + # incompatible tensors for add due to shape mismatch + input1 = torch.randn(4, 7) + input2 = torch.randn(4, 5) + torch_graph = torch.jit.script(fork_wait_graph_exception) + try: + static_runtime_module = StaticModule(torch_graph) + output_test = static_runtime_module.runAsync( + (input1, input2), {}) + except Exception as error: + expected_error_msg = ( + "The size of tensor a (7) must match the size " + "of tensor b (5) at non-singleton dimension 1" + ) + # test fails if error does not contain expected substr + if str(error).find(expected_error_msg) == -1: + raise RuntimeError( + "Tried execution of add.Tensors with incompatible shape. " + "Exception raised by forked runtime execution does " + f"not contain expected substring: \"{expected_error_msg}\"" + ) from error + def test_multihead_attention_layer(self): HID_DIM = 256 QUERY_LEN = 8 diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index c341ef36dae11..174ba0debdb1d 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -14,11 +14,11 @@ from torch.testing._internal.common_utils import ( TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest, - TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize) + TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize, skipIfTorchDynamo) from torch.testing._internal.common_device_type import ( expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes, onlyCPU, largeTensorTest, precisionOverride, dtypes, - onlyCUDA, skipCPUIf, dtypesIfCUDA, skipMeta, get_all_device_types) + onlyCUDA, skipCPUIf, dtypesIfCUDA, skipMeta) from torch.testing._internal.common_dtype import ( all_types_and_complex_and, all_types_and, floating_and_complex_types, floating_types, floating_and_complex_types_and, integral_types_and, get_all_dtypes @@ -839,6 +839,7 @@ def test_device_rounding(self, device, dtype): # Note: This test failed on XLA since its test cases are created by empty_strided which # doesn't support overlapping sizes/strides in XLA impl + @skipIfTorchDynamo("TorchDynamo fails on this test for unknown reasons") @onlyNativeDeviceTypes def test_like_fn_stride_proparation_vs_tensoriterator_unary_op(self, device): # Test like functions against tensoriterator based unary operator (exp) to @@ -2472,124 +2473,138 @@ def test_range_warning(self, device): self.assertEqual(len(w), 1) # TODO: this test should be updated - @onlyCPU def test_arange(self, device): - res = torch.tensor(range(10000)) - res1 = torch.arange(0, 10000) # Use a larger number so vectorized code can be triggered - res2 = torch.tensor([], dtype=torch.int64) + res = torch.tensor(range(10000), device=device) + res1 = torch.arange(0, 10000, device=device) # Use a larger number so vectorized code can be triggered + res2 = torch.tensor([], dtype=torch.int64, device=device) torch.arange(0, 10000, out=res2) self.assertEqual(res, res1, atol=0, rtol=0) self.assertEqual(res, res2, atol=0, rtol=0) # Vectorization on non-contiguous tensors - res = torch.rand(3, 3, 300000).to(torch.int64) + res = torch.rand(3, 3, 300000, device=device).to(torch.int64) res = res.permute(2, 0, 1) torch.arange(0, 300000 * 3 * 3, out=res) - self.assertEqual(res.flatten(), torch.arange(0, 300000 * 3 * 3)) + self.assertEqual(res.flatten(), torch.arange(0, 300000 * 3 * 3, device=device)) # Check arange with only one argument - res1 = torch.arange(10) - res2 = torch.arange(0, 10) + res1 = torch.arange(10, device=device) + res2 = torch.arange(0, 10, device=device) self.assertEqual(res1, res2, atol=0, rtol=0) # Check arange for non-contiguous tensors. - x = torch.zeros(2, 3) + x = torch.zeros(2, 3, device=device) torch.arange(0, 4, out=x.narrow(1, 1, 2)) - res2 = torch.tensor(((0., 0., 1.), (0., 2., 3.))) + res2 = torch.tensor(((0., 0., 1.), (0., 2., 3.)), device=device) self.assertEqual(x, res2, atol=1e-16, rtol=0) # Check negative - res1 = torch.tensor((1., 0.)) - res2 = torch.tensor([]) + res1 = torch.tensor((1., 0.), device=device) + res2 = torch.tensor([], device=device) torch.arange(1, -1, -1, out=res2) self.assertEqual(res1, res2, atol=0, rtol=0) # Equal bounds - res1 = torch.ones(1) - res2 = torch.tensor([]) + res1 = torch.ones(1, device=device) + res2 = torch.tensor([], device=device) torch.arange(1, 0, -1, out=res2) self.assertEqual(res1, res2, atol=0, rtol=0) torch.arange(1, 2, 1, out=res2) self.assertEqual(res1, res2, atol=0, rtol=0) # FloatTensor - res1 = torch.arange(0.6, 0.89, 0.1, out=torch.FloatTensor()) + out = torch.tensor([], dtype=torch.float, device=device) + res1 = torch.arange(0.6, 0.89, 0.1, out=out) self.assertEqual(res1, [0.6, 0.7, 0.8]) - res1 = torch.arange(1, 10, 0.3, out=torch.FloatTensor()) + out = torch.tensor([], dtype=torch.float, device=device) + res1 = torch.arange(1, 10, 0.3, out=out) self.assertEqual(res1.size(0), 30) self.assertEqual(res1[0], 1) self.assertEqual(res1[29], 9.7) # DoubleTensor - res1 = torch.arange(0.6, 0.89, 0.1, out=torch.DoubleTensor()) + out = torch.tensor([], dtype=torch.double, device=device) + res1 = torch.arange(0.6, 0.89, 0.1, out=out) self.assertEqual(res1, [0.6, 0.7, 0.8]) - res1 = torch.arange(1, 10, 0.3, out=torch.DoubleTensor()) + out = torch.tensor([], dtype=torch.double, device=device) + res1 = torch.arange(1, 10, 0.3, out=out) self.assertEqual(res1.size(0), 30) self.assertEqual(res1[0], 1) self.assertEqual(res1[29], 9.7) # Bool Input matching numpy semantics - r = torch.arange(True) + r = torch.arange(True, device=device) self.assertEqual(r[0], 0) - r2 = torch.arange(False) + r2 = torch.arange(False, device=device) self.assertEqual(len(r2), 0) self.assertEqual(r.dtype, torch.int64) self.assertEqual(r2.dtype, torch.int64) # Check that it's exclusive - r = torch.arange(0, 5) + r = torch.arange(0, 5, device=device) self.assertEqual(r.min(), 0) self.assertEqual(r.max(), 4) self.assertEqual(r.numel(), 5) - r = torch.arange(0, 5, 2) + r = torch.arange(0, 6, 3, device=device) + self.assertEqual(r.min(), 0) + self.assertEqual(r.max(), 3) + self.assertEqual(r.numel(), 2) + + r = torch.arange(0, 5, 2, device=device) self.assertEqual(r.min(), 0) self.assertEqual(r.max(), 4) self.assertEqual(r.numel(), 3) - r1 = torch.arange(0, 5 + 1e-6) + r = torch.arange(0, -5, -2, device=device) + self.assertEqual(r.min(), -4) + self.assertEqual(r.max(), 0) + self.assertEqual(r.numel(), 3) + + r1 = torch.arange(0, 5 + 1e-6, device=device) # NB: without the dtype, we'll infer output type to be int64 - r2 = torch.arange(0, 5, dtype=torch.float32) - r3 = torch.arange(0, 5 - 1e-6) + r2 = torch.arange(0, 5, dtype=torch.float32, device=device) + r3 = torch.arange(0, 5 - 1e-6, device=device) self.assertEqual(r1[:-1], r2, atol=0, rtol=0) self.assertEqual(r2, r3, atol=0, rtol=0) - r1 = torch.arange(10, -1 + 1e-6, -1) + r1 = torch.arange(10, -1 + 1e-6, -1, device=device) # NB: without the dtype, we'll infer output type to be int64 - r2 = torch.arange(10, -1, -1, dtype=torch.float32) - r3 = torch.arange(10, -1 - 1e-6, -1) + r2 = torch.arange(10, -1, -1, dtype=torch.float32, device=device) + r3 = torch.arange(10, -1 - 1e-6, -1, device=device) self.assertEqual(r1, r2, atol=0, rtol=0) self.assertEqual(r2, r3[:-1], atol=0, rtol=0) + w = 1449629115440469 + r = torch.arange(0, 100 * w, w, device=device) + self.assertEqual(r.numel(), 100) + # Test Rounding Errors - line = torch.zeros(size=(1, 49)) + line = torch.zeros(size=(1, 49), device=device) self.assertWarnsRegex(UserWarning, 'The out tensor will be resized', lambda: torch.arange(-1, 1, 2. / 49, dtype=torch.float32, out=line)) self.assertEqual(line.shape, [50]) x = torch.empty(1).expand(10) self.assertRaises(RuntimeError, lambda: torch.arange(10, out=x)) - msg = "unsupported range" - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'))) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'))) - - for device in get_all_device_types(): - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(-5, float('nan'), device=device)) - # check with step size - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('-inf'), -1, device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'), device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('-inf'), 10, device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), 10, device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'), device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), device=device)) - self.assertRaisesRegex( - RuntimeError, "overflow", - lambda: torch.arange(1.175494351e-38, 3.402823466e+38, device=device)) - - # check that it holds a consistent output shape on precision-cornered step sizes - d = torch.arange(-4.0, 4.0, 0.01, dtype=torch.float32, device=device) - self.assertEqual(d.shape[0], 800) + msg = "unsupported range" + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(-5, float('nan'), device=device)) + # check with step size + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('-inf'), -1, device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'), device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('-inf'), 10, device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), 10, device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'), device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), device=device)) + + self.assertRaisesRegex( + RuntimeError, "overflow", + lambda: torch.arange(1.175494351e-38, 3.402823466e+38, device=device)) + + # check that it holds a consistent output shape on precision-cornered step sizes + d = torch.arange(-4.0, 4.0, 0.01, dtype=torch.float32, device=device) + self.assertEqual(d.shape[0], 800) # TODO: this test should be updated @onlyCPU diff --git a/test/test_testing.py b/test/test_testing.py index 89be4cf3a2cd9..6eadd4058ad47 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -1134,7 +1134,7 @@ def test_matching(self): def test_mismatching_crow_indices_msg(self): actual_crow_indices = (0, 1, 2) - actual_col_indices = (1, 0) + actual_col_indices = (0, 1) actual_values = (1, 2) actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) @@ -1192,7 +1192,7 @@ def test_matching(self): def test_mismatching_ccol_indices_msg(self): actual_ccol_indices = (0, 1, 2) - actual_row_indices = (1, 0) + actual_row_indices = (0, 1) actual_values = (1, 2) actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) @@ -1250,7 +1250,7 @@ def test_matching(self): def test_mismatching_crow_indices_msg(self): actual_crow_indices = (0, 1, 2) - actual_col_indices = (1, 0) + actual_col_indices = (0, 1) actual_values = ([[1]], [[2]]) actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) @@ -1308,7 +1308,7 @@ def test_matching(self): def test_mismatching_ccol_indices_msg(self): actual_ccol_indices = (0, 1, 2) - actual_row_indices = (1, 0) + actual_row_indices = (0, 1) actual_values = ([[1]], [[2]]) actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) @@ -1774,8 +1774,9 @@ def test_circular_dependencies(self) -> None: "torch.distributed.elastic.rendezvous", # depps on etcd "torch.backends._coreml", # depends on pycoreml "torch.contrib.", # something weird - "torch.testing._internal.common_fx2trt", # needs fx "torch.testing._internal.distributed.", # just fails + "torch.ao.sparsity._experimental.", # depends on pytorch_lightning, not user-facing + "torch.cuda._dynamo_graphs", # depends on torchdynamo ] # See https://github.com/pytorch/pytorch/issues/77801 if not sys.version_info >= (3, 9): diff --git a/test/test_torch.py b/test/test_torch.py index 265bff919aa11..f335443de7536 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -34,7 +34,7 @@ TestCase, TEST_WITH_ROCM, run_tests, IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest, - TEST_WITH_CROSSREF, + TEST_WITH_CROSSREF, skipIfTorchDynamo, skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, @@ -56,7 +56,7 @@ tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN) from torch.testing._internal.common_dtype import ( floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types, - all_types_and, floating_types, floating_and_complex_types, integral_types, + all_types_and, floating_types, floating_and_complex_types, ) # Protects against includes accidentally setting the default dtype @@ -157,7 +157,7 @@ def rand_byte(): for i in range(10): bytes_list = [rand_byte() for _ in range(element_size)] scalar = bytes_to_scalar(bytes_list, dtype, device) - self.assertEqual(scalar.storage()._untyped().tolist(), bytes_list) + self.assertEqual(scalar.storage().untyped().tolist(), bytes_list) @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.bool, torch.float32, torch.complex64, torch.float64, @@ -175,7 +175,7 @@ def test_storage(self, device, dtype): v_s[el_num], v[dim0][dim1]) - v_s_byte = v.storage()._untyped() + v_s_byte = v.storage().untyped() el_size = v.element_size() for el_num in range(v.numel()): @@ -238,7 +238,7 @@ def test_tensor_from_storage(self, device, dtype): a_s = a.storage() b = torch.tensor(a_s, device=device, dtype=dtype).reshape(a.size()) self.assertEqual(a, b) - c = torch.tensor(a_s._untyped(), device=device, dtype=dtype).reshape(a.size()) + c = torch.tensor(a_s.untyped(), device=device, dtype=dtype).reshape(a.size()) self.assertEqual(a, c) for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): @@ -255,7 +255,7 @@ def test_set_storage(self, device, dtype): a_s = a.storage() b = torch.tensor([], device=device, dtype=dtype).set_(a_s).reshape(a.size()) self.assertEqual(a, b) - c = torch.tensor([], device=device, dtype=dtype).set_(a_s._untyped()).reshape(a.size()) + c = torch.tensor([], device=device, dtype=dtype).set_(a_s.untyped()).reshape(a.size()) self.assertEqual(a, c) for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): @@ -267,11 +267,11 @@ def test_set_storage(self, device, dtype): def _check_storage_meta(self, s, s_check): self.assertTrue( - isinstance(s, (torch._UntypedStorage, torch._TypedStorage)) and + isinstance(s, (torch.UntypedStorage, torch.TypedStorage)) and isinstance(s_check, type(s)), ( - 's and s_check must both be one of _UntypedStorage or ' - '_TypedStorage, but got' + 's and s_check must both be one of UntypedStorage or ' + 'TypedStorage, but got' f' {type(s).__name__} and {type(s_check).__name__}')) self.assertEqual(s.device.type, 'meta') @@ -282,9 +282,9 @@ def _check_storage_meta(self, s, s_check): with self.assertRaisesRegex(NotImplementedError, r'Not available'): s[0] - if isinstance(s, torch._TypedStorage): + if isinstance(s, torch.TypedStorage): self.assertEqual(s.dtype, s_check.dtype) - self._check_storage_meta(s._untyped(), s_check._untyped()) + self._check_storage_meta(s.untyped(), s_check.untyped()) @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @@ -296,8 +296,8 @@ def test_typed_storage_meta(self, device, dtype): [[1, 2, 3, 4, 5, 6]], ] for args in args_list: - s_check = torch._TypedStorage(*args, dtype=dtype, device=device) - s = torch._TypedStorage(*args, dtype=dtype, device='meta') + s_check = torch.TypedStorage(*args, dtype=dtype, device=device) + s = torch.TypedStorage(*args, dtype=dtype, device='meta') self._check_storage_meta(s, s_check) @onlyNativeDeviceTypes @@ -309,8 +309,8 @@ def test_untyped_storage_meta(self, device): [[1, 2, 3, 4, 5, 6]], ] for args in args_list: - s_check = torch._UntypedStorage(*args, device=device) - s = torch._UntypedStorage(*args, device='meta') + s_check = torch.UntypedStorage(*args, device=device) + s = torch.UntypedStorage(*args, device='meta') self._check_storage_meta(s, s_check) @onlyNativeDeviceTypes @@ -326,7 +326,7 @@ def test_storage_meta_from_tensor(self, device, dtype): @onlyCPU @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) def test_storage_meta_errors(self, device, dtype): - s0 = torch._TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) + s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'): s0.cpu() @@ -361,11 +361,19 @@ def test_storage_meta_errors(self, device, dtype): s0._write_file(f, True, True, s0.element_size()) for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']: - s1 = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype) + s1 = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype) with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'): s1.copy_(s0) + @onlyCUDA + def test_module_share_memory(self): + # Test fix for issue #80733 + # See https://github.com/pytorch/pytorch/issues/80733 + model = torch.nn.Linear(3, 1) + model_cuda = model.to('cuda') + model.share_memory() + @dtypes(torch.float32, torch.complex64) def test_deepcopy(self, device, dtype): from copy import deepcopy @@ -730,6 +738,7 @@ def test_scalar_check(self, device): self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='sum').shape) # Uses mismatched arange out size to trigger a warning + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") @unittest.skipIf(TEST_WITH_CROSSREF, "crossref perturbs line numbering") def test_cpp_warnings_have_python_context(self, device): # Creates long string in advance to avoid a too-long Python line @@ -2728,6 +2737,7 @@ def test_large_cumprod(self, device, dtype): x[2::3] = .5 self._test_large_cum_fn_helper(x, lambda x: torch.cumprod(x, 0)) + @skipIfTorchDynamo("Torchdynamo fails with unknown reason") @skipIfMps def test_discontiguous_out_cumsum(self, device): x = torch.randn(4, 8, device=device) @@ -2802,6 +2812,21 @@ def test_copy_all_dtypes_and_devices(self, device): # not the data self.assertEqual(x, y) + @onlyCPU + def test_bfloat16_float_copy(self, device): + for shape in [(20, 7), (249, 137), (1029, 917), (1, 7, 19, 17), (3, 77, 1091)]: + input = torch.randn(shape, dtype=torch.float, device=device) + out1 = input.to(torch.bfloat16) + self.assertEqual(input, out1, atol=0, rtol=1e-2, exact_dtype=False) + out2 = out1.to(torch.float) + self.assertEqual(out2, out1, atol=0, rtol=0, exact_dtype=False) + + input_s = input[..., ::2, :] + out1 = input_s.to(torch.bfloat16) + self.assertEqual(input_s, out1, atol=0, rtol=1e-2, exact_dtype=False) + out2 = out1.to(torch.float) + self.assertEqual(out2, out1, atol=0, rtol=0, exact_dtype=False) + # FIXME: move to data movement test suite @onlyNativeDeviceTypes def test_copy_math_view(self, device): @@ -2934,27 +2959,25 @@ def test_narrow_empty(self, device): # FIXME: move to indexing test suite @parametrize("reduce", ['prod', 'amin', 'amax', 'mean']) - @dtypes(*floating_types_and(torch.half, torch.bfloat16)) + @dtypes(*all_types_and(torch.half, torch.bfloat16)) def test_index_reduce(self, device, dtype, reduce): size = (3, 4, 5) index_dtypes = [torch.int, torch.long] include_selfs = [True, False] - reduction_init = {'prod': 1, 'mean': 0, 'amin': float('inf'), 'amax': -float('inf')} + amin_init = float('inf') if dtype.is_floating_point else torch.iinfo(dtype).max + amax_init = -float('inf') if dtype.is_floating_point else torch.iinfo(dtype).min + reduction_init = {'prod': 1, 'mean': 0, 'amin': amin_init, 'amax': amax_init} - for dest_contig, src_contig, index_contig in product([True, False], repeat=3): + for dest_noncontig, src_noncontig, index_noncontig in product([True, False], repeat=3): for idx_dtype, include_self in product(index_dtypes, include_selfs): for dim in range(len(size)): num_src = np.random.randint(10) num_dest = size[dim] - dest = torch.randn(size, dtype=dtype, device=device) - if not dest_contig: - dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=True) - src = torch.randn(*size[:dim], num_src, *size[dim + 1:], dtype=dtype, device=device) - if not src_contig: - # noncontiguous_like fails with RuntimeError: XLA tensors do not have storage - src = torch.testing.make_non_contiguous(src) + dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=dest_noncontig) + src_size = size[:dim] + (num_src,) + size[dim + 1:] + src = make_tensor(src_size, device=device, dtype=dtype, noncontiguous=src_noncontig) idx = torch.randint(num_dest, (num_src,), dtype=idx_dtype, device=device) - if not index_contig: + if index_noncontig: # noncontiguous_like fails with RuntimeError: XLA tensors do not have storage idx = torch.testing.make_non_contiguous(idx) expected = dest.clone() @@ -2977,7 +3000,10 @@ def test_index_reduce(self, device, dtype, reduce): counts = torch.ones_like(expected) if include_self else torch.zeros_like(expected) counts.index_add_(0, idx, torch.ones_like(src)) counts.masked_fill_(counts == 0, 1) - expected /= counts + if (dtype.is_floating_point): + expected.div_(counts) + else: + expected.div_(counts, rounding_mode="floor") expected = expected.transpose(0, dim) self.assertEqual(dest, expected) @@ -3434,7 +3460,7 @@ def test_scatter_reduce_non_unique_index(self, device, dtype): self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}") @onlyCUDA - @dtypes(*integral_types(), *complex_types()) + @dtypes(*complex_types()) def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype): height = 2 width = 2 @@ -3901,62 +3927,6 @@ def test_dim_function_empty(self, device): with self.assertRaisesRegex(RuntimeError, "INDICES element is out of DATA bounds"): torch.index_select(w, 1, ind_05) - # FIXME: find a test suite for the pdist operator - def _brute_pdist(self, inp, p=2): - """Computes the same as torch.pdist using primitives""" - n = inp.shape[-2] - k = n * (n - 1) // 2 - if k == 0: - # torch complains about empty indices - return torch.empty(inp.shape[:-2] + (0,), dtype=inp.dtype, device=inp.device) - square = torch.norm(inp[..., None, :] - inp[..., None, :, :], p=p, dim=-1) - unroll = square.view(square.shape[:-2] + (n * n,)) - inds = torch.ones(k, dtype=torch.int) - inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int) - return unroll[..., inds.cumsum(0)] - - # FIXME: find a test suite for the pdist operator - def _pdist_single(self, shape, device, p, dtype, trans, grad_check=False): - x = torch.randn(shape, dtype=dtype, device=device) - if trans: - x.transpose_(-2, -1) - if grad_check: - x.requires_grad_() - y = x.detach().clone().requires_grad_() - else: - y = x - actual = torch.pdist(x, p=p) - expected = self._brute_pdist(y, p=p) - self.assertEqual(expected.shape, actual.shape) - self.assertEqual(expected, actual) - if grad_check and expected.size() != torch.Size([0]): - g0 = torch.rand_like(actual) - actual.backward(g0) - expected.backward(g0) - self.assertEqual(x.grad, y.grad) - - # FIXME: find a test suite for the pdist operator - @slowTest - def test_pdist_norm_forward(self, device): - for shape in [(4, 5), (3, 2), (2, 1), (1500, 1)]: - for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: - for trans in [False, True]: - for dtype in [torch.float32, torch.float64]: - self._pdist_single(shape, device, p, dtype, trans, grad_check=False) - - # do a simplified comparison with big inputs, see: - # https://github.com/pytorch/pytorch/issues/15511 - for dtype in [torch.float32, torch.float64]: - self._pdist_single((1000, 2), device, 2, dtype, trans=False, grad_check=False) - - # FIXME: find a test suite for the pdist operator - @slowTest - def test_pdist_norm_backward(self, device): - for shape in [(4, 5), (3, 2), (2, 1), (1500, 1)]: - for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: - for trans in [False, True]: - self._pdist_single(shape, device, p, torch.float64, trans, grad_check=True) - # FIXME: find a test suite for the pdist operator @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration") @skipIfRocm @@ -3966,7 +3936,7 @@ def test_pdist_norm_backward(self, device): def test_pdist_norm_large(self, device): # use dim0>=46342 for forward, see: # https://github.com/pytorch/pytorch/issues/30583 - # Compare output using GPU with the CPU implementation, as brute_pdist uses too much memory + # Compare output using GPU with the CPU implementation x = torch.randn(50000, 1, dtype=torch.float32) # 50k * 4 bytes = 200 KB # Will require 1249975000 float32s expected_cpu = torch.pdist(x, p=2) # ~1250M * 4 bytes = 5 GB on CPU @@ -4600,32 +4570,32 @@ def _test_helper(x, y, bias, memory_format): _unsqueeze_op_add, _unsqueeze_op_clone, ] + x_c = x.contiguous() + y_c = y.contiguous() + b_c = bias.contiguous() for fn in fns: - x_c = x.contiguous() - y_c = y.contiguous() - result_c = fn(x_c, y_c) - result = fn(x, y) - self.assertEqual(result, result_c) + is_inplace = '_(' in inspect.getsource(fn) + x_clone = x.clone() if is_inplace else x + x_c_clone = x_c.clone() if is_inplace else x_c + result_c = fn(x_c_clone, y_c) + result = fn(x_clone, y) + self.assertEqual(result, result_c, "Failed for '{}'".format(inspect.getsource(fn).strip())) self.assertTrue( result.is_contiguous(memory_format=memory_format), "result of the '{}' is not in '{}' format".format(inspect.getsource(fn).strip(), memory_format)) for fn in bias_fns: - x_c = x.contiguous() - b_c = bias.contiguous() result_c = fn(x_c, b_c) result = fn(x, bias) - self.assertEqual(result, result_c) + self.assertEqual(result, result_c, "Failed for '{}'".format(inspect.getsource(fn).strip())) self.assertTrue( result.is_contiguous(memory_format=memory_format), "result of the '{}' is not in '{}' format".format(inspect.getsource(fn).strip(), memory_format)) for fn in return_contig_fns: - x_c = x.contiguous() - y_c = y.contiguous() result_c = fn(x_c, y_c) result = fn(x, y) - self.assertEqual(result, result_c) + self.assertEqual(result, result_c, "Failed for '{}'".format(inspect.getsource(fn).strip())) self.assertTrue( result.is_contiguous(memory_format=torch.contiguous_format), "result of the '{}' is not in '{}' format".format(inspect.getsource(fn).strip(), torch.contiguous_format)) @@ -4642,6 +4612,7 @@ def _test_helper(x, y, bias, memory_format): torch.channels_last_3d) # FIXME: make this a elementwise unary and elementwise binary OpInfo test + @skipIfTorchDynamo("Torchdynamo fails with unknown reason") def test_strides_propagation(self, device): def _test_helper(x, op, unary=False): def compare_strides(s1, s2, div): @@ -5244,8 +5215,9 @@ def test_multinomial_invalid(self, device): def test(probs): with self.assertRaisesRegex(RuntimeError, 'probability tensor contains either `inf`, `nan` or element < 0'): - torch.multinomial(probs.to(device), 2) - torch.cuda.synchronize() + out = torch.multinomial(probs.to(device), 2) + if out.is_cuda: + torch.cuda.synchronize() test(torch.tensor([1., -1., 1.])) test(torch.tensor([1., inf, 1.])) @@ -5258,8 +5230,9 @@ def test_multinomial_invalid_distribution(self, device): def test(probs, replacement): with self.assertRaisesRegex(RuntimeError, r"invalid multinomial distribution \(sum of probabilities <= 0\)"): - torch.multinomial(probs, 2, replacement) - torch.cuda.synchronize() + out = torch.multinomial(probs, 2, replacement) + if out.is_cuda: + torch.cuda.synchronize() x = torch.zeros(3, device=device) y = torch.zeros(3, 3, device=device) @@ -5884,7 +5857,7 @@ def test_unflatten(self): torch.ones(2, 3, 0, 4, 5, 2)) # test invalid args: tensor, str, sizes - with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"): + with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"): torch.tensor([1]).unflatten('A', (1, 1)) # test invalid args: tensor, str, namedshape @@ -6099,6 +6072,7 @@ def test_permute(self): self.assertEqual(perm, new) self.assertEqual(x.size(), orig) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_reversed(self): val = torch.arange(0, 10) self.assertEqual(reversed(val), torch.arange(9, -1, -1)) @@ -6147,6 +6121,7 @@ def test_pickle(self): b = pickle.loads(serialized) self.assertEqual(a, b) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_pickle_parameter(self): import pickle a = torch.nn.Parameter(torch.randn(5, 5)) @@ -6156,6 +6131,7 @@ def test_pickle_parameter(self): self.assertEqual(a.requires_grad, b.requires_grad) self.assertEqual(a, b) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_pickle_parameter_no_requires_grad(self): import pickle a = torch.nn.Parameter(torch.randn(5, 5), requires_grad=False) @@ -6468,7 +6444,7 @@ def test_storage_error(self): torch.storage._LegacyStorage() for storage_class in torch._storage_classes: - if storage_class in [torch._UntypedStorage, torch._TypedStorage]: + if storage_class in [torch.UntypedStorage, torch.TypedStorage]: continue device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu' @@ -6499,9 +6475,9 @@ def test_storage_error(self): s = storage_class() with self.assertRaisesRegex(RuntimeError, r"No positional arguments"): - storage_class(0, wrap_storage=s._untyped()) + storage_class(0, wrap_storage=s.untyped()) - with self.assertRaisesRegex(TypeError, r"must be _UntypedStorage"): + with self.assertRaisesRegex(TypeError, r"must be UntypedStorage"): storage_class(wrap_storage=s) if torch.cuda.is_available(): @@ -6517,40 +6493,40 @@ def test_storage_error(self): s_other_device = s.cuda() with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"): - storage_class(wrap_storage=s_other_device._untyped()) + storage_class(wrap_storage=s_other_device.untyped()) - # _TypedStorage constructor errors + # TypedStorage constructor errors with self.assertRaisesRegex(RuntimeError, r"No positional arguments"): - torch._TypedStorage(0, wrap_storage=s._untyped(), dtype=dtype) + torch.TypedStorage(0, wrap_storage=s.untyped(), dtype=dtype) with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"): - torch._TypedStorage(wrap_storage=s._untyped()) + torch.TypedStorage(wrap_storage=s.untyped()) with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"): - torch._TypedStorage(wrap_storage=s._untyped(), dtype=0) + torch.TypedStorage(wrap_storage=s.untyped(), dtype=0) with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"): - torch._TypedStorage(wrap_storage=s._untyped(), dtype=dtype, device=device) + torch.TypedStorage(wrap_storage=s.untyped(), dtype=dtype, device=device) - with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be _UntypedStorage"): - torch._TypedStorage(wrap_storage=s, dtype=dtype) + with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be UntypedStorage"): + torch.TypedStorage(wrap_storage=s, dtype=dtype) with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"): - torch._TypedStorage(dtype=dtype, device='xla') + torch.TypedStorage(dtype=dtype, device='xla') if torch.cuda.is_available(): if storage_class in quantized_storages: with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"): - torch._TypedStorage(dtype=dtype, device='cuda') + torch.TypedStorage(dtype=dtype, device='cuda') with self.assertRaisesRegex(TypeError, r"Argument type not recognized"): - torch._TypedStorage(torch.tensor([]), dtype=dtype, device=device) + torch.TypedStorage(torch.tensor([]), dtype=dtype, device=device) with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"): - torch._TypedStorage(0, 0, dtype=dtype, device=device) + torch.TypedStorage(0, 0, dtype=dtype, device=device) - if isinstance(s, torch._TypedStorage): - s_other = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype) + if isinstance(s, torch.TypedStorage): + s_other = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, r'cannot set item'): s.fill_(s_other) @@ -8129,19 +8105,32 @@ def test_type_alias(self): for dtype, alias in type_alias_map.items(): self.assertIs(alias, dtype) - # FIXME: Describe this test def test_doc_template(self) -> None: + """ + Test that all public API doc strings use the same standard template for + all common arguments such as tensor or dim + """ from torch._torch_docs import __file__ as doc_file from torch._torch_docs import multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args with open(doc_file, "r", encoding="utf-8") as f: doc_strs = f.read() - for doc_str in re.findall(r'add_docstr\((.*?),.*?("""|\'\'\')(.*?)("""|\'\'\')\)', doc_strs, re.MULTILINE | re.DOTALL): + matches = re.findall( + r'add_docstr\(([^,]+?),[^"\']*?(?:"""|\'\'\')(.*?)(?:"""|\'\'\')(?:\.|,?[^,\)]*?\))', + doc_strs, + re.MULTILINE | re.DOTALL, + ) + self.assertTrue(matches) + + for m in matches: + func = m[0].strip() + desc = m[1].strip() + for common_args in [multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args]: for k, v in common_args.items(): - self.assertNotIn(v, doc_str[2], 'The argument description "{}" in {} can be ' - 'replaced by {{{}}}'.format(v, doc_str[0], k)) + self.assertNotIn(v, desc, 'The argument description "{}" in {} can be ' + 'replaced by {{{}}}'.format(v, func, k)) def test_doc(self): checked_types = (types.MethodType, types.FunctionType, @@ -8297,6 +8286,7 @@ class SlotTensor2(SlotTensor1): self.assertTrue(m1[0]) self.assertTrue(m2[0]) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_tensor_dict_dealloc(self): m, t = Tracker.make() x = torch.empty(2) @@ -8318,6 +8308,7 @@ def __del__(self): del fin_tensor self.assertTrue(m[0]) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_tensor_weakref_dealloc(self): x = torch.empty(2) @@ -8331,6 +8322,7 @@ def cb(r): self.assertTrue(m[0]) self.assertEqual(wref(), None) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_tensor_cycle_via_dict(self): m1, t1 = Tracker.make() x = torch.empty(2) @@ -8406,6 +8398,7 @@ def __del__(self): self.assertTrue(m2[0]) # FIXME: move to test_autograd? + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") def test_backward_hooks_traverse(self): m1, t1 = Tracker.make() m2, t2 = Tracker.make() @@ -8431,6 +8424,7 @@ def test_backward_hooks_traverse(self): self.assertTrue(m1[0]) self.assertTrue(m2[0]) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") def test_dead_weak_ref(self): x = torch.empty(2) w_x = weakref.ref(x) diff --git a/test/test_transformers.py b/test/test_transformers.py index 19670f418313b..68fb89697a4a0 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1,19 +1,84 @@ # Owner(s): ["module: nn"] +import contextlib import torch +import torch.nn as nn +import torch.nn.functional as F import unittest from torch.testing._internal.common_nn import NNTestCase -from torch.testing._internal.common_utils import TEST_FAIRSEQ, parametrize, instantiate_parametrized_tests +from torch.testing._internal.common_utils import ( + TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state) from torch.testing._internal.common_cuda import TEST_CUDA if TEST_FAIRSEQ: import fairseq.models.transformer as fairseq_transformer +@contextlib.contextmanager +def set_default_dtype(dtype): + saved_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(saved_dtype) + class TestTransformers(NNTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True + device_list = ['cpu'] # TODO: is there a way to do parametrize for this? + if TEST_CUDA: + device_list.append('cuda') + + @unittest.skip("4D mask not supported yet - activate when 4D mask supported") + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") # TODO: make this work for both cuda and cpu + def test_self_attn_TxT_attn_mask(self): + embed_dim = 16 + num_heads = 4 + batch_size = 10 + tgt_len = 16 + + query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D] + attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] + attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) + + attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) + + mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() + mta_model.eval() + + # Generate 3D results + with torch.inference_mode(): + output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] + output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] + + output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] + output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] + + self.assertEqual(output_mask_4d, output_mask_TxT) + + @parametrize("device", device_list) + def test_transformerencoderlayer_src_mask(self, device): + batch_size = 2 + seqlen = 4 + d_model = 8 + nhead = 8 + dim_feedforward = 32 + + model = torch.nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True).to(device) + src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model + src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) + + model(src, src_mask=src_mask) + model.eval() + with torch.no_grad(): + model(src, src_mask=src_mask) + @parametrize("use_torchscript", [True, False]) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @@ -39,12 +104,252 @@ def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_ mask = torch.Tensor([[0, 1]]).to(torch.bool) if with_no_grad: - with torch.no_grad(): - model(x, src_key_padding_mask=mask) + cm = torch.no_grad() else: + cm = contextlib.nullcontext() + with cm: model(x, src_key_padding_mask=mask) - @unittest.skipIf(not TEST_FAIRSEQ, "numpy not found") + @parametrize("with_no_grad", [True, False]) + @parametrize("training", [True, False]) + @parametrize("enable_nested_tensor", [False]) + @parametrize("device", device_list) + def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device): + """ + Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has + batch size == sequence length + """ + model = torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True), + num_layers=2, + enable_nested_tensor=enable_nested_tensor + ).to(device) + + with torch.no_grad(): + # set constant weights of the model + for idx, p in enumerate(model.parameters()): + x = p.data + sz = x.view(-1).size(0) + shape = x.shape + x = torch.cos(torch.arange(0, sz).float().view(shape)) + p.data.copy_(x) + + if training: + model = model.train() + else: + model = model.eval() + x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device) + src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device) + + if with_no_grad: + cm = torch.no_grad() + else: + cm = contextlib.nullcontext() + with cm: + result = model(x, mask=src_mask) + + ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351], + [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]], + [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689], + [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]] + ).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + @parametrize("batch_first", [True, False]) + @parametrize("training", [True, False]) + @parametrize("enable_nested_tensor", [True, False]) + @parametrize("device", device_list) + def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device): + def get_a_test_layer(activation, batch_first=False): + d_model = 4 + nhead = 2 + dim_feedforward = 16 + dropout = 0.0 + + layer = nn.TransformerEncoderLayer( + d_model, + nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + batch_first=batch_first, + ).to(device) + + with torch.no_grad(): + # set constant weights of the model + for idx, p in enumerate(layer.parameters()): + x = p.data + sz = x.view(-1).size(0) + shape = x.shape + x = torch.cos(torch.arange(0, sz).float().view(shape)) + p.data.copy_(x) + + return layer + + # this is a deterministic test for TransformerEncoder + activation = F.relu + + def _test(batch_first, training, enable_nested_tensor): + def perm_fn(x): + return x.transpose(1, 0) if batch_first else x + + encoder_layer = get_a_test_layer(activation=activation, + batch_first=batch_first) + + model = nn.TransformerEncoder(encoder_layer, 1).to(device) + if not training: + model = model.eval() + + # deterministic input + encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], + [0.5387, 0.1655, 0.3565, 0.0471]], + [[0.8335, 0.2799, 0.5031, 0.2947], + [0.1402, 0.0318, 0.7636, 0.1346]], + [[0.6333, 0.9344, 0.1376, 0.9938], + [0.8924, 0.2872, 0.6692, 0.2944]], + [[0.9897, 0.6915, 0.3154, 0.1733], + [0.8645, 0.3513, 0.3064, 0.0767]], + [[0.8117, 0.2366, 0.4838, 0.7881], + [0.3718, 0.4945, 0.9511, 0.0864]]] + )).to(device) + result = model(encoder_input) + ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249], + [2.427987, 0.021213, -0.602496, -0.084103]], + [[2.424689, 0.019155, -0.604793, -0.085672], + [2.413863, 0.022211, -0.612486, -0.072490]], + [[2.433774, 0.021598, -0.598343, -0.087548], + [2.425104, 0.019748, -0.604515, -0.084839]], + [[2.436185, 0.022682, -0.596625, -0.087261], + [2.433556, 0.021891, -0.598509, -0.086832]], + [[2.416246, 0.017512, -0.610712, -0.082961], + [2.422901, 0.024187, -0.606178, -0.074929]]] + )).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + # all 0 src_mask + src_mask = torch.zeros([5, 5]).to(device) == 1 + result = model(encoder_input, mask=src_mask) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + # all 0 + mask = torch.zeros([2, 5]).to(device) == 1 + result = model(encoder_input, src_key_padding_mask=mask) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + mask[0, 1] = 1 + mask[1, 3] = 1 + mask[1, 4] = 1 + # If mask is not left aligned + # We disable nested tensor + model.enable_nested_tensor = enable_nested_tensor + result = model(encoder_input, src_key_padding_mask=mask) + ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642], + [2.428811, 0.021445, -0.601912, -0.084252]], + [[2.425009, 0.019155, -0.604566, -0.085899], + [2.415408, 0.02249, -0.611415, -0.073]], + [[2.434199, 0.021682, -0.598039, -0.087699], + [2.42598, 0.019941, -0.603896, -0.085091]], + [[2.436457, 0.022736, -0.59643, -0.08736], + [2.434021, 0.022093, -0.598179, -0.08679]], + [[2.416531, 0.017498, -0.610513, -0.083181], + [2.4242, 0.024653, -0.605266, -0.074959]]] + )).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + # test case 2, multiple layers no norm + model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device) + if not training: + model = model.eval() + result = model(encoder_input, src_key_padding_mask=mask) + ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003], + [2.419102, 0.017452, -0.608703, -0.085026]], + [[2.419043, 0.017445, -0.608744, -0.084999], + [2.419052, 0.017446, -0.608738, -0.085004]], + [[2.419067, 0.017448, -0.608727, -0.085010], + [2.419098, 0.017452, -0.608706, -0.085024]], + [[2.419072, 0.017449, -0.608724, -0.085012], + [2.419119, 0.017455, -0.608691, -0.085034]], + [[2.419019, 0.017442, -0.608761, -0.084989], + [2.419075, 0.017449, -0.608722, -0.085014]]] + )).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device) + if not training: + model = model.eval() + result = model(encoder_input, src_key_padding_mask=mask) + ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025], + [2.419101, 0.017453, -0.608704, -0.085025]], + [[2.419101, 0.017453, -0.608703, -0.085025], + [2.419101, 0.017453, -0.608704, -0.085025]], + [[2.419101, 0.017453, -0.608703, -0.085025], + [2.419101, 0.017453, -0.608704, -0.085025]], + [[2.419101, 0.017453, -0.608703, -0.085025], + [2.419101, 0.017453, -0.608704, -0.085025]], + [[2.419101, 0.017453, -0.608703, -0.085025], + [2.419101, 0.017453, -0.608704, -0.085025]]] + )).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + # test case 3, multiple layers with norm + # d_model = 4 + norm = nn.LayerNorm(4) + model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device) + if not training: + model = model.eval() + result = model(encoder_input, src_key_padding_mask=mask) + ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238], + [1.695955, -0.357639, -0.893050, -0.445266]], + [[1.695948, -0.357634, -0.893082, -0.445233], + [1.695950, -0.357635, -0.893077, -0.445238]], + [[1.695951, -0.357636, -0.893069, -0.445246], + [1.695955, -0.357639, -0.893052, -0.445264]], + [[1.695952, -0.357636, -0.893066, -0.445249], + [1.695957, -0.357641, -0.893041, -0.445276]], + [[1.695946, -0.357632, -0.893095, -0.445220], + [1.695952, -0.357637, -0.893065, -0.445251]]] + )).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + model = nn.TransformerEncoder(encoder_layer, 6, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device) + if not training: + model = model.eval() + result = model(encoder_input, src_key_padding_mask=mask) + ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265], + [1.695955, -0.357639, -0.893051, -0.445265]], + [[1.695955, -0.357639, -0.893051, -0.445265], + [1.695955, -0.357639, -0.893051, -0.445265]], + [[1.695955, -0.357639, -0.893051, -0.445265], + [1.695955, -0.357639, -0.893051, -0.445265]], + [[1.695955, -0.357639, -0.893051, -0.445265], + [1.695955, -0.357639, -0.893051, -0.445265]], + [[1.695955, -0.357639, -0.893051, -0.445265], + [1.695955, -0.357639, -0.893051, -0.445265]]] + )).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + + # TODO: remove set default dtype to double by making ref_output more precise. + # Added because this test was copied from test_nn.py, which has default + # dtype double. If default dtype is float, tests will say tensors not close because + # ref output precision too low + with set_default_dtype(torch.double): + if training: + cm = contextlib.nullcontext() + else: + cm = torch.no_grad() # transformer fast path requires no grad + with cm: + _test(batch_first, training, enable_nested_tensor) + + @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found") @unittest.skipIf(not TEST_CUDA, 'CUDA not available') def test_decoder_only_layer(self): DEFAULT_PADDING_IDX = 0 @@ -346,4 +651,83 @@ def set_weights_deterministic(model): self.assertEqual(result.shape, ref_output.shape) torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2) + @parametrize("input_dim,attn_mask_dim,is_causal", + [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), + (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], + name_fn=lambda input_dim, attn_dim, is_causal: ( + f"{input_dim}D_input_dim_" + ( + f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask" + if attn_dim is not None else "no_attn_mask"))) + @parametrize("dropout_p", [0.0, 0.2, 0.5]) + @parametrize("device", device_list) + def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p): + # TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used. + dtypes = [torch.double, torch.float] + for dtype in dtypes: + + def rand_tensor(*shape): + return torch.randn(shape, device=device, dtype=dtype) + + # This test compares python and C++ implementations of SDP. + N, N_prime, L, S, E = 5, 2, 4, 3, 6 + if input_dim == 3: + query = rand_tensor(N, L, E) + key = rand_tensor(N, S, E) + value = rand_tensor(N, S, E) + elif input_dim == 4: + query = rand_tensor(N, N_prime, L, E) + key = rand_tensor(N, N_prime, S, E) + value = rand_tensor(N, N_prime, S, E) + else: + self.fail(f'Invalid input_dim {input_dim} encountered in SDP test') + + attn_mask = None + if attn_mask_dim is not None: + assert attn_mask_dim in [2, input_dim] + mask_size = (L, S) if attn_mask_dim == 2 else ((N, L, S) if input_dim == 3 else (N, N_prime, L, S)) + attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal + else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool)) + + with freeze_rng_state(): + # Python impl only supports float mask and 3D inputs. + attn_mask_float = attn_mask + if attn_mask_float is not None: + attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype) + attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf")) + q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E) + a = attn_mask_float + if a is not None and attn_mask_dim > 3: + a = a.view(-1, L, S) + expected = F._scaled_dot_product_attention( + q, k, v, attn_mask=a, dropout_p=dropout_p) + if input_dim > 3: + expected = (expected[0].view(-1, N_prime, L, E), expected[1].view(-1, N_prime, L, S)) + + need_attn_weights: bool = True + with freeze_rng_state(): + if is_causal: + # NB: Don't pass attn_mask here + actual = torch.ops.aten._scaled_dot_product_attention( + query, key, value, None, dropout_p, need_attn_weights, is_causal) + + # Error case: both explicit attn_mask and is_causal are set + with self.assertRaisesRegex(RuntimeError, + "Explicit attn_mask should not be set when is_causal=True"): + torch.ops.aten._scaled_dot_product_attention( + query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) + else: + actual = torch.ops.aten._scaled_dot_product_attention( + query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) + + # freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable. + # TODO: Do this skipping in a nicer way once the granular test skipping logic lands. + if dropout_p == 0.0 or device == 'cpu': + self.assertEqual(actual, expected) + + +# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for +# cross device / dtype testing. instantiate_parametrized_tests(TestTransformers) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index d7611eddf8228..25190f8976ccc 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -1,6 +1,6 @@ # Owner(s): ["module: type promotion"] -from functools import wraps +from functools import (partial, wraps) import itertools import unittest @@ -816,8 +816,9 @@ def _test_sparse_op(self, op_name, inplace, dtype1, dtype2, device, coalesced): suffix = '_' if inplace else '' err = "{} {}({}, {})".format(" coalesced" if coalesced else "uncoalesced", op_name + suffix, dtype1, dtype2) - def op(t1, t2): - return getattr(t1, op_name + suffix)(t2) + def op(t1, t2, suf=None): + suf = suffix if suf is None else suf + return getattr(t1, op_name + suf)(t2) add_sub = op_name == 'add' or op_name == 'sub' @@ -852,21 +853,25 @@ def op(t1, t2): self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense()) # Test op(dense, sparse) - if add_sub: + if add_sub or op_name == 'mul': if inplace: e, d1, s1, d2, s2 = [x.clone() for x in test_tensors] dense_sparse = op(d1, s2) + dense_sparse = dense_sparse.to_dense() if dense_sparse.is_sparse else dense_sparse self.assertEqual(e, dense_sparse, atol=precision, rtol=rtol, msg=err) else: # sparse division only supports division by a scalar # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz' self.assertRaises(RuntimeError, lambda: op(d1, s2)) - # Test op(sparse, dense) not supported for any ops: + # Test op(sparse, dense) not supported for all ops but 'mul'. # add(sparse, dense) is not supported. Use add(dense, sparse) instead. # sparse division only supports division by a scalar - # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'. - self.assertRaises(RuntimeError, lambda: op(s1, d2)) + if op_name != 'mul': + self.assertRaises(RuntimeError, lambda: op(s1, d2)) + else: + # No type promotions for inplace operations, hence suf='' + op(s1, d2, suf='') # Test op(sparse, scalar) if not add_sub and not (self.device_type == 'cpu' and dtype1 == torch.half): @@ -932,6 +937,53 @@ def test_integer_addcdiv_deprecated(self, device, dtype): with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported+'): t.addcdiv_(t, t) + def _ternary_promotion_common(self, device, op1, op2): + make_arg = partial(make_tensor, device=device) + + types = ( + (torch.float64, torch.float64, torch.complex128), + (torch.long, torch.bfloat16, torch.float32), + ) + + for type1, type2, type3 in types: + arg1 = make_arg([5, 5], dtype=type1) + arg2 = make_arg([5, 5], dtype=type2) + arg3 = make_arg([1, 5], dtype=type3) + + res1 = op1(arg1, arg2, arg3) + res2 = op2(arg1, arg2, arg3) + + # res1 and res2 are not guaranteed to be the same. They are the + # same when all the inputs are tensors with one or more dimensions. + self.assertEqual(res1, res2) + self.assertEqual(res1.dtype, res2.dtype) + + # Fails on XLA: + # https://github.com/pytorch/pytorch/pull/74234#issuecomment-1117169366 + # https://github.com/pytorch/xla/issues/3551 + @onlyNativeDeviceTypes + def test_addcdiv_promotion(self, device): + def op1(arg1, arg2, arg3): + return torch.addcdiv(arg1, arg2, arg3) + + def op2(arg1, arg2, arg3): + return arg1 + arg2 / arg3 + + self._ternary_promotion_common(device, op1, op2) + + # Fails on XLA: + # https://github.com/pytorch/pytorch/pull/74234#issuecomment-1117169366 + # https://github.com/pytorch/xla/issues/3551 + @onlyNativeDeviceTypes + def test_addcmul_promotion(self, device): + def op1(arg1, arg2, arg3): + return torch.addcmul(arg1, arg2, arg3) + + def op2(arg1, arg2, arg3): + return arg1 + arg2 * arg3 + + self._ternary_promotion_common(device, op1, op2) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @float_double_default_dtype @onlyCPU diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 776e2591a6c9e..6c65457ae24f1 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -1,5 +1,4 @@ # Owner(s): ["module: tests"] - import torch import numpy as np @@ -11,7 +10,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck, - numpy_to_torch_dtype_dict, + numpy_to_torch_dtype_dict, skipIfTorchDynamo ) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta) @@ -127,6 +126,7 @@ def test_conj_self(self, device, dtype): s = t.conj() self.assertTrue(s is t) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) def test_view_dtype_new(self, device, dtype): @@ -475,6 +475,8 @@ def test_select_view(self, device) -> None: v[0] = 0 self.assertEqual(t[2, 0], v[0]) + # Lazy hasn't implemented unbind yet. + @onlyNativeDeviceTypes def test_unbind_view(self, device) -> None: t = torch.zeros((5, 5), device=device) tup = torch.unbind(t) @@ -505,6 +507,9 @@ def test_unbind(self): stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True) gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True) + # TODO: Fix this test for LTC. There is an interaction with dynamic shapes here that is broken, + # causing asserts to trigger. + @onlyNativeDeviceTypes def test_expand_view(self, device) -> None: t = torch.ones((5, 1), device=device) v = t.expand(5, 5) @@ -718,6 +723,8 @@ def test_contiguous_self(self, device): self.assertTrue(s is t) @skipMeta + # self.is_view_of reports false positives for lazy + @onlyNativeDeviceTypes def test_contiguous_nonview(self, device): t = torch.ones(5, 5, device=device) nv = t.t().contiguous() @@ -744,6 +751,8 @@ def test_reshape_as_view(self, device): self.assertEqual(t[1, 1], v[6]) @skipMeta + # self.is_view_of reports false positives for lazy + @onlyNativeDeviceTypes def test_reshape_nonview(self, device): t = torch.ones(5, 5, device=device) nv = torch.reshape(t.t(), (25,)) @@ -752,6 +761,9 @@ def test_reshape_nonview(self, device): nv[6] = 0 self.assertNotEqual(t[1, 1], nv[6]) + # This test use as_strided to construct a tensor with overlapping memory, + # which is not handled by the functionalization pass. + @onlyNativeDeviceTypes def test_flatten_view(self, device): def test_writes_propagate(t, v): idx_t = (0,) * t.ndim @@ -1016,6 +1028,7 @@ def test_empty_reshape(self, device): # match NumPy semantics -- don't infer the size of dimension with a degree of freedom self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_expand(self, device): tensor = torch.rand(1, 8, 1, device=device) tensor2 = torch.rand(5, device=device) @@ -1257,6 +1270,7 @@ def test_chunk(self, device): tensor.chunk(-2) # TODO: make work on CUDA, too + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @onlyCPU def test_unsqueeze(self, device) -> None: x = torch.randn(2, 3, 4) @@ -1820,7 +1834,7 @@ def test_crow_col_indices(self, device): t.crow_indices() t.col_indices() -instantiate_device_type_tests(TestViewOps, globals()) +instantiate_device_type_tests(TestViewOps, globals(), include_lazy=True) instantiate_device_type_tests(TestOldViewOps, globals()) if __name__ == '__main__': diff --git a/third_party/BUCK.oss b/third_party/BUCK.oss index 246b05b6e6040..3d30464f57676 100644 --- a/third_party/BUCK.oss +++ b/third_party/BUCK.oss @@ -1,13 +1,28 @@ load("//third_party:glog.buck.bzl", "define_glog") load("//third_party:xnnpack.buck.bzl", "define_xnnpack") load("//third_party:kineto.buck.bzl", "define_kineto") +load("//:buckbuild.bzl", "third_party") define_glog() -define_xnnpack() +define_xnnpack(third_party) define_kineto() +# a placeholder for libraries that are not implemented in OSS +cxx_library( + name = "no-op", + visibility = ['PUBLIC'], +) + +cxx_library( + name = "rt", + exported_platform_linker_flags = [ + ("^linux-.*$", ["-lrt"]), + ], + visibility = ['PUBLIC'], +) + cxx_library( name = "fmt", srcs = ['fmt/src/format.cc'], @@ -60,6 +75,16 @@ cxx_library( visibility = ["PUBLIC"], ) +cxx_library( + name = "pocketfft_header", + header_namespace = "", + exported_headers = { + "pocketfft_hdronly.h": "pocketfft/pocketfft_hdronly.h", + }, + reexport_all_header_dependencies = True, + visibility = ["PUBLIC"], +) + cxx_library( name = "FXdiv", header_namespace = "", @@ -165,9 +190,9 @@ cxx_library( cxx_library( name = "miniz", - srcs = ["miniz-2.0.8/miniz.c"], + srcs = ["miniz-2.1.0/miniz.c"], header_namespace = "", - exported_headers = {"miniz.h": "miniz-2.0.8/miniz.h"}, + exported_headers = {"miniz.h": "miniz-2.1.0/miniz.h"}, exported_preprocessor_flags = [ "-DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS", ], @@ -315,3 +340,73 @@ cxx_binary( visibility = ["PUBLIC"], deps = [":flatc_library"], ) + +cxx_library( + name = "gtest_headers", + exported_preprocessor_flags = [ + "-DGTEST_USE_OWN_TR1_TUPLE=0", + "-DGTEST_HAS_TR1_TUPLE=0", + "-D_CRT_DECLARE_NONSTDC_NAMES", + "-D_CRT_NONSTDC_NO_WARNINGS", + "-D_CRT_NONSTDC_NO_DEPRECATE", + ], + include_directories = [ + "googletest/googletest", + ], + public_system_include_directories = [ + "googletest/googletest/include", + ], + raw_headers = glob([ + "googletest/googletest/src/**/*.h", + "googletest/googletest/include/**/*.h", + ]), + visibility = [ + "PUBLIC", + ], +) + +cxx_library( + name = "gtest", + srcs = [ + "googletest/googletest/src/gtest-all.cc", + "googletest/googletest/src/gtest_main.cc", + ], + include_directories = [ + "googletest/googletest", + ], + raw_headers = glob([ + "googletest/googletest/src/**/*.cc", + "googletest/googletest/src/**/*.h", + ]), + visibility = [ + "PUBLIC", + ], + xcode_public_headers_symlinks = True, + exported_deps = [ + ":gtest_headers", + ], +) + +cxx_library( + name = "gmock", + srcs = [ + "googletest/googlemock/src/gmock-all.cc", + ], + include_directories = [ + "googletest/googlemock", + ], + public_system_include_directories = [ + "googletest/googlemock/include", + ], + raw_headers = glob([ + "googletest/googlemock/include/**/*.h", + "googletest/googlemock/src/**/*.cc", + ]), + visibility = ["PUBLIC"], + deps = [ + ":gtest", + ], + exported_deps = [ + ":gtest_headers", + ], +) diff --git a/third_party/LICENSES_BUNDLED.txt b/third_party/LICENSES_BUNDLED.txt index 9b61374c0aa76..d03c1c2137e8b 100644 --- a/third_party/LICENSES_BUNDLED.txt +++ b/third_party/LICENSES_BUNDLED.txt @@ -4,361 +4,301 @@ compatibly licensed. We list these here. Name: FP16 License: MIT Files: third_party/FP16 - For details, see third_party/FP16/LICENSE - -Name: FP16-source -License: MIT -Files: third_party/XNNPACK/build/FP16-source - For details, see third_party/XNNPACK/build/FP16-source/LICENSE + For details, see: third_party/FP16/LICENSE Name: FXdiv License: MIT Files: third_party/FXdiv - For details, see third_party/FXdiv/LICENSE - -Name: FXdiv-source -License: MIT -Files: third_party/XNNPACK/build/FXdiv-source - For details, see third_party/XNNPACK/build/FXdiv-source/LICENSE + For details, see: third_party/FXdiv/LICENSE Name: NNPACK License: BSD-2-Clause Files: third_party/NNPACK - For details, see third_party/NNPACK/LICENSE + For details, see: third_party/NNPACK/LICENSE Name: QNNPACK License: BSD-3-Clause Files: third_party/QNNPACK - For details, see third_party/QNNPACK/LICENSE + For details, see: third_party/QNNPACK/LICENSE Name: XNNPACK License: BSD-3-Clause Files: third_party/XNNPACK - For details, see third_party/XNNPACK/LICENSE + For details, see: third_party/XNNPACK/LICENSE Name: benchmark License: Apache-2.0 Files: third_party/benchmark, - third_party/onnx/third_party/benchmark, + third_party/protobuf/third_party/benchmark, third_party/onnx-tensorrt/third_party/onnx/third_party/benchmark, - third_party/protobuf/third_party/benchmark - For details, see third_party/benchmark/LICENSE, - third_party/onnx/third_party/benchmark/LICENSE, + third_party/onnx/third_party/benchmark + For details, see: third_party/benchmark/LICENSE, + third_party/protobuf/third_party/benchmark/LICENSE, third_party/onnx-tensorrt/third_party/onnx/third_party/benchmark/LICENSE, - third_party/protobuf/third_party/benchmark/LICENSE - -Name: breakpad -License: BSD-3-Clause -Files: third_party/breakpad - For details, see third_party/breakpad/LICENSE + third_party/onnx/third_party/benchmark/LICENSE Name: clog License: BSD-2-Clause -Files: third_party/QNNPACK/deps/clog, - third_party/XNNPACK/build/clog-source/deps/clog, - third_party/XNNPACK/build/cpuinfo-source/deps/clog, - third_party/cpuinfo/deps/clog, - third_party/fbgemm/third_party/cpuinfo/deps/clog - For details, see third_party/QNNPACK/deps/clog/LICENSE, - third_party/XNNPACK/build/clog-source/deps/clog/LICENSE, - third_party/XNNPACK/build/cpuinfo-source/deps/clog/LICENSE, - third_party/cpuinfo/deps/clog/LICENSE, - third_party/fbgemm/third_party/cpuinfo/deps/clog/LICENSE - -Name: clog-source -License: BSD-2-Clause -Files: third_party/XNNPACK/build/clog-source - For details, see third_party/XNNPACK/build/clog-source/LICENSE +Files: third_party/cpuinfo/deps/clog, + third_party/fbgemm/third_party/cpuinfo/deps/clog, + third_party/QNNPACK/deps/clog + For details, see: third_party/cpuinfo/deps/clog/LICENSE, + third_party/fbgemm/third_party/cpuinfo/deps/clog/LICENSE, + third_party/QNNPACK/deps/clog/LICENSE + +Name: cpplint +License: BSD-3-Clause +Files: third_party/nlohmann/tools/cpplint + For details, see: third_party/nlohmann/tools/cpplint/LICENSE Name: cpuinfo License: BSD-2-Clause Files: third_party/cpuinfo, third_party/fbgemm/third_party/cpuinfo - For details, see third_party/cpuinfo/LICENSE, + For details, see: third_party/cpuinfo/LICENSE, third_party/fbgemm/third_party/cpuinfo/LICENSE -Name: cpuinfo-source -License: BSD-2-Clause -Files: third_party/XNNPACK/build/cpuinfo-source - For details, see third_party/XNNPACK/build/cpuinfo-source/LICENSE - Name: cudnn_frontend License: MIT Files: third_party/cudnn_frontend - For details, see third_party/cudnn_frontend/LICENSE.txt + For details, see: third_party/cudnn_frontend/LICENSE.txt Name: dart License: Apache-2.0 Files: third_party/flatbuffers/dart - For details, see third_party/flatbuffers/dart/LICENSE + For details, see: third_party/flatbuffers/dart/LICENSE + +Name: doctest +License: MIT +Files: third_party/nlohmann/tests/thirdparty/doctest + For details, see: third_party/nlohmann/tests/thirdparty/doctest/LICENSE.txt Name: eigen License: BSD-3-Clause Files: third_party/eigen - For details, see third_party/eigen/COPYING.BSD + For details, see: third_party/eigen/COPYING.BSD Name: enum License: BSD-3-Clause Files: third_party/python-enum/enum - For details, see third_party/python-enum/enum/LICENSE + For details, see: third_party/python-enum/enum/LICENSE Name: fbgemm License: BSD-3-Clause Files: third_party/fbgemm - For details, see third_party/fbgemm/LICENSE + For details, see: third_party/fbgemm/LICENSE Name: flatbuffers License: Apache-2.0 Files: third_party/flatbuffers - For details, see third_party/flatbuffers/LICENSE.txt + For details, see: third_party/flatbuffers/LICENSE.txt Name: fmt License: MIT with exception -Files: third_party/fmt, - third_party/kineto/libkineto/third_party/fmt - For details, see third_party/fmt/LICENSE.rst, - third_party/kineto/libkineto/third_party/fmt/LICENSE.rst +Files: third_party/kineto/libkineto/third_party/fmt, + third_party/fmt + For details, see: third_party/kineto/libkineto/third_party/fmt/LICENSE.rst, + third_party/fmt/LICENSE.rst Name: foxi License: MIT Files: third_party/foxi - For details, see third_party/foxi/LICENSE + For details, see: third_party/foxi/LICENSE Name: gemmlowp License: Apache-2.0 Files: third_party/gemmlowp/gemmlowp - For details, see third_party/gemmlowp/gemmlowp/LICENSE + For details, see: third_party/gemmlowp/gemmlowp/LICENSE Name: generator License: Apache-2.0 -Files: third_party/XNNPACK/build/googletest-source/googlemock/scripts/generator, - third_party/benchmark/build/third_party/googletest/src/googlemock/scripts/generator, - third_party/fbgemm/third_party/googletest/googlemock/scripts/generator, +Files: third_party/kineto/libkineto/third_party/googletest/googlemock/scripts/generator, third_party/googletest/googlemock/scripts/generator, - third_party/kineto/libkineto/third_party/googletest/googlemock/scripts/generator, + third_party/fbgemm/third_party/googletest/googlemock/scripts/generator, third_party/protobuf/third_party/googletest/googlemock/scripts/generator, third_party/tensorpipe/third_party/googletest/googlemock/scripts/generator - For details, see third_party/XNNPACK/build/googletest-source/googlemock/scripts/generator/LICENSE, - third_party/benchmark/build/third_party/googletest/src/googlemock/scripts/generator/LICENSE, - third_party/fbgemm/third_party/googletest/googlemock/scripts/generator/LICENSE, + For details, see: third_party/kineto/libkineto/third_party/googletest/googlemock/scripts/generator/LICENSE, third_party/googletest/googlemock/scripts/generator/LICENSE, - third_party/kineto/libkineto/third_party/googletest/googlemock/scripts/generator/LICENSE, + third_party/fbgemm/third_party/googletest/googlemock/scripts/generator/LICENSE, third_party/protobuf/third_party/googletest/googlemock/scripts/generator/LICENSE, third_party/tensorpipe/third_party/googletest/googlemock/scripts/generator/LICENSE Name: gloo License: BSD-3-Clause Files: third_party/gloo - For details, see third_party/gloo/LICENSE - -Name: googlebenchmark-source -License: Apache-2.0 -Files: third_party/XNNPACK/build/googlebenchmark-source - For details, see third_party/XNNPACK/build/googlebenchmark-source/LICENSE + For details, see: third_party/gloo/LICENSE Name: googlemock License: BSD-3-Clause -Files: third_party/XNNPACK/build/googletest-source/googlemock, +Files: third_party/kineto/libkineto/third_party/googletest/googlemock, third_party/fbgemm/third_party/googletest/googlemock, - third_party/kineto/libkineto/third_party/googletest/googlemock, third_party/protobuf/third_party/googletest/googlemock, third_party/tensorpipe/third_party/googletest/googlemock - For details, see third_party/XNNPACK/build/googletest-source/googlemock/LICENSE, + For details, see: third_party/kineto/libkineto/third_party/googletest/googlemock/LICENSE, third_party/fbgemm/third_party/googletest/googlemock/LICENSE, - third_party/kineto/libkineto/third_party/googletest/googlemock/LICENSE, third_party/protobuf/third_party/googletest/googlemock/LICENSE, third_party/tensorpipe/third_party/googletest/googlemock/LICENSE Name: googletest License: BSD-3-Clause -Files: third_party/XNNPACK/build/googletest-source/googletest, +Files: third_party/kineto/libkineto/third_party/googletest, + third_party/kineto/libkineto/third_party/googletest/googletest, + third_party/googletest, third_party/fbgemm/third_party/googletest, third_party/fbgemm/third_party/googletest/googletest, - third_party/googletest, - third_party/kineto/libkineto/third_party/googletest, - third_party/kineto/libkineto/third_party/googletest/googletest, third_party/protobuf/third_party/googletest, third_party/protobuf/third_party/googletest/googletest, third_party/tensorpipe/third_party/googletest, third_party/tensorpipe/third_party/googletest/googletest - For details, see third_party/XNNPACK/build/googletest-source/googletest/LICENSE, + For details, see: third_party/kineto/libkineto/third_party/googletest/LICENSE, + third_party/kineto/libkineto/third_party/googletest/googletest/LICENSE, + third_party/googletest/LICENSE, third_party/fbgemm/third_party/googletest/LICENSE, third_party/fbgemm/third_party/googletest/googletest/LICENSE, - third_party/googletest/LICENSE, - third_party/kineto/libkineto/third_party/googletest/LICENSE, - third_party/kineto/libkineto/third_party/googletest/googletest/LICENSE, third_party/protobuf/third_party/googletest/LICENSE, third_party/protobuf/third_party/googletest/googletest/LICENSE, third_party/tensorpipe/third_party/googletest/LICENSE, third_party/tensorpipe/third_party/googletest/googletest/LICENSE -Name: googletest-source -License: BSD-3-Clause -Files: third_party/XNNPACK/build/googletest-source - For details, see third_party/XNNPACK/build/googletest-source/LICENSE - Name: gtest License: BSD-3-Clause Files: third_party/ideep/mkl-dnn/tests/gtest, third_party/ideep/mkl-dnn/third_party/oneDNN/tests/gtests/gtest - For details, see third_party/ideep/mkl-dnn/tests/gtest/LICENSE, + For details, see: third_party/ideep/mkl-dnn/tests/gtest/LICENSE, third_party/ideep/mkl-dnn/third_party/oneDNN/tests/gtests/gtest/LICENSE Name: ideep License: MIT Files: third_party/ideep - For details, see third_party/ideep/LICENSE + For details, see: third_party/ideep/LICENSE Name: ios-cmake License: BSD-3-Clause Files: third_party/ios-cmake - For details, see third_party/ios-cmake/LICENSE + For details, see: third_party/ios-cmake/LICENSE Name: json License: MIT Files: third_party/cudnn_frontend/include/contrib/nlohmann/json - For details, see third_party/cudnn_frontend/include/contrib/nlohmann/json/LICENSE.txt + For details, see: third_party/cudnn_frontend/include/contrib/nlohmann/json/LICENSE.txt Name: kineto License: BSD-3-Clause Files: third_party/kineto - For details, see third_party/kineto/LICENSE - -Name: libdisasm -License: Clarified Artistic License -Files: third_party/breakpad/src/third_party/libdisasm - For details, see third_party/breakpad/src/third_party/libdisasm/LICENSE + For details, see: third_party/kineto/LICENSE Name: libnop License: Apache-2.0 Files: third_party/tensorpipe/third_party/libnop - For details, see third_party/tensorpipe/third_party/libnop/LICENSE + For details, see: third_party/tensorpipe/third_party/libnop/LICENSE Name: libuv License: MIT Files: third_party/tensorpipe/third_party/libuv - For details, see third_party/tensorpipe/third_party/libuv/LICENSE - -Name: lss -License: BSD-3-Clause -Files: third_party/breakpad/src/third_party/lss - For details, see third_party/breakpad/src/third_party/lss/LICENSE + For details, see: third_party/tensorpipe/third_party/libuv/LICENSE -Name: miniz-2.0.8 +Name: miniz-2.1.0 License: MIT -Files: third_party/miniz-2.0.8 - For details, see third_party/miniz-2.0.8/LICENSE +Files: third_party/miniz-2.1.0 + For details, see: third_party/miniz-2.1.0/LICENSE Name: mkl-dnn License: Apache-2.0 Files: third_party/ideep/mkl-dnn - For details, see third_party/ideep/mkl-dnn/LICENSE + For details, see: third_party/ideep/mkl-dnn/LICENSE Name: nccl License: BSD-3-Clause Files: third_party/nccl/nccl - For details, see third_party/nccl/nccl/LICENSE.txt + For details, see: third_party/nccl/nccl/LICENSE.txt Name: neon2sse License: BSD-Source-Code Files: third_party/neon2sse - For details, see third_party/neon2sse/LICENSE + For details, see: third_party/neon2sse/LICENSE Name: oneDNN License: Apache-2.0 Files: third_party/ideep/mkl-dnn/third_party/oneDNN - For details, see third_party/ideep/mkl-dnn/third_party/oneDNN/LICENSE - -Name: onnx -License: Apache-2.0 -Files: third_party/onnx - For details, see third_party/onnx/LICENSE + For details, see: third_party/ideep/mkl-dnn/third_party/oneDNN/LICENSE Name: onnx License: MIT Files: third_party/onnx-tensorrt/third_party/onnx - For details, see third_party/onnx-tensorrt/third_party/onnx/LICENSE + For details, see: third_party/onnx-tensorrt/third_party/onnx/LICENSE + +Name: onnx +License: Apache-2.0 +Files: third_party/onnx + For details, see: third_party/onnx/LICENSE Name: onnx-tensorrt License: MIT Files: third_party/onnx-tensorrt - For details, see third_party/onnx-tensorrt/LICENSE + For details, see: third_party/onnx-tensorrt/LICENSE Name: protobuf License: BSD-3-Clause Files: third_party/protobuf - For details, see third_party/protobuf/LICENSE + For details, see: third_party/protobuf/LICENSE Name: psimd License: MIT -Files: third_party/XNNPACK/deps/psimd, - third_party/psimd - For details, see third_party/XNNPACK/deps/psimd/LICENSE, - third_party/psimd/LICENSE +Files: third_party/psimd + For details, see: third_party/psimd/LICENSE Name: pthreadpool License: BSD-2-Clause Files: third_party/pthreadpool - For details, see third_party/pthreadpool/LICENSE - -Name: pthreadpool-source -License: BSD-2-Clause -Files: third_party/XNNPACK/build/pthreadpool-source - For details, see third_party/XNNPACK/build/pthreadpool-source/LICENSE + For details, see: third_party/pthreadpool/LICENSE Name: pybind11 License: BSD-3-Clause -Files: third_party/onnx/third_party/pybind11, +Files: third_party/pybind11, third_party/onnx-tensorrt/third_party/onnx/third_party/pybind11, - third_party/pybind11, + third_party/onnx/third_party/pybind11, third_party/tensorpipe/third_party/pybind11 - For details, see third_party/onnx/third_party/pybind11/LICENSE, + For details, see: third_party/pybind11/LICENSE, third_party/onnx-tensorrt/third_party/onnx/third_party/pybind11/LICENSE, - third_party/pybind11/LICENSE, + third_party/onnx/third_party/pybind11/LICENSE, third_party/tensorpipe/third_party/pybind11/LICENSE Name: python-peachpy License: BSD-2-Clause Files: third_party/python-peachpy - For details, see third_party/python-peachpy/LICENSE.rst + For details, see: third_party/python-peachpy/LICENSE.rst Name: python-six License: MIT Files: third_party/python-six - For details, see third_party/python-six/LICENSE + For details, see: third_party/python-six/LICENSE Name: sleef License: BSL-1.0 Files: third_party/sleef - For details, see third_party/sleef/LICENSE.txt - -Name: src -License: BSD-3-Clause -Files: third_party/benchmark/build/third_party/googletest/src - For details, see third_party/benchmark/build/third_party/googletest/src/LICENSE + For details, see: third_party/sleef/LICENSE.txt Name: swift License: Apache-2.0 Files: third_party/flatbuffers/swift - For details, see third_party/flatbuffers/swift/LICENSE + For details, see: third_party/flatbuffers/swift/LICENSE Name: tb_plugin License: BSD-3-Clause Files: third_party/kineto/tb_plugin - For details, see third_party/kineto/tb_plugin/LICENSE + For details, see: third_party/kineto/tb_plugin/LICENSE Name: tbb License: Apache-2.0 Files: third_party/tbb - For details, see third_party/tbb/LICENSE + For details, see: third_party/tbb/LICENSE Name: tensorpipe License: BSD-3-Clause Files: third_party/tensorpipe - For details, see third_party/tensorpipe/LICENSE.txt + For details, see: third_party/tensorpipe/LICENSE.txt Name: zstd License: BSD-3-Clause Files: third_party/zstd - For details, see third_party/zstd/LICENSE - + For details, see: third_party/zstd/LICENSE \ No newline at end of file diff --git a/third_party/build_bundled.py b/third_party/build_bundled.py index c05e1c3642fe1..4da1b84a6f32e 100644 --- a/third_party/build_bundled.py +++ b/third_party/build_bundled.py @@ -37,23 +37,35 @@ def collect_license(current): return collected -def create_bundled(d, outstream): +def create_bundled(d, outstream, include_files=False): """Write the information to an open outstream""" collected = collect_license(d) sorted_keys = sorted(collected.keys()) outstream.write('The Pytorch repository and source distributions bundle ' 'several libraries that are \n') - outstream.write('compatibly licensed. We list these here.\n\n') + outstream.write('compatibly licensed. We list these here.') + files_to_include = [] for k in sorted_keys: c = collected[k] files = ',\n '.join(c['Files']) license_file = ',\n '.join(c['License_file']) + outstream.write('\n\n') outstream.write(f"Name: {c['Name']}\n") outstream.write(f"License: {c['License']}\n") outstream.write(f"Files: {files}\n") - outstream.write(' For details, see ') + outstream.write(' For details, see') + if include_files: + outstream.write(' the files concatenated below: ') + files_to_include += c['License_file'] + else: + outstream.write(': ') outstream.write(license_file) + for fname in files_to_include: outstream.write('\n\n') + outstream.write(fname) + outstream.write('\n' + '-' * len(fname) + '\n') + with open(fname, 'r') as fid: + outstream.write(fid.read()) def identify_license(f, exception=''): @@ -156,7 +168,7 @@ def squeeze(t): if __name__ == '__main__': - third_party = os.path.join(mydir) + third_party = os.path.relpath(mydir) parser = argparse.ArgumentParser( description="Generate bundled licenses file", ) diff --git a/third_party/fbgemm b/third_party/fbgemm index 2e9be65810107..499cd22f5c2e2 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 2e9be65810107a9595da717f95d21924b73be833 +Subproject commit 499cd22f5c2e26041e4f190f628b48478a89a030 diff --git a/third_party/generate-xnnpack-wrappers.py b/third_party/generate-xnnpack-wrappers.py index 23992645672a9..c1bb51ad9cf5d 100644 --- a/third_party/generate-xnnpack-wrappers.py +++ b/third_party/generate-xnnpack-wrappers.py @@ -3,6 +3,7 @@ from __future__ import print_function import collections import os +import sys BANNER = "Auto-generated by generate-wrappers.py script. Do not modify" WRAPPER_SRC_NAMES = { @@ -53,9 +54,9 @@ "PROD_AVX512SKX_MICROKERNEL_SRCS", ] -def update_sources(): +def update_sources(xnnpack_path): sources = collections.defaultdict(list) - with open("./XNNPACK/CMakeLists.txt") as cmake: + with open(os.path.join(xnnpack_path, "XNNPACK/CMakeLists.txt")) as cmake: lines = cmake.readlines() i = 0 while i < len(lines): @@ -74,17 +75,16 @@ def update_sources(): sources[name].append(value[4:]) else: i += 1 - print(sources) return sources -if __name__ == "__main__": +def gen_wrappers(xnnpack_path): xnnpack_sources = collections.defaultdict(list) - sources = update_sources() + sources = update_sources(xnnpack_path) for name in WRAPPER_SRC_NAMES: xnnpack_sources[WRAPPER_SRC_NAMES[name]].extend(sources[name]) for condition, filenames in xnnpack_sources.items(): for filename in filenames: - filepath = os.path.join("XNNPACK/wrappers", filename) + filepath = os.path.join(xnnpack_path, "xnnpack_wrappers", filename) if not os.path.isdir(os.path.dirname(filepath)): os.makedirs(os.path.dirname(filepath)) with open(filepath, "w") as wrapper: @@ -102,3 +102,38 @@ def update_sources(): print("#if %s" % condition, file=wrapper) print("#include <%s>" % filename, file=wrapper) print("#endif /* %s */" % condition, file=wrapper) + + # update xnnpack_wrapper_defs.bzl file under the same folder + with open(os.path.join(os.path.dirname(__file__), "xnnpack_wrapper_defs.bzl"), 'w') as wrapper_defs: + print('"""', file=wrapper_defs) + print(BANNER, file=wrapper_defs) + print('"""', file=wrapper_defs) + for name in WRAPPER_SRC_NAMES: + print('\n' + name + ' = [', file=wrapper_defs) + for file_name in sources[name]: + print(' "xnnpack_wrappers/{}",'.format(file_name), file=wrapper_defs) + print(']', file=wrapper_defs) + + # update xnnpack_src_defs.bzl file under the same folder + with open(os.path.join(os.path.dirname(__file__), "xnnpack_src_defs.bzl"), 'w') as src_defs: + print('"""', file=src_defs) + print(BANNER, file=src_defs) + print('"""', file=src_defs) + for name in SRC_NAMES: + print('\n' + name + ' = [', file=src_defs) + for file_name in sources[name]: + print(' "XNNPACK/src/{}",'.format(file_name), file=src_defs) + print(']', file=src_defs) + + +def main(argv): + if argv is None or len(argv) == 0: + gen_wrappers(".") + else: + gen_wrappers(argv[0]) + +# The first argument is the place where the "xnnpack_wrappers" folder will be created. +# Run it without arguments will generate "xnnpack_wrappers" in the current path. +# The two .bzl files will always be generated in the current path. +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/third_party/ittapi b/third_party/ittapi new file mode 160000 index 0000000000000..5b8a7d7422611 --- /dev/null +++ b/third_party/ittapi @@ -0,0 +1 @@ +Subproject commit 5b8a7d7422611c3a0d799fb5fc5dd4abfae35b42 diff --git a/third_party/kineto.BUILD b/third_party/kineto.BUILD new file mode 100644 index 0000000000000..d8e484ae80b6b --- /dev/null +++ b/third_party/kineto.BUILD @@ -0,0 +1,10 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "kineto", + hdrs = glob(["libkineto/include/*.h",]), + includes = [ + "libkineto/include/", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/kineto.buck.bzl b/third_party/kineto.buck.bzl index bd0a40d6ecea5..623d45c0b6029 100644 --- a/third_party/kineto.buck.bzl +++ b/third_party/kineto.buck.bzl @@ -168,3 +168,14 @@ def define_kineto(): ":fmt", ], ) + + cxx_library( + name = "libkineto_headers", + exported_headers = native.glob([ + "kineto/libkineto/include/*.h", + ]), + public_include_directories = [ + "kineto/libkineto/include", + ], + visibility = ["PUBLIC"], + ) diff --git a/third_party/miniz-2.0.8/BUILD.bazel b/third_party/miniz-2.1.0/BUILD.bazel similarity index 100% rename from third_party/miniz-2.0.8/BUILD.bazel rename to third_party/miniz-2.1.0/BUILD.bazel diff --git a/third_party/miniz-2.0.8/ChangeLog.md b/third_party/miniz-2.1.0/ChangeLog.md similarity index 88% rename from third_party/miniz-2.0.8/ChangeLog.md rename to third_party/miniz-2.1.0/ChangeLog.md index 8b44cee2b153b..3ee292d7996d9 100755 --- a/third_party/miniz-2.0.8/ChangeLog.md +++ b/third_party/miniz-2.1.0/ChangeLog.md @@ -1,16 +1,31 @@ ## Changelog +### 2.1.0 + + - More instances of memcpy instead of cast and use memcpy per default + - Remove inline for c90 support + - New function to read files via callback functions when adding them + - Fix out of bounds read while reading Zip64 extended information + - guard memcpy when n == 0 because buffer may be NULL + - Implement inflateReset() function + - Move comp/decomp alloc/free prototypes under guarding #ifndef MZ_NO_MALLOC + - Fix large file support under Windows + - Don't warn if _LARGEFILE64_SOURCE is not defined to 1 + - Fixes for MSVC warnings + - Remove check that path of file added to archive contains ':' or '\' + - Add !defined check on MINIZ_USE_ALIGNED_LOADS_AND_STORES + ### 2.0.8 - Remove unimplemented functions (mz_zip_locate_file and mz_zip_locate_file_v2) - Add license, changelog, readme and example files to release zip - Fix heap overflow to user buffer in tinfl_status tinfl_decompress - - Fix corrupt archive if uncompressed file smaller than 4 byte and file is added by mz_zip_writer_add_mem* + - Fix corrupt archive if uncompressed file smaller than 4 byte and the file is added by mz_zip_writer_add_mem* ### 2.0.7 - Removed need in C++ compiler in cmake build - - Fixed loads of uninitialized value errors found with Valgrind by memsetting m_dict to 0 in tdefl_init. + - Fixed a lot of uninitialized value errors found with Valgrind by memsetting m_dict to 0 in tdefl_init - Fix resource leak in mz_zip_reader_init_file_v2 - Fix assert with mz_zip_writer_add_mem* w/MZ_DEFAULT_COMPRESSION - cmake build: install library and headers @@ -19,7 +34,7 @@ ### 2.0.6 - Improve MZ_ZIP_FLAG_WRITE_ZIP64 documentation - - Remove check for cur_archive_file_ofs > UINT_MAX, because cur_archive_file_ofs is not used after this point + - Remove check for cur_archive_file_ofs > UINT_MAX because cur_archive_file_ofs is not used after this point - Add cmake debug configuration - Fix PNG height when creating png files - Add "iterative" file extraction method based on mz_zip_reader_extract_to_callback. @@ -49,11 +64,12 @@ ### 2.0.0 beta -- Matthew Sitton merged to vogl ZIP64 changes. Miniz is now licensed as MIT since the vogl code base is MIT licensed +- Matthew Sitton merged miniz 1.x to Rich Geldreich's vogl ZIP64 changes. Miniz is now licensed as MIT since the vogl code base is MIT licensed - Miniz is now split into several files - Miniz does now not seek backwards when creating ZIP files. That is the ZIP files can be streamed -- Miniz automatically switches to the ZIP64 format the created ZIP files goes over ZIP file limits +- Miniz automatically switches to the ZIP64 format when the created ZIP files goes over ZIP file limits - Similar to [SQLite](https://www.sqlite.org/amalgamation.html) the Miniz source code is amalgamated into one miniz.c/miniz.h pair in a build step (amalgamate.sh). Please use miniz.c/miniz.h in your projects +- Miniz 2 is only source back-compatible with miniz 1.x. It breaks binary compatibility because structures changed ### v1.16 BETA Oct 19, 2013 @@ -66,7 +82,7 @@ The inflator now has a new failure status TINFL_STATUS_FAILED_CANNOT_MAKE_PROGRE - The inflator coroutine func. is subtle and complex so I'm being cautious about this release. I would greatly appreciate any help with testing or any feedback. I feel good about these changes, and they've been through several hours of automated testing, but they will probably not fix anything for the majority of prev. users so I'm going to mark this release as beta for a few weeks and continue testing it at work/home on various things. -- The inflator in raw (non-zlib) mode is now usable on gzip or similar data streams that have a bunch of bytes following the raw deflate data (problem discovered by rustyzip author williamw520). +- The inflator in raw (non-zlib) mode is now usable on gzip or similiar data streams that have a bunch of bytes following the raw deflate data (problem discovered by rustyzip author williamw520). This version should *never* read beyond the last byte of the raw deflate data independent of how many bytes you pass into the input buffer. This issue was caused by the various Huffman bitbuffer lookahead optimizations, and would not be an issue if the caller knew and enforced the precise size of the raw compressed data *or* if the compressed data was in zlib format (i.e. always followed by the byte aligned zlib adler32). So in other words, you can now call the inflator on deflate streams that are followed by arbitrary amounts of data and it's guaranteed that decompression will stop exactly on the last byte. @@ -87,7 +103,7 @@ Merged over a few very minor bug fixes that I fixed in the zip64 branch. This is Interim bugfix release while I work on the next major release with zip64 and streaming compression/decompression support. Fixed the MZ_ZIP_FLAG_DO_NOT_SORT_CENTRAL_DIRECTORY bug (thanks kahmyong.moon@hp.com), which could cause the locate files func to not find files when this flag was specified. Also fixed a bug in mz_zip_reader_extract_to_mem_no_alloc() with user provided read buffers (thanks kymoon). I also merged lots of compiler fixes from various github repo branches and Google Code issue reports. I finally added cmake support (only tested under for Linux so far), compiled and tested with clang v3.3 and gcc 4.6 (under Linux), added defl_write_image_to_png_file_in_memory_ex() (supports Y flipping for OpenGL use, real-time compression), added a new PNG example (example6.c - Mandelbrot), and I added 64-bit file I/O support (stat64(), etc.) for glibc. - Critical fix for the MZ_ZIP_FLAG_DO_NOT_SORT_CENTRAL_DIRECTORY bug (thanks kahmyong.moon@hp.com) which could cause locate files to not find files. This bug - would only have occurred in earlier versions if you explicitly used this flag, OR if you used mz_zip_extract_archive_file_to_heap() or mz_zip_add_mem_to_archive_file_in_place() + would only have occured in earlier versions if you explicitly used this flag, OR if you used mz_zip_extract_archive_file_to_heap() or mz_zip_add_mem_to_archive_file_in_place() (which used this flag). If you can't switch to v1.15 but want to fix this bug, just remove the uses of this flag from both helper funcs (and of course don't use the flag). - Bugfix in mz_zip_reader_extract_to_mem_no_alloc() from kymoon when pUser_read_buf is not NULL and compressed size is > uncompressed size - Fixing mz_zip_reader_extract_*() funcs so they don't try to extract compressed data from directory entries, to account for weird zipfiles which contain zero-size compressed data on dir entries. @@ -104,7 +120,7 @@ Interim bugfix release while I work on the next major release with zip64 and str - Retested this build under Windows (VS 2010, including static analysis), tcc 0.9.26, gcc v4.6 and clang v3.3. - Added example6.c, which dumps an image of the mandelbrot set to a PNG file. - Modified example2 to help test the MZ_ZIP_FLAG_DO_NOT_SORT_CENTRAL_DIRECTORY flag more. -- In r3: Bugfix to mz_zip_writer_add_file() found during merge: Fix possible src file fclose() leak if alignment bytes+local header file write failed +- In r3: Bugfix to mz_zip_writer_add_file() found during merge: Fix possible src file fclose() leak if alignment bytes+local header file write faiiled - In r4: Minor bugfix to mz_zip_writer_add_from_zip_reader(): Was pushing the wrong central dir header offset, appears harmless in this release, but it became a problem in the zip64 branch ### v1.14 - May 20, 2012 diff --git a/third_party/miniz-2.0.8/LICENSE b/third_party/miniz-2.1.0/LICENSE similarity index 100% rename from third_party/miniz-2.0.8/LICENSE rename to third_party/miniz-2.1.0/LICENSE diff --git a/third_party/miniz-2.0.8/examples/example1.c b/third_party/miniz-2.1.0/examples/example1.c similarity index 84% rename from third_party/miniz-2.0.8/examples/example1.c rename to third_party/miniz-2.1.0/examples/example1.c index 45254b7840d7a..d6e33faaa6fda 100755 --- a/third_party/miniz-2.0.8/examples/example1.c +++ b/third_party/miniz-2.1.0/examples/example1.c @@ -2,7 +2,6 @@ // Public domain, May 15 2011, Rich Geldreich, richgel99@gmail.com. See "unlicense" statement at the end of tinfl.c. #include #include "miniz.h" -#include "miniz_zip.h" typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint; @@ -18,26 +17,6 @@ static const char *s_pStr = "Good morning Dr. Chandra. This is Hal. I am ready f int main(int argc, char *argv[]) { - { - mz_zip_archive zip_archive; - memset(&zip_archive, 0, sizeof(zip_archive)); // mz_zip_archive contains a bunch of pointers. set all to nullptr - mz_bool status = mz_zip_writer_init(&zip_archive, 0); - if (!status) - return; - - status = mz_zip_writer_add_file(&zip_archive, "Images.zip", "Images/title.png", NULL, 0, MZ_DEFAULT_COMPRESSION); - if (!status) - return; - - status = mz_zip_writer_finalize_archive(&zip_archive); - if (!status) - return; - - status = mz_zip_writer_end(&zip_archive); - if (!status) - return; - } - uint step = 0; int cmp_status; uLong src_len = (uLong)strlen(s_pStr); diff --git a/third_party/miniz-2.0.8/examples/example2.c b/third_party/miniz-2.1.0/examples/example2.c similarity index 98% rename from third_party/miniz-2.0.8/examples/example2.c rename to third_party/miniz-2.1.0/examples/example2.c index eb0019b063c6a..c3a84bacbfeda 100755 --- a/third_party/miniz-2.0.8/examples/example2.c +++ b/third_party/miniz-2.1.0/examples/example2.c @@ -61,7 +61,7 @@ int main(int argc, char *argv[]) // Add a new file to the archive. Note this is an IN-PLACE operation, so if it fails your archive is probably hosed (its central directory may not be complete) but it should be recoverable using zip -F or -FF. So use caution with this guy. // A more robust way to add a file to an archive would be to read it into memory, perform the operation, then write a new archive out to a temp file and then delete/rename the files. // Or, write a new archive to disk to a temp file, then delete/rename the files. For this test this API is fine. - status = mz_zip_add_mem_to_archive_file_in_place(s_Test_archive_filename, archive_filename, data, 2, s_pComment, (uint16)strlen(s_pComment), MZ_BEST_COMPRESSION); + status = mz_zip_add_mem_to_archive_file_in_place(s_Test_archive_filename, archive_filename, data, strlen(data) + 1, s_pComment, (uint16)strlen(s_pComment), MZ_BEST_COMPRESSION); if (!status) { printf("mz_zip_add_mem_to_archive_file_in_place failed!\n"); diff --git a/third_party/miniz-2.0.8/examples/example3.c b/third_party/miniz-2.1.0/examples/example3.c similarity index 100% rename from third_party/miniz-2.0.8/examples/example3.c rename to third_party/miniz-2.1.0/examples/example3.c diff --git a/third_party/miniz-2.0.8/examples/example4.c b/third_party/miniz-2.1.0/examples/example4.c similarity index 100% rename from third_party/miniz-2.0.8/examples/example4.c rename to third_party/miniz-2.1.0/examples/example4.c diff --git a/third_party/miniz-2.0.8/examples/example5.c b/third_party/miniz-2.1.0/examples/example5.c similarity index 100% rename from third_party/miniz-2.0.8/examples/example5.c rename to third_party/miniz-2.1.0/examples/example5.c diff --git a/third_party/miniz-2.0.8/examples/example6.c b/third_party/miniz-2.1.0/examples/example6.c similarity index 100% rename from third_party/miniz-2.0.8/examples/example6.c rename to third_party/miniz-2.1.0/examples/example6.c diff --git a/third_party/miniz-2.0.8/miniz.c b/third_party/miniz-2.1.0/miniz.c similarity index 98% rename from third_party/miniz-2.0.8/miniz.c rename to third_party/miniz-2.1.0/miniz.c index 9a1ff5f67f320..d6d17b6d96d31 100755 --- a/third_party/miniz-2.0.8/miniz.c +++ b/third_party/miniz-2.1.0/miniz.c @@ -410,6 +410,33 @@ int mz_inflateInit(mz_streamp pStream) return mz_inflateInit2(pStream, MZ_DEFAULT_WINDOW_BITS); } +int mz_inflateReset(mz_streamp pStream) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + inflate_state *pDecomp; + if (!pStream) + return MZ_STREAM_ERROR; + + pStream->data_type = 0; + pStream->adler = 0; + pStream->msg = NULL; + pStream->total_in = 0; + pStream->total_out = 0; + pStream->reserved = 0; + + pDecomp = (inflate_state *)pStream->state; + + tinfl_init(&pDecomp->m_decomp); + pDecomp->m_dict_ofs = 0; + pDecomp->m_dict_avail = 0; + pDecomp->m_last_status = TINFL_STATUS_NEEDS_MORE_INPUT; + pDecomp->m_first_call = 1; + pDecomp->m_has_flushed = 0; + /* pDecomp->m_window_bits = window_bits */; + + return MZ_OK; +} + int mz_inflate(mz_streamp pStream, int flush) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -1379,13 +1406,13 @@ static int tdefl_flush_block(tdefl_compressor *d, int flush) #if MINIZ_USE_UNALIGNED_LOADS_AND_STORES #ifdef MINIZ_UNALIGNED_USE_MEMCPY -static inline mz_uint16 TDEFL_READ_UNALIGNED_WORD(const mz_uint8* p) +static mz_uint16 TDEFL_READ_UNALIGNED_WORD(const mz_uint8* p) { mz_uint16 ret; memcpy(&ret, p, sizeof(mz_uint16)); return ret; } -static inline mz_uint16 TDEFL_READ_UNALIGNED_WORD2(const mz_uint16* p) +static mz_uint16 TDEFL_READ_UNALIGNED_WORD2(const mz_uint16* p) { mz_uint16 ret; memcpy(&ret, p, sizeof(mz_uint16)); @@ -1397,9 +1424,12 @@ static inline mz_uint16 TDEFL_READ_UNALIGNED_WORD2(const mz_uint16* p) #endif static MZ_FORCEINLINE void tdefl_find_match(tdefl_compressor *d, mz_uint lookahead_pos, mz_uint max_dist, mz_uint max_match_len, mz_uint *pMatch_dist, mz_uint *pMatch_len) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) mz_uint dist, pos = lookahead_pos & TDEFL_LZ_DICT_SIZE_MASK, match_len = *pMatch_len, probe_pos = pos, next_probe_pos, probe_len; mz_uint num_probes_left = d->m_max_probes[match_len >= 32]; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const mz_uint16 *s = (const mz_uint16 *)(d->m_dict + pos), *p, *q; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) mz_uint16 c01 = TDEFL_READ_UNALIGNED_WORD(&d->m_dict[pos + match_len - 1]), s01 = TDEFL_READ_UNALIGNED_WORD2(s); MZ_ASSERT(max_match_len <= TDEFL_MAX_MATCH_LEN); if (max_match_len <= match_len) @@ -1450,12 +1480,9 @@ static MZ_FORCEINLINE void tdefl_find_match(tdefl_compressor *d, mz_uint lookahe #else static MZ_FORCEINLINE void tdefl_find_match(tdefl_compressor *d, mz_uint lookahead_pos, mz_uint max_dist, mz_uint max_match_len, mz_uint *pMatch_dist, mz_uint *pMatch_len) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) mz_uint dist, pos = lookahead_pos & TDEFL_LZ_DICT_SIZE_MASK, match_len = *pMatch_len, probe_pos = pos, next_probe_pos, probe_len; mz_uint num_probes_left = d->m_max_probes[match_len >= 32]; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const mz_uint8 *s = d->m_dict + pos, *p, *q; - // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) mz_uint8 c0 = d->m_dict[pos + match_len], c1 = d->m_dict[pos + match_len - 1]; MZ_ASSERT(max_match_len <= TDEFL_MAX_MATCH_LEN); if (max_match_len <= match_len) @@ -1497,6 +1524,16 @@ static MZ_FORCEINLINE void tdefl_find_match(tdefl_compressor *d, mz_uint lookahe #endif /* #if MINIZ_USE_UNALIGNED_LOADS_AND_STORES */ #if MINIZ_USE_UNALIGNED_LOADS_AND_STORES && MINIZ_LITTLE_ENDIAN +#ifdef MINIZ_UNALIGNED_USE_MEMCPY +static mz_uint32 TDEFL_READ_UNALIGNED_WORD32(const mz_uint8* p) +{ + mz_uint32 ret; + memcpy(&ret, p, sizeof(mz_uint32)); + return ret; +} +#else +#define TDEFL_READ_UNALIGNED_WORD32(p) *(const mz_uint32 *)(p) +#endif static mz_bool tdefl_compress_fast(tdefl_compressor *d) { /* Faster, minimally featured LZRW1-style match+parse loop with better register utilization. Intended for applications where raw throughput is valued more highly than ratio. */ @@ -1531,12 +1568,12 @@ static mz_bool tdefl_compress_fast(tdefl_compressor *d) { mz_uint cur_match_dist, cur_match_len = 1; mz_uint8 *pCur_dict = d->m_dict + cur_pos; - mz_uint first_trigram = (*(const mz_uint32 *)pCur_dict) & 0xFFFFFF; + mz_uint first_trigram = TDEFL_READ_UNALIGNED_WORD32(pCur_dict) & 0xFFFFFF; mz_uint hash = (first_trigram ^ (first_trigram >> (24 - (TDEFL_LZ_HASH_BITS - 8)))) & TDEFL_LEVEL1_HASH_SIZE_MASK; mz_uint probe_pos = d->m_hash[hash]; d->m_hash[hash] = (mz_uint16)lookahead_pos; - if (((cur_match_dist = (mz_uint16)(lookahead_pos - probe_pos)) <= dict_size) && ((*(const mz_uint32 *)(d->m_dict + (probe_pos &= TDEFL_LZ_DICT_SIZE_MASK)) & 0xFFFFFF) == first_trigram)) + if (((cur_match_dist = (mz_uint16)(lookahead_pos - probe_pos)) <= dict_size) && ((TDEFL_READ_UNALIGNED_WORD32(d->m_dict + (probe_pos &= TDEFL_LZ_DICT_SIZE_MASK)) & 0xFFFFFF) == first_trigram)) { const mz_uint16 *p = (const mz_uint16 *)pCur_dict; const mz_uint16 *q = (const mz_uint16 *)(d->m_dict + probe_pos); @@ -1566,7 +1603,11 @@ static mz_bool tdefl_compress_fast(tdefl_compressor *d) cur_match_dist--; pLZ_code_buf[0] = (mz_uint8)(cur_match_len - TDEFL_MIN_MATCH_LEN); +#ifdef MINIZ_UNALIGNED_USE_MEMCPY + memcpy(&pLZ_code_buf[1], &cur_match_dist, sizeof(cur_match_dist)); +#else *(mz_uint16 *)(&pLZ_code_buf[1]) = (mz_uint16)cur_match_dist; +#endif pLZ_code_buf += 3; *pLZ_flags = (mz_uint8)((*pLZ_flags >> 1) | 0x80); @@ -2195,6 +2236,7 @@ void *tdefl_write_image_to_png_file_in_memory(const void *pImage, int w, int h, return tdefl_write_image_to_png_file_in_memory_ex(pImage, w, h, num_chans, pLen_out, 6, MZ_FALSE); } +#ifndef MINIZ_NO_MALLOC /* Allocate the tdefl_compressor and tinfl_decompressor structures in C so that */ /* non-C language bindings to tdefL_ and tinfl_ API don't need to worry about */ /* structure size and allocation mechanism. */ @@ -2207,6 +2249,7 @@ void tdefl_compressor_free(tdefl_compressor *pComp) { MZ_FREE(pComp); } +#endif #ifdef _MSC_VER #pragma warning(pop) @@ -2760,8 +2803,12 @@ tinfl_status tinfl_decompress(tinfl_decompressor *r, const mz_uint8 *pIn_buf_nex const mz_uint8 *pSrc_end = pSrc + (counter & ~7); do { +#ifdef MINIZ_UNALIGNED_USE_MEMCPY + memcpy(pOut_buf_cur, pSrc, sizeof(mz_uint32)*2); +#else ((mz_uint32 *)pOut_buf_cur)[0] = ((const mz_uint32 *)pSrc)[0]; ((mz_uint32 *)pOut_buf_cur)[1] = ((const mz_uint32 *)pSrc)[1]; +#endif pOut_buf_cur += 8; } while ((pSrc += 8) < pSrc_end); if ((counter &= 7) < 3) @@ -2958,6 +3005,7 @@ int tinfl_decompress_mem_to_callback(const void *pIn_buf, size_t *pIn_buf_size, return result; } +#ifndef MINIZ_NO_MALLOC tinfl_decompressor *tinfl_decompressor_alloc() { tinfl_decompressor *pDecomp = (tinfl_decompressor *)MZ_MALLOC(sizeof(tinfl_decompressor)); @@ -2970,6 +3018,7 @@ void tinfl_decompressor_free(tinfl_decompressor *pDecomp) { MZ_FREE(pDecomp); } +#endif #ifdef __cplusplus } @@ -3038,8 +3087,8 @@ static FILE *mz_freopen(const char *pPath, const char *pMode, FILE *pStream) #define MZ_FWRITE fwrite #define MZ_FTELL64 _ftelli64 #define MZ_FSEEK64 _fseeki64 -#define MZ_FILE_STAT_STRUCT _stat -#define MZ_FILE_STAT _stat +#define MZ_FILE_STAT_STRUCT _stat64 +#define MZ_FILE_STAT _stat64 #define MZ_FFLUSH fflush #define MZ_FREOPEN mz_freopen #define MZ_DELETE_FILE remove @@ -3073,7 +3122,7 @@ static FILE *mz_freopen(const char *pPath, const char *pMode, FILE *pStream) #define MZ_FFLUSH fflush #define MZ_FREOPEN(f, m, s) freopen(f, m, s) #define MZ_DELETE_FILE remove -#elif defined(__GNUC__) && _LARGEFILE64_SOURCE +#elif defined(__GNUC__) && defined(_LARGEFILE64_SOURCE) #ifndef MINIZ_NO_TIME #include #endif @@ -3328,8 +3377,8 @@ static MZ_FORCEINLINE mz_bool mz_zip_array_push_back(mz_zip_archive *pZip, mz_zi size_t orig_size = pArray->m_size; if (!mz_zip_array_resize(pZip, pArray, orig_size + n, MZ_TRUE)) return MZ_FALSE; - if (n > 0) // zdevito: pElements may be null when n == 0 and ASAN complains - memcpy((mz_uint8 *)pArray->m_p + orig_size * pArray->m_element_size, pElements, n * pArray->m_element_size); + if (n > 0) + memcpy((mz_uint8 *)pArray->m_p + orig_size * pArray->m_element_size, pElements, n * pArray->m_element_size); return MZ_TRUE; } @@ -3739,7 +3788,27 @@ static mz_bool mz_zip_reader_read_central_dir(mz_zip_archive *pZip, mz_uint flag if (extra_size_remaining) { - const mz_uint8 *pExtra_data = p + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size; + const mz_uint8 *pExtra_data; + void* buf = NULL; + + if (MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size + ext_data_size > n) + { + buf = MZ_MALLOC(ext_data_size); + if(buf==NULL) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + if (pZip->m_pRead(pZip->m_pIO_opaque, cdir_ofs + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size, buf, ext_data_size) != ext_data_size) + { + MZ_FREE(buf); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + } + + pExtra_data = (mz_uint8*)buf; + } + else + { + pExtra_data = p + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size; + } do { @@ -3748,14 +3817,20 @@ static mz_bool mz_zip_reader_read_central_dir(mz_zip_archive *pZip, mz_uint flag // NOLINTNEXTLINE(cppcoreguidelines-init-variables) mz_uint32 field_data_size; - if (extra_size_remaining < (sizeof(mz_uint16) * 2)) - return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + if (extra_size_remaining < (sizeof(mz_uint16) * 2)) + { + MZ_FREE(buf); + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } field_id = MZ_READ_LE16(pExtra_data); field_data_size = MZ_READ_LE16(pExtra_data + sizeof(mz_uint16)); - if ((field_data_size + sizeof(mz_uint16) * 2) > extra_size_remaining) - return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + if ((field_data_size + sizeof(mz_uint16) * 2) > extra_size_remaining) + { + MZ_FREE(buf); + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } if (field_id == MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID) { @@ -3768,6 +3843,8 @@ static mz_bool mz_zip_reader_read_central_dir(mz_zip_archive *pZip, mz_uint flag pExtra_data += sizeof(mz_uint16) * 2 + field_data_size; extra_size_remaining = extra_size_remaining - sizeof(mz_uint16) * 2 - field_data_size; } while (extra_size_remaining); + + MZ_FREE(buf); } } @@ -4995,7 +5072,7 @@ size_t mz_zip_reader_extract_iter_read(mz_zip_reader_extract_iter_state* pState, if ((pState->flags & MZ_ZIP_FLAG_COMPRESSED_DATA) || (!pState->file_stat.m_method)) { /* The file is stored or the caller has requested the compressed data, calc amount to return. */ - copied_to_caller = MZ_MIN( buf_size, pState->comp_remaining ); + copied_to_caller = (size_t)MZ_MIN( buf_size, pState->comp_remaining ); /* Zip is in memory....or requires reading from a file? */ if (pState->pZip->m_pState->m_pMem) @@ -6092,7 +6169,7 @@ static mz_bool mz_zip_writer_add_to_central_dir(mz_zip_archive *pZip, const char if (((mz_uint64)pState->m_central_dir.m_size + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size + extra_size + user_extra_data_len + comment_size) >= MZ_UINT32_MAX) return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_CDIR_SIZE); - if (!mz_zip_writer_create_central_dir_header(pZip, central_dir_header, filename_size, extra_size + user_extra_data_len, comment_size, uncomp_size, comp_size, uncomp_crc32, method, bit_flags, dos_time, dos_date, local_header_ofs, ext_attributes)) + if (!mz_zip_writer_create_central_dir_header(pZip, central_dir_header, filename_size, (mz_uint16)(extra_size + user_extra_data_len), comment_size, uncomp_size, comp_size, uncomp_crc32, method, bit_flags, dos_time, dos_date, local_header_ofs, ext_attributes)) return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); if ((!mz_zip_array_push_back(pZip, &pState->m_central_dir, central_dir_header, MZ_ZIP_CENTRAL_DIR_HEADER_SIZE)) || @@ -6116,13 +6193,7 @@ static mz_bool mz_zip_writer_validate_archive_name(const char *pArchive_name) if (*pArchive_name == '/') return MZ_FALSE; - while (*pArchive_name) - { - if ((*pArchive_name == '\\') || (*pArchive_name == ':')) - return MZ_FALSE; - - pArchive_name++; - } + /* Making sure the name does not contain drive letters or DOS style backward slashes is the responsibility of the program using miniz*/ return MZ_TRUE; } @@ -6160,8 +6231,8 @@ mz_bool mz_zip_writer_add_mem_ex(mz_zip_archive *pZip, const char *pArchive_name } mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, - mz_uint level_and_flags, mz_uint64 uncomp_size, mz_uint32 uncomp_crc32, MZ_TIME_T *last_modified, - const char *user_extra_data, mz_uint user_extra_data_len, const char *user_extra_data_central, mz_uint user_extra_data_central_len) + mz_uint level_and_flags, mz_uint64 uncomp_size, mz_uint32 uncomp_crc32, MZ_TIME_T *last_modified, + const char *user_extra_data, mz_uint user_extra_data_len, const char *user_extra_data_central, mz_uint user_extra_data_central_len) { mz_uint16 method = 0, dos_time = 0, dos_date = 0; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -6259,8 +6330,8 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n if (!pState->m_zip64) { /* Bail early if the archive would obviously become too large */ - if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size - + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len + + if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size + + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len + pState->m_central_dir.m_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE + user_extra_data_central_len + MZ_ZIP_DATA_DESCRIPTER_SIZE32) > 0xFFFFFFFF) { @@ -6318,7 +6389,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n (uncomp_size >= MZ_UINT32_MAX) ? &comp_size : NULL, (local_dir_header_ofs >= MZ_UINT32_MAX) ? &local_dir_header_ofs : NULL); } - if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, extra_size + user_extra_data_len, 0, 0, 0, method, bit_flags, dos_time, dos_date)) + if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, (mz_uint16)(extra_size + user_extra_data_len), 0, 0, 0, method, bit_flags, dos_time, dos_date)) return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); if (pZip->m_pWrite(pZip->m_pIO_opaque, local_dir_header_ofs, local_dir_header, sizeof(local_dir_header)) != sizeof(local_dir_header)) @@ -6345,7 +6416,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n { if ((comp_size > MZ_UINT32_MAX) || (cur_archive_file_ofs > MZ_UINT32_MAX)) return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); - if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, user_extra_data_len, 0, 0, 0, method, bit_flags, dos_time, dos_date)) + if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, (mz_uint16)user_extra_data_len, 0, 0, 0, method, bit_flags, dos_time, dos_date)) return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); if (pZip->m_pWrite(pZip->m_pIO_opaque, local_dir_header_ofs, local_dir_header, sizeof(local_dir_header)) != sizeof(local_dir_header)) @@ -6438,9 +6509,9 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n (uncomp_size >= MZ_UINT32_MAX) ? &comp_size : NULL, (local_dir_header_ofs >= MZ_UINT32_MAX) ? &local_dir_header_ofs : NULL); } - if (!mz_zip_writer_add_to_central_dir(pZip, pArchive_name, (mz_uint16)archive_name_size, pExtra_data, extra_size, pComment, - comment_size, uncomp_size, comp_size, uncomp_crc32, method, bit_flags, dos_time, dos_date, local_dir_header_ofs, ext_attributes, - user_extra_data_central, user_extra_data_central_len)) + if (!mz_zip_writer_add_to_central_dir(pZip, pArchive_name, (mz_uint16)archive_name_size, pExtra_data, (mz_uint16)extra_size, pComment, + comment_size, uncomp_size, comp_size, uncomp_crc32, method, bit_flags, dos_time, dos_date, local_dir_header_ofs, ext_attributes, + user_extra_data_central, user_extra_data_central_len)) return MZ_FALSE; pZip->m_total_files++; @@ -6449,8 +6520,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n return MZ_TRUE; } -#ifndef MINIZ_NO_STDIO -mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, MZ_FILE *pSrc_file, mz_uint64 size_to_add, const MZ_TIME_T *pFile_time, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, +mz_bool mz_zip_writer_add_read_buf_callback(mz_zip_archive *pZip, const char *pArchive_name, mz_file_read_func read_callback, void* callback_opaque, mz_uint64 size_to_add, const MZ_TIME_T *pFile_time, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, const char *user_extra_data, mz_uint user_extra_data_len, const char *user_extra_data_central, mz_uint user_extra_data_central_len) { mz_uint16 gen_flags = MZ_ZIP_LDH_BIT_FLAG_HAS_LOCATOR; @@ -6463,6 +6533,7 @@ mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, mz_uint32 extra_size = 0; mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE]; mz_zip_internal_state *pState; + mz_uint64 file_ofs = 0; if (!(level_and_flags & MZ_ZIP_FLAG_ASCII_FILENAME)) gen_flags |= MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_UTF8; @@ -6565,7 +6636,7 @@ mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, (uncomp_size >= MZ_UINT32_MAX) ? &comp_size : NULL, (local_dir_header_ofs >= MZ_UINT32_MAX) ? &local_dir_header_ofs : NULL); } - if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, extra_size + user_extra_data_len, 0, 0, 0, method, gen_flags, dos_time, dos_date)) + if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, (mz_uint16)(extra_size + user_extra_data_len), 0, 0, 0, method, gen_flags, dos_time, dos_date)) return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, local_dir_header, sizeof(local_dir_header)) != sizeof(local_dir_header)) @@ -6589,7 +6660,7 @@ mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, { if ((comp_size > MZ_UINT32_MAX) || (cur_archive_file_ofs > MZ_UINT32_MAX)) return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); - if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, user_extra_data_len, 0, 0, 0, method, gen_flags, dos_time, dos_date)) + if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, (mz_uint16)user_extra_data_len, 0, 0, 0, method, gen_flags, dos_time, dos_date)) return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, local_dir_header, sizeof(local_dir_header)) != sizeof(local_dir_header)) @@ -6627,11 +6698,12 @@ mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, while (uncomp_remaining) { mz_uint n = (mz_uint)MZ_MIN((mz_uint64)MZ_ZIP_MAX_IO_BUF_SIZE, uncomp_remaining); - if ((MZ_FREAD(pRead_buf, 1, n, pSrc_file) != n) || (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pRead_buf, n) != n)) + if ((read_callback(callback_opaque, file_ofs, pRead_buf, n) != n) || (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pRead_buf, n) != n)) { pZip->m_pFree(pZip->m_pAlloc_opaque, pRead_buf); return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); } + file_ofs += n; uncomp_crc32 = (mz_uint32)mz_crc32(uncomp_crc32, (const mz_uint8 *)pRead_buf, n); uncomp_remaining -= n; cur_archive_file_ofs += n; @@ -6666,12 +6738,13 @@ mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, tdefl_status status; tdefl_flush flush = TDEFL_NO_FLUSH; - if (MZ_FREAD(pRead_buf, 1, in_buf_size, pSrc_file) != in_buf_size) + if (read_callback(callback_opaque, file_ofs, pRead_buf, in_buf_size)!= in_buf_size) { mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); break; } + file_ofs += in_buf_size; uncomp_crc32 = (mz_uint32)mz_crc32(uncomp_crc32, (const mz_uint8 *)pRead_buf, in_buf_size); uncomp_remaining -= in_buf_size; @@ -6739,7 +6812,7 @@ mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, (uncomp_size >= MZ_UINT32_MAX) ? &comp_size : NULL, (local_dir_header_ofs >= MZ_UINT32_MAX) ? &local_dir_header_ofs : NULL); } - if (!mz_zip_writer_add_to_central_dir(pZip, pArchive_name, (mz_uint16)archive_name_size, pExtra_data, extra_size, pComment, comment_size, + if (!mz_zip_writer_add_to_central_dir(pZip, pArchive_name, (mz_uint16)archive_name_size, pExtra_data, (mz_uint16)extra_size, pComment, comment_size, uncomp_size, comp_size, uncomp_crc32, method, gen_flags, dos_time, dos_date, local_dir_header_ofs, ext_attributes, user_extra_data_central, user_extra_data_central_len)) return MZ_FALSE; @@ -6750,6 +6823,26 @@ mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, return MZ_TRUE; } +#ifndef MINIZ_NO_STDIO + +static size_t mz_file_read_func_stdio(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) +{ + MZ_FILE *pSrc_file = (MZ_FILE *)pOpaque; + mz_int64 cur_ofs = MZ_FTELL64(pSrc_file); + + if (((mz_int64)file_ofs < 0) || (((cur_ofs != (mz_int64)file_ofs)) && (MZ_FSEEK64(pSrc_file, (mz_int64)file_ofs, SEEK_SET)))) + return 0; + + return MZ_FREAD(pBuf, 1, n, pSrc_file); +} + +mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, MZ_FILE *pSrc_file, mz_uint64 size_to_add, const MZ_TIME_T *pFile_time, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, + const char *user_extra_data, mz_uint user_extra_data_len, const char *user_extra_data_central, mz_uint user_extra_data_central_len) +{ + return mz_zip_writer_add_read_buf_callback(pZip, pArchive_name, mz_file_read_func_stdio, pSrc_file, size_to_add, pFile_time, pComment, comment_size, level_and_flags, + user_extra_data, user_extra_data_len, user_extra_data_central, user_extra_data_central_len); +} + mz_bool mz_zip_writer_add_file(mz_zip_archive *pZip, const char *pArchive_name, const char *pSrc_filename, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags) { MZ_FILE *pSrc_file = NULL; @@ -6957,7 +7050,6 @@ mz_bool mz_zip_writer_add_from_zip_reader(mz_zip_archive *pZip, mz_zip_archive * if ((local_header_extra_len) && ((local_header_comp_size == MZ_UINT32_MAX) || (local_header_uncomp_size == MZ_UINT32_MAX))) { mz_zip_array file_data_array; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const mz_uint8 *pExtra_data; mz_uint32 extra_size_remaining = local_header_extra_len; diff --git a/third_party/miniz-2.0.8/miniz.h b/third_party/miniz-2.1.0/miniz.h similarity index 97% rename from third_party/miniz-2.0.8/miniz.h rename to third_party/miniz-2.1.0/miniz.h index 7e18807fbe719..2cad1370c6388 100755 --- a/third_party/miniz-2.0.8/miniz.h +++ b/third_party/miniz-2.1.0/miniz.h @@ -1,4 +1,4 @@ -/* miniz.c 2.0.8 - public domain deflate/inflate, zlib-subset, ZIP reading/writing/appending, PNG writing +/* miniz.c 2.1.0 - public domain deflate/inflate, zlib-subset, ZIP reading/writing/appending, PNG writing See "unlicense" statement at the end of this file. Rich Geldreich , last updated Oct. 13, 2013 Implements RFC 1950: http://www.ietf.org/rfc/rfc1950.txt and RFC 1951: http://www.ietf.org/rfc/rfc1951.txt @@ -24,7 +24,7 @@ zlib replacement in many apps: The z_stream struct, optional memory allocation callbacks deflateInit/deflateInit2/deflate/deflateReset/deflateEnd/deflateBound - inflateInit/inflateInit2/inflate/inflateEnd + inflateInit/inflateInit2/inflate/inflateReset/inflateEnd compress, compress2, compressBound, uncompress CRC-32, Adler-32 - Using modern, minimal code size, CPU cache friendly routines. Supports raw deflate streams or standard zlib streams with adler-32 checking. @@ -116,11 +116,11 @@ -/* Defines to completely disable specific portions of miniz.c: +/* Defines to completely disable specific portions of miniz.c: If all macros here are defined the only functionality remaining will be CRC-32, adler-32, tinfl, and tdefl. */ /* Define MINIZ_NO_STDIO to disable all usage and any functions which rely on stdio for file I/O. */ -// #define MINIZ_NO_STDIO // enable STDIO API mz_zip_reader_init_file for pytorch file extractor, see D28168870 +/*#define MINIZ_NO_STDIO */ /* If MINIZ_NO_TIME is specified then the ZIP archive functions will not be able to get the current time, or */ /* get/set file times, and the C run-time funcs that get/set times won't be called. */ @@ -139,7 +139,7 @@ /* Define MINIZ_NO_ZLIB_COMPATIBLE_NAME to disable zlib names, to prevent conflicts against stock zlib. */ #define MINIZ_NO_ZLIB_COMPATIBLE_NAMES -/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc. +/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc. Note if MINIZ_NO_MALLOC is defined then the user must always provide custom user alloc/free/realloc callbacks to the zlib and archive API's, and a few stand-alone helper API's which don't provide custom user functions (such as tdefl_compress_mem_to_heap() and tinfl_decompress_mem_to_heap()) won't work. */ @@ -170,13 +170,18 @@ #define MINIZ_LITTLE_ENDIAN 0 #endif +/* Set MINIZ_USE_UNALIGNED_LOADS_AND_STORES only if not set */ +#if !defined(MINIZ_USE_UNALIGNED_LOADS_AND_STORES) #if MINIZ_X86_OR_X64_CPU /* Set MINIZ_USE_UNALIGNED_LOADS_AND_STORES to 1 on CPU's that permit efficient integer loads and stores from unaligned addresses. */ /* zdevito: ASAN doesn't like unligned loads and stores, and -O3 optimizes the unoptimized code pattern away anyawy */ #define MINIZ_USE_UNALIGNED_LOADS_AND_STORES 0 +/* zdevito: ASAN doesn't like unligned loads and stores, and -O3 optimizes the unoptimized code pattern away anyawy */ +/*#define MINIZ_UNALIGNED_USE_MEMCPY*/ #else #define MINIZ_USE_UNALIGNED_LOADS_AND_STORES 0 #endif +#endif #if defined(_M_X64) || defined(_WIN64) || defined(__MINGW64__) || defined(_LP64) || defined(__LP64__) || defined(__ia64__) || defined(__x86_64__) /* Set MINIZ_HAS_64BIT_REGISTERS to 1 if operations on 64-bit integers are reasonably fast (and don't involve compiler generated calls to helper functions). */ @@ -235,11 +240,11 @@ enum MZ_DEFAULT_COMPRESSION = -1 }; -#define MZ_VERSION "10.0.3" -#define MZ_VERNUM 0xA030 +#define MZ_VERSION "10.1.0" +#define MZ_VERNUM 0xA100 #define MZ_VER_MAJOR 10 -#define MZ_VER_MINOR 0 -#define MZ_VER_REVISION 3 +#define MZ_VER_MINOR 1 +#define MZ_VER_REVISION 0 #define MZ_VER_SUBREVISION 0 #ifndef MINIZ_NO_ZLIB_APIS @@ -362,6 +367,9 @@ int mz_inflateInit(mz_streamp pStream); /* window_bits must be MZ_DEFAULT_WINDOW_BITS (to parse zlib header/footer) or -MZ_DEFAULT_WINDOW_BITS (raw deflate). */ int mz_inflateInit2(mz_streamp pStream, int window_bits); +/* Quickly resets a compressor without having to reallocate anything. Same as calling mz_inflateEnd() followed by mz_inflateInit()/mz_inflateInit2(). */ +int mz_inflateReset(mz_streamp pStream); + /* Decompresses the input stream to the output, consuming only as much of the input as needed, and writing as much to the output as possible. */ /* Parameters: */ /* pStream is the stream to read from and write to. You must initialize/update the next_in, avail_in, next_out, and avail_out members. */ @@ -445,6 +453,7 @@ typedef void *const voidpc; #define compressBound mz_compressBound #define inflateInit mz_inflateInit #define inflateInit2 mz_inflateInit2 +#define inflateReset mz_inflateReset #define inflate mz_inflate #define inflateEnd mz_inflateEnd #define uncompress mz_uncompress @@ -738,11 +747,13 @@ mz_uint32 tdefl_get_adler32(tdefl_compressor *d); /* strategy may be either MZ_DEFAULT_STRATEGY, MZ_FILTERED, MZ_HUFFMAN_ONLY, MZ_RLE, or MZ_FIXED */ mz_uint tdefl_create_comp_flags_from_zip_params(int level, int window_bits, int strategy); +#ifndef MINIZ_NO_MALLOC /* Allocate the tdefl_compressor structure in C so that */ /* non-C language bindings to tdefl_ API don't need to worry about */ /* structure size and allocation mechanism. */ -tdefl_compressor *tdefl_compressor_alloc(); +tdefl_compressor *tdefl_compressor_alloc(void); void tdefl_compressor_free(tdefl_compressor *pComp); +#endif #ifdef __cplusplus } @@ -790,12 +801,13 @@ int tinfl_decompress_mem_to_callback(const void *pIn_buf, size_t *pIn_buf_size, struct tinfl_decompressor_tag; typedef struct tinfl_decompressor_tag tinfl_decompressor; +#ifndef MINIZ_NO_MALLOC /* Allocate the tinfl_decompressor structure in C so that */ /* non-C language bindings to tinfl_ API don't need to worry about */ /* structure size and allocation mechanism. */ - -tinfl_decompressor *tinfl_decompressor_alloc(); +tinfl_decompressor *tinfl_decompressor_alloc(void); void tinfl_decompressor_free(tinfl_decompressor *pDecomp); +#endif /* Max size of LZ dictionary. */ #define TINFL_LZ_DICT_SIZE 32768 @@ -1270,6 +1282,12 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n mz_uint64 uncomp_size, mz_uint32 uncomp_crc32, MZ_TIME_T *last_modified, const char *user_extra_data_local, mz_uint user_extra_data_local_len, const char *user_extra_data_central, mz_uint user_extra_data_central_len); +/* Adds the contents of a file to an archive. This function also records the disk file's modified time into the archive. */ +/* File data is supplied via a read callback function. User mz_zip_writer_add_(c)file to add a file directly.*/ +mz_bool mz_zip_writer_add_read_buf_callback(mz_zip_archive *pZip, const char *pArchive_name, mz_file_read_func read_callback, void* callback_opaque, mz_uint64 size_to_add, + const MZ_TIME_T *pFile_time, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, const char *user_extra_data_local, mz_uint user_extra_data_local_len, + const char *user_extra_data_central, mz_uint user_extra_data_central_len); + #ifndef MINIZ_NO_STDIO /* Adds the contents of a disk file to an archive. This function also records the disk file's modified time into the archive. */ /* level_and_flags - compression level (0-10, see MZ_BEST_SPEED, MZ_BEST_COMPRESSION, etc.) logically OR'd with zero or more mz_zip_flags, or just set to MZ_DEFAULT_COMPRESSION. */ diff --git a/third_party/miniz-2.0.8/readme.md b/third_party/miniz-2.1.0/readme.md similarity index 100% rename from third_party/miniz-2.0.8/readme.md rename to third_party/miniz-2.1.0/readme.md diff --git a/third_party/nlohmann b/third_party/nlohmann new file mode 160000 index 0000000000000..87cda1d664659 --- /dev/null +++ b/third_party/nlohmann @@ -0,0 +1 @@ +Subproject commit 87cda1d6646592ac5866dc703c8e1839046a6806 diff --git a/third_party/onnx b/third_party/onnx index 96046b8ccfb8e..f7ee1ac60d06a 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 96046b8ccfb8e6fa82f6b2b34b3d56add2e8849c +Subproject commit f7ee1ac60d06abe8e26c9b6bbe1e3db5286b614b diff --git a/third_party/pybind11 b/third_party/pybind11 index 8de7772cc72da..aa304c9c7d725 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 8de7772cc72daca8e947b79b83fea46214931604 +Subproject commit aa304c9c7d725ffb9d10af08a3b34cb372307020 diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index 549c70e039532..ee07488e26749 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1,586 +1,1791 @@ +load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") +load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") +load( + ":xnnpack_src_defs.bzl", + "HOT_SRCS", + "JIT_SRCS", + "LOGGING_SRCS", + "OPERATOR_SRCS", + "SUBGRAPH_SRCS", + "TABLE_SRCS", +) +load( + ":xnnpack_wrapper_defs.bzl", + "AARCH32_ASM_MICROKERNEL_SRCS", + "AARCH64_ASM_MICROKERNEL_SRCS", + "PROD_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS", + "PROD_AARCH64_NEON_MICROKERNEL_SRCS", + "PROD_AVX2_MICROKERNEL_SRCS", + "PROD_AVX512F_MICROKERNEL_SRCS", + "PROD_AVX512SKX_MICROKERNEL_SRCS", + "PROD_AVX_MICROKERNEL_SRCS", + "PROD_F16C_MICROKERNEL_SRCS", + "PROD_FMA3_MICROKERNEL_SRCS", + "PROD_NEONDOT_MICROKERNEL_SRCS", + "PROD_NEONFMA_MICROKERNEL_SRCS", + "PROD_NEONFP16_MICROKERNEL_SRCS", + "PROD_NEONV8_MICROKERNEL_SRCS", + "PROD_NEON_MICROKERNEL_SRCS", + "PROD_SCALAR_AARCH32_MICROKERNEL_SRCS", + "PROD_SCALAR_PORTABLE_MICROKERNEL_SRCS", + "PROD_SSE2_MICROKERNEL_SRCS", + "PROD_SSE41_MICROKERNEL_SRCS", + "PROD_SSE_MICROKERNEL_SRCS", + "PROD_SSSE3_MICROKERNEL_SRCS", + "PROD_XOP_MICROKERNEL_SRCS", +) -def define_xnnpack(): - cxx_library( - name = "XNNPACK", - srcs = ["XNNPACK/src/allocator.c", "XNNPACK/src/init.c", "XNNPACK/src/memory-planner.c", "XNNPACK/src/operator-delete.c", "XNNPACK/src/runtime.c", "XNNPACK/src/subgraph.c", "XNNPACK/src/tensor.c", "XNNPACK/src/datatype-strings.c", "XNNPACK/src/operator-strings.c", "XNNPACK/src/subgraph-strings.c"], - deps = [":operators", ":subgraph", ":tables", ":ukernels_scalar", "//third_party:cpuinfo", "//third_party:pthreadpool", "//third_party:pthreadpool_header", ":arm_lib", ":x86_and_x86_64_lib"], - exported_deps = [], - compiler_flags = ["-w"], - preferred_linkage = "static", - exported_headers = {"xnnpack.h": "XNNPACK/include/xnnpack.h"}, - exported_preprocessor_flags = [], +# This defines XNNPACK targets for both fbsource BUCK and OSS BUCK +# Note that the file path is relative to the BUCK file that called from, not to this bzl file. +# So for fbsource build it points to xplat/third-party/XNNPACK/XNNPACK, +# and for OSS it points to pytorch/third_party/XNNPACK +def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = False): + WINDOWS_FLAGS = [ + "/D__x86_64__", + "/EHsc", + "/wd4090", # 'function': different 'const' qualifiers + "/wd4146", # unary minus operator applied to unsigned type, result still unsigned + ] + ([ + "/D__AVX512F__", # needed to avoid linkage errors + "-mavx2", + "/D__builtin_clz=__lzcnt", # Intrinsics are spelled differently in MSVC + "/Drestrict=", # MSVC doesn't understand [restrict XNN_NUM_ELEMENTS(N)] syntax + ] if XNNPACK_WINDOWS_AVX512F_ENABLED else []) + + WINDOWS_CLANG_COMPILER_FLAGS = [ + "-Wno-error", + "-Wno-error=undef", + "-Wno-error=incompatible-pointer-types", + "-Wno-error=incompatible-pointer-types-discards-qualifiers", + ] + + fb_xplat_cxx_library( + name = "interface", header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0", "-DXNN_NO_Q8_OPERATORS", "-DXNN_NO_F16_OPERATORS", "-DXNN_NO_NCHW_OPERATORS", "-DXNN_NO_QU8_OPERATORS", "-DXNN_NO_S8_OPERATORS", "-DXNN_NO_U8_OPERATORS", "-DXNN_NO_VCVT_OPERATORS", "-DXNN_NO_X32_OPERATORS", "-DXNN_NO_X8_OPERATORS", "-DXNN_NO_XX_OPERATORS"], - soname = "", + exported_headers = { + "xnnpack.h": "XNNPACK/include/xnnpack.h", + }, + apple_sdks = (IOS, MACOSX, APPLETVOS), + labels = labels, + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + exported_deps = [ + # Dependency only on pthreadpool interface + third_party("pthreadpool_header"), + ], ) - cxx_library( - name = "ukernels_scalar", - srcs = ["XNNPACK/wrappers/params-init.c", "XNNPACK/wrappers/u8-lut32norm/scalar.c", "XNNPACK/wrappers/xx-copy/memcpy.c", "XNNPACK/wrappers/x8-lut/gen/lut-scalar-x4.c", "XNNPACK/wrappers/x32-depthtospace2d-chw2hwc/scalar.c"], - deps = [":interface", "//third_party:FP16", "//third_party:FXdiv"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], - preferred_linkage = "static", - exported_preprocessor_flags = [], + fb_xplat_cxx_library( + name = "operators", + # srcs have to include HOT_SRCS to be able to build on ARVR + srcs = OPERATOR_SRCS + HOT_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-Oz", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("cpuinfo"), + third_party("FP16"), + third_party("FXdiv"), + third_party("clog"), + ], ) - cxx_library( - name = "operators", - srcs = ["XNNPACK/src/operators/argmax-pooling-nhwc.c", "XNNPACK/src/operators/average-pooling-nhwc.c", "XNNPACK/src/operators/binary-elementwise-nd.c", "XNNPACK/src/operators/channel-shuffle-nc.c", "XNNPACK/src/operators/constant-pad-nd.c", "XNNPACK/src/operators/convolution-nchw.c", "XNNPACK/src/operators/convolution-nhwc.c", "XNNPACK/src/operators/deconvolution-nhwc.c", "XNNPACK/src/operators/depth-to-space-nchw2nhwc.c", "XNNPACK/src/operators/depth-to-space-nhwc.c", "XNNPACK/src/operators/fully-connected-nc.c", "XNNPACK/src/operators/global-average-pooling-ncw.c", "XNNPACK/src/operators/global-average-pooling-nwc.c", "XNNPACK/src/operators/lut-elementwise-nc.c", "XNNPACK/src/operators/max-pooling-nhwc.c", "XNNPACK/src/operators/prelu-nc.c", "XNNPACK/src/operators/resize-bilinear-nchw.c", "XNNPACK/src/operators/resize-bilinear-nhwc.c", "XNNPACK/src/operators/softmax-nc.c", "XNNPACK/src/operators/unary-elementwise-nc.c", "XNNPACK/src/operators/unpooling-nhwc.c", "XNNPACK/src/indirection.c", "XNNPACK/src/operator-run.c", "XNNPACK/src/packing.c"], - deps = [":interface", "//third_party:cpuinfo", "//third_party:FP16", "//third_party:FXdiv", "//third_party:clog"], - exported_deps = [], - compiler_flags = ["-w", "-Os"], + fb_xplat_cxx_library( + name = "subgraph", + srcs = SUBGRAPH_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + third_party("FXdiv"), + third_party("clog"), + ], + ) + + fb_xplat_cxx_library( + name = "tables", + srcs = TABLE_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + third_party("FXdiv"), + third_party("clog"), + ], ) - cxx_library( - name = "arm_lib", - srcs = [], - deps = [":jit_memory", ":ukernels_asm_aarch32", ":ukernels_asm_aarch64", ":ukernels_neon", ":ukernels_neon_aarch64", ":ukernels_neon_dot", ":ukernels_neon_fma", ":ukernels_neon_fp16", ":ukernels_neon_fp16arith_aarch64", ":ukernels_neon_v8", ":ukernels_scalar_aarch32"], - exported_deps = [], - compiler_flags = ["-w"], + fb_xplat_cxx_library( + name = "jit_memory", + # srcs have to include HOT_SRCS to be able to build on ARVR + srcs = JIT_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-Oz", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platforms = (APPLE, ANDROID, CXX, WINDOWS), preferred_linkage = "static", - exported_preprocessor_flags = [], - header_namespace = "third-party/XNNPACK", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = [], - soname = "", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("clog"), + ], ) - cxx_library( - name = "x86_and_x86_64_lib", - srcs = [], - deps = [":ukernels_avx", ":ukernels_avx2", ":ukernels_avx512", ":ukernels_avx512skx", ":ukernels_f16c", ":ukernels_fma3", ":ukernels_sse", ":ukernels_sse2", ":ukernels_sse41", ":ukernels_ssse3", ":ukernels_xop"], - exported_deps = [], - compiler_flags = ["-w"], + fb_xplat_cxx_library( + name = "ukernels_scalar", + srcs = PROD_SCALAR_PORTABLE_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, preferred_linkage = "static", - exported_preprocessor_flags = [], - header_namespace = "third-party/XNNPACK", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = [], - soname = "", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + third_party("FXdiv"), + ], ) - cxx_library( - name = "tables", - srcs = ["XNNPACK/src/tables/exp2-k-over-64.c", "XNNPACK/src/tables/exp2-k-over-2048.c", "XNNPACK/src/tables/exp2minus-k-over-4.c", "XNNPACK/src/tables/exp2minus-k-over-8.c", "XNNPACK/src/tables/exp2minus-k-over-16.c", "XNNPACK/src/tables/exp2minus-k-over-64.c", "XNNPACK/src/tables/exp2minus-k-over-2048.c"], - deps = [":interface", "//third_party:FP16", "//third_party:FXdiv", "//third_party:clog"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_sse", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-msse", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_SSE_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse"], + deps = [ + ":interface", + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_sse_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-msse", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse"], + windows_srcs = PROD_SSE_MICROKERNEL_SRCS, + deps = [ + ":interface", + ], ) - cxx_library( - name = "subgraph", - srcs = ["XNNPACK/src/subgraph/abs.c", "XNNPACK/src/subgraph/add2.c", "XNNPACK/src/subgraph/argmax-pooling-2d.c", "XNNPACK/src/subgraph/average-pooling-2d.c", "XNNPACK/src/subgraph/bankers-rounding.c", "XNNPACK/src/subgraph/ceiling.c", "XNNPACK/src/subgraph/clamp.c", "XNNPACK/src/subgraph/convert.c", "XNNPACK/src/subgraph/convolution-2d.c", "XNNPACK/src/subgraph/deconvolution-2d.c", "XNNPACK/src/subgraph/depth-to-space.c", "XNNPACK/src/subgraph/depthwise-convolution-2d.c", "XNNPACK/src/subgraph/divide.c", "XNNPACK/src/subgraph/elu.c", "XNNPACK/src/subgraph/floor.c", "XNNPACK/src/subgraph/fully-connected.c", "XNNPACK/src/subgraph/global-average-pooling-2d.c", "XNNPACK/src/subgraph/hardswish.c", "XNNPACK/src/subgraph/leaky-relu.c", "XNNPACK/src/subgraph/max-pooling-2d.c", "XNNPACK/src/subgraph/maximum2.c", "XNNPACK/src/subgraph/minimum2.c", "XNNPACK/src/subgraph/multiply2.c", "XNNPACK/src/subgraph/negate.c", "XNNPACK/src/subgraph/prelu.c", "XNNPACK/src/subgraph/sigmoid.c", "XNNPACK/src/subgraph/softmax.c", "XNNPACK/src/subgraph/square-root.c", "XNNPACK/src/subgraph/square.c", "XNNPACK/src/subgraph/squared-difference.c", "XNNPACK/src/subgraph/static-constant-pad.c", "XNNPACK/src/subgraph/static-reshape.c", "XNNPACK/src/subgraph/static-resize-bilinear-2d.c", "XNNPACK/src/subgraph/subtract.c", "XNNPACK/src/subgraph/unpooling-2d.c"], - deps = [":interface", "//third_party:FP16", "//third_party:FXdiv", "//third_party:clog"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_sse2", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-msse2", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_SSE2_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse2"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse2"], + deps = [ + ":interface", + third_party("FP16"), + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_sse2_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-msse2", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse2"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse2"], + windows_srcs = PROD_SSE2_MICROKERNEL_SRCS, + deps = [ + ":interface", + third_party("FP16"), + ], ) - cxx_library( - name = "ukernels_avx512", - srcs = [], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2", "-mavx512f"], + fb_xplat_cxx_library( + name = "ukernels_ssse3", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mssse3", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_SSSE3_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mssse3"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mssse3"], + deps = [ + ":interface", + third_party("FP16"), + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_ssse3_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["x86", ["-mavx512f"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f32-dwconv/gen/up16x3-minmax-avx512f.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x4-minmax-avx512f.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x9-minmax-avx512f.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x25-minmax-avx512f.c", "XNNPACK/wrappers/f32-gemm/gen/1x16-minmax-avx512f-broadcast.c", "XNNPACK/wrappers/f32-gemm/gen/7x16-minmax-avx512f-broadcast.c", "XNNPACK/wrappers/f32-igemm/gen/1x16-minmax-avx512f-broadcast.c", "XNNPACK/wrappers/f32-igemm/gen/7x16-minmax-avx512f-broadcast.c", "XNNPACK/wrappers/f32-prelu/gen/avx512f-2x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vadd-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vaddc-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vdiv-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vdivc-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vmaxc-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vmin-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vminc-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vmul-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vmulc-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vrdivc-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vrsubc-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiff-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiffc-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vsub-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vsubc-minmax-avx512f-x32.c", "XNNPACK/wrappers/f32-vclamp/gen/vclamp-avx512f-x16.c", "XNNPACK/wrappers/f32-velu/gen/velu-avx512f-rr1-lut16-p3-perm-x64.c", "XNNPACK/wrappers/f32-vhswish/gen/vhswish-avx512f-x16.c", "XNNPACK/wrappers/f32-vlrelu/gen/vlrelu-avx512f-x16.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndd-avx512f-x16.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndne-avx512f-x16.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndu-avx512f-x16.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndz-avx512f-x16.c", "XNNPACK/wrappers/f32-vsigmoid/gen/vsigmoid-avx512f-rr2-lut32-p2-perm2-scalef-div-x64.c", "XNNPACK/wrappers/f32-vunary/gen/vabs-avx512f-x16.c", "XNNPACK/wrappers/f32-vunary/gen/vneg-avx512f-x16.c", "XNNPACK/wrappers/f32-vunary/gen/vsqr-avx512f-x16.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mssse3", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mssse3"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mssse3"], + windows_srcs = PROD_SSSE3_MICROKERNEL_SRCS, + deps = [ + ":interface", + third_party("FP16"), + ], ) - cxx_library( - name = "ukernels_neon_fp16arith_aarch64", - srcs = ["XNNPACK/wrappers/f16-dwconv/gen/up8x25-minmax-neonfp16arith-acc2.c", "XNNPACK/wrappers/f16-dwconv/gen/up16x3-minmax-neonfp16arith.c", "XNNPACK/wrappers/f16-dwconv/gen/up16x4-minmax-neonfp16arith.c", "XNNPACK/wrappers/f16-dwconv/gen/up16x9-minmax-neonfp16arith.c", "XNNPACK/wrappers/f16-gavgpool/gen/7p7x-minmax-neonfp16arith-c8.c", "XNNPACK/wrappers/f16-gavgpool/gen/7x-minmax-neonfp16arith-c8.c", "XNNPACK/wrappers/f16-gemm/gen/1x16-minmax-neonfp16arith-ld64.c", "XNNPACK/wrappers/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c", "XNNPACK/wrappers/f16-ibilinear/gen/neonfp16arith-c8.c", "XNNPACK/wrappers/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c", "XNNPACK/wrappers/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c", "XNNPACK/wrappers/f16-maxpool/9p8x-minmax-neonfp16arith-c8.c", "XNNPACK/wrappers/f16-prelu/gen/neonfp16arith-2x16.c", "XNNPACK/wrappers/f16-vbinary/gen/vadd-minmax-neonfp16arith-x16.c", "XNNPACK/wrappers/f16-vbinary/gen/vaddc-minmax-neonfp16arith-x16.c", "XNNPACK/wrappers/f16-vbinary/gen/vmul-minmax-neonfp16arith-x16.c", "XNNPACK/wrappers/f16-vbinary/gen/vmulc-minmax-neonfp16arith-x16.c", "XNNPACK/wrappers/f16-vclamp/gen/vclamp-neonfp16arith-x16.c", "XNNPACK/wrappers/f16-vhswish/gen/vhswish-neonfp16arith-x16.c", "XNNPACK/wrappers/f16-vlrelu/gen/vlrelu-neonfp16arith-x16.c", "XNNPACK/wrappers/f16-vmulcaddc/gen/c8-minmax-neonfp16arith-2x.c"], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_sse41", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-msse4.1", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_SSE41_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse4.1"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse4.1"], + deps = [ + ":interface", + third_party("FP16"), + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_sse41_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["(aarch64|arm64)", ["-march=armv8.2-a+fp16"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-msse4.1", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse4.1"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse4.1"], + windows_srcs = PROD_SSE41_MICROKERNEL_SRCS, + deps = [ + ":interface", + third_party("FP16"), + ], ) - cxx_library( + fb_xplat_cxx_library( name = "ukernels_avx", - srcs = [], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2", "-mavx"], + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mavx", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mavx", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_AVX_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], + deps = [ + ":interface", + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_avx_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["x86", ["-mavx"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f16-f32-vcvt/gen/vcvt-avx-int16-x16.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x25-minmax-avx.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x3-minmax-avx.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x4-minmax-avx.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x9-minmax-avx.c", "XNNPACK/wrappers/f32-f16-vcvt/gen/vcvt-avx-x24.c", "XNNPACK/wrappers/f32-gemm/gen/1x16-minmax-avx-broadcast.c", "XNNPACK/wrappers/f32-gemm/gen/5x16-minmax-avx-broadcast.c", "XNNPACK/wrappers/f32-igemm/gen/1x16-minmax-avx-broadcast.c", "XNNPACK/wrappers/f32-igemm/gen/5x16-minmax-avx-broadcast.c", "XNNPACK/wrappers/f32-prelu/gen/avx-2x16.c", "XNNPACK/wrappers/f32-qs8-vcvt/gen/vcvt-avx-x32.c", "XNNPACK/wrappers/f32-qu8-vcvt/gen/vcvt-avx-x32.c", "XNNPACK/wrappers/f32-vbinary/gen/vadd-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vaddc-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vdiv-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vdivc-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vmaxc-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vmin-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vminc-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vmul-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vmulc-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vrdivc-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vrsubc-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiff-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiffc-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vsub-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vbinary/gen/vsubc-minmax-avx-x16.c", "XNNPACK/wrappers/f32-vclamp/gen/vclamp-avx-x16.c", "XNNPACK/wrappers/f32-velu/gen/velu-avx-rr2-lut4-p4-perm-x32.c", "XNNPACK/wrappers/f32-vhswish/gen/vhswish-avx-x16.c", "XNNPACK/wrappers/f32-vlrelu/gen/vlrelu-avx-x16.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndd-avx-x16.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndne-avx-x16.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndu-avx-x16.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndz-avx-x16.c", "XNNPACK/wrappers/f32-vsigmoid/gen/vsigmoid-avx-rr2-p5-nr2-x40.c", "XNNPACK/wrappers/f32-vsqrt/gen/avx-sqrt-x8.c", "XNNPACK/wrappers/f32-vunary/gen/vabs-avx-x16.c", "XNNPACK/wrappers/f32-vunary/gen/vneg-avx-x16.c", "XNNPACK/wrappers/f32-vunary/gen/vsqr-avx-x16.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-avx-mul16-add16.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-avx-mul16-add16.c", "XNNPACK/wrappers/qc8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qc8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qc8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qc8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qs8-dwconv/gen/up16x9-minmax-fp32-avx-mul16-add16.c", "XNNPACK/wrappers/qs8-dwconv/gen/up16x25-minmax-fp32-avx-mul16-add16.c", "XNNPACK/wrappers/qs8-f32-vcvt/gen/vcvt-avx-x32.c", "XNNPACK/wrappers/qs8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qs8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qs8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qs8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-avx-mul32-ld32-x8.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-avx-mul32-ld32-x8.c", "XNNPACK/wrappers/qs8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c", "XNNPACK/wrappers/qs8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c", "XNNPACK/wrappers/qu8-dwconv/gen/up16x9-minmax-fp32-avx-mul16.c", "XNNPACK/wrappers/qu8-dwconv/gen/up16x25-minmax-fp32-avx-mul16.c", "XNNPACK/wrappers/qu8-f32-vcvt/gen/vcvt-avx-x32.c", "XNNPACK/wrappers/qu8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qu8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qu8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qu8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-avx-mul32-ld32-x8.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-avx-mul32-ld32-x8.c", "XNNPACK/wrappers/qu8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c", "XNNPACK/wrappers/qu8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c", "XNNPACK/wrappers/x8-lut/gen/lut-avx-x64.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mavx", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mavx", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], + windows_srcs = PROD_AVX_MICROKERNEL_SRCS, + deps = [ + ":interface", + ], ) - cxx_library( - name = "ukernels_sse41", - srcs = [], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_f16c", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mf16c", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mf16c", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_F16C_MICROKERNEL_SRCS, + ), + ], + platforms = (APPLE, ANDROID, CXX, WINDOWS), preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mf16c"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mf16c"], + deps = [ + ":interface", + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_f16c_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["x86", ["-msse4.1"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f16-f32-vcvt/gen/vcvt-sse41-int16-x16.c", "XNNPACK/wrappers/f32-f16-vcvt/gen/vcvt-sse41-x8.c", "XNNPACK/wrappers/f32-prelu/gen/sse41-2x8.c", "XNNPACK/wrappers/f32-qs8-vcvt/gen/vcvt-sse41-x32.c", "XNNPACK/wrappers/f32-vlrelu/gen/vlrelu-sse41-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndd-sse41-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndne-sse41-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndu-sse41-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndz-sse41-x8.c", "XNNPACK/wrappers/f32-vsigmoid/gen/vsigmoid-sse41-rr2-lut64-p2-div-x8.c", "XNNPACK/wrappers/qc8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c", "XNNPACK/wrappers/qc8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c", "XNNPACK/wrappers/qc8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qc8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qc8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qc8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qs8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16-add16.c", "XNNPACK/wrappers/qs8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16-add16.c", "XNNPACK/wrappers/qs8-f32-vcvt/gen/vcvt-sse41-x16.c", "XNNPACK/wrappers/qs8-gavgpool/gen/7p7x-minmax-fp32-sse41-c8.c", "XNNPACK/wrappers/qs8-gavgpool/gen/7x-minmax-fp32-sse41-c8.c", "XNNPACK/wrappers/qs8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qs8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qs8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qs8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-sse41-mul16-ld64-x8.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c", "XNNPACK/wrappers/qs8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c", "XNNPACK/wrappers/qs8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c", "XNNPACK/wrappers/qu8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c", "XNNPACK/wrappers/qu8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c", "XNNPACK/wrappers/qu8-f32-vcvt/gen/vcvt-sse41-x16.c", "XNNPACK/wrappers/qu8-gavgpool/gen/7p7x-minmax-fp32-sse41-c8.c", "XNNPACK/wrappers/qu8-gavgpool/gen/7x-minmax-fp32-sse41-c8.c", "XNNPACK/wrappers/qu8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qu8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qu8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qu8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-sse41-mul16-ld64-x8.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c", "XNNPACK/wrappers/qu8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c", "XNNPACK/wrappers/qu8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c", "XNNPACK/wrappers/s8-ibilinear/gen/sse41-c16.c", "XNNPACK/wrappers/s8-maxpool/9p8x-minmax-sse41-c16.c", "XNNPACK/wrappers/s8-vclamp/sse41-x64.c", "XNNPACK/wrappers/u8-ibilinear/gen/sse41-c16.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mf16c", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mf16c", + ], + ), + ], + platforms = (APPLE, ANDROID, CXX, WINDOWS), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mf16c"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mf16c"], + windows_srcs = PROD_F16C_MICROKERNEL_SRCS, + deps = [ + ":interface", + ], ) - cxx_library( - name = "ukernels_neon", - srcs = ["XNNPACK/wrappers/f16-f32-vcvt/gen/vcvt-neon-int16-x16.c", "XNNPACK/wrappers/f32-argmaxpool/4x-neon-c4.c", "XNNPACK/wrappers/f32-argmaxpool/9p8x-neon-c4.c", "XNNPACK/wrappers/f32-argmaxpool/9x-neon-c4.c", "XNNPACK/wrappers/f32-avgpool/9p8x-minmax-neon-c4.c", "XNNPACK/wrappers/f32-avgpool/9x-minmax-neon-c4.c", "XNNPACK/wrappers/f32-conv-hwc2chw/3x3s2p1c3x4-neon-2x2.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x3-minmax-neon.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x4-minmax-neon.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x9-minmax-neon.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x25-minmax-neon-acc2.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-neon-2x4.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3s2p1-minmax-neon-1x4.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/5x5p2-minmax-neon-1x4.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/5x5s2p2-minmax-neon-1x4.c", "XNNPACK/wrappers/f32-f16-vcvt/gen/vcvt-neon-x8.c", "XNNPACK/wrappers/f32-gavgpool-cw/neon-x4.c", "XNNPACK/wrappers/f32-gavgpool/7p7x-minmax-neon-c4.c", "XNNPACK/wrappers/f32-gavgpool/7x-minmax-neon-c4.c", "XNNPACK/wrappers/f32-gemm/gen/1x8-minmax-neon-lane-ld64.c", "XNNPACK/wrappers/f32-gemm/gen/4x2-minmax-neon-lane-ld64.c", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-neon-lane-ld64.c", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-neon-lane-ld128.c", "XNNPACK/wrappers/f32-ibilinear-chw/gen/neon-p8.c", "XNNPACK/wrappers/f32-ibilinear/gen/neon-c8.c", "XNNPACK/wrappers/f32-igemm/gen/1x8-minmax-neon-lane-ld64.c", "XNNPACK/wrappers/f32-igemm/gen/4x2-minmax-neon-lane-ld64.c", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-neon-lane-ld64.c", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-neon-lane-ld128.c", "XNNPACK/wrappers/f32-maxpool/9p8x-minmax-neon-c4.c", "XNNPACK/wrappers/f32-pavgpool/9p8x-minmax-neon-c4.c", "XNNPACK/wrappers/f32-pavgpool/9x-minmax-neon-c4.c", "XNNPACK/wrappers/f32-prelu/gen/neon-2x8.c", "XNNPACK/wrappers/f32-qs8-vcvt/gen/vcvt-neon-x32.c", "XNNPACK/wrappers/f32-qu8-vcvt/gen/vcvt-neon-x32.c", "XNNPACK/wrappers/f32-raddstoreexpminusmax/gen/neon-rr2-lut64-p2-x8.c", "XNNPACK/wrappers/f32-rmax/neon.c", "XNNPACK/wrappers/f32-spmm/gen/32x1-minmax-neon.c", "XNNPACK/wrappers/f32-vbinary/gen/vadd-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vaddc-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmaxc-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmin-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vminc-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmul-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmulc-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vrsubc-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiff-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiffc-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsub-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsubc-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vclamp/gen/vclamp-neon-x8.c", "XNNPACK/wrappers/f32-velu/gen/velu-neon-rr2-lut16-p3-x8.c", "XNNPACK/wrappers/f32-vhswish/gen/vhswish-neon-x16.c", "XNNPACK/wrappers/f32-vlrelu/gen/vlrelu-neon-x8.c", "XNNPACK/wrappers/f32-vmulcaddc/gen/c4-minmax-neon-2x.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndd-neon-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndne-neon-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndu-neon-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndz-neon-x8.c", "XNNPACK/wrappers/f32-vsigmoid/gen/vsigmoid-neon-rr2-lut64-p2-nr2recps-x8.c", "XNNPACK/wrappers/f32-vunary/gen/vabs-neon-x8.c", "XNNPACK/wrappers/f32-vunary/gen/vneg-neon-x8.c", "XNNPACK/wrappers/f32-vunary/gen/vsqr-neon-x8.c", "XNNPACK/wrappers/qc8-dwconv/gen/up8x25-minmax-fp32-neon-mla8-ld64.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-neon-mla8-ld64.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-neon-mla8-ld64.c", "XNNPACK/wrappers/qc8-gemm/gen/1x8c2s4-minmax-fp32-neon-mlal.c", "XNNPACK/wrappers/qc8-gemm/gen/2x8c2s4-minmax-fp32-neon-mlal.c", "XNNPACK/wrappers/qc8-igemm/gen/1x8c2s4-minmax-fp32-neon-mlal.c", "XNNPACK/wrappers/qc8-igemm/gen/2x8c2s4-minmax-fp32-neon-mlal.c", "XNNPACK/wrappers/qs8-dwconv/gen/up8x25-minmax-rndnu-neon-mla8-ld64.c", "XNNPACK/wrappers/qs8-dwconv/gen/up16x9-minmax-rndnu-neon-mla8-ld64.c", "XNNPACK/wrappers/qs8-dwconv/gen/up16x25-minmax-rndnu-neon-mla8-ld64.c", "XNNPACK/wrappers/qs8-f32-vcvt/gen/vcvt-neon-x32.c", "XNNPACK/wrappers/qs8-gavgpool/gen/7p7x-minmax-rndnu-neon-c8.c", "XNNPACK/wrappers/qs8-gavgpool/gen/7x-minmax-rndnu-neon-c8.c", "XNNPACK/wrappers/qs8-gemm/gen/1x8-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qs8-gemm/gen/1x8c2s4-minmax-rndnu-neon-mlal.c", "XNNPACK/wrappers/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qs8-gemm/gen/2x8c2s4-minmax-rndnu-neon-mlal.c", "XNNPACK/wrappers/qs8-igemm/gen/1x8-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qs8-igemm/gen/1x8c2s4-minmax-rndnu-neon-mlal.c", "XNNPACK/wrappers/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qs8-igemm/gen/2x8c2s4-minmax-rndnu-neon-mlal.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-neon-ld64-x16.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-neon-ld64-x32.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-neon-ld64-x16.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-neon-ld64-x32.c", "XNNPACK/wrappers/qs8-vmul/gen/minmax-rndnu-neon-ld64-x16.c", "XNNPACK/wrappers/qs8-vmulc/gen/minmax-rndnu-neon-ld64-x16.c", "XNNPACK/wrappers/qu8-avgpool/9p8x-minmax-neon-c8.c", "XNNPACK/wrappers/qu8-avgpool/9x-minmax-neon-c8.c", "XNNPACK/wrappers/qu8-dwconv/gen/up8x25-minmax-rndnu-neon-mul8.c", "XNNPACK/wrappers/qu8-dwconv/gen/up16x9-minmax-rndnu-neon-mul8.c", "XNNPACK/wrappers/qu8-f32-vcvt/gen/vcvt-neon-x32.c", "XNNPACK/wrappers/qu8-gavgpool/gen/7p7x-minmax-rndnu-neon-c8.c", "XNNPACK/wrappers/qu8-gavgpool/gen/7x-minmax-rndnu-neon-c8.c", "XNNPACK/wrappers/qu8-gemm/gen/1x8-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qu8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qu8-gemm/gen/3x8-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qu8-gemm/gen/4x16-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qu8-igemm/gen/1x8-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qu8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qu8-igemm/gen/3x8-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qu8-igemm/gen/4x16-minmax-rndnu-neon-mlal-lane.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-neon-ld64-x16.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-neon-ld64-x32.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-neon-ld64-x16.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-neon-ld64-x32.c", "XNNPACK/wrappers/qu8-vmul/gen/minmax-rndnu-neon-ld64-x16.c", "XNNPACK/wrappers/qu8-vmulc/gen/minmax-rndnu-neon-ld64-x16.c", "XNNPACK/wrappers/s8-ibilinear/gen/neon-c8.c", "XNNPACK/wrappers/s8-ibilinear/gen/neon-c16.c", "XNNPACK/wrappers/s8-maxpool/9p8x-minmax-neon-c16.c", "XNNPACK/wrappers/s8-vclamp/neon-x64.c", "XNNPACK/wrappers/u8-ibilinear/gen/neon-c8.c", "XNNPACK/wrappers/u8-ibilinear/gen/neon-c16.c", "XNNPACK/wrappers/u8-maxpool/9p8x-minmax-neon-c16.c", "XNNPACK/wrappers/u8-rmax/neon.c", "XNNPACK/wrappers/u8-vclamp/neon-x64.c", "XNNPACK/wrappers/xx-fill/neon-x64.c", "XNNPACK/wrappers/xx-pad/neon.c", "XNNPACK/wrappers/x8-zip/xm-neon.c", "XNNPACK/wrappers/x8-zip/x2-neon.c", "XNNPACK/wrappers/x8-zip/x3-neon.c", "XNNPACK/wrappers/x8-zip/x4-neon.c", "XNNPACK/wrappers/x32-packx/x4-neon-st4.c", "XNNPACK/wrappers/x32-unpool/neon.c", "XNNPACK/wrappers/x32-zip/xm-neon.c", "XNNPACK/wrappers/x32-zip/x2-neon.c", "XNNPACK/wrappers/x32-zip/x3-neon.c", "XNNPACK/wrappers/x32-zip/x4-neon.c"], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_xop", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mxop", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_preprocessor_flags = [ + ( + "windows-x86_64", + [ + "-Drestrict=", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_XOP_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mxop"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mxop"], + deps = [ + ":interface", + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_xop_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["^(android-armv7|iphoneos-armv7)$", ["-march=armv7-a", "-mfpu=neon", "-mfloat-abi=softfp"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mxop", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_preprocessor_flags = [ + ( + "windows-x86_64", + [ + "-Drestrict=", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mxop"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mxop"], + windows_srcs = PROD_XOP_MICROKERNEL_SRCS, + deps = [ + ":interface", + ], ) - cxx_library( - name = "ukernels_neon_dot", - srcs = [], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_fma3", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mfma", + "-mf16c", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "^(i[3-6]86|x86|x86_64|AMD64)$", + [ + "-mfma", + "-mf16c", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_FMA3_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ + "-mfma", + "-mf16c", + ], + windows_compiler_flags_override = WINDOWS_FLAGS + [ + "-mfma", + "-mf16c", + ], + deps = [ + ":interface", + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_fma3_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["(aarch64|arm64)", ["-march=armv8.2-a+dotprod"]], ["^android-armv7$", ["-march=armv8.2-a+dotprod", "-mfpu=neon-fp-armv8", "-mfloat-abi=softfp"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["^((?!iphoneos-armv7).)*$", ["XNNPACK/wrappers/qc8-gemm/gen/1x8c4-minmax-fp32-neondot.c", "XNNPACK/wrappers/qc8-gemm/gen/1x16c4-minmax-fp32-neondot.c", "XNNPACK/wrappers/qc8-gemm/gen/4x8c4-minmax-fp32-neondot.c", "XNNPACK/wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-neondot.c", "XNNPACK/wrappers/qc8-igemm/gen/1x8c4-minmax-fp32-neondot.c", "XNNPACK/wrappers/qc8-igemm/gen/1x16c4-minmax-fp32-neondot.c", "XNNPACK/wrappers/qc8-igemm/gen/4x8c4-minmax-fp32-neondot.c", "XNNPACK/wrappers/qc8-igemm/gen/4x16c4-minmax-fp32-neondot.c", "XNNPACK/wrappers/qs8-gemm/gen/1x8c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qs8-gemm/gen/1x16c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qs8-gemm/gen/4x8c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qs8-igemm/gen/1x8c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qs8-igemm/gen/1x16c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qs8-igemm/gen/4x8c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qs8-igemm/gen/4x16c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qu8-gemm/gen/1x8c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qu8-gemm/gen/1x16c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qu8-gemm/gen/4x8c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qu8-gemm/gen/4x16c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qu8-igemm/gen/1x8c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qu8-igemm/gen/1x16c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qu8-igemm/gen/4x8c4-minmax-rndnu-neondot.c", "XNNPACK/wrappers/qu8-igemm/gen/4x16c4-minmax-rndnu-neondot.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mfma", + "-mf16c", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "^(i[3-6]86|x86|x86_64|AMD64)$", + [ + "-mfma", + "-mf16c", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ + "-mfma", + "-mf16c", + ], + windows_compiler_flags_override = WINDOWS_FLAGS + [ + "-mfma", + "-mf16c", + ], + windows_srcs = PROD_FMA3_MICROKERNEL_SRCS, + deps = [ + ":interface", + ], ) - cxx_library( - name = "ukernels_neon_aarch64", - srcs = ["XNNPACK/wrappers/f32-conv-hwc2chw/3x3s2p1c3x4-neonfma-2x2.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-neonfma-3x4.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3s2p1-minmax-neonfma-2x4-acc2.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/5x5p2-minmax-neonfma-4x4.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/5x5s2p2-minmax-neonfma-1x4-acc2.c", "XNNPACK/wrappers/f32-gemm/gen/1x8-minmax-neonfma-lane-ld64.c", "XNNPACK/wrappers/f32-gemm/gen/4x2-minmax-neonfma-lane-ld64.c", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-neonfma-lane-ld64.c", "XNNPACK/wrappers/f32-igemm/gen/1x8-minmax-neonfma-lane-ld64.c", "XNNPACK/wrappers/f32-igemm/gen/4x2-minmax-neonfma-lane-ld64.c", "XNNPACK/wrappers/f32-igemm/gen/6x8-minmax-neonfma-lane-ld64.c", "XNNPACK/wrappers/f32-spmm/gen/32x2-minmax-neonfma.c", "XNNPACK/wrappers/f32-spmm/gen/32x4-minmax-neonfma.c", "XNNPACK/wrappers/f32-vbinary/gen/vdiv-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vdivc-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vrdivc-minmax-neon-x8.c", "XNNPACK/wrappers/f32-vsqrt/gen/neon-sqrt-x4.c", "XNNPACK/wrappers/x8-lut/gen/lut-neon-tbx128x4-x64.c"], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_avx2", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mavx2", + "-mfma", + "-mf16c", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mavx2", + "-mfma", + "-mf16c", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_AVX2_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ + "-mavx2", + "-mfma", + "-mf16c", + ], + windows_compiler_flags_override = WINDOWS_FLAGS + [ + "-mavx2", + "-mfma", + "-mf16c", + ], + deps = [ + ":interface", + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_avx2_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["(aarch64|arm64)", ["-mfpu=neon-vfpv4"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mavx2", + "-mfma", + "-mf16c", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mavx2", + "-mfma", + "-mf16c", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ + "-mavx2", + "-mfma", + "-mf16c", + ], + windows_compiler_flags_override = WINDOWS_FLAGS + [ + "-mavx2", + "-mfma", + "-mf16c", + ], + windows_srcs = PROD_AVX2_MICROKERNEL_SRCS, + deps = [ + ":interface", + ], ) - cxx_library( - name = "ukernels_neon_v8", - srcs = ["XNNPACK/wrappers/f32-qs8-vcvt/gen/vcvt-neonv8-x32.c", "XNNPACK/wrappers/f32-qu8-vcvt/gen/vcvt-neonv8-x32.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndd-neonv8-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndne-neonv8-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndu-neonv8-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndz-neonv8-x8.c", "XNNPACK/wrappers/qc8-dwconv/gen/up8x25-minmax-fp32-neonv8-mla8-ld64.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-neonv8-mla8-ld64.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-neonv8-mla8-ld64.c", "XNNPACK/wrappers/qc8-gemm/gen/1x8-minmax-fp32-neonv8-mlal-lane-prfm.c", "XNNPACK/wrappers/qc8-gemm/gen/1x8-minmax-fp32-neonv8-mlal-lane.c", "XNNPACK/wrappers/qc8-gemm/gen/1x8c2s4-minmax-fp32-neonv8-mlal.c", "XNNPACK/wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-neonv8-mlal.c", "XNNPACK/wrappers/qc8-gemm/gen/1x16-minmax-fp32-neonv8-mlal-lane.c", "XNNPACK/wrappers/qc8-gemm/gen/2x8c2s4-minmax-fp32-neonv8-mlal.c", "XNNPACK/wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-neonv8-mlal.c", "XNNPACK/wrappers/qc8-gemm/gen/4x16-minmax-fp32-neonv8-mlal-lane.c", "XNNPACK/wrappers/qc8-igemm/gen/1x8-minmax-fp32-neonv8-mlal-lane-prfm.c", "XNNPACK/wrappers/qc8-igemm/gen/1x8-minmax-fp32-neonv8-mlal-lane.c", "XNNPACK/wrappers/qc8-igemm/gen/1x8c2s4-minmax-fp32-neonv8-mlal.c", "XNNPACK/wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-neonv8-mlal.c", "XNNPACK/wrappers/qc8-igemm/gen/1x16-minmax-fp32-neonv8-mlal-lane.c", "XNNPACK/wrappers/qc8-igemm/gen/2x8c2s4-minmax-fp32-neonv8-mlal.c", "XNNPACK/wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-neonv8-mlal.c", "XNNPACK/wrappers/qc8-igemm/gen/4x16-minmax-fp32-neonv8-mlal-lane.c"], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_avx512", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mavx512f", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mavx512f", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_AVX512F_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx512f"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx512f"], + deps = [ + ":interface", + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_avx512_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["(aarch64|arm64)", ["-march=armv8-a", "-mfpu=neon-fp-armv8"]], ["^android-armv7$", ["-march=armv8-a", "-mfpu=neon-fp-armv8", "-mfloat-abi=softfp"]], ["^iphoneos-armv7$", ["-mcpu=cyclone", "-mtune=generic"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mavx512f", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "x86", + [ + "-mavx512f", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx512f"], + windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx512f"], + windows_srcs = PROD_AVX512F_MICROKERNEL_SRCS, + deps = [ + ":interface", + ], ) - cxx_library( + fb_xplat_cxx_library( name = "ukernels_avx512skx", - srcs = [], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2", "-mavx512f", "-mavx512cd", "-mavx512bw", "-mavx512dq", "-mavx512vl"], - preferred_linkage = "static", - exported_preprocessor_flags = [], + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["^(i[3-6]86|x86|x86_64|AMD64)$", ["-mavx512f", "-mavx512cd", "-mavx512bw", "-mavx512dq", "-mavx512vl"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f16-f32-vcvt/gen/vcvt-avx512skx-x16.c", "XNNPACK/wrappers/f32-f16-vcvt/gen/vcvt-avx512skx-x16.c", "XNNPACK/wrappers/f32-qs8-vcvt/gen/vcvt-avx512skx-x128.c", "XNNPACK/wrappers/f32-qu8-vcvt/gen/vcvt-avx512skx-x128.c", "XNNPACK/wrappers/qc8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", "XNNPACK/wrappers/qc8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", "XNNPACK/wrappers/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qs8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", "XNNPACK/wrappers/qs8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", "XNNPACK/wrappers/qs8-f32-vcvt/gen/vcvt-avx512skx-x32.c", "XNNPACK/wrappers/qs8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qs8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qs8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-avx512skx-mul32-ld128-x16.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-avx512skx-mul32-ld128-x16.c", "XNNPACK/wrappers/qu8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", "XNNPACK/wrappers/qu8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", "XNNPACK/wrappers/qu8-f32-vcvt/gen/vcvt-avx512skx-x32.c", "XNNPACK/wrappers/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-avx512skx-mul32-ld128-x16.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-avx512skx-mul32-ld128-x16.c", "XNNPACK/wrappers/x8-lut/gen/lut-avx512skx-vpshufb-x64.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "^(i[3-6]86|x86|x86_64|AMD64)$", + [ + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + ], + ), + ], + platform_srcs = [ + ( + "x86|x86_64|platform009|platform010", + PROD_AVX512SKX_MICROKERNEL_SRCS, + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + ], + windows_compiler_flags_override = WINDOWS_FLAGS + [ + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + ], + deps = [ + ":interface", + ], ) - cxx_library( - name = "ukernels_neon_fp16", - srcs = ["XNNPACK/wrappers/f16-f32-vcvt/gen/vcvt-neonfp16-x16.c", "XNNPACK/wrappers/f32-f16-vcvt/gen/vcvt-neonfp16-x16.c"], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], - preferred_linkage = "static", - exported_preprocessor_flags = [], + fb_xplat_cxx_library( + name = "ukernels_avx512skx_ovr_win32", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["arm", ["-mfpu=neon-fp16"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "^(i[3-6]86|x86|x86_64|AMD64)$", + [ + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + ], + windows_compiler_flags_override = WINDOWS_FLAGS + [ + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + ], + windows_srcs = PROD_AVX512SKX_MICROKERNEL_SRCS, + deps = [ + ":interface", + ], ) - cxx_library( - name = "interface", - srcs = [], - deps = [], - exported_deps = ["//third_party:pthreadpool_header"], - compiler_flags = ["-w"], - preferred_linkage = "static", - exported_headers = {"xnnpack.h": "XNNPACK/include/xnnpack.h"}, - exported_preprocessor_flags = [], + fb_xplat_cxx_library( + name = "ukernels_neon", + srcs = PROD_NEON_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "^(android-armv7|iphoneos-armv7)$", + [ + "-march=armv7-a", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + ], ) - cxx_library( - name = "ukernels_fma3", - srcs = [], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2", "-mfma", "-mf16c"], - preferred_linkage = "static", - exported_preprocessor_flags = [], + fb_xplat_cxx_library( + name = "ukernels_neon_fma", + srcs = PROD_NEONFMA_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["^(i[3-6]86|x86|x86_64|AMD64)$", ["-mfma", "-mf16c"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f16-dwconv/gen/up8x25-minmax-fma3-acc2.c", "XNNPACK/wrappers/f16-dwconv/gen/up16x3-minmax-fma3.c", "XNNPACK/wrappers/f16-dwconv/gen/up16x4-minmax-fma3.c", "XNNPACK/wrappers/f16-dwconv/gen/up16x9-minmax-fma3.c", "XNNPACK/wrappers/f16-ibilinear/gen/fma3-c8.c", "XNNPACK/wrappers/f16-vmulcaddc/gen/c8-minmax-fma3-2x.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x25-minmax-fma3.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x3-minmax-fma3.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x4-minmax-fma3.c", "XNNPACK/wrappers/f32-dwconv/gen/up16x9-minmax-fma3.c", "XNNPACK/wrappers/f32-gemm/gen/1x16-minmax-fma3-broadcast.c", "XNNPACK/wrappers/f32-gemm/gen/1x16s4-minmax-fma3-broadcast.c", "XNNPACK/wrappers/f32-gemm/gen/4x16s4-minmax-fma3-broadcast.c", "XNNPACK/wrappers/f32-gemm/gen/5x16-minmax-fma3-broadcast.c", "XNNPACK/wrappers/f32-igemm/gen/1x16-minmax-fma3-broadcast.c", "XNNPACK/wrappers/f32-igemm/gen/1x16s4-minmax-fma3-broadcast.c", "XNNPACK/wrappers/f32-igemm/gen/4x16s4-minmax-fma3-broadcast.c", "XNNPACK/wrappers/f32-igemm/gen/5x16-minmax-fma3-broadcast.c", "XNNPACK/wrappers/f32-vhswish/gen/vhswish-fma3-x16.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "arm", + [ + "-mfpu=neon-vfpv4", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + ], ) - cxx_library( - name = "jit_memory", - srcs = ["XNNPACK/src/jit/aarch32-assembler.cc", "XNNPACK/src/jit/aarch64-assembler.cc", "XNNPACK/src/jit/assembler.cc", "XNNPACK/src/jit/memory.c"], - deps = [":interface", "//third_party:clog"], - exported_deps = [], - compiler_flags = ["-w", "-Os"], + fb_xplat_cxx_library( + name = "ukernels_neon_fp16", + srcs = PROD_NEONFP16_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "arm", + [ + "-mfpu=neon-fp16", + ], + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_neon_v8", + srcs = PROD_NEONV8_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "(aarch64|arm64)", + [ + "-march=armv8-a", + "-mfpu=neon-fp-armv8", + ], + ), + ( + "^android-armv7$", + [ + "-march=armv8-a", + "-mfpu=neon-fp-armv8", + "-mfloat-abi=softfp", + ], + ), + ( + "^iphoneos-armv7$", + [ + "-mcpu=cyclone", + "-mtune=generic", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + ], ) - cxx_library( - name = "ukernels_sse2", - srcs = [], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_neon_dot", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "(aarch64|arm64)", + [ + "-march=armv8.2-a+dotprod", + ], + ), + ( + "^android-armv7$", + [ + "-march=armv8.2-a+dotprod", + "-mfpu=neon-fp-armv8", + "-mfloat-abi=softfp", + ], + ), + ], + platform_srcs = [ + # excluding iphoneos-armv7, matching everything else + ( + "^((?!iphoneos-armv7).)*$", + PROD_NEONDOT_MICROKERNEL_SRCS, + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_neon_aarch64", + srcs = PROD_AARCH64_NEON_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["x86", ["-msse2"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f16-f32-vcvt/gen/vcvt-sse2-int16-x32.c", "XNNPACK/wrappers/f32-argmaxpool/4x-sse2-c4.c", "XNNPACK/wrappers/f32-argmaxpool/9p8x-sse2-c4.c", "XNNPACK/wrappers/f32-argmaxpool/9x-sse2-c4.c", "XNNPACK/wrappers/f32-f16-vcvt/gen/vcvt-sse2-x16.c", "XNNPACK/wrappers/f32-prelu/gen/sse2-2x8.c", "XNNPACK/wrappers/f32-qs8-vcvt/gen/vcvt-sse2-x32.c", "XNNPACK/wrappers/f32-qu8-vcvt/gen/vcvt-sse2-x32.c", "XNNPACK/wrappers/f32-raddstoreexpminusmax/gen/sse2-rr2-p5-x20-acc2.c", "XNNPACK/wrappers/f32-velu/gen/velu-sse2-rr2-lut16-p3-x12.c", "XNNPACK/wrappers/f32-vlrelu/gen/vlrelu-sse2-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndd-sse2-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndne-sse2-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndu-sse2-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndz-sse2-x8.c", "XNNPACK/wrappers/f32-vsigmoid/gen/vsigmoid-sse2-rr2-lut64-p2-div-x8.c", "XNNPACK/wrappers/qc8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c", "XNNPACK/wrappers/qc8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16.c", "XNNPACK/wrappers/qc8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qc8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qc8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qc8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qs8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16-add16.c", "XNNPACK/wrappers/qs8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16-add16.c", "XNNPACK/wrappers/qs8-f32-vcvt/gen/vcvt-sse2-x32.c", "XNNPACK/wrappers/qs8-gavgpool/gen/7p7x-minmax-fp32-sse2-c8.c", "XNNPACK/wrappers/qs8-gavgpool/gen/7x-minmax-fp32-sse2-c8.c", "XNNPACK/wrappers/qs8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qs8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qs8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qs8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-sse2-mul16-ld64-x8.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c", "XNNPACK/wrappers/qs8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c", "XNNPACK/wrappers/qs8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c", "XNNPACK/wrappers/qu8-avgpool/9p8x-minmax-sse2-c8.c", "XNNPACK/wrappers/qu8-avgpool/9x-minmax-sse2-c8.c", "XNNPACK/wrappers/qu8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c", "XNNPACK/wrappers/qu8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16.c", "XNNPACK/wrappers/qu8-f32-vcvt/gen/vcvt-sse2-x32.c", "XNNPACK/wrappers/qu8-gavgpool/gen/7p7x-minmax-fp32-sse2-c8.c", "XNNPACK/wrappers/qu8-gavgpool/gen/7x-minmax-fp32-sse2-c8.c", "XNNPACK/wrappers/qu8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qu8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qu8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qu8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-sse2-mul16-ld64-x8.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c", "XNNPACK/wrappers/qu8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c", "XNNPACK/wrappers/qu8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c", "XNNPACK/wrappers/s8-ibilinear/gen/sse2-c8.c", "XNNPACK/wrappers/s8-maxpool/9p8x-minmax-sse2-c16.c", "XNNPACK/wrappers/s8-vclamp/sse2-x64.c", "XNNPACK/wrappers/u8-ibilinear/gen/sse2-c8.c", "XNNPACK/wrappers/u8-maxpool/9p8x-minmax-sse2-c16.c", "XNNPACK/wrappers/u8-rmax/sse2.c", "XNNPACK/wrappers/u8-vclamp/sse2-x64.c", "XNNPACK/wrappers/xx-fill/sse2-x64.c", "XNNPACK/wrappers/xx-pad/sse2.c", "XNNPACK/wrappers/x8-zip/xm-sse2.c", "XNNPACK/wrappers/x8-zip/x2-sse2.c", "XNNPACK/wrappers/x8-zip/x3-sse2.c", "XNNPACK/wrappers/x8-zip/x4-sse2.c", "XNNPACK/wrappers/x32-unpool/sse2.c", "XNNPACK/wrappers/x32-zip/xm-sse2.c", "XNNPACK/wrappers/x32-zip/x2-sse2.c", "XNNPACK/wrappers/x32-zip/x3-sse2.c", "XNNPACK/wrappers/x32-zip/x4-sse2.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "(aarch64|arm64)", + [ + "-mfpu=neon-vfpv4", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + ], ) - cxx_library( - name = "ukernels_sse", - srcs = [], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "ukernels_neon_fp16arith_aarch64", + srcs = PROD_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.c"), + ("XNNPACK/src", "**/*.h"), + ]), + header_namespace = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "(aarch64|arm64)", + [ + "-march=armv8.2-a+fp16", + ], + ), + ], preferred_linkage = "static", - exported_preprocessor_flags = [], + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + ], + ) + + fb_xplat_cxx_library( + name = "ukernels_scalar_aarch32", + srcs = PROD_SCALAR_AARCH32_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/src", "**/*.c"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["x86", ["-msse"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f32-avgpool/9p8x-minmax-sse-c4.c", "XNNPACK/wrappers/f32-avgpool/9x-minmax-sse-c4.c", "XNNPACK/wrappers/f32-conv-hwc2chw/3x3s2p1c3x4-sse-2x2.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x3-minmax-sse.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x4-minmax-sse.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x9-minmax-sse.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x25-minmax-sse.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-sse-2x4-acc2.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3s2p1-minmax-sse-1x4-acc3.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/5x5p2-minmax-sse-4x4.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/5x5s2p2-minmax-sse-2x4.c", "XNNPACK/wrappers/f32-gavgpool-cw/sse-x4.c", "XNNPACK/wrappers/f32-gavgpool/7p7x-minmax-sse-c4.c", "XNNPACK/wrappers/f32-gavgpool/7x-minmax-sse-c4.c", "XNNPACK/wrappers/f32-gemm/gen/1x8-minmax-sse-load1.c", "XNNPACK/wrappers/f32-gemm/gen/4x2c4-minmax-sse.c", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-sse-load1.c", "XNNPACK/wrappers/f32-ibilinear-chw/gen/sse-p8.c", "XNNPACK/wrappers/f32-ibilinear/gen/sse-c8.c", "XNNPACK/wrappers/f32-igemm/gen/1x8-minmax-sse-load1.c", "XNNPACK/wrappers/f32-igemm/gen/4x2c4-minmax-sse.c", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-sse-load1.c", "XNNPACK/wrappers/f32-maxpool/9p8x-minmax-sse-c4.c", "XNNPACK/wrappers/f32-pavgpool/9p8x-minmax-sse-c4.c", "XNNPACK/wrappers/f32-pavgpool/9x-minmax-sse-c4.c", "XNNPACK/wrappers/f32-rmax/sse.c", "XNNPACK/wrappers/f32-spmm/gen/32x1-minmax-sse.c", "XNNPACK/wrappers/f32-vbinary/gen/vadd-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vaddc-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vdiv-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vdivc-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmaxc-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmin-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vminc-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmul-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmulc-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vrdivc-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vrsubc-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiff-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiffc-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsub-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsubc-minmax-sse-x8.c", "XNNPACK/wrappers/f32-vclamp/gen/vclamp-sse-x8.c", "XNNPACK/wrappers/f32-vhswish/gen/vhswish-sse-x8.c", "XNNPACK/wrappers/f32-vlrelu/gen/vlrelu-sse-x8.c", "XNNPACK/wrappers/f32-vmulcaddc/gen/c4-minmax-sse-2x.c", "XNNPACK/wrappers/f32-vsqrt/gen/sse-sqrt-x4.c", "XNNPACK/wrappers/f32-vunary/gen/vabs-sse-x8.c", "XNNPACK/wrappers/f32-vunary/gen/vneg-sse-x8.c", "XNNPACK/wrappers/f32-vunary/gen/vsqr-sse-x8.c", "XNNPACK/wrappers/x32-packx/x4-sse.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "^(android-armv7|iphoneos-armv7)$", + [ + "-march=armv7-a", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ), + ], + platforms = (APPLE, ANDROID, CXX, WINDOWS), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + third_party("FP16"), + ], ) - cxx_library( + fb_xplat_cxx_library( name = "ukernels_asm_aarch32", - srcs = ["XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-cortex-a7.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-ld64.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/4x4-aarch32-vfp-ld64.S", "XNNPACK/wrappers/f32-gemm/4x4-minmax-aarch32-vfp-ld64.S", "XNNPACK/wrappers/f32-gemm/4x8-minmax-aarch32-neon-cortex-a55.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-cortex-a7.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-ld64.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/4x8-minmax-aarch32-neon-cortex-a55.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-cortex-a7.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-cortex-a7.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-ld64.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8c4-minmax-fp32-aarch32-neondot-cortex-a55.S", "XNNPACK/wrappers/qc8-gemm/gen/4x8c4-minmax-fp32-aarch32-neondot-ld64.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-ld64.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8c4-minmax-fp32-aarch32-neondot-cortex-a55.S", "XNNPACK/wrappers/qc8-igemm/gen/4x8c4-minmax-fp32-aarch32-neondot-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a7.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a7.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8c4-minmax-rndnu-aarch32-neondot-cortex-a55.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8c4-minmax-rndnu-aarch32-neondot-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x8c4-minmax-rndnu-aarch32-neondot-cortex-a55.S", "XNNPACK/wrappers/qs8-igemm/gen/4x8c4-minmax-rndnu-aarch32-neondot-ld64.S", "XNNPACK/wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a7.S", "XNNPACK/wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a7.S", "XNNPACK/wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qu8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qu8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qu8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qu8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S"], - deps = [":interface", ":jit_memory", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], - preferred_linkage = "static", - exported_preprocessor_flags = [], + srcs = AARCH32_ASM_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "xnnpack/assembly.h"), + ("XNNPACK/src", "**/*.S"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["^android-armv7$", ["-march=armv8.2-a+dotprod", "-mfpu=neon-fp-armv8", "-mfloat-abi=softfp"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "^android-armv7$", + [ + "-march=armv8.2-a+dotprod", + "-mfpu=neon-fp-armv8", + "-mfloat-abi=softfp", + ], + ), + ], + platforms = (APPLE, ANDROID, CXX, WINDOWS), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + ":jit_memory", + third_party("FP16"), + ], ) - cxx_library( - name = "ukernels_ssse3", - srcs = [], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], - preferred_linkage = "static", - exported_preprocessor_flags = [], + fb_xplat_cxx_library( + name = "ukernels_asm_aarch64", + srcs = AARCH64_ASM_MICROKERNEL_SRCS, + headers = subdir_glob([ + ("XNNPACK/src", "xnnpack/assembly.h"), + ("XNNPACK/src", "**/*.S"), + ]), header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["x86", ["-mssse3"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-ssse3-2x4-acc2.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + apple_sdks = (IOS, MACOSX, APPLETVOS), + compiler_flags = [ + "-O2", + ], + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], + labels = labels, + platform_compiler_flags = [ + ( + "(aarch64|arm64)", + [ + "-march=armv8.2-a+fp16+dotprod", + ], + ), + ], + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS, + deps = [ + ":interface", + ":jit_memory", + third_party("FP16"), + ], ) - cxx_library( - name = "ukernels_f16c", - srcs = [], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2", "-mf16c"], + fb_xplat_cxx_library( + name = "arm64_lib", + apple_sdks = (IOS, MACOSX, APPLETVOS), + labels = labels, preferred_linkage = "static", - exported_preprocessor_flags = [], - header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["x86", ["-mf16c"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f16-f32-vcvt/gen/vcvt-f16c-x16.c", "XNNPACK/wrappers/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c", "XNNPACK/wrappers/f16-gavgpool/gen/7x-minmax-f16c-c8.c", "XNNPACK/wrappers/f16-maxpool/9p8x-minmax-f16c-c8.c", "XNNPACK/wrappers/f16-prelu/gen/f16c-2x16.c", "XNNPACK/wrappers/f16-vbinary/gen/vadd-minmax-f16c-x16.c", "XNNPACK/wrappers/f16-vbinary/gen/vaddc-minmax-f16c-x16.c", "XNNPACK/wrappers/f16-vbinary/gen/vmul-minmax-f16c-x16.c", "XNNPACK/wrappers/f16-vbinary/gen/vmulc-minmax-f16c-x16.c", "XNNPACK/wrappers/f16-vclamp/gen/vclamp-f16c-x16.c", "XNNPACK/wrappers/f16-vhswish/gen/vhswish-f16c-x16.c", "XNNPACK/wrappers/f16-vlrelu/gen/vlrelu-f16c-x16.c", "XNNPACK/wrappers/f32-f16-vcvt/gen/vcvt-f16c-x16.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", visibility = ["PUBLIC"], + deps = [ + ":jit_memory", + ":ukernels_asm_aarch64", + ":ukernels_neon", + ":ukernels_neon_aarch64", + ":ukernels_neon_dot", + ":ukernels_neon_fma", + ":ukernels_neon_fp16", + ":ukernels_neon_fp16arith_aarch64", + ":ukernels_neon_v8", + ], ) - cxx_library( - name = "ukernels_xop", - srcs = [], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2", "-mxop"], + fb_xplat_cxx_library( + name = "x86_and_x86_64_lib", + apple_sdks = (IOS, MACOSX, APPLETVOS), + labels = labels, preferred_linkage = "static", - exported_preprocessor_flags = [], - header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows-x86_64", ["-Drestrict="]], ["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-xop-mul16-add16.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-xop-mul16-add16.c", "XNNPACK/wrappers/qc8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qc8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qc8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qc8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qs8-dwconv/gen/up16x9-minmax-fp32-xop-mul16-add16.c", "XNNPACK/wrappers/qs8-dwconv/gen/up16x25-minmax-fp32-xop-mul16-add16.c", "XNNPACK/wrappers/qs8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qs8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qs8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qs8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-xop-mul32-ld32-x8.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-xop-mul32-ld32-x8.c", "XNNPACK/wrappers/qu8-dwconv/gen/up16x9-minmax-fp32-xop-mul32.c", "XNNPACK/wrappers/qu8-dwconv/gen/up16x25-minmax-fp32-xop-mul32.c", "XNNPACK/wrappers/qu8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qu8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qu8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qu8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-xop-mul32-ld32-x8.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-xop-mul32-ld32-x8.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", visibility = ["PUBLIC"], + deps = [ + ":ukernels_avx", + ":ukernels_avx2", + ":ukernels_avx512", + ":ukernels_avx512skx", + ":ukernels_f16c", + ":ukernels_fma3", + ":ukernels_sse", + ":ukernels_sse2", + ":ukernels_sse41", + ":ukernels_ssse3", + ":ukernels_xop", + ], ) - cxx_library( - name = "ukernels_scalar_aarch32", - srcs = ["XNNPACK/wrappers/f16-f32-vcvt/gen/vcvt-scalar-x4.c", "XNNPACK/wrappers/f32-argmaxpool/4x-scalar-c1.c", "XNNPACK/wrappers/f32-argmaxpool/9p8x-scalar-c1.c", "XNNPACK/wrappers/f32-argmaxpool/9x-scalar-c1.c", "XNNPACK/wrappers/f32-avgpool/9p8x-minmax-scalar-c1.c", "XNNPACK/wrappers/f32-avgpool/9x-minmax-scalar-c1.c", "XNNPACK/wrappers/f32-conv-hwc/3x3s2p0p1c3x4-scalar-1x1.c", "XNNPACK/wrappers/f32-conv-hwc/3x3s2p1c3x4-scalar-1x1.c", "XNNPACK/wrappers/f32-conv-hwc2chw/3x3s2p1c3x4-scalar-1x1.c", "XNNPACK/wrappers/f32-dwconv/gen/up1x3-minmax-scalar-acc2.c", "XNNPACK/wrappers/f32-dwconv/gen/up1x3-scalar-acc2.c", "XNNPACK/wrappers/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c", "XNNPACK/wrappers/f32-dwconv/gen/up1x4-scalar-acc2.c", "XNNPACK/wrappers/f32-dwconv/gen/up1x9-minmax-scalar-acc2.c", "XNNPACK/wrappers/f32-dwconv/gen/up1x9-scalar-acc2.c", "XNNPACK/wrappers/f32-dwconv/gen/up1x25-minmax-scalar-acc2.c", "XNNPACK/wrappers/f32-dwconv/gen/up1x25-scalar-acc2.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-scalar-4x1.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/3x3s2p1-minmax-scalar-2x1-acc2.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/5x5p2-minmax-scalar-2x1-acc2.c", "XNNPACK/wrappers/f32-dwconv2d-chw/gen/5x5s2p2-minmax-scalar-2x1-acc2.c", "XNNPACK/wrappers/f32-f16-vcvt/gen/vcvt-scalar-fabsf-x2.c", "XNNPACK/wrappers/f32-gavgpool-cw/scalar-x1.c", "XNNPACK/wrappers/f32-gavgpool/7p7x-minmax-scalar-c1.c", "XNNPACK/wrappers/f32-gavgpool/7x-minmax-scalar-c1.c", "XNNPACK/wrappers/f32-gemm/gen/1x4-minmax-scalar.c", "XNNPACK/wrappers/f32-gemm/gen/1x4-relu-scalar.c", "XNNPACK/wrappers/f32-gemm/gen/1x4-scalar.c", "XNNPACK/wrappers/f32-gemm/gen/4x2-minmax-scalar.c", "XNNPACK/wrappers/f32-gemm/gen/4x2-scalar.c", "XNNPACK/wrappers/f32-gemm/gen/4x4-minmax-scalar.c", "XNNPACK/wrappers/f32-gemm/gen/4x4-relu-scalar.c", "XNNPACK/wrappers/f32-gemm/gen/4x4-scalar.c", "XNNPACK/wrappers/f32-ibilinear-chw/gen/scalar-p4.c", "XNNPACK/wrappers/f32-ibilinear/gen/scalar-c2.c", "XNNPACK/wrappers/f32-igemm/gen/1x4-minmax-scalar.c", "XNNPACK/wrappers/f32-igemm/gen/1x4-relu-scalar.c", "XNNPACK/wrappers/f32-igemm/gen/1x4-scalar.c", "XNNPACK/wrappers/f32-igemm/gen/4x2-minmax-scalar.c", "XNNPACK/wrappers/f32-igemm/gen/4x2-scalar.c", "XNNPACK/wrappers/f32-igemm/gen/4x4-minmax-scalar.c", "XNNPACK/wrappers/f32-igemm/gen/4x4-relu-scalar.c", "XNNPACK/wrappers/f32-igemm/gen/4x4-scalar.c", "XNNPACK/wrappers/f32-maxpool/9p8x-minmax-scalar-c1.c", "XNNPACK/wrappers/f32-pavgpool/9p8x-minmax-scalar-c1.c", "XNNPACK/wrappers/f32-pavgpool/9x-minmax-scalar-c1.c", "XNNPACK/wrappers/f32-prelu/gen/scalar-2x4.c", "XNNPACK/wrappers/f32-qs8-vcvt/gen/vcvt-scalar-imagic-x4.c", "XNNPACK/wrappers/f32-qu8-vcvt/gen/vcvt-scalar-imagic-x4.c", "XNNPACK/wrappers/f32-raddstoreexpminusmax/gen/scalar-rr2-p5-x4-acc2.c", "XNNPACK/wrappers/f32-rmax/scalar.c", "XNNPACK/wrappers/f32-spmm/gen/8x1-minmax-scalar.c", "XNNPACK/wrappers/f32-spmm/gen/8x2-minmax-scalar.c", "XNNPACK/wrappers/f32-spmm/gen/8x4-minmax-scalar.c", "XNNPACK/wrappers/f32-vbinary/gen/vadd-minmax-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vaddc-minmax-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vdiv-minmax-scalar-x2.c", "XNNPACK/wrappers/f32-vbinary/gen/vdivc-minmax-scalar-x2.c", "XNNPACK/wrappers/f32-vbinary/gen/vmax-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmaxc-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmin-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vminc-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmul-minmax-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vmulc-minmax-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vrdivc-minmax-scalar-x2.c", "XNNPACK/wrappers/f32-vbinary/gen/vrsubc-minmax-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiff-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsqrdiffc-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsub-minmax-scalar-x8.c", "XNNPACK/wrappers/f32-vbinary/gen/vsubc-minmax-scalar-x8.c", "XNNPACK/wrappers/f32-vclamp/gen/vclamp-scalar-x4.c", "XNNPACK/wrappers/f32-velu/gen/velu-scalar-rr2-lut16-p3-x4.c", "XNNPACK/wrappers/f32-vhswish/gen/vhswish-scalar-x4.c", "XNNPACK/wrappers/f32-vlrelu/gen/vlrelu-scalar-x4.c", "XNNPACK/wrappers/f32-vmulcaddc/gen/c1-minmax-scalar-2x.c", "XNNPACK/wrappers/f32-vrelu/gen/vrelu-scalar-x8.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndd-scalar-libm-x1.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndne-scalar-libm-x1.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndu-scalar-libm-x1.c", "XNNPACK/wrappers/f32-vrnd/gen/vrndz-scalar-libm-x1.c", "XNNPACK/wrappers/f32-vsigmoid/gen/vsigmoid-scalar-rr2-lut64-p2-div-x2.c", "XNNPACK/wrappers/f32-vsqrt/gen/scalar-sqrt-x1.c", "XNNPACK/wrappers/f32-vunary/gen/vabs-scalar-x4.c", "XNNPACK/wrappers/f32-vunary/gen/vneg-scalar-x4.c", "XNNPACK/wrappers/f32-vunary/gen/vsqr-scalar-x4.c", "XNNPACK/wrappers/qc8-dwconv/gen/up2x9-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qc8-dwconv/gen/up2x25-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qc8-gemm/gen/1x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qc8-gemm/gen/1x8-minmax-fp32-neon-mlal-lane.c", "XNNPACK/wrappers/qc8-gemm/gen/2x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qc8-igemm/gen/1x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qc8-igemm/gen/1x8-minmax-fp32-neon-mlal-lane.c", "XNNPACK/wrappers/qc8-igemm/gen/2x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qs8-dwconv/gen/up1x9-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qs8-dwconv/gen/up1x25-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qs8-f32-vcvt/gen/vcvt-scalar-x4.c", "XNNPACK/wrappers/qs8-gavgpool/gen/7p7x-minmax-fp32-scalar-imagic-c1.c", "XNNPACK/wrappers/qs8-gavgpool/gen/7x-minmax-fp32-scalar-imagic-c1.c", "XNNPACK/wrappers/qs8-gemm/gen/1x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qs8-gemm/gen/2x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qs8-igemm/gen/1x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qs8-igemm/gen/2x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-scalar-x1.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-scalar-x1.c", "XNNPACK/wrappers/qs8-vmul/gen/minmax-fp32-scalar-x4.c", "XNNPACK/wrappers/qs8-vmulc/gen/minmax-fp32-scalar-x4.c", "XNNPACK/wrappers/qu8-avgpool/9p8x-minmax-scalar-c1.c", "XNNPACK/wrappers/qu8-avgpool/9x-minmax-scalar-c1.c", "XNNPACK/wrappers/qu8-dwconv/gen/up1x9-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qu8-dwconv/gen/up1x25-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qu8-f32-vcvt/gen/vcvt-scalar-x4.c", "XNNPACK/wrappers/qu8-gavgpool/gen/7p7x-minmax-fp32-scalar-imagic-c1.c", "XNNPACK/wrappers/qu8-gavgpool/gen/7x-minmax-fp32-scalar-imagic-c1.c", "XNNPACK/wrappers/qu8-gemm/gen/1x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qu8-gemm/gen/2x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qu8-igemm/gen/1x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qu8-igemm/gen/2x2-minmax-fp32-scalar-fmagic.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-scalar-x1.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-scalar-x1.c", "XNNPACK/wrappers/qu8-vmul/gen/minmax-fp32-scalar-x4.c", "XNNPACK/wrappers/qu8-vmulc/gen/minmax-fp32-scalar-x4.c", "XNNPACK/wrappers/s8-ibilinear/gen/scalar-c1.c", "XNNPACK/wrappers/s8-maxpool/9p8x-minmax-scalar-c1.c", "XNNPACK/wrappers/s8-vclamp/scalar-x4.c", "XNNPACK/wrappers/u8-ibilinear/gen/scalar-c1.c", "XNNPACK/wrappers/u8-maxpool/9p8x-minmax-scalar-c1.c", "XNNPACK/wrappers/u8-rmax/scalar.c", "XNNPACK/wrappers/u8-vclamp/scalar-x4.c", "XNNPACK/wrappers/xx-fill/scalar-x16.c", "XNNPACK/wrappers/xx-pad/scalar.c", "XNNPACK/wrappers/x8-zip/xm-scalar.c", "XNNPACK/wrappers/x8-zip/x2-scalar.c", "XNNPACK/wrappers/x8-zip/x3-scalar.c", "XNNPACK/wrappers/x8-zip/x4-scalar.c", "XNNPACK/wrappers/x32-packx/x2-scalar.c", "XNNPACK/wrappers/x32-packx/x3-scalar.c", "XNNPACK/wrappers/x32-packx/x4-scalar.c", "XNNPACK/wrappers/x32-unpool/scalar.c", "XNNPACK/wrappers/x32-zip/xm-scalar.c", "XNNPACK/wrappers/x32-zip/x2-scalar.c", "XNNPACK/wrappers/x32-zip/x3-scalar.c", "XNNPACK/wrappers/x32-zip/x4-scalar.c"], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "x86_and_x86_64_lib_ovr_win32", + apple_sdks = (IOS, MACOSX, APPLETVOS), + labels = labels, preferred_linkage = "static", - exported_preprocessor_flags = [], - header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["^(android-armv7|iphoneos-armv7)$", ["-march=armv7-a", "-mfpu=neon", "-mfloat-abi=softfp"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", visibility = ["PUBLIC"], + deps = [ + ":ukernels_avx2_ovr_win32", + ":ukernels_avx512_ovr_win32", + ":ukernels_avx512skx_ovr_win32", + ":ukernels_avx_ovr_win32", + ":ukernels_f16c_ovr_win32", + ":ukernels_fma3_ovr_win32", + ":ukernels_sse2_ovr_win32", + ":ukernels_sse41_ovr_win32", + ":ukernels_sse_ovr_win32", + ":ukernels_ssse3_ovr_win32", + ":ukernels_xop_ovr_win32", + ], ) - cxx_library( - name = "ukernels_neon_fma", - srcs = ["XNNPACK/wrappers/f32-dwconv/gen/up8x3-minmax-neonfma.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x4-minmax-neonfma.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x9-minmax-neonfma.c", "XNNPACK/wrappers/f32-dwconv/gen/up8x25-minmax-neonfma-acc2.c", "XNNPACK/wrappers/f32-gemm/gen/1x8s4-minmax-neonfma.c", "XNNPACK/wrappers/f32-gemm/gen/6x8s4-minmax-neonfma.c", "XNNPACK/wrappers/f32-ibilinear-chw/gen/neonfma-p8.c", "XNNPACK/wrappers/f32-ibilinear/gen/neonfma-c8.c", "XNNPACK/wrappers/f32-igemm/gen/1x8s4-minmax-neonfma.c", "XNNPACK/wrappers/f32-igemm/gen/6x8s4-minmax-neonfma.c", "XNNPACK/wrappers/f32-raddstoreexpminusmax/gen/neonfma-rr1-lut64-p2-x16.c", "XNNPACK/wrappers/f32-spmm/gen/32x1-minmax-neonfma-pipelined.c", "XNNPACK/wrappers/f32-velu/gen/velu-neonfma-rr1-lut16-p3-x16.c", "XNNPACK/wrappers/f32-velu/gen/velu-neonfma-rr1-p6-x8.c", "XNNPACK/wrappers/f32-vmulcaddc/gen/c4-minmax-neonfma-2x.c", "XNNPACK/wrappers/f32-vsigmoid/gen/vsigmoid-neonfma-rr1-lut64-p2-nr2recps-x16.c"], - deps = [":interface", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], + fb_xplat_cxx_library( + name = "arm_lib", + apple_sdks = (IOS, MACOSX, APPLETVOS), + labels = labels, preferred_linkage = "static", - exported_preprocessor_flags = [], - header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["arm", ["-mfpu=neon-vfpv4"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", visibility = ["PUBLIC"], + deps = [ + ":jit_memory", + ":ukernels_asm_aarch32", + ":ukernels_asm_aarch64", + ":ukernels_neon", + ":ukernels_neon_aarch64", + ":ukernels_neon_dot", + ":ukernels_neon_fma", + ":ukernels_neon_fp16", + ":ukernels_neon_fp16arith_aarch64", + ":ukernels_neon_v8", + ":ukernels_scalar_aarch32", + ], ) - cxx_library( - name = "ukernels_avx2", - srcs = [], - deps = [":interface"], - exported_deps = [], - compiler_flags = ["-w", "-O2", "-mavx2", "-mfma", "-mf16c"], + fb_xplat_cxx_library( + name = "armv7_lib", + apple_sdks = (IOS, MACOSX, APPLETVOS), + labels = labels, preferred_linkage = "static", - exported_preprocessor_flags = [], - header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["x86", ["-mavx2", "-mfma", "-mf16c"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - platform_srcs = [["x86|x86_64|platform009", ["XNNPACK/wrappers/f16-gemm/gen/1x16-minmax-avx2-broadcast.c", "XNNPACK/wrappers/f16-gemm/gen/4x16-minmax-avx2-broadcast.c", "XNNPACK/wrappers/f16-igemm/gen/1x16-minmax-avx2-broadcast.c", "XNNPACK/wrappers/f16-igemm/gen/4x16-minmax-avx2-broadcast.c", "XNNPACK/wrappers/f32-qs8-vcvt/gen/vcvt-avx2-x64.c", "XNNPACK/wrappers/f32-qu8-vcvt/gen/vcvt-avx2-x64.c", "XNNPACK/wrappers/f32-velu/gen/velu-avx2-rr1-lut4-p4-perm-x56.c", "XNNPACK/wrappers/f32-vsigmoid/gen/vsigmoid-avx2-rr1-p5-div-x40.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", "XNNPACK/wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", "XNNPACK/wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qc8-gemm/gen/3x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qc8-igemm/gen/3x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qs8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", "XNNPACK/wrappers/qs8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", "XNNPACK/wrappers/qs8-f32-vcvt/gen/vcvt-avx2-x16.c", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qs8-gemm/gen/3x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qs8-igemm/gen/3x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qs8-vadd/gen/minmax-avx2-mul32-ld64-x16.c", "XNNPACK/wrappers/qs8-vaddc/gen/minmax-avx2-mul32-ld64-x16.c", "XNNPACK/wrappers/qu8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", "XNNPACK/wrappers/qu8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", "XNNPACK/wrappers/qu8-f32-vcvt/gen/vcvt-avx2-x16.c", "XNNPACK/wrappers/qu8-gemm/gen/1x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qu8-gemm/gen/3x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qu8-igemm/gen/1x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qu8-igemm/gen/3x8c8-minmax-fp32-avx2.c", "XNNPACK/wrappers/qu8-vadd/gen/minmax-avx2-mul32-ld64-x16.c", "XNNPACK/wrappers/qu8-vaddc/gen/minmax-avx2-mul32-ld64-x16.c", "XNNPACK/wrappers/x8-lut/gen/lut-avx2-x128.c"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", visibility = ["PUBLIC"], + deps = [ + ":jit_memory", + ":ukernels_asm_aarch32", + ":ukernels_neon", + ":ukernels_neon_dot", + ":ukernels_neon_fma", + ":ukernels_neon_v8", + ":ukernels_scalar_aarch32", + ], ) - cxx_library( - name = "ukernels_asm_aarch64", - srcs = ["XNNPACK/wrappers/f16-gemm/gen-inc/1x8inc-minmax-aarch64-neonfp16arith-ld64.S", "XNNPACK/wrappers/f16-gemm/gen-inc/1x16inc-minmax-aarch64-neonfp16arith-ld32.S", "XNNPACK/wrappers/f16-gemm/gen-inc/4x8inc-minmax-aarch64-neonfp16arith-ld64.S", "XNNPACK/wrappers/f16-gemm/gen-inc/4x16inc-minmax-aarch64-neonfp16arith-ld32.S", "XNNPACK/wrappers/f16-gemm/gen-inc/6x8inc-minmax-aarch64-neonfp16arith-ld64.S", "XNNPACK/wrappers/f16-gemm/gen-inc/6x16inc-minmax-aarch64-neonfp16arith-cortex-a55.S", "XNNPACK/wrappers/f16-gemm/gen-inc/6x16inc-minmax-aarch64-neonfp16arith-cortex-a75.S", "XNNPACK/wrappers/f16-gemm/gen-inc/6x16inc-minmax-aarch64-neonfp16arith-ld32.S", "XNNPACK/wrappers/f16-gemm/gen-inc/8x8inc-minmax-aarch64-neonfp16arith-ld64.S", "XNNPACK/wrappers/f16-gemm/gen/1x8-minmax-aarch64-neonfp16arith-ld64.S", "XNNPACK/wrappers/f16-gemm/gen/1x16-minmax-aarch64-neonfp16arith-ld32.S", "XNNPACK/wrappers/f16-gemm/gen/4x8-minmax-aarch64-neonfp16arith-ld64.S", "XNNPACK/wrappers/f16-gemm/gen/4x16-minmax-aarch64-neonfp16arith-ld32.S", "XNNPACK/wrappers/f16-gemm/gen/6x8-minmax-aarch64-neonfp16arith-ld64.S", "XNNPACK/wrappers/f16-gemm/gen/6x16-minmax-aarch64-neonfp16arith-cortex-a55.S", "XNNPACK/wrappers/f16-gemm/gen/6x16-minmax-aarch64-neonfp16arith-cortex-a75.S", "XNNPACK/wrappers/f16-gemm/gen/6x16-minmax-aarch64-neonfp16arith-ld32.S", "XNNPACK/wrappers/f16-gemm/gen/8x8-minmax-aarch64-neonfp16arith-ld64.S", "XNNPACK/wrappers/f16-igemm/4x16-minmax-aarch64-neonfp16arith-ld32.S", "XNNPACK/wrappers/f32-dwconv/up4x9-minmax-aarch64-neonfma-cortex-a55.S", "XNNPACK/wrappers/f32-dwconv/up4x9-minmax-aarch64-neonfma.S", "XNNPACK/wrappers/f32-gemm/gen-inc/1x8inc-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen-inc/1x8inc-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen-inc/1x8inc-minmax-aarch64-neonfma-ld64.S", "XNNPACK/wrappers/f32-gemm/gen-inc/1x8inc-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen-inc/1x12inc-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-cortex-a55.S", "XNNPACK/wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-ld64.S", "XNNPACK/wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-ld128.S", "XNNPACK/wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen-inc/4x12inc-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen-inc/5x8inc-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen-inc/5x8inc-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-cortex-a55.S", "XNNPACK/wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-cortex-a73.S", "XNNPACK/wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-ld64.S", "XNNPACK/wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-ld128.S", "XNNPACK/wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/1x8-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/1x8-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/1x8-minmax-aarch64-neonfma-ld64.S", "XNNPACK/wrappers/f32-gemm/gen/1x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/1x12-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-cortex-a55.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-ld64.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-ld128.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/4x12-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/5x8-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/5x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-cortex-a55.S", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-cortex-a73.S", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-ld64.S", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-ld128.S", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a53.S", "XNNPACK/wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/1x8-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/1x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-ld64.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-ld128.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/5x8-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/5x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-ld64.S", "XNNPACK/wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-ld128.S", "XNNPACK/wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", "XNNPACK/wrappers/f32-igemm/1x8-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/1x12-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/4x8-minmax-aarch64-neonfma-cortex-a55.S", "XNNPACK/wrappers/f32-igemm/4x12-minmax-aarch64-neonfma-cortex-a53.S", "XNNPACK/wrappers/f32-igemm/6x8-minmax-aarch64-neonfma-cortex-a55.S", "XNNPACK/wrappers/f32-igemm/6x8-minmax-aarch64-neonfma-cortex-a73.S", "XNNPACK/wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qc8-gemm/gen/1x16c4-minmax-fp32-aarch64-neondot-ld32.S", "XNNPACK/wrappers/qc8-gemm/gen/1x16c4-minmax-fp32-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mull.S", "XNNPACK/wrappers/qc8-gemm/gen/2x8c16-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qc8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qc8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld32.S", "XNNPACK/wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qc8-igemm/gen/2x8c16-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qc8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qc8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qc8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qc8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qc8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qc8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qs8-gemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-gemm/gen/1x16c4-minmax-fp32-aarch64-neondot-ld32.S", "XNNPACK/wrappers/qs8-gemm/gen/1x16c4-minmax-fp32-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/1x16c4-minmax-rndnu-aarch64-neondot-ld32.S", "XNNPACK/wrappers/qs8-gemm/gen/1x16c4-minmax-rndnu-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mull.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mull.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c16-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-gemm/gen/2x8c16-minmax-rndnu-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld32.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld32.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qs8-igemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-prfm.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c16-minmax-fp32-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-igemm/gen/2x8c16-minmax-rndnu-aarch64-neon-mlal.S", "XNNPACK/wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld64.S", "XNNPACK/wrappers/qs8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qu8-gemm/gen/4x8c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qu8-gemm/gen/4x8c4-minmax-rndnu-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a75.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a75.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qu8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qu8-igemm/gen/4x8c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qu8-igemm/gen/4x8c4-minmax-rndnu-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a75.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a53.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a75.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", "XNNPACK/wrappers/qu8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld128.S"], - deps = [":interface", ":jit_memory", "//third_party:FP16"], - exported_deps = [], - compiler_flags = ["-w", "-O2"], - preferred_linkage = "static", - exported_preprocessor_flags = [], + fb_xplat_cxx_library( + name = "XNNPACK", + apple_sdks = (IOS, MACOSX, APPLETVOS), + labels = labels, + deps = [ + ":operators", + ":subgraph", + ":tables", + ":ukernels_scalar", + third_party("cpuinfo"), + third_party("pthreadpool"), + ] + select({ + "DEFAULT": [ + ":arm_lib", + ":x86_and_x86_64_lib", + ], + "ovr_config//os:windows": [":x86_and_x86_64_lib_ovr_win32"] if XNNPACK_WINDOWS_AVX512F_ENABLED else [ + ":arm_lib", + ":x86_and_x86_64_lib", + ], + # doesn't cover iphonesimulator-x86_64 + "ovr_config//runtime:arm64-linux-ubuntu-neon": [":arm64_lib"], + "ovr_config//runtime:platform009": [":x86_and_x86_64_lib"], + "ovr_config//runtime:platform010": [":x86_and_x86_64_lib"], + }), + exported_headers = { + "xnnpack.h": "XNNPACK/include/xnnpack.h", + }, + fbobjc_preprocessor_flags = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + ], header_namespace = "", - headers = subdir_glob([("XNNPACK/src", "**/*.S"), ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h")]), - linker_flags = [], - platform_compiler_flags = [["(aarch64|arm64)", ["-march=armv8.2-a+fp16+dotprod"]]], - platform_linker_flags = [], - platform_preprocessor_flags = [["windows", ["-D_WINDOWS", "-D_WIN32", "-DWIN32", "-DNOMINMAX", "-D_CRT_SECURE_NO_WARNINGS", "-D_USE_MATH_DEFINES"]], ["windows.*64$", ["-D_WIN64"]]], - preprocessor_flags = ["-DXNN_LOG_LEVEL=0"], - soname = "", + headers = subdir_glob([ + ("XNNPACK/src", "**/*.h"), + ("XNNPACK/include", "**/*.h"), + ]), + platforms = (APPLE, ANDROID, CXX, WINDOWS), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + "-DXNN_NO_Q8_OPERATORS", + "-DXNN_NO_F16_OPERATORS", + "-DXNN_NO_NCHW_OPERATORS", + "-DXNN_NO_QU8_OPERATORS", + "-DXNN_NO_S8_OPERATORS", + "-DXNN_NO_U8_OPERATORS", + "-DXNN_NO_VCVT_OPERATORS", + "-DXNN_NO_X32_OPERATORS", + "-DXNN_NO_X8_OPERATORS", + "-DXNN_NO_XX_OPERATORS", + ], + srcs = [ + "XNNPACK/src/allocator.c", + "XNNPACK/src/init.c", + "XNNPACK/src/memory-planner.c", + "XNNPACK/src/operator-delete.c", + "XNNPACK/src/runtime.c", + "XNNPACK/src/subgraph.c", + "XNNPACK/src/tensor.c", + ] + LOGGING_SRCS, visibility = ["PUBLIC"], + windows_clang_compiler_flags_override = (WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS) if XNNPACK_WINDOWS_AVX512F_ENABLED else WINDOWS_FLAGS, + windows_compiler_flags_override = WINDOWS_FLAGS if XNNPACK_WINDOWS_AVX512F_ENABLED else [], ) diff --git a/third_party/xnnpack_src_defs.bzl b/third_party/xnnpack_src_defs.bzl new file mode 100644 index 0000000000000..d7586e9463cd4 --- /dev/null +++ b/third_party/xnnpack_src_defs.bzl @@ -0,0 +1,554 @@ +""" +Auto-generated by generate-wrappers.py script. Do not modify +""" + +OPERATOR_SRCS = [ + "XNNPACK/src/operators/argmax-pooling-nhwc.c", + "XNNPACK/src/operators/average-pooling-nhwc.c", + "XNNPACK/src/operators/binary-elementwise-nd.c", + "XNNPACK/src/operators/channel-shuffle-nc.c", + "XNNPACK/src/operators/constant-pad-nd.c", + "XNNPACK/src/operators/convolution-nchw.c", + "XNNPACK/src/operators/convolution-nhwc.c", + "XNNPACK/src/operators/deconvolution-nhwc.c", + "XNNPACK/src/operators/depth-to-space-nchw2nhwc.c", + "XNNPACK/src/operators/depth-to-space-nhwc.c", + "XNNPACK/src/operators/fully-connected-nc.c", + "XNNPACK/src/operators/global-average-pooling-ncw.c", + "XNNPACK/src/operators/global-average-pooling-nwc.c", + "XNNPACK/src/operators/lut-elementwise-nc.c", + "XNNPACK/src/operators/max-pooling-nhwc.c", + "XNNPACK/src/operators/prelu-nc.c", + "XNNPACK/src/operators/resize-bilinear-nchw.c", + "XNNPACK/src/operators/resize-bilinear-nhwc.c", + "XNNPACK/src/operators/softmax-nc.c", + "XNNPACK/src/operators/unary-elementwise-nc.c", + "XNNPACK/src/operators/unpooling-nhwc.c", +] + +SUBGRAPH_SRCS = [ + "XNNPACK/src/subgraph/abs.c", + "XNNPACK/src/subgraph/add2.c", + "XNNPACK/src/subgraph/argmax-pooling-2d.c", + "XNNPACK/src/subgraph/average-pooling-2d.c", + "XNNPACK/src/subgraph/bankers-rounding.c", + "XNNPACK/src/subgraph/ceiling.c", + "XNNPACK/src/subgraph/clamp.c", + "XNNPACK/src/subgraph/convert.c", + "XNNPACK/src/subgraph/convolution-2d.c", + "XNNPACK/src/subgraph/deconvolution-2d.c", + "XNNPACK/src/subgraph/depth-to-space.c", + "XNNPACK/src/subgraph/depthwise-convolution-2d.c", + "XNNPACK/src/subgraph/divide.c", + "XNNPACK/src/subgraph/elu.c", + "XNNPACK/src/subgraph/floor.c", + "XNNPACK/src/subgraph/fully-connected.c", + "XNNPACK/src/subgraph/global-average-pooling-2d.c", + "XNNPACK/src/subgraph/hardswish.c", + "XNNPACK/src/subgraph/leaky-relu.c", + "XNNPACK/src/subgraph/max-pooling-2d.c", + "XNNPACK/src/subgraph/maximum2.c", + "XNNPACK/src/subgraph/minimum2.c", + "XNNPACK/src/subgraph/multiply2.c", + "XNNPACK/src/subgraph/negate.c", + "XNNPACK/src/subgraph/prelu.c", + "XNNPACK/src/subgraph/sigmoid.c", + "XNNPACK/src/subgraph/softmax.c", + "XNNPACK/src/subgraph/square-root.c", + "XNNPACK/src/subgraph/square.c", + "XNNPACK/src/subgraph/squared-difference.c", + "XNNPACK/src/subgraph/static-constant-pad.c", + "XNNPACK/src/subgraph/static-reshape.c", + "XNNPACK/src/subgraph/static-resize-bilinear-2d.c", + "XNNPACK/src/subgraph/subtract.c", + "XNNPACK/src/subgraph/unpooling-2d.c", +] + +LOGGING_SRCS = [ + "XNNPACK/src/datatype-strings.c", + "XNNPACK/src/operator-strings.c", + "XNNPACK/src/subgraph-strings.c", +] + +HOT_SRCS = [ + "XNNPACK/src/indirection.c", + "XNNPACK/src/operator-run.c", + "XNNPACK/src/packing.c", +] + +TABLE_SRCS = [ + "XNNPACK/src/tables/exp2-k-over-64.c", + "XNNPACK/src/tables/exp2-k-over-2048.c", + "XNNPACK/src/tables/exp2minus-k-over-4.c", + "XNNPACK/src/tables/exp2minus-k-over-8.c", + "XNNPACK/src/tables/exp2minus-k-over-16.c", + "XNNPACK/src/tables/exp2minus-k-over-64.c", + "XNNPACK/src/tables/exp2minus-k-over-2048.c", +] + +JIT_SRCS = [ + "XNNPACK/src/jit/aarch32-assembler.cc", + "XNNPACK/src/jit/aarch64-assembler.cc", + "XNNPACK/src/jit/assembler.cc", + "XNNPACK/src/jit/memory.c", +] + +JIT_AARCH32_SRCS = [ + "XNNPACK/src/f32-gemm/4x8-aarch32-neon-cortex-a7.cc", + "XNNPACK/src/f32-gemm/4x8-aarch32-neon-cortex-a53.cc", + "XNNPACK/src/f32-gemm/4x8-aarch32-neon-cortex-a55.cc", + "XNNPACK/src/f32-gemm/4x8-aarch32-neon-cortex-a75.cc", + "XNNPACK/src/f32-gemm/4x8-aarch32-neon-ld64.cc", + "XNNPACK/src/f32-igemm/4x8-aarch32-neon-cortex-a7.cc", + "XNNPACK/src/f32-igemm/4x8-aarch32-neon-cortex-a53.cc", + "XNNPACK/src/f32-igemm/4x8-aarch32-neon-cortex-a55.cc", + "XNNPACK/src/f32-igemm/4x8-aarch32-neon-cortex-a75.cc", + "XNNPACK/src/f32-igemm/4x8-aarch32-neon-ld64.cc", + "XNNPACK/src/qc8-gemm/4x8-fp32-aarch32-neonv8-mlal-lane-ld64.cc", + "XNNPACK/src/qc8-gemm/4x8c4-fp32-aarch32-neondot-ld64.cc", + "XNNPACK/src/qc8-igemm/4x8-fp32-aarch32-neonv8-mlal-lane-ld64.cc", + "XNNPACK/src/qc8-igemm/4x8c4-fp32-aarch32-neondot-ld64.cc", + "XNNPACK/src/qs8-gemm/4x8-rndnu-aarch32-neon-mlal-lane-ld64.cc", + "XNNPACK/src/qs8-gemm/4x8c4-rndnu-aarch32-neondot-ld64.cc", + "XNNPACK/src/qs8-igemm/4x8-rndnu-aarch32-neon-mlal-lane-ld64.cc", + "XNNPACK/src/qs8-igemm/4x8c4-rndnu-aarch32-neondot-ld64.cc", +] + +JIT_AARCH64_SRCS = [ + "XNNPACK/src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.cc", + "XNNPACK/src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.cc", + "XNNPACK/src/f32-igemm/1x8-aarch64-neonfma-cortex-a75.cc", + "XNNPACK/src/f32-igemm/6x8-aarch64-neonfma-cortex-a75.cc", +] + +PROD_SCALAR_PORTABLE_MICROKERNEL_SRCS = [ + "XNNPACK/src/params-init.c", + "XNNPACK/src/u8-lut32norm/scalar.c", + "XNNPACK/src/xx-copy/memcpy.c", + "XNNPACK/src/x8-lut/gen/lut-scalar-x4.c", + "XNNPACK/src/x32-depthtospace2d-chw2hwc/scalar.c", +] + +PROD_SSE_MICROKERNEL_SRCS = [ + "XNNPACK/src/f32-avgpool/9p8x-minmax-sse-c4.c", + "XNNPACK/src/f32-avgpool/9x-minmax-sse-c4.c", + "XNNPACK/src/f32-conv-hwc2chw/3x3s2p1c3x4-sse-2x2.c", + "XNNPACK/src/f32-dwconv/gen/up8x3-minmax-sse.c", + "XNNPACK/src/f32-dwconv/gen/up8x4-minmax-sse.c", + "XNNPACK/src/f32-dwconv/gen/up8x9-minmax-sse.c", + "XNNPACK/src/f32-dwconv/gen/up8x25-minmax-sse.c", + "XNNPACK/src/f32-dwconv2d-chw/gen/3x3p1-minmax-sse-2x4-acc2.c", + "XNNPACK/src/f32-dwconv2d-chw/gen/3x3s2p1-minmax-sse-1x4-acc3.c", + "XNNPACK/src/f32-dwconv2d-chw/gen/5x5p2-minmax-sse-4x4.c", + "XNNPACK/src/f32-dwconv2d-chw/gen/5x5s2p2-minmax-sse-2x4.c", + "XNNPACK/src/f32-gavgpool-cw/sse-x4.c", + "XNNPACK/src/f32-gavgpool/7p7x-minmax-sse-c4.c", + "XNNPACK/src/f32-gavgpool/7x-minmax-sse-c4.c", + "XNNPACK/src/f32-gemm/gen/1x8-minmax-sse-load1.c", + "XNNPACK/src/f32-gemm/gen/4x2c4-minmax-sse.c", + "XNNPACK/src/f32-gemm/gen/4x8-minmax-sse-load1.c", + "XNNPACK/src/f32-ibilinear-chw/gen/sse-p8.c", + "XNNPACK/src/f32-ibilinear/gen/sse-c8.c", + "XNNPACK/src/f32-igemm/gen/1x8-minmax-sse-load1.c", + "XNNPACK/src/f32-igemm/gen/4x2c4-minmax-sse.c", + "XNNPACK/src/f32-igemm/gen/4x8-minmax-sse-load1.c", + "XNNPACK/src/f32-maxpool/9p8x-minmax-sse-c4.c", + "XNNPACK/src/f32-pavgpool/9p8x-minmax-sse-c4.c", + "XNNPACK/src/f32-pavgpool/9x-minmax-sse-c4.c", + "XNNPACK/src/f32-rmax/sse.c", + "XNNPACK/src/f32-spmm/gen/32x1-minmax-sse.c", + "XNNPACK/src/f32-vbinary/gen/vadd-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vaddc-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vdiv-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vdivc-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vmaxc-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vmin-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vminc-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vmul-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vmulc-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vrdivc-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vrsubc-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vsqrdiff-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vsqrdiffc-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vsub-minmax-sse-x8.c", + "XNNPACK/src/f32-vbinary/gen/vsubc-minmax-sse-x8.c", + "XNNPACK/src/f32-vclamp/gen/vclamp-sse-x8.c", + "XNNPACK/src/f32-vhswish/gen/vhswish-sse-x8.c", + "XNNPACK/src/f32-vlrelu/gen/vlrelu-sse-x8.c", + "XNNPACK/src/f32-vmulcaddc/gen/c4-minmax-sse-2x.c", + "XNNPACK/src/f32-vsqrt/gen/sse-sqrt-x4.c", + "XNNPACK/src/f32-vunary/gen/vabs-sse-x8.c", + "XNNPACK/src/f32-vunary/gen/vneg-sse-x8.c", + "XNNPACK/src/f32-vunary/gen/vsqr-sse-x8.c", + "XNNPACK/src/x32-packx/x4-sse.c", +] + +PROD_SSE2_MICROKERNEL_SRCS = [ + "XNNPACK/src/f16-f32-vcvt/gen/vcvt-sse2-int16-x32.c", + "XNNPACK/src/f32-argmaxpool/4x-sse2-c4.c", + "XNNPACK/src/f32-argmaxpool/9p8x-sse2-c4.c", + "XNNPACK/src/f32-argmaxpool/9x-sse2-c4.c", + "XNNPACK/src/f32-f16-vcvt/gen/vcvt-sse2-x16.c", + "XNNPACK/src/f32-prelu/gen/sse2-2x8.c", + "XNNPACK/src/f32-qs8-vcvt/gen/vcvt-sse2-x32.c", + "XNNPACK/src/f32-qu8-vcvt/gen/vcvt-sse2-x32.c", + "XNNPACK/src/f32-raddstoreexpminusmax/gen/sse2-rr2-p5-x20-acc2.c", + "XNNPACK/src/f32-velu/gen/velu-sse2-rr2-lut16-p3-x12.c", + "XNNPACK/src/f32-vlrelu/gen/vlrelu-sse2-x8.c", + "XNNPACK/src/f32-vrnd/gen/vrndd-sse2-x8.c", + "XNNPACK/src/f32-vrnd/gen/vrndne-sse2-x8.c", + "XNNPACK/src/f32-vrnd/gen/vrndu-sse2-x8.c", + "XNNPACK/src/f32-vrnd/gen/vrndz-sse2-x8.c", + "XNNPACK/src/f32-vsigmoid/gen/vsigmoid-sse2-rr2-lut64-p2-div-x8.c", + "XNNPACK/src/qc8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c", + "XNNPACK/src/qc8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16.c", + "XNNPACK/src/qc8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qc8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qc8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qc8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qs8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16-add16.c", + "XNNPACK/src/qs8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16-add16.c", + "XNNPACK/src/qs8-f32-vcvt/gen/vcvt-sse2-x32.c", + "XNNPACK/src/qs8-gavgpool/gen/7p7x-minmax-fp32-sse2-c8.c", + "XNNPACK/src/qs8-gavgpool/gen/7x-minmax-fp32-sse2-c8.c", + "XNNPACK/src/qs8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qs8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qs8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qs8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qs8-vadd/gen/minmax-sse2-mul16-ld64-x8.c", + "XNNPACK/src/qs8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c", + "XNNPACK/src/qs8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c", + "XNNPACK/src/qs8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c", + "XNNPACK/src/qu8-avgpool/9p8x-minmax-sse2-c8.c", + "XNNPACK/src/qu8-avgpool/9x-minmax-sse2-c8.c", + "XNNPACK/src/qu8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c", + "XNNPACK/src/qu8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16.c", + "XNNPACK/src/qu8-f32-vcvt/gen/vcvt-sse2-x32.c", + "XNNPACK/src/qu8-gavgpool/gen/7p7x-minmax-fp32-sse2-c8.c", + "XNNPACK/src/qu8-gavgpool/gen/7x-minmax-fp32-sse2-c8.c", + "XNNPACK/src/qu8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qu8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qu8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qu8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "XNNPACK/src/qu8-vadd/gen/minmax-sse2-mul16-ld64-x8.c", + "XNNPACK/src/qu8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c", + "XNNPACK/src/qu8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c", + "XNNPACK/src/qu8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c", + "XNNPACK/src/s8-ibilinear/gen/sse2-c8.c", + "XNNPACK/src/s8-maxpool/9p8x-minmax-sse2-c16.c", + "XNNPACK/src/s8-vclamp/sse2-x64.c", + "XNNPACK/src/u8-ibilinear/gen/sse2-c8.c", + "XNNPACK/src/u8-maxpool/9p8x-minmax-sse2-c16.c", + "XNNPACK/src/u8-rmax/sse2.c", + "XNNPACK/src/u8-vclamp/sse2-x64.c", + "XNNPACK/src/xx-fill/sse2-x64.c", + "XNNPACK/src/xx-pad/sse2.c", + "XNNPACK/src/x8-zip/xm-sse2.c", + "XNNPACK/src/x8-zip/x2-sse2.c", + "XNNPACK/src/x8-zip/x3-sse2.c", + "XNNPACK/src/x8-zip/x4-sse2.c", + "XNNPACK/src/x32-unpool/sse2.c", + "XNNPACK/src/x32-zip/xm-sse2.c", + "XNNPACK/src/x32-zip/x2-sse2.c", + "XNNPACK/src/x32-zip/x3-sse2.c", + "XNNPACK/src/x32-zip/x4-sse2.c", +] + +PROD_SSSE3_MICROKERNEL_SRCS = [ + "XNNPACK/src/f32-dwconv2d-chw/gen/3x3p1-minmax-ssse3-2x4-acc2.c", +] + +PROD_SSE41_MICROKERNEL_SRCS = [ + "XNNPACK/src/f16-f32-vcvt/gen/vcvt-sse41-int16-x16.c", + "XNNPACK/src/f32-f16-vcvt/gen/vcvt-sse41-x8.c", + "XNNPACK/src/f32-prelu/gen/sse41-2x8.c", + "XNNPACK/src/f32-qs8-vcvt/gen/vcvt-sse41-x32.c", + "XNNPACK/src/f32-vlrelu/gen/vlrelu-sse41-x8.c", + "XNNPACK/src/f32-vrnd/gen/vrndd-sse41-x8.c", + "XNNPACK/src/f32-vrnd/gen/vrndne-sse41-x8.c", + "XNNPACK/src/f32-vrnd/gen/vrndu-sse41-x8.c", + "XNNPACK/src/f32-vrnd/gen/vrndz-sse41-x8.c", + "XNNPACK/src/f32-vsigmoid/gen/vsigmoid-sse41-rr2-lut64-p2-div-x8.c", + "XNNPACK/src/qc8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c", + "XNNPACK/src/qc8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c", + "XNNPACK/src/qc8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qc8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qc8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qc8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qs8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16-add16.c", + "XNNPACK/src/qs8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16-add16.c", + "XNNPACK/src/qs8-f32-vcvt/gen/vcvt-sse41-x16.c", + "XNNPACK/src/qs8-gavgpool/gen/7p7x-minmax-fp32-sse41-c8.c", + "XNNPACK/src/qs8-gavgpool/gen/7x-minmax-fp32-sse41-c8.c", + "XNNPACK/src/qs8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qs8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qs8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qs8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qs8-vadd/gen/minmax-sse41-mul16-ld64-x8.c", + "XNNPACK/src/qs8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c", + "XNNPACK/src/qs8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c", + "XNNPACK/src/qs8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c", + "XNNPACK/src/qu8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c", + "XNNPACK/src/qu8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c", + "XNNPACK/src/qu8-f32-vcvt/gen/vcvt-sse41-x16.c", + "XNNPACK/src/qu8-gavgpool/gen/7p7x-minmax-fp32-sse41-c8.c", + "XNNPACK/src/qu8-gavgpool/gen/7x-minmax-fp32-sse41-c8.c", + "XNNPACK/src/qu8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qu8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qu8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qu8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "XNNPACK/src/qu8-vadd/gen/minmax-sse41-mul16-ld64-x8.c", + "XNNPACK/src/qu8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c", + "XNNPACK/src/qu8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c", + "XNNPACK/src/qu8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c", + "XNNPACK/src/s8-ibilinear/gen/sse41-c16.c", + "XNNPACK/src/s8-maxpool/9p8x-minmax-sse41-c16.c", + "XNNPACK/src/s8-vclamp/sse41-x64.c", + "XNNPACK/src/u8-ibilinear/gen/sse41-c16.c", +] + +PROD_AVX_MICROKERNEL_SRCS = [ + "XNNPACK/src/f16-f32-vcvt/gen/vcvt-avx-int16-x16.c", + "XNNPACK/src/f32-dwconv/gen/up8x25-minmax-avx.c", + "XNNPACK/src/f32-dwconv/gen/up16x3-minmax-avx.c", + "XNNPACK/src/f32-dwconv/gen/up16x4-minmax-avx.c", + "XNNPACK/src/f32-dwconv/gen/up16x9-minmax-avx.c", + "XNNPACK/src/f32-f16-vcvt/gen/vcvt-avx-x24.c", + "XNNPACK/src/f32-gemm/gen/1x16-minmax-avx-broadcast.c", + "XNNPACK/src/f32-gemm/gen/5x16-minmax-avx-broadcast.c", + "XNNPACK/src/f32-igemm/gen/1x16-minmax-avx-broadcast.c", + "XNNPACK/src/f32-igemm/gen/5x16-minmax-avx-broadcast.c", + "XNNPACK/src/f32-prelu/gen/avx-2x16.c", + "XNNPACK/src/f32-qs8-vcvt/gen/vcvt-avx-x32.c", + "XNNPACK/src/f32-qu8-vcvt/gen/vcvt-avx-x32.c", + "XNNPACK/src/f32-vbinary/gen/vadd-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vaddc-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vdiv-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vdivc-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vmaxc-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vmin-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vminc-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vmul-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vmulc-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vrdivc-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vrsubc-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vsqrdiff-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vsqrdiffc-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vsub-minmax-avx-x16.c", + "XNNPACK/src/f32-vbinary/gen/vsubc-minmax-avx-x16.c", + "XNNPACK/src/f32-vclamp/gen/vclamp-avx-x16.c", + "XNNPACK/src/f32-velu/gen/velu-avx-rr2-lut4-p4-perm-x32.c", + "XNNPACK/src/f32-vhswish/gen/vhswish-avx-x16.c", + "XNNPACK/src/f32-vlrelu/gen/vlrelu-avx-x16.c", + "XNNPACK/src/f32-vrnd/gen/vrndd-avx-x16.c", + "XNNPACK/src/f32-vrnd/gen/vrndne-avx-x16.c", + "XNNPACK/src/f32-vrnd/gen/vrndu-avx-x16.c", + "XNNPACK/src/f32-vrnd/gen/vrndz-avx-x16.c", + "XNNPACK/src/f32-vsigmoid/gen/vsigmoid-avx-rr2-p5-nr2-x40.c", + "XNNPACK/src/f32-vsqrt/gen/avx-sqrt-x8.c", + "XNNPACK/src/f32-vunary/gen/vabs-avx-x16.c", + "XNNPACK/src/f32-vunary/gen/vneg-avx-x16.c", + "XNNPACK/src/f32-vunary/gen/vsqr-avx-x16.c", + "XNNPACK/src/qc8-dwconv/gen/up16x9-minmax-fp32-avx-mul16-add16.c", + "XNNPACK/src/qc8-dwconv/gen/up16x25-minmax-fp32-avx-mul16-add16.c", + "XNNPACK/src/qc8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qc8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qc8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qc8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qs8-dwconv/gen/up16x9-minmax-fp32-avx-mul16-add16.c", + "XNNPACK/src/qs8-dwconv/gen/up16x25-minmax-fp32-avx-mul16-add16.c", + "XNNPACK/src/qs8-f32-vcvt/gen/vcvt-avx-x32.c", + "XNNPACK/src/qs8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qs8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qs8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qs8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qs8-vadd/gen/minmax-avx-mul32-ld32-x8.c", + "XNNPACK/src/qs8-vaddc/gen/minmax-avx-mul32-ld32-x8.c", + "XNNPACK/src/qs8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c", + "XNNPACK/src/qs8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c", + "XNNPACK/src/qu8-dwconv/gen/up16x9-minmax-fp32-avx-mul16.c", + "XNNPACK/src/qu8-dwconv/gen/up16x25-minmax-fp32-avx-mul16.c", + "XNNPACK/src/qu8-f32-vcvt/gen/vcvt-avx-x32.c", + "XNNPACK/src/qu8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qu8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qu8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qu8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "XNNPACK/src/qu8-vadd/gen/minmax-avx-mul32-ld32-x8.c", + "XNNPACK/src/qu8-vaddc/gen/minmax-avx-mul32-ld32-x8.c", + "XNNPACK/src/qu8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c", + "XNNPACK/src/qu8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c", + "XNNPACK/src/x8-lut/gen/lut-avx-x64.c", +] + +PROD_F16C_MICROKERNEL_SRCS = [ + "XNNPACK/src/f16-f32-vcvt/gen/vcvt-f16c-x16.c", + "XNNPACK/src/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c", + "XNNPACK/src/f16-gavgpool/gen/7x-minmax-f16c-c8.c", + "XNNPACK/src/f16-maxpool/9p8x-minmax-f16c-c8.c", + "XNNPACK/src/f16-prelu/gen/f16c-2x16.c", + "XNNPACK/src/f16-vbinary/gen/vadd-minmax-f16c-x16.c", + "XNNPACK/src/f16-vbinary/gen/vaddc-minmax-f16c-x16.c", + "XNNPACK/src/f16-vbinary/gen/vmul-minmax-f16c-x16.c", + "XNNPACK/src/f16-vbinary/gen/vmulc-minmax-f16c-x16.c", + "XNNPACK/src/f16-vclamp/gen/vclamp-f16c-x16.c", + "XNNPACK/src/f16-vhswish/gen/vhswish-f16c-x16.c", + "XNNPACK/src/f16-vlrelu/gen/vlrelu-f16c-x16.c", + "XNNPACK/src/f32-f16-vcvt/gen/vcvt-f16c-x16.c", +] + +PROD_XOP_MICROKERNEL_SRCS = [ + "XNNPACK/src/qc8-dwconv/gen/up16x9-minmax-fp32-xop-mul16-add16.c", + "XNNPACK/src/qc8-dwconv/gen/up16x25-minmax-fp32-xop-mul16-add16.c", + "XNNPACK/src/qc8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qc8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qc8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qc8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qs8-dwconv/gen/up16x9-minmax-fp32-xop-mul16-add16.c", + "XNNPACK/src/qs8-dwconv/gen/up16x25-minmax-fp32-xop-mul16-add16.c", + "XNNPACK/src/qs8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qs8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qs8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qs8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qs8-vadd/gen/minmax-xop-mul32-ld32-x8.c", + "XNNPACK/src/qs8-vaddc/gen/minmax-xop-mul32-ld32-x8.c", + "XNNPACK/src/qu8-dwconv/gen/up16x9-minmax-fp32-xop-mul32.c", + "XNNPACK/src/qu8-dwconv/gen/up16x25-minmax-fp32-xop-mul32.c", + "XNNPACK/src/qu8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qu8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qu8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qu8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "XNNPACK/src/qu8-vadd/gen/minmax-xop-mul32-ld32-x8.c", + "XNNPACK/src/qu8-vaddc/gen/minmax-xop-mul32-ld32-x8.c", +] + +PROD_FMA3_MICROKERNEL_SRCS = [ + "XNNPACK/src/f16-dwconv/gen/up8x25-minmax-fma3-acc2.c", + "XNNPACK/src/f16-dwconv/gen/up16x3-minmax-fma3.c", + "XNNPACK/src/f16-dwconv/gen/up16x4-minmax-fma3.c", + "XNNPACK/src/f16-dwconv/gen/up16x9-minmax-fma3.c", + "XNNPACK/src/f16-ibilinear/gen/fma3-c8.c", + "XNNPACK/src/f16-vmulcaddc/gen/c8-minmax-fma3-2x.c", + "XNNPACK/src/f32-dwconv/gen/up8x25-minmax-fma3.c", + "XNNPACK/src/f32-dwconv/gen/up16x3-minmax-fma3.c", + "XNNPACK/src/f32-dwconv/gen/up16x4-minmax-fma3.c", + "XNNPACK/src/f32-dwconv/gen/up16x9-minmax-fma3.c", + "XNNPACK/src/f32-gemm/gen/1x16-minmax-fma3-broadcast.c", + "XNNPACK/src/f32-gemm/gen/1x16s4-minmax-fma3-broadcast.c", + "XNNPACK/src/f32-gemm/gen/4x16s4-minmax-fma3-broadcast.c", + "XNNPACK/src/f32-gemm/gen/5x16-minmax-fma3-broadcast.c", + "XNNPACK/src/f32-igemm/gen/1x16-minmax-fma3-broadcast.c", + "XNNPACK/src/f32-igemm/gen/1x16s4-minmax-fma3-broadcast.c", + "XNNPACK/src/f32-igemm/gen/4x16s4-minmax-fma3-broadcast.c", + "XNNPACK/src/f32-igemm/gen/5x16-minmax-fma3-broadcast.c", + "XNNPACK/src/f32-vhswish/gen/vhswish-fma3-x16.c", +] + +PROD_AVX2_MICROKERNEL_SRCS = [ + "XNNPACK/src/f16-gemm/gen/1x16-minmax-avx2-broadcast.c", + "XNNPACK/src/f16-gemm/gen/4x16-minmax-avx2-broadcast.c", + "XNNPACK/src/f16-igemm/gen/1x16-minmax-avx2-broadcast.c", + "XNNPACK/src/f16-igemm/gen/4x16-minmax-avx2-broadcast.c", + "XNNPACK/src/f32-qs8-vcvt/gen/vcvt-avx2-x64.c", + "XNNPACK/src/f32-qu8-vcvt/gen/vcvt-avx2-x64.c", + "XNNPACK/src/f32-velu/gen/velu-avx2-rr1-lut4-p4-perm-x56.c", + "XNNPACK/src/f32-vsigmoid/gen/vsigmoid-avx2-rr1-p5-div-x40.c", + "XNNPACK/src/qc8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", + "XNNPACK/src/qc8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", + "XNNPACK/src/qc8-gemm/gen/1x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qc8-gemm/gen/3x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qc8-igemm/gen/1x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qc8-igemm/gen/3x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qs8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", + "XNNPACK/src/qs8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", + "XNNPACK/src/qs8-f32-vcvt/gen/vcvt-avx2-x16.c", + "XNNPACK/src/qs8-gemm/gen/1x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qs8-gemm/gen/3x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qs8-igemm/gen/1x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qs8-igemm/gen/3x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qs8-vadd/gen/minmax-avx2-mul32-ld64-x16.c", + "XNNPACK/src/qs8-vaddc/gen/minmax-avx2-mul32-ld64-x16.c", + "XNNPACK/src/qu8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", + "XNNPACK/src/qu8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", + "XNNPACK/src/qu8-f32-vcvt/gen/vcvt-avx2-x16.c", + "XNNPACK/src/qu8-gemm/gen/1x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qu8-gemm/gen/3x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qu8-igemm/gen/1x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qu8-igemm/gen/3x8c8-minmax-fp32-avx2.c", + "XNNPACK/src/qu8-vadd/gen/minmax-avx2-mul32-ld64-x16.c", + "XNNPACK/src/qu8-vaddc/gen/minmax-avx2-mul32-ld64-x16.c", + "XNNPACK/src/x8-lut/gen/lut-avx2-x128.c", +] + +PROD_AVX512F_MICROKERNEL_SRCS = [ + "XNNPACK/src/f32-dwconv/gen/up16x3-minmax-avx512f.c", + "XNNPACK/src/f32-dwconv/gen/up16x4-minmax-avx512f.c", + "XNNPACK/src/f32-dwconv/gen/up16x9-minmax-avx512f.c", + "XNNPACK/src/f32-dwconv/gen/up16x25-minmax-avx512f.c", + "XNNPACK/src/f32-gemm/gen/1x16-minmax-avx512f-broadcast.c", + "XNNPACK/src/f32-gemm/gen/7x16-minmax-avx512f-broadcast.c", + "XNNPACK/src/f32-igemm/gen/1x16-minmax-avx512f-broadcast.c", + "XNNPACK/src/f32-igemm/gen/7x16-minmax-avx512f-broadcast.c", + "XNNPACK/src/f32-prelu/gen/avx512f-2x16.c", + "XNNPACK/src/f32-vbinary/gen/vadd-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vaddc-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vdiv-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vdivc-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vmaxc-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vmin-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vminc-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vmul-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vmulc-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vrdivc-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vrsubc-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vsqrdiff-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vsqrdiffc-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vsub-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vbinary/gen/vsubc-minmax-avx512f-x32.c", + "XNNPACK/src/f32-vclamp/gen/vclamp-avx512f-x16.c", + "XNNPACK/src/f32-velu/gen/velu-avx512f-rr1-lut16-p3-perm-x64.c", + "XNNPACK/src/f32-vhswish/gen/vhswish-avx512f-x16.c", + "XNNPACK/src/f32-vlrelu/gen/vlrelu-avx512f-x16.c", + "XNNPACK/src/f32-vrnd/gen/vrndd-avx512f-x16.c", + "XNNPACK/src/f32-vrnd/gen/vrndne-avx512f-x16.c", + "XNNPACK/src/f32-vrnd/gen/vrndu-avx512f-x16.c", + "XNNPACK/src/f32-vrnd/gen/vrndz-avx512f-x16.c", + "XNNPACK/src/f32-vsigmoid/gen/vsigmoid-avx512f-rr2-lut32-p2-perm2-scalef-div-x64.c", + "XNNPACK/src/f32-vunary/gen/vabs-avx512f-x16.c", + "XNNPACK/src/f32-vunary/gen/vneg-avx512f-x16.c", + "XNNPACK/src/f32-vunary/gen/vsqr-avx512f-x16.c", +] + +PROD_AVX512SKX_MICROKERNEL_SRCS = [ + "XNNPACK/src/f16-f32-vcvt/gen/vcvt-avx512skx-x16.c", + "XNNPACK/src/f32-f16-vcvt/gen/vcvt-avx512skx-x16.c", + "XNNPACK/src/f32-qs8-vcvt/gen/vcvt-avx512skx-x128.c", + "XNNPACK/src/f32-qu8-vcvt/gen/vcvt-avx512skx-x128.c", + "XNNPACK/src/qc8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", + "XNNPACK/src/qc8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", + "XNNPACK/src/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qs8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", + "XNNPACK/src/qs8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", + "XNNPACK/src/qs8-f32-vcvt/gen/vcvt-avx512skx-x32.c", + "XNNPACK/src/qs8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qs8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qs8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qs8-vadd/gen/minmax-avx512skx-mul32-ld128-x16.c", + "XNNPACK/src/qs8-vaddc/gen/minmax-avx512skx-mul32-ld128-x16.c", + "XNNPACK/src/qu8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", + "XNNPACK/src/qu8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", + "XNNPACK/src/qu8-f32-vcvt/gen/vcvt-avx512skx-x32.c", + "XNNPACK/src/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "XNNPACK/src/qu8-vadd/gen/minmax-avx512skx-mul32-ld128-x16.c", + "XNNPACK/src/qu8-vaddc/gen/minmax-avx512skx-mul32-ld128-x16.c", + "XNNPACK/src/x8-lut/gen/lut-avx512skx-vpshufb-x64.c", +] diff --git a/third_party/xnnpack_wrapper_defs.bzl b/third_party/xnnpack_wrapper_defs.bzl new file mode 100644 index 0000000000000..26556a7fbfa25 --- /dev/null +++ b/third_party/xnnpack_wrapper_defs.bzl @@ -0,0 +1,1131 @@ +""" +Auto-generated by generate-wrappers.py script. Do not modify +""" + +AARCH32_ASM_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-cortex-a7.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-ld64.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/4x4-aarch32-vfp-ld64.S", + "xnnpack_wrappers/f32-gemm/4x4-minmax-aarch32-vfp-ld64.S", + "xnnpack_wrappers/f32-gemm/4x8-minmax-aarch32-neon-cortex-a55.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-cortex-a7.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-ld64.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/4x8-minmax-aarch32-neon-cortex-a55.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-cortex-a7.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-cortex-a7.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-ld64.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8c4-minmax-fp32-aarch32-neondot-cortex-a55.S", + "xnnpack_wrappers/qc8-gemm/gen/4x8c4-minmax-fp32-aarch32-neondot-ld64.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-ld64.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8c4-minmax-fp32-aarch32-neondot-cortex-a55.S", + "xnnpack_wrappers/qc8-igemm/gen/4x8c4-minmax-fp32-aarch32-neondot-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a7.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a7.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8c4-minmax-rndnu-aarch32-neondot-cortex-a55.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8c4-minmax-rndnu-aarch32-neondot-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x8c4-minmax-rndnu-aarch32-neondot-cortex-a55.S", + "xnnpack_wrappers/qs8-igemm/gen/4x8c4-minmax-rndnu-aarch32-neondot-ld64.S", + "xnnpack_wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a7.S", + "xnnpack_wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a7.S", + "xnnpack_wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qu8-gemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qu8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qu8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qu8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qu8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S", +] + +PROD_NEONDOT_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/qc8-gemm/gen/1x8c4-minmax-fp32-neondot.c", + "xnnpack_wrappers/qc8-gemm/gen/1x16c4-minmax-fp32-neondot.c", + "xnnpack_wrappers/qc8-gemm/gen/4x8c4-minmax-fp32-neondot.c", + "xnnpack_wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-neondot.c", + "xnnpack_wrappers/qc8-igemm/gen/1x8c4-minmax-fp32-neondot.c", + "xnnpack_wrappers/qc8-igemm/gen/1x16c4-minmax-fp32-neondot.c", + "xnnpack_wrappers/qc8-igemm/gen/4x8c4-minmax-fp32-neondot.c", + "xnnpack_wrappers/qc8-igemm/gen/4x16c4-minmax-fp32-neondot.c", + "xnnpack_wrappers/qs8-gemm/gen/1x8c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qs8-gemm/gen/1x16c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qs8-gemm/gen/4x8c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qs8-igemm/gen/1x8c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qs8-igemm/gen/1x16c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qs8-igemm/gen/4x8c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qs8-igemm/gen/4x16c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qu8-gemm/gen/1x8c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qu8-gemm/gen/1x16c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qu8-gemm/gen/4x8c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qu8-gemm/gen/4x16c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qu8-igemm/gen/1x8c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qu8-igemm/gen/1x16c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qu8-igemm/gen/4x8c4-minmax-rndnu-neondot.c", + "xnnpack_wrappers/qu8-igemm/gen/4x16c4-minmax-rndnu-neondot.c", +] + +PROD_NEONFMA_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f32-dwconv/gen/up8x3-minmax-neonfma.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x4-minmax-neonfma.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x9-minmax-neonfma.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x25-minmax-neonfma-acc2.c", + "xnnpack_wrappers/f32-gemm/gen/1x8s4-minmax-neonfma.c", + "xnnpack_wrappers/f32-gemm/gen/6x8s4-minmax-neonfma.c", + "xnnpack_wrappers/f32-ibilinear-chw/gen/neonfma-p8.c", + "xnnpack_wrappers/f32-ibilinear/gen/neonfma-c8.c", + "xnnpack_wrappers/f32-igemm/gen/1x8s4-minmax-neonfma.c", + "xnnpack_wrappers/f32-igemm/gen/6x8s4-minmax-neonfma.c", + "xnnpack_wrappers/f32-raddstoreexpminusmax/gen/neonfma-rr1-lut64-p2-x16.c", + "xnnpack_wrappers/f32-spmm/gen/32x1-minmax-neonfma-pipelined.c", + "xnnpack_wrappers/f32-velu/gen/velu-neonfma-rr1-lut16-p3-x16.c", + "xnnpack_wrappers/f32-velu/gen/velu-neonfma-rr1-p6-x8.c", + "xnnpack_wrappers/f32-vmulcaddc/gen/c4-minmax-neonfma-2x.c", + "xnnpack_wrappers/f32-vsigmoid/gen/vsigmoid-neonfma-rr1-lut64-p2-nr2recps-x16.c", +] + +PROD_SSSE3_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-ssse3-2x4-acc2.c", +] + +PROD_SCALAR_AARCH32_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-f32-vcvt/gen/vcvt-scalar-x4.c", + "xnnpack_wrappers/f32-argmaxpool/4x-scalar-c1.c", + "xnnpack_wrappers/f32-argmaxpool/9p8x-scalar-c1.c", + "xnnpack_wrappers/f32-argmaxpool/9x-scalar-c1.c", + "xnnpack_wrappers/f32-avgpool/9p8x-minmax-scalar-c1.c", + "xnnpack_wrappers/f32-avgpool/9x-minmax-scalar-c1.c", + "xnnpack_wrappers/f32-conv-hwc/3x3s2p0p1c3x4-scalar-1x1.c", + "xnnpack_wrappers/f32-conv-hwc/3x3s2p1c3x4-scalar-1x1.c", + "xnnpack_wrappers/f32-conv-hwc2chw/3x3s2p1c3x4-scalar-1x1.c", + "xnnpack_wrappers/f32-dwconv/gen/up1x3-minmax-scalar-acc2.c", + "xnnpack_wrappers/f32-dwconv/gen/up1x3-scalar-acc2.c", + "xnnpack_wrappers/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c", + "xnnpack_wrappers/f32-dwconv/gen/up1x4-scalar-acc2.c", + "xnnpack_wrappers/f32-dwconv/gen/up1x9-minmax-scalar-acc2.c", + "xnnpack_wrappers/f32-dwconv/gen/up1x9-scalar-acc2.c", + "xnnpack_wrappers/f32-dwconv/gen/up1x25-minmax-scalar-acc2.c", + "xnnpack_wrappers/f32-dwconv/gen/up1x25-scalar-acc2.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-scalar-4x1.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3s2p1-minmax-scalar-2x1-acc2.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/5x5p2-minmax-scalar-2x1-acc2.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/5x5s2p2-minmax-scalar-2x1-acc2.c", + "xnnpack_wrappers/f32-f16-vcvt/gen/vcvt-scalar-fabsf-x2.c", + "xnnpack_wrappers/f32-gavgpool-cw/scalar-x1.c", + "xnnpack_wrappers/f32-gavgpool/7p7x-minmax-scalar-c1.c", + "xnnpack_wrappers/f32-gavgpool/7x-minmax-scalar-c1.c", + "xnnpack_wrappers/f32-gemm/gen/1x4-minmax-scalar.c", + "xnnpack_wrappers/f32-gemm/gen/1x4-relu-scalar.c", + "xnnpack_wrappers/f32-gemm/gen/1x4-scalar.c", + "xnnpack_wrappers/f32-gemm/gen/4x2-minmax-scalar.c", + "xnnpack_wrappers/f32-gemm/gen/4x2-scalar.c", + "xnnpack_wrappers/f32-gemm/gen/4x4-minmax-scalar.c", + "xnnpack_wrappers/f32-gemm/gen/4x4-relu-scalar.c", + "xnnpack_wrappers/f32-gemm/gen/4x4-scalar.c", + "xnnpack_wrappers/f32-ibilinear-chw/gen/scalar-p4.c", + "xnnpack_wrappers/f32-ibilinear/gen/scalar-c2.c", + "xnnpack_wrappers/f32-igemm/gen/1x4-minmax-scalar.c", + "xnnpack_wrappers/f32-igemm/gen/1x4-relu-scalar.c", + "xnnpack_wrappers/f32-igemm/gen/1x4-scalar.c", + "xnnpack_wrappers/f32-igemm/gen/4x2-minmax-scalar.c", + "xnnpack_wrappers/f32-igemm/gen/4x2-scalar.c", + "xnnpack_wrappers/f32-igemm/gen/4x4-minmax-scalar.c", + "xnnpack_wrappers/f32-igemm/gen/4x4-relu-scalar.c", + "xnnpack_wrappers/f32-igemm/gen/4x4-scalar.c", + "xnnpack_wrappers/f32-maxpool/9p8x-minmax-scalar-c1.c", + "xnnpack_wrappers/f32-pavgpool/9p8x-minmax-scalar-c1.c", + "xnnpack_wrappers/f32-pavgpool/9x-minmax-scalar-c1.c", + "xnnpack_wrappers/f32-prelu/gen/scalar-2x4.c", + "xnnpack_wrappers/f32-qs8-vcvt/gen/vcvt-scalar-imagic-x4.c", + "xnnpack_wrappers/f32-qu8-vcvt/gen/vcvt-scalar-imagic-x4.c", + "xnnpack_wrappers/f32-raddstoreexpminusmax/gen/scalar-rr2-p5-x4-acc2.c", + "xnnpack_wrappers/f32-rmax/scalar.c", + "xnnpack_wrappers/f32-spmm/gen/8x1-minmax-scalar.c", + "xnnpack_wrappers/f32-spmm/gen/8x2-minmax-scalar.c", + "xnnpack_wrappers/f32-spmm/gen/8x4-minmax-scalar.c", + "xnnpack_wrappers/f32-vbinary/gen/vadd-minmax-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vaddc-minmax-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vdiv-minmax-scalar-x2.c", + "xnnpack_wrappers/f32-vbinary/gen/vdivc-minmax-scalar-x2.c", + "xnnpack_wrappers/f32-vbinary/gen/vmax-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmaxc-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmin-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vminc-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmul-minmax-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmulc-minmax-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vrdivc-minmax-scalar-x2.c", + "xnnpack_wrappers/f32-vbinary/gen/vrsubc-minmax-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiff-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiffc-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsub-minmax-scalar-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsubc-minmax-scalar-x8.c", + "xnnpack_wrappers/f32-vclamp/gen/vclamp-scalar-x4.c", + "xnnpack_wrappers/f32-velu/gen/velu-scalar-rr2-lut16-p3-x4.c", + "xnnpack_wrappers/f32-vhswish/gen/vhswish-scalar-x4.c", + "xnnpack_wrappers/f32-vlrelu/gen/vlrelu-scalar-x4.c", + "xnnpack_wrappers/f32-vmulcaddc/gen/c1-minmax-scalar-2x.c", + "xnnpack_wrappers/f32-vrelu/gen/vrelu-scalar-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndd-scalar-libm-x1.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndne-scalar-libm-x1.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndu-scalar-libm-x1.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndz-scalar-libm-x1.c", + "xnnpack_wrappers/f32-vsigmoid/gen/vsigmoid-scalar-rr2-lut64-p2-div-x2.c", + "xnnpack_wrappers/f32-vsqrt/gen/scalar-sqrt-x1.c", + "xnnpack_wrappers/f32-vunary/gen/vabs-scalar-x4.c", + "xnnpack_wrappers/f32-vunary/gen/vneg-scalar-x4.c", + "xnnpack_wrappers/f32-vunary/gen/vsqr-scalar-x4.c", + "xnnpack_wrappers/qc8-dwconv/gen/up2x9-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qc8-dwconv/gen/up2x25-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qc8-gemm/gen/1x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qc8-gemm/gen/1x8-minmax-fp32-neon-mlal-lane.c", + "xnnpack_wrappers/qc8-gemm/gen/2x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qc8-igemm/gen/1x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qc8-igemm/gen/1x8-minmax-fp32-neon-mlal-lane.c", + "xnnpack_wrappers/qc8-igemm/gen/2x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qs8-dwconv/gen/up1x9-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qs8-dwconv/gen/up1x25-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qs8-f32-vcvt/gen/vcvt-scalar-x4.c", + "xnnpack_wrappers/qs8-gavgpool/gen/7p7x-minmax-fp32-scalar-imagic-c1.c", + "xnnpack_wrappers/qs8-gavgpool/gen/7x-minmax-fp32-scalar-imagic-c1.c", + "xnnpack_wrappers/qs8-gemm/gen/1x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qs8-gemm/gen/2x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qs8-igemm/gen/1x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qs8-igemm/gen/2x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-scalar-x1.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-scalar-x1.c", + "xnnpack_wrappers/qs8-vmul/gen/minmax-fp32-scalar-x4.c", + "xnnpack_wrappers/qs8-vmulc/gen/minmax-fp32-scalar-x4.c", + "xnnpack_wrappers/qu8-avgpool/9p8x-minmax-scalar-c1.c", + "xnnpack_wrappers/qu8-avgpool/9x-minmax-scalar-c1.c", + "xnnpack_wrappers/qu8-dwconv/gen/up1x9-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qu8-dwconv/gen/up1x25-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qu8-f32-vcvt/gen/vcvt-scalar-x4.c", + "xnnpack_wrappers/qu8-gavgpool/gen/7p7x-minmax-fp32-scalar-imagic-c1.c", + "xnnpack_wrappers/qu8-gavgpool/gen/7x-minmax-fp32-scalar-imagic-c1.c", + "xnnpack_wrappers/qu8-gemm/gen/1x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qu8-gemm/gen/2x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qu8-igemm/gen/1x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qu8-igemm/gen/2x2-minmax-fp32-scalar-fmagic.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-scalar-x1.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-scalar-x1.c", + "xnnpack_wrappers/qu8-vmul/gen/minmax-fp32-scalar-x4.c", + "xnnpack_wrappers/qu8-vmulc/gen/minmax-fp32-scalar-x4.c", + "xnnpack_wrappers/s8-ibilinear/gen/scalar-c1.c", + "xnnpack_wrappers/s8-maxpool/9p8x-minmax-scalar-c1.c", + "xnnpack_wrappers/s8-vclamp/scalar-x4.c", + "xnnpack_wrappers/u8-ibilinear/gen/scalar-c1.c", + "xnnpack_wrappers/u8-maxpool/9p8x-minmax-scalar-c1.c", + "xnnpack_wrappers/u8-rmax/scalar.c", + "xnnpack_wrappers/u8-vclamp/scalar-x4.c", + "xnnpack_wrappers/xx-fill/scalar-x16.c", + "xnnpack_wrappers/xx-pad/scalar.c", + "xnnpack_wrappers/x8-zip/xm-scalar.c", + "xnnpack_wrappers/x8-zip/x2-scalar.c", + "xnnpack_wrappers/x8-zip/x3-scalar.c", + "xnnpack_wrappers/x8-zip/x4-scalar.c", + "xnnpack_wrappers/x32-packx/x2-scalar.c", + "xnnpack_wrappers/x32-packx/x3-scalar.c", + "xnnpack_wrappers/x32-packx/x4-scalar.c", + "xnnpack_wrappers/x32-unpool/scalar.c", + "xnnpack_wrappers/x32-zip/xm-scalar.c", + "xnnpack_wrappers/x32-zip/x2-scalar.c", + "xnnpack_wrappers/x32-zip/x3-scalar.c", + "xnnpack_wrappers/x32-zip/x4-scalar.c", +] + +PROD_XOP_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-xop-mul16-add16.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-xop-mul16-add16.c", + "xnnpack_wrappers/qc8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qc8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qc8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qc8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qs8-dwconv/gen/up16x9-minmax-fp32-xop-mul16-add16.c", + "xnnpack_wrappers/qs8-dwconv/gen/up16x25-minmax-fp32-xop-mul16-add16.c", + "xnnpack_wrappers/qs8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qs8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qs8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qs8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-xop-mul32-ld32-x8.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-xop-mul32-ld32-x8.c", + "xnnpack_wrappers/qu8-dwconv/gen/up16x9-minmax-fp32-xop-mul32.c", + "xnnpack_wrappers/qu8-dwconv/gen/up16x25-minmax-fp32-xop-mul32.c", + "xnnpack_wrappers/qu8-gemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qu8-gemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qu8-igemm/gen/1x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qu8-igemm/gen/2x4c8-minmax-fp32-xop-ld64.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-xop-mul32-ld32-x8.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-xop-mul32-ld32-x8.c", +] + +PROD_FMA3_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-dwconv/gen/up8x25-minmax-fma3-acc2.c", + "xnnpack_wrappers/f16-dwconv/gen/up16x3-minmax-fma3.c", + "xnnpack_wrappers/f16-dwconv/gen/up16x4-minmax-fma3.c", + "xnnpack_wrappers/f16-dwconv/gen/up16x9-minmax-fma3.c", + "xnnpack_wrappers/f16-ibilinear/gen/fma3-c8.c", + "xnnpack_wrappers/f16-vmulcaddc/gen/c8-minmax-fma3-2x.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x25-minmax-fma3.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x3-minmax-fma3.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x4-minmax-fma3.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x9-minmax-fma3.c", + "xnnpack_wrappers/f32-gemm/gen/1x16-minmax-fma3-broadcast.c", + "xnnpack_wrappers/f32-gemm/gen/1x16s4-minmax-fma3-broadcast.c", + "xnnpack_wrappers/f32-gemm/gen/4x16s4-minmax-fma3-broadcast.c", + "xnnpack_wrappers/f32-gemm/gen/5x16-minmax-fma3-broadcast.c", + "xnnpack_wrappers/f32-igemm/gen/1x16-minmax-fma3-broadcast.c", + "xnnpack_wrappers/f32-igemm/gen/1x16s4-minmax-fma3-broadcast.c", + "xnnpack_wrappers/f32-igemm/gen/4x16s4-minmax-fma3-broadcast.c", + "xnnpack_wrappers/f32-igemm/gen/5x16-minmax-fma3-broadcast.c", + "xnnpack_wrappers/f32-vhswish/gen/vhswish-fma3-x16.c", +] + +PROD_AARCH64_NEON_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f32-conv-hwc2chw/3x3s2p1c3x4-neonfma-2x2.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-neonfma-3x4.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3s2p1-minmax-neonfma-2x4-acc2.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/5x5p2-minmax-neonfma-4x4.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/5x5s2p2-minmax-neonfma-1x4-acc2.c", + "xnnpack_wrappers/f32-gemm/gen/1x8-minmax-neonfma-lane-ld64.c", + "xnnpack_wrappers/f32-gemm/gen/4x2-minmax-neonfma-lane-ld64.c", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-neonfma-lane-ld64.c", + "xnnpack_wrappers/f32-igemm/gen/1x8-minmax-neonfma-lane-ld64.c", + "xnnpack_wrappers/f32-igemm/gen/4x2-minmax-neonfma-lane-ld64.c", + "xnnpack_wrappers/f32-igemm/gen/6x8-minmax-neonfma-lane-ld64.c", + "xnnpack_wrappers/f32-spmm/gen/32x2-minmax-neonfma.c", + "xnnpack_wrappers/f32-spmm/gen/32x4-minmax-neonfma.c", + "xnnpack_wrappers/f32-vbinary/gen/vdiv-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vdivc-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vrdivc-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vsqrt/gen/neon-sqrt-x4.c", + "xnnpack_wrappers/x8-lut/gen/lut-neon-tbx128x4-x64.c", +] + +PROD_NEONFP16_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-f32-vcvt/gen/vcvt-neonfp16-x16.c", + "xnnpack_wrappers/f32-f16-vcvt/gen/vcvt-neonfp16-x16.c", +] + +PROD_SCALAR_PORTABLE_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/params-init.c", + "xnnpack_wrappers/u8-lut32norm/scalar.c", + "xnnpack_wrappers/xx-copy/memcpy.c", + "xnnpack_wrappers/x8-lut/gen/lut-scalar-x4.c", + "xnnpack_wrappers/x32-depthtospace2d-chw2hwc/scalar.c", +] + +PROD_AVX_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-f32-vcvt/gen/vcvt-avx-int16-x16.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x25-minmax-avx.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x3-minmax-avx.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x4-minmax-avx.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x9-minmax-avx.c", + "xnnpack_wrappers/f32-f16-vcvt/gen/vcvt-avx-x24.c", + "xnnpack_wrappers/f32-gemm/gen/1x16-minmax-avx-broadcast.c", + "xnnpack_wrappers/f32-gemm/gen/5x16-minmax-avx-broadcast.c", + "xnnpack_wrappers/f32-igemm/gen/1x16-minmax-avx-broadcast.c", + "xnnpack_wrappers/f32-igemm/gen/5x16-minmax-avx-broadcast.c", + "xnnpack_wrappers/f32-prelu/gen/avx-2x16.c", + "xnnpack_wrappers/f32-qs8-vcvt/gen/vcvt-avx-x32.c", + "xnnpack_wrappers/f32-qu8-vcvt/gen/vcvt-avx-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vadd-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vaddc-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vdiv-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vdivc-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vmaxc-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vmin-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vminc-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vmul-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vmulc-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vrdivc-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vrsubc-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiff-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiffc-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vsub-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vsubc-minmax-avx-x16.c", + "xnnpack_wrappers/f32-vclamp/gen/vclamp-avx-x16.c", + "xnnpack_wrappers/f32-velu/gen/velu-avx-rr2-lut4-p4-perm-x32.c", + "xnnpack_wrappers/f32-vhswish/gen/vhswish-avx-x16.c", + "xnnpack_wrappers/f32-vlrelu/gen/vlrelu-avx-x16.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndd-avx-x16.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndne-avx-x16.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndu-avx-x16.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndz-avx-x16.c", + "xnnpack_wrappers/f32-vsigmoid/gen/vsigmoid-avx-rr2-p5-nr2-x40.c", + "xnnpack_wrappers/f32-vsqrt/gen/avx-sqrt-x8.c", + "xnnpack_wrappers/f32-vunary/gen/vabs-avx-x16.c", + "xnnpack_wrappers/f32-vunary/gen/vneg-avx-x16.c", + "xnnpack_wrappers/f32-vunary/gen/vsqr-avx-x16.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-avx-mul16-add16.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-avx-mul16-add16.c", + "xnnpack_wrappers/qc8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qc8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qc8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qc8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qs8-dwconv/gen/up16x9-minmax-fp32-avx-mul16-add16.c", + "xnnpack_wrappers/qs8-dwconv/gen/up16x25-minmax-fp32-avx-mul16-add16.c", + "xnnpack_wrappers/qs8-f32-vcvt/gen/vcvt-avx-x32.c", + "xnnpack_wrappers/qs8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qs8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qs8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qs8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-avx-mul32-ld32-x8.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-avx-mul32-ld32-x8.c", + "xnnpack_wrappers/qs8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c", + "xnnpack_wrappers/qs8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c", + "xnnpack_wrappers/qu8-dwconv/gen/up16x9-minmax-fp32-avx-mul16.c", + "xnnpack_wrappers/qu8-dwconv/gen/up16x25-minmax-fp32-avx-mul16.c", + "xnnpack_wrappers/qu8-f32-vcvt/gen/vcvt-avx-x32.c", + "xnnpack_wrappers/qu8-gemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qu8-gemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qu8-igemm/gen/1x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qu8-igemm/gen/2x4c8-minmax-fp32-avx-ld128.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-avx-mul32-ld32-x8.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-avx-mul32-ld32-x8.c", + "xnnpack_wrappers/qu8-vmul/gen/minmax-fp32-avx-mul16-ld64-x16.c", + "xnnpack_wrappers/qu8-vmulc/gen/minmax-fp32-avx-mul16-ld64-x16.c", + "xnnpack_wrappers/x8-lut/gen/lut-avx-x64.c", +] + +PROD_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-dwconv/gen/up8x25-minmax-neonfp16arith-acc2.c", + "xnnpack_wrappers/f16-dwconv/gen/up16x3-minmax-neonfp16arith.c", + "xnnpack_wrappers/f16-dwconv/gen/up16x4-minmax-neonfp16arith.c", + "xnnpack_wrappers/f16-dwconv/gen/up16x9-minmax-neonfp16arith.c", + "xnnpack_wrappers/f16-gavgpool/gen/7p7x-minmax-neonfp16arith-c8.c", + "xnnpack_wrappers/f16-gavgpool/gen/7x-minmax-neonfp16arith-c8.c", + "xnnpack_wrappers/f16-gemm/gen/1x16-minmax-neonfp16arith-ld64.c", + "xnnpack_wrappers/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c", + "xnnpack_wrappers/f16-ibilinear/gen/neonfp16arith-c8.c", + "xnnpack_wrappers/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c", + "xnnpack_wrappers/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c", + "xnnpack_wrappers/f16-maxpool/9p8x-minmax-neonfp16arith-c8.c", + "xnnpack_wrappers/f16-prelu/gen/neonfp16arith-2x16.c", + "xnnpack_wrappers/f16-vbinary/gen/vadd-minmax-neonfp16arith-x16.c", + "xnnpack_wrappers/f16-vbinary/gen/vaddc-minmax-neonfp16arith-x16.c", + "xnnpack_wrappers/f16-vbinary/gen/vmul-minmax-neonfp16arith-x16.c", + "xnnpack_wrappers/f16-vbinary/gen/vmulc-minmax-neonfp16arith-x16.c", + "xnnpack_wrappers/f16-vclamp/gen/vclamp-neonfp16arith-x16.c", + "xnnpack_wrappers/f16-vhswish/gen/vhswish-neonfp16arith-x16.c", + "xnnpack_wrappers/f16-vlrelu/gen/vlrelu-neonfp16arith-x16.c", + "xnnpack_wrappers/f16-vmulcaddc/gen/c8-minmax-neonfp16arith-2x.c", +] + +PROD_F16C_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-f32-vcvt/gen/vcvt-f16c-x16.c", + "xnnpack_wrappers/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c", + "xnnpack_wrappers/f16-gavgpool/gen/7x-minmax-f16c-c8.c", + "xnnpack_wrappers/f16-maxpool/9p8x-minmax-f16c-c8.c", + "xnnpack_wrappers/f16-prelu/gen/f16c-2x16.c", + "xnnpack_wrappers/f16-vbinary/gen/vadd-minmax-f16c-x16.c", + "xnnpack_wrappers/f16-vbinary/gen/vaddc-minmax-f16c-x16.c", + "xnnpack_wrappers/f16-vbinary/gen/vmul-minmax-f16c-x16.c", + "xnnpack_wrappers/f16-vbinary/gen/vmulc-minmax-f16c-x16.c", + "xnnpack_wrappers/f16-vclamp/gen/vclamp-f16c-x16.c", + "xnnpack_wrappers/f16-vhswish/gen/vhswish-f16c-x16.c", + "xnnpack_wrappers/f16-vlrelu/gen/vlrelu-f16c-x16.c", + "xnnpack_wrappers/f32-f16-vcvt/gen/vcvt-f16c-x16.c", +] + +PROD_NEONV8_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f32-qs8-vcvt/gen/vcvt-neonv8-x32.c", + "xnnpack_wrappers/f32-qu8-vcvt/gen/vcvt-neonv8-x32.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndd-neonv8-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndne-neonv8-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndu-neonv8-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndz-neonv8-x8.c", + "xnnpack_wrappers/qc8-dwconv/gen/up8x25-minmax-fp32-neonv8-mla8-ld64.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-neonv8-mla8-ld64.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-neonv8-mla8-ld64.c", + "xnnpack_wrappers/qc8-gemm/gen/1x8-minmax-fp32-neonv8-mlal-lane-prfm.c", + "xnnpack_wrappers/qc8-gemm/gen/1x8-minmax-fp32-neonv8-mlal-lane.c", + "xnnpack_wrappers/qc8-gemm/gen/1x8c2s4-minmax-fp32-neonv8-mlal.c", + "xnnpack_wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-neonv8-mlal.c", + "xnnpack_wrappers/qc8-gemm/gen/1x16-minmax-fp32-neonv8-mlal-lane.c", + "xnnpack_wrappers/qc8-gemm/gen/2x8c2s4-minmax-fp32-neonv8-mlal.c", + "xnnpack_wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-neonv8-mlal.c", + "xnnpack_wrappers/qc8-gemm/gen/4x16-minmax-fp32-neonv8-mlal-lane.c", + "xnnpack_wrappers/qc8-igemm/gen/1x8-minmax-fp32-neonv8-mlal-lane-prfm.c", + "xnnpack_wrappers/qc8-igemm/gen/1x8-minmax-fp32-neonv8-mlal-lane.c", + "xnnpack_wrappers/qc8-igemm/gen/1x8c2s4-minmax-fp32-neonv8-mlal.c", + "xnnpack_wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-neonv8-mlal.c", + "xnnpack_wrappers/qc8-igemm/gen/1x16-minmax-fp32-neonv8-mlal-lane.c", + "xnnpack_wrappers/qc8-igemm/gen/2x8c2s4-minmax-fp32-neonv8-mlal.c", + "xnnpack_wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-neonv8-mlal.c", + "xnnpack_wrappers/qc8-igemm/gen/4x16-minmax-fp32-neonv8-mlal-lane.c", +] + +PROD_AVX512SKX_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-f32-vcvt/gen/vcvt-avx512skx-x16.c", + "xnnpack_wrappers/f32-f16-vcvt/gen/vcvt-avx512skx-x16.c", + "xnnpack_wrappers/f32-qs8-vcvt/gen/vcvt-avx512skx-x128.c", + "xnnpack_wrappers/f32-qu8-vcvt/gen/vcvt-avx512skx-x128.c", + "xnnpack_wrappers/qc8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", + "xnnpack_wrappers/qc8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", + "xnnpack_wrappers/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qs8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", + "xnnpack_wrappers/qs8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", + "xnnpack_wrappers/qs8-f32-vcvt/gen/vcvt-avx512skx-x32.c", + "xnnpack_wrappers/qs8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qs8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qs8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-avx512skx-mul32-ld128-x16.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-avx512skx-mul32-ld128-x16.c", + "xnnpack_wrappers/qu8-dwconv/gen/up32x9-minmax-fp32-avx512skx-mul32.c", + "xnnpack_wrappers/qu8-dwconv/gen/up32x25-minmax-fp32-avx512skx-mul32.c", + "xnnpack_wrappers/qu8-f32-vcvt/gen/vcvt-avx512skx-x32.c", + "xnnpack_wrappers/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-avx512skx-mul32-ld128-x16.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-avx512skx-mul32-ld128-x16.c", + "xnnpack_wrappers/x8-lut/gen/lut-avx512skx-vpshufb-x64.c", +] + +PROD_NEON_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-f32-vcvt/gen/vcvt-neon-int16-x16.c", + "xnnpack_wrappers/f32-argmaxpool/4x-neon-c4.c", + "xnnpack_wrappers/f32-argmaxpool/9p8x-neon-c4.c", + "xnnpack_wrappers/f32-argmaxpool/9x-neon-c4.c", + "xnnpack_wrappers/f32-avgpool/9p8x-minmax-neon-c4.c", + "xnnpack_wrappers/f32-avgpool/9x-minmax-neon-c4.c", + "xnnpack_wrappers/f32-conv-hwc2chw/3x3s2p1c3x4-neon-2x2.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x3-minmax-neon.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x4-minmax-neon.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x9-minmax-neon.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x25-minmax-neon-acc2.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-neon-2x4.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3s2p1-minmax-neon-1x4.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/5x5p2-minmax-neon-1x4.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/5x5s2p2-minmax-neon-1x4.c", + "xnnpack_wrappers/f32-f16-vcvt/gen/vcvt-neon-x8.c", + "xnnpack_wrappers/f32-gavgpool-cw/neon-x4.c", + "xnnpack_wrappers/f32-gavgpool/7p7x-minmax-neon-c4.c", + "xnnpack_wrappers/f32-gavgpool/7x-minmax-neon-c4.c", + "xnnpack_wrappers/f32-gemm/gen/1x8-minmax-neon-lane-ld64.c", + "xnnpack_wrappers/f32-gemm/gen/4x2-minmax-neon-lane-ld64.c", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-neon-lane-ld64.c", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-neon-lane-ld128.c", + "xnnpack_wrappers/f32-ibilinear-chw/gen/neon-p8.c", + "xnnpack_wrappers/f32-ibilinear/gen/neon-c8.c", + "xnnpack_wrappers/f32-igemm/gen/1x8-minmax-neon-lane-ld64.c", + "xnnpack_wrappers/f32-igemm/gen/4x2-minmax-neon-lane-ld64.c", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-neon-lane-ld64.c", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-neon-lane-ld128.c", + "xnnpack_wrappers/f32-maxpool/9p8x-minmax-neon-c4.c", + "xnnpack_wrappers/f32-pavgpool/9p8x-minmax-neon-c4.c", + "xnnpack_wrappers/f32-pavgpool/9x-minmax-neon-c4.c", + "xnnpack_wrappers/f32-prelu/gen/neon-2x8.c", + "xnnpack_wrappers/f32-qs8-vcvt/gen/vcvt-neon-x32.c", + "xnnpack_wrappers/f32-qu8-vcvt/gen/vcvt-neon-x32.c", + "xnnpack_wrappers/f32-raddstoreexpminusmax/gen/neon-rr2-lut64-p2-x8.c", + "xnnpack_wrappers/f32-rmax/neon.c", + "xnnpack_wrappers/f32-spmm/gen/32x1-minmax-neon.c", + "xnnpack_wrappers/f32-vbinary/gen/vadd-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vaddc-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmaxc-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmin-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vminc-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmul-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmulc-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vrsubc-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiff-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiffc-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsub-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsubc-minmax-neon-x8.c", + "xnnpack_wrappers/f32-vclamp/gen/vclamp-neon-x8.c", + "xnnpack_wrappers/f32-velu/gen/velu-neon-rr2-lut16-p3-x8.c", + "xnnpack_wrappers/f32-vhswish/gen/vhswish-neon-x16.c", + "xnnpack_wrappers/f32-vlrelu/gen/vlrelu-neon-x8.c", + "xnnpack_wrappers/f32-vmulcaddc/gen/c4-minmax-neon-2x.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndd-neon-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndne-neon-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndu-neon-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndz-neon-x8.c", + "xnnpack_wrappers/f32-vsigmoid/gen/vsigmoid-neon-rr2-lut64-p2-nr2recps-x8.c", + "xnnpack_wrappers/f32-vunary/gen/vabs-neon-x8.c", + "xnnpack_wrappers/f32-vunary/gen/vneg-neon-x8.c", + "xnnpack_wrappers/f32-vunary/gen/vsqr-neon-x8.c", + "xnnpack_wrappers/qc8-dwconv/gen/up8x25-minmax-fp32-neon-mla8-ld64.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-neon-mla8-ld64.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-neon-mla8-ld64.c", + "xnnpack_wrappers/qc8-gemm/gen/1x8c2s4-minmax-fp32-neon-mlal.c", + "xnnpack_wrappers/qc8-gemm/gen/2x8c2s4-minmax-fp32-neon-mlal.c", + "xnnpack_wrappers/qc8-igemm/gen/1x8c2s4-minmax-fp32-neon-mlal.c", + "xnnpack_wrappers/qc8-igemm/gen/2x8c2s4-minmax-fp32-neon-mlal.c", + "xnnpack_wrappers/qs8-dwconv/gen/up8x25-minmax-rndnu-neon-mla8-ld64.c", + "xnnpack_wrappers/qs8-dwconv/gen/up16x9-minmax-rndnu-neon-mla8-ld64.c", + "xnnpack_wrappers/qs8-dwconv/gen/up16x25-minmax-rndnu-neon-mla8-ld64.c", + "xnnpack_wrappers/qs8-f32-vcvt/gen/vcvt-neon-x32.c", + "xnnpack_wrappers/qs8-gavgpool/gen/7p7x-minmax-rndnu-neon-c8.c", + "xnnpack_wrappers/qs8-gavgpool/gen/7x-minmax-rndnu-neon-c8.c", + "xnnpack_wrappers/qs8-gemm/gen/1x8-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qs8-gemm/gen/1x8c2s4-minmax-rndnu-neon-mlal.c", + "xnnpack_wrappers/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qs8-gemm/gen/2x8c2s4-minmax-rndnu-neon-mlal.c", + "xnnpack_wrappers/qs8-igemm/gen/1x8-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qs8-igemm/gen/1x8c2s4-minmax-rndnu-neon-mlal.c", + "xnnpack_wrappers/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qs8-igemm/gen/2x8c2s4-minmax-rndnu-neon-mlal.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-neon-ld64-x16.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-neon-ld64-x32.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-neon-ld64-x16.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-neon-ld64-x32.c", + "xnnpack_wrappers/qs8-vmul/gen/minmax-rndnu-neon-ld64-x16.c", + "xnnpack_wrappers/qs8-vmulc/gen/minmax-rndnu-neon-ld64-x16.c", + "xnnpack_wrappers/qu8-avgpool/9p8x-minmax-neon-c8.c", + "xnnpack_wrappers/qu8-avgpool/9x-minmax-neon-c8.c", + "xnnpack_wrappers/qu8-dwconv/gen/up8x25-minmax-rndnu-neon-mul8.c", + "xnnpack_wrappers/qu8-dwconv/gen/up16x9-minmax-rndnu-neon-mul8.c", + "xnnpack_wrappers/qu8-f32-vcvt/gen/vcvt-neon-x32.c", + "xnnpack_wrappers/qu8-gavgpool/gen/7p7x-minmax-rndnu-neon-c8.c", + "xnnpack_wrappers/qu8-gavgpool/gen/7x-minmax-rndnu-neon-c8.c", + "xnnpack_wrappers/qu8-gemm/gen/1x8-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qu8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qu8-gemm/gen/3x8-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qu8-gemm/gen/4x16-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qu8-igemm/gen/1x8-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qu8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qu8-igemm/gen/3x8-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qu8-igemm/gen/4x16-minmax-rndnu-neon-mlal-lane.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-neon-ld64-x16.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-neon-ld64-x32.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-neon-ld64-x16.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-neon-ld64-x32.c", + "xnnpack_wrappers/qu8-vmul/gen/minmax-rndnu-neon-ld64-x16.c", + "xnnpack_wrappers/qu8-vmulc/gen/minmax-rndnu-neon-ld64-x16.c", + "xnnpack_wrappers/s8-ibilinear/gen/neon-c8.c", + "xnnpack_wrappers/s8-ibilinear/gen/neon-c16.c", + "xnnpack_wrappers/s8-maxpool/9p8x-minmax-neon-c16.c", + "xnnpack_wrappers/s8-vclamp/neon-x64.c", + "xnnpack_wrappers/u8-ibilinear/gen/neon-c8.c", + "xnnpack_wrappers/u8-ibilinear/gen/neon-c16.c", + "xnnpack_wrappers/u8-maxpool/9p8x-minmax-neon-c16.c", + "xnnpack_wrappers/u8-rmax/neon.c", + "xnnpack_wrappers/u8-vclamp/neon-x64.c", + "xnnpack_wrappers/xx-fill/neon-x64.c", + "xnnpack_wrappers/xx-pad/neon.c", + "xnnpack_wrappers/x8-zip/xm-neon.c", + "xnnpack_wrappers/x8-zip/x2-neon.c", + "xnnpack_wrappers/x8-zip/x3-neon.c", + "xnnpack_wrappers/x8-zip/x4-neon.c", + "xnnpack_wrappers/x32-packx/x4-neon-st4.c", + "xnnpack_wrappers/x32-unpool/neon.c", + "xnnpack_wrappers/x32-zip/xm-neon.c", + "xnnpack_wrappers/x32-zip/x2-neon.c", + "xnnpack_wrappers/x32-zip/x3-neon.c", + "xnnpack_wrappers/x32-zip/x4-neon.c", +] + +PROD_AVX2_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-gemm/gen/1x16-minmax-avx2-broadcast.c", + "xnnpack_wrappers/f16-gemm/gen/4x16-minmax-avx2-broadcast.c", + "xnnpack_wrappers/f16-igemm/gen/1x16-minmax-avx2-broadcast.c", + "xnnpack_wrappers/f16-igemm/gen/4x16-minmax-avx2-broadcast.c", + "xnnpack_wrappers/f32-qs8-vcvt/gen/vcvt-avx2-x64.c", + "xnnpack_wrappers/f32-qu8-vcvt/gen/vcvt-avx2-x64.c", + "xnnpack_wrappers/f32-velu/gen/velu-avx2-rr1-lut4-p4-perm-x56.c", + "xnnpack_wrappers/f32-vsigmoid/gen/vsigmoid-avx2-rr1-p5-div-x40.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", + "xnnpack_wrappers/qc8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", + "xnnpack_wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qc8-gemm/gen/3x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qc8-igemm/gen/3x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qs8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", + "xnnpack_wrappers/qs8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", + "xnnpack_wrappers/qs8-f32-vcvt/gen/vcvt-avx2-x16.c", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qs8-gemm/gen/3x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qs8-igemm/gen/3x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-avx2-mul32-ld64-x16.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-avx2-mul32-ld64-x16.c", + "xnnpack_wrappers/qu8-dwconv/gen/up16x9-minmax-fp32-avx2-mul32.c", + "xnnpack_wrappers/qu8-dwconv/gen/up16x25-minmax-fp32-avx2-mul32.c", + "xnnpack_wrappers/qu8-f32-vcvt/gen/vcvt-avx2-x16.c", + "xnnpack_wrappers/qu8-gemm/gen/1x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qu8-gemm/gen/3x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qu8-igemm/gen/1x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qu8-igemm/gen/3x8c8-minmax-fp32-avx2.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-avx2-mul32-ld64-x16.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-avx2-mul32-ld64-x16.c", + "xnnpack_wrappers/x8-lut/gen/lut-avx2-x128.c", +] + +PROD_SSE_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f32-avgpool/9p8x-minmax-sse-c4.c", + "xnnpack_wrappers/f32-avgpool/9x-minmax-sse-c4.c", + "xnnpack_wrappers/f32-conv-hwc2chw/3x3s2p1c3x4-sse-2x2.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x3-minmax-sse.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x4-minmax-sse.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x9-minmax-sse.c", + "xnnpack_wrappers/f32-dwconv/gen/up8x25-minmax-sse.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3p1-minmax-sse-2x4-acc2.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/3x3s2p1-minmax-sse-1x4-acc3.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/5x5p2-minmax-sse-4x4.c", + "xnnpack_wrappers/f32-dwconv2d-chw/gen/5x5s2p2-minmax-sse-2x4.c", + "xnnpack_wrappers/f32-gavgpool-cw/sse-x4.c", + "xnnpack_wrappers/f32-gavgpool/7p7x-minmax-sse-c4.c", + "xnnpack_wrappers/f32-gavgpool/7x-minmax-sse-c4.c", + "xnnpack_wrappers/f32-gemm/gen/1x8-minmax-sse-load1.c", + "xnnpack_wrappers/f32-gemm/gen/4x2c4-minmax-sse.c", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-sse-load1.c", + "xnnpack_wrappers/f32-ibilinear-chw/gen/sse-p8.c", + "xnnpack_wrappers/f32-ibilinear/gen/sse-c8.c", + "xnnpack_wrappers/f32-igemm/gen/1x8-minmax-sse-load1.c", + "xnnpack_wrappers/f32-igemm/gen/4x2c4-minmax-sse.c", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-sse-load1.c", + "xnnpack_wrappers/f32-maxpool/9p8x-minmax-sse-c4.c", + "xnnpack_wrappers/f32-pavgpool/9p8x-minmax-sse-c4.c", + "xnnpack_wrappers/f32-pavgpool/9x-minmax-sse-c4.c", + "xnnpack_wrappers/f32-rmax/sse.c", + "xnnpack_wrappers/f32-spmm/gen/32x1-minmax-sse.c", + "xnnpack_wrappers/f32-vbinary/gen/vadd-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vaddc-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vdiv-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vdivc-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmaxc-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmin-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vminc-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmul-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vmulc-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vrdivc-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vrsubc-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiff-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiffc-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsub-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vbinary/gen/vsubc-minmax-sse-x8.c", + "xnnpack_wrappers/f32-vclamp/gen/vclamp-sse-x8.c", + "xnnpack_wrappers/f32-vhswish/gen/vhswish-sse-x8.c", + "xnnpack_wrappers/f32-vlrelu/gen/vlrelu-sse-x8.c", + "xnnpack_wrappers/f32-vmulcaddc/gen/c4-minmax-sse-2x.c", + "xnnpack_wrappers/f32-vsqrt/gen/sse-sqrt-x4.c", + "xnnpack_wrappers/f32-vunary/gen/vabs-sse-x8.c", + "xnnpack_wrappers/f32-vunary/gen/vneg-sse-x8.c", + "xnnpack_wrappers/f32-vunary/gen/vsqr-sse-x8.c", + "xnnpack_wrappers/x32-packx/x4-sse.c", +] + +PROD_SSE41_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-f32-vcvt/gen/vcvt-sse41-int16-x16.c", + "xnnpack_wrappers/f32-f16-vcvt/gen/vcvt-sse41-x8.c", + "xnnpack_wrappers/f32-prelu/gen/sse41-2x8.c", + "xnnpack_wrappers/f32-qs8-vcvt/gen/vcvt-sse41-x32.c", + "xnnpack_wrappers/f32-vlrelu/gen/vlrelu-sse41-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndd-sse41-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndne-sse41-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndu-sse41-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndz-sse41-x8.c", + "xnnpack_wrappers/f32-vsigmoid/gen/vsigmoid-sse41-rr2-lut64-p2-div-x8.c", + "xnnpack_wrappers/qc8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c", + "xnnpack_wrappers/qc8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c", + "xnnpack_wrappers/qc8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qc8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qc8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qc8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qs8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16-add16.c", + "xnnpack_wrappers/qs8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16-add16.c", + "xnnpack_wrappers/qs8-f32-vcvt/gen/vcvt-sse41-x16.c", + "xnnpack_wrappers/qs8-gavgpool/gen/7p7x-minmax-fp32-sse41-c8.c", + "xnnpack_wrappers/qs8-gavgpool/gen/7x-minmax-fp32-sse41-c8.c", + "xnnpack_wrappers/qs8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qs8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qs8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qs8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-sse41-mul16-ld64-x8.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c", + "xnnpack_wrappers/qs8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c", + "xnnpack_wrappers/qs8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c", + "xnnpack_wrappers/qu8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c", + "xnnpack_wrappers/qu8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c", + "xnnpack_wrappers/qu8-f32-vcvt/gen/vcvt-sse41-x16.c", + "xnnpack_wrappers/qu8-gavgpool/gen/7p7x-minmax-fp32-sse41-c8.c", + "xnnpack_wrappers/qu8-gavgpool/gen/7x-minmax-fp32-sse41-c8.c", + "xnnpack_wrappers/qu8-gemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qu8-gemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qu8-igemm/gen/1x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qu8-igemm/gen/3x4c8-minmax-fp32-sse41-ld64.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-sse41-mul16-ld64-x8.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-sse41-mul16-ld64-x8.c", + "xnnpack_wrappers/qu8-vmul/gen/minmax-fp32-sse41-mul16-ld64-x16.c", + "xnnpack_wrappers/qu8-vmulc/gen/minmax-fp32-sse41-mul16-ld64-x16.c", + "xnnpack_wrappers/s8-ibilinear/gen/sse41-c16.c", + "xnnpack_wrappers/s8-maxpool/9p8x-minmax-sse41-c16.c", + "xnnpack_wrappers/s8-vclamp/sse41-x64.c", + "xnnpack_wrappers/u8-ibilinear/gen/sse41-c16.c", +] + +PROD_SSE2_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-f32-vcvt/gen/vcvt-sse2-int16-x32.c", + "xnnpack_wrappers/f32-argmaxpool/4x-sse2-c4.c", + "xnnpack_wrappers/f32-argmaxpool/9p8x-sse2-c4.c", + "xnnpack_wrappers/f32-argmaxpool/9x-sse2-c4.c", + "xnnpack_wrappers/f32-f16-vcvt/gen/vcvt-sse2-x16.c", + "xnnpack_wrappers/f32-prelu/gen/sse2-2x8.c", + "xnnpack_wrappers/f32-qs8-vcvt/gen/vcvt-sse2-x32.c", + "xnnpack_wrappers/f32-qu8-vcvt/gen/vcvt-sse2-x32.c", + "xnnpack_wrappers/f32-raddstoreexpminusmax/gen/sse2-rr2-p5-x20-acc2.c", + "xnnpack_wrappers/f32-velu/gen/velu-sse2-rr2-lut16-p3-x12.c", + "xnnpack_wrappers/f32-vlrelu/gen/vlrelu-sse2-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndd-sse2-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndne-sse2-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndu-sse2-x8.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndz-sse2-x8.c", + "xnnpack_wrappers/f32-vsigmoid/gen/vsigmoid-sse2-rr2-lut64-p2-div-x8.c", + "xnnpack_wrappers/qc8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c", + "xnnpack_wrappers/qc8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16.c", + "xnnpack_wrappers/qc8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qc8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qc8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qc8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qs8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16-add16.c", + "xnnpack_wrappers/qs8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16-add16.c", + "xnnpack_wrappers/qs8-f32-vcvt/gen/vcvt-sse2-x32.c", + "xnnpack_wrappers/qs8-gavgpool/gen/7p7x-minmax-fp32-sse2-c8.c", + "xnnpack_wrappers/qs8-gavgpool/gen/7x-minmax-fp32-sse2-c8.c", + "xnnpack_wrappers/qs8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qs8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qs8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qs8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qs8-vadd/gen/minmax-sse2-mul16-ld64-x8.c", + "xnnpack_wrappers/qs8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c", + "xnnpack_wrappers/qs8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c", + "xnnpack_wrappers/qs8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c", + "xnnpack_wrappers/qu8-avgpool/9p8x-minmax-sse2-c8.c", + "xnnpack_wrappers/qu8-avgpool/9x-minmax-sse2-c8.c", + "xnnpack_wrappers/qu8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c", + "xnnpack_wrappers/qu8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16.c", + "xnnpack_wrappers/qu8-f32-vcvt/gen/vcvt-sse2-x32.c", + "xnnpack_wrappers/qu8-gavgpool/gen/7p7x-minmax-fp32-sse2-c8.c", + "xnnpack_wrappers/qu8-gavgpool/gen/7x-minmax-fp32-sse2-c8.c", + "xnnpack_wrappers/qu8-gemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qu8-gemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qu8-igemm/gen/1x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qu8-igemm/gen/3x4c8-minmax-fp32-sse2-ld64.c", + "xnnpack_wrappers/qu8-vadd/gen/minmax-sse2-mul16-ld64-x8.c", + "xnnpack_wrappers/qu8-vaddc/gen/minmax-sse2-mul16-ld64-x8.c", + "xnnpack_wrappers/qu8-vmul/gen/minmax-fp32-sse2-mul16-ld64-x8.c", + "xnnpack_wrappers/qu8-vmulc/gen/minmax-fp32-sse2-mul16-ld64-x8.c", + "xnnpack_wrappers/s8-ibilinear/gen/sse2-c8.c", + "xnnpack_wrappers/s8-maxpool/9p8x-minmax-sse2-c16.c", + "xnnpack_wrappers/s8-vclamp/sse2-x64.c", + "xnnpack_wrappers/u8-ibilinear/gen/sse2-c8.c", + "xnnpack_wrappers/u8-maxpool/9p8x-minmax-sse2-c16.c", + "xnnpack_wrappers/u8-rmax/sse2.c", + "xnnpack_wrappers/u8-vclamp/sse2-x64.c", + "xnnpack_wrappers/xx-fill/sse2-x64.c", + "xnnpack_wrappers/xx-pad/sse2.c", + "xnnpack_wrappers/x8-zip/xm-sse2.c", + "xnnpack_wrappers/x8-zip/x2-sse2.c", + "xnnpack_wrappers/x8-zip/x3-sse2.c", + "xnnpack_wrappers/x8-zip/x4-sse2.c", + "xnnpack_wrappers/x32-unpool/sse2.c", + "xnnpack_wrappers/x32-zip/xm-sse2.c", + "xnnpack_wrappers/x32-zip/x2-sse2.c", + "xnnpack_wrappers/x32-zip/x3-sse2.c", + "xnnpack_wrappers/x32-zip/x4-sse2.c", +] + +PROD_AVX512F_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f32-dwconv/gen/up16x3-minmax-avx512f.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x4-minmax-avx512f.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x9-minmax-avx512f.c", + "xnnpack_wrappers/f32-dwconv/gen/up16x25-minmax-avx512f.c", + "xnnpack_wrappers/f32-gemm/gen/1x16-minmax-avx512f-broadcast.c", + "xnnpack_wrappers/f32-gemm/gen/7x16-minmax-avx512f-broadcast.c", + "xnnpack_wrappers/f32-igemm/gen/1x16-minmax-avx512f-broadcast.c", + "xnnpack_wrappers/f32-igemm/gen/7x16-minmax-avx512f-broadcast.c", + "xnnpack_wrappers/f32-prelu/gen/avx512f-2x16.c", + "xnnpack_wrappers/f32-vbinary/gen/vadd-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vaddc-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vdiv-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vdivc-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vmaxc-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vmin-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vminc-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vmul-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vmulc-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vrdivc-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vrsubc-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiff-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vsqrdiffc-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vsub-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vbinary/gen/vsubc-minmax-avx512f-x32.c", + "xnnpack_wrappers/f32-vclamp/gen/vclamp-avx512f-x16.c", + "xnnpack_wrappers/f32-velu/gen/velu-avx512f-rr1-lut16-p3-perm-x64.c", + "xnnpack_wrappers/f32-vhswish/gen/vhswish-avx512f-x16.c", + "xnnpack_wrappers/f32-vlrelu/gen/vlrelu-avx512f-x16.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndd-avx512f-x16.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndne-avx512f-x16.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndu-avx512f-x16.c", + "xnnpack_wrappers/f32-vrnd/gen/vrndz-avx512f-x16.c", + "xnnpack_wrappers/f32-vsigmoid/gen/vsigmoid-avx512f-rr2-lut32-p2-perm2-scalef-div-x64.c", + "xnnpack_wrappers/f32-vunary/gen/vabs-avx512f-x16.c", + "xnnpack_wrappers/f32-vunary/gen/vneg-avx512f-x16.c", + "xnnpack_wrappers/f32-vunary/gen/vsqr-avx512f-x16.c", +] + +AARCH64_ASM_MICROKERNEL_SRCS = [ + "xnnpack_wrappers/f16-gemm/gen-inc/1x8inc-minmax-aarch64-neonfp16arith-ld64.S", + "xnnpack_wrappers/f16-gemm/gen-inc/1x16inc-minmax-aarch64-neonfp16arith-ld32.S", + "xnnpack_wrappers/f16-gemm/gen-inc/4x8inc-minmax-aarch64-neonfp16arith-ld64.S", + "xnnpack_wrappers/f16-gemm/gen-inc/4x16inc-minmax-aarch64-neonfp16arith-ld32.S", + "xnnpack_wrappers/f16-gemm/gen-inc/6x8inc-minmax-aarch64-neonfp16arith-ld64.S", + "xnnpack_wrappers/f16-gemm/gen-inc/6x16inc-minmax-aarch64-neonfp16arith-cortex-a55.S", + "xnnpack_wrappers/f16-gemm/gen-inc/6x16inc-minmax-aarch64-neonfp16arith-cortex-a75.S", + "xnnpack_wrappers/f16-gemm/gen-inc/6x16inc-minmax-aarch64-neonfp16arith-ld32.S", + "xnnpack_wrappers/f16-gemm/gen-inc/8x8inc-minmax-aarch64-neonfp16arith-ld64.S", + "xnnpack_wrappers/f16-gemm/gen/1x8-minmax-aarch64-neonfp16arith-ld64.S", + "xnnpack_wrappers/f16-gemm/gen/1x16-minmax-aarch64-neonfp16arith-ld32.S", + "xnnpack_wrappers/f16-gemm/gen/4x8-minmax-aarch64-neonfp16arith-ld64.S", + "xnnpack_wrappers/f16-gemm/gen/4x16-minmax-aarch64-neonfp16arith-ld32.S", + "xnnpack_wrappers/f16-gemm/gen/6x8-minmax-aarch64-neonfp16arith-ld64.S", + "xnnpack_wrappers/f16-gemm/gen/6x16-minmax-aarch64-neonfp16arith-cortex-a55.S", + "xnnpack_wrappers/f16-gemm/gen/6x16-minmax-aarch64-neonfp16arith-cortex-a75.S", + "xnnpack_wrappers/f16-gemm/gen/6x16-minmax-aarch64-neonfp16arith-ld32.S", + "xnnpack_wrappers/f16-gemm/gen/8x8-minmax-aarch64-neonfp16arith-ld64.S", + "xnnpack_wrappers/f16-igemm/4x16-minmax-aarch64-neonfp16arith-ld32.S", + "xnnpack_wrappers/f32-dwconv/up4x9-minmax-aarch64-neonfma-cortex-a55.S", + "xnnpack_wrappers/f32-dwconv/up4x9-minmax-aarch64-neonfma.S", + "xnnpack_wrappers/f32-gemm/gen-inc/1x8inc-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen-inc/1x8inc-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen-inc/1x8inc-minmax-aarch64-neonfma-ld64.S", + "xnnpack_wrappers/f32-gemm/gen-inc/1x8inc-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen-inc/1x12inc-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-cortex-a55.S", + "xnnpack_wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-ld64.S", + "xnnpack_wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-ld128.S", + "xnnpack_wrappers/f32-gemm/gen-inc/4x8inc-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen-inc/4x12inc-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen-inc/5x8inc-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen-inc/5x8inc-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-cortex-a55.S", + "xnnpack_wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-cortex-a73.S", + "xnnpack_wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-ld64.S", + "xnnpack_wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-ld128.S", + "xnnpack_wrappers/f32-gemm/gen-inc/6x8inc-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/1x8-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/1x8-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/1x8-minmax-aarch64-neonfma-ld64.S", + "xnnpack_wrappers/f32-gemm/gen/1x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/1x12-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-cortex-a55.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-ld64.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-ld128.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/4x12-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/5x8-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/5x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-cortex-a55.S", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-cortex-a73.S", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-ld64.S", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-ld128.S", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a53.S", + "xnnpack_wrappers/f32-gemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/1x8-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/1x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-ld64.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-ld128.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/5x8-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/5x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-ld64.S", + "xnnpack_wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-ld128.S", + "xnnpack_wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a75.S", + "xnnpack_wrappers/f32-igemm/1x8-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/1x12-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/4x8-minmax-aarch64-neonfma-cortex-a55.S", + "xnnpack_wrappers/f32-igemm/4x12-minmax-aarch64-neonfma-cortex-a53.S", + "xnnpack_wrappers/f32-igemm/6x8-minmax-aarch64-neonfma-cortex-a55.S", + "xnnpack_wrappers/f32-igemm/6x8-minmax-aarch64-neonfma-cortex-a73.S", + "xnnpack_wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qc8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qc8-gemm/gen/1x16c4-minmax-fp32-aarch64-neondot-ld32.S", + "xnnpack_wrappers/qc8-gemm/gen/1x16c4-minmax-fp32-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qc8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mull.S", + "xnnpack_wrappers/qc8-gemm/gen/2x8c16-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qc8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qc8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld32.S", + "xnnpack_wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qc8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qc8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qc8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qc8-igemm/gen/2x8c16-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qc8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qc8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qc8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qc8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qc8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qc8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qs8-gemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-gemm/gen/1x16c4-minmax-fp32-aarch64-neondot-ld32.S", + "xnnpack_wrappers/qs8-gemm/gen/1x16c4-minmax-fp32-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/1x16c4-minmax-rndnu-aarch64-neondot-ld32.S", + "xnnpack_wrappers/qs8-gemm/gen/1x16c4-minmax-rndnu-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-fp32-aarch64-neon-mull.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mull.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c16-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-gemm/gen/2x8c16-minmax-rndnu-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x8-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld32.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld32.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qs8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qs8-igemm/gen/1x8c8-minmax-rndnu-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c8-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal-prfm.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c8-minmax-rndnu-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c16-minmax-fp32-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-igemm/gen/2x8c16-minmax-rndnu-aarch64-neon-mlal.S", + "xnnpack_wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x8-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16-minmax-fp32-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld64.S", + "xnnpack_wrappers/qs8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qu8-gemm/gen/4x8c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qu8-gemm/gen/4x8c4-minmax-rndnu-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a75.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a75.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qu8-gemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qu8-igemm/gen/4x8c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qu8-igemm/gen/4x8c4-minmax-rndnu-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a53.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-cortex-a75.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-ld64.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a53.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-cortex-a75.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16-minmax-rndnu-aarch64-neon-mlal-lane-prfm-ld64.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16c4-minmax-fp32-aarch64-neondot-ld128.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-cortex-a55.S", + "xnnpack_wrappers/qu8-igemm/gen/4x16c4-minmax-rndnu-aarch64-neondot-ld128.S", +] diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl new file mode 100644 index 0000000000000..2ca59eec5ff61 --- /dev/null +++ b/tools/BUCK.bzl @@ -0,0 +1,276 @@ +# @lint-ignore-every FBCODEBZLADDLOADS +load("//tools/build_defs:glob_defs.bzl", "subdir_glob") + +# shared by internal and OSS BUCK +def define_tools_targets( + python_binary, + python_library, + python_test, + third_party, + torchgen_deps, + contacts = []): + python_library( + name = "substitutelib", + srcs = ["substitute.py"], + base_module = "", + ) + + python_binary( + name = "substitute", + main_module = "substitute", + visibility = ["PUBLIC"], + deps = [ + ":substitutelib", + ], + ) + + python_library( + name = "jit", + # @lint-ignore BUCKRESTRICTEDSYNTAX + srcs = glob([ + "jit/*.py", + "jit/templates/*", + ]), + base_module = "tools", + visibility = ["PUBLIC"], + deps = [ + torchgen_deps, + ], + ) + + python_binary( + name = "gen_unboxing_bin", + main_module = "tools.jit.gen_unboxing", + visibility = [ + "PUBLIC", + ], + deps = [ + ":jit", + ], + ) + + python_library( + name = "gen_selected_mobile_ops_header", + srcs = ["lite_interpreter/gen_selected_mobile_ops_header.py"], + base_module = "tools", + visibility = ["PUBLIC"], + ) + + python_library( + name = "gen_oplist_lib", + srcs = subdir_glob([ + ("code_analyzer", "gen_oplist.py"), + ("code_analyzer", "gen_op_registration_allowlist.py"), + ]), + base_module = "", + tests = [ + ":gen_oplist_test", + ], + deps = [ + ":gen_selected_mobile_ops_header", + torchgen_deps, + third_party("pyyaml"), + ], + ) + + python_binary( + name = "gen_oplist", + main_module = "gen_oplist", + visibility = ["PUBLIC"], + deps = [ + ":gen_oplist_lib", + ], + ) + + python_library( + name = "gen_operators_yaml_lib", + srcs = subdir_glob([ + ("code_analyzer", "gen_operators_yaml.py"), + ("code_analyzer", "gen_op_registration_allowlist.py"), + ]), + base_module = "", + tests = [ + ":gen_operators_yaml_test", + ], + deps = [ + third_party("pyyaml"), + torchgen_deps, + ], + ) + + python_binary( + name = "gen_operators_yaml", + main_module = "gen_operators_yaml", + visibility = ["PUBLIC"], + deps = [ + ":gen_operators_yaml_lib", + ], + ) + + python_library( + name = "autograd", + # @lint-ignore BUCKRESTRICTEDSYNTAX + srcs = glob( + ["autograd/*.py"], + ), + base_module = "tools", + resources = [ + "autograd/deprecated.yaml", + "autograd/derivatives.yaml", + "autograd/templates/ADInplaceOrViewType.cpp", + "autograd/templates/Functions.cpp", + "autograd/templates/Functions.h", + "autograd/templates/TraceType.cpp", + "autograd/templates/VariableType.cpp", + "autograd/templates/VariableType.h", + "autograd/templates/annotated_fn_args.py.in", + "autograd/templates/python_enum_tag.cpp", + "autograd/templates/python_fft_functions.cpp", + "autograd/templates/python_functions.cpp", + "autograd/templates/python_functions.h", + "autograd/templates/python_linalg_functions.cpp", + "autograd/templates/python_nn_functions.cpp", + "autograd/templates/python_return_types.cpp", + "autograd/templates/python_sparse_functions.cpp", + "autograd/templates/python_special_functions.cpp", + "autograd/templates/python_torch_functions.cpp", + "autograd/templates/python_variable_methods.cpp", + "autograd/templates/variable_factories.h", + ], + visibility = ["PUBLIC"], + deps = [ + third_party("pyyaml"), + torchgen_deps, + ], + ) + + python_library( + name = "generate_code", + srcs = [ + "setup_helpers/generate_code.py", + ], + base_module = "tools", + deps = [ + ":autograd", + ":jit", + torchgen_deps, + ], + ) + + python_binary( + name = "generate_code_bin", + main_module = "tools.setup_helpers.generate_code", + # Windows does not support inplace: + # https://github.com/facebook/buck/issues/2161. + # + # Note that //arvr/mode/embedded/win/clang-aarch64-release sets + # its target platform to + # ovr_config//platform/embedded:clang-aarch64-linux-release, hence + # that is why we are selecting that OS to trigger this behavior. + package_style = select({ + "DEFAULT": "inplace", + "ovr_config//os:linux-arm64": "standalone", + }), + visibility = ["PUBLIC"], + # Because Windows does not support inplace packaging, we need to + # ensure it is unzipped before executing it, otherwise it will not + # be able to find any resources using path manipulation. + # + # See note above about why the OS is Linux here and not Windows. + zip_safe = select({ + "DEFAULT": True, + "ovr_config//os:linux-arm64": False, + }), + deps = [ + ":generate_code", + ], + ) + + python_library( + name = "gen-version-header-lib", + srcs = [ + "setup_helpers/gen_version_header.py", + ], + base_module = "", + deps = [], + ) + + python_binary( + name = "gen-version-header", + main_module = "setup_helpers.gen_version_header", + visibility = ["PUBLIC"], + deps = [ + ":gen-version-header-lib", + ], + ) + + python_library( + name = "gen_aten_vulkan_spv_lib", + srcs = [ + "gen_vulkan_spv.py", + ], + base_module = "", + deps = [ + torchgen_deps, + ], + ) + + python_binary( + name = "gen_aten_vulkan_spv_bin", + main_module = "gen_vulkan_spv", + visibility = [ + "PUBLIC", + ], + deps = [ + ":gen_aten_vulkan_spv_lib", + ], + ) + + python_test( + name = "selective_build_test", + srcs = [ + "test/test_selective_build.py", + ], + contacts = contacts, + visibility = ["PUBLIC"], + deps = [ + torchgen_deps, + ], + ) + + python_test( + name = "gen_oplist_test", + srcs = [ + "test/gen_oplist_test.py", + ], + contacts = contacts, + visibility = ["PUBLIC"], + deps = [ + ":gen_oplist_lib", + ], + ) + + python_test( + name = "gen_operators_yaml_test", + srcs = [ + "test/gen_operators_yaml_test.py", + ], + visibility = ["PUBLIC"], + contacts = contacts, + deps = [ + ":gen_operators_yaml_lib", + ], + ) + + python_test( + name = "test_codegen", + srcs = [ + "test/test_codegen.py", + ], + contacts = contacts, + visibility = ["PUBLIC"], + deps = [ + torchgen_deps, + ":autograd", + ], + ) diff --git a/tools/BUCK.oss b/tools/BUCK.oss new file mode 100644 index 0000000000000..97f67945120ed --- /dev/null +++ b/tools/BUCK.oss @@ -0,0 +1,10 @@ +load("//:buckbuild.bzl", "third_party") +load(":BUCK.bzl", "define_tools_targets") + +define_tools_targets( + python_binary = python_binary, + python_library = python_library, + python_test = python_test, + third_party = third_party, + torchgen_deps = "//torchgen:torchgen", +) diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index ca41b17a43e7b..46e1b8e20b491 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 -import os import argparse +import os import sys sys.path.append( @@ -113,6 +113,9 @@ "*/hip/*", # These files are compatible with both cuda and hip "aten/src/ATen/core/*", + # Correct path to generate HIPConfig.h: + # CUDAConfig.h.in -> (amd_build) HIPConfig.h.in -> (cmake) HIPConfig.h + "aten/src/ATen/cuda/CUDAConfig.h", "torch/csrc/jit/codegen/cuda/codegen.cpp", "torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu", "torch/csrc/jit/codegen/cuda/runtime/broadcast.cu", @@ -157,7 +160,7 @@ def is_hip_clang() -> bool: do_write = False with open(gloo_cmake_file, "r") as sources: lines = sources.readlines() - newlines = [line.replace("RCCL_LIBRARY", "RCCL_LIBRARY_PATH") for line in lines] + newlines = [line.replace("RCCL_LIBRARY", "RCCL_LIB_PATH") for line in lines] if lines == newlines: print("%s skipped" % gloo_cmake_file) else: diff --git a/tools/autograd/BUCK.oss b/tools/autograd/BUCK.oss deleted file mode 100644 index 04403f4a52696..0000000000000 --- a/tools/autograd/BUCK.oss +++ /dev/null @@ -1,35 +0,0 @@ -python_library( - name = "autograd", - srcs = glob( - ["*.py"], - ), - base_module = "tools.autograd", - resources = [ - "deprecated.yaml", - "derivatives.yaml", - "templates/ADInplaceOrViewType.cpp", - "templates/Functions.cpp", - "templates/Functions.h", - "templates/TraceType.cpp", - "templates/VariableType.cpp", - "templates/VariableType.h", - "templates/annotated_fn_args.py.in", - "templates/python_fft_functions.cpp", - "templates/python_functions.cpp", - "templates/python_functions.h", - "templates/python_linalg_functions.cpp", - "templates/python_nn_functions.cpp", - "templates/python_return_types.cpp", - "templates/python_sparse_functions.cpp", - "templates/python_special_functions.cpp", - "templates/python_torch_functions.cpp", - "templates/python_variable_methods.cpp", - "templates/variable_factories.h", - "templates/python_enum_tag.cpp", - ], - visibility = ["PUBLIC"], - deps = [ - "//third_party:pyyaml", - "//torchgen:torchgen", - ], -) diff --git a/tools/autograd/context.py b/tools/autograd/context.py index af1a6025ed8da..da39ad467e154 100644 --- a/tools/autograd/context.py +++ b/tools/autograd/context.py @@ -1,10 +1,10 @@ +import functools +from typing import Callable + from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI from torchgen.context import native_function_manager from torchgen.utils import T -import functools -from typing import Callable - # Like tools.api.context.with_native_function, but for # NativeFunctionWithDifferentiabilityInfo. def with_native_function_with_differentiability_info( diff --git a/tools/autograd/deprecated.yaml b/tools/autograd/deprecated.yaml index b7cbf3c918d0f..52f7ec50b6ea1 100644 --- a/tools/autograd/deprecated.yaml +++ b/tools/autograd/deprecated.yaml @@ -1,92 +1,134 @@ # Deprecated function signatures. These are exposed in Python, but not included # in the error message suggestions. -- name: add(Tensor self, Scalar alpha, Tensor other) +- name: add(Tensor self, Scalar alpha, Tensor other) -> Tensor aten: add(self, other, alpha) -- name: add(Tensor self, Scalar alpha, Tensor other, *, Tensor out) +- name: add_(Tensor(a!) self, Scalar alpha, Tensor other) -> Tensor(a!) + aten: add_(self, other, alpha) + +- name: add(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!) aten: add_out(out, self, other, alpha) -- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) +- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor aten: addbmm(self, batch1, batch2, beta, alpha) -- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor out) +- name: addbmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: addbmm_(self, batch1, batch2, beta, alpha) + +- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) aten: addbmm_out(out, self, batch1, batch2, beta, alpha) -- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) +- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) -> Tensor aten: addbmm(self, batch1, batch2, beta, 1) -- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor out) +- name: addbmm_(Scalar beta, Tensor(a!) self, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: addbmm_(self, batch1, batch2, beta, 1) + +- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) aten: addbmm_out(out, self, batch1, batch2, beta, 1) -- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) +- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor aten: addcdiv(self, tensor1, tensor2, value) -- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor out) +- name: addcdiv_(Tensor(a!) self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor(a!) + aten: addcdiv_(self, tensor1, tensor2, value) + +- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor(a!) out) -> Tensor(a!) aten: addcdiv_out(out, self, tensor1, tensor2, value) -- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) +- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor aten: addcmul(self, tensor1, tensor2, value) -- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor out) +- name: addcmul_(Tensor(a!) self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor(a!) + aten: addcmul_(self, tensor1, tensor2, value) + +- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor(a!) out) -> Tensor(a!) aten: addcmul_out(out, self, tensor1, tensor2, value) -- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) +- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor aten: addmm(self, mat1, mat2, beta, alpha) -- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2, *, Tensor out) +- name: addmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor(a!) + aten: addmm_(self, mat1, mat2, beta, alpha) + +- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) aten: addmm_out(out, self, mat1, mat2, beta, alpha) -- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) +- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) -> Tensor aten: addmm(self, mat1, mat2, beta, 1) -- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2, *, Tensor out) +- name: addmm_(Scalar beta, Tensor(a!) self, Tensor mat1, Tensor mat2) -> Tensor(a!) + aten: addmm_(self, mat1, mat2, beta, 1) + +- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) aten: addmm_out(out, self, mat1, mat2, beta, 1) -- name: sspaddmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) +- name: sspaddmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor aten: sspaddmm(self, mat1, mat2, beta, alpha) -- name: sspaddmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) +- name: sspaddmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) -> Tensor aten: sspaddmm(self, mat1, mat2, beta, 1) -- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec) +- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec) -> Tensor aten: addmv(self, mat, vec, beta, alpha) -- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec, *, Tensor out) +- name: addmv_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor mat, Tensor vec) -> Tensor(a!) + aten: addmv_(self, mat, vec, beta, alpha) + +- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) aten: addmv_out(out, self, mat, vec, beta, alpha) -- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec) +- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec) -> Tensor aten: addmv(self, mat, vec, beta, 1) -- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec, *, Tensor out) +- name: addmv_(Scalar beta, Tensor(a!) self, Tensor mat, Tensor vec) -> Tensor(a!) + aten: addmv_(self, mat, vec, beta, 1) + +- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) aten: addmv_out(out, self, mat, vec, beta, 1) -- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2) +- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2) -> Tensor aten: addr(self, vec1, vec2, beta, alpha) -- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2, *, Tensor out) +- name: addr_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor vec1, Tensor vec2) -> Tensor(a!) + aten: addr_(self, vec1, vec2, beta, alpha) + +- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) aten: addr_out(out, self, vec1, vec2, beta, alpha) -- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2) +- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2) -> Tensor aten: addr(self, vec1, vec2, beta, 1) -- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2, *, Tensor out) +- name: addr_(Scalar beta, Tensor(a!) self, Tensor vec1, Tensor vec2) -> Tensor(a!) + aten: addr_(self, vec1, vec2, beta, 1) + +- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) aten: addr_out(out, self, vec1, vec2, beta, 1) -- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) +- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor aten: baddbmm(self, batch1, batch2, beta, alpha) -- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor out) +- name: baddbmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: baddbmm_(self, batch1, batch2, beta, alpha) + +- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) aten: baddbmm_out(out, self, batch1, batch2, beta, alpha) -- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) +- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) -> Tensor aten: baddbmm(self, batch1, batch2, beta, 1) -- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor out) +- name: baddbmm_(Scalar beta, Tensor(a!) self, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: baddbmm_(self, batch1, batch2, beta, 1) + +- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) aten: baddbmm_out(out, self, batch1, batch2, beta, 1) -- name: sub(Tensor self, Scalar alpha, Tensor other) +- name: sub(Tensor self, Scalar alpha, Tensor other) -> Tensor aten: sub(self, other, alpha) -- name: sub(Tensor self, Scalar alpha, Tensor other, *, Tensor out) +- name: sub_(Tensor(a!) self, Scalar alpha, Tensor other) -> Tensor(a!) + aten: sub_(self, other, alpha) + +- name: sub(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!) aten: sub_out(out, self, other, alpha) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 550b749ba68f0..c9c708026854b 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -277,7 +277,8 @@ output_differentiability: [False] - name: acosh(Tensor self) -> Tensor - self: grad * (self.pow(2) - 1).rsqrt().conj() +# Save one rsqrt in the real case by using that for x real and positive sqrt(x*y) = sqrt(x)*sqrt(y) (not true in the complex case) + self: "self.is_complex() ? grad * ((self + 1).rsqrt() * (self - 1).rsqrt()).conj() : grad * (self * self - 1).rsqrt()" result: auto_element_wise - name: acosh_(Tensor(a!) self) -> Tensor(a!) @@ -341,6 +342,11 @@ mat2: self.transpose(1, 2).conj().bmm(grad) result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t) +- name: _NestedTensor_GeneralizedBMM(Tensor self, Tensor mat2) -> Tensor + self: _NestedTensor_GeneralizedBMM(grad, mat2.transpose(-2, -1).conj()) + mat2: _NestedTensor_GeneralizedBMM(self.transpose(-2, -1).conj(), grad) + result: _NestedTensor_GeneralizedBMM(self_t, mat2_p) + _NestedTensor_GeneralizedBMM(self_p, mat2_t) + - name: cat(Tensor[] tensors, int dim=0) -> Tensor tensors: cat_tensors_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors), dim) result: cat_jvp(tensors, dim) @@ -493,10 +499,16 @@ self: deg2rad_backward(grad) result: auto_element_wise -- name: _det_lu_based_helper(Tensor self) -> (Tensor det, Tensor lu, Tensor pivs) - self: _det_lu_based_helper_backward(grad, det, self, lu, pivs) +- name: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + A: linalg_det_backward(grad, result, A, LU, pivots) + result: at::linalg_lu_solve(LU, pivots, A_t, /*left*/true, /*adjoint*/!A_p.is_complex() && A_p.is_contiguous()).diagonal(0, -2, -1).sum(-1) * result output_differentiability: [True, False, False] +- name: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + A: slogdet_backward(grad_sign, grad_logabsdet, A, sign, LU, pivots) + sign, logabsdet: slogdet_jvp(LU, pivots, A_t, sign, A_p.is_contiguous() && !A_p.is_complex()) + output_differentiability: [True, True, False, False] + - name: block_diag(Tensor[] tensors) -> Tensor tensors: block_diag_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors)) result: block_diag_jvp(tensors) @@ -528,10 +540,10 @@ - name: div.Tensor(Tensor self, Tensor other) -> Tensor self: div_tensor_self_backward(grad, other, self.scalar_type()) other: div_tensor_other_backward(grad, self, other) - result: self_t / other_p - other_t * (self_p / other_p) / other_p + result: (self_t - other_t * result) / other_p - name: div.Scalar(Tensor self, Scalar other) -> Tensor - self: div_tensor_self_backward(grad, at::scalar_to_tensor(other), self.scalar_type()) + self: div_tensor_self_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type()) result: self_t / other - name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor @@ -540,7 +552,7 @@ result: "rounding_mode.has_value() ? result.new_zeros(result.sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p" - name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor - self: div_tensor_self_backward(grad, at::scalar_to_tensor(other), self.scalar_type(), rounding_mode) + self: div_tensor_self_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type(), rounding_mode) result: "rounding_mode.has_value() ? result.new_zeros(result.sizes()) : self_t / other" - name: dot(Tensor self, Tensor tensor) -> Tensor @@ -934,9 +946,6 @@ - name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor self: not_implemented("zeta") -- name: logdet(Tensor self) -> Tensor - self: logdet_backward(grad, self, result) - - name: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) self: zeros_like(grad) result: self_t.zero_() @@ -1031,11 +1040,11 @@ result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t) - name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor - self: grad.expand(self.sizes()).to(self.scalar_type()) / self.numel() + self: grad.expand(self.sizes()) / self.numel() result: auto_linear -- name: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.scalar_type()) / _safe_size(self.sizes(), dim) +- name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: mean_backward(grad, self.sizes(), dim, self.numel(), keepdim) result: auto_linear - name: median(Tensor self) -> Tensor @@ -1088,9 +1097,11 @@ - name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) + result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) - name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) + result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) - name: mm(Tensor self, Tensor mat2) -> Tensor self: mm_mat1_backward(grad, mat2, self.sizes(), self.strides(), self.layout(), 1) @@ -1107,7 +1118,7 @@ result: other_t * self_p + self_t * other_p - name: mul.Scalar(Tensor self, Scalar other) -> Tensor - self: mul_tensor_backward(grad, at::scalar_to_tensor(other), self.scalar_type()) + self: mul_tensor_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type()) result: self_t * other - name: mv(Tensor self, Tensor vec) -> Tensor @@ -1261,11 +1272,11 @@ self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim) result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj() -- name: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) +- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor self: "accumulate ? grad : grad.put(index, zeros_like(source), false)" index: non_differentiable source: grad.take(index).reshape_as(source) - result: auto_linear # It is affine, but sure + result: self_t.put(index, source_t, accumulate) - name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode) @@ -1331,6 +1342,10 @@ # making it impossible (hard) to detect when it is actually a view. # - name: reshape(Tensor self, IntArrayRef shape) +- name: _reshape_nested(Tensor self, int[] shape) -> Tensor + self: _reshape_nested_backward(self, grad) + result: auto_linear + - name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a) self: grad.reshape(self.sizes()) result: auto_linear @@ -1385,8 +1400,11 @@ result: auto_element_wise - name: sgn(Tensor self) -> Tensor - self: sgn_backward(result, grad, self) - result: auto_element_wise + self: sgn_backward(self, grad, result) + # Cannot use auto_element_wise here because the Jacobian is *not* Hermitian (in fact, it is symmetric) + # The function is not holomorphic, so there's no reason for its Jacobian to be Hermitian + # auto_element_wise has a name that's a bit deceiving in the complex case + result: sgn_backward(self_p, self_t, result) - name: sin(Tensor self) -> Tensor self: grad * self.cos().conj() @@ -1429,18 +1447,10 @@ src: grad.contiguous().as_strided(size, stride, storage_offset) result: auto_linear -- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) - self: slogdet_backward(grad, self, sign, logabsdet) - output_differentiability: [false, true] - -- name: linalg_slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) - self: slogdet_backward(grad, self, sign, logabsdet) - output_differentiability: [false, true] - -- name: _linalg_solve(Tensor A, Tensor B, *, bool left=True) -> (Tensor result, Tensor LU, Tensor pivots) +- name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1]) result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())" - output_differentiability: [True, False, False] # LU is an auxiliary tensor not exposed to the user + output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user - name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) @@ -1490,10 +1500,14 @@ - name: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor self: std_backward(result, grad, self, dim, correction, keepdim) - result: handle_r_to_c(result.scalar_type(), var_jvp(self_t, self_p, result, dim, correction, keepdim) / (2 * result)) + # pointwise (variance) + sum + sqrt + result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0) - name: std_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor) - self: var_std_mean_backward(grads, self, result0, result1, dim, correction, keepdim, true) + self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim) + result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0) + # linear + result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) - name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor self: handle_r_to_c(self.scalar_type(), grad) @@ -1517,12 +1531,17 @@ self: grad.expand(self.sizes()) result: auto_linear -- name: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +- name: sum.SymInt(Tensor self, SymInt[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: sum_backward(grad, self.sym_sizes(), dim, keepdim) + result: auto_linear + +- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: sum_backward(grad, self.sizes(), dim, keepdim) result: auto_linear - name: nansum(Tensor self, int[1] dim=[], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) + result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype) # We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here - name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) @@ -1569,7 +1588,7 @@ result: auto_linear - name: take(Tensor self, Tensor index) -> Tensor - self: zeros_like(self).put_(index, grad, true) + self: take_backward(grad, self, index) index: non_differentiable result: auto_linear @@ -1672,7 +1691,12 @@ result: auto_linear - name: lift(Tensor self) -> Tensor - self: not_implemented("lift") + self: grad + result: auto_linear + +- name: lift_fresh(Tensor(a) self) -> Tensor(a) + self: grad + result: auto_linear - name: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) self: grad.squeeze(dim) @@ -1684,15 +1708,25 @@ - name: var.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor self: var_backward(grad, self, dim, correction, keepdim) - result: handle_r_to_c(result.scalar_type(), var_jvp(self_t, self_p, result, dim, correction, keepdim)) + # pointwise + sum + result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) - name: var_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor) - self: var_std_mean_backward(grads, self, result0, result1, dim, correction, keepdim, false) + self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim) + result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) + # linear + result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) - name: view(Tensor(a) self, int[] size) -> Tensor(a) self: grad.reshape(self.sizes()) result: auto_linear +- name: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a) + # TODO: add proper double backward for view.SymInt + # by SymIntizing `reshape` + self: grad.reshape(c10::asIntArrayRefSlow(self.sym_sizes())) + result: auto_linear + - name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) output_differentiability: [False] @@ -1805,11 +1839,6 @@ indices: non_differentiable self: not_implemented("embedding_renorm") -- name: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor - self: kl_div_backward(grad, self, target, reduction, log_target) - target: kl_div_target_backward(grad, self, target, reduction, log_target) - result: apply_loss_reduction(kl_div_backward(self_t, self_p, target_p, at::Reduction::None, log_target) + kl_div_target_backward(target_t, self_p, target_p, at::Reduction::None, log_target), reduction) - - name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor self: mse_loss_backward(grad, self, target, reduction) target: mse_loss_backward(grad, target, self, reduction) @@ -2130,6 +2159,9 @@ result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) output_differentiability: [True, False] +- name: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + input, weight, bias: linear_backward(input, grad, weight, grad_input_mask) + #mps - name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode) @@ -2184,7 +2216,7 @@ input, weight, bias: "grad.defined() ? convolution_backward_overrideable(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" - name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, false, output_padding, groups, grad_input_mask) + grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) - name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" @@ -2285,11 +2317,6 @@ self: zeros_like(grad) result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result)) -- name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor - grad_output: kl_div_double_backward_grad_output(grad, self, target, reduction, log_target) - self: zeros_like(grad) - target: zeros_like(grad) - - name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor grad_output: log_sigmoid_backward(grad, self, buffer) self: log_sigmoid_double_backward(grad * grad_output, self) @@ -2356,6 +2383,7 @@ # self_is_result is always false here since double backward call is an out-of-place call, self is input itself grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) self: zeros_like(grad) + result: rrelu_with_noise_backward(grad_output_t, self_p, noise, lower, upper, training, false) - name: reflection_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor grad_output: reflection_pad1d(grad, padding) @@ -2753,21 +2781,6 @@ - name: _pin_memory(Tensor self, Device? device=None) -> Tensor self: grad -# Empty factory functions have explicit non-differentiability so that they propagate the class -# when used with a Tensor subclass together with __torch_dispatch__. -# All the other factory functions are composite and call into one of these. -- name: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - self: non_differentiable - output_differentiability: [False] - -- name: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - self: non_differentiable - output_differentiability: [False] - -- name: empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - self: non_differentiable - output_differentiability: [False] - - name: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor self: non_differentiable other: non_differentiable @@ -2783,6 +2796,9 @@ self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result) index: non_differentiable +- name: special_airy_ai(Tensor x) -> Tensor + x: non_differentiable + - name: special_bessel_j0(Tensor self) -> Tensor self: non_differentiable @@ -2887,6 +2903,12 @@ - name: special_modified_bessel_k1(Tensor self) -> Tensor self: non_differentiable +- name: special_scaled_modified_bessel_k0(Tensor x) -> Tensor + x: non_differentiable + +- name: special_scaled_modified_bessel_k1(Tensor x) -> Tensor + x: non_differentiable + - name: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor x: non_differentiable n: non_differentiable @@ -2926,3 +2948,6 @@ - name: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor x: non_differentiable + +- name: special_spherical_bessel_j0(Tensor x) -> Tensor + x: non_differentiable diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index 89269e8e0e0ff..82d165890f1ec 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -14,26 +14,28 @@ torch/testing/_internal/generated """ -from collections import defaultdict import argparse import os import textwrap +from collections import defaultdict -from typing import Dict, List, Any +from typing import Any, Dict, List -from torchgen.gen import parse_native_yaml -from torchgen.utils import FileManager +import torchgen.api.python as python from torchgen.context import with_native_function + +from torchgen.gen import parse_native_yaml from torchgen.model import BaseOperatorName, NativeFunction -import torchgen.api.python as python +from torchgen.utils import FileManager + from .gen_python_functions import ( - should_generate_py_binding, - is_py_torch_function, - is_py_nn_function, + is_py_fft_function, is_py_linalg_function, - is_py_variable_method, + is_py_nn_function, is_py_special_function, - is_py_fft_function, + is_py_torch_function, + is_py_variable_method, + should_generate_py_binding, ) diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 25a04fb14acc1..3747ae3341f45 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -25,6 +25,8 @@ import argparse import os +from typing import List + from torchgen.api import cpp from torchgen.api.autograd import ( match_differentiability_info, @@ -32,16 +34,16 @@ ) from torchgen.gen import parse_native_yaml from torchgen.selective_build.selector import SelectiveBuilder -from typing import List + from . import gen_python_functions from .gen_autograd_functions import ( gen_autograd_functions_lib, gen_autograd_functions_python, ) -from .gen_trace_type import gen_trace_type -from .gen_variable_type import gen_variable_type from .gen_inplace_or_view_type import gen_inplace_or_view_type +from .gen_trace_type import gen_trace_type from .gen_variable_factories import gen_variable_factories +from .gen_variable_type import gen_variable_type from .load_derivatives import load_derivatives diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 3e1e55b82b2fc..072c236f61351 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -4,8 +4,6 @@ # Functions.h/cpp: subclasses of autograd::Node # python_functions.h/cpp: Python bindings for the above classes # -from .gen_inplace_or_view_type import VIEW_FUNCTIONS - from typing import List, Sequence, Tuple from torchgen.api.autograd import ( @@ -16,25 +14,28 @@ uses_single_grad, ) from torchgen.api.types import ( - Binding, + ArrayRefCType, BaseCType, - OptionalCType, - tensorT, - longT, - doubleT, - scalarT, - stringT, + Binding, boolT, + doubleT, intArrayRefT, - tensorListT, - MutRefCType, ListCType, - ArrayRefCType, + longT, + MutRefCType, + OptionalCType, optionalIntArrayRefT, + scalarT, + stringT, + symIntArrayRefT, + tensorListT, + tensorT, ) from torchgen.code_template import CodeTemplate -from torchgen.utils import FileManager from torchgen.model import Argument +from torchgen.utils import FileManager + +from .gen_inplace_or_view_type import VIEW_FUNCTIONS FUNCTION_DECLARATION = CodeTemplate( """\ @@ -281,6 +282,20 @@ return tup; """ +GETTER_BODY_ARRAYREF_SYMINT = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + auto si = prop[i]; + if (si.is_symbolic()) { + auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr(); + PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint); + } else { + PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked())); + } +} +return tup; +""" + GETTER_BODY_ARRAYREF_DOUBLE = """\ PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); for (auto i : c10::irange(prop.size())) { @@ -520,6 +535,13 @@ def save_var(var: SavedAttribute, is_output: bool) -> None: op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG ) ) + elif type == BaseCType(symIntArrayRefT): + saved_variables.append(f"std::vector {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT + ) + ) elif type == BaseCType(optionalIntArrayRefT): saved_variables.append(f"c10::OptionalArray {name};") getter_definitions.append( diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 541ef2b5312bf..69f3eecf590cc 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -4,40 +4,42 @@ # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp # The fallback is expected to mimick this codegen, so we should keep the two in sync. +from typing import Dict, List, Optional, Sequence, Tuple + from torchgen.api import cpp from torchgen.api.autograd import ( - NativeFunctionWithDifferentiabilityInfo, - gen_differentiable_outputs, dispatch_strategy, + gen_differentiable_outputs, + NativeFunctionWithDifferentiabilityInfo, ) from torchgen.api.types import ( - Binding, - DispatcherSignature, - CType, BaseCType, - OptionalCType, - longT, + Binding, boolT, + CType, + DispatcherSignature, intArrayRefT, + longT, + OptionalCType, symIntArrayRefT, ) from torchgen.code_template import CodeTemplate from torchgen.context import with_native_function from torchgen.model import ( - Type, NativeFunction, + SchemaKind, SelfArgument, TensorOptionsArguments, - SchemaKind, + Type, ) -from typing import List, Optional, Sequence, Tuple, Dict from torchgen.utils import FileManager + from .context import with_native_function_with_differentiability_info from .gen_trace_type import ( + get_return_value, MANUAL_AUTOGRAD, - type_wrapper_name, tie_return_values, - get_return_value, + type_wrapper_name, ) # See NOTE [ Autograd View Variables ] in variable.h for details. diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 2336b2354915d..aeb1c4af9e0dc 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -30,24 +30,16 @@ # message, but use what's there # -from collections import defaultdict import itertools import re -import yaml +from collections import defaultdict -from .gen_trace_type import should_trace +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple -from torchgen.code_template import CodeTemplate +import yaml from torchgen.api import cpp -from torchgen.api.types import CppSignatureGroup from torchgen.api.python import ( - PythonArgument, - PythonSignature, - PythonSignatureDeprecated, - PythonSignatureGroup, - PythonSignatureNativeFunctionPair, arg_parser_output_exprs, - argument_type_str, cpp_dispatch_exprs, cpp_dispatch_target, dispatch_lambda_args, @@ -55,20 +47,28 @@ dispatch_lambda_return_str, has_tensor_options, namedtuple_fieldnames, + PythonSignature, + PythonSignatureDeprecated, + PythonSignatureGroup, + PythonSignatureNativeFunctionPair, signature, + signature_from_schema, ) -from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml + +from torchgen.code_template import CodeTemplate from torchgen.context import with_native_function +from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml from torchgen.model import ( Argument, BaseOperatorName, + FunctionSchema, NativeFunction, Type, Variant, ) -from torchgen.utils import split_name_params, YamlLoader, FileManager +from torchgen.utils import FileManager, split_name_params, YamlLoader -from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable +from .gen_trace_type import should_trace # # declarations blocklist @@ -124,7 +124,6 @@ "_symeig.*", "_svd.*", "slice", - "randint(_out)?", "item", "_local_scalar_dense", "to", @@ -155,7 +154,8 @@ "copy", # only used by the functionalization pass "fill.Tensor", # only used by the functionalization pass "fill.Scalar", # only used by the functionalization pass - "lift", + "lift.*", + "normal_functional", # only used by the functionalization pas ] SKIP_PYTHON_BINDINGS = list( @@ -500,46 +500,10 @@ def load_deprecated_signatures( # the call) to generate the full python signature. # We join the deprecated and the original signatures using type-only form. - # native function -> type-only signature - @with_native_function - def signature_original(f: NativeFunction) -> str: - # remove inplace suffix but keep outplace suffix - opname = str(f.func.name.name.base) - if f.func.is_out_fn(): - opname += "_out" - if f.func.name.name.inplace and pyi: - opname += "_" - args = CppSignatureGroup.from_native_function( - f, method=False - ).signature.arguments() - # Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml. - types = ", ".join( - argument_type_str(a.argument.type) - for a in args - if isinstance(a.argument, Argument) - ) - return f"{opname}({types})" - - # deprecated -> type-only native signature (according to the call order) - def signature_deprecated( - opname: str, params: List[str], call_args: List[str] - ) -> str: - # create a mapping of parameter name to parameter type - types: Dict[str, str] = {} - for param in params: - if param == "*": - continue - type, name = param.split(" ") - types[name] = type - # if the name in the call is not in the parameter list, assume it's - # a literal Scalar - rearranged_types = ", ".join(types.get(arg, "Scalar") for arg in call_args) - return f"{opname}({rearranged_types})" - - # group the original ATen signatures by type-only signature + # group the original ATen signatures by name grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) for pair in pairs: - grouped[signature_original(pair.function)].append(pair) + grouped[pair.signature.name].append(pair) # find matching original signatures for each deprecated signature results: List[PythonSignatureNativeFunctionPair] = [] @@ -548,66 +512,89 @@ def signature_deprecated( deprecated_defs = yaml.load(f, Loader=YamlLoader) for deprecated in deprecated_defs: - _, params = split_name_params(deprecated["name"]) + schema = FunctionSchema.parse(deprecated["name"]) aten_name, call_args = split_name_params(deprecated["aten"]) - - for pair in grouped[signature_deprecated(aten_name, params, call_args)]: - # It uses the types from the original ATen declaration, but the - # ordering and parameter names from the deprecated overload. Any - # default parameter values from the original ATen declaration are - # ignored. - # Deprecated signature might reorder input_args and input_kwargs, - # but never changes output_args nor TensorOptions (if any?), - # so here we only look into these two types of args. - python_sig = pair.signature - src_args: Dict[str, PythonArgument] = { - a.name: PythonArgument( - name=a.name, - type=a.type, - default=None, - default_init=None, + is_out = aten_name.endswith("_out") + if is_out: + aten_name = aten_name.replace("_out", "") + + # HACK: these are fixed constants used to pass the the aten function. + # The type must be known ahead of time + known_constants = { + "1": Type.parse("Scalar"), + } + schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} + for name in call_args: + assert ( + name in schema_args_by_name or name in known_constants + ), f"deprecation definiton: Unrecognized value {name}" + + # Map deprecated signature arguments to their aten signature and test + # if the types and alias annotation match. + def is_schema_compatible( + aten_schema: FunctionSchema, + ) -> bool: + arguments: Iterable[Argument] + if is_out: + arguments = itertools.chain( + aten_schema.arguments.out, aten_schema.arguments.flat_non_out ) - for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs) - } - - args: List[str] = [] - input_args: List[PythonArgument] = [] - input_kwargs: List[PythonArgument] = [] - - kwarg_only = False - for param in params: - if param == "*": - kwarg_only = True - continue - _, param_name = param.split(" ") - args.append(param_name) - - if param_name not in src_args: - # output argument - continue - - if not kwarg_only: - if not method or param_name != "self": - input_args.append(src_args[param_name]) + else: + arguments = aten_schema.arguments.flat_all + + for i, arg in enumerate(arguments): + if i < len(call_args): + arg_name = call_args[i] + if arg_name in known_constants: + schema_type = known_constants[arg_name] + schema_annotation = None + else: + schema_arg = schema_args_by_name[arg_name] + schema_type = schema_arg.type + schema_annotation = schema_arg.annotation + + if schema_type != arg.type or schema_annotation != arg.annotation: + return False else: - input_kwargs.append(src_args[param_name]) + if arg.default is None: + return False + + return len(schema.returns) == len(aten_schema.returns) and all( + a == b for a, b in zip(schema.returns, aten_schema.returns) + ) + + any_schema_found = False + for pair in grouped[aten_name]: + if not is_schema_compatible(pair.function.func): + continue + any_schema_found = True + + python_sig = signature_from_schema( + schema, + category_override=pair.function.category_override, + method=method, + pyi=pyi, + ) results.append( PythonSignatureNativeFunctionPair( signature=PythonSignatureDeprecated( name=python_sig.name, - input_args=tuple(input_args), - input_kwargs=tuple(input_kwargs), + input_args=python_sig.input_args, + input_kwargs=python_sig.input_kwargs, output_args=python_sig.output_args, tensor_options_args=python_sig.tensor_options_args, method=python_sig.method, - deprecated_args_names=tuple(args), + deprecated_schema=schema, deprecated_args_exprs=tuple(call_args), returns=python_sig.returns, ), function=pair.function, ) ) + assert ( + any_schema_found + ), f"No native function with name {aten_name} matched signature:\n {str(schema)}" return results @@ -915,7 +902,9 @@ def emit_dispatch_case( overload.signature, overload.base, namedtuple_typenames ), call_dispatch_out=emit_single_dispatch( - overload.signature, overload.outplace, namedtuple_typenames + overload.signature, + overload.outplace, + namedtuple_typenames, ), ) else: @@ -1047,21 +1036,13 @@ def group_overloads( + "\n".join(f"- {candidate}" for candidate in candidates) ) - grouped: List[PythonSignatureGroup] = [] - for sig, base in bases.items(): - outplace = outplaces.get(sig) - grouped.append( - PythonSignatureGroup( - # prefer the signature with optional out=... arguments because it's the - # superset that can be used to parse input for both base and outplace. - signature=outplace.signature - if outplace is not None - else base.signature, - base=base.function, - outplace=outplace.function if outplace is not None else None, - ) + grouped = [ + PythonSignatureGroup.from_pairs( + functional=base, + out=outplaces.get(sig), ) - + for sig, base in bases.items() + ] return sort_overloads(grouped) @@ -1134,9 +1115,9 @@ def is_arg_smaller(t1: Type, t2: Type) -> bool: str(t1) == "Tensor[]" and str(t2).find("[]") != -1 or - # Prioritize SymIntArrayRef overload over IntArrayRef - str(t1) == "int[]" - and str(t2) == "SymInt[]" + # Prioritize IntArrayRef overload over SymIntArrayRef + str(t1) == "SymInt[]" + and str(t2) == "int[]" ) def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: @@ -1203,8 +1184,12 @@ def emit_single_dispatch( @with_native_function def go(f: NativeFunction) -> str: # header comments + if isinstance(ps, PythonSignatureDeprecated): + schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}" + else: + schema_comment = f"// aten::{f.func}" + deprecated = "[deprecated] " if ps.deprecated else "" - schema_comment = f"// {deprecated}aten::{f.func}" # dispatch lambda signature name = cpp.name(f.func) diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index 46c3baf3b1e24..f1e584fa64e21 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -1,17 +1,13 @@ import itertools -from typing import List, Sequence, Union, Dict +from typing import Dict, List, Sequence, Union -from torchgen.api.types import DispatcherSignature from torchgen.api import cpp + +from torchgen.api.types import DispatcherSignature from torchgen.code_template import CodeTemplate from torchgen.context import with_native_function +from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments from torchgen.utils import FileManager -from torchgen.model import ( - Argument, - NativeFunction, - SchemaKind, - TensorOptionsArguments, -) # Note [Manual Backend kernels] # For these ops, we want to manually register to dispatch key Backend and diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py index 26eb2d91595d9..07abc98a8d4c2 100644 --- a/tools/autograd/gen_variable_factories.py +++ b/tools/autograd/gen_variable_factories.py @@ -3,15 +3,16 @@ # This writes one file: variable_factories.h import re -from typing import Optional, List +from typing import List, Optional -from torchgen.api.types import CppSignatureGroup -from torchgen.api import cpp import torchgen.api.python as python -from torchgen.gen import parse_native_yaml +from torchgen.api import cpp + +from torchgen.api.types import CppSignatureGroup from torchgen.context import with_native_function -from torchgen.utils import mapMaybe, FileManager +from torchgen.gen import parse_native_yaml from torchgen.model import NativeFunction, TensorOptionsArguments, Variant +from torchgen.utils import FileManager, mapMaybe OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>") TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 4d09dcffd561e..2667cf29228c1 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -25,81 +25,92 @@ # which will in turn dispatch back to VariableType for its # differentiable subcomponents. # -from .context import with_native_function_with_differentiability_info -from .gen_trace_type import ( - MANUAL_BACKEND, - MANUAL_AUTOGRAD_AND_TRACER, - declare_returned_variables, - tie_return_values, - get_return_value, - type_wrapper_name, -) -from .gen_inplace_or_view_type import ( - get_view_info, - is_tensor_type, - is_tensor_list_type, - unpack_args, - get_base_name, - use_derived, - modifies_arguments, - WRAPPER_REGISTRATION, - TMP_VAR, - METHOD_DEFINITION, - ASSIGN_RETURN_VALUE, - gen_formals, - ALL_VIEW_FUNCTIONS, - unpacked_name, - AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION, +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +from torchgen.api import cpp +from torchgen.api.autograd import ( + DifferentiableInput, + dispatch_strategy, + gen_differentiable_outputs, + is_differentiable, + NativeFunctionWithDifferentiabilityInfo, + SavedAttribute, ) from torchgen.api.types import ( + BaseCType, Binding, DispatcherSignature, - BaseCType, intArrayRefT, - tensorT, - tensorListT, + ListCType, MutRefCType, OptionalCType, - ListCType, - SpecialArgName, scalarT, + SpecialArgName, stringT, + symIntArrayRefT, + tensorListT, + tensorT, TupleCType, VectorCType, ) -from torchgen.api.autograd import ( - DifferentiableInput, - NativeFunctionWithDifferentiabilityInfo, - SavedAttribute, - dispatch_strategy, - gen_differentiable_outputs, - is_differentiable, -) -from torchgen.api import cpp from torchgen.code_template import CodeTemplate from torchgen.context import native_function_manager, with_native_function -from torchgen.utils import mapMaybe, FileManager from torchgen.model import ( Argument, + BaseType, + ListType, NativeFunction, SchemaKind, SelfArgument, TensorOptionsArguments, - BaseType, - ListType, ) -from typing import Callable, List, Optional, Sequence, Tuple, Union, Dict +from torchgen.utils import FileManager, mapMaybe + +from .context import with_native_function_with_differentiability_info +from .gen_inplace_or_view_type import ( + ALL_VIEW_FUNCTIONS, + ASSIGN_RETURN_VALUE, + AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION, + gen_formals, + get_base_name, + get_view_info, + is_tensor_list_type, + is_tensor_type, + METHOD_DEFINITION, + modifies_arguments, + TMP_VAR, + unpack_args, + unpacked_name, + use_derived, + WRAPPER_REGISTRATION, +) +from .gen_trace_type import ( + declare_returned_variables, + get_return_value, + MANUAL_AUTOGRAD_AND_TRACER, + MANUAL_BACKEND, + tie_return_values, + type_wrapper_name, +) # We don't set or modify grad_fn on these methods. Generally, they return # tensors that have requires_grad=False. In-place functions listed here will # not examine or modify requires_grad or grad_fn. +# NB: this does NOT include overload name DONT_REQUIRE_DERIVATIVE = { # These only depend on the input Tensor's shape and device, not the data + "empty_like", "ones_like", + "full_like", "zeros_like", "rand_like", "randn_like", + "new_empty", + "new_empty_strided", + "new_full", + "new_zeros", + "new_ones", # These are only implemented on integral types "__and__", "__iand__", @@ -134,6 +145,7 @@ "isinf", "signbit", "isin", + "allclose", # Functions return none are not differentiable "record_stream", # These functions are not differentiable @@ -149,6 +161,7 @@ # but will not error out. # C -> C, R -> C functions for which backward is correctly implemented and tested GRADIENT_IMPLEMENTED_FOR_COMPLEX = { + "fill", "t", "view", "reshape", @@ -238,6 +251,8 @@ "exp", "nonzero", "mean", + "std_mean", + "var_mean", "inverse", "solve", "linalg_cholesky", @@ -283,6 +298,7 @@ "replication_pad2d", "replication_pad3d", "take", + "put", "put_", "_to_copy", "replication_pad1d_backward", @@ -322,7 +338,7 @@ "conj_physical_", "_neg_view", "_reshape_alias", - "_det_lu_based_helper", + "_linalg_det", "lu_solve", "linalg_solve_triangular", "linalg_pinv", @@ -338,7 +354,8 @@ "pixel_shuffle", "pixel_unshuffle", "linalg_lu_solve", - "_linalg_solve", + "_linalg_slogdet", + "_linalg_solve_ex", } GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { @@ -511,6 +528,8 @@ "dequantize_self", # lift() should never actually be called with a requires_grad=True tensor, "lift", + "lift_fresh", + "lift_fresh_copy", # Nested Tensors related functions # _nested_tensor_size() should never actually be called with requires_grad=True tensor "_nested_tensor_size", @@ -636,7 +655,7 @@ FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate( """\ -if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined()) { +if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}.defined()) { // The hardcoded 0 here will need to be updated once we support multiple levels. ${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace}); } @@ -645,7 +664,8 @@ FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate( """\ -if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined()) { +if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined() + && ${out_arg}.defined()) { ${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false); } """ @@ -657,7 +677,7 @@ auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value(); TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size()); for (auto i=0; i<${out_arg}.size(); ++i) { - if (${out_arg}_new_fw_grad[i].defined()) { + if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) { // The hardcoded 0 here will need to be updated once we support multiple levels. ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace}); } @@ -1080,6 +1100,8 @@ def save_variables( name += "_" elif type == BaseCType(intArrayRefT): expr = expr + ".vec()" + elif type == BaseCType(symIntArrayRefT): + expr = expr + ".vec()" elif type == BaseCType(stringT): expr = f"std::string({expr})" elif type == OptionalCType(BaseCType(stringT)): @@ -1404,6 +1426,8 @@ def emit_fw_derivatives() -> List[str]: else: is_inplace_str = "false" + requires_fw_grad = get_any_has_forward_grad_name(derivative.var_names) + if all( (isinstance(var_type, BaseType) and var_type.is_tensor_like()) for var_type in derivative.var_types @@ -1416,6 +1440,7 @@ def emit_fw_derivatives() -> List[str]: out_arg=res[0], is_inplace=is_inplace_str ) ) + requires_fw_grad += f" && ({derivative.var_names[0]}.defined())" else: tuple_type = TupleCType( [BaseCType(tensorT)] * len(derivative.var_types) @@ -1453,9 +1478,7 @@ def emit_fw_derivatives() -> List[str]: content.append( FW_DERIVATIVE_TEMPLATE.substitute( fw_grad_opt_definition=fw_grad_opt_definition, - requires_fw_grad=get_any_has_forward_grad_name( - derivative.var_names - ), + requires_fw_grad=requires_fw_grad, formula=derivative.formula, out_arg="_".join(res), unpacked_arguments=unpacked_arguments, diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index bc6652ef93b85..7bf43cfb3c992 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -2,50 +2,49 @@ # # Each autograd function is represented by `DifferentiabilityInfo` containing # a list of `Derivative`. See `torchgen.api.autograd` for the data models. -from collections import defaultdict import re -from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional +from collections import defaultdict +from typing import Any, Counter, Dict, List, Match, Optional, Sequence, Set, Tuple + import yaml +from torchgen.api import cpp from torchgen.api.autograd import ( Derivative, DifferentiabilityInfo, - SavedAttribute, ForwardDerivative, + SavedAttribute, ) from torchgen.api.types import ( + BaseCType, Binding, + boolT, CppSignatureGroup, - NamedCType, - BaseCType, - VectorCType, intArrayRefT, - tensorOptionsT, - typeAndSizeT, - longT, - boolT, layoutT, - tensorGeometryT, + longT, + NamedCType, + OptionalCType, scalarTypeT, SpecialArgName, - OptionalCType, stringT, -) -from torchgen.api import cpp -from torchgen.gen import ( - parse_native_yaml, - get_grouped_by_view_native_functions, + symIntArrayRefT, + tensorGeometryT, + tensorOptionsT, + typeAndSizeT, + VectorCType, ) from torchgen.context import with_native_function +from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml from torchgen.model import ( FunctionSchema, NativeFunction, - Variant, - Type, NativeFunctionsViewGroup, OperatorName, + Type, + Variant, ) -from torchgen.utils import IDENT_REGEX, split_name_params, YamlLoader, concatMap +from torchgen.utils import concatMap, IDENT_REGEX, split_name_params, YamlLoader _GLOBAL_LOAD_DERIVATIVE_CACHE = {} @@ -698,6 +697,14 @@ def stride_expr(name: str) -> str: "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)), }, ), + # replace self.sym_sizes() with self_sym_sizes + ( + r"{}.sym_sizes\(\)", + { + "suffix": "_sym_sizes", + "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), + }, + ), # replace self->sizes() with self_sizes_opt ( r"{}->sizes\(\)", diff --git a/tools/autograd/templates/python_enum_tag.cpp b/tools/autograd/templates/python_enum_tag.cpp index cec5ffabd1c7a..0d86c52d752f9 100644 --- a/tools/autograd/templates/python_enum_tag.cpp +++ b/tools/autograd/templates/python_enum_tag.cpp @@ -1,4 +1,5 @@ #include +#include #include #include diff --git a/tools/autograd/templates/python_functions.cpp b/tools/autograd/templates/python_functions.cpp index 3be7a01a48401..57343a53ea982 100644 --- a/tools/autograd/templates/python_functions.cpp +++ b/tools/autograd/templates/python_functions.cpp @@ -5,11 +5,14 @@ #include #include +#include #include "torch/csrc/autograd/generated/Functions.h" #include "torch/csrc/autograd/python_cpp_function.h" #include #include +#include #include +#include // NOTE: See [Sharded File] comment in VariableType diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index fdbecf062b4b1..13f14bff75b87 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -95,43 +95,6 @@ static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg) END_HANDLE_TH_ERRORS } -// TODO: FIXME This should be super temprorary until we fix the XLA issue. -static PyObject * THPVariable_sym_size(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "sym_size(int64_t dim)", - "sym_size()", - "sym_size(Dimname dim)", - }); - auto& self_ = THPVariable_Unpack(self); - ParsedArgs<3> parsed_args; - auto r = parser.parse(self, args, kwargs, parsed_args); - - if(r.has_torch_function()){ - return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); - } - if (r.idx == 0) { - if (jit::tracer::isTracing()) { - // will error out if a tensor has symints - return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0))); - } else { - return torch::toPyObject(self_.sym_size(r.toInt64(0))); - } - } else if (r.idx == 1) { - return THPSize_NewFromSymSizes(self_); - } - else if (r.idx == 2) { - if (jit::tracer::isTracing()) { - TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT"); - } - return wrap(self_.size(r.dimname(0))); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - - static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS @@ -152,14 +115,10 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa // will error out if a tensor has symints return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0))); } else { - return wrap(self_.size(r.toInt64(0))); - //return torch::toPyObject(self_.sym_size(r.toInt64(0))); + return torch::toPyObject(self_.sym_size(r.toInt64(0))); } } else if (r.idx == 1) { - // we can't do the normal wrapping here because IntArrayRef maps to both - // torch.Size and tuple in python. - return THPSize_New(self_); - //return THPSize_NewFromSymSizes(self_); + return THPSize_NewFromSymSizes(self_); } else if (r.idx == 2) { if (jit::tracer::isTracing()) { @@ -1187,7 +1146,7 @@ static PyObject* THPVariable_set_( at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage); TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage, "Expected a Storage of type ", self.dtype(), - " or an _UntypedStorage, but got type ", storage_scalar_type, + " or an UntypedStorage, but got type ", storage_scalar_type, " for argument 1 'storage'"); auto dispatch_set_ = [](const Tensor& self, Storage source) -> Tensor { pybind11::gil_scoped_release no_gil; @@ -1203,7 +1162,7 @@ static PyObject* THPVariable_set_( at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage); TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage, "Expected a Storage of type ", self.dtype(), - " or an _UntypedStorage, but got type ", storage_scalar_type, + " or an UntypedStorage, but got type ", storage_scalar_type, " for argument 1 'storage'"); auto dispatch_set_ = [](const Tensor& self, Storage source, @@ -1322,7 +1281,6 @@ PyMethodDef variable_methods[] = { {"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL}, {"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL}, {"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL}, - {"sym_size", castPyCFunctionWithKeywords(THPVariable_sym_size), METH_VARARGS | METH_KEYWORDS, NULL}, {"_storage", THPVariable_storage, METH_NOARGS, NULL}, {"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL}, {"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/tools/bazel_tools/BUILD.bazel b/tools/bazel_tools/BUILD.bazel new file mode 100644 index 0000000000000..f4c37fb7389bd --- /dev/null +++ b/tools/bazel_tools/BUILD.bazel @@ -0,0 +1,5 @@ +sh_binary( + name = "shellwrap", + srcs = ["shellwrap.sh"], + visibility = ["//visibility:public"], +) diff --git a/tools/bazel_tools/shellwrap.sh b/tools/bazel_tools/shellwrap.sh new file mode 100755 index 0000000000000..1ebab29a6a73c --- /dev/null +++ b/tools/bazel_tools/shellwrap.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# This script is helpful in entering an interactive shell from a bazel build +# before running a given bazel executable. +# This can provide a quick way to explore the sandbox directory and filesystem. +# Typical use is with +# +# bazel run --run_under=//tools/bazel:shell_wrapper //:target +# OR +# bazel run --config=shell //:target + +shell='/bin/bash' +rcfile='/tmp/pytorch_bazel_tools_shellwrap' +while [[ $# -gt 0 ]] ; do + case "$1" in + --shell_bin_path) + # path for the shell executable + shell="$2" + shift 2 + ;; + --rcfile) + # path for the file used to write the environment + rcfile="$2" + shift 2 + ;; + *) + # remaining arguments are part of the command for execution + break + ;; + esac +done + +if ! tty -s; then + echo 'A tty is not available.' + echo "Use \`bazel run\`, not \`bazel test\`." + exit 1 +fi + +NOCOLOR='\033[0m' +YELLOW='\033[1;33m' + +# store the environment in a file +export PYTORCH_SHELL_COMMAND=$* +echo "alias run=\"$*\"" > "$rcfile" +echo "PS1='\s-\v\$ '" >> "$rcfile" + +echo ===== +# print the execution command (command is yellow) +echo -e "alias run=${YELLOW}$PYTORCH_SHELL_COMMAND${NOCOLOR}" +echo ===== + +echo "Entering interactive shell at the execution root:" + +# quote escape all the arguments to use as a single input string +cmd="'$shell' --noprofile --rcfile '$rcfile'" + +# run the command in a script psuedo terminal and dump to null +/usr/bin/script -c "$cmd" -q /dev/null diff --git a/tools/build_defs/android/build_mode_defs.bzl b/tools/build_defs/android/build_mode_defs.bzl new file mode 100644 index 0000000000000..2bf307f09c418 --- /dev/null +++ b/tools/build_defs/android/build_mode_defs.bzl @@ -0,0 +1,5 @@ +# @lint-ignore-every BUCKRESTRICTEDSYNTAX +def is_production_build(): + if read_config("pt", "is_oss", "0") == "0": + fail("This file is for open source pytorch build. Do not use it in fbsource!") + return False diff --git a/tools/build_defs/apple/build_mode_defs.bzl b/tools/build_defs/apple/build_mode_defs.bzl new file mode 100644 index 0000000000000..2bf307f09c418 --- /dev/null +++ b/tools/build_defs/apple/build_mode_defs.bzl @@ -0,0 +1,5 @@ +# @lint-ignore-every BUCKRESTRICTEDSYNTAX +def is_production_build(): + if read_config("pt", "is_oss", "0") == "0": + fail("This file is for open source pytorch build. Do not use it in fbsource!") + return False diff --git a/tools/build_defs/buck_helpers.bzl b/tools/build_defs/buck_helpers.bzl index c946f87eba065..a084e01eff93c 100644 --- a/tools/build_defs/buck_helpers.bzl +++ b/tools/build_defs/buck_helpers.bzl @@ -11,6 +11,7 @@ IGNORED_ATTRIBUTE_PREFIX = [ IGNORED_ATTRIBUTES = [ "feature", "platforms", + "contacts", ] def filter_attributes(kwgs): @@ -25,49 +26,3 @@ def filter_attributes(kwgs): if key.startswith(invalid_prefix): kwgs.pop(key) return kwgs - -# maps known fbsource deps to OSS deps -DEPS_MAP = { - "//third-party/FP16:FP16": "//third_party:FP16", - "//third-party/FXdiv:FXdiv": "//third_party:FXdiv", - "//third-party/XNNPACK:XNNPACK": "//third_party:XNNPACK", - "//third-party/clog:clog": "//third_party:clog", - "//third-party/cpuinfo:cpuinfo": "//third_party:cpuinfo", - "//third-party/fmt:fmt": "//third_party:fmt", - "//third-party/glog:glog": "//third_party:glog", - "//third-party/psimd:psimd": "//third_party:psimd", - "//third-party/pthreadpool:pthreadpool": "//third_party:pthreadpool", - "//third-party/pthreadpool:pthreadpool_header": "//third_party:pthreadpool_header", - "//third-party/ruy:ruy_xplat_lib": "//third_party:ruy_lib", -} - -# map fbsource deps to OSS deps -def to_oss_deps(deps = []): - new_deps = [] - for dep in deps: - new_deps += map_deps(dep) - return new_deps - -def map_deps(dep): - # remove @fbsource prefix - if dep.startswith("@fbsource"): - dep = dep[len("@fbsource"):] - - # ignore all fbsource linker_lib targets - if dep.startswith("//xplat/third-party/linker_lib"): - return [] - - # map targets in caffe2 root folder. Just use relative path - if dep.startswith("//xplat/caffe2:"): - return [dep[len("//xplat/caffe2"):]] - - # map targets in caffe2 subfolders - if dep.startswith("//xplat/caffe2/"): - return ["//" + dep[len("//xplat/caffe2/"):]] - - # map other known targets - if dep in DEPS_MAP: - return DEPS_MAP[dep] - - # drop other unknown deps - return [] diff --git a/tools/build_defs/fb_xplat_cxx_library.bzl b/tools/build_defs/fb_xplat_cxx_library.bzl index 7859836062135..025a3bc4d229b 100644 --- a/tools/build_defs/fb_xplat_cxx_library.bzl +++ b/tools/build_defs/fb_xplat_cxx_library.bzl @@ -1,19 +1,19 @@ # Only used for PyTorch open source BUCK build # @lint-ignore-every BUCKRESTRICTEDSYNTAX -load(":buck_helpers.bzl", "filter_attributes", "to_oss_deps") +load( + ":buck_helpers.bzl", + "filter_attributes", +) def fb_xplat_cxx_library( name, - deps = [], - exported_deps = [], + extra_flags = {}, **kwgs): if read_config("pt", "is_oss", "0") == "0": fail("This file is for open source pytorch build. Do not use it in fbsource!") cxx_library( name = name, - deps = to_oss_deps(deps), - exported_deps = to_oss_deps(exported_deps), **filter_attributes(kwgs) ) diff --git a/tools/build_defs/fb_xplat_cxx_test.bzl b/tools/build_defs/fb_xplat_cxx_test.bzl new file mode 100644 index 0000000000000..c06176630e08d --- /dev/null +++ b/tools/build_defs/fb_xplat_cxx_test.bzl @@ -0,0 +1,18 @@ +# Only used for PyTorch open source BUCK build +# @lint-ignore-every BUCKRESTRICTEDSYNTAX +load(":buck_helpers.bzl", "filter_attributes") + +def fb_xplat_cxx_test( + name, + deps = [], + **kwgs): + if read_config("pt", "is_oss", "0") == "0": + fail("This file is for open source pytorch build. Do not use it in fbsource!") + + cxx_test( + name = name, + deps = deps + [ + "//third_party:gtest", + ], + **filter_attributes(kwgs) + ) diff --git a/tools/build_defs/fb_xplat_genrule.bzl b/tools/build_defs/fb_xplat_genrule.bzl index 8d61967cdb416..f1e09617888cf 100644 --- a/tools/build_defs/fb_xplat_genrule.bzl +++ b/tools/build_defs/fb_xplat_genrule.bzl @@ -1,7 +1,7 @@ # Only used for PyTorch open source BUCK build # @lint-ignore-every BUCKRESTRICTEDSYNTAX -def fb_xplat_genrule(default_outs = ["."], **kwargs): +def fb_xplat_genrule(default_outs = ["."], apple_sdks = None, **kwargs): if read_config("pt", "is_oss", "0") == "0": fail("This file is for open source pytorch build. Do not use it in fbsource!") diff --git a/tools/build_defs/fb_python_binary.bzl b/tools/build_defs/select.bzl similarity index 64% rename from tools/build_defs/fb_python_binary.bzl rename to tools/build_defs/select.bzl index f403b4bbbda6c..36c86e93cb8db 100644 --- a/tools/build_defs/fb_python_binary.bzl +++ b/tools/build_defs/select.bzl @@ -1,10 +1,8 @@ # Only used for PyTorch open source BUCK build # @lint-ignore-every BUCKRESTRICTEDSYNTAX -load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") - -def fb_python_binary(**kwgs): +def select(conditions): if read_config("pt", "is_oss", "0") == "0": fail("This file is for open source pytorch build. Do not use it in fbsource!") - python_binary(**kwgs) + return conditions["DEFAULT"] diff --git a/tools/build_defs/windows/windows_flag_map.bzl b/tools/build_defs/windows/windows_flag_map.bzl new file mode 100644 index 0000000000000..e25d9897e076b --- /dev/null +++ b/tools/build_defs/windows/windows_flag_map.bzl @@ -0,0 +1,9 @@ +# Only used for PyTorch open source BUCK build +# @lint-ignore-every BUCKRESTRICTEDSYNTAX + +def windows_convert_gcc_clang_flags(flags = []): + if read_config("pt", "is_oss", "0") == "0": + fail("This file is for open source pytorch build. Do not use it in fbsource!") + + # not implemented + return [] diff --git a/tools/build_libtorch.py b/tools/build_libtorch.py index c5508773f643b..3b85d415d0273 100644 --- a/tools/build_libtorch.py +++ b/tools/build_libtorch.py @@ -1,6 +1,6 @@ import argparse -from os.path import dirname, abspath import sys +from os.path import abspath, dirname # By appending pytorch_root to sys.path, this module can import other torch # modules even when run as a standalone script. i.e., it's okay either you diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index eba8ea1dcf66d..dbdf34eeda6e5 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -1,14 +1,15 @@ import os import platform -from glob import glob import shutil +from glob import glob from typing import Dict, Optional -from .setup_helpers.env import IS_64BIT, IS_WINDOWS, check_negative_env_flag -from .setup_helpers.cmake import USE_NINJA, CMake - from setuptools import distutils # type: ignore[import] +from .setup_helpers.cmake import CMake, USE_NINJA + +from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS + def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]: vc_arch = "x64" if IS_64BIT else "x86" diff --git a/tools/code_analyzer/gen_op_registration_allowlist.py b/tools/code_analyzer/gen_op_registration_allowlist.py index 65e56856a7893..b01142c872f1c 100644 --- a/tools/code_analyzer/gen_op_registration_allowlist.py +++ b/tools/code_analyzer/gen_op_registration_allowlist.py @@ -9,11 +9,12 @@ """ import argparse -import yaml from collections import defaultdict from typing import Dict, List, Set +import yaml + DepGraph = Dict[str, Set[str]] diff --git a/tools/code_analyzer/gen_operators_yaml.py b/tools/code_analyzer/gen_operators_yaml.py index 0daa27f0480e4..58b8763c142c7 100644 --- a/tools/code_analyzer/gen_operators_yaml.py +++ b/tools/code_analyzer/gen_operators_yaml.py @@ -2,7 +2,7 @@ import argparse import json import sys -from typing import List, Optional, Dict, Any +from typing import Any, Dict, List, Optional import yaml from gen_op_registration_allowlist import ( @@ -11,8 +11,8 @@ load_op_dep_graph, ) from torchgen.selective_build.operator import ( - SelectiveBuildOperator, merge_operator_dicts, + SelectiveBuildOperator, ) from torchgen.selective_build.selector import merge_kernel_metadata diff --git a/tools/code_analyzer/gen_oplist.py b/tools/code_analyzer/gen_oplist.py index b5d31b9221674..1e5d1277afcdf 100644 --- a/tools/code_analyzer/gen_oplist.py +++ b/tools/code_analyzer/gen_oplist.py @@ -4,16 +4,16 @@ import os import sys from functools import reduce -from typing import Set, List, Any +from typing import Any, List, Set import yaml +from tools.lite_interpreter.gen_selected_mobile_ops_header import ( + write_selected_mobile_ops, +) from torchgen.selective_build.selector import ( combine_selective_builders, SelectiveBuilder, ) -from tools.lite_interpreter.gen_selected_mobile_ops_header import ( - write_selected_mobile_ops, -) def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]: diff --git a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py index a64670b6ada3b..e6d0786a32b05 100644 --- a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py +++ b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py @@ -8,19 +8,20 @@ marked as covered. """ -from coverage import CoveragePlugin, CoverageData # type: ignore[import] from inspect import ( - ismodule, - isclass, - ismethod, - isfunction, - iscode, getsourcefile, getsourcelines, + isclass, + iscode, + isfunction, + ismethod, + ismodule, ) from time import time from typing import Any +from coverage import CoverageData, CoveragePlugin # type: ignore[import] + # All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with # `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link: # https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine diff --git a/tools/download_mnist.py b/tools/download_mnist.py index 80894ad2bdbba..52fa411eda9f8 100644 --- a/tools/download_mnist.py +++ b/tools/download_mnist.py @@ -1,9 +1,9 @@ import argparse import gzip import os +import sys from urllib.error import URLError from urllib.request import urlretrieve -import sys MIRRORS = [ "http://yann.lecun.com/exdb/mnist/", diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py index 0a1ae07c23429..3b79e4f0eac40 100755 --- a/tools/fast_nvcc/fast_nvcc.py +++ b/tools/fast_nvcc/fast_nvcc.py @@ -14,7 +14,7 @@ import subprocess import sys import time -from typing import Awaitable, DefaultDict, Dict, List, Match, Optional, Set, cast +from typing import Awaitable, cast, DefaultDict, Dict, List, Match, Optional, Set from typing_extensions import TypedDict diff --git a/tools/gdb/pytorch-gdb.py b/tools/gdb/pytorch-gdb.py index 0ed516078f769..a955b45967989 100644 --- a/tools/gdb/pytorch-gdb.py +++ b/tools/gdb/pytorch-gdb.py @@ -1,7 +1,8 @@ -import gdb # type: ignore[import] import textwrap from typing import Any +import gdb # type: ignore[import] + class DisableBreakpoints: """ diff --git a/aten/src/ATen/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py similarity index 71% rename from aten/src/ATen/gen_vulkan_spv.py rename to tools/gen_vulkan_spv.py index 0d0906ded60e7..74b1212bdbe26 100644 --- a/aten/src/ATen/gen_vulkan_spv.py +++ b/tools/gen_vulkan_spv.py @@ -4,6 +4,7 @@ import array import glob import os +import re import sys import subprocess from torchgen.code_template import CodeTemplate @@ -15,6 +16,33 @@ def getName(filePath): return os.path.basename(filePath).replace("/", "_").replace(".", "_") +def isDescriptorLine(lineStr): + descriptorLineId = r"^layout\(set" + return re.search(descriptorLineId, lineStr) + +typeIdMapping = { + r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", + r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", + r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", + r"\buniform\b.*\bBlock\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER", +} + +def determineDescriptorType(lineStr): + for identifier, typeNum in typeIdMapping.items(): + if re.search(identifier, lineStr): + return typeNum + + raise Exception("Could not identify descriptor type of line: {}".format(lineStr)) + +def getLayout(srcFilePath): + layout = [] + with open(srcFilePath, 'r') as srcFile: + for line in srcFile: + if isDescriptorLine(line): + layout.append(determineDescriptorType(line)) + + return layout + def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format( hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath)) @@ -27,7 +55,7 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): templateSrcPaths.sort() print("templateSrcPaths:{}".format(templateSrcPaths)) - spvPaths = [] + spvPaths = {} for templateSrcPath in templateSrcPaths: print("templateSrcPath {}".format(templateSrcPath)) name = getName(templateSrcPath).replace("_glsl", "") @@ -52,24 +80,31 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): print("\nglslc cmd:", cmd) subprocess.check_call(cmd) - spvPaths.append(spvPath) + spvPaths[spvPath] = templateSrcPath h = "#pragma once\n" h += "#include \n" - nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n" - nsend = "\n} } } //namespace at::native::vulkan\n" + h += "#include \n" + h += "#include " + + nsbegin = "\nnamespace at {\nnamespace native {\nnamespace vulkan {\n" + nsend = "\n}\n}\n} //namespace at::native::vulkan\n" h += nsbegin cpp = "#include ".format(H_NAME) cpp += nsbegin - for spvPath in spvPaths: + for spvPath, srcPath in spvPaths.items(): name = getName(spvPath) name_len = name + "_len" h += "extern const uint32_t {}[];\n".format(name) h += "extern const uint32_t {};\n".format(name_len) + layout = getLayout(srcPath) + name_layout = name + "_layout" + h += "extern const std::vector {};\n".format(name_layout) + cpp += "const uint32_t " + name + "[] = {\n" sizeBytes = 0 print("spvPath:{}".format(spvPath)) @@ -80,6 +115,12 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): cpp += "};\n" cpp += "const uint32_t {} = {};\n".format(name_len, sizeBytes) + # Add layout + cpp += "const std::vector {} = {{\n".format(name_layout) + for descriptor in layout: + cpp += " {},\n".format(descriptor) + cpp += "};\n" + cpp += nsend h += nsend diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index 792e8aef494b4..96970bd2b1c35 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -3,9 +3,10 @@ import re import subprocess from pathlib import Path -from setuptools import distutils # type: ignore[import] from typing import Optional, Union +from setuptools import distutils # type: ignore[import] + UNKNOWN = "Unknown" RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/") diff --git a/tools/iwyu/fixup.py b/tools/iwyu/fixup.py index 4ce80bb0f52bc..2a585762273be 100644 --- a/tools/iwyu/fixup.py +++ b/tools/iwyu/fixup.py @@ -1,5 +1,5 @@ -import sys import re +import sys QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"') ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>") diff --git a/tools/jit/BUCK.oss b/tools/jit/BUCK.oss deleted file mode 100644 index 8c0105f1cf8e7..0000000000000 --- a/tools/jit/BUCK.oss +++ /dev/null @@ -1,12 +0,0 @@ -python_library( - name = "jit", - srcs = glob([ - "*.py", - "templates/*", - ]), - base_module = "tools.jit", - visibility = ["PUBLIC"], - deps = [ - "//torchgen:torchgen", - ], -) diff --git a/tools/jit/gen_unboxing.py b/tools/jit/gen_unboxing.py index 4973f673a451d..ebeaa21bc7be9 100644 --- a/tools/jit/gen_unboxing.py +++ b/tools/jit/gen_unboxing.py @@ -4,21 +4,20 @@ import pathlib import sys from dataclasses import dataclass -from typing import Union, Sequence, List +from typing import List, Sequence, Union import yaml -from typing_extensions import Literal -from torchgen.api import cpp -from torchgen.api import unboxing +from torchgen.api import cpp, unboxing from torchgen.api.translate import translate from torchgen.api.types import CppSignatureGroup from torchgen.api.unboxing import convert_arguments from torchgen.context import method_with_native_function -from torchgen.gen import parse_native_yaml, cpp_string, get_custom_build_selector -from torchgen.model import NativeFunction, NativeFunctionsGroup, Variant, Argument +from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml +from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant from torchgen.selective_build.selector import SelectiveBuilder -from torchgen.utils import Target, FileManager, mapMaybe, make_file_manager +from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target +from typing_extensions import Literal # Generates UnboxingFunctions.h & UnboxingFunctions.cpp. @@ -159,6 +158,9 @@ def gen_unboxing( def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: return fn.root_name + selected_op_num: int = len(selector.operators) + # a best practice threshold of operators to enable sharding + sharding_threshold: int = 100 cpu_fm.write_sharded( "UnboxingFunctions.cpp", native_functions, @@ -166,7 +168,7 @@ def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: env_callable=lambda fn: { "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)] }, - num_shards=5, + num_shards=1 if selected_op_num < sharding_threshold else 5, sharded_keys={"definitions"}, ) cpu_fm.write( @@ -187,7 +189,7 @@ def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: env_callable=lambda fn: { "unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)] }, - num_shards=10, + num_shards=1 if selected_op_num < sharding_threshold else 10, sharded_keys={"unboxed_ops"}, ) @@ -246,7 +248,7 @@ def main(args: List[str]) -> None: op_registration_allowlist = None selector = get_custom_build_selector( - options.op_registration_allowlist, + op_registration_allowlist, options.op_selection_yaml_path, ) diff --git a/tools/jit/test/test_gen_unboxing.py b/tools/jit/test/test_gen_unboxing.py index 70a37add4fd8c..de016b1642229 100644 --- a/tools/jit/test/test_gen_unboxing.py +++ b/tools/jit/test/test_gen_unboxing.py @@ -1,6 +1,6 @@ import tempfile import unittest -from unittest.mock import patch, NonCallableMock +from unittest.mock import NonCallableMock, patch import tools.jit.gen_unboxing as gen_unboxing diff --git a/tools/linter/adapters/actionlint_linter.py b/tools/linter/adapters/actionlint_linter.py index bbc93954eda4d..d9131b37ec007 100644 --- a/tools/linter/adapters/actionlint_linter.py +++ b/tools/linter/adapters/actionlint_linter.py @@ -1,8 +1,9 @@ import argparse -import os -import re +import concurrent.futures import json import logging +import os +import re import subprocess import time from enum import Enum @@ -60,12 +61,12 @@ def run_command( logging.debug("took %dms", (end_time - start_time) * 1000) -def check_files( +def check_file( binary: str, - files: List[str], + file: str, ) -> List[LintMessage]: try: - proc = run_command([binary] + files) + proc = run_command([binary, file]) except OSError as err: return [ LintMessage( @@ -133,6 +134,22 @@ def check_files( print(json.dumps(err_msg._asdict()), flush=True) exit(0) - lint_messages = check_files(args.binary, args.filenames) - for lint_message in lint_messages: - print(json.dumps(lint_message._asdict()), flush=True) + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = { + executor.submit( + check_file, + args.binary, + filename, + ): filename + for filename in args.filenames + } + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + print(json.dumps(lint_message._asdict()), flush=True) + except Exception: + logging.critical('Failed at "%s".', futures[future]) + raise diff --git a/tools/linter/adapters/black_linter.py b/tools/linter/adapters/black_linter.py index 9d259fe096b84..8459b6a1e1427 100644 --- a/tools/linter/adapters/black_linter.py +++ b/tools/linter/adapters/black_linter.py @@ -7,7 +7,7 @@ import sys import time from enum import Enum -from typing import Any, List, NamedTuple, Optional, BinaryIO +from typing import Any, BinaryIO, List, NamedTuple, Optional IS_WINDOWS: bool = os.name == "nt" diff --git a/tools/linter/adapters/circleci_linter.py b/tools/linter/adapters/circleci_linter.py index 8a76ed396f9fd..6200b383ee351 100644 --- a/tools/linter/adapters/circleci_linter.py +++ b/tools/linter/adapters/circleci_linter.py @@ -2,15 +2,15 @@ Checks that the configuration in .circleci/config.yml has been properly regenerated. """ -import os import argparse +import json +import logging +import os import subprocess import sys -import logging import time from enum import Enum from typing import List, NamedTuple, Optional -import json CHECKED_IN_FILE = "config.yml" diff --git a/tools/linter/adapters/flake8_linter.py b/tools/linter/adapters/flake8_linter.py index 20274432566c9..26f8dd8eec3ff 100644 --- a/tools/linter/adapters/flake8_linter.py +++ b/tools/linter/adapters/flake8_linter.py @@ -7,7 +7,7 @@ import sys import time from enum import Enum -from typing import Any, Dict, List, NamedTuple, Optional, Set, Pattern +from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Set IS_WINDOWS: bool = os.name == "nt" diff --git a/tools/linter/adapters/grep_linter.py b/tools/linter/adapters/grep_linter.py index 61a81ad12dc3a..f6bd714eb4a7f 100644 --- a/tools/linter/adapters/grep_linter.py +++ b/tools/linter/adapters/grep_linter.py @@ -61,16 +61,51 @@ def run_command( def lint_file( matching_line: str, + allowlist_pattern: str, replace_pattern: str, linter_name: str, error_name: str, error_description: str, -) -> LintMessage: +) -> Optional[LintMessage]: # matching_line looks like: # tools/linter/clangtidy_linter.py:13:import foo.bar.baz split = matching_line.split(":") filename = split[0] + if allowlist_pattern: + try: + proc = run_command(["grep", "-nEHI", allowlist_pattern, filename]) + except Exception as err: + return LintMessage( + path=None, + line=None, + char=None, + code=linter_name, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + "COMMAND (exit code {returncode})\n" + "{command}\n\n" + "STDERR\n{stderr}\n\n" + "STDOUT\n{stdout}" + ).format( + returncode=err.returncode, + command=" ".join(as_posix(x) for x in err.cmd), + stderr=err.stderr.decode("utf-8").strip() or "(empty)", + stdout=err.stdout.decode("utf-8").strip() or "(empty)", + ) + ), + ) + + # allowlist pattern was found, abort lint + if proc.returncode == 0: + return None + original = None replacement = None if replace_pattern: @@ -109,7 +144,7 @@ def lint_file( return LintMessage( path=split[0], - line=int(split[1]), + line=int(split[1]) if len(split) > 1 else None, char=None, code=linter_name, severity=LintSeverity.ERROR, @@ -130,11 +165,20 @@ def main() -> None: required=True, help="pattern to grep for", ) + parser.add_argument( + "--allowlist-pattern", + help="if this pattern is true in the file, we don't grep for pattern", + ) parser.add_argument( "--linter-name", required=True, help="name of the linter", ) + parser.add_argument( + "--match-first-only", + action="store_true", + help="only match the first hit in the file", + ) parser.add_argument( "--error-name", required=True, @@ -174,8 +218,14 @@ def main() -> None: stream=sys.stderr, ) + files_with_matches = [] + if args.match_first_only: + files_with_matches = ["--files-with-matches"] + try: - proc = run_command(["grep", "-nEHI", args.pattern, *args.filenames]) + proc = run_command( + ["grep", "-nEHI", *files_with_matches, args.pattern, *args.filenames] + ) except Exception as err: err_msg = LintMessage( path=None, @@ -209,12 +259,14 @@ def main() -> None: for line in lines: lint_message = lint_file( line, + args.allowlist_pattern, args.replace_pattern, args.linter_name, args.error_name, args.error_description, ) - print(json.dumps(lint_message._asdict()), flush=True) + if lint_message is not None: + print(json.dumps(lint_message._asdict()), flush=True) if __name__ == "__main__": diff --git a/tools/linter/adapters/nativefunctions_linter.py b/tools/linter/adapters/nativefunctions_linter.py index 28065f2b7af46..12a6c7e0062df 100644 --- a/tools/linter/adapters/nativefunctions_linter.py +++ b/tools/linter/adapters/nativefunctions_linter.py @@ -14,14 +14,15 @@ the YAML, not to be prescriptive about it. """ -import ruamel.yaml # type: ignore[import] import argparse import json import sys -from io import StringIO from enum import Enum +from io import StringIO from typing import NamedTuple, Optional +import ruamel.yaml # type: ignore[import] + class LintSeverity(str, Enum): ERROR = "error" diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index d67c38ce9f5ad..f921bcb73e331 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -1,9 +1,9 @@ """ Initializer script that installs stuff to pip. """ -import os import argparse import logging +import os import subprocess import sys import time diff --git a/tools/linter/adapters/s3_init_config.json b/tools/linter/adapters/s3_init_config.json index 736ab6addb842..0b0e87e8e26cf 100644 --- a/tools/linter/adapters/s3_init_config.json +++ b/tools/linter/adapters/s3_init_config.json @@ -1,4 +1,10 @@ { + "HOW TO UPDATE THE BINARIES": [ + "Upload the new file to S3 under a new folder with the version number embedded in (see actionlint for an example).", + "(Don't override the old files, otherwise you'll break `lintrunner install` for anyone using an older commit of pytorch.)", + "'Hash' is the sha256 of the uploaded file.", + "Validate the new download url and hash by running 'lintrunner init' to pull the new binaries and then run 'lintrunner' to try linting the files." + ], "clang-format": { "Darwin": { "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/mac/clang-format-mojave", @@ -21,12 +27,12 @@ }, "actionlint": { "Darwin": { - "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/macos/actionlint", - "hash": "3ce2c94280c540e20b270acae60bdd9e72ad17d6cb35b688951b1ec1eb8cbdd6" + "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/actionlint/1.6.15/Darwin_amd64/actionlint", + "hash": "e9a0e0b17e54cfefe7964b6aa1da8921b1f8f2318c31c0eb1a17ea3e8ab10db2" }, "Linux": { - "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/actionlint", - "hash": "693f464106474760f0edf4a1778215095eacc4bd5f79aab5dc950892f120828b" + "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/actionlint/1.6.15/Linux_arm64/actionlint", + "hash": "d6b45ae67f29a2bf9ddd226071ddd8f158fdf2992e8515a06838e5fef90f3a2d" } } } diff --git a/tools/linter/adapters/shellcheck_linter.py b/tools/linter/adapters/shellcheck_linter.py index d94c5a1ce0478..025595d39f29e 100644 --- a/tools/linter/adapters/shellcheck_linter.py +++ b/tools/linter/adapters/shellcheck_linter.py @@ -1,9 +1,9 @@ import argparse import json import logging +import shutil import subprocess import time -import shutil from enum import Enum from typing import List, NamedTuple, Optional diff --git a/tools/linter/adapters/testowners_linter.py b/tools/linter/adapters/testowners_linter.py index b65cfde4d79dc..dfd5172a39e10 100755 --- a/tools/linter/adapters/testowners_linter.py +++ b/tools/linter/adapters/testowners_linter.py @@ -8,10 +8,10 @@ - Each owner label actually exists in PyTorch - Each owner label starts with "module: " or "oncall: " or is in ACCEPTABLE_OWNER_LABELS """ -import json import argparse +import json from enum import Enum -from typing import List, Any, Optional, NamedTuple +from typing import Any, List, NamedTuple, Optional from urllib.request import urlopen diff --git a/tools/linter/adapters/ufmt_linter.py b/tools/linter/adapters/ufmt_linter.py new file mode 100644 index 0000000000000..7174e83212990 --- /dev/null +++ b/tools/linter/adapters/ufmt_linter.py @@ -0,0 +1,141 @@ +import argparse +import concurrent.futures +import json +import logging +import os +import sys +from enum import Enum +from pathlib import Path +from typing import Any, List, NamedTuple, Optional + +from ufmt.core import make_black_config, ufmt_string +from usort import Config as UsortConfig + + +IS_WINDOWS: bool = os.name == "nt" + + +def eprint(*args: Any, **kwargs: Any) -> None: + print(*args, file=sys.stderr, flush=True, **kwargs) + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: Optional[str] + line: Optional[int] + char: Optional[int] + code: str + severity: LintSeverity + name: str + original: Optional[str] + replacement: Optional[str] + description: Optional[str] + + +def as_posix(name: str) -> str: + return name.replace("\\", "/") if IS_WINDOWS else name + + +def format_error_message(filename: str, err: Exception) -> LintMessage: + return LintMessage( + path=filename, + line=None, + char=None, + code="UFMT", + severity=LintSeverity.ADVICE, + name="command-failed", + original=None, + replacement=None, + description=(f"Failed due to {err.__class__.__name__}:\n{err}"), + ) + + +def check_file( + filename: str, +) -> List[LintMessage]: + with open(filename, "rb") as f: + original = f.read().decode("utf-8") + + try: + path = Path(filename) + + usort_config = UsortConfig.find(path) + black_config = make_black_config(path) + + # Use UFMT API to call both usort and black + replacement = ufmt_string( + path=path, + content=original, + usort_config=usort_config, + black_config=black_config, + ) + + if original == replacement: + return [] + + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="UFMT", + severity=LintSeverity.WARNING, + name="format", + original=original, + replacement=replacement, + description="Run `lintrunner -a` to apply this patch.", + ) + ] + except Exception as err: + return [format_error_message(filename, err)] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Format files with ufmt (black + usort).", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = {executor.submit(check_file, x): x for x in args.filenames} + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + print(json.dumps(lint_message._asdict()), flush=True) + except Exception: + logging.critical('Failed at "%s".', futures[future]) + raise + + +if __name__ == "__main__": + main() diff --git a/tools/linter/adapters/update_s3.py b/tools/linter/adapters/update_s3.py index 847b8b2d158fa..5f19b47255be4 100644 --- a/tools/linter/adapters/update_s3.py +++ b/tools/linter/adapters/update_s3.py @@ -6,11 +6,12 @@ """ import argparse -import boto3 # type: ignore[import] -import json import hashlib -import os +import json import logging +import os + +import boto3 # type: ignore[import] def compute_file_sha256(path: str) -> str: diff --git a/tools/linter/adapters/workflow_consistency_linter.py b/tools/linter/adapters/workflow_consistency_linter.py new file mode 100644 index 0000000000000..6e5fb4db20ff2 --- /dev/null +++ b/tools/linter/adapters/workflow_consistency_linter.py @@ -0,0 +1,115 @@ +"""Checks for consistency of jobs between different GitHub workflows. + +Any job with a specific `sync-tag` must match all other jobs with the same `sync-tag`. +""" +import argparse +import itertools +import json +from collections import defaultdict +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Iterable, NamedTuple, Optional + +from yaml import CSafeLoader, dump, load + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: Optional[str] + line: Optional[int] + char: Optional[int] + code: str + severity: LintSeverity + name: str + original: Optional[str] + replacement: Optional[str] + description: Optional[str] + + +def glob_yamls(path: Path) -> Iterable[Path]: + return itertools.chain(path.glob("**/*.yml"), path.glob("**/*.yaml")) + + +def load_yaml(path: Path) -> Any: + with open(path) as f: + return load(f, CSafeLoader) + + +def is_workflow(yaml: Any) -> bool: + return yaml.get("jobs") is not None + + +def print_lint_message(path: Path, job: Dict[str, Any], sync_tag: str) -> None: + job_id = list(job.keys())[0] + with open(path) as f: + lines = f.readlines() + for i, line in enumerate(lines): + if f"{job_id}:" in line: + line_number = i + 1 + + lint_message = LintMessage( + path=str(path), + line=line_number, + char=None, + code="WORKFLOWSYNC", + severity=LintSeverity.ERROR, + name="workflow-inconsistency", + original=None, + replacement=None, + description=f"Job doesn't match other jobs with sync-tag: '{sync_tag}'", + ) + print(json.dumps(lint_message._asdict()), flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="workflow consistency linter.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + # Go through the provided files, aggregating jobs with the same sync tag + tag_to_jobs = defaultdict(list) + for path in args.filenames: + workflow = load_yaml(Path(path)) + jobs = workflow["jobs"] + for job_id, job in jobs.items(): + try: + sync_tag = job["with"]["sync-tag"] + except KeyError: + continue + + # remove the "if" field, which we allow to be different between jobs + # (since you might have different triggering conditions on pull vs. + # trunk, say.) + if "if" in job: + del job["if"] + + tag_to_jobs[sync_tag].append((path, {job_id: job})) + + # For each sync tag, check that all the jobs have the same code. + for sync_tag, path_and_jobs in tag_to_jobs.items(): + baseline_path, baseline_dict = path_and_jobs.pop() + baseline_str = dump(baseline_dict) + + printed_baseline = False + + for path, job_dict in path_and_jobs: + job_str = dump(job_dict) + if baseline_str != job_str: + print_lint_message(path, job_dict, sync_tag) + + if not printed_baseline: + print_lint_message(baseline_path, baseline_dict, sync_tag) + printed_baseline = True diff --git a/tools/linter/clang_tidy/generate_build_files.py b/tools/linter/clang_tidy/generate_build_files.py index fff8bf492e0fd..35f1b81d89893 100644 --- a/tools/linter/clang_tidy/generate_build_files.py +++ b/tools/linter/clang_tidy/generate_build_files.py @@ -1,6 +1,6 @@ +import os import subprocess import sys -import os from typing import List diff --git a/tools/lite_interpreter/BUCK.oss b/tools/lite_interpreter/BUCK.oss deleted file mode 100644 index 10415c26aee78..0000000000000 --- a/tools/lite_interpreter/BUCK.oss +++ /dev/null @@ -1,6 +0,0 @@ -python_library( - name = "gen_selected_mobile_ops_header", - srcs = ["gen_selected_mobile_ops_header.py"], - base_module = "tools.lite_interpreter", - visibility = ["PUBLIC"], -) diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py index 37cd9e6903bf5..aebb36ca156b6 100644 --- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py +++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py @@ -2,10 +2,10 @@ import argparse import os from typing import Set -from torchgen.selective_build.selector import SelectiveBuilder -from torchgen.code_template import CodeTemplate import yaml +from torchgen.code_template import CodeTemplate +from torchgen.selective_build.selector import SelectiveBuilder # Safely load fast C Yaml loader/dumper if they are available try: diff --git a/tools/miniz_target_definition.bzl b/tools/miniz_target_definition.bzl index 7040ff6beaa10..49eaa00643b2f 100644 --- a/tools/miniz_target_definition.bzl +++ b/tools/miniz_target_definition.bzl @@ -5,13 +5,13 @@ def add_miniz_lib(): cpp_library( name = "miniz", srcs = [ - "third_party/miniz-2.0.8/fb/FollyCrcPlugin.cpp", - "third_party/miniz-2.0.8/fb/miniz-fb.c", + "third_party/miniz-2.1.0/fb/FollyCrcPlugin.cpp", + "third_party/miniz-2.1.0/fb/miniz-fb.c", ], headers = { - "caffe2/third_party/miniz-2.0.8/miniz.c": "third_party/miniz-2.0.8/miniz.c", - "miniz-fb.h": "third_party/miniz-2.0.8/fb/miniz-fb.h", - "miniz.h": "third_party/miniz-2.0.8/miniz.h", + "caffe2/third_party/miniz-2.1.0/miniz.c": "third_party/miniz-2.1.0/miniz.c", + "miniz-fb.h": "third_party/miniz-2.1.0/fb/miniz-fb.h", + "miniz.h": "third_party/miniz-2.1.0/miniz.h", }, header_namespace = "", # -fexceptions is required, otherwise, when we use @mode/opt-clang-thinlto, diff --git a/tools/nightly.py b/tools/nightly.py index 32733c5d9477e..4d1c9291fd8bb 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -24,25 +24,26 @@ Pulling will reinstalle the conda dependencies as well as the nightly binaries into the repo directory. """ +import contextlib +import datetime +import functools +import glob +import json +import logging import os import re +import shutil +import subprocess import sys -import json -import glob +import tempfile import time import uuid -import shutil -import logging -import datetime -import tempfile -import functools -import contextlib -import subprocess -from ast import literal_eval from argparse import ArgumentParser +from ast import literal_eval from typing import ( Any, Callable, + cast, Dict, Generator, Iterable, @@ -53,7 +54,6 @@ Set, Tuple, TypeVar, - cast, ) LOGGER: Optional[logging.Logger] = None diff --git a/tools/nvcc_fix_deps.py b/tools/nvcc_fix_deps.py index 0f9fe90124ffc..bb420d90308b4 100644 --- a/tools/nvcc_fix_deps.py +++ b/tools/nvcc_fix_deps.py @@ -13,10 +13,10 @@ """ -import sys import subprocess +import sys from pathlib import Path -from typing import List, TextIO, Optional +from typing import List, Optional, TextIO def resolve_include(path: Path, include_dirs: List[Path]) -> Path: diff --git a/tools/perf_kernel_defs.bzl b/tools/perf_kernel_defs.bzl index 2a699840c8bf8..dfefb734b111b 100644 --- a/tools/perf_kernel_defs.bzl +++ b/tools/perf_kernel_defs.bzl @@ -3,7 +3,14 @@ load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library") is_dbg_build = native.read_config("fbcode", "build_mode", "").find("dbg") != -1 is_sanitizer = native.read_config("fbcode", "sanitizer", "") != "" -def define_perf_kernels(prefix, levels_and_flags, compiler_common_flags, dependencies, external_deps): +def define_perf_kernels( + prefix, + levels_and_flags, + compiler_common_flags = [], + arch_compiler_common_flags = {}, + dependencies = [], + arch_dependencies = [], + external_deps = []): vectorize_flags = ([ # "-Rpass=loop-vectorize", # Add vectorization information to output "-DENABLE_VECTORIZATION=1", @@ -30,25 +37,30 @@ def define_perf_kernels(prefix, levels_and_flags, compiler_common_flags, depende ["**/*.h"], ) - kernel_targets = [] - for level, flags in levels_and_flags: - cpp_library( - name = prefix + "perfkernels_" + level, - srcs = native.glob(["**/*_" + level + ".cc"]), - headers = cpp_headers, - compiler_flags = compiler_common_flags + flags, - compiler_specific_flags = compiler_specific_flags, - exported_deps = dependencies, - exported_external_deps = external_deps, - ) - kernel_targets.append(":" + prefix + "perfkernels_" + level) + kernel_targets = {} + for arch, levels_and_flags in levels_and_flags.items(): + for level, flags in levels_and_flags: + cpp_library( + name = prefix + "perfkernels_" + level, + srcs = native.glob(["**/*_" + level + ".cc"]), + headers = cpp_headers, + compiler_flags = compiler_common_flags + flags, + arch_compiler_flags = arch_compiler_common_flags, + compiler_specific_flags = compiler_specific_flags, + exported_deps = dependencies, + exported_arch_deps = arch_dependencies, + exported_external_deps = external_deps, + ) + kernel_targets.setdefault(arch, []).append(":" + prefix + "perfkernels_" + level) cpp_library( name = prefix + "perfkernels", srcs = common_srcs, headers = cpp_headers, compiler_flags = compiler_common_flags, + arch_compiler_flags = arch_compiler_common_flags, compiler_specific_flags = compiler_specific_flags, link_whole = True, - exported_deps = kernel_targets + dependencies, + exported_arch_deps = kernel_targets.items() + arch_dependencies, + exported_deps = dependencies, ) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index eef5279268c40..79f5ee2979b34 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1,21 +1,22 @@ import argparse import collections from pprint import pformat +from typing import Dict, List, Sequence -from torchgen.model import Variant from torchgen.api.python import ( PythonSignatureGroup, PythonSignatureNativeFunctionPair, returns_named_tuple_pyi, ) from torchgen.gen import parse_native_yaml + +from torchgen.model import Variant from torchgen.utils import FileManager -from typing import Sequence, List, Dict from tools.autograd.gen_python_functions import ( - should_generate_py_binding, - load_signatures, group_overloads, + load_signatures, + should_generate_py_binding, ) """ @@ -432,6 +433,12 @@ def gen_pyi( "_is_functional_tensor": [ "def _is_functional_tensor(t: Tensor) -> _bool: ..." ], + "_from_functional_tensor": [ + "def _from_functional_tensor(t: Tensor) -> Tensor: ..." + ], + "_to_functional_tensor": [ + "def _to_functional_tensor(t: Tensor) -> Tensor: ..." + ], "range": [ "def range(start: Number, end: Number," " step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...".format( @@ -618,7 +625,7 @@ def gen_pyi( "as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."], "_make_subclass": [ "def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False," - " dispatch_device: _bool=False) -> Tensor: ..." + " dispatch_device: _bool=False, device_for_backend_keys: Optional[_device] = None) -> Tensor: ..." ], "__getitem__": ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)], "__setitem__": [ @@ -689,8 +696,8 @@ def gen_pyi( "def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..." ], "set_": [ - "def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...", - "def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...", + "def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...", + "def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...", ], "split": [ "def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...", diff --git a/tools/render_junit.py b/tools/render_junit.py index 68adadde04491..95c281d99d492 100644 --- a/tools/render_junit.py +++ b/tools/render_junit.py @@ -5,7 +5,13 @@ from typing import Any, List, Union try: - from junitparser import JUnitXml, TestSuite, TestCase, Error, Failure # type: ignore[import] + from junitparser import ( # type: ignore[import] + Error, + Failure, + JUnitXml, + TestCase, + TestSuite, + ) except ImportError: raise ImportError( "junitparser not found, please install with 'pip install junitparser'" diff --git a/tools/setup_helpers/BUCK.oss b/tools/setup_helpers/BUCK.oss deleted file mode 100644 index afcd31fb3a03d..0000000000000 --- a/tools/setup_helpers/BUCK.oss +++ /dev/null @@ -1,41 +0,0 @@ -python_library( - name = "generate_code", - srcs = [ - "generate_code.py", - ], - base_module = "tools.setup_helpers", - deps = [ - "//tools/autograd:autograd", - "//tools/jit:jit", - "//torchgen:torchgen", - ], -) - -python_binary( - name = "generate_code_bin", - main_module = "tools.setup_helpers.generate_code", - visibility = ["PUBLIC"], - # package_style = "inplace", - zip_safe = False, - deps = [ - ":generate_code", - ], -) - -python_library( - name = "gen-version-header-lib", - srcs = [ - "gen_version_header.py", - ], - base_module = "tools.setup_helpers", - deps = [], -) - -python_binary( - name = "gen-version-header", - main_module = "tools.setup_helpers.gen_version_header", - visibility = ["PUBLIC"], - deps = [ - ":gen-version-header-lib", - ], -) diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index ee314bdc8af99..4f4e22e2e0aed 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -4,16 +4,16 @@ import multiprocessing import os import platform -import re -from subprocess import check_call, check_output, CalledProcessError import sys import sysconfig from distutils.version import LooseVersion -from typing import IO, Any, Dict, List, Optional, Union, cast +from subprocess import CalledProcessError, check_call, check_output +from typing import Any, cast, Dict, List, Optional from . import which -from .env import BUILD_DIR, IS_64BIT, IS_DARWIN, IS_WINDOWS, check_negative_env_flag -from .numpy_ import USE_NUMPY, NUMPY_INCLUDE_DIR +from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file +from .env import BUILD_DIR, check_negative_env_flag, IS_64BIT, IS_DARWIN, IS_WINDOWS +from .numpy_ import NUMPY_INCLUDE_DIR, USE_NUMPY def _mkdir_p(d: str) -> None: @@ -31,84 +31,6 @@ def _mkdir_p(d: str) -> None: USE_NINJA = not check_negative_env_flag("USE_NINJA") and which("ninja") is not None -CMakeValue = Optional[Union[bool, str]] - - -def convert_cmake_value_to_python_value( - cmake_value: str, cmake_type: str -) -> CMakeValue: - r"""Convert a CMake value in a string form to a Python value. - - Args: - cmake_value (string): The CMake value in a string form (e.g., "ON", "OFF", "1"). - cmake_type (string): The CMake type of :attr:`cmake_value`. - - Returns: - A Python value corresponding to :attr:`cmake_value` with type :attr:`cmake_type`. - """ - - cmake_type = cmake_type.upper() - up_val = cmake_value.upper() - if cmake_type == "BOOL": - # https://gitlab.kitware.com/cmake/community/wikis/doc/cmake/VariablesListsStrings#boolean-values-in-cmake - return not ( - up_val in ("FALSE", "OFF", "N", "NO", "0", "", "NOTFOUND") - or up_val.endswith("-NOTFOUND") - ) - elif cmake_type == "FILEPATH": - if up_val.endswith("-NOTFOUND"): - return None - else: - return cmake_value - else: # Directly return the cmake_value. - return cmake_value - - -def get_cmake_cache_variables_from_file( - cmake_cache_file: IO[str], -) -> Dict[str, CMakeValue]: - r"""Gets values in CMakeCache.txt into a dictionary. - - Args: - cmake_cache_file: A CMakeCache.txt file object. - Returns: - dict: A ``dict`` containing the value of cached CMake variables. - """ - - results = dict() - for i, line in enumerate(cmake_cache_file, 1): - line = line.strip() - if not line or line.startswith(("#", "//")): - # Blank or comment line, skip - continue - - # Almost any character can be part of variable name and value. As a practical matter, we assume the type must be - # valid if it were a C variable name. It should match the following kinds of strings: - # - # USE_CUDA:BOOL=ON - # "USE_CUDA":BOOL=ON - # USE_CUDA=ON - # USE_CUDA:=ON - # Intel(R) MKL-DNN_SOURCE_DIR:STATIC=/path/to/pytorch/third_party/ideep/mkl-dnn - # "OpenMP_COMPILE_RESULT_CXX_openmp:experimental":INTERNAL=FALSE - matched = re.match( - r'("?)(.+?)\1(?::\s*([a-zA-Z_-][a-zA-Z0-9_-]*)?)?\s*=\s*(.*)', line - ) - if matched is None: # Illegal line - raise ValueError( - "Unexpected line {} in {}: {}".format(i, repr(cmake_cache_file), line) - ) - _, variable, type_, value = matched.groups() - if type_ is None: - type_ = "" - if type_.upper() in ("INTERNAL", "STATIC"): - # CMake internal variable, do not touch - continue - results[variable] = convert_cmake_value_to_python_value(value, type_) - - return results - - class CMake: "Manages cmake." @@ -307,6 +229,7 @@ def generate( "WERROR", "OPENSSL_ROOT_DIR", "STATIC_DISPATCH_BACKEND", + "SELECTED_OP_LIST", ) } ) diff --git a/tools/setup_helpers/cmake_utils.py b/tools/setup_helpers/cmake_utils.py new file mode 100644 index 0000000000000..8fb41c913e256 --- /dev/null +++ b/tools/setup_helpers/cmake_utils.py @@ -0,0 +1,85 @@ +""" +This is refactored from cmake.py to avoid circular imports issue with env.py, +which calls get_cmake_cache_variables_from_file +""" + +import re +from typing import Dict, IO, Optional, Union + + +CMakeValue = Optional[Union[bool, str]] + + +def convert_cmake_value_to_python_value( + cmake_value: str, cmake_type: str +) -> CMakeValue: + r"""Convert a CMake value in a string form to a Python value. + + Args: + cmake_value (string): The CMake value in a string form (e.g., "ON", "OFF", "1"). + cmake_type (string): The CMake type of :attr:`cmake_value`. + + Returns: + A Python value corresponding to :attr:`cmake_value` with type :attr:`cmake_type`. + """ + + cmake_type = cmake_type.upper() + up_val = cmake_value.upper() + if cmake_type == "BOOL": + # https://gitlab.kitware.com/cmake/community/wikis/doc/cmake/VariablesListsStrings#boolean-values-in-cmake + return not ( + up_val in ("FALSE", "OFF", "N", "NO", "0", "", "NOTFOUND") + or up_val.endswith("-NOTFOUND") + ) + elif cmake_type == "FILEPATH": + if up_val.endswith("-NOTFOUND"): + return None + else: + return cmake_value + else: # Directly return the cmake_value. + return cmake_value + + +def get_cmake_cache_variables_from_file( + cmake_cache_file: IO[str], +) -> Dict[str, CMakeValue]: + r"""Gets values in CMakeCache.txt into a dictionary. + + Args: + cmake_cache_file: A CMakeCache.txt file object. + Returns: + dict: A ``dict`` containing the value of cached CMake variables. + """ + + results = dict() + for i, line in enumerate(cmake_cache_file, 1): + line = line.strip() + if not line or line.startswith(("#", "//")): + # Blank or comment line, skip + continue + + # Almost any character can be part of variable name and value. As a practical matter, we assume the type must be + # valid if it were a C variable name. It should match the following kinds of strings: + # + # USE_CUDA:BOOL=ON + # "USE_CUDA":BOOL=ON + # USE_CUDA=ON + # USE_CUDA:=ON + # Intel(R) MKL-DNN_SOURCE_DIR:STATIC=/path/to/pytorch/third_party/ideep/mkl-dnn + # "OpenMP_COMPILE_RESULT_CXX_openmp:experimental":INTERNAL=FALSE + matched = re.match( + r'("?)(.+?)\1(?::\s*([a-zA-Z_-][a-zA-Z0-9_-]*)?)?\s*=\s*(.*)', line + ) + if matched is None: # Illegal line + raise ValueError( + "Unexpected line {} in {}: {}".format(i, repr(cmake_cache_file), line) + ) + _, variable, type_, value = matched.groups() + if type_ is None: + type_ = "" + if type_.upper() in ("INTERNAL", "STATIC"): + # CMake internal variable, do not touch + continue + results[variable] = convert_cmake_value_to_python_value(value, type_) + + return results diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py index bf693cacc381b..cb0c4650e6912 100644 --- a/tools/setup_helpers/env.py +++ b/tools/setup_helpers/env.py @@ -3,7 +3,7 @@ import struct import sys from itertools import chain -from typing import Iterable, List, Optional, cast +from typing import cast, Iterable, List, Optional IS_WINDOWS = platform.system() == "Windows" @@ -62,7 +62,7 @@ def __init__(self, cmake_build_type_env: Optional[str] = None) -> None: cmake_cache_txt = os.path.join(BUILD_DIR, "CMakeCache.txt") if os.path.isfile(cmake_cache_txt): # Found CMakeCache.txt. Use the build type specified in it. - from .cmake import get_cmake_cache_variables_from_file + from .cmake_utils import get_cmake_cache_variables_from_file with open(cmake_cache_txt) as f: cmake_cache_vars = get_cmake_cache_variables_from_file(f) diff --git a/tools/setup_helpers/gen_version_header.py b/tools/setup_helpers/gen_version_header.py index bd576af6f1114..cdfd5372fd696 100644 --- a/tools/setup_helpers/gen_version_header.py +++ b/tools/setup_helpers/gen_version_header.py @@ -4,7 +4,7 @@ import argparse import os -from typing import Dict, Tuple, cast +from typing import cast, Dict, Tuple Version = Tuple[int, int, int] diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index 4440e6c2e0a2e..8defd769539a4 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -2,8 +2,9 @@ import os import pathlib import sys +from typing import Any, cast, Optional + import yaml -from typing import Any, Optional, cast try: # use faster C loader if available @@ -25,10 +26,11 @@ def generate_code( force_schema_registration: bool = False, operator_selector: Any = None, ) -> None: - from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python - from tools.autograd.gen_annotated_fn_args import gen_annotated from torchgen.selective_build.selector import SelectiveBuilder + from tools.autograd.gen_annotated_fn_args import gen_annotated + from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python + # Build ATen based Variable classes if install_dir is None: install_dir = os.fspath(gen_dir / "torch/csrc") @@ -207,8 +209,8 @@ def main() -> None: assert os.path.isfile( ts_native_functions ), f"Unable to access {ts_native_functions}" - from torchgen.gen_lazy_tensor import run_gen_lazy_tensor from torchgen.dest.lazy_ir import GenTSLazyIR + from torchgen.gen_lazy_tensor import run_gen_lazy_tensor run_gen_lazy_tensor( aten_path=aten_path, diff --git a/tools/sgx_caffe2_target_definitions.bzl b/tools/sgx_caffe2_target_definitions.bzl index 551244fe8c964..d7277298cc9b6 100644 --- a/tools/sgx_caffe2_target_definitions.bzl +++ b/tools/sgx_caffe2_target_definitions.bzl @@ -225,24 +225,26 @@ def add_sgx_perf_kernel_libs(): # these are esentially disabled for hte sgx build but we still need them # to avoid linking issues - levels_and_flags = ([ - ( - "avx2", - [ - "-mavx2", - "-mfma", - "-mavx", - "-mf16c", - ], - ), - ( - "avx", - [ - "-mavx", - "-mf16c", - ], - ), - ]) + levels_and_flags = { + "x86_64": [ + ( + "avx2", + [ + "-mavx2", + "-mfma", + "-mavx", + "-mf16c", + ], + ), + ( + "avx", + [ + "-mavx", + "-mf16c", + ], + ), + ], + } define_perf_kernels( prefix = "sgx_", diff --git a/tools/shared/__init__.py b/tools/shared/__init__.py index 36c0f2bdbca0c..6bcc9aa6271eb 100644 --- a/tools/shared/__init__.py +++ b/tools/shared/__init__.py @@ -1,2 +1,2 @@ -from .module_loader import import_module from .cwrap_common import set_declaration_defaults, sort_by_number_of_args +from .module_loader import import_module diff --git a/tools/stats/export_slow_tests.py b/tools/stats/export_slow_tests.py deleted file mode 100644 index 13afbf984a235..0000000000000 --- a/tools/stats/export_slow_tests.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import json -import os -import statistics -from collections import defaultdict -from tools.stats.s3_stat_parser import ( - get_previous_reports_for_branch, - Report, - Version2Report, -) -from typing import cast, DefaultDict, Dict, List, Any -from urllib.request import urlopen - -SLOW_TESTS_FILE = ".pytorch-slow-tests.json" -SLOW_TEST_CASE_THRESHOLD_SEC = 60.0 -RELATIVE_DIFFERENCE_THRESHOLD = 0.1 -IGNORED_JOBS = ["asan", "periodic"] - - -def get_test_case_times() -> Dict[str, float]: - reports: List[Report] = get_previous_reports_for_branch("origin/viable/strict", "") - # an entry will be like ("test_doc_examples (__main__.TestTypeHints)" -> [values])) - test_names_to_times: DefaultDict[str, List[float]] = defaultdict(list) - for report in reports: - if report.get("format_version", 1) != 2: # type: ignore[misc] - raise RuntimeError("S3 format currently handled is version 2 only") - v2report = cast(Version2Report, report) - - if any(job_name in str(report["build_job"]) for job_name in IGNORED_JOBS): - continue - - for test_file in v2report["files"].values(): - for suitename, test_suite in test_file["suites"].items(): - for casename, test_case in test_suite["cases"].items(): - # The below attaches a __main__ as that matches the format of test.__class__ in - # common_utils.py (where this data will be used), and also matches what the output - # of a running test would look like. - name = f"{casename} (__main__.{suitename})" - succeeded: bool = test_case["status"] is None - if succeeded: - test_names_to_times[name].append(test_case["seconds"]) - return { - test_case: statistics.mean(times) - for test_case, times in test_names_to_times.items() - } - - -def filter_slow_tests(test_cases_dict: Dict[str, float]) -> Dict[str, float]: - return { - test_case: time - for test_case, time in test_cases_dict.items() - if time >= SLOW_TEST_CASE_THRESHOLD_SEC - } - - -def get_test_infra_slow_tests() -> Dict[str, float]: - url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json" - contents = urlopen(url, timeout=1).read().decode("utf-8") - return cast(Dict[str, float], json.loads(contents)) - - -def too_similar( - calculated_times: Dict[str, float], other_times: Dict[str, float], threshold: float -) -> bool: - # check that their keys are the same - if calculated_times.keys() != other_times.keys(): - return False - - for test_case, test_time in calculated_times.items(): - other_test_time = other_times[test_case] - relative_difference = abs( - (other_test_time - test_time) / max(other_test_time, test_time) - ) - if relative_difference > threshold: - return False - return True - - -def export_slow_tests(options: Any) -> None: - filename = options.filename - if os.path.exists(filename): - print(f"Overwriting existent file: {filename}") - with open(filename, "w+") as file: - slow_test_times: Dict[str, float] = filter_slow_tests(get_test_case_times()) - if options.ignore_small_diffs: - test_infra_slow_tests_dict = get_test_infra_slow_tests() - if too_similar( - slow_test_times, test_infra_slow_tests_dict, options.ignore_small_diffs - ): - slow_test_times = test_infra_slow_tests_dict - json.dump( - slow_test_times, file, indent=" ", separators=(",", ": "), sort_keys=True - ) - file.write("\n") - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Export a JSON of slow test cases in PyTorch unit test suite" - ) - parser.add_argument( - "-f", - "--filename", - nargs="?", - type=str, - default=SLOW_TESTS_FILE, - const=SLOW_TESTS_FILE, - help="Specify a file path to dump slow test times from previous S3 stats. Default file path: .pytorch-slow-tests.json", - ) - parser.add_argument( - "--ignore-small-diffs", - nargs="?", - type=float, - const=RELATIVE_DIFFERENCE_THRESHOLD, - help="Compares generated results with stats/slow-tests.json in pytorch/test-infra. If the relative differences " - "between test times for each test are smaller than the threshold and the set of test cases have not " - "changed, we will export the stats already in stats/slow-tests.json. Else, we will export the calculated " - "results. The default threshold is 10%.", - ) - return parser.parse_args() - - -def main() -> None: - options = parse_args() - export_slow_tests(options) - - -if __name__ == "__main__": - main() diff --git a/tools/stats/export_test_times.py b/tools/stats/export_test_times.py new file mode 100644 index 0000000000000..4554f546ee050 --- /dev/null +++ b/tools/stats/export_test_times.py @@ -0,0 +1,17 @@ +import pathlib +import sys + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent +sys.path.append(str(REPO_ROOT)) +from tools.stats.import_test_stats import get_test_times + +TEST_TIMES_FILE = ".pytorch-test-times.json" + + +def main() -> None: + print(f"Exporting test times from test-infra to {TEST_TIMES_FILE}") + get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE) + + +if __name__ == "__main__": + main() diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index 7249c5fccb65a..fbc33a685d4ae 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -5,7 +5,7 @@ import os import pathlib import re -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, cast, Dict, List, Optional from urllib.request import urlopen @@ -41,6 +41,7 @@ def fetch_and_cache( This fetch and cache utils allows sharing between different process. """ path = os.path.join(dirpath, name) + print(f"Downloading {url} to {path}") def is_cached_file_valid() -> bool: # Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check @@ -80,6 +81,20 @@ def get_slow_tests( return {} +def get_test_times(dirpath: str, filename: str) -> Dict[str, Dict[str, float]]: + url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-times.json" + + def process_response(the_response: Dict[str, Any]) -> Any: + build_environment = os.environ["BUILD_ENVIRONMENT"] + return the_response[build_environment] + + try: + return fetch_and_cache(dirpath, filename, url, process_response) + except Exception: + print("Couldn't download test times...") + return {} + + def get_disabled_tests( dirpath: str, filename: str = DISABLED_TESTS_FILE ) -> Optional[Dict[str, Any]]: diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py new file mode 100644 index 0000000000000..81183aff2b0fe --- /dev/null +++ b/tools/stats/monitor.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +import datetime +import json +import signal +import time +from typing import Any, Dict, List + +import psutil # type: ignore[import] +import pynvml # type: ignore[import] + + +def get_processes_running_python_tests() -> List[Any]: + python_processes = [] + for process in psutil.process_iter(): + try: + if "python" in process.name() and process.cmdline(): + python_processes.append(process) + except (psutil.NoSuchProcess, psutil.AccessDenied): + # access denied or the process died + pass + return python_processes + + +def get_per_process_cpu_info() -> List[Dict[str, Any]]: + processes = get_processes_running_python_tests() + per_process_info = [] + for p in processes: + info = { + "pid": p.pid, + "cmd": " ".join(p.cmdline()), + "cpu_percent": p.cpu_percent(), + "rss_memory": p.memory_info().rss, + "uss_memory": p.memory_full_info().uss, + } + if "pss" in p.memory_full_info(): + # only availiable in linux + info["pss_memory"] = p.memory_full_info().pss + per_process_info.append(info) + return per_process_info + + +def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: + processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + per_process_info = [] + for p in processes: + info = {"pid": p.pid, "gpu_memory": p.usedGpuMemory} + per_process_info.append(info) + return per_process_info + + +if __name__ == "__main__": + + handle = None + try: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + except pynvml.NVMLError: + # no pynvml avaliable, probably because not cuda + pass + + kill_now = False + + def exit_gracefully(*args: Any) -> None: + global kill_now + kill_now = True + + signal.signal(signal.SIGTERM, exit_gracefully) + + while not kill_now: + try: + stats = { + "time": datetime.datetime.utcnow().isoformat("T") + "Z", + "total_cpu_percent": psutil.cpu_percent(), + "per_process_cpu_info": get_per_process_cpu_info(), + } + if handle is not None: + stats["per_process_gpu_info"] = get_per_process_gpu_info(handle) + stats["total_gpu_utilizaiton"] = pynvml.nvmlDeviceGetUtilizationRates( + handle + ).gpu + except Exception as e: + stats = { + "time": datetime.datetime.utcnow().isoformat("T") + "Z", + "error": str(e), + } + finally: + print(json.dumps(stats)) + time.sleep(1) diff --git a/tools/stats/print_test_stats.py b/tools/stats/print_test_stats.py index c6994c1ea68cd..b82c1236525d8 100755 --- a/tools/stats/print_test_stats.py +++ b/tools/stats/print_test_stats.py @@ -14,6 +14,7 @@ from pathlib import Path from typing import ( Any, + cast, DefaultDict, Dict, Iterable, @@ -22,24 +23,24 @@ Optional, Set, Tuple, - cast, ) from xml.dom import minidom from typing_extensions import TypedDict + from tools.stats.s3_stat_parser import ( - newify_case, + Commit, get_S3_object_from_bucket, get_test_stats_summaries_for_job, + HAVE_BOTO3, + newify_case, Report, + ReportMetaMeta, Status, - Commit, - HAVE_BOTO3, - Version2Case, - VersionedReport, Version1Report, + Version2Case, Version2Report, - ReportMetaMeta, + VersionedReport, ) from tools.stats.scribe import send_to_scribe @@ -785,7 +786,9 @@ def build_info() -> ReportMetaMeta: os.environ.get("SHA1", os.environ.get("CIRCLE_SHA1", "HEAD")) ), "build_branch": os.environ.get("BRANCH", os.environ.get("CIRCLE_BRANCH", "")), - "build_job": os.environ.get("JOB_BASE_NAME", os.environ.get("CIRCLE_JOB", "")), + "build_job": os.environ.get( + "BUILD_ENVIRONMENT", os.environ.get("CIRCLE_JOB", "") + ), "build_workflow_id": os.environ.get( "WORKFLOW_ID", os.environ.get("CIRCLE_WORKFLOW_ID", "") ), @@ -877,7 +880,7 @@ def assemble_s3_object( def send_report_to_s3(head_report: Version2Report) -> None: - job = os.getenv("JOB_BASE_NAME", os.environ.get("CIRCLE_JOB")) + job = os.getenv("BUILD_ENVIRONMENT", os.environ.get("CIRCLE_JOB")) sha1 = os.environ.get("SHA1", os.environ.get("CIRCLE_SHA1", "")) now = datetime.datetime.utcnow().isoformat() @@ -929,7 +932,7 @@ def print_regressions(head_report: Report, *, num_prev_commits: int) -> None: else: commits = commits[:-1] - job = os.environ.get("JOB_BASE_NAME", "") + job = os.environ.get("BUILD_ENVIRONMENT", "") objects: Dict[Commit, List[Report]] = defaultdict(list) for commit in commits: diff --git a/tools/stats/s3_stat_parser.py b/tools/stats/s3_stat_parser.py index a715651c1378d..2691888ecbfab 100644 --- a/tools/stats/s3_stat_parser.py +++ b/tools/stats/s3_stat_parser.py @@ -4,7 +4,8 @@ import subprocess from collections import defaultdict from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Union, Any, cast +from typing import Any, cast, Dict, List, Optional, Tuple, Union + from typing_extensions import Literal, TypedDict try: diff --git a/tools/stats/sccache_stats_to_json.py b/tools/stats/sccache_stats_to_json.py deleted file mode 100644 index e0e9726315c58..0000000000000 --- a/tools/stats/sccache_stats_to_json.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -import sys -import os -from typing import Any - -GITHUB_JOB_ID = os.environ["OUR_GITHUB_JOB_ID"] - - -def parse_value(value: str) -> Any: - # Take the value from a line of `sccache --show-stats` and try to parse - # out a value - try: - return int(value) - except ValueError: - # sccache reports times as 0.000 s, so detect that here and strip - # off the non-numeric parts - if value.endswith(" s"): - return float(value[: -len(" s")]) - - return value - - -def get_name(name: str) -> str: - return name.replace(" ", "_").replace("-", "_").lower() - - -STAT_NAMES = { - "compile_requests", - "compile_requests_executed", - "cache_hits", - "cache_misses", - "cache_timeouts", - "cache_read_errors", - "forced_recaches", - "cache_write_errors", - "compilation_failures", - "cache_errors", - "non_cacheable_compilations", - "non_cacheable_calls", - "non_compilation_calls", - "unsupported_compiler_calls", - "average_cache_write", - "average_cache_read_miss", - "average_cache_read_hit", - "failed_distributed_compilations", -} - - -if __name__ == "__main__": - data = {"job_id": int(GITHUB_JOB_ID)} - for line in sys.stdin: - line = line.strip() - values = [x.strip() for x in line.split(" ")] - values = [x for x in values if x != ""] - if len(values) == 2: - name = get_name(values[0]) - if name in STAT_NAMES: - data[name] = parse_value(values[1]) - - print(json.dumps(data, indent=2)) diff --git a/tools/stats/scribe.py b/tools/stats/scribe.py index 47a8a819206c2..2ca2d8c6824f2 100644 --- a/tools/stats/scribe.py +++ b/tools/stats/scribe.py @@ -1,7 +1,7 @@ import base64 import bz2 -import os import json +import os from typing import Any diff --git a/tools/stats/test_history.py b/tools/stats/test_history.py index 83751441bb7d7..c964fb487522b 100755 --- a/tools/stats/test_history.py +++ b/tools/stats/test_history.py @@ -4,10 +4,10 @@ import subprocess import sys from datetime import datetime, timezone -from signal import SIG_DFL, SIGPIPE, signal +from signal import SIG_DFL, signal, SIGPIPE from typing import Dict, Iterator, List, Optional, Set, Tuple -from tools.stats.s3_stat_parser import Report, get_cases, get_test_stats_summaries +from tools.stats.s3_stat_parser import get_cases, get_test_stats_summaries, Report def get_git_commit_history(*, path: str, ref: str) -> List[Tuple[str, datetime]]: diff --git a/tools/stats/upload_sccache_stats.py b/tools/stats/upload_sccache_stats.py index 1d320b767b5fe..c155e62897005 100644 --- a/tools/stats/upload_sccache_stats.py +++ b/tools/stats/upload_sccache_stats.py @@ -2,14 +2,14 @@ import json import os from pathlib import Path -from typing import Dict, List, Any from tempfile import TemporaryDirectory +from typing import Any, Dict, List from tools.stats.upload_stats_lib import ( download_gha_artifacts, download_s3_artifacts, - upload_to_rockset, unzip, + upload_to_rockset, ) diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index e7c4e41418718..1cba78f68da1e 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -1,11 +1,14 @@ +import gzip +import io +import json import os -import requests import zipfile from pathlib import Path -from typing import Dict, List, Any +from typing import Any, Dict, List -import rockset # type: ignore[import] import boto3 # type: ignore[import] +import requests +import rockset # type: ignore[import] PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch" S3_RESOURCE = boto3.resource("s3") @@ -110,6 +113,29 @@ def upload_to_rockset(collection: str, docs: List[Any]) -> None: print("Done!") +def upload_to_s3( + workflow_run_id: int, + workflow_run_attempt: int, + collection: str, + docs: List[Dict[str, Any]], +) -> None: + print(f"Writing {len(docs)} documents to S3") + body = io.StringIO() + for doc in docs: + json.dump(doc, body) + body.write("\n") + + S3_RESOURCE.Object( + "ossci-raw-job-status", + f"{collection}/{workflow_run_id}/{workflow_run_attempt}", + ).put( + Body=gzip.compress(body.getvalue().encode()), + ContentEncoding="gzip", + ContentType="application/json", + ) + print("Done!") + + def unzip(p: Path) -> None: """Unzip the provided zipfile to a similarly-named directory. diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index 19e74d6c08dc7..9b919716ebc3a 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -1,49 +1,66 @@ import argparse import os +import sys import xml.etree.ElementTree as ET from pathlib import Path -from typing import Dict, List, Any, Tuple, Optional from tempfile import TemporaryDirectory +from typing import Any, Dict, List, Tuple from tools.stats.upload_stats_lib import ( download_gha_artifacts, download_s3_artifacts, - upload_to_rockset, unzip, + upload_to_s3, ) +def get_job_id(report: Path) -> int: + # [Job id in artifacts] + # Retrieve the job id from the report path. In our GHA workflows, we append + # the job id to the end of the report name, so `report` looks like: + # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml + # and we want to get `5596745227` out of it. + return int(report.parts[0].rpartition("_")[2]) + + def parse_xml_report( tag: str, report: Path, workflow_id: int, workflow_run_attempt: int, - skip_tag: Optional[str] = None, ) -> List[Dict[str, Any]]: """Convert a test report xml file into a JSON-serializable list of test cases.""" print(f"Parsing {tag}s for test report: {report}") - # [Job id in artifacts] - # Retrieve the job id from the report path. In our GHA workflows, we append - # the job id to the end of the report name, so `report` looks like: - # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml - # and we want to get `5596745227` out of it. - job_id = int(report.parts[0].rpartition("_")[2]) + + job_id = get_job_id(report) print(f"Found job id: {job_id}") root = ET.parse(report) test_cases = [] for test_case in root.iter(tag): - case = process_xml_element(test_case, skip_tag) + case = process_xml_element(test_case) case["workflow_id"] = workflow_id case["workflow_run_attempt"] = workflow_run_attempt case["job_id"] = job_id + + # [invoking file] + # The name of the file that the test is located in is not necessarily + # the same as the name of the file that invoked the test. + # For example, `test_jit.py` calls into multiple other test files (e.g. + # jit/test_dce.py). For sharding/test selection purposes, we want to + # record the file that invoked the test. + # + # To do this, we leverage an implementation detail of how we write out + # tests (https://bit.ly/3ajEV1M), which is that reports are created + # under a folder with the same name as the invoking file. + case["invoking_file"] = report.parent.name test_cases.append(case) return test_cases -def process_xml_element(element: ET.Element, skip_tag: Optional[str]) -> Dict[str, Any]: +def process_xml_element(element: ET.Element) -> Dict[str, Any]: """Convert a test suite element into a JSON-serializable dict.""" ret: Dict[str, Any] = {} @@ -89,23 +106,33 @@ def process_xml_element(element: ET.Element, skip_tag: Optional[str]) -> Dict[st # "bar": {"text": "another"} # } for child in element: - if child.tag == skip_tag: - continue - if child.tag not in ret: - ret[child.tag] = process_xml_element(child, skip_tag) + ret[child.tag] = process_xml_element(child) else: # If there are multiple tags with the same name, they should be # coalesced into a list. if not isinstance(ret[child.tag], list): ret[child.tag] = [ret[child.tag]] - ret[child.tag].append(process_xml_element(child, skip_tag)) + ret[child.tag].append(process_xml_element(child)) return ret +def get_pytest_parallel_times() -> Dict[Any, Any]: + pytest_parallel_times = {} + for report in Path(".").glob("**/python-pytest/**/*.xml"): + invoking_file = report.parent.name + root = ET.parse(report) + assert len(list(root.iter("testsuite"))) == 1 + for test_suite in root.iter("testsuite"): + pytest_parallel_times[ + (invoking_file, get_job_id(report)) + ] = test_suite.attrib["time"] + return pytest_parallel_times + + def get_tests( workflow_run_id: int, workflow_run_attempt: int -) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: +) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]: with TemporaryDirectory() as temp_dir: print("Using temporary directory:", temp_dir) os.chdir(temp_dir) @@ -125,7 +152,6 @@ def get_tests( # Parse the reports and transform them to JSON test_cases = [] - test_suites = [] for xml_report in Path(".").glob("**/*.xml"): test_cases.extend( parse_xml_report( @@ -135,17 +161,100 @@ def get_tests( workflow_run_attempt, ) ) - test_suites.extend( - parse_xml_report( - "testsuite", - xml_report, - workflow_run_id, - workflow_run_attempt, - skip_tag="testcase", - ) - ) - return test_cases, test_suites + pytest_parallel_times = get_pytest_parallel_times() + + return test_cases, pytest_parallel_times + + +def get_invoking_file_times( + test_case_summaries: List[Dict[str, Any]], pytest_parallel_times: Dict[Any, Any] +) -> List[Dict[str, Any]]: + def get_key(summary: Dict[str, Any]) -> Any: + return ( + summary["invoking_file"], + summary["job_id"], + ) + + def init_value(summary: Dict[str, Any]) -> Any: + return { + "job_id": summary["job_id"], + "workflow_id": summary["workflow_id"], + "workflow_run_attempt": summary["workflow_run_attempt"], + "invoking_file": summary["invoking_file"], + "time": 0.0, + } + + ret = {} + for summary in test_case_summaries: + key = get_key(summary) + if key not in ret: + ret[key] = init_value(summary) + ret[key]["time"] += summary["time"] + + for key, val in ret.items(): + # when running in parallel in pytest, adding the test times will not give the correct + # time used to run the file, which will make the sharding incorrect, so if the test is + # run in parallel, we take the time reported by the testsuite + if key in pytest_parallel_times: + val["time"] = pytest_parallel_times[key] + + return list(ret.values()) + + +def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Group test cases by classname, file, and job_id. We perform the aggregation + manually instead of using the `test-suite` XML tag because xmlrunner does + not produce reliable output for it. + """ + + def get_key(test_case: Dict[str, Any]) -> Any: + return ( + test_case.get("file"), + test_case.get("classname"), + test_case["job_id"], + test_case["workflow_id"], + test_case["workflow_run_attempt"], + # [see: invoking file] + test_case["invoking_file"], + ) + + def init_value(test_case: Dict[str, Any]) -> Dict[str, Any]: + return { + "file": test_case.get("file"), + "classname": test_case.get("classname"), + "job_id": test_case["job_id"], + "workflow_id": test_case["workflow_id"], + "workflow_run_attempt": test_case["workflow_run_attempt"], + # [see: invoking file] + "invoking_file": test_case["invoking_file"], + "tests": 0, + "failures": 0, + "errors": 0, + "skipped": 0, + "successes": 0, + "time": 0.0, + } + + ret = {} + for test_case in test_cases: + key = get_key(test_case) + if key not in ret: + ret[key] = init_value(test_case) + + ret[key]["tests"] += 1 + + if "failure" in test_case: + ret[key]["failures"] += 1 + elif "error" in test_case: + ret[key]["errors"] += 1 + elif "skipped" in test_case: + ret[key]["skipped"] += 1 + else: + ret[key]["successes"] += 1 + + ret[key]["time"] += test_case["time"] + return list(ret.values()) if __name__ == "__main__": @@ -162,7 +271,42 @@ def get_tests( required=True, help="which retry of the workflow this is", ) + parser.add_argument( + "--head-branch", + required=True, + help="Head branch of the workflow", + ) args = parser.parse_args() - test_cases, test_suites = get_tests(args.workflow_run_id, args.workflow_run_attempt) - upload_to_rockset("test_run", test_cases) - upload_to_rockset("test_suite", test_suites) + test_cases, pytest_parallel_times = get_tests( + args.workflow_run_id, args.workflow_run_attempt + ) + + # Flush stdout so that any errors in rockset upload show up last in the logs. + sys.stdout.flush() + + # For PRs, only upload a summary of test_runs. This helps lower the + # volume of writes we do to Rockset. + test_case_summary = summarize_test_cases(test_cases) + invoking_file_times = get_invoking_file_times( + test_case_summary, pytest_parallel_times + ) + + upload_to_s3( + args.workflow_run_id, + args.workflow_run_attempt, + "test_run_summary", + test_case_summary, + ) + + upload_to_s3( + args.workflow_run_id, + args.workflow_run_attempt, + "invoking_file_times", + invoking_file_times, + ) + + if args.head_branch == "master": + # For master jobs, upload everytihng. + upload_to_s3( + args.workflow_run_id, args.workflow_run_attempt, "test_run", test_cases + ) diff --git a/tools/target_definitions.bzl b/tools/target_definitions.bzl index 3b78e3c9d4e75..6c4a53dbfc89c 100644 --- a/tools/target_definitions.bzl +++ b/tools/target_definitions.bzl @@ -168,7 +168,7 @@ def add_torch_libs(): "//gloo/fb/transport/tls:tls", "//gloo/transport/tcp:tcp", "//tensorpipe:tensorpipe_cpu", - ] + (["//kineto/libkineto:kineto"] if use_kineto() else []) + + ] + (["//kineto/libkineto:kineto"] if use_kineto() else ["//kineto/libkineto:kineto_activity_header"]) + (["//caffe2:mobile_bytecode"] if enable_flatbuffer else []) ), exported_external_deps = [ diff --git a/tools/test/gen_operators_yaml_test.py b/tools/test/gen_operators_yaml_test.py new file mode 100644 index 0000000000000..87455d3a13ff5 --- /dev/null +++ b/tools/test/gen_operators_yaml_test.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import unittest + +from gen_operators_yaml import make_filter_from_options, verify_all_specified_present + + +class GenOperatorsYAMLTest(unittest.TestCase): + def setUp(self): + pass + + def test_filter_creation(self): + filter_func = make_filter_from_options( + model_name="abc", + model_versions=["100", "101"], + model_assets=None, + model_backends=None, + ) + config = [ + { + "model": { + "name": "abc", + "version": 100, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + "traced_operators": [], + }, + { + "model": { + "name": "abc", + "version": 102, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + }, + { + "model": { + "name": "abcd", + "version": 100, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + "traced_operators": [], + }, + { + "model": { + "name": "abc", + "version": 101, + "asset": "asset-2", + "backend": "CPU", + }, + "root_operators": [], + }, + ] + + filtered_configs = list(filter(filter_func, config)) + assert ( + len(filtered_configs) == 2 + ), "Expected 2 elements in filtered_configs, but got {}".format( + len(filtered_configs) + ) + + def test_verification_success(self): + filter_func = make_filter_from_options( + model_name="abc", + model_versions=["100", "101"], + model_assets=["asset-1", "asset-2"], + model_backends=None, + ) + config = [ + { + "model": { + "name": "abc", + "version": 100, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + "traced_operators": [], + }, + { + "model": { + "name": "abc", + "version": 101, + "asset": "asset-2", + "backend": "CPU", + }, + "root_operators": [], + }, + ] + filtered_configs = list(filter(filter_func, config)) + try: + verify_all_specified_present( + model_assets=["asset-1", "asset-2"], + model_versions=["100", "101"], + selected_models_yaml=filtered_configs, + rule_name="test", + model_name="abc", + new_style_rule=True, + ) + except Exception: + self.fail( + "expected verify_all_specified_present to succeed instead it raised an exception" + ) + + def test_verification_fail(self): + config = [ + { + "model": { + "name": "abc", + "version": 100, + "asset": "asset-1", + "backend": "CPU", + }, + "root_operators": [], + "traced_operators": [], + }, + { + "model": { + "name": "abc", + "version": 101, + "asset": "asset-2", + "backend": "CPU", + }, + "root_operators": [], + }, + ] + + good_assets = ["asset-1", "asset-2"] + good_versions = ["100", "101"] + good_name = "abc" + + # Test bad asset + filter_func_bad_asset = make_filter_from_options( + model_name=good_name, + model_versions=good_versions, + model_assets=["asset-1", "asset-2", "asset-3"], + model_backends=None, + ) + filtered_configs_asset = list(filter(filter_func_bad_asset, config)) + with self.assertRaises(RuntimeError): + verify_all_specified_present( + model_assets=["asset-1", "asset-2", "asset-3"], + model_versions=good_versions, + selected_models_yaml=filtered_configs_asset, + rule_name="test", + model_name=good_name, + new_style_rule=True, + ) + + # Test bad version + filter_func_bad_version = make_filter_from_options( + model_name=good_name, + model_versions=["100", "101", "102"], + model_assets=good_assets, + model_backends=None, + ) + filtered_configs_version = list(filter(filter_func_bad_version, config)) + with self.assertRaises(RuntimeError): + verify_all_specified_present( + model_assets=good_assets, + model_versions=["100", "101", "102"], + selected_models_yaml=filtered_configs_version, + rule_name="test", + model_name=good_name, + new_style_rule=True, + ) + + # Test bad name + filter_func_bad_name = make_filter_from_options( + model_name="abcd", + model_versions=good_versions, + model_assets=good_assets, + model_backends=None, + ) + filtered_configs_name = list(filter(filter_func_bad_name, config)) + with self.assertRaises(RuntimeError): + verify_all_specified_present( + model_assets=good_assets, + model_versions=good_versions, + selected_models_yaml=filtered_configs_name, + rule_name="test", + model_name="abcd", + new_style_rule=True, + ) diff --git a/tools/test/gen_oplist_test.py b/tools/test/gen_oplist_test.py new file mode 100644 index 0000000000000..d58e2ccc90671 --- /dev/null +++ b/tools/test/gen_oplist_test.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import unittest +from unittest.mock import MagicMock + +from gen_oplist import throw_if_any_op_includes_overloads + + +class GenOplistTest(unittest.TestCase): + def setUp(self): + pass + + def test_throw_if_any_op_includes_overloads(self): + selective_builder = MagicMock() + selective_builder.operators = MagicMock() + selective_builder.operators.items.return_value = [ + ("op1", MagicMock(include_all_overloads=True)), + ("op2", MagicMock(include_all_overloads=False)), + ("op3", MagicMock(include_all_overloads=True)), + ] + + self.assertRaises( + Exception, throw_if_any_op_includes_overloads, selective_builder + ) + + selective_builder.operators.items.return_value = [ + ("op1", MagicMock(include_all_overloads=False)), + ("op2", MagicMock(include_all_overloads=False)), + ("op3", MagicMock(include_all_overloads=False)), + ] + + # Here we do not expect it to throw an exception since none of the ops + # include all overloads. + throw_if_any_op_includes_overloads(selective_builder) diff --git a/tools/test/test_cmake.py b/tools/test/test_cmake.py index 2c4bead6db3b7..618b951a8c54c 100644 --- a/tools/test/test_cmake.py +++ b/tools/test/test_cmake.py @@ -1,13 +1,14 @@ import contextlib import os import typing -from typing import Iterator, Optional, Sequence import unittest import unittest.mock +from typing import Iterator, Optional, Sequence -import tools.setup_helpers.env # noqa: F401 unused but resolves circular import import tools.setup_helpers.cmake +import tools.setup_helpers.env # noqa: F401 unused but resolves circular import + T = typing.TypeVar("T") diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py index 22b5470f63268..a96533769d364 100644 --- a/tools/test/test_codegen.py +++ b/tools/test/test_codegen.py @@ -1,11 +1,24 @@ import dataclasses import typing import unittest +from typing import Dict -from tools.autograd import gen_autograd_functions -from tools.autograd import load_derivatives import torchgen.model +from tools.autograd import gen_autograd_functions, load_derivatives +from torchgen.gen import ( + get_native_function_declarations, + get_native_function_schema_registrations, +) +from torchgen.model import ( + BackendIndex, + BackendMetadata, + DispatchKey, + NativeFunction, + OperatorName, +) +from torchgen.selective_build.selector import SelectiveBuilder + class TestCreateDerivative(unittest.TestCase): def test_named_grads(self) -> None: @@ -130,6 +143,133 @@ def test_non_differentiable_output_output_differentiability(self) -> None: assert "grad_z = grads[1]" in definition +class TestGenSchemaRegistration(unittest.TestCase): + def setUp(self) -> None: + self.selector = SelectiveBuilder.get_nop_selector() + self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml( + {"func": "custom::func() -> bool"}, + loc=torchgen.model.Location(__file__, 1), + valid_tags=set(), + ) + + def test_default_namespace_schema_registration_code_valid(self) -> None: + native_functions = [DEFAULT_NATIVE_FUNCTION] + registrations, _ = get_native_function_schema_registrations( + native_functions=native_functions, + schema_selector=self.selector, + ) + self.assertEqual(registrations, ['m.def("func() -> bool", {});\n']) + + def test_custom_namespace_schema_registration_code_valid(self) -> None: + _, registrations = get_native_function_schema_registrations( + native_functions=[self.custom_native_function], + schema_selector=self.selector, + ) + self.assertEqual( + registrations, + """ +TORCH_LIBRARY(custom, m) { + m.def("func() -> bool", {}); + +};""", + ) + + def test_mixed_namespace_schema_registration_code_valid(self) -> None: + ( + aten_registrations, + custom_registrations, + ) = get_native_function_schema_registrations( + native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function], + schema_selector=self.selector, + ) + self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n']) + self.assertEqual( + custom_registrations, + """ +TORCH_LIBRARY(custom, m) { + m.def("func() -> bool", {}); + +};""", + ) + + def test_3_namespaces_schema_registration_code_invalid(self) -> None: + custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml( + {"func": "custom2::func() -> bool"}, + loc=torchgen.model.Location(__file__, 1), + valid_tags=set(), + ) + with self.assertRaises(AssertionError): + get_native_function_schema_registrations( + native_functions=[ + DEFAULT_NATIVE_FUNCTION, + self.custom_native_function, + custom2_native_function, + ], + schema_selector=self.selector, + ) + + +class TestGenNativeFunctionDeclaration(unittest.TestCase): + def setUp(self) -> None: + self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml( + {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, + loc=torchgen.model.Location(__file__, 1), + valid_tags=set(), + ) + self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml( + { + "func": "op_2() -> bool", + "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"}, + }, + loc=torchgen.model.Location(__file__, 1), + valid_tags=set(), + ) + + backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = { + DispatchKey.CPU: {}, + DispatchKey.QuantizedCPU: {}, + } + BackendIndex.grow_index(backend_indices, op_1_backend_index) + BackendIndex.grow_index(backend_indices, op_2_backend_index) + self.backend_indices = { + k: BackendIndex( + dispatch_key=k, + use_out_as_primary=True, + external=False, + device_guard=False, + index=backend_indices[k], + ) + for k in backend_indices + } + + def test_native_function_declaration_1_op_2_ns_error(self) -> None: + with self.assertRaises(AssertionError): + get_native_function_declarations( + grouped_native_functions=[ + self.op_1_native_function, + self.op_2_native_function, + ], + backend_indices=self.backend_indices, + ) + + def test_native_function_declaration_1_op_1_ns_valid(self) -> None: + self.assertIsInstance(self.op_1_native_function, NativeFunction) + declaration = get_native_function_declarations( + grouped_native_functions=[ + self.op_1_native_function, + ], + backend_indices=self.backend_indices, + ) + target = """ +namespace at { +namespace native { +TORCH_API bool kernel_1(); +} // namespace native +} // namespace at + """ + self.assertEqual("\n".join(declaration), target) + + # Represents the most basic NativeFunction. Use dataclasses.replace() # to edit for use. DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml( diff --git a/tools/test/test_codegen_model.py b/tools/test/test_codegen_model.py index 710e90697116f..cb31561275edf 100644 --- a/tools/test/test_codegen_model.py +++ b/tools/test/test_codegen_model.py @@ -1,15 +1,16 @@ # Owner(s): ["module: codegen"] -import expecttest -import unittest -import yaml import textwrap +import unittest -from torchgen.model import NativeFunctionsGroup, DispatchKey +import expecttest import torchgen.dest as dest import torchgen.gen as gen +import yaml from torchgen.gen import LineLoader, parse_native_yaml_struct +from torchgen.model import DispatchKey, NativeFunctionsGroup + class TestCodegenModel(expecttest.TestCase): def assertParseErrorInline(self, yaml_str: str, expect: str) -> None: diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index dbe32ecec169b..9091cca6dddf6 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -3,10 +3,11 @@ import os import tempfile import unittest + import expecttest +from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 from torchgen.gen_backend_stubs import run -from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 path = os.path.dirname(os.path.realpath(__file__)) gen_backend_stubs_path = os.path.join(path, "../torchgen/gen_backend_stubs.py") @@ -237,7 +238,7 @@ def test_unrecognized_key(self) -> None: output_error = self.get_errors_from_gen_backend_stubs(yaml_str) self.assertExpectedInline( output_error, - """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native""", # noqa: B950 + """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen""", # noqa: B950 ) # if use_out_as_primary is provided, it must be a bool diff --git a/tools/test/test_import_test_stats.py b/tools/test/test_import_test_stats.py index ea9aad8df40db..3679ad82d3d18 100644 --- a/tools/test/test_import_test_stats.py +++ b/tools/test/test_import_test_stats.py @@ -1,9 +1,10 @@ import os import unittest -from tools.stats.import_test_stats import get_disabled_issues from typing import List from unittest.mock import patch +from tools.stats.import_test_stats import get_disabled_issues + class TestGetDisabledIssues(unittest.TestCase): def run_assert_disabled_issues( diff --git a/tools/test/test_selective_build.py b/tools/test/test_selective_build.py new file mode 100644 index 0000000000000..50a3ba56eb795 --- /dev/null +++ b/tools/test/test_selective_build.py @@ -0,0 +1,281 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import unittest + +from torchgen.selective_build.operator import * +from torchgen.selective_build.selector import ( + combine_selective_builders, + SelectiveBuilder, +) + + +class TestSelectiveBuild(unittest.TestCase): + def test_selective_build_operator(self): + op = SelectiveBuildOperator( + "aten::add.int", + is_root_operator=True, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + self.assertTrue(op.is_root_operator) + self.assertFalse(op.is_used_for_training) + self.assertFalse(op.include_all_overloads) + + def test_selector_factory(self): + yaml_config_v1 = """ +debug_info: + - model1@v100 + - model2@v51 +operators: + aten::add: + is_used_for_training: No + is_root_operator: Yes + include_all_overloads: Yes + aten::add.int: + is_used_for_training: Yes + is_root_operator: No + include_all_overloads: No + aten::mul.int: + is_used_for_training: Yes + is_root_operator: No + include_all_overloads: No +""" + + yaml_config_v2 = """ +debug_info: + - model1@v100 + - model2@v51 +operators: + aten::sub: + is_used_for_training: No + is_root_operator: Yes + include_all_overloads: No + debug_info: + - model1@v100 + aten::sub.int: + is_used_for_training: Yes + is_root_operator: No + include_all_overloads: No +""" + + yaml_config_all = "include_all_operators: Yes" + + yaml_config_invalid = "invalid:" + + selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1) + + self.assertTrue(selector1.is_operator_selected("aten::add")) + self.assertTrue(selector1.is_operator_selected("aten::add.int")) + # Overload name is not used for checking in v1. + self.assertTrue(selector1.is_operator_selected("aten::add.float")) + + def gen(): + return SelectiveBuilder.from_yaml_str(yaml_config_invalid) + + self.assertRaises(Exception, gen) + + selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all) + + self.assertTrue(selector_all.is_operator_selected("aten::add")) + self.assertTrue(selector_all.is_operator_selected("aten::sub")) + self.assertTrue(selector_all.is_operator_selected("aten::sub.int")) + self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32")) + + selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2) + + self.assertFalse(selector2.is_operator_selected("aten::add")) + self.assertTrue(selector2.is_operator_selected("aten::sub")) + self.assertTrue(selector2.is_operator_selected("aten::sub.int")) + + selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( + ["aten::add", "aten::add.int", "aten::mul.int"], + False, + False, + ) + self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float")) + self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add")) + self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int")) + self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub")) + + self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) + self.assertFalse( + selector_legacy_v1.is_operator_selected_for_training("aten::add") + ) + + selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( + ["aten::add", "aten::add.int", "aten::mul.int"], + True, + False, + ) + + self.assertTrue(selector_legacy_v1.is_root_operator("aten::add")) + self.assertFalse( + selector_legacy_v1.is_operator_selected_for_training("aten::add") + ) + self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float")) + self.assertFalse( + selector_legacy_v1.is_operator_selected_for_training("aten::add.float") + ) + + selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( + ["aten::add", "aten::add.int", "aten::mul.int"], + False, + True, + ) + + self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) + self.assertTrue( + selector_legacy_v1.is_operator_selected_for_training("aten::add") + ) + self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float")) + self.assertTrue( + selector_legacy_v1.is_operator_selected_for_training("aten::add.float") + ) + + def test_operator_combine(self): + op1 = SelectiveBuildOperator( + "aten::add.int", + is_root_operator=True, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op2 = SelectiveBuildOperator( + "aten::add.int", + is_root_operator=False, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op3 = SelectiveBuildOperator( + "aten::add", + is_root_operator=True, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op4 = SelectiveBuildOperator( + "aten::add.int", + is_root_operator=True, + is_used_for_training=True, + include_all_overloads=False, + _debug_info=None, + ) + + op5 = combine_operators(op1, op2) + + self.assertTrue(op5.is_root_operator) + self.assertFalse(op5.is_used_for_training) + + op6 = combine_operators(op1, op4) + + self.assertTrue(op6.is_root_operator) + self.assertTrue(op6.is_used_for_training) + + def gen_new_op(): + return combine_operators(op1, op3) + + self.assertRaises(Exception, gen_new_op) + + def test_training_op_fetch(self): + yaml_config = """ +operators: + aten::add.int: + is_used_for_training: No + is_root_operator: Yes + include_all_overloads: No + aten::add: + is_used_for_training: Yes + is_root_operator: No + include_all_overloads: Yes +""" + + selector = SelectiveBuilder.from_yaml_str(yaml_config) + self.assertTrue(selector.is_operator_selected_for_training("aten::add.int")) + self.assertTrue(selector.is_operator_selected_for_training("aten::add")) + + def test_kernel_dtypes(self): + yaml_config = """ +kernel_metadata: + add_kernel: + - int8 + - int32 + sub_kernel: + - int16 + - int32 + add/sub_kernel: + - float + - complex +""" + + selector = SelectiveBuilder.from_yaml_str(yaml_config) + + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) + self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16")) + self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32")) + self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float")) + + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float")) + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex")) + self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) + self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) + + def test_merge_kernel_dtypes(self): + yaml_config1 = """ +kernel_metadata: + add_kernel: + - int8 + add/sub_kernel: + - float + - complex + - none + mul_kernel: + - int8 +""" + + yaml_config2 = """ +kernel_metadata: + add_kernel: + - int32 + sub_kernel: + - int16 + - int32 + add/sub_kernel: + - float + - complex +""" + + selector1 = SelectiveBuilder.from_yaml_str(yaml_config1) + selector2 = SelectiveBuilder.from_yaml_str(yaml_config2) + + selector = combine_selective_builders(selector1, selector2) + + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) + self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16")) + self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32")) + self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float")) + + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float")) + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex")) + self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none")) + self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) + self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) + + self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8")) + self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32")) + + def test_all_kernel_dtypes_selected(self): + yaml_config = """ +include_all_non_op_selectives: True +""" + + selector = SelectiveBuilder.from_yaml_str(yaml_config) + + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16")) + self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32")) + self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float")) diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index b846bb53c0cb3..23f05cb99fe89 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -1,8 +1,8 @@ import random import unittest +from typing import Dict, List, Tuple from tools.testing.test_selections import calculate_shards -from typing import Dict, List, Tuple class TestCalculateShards(unittest.TestCase): diff --git a/tools/test/test_upload_test_stats.py b/tools/test/test_upload_test_stats.py index 71b5c4c6c0513..58be2ed4a646d 100644 --- a/tools/test/test_upload_test_stats.py +++ b/tools/test/test_upload_test_stats.py @@ -1,9 +1,9 @@ -import unittest import os +import unittest -IN_CI = os.environ.get("CI") +from tools.stats.upload_test_stats import get_tests, summarize_test_cases -from tools.stats.upload_test_stats import get_tests +IN_CI = os.environ.get("CI") class TestUploadTestStats(unittest.TestCase): @@ -13,9 +13,10 @@ class TestUploadTestStats(unittest.TestCase): ) def test_existing_job(self) -> None: """Run on a known-good job and make sure we don't error and get basically okay reults.""" - test_cases, test_suites = get_tests(2465214458, 1) - self.assertEqual(len(test_cases), 731457) - self.assertEqual(len(test_suites), 7781) + test_cases, _ = get_tests(2561394934, 1) + self.assertEqual(len(test_cases), 609873) + summary = summarize_test_cases(test_cases) + self.assertEqual(len(summary), 5068) if __name__ == "__main__": diff --git a/tools/test/test_utils.py b/tools/test/test_utils.py new file mode 100644 index 0000000000000..92fa311140bbb --- /dev/null +++ b/tools/test/test_utils.py @@ -0,0 +1,22 @@ +import unittest + +from torchgen.utils import NamespaceHelper + + +class TestNamespaceHelper(unittest.TestCase): + def test_create_from_namespaced_tuple(self) -> None: + helper = NamespaceHelper.from_namespaced_entity("aten::add") + self.assertEqual(helper.entity_name, "add") + self.assertEqual(helper.get_cpp_namespace(), "aten") + + def test_default_namespace(self) -> None: + helper = NamespaceHelper.from_namespaced_entity("add") + self.assertEqual(helper.entity_name, "add") + self.assertEqual(helper.get_cpp_namespace(), "") + self.assertEqual(helper.get_cpp_namespace("default"), "default") + + def test_namespace_levels_more_than_max(self) -> None: + with self.assertRaises(AssertionError): + NamespaceHelper( + namespace_str="custom_1::custom_2", entity_name="", max_level=1 + ) diff --git a/tools/testing/explicit_ci_jobs.py b/tools/testing/explicit_ci_jobs.py index 3de04e1a18e92..daff3cce8956f 100755 --- a/tools/testing/explicit_ci_jobs.py +++ b/tools/testing/explicit_ci_jobs.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 -import yaml -import textwrap -import subprocess -import pathlib import argparse import fnmatch +import pathlib +import subprocess +import textwrap -from typing import Dict, List, Any +from typing import Any, Dict, List + +import yaml REPO_ROOT = pathlib.Path(__file__).parent.parent.parent diff --git a/tools/testing/modulefinder_determinator.py b/tools/testing/modulefinder_determinator.py index f041be184dd86..dd4e4cf5c6df8 100644 --- a/tools/testing/modulefinder_determinator.py +++ b/tools/testing/modulefinder_determinator.py @@ -1,9 +1,9 @@ -import os import modulefinder -import sys +import os import pathlib +import sys import warnings -from typing import Dict, Any, List, Set +from typing import Any, Dict, List, Set REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent @@ -20,6 +20,7 @@ "test_cpp_extensions_aot_ninja", "test_cpp_extensions_aot_no_ninja", "test_cpp_extensions_jit", + "test_cpp_extensions_open_device_registration", "test_cuda", "test_cuda_primary_ctx", "test_dataloader", diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 9056054ec730b..caa447c6907d5 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -1,68 +1,9 @@ -import json import os import subprocess -from tools.stats.s3_stat_parser import ( - get_previous_reports_for_branch, - Report, - Version2Report, - HAVE_BOTO3, -) -from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests - -from typing import Any, Dict, List, Optional, Tuple, cast -from typing_extensions import TypedDict - - -class JobTimeJSON(TypedDict): - commit: str - JOB_BASE_NAME: str - job_times: Dict[str, float] - - -def _get_stripped_CI_job() -> str: - """E.g. convert 'pytorch_windows_vs2019_py36_cuda10.1_build' to 'pytorch_windows_vs2019_py36_cuda10.1'.""" - job = os.environ.get("JOB_BASE_NAME", "").rstrip("0123456789") - if job.endswith("_slow_test"): - job = job[: len(job) - len("_slow_test")] - elif job.endswith("_test") or job.endswith("-test"): - job = job[: len(job) - len("_test")] - elif job.endswith("_build") or job.endswith("-build"): - job = job[: len(job) - len("_build")] - return job +from typing import Dict, List, Tuple - -def _get_job_times_json(job_times: Dict[str, float]) -> JobTimeJSON: - return { - "commit": subprocess.check_output( - ["git", "rev-parse", "HEAD"], encoding="ascii" - ).strip(), - "JOB_BASE_NAME": _get_stripped_CI_job(), - "job_times": job_times, - } - - -def _calculate_job_times(reports: List["Report"]) -> Dict[str, float]: - """Compute test runtime by filename: ("test_file_name" -> (current_avg, # values))""" - jobs_to_times: Dict[str, Tuple[float, int]] = dict() - for report in reports: - v_report = cast(Version2Report, report) - assert ( - "format_version" in v_report.keys() and v_report.get("format_version") == 2 - ), "S3 format currently handled is version 2 only" - files: Dict[str, Any] = v_report["files"] - for name, test_file in files.items(): - if name not in jobs_to_times: - jobs_to_times[name] = (test_file["total_seconds"], 1) - else: - curr_avg, curr_count = jobs_to_times[name] - new_count = curr_count + 1 - new_avg = ( - curr_avg * curr_count + test_file["total_seconds"] - ) / new_count - jobs_to_times[name] = (new_avg, new_count) - - return {job: time for job, (time, _) in jobs_to_times.items()} +from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests def calculate_shards( @@ -99,63 +40,6 @@ def calculate_shards( return sharded_jobs -def _pull_job_times_from_S3() -> Dict[str, float]: - if HAVE_BOTO3: - ci_job_prefix = _get_stripped_CI_job() - s3_reports: List["Report"] = get_previous_reports_for_branch( - "origin/viable/strict", ci_job_prefix - ) - else: - print( - "Uh oh, boto3 is not found. Either it is not installed or we failed to import s3_stat_parser." - ) - print( - "If not installed, please install boto3 for automatic sharding and test categorization." - ) - s3_reports = [] - - if len(s3_reports) == 0: - print("Gathered no reports from S3. Please proceed without them.") - return dict() - - return _calculate_job_times(s3_reports) - - -def _query_past_job_times(test_times_file: Optional[str] = None) -> Dict[str, float]: - """Read historic test job times from a file. - - If the file doesn't exist or isn't matching current commit. It will download data from S3 and exported it. - """ - if test_times_file and os.path.exists(test_times_file): - with open(test_times_file) as file: - test_times_json: JobTimeJSON = json.load(file) - - curr_commit = subprocess.check_output( - ["git", "rev-parse", "HEAD"], encoding="ascii" - ).strip() - file_commit = test_times_json.get("commit", "") - curr_ci_job = _get_stripped_CI_job() - file_ci_job = test_times_json.get("JOB_BASE_NAME", "N/A") - if curr_commit != file_commit: - print(f"Current test times file is from different commit {file_commit}.") - elif curr_ci_job != file_ci_job: - print(f"Current test times file is for different CI job {file_ci_job}.") - else: - print( - f"Found stats for current commit: {curr_commit} and job: {curr_ci_job}. Proceeding with those values." - ) - return test_times_json.get("job_times", {}) - - # Found file, but commit or CI job in JSON doesn't match - print( - f"Overwriting current file with stats based on current commit: {curr_commit} and CI job: {curr_ci_job}" - ) - - job_times = export_S3_test_times(test_times_file) - - return job_times - - def _query_changed_test_files() -> List[str]: default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'master')}" cmd = ["git", "diff", "--name-only", default_branch, "HEAD"] @@ -169,45 +53,6 @@ def _query_changed_test_files() -> List[str]: return lines -# Get sharded test allocation based on historic S3 data. -def get_shard_based_on_S3( - which_shard: int, num_shards: int, tests: List[str], test_times_file: str -) -> List[str]: - # Short circuit and don't do any work if there's only 1 shard - if num_shards == 1: - return tests - - jobs_to_times = _query_past_job_times(test_times_file) - - # Got no stats from S3, returning early to save runtime - if len(jobs_to_times) == 0: - print("Gathered no stats from S3. Proceeding with default sharding plan.") - return tests[which_shard - 1 :: num_shards] - - shards = calculate_shards(num_shards, tests, jobs_to_times) - _, tests_from_shard = shards[which_shard - 1] - return tests_from_shard - - -def get_slow_tests_based_on_S3( - test_list: List[str], td_list: List[str], slow_test_threshold: int -) -> List[str]: - """Get list of slow tests based on historic S3 data.""" - jobs_to_times: Dict[str, float] = _query_past_job_times() - - # Got no stats from S3, returning early to save runtime - if len(jobs_to_times) == 0: - print("Gathered no stats from S3. No new slow tests calculated.") - return [] - - slow_tests: List[str] = [] - for test in test_list: - if test in jobs_to_times and test not in td_list: - if jobs_to_times[test] > slow_test_threshold: - slow_tests.append(test) - return slow_tests - - def get_reordered_tests(tests: List[str]) -> List[str]: """Get the reordered test filename list based on github PR history or git changed file.""" prioritized_tests: List[str] = [] @@ -248,20 +93,6 @@ def get_reordered_tests(tests: List[str]) -> List[str]: return tests -# TODO Refactor this and unify with tools.stats.export_slow_tests -def export_S3_test_times(test_times_filename: Optional[str] = None) -> Dict[str, float]: - test_times: Dict[str, float] = _pull_job_times_from_S3() - if test_times_filename is not None: - print(f"Exporting S3 test stats to {test_times_filename}.") - if os.path.exists(test_times_filename): - print(f"Overwriting existent file: {test_times_filename}") - with open(test_times_filename, "w+") as file: - job_times_json = _get_job_times_json(test_times) - json.dump(job_times_json, file, indent=" ", separators=(",", ": ")) - file.write("\n") - return test_times - - def get_test_case_configs(dirpath: str) -> None: get_slow_tests(dirpath=dirpath) get_disabled_tests(dirpath=dirpath) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index ea28278dfe8f6..05acd066f7f08 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -74,6 +74,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_ROOT}/third_party/gloo ${TORCH_ROOT}/third_party/onnx ${TORCH_ROOT}/third_party/flatbuffers/include + ${TORCH_ROOT}/third_party/kineto/libkineto/include ${TORCH_SRC_DIR}/csrc ${TORCH_SRC_DIR}/csrc/api/include @@ -117,6 +118,13 @@ if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") -Wno-writable-strings) endif() +if(USE_ITT) + list(APPEND TORCH_PYTHON_SRCS + ${TORCH_SRC_DIR}/csrc/itt.cpp + ) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_ITT) +endif() + if(USE_CUDA) include(${TORCH_ROOT}/cmake/public/cuda.cmake) append_filelist("libtorch_python_cuda_core_sources" TORCH_PYTHON_SRCS) @@ -142,6 +150,10 @@ if(USE_ROCM) list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${roctracer_INCLUDE_DIRS}) endif() +if(USE_EXPERIMENTAL_CUDNN_V8_API) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_EXPERIMENTAL_CUDNN_V8_API) +endif() + if(USE_CUDNN OR USE_ROCM) list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp @@ -247,6 +259,9 @@ if(USE_DISTRIBUTED) if(USE_NCCL) list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) endif() + if(USE_UCC) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_ucc) + endif() # Same for MPI. if(USE_MPI) list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES}) @@ -280,6 +295,9 @@ if(USE_DEPLOY) if(USE_GLOO AND USE_C10D_GLOO) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO) endif() + if(USE_UCC AND USE_C10D_UCC) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_UCC) + endif() if(USE_NCCL AND USE_C10D_NCCL) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL) # Put nccl headers on the include path. We are specifically only setting diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1399da23379fa..d6ad4ec95a6f8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -13,7 +13,7 @@ from typing_extensions import Literal from torch._six import inf from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt -from torch.storage import _TypedStorage +from torch.storage import TypedStorage import builtins @@ -165,6 +165,12 @@ class Future(object): def _jit_set_num_profiled_runs(num: _size) -> _size: ... +class SymIntNode(object): + def get_pyobj(self) -> Any: ... + + @staticmethod + def new_symint(self) -> SymIntNode: ... + # Defined in torch/csrc/jit/passes/xnnpack_rewrite.h class MobileOptimizerType: ... @@ -452,12 +458,34 @@ class _InsertPoint: def __enter__(self) -> None: ... def __exit__(self, *args) -> None: ... +# Defined in torch/csrc/jit/ir/ir.h +class Use: + @property + def user(self) -> Node: ... + @property + def offset(self) -> _int: ... + def isAfter(self, other: Use) -> _bool: ... + ... + # Defined in torch/csrc/jit/ir/ir.h class Value: def type(self)-> JitType: ... def setType(self, t: JitType) -> Value: ... + def setTypeAs(self, other: Value) -> Value: ... + def inferTypeFrom(self, t: Tensor) -> None: ... def debugName(self) -> str: ... + def setDebugName(self, name: str) -> None: ... + def unique(self) -> _int: ... + def offset(self) -> _int: ... + def node(self) -> Node: ... + def uses(self) -> List[Use]: ... + def replaceAllUsesWith(self, val: Value) -> None: ... + def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ... def requires_grad(self) -> _bool: ... + def requiresGrad(self) -> _bool: ... + def copyMetadata(self, other: Value) -> Value: ... + def isCompleteTensor(self) -> _bool: ... + def toIValue(self) -> IValue: ... ... # Defined in torch/csrc/jit/ir/ir.h @@ -477,12 +505,17 @@ class Node: def schema(self) -> str: ... def input(self) -> Value: ... def inputs(self) -> List[Value]: ... + def inputsAt(self, idx: _int) -> Value: ... + def inputsSize(self) -> _int: ... def output(self) -> Value: ... def outputs(self) -> List[Value]: ... + def outputsAt(self, idx: _int) -> Value: ... def outputsSize(self) -> _int: ... + def hasMultipleOutputs(self) -> _bool: ... def blocks(self) -> List[Block]: ... def addBlock(self) -> Block: ... def mustBeNone(self) -> _bool: ... + def matches(self, pattern: str) -> _bool: ... def kind(self) -> str: ... def kindOf(self, name: str) -> str: ... def addInput(self, name: str) -> Value: ... @@ -503,9 +536,47 @@ class Node: def scopeName(self) -> str: ... def isNondeterministic(self) -> _bool: ... def copyAttributes(self, rhs: Node) -> Node: ... - def hasAttributes(self, name: str) -> _bool: ... + def copyMetadata(self, rhs: Node) -> Node: ... + def hasAttributes(self) -> _bool: ... + def hasAttribute(self, name: str) -> _bool: ... + def removeAttribute(self, attr: str) -> Node: ... def namedInput(self, name: str) -> Value: ... def sourceRange(self) -> SourceRange: ... + def owningBlock(self) -> Block: ... + def findNode(self, kind: str, recurse: _bool = True) -> Node: ... + def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ... + def getModuleHierarchy(self) -> str: ... + def prev(self) -> Node: ... + def destroy(self) -> None: ... + + # Accessors for attributes as types. + def f(self, name: str) -> _float: ... + def f_(self, name: str, val: _float) -> Node: ... + def fs(self, name: str) -> List[_float]: ... + def fs_(self, name: str, val: List[_float]) -> Node: ... + def c(self, name: str) -> complex: ... + def c_(self, name: str, val: complex) -> Node: ... + def s(self, name: str) -> str: ... + def s_(self, name: str, val: str) -> Node: ... + def ss(self, name: str) -> List[str]: ... + def ss_(self, name: str, val: List[str]) -> Node: ... + def i(self, name: str) -> _int: ... + def i_(self, name: str, val: _int) -> Node: ... + # Cannot define "is" like this because it's a reserved keyword in python. + # def is(self, name: str) -> List[_int]: ... + # def is_(self, name: str, val: List[_int]) -> Node: ... + def g(self, name: str) -> Graph: ... + def g_(self, name: str, val: Graph) -> Node: ... + def gs(self, name: str) -> List[Graph]: ... + def gs_(self, name: str, val: List[Graph]) -> Node: ... + def ival(self, name: str) -> IValue: ... + def ival_(self, name: str, val: IValue) -> Node: ... + def t(self, name: str) -> Tensor: ... + def t_(self, name: str, val: Tensor) -> Node: ... + def ts(self, name: str) -> List[Tensor]: ... + def ts_(self, name: str, val: List[Tensor]) -> Node: ... + def ty_(self, name: str, val: JitType) -> Node: ... + def tys_(self, name: str, val: List[JitType]) -> Node: ... ... # Defined in torch/torch/csrc/jit/ir/ir.h @@ -732,6 +803,8 @@ def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModul def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS def _set_conj(x: Tensor, conj: _bool) -> None: ... def _set_neg(x: Tensor, neg: _bool) -> None: ... +def _add_meta_to_tls_dispatch_include() -> None: ... +def _remove_meta_from_tls_dispatch_include() -> None: ... # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack @@ -834,6 +907,8 @@ class FileCheck(object): # TODO (add more FileCheck signature) def check_source_highlighted(self, highlight: str) -> 'FileCheck': ... def run(self, test_string: str) -> None: ... + def check(self, test_string: str) -> 'FileCheck': ... + def check_not(self, test_string: str) -> 'FileCheck': ... ... # Defined in torch/csrc/jit/python/init.cpp @@ -988,6 +1063,8 @@ def _cuda_jiterator_compile_and_launch_kernel(code_string: str, num_outputs: _int, tensors: Tuple, kwargs: Dict[str, Union[_int, _float, _bool]]) -> Tensor: ... +def _cuda_get_cudnn_benchmark_limit() -> _int: ... +def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ... def _nccl_version() -> _int: ... def _nccl_unique_id() -> bytes: ... def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ... @@ -1120,6 +1197,7 @@ class JitType: def with_dtype(self, dtype: _dtype) -> JitType: ... def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ... def kind(self) -> str: ... + def scalarType(self) -> Optional[str]: ... class InferredType: def __init__(self, arg: Union[JitType, str]): ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index b5654d997708d..34f7766b42b4c 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -1,4 +1,4 @@ -from typing import List, Set, Callable, Any +from typing import List, Set, Callable, Any, Union from enum import Enum import torch @@ -10,6 +10,7 @@ class ProfilerState(Enum): CPU = ... CUDA = ... NVTX = ... + ITT = ... KINETO = ... KINETO_GPU_FALLBACK = ... @@ -85,16 +86,58 @@ class _KinetoEvent: ... class _ProfilerEvent: + tag: _EventType id: int correlation_id: int + start_tid: int start_time_ns: int end_time_ns: int duration_time_ns: int parent: _ProfilerEvent children: List[_ProfilerEvent] + extra_fields: Union[_ExtraFields_Allocation, _ExtraFields_Backend, + _ExtraFields_PyCall, _ExtraFields_PyCCall, + _ExtraFields_TorchOp] def name(self) -> str: ... ... +class _PyFrameState: + line_number: int + function_name: str + file_name: str + ... + +class _EventType(Enum): + Allocation = ... + Backend = ... + PyCall = ... + PyCCall = ... + TorchOp = ... + Kineto = ... + +class _Inputs: + shapes: List[List[int]] + dtypes: List[str] + +class _ExtraFields_TorchOp: + allow_tf32_cublas: bool + inputs: _Inputs + ... + +class _ExtraFields_Backend: + ... + +class _ExtraFields_Allocation: + ... + +class _ExtraFields_PyCCall: + caller: _PyFrameState + ... + +class _ExtraFields_PyCall: + caller: _PyFrameState + ... + class _ProfilerResult: def events(self) -> List[_KinetoEvent]: ... def legacy_events(self) -> List[List[ProfilerEvent]]: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 6192b1f043889..64126fd423330 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -371,6 +371,15 @@ class ProcessGroupNCCL(ProcessGroup): def _group_end() -> None: ... ... +class ProcessGroupUCC(ProcessGroup): + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta, + ): ... + class ProcessGroupMPI(ProcessGroup): def __init__( self, diff --git a/torch/_C/_itt.pyi b/torch/_C/_itt.pyi new file mode 100644 index 0000000000000..8de715bc85a98 --- /dev/null +++ b/torch/_C/_itt.pyi @@ -0,0 +1,4 @@ +# Defined in torch/csrc/itt.cpp +def rangePush(message: str) -> None: ... +def rangePop() -> None: ... +def mark(message: str) -> None: ... diff --git a/torch/_C/_lazy.pyi b/torch/_C/_lazy.pyi index e86b80837d589..808dbdd11a4b4 100644 --- a/torch/_C/_lazy.pyi +++ b/torch/_C/_lazy.pyi @@ -7,6 +7,7 @@ def _wait_device_ops(devices: List[str]): ... def _reset_metrics(): ... def _counter_names() -> List[str]: ... def _counter_value(name: str) -> int: ... +def _metrics_report() -> str: ... def _get_graph_hash(tensors: List[Tensor]) -> str: ... def _sync_multi(tensors: List[Tensor], devices: List[str], wait: bool = True, sync_ltc_data: bool = True): ... def _get_tensor_id(tensor: Tensor) -> int: ... diff --git a/torch/_C/_verbose.pyi b/torch/_C/_verbose.pyi new file mode 100644 index 0000000000000..2388ce2bb8a5e --- /dev/null +++ b/torch/_C/_verbose.pyi @@ -0,0 +1,3 @@ +# Defined in torch/csrc/utils/verbose.cpp +def mkl_set_verbose(enable: int) -> int: ... +def mkldnn_set_verbose(level: int) -> int: ... diff --git a/torch/_VF.py b/torch/_VF.py index f42c731576acb..b0b6c1dd85b46 100644 --- a/torch/_VF.py +++ b/torch/_VF.py @@ -10,10 +10,11 @@ introducing torch._VF """ -import torch import sys import types +import torch + class VFModule(types.ModuleType): vf: types.ModuleType diff --git a/torch/__config__.py b/torch/__config__.py index edddcbce46459..f7e3e209654a8 100644 --- a/torch/__config__.py +++ b/torch/__config__.py @@ -8,6 +8,7 @@ def show(): """ return torch._C._show_config() + # TODO: In principle, we could provide more structured version/config # information here. For now only CXX_FLAGS is exposed, as Timer # uses them. @@ -15,6 +16,7 @@ def _cxx_flags(): """Returns the CXX_FLAGS used when building PyTorch.""" return torch._C._cxx_flags() + def parallel_info(): r"""Returns detailed string with parallelization settings""" return torch._C._parallel_info() diff --git a/torch/__future__.py b/torch/__future__.py index 789ec655caa9e..9ac8406e8f8ea 100644 --- a/torch/__future__.py +++ b/torch/__future__.py @@ -11,9 +11,11 @@ """ _overwrite_module_params_on_conversion = False + def set_overwrite_module_params_on_conversion(value): global _overwrite_module_params_on_conversion _overwrite_module_params_on_conversion = value + def get_overwrite_module_params_on_conversion(): return _overwrite_module_params_on_conversion diff --git a/torch/__init__.py b/torch/__init__.py index 0448734fd9771..80853c97f562a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -40,7 +40,7 @@ 'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode', 'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage', 'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage', - '_TypedStorage', + 'TypedStorage', 'UntypedStorage', 'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor', 'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor', 'lobpcg', 'use_deterministic_algorithms', @@ -493,14 +493,14 @@ def use_deterministic_algorithms(mode, *, warn_only=False): >>> torch.use_deterministic_algorithms(True) # Forward mode nondeterministic error - >>> torch.randn(10).index_copy(0, torch.tensor([0]), torch.randn(1)) + >>> torch.randn(10, device='cuda').kthvalue(0) ... - RuntimeError: index_copy does not have a deterministic implementation... + RuntimeError: kthvalue CUDA does not have a deterministic implementation... # Backward mode nondeterministic error - >>> torch.randn(10, requires_grad=True, device='cuda').index_select(0, torch.tensor([0], device='cuda')).backward() + >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward() ... - RuntimeError: index_add_cuda_ does not have a deterministic implementation... + RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation... """ _C._set_deterministic_algorithms(mode, warn_only=warn_only) @@ -656,10 +656,10 @@ def is_warn_always_enabled(): ################################################################################ from ._tensor import Tensor -from .storage import _StorageBase, _TypedStorage, _LegacyStorage, _UntypedStorage +from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage # NOTE: New Storage classes should never be added. When adding a new -# dtype, use torch.storage._TypedStorage directly. +# dtype, use torch.storage.TypedStorage directly. class ByteStorage(_LegacyStorage): @classproperty @@ -747,11 +747,11 @@ def dtype(self): return torch.quint2x4 _storage_classes = { - _UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage, + UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage, QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage, ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage, - _TypedStorage + TypedStorage } # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings() diff --git a/torch/_appdirs.py b/torch/_appdirs.py index ab3ea14a96fb8..9395d9fb9f5c7 100644 --- a/torch/_appdirs.py +++ b/torch/_appdirs.py @@ -50,28 +50,28 @@ __version_info__ = tuple(int(segment) for segment in __version__.split(".")) -import sys import os +import sys unicode = str -if sys.platform.startswith('java'): +if sys.platform.startswith("java"): import platform + os_name = platform.java_ver()[3][0] - if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc. - system = 'win32' - elif os_name.startswith('Mac'): # "Mac OS X", etc. - system = 'darwin' - else: # "Linux", "SunOS", "FreeBSD", etc. + if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc. + system = "win32" + elif os_name.startswith("Mac"): # "Mac OS X", etc. + system = "darwin" + else: # "Linux", "SunOS", "FreeBSD", etc. # Setting this to "linux2" is not ideal, but only Windows or Mac # are actually checked for and the rest of the module expects # *sys.platform* style strings. - system = 'linux2' + system = "linux2" else: system = sys.platform - def user_data_dir(appname=None, appauthor=None, version=None, roaming=False): r"""Return full path to the user-specific data dir for this application. @@ -114,12 +114,12 @@ def user_data_dir(appname=None, appauthor=None, version=None, roaming=False): path = os.path.join(path, appauthor, appname) else: path = os.path.join(path, appname) - elif system == 'darwin': - path = os.path.expanduser('~/Library/Application Support/') + elif system == "darwin": + path = os.path.expanduser("~/Library/Application Support/") if appname: path = os.path.join(path, appname) else: - path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share")) + path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share")) if appname: path = os.path.join(path, appname) if appname and version: @@ -167,16 +167,19 @@ def site_data_dir(appname=None, appauthor=None, version=None, multipath=False): path = os.path.join(path, appauthor, appname) else: path = os.path.join(path, appname) - elif system == 'darwin': - path = os.path.expanduser('/Library/Application Support') + elif system == "darwin": + path = os.path.expanduser("/Library/Application Support") if appname: path = os.path.join(path, appname) else: # XDG default for $XDG_DATA_DIRS # only first, if multipath is False - path = os.getenv('XDG_DATA_DIRS', - os.pathsep.join(['/usr/local/share', '/usr/share'])) - pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)] + path = os.getenv( + "XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"]) + ) + pathlist = [ + os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep) + ] if appname: if version: appname = os.path.join(appname, version) @@ -224,12 +227,12 @@ def user_config_dir(appname=None, appauthor=None, version=None, roaming=False): """ if system == "win32": path = user_data_dir(appname, appauthor, None, roaming) - elif system == 'darwin': - path = os.path.expanduser('~/Library/Preferences/') + elif system == "darwin": + path = os.path.expanduser("~/Library/Preferences/") if appname: path = os.path.join(path, appname) else: - path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config")) + path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) if appname: path = os.path.join(path, appname) if appname and version: @@ -267,19 +270,21 @@ def site_config_dir(appname=None, appauthor=None, version=None, multipath=False) WARNING: Do not use this on Windows. See the Vista-Fail note above for why. """ - if system == 'win32': + if system == "win32": path = site_data_dir(appname, appauthor) if appname and version: path = os.path.join(path, version) - elif system == 'darwin': - path = os.path.expanduser('/Library/Preferences') + elif system == "darwin": + path = os.path.expanduser("/Library/Preferences") if appname: path = os.path.join(path, appname) else: # XDG default for $XDG_CONFIG_DIRS # only first, if multipath is False - path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg') - pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)] + path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg") + pathlist = [ + os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep) + ] if appname: if version: appname = os.path.join(appname, version) @@ -336,12 +341,12 @@ def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True): path = os.path.join(path, appname) if opinion: path = os.path.join(path, "Cache") - elif system == 'darwin': - path = os.path.expanduser('~/Library/Caches') + elif system == "darwin": + path = os.path.expanduser("~/Library/Caches") if appname: path = os.path.join(path, appname) else: - path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache')) + path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) if appname: path = os.path.join(path, appname) if appname and version: @@ -383,7 +388,7 @@ def user_state_dir(appname=None, appauthor=None, version=None, roaming=False): if system in ["win32", "darwin"]: path = user_data_dir(appname, appauthor, None, roaming) else: - path = os.getenv('XDG_STATE_HOME', os.path.expanduser("~/.local/state")) + path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state")) if appname: path = os.path.join(path, appname) if appname and version: @@ -424,9 +429,7 @@ def user_log_dir(appname=None, appauthor=None, version=None, opinion=True): This can be disabled with the `opinion=False` option. """ if system == "darwin": - path = os.path.join( - os.path.expanduser('~/Library/Logs'), - appname) + path = os.path.join(os.path.expanduser("~/Library/Logs"), appname) elif system == "win32": path = user_data_dir(appname, appauthor, version) version = False @@ -444,8 +447,10 @@ def user_log_dir(appname=None, appauthor=None, version=None, opinion=True): class AppDirs(object): """Convenience wrapper for getting application dirs.""" - def __init__(self, appname=None, appauthor=None, version=None, - roaming=False, multipath=False): + + def __init__( + self, appname=None, appauthor=None, version=None, roaming=False, multipath=False + ): self.appname = appname self.appauthor = appauthor self.version = version @@ -454,41 +459,43 @@ def __init__(self, appname=None, appauthor=None, version=None, @property def user_data_dir(self): - return user_data_dir(self.appname, self.appauthor, - version=self.version, roaming=self.roaming) + return user_data_dir( + self.appname, self.appauthor, version=self.version, roaming=self.roaming + ) @property def site_data_dir(self): - return site_data_dir(self.appname, self.appauthor, - version=self.version, multipath=self.multipath) + return site_data_dir( + self.appname, self.appauthor, version=self.version, multipath=self.multipath + ) @property def user_config_dir(self): - return user_config_dir(self.appname, self.appauthor, - version=self.version, roaming=self.roaming) + return user_config_dir( + self.appname, self.appauthor, version=self.version, roaming=self.roaming + ) @property def site_config_dir(self): - return site_config_dir(self.appname, self.appauthor, - version=self.version, multipath=self.multipath) + return site_config_dir( + self.appname, self.appauthor, version=self.version, multipath=self.multipath + ) @property def user_cache_dir(self): - return user_cache_dir(self.appname, self.appauthor, - version=self.version) + return user_cache_dir(self.appname, self.appauthor, version=self.version) @property def user_state_dir(self): - return user_state_dir(self.appname, self.appauthor, - version=self.version) + return user_state_dir(self.appname, self.appauthor, version=self.version) @property def user_log_dir(self): - return user_log_dir(self.appname, self.appauthor, - version=self.version) + return user_log_dir(self.appname, self.appauthor, version=self.version) + +# ---- internal support stuff -#---- internal support stuff def _get_win_folder_from_registry(csidl_name): """This is a fallback technique at best. I'm not sure if using the @@ -505,14 +512,15 @@ def _get_win_folder_from_registry(csidl_name): key = _winreg.OpenKey( _winreg.HKEY_CURRENT_USER, - r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", ) dir, type = _winreg.QueryValueEx(key, shell_folder_name) return dir def _get_win_folder_with_pywin32(csidl_name): - from win32com.shell import shellcon, shell + from win32com.shell import shell, shellcon + dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0) # Try to make this a unicode path because SHGetFolderPath does # not return unicode strings when there is unicode data in the @@ -530,6 +538,7 @@ def _get_win_folder_with_pywin32(csidl_name): if has_high_char: try: import win32api + dir = win32api.GetShortPathName(dir) except ImportError: pass @@ -564,15 +573,23 @@ def _get_win_folder_with_ctypes(csidl_name): return buf.value + def _get_win_folder_with_jna(csidl_name): import array + from com.sun import jna from com.sun.jna.platform import win32 buf_size = win32.WinDef.MAX_PATH * 2 - buf = array.zeros('c', buf_size) + buf = array.zeros("c", buf_size) shell = win32.Shell32.INSTANCE - shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf) + shell.SHGetFolderPath( + None, + getattr(win32.ShlObj, csidl_name), + None, + win32.ShlObj.SHGFP_TYPE_CURRENT, + buf, + ) dir = jna.Native.toString(buf.tostring()).rstrip("\0") # Downgrade to short path name if have highbit chars. See @@ -583,42 +600,48 @@ def _get_win_folder_with_jna(csidl_name): has_high_char = True break if has_high_char: - buf = array.zeros('c', buf_size) + buf = array.zeros("c", buf_size) kernel = win32.Kernel32.INSTANCE if kernel.GetShortPathName(dir, buf, buf_size): dir = jna.Native.toString(buf.tostring()).rstrip("\0") return dir + if system == "win32": try: import win32com.shell + _get_win_folder = _get_win_folder_with_pywin32 except ImportError: try: from ctypes import windll + _get_win_folder = _get_win_folder_with_ctypes except ImportError: try: import com.sun.jna + _get_win_folder = _get_win_folder_with_jna except ImportError: _get_win_folder = _get_win_folder_from_registry -#---- self test code +# ---- self test code if __name__ == "__main__": appname = "MyApp" appauthor = "MyCompany" - props = ("user_data_dir", - "user_config_dir", - "user_cache_dir", - "user_state_dir", - "user_log_dir", - "site_data_dir", - "site_config_dir") + props = ( + "user_data_dir", + "user_config_dir", + "user_cache_dir", + "user_state_dir", + "user_log_dir", + "site_data_dir", + "site_config_dir", + ) print("-- app dirs %s --" % __version__) diff --git a/torch/_classes.py b/torch/_classes.py index f36463d881987..3de7c9e1a2bec 100644 --- a/torch/_classes.py +++ b/torch/_classes.py @@ -1,22 +1,25 @@ import types + import torch._C + class _ClassNamespace(types.ModuleType): def __init__(self, name): - super(_ClassNamespace, self).__init__('torch.classes' + name) + super(_ClassNamespace, self).__init__("torch.classes" + name) self.name = name def __getattr__(self, attr): proxy = torch._C._get_custom_class_python_wrapper(self.name, attr) if proxy is None: - raise RuntimeError(f'Class {self.name}.{attr} not registered!') + raise RuntimeError(f"Class {self.name}.{attr} not registered!") return proxy + class _Classes(types.ModuleType): - __file__ = '_classes.py' + __file__ = "_classes.py" def __init__(self): - super(_Classes, self).__init__('torch.classes') + super(_Classes, self).__init__("torch.classes") def __getattr__(self, name): namespace = _ClassNamespace(name) @@ -47,5 +50,6 @@ def load_library(self, path): """ torch.ops.load_library(path) + # The classes "namespace" classes = _Classes() diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 0bed70c23f2a4..9f9c4f1404f3e 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -1,9 +1,13 @@ +import inspect +from collections import defaultdict +from functools import wraps +from itertools import chain +from typing import Callable, Dict, NamedTuple, Sequence, Tuple, Union + import torch import torch._ops import torch.library -from typing import Callable, Union, Dict, Sequence from torch.utils._pytree import tree_map -from collections import defaultdict __all__ = ["decomposition_table", "register_decomposition", "get_decompositions"] @@ -36,7 +40,50 @@ def clamp_min(x): a Meta implementation, we will register it to the dispatcher. Use `disable_meta` to disable this behavior. """ - def decomposition_decorator(f): + + def decomposition_decorator(f: Callable) -> Callable: + sig = inspect.signature(f) + out_annotation = f.__annotations__.get("out") + # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this + fn = f + if out_annotation and getattr(out_annotation, "__origin__", None) is tuple: + out_names = sig.return_annotation._fields + # If out is a tuple, we need to register a function that unpacks all the out + # elements as this is what native_functions.yaml expects + + @wraps(f) + def _fn(*args, **kwargs): + out_kwargs = tuple(kwargs.pop(o, None) for o in out_names) + # Either all of the out kwargs are set or none of them + is_none = out_kwargs[0] is None + assert all((o is None) == is_none for o in out_kwargs) + return f(*args, **kwargs, out=None if is_none else out_kwargs) + + out_params = [ + inspect.Parameter( + o, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=t, + ) + for o, t in zip(out_names, out_annotation.__args__) + ] + # Drop the out parameter and concatenate the new kwargs in the signature + params = chain( + (v for k, v in sig.parameters.items() if k != "out"), out_params + ) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] + ) + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = { + k: v for k, v in f.__annotations__.items() if k != "out" + } + for o in out_params: + _fn.__annotations__[o.name] = o.annotation + + fn = _fn + nonlocal registry if registry is None: registry = decomposition_table @@ -52,7 +99,7 @@ def add_op_to_table(aten_op): for op_overload in overloads: if op_overload in registry: raise RuntimeError(f"duplicate registrations for {op_overload}") - registry[op_overload] = f + registry[op_overload] = fn # TODO: factor this logic into OpOverload or Library API name = op_overload._schema.name if op_overload._schema.overload_name: @@ -63,17 +110,31 @@ def add_op_to_table(aten_op): # which don't have corresponding dispatcher entries, we need # to filter those out and torch._C._dispatch_has_kernel(name) - and not torch._C._dispatch_has_kernel_for_dispatch_key(name, 'Meta') # Don't register a meta kernel to any operator that has # a CompositeImplicitAutograd kernel in core. # Otherwise we won't be able to run autograd for that operator with the meta backend. - and 'CompositeImplicitAutograd' not in torch._C._dispatch_dump(name) + and "CompositeImplicitAutograd" not in torch._C._dispatch_dump(name) + and not torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta") ): - meta_lib.impl(op_overload, f) + if any( + a.alias_info is not None and not a.alias_info.is_write + for a in op_overload._schema.arguments + ): + raise RuntimeError( + f""" +Attempting to register a python meta kernel for a view operator: {str(op_overload)}. +We shouldn't do this, because the output will report as not having aliased storages. +All view ops have meta kernels in C++ today, so we should use those instead. + +If you're registering an operator through the `@register_decomposition` decorator, +Please set `disable_meta=True`. + """ + ) + meta_lib.impl(op_overload, fn) # To handle allowing multiple aten_ops at once tree_map(add_op_to_table, aten_op) - return f + return fn return decomposition_decorator @@ -104,6 +165,7 @@ def get_decompositions( decompositions[op] = decomposition_table[op] return decompositions + # populate the table import torch._decomp.decompositions import torch._refs diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 153b94510d85f..4b177d99ce605 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1,13 +1,14 @@ +import functools +from enum import Enum +from typing import Callable, List, Optional, Tuple + import torch +import torch._prims_common as utils +import torch.nn.functional as F from torch import Tensor from torch._decomp import register_decomposition -from enum import Enum -from typing import Tuple, Optional, List, Callable -import torch.nn.functional as F -import functools -from torch.utils._pytree import tree_map, tree_flatten -import torch._prims.utils as utils -from torch._prims.wrappers import out_wrapper_multi +from torch._prims_common.wrappers import out_wrapper +from torch.utils._pytree import tree_flatten, tree_map # None of these functions are publicly accessible; get at them # from torch._decomps @@ -28,9 +29,12 @@ class Reduction(Enum): def type_casts(f: Callable, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND): @functools.wraps(f) def inner(*args, **kwargs): - flat_args = [x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)] - computation_dtype, result_dtype = utils.elementwise_dtypes(*flat_args, - type_promotion_kind=type_promotion) + flat_args = [ + x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor) + ] + computation_dtype, result_dtype = utils.elementwise_dtypes( + *flat_args, type_promotion_kind=type_promotion + ) # TODO: pretty sure this is not quite right def increase_prec(x): @@ -50,9 +54,16 @@ def decrease_prec(x): return inner -pw_cast_for_opmath = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) -reduction_complex_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT) -pw_cast_for_int_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) + +pw_cast_for_opmath = functools.partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +) +reduction_complex_to_real = functools.partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT +) +pw_cast_for_int_to_real = functools.partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +) # This expands x until x.dim() == dim. Might be useful as an operator def _unsqueeze_to_dim(x: Tensor, dim: int): @@ -79,6 +90,7 @@ def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float z = (x * beta).exp() return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) + @register_decomposition(aten.elu) @pw_cast_for_opmath def elu( @@ -189,7 +201,7 @@ def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"): M_SQRT2 = 1.41421356237309504880 M_SQRT1_2 = 0.70710678118654752440 M_2_SQRTPI = 1.12837916709551257390 - if approximate == 'tanh': + if approximate == "tanh": kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 kKappa = 0.044715 x_sq = self * self @@ -281,7 +293,9 @@ def rrelu_with_noise_backward( return grad_output.mul(noise) else: negative_slope = (lower + upper) / 2 - return aten.leaky_relu_backward(grad_output, self, negative_slope, self_is_result) + return aten.leaky_relu_backward( + grad_output, self, negative_slope, self_is_result + ) @register_decomposition(aten.log_sigmoid_backward) @@ -313,11 +327,13 @@ def to_real_dtype(dtype: torch.dtype): elif dtype == torch.complex128: return torch.float64 + # TODO: None of these loss castings are quite correct, see # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels # perform the pointwise portion in opmath, but don't maintain it between the # pointwise portion and the reduction + @register_decomposition(aten.mse_loss) @pw_cast_for_opmath def mse_loss( @@ -365,13 +381,13 @@ def huber_loss_backward( def _nll_loss_backward( - grad_output: Tensor, - self: Tensor, - target: Tensor, - weight: Optional[Tensor], - reduction: int, - ignore_index: int, - total_weight: Tensor, + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, ) -> Tensor: channel_dim = 0 if self.dim() < 2 else 1 if reduction == Reduction.MEAN.value: @@ -404,12 +420,16 @@ def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor: assert self.dim() > 0, "glu does not support 0-dimensional tensors" wrap_dim = utils.canonicalize_dim(self.dim(), dim) nIn = self.size(wrap_dim) - assert nIn % 2 == 0, f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" + assert ( + nIn % 2 == 0 + ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" inputSize = nIn // 2 firstHalf = self.narrow(wrap_dim, 0, inputSize) secondHalf = self.narrow(wrap_dim, inputSize, inputSize) gradInputFirstHalf = torch.sigmoid(secondHalf) - gradInputSecondHalf = (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output + gradInputSecondHalf = ( + (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output + ) gradInputFirstHalf = gradInputFirstHalf * grad_output return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim) @@ -452,7 +472,9 @@ def nll_loss_backward( grad_output.dim() <= 1 and grad_output.numel() == 1 ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}" - return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight) + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) @register_decomposition(aten.nll_loss2d_backward) @@ -473,18 +495,20 @@ def nll_loss2d_backward( target.dim() == 3 ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}" - assert( - self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2] + assert ( + self.shape[0] == target.shape[0] + and self.shape[2] == target.shape[1] + and self.shape[3] == target.shape[2] ), f"size mismatch (got input: {self.shape}, target: {target.shape}" - assert ( - total_weight.numel() == 1 - ), ( + assert total_weight.numel() == 1, ( "expected total_weight to be a single element tensor, " f"got: {total_weight.shape} ( {total_weight.numel()}, elements)" ) - return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight) + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) @register_decomposition(aten.binary_cross_entropy) @@ -725,6 +749,7 @@ def embedding( return weight.index_select(0, indices.reshape(-1)).view(size) + # TODO: Correct the type promotion semantics @register_decomposition(aten.embedding_dense_backward) def embedding_dense_backward( @@ -735,7 +760,7 @@ def embedding_dense_backward( scale_grad_by_freq: bool, ): numel = indices.numel() - grad = grad_output.view(numel, grad_output.size(-1)) + grad = grad_output.reshape(numel, grad_output.size(-1)) grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1])) indices_rank1 = indices.reshape(numel) if scale_grad_by_freq: @@ -759,7 +784,7 @@ def prod(x: List[int]): return r -@register_decomposition(aten.split_with_sizes) +@register_decomposition(aten.split_with_sizes, disable_meta=True) def split_with_sizes( self: Tensor, split_sizes: List[int], dim: int = 0 ) -> List[Tensor]: @@ -773,7 +798,7 @@ def split_with_sizes( return splits -@register_decomposition(aten.split.Tensor) +@register_decomposition(aten.split.Tensor, disable_meta=True) def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: input_sizes = self.shape dim_size = input_sizes[dim] @@ -798,6 +823,7 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = return out return beta * self + out + # This computes the mean and variance along the specifized normalization dims, # then normalizes along those dims. Finally, it returns the mean and variance of # the normalized dims. Note that it intentionally leaves outputs upcasted. @@ -811,7 +837,7 @@ def normalize(input, norm_dims, eps): mean = torch.mean(input_acc, dim=norm_dims, keepdim=True) rstd = torch.rsqrt(biased_var + eps) - out = ((input - mean) * rstd) + out = (input - mean) * rstd return out, mean, rstd @@ -913,9 +939,7 @@ def native_layer_norm_backward( if output_mask[1] and weight_cast is not None: if len(outer_dim_indices) > 0: - d_weight = torch.sum( - grad_out_cast * x_hat, outer_dim_indices, False - ) + d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False) else: d_weight = grad_out_cast * x_hat @@ -925,7 +949,11 @@ def native_layer_norm_backward( else: d_bias = grad_out_cast - return _maybe_cast(d_input, input.dtype), _maybe_cast(d_weight, input.dtype), _maybe_cast(d_bias, input.dtype) + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + _maybe_cast(d_bias, input.dtype), + ) @register_decomposition(aten.native_batch_norm) @@ -953,7 +981,9 @@ def native_batch_norm( # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose # numerics probably don't matter. - unbiased_var = torch.var(input, reduction_dims, unbiased=False) * (n / (n - 1)) + unbiased_var = torch.var(input, reduction_dims, unbiased=False) * ( + n / (n - 1) + ) running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var) else: assert running_mean is not None and running_var is not None @@ -970,7 +1000,7 @@ def native_batch_norm( save_rstd = input.new_zeros((0,)) mean = _unsqueeze_to_dim(mean, input.dim() - 1) invstd = _unsqueeze_to_dim(invstd, input.dim() - 1) - output = ((input - mean) * invstd) + output = (input - mean) * invstd if weight is None: weight = input.new_ones(()) @@ -981,22 +1011,12 @@ def native_batch_norm( weight = _unsqueeze_to_dim(weight, input.dim() - 1) bias = _unsqueeze_to_dim(bias, input.dim() - 1) output = output * weight + bias - if input.device.type == 'cpu': + if input.device.type == "cpu": save_mean = save_mean.to(dtype=input.dtype) save_rstd = save_rstd.to(dtype=input.dtype) return output.to(dtype=input.dtype), save_mean, save_rstd -@register_decomposition(aten.clamp_min) -def clamp_min(self: Tensor, min: float): - return torch.clamp(self, min=min) - - -@register_decomposition(aten.clamp_max) -def clamp_max(self: Tensor, max: float): - return torch.clamp(self, max=max) - - @register_decomposition(aten._fused_dropout) @pw_cast_for_opmath def _fused_dropout_decomposition(input, p, generator=None): @@ -1008,11 +1028,15 @@ def _fused_dropout_decomposition(input, p, generator=None): @register_decomposition(aten.xlogy.Tensor) @pw_cast_for_int_to_real def xlogy(self: Tensor, other: Tensor) -> Tensor: - return aten.where(aten.isnan(self), - self, - aten.where(self == aten.new_zeros(self, ()), - aten.new_zeros(self, ()), - self * aten.log(other))) + return aten.where( + aten.isnan(self), + self, + aten.where( + self == aten.new_zeros(self, ()), + aten.new_zeros(self, ()), + self * aten.log(other), + ), + ) @register_decomposition(aten.var.correction) @@ -1067,8 +1091,10 @@ def std_decomposition( # Questionable decompositions # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. # Note that this decomposition causes issues with in-place ops -@register_decomposition(aten.detach, disable_meta=True) -def detach_decomposition(x): +@register_decomposition( + [aten.detach, aten.lift, aten.lift_fresh, aten.alias], disable_meta=True +) +def nop_decomposition(x): return x @@ -1096,7 +1122,106 @@ def cudnn_batch_norm( # Cudnn return running mean and variance when training is True if training: return (a, b, c, input.new_zeros((0,), dtype=torch.uint8)) - return (a, input.new_zeros((0,)), input.new_zeros((0,)), input.new_zeros((0,), dtype=torch.uint8)) + return ( + a, + input.new_zeros((0,)), + input.new_zeros((0,)), + input.new_zeros((0,), dtype=torch.uint8), + ) + + +@register_decomposition(aten.native_batch_norm_backward) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_dtype = input.dtype + computation_dtype = utils.get_computation_dtype(input.dtype) + ( + grad_out_cast, + input_cast, + weight_cast, + running_mean_cast, + running_var_cast, + save_mean_cast, + save_invstd_cast, + ) = [ + x.to(computation_dtype) if x is not None else x + for x in ( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + ) + ] + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(list(input_shape)) / input_shape[axis] + mean = save_mean_cast + invstd = save_invstd_cast + if train: + assert save_mean_cast is not None and save_invstd_cast is not None + else: + assert running_mean_cast is not None and running_var_cast is not None + mean = running_mean_cast + invstd = torch.rsqrt(running_var_cast + eps) + + broadcast_mask: List[int] = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = torch.reshape(mean, broadcast_mask) # type: ignore[arg-type] + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type] + dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) + + grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) + proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator] + + if weight_cast is None: + grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type] + else: + grad_scale = torch.reshape(invstd * weight_cast, broadcast_mask) + + if train: + proj = (input_cast - mean) * proj_scale + grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out_cast * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + else: + grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp + + return ( + grad_input.to(input_dtype), + _maybe_cast(grad_weight, input_dtype), + _maybe_cast(grad_bias, input_dtype), + ) @register_decomposition(aten.cudnn_batch_norm_backward) @@ -1125,7 +1250,7 @@ def cudnn_batch_norm_backward( ) -@register_decomposition(aten.transpose.int) +@register_decomposition(aten.transpose.int, disable_meta=True) def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor: dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc] @@ -1156,14 +1281,16 @@ def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor: return torch.sum(torch.exp(self), dim, keepdim).log() maxes = torch.amax(self, dim, keepdim=True) maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim) - maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0) + maxes_squeezed = torch.masked_fill( + maxes_squeezed, maxes_squeezed.abs() == float("inf"), 0 + ) result = torch.sum(torch.exp(self - maxes), dim, keepdim) return result.log().add(maxes_squeezed) # nb: Should use acc_t, not op_math @register_decomposition(aten.log_sigmoid_forward) -@out_wrapper_multi('output', 'buffer') +@out_wrapper("output", "buffer") @pw_cast_for_opmath def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: min = torch.minimum(self.new_zeros(()), self) @@ -1174,33 +1301,88 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: buffer = z return min - torch.log1p(z), buffer -# The implementation matches torch.ops.aten.norm -# torch.ops.aten.norm only supports numeric p, does not support Frobenius norm or nuclear norm -# For 2-norm and -2 matrix norm, it doesn't compute the singular values, it just compute the norm the same as when p > 2. -@register_decomposition([aten.norm.Scalar, aten.norm.ScalarOpt_dim]) + +@register_decomposition(aten.norm) +@out_wrapper() @reduction_complex_to_real -def norm(self: Tensor, p: float = 2, dim: List[int] = None, keepdim: bool = False): - if dim is None: - dim = [] - - if p == 0: - return (self != 0).sum(dim, keepdim=keepdim) - elif p == float('inf'): - return self.abs().amax(dim, keepdim=keepdim) - elif p == -float('inf'): - return self.abs().amin(dim, keepdim=keepdim) - - def fast_pow(x, ord): - if ord == 1.0: - return x - elif ord == 2.0: - return x.square() - elif ord == 0.5: - return x.sqrt() +def norm( + self: Tensor, + p: Optional[float] = None, + dim: List[int] = None, + keepdim: bool = False, + dtype: Optional[torch.dtype] = None, +): + if p is None: + p = 2.0 + return torch.linalg.vector_norm(self, p, dim, keepdim, dtype=dtype) + + +@register_decomposition(torch.ops.aten.upsample_bilinear2d.vec) +@pw_cast_for_opmath +def upsample_bilinear2d_vec( + input: Tensor, + output_size: Optional[List[int]], + align_corners: bool, + scale_factors: Optional[List[float]], +) -> Tensor: + # get dimensions of original image + n_batch, n_channels, in_h, in_w = input.shape + + if output_size is not None: + out_h = float(output_size[0]) + out_w = float(output_size[1]) + elif scale_factors is not None: + out_h = in_h * scale_factors[0] + out_w = in_w * scale_factors[1] + + # Calculate horizontal and vertical scaling factor + if out_h > 1: + if align_corners: + h_scale_factor = (in_h - 1) / (int(out_h) - 1) else: - return x.pow(ord) + h_scale_factor = in_h / out_h + else: + h_scale_factor = 0.0 - if not (p % 2.0 == 0.0 and utils.is_float_dtype(self.dtype)): - self = self.abs() + if out_w > 1: + if align_corners: + w_scale_factor = (in_w - 1) / (int(out_w) - 1) + else: + w_scale_factor = in_w / out_w + else: + w_scale_factor = 0.0 + + i = torch.arange(int(out_h), dtype=input.dtype, device=input.device) + j = torch.arange(int(out_w), dtype=input.dtype, device=input.device) + + if align_corners: + x = h_scale_factor * i + y = w_scale_factor * j + else: + x = (h_scale_factor * (i + 0.5) - 0.5).clamp(min=0.0) + y = (w_scale_factor * (j + 0.5) - 0.5).clamp(min=0.0) + + x_floor = torch.floor(x).to(torch.int64) + x_ceil = torch.ceil(x).clamp(max=in_h - 1).to(torch.int64) + y_floor = torch.floor(y).to(torch.int64) + y_ceil = torch.ceil(y).clamp(max=in_w - 1).to(torch.int64) - return fast_pow(fast_pow(self, p).sum(dim, keepdim=keepdim), 1.0 / p) + x_view = x.unsqueeze(1) + x_floor_view = x_floor.unsqueeze(1) + x_ceil_view = x_ceil.unsqueeze(1) + + v1 = input[:, :, x_floor_view, y_floor] + v2 = input[:, :, x_ceil_view, y_floor] + v3 = input[:, :, x_floor_view, y_ceil] + v4 = input[:, :, x_ceil_view, y_ceil] + + xscale2 = x_view - x_floor_view + xscale1 = 1.0 - xscale2 + + yscale2 = y - y_floor + yscale1 = 1.0 - yscale2 + + q1 = torch.mul(v1, xscale1) + torch.mul(v2, xscale2) + q2 = torch.mul(v3, xscale1) + torch.mul(v4, xscale2) + result = torch.mul(q1, yscale1) + torch.mul(q2, yscale2) + return result diff --git a/torch/_deploy.py b/torch/_deploy.py index 4cdb6f6f92e10..53769538b6c11 100644 --- a/torch/_deploy.py +++ b/torch/_deploy.py @@ -1,10 +1,12 @@ import io + import torch +from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer from torch.package._package_pickler import create_pickler from torch.package._package_unpickler import PackageUnpickler -from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer from torch.serialization import _maybe_decode_ascii + def _save_storages(importer, obj): serialized_storages = [] serialized_dtypes = [] @@ -17,8 +19,8 @@ def _save_storages(importer, obj): importers = sys_importer def persistent_id(obj): - if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage): - if isinstance(obj, torch.storage._TypedStorage): + if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): + if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case storage = obj._storage @@ -29,7 +31,7 @@ def persistent_id(obj): serialized_storages.append(obj) serialized_dtypes.append(dtype) - return ('storage', len(serialized_storages) - 1) + return ("storage", len(serialized_storages) - 1) if hasattr(obj, "__reduce_deploy__"): if _serialized_reduces.get(id(obj)) is None: @@ -48,25 +50,30 @@ def persistent_id(obj): pickler.persistent_id = persistent_id pickler.dump(obj) data_value = data_buf.getvalue() - return data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None + return ( + data_value, + serialized_storages, + serialized_dtypes, + importer.zip_reader if importer else None, + ) -def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): +def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] - if typename == 'storage': + if typename == "storage": # TODO: Once we decide to break serialization FC, we can - # stop wrapping with _TypedStorage + # stop wrapping with TypedStorage storage = serialized_storages[data[0]] dtype = serialized_dtypes[data[0]] - return torch.storage._TypedStorage( - wrap_storage=storage._untyped(), - dtype=dtype) + return torch.storage.TypedStorage( + wrap_storage=storage.untyped(), dtype=dtype + ) - if typename == 'reduce_deploy': + if typename == "reduce_deploy": reduce_id, func, args = data if reduce_id not in _loaded_reduces: _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) @@ -74,7 +81,6 @@ def persistent_load(saved_id): return None - importer: Importer if zip_reader is not None: importer = OrderedImporter(_get_package(zip_reader), sys_importer) @@ -86,6 +92,7 @@ def persistent_load(saved_id): result = _deploy_objects[id] = unpickler.load() return result + def _get_package(zip_reader): if zip_reader not in _raw_packages: _raw_packages[zip_reader] = PackageImporter(zip_reader) diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index b3c61edaf83ae..223d96aa84abd 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -4,30 +4,43 @@ circular dependency problems """ -import contextlib +import ast +import builtins import collections +import contextlib import enum import inspect -import ast -import weakref -import warnings -from textwrap import dedent -import torch -import sys -import builtins -import typing import io import pickle +import sys import threading +import typing +import warnings +import weakref +from textwrap import dedent +from typing import ( # noqa: F401 + Any, + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +import torch + # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. # Explicitly ask to import `torch.distributed.__init__` first. # Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. import torch.distributed.rpc +import torch.package._mangling as package_mangling from torch._C import Future as CFuture -from torch._sources import get_source_lines_and_file, parse_def, fake_range +from torch._sources import fake_range, get_source_lines_and_file, parse_def from torch.futures import Future -import torch.package._mangling as package_mangling -from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union # noqa: F401 if sys.version_info[:2] > (3, 7): from typing import Final @@ -37,21 +50,24 @@ LockType: Type try: import _thread + LockType = _thread.LockType except ImportError: import _dummy_thread + LockType = _dummy_thread.LockType # Wrapper functions that can call either of 2 functions depending on a boolean # argument -boolean_dispatched: 'weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]' = weakref.WeakKeyDictionary() # noqa: T484 +boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = ( + weakref.WeakKeyDictionary() +) # noqa: T484 -FAKE_FILENAME_PREFIX = '__torch_jit_dataclass' +FAKE_FILENAME_PREFIX = "__torch_jit_dataclass" class SourceLoader: - def __init__(self): self.content = {} @@ -74,11 +90,12 @@ def createResolutionCallbackFromEnv(lookup_base): You should not use this directly, it should only be used from the other createResolutionCallbackFrom* functions. """ + def lookupInModule(qualified_name, module): - if '.' in qualified_name: - parts = qualified_name.split('.') + if "." in qualified_name: + parts = qualified_name.split(".") base = parts[0] - remaining_pieces = '.'.join(parts[1:]) + remaining_pieces = ".".join(parts[1:]) module_value = getattr(module, base) return lookupInModule(remaining_pieces, module_value) else: @@ -86,22 +103,22 @@ def lookupInModule(qualified_name, module): def parseNestedExpr(expr, module) -> Tuple[Any, int]: i = 0 - while i < len(expr) and expr[i] not in (',', '[', ']'): + while i < len(expr) and expr[i] not in (",", "[", "]"): i += 1 # Special case logic for the empty Tuple as a subscript (used # in the type annotation `Tuple[()]`) - if expr[:i] == '()': + if expr[:i] == "()": return (), i base = lookupInModule(expr[:i].strip(), module) assert base is not None, f"Unresolvable type {expr[:i]}" - if i == len(expr) or expr[i] != '[': + if i == len(expr) or expr[i] != "[": return base, i - assert expr[i] == '[' + assert expr[i] == "[" parts = [] - while expr[i] != ']': + while expr[i] != "]": part_len = 0 i += 1 part, part_len = parseNestedExpr(expr[i:], module) @@ -115,7 +132,9 @@ def parseNestedExpr(expr, module) -> Tuple[Any, int]: def parseExpr(expr, module): try: value, len_parsed = parseNestedExpr(expr, module) - assert len_parsed == len(expr), "whole expression was not parsed, falling back to c++ parser" + assert len_parsed == len( + expr + ), "whole expression was not parsed, falling back to c++ parser" return value except Exception: """ @@ -191,6 +210,7 @@ def get_closure(fn): return captures + # [local resolution in python] # Depending on where a variable is defined, and where it is used, we may # or may not be able to recover its value when recursively compiling a @@ -236,7 +256,6 @@ def get_closure(fn): # This could be worked around by manualy adding it to `global()` dictionary. - def createResolutionCallbackFromClosure(fn): """ Create a resolutionCallback by introspecting the function instead of @@ -271,8 +290,12 @@ def can_compile_class(cls) -> bool: return False names = cls.__dict__ - fns = [getattr(cls, name) for name in names if inspect.isroutine(getattr(cls, name, None))] - has_code = [hasattr(fn, '__code__') for fn in fns] + fns = [ + getattr(cls, name) + for name in names + if inspect.isroutine(getattr(cls, name, None)) + ] + has_code = [hasattr(fn, "__code__") for fn in fns] return all(has_code) @@ -315,14 +338,16 @@ def get_annotation_str(annotation): if isinstance(annotation, ast.Name): return annotation.id elif isinstance(annotation, ast.Attribute): - return '.'.join([get_annotation_str(annotation.value), annotation.attr]) + return ".".join([get_annotation_str(annotation.value), annotation.attr]) elif isinstance(annotation, ast.Subscript): # In Python3.9+ subscript indicies are not wrapped in ast.Index subscript_slice = annotation.slice if sys.version_info >= (3, 9) else annotation.slice.value # type: ignore[attr-defined] return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]" elif isinstance(annotation, ast.Tuple): - return ','.join([get_annotation_str(elt) for elt in annotation.elts]) - elif isinstance(annotation, ast.Constant) or isinstance(annotation, ast.NameConstant): + return ",".join([get_annotation_str(elt) for elt in annotation.elts]) + elif isinstance(annotation, ast.Constant) or isinstance( + annotation, ast.NameConstant + ): return f"{annotation.value}" # If an AST node is not handled here, it's probably handled in ScriptTypeParser. @@ -358,7 +383,8 @@ def get_type_hint_captures(fn): name_to_type = { name: parameter.annotation for name, parameter in signature.parameters.items() - if parameter.annotation is not inspect.Parameter.empty and not isinstance(parameter.annotation, str) + if parameter.annotation is not inspect.Parameter.empty + and not isinstance(parameter.annotation, str) } # Then, get the literal type annotations from the function declaration @@ -377,7 +403,9 @@ def get_type_hint_captures(fn): for arg in f.args.args: # Get the source type annotation string for this argument if possible. - arg_annotation_str = get_annotation_str(arg.annotation) if arg.annotation else None + arg_annotation_str = ( + get_annotation_str(arg.annotation) if arg.annotation else None + ) # If the argument has no annotation or get_annotation_str cannot convert it to a string, # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle @@ -398,7 +426,10 @@ def get_type_hint_captures(fn): literal_return_annotation = get_annotation_str(f.returns) valid_literal_annotation = literal_return_annotation is not None return_annotation = signature.return_annotation - valid_return_annotation_type = return_annotation is not inspect.Parameter.empty and not isinstance(return_annotation, str) + valid_return_annotation_type = ( + return_annotation is not inspect.Parameter.empty + and not isinstance(return_annotation, str) + ) if valid_literal_annotation and valid_return_annotation_type: annotation_to_type[literal_return_annotation] = return_annotation @@ -412,7 +443,11 @@ def createResolutionCallbackForClassMethods(cls): """ # cls is a type here, so `ismethod` is false since the methods on the type # aren't bound to anything, so Python treats them as regular functions - fns = [getattr(cls, name) for name in cls.__dict__ if inspect.isroutine(getattr(cls, name))] + fns = [ + getattr(cls, name) + for name in cls.__dict__ + if inspect.isroutine(getattr(cls, name)) + ] captures = {} for fn in fns: @@ -428,12 +463,15 @@ def lookup_in_class(key): return lookup_in_class -def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name): +def boolean_dispatch( + arg_name, arg_index, default, if_true, if_false, module_name, func_name +): """ Dispatches to either of 2 script functions based on a boolean argument. In TorchScript, the boolean argument must be constant so that the correct function to use can be determined at compile time. """ + def fn(*args, **kwargs): dispatch_flag = False if arg_name in kwargs: @@ -469,7 +507,7 @@ def fn(*args, **kwargs): "if_false": if_false, "index": arg_index, "default": default, - "arg_name": arg_name + "arg_name": arg_name, } return fn @@ -479,12 +517,14 @@ class FunctionModifiers(object): Used to denote the behavior of a function in TorchScript. See export() and ignore() for details. """ + UNUSED = "unused (ignored and replaced with raising of an exception)" IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" EXPORT = "export (compile this function even if nothing calls it)" DEFAULT = "default (compile if called from a exported function / forward)" - COPY_TO_SCRIPT_WRAPPER = \ + COPY_TO_SCRIPT_WRAPPER = ( "if this method is not scripted, copy the python method onto the scripted model" + ) def export(fn): @@ -572,16 +612,21 @@ def forward(self, x): """ if isinstance(fn, property): prop = fn - setattr(prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED) # noqa: B010 + setattr( # noqa: B010 + prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED + ) if prop.fset: - setattr(prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED) # noqa: B010 + setattr( # noqa: B010 + prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED + ) return prop fn._torchscript_modifier = FunctionModifiers.UNUSED return fn + # No op context manager from python side class _IgnoreContextManager(contextlib.AbstractContextManager): def __init__(self, **kwargs): @@ -590,6 +635,7 @@ def __init__(self, **kwargs): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: pass + def ignore(drop=False, **kwargs): """ This decorator indicates to the compiler that a function or method should @@ -661,19 +707,27 @@ def forward(self, x): return fn if not isinstance(drop, bool): - raise RuntimeError("Argument to @torch.jit.ignore must be a bool or " - f"a function but got {drop}") + raise RuntimeError( + "Argument to @torch.jit.ignore must be a bool or " + f"a function but got {drop}" + ) # for backwards compat drop_on_export = kwargs.pop("drop_on_export", None) if drop_on_export: - warnings.warn("ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " - "call on compilation. Use torch.jit.unused now. {}", category=FutureWarning) + warnings.warn( + "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " + "call on compilation. Use torch.jit.unused now. {}", + category=FutureWarning, + ) drop = drop_on_export elif drop: - warnings.warn("ignore(True) has been deprecated. TorchScript will now drop the function " - "call on compilation. Use torch.jit.unused now. {}", category=FutureWarning) + warnings.warn( + "ignore(True) has been deprecated. TorchScript will now drop the function " + "call on compilation. Use torch.jit.unused now. {}", + category=FutureWarning, + ) def decorator(fn): if drop: @@ -681,6 +735,7 @@ def decorator(fn): else: fn._torchscript_modifier = FunctionModifiers.IGNORE return fn + return decorator @@ -688,6 +743,7 @@ def _copy_to_script_wrapper(fn): fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER return fn + def module_has_exports(mod): for name in dir(mod): if hasattr(mod, name): @@ -716,6 +772,7 @@ def is_ignored_fn(fn) -> bool: def is_static_fn(cls, fn) -> bool: return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod) + def get_static_fn(cls, fn): return inspect.getattr_static(cls, fn).__func__ @@ -723,9 +780,10 @@ def get_static_fn(cls, fn): def get_torchscript_modifier(fn): if not callable(fn): return None - if hasattr(fn, '__func__'): + if hasattr(fn, "__func__"): fn = fn.__func__ - return getattr(fn, '_torchscript_modifier', FunctionModifiers.DEFAULT) + return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT) + def copy_torchscript_modifier(orig, new) -> None: attr = get_torchscript_modifier(orig) @@ -733,15 +791,16 @@ def copy_torchscript_modifier(orig, new) -> None: return new._torchscript_modifier = attr + # overloading registration # overloads get registered in this file, and compiled in torch/jit/__init__.py # so that they can be imported in nn/functional.py without an import cycle # qualified_name => list[overload_functions] -_overloaded_fns : Dict[str, List[Callable]] = {} # noqa: T484 +_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484 -_OVERLOAD_EXAMPLE = ''' +_OVERLOAD_EXAMPLE = """ Example usage of overload function: @torch.jit._overload def my_function(x: type0) -> type0: # decl 1 @@ -756,23 +815,30 @@ def my_function(x): # implementation return x elif isinstance(x, type1): return x -''' +""" + def get_overload_no_implementation_error_message(kind, obj): sourcelines, file_lineno, filename = get_source_lines_and_file(obj) return ( f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make ' - f'sure a definition is provided and defined after all overload declarations.\n' - f'File "{filename}", line {file_lineno}:\n' + ''.join(sourcelines) + "\n" + _OVERLOAD_EXAMPLE + f"sure a definition is provided and defined after all overload declarations.\n" + f'File "{filename}", line {file_lineno}:\n' + + "".join(sourcelines) + + "\n" + + _OVERLOAD_EXAMPLE ) + def _check_overload_body(func): try: parsed_def = parse_def(func) except OSError as e: # Parsing the function definition can raise an OSError if source is unavailable. # Since this is just an initial check, just raise a warning if this is the case. - warnings.warn(f"Unable to retrieve source for @torch.jit._overload function: {func}.") + warnings.warn( + f"Unable to retrieve source for @torch.jit._overload function: {func}." + ) return body = parsed_def.ast.body[0].body @@ -784,11 +850,14 @@ def is_ellipsis(x): return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis) if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])): - msg = "Only `pass` statement or `...` can be the body of overload declaration:\n" - msg += '\n'.join(parsed_def.source.split("\n")[:3]) + msg = ( + "Only `pass` statement or `...` can be the body of overload declaration:\n" + ) + msg += "\n".join(parsed_def.source.split("\n")[:3]) msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE raise RuntimeError(msg) + def _overload(func): _check_overload_body(func) qual_name = _qualified_name(func) @@ -800,18 +869,23 @@ def _overload(func): fn_overload_list.append(func) return func + def _get_fn_overloads(qual_name): return _overloaded_fns.get(qual_name) + def _clear_fn_overloads(qual_name) -> None: del _overloaded_fns[qual_name] + def get_class_name_lineno(method) -> Tuple[str, int]: current_frame = inspect.currentframe() # one for the get_class_name call, one for _overload_method call for i in range(2): - assert current_frame is not None # assert current frame is not an Optional[FrameType] + assert ( + current_frame is not None + ) # assert current frame is not an Optional[FrameType] current_frame = current_frame.f_back assert current_frame is not None # same here @@ -819,6 +893,7 @@ def get_class_name_lineno(method) -> Tuple[str, int]: line_no = current_frame.f_code.co_firstlineno return class_name, line_no + # At the the point the decorator is applied to class methods the method # has no reference to its owning class. _qualified_name would not include # the class it is defined in, so any methods with the same name in the same file @@ -829,12 +904,13 @@ def get_class_name_lineno(method) -> Tuple[str, int]: # when modules of the same name are in the same file # qualified_name => class name => list[overload_functions] -_overloaded_methods : Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484 +_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484 # (qualified_name, class name) => class_fileno _overloaded_method_class_fileno = {} + def _overload_method(func): _check_overload_body(func) qual_name = _qualified_name(func) @@ -853,12 +929,15 @@ def _overload_method(func): else: existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)] if existing_lineno != line_no: - raise RuntimeError("Cannot currently overload the same method name in two different" - " classes with the same name in the same module") + raise RuntimeError( + "Cannot currently overload the same method name in two different" + " classes with the same name in the same module" + ) method_overloads.append(func) return func + def _get_overloaded_methods(method, mod_class): # TODO: __name__ not set for submodules in recursive script if not hasattr(method, "__name__"): @@ -875,7 +954,10 @@ def _get_overloaded_methods(method, mod_class): mod_class_fileno = get_source_lines_and_file(mod_class)[1] mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): - raise Exception("Overloads are not useable when a module is redeclared within the same file: " + str(method)) + raise Exception( + "Overloads are not useable when a module is redeclared within the same file: " + + str(method) + ) return overloads @@ -884,48 +966,59 @@ def is_tuple(ann) -> bool: raise_error_container_parameter_missing("Tuple") # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule - if not hasattr(ann, '__module__'): + if not hasattr(ann, "__module__"): return False - return ann.__module__ == 'typing' and \ - (getattr(ann, '__origin__', None) is Tuple or - getattr(ann, '__origin__', None) is tuple) + return ann.__module__ == "typing" and ( + getattr(ann, "__origin__", None) is Tuple + or getattr(ann, "__origin__", None) is tuple + ) + def is_list(ann) -> bool: if ann is List: raise_error_container_parameter_missing("List") - if not hasattr(ann, '__module__'): + if not hasattr(ann, "__module__"): return False - return ann.__module__ == 'typing' and \ - (getattr(ann, '__origin__', None) is List or - getattr(ann, '__origin__', None) is list) + return ann.__module__ == "typing" and ( + getattr(ann, "__origin__", None) is List + or getattr(ann, "__origin__", None) is list + ) + def is_dict(ann) -> bool: if ann is Dict: raise_error_container_parameter_missing("Dict") - if not hasattr(ann, '__module__'): + if not hasattr(ann, "__module__"): return False - return ann.__module__ == 'typing' and \ - (getattr(ann, '__origin__', None) is Dict or - getattr(ann, '__origin__', None) is dict) + return ann.__module__ == "typing" and ( + getattr(ann, "__origin__", None) is Dict + or getattr(ann, "__origin__", None) is dict + ) + def is_union(ann): if ann is Union: raise_error_container_parameter_missing("Union") - return (hasattr(ann, '__module__') and - ann.__module__ == 'typing' and - (getattr(ann, '__origin__', None) is Union)) + return ( + hasattr(ann, "__module__") + and ann.__module__ == "typing" + and (getattr(ann, "__origin__", None) is Union) + ) + def is_optional(ann): if ann is Optional: raise_error_container_parameter_missing("Optional") def is_optional_as_optional(ann): - return (hasattr(ann, '__module__') and - ann.__module__ == 'typing' and - (getattr(ann, '__origin__', None) is Optional)) + return ( + hasattr(ann, "__module__") + and ann.__module__ == "typing" + and (getattr(ann, "__origin__", None) is Optional) + ) def is_union_as_optional(ann): ann_args = ann.__args__ @@ -933,6 +1026,7 @@ def is_union_as_optional(ann): return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann)) + def is_future(ann) -> bool: if ann is Future: raise RuntimeError( @@ -942,9 +1036,10 @@ def is_future(ann) -> bool: ) return getattr(ann, "__origin__", None) is Future + if torch.distributed.rpc.is_available(): - from torch.distributed.rpc import RRef from torch._C._distributed_rpc import PyRRef + from torch.distributed.rpc import RRef def is_rref(ann) -> bool: if ann is RRef: @@ -959,19 +1054,24 @@ def is_rref_instance(obj) -> bool: return isinstance(obj, PyRRef) else: + def is_rref_instance(obj) -> bool: # If the RPC module doesn't exist then RRefs don't exist either. return False + def is_final(ann) -> bool: - return ann.__module__ in {'typing', 'typing_extensions'} and \ - (getattr(ann, '__origin__', None) is Final or isinstance(ann, type(Final))) + return ann.__module__ in {"typing", "typing_extensions"} and ( + getattr(ann, "__origin__", None) is Final or isinstance(ann, type(Final)) + ) + # allows BroadcastingList instance to be subscriptable class BroadcastingListCls(object): def __getitem__(self, types): return + # mypy doesn't support parameters on types, so we have to explicitly type each # list size BroadcastingList1 = BroadcastingListCls() @@ -1010,7 +1110,7 @@ def _qualified_name(obj, mangle_name=True) -> str: # its qualname so it appears correctly in the TorchScript system. This, # we set '_jit_override_qualname' with the original traced module's # qualified name, which is picked up here - if hasattr(obj, '_jit_override_qualname'): + if hasattr(obj, "_jit_override_qualname"): return obj._jit_override_qualname # short-circuit in cases where the object already has a known qualified name if isinstance(obj, torch._C.ScriptFunction): @@ -1024,9 +1124,8 @@ def _qualified_name(obj, mangle_name=True) -> str: else: raise RuntimeError("Could not get name of python class object") - - if name == '': - name = '_lambda' # make name a valid identifier + if name == "": + name = "_lambda" # make name a valid identifier module_name = obj.__module__ @@ -1037,8 +1136,10 @@ def _qualified_name(obj, mangle_name=True) -> str: # The Python docs are very clear that `__module__` can be None, but I can't # figure out when it actually would be. if module_name is None: - raise RuntimeError(f"Could not get qualified name for class '{name}': " - "__module__ can't be None.") + raise RuntimeError( + f"Could not get qualified name for class '{name}': " + "__module__ can't be None." + ) # if getattr(sys.modules[module_name], name) is not obj: # raise RuntimeError(f"Could not get qualified name for class '{name}': " @@ -1063,8 +1164,10 @@ def _qualified_name(obj, mangle_name=True) -> str: module_name = "__torch__." + module_name if "." in name: - raise RuntimeError(f"Could not get qualified name for class '{name}': " - f"'{name}' is not a valid identifier") + raise RuntimeError( + f"Could not get qualified name for class '{name}': " + f"'{name}' is not a valid identifier" + ) return module_name + "." + name @@ -1076,33 +1179,48 @@ def _try_get_dispatched_fn(fn): def _get_named_tuple_properties(obj): - assert issubclass(obj, tuple) and hasattr(obj, '_fields') + assert issubclass(obj, tuple) and hasattr(obj, "_fields") if hasattr(obj, "_field_defaults"): - defaults = [obj._field_defaults[field] - for field in obj._fields - if field in obj._field_defaults] + defaults = [ + obj._field_defaults[field] + for field in obj._fields + if field in obj._field_defaults + ] else: defaults = [] + # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function + # Also, annotations from base class are not inherited so they need to be queried explicitly + if sys.version_info[:2] < (3, 10): + obj_annotations = getattr(obj, "__annotations__", {}) + else: + obj_annotations = inspect.get_annotations(obj) + if len(obj_annotations) == 0 and hasattr(obj, "__base__"): + obj_annotations = inspect.get_annotations(obj.__base__) + annotations = [] - has_annotations = hasattr(obj, '__annotations__') for field in obj._fields: - if has_annotations and field in obj.__annotations__: - the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range()) + if field in obj_annotations: + the_type = torch.jit.annotations.ann_to_type( + obj_annotations[field], fake_range() + ) annotations.append(the_type) else: annotations.append(torch._C.TensorType.getInferred()) return type(obj).__name__, obj._fields, annotations, defaults -def _create_named_tuple(t, unqual_name: str, field_names: List[str], defaults: Tuple[Any, ...]): +def _create_named_tuple( + t, unqual_name: str, field_names: List[str], defaults: Tuple[Any, ...] +): # mypy: namedtuple() expects a string literal as the first argument if sys.version_info < (3, 7, 0): TupleType = collections.namedtuple(unqual_name, field_names) # type: ignore[no-redef, misc] - TupleType.__new__.__defaults__ = defaults # type: ignore[attr-defined] + TupleType.__new__.__defaults__ = defaults # type: ignore[attr-defined] else: TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc] return TupleType(*t) + @contextlib.contextmanager def _disable_emit_hooks(): hooks = torch._C._jit_get_emit_hooks() @@ -1119,13 +1237,15 @@ def __enter__(self) -> None: def __exit__(self, *args) -> None: torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) + def _is_exception(obj) -> bool: if not inspect.isclass(obj): return False return issubclass(obj, Exception) + def raise_error_container_parameter_missing(target_type) -> None: - if target_type == 'Dict': + if target_type == "Dict": raise RuntimeError( "Attempted to use Dict without " "contained types. Please add contained type, e.g. " @@ -1159,11 +1279,13 @@ def check_args_exist(target_type) -> None: def check_empty_containers(obj) -> None: if obj == [] or obj == {} or obj == (): - warnings.warn("The inner type of a container is lost when " - "calling torch.jit.isinstance in eager mode. For " - "example, List[int] would become list and " - "therefore falsely return True for List[float] or" - " List[str].") + warnings.warn( + "The inner type of a container is lost when " + "calling torch.jit.isinstance in eager mode. For " + "example, List[int] would become list and " + "therefore falsely return True for List[float] or" + " List[str]." + ) # supports List/Dict/Tuple and Optional types @@ -1223,7 +1345,7 @@ def container_checker(obj, target_type) -> bool: inner_types = get_args(target_type) for t in inner_types: t_origin = get_origin(t) - if (t_origin): + if t_origin: return container_checker(obj, t) elif isinstance(obj, t): return True @@ -1233,9 +1355,11 @@ def container_checker(obj, target_type) -> bool: def _isinstance(obj, target_type) -> bool: if isinstance(target_type, collections.abc.Container): if not isinstance(target_type, tuple): - raise RuntimeError("The second argument to " - "`torch.jit.isinstance` must be a type " - "or a tuple of types") + raise RuntimeError( + "The second argument to " + "`torch.jit.isinstance` must be a type " + "or a tuple of types" + ) for t_type in target_type: if _isinstance(obj, t_type): return True diff --git a/torch/_lazy/__init__.py b/torch/_lazy/__init__.py index ff4e90c0edf23..953f6a83ffacd 100644 --- a/torch/_lazy/__init__.py +++ b/torch/_lazy/__init__.py @@ -1,7 +1,7 @@ import torch._C._lazy -def mark_step(device: str = "lazy:0", wait=False): +def mark_step(device: str = "", wait=False): """Triggers a mark step, which amounts to - collecting a group of 'live' lazy tensors to index into the compilation cache (lowering/compiling their IR graphs if not cached) @@ -11,6 +11,7 @@ def mark_step(device: str = "lazy:0", wait=False): # TODO(whc) expand this to include backend hooks and align with XLA backend needs torch._C._lazy._mark_step(device, [], wait=wait) + def wait_device_ops(devices=None): """Waits for all the async operations on the given devices to complete. Args: @@ -21,6 +22,7 @@ def wait_device_ops(devices=None): devices = [] torch._C._lazy._wait_device_ops(devices=devices) + def sync_multi(tensors, devices): """ Sync the list of lazy tensors so there IR get lowered for the activate backend @@ -28,6 +30,7 @@ def sync_multi(tensors, devices): """ torch._C._lazy._sync_multi(tensors, devices) + def get_tensor_id(tensor): """Return a unique id of the lazy tensor maintained by LTC""" return torch._C._lazy._get_tensor_id(tensor) diff --git a/torch/_lazy/computation.py b/torch/_lazy/computation.py index 7dd57cd7238d4..27b73c42e5c0d 100644 --- a/torch/_lazy/computation.py +++ b/torch/_lazy/computation.py @@ -1,23 +1,26 @@ import torch._C._lazy import torch._C._lazy_ts_backend + def get_tensors_ts_device_data_node(tensors): """Return tensor ids and eager tensors for DeviceData nodes in the - IR for the passed in lazy tensors. + IR for the passed in lazy tensors. - TODO: This API is currently ts backend specific. We are working on - generalizing it to all backends including XLA. + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. """ return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors) + def get_graph_hash(tensors): """Return the graph hash for the passed in lazy tensors""" return torch._C._lazy._get_graph_hash(tensors) + def run_cached_graph(hash_str, graph_inputs): """Running the cached computation graph with the given inputs - TODO: This API is currently ts backend specific. We are working on - generalizing it to all backends including XLA. + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. """ return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs) diff --git a/torch/_lazy/config.py b/torch/_lazy/config.py index c2e72bd7d60b9..e7a4d1dd24f8d 100644 --- a/torch/_lazy/config.py +++ b/torch/_lazy/config.py @@ -1,13 +1,16 @@ import torch._C._lazy + def get_force_fallback(): """Get the config used to force LTC fallback""" return torch._C._lazy._get_force_fallback() + def set_force_fallback(configval): """Set the config used to force LTC fallback""" torch._C._lazy._set_force_fallback(configval) + def set_reuse_ir(val: bool): """Set the config to reuse IR nodes for faster tracing""" torch._C._lazy._set_reuse_ir(val) diff --git a/torch/_lazy/debug.py b/torch/_lazy/debug.py index 882056ca9c0f3..286aa049280c9 100644 --- a/torch/_lazy/debug.py +++ b/torch/_lazy/debug.py @@ -3,14 +3,15 @@ def render_ir_graph(tensors): """Return a text dump of the LTC IR graph in dot format for the tensors. - The text can be processed by tools like dot to be rendered in pdf,png etc.""" + The text can be processed by tools like dot to be rendered in pdf,png etc.""" return torch._C._lazy._get_tensors_dot(tensors) + def dump_ir(tensors, ir_format): """Return a dump of the tensors in the specified format. - Valid format are - - text: for LTC IR - - backend: for the activate backend IR + Valid format are + - text: for LTC IR + - backend: for the activate backend IR """ if ir_format == "text": return torch._C._lazy._get_tensors_text(tensors) diff --git a/torch/_lazy/extract_compiled_graph.py b/torch/_lazy/extract_compiled_graph.py index 37d0e67f31f3f..440539f52453b 100644 --- a/torch/_lazy/extract_compiled_graph.py +++ b/torch/_lazy/extract_compiled_graph.py @@ -1,18 +1,19 @@ -import torch._lazy.metrics as metrics -from torch._lazy.tensor_factory_functions import tensor_factory_functions -from torch._lazy import computation -from torch._lazy import debug as lazy_debug -import torch._lazy as lazy -import dataclasses -from typing import List, Dict, Any, Callable import copy -from torch import fx -import torch +import dataclasses import itertools import os +from typing import Any, Callable, Dict, List + +import torch +import torch._lazy as lazy +import torch._lazy.metrics as metrics +from torch import fx +from torch._lazy import computation, debug as lazy_debug +from torch._lazy.tensor_factory_functions import tensor_factory_functions debug = os.environ.get("debug_extract_compiled_graph") is not None + @dataclasses.dataclass class GraphInputMatcher: """ @@ -24,6 +25,7 @@ class GraphInputMatcher: graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the TS/XLA graph inputs. """ + tensor_id_to_arg_idx: Dict[int, int] graph_input_tensor_ids: List[int] # there are 2 categories of graph_input_tensors. @@ -36,7 +38,9 @@ class GraphInputMatcher: # get the real graph input tensors def __call__(self, args): real_input = [] - for tensor_id, traced_ivalue in zip(self.graph_input_tensor_ids, self.graph_input_ivalues): + for tensor_id, traced_ivalue in zip( + self.graph_input_tensor_ids, self.graph_input_ivalues + ): arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) if arg_idx is None: inp = traced_ivalue @@ -45,6 +49,7 @@ def __call__(self, args): real_input.append(inp) return real_input + class ReturnValueHandler: r""" When ltc_sync_multi is called on multi tensors, the compiled graph @@ -62,6 +67,7 @@ def forward(self, a): This class dedup the lazy tensors first to get the index that will be used to duplicate the eager tensors later. """ + def __init__(self, lazy_out_list): self.index: List[List[int]] = [] self.total_count = len(lazy_out_list) @@ -85,19 +91,24 @@ def duplicate_eager_tensors(self, eager_tensor_list): duplicated_list[dup_idx] = eager_tensor return duplicated_list + def force_lazy_device(model: fx.GraphModule): """ Factory methods in a Fx graph may create tensors for a specific eager devices. If we take no actions, those eager tensors will be mixed with lazy tensors and cause crash. This method overwrite those eager device to lazy device. """ + def tolazydevice(dev): if isinstance(dev, torch.device): return torch.device("lazy", index=dev.index) return dev def hasDeviceArg(args, kwargs): - return any(isinstance(arg, torch.device) for arg in itertools.chain(args, kwargs.values())) + return any( + isinstance(arg, torch.device) + for arg in itertools.chain(args, kwargs.values()) + ) for nd in model.graph.nodes: nd.args = tuple(tolazydevice(arg) for arg in nd.args) @@ -114,13 +125,16 @@ def hasDeviceArg(args, kwargs): # # TODO: This solution is no ideal since we may miss some factory methods. In future # when we support lazy mode, this method can be replaced by that. - if nd.target in tensor_factory_functions and not hasDeviceArg(nd.args, nd.kwargs): + if nd.target in tensor_factory_functions and not hasDeviceArg( + nd.args, nd.kwargs + ): kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy. kwargs["device"] = torch.device("lazy") nd.kwargs = kwargs model.recompile() + def get_fallback_ops(): fallback_ops = [] for opname in metrics.counter_names(): @@ -132,6 +146,7 @@ def get_fallback_ops(): return fallback_ops + def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: """ Optimize an eager model with LTC and returns a wrapper to execute the @@ -152,7 +167,9 @@ def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: metrics.reset() if len(fallback_ops) > 0: - raise RuntimeError(f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}") + raise RuntimeError( + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" + ) if not isinstance(lazy_out, (tuple, list)): lazy_out = (lazy_out,) @@ -165,9 +182,14 @@ def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: # TODO: this part is TS backend specific for now and will be generalized to # support XLA - graph_input_tensor_ids, graph_input_ivalues = computation.get_tensors_ts_device_data_node(args_and_out) + ( + graph_input_tensor_ids, + graph_input_ivalues, + ) = computation.get_tensors_ts_device_data_node(args_and_out) assert len(graph_input_tensor_ids) == len(graph_input_ivalues) - graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues) + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues + ) graph_hash = computation.get_graph_hash(args_and_out) @@ -185,7 +207,9 @@ def optimized_mod(*args): if len(args_and_out) == 0: return () graph_input = graph_input_matcher(args) - res = return_value_handler.duplicate_eager_tensors(computation.run_cached_graph(graph_hash, graph_input)) + res = return_value_handler.duplicate_eager_tensors( + computation.run_cached_graph(graph_hash, graph_input) + ) assert len(res) == len(args_and_out) for i, arg in enumerate(args): @@ -194,6 +218,6 @@ def optimized_mod(*args): arg.copy_(res[i]) # skip the args - return res[len(args):] + return res[len(args) :] return optimized_mod diff --git a/torch/_lazy/ir_cache.py b/torch/_lazy/ir_cache.py index 04f1f103d286f..4270684d29434 100644 --- a/torch/_lazy/ir_cache.py +++ b/torch/_lazy/ir_cache.py @@ -1,9 +1,11 @@ import torch._C._lazy + def dump(dot_file_name: str): """Dump TrieCache in the dot format""" return torch._C._lazy._dump_ir_cache(dot_file_name) + def reset(): """Clear TrieCache. This is needed in testing to avoid node reusing between different tests. diff --git a/torch/_lazy/metrics.py b/torch/_lazy/metrics.py index 043db981bb71e..2d7db73055677 100644 --- a/torch/_lazy/metrics.py +++ b/torch/_lazy/metrics.py @@ -1,13 +1,21 @@ import torch._C._lazy + def reset(): """Resets all metric counters.""" torch._C._lazy._reset_metrics() + def counter_names(): """Retrieves all the currently active counter names.""" return torch._C._lazy._counter_names() + def counter_value(name: str): """Return the value of the counter with the speficied name""" return torch._C._lazy._counter_value(name) + + +def metrics_report(): + """Return the combined (lazy core and backend) metric report""" + return torch._C._lazy._metrics_report() diff --git a/torch/_lazy/ts_backend.py b/torch/_lazy/ts_backend.py index 118de2dbefca0..184223771932d 100644 --- a/torch/_lazy/ts_backend.py +++ b/torch/_lazy/ts_backend.py @@ -1,5 +1,6 @@ import torch._C._lazy_ts_backend + def init(): """Initializes the lazy Torchscript backend""" torch._C._lazy_ts_backend._init() diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index faa79f7f0cdb8..d7f6798dd9d7f 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -2,11 +2,11 @@ """ -from torch import Tensor -import torch - from typing import Optional, Tuple +import torch +from torch import Tensor + def is_sparse(A): """Check if tensor A is a sparse tensor""" @@ -18,6 +18,7 @@ def is_sparse(A): error_str += " but got {}".format(type(A)) raise TypeError(error_str) + def get_floating_dtype(A): """Return the floating point dtype of tensor A. @@ -53,33 +54,28 @@ def conjugate(A): def transpose(A): - """Return transpose of a matrix or batches of matrices. - """ + """Return transpose of a matrix or batches of matrices.""" ndim = len(A.shape) return A.transpose(ndim - 1, ndim - 2) def transjugate(A): - """Return transpose conjugate of a matrix or batches of matrices. - """ + """Return transpose conjugate of a matrix or batches of matrices.""" return conjugate(transpose(A)) def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: - """Return bilinear form of matrices: :math:`X^T A Y`. - """ + """Return bilinear form of matrices: :math:`X^T A Y`.""" return matmul(transpose(X), matmul(A, Y)) def qform(A: Optional[Tensor], S: Tensor): - """Return quadratic form :math:`S^T A S`. - """ + """Return quadratic form :math:`S^T A S`.""" return bform(S, A, S) def basis(A): - """Return orthogonal basis of A columns. - """ + """Return orthogonal basis of A columns.""" if A.is_cuda: # torch.orgqr is not available in CUDA Q = torch.linalg.qr(A).Q @@ -89,17 +85,17 @@ def basis(A): def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]: - """Return eigenpairs of A with specified ordering. - """ + """Return eigenpairs of A with specified ordering.""" if largest is None: largest = False - E, Z = torch.linalg.eigh(A, UPLO='U') + E, Z = torch.linalg.eigh(A, UPLO="U") # assuming that E is ordered if largest: E = torch.flip(E, dims=(-1,)) Z = torch.flip(Z, dims=(-1,)) return E, Z + # This function was deprecated and removed # This nice error message can be removed in version 1.13+ def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index cb7a6723683ab..273c93d038158 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -3,30 +3,27 @@ # Author: Pearu Peterson # Created: February 2020 -from typing import Dict, Tuple, Optional +from typing import Dict, Optional, Tuple import torch from torch import Tensor from . import _linalg_utils as _utils -from .overrides import has_torch_function, handle_torch_function +from .overrides import handle_torch_function, has_torch_function -__all__ = ['lobpcg'] +__all__ = ["lobpcg"] + def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U): # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0 F = D.unsqueeze(-2) - D.unsqueeze(-1) - F.diagonal(dim1=-2, dim2=-1).fill_(float('inf')) + F.diagonal(dim1=-2, dim2=-1).fill_(float("inf")) F.pow_(-1) # A.grad = U (D.grad + (U^T U.grad * F)) U^T Ut = U.mT.contiguous() res = torch.matmul( - U, - torch.matmul( - torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, - Ut - ) + U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut) ) return res @@ -66,7 +63,9 @@ def _polynomial_coefficients_given_roots(roots): # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity. poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1) - out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(-1, poly_order - i + 1, i + 1) + out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow( + -1, poly_order - i + 1, i + 1 + ) poly_coeffs = poly_coeffs_new return poly_coeffs.narrow(-1, 1, poly_order + 1) @@ -102,6 +101,7 @@ def _polynomial_value(poly, x, zero_power, transition): res = transition(res, x, poly[..., k]) return res + def _matrix_polynomial_value(poly, x, zero_power=None): """ Evaluates `poly(x)` for the (batched) matrix input `x`. @@ -115,11 +115,13 @@ def transition(curr_poly_val, x, poly_coeff): return res if zero_power is None: - zero_power = torch.eye(x.size(-1), x.size(-1), dtype=x.dtype, device=x.device) \ - .view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1)) + zero_power = torch.eye( + x.size(-1), x.size(-1), dtype=x.dtype, device=x.device + ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1)) return _polynomial_value(poly, x, zero_power, transition) + def _vector_polynomial_value(poly, x, zero_power=None): """ Evaluates `poly(x)` for the (batched) vector input `x`. @@ -136,6 +138,7 @@ def transition(curr_poly_val, x, poly_coeff): return _polynomial_value(poly, x, zero_power, transition) + def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): # compute a projection operator onto an orthogonal subspace spanned by the # columns of U defined as (I - UU^T) @@ -156,7 +159,7 @@ def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): (*A.shape[:-1], A.size(-1) - D.size(-1)), dtype=A.dtype, device=A.device, - generator=gen + generator=gen, ) ) U_ortho_t = U_ortho.mT.contiguous() @@ -212,11 +215,7 @@ def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t chr_poly_D_at_A_to_U_ortho = torch.matmul( - U_ortho_t, - torch.matmul( - chr_poly_D_at_A, - U_ortho - ) + U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho) ) # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its # Cholesky decomposition and then use `torch.cholesky_solve` for better stability. @@ -233,51 +232,47 @@ def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): ) # compute the gradient part in span(U) - res = _symeig_backward_complete_eigenspace( - D_grad, U_grad, A, D, U - ) + res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) # incorporate the Sylvester equation solution into the full gradient # it resides in span(U_ortho) res -= U_ortho.matmul( - chr_poly_D_at_A_to_U_ortho_sign * torch.cholesky_solve( - U_ortho_t.matmul(series_acc), - chr_poly_D_at_A_to_U_ortho_L + chr_poly_D_at_A_to_U_ortho_sign + * torch.cholesky_solve( + U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L ) ).matmul(Ut) return res + def _symeig_backward(D_grad, U_grad, A, D, U, largest): # if `U` is square, then the columns of `U` is a complete eigenspace if U.size(-1) == U.size(-2): - return _symeig_backward_complete_eigenspace( - D_grad, U_grad, A, D, U - ) + return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) else: - return _symeig_backward_partial_eigenspace( - D_grad, U_grad, A, D, U, largest - ) + return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest) -class LOBPCGAutogradFunction(torch.autograd.Function): +class LOBPCGAutogradFunction(torch.autograd.Function): @staticmethod - def forward(ctx, # type: ignore[override] - A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, - tracker: None = None, - ortho_iparams: Optional[Dict[str, int]] = None, - ortho_fparams: Optional[Dict[str, float]] = None, - ortho_bparams: Optional[Dict[str, bool]] = None - ) -> Tuple[Tensor, Tensor]: + def forward( # type: ignore[override] + ctx, + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None, + ) -> Tuple[Tensor, Tensor]: # makes sure that input is contiguous for efficiency. # Note: autograd does not support dense gradients for sparse input yet. @@ -286,9 +281,20 @@ def forward(ctx, # type: ignore[override] B = B.contiguous() if (not B.is_sparse) else B D, U = _lobpcg( - A, k, B, X, - n, iK, niter, tol, largest, method, tracker, - ortho_iparams, ortho_fparams, ortho_bparams + A, + k, + B, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, ) ctx.save_for_backward(A, B, D, U) @@ -307,18 +313,21 @@ def backward(ctx, D_grad, U_grad): # lobpcg.backward has some limitations. Checks for unsupported input if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]): raise ValueError( - 'lobpcg.backward does not support sparse input yet.' - 'Note that lobpcg.forward does though.' + "lobpcg.backward does not support sparse input yet." + "Note that lobpcg.forward does though." ) - if A.dtype in (torch.complex64, torch.complex128) or \ - B is not None and B.dtype in (torch.complex64, torch.complex128): + if ( + A.dtype in (torch.complex64, torch.complex128) + or B is not None + and B.dtype in (torch.complex64, torch.complex128) + ): raise ValueError( - 'lobpcg.backward does not support complex input yet.' - 'Note that lobpcg.forward does though.' + "lobpcg.backward does not support complex input yet." + "Note that lobpcg.forward does though." ) if B is not None: raise ValueError( - 'lobpcg.backward does not support backward with B != I yet.' + "lobpcg.backward does not support backward with B != I yet." ) if largest is None: @@ -326,9 +335,7 @@ def backward(ctx, D_grad, U_grad): # symeig backward if B is None: - A_grad = _symeig_backward( - D_grad, U_grad, A, D, U, largest - ) + A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest) # A has index 0 grads[0] = A_grad @@ -337,21 +344,22 @@ def backward(ctx, D_grad, U_grad): return tuple(grads) -def lobpcg(A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, - tracker: None = None, - ortho_iparams: Optional[Dict[str, int]] = None, - ortho_fparams: Optional[Dict[str, float]] = None, - ortho_bparams: Optional[Dict[str, bool]] = None - ) -> Tuple[Tensor, Tensor]: +def lobpcg( + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None, +) -> Tuple[Tensor, Tensor]: """Find the k largest (or smallest) eigenvalues and the corresponding eigenvectors of a symmetric positive definite generalized @@ -499,14 +507,27 @@ def lobpcg(A: Tensor, if not torch.jit.is_scripting(): tensor_ops = (A, B, X, iK) - if (not set(map(type, tensor_ops)).issubset((torch.Tensor, type(None))) and has_torch_function(tensor_ops)): + if not set(map(type, tensor_ops)).issubset( + (torch.Tensor, type(None)) + ) and has_torch_function(tensor_ops): return handle_torch_function( - lobpcg, tensor_ops, A, k=k, - B=B, X=X, n=n, iK=iK, niter=niter, tol=tol, - largest=largest, method=method, tracker=tracker, + lobpcg, + tensor_ops, + A, + k=k, + B=B, + X=X, + n=n, + iK=iK, + niter=niter, + tol=tol, + largest=largest, + method=method, + tracker=tracker, ortho_iparams=ortho_iparams, ortho_fparams=ortho_fparams, - ortho_bparams=ortho_bparams) + ortho_bparams=ortho_bparams, + ) if not torch._jit_internal.is_scripting(): if A.requires_grad or (B is not None and B.requires_grad): @@ -520,38 +541,63 @@ def lobpcg(A: Tensor, B_sym = (B + B.mT) / 2 if (B is not None) else None return LOBPCGAutogradFunction.apply( - A_sym, k, B_sym, X, n, iK, niter, tol, largest, - method, tracker, ortho_iparams, ortho_fparams, ortho_bparams + A_sym, + k, + B_sym, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, ) else: if A.requires_grad or (B is not None and B.requires_grad): raise RuntimeError( - 'Script and require grads is not supported atm.' - 'If you just want to do the forward, use .detach()' - 'on A and B before calling into lobpcg' + "Script and require grads is not supported atm." + "If you just want to do the forward, use .detach()" + "on A and B before calling into lobpcg" ) return _lobpcg( - A, k, B, X, - n, iK, niter, tol, largest, method, tracker, - ortho_iparams, ortho_fparams, ortho_bparams + A, + k, + B, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, ) -def _lobpcg(A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, - tracker: None = None, - ortho_iparams: Optional[Dict[str, int]] = None, - ortho_fparams: Optional[Dict[str, float]] = None, - ortho_bparams: Optional[Dict[str, bool]] = None - ) -> Tuple[Tensor, Tensor]: + +def _lobpcg( + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None, +) -> Tuple[Tensor, Tensor]: # A must be square: assert A.shape[-2] == A.shape[-1], A.shape @@ -562,50 +608,47 @@ def _lobpcg(A: Tensor, dtype = _utils.get_floating_dtype(A) device = A.device if tol is None: - feps = {torch.float32: 1.2e-07, - torch.float64: 2.23e-16}[dtype] - tol = feps ** 0.5 + feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype] + tol = feps**0.5 m = A.shape[-1] k = (1 if X is None else X.shape[-1]) if k is None else k n = (k if n is None else n) if X is None else X.shape[-1] - if (m < 3 * n): + if m < 3 * n: raise ValueError( - 'LPBPCG algorithm is not applicable when the number of A rows (={})' - ' is smaller than 3 x the number of requested eigenpairs (={})' - .format(m, n)) + "LPBPCG algorithm is not applicable when the number of A rows (={})" + " is smaller than 3 x the number of requested eigenpairs (={})".format(m, n) + ) - method = 'ortho' if method is None else method + method = "ortho" if method is None else method iparams = { - 'm': m, - 'n': n, - 'k': k, - 'niter': 1000 if niter is None else niter, + "m": m, + "n": n, + "k": k, + "niter": 1000 if niter is None else niter, } fparams = { - 'tol': tol, + "tol": tol, } - bparams = { - 'largest': True if largest is None else largest - } + bparams = {"largest": True if largest is None else largest} - if method == 'ortho': + if method == "ortho": if ortho_iparams is not None: iparams.update(ortho_iparams) if ortho_fparams is not None: fparams.update(ortho_fparams) if ortho_bparams is not None: bparams.update(ortho_bparams) - iparams['ortho_i_max'] = iparams.get('ortho_i_max', 3) - iparams['ortho_j_max'] = iparams.get('ortho_j_max', 3) - fparams['ortho_tol'] = fparams.get('ortho_tol', tol) - fparams['ortho_tol_drop'] = fparams.get('ortho_tol_drop', tol) - fparams['ortho_tol_replace'] = fparams.get('ortho_tol_replace', tol) - bparams['ortho_use_drop'] = bparams.get('ortho_use_drop', False) + iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3) + iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3) + fparams["ortho_tol"] = fparams.get("ortho_tol", tol) + fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol) + fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol) + bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False) if not torch.jit.is_scripting(): LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[assignment] @@ -621,9 +664,11 @@ def _lobpcg(A: Tensor, for i in range(N): A_ = bA[i] B_ = bB[i] if bB is not None else None - X_ = torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i] + X_ = ( + torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i] + ) assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n)) - iparams['batch_index'] = i + iparams["batch_index"] = i worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker) worker.run() bE[i] = worker.E[:k] @@ -648,20 +693,20 @@ def _lobpcg(A: Tensor, class LOBPCG(object): - """Worker class of LOBPCG methods. - """ - - def __init__(self, - A: Optional[Tensor], - B: Optional[Tensor], - X: Tensor, - iK: Optional[Tensor], - iparams: Dict[str, int], - fparams: Dict[str, float], - bparams: Dict[str, bool], - method: str, - tracker: None - ) -> None: + """Worker class of LOBPCG methods.""" + + def __init__( + self, + A: Optional[Tensor], + B: Optional[Tensor], + X: Tensor, + iK: Optional[Tensor], + iparams: Dict[str, int], + fparams: Dict[str, float], + bparams: Dict[str, bool], + method: str, + tracker: None, + ) -> None: # constant parameters self.A = A @@ -672,64 +717,62 @@ def __init__(self, self.bparams = bparams self.method = method self.tracker = tracker - m = iparams['m'] - n = iparams['n'] + m = iparams["m"] + n = iparams["n"] # variable parameters self.X = X - self.E = torch.zeros((n, ), dtype=X.dtype, device=X.device) + self.E = torch.zeros((n,), dtype=X.dtype, device=X.device) self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device) self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device) self.tvars: Dict[str, Tensor] = {} - self.ivars: Dict[str, int] = {'istep': 0} - self.fvars: Dict[str, float] = {'_': 0.0} - self.bvars: Dict[str, bool] = {'_': False} + self.ivars: Dict[str, int] = {"istep": 0} + self.fvars: Dict[str, float] = {"_": 0.0} + self.bvars: Dict[str, bool] = {"_": False} def __str__(self): - lines = ['LOPBCG:'] - lines += [' iparams={}'.format(self.iparams)] - lines += [' fparams={}'.format(self.fparams)] - lines += [' bparams={}'.format(self.bparams)] - lines += [' ivars={}'.format(self.ivars)] - lines += [' fvars={}'.format(self.fvars)] - lines += [' bvars={}'.format(self.bvars)] - lines += [' tvars={}'.format(self.tvars)] - lines += [' A={}'.format(self.A)] - lines += [' B={}'.format(self.B)] - lines += [' iK={}'.format(self.iK)] - lines += [' X={}'.format(self.X)] - lines += [' E={}'.format(self.E)] - r = '' + lines = ["LOPBCG:"] + lines += [" iparams={}".format(self.iparams)] + lines += [" fparams={}".format(self.fparams)] + lines += [" bparams={}".format(self.bparams)] + lines += [" ivars={}".format(self.ivars)] + lines += [" fvars={}".format(self.fvars)] + lines += [" bvars={}".format(self.bvars)] + lines += [" tvars={}".format(self.tvars)] + lines += [" A={}".format(self.A)] + lines += [" B={}".format(self.B)] + lines += [" iK={}".format(self.iK)] + lines += [" X={}".format(self.X)] + lines += [" E={}".format(self.E)] + r = "" for line in lines: - r += line + '\n' + r += line + "\n" return r def update(self): - """Set and update iteration variables. - """ - if self.ivars['istep'] == 0: + """Set and update iteration variables.""" + if self.ivars["istep"] == 0: X_norm = float(torch.norm(self.X)) - iX_norm = X_norm ** -1 + iX_norm = X_norm**-1 A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm - self.fvars['X_norm'] = X_norm - self.fvars['A_norm'] = A_norm - self.fvars['B_norm'] = B_norm - self.ivars['iterations_left'] = self.iparams['niter'] - self.ivars['converged_count'] = 0 - self.ivars['converged_end'] = 0 - - if self.method == 'ortho': + self.fvars["X_norm"] = X_norm + self.fvars["A_norm"] = A_norm + self.fvars["B_norm"] = B_norm + self.ivars["iterations_left"] = self.iparams["niter"] + self.ivars["converged_count"] = 0 + self.ivars["converged_end"] = 0 + + if self.method == "ortho": self._update_ortho() else: self._update_basic() - self.ivars['iterations_left'] = self.ivars['iterations_left'] - 1 - self.ivars['istep'] = self.ivars['istep'] + 1 + self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1 + self.ivars["istep"] = self.ivars["istep"] + 1 def update_residual(self): - """Update residual R from A, B, X, E. - """ + """Update residual R from A, B, X, E.""" mm = _utils.matmul self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E @@ -740,12 +783,15 @@ def update_converged_count(self): Users may redefine this method for custom convergence criteria. """ # (...) -> int - prev_count = self.ivars['converged_count'] - tol = self.fparams['tol'] - A_norm = self.fvars['A_norm'] - B_norm = self.fvars['B_norm'] + prev_count = self.ivars["converged_count"] + tol = self.fparams["tol"] + A_norm = self.fvars["A_norm"] + B_norm = self.fvars["B_norm"] E, X, R = self.E, self.X, self.R - rerr = torch.norm(R, 2, (0, )) * (torch.norm(X, 2, (0, )) * (A_norm + E[:X.shape[-1]] * B_norm)) ** -1 + rerr = ( + torch.norm(R, 2, (0,)) + * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1 + ) converged = rerr < tol count = 0 for b in converged: @@ -754,10 +800,12 @@ def update_converged_count(self): # strict ordering of eigenpairs break count += 1 - assert count >= prev_count, 'the number of converged eigenpairs ' \ - '(was {}, got {}) cannot decrease'.format(prev_count, count) - self.ivars['converged_count'] = count - self.tvars['rerr'] = rerr + assert count >= prev_count, ( + "the number of converged eigenpairs " + "(was {}, got {}) cannot decrease".format(prev_count, count) + ) + self.ivars["converged_count"] = count + self.tvars["rerr"] = rerr return count def stop_iteration(self): @@ -766,9 +814,11 @@ def stop_iteration(self): Note that tracker (if defined) can force-stop iterations by setting ``worker.bvars['force_stop'] = True``. """ - return (self.bvars.get('force_stop', False) - or self.ivars['iterations_left'] == 0 - or self.ivars['converged_count'] >= self.iparams['k']) + return ( + self.bvars.get("force_stop", False) + or self.ivars["iterations_left"] == 0 + or self.ivars["converged_count"] >= self.iparams["k"] + ) def run(self): """Run LOBPCG iterations. @@ -807,12 +857,12 @@ def _update_basic(self): Update or initialize iteration variables when `method == "basic"`. """ mm = torch.matmul - ns = self.ivars['converged_end'] - nc = self.ivars['converged_count'] - n = self.iparams['n'] - largest = self.bparams['largest'] + ns = self.ivars["converged_end"] + nc = self.ivars["converged_count"] + n = self.iparams["n"] + largest = self.bparams["largest"] - if self.ivars['istep'] == 0: + if self.ivars["istep"] == 0: Ri = self._get_rayleigh_ritz_transform(self.X) M = _utils.qform(_utils.qform(self.A, self.X), Ri) E, Z = _utils.symeig(M, largest) @@ -824,38 +874,38 @@ def _update_basic(self): self.S[..., :n] = self.X W = _utils.matmul(self.iK, self.R) - self.ivars['converged_end'] = ns = n + np + W.shape[-1] - self.S[:, n + np:ns] = W + self.ivars["converged_end"] = ns = n + np + W.shape[-1] + self.S[:, n + np : ns] = W else: S_ = self.S[:, nc:ns] Ri = self._get_rayleigh_ritz_transform(S_) M = _utils.qform(_utils.qform(self.A, S_), Ri) E_, Z = _utils.symeig(M, largest) - self.X[:, nc:] = mm(S_, mm(Ri, Z[:, :n - nc])) - self.E[nc:] = E_[:n - nc] - P = mm(S_, mm(Ri, Z[:, n:2 * n - nc])) + self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc])) + self.E[nc:] = E_[: n - nc] + P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc])) np = P.shape[-1] self.update_residual() nc = self.update_converged_count() self.S[..., :n] = self.X - self.S[:, n:n + np] = P + self.S[:, n : n + np] = P W = _utils.matmul(self.iK, self.R[:, nc:]) - self.ivars['converged_end'] = ns = n + np + W.shape[-1] - self.S[:, n + np:ns] = W + self.ivars["converged_end"] = ns = n + np + W.shape[-1] + self.S[:, n + np : ns] = W def _update_ortho(self): """ Update or initialize iteration variables when `method == "ortho"`. """ mm = torch.matmul - ns = self.ivars['converged_end'] - nc = self.ivars['converged_count'] - n = self.iparams['n'] - largest = self.bparams['largest'] + ns = self.ivars["converged_end"] + nc = self.ivars["converged_count"] + n = self.iparams["n"] + largest = self.bparams["largest"] - if self.ivars['istep'] == 0: + if self.ivars["istep"] == 0: Ri = self._get_rayleigh_ritz_transform(self.X) M = _utils.qform(_utils.qform(self.A, self.X), Ri) E, Z = _utils.symeig(M, largest) @@ -865,8 +915,8 @@ def _update_ortho(self): nc = self.update_converged_count() self.S[:, :n] = self.X W = self._get_ortho(self.R, self.X) - ns = self.ivars['converged_end'] = n + np + W.shape[-1] - self.S[:, n + np:ns] = W + ns = self.ivars["converged_end"] = n + np + W.shape[-1] + self.S[:, n + np : ns] = W else: S_ = self.S[:, nc:ns] @@ -874,9 +924,15 @@ def _update_ortho(self): E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest) # Update E, X, P - self.X[:, nc:] = mm(S_, Z[:, :n - nc]) - self.E[nc:] = E_[:n - nc] - P = mm(S_, mm(Z[:, n - nc:], _utils.basis(_utils.transpose(Z[:n - nc, n - nc:])))) + self.X[:, nc:] = mm(S_, Z[:, : n - nc]) + self.E[nc:] = E_[: n - nc] + P = mm( + S_, + mm( + Z[:, n - nc :], + _utils.basis(_utils.transpose(Z[: n - nc, n - nc :])), + ), + ) np = P.shape[-1] # check convergence @@ -885,10 +941,10 @@ def _update_ortho(self): # update S self.S[:, :n] = self.X - self.S[:, n:n + np] = P - W = self._get_ortho(self.R[:, nc:], self.S[:, :n + np]) - ns = self.ivars['converged_end'] = n + np + W.shape[-1] - self.S[:, n + np:ns] = W + self.S[:, n : n + np] = P + W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np]) + ns = self.ivars["converged_end"] = n + np + W.shape[-1] + self.S[:, n + np : ns] = W def _get_rayleigh_ritz_transform(self, S): """Return a transformation matrix that is used in Rayleigh-Ritz @@ -942,13 +998,13 @@ def _get_rayleigh_ritz_transform(self, S): d_col = d_row.reshape(d_row.shape[0], 1) # TODO use torch.linalg.cholesky_solve once it is implemented R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True) - return torch.linalg.solve_triangular(R, d_row.diag_embed(), upper=True, left=False) + return torch.linalg.solve_triangular( + R, d_row.diag_embed(), upper=True, left=False + ) - def _get_svqb(self, - U: Tensor, # Tensor - drop: bool, # bool - tau: float # float - ) -> Tensor: + def _get_svqb( + self, U: Tensor, drop: bool, tau: float # Tensor # bool # float + ) -> Tensor: """Return B-orthonormal U. .. note:: When `drop` is `False` then `svqb` is based on the @@ -994,7 +1050,7 @@ def _get_svqb(self, assert len(nz[0]) == len(d) # The original algorithm 4 from [DuerschPhD2015]. - d_col = (d ** -0.5).reshape(d.shape[0], 1) + d_col = (d**-0.5).reshape(d.shape[0], 1) DUBUD = (UBU * d_col) * _utils.transpose(d_col) E, Z = _utils.symeig(DUBUD) t = tau * abs(E).max() @@ -1007,7 +1063,7 @@ def _get_svqb(self, else: E[(torch.where(E < t))[0]] = t - return torch.matmul(U * _utils.transpose(d_col), Z * E ** -0.5) + return torch.matmul(U * _utils.transpose(d_col), Z * E**-0.5) def _get_ortho(self, U, V): """Return B-orthonormal U with columns are B-orthogonal to V. @@ -1036,28 +1092,28 @@ def _get_ortho(self, U, V): """ mm = torch.matmul mm_B = _utils.matmul - m = self.iparams['m'] - tau_ortho = self.fparams['ortho_tol'] - tau_drop = self.fparams['ortho_tol_drop'] - tau_replace = self.fparams['ortho_tol_replace'] - i_max = self.iparams['ortho_i_max'] - j_max = self.iparams['ortho_j_max'] + m = self.iparams["m"] + tau_ortho = self.fparams["ortho_tol"] + tau_drop = self.fparams["ortho_tol_drop"] + tau_replace = self.fparams["ortho_tol_replace"] + i_max = self.iparams["ortho_i_max"] + j_max = self.iparams["ortho_j_max"] # when use_drop==True, enable dropping U columns that have # small contribution to the `span([U, V])`. - use_drop = self.bparams['ortho_use_drop'] + use_drop = self.bparams["ortho_use_drop"] # clean up variables from the previous call for vkey in list(self.fvars.keys()): - if vkey.startswith('ortho_') and vkey.endswith('_rerr'): + if vkey.startswith("ortho_") and vkey.endswith("_rerr"): self.fvars.pop(vkey) - self.ivars.pop('ortho_i', 0) - self.ivars.pop('ortho_j', 0) + self.ivars.pop("ortho_i", 0) + self.ivars.pop("ortho_j", 0) BV_norm = torch.norm(mm_B(self.B, V)) BU = mm_B(self.B, U) VBU = mm(_utils.transpose(V), BU) i = j = 0 - stats = '' + stats = "" for i in range(i_max): U = U - mm(V, VBU) drop = False @@ -1071,20 +1127,18 @@ def _get_ortho(self, U, V): U = self._get_svqb(U, False, tau_replace) if torch.numel(U) == 0: # all initial U columns are B-collinear to V - self.ivars['ortho_i'] = i - self.ivars['ortho_j'] = j + self.ivars["ortho_i"] = i + self.ivars["ortho_j"] = j return U BU = mm_B(self.B, U) UBU = mm(_utils.transpose(U), BU) U_norm = torch.norm(U) BU_norm = torch.norm(BU) - R = UBU - torch.eye(UBU.shape[-1], - device=UBU.device, - dtype=UBU.dtype) + R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype) R_norm = torch.norm(R) # https://github.com/pytorch/pytorch/issues/33810 workaround: rerr = float(R_norm) * float(BU_norm * U_norm) ** -1 - vkey = 'ortho_UBUmI_rerr[{}, {}]'.format(i, j) + vkey = "ortho_UBUmI_rerr[{}, {}]".format(i, j) self.fvars[vkey] = rerr if rerr < tau_ortho: break @@ -1092,7 +1146,7 @@ def _get_ortho(self, U, V): VBU_norm = torch.norm(VBU) U_norm = torch.norm(U) rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1 - vkey = 'ortho_VBU_rerr[{}]'.format(i) + vkey = "ortho_VBU_rerr[{}]".format(i) self.fvars[vkey] = rerr if rerr < tau_ortho: break @@ -1102,16 +1156,20 @@ def _get_ortho(self, U, V): B = self.B assert B is not None raise ValueError( - 'Overdetermined shape of U:' - ' #B-cols(={}) >= #U-cols(={}) + #V-cols(={}) must hold' - .format(B.shape[-1], U.shape[-1], V.shape[-1])) - self.ivars['ortho_i'] = i - self.ivars['ortho_j'] = j + "Overdetermined shape of U:" + " #B-cols(={}) >= #U-cols(={}) + #V-cols(={}) must hold".format( + B.shape[-1], U.shape[-1], V.shape[-1] + ) + ) + self.ivars["ortho_i"] = i + self.ivars["ortho_j"] = j return U # Calling tracker is separated from LOBPCG definitions because # TorchScript does not support user-defined callback arguments: LOBPCG_call_tracker_orig = LOBPCG.call_tracker + + def LOBPCG_call_tracker(self): self.tracker(self) diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 89fddb6888266..0c55a566ba86d 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -1,20 +1,19 @@ """Implement various linear algebra algorithms for low rank matrices. """ -__all__ = ['svd_lowrank', 'pca_lowrank'] +__all__ = ["svd_lowrank", "pca_lowrank"] + +from typing import Optional, Tuple -from torch import Tensor import torch +from torch import Tensor from . import _linalg_utils as _utils -from .overrides import has_torch_function, handle_torch_function +from .overrides import handle_torch_function, has_torch_function -from typing import Optional, Tuple -def get_approximate_basis(A: Tensor, - q: int, - niter: Optional[int] = 2, - M: Optional[Tensor] = None - ) -> Tensor: +def get_approximate_basis( + A: Tensor, q: int, niter: Optional[int] = 2, M: Optional[Tensor] = None +) -> Tensor: """Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` @@ -82,8 +81,12 @@ def get_approximate_basis(A: Tensor, return Q -def svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, - M: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: +def svd_lowrank( + A: Tensor, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, batches of matrices, or a sparse matrix :math:`A` such that :math:`A \approx U diag(S) V^T`. In case :math:`M` is given, then @@ -125,13 +128,21 @@ def svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, """ if not torch.jit.is_scripting(): tensor_ops = (A, M) - if (not set(map(type, tensor_ops)).issubset((torch.Tensor, type(None))) and has_torch_function(tensor_ops)): - return handle_torch_function(svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M) + if not set(map(type, tensor_ops)).issubset( + (torch.Tensor, type(None)) + ) and has_torch_function(tensor_ops): + return handle_torch_function( + svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M + ) return _svd_lowrank(A, q=q, niter=niter, M=M) -def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, - M: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: +def _svd_lowrank( + A: Tensor, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: q = 6 if q is None else q m, n = A.shape[-2:] matmul = _utils.matmul @@ -177,8 +188,9 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, return U, S, V -def pca_lowrank(A: Tensor, q: Optional[int] = None, center: bool = True, - niter: int = 2) -> Tuple[Tensor, Tensor, Tensor]: +def pca_lowrank( + A: Tensor, q: Optional[int] = None, center: bool = True, niter: int = 2 +) -> Tuple[Tensor, Tensor, Tensor]: r"""Performs linear Principal Component Analysis (PCA) on a low-rank matrix, batches of such matrices, or sparse matrix. @@ -241,19 +253,21 @@ def pca_lowrank(A: Tensor, q: Optional[int] = None, center: bool = True, if not torch.jit.is_scripting(): if type(A) is not torch.Tensor and has_torch_function((A,)): - return handle_torch_function(pca_lowrank, (A,), A, q=q, center=center, niter=niter) + return handle_torch_function( + pca_lowrank, (A,), A, q=q, center=center, niter=niter + ) (m, n) = A.shape[-2:] if q is None: q = min(6, m, n) elif not (q >= 0 and q <= min(m, n)): - raise ValueError('q(={}) must be non-negative integer' - ' and not greater than min(m, n)={}' - .format(q, min(m, n))) + raise ValueError( + "q(={}) must be non-negative integer" + " and not greater than min(m, n)={}".format(q, min(m, n)) + ) if not (niter >= 0): - raise ValueError('niter(={}) must be non-negative integer' - .format(niter)) + raise ValueError("niter(={}) must be non-negative integer".format(niter)) dtype = _utils.get_floating_dtype(A) @@ -262,16 +276,20 @@ def pca_lowrank(A: Tensor, q: Optional[int] = None, center: bool = True, if _utils.is_sparse(A): if len(A.shape) != 2: - raise ValueError('pca_lowrank input is expected to be 2-dimensional tensor') + raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor") c = torch.sparse.sum(A, dim=(-2,)) / m # reshape c column_indices = c.indices()[0] - indices = torch.zeros(2, len(column_indices), - dtype=column_indices.dtype, - device=column_indices.device) + indices = torch.zeros( + 2, + len(column_indices), + dtype=column_indices.dtype, + device=column_indices.device, + ) indices[0] = column_indices C_t = torch.sparse_coo_tensor( - indices, c.values(), (n, 1), dtype=dtype, device=A.device) + indices, c.values(), (n, 1), dtype=dtype, device=A.device + ) ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) M = _utils.transpose(torch.sparse.mm(C_t, ones_m1_t)) diff --git a/torch/_masked/__init__.py b/torch/_masked/__init__.py index e4ea3cd638e0d..ac49fc95baf7e 100644 --- a/torch/_masked/__init__.py +++ b/torch/_masked/__init__.py @@ -1,16 +1,17 @@ # -*- coding: utf-8 -*- -from typing import Optional, Tuple, List, Union, Any - import warnings + +# A workaround to support both TorchScript and MyPy: +from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union + import torch from torch import Tensor from . import _docs -# A workaround to support both TorchScript and MyPy: -from typing import TYPE_CHECKING if TYPE_CHECKING: from torch.types import _dtype as DType + DimOrDims = Optional[Union[int, Tuple[int], List[int]]] else: # The JIT doesn't understand Union, nor torch.dtype here @@ -31,12 +32,13 @@ def _apply_docstring_templates(func): and returns the function instance. """ - doc_string = getattr(_docs, f'{func.__name__}_docstring', None) + doc_string = getattr(_docs, f"{func.__name__}_docstring", None) if doc_string is None: warnings.warn( - f'No documentation string available for {func.__name__}.' - ' PyTorch team should run `python tools/update_masked_docs.py`' - ' to generate the missing docstrings.') + f"No documentation string available for {func.__name__}." + " PyTorch team should run `python tools/update_masked_docs.py`" + " to generate the missing docstrings." + ) else: func.__doc__ = doc_string @@ -51,14 +53,14 @@ def _generate_docstring(func): script to update the module torch._masked._docs.py """ docstring_templates = dict( - reduction_signature='''\ -{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor''', - reduction_descr='''\ + reduction_signature="""\ +{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", + reduction_descr="""\ Returns {operation name} of all the elements in the :attr:`input` tensor along the given dimension(s) :attr:`dim` while the :attr:`input` elements are masked out according to the boolean tensor -:attr:`mask`.''', - reduction_args='''\ +:attr:`mask`.""", + reduction_args="""\ If :attr:`keepdim` is ``True``, the output tensor is of the same size as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. Otherwise, :attr:`dim` is squeezed (see @@ -92,8 +94,8 @@ def _generate_docstring(func): {args_declarations} Keyword args: - {kwargs_declarations}''', - reduction_example='''\ + {kwargs_declarations}""", + reduction_example="""\ Example:: >>> input = {example_input} @@ -104,22 +106,22 @@ def _generate_docstring(func): {indent_example_mask} >>> {full_function_name}(input, {example_args}, mask=mask) {indent_example_output} -''', - reduction_identity='''\ -The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.''', - reduction_identity_dtype='''\ +""", + reduction_identity="""\ +The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""", + reduction_identity_dtype="""\ The identity value of {operation name} operation, which is used to start the reduction, depends on input dtype. For instance, for float32, uint8, -and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.''', - normalization_signature='''\ -{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor''', - normalization_descr='''\ +and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""", + normalization_signature="""\ +{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", + normalization_descr="""\ Returns {operation name} of all the slices in the :attr:`input` tensor along :attr:`dim` while the :attr:`input` elements are masked out according to the boolean tensor :attr:`mask`. -{definition}''', - normalization_args='''\ +{definition}""", + normalization_args="""\ The boolean tensor :attr:`mask` defines the "validity" of :attr:`input` tensor elements: if :attr:`mask` element is True then the corresponding element in :attr:`input` tensor will be included in @@ -143,8 +145,8 @@ def _generate_docstring(func): {args_declarations} Keyword args: - {kwargs_declarations}''', - normalization_example='''\ + {kwargs_declarations}""", + normalization_example="""\ Example:: >>> input = {example_input} @@ -155,109 +157,126 @@ def _generate_docstring(func): {indent_example_mask} >>> {full_function_name}(input, {example_args}, mask=mask) {indent_example_output} -''') +""", + ) args_and_kwargs = dict( # argument name sufficies separated by double underscore will # be removed in the final documentation string. - sum=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')), - prod=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')), - cumsum=(('dim__as_int',), ('dtype=None', 'mask=None')), - cumprod=(('dim__as_int',), ('dtype=None', 'mask=None')), - amin=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')), - amax=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')), - argmin=(('dim__as_int',), ('keepdim=False', 'dtype=None', 'mask=None')), - argmax=(('dim__as_int',), ('keepdim=False', 'dtype=None', 'mask=None')), - mean=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')), - median=(('dim__as_int',), ('keepdim=False', 'dtype=None', 'mask=None')), - norm=(('ord', 'dim',), ('keepdim=False', 'dtype=None', 'mask=None')), - var=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')), - std=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')), - logsumexp=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')), - softmax=(('dim__as_int',), ('dtype=None', 'mask=None')), - log_softmax=(('dim__as_int',), ('dtype=None', 'mask=None')), - softmin=(('dim__as_int',), ('dtype=None', 'mask=None')), - normalize=(('ord__required', 'dim__as_int',), ('eps=1e-12', 'dtype=None', 'mask=None')), + sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + cumsum=(("dim__as_int",), ("dtype=None", "mask=None")), + cumprod=(("dim__as_int",), ("dtype=None", "mask=None")), + amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + norm=( + ( + "ord", + "dim", + ), + ("keepdim=False", "dtype=None", "mask=None"), + ), + var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + softmin=(("dim__as_int",), ("dtype=None", "mask=None")), + normalize=( + ( + "ord__required", + "dim__as_int", + ), + ("eps=1e-12", "dtype=None", "mask=None"), + ), ) argument_declarations = dict( - dim='''\ + dim="""\ dim (int or tuple of ints, optional): the dimension or dimensions to reduce. - Default: None that is equivalent to ``tuple(range(input.ndim))``.''', - dim__as_int='''\ -dim (int): the dimension along which {operation name} is computed.''', - ord='''\ + Default: None that is equivalent to ``tuple(range(input.ndim))``.""", + dim__as_int="""\ +dim (int): the dimension along which {operation name} is computed.""", + ord="""\ ord (int, float, optional): the order of vector norm. Default: 2. - See :func:`torch.linalg.vector_norm` for a list of supported norms.''', - ord__required='''\ + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + ord__required="""\ ord (int, float): the order of vector norm. Default: 2. - See :func:`torch.linalg.vector_norm` for a list of supported norms.''', - unbiased='''\ + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + unbiased="""\ unbiased (bool): when True, use Bessel’s correction, otherwise, compute - the uncorrected sample variance.''', - eps='''\ -eps (float, optional): small value to avoid division by zero. Default: {default}.''', - keepdim='''\ + the uncorrected sample variance.""", + eps="""\ +eps (float, optional): small value to avoid division by zero. Default: {default}.""", + keepdim="""\ keepdim (bool, optional): whether the output tensor has - :attr:`dim` retained or not. Default: {default}.''', - dtype='''\ + :attr:`dim` retained or not. Default: {default}.""", + dtype="""\ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` before the operation is - performed. Default: {default}.''', - mask='''\ + performed. Default: {default}.""", + mask="""\ mask (:class:`torch.Tensor`, optional): the boolean tensor containing the binary mask of validity of input tensor elements. - Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.''') + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""", + ) definitions = dict( - softmax='''\ + softmax="""\ Let ``x`` be a sequence of unmasked elements of one-dimensional slice of the :attr:`input` tensor. Softmax of i-th element in ``x`` is -defined as ``exp(x[i])/sum(exp(x))``.''', - log_softmax='''\ +defined as ``exp(x[i])/sum(exp(x))``.""", + log_softmax="""\ Let ``x`` be a sequence of unmasked elements of one-dimensional slice of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is -defined as ``log(exp(x[i])/sum(exp(x)))``.''', - softmin='''\ +defined as ``log(exp(x[i])/sum(exp(x)))``.""", + softmin="""\ Let ``x`` be a sequence of unmasked elements of one-dimensional slice of the :attr:`input` tensor. Softmin of i-th element in ``x`` is -defined as ``exp(-x[i])/sum(exp(-x))``.''', - normalize='''\ +defined as ``exp(-x[i])/sum(exp(-x))``.""", + normalize="""\ Let ``x`` be a sequence of unmasked elements of one-dimensional slice of the :attr:`input` tensor. Normalize of i-th element in ``x`` is -defined as ``x[i]/max(norm(x, p), eps)``.''', - cumsum='''\ +defined as ``x[i]/max(norm(x, p), eps)``.""", + cumsum="""\ Let ``x`` be a sequence of unmasked elements of one-dimensional slice of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is -defined as ``sum(x[:i])``.''', - cumprod='''\ +defined as ``sum(x[:i])``.""", + cumprod="""\ Let ``x`` be a sequence of unmasked elements of one-dimensional slice of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is -defined as ``prod(x[:i])``.''') +defined as ``prod(x[:i])``.""", + ) reduction_names = dict( - sum='sum', - prod='product', - amax='maximum', - amin='minimum', - argmax='argmax', - argmin='argmin', - mean='mean', - median='median', - norm='norm', - var='variance', - std='standard_deviation', - logsumexp='logsumexp') + sum="sum", + prod="product", + amax="maximum", + amin="minimum", + argmax="argmax", + argmin="argmin", + mean="mean", + median="median", + norm="norm", + var="variance", + std="standard_deviation", + logsumexp="logsumexp", + ) normalization_names = dict( - softmax='softmax', - log_softmax='log_softmax', - softmin='softmin', - normalize='normalize', - cumsum='cumulative_sum', - cumprod='cumulative_prod') + softmax="softmax", + log_softmax="log_softmax", + softmin="softmin", + normalize="normalize", + cumsum="cumulative_sum", + cumprod="cumulative_prod", + ) operation_names = dict() operation_names.update(reduction_names) @@ -268,12 +287,12 @@ def _generate_docstring(func): example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]]) example_mask = torch.tensor([[True, False, True], [False, False, False]]) example_args: Tuple[Any, ...] - if func.__name__ in {'norm', 'normalize'}: + if func.__name__ in {"norm", "normalize"}: example_args = (2.0, example_dim) example_input = example_input.to(dtype=torch.float32) - elif func.__name__ in {'var', 'std'}: + elif func.__name__ in {"var", "std"}: example_args = (example_dim, False) - elif func.__name__ == 'median': + elif func.__name__ == "median": example_args = (example_dim,) example_input = example_input.to(dtype=torch.float32) else: @@ -283,66 +302,94 @@ def _generate_docstring(func): operation_kwargs: Tuple[str, ...] operation_args, operation_kwargs = args_and_kwargs[func.__name__] arg_declarations = [ - '\n '.join(argument_declarations - .get(a, f'{a.split("__", 1)[0]}: TBD.') - .splitlines()) - for a in operation_args] + "\n ".join( + argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines() + ) + for a in operation_args + ] kwarg_declarations = [ - '\n '.join(argument_declarations - .get(a.split('=', 1)[0], f'{a.split("__", 1)[0]}: TBD.') - .format(default=a.split('=', 1)[1]) - .splitlines()) - for a in operation_kwargs] + "\n ".join( + argument_declarations.get( + a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.' + ) + .format(default=a.split("=", 1)[1]) + .splitlines() + ) + for a in operation_kwargs + ] if func.__name__ in reduction_names: - op_kind = 'reduction' - doc_sections = ['signature', 'descr', 'identity', 'args', 'example'] + op_kind = "reduction" + doc_sections = ["signature", "descr", "identity", "args", "example"] elif func.__name__ in normalization_names: - op_kind = 'normalization' - doc_sections = ['signature', 'descr', 'args', 'example'] + op_kind = "normalization" + doc_sections = ["signature", "descr", "args", "example"] example_input = example_input.to(dtype=torch.float32) else: assert 0 # add function name to operation names dictionaries example_output = func(example_input, *example_args, mask=example_mask) - template_data = {'function_name': func.__name__, - 'full_function_name': func.__module__ + '.' + func.__name__, - 'operation name': operation_names[func.__name__], - 'operation_args': ', '.join(a.split('__', 1)[0] for a in operation_args), - 'operation_kwargs': ', '.join(a.split('__', 1)[0] for a in operation_kwargs), - # one-line representation of a tensor: - 'example_input': ' '.join(str(example_input).split()), - 'example_args': ', '.join(map(str, example_args)), - 'example_mask': ' '.join(str(example_mask).split()), - # multi-line representation of a tensor with indent - 'indent_example_input': ('\n ').join(str(example_input).splitlines()), - 'indent_example_mask': ('\n ').join(str(example_mask).splitlines()), - 'indent_example_output': ('\n ').join(str(example_output).splitlines())} + template_data = { + "function_name": func.__name__, + "full_function_name": func.__module__ + "." + func.__name__, + "operation name": operation_names[func.__name__], + "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args), + "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs), + # one-line representation of a tensor: + "example_input": " ".join(str(example_input).split()), + "example_args": ", ".join(map(str, example_args)), + "example_mask": " ".join(str(example_mask).split()), + # multi-line representation of a tensor with indent + "indent_example_input": ("\n ").join(str(example_input).splitlines()), + "indent_example_mask": ("\n ").join(str(example_mask).splitlines()), + "indent_example_output": ("\n ").join(str(example_output).splitlines()), + } if func.__name__ in reduction_names: template_data.update( - identity_uint8=_reduction_identity(func.__name__, torch.tensor(0, dtype=torch.uint8)), - identity_int32=_reduction_identity(func.__name__, torch.tensor(0, dtype=torch.int32)), - identity_float32=_reduction_identity(func.__name__, torch.tensor(0, dtype=torch.float32))) - if func.__name__ == 'norm': + identity_uint8=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.uint8) + ), + identity_int32=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.int32) + ), + identity_float32=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.float32) + ), + ) + if func.__name__ == "norm": template_data.update( identity_ord_ninf=_reduction_identity( - func.__name__, torch.tensor(0, dtype=torch.float32), float('-inf'))) + func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf") + ) + ) elif func.__name__ in normalization_names: template_data.update(definition=definitions[func.__name__]) else: assert 0 # add function name to operation names dictionaries - template_data.update(args_declarations=('\n '.join(arg_declarations)).format_map(template_data)) - template_data.update(kwargs_declarations=('\n '.join(kwarg_declarations)).format_map(template_data)) + template_data.update( + args_declarations=("\n ".join(arg_declarations)).format_map(template_data) + ) + template_data.update( + kwargs_declarations=("\n ".join(kwarg_declarations)).format_map( + template_data + ) + ) # Apply function name info to docstring templates: - templates = dict((k, v.format_map(template_data)) - for k, v in docstring_templates.items() if k.startswith(op_kind)) - templates.update((k, v.format_map(template_data) if isinstance(v, str) else v) for k, v in template_data.items()) + templates = dict( + (k, v.format_map(template_data)) + for k, v in docstring_templates.items() + if k.startswith(op_kind) + ) + templates.update( + (k, v.format_map(template_data) if isinstance(v, str) else v) + for k, v in template_data.items() + ) # Apply docstring templates to function doctring: if func.__doc__ is None: - doc_template = '\n\n'.join([f'{{{op_kind}_{sec}}}' for sec in doc_sections]) + doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections]) else: doc_template = func.__doc__ return doc_template.format_map(templates) @@ -364,47 +411,46 @@ def _reduction_identity(op_name: str, input: Tensor, *args): """ dtype: DType = input.dtype device = input.device - op_name = op_name.rsplit('.', 1)[-1] # lstrip module name when present - if op_name in {'sum', 'cumsum'}: + op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present + if op_name in {"sum", "cumsum"}: return torch.tensor(0, dtype=dtype, device=device) - elif op_name in {'prod', 'cumprod'}: + elif op_name in {"prod", "cumprod"}: return torch.tensor(1, dtype=dtype, device=device) - elif op_name in {'amax', 'argmax', 'logsumexp'}: + elif op_name in {"amax", "argmax", "logsumexp"}: if torch.is_floating_point(input): return torch.tensor(-torch.inf, dtype=dtype, device=device) elif torch.is_signed(input) or dtype == torch.uint8: return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) - elif op_name in {'amin', 'argmin'}: + elif op_name in {"amin", "argmin"}: if torch.is_floating_point(input): return torch.tensor(torch.inf, dtype=dtype, device=device) elif torch.is_signed(input) or dtype == torch.uint8: return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) - elif op_name == 'mean': + elif op_name == "mean": # Strictly speaking, the identity value of the mean operation # is the mean of the input. Since the mean value depends on # the dim argument and it may be a non-scalar tensor, we # consider the identity value of the mean operation ambiguous. # Moreover, the mean value of empty input is undefined. return None - elif op_name == 'norm': + elif op_name == "norm": ord = args[0] if args else 2 - if ord == float('-inf'): + if ord == float("-inf"): assert torch.is_floating_point(input), input.dtype return torch.tensor(torch.inf, dtype=dtype, device=device) return torch.tensor(0, dtype=dtype, device=device) - elif op_name == 'median': + elif op_name == "median": # We use NaN for now because the implementation is currently using torch.nanmedian # and NaN is the identity for that function since it gets ignored dtype = input.dtype if torch.is_floating_point(input) else torch.float return torch.tensor(torch.nan, dtype=dtype, device=device) - elif op_name in {'var', 'std'}: + elif op_name in {"var", "std"}: return None - raise NotImplementedError(f'identity of {op_name} on {dtype} input') + raise NotImplementedError(f"identity of {op_name} on {dtype} input") def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]: - """Return dim argument as a tuple of sorted dim values. - """ + """Return dim argument as a tuple of sorted dim values.""" dims: List[int] = [] if dim == (): # Currently, `dim=()` in reductions operations means "reduce @@ -418,9 +464,11 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]: dim_ = (dim,) if isinstance(dim, int) else dim for d in dim_: if d in dims: - raise RuntimeError(f'dim={d} appears multiple times in the list of dims') + raise RuntimeError(f"dim={d} appears multiple times in the list of dims") if d >= ndim or d < -ndim: - raise IndexError(f'Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})') + raise IndexError( + f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})" + ) dims.append(d % ndim) return tuple(sorted(dims)) @@ -477,12 +525,18 @@ def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor # For set operations on sparse tensor indices, we'll convert # multi-dimensional indices to 1-D indices for efficiency. - input_flat_indices = _sparse_coo_flatten_indices(input.indices(), input.shape[:input.sparse_dim()]) - mask_flat_indices = _sparse_coo_flatten_indices(mask.indices(), mask.shape[:mask.sparse_dim()]) + input_flat_indices = _sparse_coo_flatten_indices( + input.indices(), input.shape[: input.sparse_dim()] + ) + mask_flat_indices = _sparse_coo_flatten_indices( + mask.indices(), mask.shape[: mask.sparse_dim()] + ) # the set of mask flat indices that define masked-in elements: if mask.dense_dim() > 0: - mask_values = _any(mask.values(), tuple(range(1, input.sparse_dim() + 1)), False) + mask_values = _any( + mask.values(), tuple(range(1, input.sparse_dim() + 1)), False + ) else: mask_values = mask.values() maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]] @@ -500,7 +554,9 @@ def _apply(a): return obj[w] # the set of input flat indices of specified and masked-in elements: - maskin_input_flat_indices = _apply(intersection(maskin_flat_indices, input_flat_indices)) + maskin_input_flat_indices = _apply( + intersection(maskin_flat_indices, input_flat_indices) + ) _, w = intersection(input_flat_indices, maskin_input_flat_indices) # the indices and values of masked-in elements @@ -511,11 +567,14 @@ def _apply(a): # apply mask to the dense part of the input values: _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices) where_mask_values = mask.values()[w1] - where_input_values = torch.where(where_mask_values, where_input_values, - where_input_values.new_full([], fill_value.item())) + where_input_values = torch.where( + where_mask_values, where_input_values, fill_value + ) # the set of flat indices of unspecified input and masked-in elements: - maskin_zero_flat_indices = _apply(minus(maskin_flat_indices, maskin_input_flat_indices)) + maskin_zero_flat_indices = _apply( + minus(maskin_flat_indices, maskin_input_flat_indices) + ) # the indices of masked-in zero elements _, w = intersection(mask_flat_indices, maskin_zero_flat_indices) @@ -526,26 +585,37 @@ def _apply(a): if n == 0: # the input is coalesced, hence input_flat_indices are ordered # and the result is guaranteed to be coalesced: - result = torch.sparse_coo_tensor(where_input_indices, where_input_values, input.shape) + result = torch.sparse_coo_tensor( + where_input_indices, where_input_values, input.shape + ) return result._coalesced_(True) where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1) - where_values = torch.cat([where_input_values, where_input_values.new_zeros((n,) + where_input_values.shape[1:])]) + where_values = torch.cat( + [ + where_input_values, + where_input_values.new_zeros((n,) + where_input_values.shape[1:]), + ] + ) result = torch.sparse_coo_tensor(where_indices, where_values, input.shape) # appending zero elements leads to uncoalesced sparse tensor return result.coalesce() -def _sparse_coo_scatter_reduction_helper(op, - mask_input: Tensor, - dims: Tuple[int, ...], - keepdim: bool, - dtype: Optional[DType] = None) -> Tensor: +def _sparse_coo_scatter_reduction_helper( + op, + mask_input: Tensor, + dims: Tuple[int, ...], + keepdim: bool, + dtype: Optional[DType] = None, +) -> Tensor: reduce = op.__name__ - valid_reductions = ['sum', 'prod', 'amax', 'amin'] + valid_reductions = ["sum", "prod", "amax", "amin"] if reduce not in valid_reductions: - raise ValueError(f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead") + raise ValueError( + f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" + ) output_dtype = dtype values, indices = mask_input._values(), mask_input._indices() @@ -560,12 +630,16 @@ def _sparse_coo_scatter_reduction_helper(op, values = values.to(output_dtype) if keepdim: - output_shape = tuple(1 if i in dims else si for (i, si) in enumerate(mask_input.shape)) + output_shape = tuple( + 1 if i in dims else si for (i, si) in enumerate(mask_input.shape) + ) else: - output_shape = tuple(si for (i, si) in enumerate(mask_input.shape) if i not in dims) + output_shape = tuple( + si for (i, si) in enumerate(mask_input.shape) if i not in dims + ) for d in dims: - if (d >= input_dims): + if d >= input_dims: continue if d < num_sparse_dims: @@ -586,14 +660,14 @@ def _sparse_coo_scatter_reduction_helper(op, # Reduce sparse dimensions if len(reduced_sparse_dims) == num_sparse_dims: - if reduce in {'amax', 'amin'} and new_values.size(0) == 0: + if reduce in {"amax", "amin"} and new_values.size(0) == 0: # IndexError: amax(): Expected reduction dim 0 to have non-zero size. # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not # See https://github.com/pytorch/pytorch/issues/61901 new_values = _reduction_identity(reduce, new_values) else: new_values = op(new_values, dim=0) - if (keepdim): + if keepdim: for _ in range(num_sparse_dims): new_values = new_values.unsqueeze(0) return new_values.to(dtype=output_dtype).to_sparse() @@ -605,14 +679,22 @@ def _sparse_coo_scatter_reduction_helper(op, new_indices[reduced_sparse_dims, :] = 0 else: # remove reduced sparse dimensions if keepdim = False - if (len(reduced_sparse_dims) > 0): - retained_sparse_dims = [i for i in range(num_sparse_dims) if i not in set(reduced_sparse_dims)] - new_indices = new_indices.index_select(0, torch.tensor(retained_sparse_dims).to(mask_input.device)) + if len(reduced_sparse_dims) > 0: + retained_sparse_dims = [ + i + for i in range(num_sparse_dims) + if i not in set(reduced_sparse_dims) + ] + new_indices = new_indices.index_select( + 0, torch.tensor(retained_sparse_dims).to(mask_input.device) + ) # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices - if (new_indices.numel() > 0): + if new_indices.numel() > 0: # lexsort indices and get index tensor for scatter reduction - new_indices, inverse_indices = torch.unique(new_indices, return_inverse=True, dim=1) + new_indices, inverse_indices = torch.unique( + new_indices, return_inverse=True, dim=1 + ) out_shape = list(new_values.shape) out_shape[0] = new_indices.shape[1] for _ in range(new_values.ndim - 1): @@ -622,20 +704,120 @@ def _sparse_coo_scatter_reduction_helper(op, if output_dtype in {torch.bfloat16, torch.float16}: new_values = new_values.to(torch.float) out = new_values.new_empty(out_shape) - new_values = out.scatter_reduce_(0, scatter_indices, new_values, reduce=reduce, include_self=False) + new_values = out.scatter_reduce_( + 0, scatter_indices, new_values, reduce=reduce, include_self=False + ) new_values = new_values.to(dtype=output_dtype) else: out = new_values.new_empty(out_shape) - new_values = out.scatter_reduce_(0, scatter_indices, new_values, reduce=reduce, include_self=False) + new_values = out.scatter_reduce_( + 0, scatter_indices, new_values, reduce=reduce, include_self=False + ) + + return torch.sparse_coo_tensor( + new_indices, + new_values, + output_shape, + dtype=output_dtype, + device=mask_input.device, + ) - return torch.sparse_coo_tensor(new_indices, new_values, output_shape, dtype=output_dtype, device=mask_input.device) + +def _sparse_csr_segment_reduction_helper( + op, + mask_input: Tensor, + dims: Tuple[int, ...], + keepdim: bool, + dtype: Optional[DType] = None, +) -> Tensor: + # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True + # FIXME: when dense dimensions are implemented for CSR tensors + assert ( + keepdim + ), "reduction operations on CSR tensors with keepdim=False is unsupported" + reduce = op.__name__ + valid_reductions = ["sum", "prod", "mean", "amax", "amin"] + if reduce not in valid_reductions: + raise ValueError( + f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" + ) + device = mask_input.device + output_dtype = dtype + values, crow_indices, col_indices = ( + mask_input.values(), + mask_input.crow_indices(), + mask_input.col_indices(), + ) + + # promote dtype if specified + if values.dtype != output_dtype: + values = values.to(output_dtype) + + if len(dims) == 0: + return mask_input + if len(dims) == 1: + if dims[0] == 0: + new_col_indices, scatter_indices = torch.unique( + col_indices, return_inverse=True + ) + new_nnz = new_col_indices.shape[0] + new_crow_indices = torch.tensor([0, new_nnz]) + new_values = values.new_empty(new_col_indices.shape) + new_values.scatter_reduce_( + 0, scatter_indices, values, reduce, include_self=False + ) + new_shape = [1, mask_input.size(1)] + else: + assert ( + dims[0] == 1 + ), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." + # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 + # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 + new_crow_indices = torch.cat( + ( + crow_indices.new_zeros(1), + torch.cumsum(torch.diff(crow_indices) != 0, 0), + ), + 0, + ) + new_nnz = new_crow_indices[-1] + new_col_indices = col_indices.new_zeros(new_nnz) + # segment_reduce takes 'max'/'min' rather than 'amax'/'amin', changing this would be BC-breaking + if reduce in ["amax", "amin"]: + reduce = reduce[1:] + new_values = torch.segment_reduce(values, reduce, offsets=crow_indices) + new_shape = [mask_input.size(0), 1] + else: + assert len(dims) == 2 + nnz = min(1, values.numel()) + if nnz == 1: + op_kwargs = {"keepdim": True, "dtype": output_dtype} + # amax and amin do not support dtype kwarg + if reduce in ["amax", "amin"]: + del op_kwargs["dtype"] + new_values = op(values, 0, **op_kwargs) + else: + new_values = torch.empty(0, dtype=output_dtype) + new_col_indices = col_indices.new_zeros(nnz) + new_crow_indices = torch.tensor([0, nnz]) + new_shape = [1, nnz] + + return torch.sparse_csr_tensor( + new_crow_indices, + new_col_indices, + new_values, + new_shape, + dtype=output_dtype, + device=device, + ) def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: - """Sparse variant of torch.where. Supports sparse CSR tensors. - """ + """Sparse variant of torch.where. Supports sparse CSR tensors.""" # TODO: implement sparse CSR specific where operator for efficiency - return _sparse_coo_where(mask.to_sparse_coo(), input.to_sparse_coo(), fill_value).to_sparse_csr() + return _sparse_coo_where( + mask.to_sparse_coo(), input.to_sparse_coo(), fill_value + ).to_sparse_csr() def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: @@ -663,22 +845,15 @@ def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: - all unspecified elements correspond to masked-out elements. """ if mask.layout == torch.strided: - if fill_value.dtype == torch.bool: - # Workaround internal assert failure in - # test_nvfuser_correctness__masked_mean_cuda_bool: We - # don't have an op for aten::new_full but it isn't a - # special case. Argument types: Tensor, int[], bool, int, - # int, Device, bool - fill = input.new_full([], int(fill_value.item())).to(dtype=torch.bool) - else: - fill = input.new_full([], fill_value.item()) - return torch.where(mask, input, fill) + return torch.where(mask, input, fill_value) elif mask.layout == torch.sparse_coo: return _sparse_coo_where(mask, input, fill_value) elif mask.layout == torch.sparse_csr: return _sparse_csr_where(mask, input, fill_value) else: - raise ValueError(f'_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}') + raise ValueError( + f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}" + ) def _input_mask(input: Tensor, *args, **kwargs) -> Tensor: @@ -715,18 +890,22 @@ def _input_mask(input: Tensor, *args, **kwargs) -> Tensor: """ if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: - raise ValueError(f'_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}') + raise ValueError( + f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}" + ) - mask = kwargs.get('mask') + mask = kwargs.get("mask") # default mask if mask is None: - raise ValueError('_input_mask requires explicit mask') + raise ValueError("_input_mask requires explicit mask") # mask shape must match with input shape if mask.shape != input.shape: if mask.ndim > input.ndim: - raise IndexError("_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)") + raise IndexError( + "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)" + ) if mask.layout == torch.strided: mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool) elif mask.layout == torch.sparse_coo: @@ -735,7 +914,9 @@ def _input_mask(input: Tensor, *args, **kwargs) -> Tensor: assert mask.layout == torch.sparse_csr # Broadcasting of CSR tensors is not implemented. Working # around by using COO layout. - mask = torch._sparse_broadcast_to(mask.to_sparse(), input.shape).to_sparse_csr() + mask = torch._sparse_broadcast_to( + mask.to_sparse(), input.shape + ).to_sparse_csr() # mask layout must match with input layout if mask.layout != input.layout: @@ -761,32 +942,53 @@ def _input_mask(input: Tensor, *args, **kwargs) -> Tensor: def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: - """Return output mask of masked operation applied to given arguments. - """ + """Return output mask of masked operation applied to given arguments.""" if callable(op): - is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', - 'argmax', 'argmin', 'mean', 'median', 'norm', 'var', 'std', 'logsumexp'} - is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize', 'cumsum', 'cumprod'} + is_reduction = op.__name__ in { + "sum", + "prod", + "amax", + "amin", + "argmax", + "argmin", + "mean", + "median", + "norm", + "var", + "std", + "logsumexp", + } + is_normalization = op.__name__ in { + "softmax", + "log_softmax", + "softmin", + "normalize", + "cumsum", + "cumprod", + } if is_reduction: - if op.__name__ == 'norm': + if op.__name__ == "norm": if args: args = args[1:] # lstrip ord argument - dim = args[0] if args else kwargs.get('dim') + dim = args[0] if args else kwargs.get("dim") outmask = _input_mask(input, *args, **kwargs) - keepdim = kwargs.get('keepdim', False) + keepdim = kwargs.get("keepdim", False) dim_ = _canonical_dim(dim, input.ndim) return _any(outmask, dim_, bool(keepdim)) elif is_normalization: return _input_mask(input, *args, **kwargs) else: - raise ValueError(f'_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})') + raise ValueError( + f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})" + ) else: - raise ValueError(f'_output_mask expected masked operation (got {type(op).__name__} object)') + raise ValueError( + f"_output_mask expected masked operation (got {type(op).__name__} object)" + ) def _combine_input_and_mask(op, input: Tensor, mask, *args) -> Tensor: - """Return input with masked-out elements eliminated for the given operations. - """ + """Return input with masked-out elements eliminated for the given operations.""" if mask is None: return input canonical_mask = _input_mask(input, mask=mask) @@ -794,21 +996,31 @@ def _combine_input_and_mask(op, input: Tensor, mask, *args) -> Tensor: fill_value = _reduction_identity(op.__name__, input, *args) return _where(canonical_mask, input, fill_value) else: - raise ValueError(f'_combine_input_and_mask expected masked operation (got {type(op).__name__} object)') + raise ValueError( + f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)" + ) @_apply_docstring_templates -def sum(input: Tensor, - dim: DimOrDims = None, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def sum( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: # __doc__ is generated by _apply_docstring_templates decorator if dtype is None: # promote integer types to int64 when output dtype is not specified if input.layout == torch.sparse_csr: - if input.dtype in {torch.uint8, torch.bool, torch.int8, torch.int16, torch.int32}: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: # csr.to(dtype=torch.int64) is not implemented, so # using coo.to on input to ensure the promoted dtype input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() @@ -816,32 +1028,52 @@ def sum(input: Tensor, dtype = input.dtype else: dtype = input.dtype - if input.dtype in {torch.uint8, torch.bool, torch.int8, torch.int16, torch.int32}: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: dtype = torch.int64 dim_ = _canonical_dim(dim, input.ndim) mask_input = _combine_input_and_mask(sum, input, mask) if input.layout == torch.strided: return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype) elif input.layout == torch.sparse_coo: - return _sparse_coo_scatter_reduction_helper(torch.sum, mask_input, dim_, bool(keepdim), dtype) + return _sparse_coo_scatter_reduction_helper( + torch.sum, mask_input, dim_, bool(keepdim), dtype + ) elif input.layout == torch.sparse_csr: - return torch._sparse_csr_sum(mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype) + return torch._sparse_csr_sum( + mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype + ) else: - raise ValueError(f'masked sum expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)') + raise ValueError( + f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def prod(input: Tensor, - dim: DimOrDims = None, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def prod( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: # __doc__ is generated by _apply_docstring_templates decorator if dtype is None: # promote integer types to int64 when output dtype is not specified if input.layout == torch.sparse_csr: - if input.dtype in {torch.uint8, torch.bool, torch.int8, torch.int16, torch.int32}: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: # csr.to(dtype=torch.int64) is not implemented, so # using coo.to on input to ensure the promoted dtype input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() @@ -849,7 +1081,13 @@ def prod(input: Tensor, dtype = input.dtype else: dtype = input.dtype - if input.dtype in {torch.uint8, torch.bool, torch.int8, torch.int16, torch.int32}: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: dtype = torch.int64 dim_ = _canonical_dim(dim, input.ndim) mask_input = _combine_input_and_mask(prod, input, mask) @@ -863,8 +1101,12 @@ def prod(input: Tensor, elif input.layout == torch.sparse_coo: if mask is None: # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors - raise ValueError('masked prod expects explicit mask for sparse_coo tensor input') - return _sparse_coo_scatter_reduction_helper(torch.prod, mask_input, dim_, bool(keepdim), dtype) + raise ValueError( + "masked prod expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.prod, mask_input, dim_, bool(keepdim), dtype + ) elif input.layout == torch.sparse_csr: if mask is None: # mask is None corresponds to all-True mask. The @@ -877,18 +1119,26 @@ def prod(input: Tensor, # # but that requires implementing `all` and `nonzero` # support for sparse csr tensors. - raise ValueError('masked prod expects explicit mask for sparse_csr tensor input') - return torch._sparse_csr_prod(mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype) + raise ValueError( + "masked prod expects explicit mask for sparse_csr tensor input" + ) + return torch._sparse_csr_prod( + mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype + ) else: - raise ValueError(f'masked prod expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)') + raise ValueError( + f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def cumsum(input: Tensor, - dim: int, - *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def cumsum( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: if dtype is None: dtype = input.dtype dim_ = _canonical_dim(dim, input.ndim)[0] @@ -896,15 +1146,19 @@ def cumsum(input: Tensor, if input.layout == torch.strided: return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype) else: - raise ValueError(f'masked cumsum expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked cumsum expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def cumprod(input: Tensor, - dim: int, - *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def cumprod( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: if dtype is None: dtype = input.dtype dim_ = _canonical_dim(dim, input.ndim)[0] @@ -912,16 +1166,20 @@ def cumprod(input: Tensor, if input.layout == torch.strided: return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype) else: - raise ValueError(f'masked cumprod expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked cumprod expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def amax(input: Tensor, - dim: DimOrDims = None, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def amax( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} @@ -943,19 +1201,35 @@ def amax(input: Tensor, if mask is None: # See comment in the sparse_csr branch of prod, a similar issue arises here # where unspecified elements along a dimension may need to be reduced with the result - raise ValueError('masked amax expects explicit mask for sparse_coo tensor input') - return _sparse_coo_scatter_reduction_helper(torch.amax, mask_input, dim_, bool(keepdim), dtype) + raise ValueError( + "masked amax expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.amax, mask_input, dim_, bool(keepdim), dtype + ) + elif input.layout == torch.sparse_csr: + if mask is None: + raise ValueError( + "masked amax expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.amax, mask_input, dim_, bool(keepdim), dtype + ) else: - raise ValueError(f'masked amax expects strided or sparse_coo tensor (got {input.layout} tensor)') + raise ValueError( + f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def amin(input: Tensor, - dim: DimOrDims = None, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def amin( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} @@ -977,19 +1251,35 @@ def amin(input: Tensor, if mask is None: # See comment in the sparse_csr branch of prod, a similar issue arises here # where unspecified elements along a dimension may need to be reduced with the result - raise ValueError('masked amax expects explicit mask for sparse_coo tensor input') - return _sparse_coo_scatter_reduction_helper(torch.amin, mask_input, dim_, bool(keepdim), dtype) + raise ValueError( + "masked amax expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.amin, mask_input, dim_, bool(keepdim), dtype + ) + elif input.layout == torch.sparse_csr: + if mask is None: + raise ValueError( + "masked amin expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.amin, mask_input, dim_, bool(keepdim), dtype + ) else: - raise ValueError(f'masked amin expects strided or sparse_coo tensor (got {input.layout} tensor)') + raise ValueError( + f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def argmax(input: Tensor, - dim: int = None, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def argmax( + input: Tensor, + dim: int = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} {reduction_descr} @@ -1002,16 +1292,20 @@ def argmax(input: Tensor, if input.layout == torch.strided: return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype) else: - raise ValueError(f'masked argmax expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked argmax expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def argmin(input: Tensor, - dim: int = None, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def argmin( + input: Tensor, + dim: int = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} {reduction_descr} @@ -1024,16 +1318,20 @@ def argmin(input: Tensor, if input.layout == torch.strided: return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype) else: - raise ValueError(f'masked argmin expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked argmin expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def mean(input: Tensor, - dim: DimOrDims = None, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def mean( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} @@ -1054,25 +1352,47 @@ def mean(input: Tensor, if input.layout == torch.strided: if mask is None: # TODO: compute count analytically - count = sum(torch.ones(input.shape, dtype=torch.int64, device=input.device), dim, keepdim=keepdim) + count = sum( + torch.ones(input.shape, dtype=torch.int64, device=input.device), + dim, + keepdim=keepdim, + ) total = sum(input, dim, keepdim=keepdim, dtype=dtype) else: inmask = _input_mask(input, mask=mask) - count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=keepdim, mask=inmask) + count = sum( + inmask.new_ones(input.shape, dtype=torch.int64), + dim, + keepdim=keepdim, + mask=inmask, + ) total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) return total / count + elif input.layout == torch.sparse_csr: + mask_input = _combine_input_and_mask(mean, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask is None: + raise ValueError( + "masked mean expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.mean, mask_input, dim_, bool(keepdim), dtype + ) else: - raise ValueError(f'masked sum expects strided tensor (got {input.layout} tensor)') - + raise ValueError( + f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def median(input: Tensor, - dim: int = -1, - *, - keepdim: bool = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def median( + input: Tensor, + dim: int = -1, + *, + keepdim: bool = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} @@ -1099,18 +1419,24 @@ def median(input: Tensor, elif not is_float and not torch.isnan(output).any(): return output.to(dtype=dtype) else: - raise ValueError("masked median expects no fully masked out rows if dtype is not floating point") + raise ValueError( + "masked median expects no fully masked out rows if dtype is not floating point" + ) else: - raise ValueError(f'masked median expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked median expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def logsumexp(input: Tensor, - dim: DimOrDims = None, - *, - keepdim: bool = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def logsumexp( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: bool = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: if dtype is None: dtype = input.dtype dim_ = _canonical_dim(dim, input.ndim) @@ -1118,17 +1444,21 @@ def logsumexp(input: Tensor, if input.layout == torch.strided: return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype) else: - raise ValueError(f'masked logsumexp expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked logsumexp expects strided tensor (got {input.layout} tensor)" + ) # TODO: Add docstring; currently they're only set up for reductions and normalizations # @_apply_docstring_templates -def logaddexp(input: Tensor, - other: Tensor, - *, - dtype: Optional[DType] = None, - input_mask: Optional[Tensor] = None, - other_mask: Optional[Tensor] = None) -> Tensor: +def logaddexp( + input: Tensor, + other: Tensor, + *, + dtype: Optional[DType] = None, + input_mask: Optional[Tensor] = None, + other_mask: Optional[Tensor] = None, +) -> Tensor: if dtype is None: dtype = input.dtype if input.layout == torch.strided and other.layout == torch.strided: @@ -1137,17 +1467,20 @@ def logaddexp(input: Tensor, return torch.logaddexp(mask_input, mask_other).to(dtype=dtype) else: raise ValueError( - f'masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)') + f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)" + ) @_apply_docstring_templates -def norm(input: Tensor, - ord: Optional[float] = 2.0, - dim: DimOrDims = None, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def norm( + input: Tensor, + ord: Optional[float] = 2.0, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} @@ -1165,19 +1498,25 @@ def norm(input: Tensor, mask_input = _combine_input_and_mask(norm, input, mask, ord) if input.layout == torch.strided: dim_ = _canonical_dim(dim, input.ndim) - return torch.linalg.vector_norm(mask_input, ord, dim_, bool(keepdim), dtype=dtype) + return torch.linalg.vector_norm( + mask_input, ord, dim_, bool(keepdim), dtype=dtype + ) else: - raise ValueError(f'masked norm expects strided tensor (got {input.layout} tensor)') - - -def std_var(input: Tensor, - dim: DimOrDims = None, - unbiased: Optional[bool] = False, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, - take_sqrt: Optional[bool] = False) -> Tensor: + raise ValueError( + f"masked norm expects strided tensor (got {input.layout} tensor)" + ) + + +def std_var( + input: Tensor, + dim: DimOrDims = None, + unbiased: Optional[bool] = False, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, + take_sqrt: Optional[bool] = False, +) -> Tensor: if dtype is None: dtype = input.dtype if not (dtype.is_floating_point or dtype.is_complex): @@ -1188,11 +1527,20 @@ def std_var(input: Tensor, if input.layout == torch.strided: if mask is None: # TODO: compute count analytically - count = sum(torch.ones(input.shape, dtype=torch.int64, device=input.device), dim, keepdim=True) + count = sum( + torch.ones(input.shape, dtype=torch.int64, device=input.device), + dim, + keepdim=True, + ) sample_total = sum(input, dim, keepdim=True, dtype=dtype) else: inmask = _input_mask(input, mask=mask) - count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=True, mask=inmask) + count = sum( + inmask.new_ones(input.shape, dtype=torch.int64), + dim, + keepdim=True, + mask=inmask, + ) sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask) # TODO: replace torch.subtract/divide/square/maximum with # masked subtract/divide/square/maximum when these will be @@ -1202,7 +1550,9 @@ def std_var(input: Tensor, if mask is None: total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) else: - total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask) + total = sum( + x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask + ) if not keepdim: count = count.reshape(total.shape) if unbiased: @@ -1213,17 +1563,21 @@ def std_var(input: Tensor, output = torch.sqrt(output) return output else: - raise ValueError(f'masked std/var expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked std/var expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def var(input: Tensor, - dim: DimOrDims = None, - unbiased: Optional[bool] = False, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def var( + input: Tensor, + dim: DimOrDims = None, + unbiased: Optional[bool] = False, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} {reduction_descr} @@ -1244,13 +1598,15 @@ def var(input: Tensor, @_apply_docstring_templates -def std(input: Tensor, - dim: DimOrDims = None, - unbiased: Optional[bool] = False, - *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def std( + input: Tensor, + dim: DimOrDims = None, + unbiased: Optional[bool] = False, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: """\ {reduction_signature} {reduction_descr} @@ -1266,16 +1622,18 @@ def std(input: Tensor, keepdim=keepdim, dtype=dtype, mask=mask, - take_sqrt=True + take_sqrt=True, ) @_apply_docstring_templates -def softmax(input: Tensor, - dim: int, - *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def softmax( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: if dtype is None: dtype = input.dtype dim_ = _canonical_dim(dim, input.ndim)[0] @@ -1283,15 +1641,19 @@ def softmax(input: Tensor, if input.layout == torch.strided: return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype) else: - raise ValueError(f'masked softmax expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked softmax expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def log_softmax(input: Tensor, - dim: int, - *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def log_softmax( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: if dtype is None: dtype = input.dtype dim_ = _canonical_dim(dim, input.ndim)[0] @@ -1299,15 +1661,19 @@ def log_softmax(input: Tensor, if input.layout == torch.strided: return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype) else: - raise ValueError(f'masked log_softmax expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked log_softmax expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def softmin(input: Tensor, - dim: int, - *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def softmin( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: if dtype is None: dtype = input.dtype dim_ = _canonical_dim(dim, input.ndim)[0] @@ -1315,17 +1681,21 @@ def softmin(input: Tensor, if input.layout == torch.strided: return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype) else: - raise ValueError(f'masked softmin expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked softmin expects strided tensor (got {input.layout} tensor)" + ) @_apply_docstring_templates -def normalize(input: Tensor, - ord: float, - dim: int, - *, - eps: float = 1e-12, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None) -> Tensor: +def normalize( + input: Tensor, + ord: float, + dim: int, + *, + eps: float = 1e-12, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: if dtype is None: dtype = input.dtype dim_ = _canonical_dim(dim, input.ndim)[0] @@ -1338,4 +1708,6 @@ def normalize(input: Tensor, # TODO: replace torch.divide with masked divide when available. return torch.divide(mask_input, denom) else: - raise ValueError(f'masked normalize expects strided tensor (got {input.layout} tensor)') + raise ValueError( + f"masked normalize expects strided tensor (got {input.layout} tensor)" + ) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 9061f6c7cbf87..9f5574575fd7a 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1,17 +1,44 @@ +from typing import List, Optional, Union + import torch +import torch._prims_common as utils from torch import Tensor -from torch._prims import utils -from torch._prims.utils import ( - ELEMENTWISE_TYPE_PROMOTION_KIND, +from torch._prims_common import ( check, + corresponding_complex_dtype, + corresponding_real_dtype, elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, ) -from torch._prims.wrappers import out_wrapper_multi, out_wrapper -from typing import List, Optional +from torch._prims_common.wrappers import out_wrapper +from torch._refs import _broadcast_shapes +from torch.utils._pytree import tree_map + +aten = torch.ops.aten meta_lib = torch.library.Library("aten", "IMPL", "Meta") +meta_table = {} + + +def register_meta(op, register_dispatcher=True): + def wrapper(f): + def add_func(op): + meta_table[op] = f + if register_dispatcher: + name = ( + op.__name__ + if op._overloadname != "default" + else op.overloadpacket.__name__ + ) + meta_lib.impl(name, f) + + tree_map(add_func, op) + return f + + return wrapper + def toRealValueType(dtype): from_complex = { @@ -22,13 +49,13 @@ def toRealValueType(dtype): return from_complex.get(dtype, dtype) -@torch.library.impl(meta_lib, "_fft_c2c") +@register_meta(aten._fft_c2c.default) def meta_fft_c2c(self, dim, normalization, forward): assert self.dtype.is_complex return self.new_empty(self.size()) -@torch.library.impl(meta_lib, "_fft_r2c") +@register_meta(aten._fft_r2c.default) def meta_fft_r2c(self, dim, normalization, onesided): assert self.dtype.is_floating_point output_sizes = list(self.size()) @@ -43,9 +70,8 @@ def meta_fft_r2c(self, dim, normalization, onesided): ) -@torch.library.impl(meta_lib, "_fft_c2r.out") -@torch.library.impl(meta_lib, "_fft_c2r") -@out_wrapper +@register_meta([aten._fft_c2r.default, aten._fft_c2r.out]) +@out_wrapper() def meta_fft_c2r(self, dim, normalization, lastdim): assert self.dtype.is_complex output_sizes = list(self.size()) @@ -53,13 +79,8 @@ def meta_fft_c2r(self, dim, normalization, lastdim): return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype)) -@torch.library.impl(meta_lib, "conj_physical.out") -def meta_conj_physical_out(self, out): - return torch._resize_output_(out, self.size(), self.device) - - # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py -@torch.library.impl(meta_lib, "index_select") +@register_meta(aten.index_select.default) def meta_index_select(self, dim, index): result_size = list(self.size()) if self.dim() > 0: @@ -67,31 +88,29 @@ def meta_index_select(self, dim, index): return self.new_empty(result_size) -@torch.library.impl(meta_lib, "index_select.out") +@register_meta(aten.index_select.out) def meta_index_select_out(self, dim, index, out): torch._resize_output_(out, self.size(), self.device) return out.copy_(torch.index_select(self, dim, index)) -@torch.library.impl(meta_lib, "max") +@register_meta([aten.max.default, aten.min.default]) def meta_max(self): return self.new_empty(()) -@torch.library.impl(meta_lib, "min") -def meta_min(self): - return self.new_empty(()) - - -@torch.library.impl(meta_lib, "angle") +@register_meta(aten.angle.default) def meta_angle(self): - _, result_dtype = elementwise_dtypes( - self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT - ) + if self.is_complex(): + result_dtype = corresponding_real_dtype(self.dtype) + else: + _, result_dtype = elementwise_dtypes( + self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) return self.new_empty(self.size(), dtype=result_dtype) -@torch.library.impl(meta_lib, "angle.out") +@register_meta(aten.angle.out) def meta_angle_out(self, out): torch._resize_output_(out, self.size(), self.device) return out.copy_(torch.angle(self)) @@ -113,9 +132,7 @@ def checkUplo(uplo: str): ), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}" -# Keeping this meta impl around, but we don't want to register it directly to the meta key -# because `aten::linalg_eigh` is composite. -# `_linalg_eigh` is implemented internally as a structured kernel, so we have meta support. +# @register_meta(aten.linalg_eigh.default) def meta_linalg_eigh(self, uplo="L"): squareCheckInputs(self, "linalg_eigh") checkUplo(uplo) @@ -127,7 +144,7 @@ def meta_linalg_eigh(self, uplo="L"): return (values, vectors) -@torch.library.impl(meta_lib, "reflection_pad2d") +@register_meta(aten.reflection_pad2d.default) def meta_pad2d(self, padding): valid_dims = self.size(1) != 0 and self.size(2) != 0 check( @@ -152,12 +169,16 @@ def meta_pad2d(self, padding): return self.new_empty((nbatch, nplane, output_h, output_w)) -@torch.library.impl(meta_lib, "dot") -def meta_dot(self, tensor): +def dot_check(self, other): check( - self.dim() == 1 and tensor.dim() == 1, - lambda: f"1D tensors expected, but got {self.dim()}D and {tensor.dim()}D tensors", + self.dim() == 1 and other.dim() == 1, + lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", ) + + +@register_meta(aten.dot.default) +def meta_dot(self, tensor): + dot_check(self, tensor) return self.new_empty(()) @@ -168,16 +189,7 @@ def _compute_reduction_shape(self, dims, keepdim): return utils.compute_reduction_output_shape(self.shape, dims) -@torch.library.impl(meta_lib, "var_mean.correction") -def meta_var_mean_correction(self, dim, *, correction, keepdim=False): - dim = utils.reduction_dims(self.shape, dim) - output_shape = _compute_reduction_shape(self, dim, keepdim) - result1 = self.new_empty(output_shape, dtype=toRealValueType(self.dtype)) - result2 = self.new_empty(output_shape) - return result1, result2 - - -@torch.library.impl(meta_lib, "inverse") +@register_meta(aten.inverse.default) def meta_inverse(self): # Bug: https://github.com/pytorch/pytorch/issues/77498 if self.numel() == 0: @@ -193,7 +205,144 @@ def meta_bernoulli(self, *, generator=None, out): return out -@torch.library.impl(meta_lib, "_adaptive_avg_pool2d") +@register_meta(aten.convolution.default) +def meta_conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, +): + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + Returns: + The output length + """ + return (ln + 2 * p - d * (k - 1) - 1) // s + 1 + + def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + if transposed convolution is used. + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + op: output padding in that dim + + Returns: + The output length + """ + return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 + + def calc_conv_nd_return_shape( + dims: torch.Size, + kernel_size: torch.Size, + stride: Union[List[int], int], + padding: Union[List[int], int], + dilation: Union[List[int], int], + output_padding: Optional[Union[List[int], int]] = None, + ): + ret_shape = [] + if isinstance(stride, int): + stride = [stride] * len(dims) + elif len(stride) == 1: + stride = [stride[0]] * len(dims) + + if isinstance(padding, int): + padding = [padding] * len(dims) + elif len(padding) == 1: + padding = [padding[0]] * len(dims) + + if isinstance(dilation, int): + dilation = [dilation] * len(dims) + elif len(dilation) == 1: + dilation = [dilation[0]] * len(dims) + + output_padding_list: Optional[List[int]] = None + if output_padding: + if isinstance(output_padding, int): + output_padding_list = [output_padding] * len(dims) + elif len(output_padding) == 1: + output_padding_list = [output_padding[0]] * len(dims) + else: + output_padding_list = output_padding + + for i in range(len(dims)): + # If output_padding is present, we are dealing with a transposed convolution + if output_padding_list: + ret_shape.append( + _formula_transposed( + dims[i], + padding[i], + dilation[i], + kernel_size[i], + stride[i], + output_padding_list[i], + ) + ) + else: + ret_shape.append( + _formula( + dims[i], padding[i], dilation[i], kernel_size[i], stride[i] + ) + ) + return ret_shape + + def pick_memory_format(): + if input_tensor.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + elif input_tensor.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif input_tensor.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + + kernel_size = weight.shape[2:] + dims = input_tensor.shape[2:] + if is_transposed: + out_channels = groups * weight.shape[1] + + shape_out = calc_conv_nd_return_shape( + dims, + kernel_size, + stride, + padding, + dilation, + output_padding, + ) + + else: + out_channels = weight.shape[0] + if weight.shape[1] != input_tensor.shape[1] / groups: + raise RuntimeError("Invalid channel dimensions") + shape_out = calc_conv_nd_return_shape( + dims, kernel_size, stride, padding, dilation + ) + out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) + mem_fmt = pick_memory_format() + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + return out + + +@register_meta(aten._adaptive_avg_pool2d.default) def meta_adaptive_avg_pool2d(self, output_size): check( self.ndim == 3 or self.ndim == 4, @@ -202,7 +351,7 @@ def meta_adaptive_avg_pool2d(self, output_size): return self.new_empty(self.shape[:-2] + tuple(output_size)) -@torch.library.impl(meta_lib, "_adaptive_avg_pool3d") +@register_meta(aten._adaptive_avg_pool3d.default) def meta_adaptive_avg_pool3d(self, output_size): check( self.ndim == 4 or self.ndim == 5, @@ -211,17 +360,45 @@ def meta_adaptive_avg_pool3d(self, output_size): return self.new_empty(self.shape[:-3] + tuple(output_size)) -@torch.library.impl(meta_lib, "repeat_interleave.Tensor") +@register_meta(aten.repeat_interleave.Tensor) def meta_repeat_interleave_Tensor(repeats, output_size=None): if output_size is None: raise RuntimeError("cannot repeat_interleave a meta tensor without output_size") return repeats.new_empty(output_size) +@torch.library.impl(meta_lib, "complex") +@torch.library.impl(meta_lib, "complex.out") +@out_wrapper() +def meta_complex(real, imag): + assert real.dtype.is_floating_point + assert imag.dtype.is_floating_point + out_shape = _broadcast_shapes(real.shape, imag.shape) + return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype)) + + +@torch.library.impl(meta_lib, "vdot") +def vdot(self, other): + if not self.is_complex: + return torch.dot(self, other) + + if self.is_conj(): + if other.is_conj(): + return torch.vdot(other.conj(), self.conj()) + else: + return torch.dot(self.conj(), other) + elif other.is_conj(): + return torch.dot(self, other.conj()).conj() + + dot_check(self, other) + return self.new_empty(()) + + # Leaving this function around because a python implementation # of indexing shape inference is useful, # but not registering it to the dispatcher because we already # get shape inference through structured kernels +@register_meta(aten.index.Tensor, register_dispatcher=False) def meta_index_Tensor(self, indices): check(indices, lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation @@ -323,34 +500,8 @@ def meta_index_Tensor(self, indices): return self.new_empty(before_shape + replacement_shape + after_shape) -@out_wrapper_multi("L", "info") -def meta_linalg_cholesky_ex(input, upper=False, check_errors=False): - check( - input.ndim >= 2, - lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor", - ) - check( - utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), - lambda: f"expected float or complex tensor, but got {input.dtype}", - ) - check( - input.size(-1) == input.size(-2), - lambda: f"expected square matrix but got {input.shape}", - ) - L = input.new_empty(input.size()) - L.transpose_(-2, -1) - info_sizes = input.size()[:-2] - info = input.new_empty(info_sizes, dtype=torch.int) - return L, info - - -torch.library.impl(meta_lib, "linalg_cholesky_ex")(meta_linalg_cholesky_ex) -torch.library.impl(meta_lib, "linalg_cholesky_ex.L")(meta_linalg_cholesky_ex) - - -@torch.library.impl(meta_lib, "addbmm") -@torch.library.impl(meta_lib, "addbmm.out") -@out_wrapper +@register_meta([aten.addbmm.default, aten.addbmm.out]) +@out_wrapper() def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(1) dim2 = batch2.size(2) @@ -504,9 +655,8 @@ def is_fast_path(src, scale, output, padding_idx): return output, offset2bag, bag_size, max_indices -@torch.library.impl(meta_lib, "diag") -@torch.library.impl(meta_lib, "diag.out") -@out_wrapper +@register_meta([aten.diag.default, aten.diag.out]) +@out_wrapper() def meta_diag(self, dim=0): check(self.dim() in (1, 2), lambda: "matrix or a vector expected") if self.dim() == 1: @@ -544,9 +694,8 @@ def _get_reduction_dtype(input, dtype, promote_int_to_long=True): return input.dtype -@torch.library.impl(meta_lib, "nansum") -@torch.library.impl(meta_lib, "nansum.out") -@out_wrapper +@register_meta([aten.nansum.default, aten.nansum.out]) +@out_wrapper() def meta_nansum(input, dims=None, keepdim=False, *, dtype=None): output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True) dims = utils.reduction_dims(input.shape, dims) @@ -554,7 +703,7 @@ def meta_nansum(input, dims=None, keepdim=False, *, dtype=None): return input.new_empty(output_shape, dtype=output_dtype) -@torch.library.impl(meta_lib, "nanmedian") +@register_meta(aten.nanmedian.default) def meta_nanmedian(input): output_shape = utils.compute_reduction_output_shape( input.shape, tuple(range(input.dim())) @@ -562,27 +711,24 @@ def meta_nanmedian(input): return input.new_empty(output_shape) -@torch.library.impl(meta_lib, "nanmedian.dim_values") -@torch.library.impl(meta_lib, "nanmedian.dim") -@out_wrapper_multi("values", "indices") +@register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values]) +@out_wrapper("values", "indices") def meta_nanmedian_dim(input, dim=-1, keepdim=False): dim = utils.reduction_dims(input.shape, (dim,)) output_shape = _compute_reduction_shape(input, dim, keepdim) - return input.new_empty(output_shape), input.new_empty( - output_shape, dtype=torch.long + return ( + input.new_empty(output_shape), + input.new_empty(output_shape, dtype=torch.long), ) -@torch.library.impl(meta_lib, "nan_to_num") -def meta_nan_to_num(self, nan=None, posinf=None, neginf=None): - return self.new_empty(self.shape) - - -@torch.library.impl(meta_lib, "remainder.Scalar_Tensor") -def meta_remainder_scalar(scalar, other): - return other % scalar - - @torch.library.impl(meta_lib, "logical_not_") def meta_logical_not_(self): return self + + +# We must also trigger meta registrations from PrimTorch ref +# decompositions +import torch._refs +import torch._refs.nn.functional +import torch._refs.special diff --git a/torch/_namedtensor_internals.py b/torch/_namedtensor_internals.py index e20b08937ab5b..0c422295d897c 100644 --- a/torch/_namedtensor_internals.py +++ b/torch/_namedtensor_internals.py @@ -11,55 +11,64 @@ def check_serializing_named_tensor(tensor): if tensor.has_names(): raise RuntimeError( "NYI: Named tensors don't support serialization. Please drop " - "names via `tensor = tensor.rename(None)` before serialization.") + "names via `tensor = tensor.rename(None)` before serialization." + ) def build_dim_map(tensor): """Returns a map of { dim: dim_name } where dim is a name if the dim is named and the dim index otherwise.""" - return OrderedDict([(idx if name is None else name, name) - for idx, name in enumerate(tensor.names)]) + return OrderedDict( + [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)] + ) def unzip_namedshape(namedshape): if isinstance(namedshape, OrderedDict): namedshape = namedshape.items() - if not hasattr(namedshape, '__iter__') and not isinstance(namedshape, tuple): + if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple): raise RuntimeError( - 'Expected namedshape to be OrderedDict or iterable of tuples, got: {}' - .format(type(namedshape))) + "Expected namedshape to be OrderedDict or iterable of tuples, got: {}".format( + type(namedshape) + ) + ) if len(namedshape) == 0: - raise RuntimeError('Expected namedshape to non-empty.') + raise RuntimeError("Expected namedshape to non-empty.") return zip(*namedshape) def namer_api_name(inplace): if inplace: - return 'rename_' + return "rename_" else: - return 'rename' + return "rename" def is_ellipsis(item): - return item == Ellipsis or item == '...' + return item == Ellipsis or item == "..." + def single_ellipsis_index(names, fn_name): ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] if len(ellipsis_indices) >= 2: - raise RuntimeError('{}: More than one Ellipsis (\'...\') found in names (' - '{}). This function supports up to one Ellipsis.' - .format(fn_name, names)) + raise RuntimeError( + "{}: More than one Ellipsis ('...') found in names (" + "{}). This function supports up to one Ellipsis.".format(fn_name, names) + ) if len(ellipsis_indices) == 1: return ellipsis_indices[0] return None + def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names): - return names[numel_pre_glob:len(names) - numel_post_glob] + return names[numel_pre_glob : len(names) - numel_post_glob] def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names): - globbed_names = expand_single_ellipsis(ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names) - return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1:] + globbed_names = expand_single_ellipsis( + ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names + ) + return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :] def resolve_ellipsis(names, tensor_names, fn_name): @@ -78,7 +87,8 @@ def update_names_with_list(tensor, names, inplace): return tensor._update_names(None, inplace) return tensor._update_names( - resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace) + resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace + ) def update_names_with_mapping(tensor, rename_map, inplace): @@ -88,10 +98,17 @@ def update_names_with_mapping(tensor, rename_map, inplace): if old_dim in dim_map.keys(): dim_map[old_dim] = new_dim else: - raise RuntimeError(('{api_name}: Tried to rename dim \'{old_dim}\' to dim ' - '{new_dim} in Tensor[{dims}] but dim \'{old_dim}\' does not exist') - .format(old_dim=old_dim, new_dim=new_dim, dims=tensor.names, - api_name=namer_api_name(inplace))) + raise RuntimeError( + ( + "{api_name}: Tried to rename dim '{old_dim}' to dim " + "{new_dim} in Tensor[{dims}] but dim '{old_dim}' does not exist" + ).format( + old_dim=old_dim, + new_dim=new_dim, + dims=tensor.names, + api_name=namer_api_name(inplace), + ) + ) return tensor._update_names(tuple(dim_map.values()), inplace) @@ -128,10 +145,12 @@ def update_names(tensor, names, rename_map, inplace): has_names = len(names) > 0 has_rename_pairs = bool(rename_map) if has_names and has_rename_pairs: - raise RuntimeError('{api_name}: This function takes either positional ' - 'args or keyword args, but not both. Use tensor.{api_name}(*names) ' - 'to name dims and tensor.{api_name}(**rename_map) to rename ' - 'dims.'.format(api_name=namer_api_name(inplace))) + raise RuntimeError( + "{api_name}: This function takes either positional " + "args or keyword args, but not both. Use tensor.{api_name}(*names) " + "to name dims and tensor.{api_name}(**rename_map) to rename " + "dims.".format(api_name=namer_api_name(inplace)) + ) # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor. if not has_names and not has_rename_pairs: diff --git a/torch/_ops.py b/torch/_ops.py index a325d93e51782..b1a43734b440f 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1,15 +1,15 @@ -import torch._C - import contextlib import ctypes import sys import types +import torch._C + import torch.jit -import torch._utils_internal as torch_utils_internal +from torch import _utils_internal # Query `hasattr` only once. -_SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags') +_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") @contextlib.contextmanager @@ -25,6 +25,7 @@ def dl_open_guard(): if _SET_GLOBAL_FLAGS: sys.setdlopenflags(old_flags) + # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. # You can obtain an OpOverload object through attribute query on OpOverloadPacket. class OpOverload: @@ -33,8 +34,12 @@ def __init__(self, overloadpacket, op, schema, tags): self._schema = schema self._overloadpacket = overloadpacket self._tags = tags - self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name - self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname) + self._overloadname = ( + "default" if schema.overload_name == "" else schema.overload_name + ) + self.__name__ = "{}.{}".format( + self._schema.name.split("::")[1], self._overloadname + ) self.__module__ = overloadpacket.__module__ op.__module__ = overloadpacket.__module__ @@ -43,7 +48,9 @@ def __deepcopy__(self, memo=None): return self def __repr__(self): - return "".format(*self._schema.name.split("::"), self._overloadname) + return "".format( + *self._schema.name.split("::"), self._overloadname + ) def __call__(self, *args, **kwargs): return self._op(*args, **kwargs or {}) @@ -72,6 +79,7 @@ def tags(self): # TODO: add more methods to expose information about input and output arguments + # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator # You can obtain an OpOverload object through attribute query. class OpOverloadPacket: @@ -88,7 +96,9 @@ def __deepcopy__(self, memo=None): return self def __repr__(self): - return "".format(*self._qualified_op_name.split("::")) + return "".format( + *self._qualified_op_name.split("::") + ) def __hash__(self): return hash(self._op) @@ -102,8 +112,8 @@ def op(self): def __getattr__(self, key): # It is not a valid op_name when __file__ is passed in - if key == '__file__': - return 'torch.ops' + if key == "__file__": + return "torch.ops" # ensure that query for dunder attributes that does not exist on # opoverloadpacket but instead exists on the self._op object does not unnecessarily call @@ -113,23 +123,27 @@ def __getattr__(self, key): # opoverloadpacket. # This is ok since we are guaranteed that an overload name for an aten op can't start with '__' try: - if key.startswith('__'): + if key.startswith("__"): return getattr(self._op, key) except AttributeError: # for consistency because it seems weird to # throw an attribute error with a message containing # an object name different from the one the attribute # query was performed on. - raise AttributeError("'{}' can't have an overload name beginning with '__' and the " - "underlying op {} has no attribute {} either." - .format(str(self), str(self._op), key)) from None + raise AttributeError( + "'{}' can't have an overload name beginning with '__' and the " + "underlying op {} has no attribute {} either.".format( + str(self), str(self._op), key + ) + ) from None try: # This is ok since we are guaranteed that an overload name for an aten op can't be 'default' - use_key = '' if key == 'default' else key + use_key = "" if key == "default" else key # TODO: disallow access to overloads registered by JIT op_, tags = torch._C._get_operation_overload( - self._qualified_op_name, use_key) + self._qualified_op_name, use_key + ) schema = torch._C._get_schema(self._qualified_op_name, use_key) overload = OpOverload(self, op_, schema, tags) # cache the overload object @@ -137,7 +151,9 @@ def __getattr__(self, key): return overload except RuntimeError: raise AttributeError( - "The underlying op of '{}' has no overload name '{}'".format(str(self), key) + "The underlying op of '{}' has no overload name '{}'".format( + str(self), key + ) ) from None def __call__(self, *args, **kwargs): @@ -151,6 +167,7 @@ def __call__(self, *args, **kwargs): def overloads(self): return [n if n else "default" for n in self._overload_names] + # Resolution of torch.fn is different from torch.ops.aten.fn # torch.fn uses the Python argparser, matches with the # appropriate schema, and calls into the unboxed version of the method @@ -186,31 +203,36 @@ class _OpNamespace(types.ModuleType): and subsequent accesses will incur no further lookup (the namespace and operation will already exist). """ + def __init__(self, name): - super(_OpNamespace, self).__init__('torch.ops.' + name) + super(_OpNamespace, self).__init__("torch.ops." + name) self.name = name def __getattr__(self, op_name): # It is not a valid op_name when __file__ is passed in - if op_name == '__file__': - return 'torch.ops' + if op_name == "__file__": + return "torch.ops" # Get the op `my_namespace::my_op` if available. This will also check # for overloads and raise an exception if there are more than one. namespace_name = self.name - qualified_op_name = '{}::{}'.format(namespace_name, op_name) + qualified_op_name = "{}::{}".format(namespace_name, op_name) try: op, overload_names = torch._C._jit_get_operation(qualified_op_name) except RuntimeError as e: # Turn this into AttributeError so getattr(obj, key, default) # works (this is called by TorchScript with __origin__) - raise AttributeError(f"'_OpNamespace' object has no attribute '{op_name}'") from e + raise AttributeError( + f"'_OpNamespace' object has no attribute '{op_name}'" + ) from e # let the script frontend know that op is identical to the builtin op # with qualified_op_name torch.jit._builtins._register_builtin(op, qualified_op_name) op.__module__ = self.__module__ + "." + namespace_name - opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op, overload_names) + opoverloadpacket = OpOverloadPacket( + qualified_op_name, op_name, op, overload_names + ) opoverloadpacket.__module__ = self.__module__ + "." + namespace_name # cache the opoverloadpacket to ensure that each op corresponds to # a unique OpOverloadPacket object @@ -219,10 +241,10 @@ def __getattr__(self, op_name): class _Ops(types.ModuleType): - __file__ = '_ops.py' + __file__ = "_ops.py" def __init__(self): - super(_Ops, self).__init__('torch.ops') + super(_Ops, self).__init__("torch.ops") self.loaded_libraries = set() def __getattr__(self, name): @@ -252,7 +274,7 @@ def load_library(self, path): if sys.executable == "torch_deploy": return - path = torch_utils_internal.resolve_library_path(path) + path = _utils_internal.resolve_library_path(path) with dl_open_guard(): # Import the shared library into the process, thus running its # static (global) initialization code in order to register custom diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index f34542c8c0955..5f9507804ebae 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1,33 +1,38 @@ +import contextlib +import itertools +import math +import operator +import weakref +from enum import Enum +from functools import partial, reduce +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union + import torch -from torch import Tensor, _TypedStorage -import torch._prims.utils as utils -from torch._prims.utils import ( - TensorLike, - TensorLikeType, - ShapeType, - getnvFuserDtype, - DimsType, +import torch._prims_common as utils +import torch.library +from torch import Tensor, TypedStorage +from torch._C import _get_default_device +from torch._prims_common import ( + check, DimsSequenceType, - StrideType, + DimsType, + getnvFuserDtype, Number, NumberType, - TensorMeta, + ShapeType, + StrideType, + TensorLike, + TensorLikeType, + type_to_dtype, ) -from torch.overrides import has_torch_function, handle_torch_function -import torch.library -from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten -from torch._subclasses.fake_tensor import FakeTensor - -import contextlib -from typing import Sequence, Optional, Union, Callable, List, Tuple, Any, Type -from functools import reduce, partial -from enum import Enum -import operator -import math +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.overrides import handle_torch_function, has_torch_function +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten prim = torch.library.Library("prims", "DEF") prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd") +prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect") prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd") prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta") @@ -45,7 +50,9 @@ "acos", "acosh", "asin", + "asinh", "atan", + "atanh", "cos", "cosh", "bessel_i0", @@ -55,6 +62,7 @@ "bitwise_not", "cbrt", "ceil", + "conj_physical", "digamma", "erf", "erf_inv", @@ -64,14 +72,15 @@ "exp2", "fill", "floor", + "imag", "isfinite", - "is_infinite", "lgamma", "log", "log1p", "log2", "log10", "neg", + "real", "reciprocal", "round", "sign", @@ -81,6 +90,7 @@ "sqrt", "tan", "tanh", + "trunc", # # Elementwise binary prims # @@ -98,6 +108,7 @@ "gcd", "ge", "gt", + "hypot", "igamma", "igammac", "le", @@ -121,6 +132,7 @@ "as_strided", "broadcast_in_dim", "collapse_view", + "conj", "expand_dims", "slice", "slice_in_dim", # implemented using slice -- make this a ref? @@ -142,7 +154,6 @@ # # Data conversion and movement prims # - "clone", "convert_element_type", "device_put", "item", @@ -167,12 +178,104 @@ # Tensor Creation Prims # "empty_strided", + "scalar_tensor", + "arange", + # + # Linear algebra (linalg) Prims + # + "svd", # # Randomness Prims # "uniform", + # + # FFT prims + # + "fft_r2c", + "fft_c2c", + "fft_c2r", ] + +# In order to keep things like aliasing relationships and storage +# consistent wrt/meta tensors, FakeTensors own a FakeTensorMode +# which caches conversions to Meta Tensors. We would like to use +# one consistent mode among along FakeTensors, which we store here. +# We store a weakref, so that when all previous FakeTensors are +# the present mode will also deallocate. FakeTensorMode holds onto +# tensors that are converted to Meta so we don't want to persist it +# longer than necessary.x +prim_fake_mode_ref = None + + +def get_prim_fake_mode(): + global prim_fake_mode_ref + if prim_fake_mode_ref is None or prim_fake_mode_ref() is None: + mode = FakeTensorMode() + prim_fake_mode_ref = weakref.ref(mode) + return mode + else: + return prim_fake_mode_ref() + + +def TensorMeta( + tensorlike: Optional[Union[NumberType, torch.Tensor]] = None, + *, + shape: Optional[ShapeType] = None, + strides: Optional[StrideType] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, +): + if isinstance(tensorlike, Number): + assert not shape and (shape is None or isinstance(shape, Sequence)) + assert not strides and (strides is None or isinstance(strides, Sequence)) + inferred_shape: Tuple[int, ...] = () + inferred_strides: Tuple[int, ...] = () + inferred_dtype = type_to_dtype(type(tensorlike)) + inferred_device = torch.device("cpu") + # TODO: This looks wrong, a number that is wrapped into a tensor + # needs to behave differently than a scalar tensor for type + # promotion purposes + elif tensorlike is not None: + assert isinstance(tensorlike, torch.Tensor) + inferred_shape = tuple(tensorlike.shape) + inferred_strides = tuple(tensorlike.stride()) + inferred_dtype = tensorlike.dtype + inferred_device = tensorlike.device + else: + # If no tensorlike "example" is given then all metadata + # must be provided explicitly + assert shape is not None + assert strides is not None + assert dtype is not None + assert device is not None + + shape = inferred_shape if shape is None else tuple(shape) + strides = inferred_strides if strides is None else tuple(strides) + dtype = inferred_dtype if dtype is None else dtype + device = inferred_device if device is None else device + + if isinstance(device, str): + device = torch.device(device) + + if isinstance(tensorlike, FakeTensor): + mode = tensorlike.fake_mode + else: + mode = get_prim_fake_mode() + + if device.type == "meta": + return torch.empty_strided(shape, strides, dtype=dtype, device="meta") + else: + # SymInt doesnt support empty_strided yet + if any( + isinstance(inp, torch.SymIntNode) for inp in itertools.chain(shape, strides) + ): + meta_t = torch.empty(shape, dtype=dtype, device="meta") + else: + meta_t = torch.empty_strided(shape, strides, dtype=dtype, device="meta") + return FakeTensor(mode, meta_t, device) + + # # Common datastructures and helpers # @@ -182,6 +285,7 @@ "acos", "asin", "atan", + "atanh", "cos", "cosh", "bitwise_not", @@ -191,6 +295,7 @@ "exp", "expm1", "floor", + "imag", "isfinite", "lgamma", "log", @@ -199,6 +304,7 @@ "log10", "reciprocal", "neg", + "real", "round", "rsqrt", "sin", @@ -213,7 +319,7 @@ def _assert_nvfuser_op_exists(fname: str): try: from torch._C._nvfuser import FusionDefinition as fd # type: ignore[import] - assert getattr(fd.Ops, fname) + assert getattr(fd.Operators, fname) except ImportError: # Not all PyTorch builds have nvfuser pass @@ -226,7 +332,7 @@ def _assert_nvfuser_op_exists(fname: str): _assert_nvfuser_op_exists("{fname}") def _{fname}_nvfuser(fd: Any, a: TensorLikeType): - return fd.Ops.{fname}(a) # type: ignore[attr-defined] + return fd.ops.{fname}(a) # type: ignore[attr-defined] """ ) @@ -256,7 +362,7 @@ def _{fname}_nvfuser(fd: Any, a: TensorLikeType): _assert_nvfuser_op_exists("{fname}") def _{fname}_nvfuser(fd: Any, a: TensorLikeType, b: TensorLikeType): - return fd.Ops.{fname}(a, b) # type: ignore[attr-defined] + return fd.ops.{fname}(a, b) # type: ignore[attr-defined] """ ) @@ -271,7 +377,7 @@ def _{fname}_nvfuser(fd: Any, a: TensorLikeType, b: TensorLikeType): _assert_nvfuser_op_exists("{fname}") def _{fname}_nvfuser(fd: Any, a: TensorLikeType, b: TensorLikeType, c: TensorLikeType): - return fd.Ops.{fname}(a, b, c) # type: ignore[attr-defined] + return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined] """ ) @@ -295,7 +401,7 @@ def wrap(t): and not isinstance(t, FakeTensor) and not t.device.type == "meta" ): - return FakeTensor.from_tensor(t, utils.get_prim_fake_mode()) + return FakeTensor.from_tensor(t, get_prim_fake_mode()) else: return t @@ -310,7 +416,7 @@ def wrapper(*args, **kwargs): def _make_prim( *, schema: str, - return_type: RETURN_TYPE, + return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]], meta: Callable, impl_aten: Callable, impl_nvfuser: Optional[Callable] = None, @@ -352,14 +458,27 @@ def _autograd_impl(*args, **kwargs): flat_args, args_spec = tree_flatten((args, kwargs)) return BackwardsNotSupported.apply(args_spec, *flat_args) + _meta_impl = _wrap_tensor_meta(meta) + + def _backend_select_impl(*args, **kwargs): + if kwargs.get("device") and kwargs["device"].type == "meta": + return _meta_impl(*args, **kwargs) + else: + return _prim_impl(*args, **kwargs) + name = schema.split("(")[0] prim_impl.impl(name, _prim_impl) prim_autograd_impl.impl(name, _autograd_impl) - prim_meta_impl.impl(name, _wrap_tensor_meta(meta)) + prim_meta_impl.impl(name, _meta_impl) _prim_packet = getattr(torch.ops.prims, name) _prim = _prim_packet.default + from torch._subclasses.fake_tensor import contains_tensor_types + + if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments): + prim_backend_select_impl.impl(name, _backend_select_impl) + for p in (_prim_packet, _prim): p.__doc__ = doc p.impl_nvfuser = impl_nvfuser # type: ignore[attr-defined] @@ -414,6 +533,9 @@ def _elementwise_meta( elif isinstance(arg, Number): scalar_type = type(arg) + if dtype is None and scalar_type is not None: + dtype = utils.type_to_dtype(scalar_type) + # Acquires the device (if it exists) or number device = None number = None @@ -449,6 +571,13 @@ def _elementwise_meta( return TensorMeta(number) +def _complex_only_elementwise_meta(*args, **kwargs): + utils.check( + utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported" + ) + return _elementwise_meta(*args, **kwargs) + + def _make_elementwise_unary_prim( name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs ): @@ -519,6 +648,13 @@ def _not_impl(*args, **kwargs): type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, ) +asinh = _make_elementwise_unary_prim( + "asinh", + impl_aten=torch.asinh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + atan = _make_elementwise_unary_prim( "atan", impl_aten=torch.atan, @@ -527,6 +663,14 @@ def _not_impl(*args, **kwargs): type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, ) +atanh = _make_elementwise_unary_prim( + "atanh", + impl_aten=torch.atanh, + impl_nvfuser=_atanh_nvfuser, # type: ignore[name-defined] + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + cos = _make_elementwise_unary_prim( "cos", impl_aten=torch.cos, @@ -580,8 +724,18 @@ def _not_impl(*args, **kwargs): ) -def _cbrt_aten(a: torch.Tensor): - return pow(a, (1 / 3)) +def _cbrt_aten(a: torch.Tensor) -> Tensor: + utils.check( + not a.is_complex(), + lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)", + ) + # Returns the real cubic root of the number. + # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number + # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i} + # which is a complex number. + # For more info see the section Note in + # https://en.cppreference.com/w/cpp/numeric/math/cbrt + return torch.copysign(torch.pow(a.abs(), 1 / 3), a) cbrt = _make_elementwise_unary_prim( @@ -599,6 +753,23 @@ def _cbrt_aten(a: torch.Tensor): type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, ) + +def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType: + if not input.dtype.is_complex: + raise RuntimeError("prims.conj_physical is only defined for complex dtypes") + + strides = utils.compute_elementwise_output_strides(input) + return TensorMeta(input, strides=strides) + + +conj_physical = _make_prim( + schema="conj_physical(Tensor self) -> Tensor", + meta=_conj_physical_meta, + impl_aten=torch._conj_physical, + doc="Returns the physical conjugation of a complex tensor", + return_type=RETURN_TYPE.NEW, +) + digamma = _make_elementwise_unary_prim( "digamma", impl_aten=torch.digamma, @@ -684,6 +855,18 @@ def _fill_aten(a: Tensor, value: NumberType) -> Tensor: type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, ) +imag = _make_prim( + schema="imag(Tensor self) -> Tensor", + meta=partial( + _complex_only_elementwise_meta, + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + ), + return_type=RETURN_TYPE.VIEW, + impl_aten=torch.imag, + impl_nvfuser=_imag_nvfuser, # type: ignore[name-defined] + doc="", +) + isfinite = _make_elementwise_unary_prim( "isfinite", impl_aten=torch.isfinite, @@ -692,13 +875,6 @@ def _fill_aten(a: Tensor, value: NumberType) -> Tensor: type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, ) -is_infinite = _make_elementwise_unary_prim( - "is_infinite", - impl_aten=torch.isinf, - doc="", - type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) - lgamma = _make_elementwise_unary_prim( "lgamma", impl_aten=torch.lgamma, @@ -739,6 +915,18 @@ def _fill_aten(a: Tensor, value: NumberType) -> Tensor: type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, ) +real = _make_prim( + schema="real(Tensor self) -> Tensor", + meta=partial( + _complex_only_elementwise_meta, + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + ), + return_type=RETURN_TYPE.VIEW, + impl_aten=torch.real, + impl_nvfuser=_real_nvfuser, # type: ignore[name-defined] + doc="", +) + reciprocal = _make_elementwise_unary_prim( "reciprocal", impl_aten=torch.reciprocal, @@ -825,6 +1013,19 @@ def _fill_aten(a: Tensor, value: NumberType) -> Tensor: type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, ) + +def _trunc_nvfuser(fd: Any, a: TensorLikeType): + return fd.ops.trunc(a) # type: ignore[attr-defined] + + +trunc = _make_elementwise_unary_prim( + "trunc", + impl_aten=torch.trunc, + impl_nvfuser=_trunc_nvfuser, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + # # Elementwise binary operations # @@ -946,6 +1147,13 @@ def _div_aten(a, b): type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, ) +hypot = _make_elementwise_binary_prim( + "hypot", + impl_aten=torch.hypot, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + igamma = _make_elementwise_binary_prim( "igamma", impl_aten=torch.special.gammainc, @@ -977,25 +1185,14 @@ def _div_aten(a, b): ) -def _wrap_scalar(a: NumberType, *, dtype: torch.dtype = None) -> torch.Tensor: - """ - Wraps a Number into a Tensor of corresponding dtype. - - Note: this should not generally be used, but some torch functions don't - accept scalars, so it's necessary for their prims to do so. - """ - dtype = dtype if dtype is not None else utils.type_to_dtype(type(a)) - return torch.tensor(a, dtype=dtype) - - # Note: the following impls are because torch.maximum and torch.mininum do not support scalar inputs def _maximum_aten( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] ) -> TensorLikeType: if isinstance(a, TensorLike) and isinstance(b, Number): - b = _wrap_scalar(b, dtype=a.dtype) + b = scalar_tensor(b, dtype=a.dtype, device=a.device) elif isinstance(b, TensorLike) and isinstance(a, Number): - a = _wrap_scalar(a, dtype=b.dtype) + a = scalar_tensor(a, dtype=b.dtype, device=b.device) return torch.maximum(a, b) # type: ignore[arg-type] @@ -1012,9 +1209,9 @@ def _minimum_aten( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] ) -> TensorLikeType: if isinstance(a, TensorLike) and isinstance(b, Number): - b = _wrap_scalar(b, dtype=a.dtype) + b = scalar_tensor(b, dtype=a.dtype, device=a.device) elif isinstance(b, TensorLike) and isinstance(a, Number): - a = _wrap_scalar(a, dtype=b.dtype) + a = scalar_tensor(a, dtype=b.dtype, device=b.device) return torch.minimum(a, b) # type: ignore[arg-type] @@ -1059,7 +1256,7 @@ def _minimum_aten( def _remainder_nvfuser(fd: Any, a: TensorLikeType, b: TensorLikeType): - return fd.Ops.remainder(a, b) # type: ignore[attr-defined] + return fd.ops.remainder(a, b) # type: ignore[attr-defined] remainder = _make_elementwise_binary_prim( @@ -1211,7 +1408,7 @@ def _broadcast_in_dim_nvfuser( shape: ShapeType, broadcast_dimensions: ShapeType, ): - return fd.Ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined] + return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined] _broadcast_in_dim_doc = """ @@ -1226,7 +1423,7 @@ def _broadcast_in_dim_nvfuser( """ broadcast_in_dim = _make_prim( - schema="broadcast_in_dim(Tensor(a) a, int[] shape, int[] broadcast_dimensions) -> Tensor(a)", + schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)", meta=_broadcast_in_dim_meta, impl_aten=_broadcast_in_dim_aten, impl_nvfuser=_broadcast_in_dim_nvfuser, @@ -1345,6 +1542,25 @@ def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor: ) +def _conj_meta(a: TensorLikeType) -> TensorLikeType: + if not a.dtype.is_complex: + raise RuntimeError("Expected complex dtype in prims.conj") + return TensorMeta(a) + + +_conj_doc = """ +Returns a conjugated view of the original tensor +""" + +conj = _make_prim( + schema="conj(Tensor(a) a) -> Tensor(a)", + meta=_conj_meta, + impl_aten=torch.conj, + return_type=RETURN_TYPE.VIEW, + doc=_conj_doc, +) + + def expand_dims(a: TensorLikeType, dimensions: DimsSequenceType) -> TensorLikeType: """ Creates a view of a with a.ndim + len(dimensions) dimensions, with new @@ -1644,10 +1860,8 @@ def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType: def _squeeze_aten(a: Tensor, dimensions: Sequence) -> Tensor: - squeezes = 0 - for idx in dimensions: - a = torch.squeeze(a, dim=(idx - squeezes)) - squeezes = squeezes + 1 + for idx in reversed(sorted(dimensions)): + a = torch.squeeze(a, dim=idx) return a @@ -1872,32 +2086,6 @@ def _where_meta( # # Type conversions # -# TODO: model memory format on TensorMeta -# TODO: make clone a reference following its implementation in TensorFactories.cpp -def _clone_meta( - a: TensorLikeType, *, memory_format: torch.memory_format -) -> TensorLikeType: - strides = utils.compute_elementwise_output_strides(a) - return TensorMeta(a, strides=strides) - - -def _clone_aten(a: Tensor, *, memory_format: torch.memory_format) -> Tensor: - return torch.clone(a, memory_format=memory_format) - - -_clone_doc = """ - Creates a copy of a tensors. -""" - -clone = _make_prim( - schema="clone(Tensor a, *, MemoryFormat memory_format) -> Tensor", - meta=_clone_meta, - impl_aten=_clone_aten, - return_type=RETURN_TYPE.NEW, - doc=_clone_doc, -) - - def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: # Type checks assert isinstance(a, TensorLike) @@ -1929,7 +2117,7 @@ def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor: nvfuser_dtype = getnvFuserDtype(dtype) - return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined] + return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined] _convert_element_type_doc = """ @@ -1952,7 +2140,7 @@ def _device_put_meta( assert isinstance(a, TensorLike) assert isinstance(device, (str, torch.device)) - return TensorMeta(a, device=utils.wrap_device(device)) + return TensorMeta(a, device=utils.canonicalize_device(device)) def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor: @@ -2123,7 +2311,7 @@ def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor: def _resize_meta(a: TensorLikeType, shape: ShapeType): - return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape)) + return a.resize_(shape) def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor: @@ -2223,7 +2411,7 @@ def _sum_nvfuser( ): keep_dims = False output_dtype = torch._C._nvfuser.DataType.Null - return fd.Ops.sum(a, dims, keep_dims, output_dtype) + return fd.ops.sum(a, dims, keep_dims, output_dtype) sum = _make_reduction_prim( @@ -2264,7 +2452,25 @@ def _var_nvfuser( correction: int, ): keep_dims = False - return fd.Ops.var(a, dims, correction, keep_dims) + return fd.ops.var(a, dims, correction, keep_dims) + + +def _amax_nvfuser( + fd: Any, + a: TensorLikeType, + dims: DimsSequenceType, +): + keep_dims = False + return fd.ops.max(a, dims, keep_dims) + + +def _amin_nvfuser( + fd: Any, + a: TensorLikeType, + dims: DimsSequenceType, +): + keep_dims = False + return fd.ops.min(a, dims, keep_dims) var = _make_var_reduction_prim( @@ -2277,15 +2483,97 @@ def _var_nvfuser( amax = _make_reduction_prim( name="amax", impl_aten=torch.amax, + impl_nvfuser=_amax_nvfuser, doc=_amax_doc, ) amin = _make_reduction_prim( name="amin", impl_aten=torch.amin, + impl_nvfuser=_amin_nvfuser, doc=_amin_doc, ) + +_arange_doc = """ + Constructs a 1-D tensor with values from the interval [start, end) taken + with common difference `step` beginning from `start`. +""" + + +# TODO: layout, pin_memory, memory_format +# TODO: model requires_grad on TensorMeta +def _arange_meta( + start: NumberType, + end: NumberType, + step: NumberType, + *, + dtype: Optional[torch.dtype], + device: Optional[torch.device], + requires_grad: bool, +) -> TensorLikeType: + assert not ( + isinstance(start, complex) + and isinstance(end, complex) + and isinstance(step, complex) + ) + utils.check( + step != 0, + lambda: "step must be nonzero", + ) + utils.check( + math.isfinite(start) and math.isfinite(end), + lambda: f"unsupported range: {start} -> {end}", + ) + utils.check( + (step > 0 and end >= start) or (step < 0 and end <= start), + lambda: "upper bound and lower bound inconsistent with step sign", + ) + if dtype is not None: + pass + elif all(isinstance(arg, int) for arg in (start, end, step)): + dtype = torch.int64 + else: + dtype = torch.get_default_dtype() + device = _get_default_device() if device is None else device + shape = (math.ceil((end - start) / step),) + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _arange_aten( + start: NumberType, + end: NumberType, + step: NumberType, + *, + dtype: Optional[torch.dtype], + device: Optional[torch.device], + requires_grad: bool, +) -> TensorLikeType: + # mypy: Not all union combinations were tried because there are too many unions + return torch.arange( # type: ignore[call-overload, misc] + start, + end, + step, + dtype=dtype, + device=device, + layout=torch.strided, + pin_memory=False, + requires_grad=requires_grad, + ) + + +# TODO: maybe prims should not have requires_grad arg +# see: https://github.com/pytorch/pytorch/pull/77542/files#r873943255 +arange = _make_prim( + schema="arange(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype, Device? device, bool requires_grad) -> Tensor", # noqa: B950 + return_type=RETURN_TYPE.NEW, + meta=_arange_meta, + impl_aten=_arange_aten, + doc=_arange_doc, +) + + # TODO: layout, pin_memory, memory_format # TODO: model requires_grad on TensorMeta def _empty_meta( @@ -2387,7 +2675,7 @@ def _full_like_meta( device: torch.device, requires_grad: bool, ) -> TensorLikeType: - strides = strides = utils.compute_elementwise_output_strides(a) + strides = utils.compute_elementwise_output_strides(a) if a.numel() == 0: strides = a.stride() @@ -2422,6 +2710,105 @@ def _full_like_aten( doc=_full_like_doc, ) + +def _scalar_tensor_meta( + scalar: NumberType, + *, + dtype: torch.dtype, + device: torch.device, +) -> TensorLikeType: + shape: ShapeType = [] + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device) + + +def _scalar_tensor_aten( + scalar: NumberType, + *, + dtype: torch.dtype, + device: torch.device, +) -> Tensor: + if isinstance(scalar, complex) and ( + dtype is None or not utils.is_complex_dtype(dtype) + ): + raise TypeError("Complex scalar requires complex tensor dtype.") + # Note that Mypy thinks torch.scalar can't accept a complex scalar + return torch.scalar_tensor(scalar, dtype=dtype, device=device) # type: ignore[arg-type] + + +_scalar_tensor_doc = """ + Wraps a Number into a Tensor with the specified dtype and device. +""" + +# TODO: add layout and pin_memory support +scalar_tensor = _make_prim( + schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor", + meta=_scalar_tensor_meta, + impl_aten=_scalar_tensor_aten, + return_type=RETURN_TYPE.NEW, + doc=_scalar_tensor_doc, +) + + +# +# Linear algebra (linalg) prims +# + + +def _svd_meta( + A: TensorLikeType, *, full_matrices: bool +) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]: + utils.check_is_matrix(A, "linalg.svd") + utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False) + + A_shape = A.shape + batch = A_shape[:-2] + m, n = A_shape[-2:] + k = min(m, n) + + shape_U = batch + (m, m if full_matrices else k) + strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False) + U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device) + + shape_S = batch + (k,) + strides_S = utils.make_contiguous_strides_for(shape_S) + S = TensorMeta( + shape=shape_S, + strides=strides_S, + dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype, + device=A.device, + ) + + shape_Vh = batch + (n if full_matrices else k, n) + # The CPU backend returns V, but the cuSolver backend returns V^H + # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend + is_cuda = A.device.type == "cuda" + strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda) + Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device) + return U, S, Vh + + +def _svd_aten( + A: TensorLikeType, *, full_matrices: bool +) -> Tuple[Tensor, Tensor, Tensor]: + return torch.linalg.svd(A, full_matrices=full_matrices) + + +_svd_doc = """ + Returns the SVD of a matrix or batch of matrices. + + The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned. +""" + +svd = _make_prim( + schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)", + meta=_svd_meta, + impl_aten=_svd_aten, + return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW), + doc=_svd_doc, +) + + # # Randomness Prims # @@ -2466,3 +2853,126 @@ def _uniform_aten( impl_aten=_uniform_aten, doc=_uniform_doc, ) + + +def _fft_r2c_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + onesided: bool, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = list(input.shape) + if onesided: + last_dim = dim[-1] + shape[last_dim] = shape[last_dim] // 2 + 1 + + dtype = utils.corresponding_complex_dtype(input.dtype) + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device) + + +def _fft_r2c_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + onesided: bool, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_r2c(input, dim, normalization, onesided) + + +_fft_r2c_doc = """ + Performs a real to complex Fast Fourier Transform +""" + + +fft_r2c = _make_prim( + schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor", + meta=_fft_r2c_meta, + impl_aten=_fft_r2c_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_r2c_doc, +) + + +def _fft_c2c_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + forward: bool, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = input.shape + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta( + shape=shape, strides=strides, dtype=input.dtype, device=input.device + ) + + +def _fft_c2c_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + forward: bool, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_c2c(input, dim, normalization, forward) + + +_fft_c2c_doc = """ + Performs either a Fast Fourier Transform, or its inverse +""" + + +fft_c2c = _make_prim( + schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor", + meta=_fft_c2c_meta, + impl_aten=_fft_c2c_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_c2c_doc, +) + + +def _fft_c2r_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + last_dim_size: int, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = list(input.shape) + shape[dim[-1]] = last_dim_size + dtype = utils.corresponding_real_dtype(input.dtype) + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device) + + +def _fft_c2r_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + last_dim_size: int, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_c2r(input, dim, normalization, last_dim_size) + + +_fft_c2r_doc = """ + Performs a complex to real Inverse Fast Fourier Transform +""" + + +fft_c2r = _make_prim( + schema="fft_c2r(Tensor self, *, int[] dim, int last_dim_size) -> Tensor", + meta=_fft_c2r_meta, + impl_aten=_fft_c2r_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_c2r_doc, +) diff --git a/torch/_prims/context.py b/torch/_prims/context.py index c17b44efce890..c6ef474eec2d3 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -1,18 +1,18 @@ -from typing import Callable, Sequence, Any, Dict import functools - +from typing import Any, Callable, Dict, Sequence import torch -import torch.overrides -from torch._prims.utils import torch_function_passthrough +import torch._prims import torch._refs import torch._refs.nn import torch._refs.nn.functional import torch._refs.special +import torch.overrides -import torch._prims +from torch._prims_common import torch_function_passthrough +from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule @functools.lru_cache(None) @@ -26,6 +26,8 @@ def torch_to_refs_map(): (torch.nn, torch._refs.nn), (torch.nn.functional, torch._refs.nn.functional), (torch.special, torch._refs.special), + (torch.fft, torch._refs.fft), + (torch.linalg, torch._refs.linalg), ] r: Dict[Any, Any] = { torch.Tensor.__invert__: torch._refs.bitwise_not, @@ -33,6 +35,15 @@ def torch_to_refs_map(): torch.Tensor.__and__: torch._refs.bitwise_and, torch.Tensor.__or__: torch._refs.bitwise_or, torch.Tensor.__eq__: torch._refs.eq, + torch.Tensor.new_empty: torch._refs.new_empty, + torch.Tensor.new_full: torch._refs.new_full, + torch.Tensor.new_zeros: torch._refs.new_zeros, + torch.Tensor.new_ones: torch._refs.new_ones, + torch.Tensor.fill_: torch._refs.fill_, + torch.Tensor.zero_: torch._refs.zero_, + # TODO: Should these methods be mapped some other way? + torch.Tensor.copy_: torch._prims.copy_to, + torch.Tensor.resize: torch._prims.resize, } for mod_torch, mod_refs in modules: for s in mod_refs.__all__: # type: ignore[attr-defined] @@ -58,15 +69,18 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode): Switches the interpretation of torch.* functions and Tensor methods to use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.) - >>> with TorchRefsMode.push(): + >>> with TorchRefsMode(): ... torch.add(x, y) # calls torch._refs.add(x, y) By default, this context manager will fall back on the torch.* if the ref does not exist; set strict=True to error if this occurs. + If the ref exists we still would like to fall back on the torch.* sometimes, + this behavior can be customized by passing a function to should_fallback_fn. """ - def __init__(self, strict=False): + def __init__(self, strict=False, should_fallback_fn=lambda *_: False): self.strict = strict + self.should_fallback_fn = should_fallback_fn def __torch_function__( self, @@ -83,9 +97,39 @@ def __torch_function__( mapping = torch_to_refs_map() func = mapping.get(orig_func, None) if func is not None: - return func(*args, **kwargs) + # If the ref exists query whether we should use it or not + if self.should_fallback_fn(self, func, args, kwargs): + return orig_func(*args, **kwargs) + # torch calls inside func should be interpreted as refs calls + with torch.overrides.enable_torch_function_mode(self, replace=self.inner): + return func(*args, **kwargs) if self.strict: raise RuntimeError( f"no _refs support for {torch.overrides.resolve_name(orig_func)}" ) return orig_func(*args, **kwargs) + + +def _is_node_supported_nvfuser(node): + return ( + node.op == "call_function" + and getattr(node.target, "impl_nvfuser", None) is not None + ) + + +def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs): + with torch.overrides.enable_torch_function_mode( + torch_function_mode, replace=torch_function_mode.inner + ): + gm = get_isolated_graphmodule(func, args, kwargs) + + call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes) + any_unsupported = any( + not _is_node_supported_nvfuser(node) for node in call_function_nodes + ) + return any_unsupported + + +TorchRefsNvfuserCapabilityMode = functools.partial( + TorchRefsMode, should_fallback_fn=_is_func_unsupported_nvfuser +) diff --git a/torch/_prims/executor.py b/torch/_prims/executor.py index 3baccd649f31c..ca8d07d2f2031 100644 --- a/torch/_prims/executor.py +++ b/torch/_prims/executor.py @@ -1,16 +1,10 @@ from typing import Callable -import torch +from torch._prims.context import TorchRefsMode +from torch._prims.nvfuser_executor import nvfuser_execute, nvfuser_execute_partitioned from torch.fx import GraphModule from torch.fx.experimental.proxy_tensor import make_fx -from torch._prims.utils import getnvFuserDtype, Number -from torch._prims.context import TorchRefsMode -import torch.overrides -from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten - -if torch.cuda.is_available(): - from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import] def execute(gm: GraphModule, *args, executor: str = "aten"): @@ -23,55 +17,9 @@ def execute(gm: GraphModule, *args, executor: str = "aten"): if executor == "aten": return gm.forward(*args) elif executor == "nvfuser": - if not torch.cuda.is_available(): - raise RuntimeError( - "Attempting to use nvFuser trace executor but CUDA is not available!" - ) - - # PROTOTYPE nvfuser executor - # Everything in the graph must support nvfuser - - fusion = Fusion() - with FusionDefinition(fusion) as fd: - - def _to_nvfuser_constant(arg): - if isinstance(arg, Number): - return fd.define_constant(arg) - else: - return arg - - class FusionInterpreter(torch.fx.Interpreter): - def call_function(self, target, args, kwargs): - args = tuple(map(_to_nvfuser_constant, args)) - target = target.impl_nvfuser - args = (fd,) + args - return target(*args, **kwargs) - - def to_nv(arg): - if isinstance(arg, torch.Tensor): - x = fd.define_tensor( - arg.size(), arg.stride(), getnvFuserDtype(arg.dtype) - ) - fd.add_input(x) - return x - else: - return arg - - # Transforms graph to call nvfuser lowerings - # Note, this doesn't handle nested structures in the args, TODO: add tree_flatten - nv_args = tree_map(to_nv, args) - out = FusionInterpreter(gm).run(*nv_args) - flat_out, unflatten_spec = tree_flatten(out) - for o in flat_out: - fd.add_output(o) - assert len(args) == 1 - args = args[0] # we are passing a packed list of args - return tree_unflatten( - fusion.execute( - tuple(arg for arg in args if isinstance(arg, torch.Tensor)) - ), - unflatten_spec, - ) + return nvfuser_execute_partitioned(gm, *args) + elif executor == "strictly_nvfuser": + return nvfuser_execute(gm, *args) msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor @@ -118,7 +66,7 @@ def wrapped(args): kwargs = dict(zip(kwargs_keys, args[nargs:])) return fn(*fn_args, **kwargs) - with TorchRefsMode.push(): + with TorchRefsMode(): gm = make_fx(wrapped)(all_args) return execute(gm, all_args, executor=executor) diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py new file mode 100644 index 0000000000000..93de1f6bc43ce --- /dev/null +++ b/torch/_prims/nvfuser_executor.py @@ -0,0 +1,186 @@ +from copy import deepcopy +from dataclasses import dataclass +from functools import lru_cache +from warnings import warn + +import torch +import torch.overrides +from torch._prims_common import getnvFuserDtype, Number + +from torch.fx import GraphModule +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +if torch.cuda.is_available(): + from torch._C._nvfuser import ( # type: ignore[import] + DataType, + Fusion, + FusionDefinition, + ) +else: + DataType = None + + +# nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects +# for cached construction of the nvFuser's Fusion +# TODO: change what is stored in the cache for nvFuser's Tensor objects +# https://github.com/pytorch/pytorch/issues/80551 +@dataclass(frozen=True) +class nvFuserTensorTemplate: + size: tuple + stride: tuple + dtype: DataType + + +@dataclass(frozen=True) +class nvFuserScalarTemplate: + dtype: DataType + + +def to_nvfuser_template_args(args): + def to_nvfuser(arg): + if isinstance(arg, torch.Tensor): + return nvFuserTensorTemplate( + arg.size(), arg.stride(), getnvFuserDtype(arg.dtype) + ) + elif isinstance(arg, Number): + return nvFuserScalarTemplate(getnvFuserDtype(type(arg))) + else: + return arg + + return tree_map(to_nvfuser, args) + + +# MyPy bug: https://github.com/python/mypy/issues/5107 +@lru_cache(maxsize=1024) # type: ignore[arg-type] +def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates): + # PROTOTYPE nvfuser executor + # Everything in the graph must support nvfuser + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and getattr(node.target, "impl_nvfuser", None) is None + ): + raise ValueError( + "All call_function nodes in the graph must support nvfuser. " + f"Node {node} with target {node.target} does not support nvfuser" + ) + + fusion = Fusion() + with FusionDefinition(fusion) as fd: + + def _to_nvfuser_constant(arg): + if isinstance(arg, Number): + return fd.define_constant(arg) + else: + return arg + + class FusionInterpreter(torch.fx.Interpreter): + def call_function(self, target, args, kwargs): + args = tuple(map(_to_nvfuser_constant, args)) + target = target.impl_nvfuser + args = (fd,) + args + return target(*args, **kwargs) + + def templates_to_nvfuser_inputs(arg): + if isinstance(arg, nvFuserTensorTemplate): + x = fd.define_tensor(arg.size, arg.stride, arg.dtype) + return x + elif isinstance(arg, nvFuserScalarTemplate): + x = fd.define_scalar(arg.dtype) + return x + else: + return arg + + # Transforms graph to call nvfuser lowerings + nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates)) + out = FusionInterpreter(gm).run(*nv_args) + flat_out, unflatten_spec = tree_flatten(out) + for o in flat_out: + fd.add_output(o) + + return fusion, unflatten_spec + + +def nvfuser_execute(gm: GraphModule, *args): + if not torch.cuda.is_available(): + raise RuntimeError( + "Attempting to use nvFuser trace executor but CUDA is not available!" + ) + + flat_args, _ = tree_flatten(args) + + # Construction of the fusion is expensive and cached based on the GraphModule + # and symbolic nvFuser args. + nv_template_args = to_nvfuser_template_args(flat_args) + fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc] + + # Inputs to fusion.execute correspond to the same template/symbolic inputs + # marked with `define_tensor/scalar` + concrete_fusion_inputs = tuple( + arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number)) + ) + + return tree_unflatten( + fusion.execute(concrete_fusion_inputs), # type: ignore[has-type] + unflatten_spec, # type: ignore[has-type] + ) + + +class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and getattr(node.target, "impl_nvfuser", None) is not None + ) + + +class PartitionedInterpreter(torch.fx.Interpreter): + def call_module(self, target, args, kwargs): + assert isinstance(target, str) + assert len(kwargs) == 0 + submod = self.fetch_attr(target) + # CapabilityBasedPartitioner hardcodes the name of the subgraphs with supported_ops as "fused_" + subgraph id + if target.startswith("fused_"): + return nvfuser_execute(submod, *args) + else: + return super().call_module(target, args, kwargs) + + +# MyPy bug: https://github.com/python/mypy/issues/5107 +@lru_cache() # type: ignore[arg-type] +def maybe_partition_graph(gm: GraphModule): + supported_ops = NvfuserPrimOperatorSupport() + call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes) + # the graph is partitioned only if at least one node is not supported by nvFuser + any_unsupported = any( + not supported_ops.is_node_supported(None, node) for node in call_function_nodes + ) + if any_unsupported: + # CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph + gm = deepcopy(gm) + partitioner = CapabilityBasedPartitioner( + gm, supported_ops, allows_single_node_partition=True + ) + partitions = partitioner.propose_partitions() + if len(partitions) == 0: + warn( + "No partition found for the graph. " + + "This is likely because the graph is not supported by nvFuser. " + + "Please use the eager ATen mode to execute the graph.", + category=RuntimeWarning, + ) + partitioned_graph = partitioner.fuse_partitions(partitions) + return partitioned_graph, any_unsupported + else: + return gm, any_unsupported + + +def nvfuser_execute_partitioned(gm: GraphModule, *args): + # When possible it's better to use nvfuser_execute directly + # because it avoids PartitionedInterpreter's overhead + gm, is_partitioned = maybe_partition_graph(gm) + if is_partitioned: + return PartitionedInterpreter(gm).run(*args) + else: + return nvfuser_execute(gm, *args) diff --git a/torch/_prims/utils.py b/torch/_prims_common/__init__.py similarity index 77% rename from torch/_prims/utils.py rename to torch/_prims_common/__init__.py index 14301ee2bb5d9..49ae9e806a947 100644 --- a/torch/_prims/utils.py +++ b/torch/_prims_common/__init__.py @@ -1,16 +1,15 @@ from __future__ import annotations -from typing import Any, Union, Sequence, Optional, Tuple, List, Callable, Type +from typing import Any, Union, Sequence, Optional, Tuple, List, Callable, Type, overload from enum import Enum from functools import reduce, cmp_to_key import operator import weakref -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode import torch -# nvFuser imports are conditional on CUDA being available -if torch.cuda.is_available(): +# nvFuser imports are conditional on being compiled with CUDA +if hasattr(torch._C, "_nvfuser"): from torch._C._nvfuser import DataType # type: ignore[import] _torch_dtype_to_nvfuser_dtype_map = { @@ -23,12 +22,17 @@ torch.long: DataType.Int, torch.int: DataType.Int32, torch.bool: DataType.Bool, + # Python scalars + complex: DataType.ComplexDouble, + float: DataType.Double, + int: DataType.Int, + bool: DataType.Bool, } else: _torch_dtype_to_nvfuser_dtype_map = {} -def getnvFuserDtype(dtype: torch.dtype): +def getnvFuserDtype(dtype: Union[torch.dtype, NumberTypeType]): """ Translates from torch.dtype to nvFuser's DataType enum """ @@ -39,9 +43,11 @@ def getnvFuserDtype(dtype: torch.dtype): StrideType = Union[List[int], Tuple[int, ...]] DimsType = Union[int, List[int], Tuple[int, ...]] DimsSequenceType = Union[List[int], Tuple[int, ...]] +NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]] NumberType = Union[bool, int, float, complex] Number = (bool, int, float, complex) DeviceLikeType = Union[str, torch.device] +Tensor = torch.Tensor torch_function_passthrough = { @@ -49,11 +55,15 @@ def getnvFuserDtype(dtype: torch.dtype): torch.Tensor.numel, torch.Tensor.stride, torch.Tensor.dtype.__get__, # type: ignore[attr-defined] + torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined] torch.Tensor.shape.__get__, # type: ignore[attr-defined] torch.Tensor.device.__get__, # type: ignore[attr-defined] + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] + torch.Tensor.layout.__get__, # type: ignore[attr-defined] # For TorchRefsMode only torch.Tensor.__format__, torch.Tensor.__repr__, + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] } @@ -63,82 +73,6 @@ def getnvFuserDtype(dtype: torch.dtype): TensorOrNumberLikeType = Union[TensorLikeType, NumberType] -# In order to keep things like aliasing relationships and storage -# consistent wrt/meta tensors, FakeTensors own a FakeTensorMode -# which caches conversions to Meta Tensors. We would like to use -# one consistent mode among along FakeTensors, which we store here. -# We store a weakref, so that when all previous FakeTensors are -# the present mode will also deallocate. FakeTensorMode holds onto -# tensors that are converted to Meta so we don't want to persist it -# longer than necessary.x -prim_fake_mode_ref = None - - -def get_prim_fake_mode(): - global prim_fake_mode_ref - if prim_fake_mode_ref is None or prim_fake_mode_ref() is None: - mode = FakeTensorMode() - prim_fake_mode_ref = weakref.ref(mode) - return mode - else: - return prim_fake_mode_ref() - - -def TensorMeta( - tensorlike: Optional[Union[NumberType, torch.Tensor]] = None, - *, - shape: Optional[ShapeType] = None, - strides: Optional[StrideType] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str]] = None, -): - if isinstance(tensorlike, Number): - assert not shape and (shape is None or isinstance(shape, Sequence)) - assert not strides and (strides is None or isinstance(strides, Sequence)) - inferred_shape: Tuple[int, ...] = () - inferred_strides: Tuple[int, ...] = () - inferred_dtype = type_to_dtype(type(tensorlike)) - inferred_device = torch.device("cpu") - # TODO: This looks wrong, a number that is wrapped into a tensor - # needs to behave differently than a scalar tensor for type - # promotion purposes - elif tensorlike is not None: - assert isinstance(tensorlike, torch.Tensor) - inferred_shape = tuple(tensorlike.shape) - inferred_strides = tuple(tensorlike.stride()) - inferred_dtype = tensorlike.dtype - inferred_device = tensorlike.device - else: - # If no tensorlike "example" is given then all metadata - # must be provided explicitly - assert shape is not None - assert strides is not None - assert dtype is not None - assert device is not None - - shape = inferred_shape if shape is None else tuple(shape) - strides = inferred_strides if strides is None else tuple(strides) - dtype = inferred_dtype if dtype is None else dtype - device = inferred_device if device is None else device - - if isinstance(device, str): - device = torch.device(device) - - if isinstance(tensorlike, FakeTensor): - mode = tensorlike.fake_mode - else: - mode = get_prim_fake_mode() - - if device.type == "meta": - return torch.empty_strided(shape, strides, dtype=dtype, device="meta") - else: - return FakeTensor( - mode, - torch.empty_strided(shape, strides, dtype=dtype, device="meta"), - device, - ) - - def same_shape(a: ShapeType, b: ShapeType) -> bool: if len(a) != len(b): return False @@ -152,7 +86,7 @@ def same_shape(a: ShapeType, b: ShapeType) -> bool: # TODO: look at using torch.testing.assert_close instead with an option # to just compare metadata -def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType): +def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType, check_strides=False): """ Checks that two tensor likes have the same shape, dtype and device. @@ -183,12 +117,13 @@ def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType): raise AssertionError(msg) # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050 - # same_strides, idx = check_significant_strides(a, b) - # if not same_strides: - # msg = "Stride mismatch! Strides are {0} and {1} (mismatched at {2})!".format( - # a.stride(), b.stride(), idx - # ) - # raise RuntimeError(msg) + if check_strides: + same_strides, idx = check_significant_strides(a, b) + if not same_strides: + msg = "Stride mismatch! Strides are {0} and {1} (mismatched at {2})!".format( + a.stride(), b.stride(), idx + ) + raise RuntimeError(msg) def check_significant_strides( @@ -206,14 +141,15 @@ def check_significant_strides( return True, None +# This function is equivalent to compute_contiguous() from TensorImpl.cpp def is_contiguous(a: TensorLikeType) -> bool: """ Tests whether a tensor is contiguous or not. Tensors are contiguous when they have no elements, - or when they have "nested" strides. + one element, or when they have "nested" strides. """ - if a.numel() == 0: + if a.numel() < 2: return True expected_stride = 1 @@ -229,6 +165,140 @@ def is_contiguous(a: TensorLikeType) -> bool: return True +# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp +def is_channels_last_contiguous_2d(a: Tensor) -> bool: + # NHWC or not channels last 2D contiguous + if a.ndim != 4: + return False + + expected_stride = 1 + for idx in (1, 3, 2, 0): + + length = a.shape[idx] + if length == 1: + continue + + stride = a.stride()[idx] + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +def is_channels_last_contiguous_3d(a: Tensor) -> bool: + # NDHWC or not channels last 3D contiguous + if a.ndim != 5: + return False + + expected_stride = 1 + for idx in (1, 4, 3, 2, 0): + + length = a.shape[idx] + if length == 1: + continue + + stride = a.stride()[idx] + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +_memory_formats = set( + ( + torch.contiguous_format, + torch.preserve_format, + torch.channels_last, + torch.channels_last_3d, + ) +) + + +def validate_memory_format(memory_format: torch.memory_format): + check( + memory_format in _memory_formats, + lambda: f"Received unknown memory format {memory_format}!", + ) + + +def is_contiguous_for_memory_format( # type: ignore[return] + a: Tensor, *, memory_format: torch.memory_format +) -> bool: + validate_memory_format(memory_format) + + if memory_format == torch.contiguous_format: + return is_contiguous(a) + if memory_format == torch.channels_last: + return is_channels_last_contiguous_2d(a) + if memory_format == torch.channels_last_3d: + return is_channels_last_contiguous_3d(a) + + check( + False, + lambda: f"is_contiguous received unsupported memory format {memory_format}", + ) + + +# NOTE: that tensors with no elements and channels last is ??? +def is_channels_last_contiguous(a: Tensor) -> bool: + """ + True when a tensor is channels-last contiguous. + + This requires that: + + - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions + - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the + stride of the 'C' dimension (Cs) is 1 and the strides corresponding to + each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are + "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension, + for example. + """ + return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) + + +def is_non_overlapping_and_dense(a: Tensor) -> bool: + """ + True when a tensor is non-overlapping and dense. + + A tensor is non-overlapping and dense when there exists a permutation of + its dimensions that is contiguous. + """ + + # Short-circuits if the tensor is already contiguous or channels-last contiguous + if is_contiguous(a) or is_channels_last_contiguous(a): + return True + + # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + if a.ndim == 1: + return a.stride()[0] == 1 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + lengths_and_strides = sorted( + tuple(zip(a.shape, a.stride())), key=operator.itemgetter(1) + ) + + expected_stride = 1 + for length, stride in lengths_and_strides: + + if length == 1: + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + # NOTE: Based on the implementation in TensorIterator.cpp, but note that # the note [Computing output strides] is incorrect, because it # says that strides will be preserved even if they are not @@ -319,7 +389,6 @@ def validate_dim_length(length: int): dimension length. """ - assert isinstance(length, int) assert length >= 0 @@ -397,7 +466,17 @@ def canonicalize_dim(rank: int, idx: int) -> int: # Takes a dimension or sequence of dimensions and "wraps" them, # mapping negative offsets to positive ones -def canonicalize_dims(rank: int, indices: DimsType) -> DimsType: +@overload +def canonicalize_dims(rank: int, indices: Sequence[int]) -> Tuple[int, ...]: + pass + + +@overload +def canonicalize_dims(rank: int, indices: int) -> int: + pass + + +def canonicalize_dims(rank, indices): if isinstance(indices, int): return canonicalize_dim(rank, indices) @@ -471,7 +550,7 @@ def check_same_device(*args, allow_cpu_scalar_tensors): raise RuntimeError(msg) -def canonicalize_device(device: Union[str, torch.device]) -> torch.device: +def canonicalize_device(device: DeviceLikeType) -> torch.device: if isinstance(device, torch.device): return device @@ -514,6 +593,8 @@ def check_same_shape(*args, allow_cpu_scalar_tensors: bool): raise RuntimeError(msg) +# Acquires a common shape, if it exists, from one or more tensor arguments, +# filtering number arguments def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: shape = None scalar_shape = None @@ -559,7 +640,7 @@ def extract_shape_from_varargs( """ # Handles tuple unwrapping - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], Sequence): shape = shape[0] validate_shape(shape) # type: ignore[arg-type] @@ -567,6 +648,7 @@ def extract_shape_from_varargs( _integer_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) +_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) _float_dtypes = (torch.float16, torch.bfloat16, torch.float32, torch.float64) _complex_dtypes = (torch.complex32, torch.complex64, torch.complex128) @@ -581,6 +663,11 @@ def is_integer_dtype(dtype: torch.dtype) -> bool: return dtype in _integer_dtypes +def is_low_precision_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _low_precision_dtypes + + def is_float_dtype(dtype: torch.dtype) -> bool: assert isinstance(dtype, torch.dtype) return dtype in _float_dtypes @@ -639,24 +726,52 @@ def dtype_to_type(dtype: torch.dtype) -> type: raise ValueError("Invalid dtype!") -_type_to_dtype_map = { - bool: torch.bool, - int: torch.int64, - float: torch.float64, - complex: torch.complex128, -} - - def type_to_dtype(typ: type) -> torch.dtype: """ Computes the corresponding dtype for a Number type. """ - return _type_to_dtype_map[typ] + + assert isinstance(typ, type) + + if typ is bool: + return torch.bool + if typ is int: + return torch.long + if typ is float: + return torch.get_default_dtype() + if typ is complex: + return corresponding_complex_dtype(torch.get_default_dtype()) + + raise ValueError("Invalid type!") _ordered_types = (bool, int, float, complex) +def check_fp_or_complex( + dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True +): + """ + Checks whether the input is floating point or complex. + If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 + """ + check( + is_float_dtype(dtype) or is_complex_dtype(dtype), + lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", + ) + check( + allow_low_precision_dtypes or not is_low_precision_dtype(dtype), + lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", + ) + + +def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): + check( + len(A.shape) >= 2, + lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", + ) + + def get_higher_type(a: type, b: type) -> type: """ Returns the higher of the two given Number types. @@ -1083,21 +1198,14 @@ def reduction_dtypes( return computation_dtype, result_dtype -def wrap_device(d: Union[str, torch.device]) -> torch.device: +def make_contiguous_strides_for( + shape: ShapeType, row_major: bool = True +) -> Tuple[int, ...]: """ - Wraps strings into torch.device objects. - - Given torch.device objects are returned unmodified. + Returns the strides of a contriguous tensor if row_major + If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices + This is often used when calling external libraries like BLAS/LAPACK/cuSolver... """ - - assert isinstance(d, (str, torch.device)) - if isinstance(d, str): - return torch.device(d) - - return d - - -def make_contiguous_strides_for(shape: ShapeType) -> Tuple[int, ...]: validate_shape(shape) if not shape: return () @@ -1105,14 +1213,65 @@ def make_contiguous_strides_for(shape: ShapeType) -> Tuple[int, ...]: multiplier = 1 strides = [] for l in reversed(shape): + strides.append(multiplier) if l != 0: - strides.append(multiplier) - multiplier = l * multiplier - else: - strides.append(multiplier) + multiplier *= l result = tuple(reversed(strides)) - return result + + if row_major: + return result + else: + if len(shape) < 2: + return result + return result[:-2] + (1, max(shape[-2], 1)) + + +def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5? + check( + len(shape) == 4, + lambda: "Only tensors of rank 4 can use the channels_last memory format", + ) + + multiplier = 1 + strides = [0] * 4 + for idx in (1, -1, -2, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + check( + len(shape) == 5, + lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", + ) + + multiplier = 1 + strides = [0] * 5 + for idx in (1, -1, -2, -3, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]: + ndim = len(shape) if isinstance(shape, Sequence) else 1 + if ndim == 4: + return make_channels_last_2d_strides_for(shape) + elif ndim == 5: + return make_channels_last_3d_strides_for(shape) + else: + raise RuntimeError( + f"no channels last format strides exist in {ndim} dimensions" + ) def compute_reduction_output_shape( @@ -1145,7 +1304,7 @@ def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ... def check_in_bounds_for_storage( - a: torch._TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int + a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int ): """ Determines if the given shape, strides, and offset are valid for the given storage. @@ -1182,3 +1341,46 @@ def check( """ if not b: raise exc_type(s()) + + +# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in +# c10/core/MemoryFormat.h into one function +def are_strides_like_channels_last( + shape: Sequence[int], strides: Sequence[int] +) -> bool: + ndim = len(shape) + + if ndim == 4: + # Check for channels_last_2d + dim_order = [1, 3, 2, 0] + elif ndim == 5: + # Check for channels_last_3d + dim_order = [1, 4, 3, 2, 0] + else: + return False + + if strides[1] == 0: + return False + + min = 0 + for d in dim_order: + if shape[d] == 0: + return False + if strides[d] < min: + return False + if d == 0 and min == strides[1]: + return False + min = strides[d] + if strides[d] > 1: + min *= shape[d] + return True + + +def suggest_memory_format(x: TensorLikeType) -> torch.memory_format: + if x.layout != torch.strided: + return torch.contiguous_format + + if are_strides_like_channels_last(x.shape, x.stride()): + return torch.channels_last if x.ndim == 4 else torch.channels_last_3d + + return torch.contiguous_format diff --git a/torch/_prims/wrappers.py b/torch/_prims_common/wrappers.py similarity index 57% rename from torch/_prims/wrappers.py rename to torch/_prims_common/wrappers.py index 27270ac1da859..a3199356ea1b3 100644 --- a/torch/_prims/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -1,15 +1,15 @@ import torch -from torch._prims.utils import ( +from torch._prims_common import ( Number, NumberType, TensorLike, TensorLikeType, ELEMENTWISE_TYPE_PROMOTION_KIND, ) -import torch._prims.utils as utils +import torch._prims_common as utils from torch.utils._pytree import tree_flatten -from typing import Callable, Sequence, Union +from typing import Callable, Sequence, Union, Tuple, NamedTuple import inspect from functools import wraps, reduce import operator @@ -20,6 +20,7 @@ def _maybe_convert_to_dtype( a: Union[TensorLikeType, NumberType, Sequence], dtype: torch.dtype ) -> Union[TensorLikeType, NumberType, Sequence]: + import torch._prims as prims if isinstance(a, TensorLike): if a.dtype != dtype: # NOTE: this is incorrect on the CPU @@ -124,7 +125,7 @@ def _fn(*args, **kwargs): # TODO: handle tuples of tensors def _maybe_resize_out(out: TensorLikeType, shape): if out.numel() == 0: - return prims.resize(out, shape) + return out.resize_(shape) if out.numel() != reduce(operator.mul, shape, 1): msg = ( @@ -137,12 +138,14 @@ def _maybe_resize_out(out: TensorLikeType, shape): ) ) warnings.warn(msg) - return prims.resize(out, shape) + return out.resize_(shape) return out -def _safe_copy_out(*, copy_from: TensorLikeType, copy_to: TensorLikeType): +def _safe_copy_out( + *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False +): # Checks same device if copy_from.device != copy_to.device: msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format( @@ -151,94 +154,118 @@ def _safe_copy_out(*, copy_from: TensorLikeType, copy_to: TensorLikeType): raise RuntimeError(msg) # Checks safe cast - if not utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype): - msg = "Attempting to cast from {0} to out tensor with dtype {1}, but this can't be cast because it is not safe!".format( - copy_from.dtype, copy_to.dtype + if exact_dtype: + utils.check( + copy_from.dtype == copy_to.dtype, + lambda: f"Expected out tensor to have dtype {copy_from.dtype} " + "but got {copy_to.dtype} instead", + ) + else: + utils.check( + utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), + lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " + "but this can't be cast because it is not safe!", ) - raise RuntimeError(msg) - - return prims.copy_to(copy_to, copy_from) + return copy_to.copy_(copy_from) -# FIXME: only supports out parameter that is literally called "out" -def out_wrapper(fn: Callable) -> Callable: - """ - Adds the out parameter to a Python reference. - Note that this currently only supports operations that return a single tensor. - """ - - @wraps(fn) - def _fn(*args, out=None, **kwargs): - result = fn(*args, **kwargs) - if out is not None: - assert isinstance(out, TensorLike) - out = _maybe_resize_out(out, result.shape) - return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] - return result +def out_wrapper(*out_names: str, exact_dtype: bool = False): + is_tensor = len(out_names) == 0 + assert is_tensor or len(out_names) >= 2 - sig = inspect.signature(fn) - out_param = inspect.Parameter( - "out", - kind=inspect.Parameter.KEYWORD_ONLY, - default=None, - annotation=TensorLikeType, - ) - params = chain(sig.parameters.values(), (out_param,)) - _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] - parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] - ) - _fn.__annotations__ = fn.__annotations__ - _fn.__annotations__["out"] = TensorLikeType - return _fn + def _out_wrapper(fn: Callable) -> Callable: + """ + Adds the out parameter to a Python reference. + """ + out_type = ( + TensorLikeType + if is_tensor + else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] + ) + return_type = ( + TensorLikeType + if is_tensor + else NamedTuple( + f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] + ) + ) + sig = inspect.signature(fn) + factory_kwargs = ("device", "dtype") + is_factory_fn = all(p in sig.parameters for p in factory_kwargs) -def out_wrapper_multi(*out_names): - def go(fn: Callable) -> Callable: @wraps(fn) - def _fn(*args, **kwargs): - out_kwargs = {} - has_out_kwargs = None - for o in out_names: - out_kwargs[o] = kwargs.pop(o, None) - # Either all of the out kwargs are set or none of them - if has_out_kwargs is None: - has_out_kwargs = out_kwargs[o] is not None - else: - assert has_out_kwargs == (out_kwargs[o] is not None) - result = fn(*args, **kwargs) - assert isinstance(result, tuple) - if has_out_kwargs: - final_result = [] - for i, o in enumerate(out_names): - out = out_kwargs[o] - assert isinstance(out, TensorLike) - out = _maybe_resize_out(out, result[i].shape) - final_result.append(_safe_copy_out(copy_from=result[i], copy_to=out)) # type: ignore[arg-type] - return tuple(final_result) - return result + def _fn(*args, out=None, **kwargs): + if is_factory_fn and out is not None: + for k in factory_kwargs: + out_attr = getattr(out, k) + if k not in kwargs: + kwargs[k] = out_attr - sig = inspect.signature(fn) - out_params = [] - for o in out_names: - out_params.append( - inspect.Parameter( - o, - kind=inspect.Parameter.KEYWORD_ONLY, - default=None, - annotation=TensorLikeType, - ) + result = fn(*args, **kwargs) + assert ( + isinstance(result, TensorLike) + and is_tensor + or isinstance(result, Tuple) # type: ignore[arg-type] + and len(result) == len(out_names) ) - params = chain(sig.parameters.values(), out_params) + if out is not None: + # Naively you might expect this assert to be true, but + # it's not: + # + # assert type(out) == type(result) + # + # The reason is that functions under this wrapper can + # get registered to the Meta dispatch key, and that + # means they can be executed in a context where tensor + # subclasses are disabled (with no_dispatch), which is a + # handy way for an is-a tensor subclass (e.g., + # FakeTensor) to have the normal meta backend create a + # meta tensor, to be wrapped once it gets returned. + # In this situation, you will get a FakeTensor as + # the output tensor, but not the result--which will + # be a normal meta tensor, but this is perfectly + # harmless. + if is_tensor: + assert isinstance(out, TensorLike) + # These two operations are done in-place + _maybe_resize_out(out, result.shape) + _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + assert isinstance(out, Tuple) # type: ignore[arg-type] + utils.check( + len(out) == len(result), + lambda: f"expected tuple of {len(result)} elements but got {len(out)}", + TypeError, + ) + for r, o in zip(result, out): + # These two operations are done in-place + _maybe_resize_out(o, r.shape) + _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + out = result + # mypy does not see through the definition of out_type given that it's in a different scope + return out if is_tensor else return_type(*out) # type: ignore[operator] + + out_param = inspect.Parameter( + "out", + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_type, + ) + # Mark that the function now returns a tuple + assert sig.return_annotation in (sig.empty, out_type) + params = chain(sig.parameters.values(), (out_param,)) _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] - parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] + parameters=params, return_annotation=return_type # type: ignore[arg-type] ) _fn.__annotations__ = fn.__annotations__ - for o in out_names: - _fn.__annotations__[o] = TensorLikeType + _fn.__annotations__["out"] = out_type + _fn.__annotations__["return"] = return_type return _fn - return go + return _out_wrapper # TODO: when tracing this will add torch tensors and not TensorMeta objects @@ -264,7 +291,3 @@ def _fn(*args, **kwargs): _fn.__signature__ = sig # type: ignore[attr-defined] return _fn - - -# avoid mypy import cycle -import torch._prims as prims diff --git a/torch/_python_dispatcher.py b/torch/_python_dispatcher.py index ee2c7d279458e..fbad612ab8d37 100644 --- a/torch/_python_dispatcher.py +++ b/torch/_python_dispatcher.py @@ -1,4 +1,5 @@ import re + import torch._C as C @@ -48,15 +49,19 @@ This file only provides the simplified API for developers, revelant test code is located in test/test_dispatch.py """ + + class PythonDispatcher: namespace = "__test__" name = "foo" + # fmt: off runtime_keys = [ "CPU", "AutogradCPU", "FPGA", "AutogradOther", "XLA", "AutogradXLA", "Lazy", "AutogradLazy", ] + # fmt: on alias_keys = [ "CompositeExplicitAutograd", "Autograd", @@ -73,6 +78,7 @@ def __init__(self): Returns a list of dispatch keys supported by PythonDispatcher. You can register kernels to these keys. """ + def keys(self): return self.supported_keys @@ -83,27 +89,39 @@ def keys(self): this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is automatically generated and registered. """ + def register(self, dispatchKeys): # Overriden is not supported and triggers a warning in C++ dispatcher. if len(set(dispatchKeys)) != len(dispatchKeys): - raise RuntimeError(f"Overriden is not allowed but found duplicates in {dispatchKeys}.") + raise RuntimeError( + f"Overriden is not allowed but found duplicates in {dispatchKeys}." + ) # We currently forbid this in codegen instead of C++ dispatcher. - if 'CompositeImplicitAutograd' in dispatchKeys and 'CompositeExplicitAutograd' in dispatchKeys: - raise RuntimeError("Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed.") + if ( + "CompositeImplicitAutograd" in dispatchKeys + and "CompositeExplicitAutograd" in dispatchKeys + ): + raise RuntimeError( + "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed." + ) for key in dispatchKeys: if key not in self.supported_keys: - raise RuntimeError(f"{key} is not supported, please select a dispatch key in {self.supported_keys}.") + raise RuntimeError( + f"{key} is not supported, please select a dispatch key in {self.supported_keys}." + ) self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key) """ Helper function to format (key, kernel). """ + def _format_line(self, key, kernel): return "{:<15} {}\n".format(key, kernel) """ Helper function to print a table header. """ + def _format_header(self, header): s = f""" {header} @@ -116,6 +134,7 @@ def _format_header(self, header): Returns raw output of all registration info for debugging only. Use registrations() for a simplified version. """ + def rawRegistrations(self): return C._dispatch_dump("{}::{}".format(self.namespace, self.name)) # type: ignore[attr-defined] @@ -123,6 +142,7 @@ def rawRegistrations(self): Returns raw output of computed dispatch table for debugging only. Use dispatchTable() for a simplified version. """ + def rawDispatchTable(self): return C._dispatch_dump_table("{}::{}".format(self.namespace, self.name)) # type: ignore[attr-defined] @@ -130,10 +150,11 @@ def rawDispatchTable(self): Returns a table(str) including all the registrations from users. Note this includes registrations to both runtime keys and alias keys. """ + def registrations(self): output = self._format_header("Registered Kernels") state = self.rawRegistrations() - state_entries = state.split('\n') + state_entries = state.split("\n") for line in state_entries: first = line.split(":")[0] if any(first.startswith(k) for k in self.supported_keys): @@ -146,14 +167,15 @@ def registrations(self): runtime keys, registrations to alias keys have been decoded to their mapped runtime keys. """ + def dispatchTable(self): output = self._format_header("Computed Dispatch Table") table = self.rawDispatchTable() - table_entries = table.split('\n') + table_entries = table.split("\n") regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") for line in table_entries: k = line.split(":")[0] if k in self.runtime_keys: - entry = regex.sub('[', line) + entry = regex.sub("[", line) output += self._format_line(k, entry.split(": ")[1]) return output diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 48a25353ef346..bb7b65031b8b1 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1,43 +1,45 @@ +import builtins +import collections +import math +import operator +import warnings + +from collections.abc import Iterable +from enum import Enum +from functools import partial, reduce, wraps +from typing import Callable, List, Optional, overload, Sequence, Tuple, Union + import torch import torch._prims as prims -import torch._prims.utils as utils -from torch._prims.utils import ( +import torch._prims_common as utils +from torch._prims_common import ( check, + DeviceLikeType, + DimsSequenceType, DimsType, + dtype_to_type, + ELEMENTWISE_TYPE_PROMOTION_KIND, + is_weakly_lesser_type, + Number, + NumberType, + REDUCTION_OUTPUT_TYPE_KIND, ShapeType, StrideType, TensorLike, TensorLikeType, - DeviceLikeType, TensorOrNumberLikeType, - DimsSequenceType, TensorSequenceType, - Number, - NumberType, - ELEMENTWISE_TYPE_PROMOTION_KIND, - REDUCTION_OUTPUT_TYPE_KIND, - is_weakly_lesser_type, - dtype_to_type, ) -from torch._prims.wrappers import ( - elementwise_type_promotion_wrapper, - out_wrapper, +from torch._prims_common.wrappers import ( _maybe_convert_to_dtype, _maybe_resize_out, - elementwise_unary_scalar_wrapper, _safe_copy_out, + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, ) -from collections.abc import Iterable -from functools import reduce, partial, wraps -from typing import Sequence, Optional, Union, Callable, List, Tuple -import operator -import warnings -import math -from enum import Enum -import collections - # Experimental module containing prototype Python references for existing # PyTorch operations. @@ -48,11 +50,14 @@ "abs", "acos", "acosh", + "asinh", "asin", "atan", + "atanh", "bitwise_not", # "cbrt", # No corresponding torch operation "ceil", + "conj_physical", "cos", "cosh", "digamma", @@ -89,6 +94,7 @@ "tan", "tanh", "trace", + "trunc", # # Elementwise Binary References # @@ -100,19 +106,19 @@ "bitwise_right_shift", "bitwise_xor", # "complex", - # 'copysign', # where - # 'div', # need to implement all rounding modes first + "copysign", + "div", "eq", "float_power", - # 'floor_divide', # requires floor + "floor_divide", "fmax", "fmin", "fmod", "gcd", "ge", "gt", - # 'heaviside', - # 'hypot', + "heaviside", + "hypot", "igamma", "igammac", "isclose", @@ -134,11 +140,12 @@ # 'polar', # abs, cos, sin "pow", "remainder", - # 'rsub', # unblocked + "rsub", # # special.xlog1py # # special.zeta "sub", "true_divide", + "trunc_divide", # 'xlogy', # where?, log, mul # # Elementwise Ternary References @@ -185,8 +192,12 @@ "cat", "chunk", "column_stack", + "conj", + "constant_pad_nd", + "contiguous", "dsplit", "dstack", + "expand", "flatten", "flip", "fliplr", @@ -221,8 +232,12 @@ "full_like", "ones", "ones_like", + "scalar_tensor", "zeros", "zeros_like", + "arange", + "linspace", + "logspace", # # Randomness References # @@ -230,6 +245,7 @@ # # Test-related functions # + "allclose", "equal", # TODO: add OpInfo ] @@ -325,7 +341,7 @@ def inner(prim: Callable): nonlocal aten_op @wraps(prim) - @out_wrapper + @out_wrapper() @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a",), @@ -372,11 +388,21 @@ def asin(a): return prims.asin(a) +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asinh(a): + return prims.asinh(a) + + @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) def atan(a): return prims.atan(a) +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atanh(a): + return prims.atanh(a) + + @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) def bitwise_not(a): return prims.bitwise_not(a) @@ -387,6 +413,14 @@ def ceil(a): return prims.ceil(a) +@register_decomposition(torch.ops.aten.conj_physical) +@out_wrapper() +def conj_physical(input: TensorLikeType): + if not input.dtype.is_complex: + return input + return prims.conj_physical(input) + + @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) def cos(a): return prims.cos(a) @@ -433,7 +467,8 @@ def exp2(a): # Fill has its own implementation because it has a value parameter -@out_wrapper +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a,"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, @@ -453,6 +488,18 @@ def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType: return prims.fill(a, value) +def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType: + r = prims.fill(a, value) + prims.copy_to(a, r) + return a + + +def zero_(a: TensorLikeType) -> TensorLikeType: + r = prims.fill(a, 0) + prims.copy_to(a, r) + return a + + @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) def floor(a): return prims.floor(a) @@ -464,6 +511,15 @@ def frac(x: TensorLikeType) -> TensorLikeType: return sub(x, trunc_x) +# imag does not use _make_elementwise_unary_reference because it does not support out +def imag(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + utils.check( + utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." + ) + return prims.imag(a) + + @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=None, # CompositeImplicitAutograd @@ -477,14 +533,30 @@ def isfinite(a: TensorLikeType) -> TensorLikeType: @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) def isinf(a: TensorLikeType) -> TensorLikeType: - # TODO Add complex tensor support to remove is_infinite prim - # if utils.is_complex_dtype(a): - # return bitwise_or(_isinf(real(a), _isinf(imag(a)) - # else: - # return bitwise_not(bitwise_or(isnan(a), isfinite(a))) - if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): - return prims.is_infinite(a) + if utils.is_complex_dtype(a.dtype): + return logical_or(isinf(real(a)), isinf(imag(a))) + return logical_not(logical_or(isnan(a), isfinite(a))) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isposinf(a: TensorLikeType) -> TensorLikeType: + utils.check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return eq(a, float("inf")) + return zeros_like(a, dtype=torch.bool) + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isneginf(a: TensorLikeType) -> TensorLikeType: + utils.check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return eq(a, float("-inf")) return zeros_like(a, dtype=torch.bool) @@ -526,7 +598,7 @@ def log10(a): return prims.log10(a) -@out_wrapper +@out_wrapper() def log_softmax( a: TensorLikeType, dim: int, @@ -539,7 +611,7 @@ def log_softmax( return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value] -@out_wrapper +@out_wrapper() def logsumexp( a: TensorLikeType, dim: DimsType, @@ -562,7 +634,7 @@ def logsumexp( @register_decomposition(torch.ops.aten.nan_to_num) -@out_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a,"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -610,6 +682,7 @@ def neg(a): # positive does not use _make_elementwise_unary_reference because it does not support out +# CompositeImplicitAutograd - don't register decomp def positive(a: TensorLikeType) -> TensorLikeType: assert isinstance(a, TensorLike) if a.dtype is torch.bool: @@ -618,6 +691,14 @@ def positive(a: TensorLikeType) -> TensorLikeType: return a +# real does not use _make_elementwise_unary_reference because it does not support out +def real(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if utils.is_complex_dtype(a.dtype): + return prims.real(a) + return a + + @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) def reciprocal(a): return prims.reciprocal(a) @@ -685,6 +766,11 @@ def tanh(a): return prims.tanh(a) +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def trunc(a): + return prims.trunc(a) + + def _make_elementwise_binary_reference( prim: Callable, *, @@ -723,7 +809,7 @@ def _ref( return prim(a, b) if has_out: - _ref = out_wrapper(_ref) + _ref = out_wrapper()(_ref) if aten_op is infer_aten_op: aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0]) @@ -735,7 +821,7 @@ def _ref( # Add has its own implementation because it has an alpha argument @register_decomposition(torch.ops.aten.add) -@out_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a", "b"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -812,9 +898,57 @@ def add( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) + +def _copysign( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + if isinstance(b, Number) and isinstance(a, Tensor): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + msg = "Expected divisor (b) to be on the same device ({0}) as dividend (a), but it is found on {1}!".format( + a.device, b.device + ) + raise RuntimeError(msg) + return where(signbit(b), neg(abs(a)), abs(a)) + + +# TODO: add docstring +copysign = _make_elementwise_binary_reference( + _copysign, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + aten_op=torch.ops.aten.copysign, +) + # TODO: add docstring # complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) + +@register_decomposition(torch.ops.aten.div) +@out_wrapper() +def div( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + rounding_mode: Optional[str] = None, +): + """ + Reference implementation of torch.div + """ + if rounding_mode is None: + return true_divide(a, b) + elif rounding_mode == "trunc": + return trunc_divide(a, b) + elif rounding_mode == "floor": + return floor_divide(a, b) + else: + msg = ( + "div expected rounding_mode to be one of None, 'trunc', or 'floor' " + "but found {0}.".format(rounding_mode) + ) + raise ValueError(msg) + + # TODO: add docstring eq = _make_elementwise_binary_reference( prims.eq, @@ -822,10 +956,34 @@ def add( supports_lhs_python_scalar=False, ) + +def _pow( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> TensorLikeType: + assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType) + + if isinstance(b, Number): + if b == 1.0: + return a.clone() # type: ignore[return-value,union-attr] + elif b == 2.0: + return a * a # type: ignore[return-value] + elif b == 0.5: + return torch.sqrt(a) # type: ignore[arg-type] + return prims.pow(a, b) + + +# TODO: add docstring +pow = _make_elementwise_binary_reference( + _pow, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, + aten_op=torch.ops.aten.pow, +) + # TODO: add docstring # Float power has its own implementation because it has unique type promotion. # NB: aten_op not registered because CompositeExplicitAutograd -@out_wrapper +@out_wrapper() def float_power( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType], @@ -852,7 +1010,85 @@ def float_power( b = prims.to_dtype(b, dtype) a, b = _maybe_broadcast(a, b) - return prims.pow(a, b) + return pow(a, b) + + +# >>> a = torch.tensor(-0.2500, dtype=torch.float64) +# tensor(-0.250000000000000, dtype=torch.float64) +# +# >>> b = torch.tensor(-0.0010, dtype=torch.float64) +# tensor(-0.001000000000000, dtype=torch.float64) +# +# Note: In this case, casting float to double will expand the float mantissa with zeros, +# while creating a double generates a distinct mantissa. +# >>> torch.tensor(-0.001).to(dtype=torch.float64) +# tensor(-0.001000000047497, dtype=torch.float64) +# +# Floor Division +# The difference is caused because torch.remainder(a, b) = -0.001. +# +# >>> torch.floor(torch.true_divide(a, b)) +# tensor(250., dtype=torch.float64) +# +# >>> torch.div(a, b, rounding_mode='floor') +# tensor(249., dtype=torch.float64) +# +# Definition: a // b = (a - remainder(a, b)) / b +# >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b) +# tensor(249., dtype=torch.float64) +# +# For reference, see CPython's implementation: +# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 +def _floor_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + # Wrap scalars because some references only accept tensor arguments. + if isinstance(a, Number) and isinstance(b, Number): + a = scalar_tensor(a) + b = scalar_tensor(b) + elif isinstance(b, Number) and isinstance(a, Tensor): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Number) and isinstance(b, Tensor): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + if a.device == torch.device("cpu"): + msg = "Expected divisor (b) to be on the same device ({0}) as dividend (a), but it is found on {1}!".format( + a.device, b.device + ) + raise RuntimeError(msg) + else: + b = prims.device_put(b, device=a.device) + + mod = fmod(a, b) + div = true_divide(sub(a, mod), b) + + # Ensure that the remainder has the same sign as denominator + different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0)) + non_zero_remainder = ne(mod, 0) + mask = bitwise_and(non_zero_remainder, different_signed_inputs) + div = where(mask, sub(div, 1), div) + + # Map quotient to nearest integer value + floor_div = floor(div) + mask = gt(sub(div, floor_div), 0.5) + floor_div = where(mask, add(floor_div, 1), floor_div) + + basic_div = true_divide(a, b) + zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device) + + # If quotient is zero, copy signbit from true_divide quotient + floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div)) + + # If denominator is zero, then follow true_divide behavior + return where(ne(b, 0), floor_div, basic_div) + + +# TODO: add docstring +floor_divide = _make_elementwise_binary_reference( + _floor_divide, + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + aten_op=torch.ops.aten.floor_divide, +) # TODO: add docstring @@ -860,6 +1096,8 @@ def float_power( prims.fmax, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, aten_op=torch.ops.aten.fmax, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, ) # TODO: add docstring @@ -867,6 +1105,8 @@ def float_power( prims.fmin, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, aten_op=torch.ops.aten.fmin, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, ) # TODO: add docstring @@ -874,6 +1114,8 @@ def float_power( prims.fmod, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, aten_op=torch.ops.aten.fmod, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=True, ) # TODO: add docstring @@ -899,6 +1141,30 @@ def float_power( supports_lhs_python_scalar=False, ) + +def _heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: + input_eq_zero = eq(input, 0) + input_lt_zero = logical_or(lt(input, 0), isnan(input)) + zeros_and_ones = where(input_lt_zero, 0, 1) + output = where(input_eq_zero, values, zeros_and_ones) + return output + + +heaviside = _make_elementwise_binary_reference( + _heaviside, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, + aten_op=torch.ops.aten.heaviside, +) + +hypot = _make_elementwise_binary_reference( + prims.hypot, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) + igamma = _make_elementwise_binary_reference( prims.igamma, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, @@ -914,33 +1180,44 @@ def float_power( ) -def isclose( +def _check_close_args( + name: str, a: TensorLikeType, b: TensorLikeType, - rtol: float = 1e-05, - atol: float = 1e-08, - equal_nan: bool = False, -) -> TensorLikeType: + rtol: float, + atol: float, +) -> None: check( a.dtype == b.dtype, - lambda: "torch.isclose: Attempting to compare tensors of different dtypes {0} and {1}!".format( - a.dtype, b.dtype + lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format( + name, a.dtype, b.dtype ), ValueError, ) check( rtol >= 0, - lambda: "torch.isclose: rtol must be greater than or equal to zero, but got {0}!".format( - rtol + lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format( + name, rtol ), ) check( atol >= 0, - lambda: "torch.isclose: atol must be greater than or equal to zero, but got {0}!".format( - atol + lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format( + name, atol ), ) + +# CompositeImplicitAutograd - don't register decomp +def isclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> TensorLikeType: + _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol) + close = eq(a, b) if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): close = logical_or(close, logical_and(isnan(a), isnan(b))) @@ -1027,7 +1304,7 @@ def _logical_or(a: TensorLikeType, b: TensorLikeType): a = a != 0 if not utils.is_boolean_dtype(b.dtype): b = b != 0 - return a | b + return bitwise_or(a, b) logical_or = _make_elementwise_binary_reference( @@ -1093,12 +1370,6 @@ def _logical_xor(a: TensorLikeType, b: TensorLikeType): supports_rhs_python_scalar=False, ) -# TODO: add docstring -pow = _make_elementwise_binary_reference( - prims.pow, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, -) - # TODO: add docstring remainder = _make_elementwise_binary_reference( prims.remainder, @@ -1106,11 +1377,24 @@ def _logical_xor(a: TensorLikeType, b: TensorLikeType): aten_op=torch.ops.aten.remainder, ) +# reverse sub +def rsub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: Optional[NumberType] = None, +): + if isinstance(a, Number): + msg = "Received a Number for the first argument, but expected a Tensor" + raise ValueError(msg) + return sub(b, a, alpha=alpha) + + # TODO: add docstring # TODO: consider refactoring this with add impl # sub has its own implementation because it has an alpha argument @register_decomposition(torch.ops.aten.sub) -@out_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a", "b"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -1122,7 +1406,7 @@ def sub( alpha: Optional[NumberType] = None, ): """ - Reference implementation of torch.add + Reference implementation of torch.sub """ if isinstance(a, Number) and isinstance(b, Number): @@ -1154,12 +1438,27 @@ def sub( aten_op=None, # CompositeImplicitAutograd ) + +def _trunc_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + return trunc(true_divide(a, b)) + + +# TODO: add docstring +trunc_divide = _make_elementwise_binary_reference( + _trunc_divide, + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + aten_op=None, # CompositeImplicitAutograd +) + # # Elementwise Ternary References # -@out_wrapper +@register_decomposition(torch.ops.aten.clamp) +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a", "min", "max"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -1169,17 +1468,43 @@ def clamp( min: Optional[TensorOrNumberLikeType] = None, max: Optional[TensorOrNumberLikeType] = None, ) -> TensorLikeType: - a, min, max = _maybe_broadcast(a, min, max) - - if min is not None and max is not None: - return minimum(maximum(a, min), max) + # NOTE: grad behavior with implementation `where` is not consistent on `nan` + if min is None and max is None: + msg = "clamp called but both min and max are none!" + raise ValueError(msg) if min is not None: - return maximum(a, min) + a_isnan = torch.isnan(a) + condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type] + # we should also propagate `nan` coming from boundaries. However, that's + # not necessary since `ge` would already `False` when either operands has + # a `nan`. So this line below is redundant + # `condition = bitwise_and(condition, bitwise_not(isnan(min)))` + a = torch.where(condition, a, min) # type: ignore[arg-type] if max is not None: - return minimum(a, max) + a_isnan = torch.isnan(a) + # same as above, no need to adjust `nan` from `max` + condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type] + a = torch.where(condition, a, max) # type: ignore[arg-type] + + return a + + +@register_decomposition(torch.ops.aten.clamp_min) +@out_wrapper() +def clamp_min( + self: TensorLikeType, + min: TensorOrNumberLikeType = None, +) -> TensorLikeType: + return torch.clamp(self, min=min) # type: ignore[arg-type] + - msg = "clamp called but both min and max are none!" - raise ValueError(msg) +@register_decomposition(torch.ops.aten.clamp_max) +@out_wrapper() +def clamp_max( + self: TensorLikeType, + max: TensorOrNumberLikeType = None, +) -> TensorLikeType: + return torch.clamp(self, max=max) # type: ignore[arg-type] # @@ -1189,7 +1514,7 @@ def clamp( # https://pytorch.org/docs/stable/generated/torch.where.html # TODO: implement alternate where @register_decomposition(torch.ops.aten.where) -@out_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a", "b"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, @@ -1205,7 +1530,10 @@ def where( raise NotImplementedError utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) - assert pred.dtype is torch.bool + check( + pred.dtype is torch.bool, + lambda: f"expected predicate to be bool, got {pred.dtype}", + ) pred, a, b = _maybe_broadcast(pred, a, b) return prims.where(pred, a, b) @@ -1214,12 +1542,14 @@ def where( # # Data Movement References # -# TODO: Turn this into a decomposition (currently fails on reshape meta tests) def clone( a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format ) -> TensorLikeType: - - return prims.clone(a, memory_format=memory_format) + result = torch.empty_like( + a, requires_grad=a.requires_grad, memory_format=memory_format + ) + copy_to(result, a) + return result def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True): @@ -1232,6 +1562,7 @@ def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True): return prims.copy_to(a, b) +@register_decomposition(torch.ops.aten.item) def item(a: TensorLikeType) -> NumberType: if a.numel() != 1: msg = f"Can't convert a tensor with {a.numel()} elements to a number!" @@ -1317,7 +1648,8 @@ def _reduction( py_all = all -@out_wrapper +@register_decomposition(torch.ops.aten.all) +@out_wrapper() def all( a: TensorLikeType, dim: Optional[DimsType] = None, @@ -1339,7 +1671,12 @@ def all( return result -@out_wrapper +# Saves Python any +py_any = any + + +@register_decomposition(torch.ops.aten.any) +@out_wrapper() def any( a: TensorLikeType, dim: Optional[DimsType] = None, @@ -1411,6 +1748,7 @@ def prod( ) +@register_decomposition(torch.ops.aten.amin) def amin( a: TensorLikeType, dim: Union[Optional[int], Optional[List[int]]] = None, @@ -1434,6 +1772,7 @@ def amin( ) +@register_decomposition(torch.ops.aten.amax) def amax( a: TensorLikeType, dim: Optional[DimsType] = None, @@ -1474,7 +1813,7 @@ def _set_correction( return correction -@out_wrapper +@out_wrapper() def var( a: TensorLikeType, dim: Optional[DimsType] = None, @@ -1501,7 +1840,7 @@ def var( return result -@out_wrapper +@out_wrapper() def std( a: TensorLikeType, dim: Union[Optional[int], Optional[List[int]]] = None, @@ -1533,6 +1872,7 @@ def std( return _maybe_convert_to_dtype(result, dtype) # type: ignore[return-value,arg-type] +@register_decomposition(torch.ops.aten.mean) def mean( a: TensorLikeType, dim: Optional[DimsType] = None, @@ -1588,6 +1928,7 @@ def std_mean( return s, m +@register_decomposition(torch.ops.aten.var_mean) def var_mean( a: TensorLikeType, dim: Optional[DimsType] = None, @@ -1602,7 +1943,7 @@ def var_mean( @register_decomposition(torch.ops.aten.addr) -@out_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("self", "vec1", "vec2"), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -1656,14 +1997,15 @@ def addr( return beta * self + alpha * torch.outer(vec1, vec2) +# CompositeImplicitAutograd - don't register decomp def atleast_1d( arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_1d`.""" - if not args and isinstance(arg, collections.Sequence): + if not args and isinstance(arg, collections.abc.Sequence): args_ = arg else: - assert not isinstance(arg, collections.Sequence) + assert not isinstance(arg, collections.abc.Sequence) args_ = (arg,) + args res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) return res if len(res) > 1 else res[0] @@ -1679,28 +2021,30 @@ def _unsqueeze_atleast( return unsqueeze(arg_, dim) +# CompositeImplicitAutograd - don't register decomp def atleast_2d( arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_2d`.""" - if not args and isinstance(arg, collections.Sequence): + if not args and isinstance(arg, collections.abc.Sequence): args_ = arg else: - assert not isinstance(arg, collections.Sequence) + assert not isinstance(arg, collections.abc.Sequence) args_ = (arg,) + args unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) return res if len(res) > 1 else res[0] +# CompositeImplicitAutograd - don't register decomp def atleast_3d( arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: """Reference implementation of :func:`torch.atleast_3d`.""" - if not args and isinstance(arg, collections.Sequence): + if not args and isinstance(arg, collections.abc.Sequence): args_ = arg else: - assert not isinstance(arg, collections.Sequence) + assert not isinstance(arg, collections.abc.Sequence) args_ = (arg,) + args unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1) res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_) @@ -1721,6 +2065,7 @@ def broadcast_tensors(*tensors) -> List[TensorLikeType]: return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) +# CompositeImplicitAutograd - don't register decomp def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: start = len(size) - len(a.shape) dims = tuple(range(start, len(a.shape) + start)) @@ -1728,7 +2073,7 @@ def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: @register_decomposition(torch.ops.aten.cat) -@out_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("tensors",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, @@ -1762,7 +2107,8 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: return prims.cat(filtered, dim) -@out_wrapper +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() def column_stack(tensors: TensorSequenceType) -> TensorLikeType: aligned_tensors = tuple( x if x.ndim > 1 else prims.expand_dims(x, list(range(x.ndim, 2))) @@ -1771,13 +2117,145 @@ def column_stack(tensors: TensorSequenceType) -> TensorLikeType: return cat(aligned_tensors, 1) -@out_wrapper +def conj(input: TensorLikeType) -> TensorLikeType: + if not input.dtype.is_complex: + return input + if input.is_sparse: + return torch.conj_physical(input) + return prims.conj(input) + + +# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp +@register_decomposition(torch.ops.aten.constant_pad_nd) +def constant_pad_nd( + input: TensorLikeType, pad: List[int], value: NumberType = 0 +) -> TensorLikeType: + check( + len(pad) % 2 == 0, + lambda: f"Length of pad must be even but instead it equals {len(pad)}", + ) + + input_sizes = input.shape + l_inp = len(input_sizes) + + l_pad = len(pad) // 2 + l_diff = l_inp - l_pad + + check( + l_inp >= l_pad, + lambda: "Length of pad should be no more than twice the number of " + f"dimensions of the input. Pad length is {len(pad)} while the input has " + f"{l_inp} dimensions.", + ) + + c_input = input + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] < 0: + c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]) + + if pad[pad_idx + 1] < 0: + c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1]) + + # if none of the pads are positive we can just return the result + if builtins.all(p <= 0 for p in pad): + return c_input.clone() + + new_shape = list(input_sizes[:l_diff]) + + for i in range(l_pad): + pad_idx = len(pad) - ((i + 1) * 2) + new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] + check( + new_dim > 0, + lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " + f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " + f"which is invalid. Check dimension {l_diff + i} of your input.", + ) + new_shape.append(new_dim) + + memory_format = utils.suggest_memory_format(input) + output = torch.empty( + new_shape, + dtype=input.dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=memory_format, + ) + + if value == 0 and input.dtype == torch.bool: + value = False + # torch.fill isn't typed to allow complex values + output = torch.fill(output, value) # type: ignore[arg-type] + + c_output = output + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] > 0: + c_output = c_output.narrow( + i, pad[pad_idx], c_output.shape[i] - pad[pad_idx] + ) + if pad[pad_idx + 1] > 0: + c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1]) + + prims.copy_to(c_output, c_input) + return output + + +def contiguous( + a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format +) -> Tensor: + check( + memory_format != torch.preserve_format, + lambda: "preserve memory format is unsupported by the contiguous operator", + ) + + if utils.is_contiguous_for_memory_format(a, memory_format=memory_format): + return a + + return torch.clone(a, memory_format=memory_format) + + +@out_wrapper() def dstack(tensors: TensorSequenceType) -> TensorLikeType: check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") aligned_tensors = atleast_3d(*tensors) return cat(aligned_tensors, 2) +@register_decomposition(torch.ops.aten.expand, disable_meta=True) +def expand(a: Tensor, *shape) -> Tensor: + # NOTE: cannot use utils.extract_shape_from_varargs here + # because that also validates the shape, but the shape + # given to expand may be "invalid" + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = tuple(shape[0]) + + check( + len(shape) >= len(a.shape), + lambda: "expand: the requested shape has too few dimensions!", + ) + + offset = len(shape) - len(a.shape) + shape_ = list(shape) + for idx, x in enumerate(a.shape): + offset_idx = idx + offset + requested_length = shape[offset_idx] + check( + requested_length == x or x == 1 or requested_length == -1, + lambda: f"expand: attempting to expand a dimension of length {x}!", + ) + + shape_[offset_idx] = requested_length if requested_length != -1 else x + + # At this point shape must be valid + utils.validate_shape(shape_) + + return prims.broadcast_in_dim( + a, shape_, tuple(range(offset, len(a.shape) + offset)) + ) + + def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]: if chunks <= 0: msg = "Expected at least one chunk, but got {0}!".format(chunks) @@ -1802,6 +2280,7 @@ def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, # Note: flatten, unlike prim.collapse and prim.collapse_view has an inclusive end_dim # Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless # a 0D tensor is flattened, in which case it's returned in 1D) +# CompositeImplicitAutograd - don't register decomp def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: start_dim = utils.canonicalize_dim(a.ndim, start_dim) end_dim = utils.canonicalize_dim(a.ndim, end_dim) @@ -1829,6 +2308,7 @@ def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: return prims.rev(a, dims) +# CompositeImplicitAutograd - don't register decomp def fliplr(a: TensorLikeType) -> TensorLikeType: if a.ndim < 2: raise RuntimeError("Input must be >= 2-d.") @@ -1836,6 +2316,7 @@ def fliplr(a: TensorLikeType) -> TensorLikeType: return flip(a, (1,)) +# CompositeImplicitAutograd - don't register decomp def flipud(a: TensorLikeType) -> TensorLikeType: if a.ndim < 1: raise RuntimeError("Input must be >= 1-d.") @@ -1843,6 +2324,7 @@ def flipud(a: TensorLikeType) -> TensorLikeType: return flip(a, (0,)) +# CompositeImplicitAutograd - don't register decomp def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: dim = utils.canonicalize_dim(a.ndim, dim) return prims.slice_in_dim(a, start, start + length, axis=dim) @@ -2083,6 +2565,7 @@ def _reshape_view_helper( # TODO: Turn this into a decomposition (currently fails on reshape meta tests) +# CompositeImplicitAutograd - don't register decomp def reshape(a: TensorLikeType, shape: ShapeType) -> TensorLikeType: return _reshape_view_helper(a, shape, allow_copy=True) @@ -2173,7 +2656,7 @@ def _check_stack_inputs(tensors: TensorSequenceType) -> None: @register_decomposition(torch.ops.aten.stack) -@out_wrapper +@out_wrapper() def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: assert len(tensors) > 0, "stack expects a non-empty TensorList" wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim) @@ -2189,7 +2672,7 @@ def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim) -@out_wrapper +@out_wrapper() def softmax( a: TensorLikeType, dim: int, @@ -2207,7 +2690,8 @@ def softmax( ) # type: ignore[return-value] -@out_wrapper +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() def hstack(tensors: TensorSequenceType) -> TensorLikeType: check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") aligned_tensors = atleast_1d(*tensors) @@ -2216,7 +2700,8 @@ def hstack(tensors: TensorSequenceType) -> TensorLikeType: return cat(aligned_tensors, 1) -@out_wrapper +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() def vstack(tensors: TensorSequenceType) -> TensorLikeType: check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") aligned_tensors = atleast_2d(*tensors) @@ -2224,6 +2709,7 @@ def vstack(tensors: TensorSequenceType) -> TensorLikeType: # Note: although squeeze is documented as having the out= kwarg it doesn't +@register_decomposition(torch.ops.aten.squeeze, disable_meta=True) def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType: if dim is not None: dim = utils.canonicalize_dim(a.ndim, dim) @@ -2242,6 +2728,7 @@ def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType: # Note: does not work with TensorMetas because of data-dependent control-flow +# CompositeImplicitAutograd - don't register decomp def tensor_split( a: TensorLikeType, indices_or_sections: Union[Tensor, DimsType], @@ -2254,7 +2741,7 @@ def tensor_split( # If indices_or_sections is a tensor, it must be a CPU Long tensor if isinstance(indices_or_sections, TensorLike): - if indices_or_sections.device != torch.device("cpu"): + if not indices_or_sections.device.type == "cpu": msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {0}".format( indices_or_sections.device ) @@ -2318,6 +2805,7 @@ def tensor_split( return tuple(splits) +# CompositeImplicitAutograd - don't register decomp def hsplit( a: TensorLikeType, indices_or_sections: DimsType ) -> Tuple[TensorLikeType, ...]: @@ -2360,6 +2848,7 @@ def hsplit( return tensor_split(a, split_sizes, dim) +# CompositeImplicitAutograd - don't register decomp def vsplit( a: TensorLikeType, indices_or_sections: DimsType ) -> Tuple[TensorLikeType, ...]: @@ -2400,6 +2889,7 @@ def vsplit( return tensor_split(a, split_sizes, 0) +# CompositeImplicitAutograd - don't register decomp def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: if a.ndim < 3: raise RuntimeError( @@ -2413,7 +2903,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: return tensor_split(a, sections, 2) -@register_decomposition(torch.ops.aten.t.default) +@register_decomposition(torch.ops.aten.t.default, disable_meta=True) def t(a: TensorLikeType): # TODO: Add sparse support # if a.is_sparse: @@ -2447,7 +2937,7 @@ def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: swap_axes = transpose -@register_decomposition(torch.ops.aten.unsqueeze) +@register_decomposition(torch.ops.aten.unsqueeze, disable_meta=True) def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: # Note that unsqueeze canonicalizes with rank + 1 because it allows # a new innermost dimension to be specified @@ -2456,26 +2946,128 @@ def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: # TODO: Turn this into a decomposition (currently fails on reshape meta tests) +@register_decomposition(torch.ops.aten.view, disable_meta=True) def view(a: TensorLikeType, shape: ShapeType) -> TensorLikeType: return _reshape_view_helper(a, shape, allow_copy=False) +# CompositeImplicitAutograd - don't register decomp def ravel(a: TensorLikeType) -> TensorLikeType: return reshape(a, (-1,)) -@out_wrapper +@out_wrapper() def empty( *shape, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, + layout: Optional[torch.layout] = None, requires_grad: bool = False, + pin_memory: bool = False, + memory_format: torch.memory_format = torch.contiguous_format, ) -> TensorLikeType: + check( + memory_format != torch.preserve_format, + lambda: "torch.empty: the Preserve memory format is not supported", + ) + shape = utils.extract_shape_from_varargs(shape) - strides = utils.make_contiguous_strides_for(shape) - return empty_strided( - shape, strides, dtype=dtype, device=device, requires_grad=requires_grad + + if memory_format == torch.contiguous_format: + strides = utils.make_contiguous_strides_for(shape) + elif memory_format == torch.channels_last_3d: + strides = utils.make_channels_last_3d_strides_for(shape) + else: # memory_format == torch.channels_last + check( + memory_format == torch.channels_last, + lambda: f"torch.empty: received an unknown memory format {memory_format}!", + ) + strides = utils.make_channels_last_2d_strides_for(shape) + + return torch.empty_strided( + shape, + strides, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(torch.ops.aten.new_empty) +def new_empty( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, +) -> TensorLikeType: + + dtype = a.dtype if dtype is None else dtype + device = a.device if device is None else device + + return torch.empty( + size, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(torch.ops.aten.new_zeros) +def new_zeros( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, +) -> TensorLikeType: + r = a.new_empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + r.zero_() + return r + + +@register_decomposition(torch.ops.aten.new_ones) +def new_ones( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, +) -> TensorLikeType: + r = a.new_empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory ) + r.fill_(1) + return r + + +@register_decomposition(torch.ops.aten.new_full) +def new_full( + a: TensorLikeType, + size: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, +) -> TensorLikeType: + r = a.new_empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + r.fill_(fill_value) # type: ignore[arg-type] + return r def empty_like( @@ -2483,43 +3075,271 @@ def empty_like( *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, + layout: Optional[torch.layout] = None, requires_grad: bool = False, + pin_memory: bool = False, + memory_format: torch.memory_format = torch.preserve_format, ) -> TensorLikeType: dtype = a.dtype if dtype is None else dtype device = a.device if device is None else device strides: Tuple[int, ...] - if a.numel() == 0: - strides = a.stride() + + if memory_format != torch.preserve_format: + return torch.empty( + a.shape, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # memory_format == torch.preserve_format + strides = utils.compute_elementwise_output_strides(a) + return torch.empty_strided( + a.shape, + strides, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@overload +def arange( + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + pass + + +@overload +def arange( + start: NumberType, + end: NumberType, + step: NumberType = 1, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + pass + + +# See https://github.com/pytorch/pytorch/issues/82364 +# @register_decomposition(torch.ops.aten.arange) +# @out_wrapper() +@register_decomposition( + [ + torch.ops.aten.arange.default, + torch.ops.aten.arange.start, + torch.ops.aten.arange.start_step, + ] +) +def arange( + a: Optional[NumberType] = None, + b: Optional[NumberType] = None, + step: NumberType = 1, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + assert (a is not None and b is not None) or (a is not None and b is None) + if a is not None and b is not None: + return prims.arange( + a, + b, + step, + dtype=dtype, + device=device, + # layout=layout, + # pin_memory=pin_memory, + requires_grad=requires_grad, + ) + elif a is not None and b is None: + return prims.arange( + 0, + a, + step, + dtype=dtype, + device=device, + # layout=layout, + # pin_memory=pin_memory, + requires_grad=requires_grad, + ) else: - strides = utils.compute_elementwise_output_strides(a) + raise AssertionError() + + +@register_decomposition(torch.ops.aten.linspace) +@out_wrapper() +def linspace( + start: NumberType, + end: NumberType, + steps: NumberType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if dtype is None: + dtype = torch.get_default_dtype() + + # NB: NumPy actually doesn't do this cast, but for this ref, I'd rather have this + # cast than not, because it allows us to always go into the precise path + # if dtype is integral and not worry about whether start/end are float + if prims.utils.is_integer_dtype(dtype): + if isinstance(start, float): + start = int(start) + if isinstance(end, float): + end = int(end) + + if py_any(isinstance(arg, complex) for arg in (start, end, steps)): + raise NotImplementedError + assert not isinstance(start, complex) and not isinstance(end, complex) # for mypy - return empty_strided( - a.shape, strides, dtype=dtype, device=device, requires_grad=requires_grad + check( + isinstance(steps, int), + lambda: "steps must be int, not float", + exc_type=TypeError, ) + assert isinstance(steps, int) # for mypy + check(steps >= 0, lambda: "number of steps must be non-negative") + + factory_kwargs = { + "device": device, + # "layout":layout, + # "pin_memory":pin_memory, + "requires_grad": requires_grad, + } + if steps == 0: + ret = torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] + elif steps == 1: + ret = torch.full((1,), start, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] + elif start == end: + ret = torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] + else: + if prims.utils.is_integer_dtype(dtype): + # We need to cast to int, so to avoid off-by-one issues + # do the entire computation with ints when we can + assert isinstance(start, int) and isinstance(end, int) + step_size_x_denom = end - start + eps = 1 if end > start else -1 + denom = steps - 1 + ret = prims.to_dtype( + torch.arange( + start * denom, + end * denom + eps, + step_size_x_denom, + dtype=torch.int64, + **factory_kwargs, # type: ignore[arg-type] + ) + / denom, + dtype, + ) + else: + step_size = (end - start) / (steps - 1) + eps = step_size / 2 + ret = prims.to_dtype( + torch.arange( # type: ignore[call-overload] + start, end + eps, step_size, dtype=torch.float64, **factory_kwargs + ), + dtype, + ) + + return ret + + +@register_decomposition(torch.ops.aten.logspace) +@out_wrapper() +def logspace( + start: NumberType, + end: NumberType, + steps: NumberType, + base: NumberType = 10, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if dtype is None: + dtype = torch.get_default_dtype() + + # NB: NumPy doesn't have this cast + if prims.utils.is_integer_dtype(dtype): + if isinstance(start, float): + start = int(start) + if isinstance(end, float): + end = int(end) + + assert not isinstance(base, complex) # for mypy + if base < 0: + raise NotImplementedError + ret = torch.linspace( + start, + end, + steps, + dtype=torch.float64, + device=device, + layout=layout, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return prims.to_dtype(torch.pow(base, ret), dtype) # NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints +@register_decomposition(torch.ops.aten.empty_strided) def empty_strided( shape: Union[ShapeType, Tuple[ShapeType]], strides: StrideType, *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, + layout: Optional[torch.layout] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> TensorLikeType: + if pin_memory: + raise NotImplementedError("PrimTorch doesn't support pinned memory") + if layout is not None and layout is not torch.strided: + raise NotImplementedError(f"PrimTorch doesn't support layout={layout}") + shape = utils.extract_shape_from_varargs(shape) dtype = torch.get_default_dtype() if dtype is None else dtype device = torch.device("cpu") if device is None else device return prims.empty_strided( - shape, strides, dtype=dtype, device=device, requires_grad=requires_grad + shape, + strides, + dtype=dtype, + device=device, + requires_grad=requires_grad, ) -@out_wrapper +# TODO: missing kwargs (e.g. layout) +@out_wrapper() def full( shape: ShapeType, fill_value: NumberType, @@ -2548,11 +3368,25 @@ def full_like( ones_like = partial(full_like, fill_value=True) + +# TODO: missing kwargs (e.g. layout) +def scalar_tensor( + a: NumberType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, +) -> TensorLikeType: + dtype = dtype if dtype is not None else utils.type_to_dtype(type(a)) + device = device if device is not None else torch.device("cpu") + return prims.scalar_tensor(a, dtype=dtype, device=device) + + zeros = partial(full, fill_value=False) zeros_like = partial(full_like, fill_value=False) +@register_decomposition(torch.ops.aten.uniform) def uniform( shape: ShapeType, low: Union[bool, int, float] = 0.0, @@ -2607,6 +3441,24 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi return where(mask, prims.to_dtype(value, a.dtype), a) +# CompositeImplicitAutograd - don't register decomp +def allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) + + return bool( + torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() + ) + + # TODO: add OpInfo for torch.equal and refs.equal def equal(a: TensorLikeType, b: TensorLikeType) -> bool: utils.check_same_device(a, b, allow_cpu_scalar_tensors=False) @@ -2627,6 +3479,32 @@ def equal(a: TensorLikeType, b: TensorLikeType) -> bool: return item(all(eq(a, b))) # type: ignore[return-value] +@out_wrapper(exact_dtype=True) +def norm( + input: TensorLikeType, + p: Optional[Union[float, str]] = "fro", + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # In these cases we compute the "Frobenius norm" + if ( + p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2) + ) or p is None: + p = 2 + if isinstance(dim, int): + dim = [dim] + if isinstance(p, str): + # Here we either call the nuclear norm, or we call matrix_norm with some arguments + # that will throw an error + if dim is None: + dim = tuple(range(input.ndim)) + return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype) + else: + return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype) + + @register_decomposition(torch.ops.aten.trace) def trace(self: TensorLikeType) -> TensorLikeType: utils.check( @@ -2635,5 +3513,7 @@ def trace(self: TensorLikeType) -> TensorLikeType: return torch.sum(torch.diag(self, 0)) +import torch._refs.fft +import torch._refs.linalg import torch._refs.nn.functional import torch._refs.special diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py new file mode 100644 index 0000000000000..d92ef6914c2d1 --- /dev/null +++ b/torch/_refs/fft.py @@ -0,0 +1,571 @@ +import math + +from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple, Union + +from typing_extensions import Literal + +import torch +import torch._prims as prims +import torch._prims_common as utils +from torch._decomp import register_decomposition +from torch._prims_common import check, DimsType, ShapeType, TensorLikeType +from torch._prims_common.wrappers import out_wrapper + +__all__ = [ + # Transforms + "fft", + "fft2", + "fftn", + "hfft", + "hfft2", + "hfftn", + "rfft", + "rfft2", + "rfftn", + "ifft", + "ifft2", + "ifftn", + "ihfft", + "ihfft2", + "ihfftn", + "irfft", + "irfft2", + "irfftn", + # Helpers + "fftshift", + "ifftshift", +] + +NormType = Union[None, Literal["forward"], Literal["backward"], Literal["ortho"]] +_NORM_VALUES = {None, "forward", "backward", "ortho"} + + +def _apply_norm( + x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool +) -> TensorLikeType: + """Apply normalization to the un-normalized FFT result""" + check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") + + if norm == "ortho": + return x * (1 / math.sqrt(signal_numel)) + + normalize = (not forward and (norm is None or norm == "backward")) or ( + forward and norm == "forward" + ) + return x * (1 / signal_numel) if normalize else x + + +def _promote_type_fft(dtype: torch.dtype, require_complex: bool) -> torch.dtype: + """Helper to promote a dtype to one supported by the FFT primitives""" + if dtype.is_complex: + return dtype + + # Promote integral to default float type + if not dtype.is_floating_point: + dtype = torch.get_default_dtype() + + if require_complex: + dtype = utils.corresponding_complex_dtype(dtype) + + return dtype + + +def _maybe_promote_tensor_fft( + t: TensorLikeType, require_complex: bool = False +) -> TensorLikeType: + """Helper to promote a tensor to a dtype supported by the FFT primitives""" + cur_type = t.dtype + new_type = _promote_type_fft(cur_type, require_complex) + if cur_type == new_type: + return t + return prims.convert_element_type(t, new_type) + + +def _resize_fft_input( + x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...] +) -> TensorLikeType: + """ + Fixes the shape of x such that x.size(dims[i]) == sizes[i], + either by zero-padding, or by slicing x starting from 0. + """ + assert len(dims) == len(sizes) + must_copy = False + x_sizes = x.shape + pad_amount = [0] * len(x_sizes) * 2 + for i in range(len(dims)): + if sizes[i] == -1: + continue + + if x_sizes[dims[i]] < sizes[i]: + must_copy = True + pad_idx = len(pad_amount) - 2 * dims[i] - 1 + pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] + + if x_sizes[dims[i]] > sizes[i]: + x = x.narrow(dims[i], 0, sizes[i]) + + return torch.constant_pad_nd(x, pad_amount) if must_copy else x + + +def _fft_c2r( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to real FFT (irfft or hfft)""" + input = _maybe_promote_tensor_fft(input, require_complex=True) + dims = (utils.canonicalize_dim(input.ndim, dim),) + last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) + check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified") + + if n is not None: + input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) + + if forward: + input = torch.conj(input) + + output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) + return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward) + + +def _fft_r2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, + onesided: bool, +) -> TensorLikeType: + """Common code for performing any real to complex FFT (rfft or ihfft)""" + check( + not input.dtype.is_complex, + lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", + ) + input = _maybe_promote_tensor_fft(input) + dims = (utils.canonicalize_dim(input.ndim, dim),) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_r2c(input, dim=dims, onesided=onesided) + ret = _apply_norm(ret, norm, input.shape[dim], forward) + return ret if forward else torch.conj(ret) + + +def _fft_c2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to complex FFT (fft or ifft)""" + check( + input.dtype.is_complex, + lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", + ) + dims = (utils.canonicalize_dim(input.ndim, dim),) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_c2c(input, dim=dims, forward=forward) + return _apply_norm(ret, norm, input.shape[dim], forward) + + +@register_decomposition(torch.ops.aten.fft_fft) +@out_wrapper() +def fft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("fft", input, n, dim, norm, forward=True) + else: + return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False) + + +@register_decomposition(torch.ops.aten.fft_ifft) +@out_wrapper() +def ifft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("ifft", input, n, dim, norm, forward=False) + else: + return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False) + + +@register_decomposition(torch.ops.aten.fft_rfft) +@out_wrapper() +def rfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True) + + +@register_decomposition(torch.ops.aten.fft_irfft) +@out_wrapper() +def irfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("irfft", input, n, dim, norm, forward=False) + + +@register_decomposition(torch.ops.aten.fft_hfft) +@out_wrapper() +def hfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("hfft", input, n, dim, norm, forward=True) + + +@register_decomposition(torch.ops.aten.fft_ihfft) +@out_wrapper() +def ihfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True) + + +class _ShapeAndDims(NamedTuple): + shape: Tuple[int, ...] + dims: Tuple[int, ...] + + +def _canonicalize_fft_shape_and_dim_args( + input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType] +) -> _ShapeAndDims: + """Convert the shape and dim arguments into a canonical form where neither are optional""" + input_dim = input.ndim + input_sizes = input.shape + + if dim is not None: + if not isinstance(dim, Sequence): + dim = (dim,) + ret_dims = utils.canonicalize_dims(input_dim, dim) + + # Check dims are unique + check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique") + + if shape is not None: + if not isinstance(shape, Sequence): + shape = (shape,) + + # Has shape, might have dim + check( + dim is None or len(dim) == len(shape), + lambda: "When given, dim and shape arguments must have the same length", + ) + transform_ndim = len(shape) + + check( + transform_ndim <= input_dim, + lambda: f"Got shape with {transform_ndim} values but input tensor " + f"only has {input_dim} dimensions.", + ) + + # If shape is given, dims defaults to the last len(shape) dimensions + if dim is None: + ret_dims = tuple(range(input_dim - transform_ndim, input_dim)) + + # Translate any -1 values in shape to the default length + ret_shape = tuple( + s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) + ) + elif dim is None: + # No shape, no dim + ret_dims = tuple(range(input_dim)) + ret_shape = tuple(input_sizes) + else: + # No shape, has dim + ret_shape = tuple(input_sizes[d] for d in ret_dims) + + for n in ret_shape: + check(n > 0, lambda: f"Invalid number of data points ({n}) specified") + + return _ShapeAndDims(shape=ret_shape, dims=ret_dims) + + +def _prod(xs: Iterable[int]) -> int: + """Compute product of a list""" + prod = 1 + for x in xs: + prod *= x + return prod + + +def _fftn_c2c( + function_name: str, + input: TensorLikeType, + shape: Tuple[int, ...], + dim: Tuple[int, ...], + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" + check( + input.dtype.is_complex, + lambda: f"{function_name} expects a complex input tensor, " + f"but got {input.dtype}", + ) + x = _resize_fft_input(input, dim, shape) + output = prims.fft_c2c(x, dim=dim, forward=forward) + return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward) + + +@register_decomposition(torch.ops.aten.fft_fftn) +@out_wrapper() +def fftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("fftn", x, shape, dim, norm, forward=True) + + +@register_decomposition(torch.ops.aten.fft_ifftn) +@out_wrapper() +def ifftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False) + + +@register_decomposition(torch.ops.aten.fft_rfftn) +@out_wrapper() +def rfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + check( + not input.dtype.is_complex, + lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_r2c(input, dim=dim, onesided=True) + return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True) + + +@register_decomposition(torch.ops.aten.fft_ihfftn) +@out_wrapper() +def ihfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + check( + not input.dtype.is_complex, + lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True) + + if len(dim) == 1: + tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False) + return prims.conj(tmp) + + tmp = prims.conj_physical(tmp) + tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False) + return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False) + + +class _CanonicalizeC2rReturn(NamedTuple): + shape: Tuple[int, ...] + dim: Tuple[int, ...] + last_dim_size: int + + +def _canonicalize_fft_c2r_shape_and_dim_args( + fname: str, + input: TensorLikeType, + s: Optional[ShapeType], + dim: Optional[DimsType], +) -> _CanonicalizeC2rReturn: + """Canonicalize shape and dim arguments for n-dimensional c2r transforms, + as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") + + if s is None or s[-1] == -1: + last_dim_size = 2 * (input.shape[dim[-1]] - 1) + else: + last_dim_size = shape[-1] + + check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + shape_list = list(shape) + shape_list[-1] = last_dim_size // 2 + 1 + return _CanonicalizeC2rReturn( + shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size + ) + + +@register_decomposition(torch.ops.aten.fft_irfftn) +@out_wrapper() +def irfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "irfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size) + return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False) + + +@register_decomposition(torch.ops.aten.fft_hfftn) +@out_wrapper() +def hfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "hfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input + tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True) + tmp = prims.conj_physical(tmp) + out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size) + return _apply_norm(out, norm, last_dim_size, forward=True) + + +@register_decomposition(torch.ops.aten.fft_fft2) +@out_wrapper() +def fft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.fftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(torch.ops.aten.fft_ifft2) +@out_wrapper() +def ifft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ifftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(torch.ops.aten.fft_rfft2) +@out_wrapper() +def rfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.rfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(torch.ops.aten.fft_irfft2) +@out_wrapper() +def irfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.irfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(torch.ops.aten.fft_hfft2) +@out_wrapper() +def hfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.hfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(torch.ops.aten.fft_ihfft2) +@out_wrapper() +def ihfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm) + + +def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]: + """Convert Optional[DimsType] to a simple list, defaulting to all dimensions""" + if dim is None: + return list(range(x.ndim)) + elif not isinstance(dim, Sequence): + return [dim] + else: + return list(dim) + + +@register_decomposition(torch.ops.aten.fft_fftshift) +def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [input.shape[d] // 2 for d in dims] + return torch.roll(input, shift, dims) + + +@register_decomposition(torch.ops.aten.fft_ifftshift) +def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [(input.shape[d] + 1) // 2 for d in dims] + return torch.roll(input, shift, dims) diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py new file mode 100644 index 0000000000000..c3b8a3c603524 --- /dev/null +++ b/torch/_refs/linalg/__init__.py @@ -0,0 +1,255 @@ +from functools import partial + +from typing import List, Optional, Tuple, Union + +import torch + +import torch._prims as prims + +import torch._prims_common as utils +import torch._refs as refs +import torch._refs.linalg as linalg +from torch import Tensor +from torch._prims_common import ( + check, + check_fp_or_complex, + check_is_matrix, + DimsType, + NumberType, + TensorLikeType, +) +from torch._prims_common.wrappers import out_wrapper + +__all__ = [ + "svd", + "vector_norm", + "matrix_norm", + "norm", +] + + +def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): + """ + Checks related to the dtype kwarg in `linalg.*norm` functions + """ + if dtype is not None: + check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", + ) + check( + utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), + lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( + fn_name=fn_name, + d="complex" if utils.is_complex_dtype(x_dtype) else "real", + dtype=dtype, + ), + ) + check( + utils.get_higher_dtype(dtype, x_dtype) == dtype, + lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " + "without narrowing to the specified dtype ({dtype})", + ) + + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + + +@register_decomposition(torch.ops.aten.linalg_vector_norm) +@out_wrapper(exact_dtype=True) +def vector_norm( + x: TensorLikeType, + ord: float = 2.0, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + # Checks + check_fp_or_complex(x.dtype, "linalg.vector_norm") + + if isinstance(dim, int): + dim = [dim] # type: ignore[assignment] + elif not isinstance(dim, List) and dim is not None: + # refs.amin just accepts List rather than DimType (Tuple) + dim = list(dim) # type: ignore[assignment] + + if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): + check( + dim is not None and len(dim) != 0, + lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " + "because the operation does not have an identity", + ) + shape = x.shape + assert dim is not None # mypy does not seem to be able to see through check? + for d in dim: + check( + shape[d] != 0, + lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " + f"dimension {d} because this dimension is empty and the " + "operation does not have an identity", + ) + check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") + + computation_dtype, result_dtype = utils.reduction_dtypes( + x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype + ) + + to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype) + + # Implementation + if ord == 0.0: + return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) + elif ord == float("inf"): + return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim)) + elif ord == float("-inf"): + return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim)) + else: + # From here on the computation dtype is important as the reduction is non-trivial + x = prims.convert_element_type(x, computation_dtype) + reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim) + + if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)): + x = torch.abs(x) + return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) + + +def backshift_permutation(dim0, dim1, ndim): + # Auxiliary function for matrix_norm + # Computes the permutation that moves the two given dimensions to the back + ret = [i for i in range(ndim) if i != dim0 and i != dim1] + ret.extend((dim0, dim1)) + return ret + + +def inverse_permutation(perm): + # Given a permutation, returns its inverse. It's equivalent to argsort on an array + return [i for i, j in sorted(enumerate(perm), key=lambda i_j: i_j[1])] + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def matrix_norm( + A: TensorLikeType, + ord: Union[float, str] = "fro", + dim: DimsType = (-2, -1), + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # shape + check_is_matrix(A, "linalg.matrix_norm") + # dim + dim = utils.canonicalize_dims(A.ndim, dim) + if isinstance(dim, int): + dim = (dim,) # type: ignore[assignment] + check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}") + check( + dim[0] != dim[1], + lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", + ) + # dtype arg + check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") + + if isinstance(ord, str): + # ord + check( + ord in ("fro", "nuc"), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc" + ) + + if ord == "fro": + return vector_norm(A, 2, dim, keepdim, dtype=dtype) + else: # ord == "nuc" + if dtype is not None: + A = prims.convert_element_type(A, dtype) + perm = backshift_permutation(dim[0], dim[1], A.ndim) + result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) + if keepdim: + inv_perm = inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: + # ord + abs_ord = abs(ord) + check( + abs_ord in (2, 1, float("inf")), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2 + ) + + max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + + if abs_ord == 2.0: + if dtype is not None: + A = prims.convert_element_type(A, dtype) + perm = backshift_permutation(dim[0], dim[1], A.ndim) + result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) + if keepdim: + inv_perm = inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: # 1, -1, inf, -inf + dim0, dim1 = dim + if abs_ord == float("inf"): + dim0, dim1 = dim1, dim0 + if not keepdim and (dim0 < dim1): + dim1 -= 1 + return max_min( + vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 + ) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def norm( + A: TensorLikeType, + ord: Optional[Union[float, str]] = None, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + if dim is not None: + if isinstance(dim, int): + dim = (dim,) # type: ignore[assignment] + check( + len(dim) in (1, 2), + lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", + ) + elif ord is not None: + check( + A.ndim in (1, 2), + lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", + ) + + if ord is not None and ( + (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2) + ): + if dim is None: + dim = (0, 1) + return matrix_norm(A, ord, dim, keepdim, dtype=dtype) + else: + if ord is None: + ord = 2.0 + return vector_norm(A, ord, dim, keepdim, dtype=dtype) + + +# CompositeImplicitAutograd +@out_wrapper("U", "S", "Vh", exact_dtype=True) +def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + return prims.svd(A, full_matrices=full_matrices) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def svdvals(A: TensorLikeType) -> Tensor: + return svd(A, full_matrices=False)[1] diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index d80570242cac0..702589e72ad2e 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,27 +1,29 @@ +from typing import Optional, Union + import torch -import torch._prims.utils as utils -from torch._prims.utils import ( +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch._decomp import register_decomposition +from torch._prims_common import ( + check, + ELEMENTWISE_TYPE_PROMOTION_KIND, + NumberType, ShapeType, TensorLike, TensorLikeType, - NumberType, - ELEMENTWISE_TYPE_PROMOTION_KIND, ) -import torch._refs as refs -from torch._decomp import register_decomposition -from torch._prims.wrappers import ( +from torch._prims_common.wrappers import ( elementwise_type_promotion_wrapper, elementwise_unary_scalar_wrapper, out_wrapper, ) from torch._refs import ( - _make_elementwise_unary_reference, _make_elementwise_binary_reference, + _make_elementwise_unary_reference, ) -from typing import Optional - __all__ = [ "celu", "dropout", @@ -29,13 +31,18 @@ "hardshrink", "hardtanh", "hinge_embedding_loss", + "l1_loss", "margin_ranking_loss", "mish", + "mse_loss", + "prelu", "relu", + "relu6", "selu", "softplus", "softshrink", "tanhshrink", + "threshold", ] Tensor = torch.Tensor @@ -75,6 +82,7 @@ def celu( # TODO: should we allow the user to set a different dtype for the mask generation? +@register_decomposition(torch.ops.aten.dropout) def dropout( a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False ) -> TensorLikeType: @@ -201,6 +209,7 @@ def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: return a * torch.tanh(torch.nn.functional.softplus(a)) +@register_decomposition(torch.ops.aten.selu) @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -222,7 +231,7 @@ def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: # softplus is implemented specially because it has beta and threshold arguments @register_decomposition(torch.ops.aten.softplus) -@out_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -258,21 +267,27 @@ def softplus( return torch.where(scaled_input > threshold, a, rhs) -@out_wrapper +@register_decomposition(torch.ops.aten.hardshrink) +@out_wrapper() def hardshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # hardshrink(x) = x if x > lambd # = x if x < -lambd # = 0 otherwise - return refs.where(abs(a) > abs(lambd), a, 0) + return refs.where(refs.logical_and(a >= -lambd, a <= lambd), 0, a) -@out_wrapper +@register_decomposition(torch.ops.aten.softshrink) +@out_wrapper() def softshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # softshrink(x) = x - lambd if x > lambd # = x + lambd if x < -lambd # = 0 otherwise + check( + lambd >= 0, + lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", + ) ge_mask = a > lambd le_mask = a < -lambd zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask)) @@ -296,6 +311,42 @@ def _check_reduction_value(reduction: str): raise ValueError("{} is not a valid value for reduction".format(reduction)) +# This helper function maps depreciated arguments, "size_average" and "reduce" +# to their corresponding "reduction" string argument +def _get_string_reduction_arg( + *, size_average: Optional[bool], reduce: Optional[bool] +) -> str: + if size_average is None: + size_average = True + if reduce is None: + reduce = True + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + return ret + + +@register_decomposition(torch.ops.aten.l1_loss) +def l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.abs(input - target) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(torch.ops.aten.margin_ranking_loss) def margin_ranking_loss( input1: TensorLikeType, input2: TensorLikeType, @@ -323,6 +374,23 @@ def margin_ranking_loss( return _apply_loss_reduction(loss, reduction) +def mse_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.pow(input - target, 2) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(torch.ops.aten.hinge_embedding_loss) def hinge_embedding_loss( input: TensorLikeType, target: TensorLikeType, @@ -357,6 +425,27 @@ def tanhshrink(a: TensorLikeType) -> TensorLikeType: return refs.sub(a, refs.tanh(a)) +@register_decomposition(torch.ops.aten.threshold) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def threshold( + a: TensorLikeType, + threshold: NumberType, + value: Union[bool, int, float], + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.threshold + """ + + if inplace: + raise NotImplementedError + + return torch.where(a <= threshold, value, a) + + @register_decomposition(torch.ops.aten.hardtanh) @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( @@ -389,7 +478,7 @@ def hardtanh( @register_decomposition(torch.ops.aten.gelu) -@out_wrapper +@out_wrapper() @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a",), @@ -417,3 +506,56 @@ def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: return a * 0.5 * (1 + torch.erf(a * kAlpha)) else: raise RuntimeError("approximate argument must be either none or tanh.") + + +@register_decomposition(torch.ops.aten.prelu) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "weight"), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.prelu + """ + check( + isinstance(a, TensorLike), + lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", + ) + check( + isinstance(weight, TensorLike), + lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", + ) + + if weight.numel() != 1: + check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") + channel_size = a.shape[1] if a.ndim >= 2 else 1 + check( + weight.numel() == channel_size, + lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" + f" {weight.numel()} and channel size = {channel_size}.", + ) + + check( + weight.ndim == 0 or weight.ndim == 1, + lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " + f"ndim = {weight.ndim}", + ) + weight = prims.broadcast_in_dim( + weight, a.shape, tuple() if weight.ndim == 0 else (1,) + ) + + return refs.where(a > 0, a, a * weight) + + +@register_decomposition(torch.ops.aten.relu6) +def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu6 + """ + if inplace: + raise NotImplementedError + + # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 + # It may be better to use clamp here, but we use hardtanh to replicate + # the behavior of the existing implementation + return refs.nn.functional.hardtanh(a, 0, 6) diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py index 4aa2a278f112f..1c75522d9be17 100644 --- a/torch/_refs/special/__init__.py +++ b/torch/_refs/special/__init__.py @@ -1,17 +1,18 @@ -import torch - -from torch import Tensor from typing import Optional + +import torch import torch._prims as prims -import torch._prims.utils as utils +import torch._prims_common as utils import torch._refs as refs -from torch._prims.utils import TensorLikeType, ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims.wrappers import out_wrapper, elementwise_type_promotion_wrapper + +from torch import Tensor +from torch._decomp import register_decomposition +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, TensorLikeType +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper from torch._refs import ( - _make_elementwise_unary_reference, _make_elementwise_binary_reference, + _make_elementwise_unary_reference, ) -from torch._decomp import register_decomposition __all__ = [ @@ -45,7 +46,7 @@ def i1e(a): @register_decomposition(torch.ops.aten.logit) -@out_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("self",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, diff --git a/torch/_six.py b/torch/_six.py index e288a1199b771..7ccc12f6bc5dd 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -24,13 +24,13 @@ nan = math.nan string_classes = (str, bytes) + def with_metaclass(meta: type, *bases) -> type: """Create a base class with a metaclass.""" # This requires a bit of explanation: the basic idea is to make a dummy # metaclass for one level of class instantiation that replaces itself with # the actual metaclass. class metaclass(meta): # type: ignore[misc, valid-type] - def __new__(cls, name, this_bases, d): return meta(name, bases, d) @@ -38,4 +38,4 @@ def __new__(cls, name, this_bases, d): def __prepare__(cls, name, this_bases): return meta.__prepare__(name, bases) - return type.__new__(metaclass, 'temporary_class', (), {}) + return type.__new__(metaclass, "temporary_class", (), {}) diff --git a/torch/_sources.py b/torch/_sources.py index a7a87481d3e5a..23d7338114dc9 100644 --- a/torch/_sources.py +++ b/torch/_sources.py @@ -2,10 +2,12 @@ import functools import inspect from textwrap import dedent -from typing import Any, Optional, Tuple, List, NamedTuple +from typing import Any, List, NamedTuple, Optional, Tuple + from torch._C import ErrorReport from torch._C._jit_tree_views import SourceRangeFactory + def get_source_lines_and_file( obj: Any, error_msg: Optional[str] = None, @@ -20,11 +22,13 @@ def get_source_lines_and_file( filename = inspect.getsourcefile(obj) sourcelines, file_lineno = inspect.getsourcelines(obj) except OSError as e: - msg = (f"Can't get source for {obj}. TorchScript requires source access in " - "order to carry out compilation, make sure original .py files are " - "available.") + msg = ( + f"Can't get source for {obj}. TorchScript requires source access in " + "order to carry out compilation, make sure original .py files are " + "available." + ) if error_msg: - msg += '\n' + error_msg + msg += "\n" + error_msg raise OSError(msg) from e return sourcelines, file_lineno, filename @@ -45,7 +49,7 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]: """ def remove_prefix(text, prefix): - return text[text.startswith(prefix) and len(prefix):] + return text[text.startswith(prefix) and len(prefix) :] # Find the line and line number containing the function definition idx = None @@ -65,8 +69,12 @@ def remove_prefix(text, prefix): whitespace = fn_def.split("def")[0] # Add this leading whitespace to all lines before and after the `def` - aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]] - aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]] + aligned_prefix = [ + whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx] + ] + aligned_suffix = [ + whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :] + ] # Put it together again aligned_prefix.append(fn_def) @@ -76,8 +84,18 @@ def remove_prefix(text, prefix): # Thin wrapper around SourceRangeFactory to store extra metadata # about the function-to-be-compiled. class SourceContext(SourceRangeFactory): - def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True, funcname=None): - super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len) + def __init__( + self, + source, + filename, + file_lineno, + leading_whitespace_len, + uses_true_division=True, + funcname=None, + ): + super(SourceContext, self).__init__( + source, filename, file_lineno, leading_whitespace_len + ) self.uses_true_division = uses_true_division self.filename = filename self.funcname = funcname @@ -89,7 +107,7 @@ def make_source_context(*args): def fake_range(): - return SourceContext('', None, 0, 0).make_raw_range(0, 1) + return SourceContext("", None, 0, 0).make_raw_range(0, 1) class ParsedDef(NamedTuple): @@ -99,14 +117,23 @@ class ParsedDef(NamedTuple): filename: Optional[str] file_lineno: int + def parse_def(fn): - sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack()) + sourcelines, file_lineno, filename = get_source_lines_and_file( + fn, ErrorReport.call_stack() + ) sourcelines = normalize_source_lines(sourcelines) - source = ''.join(sourcelines) + source = "".join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): - raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") - leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) - ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True, fn.__name__) + raise RuntimeError( + f"Expected a single top-level function: {filename}:{file_lineno}" + ) + leading_whitespace_len = len(source.split("\n", 1)[0]) - len( + dedent_src.split("\n", 1)[0] + ) + ctx = make_source_context( + source, filename, file_lineno, leading_whitespace_len, True, fn.__name__ + ) return ParsedDef(py_ast, ctx, source, filename, file_lineno) diff --git a/torch/_storage_docs.py b/torch/_storage_docs.py index 15461d263b534..a00ffc2c6f360 100644 --- a/torch/_storage_docs.py +++ b/torch/_storage_docs.py @@ -5,7 +5,7 @@ storage_classes = [ - 'StorageBase', + "StorageBase", ] @@ -18,8 +18,9 @@ def add_docstr_all(method, docstr): pass -add_docstr_all('from_file', - """ +add_docstr_all( + "from_file", + """ from_file(filename, shared=False, size=0) -> Storage If `shared` is `True`, then memory is shared between all processes. @@ -35,4 +36,5 @@ def add_docstr_all(method, docstr): filename (str): file name to map shared (bool): whether to share memory size (int): number of elements in the storage -""") +""", +) diff --git a/torch/_subclasses/__init__.py b/torch/_subclasses/__init__.py index 2aea697b1e3b6..85ea330182872 100644 --- a/torch/_subclasses/__init__.py +++ b/torch/_subclasses/__init__.py @@ -1,6 +1,11 @@ import torch -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, UnsupportedFakeTensorException, DynamicOutputShapeException +from torch._subclasses.fake_tensor import ( + DynamicOutputShapeException, + FakeTensor, + FakeTensorMode, + UnsupportedFakeTensorException, +) __all__ = [ "FakeTensor", diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 1cd53722f5957..e4dbc6c70aa4a 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,18 +1,21 @@ -import torch - -from torch.utils._pytree import tree_map, tree_flatten -from functools import partial -from torch.fx.operator_schemas import normalize_function -from torch.utils._mode_utils import no_dispatch -from torch._subclasses.meta_utils import MetaConverter -from typing import Union, Callable -from torch._ops import OpOverload -from torch.overrides import TorchFunctionMode -from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode -import weakref +import contextlib import functools import itertools +import weakref from dataclasses import dataclass +from functools import partial +from typing import Callable, Union + +import torch +import torch.fx.experimental.symbolic_shapes as symbolic_shapes +from torch._ops import OpOverload +from torch._subclasses.meta_utils import MetaConverter, WeakTensorRefKey +from torch.fx.operator_schemas import normalize_function +from torch.overrides import TorchFunctionMode +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode + +from torch.utils._pytree import tree_flatten, tree_map aten = torch.ops.aten @@ -22,6 +25,7 @@ class UnsupportedFakeTensorException(RuntimeError): reason: str + @dataclass class DynamicOutputShapeException(RuntimeError): func: OpOverload @@ -35,12 +39,12 @@ class DynamicOutputShapeException(RuntimeError): aten.to.device, aten.to.prim_Device, aten._pin_memory.default, - aten._resize_output.functional, + aten._resize_output.default, aten._resize_output.out, ) # this op is never actually used -_non_kwarg_device_constructors = (torch.ops.aten._list_to_tensor,) +_non_kwarg_device_constructors = (aten._list_to_tensor,) def contains_tensor_types(type): @@ -82,8 +86,8 @@ def _is_tensor_constructor(func: OpOverload): # Similar to `MetaConverter`, this is a class for converting # multiple tensors into fake tensors which share the same view/storage -# structure. Like `MetaConverter`, it will keep alive all -# tensors that are converted to FakeTensors. +# structure. Like `MetaConverter`, it uses `WeakTensorRefKey` to +# hold a weak reference for all memoized tensors. class FakeTensorConverter(object): tensor_memo: weakref.WeakValueDictionary meta_converter: MetaConverter @@ -96,20 +100,35 @@ def __init__(self): self.meta_converter = MetaConverter() def _get_memo(self, t): - if t in self.tensor_memo: - out = self.tensor_memo[t] + if WeakTensorRefKey(t) in self.tensor_memo: + out = self.tensor_memo[WeakTensorRefKey(t)] out._fix_weakref() return out return None + def set_tensor_memo(self, t, v): + th = WeakTensorRefKey(t) + + # hold a weak ref to self, otherwise it will be kept alive + # by the del_ten closure + self_weak_ref = weakref.ref(self) + + def del_ten(): + self_ref = self_weak_ref() + if self_ref is None: + return + # on shutdown, th may not be in memo + self_ref.tensor_memo.pop(th, None) + + weakref.finalize(t, del_ten) + self.tensor_memo[th] = v + def from_real_tensor(self, fake_mode, t): maybe_memo = self._get_memo(t) if maybe_memo is not None: return maybe_memo existing_device = t.device # not yet supported in metatensors - if t.is_complex(): - raise UnsupportedFakeTensorException("complex nyi in meta tensors") if t.is_sparse: raise UnsupportedFakeTensorException("sparse nyi in meta tensors") if t.is_quantized: @@ -118,7 +137,9 @@ def from_real_tensor(self, fake_mode, t): out = FakeTensor(fake_mode, self.meta_converter(t), existing_device) if type(t) is torch.nn.Parameter: out = torch.nn.Parameter(out, requires_grad=out.requires_grad) # type: ignore[assignment] - self.tensor_memo[t] = out + if t.grad is not None: + out.grad = self.from_real_tensor(fake_mode, t.grad) + self.set_tensor_memo(t, out) return out def from_meta_and_device(self, fake_mode, t, device): @@ -126,7 +147,7 @@ def from_meta_and_device(self, fake_mode, t, device): if maybe_memo is not None: return maybe_memo out = FakeTensor(fake_mode, t, device) - self.tensor_memo[t] = out + self.set_tensor_memo(t, out) return out def __call__(self, fake_mode, t, device=None): @@ -152,8 +173,11 @@ def impl_decorator(op_impl): return impl_decorator -@register_op_impl(lambda func: (_is_tensor_constructor(func) or func in _like_tensor_constructors)) -def contructors(fake_mode, func, *args, **kwargs): + +@register_op_impl( + lambda func: (_is_tensor_constructor(func) or func in _like_tensor_constructors) +) +def constructors(fake_mode, func, *args, **kwargs): assert func not in _non_kwarg_device_constructors _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True @@ -172,6 +196,7 @@ def contructors(fake_mode, func, *args, **kwargs): r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) + @register_op_impl(lambda func: func in (aten.to.prim_Device, aten.to.device)) def non_kwarg_to(fake_mode, func, *args, **kwargs): _, new_kwargs = normalize_function( @@ -193,7 +218,7 @@ def resize_as_(fake_mode, func, *args, **kwargs): # _to_copy fails when run with FakeTensors to cuda device # TODO: debug -@register_op_impl(torch.ops.aten._to_copy.default) +@register_op_impl(aten._to_copy.default) def to_copy(fake_mode, func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True @@ -203,23 +228,26 @@ def to_copy(fake_mode, func, *args, **kwargs): out_device = input_device if input_device else new_kwargs["input"].device with no_dispatch(): input = new_kwargs.pop("input").to("meta") - return FakeTensor( - fake_mode, torch.ops.aten._to_copy(input, **new_kwargs), out_device - ) + return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device) + -@register_op_impl(torch.ops.aten.clone.default) +@register_op_impl(aten.clone.default) def clone(fake_mode, func, input, memory_format=None): out_device = input.device with no_dispatch(): - out = torch.ops.aten._to_copy(input.to("meta"), memory_format=memory_format) + out = aten._to_copy(input.to("meta"), memory_format=memory_format) return FakeTensor(fake_mode, out, out_device) + # index.Tensor data-dependent in only some conditions -@register_op_impl(lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined] - and func != aten.index.Tensor) +@register_op_impl( + lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined] + and func != aten.index.Tensor +) def data_dep_op(fake_mode, func, *args, **kwargs): raise DynamicOutputShapeException(func) + # Bool Indices get Expanded as Masks # See: IndexingUtils.h:expandTensors def check_no_bool_index_tensors(func, self, indices): @@ -227,6 +255,49 @@ def check_no_bool_index_tensors(func, self, indices): if index is not None and index.dtype in (torch.bool, torch.uint8): raise DynamicOutputShapeException(func) + +def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + out_device = new_kwargs["input"].device + with in_kernel_invocation_manager(fake_mode): + out = func(*args, **kwargs) + + return FakeTensor(fake_mode, out, out_device) + + +# Dont default to default device handling, +# Since op can take in non-zero sized cpu +# index tensors with cuda self +@register_op_impl(aten.index.Tensor) +def index_tensor(fake_mode, func, *args, **kwargs): + # dynamic shape op if indices are bool/uint8 + check_no_bool_index_tensors(func, *args, **kwargs) + + return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) + + +# takes in multiple-devices, dont default to default device handling +@register_op_impl(aten.index_put.default) +def index_put(fake_mode, func, *args, **kwargs): + return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) + + +# same with index_put, but return the input +@register_op_impl(aten.index_put_.default) +def index_put_(fake_mode, func, *args, **kwargs): + with in_kernel_invocation_manager(fake_mode): + out = func(*args, **kwargs) + + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + return new_kwargs["input"] + + # Meta tensors give you the ability to run PyTorch code without having to # actually do computation through tensors allocated on a `meta` device. # Because the device is `meta`, meta tensors do not model device propagation. @@ -234,23 +305,57 @@ def check_no_bool_index_tensors(func, self, indices): # which tracks devices that would have been used. +@contextlib.contextmanager +def in_kernel_invocation_manager(fake_mode): + fake_mode.in_kernel_invocation = True + # See: note [Fake Tensor Dispatch Keys] + torch._C._add_meta_to_tls_dispatch_include() + try: + yield + finally: + fake_mode.in_kernel_invocation = False + torch._C._remove_meta_from_tls_dispatch_include() + + class FakeTensor(torch.Tensor): fake_device: torch.device fake_mode: "FakeTensorMode" + has_sym_ints: bool + + # Note: [Fake Tensor Dispatch Keys] + # In order to model the behavior of device-specific autocast + # and autograd logic, we update the dispatch keys of FakeTensors + # to reflect their fake device. This includes the BackendComponent + # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent + # related Autocast and Autograd keys. __torch__dispatch__ sits below + # Autocast and Autograd, and is only invoked when we are at the + # kernel for the BackendComponent. Then, we add Meta to the + # thread-local dispatch include set to hit the meta kernel + # instead of the kernel of the BackendComponent for the fake device. + # The `device_for_backend_keys` does that below @staticmethod def __new__(cls, fake_mode, elem, device): return torch.Tensor._make_subclass( - cls, elem, elem.requires_grad, dispatch_device=True + cls, + elem, + elem.requires_grad, + dispatch_device=True, + device_for_backend_keys=device, ) def __init__(self, fake_mode, elem, device: Union[torch.device, str]): # elem does not need to be recorded, because FakeTensor *is a* elem - assert elem.device.type == "meta" + assert elem.device.type == "meta", elem device = device if isinstance(device, torch.device) else torch.device(device) + # normalize cuda device + if device.type == "cuda" and device.index is None: + device = torch.device(f"cuda:{torch.cuda.current_device()}") assert device.type != "meta" + self.fake_device = device self.fake_mode = fake_mode + self.has_sym_ints = symbolic_shapes.has_symbolic_sizes_strides(elem) @staticmethod def from_tensor(t, fake_mode): @@ -261,24 +366,32 @@ def from_tensor(t, fake_mode): def __repr__(self): return f"FakeTensor({self.fake_device}, {self.size()}, {self.dtype})" + def stride(self): + if self.has_sym_ints: + # TODO: As we currently don't support symbolic strides, we'll assume contiguous strides + # The reason this needs to be here instead of __torch_dispatch__ is that + # when aten.stride goes into __torch_dispatch__, it expects a list of + # concrete ints to be returned. So we need to short-circuit that entirely + return symbolic_shapes.create_contiguous(self.shape) + return super().stride() + def new(self, *args, **kwargs): # torch.Tensor.new does not go through the normal dispatcher pattern # so in order to use the same pattern as normal invocation of # returning meta device within the kernel we need to intercept # the call here + # because it doesn't go through the dispatcher, we run into errors + # when attempting to compute an output in meta, so + # we compute the real tensor then convert to meta out_device = self.fake_device - if "device" in kwargs: - kwarg_device = kwargs.pop("device") - out_device = kwarg_device if kwarg_device else out_device - kwargs["device"] = "meta" - self.in_kernel_invocation = True - try: - with no_dispatch(): - meta_out = super().new(*args, **kwargs) - finally: - self.in_kernel_invocation = False + with no_dispatch(): + real_out = super().new(*args, **kwargs) + + assert not isinstance(real_out, FakeTensor), real_out + assert real_out.device.type != "meta", real_out.device with no_dispatch(): + meta_out = MetaConverter()(real_out) return FakeTensor(self.fake_mode, meta_out, out_device) @classmethod @@ -292,6 +405,14 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): else: return args[0].fake_device + # Because fake mode can return NotImplemented (if it sees a subclass + # it doesn't know how to deal with), this test here is important + # because the next dispatch after a fake mode will attempt to use + # subclasses of tensors to dispatch, and any FakeTensor arguments + # will be considered eligible. + if any(not issubclass(t, FakeTensor) and t is not torch.Tensor for t in types): + return NotImplemented + fake_mode = None for arg in itertools.chain(tree_flatten(args)[0], tree_flatten(kwargs)[0]): if isinstance(arg, FakeTensor): @@ -344,7 +465,7 @@ def merge_devices(t): # mismatching devices of non-zero dim tensors, throw # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as - raise Exception( + raise RuntimeError( f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}" ) @@ -368,8 +489,8 @@ def merge_devices(t): class FakeTensorMode(TorchDispatchMode): - def __init__(self, allow_cpu_fallback=True): - self.allow_cpu_fallback = allow_cpu_fallback + def __init__(self, allow_fallback_kernels=True): + self.allow_fallback_kernels = allow_fallback_kernels self.fake_tensor_converter = FakeTensorConverter() # [in_kernel_invocation] @@ -393,122 +514,205 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return torch.device("meta") else: return args[0].fake_device + flat_arg_tensors = [ + i for i in tree_flatten((args, kwargs))[0] if isinstance(i, FakeTensor) + ] + has_symbolic_sizes = any([i.has_sym_ints for i in flat_arg_tensors]) + if has_symbolic_sizes: + # TODO: Find better approach for this + # Avoid circular import + from torch._decomp import decomposition_table + from torch._meta_registrations import meta_table + + # TODO: hack, doesn't actually work. + # see https://github.com/pytorch/pytorch/pull/81598#issuecomment-1192030435 + with enable_torch_dispatch_mode( + self + ), torch.overrides.enable_reentrant_dispatch(): + if func in meta_table: + r = meta_table[func](*args, **kwargs) + return r + if func in decomposition_table: + return decomposition_table[func](*args, **kwargs) + + with no_dispatch(): + if symbolic_shapes.is_symbolic_op(func): + return symbolic_shapes.handle_symbolic_op(func, args, kwargs) + if func == aten.size.default: + raise RuntimeError( + "Trying to call aten.size on a tensor with symbolic shapes. " + "It's likely that this is from calling tensor.shape in C++" + ) # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them - if "prims::" in func._schema.name: - with no_dispatch(): - return func(*args, **kwargs) + # and ensure that Meta kernels are dispatched to (see) + # Fake Tensor Dispatch Keys - with no_dispatch(): - # TODO: apply as no_dispatch decorator - converter = self.fake_tensor_converter + if "prims::" in func._schema.name and len(flat_arg_tensors) != 0: + try: + torch._C._add_meta_to_tls_dispatch_include() + with no_dispatch(): + return func(*args, **kwargs) + finally: + torch._C._remove_meta_from_tls_dispatch_include() - # this is generated from torch.tensor(), which does not use the - # dispatcher, to allow wrapper subclasses to wrap the new tensor - # we need to handle before error checking - if func == torch.ops.aten.lift.default: - assert ( - len(kwargs) == 0 - and len(args) == 1 - and type(args[0]) is torch.Tensor + if has_symbolic_sizes: + constructors = [aten.empty.SymInt] + if func not in constructors: + raise RuntimeError( + f"{func} - couldn't find symbolic meta function/decomposition" ) - with no_dispatch(): - return converter(self, args[0]) - def wrap(e, device=None): - if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): - return converter(self, e, device) - else: - return e + with no_dispatch(): + # TODO: apply as no_dispatch decorator + converter = self.fake_tensor_converter # if we are in the dispatch mode, we will enter this function even if the inputs # are not FakeTensors. For now, throw if any non-Fake Tensor inputs # and just support constructors. TODO: extend more broadly conversion_made = False + subclass_seen = False def check_non_fake_tensor(x): - nonlocal conversion_made + nonlocal conversion_made, subclass_seen conversion_made = conversion_made or ( isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor) ) + subclass_seen = subclass_seen or ( + isinstance(x, torch.Tensor) + and not isinstance(x, FakeTensor) + and type(x) is not torch.Tensor + and type(x) is not torch.nn.Parameter + ) tree_map(check_non_fake_tensor, args) tree_map(check_non_fake_tensor, kwargs) + # Suppose we enable fake tensor mode. This means that fake tensor + # mode will run first. But what if we do an operation that + # involves a tensor subclass that will desugar into normal tensor + # operations? Without this line, fake tensor mode will run first, + # decide that a conversion was made (since there was a non fake + # tensor argument), and report an error that converting non + # fake tensor is not supported. What we actually wanted to happen + # was to give the subclass a chance to figure out what it wants to + # before erroring out. Returning NotImplemented here allows this. + # + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + if subclass_seen: + return NotImplemented + + # this is generated from torch.tensor(), which does not use the + # dispatcher, to allow wrapper subclasses to wrap the new tensor + # we need to handle before error checking + if func in [ + aten.lift_fresh.default, + aten.lift_fresh_copy.default, + ]: + assert ( + len(kwargs) == 0 + and len(args) == 1 + and type(args[0]) is torch.Tensor + ), f"{args} {kwargs}" + with no_dispatch(): + return converter(self, args[0]) + if conversion_made: raise Exception( "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " - f"Please convert all Tensors to FakeTensors first. Found in {func}" + f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})" ) for run_impl_check, op_impl in op_implementations: if run_impl_check(func): return op_impl(self, func, *args, **kwargs) - if func == aten.index.Tensor: - check_no_bool_index_tensors(func, *args, **kwargs) - - self.in_kernel_invocation = True try: - r = func(*args, **kwargs) + with in_kernel_invocation_manager(self): + r = func(*args, **kwargs) except NotImplementedError as not_implemented_error: - if not self.allow_cpu_fallback: + if not self.allow_fallback_kernels: raise not_implemented_error - r = run_cpu_fallback(func, args, kwargs, not_implemented_error) - finally: - self.in_kernel_invocation = False + r = run_fallback_kernel(func, args, kwargs, not_implemented_error) # TODO: handle non-kwarg devices assert func not in _device_not_kwarg_ops, f"NYI: {func}" + # Lazily initialized, in case there are no tensor returns + common_device = None + + def wrap(e, device=None): + nonlocal common_device + if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): + if common_device is None: + common_device = FakeTensor._find_common_device( + func, args, kwargs + ) + return converter(self, e, device or common_device) + else: + return e + # if device is specified, use that if kwargs.get("device", None): return tree_map(partial(wrap, device=kwargs["device"]), r) - common_device = FakeTensor._find_common_device(func, args, kwargs) - - return tree_map(partial(wrap, device=common_device), r) + return tree_map(partial(wrap), r) def from_tensor(self, tensor): return self.fake_tensor_converter(self, tensor) -def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception): + +def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception): + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] + raise orig_not_implemented_exception + with no_dispatch(): - def to_cpu(e): + inp_impls = {} + + def to_real_tensor(e): if isinstance(e, FakeTensor): - return torch.zeros_like(e, device="cpu") + out = torch.zeros_like(e, device=e.fake_device) + inp_impls[id(out)] = e + return out return e - try: - args = tree_map(to_cpu, args) - kwargs = tree_map(to_cpu, kwargs) + args = tree_map(to_real_tensor, args) + kwargs = tree_map(to_real_tensor, kwargs) - r = func(*args, **kwargs) - except Exception as new_exception: - raise orig_not_implemented_exception from new_exception + r = func(*args, **kwargs) tensor_impls = set() storages = set() for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): - tensor_impls.add(e) storages.add(e.storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will - # not be set up, bc of conversion to cpu, error on reused impls + # not be set up, bc of conversion to device, unless we can reuse an + # input impl for e in tree_flatten(r)[0]: - if e in tensor_impls or ( + if id(e) not in inp_impls and ( isinstance(e, torch.Tensor) and e.storage()._cdata in storages ): raise orig_not_implemented_exception - # we're only converting these to MetaTensors now, not Fake Tensors, - # and the cpu inputs should be temporary. just convert outputs to meta - # and continue - return tree_map(MetaConverter(), r) + # the outputs which are are not reused from impls will be converted + # to fake tensors later + meta_converter = MetaConverter() + + def map_out(e): + return inp_impls.get(id(e), meta_converter(e)) + + return tree_map(map_out, r) # Just for use to allow copying a module to fake tensors, diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 0508d8f47743d..5e554fbf5f40f 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1,6 +1,10 @@ +import weakref + import torch +from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._mode_utils import no_dispatch + def safe_is_leaf(t): try: return t.is_leaf @@ -9,36 +13,135 @@ def safe_is_leaf(t): return False +# torch.Tensors cannot be used as a key in a dictionary +# because they define a custom __eq__ function which when used +# to resolve hash collisions will throw when comparing tensors: +# "RuntimeError: bool value of Tensor with more than one value is ambiguous." +# To avoid that, we use an object which will hold a Tensor and use +# its id for both hashing and equality. +# In order to use this as a weak key reference, we cannot +# simply use weakref.WeakKeyDictionary because the newly constructed +# WeakTensorRefKey only use would be a dictionary so it would have no strong +# references. +# To get around this issue, we can use it as a normal key, and then set +# `weakref.finalize` to delete the key when its contained tensor dies. + + +class WeakTensorRefKey(object): + def __init__(self, ten): + self.ten = weakref.ref(ten) + # store id since as soon as ten is deallocated + # the old id will no longer be recoverable, and + # we need to be able to remove the WeakTensorRefKey + # from the dictionary by hashing it to the same + # value it had when ten was alive + self.id = id(self.ten()) + + def __hash__(self): + return self.id + + def __eq__(self, other): + if id(self) == id(other): + return True + return self.id == other.id + + # This is a class for converting multiple tensors into meta tensors which # share the same view/storage structure. The operation model is you allocate # one of these, and then call it repeatedly on all the tensors you want to # convert. It's important to use the same object for tensors you want to # share storage because this is how we correlate shared storages to the same -# meta storages; similarly, it's important NOT to use the same object for -# unrelated groups of tensors because this class will remember all the -# tensors/storages its seen and therefore leak memory. +# meta storages. This class will hold weak references to cached tenosrs +# and tensor storages. class MetaConverter: def __init__(self): self.storage_memo = {} self.tensor_memo = {} + self.maybe_storages_to_delete = [] + self.check_expired_frequency = 128 + self.check_expired_count = 0 self.hit = 0 self.miss = 0 + self.del_hook = None def successful(self): return self.hit > 0 and self.miss == 0 + def check_for_expired_weak_storages(self): + new_li = [] + stor_to_delete = [] + for obj in self.maybe_storages_to_delete: + if not obj.expired(): + new_li.append(obj) + else: + stor_to_delete.append(obj) + for obj in stor_to_delete: + self.storage_memo.pop(obj, None) + self.maybe_storages_to_delete = new_li + + # if for some reason we have aquired many storages which have not expired + # even though a tensor with their storage has expired (aliasing or otherwise) + # check for expired storages less often so as to bound the amount of work we + # do checking for expired storages + self.check_expired_frequency = max( + self.check_expired_frequency, len(self.maybe_storages_to_delete) + ) + + def get_tensor_memo(self, t): + return self.tensor_memo.get(WeakTensorRefKey(t), None) + + def set_tensor_memo(self, t, v): + # hold a weak ref to self, otherwise it will be kept alive + # by the del_ten closure + self_weak_ref = weakref.ref(self) + weak_st = StorageWeakRef(t.storage()) + tensor_ref_key = WeakTensorRefKey(t) + + def del_ten(): + # tensor outlives the converter + self_ref = self_weak_ref() + if self_ref is None: + return + # on shutdown, tensor_ref_key may not be in memo + self_ref.tensor_memo.pop(tensor_ref_key, None) + if weak_st and weak_st.expired(): + self_ref.storage_memo.pop(weak_st, None) + else: + # [expired-storages] + # NB: even though the tensor has died, + # the deallocation of its storage can take longer, + # even when the storage has no other uses/views. + # In this case, the StorageWeakRef object will be kept alive + # longer than it needs to be, however the storage itself + # will be deallocated. We retain the possibly dead storages + # and periodically check if any of them are expired and + # can be freed. + self_ref.maybe_storages_to_delete.append(weak_st) + + weakref.finalize(t, del_ten) + self.tensor_memo[tensor_ref_key] = v + # NB: doesn't actually return a storage, because meta storage is # not supported def meta_storage(self, s): # NB: TypedStorage is freshly allocated and cannot be used as hash # key index. - if s._cdata not in self.storage_memo: - self.storage_memo[s._cdata] = torch.empty(s.size(), dtype=s.dtype, device='meta') - return self.storage_memo[s._cdata] + + # Use a Weak Ref to s in order to not leak memory + swr = StorageWeakRef(s) + if swr not in self.storage_memo: + self.storage_memo[swr] = torch.empty(s.size(), dtype=s.dtype, device="meta") + return self.storage_memo[swr] # This function assumes that it's possible to do the conversion def meta_tensor(self, t): - if t not in self.tensor_memo: + # see expired-storages + self.check_expired_count += 1 + if self.check_expired_count >= self.check_expired_frequency: + self.check_for_expired_weak_storages() + self.check_expired_count = 0 + + if self.get_tensor_memo(t) is None: with torch.inference_mode(t.is_inference()): if t._is_view(): # Construct views in two steps: recursively meta-fy their @@ -49,8 +152,11 @@ def meta_tensor(self, t): base = self.meta_tensor(t._base) def is_c_of_r(complex_dtype, real_dtype): - return utils.is_complex_dtype(complex_dtype) and \ - utils.corresponding_real_dtype(complex_dtype) == real_dtype + return ( + utils.is_complex_dtype(complex_dtype) + and utils.corresponding_real_dtype(complex_dtype) + == real_dtype + ) if base.dtype == t.dtype: pass @@ -70,7 +176,9 @@ def is_c_of_r(complex_dtype, real_dtype): is_leaf = safe_is_leaf(t) # Fake up some autograd history. if t.requires_grad: - r = torch.empty((0,), dtype=t.dtype, device='meta', requires_grad=True) + r = torch.empty( + (0,), dtype=t.dtype, device="meta", requires_grad=True + ) if not is_leaf: with torch.enable_grad(): # The backward function here will be wrong, but @@ -81,7 +189,7 @@ def is_c_of_r(complex_dtype, real_dtype): # sort of unsupported grad_fn here r = r.clone() else: - r = torch.empty((0,), dtype=t.dtype, device='meta') + r = torch.empty((0,), dtype=t.dtype, device="meta") # As long as meta storage is not supported, need to prevent # redispatching on set_(Storage, ...) which will choke with # meta storage @@ -92,27 +200,32 @@ def is_c_of_r(complex_dtype, real_dtype): torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) - self.tensor_memo[t] = r + self.set_tensor_memo(t, r) - return self.tensor_memo[t] + return self.get_tensor_memo(t) def __call__(self, t): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: - if any([ - t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized, - t.is_nested, torch._is_functional_tensor(t), - # these are supported in meta conversion but the fallbacks - # don't work - t.is_neg(), t.is_conj(), - # conjugate fallback does not support meta tensors - t.dtype in (torch.complex128, torch.complex64, torch.complex32), - t.device.type in ("lazy", "meta"), - # We need a way to test if a tensor is batched but there - # is no official APi to do it - # torch._C._is_batched(t), - ]): + if any( + [ + t.is_sparse_csr, + t.is_sparse, + t.is_mkldnn, + t.is_quantized, + t.is_nested, + torch._is_functional_tensor(t), + # these are supported in meta conversion but the fallbacks + # don't work + t.is_neg(), + t.is_conj(), + t.device.type in ("lazy", "meta"), + # We need a way to test if a tensor is batched but there + # is no official APi to do it + # torch._C._is_batched(t), + ] + ): # TODO: sparse should support meta # NB technically to('meta') does work but our logging # instrumentation will see the meta conversions and the @@ -138,4 +251,5 @@ def __call__(self, t): # non-Tensor types don't count as hit or miss return t -import torch._prims.utils as utils + +import torch._prims_common as utils diff --git a/torch/_tensor.py b/torch/_tensor.py index 9ee17dc47c807..472ecc512e771 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1,26 +1,35 @@ -from collections import OrderedDict +import copyreg import enum import functools -from numbers import Number -from typing import Any, Dict, Optional, Tuple, Union import warnings -import copyreg +from collections import OrderedDict from copy import deepcopy +from numbers import Number +from typing import Any, Dict, Optional, Tuple, Union import torch import torch._C as _C +import torch.utils.hooks as hooks from torch._namedtensor_internals import ( - update_names, check_serializing_named_tensor, resolve_ellipsis, - unzip_namedshape, single_ellipsis_index, is_ellipsis) + check_serializing_named_tensor, + is_ellipsis, + resolve_ellipsis, + single_ellipsis_index, + unzip_namedshape, + update_names, +) from torch.overrides import ( - has_torch_function, has_torch_function_unary, has_torch_function_variadic, - handle_torch_function, get_default_nowrap_functions) -import torch.utils.hooks as hooks + get_default_nowrap_functions, + handle_torch_function, + has_torch_function, + has_torch_function_unary, + has_torch_function_variadic, +) def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): # functools.wraps doesn't work well with methods in python 2 - method_assignments = ('__name__', '__doc__') + method_assignments = ("__name__", "__doc__") assigned = functools.WRAPPER_ASSIGNMENTS @functools.wraps(f, assigned=assigned) @@ -32,8 +41,10 @@ def wrapped(*args, **kwargs): return f(*args, **kwargs) except TypeError: return NotImplemented + return wrapped + # Should not be used, this is kept only for BC of loading old serialized Tensor subclasses def _rebuild_from_type(func, type, args, dict): if type is Tensor: @@ -43,6 +54,7 @@ def _rebuild_from_type(func, type, args, dict): ret.__dict__ = dict return ret + def _rebuild_from_type_v2(func, new_type, args, state): if new_type is Tensor: return func(*args) @@ -53,7 +65,10 @@ def _rebuild_from_type_v2(func, new_type, args, state): # Tensor does define __setstate__ even though it doesn't define # __getstate__. So only use __setstate__ if it is NOT the one defined # on Tensor - if getattr(ret.__class__, "__setstate__", Tensor.__setstate__) is not Tensor.__setstate__: + if ( + getattr(ret.__class__, "__setstate__", Tensor.__setstate__) + is not Tensor.__setstate__ + ): ret.__setstate__(state) else: if isinstance(state, tuple): @@ -86,8 +101,10 @@ def __deepcopy__(self, memo): if has_torch_function_unary(self): return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo) if not self.is_leaf: - raise RuntimeError("Only Tensors created explicitly by the user " - "(graph leaves) support the deepcopy protocol at the moment") + raise RuntimeError( + "Only Tensors created explicitly by the user " + "(graph leaves) support the deepcopy protocol at the moment" + ) if id(self) in memo: return memo[id(self)] with torch.no_grad(): @@ -97,57 +114,83 @@ def __deepcopy__(self, memo): # https://github.com/pytorch/pytorch/issues/47442 # Update the test in test_serialization if you remove 'meta' from here - if self.is_sparse or self.device.type in ['lazy', 'xla', 'mps', 'ort', 'meta', 'hpu'] or \ - (type(self) is not Tensor and self.data_ptr() == 0): + if ( + self.is_sparse + or self.device.type in ["lazy", "xla", "mps", "ort", "meta", "hpu"] + or (type(self) is not Tensor and self.data_ptr() == 0) + ): new_tensor = self.clone() if type(new_tensor) is not type(self): - raise RuntimeError("The default implementation of __deepcopy__() for wrapper subclasses " - "only works for subclass types that implement clone() and for which " - "cloning returns another instance of the same subclass. You should either " - "properly implement clone() for your subclass or override __deepcopy__() " - "if it is intended behavior for clone() to return an instance of a " - "different type.") + raise RuntimeError( + "The default implementation of __deepcopy__() for wrapper subclasses " + "only works for subclass types that implement clone() and for which " + "cloning returns another instance of the same subclass. You should either " + "properly implement clone() for your subclass or override __deepcopy__() " + "if it is intended behavior for clone() to return an instance of a " + "different type." + ) else: new_storage = self.storage().__deepcopy__(memo) if self.is_quantized: # quantizer_params can be different type based on torch attribute - quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[torch.qscheme, Tensor, Tensor, int]] + quantizer_params: Union[ + Tuple[torch.qscheme, float, int], + Tuple[torch.qscheme, Tensor, Tensor, int], + ] if self.qscheme() == torch.per_tensor_affine: - quantizer_params = self.qscheme(), self.q_scale(), self.q_zero_point() - elif self.qscheme() in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): - quantizer_params = self.qscheme(), \ - self.q_per_channel_scales(), \ - self.q_per_channel_zero_points(), \ - self.q_per_channel_axis() + quantizer_params = ( + self.qscheme(), + self.q_scale(), + self.q_zero_point(), + ) + elif self.qscheme() in ( + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ): + quantizer_params = ( + self.qscheme(), + self.q_per_channel_scales(), + self.q_per_channel_zero_points(), + self.q_per_channel_axis(), + ) else: - raise RuntimeError(f"Unsupported qscheme {self.qscheme()} in deepcopy") + raise RuntimeError( + f"Unsupported qscheme {self.qscheme()} in deepcopy" + ) # TODO: Once we decide to break serialization FC, no longer - # need to wrap with _TypedStorage + # need to wrap with TypedStorage new_tensor = torch._utils._rebuild_qtensor( - torch.storage._TypedStorage( - wrap_storage=new_storage._untyped(), - dtype=self.dtype), + torch.storage.TypedStorage( + wrap_storage=new_storage.untyped(), dtype=self.dtype + ), self.storage_offset(), self.size(), self.stride(), quantizer_params, self.requires_grad, - self._backward_hooks) + self._backward_hooks, + ) if type(new_tensor) is not type(self): - raise RuntimeError("The default implementation of __deepcopy__() for quantized tensors " - "expects the tensor returned by torch._utils._rebuild_qtensor() to " - "match the type of the instance being copied. If you encounter this, " - "please open an issue on PyTorch's GitHub.") + raise RuntimeError( + "The default implementation of __deepcopy__() for quantized tensors " + "expects the tensor returned by torch._utils._rebuild_qtensor() to " + "match the type of the instance being copied. If you encounter this, " + "please open an issue on PyTorch's GitHub." + ) else: new_tensor = self.new_empty([]) if type(new_tensor) is not type(self): - raise RuntimeError("The default implementation of __deepcopy__() for non-wrapper subclasses " - "only works for subclass types that implement new_empty() and for which " - "that function returns another instance of the same subclass. You should " - "either properly implement new_empty() for your subclass or override " - "__deepcopy__() if it is intended behavior for new_empty() to return " - "an instance of a different type.") - new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride()) + raise RuntimeError( + "The default implementation of __deepcopy__() for non-wrapper subclasses " + "only works for subclass types that implement new_empty() and for which " + "that function returns another instance of the same subclass. You should " + "either properly implement new_empty() for your subclass or override " + "__deepcopy__() if it is intended behavior for new_empty() to return " + "an instance of a different type." + ) + new_tensor.set_( + new_storage, self.storage_offset(), self.size(), self.stride() + ) if self.is_conj(): new_tensor = new_tensor.conj_physical() if self.is_neg(): @@ -159,8 +202,10 @@ def __deepcopy__(self, memo): if not type(self) is Tensor: if type(new_tensor) is not type(self): - raise RuntimeError("Type of deepcopy result does not match the type of the source tensor. " - "If you encounter this, please open an issue on PyTorch's GitHub.") + raise RuntimeError( + "Type of deepcopy result does not match the type of the source tensor. " + "If you encounter this, please open an issue on PyTorch's GitHub." + ) # Plain Tensors don't have slots slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] @@ -189,7 +234,14 @@ def __reduce_ex__(self, proto): else: slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] if slots_to_save: - state = (self.__dict__, {name: getattr(self, name) for name in slots_to_save if hasattr(self, name)}) + state = ( + self.__dict__, + { + name: getattr(self, name) + for name in slots_to_save + if hasattr(self, name) + }, + ) else: state = self.__dict__ return (_rebuild_from_type_v2, (func, type(self), args, state)) @@ -203,7 +255,7 @@ def storage(self): if has_torch_function_unary(self): return handle_torch_function(Tensor.storage, (self,), self) - return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype) + return torch.TypedStorage(wrap_storage=self._storage(), dtype=self.dtype) def _reduce_ex_internal(self, proto): check_serializing_named_tensor(self) @@ -220,16 +272,20 @@ def _reduce_ex_internal(self, proto): # 2. Python list is not a good fit due to performance reason. # `tolist()` converts every single element in the tensor into python objects # and serialize them one by one. - if self.device.type in ['xla', 'ort', 'hpu']: + if self.device.type in ["xla", "ort", "hpu"]: # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. - numpy_tensor = self.cpu().numpy() if self.dtype != torch.bfloat16 else self.cpu().to(torch.float32).numpy() - return (torch._utils._rebuild_device_tensor_from_numpy, (numpy_tensor, - self.dtype, - str(self.device), - self.requires_grad)) - if self.device.type == 'meta': + numpy_tensor = ( + self.cpu().numpy() + if self.dtype != torch.bfloat16 + else self.cpu().to(torch.float32).numpy() + ) + return ( + torch._utils._rebuild_device_tensor_from_numpy, + (numpy_tensor, self.dtype, str(self.device), self.requires_grad), + ) + if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. arg_meta = ( @@ -241,56 +297,78 @@ def _reduce_ex_internal(self, proto): return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: # quantizer_params can be different type based on torch attribute - quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]] + quantizer_params: Union[ + Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] + ] if self.qscheme() == torch.per_tensor_affine: - quantizer_params = (torch.per_tensor_affine, - self.q_scale(), - self.q_zero_point()) - elif self.qscheme() in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): + quantizer_params = ( + torch.per_tensor_affine, + self.q_scale(), + self.q_zero_point(), + ) + elif self.qscheme() in ( + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ): # convert scales and zero points to tuple to avoid recursive calls # when/if we get multi-axis quantized tensors in the future, the shape # is recoverable from the main tensor shape - quantizer_params = (torch.per_channel_affine, - self.q_per_channel_scales(), - self.q_per_channel_zero_points(), - self.q_per_channel_axis()) + quantizer_params = ( + torch.per_channel_affine, + self.q_per_channel_scales(), + self.q_per_channel_zero_points(), + self.q_per_channel_axis(), + ) else: - raise RuntimeError(f"Serialization is not supported for tensors of type {self.qscheme()}") + raise RuntimeError( + f"Serialization is not supported for tensors of type {self.qscheme()}" + ) # TODO: Once we decide to break serialization FC, no longer - # need to wrap with _TypedStorage + # need to wrap with TypedStorage args_qtensor = ( - torch.storage._TypedStorage( - wrap_storage=self.storage()._untyped(), - dtype=self.dtype), + torch.storage.TypedStorage( + wrap_storage=self.storage().untyped(), dtype=self.dtype + ), self.storage_offset(), tuple(self.size()), self.stride(), quantizer_params, self.requires_grad, - backward_hooks) + backward_hooks, + ) return (torch._utils._rebuild_qtensor, args_qtensor) elif self.is_sparse: if self.layout == torch.sparse_coo: - args_sparse = (self.layout, - (self._indices(), - self._values(), - self.size())) + args_sparse = ( + self.layout, + (self._indices(), self._values(), self.size()), + ) else: raise NotImplementedError( - 'sparse tensor __reduce_ex__ for layout `%s`' % (self.layout)) + "sparse tensor __reduce_ex__ for layout `%s`" % (self.layout) + ) return (torch._utils._rebuild_sparse_tensor, args_sparse) elif self.is_sparse_csr: if self.layout == torch.sparse_csr: - args_sparse_csr = (self.layout, - (self.crow_indices(), - self.col_indices(), - self.values(), - self.size())) + args_sparse_csr = ( + self.layout, + ( + self.crow_indices(), + self.col_indices(), + self.values(), + self.size(), + ), + ) else: raise NotImplementedError( - 'sparse csr tensor __reduce_ex__ for layout `%s`' % (self.layout)) + "sparse csr tensor __reduce_ex__ for layout `%s`" % (self.layout) + ) return (torch._utils._rebuild_sparse_csr_tensor, args_sparse_csr) - elif self.data_ptr() == 0 and type(self) is not torch.Tensor: + elif ( + self.data_ptr() == 0 + and type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + ): arg_wrapper_subclass = ( type(self), self.dtype, @@ -299,21 +377,22 @@ def _reduce_ex_internal(self, proto): self.storage_offset(), self.layout, self.device, - self.requires_grad + self.requires_grad, ) return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) else: # TODO: Once we decide to break serialization FC, no longer - # need to wrap with _TypedStorage + # need to wrap with TypedStorage args = ( - torch.storage._TypedStorage( - wrap_storage=self.storage()._untyped(), - dtype=self.dtype), + torch.storage.TypedStorage( + wrap_storage=self.storage().untyped(), dtype=self.dtype + ), self.storage_offset(), tuple(self.size()), self.stride(), self.requires_grad, - backward_hooks) # previously was self._backward_hooks + backward_hooks, + ) # previously was self._backward_hooks return (torch._utils._rebuild_tensor_v2, args) def __setstate__(self, state): @@ -322,7 +401,7 @@ def __setstate__(self, state): # Warning: this method is NOT called when you torch.load() a tensor; # that is managed by _rebuild_tensor_v2 if not self.is_leaf: - raise RuntimeError('__setstate__ can be only called on leaf Tensors') + raise RuntimeError("__setstate__ can be only called on leaf Tensors") if len(state) == 4: # legacy serialization of Tensor self.set_(*state) @@ -337,12 +416,15 @@ def __setstate__(self, state): def __repr__(self, *, tensor_contents=None): if has_torch_function_unary(self): - return handle_torch_function(Tensor.__repr__, (self,), self, - tensor_contents=tensor_contents) + return handle_torch_function( + Tensor.__repr__, (self,), self, tensor_contents=tensor_contents + ) # All strings are unicode in Python 3. return torch._tensor_str._str(self, tensor_contents=tensor_contents) - def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None): + def backward( + self, gradient=None, retain_graph=None, create_graph=False, inputs=None + ): r"""Computes the gradient of current tensor w.r.t. graph leaves. The graph is differentiated using the chain rule. If the tensor is @@ -397,8 +479,11 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs= gradient=gradient, retain_graph=retain_graph, create_graph=create_graph, - inputs=inputs) - torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) + inputs=inputs, + ) + torch.autograd.backward( + self, gradient, retain_graph, create_graph, inputs=inputs + ) def register_hook(self, hook): r"""Registers a backward hook. @@ -432,8 +517,9 @@ def register_hook(self, hook): if has_torch_function_unary(self): return handle_torch_function(Tensor.register_hook, (self,), self, hook) if not self.requires_grad: - raise RuntimeError("cannot register a hook on a tensor that " - "doesn't require gradient") + raise RuntimeError( + "cannot register a hook on a tensor that " "doesn't require gradient" + ) if self._backward_hooks is None: self._backward_hooks = OrderedDict() if self.grad_fn is not None: @@ -444,9 +530,11 @@ def register_hook(self, hook): def reinforce(self, reward): def trim(str): - return '\n'.join([line.strip() for line in str.split('\n')]) + return "\n".join([line.strip() for line in str.split("\n")]) - raise RuntimeError(trim(r"""reinforce() was removed. + raise RuntimeError( + trim( + r"""reinforce() was removed. Use torch.distributions instead. See https://pytorch.org/docs/master/distributions.html @@ -467,9 +555,13 @@ def trim(str): next_state, reward = env.step(action) loss = -m.log_prob(action) * reward loss.backward() - """)) + """ + ) + ) - detach = _C._add_docstr(_C._TensorBase.detach, r""" + detach = _C._add_docstr( + _C._TensorBase.detach, + r""" Returns a new Tensor, detached from the current graph. The result will never require gradient. @@ -490,15 +582,19 @@ def trim(str): In-place indices / values changes (such as `zero_` / `copy_` / `add_`) to the returned tensor will not update the original tensor anymore, and will instead trigger an error. - """) + """, + ) - detach_ = _C._add_docstr(_C._TensorBase.detach_, r""" + detach_ = _C._add_docstr( + _C._TensorBase.detach_, + r""" Detaches the Tensor from the graph that created it, making it a leaf. Views cannot be detached in-place. This method also affects forward mode AD gradients and the result will never have forward mode AD gradients. - """) + """, + ) def is_shared(self): r"""Checks if tensor is in shared memory. @@ -532,29 +628,44 @@ def __reversed__(self): def norm(self, p="fro", dim=None, keepdim=False, dtype=None): r"""See :func:`torch.norm`""" if has_torch_function_unary(self): - return handle_torch_function(Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype) + return handle_torch_function( + Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype + ) return torch.norm(self, p, dim, keepdim, dtype=dtype) def solve(self, other): from ._linalg_utils import solve + return solve(self, other) def lu(self, pivot=True, get_infos=False): r"""See :func:`torch.lu`""" # If get_infos is True, then we don't need to check for errors and vice versa if has_torch_function_unary(self): - return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos) + return handle_torch_function( + Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos + ) - LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos)) + LU, pivots, infos = torch._lu_with_info( + self, pivot=pivot, check_errors=(not get_infos) + ) if get_infos: return LU, pivots, infos else: return LU, pivots - def stft(self, n_fft: int, hop_length: Optional[int] = None, - win_length: Optional[int] = None, window: 'Optional[Tensor]' = None, - center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, - onesided: Optional[bool] = None, return_complex: Optional[bool] = None): + def stft( + self, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: "Optional[Tensor]" = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, + ): r"""See :func:`torch.stft` .. warning:: @@ -563,33 +674,79 @@ def stft(self, n_fft: int, hop_length: Optional[int] = None, """ if has_torch_function_unary(self): return handle_torch_function( - Tensor.stft, (self,), self, n_fft, hop_length=hop_length, - win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, - onesided=onesided, return_complex=return_complex + Tensor.stft, + (self,), + self, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, + normalized=normalized, + onesided=onesided, + return_complex=return_complex, ) - return torch.stft(self, n_fft, hop_length, win_length, window, center, - pad_mode, normalized, onesided, return_complex=return_complex) - - def istft(self, n_fft: int, hop_length: Optional[int] = None, - win_length: Optional[int] = None, window: 'Optional[Tensor]' = None, - center: bool = True, normalized: bool = False, - onesided: Optional[bool] = None, length: Optional[int] = None, - return_complex: bool = False): + return torch.stft( + self, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + return_complex=return_complex, + ) + + def istft( + self, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: "Optional[Tensor]" = None, + center: bool = True, + normalized: bool = False, + onesided: Optional[bool] = None, + length: Optional[int] = None, + return_complex: bool = False, + ): r"""See :func:`torch.istft`""" if has_torch_function_unary(self): return handle_torch_function( - Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, normalized=normalized, onesided=onesided, length=length, - return_complex=return_complex + Tensor.istft, + (self,), + self, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + length=length, + return_complex=return_complex, ) - return torch.istft(self, n_fft, hop_length, win_length, window, center, - normalized, onesided, length, return_complex=return_complex) + return torch.istft( + self, + n_fft, + hop_length, + win_length, + window, + center, + normalized, + onesided, + length, + return_complex=return_complex, + ) def resize(self, *sizes): if has_torch_function_unary(self): return handle_torch_function(Tensor.resize, (self,), self, *sizes) warnings.warn("non-inplace resize is deprecated") from torch.autograd._functions import Resize + return Resize.apply(self, sizes) def resize_as(self, tensor): @@ -597,13 +754,15 @@ def resize_as(self, tensor): return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor) warnings.warn("non-inplace resize_as is deprecated") from torch.autograd._functions import Resize + return Resize.apply(self, tensor.size()) def split(self, split_size, dim=0): - r"""See :func:`torch.split` - """ + r"""See :func:`torch.split`""" if has_torch_function_unary(self): - return handle_torch_function(Tensor.split, (self,), self, split_size, dim=dim) + return handle_torch_function( + Tensor.split, (self,), self, split_size, dim=dim + ) if isinstance(split_size, int): return super(Tensor, self).split(split_size, dim) elif isinstance(split_size, Tensor): @@ -622,10 +781,21 @@ def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=Non """ if has_torch_function_unary(self): return handle_torch_function( - Tensor.unique, (self,), self, sorted=sorted, return_inverse=return_inverse, - return_counts=return_counts, dim=dim + Tensor.unique, + (self,), + self, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, ) - return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) + return torch.unique( + self, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): r"""Eliminates all but the first element from every consecutive group of equivalent elements. @@ -634,10 +804,16 @@ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None """ if has_torch_function_unary(self): return handle_torch_function( - Tensor.unique_consecutive, (self,), self, return_inverse=return_inverse, - return_counts=return_counts, dim=dim + Tensor.unique_consecutive, + (self,), + self, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, ) - return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim) + return torch.unique_consecutive( + self, return_inverse=return_inverse, return_counts=return_counts, dim=dim + ) @_handle_torch_function_and_wrap_type_error_to_not_implemented def __rsub__(self, other): @@ -650,8 +826,12 @@ def __rdiv__(self, other): __rtruediv__ = __rdiv__ __itruediv__ = _C._TensorBase.__idiv__ - __pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(_C._TensorBase.pow) - __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(_C._TensorBase.pow_) + __pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( + _C._TensorBase.pow + ) + __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( + _C._TensorBase.pow_ + ) @_handle_torch_function_and_wrap_type_error_to_not_implemented def __rmod__(self, other): @@ -660,7 +840,7 @@ def __rmod__(self, other): def __format__(self, format_spec): if has_torch_function_unary(self): return handle_torch_function(Tensor.__format__, (self,), self, format_spec) - if self.dim() == 0 and not self.is_meta: + if self.dim() == 0 and not self.is_meta and type(self) is Tensor: return self.item().__format__(format_spec) return object.__format__(self, format_spec) @@ -699,10 +879,14 @@ def __len__(self): if self.dim() == 0: raise TypeError("len() of a 0-d tensor") if torch._C._get_tracing_state(): - warnings.warn('Using len to get tensor shape might cause the trace to be incorrect. ' - 'Recommended usage would be tensor.shape[0]. ' - 'Passing a tensor of different shape might lead to errors or silently give ' - 'incorrect results.', category=torch.jit.TracerWarning, stacklevel=2) + warnings.warn( + "Using len to get tensor shape might cause the trace to be incorrect. " + "Recommended usage would be tensor.shape[0]. " + "Passing a tensor of different shape might lead to errors or silently give " + "incorrect results.", + category=torch.jit.TracerWarning, + stacklevel=2, + ) return self.shape[0] def __iter__(self): @@ -715,12 +899,16 @@ def __iter__(self): # NB: We have intentionally skipped __torch_function__ dispatch here. # See gh-54457 if self.dim() == 0: - raise TypeError('iteration over a 0-d tensor') + raise TypeError("iteration over a 0-d tensor") if torch._C._get_tracing_state(): - warnings.warn('Iterating over a tensor might cause the trace to be incorrect. ' - 'Passing a tensor of different shape won\'t change the number of ' - 'iterations executed (and might lead to errors or silently give ' - 'incorrect results).', category=torch.jit.TracerWarning, stacklevel=2) + warnings.warn( + "Iterating over a tensor might cause the trace to be incorrect. " + "Passing a tensor of different shape won't change the number of " + "iterations executed (and might lead to errors or silently give " + "incorrect results).", + category=torch.jit.TracerWarning, + stacklevel=2, + ) return iter(self.unbind(0)) def __hash__(self): @@ -732,7 +920,7 @@ def __dir__(self): if has_torch_function_unary(self): return handle_torch_function(Tensor.__dir__, (self,), self) tensor_methods = dir(self.__class__) - tensor_methods.remove('volatile') # deprecated + tensor_methods.remove("volatile") # deprecated attrs = list(self.__dict__.keys()) keys = tensor_methods + attrs @@ -743,7 +931,7 @@ def __dir__(self): return sorted(keys) # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray` - __array_priority__ = 1000 # prefer Tensor ops over numpy ones + __array_priority__ = 1000 # prefer Tensor ops over numpy ones def __array__(self, dtype=None): if has_torch_function_unary(self): @@ -757,10 +945,12 @@ def __array__(self, dtype=None): # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor` def __array_wrap__(self, array): if has_torch_function_unary(self): - return handle_torch_function(Tensor.__array_wrap__, (self,), self, array=array) + return handle_torch_function( + Tensor.__array_wrap__, (self,), self, array=array + ) if array.dtype == bool: # Workaround, torch has no built-in bool tensor - array = array.astype('uint8') + array = array.astype("uint8") return torch.from_numpy(array) def __contains__(self, element): @@ -777,8 +967,8 @@ def __contains__(self, element): return (element == self).any().item() # type: ignore[union-attr] raise RuntimeError( - "Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." % - type(element) + "Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." + % type(element) ) @property @@ -797,15 +987,15 @@ def __cuda_array_interface__(self): if not self.is_cuda: raise AttributeError( "Can't get __cuda_array_interface__ on non-CUDA tensor type: %s " - "If CUDA data is required use tensor.cuda() to copy tensor to device memory." % - self.type() + "If CUDA data is required use tensor.cuda() to copy tensor to device memory." + % self.type() ) if self.is_sparse: raise AttributeError( "Can't get __cuda_array_interface__ on sparse type: %s " - "Use Tensor.to_dense() to convert to a dense tensor first." % - self.type() + "Use Tensor.to_dense() to convert to a dense tensor first." + % self.type() ) # RuntimeError, matching tensor.__array__() behavior. @@ -896,7 +1086,7 @@ def refine_names(self, *names): """ if has_torch_function_unary(self): return handle_torch_function(Tensor.refine_names, (self,), self, *names) - names = resolve_ellipsis(names, self.names, 'refine_names') + names = resolve_ellipsis(names, self.names, "refine_names") return super(Tensor, self).refine_names(names) def align_to(self, *names): @@ -937,42 +1127,18 @@ def align_to(self, *names): """ if has_torch_function_unary(self): return handle_torch_function(Tensor.align_to, (self,), self, *names) - ellipsis_idx = single_ellipsis_index(names, 'align_to') + ellipsis_idx = single_ellipsis_index(names, "align_to") if ellipsis_idx is None: return super(Tensor, self).align_to(names) return super(Tensor, self).align_to( - [name for name in names if not is_ellipsis(name)], - ellipsis_idx) + [name for name in names if not is_ellipsis(name)], ellipsis_idx + ) def unflatten(self, dim, sizes): - r"""Expands the dimension :attr:`dim` of the :attr:`self` tensor over multiple dimensions - of sizes given by :attr:`sizes`. - - * :attr:`sizes` is the new shape of the unflattened dimension and it can be a `Tuple[int]` as well - as `torch.Size` if :attr:`self` is a `Tensor`, or `namedshape` (Tuple[(name: str, size: int)]) - if :attr:`self` is a `NamedTensor`. The total number of elements in sizes must match the number - of elements in the original dim being unflattened. - - Args: - dim (Union[int, str]): Dimension to unflatten - sizes (Union[Tuple[int] or torch.Size, Tuple[Tuple[str, int]]]): New shape of the unflattened dimension - - Examples: - >>> torch.randn(3, 4, 1).unflatten(1, (2, 2)).shape - torch.Size([3, 2, 2, 1]) - >>> torch.randn(3, 4, 1).unflatten(1, (-1, 2)).shape # the size -1 is inferred from the size of dimension 1 - torch.Size([3, 2, 2, 1]) - >>> torch.randn(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))) - tensor([[[-1.1772, 0.0180], - [ 0.2412, 0.1431]], - [[-1.1819, -0.8899], - [ 1.5813, 0.2274]]], names=('A', 'B1', 'B2')) - >>> torch.randn(2, names=('A',)).unflatten('A', (('B1', -1), ('B2', 1))) - tensor([[-0.8591], - [ 0.3100]], names=('B1', 'B2')) + r""" + unflatten(dim, sizes) -> Tensor - .. warning:: - The named tensor API is experimental and subject to change. + See :func:`torch.unflatten`. """ if has_torch_function_unary(self): @@ -982,16 +1148,21 @@ def unflatten(self, dim, sizes): raise RuntimeError("unflatten: sizes must be non-empty") names = None - if isinstance(sizes, OrderedDict) or (isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list))): + if isinstance(sizes, OrderedDict) or ( + isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list)) + ): names, sizes = unzip_namedshape(sizes) - return super(Tensor, self).unflatten(dim, sizes, names) - + return super(Tensor, self).unflatten(dim, sizes, names) + else: + return super(Tensor, self).unflatten(dim, sizes) def rename_(self, *names, **rename_map): """In-place version of :meth:`~Tensor.rename`.""" if has_torch_function_unary(self): - return handle_torch_function(Tensor.rename_, (self,), self, *names, **rename_map) + return handle_torch_function( + Tensor.rename_, (self,), self, *names, **rename_map + ) # Note [rename_ / rename API] # The Python API for these is different from the C++ API. In Python: @@ -1035,27 +1206,31 @@ def rename(self, *names, **rename_map): """ if has_torch_function_unary(self): - return handle_torch_function(Tensor.rename, (self,), self, *names, **rename_map) + return handle_torch_function( + Tensor.rename, (self,), self, *names, **rename_map + ) # See Note [rename_ / rename API] return update_names(self, names, rename_map, inplace=False) def to_sparse_coo(self): - """ Convert a tensor to :ref:`coordinate format `. + """Convert a tensor to :ref:`coordinate format `. - Examples:: + Examples:: - >>> dense = torch.randn(5, 5) - >>> sparse = dense.to_sparse_coo() - >>> sparse._nnz() - 25 + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_coo() + >>> sparse._nnz() + 25 - """ + """ return self.to_sparse() def _update_names(self, names, inplace): if has_torch_function_unary(self): - return handle_torch_function(Tensor._update_names, (self,), self, names, inplace) + return handle_torch_function( + Tensor._update_names, (self,), self, names, inplace + ) # See Note [rename_ / rename API] if inplace: @@ -1116,18 +1291,22 @@ def __dlpack__(self, stream=None): # so we prohibit exporting tensors that would lose their properties like # requires_grad and having the conjugate bit set. if self.requires_grad: - raise RuntimeError('Can\'t export tensors that require gradient, use tensor.detach()') + raise RuntimeError( + "Can't export tensors that require gradient, use tensor.detach()" + ) if self.is_conj(): - raise RuntimeError('Can\'t export tensors with the conjugate bit set') + raise RuntimeError("Can't export tensors with the conjugate bit set") if self.layout != torch.strided: - raise RuntimeError('Can\'t export tensors with layout other than torch.strided') + raise RuntimeError( + "Can't export tensors with layout other than torch.strided" + ) if stream is not None and type(stream) is not int: # Stream pointers in CUDA/ROCm are uniquely numbered and can # be retrieved from their integer value. - raise TypeError('stream must be ``int`` or ``none``') + raise TypeError("stream must be ``int`` or ``none``") elif stream is not None and stream != -1: - if self.device.type == 'cuda': + if self.device.type == "cuda": stream = torch.cuda.ExternalStream(stream) # Only synchronize on different streams if stream != torch.cuda.current_stream: @@ -1139,22 +1318,26 @@ def __dlpack__(self, stream=None): def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: # Avoid circular import from torch.utils.dlpack import DLDeviceType + if has_torch_function_unary(self): return handle_torch_function(Tensor.__dlpack_device__, (self,), self) idx = self.device.index if self.device.index is not None else 0 - if self.device.type == 'cuda' and torch.version.hip is not None: + if self.device.type == "cuda" and torch.version.hip is not None: device_type = DLDeviceType.kDLROCM - elif self.device.type == 'cpu' and self.is_pinned(): + elif self.device.type == "cpu" and self.is_pinned(): device_type = DLDeviceType.kDLCPUPinned - elif self.device.type == 'cuda': + elif self.device.type == "cuda": device_type = DLDeviceType.kDLGPU - elif self.device.type == 'cpu': + elif self.device.type == "cpu": device_type = DLDeviceType.kDLCPU else: - raise ValueError('Unknown device type {} for Dlpack'.format(self.device.type)) + raise ValueError( + "Unknown device type {} for Dlpack".format(self.device.type) + ) return (device_type, idx) - __module__ = 'torch' + __module__ = "torch" + def _convert(ret, cls): if cls is Tensor: diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 10e4286c03909..3380942c02877 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2,19 +2,22 @@ import torch._C from torch._C import _add_docstr as add_docstr -from ._torch_docs import parse_kwargs -from ._torch_docs import reproducibility_notes +from ._torch_docs import parse_kwargs, reproducibility_notes def add_docstr_all(method, docstr): add_docstr(getattr(torch._C._TensorBase, method), docstr) -common_args = parse_kwargs(""" + +common_args = parse_kwargs( + """ memory_format (:class:`torch.memory_format`, optional): the desired memory format of returned Tensor. Default: ``torch.preserve_format``. -""") +""" +) -new_common_args = parse_kwargs(""" +new_common_args = parse_kwargs( + """ size (int...): a list, tuple, or :class:`torch.Size` of integers defining the shape of the output tensor. dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. @@ -25,10 +28,12 @@ def add_docstr_all(method, docstr): returned tensor. Default: ``False``. pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: ``False``. -""") +""" +) -add_docstr_all('new_tensor', - r""" +add_docstr_all( + "new_tensor", + r""" new_tensor(data, dtype=None, device=None, requires_grad=False) -> Tensor Returns a new Tensor with :attr:`data` as the tensor data. @@ -64,10 +69,14 @@ def add_docstr_all(method, docstr): tensor([[ 0, 1], [ 2, 3]], dtype=torch.int8) -""".format(**new_common_args)) +""".format( + **new_common_args + ), +) -add_docstr_all('new_full', - r""" +add_docstr_all( + "new_full", + r""" new_full(size, fill_value, dtype=None, device=None, requires_grad=False) -> Tensor Returns a Tensor of size :attr:`size` filled with :attr:`fill_value`. @@ -88,10 +97,14 @@ def add_docstr_all(method, docstr): [ 3.1416, 3.1416, 3.1416, 3.1416], [ 3.1416, 3.1416, 3.1416, 3.1416]], dtype=torch.float64) -""".format(**new_common_args)) +""".format( + **new_common_args + ), +) -add_docstr_all('new_empty', - r""" +add_docstr_all( + "new_empty", + r""" new_empty(size, dtype=None, device=None, requires_grad=False) -> Tensor Returns a Tensor of size :attr:`size` filled with uninitialized data. @@ -110,10 +123,14 @@ def add_docstr_all(method, docstr): tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) -""".format(**new_common_args)) +""".format( + **new_common_args + ), +) -add_docstr_all('new_empty_strided', - r""" +add_docstr_all( + "new_empty_strided", + r""" new_empty_strided(size, stride, dtype=None, device=None, requires_grad=False) -> Tensor Returns a Tensor of size :attr:`size` and strides :attr:`stride` filled with @@ -132,10 +149,14 @@ def add_docstr_all(method, docstr): tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) -""".format(**new_common_args)) +""".format( + **new_common_args + ), +) -add_docstr_all('new_ones', - r""" +add_docstr_all( + "new_ones", + r""" new_ones(size, dtype=None, device=None, requires_grad=False) -> Tensor Returns a Tensor of size :attr:`size` filled with ``1``. @@ -156,10 +177,14 @@ def add_docstr_all(method, docstr): tensor([[ 1, 1, 1], [ 1, 1, 1]], dtype=torch.int32) -""".format(**new_common_args)) +""".format( + **new_common_args + ), +) -add_docstr_all('new_zeros', - r""" +add_docstr_all( + "new_zeros", + r""" new_zeros(size, dtype=None, device=None, requires_grad=False) -> Tensor Returns a Tensor of size :attr:`size` filled with ``0``. @@ -180,91 +205,123 @@ def add_docstr_all(method, docstr): tensor([[ 0., 0., 0.], [ 0., 0., 0.]], dtype=torch.float64) -""".format(**new_common_args)) +""".format( + **new_common_args + ), +) -add_docstr_all('abs', - r""" +add_docstr_all( + "abs", + r""" abs() -> Tensor See :func:`torch.abs` -""") +""", +) -add_docstr_all('abs_', - r""" +add_docstr_all( + "abs_", + r""" abs_() -> Tensor In-place version of :meth:`~Tensor.abs` -""") +""", +) -add_docstr_all('absolute', - r""" +add_docstr_all( + "absolute", + r""" absolute() -> Tensor Alias for :func:`abs` -""") +""", +) -add_docstr_all('absolute_', - r""" +add_docstr_all( + "absolute_", + r""" absolute_() -> Tensor In-place version of :meth:`~Tensor.absolute` Alias for :func:`abs_` -""") +""", +) -add_docstr_all('acos', - r""" +add_docstr_all( + "acos", + r""" acos() -> Tensor See :func:`torch.acos` -""") +""", +) -add_docstr_all('acos_', - r""" +add_docstr_all( + "acos_", + r""" acos_() -> Tensor In-place version of :meth:`~Tensor.acos` -""") +""", +) -add_docstr_all('arccos', r""" +add_docstr_all( + "arccos", + r""" arccos() -> Tensor See :func:`torch.arccos` -""") +""", +) -add_docstr_all('arccos_', r""" +add_docstr_all( + "arccos_", + r""" arccos_() -> Tensor In-place version of :meth:`~Tensor.arccos` -""") +""", +) -add_docstr_all('acosh', - r""" +add_docstr_all( + "acosh", + r""" acosh() -> Tensor See :func:`torch.acosh` -""") +""", +) -add_docstr_all('acosh_', - r""" +add_docstr_all( + "acosh_", + r""" acosh_() -> Tensor In-place version of :meth:`~Tensor.acosh` -""") +""", +) -add_docstr_all('arccosh', r""" +add_docstr_all( + "arccosh", + r""" acosh() -> Tensor See :func:`torch.arccosh` -""") +""", +) -add_docstr_all('arccosh_', r""" +add_docstr_all( + "arccosh_", + r""" acosh_() -> Tensor In-place version of :meth:`~Tensor.arccosh` -""") +""", +) -add_docstr_all('add', - r""" +add_docstr_all( + "add", + r""" add(other, *, alpha=1) -> Tensor Add a scalar or tensor to :attr:`self` tensor. If both :attr:`alpha` @@ -276,115 +333,147 @@ def add_docstr_all(method, docstr): tensor See :func:`torch.add` -""") +""", +) -add_docstr_all('add_', - r""" +add_docstr_all( + "add_", + r""" add_(other, *, alpha=1) -> Tensor In-place version of :meth:`~Tensor.add` -""") +""", +) -add_docstr_all('addbmm', - r""" +add_docstr_all( + "addbmm", + r""" addbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor See :func:`torch.addbmm` -""") +""", +) -add_docstr_all('addbmm_', - r""" +add_docstr_all( + "addbmm_", + r""" addbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.addbmm` -""") +""", +) -add_docstr_all('addcdiv', - r""" +add_docstr_all( + "addcdiv", + r""" addcdiv(tensor1, tensor2, *, value=1) -> Tensor See :func:`torch.addcdiv` -""") +""", +) -add_docstr_all('addcdiv_', - r""" +add_docstr_all( + "addcdiv_", + r""" addcdiv_(tensor1, tensor2, *, value=1) -> Tensor In-place version of :meth:`~Tensor.addcdiv` -""") +""", +) -add_docstr_all('addcmul', - r""" +add_docstr_all( + "addcmul", + r""" addcmul(tensor1, tensor2, *, value=1) -> Tensor See :func:`torch.addcmul` -""") +""", +) -add_docstr_all('addcmul_', - r""" +add_docstr_all( + "addcmul_", + r""" addcmul_(tensor1, tensor2, *, value=1) -> Tensor In-place version of :meth:`~Tensor.addcmul` -""") +""", +) -add_docstr_all('addmm', - r""" +add_docstr_all( + "addmm", + r""" addmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor See :func:`torch.addmm` -""") +""", +) -add_docstr_all('addmm_', - r""" +add_docstr_all( + "addmm_", + r""" addmm_(mat1, mat2, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.addmm` -""") +""", +) -add_docstr_all('addmv', - r""" +add_docstr_all( + "addmv", + r""" addmv(mat, vec, *, beta=1, alpha=1) -> Tensor See :func:`torch.addmv` -""") +""", +) -add_docstr_all('addmv_', - r""" +add_docstr_all( + "addmv_", + r""" addmv_(mat, vec, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.addmv` -""") +""", +) -add_docstr_all('sspaddmm', - r""" +add_docstr_all( + "sspaddmm", + r""" sspaddmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor See :func:`torch.sspaddmm` -""") +""", +) -add_docstr_all('smm', - r""" +add_docstr_all( + "smm", + r""" smm(mat) -> Tensor See :func:`torch.smm` -""") +""", +) -add_docstr_all('addr', - r""" +add_docstr_all( + "addr", + r""" addr(vec1, vec2, *, beta=1, alpha=1) -> Tensor See :func:`torch.addr` -""") +""", +) -add_docstr_all('addr_', - r""" +add_docstr_all( + "addr_", + r""" addr_(vec1, vec2, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.addr` -""") +""", +) -add_docstr_all('align_as', - r""" +add_docstr_all( + "align_as", + r""" align_as(other) -> Tensor Permutes the dimensions of the :attr:`self` tensor to match the dimension order @@ -428,38 +517,48 @@ def add_docstr_all(method, docstr): .. warning:: The named tensor API is experimental and subject to change. -""") +""", +) -add_docstr_all('all', - r""" +add_docstr_all( + "all", + r""" all(dim=None, keepdim=False) -> Tensor See :func:`torch.all` -""") +""", +) -add_docstr_all('allclose', - r""" +add_docstr_all( + "allclose", + r""" allclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor See :func:`torch.allclose` -""") +""", +) -add_docstr_all('angle', - r""" +add_docstr_all( + "angle", + r""" angle() -> Tensor See :func:`torch.angle` -""") +""", +) -add_docstr_all('any', - r""" +add_docstr_all( + "any", + r""" any(dim=None, keepdim=False) -> Tensor See :func:`torch.any` -""") +""", +) -add_docstr_all('apply_', - r""" +add_docstr_all( + "apply_", + r""" apply_(callable) -> Tensor Applies the function :attr:`callable` to each element in the tensor, replacing @@ -469,152 +568,219 @@ def add_docstr_all(method, docstr): This function only works with CPU tensors and should not be used in code sections that require high performance. -""") +""", +) -add_docstr_all('asin', r""" +add_docstr_all( + "asin", + r""" asin() -> Tensor See :func:`torch.asin` -""") +""", +) -add_docstr_all('asin_', - r""" +add_docstr_all( + "asin_", + r""" asin_() -> Tensor In-place version of :meth:`~Tensor.asin` -""") +""", +) -add_docstr_all('arcsin', r""" +add_docstr_all( + "arcsin", + r""" arcsin() -> Tensor See :func:`torch.arcsin` -""") +""", +) -add_docstr_all('arcsin_', r""" +add_docstr_all( + "arcsin_", + r""" arcsin_() -> Tensor In-place version of :meth:`~Tensor.arcsin` -""") +""", +) -add_docstr_all('asinh', r""" +add_docstr_all( + "asinh", + r""" asinh() -> Tensor See :func:`torch.asinh` -""") +""", +) -add_docstr_all('asinh_', - r""" +add_docstr_all( + "asinh_", + r""" asinh_() -> Tensor In-place version of :meth:`~Tensor.asinh` -""") +""", +) -add_docstr_all('arcsinh', r""" +add_docstr_all( + "arcsinh", + r""" arcsinh() -> Tensor See :func:`torch.arcsinh` -""") +""", +) -add_docstr_all('arcsinh_', r""" +add_docstr_all( + "arcsinh_", + r""" arcsinh_() -> Tensor In-place version of :meth:`~Tensor.arcsinh` -""") +""", +) -add_docstr_all('as_strided', r""" +add_docstr_all( + "as_strided", + r""" as_strided(size, stride, storage_offset=None) -> Tensor See :func:`torch.as_strided` -""") +""", +) -add_docstr_all('atan', r""" +add_docstr_all( + "atan", + r""" atan() -> Tensor See :func:`torch.atan` -""") +""", +) -add_docstr_all('atan_', r""" +add_docstr_all( + "atan_", + r""" atan_() -> Tensor In-place version of :meth:`~Tensor.atan` -""") +""", +) -add_docstr_all('arctan', r""" +add_docstr_all( + "arctan", + r""" arctan() -> Tensor See :func:`torch.arctan` -""") +""", +) -add_docstr_all('arctan_', r""" +add_docstr_all( + "arctan_", + r""" arctan_() -> Tensor In-place version of :meth:`~Tensor.arctan` -""") +""", +) -add_docstr_all('atan2', r""" +add_docstr_all( + "atan2", + r""" atan2(other) -> Tensor See :func:`torch.atan2` -""") +""", +) -add_docstr_all('atan2_', r""" +add_docstr_all( + "atan2_", + r""" atan2_(other) -> Tensor In-place version of :meth:`~Tensor.atan2` -""") +""", +) -add_docstr_all('arctan2', r""" +add_docstr_all( + "arctan2", + r""" arctan2(other) -> Tensor See :func:`torch.arctan2` -""") +""", +) -add_docstr_all('arctan2_', r""" +add_docstr_all( + "arctan2_", + r""" atan2_(other) -> Tensor In-place version of :meth:`~Tensor.arctan2` -""") +""", +) -add_docstr_all('atanh', r""" +add_docstr_all( + "atanh", + r""" atanh() -> Tensor See :func:`torch.atanh` -""") +""", +) -add_docstr_all('atanh_', r""" +add_docstr_all( + "atanh_", + r""" atanh_(other) -> Tensor In-place version of :meth:`~Tensor.atanh` -""") +""", +) -add_docstr_all('arctanh', r""" +add_docstr_all( + "arctanh", + r""" arctanh() -> Tensor See :func:`torch.arctanh` -""") +""", +) -add_docstr_all('arctanh_', r""" +add_docstr_all( + "arctanh_", + r""" arctanh_(other) -> Tensor In-place version of :meth:`~Tensor.arctanh` -""") +""", +) -add_docstr_all('baddbmm', - r""" +add_docstr_all( + "baddbmm", + r""" baddbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor See :func:`torch.baddbmm` -""") +""", +) -add_docstr_all('baddbmm_', - r""" +add_docstr_all( + "baddbmm_", + r""" baddbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.baddbmm` -""") +""", +) -add_docstr_all('bernoulli', - r""" +add_docstr_all( + "bernoulli", + r""" bernoulli(*, generator=None) -> Tensor Returns a result tensor where each :math:`\texttt{result[i]}` is independently @@ -622,10 +788,12 @@ def add_docstr_all(method, docstr): floating point ``dtype``, and the result will have the same ``dtype``. See :func:`torch.bernoulli` -""") +""", +) -add_docstr_all('bernoulli_', - r""" +add_docstr_all( + "bernoulli_", + r""" bernoulli_(p=0.5, *, generator=None) -> Tensor Fills each location of :attr:`self` with an independent sample from @@ -641,171 +809,219 @@ def add_docstr_all(method, docstr): floating point ``dtype``. See also :meth:`~Tensor.bernoulli` and :func:`torch.bernoulli` -""") +""", +) -add_docstr_all('bincount', - r""" +add_docstr_all( + "bincount", + r""" bincount(weights=None, minlength=0) -> Tensor See :func:`torch.bincount` -""") +""", +) -add_docstr_all('bitwise_not', - r""" +add_docstr_all( + "bitwise_not", + r""" bitwise_not() -> Tensor See :func:`torch.bitwise_not` -""") +""", +) -add_docstr_all('bitwise_not_', - r""" +add_docstr_all( + "bitwise_not_", + r""" bitwise_not_() -> Tensor In-place version of :meth:`~Tensor.bitwise_not` -""") +""", +) -add_docstr_all('bitwise_and', - r""" +add_docstr_all( + "bitwise_and", + r""" bitwise_and() -> Tensor See :func:`torch.bitwise_and` -""") +""", +) -add_docstr_all('bitwise_and_', - r""" +add_docstr_all( + "bitwise_and_", + r""" bitwise_and_() -> Tensor In-place version of :meth:`~Tensor.bitwise_and` -""") +""", +) -add_docstr_all('bitwise_or', - r""" +add_docstr_all( + "bitwise_or", + r""" bitwise_or() -> Tensor See :func:`torch.bitwise_or` -""") +""", +) -add_docstr_all('bitwise_or_', - r""" +add_docstr_all( + "bitwise_or_", + r""" bitwise_or_() -> Tensor In-place version of :meth:`~Tensor.bitwise_or` -""") +""", +) -add_docstr_all('bitwise_xor', - r""" +add_docstr_all( + "bitwise_xor", + r""" bitwise_xor() -> Tensor See :func:`torch.bitwise_xor` -""") +""", +) -add_docstr_all('bitwise_xor_', - r""" +add_docstr_all( + "bitwise_xor_", + r""" bitwise_xor_() -> Tensor In-place version of :meth:`~Tensor.bitwise_xor` -""") +""", +) -add_docstr_all('bitwise_left_shift', - r""" +add_docstr_all( + "bitwise_left_shift", + r""" bitwise_left_shift(other) -> Tensor See :func:`torch.bitwise_left_shift` -""") +""", +) -add_docstr_all('bitwise_left_shift_', - r""" +add_docstr_all( + "bitwise_left_shift_", + r""" bitwise_left_shift_(other) -> Tensor In-place version of :meth:`~Tensor.bitwise_left_shift` -""") +""", +) -add_docstr_all('bitwise_right_shift', - r""" +add_docstr_all( + "bitwise_right_shift", + r""" bitwise_right_shift(other) -> Tensor See :func:`torch.bitwise_right_shift` -""") +""", +) -add_docstr_all('bitwise_right_shift_', - r""" +add_docstr_all( + "bitwise_right_shift_", + r""" bitwise_right_shift_(other) -> Tensor In-place version of :meth:`~Tensor.bitwise_right_shift` -""") +""", +) -add_docstr_all('broadcast_to', - r""" +add_docstr_all( + "broadcast_to", + r""" broadcast_to(shape) -> Tensor See :func:`torch.broadcast_to`. -""") +""", +) -add_docstr_all('logical_and', - r""" +add_docstr_all( + "logical_and", + r""" logical_and() -> Tensor See :func:`torch.logical_and` -""") +""", +) -add_docstr_all('logical_and_', - r""" +add_docstr_all( + "logical_and_", + r""" logical_and_() -> Tensor In-place version of :meth:`~Tensor.logical_and` -""") +""", +) -add_docstr_all('logical_not', - r""" +add_docstr_all( + "logical_not", + r""" logical_not() -> Tensor See :func:`torch.logical_not` -""") +""", +) -add_docstr_all('logical_not_', - r""" +add_docstr_all( + "logical_not_", + r""" logical_not_() -> Tensor In-place version of :meth:`~Tensor.logical_not` -""") +""", +) -add_docstr_all('logical_or', - r""" +add_docstr_all( + "logical_or", + r""" logical_or() -> Tensor See :func:`torch.logical_or` -""") +""", +) -add_docstr_all('logical_or_', - r""" +add_docstr_all( + "logical_or_", + r""" logical_or_() -> Tensor In-place version of :meth:`~Tensor.logical_or` -""") +""", +) -add_docstr_all('logical_xor', - r""" +add_docstr_all( + "logical_xor", + r""" logical_xor() -> Tensor See :func:`torch.logical_xor` -""") +""", +) -add_docstr_all('logical_xor_', - r""" +add_docstr_all( + "logical_xor_", + r""" logical_xor_() -> Tensor In-place version of :meth:`~Tensor.logical_xor` -""") +""", +) -add_docstr_all('bmm', - r""" +add_docstr_all( + "bmm", + r""" bmm(batch2) -> Tensor See :func:`torch.bmm` -""") +""", +) -add_docstr_all('cauchy_', - r""" +add_docstr_all( + "cauchy_", + r""" cauchy_(median=0, sigma=1, *, generator=None) -> Tensor Fills the tensor with numbers drawn from the Cauchy distribution: @@ -813,77 +1029,104 @@ def add_docstr_all(method, docstr): .. math:: f(x) = \dfrac{1}{\pi} \dfrac{\sigma}{(x - \text{median})^2 + \sigma^2} -""") +""", +) -add_docstr_all('ceil', - r""" +add_docstr_all( + "ceil", + r""" ceil() -> Tensor See :func:`torch.ceil` -""") +""", +) -add_docstr_all('ceil_', - r""" +add_docstr_all( + "ceil_", + r""" ceil_() -> Tensor In-place version of :meth:`~Tensor.ceil` -""") +""", +) -add_docstr_all('cholesky', - r""" +add_docstr_all( + "cholesky", + r""" cholesky(upper=False) -> Tensor See :func:`torch.cholesky` -""") +""", +) -add_docstr_all('cholesky_solve', - r""" +add_docstr_all( + "cholesky_solve", + r""" cholesky_solve(input2, upper=False) -> Tensor See :func:`torch.cholesky_solve` -""") +""", +) -add_docstr_all('cholesky_inverse', - r""" +add_docstr_all( + "cholesky_inverse", + r""" cholesky_inverse(upper=False) -> Tensor See :func:`torch.cholesky_inverse` -""") +""", +) -add_docstr_all('clamp', - r""" +add_docstr_all( + "clamp", + r""" clamp(min=None, max=None) -> Tensor See :func:`torch.clamp` -""") +""", +) -add_docstr_all('clamp_', - r""" +add_docstr_all( + "clamp_", + r""" clamp_(min=None, max=None) -> Tensor In-place version of :meth:`~Tensor.clamp` -""") +""", +) -add_docstr_all('clip', r""" +add_docstr_all( + "clip", + r""" clip(min=None, max=None) -> Tensor Alias for :meth:`~Tensor.clamp`. -""") +""", +) -add_docstr_all('clip_', r""" +add_docstr_all( + "clip_", + r""" clip_(min=None, max=None) -> Tensor Alias for :meth:`~Tensor.clamp_`. -""") +""", +) -add_docstr_all('clone', r""" +add_docstr_all( + "clone", + r""" clone(*, memory_format=torch.preserve_format) -> Tensor See :func:`torch.clone` -""".format(**common_args)) - -add_docstr_all('coalesce', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "coalesce", + r""" coalesce() -> Tensor Returns a coalesced copy of :attr:`self` if :attr:`self` is an @@ -893,10 +1136,12 @@ def add_docstr_all(method, docstr): .. warning:: Throws an error if :attr:`self` is not a sparse COO tensor. -""") +""", +) -add_docstr_all('contiguous', - r""" +add_docstr_all( + "contiguous", + r""" contiguous(memory_format=torch.contiguous_format) -> Tensor Returns a contiguous in memory tensor containing the same data as :attr:`self` tensor. If @@ -906,10 +1151,12 @@ def add_docstr_all(method, docstr): Args: memory_format (:class:`torch.memory_format`, optional): the desired memory format of returned Tensor. Default: ``torch.contiguous_format``. -""") +""", +) -add_docstr_all('copy_', - r""" +add_docstr_all( + "copy_", + r""" copy_(src, non_blocking=False) -> Tensor Copies the elements from :attr:`src` into :attr:`self` tensor and returns @@ -924,86 +1171,111 @@ def add_docstr_all(method, docstr): non_blocking (bool): if ``True`` and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. -""") +""", +) -add_docstr_all('conj', - r""" +add_docstr_all( + "conj", + r""" conj() -> Tensor See :func:`torch.conj` -""") +""", +) -add_docstr_all('conj_physical', - r""" +add_docstr_all( + "conj_physical", + r""" conj_physical() -> Tensor See :func:`torch.conj_physical` -""") +""", +) -add_docstr_all('conj_physical_', - r""" +add_docstr_all( + "conj_physical_", + r""" conj_physical_() -> Tensor In-place version of :meth:`~Tensor.conj_physical` -""") +""", +) -add_docstr_all('resolve_conj', - r""" +add_docstr_all( + "resolve_conj", + r""" resolve_conj() -> Tensor See :func:`torch.resolve_conj` -""") +""", +) -add_docstr_all('resolve_neg', - r""" +add_docstr_all( + "resolve_neg", + r""" resolve_neg() -> Tensor See :func:`torch.resolve_neg` -""") +""", +) -add_docstr_all('copysign', - r""" +add_docstr_all( + "copysign", + r""" copysign(other) -> Tensor See :func:`torch.copysign` -""") +""", +) -add_docstr_all('copysign_', r""" +add_docstr_all( + "copysign_", + r""" copysign_(other) -> Tensor In-place version of :meth:`~Tensor.copysign` -""") +""", +) -add_docstr_all('cos', - r""" +add_docstr_all( + "cos", + r""" cos() -> Tensor See :func:`torch.cos` -""") +""", +) -add_docstr_all('cos_', - r""" +add_docstr_all( + "cos_", + r""" cos_() -> Tensor In-place version of :meth:`~Tensor.cos` -""") +""", +) -add_docstr_all('cosh', - r""" +add_docstr_all( + "cosh", + r""" cosh() -> Tensor See :func:`torch.cosh` -""") +""", +) -add_docstr_all('cosh_', - r""" +add_docstr_all( + "cosh_", + r""" cosh_() -> Tensor In-place version of :meth:`~Tensor.cosh` -""") +""", +) -add_docstr_all('cpu', - r""" +add_docstr_all( + "cpu", + r""" cpu(memory_format=torch.preserve_format) -> Tensor Returns a copy of this object in CPU memory. @@ -1014,36 +1286,50 @@ def add_docstr_all(method, docstr): Args: {memory_format} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr_all('count_nonzero', - r""" +add_docstr_all( + "count_nonzero", + r""" count_nonzero(dim=None) -> Tensor See :func:`torch.count_nonzero` -""") +""", +) -add_docstr_all('cov', r""" +add_docstr_all( + "cov", + r""" cov(*, correction=1, fweights=None, aweights=None) -> Tensor See :func:`torch.cov` -""") +""", +) -add_docstr_all('corrcoef', r""" +add_docstr_all( + "corrcoef", + r""" corrcoef() -> Tensor See :func:`torch.corrcoef` -""") +""", +) -add_docstr_all('cross', - r""" +add_docstr_all( + "cross", + r""" cross(other, dim=None) -> Tensor See :func:`torch.cross` -""") +""", +) -add_docstr_all('cuda', - r""" +add_docstr_all( + "cuda", + r""" cuda(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor Returns a copy of this object in CUDA memory. @@ -1058,10 +1344,14 @@ def add_docstr_all(method, docstr): the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. Default: ``False``. {memory_format} -""".format(**common_args)) - -add_docstr_all('ipu', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "ipu", + r""" ipu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor Returns a copy of this object in IPU memory. @@ -1076,10 +1366,14 @@ def add_docstr_all(method, docstr): the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. Default: ``False``. {memory_format} -""".format(**common_args)) - -add_docstr_all('xpu', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "xpu", + r""" xpu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor Returns a copy of this object in XPU memory. @@ -1094,73 +1388,95 @@ def add_docstr_all(method, docstr): the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. Default: ``False``. {memory_format} -""".format(**common_args)) - -add_docstr_all('logcumsumexp', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "logcumsumexp", + r""" logcumsumexp(dim) -> Tensor See :func:`torch.logcumsumexp` -""") +""", +) -add_docstr_all('cummax', - r""" +add_docstr_all( + "cummax", + r""" cummax(dim) -> (Tensor, Tensor) See :func:`torch.cummax` -""") +""", +) -add_docstr_all('cummin', - r""" +add_docstr_all( + "cummin", + r""" cummin(dim) -> (Tensor, Tensor) See :func:`torch.cummin` -""") +""", +) -add_docstr_all('cumprod', - r""" +add_docstr_all( + "cumprod", + r""" cumprod(dim, dtype=None) -> Tensor See :func:`torch.cumprod` -""") +""", +) -add_docstr_all('cumprod_', - r""" +add_docstr_all( + "cumprod_", + r""" cumprod_(dim, dtype=None) -> Tensor In-place version of :meth:`~Tensor.cumprod` -""") +""", +) -add_docstr_all('cumsum', - r""" +add_docstr_all( + "cumsum", + r""" cumsum(dim, dtype=None) -> Tensor See :func:`torch.cumsum` -""") +""", +) -add_docstr_all('cumsum_', - r""" +add_docstr_all( + "cumsum_", + r""" cumsum_(dim, dtype=None) -> Tensor In-place version of :meth:`~Tensor.cumsum` -""") +""", +) -add_docstr_all('data_ptr', - r""" +add_docstr_all( + "data_ptr", + r""" data_ptr() -> int Returns the address of the first element of :attr:`self` tensor. -""") +""", +) -add_docstr_all('dequantize', - r""" +add_docstr_all( + "dequantize", + r""" dequantize() -> Tensor Given a quantized Tensor, dequantize it and return the dequantized float Tensor. -""") +""", +) -add_docstr_all('dense_dim', - r""" +add_docstr_all( + "dense_dim", + r""" dense_dim() -> int Return the number of dense dimensions in a :ref:`sparse tensor ` :attr:`self`. @@ -1169,52 +1485,66 @@ def add_docstr_all(method, docstr): Throws an error if :attr:`self` is not a sparse tensor. See also :meth:`Tensor.sparse_dim` and :ref:`hybrid tensors `. -""") +""", +) -add_docstr_all('diag', - r""" +add_docstr_all( + "diag", + r""" diag(diagonal=0) -> Tensor See :func:`torch.diag` -""") +""", +) -add_docstr_all('diag_embed', - r""" +add_docstr_all( + "diag_embed", + r""" diag_embed(offset=0, dim1=-2, dim2=-1) -> Tensor See :func:`torch.diag_embed` -""") +""", +) -add_docstr_all('diagflat', - r""" +add_docstr_all( + "diagflat", + r""" diagflat(offset=0) -> Tensor See :func:`torch.diagflat` -""") +""", +) -add_docstr_all('diagonal', - r""" +add_docstr_all( + "diagonal", + r""" diagonal(offset=0, dim1=0, dim2=1) -> Tensor See :func:`torch.diagonal` -""") +""", +) -add_docstr_all('diagonal_scatter', - r""" +add_docstr_all( + "diagonal_scatter", + r""" diagonal_scatter(src, offset=0, dim1=0, dim2=1) -> Tensor See :func:`torch.diagonal_scatter` -""") +""", +) -add_docstr_all('as_strided_scatter', - r""" +add_docstr_all( + "as_strided_scatter", + r""" as_strided_scatter(src, size, stride, storage_offset=0) -> Tensor See :func:`torch.as_strided_scatter` -""") +""", +) -add_docstr_all('fill_diagonal_', - r""" +add_docstr_all( + "fill_diagonal_", + r""" fill_diagonal_(fill_value, wrap=False) -> Tensor Fill the main diagonal of a tensor that has at least 2-dimensions. @@ -1251,97 +1581,129 @@ def add_docstr_all(method, docstr): [0., 5., 0.], [0., 0., 5.]]) -""") +""", +) -add_docstr_all('floor_divide', - r""" +add_docstr_all( + "floor_divide", + r""" floor_divide(value) -> Tensor See :func:`torch.floor_divide` -""") +""", +) -add_docstr_all('floor_divide_', - r""" +add_docstr_all( + "floor_divide_", + r""" floor_divide_(value) -> Tensor In-place version of :meth:`~Tensor.floor_divide` -""") +""", +) -add_docstr_all('diff', - r""" +add_docstr_all( + "diff", + r""" diff(n=1, dim=-1, prepend=None, append=None) -> Tensor See :func:`torch.diff` -""") +""", +) -add_docstr_all('digamma', - r""" +add_docstr_all( + "digamma", + r""" digamma() -> Tensor See :func:`torch.digamma` -""") +""", +) -add_docstr_all('digamma_', - r""" +add_docstr_all( + "digamma_", + r""" digamma_() -> Tensor In-place version of :meth:`~Tensor.digamma` -""") +""", +) -add_docstr_all('dim', - r""" +add_docstr_all( + "dim", + r""" dim() -> int Returns the number of dimensions of :attr:`self` tensor. -""") +""", +) -add_docstr_all('dist', - r""" +add_docstr_all( + "dist", + r""" dist(other, p=2) -> Tensor See :func:`torch.dist` -""") +""", +) -add_docstr_all('div', r""" +add_docstr_all( + "div", + r""" div(value, *, rounding_mode=None) -> Tensor See :func:`torch.div` -""") +""", +) -add_docstr_all('div_', r""" +add_docstr_all( + "div_", + r""" div_(value, *, rounding_mode=None) -> Tensor In-place version of :meth:`~Tensor.div` -""") +""", +) -add_docstr_all('divide', r""" +add_docstr_all( + "divide", + r""" divide(value, *, rounding_mode=None) -> Tensor See :func:`torch.divide` -""") +""", +) -add_docstr_all('divide_', r""" +add_docstr_all( + "divide_", + r""" divide_(value, *, rounding_mode=None) -> Tensor In-place version of :meth:`~Tensor.divide` -""") +""", +) -add_docstr_all('dot', - r""" +add_docstr_all( + "dot", + r""" dot(other) -> Tensor See :func:`torch.dot` -""") +""", +) -add_docstr_all('eig', - r""" +add_docstr_all( + "eig", + r""" eig(eigenvectors=False) -> (Tensor, Tensor) See :func:`torch.eig` -""") +""", +) -add_docstr_all('element_size', - r""" +add_docstr_all( + "element_size", + r""" element_size() -> int Returns the size in bytes of an individual element. @@ -1353,115 +1715,147 @@ def add_docstr_all(method, docstr): >>> torch.tensor([], dtype=torch.uint8).element_size() 1 -""") +""", +) -add_docstr_all('eq', - r""" +add_docstr_all( + "eq", + r""" eq(other) -> Tensor See :func:`torch.eq` -""") +""", +) -add_docstr_all('eq_', - r""" +add_docstr_all( + "eq_", + r""" eq_(other) -> Tensor In-place version of :meth:`~Tensor.eq` -""") +""", +) -add_docstr_all('equal', - r""" +add_docstr_all( + "equal", + r""" equal(other) -> bool See :func:`torch.equal` -""") +""", +) -add_docstr_all('erf', - r""" +add_docstr_all( + "erf", + r""" erf() -> Tensor See :func:`torch.erf` -""") +""", +) -add_docstr_all('erf_', - r""" +add_docstr_all( + "erf_", + r""" erf_() -> Tensor In-place version of :meth:`~Tensor.erf` -""") +""", +) -add_docstr_all('erfc', - r""" +add_docstr_all( + "erfc", + r""" erfc() -> Tensor See :func:`torch.erfc` -""") +""", +) -add_docstr_all('erfc_', - r""" +add_docstr_all( + "erfc_", + r""" erfc_() -> Tensor In-place version of :meth:`~Tensor.erfc` -""") +""", +) -add_docstr_all('erfinv', - r""" +add_docstr_all( + "erfinv", + r""" erfinv() -> Tensor See :func:`torch.erfinv` -""") +""", +) -add_docstr_all('erfinv_', - r""" +add_docstr_all( + "erfinv_", + r""" erfinv_() -> Tensor In-place version of :meth:`~Tensor.erfinv` -""") +""", +) -add_docstr_all('exp', - r""" +add_docstr_all( + "exp", + r""" exp() -> Tensor See :func:`torch.exp` -""") +""", +) -add_docstr_all('exp_', - r""" +add_docstr_all( + "exp_", + r""" exp_() -> Tensor In-place version of :meth:`~Tensor.exp` -""") +""", +) -add_docstr_all('exp2', - r""" +add_docstr_all( + "exp2", + r""" exp2() -> Tensor See :func:`torch.exp2` -""") +""", +) -add_docstr_all('exp2_', - r""" +add_docstr_all( + "exp2_", + r""" exp2_() -> Tensor In-place version of :meth:`~Tensor.exp2` -""") +""", +) -add_docstr_all('expm1', - r""" +add_docstr_all( + "expm1", + r""" expm1() -> Tensor See :func:`torch.expm1` -""") +""", +) -add_docstr_all('expm1_', - r""" +add_docstr_all( + "expm1_", + r""" expm1_() -> Tensor In-place version of :meth:`~Tensor.expm1` -""") +""", +) -add_docstr_all('exponential_', - r""" +add_docstr_all( + "exponential_", + r""" exponential_(lambd=1, *, generator=None) -> Tensor Fills :attr:`self` tensor with elements drawn from the exponential distribution: @@ -1469,146 +1863,192 @@ def add_docstr_all(method, docstr): .. math:: f(x) = \lambda e^{-\lambda x} -""") +""", +) -add_docstr_all('fill_', - r""" +add_docstr_all( + "fill_", + r""" fill_(value) -> Tensor Fills :attr:`self` tensor with the specified value. -""") +""", +) -add_docstr_all('floor', - r""" +add_docstr_all( + "floor", + r""" floor() -> Tensor See :func:`torch.floor` -""") +""", +) -add_docstr_all('flip', - r""" +add_docstr_all( + "flip", + r""" flip(dims) -> Tensor See :func:`torch.flip` -""") +""", +) -add_docstr_all('fliplr', - r""" +add_docstr_all( + "fliplr", + r""" fliplr() -> Tensor See :func:`torch.fliplr` -""") +""", +) -add_docstr_all('flipud', - r""" +add_docstr_all( + "flipud", + r""" flipud() -> Tensor See :func:`torch.flipud` -""") +""", +) -add_docstr_all('roll', - r""" +add_docstr_all( + "roll", + r""" roll(shifts, dims) -> Tensor See :func:`torch.roll` -""") +""", +) -add_docstr_all('floor_', - r""" +add_docstr_all( + "floor_", + r""" floor_() -> Tensor In-place version of :meth:`~Tensor.floor` -""") +""", +) -add_docstr_all('fmod', - r""" +add_docstr_all( + "fmod", + r""" fmod(divisor) -> Tensor See :func:`torch.fmod` -""") +""", +) -add_docstr_all('fmod_', - r""" +add_docstr_all( + "fmod_", + r""" fmod_(divisor) -> Tensor In-place version of :meth:`~Tensor.fmod` -""") +""", +) -add_docstr_all('frac', - r""" +add_docstr_all( + "frac", + r""" frac() -> Tensor See :func:`torch.frac` -""") +""", +) -add_docstr_all('frac_', - r""" +add_docstr_all( + "frac_", + r""" frac_() -> Tensor In-place version of :meth:`~Tensor.frac` -""") +""", +) -add_docstr_all('frexp', - r""" +add_docstr_all( + "frexp", + r""" frexp(input) -> (Tensor mantissa, Tensor exponent) See :func:`torch.frexp` -""") +""", +) -add_docstr_all('flatten', - r""" +add_docstr_all( + "flatten", + r""" flatten(start_dim=0, end_dim=-1) -> Tensor See :func:`torch.flatten` -""") +""", +) -add_docstr_all('gather', - r""" +add_docstr_all( + "gather", + r""" gather(dim, index) -> Tensor See :func:`torch.gather` -""") +""", +) -add_docstr_all('gcd', - r""" +add_docstr_all( + "gcd", + r""" gcd(other) -> Tensor See :func:`torch.gcd` -""") +""", +) -add_docstr_all('gcd_', - r""" +add_docstr_all( + "gcd_", + r""" gcd_(other) -> Tensor In-place version of :meth:`~Tensor.gcd` -""") +""", +) -add_docstr_all('ge', r""" +add_docstr_all( + "ge", + r""" ge(other) -> Tensor See :func:`torch.ge`. -""") +""", +) -add_docstr_all('ge_', r""" +add_docstr_all( + "ge_", + r""" ge_(other) -> Tensor In-place version of :meth:`~Tensor.ge`. -""") +""", +) -add_docstr_all('greater_equal', r""" +add_docstr_all( + "greater_equal", + r""" greater_equal(other) -> Tensor See :func:`torch.greater_equal`. -""") +""", +) -add_docstr_all('greater_equal_', r""" +add_docstr_all( + "greater_equal_", + r""" greater_equal_(other) -> Tensor In-place version of :meth:`~Tensor.greater_equal`. -""") +""", +) -add_docstr_all('geometric_', - r""" +add_docstr_all( + "geometric_", + r""" geometric_(p, *, generator=None) -> Tensor Fills :attr:`self` tensor with elements drawn from the geometric distribution: @@ -1617,90 +2057,118 @@ def add_docstr_all(method, docstr): f(X=k) = p^{k - 1} (1 - p) -""") +""", +) -add_docstr_all('geqrf', - r""" +add_docstr_all( + "geqrf", + r""" geqrf() -> (Tensor, Tensor) See :func:`torch.geqrf` -""") +""", +) -add_docstr_all('ger', - r""" +add_docstr_all( + "ger", + r""" ger(vec2) -> Tensor See :func:`torch.ger` -""") +""", +) -add_docstr_all('inner', r""" +add_docstr_all( + "inner", + r""" inner(other) -> Tensor See :func:`torch.inner`. -""") +""", +) -add_docstr_all('outer', r""" +add_docstr_all( + "outer", + r""" outer(vec2) -> Tensor See :func:`torch.outer`. -""") +""", +) -add_docstr_all('hypot', - r""" +add_docstr_all( + "hypot", + r""" hypot(other) -> Tensor See :func:`torch.hypot` -""") +""", +) -add_docstr_all('hypot_', - r""" +add_docstr_all( + "hypot_", + r""" hypot_(other) -> Tensor In-place version of :meth:`~Tensor.hypot` -""") +""", +) -add_docstr_all('i0', - r""" +add_docstr_all( + "i0", + r""" i0() -> Tensor See :func:`torch.i0` -""") +""", +) -add_docstr_all('i0_', - r""" +add_docstr_all( + "i0_", + r""" i0_() -> Tensor In-place version of :meth:`~Tensor.i0` -""") +""", +) -add_docstr_all('igamma', - r""" +add_docstr_all( + "igamma", + r""" igamma(other) -> Tensor See :func:`torch.igamma` -""") +""", +) -add_docstr_all('igamma_', - r""" +add_docstr_all( + "igamma_", + r""" igamma_(other) -> Tensor In-place version of :meth:`~Tensor.igamma` -""") +""", +) -add_docstr_all('igammac', - r""" +add_docstr_all( + "igammac", + r""" igammac(other) -> Tensor See :func:`torch.igammac` -""") +""", +) -add_docstr_all('igammac_', - r""" +add_docstr_all( + "igammac_", + r""" igammac_(other) -> Tensor In-place version of :meth:`~Tensor.igammac` -""") +""", +) -add_docstr_all('indices', - r""" +add_docstr_all( + "indices", + r""" indices() -> Tensor Return the indices tensor of a :ref:`sparse COO tensor `. @@ -1713,10 +2181,12 @@ def add_docstr_all(method, docstr): .. note:: This method can only be called on a coalesced sparse tensor. See :meth:`Tensor.coalesce` for details. -""") +""", +) -add_docstr_all('get_device', - r""" +add_docstr_all( + "get_device", + r""" get_device() -> Device ordinal (Integer) For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. @@ -1728,10 +2198,12 @@ def add_docstr_all(method, docstr): >>> x.get_device() 0 >>> x.cpu().get_device() # RuntimeError: get_device is not implemented for type torch.FloatTensor -""") +""", +) -add_docstr_all('values', - r""" +add_docstr_all( + "values", + r""" values() -> Tensor Return the values tensor of a :ref:`sparse COO tensor `. @@ -1744,74 +2216,100 @@ def add_docstr_all(method, docstr): .. note:: This method can only be called on a coalesced sparse tensor. See :meth:`Tensor.coalesce` for details. -""") +""", +) -add_docstr_all('gt', r""" +add_docstr_all( + "gt", + r""" gt(other) -> Tensor See :func:`torch.gt`. -""") +""", +) -add_docstr_all('gt_', r""" +add_docstr_all( + "gt_", + r""" gt_(other) -> Tensor In-place version of :meth:`~Tensor.gt`. -""") +""", +) -add_docstr_all('greater', r""" +add_docstr_all( + "greater", + r""" greater(other) -> Tensor See :func:`torch.greater`. -""") +""", +) -add_docstr_all('greater_', r""" +add_docstr_all( + "greater_", + r""" greater_(other) -> Tensor In-place version of :meth:`~Tensor.greater`. -""") +""", +) -add_docstr_all('has_names', - r""" +add_docstr_all( + "has_names", + r""" Is ``True`` if any of this tensor's dimensions are named. Otherwise, is ``False``. -""") +""", +) -add_docstr_all('hardshrink', - r""" +add_docstr_all( + "hardshrink", + r""" hardshrink(lambd=0.5) -> Tensor See :func:`torch.nn.functional.hardshrink` -""") +""", +) -add_docstr_all('heaviside', - r""" +add_docstr_all( + "heaviside", + r""" heaviside(values) -> Tensor See :func:`torch.heaviside` -""") +""", +) -add_docstr_all('heaviside_', - r""" +add_docstr_all( + "heaviside_", + r""" heaviside_(values) -> Tensor In-place version of :meth:`~Tensor.heaviside` -""") +""", +) -add_docstr_all('histc', - r""" +add_docstr_all( + "histc", + r""" histc(bins=100, min=0, max=0) -> Tensor See :func:`torch.histc` -""") +""", +) -add_docstr_all('histogram', - r""" +add_docstr_all( + "histogram", + r""" histogram(input, bins, *, range=None, weight=None, density=False) -> (Tensor, Tensor) See :func:`torch.histogram` -""") +""", +) -add_docstr_all('index_add_', - r""" +add_docstr_all( + "index_add_", + r""" index_add_(dim, index, source, *, alpha=1) -> Tensor Accumulate the elements of :attr:`alpha` times ``source`` into the :attr:`self` @@ -1858,10 +2356,14 @@ def add_docstr_all(method, docstr): [ 1., 1., 1.], [ 1., 1., 1.], [ 1., 1., 1.]]) -""".format(**reproducibility_notes)) - -add_docstr_all('index_copy_', - r""" +""".format( + **reproducibility_notes + ), +) + +add_docstr_all( + "index_copy_", + r""" index_copy_(dim, index, tensor) -> Tensor Copies the elements of :attr:`tensor` into the :attr:`self` tensor by selecting @@ -1894,10 +2396,12 @@ def add_docstr_all(method, docstr): [ 7., 8., 9.], [ 0., 0., 0.], [ 4., 5., 6.]]) -""") +""", +) -add_docstr_all('index_fill_', - r""" +add_docstr_all( + "index_fill_", + r""" index_fill_(dim, index, value) -> Tensor Fills the elements of the :attr:`self` tensor with value :attr:`value` by @@ -1915,10 +2419,12 @@ def add_docstr_all(method, docstr): tensor([[-1., 2., -1.], [-1., 5., -1.], [-1., 8., -1.]]) -""") +""", +) -add_docstr_all('index_put_', - r""" +add_docstr_all( + "index_put_", + r""" index_put_(indices, values, accumulate=False) -> Tensor Puts values from the tensor :attr:`values` into the tensor :attr:`self` using @@ -1934,17 +2440,21 @@ def add_docstr_all(method, docstr): indices (tuple of LongTensor): tensors used to index into `self`. values (Tensor): tensor of same dtype as `self`. accumulate (bool): whether to accumulate into self -""") +""", +) -add_docstr_all('index_put', - r""" +add_docstr_all( + "index_put", + r""" index_put(indices, values, accumulate=False) -> Tensor Out-place version of :meth:`~Tensor.index_put_`. -""") +""", +) -add_docstr_all('index_reduce_', - r""" +add_docstr_all( + "index_reduce_", + r""" index_reduce_(dim, index, source, reduce, *, include_self=True) -> Tensor Accumulate the elements of ``source`` into the :attr:`self` @@ -2008,17 +2518,23 @@ def add_docstr_all(method, docstr): [ 7., 8., 9.], [ 2., 2., 2.], [ 4., 5., 6.]]) -""".format(**reproducibility_notes)) - -add_docstr_all('index_select', - r""" +""".format( + **reproducibility_notes + ), +) + +add_docstr_all( + "index_select", + r""" index_select(dim, index) -> Tensor See :func:`torch.index_select` -""") +""", +) -add_docstr_all('sparse_mask', - r""" +add_docstr_all( + "sparse_mask", + r""" sparse_mask(mask) -> Tensor Returns a new :ref:`sparse tensor ` with values from a @@ -2060,66 +2576,84 @@ def add_docstr_all(method, docstr): [[ 0.0793, 0.0036], [-0.2569, -0.1055]]]), size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo) -""") +""", +) -add_docstr_all('inverse', - r""" +add_docstr_all( + "inverse", + r""" inverse() -> Tensor See :func:`torch.inverse` -""") +""", +) -add_docstr_all('isnan', - r""" +add_docstr_all( + "isnan", + r""" isnan() -> Tensor See :func:`torch.isnan` -""") +""", +) -add_docstr_all('isinf', - r""" +add_docstr_all( + "isinf", + r""" isinf() -> Tensor See :func:`torch.isinf` -""") +""", +) -add_docstr_all('isposinf', - r""" +add_docstr_all( + "isposinf", + r""" isposinf() -> Tensor See :func:`torch.isposinf` -""") +""", +) -add_docstr_all('isneginf', - r""" +add_docstr_all( + "isneginf", + r""" isneginf() -> Tensor See :func:`torch.isneginf` -""") +""", +) -add_docstr_all('isfinite', - r""" +add_docstr_all( + "isfinite", + r""" isfinite() -> Tensor See :func:`torch.isfinite` -""") +""", +) -add_docstr_all('isclose', - r""" +add_docstr_all( + "isclose", + r""" isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor See :func:`torch.isclose` -""") +""", +) -add_docstr_all('isreal', - r""" +add_docstr_all( + "isreal", + r""" isreal() -> Tensor See :func:`torch.isreal` -""") +""", +) -add_docstr_all('is_coalesced', - r""" +add_docstr_all( + "is_coalesced", + r""" is_coalesced() -> bool Returns ``True`` if :attr:`self` is a :ref:`sparse COO tensor @@ -2129,10 +2663,12 @@ def add_docstr_all(method, docstr): Throws an error if :attr:`self` is not a sparse COO tensor. See :meth:`coalesce` and :ref:`uncoalesced tensors `. -""") +""", +) -add_docstr_all('is_contiguous', - r""" +add_docstr_all( + "is_contiguous", + r""" is_contiguous(memory_format=torch.contiguous_format) -> bool Returns True if :attr:`self` tensor is contiguous in memory in the order specified @@ -2141,64 +2677,83 @@ def add_docstr_all(method, docstr): Args: memory_format (:class:`torch.memory_format`, optional): Specifies memory allocation order. Default: ``torch.contiguous_format``. -""") +""", +) -add_docstr_all('is_pinned', - r""" +add_docstr_all( + "is_pinned", + r""" Returns true if this tensor resides in pinned memory. -""") +""", +) -add_docstr_all('is_floating_point', - r""" +add_docstr_all( + "is_floating_point", + r""" is_floating_point() -> bool Returns True if the data type of :attr:`self` is a floating point data type. -""") +""", +) -add_docstr_all('is_complex', - r""" +add_docstr_all( + "is_complex", + r""" is_complex() -> bool Returns True if the data type of :attr:`self` is a complex data type. -""") +""", +) -add_docstr_all('is_inference', - r""" +add_docstr_all( + "is_inference", + r""" is_inference() -> bool See :func:`torch.is_inference` -""") +""", +) -add_docstr_all('is_conj', - r""" +add_docstr_all( + "is_conj", + r""" is_conj() -> bool Returns True if the conjugate bit of :attr:`self` is set to true. -""") +""", +) -add_docstr_all('is_neg', - r""" +add_docstr_all( + "is_neg", + r""" is_neg() -> bool Returns True if the negative bit of :attr:`self` is set to true. -""") +""", +) -add_docstr_all('is_signed', - r""" +add_docstr_all( + "is_signed", + r""" is_signed() -> bool Returns True if the data type of :attr:`self` is a signed data type. -""") +""", +) -add_docstr_all('is_set_to', - r""" +add_docstr_all( + "is_set_to", + r""" is_set_to(tensor) -> bool Returns True if both tensors are pointing to the exact same memory (same storage, offset, size and stride). -""") +""", +) -add_docstr_all('item', r""" +add_docstr_all( + "item", + r""" item() -> number Returns the value of this tensor as a standard Python number. This only works @@ -2212,171 +2767,228 @@ def add_docstr_all(method, docstr): >>> x.item() 1.0 -""") +""", +) -add_docstr_all('kron', - r""" +add_docstr_all( + "kron", + r""" kron(other) -> Tensor See :func:`torch.kron` -""") +""", +) -add_docstr_all('kthvalue', - r""" +add_docstr_all( + "kthvalue", + r""" kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) See :func:`torch.kthvalue` -""") +""", +) -add_docstr_all('ldexp', - r""" +add_docstr_all( + "ldexp", + r""" ldexp(other) -> Tensor See :func:`torch.ldexp` -""") +""", +) -add_docstr_all('ldexp_', - r""" +add_docstr_all( + "ldexp_", + r""" ldexp_(other) -> Tensor In-place version of :meth:`~Tensor.ldexp` -""") +""", +) -add_docstr_all('lcm', - r""" +add_docstr_all( + "lcm", + r""" lcm(other) -> Tensor See :func:`torch.lcm` -""") +""", +) -add_docstr_all('lcm_', - r""" +add_docstr_all( + "lcm_", + r""" lcm_(other) -> Tensor In-place version of :meth:`~Tensor.lcm` -""") +""", +) -add_docstr_all('le', r""" +add_docstr_all( + "le", + r""" le(other) -> Tensor See :func:`torch.le`. -""") +""", +) -add_docstr_all('le_', r""" +add_docstr_all( + "le_", + r""" le_(other) -> Tensor In-place version of :meth:`~Tensor.le`. -""") +""", +) -add_docstr_all('less_equal', r""" +add_docstr_all( + "less_equal", + r""" less_equal(other) -> Tensor See :func:`torch.less_equal`. -""") +""", +) -add_docstr_all('less_equal_', r""" +add_docstr_all( + "less_equal_", + r""" less_equal_(other) -> Tensor In-place version of :meth:`~Tensor.less_equal`. -""") +""", +) -add_docstr_all('lerp', - r""" +add_docstr_all( + "lerp", + r""" lerp(end, weight) -> Tensor See :func:`torch.lerp` -""") +""", +) -add_docstr_all('lerp_', - r""" +add_docstr_all( + "lerp_", + r""" lerp_(end, weight) -> Tensor In-place version of :meth:`~Tensor.lerp` -""") +""", +) -add_docstr_all('lgamma', - r""" +add_docstr_all( + "lgamma", + r""" lgamma() -> Tensor See :func:`torch.lgamma` -""") +""", +) -add_docstr_all('lgamma_', r""" +add_docstr_all( + "lgamma_", + r""" lgamma_() -> Tensor In-place version of :meth:`~Tensor.lgamma` -""") +""", +) -add_docstr_all('log', - r""" +add_docstr_all( + "log", + r""" log() -> Tensor See :func:`torch.log` -""") +""", +) -add_docstr_all('log_', r""" +add_docstr_all( + "log_", + r""" log_() -> Tensor In-place version of :meth:`~Tensor.log` -""") +""", +) -add_docstr_all('log10', - r""" +add_docstr_all( + "log10", + r""" log10() -> Tensor See :func:`torch.log10` -""") +""", +) -add_docstr_all('log10_', - r""" +add_docstr_all( + "log10_", + r""" log10_() -> Tensor In-place version of :meth:`~Tensor.log10` -""") +""", +) -add_docstr_all('log1p', - r""" +add_docstr_all( + "log1p", + r""" log1p() -> Tensor See :func:`torch.log1p` -""") +""", +) -add_docstr_all('log1p_', - r""" +add_docstr_all( + "log1p_", + r""" log1p_() -> Tensor In-place version of :meth:`~Tensor.log1p` -""") +""", +) -add_docstr_all('log2', - r""" +add_docstr_all( + "log2", + r""" log2() -> Tensor See :func:`torch.log2` -""") +""", +) -add_docstr_all('log2_', - r""" +add_docstr_all( + "log2_", + r""" log2_() -> Tensor In-place version of :meth:`~Tensor.log2` -""") +""", +) -add_docstr_all('logaddexp', - r""" +add_docstr_all( + "logaddexp", + r""" logaddexp(other) -> Tensor See :func:`torch.logaddexp` -""") +""", +) -add_docstr_all('logaddexp2', - r""" +add_docstr_all( + "logaddexp2", + r""" logaddexp2(other) -> Tensor See :func:`torch.logaddexp2` -""") +""", +) -add_docstr_all('log_normal_', r""" +add_docstr_all( + "log_normal_", + r""" log_normal_(mean=1, std=2, *, generator=None) Fills :attr:`self` tensor with numbers samples from the log-normal distribution @@ -2388,55 +3000,75 @@ def add_docstr_all(method, docstr): .. math:: f(x) = \dfrac{1}{x \sigma \sqrt{2\pi}}\ e^{-\frac{(\ln x - \mu)^2}{2\sigma^2}} -""") +""", +) -add_docstr_all('logsumexp', - r""" +add_docstr_all( + "logsumexp", + r""" logsumexp(dim, keepdim=False) -> Tensor See :func:`torch.logsumexp` -""") +""", +) -add_docstr_all('lstsq', - r""" +add_docstr_all( + "lstsq", + r""" lstsq(A) -> (Tensor, Tensor) See :func:`torch.lstsq` -""") +""", +) -add_docstr_all('lt', r""" +add_docstr_all( + "lt", + r""" lt(other) -> Tensor See :func:`torch.lt`. -""") +""", +) -add_docstr_all('lt_', r""" +add_docstr_all( + "lt_", + r""" lt_(other) -> Tensor In-place version of :meth:`~Tensor.lt`. -""") +""", +) -add_docstr_all('less', r""" +add_docstr_all( + "less", + r""" lt(other) -> Tensor See :func:`torch.less`. -""") +""", +) -add_docstr_all('less_', r""" +add_docstr_all( + "less_", + r""" less_(other) -> Tensor In-place version of :meth:`~Tensor.less`. -""") +""", +) -add_docstr_all('lu_solve', - r""" +add_docstr_all( + "lu_solve", + r""" lu_solve(LU_data, LU_pivots) -> Tensor See :func:`torch.lu_solve` -""") +""", +) -add_docstr_all('map_', - r""" +add_docstr_all( + "map_", + r""" map_(tensor, callable) Applies :attr:`callable` for each element in :attr:`self` tensor and the given @@ -2446,10 +3078,12 @@ def add_docstr_all(method, docstr): The :attr:`callable` should have the signature:: def callable(a, b) -> number -""") +""", +) -add_docstr_all('masked_scatter_', - r""" +add_docstr_all( + "masked_scatter_", + r""" masked_scatter_(mask, source) Copies elements from :attr:`source` into :attr:`self` tensor at positions where @@ -2466,10 +3100,12 @@ def callable(a, b) -> number The :attr:`mask` operates on the :attr:`self` tensor, not on the given :attr:`source` tensor. -""") +""", +) -add_docstr_all('masked_fill_', - r""" +add_docstr_all( + "masked_fill_", + r""" masked_fill_(mask, value) Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is @@ -2480,219 +3116,293 @@ def callable(a, b) -> number Args: mask (BoolTensor): the boolean mask value (float): the value to fill in with -""") +""", +) -add_docstr_all('masked_select', - r""" +add_docstr_all( + "masked_select", + r""" masked_select(mask) -> Tensor See :func:`torch.masked_select` -""") +""", +) -add_docstr_all('matrix_power', r""" +add_docstr_all( + "matrix_power", + r""" matrix_power(n) -> Tensor .. note:: :meth:`~Tensor.matrix_power` is deprecated, use :func:`torch.linalg.matrix_power` instead. Alias for :func:`torch.linalg.matrix_power` -""") +""", +) -add_docstr_all('matrix_exp', - r""" +add_docstr_all( + "matrix_exp", + r""" matrix_exp() -> Tensor See :func:`torch.matrix_exp` -""") +""", +) -add_docstr_all('max', - r""" +add_docstr_all( + "max", + r""" max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) See :func:`torch.max` -""") +""", +) -add_docstr_all('amax', - r""" +add_docstr_all( + "amax", + r""" amax(dim=None, keepdim=False) -> Tensor See :func:`torch.amax` -""") +""", +) -add_docstr_all('maximum', - r""" +add_docstr_all( + "maximum", + r""" maximum(other) -> Tensor See :func:`torch.maximum` -""") +""", +) -add_docstr_all('fmax', - r""" +add_docstr_all( + "fmax", + r""" fmax(other) -> Tensor See :func:`torch.fmax` -""") +""", +) -add_docstr_all('argmax', - r""" +add_docstr_all( + "argmax", + r""" argmax(dim=None, keepdim=False) -> LongTensor See :func:`torch.argmax` -""") +""", +) -add_docstr_all('argwhere', - r""" +add_docstr_all( + "argwhere", + r""" argwhere() -> Tensor See :func:`torch.argwhere` -""") +""", +) -add_docstr_all('mean', r""" +add_docstr_all( + "mean", + r""" mean(dim=None, keepdim=False, *, dtype=None) -> Tensor See :func:`torch.mean` -""") +""", +) -add_docstr_all('nanmean', r""" +add_docstr_all( + "nanmean", + r""" nanmean(dim=None, keepdim=False, *, dtype=None) -> Tensor See :func:`torch.nanmean` -""") +""", +) -add_docstr_all('median', - r""" +add_docstr_all( + "median", + r""" median(dim=None, keepdim=False) -> (Tensor, LongTensor) See :func:`torch.median` -""") +""", +) -add_docstr_all('nanmedian', - r""" +add_docstr_all( + "nanmedian", + r""" nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) See :func:`torch.nanmedian` -""") +""", +) -add_docstr_all('min', - r""" +add_docstr_all( + "min", + r""" min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) See :func:`torch.min` -""") +""", +) -add_docstr_all('amin', - r""" +add_docstr_all( + "amin", + r""" amin(dim=None, keepdim=False) -> Tensor See :func:`torch.amin` -""") +""", +) -add_docstr_all('minimum', - r""" +add_docstr_all( + "minimum", + r""" minimum(other) -> Tensor See :func:`torch.minimum` -""") +""", +) -add_docstr_all('aminmax', r""" +add_docstr_all( + "aminmax", + r""" aminmax(*, dim=None, keepdim=False) -> (Tensor min, Tensor max) See :func:`torch.aminmax` -""") +""", +) -add_docstr_all('fmin', - r""" +add_docstr_all( + "fmin", + r""" fmin(other) -> Tensor See :func:`torch.fmin` -""") +""", +) -add_docstr_all('argmin', - r""" +add_docstr_all( + "argmin", + r""" argmin(dim=None, keepdim=False) -> LongTensor See :func:`torch.argmin` -""") +""", +) -add_docstr_all('mm', - r""" +add_docstr_all( + "mm", + r""" mm(mat2) -> Tensor See :func:`torch.mm` -""") +""", +) -add_docstr_all('mode', - r""" +add_docstr_all( + "mode", + r""" mode(dim=None, keepdim=False) -> (Tensor, LongTensor) See :func:`torch.mode` -""") +""", +) -add_docstr_all('movedim', r""" +add_docstr_all( + "movedim", + r""" movedim(source, destination) -> Tensor See :func:`torch.movedim` -""") +""", +) -add_docstr_all('moveaxis', r""" +add_docstr_all( + "moveaxis", + r""" moveaxis(source, destination) -> Tensor See :func:`torch.moveaxis` -""") +""", +) -add_docstr_all('mul', r""" +add_docstr_all( + "mul", + r""" mul(value) -> Tensor See :func:`torch.mul`. -""") +""", +) -add_docstr_all('mul_', r""" +add_docstr_all( + "mul_", + r""" mul_(value) -> Tensor In-place version of :meth:`~Tensor.mul`. -""") +""", +) -add_docstr_all('multiply', r""" +add_docstr_all( + "multiply", + r""" multiply(value) -> Tensor See :func:`torch.multiply`. -""") +""", +) -add_docstr_all('multiply_', r""" +add_docstr_all( + "multiply_", + r""" multiply_(value) -> Tensor In-place version of :meth:`~Tensor.multiply`. -""") +""", +) -add_docstr_all('multinomial', - r""" +add_docstr_all( + "multinomial", + r""" multinomial(num_samples, replacement=False, *, generator=None) -> Tensor See :func:`torch.multinomial` -""") +""", +) -add_docstr_all('mv', - r""" +add_docstr_all( + "mv", + r""" mv(vec) -> Tensor See :func:`torch.mv` -""") +""", +) -add_docstr_all('mvlgamma', - r""" +add_docstr_all( + "mvlgamma", + r""" mvlgamma(p) -> Tensor See :func:`torch.mvlgamma` -""") +""", +) -add_docstr_all('mvlgamma_', - r""" +add_docstr_all( + "mvlgamma_", + r""" mvlgamma_(p) -> Tensor In-place version of :meth:`~Tensor.mvlgamma` -""") +""", +) -add_docstr_all('narrow', - r""" +add_docstr_all( + "narrow", + r""" narrow(dimension, start, length) -> Tensor See :func:`torch.narrow` @@ -2707,10 +3417,12 @@ def callable(a, b) -> number tensor([[ 2, 3], [ 5, 6], [ 8, 9]]) -""") +""", +) -add_docstr_all('narrow_copy', - r""" +add_docstr_all( + "narrow_copy", + r""" narrow_copy(dimension, start, length) -> Tensor Same as :meth:`Tensor.narrow` except returning a copy rather @@ -2718,129 +3430,173 @@ def callable(a, b) -> number do not have a shared-storage narrow method. Calling ``narrow_copy`` with ``dimemsion > self.sparse_dim()`` will return a copy with the relevant dense dimension narrowed, and ``self.shape`` updated accordingly. -""") +""", +) -add_docstr_all('ndimension', - r""" +add_docstr_all( + "ndimension", + r""" ndimension() -> int Alias for :meth:`~Tensor.dim()` -""") +""", +) -add_docstr_all('nan_to_num', r""" +add_docstr_all( + "nan_to_num", + r""" nan_to_num(nan=0.0, posinf=None, neginf=None) -> Tensor See :func:`torch.nan_to_num`. -""") +""", +) -add_docstr_all('nan_to_num_', r""" +add_docstr_all( + "nan_to_num_", + r""" nan_to_num_(nan=0.0, posinf=None, neginf=None) -> Tensor In-place version of :meth:`~Tensor.nan_to_num`. -""") +""", +) -add_docstr_all('ne', r""" +add_docstr_all( + "ne", + r""" ne(other) -> Tensor See :func:`torch.ne`. -""") +""", +) -add_docstr_all('ne_', r""" +add_docstr_all( + "ne_", + r""" ne_(other) -> Tensor In-place version of :meth:`~Tensor.ne`. -""") +""", +) -add_docstr_all('not_equal', r""" +add_docstr_all( + "not_equal", + r""" not_equal(other) -> Tensor See :func:`torch.not_equal`. -""") +""", +) -add_docstr_all('not_equal_', r""" +add_docstr_all( + "not_equal_", + r""" not_equal_(other) -> Tensor In-place version of :meth:`~Tensor.not_equal`. -""") +""", +) -add_docstr_all('neg', - r""" +add_docstr_all( + "neg", + r""" neg() -> Tensor See :func:`torch.neg` -""") +""", +) -add_docstr_all('negative', - r""" +add_docstr_all( + "negative", + r""" negative() -> Tensor See :func:`torch.negative` -""") +""", +) -add_docstr_all('neg_', - r""" +add_docstr_all( + "neg_", + r""" neg_() -> Tensor In-place version of :meth:`~Tensor.neg` -""") +""", +) -add_docstr_all('negative_', - r""" +add_docstr_all( + "negative_", + r""" negative_() -> Tensor In-place version of :meth:`~Tensor.negative` -""") +""", +) -add_docstr_all('nelement', - r""" +add_docstr_all( + "nelement", + r""" nelement() -> int Alias for :meth:`~Tensor.numel` -""") +""", +) -add_docstr_all('nextafter', - r""" +add_docstr_all( + "nextafter", + r""" nextafter(other) -> Tensor See :func:`torch.nextafter` -""") +""", +) -add_docstr_all('nextafter_', - r""" +add_docstr_all( + "nextafter_", + r""" nextafter_(other) -> Tensor In-place version of :meth:`~Tensor.nextafter` -""") +""", +) -add_docstr_all('nonzero', - r""" +add_docstr_all( + "nonzero", + r""" nonzero() -> LongTensor See :func:`torch.nonzero` -""") +""", +) -add_docstr_all('norm', - r""" +add_docstr_all( + "norm", + r""" norm(p=2, dim=None, keepdim=False) -> Tensor See :func:`torch.norm` -""") +""", +) -add_docstr_all('normal_', - r""" +add_docstr_all( + "normal_", + r""" normal_(mean=0, std=1, *, generator=None) -> Tensor Fills :attr:`self` tensor with elements samples from the normal distribution parameterized by :attr:`mean` and :attr:`std`. -""") +""", +) -add_docstr_all('numel', - r""" +add_docstr_all( + "numel", + r""" numel() -> int See :func:`torch.numel` -""") +""", +) -add_docstr_all('numpy', - r""" +add_docstr_all( + "numpy", + r""" numpy(*, force=False) -> numpy.ndarray Returns the tensor as a NumPy :class:`ndarray`. @@ -2861,87 +3617,111 @@ def callable(a, b) -> number Args: force (bool): if ``True``, the ndarray may be a copy of the tensor instead of always sharing memory, defaults to ``False``. -""") +""", +) -add_docstr_all('orgqr', - r""" +add_docstr_all( + "orgqr", + r""" orgqr(input2) -> Tensor See :func:`torch.orgqr` -""") +""", +) -add_docstr_all('ormqr', - r""" +add_docstr_all( + "ormqr", + r""" ormqr(input2, input3, left=True, transpose=False) -> Tensor See :func:`torch.ormqr` -""") +""", +) -add_docstr_all('permute', - r""" +add_docstr_all( + "permute", + r""" permute(*dims) -> Tensor See :func:`torch.permute` -""") +""", +) -add_docstr_all('polygamma', - r""" +add_docstr_all( + "polygamma", + r""" polygamma(n) -> Tensor See :func:`torch.polygamma` -""") +""", +) -add_docstr_all('polygamma_', - r""" +add_docstr_all( + "polygamma_", + r""" polygamma_(n) -> Tensor In-place version of :meth:`~Tensor.polygamma` -""") +""", +) -add_docstr_all('positive', - r""" +add_docstr_all( + "positive", + r""" positive() -> Tensor See :func:`torch.positive` -""") +""", +) -add_docstr_all('pow', - r""" +add_docstr_all( + "pow", + r""" pow(exponent) -> Tensor See :func:`torch.pow` -""") +""", +) -add_docstr_all('pow_', - r""" +add_docstr_all( + "pow_", + r""" pow_(exponent) -> Tensor In-place version of :meth:`~Tensor.pow` -""") +""", +) -add_docstr_all('float_power', - r""" +add_docstr_all( + "float_power", + r""" float_power(exponent) -> Tensor See :func:`torch.float_power` -""") +""", +) -add_docstr_all('float_power_', - r""" +add_docstr_all( + "float_power_", + r""" float_power_(exponent) -> Tensor In-place version of :meth:`~Tensor.float_power` -""") +""", +) -add_docstr_all('prod', - r""" +add_docstr_all( + "prod", + r""" prod(dim=None, keepdim=False, dtype=None) -> Tensor See :func:`torch.prod` -""") +""", +) -add_docstr_all('put_', - r""" +add_docstr_all( + "put_", + r""" put_(index, source, accumulate=False) -> Tensor Copies the elements from :attr:`source` into the positions specified by @@ -2967,88 +3747,112 @@ def callable(a, b) -> number >>> src.put_(torch.tensor([1, 3]), torch.tensor([9, 10])) tensor([[ 4, 9, 5], [ 10, 7, 8]]) -""") +""", +) -add_docstr_all('put', - r""" +add_docstr_all( + "put", + r""" put(input, index, source, accumulate=False) -> Tensor Out-of-place version of :meth:`torch.Tensor.put_`. `input` corresponds to `self` in :meth:`torch.Tensor.put_`. -""") +""", +) -add_docstr_all('qr', - r""" +add_docstr_all( + "qr", + r""" qr(some=True) -> (Tensor, Tensor) See :func:`torch.qr` -""") +""", +) -add_docstr_all('qscheme', - r""" +add_docstr_all( + "qscheme", + r""" qscheme() -> torch.qscheme Returns the quantization scheme of a given QTensor. -""") +""", +) -add_docstr_all('quantile', r""" +add_docstr_all( + "quantile", + r""" quantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor See :func:`torch.quantile` -""") +""", +) -add_docstr_all('nanquantile', r""" +add_docstr_all( + "nanquantile", + r""" nanquantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor See :func:`torch.nanquantile` -""") +""", +) -add_docstr_all('q_scale', - r""" +add_docstr_all( + "q_scale", + r""" q_scale() -> float Given a Tensor quantized by linear(affine) quantization, returns the scale of the underlying quantizer(). -""") +""", +) -add_docstr_all('q_zero_point', - r""" +add_docstr_all( + "q_zero_point", + r""" q_zero_point() -> int Given a Tensor quantized by linear(affine) quantization, returns the zero_point of the underlying quantizer(). -""") +""", +) -add_docstr_all('q_per_channel_scales', - r""" +add_docstr_all( + "q_per_channel_scales", + r""" q_per_channel_scales() -> Tensor Given a Tensor quantized by linear (affine) per-channel quantization, returns a Tensor of scales of the underlying quantizer. It has the number of elements that matches the corresponding dimensions (from q_per_channel_axis) of the tensor. -""") +""", +) -add_docstr_all('q_per_channel_zero_points', - r""" +add_docstr_all( + "q_per_channel_zero_points", + r""" q_per_channel_zero_points() -> Tensor Given a Tensor quantized by linear (affine) per-channel quantization, returns a tensor of zero_points of the underlying quantizer. It has the number of elements that matches the corresponding dimensions (from q_per_channel_axis) of the tensor. -""") +""", +) -add_docstr_all('q_per_channel_axis', - r""" +add_docstr_all( + "q_per_channel_axis", + r""" q_per_channel_axis() -> int Given a Tensor quantized by linear (affine) per-channel quantization, returns the index of dimension on which per-channel quantization is applied. -""") +""", +) -add_docstr_all('random_', - r""" +add_docstr_all( + "random_", + r""" random_(from=0, to=None, *, generator=None) -> Tensor Fills :attr:`self` tensor with numbers sampled from the discrete uniform @@ -3057,59 +3861,75 @@ def callable(a, b) -> number types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` will be uniform in ``[0, 2^53]``. -""") +""", +) -add_docstr_all('rad2deg', - r""" +add_docstr_all( + "rad2deg", + r""" rad2deg() -> Tensor See :func:`torch.rad2deg` -""") +""", +) -add_docstr_all('rad2deg_', - r""" +add_docstr_all( + "rad2deg_", + r""" rad2deg_() -> Tensor In-place version of :meth:`~Tensor.rad2deg` -""") +""", +) -add_docstr_all('deg2rad', - r""" +add_docstr_all( + "deg2rad", + r""" deg2rad() -> Tensor See :func:`torch.deg2rad` -""") +""", +) -add_docstr_all('deg2rad_', - r""" +add_docstr_all( + "deg2rad_", + r""" deg2rad_() -> Tensor In-place version of :meth:`~Tensor.deg2rad` -""") +""", +) -add_docstr_all('ravel', - r""" +add_docstr_all( + "ravel", + r""" ravel() -> Tensor see :func:`torch.ravel` -""") +""", +) -add_docstr_all('reciprocal', - r""" +add_docstr_all( + "reciprocal", + r""" reciprocal() -> Tensor See :func:`torch.reciprocal` -""") +""", +) -add_docstr_all('reciprocal_', - r""" +add_docstr_all( + "reciprocal_", + r""" reciprocal_() -> Tensor In-place version of :meth:`~Tensor.reciprocal` -""") +""", +) -add_docstr_all('record_stream', - r""" +add_docstr_all( + "record_stream", + r""" record_stream(stream) Ensures that the tensor memory is not reused for another tensor until all @@ -3124,38 +3944,48 @@ def callable(a, b) -> number unexpectedly. Calling this method lets the allocator know which streams have used the tensor. -""") +""", +) -add_docstr_all('remainder', - r""" +add_docstr_all( + "remainder", + r""" remainder(divisor) -> Tensor See :func:`torch.remainder` -""") +""", +) -add_docstr_all('remainder_', - r""" +add_docstr_all( + "remainder_", + r""" remainder_(divisor) -> Tensor In-place version of :meth:`~Tensor.remainder` -""") +""", +) -add_docstr_all('renorm', - r""" +add_docstr_all( + "renorm", + r""" renorm(p, dim, maxnorm) -> Tensor See :func:`torch.renorm` -""") +""", +) -add_docstr_all('renorm_', - r""" +add_docstr_all( + "renorm_", + r""" renorm_(p, dim, maxnorm) -> Tensor In-place version of :meth:`~Tensor.renorm` -""") +""", +) -add_docstr_all('repeat', - r""" +add_docstr_all( + "repeat", + r""" repeat(*sizes) -> Tensor Repeats this tensor along the specified dimensions. @@ -3184,17 +4014,21 @@ def callable(a, b) -> number [ 1, 2, 3, 1, 2, 3]]) >>> x.repeat(4, 2, 1).size() torch.Size([4, 2, 3]) -""") +""", +) -add_docstr_all('repeat_interleave', - r""" +add_docstr_all( + "repeat_interleave", + r""" repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor See :func:`torch.repeat_interleave`. -""") +""", +) -add_docstr_all('requires_grad_', - r""" +add_docstr_all( + "requires_grad_", + r""" requires_grad_(requires_grad=True) -> Tensor Change if autograd should record operations on this tensor: sets this tensor's @@ -3227,10 +4061,12 @@ def callable(a, b) -> number >>> weights.grad tensor([-1.1007, 0.9853, -4.2316, -1.6606]) -""") +""", +) -add_docstr_all('reshape', - r""" +add_docstr_all( + "reshape", + r""" reshape(*shape) -> Tensor Returns a tensor with the same data and number of elements as :attr:`self` @@ -3243,10 +4079,12 @@ def callable(a, b) -> number Args: shape (tuple of ints or int...): the desired shape -""") +""", +) -add_docstr_all('reshape_as', - r""" +add_docstr_all( + "reshape_as", + r""" reshape_as(other) -> Tensor Returns this tensor as the same shape as :attr:`other`. @@ -3259,10 +4097,12 @@ def callable(a, b) -> number Args: other (:class:`torch.Tensor`): The result tensor has the same shape as :attr:`other`. -""") +""", +) -add_docstr_all('resize_', - r""" +add_docstr_all( + "resize_", + r""" resize_(*sizes, memory_format=torch.contiguous_format) -> Tensor Resizes :attr:`self` tensor to the specified size. If the number of elements is @@ -3292,10 +4132,12 @@ def callable(a, b) -> number >>> x.resize_(2, 2) tensor([[ 1, 2], [ 3, 4]]) -""") +""", +) -add_docstr_all('resize_as_', - r""" +add_docstr_all( + "resize_as_", + r""" resize_as_(tensor, memory_format=torch.contiguous_format) -> Tensor Resizes the :attr:`self` tensor to be the same size as the specified @@ -3306,45 +4148,57 @@ def callable(a, b) -> number Tensor. Default: ``torch.contiguous_format``. Note that memory format of :attr:`self` is going to be unaffected if ``self.size()`` matches ``tensor.size()``. -""") +""", +) -add_docstr_all('rot90', - r""" +add_docstr_all( + "rot90", + r""" rot90(k, dims) -> Tensor See :func:`torch.rot90` -""") +""", +) -add_docstr_all('round', - r""" +add_docstr_all( + "round", + r""" round(decimals=0) -> Tensor See :func:`torch.round` -""") +""", +) -add_docstr_all('round_', - r""" +add_docstr_all( + "round_", + r""" round_(decimals=0) -> Tensor In-place version of :meth:`~Tensor.round` -""") +""", +) -add_docstr_all('rsqrt', - r""" +add_docstr_all( + "rsqrt", + r""" rsqrt() -> Tensor See :func:`torch.rsqrt` -""") +""", +) -add_docstr_all('rsqrt_', - r""" +add_docstr_all( + "rsqrt_", + r""" rsqrt_() -> Tensor In-place version of :meth:`~Tensor.rsqrt` -""") +""", +) -add_docstr_all('scatter_', - r""" +add_docstr_all( + "scatter_", + r""" scatter_(dim, index, src, reduce=None) -> Tensor Writes all values from the tensor :attr:`src` into :attr:`self` at the indices @@ -3433,13 +4287,15 @@ def callable(a, b) -> number tensor([[2.0000, 2.0000, 3.2300, 2.0000], [2.0000, 2.0000, 2.0000, 3.2300]]) -""") +""", +) -add_docstr_all('scatter_add_', - r""" +add_docstr_all( + "scatter_add_", + r""" scatter_add_(dim, index, src) -> Tensor -Adds all values from the tensor :attr:`other` into :attr:`self` at the indices +Adds all values from the tensor :attr:`src` into :attr:`self` at the indices specified in the :attr:`index` tensor in a similar fashion as :meth:`~torch.Tensor.scatter_`. For each value in :attr:`src`, it is added to an index in :attr:`self` which is specified by its index in :attr:`src` @@ -3485,9 +4341,14 @@ def callable(a, b) -> number [0., 2., 0., 0., 0.], [0., 0., 2., 1., 1.]]) -""".format(**reproducibility_notes)) +""".format( + **reproducibility_notes + ), +) -add_docstr_all('scatter_reduce_', r""" +add_docstr_all( + "scatter_reduce_", + r""" scatter_reduce_(dim, index, src, reduce, *, include_self=True) -> Tensor Reduces all values from the :attr:`src` tensor to the indices specified in @@ -3548,31 +4409,41 @@ def callable(a, b) -> number tensor([3., 6., 5., 2.]) -""".format(**reproducibility_notes)) +""".format( + **reproducibility_notes + ), +) -add_docstr_all('select', - r""" +add_docstr_all( + "select", + r""" select(dim, index) -> Tensor See :func:`torch.select` -""") +""", +) -add_docstr_all('select_scatter', - r""" +add_docstr_all( + "select_scatter", + r""" select_scatter(src, dim, index) -> Tensor See :func:`torch.select_scatter` -""") +""", +) -add_docstr_all('slice_scatter', - r""" +add_docstr_all( + "slice_scatter", + r""" slice_scatter(src, dim=0, start=None, end=None, step=1) -> Tensor See :func:`torch.slice_scatter` -""") +""", +) -add_docstr_all('set_', - r""" +add_docstr_all( + "set_", + r""" set_(source=None, storage_offset=0, size=None, stride=None) -> Tensor Sets the underlying storage, size, and strides. If :attr:`source` is a tensor, @@ -3588,115 +4459,147 @@ def callable(a, b) -> number storage_offset (int, optional): the offset in the storage size (torch.Size, optional): the desired size. Defaults to the size of the source. stride (tuple, optional): the desired stride. Defaults to C-contiguous strides. -""") +""", +) -add_docstr_all('sigmoid', - r""" +add_docstr_all( + "sigmoid", + r""" sigmoid() -> Tensor See :func:`torch.sigmoid` -""") +""", +) -add_docstr_all('sigmoid_', - r""" +add_docstr_all( + "sigmoid_", + r""" sigmoid_() -> Tensor In-place version of :meth:`~Tensor.sigmoid` -""") +""", +) -add_docstr_all('logit', - r""" +add_docstr_all( + "logit", + r""" logit() -> Tensor See :func:`torch.logit` -""") +""", +) -add_docstr_all('logit_', - r""" +add_docstr_all( + "logit_", + r""" logit_() -> Tensor In-place version of :meth:`~Tensor.logit` -""") +""", +) -add_docstr_all('sign', - r""" +add_docstr_all( + "sign", + r""" sign() -> Tensor See :func:`torch.sign` -""") +""", +) -add_docstr_all('sign_', - r""" +add_docstr_all( + "sign_", + r""" sign_() -> Tensor In-place version of :meth:`~Tensor.sign` -""") +""", +) -add_docstr_all('signbit', - r""" +add_docstr_all( + "signbit", + r""" signbit() -> Tensor See :func:`torch.signbit` -""") +""", +) -add_docstr_all('sgn', - r""" +add_docstr_all( + "sgn", + r""" sgn() -> Tensor See :func:`torch.sgn` -""") +""", +) -add_docstr_all('sgn_', - r""" +add_docstr_all( + "sgn_", + r""" sgn_() -> Tensor In-place version of :meth:`~Tensor.sgn` -""") +""", +) -add_docstr_all('sin', - r""" +add_docstr_all( + "sin", + r""" sin() -> Tensor See :func:`torch.sin` -""") +""", +) -add_docstr_all('sin_', - r""" +add_docstr_all( + "sin_", + r""" sin_() -> Tensor In-place version of :meth:`~Tensor.sin` -""") +""", +) -add_docstr_all('sinc', - r""" +add_docstr_all( + "sinc", + r""" sinc() -> Tensor See :func:`torch.sinc` -""") +""", +) -add_docstr_all('sinc_', - r""" +add_docstr_all( + "sinc_", + r""" sinc_() -> Tensor In-place version of :meth:`~Tensor.sinc` -""") +""", +) -add_docstr_all('sinh', - r""" +add_docstr_all( + "sinh", + r""" sinh() -> Tensor See :func:`torch.sinh` -""") +""", +) -add_docstr_all('sinh_', - r""" +add_docstr_all( + "sinh_", + r""" sinh_() -> Tensor In-place version of :meth:`~Tensor.sinh` -""") +""", +) -add_docstr_all('size', - r""" +add_docstr_all( + "size", + r""" size(dim=None) -> torch.Size or int Returns the size of the :attr:`self` tensor. If ``dim`` is not specified, @@ -3714,31 +4617,39 @@ def callable(a, b) -> number >>> t.size(dim=1) 4 -""") +""", +) -add_docstr_all('sort', - r""" +add_docstr_all( + "sort", + r""" sort(dim=-1, descending=False) -> (Tensor, LongTensor) See :func:`torch.sort` -""") +""", +) -add_docstr_all('msort', - r""" +add_docstr_all( + "msort", + r""" msort() -> Tensor See :func:`torch.msort` -""") +""", +) -add_docstr_all('argsort', - r""" +add_docstr_all( + "argsort", + r""" argsort(dim=-1, descending=False) -> LongTensor See :func:`torch.argsort` -""") +""", +) -add_docstr_all('sparse_dim', - r""" +add_docstr_all( + "sparse_dim", + r""" sparse_dim() -> int Return the number of sparse dimensions in a :ref:`sparse tensor ` :attr:`self`. @@ -3747,10 +4658,12 @@ def callable(a, b) -> number Throws an error if :attr:`self` is not a sparse tensor. See also :meth:`Tensor.dense_dim` and :ref:`hybrid tensors `. -""") +""", +) -add_docstr_all('sparse_resize_', - r""" +add_docstr_all( + "sparse_resize_", + r""" sparse_resize_(size, sparse_dim, dense_dim) -> Tensor Resizes :attr:`self` :ref:`sparse tensor ` to the desired @@ -3777,10 +4690,12 @@ def callable(a, b) -> number original size. sparse_dim (int): the number of sparse dimensions dense_dim (int): the number of dense dimensions -""") +""", +) -add_docstr_all('sparse_resize_and_clear_', - r""" +add_docstr_all( + "sparse_resize_and_clear_", + r""" sparse_resize_and_clear_(size, sparse_dim, dense_dim) -> Tensor Removes all specified elements from a :ref:`sparse tensor @@ -3794,52 +4709,66 @@ def callable(a, b) -> number size (torch.Size): the desired size. sparse_dim (int): the number of sparse dimensions dense_dim (int): the number of dense dimensions -""") +""", +) -add_docstr_all('sqrt', - r""" +add_docstr_all( + "sqrt", + r""" sqrt() -> Tensor See :func:`torch.sqrt` -""") +""", +) -add_docstr_all('sqrt_', - r""" +add_docstr_all( + "sqrt_", + r""" sqrt_() -> Tensor In-place version of :meth:`~Tensor.sqrt` -""") +""", +) -add_docstr_all('square', - r""" +add_docstr_all( + "square", + r""" square() -> Tensor See :func:`torch.square` -""") +""", +) -add_docstr_all('square_', - r""" +add_docstr_all( + "square_", + r""" square_() -> Tensor In-place version of :meth:`~Tensor.square` -""") +""", +) -add_docstr_all('squeeze', - r""" +add_docstr_all( + "squeeze", + r""" squeeze(dim=None) -> Tensor See :func:`torch.squeeze` -""") +""", +) -add_docstr_all('squeeze_', - r""" +add_docstr_all( + "squeeze_", + r""" squeeze_(dim=None) -> Tensor In-place version of :meth:`~Tensor.squeeze` -""") +""", +) -add_docstr_all('std', - r""" +add_docstr_all( + "std", + r""" std(dim, unbiased=True, keepdim=False) -> Tensor See :func:`torch.std` @@ -3848,10 +4777,12 @@ def callable(a, b) -> number :noindex: See :func:`torch.std` -""") +""", +) -add_docstr_all('storage_offset', - r""" +add_docstr_all( + "storage_offset", + r""" storage_offset() -> int Returns :attr:`self` tensor's offset in the underlying storage in terms of @@ -3865,10 +4796,12 @@ def callable(a, b) -> number >>> x[3:].storage_offset() 3 -""") +""", +) -add_docstr_all('stride', - r""" +add_docstr_all( + "stride", + r""" stride(dim) -> tuple or int Returns the stride of :attr:`self` tensor. @@ -3891,109 +4824,147 @@ def callable(a, b) -> number >>> x.stride(-1) 1 -""") +""", +) -add_docstr_all('sub', r""" +add_docstr_all( + "sub", + r""" sub(other, *, alpha=1) -> Tensor See :func:`torch.sub`. -""") +""", +) -add_docstr_all('sub_', - r""" +add_docstr_all( + "sub_", + r""" sub_(other, *, alpha=1) -> Tensor In-place version of :meth:`~Tensor.sub` -""") +""", +) -add_docstr_all('subtract', r""" +add_docstr_all( + "subtract", + r""" subtract(other, *, alpha=1) -> Tensor See :func:`torch.subtract`. -""") +""", +) -add_docstr_all('subtract_', r""" +add_docstr_all( + "subtract_", + r""" subtract_(other, *, alpha=1) -> Tensor In-place version of :meth:`~Tensor.subtract`. -""") +""", +) -add_docstr_all('sum', - r""" +add_docstr_all( + "sum", + r""" sum(dim=None, keepdim=False, dtype=None) -> Tensor See :func:`torch.sum` -""") +""", +) -add_docstr_all('nansum', - r""" +add_docstr_all( + "nansum", + r""" nansum(dim=None, keepdim=False, dtype=None) -> Tensor See :func:`torch.nansum` -""") +""", +) -add_docstr_all('svd', - r""" +add_docstr_all( + "svd", + r""" svd(some=True, compute_uv=True) -> (Tensor, Tensor, Tensor) See :func:`torch.svd` -""") +""", +) -add_docstr_all('symeig', - r""" +add_docstr_all( + "symeig", + r""" symeig(eigenvectors=False, upper=True) -> (Tensor, Tensor) See :func:`torch.symeig` -""") +""", +) -add_docstr_all('swapdims', r""" +add_docstr_all( + "swapdims", + r""" swapdims(dim0, dim1) -> Tensor See :func:`torch.swapdims` -""") +""", +) -add_docstr_all('swapdims_', - r""" +add_docstr_all( + "swapdims_", + r""" swapdims_(dim0, dim1) -> Tensor In-place version of :meth:`~Tensor.swapdims` -""") +""", +) -add_docstr_all('swapaxes', r""" +add_docstr_all( + "swapaxes", + r""" swapaxes(axis0, axis1) -> Tensor See :func:`torch.swapaxes` -""") +""", +) -add_docstr_all('swapaxes_', r""" +add_docstr_all( + "swapaxes_", + r""" swapaxes_(axis0, axis1) -> Tensor In-place version of :meth:`~Tensor.swapaxes` -""") +""", +) -add_docstr_all('t', - r""" +add_docstr_all( + "t", + r""" t() -> Tensor See :func:`torch.t` -""") +""", +) -add_docstr_all('t_', - r""" +add_docstr_all( + "t_", + r""" t_() -> Tensor In-place version of :meth:`~Tensor.t` -""") +""", +) -add_docstr_all('tile', - r""" +add_docstr_all( + "tile", + r""" tile(*reps) -> Tensor See :func:`torch.tile` -""") +""", +) -add_docstr_all('to', - r""" +add_docstr_all( + "to", + r""" to(*args, **kwargs) -> Tensor Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are @@ -4060,191 +5031,261 @@ def callable(a, b) -> number >>> tensor.to(other, non_blocking=True) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') -""".format(**common_args)) - -add_docstr_all('byte', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "byte", + r""" byte(memory_format=torch.preserve_format) -> Tensor ``self.byte()`` is equivalent to ``self.to(torch.uint8)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('bool', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "bool", + r""" bool(memory_format=torch.preserve_format) -> Tensor ``self.bool()`` is equivalent to ``self.to(torch.bool)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('char', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "char", + r""" char(memory_format=torch.preserve_format) -> Tensor ``self.char()`` is equivalent to ``self.to(torch.int8)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('bfloat16', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "bfloat16", + r""" bfloat16(memory_format=torch.preserve_format) -> Tensor ``self.bfloat16()`` is equivalent to ``self.to(torch.bfloat16)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('double', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "double", + r""" double(memory_format=torch.preserve_format) -> Tensor ``self.double()`` is equivalent to ``self.to(torch.float64)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('float', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "float", + r""" float(memory_format=torch.preserve_format) -> Tensor ``self.float()`` is equivalent to ``self.to(torch.float32)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('cdouble', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "cdouble", + r""" cdouble(memory_format=torch.preserve_format) -> Tensor ``self.cdouble()`` is equivalent to ``self.to(torch.complex128)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('cfloat', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "cfloat", + r""" cfloat(memory_format=torch.preserve_format) -> Tensor ``self.cfloat()`` is equivalent to ``self.to(torch.complex64)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('chalf', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "chalf", + r""" chalf(memory_format=torch.preserve_format) -> Tensor ``self.chalf()`` is equivalent to ``self.to(torch.complex32)``. See :func:`to`. Args: {memory_format} - """.format(**common_args)) - -add_docstr_all('half', - r""" + """.format( + **common_args + ), +) + +add_docstr_all( + "half", + r""" half(memory_format=torch.preserve_format) -> Tensor ``self.half()`` is equivalent to ``self.to(torch.float16)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('int', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "int", + r""" int(memory_format=torch.preserve_format) -> Tensor ``self.int()`` is equivalent to ``self.to(torch.int32)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('int_repr', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "int_repr", + r""" int_repr() -> Tensor Given a quantized Tensor, ``self.int_repr()`` returns a CPU Tensor with uint8_t as data type that stores the underlying uint8_t values of the given Tensor. -""") +""", +) -add_docstr_all('long', - r""" +add_docstr_all( + "long", + r""" long(memory_format=torch.preserve_format) -> Tensor ``self.long()`` is equivalent to ``self.to(torch.int64)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('short', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "short", + r""" short(memory_format=torch.preserve_format) -> Tensor ``self.short()`` is equivalent to ``self.to(torch.int16)``. See :func:`to`. Args: {memory_format} -""".format(**common_args)) - -add_docstr_all('take', - r""" +""".format( + **common_args + ), +) + +add_docstr_all( + "take", + r""" take(indices) -> Tensor See :func:`torch.take` -""") +""", +) -add_docstr_all('take_along_dim', - r""" +add_docstr_all( + "take_along_dim", + r""" take_along_dim(indices, dim) -> Tensor See :func:`torch.take_along_dim` -""") +""", +) -add_docstr_all('tan', - r""" +add_docstr_all( + "tan", + r""" tan() -> Tensor See :func:`torch.tan` -""") +""", +) -add_docstr_all('tan_', - r""" +add_docstr_all( + "tan_", + r""" tan_() -> Tensor In-place version of :meth:`~Tensor.tan` -""") +""", +) -add_docstr_all('tanh', - r""" +add_docstr_all( + "tanh", + r""" tanh() -> Tensor See :func:`torch.tanh` -""") +""", +) -add_docstr_all('tanh_', - r""" +add_docstr_all( + "tanh_", + r""" tanh_() -> Tensor In-place version of :meth:`~Tensor.tanh` -""") +""", +) -add_docstr_all('tolist', - r""" +add_docstr_all( + "tolist", + r""" tolist() -> list or number Returns the tensor as a (nested) list. For scalars, a standard @@ -4261,23 +5302,24 @@ def callable(a, b) -> number [-0.08909505605697632, 0.7729271650314331]] >>> a[0,0].tolist() 0.012766935862600803 -""") +""", +) -add_docstr_all('topk', - r""" +add_docstr_all( + "topk", + r""" topk(k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor) See :func:`torch.topk` -""") +""", +) -add_docstr_all('to_dense', - r""" +add_docstr_all( + "to_dense", + r""" to_dense() -> Tensor -Creates a strided copy of :attr:`self`. - -.. warning:: - Throws an error if :attr:`self` is a strided tensor. +Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`. Example:: @@ -4290,11 +5332,14 @@ def callable(a, b) -> number tensor([[ 0, 0, 0], [ 9, 0, 10], [ 0, 0, 0]]) -""") +""", +) -add_docstr_all('to_sparse', - r""" +add_docstr_all( + "to_sparse", + r""" to_sparse(sparseDims) -> Tensor + Returns a sparse copy of the tensor. PyTorch supports sparse tensors in :ref:`coordinate format `. @@ -4317,12 +5362,15 @@ def callable(a, b) -> number tensor(indices=tensor([[1]]), values=tensor([[ 9, 0, 10]]), size=(3, 3), nnz=1, layout=torch.sparse_coo) -""") +""", +) -add_docstr_all('to_sparse_csr', - r""" +add_docstr_all( + "to_sparse_csr", + r""" to_sparse_csr() -> Tensor -Convert a tensor to compressed row storage format. Only works with 2D tensors. + +Convert a tensor to compressed row storage format (CSR). Only works with 2D tensors. Example:: @@ -4331,11 +5379,31 @@ def callable(a, b) -> number >>> sparse._nnz() 25 -""") +""", +) + +add_docstr_all( + "to_sparse_csc", + r""" +to_sparse_csc() -> Tensor + +Convert a tensor to compressed column storage (CSC) format. Only works with 2D tensors. + +Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csc() + >>> sparse._nnz() + 25 + +""", +) -add_docstr_all('to_sparse_bsr', - r""" +add_docstr_all( + "to_sparse_bsr", + r""" to_sparse_bsr(blocksize) -> Tensor + Convert a CSR tensor to a block sparse row (BSR) storage format of given blocksize. Example:: @@ -4346,115 +5414,165 @@ def callable(a, b) -> number >>> sparse_bsr.col_indices() tensor([0, 1, 0, 1]) -""") +""", +) + +add_docstr_all( + "to_sparse_bsc", + r""" +to_sparse_bsc(blocksize) -> Tensor + +Convert a CSR tensor to a block sparse column (BSC) storage format of given blocksize. + +Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsc = sparse.to_sparse_bsc((5, 5)) + >>> sparse_bsc.row_indices() + tensor([0, 1, 0, 1]) + +""", +) -add_docstr_all('to_mkldnn', - r""" +add_docstr_all( + "to_mkldnn", + r""" to_mkldnn() -> Tensor Returns a copy of the tensor in ``torch.mkldnn`` layout. -""") +""", +) -add_docstr_all('trace', - r""" +add_docstr_all( + "trace", + r""" trace() -> Tensor See :func:`torch.trace` -""") +""", +) -add_docstr_all('transpose', - r""" +add_docstr_all( + "transpose", + r""" transpose(dim0, dim1) -> Tensor See :func:`torch.transpose` -""") +""", +) -add_docstr_all('transpose_', - r""" +add_docstr_all( + "transpose_", + r""" transpose_(dim0, dim1) -> Tensor In-place version of :meth:`~Tensor.transpose` -""") +""", +) -add_docstr_all('triangular_solve', - r""" +add_docstr_all( + "triangular_solve", + r""" triangular_solve(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) See :func:`torch.triangular_solve` -""") +""", +) -add_docstr_all('tril', - r""" +add_docstr_all( + "tril", + r""" tril(diagonal=0) -> Tensor See :func:`torch.tril` -""") +""", +) -add_docstr_all('tril_', - r""" +add_docstr_all( + "tril_", + r""" tril_(diagonal=0) -> Tensor In-place version of :meth:`~Tensor.tril` -""") +""", +) -add_docstr_all('triu', - r""" +add_docstr_all( + "triu", + r""" triu(diagonal=0) -> Tensor See :func:`torch.triu` -""") +""", +) -add_docstr_all('triu_', - r""" +add_docstr_all( + "triu_", + r""" triu_(diagonal=0) -> Tensor In-place version of :meth:`~Tensor.triu` -""") +""", +) -add_docstr_all('true_divide', - r""" +add_docstr_all( + "true_divide", + r""" true_divide(value) -> Tensor See :func:`torch.true_divide` -""") +""", +) -add_docstr_all('true_divide_', - r""" +add_docstr_all( + "true_divide_", + r""" true_divide_(value) -> Tensor In-place version of :meth:`~Tensor.true_divide_` -""") +""", +) -add_docstr_all('trunc', - r""" +add_docstr_all( + "trunc", + r""" trunc() -> Tensor See :func:`torch.trunc` -""") +""", +) -add_docstr_all('fix', - r""" +add_docstr_all( + "fix", + r""" fix() -> Tensor See :func:`torch.fix`. -""") +""", +) -add_docstr_all('trunc_', - r""" +add_docstr_all( + "trunc_", + r""" trunc_() -> Tensor In-place version of :meth:`~Tensor.trunc` -""") +""", +) -add_docstr_all('fix_', - r""" +add_docstr_all( + "fix_", + r""" fix_() -> Tensor In-place version of :meth:`~Tensor.fix` -""") +""", +) -add_docstr_all('type', - r""" +add_docstr_all( + "type", + r""" type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor Returns the type if `dtype` is not provided, else casts this object to the specified type. @@ -4470,10 +5588,12 @@ def callable(a, b) -> number has no effect. **kwargs: For compatibility, may contain the key ``async`` in place of the ``non_blocking`` argument. The ``async`` arg is deprecated. -""") +""", +) -add_docstr_all('type_as', - r""" +add_docstr_all( + "type_as", + r""" type_as(tensor) -> Tensor Returns this tensor cast to the type of the given tensor. @@ -4483,10 +5603,12 @@ def callable(a, b) -> number Args: tensor (Tensor): the tensor which has the desired type -""") +""", +) -add_docstr_all('unfold', - r""" +add_docstr_all( + "unfold", + r""" unfold(dimension, size, step) -> Tensor Returns a view of the original tensor which contains all slices of size :attr:`size` from @@ -4521,10 +5643,12 @@ def callable(a, b) -> number tensor([[ 1., 2.], [ 3., 4.], [ 5., 6.]]) -""") +""", +) -add_docstr_all('uniform_', - r""" +add_docstr_all( + "uniform_", + r""" uniform_(from=0, to=1) -> Tensor Fills :attr:`self` tensor with numbers sampled from the continuous uniform @@ -4532,24 +5656,30 @@ def callable(a, b) -> number .. math:: P(x) = \dfrac{1}{\text{to} - \text{from}} -""") +""", +) -add_docstr_all('unsqueeze', - r""" +add_docstr_all( + "unsqueeze", + r""" unsqueeze(dim) -> Tensor See :func:`torch.unsqueeze` -""") +""", +) -add_docstr_all('unsqueeze_', - r""" +add_docstr_all( + "unsqueeze_", + r""" unsqueeze_(dim) -> Tensor In-place version of :meth:`~Tensor.unsqueeze` -""") +""", +) -add_docstr_all('var', - r""" +add_docstr_all( + "var", + r""" var(dim, unbiased=True, keepdim=False) -> Tensor See :func:`torch.var` @@ -4558,17 +5688,21 @@ def callable(a, b) -> number :noindex: See :func:`torch.var` -""") +""", +) -add_docstr_all('vdot', - r""" +add_docstr_all( + "vdot", + r""" vdot(other) -> Tensor See :func:`torch.vdot` -""") +""", +) -add_docstr_all('view', - r""" +add_docstr_all( + "view", + r""" view(*shape) -> Tensor Returns a new tensor with the same data as the :attr:`self` tensor but of a @@ -4704,10 +5838,12 @@ def callable(a, b) -> number 28, 191]], dtype=torch.uint8) >>> x.view(torch.uint8).size() torch.Size([4, 16]) -""") +""", +) -add_docstr_all('view_as', - r""" +add_docstr_all( + "view_as", + r""" view_as(other) -> Tensor View this tensor as the same size as :attr:`other`. @@ -4718,10 +5854,12 @@ def callable(a, b) -> number Args: other (:class:`torch.Tensor`): The result tensor has the same size as :attr:`other`. -""") +""", +) -add_docstr_all('expand', - r""" +add_docstr_all( + "expand", + r""" expand(*sizes) -> Tensor Returns a new view of the :attr:`self` tensor with singleton dimensions expanded @@ -4763,10 +5901,12 @@ def callable(a, b) -> number tensor([[ 1, 1, 1, 1], [ 2, 2, 2, 2], [ 3, 3, 3, 3]]) -""") +""", +) -add_docstr_all('expand_as', - r""" +add_docstr_all( + "expand_as", + r""" expand_as(other) -> Tensor Expand this tensor to the same size as :attr:`other`. @@ -4777,10 +5917,12 @@ def callable(a, b) -> number Args: other (:class:`torch.Tensor`): The result tensor has the same size as :attr:`other`. -""") +""", +) -add_docstr_all('sum_to_size', - r""" +add_docstr_all( + "sum_to_size", + r""" sum_to_size(*size) -> Tensor Sum ``this`` tensor to :attr:`size`. @@ -4788,231 +5930,295 @@ def callable(a, b) -> number Args: size (int...): a sequence of integers defining the shape of the output tensor. -""") +""", +) -add_docstr_all('zero_', - r""" +add_docstr_all( + "zero_", + r""" zero_() -> Tensor Fills :attr:`self` tensor with zeros. -""") +""", +) -add_docstr_all('matmul', - r""" +add_docstr_all( + "matmul", + r""" matmul(tensor2) -> Tensor See :func:`torch.matmul` -""") +""", +) -add_docstr_all('chunk', - r""" +add_docstr_all( + "chunk", + r""" chunk(chunks, dim=0) -> List of Tensors See :func:`torch.chunk` -""") +""", +) -add_docstr_all('unsafe_chunk', - r""" +add_docstr_all( + "unsafe_chunk", + r""" unsafe_chunk(chunks, dim=0) -> List of Tensors See :func:`torch.unsafe_chunk` -""") +""", +) -add_docstr_all('unsafe_split', - r""" +add_docstr_all( + "unsafe_split", + r""" unsafe_split(split_size, dim=0) -> List of Tensors See :func:`torch.unsafe_split` -""") +""", +) -add_docstr_all('tensor_split', - r""" +add_docstr_all( + "tensor_split", + r""" tensor_split(indices_or_sections, dim=0) -> List of Tensors See :func:`torch.tensor_split` -""") +""", +) -add_docstr_all('hsplit', - r""" +add_docstr_all( + "hsplit", + r""" hsplit(split_size_or_sections) -> List of Tensors See :func:`torch.hsplit` -""") +""", +) -add_docstr_all('vsplit', - r""" +add_docstr_all( + "vsplit", + r""" vsplit(split_size_or_sections) -> List of Tensors See :func:`torch.vsplit` -""") +""", +) -add_docstr_all('dsplit', - r""" +add_docstr_all( + "dsplit", + r""" dsplit(split_size_or_sections) -> List of Tensors See :func:`torch.dsplit` -""") +""", +) -add_docstr_all('stft', - r""" +add_docstr_all( + "stft", + r""" stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor See :func:`torch.stft` -""") +""", +) -add_docstr_all('istft', - r""" +add_docstr_all( + "istft", + r""" istft(n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=True, length=None) -> Tensor See :func:`torch.istft` -""") +""", +) -add_docstr_all('det', - r""" +add_docstr_all( + "det", + r""" det() -> Tensor See :func:`torch.det` -""") +""", +) -add_docstr_all('where', - r""" +add_docstr_all( + "where", + r""" where(condition, y) -> Tensor ``self.where(condition, y)`` is equivalent to ``torch.where(condition, self, y)``. See :func:`torch.where` -""") +""", +) -add_docstr_all('logdet', - r""" +add_docstr_all( + "logdet", + r""" logdet() -> Tensor See :func:`torch.logdet` -""") +""", +) -add_docstr_all('slogdet', - r""" +add_docstr_all( + "slogdet", + r""" slogdet() -> (Tensor, Tensor) See :func:`torch.slogdet` -""") +""", +) -add_docstr_all('unbind', - r""" +add_docstr_all( + "unbind", + r""" unbind(dim=0) -> seq See :func:`torch.unbind` -""") +""", +) -add_docstr_all('pin_memory', - r""" +add_docstr_all( + "pin_memory", + r""" pin_memory() -> Tensor Copies the tensor to pinned memory, if it's not already pinned. -""") +""", +) -add_docstr_all('pinverse', - r""" +add_docstr_all( + "pinverse", + r""" pinverse() -> Tensor See :func:`torch.pinverse` -""") +""", +) -add_docstr_all('index_add', - r""" +add_docstr_all( + "index_add", + r""" index_add(dim, index, source, *, alpha=1) -> Tensor Out-of-place version of :meth:`torch.Tensor.index_add_`. -""") +""", +) -add_docstr_all('index_copy', - r""" +add_docstr_all( + "index_copy", + r""" index_copy(dim, index, tensor2) -> Tensor Out-of-place version of :meth:`torch.Tensor.index_copy_`. -""") +""", +) -add_docstr_all('index_fill', - r""" +add_docstr_all( + "index_fill", + r""" index_fill(dim, index, value) -> Tensor Out-of-place version of :meth:`torch.Tensor.index_fill_`. -""") +""", +) -add_docstr_all('scatter', - r""" +add_docstr_all( + "scatter", + r""" scatter(dim, index, src) -> Tensor Out-of-place version of :meth:`torch.Tensor.scatter_` -""") +""", +) -add_docstr_all('scatter_add', - r""" +add_docstr_all( + "scatter_add", + r""" scatter_add(dim, index, src) -> Tensor Out-of-place version of :meth:`torch.Tensor.scatter_add_` -""") +""", +) -add_docstr_all('scatter_reduce', - r""" +add_docstr_all( + "scatter_reduce", + r""" scatter_reduce(dim, index, src, reduce, *, include_self=True) -> Tensor Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` -""") +""", +) -add_docstr_all('masked_scatter', - r""" +add_docstr_all( + "masked_scatter", + r""" masked_scatter(mask, tensor) -> Tensor Out-of-place version of :meth:`torch.Tensor.masked_scatter_` -""") +""", +) -add_docstr_all('xlogy', - r""" +add_docstr_all( + "xlogy", + r""" xlogy(other) -> Tensor See :func:`torch.xlogy` -""") +""", +) -add_docstr_all('xlogy_', - r""" +add_docstr_all( + "xlogy_", + r""" xlogy_(other) -> Tensor In-place version of :meth:`~Tensor.xlogy` -""") +""", +) -add_docstr_all('masked_fill', - r""" +add_docstr_all( + "masked_fill", + r""" masked_fill(mask, value) -> Tensor Out-of-place version of :meth:`torch.Tensor.masked_fill_` -""") +""", +) -add_docstr_all('grad', - r""" +add_docstr_all( + "grad", + r""" This attribute is ``None`` by default and becomes a Tensor the first time a call to :func:`backward` computes gradients for ``self``. The attribute will then contain the gradients computed and future calls to :func:`backward` will accumulate (add) gradients into it. -""") +""", +) -add_docstr_all('retain_grad', - r""" +add_docstr_all( + "retain_grad", + r""" retain_grad() -> None Enables this Tensor to have their :attr:`grad` populated during :func:`backward`. This is a no-op for leaf tensors. -""") +""", +) -add_docstr_all('retains_grad', - r""" +add_docstr_all( + "retains_grad", + r""" Is ``True`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be populated during :func:`backward`, ``False`` otherwise. -""") +""", +) -add_docstr_all('requires_grad', - r""" +add_docstr_all( + "requires_grad", + r""" Is ``True`` if gradients need to be computed for this Tensor, ``False`` otherwise. .. note:: @@ -5020,10 +6226,12 @@ def callable(a, b) -> number The fact that gradients need to be computed for a Tensor do not mean that the :attr:`grad` attribute will be populated, see :attr:`is_leaf` for more details. -""") +""", +) -add_docstr_all('is_leaf', - r""" +add_docstr_all( + "is_leaf", + r""" All Tensors that have :attr:`requires_grad` which is ``False`` will be leaf Tensors by convention. For Tensors that have :attr:`requires_grad` which is ``True``, they will be leaf Tensors if they were @@ -5060,10 +6268,12 @@ def callable(a, b) -> number # f requires grad, has no operation creating it -""") +""", +) -add_docstr_all('names', - r""" +add_docstr_all( + "names", + r""" Stores names for each of this tensor's dimensions. ``names[idx]`` corresponds to the name of tensor dimension ``idx``. @@ -5078,66 +6288,90 @@ def callable(a, b) -> number .. warning:: The named tensor API is experimental and subject to change. -""") +""", +) -add_docstr_all('is_cuda', - r""" +add_docstr_all( + "is_cuda", + r""" Is ``True`` if the Tensor is stored on the GPU, ``False`` otherwise. -""") +""", +) -add_docstr_all('is_cpu', - r""" +add_docstr_all( + "is_cpu", + r""" Is ``True`` if the Tensor is stored on the CPU, ``False`` otherwise. -""") +""", +) -add_docstr_all('is_ipu', - r""" +add_docstr_all( + "is_ipu", + r""" Is ``True`` if the Tensor is stored on the IPU, ``False`` otherwise. -""") +""", +) -add_docstr_all('is_xpu', - r""" +add_docstr_all( + "is_xpu", + r""" Is ``True`` if the Tensor is stored on the XPU, ``False`` otherwise. -""") +""", +) -add_docstr_all('is_quantized', - r""" +add_docstr_all( + "is_quantized", + r""" Is ``True`` if the Tensor is quantized, ``False`` otherwise. -""") +""", +) -add_docstr_all('is_meta', - r""" +add_docstr_all( + "is_meta", + r""" Is ``True`` if the Tensor is a meta tensor, ``False`` otherwise. Meta tensors are like normal tensors, but they carry no data. -""") +""", +) -add_docstr_all('is_mps', - r""" +add_docstr_all( + "is_mps", + r""" Is ``True`` if the Tensor is stored on the MPS device, ``False`` otherwise. -""") +""", +) -add_docstr_all('is_sparse', - r""" +add_docstr_all( + "is_sparse", + r""" Is ``True`` if the Tensor uses sparse storage layout, ``False`` otherwise. -""") +""", +) -add_docstr_all('is_sparse_csr', - r""" +add_docstr_all( + "is_sparse_csr", + r""" Is ``True`` if the Tensor uses sparse CSR storage layout, ``False`` otherwise. -""") +""", +) -add_docstr_all('device', - r""" +add_docstr_all( + "device", + r""" Is the :class:`torch.device` where this Tensor is. -""") +""", +) -add_docstr_all('ndim', - r""" +add_docstr_all( + "ndim", + r""" Alias for :meth:`~Tensor.dim()` -""") +""", +) -add_docstr_all('T', - r""" +add_docstr_all( + "T", + r""" Returns a view of this tensor with its dimensions reversed. If ``n`` is the number of dimensions in ``x``, @@ -5148,10 +6382,12 @@ def callable(a, b) -> number is deprecated and it will throw an error in a future release. Consider :attr:`~.Tensor.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. -""") +""", +) -add_docstr_all('H', - r""" +add_docstr_all( + "H", + r""" Returns a view of a matrix (2-D tensor) conjugated and transposed. ``x.H`` is equivalent to ``x.transpose(0, 1).conj()`` for complex matrices and @@ -5160,29 +6396,37 @@ def callable(a, b) -> number .. seealso:: :attr:`~.Tensor.mH`: An attribute that also works on batches of matrices. -""") +""", +) -add_docstr_all('mT', - r""" +add_docstr_all( + "mT", + r""" Returns a view of this tensor with the last two dimensions transposed. ``x.mT`` is equivalent to ``x.transpose(-2, -1)``. -""") +""", +) -add_docstr_all('mH', - r""" +add_docstr_all( + "mH", + r""" Accessing this property is equivalent to calling :func:`adjoint`. -""") +""", +) -add_docstr_all('adjoint', - r""" +add_docstr_all( + "adjoint", + r""" adjoint() -> Tensor Alias for :func:`adjoint` -""") +""", +) -add_docstr_all('real', - r""" +add_docstr_all( + "real", + r""" Returns a new tensor containing real values of the :attr:`self` tensor for a complex-valued input tensor. The returned tensor and :attr:`self` share the same underlying storage. @@ -5195,10 +6439,12 @@ def callable(a, b) -> number >>> x.real tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) -""") +""", +) -add_docstr_all('imag', - r""" +add_docstr_all( + "imag", + r""" Returns a new tensor containing imaginary values of the :attr:`self` tensor. The returned tensor and :attr:`self` share the same underlying storage. @@ -5212,19 +6458,23 @@ def callable(a, b) -> number >>> x.imag tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) -""") +""", +) -add_docstr_all('as_subclass', - r""" +add_docstr_all( + "as_subclass", + r""" as_subclass(cls) -> Tensor Makes a ``cls`` instance with the same data pointer as ``self``. Changes in the output mirror changes in ``self``, and the output stays attached to the autograd graph. ``cls`` must be a subclass of ``Tensor``. -""") +""", +) -add_docstr_all('crow_indices', - r""" +add_docstr_all( + "crow_indices", + r""" crow_indices() -> IntTensor Returns the tensor containing the compressed row indices of the :attr:`self` @@ -5239,10 +6489,12 @@ def callable(a, b) -> number >>> csr.crow_indices() tensor([0, 1, 2, 3, 4, 5], dtype=torch.int32) -""") +""", +) -add_docstr_all('col_indices', - r""" +add_docstr_all( + "col_indices", + r""" col_indices() -> IntTensor Returns the tensor containing the column indices of the :attr:`self` @@ -5257,10 +6509,12 @@ def callable(a, b) -> number >>> csr.col_indices() tensor([0, 1, 2, 3, 4], dtype=torch.int32) -""") +""", +) -add_docstr_all('to_padded_tensor', - r""" +add_docstr_all( + "to_padded_tensor", + r""" to_padded_tensor(padding, output_size=None) -> Tensor Returns a new (non-nested) Tensor by padding the nested tensor. @@ -5307,4 +6561,5 @@ def callable(a, b) -> number >>> pt_small = nt.to_padded_tensor(2.0, (2, 2, 2)) RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported. -""") +""", +) diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index a9b83f31b9398..0308d028bdd05 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -1,7 +1,8 @@ import math +from typing import Optional + import torch from torch._six import inf -from typing import Optional class __PrinterOptions(object): @@ -17,12 +18,12 @@ class __PrinterOptions(object): # We could use **kwargs, but this will give better docs def set_printoptions( - precision=None, - threshold=None, - edgeitems=None, - linewidth=None, - profile=None, - sci_mode=None + precision=None, + threshold=None, + edgeitems=None, + linewidth=None, + profile=None, + sci_mode=None, ): r"""Set options for printing. Items shamelessly taken from NumPy @@ -80,10 +81,12 @@ def set_printoptions( PRINT_OPTS.linewidth = linewidth PRINT_OPTS.sci_mode = sci_mode + def tensor_totype(t): dtype = torch.float if t.is_mps else torch.double return t.to(dtype=dtype) + class _Formatter(object): def __init__(self, tensor): self.floating_dtype = tensor.dtype.is_floating_point @@ -96,11 +99,13 @@ def __init__(self, tensor): if not self.floating_dtype: for value in tensor_view: - value_str = '{}'.format(value) + value_str = "{}".format(value) self.max_width = max(self.max_width, len(value_str)) else: - nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)) + nonzero_finite_vals = torch.masked_select( + tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0) + ) if nonzero_finite_vals.numel() == 0: # no valid number, do nothing @@ -119,27 +124,38 @@ def __init__(self, tensor): if self.int_mode: # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites # to indicate that the tensor is of floating type. add 1 to the len to account for this. - if nonzero_finite_max / nonzero_finite_min > 1000. or nonzero_finite_max > 1.e8: + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + ): self.sci_mode = True for value in nonzero_finite_vals: - value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value) + value_str = ( + ("{{:.{}e}}").format(PRINT_OPTS.precision).format(value) + ) self.max_width = max(self.max_width, len(value_str)) else: for value in nonzero_finite_vals: - value_str = ('{:.0f}').format(value) + value_str = ("{:.0f}").format(value) self.max_width = max(self.max_width, len(value_str) + 1) else: # Check if scientific representation should be used. - if nonzero_finite_max / nonzero_finite_min > 1000.\ - or nonzero_finite_max > 1.e8\ - or nonzero_finite_min < 1.e-4: + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + or nonzero_finite_min < 1.0e-4 + ): self.sci_mode = True for value in nonzero_finite_vals: - value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value) + value_str = ( + ("{{:.{}e}}").format(PRINT_OPTS.precision).format(value) + ) self.max_width = max(self.max_width, len(value_str)) else: for value in nonzero_finite_vals: - value_str = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value) + value_str = ( + ("{{:.{}f}}").format(PRINT_OPTS.precision).format(value) + ) self.max_width = max(self.max_width, len(value_str)) if PRINT_OPTS.sci_mode is not None: @@ -151,16 +167,20 @@ def width(self): def format(self, value): if self.floating_dtype: if self.sci_mode: - ret = ('{{:{}.{}e}}').format(self.max_width, PRINT_OPTS.precision).format(value) + ret = ( + ("{{:{}.{}e}}") + .format(self.max_width, PRINT_OPTS.precision) + .format(value) + ) elif self.int_mode: - ret = '{:.0f}'.format(value) + ret = "{:.0f}".format(value) if not (math.isinf(value) or math.isnan(value)): - ret += '.' + ret += "." else: - ret = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value) + ret = ("{{:.{}f}}").format(PRINT_OPTS.precision).format(value) else: - ret = '{}'.format(value) - return (self.max_width - len(ret)) * ' ' + ret + ret = "{}".format(value) + return (self.max_width - len(ret)) * " " + ret def _scalar_str(self, formatter1, formatter2=None): @@ -168,13 +188,14 @@ def _scalar_str(self, formatter1, formatter2=None): real_str = _scalar_str(self.real, formatter1) imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip() # handles negative numbers, +0.0, -0.0 - if imag_str[0] == '+' or imag_str[0] == '-': + if imag_str[0] == "+" or imag_str[0] == "-": return real_str + imag_str else: return real_str + "+" + imag_str else: return formatter1.format(self.item()) + def _vector_str(self, indent, summarize, formatter1, formatter2=None): # length includes spaces and comma between elements element_length = formatter1.width() + 2 @@ -182,7 +203,9 @@ def _vector_str(self, indent, summarize, formatter1, formatter2=None): # width for imag_formatter + an extra j for complex element_length += formatter2.width() + 1 - elements_per_line = max(1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))) + elements_per_line = max( + 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))) + ) char_per_line = element_length * elements_per_line def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): @@ -190,7 +213,7 @@ def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): real_str = formatter1.format(val.real) imag_str = (formatter2.format(val.imag) + "j").lstrip() # handles negative numbers, +0.0, -0.0 - if imag_str[0] == '+' or imag_str[0] == '-': + if imag_str[0] == "+" or imag_str[0] == "-": return real_str + imag_str else: return real_str + "+" + imag_str @@ -198,15 +221,20 @@ def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): return formatter1.format(val) if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: - data = ([_val_formatter(val) for val in self[:PRINT_OPTS.edgeitems].tolist()] + - [' ...'] + - [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems:].tolist()]) + data = ( + [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()] + + [" ..."] + + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()] + ) else: data = [_val_formatter(val) for val in self.tolist()] - data_lines = [data[i:i + elements_per_line] for i in range(0, len(data), elements_per_line)] - lines = [', '.join(line) for line in data_lines] - return '[' + (',' + '\n' + ' ' * (indent + 1)).join(lines) + ']' + data_lines = [ + data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line) + ] + lines = [", ".join(line) for line in data_lines] + return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]" + # formatter2 is only used for printing complex tensors. # For complex tensors, formatter1 and formatter2 are the formatters for tensor.real @@ -221,21 +249,36 @@ def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=N return _vector_str(self, indent, summarize, formatter1, formatter2) if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: - slices = ([_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2) - for i in range(0, PRINT_OPTS.edgeitems)] + - ['...'] + - [_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2) - for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]) + slices = ( + [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(0, PRINT_OPTS.edgeitems) + ] + + ["..."] + + [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(len(self) - PRINT_OPTS.edgeitems, len(self)) + ] + ) else: - slices = [_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2) - for i in range(0, self.size(0))] + slices = [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(0, self.size(0)) + ] + + tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices) + return "[" + tensor_str + "]" - tensor_str = (',' + '\n' * (dim - 1) + ' ' * (indent + 1)).join(slices) - return '[' + tensor_str + ']' def _tensor_str(self, indent): if self.numel() == 0: - return '[]' + return "[]" if self.has_names(): # There are two main codepaths (possibly more) that tensor printing goes through: @@ -263,27 +306,34 @@ def _tensor_str(self, indent): if self.dtype.is_complex: # handle the conjugate bit self = self.resolve_conj() - real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real) - imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag) - return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter) + real_formatter = _Formatter( + get_summarized_data(self.real) if summarize else self.real + ) + imag_formatter = _Formatter( + get_summarized_data(self.imag) if summarize else self.imag + ) + return _tensor_str_with_formatter( + self, indent, summarize, real_formatter, imag_formatter + ) else: formatter = _Formatter(get_summarized_data(self) if summarize else self) return _tensor_str_with_formatter(self, indent, summarize, formatter) + def _add_suffixes(tensor_str, suffixes, indent, force_newline): tensor_strs = [tensor_str] - last_line_len = len(tensor_str) - tensor_str.rfind('\n') + 1 + last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1 for suffix in suffixes: suffix_len = len(suffix) if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth: - tensor_strs.append(',\n' + ' ' * indent + suffix) + tensor_strs.append(",\n" + " " * indent + suffix) last_line_len = indent + suffix_len force_newline = False else: - tensor_strs.append(', ' + suffix) + tensor_strs.append(", " + suffix) last_line_len += suffix_len + 2 - tensor_strs.append(')') - return ''.join(tensor_strs) + tensor_strs.append(")") + return "".join(tensor_strs) def get_summarized_data(self): @@ -292,23 +342,25 @@ def get_summarized_data(self): return self if dim == 1: if self.size(0) > 2 * PRINT_OPTS.edgeitems: - return torch.cat((self[:PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems:])) + return torch.cat( + (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :]) + ) else: return self if self.size(0) > 2 * PRINT_OPTS.edgeitems: start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] - end = ([self[i] - for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]) + end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))] return torch.stack([get_summarized_data(x) for x in (start + end)]) else: return torch.stack([get_summarized_data(x) for x in self]) + def _str_intern(inp, *, tensor_contents=None): is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter if inp.is_nested: prefix = "nested_tensor(" elif is_plain_tensor: - prefix = 'tensor(' + prefix = "tensor(" else: prefix = f"{type(inp).__name__}(" indent = len(prefix) @@ -329,42 +381,67 @@ def _str_intern(inp, *, tensor_contents=None): # torch._C._get_default_device() only returns either cpu or cuda. # In other cases, we don't have a way to set them as default yet, # and we should always print out device for them. - if self.device.type != torch._C._get_default_device()\ - or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index)\ - or (self.device.type == 'mps'): - suffixes.append('device=\'' + str(self.device) + '\'') + if ( + self.device.type != torch._C._get_default_device() + or ( + self.device.type == "cuda" + and torch.cuda.current_device() != self.device.index + ) + or (self.device.type == "mps") + ): + suffixes.append("device='" + str(self.device) + "'") # Tensor printing performs tensor operations like slice, indexing, etc to make it in a # representable format. These operations on ipu/xla/lazy tensor results in compilations. Hence, # to avoid compilations, copying the tensor to cpu before printing. - if self.device.type in ['xla', 'lazy', 'ipu']: - self = self.to('cpu') + if self.device.type in ["xla", "lazy", "ipu"]: + self = self.to("cpu") # TODO: add an API to map real -> complex dtypes - _default_complex_dtype = torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat - has_default_dtype = self.dtype in (torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool) + _default_complex_dtype = ( + torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat + ) + has_default_dtype = self.dtype in ( + torch.get_default_dtype(), + _default_complex_dtype, + torch.int64, + torch.bool, + ) if self.is_sparse: - suffixes.append('size=' + str(tuple(self.shape))) - suffixes.append('nnz=' + str(self._nnz())) + suffixes.append("size=" + str(tuple(self.shape))) + suffixes.append("nnz=" + str(self._nnz())) if not has_default_dtype: - suffixes.append('dtype=' + str(self.dtype)) + suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: - indices_prefix = 'indices=tensor(' + indices_prefix = "indices=tensor(" indices = self._indices().detach() indices_str = _tensor_str(indices, indent + len(indices_prefix)) if indices.numel() == 0: - indices_str += ', size=' + str(tuple(indices.shape)) - values_prefix = 'values=tensor(' + indices_str += ", size=" + str(tuple(indices.shape)) + values_prefix = "values=tensor(" values = self._values().detach() values_str = _tensor_str(values, indent + len(values_prefix)) if values.numel() == 0: - values_str += ', size=' + str(tuple(values.shape)) - tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')' - elif self.layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}: - suffixes.append('size=' + str(tuple(self.shape))) - suffixes.append('nnz=' + str(self._nnz())) + values_str += ", size=" + str(tuple(values.shape)) + tensor_str = ( + indices_prefix + + indices_str + + "),\n" + + " " * indent + + values_prefix + + values_str + + ")" + ) + elif self.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + suffixes.append("size=" + str(tuple(self.shape))) + suffixes.append("nnz=" + str(self._nnz())) if not has_default_dtype: - suffixes.append('dtype=' + str(self.dtype)) + suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: compressed_indices_method, plain_indices_method = { torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), @@ -373,72 +450,102 @@ def _str_intern(inp, *, tensor_contents=None): torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), }[self.layout] if self.layout in {torch.sparse_csr, torch.sparse_bsr}: - cdimname, pdimname = 'row', 'column' + cdimname, pdimname = "row", "column" else: - cdimname, pdimname = 'column', 'row' - compressed_indices_prefix = f'c{cdimname[:3]}_indices=tensor(' + cdimname, pdimname = "column", "row" + compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor(" compressed_indices = compressed_indices_method(self).detach() - compressed_indices_str = _tensor_str(compressed_indices, indent + len(compressed_indices_prefix)) + compressed_indices_str = _tensor_str( + compressed_indices, indent + len(compressed_indices_prefix) + ) if compressed_indices.numel() == 0: - compressed_indices_str += ', size=' + str(tuple(compressed_indices.shape)) - plain_indices_prefix = f'{pdimname[:3]}_indices=tensor(' + compressed_indices_str += ", size=" + str( + tuple(compressed_indices.shape) + ) + plain_indices_prefix = f"{pdimname[:3]}_indices=tensor(" plain_indices = plain_indices_method(self).detach() - plain_indices_str = _tensor_str(plain_indices, indent + len(plain_indices_prefix)) + plain_indices_str = _tensor_str( + plain_indices, indent + len(plain_indices_prefix) + ) if plain_indices.numel() == 0: - plain_indices_str += ', size=' + str(tuple(plain_indices.shape)) - values_prefix = 'values=tensor(' + plain_indices_str += ", size=" + str(tuple(plain_indices.shape)) + values_prefix = "values=tensor(" values = self.values().detach() values_str = _tensor_str(values, indent + len(values_prefix)) if values.numel() == 0: - values_str += ', size=' + str(tuple(values.shape)) - tensor_str = compressed_indices_prefix + compressed_indices_str + '),\n' + ' ' * indent +\ - plain_indices_prefix + plain_indices_str + '),\n' + ' ' * indent +\ - values_prefix + values_str + ')' + values_str += ", size=" + str(tuple(values.shape)) + tensor_str = ( + compressed_indices_prefix + + compressed_indices_str + + "),\n" + + " " * indent + + plain_indices_prefix + + plain_indices_str + + "),\n" + + " " * indent + + values_prefix + + values_str + + ")" + ) elif self.is_quantized: - suffixes.append('size=' + str(tuple(self.shape))) + suffixes.append("size=" + str(tuple(self.shape))) if not has_default_dtype: - suffixes.append('dtype=' + str(self.dtype)) - suffixes.append('quantization_scheme=' + str(self.qscheme())) - if self.qscheme() == torch.per_tensor_affine or self.qscheme() == torch.per_tensor_symmetric: - suffixes.append('scale=' + str(self.q_scale())) - suffixes.append('zero_point=' + str(self.q_zero_point())) - elif self.qscheme() == torch.per_channel_affine or self.qscheme() == torch.per_channel_symmetric \ - or self.qscheme() == torch.per_channel_affine_float_qparams: - suffixes.append('scale=' + str(self.q_per_channel_scales())) - suffixes.append('zero_point=' + str(self.q_per_channel_zero_points())) - suffixes.append('axis=' + str(self.q_per_channel_axis())) + suffixes.append("dtype=" + str(self.dtype)) + suffixes.append("quantization_scheme=" + str(self.qscheme())) + if ( + self.qscheme() == torch.per_tensor_affine + or self.qscheme() == torch.per_tensor_symmetric + ): + suffixes.append("scale=" + str(self.q_scale())) + suffixes.append("zero_point=" + str(self.q_zero_point())) + elif ( + self.qscheme() == torch.per_channel_affine + or self.qscheme() == torch.per_channel_symmetric + or self.qscheme() == torch.per_channel_affine_float_qparams + ): + suffixes.append("scale=" + str(self.q_per_channel_scales())) + suffixes.append("zero_point=" + str(self.q_per_channel_zero_points())) + suffixes.append("axis=" + str(self.q_per_channel_axis())) if not custom_contents_provided: tensor_str = _tensor_str(self.dequantize(), indent) elif self.is_nested: if not custom_contents_provided: + def indented_str(s, indent): return "\n".join(f" {line}" for line in s.split("\n")) - strs = ",\n".join(indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0)) + + strs = ",\n".join( + indented_str(str(t), indent + 1) + for t in torch.ops.aten.unbind.int(self, 0) + ) tensor_str = f"[\n{strs}\n]" + elif torch._is_functional_tensor(self): + prefix = "_to_functional_tensor(" + tensor_str = repr(torch._from_functional_tensor(self)) else: if self.is_meta: - suffixes.append('size=' + str(tuple(self.shape))) + suffixes.append("size=" + str(tuple(self.shape))) if self.dtype != torch.get_default_dtype(): - suffixes.append('dtype=' + str(self.dtype)) + suffixes.append("dtype=" + str(self.dtype)) # TODO: This implies that ellipses is valid syntax for allocating # a meta tensor, which it could be, but it isn't right now if not custom_contents_provided: - tensor_str = '...' + tensor_str = "..." else: if self.numel() == 0 and not self.is_sparse: # Explicitly print the shape if it is not (0,), to match NumPy behavior if self.dim() != 1: - suffixes.append('size=' + str(tuple(self.shape))) + suffixes.append("size=" + str(tuple(self.shape))) # In an empty tensor, there are no elements to infer if the dtype # should be int64, so it must be shown explicitly. if self.dtype != torch.get_default_dtype(): - suffixes.append('dtype=' + str(self.dtype)) + suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: - tensor_str = '[]' + tensor_str = "[]" else: if not has_default_dtype: - suffixes.append('dtype=' + str(self.dtype)) + suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: if self.layout != torch.strided: @@ -447,25 +554,27 @@ def indented_str(s, indent): tensor_str = _tensor_str(self, indent) if self.layout != torch.strided: - suffixes.append('layout=' + str(self.layout)) + suffixes.append("layout=" + str(self.layout)) # Use inp here to get the original grad_fn and not the one generated by the forward grad # unpacking. if inp.grad_fn is not None: name = type(inp.grad_fn).__name__ - if name == 'CppFunction': - name = inp.grad_fn.name().rsplit('::', 1)[-1] - suffixes.append('grad_fn=<{}>'.format(name)) + if name == "CppFunction": + name = inp.grad_fn.name().rsplit("::", 1)[-1] + suffixes.append("grad_fn=<{}>".format(name)) elif inp.requires_grad: - suffixes.append('requires_grad=True') + suffixes.append("requires_grad=True") if self.has_names(): - suffixes.append('names={}'.format(self.names)) + suffixes.append("names={}".format(self.names)) if tangent is not None: - suffixes.append('tangent={}'.format(tangent)) + suffixes.append("tangent={}".format(tangent)) - string_repr = _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse) + string_repr = _add_suffixes( + prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse + ) # Check if this instance is flagged as a parameter and change the repr accordingly. # Unfortunately, this function has to be aware of this detail. @@ -476,6 +585,7 @@ def indented_str(s, indent): return string_repr + def _str(self, *, tensor_contents=None): with torch.no_grad(): return _str_intern(self, tensor_contents=tensor_contents) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 508d2131f3863..958d9814d7cf4 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -21,45 +21,77 @@ def parse_kwargs(desc): regx = re.compile(r"\n\s{4}(?!\s)") kwargs = [section.strip() for section in regx.split(desc)] kwargs = [section for section in kwargs if len(section) > 0] - return {desc.split(' ')[0]: desc for desc in kwargs} + return {desc.split(" ")[0]: desc for desc in kwargs} def merge_dicts(*dicts): return {x: d[x] for d in dicts for x in d} -common_args = parse_kwargs(""" +common_args = parse_kwargs( + """ input (Tensor): the input tensor. generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling out (Tensor, optional): the output tensor. memory_format (:class:`torch.memory_format`, optional): the desired memory format of returned tensor. Default: ``torch.preserve_format``. -""") +""" +) -reduceops_common_args = merge_dicts(common_args, parse_kwargs(""" +reduceops_common_args = merge_dicts( + common_args, + parse_kwargs( + """ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. keepdim (bool): whether the output tensor has :attr:`dim` retained or not. -""")) - -multi_dim_common = merge_dicts(reduceops_common_args, parse_kwargs(""" +""" + ), +) + +multi_dim_common = merge_dicts( + reduceops_common_args, + parse_kwargs( + """ dim (int or tuple of ints): the dimension or dimensions to reduce. -"""), {'keepdim_details': """ +""" + ), + { + "keepdim_details": """ If :attr:`keepdim` is ``True``, the output tensor is of the same size as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the output tensor having 1 (or ``len(dim)``) fewer dimension(s). -"""}) - -single_dim_common = merge_dicts(reduceops_common_args, parse_kwargs(""" +""" + }, + { + "opt_dim": """ + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. +""" + }, +) + +single_dim_common = merge_dicts( + reduceops_common_args, + parse_kwargs( + """ dim (int): the dimension to reduce. -"""), {'keepdim_details': """If :attr:`keepdim` is ``True``, the output tensor is of the same size +""" + ), + { + "keepdim_details": """If :attr:`keepdim` is ``True``, the output tensor is of the same size as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in -the output tensor having 1 fewer dimension than :attr:`input`."""}) - -factory_common_args = merge_dicts(common_args, parse_kwargs(""" +the output tensor having 1 fewer dimension than :attr:`input`.""" + }, +) + +factory_common_args = merge_dicts( + common_args, + parse_kwargs( + """ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. @@ -74,9 +106,12 @@ def merge_dicts(*dicts): the pinned memory. Works only for CPU tensors. Default: ``False``. memory_format (:class:`torch.memory_format`, optional): the desired memory format of returned Tensor. Default: ``torch.contiguous_format``. -""")) +""" + ), +) -factory_like_common_args = parse_kwargs(""" +factory_like_common_args = parse_kwargs( + """ input (Tensor): the size of :attr:`input` will determine size of the output tensor. layout (:class:`torch.layout`, optional): the desired layout of returned tensor. Default: if ``None``, defaults to the layout of :attr:`input`. @@ -90,9 +125,11 @@ def merge_dicts(*dicts): the pinned memory. Works only for CPU tensors. Default: ``False``. memory_format (:class:`torch.memory_format`, optional): the desired memory format of returned Tensor. Default: ``torch.preserve_format``. -""") +""" +) -factory_data_common_args = parse_kwargs(""" +factory_data_common_args = parse_kwargs( + """ data (array_like): Initial data for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar, and other types. dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. @@ -105,7 +142,8 @@ def merge_dicts(*dicts): returned tensor. Default: ``False``. pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: ``False``. -""") +""" +) tf32_notes = { "tf32_note": """This operator supports :ref:`TensorFloat32`.""" @@ -125,17 +163,20 @@ def merge_dicts(*dicts): and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is \ undesirable, you can try to make the operation deterministic (potentially at \ a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. \ -See :doc:`/notes/randomness` for more information.""" +See :doc:`/notes/randomness` for more information.""", } -add_docstr(torch.abs, r""" +add_docstr( + torch.abs, + r""" abs(input, *, out=None) -> Tensor Computes the absolute value of each element in :attr:`input`. .. math:: \text{out}_{i} = |\text{input}_{i}| -""" + r""" +""" + + r""" Args: {input} @@ -146,23 +187,31 @@ def merge_dicts(*dicts): >>> torch.abs(torch.tensor([-1, -2, 3])) tensor([ 1, 2, 3]) -""".format(**common_args)) - -add_docstr(torch.absolute, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.absolute, + r""" absolute(input, *, out=None) -> Tensor Alias for :func:`torch.abs` -""") +""", +) -add_docstr(torch.acos, r""" +add_docstr( + torch.acos, + r""" acos(input, *, out=None) -> Tensor Computes the inverse cosine of each element in :attr:`input`. .. math:: \text{out}_{i} = \cos^{-1}(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -176,15 +225,23 @@ def merge_dicts(*dicts): tensor([ 0.3348, -0.5889, 0.2005, -0.1584]) >>> torch.acos(a) tensor([ 1.2294, 2.2004, 1.3690, 1.7298]) -""".format(**common_args)) - -add_docstr(torch.arccos, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.arccos, + r""" arccos(input, *, out=None) -> Tensor Alias for :func:`torch.acos`. -""") +""", +) -add_docstr(torch.acosh, r""" +add_docstr( + torch.acosh, + r""" acosh(input, *, out=None) -> Tensor Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. @@ -195,7 +252,8 @@ def merge_dicts(*dicts): Note: The domain of the inverse hyperbolic cosine is `[1, inf)` and values outside this range will be mapped to ``NaN``, except for `+ INF` for which the output is mapped to `+ INF`. -""" + r""" +""" + + r""" Args: {input} @@ -209,40 +267,58 @@ def merge_dicts(*dicts): tensor([ 1.3192, 1.9915, 1.9674, 1.7151 ]) >>> torch.acosh(a) tensor([ 0.7791, 1.3120, 1.2979, 1.1341 ]) -""".format(**common_args)) - -add_docstr(torch.arccosh, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.arccosh, + r""" arccosh(input, *, out=None) -> Tensor Alias for :func:`torch.acosh`. -""") +""", +) -add_docstr(torch.index_add, r""" +add_docstr( + torch.index_add, + r""" index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor See :meth:`~Tensor.index_add_` for function description. -""") +""", +) -add_docstr(torch.index_copy, r""" +add_docstr( + torch.index_copy, + r""" index_copy(input, dim, index, source, *, out=None) -> Tensor See :meth:`~Tensor.index_add_` for function description. -""") +""", +) -add_docstr(torch.index_reduce, r""" +add_docstr( + torch.index_reduce, + r""" index_reduce(input, dim, index, source, reduce, *, include_self=True, out=None) -> Tensor See :meth:`~Tensor.index_reduce_` for function description. -""") +""", +) -add_docstr(torch.add, r""" +add_docstr( + torch.add, + r""" add(input, other, *, alpha=1, out=None) -> Tensor Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. .. math:: \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i -""" + r""" +""" + + r""" Supports :ref:`broadcasting to a common shape `, :ref:`type promotion `, and integer, float, and complex inputs. @@ -277,10 +353,14 @@ def merge_dicts(*dicts): [-18.6971, -18.0736, -17.0994, -17.3216], [ -6.7845, -6.1610, -5.1868, -5.4090], [ -8.9902, -8.3667, -7.3925, -7.6147]]) -""".format(**common_args)) - -add_docstr(torch.addbmm, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.addbmm, + r""" addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor Performs a batch matrix-matrix product of matrices stored @@ -302,7 +382,8 @@ def merge_dicts(*dicts): If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in it will not be propagated. -""" + r""" +""" + + r""" For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers. @@ -329,9 +410,14 @@ def merge_dicts(*dicts): tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) -""".format(**common_args, **tf32_notes, **rocm_fp16_notes)) - -add_docstr(torch.addcdiv, r""" +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) + +add_docstr( + torch.addcdiv, + r""" addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, @@ -348,7 +434,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} -""" + r""" +""" + + r""" The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be :ref:`broadcastable `. @@ -374,10 +461,14 @@ def merge_dicts(*dicts): tensor([[-0.2312, -3.6496, 0.1312], [-1.0428, 3.4292, -0.1030], [-0.5369, -0.9829, 0.0430]]) -""".format(**common_args)) - -add_docstr(torch.addcmul, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.addcmul, + r""" addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor Performs the element-wise multiplication of :attr:`tensor1` @@ -386,7 +477,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i -""" + r""" +""" + + r""" The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be :ref:`broadcastable `. @@ -411,10 +503,14 @@ def merge_dicts(*dicts): tensor([[-0.8635, -0.6391, 1.6174], [-0.7617, -0.5879, 1.7388], [-0.8353, -0.6249, 1.6511]]) -""".format(**common_args)) - -add_docstr(torch.addmm, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.addmm, + r""" addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. @@ -433,7 +529,8 @@ def merge_dicts(*dicts): If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in it will not be propagated. -""" + r""" +""" + + r""" For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers. @@ -459,10 +556,14 @@ def merge_dicts(*dicts): >>> torch.addmm(M, mat1, mat2) tensor([[-4.8716, 1.4671, -1.3746], [ 0.7573, -3.9555, -2.8681]]) -""".format(**common_args, **tf32_notes, **rocm_fp16_notes)) - -add_docstr(torch.adjoint, - r""" +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) + +add_docstr( + torch.adjoint, + r""" adjoint(Tensor) -> Tensor Returns a view of the tensor conjugated and with the last two dimensions transposed. @@ -480,10 +581,12 @@ def merge_dicts(*dicts): [1.-1.j, 3.-3.j]]) >>> (A.adjoint() == A.mH).all() tensor(True) -""") +""", +) -add_docstr(torch.sspaddmm, - r""" +add_docstr( + torch.sspaddmm, + r""" sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor @@ -501,10 +604,14 @@ def merge_dicts(*dicts): beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) {out} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.smm, - r""" +add_docstr( + torch.smm, + r""" smm(input, mat) -> Tensor Performs a matrix multiplication of the sparse matrix :attr:`input` @@ -513,10 +620,12 @@ def merge_dicts(*dicts): Args: input (Tensor): a sparse matrix to be matrix multiplied mat (Tensor): a dense matrix to be matrix multiplied -""") +""", +) -add_docstr(torch.addmv, - r""" +add_docstr( + torch.addmv, + r""" addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor Performs a matrix-vector product of the matrix :attr:`mat` and @@ -536,7 +645,8 @@ def merge_dicts(*dicts): If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in it will not be propagated. -""" + r""" +""" + + r""" For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers @@ -557,10 +667,14 @@ def merge_dicts(*dicts): >>> vec = torch.randn(3) >>> torch.addmv(M, mat, vec) tensor([-0.3768, -5.5565]) -""".format(**common_args)) - -add_docstr(torch.addr, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.addr, + r""" addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` @@ -575,7 +689,8 @@ def merge_dicts(*dicts): If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in it will not be propagated. -""" + r""" +""" + + r""" If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector of size `m`, then :attr:`input` must be :ref:`broadcastable ` with a matrix of size @@ -601,17 +716,22 @@ def merge_dicts(*dicts): tensor([[ 1., 2.], [ 2., 4.], [ 3., 6.]]) -""".format(**common_args)) - -add_docstr(torch.allclose, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.allclose, + r""" allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool This function checks if all :attr:`input` and :attr:`other` satisfy the condition: .. math:: \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert -""" + r""" +""" + + r""" elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to `numpy.allclose `_ @@ -632,10 +752,12 @@ def merge_dicts(*dicts): False >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')]), equal_nan=True) True -""") +""", +) -add_docstr(torch.all, - r""" +add_docstr( + torch.all, + r""" all(input) -> Tensor Tests if all elements in :attr:`input` evaluate to `True`. @@ -685,10 +807,14 @@ def merge_dicts(*dicts): tensor([ True, False, True, True], dtype=torch.bool) >>> torch.all(a, dim=0) tensor([ True, False], dtype=torch.bool) -""".format(**single_dim_common)) - -add_docstr(torch.any, - r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.any, + r""" any(input) -> Tensor Tests if any element in :attr:`input` evaluates to `True`. @@ -738,17 +864,22 @@ def merge_dicts(*dicts): tensor([ True, True, True, False]) >>> torch.any(a, 0) tensor([True, True]) -""".format(**single_dim_common)) - -add_docstr(torch.angle, - r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.angle, + r""" angle(input, *, out=None) -> Tensor Computes the element-wise angle (in radians) of the given :attr:`input` tensor. .. math:: \text{out}_{i} = angle(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -764,10 +895,14 @@ def merge_dicts(*dicts): >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 tensor([ 135., 135, -45]) -""".format(**common_args)) - -add_docstr(torch.as_strided, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.as_strided, + r""" as_strided(input, size, stride, storage_offset=None) -> Tensor Create a view of an existing `torch.Tensor` :attr:`input` with specified @@ -803,10 +938,14 @@ def merge_dicts(*dicts): >>> t = torch.as_strided(x, (2, 2), (1, 2), 1) tensor([[0.6291, 0.1586], [1.0795, 2.1939]]) -""".format(**common_args)) - -add_docstr(torch.as_tensor, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.as_tensor, + r""" as_tensor(data, dtype=None, device=None) -> Tensor Converts data into a tensor, sharing data and preserving autograd @@ -850,16 +989,22 @@ def merge_dicts(*dicts): >>> t[0] = -1 >>> a array([1, 2, 3]) -""".format(**factory_data_common_args)) - -add_docstr(torch.asin, r""" +""".format( + **factory_data_common_args + ), +) + +add_docstr( + torch.asin, + r""" asin(input, *, out=None) -> Tensor Returns a new tensor with the arcsine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sin^{-1}(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -873,23 +1018,31 @@ def merge_dicts(*dicts): tensor([-0.5962, 1.4985, -0.4396, 1.4525]) >>> torch.asin(a) tensor([-0.6387, nan, -0.4552, nan]) -""".format(**common_args)) - -add_docstr(torch.arcsin, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.arcsin, + r""" arcsin(input, *, out=None) -> Tensor Alias for :func:`torch.asin`. -""") +""", +) -add_docstr(torch.asinh, - r""" +add_docstr( + torch.asinh, + r""" asinh(input, *, out=None) -> Tensor Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sinh^{-1}(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -903,22 +1056,31 @@ def merge_dicts(*dicts): tensor([ 0.1606, -1.4267, -1.0899, -1.0250 ]) >>> torch.asinh(a) tensor([ 0.1599, -1.1534, -0.9435, -0.8990 ]) -""".format(**common_args)) - -add_docstr(torch.arcsinh, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.arcsinh, + r""" arcsinh(input, *, out=None) -> Tensor Alias for :func:`torch.asinh`. -""") +""", +) -add_docstr(torch.atan, r""" +add_docstr( + torch.atan, + r""" atan(input, *, out=None) -> Tensor Returns a new tensor with the arctangent of the elements of :attr:`input`. .. math:: \text{out}_{i} = \tan^{-1}(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -932,16 +1094,23 @@ def merge_dicts(*dicts): tensor([ 0.2341, 0.2539, -0.6256, -0.6448]) >>> torch.atan(a) tensor([ 0.2299, 0.2487, -0.5591, -0.5727]) -""".format(**common_args)) - -add_docstr(torch.arctan, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.arctan, + r""" arctan(input, *, out=None) -> Tensor Alias for :func:`torch.atan`. -""") +""", +) -add_docstr(torch.atan2, - r""" +add_docstr( + torch.atan2, + r""" atan2(input, other, *, out=None) -> Tensor Element-wise arctangent of :math:`\text{{input}}_{{i}} / \text{{other}}_{{i}}` @@ -968,15 +1137,22 @@ def merge_dicts(*dicts): tensor([ 0.9041, 0.0196, -0.3108, -2.4423]) >>> torch.atan2(a, torch.randn(4)) tensor([ 0.9833, 0.0811, -1.9743, -1.4151]) -""".format(**common_args)) - -add_docstr(torch.arctan2, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.arctan2, + r""" arctan2(input, other, *, out=None) -> Tensor Alias for :func:`torch.atan2`. -""") +""", +) -add_docstr(torch.atanh, r""" +add_docstr( + torch.atanh, + r""" atanh(input, *, out=None) -> Tensor Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. @@ -988,7 +1164,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -1002,16 +1179,23 @@ def merge_dicts(*dicts): tensor([ -0.9385, 0.2968, -0.8591, -0.1871 ]) >>> torch.atanh(a) tensor([ -1.7253, 0.3060, -1.2899, -0.1893 ]) -""".format(**common_args)) - -add_docstr(torch.arctanh, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.arctanh, + r""" arctanh(input, *, out=None) -> Tensor Alias for :func:`torch.atanh`. -""") +""", +) -add_docstr(torch.asarray, - r""" +add_docstr( + torch.asarray, + r""" asarray(obj, *, dtype=None, device=None, copy=None, requires_grad=False) -> Tensor Converts :attr:`obj` to a tensor. @@ -1109,10 +1293,12 @@ def merge_dicts(*dicts): >>> t2 = torch.asarray(array, dtype=torch.float32) >>> array.__array_interface__['data'][0] == t1.data_ptr() False -""") +""", +) -add_docstr(torch.baddbmm, - r""" +add_docstr( + torch.baddbmm, + r""" baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor Performs a batch matrix-matrix product of matrices in :attr:`batch1` @@ -1134,7 +1320,8 @@ def merge_dicts(*dicts): If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in it will not be propagated. -""" + r""" +""" + + r""" For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers. @@ -1159,10 +1346,14 @@ def merge_dicts(*dicts): >>> batch2 = torch.randn(10, 4, 5) >>> torch.baddbmm(M, batch1, batch2).size() torch.Size([10, 3, 5]) -""".format(**common_args, **tf32_notes, **rocm_fp16_notes)) - -add_docstr(torch.bernoulli, - r""" +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) + +add_docstr( + torch.bernoulli, + r""" bernoulli(input, *, generator=None, out=None) -> Tensor Draws binary random numbers (0 or 1) from a Bernoulli distribution. @@ -1178,7 +1369,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) -""" + r""" +""" + + r""" The returned :attr:`out` tensor only has values 0 or 1 and is of the same shape as :attr:`input`. @@ -1214,10 +1406,14 @@ def merge_dicts(*dicts): tensor([[ 0., 0., 0.], [ 0., 0., 0.], [ 0., 0., 0.]]) -""".format(**common_args)) - -add_docstr(torch.bincount, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.bincount, + r""" bincount(input, weights=None, minlength=0) -> Tensor Count the frequency of each value in an array of non-negative ints. @@ -1256,10 +1452,14 @@ def merge_dicts(*dicts): >>> input.bincount(weights) tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) -""".format(**reproducibility_notes)) - -add_docstr(torch.bitwise_not, - r""" +""".format( + **reproducibility_notes + ), +) + +add_docstr( + torch.bitwise_not, + r""" bitwise_not(input, *, out=None) -> Tensor Computes the bitwise NOT of the given input tensor. The input tensor must be of @@ -1275,10 +1475,14 @@ def merge_dicts(*dicts): >>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8)) tensor([ 0, 1, -4], dtype=torch.int8) -""".format(**common_args)) - -add_docstr(torch.bmm, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.bmm, + r""" bmm(input, mat2, *, out=None) -> Tensor Performs a batch matrix-matrix product of matrices stored in :attr:`input` @@ -1293,7 +1497,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i -""" + r""" +""" + + r""" {tf32_note} {rocm_fp16_note} @@ -1315,10 +1520,14 @@ def merge_dicts(*dicts): >>> res = torch.bmm(input, mat2) >>> res.size() torch.Size([10, 3, 5]) -""".format(**common_args, **tf32_notes, **rocm_fp16_notes)) - -add_docstr(torch.bitwise_and, - r""" +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) + +add_docstr( + torch.bitwise_and, + r""" bitwise_and(input, other, *, out=None) -> Tensor Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of @@ -1337,10 +1546,14 @@ def merge_dicts(*dicts): tensor([1, 0, 3], dtype=torch.int8) >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) tensor([ False, True, False]) -""".format(**common_args)) - -add_docstr(torch.bitwise_or, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.bitwise_or, + r""" bitwise_or(input, other, *, out=None) -> Tensor Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of @@ -1359,10 +1572,14 @@ def merge_dicts(*dicts): tensor([-1, -2, 3], dtype=torch.int8) >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) tensor([ True, True, False]) -""".format(**common_args)) - -add_docstr(torch.bitwise_xor, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.bitwise_xor, + r""" bitwise_xor(input, other, *, out=None) -> Tensor Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of @@ -1381,10 +1598,14 @@ def merge_dicts(*dicts): tensor([-2, -2, 0], dtype=torch.int8) >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) tensor([ True, False, False]) -""".format(**common_args)) - -add_docstr(torch.bitwise_left_shift, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.bitwise_left_shift, + r""" bitwise_left_shift(input, other, *, out=None) -> Tensor Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. @@ -1408,10 +1629,14 @@ def merge_dicts(*dicts): >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) tensor([-2, -2, 24], dtype=torch.int8) -""".format(**common_args)) - -add_docstr(torch.bitwise_right_shift, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.bitwise_right_shift, + r""" bitwise_right_shift(input, other, *, out=None) -> Tensor Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. @@ -1435,10 +1660,14 @@ def merge_dicts(*dicts): >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) tensor([-1, -7, 3], dtype=torch.int8) -""".format(**common_args)) - -add_docstr(torch.broadcast_to, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.broadcast_to, + r""" broadcast_to(input, shape) -> Tensor Broadcasts :attr:`input` to the shape :attr:`\shape`. @@ -1455,10 +1684,14 @@ def merge_dicts(*dicts): tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) -""".format(**common_args)) - -add_docstr(torch.stack, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.stack, + r""" stack(tensors, dim=0, *, out=None) -> Tensor Concatenates a sequence of tensors along a new dimension. @@ -1472,10 +1705,14 @@ def merge_dicts(*dicts): Keyword args: {out} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.hstack, - r""" +add_docstr( + torch.hstack, + r""" hstack(tensors, *, out=None) -> Tensor Stack tensors in sequence horizontally (column wise). @@ -1501,10 +1738,14 @@ def merge_dicts(*dicts): [2, 5], [3, 6]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.vstack, - r""" +add_docstr( + torch.vstack, + r""" vstack(tensors, *, out=None) -> Tensor Stack tensors in sequence vertically (row wise). @@ -1535,10 +1776,14 @@ def merge_dicts(*dicts): [6]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.dstack, - r""" +add_docstr( + torch.dstack, + r""" dstack(tensors, *, out=None) -> Tensor Stack tensors in sequence depthwise (along third axis). @@ -1567,10 +1812,14 @@ def merge_dicts(*dicts): [[3, 6]]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.tensor_split, - r""" +add_docstr( + torch.tensor_split, + r""" tensor_split(input, indices_or_sections, dim=0) -> List of Tensors Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, @@ -1629,10 +1878,12 @@ def merge_dicts(*dicts): [ 8, 9, 10, 11, 12]]), tensor([[ 6], [13]])) -""") +""", +) -add_docstr(torch.chunk, - r""" +add_docstr( + torch.chunk, + r""" chunk(input, chunks, dim=0) -> List of Tensors Attempts to split a tensor into the specified number of chunks. Each chunk is a view of @@ -1680,10 +1931,12 @@ def merge_dicts(*dicts): tensor([6, 7, 8]), tensor([ 9, 10, 11]), tensor([12])) -""") +""", +) -add_docstr(torch.unsafe_chunk, - r""" +add_docstr( + torch.unsafe_chunk, + r""" unsafe_chunk(input, chunks, dim=0) -> List of Tensors Works like :func:`torch.chunk` but without enforcing the autograd restrictions @@ -1695,10 +1948,12 @@ def merge_dicts(*dicts): responsibility to ensure that is the case. If both the input and one or more of the outputs are modified inplace, gradients computed by autograd will be silently incorrect. -""") +""", +) -add_docstr(torch.unsafe_split, - r""" +add_docstr( + torch.unsafe_split, + r""" unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors Works like :func:`torch.split` but without enforcing the autograd restrictions @@ -1710,10 +1965,12 @@ def merge_dicts(*dicts): responsibility to ensure that is the case. If both the input and one or more of the outputs are modified inplace, gradients computed by autograd will be silently incorrect. -""") +""", +) -add_docstr(torch.hsplit, - r""" +add_docstr( + torch.hsplit, + r""" hsplit(input, indices_or_sections) -> List of Tensors Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors @@ -1760,10 +2017,12 @@ def merge_dicts(*dicts): [15.]]), tensor([], size=(4, 0))) -""") +""", +) -add_docstr(torch.vsplit, - r""" +add_docstr( + torch.vsplit, + r""" vsplit(input, indices_or_sections) -> List of Tensors Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors @@ -1799,10 +2058,12 @@ def merge_dicts(*dicts): tensor([[12., 13., 14., 15.]]), tensor([], size=(0, 4))) -""") +""", +) -add_docstr(torch.dsplit, - r""" +add_docstr( + torch.dsplit, + r""" dsplit(input, indices_or_sections) -> List of Tensors Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors @@ -1847,10 +2108,12 @@ def merge_dicts(*dicts): [15.]]]), tensor([], size=(2, 2, 0))) -""") +""", +) -add_docstr(torch.can_cast, - r""" +add_docstr( + torch.can_cast, + r""" can_cast(from, to) -> bool Determines if a type conversion is allowed under PyTorch casting rules @@ -1866,9 +2129,12 @@ def merge_dicts(*dicts): True >>> torch.can_cast(torch.float, torch.int) False -""") +""", +) -add_docstr(torch.corrcoef, r""" +add_docstr( + torch.corrcoef, + r""" corrcoef(input) -> Tensor Estimates the Pearson product-moment correlation coefficient matrix of the variables given by the :attr:`input` matrix, @@ -1910,9 +2176,12 @@ def merge_dicts(*dicts): [0.3582, 1.0000]]) >>> torch.corrcoef(x[0]) tensor(1.) -""") +""", +) -add_docstr(torch.cov, r""" +add_docstr( + torch.cov, + r""" cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are @@ -1983,10 +2252,12 @@ def merge_dicts(*dicts): >>> torch.cov(x, fweights=fw, aweights=aw) tensor([[ 0.4169, -0.4169], [-0.4169, 0.4169]]) -""") +""", +) -add_docstr(torch.cat, - r""" +add_docstr( + torch.cat, + r""" cat(tensors, dim=0, *, out=None) -> Tensor Concatenates the given sequence of :attr:`seq` tensors in the given dimension. @@ -2025,17 +2296,23 @@ def merge_dicts(*dicts): -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]]) -""".format(**common_args)) - -add_docstr(torch.concat, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.concat, + r""" concat(tensors, dim=0, *, out=None) -> Tensor Alias of :func:`torch.cat`. -""") +""", +) -add_docstr(torch.ceil, - r""" +add_docstr( + torch.ceil, + r""" ceil(input, *, out=None) -> Tensor Returns a new tensor with the ceil of the elements of :attr:`input`, @@ -2043,7 +2320,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil -""" + r""" +""" + + r""" Args: {input} @@ -2057,10 +2335,14 @@ def merge_dicts(*dicts): tensor([-0.6341, -1.4208, -1.0900, 0.5826]) >>> torch.ceil(a) tensor([-0., -1., -1., 1.]) -""".format(**common_args)) - -add_docstr(torch.real, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.real, + r""" real(input) -> Tensor Returns a new tensor containing real values of the :attr:`self` tensor. @@ -2077,10 +2359,14 @@ def merge_dicts(*dicts): >>> x.real tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.imag, - r""" +add_docstr( + torch.imag, + r""" imag(input) -> Tensor Returns a new tensor containing imaginary values of the :attr:`self` tensor. @@ -2100,10 +2386,14 @@ def merge_dicts(*dicts): >>> x.imag tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.view_as_real, - r""" +add_docstr( + torch.view_as_real, + r""" view_as_real(input) -> Tensor Returns a view of :attr:`input` as a real tensor. For an input complex tensor of @@ -2127,10 +2417,14 @@ def merge_dicts(*dicts): [-0.2098, -0.6699], [ 0.3470, -0.9451], [-0.5174, -1.3136]]) -""".format(**common_args)) - -add_docstr(torch.view_as_complex, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.view_as_complex, + r""" view_as_complex(input) -> Tensor Returns a view of :attr:`input` as a complex tensor. For an input complex @@ -2159,10 +2453,14 @@ def merge_dicts(*dicts): [-0.6561, -1.6623]]) >>> torch.view_as_complex(x) tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)]) -""".format(**common_args)) - -add_docstr(torch.reciprocal, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.reciprocal, + r""" reciprocal(input, *, out=None) -> Tensor Returns a new tensor with the reciprocal of the elements of :attr:`input` @@ -2174,7 +2472,8 @@ def merge_dicts(*dicts): Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral inputs to reciprocal are automatically :ref:`promoted ` to the default scalar type. -""" + r""" +""" + + r""" Args: {input} @@ -2188,9 +2487,14 @@ def merge_dicts(*dicts): tensor([-0.4595, -2.1219, -1.4314, 0.7298]) >>> torch.reciprocal(a) tensor([-2.1763, -0.4713, -0.6986, 1.3702]) -""".format(**common_args)) - -add_docstr(torch.cholesky, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.cholesky, + r""" cholesky(input, upper=False, *, out=None) -> Tensor Computes the Cholesky decomposition of a symmetric positive-definite @@ -2267,9 +2571,12 @@ def merge_dicts(*dicts): >>> z = l @ l.mT >>> torch.dist(z, a) tensor(2.3842e-07) -""") +""", +) -add_docstr(torch.cholesky_solve, r""" +add_docstr( + torch.cholesky_solve, + r""" cholesky_solve(input, input2, upper=False, *, out=None) -> Tensor Solves a linear system of equations with a positive semidefinite @@ -2328,9 +2635,12 @@ def merge_dicts(*dicts): tensor([[ -8.1626, 19.6097], [ -5.8398, 14.2387], [ -4.3771, 10.4173]]) -""") +""", +) -add_docstr(torch.cholesky_inverse, r""" +add_docstr( + torch.cholesky_inverse, + r""" cholesky_inverse(input, upper=False, *, out=None) -> Tensor Computes the inverse of a symmetric positive-definite matrix :math:`A` using its @@ -2385,9 +2695,12 @@ def merge_dicts(*dicts): >>> z = l @ l.mT >>> torch.dist(z, a) tensor(3.5894e-07) -""") +""", +) -add_docstr(torch.clone, r""" +add_docstr( + torch.clone, + r""" clone(input, *, memory_format=torch.preserve_format) -> Tensor Returns a copy of :attr:`input`. @@ -2403,9 +2716,14 @@ def merge_dicts(*dicts): Keyword args: {memory_format} -""".format(**common_args)) - -add_docstr(torch.clamp, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.clamp, + r""" clamp(input, min=None, max=None, *, out=None) -> Tensor Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. @@ -2416,7 +2734,8 @@ def merge_dicts(*dicts): If :attr:`min` is ``None``, there is no lower bound. Or, if :attr:`max` is ``None`` there is no upper bound. -""" + r""" +""" + + r""" .. note:: If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` @@ -2442,16 +2761,23 @@ def merge_dicts(*dicts): >>> torch.clamp(a, min=min) tensor([-1.0000, 0.1734, 0.3333, 1.0000]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.clip, r""" +add_docstr( + torch.clip, + r""" clip(input, min=None, max=None, *, out=None) -> Tensor Alias for :func:`torch.clamp`. -""") +""", +) -add_docstr(torch.column_stack, - r""" +add_docstr( + torch.column_stack, + r""" column_stack(tensors, *, out=None) -> Tensor Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. @@ -2482,10 +2808,14 @@ def merge_dicts(*dicts): [3, 6, 7, 6, 7], [4, 8, 9, 8, 9]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.complex, - r""" +add_docstr( + torch.complex, + r""" complex(real, imag, *, out=None) -> Tensor Constructs a complex tensor with its real part equal to :attr:`real` and its @@ -2511,10 +2841,12 @@ def merge_dicts(*dicts): >>> z.dtype torch.complex64 -""") +""", +) -add_docstr(torch.polar, - r""" +add_docstr( + torch.polar, + r""" polar(abs, angle, *, out=None) -> Tensor Constructs a complex tensor whose elements are Cartesian coordinates @@ -2532,7 +2864,8 @@ def merge_dicts(*dicts): The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is infinite. -""" + r""" +""" + + r""" Args: abs (Tensor): The absolute value the complex tensor. Must be float or double. angle (Tensor): The angle of the complex tensor. Must be same dtype as @@ -2551,10 +2884,12 @@ def merge_dicts(*dicts): >>> z = torch.polar(abs, angle) >>> z tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) -""") +""", +) -add_docstr(torch.conj_physical, - r""" +add_docstr( + torch.conj_physical, + r""" conj_physical(input, *, out=None) -> Tensor Computes the element-wise conjugate of the given :attr:`input` tensor. @@ -2569,7 +2904,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} = conj(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -2580,10 +2916,14 @@ def merge_dicts(*dicts): >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) tensor([-1 - 1j, -2 - 2j, 3 + 3j]) -""".format(**common_args)) - -add_docstr(torch.conj, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.conj, + r""" conj(input) -> Tensor Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype, @@ -2608,10 +2948,14 @@ def merge_dicts(*dicts): >>> y = torch.conj(x) >>> y.is_conj() True -""".format(**common_args)) - -add_docstr(torch.resolve_conj, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.resolve_conj, + r""" resolve_conj(input) -> Tensor Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`, @@ -2631,10 +2975,14 @@ def merge_dicts(*dicts): tensor([-1 - 1j, -2 - 2j, 3 + 3j]) >>> z.is_conj() False -""".format(**common_args)) - -add_docstr(torch.resolve_neg, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.resolve_neg, + r""" resolve_neg(input) -> Tensor Returns a new tensor with materialized negation if :attr:`input`'s negative bit is set to `True`, @@ -2655,10 +3003,14 @@ def merge_dicts(*dicts): >>> out.is_neg() False -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.copysign, - r""" +add_docstr( + torch.copysign, + r""" copysign(input, other, *, out=None) -> Tensor Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. @@ -2668,7 +3020,8 @@ def merge_dicts(*dicts): -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ \end{cases} -""" + r""" +""" + + r""" Supports :ref:`broadcasting to a common shape `, and integer and float inputs. @@ -2710,17 +3063,22 @@ def merge_dicts(*dicts): copysign handles signed zeros. If the other argument has a negative zero (-0), the corresponding output value will be negative. -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.cos, - r""" +add_docstr( + torch.cos, + r""" cos(input, *, out=None) -> Tensor Returns a new tensor with the cosine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \cos(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -2734,10 +3092,14 @@ def merge_dicts(*dicts): tensor([ 1.4309, 1.2706, -0.8562, 0.9796]) >>> torch.cos(a) tensor([ 0.1395, 0.2957, 0.6553, 0.5574]) -""".format(**common_args)) - -add_docstr(torch.cosh, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.cosh, + r""" cosh(input, *, out=None) -> Tensor Returns a new tensor with the hyperbolic cosine of the elements of @@ -2745,7 +3107,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} = \cosh(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -2764,10 +3127,14 @@ def merge_dicts(*dicts): When :attr:`input` is on the CPU, the implementation of torch.cosh may use the Sleef library, which rounds very large results to infinity or negative infinity. See `here `_ for details. -""".format(**common_args)) - -add_docstr(torch.cross, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.cross, + r""" cross(input, other, dim=None, *, out=None) -> Tensor @@ -2820,10 +3187,14 @@ def merge_dicts(*dicts): [-2.4490, -1.5687, 1.9792], [-0.8304, -1.3037, 0.5650], [-1.2329, 1.9883, 1.0551]]) -""".format(**common_args)) - -add_docstr(torch.logcumsumexp, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.logcumsumexp, + r""" logcumsumexp(input, dim, *, out=None) -> Tensor Returns the logarithm of the cumulative summation of the exponentiation of elements of :attr:`input` in the dimension :attr:`dim`. @@ -2846,10 +3217,14 @@ def merge_dicts(*dicts): >>> torch.logcumsumexp(a, dim=0) tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) -""".format(**reduceops_common_args)) - -add_docstr(torch.cummax, - r""" +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.cummax, + r""" cummax(input, dim, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index @@ -2876,10 +3251,14 @@ def merge_dicts(*dicts): values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, 1.9946, 1.9946]), indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) -""".format(**reduceops_common_args)) - -add_docstr(torch.cummin, - r""" +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.cummin, + r""" cummin(input, dim, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index @@ -2906,10 +3285,14 @@ def merge_dicts(*dicts): values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, -1.3298, -1.3298]), indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) -""".format(**reduceops_common_args)) - -add_docstr(torch.cumprod, - r""" +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.cumprod, + r""" cumprod(input, dim, *, dtype=None, out=None) -> Tensor Returns the cumulative product of elements of :attr:`input` in the dimension @@ -2943,10 +3326,14 @@ def merge_dicts(*dicts): >>> torch.cumprod(a, dim=0) tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, 0.0000, -0.0000, -0.0000]) -""".format(**reduceops_common_args)) - -add_docstr(torch.cumsum, - r""" +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.cumsum, + r""" cumsum(input, dim, *, dtype=None, out=None) -> Tensor Returns the cumulative sum of elements of :attr:`input` in the dimension @@ -2975,10 +3362,14 @@ def merge_dicts(*dicts): >>> torch.cumsum(a, dim=0) tensor([-0.8286, -1.3175, -0.8020, 0.0423, 0.2289, 0.0537, -2.0058, -1.8209, -2.9780, -3.4022]) -""".format(**reduceops_common_args)) - -add_docstr(torch.count_nonzero, - r""" +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.count_nonzero, + r""" count_nonzero(input, dim=None) -> Tensor Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. @@ -3000,10 +3391,14 @@ def merge_dicts(*dicts): tensor(3) >>> torch.count_nonzero(x, dim=0) tensor([0, 1, 2]) -""".format(**reduceops_common_args)) - -add_docstr(torch.dequantize, - r""" +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.dequantize, + r""" dequantize(tensor) -> Tensor Returns an fp32 Tensor by dequantizing a quantized Tensor @@ -3018,10 +3413,12 @@ def merge_dicts(*dicts): Args: tensors (sequence of Tensors): A list of quantized Tensors -""") +""", +) -add_docstr(torch.diag, - r""" +add_docstr( + torch.diag, + r""" diag(input, diagonal=0, *, out=None) -> Tensor - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor @@ -3077,10 +3474,14 @@ def merge_dicts(*dicts): tensor([-0.4264,-0.2429,-1.6300]) >>> torch.diag(a, 1) tensor([ 0.0255, 0.1374]) -""".format(**common_args)) - -add_docstr(torch.diag_embed, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.diag_embed, + r""" diag_embed(input, offset=0, dim1=-2, dim2=-1) -> Tensor Creates a tensor whose diagonals of certain 2D planes (specified by @@ -3138,11 +3539,15 @@ def merge_dicts(*dicts): [[ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000]]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.diagflat, - r""" +add_docstr( + torch.diagflat, + r""" diagflat(input, offset=0) -> Tensor - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor @@ -3185,10 +3590,14 @@ def merge_dicts(*dicts): [ 0.0000, -0.3018, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.1516, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.9342]]) -""".format(**common_args)) - -add_docstr(torch.diagonal, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.diagonal, + r""" diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor Returns a partial view of :attr:`input` with the its diagonal elements @@ -3241,10 +3650,14 @@ def merge_dicts(*dicts): [[-1.7325, -0.3081, 0.6166, 0.2335], [ 1.0500, 0.7336, -0.3836, -1.1015]]]) -""".format(**common_args)) - -add_docstr(torch.diagonal_scatter, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.diagonal_scatter, + r""" diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) -> Tensor Embeds the values of the :attr:`src` tensor into :attr:`input` along @@ -3293,10 +3706,14 @@ def merge_dicts(*dicts): tensor([[0., 1., 0.], [0., 0., 1.], [0., 0., 0.]]) -""".format(**common_args)) - -add_docstr(torch.as_strided_scatter, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.as_strided_scatter, + r""" as_strided_scatter(input, src, size, stride, storage_offset=0) -> Tensor Embeds the values of the :attr:`src` tensor into :attr:`input` along @@ -3334,9 +3751,14 @@ def merge_dicts(*dicts): [4., 0., 0.], [0., 0., 0.]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.diff, r""" +add_docstr( + torch.diff, + r""" diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor Computes the n-th forward difference along the given dimension. @@ -3371,16 +3793,23 @@ def merge_dicts(*dicts): >>> torch.diff(c, dim=1) tensor([[1, 1], [1, 1]]) -""".format(**common_args)) - -add_docstr(torch.digamma, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.digamma, + r""" digamma(input, *, out=None) -> Tensor Alias for :func:`torch.special.digamma`. -""") +""", +) -add_docstr(torch.dist, - r""" +add_docstr( + torch.dist, + r""" dist(input, other, p=2) -> Tensor Returns the p-norm of (:attr:`input` - :attr:`other`) @@ -3406,12 +3835,17 @@ def merge_dicts(*dicts): >>> torch.dist(x, y, 3) tensor(1.6973) >>> torch.dist(x, y, 0) - tensor(inf) + tensor(4.) >>> torch.dist(x, y, 1) tensor(2.6537) -""".format(**common_args)) - -add_docstr(torch.div, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.div, + r""" div(input, other, *, rounding_mode=None, out=None) -> Tensor Divides each element of the input ``input`` by the corresponding element of @@ -3474,16 +3908,23 @@ def merge_dicts(*dicts): [ 0., 4., -1., 5.], [-1., -1., -2., 6.]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.divide, r""" +add_docstr( + torch.divide, + r""" divide(input, other, *, rounding_mode=None, out=None) -> Tensor Alias for :func:`torch.div`. -""") +""", +) -add_docstr(torch.dot, - r""" +add_docstr( + torch.dot, + r""" dot(input, other, *, out=None) -> Tensor Computes the dot product of two 1D tensors. @@ -3504,27 +3945,46 @@ def merge_dicts(*dicts): >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1])) tensor(7) -""".format(**common_args)) - -add_docstr(torch.vdot, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.vdot, + r""" vdot(input, other, *, out=None) -> Tensor -Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers -differently than dot(a, b). If the first argument is complex, the complex conjugate of the -first argument is used for the calculation of the dot product. +Computes the dot product of two 1D vectors along a dimension. + +In symbols, this function computes + +.. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + +where :math:`\overline{x_i}` denotes the conjugate for complex +vectors, and it is the identity for real vectors. .. note:: Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product of two 1D tensors with the same number of elements. +.. seealso:: + + :func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension. + Args: input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex. other (Tensor): second tensor in the dot product, must be 1D. Keyword args: - {out} +""" + + rf""" +.. note:: {common_args["out"]} +""" + + r""" Example:: @@ -3536,10 +3996,12 @@ def merge_dicts(*dicts): tensor([16.+1.j]) >>> torch.vdot(b, a) tensor([16.-1.j]) -""".format(**common_args)) +""", +) -add_docstr(torch.eig, - r""" +add_docstr( + torch.eig, + r""" eig(input, eigenvectors=False, *, out=None) -> (Tensor, Tensor) Computes the eigenvalues and eigenvectors of a real square matrix. @@ -3620,9 +4082,12 @@ def merge_dicts(*dicts): [0., 1., 0.], [0., 0., 1.]], dtype=torch.float64) -""") +""", +) -add_docstr(torch.eq, r""" +add_docstr( + torch.eq, + r""" eq(input, other, *, out=None) -> Tensor Computes element-wise equality @@ -3645,10 +4110,14 @@ def merge_dicts(*dicts): >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[ True, False], [False, True]]) -""".format(**common_args)) - -add_docstr(torch.equal, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.equal, + r""" equal(input, other) -> bool ``True`` if two tensors have the same size and elements, ``False`` otherwise. @@ -3657,31 +4126,39 @@ def merge_dicts(*dicts): >>> torch.equal(torch.tensor([1, 2]), torch.tensor([1, 2])) True -""") +""", +) -add_docstr(torch.erf, - r""" +add_docstr( + torch.erf, + r""" erf(input, *, out=None) -> Tensor Alias for :func:`torch.special.erf`. -""") +""", +) -add_docstr(torch.erfc, - r""" +add_docstr( + torch.erfc, + r""" erfc(input, *, out=None) -> Tensor Alias for :func:`torch.special.erfc`. -""") +""", +) -add_docstr(torch.erfinv, - r""" +add_docstr( + torch.erfinv, + r""" erfinv(input, *, out=None) -> Tensor Alias for :func:`torch.special.erfinv`. -""") +""", +) -add_docstr(torch.exp, - r""" +add_docstr( + torch.exp, + r""" exp(input, *, out=None) -> Tensor Returns a new tensor with the exponential of the elements @@ -3689,7 +4166,8 @@ def merge_dicts(*dicts): .. math:: y_{i} = e^{x_{i}} -""" + r""" +""" + + r""" Args: {input} @@ -3700,24 +4178,32 @@ def merge_dicts(*dicts): >>> torch.exp(torch.tensor([0, math.log(2.)])) tensor([ 1., 2.]) -""".format(**common_args)) - -add_docstr(torch.exp2, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.exp2, + r""" exp2(input, *, out=None) -> Tensor Alias for :func:`torch.special.exp2`. -""") +""", +) -add_docstr(torch.expm1, - r""" +add_docstr( + torch.expm1, + r""" expm1(input, *, out=None) -> Tensor Alias for :func:`torch.special.expm1`. -""") +""", +) -add_docstr(torch.eye, - r""" +add_docstr( + torch.eye, + r""" eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. @@ -3742,10 +4228,14 @@ def merge_dicts(*dicts): tensor([[ 1., 0., 0.], [ 0., 1., 0.], [ 0., 0., 1.]]) -""".format(**factory_common_args)) - -add_docstr(torch.floor, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.floor, + r""" floor(input, *, out=None) -> Tensor Returns a new tensor with the floor of the elements of :attr:`input`, @@ -3753,7 +4243,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} = \left\lfloor \text{input}_{i} \right\rfloor -""" + r""" +""" + + r""" Args: {input} @@ -3767,9 +4258,14 @@ def merge_dicts(*dicts): tensor([-0.8166, 1.5308, -0.2530, -0.2091]) >>> torch.floor(a) tensor([-1., 1., -1., -1.]) -""".format(**common_args)) - -add_docstr(torch.floor_divide, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.floor_divide, + r""" floor_divide(input, other, *, out=None) -> Tensor .. note:: @@ -3784,7 +4280,8 @@ def merge_dicts(*dicts): .. math:: \text{{out}}_i = \text{floor} \left( \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right) -""" + r""" +""" + + r""" Supports broadcasting to a common shape, type promotion, and integer and float inputs. @@ -3803,10 +4300,14 @@ def merge_dicts(*dicts): tensor([2.0, 1.0]) >>> torch.floor_divide(a, 1.4) tensor([2.0, 2.0]) -""".format(**common_args)) - -add_docstr(torch.fmod, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.fmod, + r""" fmod(input, other, *, out=None) -> Tensor Applies C++'s `std::fmod `_ entrywise. @@ -3852,10 +4353,14 @@ def merge_dicts(*dicts): >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.frac, - r""" +add_docstr( + torch.frac, + r""" frac(input, *, out=None) -> Tensor Computes the fractional portion of each element in :attr:`input`. @@ -3867,10 +4372,12 @@ def merge_dicts(*dicts): >>> torch.frac(torch.tensor([1, 2.5, -3.2])) tensor([ 0.0000, 0.5000, -0.2000]) -""") +""", +) -add_docstr(torch.frexp, - r""" +add_docstr( + torch.frexp, + r""" frexp(input, *, out=None) -> (Tensor mantissa, Tensor exponent) Decomposes :attr:`input` into mantissa and exponent tensors @@ -3897,10 +4404,12 @@ def merge_dicts(*dicts): tensor([0, 1, 2, 2, 3, 3, 3, 3, 4], dtype=torch.int32) >>> torch.ldexp(mantissa, exponent) tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.]) -""") +""", +) -add_docstr(torch.from_numpy, - r""" +add_docstr( + torch.from_numpy, + r""" from_numpy(ndarray) -> Tensor Creates a :class:`Tensor` from a :class:`numpy.ndarray`. @@ -3926,10 +4435,12 @@ def merge_dicts(*dicts): >>> t[0] = -1 >>> a array([-1, 2, 3]) -""") +""", +) -add_docstr(torch.frombuffer, - r""" +add_docstr( + torch.frombuffer, + r""" frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) -> Tensor Creates a 1-dimensional :class:`Tensor` from an object that implements @@ -3998,10 +4509,14 @@ def merge_dicts(*dicts): >>> a = array.array('b', [-1, 0, 0, 0]) >>> torch.frombuffer(a, dtype=torch.int32) tensor([255], dtype=torch.int32) -""".format(**factory_common_args)) - -add_docstr(torch.flatten, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.flatten, + r""" flatten(input, start_dim=0, end_dim=-1) -> Tensor Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` @@ -4032,10 +4547,49 @@ def merge_dicts(*dicts): >>> torch.flatten(t, start_dim=1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) -""".format(**common_args)) +""".format( + **common_args + ), +) + +add_docstr( + torch.unflatten, + r""" +unflatten(input, dim, sizes) -> Tensor + +Expands a dimension of the input tensor over multiple dimensions. -add_docstr(torch.gather, - r""" +.. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + +Args: + {input} + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + +Returns: + A View of input with the specified dimension unflattened. + +Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -1, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) +""".format( + **common_args + ), +) + +add_docstr( + torch.gather, + r""" gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor Gathers values along an axis specified by `dim`. @@ -4066,11 +4620,13 @@ def merge_dicts(*dicts): >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) tensor([[ 1, 1], [ 4, 3]]) -""") +""", +) -add_docstr(torch.gcd, - r""" +add_docstr( + torch.gcd, + r""" gcd(input, other, *, out=None) -> Tensor Computes the element-wise greatest common divisor (GCD) of :attr:`input` and :attr:`other`. @@ -4096,13 +4652,19 @@ def merge_dicts(*dicts): >>> c = torch.tensor([3]) >>> torch.gcd(a, c) tensor([1, 1, 3]) -""".format(**common_args)) - -add_docstr(torch.ge, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.ge, + r""" ge(input, other, *, out=None) -> Tensor Computes :math:`\text{input} \geq \text{other}` element-wise. -""" + r""" +""" + + r""" The second argument can be a number or a tensor whose shape is :ref:`broadcastable ` with the first argument. @@ -4121,16 +4683,23 @@ def merge_dicts(*dicts): >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[True, True], [False, True]]) -""".format(**common_args)) - -add_docstr(torch.greater_equal, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.greater_equal, + r""" greater_equal(input, other, *, out=None) -> Tensor Alias for :func:`torch.ge`. -""") +""", +) -add_docstr(torch.gradient, - r""" +add_docstr( + torch.gradient, + r""" gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in @@ -4250,10 +4819,12 @@ def merge_dicts(*dicts): tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], [ 3.3333, 5.0000, 10.0000, 13.3333]])) -""") +""", +) -add_docstr(torch.geqrf, - r""" +add_docstr( + torch.geqrf, + r""" geqrf(input, *, out=None) -> (Tensor, Tensor) This is a low-level function for calling LAPACK's geqrf directly. This function @@ -4284,9 +4855,12 @@ def merge_dicts(*dicts): .. _LAPACK documentation for geqrf: http://www.netlib.org/lapack/explore-html/df/dc5/group__variants_g_ecomputational_ga3766ea903391b5cf9008132f7440ec7b.html -""") +""", +) -add_docstr(torch.inner, r""" +add_docstr( + torch.inner, + r""" inner(input, other, *, out=None) -> Tensor Computes the dot product for 1D tensors. For higher dimensions, sums the product @@ -4342,9 +4916,12 @@ def merge_dicts(*dicts): >>> torch.inner(a, torch.tensor(2)) tensor([[1.6347, 2.1748, 2.3567], [0.6558, 0.2469, 5.5787]]) -""") +""", +) -add_docstr(torch.outer, r""" +add_docstr( + torch.outer, + r""" outer(input, vec2, *, out=None) -> Tensor Outer product of :attr:`input` and :attr:`vec2`. @@ -4369,10 +4946,12 @@ def merge_dicts(*dicts): [ 2., 4., 6.], [ 3., 6., 9.], [ 4., 8., 12.]]) -""") +""", +) -add_docstr(torch.ger, - r""" +add_docstr( + torch.ger, + r""" ger(input, vec2, *, out=None) -> Tensor Alias of :func:`torch.outer`. @@ -4380,10 +4959,12 @@ def merge_dicts(*dicts): .. warning:: This function is deprecated and will be removed in a future PyTorch release. Use :func:`torch.outer` instead. -""") +""", +) -add_docstr(torch.get_default_dtype, - r""" +add_docstr( + torch.get_default_dtype, + r""" get_default_dtype() -> torch.dtype Get the current default floating point :class:`torch.dtype`. @@ -4399,28 +4980,36 @@ def merge_dicts(*dicts): >>> torch.get_default_dtype() # changed to torch.float32, the dtype for torch.FloatTensor torch.float32 -""") +""", +) -add_docstr(torch.get_num_threads, - r""" +add_docstr( + torch.get_num_threads, + r""" get_num_threads() -> int Returns the number of threads used for parallelizing CPU operations -""") +""", +) -add_docstr(torch.get_num_interop_threads, - r""" +add_docstr( + torch.get_num_interop_threads, + r""" get_num_interop_threads() -> int Returns the number of threads used for inter-op parallelism on CPU (e.g. in JIT interpreter) -""") +""", +) -add_docstr(torch.gt, r""" +add_docstr( + torch.gt, + r""" gt(input, other, *, out=None) -> Tensor Computes :math:`\text{input} > \text{other}` element-wise. -""" + r""" +""" + + r""" The second argument can be a number or a tensor whose shape is :ref:`broadcastable ` with the first argument. @@ -4439,16 +5028,23 @@ def merge_dicts(*dicts): >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[False, True], [False, False]]) -""".format(**common_args)) - -add_docstr(torch.greater, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.greater, + r""" greater(input, other, *, out=None) -> Tensor Alias for :func:`torch.gt`. -""") +""", +) -add_docstr(torch.histc, - r""" +add_docstr( + torch.histc, + r""" histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor Computes the histogram of a tensor. @@ -4475,10 +5071,14 @@ def merge_dicts(*dicts): >>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3) tensor([ 0., 2., 1., 0.]) -""".format(**common_args)) - -add_docstr(torch.histogram, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.histogram, + r""" histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) Computes a histogram of the values in a tensor. @@ -4518,10 +5118,14 @@ def merge_dicts(*dicts): (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) -""".format(**common_args)) - -add_docstr(torch.histogramdd, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.histogramdd, + r""" histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) Computes a multi-dimensional histogram of the values in a tensor. @@ -4608,12 +5212,14 @@ def merge_dicts(*dicts): bin_edges=(tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000]))) -""") +""", +) # TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 torch.histogramdd.__module__ = "torch" -add_docstr(torch.hypot, - r""" +add_docstr( + torch.hypot, + r""" hypot(input, other, *, out=None) -> Tensor Given the legs of a right triangle, return its hypotenuse. @@ -4623,7 +5229,8 @@ def merge_dicts(*dicts): The shapes of ``input`` and ``other`` must be :ref:`broadcastable `. -""" + r""" +""" + + r""" Args: input (Tensor): the first input tensor other (Tensor): the second input tensor @@ -4636,31 +5243,41 @@ def merge_dicts(*dicts): >>> a = torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) tensor([5.0000, 5.6569, 6.4031]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.i0, - r""" +add_docstr( + torch.i0, + r""" i0(input, *, out=None) -> Tensor Alias for :func:`torch.special.i0`. -""") +""", +) -add_docstr(torch.igamma, - r""" +add_docstr( + torch.igamma, + r""" igamma(input, other, *, out=None) -> Tensor Alias for :func:`torch.special.gammainc`. -""") +""", +) -add_docstr(torch.igammac, - r""" +add_docstr( + torch.igammac, + r""" igammac(input, other, *, out=None) -> Tensor Alias for :func:`torch.special.gammaincc`. -""") +""", +) -add_docstr(torch.index_select, - r""" +add_docstr( + torch.index_select, + r""" index_select(input, dim, index, *, out=None) -> Tensor Returns a new tensor which indexes the :attr:`input` tensor along dimension @@ -4698,15 +5315,23 @@ def merge_dicts(*dicts): tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]]) -""".format(**common_args)) - -add_docstr(torch.inverse, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.inverse, + r""" inverse(input, *, out=None) -> Tensor Alias for :func:`torch.linalg.inv` -""") +""", +) -add_docstr(torch.isin, r""" +add_docstr( + torch.isin, + r""" isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns @@ -4733,9 +5358,12 @@ def merge_dicts(*dicts): >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) tensor([[False, True], [ True, False]]) -""") +""", +) -add_docstr(torch.isinf, r""" +add_docstr( + torch.isinf, + r""" isinf(input) -> Tensor Tests if each element of :attr:`input` is infinite @@ -4755,10 +5383,14 @@ def merge_dicts(*dicts): >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) tensor([False, True, False, True, False]) -""".format(**common_args)) - -add_docstr(torch.isposinf, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.isposinf, + r""" isposinf(input, *, out=None) -> Tensor Tests if each element of :attr:`input` is positive infinity or not. @@ -4773,10 +5405,14 @@ def merge_dicts(*dicts): >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) >>> torch.isposinf(a) tensor([False, True, False]) -""".format(**common_args)) - -add_docstr(torch.isneginf, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.isneginf, + r""" isneginf(input, *, out=None) -> Tensor Tests if each element of :attr:`input` is negative infinity or not. @@ -4791,9 +5427,14 @@ def merge_dicts(*dicts): >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) >>> torch.isneginf(a) tensor([ True, False, False]) -""".format(**common_args)) - -add_docstr(torch.isclose, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.isclose, + r""" isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor Returns a new tensor with boolean elements representing if each element of @@ -4802,7 +5443,8 @@ def merge_dicts(*dicts): .. math:: \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert -""" + r""" +""" + + r""" where :attr:`input` and :attr:`other` are finite. Where :attr:`input` and/or :attr:`other` are nonfinite they are close if and only if @@ -4822,9 +5464,12 @@ def merge_dicts(*dicts): tensor([ True, False, False]) >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5) tensor([True, True]) -""") +""", +) -add_docstr(torch.isfinite, r""" +add_docstr( + torch.isfinite, + r""" isfinite(input) -> Tensor Returns a new tensor with boolean elements representing if each element is `finite` or not. @@ -4842,9 +5487,14 @@ def merge_dicts(*dicts): >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) tensor([True, False, True, False, False]) -""".format(**common_args)) - -add_docstr(torch.isnan, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.isnan, + r""" isnan(input) -> Tensor Returns a new tensor with boolean elements representing if each element of :attr:`input` @@ -4861,9 +5511,14 @@ def merge_dicts(*dicts): >>> torch.isnan(torch.tensor([1, float('nan'), 2])) tensor([False, True, False]) -""".format(**common_args)) - -add_docstr(torch.isreal, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.isreal, + r""" isreal(input) -> Tensor Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not. @@ -4879,9 +5534,14 @@ def merge_dicts(*dicts): >>> torch.isreal(torch.tensor([1, 1+1j, 2+0j])) tensor([True, False, True]) -""".format(**common_args)) - -add_docstr(torch.is_floating_point, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.is_floating_point, + r""" is_floating_point(input) -> (bool) Returns True if the data type of :attr:`input` is a floating point data type i.e., @@ -4889,9 +5549,14 @@ def merge_dicts(*dicts): Args: {input} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.is_complex, r""" +add_docstr( + torch.is_complex, + r""" is_complex(input) -> (bool) Returns True if the data type of :attr:`input` is a complex data type i.e., @@ -4899,21 +5564,36 @@ def merge_dicts(*dicts): Args: {input} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.is_grad_enabled, r""" +add_docstr( + torch.is_grad_enabled, + r""" is_grad_enabled() -> (bool) Returns True if grad mode is currently enabled. -""".format(**common_args)) - -add_docstr(torch.is_inference_mode_enabled, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.is_inference_mode_enabled, + r""" is_inference_mode_enabled() -> (bool) Returns True if inference mode is currently enabled. -""".format(**common_args)) - -add_docstr(torch.is_inference, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.is_inference, + r""" is_inference(input) -> (bool) Returns True if :attr:`input` is an inference tensor. @@ -4927,18 +5607,28 @@ def merge_dicts(*dicts): Args: {input} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.is_conj, r""" +add_docstr( + torch.is_conj, + r""" is_conj(input) -> (bool) Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`. Args: {input} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.is_nonzero, r""" +add_docstr( + torch.is_nonzero, + r""" is_nonzero(input) -> (bool) Returns True if the :attr:`input` is a single element tensor which is not equal to zero @@ -4969,10 +5659,14 @@ def merge_dicts(*dicts): Traceback (most recent call last): ... RuntimeError: bool value of Tensor with no values is ambiguous -""".format(**common_args)) - -add_docstr(torch.kron, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.kron, + r""" kron(input, other, *, out=None) -> Tensor Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. @@ -5027,10 +5721,12 @@ def merge_dicts(*dicts): [3., 4., 0., 0.], [0., 0., 1., 2.], [0., 0., 3., 4.]]) -""") +""", +) -add_docstr(torch.kthvalue, - r""" +add_docstr( + torch.kthvalue, + r""" kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th @@ -5074,10 +5770,14 @@ def merge_dicts(*dicts): [ 4., 5., 6.]]) >>> torch.kthvalue(x, 2, 0, True) torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) -""".format(**single_dim_common)) - -add_docstr(torch.lcm, - r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.lcm, + r""" lcm(input, other, *, out=None) -> Tensor Computes the element-wise least common multiple (LCM) of :attr:`input` and :attr:`other`. @@ -5103,16 +5803,22 @@ def merge_dicts(*dicts): >>> c = torch.tensor([3]) >>> torch.lcm(a, c) tensor([15, 30, 15]) -""".format(**common_args)) - -add_docstr(torch.ldexp, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.ldexp, + r""" ldexp(input, other, *, out=None) -> Tensor Multiplies :attr:`input` by 2**:attr:`other`. .. math:: \text{{out}}_i = \text{{input}}_i * 2^\text{{other}}_i -""" + r""" +""" + + r""" Typically this function is used to construct floating point numbers by multiplying mantissas in :attr:`input` with integral powers of two created from the exponents @@ -5133,13 +5839,19 @@ def merge_dicts(*dicts): tensor([ 2., 4., 8., 16.]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.le, r""" +add_docstr( + torch.le, + r""" le(input, other, *, out=None) -> Tensor Computes :math:`\text{input} \leq \text{other}` element-wise. -""" + r""" +""" + + r""" The second argument can be a number or a tensor whose shape is :ref:`broadcastable ` with the first argument. @@ -5159,16 +5871,23 @@ def merge_dicts(*dicts): >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[True, False], [True, True]]) -""".format(**common_args)) - -add_docstr(torch.less_equal, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.less_equal, + r""" less_equal(input, other, *, out=None) -> Tensor Alias for :func:`torch.le`. -""") +""", +) -add_docstr(torch.lerp, - r""" +add_docstr( + torch.lerp, + r""" lerp(input, end, weight, *, out=None) Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based @@ -5176,7 +5895,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) -""" + r""" +""" + + r""" The shapes of :attr:`start` and :attr:`end` must be :ref:`broadcastable `. If :attr:`weight` is a tensor, then the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. @@ -5201,17 +5921,22 @@ def merge_dicts(*dicts): tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) >>> torch.lerp(start, end, torch.full_like(start, 0.5)) tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) -""".format(**common_args)) - -add_docstr(torch.lgamma, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.lgamma, + r""" lgamma(input, *, out=None) -> Tensor Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. .. math:: \text{out}_{i} = \ln \Gamma(|\text{input}_{i}|) -""" + """ +""" + + """ Args: {input} @@ -5223,9 +5948,14 @@ def merge_dicts(*dicts): >>> a = torch.arange(0.5, 2, 0.5) >>> torch.lgamma(a) tensor([ 0.5724, 0.0000, -0.1208]) -""".format(**common_args)) - -add_docstr(torch.linspace, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.linspace, + r""" linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly @@ -5237,7 +5967,8 @@ def merge_dicts(*dicts): \ldots, \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, \text{end}) -""" + """ +""" + + """ From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. @@ -5267,10 +5998,14 @@ def merge_dicts(*dicts): tensor([-10., -5., 0., 5., 10.]) >>> torch.linspace(start=-10, end=10, steps=1) tensor([-10.]) -""".format(**factory_common_args)) - -add_docstr(torch.log, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.log, + r""" log(input, *, out=None) -> Tensor Returns a new tensor with the natural logarithm of the elements @@ -5278,7 +6013,8 @@ def merge_dicts(*dicts): .. math:: y_{i} = \log_{e} (x_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -5293,10 +6029,14 @@ def merge_dicts(*dicts): tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) >>> torch.log(a) tensor([ 1.5637, 1.4640, 0.1952, -1.4226, 1.5204]) -""".format(**common_args)) - -add_docstr(torch.log10, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.log10, + r""" log10(input, *, out=None) -> Tensor Returns a new tensor with the logarithm to the base 10 of the elements @@ -5304,7 +6044,8 @@ def merge_dicts(*dicts): .. math:: y_{i} = \log_{10} (x_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -5322,17 +6063,22 @@ def merge_dicts(*dicts): >>> torch.log10(a) tensor([-0.2820, -0.0290, -0.1392, -0.8857, -0.6476]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.log1p, - r""" +add_docstr( + torch.log1p, + r""" log1p(input, *, out=None) -> Tensor Returns a new tensor with the natural logarithm of (1 + :attr:`input`). .. math:: y_i = \log_{e} (x_i + 1) -""" + r""" +""" + + r""" .. note:: This function is more accurate than :func:`torch.log` for small values of :attr:`input` @@ -5349,10 +6095,14 @@ def merge_dicts(*dicts): tensor([-1.0090, -0.9923, 1.0249, -0.5372, 0.2492]) >>> torch.log1p(a) tensor([ nan, -4.8653, 0.7055, -0.7705, 0.2225]) -""".format(**common_args)) - -add_docstr(torch.log2, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.log2, + r""" log2(input, *, out=None) -> Tensor Returns a new tensor with the logarithm to the base 2 of the elements @@ -5360,7 +6110,8 @@ def merge_dicts(*dicts): .. math:: y_{i} = \log_{2} (x_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -5378,10 +6129,14 @@ def merge_dicts(*dicts): >>> torch.log2(a) tensor([-0.2483, -0.3213, -0.0042, -0.9196, -4.3504]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.logaddexp, - r""" +add_docstr( + torch.logaddexp, + r""" logaddexp(input, other, *, out=None) -> Tensor Logarithm of the sum of exponentiations of the inputs. @@ -5410,10 +6165,14 @@ def merge_dicts(*dicts): tensor([-1., -2., -3.]) >>> torch.logaddexp(torch.tensor([1.0, 2000, 30000]), torch.tensor([-1.0, -2, -3])) tensor([1.1269e+00, 2.0000e+03, 3.0000e+04]) -""".format(**common_args)) - -add_docstr(torch.logaddexp2, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.logaddexp2, + r""" logaddexp2(input, other, *, out=None) -> Tensor Logarithm of the sum of exponentiations of the inputs in base-2. @@ -5427,17 +6186,23 @@ def merge_dicts(*dicts): Keyword arguments: {out} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.xlogy, - r""" +add_docstr( + torch.xlogy, + r""" xlogy(input, other, *, out=None) -> Tensor Alias for :func:`torch.special.xlogy`. -""") +""", +) -add_docstr(torch.logical_and, - r""" +add_docstr( + torch.logical_and, + r""" logical_and(input, other, *, out=None) -> Tensor Computes the element-wise logical AND of the given input tensors. Zeros are treated as ``False`` and nonzeros are @@ -5464,10 +6229,14 @@ def merge_dicts(*dicts): tensor([False, False, True, False]) >>> torch.logical_and(a, b, out=torch.empty(4, dtype=torch.bool)) tensor([False, False, True, False]) -""".format(**common_args)) - -add_docstr(torch.logical_not, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.logical_not, + r""" logical_not(input, *, out=None) -> Tensor Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool @@ -5489,10 +6258,14 @@ def merge_dicts(*dicts): tensor([ True, False, False]) >>> torch.logical_not(torch.tensor([0., 1., -10.], dtype=torch.double), out=torch.empty(3, dtype=torch.int16)) tensor([1, 0, 0], dtype=torch.int16) -""".format(**common_args)) - -add_docstr(torch.logical_or, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.logical_or, + r""" logical_or(input, other, *, out=None) -> Tensor Computes the element-wise logical OR of the given input tensors. Zeros are treated as ``False`` and nonzeros are @@ -5519,10 +6292,14 @@ def merge_dicts(*dicts): tensor([ True, True, True, False]) >>> torch.logical_or(a, b, out=torch.empty(4, dtype=torch.bool)) tensor([ True, True, True, False]) -""".format(**common_args)) - -add_docstr(torch.logical_xor, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.logical_xor, + r""" logical_xor(input, other, *, out=None) -> Tensor Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are @@ -5549,12 +6326,18 @@ def merge_dicts(*dicts): tensor([ True, True, False, False]) >>> torch.logical_xor(a, b, out=torch.empty(4, dtype=torch.bool)) tensor([ True, True, False, False]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.logspace, """ +add_docstr( + torch.logspace, + """ logspace(start, end, steps, base=10.0, *, \ out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor -""" + r""" +""" + + r""" Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to @@ -5567,7 +6350,8 @@ def merge_dicts(*dicts): \ldots, \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \text{base}^{\text{end}}) -""" + """ +""" + + """ From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. @@ -5598,10 +6382,14 @@ def merge_dicts(*dicts): tensor([1.2589]) >>> torch.logspace(start=2, end=2, steps=1, base=2) tensor([4.0]) -""".format(**factory_common_args)) - -add_docstr(torch.logsumexp, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.logsumexp, + r""" logsumexp(input, dim, keepdim=False, *, out=None) Returns the log of summed exponentials of each row of the :attr:`input` @@ -5617,7 +6405,7 @@ def merge_dicts(*dicts): Args: {input} - {dim} + {opt_dim} {keepdim} Keyword args: @@ -5630,10 +6418,14 @@ def merge_dicts(*dicts): tensor([1.4907, 1.0593, 1.5696]) >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) tensor(1.6859e-07) -""".format(**multi_dim_common)) - -add_docstr(torch.lstsq, - r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.lstsq, + r""" lstsq(input, A, *, out=None) -> (Tensor, Tensor) Computes the solution to the least squares and least norm problems for a full @@ -5718,13 +6510,17 @@ def merge_dicts(*dicts): [ 1.0000, 2.0000], [ 10.9635, 4.8501], [ 8.9332, 5.2418]]) -""") +""", +) -add_docstr(torch.lt, r""" +add_docstr( + torch.lt, + r""" lt(input, other, *, out=None) -> Tensor Computes :math:`\text{input} < \text{other}` element-wise. -""" + r""" +""" + + r""" The second argument can be a number or a tensor whose shape is :ref:`broadcastable ` with the first argument. @@ -5743,9 +6539,14 @@ def merge_dicts(*dicts): >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[False, False], [True, False]]) -""".format(**common_args)) - -add_docstr(torch.lu_unpack, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.lu_unpack, + r""" lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor) Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices. @@ -5790,16 +6591,23 @@ def merge_dicts(*dicts): >>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_) True -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.less, r""" +add_docstr( + torch.less, + r""" less(input, other, *, out=None) -> Tensor Alias for :func:`torch.lt`. -""") +""", +) -add_docstr(torch.lu_solve, - r""" +add_docstr( + torch.lu_solve, + r""" lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted @@ -5839,10 +6647,14 @@ def merge_dicts(*dicts): >>> torch.dist(A @ x, b) tensor(1.00000e-07 * 2.8312) -""".format(**common_args)) - -add_docstr(torch.masked_select, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.masked_select, + r""" masked_select(input, mask, *, out=None) -> Tensor Returns a new 1-D tensor which indexes the :attr:`input` tensor according to @@ -5875,9 +6687,14 @@ def merge_dicts(*dicts): [False, False, False, True]]) >>> torch.masked_select(x, mask) tensor([ 1.2252, 0.5002, 0.6248, 2.0139]) -""".format(**common_args)) - -add_docstr(torch.matrix_rank, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.matrix_rank, + r""" matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor Returns the numerical rank of a 2-D tensor. The method to compute the @@ -5915,22 +6732,32 @@ def merge_dicts(*dicts): >>> b[0, 0] = 0 >>> torch.matrix_rank(b) tensor(9) -""".format(**common_args)) - -add_docstr(torch.matrix_power, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.matrix_power, + r""" matrix_power(input, n, *, out=None) -> Tensor Alias for :func:`torch.linalg.matrix_power` -""") +""", +) -add_docstr(torch.matrix_exp, r""" +add_docstr( + torch.matrix_exp, + r""" matrix_exp(A) -> Tensor Alias for :func:`torch.linalg.matrix_exp`. -""") +""", +) -add_docstr(torch.max, - r""" +add_docstr( + torch.max, + r""" max(input) -> Tensor Returns the maximum value of all elements in the ``input`` tensor. @@ -5989,9 +6816,14 @@ def merge_dicts(*dicts): See :func:`torch.maximum`. -""".format(**single_dim_common)) +""".format( + **single_dim_common + ), +) -add_docstr(torch.maximum, r""" +add_docstr( + torch.maximum, + r""" maximum(input, other, *, out=None) -> Tensor Computes the element-wise maximum of :attr:`input` and :attr:`other`. @@ -6013,9 +6845,14 @@ def merge_dicts(*dicts): >>> b = torch.tensor((3, 0, 4)) >>> torch.maximum(a, b) tensor([3, 2, 4]) -""".format(**common_args)) - -add_docstr(torch.fmax, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.fmax, + r""" fmax(input, other, *, out=None) -> Tensor Computes the element-wise maximum of :attr:`input` and :attr:`other`. @@ -6042,10 +6879,14 @@ def merge_dicts(*dicts): >>> b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')]) >>> torch.fmax(a, b) tensor([9.7000, 0.5000, 3.1000, nan]) -""".format(**common_args)) - -add_docstr(torch.amax, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.amax, + r""" amax(input, dim, keepdim=False, *, out=None) -> Tensor Returns the maximum value of each slice of the :attr:`input` tensor in the given @@ -6079,10 +6920,14 @@ def merge_dicts(*dicts): [ 1.9700, 1.1106, -1.0318, -1.0816]]) >>> torch.amax(a, 1) tensor([1.4878, 2.0992, 0.0164, 1.9700]) -""".format(**multi_dim_common)) - -add_docstr(torch.argmax, - r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.argmax, + r""" argmax(input) -> LongTensor Returns the indices of the maximum value of all elements in the :attr:`input` tensor. @@ -6129,10 +6974,14 @@ def merge_dicts(*dicts): [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a, dim=1) tensor([ 0, 2, 0, 1]) -""".format(**single_dim_common)) - -add_docstr(torch.argwhere, - r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.argwhere, + r""" argwhere(input) -> Tensor Returns a tensor containing the indices of all non-zero elements of @@ -6164,9 +7013,12 @@ def merge_dicts(*dicts): [0, 2], [1, 1], [1, 2]]) -""") +""", +) -add_docstr(torch.mean, r""" +add_docstr( + torch.mean, + r""" mean(input, *, dtype=None) -> Tensor Returns the mean value of all elements in the :attr:`input` tensor. @@ -6222,9 +7074,14 @@ def merge_dicts(*dicts): [-0.5085], [-0.4599], [ 0.1807]]) -""".format(**multi_dim_common)) - -add_docstr(torch.nanmean, r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.nanmean, + r""" nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor Computes the mean of all `non-NaN` elements along the specified dimensions. @@ -6264,10 +7121,14 @@ def merge_dicts(*dicts): # If all elements in the reduced dimensions are NaN then the result is NaN >>> torch.tensor([torch.nan]).nanmean() tensor(nan) -""".format(**multi_dim_common)) - -add_docstr(torch.median, - r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.median, + r""" median(input) -> Tensor Returns the median of the values in :attr:`input`. @@ -6337,10 +7198,14 @@ def merge_dicts(*dicts): [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) >>> torch.median(a, 1) torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) -""".format(**single_dim_common)) - -add_docstr(torch.nanmedian, - r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.nanmedian, + r""" nanmedian(input) -> Tensor Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. @@ -6392,9 +7257,14 @@ def merge_dicts(*dicts): torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) >>> a.nanmedian(0) torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) -""".format(**single_dim_common)) - -add_docstr(torch.quantile, r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.quantile, + r""" quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. @@ -6423,7 +7293,7 @@ def merge_dicts(*dicts): {keepdim} Keyword arguments: - interpolation (string): interpolation method to use when the desired quantile lies between two data points. + interpolation (str): interpolation method to use when the desired quantile lies between two data points. Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. Default is ``linear``. {out} @@ -6461,9 +7331,14 @@ def merge_dicts(*dicts): tensor(2.) >>> torch.quantile(a, 0.4, interpolation='nearest') tensor(1.) -""".format(**single_dim_common)) - -add_docstr(torch.nanquantile, r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.nanquantile, + r""" nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, @@ -6478,7 +7353,7 @@ def merge_dicts(*dicts): {keepdim} Keyword arguments: - interpolation (string): interpolation method to use when the desired quantile lies between two data points. + interpolation (str): interpolation method to use when the desired quantile lies between two data points. Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. Default is ``linear``. {out} @@ -6498,10 +7373,14 @@ def merge_dicts(*dicts): tensor([1., 2.]) >>> t.nanquantile(0.5, dim=1) tensor([ nan, 1.5000]) -""".format(**single_dim_common)) - -add_docstr(torch.min, - r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.min, + r""" min(input) -> Tensor Returns the minimum value of all elements in the :attr:`input` tensor. @@ -6559,9 +7438,14 @@ def merge_dicts(*dicts): :noindex: See :func:`torch.minimum`. -""".format(**single_dim_common)) - -add_docstr(torch.minimum, r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.minimum, + r""" minimum(input, other, *, out=None) -> Tensor Computes the element-wise minimum of :attr:`input` and :attr:`other`. @@ -6583,9 +7467,14 @@ def merge_dicts(*dicts): >>> b = torch.tensor((3, 0, 4)) >>> torch.minimum(a, b) tensor([1, 0, -1]) -""".format(**common_args)) - -add_docstr(torch.fmin, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.fmin, + r""" fmin(input, other, *, out=None) -> Tensor Computes the element-wise minimum of :attr:`input` and :attr:`other`. @@ -6612,10 +7501,14 @@ def merge_dicts(*dicts): >>> b = torch.tensor([-9.3, 0.1, float('nan'), float('nan')]) >>> torch.fmin(a, b) tensor([-9.3000, 0.1000, 2.1000, nan]) -""".format(**common_args)) - -add_docstr(torch.amin, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.amin, + r""" amin(input, dim, keepdim=False, *, out=None) -> Tensor Returns the minimum value of each slice of the :attr:`input` tensor in the given @@ -6649,9 +7542,14 @@ def merge_dicts(*dicts): [ 0.9023, 0.4853, 0.9075, -1.6165]]) >>> torch.amin(a, 1) tensor([-1.3312, -0.5744, -1.7268, -1.6165]) -""".format(**multi_dim_common)) - -add_docstr(torch.aminmax, r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.aminmax, + r""" aminmax(input, *, dim=None, keepdim=False, out=None) -> (Tensor min, Tensor max) Computes the minimum and maximum values of the :attr:`input` tensor. @@ -6710,10 +7608,12 @@ def merge_dicts(*dicts): torch.return_types.aminmax( min=tensor([[0, 1, 2, 3, 4]]), max=tensor([[5, 6, 7, 8, 9]])) -""") +""", +) -add_docstr(torch.argmin, - r""" +add_docstr( + torch.argmin, + r""" argmin(input, dim=None, keepdim=False) -> LongTensor Returns the indices of the minimum value(s) of the flattened tensor or along a dimension @@ -6745,10 +7645,14 @@ def merge_dicts(*dicts): [1], [3], [1]]) -""".format(**single_dim_common)) - -add_docstr(torch.mm, - r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.mm, + r""" mm(input, mat2, *, out=None) -> Tensor Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. @@ -6780,10 +7684,14 @@ def merge_dicts(*dicts): >>> torch.mm(mat1, mat2) tensor([[ 0.4851, 0.5037, -0.3633], [-0.0760, -3.6705, 2.4784]]) -""".format(**common_args, **tf32_notes, **rocm_fp16_notes)) - -add_docstr(torch.hspmm, - r""" +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) + +add_docstr( + torch.hspmm, + r""" hspmm(mat1, mat2, *, out=None) -> Tensor Performs a matrix multiplication of a :ref:`sparse COO matrix @@ -6797,10 +7705,14 @@ def merge_dicts(*dicts): Keyword args: {out} -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.matmul, - r""" +add_docstr( + torch.matmul, + r""" matmul(input, other, *, out=None) -> Tensor Matrix product of two tensors. @@ -6873,10 +7785,14 @@ def merge_dicts(*dicts): >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) -""".format(**common_args, **tf32_notes, **rocm_fp16_notes)) +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) -add_docstr(torch.mode, - r""" +add_docstr( + torch.mode, + r""" mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the mode @@ -6909,9 +7825,14 @@ def merge_dicts(*dicts): >>> b = a + (torch.randn(50, 1) * 5).long() >>> torch.mode(b, 0) torch.return_types.mode(values=tensor([6, 5, 1, 0, 2]), indices=tensor([2, 2, 2, 2, 2])) -""".format(**single_dim_common)) - -add_docstr(torch.mul, r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.mul, + r""" mul(input, other, *, out=None) -> Tensor Multiplies :attr:`input` by :attr:`other`. @@ -6919,7 +7840,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_i = \text{input}_i \times \text{other}_i -""" + r""" +""" + + r""" Supports :ref:`broadcasting to a common shape `, :ref:`type promotion `, and integer, float, and complex inputs. @@ -6953,16 +7875,23 @@ def merge_dicts(*dicts): [-0.1614, -0.0382, 0.1645, -0.7021], [ 0.0360, 0.0085, -0.0367, 0.1567], [ 0.4312, 0.1019, -0.4394, 1.8753]]) -""".format(**common_args)) - -add_docstr(torch.multiply, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.multiply, + r""" multiply(input, other, *, out=None) Alias for :func:`torch.mul`. -""") +""", +) -add_docstr(torch.multinomial, - r""" +add_docstr( + torch.multinomial, + r""" multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor Returns a tensor where each row contains :attr:`num_samples` indices sampled @@ -7011,10 +7940,14 @@ def merge_dicts(*dicts): not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320 >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1]) -""".format(**common_args)) - -add_docstr(torch.mv, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.mv, + r""" mv(input, vec, *, out=None) -> Tensor Performs a matrix-vector product of the matrix :attr:`input` and the vector @@ -7038,16 +7971,23 @@ def merge_dicts(*dicts): >>> vec = torch.randn(3) >>> torch.mv(mat, vec) tensor([ 1.0404, -0.6361]) -""".format(**common_args)) - -add_docstr(torch.mvlgamma, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.mvlgamma, + r""" mvlgamma(input, p, *, out=None) -> Tensor Alias for :func:`torch.special.multigammaln`. -""") +""", +) -add_docstr(torch.movedim, r""" +add_docstr( + torch.movedim, + r""" movedim(input, source, destination) -> Tensor Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` @@ -7089,9 +8029,14 @@ def merge_dicts(*dicts): tensor([[[-0.3362, -0.9627, 0.5173]], [[-0.8437, 0.1727, -0.1398]]]) -""".format(**common_args)) - -add_docstr(torch.moveaxis, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.moveaxis, + r""" moveaxis(input, source, destination) -> Tensor Alias for :func:`torch.movedim`. @@ -7126,9 +8071,14 @@ def merge_dicts(*dicts): tensor([[[-0.3362, -0.9627, 0.5173]], [[-0.8437, 0.1727, -0.1398]]]) -""".format(**common_args)) - -add_docstr(torch.swapdims, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.swapdims, + r""" swapdims(input, dim0, dim1) -> Tensor Alias for :func:`torch.transpose`. @@ -7156,9 +8106,14 @@ def merge_dicts(*dicts): [[1, 5], [3, 7]]]) -""".format(**common_args)) - -add_docstr(torch.swapaxes, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.swapaxes, + r""" swapaxes(input, axis0, axis1) -> Tensor Alias for :func:`torch.transpose`. @@ -7186,10 +8141,14 @@ def merge_dicts(*dicts): [[1, 5], [3, 7]]]) -""".format(**common_args)) - -add_docstr(torch.narrow, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.narrow, + r""" narrow(input, dim, start, length) -> Tensor Returns a new tensor that is a narrowed version of :attr:`input` tensor. The @@ -7212,10 +8171,12 @@ def merge_dicts(*dicts): tensor([[ 2, 3], [ 5, 6], [ 8, 9]]) -""") +""", +) -add_docstr(torch.nan_to_num, - r""" +add_docstr( + torch.nan_to_num, + r""" nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` @@ -7247,13 +8208,19 @@ def merge_dicts(*dicts): >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.ne, r""" +add_docstr( + torch.ne, + r""" ne(input, other, *, out=None) -> Tensor Computes :math:`\text{input} \neq \text{other}` element-wise. -""" + r""" +""" + + r""" The second argument can be a number or a tensor whose shape is :ref:`broadcastable ` with the first argument. @@ -7272,23 +8239,31 @@ def merge_dicts(*dicts): >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[False, True], [True, False]]) -""".format(**common_args)) - -add_docstr(torch.not_equal, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.not_equal, + r""" not_equal(input, other, *, out=None) -> Tensor Alias for :func:`torch.ne`. -""") +""", +) -add_docstr(torch.neg, - r""" +add_docstr( + torch.neg, + r""" neg(input, *, out=None) -> Tensor Returns a new tensor with the negative of the elements of :attr:`input`. .. math:: \text{out} = -1 \times \text{input} -""" + r""" +""" + + r""" Args: {input} @@ -7302,17 +8277,23 @@ def merge_dicts(*dicts): tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) >>> torch.neg(a) tensor([-0.0090, 0.2262, 0.0682, 0.2866, -0.3940]) -""".format(**common_args)) - -add_docstr(torch.negative, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.negative, + r""" negative(input, *, out=None) -> Tensor Alias for :func:`torch.neg` -""") +""", +) -add_docstr(torch.nextafter, - r""" +add_docstr( + torch.nextafter, + r""" nextafter(input, other, *, out=None) -> Tensor Return the next floating-point value after :attr:`input` towards :attr:`other`, elementwise. @@ -7333,10 +8314,14 @@ def merge_dicts(*dicts): >>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps]) tensor([True, True]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.nonzero, - r""" +add_docstr( + torch.nonzero, + r""" nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors .. note:: @@ -7413,10 +8398,14 @@ def merge_dicts(*dicts): (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) >>> torch.nonzero(torch.tensor(5), as_tuple=True) (tensor([0]),) -""".format(**common_args)) - -add_docstr(torch.normal, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.normal, + r""" normal(mean, std, *, generator=None, out=None) -> Tensor Returns a tensor of random numbers drawn from separate normal distributions @@ -7505,10 +8494,14 @@ def merge_dicts(*dicts): >>> torch.normal(2, 3, size=(1, 4)) tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) -""".format(**common_args)) - -add_docstr(torch.numel, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.numel, + r""" numel(input) -> int Returns the total number of elements in the :attr:`input` tensor. @@ -7525,10 +8518,14 @@ def merge_dicts(*dicts): >>> torch.numel(a) 16 -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.ones, - r""" +add_docstr( + torch.ones, + r""" ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a tensor filled with the scalar value `1`, with the shape defined @@ -7554,10 +8551,14 @@ def merge_dicts(*dicts): >>> torch.ones(5) tensor([ 1., 1., 1., 1., 1.]) -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.ones_like, - r""" +add_docstr( + torch.ones_like, + r""" ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns a tensor filled with the scalar value `1`, with the same size as @@ -7585,17 +8586,23 @@ def merge_dicts(*dicts): >>> torch.ones_like(input) tensor([[ 1., 1., 1.], [ 1., 1., 1.]]) -""".format(**factory_like_common_args)) - -add_docstr(torch.orgqr, - r""" +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.orgqr, + r""" orgqr(input, tau) -> Tensor Alias for :func:`torch.linalg.householder_product`. -""") +""", +) -add_docstr(torch.ormqr, - r""" +add_docstr( + torch.ormqr, + r""" ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. @@ -7630,17 +8637,19 @@ def merge_dicts(*dicts): .. _Representation of Orthogonal or Unitary Matrices: https://www.netlib.org/lapack/lug/node128.html -""") +""", +) -add_docstr(torch.permute, - r""" +add_docstr( + torch.permute, + r""" permute(input, dims) -> Tensor Returns a view of the original tensor :attr:`input` with its dimensions permuted. Args: {input} - dims (tuple of ints): The desired ordering of dimensions + dims (tuple of int): The desired ordering of dimensions Example: >>> x = torch.randn(2, 3, 5) @@ -7648,10 +8657,14 @@ def merge_dicts(*dicts): torch.Size([2, 3, 5]) >>> torch.permute(x, (2, 0, 1)).size() torch.Size([5, 2, 3]) -""".format(**common_args)) - -add_docstr(torch.poisson, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.poisson, + r""" poisson(input, generator=None) -> Tensor Returns a tensor of the same size as :attr:`input` with each element @@ -7675,22 +8688,29 @@ def merge_dicts(*dicts): [8., 6., 6., 0.], [0., 4., 5., 3.], [2., 1., 4., 2.]]) -""".format(**common_args)) - -add_docstr(torch.polygamma, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.polygamma, + r""" polygamma(n, input, *, out=None) -> Tensor Alias for :func:`torch.special.polygamma`. -""") +""", +) -add_docstr(torch.positive, - r""" +add_docstr( + torch.positive, + r""" positive(input) -> Tensor Returns :attr:`input`. Throws a runtime error if :attr:`input` is a bool tensor. -""" + r""" +""" + + r""" Args: {input} @@ -7701,10 +8721,14 @@ def merge_dicts(*dicts): tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) >>> torch.positive(t) tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) -""".format(**common_args)) - -add_docstr(torch.pow, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.pow, + r""" pow(input, exponent, *, out=None) -> Tensor Takes the power of each element in :attr:`input` with :attr:`exponent` and @@ -7722,7 +8746,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_i = x_i ^ {\text{exponent}_i} -""" + r""" +""" + + r""" When :attr:`exponent` is a tensor, the shapes of :attr:`input` and :attr:`exponent` must be :ref:`broadcastable `. @@ -7774,10 +8799,14 @@ def merge_dicts(*dicts): >>> base = 2 >>> torch.pow(base, exp) tensor([ 2., 4., 8., 16.]) -""".format(**common_args)) - -add_docstr(torch.float_power, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.float_power, + r""" float_power(input, exponent, *, out=None) -> Tensor Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. @@ -7814,10 +8843,14 @@ def merge_dicts(*dicts): tensor([ 2, -3, 4, -5]) >>> torch.float_power(a, exp) tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) -""".format(**common_args)) - -add_docstr(torch.prod, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.prod, + r""" prod(input, *, dtype=None) -> Tensor Returns the product of all elements in the :attr:`input` tensor. @@ -7862,10 +8895,14 @@ def merge_dicts(*dicts): [ 1.1131, -1.0629]]) >>> torch.prod(a, 1) tensor([-0.2018, -0.2962, -0.0821, -1.1831]) -""".format(**single_dim_common)) - -add_docstr(torch.promote_types, - r""" +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.promote_types, + r""" promote_types(type1, type2) -> dtype Returns the :class:`torch.dtype` with the smallest size and scalar kind that is @@ -7883,10 +8920,12 @@ def merge_dicts(*dicts): torch.float32 >>> torch.promote_types(torch.uint8, torch.long) torch.long -""") +""", +) -add_docstr(torch.qr, - r""" +add_docstr( + torch.qr, + r""" qr(input, some=True, *, out=None) -> (Tensor, Tensor) Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, @@ -7965,10 +9004,12 @@ def merge_dicts(*dicts): True >>> torch.allclose(torch.matmul(q.mT, q), torch.eye(5)) True -""") +""", +) -add_docstr(torch.rad2deg, - r""" +add_docstr( + torch.rad2deg, + r""" rad2deg(input, *, out=None) -> Tensor Returns a new tensor with each of the elements of :attr:`input` @@ -7988,10 +9029,14 @@ def merge_dicts(*dicts): [ 359.9894, -359.9894], [ 89.9544, -89.9544]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.deg2rad, - r""" +add_docstr( + torch.deg2rad, + r""" deg2rad(input, *, out=None) -> Tensor Returns a new tensor with each of the elements of :attr:`input` @@ -8011,10 +9056,14 @@ def merge_dicts(*dicts): [ 6.2832, -6.2832], [ 1.5708, -1.5708]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.heaviside, - r""" +add_docstr( + torch.heaviside, + r""" heaviside(input, values, *, out=None) -> Tensor Computes the Heaviside step function for each element in :attr:`input`. @@ -8026,7 +9075,8 @@ def merge_dicts(*dicts): values, & \text{if input == 0}\\ 1, & \text{if input > 0} \end{cases} -""" + r""" +""" + + r""" Args: {input} @@ -8045,10 +9095,14 @@ def merge_dicts(*dicts): >>> torch.heaviside(input, values) tensor([0., -2., 1.]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.rand, - r""" +add_docstr( + torch.rand, + r""" rand(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a tensor filled with random numbers from a uniform distribution @@ -8075,10 +9129,14 @@ def merge_dicts(*dicts): >>> torch.rand(2, 3) tensor([[ 0.8237, 0.5781, 0.6879], [ 0.3816, 0.7249, 0.0998]]) -""".format(**factory_common_args)) - -add_docstr(torch.rand_like, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.rand_like, + r""" rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns a tensor with the same size as :attr:`input` that is filled with @@ -8096,10 +9154,14 @@ def merge_dicts(*dicts): {requires_grad} {memory_format} -""".format(**factory_like_common_args)) +""".format( + **factory_like_common_args + ), +) -add_docstr(torch.randint, - """ +add_docstr( + torch.randint, + """ randint(low=0, high, size, \\*, generator=None, out=None, \ dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -8142,10 +9204,14 @@ def merge_dicts(*dicts): [6, 7]]) -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.randint_like, - """ +add_docstr( + torch.randint_like, + """ randint_like(input, low=0, high, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ memory_format=torch.preserve_format) -> Tensor @@ -8169,10 +9235,14 @@ def merge_dicts(*dicts): {requires_grad} {memory_format} -""".format(**factory_like_common_args)) +""".format( + **factory_like_common_args + ), +) -add_docstr(torch.randn, - r""" +add_docstr( + torch.randn, + r""" randn(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a tensor filled with random numbers from a normal distribution @@ -8203,10 +9273,14 @@ def merge_dicts(*dicts): >>> torch.randn(2, 3) tensor([[ 1.5954, 2.8929, -1.0923], [ 1.1719, -0.4709, -0.1996]]) -""".format(**factory_common_args)) - -add_docstr(torch.randn_like, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.randn_like, + r""" randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns a tensor with the same size as :attr:`input` that is filled with @@ -8224,13 +9298,18 @@ def merge_dicts(*dicts): {requires_grad} {memory_format} -""".format(**factory_like_common_args)) +""".format( + **factory_like_common_args + ), +) -add_docstr(torch.randperm, - """ +add_docstr( + torch.randperm, + """ randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, \ device=None, requires_grad=False, pin_memory=False) -> Tensor -""" + r""" +""" + + r""" Returns a random permutation of integers from ``0`` to ``n - 1``. Args: @@ -8250,10 +9329,14 @@ def merge_dicts(*dicts): >>> torch.randperm(4) tensor([2, 1, 0, 3]) -""".format(**factory_common_args)) - -add_docstr(torch.tensor, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.tensor, + r""" tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor Constructs a tensor with no autograd history (also known as a "leaf tensor", see :doc:`/notes/autograd`) by copying :attr:`data`. @@ -8303,10 +9386,14 @@ def merge_dicts(*dicts): >>> torch.tensor([]) # Create an empty tensor (of size (0,)) tensor([]) -""".format(**factory_data_common_args)) - -add_docstr(torch.range, - r""" +""".format( + **factory_data_common_args + ), +) + +add_docstr( + torch.range, + r""" range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1` @@ -8315,7 +9402,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i+1} = \text{out}_i + \text{step}. -""" + r""" +""" + + r""" .. warning:: This function is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin. Instead, use :func:`torch.arange`, which produces values in [start, end). @@ -8342,10 +9430,14 @@ def merge_dicts(*dicts): tensor([ 1., 2., 3., 4.]) >>> torch.range(1, 4, 0.5) tensor([ 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000]) -""".format(**factory_common_args)) - -add_docstr(torch.arange, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.arange, + r""" arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` @@ -8358,7 +9450,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{{i+1}} = \text{out}_{i} + \text{step} -""" + r""" +""" + + r""" Args: start (Number): the starting value for the set of points. Default: ``0``. end (Number): the ending value for the set of points @@ -8383,10 +9476,14 @@ def merge_dicts(*dicts): tensor([ 1, 2, 3]) >>> torch.arange(1, 2.5, 0.5) tensor([ 1.0000, 1.5000, 2.0000]) -""".format(**factory_common_args)) - -add_docstr(torch.ravel, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.ravel, + r""" ravel(input) -> Tensor Return a contiguous flattened tensor. A copy is made only if needed. @@ -8402,10 +9499,14 @@ def merge_dicts(*dicts): ... [7, 8]]]) >>> torch.ravel(t) tensor([1, 2, 3, 4, 5, 6, 7, 8]) -""".format(**common_args)) - -add_docstr(torch.remainder, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.remainder, + r""" remainder(input, other, *, out=None) -> Tensor Computes @@ -8445,10 +9546,14 @@ def merge_dicts(*dicts): tensor([ 1., 0., 1., 1., 0., 1.]) >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) -""".format(**common_args)) - -add_docstr(torch.renorm, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.renorm, + r""" renorm(input, p, dim, maxnorm, *, out=None) -> Tensor Returns a tensor where each sub-tensor of :attr:`input` along dimension @@ -8481,10 +9586,14 @@ def merge_dicts(*dicts): tensor([[ 1.0000, 1.0000, 1.0000], [ 1.6667, 1.6667, 1.6667], [ 1.6667, 1.6667, 1.6667]]) -""".format(**common_args)) - -add_docstr(torch.reshape, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.reshape, + r""" reshape(input, shape) -> Tensor Returns a tensor with the same data and number of elements as :attr:`input`, @@ -8500,7 +9609,7 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to be reshaped - shape (tuple of ints): the new shape + shape (tuple of int): the new shape Example:: @@ -8511,11 +9620,13 @@ def merge_dicts(*dicts): >>> b = torch.tensor([[0, 1], [2, 3]]) >>> torch.reshape(b, (-1,)) tensor([ 0, 1, 2, 3]) -""") +""", +) -add_docstr(torch.result_type, - r""" +add_docstr( + torch.result_type, + r""" result_type(tensor1, tensor2) -> dtype Returns the :class:`torch.dtype` that would result from performing an arithmetic @@ -8532,16 +9643,21 @@ def merge_dicts(*dicts): torch.float32 >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) torch.uint8 -""") +""", +) -add_docstr(torch.row_stack, - r""" +add_docstr( + torch.row_stack, + r""" row_stack(tensors, *, out=None) -> Tensor Alias of :func:`torch.vstack`. -""") +""", +) -add_docstr(torch.round, r""" +add_docstr( + torch.round, + r""" round(input, *, decimals=0, out=None) -> Tensor Rounds elements of :attr:`input` to the nearest integer. @@ -8588,10 +9704,14 @@ def merge_dicts(*dicts): >>> # A negative decimals argument rounds to the left of the decimal >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) tensor([1000.]) -""".format(**common_args)) - -add_docstr(torch.rsqrt, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.rsqrt, + r""" rsqrt(input, *, out=None) -> Tensor Returns a new tensor with the reciprocal of the square-root of each of @@ -8599,7 +9719,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} = \frac{1}{\sqrt{\text{input}_{i}}} -""" + r""" +""" + + r""" Args: {input} @@ -8613,30 +9734,41 @@ def merge_dicts(*dicts): tensor([-0.0370, 0.2970, 1.5420, -0.9105]) >>> torch.rsqrt(a) tensor([ nan, 1.8351, 0.8053, nan]) -""".format(**common_args)) - -add_docstr(torch.scatter, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.scatter, + r""" scatter(input, dim, index, src) -> Tensor Out-of-place version of :meth:`torch.Tensor.scatter_` -""") +""", +) -add_docstr(torch.scatter_add, - r""" +add_docstr( + torch.scatter_add, + r""" scatter_add(input, dim, index, src) -> Tensor Out-of-place version of :meth:`torch.Tensor.scatter_add_` -""") +""", +) -add_docstr(torch.scatter_reduce, r""" +add_docstr( + torch.scatter_reduce, + r""" scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` -""") +""", +) -add_docstr(torch.select, - r""" +add_docstr( + torch.select, + r""" select(input, dim, index) -> Tensor Slices the :attr:`input` tensor along the selected dimension at the given index. @@ -8652,10 +9784,14 @@ def merge_dicts(*dicts): :meth:`select` is equivalent to slicing. For example, ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. -""".format(**common_args)) - -add_docstr(torch.select_scatter, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.select_scatter, + r""" select_scatter(input, src, dim, index) -> Tensor Embeds the values of the :attr:`src` tensor into :attr:`input` at the given index. @@ -8681,10 +9817,14 @@ def merge_dicts(*dicts): >>> a.select_scatter(b, 0, 0) tensor([[1., 1.], [0., 0.]]) -""".format(**common_args)) - -add_docstr(torch.slice_scatter, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.slice_scatter, + r""" slice_scatter(input, src, dim=0, start=None, end=None, step=1) -> Tensor Embeds the values of the :attr:`src` tensor into :attr:`input` at the given @@ -8724,10 +9864,14 @@ def merge_dicts(*dicts): [0., 0., 1., 0., 1., 0., 0., 0.], [0., 0., 1., 0., 1., 0., 0., 0.], [0., 0., 1., 0., 1., 0., 0., 0.]]) -""".format(**common_args)) - -add_docstr(torch.set_flush_denormal, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.set_flush_denormal, + r""" set_flush_denormal(mode) -> bool Disables denormal floating numbers on CPU. @@ -8750,9 +9894,12 @@ def merge_dicts(*dicts): >>> torch.tensor([1e-323], dtype=torch.float64) tensor(9.88131e-324 * [ 1.0000], dtype=torch.float64) -""") +""", +) -add_docstr(torch.set_num_threads, r""" +add_docstr( + torch.set_num_threads, + r""" set_num_threads(int) Sets the number of threads used for intraop parallelism on CPU. @@ -8760,9 +9907,12 @@ def merge_dicts(*dicts): .. warning:: To ensure that the correct number of threads is used, set_num_threads must be called before running eager, JIT or autograd code. -""") +""", +) -add_docstr(torch.set_num_interop_threads, r""" +add_docstr( + torch.set_num_interop_threads, + r""" set_num_interop_threads(int) Sets the number of threads used for interop parallelism @@ -8771,30 +9921,38 @@ def merge_dicts(*dicts): .. warning:: Can only be called once and before any inter-op parallel work is started (e.g. JIT execution). -""") +""", +) -add_docstr(torch.sigmoid, r""" +add_docstr( + torch.sigmoid, + r""" sigmoid(input, *, out=None) -> Tensor Alias for :func:`torch.special.expit`. -""") +""", +) -add_docstr(torch.logit, - r""" +add_docstr( + torch.logit, + r""" logit(input, eps=None, *, out=None) -> Tensor Alias for :func:`torch.special.logit`. -""") +""", +) -add_docstr(torch.sign, - r""" +add_docstr( + torch.sign, + r""" sign(input, *, out=None) -> Tensor Returns a new tensor with the signs of the elements of :attr:`input`. .. math:: \text{out}_{i} = \operatorname{sgn}(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -8808,10 +9966,14 @@ def merge_dicts(*dicts): tensor([ 0.7000, -1.2000, 0.0000, 2.3000]) >>> torch.sign(a) tensor([ 1., -1., 0., 1.]) -""".format(**common_args)) - -add_docstr(torch.signbit, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.signbit, + r""" signbit(input, *, out=None) -> Tensor Tests if each element of :attr:`input` has its sign bit set or not. @@ -8834,10 +9996,14 @@ def merge_dicts(*dicts): .. note:: signbit handles signed zeros, so negative zero (-0) returns True. -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.sgn, - r""" +add_docstr( + torch.sgn, + r""" sgn(input, *, out=None) -> Tensor This function is an extension of torch.sign() to complex tensors. @@ -8852,7 +10018,8 @@ def merge_dicts(*dicts): \frac{{\text{{input}}_i}}{|{\text{{input}}_i}|} & \text{otherwise} \end{cases} -""" + r""" +""" + + r""" Args: {input} @@ -8864,17 +10031,22 @@ def merge_dicts(*dicts): >>> t = torch.tensor([3+4j, 7-24j, 0, 1+2j]) >>> t.sgn() tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) -""".format(**common_args)) - -add_docstr(torch.sin, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.sin, + r""" sin(input, *, out=None) -> Tensor Returns a new tensor with the sine of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sin(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -8888,17 +10060,23 @@ def merge_dicts(*dicts): tensor([-0.5461, 0.1347, -2.7266, -0.2746]) >>> torch.sin(a) tensor([-0.5194, 0.1343, -0.4032, -0.2711]) -""".format(**common_args)) - -add_docstr(torch.sinc, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.sinc, + r""" sinc(input, *, out=None) -> Tensor Alias for :func:`torch.special.sinc`. -""") +""", +) -add_docstr(torch.sinh, - r""" +add_docstr( + torch.sinh, + r""" sinh(input, *, out=None) -> Tensor Returns a new tensor with the hyperbolic sine of the elements of @@ -8906,7 +10084,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} = \sinh(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -8925,10 +10104,14 @@ def merge_dicts(*dicts): When :attr:`input` is on the CPU, the implementation of torch.sinh may use the Sleef library, which rounds very large results to infinity or negative infinity. See `here `_ for details. -""".format(**common_args)) - -add_docstr(torch.sort, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.sort, + r""" sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) Sorts the elements of the :attr:`input` tensor along a given dimension @@ -8988,10 +10171,14 @@ def merge_dicts(*dicts): torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) -""".format(**common_args)) - -add_docstr(torch.argsort, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.argsort, + r""" argsort(input, dim=-1, descending=False, stable=False) -> Tensor Returns the indices that sort a tensor along a given dimension in ascending @@ -9025,10 +10212,14 @@ def merge_dicts(*dicts): [3, 2, 1, 0], [2, 1, 0, 3], [3, 2, 1, 0]]) -""".format(**common_args)) - -add_docstr(torch.msort, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.msort, + r""" msort(input, *, out=None) -> Tensor Sorts the elements of the :attr:`input` tensor along its first dimension @@ -9054,40 +10245,46 @@ def merge_dicts(*dicts): tensor([[-2.0527, -1.1250, -1.2631, -1.1289], [-0.1321, -0.1259, -0.5495, 0.3077], [-0.0881, 0.4370, 0.2275, 1.0284]]) -""".format(**common_args)) +""".format( + **common_args + ), +) -add_docstr(torch.sparse_compressed_tensor, - r""" -sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, - *, dtype=None, layout=None, device=None, requires_grad=False) -> Tensor +add_docstr( + torch.sparse_compressed_tensor, + r"""sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, """ + r"""*, dtype=None, layout=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, -CSC, BSR, or BSC - ` with specified values at the -given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse +CSC, BSR, or BSC - ` with specified values at +the given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse matrix multiplication operations in Compressed Sparse format are typically faster than that for sparse tensors in COO format. Make you have a look at :ref:`the note on the data type of the indices -`. +`. Args: - compressed_indices (array_like): One-dimensional array of size - size[cdim] + 1 where cdim is 0 or 1 depending on the layout. - The last element is the number of non-zeros. This tensor - encodes the index in values and plain_indices depending on - where the given compressed dimension (row or column) - starts. Each successive number in the tensor subtracted by the - number before it denotes the number of elements in a given - compressed dimension. + compressed_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, compressed_dim_size + 1)``. The last element of + each batch is the number of non-zero elements or blocks. This + tensor encodes the index in ``values`` and ``plain_indices`` + depending on where the given compressed dimension (row or + column) starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + elements or blocks in a given compressed dimension. plain_indices (array_like): Plain dimension (column or row) - co-ordinates of each element in values. Strictly one - dimensional tensor with the same length as values. + co-ordinates of each element or block in values. (B+1)-dimensional + tensor with the same length as values. values (array_list): Initial values for the tensor. Can be a list, - tuple, NumPy ``ndarray``, scalar, and other types. For block - sparse formats, the dimensionality of values must be two plus - the dimensionality of plain_indices. + tuple, NumPy ``ndarray``, scalar, and other types. that + represents a (1+K)-dimensional or (1+2+K)-dimensional tensor + where ``K`` is the number of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the - sparse tensor. If not provided, the size will be inferred as - the minimum size big enough to hold all non-zero elements. + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize[0] == + blocksize[1] == 1`` for CSR and CSC formats. If not provided, + the size will be inferred as the minimum size big enough to + hold all non-zero elements or blocks. Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of @@ -9116,10 +10313,14 @@ def merge_dicts(*dicts): col_indices=tensor([0, 1, 0, 1]), values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) -""".format(**factory_common_args)) - -add_docstr(torch.sparse_csr_tensor, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_csr_tensor, + r""" sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified @@ -9128,20 +10329,24 @@ def merge_dicts(*dicts): at :ref:`the note on the data type of the indices `. Args: - crow_indices (array_like): One-dimensional array of size size[0] + 1. - The last element is the number of non-zeros. This tensor - encodes the index in values and col_indices depending on where - the given row starts. Each successive number in the tensor - subtracted by the number before it denotes the number of - elements in a given row. + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrows + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and col_indices depending on where the given row + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + row. col_indices (array_like): Column co-ordinates of each element in - values. Strictly one dimensional tensor with the same length + values. (B+1)-dimensional tensor with the same length as values. values (array_list): Initial values for the tensor. Can be a list, - tuple, NumPy ``ndarray``, scalar, and other types. + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensonal tensor where ``K`` is the number + of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the - sparse tensor. If not provided, the size will be inferred as - the minimum size big enough to hold all non-zero elements. + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of @@ -9166,34 +10371,42 @@ def merge_dicts(*dicts): col_indices=tensor([0, 1, 0, 1]), values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) -""".format(**factory_common_args)) - -add_docstr(torch.sparse_csc_tensor, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_csc_tensor, + r""" sparse_csc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) -` with specified values at the given +` with specified values at the given :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix multiplication operations in CSC format are typically faster than that for sparse tensors in COO format. Make you have a look at :ref:`the -note on the data type of the indices `. +note on the data type of the indices `. Args: - ccol_indices (array_like): One-dimensional array of size size[1] + 1. - The last element is the number of non-zeros. This tensor - encodes the index in values and row_indices depending on where - the given column starts. Each successive number in the tensor - subtracted by the number before it denotes the number of - elements in a given column. + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncols + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and row_indices depending on where the given column + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + column. row_indices (array_like): Row co-ordinates of each element in - values. Strictly one dimensional tensor with the same length - as values. + values. (B+1)-dimensional tensor with the same length as + values. values (array_list): Initial values for the tensor. Can be a list, - tuple, NumPy ``ndarray``, scalar, and other types. + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensonal tensor where ``K`` is the number + of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the - sparse tensor. If not provided, the size will be inferred as - the minimum size big enough to hold all non-zero elements. + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of @@ -9218,36 +10431,44 @@ def merge_dicts(*dicts): row_indices=tensor([0, 1, 0, 1]), values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) -""".format(**factory_common_args)) - -add_docstr(torch.sparse_bsr_tensor, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_bsr_tensor, + r""" sparse_bsr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) -` with specified 2-dimensional blocks at the given +` with specified 2-dimensional blocks at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations in BSR format are typically faster than that for sparse tensors in COO format. Make you have a look at :ref:`the -note on the data type of the indices `. +note on the data type of the indices `. Args: - crow_indices (array_like): One-dimensional array of size size[0] + - 1. The last element is the number of non-zeros. This tensor - encodes the index in values and col_indices depending on where - the given row starts. Each successive number in the tensor + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrowblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + block index in values and col_indices depending on where the + given row block starts. Each successive number in the tensor subtracted by the number before it denotes the number of blocks in a given row. - col_indices (array_like): Column co-ordinates of each block in - values. Strictly one dimensional tensor with the same length - as values. + col_indices (array_like): Column block co-ordinates of each block + in values. (B+1)-dimensional tensor with the same length as + values. values (array_list): Initial values for the tensor. Can be a list, - tuple, NumPy ``ndarray``, scalar, and other types. The - dimensionality of values must be two plus the dimensionality - of col_indices. + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1 + 2 + K)-dimensonal tensor where ``K`` is the + number of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the - sparse tensor. If not provided, the size will be inferred as - the minimum size big enough to hold all non-zero blocks. + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize == + values.shape[1:3]``. If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of @@ -9272,39 +10493,46 @@ def merge_dicts(*dicts): col_indices=tensor([0, 1]), values=tensor([[[1., 2.], [3., 4.]], - [[5., 6.], [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, layout=torch.sparse_bsr) -""".format(**factory_common_args)) - -add_docstr(torch.sparse_bsc_tensor, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_bsc_tensor, + r""" sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse -Column)) ` with specified 2-dimensional blocks at the +Column)) ` with specified 2-dimensional blocks at the given :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix multiplication operations in BSC format are typically faster than that for sparse tensors in COO format. Make you have a look at :ref:`the -note on the data type of the indices `. - -Args: - ccol_indices (array_like): One-dimensional array of size size[1] + - 1. The last element is the number of non-zeros. This tensor - encodes the index in values and row_indices depending on where - the given column starts. Each successive number in the tensor - subtracted by the number before it denotes the number of - elements in a given column. - row_indices (array_like): Row co-ordinates of each element in - values. Strictly one dimensional tensor with the same length +note on the data type of the indices `. + +Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncolblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + index in values and row_indices depending on where the given + column starts. Each successive number in the tensor subtracted + by the number before it denotes the number of elements in a + given column. + row_indices (array_like): Row block co-ordinates of each block in + values. (B+1)-dimensional tensor with the same length as values. values (array_list): Initial blocks for the tensor. Can be a list, - tuple, NumPy ``ndarray``, and other types. The dimensionality - of values must be two plus the dimensionality of row_indices. + tuple, NumPy ``ndarray``, and other types that + represents a (1 + 2 + K)-dimensonal tensor where ``K`` is the + number of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the - sparse tensor. If not provided, the size will be inferred as - the minimum size big enough to hold all non-zero blocks. + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of @@ -9329,14 +10557,17 @@ def merge_dicts(*dicts): row_indices=tensor([0, 1]), values=tensor([[[1., 2.], [3., 4.]], - [[5., 6.], [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, layout=torch.sparse_bsc) -""".format(**factory_common_args)) - -add_docstr(torch.sparse_coo_tensor, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_coo_tensor, + r""" sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor Constructs a :ref:`sparse tensor in COO(rdinate) format @@ -9415,17 +10646,22 @@ def merge_dicts(*dicts): size=(1, 2), nnz=0, layout=torch.sparse_coo) .. _torch.sparse: https://pytorch.org/docs/stable/sparse.html -""".format(**factory_common_args)) - -add_docstr(torch.sqrt, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sqrt, + r""" sqrt(input, *, out=None) -> Tensor Returns a new tensor with the square-root of the elements of :attr:`input`. .. math:: \text{out}_{i} = \sqrt{\text{input}_{i}} -""" + r""" +""" + + r""" Args: {input} @@ -9439,10 +10675,14 @@ def merge_dicts(*dicts): tensor([-2.0755, 1.0226, 0.0831, 0.4806]) >>> torch.sqrt(a) tensor([ nan, 1.0112, 0.2883, 0.6933]) -""".format(**common_args)) - -add_docstr(torch.square, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.square, + r""" square(input, *, out=None) -> Tensor Returns a new tensor with the square of the elements of :attr:`input`. @@ -9460,10 +10700,14 @@ def merge_dicts(*dicts): tensor([-2.0755, 1.0226, 0.0831, 0.4806]) >>> torch.square(a) tensor([ 4.3077, 1.0457, 0.0069, 0.2310]) -""".format(**common_args)) - -add_docstr(torch.squeeze, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.squeeze, + r""" squeeze(input, dim=None, *, out=None) -> Tensor Returns a tensor with all the dimensions of :attr:`input` of size `1` removed. @@ -9506,9 +10750,14 @@ def merge_dicts(*dicts): >>> y = torch.squeeze(x, 1) >>> y.size() torch.Size([2, 2, 1, 2]) -""".format(**common_args)) - -add_docstr(torch.std, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.std, + r""" std(input, dim, unbiased, keepdim=False, *, out=None) -> Tensor If :attr:`unbiased` is ``True``, Bessel's correction will be used. @@ -9541,10 +10790,14 @@ def merge_dicts(*dicts): >>> a = torch.tensor([[-0.8166, -1.3802, -0.3560]]) >>> torch.std(a, unbiased=False) tensor(0.4188) -""".format(**multi_dim_common)) - -add_docstr(torch.std_mean, - r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.std_mean, + r""" std_mean(input, dim, unbiased, keepdim=False, *, out=None) -> (Tensor, Tensor) If :attr:`unbiased` is ``True``, Bessel's correction will be used to calculate @@ -9584,16 +10837,22 @@ def merge_dicts(*dicts): >>> a = torch.tensor([[-0.8166, -1.3802, -0.3560]]) >>> torch.std_mean(a, unbiased=False) (tensor(0.4188), tensor(-0.8509)) -""".format(**multi_dim_common)) - -add_docstr(torch.sub, r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.sub, + r""" sub(input, other, *, alpha=1, out=None) -> Tensor Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. .. math:: \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i -""" + r""" +""" + + r""" Supports :ref:`broadcasting to a common shape `, :ref:`type promotion `, and integer, float, and complex inputs. @@ -9612,16 +10871,23 @@ def merge_dicts(*dicts): >>> b = torch.tensor((0, 1)) >>> torch.sub(a, b, alpha=2) tensor([1, 0]) -""".format(**common_args)) - -add_docstr(torch.subtract, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.subtract, + r""" subtract(input, other, *, alpha=1, out=None) -> Tensor Alias for :func:`torch.sub`. -""") +""", +) -add_docstr(torch.sum, - r""" +add_docstr( + torch.sum, + r""" sum(input, *, dtype=None) -> Tensor Returns the sum of all elements in the :attr:`input` tensor. @@ -9651,7 +10917,7 @@ def merge_dicts(*dicts): Args: {input} - {dim} + {opt_dim} {keepdim} Keyword args: @@ -9670,10 +10936,14 @@ def merge_dicts(*dicts): >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) >>> torch.sum(b, (2, 1)) tensor([ 435., 1335., 2235., 3135.]) -""".format(**multi_dim_common)) - -add_docstr(torch.nansum, - r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.nansum, + r""" nansum(input, *, dtype=None) -> Tensor Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. @@ -9718,10 +10988,14 @@ def merge_dicts(*dicts): tensor([4., 2.]) >>> torch.nansum(a, dim=1) tensor([3., 3.]) -""".format(**multi_dim_common)) - -add_docstr(torch.svd, - r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.svd, + r""" svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) Computes the singular value decomposition of either a matrix or batch of @@ -9850,9 +11124,12 @@ def merge_dicts(*dicts): .. _the resulting vectors will span the same subspace: (https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD) -""") +""", +) -add_docstr(torch.symeig, r""" +add_docstr( + torch.symeig, + r""" symeig(input, eigenvectors=False, upper=True, *, out=None) -> (Tensor, Tensor) This function returns eigenvalues and eigenvectors @@ -9908,7 +11185,7 @@ def merge_dicts(*dicts): input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more batch dimensions consisting of symmetric or Hermitian matrices. eigenvectors(bool, optional): controls whether eigenvectors have to be computed - upper(boolean, optional): controls whether to consider upper-triangular or lower-triangular region + upper(bool, optional): controls whether to consider upper-triangular or lower-triangular region Keyword args: out (tuple, optional): the output tuple of (Tensor, Tensor) @@ -9946,10 +11223,12 @@ def merge_dicts(*dicts): >>> e, v = a_big.symeig(eigenvectors=True) >>> torch.allclose(torch.matmul(v, torch.matmul(e.diag_embed(), v.mT)), a_big) True -""") +""", +) -add_docstr(torch.t, - r""" +add_docstr( + torch.t, + r""" t(input) -> Tensor Expects :attr:`input` to be <= 2-D tensor and transposes dimensions 0 @@ -9983,10 +11262,14 @@ def merge_dicts(*dicts): [-0.5872, 0.6932]]) See also :func:`torch.transpose`. -""".format(**common_args)) - -add_docstr(torch.flip, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.flip, + r""" flip(input, dims) -> Tensor Reverse the order of a n-D tensor along given axis in dims. @@ -10015,10 +11298,14 @@ def merge_dicts(*dicts): [[ 2, 3], [ 0, 1]]]) -""".format(**common_args)) - -add_docstr(torch.fliplr, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.fliplr, + r""" fliplr(input) -> Tensor Flip tensor in the left/right direction, returning a new tensor. @@ -10046,10 +11333,14 @@ def merge_dicts(*dicts): >>> torch.fliplr(x) tensor([[1, 0], [3, 2]]) -""".format(**common_args)) - -add_docstr(torch.flipud, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.flipud, + r""" flipud(input) -> Tensor Flip tensor in the up/down direction, returning a new tensor. @@ -10077,10 +11368,14 @@ def merge_dicts(*dicts): >>> torch.flipud(x) tensor([[2, 3], [0, 1]]) -""".format(**common_args)) - -add_docstr(torch.roll, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.roll, + r""" roll(input, shifts, dims=None) -> Tensor Roll the tensor :attr:`input` along the given dimension(s). Elements that are @@ -10124,10 +11419,14 @@ def merge_dicts(*dicts): [8, 7], [2, 1], [4, 3]]) -""".format(**common_args)) - -add_docstr(torch.rot90, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.rot90, + r""" rot90(input, k, dims) -> Tensor Rotate a n-D tensor by 90 degrees in the plane specified by dims axis. @@ -10161,10 +11460,14 @@ def merge_dicts(*dicts): [[5, 7], [4, 6]]]) -""".format(**common_args)) - -add_docstr(torch.take, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.take, + r""" take(input, index) -> Tensor Returns a new tensor with the elements of :attr:`input` at the given indices. @@ -10181,10 +11484,14 @@ def merge_dicts(*dicts): ... [6, 7, 8]]) >>> torch.take(src, torch.tensor([0, 2, 5])) tensor([ 4, 5, 8]) -""".format(**common_args)) - -add_docstr(torch.take_along_dim, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.take_along_dim, + r""" take_along_dim(input, indices, dim, *, out=None) -> Tensor Selects values from :attr:`input` at the 1-dimensional indices from :attr:`indices` along the given :attr:`dim`. @@ -10214,17 +11521,22 @@ def merge_dicts(*dicts): >>> torch.take_along_dim(t, sorted_idx, dim=1) tensor([[10, 20, 30], [40, 50, 60]]) -""".format(**common_args)) - -add_docstr(torch.tan, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.tan, + r""" tan(input, *, out=None) -> Tensor Returns a new tensor with the tangent of the elements of :attr:`input`. .. math:: \text{out}_{i} = \tan(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -10238,10 +11550,14 @@ def merge_dicts(*dicts): tensor([-1.2027, -1.7687, 0.4412, -1.3856]) >>> torch.tan(a) tensor([-2.5930, 4.9859, 0.4722, -5.3366]) -""".format(**common_args)) - -add_docstr(torch.tanh, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.tanh, + r""" tanh(input, *, out=None) -> Tensor Returns a new tensor with the hyperbolic tangent of the elements @@ -10249,7 +11565,8 @@ def merge_dicts(*dicts): .. math:: \text{out}_{i} = \tanh(\text{input}_{i}) -""" + r""" +""" + + r""" Args: {input} @@ -10263,10 +11580,14 @@ def merge_dicts(*dicts): tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) >>> torch.tanh(a) tensor([ 0.7156, -0.6218, 0.8257, 0.2553]) -""".format(**common_args)) - -add_docstr(torch.topk, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.topk, + r""" topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) Returns the :attr:`k` largest elements of the given :attr:`input` tensor along @@ -10303,10 +11624,14 @@ def merge_dicts(*dicts): tensor([ 1., 2., 3., 4., 5.]) >>> torch.topk(x, 3) torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2])) -""".format(**common_args)) - -add_docstr(torch.trace, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.trace, + r""" trace(input) -> Tensor Returns the sum of the elements of the diagonal of the input 2-D matrix. @@ -10320,10 +11645,12 @@ def merge_dicts(*dicts): [ 7., 8., 9.]]) >>> torch.trace(x) tensor(15.) -""") +""", +) -add_docstr(torch.transpose, - r""" +add_docstr( + torch.transpose, + r""" transpose(input, dim0, dim1) -> Tensor Returns a tensor that is a transposed version of :attr:`input`. @@ -10354,10 +11681,14 @@ def merge_dicts(*dicts): [ 0.5809, 0.4942]]) See also :func:`torch.t`. -""".format(**common_args)) - -add_docstr(torch.triangular_solve, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.triangular_solve, + r""" triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor) Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A` @@ -10426,10 +11757,12 @@ def merge_dicts(*dicts): [ 1.9320, 0.9270, -1.2826]]), cloned_coefficient=tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]])) -""") +""", +) -add_docstr(torch.tril, - r""" +add_docstr( + torch.tril, + r""" tril(input, diagonal=0, *, out=None) -> Tensor Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices @@ -10445,7 +11778,8 @@ def merge_dicts(*dicts): the main diagonal. The main diagonal are the set of indices :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where :math:`d_{1}, d_{2}` are the dimensions of the matrix. -""" + r""" +""" + + r""" Args: {input} diagonal (int, optional): the diagonal to consider @@ -10481,12 +11815,16 @@ def merge_dicts(*dicts): [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]]) -""".format(**common_args)) +""".format( + **common_args + ), +) # docstr is split in two parts to avoid format mis-captureing :math: braces '{}' # as common args. -add_docstr(torch.tril_indices, - r""" +add_docstr( + torch.tril_indices, + r""" tril_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor Returns the indices of the lower triangular part of a :attr:`row`-by- @@ -10508,7 +11846,8 @@ def merge_dicts(*dicts): .. note:: When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to prevent overflow during calculation. -""" + r""" +""" + + r""" Args: row (``int``): number of rows in the 2-D matrix. col (``int``): number of columns in the 2-D matrix. @@ -10537,10 +11876,14 @@ def merge_dicts(*dicts): >>> a tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]]) -""".format(**factory_common_args)) - -add_docstr(torch.triu, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.triu, + r""" triu(input, diagonal=0, *, out=None) -> Tensor Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices @@ -10556,7 +11899,8 @@ def merge_dicts(*dicts): the main diagonal. The main diagonal are the set of indices :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where :math:`d_{1}, d_{2}` are the dimensions of the matrix. -""" + r""" +""" + + r""" Args: {input} diagonal (int, optional): the diagonal to consider @@ -10600,12 +11944,16 @@ def merge_dicts(*dicts): [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]]) -""".format(**common_args)) +""".format( + **common_args + ), +) # docstr is split in two parts to avoid format mis-capturing :math: braces '{}' # as common args. -add_docstr(torch.triu_indices, - r""" +add_docstr( + torch.triu_indices, + r""" triu_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor Returns the indices of the upper triangular part of a :attr:`row` by @@ -10627,7 +11975,8 @@ def merge_dicts(*dicts): .. note:: When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to prevent overflow during calculation. -""" + r""" +""" + + r""" Args: row (``int``): number of rows in the 2-D matrix. col (``int``): number of columns in the 2-D matrix. @@ -10656,16 +12005,23 @@ def merge_dicts(*dicts): >>> a tensor([[0, 0, 1], [1, 2, 2]]) -""".format(**factory_common_args)) - -add_docstr(torch.true_divide, r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.true_divide, + r""" true_divide(dividend, divisor, *, out) -> Tensor Alias for :func:`torch.div` with ``rounding_mode=None``. -""") +""", +) -add_docstr(torch.trunc, - r""" +add_docstr( + torch.trunc, + r""" trunc(input, *, out=None) -> Tensor Returns a new tensor with the truncated integer values of @@ -10684,10 +12040,14 @@ def merge_dicts(*dicts): tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) >>> torch.trunc(a) tensor([ 3., 0., -0., -0.]) -""".format(**common_args)) - -add_docstr(torch.fake_quantize_per_tensor_affine, - r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.fake_quantize_per_tensor_affine, + r""" fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, @@ -10721,10 +12081,12 @@ def merge_dicts(*dicts): tensor([0.1000, 1.0000, 0.4000, 0.0000]) >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) tensor([0.6000, 0.4000, 0.0000, 0.0000]) -""") +""", +) -add_docstr(torch.fake_quantize_per_channel_affine, - r""" +add_docstr( + torch.fake_quantize_per_channel_affine, + r""" fake_quantize_per_channel_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor Returns a new tensor with the data in :attr:`input` fake quantized per channel using :attr:`scale`, @@ -10771,17 +12133,21 @@ def merge_dicts(*dicts): [[0.0000, 1.6134], [0.6323, 0.0000]]]) -""") +""", +) -add_docstr(torch.fix, - r""" +add_docstr( + torch.fix, + r""" fix(input, *, out=None) -> Tensor Alias for :func:`torch.trunc` -""") +""", +) -add_docstr(torch.unsqueeze, - r""" +add_docstr( + torch.unsqueeze, + r""" unsqueeze(input, dim) -> Tensor Returns a new tensor with a dimension of size one inserted at the @@ -10807,9 +12173,14 @@ def merge_dicts(*dicts): [ 2], [ 3], [ 4]]) -""".format(**common_args)) - -add_docstr(torch.var, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.var, + r""" var(input, dim, unbiased, keepdim=False, *, out=None) -> Tensor If :attr:`unbiased` is ``True``, Bessel's correction will be used. @@ -10841,10 +12212,14 @@ def merge_dicts(*dicts): >>> a = torch.tensor([[-0.8166, -1.3802, -0.3560]]) >>> torch.var(a, unbiased=False) tensor(0.1754) -""".format(**multi_dim_common)) - -add_docstr(torch.var_mean, - r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.var_mean, + r""" var_mean(input, dim, unbiased, keepdim=False, *, out=None) -> (Tensor, Tensor) If :attr:`unbiased` is ``True``, Bessel's correction will be used to calculate @@ -10884,10 +12259,14 @@ def merge_dicts(*dicts): >>> a = torch.tensor([[-0.8166, -1.3802, -0.3560]]) >>> torch.var_mean(a, unbiased=False) (tensor(0.1754), tensor(-0.8509)) -""".format(**multi_dim_common)) - -add_docstr(torch.zeros, - r""" +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.zeros, + r""" zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a tensor filled with the scalar value `0`, with the shape defined @@ -10912,10 +12291,14 @@ def merge_dicts(*dicts): >>> torch.zeros(5) tensor([ 0., 0., 0., 0., 0.]) -""".format(**factory_common_args)) - -add_docstr(torch.zeros_like, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.zeros_like, + r""" zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns a tensor filled with the scalar value `0`, with the same size as @@ -10943,10 +12326,14 @@ def merge_dicts(*dicts): >>> torch.zeros_like(input) tensor([[ 0., 0., 0.], [ 0., 0., 0.]]) -""".format(**factory_like_common_args)) +""".format( + **factory_like_common_args + ), +) -add_docstr(torch.empty, - """ +add_docstr( + torch.empty, + """ empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, \ memory_format=torch.contiguous_format) -> Tensor @@ -10971,10 +12358,14 @@ def merge_dicts(*dicts): >>> torch.empty((2,3), dtype=torch.int64) tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) -""".format(**factory_common_args)) - -add_docstr(torch.empty_like, - r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.empty_like, + r""" empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns an uninitialized tensor with the same size as :attr:`input`. @@ -10997,10 +12388,14 @@ def merge_dicts(*dicts): >>> torch.empty_like(a) tensor([[0, 0, 0], [0, 0, 0]], device='cuda:0', dtype=torch.int32) -""".format(**factory_like_common_args)) - -add_docstr(torch.empty_strided, - r""" +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.empty_strided, + r""" empty_strided(size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled with undefined data. @@ -11010,8 +12405,8 @@ def merge_dicts(*dicts): in memory) its behavior is undefined. Args: - size (tuple of ints): the shape of the output tensor - stride (tuple of ints): the strides of the output tensor + size (tuple of int): the shape of the output tensor + stride (tuple of int): the strides of the output tensor Keyword args: {dtype} @@ -11030,9 +12425,14 @@ def merge_dicts(*dicts): (1, 2) >>> a.size() torch.Size([2, 3]) -""".format(**factory_common_args)) - -add_docstr(torch.full, r""" +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.full, + r""" full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The @@ -11055,10 +12455,14 @@ def merge_dicts(*dicts): >>> torch.full((2, 3), 3.141592) tensor([[ 3.1416, 3.1416, 3.1416], [ 3.1416, 3.1416, 3.1416]]) -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.full_like, - """ +add_docstr( + torch.full_like, + """ full_like(input, fill_value, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ memory_format=torch.preserve_format) -> Tensor @@ -11076,16 +12480,23 @@ def merge_dicts(*dicts): {device} {requires_grad} {memory_format} -""".format(**factory_like_common_args)) - -add_docstr(torch.det, r""" +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.det, + r""" det(input) -> Tensor Alias for :func:`torch.linalg.det` -""") +""", +) -add_docstr(torch.where, - r""" +add_docstr( + torch.where, + r""" where(condition, x, y) -> Tensor Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`. @@ -11139,17 +12550,18 @@ def merge_dicts(*dicts): .. note:: See also :func:`torch.nonzero`. -""") +""", +) -add_docstr(torch.logdet, - r""" +add_docstr( + torch.logdet, + r""" logdet(input) -> Tensor Calculates log determinant of a square matrix or batches of square matrices. -.. note:: - Result is ``-inf`` if :attr:`input` has zero log determinant, and is ``nan`` if - :attr:`input` has negative determinant. +It returns ``-inf`` if the input has a determinant of zero, and ``NaN`` if it has +a negative determinant. .. note:: Backward through :meth:`logdet` internally uses SVD results when :attr:`input` @@ -11157,6 +12569,11 @@ def merge_dicts(*dicts): be unstable in when :attr:`input` doesn't have distinct singular values. See :func:`torch.linalg.svd` for details. +.. seealso:: + + :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the + absolute value of the determinant of real-valued (resp. complex) square matrices. + Arguments: input (Tensor): the input tensor of size ``(*, n, n)`` where ``*`` is zero or more batch dimensions. @@ -11181,25 +12598,34 @@ def merge_dicts(*dicts): tensor([1.1990, 0.4099, 0.7386]) >>> A.det().log() tensor([ 0.1815, -0.8917, -0.3031]) -""") +""", +) -add_docstr(torch.slogdet, r""" +add_docstr( + torch.slogdet, + r""" slogdet(input) -> (Tensor, Tensor) Alias for :func:`torch.linalg.slogdet` -""") +""", +) -add_docstr(torch.pinverse, r""" +add_docstr( + torch.pinverse, + r""" pinverse(input, rcond=1e-15) -> Tensor Alias for :func:`torch.linalg.pinv` -""") +""", +) -add_docstr(torch.hann_window, - """ +add_docstr( + torch.hann_window, + """ hann_window(window_length, periodic=True, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor -""" + r""" +""" + + r""" Hann window function. .. math:: @@ -11219,7 +12645,8 @@ def merge_dicts(*dicts): .. note:: If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. -""" + r""" +""" + + r""" Arguments: window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic @@ -11235,14 +12662,19 @@ def merge_dicts(*dicts): Returns: Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.hamming_window, - """ +add_docstr( + torch.hamming_window, + """ hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor -""" + r""" +""" + + r""" Hamming window function. .. math:: @@ -11264,7 +12696,8 @@ def merge_dicts(*dicts): .. note:: This is a generalized version of :meth:`torch.hann_window`. -""" + r""" +""" + + r""" Arguments: window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic @@ -11282,14 +12715,19 @@ def merge_dicts(*dicts): Returns: Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.bartlett_window, - """ +add_docstr( + torch.bartlett_window, + """ bartlett_window(window_length, periodic=True, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor -""" + r""" +""" + + r""" Bartlett window function. .. math:: @@ -11311,7 +12749,8 @@ def merge_dicts(*dicts): .. note:: If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. -""" + r""" +""" + + r""" Arguments: window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic @@ -11327,14 +12766,19 @@ def merge_dicts(*dicts): Returns: Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.blackman_window, - """ +add_docstr( + torch.blackman_window, + """ blackman_window(window_length, periodic=True, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor -""" + r""" +""" + + r""" Blackman window function. .. math:: @@ -11353,7 +12797,8 @@ def merge_dicts(*dicts): .. note:: If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. -""" + r""" +""" + + r""" Arguments: window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic @@ -11369,13 +12814,19 @@ def merge_dicts(*dicts): Returns: Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.kaiser_window, """ +add_docstr( + torch.kaiser_window, + """ kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor -""" + r""" +""" + + r""" Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and @@ -11393,7 +12844,8 @@ def merge_dicts(*dicts): .. note:: If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. -""" + r""" +""" + + r""" Args: window_length (int): length of the window. periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. @@ -11407,13 +12859,18 @@ def merge_dicts(*dicts): {device} {requires_grad} -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.vander, - """ +add_docstr( + torch.vander, + """ vander(x, N=None, increasing=False) -> Tensor -""" + r""" +""" + + r""" Generates a Vandermonde matrix. The columns of the output matrix are elementwise powers of the input vector :math:`x^{{(N-1)}}, x^{{(N-2)}}, ..., x^0`. @@ -11451,11 +12908,15 @@ def merge_dicts(*dicts): [ 1, 3, 9], [ 1, 5, 25]]) -""".format(**factory_common_args)) +""".format( + **factory_common_args + ), +) -add_docstr(torch.unbind, - r""" +add_docstr( + torch.unbind, + r""" unbind(input, dim=0) -> seq Removes a tensor dimension. @@ -11472,11 +12933,13 @@ def merge_dicts(*dicts): >>> [4, 5, 6], >>> [7, 8, 9]])) (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) -""") +""", +) -add_docstr(torch.combinations, - r""" +add_docstr( + torch.combinations, + r""" combinations(input, r=2, with_replacement=False) -> seq Compute combinations of length :math:`r` of the given tensor. The behavior is similar to @@ -11486,7 +12949,7 @@ def merge_dicts(*dicts): Arguments: input (Tensor): 1D vector. r (int, optional): number of elements to combine - with_replacement (boolean, optional): whether to allow duplication in combination + with_replacement (bool, optional): whether to allow duplication in combination Returns: Tensor: A tensor equivalent to converting all the input tensors into lists, do @@ -11517,10 +12980,12 @@ def merge_dicts(*dicts): [2, 3], [3, 3]]) -""") +""", +) -add_docstr(torch.trapezoid, - r""" +add_docstr( + torch.trapezoid, + r""" trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor Computes the `trapezoidal rule `_ along @@ -11626,17 +13091,21 @@ def merge_dicts(*dicts): >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) >>> torch.trapezoid(y, x) array([2., 4., 6.]) -""") +""", +) -add_docstr(torch.trapz, - r""" +add_docstr( + torch.trapz, + r""" trapz(y, x, *, dim=-1) -> Tensor Alias for :func:`torch.trapezoid`. -""") +""", +) -add_docstr(torch.cumulative_trapezoid, - r""" +add_docstr( + torch.cumulative_trapezoid, + r""" cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor Cumulatively computes the `trapezoidal rule `_ @@ -11720,10 +13189,12 @@ def merge_dicts(*dicts): tensor([[1., 2.], [2., 4.], [3., 6.]]) -""") +""", +) -add_docstr(torch.repeat_interleave, - r""" +add_docstr( + torch.repeat_interleave, + r""" repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor Repeat elements of a tensor. @@ -11774,9 +13245,14 @@ def merge_dicts(*dicts): If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, `1` appears `n2` times, `2` appears `n3` times, etc. -""".format(**common_args)) - -add_docstr(torch.tile, r""" +""".format( + **common_args + ), +) + +add_docstr( + torch.tile, + r""" tile(input, dims) -> Tensor Constructs a tensor by repeating the elements of :attr:`input`. @@ -11814,10 +13290,12 @@ def merge_dicts(*dicts): [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]]) -""") +""", +) -add_docstr(torch.quantize_per_tensor, - r""" +add_docstr( + torch.quantize_per_tensor, + r""" quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor Converts a float tensor to a quantized tensor with given scale and zero point. @@ -11848,10 +13326,12 @@ def merge_dicts(*dicts): >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) -""") +""", +) -add_docstr(torch.quantize_per_tensor_dynamic, - r""" +add_docstr( + torch.quantize_per_tensor_dynamic, + r""" quantize_per_tensor_dynamic(input, dtype, reduce_range) -> Tensor Converts a float tensor to a quantized tensor with scale and zero_point calculated @@ -11876,10 +13356,12 @@ def merge_dicts(*dicts): zero_point=85) >>> t.int_repr() tensor([ 0, 85, 170, 255], dtype=torch.uint8) -""") +""", +) -add_docstr(torch.quantize_per_channel, - r""" +add_docstr( + torch.quantize_per_channel, + r""" quantize_per_channel(input, scales, zero_points, axis, dtype) -> Tensor Converts a float tensor to a per-channel quantized tensor with given scales and zero points. @@ -11907,11 +13389,13 @@ def merge_dicts(*dicts): >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8).int_repr() tensor([[ 0, 10], [100, 200]], dtype=torch.uint8) -""") +""", +) -add_docstr(torch.quantized_batch_norm, - r""" +add_docstr( + torch.quantized_batch_norm, + r""" quantized_batch_norm(input, weight=None, bias=None, mean, var, eps, output_scale, output_zero_point) -> Tensor Applies batch normalization on a 4D (NCHW) quantized tensor. @@ -11950,11 +13434,13 @@ def merge_dicts(*dicts): [[ 0.6000, -0.4000], [ 0.6000, -0.4000]]]], size=(2, 2, 2, 2), dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=2) -""") +""", +) -add_docstr(torch.quantized_max_pool1d, - r""" +add_docstr( + torch.quantized_max_pool1d, + r""" quantized_max_pool1d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor Applies a 1D max pooling over an input quantized tensor composed of several input planes. @@ -11979,11 +13465,13 @@ def merge_dicts(*dicts): tensor([[0.0000], [1.5000]], size=(2, 1), dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) -""") +""", +) -add_docstr(torch.quantized_max_pool2d, - r""" +add_docstr( + torch.quantized_max_pool2d, + r""" quantized_max_pool2d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor Applies a 2D max pooling over an input quantized tensor composed of several input planes. @@ -12014,11 +13502,13 @@ def merge_dicts(*dicts): [[0.0000]]]], size=(2, 2, 1, 1), dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) -""") +""", +) -add_docstr(torch.Generator, - r""" +add_docstr( + torch.Generator, + r""" Generator(device='cpu') -> Generator Creates and returns a generator object that manages the state of the algorithm which @@ -12035,11 +13525,13 @@ def merge_dicts(*dicts): >>> g_cpu = torch.Generator() >>> g_cuda = torch.Generator(device='cuda') -""") +""", +) -add_docstr(torch.Generator.set_state, - r""" +add_docstr( + torch.Generator.set_state, + r""" Generator.set_state(new_state) -> void Sets the Generator state. @@ -12052,11 +13544,13 @@ def merge_dicts(*dicts): >>> g_cpu = torch.Generator() >>> g_cpu_other = torch.Generator() >>> g_cpu.set_state(g_cpu_other.get_state()) -""") +""", +) -add_docstr(torch.Generator.get_state, - r""" +add_docstr( + torch.Generator.get_state, + r""" Generator.get_state() -> Tensor Returns the Generator state as a ``torch.ByteTensor``. @@ -12069,11 +13563,13 @@ def merge_dicts(*dicts): >>> g_cpu = torch.Generator() >>> g_cpu.get_state() -""") +""", +) -add_docstr(torch.Generator.manual_seed, - r""" +add_docstr( + torch.Generator.manual_seed, + r""" Generator.manual_seed(seed) -> Generator Sets the seed for generating random numbers. Returns a `torch.Generator` object. @@ -12093,11 +13589,13 @@ def merge_dicts(*dicts): >>> g_cpu = torch.Generator() >>> g_cpu.manual_seed(2147483647) -""") +""", +) -add_docstr(torch.Generator.initial_seed, - r""" +add_docstr( + torch.Generator.initial_seed, + r""" Generator.initial_seed() -> int Returns the initial seed for generating random numbers. @@ -12107,11 +13605,13 @@ def merge_dicts(*dicts): >>> g_cpu = torch.Generator() >>> g_cpu.initial_seed() 2147483647 -""") +""", +) -add_docstr(torch.Generator.seed, - r""" +add_docstr( + torch.Generator.seed, + r""" Generator.seed() -> int Gets a non-deterministic random number from std::random_device or the current @@ -12122,11 +13622,13 @@ def merge_dicts(*dicts): >>> g_cpu = torch.Generator() >>> g_cpu.seed() 1516516984916 -""") +""", +) -add_docstr(torch.Generator.device, - r""" +add_docstr( + torch.Generator.device, + r""" Generator.device -> device Gets the current device of the generator. @@ -12136,10 +13638,12 @@ def merge_dicts(*dicts): >>> g_cpu = torch.Generator() >>> g_cpu.device device(type='cpu') -""") +""", +) -add_docstr(torch._assert_async, - r""" +add_docstr( + torch._assert_async, + r""" _assert_async(tensor) -> void Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, @@ -12154,10 +13658,12 @@ def merge_dicts(*dicts): tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero elements (including False for boolean tensors) cause an assertion failure to be raised. -""") +""", +) -add_docstr(torch.searchsorted, - r""" +add_docstr( + torch.searchsorted, + r""" searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side='left', out=None, sorter=None) -> Tensor Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the @@ -12236,10 +13742,12 @@ def merge_dicts(*dicts): >>> torch.searchsorted(sorted_sequence_1d, values) tensor([[1, 3, 4], [1, 3, 4]]) -""") +""", +) -add_docstr(torch.bucketize, - r""" +add_docstr( + torch.bucketize, + r""" bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the @@ -12289,142 +13797,189 @@ def merge_dicts(*dicts): >>> torch.bucketize(v, boundaries, right=True) tensor([[2, 3, 5], [2, 3, 5]]) -""") +""", +) -add_docstr(torch.view_as_real_copy, - r""" +add_docstr( + torch.view_as_real_copy, + r""" Performs the same operation as :func:`torch.view_as_real`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.view_as_complex_copy, - r""" +add_docstr( + torch.view_as_complex_copy, + r""" Performs the same operation as :func:`torch.view_as_complex`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.as_strided_copy, - r""" +add_docstr( + torch.as_strided_copy, + r""" Performs the same operation as :func:`torch.as_strided`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.diagonal_copy, - r""" +add_docstr( + torch.diagonal_copy, + r""" Performs the same operation as :func:`torch.diagonal`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.expand_copy, - r""" +add_docstr( + torch.expand_copy, + r""" Performs the same operation as :func:`torch.expand`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.permute_copy, - r""" +add_docstr( + torch.permute_copy, + r""" Performs the same operation as :func:`torch.permute`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.select_copy, - r""" +add_docstr( + torch.select_copy, + r""" Performs the same operation as :func:`torch.select`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.detach_copy, - r""" +add_docstr( + torch.detach_copy, + r""" Performs the same operation as :func:`torch.detach`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.slice_copy, - r""" +add_docstr( + torch.slice_copy, + r""" Performs the same operation as :func:`torch.slice`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.split_copy, - r""" +add_docstr( + torch.split_copy, + r""" Performs the same operation as :func:`torch.split`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.split_with_sizes_copy, - r""" +add_docstr( + torch.split_with_sizes_copy, + r""" Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.squeeze_copy, - r""" +add_docstr( + torch.squeeze_copy, + r""" Performs the same operation as :func:`torch.squeeze`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.t_copy, - r""" +add_docstr( + torch.t_copy, + r""" Performs the same operation as :func:`torch.t`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.transpose_copy, - r""" +add_docstr( + torch.transpose_copy, + r""" Performs the same operation as :func:`torch.transpose`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.unsqueeze_copy, - r""" +add_docstr( + torch.unsqueeze_copy, + r""" Performs the same operation as :func:`torch.unsqueeze`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.indices_copy, - r""" +add_docstr( + torch.indices_copy, + r""" Performs the same operation as :func:`torch.indices`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.values_copy, - r""" +add_docstr( + torch.values_copy, + r""" Performs the same operation as :func:`torch.values`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.crow_indices_copy, - r""" +add_docstr( + torch.crow_indices_copy, + r""" Performs the same operation as :func:`torch.crow_indices`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.col_indices_copy, - r""" +add_docstr( + torch.col_indices_copy, + r""" Performs the same operation as :func:`torch.col_indices`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.unbind_copy, - r""" +add_docstr( + torch.unbind_copy, + r""" Performs the same operation as :func:`torch.unbind`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.view_copy, - r""" +add_docstr( + torch.view_copy, + r""" Performs the same operation as :func:`torch.view`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.unfold_copy, - r""" +add_docstr( + torch.unfold_copy, + r""" Performs the same operation as :func:`torch.unfold`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) -add_docstr(torch.alias_copy, - r""" +add_docstr( + torch.alias_copy, + r""" Performs the same operation as :func:`torch.alias`, but all output tensors are freshly created instead of aliasing the input. -""") +""", +) diff --git a/torch/_utils.py b/torch/_utils.py index e3152dc528d0d..8a539d75f5657 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,10 +1,10 @@ -import torch -from typing import Optional, List, DefaultDict, Any -import warnings -from collections import defaultdict import sys import traceback +import warnings +from collections import defaultdict +from typing import Any, DefaultDict, List, Optional +import torch def _type(self, dtype=None, non_blocking=False, **kwargs): @@ -23,9 +23,9 @@ def _type(self, dtype=None, non_blocking=False, **kwargs): **kwargs: For compatibility, may contain the key ``async`` in place of the ``non_blocking`` argument. The ``async`` arg is deprecated. """ - non_blocking = _get_async_or_non_blocking('type', non_blocking, kwargs) + non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs) if dtype is None: - return self.__module__ + '.' + self.__class__.__name__ + return self.__module__ + "." + self.__class__.__name__ if isinstance(dtype, str): dtype = _import_dotted_name(dtype) @@ -34,11 +34,13 @@ def _type(self, dtype=None, non_blocking=False, **kwargs): if self.is_sparse: if not dtype.is_sparse: raise RuntimeError("Cannot cast sparse tensor to dense tensor") - new_module_name = dtype.__module__.replace('.sparse', '') - new_values_type_name = new_module_name + '.' + dtype.__name__ + new_module_name = dtype.__module__.replace(".sparse", "") + new_values_type_name = new_module_name + "." + dtype.__name__ new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) - new_indices_type_name = new_module_name + '.LongTensor' - new_indices = torch.Tensor._indices(self).type(new_indices_type_name, non_blocking) + new_indices_type_name = new_module_name + ".LongTensor" + new_indices = torch.Tensor._indices(self).type( + new_indices_type_name, non_blocking + ) return dtype(new_indices, new_values, self.size()) if dtype.is_sparse: raise RuntimeError("Cannot cast dense tensor to sparse tensor") @@ -59,7 +61,7 @@ def _cuda(self, device=None, non_blocking=False, **kwargs): **kwargs: For compatibility, may contain the key ``async`` in place of the ``non_blocking`` argument. """ - non_blocking = _get_async_or_non_blocking('cuda', non_blocking, kwargs) + non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs) if self.is_cuda: if device is None: device = torch.cuda.current_device() @@ -75,11 +77,15 @@ def _cuda(self, device=None, non_blocking=False, **kwargs): values = torch.Tensor._values(self).cuda(device, non_blocking) return new_type(indices, values, self.size()) else: - return torch._UntypedStorage(self.size(), device=torch.device('cuda')).copy_(self, non_blocking) + untyped_storage = torch.UntypedStorage( + self.size(), device=torch.device("cuda") + ) + untyped_storage.copy_(self, non_blocking) + return untyped_storage def _get_async_or_non_blocking(function_name, non_blocking, kwargs): - """ Return the non-blocking flag given the function name and kwargs. + """Return the non-blocking flag given the function name and kwargs. Args: function_name (str): the name of the function being used. @@ -88,12 +94,12 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): """ if not kwargs: return non_blocking - if len(kwargs) != 1 or 'async' not in kwargs: + if len(kwargs) != 1 or "async" not in kwargs: message = "{}() got an unexpected keyword argument '{}'" argument = list(kwargs.keys()).pop() raise TypeError(message.format(function_name, argument)) warnings.warn("'async' is deprecated; use 'non_blocking'") - return kwargs['async'] + return kwargs["async"] # Note [Don't serialize hooks] @@ -134,14 +140,16 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): # TODO: Once we decide to break serialization FC, `storage` no longer needs to -# be a _TypedStorage +# be a TypedStorage def _rebuild_tensor(storage, storage_offset, size, stride): # first construct a tensor with the correct dtype/device - t = torch.tensor([], dtype=storage.dtype, device=storage._untyped().device) - return t.set_(storage._untyped(), storage_offset, size, stride) + t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device) + return t.set_(storage.untyped(), storage_offset, size, stride) -def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): +def _rebuild_tensor_v2( + storage, storage_offset, size, stride, requires_grad, backward_hooks +): tensor = _rebuild_tensor(storage, storage_offset, size, stride) tensor.requires_grad = requires_grad # NB: This line exists only for backwards compatibility; the @@ -167,20 +175,24 @@ def _validate_loaded_sparse_tensors(): try: for t in _sparse_tensors_to_validate: if t.is_sparse: - torch._validate_sparse_coo_tensor_args(t._indices(), t._values(), - t.size()) + torch._validate_sparse_coo_tensor_args( + t._indices(), t._values(), t.size() + ) elif t.is_sparse_csr: # TODO: Validation currently involves an expensive traversal # on CPU, which may include a device transfer. - torch._validate_sparse_csr_tensor_args(t.crow_indices(), t.col_indices(), - t.values(), t.size()) + torch._validate_sparse_csr_tensor_args( + t.crow_indices(), t.col_indices(), t.values(), t.size() + ) else: raise NotImplementedError( - '_validate_loaded_sparse_tensors for layout `%s`' % (t.layout)) + "_validate_loaded_sparse_tensors for layout `%s`" % (t.layout) + ) finally: _sparse_tensors_to_validate.clear() + def _rebuild_sparse_tensor(layout, data): """ Rebuilds a sparse tensor from its sparse storage representation. @@ -197,10 +209,13 @@ def _rebuild_sparse_tensor(layout, data): raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout)) + def _rebuild_sparse_csr_tensor(layout, data): if layout == torch.sparse_csr: crow_indices, col_indices, values, size = data - result = torch._sparse_csr_tensor_unsafe(crow_indices, col_indices, values, size) + result = torch._sparse_csr_tensor_unsafe( + crow_indices, col_indices, values, size + ) _sparse_tensors_to_validate.append(result) return result @@ -218,35 +233,71 @@ def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): - return torch.empty_strided(size, stride, dtype=dtype, device='meta', requires_grad=requires_grad) + return torch.empty_strided( + size, stride, dtype=dtype, device="meta", requires_grad=requires_grad + ) -def _rebuild_wrapper_subclass(cls, dtype, size, stride, storage_offset, layout, device, requires_grad): +def _rebuild_wrapper_subclass( + cls, dtype, size, stride, storage_offset, layout, device, requires_grad +): return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, size, strides=stride, storage_offset=storage_offset, layout=layout, - device=device, requires_grad=requires_grad) + cls, + size, + strides=stride, + storage_offset=storage_offset, + layout=layout, + device=device, + requires_grad=requires_grad, + ) # TODO: Once we decide to break serialization FC, `storage` no longer needs to -# be a _TypedStorage -def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks): +# be a TypedStorage +def _rebuild_qtensor( + storage, + storage_offset, + size, + stride, + quantizer_params, + requires_grad, + backward_hooks, +): qscheme = quantizer_params[0] if qscheme == torch.per_tensor_affine: _, scale, zero_point = quantizer_params - tensor = torch._empty_affine_quantized(size, scale=scale, zero_point=zero_point, dtype=storage.dtype, device=storage.device) + tensor = torch._empty_affine_quantized( + size, + scale=scale, + zero_point=zero_point, + dtype=storage.dtype, + device=storage.device, + ) elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): _, scales, zero_points, axis = quantizer_params if type(scales) is list and type(zero_points) is list: if qscheme == torch.per_channel_affine: scales = torch.tensor(scales, dtype=torch.double, device=storage.device) - zero_points = torch.tensor(zero_points, dtype=torch.long, device=storage.device) + zero_points = torch.tensor( + zero_points, dtype=torch.long, device=storage.device + ) else: scales = torch.tensor(scales, dtype=torch.float, device=storage.device) - zero_points = torch.tensor(zero_points, dtype=torch.float, device=storage.device) + zero_points = torch.tensor( + zero_points, dtype=torch.float, device=storage.device + ) tensor = torch._empty_per_channel_affine_quantized( - size, scales=scales, zero_points=zero_points, axis=axis, dtype=storage.dtype, device=storage.device) + size, + scales=scales, + zero_points=zero_points, + axis=axis, + dtype=storage.dtype, + device=storage.device, + ) else: - raise RuntimeError("Can't deserialize quantized tensor with qscheme {}".format(qscheme)) + raise RuntimeError( + "Can't deserialize quantized tensor with qscheme {}".format(qscheme) + ) tensor.set_(storage, storage_offset, size, stride) tensor.requires_grad = requires_grad # NB: This line exists only for backwards compatibility; the @@ -255,6 +306,7 @@ def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, re tensor._backward_hooks = backward_hooks return tensor + def _rebuild_parameter(data, requires_grad, backward_hooks): param = torch.nn.Parameter(data, requires_grad) # NB: This line exists only for backwards compatibility; the @@ -266,7 +318,7 @@ def _rebuild_parameter(data, requires_grad, backward_hooks): def _import_dotted_name(name): - components = name.split('.') + components = name.split(".") obj = __import__(components[0]) for component in components[1:]: obj = getattr(obj, component) @@ -275,7 +327,7 @@ def _import_dotted_name(name): # Taken from python 3.5 docs def _accumulate(iterable, fn=lambda x, y: x + y): - 'Return running totals' + "Return running totals" # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 it = iter(iterable) @@ -317,8 +369,12 @@ def _flatten_sparse_tensors(tensors): A tuple of two contiguous 1D buffers, one containing input tensors' indices and the other containing the values. """ - flat_indices = torch._C._nn.flatten_dense_tensors([torch.Tensor._indices(t) for t in tensors]) - flat_values = torch._C._nn.flatten_dense_tensors([torch.Tensor._values(t) for t in tensors]) + flat_indices = torch._C._nn.flatten_dense_tensors( + [torch.Tensor._indices(t) for t in tensors] + ) + flat_values = torch._C._nn.flatten_dense_tensors( + [torch.Tensor._values(t) for t in tensors] + ) return flat_indices, flat_values @@ -354,8 +410,12 @@ def _unflatten_sparse_tensors(flat, tensors): flat. """ flat_indices, flat_values = flat - indices = torch._C._nn.unflatten_dense_tensors(flat_indices, [torch.Tensor._indices(t) for t in tensors]) - values = torch._C._nn.unflatten_dense_tensors(flat_values, [torch.Tensor._values(t) for t in tensors]) + indices = torch._C._nn.unflatten_dense_tensors( + flat_indices, [torch.Tensor._indices(t) for t in tensors] + ) + values = torch._C._nn.unflatten_dense_tensors( + flat_values, [torch.Tensor._values(t) for t in tensors] + ) outputs = [] for t, i, v in zip(tensors, indices, values): outputs.append(t.new(i, v, t.size())) @@ -402,7 +462,10 @@ def _take_tensors(tensors, size_limit): if tensor.is_sparse: indices = torch.Tensor._indices(tensor) values = torch.Tensor._values(tensor) - size = indices.numel() * indices.element_size() + values.numel() * values.element_size() + size = ( + indices.numel() * indices.element_size() + + values.numel() * values.element_size() + ) else: size = tensor.numel() * tensor.element_size() buf_and_size = buf_dict[t] @@ -421,8 +484,9 @@ def _take_tensors(tensors, size_limit): def annotate(ret, **kwargs): def dec(fun): fun.__annotations__ = dict(kwargs) - fun.__annotations__['return'] = ret + fun.__annotations__["return"] = ret return fun + return dec @@ -434,14 +498,17 @@ def dec(fun): # and the frame (which holds reference to all the object in its temporary scope) # holding reference the traceback. + class KeyErrorMessage(str): r"""str subclass that returns itself in repr""" + def __repr__(self): return self class ExceptionWrapper(object): r"""Wraps an exception plus traceback to communicate across threads""" + def __init__(self, exc_info=None, where="in background"): # It is important that we don't store exc_info, see # NOTE [ Python Traceback Reference Cycle Problem ] @@ -456,7 +523,8 @@ def reraise(self): # Format a message such as: "Caught ValueError in DataLoader worker # process 2. Original Traceback:", followed by the traceback. msg = "Caught {} {}.\nOriginal {}".format( - self.exc_type.__name__, self.where, self.exc_msg) + self.exc_type.__name__, self.where, self.exc_msg + ) if self.exc_type == KeyError: # KeyError calls repr() on its argument (usually a dict key). This # makes stack traces unreadable. It will not be changed in Python @@ -508,6 +576,7 @@ def _get_devices_properties(device_ids): # all device properties return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids] + def get_current_device_index() -> int: r"""Checks if there are CUDA devices available and returns the device index of the current default CUDA device. @@ -518,7 +587,10 @@ def get_current_device_index() -> int: return torch.cuda.current_device() return -1 -def _get_device_index(device: Any, optional: bool = False, allow_cpu: bool = False) -> int: + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: r"""Gets the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. @@ -538,9 +610,9 @@ def _get_device_index(device: Any, optional: bool = False, allow_cpu: bool = Fal device = torch.device(device) device_idx: Optional[int] = None if isinstance(device, torch.device): - if not allow_cpu and device.type == 'cpu': - raise ValueError('Expected a non cpu device, but got: {}'.format(device)) - device_idx = -1 if device.type == 'cpu' else device.index + if not allow_cpu and device.type == "cpu": + raise ValueError("Expected a non cpu device, but got: {}".format(device)) + device_idx = -1 if device.type == "cpu" else device.index if isinstance(device, int): device_idx = device if device_idx is None: @@ -555,8 +627,10 @@ def _get_device_index(device: Any, optional: bool = False, allow_cpu: bool = Fal else: device_idx = _get_current_device_index() else: - raise ValueError('Expected a torch.device with a specified index ' - 'or an integer, but got:{}'.format(device)) + raise ValueError( + "Expected a torch.device with a specified index " + "or an integer, but got:{}".format(device) + ) return device_idx @@ -565,15 +639,20 @@ def _handle_complex(tensor): Returns a real view of a tensor if complex dtype else just the tensor need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule """ - return torch.view_as_real(tensor) if not isinstance(tensor, - torch.nn.UninitializedParameter) and tensor.is_complex() else tensor + return ( + torch.view_as_real(tensor) + if not isinstance(tensor, torch.nn.UninitializedParameter) + and tensor.is_complex() + else tensor + ) + def _element_size(dtype): """ Returns the element size for a dtype, in bytes """ if not isinstance(dtype, torch.dtype): - raise RuntimeError(f'expected torch.dtype, but got {type(dtype)}') + raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}") if dtype.is_complex: return torch.finfo(dtype).bits >> 2 @@ -585,6 +664,7 @@ def _element_size(dtype): else: return torch.iinfo(dtype).bits >> 3 + class _ClassPropertyDescriptor: def __init__(self, fget, fset=None): self.fget = fget @@ -594,6 +674,7 @@ def __get__(self, instance, owner=None): owner = type(instance) return self.fget.__get__(instance, owner)() + def classproperty(func): if not isinstance(func, (classmethod, staticmethod)): func = classmethod(func) diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 1bbfe3297fe3b..fed4bfad77afb 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -8,17 +8,18 @@ # use is the FB build environment, where this source file is replaced # by an equivalent. -if sys.executable == 'torch_deploy': +if sys.executable == "torch_deploy": # __file__ is meaningless in the context of frozen torch used in torch deploy. # setting empty torch_parent should allow below functions to operate without crashing, # but it's unclear if there is a valid use case for them in the context of deploy. torch_parent = "" else: - if os.path.basename(os.path.dirname(__file__)) == 'shared': + if os.path.basename(os.path.dirname(__file__)) == "shared": torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) else: torch_parent = os.path.dirname(os.path.dirname(__file__)) + def get_file_path(*path_components: str) -> str: return os.path.join(torch_parent, *path_components) @@ -33,7 +34,6 @@ def get_writable_path(path: str) -> str: return tempfile.mkdtemp(suffix=os.path.basename(path)) - def prepare_multiprocessing_environment(path: str) -> None: pass @@ -42,7 +42,7 @@ def resolve_library_path(path: str) -> str: return os.path.realpath(path) -TEST_MASTER_ADDR = '127.0.0.1' +TEST_MASTER_ADDR = "127.0.0.1" TEST_MASTER_PORT = 29500 # USE_GLOBAL_DEPS controls whether __init__.py tries to load # libtorch_global_deps, see Note [Global dependencies] diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 1c87554626c3b..5bb88c06ed7d2 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,99 +1,124 @@ -import torch import functools -from torch import Tensor -from typing import Any, Callable, Optional, Tuple, Union, List -from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten import warnings +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten in_dims_t = Union[int, Tuple] out_dims_t = Union[int, Tuple[int, ...]] # Checks that all args-to-be-batched have the same batch dim size def _validate_and_get_batch_size( - flat_in_dims: List[Optional[int]], - flat_args: List) -> int: - batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args) - if in_dim is not None] + flat_in_dims: List[Optional[int]], flat_args: List +) -> int: + batch_sizes = [ + arg.size(in_dim) + for in_dim, arg in zip(flat_in_dims, flat_args) + if in_dim is not None + ] if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]): raise ValueError( - f'vmap: Expected all tensors to have the same size in the mapped ' - f'dimension, got sizes {batch_sizes} for the mapped dimension') + f"vmap: Expected all tensors to have the same size in the mapped " + f"dimension, got sizes {batch_sizes} for the mapped dimension" + ) return batch_sizes[0] + def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int: if isinstance(batched_outputs, tuple): return len(batched_outputs) return 1 + # If value is a tuple, check it has length `num_elements`. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times -def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple: +def _as_tuple( + value: Any, num_elements: int, error_message_lambda: Callable[[], str] +) -> Tuple: if not isinstance(value, tuple): return (value,) * num_elements if len(value) != num_elements: raise ValueError(error_message_lambda()) return value + # Creates BatchedTensors for every Tensor in arg that should be batched. # Returns the (potentially) batched arguments and the batch_size. def _create_batched_inputs( - in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable) -> Tuple[Tuple, int]: + in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable +) -> Tuple[Tuple, int]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' - f'expected `in_dims` to be int or a (potentially nested) tuple ' - f'matching the structure of inputs, got: {type(in_dims)}.') + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"expected `in_dims` to be int or a (potentially nested) tuple " + f"matching the structure of inputs, got: {type(in_dims)}." + ) if len(args) == 0: raise ValueError( - f'vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add ' - f'inputs, or you are trying to vmap over a function with no inputs. ' - f'The latter is unsupported.') + f"vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add " + f"inputs, or you are trying to vmap over a function with no inputs. " + f"The latter is unsupported." + ) flat_args, args_spec = tree_flatten(args) flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) if flat_in_dims is None: raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' - f'in_dims is not compatible with the structure of `inputs`. ' - f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs ' - f'has structure {args_spec}.') + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"in_dims is not compatible with the structure of `inputs`. " + f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs " + f"has structure {args_spec}." + ) for arg, in_dim in zip(flat_args, flat_in_dims): if not isinstance(in_dim, int) and in_dim is not None: raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' - f'Got in_dim={in_dim} for an input but in_dim must be either ' - f'an integer dimension or None.') + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for an input but in_dim must be either " + f"an integer dimension or None." + ) if isinstance(in_dim, int) and not isinstance(arg, Tensor): raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' - f'Got in_dim={in_dim} for an input but the input is of type ' - f'{type(arg)}. We cannot vmap over non-Tensor arguments, ' - f'please use None as the respective in_dim') + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for an input but the input is of type " + f"{type(arg)}. We cannot vmap over non-Tensor arguments, " + f"please use None as the respective in_dim" + ) if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()): raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' - f'Got in_dim={in_dim} for some input, but that input is a Tensor ' - f'of dimensionality {arg.dim()} so expected in_dim to satisfy ' - f'0 <= in_dim < {arg.dim()}.') + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for some input, but that input is a Tensor " + f"of dimensionality {arg.dim()} so expected in_dim to satisfy " + f"0 <= in_dim < {arg.dim()}." + ) batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] - batched_inputs = [arg if in_dim is None else - torch._add_batch_dim(arg, in_dim, vmap_level) - for in_dim, arg in zip(flat_in_dims, flat_args)] + batched_inputs = [ + arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level) + for in_dim, arg in zip(flat_in_dims, flat_args) + ] return tree_unflatten(batched_inputs, args_spec), batch_size + # Undos the batching (and any batch dimensions) associated with the `vmap_level`. def _unwrap_batched( - batched_outputs: Union[Tensor, Tuple[Tensor, ...]], - out_dims: out_dims_t, vmap_level: int, batch_size: int, func: Callable, - allow_none_pass_through: bool = False) -> Tuple: + batched_outputs: Union[Tensor, Tuple[Tensor, ...]], + out_dims: out_dims_t, + vmap_level: int, + batch_size: int, + func: Callable, + allow_none_pass_through: bool = False, +) -> Tuple: num_outputs = _num_outputs(batched_outputs) out_dims_as_tuple = _as_tuple( - out_dims, num_outputs, - lambda: f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must ' - f'have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.') + out_dims, + num_outputs, + lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must " + f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.", + ) # NOTE [Ignored _remove_batch_dim, _add_batch_dim] # There is something wrong with our type bindings for functions that begin @@ -102,11 +127,20 @@ def _unwrap_batched( out_dim = out_dims_as_tuple[0] return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value] if allow_none_pass_through: - return tuple((torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) if out is not None else None) - for out, out_dim in zip(batched_outputs, out_dims_as_tuple)) + return tuple( + ( + torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) + if out is not None + else None + ) + for out, out_dim in zip(batched_outputs, out_dims_as_tuple) + ) else: - return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) - for out, out_dim in zip(batched_outputs, out_dims_as_tuple)) + return tuple( + torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) + for out, out_dim in zip(batched_outputs, out_dims_as_tuple) + ) + # Checks that `fn` returned one or more Tensors and nothing else. # NB: A python function that return multiple arguments returns a single tuple, @@ -116,26 +150,34 @@ def _validate_outputs(outputs: Any, func: Callable) -> None: if isinstance(outputs, Tensor): return if not isinstance(outputs, tuple): - raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return ' - f'Tensors, got type {type(outputs)} as the return.') + raise ValueError( + f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " + f"Tensors, got type {type(outputs)} as the return." + ) for idx, output in enumerate(outputs): if isinstance(output, Tensor): continue - raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return ' - f'Tensors, got type {type(output)} for return {idx}.') + raise ValueError( + f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " + f"Tensors, got type {type(output)} for return {idx}." + ) + def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None: if isinstance(out_dims, int): return - if not isinstance(out_dims, tuple) or \ - not all([isinstance(out_dim, int) for out_dim in out_dims]): + if not isinstance(out_dims, tuple) or not all( + [isinstance(out_dim, int) for out_dim in out_dims] + ): raise ValueError( - f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be ' - f'an int or a tuple of int representing where in the outputs the ' - f'vmapped dimension should appear.') + f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be " + f"an int or a tuple of int representing where in the outputs the " + f"vmapped dimension should appear." + ) + def _get_name(func: Callable): - if hasattr(func, '__name__'): + if hasattr(func, "__name__"): return func.__name__ # Not all callables have __name__, in fact, only static functions/methods do. @@ -143,6 +185,7 @@ def _get_name(func: Callable): # examples, don't have a __name__. return repr(func) + # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, # sends those into func, and then unwraps the output BatchedTensors. Operations # on BatchedTensors perform the batched operations that the user is asking for. @@ -248,16 +291,23 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca sequences out of the box. """ warnings.warn( - 'Please use functorch.vmap instead of torch.vmap ' - '(https://github.com/pytorch/functorch). ' - 'We\'ve moved development on torch.vmap over to functorch; ' - 'functorch\'s vmap has a multitude of significant performance and ' - 'functionality improvements.', - stacklevel=2) + "Please use functorch.vmap instead of torch.vmap " + "(https://github.com/pytorch/functorch). " + "We've moved development on torch.vmap over to functorch; " + "functorch's vmap has a multitude of significant performance and " + "functionality improvements.", + stacklevel=2, + ) return _vmap(func, in_dims, out_dims) + # A version of vmap but without the initial "experimental prototype" warning -def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0, allow_none_pass_through: bool = False) -> Callable: +def _vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + allow_none_pass_through: bool = False, +) -> Callable: # The `allow_none_pass_through` argument is a temporary workaround may be removed. # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine, # which may return None if any of the inputs are unused. See the issue discussing this: @@ -267,12 +317,21 @@ def wrapped(*args): _check_out_dims_is_int_or_int_tuple(out_dims, func) vmap_level = torch._C._vmapmode_increment_nesting() try: - batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func) + batched_inputs, batch_size = _create_batched_inputs( + in_dims, args, vmap_level, func + ) batched_outputs = func(*batched_inputs) if not allow_none_pass_through: _validate_outputs(batched_outputs, func) - return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func, - allow_none_pass_through=allow_none_pass_through) + return _unwrap_batched( + batched_outputs, + out_dims, + vmap_level, + batch_size, + func, + allow_none_pass_through=allow_none_pass_through, + ) finally: torch._C._vmapmode_decrement_nesting() + return wrapped diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 22d92649775e9..fd6ce5e7693dd 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -5,6 +5,8 @@ from typing import Any, Optional from torch.types import _dtype +__all__ = ['autocast_decorator', 'autocast'] + def autocast_decorator(autocast_instance, func): @functools.wraps(func) def decorate_autocast(*args, **kwargs): @@ -168,10 +170,12 @@ def forward(self, x): (see :ref:`Working with Multiple GPUs`). Args: - device_type(string, required): Whether to use 'cuda' or 'cpu' device - enabled(bool, optional, default=True): Whether autocasting should be enabled in the region. + device_type(str, required): Whether to use 'cuda' or 'cpu' device + enabled(bool, optional): Whether autocasting should be enabled in the region. + Default: ``True`` dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16. - cache_enabled(bool, optional, default=True): Whether the weight cache inside autocast should be enabled. + cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. + Default: ``True`` """ def __init__(self, device_type : str, dtype : Optional[_dtype] = None, diff --git a/torch/ao/nn/sparse/quantized/dynamic/linear.py b/torch/ao/nn/sparse/quantized/dynamic/linear.py index 067a42a504b19..e1742d7ed1097 100644 --- a/torch/ao/nn/sparse/quantized/dynamic/linear.py +++ b/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -7,6 +7,7 @@ import torch.nn.intrinsic as nni from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr +__all__ = ['Linear'] class Linear(torch.nn.Module): r""" @@ -32,7 +33,9 @@ def __init__(self, in_features, out_features, row_block_size, col_block_size, bi qweight = torch._empty_affine_quantized([out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8) - self._packed_params = linear.LinearPackedParams(dtype) + self._packed_params = linear.LinearPackedParams(row_block_size=row_block_size, + col_block_size=col_block_size, + dtype=dtype) self._packed_params.set_weight_bias(qweight, bias, row_block_size, col_block_size) def _get_name(self): @@ -86,6 +89,8 @@ def bias(self): def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor], row_block_size: Optional[int], col_block_size: Optional[int]) -> None: assert row_block_size is not None and col_block_size is not None + self.out_features = w.shape[0] + self.in_features = w.shape[1] self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) @classmethod diff --git a/torch/ao/nn/sparse/quantized/linear.py b/torch/ao/nn/sparse/quantized/linear.py index c57122fbf4112..81c224666fb18 100644 --- a/torch/ao/nn/sparse/quantized/linear.py +++ b/torch/ao/nn/sparse/quantized/linear.py @@ -3,23 +3,20 @@ import torch from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr +__all__ = ['LinearPackedParams', 'Linear'] + # TODO (zaf): Inherit from `quantized.LinearPackedParams` (T83294430) class LinearPackedParams(torch.nn.Module): _version = 1 def __init__(self, row_block_size=1, col_block_size=4, dtype=torch.qint8): super().__init__() - self.prepack_op = torch.ops.sparse.qlinear_prepack - self.unpack_op = torch.ops.sparse.qlinear_unpack if dtype != torch.qint8: raise NotImplementedError("Linear prepacking only supports QINT8") self.dtype = dtype wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8) self.set_weight_bias(wq, None, row_block_size, col_block_size) - # Hack to make torch.jit.script/torch.jit.load work - # Once we have self.unpack_op working we wont need this. - self.__annotations__['bias'] = Optional[torch.Tensor] def _get_name(self): return "SparseQuantizedLinearPackedParams" @@ -28,18 +25,12 @@ def _get_name(self): def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor], row_block_size: Optional[int], col_block_size: Optional[int]) -> None: assert row_block_size is not None and col_block_size is not None - self._packed_params = self.prepack_op(weight, bias, row_block_size, col_block_size) - # TODO: We will save the original weight and bias, because the unpacking is not yet there. - self.weight = weight - self.bias = bias - self.row_block_size = row_block_size - self.col_block_size = col_block_size + self._packed_params = torch.ops.sparse.qlinear_prepack(weight, bias, row_block_size, col_block_size) @torch.jit.export def _weight_bias(self): - # TODO: The unpacking is not yet implemented - # return self.unpack_op(self._packed_params) - return self.weight, self.bias, self.row_block_size, self.col_block_size + (weight, bias, block_sizes) = torch.ops.sparse.qlinear_unpack(self._packed_params) + return (weight, bias, block_sizes[0], block_sizes[1]) def forward(self, x): return x @@ -63,14 +54,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, @torch.jit.export def __getstate__(self): - qweight, bias, row_block_size, col_block_size = self._weight_bias() - return qweight, bias, row_block_size, col_block_size, self.training, self.dtype + return self._packed_params, self.training, self.dtype @torch.jit.export def __setstate__(self, state): - self.set_weight_bias(state[0], state[1], state[2], state[3]) - self.training = state[4] - self.dtype = state[5] + (self._packed_params, self.training, self.dtype) = state def __repr__(self): return self._weight_bias().__repr__() @@ -99,7 +87,9 @@ def __init__(self, in_features, out_features, row_block_size, col_block_size, bi qweight = torch._empty_affine_quantized([out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8) - self._packed_params = LinearPackedParams(dtype) + self._packed_params = LinearPackedParams(row_block_size=row_block_size, + col_block_size=col_block_size, + dtype=dtype) self._packed_params.set_weight_bias(qweight, bias, row_block_size, col_block_size) self.scale = 1.0 self.zero_point = 0 diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 2a54535678b27..143265a8f1749 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -467,13 +467,14 @@ def prepare_model_outputs( qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None) float_module.qconfig = qconfig_debug # type: ignore[assignment] - prepare(float_module, inplace=True, allow_list=allow_list) + prepare(float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={}) q_module.qconfig = qconfig_debug # type: ignore[assignment] prepare( q_module, inplace=True, allow_list=allow_list, observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, + prepare_custom_config_dict={} ) diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index fc53a24fc53cc..7bcd0bab10cb3 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -310,6 +310,16 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: set([ nn.Softmax, ]), + # PReLU + set([ + nn.PReLU, + nnq.PReLU, + ]), + # F.prelu + set([ + F.prelu, + toq.prelu, + ]), ] # for each floating point op, add versions of the op added by @@ -468,6 +478,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: operator.mul, torch.mul, torch.sum, + F.prelu, ]) FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set() @@ -488,6 +499,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: toq.layer_norm, toq.leaky_relu, toq.dropout, + toq.prelu, # TODO(future PR): implement shadowing for binary ops and # uncomment below # toq.add, @@ -568,6 +580,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nn.SiLU, nn.Mish, nn.Softmax, + nn.PReLU, nni.BNReLU2d, nni.BNReLU3d, nni.ConvReLU1d, @@ -613,6 +626,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nnq.EmbeddingBag, nnq.Dropout, nnq.Softmax, + nnq.PReLU, nniq.BNReLU2d, nniq.BNReLU3d, nniq.ConvReLU1d, diff --git a/torch/ao/quantization/backend_config/__init__.py b/torch/ao/quantization/backend_config/__init__.py index e5ce1734d1799..1f62594a722a4 100644 --- a/torch/ao/quantization/backend_config/__init__.py +++ b/torch/ao/quantization/backend_config/__init__.py @@ -1,5 +1,7 @@ -from .tensorrt import get_tensorrt_backend_config_dict +from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig from .native import get_native_backend_config_dict +from .observation_type import ObservationType +from .tensorrt import get_tensorrt_backend_config_dict # TODO: add more validations def validate_backend_config_dict(backend_config_dict): diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py new file mode 100644 index 0000000000000..aea6945aca466 --- /dev/null +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -0,0 +1,390 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch.ao.quantization.backend_config.observation_type import ObservationType +from torch.ao.quantization.observer import _PartialWrapper +from torch.ao.quantization.utils import Pattern + + +__all__ = [ + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", +] + + +# DTypeConfig dict keys +INPUT_DTYPE_DICT_KEY = "input_dtype" +OUTPUT_DTYPE_DICT_KEY = "output_dtype" +WEIGHT_DTYPE_DICT_KEY = "weight_dtype" +BIAS_DTYPE_DICT_KEY = "bias_dtype" +IS_DYNAMIC_DICT_KEY = "is_dynamic" + +# BackendConfig dict keys +NAME_DICT_KEY = "name" +CONFIGS_DICT_KEY = "configs" + +# BackendPatternConfig dict keys +PATTERN_DICT_KEY = "pattern" +OBSERVATION_TYPE_DICT_KEY = "observation_type" +DTYPE_CONFIGS_DICT_KEY = "dtype_configs" +ROOT_MODULE_DICT_KEY = "root_module" +QAT_MODULE_DICT_KEY = "qat_module" +REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root" +FUSED_MODULE_DICT_KEY = "fused_module" +FUSER_METHOD_DICT_KEY = "fuser_method" +ROOT_NODE_GETTER_DICT_KEY = "root_node_getter" +EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter" +NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" +INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" +INPUT_OUTPUT_OBSERVED_DICT_KEY = "input_output_observed" +OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY = "overwrite_output_fake_quantize" +OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer" + + +@dataclass +class DTypeConfig: + """ + Config for the set of supported input/output activation, weight, and bias data types for the + patterns defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. + """ + input_dtype: Optional[torch.dtype] = None + output_dtype: Optional[torch.dtype] = None + weight_dtype: Optional[torch.dtype] = None + bias_dtype: Optional[torch.dtype] = None + is_dynamic: Optional[bool] = None + + @classmethod + def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig: + """ + Create a `DTypeConfig` from a dictionary with the following items (all optional): + + "input_dtype": torch.dtype + "output_dtype": torch.dtype + "weight_dtype": torch.dtype + "bias_type": torch.dtype + "is_dynamic": bool + """ + input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None) + output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None) + weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None) + bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None) + is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None) + return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this `DTypeConfig` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. + """ + dtype_config_dict: Dict[str, Any] = {} + if self.input_dtype is not None: + dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype + if self.output_dtype is not None: + dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = self.output_dtype + if self.weight_dtype is not None: + dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = self.weight_dtype + if self.bias_dtype is not None: + dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype + if self.is_dynamic is not None: + dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic + return dtype_config_dict + + +class BackendConfig: + # TODO: refer to NativeBackendConfig once that is implemented + """ + Config that defines the set of patterns that can be quantized on a given backend, and how reference + quantized models can be produced from these patterns. + + A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph + of the above. Each pattern supported on the target backend can be individually configured through + :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: + (1) The supported input/output activation, weight, and bias data types + (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and + (3) (Optionally) Fusion, QAT, and reference module mappings. + + The format of the patterns is described in: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md + + Example usage:: + + import torch + from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType + from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2 + + weighted_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_type=torch.float) + + linear_config = BackendPatternConfig(torch.nn.Linear) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Linear) \ + .set_qat_module(torch.nn.qat.Linear) \ + .set_reference_quantized_module(torch.nn.quantized._reference.Linear) + + conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_fused_module(torch.nn.intrinsic.ConvReLU2d) \ + .set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d)) + + backend_config = BackendConfig("my_backend") \ + .set_backend_pattern_config(linear_config) \ + .set_backend_pattern_config(conv_relu_config) + """ + def __init__(self, name: str = ""): + self.name = name + self.configs: Dict[Pattern, BackendPatternConfig] = {} + + def set_name(self, name: str) -> BackendConfig: + """ + Set the name of the target backend. + """ + self.name = name + return self + + def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig: + """ + Set the config for an op that can be run on the target backend. + This overrides any existing config for the given op. + """ + self.configs[config.pattern] = config + return self + + @classmethod + def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig: + """ + Create a `BackendConfig` from a dictionary with the following items: + + "name": the name of the target backend + "configs": a list of dictionaries that each represents a `BackendPatternConfig` + """ + for dict_key in [NAME_DICT_KEY, CONFIGS_DICT_KEY]: + if dict_key not in backend_config_dict: + raise ValueError("backend_config_dict must contain '%s'" % dict_key) + conf = cls(backend_config_dict[NAME_DICT_KEY]) + for d in backend_config_dict[CONFIGS_DICT_KEY]: + if isinstance(d, BackendPatternConfig): + conf.set_backend_pattern_config(d) + elif isinstance(d, Dict): + conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d)) + else: + raise ValueError("Expected backend_config_dict['%s'] to be a dictionary" % CONFIGS_DICT_KEY) + return conf + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this `BackendConfig` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. + """ + return { + NAME_DICT_KEY: self.name, + CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs.values()], + } + + +class BackendPatternConfig: + """ + Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. + + The user can configure how a operator pattern graph is handled on a given backend using the following methods: + `set_observation_type`: sets how observers should be inserted for this pattern. + See :class:`~torch.ao.quantization.backend_config.ObservationType` + `add_dtype_config`: add a set of supported data types for this pattern + `set_root_module`: sets the module that represents the root for this pattern + `set_qat_module`: sets the module that represents the QAT implementation for this pattern + `set_reference_quantized_module`: sets the module that represents the reference quantized + implementation for this pattern's root module. + `set_fused_module`: sets the module that represents the fused implementation for this pattern + `set_fuser_method`: sets the function that specifies how to fuse the pattern for this pattern + + For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. + """ + def __init__(self, pattern: Pattern): + self.pattern = pattern + self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + self.dtype_configs: List[DTypeConfig] = [] + self.root_module: Optional[torch.nn.Module] = None + self.qat_module: Optional[torch.nn.Module] = None + self.reference_quantized_module: Optional[torch.nn.Module] = None + self.fused_module: Optional[torch.nn.Module] = None + self.fuser_method: Optional[Callable] = None + + # Temporary/internal configs + self._root_node_getter: Optional[Callable] = None + self._extra_inputs_getter: Optional[Callable] = None + self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {} + self._input_type_to_index: Dict[str, int] = {} + self._input_output_observed: Optional[bool] = None + self._overwrite_output_fake_quantize: Optional[_PartialWrapper] = None + self._overwrite_output_observer: Optional[_PartialWrapper] = None + + def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig: + """ + Set how observers should be inserted for this pattern. + """ + self.observation_type = observation_type + return self + + def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: + """ + Register a set of supported input/output activation, weight, and bias data types for this pattern. + """ + self.dtype_configs.append(dtype_config) + return self + + def set_root_module(self, root_module: torch.nn.Module) -> BackendPatternConfig: + """ + Set the module that represents the root for this pattern. + For example, the root module for :class:`torch.nn.intrinsic.LinearReLU` should be :class:`torch.nn.Linear`. + """ + self.root_module = root_module + return self + + def set_qat_module(self, qat_module: torch.nn.Module) -> BackendPatternConfig: + """ + Set the module that represents the QAT implementation for this pattern. + """ + self.qat_module = qat_module + return self + + def set_reference_quantized_module(self, reference_quantized_module: torch.nn.Module) -> BackendPatternConfig: + """ + Set the module that represents the reference quantized implementation for this pattern's root module. + """ + self.reference_quantized_module = reference_quantized_module + return self + + def set_fused_module(self, fused_module: torch.nn.Module) -> BackendPatternConfig: + """ + Set the module that represents the fused implementation for this pattern. + """ + self.fused_module = fused_module + return self + + def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: + """ + Set the function that specifies how to fuse the pattern for this pattern. + """ + self.fuser_method = fuser_method + return self + + def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig: + self._root_node_getter = root_node_getter + return self + + def _set_extra_inputs_getter(self, extra_inputs_getter: Callable) -> BackendPatternConfig: + self._extra_inputs_getter = extra_inputs_getter + return self + + def _set_num_tensor_args_to_observation_type( + self, num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> BackendPatternConfig: + self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type + return self + + def _set_input_type_to_index(self, input_type_to_index: Dict[str, int]) -> BackendPatternConfig: + self._input_type_to_index = input_type_to_index + return self + + def _set_input_output_observed(self, input_output_observed: bool) -> BackendPatternConfig: + self._input_output_observed = input_output_observed + return self + + def _set_overwrite_output_fake_quantize(self, overwrite_output_fake_quantize: _PartialWrapper) -> BackendPatternConfig: + self._overwrite_output_fake_quantize = overwrite_output_fake_quantize + return self + + def _set_overwrite_output_observer(self, overwrite_output_observer: _PartialWrapper) -> BackendPatternConfig: + self._overwrite_output_observer = overwrite_output_observer + return self + + @classmethod + def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig: + """ + Create a `BackendPatternConfig` from a dictionary with the following items: + + "pattern": the pattern being configured + "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how + observers should be inserted for this pattern + "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig`s + "root_module": a :class:`torch.nn.Module` that represents the root for this pattern + "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern + "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized + implementation for this pattern's root module. + "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern + "fuser_method": a function that specifies how to fuse the pattern for this pattern + """ + def _get_dtype_config(obj: Any) -> DTypeConfig: + """ + Convert the given object into a `DTypeConfig` if possible, else throw an exception. + """ + if isinstance(obj, DTypeConfig): + return obj + if isinstance(obj, Dict): + return DTypeConfig.from_dict(obj) + raise ValueError("Expected a list of DTypeConfigs in backend_pattern_config_dict[\"%s\"], got '%s'" % + (DTYPE_CONFIGS_DICT_KEY, type(obj))) + + if PATTERN_DICT_KEY not in backend_pattern_config_dict: + raise ValueError("backend_pattern_config_dict must contain '%s'" % PATTERN_DICT_KEY) + conf = cls(backend_pattern_config_dict[PATTERN_DICT_KEY]) + if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict: + conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY]) + for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []): + conf.add_dtype_config(_get_dtype_config(d)) + conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None)) + conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) + conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None)) + conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None)) + conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None)) + conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None)) + conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None)) + conf._set_num_tensor_args_to_observation_type( + backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {})) + conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {})) + conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY, None)) + conf._set_overwrite_output_fake_quantize(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY, None)) + conf._set_overwrite_output_observer(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_OBSERVER_DICT_KEY, None)) + return conf + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this `BackendPatternConfig` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. + """ + backend_pattern_config_dict: Dict[str, Any] = { + PATTERN_DICT_KEY: self.pattern, + OBSERVATION_TYPE_DICT_KEY: self.observation_type, + DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs], + } + if self.root_module is not None: + backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module + if self.qat_module is not None: + backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module + if self.reference_quantized_module is not None: + backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = self.reference_quantized_module + if self.fused_module is not None: + backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module + if self.fuser_method is not None: + backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method + if self._root_node_getter is not None: + backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = self._root_node_getter + if self._extra_inputs_getter is not None: + backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = self._extra_inputs_getter + if len(self._num_tensor_args_to_observation_type) > 0: + backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY] = self._num_tensor_args_to_observation_type + if len(self._input_type_to_index) > 0: + backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index + if self._input_output_observed is not None: + backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY] = self._input_output_observed + if self._overwrite_output_fake_quantize is not None: + backend_pattern_config_dict[OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY] = self._overwrite_output_fake_quantize + if self._overwrite_output_observer is not None: + backend_pattern_config_dict[OVERWRITE_OUTPUT_OBSERVER_DICT_KEY] = self._overwrite_output_observer + return backend_pattern_config_dict diff --git a/torch/ao/quantization/backend_config/native.py b/torch/ao/quantization/backend_config/native.py index 4a952ec6f545d..bf259b875aaeb 100644 --- a/torch/ao/quantization/backend_config/native.py +++ b/torch/ao/quantization/backend_config/native.py @@ -10,14 +10,11 @@ _get_share_qparams_op_configs, ) from .observation_type import ObservationType -from ..observer import ( - default_fixed_qparams_range_0to1_observer, - default_fixed_qparams_range_neg1to1_observer, -) from ..fake_quantize import FixedQParamsFakeQuantize from ..fuser_method_mappings import ( reverse_sequential_wrapper2, ) +from ..qconfig_mapping import _FIXED_QPARAMS_OP_TO_OBSERVER # =================== # | DTYPE CONFIGS | @@ -109,6 +106,7 @@ def _get_default_op_backend_config(op, dtype_configs): torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.Dropout, + torch.nn.PReLU, torch.nn.functional.elu, torch.nn.functional.hardswish, torch.nn.functional.instance_norm, @@ -119,21 +117,7 @@ def _get_default_op_backend_config(op, dtype_configs): def _get_fixed_qparams_op_configs(dtype_configs): fixed_qparams_op_configs = [] - for fixed_qparam_op, output_observer in [ - (torch.nn.Hardsigmoid, default_fixed_qparams_range_0to1_observer), - (torch.nn.functional.hardsigmoid, default_fixed_qparams_range_0to1_observer), - ("hardsigmoid", default_fixed_qparams_range_0to1_observer), - ("hardsigmoid_", default_fixed_qparams_range_0to1_observer), - (torch.nn.Sigmoid, default_fixed_qparams_range_0to1_observer), - (torch.sigmoid, default_fixed_qparams_range_0to1_observer), - ("sigmoid", default_fixed_qparams_range_0to1_observer), - ("sigmoid_", default_fixed_qparams_range_0to1_observer), - (torch.nn.Tanh, default_fixed_qparams_range_neg1to1_observer), - (torch.tanh, default_fixed_qparams_range_neg1to1_observer), - ("tanh", default_fixed_qparams_range_neg1to1_observer), - ("tanh_", default_fixed_qparams_range_neg1to1_observer), - (torch.nn.Softmax, default_fixed_qparams_range_0to1_observer), - ]: + for fixed_qparam_op, output_observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): fixed_qparams_op_configs.append({ "pattern": fixed_qparam_op, "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, diff --git a/torch/ao/quantization/backend_config/observation_type.py b/torch/ao/quantization/backend_config/observation_type.py index be394eaed43f8..9a25f1dbc70f1 100644 --- a/torch/ao/quantization/backend_config/observation_type.py +++ b/torch/ao/quantization/backend_config/observation_type.py @@ -1,5 +1,7 @@ from enum import Enum +__all__ = ['ObservationType'] + class ObservationType(Enum): # this means input and output are observed with different observers, based # on qconfig.activation diff --git a/torch/ao/quantization/experimental/APoT_tensor.py b/torch/ao/quantization/experimental/APoT_tensor.py index e3a6b7dcf4dba..f780e20415414 100644 --- a/torch/ao/quantization/experimental/APoT_tensor.py +++ b/torch/ao/quantization/experimental/APoT_tensor.py @@ -1,14 +1,14 @@ import torch -from torch import Tensor +from torch.ao.quantization.experimental.quantizer import APoTQuantizer # class to store APoT quantized tensor -class TensorAPoT(torch.Tensor): - @staticmethod - def quantize_APoT(tensor2quantize: Tensor) -> Tensor: - raise NotImplementedError +class TensorAPoT(): + quantizer: APoTQuantizer + data: torch.Tensor - def dequantize(self) -> Tensor: - raise NotImplementedError + def __init__(self, quantizer: APoTQuantizer, apot_data: torch.Tensor): + self.quantizer = quantizer + self.data = apot_data - def q_apot_alpha(self) -> float: - raise NotImplementedError + def int_repr(self): + return self.data diff --git a/torch/ao/quantization/experimental/apot_utils.py b/torch/ao/quantization/experimental/apot_utils.py index bac48c85d5d89..ad7a7bed1fbe6 100644 --- a/torch/ao/quantization/experimental/apot_utils.py +++ b/torch/ao/quantization/experimental/apot_utils.py @@ -5,10 +5,16 @@ import math -r"""Converts floating point input into int4 APoT2 number +r"""Converts floating point input into APoT number based on quantization levels """ -def float_to_apot(x, levels, indices): +def float_to_apot(x, levels, indices, alpha): + # clip values based on alpha + if x < -alpha: + return -alpha + elif x > alpha: + return alpha + levels_lst = list(levels) indices_lst = list(indices) @@ -27,7 +33,7 @@ def float_to_apot(x, levels, indices): reduced precision floating point value based on quantization levels """ -def float_to_reduced_precision(x, levels, indices): +def quant_dequant_util(x, levels, indices): levels_lst = list(levels) indices_lst = list(indices) @@ -42,7 +48,7 @@ def float_to_reduced_precision(x, levels, indices): return best_fp -r"""Converts int4 APoT2 input into floating point number +r"""Converts APoT input into floating point number based on quantization levels """ def apot_to_float(x_apot, levels, indices): diff --git a/torch/ao/quantization/experimental/fake_quantize.py b/torch/ao/quantization/experimental/fake_quantize.py new file mode 100644 index 0000000000000..7541106a61c8d --- /dev/null +++ b/torch/ao/quantization/experimental/fake_quantize.py @@ -0,0 +1,38 @@ +import torch +from torch import Tensor +from torch.ao.quantization.experimental.observer import APoTObserver +from torch.ao.quantization.fake_quantize import FakeQuantizeBase +from torch.ao.quantization.experimental.fake_quantize_function import fake_quantize_function + +class APoTFakeQuantize(FakeQuantizeBase): + alpha: Tensor + gamma: Tensor + quantization_levels: Tensor + level_indices: Tensor + + def __init__(self, observer=APoTObserver, **observer_kwargs): + super().__init__() + self.activation_post_process = observer(**observer_kwargs) + self.dtype = self.activation_post_process.dtype + + def calculate_qparams(self, signed=False): # type: ignore[override] + return self.activation_post_process.calculate_qparams(signed=signed) + + def forward(self, X: torch.Tensor): # type: ignore[override] + if self.observer_enabled[0] == 1: + self.activation_post_process.forward(X) + result = self.activation_post_process.calculate_qparams(signed=False) + self.alpha = result[0] + self.gamma = result[1] + self.quantization_levels = result[2] + self.level_indices = result[3] + + if self.fake_quant_enabled[0] == 1: + assert (self.alpha is not None + and self.gamma is not None + and self.quantization_levels is not None + and self.level_indices is not None), "Must set qparams for fake quant" + + X = fake_quantize_function.apply(X, self.alpha, self.gamma, self.quantization_levels, self.level_indices) + + return X diff --git a/torch/ao/quantization/experimental/fake_quantize_function.py b/torch/ao/quantization/experimental/fake_quantize_function.py new file mode 100644 index 0000000000000..cac01fd8c0027 --- /dev/null +++ b/torch/ao/quantization/experimental/fake_quantize_function.py @@ -0,0 +1,27 @@ +import torch +from torch import Tensor +from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT + +class fake_quantize_function(torch.autograd.Function): + @staticmethod + def forward(ctx, # type: ignore[override] + x: Tensor, + alpha: Tensor, + gamma: Tensor, + quantization_levels: Tensor, + level_indices: Tensor) -> Tensor: + quantized_result = quantize_APoT(x, alpha, gamma, quantization_levels, level_indices) + + # calculate mask tensor + mask = x.detach().apply_(lambda x: (x <= alpha and x >= -alpha)) + + result = dequantize_APoT(quantized_result) + + ctx.save_for_backward(mask) + + return result + + @staticmethod + def backward(ctx, grad_output: Tensor) -> Tensor: # type: ignore[override] + mask = ctx.saved_tensors + return grad_output * mask diff --git a/torch/ao/quantization/experimental/linear.py b/torch/ao/quantization/experimental/linear.py new file mode 100644 index 0000000000000..cc98ffd08dc34 --- /dev/null +++ b/torch/ao/quantization/experimental/linear.py @@ -0,0 +1,161 @@ +import torch +import numpy as np + +from torch.nn.quantized.modules.utils import WeightedQuantizedModule +from torch.ao.quantization.experimental.observer import APoTObserver +from torch.ao.quantization.experimental.quantizer import quantize_APoT + +class LinearAPoT(WeightedQuantizedModule): + r""" + A quantized linear module with quantized tensor as inputs and outputs + to support APoT quantization. + We adopt the same interface as `torch.nn.Linear`, see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. + + Similar to :class:`~torch.nn.Linear`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + alpha: `alpha` qparam of output Quantized Tensor, type: Tensor + gamma: `gamma` qparam of output Quantized Tensor, type: Tensor + quantization_levels: `quantization_levels` qparam of output Quantized Tensor, type: Tensor + level_indices: `level_indices` qparam of output Quantized Tensor, type: Tensor + weight: APoT quantized tensor from weight2quantize + weight_transposed: transposed weight tensor, used in linear transformation calculation (y = x * A^T + b) + """ + + def __init__(self, weight2quantize: torch.Tensor, b: int, k: int): + assert weight2quantize.dim() == 2 + assert b % k == 0 + + super().__init__() + + self.b = b + self.k = k + self.n = self.b // self.k + + observer = APoTObserver(b=self.b, k=self.k) + + observer(weight2quantize) + + self.alpha, self.gamma, self.quantization_levels, self.level_indices = observer.calculate_qparams(signed=False) + + quantized_weight = quantize_APoT(weight2quantize, self.alpha, self.gamma, self.quantization_levels, self.level_indices) + self.weight = quantized_weight.data + self.weight_transposed = torch.transpose(self.weight, 0, 1) + + def decompose_APoT(self, x): + r""" + Decompose binary representation of APoT values into list of k-sized blocks + Args: + x (Tensor): binary representation of APoT quantized tensor + """ + # remove "0b" prefix from binary representation + x = x[2:] + + # initialize list of blocks + blocks = [] + + while x: + blocks.append(x[0:self.k]) + x = x[self.k:] + + return blocks + + def bitshift_mul(self, weight_val, r): + r""" + Compute multiplication of weight_val * r using bitshifting + method discussed in APoT paper: https://arxiv.org/pdf/1909.13144.pdf + Args: + weight_val: list of binary digits representing APoT quantized weight value + r: int representing uniformly quantized activation value + """ + product = 0 + + idx = len(weight_val) - 1 + place = 0 + + while idx >= 0: + block = weight_val[idx] + + # reverse digits in block + block = block[::-1] + + curr_block_result = 0 + + for ele in block: + if int(ele): + curr_block_result += r << place + place += 1 + + idx -= 1 + product += curr_block_result + + return product + + + def matmul(self, decomposed_weight, activation): + r""" + Perform matrix multiplication between decomposed_weight and + activation by calling bitshift_mul function for each value + Args: + decomposed_weight (Tensor): APoT quantized weight decomposed into binary + activation (Tensor): uniformly quantized activation + """ + rows1 = activation.size(dim=0) + cols1 = activation.size(dim=1) + + rows2 = decomposed_weight.shape[0] + cols2 = decomposed_weight.shape[1] + + result = torch.zeros(rows1, cols2) + + # compute matrix multiplication with bitshifts + for i in range(rows1): + for j in range(cols2): + for k in range(rows2): + weight_val = decomposed_weight[k][j] + r = int(activation[i][k]) + + product = self.bitshift_mul(weight_val, r) + + result[i][j] += product + + return result + + def forward(self, activation: torch.Tensor) -> torch.FloatTensor: + r""" + Multiply APoT quantized weight and uniformly quantized activation (dtype: quint8) + with bitshifting instead of matrix multiplication. + Result has dtype torch.float32 + Args: + activation (Tensor): uniformly quantized activation tensor + """ + assert activation.dim() == 2 + + weight_rows = self.weight_transposed.size()[0] + weight_cols = self.weight_transposed.size()[1] + + decomposed_weight = np.empty(shape=(weight_rows, weight_cols), dtype=object) + for row in range(weight_rows): + for col in range(weight_cols): + decomposed_weight[row][col] = self.decompose_APoT(bin(self.weight_transposed[row][col])) + + rows1 = self.weight_transposed.size(dim=0) + cols1 = self.weight_transposed.size(dim=1) + + rows2 = activation.size(dim=0) + cols2 = activation.size(dim=1) + + result = self.matmul(decomposed_weight, activation).type(torch.FloatTensor) + + return result + + @classmethod + def from_reference(cls, # type: ignore[override] + ref_qlinear, + alpha: torch.Tensor, + gamma: torch.Tensor, + quantization_levels: torch.Tensor, + level_indices: torch.Tensor): + raise NotImplementedError diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 85313f646ce67..244975bdd4caf 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -13,40 +13,48 @@ # when more than one non-uniform method is implemented class APoTObserver(ObserverBase): - max_val: float b: int k: int n: int - alpha: float - gamma: float - level_indices: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor def __init__( self, - max_val, b, k, dtype=torch.quint8) -> None: super().__init__(dtype) - self.max_val = max_val self.b = b self.k = k + self.min_val = torch.tensor([]) + self.max_val = torch.tensor([]) + + # min_val and max_val are optional args to override + # the min_val and max_val observed by forward def calculate_qparams(self, signed): - return self._calculate_qparams(signed) + return self._calculate_qparams(signed, self.min_val, self.max_val) r""" Calculates nonuniform quantization parameters according to APoT paper: https://arxiv.org/pdf/1909.13144.pdf. Arg: signed: specifies whether to include signed values in quantization level calculations + min_val: optional arg that can override min_val internal attribute + max_val: optional arg that can override max_val internal attribute Returns: gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range quantization_levels: non-uniform quantization levels (fp representation) level_indices: int representation of quantization_levels indices """ - def _calculate_qparams(self, signed): + def _calculate_qparams(self, signed: bool, min_val=None, max_val=None): + if min_val is not None: + self.min_val = min_val + if max_val is not None: + self.max_val = max_val + # compute alpha - self.alpha = self.max_val + alpha = torch.max(-self.min_val, self.max_val) # check for valid inputs of b, k assert(self.k and self.k != 0) @@ -62,7 +70,7 @@ def _calculate_qparams(self, signed): for i in range(0, self.n): p_curr = torch.tensor([0]) - for j in range(0, 2 ** (self.k - 1) + 1): + for j in range(0, (2 ** self.k - 2) + 1): curr_ele = 2 ** (- (i + j * self.n)) p_append = torch.tensor([curr_ele]) p_curr = torch.cat((p_curr, p_append)) @@ -90,7 +98,7 @@ def _calculate_qparams(self, signed): p_sum += float(tens[1]) # assign gamma - self.gamma = self.alpha / p_sum + gamma = alpha / p_sum # calculate cartesian product cartesian_product = list(itertools.product(*p_all)) @@ -99,21 +107,30 @@ def _calculate_qparams(self, signed): # calculate sum of each row for row in cartesian_product: - sum = 0 + sum = 0.0 for ele in row: sum += ele quantization_levels_list.append(sum) - quantization_levels_gamma = [self.gamma * ele for ele in quantization_levels_list] + quantization_levels_gamma = [float(gamma) * ele for ele in quantization_levels_list] quantization_levels = torch.tensor(quantization_levels_gamma) level_indices = torch.tensor([]) - quantization_levels, self.level_indices = quantization_levels.sort() + quantization_levels, level_indices = quantization_levels.sort() - return (self.gamma, quantization_levels, self.level_indices) + return (alpha, gamma, quantization_levels, level_indices) + r"""Records the running minimum and maximum of ``x``.""" def forward(self, x_orig): - r"""Records the running maximum of ``x``.""" - max_val = self.max_val + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() + min_val, max_val = torch.aminmax(x) + if self.min_val.numel(): + min_val = torch.min(min_val, self.min_val) + if self.max_val.numel(): + max_val = torch.max(max_val, self.max_val) + self.min_val = min_val + self.max_val = max_val return x_orig def quant_levels_visualization(self, obs_result, filename): diff --git a/torch/ao/quantization/experimental/qconfig.py b/torch/ao/quantization/experimental/qconfig.py new file mode 100644 index 0000000000000..f9397d18d6f27 --- /dev/null +++ b/torch/ao/quantization/experimental/qconfig.py @@ -0,0 +1,46 @@ +import torch +from torch.ao.quantization.qconfig import QConfig +from torch.ao.quantization import MinMaxObserver +from torch.ao.quantization.fake_quantize import FakeQuantize +from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize + +""" +Default symmetric fake_quant for activations. +""" +default_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver, + qscheme=torch.per_tensor_symmetric, + dtype=torch.quint8) + +""" +Default symmetric fake_quant for weights. +""" +default_weight_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver, + qscheme=torch.per_tensor_symmetric, + dtype=torch.qint8) + +# uniform activation and weight, b=8 k=2 +uniform_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant, + weight=default_weight_symmetric_fake_quant.with_args) + +# uniform activation, APoT weight, b=8 k=2 +apot_weight_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant.with_args, + weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8)) + +# APoT activation and uniform weight, b=8 k=2 +apot_qconfig_8bit = QConfig(activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8), + weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8)) + +# uniform activation and weight, b=4 k=2 +uniform_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0, + quant_max=15), + weight=default_weight_symmetric_fake_quant.with_args(quant_min=0, + quant_max=15)) + +# uniform activation, APoT weight, b=4 k=2 +apot_weight_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0, + quant_max=15), + weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8)) + +# APoT activation and uniform weight, b=4 k=2 +apot_qconfig_4bit = QConfig(activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8), + weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8)) diff --git a/torch/ao/quantization/experimental/quantizer.py b/torch/ao/quantization/experimental/quantizer.py new file mode 100644 index 0000000000000..1d8845cd2b654 --- /dev/null +++ b/torch/ao/quantization/experimental/quantizer.py @@ -0,0 +1,136 @@ +import torch +from torch import Tensor +import numpy as np +from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util + +# class to store APoT quantizer and +# implement quantize and dequantize +class APoTQuantizer(): + alpha: torch.Tensor + gamma: torch.Tensor + quantization_levels: torch.Tensor + level_indices: torch.Tensor + + def __init__( + self, + alpha: torch.Tensor, + gamma: torch.Tensor, + quantization_levels: torch.Tensor, + level_indices: torch.Tensor) -> None: + self.alpha = alpha + self.gamma = gamma + self.quantization_levels = quantization_levels + self.level_indices = level_indices + + r""" Quantizes fp Tensor to integer APoT representation. + Conversion is based on the qparams from a specified APoT non-uniform observer. + The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf. + Args: + tensor2quantize: fp Tensor + Returns: + result: APoT Tensor representation of tensor2quantize + """ + def quantize(self, tensor2quantize: Tensor): + result = torch.tensor([]) + + # map float_to_apot over tensor2quantize elements + tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x, + self.quantization_levels, + self.level_indices, + self.alpha)) + + # convert to APoT int representation for dtype + tensor2quantize = tensor2quantize.int() + + from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT + + result = TensorAPoT(self, tensor2quantize) + + return result + + r""" Dequantizes integer Tensor to floating point (fp) representation + based on the calculated quantization levels from a specified APoT non-uniform observer. + The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf. + Args: + tensor2quantize: fp Tensor + Returns: + result: fp reduced precision representation of input Tensor + """ + def dequantize(self, apot_tensor) -> Tensor: + orig_size = apot_tensor.data.size() + apot_tensor_data = apot_tensor.data.flatten() + + print(apot_tensor_data) + + # map apot_to_float over tensor2quantize elements + result_temp = np.empty(shape=apot_tensor_data.size()) + for i in range(len(apot_tensor_data)): + new_ele = apot_to_float(apot_tensor_data[i], self.quantization_levels, self.level_indices) + result_temp[i] = new_ele + + result = torch.from_numpy(result_temp).reshape(orig_size) + + return result + + r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision) + based on the calculated quantization levels from a specified APoT non-uniform observer. + The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf. + Args: + apot_tensor: quantized APoT Tensor to dequantize + Returns: + result: fp representation of input Tensor + """ + def quant_dequant(self, tensor2quantize: Tensor) -> Tensor: + levels_lst = list(self.quantization_levels) + + result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) + + return result + + def q_apot_alpha(self) -> float: + raise NotImplementedError + +r""" Global method to create quantizer and call quantizer quantize_APoT + Args: + tensor2quantize: fp Tensor to quantize + alpha: Tensor qparam alpha (clipping level) + gamma: Tensor qparam gamma (scale factor for quantization levels) + quantization levels: Tensor with fp quantization levels + level indices: Tensor with integer quantization level indices + Returns: + result: ApoT Tensor representation of tensor2quantize +""" +def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor): + quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices) + result = quantizer.quantize(tensor2quantize) + return result + +r""" Global method to create quantizer and call quantizer dequantize_APoT + Args: + apot_tensor: APoT Tensor to dequantize + Returns: + result: fp Tensor dequantized from apot_tensor +""" +def dequantize_APoT(apot_tensor) -> Tensor: + quantizer = apot_tensor.quantizer + result = quantizer.dequantize(apot_tensor) + return result + +r""" Global method to create quantizer and call quantizer quant_dequant + Args: + tensor2quantize: fp Tensor to quantize + alpha: Tensor qparam alpha (clipping level) + gamma: Tensor qparam gamma (scale factor for quantization levels) + quantization levels: Tensor with fp quantization levels + level indices: Tensor with integer quantization level indices + Returns: + result: fp reduced precision Tensor from tensor2quantize +""" +def quant_dequant_APoT(tensor2quantize: Tensor, + alpha: Tensor, + gamma: Tensor, + quantization_levels: Tensor, + level_indices: Tensor) -> Tensor: + quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices) + result = quantizer.quant_dequant(tensor2quantize) + return result diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index 41fbb366934ef..02fdc76ff6ac5 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -83,6 +83,7 @@ def is_default_node(node, modules): torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.Dropout, + torch.nn.PReLU, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.intrinsic.BNReLU2d, @@ -238,6 +239,7 @@ def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigA nn.LayerNorm: nnq.LayerNorm, nn.Dropout: nnq.Dropout, nn.Softmax: nnq.Softmax, + nn.PReLU: nnq.PReLU, nni.BNReLU2d: nniq.BNReLU2d, nni.BNReLU3d: nniq.BNReLU3d, } diff --git a/torch/ao/quantization/fx/_model_report/README.md b/torch/ao/quantization/fx/_model_report/README.md new file mode 100644 index 0000000000000..03c0118655f03 --- /dev/null +++ b/torch/ao/quantization/fx/_model_report/README.md @@ -0,0 +1,147 @@ +ModelReport +======== + +## Model Report Class in Fx Workflow + + > ⚠️ *While the example below uses the Fx Workflow, the use of the ModelReport class **does not depend** on the Fx Workflow to work*. + The requirements are detector dependent. + Most detectors require a **traceable GraphModule**, but some (ex. `PerChannelDetector`) require just a `nn.Module`. + +#### Typical Fx Workflow +- Initialize model → Prepare model → Callibrate model → Convert model → ... + +#### Fx Workflow with ModelReport +- Initialize model → Prepare model → **Add detector observers** → Callibrate model → **Generate report** → **Remove detector observers** → Convert model → ... + + > ⚠️ **You can only prepare and remove observers once with a given ModelReport Instance**: Be very careful here! + +## Usage + +This snippet should be ready to copy, paste, and use with the exception of a few small parts denoted in `#TODO` comments + +```python +# prep model +q_config_mapping = torch.ao.quantization.get_default_qconfig_mapping() # alternatively use your own qconfig mapping if you alredy have one +model = Model() # TODO define model +example_input = model.get_example_data()[0] # get example data for your model +prepared_model = quantize_fx.prepare_fx(model, q_config_mapping, example_input) + +# create ModelReport instance and insert observers +detector_set = set([PerChannelDetector(), InputWeightDetector(0.5), DynamicStaticDetector(), OutlierDetector()]) # TODO add all desired detectors +model_report = ModelReport(prepared_model, detector_set) +ready_for_callibrate = model_report.prepare_detailed_callibration() + +# TODO run callibration of model with relavent data + +# generate reports for your model and remove observers if desired +reports = model_report.generate_model_report(remove_inserted_observers=True) +for report_name in report.keys(): + text_report, report_dict = reports[report_name] + print(text_report, report_dict) + +# TODO update q_config_mapping based on feedback from reports +``` + +There is a tutorial in the works that will walk through a full usage of the ModelReport API. +This tutorial will show the ModelReport API being used on toy model in both an Fx Graph Mode workflow and an alterative workflow with just a traceable model. +This README will be updated with a link to the tutorial upon completion of the tutorial. + +# Key Modules Overview + +## ModelReport Overview + +The `ModelReport` class is the primary class the user will be interacting with in the ModelReport workflow. +There are three primary methods to be familiar with when using the ModelReport class: + +- `__init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase])` constructor that takes in instances of the model we wish to generate report for (must be traceable GraphModule) and desired detectors and stores them. +This is so that we can keep track of where we want to insert observers on a detector by detector basis and also keep track of which detectors to generate reports for. +- `prepare_detailed_calibration(self)` → `GraphModule` inserts observers into the locations specified by each detector in the model. +It then returns the GraphModule with the detectors inserted into both the regular module structure as well as the node structure. +- `generate_model_report(self, remove_inserted_observers: bool)` → `Dict[str, Tuple[str, Dict]]` uses callibrated GraphModule to optionally removes inserted observers, and generate, for each detector the ModelReport instance was initialized with: + - A string-based report that is easily digestable and actionable explaining the data collected by relavent observers for that detector + - A dictionary containing statistics collected by the relavent observers and values calculated by the detector for futher analysis or plotting + +## Detector Overview + +The main way to add functionality to the ModelReport API is to add more Detectors. +Detectors each have a specific focus in terms of the type of information they collect. +For example, the `DynamicStaticDetector` figures out whether Dynamic or Static Quantization is appropriate for different layers. +Meanwhile, the `InputWeightEqualizationDetector` determines whether Input-Weight Equalization should be applied for each layer. + + +### Requirements to Implement A Detector +All Detectors inherit from the `DetectorBase` class, and all of them (including any custom detectors you create) will need to implement 3 methods: +- `determine_observer_insert_points(self, model)` -> `Dict`: determines which observers you want to insert into a model to gather statistics and where in the model. +All of them return a dictionary mapping unique observer fully qualified names (fqns), which is where we want to insert them, to a dictionary of location and argument information in the format: + +```python +return_dict = { + "[unique_observer_fqn_of_insert_location]" : + { + "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node), + "insert_observer" -> the intialized observer we wish to insert (ObserverBase), + "insert_post" -> True if this is meant to be a post-observer for target_node, False if pre-observer, + "observer_args" -> The arguments that are meant to be passed into the observer, + } +} +``` +- `get_detector_name(self)` -> `str`: returns the name of the detector. +You should give your detector a unique name different from exisiting detectors. +- `generate_detector_report(self, model)` -> `Tuple[str, Dict[str, Any]]`: generates a report based on the information the detector is trying to collect. +This report consists of both a text-based report as well as a dictionary of collected and calculated statistics. +This report is returned to the `ModelReport` instance, which will then compile all the reports of all the Detectors requested by the user. + +## ModelReportObserver Overview + +As seen in the [requirments to implement a detector section](#requirements-to-implement-a-detector), one of the key parts of implementing a detector is to specify what `Observer` we are trying to insert. +All the detectors in the ModelReport API use the [`ModelReportObserver`](https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/_model_report/model_report_observer.py). +While the core purpose of many observers in PyTorch's Quantization API is to collect min / max information to help determine quantization parameters, the `ModelReportObserver` collects additional statistics. + +The statistics collected by the `ModelReportObserver` include: +- Average batch activation range +- Epoch level activation range +- Per-channel min / max values +- Ratio of 100th percentile to some *n*th percentile +- Number of constant value batches to pass through each channel + +After the `ModelReportObserver` collects the statistics above during the callibration process, the detectors then extract the information they need to generate their reports from the relavent observers. + +### Using Your Own Observer + +If you wish to implement your own custom Observer to use with the ModelReport API for your own custom detector, there are a few things to keep in mind. +- Make sure your detector inherits from [`torch.ao.quantization.observer.ObserverBase`](https://www.internalfb.com/code/fbsource/[20eb160510847bd24bf21a5b95092c160642155f]/fbcode/caffe2/torch/ao/quantization/observer.py?lines=122) +- In the custom detector class, come up with a descriptive and unique `PRE_OBSERVER_NAME` (and/or `POST_OBSERVER_NAME`) so that you can generate a fully qualified name (fqn) for each observer that acts a key in the returned dictionary described [here](#requirements-to-implement-a-detector) + - [Code Example](https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/_model_report/detector.py#L958) +- In the `determine_observer_insert_points()` method in your detector, initialize your custom Observer and add it to the returned dictionary described [here](#requirements-to-implement-a-detector) + - [Code Example](https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/_model_report/detector.py#L1047) + +Since you are also implementing your own detector in this case, it is up to you to determine where your observers should be placed in the model, and what type of information you wish to extract from them to generate your report. + +# Folder Structure + +./: the main folder all the model report code is under +- `__init__.py`: File to mark ModelReport as package directory +- `detector.py`: File containing Detector classes + - Contains `DetectorBase` class which all detectors inherit from + - Contains several implemented detectors including: + - `PerChannelDetector` + - `DynamicStaticDetector` + - `InputWeightEqualizationDetector` + - `OutlierDetector` +- `model_report_observer.py`: File containing the `ModelReportObserver` class + - Primary observer inserted by Detectors to collect necessary information to generate reports +- `model_report.py`: File containing the `ModelReport` class + - Main class users are interacting with to go through the ModelReport worflow + - API described in detail in [Overview section](#modelreport-overview) + +# Tests + +Tests for the ModelReport API are found in the `test_model_report_fx.py` file found [here](https://github.com/pytorch/pytorch/blob/master/test/quantization/fx/test_model_report_fx.py). + +These tests include: +- Test class for the `ModelReportObserver` +- Test class for the `ModelReport` class +- Test class for **each** of the implemented Detectors + +If you wish to add a Detector, make sure to create a test class modeled after one of the exisiting classes and test your detector. +Because users will be interacting with the Detectors through the `ModelReport` class and not directly, ensure that the tests follow this as well. diff --git a/torch/ao/sparsity/experimental/__init__.py b/torch/ao/quantization/fx/_model_report/__init__.py similarity index 100% rename from torch/ao/sparsity/experimental/__init__.py rename to torch/ao/quantization/fx/_model_report/__init__.py diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index 48bdf4abc2113..56ed9ff367ff2 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Set, Tuple, Union +from typing import Any, Dict, Set, Tuple, Callable, List import torch import torch.nn as nn @@ -7,7 +7,15 @@ from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.fx.graph_module import GraphModule from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver from torch.ao.quantization.qconfig import QConfig +from torch.ao.quantization.quantize import is_activation_post_process + +# Names for observer insert keys +DETECTOR_TARGET_NODE_KEY = "target_node" +DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert" +DETECTOR_IS_POST_OBS_KEY = "is_post_observer" +DETECTOR_OBS_ARGS_KEY = "observer_args" # Adding base class for detectors class DetectorBase(ABC): @@ -25,14 +33,13 @@ def __init__(self): super().__init__() @abstractmethod - def determine_observer_insert_points(self, model) -> Union[None, Tuple[Set[str], Any]]: + def determine_observer_insert_points(self, model) -> Dict: r""" Args model (nn.Module or subclass): model to find observer insertion points - Returns a Tuple of two elements: - Set[str] of observer fqns denoting where to insert observers - ObserverBase (or subclass): the class (not an instance) of the observer to insert + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict. + This dict maps string keys to detector specific information """ pass @@ -41,6 +48,41 @@ def get_detector_name(self) -> str: r""" Returns the name of the current detector """ pass + def _get_targeting_node(self, prepared_fx_model: GraphModule, target_fqn: str) -> torch.fx.node.Node: + r""" + Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn. + + If it's not found, it means it is most likely inside a fused layer + We just go one layer up in terms of the fqn we are searching for until we find parent node + If we get to empty string, then we know that it doesn't exist + + The reason for the recursion is that if the model that we are looking for got fused, + we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module, + which would have fqn as x.linear so they will not match. + To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear, + or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module + even in cases with fusion + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + target_fqn (str): The fqn of the layer we are trying to target + + Returns the node object we are trying to add observers around + """ + for node in prepared_fx_model.graph.nodes: + # if the node's target is our target, return it + if node.target == target_fqn: + return node + + # getting here means node not found + # if no "." we are already at base and failed + parent_fqn_sep_index = target_fqn.rfind(".") + if parent_fqn_sep_index == -1: + raise ValueError("passed in target_fqn not found in graph's targets.") + else: + # recursively call it with parent fqn + return self._get_targeting_node(prepared_fx_model, target_fqn[:parent_fqn_sep_index]) + @abstractmethod def generate_detector_report(self, model) -> Tuple[str, Dict[str, Any]]: r""" @@ -65,6 +107,11 @@ class PerChannelDetector(DetectorBase): Default value is current torch.backends.quantized.engine """ + # Keys for return dictionary + BACKEND_KEY = "backend" + PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported" + PER_CHAN_USED_KEY = "per_channel_quantization_used" + # Default map for representing supported per channel quantization modules for different backends DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = { "fbgemm": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), @@ -72,7 +119,7 @@ class PerChannelDetector(DetectorBase): "onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), } - def __init__(self, backend=torch.backends.quantized.engine): + def __init__(self, backend: str = torch.backends.quantized.engine): super().__init__() # store the backend information @@ -87,14 +134,16 @@ def get_detector_name(self) -> str: r""" returns the string name of this detector""" return "per_channel_detector" - def determine_observer_insert_points(self, model): + def determine_observer_insert_points(self, model: nn.Module) -> Dict: r""" - There is no observers inserted for the PerChannelDetector + There is no observers inserted for the PerChannelDetector. + + Returns an empty dictionary since no observers are added or needed """ - raise NotImplementedError("No observers are inserted in the PerChannelDetector.") + return {} - def _detect_per_channel_helper(self, model: nn.Module, per_channel_info: Dict): + def _detect_per_channel_helper(self, model: nn.Module): r""" determines if per_channel quantization is supported in modules and submodules. @@ -106,13 +155,11 @@ def _detect_per_channel_helper(self, model: nn.Module, per_channel_info: Dict): Returns dictionary mapping fqns to if per_channel quantization is possible """ - for named_mod in model.named_modules(): - - # get the fully qualified name and check if in list of modules to include and list of modules to ignore - fqn, module = named_mod + # create dict we will return + per_channel_info: Dict = {} - # asserts for MyPy - assert isinstance(fqn, str) and isinstance(per_channel_info["per_channel_status"], dict) + # get the fully qualified name and check if in list of modules to include and list of modules to ignore + for fqn, module in model.named_modules(): is_in_include_list = sum(list(map(lambda x: isinstance(module, x), self.supported_modules))) > 0 @@ -145,9 +192,10 @@ def _detect_per_channel_helper(self, model: nn.Module, per_channel_info: Dict): else: raise ValueError("Should be either observer or fake quant") - per_channel_info["per_channel_status"][fqn] = { - "per_channel_supported": per_channel_supported, - "per_channel_used": per_channel_used, + per_channel_info[fqn] = { + self.PER_CHAN_SUPPORTED_KEY: per_channel_supported, + self.PER_CHAN_USED_KEY: per_channel_used, + self.BACKEND_KEY: self.backend_chosen } return per_channel_info @@ -169,22 +217,16 @@ def generate_detector_report(self, model: nn.Module) -> Tuple[str, Dict[str, Any if it is being utilized in the current model """ - # store information on submodules and if per_channel quantization is supported and used as well as qconfig information - per_channel_info = {"backend": self.backend_chosen, "per_channel_status": {}} - # run the helper function to populate the dictionary - per_channel_info = self._detect_per_channel_helper(model, per_channel_info) + per_channel_info = self._detect_per_channel_helper(model) # String to let the user know of further optimizations further_optims_str = "Further Optimizations for backend {}: \n".format(self.backend_chosen) - # assert for MyPy check - assert isinstance(per_channel_info["per_channel_status"], dict) - optimizations_possible = False - for fqn in per_channel_info["per_channel_status"]: - fqn_dict = per_channel_info["per_channel_status"][fqn] - if fqn_dict["per_channel_supported"] and not fqn_dict["per_channel_used"]: + for fqn in per_channel_info: + fqn_dict = per_channel_info[fqn] + if fqn_dict[self.PER_CHAN_SUPPORTED_KEY] and not fqn_dict[self.PER_CHAN_USED_KEY]: optimizations_possible = True further_optims_str += "Module {module_fqn} can be configured to use per_channel quantization.\n".format( module_fqn=fqn @@ -223,20 +265,28 @@ class DynamicStaticDetector(DetectorBase): DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer" # naming conventions for stationary vs non-stationary data - DEFAULT_STATIONARY = "stationary" - DEFAULT_NON_STATIONARY = "non-stationary" + STATIONARY_STR = "stationary" + NON_STATIONARY_STR = "non-stationary" + + # naming for activation + INPUT_ACTIVATION_PREFIX = "input_activation_" + OUTPUT_ACTIVATION_PREFIX = "output_activation_" # naming conventions for the keys of the return module info - DEFAULT_TOLERANCE_KEY = "tolerance" + TOLERANCE_KEY = "dynamic_static_tolerance" DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended" - DEFAULT_PRE_OBS_COMP_STAT_KEY = "pre_observer_comp_stat" - DEFAULT_POST_OBS_COMP_STAT_KEY = "post_observer_comp_stat" - DEFAULT_PRE_OBS_DATA_DIST_KEY = "pre_observer_data_dist" - DEFAULT_POST_OBS_DATA_DIST_KEY = "post_observer_data_dist" + PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + PRE_OBS_DATA_DIST_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + POST_OBS_DATA_DIST_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported" # modules that are supported both dynamic and static for this report function DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = set([nn.Linear]) + # modules that will be supported soon for both + DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = set([nn.Conv1d, nn.Conv2d, nn.Conv3d]) + def __init__(self, tolerance=0.5): super().__init__() @@ -244,14 +294,88 @@ def __init__(self, tolerance=0.5): self.tolerance = tolerance self.useful_observer_fqns: Set[str] = set([]) - def determine_observer_insert_points(self, model) -> Tuple[Set[str], Any]: + def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]: + r""" + Determines where observers need to be inserted for the Dynamic vs Static detector. + For this detector, we want to place observers on either side of linear layers in the model. + + Currently inserts observers for: + linear layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # make sure module is supported + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args + } + + # add entry for post-observer + post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME - raise NotImplementedError("Will be implemented in a future commit") + obs_fqn_to_info[post_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: True, + DETECTOR_OBS_ARGS_KEY: (targeted_node,) + } + + return obs_fqn_to_info def get_detector_name(self) -> str: r""" returns the string name of this detector""" return "dynamic_vs_static_detector" + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0 + + # check if it will be supported + future_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED))) > 0 + + # supported + supported = is_supported_type or future_supported_type + + # this is check for observer insertion + if insert: + return supported + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(module, self.DEFAULT_POST_OBSERVER_NAME) + return supported and has_obs + def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]: r""" Helper function for generate_detector_report that does the generation of the dictionary. @@ -267,6 +391,7 @@ def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]: their S metric of output of module whether output of module is stationary or non-stationary the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future """ # store modules dynamic vs static information module_dynamic_static_info = {} @@ -277,16 +402,8 @@ def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]: # loop through all submodules included nested ones for fqn, module in model.named_modules(): - - # check to see if module is of a supported type - is_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0 - # if module is Linear has the ModelReportObserver attached to it - if ( - is_supported_type - and hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) - and hasattr(module, self.DEFAULT_POST_OBSERVER_NAME) - ): + if self._is_supported(module): # get pre and post observers for the module pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME) @@ -300,17 +417,21 @@ def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]: dynamic_recommended = post_stat <= self.tolerance # specify the classifications for whether data distributions considered stationary or non-stationary - pre_obs_dist_classif = self.DEFAULT_STATIONARY if pre_stat > self.tolerance else self.DEFAULT_NON_STATIONARY - post_obs_dist_classif = self.DEFAULT_STATIONARY if post_stat > self.tolerance else self.DEFAULT_NON_STATIONARY + pre_obs_dist_classif = self.STATIONARY_STR if pre_stat > self.tolerance else self.NON_STATIONARY_STR + post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR + + # check if current support or future support + is_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0 # store the set of important information for this module module_info = { - self.DEFAULT_TOLERANCE_KEY: self.tolerance, + self.TOLERANCE_KEY: self.tolerance, self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended, - self.DEFAULT_PRE_OBS_COMP_STAT_KEY: pre_stat, - self.DEFAULT_PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif, - self.DEFAULT_POST_OBS_COMP_STAT_KEY: post_stat, - self.DEFAULT_POST_OBS_DATA_DIST_KEY: post_obs_dist_classif, + self.PRE_OBS_COMP_STAT_KEY: pre_stat, + self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif, + self.POST_OBS_COMP_STAT_KEY: post_stat, + self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif, + self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type, } module_dynamic_static_info[fqn] = module_info @@ -345,6 +466,7 @@ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, A their S metric of output of module whether output of module is stationary or non-stationary the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future """ # get the dictionary of the information to format the string report @@ -352,21 +474,26 @@ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, A dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n" + modules_added: bool = False # check to make sure at least 1 module added. + + dynamic_benefit = " You will get more accurate results if you use dynamic quantization" + static_benefit = " You can increase model efficiency if you use static quantization" + future_support_str = ". This layer is not yet supported for dynamic quantization" # This for loop goes through the information collected in module_dynamic_static_info and: # Populates the string based report with the information from module_dynamic_static_info # Compiles the complete report by appending relavent formatted strings for module_fqn in module_dynamic_static_info.keys(): + # there is at least 1 module for suggestion + modules_added = True module_info = module_dynamic_static_info[module_fqn] suggestion_string_template = "For module {} it is suggested to use {} quantization because {}.\n" # decide what string formatting values will be quantization_type = "" - quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}." - dynamic_benefit = " You will get more accurate results if you use dynamic quantization" - static_benefit = " You can increase model efficiency if you use static quantization" + benefit_str = "" # strings for if dynamic quantized per tensor is needed @@ -374,13 +501,16 @@ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, A rec_lay_to_add = "dynamic quantize per tensor layer" dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add) dynamic_per_tensor_reasoning_string = ( - " This is because the input to this module has a non-stationary distribution." + " This is because the input to this module has a non-stationary distribution" ) # start composing explanation if module_info[self.DEFAULT_DYNAMIC_REC_KEY]: quantization_type = "dynamic" + # check if currently supported or future supported benefit_str = dynamic_benefit + if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]: + benefit_str += future_support_str else: quantization_type = "static" benefit_str = static_benefit @@ -388,7 +518,7 @@ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, A # now set the quantization explanation string quantization_reasoning = ( quantization_reasoning.format( - module_fqn, module_info[self.DEFAULT_PRE_OBS_DATA_DIST_KEY], module_info[self.DEFAULT_POST_OBS_DATA_DIST_KEY] + module_fqn, module_info[self.PRE_OBS_DATA_DIST_KEY], module_info[self.POST_OBS_DATA_DIST_KEY] ) + benefit_str ) @@ -396,8 +526,8 @@ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, A # if we have a non-stationary input -> linear -> stationary we suggested static # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made if ( - module_info[self.DEFAULT_PRE_OBS_DATA_DIST_KEY] == self.DEFAULT_NON_STATIONARY - and module_info[self.DEFAULT_POST_OBS_DATA_DIST_KEY] == self.DEFAULT_STATIONARY + module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR + and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR ): quantization_reasoning = ( quantization_reasoning + dynamic_per_tensor_string + dynamic_per_tensor_reasoning_string @@ -411,5 +541,793 @@ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, A # append to overall suggestion dynamic_vs_static_string += module_suggestion_string + if not modules_added: + dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n" + # return the string as well as the dictionary of information return (dynamic_vs_static_string, module_dynamic_static_info) + + +class InputWeightEqualizationDetector(DetectorBase): + r""" + Determines whether input-weight equalization can help improve quantization for certain modules. + + Specifically, this list of modules includes: + linear + conv + + Determines whether input-weight equalization is recommended based on the comp stat: + s_c = sqrt(w_c/W)/sqrt(i_c/I) + where: + w_c is range of weight for channel c, W is range of weight over all channels + i_c is range of input for channel c, I is range of input over all channels + + if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization + + Args: + ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is sugggested + Should be between 0 and 1 (both non-inclusive) + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is sugggested + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization + + * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + SUPPORTED_MODULES: Set[Callable] = set( + [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d] + ) + + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # weight / activation prefix for each of the below info + WEIGHT_PREFIX = "weight_" + ACTIVATION_PREFIX = "input_activation_" + + # string names for keys of info dictionaries + PER_CHANNEL_MAX_KEY = "per_channel_max" + PER_CHANNEL_MIN_KEY = "per_channel_min" + GLOBAL_MAX_KEY = "global_max" + GLOBAL_MIN_KEY = "global_min" + + # keys for return dict of recommendations + RECOMMENDED_KEY = "input_weight_equalization_recommended" + COMP_METRIC_KEY = "input_weight_channel_comparison_metrics" + THRESHOLD_KEY = "input_weight_threshold" + CHANNEL_KEY = "input_weight_channel_axis" + + # default weight and info strings + WEIGHT_STR = "weight" + INPUT_STR = "input" + + # default for what ratio we recommend input weight + DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.5 + + def __init__(self, ratio_threshold: float, ch_axis: int = 1): + # ensure passed in inputs are valid + if ratio_threshold <= 0 or ratio_threshold >= 1: + raise ValueError("Make sure threshold is > 0 and < 1") + + # intialize attributes based on args + self.ratio_threshold: float = ratio_threshold + self.ch_axis: int = ch_axis + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = sum(list(map(lambda x: type(module) is x, self.SUPPORTED_MODULES))) > 0 + + # this is check for observer insertion + if insert: + return is_supported_type + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + return is_supported_type and has_obs + + def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]: + r"""Determines where observers need to be inserted for the Input Weight Equalization Detector. + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + linear layers + conv layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "input_weight_equalization_detector" + + def _extract_input_info(self, model: GraphModule) -> Dict[str, Dict]: + r""" + Takes in a callibrated GraphModule and then finds the relevant observers. + It then extracts the input information for each observer returns it + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relavent module fqns (str) to a dict with keys: + "input_activation_per_channel_max" : maps to the per_channel max values + "input_activation_per_channel_min" : maps to the per_channel min values + "input_activation_global_max" : maps to the global max recorded + "input_activation_global_min" : maps to the global min recorded + """ + + # return dictionary mapping observer fqns to desired info + input_info: Dict[str, Dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # get pre observer for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + input_info[fqn] = { + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val, + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val, + self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val), + self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val), + } + + return input_info + + def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]: + r""" + Takes in a callibrated GraphModule and then finds the relavent observers. + It then extracts the weight information for each layer an observer is attached to. + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping module fqns (str) to a dict with keys: + "per_channel_max" : maps to the per_channel max values + "per_channel_min" : maps to the per_channel min values + "global_max" : maps to the global max recorded + "global_min" : maps to the global min recorded + """ + # return dictionary mapping observer fqns to desired info + weight_info: Dict[str, Dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # we don't need actual observer, just the module weights + # calculate min and max vals + min_val: torch.Tensor = torch.tensor([float('inf')]) + max_val: torch.Tensor = torch.tensor([float('-inf')]) + x_copy = module.weight + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + weight_info[fqn] = { + self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val, + self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val, + self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val), + self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val), + } + + return weight_info + + def _calculate_range_ratio(self, info_dict: Dict, info_str: str, module_fqn: str) -> torch.Tensor: + r""" + Takes in an info dict and calculates the s_c matrix. + + Args: + info_dict (dict): A dictionary of either input or weight range info + info_str (str): A str describing whether currently looking at weight or input info + Either "weight" or "input" + module_fqn (str): The fqn of the module we are looking at + + Returns a tensor of values, where each value is the s_c stat for a different channel + """ + # calculate the ratios of the info + # get the prefix str + prefix_str = self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX + + per_channel_range = info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY] + global_range = info_dict[prefix_str + self.GLOBAL_MAX_KEY] - info_dict[prefix_str + self.GLOBAL_MIN_KEY] + + if global_range == 0: + range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information." + raise ValueError( + "The range of the {} data for module {} is 0, which means you have a constant value channel. {}".format( + info_str, module_fqn, range_zero_explanation + ) + ) + + ratio = per_channel_range / global_range + + return ratio + + def _generate_comparision_values(self, input_info: Dict, weight_info: Dict) -> Dict[str, torch.Tensor]: + r""" + Takes in the information on the min and max values of the inputs and weights and: + Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I) + + Args: + input_info (dict): A dict mapping each observer to input range information + weight_info (dict): A dict mapping each observer to weight range information + + Returns a dict mapping relavent observer fqns (str) to a 1-D tensor. + Each value is a different s_c value for a different channel + """ + # create return dictionary for each observer + module_fqn_to_channel: Dict[str, torch.Tensor] = {} + + # for each module (both passed in dicts should have same keys) + for module_fqn in input_info: + + # raise error if not in weight info + if module_fqn not in weight_info: + raise KeyError("Unable to find weight range stats for module {}".format(module_fqn)) + + # calculate the ratios of the weight info and input info + weight_ratio = self._calculate_range_ratio(weight_info[module_fqn], self.WEIGHT_STR, module_fqn) + input_ratio = self._calculate_range_ratio(input_info[module_fqn], self.INPUT_STR, module_fqn) + + # if mismatched size, because of grouping, we want to replicate weight enough times + weight_channels = len(weight_ratio) + input_channels = len(input_ratio) + if weight_channels != input_channels: + # we try to replicate + assert input_channels % weight_channels == 0, "input channels should be divisible by weight channels." + # get replication factor + rep_factor: int = input_channels // weight_channels + + # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n + weight_ratio = weight_ratio.repeat(rep_factor) + + # calculate the s metric per channel + s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio) + module_fqn_to_channel[module_fqn] = s + + # return compiled observer ratios + return module_fqn_to_channel + + def _generate_dict_info(self, input_info: Dict, weight_info: Dict, comp_stats: Dict) -> Dict[str, Dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + input_info (dict): A dict mapping each module to input range information + weight_info (dict): A dict mapping each module to weight range information + comp_stats (dict): A dict mapping each module to its corresponding comp stat + + Returns a dictionary mapping each module with relavent ModelReportObservers around them to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + # store modules input weight equalization info + input_weight_equalization_info: Dict[str, Dict] = {} + + # for each module we add seperate set of suggestions + for module_fqn in input_info: + + # get relavent info for this module + mod_input_info: Dict = input_info[module_fqn] + mod_weight_info: Dict = weight_info[module_fqn] + mod_comp_stat: Dict = comp_stats[module_fqn] + + # decide if each channel should have input weight equalization or not + channel_rec_vals: list = [] + + for val in mod_comp_stat: + float_rep: float = val.item() + + # decide if recommending input weight equalization + recommended: bool = float_rep >= self.ratio_threshold and float_rep <= 1 / self.ratio_threshold + channel_rec_vals.append(recommended) + + # build the return dict input + # also unpack input and weight dicts into it + input_weight_equalization_info[module_fqn] = { + self.RECOMMENDED_KEY: channel_rec_vals, + self.COMP_METRIC_KEY: mod_comp_stat, + self.THRESHOLD_KEY: self.ratio_threshold, + self.CHANNEL_KEY: self.ch_axis, + **mod_input_info, + **mod_weight_info, + } + + # return our compiled info for each module + return input_weight_equalization_info + + def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records per channel information of input range + It then uses the passed in weight info inconjunction to compute the desired ratio + Finally, it gives suggestions based on this information for each module of interest + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + weight_info (Dict): Maps modules of interest to information on their weights to be analyzed + + Returns a tuple with two elements: + String report of of whether input weight equalization is recommended for certain modules + Dictionary mapping modules of interest to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + + # find the range of inputs + input_values: Dict[str, Dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: Dict[str, Dict] = self._extract_weight_info(model) + + # calculate per_channel comparision statistic s_c + comp_stats: Dict[str, torch.Tensor] = self._generate_comparision_values(input_values, weight_values) + + # generate the return dictionary + input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats) + + # now we can generate report based on this information + input_weight_string = "Input-Weight Equalization suggestions: \n" + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = "\tWe suggest {} input weight equalization because {}\n" + use_str = "to use" + no_use_str = "to not use" + input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error." + input_weight_non_benefit_reasoning = "{}/{} channels benefitting from input-weight equalization being applied." + input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}" + + # added module check + added_module: bool = False + + # compile the suggestion string + for module_fqn in input_weight_equalization_info: + # we added at least 1 module + added_module = True + # add the module level description + input_weight_string += module_suggestion_str.format(module_fqn, self.ch_axis) + + mod_info: Dict[str, Any] = input_weight_equalization_info[module_fqn] + + # gather info on how many channels would benefit from input weight and + recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY] + num_recs = sum(recommendation_per_channel) + + if num_recs / len(recommendation_per_channel) >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO: + input_benefit_formatted = input_weight_benefit_str.format(num_recs, len(recommendation_per_channel)) + channel_str = channel_suggestion_str.format(use_str, input_benefit_formatted) + input_weight_string += channel_str + else: + non_benefit_reason_formatted = input_weight_non_benefit_reasoning.format(num_recs, len(recommendation_per_channel)) + non_benefit_str = input_weight_non_benefit_str.format(non_benefit_reason_formatted) + channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str) + input_weight_string += channel_str + + # if no modules looked at, amend return string + if not added_module: + input_weight_string += "No applicable layers for suggestions. Only linear and conv valid.\n" + + # return a tuple with the string explanation and the compiled dict info + return (input_weight_string, input_weight_equalization_info) + + +class OutlierDetector(DetectorBase): + r""" + Determines whether there are significant outliers in activation data around a certain layer. + + This is ideally used in conjunction with information on stationary vs. non-stationary distribution: + If the data is stationary, and there are significant outliers, then we want to flag them + We want to do this on a per channel basis for detecting outliers + + Determines whether activation data is flagged as outlier based on if data is stationary and: + p_r = avg(100th percentile / "reference_percentile"th percentile) + where: + p_r is average percentile ratio across all batches in the epoch + reference_percentile is a percentile values between 0 and 100 exclusive + + if p_r is above some threshold, then we consider the activations to have significant outliers + + Args: + ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations + Should be >= 1 + Default: 3.5 + reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile + Should be between 0 and 1 + Default: 0.975 + fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier + If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user + regardless of whether we detected outliers or not in channel to take a closer look at channel results + Should be between 0 and 1 + Default: 0.95 + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations + The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold + If it is significantly greater, then we consider it an outlier + This threshold was calculated based on the ratio of the percentiles in a normal distribution + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile + Should be between 0 and 1 + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this + Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine outliers + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + # names for the pre observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # pre activation prefix + INPUT_ACTIVATION_PREFIX = "input_activation_" + + # names for dict keys + OUTLIER_KEY = "outliers_detected" + NUM_BATCHES_KEY = "outlier_detection_batches_used" + IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches" + COMP_METRIC_KEY = "outlier_detection_percentile_ratios" + RATIO_THRES_KEY = "outlier_detection_ratio_threshold" + REF_PERCENTILE_KEY = "outlier_detection_reference_percentile" + CHANNEL_AXIS_KEY = "outlier_detection_channel_axis" + MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max" + CONSTANT_COUNTS_KEY = "constant_batch_counts" + + def __init__( + self, + ratio_threshold: float = 3.5, + reference_percentile: float = 0.975, + fraction_batches_used_threshold: float = 0.95, + ch_axis: int = 1, + ): + # initialize the variables of interest + self.ratio_threshold = ratio_threshold + + # make sure passed in percentile is valid + assert reference_percentile >= 0 and reference_percentile <= 1 + assert fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1 + self.reference_percentile = reference_percentile + self.fraction_batches_used_threshold = fraction_batches_used_threshold + self.ch_axis = ch_axis + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "outlier_detector" + + def _supports_insertion(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for observers insertion + + Any module that doesn't have children and isn't an observer itself is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + # case for insertion of module + # check if the module has any children and isn't observer + num_children = len(list(module.children())) + return num_children == 0 and not is_activation_post_process(module) + + def _supports_report_gen(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for report generation + + Any module that has a model report pre-observer is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]: + r""" Determines where observers need to be inserted for the Outlier Detector. + + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + all layers that do not have children (leaf level layers) + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._supports_insertion(module): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis, comp_percentile=self.reference_percentile), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def _calculate_outlier_info( + self, + percentile_ratios: torch.Tensor, + counted_batches: torch.Tensor, + total_batches: int, + ) -> Dict[str, List[bool]]: + r""" + Gives info on whether the percentile ratios cacluated would be considered outliers + Also gives information on whether the collected data is statistically significant to make this claim + + Args: + percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer + counted_batches (torch.Tensor): The number of batches used for average calculation per tensor + total_batches (int): The total number of batches that passed through observer in this epoch + + Returns a dictionary mapping: + "outliers_detected" : list of bools per channel that are true if it is considered an outlier + "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold: + where o_r = counted_batches / total_batches + """ + outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.IS_SUFFICIENT_BATCHES_KEY: []} + + # get both as flattened lists for easy mapping + ratios_list: List = percentile_ratios.tolist() + num_batches_list: List = counted_batches.tolist() + + # calculate whether channels were statistically significant + significant_size = [ + batch_size / total_batches >= self.fraction_batches_used_threshold for batch_size in num_batches_list + ] + outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size + + # calculate for each channel whether it's an outlier or not based on ratio + outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list] + outlier_dict[self.OUTLIER_KEY] = outlier_detected + + # return the dictionary with the two lists + return outlier_dict + + def _generate_info_dict(self, model: GraphModule) -> Dict[str, Dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relavent module fqns to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # return dictionary mapping observer fqns to desired info + info_dict: Dict[str, Dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._supports_report_gen(module): + # get pre observer for the module + pre_obs: ModelReportObserver = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + # get the number of batches and calculated ratio thresholds + num_batches: torch.Tensor = pre_obs.percentile_batches_tracked + average_ratios: torch.Tensor = pre_obs.average_percentile_ratio + channel_batch_cnts: torch.Tensor = pre_obs.constant_channels + total_batches: int = pre_obs.num_batches_tracked + + # also get the max values + max_vals: torch.Tensor = pre_obs.max_val + + # we have to specifically modify how we are recording negative ratio for pre-relu layers + for index, ratio_val in enumerate(average_ratios): + # check if we have a negative ratio + # a ratio might be negative if we have a situation where the 100th percentile is + # > 0 while the nth percentile is < 0, in which case this would not be detected + # as an outlier. Since we care more about magnitude, we make it positive. + if ratio_val.item() < 0: + # first make it positive + average_ratios[index] = -ratio_val + + if ratio_val.item() < 1: + # if it's less than 1 we have the flip it as well + average_ratios[index] = 1 / ratio_val + + outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches, total_batches) + + # calculate whether ratios were outliers + info_dict[fqn] = { + self.CHANNEL_AXIS_KEY: self.ch_axis, + self.REF_PERCENTILE_KEY: self.reference_percentile, + self.RATIO_THRES_KEY: self.ratio_threshold, + self.COMP_METRIC_KEY: average_ratios, + self.NUM_BATCHES_KEY: num_batches, + self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY], + self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[self.IS_SUFFICIENT_BATCHES_KEY], + self.CONSTANT_COUNTS_KEY: channel_batch_cnts, + self.MAX_VALS_KEY: max_vals + } + + return info_dict + + def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records the relavent percentile information + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether there are outliers in the activations around certain modules + Dictionary mapping modules of interest to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # generate the information dictionary of outlier information + info_dict = self._generate_info_dict(model) + + # now we can generate report based on this information + outlier_string = "Outlier detection report: \n" + + # added module check + added_module: bool = False + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n" + channel_max_value_str = "a max value across all batches of {}" + note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results." + note_distribution = "stationary distributions" + note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary" + + # suggestion for constant batch check since that can make it no outliers + constant_str = "\tFor channel {}, we found {} constant value batches. {}\n" + constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occured and why." + + # compile the suggestion string + for module_fqn in info_dict: + # get module specific info + mod_info: Dict[str, Any] = info_dict[module_fqn] + # check to see if we already added high level model desc + added_model_desc = False + # look at each individual channel and add a suggestion + for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]): + if outlier_detected: + # we found at least 1 outlier + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis) + added_model_desc = True + + # we mark that we found at least one outlier + added_module = True + max_value_found_str = channel_max_value_str.format(mod_info[self.MAX_VALS_KEY][index]) + channel_str = channel_suggestion_str.format(index, max_value_found_str) + outlier_string += channel_str + + # also check if we found constant batch + if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0: + # make sure we add a module level highlight. + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis) + added_model_desc = True + + constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][index] + formatted_str = constant_str.format(index, constant_values_for_channel, constant_suggestion) + outlier_string += formatted_str + # we also added at least one thing to description + added_module = True + + + # if found outlier, give suggestion, else give default response + if added_module: + # compose the note string + note_composed = note_string.format(note_distribution, note_rec) + outlier_string += note_composed + else: + outlier_string += "There were no outliers found in the activations.\n" + + return (outlier_string, info_dict) diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py new file mode 100644 index 0000000000000..25be9e4edc664 --- /dev/null +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -0,0 +1,417 @@ +from typing import Any, Dict, Set, Tuple +from collections import OrderedDict +import torch +from torch.ao.quantization.fx._model_report.detector import ( + DetectorBase, + DETECTOR_OBS_ARGS_KEY, + DETECTOR_OBS_TO_INSERT_KEY, + DETECTOR_IS_POST_OBS_KEY, + DETECTOR_TARGET_NODE_KEY +) +from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.observer import ObserverBase + + +class ModelReport: + r""" + The ModelReport class aims to provide users an easy way to diagnose issues that they run into + with their models. The class works with all traceable GraphModules to help diagnose issues, + though the requirements on the type of model more-so depends on the specific report the user + is trying to generate. With respect to the reports, the ModelReport class is intialized with + a set of Detector classes, each of which generate reports on quantization configuration + issues a use might have. + + Currently supports generating reports on: + - Suggestions for per-channel vs. per-tensor quantization (nn.Module) + - Suggestions for dynamic vs static quantization for linear layers (Graph Modules) + - Suggestions for input-weight equalization for linear and conv layers (Graph Modules) + - Suggestions for outlier detection for all layers (Graph Modules) + + The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver) + where needed for each detector to gather the information it needs, and then after callibration, the ModelReport + class compiles the report generated by each Detector class into a single report to return to the user. It also + has the capability to remove all the observers it inserted as well. + + * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule + + * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class + Make sure that these are all unique types of detectors [do not have more than 1 of the same class] + + * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors. + This set is generated by calling the get_detector_name() of each detector + + * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest + The purpose of this is to keep track of what observers were inserted for each detector, so that they + can be removed at the end if desired + + * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not + This is to ensure we only insert observers once with the ModelReport instance + + * :attr:`_removed_observers` A boolean to track if we have removed observers already + The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport + instance. This also allows the functionality where we can generate the report multiple times + as long as we haven't removed the observers yet. + + Note: + This class was initially designed to work with the Fx Graph Mode workflow in mind. However, + full functionality is available as long as there is a traceable GraphModule that is being used. + One method to get a traceable GraphModule without going through the Fx workflow is to use + the QuantizationTracer class. + + General Flow for Fx workflow: + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration to add relavent observers + 4.) Callibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + Optional + 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance + 7.) To help in parsing report information and debugging, view report info as a: + - Table + - Histogram + - Line plot + + Example (with QuantizationTracer): + >>> # get the necessary qconfig + >>> config = PrepareCustomConfig() + >>> skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(config, False) + + >>> # initialize our model and get GraphModule + >>> model = SomeModel() + >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) + >>> graph_module = GraphModule(model, tracer.trace(model)) + + >>> # get our set of detectors and ModelReport instance + >>> detector_set = set([DynamicStaticDetector(tolerance=0.5), InputWeightEqualizationDetector(ratio_threshold=0.7)]) + >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set) + + >>> # now we insert the observers and callibrate the model + >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration() + >>> for i in range(num_callibration_batches): + >>> example_input = get_callibration_input() + >>> tracer_model_with_observers(example_input) + + >>> # finally we generate the reports and optionally remove the observers we inserted + >>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True) + + >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired + >>> model_report_visualizer = tracer_reporter.generate_visualizer() + + """ + + def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase]): + + if len(desired_report_detectors) == 0: + raise ValueError("Should include at least 1 desired report") + + # keep track of the model we wish to generate report for + self._model: GraphModule = model + + # keep the reports private so they can't be modified + self._desired_report_detectors = desired_report_detectors + self._desired_detector_names = set([detector.get_detector_name() for detector in desired_report_detectors]) + + # keep a mapping of desired reports to observers of interest + # this is to get the readings, and to remove them, can create a large set + # this set can then be used to traverse the graph and remove added observers + self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {} + + # initialize each report to have empty set of observers of interest + for desired_report in self._desired_detector_names: + self._detector_name_to_observer_fqns[desired_report] = set([]) + + # flags to ensure that we can only prepare and remove observers once + self._prepared_flag = False + self._removed_observers = False + + # store the reports that we generated for visualization purposes + # intially empty since no reports generated + self._generated_reports: Dict[str, Dict] = {} + + def get_desired_reports_names(self) -> Set[str]: + """ Returns a copy of the desired reports for viewing """ + return self._desired_detector_names.copy() + + def get_observers_of_interest(self) -> Dict[str, Set[str]]: + """ Returns a copy of the observers of interest for viewing """ + return self._detector_name_to_observer_fqns.copy() + + def prepare_detailed_calibration(self) -> GraphModule: + r""" + Takes in a graph model and inserts the following observers: + - ModelReportObserver + + Each observer is inserted based on the desired_reports into the relavent locations + + Right now, each report in self._desired_detector_names has independent insertions + However, if a module already has a Observer of the same type, the insertion will not occur + This is because all of the same type of Observer collect same information, so redundant + + Returns the same GraphModule with the observers inserted + """ + + # if already prepared once, cannot prepare again + if self._prepared_flag: + raise ValueError("Already ran preparing detailed callibration. Run the report generation next after callibration.") + + # loop through each detector, find where placements should be, and keep track + insert_observers_fqns: Dict[str, Any] = {} + + for detector in self._desired_report_detectors: + # determine observer points for each detector + obs_fqn_to_info = detector.determine_observer_insert_points(self._model) + # map each insert point to the observer to use + insert_observers_fqns.update(obs_fqn_to_info) + # update the set of observers this report cares about + self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys()) + + # now insert all the observers at their desired locations + for observer_fqn in insert_observers_fqns: + target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY] + insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY] + insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY] + observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY] + self._insert_observer_around_module( + observer_fqn, target_node, insert_obs, observer_args, insert_post + ) + + self._prepared_flag = True + + return self._model + + def _insert_observer_around_module( + self, + obs_fqn: str, + target_node: torch.fx.node.Node, + obs_to_insert: ObserverBase, + observer_args: Tuple, + insert_post: bool + ): + r""" + Helper function that inserts the observer into both the graph structure and the module of the model + + Args + node_fqn (str): The fully qualified name of the observer we want to insert + target_node (torch.fx.node.Node): The node in model we are inserting observers around + obs_to_insert (ObserverBase): The observer we are inserting around target_node + observer_args (Tuple): The arguments we want to pass into the observer + insert_post (bool): whether this is meant to be a post observer for this node + """ + # if we are inserting post, then our target node is the next node + if insert_post: + target_node = target_node.next + + with self._model.graph.inserting_before(target_node): + self._model.add_submodule(obs_fqn, obs_to_insert) + self._model.graph.create_node(op="call_module", target=obs_fqn, args=observer_args) + + # recompile model after inserts are made + self._model.recompile() + + def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node: + r""" + Takes in a node fqn and returns the node based on the fqn + + Args + node_fqn (str): The fully qualified name of the node we want to find in model + + Returns the Node object of the given node_fqn otherwise returns None + """ + node_to_return = None + for node in self._model.graph.nodes: + # if the target matches the fqn, it's the node we are looking for + if node.target == node_fqn: + node_to_return = node + break + + if node_to_return is None: + raise ValueError("The node_fqn is was not found within the module.") + + # assert for MyPy + assert isinstance(node_to_return, torch.fx.node.Node) + + return node_to_return + + def generate_model_report( + self, remove_inserted_observers: bool + ) -> Dict[str, Tuple[str, Dict]]: + r""" + Generates all the requested reports. + + Note: + You should have callibrated the model with relavent data before calling this + + The reports generated are specified by the desired_reports specified in desired_reports + + Can optionally remove all the observers inserted by the ModelReport instance + + Args: + remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance + + Returns a mapping of each desired report name to a tuple with: + The textual summary of that report information + A dictionary containing relavent statistics or information for that report + + Note: + Throws exception if we try to generate report on model we already removed observers from + Throws exception if we try to generate report without preparing for callibration + """ + # if we haven't prepped model for callibration, then we shouldn't generate report yet + if not self._prepared_flag: + raise Exception("Cannot generate report without preparing model for callibration") + + # if we already removed the observers, we cannot generate report + if self._removed_observers: + raise Exception("Cannot generate report on model you already removed observers from") + + # keep track of all the reports of interest and their outputs + reports_of_interest = {} + + for detector in self._desired_report_detectors: + # generate the individual report for the detector + report_output = detector.generate_detector_report(self._model) + reports_of_interest[detector.get_detector_name()] = report_output + + # if user wishes to remove inserted observers, go ahead and remove + if remove_inserted_observers: + self._removed_observers = True + # get the set of all Observers inserted by this instance of ModelReport + all_observers_of_interest: Set[str] = set([]) + for desired_report in self._detector_name_to_observer_fqns: + observers_of_interest = self._detector_name_to_observer_fqns[desired_report] + all_observers_of_interest.update(observers_of_interest) + + # go through all_observers_of_interest and remove them from the graph and model + for observer_fqn in all_observers_of_interest: + # remove the observer from the model + self._model.delete_submodule(observer_fqn) + + # remove the observer from the graph structure + node_obj = self._get_node_from_fqn(observer_fqn) + + if node_obj: + self._model.graph.erase_node(node_obj) + else: + raise ValueError("Node no longer exists in GraphModule structure") + + # remember to recompile the model + self._model.recompile() + + # save the generated reports for visualization purposes + saved_reports: Dict[str, Dict] = { + report_name : report_tuple[1] for report_name, report_tuple in reports_of_interest.items() + } + + self._generated_reports = saved_reports + + # return the reports of interest + return reports_of_interest + + def _is_same_info_for_same_key(self, info_dict_a: Dict, info_dict_b: Dict) -> bool: + r""" + Takes in two dictionaries and ensures that any common keys between the two have the same + values. + + Args: + info_dict_a (Dict): First dictionary we wish to compare + info_dict_b (Dict): Second dictionary we wish to compare + + Returns True if all shared keys have same values, false otherwise + """ + # get the set of keys for both + dict_a_keys: Set = set(info_dict_a.keys()) + dict_b_keys: Set = set(info_dict_b.keys()) + + # get the insersection keys and check if same value for both dicts + intersecting_keys: Set = dict_a_keys.intersection(dict_b_keys) + + for key in intersecting_keys: + dict_a_val = info_dict_a[key] + dict_b_val = info_dict_b[key] + + # if it's a tensor we have to handle seperately + if type(dict_a_val) == torch.Tensor: + # if dict_b_val not tensor, automatically false + if type(dict_b_val) != torch.Tensor or sum(dict_a_val != dict_b_val) != 0: + return False + else: + # for non-tensor vals + if dict_a_val != dict_b_val: + return False + + # if no non matching shared keys found, return true + return True + + def _reformat_reports_for_visualizer(self) -> OrderedDict: + r""" + Takes the generated reports and reformats them into the format that is desired by the + ModelReportVisualizer + + Returns an OrderedDict mapping module_fqns to their features + """ + # we want to reorder and reformat the information so it is ordered in terms of order + # found in the model + + # first create new dict with all modules as keys and features under respective module + module_fqns_to_features: Dict[str, Dict] = {} + + for report_name in self._generated_reports: + # get mod -> feature dict and go through + module_info = self._generated_reports[report_name] + + for module_fqn in module_info: + # check if already in our accumulation dict + if module_fqn in module_fqns_to_features: + # we merge all the features together + new_info: Dict = module_info[module_fqn] + present_info: Dict = module_fqns_to_features[module_fqn] + + # merge them together into the new unioned dict + # same features keys -> same info, so okay if override + + # do safety check to make sure shared keys have same info + if self._is_same_info_for_same_key(new_info, present_info): + module_fqns_to_features[module_fqn] = {**new_info, **present_info} + else: + error_str = "You have the same key with different values across detectors. " + error_str += "Someone incorrectly implemented a detector with conflicting keys to exisiting detectors." + raise ValueError(error_str) + else: + # we just set it + module_fqns_to_features[module_fqn] = module_info[module_fqn] + + # our ordered dict so that modules can be ordered in order of how they appear in model + features_by_module: OrderedDict[str, Dict] = OrderedDict() + + # we loop through modules in graph in order + for fqn, module in self._model.named_modules(): + # find that fqn in fqns_to_features + if fqn in module_fqns_to_features: + # add it to our ordered dict + features_by_module[fqn] = module_fqns_to_features[fqn] + + # return the ordered dict of info we created + return features_by_module + + def generate_visualizer(self) -> ModelReportVisualizer: + r""" + Generates a ModelReportVisualizer instance using the reports generated + by the generate_model_report() method. + + Returns the generated ModelReportVisualizer instance initialized + + Note: + Throws exception if attempt to get visualizers without generating report + """ + # check if user has generated reports at least once + if len(self._generated_reports) == 0: + raise Exception("Unable to generate visualizers without first generating reports") + + # get the ordered dict mapping modules to their full set of collected features / stats + module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer() + + # create and return ModelReportVisualizer instance + visualizer: ModelReportVisualizer = ModelReportVisualizer(module_fqns_to_features) + + return visualizer diff --git a/torch/ao/quantization/fx/_model_report/model_report_observer.py b/torch/ao/quantization/fx/_model_report/model_report_observer.py index 1a5262188b5d1..500b1e654edcd 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_observer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_observer.py @@ -10,6 +10,13 @@ class ModelReportObserver(ObserverBase): Dynamic or Static Quantization is more appropriate for their model given the general distributions of their data. + Args: + ch_axis (int, optional): The channel axis for which the range and outlier stats are computed + Default: 1 + comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers + Should be between 0 and 1 exclusive + Default: 0.9 + * :attr:`num_batches_tracked` specifies number of batches passed through the observer * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through @@ -18,21 +25,63 @@ class ModelReportObserver(ObserverBase): * :attr:`epoch_activation_max` defines the maximum value passed through the observer + * :attr:`ch_axis` defines the channel being used to compute per channel min max stats + + * :attr:`min_val` defines the per channel minimum values passed through + + * :attr:`max_val` defines the per channel maximum values passed through + + * :attr:`comp_percentile` defines comparison percentile to find outliers + + * :attr:`average_percentile_ratio` defines the per channel average percentile ratios + + * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel + + * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel + Note: this tool is meant for FX Graph Mode Quantization """ - def __init__(self): + def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9): super().__init__(torch.qint8) self.num_batches_tracked = 0 # keep track of the min and mix of the range for average batch and epoch as a whole - self.average_batch_activation_range = torch.tensor(float(0)) + self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0)) self.epoch_activation_min = torch.tensor(float("inf")) self.epoch_activation_max = torch.tensor(float("-inf")) + # keep track of per channel min max information using the given channel + self.ch_axis: int = ch_axis + self.min_val: torch.Tensor = torch.tensor([]) + self.max_val: torch.Tensor = torch.tensor([]) + + # keep track of percentile ratio information per channel + self.comp_percentile: torch.Tensor = torch.tensor([comp_percentile]) + self.average_percentile_ratio: torch.Tensor = torch.tensor([]) + self.percentile_batches_tracked: torch.Tensor = torch.tensor([]) + self.constant_channels: torch.Tensor = torch.tensor([]) + def forward(self, x): x_copy = x.detach() # avoid keeping autograd tape x_copy = x_copy.to(self.epoch_activation_min.dtype) + + x_copy = self._calculate_range_stats(x_copy) + x_copy = self._calculate_min_max_stats(x_copy) + x_copy = self._calculate_percentile_stats(x_copy) + + # return the passed in the value + return x + + def _calculate_range_stats(self, x_copy): + r"""Calculates and stores range stats with forward values. + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the min, max values of the data min_val_cur, max_val_cur = torch.aminmax(x_copy) # calculate new epoch range values @@ -52,8 +101,127 @@ def forward(self, x): self.average_batch_activation_range = new_range self.num_batches_tracked += 1 # new batch was processed - # return the passed in the value - return x + return x_copy + + def _calculate_min_max_stats(self, x_copy): + r"""Calculates and stores the per_channel min, max stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the current min and max vals + min_val = self.min_val + max_val = self.max_val + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + return x_copy + + def _calculate_percentile_stats(self, x_copy): + r"""Calculates and stores the per_channel percentile stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the dimension of the copy + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + y = y.to(self.min_val.dtype) + + # find the percentile values along the axis + # we want both 100th percentile and comp_percentile + # we also want to find 0th quartile to see if we have constant channel + quantiles_list = [0, self.comp_percentile, 1.00] + quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype) + + # find the quantiles + desired_quantiles = torch.quantile(y, quantiles_to_find, dim=self.ch_axis, interpolation="lower") + zero_quantile = desired_quantiles[0] + comp_quantile = desired_quantiles[1] + hundreth_quartile = desired_quantiles[2] + + # if any of the channels have 0s, we ignore that channel for this calculation + any_non_zero_quantile_value: torch.Tensor = (comp_quantile != torch.tensor([0])) | (hundreth_quartile != torch.tensor([0])) + any_non_zero_quantile_value = any_non_zero_quantile_value.int() # transform boolean values to int values + + # we also check if we have a constant channel + any_constant_channels: torch.Tensor = (hundreth_quartile - zero_quantile) == torch.tensor([0]) + any_constant_channels = any_constant_channels.int() # transform boolean values to int values + + # possibilities to get nan as an answer + # will ignore any of these three cases with 0s and just not deal with them for now + # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative + # case (2) 0 in denominator: is possible unless case 3, we just ignore + # case (3) 0 in both: not outlier, channel just kinda useless, ignore + + # get the ratio and get rid of nan values + quantile_ratios = hundreth_quartile / comp_quantile + quantile_ratios = torch.nan_to_num(quantile_ratios) + # update averages, remembering to only update if didn't have zeros + ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios + + # if num_batches and average_ratio are not initialized, we want to initialize them + if self.percentile_batches_tracked.shape[0] == 0 or self.average_percentile_ratio.shape[0] == 0: + self.percentile_batches_tracked = torch.zeros_like(any_non_zero_quantile_value) + self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero) + + # also initialize the constant channel var if that is not initialized seperately + if self.constant_channels.shape[0] == 0: + self.constant_channels = torch.zeros_like(any_constant_channels) + + # get current num batches and average ratio + num_batches = self.percentile_batches_tracked + average_ratio = self.average_percentile_ratio + + # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches + new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value + new_ratios: torch.Tensor = ((average_ratio * num_batches) + ratio_if_not_zero) / new_number_of_batches + new_ratios = torch.nan_to_num(new_ratios) + + # update the number of non-constant channels + new_constant_count: torch.Tensor = self.constant_channels + any_constant_channels + + # update the values locally + self.percentile_batches_tracked.copy_(new_number_of_batches) + self.average_percentile_ratio.copy_(new_ratios) + self.constant_channels.copy_(new_constant_count) + + return x_copy + + @torch.jit.export def get_batch_to_epoch_ratio(self): @@ -75,6 +243,11 @@ def reset_batch_and_epoch_values(self): self.average_batch_activation_range = torch.tensor(float(0)) self.epoch_activation_min = torch.tensor(float("inf")) self.epoch_activation_max = torch.tensor(float("-inf")) + self.min_val = torch.tensor([]) + self.max_val = torch.tensor([]) + self.average_percentile_ratio = torch.tensor([]) + self.percentile_batches_tracked = torch.tensor([]) + self.constant_channels = torch.tensor([]) @torch.jit.export def calculate_qparams(self): diff --git a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py new file mode 100644 index 0000000000000..855c71261488e --- /dev/null +++ b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -0,0 +1,648 @@ +import torch +from typing import Any, Set, Dict, List, Tuple, OrderedDict +from collections import OrderedDict as OrdDict + +# try to import tablate +got_tabulate = True +try: + from tabulate import tabulate +except ImportError: + got_tabulate = False + + +# var to see if we could import matplotlib +got_matplotlib = True +try: + import matplotlib.pyplot as plt +except ImportError: + got_matplotlib = False + +class ModelReportVisualizer: + r""" + The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics + that were generated by the ModelReport API. However, at a higher level, the class aims to provide + some level of visualization of statistics to PyTorch in order to make it easier to parse data and + diagnose any potential issues with data or a specific model. With respect to the visualizations, + the ModelReportVisualizer class currently supports several methods of visualizing data. + + Supported Visualization Methods Include: + - Table format + - Plot format (line graph) + - Histogram format + + For all of the existing visualization methods, there is the option to filter data based on: + - A module fqn prefix + - Feature [required for the plot and histogram] + + * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below + Ensure sure that features that are the same across different report contain the same name + Ensure that objects representing the same features are the same type / dimension (where applicable) + + Note: + Currently, the ModelReportVisualizer class supports visualization of data generated by the + ModelReport class. However, this structure is extensible and should allow the visualization of + other information as long as the information is structured in the following general format: + + Report Structure + -- module_fqn [module with attached detectors] + | + -- feature keys [not every detector extracts same information] + [same collected info has same keys, unless can be specific to detector] + + + The goal behind the class is that the generated visualizations can be used in conjunction with the generated + report for people to get a better understanding of issues and what the fix might be. It is also just to provide + a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as + that grows in size. + + General Use Flow Expected + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration on your model to add relavent observers + 4.) Callibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + 6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance + 7.) Use instance to view different views of data as desired, applying filters as needed + 8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram + + """ + + # keys for table dict + TABLE_TENSOR_KEY = "tensor_level_info" + TABLE_CHANNEL_KEY = "channel_level_info" + + # Constants for header vals + NUM_NON_FEATURE_TENSOR_HEADERS = 2 + NUM_NON_FEATURE_CHANNEL_HEADERS = 3 + + # Constants for row index in header + CHANNEL_NUM_INDEX = 2 + + def __init__(self, generated_reports: OrderedDict[str, Any]): + r""" + Initializes the ModelReportVisualizer instance with the necessary reports. + + Args: + generated_reports (Dict[str, Any]): The reports generated by the ModelReport class + can also be a dictionary generated in another manner, as long as format is same + """ + self.generated_reports = generated_reports + + def get_all_unique_module_fqns(self) -> Set[str]: + r""" + The purpose of this method is to provide a user the set of all module_fqns so that if + they wish to use some of the filtering capabilities of the ModelReportVisualizer class, + they don't need to manually parse the generated_reports dictionary to get this information. + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + # returns the keys of the ordered dict + return set(self.generated_reports.keys()) + + def get_all_unique_feature_names(self, plottable_features_only: bool = True) -> Set[str]: + r""" + The purpose of this method is to provide a user the set of all feature names so that if + they wish to use the filtering capabilities of the generate_table_view(), or use either of + the generate_plot_view() or generate_histogram_view(), they don't need to manually parse + the generated_reports dictionary to get this information. + + Args: + plottable_features_only (bool): True if the user is only looking for plottable features, + False otherwise + plottable features are those that are tensor values + Default: True (only return those feature names that are plottable) + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + unique_feature_names = set() + for module_fqn in self.generated_reports: + # get dict of the features + feature_dict: Dict[str, Any] = self.generated_reports[module_fqn] + + # loop through features + for feature_name in feature_dict: + # if we need plottable, ensure type of val is tensor + if not plottable_features_only or type(feature_dict[feature_name]) == torch.Tensor: + unique_feature_names.add(feature_name) + + # return our compiled set of unique feature names + return unique_feature_names + + def _get_filtered_data(self, feature_filter: str, module_fqn_filter: str) -> OrderedDict[str, Any]: + r""" + Filters the data and returns it in the same ordered dictionary format so the relavent views can be displayed. + + Args: + feature_filter (str): The feature filter, if we want to filter the set of data to only include + a certain set of features that include feature_filter + If feature = "", then we do not filter based on any features + module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with + this prefix will be included + If module_fqn_filter = "" we do not filter based on module fqn, and include all modules + + First, the data is filtered based on module_fqn, and then filtered based on feature + Returns an OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + """ + # create return dict + filtered_dict: OrderedDict[str, Any] = OrdDict() + + for module_fqn in self.generated_reports: + # first filter based on module + if module_fqn_filter == "" or module_fqn_filter in module_fqn: + # create entry for module and loop through features + filtered_dict[module_fqn] = {} + module_reports = self.generated_reports[module_fqn] + for feature_name in module_reports: + # check if filtering on features and do so if desired + if feature_filter == "" or feature_filter in feature_name: + filtered_dict[module_fqn][feature_name] = module_reports[feature_name] + + # we have populated the filtered dict, and must return it + + return filtered_dict + + def _generate_tensor_table( + self, + filtered_data: OrderedDict[str, Dict[str, Any]], + tensor_features: List[str] + ) -> Tuple[List, List]: + r""" + Takes in the filtered data and features list and generates the tensor headers and table + + Currently meant to generate the headers and table for both the tensor information. + + Args: + filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + tensor_features (List[str]): A list of the tensor level features + + Returns a tuple with: + A list of the headers of the tensor table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the tensor information table + tensor_table: List[List[Any]] = [] + tensor_headers: List[str] = [] + + # append the table row to the table only if we have features + if len(tensor_features) > 0: + # now we add all the data + for index, module_fqn in enumerate(filtered_data): + # we make a new row for the tensor table + tensor_table_row = [index, module_fqn] + for feature in tensor_features: + # we iterate in same order of added features + + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if isinstance(feature_val, torch.Tensor): + feature_val = feature_val.item() + + # we add to our list of values + tensor_table_row.append(feature_val) + + tensor_table.append(tensor_table_row) + + # add row of headers of we actually have something, otherwise just empty + if len(tensor_table) != 0: + tensor_headers = ["idx", "layer_fqn"] + tensor_features + + return (tensor_headers, tensor_table) + + def _generate_channels_table( + self, + filtered_data: OrderedDict[str, Any], + channel_features: List[str], + num_channels: int + ) -> Tuple[List, List]: + r""" + Takes in the filtered data and features list and generates the channels headers and table + + Currently meant to generate the headers and table for both the channels information. + + Args: + filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + channel_features (List[str]): A list of the channel level features + num_channels (int): Number of channels in the channel data + + Returns a tuple with: + A list of the headers of the channel table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the table for the channel information table + channel_table: List[List[Any]] = [] + channel_headers: List[str] = [] + + # counter to keep track of number of entries in + channel_table_entry_counter: int = 0 + + if len(channel_features) > 0: + # now we add all channel data + for index, module_fqn in enumerate(filtered_data): + # we iterate over all channels + for channel in range(num_channels): + # we make a new row for the channel + new_channel_row = [channel_table_entry_counter, module_fqn, channel] + for feature in channel_features: + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature][channel] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if type(feature_val) is torch.Tensor: + feature_val = feature_val.item() + + # add value to channel specific row + new_channel_row.append(feature_val) + + # add to table and increment row index counter + channel_table.append(new_channel_row) + channel_table_entry_counter += 1 + + # add row of headers of we actually have something, otherwise just empty + if len(channel_table) != 0: + channel_headers = ["idx", "layer_fqn", "channel"] + channel_features + + return (channel_headers, channel_table) + + def generate_filtered_tables(self, feature_filter: str = "", module_fqn_filter: str = "") -> Dict[str, Tuple[List, List]]: + r""" + Takes in optional filter values and generates two tables with desired information. + + The generated tables are presented in both a list-of-lists format + + The reason for the two tables are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Returns a dictionary with two keys: + (Dict[str, Tuple[List, List]]) A dict containing two keys: + "tensor_level_info", "channel_level_info" + Each key maps to a tuple with: + A list of the headers of each table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + + Example Use: + >>> mod_report_visualizer.generate_filtered_tables( + feature_filter = "per_channel_min", + module_fqn_filter = "block1" + ) # generates table with per_channel_min info for all modules in block 1 of the model + """ + # first get the filtered data + filtered_data: OrderedDict[str, Any] = self._get_filtered_data(feature_filter, module_fqn_filter) + + # now we split into tensor and per-channel data + tensor_features: Set[str] = set() + channel_features: Set[str] = set() + + # keep track of the number of channels we have + num_channels: int = 0 + + for module_fqn in filtered_data: + for feature_name in filtered_data[module_fqn]: + # get the data for that specific feature + feature_data = filtered_data[module_fqn][feature_name] + + # check if not zero dim tensor + is_tensor: bool = isinstance(feature_data, torch.Tensor) + is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0 + + if is_not_zero_dim or isinstance(feature_data, list): + # works means per channel + channel_features.add(feature_name) + num_channels = len(feature_data) + else: + # means is per-tensor + tensor_features.add(feature_name) + + # we make them lists for iteration purposes + tensor_features_list: List[str] = sorted(list(tensor_features)) + channel_features_list: List[str] = sorted(list(channel_features)) + + # get the tensor info + tensor_headers, tensor_table = self._generate_tensor_table(filtered_data, tensor_features_list) + + # get the channel info + channel_headers, channel_table = self._generate_channels_table( + filtered_data, channel_features_list, num_channels + ) + + # let's now create the dictionary to return + table_dict = { + self.TABLE_TENSOR_KEY : (tensor_headers, tensor_table), + self.TABLE_CHANNEL_KEY : (channel_headers, channel_table) + } + + # return the two tables + return table_dict + + def generate_table_visualization(self, feature_filter: str = "", module_fqn_filter: str = ""): + r""" + Takes in optional filter values and prints out formatted tables of the information. + + The reason for the two tables printed out instead of one large one are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> mod_report_visualizer.generate_table_visualization( + feature_filter = "per_channel_min", + module_fqn_filter = "block1" + ) + # prints out neatly formatted table with per_channel_min info for + all modules in block 1 of the model + """ + # see if we got tabulate + if not got_tabulate: + print("Make sure to install tabulate and try again.") + return None + + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # get the table string and print it out + # now we have populated the tables for each one + # let's create the strings to be returned + table_str = "" + # the tables will have some headers columns that are non-feature + # ex. table index, module name, channel index, etc. + # we want to look at header columns for features, that come after those headers + if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS: + # if we have at least one tensor level feature to be addded we add tensor table + table_str += "Tensor Level Information \n" + table_str += tabulate(tensor_table, headers=tensor_headers) + if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS: + # if we have at least one channel level feature to be addded we add tensor table + table_str += "\n\n Channel Level Information \n" + table_str += tabulate(channel_table, headers=channel_headers) + + # if no features at all, let user know + if table_str == "": + table_str = "No data points to generate table with." + + print(table_str) + + def _get_plottable_data(self, feature_filter: str, module_fqn_filter: str) -> Tuple[List, List[List], bool]: + r""" + Takes in the feature filters and module filters and outputs the x and y data for plotting + + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str): Only includes modules that contains this string + + Returns a tuple of three elements + The first is a list containing relavent x-axis data + The second is a list containing the corresponding y-axis data + If the data is per channel + """ + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # make sure it is only 1 feature that is being plotted + # get the number of features in each of these + tensor_info_features_count = len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + channel_info_features_count = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + + # see if valid tensor or channel plot + is_valid_per_tensor_plot: bool = tensor_info_features_count == 1 + is_valid_per_channel_plot: bool = channel_info_features_count == 1 + + # offset should either be one of tensor or channel table or neither + feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + if is_valid_per_channel_plot: + feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + + # keep track of per_channel or not + data_is_per_channel: bool = False + + x_data: List = [] + y_data: List[List] = [] + # the feature will either be a tensor feature or channel feature + if is_valid_per_tensor_plot or is_valid_per_channel_plot: + # extra setup for y_data if per channel + if is_valid_per_channel_plot: + # gather the x_data and multiple y_data + # calculate the number of channels + num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in channel_table) + 1 + for channel in range(num_channels): + y_data.append([]) # seperate data list per channel + + for table_row_num, row in enumerate(tensor_table): + # get x_value to append + x_val_to_append = table_row_num + current_channel: int = -1 # intially chose current channel + # if new module we are looking at, add it's index to x_data + if is_valid_per_channel_plot and row[self.CHANNEL_NUM_INDEX] == 0: + new_module_index: int = table_row_num // num_channels + x_val_to_append = new_module_index + current_channel = row[self.CHANNEL_NUM_INDEX] + + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if not type(row_value) == str: + x_data.append(x_val_to_append) + # how we append y value depends on if per tensor or not + if is_valid_per_channel_plot: + y_data[current_channel].append(row_value) + else: + y_data.append(row_value) + else: + # more than one feature was chosen + error_str = "Make sure to pick only a single feature with your filter to plot a graph." + error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names." + error_str += " Pick one of those features to plot." + raise ValueError(error_str) + + # return x, y values, and if data is per-channel + return (x_data, y_data, data_is_per_channel) + + def generate_plot_visualization(self, feature_filter: str, module_fqn_filter: str = ""): + r""" + Takes in a feature and optional module_filter and plots of the desired data. + + Note: + Only features in the report that have tensor value data are plottable by this class + When the tensor information is plotted, it will plot: + idx as the x val, feature value as the y_val + When the channel information is plotted, it will plot: + the first idx of each module as the x val, feature value as the y_val [for each channel] + The reason for this is that we want to be able to compare values across the + channels for same layer, and it will be hard if values are staggered by idx + This means each module is represented by only 1 x value + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> mod_report_visualizer.generate_plot_visualization( + feature_filter = "per_channel_min", + module_fqn_filter = "block1" + ) + # outputs line plot of per_channel_min information for all modules in block1 of model + each channel gets it's own line, and it's plotted across the in-order modules + on the x-axis + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter) + + # plot based on whether data is per channel or not + fig = plt.figure() + ax = plt.subplot() + ax.set_ylabel(feature_filter) + ax.set_title(feature_filter + " Plot") + plt.xticks(x_data) # only show ticks for actual points + + if data_per_channel: + ax.set_xlabel("First idx of module") + # set the legend as well + # plot a seperate line for each channel + for index, channel_info in enumerate(y_data): + ax.plot(x_data, channel_info, label="Channel {}".format(index)) + + ax.legend(loc='upper right') + else: + ax.set_xlabel("idx") + ax.plot(x_data, y_data) + + # actually show the plot + plt.show() + + def generate_histogram_visualization(self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10): + r""" + Takes in a feature and optional module_filter and plots the histogram of desired data. + + Note: + Only features in the report that have tensor value data can be viewed as a histogram + If you want to plot a histogram from all the channel values of a specific feature for + a specific model, make sure to specify both the model and the feature properly + in the filters and you should be able to see a distribution of the channel data + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + num_bins (int, optional): The number of bins to create the histogram with + Default = 10, the values will be split into 10 equal sized bins + + Example Use: + >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization( + feature_filter = "per_channel_min", + module_fqn_filter = "block1" + ) + # outputs histogram of per_channel_min information for all modules in block1 of model + information is gathered across all channels for all modules in block 1 for the + per_channel_min and is displayed in a histogram of equally sized bins + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter) + + # for histogram, we just care about plotting the y data + # plot based on whether data is per channel or not + fig = plt.figure() + ax = plt.subplot() + ax.set_xlabel(feature_filter) + ax.set_ylabel("Frequency") + ax.set_title(feature_filter + " Histogram") + + if data_per_channel: + # set the legend as well + # combine all the data + all_data = [] + for index, channel_info in enumerate(y_data): + all_data.extend(channel_info) + val, bins, _ = plt.hist( + all_data, + bins=num_bins, + stacked=True, + rwidth=0.8, + ) + ax.legend(loc='upper right') + plt.xticks(bins) + else: + val, bins, _ = plt.hist( + y_data, + bins=num_bins, + stacked=False, + rwidth=0.8, + ) + plt.xticks(bins) + + plt.show() diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 72f83df65a70e..2fde17ac413fc 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional, Set, Callable, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type +from torch.ao.quantization.quant_type import QuantType import torch import copy import warnings @@ -44,6 +45,7 @@ is_observed_standalone_module, ) from ._equalize import update_obs_for_equalization, convert_eq_obs +from torch.nn.utils.parametrize import type_before_parametrizations from .utils import ( get_custom_module_class_keys, get_quantize_node_info, @@ -316,7 +318,11 @@ def convert_standalone_module( produce a reference model or a fbgemm/qnnpack model - backend_config_dict: backend configuration of the target backend of quantization """ - convert = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] + # TODO: remove is_reference flag + if is_reference: + convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx + else: + convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] # We know that observed standalone module is a GraphModule since # it's produced by us observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment] @@ -349,10 +355,8 @@ def convert_standalone_module( # TODO: allow convert_custom_config to override backend_config_dict # for standalone module - # TODO: think about how to handle `is_reference` here - quantized_standalone_module = convert( + quantized_standalone_module = convert_fn( observed_standalone_module, - is_reference=is_reference, backend_config_dict=backend_config_dict) parent_name, name = _parent_name(node.target) # update the modules dict @@ -458,8 +462,10 @@ def convert_weighted_module( # root_module_to_quantized_reference_module: module mapping from root (floating point) module class # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict) - ref_qmodule_cls = root_module_to_quantized_reference_module.get(type(float_module), None) - assert ref_qmodule_cls is not None, f"No reference quantized module class configured for {type(float_module)}" + ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None) + assert ( + ref_qmodule_cls is not None + ), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] if fused_module is not None: fused_module[0] = ref_qmodule # type: ignore[operator] @@ -471,7 +477,7 @@ def convert_custom_module( node: Node, graph: Graph, modules: Dict[str, torch.nn.Module], - custom_module_class_mapping: Dict[Callable, Callable], + custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]], statically_quantized_custom_module_nodes: Set[Node]): """ Converts an observed custom module to a quantized custom module based on `custom_module_class_mapping` @@ -578,14 +584,6 @@ def convert( node_name_to_scope, prepare_custom_config, observed_node_names = restore_state(model) qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment] - # TODO this should be removed now that gpu support for quantization is being supported. - # however in practice, as of 7/22/2021, certain functions that get called by convert expect - # only cpu arguments. - # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda', - # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend. - if not is_reference: - model.cpu() - # mapping from fully qualified module name to module instance # for example, # { @@ -762,16 +760,18 @@ def replace_observer_with_dequantize_node(node: Node, graph: Graph): elif is_observed_standalone_module(modules[node.target]): convert_standalone_module( node, modules, model, is_reference, backend_config_dict) - elif type(modules[node.target]) in set( + # below this point `type_before_parametrizations` is used + # instead of `type` to handle situations with fx quant + sparsity + elif type_before_parametrizations(modules[node.target]) in set( root_module_classes).union(qat_module_classes).union(fused_module_classes): # extra check for fused module classes to make sure they are fused module classes # of target modules - if type(modules[node.target]) in fused_module_classes and \ - type(modules[node.target][0]) not in root_module_classes: + if type_before_parametrizations(modules[node.target]) in fused_module_classes and \ + type_before_parametrizations(modules[node.target][0]) not in root_module_classes: continue convert_weighted_module( node, modules, observed_node_names, qconfig_map, backend_config_dict) - elif type(modules[node.target]) in custom_module_classes: + elif type_before_parametrizations(modules[node.target]) in custom_module_classes: convert_custom_module( node, model.graph, modules, custom_module_class_mapping, statically_quantized_custom_module_nodes) diff --git a/torch/ao/quantization/fx/fusion_patterns.py b/torch/ao/quantization/fx/fusion_patterns.py index 7d0cda498eb16..29d4217699b09 100644 --- a/torch/ao/quantization/fx/fusion_patterns.py +++ b/torch/ao/quantization/fx/fusion_patterns.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Optional, Union, List from .custom_config import FuseCustomConfig from .match_utils import MatchAllNode - +from torch.nn.utils.parametrize import type_before_parametrizations __all__ = [ "DefaultFuseHandler", @@ -91,7 +91,7 @@ def get_matched_types(m): if isinstance(m, tuple): return tuple(map(get_matched_types, m)) if isinstance(m, torch.nn.Module): - return type(m) + return type_before_parametrizations(m) return m matched_module_types = get_matched_types(matched_modules) diff --git a/torch/ao/quantization/fx/lower_to_fbgemm.py b/torch/ao/quantization/fx/lower_to_fbgemm.py index c8c413cacfee7..e08efc3104c3d 100644 --- a/torch/ao/quantization/fx/lower_to_fbgemm.py +++ b/torch/ao/quantization/fx/lower_to_fbgemm.py @@ -3,6 +3,8 @@ from ..qconfig import QConfigAny from typing import Dict, Tuple +__all__ = ['lower_to_fbgemm'] + def lower_to_fbgemm( model: QuantizedGraphModule, qconfig_map: Dict[str, QConfigAny], diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py index 02fc822ef46bf..d9bd4ff8459f6 100644 --- a/torch/ao/quantization/fx/match_utils.py +++ b/torch/ao/quantization/fx/match_utils.py @@ -17,7 +17,7 @@ from .graph_module import ( is_observed_standalone_module, ) - +from torch.nn.utils.parametrize import type_before_parametrizations from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set @@ -52,13 +52,13 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize): if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): return True - if len(node.users) > max_uses: + if not isinstance(node, Node) or len(node.users) > max_uses: return False if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): if node.op != 'call_module': return False - if not type(modules[node.target]) == self_match: + if not type_before_parametrizations(modules[node.target]) == self_match: return False elif callable(self_match): if node.op != 'call_function' or node.target is not self_match: diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 73be20ca8e8d7..274bfe95200da 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -17,8 +17,17 @@ from ..observer import ( ObserverBase, ) -from ..qconfig import QConfigAny, is_reuse_input_qconfig -from ..qconfig_mapping import QConfigMapping +from ..qconfig import ( + obs_or_fq_ctr_equals, + float16_dynamic_qconfig, + float16_static_qconfig, + is_reuse_input_qconfig, + QConfigAny, +) +from ..qconfig_mapping import ( + _FIXED_QPARAMS_OP_TO_OBSERVER, + QConfigMapping, +) from ..qconfig_mapping_utils import ( get_flattened_qconfig_dict, update_qconfig_for_qat, @@ -37,6 +46,8 @@ NodePattern, ) +from torch.ao.quantization import FixedQParamsFakeQuantize + from ._equalize import ( is_equalization_observer, node_supports_equalization, @@ -1316,6 +1327,42 @@ def insert_observers_for_model( return results_node +def _validate_fixed_qparams_qconfigs(model: GraphModule, qconfig_map: Dict[str, QConfigAny]): + """ + Validate whether the correct observers are configured for fixed qparams ops in the model, if any. + """ + # TODO: handle fp16 qconfigs properly + allowed_observer_ctrs = [ + float16_dynamic_qconfig.activation, + float16_static_qconfig.activation, + ] + named_modules = dict(model.named_modules(remove_duplicate=False)) + for node in model.graph.nodes: + if node.op == "call_function": + module_type_or_function_or_method = node.target + elif node.op == "call_module": + module_type_or_function_or_method = type(named_modules[node.target]) + else: + module_type_or_function_or_method = None + + if module_type_or_function_or_method in _FIXED_QPARAMS_OP_TO_OBSERVER: + bad_observer = True + qconfig = qconfig_map.get(node.name, None) + if qconfig is None: + bad_observer = False + else: + for observer_ctr in allowed_observer_ctrs + [_FIXED_QPARAMS_OP_TO_OBSERVER[module_type_or_function_or_method]]: + if obs_or_fq_ctr_equals( + qconfig.activation, + FixedQParamsFakeQuantize.with_args(observer=observer_ctr)) or \ + obs_or_fq_ctr_equals(qconfig.activation, observer_ctr): + bad_observer = False + if bad_observer: + raise ValueError("QConfigMapping must specify fixed qparams observer for fixed qparams op " + "'%s' type: '%s'. Please use torch.ao.quantization.get_default_qconfig_mapping or " + "torch.ao.quantization.get_default_qat_qconfig_mapping" + " instead." % (node.format_node(), module_type_or_function_or_method)) + def run_prepare_fx_on_standalone_modules( model: torch.nn.Module, is_qat: bool, @@ -1387,7 +1434,7 @@ def prepare( node_name_to_scope: Dict[str, Tuple[str, type]], example_inputs: Tuple[Any, ...], prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, - equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None, + _equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None, backend_config_dict: Optional[Dict[str, Any]] = None, is_standalone_module: bool = False) -> ObservedGraphModule: """ standalone_module means it a submodule that is not inlined in @@ -1412,8 +1459,8 @@ def prepare( """ if prepare_custom_config is None: prepare_custom_config = PrepareCustomConfig() - if equalization_config is None: - equalization_config = QConfigMapping() + if _equalization_config is None: + _equalization_config = QConfigMapping() if isinstance(qconfig_mapping, Dict): warnings.warn( @@ -1421,11 +1468,11 @@ def prepare( "in a future version. Please pass in a QConfigMapping instead.") qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) - if isinstance(equalization_config, Dict): + if isinstance(_equalization_config, Dict): warnings.warn( "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " "be supported in a future version. Please pass in a QConfigMapping instead.") - equalization_config = QConfigMapping.from_dict(equalization_config) + _equalization_config = QConfigMapping.from_dict(_equalization_config) if isinstance(prepare_custom_config, Dict): warnings.warn( @@ -1434,9 +1481,9 @@ def prepare( prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) assert(isinstance(qconfig_mapping, QConfigMapping)) - assert(isinstance(equalization_config, QConfigMapping)) + assert(isinstance(_equalization_config, QConfigMapping)) qconfig_mapping = copy.deepcopy(qconfig_mapping) - equalization_config = copy.deepcopy(equalization_config) + _equalization_config = copy.deepcopy(_equalization_config) # mapping from a tuple of nodes in reverse order to uninitialized # QuantizeHandler subclass. For example, @@ -1477,7 +1524,7 @@ def prepare( get_fusion_pattern_to_root_node_getter(backend_config_dict) update_qconfig_for_fusion(model, qconfig_mapping) - update_qconfig_for_fusion(model, equalization_config) + update_qconfig_for_fusion(model, _equalization_config) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_mapping) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) @@ -1498,8 +1545,9 @@ def prepare( # fill qconfig_map, a map from node name to qconfig, used in find_matches equalization_qconfig_map = generate_qconfig_map( - model, modules, model.graph, equalization_config, node_name_to_scope) + model, modules, model.graph, _equalization_config, node_name_to_scope) qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_mapping, node_name_to_scope) + _validate_fixed_qparams_qconfigs(model, qconfig_map) # match the patterns that will get quantized standalone_module_names = list(prepare_custom_config.standalone_module_names.keys()) diff --git a/torch/ao/quantization/fx/tracer.py b/torch/ao/quantization/fx/tracer.py new file mode 100644 index 0000000000000..732c8b9575550 --- /dev/null +++ b/torch/ao/quantization/fx/tracer.py @@ -0,0 +1,119 @@ +import torch +from torch.fx._symbolic_trace import Tracer +from torch.fx.node import Target, Node, Argument +from torch.nn.intrinsic import _FusedModule +from typing import List, Callable, Tuple, Any, Dict, Optional + +__all__ = [ + "QuantizationTracer", +] + +class Scope(object): + """ Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example:: + + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + class M(torch.nn.Module): + def __init__(self): + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ + + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + + +class ScopeContextManager(object): + """ A context manager to track the Scope of Node during symbolic tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + + def __init__( + self, scope: Scope, current_module: torch.nn.Module, current_module_path: str + ): + super().__init__() + self.prev_module_type = scope.module_type + self.prev_module_path = scope.module_path + self.scope = scope + self.scope.module_path = current_module_path + self.scope.module_type = type(current_module) + + def __enter__(self): + return + + def __exit__(self, *args): + self.scope.module_path = self.prev_module_path + self.scope.module_type = self.prev_module_type + return + +class QuantizationTracer(Tracer): + def __init__( + self, skipped_module_names: List[str], skipped_module_classes: List[Callable] + ): + super().__init__() + self.skipped_module_names = skipped_module_names + self.skipped_module_classes = skipped_module_classes + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type + self.scope = Scope("", None) + self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} + self.record_stack_traces = True + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + return ( + ( + m.__module__.startswith("torch.nn") + and not isinstance(m, torch.nn.Sequential) + ) + or module_qualified_name in self.skipped_module_names + or type(m) in self.skipped_module_classes + or isinstance(m, _FusedModule) + ) + + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Any: + module_qualified_name = self.path_of_module(m) + # Creating scope with information of current module + # scope will be restored automatically upon exit + with ScopeContextManager(self.scope, m, module_qualified_name): + return super().call_module(m, forward, args, kwargs) + + def create_node( + self, + kind: str, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) + self.node_name_to_scope[node.name] = ( + self.scope.module_path, + self.scope.module_type, + ) + return node diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 3279a273c751d..cbf4b77e20686 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -1,3 +1,4 @@ +import copy import re import torch import torch.nn as nn @@ -11,6 +12,7 @@ Graph, Node, ) +from .custom_config import PrepareCustomConfig from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type from collections import namedtuple @@ -47,6 +49,7 @@ "quantize_node", "return_arg_list", "WEIGHT_INDEX_DICT", + "get_skipped_module_name_and_classes", ] @@ -404,7 +407,9 @@ def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) device = assert_and_get_unique_device(module) - module.register_buffer(attr_name, torch.tensor(value, device=device)) + new_value = value.clone().detach() if isinstance(value, torch.Tensor) \ + else torch.tensor(value, device=device) + module.register_buffer(attr_name, new_value) # Create get_attr with value attr_node = graph.create_node("get_attr", attr_name) return attr_node @@ -624,3 +629,16 @@ def create_node_from_old_node_preserve_meta( new_node = quantized_graph.create_node(*create_node_args) new_node.stack_trace = old_node.stack_trace return new_node + +def get_skipped_module_name_and_classes( + prepare_custom_config: PrepareCustomConfig, + is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]: + skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) + skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes) + if not is_standalone_module: + # standalone module and custom module config are applied in top level module + skipped_module_names += list(prepare_custom_config.standalone_module_names.keys()) + skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys()) + skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping) + + return skipped_module_names, skipped_module_classes diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 7e86a39f1b177..f407a505c429a 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -15,6 +15,40 @@ from torch.ao.quantization.utils import check_min_max_valid, calculate_qmin_qmax +__all__ = [ + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_quant_observer", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_observer", + "default_weight_observer", + "get_observer_state_dict", + "load_observer_state_dict", + "per_channel_weight_observer_range_neg_127_to_127", + "weight_observer_range_neg_127_to_127", + "FixedQParamsObserver", + "HistogramObserver", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", +] + + class _PartialWrapper(object): def __init__(self, p): self.p = p diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 363b01b9fc9e4..a4766ee8bdaee 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Optional, Any +from typing import Optional, Any, Union import torch import torch.nn as nn @@ -21,6 +21,7 @@ ) from .observer import ( + _PartialWrapper, HistogramObserver, MovingAverageMinMaxObserver, NoopObserver, @@ -37,9 +38,48 @@ weight_observer_range_neg_127_to_127, per_channel_weight_observer_range_neg_127_to_127, default_reuse_input_observer, + ObserverBase, ) import warnings - +import copy + +__all__ = [ + "QConfig", + # TODO: deprecated, remove + "QConfigDynamic", + "default_qconfig", + "default_debug_qconfig", + "default_per_channel_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float16_static_qconfig", + "per_channel_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + "float_qparams_weight_only_qconfig_4bit", + "default_qat_qconfig", + "default_dynamic_qat_qconfig", + "default_weight_only_qconfig", + "default_activation_only_qconfig", + "default_qat_qconfig_v2", + "default_reuse_input_qconfig", + "default_symmetric_qnnpack_qconfig", + "default_per_channel_symmetric_qnnpack_qconfig", + "default_symmetric_qnnpack_qat_qconfig", + "default_per_channel_symmetric_qnnpack_qat_qconfig", + "default_embedding_qat_qconfig", + "default_embedding_qat_qconfig_4bit", + "get_default_qconfig", + "get_default_qat_qconfig", + "get_default_qconfig_dict", + "get_default_qat_qconfig_dict", + "assert_valid_qconfig", + "add_module_to_qconfig_obs_ctr", + "QConfigAny", + "obs_or_fq_ctr_equals", + "qconfig_equals", + "activation_is_memoryless", + "is_reuse_input_qconfig", +] class QConfig(namedtuple('QConfig', ['activation', 'weight'])): """ @@ -415,16 +455,33 @@ def configure_constructor_to_put_obs_on_module_device(original_constructor): return QConfig(activation, weight) -def qconfig_equals(q1: QConfigAny, q2: QConfigAny): +_ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, ObserverBase, FakeQuantizeBase] + +def obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor): + if isinstance(obs_or_fq1, _PartialWrapper) and isinstance(obs_or_fq2, _PartialWrapper): + return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) + return obs_or_fq1 == obs_or_fq2 + +def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper): """ - Returns `True` if `q1` equals `q2`, and `False` otherwise. + Return whether the two partial wrappers are equal, """ # functools.partial has no __eq__ operator defined so '==' defaults to 'is' - def partial_equals(p1, p2): - same = p1.func == p2.func - same = same and p1.args == p2.args - return same and p1.keywords == p2.keywords + obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) + obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) + keywords_equal = True + # compare observer constructor with obs_or_fq_ctr_equals since direct compare would fail + if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: + keywords_equal = keywords_equal and obs_or_fq_ctr_equals(obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"]) + obs_or_fq1_keywords.pop("observer") + obs_or_fq2_keywords.pop("observer") + keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords + return obs_or_fq1.p.func == obs_or_fq2.p.func and obs_or_fq1.p.args == obs_or_fq2.p.args and keywords_equal +def qconfig_equals(q1: QConfigAny, q2: QConfigAny): + """ + Returns `True` if `q1` equals `q2`, and `False` otherwise. + """ if q1 is None or q2 is None: return q1 == q2 else: @@ -433,15 +490,8 @@ def partial_equals(p1, p2): # Qconfig weight and activation can be either a partial wrapper, # or an observer class. Special handling is required (above) for # comparing partial wrappers. - if(isinstance(q1.activation, torch.ao.quantization.observer._PartialWrapper)): - activation_same = partial_equals(q1.activation.p, q2.activation.p) - else: - activation_same = q1.activation == q2.activation - if(isinstance(q1.weight, torch.ao.quantization.observer._PartialWrapper)): - weight_same = partial_equals(q1.weight.p, q2.weight.p) - else: - weight_same = q1.weight == q2.weight - + activation_same = obs_or_fq_ctr_equals(q1.activation, q2.activation) + weight_same = obs_or_fq_ctr_equals(q1.weight, q2.weight) return activation_same and weight_same except AttributeError: return q1 == q2 diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 7caa6f9216cf3..b459722286907 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -4,8 +4,16 @@ import torch -from .fake_quantize import default_weight_fake_quant -from .observer import default_weight_observer +from .fake_quantize import ( + default_weight_fake_quant, + FixedQParamsFakeQuantize, +) +from .observer import ( + _PartialWrapper, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + default_weight_observer, +) from .qconfig import ( default_reuse_input_qconfig, get_default_qconfig, @@ -29,8 +37,24 @@ MODULE_NAME_DICT_KEY = "module_name" MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" - -def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int): +_FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = { + torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, + torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer, + "hardsigmoid": default_fixed_qparams_range_0to1_observer, + "hardsigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer, + torch.sigmoid: default_fixed_qparams_range_0to1_observer, + "sigmoid": default_fixed_qparams_range_0to1_observer, + "sigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Softmax: default_fixed_qparams_range_0to1_observer, + torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer, + torch.tanh: default_fixed_qparams_range_neg1to1_observer, + "tanh": default_fixed_qparams_range_neg1to1_observer, + "tanh_": default_fixed_qparams_range_neg1to1_observer, +} + + +def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QConfigMapping: """ Return the default QConfigMapping for the given quantization type and backend. """ @@ -38,18 +62,18 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int): qconfig = get_default_qat_qconfig(backend, version) else: qconfig = get_default_qconfig(backend, version) + default_weight = default_weight_fake_quant if is_qat else default_weight_observer # default_per_channel_weight_observer is not currently compatible with fbgemm backend # so we have to modify the weight observer to default_weight_observer or another # per tensor supported observer. # see https://github.com/pytorch/pytorch/issues/47535 if backend == "fbgemm": - default_weight = default_weight_fake_quant if is_qat else default_weight_observer qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight) else: qconfig_transpose = qconfig - return QConfigMapping() \ + qconfig_mapping = QConfigMapping() \ .set_global(qconfig) \ .set_object_type("reshape", default_reuse_input_qconfig) \ .set_object_type(torch.nn.Conv1d, qconfig) \ @@ -73,13 +97,29 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int): .set_object_type(torch.nn.BatchNorm2d, qconfig) \ .set_object_type(torch.nn.BatchNorm3d, qconfig) -def get_default_qconfig_mapping(backend="fbgemm", version=0): + # Use special observers for ops with fixed qparams + fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {} + for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): + if observer in fixed_qparams_observer_to_qconfig: + fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer] + else: + if is_qat: + activation = FixedQParamsFakeQuantize.with_args(observer=observer) + else: + activation = observer + fixed_qparams_qconfig = QConfig(activation=activation, weight=default_weight) + fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig + qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig) + + return qconfig_mapping + +def get_default_qconfig_mapping(backend="fbgemm", version=0) -> QConfigMapping: """ Return the default QConfigMapping for post training quantization. """ return _get_default_qconfig_mapping(False, backend, version) -def get_default_qat_qconfig_mapping(backend="fbgemm", version=1): +def get_default_qat_qconfig_mapping(backend="fbgemm", version=1) -> QConfigMapping: """ Return the default QConfigMapping for quantization aware training. """ diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py index ebaa693c74775..8b192b6ffd670 100644 --- a/torch/ao/quantization/quantization_mappings.py +++ b/torch/ao/quantization/quantization_mappings.py @@ -71,6 +71,7 @@ nn.Linear: nnq.Linear, nn.ReLU6: nnq.ReLU6, nn.Dropout: nnq.Dropout, + nn.PReLU: nnq.PReLU, # Wrapper Modules: nnq.FloatFunctional: nnq.QFunctional, # Intrinsic modules: diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index f5aa195c94dd9..8a2d0679aa1c6 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -28,6 +28,22 @@ activation_is_memoryless) from torch.nn.utils.parametrize import type_before_parametrizations +_DEFAULT_CUSTOM_CONFIG_DICT = { + 'float_to_observed_custom_module_class': { + nn.LSTM: nn.quantizable.LSTM, + nn.MultiheadAttention: nn.quantizable.MultiheadAttention, + }, + 'observed_to_quantized_custom_module_class': { + nn.quantizable.LSTM: nn.quantized.LSTM, + nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, + } +} + +def get_default_custom_config_dict(): + r"""Defines the default custom config dict. + """ + return _DEFAULT_CUSTOM_CONFIG_DICT + def is_activation_post_process(module): return (isinstance(module, torch.ao.quantization.ObserverBase) or isinstance(module, torch.ao.quantization.FakeQuantizeBase)) @@ -261,7 +277,7 @@ def prepare(model, inplace=False, allow_list=None, """ torch._C._log_api_usage_once("quantization_api.quantize.prepare") if prepare_custom_config_dict is None: - prepare_custom_config_dict = {} + prepare_custom_config_dict = get_default_custom_config_dict() custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) if not inplace: @@ -543,7 +559,7 @@ def _convert( mapping = get_default_static_quant_reference_module_mappings() if is_reference \ else get_default_static_quant_module_mappings() if convert_custom_config_dict is None: - convert_custom_config_dict = {} + convert_custom_config_dict = get_default_custom_config_dict() custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) if not inplace: diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 71ceedb04e426..fd09a49cca8bc 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -1,12 +1,9 @@ -import copy -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Set, Tuple, Union import warnings import torch from torch.fx import GraphModule -from torch.fx._symbolic_trace import Tracer -from torch.fx.node import Target, Node, Argument -from torch.nn.intrinsic import _FusedModule +from .fx.tracer import QuantizationTracer from .fx import fuse # noqa: F401 from .fx import prepare # noqa: F401 from .fx.convert import convert @@ -19,6 +16,7 @@ ) from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 +from .fx.utils import get_skipped_module_name_and_classes from .qconfig_mapping import QConfigMapping def _check_is_graph_module(model: torch.nn.Module) -> None: @@ -118,76 +116,19 @@ def __exit__(self, *args): return -class QuantizationTracer(Tracer): - def __init__( - self, skipped_module_names: List[str], skipped_module_classes: List[Callable] - ): - super().__init__() - self.skipped_module_names = skipped_module_names - self.skipped_module_classes = skipped_module_classes - # NB: initialized the module_type of top level module to None - # we are assuming people won't configure the model with the type of top level - # module here, since people can use "" for global config - # We can change this if there is a use case that configures - # qconfig using top level module type - self.scope = Scope("", None) - self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} - self.record_stack_traces = True - - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: - return ( - ( - m.__module__.startswith("torch.nn") - and not isinstance(m, torch.nn.Sequential) - ) - or module_qualified_name in self.skipped_module_names - or type(m) in self.skipped_module_classes - or isinstance(m, _FusedModule) - ) - - def call_module( - self, - m: torch.nn.Module, - forward: Callable[..., Any], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - ) -> Any: - module_qualified_name = self.path_of_module(m) - # Creating scope with information of current module - # scope will be restored automatically upon exit - with ScopeContextManager(self.scope, m, module_qualified_name): - return super().call_module(m, forward, args, kwargs) - - def create_node( - self, - kind: str, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - ) -> Node: - node = super().create_node(kind, target, args, kwargs, name, type_expr) - self.node_name_to_scope[node.name] = ( - self.scope.module_path, - self.scope.module_type, - ) - return node - - def _prepare_fx( model: torch.nn.Module, qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], is_qat: bool, example_inputs: Tuple[Any, ...], prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, - equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, + _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, is_standalone_module: bool = False, ) -> ObservedGraphModule: r""" Internal helper function for prepare_fx Args: - `model`, `qconfig_mapping`, `prepare_custom_config`, `equalization_config`: + `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`: see docs for :func:`~torch.ao.quantization.prepare_fx` `is_standalone_module`: a boolean flag indicates whether we are quantizing a standalone module or not, a standalone module @@ -198,8 +139,8 @@ def _prepare_fx( """ if prepare_custom_config is None: prepare_custom_config = PrepareCustomConfig() - if equalization_config is None: - equalization_config = QConfigMapping() + if _equalization_config is None: + _equalization_config = QConfigMapping() if isinstance(prepare_custom_config, Dict): warnings.warn( @@ -207,20 +148,13 @@ def _prepare_fx( "in a future version. Please pass in a PrepareCustomConfig instead.") prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) - skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) - skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes) - # swap FloatFunctional with FXFloatFunctional _swap_ff_with_fxff(model) - # symbolically trace the model - if not is_standalone_module: - # standalone module and custom module config are applied in top level module - skipped_module_names += list(prepare_custom_config.standalone_module_names.keys()) - skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys()) - skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping) - + skipped_module_names, skipped_module_classes = \ + get_skipped_module_name_and_classes(prepare_custom_config, is_standalone_module) preserved_attributes = prepare_custom_config.preserved_attributes + # symbolically trace the model tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type] graph_module = GraphModule(model, tracer.trace(model)) for attr_name in preserved_attributes: @@ -238,7 +172,7 @@ def _prepare_fx( tracer.node_name_to_scope, example_inputs=example_inputs, prepare_custom_config=prepare_custom_config, - equalization_config=equalization_config, + _equalization_config=_equalization_config, backend_config_dict=backend_config_dict, is_standalone_module=is_standalone_module, ) # type: ignore[operator] @@ -337,7 +271,7 @@ def prepare_fx( qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], example_inputs: Tuple[Any, ...], prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, - equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, + _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, ) -> ObservedGraphModule: r""" Prepare a model for post training static quantization @@ -379,7 +313,7 @@ def prepare_fx( .set_output_quantized_indexes([0]) \ .set_preserved_attributes(["attr1", "attr2"]) - * `equalization_config`: config for specifying how to perform equalization on the model + * `_equalization_config`: config for specifying how to perform equalization on the model * `backend_config_dict`: a dictionary that specifies how operators are quantized in a backend, this includes how the operaetors are observed, @@ -417,7 +351,7 @@ def calibrate(model, data_loader): False, # is_qat example_inputs, prepare_custom_config, - equalization_config, + _equalization_config, backend_config_dict, ) @@ -512,7 +446,6 @@ def _convert_fx( def convert_fx( graph_module: GraphModule, - is_reference: bool = False, convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, _remove_qconfig: bool = True, qconfig_mapping: Union[QConfigMapping, Dict[str, Any]] = None, @@ -568,8 +501,54 @@ def convert_fx( torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") return _convert_fx( graph_module, - is_reference, - convert_custom_config, + is_reference=False, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config_dict=backend_config_dict, + ) + + +def convert_to_reference_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any]] = None, + backend_config_dict: Dict[str, Any] = None, +) -> torch.nn.Module: + r""" Convert a calibrated or trained model to a reference quantized model, a common interface + between PyTorch quantization with other backends like accelerators. Callers should additionally + lower the returned reference model to the target backend before using the model for inference. + + Args: + * `graph_module`: A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config`: custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail. + + * `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping`: config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail. + + * `backend_config_dict`: A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail. + + Return: + A reference quantized model (GraphModule) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + reference_model = convert_to_reference_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx") + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, _remove_qconfig=_remove_qconfig, qconfig_mapping=qconfig_mapping, backend_config_dict=backend_config_dict, diff --git a/torch/ao/sparsity/__init__.py b/torch/ao/sparsity/__init__.py index 2aac11434b0bf..715eaa6b8c6dd 100644 --- a/torch/ao/sparsity/__init__.py +++ b/torch/ao/sparsity/__init__.py @@ -15,18 +15,4 @@ from .sparsifier.utils import FakeSparsity from .sparsifier.utils import module_to_fqn from .sparsifier.utils import fqn_to_module - -# === Experimental === - -# Parametrizations -from .experimental.pruner.parametrization import PruningParametrization -from .experimental.pruner.parametrization import ZeroesParametrization -from .experimental.pruner.parametrization import ActivationReconstruction -from .experimental.pruner.parametrization import BiasHook - -# Pruner -from .experimental.pruner.base_pruner import BasePruner - -# Data Sparsifier -from .experimental.data_sparsifier.base_data_sparsifier import BaseDataSparsifier -from .experimental.data_sparsifier.data_norm_sparsifier import DataNormSparsifier +from .sparsifier.utils import get_arg_info_from_tensor_fqn diff --git a/torch/ao/sparsity/experimental/data_sparsifier/__init__.py b/torch/ao/sparsity/_experimental/__init__.py similarity index 100% rename from torch/ao/sparsity/experimental/data_sparsifier/__init__.py rename to torch/ao/sparsity/_experimental/__init__.py diff --git a/torch/ao/sparsity/_experimental/activation_sparsifier/README.md b/torch/ao/sparsity/_experimental/activation_sparsifier/README.md new file mode 100644 index 0000000000000..3c2514c2f116b --- /dev/null +++ b/torch/ao/sparsity/_experimental/activation_sparsifier/README.md @@ -0,0 +1,106 @@ +# Activation Sparsifier + +## Introduction +Activation sparsifier attaches itself to a layer(s) in the model and prunes the activations passing through them. **Note that the layer weights are not pruned here.** + +## How does it work? +The idea is to compute a mask to prune the activations. To compute the mask, we need a representative tensor that generalizes activations coming from all the batches in the dataset. + +There are 3 main steps involved: +1. **Aggregation**: The activations coming from inputs across all the batches are aggregated using a user-defined `aggregate_fn`. +A simple example is the add function. +2. **Reduce**: The aggregated activations are then reduced using a user-defined `reduce_fn`. A simple example is average. +3. **Masking**: The reduced activations are then passed into a user-defined `mask_fn` to compute the mask. + +Essentially, the high level idea of computing the mask is + +``` +>>> aggregated_tensor = aggregate_fn([activation for activation in all_activations]) +>>> reduced_tensor = reduce_fn(aggregated_tensor) +>>> mask = mask_fn(reduced_tensor) +``` + +*The activation sparsifier also supports per-feature/channel sparsity. This means that a desired set of features in an activation can be also pruned. The mask will be stored per feature.* + +``` +>>> # when features = None, mask is a tensor computed on the entire activation tensor +>>> # otherwise, mask is a list of tensors of length = len(features), computed on each feature of activations +>>> +>>> # On a high level, this is how the mask is computed if features is not None +>>> for i in range(len(features)): +>>> aggregated_tensor_feature = aggregate_fn([activation[features[i]] for activation in all_activations]) +>>> mask[i] = mask_fn(reduce_fn(aggregated_tensor_feature)) +``` + +## Implementation Details +The activation sparsifier attaches itself to a set of layers in a model and then attempts to sparsify the activations flowing through them. *Attach* means registering a [`forward_pre_hook()`](https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#register_forward_pre_hook) to the layer. + +Let's go over the 3 steps again - +1. **Aggregation**: The activation of aggregation happens by attaching a hook to the layer that specifically applies and stores the aggregated data. The aggregation happens per feature, if the features are specified, otherwise it happens on the entire tensor. +The `aggregate_fn` should accept two input tensors and return an aggregated tensor. Example: +``` +def aggregate_fn(tensor1, tensor2): + return tensor1 + tensor2 +``` + +2. **Reduce**: This is initiated once the `step()` is called. The `reduce_fn()` is called on the aggregated tensor. The goal is to squash the aggregated tensor. +The `reduce_fn` should accept one tensor as argument and return a reduced tensor. Example: +``` +def reduce_fn(agg_tensor): + return agg_tensor.mean(dim=0) +``` + +3. **Masking**: The computation of the mask happens immediately after the reduce operation. The `mask_fn()` is applied on the reduced tensor. Again, this happens per-feature, if the features are specified. +The `mask_fn` should accept a tensor (reduced) and sparse config as arguments and return a mask (computed using tensor according to the config). Example: +``` +def mask_fn(tensor, threshold): # threshold is the sparse config here + mask = torch.ones_like(tensor) + mask[torch.abs(tensor) < threshold] = 0.0 + return mask +``` + +## API Design +`ActivationSparsifier`: Attaches itself to a model layer and sparsifies the activation flowing through that layer. The user can pass in the default `aggregate_fn`, `reduce_fn` and `mask_fn`. Additionaly, `features` and `feature_dim` are also accepted. + +`register_layer`: Registers a layer for sparsification. Specifically, registers `forward_pre_hook()` that performs aggregation. + +`step`: For each registered layer, applies the `reduce_fn` on aggregated activations and then applies `mask_fn` after reduce operation. + +`squash_mask`: Unregisters aggregate hook that was applied earlier and registers sparsification hooks if `attach_sparsify_hook=True`. Sparsification hooks applies the computed mask to the activations before it flows into the registered layer. + +## Example + +``` +# Fetch model +model = SomeModel() + +# define some aggregate, reduce and mask functions +def aggregate_fn(tensor1, tensor2): + return tensor1 + tensor2 + +def reduce_fn(tensor): + return tensor.mean(dim=0) + +def mask_fn(data, threshold): + mask = torch.ones_like(tensor) + mask[torch.abs(tensor) < threshold] = 0.0 + return mask) + +# sparse config +default_sparse_config = {"threshold": 0.5} + +# define activation sparsifier +act_sparsifier = ActivationSparsifier(model=model, aggregate_fn=aggregate_fn, reduce_fn=reduce_fn, mask_fn=mask_fn, **threshold) + +# register some layer to sparsify their activations +act_sparsifier.register_layer(model.some_layer, threshold=0.8) # custom sparse config + +for epoch in range(EPOCHS): + for input, target in dataset: + ... + out = model(input) + ... + act_sparsifier.step() # mask is computed + +act_sparsifier.squash_mask(attach_sparsify_hook=True) # activations are multiplied with the computed mask before flowing through the layer +``` diff --git a/torch/ao/sparsity/experimental/pruner/__init__.py b/torch/ao/sparsity/_experimental/activation_sparsifier/__init__.py similarity index 100% rename from torch/ao/sparsity/experimental/pruner/__init__.py rename to torch/ao/sparsity/_experimental/activation_sparsifier/__init__.py diff --git a/torch/ao/sparsity/_experimental/activation_sparsifier/activation_sparsifier.py b/torch/ao/sparsity/_experimental/activation_sparsifier/activation_sparsifier.py new file mode 100644 index 0000000000000..ae9522ce8cc77 --- /dev/null +++ b/torch/ao/sparsity/_experimental/activation_sparsifier/activation_sparsifier.py @@ -0,0 +1,416 @@ +from typing import Dict, Any, List +import torch +from collections import defaultdict +from torch import nn +import copy +from ...sparsifier.utils import fqn_to_module, module_to_fqn +import warnings + +__all__ = ['ActivationSparsifier'] + + +class ActivationSparsifier: + r""" + The Activation sparsifier class aims to sparsify/prune activations in a neural + network. The idea is to attach the sparsifier to a layer (or layers) and it + zeroes out the activations based on the mask_fn (or sparsification function) + input by the user. + The mask_fn is applied once all the inputs are aggregated and reduced i.e. + mask = mask_fn(reduce_fn(aggregate_fn(activations))) + + Note:: + The sparsification mask is computed on the input **before it goes through the attached layer**. + + Args: + model (nn.Module): + The model whose layers will be sparsified. The layers that needs to be + sparsified should be added separately using the register_layer() function + aggregate_fn (Optional, Callable): + default aggregate_fn that is used if not specified while registering the layer. + specifies how inputs should be aggregated over time. + The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor. + Example + >>> def add_agg_fn(tensor1, tensor2): return tensor1 + tensor2 + reduce_fn (Optional, Callable): + default reduce_fn that is used if not specified while registering the layer. + reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after + calling agg_fn() on all inputs. + Example + >>> def mean_reduce_fn(agg_tensor): return agg_tensor.mean(dim=0) + mask_fn (Optional, Callable): + default mask_fn that is used to create the sparsification mask using the tensor obtained after + calling the reduce_fn(). This is used by default if a custom one is passed in the + register_layer(). + Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config + arguments. + features (Optional, list): + default selected features to sparsify. + If this is non-empty, then the mask_fn will be applied for each feature of the input. + For example, + >>> mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features] + feature_dim (Optional, int): + default dimension of input features. Again, features along this dim will be chosen + for sparsification. + sparse_config (Dict): + Default configuration for the mask_fn. This config will be passed + with the mask_fn() + + Example: + >>> model = SomeModel() + >>> act_sparsifier = ActivationSparsifier(...) # init activation sparsifier + >>> # Initialize aggregate_fn + >>> def agg_fn(x, y): + >>> return x + y + >>> + >>> # Initialize reduce_fn + >>> def reduce_fn(x): + >>> return torch.mean(x, dim=0) + >>> + >>> # Initialize mask_fn + >>> def mask_fn(data): + >>> return torch.eye(data.shape).to(data.device) + >>> + >>> + >>> act_sparsifier.register_layer(model.some_layer, aggregate_fn=agg_fn, reduce_fn=reduce_fn, mask_fn=mask_fn) + >>> + >>> # start training process + >>> # epoch starts + >>> # model.forward(), compute_loss() and model.backwards() + >>> # epoch ends + >>> act_sparsifier.step() + >>> # end training process + >>> sparsifier.squash_mask() + """ + def __init__(self, model: nn.Module, aggregate_fn=None, reduce_fn=None, mask_fn=None, + features=None, feature_dim=None, **sparse_config): + self.model = model + self.defaults: Dict[str, Any] = defaultdict() + self.defaults['sparse_config'] = sparse_config + + # functions + self.defaults['aggregate_fn'] = aggregate_fn + self.defaults['reduce_fn'] = reduce_fn + self.defaults['mask_fn'] = mask_fn + + # default feature and feature_dim + self.defaults['features'] = features + self.defaults['feature_dim'] = feature_dim + + self.data_groups: Dict[str, Dict] = defaultdict(dict) # contains all relevant info w.r.t each registered layer + + self.state: Dict[str, Any] = defaultdict(dict) # layer name -> mask + + @staticmethod + def _safe_rail_checks(args): + """Makes sure that some of the functions and attributes are not passed incorrectly + """ + + # if features are not None, then feature_dim must not be None + features, feature_dim = args['features'], args['feature_dim'] + if features is not None: + assert feature_dim is not None, "need feature dim to select features" + + # all the *_fns should be callable + fn_keys = ['aggregate_fn', 'reduce_fn', 'mask_fn'] + for key in fn_keys: + fn = args[key] + assert callable(fn), 'function should be callable' + + def _aggregate_hook(self, name): + """Returns hook that computes aggregate of activations passing through. + """ + + # gather some data + feature_dim = self.data_groups[name]['feature_dim'] + features = self.data_groups[name]['features'] + agg_fn = self.data_groups[name]['aggregate_fn'] + + def hook(module, input) -> None: + input_data = input[0] + + data = self.data_groups[name].get('data') # aggregated data + if features is None: + # no features associated, data should not be a list + if data is None: + data = torch.zeros_like(input_data) + self.state[name]['mask'] = torch.ones_like(input_data) + out_data = agg_fn(data, input_data) + else: + # data should be a list [aggregated over each feature only] + if data is None: + out_data = [0 for _ in range(0, len(features))] # create one incase of 1st forward + self.state[name]['mask'] = [0 for _ in range(0, len(features))] + else: + out_data = data # a list + + # compute aggregate over each feature + for feature_idx in range(len(features)): + # each feature is either a list or scalar, convert it to torch tensor + feature_tensor = torch.Tensor([features[feature_idx]]).long().to(input_data.device) + data_feature = torch.index_select(input_data, feature_dim, feature_tensor) + if data is None: + curr_data = torch.zeros_like(data_feature) + self.state[name]['mask'][feature_idx] = torch.ones_like(data_feature) + else: + curr_data = data[feature_idx] + out_data[feature_idx] = agg_fn(curr_data, data_feature) + self.data_groups[name]['data'] = out_data + return hook + + def register_layer(self, layer: nn.Module, aggregate_fn=None, reduce_fn=None, + mask_fn=None, features=None, feature_dim=None, **sparse_config): + r""" + Registers a layer for sparsification. The layer should be part of self.model. + Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn + and store the aggregated activations that is input over each step. + + Note:: + - There is no need to pass in the name of the layer as it is automatically computed as per + the fqn convention. + + - All the functions (fn) passed as argument will be called at a dim, feature level. + """ + name = module_to_fqn(self.model, layer) + assert name is not None, "layer not found in the model" # satisfy mypy + + if name in self.data_groups: # unregister layer if already present + warnings.warn("layer already attached to the sparsifier, deregistering the layer and registering with new config") + self.unregister_layer(name=name) + + local_args = copy.deepcopy(self.defaults) + update_dict = { + 'aggregate_fn': aggregate_fn, + 'reduce_fn': reduce_fn, + 'mask_fn': mask_fn, + 'features': features, + 'feature_dim': feature_dim, + 'layer': layer + } + local_args.update((arg, val) for arg, val in update_dict.items() if val is not None) + local_args['sparse_config'].update(sparse_config) + + self._safe_rail_checks(local_args) + + self.data_groups[name] = local_args + agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name)) + + self.state[name]['mask'] = None # mask will be created when model forward is called. + + # attach agg hook + self.data_groups[name]['hook'] = agg_hook + + # for serialization purposes, we know whether aggregate_hook is attached + # or sparsify_hook() + self.data_groups[name]['hook_state'] = "aggregate" # aggregate hook is attached + + def get_mask(self, name: str = None, layer: nn.Module = None): + """ + Returns mask associated to the layer. + + The mask is + - a torch tensor is features for that layer is None. + - a list of torch tensors for each feature, otherwise + + Note:: + The shape of the mask is unknown until model.forward() is applied. + Hence, if get_mask() is called before model.forward(), an + error will be raised. + """ + assert name is not None or layer is not None, "Need at least name or layer obj to retrieve mask" + + if name is None: + assert layer is not None + name = module_to_fqn(self.model, layer) + assert name is not None, "layer not found in the specified model" + + if name not in self.state: + raise ValueError("Error: layer with the given name not found") + + mask = self.state[name].get('mask', None) + + if mask is None: + raise ValueError("Error: shape unknown, call layer() routine at least once to infer mask") + return mask + + def unregister_layer(self, name): + """Detaches the sparsifier from the layer + """ + + # detach any hooks attached + self.data_groups[name]['hook'].remove() + + # pop from the state dict + self.state.pop(name) + + # pop from the data groups + self.data_groups.pop(name) + + def step(self): + """Internally calls the update_mask() function for each layer + """ + with torch.no_grad(): + for name, configs in self.data_groups.items(): + data = configs['data'] + self.update_mask(name, data, configs) + + self.data_groups[name].pop('data') # reset the accumulated data + + def update_mask(self, name, data, configs): + """ + Called for each registered layer and does the following- + 1. apply reduce_fn on the aggregated activations + 2. use mask_fn to compute the sparsification mask + + Note: + the reduce_fn and mask_fn is called for each feature, dim over the data + """ + mask = self.get_mask(name) + sparse_config = configs['sparse_config'] + features = configs['features'] + reduce_fn = configs['reduce_fn'] + mask_fn = configs['mask_fn'] + if features is None: + data = reduce_fn(data) + mask.data = mask_fn(data, **sparse_config) + else: + for feature_idx in range(len(features)): + data_feature = reduce_fn(data[feature_idx]) + mask[feature_idx].data = mask_fn(data_feature, **sparse_config) + + def _sparsify_hook(self, name): + """Returns hook that applies sparsification mask to input entering the attached layer + """ + mask = self.get_mask(name) + features = self.data_groups[name]['features'] + feature_dim = self.data_groups[name]['feature_dim'] + + def hook(module, input): + input_data = input[0] + if features is None: + # apply to all the features + return input_data * mask + else: + # apply per feature, feature_dim + for feature_idx in range(0, len(features)): + feature = torch.Tensor([features[feature_idx]]).long().to(input_data.device) + sparsified = torch.index_select(input_data, feature_dim, feature) * mask[feature_idx] + input_data.index_copy_(feature_dim, feature, sparsified) + return input_data + return hook + + def squash_mask(self, attach_sparsify_hook=True, **kwargs): + """ + Unregisters aggreagate hook that was applied earlier and registers sparsification hooks if + attach_sparsify_hook = True. + """ + for name, configs in self.data_groups.items(): + # unhook agg hook + configs['hook'].remove() + configs.pop('hook') + self.data_groups[name]['hook_state'] = "None" + if attach_sparsify_hook: + configs['hook'] = configs['layer'].register_forward_pre_hook(self._sparsify_hook(name)) + configs['hook_state'] = "sparsify" # signals that sparsify hook is now attached + + def _get_serializable_data_groups(self): + """Exclude hook and layer from the config keys before serializing + + TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing. + For time-being, functions are treated the same way as other attributes + """ + data_groups: Dict[str, Any] = defaultdict() + for name, config in self.data_groups.items(): + new_config = {key: value for key, value in config.items() if key not in ['hook', 'layer']} + data_groups[name] = new_config + return data_groups + + def _convert_mask(self, states_dict, sparse_coo=True): + r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument. + If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor + """ + states = copy.deepcopy(states_dict) + for _, state in states.items(): + if state['mask'] is not None: + if isinstance(state['mask'], List): + for idx in range(len(state['mask'])): + if sparse_coo: + state['mask'][idx] = state['mask'][idx].to_sparse_coo() + else: + state['mask'][idx] = state['mask'][idx].to_dense() + else: + if sparse_coo: + state['mask'] = state['mask'].to_sparse_coo() + else: + state['mask'] = state['mask'].to_dense() + return states + + def state_dict(self) -> Dict[str, Any]: + r"""Returns the state of the sparsifier as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a dictionary containing all config information for each + layer + * defaults - the default config while creating the constructor + """ + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return { + 'state': state, + 'data_groups': data_groups, + 'defaults': self.defaults + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + """ + state = state_dict['state'] + data_groups, defaults = state_dict['data_groups'], state_dict['defaults'] + + self.__set_state__({'state': state, 'data_groups': data_groups, 'defaults': defaults}) + + def __get_state__(self) -> Dict[str, Any]: + + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return { + 'defaults': self.defaults, + 'state': state, + 'data_groups': data_groups, + } + + def __set_state__(self, state: Dict[str, Any]) -> None: + state['state'] = self._convert_mask(state['state'], sparse_coo=False) # convert mask to dense tensor + self.__dict__.update(state) + + # need to attach layer and hook info into the data_groups + for name, config in self.data_groups.items(): + # fetch layer + layer = fqn_to_module(self.model, name) + assert layer is not None # satisfy mypy + + # if agg_mode is True, then layer in aggregate mode + if "hook_state" in config and config['hook_state'] == "aggregate": + hook = layer.register_forward_pre_hook(self._aggregate_hook(name)) + + elif "hook_state" in config and config["hook_state"] == "sparsify": + hook = layer.register_forward_pre_hook(self._sparsify_hook(name)) + + config['layer'] = layer + config['hook'] = hook + + def __repr__(self): + format_string = self.__class__.__name__ + ' (' + for name, config in self.data_groups.items(): + format_string += '\n' + format_string += '\tData Group\n' + format_string += f'\t name: {name}\n' + for key in sorted(config.keys()): + if key in ['data', 'hook', 'reduce_fn', 'mask_fn', 'aggregate_fn']: + continue + format_string += f'\t {key}: {config[key]}\n' + format_string += ')' + return format_string diff --git a/torch/ao/sparsity/_experimental/data_scheduler/README.md b/torch/ao/sparsity/_experimental/data_scheduler/README.md new file mode 100644 index 0000000000000..03e68c7253873 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_scheduler/README.md @@ -0,0 +1,66 @@ +# Data Scheduler +## Intro +The data scheduler is used to control the update of the data sparsification parameters and works specifically with the data sparsifier class. +This class controls a specific config param (specified by the `schedule_param` argument) of +the data sparsifier class and varies it across the training process (or across time). + +## API details +`BaseDataScheduler`: base class with abstract method `get_schedule_param` that computes the data sparsification parameter for all the data. The constructor accepts +1. `data_sparsifier`: The data sparsifier object whose parameter will be scheduled. +2. `schedule_param` : a specific config of the passed data sparsifier that needs to be scheduled/varied. + +`get_last_param`: gets the last scheduled parameter. Basically, a dictionary of name (of data) to schedule_param value mapping. + +`step`: Applies the `get_schedule_param` logic every epoch/step depending on when it is called. This should always be called after the `sparsifier.step()` has been called. + +## Write your own data scheduler +The custom data scheduler must be inherit from the `BaseDataScheduler` class and should have the `get_schedule_param()` function implemented. For example, that gradually multiplies the sparsity level by `gamma` every epoch. +It also takes an argument `threshold_sl` which when reached does not increase further. + +``` +class GammaScheduler(BaseDataScheduler): + def __init__(self, data_sparsifier, gamma, threshold_sl): + super().__init__(data_sparsifier, "sparsity_level") + self.gamma = gamma + self.threshold_sl = threshold_sl + + def get_schedule_param(self): + if self.last_epoch > 0: + return {name: min(self.threshold_sl, config["sparsity_level"] * self.gamma) for name, config in self.data_sparsifier.data_groups.items()} + else: + return {name: 0.0 for name, config in self.data_sparsifier.data_groups.items()} +``` + +## Using data scheduler with data sparsifier +Suppose the need is to vary data sparsity levels (or any sparsity `param`) during training, then a custom data scheduler can be implemented and used along with the data sparsifier. + +Example: + +``` +model = SomeModel() +optimizer = SomeOptimizer(model.parameters(), lr=...) +data_sparsifier = SomeDataSparsifier(...) + + +data_scheduler = SomeDataScheduler(data_sparsifier, ...) + + +data_name = 'train_data' + +for epoch in range(EPOCHS): + for input, target in dataset: + input = data_sparsifier.add_data(name=data_name, data=input) + + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + data_sparsifier.step() + + data_scheduler.step() +``` + +### Note: +1. `get_schedule_param()` should return a dictionary wherein the keys are the names of the data and the values are the corresponding values of the `schedule_param` for the next step. +2. It is the responsibility of the `BaseDataScheduler` to call the `get_schedule_param()` when necessary. diff --git a/torch/ao/sparsity/_experimental/data_scheduler/__init__.py b/torch/ao/sparsity/_experimental/data_scheduler/__init__.py new file mode 100644 index 0000000000000..4f7a6f98e2d12 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_scheduler/__init__.py @@ -0,0 +1,5 @@ +from .base_data_scheduler import BaseDataScheduler + +__all__ = [ + "BaseDataScheduler", +] diff --git a/torch/ao/sparsity/_experimental/data_scheduler/base_data_scheduler.py b/torch/ao/sparsity/_experimental/data_scheduler/base_data_scheduler.py new file mode 100644 index 0000000000000..7d4859743ef84 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_scheduler/base_data_scheduler.py @@ -0,0 +1,180 @@ +from functools import wraps +import weakref +import abc +import warnings + +from ..data_sparsifier import BaseDataSparsifier + +__all__ = ['BaseDataScheduler'] + + +class BaseDataScheduler(object): + r""" + The BaseDataScheduler is the abstract scheduler class specifically for the + BaseDataSparsifier class. This class controls a specific hyperparameter of + the sparsifier class and varies it across the training process (or across time). + + Args: + data_sparsifier (instance of BaseDataSparsifier) + Implemented class data sparsifier class wherein the update_mask is implemented + schedule_param (str) + A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied + last_epoch (int, default=-1) + This is specifically is passed when training needs to be resumed from a particular + point. + verbose (bool, default=False) + Verbosity of the BaseDataScheduler + + The *get_hyperparam()* function needs to be implemented by the user. + """ + def __init__(self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False): + # Attach sparsifier + if not isinstance(data_sparsifier, BaseDataSparsifier): + raise TypeError('{} is not an instance of torch.ao.sparsity.BaseDataSparsifier'.format( + type(data_sparsifier).__name__)) + self.data_sparsifier = data_sparsifier + self.schedule_param = schedule_param + + # Initialize epoch and base hyper-params + self.base_param = { + name: config.get(schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, '_with_counter', False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.data_sparsifier.step = with_counter(self.data_sparsifier.step) # type: ignore[assignment] + self.data_sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sp_called_within_step: bool = False # sp -> schedule parameter + self.step() + + @abc.abstractmethod + def get_schedule_param(self): + r""" + Abstract method that needs to be implemented by the child class. + The expected return type should is a dictionary of name to schedule_param value + The returned values will be updated in sparsifier when the scheduler step() function + is called. + + Example: + >>> def get_schedule_param(self): + new_param = {} + for name in self.sparsifier.data_groups.keys(): + new_param[name] = self.sparsifier.data_groups[name][self.schedule_param] * 0.5 + return new_param + + When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param] + would be halved + """ + raise NotImplementedError + + def __repr__(self): + format_string = self.__class__.__name__ + ' (' + format_string += '\n' + format_string += 'Data Sparsifier {0}\n'.format(self.data_sparsifier) + format_string += ' {0}: {1}\n'.format(self.schedule_param, self.base_param) + format_string += ')' + return format_string + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + + Note: + The scheduler class does not track the state of the data_sparsifier. + Make sure to store the state of the sparsifier before storing the + state of the scheduler + """ + return {key: value for key, value in self.__dict__.items() if key != 'data_sparsifier'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Note: + Remember to restore the state of the data_sparsifier before the scheduler. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_param(self): + return self._last_param + + def step(self): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.data_sparsifier.step, "_with_counter"): + warnings.warn("Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `data_sparsifier.step()` before " + "`scheduler.step()`.", UserWarning) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.data_sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn("Detected call of `scheduler.step()` before `data_sparsifier.step()`. " + "You have to make sure you run the data_sparsifier.step() BEFORE any " + "calls to the scheduer.step().", UserWarning) + self._step_count += 1 + + class _enable_get_sp_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sp_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sp_called_within_step = False + + with _enable_get_sp_call(self): + self.last_epoch += 1 + updated_scheduler_params = self.get_schedule_param() + + for name, param in updated_scheduler_params.items(): + self.data_sparsifier.data_groups[name][self.schedule_param] = param + if self.verbose: + print(f"Adjusting {self.schedule_param} for group {name} to {param}") + + self._last_param = { + name: config.get(self.schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + self.data_sparsifier.enable_mask_update = True diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/README.md b/torch/ao/sparsity/_experimental/data_sparsifier/README.md new file mode 100644 index 0000000000000..c6fc99b36c8c4 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/README.md @@ -0,0 +1,145 @@ +# Data Sparsifier +## Intro +The data sparsifier inherits from the `BaseSparsifier` class. It attempts to sparsify data tensors in general (trainable and non-trainable). + +## Implementation Details +The data sparsifier does not receive a model or a layer to sparsify. Hence, the mask needs to be owned by the data sparsifier. This is acheived by introducing a private container model that registers the data as a parametrized buffer. + +The BaseDataSparsifier handles all the housekeeping while allowing the user to just implement the `update_mask` logic in their implementation. + +## Supported data +1. torch tensors (torch.Tensor) +2. parameters (nn.Parameter) +3. embedding and embedding bags (nn.Embeddings / nn.EmbeddingBag) + +## API details +`BaseDataSparsifier`: base class with abstract method `update_mask` that computes the new mask for all the data. + +`add_data`: Accepts name, data tuple and registers the data as a parametrized buffer inside the container model. Note that the data is always associated to a name. A custom sparse config can be provided along with the name, data pair. If not provided, the default config will be applied while doing the sparsification. +If the named data already exists, then it is replaced with the new data. The config and mask will be retained for the new data unless not specified to. +To not the old mask, set `reuse_mask=False`. If the `config` is explicitly passed in, it will be updated. + +**Note**: name containing '.' is not a valid name for the data sparsifier + +``` +data_sparsifier = ImplementedDataSparsifier() +data_sparsifier.add_data(name=name, data=data, **some_config) +``` + +`step`: applies the update_mask() logic to all the data. + +``` +data_sparsifier.step() +``` + +`get_mask`: retrieves the mask given the name of the data. + +`get_data`: retrieves the data given the `name` argument. Accepts additional argument `return_original` which when set to `True` does not apply the mask while returning +the data tensor. Example: + +``` +original_data = data_sparsifier.get_data(name=name, return_original=True) # returns data with no mask applied +sparsified_data = data_sparsifier.get_data(name=name, return_original=False) # returns data * mask +``` + +`squash_mask`: removes the parametrizations on the data and applies mask to the data when `leave_parametrized=True`.Also, accepts list of strings to squash mask for. If none, squashes mask for all the keys. +``` +data_sparsifier.squash_mask() +``` + +`state_dict`: Returns dictionary that can be serialized. + +## Write your own data sparsifier. +The custom data sparsifier should be inherited from the BaseDataSparsifier class and the `update_mask()` should be implemented. For example, the following data sparsifier zeros out all entries of the tensor smaller than some threshold value. + +``` +class ImplementedDataSparsifier(BaseDataSparsifier): + def __init__(self, threshold): + super().__init__(threshold=threshold) + + def update_mask(self, name, data, threshold): + mask = self.get_mask(name) + mask[torch.abs(data) < threshold] = 0.0 +``` + +## Using Data Sparsifier +### Simple example + +``` +tensor1 = torch.randn(100, 100) +param1 = nn.Parameter(torch.randn(200, 32)) + +my_sparsifier = ImplementedDataSparsifier(threshold=0.2) +my_sparsifier.add_data(name='tensor1', data=tensor1, threshold=0.5) +my_sparsifier.add_data(name='param1', data=param1) + +my_sparsifier.step() # computes mask + +my_sparsifier.squash_mask() # applies and removes mask +``` + +### Sparsifying model embeddings + +``` +class Model(nn.Module): + def __init__(self, feature_dim, emb_dim, num_classes): + self.emb = nn.EmbeddingBag(feature_dim, emb_dim) + self.linear1 = nn.Linear(emb_dim, 32) + self.linear2 = nn.Linear(32, num_classes) + self.relu = nn.ReLU() + + def forward(self, x): + out = self.emb(x) + out = self.relu(self.linear1(out)) + out = self.linear2(out) + return out + +model = Model(100, 32, 10) +my_sparsifier = ImplementedDataSparsifier(threshold=0.5) +my_sparsifier.add_data(name='emb', data=model.emb) + +... +# Train model +... + +my_sparsifier.step() # creates mask for embeddings + +my_sparsifier.squash_mask() # applies and removes mask +``` + +### Using in the context of training data +Sometimes if the input data can be sparsified before sending it to the model, then we can do so by using the data sparsifier. + +The batched input data needs to be attached to the data sparsified before sending it to the model. + +``` +model = SomeModel() + +data_sparsifier = ImplementedDataSparsifier(threshold=0.2) + +data_name = 'train_data' + +for x, y in train_data_loader: + x = data_sparsifier.add_data(name=data_name, data=x) + ... + y_out = model(x) + ... + data_sparsifier.step() + +``` + + +**Note**: +1. It is the responsibility of the `BaseDataSparsifier` to call the `self.update_mask` when appropriate. +2. The mask should be modified in place. + + Some valid inplace operations are: + 1. Change a portion of a mask: `mask[:10] = torch.zeros(10)` + 2. Use an inplace operator: `mask *= another_mask` + 3. Change the underlying data: `mask.data = torch.zeros_like(mask)` + + Non-inplace operations are not valid, and might lead to bugs. For example: + + 1. Reassignment of a mask: `mask = torch.zeros_like(mask)` + 2. Non-inplace arithmetic operations: `mask = mask * another_mask` +3. Data sparsifier `name` argument cannot have a '.' in it. diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/__init__.py b/torch/ao/sparsity/_experimental/data_sparsifier/__init__.py new file mode 100644 index 0000000000000..1701f1eee08e0 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/__init__.py @@ -0,0 +1,7 @@ +from .base_data_sparsifier import BaseDataSparsifier +from .data_norm_sparsifier import DataNormSparsifier + +__all__ = [ + "BaseDataSparsifier", + "DataNormSparsifier", +] diff --git a/torch/ao/sparsity/experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/sparsity/_experimental/data_sparsifier/base_data_sparsifier.py similarity index 77% rename from torch/ao/sparsity/experimental/data_sparsifier/base_data_sparsifier.py rename to torch/ao/sparsity/_experimental/data_sparsifier/base_data_sparsifier.py index d66d33c655290..d59b672e59e1d 100644 --- a/torch/ao/sparsity/experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/sparsity/_experimental/data_sparsifier/base_data_sparsifier.py @@ -4,10 +4,15 @@ from ...sparsifier import base_sparsifier from collections import defaultdict from torch import nn -import warnings import copy from ...sparsifier import utils from torch.nn.utils import parametrize +import sys +import warnings + +if not sys.warnoptions: + # to suppress repeated warnings when being used in a training loop. + warnings.simplefilter("once") __all__ = ['BaseDataSparsifier'] @@ -68,41 +73,54 @@ def prepare(self): raise NotImplementedError("this function is undefined for this class") def _extract_weight(self, data): - if isinstance(data, torch.Tensor): + # extract the weight parameter instead of underlying data + if type(data) in [torch.Tensor, nn.Parameter]: return data - elif isinstance(data, nn.Parameter): - return data.data elif type(data) in EMBEDDING_TYPES: - return data.weight.data + return data.weight - def add_data(self, name: str, data, **config): - r""" Configures and parametrizes the internal container model with name and data + def add_data(self, name: str, data, reuse_mask=True, **config): + r""" Configures and parametrizes the internal container model with name and data. + + **Note**: + 1. If the data with name already exists, it replaces the data. + 2. While replacing, the old mask is reused when `reuse_mask=True` + 3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data. + 4. By default, the config of the replaced data is used as config for the replacing data, unless something + is specified in the config dictionary. """ assert type(data) in SUPPORTED_TYPES, \ "specified data type not supported at the moment" local_args = copy.deepcopy(self.defaults) local_args.update(config) - self.data_groups[name] = local_args - weight = self._extract_weight(data) # Bookkeeping in the container class mask = local_args.get('mask', torch.ones_like(weight)) - param_class = local_args.get('parametrization', utils.FakeSparsity) # change once public_api for utils is fixed! - param = nn.Parameter(weight, requires_grad=False) + param_class = local_args.get('parametrization', utils.FakeSparsity) if name in self.state: # If the named data already exists - replace warnings.warn("Replacing existing data of the same name. - Did you mean a different name?") - # check if parametrized - if parametrize.is_parametrized(self._container, name): - # If parametrized, squash mask - self.squash_mask(names=[name], leave_parametrized=False) - self._container.get_parameter(name).data = weight # overwrite the data - else: - setattr(self._container, name, param) + + # reuse old config + old_args = self.data_groups[name] + local_args = copy.deepcopy(old_args) + local_args.update(config) + + if reuse_mask: + current_data = self.get_data(name=name) + assert weight.shape == current_data.shape, \ + "to retain the old mask, the shape of the new data must be the same as the previous one" + mask = self.get_mask(name=name) # reuse mask instead of creating a new one + + self._delete_data(name=name) + + # parameter creates a deepcopy of the weight inside, so create a buffer + self._container.register_buffer(name=name, tensor=weight) parametrize.register_parametrization(self._container, name, param_class(mask)) self.state[name]['mask'] = mask + self.data_groups[name] = local_args return getattr(self._container, name) def get_data(self, name: str, return_original: bool = True): @@ -123,6 +141,18 @@ def get_data(self, name: str, return_original: bool = True): else: return getattr(self._container, name) + def _convert_mask(self, states, sparse_coo=True): + r"""Converts the mask to sparse coo or dense tensors depending on the `sparse_coo` argument. + """ + states = copy.deepcopy(states) + for _, state in states.items(): + if sparse_coo: + state['mask'] = state['mask'].to_sparse_coo() + else: + state['mask'] = state['mask'].to_dense() + + return states + def state_dict(self): r"""Returns the state of the optimizer as a :class:`dict`. @@ -133,8 +163,9 @@ def state_dict(self): * container_state_dict - the state dictionary of the internal container model used for sparsification """ + state = self._convert_mask(self.state) return { - 'state': self.state, + 'state': state, 'data_groups': self.data_groups, '_container': self._container.state_dict() } @@ -166,8 +197,7 @@ def _load_container_from_state(self, states, data_groups, container_state_dict): else: raise RuntimeError(f"Error loading {name}") - param = nn.Parameter(data, requires_grad=False) - setattr(self._container, name, param) + self._container.register_buffer(name=name, tensor=data) if parametrized: # register parameter if parametrized @@ -187,6 +217,8 @@ def load_state_dict(self, state_dict, strict=True): states = copy.deepcopy(state_dict['state']) data_groups = copy.deepcopy(state_dict['data_groups']) container_state_dict = copy.deepcopy(state_dict['_container']) + + states = self._convert_mask(states, sparse_coo=False) # convert sparse coo mask to dense if strict: # if strict load -> then reset container self._container = _Container() @@ -203,14 +235,16 @@ def __setstate__(self, state): if '_container' in state: # If container object is in state then load model container_dict = state.pop('_container') self._container = _Container() + state['state'] = self._convert_mask(state['state'], sparse_coo=False) # convert sparse coo mask to dense self._load_container_from_state(state['state'], state['data_groups'], container_dict) self.__dict__.update(state) def __getstate__(self): + state = self._convert_mask(self.state) return { 'defaults': self.defaults, - 'state': self.state, + 'state': state, 'data_groups': self.data_groups, '_container': self._container.state_dict() } @@ -259,3 +293,18 @@ def step(self): @abc.abstractmethod def update_mask(self, name, data, **kwargs): pass + + def _delete_data(self, name): + """Detaches some data from the sparsifier. + + Args: + name (str) + Name of the data to be removed from the sparsifier + + Note: + Currently private. Kind of used as a helper function when replacing data of the same name + """ + self.squash_mask(names=[name], leave_parametrized=False) # do not apply the mask while deleting + delattr(self._container, name) + self.state.pop(name) + self.data_groups.pop(name) diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/README.md b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/README.md new file mode 100644 index 0000000000000..b39e951efec5d --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/README.md @@ -0,0 +1,97 @@ +# Data Sparsifier Benchmarking using the DLRM Model + +## Introduction +The objective of this exercise is to use the data sparsifier to prune the embedding bags of the [DLRM Model](https://github.com/facebookresearch/dlrm) and observe the following - + +1. **Disk usage savings**: Savings in model size after pruning. +2. **Model Quality**: How and by how much does performance deteriorate after pruning the embedding bags? +3. **Model forward time**: Can we speed up the model forward time by utilizing the sparsity? Specificially, can we introduce torch.sparse interim to reduce number of computations. + +## Scope +The [DataNormSparsifier](https://github.com/pytorch/pytorch/blob/master/torch/ao/sparsity/_experimental/data_sparsifier/data_norm_sparsifier.py) is used to sparsify the embeddings of the DLRM model. The model is sparsified for all the combinations of - +1. Sparsity Levels: [0.0, 0.1, 0.2, ... 0.9, 0.91, 0.92, ... 0.99, 1.0] +2. Sparse Block shapes: (1,1) and (1,4) +3. Norm: L1 and L2 + +## Dataset +The benchmarks are created for the dlrm model on the Kaggle CriteoDataset which can be downloaded from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1). + +## Results +1. **Disk Usage**: Introducing sparsity in the embeddings reduces file size after compression. The compressed model size goes down from 1.9 GB to 150 MB after 100% sparsity. + + + + +2. **Model Quality**: The model accuracy decreases slowly with sparsity levels. Even at 90% sparsity levels, the model accuracy decreases only by 2%. + + +3. **Model forward time**: Sparse coo tensors are introduced on the features before feeding into the top layer of the dlrm model. Post that, we perform a sparse ```torch.mm``` with the first linear weight of the top layer. +The takeaway is that the dlrm model with sparse coo tensor is slower (roughly 2x). This is because even though the sparsity levels are high in the embedding weights, the interaction step between the dense and sparse features increases the sparsity levels. Hence, creating sparse coo tensor on this not so sparse features actually slows down the model. + + + + +## Setup +The benchmark codes depend on the [DLRM codebase](https://github.com/facebookresearch/dlrm). +1. Clone the dlrm git repository +2. Download the dataset from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1) +3. The DLRM model can be trained using the following script +``` +# Make sure you go into the file and make sure that the path to dataset is correct. + +./bench/dlrm_s_criteo_kaggle.sh --save-model=./models/criteo_model.ckpt [--use-gpu] + +# This should also dump kaggleAdDisplayChallenge_processed.npz in the path where data is present +``` + +4. Copy the scripts data sparsifier benchmark scripts into to the dlrm directory. + +## Scripts to run each experiment. + +### **Disk savings** +``` +python evaluate_disk_savings.py --model_path= --sparsified_model_dump_path= +``` + +Running this script should dump +* sparsified model checkpoints: model is sparsified for all the + combinations of sparsity levels, block shapes and norms and dumped. + +* ```sparse_model_metadata.csv```: This contains the compressed file size and path info for all the sparsified models. This file will be used for other experiments + + +### **Model Quality** +``` +python evaluate_model_metrics.py --raw_data_file= --processed_data_file= --sparse_model_metadata= +``` +Running this script should dump ```sparse_model_metrics.csv``` that contains evaluation metrics for all sparsified models. + +### **Model forward time**: +``` +python evaluate_forward_time.py --raw_data_file= --processed_data_file= --sparse_model_metadata= +``` +Running this script should dump ```dlrm_forward_time_info.csv``` that contains forward time for all sparsified models with and without torch.sparse in the forward pass. + +## Requirements +pytorch (latest) + +scikit-learn + +numpy + +pandas + +## Machine specs to create benchmark +AI AWS was used to run everything i.e. training the dlrm model and running data sparsifier benchmarks. + +Machine: AI AWS + +Instance Type: p4d.24xlarge + +GPU: A100 + + +## Future work +1. **Evaluate memory savings**: The idea is to use torch.sparse tensors to store weights of the embedding bags so that the model memory consumption improves. This will be possible once the embedding bags starts supporting torch.sparse backend. + +2. **Sparsifying activations**: Use activation sparsifier to sparsify the activations of the dlrm model. The idea is to sparsify the features before feeding to the top dense layer (sparsify ```z``` [here](https://github.com/facebookresearch/dlrm/blob/11afc52120c5baaf0bfe418c610bc5cccb9c5777/dlrm_s_pytorch.py#L595)). diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/dlrm_utils.py b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/dlrm_utils.py new file mode 100644 index 0000000000000..20919c140a4db --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/dlrm_utils.py @@ -0,0 +1,146 @@ +import torch +from dlrm_s_pytorch import DLRM_Net # type: ignore[import] +import numpy as np # type: ignore[import] +from dlrm_data_pytorch import CriteoDataset, collate_wrapper_criteo_offset # type: ignore[import] +import zipfile +import os + + +class SparseDLRM(DLRM_Net): + """The SparseDLRM model is a wrapper around the DLRM_Net model that tries + to use torch.sparse tensors for the features obtained after the ```interact_features()``` + call. The idea is to do a simple torch.mm() with the weight matrix of the first linear + layer of the top layer. + """ + def __init__(self, **args): + super().__init__(**args) + + def forward(self, dense_x, lS_o, lS_i): + x = self.apply_mlp(dense_x, self.bot_l) # dense features + ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag + z = self.interact_features(x, ly) + + z = z.to_sparse_coo() + z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias) + for layer in self.top_l[1:]: + z = layer(z) + + return z + + +def get_valid_name(name): + """Replaces '.' with '_' as names with '.' are invalid in data sparsifier + """ + return name.replace('.', '_') + + +def get_dlrm_model(sparse_dlrm=False): + """Obtain dlrm model. The configs specified are based on the script in + bench/dlrm_s_criteo_kaggle.sh. The same config is used to train the model + for benchmarking on data sparsifier. + """ + dlrm_model_config = { + 'm_spa': 16, + 'ln_emb': np.array([1460, 583, 10131227, 2202608, 305, 24, + 12517, 633, 3, 93145, 5683, 8351593, + 3194, 27, 14992, 5461306, 10, 5652, + 2173, 4, 7046547, 18, 15, 286181, + 105, 142572], dtype=np.int32), + 'ln_bot': np.array([13, 512, 256, 64, 16]), + 'ln_top': np.array([367, 512, 256, 1]), + 'arch_interaction_op': 'dot', + 'arch_interaction_itself': False, + 'sigmoid_bot': -1, + 'sigmoid_top': 2, + 'sync_dense_params': True, + 'loss_threshold': 0.0, + 'ndevices': 1, + 'qr_flag': False, + 'qr_operation': 'mult', + 'qr_collisions': 4, + 'qr_threshold': 200, + 'md_flag': False, + 'md_threshold': 200, + 'weighted_pooling': None, + 'loss_function': 'bce' + } + if sparse_dlrm: + dlrm_model = SparseDLRM(**dlrm_model_config) + else: + dlrm_model = DLRM_Net(**dlrm_model_config) + return dlrm_model + + +def dlrm_wrap(X, lS_o, lS_i, device, ndevices=1): + """Rewritten simpler version of ```dlrm_wrap()``` found in dlrm_s_pytorch.py. + This function simply moves the input tensors into the device and without the forward pass + """ + if ndevices == 1: + lS_i = ( + [S_i.to(device) for S_i in lS_i] + if isinstance(lS_i, list) + else lS_i.to(device) + ) + lS_o = ( + [S_o.to(device) for S_o in lS_o] + if isinstance(lS_o, list) + else lS_o.to(device) + ) + return X.to(device), lS_o, lS_i + + +def make_test_data_loader(raw_data_file_path, processed_data_file): + """Function to create dataset and dataloaders for the test dataset. + Rewritten simpler version of ```make_criteo_and_loaders()``` from the dlrm_data_pytorch.py + that makes the test dataset and dataloaders only for the ***kaggle criteo dataset*** + """ + test_data = CriteoDataset( + "kaggle", + -1, + 0.0, + "total", + "test", + raw_data_file_path, + processed_data_file, + False, + False, + ) + test_loader = torch.utils.data.DataLoader( + test_data, + batch_size=16384, + shuffle=False, + num_workers=7, + collate_fn=collate_wrapper_criteo_offset, + pin_memory=False, + drop_last=False, + ) + return test_loader + + +def fetch_model(model_path, device, sparse_dlrm=False): + """This function unzips the zipped model checkpoint (if zipped) and returns a + model object + + Args: + model_path (str) + path pointing to the zipped/raw model checkpoint file that was dumped in evaluate disk savings + device (torch.device) + device to which model needs to be loaded to + """ + if zipfile.is_zipfile(model_path): + with zipfile.ZipFile(model_path, 'r', zipfile.ZIP_DEFLATED) as zip_ref: + zip_ref.extractall(os.path.dirname(model_path)) + unzip_path = model_path.replace('.zip', '.ckpt') + else: + unzip_path = model_path + + model = get_dlrm_model(sparse_dlrm=sparse_dlrm) + model.load_state_dict(torch.load(unzip_path, map_location=device)) + model = model.to(device) + model.eval() + + # If there was a zip file, clean up the unzipped files + if zipfile.is_zipfile(model_path): + os.remove(unzip_path) + + return model diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py new file mode 100644 index 0000000000000..750bb91536b1a --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py @@ -0,0 +1,159 @@ +from typing import Dict, List +import torch +import time +from torch.ao.sparsity._experimental.data_sparsifier import DataNormSparsifier +import os +from dlrm_utils import get_dlrm_model, get_valid_name # type: ignore[import] +import copy +import zipfile +from zipfile import ZipFile +import pandas as pd # type: ignore[import] +import argparse + + +def create_attach_sparsifier(model, **sparse_config): + """Create a DataNormSparsifier and the attach it to the model embedding layers + + Args: + model (nn.Module) + layer of the model that needs to be attached to the sparsifier + sparse_config (Dict) + Config to the DataNormSparsifier. Should contain the following keys: + - sparse_block_shape + - norm + - sparsity_level + """ + data_norm_sparsifier = DataNormSparsifier(**sparse_config) + for name, parameter in model.named_parameters(): + if 'emb_l' in name: + valid_name = get_valid_name(name) + data_norm_sparsifier.add_data(name=valid_name, data=parameter) + return data_norm_sparsifier + + +def save_model_states(state_dict, sparsified_model_dump_path, save_file_name, sparse_block_shape, norm, zip=True): + """Dumps the state_dict() of the model. + + Args: + state_dict (Dict) + The state_dict() as dumped by dlrm_s_pytorch.py. Only the model state will be extracted + from this dictionary. This corresponds to the 'state_dict' key in the state_dict dictionary. + >>> model_state = state_dict['state_dict'] + save_file_name (str) + The filename (not path) when saving the model state dictionary + sparse_block_shape (Tuple) + The block shape corresponding to the data norm sparsifier. **Used for creating save directory** + norm (str) + type of norm (L1, L2) for the datanorm sparsifier. **Used for creating save directory** + zip (bool) + if True, the file is zip-compressed. + """ + folder_name = os.path.join(sparsified_model_dump_path, str(norm)) + + # save model only states + folder_str = f"config_{sparse_block_shape}" + model_state = state_dict['state_dict'] + model_state_path = os.path.join(folder_name, folder_str, save_file_name) + + if not os.path.exists(os.path.dirname(model_state_path)): + os.makedirs(os.path.dirname(model_state_path)) + torch.save(model_state, model_state_path) + + if zip: + zip_path = model_state_path.replace('.ckpt', '.zip') + with ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zip: + zip.write(model_state_path, save_file_name) + os.remove(model_state_path) # store it as zip, remove uncompressed + model_state_path = zip_path + + model_state_path = os.path.abspath(model_state_path) + file_size = os.path.getsize(model_state_path) + file_size = file_size >> 20 # size in mb + return model_state_path, file_size + + +def sparsify_model(path_to_model, sparsified_model_dump_path): + """Sparsifies the embedding layers of the dlrm model for different sparsity levels, norms and block shapes + using the DataNormSparsifier. + The function tracks the step time of the sparsifier and the size of the compressed checkpoint and collates + it into a csv. + + Note:: + This function dumps a csv sparse_model_metadata.csv in the current directory. + + Args: + path_to_model (str) + path to the trained criteo model ckpt file + sparsity_levels (List of float) + list of sparsity levels to be sparsified on + norms (List of str) + list of norms to be sparsified on + sparse_block_shapes (List of tuples) + List of sparse block shapes to be sparsified on + """ + sparsity_levels = [sl / 10 for sl in range(0, 10)] + sparsity_levels += [0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0] + + norms = ["L1", "L2"] + sparse_block_shapes = [(1, 1), (1, 4)] + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + print("Running for sparsity levels - ", sparsity_levels) + print("Running for sparse block shapes - ", sparse_block_shapes) + print("Running for norms - ", norms) + + orig_model = get_dlrm_model() + saved_state = torch.load(path_to_model, map_location=device) + orig_model.load_state_dict(saved_state['state_dict']) + + orig_model = orig_model.to(device) + step_time_dict = {} + + stat_dict: Dict[str, List] = {'norm': [], 'sparse_block_shape': [], 'sparsity_level': [], + 'step_time_sec': [], 'zip_file_size': [], 'path': []} + for norm in norms: + for sbs in sparse_block_shapes: + if norm == "L2" and sbs == (1, 1): + continue + for sl in sparsity_levels: + model = copy.deepcopy(orig_model) + sparsifier = create_attach_sparsifier(model, sparse_block_shape=sbs, norm=norm, sparsity_level=sl) + + t1 = time.time() + sparsifier.step() + t2 = time.time() + + step_time = t2 - t1 + norm_sl = f"{norm}_{sbs}_{sl}" + print(f"Step Time for {norm_sl}=: {step_time} s") + + step_time_dict[norm_sl] = step_time + + sparsifier.squash_mask() + + saved_state['state_dict'] = model.state_dict() + file_name = f'criteo_model_norm={norm}_sl={sl}.ckpt' + state_path, file_size = save_model_states(saved_state, sparsified_model_dump_path, file_name, sbs, norm=norm) + + stat_dict['norm'].append(norm) + stat_dict['sparse_block_shape'].append(sbs) + stat_dict['sparsity_level'].append(sl) + stat_dict['step_time_sec'].append(step_time) + stat_dict['zip_file_size'].append(file_size) + stat_dict['path'].append(state_path) + + df = pd.DataFrame(stat_dict) + filename = 'sparse_model_metadata.csv' + df.to_csv(filename, index=False) + + print(f"Saved sparsified metadata file in {filename}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str) + parser.add_argument('--sparsified_model_dump_path', type=str) + args = parser.parse_args() + + sparsify_model(args.model_path, args.sparsified_model_dump_path) diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py new file mode 100644 index 0000000000000..4435365c2efc6 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py @@ -0,0 +1,108 @@ +from typing import Dict, List +import torch +from dlrm_s_pytorch import unpack_batch # type: ignore[import] +import numpy as np # type: ignore[import] +import time +from dlrm_utils import make_test_data_loader, fetch_model, dlrm_wrap # type: ignore[import] +import pandas as pd # type: ignore[import] +import argparse + + +def run_forward(model, **batch): + """The purpose of this function is to time the forward run of the model. + The model forward happens a 100 times and each pass is timed. The average + of this 100 runs is returned as avg_time. + """ + time_list = [] + X, lS_o, lS_i = batch['X'], batch['lS_o'], batch['lS_i'] + for _ in range(100): + start = time.time() + with torch.no_grad(): + model(X, lS_o, lS_i) + end = time.time() + time_taken = end - start + time_list.append(time_taken) + avg_time = np.mean(time_list[1:]) + return avg_time + + +def make_sample_test_batch(raw_data_path, processed_data_path, device): + """Create the test_data_loader and sample a batch from it. This batch will be used + to measure the forward pass of the model throughout this experiment. + """ + test_data_loader = make_test_data_loader(raw_data_path, processed_data_path) + + test_iter = iter(test_data_loader) + + test_batch = next(test_iter) + + X_test, lS_o_test, lS_i_test, _, _, _ = unpack_batch(test_batch) + + X, lS_o, lS_i = dlrm_wrap(X_test, lS_o_test, lS_i_test, device) + batch = { + 'X': X, + 'lS_o': lS_o, + 'lS_i': lS_i + } + + return batch + +def measure_forward_pass(sparse_model_metadata, device, sparse_dlrm, **batch): + """Measures and tracks the forward pass of the model for all the sparsity levels, block shapes and norms + available in sparse_model_metadata file. + If sparse_dlrm=True, then the SparseDLRM model is loaded, otherwise the standard one is. + """ + time_taken_dict: Dict[str, List] = { + "norm": [], + "sparse_block_shape": [], + "sparsity_level": [], + "time_taken": [], + } + + metadata = pd.read_csv(sparse_model_metadata) + + for _, row in metadata.iterrows(): + norm, sbs, sl = row['norm'], row['sparse_block_shape'], row['sparsity_level'] + model_path = row['path'] + model = fetch_model(model_path, device, sparse_dlrm=sparse_dlrm) + time_taken = run_forward(model, **batch) + out_str = f"{norm}_{sbs}_{sl}={time_taken}" + print(out_str) + time_taken_dict["norm"].append(norm) + time_taken_dict["sparse_block_shape"].append(sbs) + time_taken_dict["sparsity_level"].append(sl) + time_taken_dict["time_taken"].append(time_taken) + + time_df = pd.DataFrame(time_taken_dict) + + if sparse_dlrm: + time_df['dlrm_type'] = 'with_torch_sparse' + else: + time_df['dlrm_type'] = 'without_torch_sparse' + + return time_df + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--raw_data_file', type=str) + parser.add_argument('--processed_data_file', type=str) + parser.add_argument('--sparse_model_metadata', type=str) + + args = parser.parse_args() + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + print(device) + + batch = make_sample_test_batch(args.raw_data_file, args.processed_data_file, device) + + print("Forward Time for Sparse DLRM") + sparse_dlrm_time_df = measure_forward_pass(args.sparse_model_metadata, device, sparse_dlrm=True, **batch) + print(sparse_dlrm_time_df) + + print("Forward Time for Normal DLRM") + norm_dlrm_time_df = measure_forward_pass(args.sparse_model_metadata, device, sparse_dlrm=False, **batch) + print(norm_dlrm_time_df) + + forward_time_all = pd.concat([sparse_dlrm_time_df, norm_dlrm_time_df]) + forward_time_all.to_csv('dlrm_forward_time_info.csv', index=False) diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py new file mode 100644 index 0000000000000..05246d545ba74 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py @@ -0,0 +1,132 @@ +from typing import Dict, List +import torch +from dlrm_s_pytorch import unpack_batch # type: ignore[import] +import numpy as np # type: ignore[import] +import sklearn # type: ignore[import] +from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model +import pandas as pd # type: ignore[import] +import argparse + + +def inference_and_evaluation(dlrm, test_dataloader, device): + """Perform inference and evaluation on the test dataset. + The function returns the dictionary that contains evaluation metrics such as accuracy, f1, auc, + precision, recall. + Note: This function is a rewritten version of ```inference()``` present in dlrm_s_pytorch.py + + Args: + dlrm (nn.Module) + dlrm model object + test_data_loader (torch dataloader): + dataloader for the test dataset + device (torch.device) + device on which the inference happens + """ + nbatches = len(test_dataloader) + scores = [] + targets = [] + + for i, testBatch in enumerate(test_dataloader): + # early exit if nbatches was set by the user and was exceeded + if nbatches > 0 and i >= nbatches: + break + + X_test, lS_o_test, lS_i_test, T_test, _, _ = unpack_batch( + testBatch + ) + # forward pass + X_test, lS_o_test, lS_i_test = dlrm_wrap(X_test, lS_o_test, lS_i_test, device, ndevices=1) + + Z_test = dlrm(X_test, lS_o_test, lS_i_test) + S_test = Z_test.detach().cpu().numpy() # numpy array + T_test = T_test.detach().cpu().numpy() # numpy array + scores.append(S_test) + targets.append(T_test) + + scores = np.concatenate(scores, axis=0) + targets = np.concatenate(targets, axis=0) + metrics = { + "recall": lambda y_true, y_score: sklearn.metrics.recall_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "precision": lambda y_true, y_score: sklearn.metrics.precision_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "f1": lambda y_true, y_score: sklearn.metrics.f1_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "ap": sklearn.metrics.average_precision_score, + "roc_auc": sklearn.metrics.roc_auc_score, + "accuracy": lambda y_true, y_score: sklearn.metrics.accuracy_score( + y_true=y_true, y_pred=np.round(y_score) + ), + "log_loss": lambda y_true, y_score: sklearn.metrics.log_loss( + y_true=y_true, y_pred=y_score + ) + } + + all_metrics = {} + for metric_name, metric_function in metrics.items(): + all_metrics[metric_name] = round(metric_function(targets, scores), 3) + + return all_metrics + + +def evaluate_metrics(test_dataloader, sparse_model_metadata): + """Evaluates the metrics the sparsified metrics for the dlrm model on various sparsity levels, + block shapes and norms. This function evaluates the model on the test dataset and dumps + evaluation metrics in a csv file [model_performance.csv] + """ + metadata = pd.read_csv(sparse_model_metadata) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + metrics_dict: Dict[str, List] = { + "norm": [], + "sparse_block_shape": [], + "sparsity_level": [], + "precision": [], + "recall": [], + "f1": [], + "roc_auc": [], + "accuracy": [], + "log_loss": [] + } + + for _, row in metadata.iterrows(): + norm, sbs, sl = row['norm'], row['sparse_block_shape'], row['sparsity_level'] + model_path = row['path'] + model = fetch_model(model_path, device) + + model_metrics = inference_and_evaluation(model, test_dataloader, device) + key = f"{norm}_{sbs}_{sl}" + print(key, "=", model_metrics) + + metrics_dict['norm'].append(norm) + metrics_dict['sparse_block_shape'].append(sbs) + metrics_dict['sparsity_level'].append(sl) + + for key, value in model_metrics.items(): + if key in metrics_dict: + metrics_dict[key].append(value) + + sparse_model_metrics = pd.DataFrame(metrics_dict) + print(sparse_model_metrics) + + filename = 'sparse_model_metrics.csv' + sparse_model_metrics.to_csv(filename, index=False) + print(f"Model metrics file saved to {filename}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--raw_data_file', type=str) + parser.add_argument('--processed_data_file', type=str) + parser.add_argument('--sparse_model_metadata', type=str) + + args = parser.parse_args() + + # Fetch test data loader + test_dataloader = make_test_data_loader(args.raw_data_file, args.processed_data_file) + + # Evaluate metrics + evaluate_metrics(test_dataloader, args.sparse_model_metadata) diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/accuracy.png b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/accuracy.png new file mode 100644 index 0000000000000..6b094aded25d3 Binary files /dev/null and b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/accuracy.png differ diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/disk_savings.png b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/disk_savings.png new file mode 100644 index 0000000000000..da55fab863044 Binary files /dev/null and b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/disk_savings.png differ diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/forward_time.png b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/forward_time.png new file mode 100644 index 0000000000000..24ade831a5c1e Binary files /dev/null and b/torch/ao/sparsity/_experimental/data_sparsifier/benchmarks/images/forward_time.png differ diff --git a/torch/ao/sparsity/experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/sparsity/_experimental/data_sparsifier/data_norm_sparsifier.py similarity index 94% rename from torch/ao/sparsity/experimental/data_sparsifier/data_norm_sparsifier.py rename to torch/ao/sparsity/_experimental/data_sparsifier/data_norm_sparsifier.py index 095cb3bb1c657..8fb13b2f4a239 100644 --- a/torch/ao/sparsity/experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/sparsity/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -2,6 +2,7 @@ from torch.nn import functional as F from functools import reduce from typing import Tuple, Any, List + from .base_data_sparsifier import BaseDataSparsifier __all__ = ['DataNormSparsifier'] @@ -80,7 +81,7 @@ def __get_block_level_mask(self, data, mask = self.__get_scatter_folded_mask(data=unfolded_data, dim=1, indices=sorted_idx, output_size=padded_data.shape, sparse_block_shape=sparse_block_shape) - mask = mask.squeeze()[:height, :width].contiguous() # remove padding and make contiguous + mask = mask.squeeze(0).squeeze(0)[:height, :width].contiguous() # remove padding and make contiguous return mask def __get_data_level_mask(self, data, sparsity_level, @@ -109,7 +110,7 @@ def __get_data_level_mask(self, data, sparsity_level, output_size=(height + dh, width + dw), sparse_block_shape=sparse_block_shape) - mask = mask.squeeze()[:height, :width] + mask = mask.squeeze(0).squeeze(0)[:height, :width] # squeeze only the first 2 dimension return mask def update_mask(self, name, data, sparsity_level, @@ -127,9 +128,12 @@ def update_mask(self, name, data, sparsity_level, else: data_norm = (data * data).squeeze() # square every element for L2 - if len(data_norm.shape) != 2: # only supports 2 dimenstional data at the moment + if len(data_norm.shape) > 2: # only supports 2 dimenstional data at the moment raise ValueError("only supports 2-D at the moment") + elif len(data_norm.shape) == 1: # in case the data is bias (or 1D) + data_norm = data_norm[None, :] + mask = self.get_mask(name) if sparsity_level <= 0 or zeros_per_block == 0: mask.data = torch.ones_like(mask) diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/lightning/__init__.py b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/README.md b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/README.md new file mode 100644 index 0000000000000..f36342edf0b4a --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/README.md @@ -0,0 +1,77 @@ +# Lightning callbacks for data sparsifier and scheduler + +**These are callback scripts for lightning and does not introduce pytorch lightning dependency on PyTorch.** + +## Introduction +Callbacks for PytorchLightning that specifies on when and how to to sparsify the data weights of the model. + +## Types of Data Sparsity Callbacks +There are 2 types of data sparsity callbacks +1. **Post Training data sparsifier callback**: Sparsification of the model parameters *post* training. + +2. **Training Aware data sparsifier callback**: Sparsification of the model parameters *during* training. + +## API Design +1. `PostTrainingDataSparsity`: callback class that sparsifies the model parameters post training. Accepts + 1. `data_sparsifier_class`: class/type of data sparsifier that needs to be used. Only the class should be passed, the data sparsifier object + will be created internally and will be attached to the model by the callback whenever necessary. + 2. `data_sparsifier_args`: the arguments/config for the data sparsifier constructor that will be used while creating the object. + + Example: + ``` + from data_sparsity import PostTrainingDataSparsity + sparsifier_args = { + 'sparsity_level': 0.5, + 'sparse_block_shape': (1, 4), + 'zeros_per_block': 4 + } + pt_callback = PostTrainingDataSparsity(data_sparsifier_class=DataNormSparsifier, data_sparsifier_args=sparsifier_args) + ``` + +2. `TrainingAwareDataSparsity`: callback class to sparsify model during training. In addition to `data_sparsifier_class` and `data_sparsifier_args`, + also accepts + 1. `data_scheduler_class`: class/type of data scheduler to schedule the sparsity levels during training. Only the class should be passed, the object + will be created internally whenever necessary. + 2. `data_scheduler_args`: the arguments/config for the data scheduler constructor that will be used while creating the object. + + Example: + + ``` + from data_sparsity import TrainingAwareDataSparsity + sparsifier_args = { + 'sparsity_level': 0.5, + 'sparse_block_shape': (1, 4), + 'zeros_per_block': 4 + } + scheduler_args = { + 'gamma': 2, + 'step_size': 1 + } + + ta_callback = TrainingAwareDataSparsity( + data_sparsifier_class=DataNormSparsifier, + data_sparsifier_args=sparsifier_args, + data_scheduler_class=StepSLScheduler, + data_scheduler_args=scheduler_args + ) + ``` + +**Note:** +1. The model is copied and then sparsified, so the existing model is not modified. +2. The sparsified model can be accessed using `sparsified` attribute and can be used for comparison with the original version. +3. The data sparsifier/scheduler object will be created internally and will be attached to the model by the callback whenever necessary. + +## Usage +``` +pl_module = SomePLModule() # pl_module.model should specify the pytorch model + +ds_callback = SomeDataSparsifierCallback(data_sparsifier_class=..., data_sparsifier_args=..., ...) # add scheduler if TrainingAwareDataSparsifier +trainer = Trainer(callbacks=[ds_callback]) + +trainer.fit(pl_module, train_data_loader, val_data_loader) + +# NOTE: pl_module.model is not sparsified + +# access sparsified model +sparsified_model = ds_callback.sparsified +``` diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/__init__.py b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py new file mode 100644 index 0000000000000..ba0b105d34bae --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py @@ -0,0 +1,39 @@ +import logging +from torch.ao.sparsity._experimental.data_sparsifier.base_data_sparsifier import SUPPORTED_TYPES + +logger: logging.Logger = logging.getLogger(__name__) + + +def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None): + """Attaches a data sparsifier to all the layers of the module. + Essentialy, loop over all the weight parameters in the module and + attach it to the data sparsifier. + Note:: + The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below) + before attaching to the sparsifier. This is because, the data + sparsifier uses a dummy model inside to store the weight parameters. + """ + if config is None: + config = {} + for name, parameter in module.named_parameters(): + if type(parameter) in SUPPORTED_TYPES: + valid_name = _get_valid_name(name) + # will be defaulted to default configs + data_sparsifier.add_data(name=valid_name, data=parameter, **config.get(valid_name, {})) + + +def _get_valid_name(name): + return name.replace('.', '_') # . is not allowed as a name + + +def _log_sparsified_level(model, data_sparsifier) -> None: + # Show the level of sparsity AFTER step: + for name, parameter in model.named_parameters(): + if not (type(parameter) in SUPPORTED_TYPES): + continue + valid_name = _get_valid_name(name) + mask = data_sparsifier.get_mask(name=valid_name) + sparsity_level = 1.0 - mask.float().mean() + logger.info( + f"Sparsity in layer {name} = {sparsity_level: .2%}" + ) diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py new file mode 100644 index 0000000000000..c36c35bcf5241 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py @@ -0,0 +1,165 @@ +from collections import defaultdict +from copy import deepcopy +import torch +from typing import Any, Optional, Dict +import pytorch_lightning as pl # type: ignore[import] + +from ._data_sparstity_utils import ( + _attach_model_to_data_sparsifier, + _log_sparsified_level, + _get_valid_name +) + + +class PostTrainingDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables post-training sparsity. + + This callback aims to sparsify the model inside lightning module after training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + once the training completes. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + Hooks implemented: + on_fit_end() + 1. copies the model and attaches it to the sparsifier + 2. sparsier step() is called + 3. squashes the mask() + """ + def __init__(self, data_sparsifier_class, data_sparsifier_args): + super().__init__() + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + self.data_sparsifier: Any = None + self.sparsified: Optional[torch.nn.Module] = None + + def on_fit_end(self, trainer, pl_module) -> None: + self.sparsified = deepcopy(pl_module.model).eval() + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + + _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier) + + self.data_sparsifier.step() + + self.data_sparsifier.squash_mask() # currently squashes params for all mask + + _log_sparsified_level(self.sparsified, self.data_sparsifier) + + +class TrainingAwareDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables in-training sparsity. + + This callback aims to sparsify the model inside lightning module during training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + when the training starts. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + data_scheduler_class (some implemented class of BaseDataScheduler) + The data scheduler of this class is created when the training starts + Note: Objects should not be passed in here as they are created + when the training starts. + + data_scheduler_args(Dict) + Dictionary of args to be passed to the data scheduler. + **Note: data_sparsifier arg should be ignored as the recipe + creates and pass sparsifier object into the class** + + Hooks implemented: + on_train_start() + Data sparsifier and scheduler objects are created. + Pytorch model attached to the sparsifier + + on_train_epoch_start() + Loads the state_dict of the data sparsifier + + on_train_epoch_end() + 1. Copies the model and attaches it to the sparsifier + 2. sparsifier step() and scheduler step() + 3. Dump state_dict of the current sparsifier + + on_train_end() + squash mask + """ + def __init__(self, data_sparsifier_class, data_sparsifier_args, + data_scheduler_class, data_scheduler_args): + super().__init__() + # data sparsifier objects + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + + # scheduler objects + self.data_scheduler_class = data_scheduler_class + self.data_scheduler_args = data_scheduler_args + + # fields + self.data_sparsifier: Any = None + self.data_scheduler: Any = None + self.sparsified: Optional[torch.nn.Module] = None + + self.data_sparsifier_state_dict: Any = None + + def on_train_start(self, trainer, pl_module) -> None: + # create sparsifier + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + self.sparsified = deepcopy(pl_module.model) + + _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier) # just to populate the base_sl in the scheduler + + # create scheduler + args = deepcopy(self.data_scheduler_args) + args['data_sparsifier'] = self.data_sparsifier + self.data_scheduler = self.data_scheduler_class(**args) + + def on_train_epoch_start(self, trainer, pl_module): + if self.data_sparsifier_state_dict is None: + return # probably first epoch + + # load the existing config for each data + self.data_sparsifier.load_state_dict(self.data_sparsifier_state_dict) + + def __create_config_based_on_state(self, pl_module): + config: Dict = defaultdict() + if self.data_sparsifier_state_dict is None: + return config + for name, _ in pl_module.model.named_parameters(): + valid_name = _get_valid_name(name) + config[valid_name] = self.data_sparsifier.data_groups[valid_name] + + return config + + def on_train_epoch_end(self, trainer, pl_module): + self.sparsified = deepcopy(pl_module.model) + config = self.__create_config_based_on_state(pl_module) + + # attach model to the data sparsifier + _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier, config=config) + self.data_sparsifier.step() + self.data_scheduler.step() + + self.data_sparsifier_state_dict = self.data_sparsifier.state_dict() + + def on_train_end(self, trainer, pl_module): + self.data_sparsifier.squash_mask() diff --git a/torch/ao/sparsity/_experimental/data_sparsifier/lightning/tests/test_callbacks.py b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/tests/test_callbacks.py new file mode 100644 index 0000000000000..76909dc48b9b8 --- /dev/null +++ b/torch/ao/sparsity/_experimental/data_sparsifier/lightning/tests/test_callbacks.py @@ -0,0 +1,275 @@ +from torch.ao.sparsity._experimental.data_sparsifier.data_norm_sparsifier import DataNormSparsifier +from torch.ao.sparsity._experimental.data_scheduler.base_data_scheduler import BaseDataScheduler +import torch +import torch.nn as nn +from typing import List +from torch.ao.sparsity._experimental.data_sparsifier.lightning.callbacks.data_sparsity import ( + PostTrainingDataSparsity, + TrainingAwareDataSparsity +) +from torch.ao.sparsity._experimental.data_sparsifier.lightning.callbacks._data_sparstity_utils import _get_valid_name +from torch.ao.sparsity._experimental.data_sparsifier.base_data_sparsifier import SUPPORTED_TYPES +from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import run_tests +import importlib +import unittest +import warnings +import math +from torch.nn.utils.parametrize import is_parametrized + + +class DummyModel(nn.Module): + def __init__(self, iC: int, oC: List[int]): + super().__init__() + self.linears = nn.Sequential() + i = iC + for idx, c in enumerate(oC): + self.linears.append(nn.Linear(i, c, bias=False)) + if idx < len(oC) - 1: + self.linears.append(nn.ReLU()) + i = c + + +def _make_lightning_module(iC: int, oC: List[int]): + import pytorch_lightning as pl # type: ignore[import] + + class DummyLightningModule(pl.LightningModule): + def __init__(self, ic: int, oC: List[int]): + super().__init__() + self.model = DummyModel(iC, oC) + + def forward(self): + pass + + return DummyLightningModule(iC, oC) + + + +class StepSLScheduler(BaseDataScheduler): + """The sparsity param of each data group is multiplied by gamma every step_size epochs. + """ + def __init__(self, data_sparsifier, schedule_param='sparsity_level', + step_size=1, gamma=2, last_epoch=-1, verbose=False): + + self.gamma = gamma + self.step_size = step_size + super().__init__(data_sparsifier, schedule_param, last_epoch, verbose) + + def get_schedule_param(self): + if not self._get_sp_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + data_groups = self.data_sparsifier.data_groups + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return {name: config[self.schedule_param] for name, config in data_groups.items()} + + return {name: config[self.schedule_param] * self.gamma for name, config in data_groups.items()} + + +class TestPostTrainingCallback(TestCase): + def _check_on_fit_end(self, pl_module, callback, sparsifier_args): + """Makes sure that each component of is working as expected while calling the + post-training callback. + Specifically, check the following - + 1. sparsifier config is the same as input config + 2. data sparsifier is correctly attached to the model + 3. sparsity is achieved after .step() + 4. non-sparsified values are the same as original values + """ + callback.on_fit_end(42, pl_module) # 42 is a dummy value + + # check sparsifier config + for key, value in sparsifier_args.items(): + assert callback.data_sparsifier.defaults[key] == value + + # assert that the model is correctly attached to the sparsifier + for name, param in pl_module.model.named_parameters(): + valid_name = _get_valid_name(name) + if type(param) not in SUPPORTED_TYPES: + assert valid_name not in callback.data_sparsifier.state + assert valid_name not in callback.data_sparsifier.data_groups + continue + assert valid_name in callback.data_sparsifier.data_groups + assert valid_name in callback.data_sparsifier.state + + mask = callback.data_sparsifier.get_mask(name=valid_name) + + # assert that some level of sparsity is achieved + assert (1.0 - mask.float().mean()) > 0.0 + + # make sure that non-zero values in data after squash mask are equal to original values + sparsified_data = callback.data_sparsifier.get_data(name=valid_name, return_original=False) + assert torch.all(sparsified_data[sparsified_data != 0] == param[sparsified_data != 0]) + + @unittest.skipIf(not importlib.util.find_spec("pytorch_lightning"), "No pytorch_lightning") + def test_post_training_callback(self): + sparsifier_args = { + 'sparsity_level': 0.5, + 'sparse_block_shape': (1, 4), + 'zeros_per_block': 4 + } + callback = PostTrainingDataSparsity(DataNormSparsifier, sparsifier_args) + pl_module = _make_lightning_module(100, [128, 256, 16]) + + self._check_on_fit_end(pl_module, callback, sparsifier_args) + + +class TestTrainingAwareCallback(TestCase): + """Class to test in-training version of lightning callback + Simulates model training and makes sure that each hook is doing what is expected + """ + def _check_on_train_start(self, pl_module, callback, sparsifier_args, scheduler_args): + """Makes sure that the data_sparsifier and data_scheduler objects are being created + correctly. + Basically, confirms that the input args and sparsifier/scheduler args are in-line. + """ + + callback.on_train_start(42, pl_module) # 42 is a dummy value + + # sparsifier and scheduler instantiated + assert callback.data_scheduler is not None and callback.data_sparsifier is not None + + # data sparsifier args are correct + for key, value in sparsifier_args.items(): + callback.data_sparsifier.defaults[key] == value + + # data scheduler args are correct + for key, value in scheduler_args.items(): + assert getattr(callback.data_scheduler, key) == value + + def _simulate_update_param_model(self, pl_module): + """This function might not be needed as the model is being copied + during train_epoch_end() but good to have if things change in the future + """ + for _, param in pl_module.model.named_parameters(): + param.data = param + 1 + + def _check_on_train_epoch_start(self, pl_module, callback): + """Basically ensures that the sparsifier's state is correctly being restored. + The state_dict() comparison is needed. Consider the flow - + + **Epoch: 1** + 1. on_train_epoch_start(): Nothing happens (for now) + 2. on_train_epoch_end(): + a) the model is copied into the data_sparsifier + b) .step() is called + c) internally, the state of each layer of the model inside + data sparsifier changes + + **Epoch: 2** + 1. on_train_epoch_start(): Assume nothing happens + 2. on_train_epoch_end(): + a) the model is copied into the data_sparsifier. + But wait! you need the config to attach layer + of the module to the sparsifier. If config is None, + the data_sparsifier uses the default config which we + do not want as the config of each layer changes after + .step() + + Hence, we need to dump and restore the state_dict() everytime because we're + copying the model after each epoch. + Hence, it is essential to make sure that the sparsifier's state_dict() is being + correctly dumped and restored. + + """ + # check if each component of state dict is being loaded correctly + callback.on_train_epoch_start(42, pl_module) + if callback.data_sparsifier_state_dict is None: + return + + data_sparsifier_state_dict = callback.data_sparsifier.state_dict() + + # compare container objects + container_obj1 = data_sparsifier_state_dict['_container'] + container_obj2 = callback.data_sparsifier_state_dict['_container'] + assert len(container_obj1) == len(container_obj2) + for key, value in container_obj2.items(): + assert key in container_obj1 + assert torch.all(value == container_obj1[key]) + + # compare state objects + state_obj1 = data_sparsifier_state_dict['state'] + state_obj2 = callback.data_sparsifier_state_dict['state'] + assert len(state_obj1) == len(state_obj2) + for key, value in state_obj2.items(): + assert key in state_obj1 + assert 'mask' in value and 'mask' in state_obj1[key] + assert torch.all(value['mask'] == state_obj1[key]['mask']) + + # compare data_groups dict + data_grp1 = data_sparsifier_state_dict['data_groups'] + data_grp2 = callback.data_sparsifier_state_dict['data_groups'] + assert len(data_grp1) == len(data_grp2) + for key, value in data_grp2.items(): + assert key in data_grp1 + assert value == data_grp1[key] + + def _check_on_train_epoch_end(self, pl_module, callback): + """Checks the following - + 1. sparsity is correctly being achieved after .step() + 2. scheduler and data_sparsifier sparsity levels are in-line + """ + callback.on_train_epoch_end(42, pl_module) + data_scheduler = callback.data_scheduler + base_sl = data_scheduler.base_param + + for name, _ in pl_module.model.named_parameters(): + valid_name = _get_valid_name(name) + mask = callback.data_sparsifier.get_mask(name=valid_name) + + # check sparsity levels + assert (1.0 - mask.float().mean()) > 0 # some sparsity level achieved + + last_sl = data_scheduler.get_last_param() + last_epoch = data_scheduler.last_epoch + + # check sparsity levels of scheduler + log_last_sl = math.log(last_sl[valid_name]) + log_actual_sl = math.log(base_sl[valid_name] * (data_scheduler.gamma ** last_epoch)) + assert log_last_sl == log_actual_sl + + def _check_on_train_end(self, pl_module, callback): + """Confirms that the mask is squashed after the training ends + This is achieved by making sure that each parameter in the internal container + are not parametrized. + """ + callback.on_train_end(42, pl_module) + + # check that the masks have been squashed + for name, _ in pl_module.model.named_parameters(): + valid_name = _get_valid_name(name) + assert not is_parametrized(callback.data_sparsifier._continer, valid_name) + + @unittest.skipIf(not importlib.util.find_spec("pytorch_lightning"), "No pytorch_lightning") + def test_train_aware_callback(self): + sparsifier_args = { + 'sparsity_level': 0.5, + 'sparse_block_shape': (1, 4), + 'zeros_per_block': 4 + } + scheduler_args = { + 'gamma': 2, + 'step_size': 1 + } + + callback = TrainingAwareDataSparsity( + data_sparsifier_class=DataNormSparsifier, + data_sparsifier_args=sparsifier_args, + data_scheduler_class=StepSLScheduler, + data_scheduler_args=scheduler_args + ) + + pl_module = _make_lightning_module(100, [128, 256, 16]) + + # simulate the training process and check all steps + self._check_on_train_start(pl_module, callback, sparsifier_args, scheduler_args) + + num_epochs = 5 + for _ in range(0, num_epochs): + self._check_on_train_epoch_start(pl_module, callback) + self._simulate_update_param_model(pl_module) + self._check_on_train_epoch_end(pl_module, callback) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/ao/sparsity/experimental/pruner/README.md b/torch/ao/sparsity/_experimental/pruner/README.md similarity index 100% rename from torch/ao/sparsity/experimental/pruner/README.md rename to torch/ao/sparsity/_experimental/pruner/README.md diff --git a/torch/ao/sparsity/_experimental/pruner/__init__.py b/torch/ao/sparsity/_experimental/pruner/__init__.py new file mode 100644 index 0000000000000..c496e555930a2 --- /dev/null +++ b/torch/ao/sparsity/_experimental/pruner/__init__.py @@ -0,0 +1,15 @@ +from .base_pruner import BasePruner +from .parametrization import ( + ActivationReconstruction, + BiasHook, + PruningParametrization, + ZeroesParametrization, +) + +__all__ = [ + "ActivationReconstruction", + "BasePruner", + "BiasHook", + "PruningParametrization", + "ZeroesParametrization", +] diff --git a/torch/ao/sparsity/experimental/pruner/base_pruner.py b/torch/ao/sparsity/_experimental/pruner/base_pruner.py similarity index 95% rename from torch/ao/sparsity/experimental/pruner/base_pruner.py rename to torch/ao/sparsity/_experimental/pruner/base_pruner.py index 65f5a279d8a89..cbcdec5d49e6e 100644 --- a/torch/ao/sparsity/experimental/pruner/base_pruner.py +++ b/torch/ao/sparsity/_experimental/pruner/base_pruner.py @@ -121,7 +121,9 @@ def make_config_from_model(self, model, SUPPORTED_MODULES=SUPPORTED_MODULES, NEE module = stack.pop() for name, child in module.named_children(): if type(child) in SUPPORTED_MODULES: - self.config.append({'tensor_fqn': module_to_fqn(model, child) + '.weight'}) + child_fqn = module_to_fqn(model, child) + assert isinstance(child_fqn, str) # for mypy + self.config.append({'tensor_fqn': child_fqn + '.weight'}) else: if NEEDS_ZEROS is not None and type(child) in NEEDS_ZEROS and hasattr(self, "prune_bias") and self.prune_bias: # only useful for Pruner @@ -153,6 +155,7 @@ def prepare(self, model, config): if type(module_config) is tuple: first_layer, next_layer = module_config assert isinstance(first_layer, nn.Conv2d) and isinstance(next_layer, nn.BatchNorm2d) + assert isinstance(module_config, tuple) # for mypy module_config = {'module': module_config} local_args = copy.deepcopy(self.defaults) local_args.update(module_config) @@ -174,7 +177,7 @@ def prepare(self, model, config): local_args['tensor_name'] = tensor_name_list else: if isinstance(module_config, nn.Module): - module_config = {'module': module_config} + module_config = {'module': module_config} # type: ignore[dict-item] local_args = copy.deepcopy(self.defaults) local_args.update(module_config) @@ -182,6 +185,7 @@ def prepare(self, model, config): # now that we're working with a dict, does it have the new format? if local_args.get('tensor_fqn', None) is not None: tensor_fqn = local_args.get('tensor_fqn') + assert isinstance(tensor_fqn, str) # for mypy info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) for key in info_from_tensor_fqn.keys(): @@ -199,6 +203,7 @@ def prepare(self, model, config): module_fqn = module_fqn[1:] local_args['module_fqn'] = module_fqn local_args['tensor_name'] = "weight" + assert isinstance(module_fqn, str) # for mypy local_args['tensor_fqn'] = module_fqn + ".weight" self.groups.append(local_args) diff --git a/torch/ao/sparsity/experimental/pruner/images/prune_1.png b/torch/ao/sparsity/_experimental/pruner/images/prune_1.png similarity index 100% rename from torch/ao/sparsity/experimental/pruner/images/prune_1.png rename to torch/ao/sparsity/_experimental/pruner/images/prune_1.png diff --git a/torch/ao/sparsity/experimental/pruner/images/prune_2.png b/torch/ao/sparsity/_experimental/pruner/images/prune_2.png similarity index 100% rename from torch/ao/sparsity/experimental/pruner/images/prune_2.png rename to torch/ao/sparsity/_experimental/pruner/images/prune_2.png diff --git a/torch/ao/sparsity/experimental/pruner/images/prune_3.png b/torch/ao/sparsity/_experimental/pruner/images/prune_3.png similarity index 100% rename from torch/ao/sparsity/experimental/pruner/images/prune_3.png rename to torch/ao/sparsity/_experimental/pruner/images/prune_3.png diff --git a/torch/ao/sparsity/experimental/pruner/images/prune_4.png b/torch/ao/sparsity/_experimental/pruner/images/prune_4.png similarity index 100% rename from torch/ao/sparsity/experimental/pruner/images/prune_4.png rename to torch/ao/sparsity/_experimental/pruner/images/prune_4.png diff --git a/torch/ao/sparsity/experimental/pruner/parametrization.py b/torch/ao/sparsity/_experimental/pruner/parametrization.py similarity index 96% rename from torch/ao/sparsity/experimental/pruner/parametrization.py rename to torch/ao/sparsity/_experimental/pruner/parametrization.py index 94e9eafc6c565..77c86a22e175a 100644 --- a/torch/ao/sparsity/experimental/pruner/parametrization.py +++ b/torch/ao/sparsity/_experimental/pruner/parametrization.py @@ -2,6 +2,7 @@ from torch import nn from typing import Any, List +__all__ = ['PruningParametrization', 'ZeroesParametrization', 'ActivationReconstruction', 'BiasHook'] class PruningParametrization(nn.Module): def __init__(self, original_outputs): diff --git a/torch/ao/sparsity/sparsifier/base_sparsifier.py b/torch/ao/sparsity/sparsifier/base_sparsifier.py index a9388e262bfa8..0f8147621d8e9 100644 --- a/torch/ao/sparsity/sparsifier/base_sparsifier.py +++ b/torch/ao/sparsity/sparsifier/base_sparsifier.py @@ -1,14 +1,17 @@ import abc import copy -import warnings from collections import defaultdict -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Set, Tuple, List, Type import torch from torch import nn from torch.nn.utils import parametrize -from .utils import FakeSparsity, module_to_fqn, fqn_to_module, get_arg_info_from_tensor_fqn +from .utils import ( + FakeSparsity, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) __all__ = ["BaseSparsifier"] @@ -16,7 +19,9 @@ nn.Linear } -KEYS_NOT_IN_STATE_DICT = ['module', 'module_fqn', 'tensor_name'] +KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"] + +__all__ = ["BaseSparsifier"] # TODO update desc with new config args class BaseSparsifier(abc.ABC): @@ -43,24 +48,22 @@ class BaseSparsifier(abc.ABC): >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default) >>> sparsifier = BaseSparsifier(config, defaults) """ - def __init__(self, defaults): + def __init__(self, defaults: Optional[Dict[str, Any]] = None): super().__init__() - self.defaults = defaults - if self.defaults is None: - self.defaults = dict() + self.defaults: Dict[str, Any] = defaults or dict() self.state: Dict[str, Dict] = defaultdict(dict) - self.groups = [] + self.groups: List[Dict[str, Any]] = [] self.enable_mask_update = True - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { 'defaults': self.defaults, 'state': self.state, 'groups': self.groups, } - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: self.__dict__.update(state) def __repr__(self): @@ -68,16 +71,16 @@ def __repr__(self): for i, sparse_args in enumerate(self.groups): module = sparse_args['module'] format_string += '\n' - format_string += f'\tModule Group {i}\n' + format_string += f'\tGroup {i}\n' format_string += f'\t module: {module}\n' for key in sorted(sparse_args.keys()): - if key == 'module': + if key == "module": continue - format_string += f'\t {key}: {sparse_args[key]}\n' - format_string += ')' + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" return format_string - def state_dict(self): + def state_dict(self) -> Dict[str, Any]: r"""Returns the state of the optimizer as a :class:`dict`. It contains: @@ -89,7 +92,7 @@ def state_dict(self): """ - groups = [ + groups: List[Dict[str, Any]] = [ dict(filter(lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT , mg.items())) for mg in self.groups ] @@ -99,15 +102,15 @@ def state_dict(self): 'groups': groups, } - def load_state_dict(self, state_dict, strict=True): + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True): groups = copy.deepcopy(state_dict['groups']) states = state_dict['state'] for tensor_fqn, s in states.items(): arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn) - module = arg_info['module'] - tensor_name = arg_info['tensor_name'] + module = arg_info["module"] + tensor_name = arg_info["tensor_name"] if strict and module is None: - raise RuntimeError(f'Error loading {tensor_fqn} into the model') + raise RuntimeError(f"Error loading {tensor_fqn} into the model") found = False for p in module.parametrizations[tensor_name]: @@ -117,23 +120,31 @@ def load_state_dict(self, state_dict, strict=True): if not found: p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape)) parametrize.register_parametrization(module, tensor_name, p) - if s.get('mask', None) is not None: - mask = s.pop('mask') + if s.get("mask", None) is not None: + mask = s.pop("mask") p.mask = mask for mg in groups: - if mg['tensor_fqn'] == tensor_fqn: + if mg["tensor_fqn"] == tensor_fqn: mg.update(arg_info) - self.__setstate__({'state': states, 'groups': groups}) + self.__setstate__({"state": states, "groups": groups}) - def make_config_from_model(self, model, SUPPORTED_MODULES=SUPPORTED_MODULES, NEEDS_ZEROS=None): + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES, + ) -> None: self.config = [] stack = [model] while stack: module = stack.pop() for name, child in module.named_children(): if type(child) in SUPPORTED_MODULES: - self.config.append({'tensor_fqn': module_to_fqn(model, child) + '.weight'}) + module_fqn = module_to_fqn(model, child) + assert isinstance(module_fqn, str) # for mypy + self.config.append( + {"tensor_fqn": module_fqn + ".weight"} + ) else: stack.append(child) @@ -154,57 +165,40 @@ def prepare(self, model, config): # TODO: Remove the configuration by reference ('module') for module_config in self.config: - if isinstance(module_config, nn.Module): - warnings.warn("config elements should be dicts not modules") - module_config = {'module': module_config} + assert isinstance(module_config, dict), ( + "config elements should be dicts not modules i.e.:" + "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" + ) + + assert isinstance(self.defaults, Dict) # for mypy local_args = copy.deepcopy(self.defaults) local_args.update(module_config) - # Make sure there is at least one way of handling the model - tensor_fqn = local_args.get('tensor_fqn', None) - - if tensor_fqn is None: - warnings.warn( - "tensor_fqn is a required argument in the sparsity config" - "and support for `module` and `module_fqn` will be deprecated" - ) - module = local_args.get('module', None) - module_fqn = local_args.get('module_fqn', None) - - if module is None and module_fqn is None: - # No module given for this group - raise ValueError('Either `tensor_fqn` or `module` or `module_fqn` must be specified!') - elif module is None: - # FQN is given - module = fqn_to_module(model, module_fqn) - elif module_fqn is None: - # Module is given - module_fqn = module_to_fqn(model, module) - else: - # Both Module and FQN are given - module_from_fqn = fqn_to_module(model, module_fqn) - assert module is module_from_fqn, \ - 'Given both `module` and `fqn`, it is expected them to ' \ - 'refer to the same thing!' - if module_fqn and module_fqn[0] == '.': - module_fqn = module_fqn[1:] - local_args['module_fqn'] = module_fqn - local_args['module'] = module - local_args['tensor_fqn'] = module_fqn + '.weight' - local_args['tensor_name'] = 'weight' - else: - info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) - - # check that whatever was put into local_args agrees with what was obtained - # from tensor_fqn - for key in info_from_tensor_fqn.keys(): - if key in local_args: - # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that - assert key == 'tensor_fqn' or info_from_tensor_fqn[key] == local_args[key], ( - "Given both `{}` and `tensor_fqn`, it is expected them to " - "agree!".format(key) - ) - local_args.update(info_from_tensor_fqn) + tensor_fqn = local_args.get("tensor_fqn", None) + assert tensor_fqn is not None, ( + "tensor_fqn is a required argument in the sparsity config which" + "replaces previous `module` and [module]`fqn` arguments" + ) + + # populate all information from tensor_fqn + info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) + + # check that whatever was put into local_args agrees with what was obtained + # from tensor_fqn + for key in info_from_tensor_fqn.keys(): + if key in local_args: + assert ( + info_from_tensor_fqn[key] == local_args[key] + or ( + key == "tensor_fqn" + and "." + info_from_tensor_fqn[key] == local_args[key] + ) + # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that + ), ( + "Given both `{}` and `tensor_fqn` in the config, it is expected them to " + "agree!".format(key) + ) + local_args.update(info_from_tensor_fqn) self.groups.append(local_args) self._prepare() @@ -284,19 +278,22 @@ def squash_mask(self, global_params = {k: config[k] for k in params_to_keep} sparse_params.update(global_params) if params_to_keep_per_layer is not None: - params = params_to_keep_per_layer.get(config['module_fqn'], None) + params = params_to_keep_per_layer.get(config["module_fqn"], None) if params is not None: per_layer_params = {k: config[k] for k in params} sparse_params.update(per_layer_params) if sparse_params: + # TODO handle multiple tensor being quantized on a single module, where to store sparse_params? module.sparse_params = sparse_params def convert(self): # TODO: Call the torch.ao.utils.convert in here - raise NotImplementedError('`convert` is not implemented. Please, use ' - '`torch.ao.utils.convert` instead.') + raise NotImplementedError( + "`convert` is not implemented. Please, use " + "`torch.ao.utils.convert` instead." + ) - def step(self, use_path=True): + def step(self, use_path: bool = True) -> None: if not self.enable_mask_update: return with torch.no_grad(): @@ -304,5 +301,5 @@ def step(self, use_path=True): self.update_mask(**config) @abc.abstractmethod - def update_mask(self, module, tensor_name, **kwargs): + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs): pass diff --git a/torch/ao/sparsity/sparsifier/utils.py b/torch/ao/sparsity/sparsifier/utils.py index 850822996b579..ee0791a91dce3 100644 --- a/torch/ao/sparsity/sparsifier/utils.py +++ b/torch/ao/sparsity/sparsifier/utils.py @@ -1,48 +1,59 @@ +from typing import Any, Dict, Optional + from torch import nn -__all__ = ["module_to_fqn", "fqn_to_module", "get_arg_info_from_tensor_fqn", "FakeSparsity"] +__all__ = [ + "module_to_fqn", + "fqn_to_module", + "get_arg_info_from_tensor_fqn", + "FakeSparsity", +] + -def module_to_fqn(model, module, prefix=''): +def module_to_fqn(model: nn.Module, module: nn.Module, prefix: str = "") -> Optional[str]: + """ + Returns the fqn for a module or None if module not a descendent of model. + """ + if module is model: + return "" for name, child in model.named_children(): - new_name = prefix + '.' + name - if child is module: - return new_name - child_path = module_to_fqn(child, module, prefix=new_name) - if child_path is not None: - return child_path + fqn = module_to_fqn(child, module, ".") + if isinstance(fqn, str): + return prefix + name + fqn return None -def fqn_to_module(model, path): - path = path.split('.') - for name in path: - model = getattr(model, name, None) - if model is None: - return None + +def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: + """ + Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` + doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. + """ + if path != "": + for name in path.split("."): + model = getattr(model, name, None) return model -def get_arg_info_from_tensor_fqn(model, tensor_fqn): - # remove starting '.' from tensor_fqn if it exists - if tensor_fqn[0] == '.': - tensor_fqn = tensor_fqn[1:] +def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]: + """ + Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name + """ # string manip to split tensor_fqn into module_fqn and tensor_name # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' - tensor_name = tensor_fqn.split('.')[-1] - module_fqn = tensor_fqn[:-len(tensor_name) - ('.' in tensor_fqn)] - + tensor_name = tensor_fqn.split(".")[-1] + module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] module = fqn_to_module(model, module_fqn) - if module is None: # handling for module_fqn='' - module = model return { - 'module_fqn': module_fqn, - 'module': module, - 'tensor_name': tensor_name, - 'tensor_fqn': tensor_fqn, + "module_fqn": module_fqn, + "module": module, + "tensor_name": tensor_name, + "tensor_fqn": tensor_fqn, } + # Parametrizations class FakeSparsity(nn.Module): r"""Parametrization for the weights. Should be attached to the 'weight' or @@ -56,7 +67,7 @@ class FakeSparsity(nn.Module): """ def __init__(self, mask): super().__init__() - self.register_buffer('mask', mask) + self.register_buffer("mask", mask) def forward(self, x): assert self.mask.shape == x.shape diff --git a/torch/ao/sparsity/sparsifier/weight_norm_sparsifier.py b/torch/ao/sparsity/sparsifier/weight_norm_sparsifier.py index 5f8a1452369f9..2aca6d4663365 100644 --- a/torch/ao/sparsity/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/sparsity/sparsifier/weight_norm_sparsifier.py @@ -33,9 +33,16 @@ class WeightNormSparsifier(BaseSparsifier): Args: sparsity_level: The target level of sparsity - sparse_block_shape: The shape of a sparse block + sparse_block_shape: The shape of a sparse block (see note below) zeros_per_block: Number of zeros in a sparse block + Note:: + The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS), + irrespective of what the rows / cols mean in the data tensor. That means, + if you were to sparsify a weight tensor in the nn.Linear, which has a + weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output + channels, while the `block_COLS` would refer to the input channels. + Note:: All arguments to the WeightNormSparsifier constructor are "default" arguments and could be overriden by the configuration provided in the @@ -48,22 +55,119 @@ def __init__(self, if zeros_per_block is None: zeros_per_block = reduce((lambda x, y: x * y), sparse_block_shape) defaults = { - 'sparsity_level': sparsity_level, - 'sparse_block_shape': sparse_block_shape, - 'zeros_per_block': zeros_per_block + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, } super().__init__(defaults=defaults) + def _scatter_fold_block_mask(self, output_shape, dim, indices, block_shape, + mask=None, input_shape=None, device=None): + r"""Creates patches of size `block_shape` after scattering the indices.""" + if mask is None: + assert input_shape is not None + mask = torch.ones(input_shape, device=device) + mask.scatter_(dim=dim, index=indices, value=0) + mask.data = F.fold(mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape) + return mask + + def _make_tensor_mask(self, data, input_shape, sparsity_level, sparse_block_shape, mask=None): + r"""Creates a tensor-level mask. + + Tensor-level mask is described as a mask, where the granularity of sparsification of the + smallest patch is the sparse_block_shape. That means, that for a given mask and a + sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape. + + In this context, `sparsity_level` describes the fraction of sparse patches. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + + if mask is None: + mask = torch.ones(h, w, device=data.device) + + if sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask) + return mask + elif sparsity_level <= 0.0: + mask.data = torch.ones_like(mask) + return mask + + values_per_block = reduce((lambda x, y: x * y), sparse_block_shape) + if values_per_block > 1: + # Reduce the data + data = F.avg_pool2d( + data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape, ceil_mode=True + ) + data = data.flatten() + num_blocks = len(data) + + data = data.repeat(1, values_per_block, 1) + + threshold_idx = int(round(sparsity_level * num_blocks)) + threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check + _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False) + + # Temp reshape for mask + mask_reshape = mask.reshape(data.shape) # data might be reshaped + self._scatter_fold_block_mask( + dim=2, output_shape=(h + dh, w + dw), + indices=sorted_idx, block_shape=sparse_block_shape, mask=mask_reshape + ) + mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() + return mask + + def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None): + r"""Creates a block-level mask. + + Block-level mask is described as a mask, where the granularity of sparsification of the + largest patch is the sparse_block_shape. That means that for a given mask and a + sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape. + + In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. + """ + if mask is None: + mask = torch.ones(data.shape, device=data.device) + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + values_per_block = reduce((lambda x, y: x * y), sparse_block_shape) + + if values_per_block == zeros_per_block: + # Everything should be sparsified + mask.data = torch.zeros_like(mask) + return mask + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device) + padded_data.fill_(torch.nan) + padded_data[:h, :w] = data + unfolded_data = F.unfold(padded_data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape) + + # Temp reshape for mask + mask_reshape = mask.reshape(unfolded_data.shape) + _, sorted_idx = torch.topk(unfolded_data, k=zeros_per_block, dim=1, largest=False) + + self._scatter_fold_block_mask( + dim=1, indices=sorted_idx, output_shape=padded_data.shape, block_shape=sparse_block_shape, mask=mask_reshape + ) + + mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() + return mask + def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs): values_per_block = reduce((lambda x, y: x * y), sparse_block_shape) if zeros_per_block > values_per_block: - raise ValueError("Number of zeros per block cannot be more than " - "the total number of elements in that block.") + raise ValueError( + "Number of zeros per block cannot be more than " "the total number of elements in that block." + ) if zeros_per_block < 0: raise ValueError("Number of zeros per block should be positive.") - # TODO: Add support for multiple parametrizations for the same weight mask = getattr(module.parametrizations, tensor_name)[0].mask if sparsity_level <= 0 or zeros_per_block == 0: mask.data = torch.ones_like(mask) @@ -71,39 +175,11 @@ def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape, mask.data = torch.zeros_like(mask) else: ww = getattr(module, tensor_name)**2 - ww_reshaped = ww.reshape(1, *ww.shape) - ww_pool = F.avg_pool2d(ww_reshaped, kernel_size=sparse_block_shape, - stride=sparse_block_shape, ceil_mode=True) - ww_pool_flat = ww_pool.flatten() - _, sorted_idx = torch.sort(ww_pool_flat) - threshold_idx = int(round(sparsity_level * len(sorted_idx))) - sorted_idx = sorted_idx[:threshold_idx] - rows, cols = _flat_idx_to_2d(sorted_idx, ww_pool.shape[1:]) - rows *= sparse_block_shape[0] - cols *= sparse_block_shape[1] - - new_mask = torch.ones(ww.shape, device=getattr(module, tensor_name).device) - for row, col in zip(rows, cols): - submask = new_mask[row:row + sparse_block_shape[0], - col:col + sparse_block_shape[1]] - subweight = getattr(module, tensor_name)[row:row + sparse_block_shape[0], - col:col + sparse_block_shape[1]] - self._update_block(submask, subweight, - zeros_per_block, values_per_block) - mask.data = new_mask - - def _update_block(self, - submask: torch.Tensor, - subweight: torch.Tensor, - zeros_per_block: int, - values_per_block: int): - r"""Updates a single sparse block""" - if zeros_per_block == values_per_block: - submask[:] = 0 - else: - w = torch.abs(subweight) - w_flat = w.flatten() - _, sorted_idx = torch.sort(w_flat) - sorted_idx = sorted_idx[:zeros_per_block] - rows, cols = _flat_idx_to_2d(sorted_idx, submask.shape) - submask[rows, cols] = 0 + tensor_mask = self._make_tensor_mask( + data=ww, input_shape=ww.shape, sparsity_level=sparsity_level, sparse_block_shape=sparse_block_shape + ) + if values_per_block != zeros_per_block: + block_mask = self._make_block_mask(data=ww, sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block) + tensor_mask = torch.logical_or(tensor_mask, block_mask) + mask.data = tensor_mask diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 755b097f1a740..38b102743f350 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -22,6 +22,17 @@ def save_for_backward(self, *tensors: torch.Tensor): incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`. + Note that if intermediary tensors, tensors that are neither inputs + nor outputs of :func:`forward`, are saved for backward, your custom Function + may not support double backward. + Custom Functions that do not support double backward should decorate their + :func:`backward` method with ``@once_differentiable`` so that performing + double backward raises an error. If you'd like to support double backward, + you can either recompute intermediaries based on the inputs during backward + or return the intermediaries as the outputs of the custom Function. See the + `double backward tutorial `_ + for more details. + In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors` attribute. Before returning them to the user, a check is made to ensure they weren't used in any in-place operation that modified their content. @@ -34,18 +45,19 @@ def save_for_backward(self, *tensors: torch.Tensor): >>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): - >>> w = x * y * z - >>> out = x * y + y * z + w + >>> w = x * z + >>> out = x * y + y * z + w * y >>> ctx.save_for_backward(x, y, w, out) >>> ctx.z = z # z is not a tensor >>> return out >>> >>> @staticmethod + >>> @once_differentiable >>> def backward(ctx, grad_out): >>> x, y, w, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) - >>> gy = grad_out * (x + z + x * z) + >>> gy = grad_out * (x + z + w) >>> gz = None >>> return gx, gy, gz >>> diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index eb8c46f8f124a..0299f79f3e474 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -1,19 +1,34 @@ -from torch.autograd.profiler_util import ( - EventList, FunctionEvent, MemRecordsAcc, MEMORY_EVENT_NAME, - _filter_name, _filter_stack_entry, _rewrite_name -) +from typing import Any, Dict, List, Optional +from warnings import warn -from torch.autograd import ( - DeviceType, ProfilerActivity, ProfilerConfig, ProfilerState, - kineto_available, _ProfilerResult, _disable_profiler, _enable_profiler, - _prepare_profiler, _supported_activities, _kineto_step, -) -from torch._C._autograd import _ExperimentalConfig import torch import torch.cuda +from torch._C._autograd import _ExperimentalConfig + +from torch.autograd import ( + _disable_profiler, + _enable_profiler, + _kineto_step, + _prepare_profiler, + _ProfilerResult, + _supported_activities, + DeviceType, + kineto_available, + ProfilerActivity, + ProfilerConfig, + ProfilerState, +) +from torch.autograd.profiler_util import ( + _filter_name, + _filter_stack_entry, + _rewrite_name, + EventList, + FunctionEvent, + MEMORY_EVENT_NAME, + MemRecordsAcc, + OUT_OF_MEMORY_EVENT_NAME, +) from torch.futures import Future -from typing import Any, Dict, List, Optional -from warnings import warn try: @@ -282,6 +297,7 @@ def _parse_kineto_results(self, result): trace_start_us = result.trace_start_us() mem_records = [[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME] + oom_records = [evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME] mem_records_acc = MemRecordsAcc(mem_records) def _cpu_memory_usage(mem_record): @@ -370,31 +386,41 @@ def _cuda_memory_usage(mem_record): # parents and children f_evt.thread = fe.thread + + def createFunctionEventForMemoryEvents(evt): + rel_start_us = evt.start_us() - trace_start_us + fe = FunctionEvent( + id=max_evt_id, + name=evt.name(), + trace_name=None, # not outputting in the trace + thread=evt.start_thread_id(), + start_us=rel_start_us, + end_us=rel_start_us, # no duration + fwd_thread=evt.start_thread_id(), + input_shapes=[], + stack=[], + scope=0, # RecordScope::FUNCTION + cpu_memory_usage=_cpu_memory_usage(evt), + cuda_memory_usage=_cuda_memory_usage(evt), + is_async=False, + sequence_nr=-1, + device_type=DeviceType.CPU, + device_index=0, + ) + return fe + # output top-level memory events for mem_record in mem_records: if not mem_record[1]: - rel_start_us = mem_record[0].start_us() - trace_start_us max_evt_id += 1 - fe = FunctionEvent( - id=max_evt_id, - name=MEMORY_EVENT_NAME, - trace_name=None, # not outputting in the trace - thread=mem_record[0].start_thread_id(), - start_us=rel_start_us, - end_us=rel_start_us, # no duration - fwd_thread=mem_record[0].start_thread_id(), - input_shapes=[], - stack=[], - scope=0, # RecordScope::FUNCTION - cpu_memory_usage=_cpu_memory_usage(mem_record[0]), - cuda_memory_usage=_cuda_memory_usage(mem_record[0]), - is_async=False, - sequence_nr=-1, - device_type=DeviceType.CPU, - device_index=0, - ) + fe = createFunctionEventForMemoryEvents(mem_record[0]) function_events.append(fe) + for oom_record in oom_records: + max_evt_id += 1 + fe = createFunctionEventForMemoryEvents(oom_record) + function_events.append(fe) + function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) return function_events @@ -479,6 +505,71 @@ def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]: return profiled_future +class emit_itt(object): + """Context manager that makes every autograd operation emit an ITT range. + + It is useful when running the program under Intel(R) VTune Profiler:: + + vtune <--vtune_flags> + + The Instrumentation and Tracing Technology (ITT) API enables your application to generate and + control the collection of trace data during its execution across different Intel tools. + This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager, + you will be able to see labled ranges in Intel(R) VTune Profiler GUI. + + .. warning: + This context manager should not be called recursively, i.e. at most one + instance should be enabled at any given time. + + Args: + enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op. + Default: ``True``. + record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping + each autograd op will append information about the sizes of Tensor arguments received + by that op, in the following format: + ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]`` + Non-tensor arguments will be represented by ``[]``. + Arguments will be listed in the order they are received by the backend op. + Please note that this order may not match the order in which those arguments were passed + on the Python side. Also note that shape recording may increase the overhead of itt range creation. + Default: ``False`` + + Example: + >>> with torch.autograd.profiler.emit_itt(): + ... model(x) + + """ + def __init__(self, enabled=True, record_shapes=False): + self.enabled = enabled + self.entered = False + self.record_shapes = record_shapes + + def __enter__(self): + if not self.enabled: + return + if self.entered: + raise RuntimeError("ITT annotation context manager is not reentrant") + self.entered = True + _enable_profiler( + ProfilerConfig( + ProfilerState.ITT, + self.record_shapes, + False, + False, + False, + False, + _ExperimentalConfig()), + set() + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.enabled: + return + _disable_profiler() + return False + + class emit_nvtx(object): """Context manager that makes every autograd operation emit an NVTX range. @@ -498,9 +589,9 @@ class emit_nvtx(object): instance should be enabled at any given time. Args: - enabled (bool, optional, default=True): Setting ``enabled=False`` makes this context manager a no-op. + enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op. Default: ``True``. - record_shapes (bool, optional, default=False): If ``record_shapes=True``, the nvtx range wrapping + record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping each autograd op will append information about the sizes of Tensor arguments received by that op, in the following format: ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]`` @@ -508,6 +599,7 @@ class emit_nvtx(object): Arguments will be listed in the order they are received by the backend op. Please note that this order may not match the order in which those arguments were passed on the Python side. Also note that shape recording may increase the overhead of nvtx range creation. + Default: ``False`` Example: >>> with torch.cuda.profiler.profile(): diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index dc505fbc210aa..49c181f73de93 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -636,11 +636,13 @@ def _filter_stack_entry(entry): return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries]) MEMORY_EVENT_NAME = "[memory]" +OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]" def _filter_name(name): # ignoring the following utility ops filtered_out_names = [ MEMORY_EVENT_NAME, # used only for the top-level memory events + OUT_OF_MEMORY_EVENT_NAME, "profiler::_record_function_enter", "profiler::_record_function_enter_new", "profiler::_record_function_exit", diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index d89049b5f3ca1..e187d6d26aed8 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -1,4 +1,5 @@ import sys +import os import torch import warnings from contextlib import contextmanager @@ -36,9 +37,26 @@ def _init(): else: cudnn_compatible = runtime_minor >= compile_minor if not cudnn_compatible: - raise RuntimeError( - 'cuDNN version incompatibility: PyTorch was compiled against {} ' - 'but linked against {}'.format(compile_version, runtime_version)) + base_error_msg = (f'cuDNN version incompatibility: ' + f'PyTorch was compiled against {compile_version} ' + f'but found runtime version {runtime_version}. ' + f'PyTorch already comes bundled with cuDNN. ' + f'One option to resolving this error is to ensure PyTorch ' + f'can find the bundled cuDNN.') + + if 'LD_LIBRARY_PATH' in os.environ: + ld_library_path = os.environ.get('LD_LIBRARY_PATH', '') + if any(substring in ld_library_path for substring in ['cuda', 'cudnn']): + raise RuntimeError(f'{base_error_msg}' + f'Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn' + f'Please either remove it from the path or install cudnn {compile_version}') + else: + raise RuntimeError(f'{base_error_msg}' + f'one possibility is that there is a ' + f'conflicting cuDNN in LD_LIBRARY_PATH.') + else: + raise RuntimeError(base_error_msg) + return True else: def _init(): @@ -84,15 +102,18 @@ def is_acceptable(tensor): return True -def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=None): +def set_flags(_enabled=None, _benchmark=None, _benchmark_limit=None, _deterministic=None, _allow_tf32=None): orig_flags = (torch._C._get_cudnn_enabled(), torch._C._get_cudnn_benchmark(), + None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), torch._C._get_cudnn_deterministic(), torch._C._get_cudnn_allow_tf32()) if _enabled is not None: torch._C._set_cudnn_enabled(_enabled) if _benchmark is not None: torch._C._set_cudnn_benchmark(_benchmark) + if _benchmark_limit is not None and is_available(): + torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit) if _deterministic is not None: torch._C._set_cudnn_deterministic(_deterministic) if _allow_tf32 is not None: @@ -101,9 +122,9 @@ def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=N @contextmanager -def flags(enabled=False, benchmark=False, deterministic=False, allow_tf32=True): +def flags(enabled=False, benchmark=False, benchmark_limit=10, deterministic=False, allow_tf32=True): with __allow_nonbracketed_mutation(): - orig_flags = set_flags(enabled, benchmark, deterministic, allow_tf32) + orig_flags = set_flags(enabled, benchmark, benchmark_limit, deterministic, allow_tf32) try: yield finally: @@ -123,6 +144,9 @@ def __init__(self, m, name): enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic) benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark) + benchmark_limit = None + if is_available(): + benchmark_limit = ContextProp(torch._C._cuda_get_cudnn_benchmark_limit, torch._C._cuda_set_cudnn_benchmark_limit) allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32) # This is the sys.modules replacement trick, see @@ -134,3 +158,4 @@ def __init__(self, m, name): deterministic: bool benchmark: bool allow_tf32: bool +benchmark_limit: int diff --git a/torch/backends/mkl/__init__.py b/torch/backends/mkl/__init__.py index f3d27d1fa08c8..25c11ea10515e 100644 --- a/torch/backends/mkl/__init__.py +++ b/torch/backends/mkl/__init__.py @@ -1,6 +1,49 @@ import torch - def is_available(): r"""Returns whether PyTorch is built with MKL support.""" return torch._C.has_mkl + +VERBOSE_OFF = 0 +VERBOSE_ON = 1 +class verbose(object): + """ + On-demand oneMKL verbosing functionality + To make it easier to debug performance issues, oneMKL can dump verbose + messages containing execution information like duration while executing + the kernel. The verbosing functionality can be invoked via an environment + variable named `MKL_VERBOSE`. However, this methodology dumps messages in + all steps. Those are a large amount of verbose messages. Moreover, for + investigating the performance issues, generally taking verbose messages + for one single iteration is enough. This on-demand verbosing functionality + makes it possible to control scope for verbose message dumping. In the + following example, verbose messages will be dumped out for the second + inference only. + + .. highlight:: python + .. code-block:: python + + import torch + model(data) + with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON): + model(data) + + Args: + level: Verbose level + - ``VERBOSE_OFF``: Disable verbosing + - ``VERBOSE_ON``: Enable verbosing + """ + + def __init__(self, enable): + self.enable = enable + + def __enter__(self): + if self.enable == VERBOSE_OFF: + return + st = torch._C._verbose.mkl_set_verbose(self.enable) + assert st, "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch._C._verbose.mkl_set_verbose(VERBOSE_OFF) + return False diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index c62af5df3b048..00b22cee15e09 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -7,6 +7,52 @@ def is_available(): r"""Returns whether PyTorch is built with MKL-DNN support.""" return torch._C.has_mkldnn +VERBOSE_OFF = 0 +VERBOSE_ON = 1 +VERBOSE_ON_CREATION = 2 +class verbose(object): + """ + On-demand oneDNN (former MKL-DNN) verbosing functionality + To make it easier to debug performance issues, oneDNN can dump verbose + messages containing information like kernel size, input data size and + execution duration while executing the kernel. The verbosing functionality + can be invoked via an environment variable named `DNNL_VERBOSE`. However, + this methodology dumps messages in all steps. Those are a large amount of + verbose messages. Moreover, for investigating the performance issues, + generally taking verbose messages for one single iteration is enough. + This on-demand verbosing functionality makes it possible to control scope + for verbose message dumping. In the following example, verbose messages + will be dumped out for the second inference only. + + .. highlight:: python + .. code-block:: python + + import torch + model(data) + with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON): + model(data) + + Args: + level: Verbose level + - ``VERBOSE_OFF``: Disable verbosing + - ``VERBOSE_ON``: Enable verbosing + - ``VERBOSE_ON_CREATION``: Enable verbosing, including oneDNN kernel creation + """ + + def __init__(self, level): + self.level = level + + def __enter__(self): + if self.level == VERBOSE_OFF: + return + st = torch._C._verbose.mkldnn_set_verbose(self.level) + assert st, "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF) + return False + def set_flags(_enabled): orig_flags = (torch._C._get_mkldnn_enabled(),) torch._C._set_mkldnn_enabled(_enabled) diff --git a/torch/backends/xeon/__init__.py b/torch/backends/xeon/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py new file mode 100644 index 0000000000000..c056af9644789 --- /dev/null +++ b/torch/backends/xeon/run_cpu.py @@ -0,0 +1,659 @@ +""" +This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable Processors with optimal configurations. +Single instance inference, multi-instance inference are enabled. + +Note: term "instance" here doesn't refer to a cloud instance. This script is executed as a single process. It invokes +multiple "instances" which are formed from multiple threads for each. "instance" is kind of group of threads in this +context. + +Illustrated as below: + +:: + + +-----------------------------+----------------------+-------+ + | process | thread | core | + +=============================+======================+=======+ + | torch.backends.xeon.run_cpu | instance 0: thread 0 | 0 | + | | thread 1 | 1 | + | +----------------------+-------+ + | | instance 1: thread 0 | 2 | + | | thread 1 | 3 | + | +----------------------+-------+ + | | ... | ... | + | +----------------------+-------+ + | | instance N: thread 0 | M | + | | thread 1 | M+1 | + +-----------------------------+----------------------+-------+ + +To get the peak performance on Intel(R) Xeon(R) Scalable Processors, the script optimizes the configuration of thread and memory +management. For thread management, the script configures thread affinity and the preload of Intel OMP library. +For memory management, it configures NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc). + +Environment variables that will be set by this script: + ++------------------+-------------------------------------------------------------------------------------------------+ +| Environ Variable | Value | ++==================+=================================================================================================+ +| LD_PRELOAD | Depending on knobs you set, /libiomp5.so, /libjemalloc.so, /libtcmalloc.so might | +| | be appended to LD_PRELOAD. | ++------------------+-------------------------------------------------------------------------------------------------+ +| KMP_AFFINITY | If libiomp5.so is preloaded, KMP_AFFINITY could be set to "granularity=fine,compact,1,0". | ++------------------+-------------------------------------------------------------------------------------------------+ +| KMP_BLOCKTIME | If libiomp5.so is preloaded, KMP_BLOCKTIME is set to "1". | ++------------------+-------------------------------------------------------------------------------------------------+ +| OMP_NUM_THREADS | value of ncores_per_instance | ++------------------+-------------------------------------------------------------------------------------------------+ +| MALLOC_CONF | If libjemalloc.so is preloaded, MALLOC_CONF will be set to | +| | "oversize_threshold:1,background_thread:true,metadata_thp:auto". | ++------------------+-------------------------------------------------------------------------------------------------+ + +*Note*: This script respects environment variables set preliminarily. I.e. If you set the environment variables +mentioned above before running the script, the script will not overwrite the values in the script. + +How to use this module: +~~~~~~~~~~~~~~~~~~~~~~~ + +Single instance inference +------------------------- + +1. Run single-instance inference on a single node with all CPU nodes. + +:: + + >>> python -m torch.backends.xeon.run_cpu --throughput_mode script.py args + +2. Run single-instance inference on a single CPU node. + +:: + + >>> python -m torch.backends.xeon.run_cpu --node_id 1 script.py args + +Multi-instance inference +------------------------ + +1. Multi-instance + By default this tool runs one process per node. If you want to set the instance numbers and core per instance, + --ninstances and --ncores_per_instance should be set. + +:: + + >>> python -m torch.backends.xeon.run_cpu -- python_script args + + eg: on an Intel(R) Xeon(R) Scalable Processor with 14 instance, 4 cores per instance + +:: + + >>> python -m torch.backends.xeon.run_cpu --ninstances 14 --ncores_per_instance 4 python_script args + +2. Run single-instance inference among multiple instances. + By default, runs all ninstances. If you want to independently run a single instance among ninstances, specify rank. + + eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 0-27) + +:: + + >>> python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 0 python_script args + + eg: run 1st instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 28-55) + +:: + + >>> python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 1 python_script args + + eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance, 2 cores per instance, + first four cores (i.e., numactl -C 0-1) + +:: + + >>> python -m torch.backends.xeon.run_cpu --core_list "0, 1, 2, 3" --ninstances 2 --ncores_per_instance 2 + --rank 0 python_script args + +3. To look up what optional arguments this module offers: + +:: + + >>> python -m torch.backends.xeon.run_cpu --help + +Memory allocator +---------------- + +"--enable_tcmalloc" and "--enable_jemalloc" can be used to enable different memory allcator. + +""" + +import sys +import platform +import subprocess +import os +from os.path import expanduser +import re +import glob +from argparse import ArgumentParser, REMAINDER +from argparse import RawTextHelpFormatter +import logging +from torch.distributed.elastic.multiprocessing import Std, start_processes +from typing import List, Dict + +format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.INFO, format=format_str) +logger = logging.getLogger(__name__) + +class _CPUinfo(): + """ + Get CPU inforamation, such as cores list and NUMA information. + """ + def __init__(self, test_input=""): + + self.cpuinfo = [] + if platform.system() in ["Windows", "Darwin"]: + raise RuntimeError(f"{platform.system()} is not supported!!!") + elif platform.system() == "Linux": + # Sample output of: `lscpu --parse=CPU,Core,Socket,Node` + # + # # The following is the parsable format, which can be fed to other + # # programs. Each different item in every column has an unique ID + # # starting from zero. + # # CPU,Core,Socket,Node + # 0,0,0,0 + # 1,1,0,0 + # ... + if test_input == "": + lscpu_cmd = ["lscpu", "--parse=CPU,Core,Socket,Node"] + lscpu_info = subprocess.check_output(lscpu_cmd, universal_newlines=True).split("\n") + else: + lscpu_info = test_input.split("\n") + + # Get information about cpu, core, socket and node + for line in lscpu_info: + pattern = r"^([\d]+,[\d]+,[\d]+,[\d]?)" + regex_out = re.search(pattern, line) + if regex_out: + self.cpuinfo.append(regex_out.group(1).strip().split(",")) + + # physical cores := core column in lscpu output + # logical cores := cPU column in lscpu output + self.node_nums = int(max([line[3] for line in self.cpuinfo])) + 1 + self.node_physical_cores: List[List[int]] = [] # node_id is index + self.node_logical_cores: List[List[int]] = [] # node_id is index + self.physical_core_node_map = {} # phyical core to numa node id + self.logical_core_node_map = {} # logical core to numa node id + + for node_id in range(self.node_nums): + cur_node_physical_core = [] + cur_node_logical_core = [] + for cpuinfo in self.cpuinfo: + nid = cpuinfo[3] if cpuinfo[3] != "" else "0" + if node_id == int(nid): + if int(cpuinfo[1]) not in cur_node_physical_core: + cur_node_physical_core.append(int(cpuinfo[1])) + self.physical_core_node_map[int(cpuinfo[1])] = int(node_id) + cur_node_logical_core.append(int(cpuinfo[0])) + self.logical_core_node_map[int(cpuinfo[0])] = int(node_id) + self.node_physical_cores.append(cur_node_physical_core) + self.node_logical_cores.append(cur_node_logical_core) + + def _physical_core_nums(self): + return len(self.node_physical_cores) * len(self.node_physical_cores[0]) + + def _logical_core_nums(self): + return len(self.node_logical_cores) * len(self.node_logical_cores[0]) + + def get_node_physical_cores(self, node_id): + if node_id < 0 or node_id > self.node_nums - 1: + raise ValueError(f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}") + return self.node_physical_cores[node_id] + + def get_node_logical_cores(self, node_id): + if node_id < 0 or node_id > self.node_nums - 1: + raise ValueError(f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}") + return self.node_logical_cores[node_id] + + def get_all_physical_cores(self): + all_cores = [] + for cores in self.node_physical_cores: + all_cores.extend(cores) + return all_cores + + def get_all_logical_cores(self): + all_cores = [] + for cores in self.node_logical_cores: + all_cores.extend(cores) + return all_cores + + def numa_aware_check(self, core_list): + """ + Check whether all cores in core_list are in the same NUMA node. cross NUMA will reduce perforamnce. + We strongly advice to not use cores on different nodes. + """ + cores_numa_map = self.logical_core_node_map + numa_ids = [] + for core in core_list: + numa_id = cores_numa_map[core] + if numa_id not in numa_ids: + numa_ids.append(numa_id) + if len(numa_ids) > 1: + logger.warning(f"Numa Aware: cores:{str(core_list)} on different NUMA nodes:{str(numa_ids)}. To avoid \ +this behavior, please use --ncores_per_instance knob to make sure number of cores is divisible by --ncores_per_\ +instance. Alternatively, please use --skip_cross_node_cores knob.") + if len(numa_ids) == 0: + raise RuntimeError("invalid number of NUMA nodes; please make sure numa_ids >= 1") + return numa_ids + +class _Launcher(): + r""" + Class for launcher + """ + + msg_lib_notfound = f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \ +or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \ +{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set." + + def __init__(self): + self.cpuinfo = _CPUinfo() + + def add_lib_preload(self, lib_type): + """ + Enale TCMalloc/JeMalloc/intel OpenMP + """ + library_paths = [] + if "CONDA_PREFIX" in os.environ: + library_paths.append(f"{os.environ['CONDA_PREFIX']}/lib") + if "VIRTUAL_ENV" in os.environ: + library_paths.append(f"{os.environ['VIRTUAL_ENV']}/lib") + + library_paths += [f"{expanduser('~')}/.local/lib", "/usr/local/lib", + "/usr/local/lib64", "/usr/lib", "/usr/lib64"] + + lib_find = False + lib_set = False + for item in os.getenv("LD_PRELOAD", "").split(":"): + if item.endswith(f"lib{lib_type}.so"): + lib_set = True + break + if not lib_set: + for lib_path in library_paths: + library_file = os.path.join(lib_path, f"lib{lib_type}.so") + matches = glob.glob(library_file) + if len(matches) > 0: + ld_preloads = [f"{matches[0]}", os.getenv("LD_PRELOAD", "")] + os.environ["LD_PRELOAD"] = os.pathsep.join([p.strip(os.pathsep) for p in ld_preloads if p]) + lib_find = True + break + return lib_set or lib_find + + + def set_memory_allocator(self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False): + """ + Enable TCMalloc/JeMalloc with LD_PRELOAD and set configuration for JeMalloc. + By default, PTMalloc will be used for PyTorch, but TCMalloc and JeMalloc can get better + memory resue and reduce page fault to improve performance. + """ + if enable_tcmalloc and enable_jemalloc: + raise RuntimeError("Unable to enable TCMalloc and JEMalloc at the same time.") + + if enable_tcmalloc: + find_tc = self.add_lib_preload(lib_type="tcmalloc") + if not find_tc: + msg = f"{self.msg_lib_notfound} you can use \"conda install -c conda-forge gperftools\" to install {{0}}" + logger.warning(msg.format("TCmalloc", "tcmalloc")) + else: + logger.info("Use TCMalloc memory allocator") + + elif enable_jemalloc: + find_je = self.add_lib_preload(lib_type="jemalloc") + if not find_je: + msg = f"{self.msg_lib_notfound} you can use \"conda install -c conda-forge jemalloc\" to install {{0}}" + logger.warning(msg.format("Jemalloc", "jemalloc")) + else: + logger.info("Use JeMalloc memory allocator") + self.set_env("MALLOC_CONF", "oversize_threshold:1,background_thread:true,metadata_thp:auto") + + elif use_default_allocator: + pass + + else: + find_tc = self.add_lib_preload(lib_type="tcmalloc") + if find_tc: + logger.info("Use TCMalloc memory allocator") + return + find_je = self.add_lib_preload(lib_type="jemalloc") + if find_je: + logger.info("Use JeMalloc memory allocator") + return + logger.warning(f"""Neither TCMalloc nor JeMalloc is found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib + or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or + {expanduser("~")}/.local/lib/ so the LD_PRELOAD environment variable will not be set. + This may drop the performance""") + + def log_env_var(self, env_var_name=""): + if env_var_name in os.environ: + logger.info(f"{env_var_name}={os.environ[env_var_name]}") + + def set_env(self, env_name, env_value): + if not env_value: + logger.warning(f"{env_name} is None") + if env_name not in os.environ: + os.environ[env_name] = env_value + elif os.environ[env_name] != env_value: + logger.warning(f"Overriding value with the one set in environment variable: {env_name}. \ +Value applied: {os.environ[env_name]}. Value ignored: {env_value}") + self.log_env_var(env_name) + + # set_kmp_affinity is used to control whether to set KMP_AFFINITY or not. + # In scenario that use all cores on all nodes, including logical cores, setting KMP_AFFINITY disables logical cores. + # In this case, KMP_AFFINITY should not be set. + def set_multi_thread_and_allocator(self, ncores_per_instance, + disable_iomp=False, + set_kmp_affinity=True, + enable_tcmalloc=True, + enable_jemalloc=False, + use_default_allocator=False): + """ + Set multi-thread configuration and enable Intel openMP and TCMalloc/JeMalloc. + By default, GNU openMP and PTMalloc are used in PyTorch. but Intel openMP and TCMalloc/JeMalloc are better alternatives + to get performance benifit. + """ + self.set_memory_allocator(enable_tcmalloc, enable_jemalloc, use_default_allocator) + self.set_env("OMP_NUM_THREADS", str(ncores_per_instance)) + if not disable_iomp: + find_iomp = self.add_lib_preload(lib_type="iomp5") + if not find_iomp: + msg = f"{self.msg_lib_notfound} you can use \"conda install mkl\" to install {{0}}" + logger.warning(msg.format("iomp", "iomp5")) + else: + logger.info("Using Intel OpenMP") + if set_kmp_affinity: + self.set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") + self.set_env("KMP_BLOCKTIME", "1") + self.log_env_var("LD_PRELOAD") + + r""" + Launcher for single instance and multi-instance + """ + def launch(self, args): + cores = [] + set_kmp_affinity = True + if args.core_list: # user specify what cores will be used by params + cores = [int(x) for x in args.core_list.split(",")] + if args.ncores_per_instance == -1: + raise RuntimeError("please specify the \"--ncores_per_instance\" if you have pass the --core_list params") + elif args.ninstances > 1 and args.ncores_per_instance * args.ninstances < len(cores): + logger.warning(f"only first {args.ncores_per_instance * args.ninstances} cores will be used, \ +but you specify {len(cores)} cores in core_list") + else: + args.ninstances = len(cores) // args.ncores_per_instance + + else: + if args.use_logical_core: + if args.node_id != -1: + cores = self.cpuinfo.get_node_logical_cores(args.node_id) + else: + cores = self.cpuinfo.get_all_logical_cores() + # When using all cores on all nodes, including logical cores, + # setting KMP_AFFINITY disables logical cores. Thus, KMP_AFFINITY should not be set. + set_kmp_affinity = False + else: + if args.node_id != -1: + cores = self.cpuinfo.get_node_physical_cores(args.node_id) + else: + cores = self.cpuinfo.get_all_physical_cores() + if not args.multi_instance and args.ninstances == -1 and args.ncores_per_instance == -1: + args.ninstances = 1 + args.ncores_per_instance = len(cores) + elif args.multi_instance and args.ninstances == -1 and args.ncores_per_instance == -1: + args.throughput_mode = True + elif args.ncores_per_instance == -1 and args.ninstances != -1: + if args.ninstances > len(cores): + raise RuntimeError(f"there are {len(cores)} total cores but you specify {args.ninstances} ninstances; \ +please make sure ninstances <= total_cores)") + else: + args.ncores_per_instance = len(cores) // args.ninstances + elif args.ncores_per_instance != -1 and args.ninstances == -1: + if not args.skip_cross_node_cores: + args.ninstances = len(cores) // args.ncores_per_instance + else: + ncore_per_node = len(self.cpuinfo.node_physical_cores[0]) + num_leftover_cores = ncore_per_node % args.ncores_per_instance + if args.ncores_per_instance > ncore_per_node: + # too many ncores_per_instance to skip cross-node cores + logger.warning("there are {} core(s) per socket, but you specify {} ncores_per_instance and \ +skip_cross_node_cores. Please make sure --ncores_per_instance < core(s) per \ +socket".format(ncore_per_node, args.ncores_per_instance)) + exit(-1) + elif num_leftover_cores == 0: + # aren't any cross-node cores + logger.info('--skip_cross_node_cores is set, but there are no cross-node cores.') + args.ninstances = len(cores) // args.ncores_per_instance + else: + # skip cross-node cores + if args.ninstances != -1: + logger.warning('--skip_cross_node_cores is exclusive to --ninstances. --ninstances \ +won\'t take effect even if it is set explicitly.') + + i = 1 + leftover_cores = set() + while ncore_per_node * i <= len(cores): + leftover_cores.update(cores[ncore_per_node * i - num_leftover_cores : ncore_per_node * i]) + i += 1 + cores = list(set(cores) - leftover_cores) + assert len(cores) % args.ncores_per_instance == 0 + args.ninstances = len(cores) // args.ncores_per_instance + else: + if args.ninstances * args.ncores_per_instance > len(cores): + raise RuntimeError("Please make sure ninstances * ncores_per_instance <= total_cores") + if args.latency_mode: + logger.warning("--latency_mode is exclusive to --ninstances, --ncores_per_instance, --node_id and \ +--use_logical_core. They won't take effect even they are set explicitly.") + args.ncores_per_instance = 4 + cores = self.cpuinfo.get_all_physical_cores() + args.ninstances = len(cores) // args.ncores_per_instance + + if args.throughput_mode: + logger.warning("--throughput_mode is exclusive to --ninstances, --ncores_per_instance, --node_id and \ +--use_logical_core. They won't take effect even they are set explicitly.") + args.ninstances = self.cpuinfo.node_nums + cores = self.cpuinfo.get_all_physical_cores() + args.ncores_per_instance = len(cores) // args.ninstances + + if args.ninstances > 1 and args.rank != -1: + logger.info(f"assigning {args.ncores_per_instance} cores for instance {args.rank}") + + self.set_multi_thread_and_allocator(args.ncores_per_instance, + args.disable_iomp, + set_kmp_affinity, + args.enable_tcmalloc, + args.enable_jemalloc, + args.use_default_allocator) + entrypoint = "" + launch_args = {} + launch_envs: Dict[int, Dict] = {} + launch_tee = {} + for i in range(args.ninstances): + cmd = [] + cur_process_cores = "" + if not args.disable_numactl: + cmd = ["numactl"] + cores = sorted(cores) + if args.rank == -1: # sequentially assign ncores_per_instance to ninstances + core_list = cores[i * args.ncores_per_instance : (i + 1) * args.ncores_per_instance] + else: # assign ncores_per_instance from rank + core_list = cores[args.rank * args.ncores_per_instance + : (args.rank + 1) * args.ncores_per_instance] + + core_ranges: List[Dict] = [] + for core in core_list: + if len(core_ranges) == 0: + range_elem = {"start": core, "end": core} + core_ranges.append(range_elem) + else: + if core - core_ranges[-1]["end"] == 1: + core_ranges[-1]["end"] = core + else: + range_elem = {"start": core, "end": core} + core_ranges.append(range_elem) + for r in core_ranges: + cur_process_cores = f"{cur_process_cores}{r['start']}-{r['end']}," + cur_process_cores = cur_process_cores[:-1] + numa_params = f"-C {cur_process_cores} " + numa_ids = ",".join([str(numa_id) for numa_id in self.cpuinfo.numa_aware_check(core_list)]) + numa_params += f"-m {numa_ids}" + cmd.extend(numa_params.split()) + with_python = not args.no_python + if with_python: + cmd.append(sys.executable) + cmd.append("-u") + if args.module: + cmd.append("-m") + cmd.append(args.program) + cmd.extend(args.program_args) + cmd_s = " ".join(cmd) + logger.info(cmd_s) + if entrypoint == "": + entrypoint = cmd[0] + del cmd[0] + launch_args[i] = tuple(cmd) + launch_envs[i] = {} + launch_tee[i] = Std.ALL + + if args.rank != -1: # launches single instance, rank, only + break + + ctx = start_processes(name=args.log_file_prefix, + entrypoint=entrypoint, + args=launch_args, + envs=launch_envs, + log_dir=args.log_path, + tee=launch_tee) + ctx.wait() + + +def _add_memory_allocator_params(parser): + + group = parser.add_argument_group("Memory Allocator Parameters") + # allocator control + group.add_argument("--enable_tcmalloc", action="store_true", default=False, + help="Enable tcmalloc allocator") + group.add_argument("--enable_jemalloc", action="store_true", default=False, + help="Enable jemalloc allocator") + group.add_argument("--use_default_allocator", action="store_true", default=False, + help="Use default memory allocator") + +def _add_multi_instance_params(parser): + + group = parser.add_argument_group("Multi-instance Parameters") + # multi-instance control + group.add_argument("--ncores_per_instance", metavar="\b", default=-1, type=int, + help="Cores per instance") + group.add_argument("--ninstances", metavar="\b", default=-1, type=int, + help="For multi-instance, you should give the cores number you used for per instance.") + group.add_argument("--skip_cross_node_cores", action='store_true', default=False, + help="If specified --ncores_per_instance, skips cross-node cores.") + group.add_argument("--rank", metavar="\b", default="-1", type=int, + help="Specify instance index to assign ncores_per_instance for rank; \ +otherwise ncores_per_instance will be assigned sequentially to ninstances. Please refer to \ +https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md") + group.add_argument("--latency_mode", action="store_true", default=False, + help="By detault 4 core per instance and use all physical cores") + group.add_argument("--throughput_mode", action="store_true", default=False, + help="By default one instance per node and use all physical cores") + group.add_argument("--node_id", metavar="\b", default=-1, type=int, + help="node id for multi-instance, by default all nodes will be used") + group.add_argument("--use_logical_core", action="store_true", default=False, + help="Whether only use physical cores") + group.add_argument("--disable_numactl", action="store_true", default=False, + help="Disable numactl") + group.add_argument("--core_list", metavar="\b", default=None, type=str, + help="Specify the core list as \"core_id, core_id, ....\", otherwise, all the cores will be used.") + group.add_argument("--log_path", metavar="\b", default="logs", type=str, + help="The log file directory. Default path is "", which means disable logging to files.") + group.add_argument("--log_file_prefix", metavar="\b", default="run", type=str, + help="log file prefix") + +def _add_kmp_iomp_params(parser): + + group = parser.add_argument_group("IOMP Parameters") + group.add_argument("--disable_iomp", action="store_true", default=False, + help="By default, we use Intel OpenMP and libiomp5.so will be add to LD_PRELOAD") + +def create_args(parser=None): + """ + Helper function parsing the command line options + @retval ArgumentParser + """ + parser.add_argument("--multi_instance", action="store_true", default=False, + help="Enable multi-instance, by default one instance per node") + + parser.add_argument("-m", "--module", default=False, action="store_true", + help="Changes each process to interpret the launch script " + "as a python module, executing with the same behavior as" + "\"python -m\".") + + parser.add_argument("--no_python", default=False, action="store_true", + help="Do not prepend the --program script with \"python\" - just exec " + "it directly. Useful when the script is not a Python script.") + + _add_memory_allocator_params(parser) + _add_kmp_iomp_params(parser) + + _add_multi_instance_params(parser) + # positional + parser.add_argument("program", type=str, + help="The full path to the proram/script to be launched. " + "followed by all the arguments for the script") + + # rest from the training program + parser.add_argument("program_args", nargs=REMAINDER) + +def main(args): + env_before = set(os.environ.keys()) + if platform.system() in ["Windows", "Darwin"]: + raise RuntimeError(f"{platform.system()} is not supported!!!") + + if args.log_path: + os.makedirs(args.log_path, exist_ok=True) + + if args.latency_mode and args.throughput_mode: + raise RuntimeError("Either args.latency_mode or args.throughput_mode should be set") + + if not args.no_python and not args.program.endswith(".py"): + raise RuntimeError("For non Python script, you should use \"--no_python\" parameter.") + + # Verify LD_PRELOAD + if "LD_PRELOAD" in os.environ: + lst_valid = [] + tmp_ldpreload = os.environ["LD_PRELOAD"] + for item in tmp_ldpreload.split(":"): + matches = glob.glob(item) + if len(matches) > 0: + lst_valid.append(item) + else: + logger.warning(f"{item} doesn't exist. Removing it from LD_PRELOAD.") + if len(lst_valid) > 0: + os.environ["LD_PRELOAD"] = ":".join(lst_valid) + else: + os.environ["LD_PRELOAD"] = "" + + launcher = _Launcher() + launcher.launch(args) + for x in sorted(set(os.environ.keys()) - env_before): + logger.debug("{x}={os.environ[x]}") + +if __name__ == "__main__": + parser = ArgumentParser(description="This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable " + "Processors with optimal configurations. Single instance inference, " + "multi-instance inference are enable. To get the peak performance on Intel(R) " + "Xeon(R) Scalable Processors, the script optimizes the configuration " + "of thread and memory management. For thread management, the script configures thread " + "affinity and the preload of Intel OMP library. For memory management, it configures " + "NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc) " + "\n################################# Basic usage ############################# \n" + "\n 1. single instance\n" + "\n >>> python -m torch.backends.xeon.run_cpu python_script args \n" + "\n2. multi-instance \n" + "\n >>> python -m torch.backends.xeon.run_cpu --ninstances xxx " + "--ncores_per_instance xx python_script args\n" + "\n############################################################################# \n", + formatter_class=RawTextHelpFormatter) + create_args(parser) + args = parser.parse_args() + main(args) diff --git a/torch/contrib/_tensorboard_vis.py b/torch/contrib/_tensorboard_vis.py index b1b8d35a511d7..94b2f64f78912 100644 --- a/torch/contrib/_tensorboard_vis.py +++ b/torch/contrib/_tensorboard_vis.py @@ -57,7 +57,7 @@ def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph): state (GraphExecutor or GraphExecutorState): GraphExecutor to display. name_prefix (str): Name prefix of the containing subgraph. pb_graph (GraphDef): graph to append to. - inline_graph (callable): a function that handles setting up a value_map, + inline_graph (Callable): a function that handles setting up a value_map, so that some graphs in here can be inlined. This is necessary, because this will simply be `visualize` for the top-level GraphExecutor, or `inline_graph` for all nested ones. diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index 674c1704daaf9..c481c27c92f73 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -94,10 +94,10 @@ PyTypeObject* loadTypedStorageTypeObject() { TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module)); PyObject* typed_storage_obj = - PyObject_GetAttrString(storage_module, "_TypedStorage"); + PyObject_GetAttrString(storage_module, "TypedStorage"); TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj)); return reinterpret_cast( - PyObject_GetAttrString(storage_module, "_TypedStorage")); + PyObject_GetAttrString(storage_module, "TypedStorage")); } PyTypeObject* getTypedStorageTypeObject() { @@ -125,7 +125,7 @@ at::Storage createStorageGetType( if (is_typed_storage) { // NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and // `_storage`, so we must decrement them. The refcounts will still stay - // nonzero since the `_TypedStorage` maintains a reference. + // nonzero since the `TypedStorage` maintains a reference. PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype"); TORCH_INTERNAL_ASSERT(dtype_obj); Py_DECREF(dtype_obj); diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 7ae3aa62f9b73..06541ad9e2607 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -15,6 +15,7 @@ #include #include #include +#include #if defined(USE_DISTRIBUTED) && defined(USE_C10D) #include diff --git a/torch/csrc/MemoryFormat.cpp b/torch/csrc/MemoryFormat.cpp index cb4a4e387feab..698eea7730be8 100644 --- a/torch/csrc/MemoryFormat.cpp +++ b/torch/csrc/MemoryFormat.cpp @@ -28,6 +28,17 @@ PyObject* THPMemoryFormat_repr(THPMemoryFormat* self) { return THPUtils_packString(self->name); } +PyObject* THPMemoryFormat_reduce(PyObject* _self, PyObject* noargs) { + auto* self = (THPMemoryFormat*)_self; + return THPUtils_packString(self->name); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) +static PyMethodDef THPMemoryFormat_methods[] = { + {"__reduce__", THPMemoryFormat_reduce, METH_NOARGS, nullptr}, + {nullptr} /* Sentinel */ +}; + PyTypeObject THPMemoryFormatType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch.memory_format", /* tp_name */ sizeof(THPMemoryFormat), /* tp_basicsize */ @@ -55,7 +66,7 @@ PyTypeObject THPMemoryFormatType = { 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ + THPMemoryFormat_methods, /* tp_methods */ nullptr, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ac2f753d89ee9..18fb5b105ef39 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -15,12 +15,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include @@ -910,6 +912,18 @@ void initModule(PyObject* module); } // namespace torch #endif +#ifdef USE_ITT +namespace torch { +namespace profiler { +void initIttBindings(PyObject* module); +} // namespace profiler +} // namespace torch +#endif + +namespace torch { +void initVerboseBindings(PyObject* module); +} // namespace torch + static std::vector methods; // In Python we can't use the trick of C10_LOG_API_USAGE_ONCE @@ -1008,9 +1022,13 @@ PyObject* initModule() { torch::autograd::init_legacy_variable(module); torch::python::init_bindings(module); torch::lazy::initLazyBindings(module); +#ifdef USE_ITT + torch::profiler::initIttBindings(module); +#endif #ifdef USE_CUDA torch::cuda::initModule(module); #endif + torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); #ifdef USE_CUDA @@ -1243,6 +1261,26 @@ Call this whenever a new thread is created in order to propagate values from return toString(x.key_set()); }); + py_module.def("_add_meta_to_tls_dispatch_include", []() { + auto local_keyset = c10::impl::tls_local_dispatch_key_set(); + c10::DispatchKeySet key_set({at::DispatchKey::Meta}); + local_keyset.included_ = local_keyset.included_ | key_set; + c10::impl::_force_tls_local_dispatch_key_set(local_keyset); + }); + py_module.def("_remove_meta_from_tls_dispatch_include", []() { + auto local_keyset = c10::impl::tls_local_dispatch_key_set(); + c10::DispatchKeySet key_set({at::DispatchKey::Meta}); + auto k = key_set.highestBackendKey(); + local_keyset.included_ = local_keyset.included_.remove_backend(k); + c10::impl::_force_tls_local_dispatch_key_set(local_keyset); + }); + + py_module.def("_dump_local_tls_set", []() { + auto local_keyset = c10::impl::tls_local_dispatch_key_set(); + std::cout << "Included: " << toString(local_keyset.included_) << "\n"; + std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n"; + }); + const auto& defaultGenerator = at::detail::getDefaultCPUGenerator(); THPDefaultCPUGenerator = (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index 966a82a65b317..36419f20eccd0 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -11,6 +12,7 @@ #include #include +#include struct THPSize { PyTupleObject tuple; @@ -57,7 +59,9 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) { TORCH_CHECK( !torch::jit::tracer::isTracing(), "JIT Tracing of SymInts isn't supported"); - auto py_symint = py::cast(si.toSymbolicIntNode()).release().ptr(); + auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr(); + if (!py_symint) + throw python_error(); PyTuple_SET_ITEM(ret.get(), i, py_symint); } else { if (torch::jit::tracer::isTracing()) { @@ -67,7 +71,8 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) { throw python_error(); PyTuple_SET_ITEM(ret.get(), i, py_size_tensor); } else { - PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(si.data())); + PyTuple_SET_ITEM( + ret.get(), i, THPUtils_packInt64(si.as_int_unchecked())); } } } diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 4e1a74b262b6b..64d5cd785e52c 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -172,10 +172,9 @@ static PyObject* THPStorage_pynew( THPUtils_setError( THPStorageStr "(): tried to construct a storage from a sequence (%s), " - "but one of the items was of type %s instead of %s", + "but one of the items was of type %s instead of int", THPUtils_typename(sequence), - THPUtils_typename(item.get()), - THPUtils_typeTraits::python_type_str); + THPUtils_typename(item.get())); return nullptr; } return (PyObject*)self.release(); @@ -262,9 +261,8 @@ static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) { HANDLE_TH_ERRORS if (!THPByteUtils_checkReal(value)) { THPUtils_setError( - "can only set storage content with a %s, but got " + "can only set storage content with a int types, but got " "%s instead", - THPUtils_typeTraits::python_type_str, THPUtils_typename(value)); return -1; } @@ -389,7 +387,7 @@ bool THPStorage_init(PyObject* module) { } void THPStorage_postInit(PyObject* module) { - THPStorageClass = PyObject_GetAttrString(module, "_UntypedStorage"); + THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage"); if (!THPStorageClass) throw python_error(); } diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 593a96921baf1..827caea2a62f6 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -3,7 +3,7 @@ #include -#define THPStorageStr "torch._UntypedStorage" +#define THPStorageStr "torch.UntypedStorage" #define THPStorageBaseStr "StorageBase" struct THPStorage { diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index cf6fca98c213c..2b74c8a2fd290 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -138,7 +138,7 @@ static PyObject* THPStorage_resize_(PyObject* _self, PyObject* number_arg) { } else { TORCH_CHECK( false, - "_UntypedStorage.resize_: got unexpected device type ", + "UntypedStorage.resize_: got unexpected device type ", device_type); } Py_INCREF(self); @@ -151,9 +151,8 @@ static PyObject* THPStorage_fill_(PyObject* _self, PyObject* number_arg) { auto self = (THPStorage*)_self; THPUtils_assert( THPByteUtils_checkReal(number_arg), - "fill_ expects %s, " + "fill_ expects int, " "but got %s", - THPUtils_typeTraits::python_type_str, THPUtils_typename(number_arg)); storage_fill( at::unsafeStorageFromTH(self->cdata, /*retain=*/true), diff --git a/torch/csrc/Stream.cpp b/torch/csrc/Stream.cpp index 7d0cd473af94d..1cd680c76de7a 100644 --- a/torch/csrc/Stream.cpp +++ b/torch/csrc/Stream.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -107,7 +108,7 @@ PyTypeObject THPStreamType = { void THPStream_init(PyObject* module) { THPStreamClass = &THPStreamType; - Py_TYPE(&THPStreamType) = &PyType_Type; + Py_SET_TYPE(&THPStreamType, &PyType_Type); if (PyType_Ready(&THPStreamType) < 0) { throw python_error(); } diff --git a/torch/csrc/THCGenerateByteType.h b/torch/csrc/THCGenerateByteType.h deleted file mode 100644 index 23648de8025ff..0000000000000 --- a/torch/csrc/THCGenerateByteType.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef THC_GENERIC_FILE -#error "You must define THC_GENERIC_FILE before including THCGenerateByteType.h" -#endif - -#define scalar_t uint8_t -#define accreal int64_t -#define Real Byte -#define CReal CudaByte -#define THC_REAL_IS_BYTE -#line 1 THC_GENERIC_FILE -#include THC_GENERIC_FILE -#undef scalar_t -#undef accreal -#undef Real -#undef CReal -#undef THC_REAL_IS_BYTE - -#ifndef THCGenerateAllTypes -#undef THC_GENERIC_FILE -#endif diff --git a/torch/csrc/THGenerateByteType.h b/torch/csrc/THGenerateByteType.h deleted file mode 100644 index aee357422bb89..0000000000000 --- a/torch/csrc/THGenerateByteType.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef TH_GENERIC_FILE -#error "You must define TH_GENERIC_FILE before including THGenerateByteType.h" -#endif - -#define scalar_t uint8_t -#define accreal int64_t -#define Real Byte -#define TH_REAL_IS_BYTE -#line 1 TH_GENERIC_FILE -#include TH_GENERIC_FILE -#undef scalar_t -#undef accreal -#undef Real -#undef TH_REAL_IS_BYTE - -#ifndef THGenerateManyTypes -#undef TH_GENERIC_FILE -#endif diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 20abf1572f283..3dd59c9f12f87 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -327,6 +327,25 @@ inline std::tuple qr_out( return torch::linalg_qr_out(Q, R, input, mode); } +inline std::tuple solve_ex( + const Tensor& input, + const Tensor& other, + bool left, + bool check_errors) { + return torch::linalg_solve_ex(input, other, left, check_errors); +} + +inline std::tuple solve_ex_out( + Tensor& result, + Tensor& info, + const Tensor& input, + const Tensor& other, + bool left, + bool check_errors) { + return torch::linalg_solve_ex_out( + result, info, input, other, left, check_errors); +} + inline Tensor solve(const Tensor& input, const Tensor& other, bool left) { return torch::linalg_solve(input, other, left); } @@ -889,6 +908,27 @@ inline Tensor& ldl_solve_out( return torch::linalg_ldl_solve_out(result, LD, pivots, B, hermitian); } +/// Solves a system linear system AX = B +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_ex +inline std::tuple solve_ex( + const Tensor& input, + const Tensor& other, + bool left, + bool check_errors) { + return detail::solve_ex(input, other, left, check_errors); +} + +inline std::tuple solve_ex_out( + Tensor& result, + Tensor& info, + const Tensor& input, + const Tensor& other, + bool left, + bool check_errors) { + return detail::solve_ex_out(result, info, input, other, left, check_errors); +} + /// Computes a tensor `x` such that `matmul(input, x) = other`. /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve diff --git a/torch/csrc/api/include/torch/python.h b/torch/csrc/api/include/torch/python.h index 6da380d490970..15902a026cf59 100644 --- a/torch/csrc/api/include/torch/python.h +++ b/torch/csrc/api/include/torch/python.h @@ -8,8 +8,12 @@ #include #include #include +#include +#include #include #include +#include +#include #include #include diff --git a/torch/csrc/api/include/torch/special.h b/torch/csrc/api/include/torch/special.h index 6d3c3df33efbe..12e3439130af5 100644 --- a/torch/csrc/api/include/torch/special.h +++ b/torch/csrc/api/include/torch/special.h @@ -615,6 +615,25 @@ inline Tensor softmax( return torch::special_softmax(self, dim, dtype); } +/// Airy function Ai. +/// +/// See https://pytorch.org/docs/master/special.html#torch.special.airy_ai. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::airy_ai(x); +/// ``` +inline Tensor airy_ai(const Tensor& x) { + return torch::special_airy_ai(x); +} + +inline Tensor& airy_ai_out(Tensor& y, const Tensor& x) { + return torch::special_airy_ai_out(y, x); +} + /// Bessel function of the first kind of order 0. /// /// See https://pytorch.org/docs/master/special.html#torch.special.bessel_j0. @@ -1139,6 +1158,46 @@ inline Tensor& modified_bessel_k1_out(Tensor& result, const Tensor& self) { return torch::special_modified_bessel_k1_out(result, self); } +/// Scaled modified Bessel function of the second kind of order 0. +/// +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.scaled_modified_bessel_k0. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::scaled_modified_bessel_k0(x); +/// ``` +inline Tensor scaled_modified_bessel_k0(const Tensor& x) { + return torch::special_scaled_modified_bessel_k0(x); +} + +inline Tensor& scaled_modified_bessel_k0_out(Tensor& y, const Tensor& x) { + return torch::special_scaled_modified_bessel_k0_out(y, x); +} + +/// Scaled modified Bessel function of the second kind of order 1. +/// +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.scaled_modified_bessel_k1. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::scaled_modified_bessel_k1(x); +/// ``` +inline Tensor scaled_modified_bessel_k1(const Tensor& x) { + return torch::special_scaled_modified_bessel_k1(x); +} + +inline Tensor& scaled_modified_bessel_k1_out(Tensor& y, const Tensor& x) { + return torch::special_scaled_modified_bessel_k1_out(y, x); +} + /// Shifted Chebyshev polynomial of the first kind. /// /// See @@ -1323,5 +1382,24 @@ inline Tensor& shifted_chebyshev_polynomial_w_out( return torch::special_shifted_chebyshev_polynomial_w_out(output, x, n); } +/// Spherical Bessel function of the first kind of order 0. +/// +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.spherical_bessel_j0. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::spherical_bessel_j0(x); +/// ``` +inline Tensor spherical_bessel_j0(const Tensor& x) { + return torch::special_spherical_bessel_j0(x); +} + +inline Tensor& spherical_bessel_j0_out(Tensor& y, const Tensor& x) { + return torch::special_spherical_bessel_j0_out(y, x); +} } // namespace special } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp index 56ff49eff189c..2d8d4f4697d11 100644 --- a/torch/csrc/api/src/nn/modules/linear.cpp +++ b/torch/csrc/api/src/nn/modules/linear.cpp @@ -126,7 +126,7 @@ Tensor UnflattenImpl::forward(const Tensor& input) { } return input.unflatten(dimname, sizes, names); } - return input.unflatten(options.dim(), options.sizes(), torch::nullopt); + return input.unflatten(options.dim(), options.sizes()); } // ============================================================================ diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index d02b6cdd9cfff..749c09f5e5dbe 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ namespace details { using at::areAnyTensorSubclassLike; using at::IntArrayRef; +using at::OptionalIntArrayRef; using at::Scalar; using at::Tensor; using at::TensorList; @@ -187,6 +189,16 @@ Tensor scale_grad_by_count( return (grad / mask.sum(dims, true)) * mask; } +Tensor amaxamin_jvp( + const Tensor& x, + const Tensor& dx, + const Tensor& result, + IntArrayRef dim, + bool keepdim) { + auto mask = x == restore_reduced_dims(result, dim, keepdim); + return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim); +} + std::tuple _euclidean_dist_backward( const Tensor& grad, const Tensor& x1, @@ -235,11 +247,11 @@ Tensor norm_backward( } if (p == 0.0) { - return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + return {}; } else if (p == 1.0) { return self.sgn() * grad; } else if (p == 2.0) { - return self * (grad / norm).masked_fill_(norm == 0, 0); + return grad * (self / norm).masked_fill_(norm == 0, 0); } else if (std::isinf(p)) { // Derivative of amax(abs(self), dim, keepdim) but respecting nans // We create a mask of `argmax`: it's argmax if self.abs() == norm or it's @@ -475,17 +487,13 @@ Tensor mvlgamma_backward(Tensor grad, const Tensor& self, int64_t p) { return grad * args.digamma_().sum(-1); } -Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) { - if (self.is_complex()) { - auto abs = at::abs(self); - // C -> C - // https://arxiv.org/pdf/1701.00392.pdf Section 4.20 - return at::where( - abs == 0.0, - at::zeros({}, grad.options()), - (grad / abs - (at::real(grad / self) * result))); +Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn) { + if (x.is_complex()) { + auto abs = x.abs(); + return ((gx - (sgn * sgn) * gx.conj()) / (2. * abs)) + .masked_fill_(abs == 0., 0.); } else { - return at::zeros_like(self, at::MemoryFormat::Preserve); + return at::_efficientzerotensor(sgn.sizes(), sgn.options()); } } @@ -550,8 +558,21 @@ Tensor deg2rad_backward(const Tensor& grad) { return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180))); } -Tensor unsqueeze_multiple(const Tensor& t, IntArrayRef dim, size_t n_dims) { - auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims); +Tensor unsqueeze_multiple( + const Tensor& t, + OptionalIntArrayRef opt_dim, + size_t n_dims) { + if (opt_dim.has_value()) { + IntArrayRef dim = opt_dim.value(); + auto dim_size = dim.size(); + // Optimisation for two common cases + if (dim_size == 0) { + return t; + } else if (dim_size == 1) { + return t.unsqueeze(dim[0]); + } + } + auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims); Tensor res = t; for (const auto i : c10::irange(n_dims)) { if (dims_to_unsqueeze[i]) { @@ -564,17 +585,28 @@ Tensor unsqueeze_multiple(const Tensor& t, IntArrayRef dim, size_t n_dims) { Tensor sum_backward( const Tensor& grad, IntArrayRef sizes, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim) { if (!keepdim && sizes.size() > 0) { - if (dims.size() == 1) { - return grad.unsqueeze(dims[0]).expand(sizes); - } else { - Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); - return res.expand(sizes); + if (opt_dims.has_value() && opt_dims.value().size() > 0) { + return unsqueeze_multiple(grad, opt_dims, sizes.size()).expand(sizes); } + } + return grad.expand(sizes); +} + +Tensor sum_backward( + const Tensor& grad, + c10::SymIntArrayRef sizes, + c10::SymIntArrayRef dims, + bool keepdim) { + if (!keepdim && sizes.size() > 0 && dims.size() > 0) { + // we are only using `keepdim=true` path for SymInts for now + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Only the keepdim=true path is implemented to support symints in autograd"); } else { - return grad.expand(sizes); + return grad.expand_symint(sizes); } } @@ -583,17 +615,19 @@ Tensor nansum_backward( const Tensor& self, IntArrayRef dims, bool keepdim) { - auto sizes = self.sizes(); - if (!keepdim && sizes.size() > 0) { - if (dims.size() == 1) { - return grad.unsqueeze(dims[0]).expand(sizes) * self.isnan().logical_not(); - } else { - Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); - return res.expand(sizes) * self.isnan().logical_not(); - } - } else { - return grad.expand(sizes) * self.isnan().logical_not(); - } + return sum_backward(grad, self.sizes(), dims, keepdim) * + self.isnan().logical_not(); +} + +Tensor mean_backward( + const Tensor& grad, + IntArrayRef shape, + OptionalIntArrayRef opt_dim, + int64_t numel, + bool keepdim) { + bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().size() == 0; + auto n = is_all_reduce ? numel : _safe_size(shape, opt_dim.value()); + return sum_backward(grad, shape, opt_dim, keepdim) / n; } std::vector reverse_list(const IntArrayRef list) { @@ -649,6 +683,10 @@ Tensor prod_backward( if (input.dim() == 0) { return grad; } + if (input.is_meta()) { + return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0) + .view_as(input); + } Tensor zero_idx = (input == 0).nonzero(); if (zero_idx.numel() == 0) { return grad * (result / input).conj(); @@ -674,6 +712,9 @@ Tensor prod_backward( grad = grad.unsqueeze(dim); result = result.unsqueeze(dim); } + if (input.is_meta()) { + return prod_safe_zeros_backward(grad, input, dim); + } Tensor zero_mask = (input == 0); Tensor slice_zero_count = zero_mask.sum(dim, true); @@ -1270,25 +1311,19 @@ Tensor sparse_sparse_matmul_backward( Tensor renorm_backward( const Tensor& grad, const Tensor& self, - const Scalar& p_s, + const Scalar& p, int64_t dim, const Scalar& maxnorm) { - auto self_sizes = self.sizes(); - dim = c10::maybe_wrap_dim(dim, self_sizes.size()); - at::DimVector reduce_dims(self_sizes.size()); + auto n = self.dim(); + dim = c10::maybe_wrap_dim(dim, n); + auto reduce_dims = at::DimVector(n); std::iota(reduce_dims.begin(), reduce_dims.end(), 0); reduce_dims.erase(reduce_dims.begin() + dim); - auto dtype = self.scalar_type(); - auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true); - const auto p = p_s.toDouble(); - - Tensor norm; - if (acc_type != dtype) { - norm = at::linalg_vector_norm( - self, p, reduce_dims, /*keepdim=*/true, /*dtype=*/acc_type); - } else { - norm = at::linalg_vector_norm(self, p, reduce_dims, /*keepdim=*/true); - } + + auto acc_type = + at::toAccumulateType(self.scalar_type(), /*is_cuda=*/self.is_cuda()); + auto norm = at::linalg_vector_norm( + self, p, reduce_dims, /*keepdim=*/true, /*dtype=*/acc_type); const auto real_acc_type = c10::toRealValueType(acc_type); auto grad_output = (self.conj() * grad); @@ -1331,47 +1366,32 @@ Tensor repeat_backward( // tensor. Then, sum up gradients over repeated tensors along 'dim', and // reduce shape from 'repeat * dimsize/repeat' to 'dimsize/repeat' // ('input_dimsize'). Example: - // Size(3, 2) Size(6, 2) - // [[v1_0, - // v1_1], - // [v1_2, - // v1_3], - // [[v0, v1], repeat(2, 1) [v1_4, - // v1_5], - // [v2, v3], -------------> [v2_0, - // v2_1], [v4, v5]] [v2_2, v2_3], - // [v2_4, - // v2_5]] + // Size(3, 2) Size(6, 2) + // [[v1_0, v1_1], + // [v1_2, v1_3], + // [[v0, v1], repeat(2, 1) [v1_4, v1_5], + // [v2, v3], -------------> [v2_0, v2_1], + // [v4, v5]] [v2_2, v2_3], + // [v2_4, v2_5]] + // + // input grad (3, 2) reshape (2, 3, 2) output grad (6, 2) + // [[[g1_0, g1_1], [[g1_0, g1_1], + // [g1_2, g1_3], [g1_2, g1_3], + // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, g1_5], + // [g1_2+g2_2, g1_3+g2_3], [g2_0, g2_1], [[g2_0, g2_1], + // [g1_4+g2_4, g1_5+g2_5]] [g2_2, g2_3], [g2_2, g2_3], + // [g2_4, g2_5]] [g2_4, g2_5]]] // - // input grad (3, 2) reshape (2, 3, 2) output grad (6, - // 2) - // [[[g1_0, g1_1], [[g1_0, - // g1_1], - // [g1_2, g1_3], [g1_2, - // g1_3], - // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, - // g1_5], - // [g1_0+g2_0, g1_1+g2_1], [g2_0, - // g2_1], [g1_0+g2_0, g1_1+g2_1]] [[g2_0, g2_1], [g2_2, g2_3], - // [g2_2, g2_3], [g2_4, - // g2_5]] [g2_4, g2_5]]] // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and // then sum over 'dim+1'. The gradient for input is not correctly aligned // with input. Example: - // input grad (3, 2) reshape (3, 2, 2) output grad (6, - // 2) - // [[[g1_0, g1_1], - // [g1_2, g1_3]], [[g1_0, - // g1_1], - // [g1_2, - // g1_3], - // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, - // g1_5], - // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, - // g2_1], [g2_2+g2_4, g2_3+g2_5]] [g2_2, g2_3], - // [[g2_2, g2_3], [g2_4, - // g2_5]] - // [g2_4, g2_5]]] + // input grad (3, 2) reshape (3, 2, 2) output grad (6, 2) + // [[[g1_0, g1_1], [[g1_0, g1_1], + // [g1_2, g1_3]], [g1_2, g1_3], + // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, g1_5], + // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, g2_1], + // [g2_2+g2_4, g2_3+g2_5]] [[g2_2, g2_3], [g2_2, g2_3], + // [g2_4, g2_5]]] [g2_4, g2_5]] if (repeat != 1) { grad_size.push_back(repeat); sum_dims.push_back(grad_size.size() - 1); @@ -1460,14 +1480,6 @@ Tensor evenly_read_jvp( return at::sum(mask * grad_output); } -static Tensor var_backward( - const Tensor& grad, - const Tensor& self, - int64_t correction) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - return (2.0 / (self.numel() - correction)) * grad * (self - self.mean()); -} - Tensor var_backward( Tensor grad, const Tensor& self, @@ -1476,7 +1488,13 @@ Tensor var_backward( bool keepdim) { auto correction = correction_opt.value_or(1); if (self.dim() == 0 || !dim_opt.has_value()) { - return var_backward(grad, self, correction); + // To apease ASAN + auto n = self.numel(); + if (n == correction) { + return INFINITY * grad; + } else { + return (2.0 / (self.numel() - correction)) * grad * (self - self.mean()); + } } auto dim = dim_opt.value(); if (!keepdim && self.dim() > 1) { @@ -1487,28 +1505,6 @@ Tensor var_backward( return (2.0 / dof) * grad * (self - self.mean(dim, /*keepdim=*/true)); } -Tensor var_jvp( - const Tensor& self_t, - const Tensor& self_p, - const Tensor& result, - at::OptionalIntArrayRef dim_opt, - c10::optional correction_opt, - bool keepdim) { - auto correction = correction_opt.value_or(1); - if (self_p.dim() == 0 || !dim_opt.has_value()) { - return var_backward(self_t.conj(), self_p, correction) - .sum() - .expand_as(result) - .conj(); - } - auto dim = dim_opt.value(); - const int64_t dof = _safe_size(self_p.sizes(), dim) - correction; - return ((2.0 / dof) * self_t.conj() * - (self_p - self_p.mean(dim, /*keepdim=*/true))) - .sum(dim, keepdim) - .conj(); -} - Tensor std_backward( const Tensor& result, const Tensor& grad, @@ -1520,51 +1516,53 @@ Tensor std_backward( return var_backward(grad_var, self, dim, correction, keepdim); } -Tensor mean_backward( - Tensor grad, - const IntArrayRef sizes, - IntArrayRef dim, - bool keepdim) { - return sum_backward(grad, sizes, dim, keepdim) / _safe_size(sizes, dim); -} - -Tensor mean_backward(Tensor grad, const IntArrayRef sizes, int64_t numel) { - return grad.expand(sizes) / numel; -} - -static Tensor mean_backward( - const Tensor& grad, - const IntArrayRef sizes, - int64_t numel, - at::OptionalIntArrayRef dim, +Tensor var_mean_backward( + const Tensor& gvar, + const Tensor& gmean, + const Tensor& self, + at::OptionalIntArrayRef dim_opt, + c10::optional correction_opt, bool keepdim) { - if (dim.has_value()) { - return mean_backward(grad, sizes, *dim, keepdim); - } else { - return mean_backward(grad, sizes, numel); + auto correction = correction_opt.value_or(1); + Tensor gself; + if (gvar.defined()) { + gself = var_backward(gvar, self, dim_opt, correction, keepdim); } + if (gmean.defined()) { + auto aux = mean_backward( + gmean, + self.sizes(), + dim_opt.value_or(IntArrayRef({})), + self.numel(), + keepdim); + gself = gself.defined() ? gself + aux : aux; + } + return gself; } -Tensor var_std_mean_backward( - const variable_list& grads, +Tensor std_mean_backward( + const Tensor& gstd, + const Tensor& gmean, const Tensor& self, - const Tensor& r1, - const Tensor& r2, - at::OptionalIntArrayRef dim, - c10::optional correction, - bool keepdim, - bool is_std) { - Tensor grad; - if (grads[0].defined()) { - grad = is_std ? std_backward(r1, grads[0], self, dim, correction, keepdim) - : var_backward(grads[0], self, dim, correction, keepdim); - } - if (grads[1].defined()) { - Tensor mean_grad = - mean_backward(grads[1], self.sizes(), self.numel(), dim, keepdim); - grad = grad.defined() ? grad + mean_grad : mean_grad; + const Tensor& std, + at::OptionalIntArrayRef dim_opt, + c10::optional correction_opt, + bool keepdim) { + auto correction = correction_opt.value_or(1); + Tensor gself; + if (gstd.defined()) { + gself = std_backward(std, gstd, self, dim_opt, correction, keepdim); } - return grad; + if (gmean.defined()) { + auto aux = mean_backward( + gmean, + self.sizes(), + dim_opt.value_or(IntArrayRef({})), + self.numel(), + keepdim); + gself = gself.defined() ? gself + aux : aux; + } + return gself; } Tensor masked_scatter_backward( @@ -1896,59 +1894,6 @@ Tensor infinitely_differentiable_logit_backward( } } -Tensor kl_div_double_backward_grad_output( - const Tensor& grad, - const Tensor& input, - const Tensor& target, - int64_t reduction, - bool log_target) { - auto result = - kl_div_backward(grad, input, target, at::Reduction::None, log_target); - if (reduction == at::Reduction::Mean) { - return result.mean(); - } else if (reduction == at::Reduction::Sum) { - return result.sum(); - } - return result; -} - -// Compute derivatives for targets. -Tensor kl_div_target_backward( - Tensor grad_output, - Tensor self, - Tensor target, - int64_t reduction, - bool log_target) { - Tensor grad_target; - if (!log_target) { - if (!areAnyTensorSubclassLike({self, target}) && - !grad_output._is_zerotensor()) { - grad_target = grad_output.mul(target.log().add_(1).sub_(self)) - .masked_fill_(target == 0, 0.); - } else { - grad_target = grad_output.mul(target.log().add(1).sub(self)) - .masked_fill(target == 0, 0.); - } - } else { - if (!areAnyTensorSubclassLike({self, target})) { - grad_target = - grad_output.mul(target.add(1).sub_(self).mul_(target.exp())); - } else { - grad_target = grad_output.mul(target.add(1).sub(self).mul_(target.exp())); - } - } - - if (reduction == at::Reduction::Mean) { - if (!grad_target._is_zerotensor()) { - grad_target.div_(target.numel()); - } else { - grad_target.div(target.numel()); - } - } - - return grad_target; -} - Tensor binary_cross_entropy_target_backward( const Tensor& grad, const Tensor& self, @@ -2011,6 +1956,57 @@ Tensor binary_cross_entropy_double_backward_target( return res; } +Tensor binary_cross_entropy_with_logits_backward( + const Tensor& grad, + const Tensor& input, + const Tensor& target, + const c10::optional& weight, + const c10::optional& pos_weight, + int64_t reduction) { + // Trivial case + if (grad._is_zerotensor()) { + return at::_efficientzerotensor(input.sizes(), input.options()); + } + + // -w * [ pos * y * (1 -sigmoid(x)) - (1 - y) sigmoid(x)] * grad + + // If there are subclassed tensors use the out of place version + Tensor grad_input; + if (isDefined(pos_weight)) { + // pos_weight might need to be broadcasted, thus mul(target) is not inplace. + auto t = pos_weight->mul(target); + grad_input = at::areAnyTensorSubclassLike({input, target}) || + at::GradMode::is_enabled() + ? t.add(1).sub(target).mul(input.sigmoid()).sub(t) + : t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t); + } else { + grad_input = at::areAnyTensorSubclassLike({input, target}) || + at::GradMode::is_enabled() + ? input.sigmoid().sub(target) + : input.sigmoid().sub_(target); + } + + if (at::isTensorSubclassLike(grad) || at::GradMode::is_enabled()) { + grad_input = grad_input.mul(grad); + } else { + grad_input.mul_(grad); + } + + if (isDefined(weight)) { + if (at::isTensorSubclassLike(*weight) || at::GradMode::is_enabled()) { + grad_input = grad_input.mul(*weight); + } else { + grad_input.mul_(*weight); + } + } + + if (reduction == at::Reduction::Mean) { + grad_input.div_(input.numel()); + } + + return grad_input; +} + Tensor binary_cross_entropy_with_logits_target_backward( const Tensor& grad_output, const Tensor& self, @@ -2018,28 +2014,30 @@ Tensor binary_cross_entropy_with_logits_target_backward( const c10::optional& weight, const c10::optional& pos_weight, int64_t reduction) { + if (grad_output._is_zerotensor()) { + return at::_efficientzerotensor(target.sizes(), target.options()); + } + Tensor grad_target; if (isDefined(pos_weight)) { - if (!areAnyTensorSubclassLike({*pos_weight, grad_output})) { - grad_target = (1. - self.sigmoid()) - .log_() - .sub_(pos_weight->mul(self.sigmoid().log_())) - .mul_(grad_output); - } else { - grad_target = (1. - self.sigmoid()) - .log_() - .sub(pos_weight->mul(self.sigmoid().log_())) + if (areAnyTensorSubclassLike({*pos_weight, grad_output})) { + grad_target = at::log_sigmoid(-self) + .sub(at::log_sigmoid(self).mul(*pos_weight)) .mul(grad_output); + } else { + grad_target = at::log_sigmoid(-self) + .sub_(at::log_sigmoid(self).mul_(*pos_weight)) + .mul_(grad_output); } } else { - grad_target = self.mul(-grad_output); + grad_target = -self * grad_output; } if (isDefined(weight)) { - if (!isTensorSubclassLike(*weight)) { - grad_target.mul_(*weight); - } else { + if (at::isTensorSubclassLike(*weight)) { grad_target = grad_target.mul(*weight); + } else { + grad_target.mul_(*weight); } } @@ -3991,10 +3989,22 @@ Tensor differential_analytic_matrix_function( meta_grad_sizes[A.dim() - 1] *= 2; auto n = A.size(-1); - auto meta_grad = at::zeros(meta_grad_sizes, grad.options()); - meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(A); - meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(A); - meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad); + Tensor meta_grad; + // For Composite Compliance, we can't copy a Subclass into a Regular Tensor, + // so we use out-of-place ops with equivalent output. + // NOTE: We can't use `new_zeros` directly as both `A` and `grad` can + // be Tensor Subclass and we don't want to make assumption about which + // one to choose for creating output buffer. + // eg. if both are BatchedTensor at different level. + if (areAnyTensorSubclassLike({A, grad})) { + meta_grad = at::cat( + {at::cat({A, grad}, -1), at::cat({at::zeros_like(A), A}, -1)}, -2); + } else { + meta_grad = at::zeros(meta_grad_sizes, grad.options()); + meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(A); + meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(A); + meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad); + } return matrix_function(meta_grad).narrow(-2, 0, n).narrow(-1, n, n); } @@ -4009,277 +4019,146 @@ Tensor linalg_matrix_exp_differential( self, grad, at::linalg_matrix_exp, /* adjoint */ adjoint); } -Tensor det_backward(const Tensor& grad, const Tensor& self, const Tensor& det) { - if (self.numel() == 0) { - return at::empty_like(self); - } - - auto det_backward_nonsingular = - [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { - // Derived from Jacobi's formula for partial derivative, which can be found - // at https://en.wikipedia.org/wiki/Jacobi%27s_formula - // i.e. if A is the input matrix, then - // A_grad = A^{-H} (grad * det.conj()) I, where - // A^{-H} = (A^{-1}).T.conj() - - // create a matrix d := (grad * det.conj()) I - auto d = at::zeros_like(self); - d.diagonal(0, -2, -1).copy_((grad * det.conj()).unsqueeze(-1)); - return at::linalg_solve(self.mH(), d); - }; - - auto det_backward_singular = - [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { - // Derived from the gradient formula that would be used if `self`'s - // determinant is calculated using SVD, like so: - // u, s, vh = svd(self) - // det(self) = det(u) * prod(s) * det(vh) - // - // This formula should be correct even if `self` is nonsingular. - Tensor u, s, vh; - std::tie(u, s, vh) = at::linalg_svd(self); - auto u_det = at::linalg_det(u); - auto s_prod = s.prod(-1); - auto vh_det = at::linalg_det(vh); - - auto u_det_grad = grad * (vh_det * s_prod).conj(); - auto u_grad = det_backward_nonsingular(u_det_grad, u, u_det); - - auto s_prod_grad = - handle_r_to_c(s_prod.scalar_type(), grad * (u_det * vh_det).conj()); - auto s_grad = prod_backward(s_prod_grad, s, s_prod, -1, false); - - auto vh_det_grad = grad * (u_det * s_prod).conj(); - auto vh_grad = det_backward_nonsingular(vh_det_grad, vh, vh_det); - - return svd_backward(u_grad, s_grad, vh_grad, u, s, vh); - }; - - auto eps = at::native::_get_epsilon(c10::toRealValueType(self.scalar_type())); - auto singular_det_cutoff = eps * at::linalg_matrix_norm(self); - - if (self.dim() == 2) { - if (det.abs().lt(singular_det_cutoff).item()) { - return det_backward_singular(grad, self, det); - } else { - return det_backward_nonsingular(grad, self, det); - } - } else { - auto nonzero_det_mask = det.abs().ge(singular_det_cutoff); - if (nonzero_det_mask.all().item()) { - return det_backward_nonsingular(grad, self, det); - } - - auto zero_det_mask = nonzero_det_mask.logical_not(); - if (zero_det_mask.all().item()) { - return det_backward_singular(grad, self, det); - } - - Tensor self_grad = self.new_empty(self.sizes(), self.options()); - - auto nonzero_det_list = - at::native::toListOfOptionalTensors(nonzero_det_mask); - self_grad.index_put_( - /*indices=*/nonzero_det_list, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/ - det_backward_nonsingular( - grad.index(nonzero_det_list), - self.index(nonzero_det_list), - det.index(nonzero_det_list))); - - auto zero_det_list = at::native::toListOfOptionalTensors(zero_det_mask); - self_grad.index_put_( - /*indices=*/zero_det_list, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/ - det_backward_singular( - grad.index(zero_det_list), - self.index(zero_det_list), - det.index(zero_det_list))); - - return self_grad; - } -} - -// The backward for this function is just a specialized version of -// lu.backward, which is implemented in /torch/_autograd_functions.py -Tensor _det_lu_based_helper_backward( - const Tensor& det_grad, +Tensor linalg_det_backward( + const Tensor& grad, const Tensor& det, - const Tensor& self, - const Tensor& lu, - const Tensor& pivs) { - if (!self.numel()) { - return at::zeros_like(self, at::MemoryFormat::Contiguous); - } - if (!det_grad.defined()) { - return Tensor(); + const Tensor& A, + const Tensor& LU, + const Tensor& pivots) { + at::NoTF32Guard disable_tf32; + if (!grad.defined()) { + return {}; } - // run det_backward only if backward is run on _det_lu_based_helper_backward. - // _det_lu_based_helper_backward is more stable for forward det computing - // functions, but it fails with double backward gradient checks - // (gradgradcheck). det_backward, on the other hand, is less stable (due to - // restrictions on svd_backward, namely, svd_backward requries distinct - // singular values which are sufficiently different from each other), yet, if - // its computation is stable, so is its double backward. Hence, if only single - // backward is run, we use _det_lu_based_helper_backward, for the double - // backward case we use det_backward. The latter approach could produce - // unstable gradients, therefore we DO NOT recommend double backpropagation - // through det computing functions. - if (at::GradMode::is_enabled()) { - return det_backward(det_grad, self, det); + // The gradient G is the matrix solving + // A.mH G = det(A).conj() * grad * I + auto d_diag = grad * det.conj(); + // Optimisation, Make it F-transposed as it's what lu_solve expects + auto d = at::diag_embed(d_diag.unsqueeze(-1).expand_as(pivots)).mT(); + + if (!at::GradMode::is_enabled()) { + // The formula is given by the solution of AX = det.conj() * det * I when A + // is invertible det is C^1, so if it's not invertible, we can wiggle the LU + // decomposition a bit and use the resulting matrix as a decent + // approximation + auto eps = at::native::_get_epsilon(c10::toRealValueType(LU.scalar_type())); + auto LU_ = + LU + at::diag_embed(at::where(LU.diagonal(0, -2, -1) == 0., eps, 0.)); + auto use_A_T = A.is_contiguous() && !A.is_complex(); + return at::linalg_lu_solve( + LU_, pivots, d, /*left=*/true, /*adjoint=*/!use_A_T); + } else { + // If we want to compute further gradients, we need to recompute the LU + // decomposition so that autograd computes the correct gradients wrt to A + // (cf. solve_backward) + + // TODO When the user wants higher derivatives, the trick above just does + // not cut it The proper way of doing this is doing `auto mask = det == 0.;` + // and then if any determinant is zero, use an SVD decomposition to compute + // the derivative in those inputs (not all inputs). The derivative may be + // then computed explicitly by noting that the gradient of the derivative of + // the determinant is given in terms of the adjugate of a matrix. Then, the + // adjugate of a singular matrix may be computed as per + // https://nhigham.com/2020/06/16/what-is-the-adjugate-of-a-matrix/ + // The code may be implemented as follows: + // + // Tensor U, S, Vh; + // std::tie(U, S, Vh) = at::linalg_svd(A); + // auto alpha = (at::linalg_det(U) * at::linalg_det(Vh)).conj() * grad; + // auto D = prod_safe_zeros_backward(alpha.unsqueeze(-1), S, S.dim() - 1); + // return (U * D.unsqueeze(-2)).matmul(Vh); + // + // The issue with this code is that the derivative given by autograd of + // prod_safe_zeros_backward is not the second derivative of the product. + // It is not clear to me how to implement the second derivative of the + // product efficently. Note that this is also currently a problem when we + // compute higher derivatives of `x.prod()` and `x` has more than one zero. + return at::linalg_solve(A.mH(), d); } - - // we use a sequence of kernels to avoid memory copies and checks, - // as in the implementation of this function we are 100% sure that - // `lu` and `pivs` are coming from a LAPACK routine. - return at::_det_lu_based_helper_backward_helper( - det_grad, det, self, lu, pivs); } -Tensor logdet_backward( - const Tensor& grad, - const Tensor& self, - const Tensor& logdet) { - auto singular_case_backward = [&](const Tensor& grad, - const Tensor& self) -> Tensor { - Tensor u, sigma, vh; - std::tie(u, sigma, vh) = at::linalg_svd(self, false); - // logdet = \sum log(sigma) - auto gsigma = grad.unsqueeze(-1).div(sigma); - return svd_backward({}, gsigma, {}, u, sigma, vh); - }; - - auto nonsingular_case_backward = [&](const Tensor& grad, - const Tensor& self) -> Tensor { - return unsqueeze_multiple(grad, {-1, -2}, self.dim()) * self.inverse().mT(); - }; - - if (self.dim() == 2) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - if (logdet.item() != -INFINITY) { - return nonsingular_case_backward(grad, self); - } else { - return singular_case_backward(grad, self); - } +std::tuple slogdet_jvp( + const Tensor& LU, + const Tensor& pivots, + const Tensor& dA, + const Tensor& sign, + const bool use_A_T) { + // No need to handle the singular case separately as we do in det since + // this function is not differentiable on singular matrices + auto trAinvE = at::linalg_lu_solve(LU, pivots, dA, /*left*/ true, use_A_T) + .diagonal(0, -2, -1) + .sum(-1); + if (LU.is_complex()) { + auto i = c10::complex{0.0, 1.0}; + return std::make_tuple(at::imag(trAinvE) * (i * sign), at::real(trAinvE)); } else { - auto finite_logdet_indices = - at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY)); - c10::optional first_finite_logdet_index = finite_logdet_indices[0]; - - if (first_finite_logdet_index->size(0) == - logdet.numel()) { // all log determinants are finite (non-singular) - return nonsingular_case_backward(grad, self); - } - - auto neginf_logdet_indices = - at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY)); - c10::optional first_neginf_logdet_index = neginf_logdet_indices[0]; - - if (first_neginf_logdet_index->size(0) == - logdet.numel()) { // all log determinants are -inf (singular) - return singular_case_backward(grad, self); - } - - Tensor grad_logdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - // invertible case - grad_logdet.index_put_( - /*indices=*/finite_logdet_indices, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/ - nonsingular_case_backward( - grad.index(finite_logdet_indices), - self.index(finite_logdet_indices))); - - // non-invertible case, uses SVD - grad_logdet.index_put_( - /*indices=*/neginf_logdet_indices, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/ - singular_case_backward( - grad.index(neginf_logdet_indices), - self.index(neginf_logdet_indices))); - - return grad_logdet; + return std::make_tuple( + at::_efficientzerotensor(sign.sizes(), sign.options()), trAinvE); } } Tensor slogdet_backward( + const Tensor& grad_sign, const Tensor& grad_logabsdet, - const Tensor& self, + const Tensor& A, const Tensor& signdet, - const Tensor& logabsdet) { - auto singular_case_backward = [&](const Tensor& grad_logabsdet, - const Tensor& self) -> Tensor { - Tensor u, sigma, vh; - std::tie(u, sigma, vh) = at::linalg_svd(self, false); - Tensor v = vh.mH(); - // sigma has all non-negative entries (also with at least one zero entry) - // so logabsdet = \sum log(abs(sigma)) - // but det = 0, so backward logabsdet = \sum log(sigma) - auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma); - return svd_backward({}, gsigma, {}, u, sigma, vh); - }; - - auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, - const Tensor& self) -> Tensor { - // TODO: replace self.inverse with linalg_inverse - return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * - self.inverse().mH(); - }; + const Tensor& LU, + const Tensor& pivots) { + // We compute the complex case, as the real case follows from it + // Forward AD + // d (logabsdet)_A(E) = Re(tr(A^{-1}E)) + // d (signdet)_A(E) = sgn * Im(tr(A^{-1}E)) * i + // So + // d (logabsdet)*_A(g) = gA^{-H} + // Now, to compute the adjoint of d(signdet), note that + // Re(z * Im(w)) = Re(-Re(z)iw) + // So, let g \in C, + // = Re(g.conj() * sgn * i * Im(A^{-1}E)) + // = Re(Re(g.conj() * sgn * i) * -i * A^{-1}E) + // = Re(Im(g.conj() * sgn) * i * A^{-1}E) + // = + // As such, + // (d slogabs)*_A(g_sign, g_abs) = (g_abs - g_sign.conj() * sgn) * A^{-H} + + if (!grad_sign.defined() && !grad_logabsdet.defined()) { + return {}; + } - if (self.dim() == 2) { - bool is_singular = self.is_complex() ? signdet.abs().item() == 0 - : signdet.item() == 0; - if (is_singular) { - return singular_case_backward(grad_logabsdet, self); - } else { - return nonsingular_case_backward(grad_logabsdet, self); - } - } else { - auto nonzero_signdet_indices = at::native::toListOfOptionalTensors( - self.is_complex() ? at::where(signdet.abs()) : at::where(signdet)); - c10::optional first_nonzero_signdet_index = - nonzero_signdet_indices[0]; - - if (first_nonzero_signdet_index->size(0) == - logabsdet.numel()) { // all log determinants are finite (non-singular) - return nonsingular_case_backward(grad_logabsdet, self); - } + auto is_complex = A.is_complex(); - auto zero_signdet_indices = - at::native::toListOfOptionalTensors(at::where(signdet == 0)); - c10::optional first_zero_signdet_index = zero_signdet_indices[0]; + // In the real case grad_sign is always zero + if (!is_complex && !grad_logabsdet.defined()) { + return {}; + } - if (first_zero_signdet_index->size(0) == - logabsdet.numel()) { // all log determinants are -inf (singular) - return singular_case_backward(grad_logabsdet, self); + auto g = grad_logabsdet; + if (is_complex) { + if (grad_sign.defined()) { + auto i = c10::complex{0.0, 1.0}; + if (g.defined()) { + g = g - i * at::imag(grad_sign.conj() * signdet); + } else { + g = -i * at::imag(grad_sign.conj() * signdet); + } + } else { + // Cast to complex explicitly + g = g.to(A.scalar_type()); } + } - Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - // invertible case - grad_slogdet.index_put_( - /*indices=*/nonzero_signdet_indices, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/ - nonsingular_case_backward( - grad_logabsdet.index(nonzero_signdet_indices), - self.index(nonzero_signdet_indices))); - - // non-invertible case, uses SVD - grad_slogdet.index_put_( - /*indices=*/zero_signdet_indices, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/ - singular_case_backward( - grad_logabsdet.index(zero_signdet_indices), - self.index(zero_signdet_indices))); - - return grad_slogdet; + // No need to handle the singular case separately here (as we do in det) + // since this function is not differentiable on singular matrices + // Optimisation, Make it F-transposed as it's what lu_solve expects + auto d = at::diag_embed(g.unsqueeze(-1).expand_as(pivots)).mT(); + if (!at::GradMode::is_enabled()) { + auto use_A_T = A.is_contiguous() && !A.is_complex(); + return at::linalg_lu_solve( + LU, pivots, d, /*left=*/true, /*adjoint=*/!use_A_T); + } else { + // If we want to compute further gradients, we need to recompute the LU + // decomposition so that autograd computes the correct gradients wrt to A + // (cf. solve_backward) + return at::linalg_solve(A.mH(), d); } } @@ -4908,27 +4787,23 @@ infinitely_differentiable_native_group_norm_backward( c = -b * mean_tensor - c * rstd_tensor * s; dX = a * dY_tensor + b * X_tensor + c; if (dmean.defined() && drstd.defined()) { - dX += var_std_mean_backward( - {dvar, dmean.view({N, G, 1, 1})}, + dX += var_mean_backward( + dvar, + dmean.view({N, G, 1, 1}), X_tensor, - var, - mean_tensor, IntArrayRef{2, 3}, 0, - true, - false); + true); } dX = dX.reshape_as(X); } else if (dmean.defined() && drstd.defined()) { - dX = var_std_mean_backward( - {dvar, dmean.view({N, G, 1, 1})}, + dX = var_mean_backward( + dvar, + dmean.view({N, G, 1, 1}), X_tensor, - var, - mean_tensor, IntArrayRef{2, 3}, 0, - true, - false) + true) .reshape_as(X); } } @@ -6696,6 +6571,21 @@ std::tuple index_reduce_backward( return std::make_tuple(grad_self, grad_src); } +Tensor take_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& indices) { + Tensor grad_self = at::zeros_like(self); + // For Composite Compliance, + // if `grad` and `indices` are CCT but `self` is not + // then we use the out-of-place variant of `put`. + if (!isTensorSubclassLike(self) && + areAnyTensorSubclassLike({grad, indices})) { + return grad_self.put(indices, grad, true); + } + return grad_self.put_(indices, grad, true); +} + } // namespace details } // namespace generated } // namespace autograd diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 9283a7375f93a..376490fb6f45e 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -150,12 +150,17 @@ at::Tensor rad2deg_backward(const at::Tensor& grad); at::Tensor deg2rad_backward(const at::Tensor& grad); at::Tensor unsqueeze_multiple( const at::Tensor& t, - at::IntArrayRef dim, + at::OptionalIntArrayRef opt_dim, size_t n_dims); at::Tensor sum_backward( const at::Tensor& grad, at::IntArrayRef sizes, - at::IntArrayRef dims, + at::OptionalIntArrayRef opt_dims, + bool keepdim); +at::Tensor sum_backward( + const at::Tensor& grad, + c10::SymIntArrayRef sizes, + c10::SymIntArrayRef dims, bool keepdim); at::Tensor nansum_backward( const at::Tensor& grad, @@ -301,7 +306,7 @@ at::Tensor evenly_distribute_backward( at::Tensor grad, const at::Tensor& input, const at::Tensor& value); -at::Tensor sgn_backward(Tensor result, Tensor grad, Tensor self); +Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn); at::Tensor var_backward( at::Tensor grad, const at::Tensor& self, @@ -322,24 +327,27 @@ at::Tensor std_backward( at::OptionalIntArrayRef dim, c10::optional correction, bool keepdim); -at::Tensor mean_backward( - at::Tensor grad, - const at::IntArrayRef sizes, - at::IntArrayRef dim, +Tensor mean_backward( + const Tensor& grad, + IntArrayRef shape, + at::OptionalIntArrayRef opt_dim, + int64_t numel, + bool keepdim); +Tensor var_mean_backward( + const Tensor& gvar, + const Tensor& gmean, + const Tensor& self, + at::OptionalIntArrayRef dim_opt, + c10::optional correction_opt, + bool keepdim); +Tensor std_mean_backward( + const Tensor& gstd, + const Tensor& gmean, + const Tensor& self, + const Tensor& std, + at::OptionalIntArrayRef dim_opt, + c10::optional correction_opt, bool keepdim); -at::Tensor mean_backward( - at::Tensor grad, - const at::IntArrayRef sizes, - int64_t numel); -at::Tensor var_std_mean_backward( - const variable_list& grads, - const at::Tensor& self, - const at::Tensor& r1, - const at::Tensor& r2, - at::OptionalIntArrayRef dim, - c10::optional correction, - bool keepdim, - bool is_std); at::Tensor masked_scatter_backward( const at::Tensor& grad, const at::Tensor& mask, @@ -399,12 +407,6 @@ Tensor infinitely_differentiable_logit_backward( const Tensor& grad, const Tensor& self, c10::optional eps); -at::Tensor kl_div_double_backward_grad_output( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& target, - int64_t reduction, - bool log_target); Tensor binary_cross_entropy_target_backward( const Tensor& grad, const Tensor& self, @@ -418,6 +420,13 @@ Tensor binary_cross_entropy_double_backward_target( const Tensor& target, const c10::optional& weight, int64_t reduction); +Tensor binary_cross_entropy_with_logits_backward( + const Tensor& grad, + const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + const c10::optional& pos_weight_opt, + int64_t reduction); at::Tensor binary_cross_entropy_with_logits_target_backward( const at::Tensor& grad_output, const at::Tensor& self, @@ -485,15 +494,19 @@ at::Tensor softplus_double_backward( const at::Tensor& input, const at::Scalar& beta, const at::Scalar& threshold); -at::Tensor logdet_backward( - const at::Tensor& grad, - const at::Tensor& self, - const at::Tensor& logdet); +std::tuple slogdet_jvp( + const at::Tensor& LU, + const at::Tensor& pivots, + const at::Tensor& dA, + const at::Tensor& sign, + const bool use_A_T); at::Tensor slogdet_backward( + const at::Tensor& grad_sign, const at::Tensor& grad_logabsdet, - const at::Tensor& self, + const at::Tensor& A, const at::Tensor& signdet, - const at::Tensor& logabsdet); + const at::Tensor& LU, + const at::Tensor& pivots); at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sinc_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sparse_constructor_values_backward( @@ -624,10 +637,6 @@ Tensor linalg_matrix_exp_differential( const Tensor& self, const Tensor& grad, bool adjoint); -Tensor linalg_det_backward( - const Tensor& grad, - const Tensor& self, - const Tensor& det); std::tuple batchnorm_double_backward( const Tensor& input, const c10::optional& gamma, @@ -647,12 +656,6 @@ std::tuple _euclidean_dist_backward( const Tensor& x1, const Tensor& x2, const Tensor& res); -Tensor kl_div_target_backward( - Tensor grad_output, - Tensor self, - Tensor target, - int64_t reduction, - bool log_target); Tensor fft_backward( const Tensor& self, const Tensor& grad, @@ -748,6 +751,12 @@ std::tuple atan2_backward( const Tensor& self, const Tensor& other, std::array output_mask); +Tensor amaxamin_jvp( + const Tensor& x, + const Tensor& dx, + const Tensor& result, + IntArrayRef dim, + bool keepdim); std::tuple layer_norm_double_backward( const Tensor& input, const c10::optional& gamma, @@ -819,13 +828,12 @@ Tensor lu_unpack_backward( const int64_t m, const int64_t n); -Tensor _det_lu_based_helper_backward( - const Tensor& det_grad, +Tensor linalg_det_backward( + const Tensor& grad, const Tensor& det, - const Tensor& self, - const Tensor& lu, - const Tensor& pivs); - + const Tensor& A, + const Tensor& LU, + const Tensor& pivots); std::tuple linalg_lstsq_backward( const Tensor& grad, const Tensor& A, @@ -993,6 +1001,11 @@ std::tuple index_reduce_backward( bool include_self, const Tensor& result); +Tensor take_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& indices); + } // namespace details } // namespace generated } // namespace autograd diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index e3c0db4eb6a8f..db00d67576d3b 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -83,14 +83,35 @@ using at::Tensor; // the base if needed. namespace { -// Check if two Tensor have the same storage offset, sizes and strides + +// Enforcing that the metadata between the primal and tangent are same has two +// goals: +// - When properties of the primal are checked in composite op's to determine +// control flow, the code path decided upon is also reasonable for the tangent +// - Make sure that when the same as_strided is applied to both primal and +// and tangent, it behaves similarly. +// +// We do that by checking: +// 1) the storages have same properties: size and conj/neg-ness +// 2) the same indices refer to the same elements in storage +// (we are more strict than necessary here to satisfy the goal 1) bool has_same_meta(const Variable& base, const Variable& other) { if (!base.defined() || !other.defined()) { return false; } - if (base.storage_offset() != other.storage_offset()) { + // 1) The storages have the same properties + if (!at::_has_same_storage_numel(base, other)) { + return false; + } + if (base.is_conj() != other.is_conj() || base.is_neg() != other.is_neg()) { return false; } + + // Technically dim and size belong as part of (2), so we shouldn't really care + // if a zero-numel tensor violates these. But since these properties + // (unlike offset and strides) often determine control flow in composite ops + // it is useful to enforce that they match for primal and tangent here so + // nothing funny happens later (See goal 1). if (base.dim() != other.dim()) { return false; } @@ -98,17 +119,24 @@ bool has_same_meta(const Variable& base, const Variable& other) { if (base.sizes()[i] != other.sizes()[i]) { return false; } + } + + // The check below will always be vacuously true for 0-element tensors + if (base.numel() == 0 && other.numel() == 0) { + return true; + } + + // 2) The same indices refer to the same elements in storage + if (base.storage_offset() != other.storage_offset()) { + return false; + } + + for (const auto i : c10::irange(base.dim())) { if (base.strides()[i] != other.strides()[i] && base.sizes()[i] != 1 && base.sizes()[i] != 0) { return false; } } - if (!at::_has_same_storage_numel(base, other)) { - return false; - } - if (base.is_conj() != other.is_conj() || base.is_neg() != other.is_neg()) { - return false; - } return true; } } // anonymous namespace diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index 65a47766acb78..92d234e7fb287 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -53,44 +53,40 @@ void autogradNotImplementedFallbackImpl( // See gen_variable_type.py const auto& schema = op.schema(); const auto& op_name = schema.operator_name().name; - const auto& arguments = schema.arguments(); - const auto& returns = schema.returns(); - const auto num_arguments = arguments.size(); - const auto num_returns = returns.size(); + const auto num_arguments = schema.arguments().size(); + const auto num_returns = schema.returns().size(); const auto stack_start = stack->size() - num_arguments; const bool grad_mode = GradMode::is_enabled(); std::vector tensors_requiring_grad_on_stack; // Keep track of which outputs are output of in-place modification // so we can rebase_history if necessary - std::vector is_inplace_output; + std::vector is_inplace_output(num_returns, false); bool any_is_inplace_output = false; - std::vector is_aliased_output; - is_inplace_output.reserve(num_returns); - is_aliased_output.reserve(num_returns); - - for (const auto i : c10::irange(num_returns)) { - const at::AliasInfo* alias_info = returns[i].alias_info(); - is_inplace_output.push_back(alias_info != nullptr && alias_info->isWrite()); - any_is_inplace_output |= alias_info != nullptr && alias_info->isWrite(); - is_aliased_output.push_back(alias_info != nullptr); - } - int aliased_input_idx = -1; + std::vector is_aliased_output(num_returns, false); int aliased_output_idx = -1; + for (const auto i : c10::irange(num_returns)) { - const at::AliasInfo* alias_info = returns[i].alias_info(); - if (alias_info != nullptr && !alias_info->isWrite()) { - TORCH_CHECK( - aliased_output_idx == -1, - "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " - "Non-composite functions where multiple outputs are aliased with inputs aren't supported." - "Please rewrite your function as a composite function."); - aliased_output_idx = i; + if (schema.is_aliasing({c10::SchemaArgType::output, i})) { + if (schema.is_mutable({c10::SchemaArgType::output, i})) { + is_inplace_output[i] = true; + any_is_inplace_output = true; + } else { + TORCH_CHECK( + aliased_output_idx == -1, + "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple outputs are aliased with inputs aren't supported." + "Please rewrite your function as a composite function."); + aliased_output_idx = i; + } + is_aliased_output[i] = true; } } + + int aliased_input_idx = -1; for (const auto i : c10::irange(num_arguments)) { - const at::AliasInfo* alias_info = arguments[i].alias_info(); - if (alias_info != nullptr && !alias_info->isWrite()) { + if (schema.is_aliasing({c10::SchemaArgType::input, i}) && + !schema.is_mutable({c10::SchemaArgType::input, i})) { TORCH_CHECK( aliased_input_idx == -1, "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " @@ -121,8 +117,7 @@ void autogradNotImplementedFallbackImpl( _foreach_tensor( [&](size_t _, size_t i, const at::Tensor& t) { - const at::AliasInfo* alias_info = arguments[i].alias_info(); - if (alias_info != nullptr && alias_info->isWrite()) { + if (schema.is_mutable({c10::SchemaArgType::input, i})) { check_inplace(t, any_requires_grad); } }, @@ -273,18 +268,16 @@ void autogradNotImplementedInplaceOrViewFallbackImpl( // that is not allowed in the gen_inplace_or_view logic const auto& schema = op.schema(); const auto& op_name = schema.operator_name().name; - const auto& arguments = schema.arguments(); - const auto& returns = schema.returns(); - const auto num_arguments = arguments.size(); - const auto num_returns = returns.size(); + const auto num_arguments = schema.arguments().size(); + const auto num_returns = schema.returns().size(); const auto stack_start = stack->size() - num_arguments; at::Tensor aliased_input; int64_t aliased_output_idx = -1; for (const auto i : c10::irange(num_returns)) { - const at::AliasInfo* alias_info = returns[i].alias_info(); - if (alias_info != nullptr && !alias_info->isWrite()) { + if (schema.is_aliasing({c10::SchemaArgType::output, i}) && + !schema.is_mutable({c10::SchemaArgType::output, i})) { TORCH_CHECK( aliased_output_idx == -1, "Fallback ADInplaceOrView kernel expects only a single output in the operator schema to have a " @@ -297,25 +290,22 @@ void autogradNotImplementedInplaceOrViewFallbackImpl( int64_t aliased_input_idx = -1; for (const auto i : c10::irange(num_arguments)) { - const at::AliasInfo* alias_info = arguments[i].alias_info(); - if (alias_info != nullptr) { - if (!alias_info->isWrite()) { - TORCH_CHECK( - aliased_input_idx == -1, - "Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a " - "non-write alias annotation (i.e., 'Tensor(a)'). " - "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " - "Please rewrite your function as a composite function."); - aliased_input_idx = i; - const c10::IValue& aliased_input_iv = - (*stack)[stack_start + i]; // get a reference to an ivalue on the - // stack - TORCH_CHECK(aliased_input_iv.isTensor()); - aliased_input = - aliased_input_iv - .toTensor(); // TODO: Can we avoid saving this tensor and - // incurring the refcount bump? - } + if (schema.is_aliasing({c10::SchemaArgType::input, i}) && + !schema.is_mutable({c10::SchemaArgType::input, i})) { + TORCH_CHECK( + aliased_input_idx == -1, + "Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a " + "non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " + "Please rewrite your function as a composite function."); + aliased_input_idx = i; + const c10::IValue& aliased_input_iv = + (*stack)[stack_start + i]; // get a reference to an ivalue on the + // stack + TORCH_CHECK(aliased_input_iv.isTensor()); + aliased_input = + aliased_input_iv.toTensor(); // TODO: Can we avoid saving this tensor + // and incurring the refcount bump? } } // See NOTE [ Limitations of ADInplaceOrView boxed kernel ] above @@ -334,8 +324,7 @@ void autogradNotImplementedInplaceOrViewFallbackImpl( } for (const auto i : c10::irange(num_returns)) { - const at::AliasInfo* alias_info = returns[i].alias_info(); - if (alias_info->isWrite()) { + if (schema.is_mutable({c10::SchemaArgType::output, i})) { increment_version((*stack)[stack->size() - num_returns + i].toTensor()); } } diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index dd0162441934e..408daf3b0250a 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -54,8 +54,8 @@ static void forked_autograd_child() { // Should be called before unsafe for forks (thread pool) calls static void track_bad_autograd_forks() { #if !defined(WIN32) - static std::once_flag flag; - std::call_once( + static c10::once_flag flag; + c10::call_once( flag, [&] { pthread_atfork(nullptr, nullptr, forked_autograd_child); }); #endif } @@ -405,11 +405,6 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { // backwards, user thread), this function is expected to exit once that // graph_task complete. -#ifdef USE_ROCM - // Keep track of backward pass for rocblas. - at::ROCmBackwardPassGuard in_backward; -#endif - // local_ready_queue should already been initialized when we get into // thread_main TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr); @@ -1109,7 +1104,7 @@ void Engine::initialize_device_threads_pool() { !in_bad_autograd_fork, "Unable to handle autograd's threading in combination with fork-based multiprocessing. " "See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork"); - std::call_once( + c10::call_once( start_device_threads_flag_, &Engine::start_device_threads, this); } diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index e436e95b1b669..e86fd8339a63e 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -14,6 +14,8 @@ #include #include +#include + #include #include #include @@ -391,7 +393,7 @@ struct TORCH_API Engine { // Ensures device_ready_queues_ are initialized only once // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::once_flag start_device_threads_flag_; + c10::once_flag start_device_threads_flag_; // Safe to read device_ready_queues_ without synchronization after // initialization // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index f98da972db47f..9f9da69381cc6 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -150,6 +150,11 @@ struct TORCH_API Node : std::enable_shared_from_this { // probably operate with names. at::NoNamesGuard no_names_guard; +#ifdef USE_ROCM + // Keep track of backward pass for rocblas. + at::ROCmBackwardPassGuard in_backward; +#endif + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); if (C10_UNLIKELY(step_callbacks.has_value())) { @@ -186,11 +191,11 @@ struct TORCH_API Node : std::enable_shared_from_this { /// of the new input. uint32_t add_input_metadata( const at::TensorOptions& options, - at::IntArrayRef shape, + c10::SymIntArrayRef shape, bool is_tensor_subclass) noexcept { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint32_t input_nr = input_metadata_.size(); - auto meta_shape = MetadataShape{c10::in_place_type, shape}; + auto meta_shape = MetadataShape{c10::in_place_type, shape}; input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass); return input_nr; } diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp index 7a02ce308fb08..013e4fe493fc8 100644 --- a/torch/csrc/autograd/functions/init.cpp +++ b/torch/csrc/autograd/functions/init.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #ifdef USE_DISTRIBUTED #include #endif @@ -161,8 +162,9 @@ void THPAutograd_initFunctions() { if (!c_module) throw python_error(); - Py_INCREF(module); + Py_INCREF(module.get()); if (PyModule_AddObject(c_module, "_functions", module) < 0) { + Py_DECREF(module.get()); throw python_error(); } } diff --git a/torch/csrc/autograd/functions/pybind.h b/torch/csrc/autograd/functions/pybind.h index 5a97e6c4e9f97..94b3c9c679969 100644 --- a/torch/csrc/autograd/functions/pybind.h +++ b/torch/csrc/autograd/functions/pybind.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index caccffa0ae45e..50ddde49575b1 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -21,7 +21,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -29,6 +31,18 @@ #include #include +namespace { + +struct DisableFuncTorch { + DisableFuncTorch() + : front_guard_(c10::DispatchKey::FuncTorchDynamicLayerFrontMode), + back_guard_(c10::DispatchKey::FuncTorchDynamicLayerBackMode) {} + c10::impl::ExcludeDispatchKeyGuard front_guard_; + c10::impl::ExcludeDispatchKeyGuard back_guard_; +}; + +} // namespace + PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { using namespace torch::autograd::profiler; using namespace torch::profiler::impl; @@ -71,6 +85,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { .value("CPU", ProfilerState::CPU) .value("CUDA", ProfilerState::CUDA) .value("NVTX", ProfilerState::NVTX) + .value("ITT", ProfilerState::ITT) .value("KINETO", ProfilerState::KINETO) .value("KINETO_GPU_FALLBACK", ProfilerState::KINETO_GPU_FALLBACK); @@ -255,16 +270,52 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); }); { + using torch::profiler::impl::PyFrameState; using torch::profiler::impl::Result; - py::class_>(m, "_ExtraFields_TorchOp"); + py::enum_(m, "_EventType") + .value("TorchOp", EventType::TorchOp) + .value("Backend", EventType::Backend) + .value("Allocation", EventType::Allocation) + .value("PyCall", EventType::PyCall) + .value("PyCCall", EventType::PyCCall) + .value("Kineto", EventType::Kineto); + py::class_>(m, "_ExtraFields_TorchOp") + .def_readonly("inputs", &ExtraFields::inputs_) + .def_readonly( + "allow_tf32_cublas", + &ExtraFields::allow_tf32_cublas_); + py::class_(m, "_Inputs") + .def_readonly("shapes", &Inputs::shapes_) + .def_readonly("tensor_metadata", &Inputs::tensor_metadata_) + .def_readonly("dtypes", &Inputs::dtypes_); + + py::class_(m, "_TensorMetadata") + .def_property_readonly("layout", [](const TensorMetadata& metadata) { + PyObject* layout_obj = torch::autograd::utils::wrap(metadata.layout_); + return py::reinterpret_borrow(layout_obj); + }); + py::class_>(m, "_ExtraFields_Backend"); py::class_>( m, "_ExtraFields_Allocation"); - py::class_>(m, "_ExtraFields_PyCall"); - py::class_>(m, "_ExtraFields_PyCCall"); + py::class_>(m, "_ExtraFields_PyCall") + .def_readonly("callsite", &ExtraFields::callsite_) + .def_readonly("caller", &ExtraFields::caller_); + py::class_>(m, "_ExtraFields_PyCCall") + .def_readonly("caller", &ExtraFields::caller_); + py::class_(m, "_PyFrameState") + .def_readonly("line_number", &PyFrameState::line_no_) + .def_property_readonly( + "file_name", + [](const PyFrameState& s) { return s.filename_.str(); }) + .def_property_readonly("function_name", [](const PyFrameState& s) { + return s.funcname_.str(); + }); + py::class_>(m, "_ExtraFields_Kineto"); py::class_>(m, "_ProfilerEvent") .def("name", &Result::name) + .def_property_readonly("tag", &Result::tag) .def_readonly("extra_fields", &Result::extra_fields_) .def_property_readonly( "id", @@ -275,6 +326,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { "parent", [](const Result& r) { return r.parent_.lock(); }) .def_readonly("children", &Result::children_) .def_readonly("start_time_ns", &Result::start_time_ns_) + .def_readonly("start_tid", &Result::start_tid_) .def_property_readonly("correlation_id", &Result::correlationID) .def_property_readonly("end_time_ns", &Result::endTimeNS) .def_property_readonly("duration_time_ns", [](const Result& r) { @@ -416,6 +468,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { // TODO: line up this binding with DisableTorchFunction py::class_(_C_m, "_DisableTorchDispatch") .def(py::init<>()); + py::class_(_C_m, "_DisableFuncTorch").def(py::init<>()); py::class_(m, "SavedTensor") .def(py::init([]() -> torch::autograd::SavedVariable { diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index 917acc292c718..7cb9e8aedb195 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -6,9 +6,12 @@ #include #include #include +#include #include #include +#include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -23,7 +26,8 @@ namespace torch { namespace autograd { -using MetadataShape = c10::variant; +using SymIntSmallVec = c10::SmallVector; +using MetadataShape = c10::variant; /** * Records TensorOptions, shape of the tensor, whether or not the Python @@ -81,7 +85,7 @@ struct InputMetadata { TORCH_CHECK( !is_nested_tensor(), "Zeros is not currently supported for nested tensors.") - return at::zeros(shape_as_dim_vector(), options_); + return at::zeros_symint(shape_as_dim_vector(), options_); } bool is_same_shape(const at::Tensor& grad) const { @@ -92,7 +96,7 @@ struct InputMetadata { return at::native::get_nested_size_tensor(grad).is_same_size( shape_as_tensor()); } - return grad.sizes().equals(shape_as_dim_vector()); + return grad.sym_sizes().equals(shape_as_dim_vector()); } bool is_expandable_to_shape(const at::Tensor& grad) const { // Currently NestedTensors are not expandable. If this support is added then @@ -102,7 +106,7 @@ struct InputMetadata { "Both grad and InputMetadata need to be either nested or non nested tensors.") return grad.is_nested() ? false - : at::is_expandable_to(shape_as_dim_vector(), grad.sizes()); + : at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes()); } at::Tensor reduce_grad(at::Tensor& grad) const { @@ -127,7 +131,7 @@ struct InputMetadata { if (is_nested_tensor()) { ss << shape_as_tensor(); } else { - ss << shape_as_dim_vector(); + ss << c10::asIntArrayRefSlow(shape_as_dim_vector()); } return ss; } @@ -141,12 +145,14 @@ struct InputMetadata { auto nested_size = at::native::get_nested_size_tensor(input); return MetadataShape{c10::in_place_type, nested_size}; } - return MetadataShape{c10::in_place_type, input.sizes()}; + return MetadataShape{c10::in_place_type, input.sym_sizes()}; } - at::DimVector shape_as_dim_vector() const { - return c10::get(shape_); + c10::SymIntArrayRef shape_as_dim_vector() const { + const auto& dim_shape = c10::get(shape_); + return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size()); } + at::Tensor shape_as_tensor() const { return c10::get(shape_); } diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 41612bd0a6cfc..5afac0002eba8 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include #include @@ -61,24 +63,28 @@ using torch::profiler::impl::ProfilerThreadLocalStateBase; using torch::profiler::impl::Result; using torch::profiler::impl::shapesToStr; using torch::profiler::impl::stacksToStr; -using torch::profiler::impl::kineto::annotation_t; -using torch::profiler::impl::kineto::KinetoActivityType; struct EventFieldsVisitor { EventFieldsVisitor( std::shared_ptr& result, KinetoEvent& kineto_event, const post_process_t& post_process) - : kineto_event_{kineto_event}, post_process_{post_process} { + : kineto_activity_{result->kineto_activity_}, + kineto_event_{kineto_event}, + post_process_{post_process} { + c10::guts::if_constexpr([&](auto _) { + kineto_event.deviceIndex(_(result->kineto_info_).device); + kineto_event.deviceResourceId(_(result->kineto_info_).resource); + }); + pushPythonMetadata(result->parent_.lock()); - c10::visit(*this, result->extra_fields_); + result->visit(*this); handleStack(result->parent_); } void operator()(ExtraFields& op_event) { handleJIT(op_event); kineto_event_.get() - .endThreadId(op_event.end_tid_) .scope((int8_t)op_event.scope_) .debugHandle(op_event.debug_handle_) .setAsync(op_event.is_async_); @@ -86,13 +92,13 @@ struct EventFieldsVisitor { auto& shapes = op_event.inputs_.shapes_; if (!shapes.empty()) { kineto_event_.get().shapes(shapes); - annotations_.emplace_back("Input Dims", shapesToStr(shapes)); + addMetadata("Input Dims", shapesToStr(shapes)); } auto& dtypes = op_event.inputs_.dtypes_; if (!dtypes.empty()) { kineto_event_.get().dtypes(dtypes); - annotations_.emplace_back("Input type", dtypesToStr(dtypes)); + addMetadata("Input type", dtypesToStr(dtypes)); } if (!op_event.extra_args_.empty()) { @@ -110,24 +116,20 @@ struct EventFieldsVisitor { kineto_event_.get() .sequenceNr(op_event.sequence_number_) .fwdThreadId(op_event.forward_tid_); - annotations_.emplace_back( - "Fwd thread id", std::to_string(op_event.forward_tid_)); - annotations_.emplace_back( - "Sequence number", std::to_string(op_event.sequence_number_)); + addMetadata("Fwd thread id", std::to_string(op_event.forward_tid_)); + addMetadata("Sequence number", std::to_string(op_event.sequence_number_)); } } void operator()(ExtraFields& backend_event) { handleJIT(backend_event); kineto_event_.get() - .endThreadId(kineto_event_.get().startThreadId()) .scope((int8_t)backend_event.scope_) .debugHandle(backend_event.debug_handle_) .backend(backend_event.backend_); if (!backend_event.backend_.empty()) { - annotations_.emplace_back( - "Backend", "\"" + backend_event.backend_ + "\""); + addMetadata("Backend", "\"" + backend_event.backend_ + "\""); } } @@ -136,18 +138,31 @@ struct EventFieldsVisitor { .deviceIndex(alloc.device_index_) .nBytes(alloc.alloc_size_); - annotations_ = { - {"Device Type", std::to_string((int8_t)alloc.device_type_)}, - {"Device Id", std::to_string(alloc.device_index_)}, - {"Addr", std::to_string(reinterpret_cast(alloc.ptr_))}, - {"Bytes", std::to_string(alloc.alloc_size_)}}; + addMetadata("Device Type", std::to_string((int8_t)alloc.device_type_)); + addMetadata("Device Id", std::to_string(alloc.device_index_)); + addMetadata("Addr", std::to_string(reinterpret_cast(alloc.ptr_))); + addMetadata("Bytes", std::to_string(alloc.alloc_size_)); + if (alloc.total_allocated_ >= 0) { + addMetadata("Total Allocated", std::to_string(alloc.total_allocated_)); + } + if (alloc.total_reserved_ >= 0) { + addMetadata("Total Reserved", std::to_string(alloc.total_reserved_)); + } + } + + void operator()(const ExtraFields& alloc) { + kineto_event_.get() + .deviceIndex(alloc.device_index_) + .nBytes(alloc.alloc_size_); + + addMetadata("Device Type", std::to_string((int8_t)alloc.device_type_)); + addMetadata("Device Id", std::to_string(alloc.device_index_)); + addMetadata("Bytes", std::to_string(alloc.alloc_size_)); if (alloc.total_allocated_ >= 0) { - annotations_.emplace_back( - "Total Allocated", std::to_string(alloc.total_allocated_)); + addMetadata("Total Allocated", std::to_string(alloc.total_allocated_)); } if (alloc.total_reserved_ >= 0) { - annotations_.emplace_back( - "Total Reserved", std::to_string(alloc.total_reserved_)); + addMetadata("Total Reserved", std::to_string(alloc.total_reserved_)); } } @@ -162,13 +177,13 @@ struct EventFieldsVisitor { // NB: This is only for the JIT stack. The python stack (if applicable) // is constructed later. kineto_event_.get().stack(jit_stack); - annotations_.emplace_back( + addMetadata( "Call stack", torch::profiler::impl::stacksToStr(jit_stack, ";")); } if (!jit_modules.empty()) { kineto_event_.get().moduleHierarchy(jit_modules); - annotations_.emplace_back( + addMetadata( "Module Hierarchy", torch::profiler::impl::stacksToStr(jit_modules, ".")); } @@ -177,8 +192,7 @@ struct EventFieldsVisitor { void operator()(const ExtraFields& py_call) { addPythonAnnotations(py_call); if (py_call.module_.has_value()) { - annotations_.emplace_back( - "Python module id", std::to_string(py_call.module_->id_)); + addMetadata("Python module id", std::to_string(py_call.module_->id_)); } } @@ -186,6 +200,14 @@ struct EventFieldsVisitor { addPythonAnnotations(py_call); } + void operator()(const ExtraFields& e) { + TORCH_INTERNAL_ASSERT(kineto_activity_ == nullptr); + const auto linked = e.linked_activity_.lock(); + if (linked) { + kineto_event_.get().linkedCorrelationId(linked->correlationID()); + } + } + void pushPythonMetadata(std::shared_ptr parent) { auto push = [&](const auto& i) { c10::guts::if_constexprextra_fields_); + parent->visit(push); parent = parent->parent_.lock(); } } template void addPythonAnnotations(T& t) { - annotations_.emplace_back("Python id", std::to_string(t.id_)); - annotations_.emplace_back( + addMetadata("Python id", std::to_string(t.id_)); + addMetadata( "Python parent id", !py_metadata_.empty() ? std::to_string(py_metadata_.at(0).id_) : "null"); - annotations_.emplace_back("Python thread", std::to_string(t.python_tid_)); + addMetadata("Python thread", std::to_string(t.python_tid_)); } void handleStack(std::weak_ptr parent) { @@ -223,22 +245,28 @@ struct EventFieldsVisitor { } if (kineto_event_.get().hasStack()) { - annotations_.emplace_back( + addMetadata( "Call stack", torch::profiler::impl::stacksToStr(kineto_event_.get().stack(), ";")); } } + void addMetadata(const std::string& key, const std::string& value) { + if (kineto_activity_) { + torch::profiler::impl::kineto::addMetadata(kineto_activity_, key, value); + } + } + struct PythonMetadata { size_t id_; size_t python_tid_; std::string name_; }; + const torch::profiler::impl::kineto::activity_t* kineto_activity_; std::reference_wrapper kineto_event_; std::reference_wrapper post_process_; std::vector py_metadata_; - annotation_t annotations_; }; // Assumption: Total threads number will not exceed 2^16-1, and total ops will @@ -253,8 +281,7 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { std::set activities) : ProfilerThreadLocalStateBase(config), start_time_(getTimeUs()), - record_queue_(config, activities), - cpu_trace_(start_time_, "PyTorch Profiler") {} + record_queue_(config, activities) {} ~KinetoThreadLocalState() override = default; static KinetoThreadLocalState* getTLS() { @@ -286,6 +313,22 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { } } + void reportOutOfMemory( + int64_t alloc_size, + int64_t total_allocated, + int64_t total_reserved, + c10::Device device) override { + if (config_.profile_memory && config_.state != ProfilerState::Disabled) { + record_queue_.getSubqueue()->emplace_ooms_event( + torch::profiler::impl::getApproximateTime(), + alloc_size, + total_allocated, + total_reserved, + device.type(), + device.index()); + } + } + const post_process_t& getEventPostProcessingCallback() const { return event_post_process_cb_; } @@ -294,12 +337,19 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { event_post_process_cb_ = std::move(cb); } - torch::profiler::impl::kineto::ActivityTraceWrapper finalizeTrace() { + std::unique_ptr + finalizeTrace() { auto end_time = getTimeUs(); record_queue_.stop(); - materializeOpEvents(); - finalizeCPUTrace(cpu_trace_.get()); + std::lock_guard guard(state_mutex_); + auto converter = clock_converter_.makeConverter(); + auto records_and_trace = + record_queue_.getRecords(converter, start_time_, end_time); + + materializeOpEvents(records_and_trace.first); + + // finalizeCPUTrace(cpu_trace_.get()); // `kineto_events_` does not include Python events. Instead it exposes them // via the `stacks` property. @@ -310,55 +360,33 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { [](const auto& i) { return i.is_python_function_; }), kineto_events_.end()); - { - std::lock_guard guard(state_mutex_); - cpu_trace_.transferCpuTrace(end_time); - } - - if (config().state != ProfilerState::KINETO_ONDEMAND) { - auto trace = torch::profiler::impl::kineto::stopTrace(); - TORCH_CHECK(trace || !torch::profiler::kKinetoAvailable); - addTraceEvents(trace); - return trace; - } else { - return torch::profiler::impl::kineto::ActivityTraceWrapper(); - } + return std::move(records_and_trace.second); } - void materializeOpEvents() { - std::lock_guard guard(state_mutex_); - auto converter = clock_converter_.makeConverter(); - - for (auto& e : record_queue_.getRecords(converter)) { + void materializeOpEvents(std::vector>& events) { + for (auto& e : events) { if (e->parent_.expired()) { event_tree_.push_back(e); } if (e->finished_) { - int64_t start_us = e->start_time_ns_ / 1000; - int64_t end_us = e->endTimeNS() / 1000; kineto_events_.emplace_back( - e->kinetoType() == KinetoActivityType::PYTHON_FUNCTION); + e->kinetoType() == libkineto::ActivityType::PYTHON_FUNCTION); kineto_events_.back() .name(e->name()) - .startUs(start_us) - .durationUs(end_us - start_us) + .startUs(e->start_time_ns_ / 1000) + .durationUs((e->endTimeNS() - e->start_time_ns_) / 1000) .correlationId(e->correlationID()) .deviceType(e->deviceType()) - .startThreadId(e->start_tid_); + .startThreadId(e->start_tid_) + .endThreadId(e->endTID()) + .activityType((uint8_t)e->kinetoType()); - // NB: also sets fields on `kineto_events_.back()`. - auto visitor = EventFieldsVisitor( + EventFieldsVisitor set_fields_and_metadata( e, kineto_events_.back(), getEventPostProcessingCallback()); - cpu_trace_.addCPUActivity( - e->name(), - e->kinetoType(), - e->kineto_info_, - e->correlationID(), - start_us, - end_us, - visitor.annotations_); + // It is not safe to use the activity after post processing. + e->kineto_activity_ = nullptr; } } } @@ -439,42 +467,9 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { } #endif // USE_KINETO - void addTraceEvents( - torch::profiler::impl::kineto::ActivityTraceWrapper& trace) { -#ifdef USE_KINETO - const auto& events = *(trace.get()->activities()); - for (const auto& ev_ptr : events) { - if (ev_ptr == nullptr) { - continue; - } - const auto& activity = *ev_ptr; - // These events are already processed - if (activity.type() != libkineto::ActivityType::CPU_OP && - activity.type() != libkineto::ActivityType::CPU_INSTANT_EVENT && - activity.type() != libkineto::ActivityType::USER_ANNOTATION && - activity.type() != libkineto::ActivityType::PYTHON_FUNCTION) { - kineto_events_.emplace_back(); - auto& kineto_event = kineto_events_.back(); - kineto_event.name(activity.name()) - .deviceIndex(activity.deviceId()) - .deviceResourceId(activity.resourceId()) - .startUs(activity.timestamp()) - .durationUs(activity.duration()) - .activityType((uint8_t)activity.type()); - if (activity.linkedActivity()) { - kineto_event.linkedCorrelationId( - activity.linkedActivity()->correlationId()); - } - kineto_event.deviceType(deviceTypeFromActivity(activity.type())); - } - } -#endif // USE_KINETO - } - uint64_t start_time_; torch::profiler::impl::ApproximateClockToUnixTimeConverter clock_converter_; torch::profiler::impl::RecordQueue record_queue_; - torch::profiler::impl::kineto::TraceWrapper cpu_trace_; std::vector kineto_events_; std::vector event_tree_; // Optional, if event post-processing is enabled. @@ -623,7 +618,8 @@ void reportBackendEventToActiveKinetoProfiler( void prepareProfiler( const torch::profiler::impl::ProfilerConfig& config, const std::set& activities) { - if (config.state == ProfilerState::NVTX) { + if (config.state == ProfilerState::NVTX || + config.state == ProfilerState::ITT) { return; } TORCH_CHECK( @@ -642,6 +638,9 @@ void enableProfilerWithEventPostProcess( TORCH_CHECK( config.state != ProfilerState::NVTX, "NVTX does not support post processing callback."); + TORCH_CHECK( + config.state != ProfilerState::ITT, + "ITT does not support post processing callback."); TORCH_INTERNAL_ASSERT( GlobalStateManager::get() == nullptr, "On-demand profiling does not support post processing callback"); @@ -659,6 +658,9 @@ void enableProfiler( if (config.state == ProfilerState::NVTX) { torch::profiler::impl::pushNVTXCallbacks(config, scopes); return; + } else if (config.state == ProfilerState::ITT) { + torch::profiler::impl::pushITTCallbacks(config, scopes); + return; } TORCH_CHECK( @@ -702,7 +704,8 @@ std::unique_ptr disableProfiler() { (config.state == ProfilerState::KINETO || config.state == ProfilerState::KINETO_GPU_FALLBACK || config.state == ProfilerState::KINETO_ONDEMAND || - config.state == ProfilerState::NVTX), + config.state == ProfilerState::NVTX || + config.state == ProfilerState::ITT), "Can't disable Kineto profiler when it's not running"); if (state_ptr->hasCallbackHandle()) { @@ -754,7 +757,8 @@ int64_t KinetoEvent::cudaElapsedUs() const { ProfilerResult::ProfilerResult( uint64_t start_time, std::vector events, - torch::profiler::impl::kineto::ActivityTraceWrapper trace, + std::unique_ptr&& + trace, std::vector&& event_tree) : trace_start_us_(start_time), events_(std::move(events)), @@ -764,7 +768,7 @@ ProfilerResult::ProfilerResult() = default; ProfilerResult::~ProfilerResult() = default; void ProfilerResult::save(const std::string& path) { - trace_.save(path); + trace_->save(path); } } // namespace profiler diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index 8c5a0cdfa2ad8..68c803d1e76a6 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -4,11 +4,17 @@ #include #include -#include -#include #include namespace torch { +namespace profiler { +namespace impl { +struct Result; +namespace kineto { +struct ActivityTraceWrapper; +} // namespace kineto +} // namespace impl +} // namespace profiler namespace autograd { namespace profiler { using experimental_event_t = std::shared_ptr; @@ -273,8 +279,8 @@ struct TORCH_API KinetoEvent { int64_t debug_handle_{-1}; std::string backend_; - torch::profiler::impl::CUDAEventStub cuda_event_start_ = nullptr; - torch::profiler::impl::CUDAEventStub cuda_event_end_ = nullptr; + torch::profiler::impl::ProfilerEventStub cuda_event_start_ = nullptr; + torch::profiler::impl::ProfilerEventStub cuda_event_end_ = nullptr; bool is_python_function_; }; @@ -286,7 +292,8 @@ struct TORCH_API ProfilerResult { ProfilerResult( uint64_t start_time, std::vector events, - torch::profiler::impl::kineto::ActivityTraceWrapper trace, + std::unique_ptr&& + trace, std::vector&& event_tree); ~ProfilerResult(); @@ -307,7 +314,7 @@ struct TORCH_API ProfilerResult { private: uint64_t trace_start_us_ = 0; std::vector events_; - torch::profiler::impl::kineto::ActivityTraceWrapper trace_; + std::unique_ptr trace_; std::vector event_tree_; }; diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 22d9c5ecf9902..a7422335c4425 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -197,7 +197,7 @@ void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) { return; } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { - torch::profiler::impl::cudaStubs()->nvtxMarkA(name.c_str()); + torch::profiler::impl::cudaStubs()->mark(name.c_str()); } else { LegacyEvent evt( EventKind::Mark, @@ -229,7 +229,7 @@ void ProfilerLegacyThreadLocalState::pushRange( return; } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { - torch::profiler::impl::cudaStubs()->nvtxRangePushA( + torch::profiler::impl::cudaStubs()->rangePush( torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes) .c_str()); } else { @@ -277,7 +277,7 @@ void ProfilerLegacyThreadLocalState::popRange( return; } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { - torch::profiler::impl::cudaStubs()->nvtxRangePop(); + torch::profiler::impl::cudaStubs()->rangePop(); } else { // In some cases RecordFunction (and popRange) may be // called on a different thread than pushRange diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h index a2b7626aa9d17..e12beca806099 100644 --- a/torch/csrc/autograd/profiler_legacy.h +++ b/torch/csrc/autograd/profiler_legacy.h @@ -266,7 +266,7 @@ struct TORCH_API LegacyEvent { int64_t cpu_memory_usage_ = 0; int64_t cuda_memory_usage_ = 0; int device_ = -1; - torch::profiler::impl::CUDAEventStub cuda_event = nullptr; + torch::profiler::impl::ProfilerEventStub cuda_event = nullptr; int node_id_ = 0; bool is_remote_ = false; int64_t cuda_us_ = -1; diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 0301aa65a64e5..34aab7e8e6709 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include namespace py = pybind11; @@ -31,31 +32,40 @@ namespace impl { namespace { enum CallType { PyCall = 0, PyModuleCall, PyCCall }; static constexpr size_t CallTypeSize = 3; +using no_ephemeral_t = std::tuple<>; // ============================================================================ // == Miscellaneous structs and utils ========================================= // ============================================================================ struct CodeLocation { CodeLocation() = default; - explicit CodeLocation(const PyFrameObject* frame) - : code_{frame->f_code}, lasti_{frame->f_lasti} {} + explicit CodeLocation(PyFrameObject* frame) + : line_number_{PyFrame_GetLineNumber(frame)} { + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + filename_ = THPUtils_unpackStringView(code->co_filename).data(); + name_ = THPUtils_unpackStringView(code->co_name).data(); + } bool operator==(const CodeLocation& other) const { - return code_ == other.code_ && lasti_ == other.lasti_; + return filename_ == other.filename_ && name_ == other.name_ && + line_number_ == other.line_number_; } - PyCodeObject* code_{nullptr}; - int lasti_{0}; + const char* filename_{nullptr}; + const char* name_{nullptr}; + int line_number_{0}; }; -PyObject* nnModuleCode() { +PyCodeObject* nnModuleCode() { static auto module_call_code = []() { pybind11::gil_scoped_acquire gil; - return py::module::import("torch.nn") - .attr("Module") - .attr("__call__") - .attr("__code__") - .ptr(); + auto res = py::module::import("torch.nn") + .attr("Module") + .attr("__call__") + .attr("__code__") + .ptr(); + TORCH_INTERNAL_ASSERT(PyCode_Check(res)); + return (PyCodeObject*)res; }(); return module_call_code; } @@ -68,7 +78,7 @@ PyObject* nnModuleCode() { template <> struct std::hash { size_t operator()(const torch::profiler::impl::CodeLocation& x) { - return c10::get_hash(x.code_, x.lasti_); + return c10::get_hash(x.filename_, x.name_, x.line_number_); } }; @@ -141,8 +151,21 @@ class CallTypeHelper final { // // To add a new event type to the cache: // 1) Add an entry to the `CallType` enum. -// 2) Add a specialization of Config which defined key_t and cache_t. +// 2) Add a specialization of Config which defined key_t, ephemeral_t and +// cache_t. // 3) Add a specialization of ValueCache::store and ValueCache::load. +// +// ------------------------- +// -- Ephemeral arguments -- +// ------------------------- +// The value cache mechanism assumes that `key_t` is enough to specify the +// correct value. However it may not be possible to materialize a value using +// only an instance of `key_t`. As a result, the cache also accepts "ephemeral" +// inputs which can be used to populate the value cache. Ephemeral inputs come +// with two caveats: +// 1) They are NOT safe to save, and cannot be used after `ValueCache::store`. +// 2) They should be used to access data that is not expect to change from +// call to call, such as the name of a function. template struct Config; @@ -150,6 +173,7 @@ struct Config; template <> struct Config { using key_t = CodeLocation; + using ephemeral_t = no_ephemeral_t; using cache_t = ska::flat_hash_map; static constexpr EventType event_type = EventType::PyCall; }; @@ -157,6 +181,7 @@ struct Config { template <> struct Config { using key_t = PyModuleSelf; + using ephemeral_t = PyFrameObject*; struct cache_t { c10::optional module_forward_; ska::flat_hash_map modules_; @@ -167,7 +192,8 @@ struct Config { template <> struct Config { - using key_t = torch::profiler::impl::PyCFunction; + using key_t = PyMethod; + using ephemeral_t = PyObject*; using cache_t = ska::flat_hash_map; static constexpr EventType event_type = EventType::PyCCall; }; @@ -186,8 +212,7 @@ class Callsite { "Key should be trivial, as it is passed by value."); template - Callsite(U value, const PyFrameObject* f_back) - : value_(value), caller_(f_back) {} + Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {} bool operator==(const Callsite& other) const { return value_ == other.value_ && caller_ == other.caller_; @@ -200,7 +225,7 @@ class Callsite { class ValueCache { public: template - void store(const typename Config::key_t&); + void store(const typename Config::key_t&, typename Config::ephemeral_t); template auto load(const Callsite& callsite, size_t python_tid) const { @@ -234,14 +259,13 @@ using PyModuleCallKey = Config::key_t; using PyCCallKey = Config::key_t; template <> -void ValueCache::store(const PyCallKey& key) { +void ValueCache::store(const PyCallKey& key, no_ephemeral_t) { auto& locations = std::get(state_); if (C10_UNLIKELY(locations.find(key) == locations.end())) { - TORCH_INTERNAL_ASSERT(key.code_ != nullptr); locations[key] = { - PyCode_Addr2Line(key.code_, key.lasti_), - at::StringView(THPUtils_unpackString(key.code_->co_filename)), - at::StringView(THPUtils_unpackString(key.code_->co_name))}; + key.line_number_, + at::StringView(key.filename_), + at::StringView(key.name_)}; } } @@ -252,14 +276,16 @@ ExtraFields::args_t ValueCache::load( } template <> -void ValueCache::store(const PyModuleCallKey& key) { +void ValueCache::store( + const PyModuleCallKey& key, + Config::ephemeral_t frame) { auto& cache = std::get(state_); if (C10_UNLIKELY(cache.modules_.find(key) == cache.modules_.end())) { if (C10_UNLIKELY(!cache.module_forward_.has_value())) { - auto frame = PyEval_GetFrame(); - TORCH_INTERNAL_ASSERT((PyObject*)(frame->f_code) == nnModuleCode()); + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + TORCH_INTERNAL_ASSERT(code.get() == nnModuleCode()); cache.module_forward_ = PyCallKey(frame); - store(*cache.module_forward_); + store(*cache.module_forward_, no_ephemeral_t()); } auto cls_handle = py::handle((PyObject*)key).attr("__class__"); auto cls = PyModuleCls(cls_handle.ptr()); @@ -283,10 +309,12 @@ ExtraFields::args_t ValueCache::load( } template <> -void ValueCache::store(const PyCCallKey& key) { +void ValueCache::store( + const PyCCallKey& key, + Config::ephemeral_t arg) { auto& names = std::get(state_); if (C10_UNLIKELY(names.find(key) == names.end())) { - names[key] = at::StringView(py::repr((PyObject*)key)); + names[key] = at::StringView(py::repr(arg)); } } @@ -298,9 +326,12 @@ ExtraFields::args_t ValueCache::load( // TODO: Use re2. void ValueCache::trimPrefixes() { - static auto prefixes = py::module::import("torch.profiler.python_tracer") - .attr("_prefix_regex")() - .cast>(); + static const auto prefixes = []() { + pybind11::gil_scoped_acquire gil; + return py::module::import("torch.profiler.python_tracer") + .attr("_prefix_regex")() + .cast>(); + }(); for (auto& it : std::get(state_)) { std::string filename = it.second.filename_.str(); @@ -332,11 +363,14 @@ struct TraceKeyCacheState { } }; - TraceKey intern(Callsite callsite, ValueCache& value_cache) { + TraceKey intern( + Callsite callsite, + typename Config::ephemeral_t ephemeral, + ValueCache& value_cache) { auto it = state_.find(callsite); if (C10_UNLIKELY(it == state_.end())) { - value_cache.store(callsite.value_); - value_cache.store(callsite.caller_); + value_cache.store(callsite.value_, ephemeral); + value_cache.store(callsite.caller_, no_ephemeral_t()); it = state_.insert({callsite, nextKey()}).first; } return it->second; @@ -426,13 +460,13 @@ struct ThreadLocalResults { Py_DECREF((PyObject*)ctx_); } - template - TraceKey intern(Args... args) { + template + TraceKey intern(Ephemeral ephemeral, Args... args) { static_assert( Config::event_type == E, "ThreadLocalResults.intern called from the wrong typed context."); - return std::get(trace_keys_) - .intern(Callsite(std::forward(args)...), *value_cache_); + auto callsite = Callsite(std::forward(args)...); + return std::get(trace_keys_).intern(callsite, ephemeral, *value_cache_); } static constexpr size_t BLOCK_SIZE = 1024; @@ -474,7 +508,7 @@ class PythonTracer final : public python_tracer::PythonTracerBase { PyObject* arg); torch::profiler::impl::RecordQueue* queue_; - PyObject* module_call_code_; + PyCodeObject* module_call_code_; std::deque thread_local_results_; ValueCache value_cache_; @@ -528,14 +562,16 @@ void PythonTracer::start(torch::profiler::impl::RecordQueue* queue) { // to all the prior frames onto our event stack. (We stop at depth=128) std::vector current_stack; auto frame = PyEval_GetFrame(); + Py_INCREF(frame); size_t depth = 0; // Make sure we can't infinite loop. while (frame != nullptr && depth <= 128) { current_stack.push_back(frame); - frame = frame->f_back; + frame = PyFrame_GetBack(frame); depth++; } for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { recordPyCall(thread_local_results_.back(), *it); + Py_DECREF(*it); } // Note: @@ -572,7 +608,8 @@ void PythonTracer::clear() { void PythonTracer::recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame) { static constexpr auto E = EventType::PyCall; auto get_key = [&]() -> TraceKey { - if ((PyObject*)(frame->f_code) == module_call_code_) { + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + if (code.get() == module_call_code_) { // By default, CPython stores locals in a "fast" format, with an array // of names and an array of values. Consequently, frame->f_locals is // NULL since the interpreter has no need to populate it. @@ -582,15 +619,17 @@ void PythonTracer::recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame) { // not stable across versions. As a result, we are forced to call // `PyFrame_FastToLocals` which forces the interpreter to materialize // the full dict of locals. - PyFrame_FastToLocals(frame); - auto self = PyDict_GetItemString(frame->f_locals, "self"); - PyFrame_LocalsToFast(frame, 0); - TORCH_INTERNAL_ASSERT(frame->f_back != nullptr); - return tls.intern(self, frame->f_back); - + auto locals = THPObjectPtr(PyFrame_GetLocals(frame)); + auto self = THPObjectPtr(PyDict_GetItemString(locals, "self")); + Py_INCREF(self.get()); + auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); + TORCH_INTERNAL_ASSERT(back != nullptr); + return tls.intern( + frame, self.get(), back.get()); } else { - auto f_back = frame->f_back != nullptr ? frame->f_back : frame; - return tls.intern(frame, f_back); + auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); + auto f_back = (back.get() != nullptr) ? back.get() : frame; + return tls.intern(no_ephemeral_t(), frame, f_back); } }; queue_->getSubqueue()->emplace_py_call(get_key(), getApproximateTime()); @@ -600,9 +639,13 @@ void PythonTracer::recordCCall( ThreadLocalResults& tls, PyFrameObject* frame, PyObject* arg) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(Py_TYPE(arg) == &PyCFunction_Type); + auto fn = reinterpret_cast(arg); + // NB: For C calls a new frame is not created, so we use `frame` rather than // `frame->f_back`. - auto key = tls.intern(arg, frame); + auto key = tls.intern( + arg, (void*)(fn->m_ml), frame); queue_->getSubqueue()->emplace_py_call(key, getApproximateTime()); } diff --git a/torch/csrc/autograd/python_anomaly_mode.cpp b/torch/csrc/autograd/python_anomaly_mode.cpp index 4cde43b6dae9a..ec5dfe1b0995c 100644 --- a/torch/csrc/autograd/python_anomaly_mode.cpp +++ b/torch/csrc/autograd/python_anomaly_mode.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/autograd/python_anomaly_mode.h b/torch/csrc/autograd/python_anomaly_mode.h index d58efd34ebaa4..6032940bfbaf2 100644 --- a/torch/csrc/autograd/python_anomaly_mode.h +++ b/torch/csrc/autograd/python_anomaly_mode.h @@ -4,6 +4,7 @@ #include #include #include +#include namespace torch { namespace autograd { diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index 3c4805fcc18bb..474bf14119b0c 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 203ccea416f2a..93e5441a1917c 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #ifndef _WIN32 diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 492585396ae22..d2bcd4974712e 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/autograd/python_hook.cpp b/torch/csrc/autograd/python_hook.cpp index 44d5871135713..4ce4904d5b060 100644 --- a/torch/csrc/autograd/python_hook.cpp +++ b/torch/csrc/autograd/python_hook.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/autograd/python_saved_variable_hooks.h b/torch/csrc/autograd/python_saved_variable_hooks.h index 835b2ed04cef1..4962a4a827d25 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.h +++ b/torch/csrc/autograd/python_saved_variable_hooks.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace py = pybind11; diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index bcb7bb6467261..57c9161e27717 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -321,154 +321,6 @@ static PyObject* THPVariable_full( END_HANDLE_TH_ERRORS } -inline Tensor dispatch_randint( - int64_t high, - IntArrayRef size, - c10::optional generator, - Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::randint_out(result, high, size, generator); -} -inline Tensor dispatch_randint( - int64_t high, - IntArrayRef size, - c10::optional generator, - const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::randint(high, size, generator, options); -} -inline Tensor dispatch_randint(int64_t high, IntArrayRef size, Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::randint_out(result, high, size); -} -inline Tensor dispatch_randint( - int64_t high, - IntArrayRef size, - const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::randint(high, size, options); -} -inline Tensor dispatch_randint( - int64_t low, - int64_t high, - IntArrayRef size, - c10::optional generator, - Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::randint_out(result, low, high, size, generator); -} -inline Tensor dispatch_randint( - int64_t low, - int64_t high, - IntArrayRef size, - c10::optional generator, - const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::randint(low, high, size, generator, options); -} -inline Tensor dispatch_randint( - int64_t low, - int64_t high, - IntArrayRef size, - Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::randint_out(result, low, high, size); -} -inline Tensor dispatch_randint( - int64_t low, - int64_t high, - IntArrayRef size, - const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::randint(low, high, size, options); -} - -static PyObject* THPVariable_randint( - PyObject* self_, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - { - "randint(int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", - "randint(int64_t low, int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", - }, - /*traceable=*/false); - - ParsedArgs<9> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - - if (r.has_torch_function()) { - return handle_torch_function( - r, args, kwargs, THPVariableFunctionsModule, "torch"); - } - - if (r.idx == 0) { - if (r.isNone(3)) { - auto high = r.toInt64(0); - auto size = r.intlist(1); - auto generator = r.generator(2); - // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) - auto dtype = r.scalartypeWithDefault(4, at::ScalarType::Long); - auto device = r.device(6); - const auto options = TensorOptions() - .dtype(dtype) - .device(device) - .layout(r.layout(5)) - .requires_grad(r.toBool(7)); - return wrap(dispatch_randint(high, size, generator, options)); - } else { - check_out_type_matches( - r.tensor(3), - r.scalartype(4), - r.isNone(4), - r.layout(5), - r.device(6), - r.isNone(6)); - return wrap(dispatch_randint( - r.toInt64(0), r.intlist(1), r.generator(2), r.tensor(3)) - .set_requires_grad(r.toBool(7))); - } - } else if (r.idx == 1) { - if (r.isNone(4)) { - auto low = r.toInt64(0); - auto high = r.toInt64(1); - auto size = r.intlist(2); - auto generator = r.generator(3); - // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) - auto dtype = r.scalartypeWithDefault(5, at::ScalarType::Long); - auto device = r.device(7); - const auto options = TensorOptions() - .dtype(dtype) - .device(device) - .layout(r.layout(6)) - .requires_grad(r.toBool(8)); - return wrap(dispatch_randint(low, high, size, generator, options)); - } else { - check_out_type_matches( - r.tensor(4), - r.scalartype(5), - r.isNone(5), - r.layout(6), - r.device(7), - r.isNone(7)); - return wrap(dispatch_randint( - r.toInt64(0), - r.toInt64(1), - r.intlist(2), - r.generator(3), - r.tensor(4)) - .set_requires_grad(r.toBool(8))); - } - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - // implemented on python object to allow torch.as_tensor to be constructed with // arbitrarily nested python objects - list, tuple, np array, scalar, etc. static PyObject* THPVariable_as_tensor( @@ -1062,10 +914,6 @@ static PyMethodDef torch_functions_manual[] = { castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"randint", - castPyCFunctionWithKeywords(THPVariable_randint), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, {"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 53cfca8653b60..99ae57e263987 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -39,6 +40,7 @@ #include +#include #include #include #include @@ -238,8 +240,7 @@ c10::intrusive_ptr concrete_detach_fn( void concrete_dispatch_fn( const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, - torch::jit::Stack* stack, - const std::shared_ptr& type); + torch::jit::Stack* stack); bool concrete_is_contiguous_fn( const c10::impl::PyInterpreter*, const c10::TensorImpl* self); @@ -255,6 +256,12 @@ c10::IntArrayRef concrete_strides_fn( c10::IntArrayRef concrete_sizes_fn( const c10::impl::PyInterpreter*, const c10::TensorImpl* self); +c10::SymIntArrayRef concrete_sym_sizes_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self); +c10::Layout concrete_layout_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self); class PyInterpreterHolder { public: @@ -268,7 +275,9 @@ class PyInterpreterHolder { &concrete_device_fn, &concrete_dim_fn, &concrete_strides_fn, - &concrete_sizes_fn)) {} + &concrete_sizes_fn, + &concrete_sym_sizes_fn, + &concrete_layout_fn)) {} // NB: intentionally leaks the memory ~PyInterpreterHolder() { impl_->disarm(); @@ -567,6 +576,12 @@ static int THPVariable_clear(THPVariable* self) { return 0; } +int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) { + TORCH_INTERNAL_ASSERT( + false, "Tensor tp_traverse function was not overriden properly"); + return 0; +} + PyObject* THPVariable_pynew( PyTypeObject* type, PyObject* args, @@ -608,9 +623,9 @@ static PyObject* THPVariable_make_subclass( PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ - "_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)", + "_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, Device? device_for_backend_keys=None)", }); - ParsedArgs<5> parsed_args{}; + ParsedArgs<7> parsed_args{}; auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); if (!PyType_Check(cls)) { @@ -639,6 +654,13 @@ static PyObject* THPVariable_make_subclass( if (r.toBool(4)) { data.unsafeGetTensorImpl()->set_custom_device(true); } + if (r.toBool(5)) { + data.unsafeGetTensorImpl()->set_custom_layout(true); + } + if (!r.isNone(6)) { + data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); + } + return THPVariable_NewWithVar( (PyTypeObject*)cls, std::move(data), @@ -657,13 +679,13 @@ static PyObject* THPVariable_make_wrapper_subclass( "_make_wrapper_subclass(PyObject* cls, IntArrayRef size, *, IntArrayRef? strides=None, " "int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, " "Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, " - "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)", + "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False)", "_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, " "int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, " "Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, " - "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)", + "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False)", }); - ParsedArgs<12> parsed_args{}; + ParsedArgs<13> parsed_args{}; auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); @@ -755,6 +777,9 @@ static PyObject* THPVariable_make_wrapper_subclass( if (r.toBool(11)) { tensor.unsafeGetTensorImpl()->set_custom_device(true); } + if (r.toBool(12)) { + tensor.unsafeGetTensorImpl()->set_custom_layout(true); + } return THPVariable_NewWithVar( (PyTypeObject*)cls, @@ -774,45 +799,79 @@ PyObject* THPVariable_get_python_dispatch(THPVariable* self, void* unused) { END_HANDLE_TH_ERRORS } -PyObject* THPVariable_get_T(THPVariable* self, void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "T"); +// CRTP base class to implement the python bindings for a Tensor property in +// PyTorch A class that implements a property is expected to have: +// - static constexpr const char* name; +// - This variable should hold the Python name of the property +// - static Tensor fn(const Tensor&); +// - This function calls the relevant ATen on the tensor +template +struct GetterBase { + static PyObject* getter(THPVariable* self, void* /*unused*/) { + HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject*)self)) { + return handle_torch_function_getter(self, T::name); + } + return THPVariable_Wrap(T::fn(THPVariable_Unpack(self))); + END_HANDLE_TH_ERRORS } - const auto& var = THPVariable_Unpack(self); - return THPVariable_Wrap(var.numpy_T()); - END_HANDLE_TH_ERRORS -} +}; -PyObject* THPVariable_get_H(THPVariable* self, void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "H"); +struct PropertyT : GetterBase { + static constexpr const char* name = "T"; + static Tensor fn(const Tensor& t) { + return t.numpy_T(); } - const auto& var = THPVariable_Unpack(self); - return THPVariable_Wrap(var.matrix_H()); - END_HANDLE_TH_ERRORS -} +}; -PyObject* THPVariable_get_mT(THPVariable* self, void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "mT"); +struct PropertyH : GetterBase { + static constexpr const char* name = "H"; + static Tensor fn(const Tensor& t) { + return t.matrix_H(); } - const auto& var = THPVariable_Unpack(self); - return THPVariable_Wrap(var.mT()); - END_HANDLE_TH_ERRORS -} +}; -PyObject* THPVariable_get_mH(THPVariable* self, void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "mH"); +struct PropertymT : GetterBase { + static constexpr const char* name = "mT"; + static Tensor fn(const Tensor& t) { + return t.mT(); } - const auto& var = THPVariable_Unpack(self); - return THPVariable_Wrap(var.mH()); - END_HANDLE_TH_ERRORS -} +}; + +struct PropertymH : GetterBase { + static constexpr const char* name = "mH"; + static Tensor fn(const Tensor& t) { + return t.mH(); + } +}; + +struct PropertyData : GetterBase { + static constexpr const char* name = "data"; + static Tensor fn(const Tensor& t) { + return t.variable_data(); + } +}; + +struct PropertyGrad : GetterBase { + static constexpr const char* name = "grad"; + static Tensor fn(const Tensor& t) { + return t.grad(); + } +}; + +struct PropertyReal : GetterBase { + static constexpr const char* name = "real"; + static Tensor fn(const Tensor& t) { + return at::real(t); + } +}; + +struct PropertyImag : GetterBase { + static constexpr const char* name = "imag"; + static Tensor fn(const Tensor& t) { + return at::imag(t); + } +}; PyObject* THPVariable_get_cdata(THPVariable* self, void* unused) { HANDLE_TH_ERRORS @@ -872,16 +931,6 @@ static PyObject* THPVariable_is_leaf(THPVariable* self, void* unused) { END_HANDLE_TH_ERRORS } -static PyObject* THPVariable_get_data(THPVariable* self, void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "data"); - } - const auto& var = THPVariable_Unpack(self).variable_data(); - return THPVariable_Wrap(var); - END_HANDLE_TH_ERRORS -} - int THPVariable_set_data(THPVariable* self, PyObject* data, void* unused) { HANDLE_TH_ERRORS if (check_has_torch_function((PyObject*)self)) { @@ -899,15 +948,6 @@ int THPVariable_set_data(THPVariable* self, PyObject* data, void* unused) { END_HANDLE_TH_ERRORS_RET(-1) } -PyObject* THPVariable_get_grad(THPVariable* self, void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "grad"); - } - return THPVariable_Wrap(THPVariable_Unpack(self).grad()); - END_HANDLE_TH_ERRORS -} - int THPVariable_set_grad(THPVariable* self, PyObject* py_grad, void* unused) { HANDLE_TH_ERRORS if (check_has_torch_function((PyObject*)self)) { @@ -1205,8 +1245,7 @@ PyObject* THPVariable_get_shape(THPVariable* self, void* unused) { if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "shape"); } - // return THPSize_NewFromSymSizes(THPVariable_Unpack(self)); - return THPSize_New(THPVariable_Unpack(self)); + return THPSize_NewFromSymSizes(THPVariable_Unpack(self)); END_HANDLE_TH_ERRORS } @@ -1379,28 +1418,6 @@ static PyObject* THPVariable_device(THPVariable* self, void* unused) { END_HANDLE_TH_ERRORS } -PyObject* THPVariable_get_real(THPVariable* self, void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "real"); - } - auto& self_ = THPVariable_Unpack(self); - auto real = at::real(self_); - return THPVariable_Wrap(real); - END_HANDLE_TH_ERRORS -} - -PyObject* THPVariable_get_imag(THPVariable* self, void* unused) { - HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { - return handle_torch_function_getter(self, "imag"); - } - auto& self_ = THPVariable_Unpack(self); - auto imag = at::imag(self_); - return THPVariable_Wrap(imag); - END_HANDLE_TH_ERRORS -} - int THPVariable_set_real(PyObject* self, PyObject* real, void* unused) { HANDLE_TH_ERRORS auto& self_ = THPVariable_Unpack(self); @@ -1436,10 +1453,10 @@ static struct PyGetSetDef THPVariable_properties[] = { nullptr, nullptr, nullptr}, - {"T", (getter)THPVariable_get_T, nullptr, nullptr, nullptr}, - {"H", (getter)THPVariable_get_H, nullptr, nullptr, nullptr}, - {"mT", (getter)THPVariable_get_mT, nullptr, nullptr, nullptr}, - {"mH", (getter)THPVariable_get_mH, nullptr, nullptr, nullptr}, + {"T", (getter)PropertyT::getter, nullptr, nullptr, nullptr}, + {"H", (getter)PropertyH::getter, nullptr, nullptr, nullptr}, + {"mT", (getter)PropertymT::getter, nullptr, nullptr, nullptr}, + {"mH", (getter)PropertymH::getter, nullptr, nullptr, nullptr}, {"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr}, {"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr}, {"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr}, @@ -1455,17 +1472,17 @@ static struct PyGetSetDef THPVariable_properties[] = { nullptr, nullptr}, {"data", - (getter)THPVariable_get_data, + (getter)PropertyData::getter, (setter)THPVariable_set_data, nullptr, nullptr}, {"_grad", - (getter)THPVariable_get_grad, + (getter)PropertyGrad::getter, (setter)THPVariable_set_grad, nullptr, nullptr}, // Allows the python class to override .grad {"grad", - (getter)THPVariable_get_grad, + (getter)PropertyGrad::getter, (setter)THPVariable_set_grad, nullptr, nullptr}, @@ -1520,12 +1537,12 @@ static struct PyGetSetDef THPVariable_properties[] = { nullptr, nullptr}, {"real", - (getter)THPVariable_get_real, + (getter)PropertyReal::getter, (setter)THPVariable_set_real, nullptr, nullptr}, {"imag", - (getter)THPVariable_get_imag, + (getter)PropertyImag::getter, (setter)THPVariable_set_imag, nullptr, nullptr}, @@ -1639,7 +1656,7 @@ PyTypeObject THPVariableType = { Py_TPFLAGS_HAVE_GC, /* tp_flags */ nullptr, /* tp_doc */ // Also set by metaclass - nullptr, /* tp_traverse */ + (traverseproc)THPFunction_traverse, /* tp_traverse */ (inquiry)THPVariable_clear, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ @@ -1686,7 +1703,7 @@ static void clear_slots(PyTypeObject* type, PyObject* self) { PyMemberDef* mp; n = Py_SIZE(type); - mp = PyHeapType_GET_MEMBERS((PyHeapTypeObject*)type); + mp = type->tp_members; for (i = 0; i < n; i++, mp++) { if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) { char* addr = (char*)self + mp->offset; @@ -1898,7 +1915,7 @@ static int traverse_slots( PyMemberDef* mp; n = Py_SIZE(type); - mp = PyHeapType_GET_MEMBERS((PyHeapTypeObject*)type); + mp = type->tp_members; for (i = 0; i < n; i++, mp++) { if (mp->type == T_OBJECT_EX) { char* addr = (char*)self + mp->offset; @@ -2108,19 +2125,10 @@ py::object torchDispatchFromTensorImpl( TorchFunctionName::TorchDispatch)); } -// NOTE [dispatch_fn's type argument] -// `type` is nullable and represents the TorchDispatchMode going on. -// Right now we only support a single TorchDispatchMode, but in the future we -// could change this to a stack of TorchDispatchModes. -// -// If `type` isn't null, then we consider the type for dispatch by prepending -// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser` -// is responsible for doing overload resolution. void concrete_dispatch_fn( const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, - torch::jit::Stack* stack, - const std::shared_ptr& type) { + torch::jit::Stack* stack) { const auto& schema = op.schema(); const auto num_arguments = schema.arguments().size(); auto arguments = torch::jit::pop(*stack, num_arguments); @@ -2156,10 +2164,6 @@ void concrete_dispatch_fn( } std::string module_name_str = "torch.ops." + ns_str; - if (type) { - append_overloaded_type(&overloaded_args, type->ptr(getPyInterpreter())); - } - // Find overloaded tensors for (const auto idx : c10::irange(arguments.size())) { const auto& ivalue = arguments[idx]; @@ -2305,17 +2309,24 @@ c10::IntArrayRef concrete_strides_fn( auto out = torchDispatchFromTensorImpl( self, "stride", - py::module::import("torch").attr("ops").attr("aten").attr("stride").ptr(), + py::module::import("torch") + .attr("ops") + .attr("aten") + .attr("stride") + .attr("default") + .ptr(), "torch.ops.aten"); if (out == Py_None) { + TORCH_CHECK( + !self->has_symbolic_sizes_strides(), + "Cannot call sizes on a tensor with symbolic shapes/strides"); return self->strides_default(); } py::object values = py::reinterpret_steal(out.ptr()); - c10::TensorImpl* ptr = const_cast(self); - c10::optional mb_obj = ptr->check_pyobj(getPyInterpreter()); + c10::optional mb_obj = self->check_pyobj(getPyInterpreter()); TORCH_CHECK( mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); PyObject* subclass = *mb_obj; @@ -2333,6 +2344,22 @@ c10::IntArrayRef concrete_strides_fn( return c10::IntArrayRef(start, len); } +static std::vector values_from_buffer( + const c10::TensorImpl* self, + py::handle values) { + c10::TensorImpl* ptr = const_cast(self); + c10::optional mb_obj = ptr->check_pyobj(getPyInterpreter()); + TORCH_CHECK( + mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); + + py::object os = py::module_::import("torch").attr("overrides"); + py::function get_buffer = + py::reinterpret_borrow(os.attr("get_buffer")); + auto buffer = get_buffer(py::handle(*mb_obj), values, "size"); + auto result = THPUtils_unpackLongs(buffer.ptr()); + return result; +} + c10::IntArrayRef concrete_sizes_fn( const c10::impl::PyInterpreter*, const c10::TensorImpl* self) { @@ -2351,28 +2378,87 @@ c10::IntArrayRef concrete_sizes_fn( "torch.ops.aten"); if (out == Py_None) { + TORCH_CHECK( + !self->has_symbolic_sizes_strides(), + "Cannot call sizes on a tensor with symbolic shapes/strides"); return self->sizes_default(); } py::object values = py::reinterpret_steal(out.ptr()); - - c10::TensorImpl* ptr = const_cast(self); - c10::optional mb_obj = ptr->check_pyobj(getPyInterpreter()); - TORCH_CHECK( - mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); - PyObject* subclass = *mb_obj; - Py_INCREF(subclass); - py::object sub = py::reinterpret_steal(subclass); - - py::object os = py::module_::import("torch").attr("overrides"); - py::function get_buffer = - py::reinterpret_borrow(os.attr("get_buffer")); - auto buffer = get_buffer(sub, values, "size"); - auto result = THPUtils_unpackLongs(buffer.ptr()); + auto result = values_from_buffer(self, values); int64_t* start = (int64_t*)result[0]; int64_t len = result[1]; return c10::IntArrayRef(start, len); } +c10::SymIntArrayRef concrete_sym_sizes_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self) { + pybind11::gil_scoped_acquire gil; + at::impl::MaybeSetTLSOnEntryGuard guard; + HANDLE_TH_ERRORS + auto out = torchDispatchFromTensorImpl( + self, + "sym_size", + py::module::import("torch") + .attr("ops") + .attr("aten") + .attr("sym_size") + .attr("default") + .ptr(), + "torch.ops.aten"); + + if (out == Py_None) { + return self->sym_sizes_default(); + } + // We need to squeeze SymIntNodes and ints into `SymInts` + // since it's a format `sym_sizes()` are stored in + TORCH_CHECK( + py::isinstance(out) || py::isinstance(out), + "Symshape must be a list or a tuple"); + py::list symints; + for (auto it = out.begin(); it != out.end(); it++) { + auto elm = *it; + auto si = torch::is_symint_node(elm) + ? elm.cast()->toSymInt() + : c10::SymInt{py::cast(elm)}; + // TODO: the buffer will need to be made owning later + symints.append(si.as_int_unchecked()); + } + + auto result = values_from_buffer(self, symints); + c10::SymInt* start = (c10::SymInt*)result[0]; + int64_t len = result[1]; + + return c10::SymIntArrayRef(start, len); + END_HANDLE_TH_ERRORS_PYBIND +} + +c10::Layout concrete_layout_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self) { + pybind11::gil_scoped_acquire gil; + at::impl::MaybeSetTLSOnEntryGuard guard; + + auto out = torchDispatchFromTensorImpl( + self, + "layout", + py::module::import("torch") + .attr("ops") + .attr("prim") + .attr("layout") + .attr("default") + .ptr(), + "torch.ops.prim"); + + TORCH_CHECK( + THPLayout_Check(out.ptr()), + "layout returned invalid type ", + py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), + ", expected Layout"); + + return toLayout(out.ptr()); +} + } // anonymous namespace diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 92d00be2da33d..2e9a83f617a26 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace py = pybind11; diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 4dc62f007a55b..912f165de8357 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -105,22 +105,28 @@ inline Variable valueToTensor( } at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove at::tracer::impl::NoTracerDispatchMode tracer_guard; + Scalar scalar; if (THPUtils_checkLong(value) || PyBool_Check(value)) { - return at::indexing::scalarToTensor( - Scalar(THPUtils_unpackLong(value)), options, device); - } - if (PyFloat_Check(value)) { - return at::indexing::scalarToTensor( - Scalar(THPUtils_unpackDouble(value)), options, device); - } - if (PyComplex_Check(value)) { - return at::indexing::scalarToTensor( - Scalar(THPUtils_unpackComplexDouble(value)), options, device); + scalar = Scalar(THPUtils_unpackLong(value)); + } else if (PyFloat_Check(value)) { + scalar = Scalar(THPUtils_unpackDouble(value)); + } else if (PyComplex_Check(value)) { + scalar = Scalar(THPUtils_unpackComplexDouble(value)); + } else { + throw TypeError( + "can't assign a %s to a %s", + Py_TYPE(value)->tp_name, + torch::utils::options_to_string(options).c_str()); + } + // lift_fresh is supposed to be used in situations where you are guaranteed to + // get a plain Tensor which is not true for cpu device but not for non cpu + // device + if (device == at::kCPU) { + return at::lift_fresh( + at::indexing::scalarToTensor(scalar, options, device)); + } else { + return at::indexing::scalarToTensor(scalar, options, device); } - throw TypeError( - "can't assign a %s to a %s", - Py_TYPE(value)->tp_name, - torch::utils::options_to_string(options).c_str()); } static inline void checkUnpackSlice( diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 373c26aebab0b..aa3f17f1adfa9 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -154,6 +154,40 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) { return get_autograd_meta(self); } +void update_cpp_hooks_on_new_gradfn( + const at::TensorBase& self, + const std::shared_ptr& new_fn) { + // This function is called whenever the grad_fn of the tensor is + // changed. We assume here that new_fn does not yet have hooks of + // its own + // + // This function does two things: + const auto& meta = impl::get_autograd_meta(self); + TORCH_INTERNAL_ASSERT(meta); + TORCH_INTERNAL_ASSERT(new_fn); + if (!self.retains_grad()) { + // (1) reset the list when grad_fn is updated, so new hooks don't + // get erroneously registered to the old grad_fn. + // Note that the old cpp_hooks_list_ is still kept alive by the + // old grad_fn so hooks registered to the older version of the tensor + // will continue to be active. + meta->cpp_hooks_list_ = nullptr; + return; + } + // (2) If there is a retains_grad hook registered, move that from the + // old cpp_hooks_list_ to the new one + auto idx = meta->retains_grad_; + auto new_list = std::make_shared(); + new_list->push_back(std::move((*meta->cpp_hooks_list_)[idx])); + (*meta->cpp_hooks_list_)[idx] = nullptr; + meta->cpp_hooks_list_ = new_list; + // Since this is a new list, 0 is the index of the retains_grad hook + meta->retains_grad_ = 0; + std::unique_ptr hook_ptr( + new CppFunctionPreHook(meta->cpp_hooks_list_, self.output_nr())); + new_fn->add_pre_hook(std::move(hook_ptr)); +} + void rebase_history(const Variable& self, Edge gradient_edge) { TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr); auto diff_view_meta = get_view_autograd_meta(self); @@ -181,6 +215,8 @@ void rebase_history(const Variable& self, Edge gradient_edge) { } set_gradient_edge(self, std::move(gradient_edge)); + // Pass both self and its grad_fn to avoid calling into grad_fn reentrantly + torch::autograd::impl::update_cpp_hooks_on_new_gradfn(self, self.grad_fn()); } void create_cpp_hook(const at::TensorBase& self) { @@ -495,7 +531,7 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const { if (self.is_leaf()) { // no-op for leaves return; } - if (impl::get_autograd_meta(self)->retains_grad_) { + if (impl::get_autograd_meta(self)->retains_grad_ != -1) { return; } c10::weak_intrusive_ptr weak_self(self.getIntrusivePtr()); @@ -517,13 +553,13 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const { } }; - at::OptionalTensorRef(self)->register_hook(retain_grad_hook); - impl::get_autograd_meta(self)->retains_grad_ = true; + auto idx = at::OptionalTensorRef(self)->register_hook(retain_grad_hook); + impl::get_autograd_meta(self)->retains_grad_ = idx; } bool VariableHooks::retains_grad(const at::TensorBase& self) const { if (impl::get_autograd_meta(self)) { - return impl::get_autograd_meta(self)->retains_grad_; + return impl::get_autograd_meta(self)->retains_grad_ != -1; } else { return false; } @@ -656,11 +692,15 @@ const std::shared_ptr& VariableHooks::grad_fn( torch::autograd::collect_next_edges(view_info.base_)); fn->add_input_metadata( view_info.base_.options(), - self.sizes(), // Note: sizes(), not base_.sizes(), is intentional + self.sym_sizes(), // Note: sizes(), not base_.sizes(), is + // intentional self.unsafeGetTensorImpl()->is_python_dispatch()); diff_view_meta->grad_fn_ = std::move(fn); } diff_view_meta->set_attr_version(current_version); + + torch::autograd::impl::update_cpp_hooks_on_new_gradfn( + self, diff_view_meta->grad_fn_); } return diff_view_meta->grad_fn_; } diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index c20f6b8d80f7e..b9603696dce24 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -222,8 +222,10 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { // Only meaningful on leaf variables (must be false otherwise) bool requires_grad_; - // Only meaningful on non-leaf variables (must be false otherwise) - bool retains_grad_; + // Only meaningful on non-leaf variables (must be -1 otherwise) + // The value of retains_grad_ indicates the index of it in cpp_hooks_list_ + // A value of -1 indicates that the tensor does not retain grad + int64_t retains_grad_; bool is_view_; @@ -281,7 +283,7 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { Edge gradient_edge = Edge()) { grad_fn_ = std::move(gradient_edge.function); requires_grad_ = false; - retains_grad_ = false; + retains_grad_ = -1; is_view_ = false; output_nr_ = gradient_edge.input_nr; diff --git a/torch/csrc/cuda/Event.cpp b/torch/csrc/cuda/Event.cpp index da81879bccebc..72b740cecfe19 100644 --- a/torch/csrc/cuda/Event.cpp +++ b/torch/csrc/cuda/Event.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 86b7418dfa86b..6b80e5a30f18d 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1,4 +1,10 @@ #include +#include +#if AT_CUDNN_ENABLED() + +#include + +#endif #include #include #include @@ -11,6 +17,7 @@ #ifdef USE_NCCL #include #endif +#include #include #include @@ -41,7 +48,7 @@ static bool in_bad_fork = false; // True for children forked after cuda init // Called in the forked child if cuda has already been initialized static void forked_child() { in_bad_fork = true; - torch::utils::set_run_yet_variable_to_false(); + torch::utils::set_requires_cuda_init(true); } #endif @@ -50,8 +57,8 @@ static void forked_child() { // has some working functions (e.g. device_count) but cannot fully initialize. static void poison_fork() { #ifndef WIN32 - static std::once_flag flag; - std::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); + static c10::once_flag flag; + c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); #endif } @@ -726,6 +733,32 @@ static PyObject* THCPModule_isCurrentStreamCapturing_wrap( END_HANDLE_TH_ERRORS } +PyObject* THCPModule_setBenchmarkLimitCuDNN(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + THPUtils_checkLong(arg), + "set_benchmark_limit_cudnn expects an int, " + "but got %s", + THPUtils_typename(arg)); + auto benchmark_limit = static_cast(THPUtils_unpackLong(arg)); +#if defined(USE_ROCM) + TORCH_WARN_ONCE( + "cuDNN Benchmark limit is not supported in MIOpen and will have no effect."); +#endif +#if AT_CUDNN_ENABLED() +#if HAS_CUDNN_V8() + at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit); +#else + TORCH_WARN_ONCE( + "cuDNN Benchmark limit is not supported with cuDNN v7 API and will have no effect."); +#endif +#endif + Py_RETURN_NONE; +} + +PyObject* THCPModule_benchmarkLimitCuDNN(PyObject* _unused, PyObject* noargs) { + return THPUtils_packInt32(at::globalContext().benchmarkLimitCuDNN()); +} + // NOLINTNEXTLINE(modernize-avoid-c-arrays, // cppcoreguidelines-avoid-non-const-global-variables, // cppcoreguidelines-avoid-c-arrays) @@ -813,6 +846,14 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_cudaJiteratorCompileAndLaunchKernel, METH_VARARGS, nullptr}, + {"_cuda_get_cudnn_benchmark_limit", + THCPModule_benchmarkLimitCuDNN, + METH_NOARGS, + nullptr}, + {"_cuda_set_cudnn_benchmark_limit", + THCPModule_setBenchmarkLimitCuDNN, + METH_O, + nullptr}, #ifdef USE_NCCL {"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr}, {"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr}, diff --git a/torch/csrc/cuda/Stream.cpp b/torch/csrc/cuda/Stream.cpp index b604529f035c9..e5b970bf16058 100644 --- a/torch/csrc/cuda/Stream.cpp +++ b/torch/csrc/cuda/Stream.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/cuda/THCP.h b/torch/csrc/cuda/THCP.h index 12b7042ebb49d..697a66dc3ee91 100644 --- a/torch/csrc/cuda/THCP.h +++ b/torch/csrc/cuda/THCP.h @@ -5,7 +5,6 @@ #include #include #include -#include #include #endif diff --git a/torch/csrc/cuda/Tensor.cpp b/torch/csrc/cuda/Tensor.cpp index 4d6bd08ef0402..beb81f187a6e2 100644 --- a/torch/csrc/cuda/Tensor.cpp +++ b/torch/csrc/cuda/Tensor.cpp @@ -11,12 +11,6 @@ #include #include -#include #include #include - -// generic_include THC torch/csrc/generic/Tensor.cpp - -#include -#include // clang-format on diff --git a/torch/csrc/cuda/override_macros.h b/torch/csrc/cuda/override_macros.h deleted file mode 100644 index ffd1efd26e46a..0000000000000 --- a/torch/csrc/cuda/override_macros.h +++ /dev/null @@ -1,34 +0,0 @@ -#include - -#define THPTensorPtr THCPTensorPtr - -#define THWTensor THCTensor -#define THWTensor_(NAME) THCTensor_(NAME) - -#define THPTensor_(NAME) TH_CONCAT_4(THCP, Real, Tensor_, NAME) -#define THPTensor_stateless_(NAME) \ - TH_CONCAT_4(THCP, Real, Tensor_stateless_, NAME) -#define THPTensor THCPTensor -#define THPTensorStr THCPTensorStr -#define THPTensorBaseStr THCPTensorBaseStr -#define THPTensorClass THCPTensorClass -#define THPTensorType THCPTensorType - -#define THPTensorStatelessType THCPTensorStatelessType -#define THPTensorStateless THCPTensorStateless - -#define THSPTensorPtr THCSPTensorPtr - -#define THSPTensor_(NAME) TH_CONCAT_4(THCSP, Real, Tensor_, NAME) -#define THSPTensor_stateless_(NAME) \ - TH_CONCAT_4(THCSP, Real, Tensor_stateless_, NAME) -#define THSPTensor THCSPTensor -#define THSPTensorStr THCSPTensorStr -#define THSPTensorBaseStr THCSPTensorBaseStr -#define THSPTensorClass THCSPTensorClass -#define THSPTensorType THCSPTensorType - -#define THSPTensorStatelessType THCSPTensorStatelessType -#define THSPTensorStateless THCSPTensorStateless - -#define TH_GENERIC_FILE THC_GENERIC_FILE diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index 0eab79dc7a800..9e520c24e246b 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/cuda/restore_macros.h b/torch/csrc/cuda/restore_macros.h deleted file mode 100644 index e76c600f12e5e..0000000000000 --- a/torch/csrc/cuda/restore_macros.h +++ /dev/null @@ -1,12 +0,0 @@ -#include - -#define THWTensor TH_CONCAT_3(TH, Real, Tensor) -#define THWTensor_(NAME) TH_CONCAT_4(TH, Real, Tensor_, NAME) - -#define THPTensor TH_CONCAT_3(THP, Real, Tensor) -#define THPTensorStr TH_CONCAT_STRING_3(torch., Real, Tensor) -#define THPTensorClass TH_CONCAT_3(THP, Real, TensorClass) -#define THPTensor_(NAME) TH_CONCAT_4(THP, Real, Tensor_, NAME) - -#define THWTensorPtr TH_CONCAT_3(TH, Real, TensorPtr) -#define THPTensorPtr TH_CONCAT_3(THP, Real, TensorPtr) diff --git a/torch/csrc/cuda/undef_macros.h b/torch/csrc/cuda/undef_macros.h deleted file mode 100644 index e69b2d374ebda..0000000000000 --- a/torch/csrc/cuda/undef_macros.h +++ /dev/null @@ -1,30 +0,0 @@ -#undef TH_GENERIC_FILE - -#undef THPTensor_ -#undef THPTensor_stateless_ -#undef THPTensor -#undef THPTensorStr -#undef THPTensorBaseStr -#undef THPTensorClass - -#undef THPTensorStatelessType -#undef THPTensorStateless -#undef THPTensorType - -#undef THPTensorPtr - -#undef THWTensor -#undef THWTensor_ - -#undef THSPTensor_ -#undef THSPTensor_stateless_ -#undef THSPTensor -#undef THSPTensorStr -#undef THSPTensorBaseStr -#undef THSPTensorClass - -#undef THSPTensorStatelessType -#undef THSPTensorStateless -#undef THSPTensorType - -#undef THSPTensorPtr diff --git a/torch/csrc/cuda/utils.cpp b/torch/csrc/cuda/utils.cpp index b3b33230f3639..fd2891e323953 100644 --- a/torch/csrc/cuda/utils.cpp +++ b/torch/csrc/cuda/utils.cpp @@ -3,8 +3,6 @@ #include #include -#include - #ifdef USE_CUDA // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use // whatever the current stream of the device the input is associated with was. diff --git a/torch/csrc/cuda/utils.h b/torch/csrc/cuda/utils.h deleted file mode 100644 index fa29394b563ab..0000000000000 --- a/torch/csrc/cuda/utils.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef THCP_UTILS_H -#define THCP_UTILS_H - -#define THCPUtils_(NAME) TH_CONCAT_4(THCP, Real, Utils_, NAME) - -#define THCTensorPtr TH_CONCAT_3(THC, Real, TensorPtr) -#define THCPStoragePtr TH_CONCAT_3(THCP, Real, StoragePtr) -#define THCPTensorPtr TH_CONCAT_3(THCP, Real, TensorPtr) - -#define THCSTensorPtr TH_CONCAT_3(THCS, Real, TensorPtr) -#define THCSPTensorPtr TH_CONCAT_3(THCSP, Real, TensorPtr) - -#include - -#define THC_GENERIC_FILE "torch/csrc/generic/utils.h" -#include -#endif diff --git a/torch/csrc/deploy/interpreter/interpreter_impl.cpp b/torch/csrc/deploy/interpreter/interpreter_impl.cpp index 2af33582aa6df..bca5e54083fff 100644 --- a/torch/csrc/deploy/interpreter/interpreter_impl.cpp +++ b/torch/csrc/deploy/interpreter/interpreter_impl.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/deploy/test_deploy_lib.cpp b/torch/csrc/deploy/test_deploy_lib.cpp index a6089353d6435..cac0b539c0434 100644 --- a/torch/csrc/deploy/test_deploy_lib.cpp +++ b/torch/csrc/deploy/test_deploy_lib.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include diff --git a/torch/csrc/deploy/test_deploy_python_ext.cpp b/torch/csrc/deploy/test_deploy_python_ext.cpp index 28d964f490317..2c7748b6f46b1 100644 --- a/torch/csrc/deploy/test_deploy_python_ext.cpp +++ b/torch/csrc/deploy/test_deploy_python_ext.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include diff --git a/torch/csrc/distributed/c10d/FileStore.cpp b/torch/csrc/distributed/c10d/FileStore.cpp index 279ca57ab5072..fa41318ecde1a 100644 --- a/torch/csrc/distributed/c10d/FileStore.cpp +++ b/torch/csrc/distributed/c10d/FileStore.cpp @@ -69,7 +69,7 @@ namespace c10d { namespace { template -typename std::result_of::type syscall(F fn) { +typename c10::invoke_result_t syscall(F fn) { while (true) { auto rv = fn(); if (rv == -1) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index c262cedc9da8b..0e4a7df03dc72 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,5 +1,7 @@ #include +#include + #ifdef USE_C10D_NCCL #include @@ -24,10 +26,10 @@ ncclComm_t NCCLComm::getNcclComm() { } std::string getNcclVersion() { - static std::once_flag ncclGetVersionFlag; + static c10::once_flag ncclGetVersionFlag; static std::string versionString; - std::call_once(ncclGetVersionFlag, []() { + c10::call_once(ncclGetVersionFlag, []() { int version; ncclResult_t status = ncclGetVersion(&version); // can't compute the version if call did not return successfully or version diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 18218a4cb1854..7ee87ad9ef771 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -5,12 +5,12 @@ namespace c10d { namespace { -c10::intrusive_ptr broadcast( - const c10::intrusive_ptr& process_group, +c10::intrusive_ptr broadcast_( at::TensorList tensors, - int64_t root_rank = 0, - int64_t root_tensor = 0, - int64_t timeout = -1) { + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t root_tensor, + int64_t timeout) { auto tensor_vec = tensors.vec(); return process_group->broadcast( tensor_vec, @@ -18,6 +18,126 @@ c10::intrusive_ptr broadcast( root_rank, root_tensor, std::chrono::milliseconds(timeout)}); } +c10::intrusive_ptr allreduce_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + return process_group->allreduce( + tensor_vec, + AllreduceOptions{ + static_cast(reduce_op), + std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr allgather_( + const std::vector>& output_tensors, + const std::vector& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + return process_group->allgather( + const_cast>&>(output_tensors), + const_cast&>(input_tensors), + AllgatherOptions{std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr reduce_scatter_( + const std::vector& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t reduce_op, + int64_t timeout) { + return process_group->reduce_scatter( + const_cast&>(output_tensors), + const_cast>&>(input_tensors), + ReduceScatterOptions{ + static_cast(reduce_op), + std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr reduce_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t reduce_op, + int64_t root_rank, + int64_t root_tensor, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + return process_group->reduce( + tensor_vec, + ReduceOptions{ + static_cast(reduce_op), + root_rank, + root_tensor, + std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr gather_( + const std::vector>& output_tensors, + const std::vector& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + return process_group->gather( + const_cast>&>(output_tensors), + const_cast&>(input_tensors), + GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr scatter_( + const std::vector& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + return process_group->scatter( + const_cast&>(output_tensors), + const_cast>&>(input_tensors), + ScatterOptions{root_rank, std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr alltoall_( + at::TensorList output_tensors, + at::TensorList input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto input_tensors_vec = input_tensors.vec(); + return process_group->alltoall( + output_tensors_vec, + input_tensors_vec, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr barrier( + const c10::intrusive_ptr& process_group, + const std::vector& device_ids, + int64_t timeout) { + return process_group->barrier( + BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr send( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t dstRank, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->send( + tensor_vec, static_cast(dstRank), static_cast(tag)); +} + +c10::intrusive_ptr recv_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t srcRank, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->recv( + tensor_vec, static_cast(srcRank), static_cast(tag)); +} + TORCH_LIBRARY(c10d, m) { // The following ProcessGroup and Work definations are more like declarations. // They don't expose the details of the two classes into TorchScript. @@ -27,8 +147,34 @@ TORCH_LIBRARY(c10d, m) { // enable // __torch_dispatch__. m.def( - "broadcast", - dispatch(c10::DispatchKey::CompositeExplicitAutograd, broadcast)); + "broadcast_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, broadcast_)); + m.def( + "allreduce_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, allreduce_)); + m.def( + "allgather_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_)); + m.def( + "reduce_scatter_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_)); + m.def( + "reduce_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_)); + m.def( + "gather_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, gather_)); + m.def( + "scatter_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, scatter_)); + m.def( + "alltoall_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_)); + m.def( + "barrier", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, barrier)); + m.def("send", dispatch(c10::DispatchKey::CompositeExplicitAutograd, send)); + m.def("recv_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, recv_)); } } // namespace @@ -38,24 +184,201 @@ c10::intrusive_ptr broadcast( const c10::intrusive_ptr& process_group, at::TensorList tensors, const BroadcastOptions& opts) { - auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::broadcast", "") - .typed( - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - at::TensorList, - int64_t, - int64_t, - int64_t)>(); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::broadcast_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t, + int64_t, + int64_t)>(); // It's awakward to unbox the opts here and box them again in the custom C++ // op. But it's also complicated to make opts as a CustomClassHolder. Leave it // as it is now. return op.call( + tensors, + process_group, + opts.rootRank, + opts.rootTensor, + opts.timeout.count()); +} + +c10::intrusive_ptr allreduce( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allreduce_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t, + int64_t)>(); + return op.call( + tensors, process_group, + static_cast(opts.reduceOp), + opts.timeout.count()); +} + +c10::intrusive_ptr allgather( + const c10::intrusive_ptr& process_group, + const std::vector>& output_tensors, + const std::vector& input_tensors, + const AllgatherOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allgather_", "") + .typed( + const std::vector>&, + const std::vector&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t)>(); + return op.call( + output_tensors, input_tensors, process_group, opts.timeout.count()); +} + +c10::intrusive_ptr reduce_scatter( + const c10::intrusive_ptr& process_group, + const std::vector& output_tensors, + const std::vector>& input_tensors, + const ReduceScatterOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::reduce_scatter_", "") + .typed( + const std::vector&, + const std::vector>&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t, + int64_t)>(); + return op.call( + output_tensors, + input_tensors, + process_group, + static_cast(opts.reduceOp), + opts.timeout.count()); +} + +c10::intrusive_ptr reduce( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const ReduceOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::reduce_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t, + int64_t, + int64_t, + int64_t)>(); + return op.call( tensors, + process_group, + static_cast(opts.reduceOp), opts.rootRank, opts.rootTensor, opts.timeout.count()); } +c10::intrusive_ptr gather( + const c10::intrusive_ptr& process_group, + const std::vector>& output_tensors, + const std::vector& input_tensors, + const GatherOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::gather_", "") + .typed( + const std::vector>&, + const std::vector&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t, + int64_t)>(); + return op.call( + output_tensors, + input_tensors, + process_group, + opts.rootRank, + opts.timeout.count()); +} + +c10::intrusive_ptr scatter( + const c10::intrusive_ptr& process_group, + const std::vector& output_tensors, + const std::vector>& input_tensors, + const ScatterOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::scatter_", "") + .typed( + const std::vector&, + const std::vector>&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t, + int64_t)>(); + return op.call( + output_tensors, + input_tensors, + process_group, + opts.rootRank, + opts.timeout.count()); +} + +c10::intrusive_ptr alltoall( + const c10::intrusive_ptr& process_group, + at::TensorList output_tensors, + at::TensorList input_tensors, + const AllToAllOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::alltoall_", "") + .typed( + at::TensorList, + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t)>(); + return op.call( + output_tensors, input_tensors, process_group, opts.timeout.count()); +} + +c10::intrusive_ptr barrier( + const c10::intrusive_ptr& process_group, + const BarrierOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::barrier", "") + .typed( + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + const std::vector&, + int64_t)>(); + return op.call(process_group, opts.device_ids, opts.timeout.count()); +} + +c10::intrusive_ptr send( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + int64_t dstRank, + int64_t tag) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::send", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t, + int64_t)>(); + return op.call(tensors, process_group, dstRank, tag); +} + +c10::intrusive_ptr recv( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + int64_t srcRank, + int64_t tag) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::recv_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t, + int64_t)>(); + return op.call(tensors, process_group, srcRank, tag); +} + } // namespace ops } // namespace c10d diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index 2c1a807911806..1d2b7b343c0f4 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -6,9 +6,71 @@ namespace c10d { namespace ops { -// Belows are essentially ProcessGroup's corresponding ops but routed to the dispatcher. -TORCH_API c10::intrusive_ptr broadcast(const c10::intrusive_ptr& process_group, - at::TensorList tensors, const BroadcastOptions& opts = {}); +// Below are essentially ProcessGroup's corresponding ops but routed to the +// dispatcher. To be noted, it's a convention to use at::TensorList to represent +// const std::vector&. However, const std::vector& is +// used whenever the API accepts std::vector>& to keep +// consistency. +TORCH_API c10::intrusive_ptr broadcast( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const BroadcastOptions& opts = {}); + +TORCH_API c10::intrusive_ptr allreduce( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceOptions& opts = {}); + +TORCH_API c10::intrusive_ptr allgather( + const c10::intrusive_ptr& process_group, + const std::vector>& output_tensors, + const std::vector& input_tensors, + const AllgatherOptions& opts = {}); + +TORCH_API c10::intrusive_ptr reduce_scatter( + const c10::intrusive_ptr& process_group, + const std::vector& output_tensors, + const std::vector>& input_tensors, + const ReduceScatterOptions& opts = {}); + +TORCH_API c10::intrusive_ptr reduce( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const ReduceOptions& opts = {}); + +TORCH_API c10::intrusive_ptr gather( + const c10::intrusive_ptr& process_group, + const std::vector>& output_tensors, + const std::vector& input_tensors, + const GatherOptions& opts = {}); + +TORCH_API c10::intrusive_ptr scatter( + const c10::intrusive_ptr& process_group, + const std::vector& output_tensors, + const std::vector>& input_tensors, + const ScatterOptions& opts = {}); + +TORCH_API c10::intrusive_ptr alltoall( + const c10::intrusive_ptr& process_group, + at::TensorList output_tensors, + at::TensorList input_tensors, + const AllToAllOptions& opts = {}); + +TORCH_API c10::intrusive_ptr barrier( + const c10::intrusive_ptr& process_group, + const BarrierOptions& opts = {}); + +TORCH_API c10::intrusive_ptr send( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + int64_t dstRank, + int64_t tag); + +TORCH_API c10::intrusive_ptr recv( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + int64_t srcRank, + int64_t tag); } // namespace ops } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index f31a7d8c8153b..fde76d9f50399 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -191,4 +191,56 @@ void ProcessGroup::init() { fmt::format("c10d.process_group_{}", getBackendName())); } +class FutureWrappingWork : public ProcessGroup::Work { + public: + FutureWrappingWork(c10::intrusive_ptr fut) + : Work(), _fut(fut) {} + + ~FutureWrappingWork() {} + + bool isCompleted() override { + return _fut->completed(); + } + + bool isSuccess() const override { + return _fut->hasValue(); + } + + std::exception_ptr exception() const override { + return _fut->exception_ptr(); + } + + int sourceRank() const override { + TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented"); + } + + std::vector result() override { + return _fut->value().toPyObjectHolder()->extractTensors(); + } + + bool wait(std::chrono::milliseconds timeout) override { + // FIXME + TORCH_CHECK( + timeout == kNoTimeout, + "FutureWrappingWork::wait() with finite timeout not implemented"); + _fut->wait(); + return true; + } + + void abort() override { + TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented"); + } + + c10::intrusive_ptr getFuture() override { + return _fut; + } + + private: + c10::intrusive_ptr _fut; +}; + +c10::intrusive_ptr ProcessGroup::Work::create_from_future( + c10::intrusive_ptr future) { + return c10::make_intrusive(future); +} } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index b6787092c88e4..f81358a87becc 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -149,6 +149,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { OpType retrieveOpType(); + static c10::intrusive_ptr create_from_future(c10::intrusive_ptr); + protected: // Completes the work object and optionally sets the exception in a // thread-safe manner. Notifies all waiting condition variables as well. diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index e2f28507a8db6..c1a5c83c48c93 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -221,7 +221,7 @@ void ProcessGroupMPI::AsyncWork::populateException() { int ProcessGroupMPI::mpiThreadSupport_ = 0; std::mutex ProcessGroupMPI::pgGlobalMutex_; // We only want to initialize once -std::once_flag ProcessGroupMPI::onceFlagInitMPI; +c10::once_flag ProcessGroupMPI::onceFlagInitMPI; void ProcessGroupMPI::mpiExit() { std::unique_lock globalLock(pgGlobalMutex_); @@ -230,7 +230,7 @@ void ProcessGroupMPI::mpiExit() { void ProcessGroupMPI::initMPIOnce() { // Initialize MPI environment - std::call_once(onceFlagInitMPI, []() { + c10::call_once(onceFlagInitMPI, []() { MPI_CHECK(MPI_Init_thread( nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_)); if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp index 8067b13b7bf6d..93bb3113f00c2 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp @@ -17,6 +17,8 @@ #include #include +#include + #include namespace c10d { @@ -256,7 +258,7 @@ class TORCH_API ProcessGroupMPI : public ProcessGroup { // Global states static void initMPIOnce(); static void mpiExit(); - static std::once_flag onceFlagInitMPI; + static c10::once_flag onceFlagInitMPI; static std::mutex pgGlobalMutex_; static int mpiThreadSupport_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 6aee5c0a9ca49..ae04f9bc702d3 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -600,8 +601,8 @@ ProcessGroupNCCL::ProcessGroupNCCL( << options_->is_high_priority_stream; #ifdef USE_NCCL_WITH_UCC - static std::once_flag initialize_ucc_lib_flag; - std::call_once(initialize_ucc_lib_flag, [&] { + static c10::once_flag initialize_ucc_lib_flag; + c10::call_once(initialize_ucc_lib_flag, [&] { uccLib_ = loadTorchUCC(); if (uccLib_ != nullptr) { LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded"; diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp new file mode 100644 index 0000000000000..191ba4b2ddd75 --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp @@ -0,0 +1,1870 @@ +#ifdef USE_C10D_UCC + +#include +#include +#include +#include +#include + +namespace c10d { + +namespace { +constexpr int64_t kBusyWaitMillis = 10; + +const std::map ucs_mtype_map = { + {c10::kCPU, UCS_MEMORY_TYPE_HOST}, + {c10::kCUDA, UCS_MEMORY_TYPE_CUDA}, +}; + +ucs_memory_type_t to_ucs_memType(c10::DeviceType _c10_type) { + if (ucs_mtype_map.find(_c10_type) != ucs_mtype_map.end()) + return ucs_mtype_map.at(_c10_type); + else + return UCS_MEMORY_TYPE_UNKNOWN; +} + +const std::map ucc_mtype_map = { + {c10::kCPU, UCC_MEMORY_TYPE_HOST}, + {c10::kCUDA, UCC_MEMORY_TYPE_CUDA}, +}; + +ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) { + if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end()) + return ucc_mtype_map.at(_c10_type); + else + return UCC_MEMORY_TYPE_UNKNOWN; +} + +const std::map ucc_dtype_map = { + {at::kByte, UCC_DT_UINT8}, + {at::kChar, UCC_DT_INT8}, + {at::kHalf, UCC_DT_FLOAT16}, + {at::kBFloat16, UCC_DT_BFLOAT16}, + {at::kDouble, UCC_DT_FLOAT64}, + {at::kFloat, UCC_DT_FLOAT32}, + {at::kInt, UCC_DT_INT32}, + {at::kLong, UCC_DT_INT64}, + {at::kBool, UCC_DT_UINT8}, +}; + +ucc_datatype_t to_ucc_dType(at::Tensor _tensor) { + if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) { + TORCH_CHECK( + false, "Size of Boolean type larger than 1 is not supported in UCC"); + } + try { + return ucc_dtype_map.at(_tensor.scalar_type()); + } catch (const std::out_of_range& e) { + TORCH_CHECK(false, "Not supported data type for UCC"); + } +} + +const std::map ucc_op_map = { + {ReduceOp::SUM, UCC_OP_SUM}, + {ReduceOp::PRODUCT, UCC_OP_PROD}, + {ReduceOp::MIN, UCC_OP_MIN}, + {ReduceOp::MAX, UCC_OP_MAX}, + {ReduceOp::BAND, UCC_OP_BAND}, + {ReduceOp::BOR, UCC_OP_BOR}, + {ReduceOp::BXOR, UCC_OP_BXOR}, + {ReduceOp::AVG, UCC_OP_AVG}, +}; + +ucc_reduction_op_t to_ucc_reduceOp( + const ReduceOp _op, + const at::ScalarType _dt) { + if (_dt == at::kBool) { + if (_op == ReduceOp::SUM) { + // bitwise or + return UCC_OP_MAX; + } else if (_op == ReduceOp::PRODUCT) { + // bitwise and + return UCC_OP_MIN; + } else if (_op == ReduceOp::AVG) { + TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs"); + } + } + + try { + return ucc_op_map.at(_op); + } catch (const std::out_of_range& e) { + TORCH_CHECK(false, "Not supported ReduceOp for UCC"); + } +} + +struct torch_ucc_config_t { + c10::once_flag flag; + std::array blocking_wait; + bool enable_profiling; + bool enable_comms_logger; + bool use_future; + // Sharing UCC communicator among multiple PGs to save resource. + bool shared_comm; + // Using allgatherv to achieve allgather, without flattening the list of + // (potentially non-contiguous) tensors. + bool use_allgatherv; + bool enable_health_check; +} torch_ucc_config; + +// TODO: support UCC_BLOCKING_WAIT that applies to all collectives. +std::map torch_ucc_envs_map = { + {"TORCH_UCC_ALLGATHER_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_ALLGATHER_BASE_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_ALLREDUCE_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_ALLTOALL_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_BCAST_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_GATHER_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_REDUCE_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_REDUCE_SCATTER_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_SCATTER_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_SEND_BLOCKING_WAIT", "0"}, + {"TORCH_UCC_RECV_BLOCKING_WAIT", "0"}, + + {"TORCH_UCC_USE_FUTURE", "1"}, + {"TORCH_UCC_PROFILING_ENABLE", "0"}, + {"TORCH_UCC_SHARED_COMM", "1"}, + {"TORCH_UCC_USE_ALLGATHERV", "0"}, + {"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"}, + {"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"}, +}; + +} // namespace + +void read_confg() { + // default configuration + torch_ucc_config.blocking_wait.fill(true); + torch_ucc_config.enable_profiling = false; + torch_ucc_config.use_future = true; + torch_ucc_config.shared_comm = false; + torch_ucc_config.use_allgatherv = false; + torch_ucc_config.enable_health_check = false; + torch_ucc_config.enable_comms_logger = false; + + // read all torch_ucc env. variables and update the map + char* env; + for (auto& torch_ucc_env : torch_ucc_envs_map) { + env = std::getenv(torch_ucc_env.first.c_str()); + if (env) { + torch_ucc_envs_map[torch_ucc_env.first] = std::string(env); + } + } + +#define BUILD_BLOCKING_CFG(op, str) \ + (torch_ucc_config.blocking_wait[(std::uint8_t)op] = \ + std::stoi(torch_ucc_envs_map.at(str))) + + BUILD_BLOCKING_CFG(OpType::ALLGATHER, "TORCH_UCC_ALLGATHER_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG( + OpType::_ALLGATHER_BASE, "TORCH_UCC_ALLGATHER_BASE_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG(OpType::ALLREDUCE, "TORCH_UCC_ALLREDUCE_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG(OpType::ALLTOALL_BASE, "TORCH_UCC_ALLTOALL_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG(OpType::BROADCAST, "TORCH_UCC_BCAST_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG(OpType::GATHER, "TORCH_UCC_GATHER_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG(OpType::REDUCE, "TORCH_UCC_REDUCE_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG( + OpType::REDUCE_SCATTER, "TORCH_UCC_REDUCE_SCATTER_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG(OpType::SCATTER, "TORCH_UCC_SCATTER_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG(OpType::SEND, "TORCH_UCC_SEND_BLOCKING_WAIT"); + BUILD_BLOCKING_CFG(OpType::RECV, "TORCH_UCC_RECV_BLOCKING_WAIT"); +#undef BUILD_BLOCKING_CFG + + torch_ucc_config.use_future = + std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE")); + torch_ucc_config.enable_profiling = + std::stoi(torch_ucc_envs_map.at("TORCH_UCC_PROFILING_ENABLE")); + torch_ucc_config.shared_comm = + std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM")); + torch_ucc_config.use_allgatherv = + std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV")); + torch_ucc_config.enable_health_check = + std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK")); + torch_ucc_config.enable_comms_logger = + std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER")); +} + +void check_device(c10::Device dev1, c10::Device dev2) { + if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) { + throw std::runtime_error("ProcessGroupUCC multidevice is not supported"); + } +} + +void check_tensor(const std::vector& tensors) { + if (tensors.size() != 1) { + throw std::runtime_error( + "ProcessGroupUCC takes 1 tensor. Got " + + std::to_string(tensors.size()) + ". "); + } + if (!tensors[0].is_contiguous()) { + throw std::runtime_error( + "ProcessGroupUCC input tensor has to be contiguous"); + } + if (tensors[0].is_sparse()) { + throw std::runtime_error("ProcessGroupUCC input tensor has to be dense"); + } + // TODO: check cuda case +} + +ProcessGroupUCC::WorkUCC::~WorkUCC() { +#ifdef USE_CUDA + if (fence && ep) { + std::lock_guard lock(ep->event_pool_mutex); + ep->event_pool.push(std::move(fence)); + } +#endif +} + +void ProcessGroupUCC::WorkUCC::setException() { + if (exception() || !entry_) { + return; + } + exception_ = entry_->eptr_; +} + +void ProcessGroupUCC::WorkUCC::setAndThrowException() { + setException(); + if (exception()) { + std::rethrow_exception(exception()); + } +} + +bool ProcessGroupUCC::WorkUCC::isCompleted() { + if (!entry_) { + return true; + } + setException(); + // status_ <= 0 to avoid listing all possible status codes. The main thread + // needs to be unblocked when UCC (in progress thread) returns success (== 0) + // or any error code (< 0). + return exception() || entry_->status_ <= 0; +} + +bool ProcessGroupUCC::WorkUCC::isSuccess() const { + if (!entry_) { + return true; + } + return !exception() && entry_->status_ == 0; +} + +bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { + if (torch_ucc_config.enable_comms_logger && logger_) { + logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_); + } +#ifdef USE_CUDA + if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) { + // block user stream + setAndThrowException(); + fence->block(at::cuda::getCurrentCUDAStream()); + return true; + } +#endif + // wait for complete. For blocking case, the main thread will be blocked in + // this loop until the progress thread changes the status of this request. + // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The + // main thread will throw out the exception then. There is no "abort" + // function in UCC currently. + while (!isCompleted()) + ; + setAndThrowException(); + // manually call profiling end callbacks if they are set, + // since progress thread does not own WorkUCC + if (ProcessGroup::Work::recordFunctionEndCallback_) { + ProcessGroup::Work::recordFunctionEndCallback_(); + ProcessGroup::Work::recordFunctionEndCallback_ = nullptr; + } + return true; +} + +c10::intrusive_ptr ProcessGroupUCC::WorkUCC::getFuture() { + return future_; +} + +std::vector ProcessGroupUCC::WorkUCC::result() { + return *outputs_; +} + +void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) { + ucc_status_t status = UCC_OK; + + if (request_ != nullptr) { + status = request_->status; + comm_->free_request(request_); + } + if (eptr) { + eptr_ = eptr; + } else { + status_ = status; + } + if (future_) { + if (eptr) { + future_->setError(eptr); + } else { + future_->markCompleted( + c10::IValue(data ? data->dst : std::vector())); + } + } +} + +Comm::Comm( + const c10::intrusive_ptr& logger_, + std::shared_ptr oob_, + c10::Device dev, + bool is_health_check) + : logger(logger_), + oob(oob_), + ucx_comm(oob->size, logger), + ucc_comm(oob, logger), + finalize_phase( + is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE), + cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) { + if (dev.is_cuda()) { + cuda_device_index = dev.index(); + } + stop_progress_loop = false; + collective_inprogress = false; + progress_thread = std::thread(&Comm::progress_loop, this); +#ifdef _GNU_SOURCE + pthread_setname_np(progress_thread.native_handle(), "ucc-progress"); +#endif +} + +Comm::~Comm() { + std::unique_lock lock(mutex); + queue_consume_cv.wait( + lock, [&] { return progress_queue.empty() && !collective_inprogress; }); + stop_progress_loop = true; + lock.unlock(); + queue_produce_cv.notify_all(); + progress_thread.join(); +} + +std::shared_ptr Comm::get_comm( + uint32_t& id, + c10::Device dev, + std::shared_ptr oob, + const c10::intrusive_ptr& logger, + bool is_health_check) { + static std::mutex m; + static std::weak_ptr comm; + static uint32_t comm_id; + + std::lock_guard lock(m); + id = (comm_id % TORCH_UCX_MAX_COMM); + + std::string group_id = "group_id"; + if (is_health_check) { + group_id = c10::str(dev.type()) + "/" + group_id; + } + + std::vector remote_comm_id; + oob->store->deleteKey(group_id + std::to_string(0)); + if (oob->rank != 0) { + std::vector val = std::vector( + reinterpret_cast(&id), + reinterpret_cast(&id) + sizeof(id)); + oob->store->set(group_id + std::to_string(oob->rank), val); + } else { + for (int i = 1; i < oob->size; i++) { + remote_comm_id = oob->store->get(group_id + std::to_string(i)); + oob->store->deleteKey(group_id + std::to_string(i)); + // Find the highest id. + id = std::max(id, *(reinterpret_cast(remote_comm_id.data()))); + } + std::vector val = std::vector( + reinterpret_cast(&id), + reinterpret_cast(&id) + sizeof(id)); + oob->store->set(group_id + std::to_string(oob->rank), val); + } + remote_comm_id = oob->store->get(group_id + std::to_string(0)); + oob->comm_id = *(reinterpret_cast(remote_comm_id.data())); + // Prepare comm_id (static variable) to the next id. + comm_id = oob->comm_id + 1; + + if (torch_ucc_config.shared_comm) { + std::shared_ptr shared_comm = comm.lock(); + if (!shared_comm) { + shared_comm = std::make_shared(logger, oob, dev, is_health_check); + comm = shared_comm; + } else { + if (dev.is_cuda() && !is_health_check) { + if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && + (shared_comm->cuda_device_index != dev.index())) { + TORCH_UCC_LOG_ERROR( + is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT, + "ucc communicator was initialized with different cuda device," + "multi device is not supported"); + throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); + } + shared_comm->cuda_device_index = dev.index(); + } + } + return shared_comm; + } else { + return std::make_shared(logger, oob, dev, is_health_check); + } +} + +void Comm::ucx_connect_eps( + std::vector& eps, + std::shared_ptr oob) { + ucp_address_t* local_addr; + size_t local_addr_len; + std::vector peer_addr; + + TORCH_UCX_CHECK( + ucp_worker_get_address(ucx_comm.worker, &local_addr, &local_addr_len), + "failed to get worker address"); + + std::vector val = std::vector( + reinterpret_cast(local_addr), + reinterpret_cast(local_addr) + local_addr_len); + oob->store->set(oob->getKey("wa" + std::to_string(oob->rank)), val); + ucp_worker_release_address(ucx_comm.worker, local_addr); + eps.resize(oob->size); + for (int i = 0; i < oob->size; i++) { + peer_addr = oob->store->get(oob->getKey("wa" + std::to_string(i))); + ucp_ep_params_t ep_params; + ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; + ep_params.address = reinterpret_cast(peer_addr.data()); + TORCH_UCX_CHECK( + ucp_ep_create(ucx_comm.worker, &ep_params, &(eps[i])), + c10::str("failed to create endpoint with rank ", i)); + } +} + +void Comm::ucx_disconnect_eps( + std::vector& eps, + std::shared_ptr oob) { + ucs_status_t st; + + for (ucp_ep_h& ep : eps) { + ucs_status_ptr_t close_req = ucp_ep_close_nb(ep, UCP_EP_CLOSE_MODE_FLUSH); + if (UCS_PTR_IS_ERR(close_req)) { + TORCH_UCC_LOG_ERROR( + finalize_phase, "failed to close endpoint, ignore and continue..."); + return; + } + if (UCS_PTR_IS_PTR(close_req)) { + do { + ucp_worker_progress(ucx_comm.worker); + st = ucp_request_check_status(close_req); + } while (st != UCS_OK); + ucp_request_free(close_req); + } + } + if (!eps.size()) { + return; + } + try { + auto sz = (size_t)oob->store->add(oob->getKey("epclosed"), 1); + while (sz != eps.size()) { + ucp_worker_progress(ucx_comm.worker); + std::this_thread::sleep_for(std::chrono::milliseconds(kBusyWaitMillis)); + sz = (size_t)oob->store->add(oob->getKey("epclosed"), 0); + } + } catch (std::exception& ex) { + LOG(ERROR) << "(disconnect_eps) Caught error in Store Operation .. " + << "[" << ex.what() << "]"; + } +} + +ucc_coll_req_h Comm::send_nb( + ucp_ep_h ep, + void* data, + ucs_memory_type_t mtype, + size_t size, + ucp_tag_t ucp_tag) { + ucs_status_ptr_t st; + ucp_request_param_t params; + params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE; + params.datatype = ucp_dt_make_contig(size); + params.memory_type = mtype; + params.cb.send = [](void* request, ucs_status_t status, void* user_data) { + static_cast(request)->status = UCC_OK; + }; + st = ucp_tag_send_nbx(ep, data, 1, ucp_tag, ¶ms); + if (UCS_PTR_IS_ERR(st)) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, + c10::str( + "failed to send message: ", ucs_status_string(UCS_PTR_STATUS(st)))); + throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st))); + } + return reinterpret_cast(st); +} + +ucc_coll_req_h Comm::recv_nb( + void* data, + ucs_memory_type_t mtype, + size_t size, + ucp_tag_t ucp_tag, + ucp_tag_t ucp_tag_mask) { + ucs_status_ptr_t st; + ucp_request_param_t params; + params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE; + params.datatype = ucp_dt_make_contig(size); + params.cb.recv = [](void* request, + ucs_status_t status, + const ucp_tag_recv_info_t* info, + void* user_data) { + static_cast(request)->status = UCC_OK; + }; + params.memory_type = mtype; + st = ucp_tag_recv_nbx( + ucx_comm.worker, data, 1, ucp_tag, ucp_tag_mask, ¶ms); + if (UCS_PTR_IS_ERR(st)) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, + c10::str( + "failed to recv message: ", ucs_status_string(UCS_PTR_STATUS(st)))); + throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st))); + } + return reinterpret_cast(st); +} + +void Comm::ucc_create_team( + ucc_team_h& team, + std::shared_ptr oob) { + ucc_status_t st; + ucc_team_params_t team_params; + team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE | + UCC_TEAM_PARAM_FIELD_OOB; + team_params.oob.allgather = oob_allgather; + team_params.oob.req_test = oob_allgather_test; + team_params.oob.req_free = oob_allgather_free; + team_params.oob.coll_info = oob.get(); + team_params.oob.n_oob_eps = oob->size; + team_params.oob.oob_ep = oob->rank; + team_params.ep = oob->rank; + team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; + TORCH_UCC_CHECK( + ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team), + "failed to post team create"); + do { + st = ucc_team_create_test(team); + ucc_context_progress(ucc_comm.context); + } while (st == UCC_INPROGRESS); + TORCH_UCC_CHECK(st, "failed to create UCC team"); +} + +void Comm::ucc_destroy_team(ucc_team_h& team) { + std::unique_lock lock(mutex); + queue_consume_cv.wait( + lock, [&] { return progress_queue.empty() && !collective_inprogress; }); + + ucc_status_t status; + while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) { + if (UCC_OK != status) { + TORCH_UCC_LOG_ERROR( + finalize_phase, + c10::str("ucc team destroy error: ", ucc_status_string(status))); + break; + } + } + + lock.unlock(); +} + +c10::intrusive_ptr Comm::enqueue_p2p( + OpType opType, + ucc_coll_req_h request, + const char* prof_title) { + auto work = + c10::make_intrusive(opType, prof_title, logger); + if (torch_ucc_config.use_future) { + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + } + if (request == nullptr) { + // p2p2 request completed immediately don't save it to progress queue + // and mark future completed immediately + if (torch_ucc_config.use_future) { + work->future_->markCompleted(c10::IValue(std::vector())); + } + return work; + } + auto entry = + std::make_shared(&ucx_comm, request); + work->entry_ = entry; + std::unique_lock lock(mutex); + progress_queue.push_back(entry); + lock.unlock(); + queue_produce_cv.notify_one(); + return work; +} + +void Comm::enqueue_collective( + std::unique_ptr data, + c10::intrusive_ptr work, + ucc_coll_args_t& coll, + ucc_team_h team) { + ucc_coll_req_h request; + TORCH_UCC_CHECK( + ucc_collective_init(&coll, &request, team), "failed to init collective"); + TORCH_UCC_CHECK(ucc_collective_post(request), "failed to post collective"); + + auto entry = + std::make_shared(&ucc_comm, request); + entry->data = std::move(data); + entry->future_ = work->getFuture(); + work->entry_ = entry; + std::unique_lock lock(mutex); + progress_queue.push_back(entry); + lock.unlock(); + queue_produce_cv.notify_one(); +} + +#ifdef USE_CUDA +void Comm::enqueue_cuda_collective( + std::unique_ptr data, + c10::intrusive_ptr work, + ucc_coll_args_t& coll, + ucc_team_h team, + ucc_ee_h ee) { + ucc_coll_req_h request; + TORCH_UCC_CHECK( + ucc_collective_init(&coll, &request, team), + "failed to init cuda collective"); + ucc_ev_t comp_ev, *post_ev; + comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; + comp_ev.ev_context = nullptr; + comp_ev.ev_context_size = 0; + comp_ev.req = request; + TORCH_UCC_CHECK( + ucc_collective_triggered_post(ee, &comp_ev), + "failed to post triggered collective"); + ucc_status_t st = ucc_ee_get_event(ee, &post_ev); + TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST); + ucc_ee_ack_event(ee, post_ev); + auto entry = + std::make_shared(&ucc_comm, request); + entry->data = std::move(data); + work->entry_ = entry; + std::unique_lock lock(mutex); + progress_queue.push_back(entry); + lock.unlock(); + queue_produce_cv.notify_one(); +} +#endif + +void Comm::progress_loop() { + std::unique_lock lock(mutex); +#ifdef USE_CUDA + bool device_set = false; +#endif + while (!stop_progress_loop) { + if (progress_queue.empty()) { + queue_produce_cv.wait(lock); + continue; + } + collective_inprogress = true; + auto work = progress_queue.front(); + progress_queue.pop_front(); + lock.unlock(); +#ifdef USE_CUDA + if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) { + c10::cuda::set_device(cuda_device_index); + device_set = true; + } +#endif + std::exception_ptr eptr; + try { + while (work->request_->status > 0) { + ucc_comm.progress(); + ucx_comm.progress(); + } + if (work->request_->status < 0) { + eptr = std::make_exception_ptr( + std::runtime_error(ucc_status_string(work->request_->status))); + std::string err_log = c10::str( + "Failed to progress communication", // TODO: report exact op type or + // id? + ucc_status_string(work->request_->status)); + TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log); + } + } catch (...) { + eptr = std::current_exception(); + } + work->finalize(eptr); + work = nullptr; + collective_inprogress = false; + queue_consume_cv.notify_one(); + lock.lock(); + } +} + +ProcessGroupUCC::ProcessGroupUCC( + const c10::intrusive_ptr& store, + int rank, + int size, + std::chrono::duration timeout) + : ProcessGroup(rank, size), timeout_(timeout) { + c10::call_once(torch_ucc_config.flag, read_confg); + oob = std::make_shared(); + oob->rank = rank; + oob->size = size; + oob->store = store; + comm = nullptr; + cuda_ee = nullptr; + static uint32_t id = 0; + uint32_t pg_id = (id++ % TORCH_UCX_MAX_COMM); + + logger = c10::make_intrusive( + c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"), + TORCH_UCC_INIT); + TORCH_UCC_LOG_INFO( + TORCH_UCC_INIT, + c10::str( + "Created ProcessGroupUCC with ", + size, + " ranks, with timeout ", + timeout_.count(), + " secs")); + std::string envs = ""; + for (auto& torch_ucc_env : torch_ucc_envs_map) { + envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second); + } + TORCH_UCC_LOG_INFO( + TORCH_UCC_INIT, + c10::str( + "Successfully read and set ProcessGroupUCC env. variables as followings", + envs)); + + if (torch_ucc_config.enable_health_check) { + // Perform health check by initializing dummy communicators and destroying + // them. This will help indicate any UCC/UCX-related issues prior to the + // first collective. Run it in a separate thread and wait on CV to handle + // timeouts so that if there are hangs, the main thread can still run + // correctly. + runHealthCheck(); + } + if (torch_ucc_config.enable_comms_logger) { + logger->initCommsTracer(); + } +} + +ProcessGroupUCC::~ProcessGroupUCC() { + if (torch_ucc_config.enable_comms_logger) { + logger->flushComms(this->getRank(), this->getSize()); + } + if (comm) { + logger->setPhase(TORCH_UCC_FINALIZE); + comm->ucc_destroy_team(team); + TORCH_UCC_LOG_INFO( + TORCH_UCC_FINALIZE, "Successfully destroyed UCC library"); + comm->ucx_disconnect_eps(eps, oob); + TORCH_UCC_LOG_INFO( + TORCH_UCC_FINALIZE, "Successfully destroyed UCX library"); + try { + if (cuda_ee) { + ucc_ee_destroy(cuda_ee); + } + if ((size_t)oob->store->add(oob->getKey("ucc_pg_closed"), 1) == + eps.size()) { + std::vector val = {1}; + oob->store->set(oob->getKey("ucc_pg_finished"), val); + } else { + oob->store->wait({oob->getKey("ucc_pg_finished")}); + } + } catch (std::exception& ex) { + TORCH_UCC_LOG_INFO( + TORCH_UCC_FINALIZE, + c10::str( + "(~ProcessGroupUCC) Caught error in Store Operation .. ", + "[", + ex.what(), + "]")); + } + comm = nullptr; + } +} + +#ifdef USE_CUDA +// Return CUDA device with ordinal given by input rank. +c10::Device getCUDADeviceForRank(int rank) { + TORCH_CHECK(rank >= 0, "Invalid rank ", rank); + auto numGPUs = at::cuda::getNumGPUs(); + auto deviceIdx = static_cast(rank % numGPUs); + return c10::Device(c10::DeviceType::CUDA, deviceIdx); +} +#endif + +void ProcessGroupUCC::runHealthCheck() { + // Run health check in a separate thread and wait on CV to handle timeouts. + // This design allows us to handle hangs. + + // When size_ is 1, there is no need to do any communication at all. + if (size_ == 1) + return; + + struct HealthCheckData { + std::mutex healthCheckMutex; + std::condition_variable healthCheckCv; + bool ucxHealthCheckSuccess = false; + bool uccHealthCheckSuccess = false; + std::exception_ptr healthCheckException; + } healthCheckData; + + auto t = std::thread([&healthCheckData, this]() { + std::list devices{c10::kCPU}; +#ifdef USE_CUDA + c10::cuda::OptionalCUDAGuard gpuGuard; + if (at::cuda::is_available()) { + devices.emplace_front(getCUDADeviceForRank(rank_)); + } +#endif + for (auto device : devices) { + bool is_last_device = (device == devices.back()); + try { + auto oob = std::make_shared(); + oob->rank = this->oob->rank; + oob->size = this->oob->size; + oob->store = this->oob->store; + + std::vector eps; + ucc_team_h team = nullptr; + uint32_t comm_id; +#ifdef USE_CUDA + if (device.is_cuda()) { + gpuGuard.set_index(device.index()); + } +#endif + auto comm = Comm::get_comm(comm_id, device, oob, logger, true); + comm->ucx_connect_eps(eps, oob); + comm->ucx_disconnect_eps(eps, oob); + TORCH_UCC_LOG_INFO( + TORCH_UCC_HEALTH_CHECK, + c10::str( + "UCX library health check succeed for device ", + c10::DeviceTypeName(device.type()))); + // Mark ucx health check as complete. + if (is_last_device) { + std::lock_guard lk(healthCheckData.healthCheckMutex); + healthCheckData.ucxHealthCheckSuccess = true; + } + + comm->ucc_create_team(team, oob); + comm->ucc_destroy_team(team); + TORCH_UCC_LOG_INFO( + TORCH_UCC_HEALTH_CHECK, + c10::str( + "UCC library health check succeed for device ", + c10::DeviceTypeName(device.type()))); + // Mark ucc health check as complete. + if (is_last_device) { + std::lock_guard lk(healthCheckData.healthCheckMutex); + healthCheckData.uccHealthCheckSuccess = true; + } + + comm = nullptr; + oob = nullptr; + // Notify main thread the health check is complete. + if (is_last_device) { + healthCheckData.healthCheckCv.notify_one(); + } + } catch (const std::exception& e) { + // Populate exception ptr. + healthCheckData.healthCheckException = std::current_exception(); + // Unblock waiting main thread which will report exception. + healthCheckData.healthCheckCv.notify_one(); + } // Unknown exceptions will just cause the program to terminate. + } + }); + // We don't need to join the thread, just need to verify health check via the + // CV. Hence we detach the thread here. + t.detach(); // NOLINT + TORCH_UCC_LOG_INFO( + TORCH_UCC_HEALTH_CHECK, + c10::str( + "will wait up to ", + timeout_.count(), + " msec for UCC health check to complete.")); + std::unique_lock lock(healthCheckData.healthCheckMutex); + healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() { + return healthCheckData.ucxHealthCheckSuccess && + healthCheckData.uccHealthCheckSuccess; + }); + + if (healthCheckData.healthCheckException) { + std::rethrow_exception(healthCheckData.healthCheckException); + } + // If there is no exception, the likely culprit is a timeout/hang + TORCH_CHECK( + healthCheckData.ucxHealthCheckSuccess, + "ProcessGroupUCC: Health check failure: Failed to initialize UCX on rank ", + rank_); + TORCH_CHECK( + healthCheckData.uccHealthCheckSuccess, + "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ", + rank_); +} + +void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) { + args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; + args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT; + args.timeout = timeout_.count(); +} + +#ifdef USE_CUDA +std::unique_ptr ProcessGroupUCC::getPooledEvent() { + std::unique_ptr ev; + std::lock_guard lock(ep.event_pool_mutex); + if (ep.event_pool.empty()) { + ev = std::make_unique(); + } else { + ev = std::move(ep.event_pool.front()); + ep.event_pool.pop(); + } + return ev; +} +#endif + +template +c10::intrusive_ptr ProcessGroupUCC::collective_post( + OpType opType, + PreProcess preproc, + PostProcess postproc, + ucc_coll_args_t& coll, + std::unique_ptr data, + c10::Device dev, + std::vector& inputTensors, + std::vector& outputTensors, + const char* prof_title) { + set_timeout(coll); + auto work = c10::make_intrusive( + opType, torch_ucc_config.enable_profiling ? prof_title : nullptr, logger); + + RECORD_COMMS_TRACE( + logger->trace_generator, + work, + opType, + this->getRank(), + this->getSize(), + inputTensors, + outputTensors); + + // Store references to outputs to be used by result + work->outputs_ = std::make_shared>(outputTensors); + switch (dev.type()) { + case c10::DeviceType::CPU: { + if (torch_ucc_config.use_future) { + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + } + comm->enqueue_collective(std::move(data), work, coll, team); + return work; + } +#ifdef USE_CUDA + case c10::DeviceType::CUDA: { + auto cuda_ev = getPooledEvent(); + cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index())); + cuda_ev->block(*stream); + at::cuda::CUDAStreamGuard guard(*stream); + preproc(); + comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee); + postproc(); + cuda_ev->record(*stream); + work->fence = std::move(cuda_ev); + work->ep = &ep; + if (torch_ucc_config.use_future) { + c10::cuda::CUDAMultiStreamGuard streamGuard(*stream); + std::vector devList{dev}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devList); + // Add a callback that runs profiling end callbacks + if (work->recordFunctionEndCallback_) { + work->future_->addCallback([work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }); + } + + work->future_->markCompleted(c10::IValue(outputTensors)); + } + return work; + } +#endif // #ifdef USE_CUDA + default: { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str())); + throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); + } + } +} + +c10::intrusive_ptr ProcessGroupUCC::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& /* unused */) { + auto& tensor = inputTensors[0]; + check_device(tensor.device(), outputTensors[0][0].device()); + initComm(tensor.device()); + + if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) { + AllgathervWorkData* data = new AllgathervWorkData(size_); + for (int i = 0; i < size_; i++) { + data->recv_lengths[i] = tensor.element_size() * tensor.numel(); + data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr(); + } + ucc_coll_args_t coll; + coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags = + UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; + coll.coll_type = UCC_COLL_TYPE_ALLGATHERV; + coll.src.info.buffer = tensor.data_ptr(); + coll.src.info.count = tensor.element_size() * tensor.numel(); + coll.src.info.datatype = UCC_DT_UINT8; + coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); + coll.dst.info_v.buffer = nullptr; + coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); + coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); + coll.dst.info_v.datatype = UCC_DT_UINT8; + coll.dst.info_v.mem_type = + to_ucc_memType(outputTensors[0][0].device().type()); + SAVE_TENSORS(inputTensors, data->src); + SAVE_TENSORS(outputTensors[0], data->dst); + + return collective_post( + OpType::ALLGATHER, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + tensor.device(), + inputTensors, + outputTensors[0], + "ucc:allgatherv"); + } else { + WorkData* data = new WorkData(); + std::vector flat_output(outputTensors.size()); + for (size_t i = 0; i < outputTensors.size(); i++) { + TORCH_CHECK( + outputTensors[i].size() == outputTensors.size() * size_, + "Tensor output list is not valid for the number of participants"); + flat_output[i] = c10d::newLikeFlat(outputTensors, i); + } + SAVE_TENSORS(flat_output, data->flat); + ucc_coll_args_t coll; + coll.mask = 0; + coll.flags = 0; + coll.coll_type = UCC_COLL_TYPE_ALLGATHER; + coll.src.info.buffer = tensor.data_ptr(); + coll.src.info.count = tensor.numel(); + coll.src.info.datatype = to_ucc_dType(tensor); + coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); + coll.dst.info.buffer = flat_output[0].data_ptr(); + coll.dst.info.count = flat_output[0].numel(); + coll.dst.info.datatype = to_ucc_dType(flat_output[0]); + coll.dst.info.mem_type = + to_ucc_memType(outputTensors[0][0].device().type()); + + auto copy_from_flat = [&] { + bool asyncCopy = false; +#ifdef USE_CUDA + bool isCuda = outputTensors[0][0].device().is_cuda(); + ; +#endif + for (size_t i = 0; i < outputTensors.size(); i++) { + auto inumel = inputTensors[i].numel(); + for (size_t j = 0; j < outputTensors[i].size(); j++) { + TORCH_CHECK( + (outputTensors[i][j].numel() == inumel), + "Tensor operand counts must be same"); +#ifdef USE_CUDA + if (isCuda) { + c10::cuda::CUDACachingAllocator::recordStream( + outputTensors[i][j].storage().data_ptr(), (*stream)); + asyncCopy = true; + } +#endif + outputTensors[i][j].copy_(flat_output[i][j], asyncCopy); + } + } + }; + return collective_post( + OpType::ALLGATHER, + []() {}, + copy_from_flat, + coll, + std::unique_ptr(data), + tensor.device(), + inputTensors, + outputTensors[0], + "ucc:allgather"); + } +} + +c10::intrusive_ptr ProcessGroupUCC::_allgather_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts) { + check_tensor({outputTensor}); + check_tensor({inputTensor}); + initComm(outputTensor.device()); + + WorkData* data = new WorkData(); + + ucc_coll_args_t coll; + coll.mask = 0; + coll.flags = 0; + coll.coll_type = UCC_COLL_TYPE_ALLGATHER; + coll.src.info.buffer = inputTensor.data_ptr(); + coll.src.info.count = inputTensor.numel(); + coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type()); + coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); + coll.dst.info.buffer = outputTensor.data_ptr(); + coll.dst.info.count = outputTensor.numel(); + coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type()); + coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); + + std::vector inputTensors = {inputTensor}; + std::vector outputTensors = {outputTensor}; + SAVE_TENSORS(inputTensors, data->src); + SAVE_TENSORS(outputTensors, data->dst); + + return collective_post( + OpType::_ALLGATHER_BASE, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + outputTensor.device(), + inputTensors, + outputTensors, + "ucc:allgather_base"); +} + +c10::intrusive_ptr ProcessGroupUCC::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + check_tensor(tensors); + auto& tensor = tensors[0]; + initComm(tensor.device()); + WorkData* data = new WorkData(); + + ucc_coll_args_t coll; + coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; + coll.coll_type = UCC_COLL_TYPE_ALLREDUCE; + coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type()); + coll.src.info.buffer = nullptr; + coll.src.info.count = tensor.numel(); + coll.src.info.datatype = to_ucc_dType(tensor); + coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); + coll.dst.info.buffer = tensor.data_ptr(); + coll.dst.info.count = tensor.numel(); + coll.dst.info.datatype = to_ucc_dType(tensor); + coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); + SAVE_TENSORS(tensors, data->dst); + return collective_post( + OpType::ALLREDUCE, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + tensor.device(), + tensors, + tensors, + "ucc:allreduce"); +} + +c10::intrusive_ptr ProcessGroupUCC::allreduce_coalesced( + std::vector& /* unused */, + const AllreduceCoalescedOptions& /* unused */) { + throw std::runtime_error( + "ProcessGroupUCC does not support allreduce_coalesced"); +} + +c10::intrusive_ptr ProcessGroupUCC::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& /* unused */) { + auto device = outputTensors[0].device(); + for (const auto r : c10::irange(outputTensors.size())) { + TORCH_CHECK( + device == outputTensors[r].device() && + device == inputTensors[r].device(), + "Tensors must be on the same device") + } + + initComm(device); + ucc_coll_args_t coll; + AlltoallWorkData* data; + data = new AlltoallWorkData(size_); + + /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as + follow. + 1. store addresses of each tensor directly in displacements, keep buffer + to nullptr, i.e., 0 + 2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size + calculation in UCC layer + 3. post Alltoallv + */ + for (const auto i : c10::irange(size_)) { + data->send_lengths[i] = + (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel()); + data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr(); + data->recv_lengths[i] = + (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel()); + data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr(); + } + + coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags = + UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; + coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; + coll.src.info_v.buffer = 0; + coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); + coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); + coll.src.info_v.datatype = UCC_DT_UINT8; + coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type()); + coll.dst.info_v.buffer = 0; + coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); + coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); + coll.dst.info_v.datatype = UCC_DT_UINT8; + coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type()); + + SAVE_TENSORS(inputTensors, data->src); + SAVE_TENSORS(outputTensors, data->dst); + + return collective_post( + OpType::ALLTOALL, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + device, + inputTensors, + outputTensors, + "ucc:alltoall"); +} + +c10::intrusive_ptr ProcessGroupUCC::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + check_device(inputTensor.device(), outputTensor.device()); + initComm(inputTensor.device()); + ucc_coll_args_t coll; + AlltoallWorkData* data; + + if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) { + data = new AlltoallWorkData(0); + TORCH_CHECK( + (outputTensor.size(0) % size_ == 0) && + (inputTensor.size(0) % size_ == 0), + "Tensor's dim 0 does not divide equally across group size"); + coll.mask = 0; + coll.flags = 0; + coll.coll_type = UCC_COLL_TYPE_ALLTOALL; + coll.src.info.buffer = inputTensor.data_ptr(); + coll.src.info.count = inputTensor.element_size() * inputTensor.numel(); + coll.src.info.datatype = UCC_DT_UINT8; + coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); + coll.dst.info.buffer = outputTensor.data_ptr(); + coll.dst.info.count = outputTensor.element_size() * outputTensor.numel(); + coll.dst.info.datatype = UCC_DT_UINT8; + coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); + coll.flags = 0; + } else { + data = new AlltoallWorkData(size_); + c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); + c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + computeLengthsAndOffsets( + outputSplitSizes, + outputTensor, + &data->recv_lengths, + &data->recv_offsets); + computeLengthsAndOffsets( + inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets); + coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; + coll.src.info_v.buffer = inputTensor.data_ptr(); + coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); + coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); + coll.src.info_v.datatype = to_ucc_dType(inputTensor); + coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type()); + coll.dst.info_v.buffer = outputTensor.data_ptr(); + coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); + coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); + coll.dst.info_v.datatype = to_ucc_dType(outputTensor); + coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type()); + coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | + UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT | + UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; + + if (torch_ucc_config.enable_comms_logger) { + logger->trace_generator->recordOptionalInfo( + outputSplitSizes, inputSplitSizes); + } + } + std::vector inputTensors = {inputTensor}; + std::vector outputTensors = {outputTensor}; + SAVE_TENSORS(inputTensors, data->src); + SAVE_TENSORS(outputTensors, data->dst); + + return collective_post( + OpType::ALLTOALL_BASE, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + inputTensor.device(), + inputTensors, + outputTensors, + "ucc:alltoall"); +} + +c10::intrusive_ptr ProcessGroupUCC::barrier( + const BarrierOptions& opts) { + c10::Device device = c10::Device(c10::DeviceType::CPU); +#ifdef USE_CUDA + auto numGPUs = c10::cuda::device_count(); + if (!opts.device_ids.empty()) { + device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front()); + } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) { + device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index); + } else if (numGPUs > 0) { + int8_t deviceIdx = static_cast(c10::cuda::current_device()); + // if current device is 0, likely the device is not set, use the best guess + if (0 == (int)deviceIdx) { + deviceIdx = static_cast(this->getRank() % numGPUs); + } + TORCH_UCC_LOG_INFO( + TORCH_UCC_COLL_POST, + c10::str( + "post barrier before specifying any GPU while there are ", + numGPUs, + " GPUs available. ", + "Not clear if GPU barrier is required, using GPU ", + (int)deviceIdx, + " to perform barrier. ", + "Specify device_ids option in barrier() to force ", + "use of a particular device")); + device = c10::Device(c10::DeviceType::CUDA, deviceIdx); + } +#endif + initComm(device); + + ucc_coll_args_t coll; + coll.mask = 0; + coll.flags = 0; + coll.coll_type = UCC_COLL_TYPE_BARRIER; + auto dummy_tensor = std::vector(); + return collective_post( + OpType::BARRIER, + []() {}, + []() {}, + coll, + nullptr, + device, + dummy_tensor, + dummy_tensor, + "ucc:barrier"); +} + +c10::intrusive_ptr ProcessGroupUCC::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + check_tensor(tensors); + auto& tensor = tensors[0]; + initComm(tensor.device()); + WorkData* data = new WorkData(); + + ucc_coll_args_t coll; + coll.mask = 0; + coll.flags = 0; + coll.coll_type = UCC_COLL_TYPE_BCAST; + coll.src.info.buffer = tensor.data_ptr(); + coll.src.info.count = tensor.numel(); + coll.src.info.datatype = to_ucc_dType(tensor); + coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); + coll.root = opts.rootRank; + SAVE_TENSORS(tensors, data->dst); + + if (torch_ucc_config.enable_comms_logger) { + logger->trace_generator->recordOptionalInfo(opts.rootRank); + } + + return collective_post( + OpType::BROADCAST, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + tensor.device(), + tensors, + tensors, + "ucc:broadcast"); +} + +c10::intrusive_ptr ProcessGroupUCC::gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts) { + std::vector outputs; + auto& input = inputTensors[0]; + initComm(input.device()); + + AllgathervWorkData* data = new AllgathervWorkData(size_); + ucc_coll_args_t coll; + coll.root = opts.rootRank; + coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags = + UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; + coll.coll_type = UCC_COLL_TYPE_GATHERV; + + /* for non-root ranks, only src is valid */ + coll.src.info.buffer = input.data_ptr(); + coll.src.info.count = (uint64_t)(input.element_size() * input.numel()); + coll.src.info.datatype = UCC_DT_UINT8; + coll.src.info.mem_type = to_ucc_memType(input.device().type()); + + if (getRank() == opts.rootRank) { + if (outputTensors.size() != 1) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, + c10::str( + "gather requires a single-element output list containing a list with ", + getSize(), + " tensors.")); + } else if (outputTensors[0].size() != static_cast(getSize())) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, + c10::str( + "Incorrect output list size ", + outputTensors[0].size(), + ". Output list size should be ", + getSize(), + ", same as size of the process group.")); + } + outputs = outputTensors[0]; + + for (int i = 0; i < size_; i++) { + data->recv_lengths[i] = + (uint64_t)(outputs[i].element_size() * outputs[i].numel()); + data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr(); + } + /* use gatherv and store non-contiguous addresses in displacements to avoid + * flatten outputTensors */ + coll.dst.info_v.buffer = nullptr; + coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); + coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); + coll.dst.info_v.datatype = UCC_DT_UINT8; + coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type()); + + SAVE_TENSORS(outputs, data->dst); + } else { + // for non-root ranks, outputTensors should be an empty list + if (outputTensors.size() != 0) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, "requires empty output on non-root"); + } + outputs = {}; + // append a empty tensor to the list to be used by future mark + outputs.emplace_back(); + } + + SAVE_TENSORS(inputTensors, data->src); + + return collective_post( + OpType::GATHER, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + input.device(), + inputTensors, + outputs, + "ucc:gather"); +} + +c10::intrusive_ptr ProcessGroupUCC::reduce( + std::vector& tensors, + const ReduceOptions& opts) { + check_tensor(tensors); + auto& tensor = tensors[0]; + initComm(tensor.device()); + WorkData* data = new WorkData(); + + ucc_coll_args_t coll; + coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; + coll.coll_type = UCC_COLL_TYPE_REDUCE; + coll.op = ucc_op_map.at(opts.reduceOp); + coll.root = opts.rootRank; + coll.src.info.buffer = tensor.data_ptr(); + coll.src.info.count = tensor.numel(); + coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); + coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); + coll.dst.info.buffer = tensor.data_ptr(); + coll.dst.info.count = tensor.numel(); + coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); + coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); + SAVE_TENSORS(tensors, data->dst); + return collective_post( + OpType::REDUCE, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + tensor.device(), + tensors, + tensors, + "ucc:reduce"); +} + +c10::intrusive_ptr ProcessGroupUCC::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK( + (outputTensors.size() == inputTensors.size()), + "Tensor input/output list for reduce_scatter must have same size"); + check_tensor(outputTensors); + check_device(inputTensors[0][0].device(), outputTensors[0].device()); + initComm(inputTensors[0][0].device()); + auto data = std::make_unique(); + std::vector flat_input(inputTensors.size()); + for (size_t i = 0; i < inputTensors.size(); i++) { + TORCH_CHECK( + inputTensors[i].size() == inputTensors.size() * size_, + "Tensor input list is not valid for the number of participants"); + flat_input[i] = c10d::newLikeFlat(inputTensors, i); + } + SAVE_TENSORS(flat_input, data->flat); + check_tensor(flat_input); + ucc_coll_args_t coll; + coll.mask = 0; + coll.flags = 0; + coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER; + coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type()); + + coll.src.info.buffer = flat_input[0].data_ptr(); + coll.src.info.count = flat_input[0].numel(); + coll.src.info.datatype = to_ucc_dType(flat_input[0]); + coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type()); + coll.dst.info.buffer = outputTensors[0].data_ptr(); + coll.dst.info.count = outputTensors[0].numel(); + coll.dst.info.datatype = to_ucc_dType(outputTensors[0]); + coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type()); + + SAVE_TENSORS(inputTensors[0], data->src); + SAVE_TENSORS(outputTensors, data->dst); + + auto copy_to_flat = [&] { + bool asyncCopy = false; + auto isize = inputTensors.size(); +#ifdef USE_CUDA + bool isCuda = inputTensors[0][0].device().is_cuda(); +#endif + for (size_t i = 0; i < isize; i++) { + auto onumel = outputTensors[i].numel(); + for (size_t j = 0; j < inputTensors[i].size(); j++) { + TORCH_CHECK( + (inputTensors[i][j].numel() == onumel), + "Tensor operand counts must be same"); +#ifdef USE_CUDA + if (isCuda) { + c10::cuda::CUDACachingAllocator::recordStream( + inputTensors[i][j].storage().data_ptr(), (*stream)); + asyncCopy = true; + } +#endif + flat_input[i][j].copy_(inputTensors[i][j], asyncCopy); + } + } + }; + + return collective_post( + OpType::REDUCE_SCATTER, + copy_to_flat, + []() {}, + coll, + std::move(data), + inputTensors[0][0].device(), + inputTensors[0], + outputTensors, + "ucc:reduce_scatter"); +} + +c10::intrusive_ptr ProcessGroupUCC::scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts) { + auto& tensor = outputTensors[0]; + initComm(tensor.device()); + + ScattervWorkData* data = new ScattervWorkData(size_); + ucc_coll_args_t coll; + coll.root = opts.rootRank; + coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; + coll.flags = + UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; + coll.coll_type = UCC_COLL_TYPE_SCATTERV; + + if (getRank() == opts.rootRank) { + /* src is only valid at non-root rank */ + if (inputTensors.size() != 1) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, + c10::str( + "gather requires a single-element output list containing a list with ", + getSize(), + " tensors.")); + } else if (inputTensors[0].size() != static_cast(getSize())) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, + c10::str( + "Incorrect output list size ", + inputTensors[0].size(), + ". Output list size should be ", + getSize(), + ", same as size of the process group.")); + } + + for (int i = 0; i < size_; i++) { + data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel(); + data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr(); + } + /* use scatter and store non-contiguous addresses in displacements to avoid + * flatten inputTensors */ + coll.src.info_v.buffer = nullptr; + coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); + coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); + coll.src.info_v.datatype = UCC_DT_UINT8; + coll.src.info_v.mem_type = + to_ucc_memType(inputTensors[0][0].device().type()); + + SAVE_TENSORS(inputTensors[0], data->src); + } else { + // for non-root ranks, inputTensors should be an empty list + if (inputTensors.size() != 0) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_COLL_POST, "requires empty output on non-root"); + } + } + + coll.dst.info.buffer = tensor.data_ptr(); + coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel(); + coll.dst.info.datatype = UCC_DT_UINT8; + coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); + SAVE_TENSORS(outputTensors, data->dst); + + return collective_post( + OpType::SCATTER, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + tensor.device(), + inputTensors[0], + outputTensors, + "ucc:scatter"); +} + +c10::intrusive_ptr ProcessGroupUCC::send( + std::vector& tensors, + int dstRank, + int tag) { + check_tensor(tensors); + auto& tensor = tensors[0]; + initComm(tensor.device()); + +#ifdef USE_ACTIVE_SETS + WorkData* data = new WorkData(); + ucc_coll_args_t coll; + coll.tag = tag; + coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; + coll.flags = 0; + coll.coll_type = UCC_COLL_TYPE_BCAST; + coll.src.info.buffer = tensor.data_ptr(); + coll.src.info.count = tensor.numel(); + coll.src.info.datatype = to_ucc_dType(tensor); + coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); + coll.root = getRank(); + + coll.active_set.size = 2; + coll.active_set.start = getRank(); + coll.active_set.stride = dstRank - getRank(); + SAVE_TENSORS(tensors, data->dst); + + return collective_post( + OpType::SEND, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + tensor.device(), + tensors, + tensors, + "ucc:send"); +#else + ucp_tag_t ucp_tag; + TORCH_UCX_MAKE_SEND_TAG(ucp_tag, tag, rank_, comm_id); + ucc_coll_req_h request = comm->send_nb( + eps[dstRank], + tensor.data_ptr(), + to_ucs_memType(tensor.device().type()), + tensor.numel() * tensor.element_size(), + ucp_tag); + + auto work = comm->enqueue_p2p(OpType::SEND, request, "ucc:send"); + // TODO: record src, dst ranks and tag + RECORD_COMMS_TRACE( + logger->trace_generator, + work, + OpType::SEND, + this->getRank(), + this->getSize(), + tensors, + tensors); + return work; +#endif +} + +c10::intrusive_ptr ProcessGroupUCC::recv( + std::vector& tensors, + int srcRank, + int tag) { + check_tensor(tensors); + auto& tensor = tensors[0]; + initComm(tensor.device()); + +#ifdef USE_ACTIVE_SETS + WorkData* data = new WorkData(); + ucc_coll_args_t coll; + coll.tag = tag; + coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; + coll.flags = 0; + coll.coll_type = UCC_COLL_TYPE_BCAST; + coll.src.info.buffer = tensor.data_ptr(); + coll.src.info.count = tensor.numel(); + coll.src.info.datatype = to_ucc_dType(tensor); + coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); + coll.root = srcRank; + + coll.active_set.size = 2; + coll.active_set.start = srcRank; + coll.active_set.stride = getRank() - srcRank; + SAVE_TENSORS(tensors, data->dst); + + return collective_post( + OpType::RECV, + []() {}, + []() {}, + coll, + std::unique_ptr(data), + tensor.device(), + tensors, + tensors, + "ucc:recv"); +#else + ucp_tag_t ucp_tag, ucp_tag_mask; + TORCH_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, srcRank, comm_id); + ucc_coll_req_h request = comm->recv_nb( + tensor.data_ptr(), + to_ucs_memType(tensor.device().type()), + tensor.numel() * tensor.element_size(), + ucp_tag, + ucp_tag_mask); + + auto work = comm->enqueue_p2p(OpType::RECV, request, "ucc:recv"); + // TODO: record src, dst ranks and tag + RECORD_COMMS_TRACE( + logger->trace_generator, + work, + OpType::RECV, + this->getRank(), + this->getSize(), + tensors, + tensors); + return work; +#endif +} + +c10::intrusive_ptr ProcessGroupUCC::recvAnysource( + std::vector& tensors, + int tag) { + check_tensor(tensors); + auto& tensor = tensors[0]; + initComm(tensor.device()); + + ucp_tag_t ucp_tag, ucp_tag_mask; + TORCH_UCX_MAKE_RECV_TAG( + ucp_tag, ucp_tag_mask, tag, TORCH_UCX_ANY_SOURCE, comm_id); + ucc_coll_req_h request = comm->recv_nb( + tensor.data_ptr(), + to_ucs_memType(tensor.device().type()), + tensor.numel() * tensor.element_size(), + ucp_tag, + ucp_tag_mask); + + auto work = comm->enqueue_p2p(OpType::RECVANYSOURCE, request, "ucc:recv"); + // TODO: record dst rank and tag + RECORD_COMMS_TRACE( + logger->trace_generator, + work, + OpType::RECVANYSOURCE, + this->getRank(), + this->getSize(), + tensors, + tensors); + return work; +} + +c10::intrusive_ptr ProcessGroupUCC::createProcessGroupUCC( + const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size, + const std::chrono::duration& timeout) { + return c10::make_intrusive(store, rank, size, timeout); +} + +void ProcessGroupUCC::initComm(c10::Device dev) { + if (!comm) { +#ifdef USE_CUDA + if (dev.is_cuda()) { + c10::cuda::set_device(dev.index()); + } +#endif + comm = Comm::get_comm(comm_id, dev, oob, logger); + comm->ucx_connect_eps(eps, oob); + TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library"); + comm->ucc_create_team(team, oob); + TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library"); + logger->setPhase(TORCH_UCC_READY); + } else { + if (dev.is_cuda()) { + if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && + (comm->cuda_device_index != dev.index())) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_INIT, + "ucc communicator was initialized with different cuda device," + "multi device is not supported"); + throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); + } + comm->cuda_device_index = dev.index(); + } + } +#ifdef USE_CUDA + // Create UCC execution engine. + if (!cuda_ee && dev.is_cuda()) { + stream = std::make_unique( + at::cuda::getStreamFromPool(true, dev.index())); + ucc_ee_params_t params; + params.ee_type = UCC_EE_CUDA_STREAM; + params.ee_context = (void*)stream->stream(); + params.ee_context_size = sizeof(cudaStream_t); + TORCH_UCC_CHECK( + ucc_ee_create(team, ¶ms, &cuda_ee), + "failed to create UCC execution engine"); + } +#endif +} + +} // namespace c10d + +#endif // USE_C10D_UCC diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp new file mode 100644 index 0000000000000..1209ea2324c9e --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp @@ -0,0 +1,407 @@ +#pragma once + +#ifdef USE_C10D_UCC + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#ifdef USE_CUDA +#include +#include +#endif + +namespace c10d { + +#define TORCH_UCC_DEVICE_NOT_SET -2 + +#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \ + ((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \ + (((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \ + (((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET)) + +#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \ + ((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \ + (((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \ + (((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET)) + +#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \ + do { \ + (_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \ + } while (0) + +#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1) +#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK) +#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1) + +#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \ + do { \ + (_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \ + if ((_rank) == TORCH_UCX_ANY_SOURCE) { \ + (_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \ + } else { \ + (_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \ + } \ + } while (0) + +#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \ + do { \ + (_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \ + } while (0) + +#define TORCH_UCX_MAKE_OOB_RECV_TAG( \ + _ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \ + do { \ + (_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \ + (_ucp_tag_mask) = (uint64_t)-1; \ + } while (0) + +#ifdef USE_CUDA +#define SAVE_TENSORS(_TENSORS, _DATA) \ + do { \ + if ((_TENSORS)[0].device().is_cuda()) { \ + for (const auto i : c10::irange((_TENSORS).size())) { \ + c10::cuda::CUDACachingAllocator::recordStream( \ + (_TENSORS)[i].storage().data_ptr(), (*stream)); \ + } \ + } else { \ + (_DATA) = (_TENSORS); \ + } \ + } while (0) + +#else +#define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS); +#endif + +constexpr const char* UCC_BACKEND_NAME = "ucc"; + +enum torch_ucx_tag_type_t { TORCH_UCX_P2P_TAG, TORCH_UCX_OOB_TAG }; + +struct event_pool_t { +#ifdef USE_CUDA + std::queue> event_pool; +#endif + std::mutex event_pool_mutex; +}; + +class Comm; + +// UCC does not support multiple CUDA devices per process. +class TORCH_API ProcessGroupUCC : public ProcessGroup { + private: + void set_timeout(ucc_coll_args_t& args); + + public: + class WorkData { + public: + std::vector src; + std::vector dst; + std::vector flat; + WorkData() {} + virtual ~WorkData() = default; + }; + class AlltoallWorkData : public WorkData { + public: + AlltoallWorkData(int size) + : send_lengths(size), + send_offsets(size), + recv_lengths(size), + recv_offsets(size) {} + std::vector send_lengths; + std::vector send_offsets; + std::vector recv_lengths; + std::vector recv_offsets; + }; + + class AllgathervWorkData : public WorkData { + public: + AllgathervWorkData(int size) : recv_lengths(size), recv_offsets(size) {} + std::vector recv_lengths; + std::vector recv_offsets; + }; + + class ScattervWorkData : public WorkData { + public: + ScattervWorkData(int size) : send_lengths(size), send_offsets(size) {} + std::vector send_lengths; + std::vector send_offsets; + }; + + class ProgressEntry { + friend class ProcessGroupUCC; + friend class Comm; + + public: + ProgressEntry(CommBase* comm, ucc_coll_req_h request) + : status_(UCC_INPROGRESS), comm_(comm), request_(request) {} + // Finalizes UCC status or exception of collective request. + void finalize(std::exception_ptr eptr = nullptr); + ucc_status_t status_; + CommBase* comm_; + ucc_coll_req_h request_; + std::unique_ptr data; + c10::intrusive_ptr future_; + std::exception_ptr eptr_; + }; + + class WorkUCC : public ProcessGroup::Work { + friend class ProcessGroupUCC; + friend class Comm; + + public: + WorkUCC(OpType opType, const char* prof_title) + : ProcessGroup::Work(-1, opType, prof_title) {} + WorkUCC( + OpType opType, + const char* prof_title, + const c10::intrusive_ptr& logger) + : ProcessGroup::Work(-1, opType, prof_title), logger_(logger) {} + ~WorkUCC(); + void setException(); + void setAndThrowException(); + bool isCompleted() override; + bool isSuccess() const override; + bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; + c10::intrusive_ptr getFuture() override; + std::vector result() override; +#ifdef USE_CUDA + std::unique_ptr fence = nullptr; + event_pool_t* ep = nullptr; +#endif + protected: + std::shared_ptr entry_; + c10::intrusive_ptr logger_; + + private: + // The future returned by getFuture. + c10::intrusive_ptr future_; + // Store a reference to collective's outputs, used by result + std::shared_ptr> outputs_; + }; + + explicit ProcessGroupUCC( + const c10::intrusive_ptr& store, + int rank = -1, + int size = -1, + std::chrono::duration timeout = kProcessGroupDefaultTimeout); + + void initComm(c10::Device dev); + + ~ProcessGroupUCC() override; + + const std::string getBackendName() const override { + return std::string(UCC_BACKEND_NAME); + } + +#ifdef USE_CUDA + std::unique_ptr getPooledEvent(); +#endif + + // Performs a health check by initializing dummy UCC & UCX communicators and + // then destroying them. This will help indicate and signal any + // UCC/UCX-related issues prior to the first collective. The actual + // initialization and subsequent destruction is ran on a separate thread and + // the main thread is signalled about timeouts/errors to report to the + // application. + void runHealthCheck(); + + template + c10::intrusive_ptr collective_post( + OpType opType, + PreProcess preproc, + PostProcess postproc, + ucc_coll_args_t& coll, + std::unique_ptr data, + c10::Device dev, + std::vector& inputTensors, + std::vector& outputTensors, + const char* prof_title); + + c10::intrusive_ptr broadcast( + std::vector& data, + const BroadcastOptions& opts = BroadcastOptions()) override; + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override; + + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override; + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override; + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override; + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override; + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override; + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override; + + c10::intrusive_ptr recvAnysource( + std::vector& tensors, + int tag) override; + + static c10::intrusive_ptr createProcessGroupUCC( + const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size, + const std::chrono::duration& timeout); + + protected: + const std::chrono::duration timeout_; + std::shared_ptr oob; + std::shared_ptr comm = {nullptr}; + uint32_t comm_id; + std::vector eps; + ucc_team_h team{nullptr}; + ucc_ee_h cuda_ee{nullptr}; +#ifdef USE_CUDA + std::unique_ptr stream = nullptr; + event_pool_t ep; +#endif + c10::intrusive_ptr logger; +}; + +class Comm { + c10::intrusive_ptr logger; + std::shared_ptr oob; + CommUCX ucx_comm; + CommUCC ucc_comm; + std::mutex mutex; + std::thread progress_thread; + std::condition_variable queue_produce_cv; + std::condition_variable queue_consume_cv; + std::deque> progress_queue; + bool stop_progress_loop; + bool collective_inprogress; + torch_ucc_phase_t finalize_phase; + + public: + c10::DeviceIndex cuda_device_index; + Comm( + const c10::intrusive_ptr& logger, + std::shared_ptr oob, + c10::Device dev, + bool is_health_check); + + ~Comm(); + + // Connects UCX end points. + void ucx_connect_eps( + std::vector& eps, + std::shared_ptr oob); + + // Disconnects UCX end points. + void ucx_disconnect_eps( + std::vector& eps, + std::shared_ptr oob); + + void ucc_create_team( + ucc_team_h& team, + std::shared_ptr oob); + + void ucc_destroy_team(ucc_team_h& team); + + c10::intrusive_ptr enqueue_p2p( + OpType opType, + ucc_coll_req_h request, + const char* prof_title); + +#ifdef USE_CUDA + void enqueue_cuda_collective( + std::unique_ptr data, + c10::intrusive_ptr work, + ucc_coll_args_t& coll, + ucc_team_h team, + ucc_ee_h ee); +#endif + + void enqueue_collective( + std::unique_ptr data, + c10::intrusive_ptr work, + ucc_coll_args_t& coll, + ucc_team_h team); + + static std::shared_ptr get_comm( + uint32_t& id, + c10::Device dev, + std::shared_ptr oob, + const c10::intrusive_ptr& logger, + bool is_health_check = false); + + void progress_loop(); + + ucc_coll_req_h send_nb( + ucp_ep_h ep, + void* data, + ucs_memory_type_t mtype, + size_t size, + ucp_tag_t ucp_tag); + + ucc_coll_req_h recv_nb( + void* data, + ucs_memory_type_t mtype, + size_t size, + ucp_tag_t ucp_tag, + ucp_tag_t ucp_tag_mask); +}; + +} // namespace c10d + +#endif // USE_C10D_UCC diff --git a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp index 359dbfd51d46c..e7463304974fa 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp @@ -29,7 +29,7 @@ struct CollectiveFingerPrint { // input tensor device types std::vector tensor_device_types_; // input tensor sizes - std::vector tensor_sizes_; + std::vector> tensor_sizes_; explicit CollectiveFingerPrint( OpType op_type, @@ -41,7 +41,7 @@ struct CollectiveFingerPrint { for (const at::Tensor& t : input_tensors) { tensor_dtypes_.push_back(static_cast(t.dtype().toScalarType())); tensor_device_types_.push_back(static_cast(t.device().type())); - tensor_sizes_.push_back(t.sizes()); + tensor_sizes_.push_back(t.sizes().vec()); } } @@ -49,10 +49,12 @@ struct CollectiveFingerPrint { CollectiveFingerPrint( OpType op_type, std::vector tensor_dtypes, - std::vector tensor_device_types) + std::vector tensor_device_types, + std::vector> tensor_sizes) : op_type_(op_type), tensor_dtypes_(tensor_dtypes), - tensor_device_types_(tensor_device_types) {} + tensor_device_types_(tensor_device_types), + tensor_sizes_(tensor_sizes) {} // Logs collective information in case of a failure. friend std::ostream& operator<<( @@ -79,19 +81,22 @@ struct CollectiveFingerPrint { // CollectiveFingerPrint::serialize_fingerprint and deserializes it back to a // CollectiveFingerPrint struct CollectiveFingerPrint deserialize_fingerprint(at::Tensor serialized_tensor) { - // TODO: Need to add asserts to validate serialized_tensor.sizes() before - // deserializing + OpType optype; + auto dtypes = std::vector(); + auto device_types = std::vector(); + auto sizes = std::vector>(); int index = 0; // 1. OpType - OpType optype = OpType(serialized_tensor[index].item()); + optype = OpType(serialized_tensor[index].item()); index++; - std::vector dtypes = std::vector(); - std::vector device_types = std::vector(); if (index < serialized_tensor.size(0)) { // 2. Num tensors int num_tensors = serialized_tensor[index].item(); index++; + dtypes.reserve(num_tensors); + device_types.reserve(num_tensors); + sizes.reserve(num_tensors); // 3. Tensor dtypes for (int i = 0; i < num_tensors; i++) { @@ -103,8 +108,22 @@ struct CollectiveFingerPrint { device_types.push_back(serialized_tensor[index].item()); index++; } + // 5. Tensor shapes + for (int i = 0; i < num_tensors; i++) { + // 5a. Shape size + int size = serialized_tensor[index].item(); + index++; + // 5b. Shape + auto shapeVec = std::vector(); + shapeVec.reserve(size); + for (int j = 0; j < size; j++) { + shapeVec.push_back(serialized_tensor[index].item()); + index++; + } + sizes.push_back(shapeVec); + } } - return CollectiveFingerPrint(optype, dtypes, device_types); + return CollectiveFingerPrint(optype, dtypes, device_types, sizes); } private: @@ -142,8 +161,8 @@ struct CollectiveFingerPrint { std::stringstream ss; ss << "Detected mismatch between collectives on ranks. Rank " << pg->getRank() << " is running collective: " << *this - << ", but Rank " << rank << " is running collective: " - << opTypeToString(rank_fingerprint.op_type_) << "."; + << ", but Rank " << rank + << " is running collective: " << rank_fingerprint << "."; TORCH_CHECK(false, ss.str()); } } @@ -169,6 +188,7 @@ struct CollectiveFingerPrint { } // 5. Shapes for (const auto& sizes : tensor_sizes_) { + data->push_back(sizes.size()); for (const auto& s : sizes) { data->push_back(s); } @@ -199,6 +219,7 @@ std::ostream& operator<<( // Convert dtype and device type info to string. std::vector dtype_strs; std::vector device_type_strs; + std::vector size_strs; for (const auto& tensor_dtype : collective_fingerprint.tensor_dtypes_) { dtype_strs.emplace_back( c10::toString(static_cast(tensor_dtype))); @@ -208,14 +229,20 @@ std::ostream& operator<<( device_type_strs.emplace_back( c10::toString(static_cast(tensor_device_type))); } + if (!collective_fingerprint.tensor_sizes_.empty()) { + for (const auto& single_tensor_shape_num : + collective_fingerprint.tensor_sizes_[0]) { + size_strs.emplace_back(std::to_string(single_tensor_shape_num)); + } + } collectiveInfo = c10::str( "CollectiveFingerPrint(", "OpType=", opTypeToString(collective_fingerprint.op_type_), - ", TensorShape=", - (collective_fingerprint.tensor_sizes_)[0], - ", TensorDtypes=", + ", TensorShape=[", + c10::Join(", ", size_strs), + "], TensorDtypes=", (dtype_strs), ", TensorDeviceTypes=", (device_type_strs), @@ -394,6 +421,16 @@ c10::intrusive_ptr ProcessGroupWrapper::barrier( return pg_->barrier(opts); } +c10::intrusive_ptr ProcessGroupWrapper:: + _reduce_scatter_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const ReduceScatterOptions& opts) { + runCollectiveChecks( + OpType::_REDUCE_SCATTER_BASE, {inputBuffer, outputBuffer}); + return pg_->_reduce_scatter_base(outputBuffer, inputBuffer, opts); +} + c10::intrusive_ptr ProcessGroupWrapper::getWrappedPg() const { return pg_; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp b/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp index 72686fd8f7799..62ec553ff3f48 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp @@ -111,6 +111,11 @@ class TORCH_API ProcessGroupWrapper : public ProcessGroup { c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const ReduceScatterOptions& opts) override; + c10::intrusive_ptr getWrappedPg() const; private: diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index 761a124dff472..22612aee12dc1 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -2,6 +2,7 @@ #include #include +#include namespace c10d { @@ -22,6 +23,22 @@ class PyProcessGroup : public ProcessGroup { wait, /* Name of function in C++ */ timeout); } + + c10::intrusive_ptr getFuture() override { + // We cannot use PYBIND11_OVERRIDE because: + // 1. We have to >MANUALLY< unwrap the PyFutureWrapper and + // 2. The python name is get_future + pybind11::gil_scoped_acquire gil; + auto override = pybind11::get_override(static_cast(this), "get_future"); + + if (override) { + py::object o = override(); + auto futWrapper = o.cast>(); + return futWrapper->fut; + } + + return Work::getFuture(); + } }; using ProcessGroup::ProcessGroup; diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 017b2e86252d6..02759541f55e7 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -34,8 +34,8 @@ struct AllreduceCoalescedOptions : AllreduceOptions {}; struct ReduceOptions { ReduceOp reduceOp = ReduceOp::SUM; - int rootRank = 0; - int rootTensor = 0; + int64_t rootRank = 0; + int64_t rootTensor = 0; std::chrono::milliseconds timeout = kUnsetTimeout; }; @@ -44,12 +44,12 @@ struct AllgatherOptions { }; struct GatherOptions { - int rootRank = 0; + int64_t rootRank = 0; std::chrono::milliseconds timeout = kUnsetTimeout; }; struct ScatterOptions { - int rootRank = 0; + int64_t rootRank = 0; std::chrono::milliseconds timeout = kUnsetTimeout; }; @@ -63,7 +63,7 @@ struct AllToAllOptions { }; struct BarrierOptions { - std::vector device_ids; + std::vector device_ids; std::chrono::milliseconds timeout = kUnsetTimeout; }; diff --git a/torch/csrc/distributed/c10d/UCCTracing.cpp b/torch/csrc/distributed/c10d/UCCTracing.cpp new file mode 100644 index 0000000000000..d23d2b68b318b --- /dev/null +++ b/torch/csrc/distributed/c10d/UCCTracing.cpp @@ -0,0 +1,176 @@ +#ifdef USE_C10D_UCC + +#include +#include + +#include + +#include +#include +#include +#include + +#ifdef FBCODE_CAFFE2 +#include +#endif + +namespace c10d { + +void ProcessGroupUCCLogger::initCommsTracer() { + trace_generator = std::make_shared(); + initialized_CommTraceLogger = true; +} + +void ProcessGroupUCCLogger::flushComms(int rank, int world_size) { + if (!initialized_CommTraceLogger || + trace_generator->getCommsTrace().empty()) { + return; + } + + std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size); + time_t now_ = time(0); + std::tm* ltm = localtime(&now_); + if (ltm) { + dirname += c10::str( + "_", (1 + ltm->tm_mon), "_", ltm->tm_mday, "_", (1900 + ltm->tm_year)); + } + + std::string fullpath = "/tmp/" + dirname; + char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR"); + if (user_path) { + fullpath = user_path; + } + std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json"); + std::ofstream _outfile; + if (!_outfile.is_open()) { + if (!mkdir(fullpath.c_str(), 0777)) { + LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath; + } else if (errno != EEXIST) { + return; + } + _outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc); + } + // flush the traced comms + if (_outfile.is_open()) { + _outfile << "[" << c10::Join(",", trace_generator->getCommsTrace()) + << "\n]"; + _outfile.flush(); + _outfile.close(); + } +#ifdef FBCODE_CAFFE2 + uploadTrace_internal( + trace_filename, dirname, c10::str("rank", rank, ".json")); +#endif +} + +/* unused */ +void CommTraceLogger::setCurBlock(const std::string& name) { + curBlocks_.push_back( + c10::str("\"", name, "\"")); // add quote marks for JSON format +} + +/* unused */ +void CommTraceLogger::popBlock() { + // TODO: remove specific name + curBlocks_.pop_back(); +} + +void CommTraceLogger::recordOptionalInfo(int root) { + curRoot_ = root; +} + +void CommTraceLogger::recordOptionalInfo( + const std::vector& outputSplitSizes, + const std::vector& inputSplitSizes) { + curOutSplitSizes_ = outputSplitSizes; + curInSplitSizes_ = inputSplitSizes; +} + +void CommTraceLogger::recordComms( + const std::string& commName, + const uintptr_t workReq, + const int rank, + const int world_size, + const std::vector& inputTensors, + const std::vector& outputTensors) { + auto inSize = (!inputTensors.empty()) ? inputTensors[0].numel() : 0; + auto outSize = (!outputTensors.empty()) ? outputTensors[0].numel() : 0; + auto dtype = + (!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte; + auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type() + : c10::DeviceType::CPU; + auto now = std::chrono::system_clock::now(); + static auto startTS = now; + int64_t time_since_begin = + std::chrono::duration_cast(now - startTS) + .count(); + + // TODO: get markers from torch profiler if enabled + + // common fields for all operations + std::string cur_trace_ = c10::str( + "\n\t\t\"markers\": [", + curBlocks_, + "]", + ",\n\t\t\"startTime_ns\": ", + time_since_begin, + ",\n\t\t\"comms\": \"", + commName, + "\"", + ",\n\t\t\"req\": ", + workReq, + ",\n\t\t\"seqnum\": ", + seqnum++, + ",\n\t\t\"world_size\": ", + world_size); + + if (inSize > 0 || outSize > 0) { + // for most collectives - append msg sizes, data type, device type + cur_trace_ = c10::str( + cur_trace_, + ",\n\t\t\"in_msg_size\": ", + inSize, + ",\n\t\t\"out_msg_size\": ", + outSize, + ",\n\t\t\"dtype\": \"", + at::toString(dtype), + "\",\n\t\t\"devType\": \"", + c10::DeviceTypeName(devType), + "\""); + } + if (curRoot_ != -1) { + // append root rank if applicable, e.g., broadcast, gather, scatter + cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_); + } + if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) { + // append input and output splits if applicable, e.g., ALLTOALL_BASE + cur_trace_ = c10::str( + cur_trace_, + ",\n\t\t\"in_split\": [", + c10::Join(",", curInSplitSizes_), + "]" + ",\n\t\t\"out_split\": [", + c10::Join(",", curOutSplitSizes_), + "]"); + } + comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}")); + + // record the trace to kineto trace if applicable + RECORD_PARAM_COMMS( + rank, + commName.c_str(), + inSize, + outSize, + dtype, + curInSplitSizes_, + curOutSplitSizes_); + + // reset optional field + curRoot_ = -1; + curInSplitSizes_ = {}; + curOutSplitSizes_ = {}; +} + +} // namespace c10d + +#endif // USE_C10D_UCC diff --git a/torch/csrc/distributed/c10d/UCCTracing.hpp b/torch/csrc/distributed/c10d/UCCTracing.hpp new file mode 100644 index 0000000000000..60d6be8775120 --- /dev/null +++ b/torch/csrc/distributed/c10d/UCCTracing.hpp @@ -0,0 +1,58 @@ +#pragma once + +#ifdef USE_C10D_UCC + +#include + +namespace c10d { + +#define RECORD_COMMS_TRACE( \ + _comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \ + do { \ + if (torch_ucc_config.enable_comms_logger) { \ + _comms_tracer->recordComms( \ + opTypeToString(_opType), \ + (uintptr_t)_work.get(), \ + _rank, \ + _comm_size, \ + _inTensors, \ + _outTensors); \ + } \ + } while (0) + +// interfaces to collect communication traces +class TORCH_API CommTraceLogger : public torch::CustomClassHolder { + private: + std::vector comms_trace_; + std::vector curBlocks_; /* unused */ + std::vector curOutSplitSizes_; + std::vector curInSplitSizes_; + int curRoot_ = -1; + unsigned long seqnum = 0; + + public: + void setCurBlock(const std::string& name); /* unused */ + void popBlock(); /* unused */ + // record root info if applicable, e.g., broadcast, gather, scatter + void recordOptionalInfo(int root = -1); + // record input/output splits of Alltoallv + void recordOptionalInfo( + const std::vector& outputSplitSizes = {}, + const std::vector& inputSplitSizes = {}); + // record essential comms information + void recordComms( + const std::string& collName, + const uintptr_t workReq = 0, + const int rank = -1, + const int world_size = -1, + const std::vector& inputTensors = {}, + const std::vector& outputTensor = {}); + // return collected comms traces + std::vector& getCommsTrace() { + return comms_trace_; + } +}; + +} // namespace c10d + +#endif // USE_C10D_UCC diff --git a/torch/csrc/distributed/c10d/UCCUtils.cpp b/torch/csrc/distributed/c10d/UCCUtils.cpp new file mode 100644 index 0000000000000..37cd829122f97 --- /dev/null +++ b/torch/csrc/distributed/c10d/UCCUtils.cpp @@ -0,0 +1,285 @@ +#ifdef USE_C10D_UCC + +#include +#include + +namespace c10d { + +namespace { +// Constants for store keys. +constexpr char kTeamRank[] = "teamr"; +constexpr char kAllGatherDone[] = "ag_done"; +constexpr char kAllGatherFree[] = "ag_free"; +} // namespace + +CommUCX::CommUCX( + int comm_size, + const c10::intrusive_ptr& logger) + : CommBase(logger) { + ucp_params_t params; + ucp_config_t* config; + ucs_status_t st; + ucp_worker_params_t worker_params; + ucp_lib_attr_t ucp_attr; + + ucp_attr.field_mask = UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL; + TORCH_UCX_CHECK( + ucp_lib_query(&ucp_attr), "failed to query UCP lib attributes"); + TORCH_CHECK( + ucp_attr.max_thread_level == UCS_THREAD_MODE_MULTI, + "ucx library wasn't initialized with multithreading support, " + "please check ucx build options"); + TORCH_UCX_CHECK( + ucp_config_read("TORCH", nullptr, &config), "failed to read UCP config"); + + memset(¶ms, 0, sizeof(ucp_params_t)); + params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_REQUEST_SIZE | + UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_TAG_SENDER_MASK | + UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_CLEANUP; + params.request_size = sizeof(ucc_coll_req_t); + params.features = UCP_FEATURE_TAG; + params.estimated_num_eps = comm_size; + params.tag_sender_mask = TORCH_UCX_RANK_MASK; + params.request_init = [](void* request) { + static_cast(request)->status = UCC_INPROGRESS; + }; + params.request_cleanup = [](void*) {}; + TORCH_UCX_CHECK( + ucp_init(¶ms, config, &context), "failed to init UCP context"); + ucp_config_release(config); + + memset(&worker_params, 0, sizeof(ucp_worker_params_t)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_MULTI; + st = ucp_worker_create(context, &worker_params, &worker); + if (st != UCS_OK) { + TORCH_UCC_LOG_ERROR( + TORCH_UCC_INIT, + c10::str("UCX failed to create UCP worker:", ucs_status_string(st))); + ucp_cleanup(context); + throw std::runtime_error(ucs_status_string(st)); + } +} + +void CommUCX::progress() { + ucp_worker_progress(worker); +} + +void CommUCX::free_request(ucc_coll_req_h request) { + request->status = UCC_INPROGRESS; + ucp_request_free(request); +} + +CommUCX::~CommUCX() { + if (worker != nullptr) { + ucp_worker_destroy(worker); + } + if (context != nullptr) { + ucp_cleanup(context); + } + worker = nullptr; + context = nullptr; +} + +ucc_status_t oob_allgather( + void* sbuf, + void* rbuf, + size_t msglen, + void* coll_info, + void** req) { + auto* info = reinterpret_cast(coll_info); + TORCH_CHECK(info != nullptr); + std::vector val = std::vector( + reinterpret_cast(sbuf), + reinterpret_cast(sbuf) + msglen); + try { + info->store->set(info->getKey(kTeamRank + std::to_string(info->rank)), val); + info->rbuf = rbuf; + info->msglen = msglen; + *req = coll_info; + } catch (std::exception& ex) { + LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " + << "[" << ex.what() << "]"; + return UCC_ERR_NO_MESSAGE; + } + return UCC_OK; +} + +ucc_status_t oob_allgather_test(void* req) { + auto* info = reinterpret_cast(req); + TORCH_CHECK(info != nullptr); + + try { + for (int r = 0; r < info->size; r++) { + if (!info->store->check({info->getKey(kTeamRank + std::to_string(r))})) { + return UCC_INPROGRESS; + } + } + for (int r = 0; r < info->size; r++) { + std::vector data = + info->store->get(info->getKey(kTeamRank + std::to_string(r))); + memcpy( + (void*)((ptrdiff_t)info->rbuf + info->msglen * r), + data.data(), + info->msglen); + } + } catch (std::exception& ex) { + LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " + << "[" << ex.what() << "]"; + return UCC_ERR_NO_MESSAGE; + } + return UCC_OK; +} + +ucc_status_t oob_allgather_free(void* req) { + auto* info = reinterpret_cast(req); + TORCH_CHECK(info != nullptr); + try { + int num_done = info->store->add({info->getKey(kAllGatherDone)}, 1); + if (num_done == info->size) { + info->store->deleteKey(info->getKey(kAllGatherDone)); + // Note: to avoid race condition, it's important to remove all keys in + // oob_allgather_free first and only after that signal completion to + // other ranks + for (const auto r : c10::irange(info->size)) { + info->store->deleteKey(info->getKey(kTeamRank + std::to_string(r))); + } + for (const auto r : c10::irange(info->size)) { + info->store->add({info->getKey(kAllGatherFree + std::to_string(r))}, 1); + } + } else { + info->store->wait( + {info->getKey(kAllGatherFree + std::to_string(info->rank))}); + } + info->store->deleteKey( + info->getKey(kAllGatherFree + std::to_string(info->rank))); + } catch (std::exception& ex) { + LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " + << "[" << ex.what() << "]"; + return UCC_ERR_NO_MESSAGE; + } + return UCC_OK; +} + +CommUCC::CommUCC( + std::shared_ptr oob, + const c10::intrusive_ptr& logger) + : CommBase(logger) { + ucc_lib_config_h lib_config; + ucc_context_config_h context_config; + ucc_lib_params_t lib_params; + ucc_context_params_t context_params; + ucc_status_t st; + + TORCH_UCC_CHECK( + ucc_lib_config_read("TORCH", nullptr, &lib_config), + "failed to read UCC lib config"); + memset(&lib_params, 0, sizeof(ucc_lib_params_t)); + lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; + lib_params.thread_mode = UCC_THREAD_MULTIPLE; + TORCH_UCC_CHECK( + ucc_init(&lib_params, lib_config, &lib), "failed to init UCC lib"); + ucc_lib_config_release(lib_config); + ucc_lib_attr_t lib_attr; + lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE; + TORCH_UCC_CHECK( + ucc_lib_get_attr(lib, &lib_attr), "failed to query for lib attr"); + TORCH_CHECK( + lib_attr.thread_mode == UCC_THREAD_MULTIPLE, + "ucc library wasn't initialized with multithreading support, " + "please check ucc build options"); + st = ucc_context_config_read(lib, NULL, &context_config); + if (st != UCC_OK) { + // FIXME: would this cause deadlock if only one rank fails? + TORCH_UCC_CHECK( + ucc_finalize(lib), + "failed to finalize UCC library when failing to read UCC context config"); + TORCH_UCC_LOG_ERROR( + TORCH_UCC_INIT, + c10::str("failed to read UCC context config: ", ucc_status_string(st))); + throw std::runtime_error(ucc_status_string(st)); + } + st = ucc_context_config_modify( + context_config, + NULL, + "ESTIMATED_NUM_EPS", + std::to_string(oob->size).c_str()); + if (st != UCC_OK) { + ucc_context_config_release(context_config); + ucc_finalize(lib); + TORCH_UCC_LOG_ERROR( + TORCH_UCC_INIT, + c10::str( + "UCC failed to modify UCC context config: ", + ucc_status_string(st))); + throw std::runtime_error(ucc_status_string(st)); + } + memset(&context_params, 0, sizeof(ucc_context_params_t)); + context_params.mask = + UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB; + context_params.type = UCC_CONTEXT_SHARED; + context_params.oob.n_oob_eps = oob->size; + context_params.oob.oob_ep = oob->rank; + context_params.oob.allgather = oob_allgather; + context_params.oob.req_test = oob_allgather_test; + context_params.oob.req_free = oob_allgather_free; + context_params.oob.coll_info = oob.get(); + st = ucc_context_create(lib, &context_params, context_config, &context); + ucc_context_config_release(context_config); + if (st != UCC_OK) { + TORCH_UCC_CHECK( + ucc_finalize(lib), + "failed to finalize UCC library when failing to creat UCC context"); + TORCH_UCC_LOG_ERROR( + TORCH_UCC_INIT, + c10::str("UCC failed to create UCC context: ", ucc_status_string(st))); + throw std::runtime_error(ucc_status_string(st)); + } +} + +void CommUCC::progress() { + TORCH_UCC_CHECK( + ucc_context_progress(context), "failed to progress UCC collective"); +} + +void CommUCC::free_request(ucc_coll_req_h request) { + TORCH_UCC_CHECK( + ucc_collective_finalize(request), "failed to release UCC request"); +} + +CommUCC::~CommUCC() { + if (context != nullptr) { + TORCH_UCC_CHECK( + ucc_context_destroy(context), "failed to destory UCC context"); + } + if (lib != nullptr) { + TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library"); + } + context = nullptr; + lib = nullptr; +} + +std::string ProcessGroupUCCLogger::getLogPrefix(torch_ucc_phase_t phase) { + // caller can override the phase stored locally + torch_ucc_phase_t phase_ = + (local_phase != phase && phase != TORCH_UCC_UNKNOWN) ? phase + : local_phase; + return c10::str(log_prefix, "[", ucc_phase_map.at(phase_), "]"); +} +void ProcessGroupUCCLogger::setLogPrefix(std::string log_prefix_) { + log_prefix = log_prefix_; +} + +ProcessGroupUCCLogger::ProcessGroupUCCLogger() { + setLogPrefix("[ProcessGroupUCC]"); +} +ProcessGroupUCCLogger::ProcessGroupUCCLogger( + std::string log_prefix, + torch_ucc_phase_t phase) + : local_phase(phase) { + setLogPrefix(log_prefix); +} + +} // namespace c10d + +#endif // USE_C10D_UCC diff --git a/torch/csrc/distributed/c10d/UCCUtils.hpp b/torch/csrc/distributed/c10d/UCCUtils.hpp new file mode 100644 index 0000000000000..70cef1d7f99a4 --- /dev/null +++ b/torch/csrc/distributed/c10d/UCCUtils.hpp @@ -0,0 +1,193 @@ +#pragma once + +#ifdef USE_C10D_UCC + +#include +#include +#include +#include + +#define TORCH_UCX_COMM_BITS 15 +#define TORCH_UCX_RANK_BITS 16 +#define TORCH_UCX_TAG_BITS 32 +#define TORCH_UCX_OOB_BITS 1 + +#define TORCH_UCX_COMM_BITS_OFFSET 0 +#define TORCH_UCX_RANK_BITS_OFFSET TORCH_UCX_COMM_BITS +#define TORCH_UCX_TAG_BITS_OFFSET (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS) +#define TORCH_UCX_OOB_BITS_OFFSET \ + (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS + TORCH_UCX_TAG_BITS) + +#define TORCH_UCX_MAX_COMM ((((uint64_t)1) << TORCH_UCX_COMM_BITS) - 1) +#define TORCH_UCX_MAX_RANK ((((uint64_t)1) << TORCH_UCX_RANK_BITS) - 1) +#define TORCH_UCX_MAX_TAG ((((uint64_t)1) << TORCH_UCX_TAG_BITS) - 1) +#define TORCH_UCX_MAX_OOB ((((uint64_t)1) << TORCH_UCX_OOB_BITS) - 1) + +#define TORCH_UCX_COMM_MASK (TORCH_UCX_MAX_COMM << TORCH_UCX_COMM_BITS_OFFSET) +#define TORCH_UCX_RANK_MASK (TORCH_UCX_MAX_RANK << TORCH_UCX_RANK_BITS_OFFSET) +#define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET) +#define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET) + +namespace c10d { + +// Macro to throw on a non-successful UCC return value. +#define TORCH_UCC_CHECK(_cmd, _error_msg) \ + do { \ + ucc_status_t result = _cmd; \ + if (result != UCC_OK) { \ + std::string err = c10::str( \ + "[", \ + std::string(__FILE__), \ + ":", \ + std::to_string(__LINE__), \ + "] ", \ + logger->getLogPrefix(), \ + _error_msg, \ + ", error code ", \ + result, \ + ": ", \ + ucc_status_string(result), \ + ", system error code ", \ + errno); \ + TORCH_CHECK(false, err); \ + } \ + } while (0) + +// Macro to throw on a non-successful UCX return value. +#define TORCH_UCX_CHECK(_cmd, _error_msg) \ + do { \ + ucs_status_t result = _cmd; \ + if (result != UCS_OK) { \ + std::string err = c10::str( \ + "[", \ + std::string(__FILE__), \ + ":", \ + std::to_string(__LINE__), \ + "] ", \ + logger->getLogPrefix(), \ + _error_msg, \ + ", error code ", \ + result, \ + ": ", \ + ucs_status_string(result), \ + ", system error code ", \ + errno); \ + TORCH_CHECK(false, err); \ + } \ + } while (0) + +// Macros to print logs with unified format +#define TORCH_UCC_LOG_ERROR(_phase, _msg) \ + LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg; +#define TORCH_UCC_LOG_INFO(_phase, _msg) \ + LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg; +#define TORCH_UCC_LOG_DEBUG(_phase, _msg) \ + VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg; + +enum torch_ucc_phase_t { + TORCH_UCC_UNKNOWN = -1, + TORCH_UCC_INIT, + TORCH_UCC_HEALTH_CHECK, + TORCH_UCC_READY, + TORCH_UCC_COLL_POST, + TORCH_UCC_COLL_PROGRESS, + TORCH_UCC_FINALIZE, +}; + +const std::map ucc_phase_map = { + {TORCH_UCC_UNKNOWN, "UNKNOWN"}, + {TORCH_UCC_INIT, "INIT"}, + {TORCH_UCC_HEALTH_CHECK, "HEALTH_CHECK"}, + {TORCH_UCC_READY, "READY"}, + {TORCH_UCC_COLL_POST, "COLL_POST"}, + {TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"}, + {TORCH_UCC_FINALIZE, "FINALIZE"}, +}; + +class CommTraceLogger; + +class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder { + public: + ProcessGroupUCCLogger(); + ProcessGroupUCCLogger(std::string log_prefix, torch_ucc_phase_t phase); + + std::string getLogPrefix(torch_ucc_phase_t phase = TORCH_UCC_UNKNOWN); + void setLogPrefix(std::string log_prefix); + inline void setPhase(torch_ucc_phase_t phase) { + local_phase = phase; + } + + void initCommsTracer(); + void flushComms(int rank, int world_size); + std::shared_ptr trace_generator = nullptr; + + protected: + std::string log_prefix; + torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN; + bool initialized_CommTraceLogger = false; +}; + +struct torch_ucc_oob_coll_info_t { + c10::intrusive_ptr store; + uint32_t comm_id; + int rank; + int size; + void* rbuf; + size_t msglen; + std::string getKey(std::string key) { + return std::to_string(comm_id) + key; + } +}; + +class CommBase { + public: + CommBase(const c10::intrusive_ptr& logger_) + : logger(logger_) {} + virtual void progress() = 0; + virtual void free_request(ucc_coll_req_h request) = 0; + virtual ~CommBase() {} + c10::intrusive_ptr logger; +}; + +class CommUCX : public CommBase { + public: + ucp_context_h context{nullptr}; + ucp_worker_h worker{nullptr}; + + public: + void progress() override; + void free_request(ucc_coll_req_h request) override; + CommUCX( + int comm_size, + const c10::intrusive_ptr& logger); + ~CommUCX(); +}; + +class CommUCC : public CommBase { + public: + ucc_lib_h lib{nullptr}; + ucc_context_h context{nullptr}; + + public: + void progress() override; + CommUCC( + std::shared_ptr oob, + const c10::intrusive_ptr& logger); + void free_request(ucc_coll_req_h request) override; + ~CommUCC(); +}; + +ucc_status_t oob_allgather( + void* sbuf, + void* rbuf, + size_t msglen, + void* coll_info, + void** req); + +ucc_status_t oob_allgather_test(void* req); + +ucc_status_t oob_allgather_free(void* req); + +} // namespace c10d + +#endif // USE_C10D_UCC diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/csrc/distributed/c10d/comm.cpp index 8c72fabd63adb..d4c26d99bb0c1 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/csrc/distributed/c10d/comm.cpp @@ -99,5 +99,26 @@ std::vector GradBucket::getGradients() const { } return per_parameter_tensors; } +namespace detail { + +at::Tensor parseCppCommHookResult(const c10::IValue& result) { + if (result.isPyObject()) { + std::vector tensors = + result.toPyObjectHolder()->extractTensors(); + return tensors[0]; + } + TORCH_INTERNAL_ASSERT( + result.isTensor() || result.isTensorList(), + "expected the hook result is either a Tensor or a TensorList found ", + result.tagKind()); + + if (result.isTensor()) { + return result.toTensor(); + } + + return result.toTensorVector()[0]; +} + +} // namespace detail } // namespace c10d diff --git a/torch/csrc/distributed/c10d/comm.hpp b/torch/csrc/distributed/c10d/comm.hpp index 43ffdb07d78fd..1f47826ab9cb9 100644 --- a/torch/csrc/distributed/c10d/comm.hpp +++ b/torch/csrc/distributed/c10d/comm.hpp @@ -106,18 +106,7 @@ class TORCH_API CommHookInterface { namespace detail { // This helper function is called both by CppCommHookInterface below and inside // reducer. -inline at::Tensor parseCppCommHookResult( - const c10::IValue& result) { - TORCH_INTERNAL_ASSERT( - result.isTensor() || result.isTensorList(), - "expected the hook result is either a Tensor or a TensorList"); - - if (result.isTensor()) { - return result.toTensor(); - } - - return result.toTensorVector()[0]; -} + at::Tensor parseCppCommHookResult(const c10::IValue& result); } // namespace detail // This CppCommHook interface only requires implementing runHook method that @@ -125,7 +114,7 @@ inline at::Tensor parseCppCommHookResult( template class CppCommHookInterface : public CommHookInterface { public: - explicit CppCommHookInterface(T& state) : state_(state) {} + explicit CppCommHookInterface(const T& state) : state_(state) {} ~CppCommHookInterface() override = default; diff --git a/torch/csrc/distributed/c10d/default_comm_hooks.cpp b/torch/csrc/distributed/c10d/default_comm_hooks.cpp index e78e2d7a0fac9..6599c2c0d197e 100644 --- a/torch/csrc/distributed/c10d/default_comm_hooks.cpp +++ b/torch/csrc/distributed/c10d/default_comm_hooks.cpp @@ -4,6 +4,7 @@ #include #include +#include #include namespace c10d { @@ -13,7 +14,7 @@ c10::intrusive_ptr AllReduceCommHook::runHook( std::vector tensors = {bucket.getBufferRef()}; // Apply the division first to avoid overflow, especially for FP16. tensors[0] /= state_->getSize(); - return state_->allreduce(tensors)->getFuture(); + return ops::allreduce(state_, tensors)->getFuture(); } c10::intrusive_ptr FP16CompressCommHook::runHook( @@ -23,7 +24,7 @@ c10::intrusive_ptr FP16CompressCommHook::runHook( compressed_tensor /= state_->getSize(); std::vector tensors = {compressed_tensor}; - auto allreduce_fut = state_->allreduce(tensors)->getFuture(); + auto allreduce_fut = ops::allreduce(state_, tensors)->getFuture(); auto decompressed_tensor = bucket.getBufferRef(); auto decompress = [decompressed_tensor](c10::ivalue::Future& allreduce_fut) { auto result = allreduce_fut.value(); @@ -46,7 +47,7 @@ c10::intrusive_ptr FP16CompressCommHook::runHook( c10::intrusive_ptr _AllReduceBySumCommHook::runHook( GradBucket& bucket) { std::vector tensors = {bucket.getBufferRef()}; - return state_->allreduce(tensors)->getFuture(); + return ops::allreduce(state_, tensors)->getFuture(); } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/default_comm_hooks.hpp b/torch/csrc/distributed/c10d/default_comm_hooks.hpp index 37b64a4badb36..df6c8db290542 100644 --- a/torch/csrc/distributed/c10d/default_comm_hooks.hpp +++ b/torch/csrc/distributed/c10d/default_comm_hooks.hpp @@ -10,20 +10,20 @@ enum class BuiltinCommHookType { FP16_COMPRESS = 2, }; -class AllReduceCommHook : public CppCommHookInterface { +class AllReduceCommHook : public CppCommHookInterface> { public: - explicit AllReduceCommHook(ProcessGroup* state) - : CppCommHookInterface(state) {} + explicit AllReduceCommHook(const c10::intrusive_ptr& state) + : CppCommHookInterface>(state) {} ~AllReduceCommHook() override = default; c10::intrusive_ptr runHook(GradBucket& bucket) override; }; -class FP16CompressCommHook : public CppCommHookInterface { +class FP16CompressCommHook : public CppCommHookInterface> { public: - explicit FP16CompressCommHook(ProcessGroup* state) - : CppCommHookInterface(state) {} + explicit FP16CompressCommHook(const c10::intrusive_ptr& state) + : CppCommHookInterface>(state) {} ~FP16CompressCommHook() override = default; @@ -35,10 +35,10 @@ class FP16CompressCommHook : public CppCommHookInterface { // over all the input parameters, when no communication hook is provided by the user. // Only used internally and not released as a public built-in communication hook. class _AllReduceBySumCommHook - : public CppCommHookInterface { + : public CppCommHookInterface> { public: - explicit _AllReduceBySumCommHook(ProcessGroup* state) - : CppCommHookInterface(state) {} + explicit _AllReduceBySumCommHook(const c10::intrusive_ptr& state) + : CppCommHookInterface>(state) {} ~_AllReduceBySumCommHook() override = default; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 513171816cb94..af4cf05107d5a 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -24,6 +24,10 @@ #include #endif +#ifdef USE_C10D_UCC +#include +#endif + #include #include #include @@ -989,7 +993,7 @@ that adds a prefix to each key inserted to the store. "broadcast", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, const std::vector& tensors, - ::c10d::BroadcastOptions opts) { + const ::c10d::BroadcastOptions& opts) { return ::c10d::ops::broadcast(self, tensors, opts); }, py::arg("tensors"), @@ -1011,19 +1015,23 @@ that adds a prefix to each key inserted to the store. .def( "allreduce", - &::c10d::ProcessGroup::allreduce, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& tensors, + const ::c10d::AllreduceOptions& opts) { + return ::c10d::ops::allreduce(self, tensors, opts); + }, py::arg("tensors"), py::arg("opts") = ::c10d::AllreduceOptions(), py::call_guard()) .def( "allreduce", - [](::c10d::ProcessGroup& pg, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, std::vector& xs, ::c10d::ReduceOp op) { ::c10d::AllreduceOptions opts; opts.reduceOp = op; - return pg.allreduce(xs, opts); + return ::c10d::ops::allreduce(self, xs, opts); }, py::arg("tensors"), py::arg("op") = ::c10d::ReduceOp::SUM, @@ -1031,11 +1039,13 @@ that adds a prefix to each key inserted to the store. .def( "allreduce", - [](::c10d::ProcessGroup& pg, at::Tensor& x, ::c10d::ReduceOp op) { + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& x, + ::c10d::ReduceOp op) { ::c10d::AllreduceOptions opts; opts.reduceOp = op; std::vector xs = {x}; - return pg.allreduce(xs, opts); + return ::c10d::ops::allreduce(self, xs, opts); }, py::arg("tensor"), py::arg("op") = ::c10d::ReduceOp::SUM, @@ -1043,10 +1053,10 @@ that adds a prefix to each key inserted to the store. .def( "allreduce_coalesced", - [](::c10d::ProcessGroup& pg, + [](::c10d::ProcessGroup& self, std::vector& xs, ::c10d::AllreduceCoalescedOptions opts) { - return pg.allreduce_coalesced(xs, opts); + return self.allreduce_coalesced(xs, opts); }, py::arg("tensors"), py::arg("opts") = ::c10d::AllreduceCoalescedOptions(), @@ -1054,14 +1064,18 @@ that adds a prefix to each key inserted to the store. .def( "reduce", - &::c10d::ProcessGroup::reduce, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& tensors, + const ::c10d::ReduceOptions& opts) { + return ::c10d::ops::reduce(self, tensors, opts); + }, py::arg("tensors"), py::arg("opts") = ::c10d::ReduceOptions(), py::call_guard()) .def( "reduce", - [](::c10d::ProcessGroup& pg, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& x, int rootRank, ::c10d::ReduceOp op) { @@ -1069,7 +1083,7 @@ that adds a prefix to each key inserted to the store. opts.reduceOp = op; opts.rootRank = rootRank; std::vector xs = {x}; - return pg.reduce(xs, opts); + return ::c10d::ops::reduce(self, xs, opts); }, py::arg("tensor"), py::arg("root"), @@ -1078,7 +1092,13 @@ that adds a prefix to each key inserted to the store. .def( "allgather", - &::c10d::ProcessGroup::allgather, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector>& output_tensors, + const std::vector& input_tensor, + const ::c10d::AllgatherOptions& opts) { + return ::c10d::ops::allgather( + self, output_tensors, input_tensor, opts); + }, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::AllgatherOptions(), @@ -1094,13 +1114,13 @@ that adds a prefix to each key inserted to the store. .def( "allgather", - [](::c10d::ProcessGroup& pg, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, std::vector& output, at::Tensor& input) { std::vector> outputs = {output}; std::vector inputs = {input}; - return pg.allgather( - outputs, inputs, ::c10d::AllgatherOptions()); + return ::c10d::ops::allgather( + self, outputs, inputs, ::c10d::AllgatherOptions()); }, py::arg("output_tensors"), py::arg("input_tensor"), @@ -1116,7 +1136,13 @@ that adds a prefix to each key inserted to the store. .def( "gather", - &::c10d::ProcessGroup::gather, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector>& output_tensors, + const std::vector& input_tensors, + const ::c10d::GatherOptions& opts) { + return ::c10d::ops::gather( + self, output_tensors, input_tensors, opts); + }, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::GatherOptions(), @@ -1124,7 +1150,7 @@ that adds a prefix to each key inserted to the store. .def( "gather", - [](::c10d::ProcessGroup& pg, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, std::vector& output, at::Tensor& input, int rootRank) { @@ -1132,7 +1158,7 @@ that adds a prefix to each key inserted to the store. opts.rootRank = rootRank; std::vector> outputs = {output}; std::vector inputs = {input}; - return pg.gather(outputs, inputs, opts); + return ::c10d::ops::gather(self, outputs, inputs, opts); }, py::arg("output_tensors"), py::arg("input_tensor"), @@ -1141,7 +1167,13 @@ that adds a prefix to each key inserted to the store. .def( "scatter", - &::c10d::ProcessGroup::scatter, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& output_tensors, + const std::vector>& input_tensors, + const ::c10d::ScatterOptions& opts) { + return ::c10d::ops::scatter( + self, output_tensors, input_tensors, opts); + }, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::ScatterOptions(), @@ -1149,7 +1181,7 @@ that adds a prefix to each key inserted to the store. .def( "scatter", - [](::c10d::ProcessGroup& pg, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& output, std::vector& input, int rootRank) { @@ -1157,7 +1189,7 @@ that adds a prefix to each key inserted to the store. opts.rootRank = rootRank; std::vector> inputs = {input}; std::vector outputs = {output}; - return pg.scatter(outputs, inputs, opts); + return ::c10d::ops::scatter(self, outputs, inputs, opts); }, py::arg("output_tensor"), py::arg("input_tensors"), @@ -1166,7 +1198,13 @@ that adds a prefix to each key inserted to the store. .def( "reduce_scatter", - &::c10d::ProcessGroup::reduce_scatter, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& output_tensors, + const std::vector>& input_tensors, + const ::c10d::ReduceScatterOptions& opts) { + return ::c10d::ops::reduce_scatter( + self, output_tensors, input_tensors, opts); + }, py::arg("output_tensors"), py::arg("input_tensors"), py::arg("opts") = ::c10d::ReduceScatterOptions(), @@ -1174,7 +1212,7 @@ that adds a prefix to each key inserted to the store. .def( "reduce_scatter", - [](::c10d::ProcessGroup& pg, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& output, std::vector& input, ::c10d::ReduceOp op) { @@ -1182,7 +1220,7 @@ that adds a prefix to each key inserted to the store. std::vector> inputs = {input}; ::c10d::ReduceScatterOptions opts; opts.reduceOp = op; - return pg.reduce_scatter(outputs, inputs, opts); + return ::c10d::ops::reduce_scatter(self, outputs, inputs, opts); }, py::arg("output_tensors"), py::arg("input_tensor"), @@ -1209,12 +1247,12 @@ that adds a prefix to each key inserted to the store. .def( "alltoall_base", - [](::c10d::ProcessGroup& pg, + [](::c10d::ProcessGroup& self, at::Tensor& output, at::Tensor& input, std::vector outputSplitSizes, std::vector inputSplitSizes) { - return pg.alltoall_base( + return self.alltoall_base( output, input, outputSplitSizes, @@ -1229,7 +1267,13 @@ that adds a prefix to each key inserted to the store. .def( "alltoall", - &::c10d::ProcessGroup::alltoall, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& output_tensors, + const std::vector& input_tensors, + const ::c10d::AllToAllOptions& opts) { + return ::c10d::ops::alltoall( + self, output_tensors, input_tensors, opts); + }, py::arg("output_tensor"), py::arg("input_tensor"), py::arg("opts") = ::c10d::AllToAllOptions(), @@ -1237,10 +1281,11 @@ that adds a prefix to each key inserted to the store. .def( "alltoall", - [](::c10d::ProcessGroup& pg, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, std::vector& output, std::vector& input) { - return pg.alltoall(output, input, ::c10d::AllToAllOptions()); + return ::c10d::ops::alltoall( + self, output, input, ::c10d::AllToAllOptions()); }, py::arg("output"), py::arg("input"), @@ -1248,12 +1293,28 @@ that adds a prefix to each key inserted to the store. .def( "send", - &::c10d::ProcessGroup::send, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& tensors, + int64_t dstRank, + int64_t tag) { + return ::c10d::ops::send(self, tensors, dstRank, tag); + }, + py::arg("tensors"), + py::arg("dstRank"), + py::arg("tag"), py::call_guard()) .def( "recv", - &::c10d::ProcessGroup::recv, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& tensors, + int64_t srcRank, + int64_t tag) { + return ::c10d::ops::recv(self, tensors, srcRank, tag); + }, + py::arg("tensors"), + py::arg("srcRank"), + py::arg("tag"), py::call_guard()) .def( @@ -1263,7 +1324,10 @@ that adds a prefix to each key inserted to the store. .def( "barrier", - &::c10d::ProcessGroup::barrier, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const ::c10d::BarrierOptions& opts) { + return ::c10d::ops::barrier(self, opts); + }, py::arg("opts") = ::c10d::BarrierOptions(), py::call_guard()) .def( @@ -1497,6 +1561,25 @@ Example:: py::call_guard()); #endif +#ifdef USE_C10D_UCC + auto processGroupUCC = + intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>( + module, "ProcessGroupUCC", processGroup) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size, + const std::chrono::milliseconds& timeout) { + return c10::make_intrusive<::c10d::ProcessGroupUCC>( + store, rank, size, timeout); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::arg("timeout") = kProcessGroupDefaultTimeout, + py::call_guard()); +#endif + py::class_< ::c10d::ProcessGroup::Work, c10::intrusive_ptr<::c10d::ProcessGroup::Work>, @@ -1717,6 +1800,36 @@ Example:: module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout); module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout); + module.def( + "_create_work_from_future", + [](std::shared_ptr future) { + return ::c10d::ProcessGroup::Work::create_from_future(future->fut); + }, + py::arg("future"), + R"( + Arguments: + future(str): The future to wrap. + Returns: + A ``ProcessGroup::Work`` object which is associated with the completion of + the ``torch.futures.Future``. + This is the prefered way of constructing Work objects when writing a custom ProcessGroup + in python. + Example:: + >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup): + >>> def broadcast(self, tensor_list, opts): + >>> fut = torch.futures.Future() + >>> fut.set_result(tensor_list) + >>> return torch._C._distributed_c10d._create_work_from_future(fut) + .. warning :: + This API is experimental and subject to change. + The returned Work object has multiple limitations: + - synchronize() does nothing. Use ``torch.futures.Future`` based synchronization. + - wait() ignored timeout argument. + - sourceRank() raises. + - abort() raises. + The provided Future object result must be a Tensor or a list of Tensors. + )"); + Py_RETURN_TRUE; } diff --git a/torch/csrc/distributed/c10d/logger.cpp b/torch/csrc/distributed/c10d/logger.cpp index 1b4b5df3ce768..7b940f645283a 100644 --- a/torch/csrc/distributed/c10d/logger.cpp +++ b/torch/csrc/distributed/c10d/logger.cpp @@ -5,6 +5,8 @@ #include #include +#include + #ifdef USE_C10D_GLOO #include #endif @@ -52,10 +54,10 @@ Logger::Logger(std::shared_ptr reducer) { ddp_logging_data_ = std::make_unique(); } -std::once_flag log_graph_static_flag; +c10::once_flag log_graph_static_flag; void Logger::log_if_graph_static(bool is_static) { - std::call_once(log_graph_static_flag, [this, is_static]() { + c10::call_once(log_graph_static_flag, [this, is_static]() { ddp_logging_data_->ints_map["can_set_static_graph"] = is_static; // It is useful to report the iteration that training finished at. ddp_logging_data_->ints_map["iteration"] = reducer_->num_iterations_; diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 85c2aa84f5938..07bf8e0b73c28 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -69,6 +69,22 @@ class CpuTimer : public Timer { C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer); +std::vector extractTensors(const c10::IValue& result) { + if (result.isPyObject()) { + return result.toPyObjectHolder()->extractTensors(); + } + TORCH_INTERNAL_ASSERT( + result.isTensor() || result.isTensorList(), + "expected the hook result is either a Tensor or a TensorList found ", + result.tagKind()); + + if (result.isTensor()) { + return {result.toTensor()}; + } + + return result.toTensorVector(); +} + } // namespace Reducer::Reducer( @@ -494,7 +510,10 @@ void Reducer::set_divide_factor() { auto& workHandle = forwardPassWorkHandle_.workHandle; if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) { workHandle->wait(); - auto results = workHandle->result(); + // PyProcessGroup::PyWork doesn't expose value, so fetch it from the + // future + auto results = extractTensors(workHandle->getFuture()->value()); + // Guard against the results being empty TORCH_INTERNAL_ASSERT(results.size() > 0); at::Tensor& res = results.front(); @@ -535,6 +554,37 @@ void Reducer::delay_all_reduce() { } } + // To avoid confusion around why static graph is picking up + // some parameters as unused on a rank vs not, we log + // unused parameter names for each rank for better + // debugability when TORCH_DISTRIBUTED_DEBUG is set to + // INFO or DETAIL + if (ddp_debug_level_ != c10d::DebugLevel::Off) { + // construct one string to output + std::ostringstream unused_params_stream; + + for (const auto& unused_index : unused_parameters_) { + auto param_name = param_names_.find(unused_index); + TORCH_INTERNAL_ASSERT( + param_name != param_names_.end(), + "Expected to find parameter name from unused parameters map in debug mode."); + // Add the param_name + unused_params_stream << "{" << param_name->second << "," << unused_index + << "}"; + } + + // Each rank prints out all the unused parameters detected + if (unused_parameters_.size() > 0) { + LOG(INFO) << "[Rank " << process_group_->getRank() << "]: " + << "Parameter(s) (in the format of {param_name, index}): " + << unused_params_stream.str() + << " is(are) unused during first iteration. Since" + << " static_graph=True is enabled for DDP, we expect" + << " this set of unused parameters to remain consistent" + << " on this rank throughout the training."; + } + } + // launch all reduces for all buckets for (auto& bucket : buckets_) { all_reduce_bucket(bucket); @@ -678,7 +728,8 @@ void Reducer::all_reduce_local_used_map() { local_used_map_dev_.copy_(local_used_map_, true); } std::vector temp_local_used_map_dev_vec_ = {local_used_map_dev_}; - local_used_work_ = process_group_->allreduce(temp_local_used_map_dev_vec_); + local_used_work_ = + ops::allreduce(process_group_, temp_local_used_map_dev_vec_); } at::Tensor& Reducer::get_param_from_index(size_t index) { @@ -826,7 +877,7 @@ void Reducer::mark_variable_ready(size_t variable_index) { c10::intrusive_ptr Reducer::run_comm_hook( GradBucket& grad_bucket) { if (comm_hook_ == nullptr) { - _AllReduceBySumCommHook allreduce_hook(process_group_.get()); + _AllReduceBySumCommHook allreduce_hook(process_group_); return allreduce_hook.runHook(grad_bucket); } else { return comm_hook_->runHook(grad_bucket); @@ -1709,13 +1760,11 @@ void Reducer::register_builtin_comm_hook( switch (comm_hook_type) { case c10d::BuiltinCommHookType::ALLREDUCE: - comm_hook_ = - std::make_unique(process_group_.get()); + comm_hook_ = std::make_unique(process_group_); LOG(INFO) << "Built-in communication hook ALLREDUCE is registered."; break; case c10d::BuiltinCommHookType::FP16_COMPRESS: - comm_hook_ = - std::make_unique(process_group_.get()); + comm_hook_ = std::make_unique(process_group_); LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered."; break; default: @@ -2060,7 +2109,8 @@ void verify_params_across_processes( } std::vector param_size_vec{param_size_tensor}; - process_group->allgather(param_size_output_tensors, param_size_vec)->wait(); + ops::allgather(process_group, param_size_output_tensors, param_size_vec) + ->wait(); auto result_size_tensors = param_size_output_tensors.front(); for (size_t i = 0; i < world_size; ++i) { auto param_size_for_rank = result_size_tensors[i][0].item(); diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index b39d63a2eb90a..26fcc2aa088fe 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -34,6 +34,8 @@ #include #include +#include + namespace c10d { namespace detail { namespace { @@ -864,11 +866,11 @@ void SocketConnectOp::throwTimeoutError() const { void Socket::initialize() { #ifdef _WIN32 - static std::once_flag init_flag{}; + static c10::once_flag init_flag{}; // All processes that call socket functions on Windows must first initialize // the Winsock library. - std::call_once(init_flag, []() { + c10::call_once(init_flag, []() { WSADATA data{}; if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) { throw SocketError{"The initialization of Winsock has failed."}; diff --git a/torch/csrc/generic/utils.h b/torch/csrc/generic/utils.h deleted file mode 100644 index 538b265a226e2..0000000000000 --- a/torch/csrc/generic/utils.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "torch/csrc/generic/utils.h" -#else - -#if defined(TH_REAL_IS_HALF) -#define GENERATE_SPARSE 0 -#else -#define GENERATE_SPARSE 1 -#endif - -struct THPStorage; -struct THSPTensor; - -typedef class THPPointer THPStoragePtr; - -#if (!defined(THC_GENERIC_FILE)) && (!defined(THQUANTIZED)) -template <> -struct THPUtils_typeTraits { -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || \ - defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || \ - defined(THC_REAL_IS_HALF) - static constexpr const char* python_type_str = "float"; -#elif defined(TH_REAL_IS_COMPLEX) || defined(THC_REAL_IS_COMPLEX) - static constexpr const char* python_type_str = "complex"; -#else - static constexpr const char* python_type_str = "int"; -#endif -}; -#endif - -#undef GENERATE_SPARSE - -#endif diff --git a/torch/csrc/init_flatbuffer_module.cpp b/torch/csrc/init_flatbuffer_module.cpp index 1f29148d72f5d..d11532890a29e 100644 --- a/torch/csrc/init_flatbuffer_module.cpp +++ b/torch/csrc/init_flatbuffer_module.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include // NOLINT #include diff --git a/torch/csrc/itt.cpp b/torch/csrc/itt.cpp new file mode 100644 index 0000000000000..e434ad09b2d98 --- /dev/null +++ b/torch/csrc/itt.cpp @@ -0,0 +1,15 @@ +#include +#include + +namespace torch { +namespace profiler { +void initIttBindings(PyObject* module) { + auto m = py::handle(module).cast(); + + auto itt = m.def_submodule("_itt", "VTune ITT bindings"); + itt.def("rangePush", itt_range_push); + itt.def("rangePop", itt_range_pop); + itt.def("mark", itt_mark); +} +} // namespace profiler +} // namespace torch diff --git a/torch/csrc/itt_wrapper.cpp b/torch/csrc/itt_wrapper.cpp new file mode 100644 index 0000000000000..a268d997c4900 --- /dev/null +++ b/torch/csrc/itt_wrapper.cpp @@ -0,0 +1,23 @@ +#include +#include + +namespace torch { +namespace profiler { +__itt_domain* _itt_domain = __itt_domain_create("PyTorch"); + +TORCH_API void itt_range_push(const char* msg) { + __itt_string_handle* hsMsg = __itt_string_handle_create(msg); + __itt_task_begin(_itt_domain, __itt_null, __itt_null, hsMsg); +} + +TORCH_API void itt_range_pop() { + __itt_task_end(_itt_domain); +} + +TORCH_API void itt_mark(const char* msg) { + __itt_string_handle* hsMsg = __itt_string_handle_create(msg); + __itt_task_begin(_itt_domain, __itt_null, __itt_null, hsMsg); + __itt_task_end(_itt_domain); +} +} // namespace profiler +} // namespace torch diff --git a/torch/csrc/itt_wrapper.h b/torch/csrc/itt_wrapper.h new file mode 100644 index 0000000000000..7460b47932b84 --- /dev/null +++ b/torch/csrc/itt_wrapper.h @@ -0,0 +1,12 @@ +#ifndef PROFILER_ITT_H +#define PROFILER_ITT_H + +namespace torch { +namespace profiler { +void itt_range_push(const char* msg); +void itt_range_pop(); +void itt_mark(const char* msg); +} // namespace profiler +} // namespace torch + +#endif // PROFILER_ITT_H diff --git a/torch/csrc/jit/JIT-AUTOCAST.md b/torch/csrc/jit/JIT-AUTOCAST.md index 00e66b77b14fb..c8e7ffac5d443 100644 --- a/torch/csrc/jit/JIT-AUTOCAST.md +++ b/torch/csrc/jit/JIT-AUTOCAST.md @@ -17,6 +17,7 @@ - [Mixing eager mode and scripting autocast](#mixing-eager-mode-and-scripting-autocast) - [Mixing tracing and scripting autocast (script calling traced)](#mixing-tracing-and-scripting-autocast-script-calling-traced) - [Mixing tracing and scripting autocast (traced calling script)](#mixing-tracing-and-scripting-autocast-traced-calling-script) + - [Disabling eager autocast with scripted autocast](#disabling-eager-autocast-with-scripted-autocast) - [References](#references) @@ -169,6 +170,25 @@ def traced(a, b): torch.jit.trace(traced, (x, y)) ``` +#### Disabling eager autocast with scripted autocast + +If eager-mode autocast is enabled and we try to disable autocasting from +within a scripted function, autocasting will still occur. + +```python +@torch.jit.script +def fn(a, b): + with autocast(enabled=False): + return torch.mm(a, b) + +x = torch.rand((2, 2), device='cuda', dtype=torch.float) +y = torch.rand((2, 2), device='cuda', dtype=torch.float) + +# this will print half-precision dtype +with autocast(enabled=True): + print(fn(x, y).dtype) +``` + ## References - [torch.cuda.amp Package][1] diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index 553cdc9c16986..e93d047d63117 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -58,6 +58,8 @@ Sections start with a reference to the source file where the code related to the - [Derivative Preserving Optimization](#derivative-preserving-optimization) - [Post-derivative optimization](#post-derivative-optimization) - [Derivate Splitting](#derivate-splitting) + - [Fusers](#fusers) + - [Disabling Optimizations](#disabling-optimizations) - [JIT Logging](#jit-logging) - [JIT Optimization Limiter](#jit-optimization-limiter) - [DifferentiableGraphOp](#differentiablegraphop) @@ -871,7 +873,7 @@ graph(%x : Tensor, [runtime/graph_executor.cpp](runtime/graph_executor.cpp) -All program execution starts with a graph executor. Its responsible for running optimizations (potentially involving the JIT-compilation of fused kernel code), and then handing the `Graph` or subcomponents of it off to an interpreter to actually run. +All program execution starts with a graph executor. It's responsible for running optimizations (potentially involving the JIT-compilation of fused kernel code), and then handing the `Graph` or subcomponents of it off to an interpreter to actually run. In this section, we use a running example program that computes one step of an LSTM to show how the graph is transformed: @@ -1166,6 +1168,59 @@ with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *), return (%hy, %cy) ``` +### Fusers ### + +As mentioned in the [Post-derivative optimization](#post-derivative-optimization) section, one of the +available optimizations is _fusion_, which merges operator kernels and compiles new kernels. Fusion +has two benefits: first, it reduces dispatcher overhead by combining multiple operator calls into a +single call to the fused kernel; and second, on GPU it can reduce the number of reads and writes to +global GPU memory, which can be a significant portion of the runtime for pointwise operators. + +The current default fuser on NVIDIA GPUs is +[NVFuser](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/README.md), while other use cases use +[NNC](https://github.com/pytorch/pytorch/tree/master/torch/csrc/jit/tensorexpr) as a fuser. + +Since fusers rely on specialized information that is only available at runtime - such as dtype, +device, and shape - they are only applied after the first invocation of a torchscript function or +module. As a result, the first invocation of a torchscript function can sometimes behave slightly +differently from subsequent invocations. + +To enable/disable different fusers, refer to the settings below. These settings apply globally in +the process in which they are set. Different fusers may excel in different scenarios, and disabling +or switching the fuser could also provide a temporary fix in case of bugs. + +**Python APIs:** + + +| Feature | Python API | +|---|---| +| NNC enable/disable | `torch._C._jit_set_texpr_fuser_enabled()` | +| NNC on CPU | `torch._C._jit_override_can_fuse_on_cpu()` | +| NNC on GPU | `torch._C._jit_override_can_fuse_on_gpu()` | +| NNC context manager | `with torch.jit.fuser("fuser1"):` | +| NVFuser enable/disable | `torch._C._jit_set_nvfuser_enabled()` | +| NVFuser context manager | `with torch.jit.fuser("fuser2")` | + +**C++ APIs:** + +| Feature | C++ API | Header file | +|---|---|---| +| NNC enable/disable | `torch::jit::setTensorExprFuserEnabled(bool);` | [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/passes/tensorexpr_fuser.h#L22) | +| NNC on CPU | `torch::jit::overrideCanFuseOnCPU(bool);` | [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/codegen/fuser/interface.h#L28-L29) | +| NNC on GPU | `torch::jit::overrideCanFuseOnGPU(bool);` | [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/codegen/fuser/interface.h#L28-L29) | +| NVFuser enable/disable | `torch::jit::fuser::cuda::setEnabled(bool);` | [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/codegen/cuda/interface.h#L56) | + +### Disabling Optimizations ### + +To completely disable the runtime optimizations and only run the minimum optimizations necessary, +the following commands can be used to globally (in a process) disable the majority of runtime +optimizations. This will disable JIT autodiff (instead it will rely on the default autograd +implementation provided in eager mode) as well as the fusers and some other runtime optimizations. + +* Python: `torch._C._get_graph_executor_optimize(False)` +* C++: `torch::jit::setGraphExecutorOptimize(false);` +* C++ header: [here](https://github.com/pytorch/pytorch/blob/1a7e560adecb0192f69f4d05b990800b60dc380b/torch/csrc/jit/python/update_graph_executor_opt.h#L5) + ## JIT Logging ## [jit_log.h](jit_log.h) diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp index d9dabf73e8058..308857123d25d 100644 --- a/torch/csrc/jit/backends/backend_init.cpp +++ b/torch/csrc/jit/backends/backend_init.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/backends/coreml/cpp/context.cpp b/torch/csrc/jit/backends/coreml/cpp/context.cpp index f615c481ec114..3c63acce71134 100644 --- a/torch/csrc/jit/backends/coreml/cpp/context.cpp +++ b/torch/csrc/jit/backends/coreml/cpp/context.cpp @@ -12,10 +12,6 @@ BackendRegistrar::BackendRegistrar(ContextInterface* ctx) { g_coreml_ctx_registry.store(ctx); } -bool isCoreMLAvailable() { - auto p = g_coreml_ctx_registry.load(); - return p ? p->isCoreMLAvailable() : false; -} void setModelCacheDirectory(std::string path) { auto p = g_coreml_ctx_registry.load(); if (p) { diff --git a/torch/csrc/jit/backends/coreml/cpp/context.h b/torch/csrc/jit/backends/coreml/cpp/context.h index 4635c7f855fc6..644e3428af3e1 100644 --- a/torch/csrc/jit/backends/coreml/cpp/context.h +++ b/torch/csrc/jit/backends/coreml/cpp/context.h @@ -11,7 +11,6 @@ namespace coreml { struct ContextInterface { virtual ~ContextInterface() = default; - virtual bool isCoreMLAvailable() const = 0; virtual void setModelCacheDirectory(std::string path) = 0; }; @@ -20,7 +19,6 @@ class BackendRegistrar { explicit BackendRegistrar(ContextInterface* ctx); }; -bool isCoremlAvailable(); void setModelCacheDirectory(std::string path); } // namespace coreml diff --git a/torch/csrc/jit/backends/coreml/cpp/preprocess.cpp b/torch/csrc/jit/backends/coreml/cpp/preprocess.cpp index 39d86b55e5b8c..b124e54798898 100644 --- a/torch/csrc/jit/backends/coreml/cpp/preprocess.cpp +++ b/torch/csrc/jit/backends/coreml/cpp/preprocess.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace py = pybind11; diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm index e395326e28caf..ac7fe9febe41e 100644 --- a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm @@ -1,7 +1,11 @@ -#include -#include -#include -#include +#import +#import +#import +#import +#import +#import +#import +#import #import @@ -16,253 +20,200 @@ namespace mobile { namespace coreml { -static constexpr int SUPPORTED_COREML_VER = 4; +using c10::impl::GenericDict; +using c10::impl::GenericList; +using c10::IValue; -enum TensorType { - Float, - Double, - Int, - Long, - Undefined, +static const int32_t kSampleThreshold = static_cast(1.0 / 1000.0 * static_cast(RAND_MAX)); +static const int32_t kSampleEvery = 500; + +struct CoreMLConfig { + std::string backend = "CPU"; + bool allow_low_precision = true; }; -static inline c10::ScalarType scalarType(TensorType type) { - switch (type) { - case TensorType::Float: - return c10::ScalarType::Float; - case TensorType::Double: - return c10::ScalarType::Double; - case TensorType::Int: - return c10::ScalarType::Int; - case TensorType::Long: - return c10::ScalarType::Long; - case TensorType::Undefined: - return c10::ScalarType::Undefined; - default: - return c10::ScalarType::Undefined; +bool type_validity(const std::vector& specs) { + for (const TensorSpec& spec : specs) { + if (spec.dtype != c10::ScalarType::Float) { + return false; + } } + return true; } -static id parse(NSString* jsonStr) { - NSData* data = [jsonStr dataUsingEncoding:NSUTF8StringEncoding]; - NSError* error = nil; - id result = [NSJSONSerialization JSONObjectWithData:data - options:0 - error:&error]; - if (error || !result) { - TORCH_CHECK( - false, - "parsing JSON string failed!", - error.localizedDescription.UTF8String); - } - - return result; +void from_json(const nlohmann::json& j, TensorSpec& spec) { + j[0].get_to(spec.name); + std::string type_string; + j[1].get_to(type_string); + spec.dtype = scalar_type(type_string); } -struct TensorSpec { - public: - TensorSpec() = delete; - TensorSpec(NSArray* spec) { - TORCH_CHECK(spec.count == 3); - name_ = spec[0]; - dtype_ = (TensorType)spec[1].intValue; - } - NSString* name() { - return name_; +void from_json(const nlohmann::json& j, CoreMLConfig& config) { + j.at("backend").get_to(config.backend); + std::string allow_low_precision_string; + j.at("allow_low_precision").get_to(allow_low_precision_string); + if (allow_low_precision_string == "True") { + config.allow_low_precision = true; + } else { + config.allow_low_precision = false; } - TensorType dtype() { - return dtype_; - } - - private: - NSString* name_ = @""; - TensorType dtype_ = TensorType::Float; -}; +} -struct CoreMLConfig { - public: - CoreMLConfig() = delete; - CoreMLConfig(NSDictionary* dict) - : coreMLVersion_([dict[@"spec_ver"] intValue]), - backend_([dict[@"backend"] lowercaseString]), - allow_low_precision_([dict[@"allow_low_precision"] boolValue]) { - TORCH_CHECK( - coreMLVersion_ >= SUPPORTED_COREML_VER, - "Only Core ML version 4 or above are supported"); - } - int64_t coreMLVersion() const { - return coreMLVersion_; - } - NSString* backend() const { - return backend_; - } - bool allowLowPrecision() const { - return allow_low_precision_; - } +GenericList pack_outputs(const std::vector& output_specs, id outputProvider) { + c10::List outputs; + for (const TensorSpec& spec : output_specs) { + NSString *name = [NSString stringWithUTF8String:spec.name.c_str()]; + MLFeatureValue *val = [outputProvider featureValueForName:name]; + std::vector output_shape; + for (int i = 0; i < val.multiArrayValue.shape.count; ++i) { + output_shape.emplace_back(val.multiArrayValue.shape[i].integerValue); + } + auto tensor = at::empty(IntArrayRef(output_shape), spec.dtype); + int64_t count = val.multiArrayValue.count; + memcpy( + tensor.data_ptr(), + (float*)val.multiArrayValue.dataPointer, + count * sizeof(float)); + outputs.push_back(tensor); + } + return c10::impl::toList(outputs); +} - private: - int64_t coreMLVersion_ = SUPPORTED_COREML_VER; - NSString* backend_ = @"CPU"; - bool allow_low_precision_ = true; -}; +class CoreMLBackend: public torch::jit::PyTorchBackendInterface { -struct MetaData { public: - MetaData(NSDictionary* dict) - : torchVer_(dict[@"torch_ver"]), - coremltoolVer_(dict[@"coremltool_ver"]) {} - NSString* torchVer() const { - return torchVer_; - } - NSString* coremltoolVer() const { - return coremltoolVer_; - } - - private: - NSString* torchVer_ = @""; - NSString* coremltoolVer_ = @""; -}; + GenericDict compile(IValue processed, GenericDict method_compile_spec) override { + const c10::Dict model_dict = processed.toGenericDict(); + const std::string& extra = model_dict.at("extra").toStringRef(); + const std::string& model = model_dict.at("model").toStringRef(); + const std::string& sha256 = model_dict.at("hash").toStringRef(); + + const int32_t load_id = std::rand(); + const int32_t instance_key = std::rand(); + size_t mem_limit = 0; + + PTMCoreMLObserver *observer = coreMLObserverConfig().getCoreMLObserver(); + if (observer) { + mem_limit = observer->getRemainingMemory(); + observer->onEnterCompileModel(instance_key, load_id); + } -// Wrap the Objective-C executor into a C++ to be able to pack into IValue -struct API_AVAILABLE(ios(11.0), macos(10.13)) CoreMLExecutorWrapper - : public CustomClassHolder { - public: - CoreMLExecutorWrapper( - PTMCoreMLExecutor* executor, - std::vector& inputs, - std::vector& outputs, - CoreMLConfig config) - : executor_(executor), - inputs_(inputs), - outputs_(outputs), - config_(config) {} - c10::List execute(const c10::impl::GenericList& inputs) { - std::vector inputSpecs; - std::vector outputSpecs; - int inputSpecIndex = 0; - // pack the inputs - for (int i = 0; i < inputs.size(); ++i) { - auto val = inputs.get(i); - if (val.isTuple()) { - auto& tuples = val.toTupleRef().elements(); - for (auto& ival : tuples) { - TORCH_CHECK(ival.isTensor()); - auto tensor = ival.toTensor(); - PTMCoreMLFeatureSpecs spec{ - .name = inputs_[inputSpecIndex].name(), - .tensor = tensor, - }; - inputSpecs.emplace_back(spec); - ++inputSpecIndex; - } - } else { - TORCH_CHECK(val.isTensor()); - auto tensor = val.toTensor(); - PTMCoreMLFeatureSpecs spec{ - .name = inputs_[inputSpecIndex].name(), - .tensor = tensor, - }; - inputSpecs.emplace_back(spec); - ++inputSpecIndex; + CoreMLConfig config; + std::vector input_specs; + std::vector output_specs; + + try { + nlohmann::json extra_json = nlohmann::json::parse(extra); + config = extra_json["config"].get(); + input_specs = extra_json["inputs"].get>(); + output_specs = extra_json["outputs"].get>(); + } catch (std::exception& exn) { + if (observer) { + observer->onExitCompileModel(instance_key, false, true); } + TORCH_CHECK(false, "Parsing model dict failed!"); } - // pack the outputs - c10::List outputs; - id results = [executor_ forwardWithInputs:inputSpecs]; - for (auto& spec : outputs_) { - MLFeatureValue* val = [results featureValueForName:spec.name()]; - TORCH_CHECK(val.multiArrayValue); - // Currently, only Float type is supported - TORCH_CHECK(val.multiArrayValue.dataType == MLMultiArrayDataTypeFloat32); - std::vector outputShape; - for (int i = 0; i < val.multiArrayValue.shape.count; ++i) { - outputShape.emplace_back(val.multiArrayValue.shape[i].integerValue); + + if (!type_validity(input_specs) || !type_validity(output_specs)) { + if (observer) { + observer->onExitCompileModel(instance_key, false, true); } - auto tensor = - at::empty(IntArrayRef(outputShape), scalarType(spec.dtype())); - int64_t count = val.multiArrayValue.count; - memcpy( - tensor.data_ptr(), - (float*)val.multiArrayValue.dataPointer, - count * sizeof(float)); - outputs.push_back(tensor); + TORCH_CHECK(false, "Compiling model failed, only float type tensors supported"); } - return outputs; - } - private: - PTMCoreMLExecutor* executor_ = nullptr; - std::vector inputs_; - std::vector outputs_; - CoreMLConfig config_; -}; + NSURL *modelURL = [PTMCoreMLCompiler compileModel:model modelID:sha256]; + MLModel *cpuModel = modelURL ? [PTMCoreMLCompiler loadCPUModelAtURL:modelURL] : nil; -class API_AVAILABLE(ios(11.0), macos(10.13)) CoreMLBackend - : public torch::jit::PyTorchBackendInterface { - public: - c10::impl::GenericDict compile( - c10::IValue processed, - c10::impl::GenericDict method_compile_spec) override { - auto modelDict = processed.toGenericDict(); - NSString* specs = [[NSString alloc] - initWithCString:modelDict.at("extra").toStringRef().c_str() - encoding:NSUTF8StringEncoding]; - NSDictionary* dict = parse(specs); - NSArray* inputs = dict[@"inputs"]; - NSArray* outputs = dict[@"outputs"]; - std::vector inputSpecs, outputSpecs; - for (NSArray* input in inputs) { - inputSpecs.emplace_back(TensorSpec(input)); + if (!cpuModel) { + if (observer) { + observer->onExitCompileModel(instance_key, false, true); + } + TORCH_CHECK(false, "Compiling MLModel for CPU failed!"); } - for (NSArray* output in outputs) { - outputSpecs.emplace_back(TensorSpec(output)); + + NSMutableArray *orderedFeatures = [NSMutableArray array]; + for (TensorSpec& spec : input_specs) { + NSString *name = [NSString stringWithUTF8String:spec.name.c_str()]; + [orderedFeatures addObject:name]; + } + + PTMCoreMLExecutor *executor = [[PTMCoreMLExecutor alloc] initWithFeatureNames:orderedFeatures]; + executor.model = cpuModel; + [executor autorelease]; + + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + MLModel *configuredModel = [PTMCoreMLCompiler loadModelAtURL:modelURL backend:config.backend allowLowPrecision:config.allow_low_precision]; + executor.model = configuredModel ?: cpuModel; + }); + + if (observer) { + bool should_log = load_id < kSampleThreshold; + observer->onExitCompileModel(instance_key, true, should_log); } - auto config = CoreMLConfig(dict[@"config"]); - const std::string& model = modelDict.at("model").toStringRef(); - const std::string& sha256 = modelDict.at("hash").toStringRef(); - PTMCoreMLExecutor* executor = [PTMCoreMLExecutor new]; - executor.backend = config.backend(); - executor.allowLowPrecision = config.allowLowPrecision(); - executor.coreMLVersion = config.coreMLVersion(); - bool result = [executor compileMLModel:model identifier:sha256]; - TORCH_CHECK(result, "Compiling MLModel failed!"); - auto executorWrapper = c10::make_intrusive( - executor, inputSpecs, outputSpecs, config); - auto handle = IValue::make_capsule(executorWrapper); + + MLModelWrapper model_wrapper = MLModelWrapper(executor); + model_wrapper.outputs = output_specs; + model_wrapper.load_id = load_id; + model_wrapper.mem_limit = mem_limit; + + auto model_wrapper_ptr = c10::make_intrusive(model_wrapper); + auto handle = IValue::make_capsule(model_wrapper_ptr); + c10::Dict ret(StringType::get(), c10::AnyType::get()); ret.insert("forward", handle); return c10::impl::toGenericDict(ret); } - c10::impl::GenericList execute( - c10::IValue handle, - c10::impl::GenericList inputs) override { - auto executor = c10::static_intrusive_pointer_cast( - handle.toCapsule()); - auto outputs = executor->execute(inputs); - return c10::impl::toList(outputs); + GenericList execute(IValue handle, GenericList inputs) override { + const auto model_wrapper = c10::static_intrusive_pointer_cast(handle.toCapsule()); + const int32_t instance_key = std::rand(); + const int32_t load_id = model_wrapper->load_id; + const size_t mem_limit = model_wrapper->mem_limit; + int32_t inferences = model_wrapper->inferences; + + PTMCoreMLObserver *observer = coreMLObserverConfig().getCoreMLObserver(); + if (observer) { + observer->onEnterExecuteModel(instance_key, load_id, mem_limit, inferences); + } + + PTMCoreMLExecutor *executor = model_wrapper->executor; + [executor setInputs:inputs]; + + id outputsProvider = [executor forward]; + + model_wrapper->inferences = ++inferences; + + if (observer) { + // Check if this inference session is logged. If so, log every N inferences + bool succeeded = outputsProvider != nil; + bool should_log = load_id < kSampleThreshold && inferences > 1; + should_log = !succeeded || (should_log && (inferences % kSampleEvery == 0)); + observer->onExitExecuteModel(instance_key, inferences, succeeded, should_log); + } + + return pack_outputs(model_wrapper->outputs, outputsProvider); } + bool is_available() override { - return [PTMCoreMLExecutor isAvailable]; +#if TARGET_OS_IPHONE + return [UIDevice currentDevice].systemVersion.floatValue >= 12.0; +#elif TARGET_OS_MAC + NSOperatingSystemVersion supportedVer = {10, 13, 0}; + return [[NSProcessInfo processInfo] isOperatingSystemAtLeastVersion:supportedVer]; +#endif + return false; } }; -struct API_AVAILABLE(ios(11.0), macos(10.13)) ContextImpl - : public ContextInterface { - bool isCoreMLAvailable() const override { - return [PTMCoreMLExecutor isAvailable]; - } +static auto cls = torch::jit::backend("coreml"); + +struct PTMCoreMLContext : public ContextInterface { void setModelCacheDirectory(std::string dir) override { - [PTMCoreMLExecutor - setModelCacheDirectory:[NSString stringWithCString:dir.c_str()]]; + [PTMCoreMLCompiler setModelCacheDirectory:dir]; } }; -API_AVAILABLE(ios(11.0), macos(10.13)) -static auto cls = torch::jit::backend("coreml"); - -API_AVAILABLE(ios(11.0), macos(10.13)) -static BackendRegistrar g_coreml_backend(new ContextImpl()); +static BackendRegistrar g_coreml_backend(new PTMCoreMLContext()); } // namespace } diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h new file mode 100644 index 0000000000000..39a511c316e69 --- /dev/null +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h @@ -0,0 +1,24 @@ +#import + +#include + +NS_ASSUME_NONNULL_BEGIN + +@interface PTMCoreMLCompiler : NSObject + ++ (void)setModelCacheDirectory:(const std::string&)dir; + ++ (NSString*)modelCacheDirectory; + ++ (NSURL*)compileModel:(const std::string&)modelSpecs + modelID:(const std::string&)modelID; + ++ (nullable MLModel*)loadCPUModelAtURL:(NSURL*)modelURL; + ++ (nullable MLModel*)loadModelAtURL:(NSURL*)modelURL + backend:(const std::string&)backend + allowLowPrecision:(BOOL)allowLowPrecision; + +@end + +NS_ASSUME_NONNULL_END diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm new file mode 100644 index 0000000000000..98096ad79075a --- /dev/null +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm @@ -0,0 +1,160 @@ +#import + +#if TARGET_OS_IPHONE +#import +#endif + +@implementation PTMCoreMLCompiler + +static NSString* gModelCacheDirectory = @""; + ++ (void)setModelCacheDirectory:(const std::string&)dir { + gModelCacheDirectory = [NSString stringWithCString:dir.c_str()]; +} + ++ (nonnull NSString *)modelCacheDirectory { + BOOL isSet = gModelCacheDirectory.length != 0; + BOOL isWriteable = isSet && [[NSFileManager defaultManager] isWritableFileAtPath:gModelCacheDirectory]; + if (!isSet || !isWriteable) { + // set the default directory to tmp + gModelCacheDirectory = NSTemporaryDirectory(); + } + return gModelCacheDirectory; +} + ++ (NSURL*)compileModel:(const std::string&)modelSpecs modelID:(const std::string&)modelID { + NSString* modelName = [NSString stringWithCString:modelID.c_str() encoding:NSUTF8StringEncoding]; + NSURL* modelPath = [PTMCoreMLCompiler _cacheFilePath:modelName]; + NSURL* compiledModelPath = [PTMCoreMLCompiler _compiledModelFilePath:modelPath.path]; + + BOOL modelCached = [[NSFileManager defaultManager] fileExistsAtPath:modelPath.path]; + BOOL compiledModelCached = [[NSFileManager defaultManager] fileExistsAtPath:compiledModelPath.path]; + BOOL shouldRecompile = [self _shouldRecompileModel:compiledModelPath]; + + if (modelCached != compiledModelCached) { + modelCached = NO; + compiledModelCached = NO; + [PTMCoreMLCompiler _cleanupModel:modelPath compiledModel:compiledModelPath]; + } + + if (!modelCached) { + // Note that the serialized protobuf binary contains bytes not text. + // https://developers.google.com/protocol-buffers/docs/pythontutorial#parsing-and-serialization + NSData* data = [NSData dataWithBytes:modelSpecs.c_str() length:modelSpecs.length()]; + if (![data writeToFile:modelPath.path atomically:YES]) { + // If the model cannot be persisted on disk then compilation cannot proceed. + NSLog(@"Failed to save specs for MLModel!"); + [PTMCoreMLCompiler _cleanupModel:modelPath compiledModel:compiledModelPath]; + return nil; + } + } + + if (shouldRecompile || !compiledModelCached) { + NSError *error; + NSURL *temporaryURL = [MLModel compileModelAtURL:modelPath error:&error]; + if (!error) { + [PTMCoreMLCompiler _moveFileToCache:temporaryURL cacheURL:compiledModelPath error:&error]; + } + if (error) { + NSLog(@"Failed to compile MLModel!"); + [PTMCoreMLCompiler _cleanupModel:modelPath compiledModel:compiledModelPath]; + return nil; + } + } + + return compiledModelPath; +} + ++ (nullable MLModel*)loadCPUModelAtURL:(NSURL*)modelURL { + NSError *error; + MLModel *model; + if (@available(iOS 12.0, macOS 10.14, *)) { + MLModelConfiguration* config = [[MLModelConfiguration alloc] init]; + config.computeUnits = MLComputeUnitsCPUOnly; + model = [MLModel modelWithContentsOfURL:modelURL configuration:config error:&error]; + } else { + model = [MLModel modelWithContentsOfURL:modelURL error:&error]; + } + if (error) { + NSLog(@"Failed to initialize MLModel!"); + [PTMCoreMLCompiler _cleanupModel:nil compiledModel:modelURL]; + return nil; + } + return model; +} + ++ (nullable MLModel*)loadModelAtURL:(NSURL*)modelURL backend:(const std::string&)backend allowLowPrecision:(BOOL)allowLowPrecision { + NSError *error; + MLModel *model; + if (@available(iOS 12.0, macOS 10.14, *)) { + MLModelConfiguration* config = [[MLModelConfiguration alloc] init]; + MLComputeUnits computeUnits = MLComputeUnitsCPUOnly; + if (backend == "cpuandgpu") { + computeUnits = MLComputeUnitsCPUAndGPU; + } else if (backend == "all") { + computeUnits = MLComputeUnitsAll; + } + config.computeUnits = computeUnits; + config.allowLowPrecisionAccumulationOnGPU = allowLowPrecision; + model = [MLModel modelWithContentsOfURL:modelURL configuration:config error:&error]; + } else { + model = [MLModel modelWithContentsOfURL:modelURL error:&error]; + } + if (error) { + NSLog(@"Failed to initialize MLModel!"); + [PTMCoreMLCompiler _cleanupModel:nil compiledModel:modelURL]; + return nil; + } + return model; +} + ++ (void)_cleanupModel:(NSURL*)modelPath compiledModel:(NSURL*)compiledModelPath { + NSFileManager* fileManager = [NSFileManager defaultManager]; + NSError* error = nil; + if (modelPath && [fileManager fileExistsAtPath:modelPath.path]) { + [fileManager removeItemAtPath:modelPath.path error:&error]; + } + if (compiledModelPath && [fileManager fileExistsAtPath:compiledModelPath.path]) { + [fileManager removeItemAtPath:compiledModelPath.path error:&error]; + } +} + ++ (void)_moveFileToCache:(NSURL *)fileURL cacheURL:(NSURL *)cacheURL error:(NSError **)error { + if ([fileURL isEqual:cacheURL]) { + return; + } + NSFileManager *fileManager = [NSFileManager defaultManager]; + if ([fileManager fileExistsAtPath:cacheURL.path]) { + [fileManager removeItemAtURL:cacheURL error:error]; + } + [fileManager moveItemAtURL:fileURL toURL:cacheURL error:error]; +} + ++ (BOOL)_shouldRecompileModel:(NSURL *)compiledModelPath { +#if TARGET_OS_IPHONE + NSString *versionPath = [PTMCoreMLCompiler _cacheFilePath:@"version"].path; + NSString *cachedOSVer = nil; + if ([[NSFileManager defaultManager] fileExistsAtPath:versionPath]) { + NSError *error = nil; + cachedOSVer = [NSString stringWithContentsOfFile:versionPath encoding:NSUTF8StringEncoding error:&error]; + } + // Compile the model when OS version changes + NSString *currentOSVer = [UIDevice currentDevice].systemVersion; + [currentOSVer writeToFile:versionPath atomically:YES]; + return ![currentOSVer isEqualToString:cachedOSVer]; +#else + return YES; +#endif +} + ++ (NSURL *)_cacheFilePath:(NSString *)fileName { + NSString *filePath = [[PTMCoreMLCompiler modelCacheDirectory] stringByAppendingPathComponent:fileName]; + return [NSURL fileURLWithPath:filePath]; +} + ++ (NSURL *)_compiledModelFilePath:(NSString *)modelPath { + NSString *filePath = [modelPath stringByAppendingString:@".mlmodelc"]; + return [NSURL fileURLWithPath:filePath]; +} + +@end diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h index 03e1c6c4cbe2c..d38a37bf6f2f3 100644 --- a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h @@ -1,36 +1,19 @@ -#import -#include - -#include -#include +#import -struct PTMCoreMLFeatureSpecs { - NSString* name; - at::Tensor tensor; -}; +#import -API_AVAILABLE(ios(11.0), macos(10.13)) -@interface PTMCoreMLFeatureProvider : NSObject -- (instancetype)initWithFeatureSpecs: - (const std::vector&)specs - CoreMLVersion:(NSUInteger)ver; -@end +NS_ASSUME_NONNULL_BEGIN -API_AVAILABLE(ios(11.0), macos(10.13)) @interface PTMCoreMLExecutor : NSObject -@property(nonatomic, copy) NSString* backend; -@property(nonatomic, assign) BOOL allowLowPrecision; -@property(nonatomic, assign) NSUInteger coreMLVersion; +@property(atomic, strong) MLModel* model; + +- (instancetype)initWithFeatureNames:(NSArray*)featureNames; -+ (BOOL)isAvailable; -+ (void)setModelCacheDirectory:(NSString*)dir; -+ (NSString*)modelCacheDirectory; +- (void)setInputs:(c10::impl::GenericList)inputs; -- (BOOL)compileMLModel:(const std::string&)modelSpecs - identifier:(const std::string&)identifier; -- (id)forwardWithInputs: - (const std::vector&)inputs; -- (BOOL)cleanup; +- (id)forward; @end + +NS_ASSUME_NONNULL_END diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm index fbb7abe87b524..b393cebd52169 100644 --- a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm @@ -1,346 +1,49 @@ -#include -#include +#import #import -#if TARGET_OS_IPHONE -#import -#endif - -// Observer -#import - -#include -#include -#include - -// This is a utility macro that can be used to throw an exception when a CoreML -// API function produces a NSError. The exception will contain a message with -// useful info extracted from the NSError. -#define COREML_THROW_IF_ERROR(error, preamble) \ - do { \ - if C10_LIKELY(error) { \ - throw c10::Error( \ - {__func__, __FILE__, static_cast(__LINE__)}, \ - c10::str( \ - preamble, \ - " Error details: ", \ - " Localized_description: ", error.localizedDescription.UTF8String, \ - " Domain: ", error.domain.UTF8String, \ - " Code: ", error.code, \ - " User Info: ", error.userInfo.description.UTF8String)); \ - } \ - } while (false) - -@implementation PTMCoreMLFeatureProvider { - NSUInteger _coremlVersion; - std::vector _specs; -} - -@synthesize featureNames = _featureNames; - -- (instancetype)initWithFeatureSpecs: - (const std::vector&)specs - CoreMLVersion:(NSUInteger)ver { - self = [super init]; - if (self) { - _coremlVersion = ver; - _specs = specs; - NSMutableArray* names = [NSMutableArray new]; - for (auto& spec : _specs) { - [names addObject:spec.name]; - } - _featureNames = [[NSSet alloc] initWithArray:names]; - } - return self; -} - -- (nullable MLFeatureValue*)featureValueForName:(NSString*)featureName { - for (auto& spec : _specs) { - if ([spec.name isEqualToString:featureName]) { - NSMutableArray* shape = [NSMutableArray new]; - for (auto& dim : spec.tensor.sizes().vec()) { - [shape addObject:@(dim)]; - } - NSMutableArray* strides = [NSMutableArray new]; - for (auto& step : spec.tensor.strides().vec()) { - [strides addObject:@(step)]; - } - NSError* error = nil; - TORCH_CHECK(spec.tensor.dtype() == c10::kFloat); - MLMultiArray* mlArray = [[MLMultiArray alloc] - initWithDataPointer:spec.tensor.data_ptr() - shape:shape - dataType:MLMultiArrayDataTypeFloat32 - strides:strides - deallocator:(^(void* bytes){ - })error:&error]; - return [MLFeatureValue featureValueWithMultiArray:mlArray]; - } - } - return nil; -} - -@end - -static NSString* gModelCacheDirectory = @""; - @implementation PTMCoreMLExecutor { - MLModel* _mlModel; - NSURL* _modelPath; - NSURL* _compiledModelPath; - - int32_t _model_load_id; - int32_t _inferences; - - int32_t _sample_thresh; - int32_t _sample_every; - - size_t _init_mem_limit; -} - -+ (void)setModelCacheDirectory:(NSString*)dir { - gModelCacheDirectory = dir; -} - -+ (NSString*)modelCacheDirectory { - if (gModelCacheDirectory.length == 0 || - ![[NSFileManager defaultManager] - isWritableFileAtPath:gModelCacheDirectory]) { - // set the default directory to tmp - gModelCacheDirectory = NSTemporaryDirectory(); - } - return gModelCacheDirectory; -} - -+ (BOOL)isAvailable { -#if !defined(__APPLE__) - return false; -#elif TARGET_OS_IPHONE - if ([UIDevice currentDevice].systemVersion.floatValue > 14.0) { - return true; - } -#elif TARGET_OS_MAC - NSOperatingSystemVersion supportedVer = {10, 13, 0}; - if ([[NSProcessInfo processInfo] - isOperatingSystemAtLeastVersion:supportedVer]) { - return true; - } -#endif - return false; + NSArray *_featureNames; + PTMCoreMLFeatureProvider *_inputProvider; } -- (BOOL)compileMLModel:(const std::string&)modelSpecs - identifier:(const std::string&)identifier - API_AVAILABLE(ios(11.0), macos(10.13)) { - NSString* mlModelName = [NSString stringWithCString:identifier.c_str() - encoding:NSUTF8StringEncoding]; - _modelPath = [self _cacheFilePath:mlModelName]; - [self _saveModel:modelSpecs]; - NSError* error = nil; - _compiledModelPath = [self _compiledModelFilePath:_modelPath.path]; - - // Get observer and create an instance key - PTMCoreMLObserver* observer = coreMLObserverConfig().getCoreMLObserver(); - int32_t instance_key = std::rand(); - _model_load_id = std::rand(); - _inferences = 0; - - _init_mem_limit = 0; - - _sample_thresh = - static_cast(1.0 / 1000.0 * static_cast(RAND_MAX)); - _sample_every = 500; - - if (observer) { - _init_mem_limit = observer->getRemainingMemory(); - observer->onEnterCompileModel(instance_key, _model_load_id); - } - - // Compile the model when OS version changes - if ([self _shouldRecompileModel]) { - if (@available(iOS 11.0, macOS 10.13, *)) { - NSURL* temporaryFileURL = [MLModel compileModelAtURL:_modelPath - error:&error]; - if (!error) { - // move the model to the cache directory - NSFileManager* fileManager = [NSFileManager defaultManager]; - if (![temporaryFileURL isEqual:_compiledModelPath]) { - if ([fileManager fileExistsAtPath:_compiledModelPath.path]) { - [fileManager removeItemAtURL:_compiledModelPath error:&error]; - } - [fileManager moveItemAtURL:temporaryFileURL - toURL:_compiledModelPath - error:&error]; - } - } - } else { - // Always log on failure - if (observer) { - observer->onExitCompileModel(instance_key, false, true); - } - TORCH_CHECK(false, "CoreML is not available on your deivce"); - } - } - - if (error) { - // Always log on failure - if (observer) { - observer->onExitCompileModel(instance_key, false, true); - } - - // remove cached models if compalition failed. - [self cleanup]; - - COREML_THROW_IF_ERROR(error, "Error compiling the MLModel file!"); - return NO; - } - if (@available(iOS 12.0, macOS 10.14, *)) { - MLModelConfiguration* config = [MLModelConfiguration alloc]; - MLComputeUnits backend = MLComputeUnitsCPUOnly; - if ([self.backend isEqualToString:@"cpuandgpu"]) { - backend = MLComputeUnitsCPUAndGPU; - } else if ([self.backend isEqualToString:@"all"]) { - backend = MLComputeUnitsAll; - } - config.computeUnits = backend; - config.allowLowPrecisionAccumulationOnGPU = self.allowLowPrecision; - _mlModel = [MLModel modelWithContentsOfURL:_compiledModelPath - configuration:config - error:&error]; - } else { - _mlModel = [MLModel modelWithContentsOfURL:_compiledModelPath error:&error]; - } - if (error || !_mlModel) { - // Always log on failure - if (observer) { - observer->onExitCompileModel(instance_key, false, true); - } - - COREML_THROW_IF_ERROR(error, "Error loading the MLModel file!"); - } - - if (observer) { - bool should_log = _model_load_id < _sample_thresh; - observer->onExitCompileModel(instance_key, true, should_log); +- (instancetype)initWithFeatureNames:(NSArray *)featureNames { + if (self = [super init]) { + _featureNames = featureNames; + NSSet *featureNamesSet = [NSSet setWithArray:featureNames]; + _inputProvider = [[PTMCoreMLFeatureProvider alloc] initWithFeatureNames:featureNamesSet]; } - - return YES; + return self; } -- (id)forwardWithInputs: - (const std::vector&)inputs { - @autoreleasepool { - // Get observer and create an instance key - PTMCoreMLObserver* observer = coreMLObserverConfig().getCoreMLObserver(); - int32_t instance_key = std::rand(); - - if (observer) { - observer->onEnterExecuteModel( - instance_key, _model_load_id, _init_mem_limit, _inferences); - } - - NSError* error = nil; - PTMCoreMLFeatureProvider* inputFeature = [[PTMCoreMLFeatureProvider alloc] - initWithFeatureSpecs:inputs - CoreMLVersion:self.coreMLVersion]; - if (inputFeature == nil) { - return nil; - } - if (@available(iOS 11.0, macOS 10.13, *)) { - MLPredictionOptions* options = [[MLPredictionOptions alloc] init]; - id outputFeature = - [_mlModel predictionFromFeatures:inputFeature - options:options - error:&error]; - - COREML_THROW_IF_ERROR(error, "Error running CoreML inference!"); +- (void)setInputs:(c10::impl::GenericList)inputs { + [_inputProvider clearInputTensors]; - ++_inferences; - if (observer) { - // Check if this inference session is being logged. - // If so, only log every N inferences - bool should_log = _model_load_id < _sample_thresh && _inferences > 1; - if (should_log) { - should_log = _inferences % _sample_every == 0; - } - observer->onExitExecuteModel( - instance_key, _inferences, true, should_log); + int input_count = 0; + for (int i = 0; i < inputs.size(); ++i) { + at::IValue val = inputs.get(i); + if (val.isTuple()) { + auto& tuples = val.toTupleRef().elements(); + for (auto& ival : tuples) { + [_inputProvider setInputTensor:ival.toTensor() forFeatureName:_featureNames[input_count]]; + input_count++; } - - return outputFeature; } else { - // Always log on failure - if (observer) { - observer->onExitExecuteModel(instance_key, _inferences, true, true); - } - - TORCH_CHECK(false, "Core ML is not available on your device"); - return nil; + [_inputProvider setInputTensor:val.toTensor() forFeatureName:_featureNames[input_count]]; + input_count++; } } } -- (BOOL)cleanup { - NSFileManager* fileManager = [NSFileManager defaultManager]; - NSError* error = nil; - NSString* modelPath = _modelPath.path; - NSString* compiledModelPath = _compiledModelPath.path; - if ([fileManager fileExistsAtPath:modelPath]) { - [fileManager removeItemAtPath:modelPath error:&error]; - } - if ([fileManager fileExistsAtPath:compiledModelPath]) { - [fileManager removeItemAtPath:compiledModelPath error:&error]; - } - return !error; -} - -- (void)_saveModel:(const std::string&)spec { - NSString* modelPath = _modelPath.path; - if (![[NSFileManager defaultManager] fileExistsAtPath:modelPath]) { - // Note that the serialized protobuf binary contains bytes, not text; - // see - // https://developers.google.com/protocol-buffers/docs/pythontutorial#parsing-and-serialization - NSData* data = [NSData dataWithBytes:spec.c_str() length:spec.length()]; - BOOL ret = [data writeToFile:modelPath atomically:YES]; - TORCH_CHECK(ret, "Error saving the MLModel", modelPath.UTF8String); - } -} - -- (BOOL)_shouldRecompileModel { -#if TARGET_OS_IPHONE - NSError* error = nil; - NSString* currentOSVer = [UIDevice currentDevice].systemVersion; - NSString* versionPath = [self _cacheFilePath:@"version"].path; - BOOL shouldRecompileModel = YES; - NSFileManager* fileManager = [NSFileManager defaultManager]; - if ([fileManager fileExistsAtPath:versionPath]) { - NSString* cachedOSVer = - [NSString stringWithContentsOfFile:versionPath - encoding:NSUTF8StringEncoding - error:&error]; - if ([cachedOSVer isEqualToString:currentOSVer]) { - if ([fileManager fileExistsAtPath:_compiledModelPath.path]) { - shouldRecompileModel = NO; - } - } +- (id)forward { + NSError *error; + MLPredictionOptions *options = [[MLPredictionOptions alloc] init]; + id outputs = [self.model predictionFromFeatures:_inputProvider options:options error:&error]; + if (error) { + NSLog(@"Prediction failed with error %@", error); + return nil; } - [currentOSVer writeToFile:versionPath atomically:YES]; - return shouldRecompileModel; -#else - return YES; -#endif -} - -- (NSURL*)_cacheFilePath:(NSString*)fileName { - NSString* filePath = [[[self class] modelCacheDirectory] - stringByAppendingPathComponent:fileName]; - return [NSURL fileURLWithPath:filePath]; -} - -- (NSURL*)_compiledModelFilePath:(NSString*)modelPath { - NSString* filePath = [modelPath stringByAppendingString:@".mlmodelc"]; - return [NSURL fileURLWithPath:filePath]; + return outputs; } @end diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h new file mode 100644 index 0000000000000..f0ccd7280b09d --- /dev/null +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h @@ -0,0 +1,16 @@ +#import +#import + +NS_ASSUME_NONNULL_BEGIN + +@interface PTMCoreMLFeatureProvider : NSObject + +- (instancetype)initWithFeatureNames:(NSSet*)featureNames; + +- (void)clearInputTensors; + +- (void)setInputTensor:(const at::Tensor&)tensor forFeatureName:(NSString*)name; + +@end + +NS_ASSUME_NONNULL_END diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.mm b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.mm new file mode 100644 index 0000000000000..27a2c5b4b2733 --- /dev/null +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.mm @@ -0,0 +1,51 @@ +#import + +@implementation PTMCoreMLFeatureProvider { + NSMutableDictionary *_featureValuesForName; +} + +@synthesize featureNames = _featureNames; + +- (instancetype)initWithFeatureNames:(NSSet *)featureNames { + if (self = [super init]) { + _featureNames = featureNames; + _featureValuesForName = [NSMutableDictionary dictionary]; + } + return self; +} + +- (void)clearInputTensors { + [_featureValuesForName removeAllObjects]; +} + +- (void)setInputTensor:(const at::Tensor&)tensor forFeatureName:(NSString *)name { + NSMutableArray *shape = [NSMutableArray new]; + for (auto& dim : tensor.sizes().vec()) { + [shape addObject:@(dim)]; + } + + NSMutableArray *strides = [NSMutableArray new]; + for (auto& step : tensor.strides().vec()) { + [strides addObject:@(step)]; + } + + NSError* error = nil; + MLMultiArray *mlArray = + [[MLMultiArray alloc] + initWithDataPointer:tensor.data_ptr() + shape:shape + dataType:MLMultiArrayDataTypeFloat32 + strides:strides + deallocator:(^(void* bytes){}) + error:&error]; + MLFeatureValue *value = [MLFeatureValue featureValueWithMultiArray:mlArray]; + if (value) { + _featureValuesForName[name] = value; + } +} + +- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName { + return _featureValuesForName[featureName]; +} + +@end diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h new file mode 100644 index 0000000000000..49bad921cfa69 --- /dev/null +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h @@ -0,0 +1,50 @@ +#include +#include +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace coreml { + +class MLModelWrapper : public CustomClassHolder { + public: + PTMCoreMLExecutor* executor; + std::vector outputs; + int32_t load_id = 0; + int32_t inferences = 0; + size_t mem_limit = 0; + + MLModelWrapper() = delete; + + MLModelWrapper(PTMCoreMLExecutor* executor) : executor(executor) { + [executor retain]; + } + + MLModelWrapper(const MLModelWrapper& oldObject) { + executor = oldObject.executor; + outputs = oldObject.outputs; + load_id = oldObject.load_id; + inferences = oldObject.inferences; + mem_limit = oldObject.mem_limit; + [executor retain]; + } + + MLModelWrapper(MLModelWrapper&& oldObject) { + executor = oldObject.executor; + outputs = oldObject.outputs; + load_id = oldObject.load_id; + inferences = oldObject.inferences; + mem_limit = oldObject.mem_limit; + [executor retain]; + } + + ~MLModelWrapper() { + [executor release]; + } +}; + +} +} +} +} diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h new file mode 100644 index 0000000000000..514629723047d --- /dev/null +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h @@ -0,0 +1,32 @@ +#include +#import + +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace coreml { + +struct TensorSpec { + std::string name = ""; + c10::ScalarType dtype = c10::ScalarType::Float; +}; + +static inline c10::ScalarType scalar_type(const std::string& type_string) { + if (type_string == "0") { + return c10::ScalarType::Float; + } else if (type_string == "1") { + return c10::ScalarType::Double; + } else if (type_string == "2") { + return c10::ScalarType::Int; + } else if (type_string == "3") { + return c10::ScalarType::Long; + } + return c10::ScalarType::Undefined; +} + +} // namespace coreml +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp b/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp index a787ecc6cbfda..448d448f1057f 100644 --- a/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp +++ b/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace py = pybind11; diff --git a/torch/csrc/jit/codegen/cuda/README.md b/torch/csrc/jit/codegen/cuda/README.md index e4bfb8d700629..be8aed6c5ce44 100644 --- a/torch/csrc/jit/codegen/cuda/README.md +++ b/torch/csrc/jit/codegen/cuda/README.md @@ -214,3 +214,11 @@ There are three ways to disable nvfuser. Listed below with descending priorities - Force using NNC instead of nvfuser for GPU fusion with env variable `export PYTORCH_JIT_USE_NNC_NOT_NVFUSER=1`. - Disabling nvfuser with torch API `torch._C._jit_set_nvfuser_enabled(False)`. - Disable nvfuser with env variable `export PYTORCH_JIT_ENABLE_NVFUSER=0`. + +4. Is there any more knobs to tune nvfuser fusion? + +Some opt-out features in nvfuser are exposed via env var `PYTORCH_NVFUSER_DISABLE`. e.g. `fallback` to disable aten fallback during compilation failure and `fma` to disable fused multiply-add, you would set `export PYTORCH_NVFUSER_DISABLE="fallback,fma"`. Note that disabling fma would usually regress on performance so we strongly encourage to not disable it. + +There's also opt-in features via env var `PYTORCH_NVFUSER_ENABLE`. +- `complex` would enable complex floating type support in nvfuser (currently experimental and turned off by default to avoid functional regression); +- `linear_decomposition` enables decomposition of the bias add in linear layer. Similarly, `conv_decomposition` enables decomposition of the bias add in conv layer. In some small benchmark models, we noticed that such decompositions added more overhead in compilation that out-weighs the benefit of faster kernel. Hence we decided to change these to be opt-in instead. diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 790894d01b817..0e943931ad66d 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -528,6 +528,46 @@ TensorView* abs(TensorView* tv) { return abs(tv->as())->as(); } +// The output of real(complex_tensor) are real numbers +Val* real(Val* v) { + if (v->getDataType() == DataType::ComplexDouble) { + Val* out = newValLike(v, DataType::Double); + IrBuilder::create(UnaryOpType::Real, out, v); + return out; + } + if (v->getDataType() == DataType::ComplexFloat) { + Val* out = newValLike(v, DataType::Float); + IrBuilder::create(UnaryOpType::Real, out, v); + return out; + } + // We use UnaryOpType::Set instead of UnaryOpType::Real to support non-complex + // tensors + return unaryOp(UnaryOpType::Set, v); +} + +TensorView* real(TensorView* tv) { + return real(tv->as())->as(); +} + +// The output of imag(complex_tensor) are real numbers +Val* imag(Val* v) { + if (v->getDataType() == DataType::ComplexDouble) { + Val* out = newValLike(v, DataType::Double); + IrBuilder::create(UnaryOpType::Imag, out, v); + return out; + } + if (v->getDataType() == DataType::ComplexFloat) { + Val* out = newValLike(v, DataType::Float); + IrBuilder::create(UnaryOpType::Imag, out, v); + return out; + } + TORCH_CHECK(false, "imag not supported for non-complex tensors"); +} + +TensorView* imag(TensorView* tv) { + return imag(tv->as())->as(); +} + // UNARY FLOAT CAST OPERATIONS #define NVFUSER_DEFINE_UNARY_FLOAT_OP(op_name, op_type) \ @@ -1091,7 +1131,11 @@ TensorView* sum( TensorView* max( TensorView* v1, const std::vector& axes, - bool keep_dim /*=false*/) { + bool keep_dim /*=false*/, + DataType dtype /* DataType::Null */) { + TORCH_CHECK( + dtype == DataType::Null, + "A dtype other than Null is not currently supported."); Val* init = getMinimumValue(v1->getDataType().value()); TORCH_CHECK(init != nullptr, "Missing initial value"); return reductionOp(BinaryOpType::Max, axes, init, v1, keep_dim); @@ -1100,7 +1144,11 @@ TensorView* max( TensorView* min( TensorView* v1, const std::vector& axes, - bool keep_dim /*=false*/) { + bool keep_dim /*=false*/, + DataType dtype /* DataType::Null */) { + TORCH_CHECK( + dtype == DataType::Null, + "A dtype other than Null is not currently supported."); Val* init = getMaximumValue(v1->getDataType().value()); TORCH_CHECK(init != nullptr, "Missing initial value"); return reductionOp(BinaryOpType::Min, axes, init, v1, keep_dim); @@ -1522,13 +1570,14 @@ Val* where(Val* c, Val* v1, Val* v2) { "Condition should be of DataType Bool, not ", c->getDataType().value()); - auto cast_values = promoteValues(TypePromotion::default_op_config, {v1, v2}); + std::vector operands = {v1, v2}; + auto common_dtype = computeTypes(TypePromotion::default_op_config, operands); + auto cast_values = promoteValues(operands, common_dtype); v1 = cast_values[0]; v2 = cast_values[1]; TORCH_CHECK(c->getDataType().value() == DataType::Bool); - auto out_dtype = - promote_type(v1->getDataType().value(), v2->getDataType().value()); + auto out_dtype = common_dtype; auto out_vtype = promote_type(v1->getValType().value(), v2->getValType().value()); // Even when v1 and v2 are scalar, the output is a tensor if the diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index c6d0011c03049..7a1efee80f5dc 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -188,6 +188,9 @@ TORCH_CUDA_CU_API TensorView* neg(TensorView*); // randlike TORCH_CUDA_CU_API Val* randlike(Val*); TORCH_CUDA_CU_API TensorView* randlike(TensorView*); +// real +TORCH_CUDA_CU_API Val* real(Val*); +TORCH_CUDA_CU_API TensorView* real(TensorView*); // reciprocal TORCH_CUDA_CU_API Val* reciprocal(Val*); TORCH_CUDA_CU_API TensorView* reciprocal(TensorView*); @@ -227,6 +230,9 @@ TORCH_CUDA_CU_API TensorView* trunc(TensorView*); // bitwise_not TORCH_CUDA_CU_API Val* bitwise_not(Val*); TORCH_CUDA_CU_API TensorView* bitwise_not(TensorView*); +// imag +TORCH_CUDA_CU_API Val* imag(Val*); +TORCH_CUDA_CU_API TensorView* imag(TensorView*); // isfinite TORCH_CUDA_CU_API Val* isfinite(Val*); TORCH_CUDA_CU_API TensorView* isfinite(TensorView*); @@ -398,12 +404,14 @@ TORCH_CUDA_CU_API TensorView* sum( TORCH_CUDA_CU_API TensorView* max( TensorView* v1, const std::vector& reduction_axes, - bool keep_dim = false); + bool keep_dim = false, + DataType dtype = DataType::Null); TORCH_CUDA_CU_API TensorView* min( TensorView* v1, const std::vector& reduction_axes, - bool keep_dim = false); + bool keep_dim = false, + DataType dtype = DataType::Null); // COMPOUND OPERATIONS // add_alpha diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 78afd39b546dc..d4131d3646f15 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -777,7 +777,12 @@ kir::ExpressionEvaluator bindKernelInputs( "Something went wrong configuring launch. Inputs no longer match."); for (const auto dim : c10::irange(root_domain.size())) { - const auto extent = root_domain[dim]->extent(); + Val* extent = nullptr; + if (root_domain[dim]->hasExpandedExtent()) { + extent = root_domain[dim]->expandedExtent(); + } else { + extent = root_domain[dim]->extent(); + } const auto value = aten_tensor.sizes()[dim]; if (value == 0 && tensor_input->uses().empty()) { // If there's no uses, ignore there's a size-0 dimension. @@ -845,7 +850,12 @@ ExpressionEvaluator bindFusionInputs( aten_tensor.ndimension() == (int64_t)root_dom.size(), "Something went wrong configuring launch. Inputs do not match."); for (const auto dim : c10::irange(root_dom.size())) { - const auto extent = root_dom[dim]->extent(); + Val* extent = nullptr; + if (root_dom[dim]->hasExpandedExtent()) { + extent = root_dom[dim]->expandedExtent(); + } else { + extent = root_dom[dim]->extent(); + } const auto value = aten_tensor.sizes()[dim]; if (value == 0 && cg_tensor->uses().empty()) { // If there's no uses, ignore there's a size-0 dimension. diff --git a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp index 0519da180aad9..569c3f9f4d6c5 100644 --- a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp +++ b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp @@ -2439,12 +2439,16 @@ void CudaFuseGraph(std::shared_ptr& graph) { GRAPH_DEBUG("Remove inplace operations: ", *graph); // TODO: separate passes into different file; - // TODO: restore decomposition after fusion, in case we are decomposing - // operation that can't be fused; - decomposeLinearOps(graph->block()); + if (isEnabled(EnableOption::LinearDecomposition)) { + // TODO: restore decomposition after fusion, in case we are decomposing + // operation that can't be fused; + decomposeLinearOps(graph->block()); + } GRAPH_DEBUG("After decompose Linear Ops by nvfuser: ", *graph); - decomposeConvOps(graph->block()); + if (isEnabled(EnableOption::ConvDecomposition)) { + decomposeConvOps(graph->block()); + } GRAPH_DEBUG("After decompose decompose Conv Ops by nvfuser: ", *graph); replaceAliasOpsWithCopy(graph, graph->block()); diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index 0683d8e0cafd9..2c86938b9b5b3 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -34,7 +35,7 @@ static std::atomic cuda_fusion_guard_mode{true}; class NVFuserEnabler { private: c10::optional runtime_assigned_fuser_enabled_ = c10::nullopt; - std::once_flag enabled_check_flag_; + c10::once_flag enabled_check_flag_; std::mutex mutex_; public: @@ -97,7 +98,7 @@ class NVFuserEnabler { if (getCachedNNCNotNVFuser()) { return false; } - std::call_once(enabled_check_flag_, [&]() { + c10::call_once(enabled_check_flag_, [&]() { // if environment variable is setting the value, we must if (!runtime_assigned_fuser_enabled_.has_value() && getCachedFuserEnabledEnvVar().has_value()) { diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 9908b6c4e3d8b..f802d02aceaa4 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -15,6 +15,8 @@ #include +#include + #include #include @@ -851,7 +853,7 @@ class IrParser { } static void initRegistry() { - std::call_once(once_flag_, []() { + c10::call_once(once_flag_, []() { std::lock_guard lock(parser_mutex_); registerJitOperator(); }); @@ -2511,7 +2513,7 @@ class IrParser { { auto ptr_op = getOperatorForLiteral( - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"); + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"); REGISTER_PARSE_RULE( ptr_op, { @@ -2576,7 +2578,7 @@ class IrParser { { auto ptr_op = getOperatorForLiteral( - "aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); + "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -3559,7 +3561,7 @@ class IrParser { cached_registry_lookup_; // NOLINT // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) - static std::once_flag once_flag_; + static c10::once_flag once_flag_; }; std::unordered_set IrParser::parser_symbol_set_; // NOLINT std::unordered_set IrParser::parser_skip_set_; // NOLINT @@ -3570,7 +3572,7 @@ std::unordered_map IrParser::cached_registry_lookup_; // NOLINT // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -std::once_flag IrParser::once_flag_; +c10::once_flag IrParser::once_flag_; ProfileIValueOp* insertProfileIValueOp( Node* node, @@ -4025,7 +4027,7 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { static auto reduction_operator_schema = getOperatorForLiteral( - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)") + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)") ->schema(); if (node->matches(reduction_operator_schema)) { switch (offset) { diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/double_half_cast.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/double_half_cast.py index fbd85fa197e83..b3ce49d32d979 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/double_half_cast.py +++ b/torch/csrc/jit/codegen/cuda/python_frontend/examples/double_half_cast.py @@ -9,13 +9,10 @@ t0 = fd.define_tensor(2, DataType.Double) t1 = fd.define_tensor(2, DataType.Double) - fd.add_input(t0) - fd.add_input(t1) - - t0h = fd.Ops.cast(DataType.Half, t0) - t1h = fd.Ops.cast(DataType.Half, t1) - t2 = fd.Ops.add(t0h, t1h) - t3 = fd.Ops.relu(t2) + t0h = fd.ops.cast(t0, DataType.Half) + t1h = fd.ops.cast(t1, DataType.Half) + t2 = fd.ops.add(t0h, t1h) + t3 = fd.ops.relu(t2) fd.add_output(t3) diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/half_double_cast.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/half_double_cast.py index faa71fbba8ac9..d5f7070a4eeb8 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/half_double_cast.py +++ b/torch/csrc/jit/codegen/cuda/python_frontend/examples/half_double_cast.py @@ -9,11 +9,8 @@ t0 = fd.define_tensor(2, DataType.Half) t1 = fd.define_tensor(2, DataType.Double) - fd.add_input(t0) - fd.add_input(t1) - - t2 = fd.Ops.add(t0, t1) - t5 = fd.Ops.relu(t2) + t2 = fd.ops.add(t0, t1) + t5 = fd.ops.relu(t2) fd.add_output(t5) diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example.py index ce6e490ac9971..2bd236c0cf2d0 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example.py +++ b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example.py @@ -6,22 +6,17 @@ with FusionDefinition(fusion) as fd : t0 = fd.define_tensor(3) - t1 = fd.define_tensor(1) + t1 = fd.define_tensor(3) s0 = fd.define_scalar() - fd.add_input(t0) - fd.add_input(t1) - fd.add_input(s0) - c0 = fd.define_constant(3.0) - t1_b = fd.Ops.broadcast(t1, [True, True, False]) - t2 = fd.Ops.add(t0, t1) - t3 = fd.Ops.mul(t2, c0) - t4 = fd.Ops.atan2(t3, s0) - t5 = fd.Ops.relu(t4) - t6 = fd.Ops.sum(t5, [-1], False, DataType.Float) - t7 = fd.Ops.isfinite(t6) + t2 = fd.ops.add(t0, t1) + t3 = fd.ops.mul(t2, c0) + t4 = fd.ops.atan2(t3, s0) + t5 = fd.ops.relu(t4) + t6 = fd.ops.sum(t5, [-1], False, DataType.Float) + t7 = fd.ops.isfinite(t6) fd.add_output(t6) fd.add_output(t7) @@ -30,7 +25,7 @@ # Execute Fusion input1 = torch.ones(2, 4, 8, device='cuda') -input2 = torch.ones(8, device='cuda') +input2 = torch.ones(2, 4, 8, device='cuda') # Kernel compilation should be cached for the 2nd iteration # with input tensors of the same shape diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py index dc0390521d153..06733dbd68de0 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py +++ b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py @@ -11,11 +11,8 @@ t0 = fd.define_tensor(1) t1 = fd.define_tensor(3) - fd.add_input(t0) - fd.add_input(t1) - - t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [1]) - t2 = fd.Ops.add(t0_b, t1) + t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [1]) + t2 = fd.ops.add(t0_b, t1) fd.add_output(t2) @@ -46,11 +43,8 @@ t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride()) t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride()) - fd.add_input(t0) - fd.add_input(t1) - - t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2]) - t2 = fd.Ops.add(t0_b, t1) + t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2]) + t2 = fd.ops.add(t0_b, t1) fd.add_output(t2) @@ -76,11 +70,8 @@ t0 = fd.define_tensor([3, 1], [1, 1]) t1 = fd.define_tensor(1) - fd.add_input(t0) - fd.add_input(t1) - - t1_b = fd.Ops.broadcast_in_dim(t1, [3, 3], [0]) # 1 -> 0 - t2 = fd.Ops.add(t0, t1_b) + t1_b = fd.ops.broadcast_in_dim(t1, [3, 3], [0]) # 1 -> 0 + t2 = fd.ops.add(t0, t1_b) fd.add_output(t2) diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_fp16.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_fp16.py index e707a863dc86e..55fc2585c22cb 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_fp16.py +++ b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_fp16.py @@ -10,20 +10,15 @@ t1 = fd.define_tensor(1, DataType.Half) s0 = fd.define_scalar() - fd.add_input(t0) - fd.add_input(t1) - fd.add_input(s0) - c0 = fd.define_constant(3.0) - t1_b = fd.Ops.broadcast(t1, [True, True, False]) - t2 = fd.Ops.add(t0, t1) - t3 = fd.Ops.mul(t2, c0) - t4 = fd.Ops.mul(t3, s0) - t5 = fd.Ops.relu(t4) - t6 = fd.Ops.sum(t5, [-1], False, DataType.Float) + t2 = fd.ops.add(t0, t1) + t3 = fd.ops.mul(t2, c0) + t4 = fd.ops.mul(t3, s0) + t5 = fd.ops.relu(t4) + t6 = fd.ops.sum(t5, [-1], False, DataType.Float) - t7 = fd.Ops.cast(DataType.Half, t6) + t7 = fd.ops.cast(t6, DataType.Half) fd.add_output(t7) fusion.print_ir() diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp new file mode 100644 index 0000000000000..4efdc21526bab --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.cpp @@ -0,0 +1,65 @@ +#ifdef USE_CUDA +#include +#include + +namespace nvfuser { + +FusionDefinition::FusionDefinition(FusionOwner* fusion_owner) + : fusion_owner_(fusion_owner), + prev_fusion_(nullptr), + recording_(), + recording_state_(), + fusion_state_(), + ops(this) {} + +FusionDefinition* FusionDefinition::enter() { + prev_fusion_ = FusionGuard::getCurFusion(); + FusionGuard::setCurFusion(fusionPtr()); + return this; +} +void FusionDefinition::exit() { + fusion_state_.resize(recording_state_.size(), nullptr); + for (auto& record : recording_) { + auto functor = record.get(); + (*functor)(*this); + } + + FusionGuard::setCurFusion(prev_fusion_); + prev_fusion_ = nullptr; +} + +Scalar* FusionDefinition::defineScalar() { + Scalar* out = new nvfuser::Scalar(recording_state_.size()); + recording_state_.emplace_back(out); + return out; +} +Tensor* FusionDefinition::defineTensor() { + Tensor* out = new nvfuser::Tensor(recording_state_.size()); + recording_state_.emplace_back(out); + return out; +} +void FusionDefinition::defineRecord(RecordFunctor* record) { + recording_.emplace_back(record); +} + +void FusionDefinition::addInput(NvfVal* input) { + fusionPtr()->addInput(input); +} +void FusionDefinition::addOutput(NvfVal* output) { + fusionPtr()->addOutput(output); +} + +NvfVal* FusionDefinition::getFusionState(size_t index) const { + return fusion_state_.at(index); +} +void FusionDefinition::setFusionState(size_t index, NvfVal* val) { + fusion_state_.at(index) = val; +} + +Fusion* FusionDefinition::fusionPtr() { + return fusion_owner_->fusionPtr(); +} + +} // namespace nvfuser + +#endif // USE_CUDA diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h new file mode 100644 index 0000000000000..a5aca2f0d250d --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h @@ -0,0 +1,138 @@ +#pragma once +#include +#include +#include + +//! nvFuser Fusion IR Types +using NvfDataType = torch::jit::fuser::cuda::DataType; +using NvfFusion = torch::jit::fuser::cuda::Fusion; +using NvfTensorView = torch::jit::fuser::cuda::TensorView; +using NvfVal = torch::jit::fuser::cuda::Val; + +namespace nvfuser { + +struct RecordFunctor; + +//! The State, child classes Tensor and Scalar, and the StateType enum +//! are used to define state objects to encapsulate the recording of state +//! in the FusionDefinition. + +enum class StateType { + Tensor, + Scalar, + None, +}; + +struct State { + State(StateType _stype, size_t _index) : stype(_stype), index(_index) {} + + //! StateType is either: Tensor or Scalar + StateType stype; + //! A unique index to identifiy each recorded state item. + size_t index; +}; + +//! The child classes are used to define separate function signtures in +//! in the FusionDefintion to identify the appropriate Operator function. +//! +//! Example: +//! +//! add(Tensor* arg1, Tensor* arg2) -> Tensor* +//! add(Tensor* arg1, Scalar* arg2) -> Tensor* +//! add(Scalar* arg1, Scalar* arg2) -> Scalar* +struct Tensor : State { + Tensor(size_t _index) : State(StateType::Tensor, _index) {} +}; + +struct Scalar : State { + Scalar(size_t _index) : State(StateType::Scalar, _index) {} +}; + +//! FusionDefinition defines the C++ side of a Python Context manager to +//! encapsulate the definition of fusion operations. +//! +//! The FusionDefinition records the state definitions and operations prior +//! to exiting the context manager. Upon exit, the operations are queried +//! in a cache and the recorded records are used to build an nvFuser Fusion +//! object if the definition missed in the cache. +//! +//! \todo Need to implement the cache portion. Currently, the Fusion object +//! is always built. +//! +//! The nested Operators class was designed to allow the user to query all the +//! available Operators in the FusionDefinition via python help. +//! +//! Example: +//! help(FusionDefinition.Operators) +class FusionDefinition { + public: + FusionDefinition(FusionOwner* fusion_owner); + + // The copy/move/assign constructors/operators are being removed + // because it is not possible to copy the fusion_recording data member + // because that would require a virtual copy/move/assign of the + // RecordFunctor that is not supported. + FusionDefinition(const FusionDefinition& fd) = delete; + FusionDefinition(FusionDefinition&& fd) = delete; + FusionDefinition& operator=(const FusionDefinition& fd) = delete; + FusionDefinition& operator=(FusionDefinition&& fd) = delete; + + //! Enter Python Context Manager + FusionDefinition* enter(); + //! Exit Python Context Manager -- Triggers cache lookup + void exit(); + + //! These methods are used to record the FusionDefinition for cache lookup + + //! Defines a Scalar State Record + Scalar* defineScalar(); + //! Defines a Tensor State Record + Tensor* defineTensor(); + //! Defines a Record that records the operation required to + //! build the corresponding Fusion IR operation on cache miss. + void defineRecord(RecordFunctor* record); + + //! These methods are used to replay the operations for building the + //! nvFuser Fusion IR on a cache miss. + + //! Adds a Tensor/Scalar input to the Fusion object + void addInput(NvfVal* input); + //! Adds a Tensor/Scalar output to the Fusion object + void addOutput(NvfVal* output); + //! Gets a Fusion IR Tensor/Scalar object + NvfVal* getFusionState(size_t index) const; + //! Sets a Fusion IR Tensor/Scalar object + void setFusionState(size_t index, NvfVal* val); + + //! A pointer to the nvFuser Fusion IR Oject + NvfFusion* fusionPtr(); + + private: + // \todo These items will be replaced by a FusionManager instead of a cache + // for an individual fusion object + FusionOwner* fusion_owner_; + NvfFusion* prev_fusion_; + + //! A vector of record operations in the FusionDefintion + std::vector> recording_; + //! A vector of state (Tensor/Scalar) recorded in the FusionDefinition + std::vector> recording_state_; + + //! A vector of nvFuser Fusion IR TensorViews/Vals for building the Fusion + //! IR graph. + std::vector fusion_state_; + + public: + //! The Operators are not directly defined in this header. They are defined + //! in the python bindings through lambda functions so the user only needs to + //! define new operators in one place. + struct Operators { + Operators(FusionDefinition* fd) : fusion_definition(fd) {} + + FusionDefinition* fusion_definition; + }; + + Operators ops; +}; + +} // namespace nvfuser diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_owner.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_owner.h new file mode 100644 index 0000000000000..dce8cc4d65d5a --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_owner.h @@ -0,0 +1,36 @@ + +#pragma once + +#include + +using namespace torch::jit::fuser::cuda; + +namespace nvfuser { + +class FusionOwner { + public: + FusionOwner() : executor_cache_(std::make_unique()) {} + + // Non-copyable + FusionOwner(const FusionOwner&) = delete; + FusionOwner& operator=(const FusionOwner&) = delete; + + std::vector execute(const at::ArrayRef& inputs) { + return executor_cache_.runFusionWithInputs(inputs); + } + Fusion* fusionPtr() { + return executor_cache_.fusion(); + } + + void printIr() { + executor_cache_.printFusion(); + } + void printKernel() { + executor_cache_.fusion()->printKernel(); + } + + private: + FusionExecutorCache executor_cache_; +}; + +} // namespace nvfuser diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h new file mode 100644 index 0000000000000..77cbb7086aa59 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h @@ -0,0 +1,365 @@ +#pragma once +#include +#include +#include +#include + +namespace nvfuser { + +//! RecordFunctor is the base class record for operations recorded by +//! the FusionDefinition. It is, in essence, a node in the graph with +//! input edges, args, and outputs edges outputs that where the stored +//! values are indices into the recorded state. +//! +//! The virual functor is the operators that is replayed on a cache +//! to build the appropriate part of the nvFuser Fusion IR for a given +//! record. + +struct RecordFunctor { + RecordFunctor(std::vector _args, std::vector _outputs) + : args(std::move(_args)), outputs(std::move(_outputs)) {} + virtual ~RecordFunctor() = default; + + //! Abstraction for an operation to build this record's nvFuser Fusion IR + //! piece if the recording has a cache miss. + virtual void operator()(FusionDefinition& fd) = 0; + + //! Inputs that are indices into the FusionDefinition's Recorded State. + std::vector args; + //! Outputs that are indices into the FusionDefinition's Recorded State. + std::vector outputs; +}; + +//! The OpRecord RecordFunctor is the most widely used child class because +//! it utilizes varidiac template arguments to represent unary, binary, +//! ternary, and other similar flavors of operations in nvFuser that have +//! a mix of Tensor and Scalar arguments only. +//! +//! The additional data memeber of this child class records the function +//! signature of the nvFuser Arith Operation to be replayed upon a cache +//! miss by the functor operator() call. + +template +struct OpRecord : RecordFunctor { + OpRecord( + std::vector _args, + std::vector _outputs, + std::function fusion_op) + : RecordFunctor(std::move(_args), std::move(_outputs)), + fusion_op_(fusion_op) {} + virtual ~OpRecord() = default; + + //! The variadic set of indices for the number of args for this op are + //! deduced by providing the index_sequence as a parameter. Similarly, + //! the tuple type is also deduced. + //! + //! The tuple type is used to decide whether to cast the input argument + //! to a Fusion IR TensorView or leave it as a Fusion IR Val (Scalar). + //! + //! A deduced binary op could look like: + //! OutType opFunc, 0, 1> + //! A deduced ternary op could look like: + //! OutTupe opFunc, 0, 1, 2> + template + OutType opFunc( + FusionDefinition& fd, + TupleType& tp, + std::index_sequence) { + return fusion_op_( + dynamic_cast::type>( + fd.getFusionState(args.at(Is)))...); + } + + void operator()(FusionDefinition& fd) final { + using arg_tuple_t = std::tuple; + auto indices = + std::make_index_sequence::value>(); + // The tuple variable is never populated, it is passed for its type. + arg_tuple_t inputs; + auto output = opFunc(fd, inputs, indices); + fd.setFusionState(outputs.at(0), output); + } + + private: + //! An nvFuser Arith Operation function signature + std::function fusion_op_; +}; + +//! Specialized Record Functor for the FusionDefinition's broadcast_in_dim op. + +struct BroadcastOpRecord : RecordFunctor { + BroadcastOpRecord( + std::vector _args, + std::vector _outputs, + std::vector& output_shape, + std::vector& broadcast_dims) + : RecordFunctor(std::move(_args), std::move(_outputs)), + output_shape_(std::move(output_shape)), + broadcast_dims_(std::move(broadcast_dims)) {} + virtual ~BroadcastOpRecord() = default; + + void operator()(FusionDefinition& fd) final { + auto arg = fd.getFusionState(args.at(0))->template as(); + + const auto& arg_domains_nr = arg->domain()->noReductions(); + const auto arg_ndims = arg_domains_nr.size(); + TORCH_CHECK( + output_shape_.size() >= arg_ndims, + "The new shape is expected to be greater-then-or-equal to the input", + output_shape_.size(), + arg_ndims); + TORCH_CHECK( + arg_ndims == broadcast_dims_.size(), + "The broadcast dimensions should match the input dimensions.", + arg_ndims, + broadcast_dims_.size()); + + std::vector is_broadcast_dim(output_shape_.size(), true); + std::vector is_expand_dim(output_shape_.size(), true); + for (const auto idx : c10::irange(broadcast_dims_.size())) { + if (idx > 0) { + TORCH_CHECK( + broadcast_dims_[idx - 1] < broadcast_dims_[idx], + "Broadcast dimension is not greater than the previous value."); + } + TORCH_CHECK( + broadcast_dims_[idx] < static_cast(output_shape_.size()), + "Invalid broadcast_dims value."); + is_broadcast_dim.at(broadcast_dims_[idx]) = false; + // Note: when we expand a broadcasted dimension, we need to expand it + // to a concrete size, hence the need for `is_expand_dim` flag and the + // expand operation following the broadcast. + is_expand_dim.at(broadcast_dims_[idx]) = + arg_domains_nr[idx]->isBroadcast(); + } + + std::vector output_shape_on_bcast( + output_shape_.size(), nullptr); + bool has_expand = false; + for (const auto idx : c10::irange(output_shape_.size())) { + if (is_expand_dim[idx] && output_shape_[idx] != 1 && + output_shape_[idx] != -1) { + // TODO: this would be tricky to handle on dynamic shapes, we'll + // need to pass-in a symbol instead somehow. + output_shape_on_bcast[idx] = IrBuilder::create(output_shape_[idx]); + has_expand = true; + } else { + output_shape_on_bcast[idx] = IrBuilder::create(-1); + } + } + + auto output = torch::jit::fuser::cuda::broadcast(arg, is_broadcast_dim); + if (has_expand) { + output = torch::jit::fuser::cuda::expand(output, output_shape_on_bcast); + } + fd.setFusionState(outputs.at(0), output); + } + + private: + //! Represents the tensor dimensions of the output tensor. + std::vector output_shape_; + //! Communicates which dimensions of the output the input tensor maps. + //! For instance, for output [2, 3, 4] and input [3]. This vector would + //! contain [1]. + std::vector broadcast_dims_; +}; + +template +struct CastOpRecord : RecordFunctor { + CastOpRecord( + std::vector _args, + std::vector _outputs, + std::function fusion_op, + NvfDataType dtype) + : RecordFunctor(std::move(_args), std::move(_outputs)), + fusion_op_(fusion_op), + dtype_(dtype) {} + virtual ~CastOpRecord() = default; + + void operator()(FusionDefinition& fd) final { + auto arg = dynamic_cast(fd.getFusionState(args.at(0))); + auto output = fusion_op_(dtype_, arg); + fd.setFusionState(outputs.at(0), output); + } + + private: + //! nvFuser arith function signature + std::function fusion_op_; + //! Type to cast to. + NvfDataType dtype_; +}; + +//! Specialized Record Functor for recording FusionDefinition constant state. + +template +struct ConstantRecord : RecordFunctor { + ConstantRecord(std::vector _outputs, ValueType val) + : RecordFunctor({}, std::move(_outputs)), value_(val) {} + virtual ~ConstantRecord() = default; + + void operator()(FusionDefinition& fd) final { + NvfVal* output = IrBuilder::create(value_); + fd.setFusionState(outputs.at(0), output); + } + + private: + //! The constants literal value. + ValueType value_; +}; + +//! Specialized Record Functor for recording FusionDefinition input tensors. + +struct InputTensorRecord : RecordFunctor { + InputTensorRecord( + std::vector _outputs, + std::vector _symbolic_sizes, + std::vector _contiguous_info, + NvfDataType _dtype) + : RecordFunctor({}, std::move(_outputs)), + symbolic_sizes(std::move(_symbolic_sizes)), + contiguous_info(std::move(_contiguous_info)), + dtype(_dtype) {} + virtual ~InputTensorRecord() = default; + + void operator()(FusionDefinition& fd) final { + auto tv = TensorViewBuilder() + .ndims(symbolic_sizes.size()) + .contiguity(contiguous_info) + .shape(symbolic_sizes) + .dtype(dtype) + .build(); + + fd.setFusionState(outputs.at(0), tv); + fd.addInput(tv); + } + + //! A vector of tensor dimension sizes. + //! This vector only captures sizes of -1 or 1 to indicate a symbolic + //! dimension (-1) or a broadcast dimension (1). + std::vector symbolic_sizes; + //! A vector to indicate whether the a tensor dimension is contiguous + //! with the dimension just to its right. + std::vector contiguous_info; + //! Tensor data type. + NvfDataType dtype; +}; + +//! Specialized Record Functor for recording FusionDefinition outputs. + +template +struct OutputRecord : RecordFunctor { + OutputRecord(std::vector _args) + : RecordFunctor(std::move(_args), {}) {} + virtual ~OutputRecord() = default; + + void operator()(FusionDefinition& fd) final { + auto input = fd.getFusionState(args.at(0)); + + // With C++17, this statement should be "if constexpr" + if (std::is_same::value) { + fd.addOutput(input->template as()); + } else { + fd.addOutput(input); + } + } +}; + +//! Specialized Record Functor for the FusionDefinition's sum/min/max ops. + +struct ReductionOpRecord : RecordFunctor { + ReductionOpRecord( + std::vector _args, + std::vector _outputs, + std::function< + NvfTensorView*(NvfTensorView*, std::vector&, bool, NvfDataType)> + fusion_op, + std::vector axes, + bool keep_dim, + NvfDataType dtype) + : RecordFunctor(std::move(_args), std::move(_outputs)), + fusion_op_(fusion_op), + axes_(std::move(axes)), + keep_dim_(keep_dim), + dtype_(dtype) {} + virtual ~ReductionOpRecord() = default; + + void operator()(FusionDefinition& fd) final { + auto arg = fd.getFusionState(args.at(0))->template as(); + auto output = fusion_op_(arg, axes_, keep_dim_, dtype_); + fd.setFusionState(outputs.at(0), output); + } + + private: + //! nvFuser arith function signature for a given reduction operation + std::function< + NvfTensorView*(NvfTensorView*, std::vector&, bool, NvfDataType)> + fusion_op_; + //! The tensor dimensions to reduce + std::vector axes_; + //! Indicates whether to keep the reduced dimension(s). + bool keep_dim_; + //! The output data type. + NvfDataType dtype_; +}; + +//! Specialized Record Functor for recording FusionDefinition input scalars. + +struct ScalarRecord : RecordFunctor { + ScalarRecord(std::vector _outputs, NvfDataType dtype) + : RecordFunctor({}, std::move(_outputs)), dtype_(dtype) {} + virtual ~ScalarRecord() = default; + + void operator()(FusionDefinition& fd) final { + NvfVal* output = nullptr; + if (dtype_ == NvfDataType::Double) { + output = IrBuilder::create(); + } else if (dtype_ == NvfDataType::ComplexDouble) { + output = IrBuilder::create(); + } else if (dtype_ == NvfDataType::Bool) { + output = IrBuilder::create(); + } else if (dtype_ == NvfDataType::Int) { + output = IrBuilder::create(); + } else { + TORCH_CHECK(false, "Dtype is not supported:", dtype_); + } + fd.addInput(output); + fd.setFusionState(outputs.at(0), output); + } + + private: + //! Scalar data type. + NvfDataType dtype_; +}; + +//! Specialized Record Functor for the FusionDefinition's var op. + +struct VarianceOpRecord : RecordFunctor { + VarianceOpRecord( + std::vector _args, + std::vector _outputs, + std::vector& axes, + int64_t correction, + bool keep_dim) + : RecordFunctor(std::move(_args), std::move(_outputs)), + axes_(axes), + correction_(correction), + keep_dim_(keep_dim) {} + virtual ~VarianceOpRecord() = default; + + void operator()(FusionDefinition& fd) final { + auto arg = fd.getFusionState(args.at(0))->as(); + auto output = + torch::jit::fuser::cuda::variance(arg, axes_, correction_, keep_dim_); + fd.setFusionState(outputs.at(0), output); + } + + private: + //! Dimensions of tensor to reduce for variance calculation + std::vector axes_; + //! Bessel's correction value + int64_t correction_; + //! Indicates whether to keep the reduced dimension(s). + bool keep_dim_; +}; + +} // namespace nvfuser diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp index 19c5ca60abf38..fe24497276490 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp @@ -4,115 +4,45 @@ #include #include #include -#include #include #include -#include -#include +#include +#include #include #include #include -using namespace torch::jit::fuser::cuda; - -namespace { - -class PythonFusionOwner { - public: - PythonFusionOwner() : executor_cache_(std::make_unique()) {} - - // Non-copyable - PythonFusionOwner(const PythonFusionOwner&) = delete; - PythonFusionOwner& operator=(const PythonFusionOwner&) = delete; - - std::vector execute(const at::ArrayRef& inputs) { - return executor_cache_.runFusionWithInputs(inputs); - } - Fusion* fusionPtr() { - return executor_cache_.fusion(); - } - - void printIr() { - executor_cache_.printFusion(); - } - void printKernel() { - executor_cache_.fusion()->printKernel(); - } - - private: - FusionExecutorCache executor_cache_; -}; - -// Manually applying the fusion guard via a context manager -class FusionDefinitionContextManager { - public: - FusionDefinitionContextManager(PythonFusionOwner* fusion_owner) - : fusion_owner_(fusion_owner), prev_fusion_(nullptr) {} - - // Context Manager Methods - FusionDefinitionContextManager* enter() { - prev_fusion_ = FusionGuard::getCurFusion(); - FusionGuard::setCurFusion(fusionPtr()); - return this; - } - - void exit() { - FusionGuard::setCurFusion(prev_fusion_); - prev_fusion_ = nullptr; - } - - void addInput(torch::jit::fuser::cuda::Val* input) { - fusionPtr()->addInput(input); - } - void addOutput(torch::jit::fuser::cuda::Val* output) { - fusionPtr()->addOutput(output); - } - - Fusion* fusionPtr() { - return fusion_owner_->fusionPtr(); - } - - // An Empty namespace to add arith ops - struct Ops {}; - - private: - PythonFusionOwner* fusion_owner_; - Fusion* prev_fusion_; -}; - -} // namespace - namespace torch { namespace jit { void initNvFuserPythonBindings(PyObject* module) { auto m = py::handle(module).cast(); + //! Top Level nvFuser Python submodule auto nvfuser = m.def_submodule("_nvfuser"); - // DataTypes supported by NVFuser in Fusion Definition - // Types not related to values found in fusion defintions - // were purposely left out. - // NOTE: DataType was ambiguous under torch::jit without full qualification. - py::enum_(nvfuser, "DataType") - .value("Double", torch::jit::fuser::cuda::DataType::Double) - .value("Float", torch::jit::fuser::cuda::DataType::Float) - .value("Half", torch::jit::fuser::cuda::DataType::Half) - .value("Int", torch::jit::fuser::cuda::DataType::Int) - .value("Int32", torch::jit::fuser::cuda::DataType::Int32) - .value("Bool", torch::jit::fuser::cuda::DataType::Bool) - .value("BFloat16", torch::jit::fuser::cuda::DataType::BFloat16) - .value("ComplexFloat", torch::jit::fuser::cuda::DataType::ComplexFloat) - .value("ComplexDouble", torch::jit::fuser::cuda::DataType::ComplexDouble) - .value("Null", torch::jit::fuser::cuda::DataType::Null); - - // Binding an object that owns a FusionExecutorCache instance and provides an - // interface - py::class_ fusion(nvfuser, "Fusion"); + //! DataTypes supported by nvFuser in the FusionDefinition + py::enum_(nvfuser, "DataType") + .value("Double", NvfDataType::Double) + .value("Float", NvfDataType::Float) + .value("Half", NvfDataType::Half) + .value("Int", NvfDataType::Int) + .value("Int32", NvfDataType::Int32) + .value("Bool", NvfDataType::Bool) + .value("BFloat16", NvfDataType::BFloat16) + .value("ComplexFloat", NvfDataType::ComplexFloat) + .value("ComplexDouble", NvfDataType::ComplexDouble) + .value("Null", NvfDataType::Null); + + //! Binding an object that owns a FusionExecutorCache instance and provides + //! an interface + //! \todo This object will be removed when a FusionManager is added + //! containing a cache. + py::class_ fusion(nvfuser, "Fusion"); fusion.def(py::init<>()) .def( "execute", - [](PythonFusionOwner& self, const py::iterable& iter) { + [](nvfuser::FusionOwner& self, const py::iterable& iter) { std::vector inputs; for (py::handle obj : iter) { inputs.push_back(toIValue(obj, c10::AnyType::get())); @@ -120,90 +50,72 @@ void initNvFuserPythonBindings(PyObject* module) { return self.execute(inputs); }, py::return_value_policy::reference) - .def("print_ir", [](PythonFusionOwner& self) { self.printIr(); }) - .def("print_kernel", [](PythonFusionOwner& self) { self.printKernel(); }); - - // Bindings to Types required for Tensor/Scalar Creation - py::class_(nvfuser, "TensorView") - .def( - "__str__", - [](TensorView& self) -> std::string { - std::stringstream ss; - TORCH_CHECK( - self.getDataType().has_value(), - "TensorView does not have DataType?"); - ss << self.getDataType().value(); - return self.toString() + " DataType: " + ss.str() + - " Contiguity: " + self.domain()->getContiguityString(); - }, - py::return_value_policy::reference); - py::class_(nvfuser, "Val") - .def( - "__str__", - [](torch::jit::fuser::cuda::Val& self) -> std::string { - return self.toString(); - }, - py::return_value_policy::reference); - - // C++ Side of Context Manager used to mimic the FusionGuard as a way - // to programatically distinguish code used to define the Fusion instead - // of having the user mysteriously create an object prior to adding definition - // code where the object is not used. - py::class_ fusion_def( - nvfuser, "FusionDefinition"); - fusion_def.def(py::init()) + .def("print_ir", [](nvfuser::FusionOwner& self) { self.printIr(); }) + .def("print_kernel", [](nvfuser::FusionOwner& self) { + self.printKernel(); + }); + + //! These are the FusionDefinition supported object types that are either + //! defined as inputs or the output of an operation. + py::class_(nvfuser, "Tensor"); + py::class_(nvfuser, "Scalar"); + + //! The FusionDefinition is a context manager in Python where the user will + //! define the set the operations and connections between operations for + //! nvFuser to create. + py::class_ fusion_def(nvfuser, "FusionDefinition"); + fusion_def.def(py::init()) + .def_readwrite("ops", &nvfuser::FusionDefinition::ops) .def( "__enter__", - [](FusionDefinitionContextManager& self) { return self.enter(); }) + [](nvfuser::FusionDefinition& self) -> nvfuser::FusionDefinition* { + return self.enter(); + }) .def( "__exit__", - [](FusionDefinitionContextManager& self, + [](nvfuser::FusionDefinition& self, void* exc_type, void* exc_value, void* traceback) { self.exit(); }) - .def( - "add_input", - [](FusionDefinitionContextManager& self, - torch::jit::fuser::cuda::Val* input) { self.addInput(input); }) - .def( - "add_input", - [](FusionDefinitionContextManager& self, TensorView* input) { - self.addInput(input); - }) .def( "add_output", - [](FusionDefinitionContextManager& self, - torch::jit::fuser::cuda::Val* output) { self.addOutput(output); }) + [](nvfuser::FusionDefinition& self, nvfuser::Scalar* output) { + self.defineRecord( + new nvfuser::OutputRecord({output->index})); + }) .def( "add_output", - [](FusionDefinitionContextManager& self, TensorView* output) { - self.addOutput(output); + [](nvfuser::FusionDefinition& self, nvfuser::Tensor* output) { + self.defineRecord( + new nvfuser::OutputRecord({output->index})); }) .def( "define_tensor", - [](FusionDefinitionContextManager& self, + [](nvfuser::FusionDefinition& self, size_t ndims, - torch::jit::fuser::cuda::DataType dtype = - torch::jit::fuser::cuda::DataType::Float) -> TensorView* { - return TensorViewBuilder() - .ndims(ndims) - .dtype(dtype) - .contiguity(std::vector(ndims, true)) - .build(); + NvfDataType dtype = NvfDataType::Float) -> nvfuser::Tensor* { + std::vector maybe_symbolic_sizes(ndims, -1); + ; + std::vector contig_info(ndims, false); + + nvfuser::Tensor* out = self.defineTensor(); + self.defineRecord(new nvfuser::InputTensorRecord( + {out->index}, + std::move(maybe_symbolic_sizes), + std::move(contig_info), + dtype)); + + return out; }, py::arg("ndims"), py::arg("dtype") = torch::jit::fuser::cuda::DataType::Float, py::return_value_policy::reference) .def( - // TODO: Should the inernals of this function live more explicitly in - // TensorViewBuilder? "define_tensor", - [](FusionDefinitionContextManager& self, - // TODO: This should come in as int64_t not int - std::vector sizes, - std::vector strides, - torch::jit::fuser::cuda::DataType dtype = - torch::jit::fuser::cuda::DataType::Float) -> TensorView* { + [](nvfuser::FusionDefinition& self, + std::vector sizes, + std::vector strides, + NvfDataType dtype = NvfDataType::Float) -> nvfuser::Tensor* { TORCH_CHECK( sizes.size() == strides.size(), "The number of sizes does not match the number of strides.", @@ -240,70 +152,117 @@ void initNvFuserPythonBindings(PyObject* module) { } } - return TensorViewBuilder() - .ndims(maybe_symbolic_sizes.size()) - .contiguity(contig_info) - .shape(maybe_symbolic_sizes) - .dtype(dtype) - .build(); + nvfuser::Tensor* out = self.defineTensor(); + self.defineRecord(new nvfuser::InputTensorRecord( + {out->index}, + std::move(maybe_symbolic_sizes), + std::move(contig_info), + dtype)); + + return out; }, py::arg("sizes"), py::arg("strides"), - py::arg("dtype") = torch::jit::fuser::cuda::DataType::Float, + py::arg("dtype") = NvfDataType::Float, + py::return_value_policy::reference) + .def( + "define_constant", + [](nvfuser::FusionDefinition& self, double val) -> nvfuser::Scalar* { + nvfuser::Scalar* out = self.defineScalar(); + self.defineRecord( + new nvfuser:: + ConstantRecord( + {out->index}, val)); + return out; + }, py::return_value_policy::reference) .def( "define_constant", - [](FusionDefinitionContextManager& self, - double val) -> torch::jit::fuser::cuda::Val* { - return IrBuilder::create(val); + [](nvfuser::FusionDefinition& self, + c10::complex val) -> nvfuser::Scalar* { + nvfuser::Scalar* out = self.defineScalar(); + self.defineRecord(new nvfuser::ConstantRecord< + torch::jit::fuser::cuda::ComplexDouble, + c10::complex>({out->index}, val)); + return out; }, py::return_value_policy::reference) .def( "define_constant", - [](FusionDefinitionContextManager& self, - bool val) -> torch::jit::fuser::cuda::Val* { - return IrBuilder::create(val); + [](nvfuser::FusionDefinition& self, bool val) -> nvfuser::Scalar* { + nvfuser::Scalar* out = self.defineScalar(); + self.defineRecord( + new nvfuser:: + ConstantRecord( + {out->index}, val)); + return out; }, py::return_value_policy::reference) .def( "define_constant", - [](FusionDefinitionContextManager& self, - int64_t val) -> torch::jit::fuser::cuda::Val* { - return IrBuilder::create(val); + [](nvfuser::FusionDefinition& self, int64_t val) -> nvfuser::Scalar* { + nvfuser::Scalar* out = self.defineScalar(); + self.defineRecord( + new nvfuser:: + ConstantRecord( + {out->index}, val)); + return out; }, py::return_value_policy::reference) .def( "define_scalar", - [](FusionDefinitionContextManager& self, - torch::jit::fuser::cuda::DataType dtype = - torch::jit::fuser::cuda::DataType::Double) - -> torch::jit::fuser::cuda::Val* { - if (dtype == torch::jit::fuser::cuda::DataType::Double) { - return IrBuilder::create(); - } else if (dtype == torch::jit::fuser::cuda::DataType::Bool) { - return IrBuilder::create(); - } else if (dtype == torch::jit::fuser::cuda::DataType::Int) { - return IrBuilder::create(); - } else { - TORCH_CHECK(false, "Dtype is not supported:", dtype); - } + [](nvfuser::FusionDefinition& self, + NvfDataType dtype = torch::jit::fuser::cuda::DataType::Double) + -> nvfuser::Scalar* { + nvfuser::Scalar* out = self.defineScalar(); + self.defineRecord(new nvfuser::ScalarRecord({out->index}, dtype)); + return out; }, py::arg("dtype") = torch::jit::fuser::cuda::DataType::Double, py::return_value_policy::reference); - py::class_ nvf_ops(fusion_def, "Ops"); + //! The Operators class is a nested class of FusionDefinition to allow the + //! user to query the class for the list of operators. + //! + //! Example: + //! help(FusionDefinition.Operators) + //! + //! Additional operators are expected to be defined below as needed. They + //! may require defining a new RecordFunctor child class if they are unique. + py::class_ nvf_ops( + fusion_def, "Operators"); + nvf_ops.def(py::init()); // ******************** INSERT OP BINDINGS BELOW HERE ******************** -#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast( \ - &torch::jit::fuser::cuda::op_name), \ +#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* input) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::OpRecord( \ + {input->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* input) -> nvfuser::Scalar* { \ + nvfuser::Scalar* output = self.fusion_definition->defineScalar(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::OpRecord( \ + {input->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ py::return_value_policy::reference); NVFUSER_PYTHON_BINDING_UNARY_OP("abs", abs) @@ -347,29 +306,72 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_UNARY_OP("isneginf", isneginf) NVFUSER_PYTHON_BINDING_UNARY_OP("isposinf", isposinf) NVFUSER_PYTHON_BINDING_UNARY_OP("isreal", isreal) + NVFUSER_PYTHON_BINDING_UNARY_OP("real", real) + NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag) #undef NVFUSER_PYTHON_BINDING_UNARY_OP #define NVFUSER_PYTHON_BINDING_BINARY_OP(op_str, op_name) \ - nvf_ops.def_static( \ + nvf_ops.def( \ op_str, \ - py::overload_cast( \ - &torch::jit::fuser::cuda::op_name), \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Tensor* arg2) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*>( \ + {arg1->index, arg2->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ py::return_value_policy::reference); \ - nvf_ops.def_static( \ + nvf_ops.def( \ op_str, \ - py::overload_cast( \ - &torch::jit::fuser::cuda::op_name), \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Scalar* arg2) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::OpRecord( \ + {arg1->index, arg2->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ py::return_value_policy::reference); \ - nvf_ops.def_static( \ + nvf_ops.def( \ op_str, \ - py::overload_cast( \ - &torch::jit::fuser::cuda::op_name), \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Tensor* arg2) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::OpRecord( \ + {arg1->index, arg2->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ py::return_value_policy::reference); \ - nvf_ops.def_static( \ + nvf_ops.def( \ op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Scalar* arg2) -> nvfuser::Scalar* { \ + nvfuser::Scalar* output = self.fusion_definition->defineScalar(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::OpRecord( \ + {arg1->index, arg2->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ py::return_value_policy::reference); NVFUSER_PYTHON_BINDING_BINARY_OP("add", add) @@ -394,284 +396,550 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_right_shift", bitwise_left_shift) #undef NVFUSER_PYTHON_BINDING_BINARY_OP -#define NVFUSER_PYTHON_BINDING_TERNARY_OP(op_str, op_name) \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast( \ - &torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - torch::jit::fuser::cuda::Val*, \ - TensorView*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - TensorView*, \ - TensorView*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*, \ - TensorView*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ +#define NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP(op_str, op_name) \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfVal*>( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser:: \ + OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser:: \ + OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Scalar* { \ + nvfuser::Scalar* output = self.fusion_definition->defineScalar(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); + + NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("add_alpha", add_alpha) + NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("sub_alpha", sub_alpha) +#undef NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP + +#define NVFUSER_PYTHON_BINDING_TERNARY_OP(op_str, op_name) \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Scalar* { \ + nvfuser::Scalar* output = self.fusion_definition->defineScalar(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Tensor* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*>( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfTensorView*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfVal*>( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Tensor* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfVal*, \ + NvfTensorView*>( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfTensorView*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Tensor* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfVal*, \ + NvfTensorView*, \ + NvfTensorView*>( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfTensorView*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Tensor* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser:: \ + OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfVal*, NvfVal*, NvfTensorView*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser:: \ + OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser:: \ + OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ py::return_value_policy::reference); NVFUSER_PYTHON_BINDING_TERNARY_OP("lerp", lerp) NVFUSER_PYTHON_BINDING_TERNARY_OP("where", where) #undef NVFUSER_PYTHON_BINDING_TERNARY_OP -#define NVFUSER_PYTHON_BINDING_TERNARY_ABRV1_OP(op_str, op_name) \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ +#define NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP(op_str, op_name) \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Scalar* { \ + nvfuser::Scalar* output = self.fusion_definition->defineScalar(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Scalar* arg3) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser:: \ + OpRecord( \ + {arg1->index, arg2->index, arg3->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ py::return_value_policy::reference); - NVFUSER_PYTHON_BINDING_TERNARY_ABRV1_OP("clamp", clamp) - NVFUSER_PYTHON_BINDING_TERNARY_ABRV1_OP("threshold", threshold) -#undef NVFUSER_PYTHON_BINDING_TERNARY_ABRV1_OP - -#define NVFUSER_PYTHON_BINDING_TERNARY_ABRV2_OP(op_str, op_name) \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ + NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("clamp", clamp) + NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("threshold", threshold) +#undef NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP + +#define NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP(op_str, op_name) \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Scalar* arg3, \ + nvfuser::Scalar* arg4) -> nvfuser::Scalar* { \ + nvfuser::Scalar* output = self.fusion_definition->defineScalar(); \ + self.fusion_definition->defineRecord( \ + new nvfuser:: \ + OpRecord( \ + {arg1->index, arg2->index, arg3->index, arg4->index}, \ + {output->index}, \ + static_cast< \ + NvfVal* (*)(NvfVal*, NvfVal*, NvfVal*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Tensor* arg3, \ + nvfuser::Scalar* arg4) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*>( \ + {arg1->index, arg2->index, arg3->index, arg4->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfTensorView*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Scalar* arg3, \ + nvfuser::Scalar* arg4) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfVal*, \ + NvfVal*>( \ + {arg1->index, arg2->index, arg3->index, arg4->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfTensorView*, NvfVal*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Tensor* arg3, \ + nvfuser::Scalar* arg4) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfVal*, \ + NvfTensorView*, \ + NvfVal*>( \ + {arg1->index, arg2->index, arg3->index, arg4->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfTensorView*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Tensor* arg3, \ + nvfuser::Scalar* arg4) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfVal*, \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfVal*>( \ + {arg1->index, arg2->index, arg3->index, arg4->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfTensorView*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Tensor* arg3, \ + nvfuser::Scalar* arg4) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfVal*, \ + NvfVal*, \ + NvfTensorView*, \ + NvfVal*>( \ + {arg1->index, arg2->index, arg3->index, arg4->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfVal*, NvfVal*, NvfTensorView*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg1, \ + nvfuser::Scalar* arg2, \ + nvfuser::Scalar* arg3, \ + nvfuser::Scalar* arg4) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfTensorView*, \ + NvfVal*, \ + NvfVal*, \ + NvfVal*>( \ + {arg1->index, arg2->index, arg3->index, arg4->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfTensorView*, NvfVal*, NvfVal*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg1, \ + nvfuser::Tensor* arg2, \ + nvfuser::Scalar* arg3, \ + nvfuser::Scalar* arg4) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::OpRecord< \ + NvfTensorView*, \ + NvfVal*, \ + NvfTensorView*, \ + NvfVal*, \ + NvfVal*>( \ + {arg1->index, arg2->index, arg3->index, arg4->index}, \ + {output->index}, \ + static_cast< \ + NvfTensorView* (*)(NvfVal*, NvfTensorView*, NvfVal*, NvfVal*)>( \ + torch::jit::fuser::cuda::op_name))); \ + return output; \ + }, \ py::return_value_policy::reference); - NVFUSER_PYTHON_BINDING_TERNARY_ABRV2_OP("add_alpha", add_alpha) - NVFUSER_PYTHON_BINDING_TERNARY_ABRV2_OP("sub_alpha", sub_alpha) -#undef NVFUSER_PYTHON_BINDING_TERNARY_ABRV2_OP + NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul) +#undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP -#define NVFUSER_PYTHON_BINDING_QUAD_ABRV3_OP(op_str, op_name) \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - TensorView*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ +#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name) \ + nvf_ops.def( \ op_str, \ - py::overload_cast< \ - TensorView*, \ - torch::jit::fuser::cuda::Val*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - TensorView*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - TensorView*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*, \ - TensorView*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ - py::return_value_policy::reference); \ - nvf_ops.def_static( \ - op_str, \ - py::overload_cast< \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*, \ - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::op_name), \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg, \ + const std::vector& axes, \ + bool keep_dim, \ + NvfDataType dtype) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord(new nvfuser::ReductionOpRecord( \ + {arg->index}, \ + {output->index}, \ + torch::jit::fuser::cuda::op_name, \ + axes, \ + keep_dim, \ + dtype)); \ + return output; \ + }, \ + py::arg("arg"), \ + py::arg("axes"), \ + py::arg("keep_dim"), \ + py::arg("dtype") = torch::jit::fuser::cuda::DataType::Null, \ + py::return_value_policy::reference); + + NVFUSER_PYTHON_BINDING_REDUCTION_OP("sum", sum) + NVFUSER_PYTHON_BINDING_REDUCTION_OP("max", max) + NVFUSER_PYTHON_BINDING_REDUCTION_OP("min", min) +#undef NVFUSER_PYTHON_BINDING_REDUCTION_OP + +#define NVFUSER_PYTHON_BINDING_CAST_OP(op_str, op_name) \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Tensor* arg, \ + NvfDataType dtype) -> nvfuser::Tensor* { \ + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::CastOpRecord( \ + {arg->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name), \ + dtype)); \ + return output; \ + }, \ + py::return_value_policy::reference); \ + nvf_ops.def( \ + op_str, \ + [](nvfuser::FusionDefinition::Operators& self, \ + nvfuser::Scalar* arg, \ + NvfDataType dtype) -> nvfuser::Scalar* { \ + nvfuser::Scalar* output = self.fusion_definition->defineScalar(); \ + self.fusion_definition->defineRecord( \ + new nvfuser::CastOpRecord( \ + {arg->index}, \ + {output->index}, \ + static_cast( \ + torch::jit::fuser::cuda::op_name), \ + dtype)); \ + return output; \ + }, \ py::return_value_policy::reference); - NVFUSER_PYTHON_BINDING_QUAD_ABRV3_OP("addcmul", addcmul) -#undef NVFUSER_PYTHON_BINDING_QUAD_ABRV3_OP - - // Reduction Operations - nvf_ops.def_static( - "max", &torch::jit::fuser::cuda::max, py::return_value_policy::reference); - nvf_ops.def_static( - "min", &torch::jit::fuser::cuda::min, py::return_value_policy::reference); - nvf_ops.def_static( - "sum", &torch::jit::fuser::cuda::sum, py::return_value_policy::reference); - nvf_ops.def_static( + NVFUSER_PYTHON_BINDING_CAST_OP("cast", castOp) +#undef NVFUSER_PYTHON_BINDING_CAST_OP + + nvf_ops.def( "var", - [](TensorView* input, - const std::vector& dims, + [](nvfuser::FusionDefinition::Operators& self, + nvfuser::Tensor* arg, + std::vector& axes, int64_t correction, - bool keepdim) -> TensorView* { - return torch::jit::fuser::cuda::variance( - input, dims, correction, keepdim); + bool keepdim) -> nvfuser::Tensor* { + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); + self.fusion_definition->defineRecord(new nvfuser::VarianceOpRecord( + {arg->index}, {output->index}, axes, correction, keepdim)); + return output; }, py::return_value_policy::reference); - // Broadcast operations - nvf_ops.def_static( - "broadcast", - &torch::jit::fuser::cuda::broadcast, - py::return_value_policy::reference); - // TODO: We don't have a way to realize a tensor if the operation creates - // the output of a fusion. - nvf_ops.def_static( + nvf_ops.def( "broadcast_in_dim", - [](TensorView* input, - std::vector& output_shape, - std::vector& broadcast_dims) -> TensorView* { - const auto& iter_domains = input->domain()->noReductions(); - const auto input_ndims = iter_domains.size(); - TORCH_CHECK( - output_shape.size() >= input_ndims, - "The new shape is expected to be greater-then-or-equal to the input", - output_shape.size(), - input_ndims); - TORCH_CHECK( - input_ndims == broadcast_dims.size(), - "The broadcast dimensions should match the input dimensions.", - input_ndims, - broadcast_dims.size()); - - // default all dimensions to be broadcasted - std::vector is_broadcast_dim(output_shape.size(), true); - std::vector is_expand_dim(output_shape.size(), true); - for (const auto idx : c10::irange(broadcast_dims.size())) { - if (idx > 0) { - TORCH_CHECK( - broadcast_dims[idx - 1] < broadcast_dims[idx], - "Broadcast dimension is not greater than the previous value."); - } - TORCH_CHECK( - broadcast_dims[idx] < static_cast(output_shape.size()), - "Invalid broadcast_dims value."); - is_broadcast_dim.at(broadcast_dims[idx]) = false; - // Note: when we expand a broadcasted dimension, we need to expand it - // to a concrete size, hence the need for `is_expand_dim` flag and the - // expand operation following the broadcast. - is_expand_dim.at(broadcast_dims[idx]) = - iter_domains[idx]->isBroadcast(); - } - - std::vector output_shape_on_bcast( - output_shape.size(), nullptr); - for (const auto idx : c10::irange(output_shape.size())) { - if (is_expand_dim[idx]) { - // TODO: this would be tricky to handle on dynamic shapes, we'll - // need to pass-in a symbol instead somehow. - output_shape_on_bcast[idx] = - IrBuilder::create(output_shape[idx]); - } else { - output_shape_on_bcast[idx] = IrBuilder::create(-1); - } - } - - auto bcasted_input = - torch::jit::fuser::cuda::broadcast(input, is_broadcast_dim); - return torch::jit::fuser::cuda::expand( - bcasted_input, output_shape_on_bcast); + [](nvfuser::FusionDefinition::Operators& self, + nvfuser::Tensor* arg, + std::vector& output_shape, + std::vector& broadcast_dims) -> nvfuser::Tensor* { + nvfuser::Tensor* output = self.fusion_definition->defineTensor(); + self.fusion_definition->defineRecord(new nvfuser::BroadcastOpRecord( + {arg->index}, {output->index}, output_shape, broadcast_dims)); + return output; }, py::return_value_policy::reference); - - // Cast Operations - nvf_ops.def_static( - "cast", - py::overload_cast( - &torch::jit::fuser::cuda::castOp), - py::return_value_policy::reference); - nvf_ops.def_static( - "cast", - py::overload_cast< - torch::jit::fuser::cuda::DataType, - torch::jit::fuser::cuda::Val*>(&torch::jit::fuser::cuda::castOp), - py::return_value_policy::reference); } } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 0f431077dd863..2befe18bece6f 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -4168,6 +4168,14 @@ TEST_F(NVFuserTest, FusionUnaryOps_CUDA) { OpTuple{at::isposinf, UnaryOpType::IsPosInf, "isposinf"}, }; + // The following ops only supports complex + std::vector ops_complex_only{ + // real is supported via UnaryOpType::Set for non-complex types, and + // UnaryOpType::Real requires input to be complex + OpTuple{at::real, UnaryOpType::Real, "real"}, + OpTuple{at::imag, UnaryOpType::Imag, "imag"}, + }; + // Complex support for the following op is not working in nvFuser yet std::vector ops_skip_complex{ // TODO: abs is actually supported in nvFuser, but it has bug!!! @@ -4199,6 +4207,9 @@ TEST_F(NVFuserTest, FusionUnaryOps_CUDA) { ops_without_complex.end()); ops_to_test.insert( ops_to_test.end(), ops_skip_complex.begin(), ops_skip_complex.end()); + } else { + ops_to_test.insert( + ops_to_test.end(), ops_complex_only.begin(), ops_complex_only.end()); } std::for_each(ops.begin(), ops.end(), [&](OpTuple& op) { test_op( diff --git a/torch/csrc/jit/codegen/cuda/transform_view.cpp b/torch/csrc/jit/codegen/cuda/transform_view.cpp index e6c7f381a1a82..03f7b1bb4c9a4 100644 --- a/torch/csrc/jit/codegen/cuda/transform_view.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_view.cpp @@ -462,6 +462,9 @@ class AnalyzeViewTransformation { //! 2) MergeAdjacentSingletonAxes class merges or reduces any //! adjacent singleton dimensions. class MergeAxesInterface { + public: + virtual ~MergeAxesInterface() = default; + protected: // See addMergeTransform for "is_index_merge_rhs" and // "is_last_axis_rfactor" descriptions @@ -547,6 +550,8 @@ class AnalyzeViewTransformation { mtsa.handle(false /* is_index_merge_rhs */, is_last_axis_rfactor); } + virtual ~MergeThenSplitAxes() = default; + private: MergeThenSplitAxes( AnalyzeViewTransformation* avt, @@ -594,6 +599,8 @@ class AnalyzeViewTransformation { true /* is_index_merge_rhs */, true /* is_last_axis_rfactor */); } + virtual ~MergeAdjacentSingletonAxes() = default; + private: MergeAdjacentSingletonAxes( AnalyzeViewTransformation* avt, diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index c9575dd0c4d58..35f6310f2820f 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -477,6 +477,10 @@ static const char* unary_op_type2string(UnaryOpType t) { return "isposinf"; case UnaryOpType::IsReal: return "isreal"; + case UnaryOpType::Real: + return "std::real"; + case UnaryOpType::Imag: + return "std::imag"; default: TORCH_INTERNAL_ASSERT(false, "No string found for unary op type."); } diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 7c5b83f9de50a..fce051d432fff 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -158,6 +158,7 @@ enum class UnaryOpType { Floor, Frac, Gelu, + Imag, Silu, Lgamma, Log, @@ -167,6 +168,7 @@ enum class UnaryOpType { BitCast, Neg, RandLike, + Real, Reciprocal, Relu, Rsqrt, diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index ad96fcc38f4d9..5838970a1a031 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -164,7 +164,10 @@ auto parseDisableOptions() { auto parseEnableOptions() { std::unordered_map options_map = { - {EnableOption::Complex, false}, {EnableOption::KernelProfile, false}}; + {EnableOption::Complex, false}, + {EnableOption::KernelProfile, false}, + {EnableOption::LinearDecomposition, false}, + {EnableOption::ConvDecomposition, false}}; if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_ENABLE")) { c10::string_view options_view(dump_options); @@ -175,6 +178,10 @@ auto parseEnableOptions() { options_map[EnableOption::Complex] = true; } else if (token == "kernel_profile") { options_map[EnableOption::KernelProfile] = true; + } else if (token == "linear_decomposition") { + options_map[EnableOption::LinearDecomposition] = true; + } else if (token == "conv_decomposition") { + options_map[EnableOption::ConvDecomposition] = true; } else { TORCH_CHECK( false, diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index 6b67d7710bb90..13c4d65c3c59c 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -76,7 +76,9 @@ TORCH_CUDA_CU_API bool isDisabled(DisableOption option); //! enum class EnableOption { Complex, //! Enable complex support on python - KernelProfile //! Enable intra-kernel performance profiling + KernelProfile, //! Enable intra-kernel performance profiling + LinearDecomposition, //! Enable linear-bias decomposition + ConvDecomposition //! Enable conv-bias decomposition }; TORCH_CUDA_CU_API bool isEnabled(EnableOption option); diff --git a/torch/csrc/jit/codegen/onednn/kernel.cpp b/torch/csrc/jit/codegen/onednn/kernel.cpp index f12f4bc86d03d..5001daf8ab677 100644 --- a/torch/csrc/jit/codegen/onednn/kernel.cpp +++ b/torch/csrc/jit/codegen/onednn/kernel.cpp @@ -251,7 +251,7 @@ void LlgaKernel::run(Stack& stack) { // Even in case of concurrent threads, the kernel would be initialized once. // TODO: Try not using an atomic lock - std::call_once( + c10::call_once( initialized_flag, [&](const TensorArgs& inputs) { GRAPH_DEBUG("Initializing input logical tensors"); diff --git a/torch/csrc/jit/codegen/onednn/kernel.h b/torch/csrc/jit/codegen/onednn/kernel.h index a9c7b24ad8c30..6e32c8e3bc907 100644 --- a/torch/csrc/jit/codegen/onednn/kernel.h +++ b/torch/csrc/jit/codegen/onednn/kernel.h @@ -8,6 +8,8 @@ #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -83,7 +85,7 @@ class LlgaKernel { ArgSpecs outputSpecs_; std::vector constantLogicalTensors_; std::string debugName_; - std::once_flag initialized_flag; + c10::once_flag initialized_flag; bool is_initialized_ = false; }; diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 7453d340596e9..4b178fff7592f 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -3565,7 +3565,9 @@ struct to_ir { case prim::enumerate: { const SourceRange& loc = apply.range(); auto inputs = apply.inputs(); - auto input_size = apply.inputs().size(); + auto input_size = inputs.size(); + auto attributes = apply.attributes(); + auto attribute_size = attributes.size(); // enumerate(x) can be rewrite as subtrees: // IterableTree(RangeValue(0, math.inf), SimpleValue(x)) Value* start_index = nullptr; @@ -3577,11 +3579,22 @@ struct to_ir { if (input_size == 2) { start_index = emitSugaredExpr(inputs[1], 1)->asValue(loc, method); } - - if (input_size > 2) { + auto arg_size = input_size + attribute_size; + if (arg_size > 2) { throw ErrorReport(loc) - << "enumerate expected at most 2 arguments, got " << input_size; + << "enumerate expected at most 2 arguments, got " << arg_size; } + + if (attribute_size == 1) { + if (attributes[0].name().name() != "start") { + throw ErrorReport(loc) + << "enumerate expected kwarg name 'start', got '" + << attributes[0].name().name() << "'"; + } + start_index = + emitSugaredExpr(attributes[0].value(), 1)->asValue(loc, method); + } + std::vector range_inputs; if (start_index != nullptr) { range_inputs.emplace_back(start_index); diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index b88fa36d6459f..d39c489be213e 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -815,6 +815,7 @@ void AliasDb::analyzeImpl(Node* node) { for (const auto i : c10::irange(schema.arguments().size())) { const at::AliasInfo* formal = schema.arguments()[i].alias_info(); const auto& actualValue = node->inputs().at(i); + // Skip if there's no alias annotation if (!formal) { continue; @@ -891,25 +892,22 @@ void AliasDb::analyzeImpl(Node* node) { continue; } + bool inputs_has_alias = false; for (const auto& formalAlias : formal->beforeSets()) { - // If we encounter an alias annotation that wasn't in the inputs: - if (!formalToActual.count(formalAlias)) { - // If this alias is not seen elsewhere and is the only annotation on - // the output, it's equivalent to being fresh: - // e.g. foo(Tensor(a) self) -> Tensor(b) - if (formal->beforeSets().size() == 1) { - giveFreshAlias(actual); - } - // Or it is the form of a|fresh, which we can ignore, taking the - // conservative assumption that the output must alias `a`, e.g - // aten::cuda(Tensor(a) self) -> Tensor(a|fresh) - - // Don't assign an alias set in that case. - continue; + if (formalToActual.count(formalAlias)) { + inputs_has_alias = true; + auto toAlias = formalToActual.at(formalAlias); + makePointerTo(actual, toAlias); } - - auto toAlias = formalToActual.at(formalAlias); - makePointerTo(actual, toAlias); + } + // If all the alias annotation that we encounter weren't in the inputs: + // e.g. foo(Tensor(a) self) -> Tensor(b) + // or foo(Tensor(a) self) -> Tensor(b|c) + // Otherwise it is the form of a|fresh, which we can ignore, taking the + // conservative assumption that the output must alias `a`, e.g + // aten::cuda(Tensor(a) self) -> Tensor(a|fresh) + if (!inputs_has_alias && formal->beforeSets().size()) { + giveFreshAlias(actual); } // Record writes diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index a7d968b87bd26..c365cd969189e 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -25,11 +25,11 @@ namespace jit { * is considered safe. * * There is a special alias set called the "wildcard set", which indicates that - * we're not sure what this value may alias. To be conservative, we consider - * the wildcard alias set as potentially aliasing any value within the same - * type class. Whenever a value becomes contained by another value, such as - * when a Tensor is appended to a List[Tensor], the contained element becomes - * part of the wildcard set. + * we're not sure what this value may alias. To be conservative, we consider the + * wildcard alias set as potentially aliasing any other wildcard value within + * the same type class. Whenever a value becomes contained by another value, + * such as when a Tensor is appended to a List[Tensor], the contained element + * becomes part of the wildcard set. * * Values that contain other mutable types, such as List[Tensor], are * initialized as containing the Wildcard set for all contained mutable types. @@ -314,7 +314,7 @@ class AliasDb { // Helper check that invariants over AliasDb are maintained. // Useful if you are using the AliasDb mutation API and want to check you did // the right thing. -void Lint(const AliasDb* db); +TORCH_API void Lint(const AliasDb* db); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 0ddaebfe724b5..2b61b6b0d43bb 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1144,40 +1144,25 @@ Operation Node::getOperation() const { } bool Node::isNondeterministic() const { - static const OperatorSet nondeterministic_ops = { - "aten::dropout(Tensor input, float p, bool train) -> Tensor", - "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)", - "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor", - "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor", - "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor", - "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor", - "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)", - "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor", - "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor", - "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor", - "aten::poisson(Tensor self, Generator? generator) -> Tensor", - "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor", - "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", - "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", - "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", - "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", - "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", - "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", - "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; - - if (!isMemberOf(nondeterministic_ops)) { + const auto schema = maybeSchema(); + if (!kind().is_aten()) { return false; } - // Dropout with train = False is deterministic - if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && - is_constant(attr::train) && !get(attr::train).value()) { + // All aten ops are expecte to have a schema. However this is left as a + // warning instead of an assert to ensure that previous use cases do not + // break. + if (!schema) { + TORCH_WARN("aten Schema not found."); return false; } - return true; + torch::utils::SchemaInfo schema_info(*schema); + if (hasNamedInput("train")) { + auto value = constant_as(namedInput("train")); + if (value.has_value()) { + schema_info.addArgumentValue("train", *value); + } + } + return schema_info.is_nondeterministic(); } bool Node::hasSideEffects() const { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index ca10809b5e70c..87eed82594689 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -9,11 +9,11 @@ #include #include #include +#include #include #include #include -#include #include #include #include @@ -1428,7 +1428,7 @@ struct Graph { TORCH_API void remapTypes(const std::function& type_map); private: - friend void Lint(const AliasDb* db); + friend TORCH_API void Lint(const AliasDb* db); TORCH_API void freeNode(Node* n); TORCH_API void freeValue(Value* v); TORCH_API void freeBlock(Block* b); diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index a7ca90e6bb35a..c983737d94341 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -710,10 +710,34 @@ uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) { mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) { auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content); - FlatbufferLoader loader; - loader.setShouldLoadOperators(false); - mobile::Module m = loader.parseModule(ff_module); - return mobile::get_module_info(m); + mobile::ModuleInfo minfo; + minfo.operator_version = ff_module->operator_version(); + minfo.bytecode_version = ff_module->bytecode_version(); + + uint32_t mobile_ivalue_size = ff_module->mobile_ivalue_size(); + if (mobile_ivalue_size == 0) { + mobile_ivalue_size = ff_module->ivalues()->size(); + } + + std::vector type_name_list; + for (uint32_t i = 0; i < mobile_ivalue_size; i++) { + const auto* ival = ff_module->ivalues()->Get(i); + if (const auto* func = ival->val_as_Function()) { + minfo.function_names.insert(func->qn()->str()); + for (const auto* op : *func->operators()) { + at::OperatorName opname(op->name()->str(), op->overload_name()->str()); + minfo.opname_to_num_args[mobile::operator_str(opname)] = + op->num_args_serialized(); + } + for (const auto* type_ann : *func->type_annotations()) { + type_name_list.push_back(type_ann->str()); + } + } + } + c10::TypeParser parser(type_name_list); + parser.parseList(); + minfo.type_names = parser.getContainedTypes(); + return minfo; } mobile::Module load_mobile_module_from_stream_with_copy( diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index a69e409746a6f..29a1ad01a93da 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -85,8 +85,9 @@ bool Function::initialize_operators(bool should_check_operators) { if (should_check_operators) { TORCH_CHECK( unsupported_op_names.empty(), - "Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/", - c10::Join(", ", unsupported_op_names)); + "Following ops cannot be found: [", + c10::Join(", ", unsupported_op_names), + "]. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/c/mobile/"); } code_.initialized = all_ops_supported; return all_ops_supported; diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 0c8fe54a9f6d8..9a3ce0d8f8391 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -560,11 +560,17 @@ mobile::Module _load_for_mobile( std::istream& in, c10::optional device, ExtraFilesMap& extra_files) { - std::shared_ptr data; - size_t size = 0; - std::tie(data, size) = get_stream_content(in); - return _load_mobile_from_bytes( - data, size, device, extra_files, kDefaultMobileLoadOptions); + if (getFileFormat(in) == FileFormat::FlatbufferFileFormat) { + std::shared_ptr data; + size_t size = 0; + std::tie(data, size) = get_stream_content(in); + return _load_mobile_from_bytes( + data, size, device, extra_files, kDefaultMobileLoadOptions); + } + std::unique_ptr rai = std::make_unique(&in); + auto module = _load_for_mobile_impl( + std::move(rai), device, extra_files, kDefaultMobileLoadOptions); + return module; } mobile::Module _load_for_mobile( diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp index 2b5a7bc7937f2..300716fd8e154 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp @@ -440,7 +440,8 @@ c10::IValue preprocess( return cu.serialize(); } -static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess); +// TODO(mvz): temporarily disable NNC backend in mobile builds. +// static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess); } // namespace nnc } // namespace mobile diff --git a/torch/csrc/jit/mobile/nnc/backend.cpp b/torch/csrc/jit/mobile/nnc/backend.cpp index 77742faa5cc38..89a96428a09b0 100644 --- a/torch/csrc/jit/mobile/nnc/backend.cpp +++ b/torch/csrc/jit/mobile/nnc/backend.cpp @@ -51,7 +51,8 @@ class NNCBackend : public PyTorchBackendInterface { }; namespace { -static const auto cls = torch::jit::backend("nnc"); +// TODO(mvz): temporarily disable NNC backend in mobile builds. +// static const auto cls = torch::jit::backend("nnc"); } // namespace } // namespace nnc diff --git a/torch/csrc/jit/mobile/profiler_edge.cpp b/torch/csrc/jit/mobile/profiler_edge.cpp index 72365588c783a..d3dc596ca3dcc 100644 --- a/torch/csrc/jit/mobile/profiler_edge.cpp +++ b/torch/csrc/jit/mobile/profiler_edge.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -66,6 +67,16 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler( tls_edge_profiler = this; } +void KinetoEdgeCPUProfiler::recordBackendMemoryEvent( + void* ptr, + int64_t alloc_size, + int64_t total_allocated, + int64_t total_reserved, + c10::Device device) { + c10::reportMemoryUsageToProfiler( + ptr, alloc_size, total_allocated, total_reserved, device); +} + void KinetoEdgeCPUProfiler::recordBackendEvent( const int64_t start_time_us, const int64_t end_time_us, diff --git a/torch/csrc/jit/mobile/profiler_edge.h b/torch/csrc/jit/mobile/profiler_edge.h index d970f8d2207e7..52dc26d1221a7 100644 --- a/torch/csrc/jit/mobile/profiler_edge.h +++ b/torch/csrc/jit/mobile/profiler_edge.h @@ -67,6 +67,12 @@ class TORCH_API KinetoEdgeCPUProfiler { const int64_t debug_handle, const std::string& event_name, const std::string& backend_name); + void recordBackendMemoryEvent( + void* ptr, + int64_t alloc_size, + int64_t total_allocated, + int64_t total_reserved, + c10::Device device); ~KinetoEdgeCPUProfiler(); @@ -88,11 +94,20 @@ TORCH_API KinetoEdgeCPUProfiler* getCurrentEdgeProfiler(); mobile::getCurrentEdgeProfiler()->recordBackendEvent( \ start_time_us, end_time_us, debug_handle, event_name, backend_name); \ } + +#define RECORD_BACKEND_MEMORY_EVENT_TO_EDGE_PROFILER( \ + ptr, alloc_size, total_allocated, total_reserved, device) \ + if (mobile::getCurrentEdgeProfiler()) { \ + mobile::getCurrentEdgeProfiler()->recordBackendMemoryEvent( \ + ptr, alloc_size, total_allocated, total_reserved, device); \ + } #else #define RECORD_BACKEND_EVENT_TO_EDGE_PROFILER( \ start_time_us, end_time_us, debug_handle, event_name, backend_name) +#define RECORD_BACKEND_MEMORY_EVENT_TO_EDGE_PROFILER( \ + ptr, alloc_size, total_allocated, total_reserved, device) #endif } // namespace mobile } // namespace jit diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.cpp b/torch/csrc/jit/mobile/promoted_prim_ops.cpp index 1a9a965cb8aa9..13d942a7b7c97 100644 --- a/torch/csrc/jit/mobile/promoted_prim_ops.cpp +++ b/torch/csrc/jit/mobile/promoted_prim_ops.cpp @@ -58,6 +58,11 @@ void size(Stack& stack) { pack(stack, t.sizes().vec()); } +void sym_size(Stack& stack) { + auto t = std::move(pop(stack)).toTensor(); + pack(stack, t.sym_sizes().vec()); +} + void device(Stack& stack) { push(stack, pop(stack).toTensor().device()); } @@ -68,6 +73,10 @@ void dtype(Stack& stack) { push(stack, static_cast(a.scalar_type())); } +void layout(Stack& stack) { + push(stack, pop(stack).toTensor().layout()); +} + void toPrimDType(Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool non_blocking; @@ -191,7 +200,7 @@ void dictIndex(Stack& stack) { push(stack, value->value()); } -static const C10_UNUSED std::array op_reg = { +static const C10_UNUSED std::array op_reg = { mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex), mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor), mobile::prim_op_fn_register("aten::format", aten_format), @@ -201,6 +210,7 @@ static const C10_UNUSED std::array op_reg = { raiseExceptionWithMessage), mobile::prim_op_fn_register("prim::device", device), mobile::prim_op_fn_register("prim::dtype", dtype), + mobile::prim_op_fn_register("prim::layout", layout), mobile::prim_op_fn_register("aten::__not__", _not), mobile::prim_op_fn_register("aten::__is__", is), mobile::prim_op_fn_register("aten::__isnot__", isNot), diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.h b/torch/csrc/jit/mobile/promoted_prim_ops.h index d823bcec878b6..a1ae3f52ba1a7 100644 --- a/torch/csrc/jit/mobile/promoted_prim_ops.h +++ b/torch/csrc/jit/mobile/promoted_prim_ops.h @@ -19,10 +19,14 @@ void aten_format(Stack& stack); void size(Stack& stack); +void sym_size(Stack& stack); + void device(Stack& stack); void dtype(Stack& stack); +void layout(Stack& stack); + void toPrimDType(Stack& stack); void dim(Stack& stack); diff --git a/torch/csrc/jit/passes/add_if_then_else.cpp b/torch/csrc/jit/passes/add_if_then_else.cpp index 72a085fd021e2..303a0824398fd 100644 --- a/torch/csrc/jit/passes/add_if_then_else.cpp +++ b/torch/csrc/jit/passes/add_if_then_else.cpp @@ -13,7 +13,7 @@ bool hasNoNodes(Block* block) { bool hasTrivialSubBlocks(Node* node) { const auto blocks = node->blocks(); - DCHECK_EQ(blocks.size(), 2); + TORCH_DCHECK_EQ(blocks.size(), 2); return hasNoNodes(blocks[0]) && hasNoNodes(blocks[1]); } diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index f1cf02e050ce6..49694d57ef10a 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -545,6 +545,15 @@ bool isConstant(Value* val, const ValueToParamPairMap& valsToParamsMap) { AttributeKind::t); // Check other types? } +bool hasParamInput(Node* n, const ValueToParamPairMap& valsToParamsMap) { + for (auto input : n->inputs()) { + if (valsToParamsMap.find(input) != valsToParamsMap.end()) { + return true; + } + } + return false; +} + std::vector getValues( Node* node, const ValueToParamPairMap& valsToParamsMap) { @@ -638,14 +647,26 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) { // Constant folding is not supported for this op. Skip it. continue; } - // Create a new input to the block (prim::Param node output). Add a - // corresponding entry in valToParamMap. Replace the downstream inputs - // with this value, and disconnect all the input values of the folded node. + at::Tensor updatedVal = *updatedValWrapped; - auto newSourceNodeOutput = b->addInput(); - valsToParamsMap.insert( - {newSourceNodeOutput, - std::make_pair(newSourceNodeOutput->debugName(), updatedVal)}); + auto newSourceNodeOutput = [&]() -> Value* { + if (onnx_constant_fold::hasParamInput(node, valsToParamsMap)) { + // Create a new input to the block (prim::Param node output). Add a + // corresponding entry in valToParamMap. Replace the downstream inputs + // with this value, and disconnect all the input values of the folded + // node. + auto newSourceNodeOutput = b->addInput(); + valsToParamsMap.insert( + {newSourceNodeOutput, + std::make_pair(newSourceNodeOutput->debugName(), updatedVal)}); + return newSourceNodeOutput; + } else { + auto newSourceNode = + createONNXConstant(node->owningGraph(), node, updatedVal); + newSourceNode->copyMetadata(node); + return newSourceNode->output(); + } + }(); newSourceNodeOutput->inferTypeFrom(updatedVal); node->outputs().at(0)->replaceAllUsesWith(newSourceNodeOutput); // Next we remove the current node that has been replaced by diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 7d10f690923ae..29d8072a16bf5 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -1,11 +1,11 @@ #include -#include - #include +#include #include #include #include #include +#include namespace torch { namespace jit { @@ -205,6 +205,15 @@ c10::optional ConstantValueMap::GetShapeValue( return ConstantValueMap::getInstance().shapeValueMap[tensorName]; } +// Gets the inferredShapeData which is obtained by ONNX data propagation +ShapeDataMap& ConstantValueMap::GetInferredShapeData() { + return ConstantValueMap::getInstance().inferredShapeData; +} + +SymbolDimMap& ConstantValueMap::GetSymbolDimMap() { + return ConstantValueMap::getInstance().symbolDimMap; +} + template void UpdateStrKey( Map& map, @@ -236,6 +245,8 @@ void ConstantValueMap::UpdateValueName( ConstantValueMap::getInstance().useInferredTypeMap, old_name, new_name); UpdateStrKey( ConstantValueMap::getInstance().shapeValueMap, old_name, new_name); + UpdateStrKey( + ConstantValueMap::getInstance().inferredShapeData, old_name, new_name); } void ConstantValueMap::ClearMaps() { @@ -245,6 +256,8 @@ void ConstantValueMap::ClearMaps() { ConstantValueMap::getInstance().typeReliableMap.clear(); ConstantValueMap::getInstance().useInferredTypeMap.clear(); ConstantValueMap::getInstance().shapeValueMap.clear(); + ConstantValueMap::getInstance().inferredShapeData.clear(); + ConstantValueMap::getInstance().symbolDimMap.clear(); } // For debug only. @@ -304,6 +317,34 @@ void ConstantValueMap::PrintMaps() { std::cout << std::endl; } } + std::cout << std::endl; + std::cout << "InferredShape Map:" << std::endl; + count = 0; + for (const auto& x : ConstantValueMap::getInstance().inferredShapeData) { + std::cout << "(node " << x.first << ": "; + for (const auto& dim : x.second.dim()) { + if (dim.has_dim_param()) { + std::cout << dim.dim_param() << " "; + } else { + std::cout << dim.dim_value() << " "; + } + } + std::cout << "), "; + count++; + if (count % 10 == 0) { + std::cout << std::endl; + } + } + std::cout << std::endl; + std::cout << "SymbolDim Map:" << std::endl; + count = 0; + for (const auto& x : ConstantValueMap::getInstance().symbolDimMap) { + std::cout << "(" << x.first << ": " << x.second << "), "; + count++; + if (count % 10 == 0) { + std::cout << std::endl; + } + } } } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/constant_map.h b/torch/csrc/jit/passes/onnx/constant_map.h index d94cd7eea7bc4..b9436c0d6fc75 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.h +++ b/torch/csrc/jit/passes/onnx/constant_map.h @@ -1,16 +1,20 @@ #pragma once +#include #include +#include #include #include namespace torch { namespace jit { +using ShapeDataMap = + std::unordered_map; + class ConstantValueMap { public: static ConstantValueMap& getInstance(); - static void SetRank(const std::string& tensorName, size_t rankValue); static bool HasRank(const std::string& tensorName); static c10::optional GetRank(const std::string& tensorName); @@ -52,6 +56,10 @@ class ConstantValueMap { static c10::optional GetShapeValue( const std::string& tensorName); + static ShapeDataMap& GetInferredShapeData(); + + static SymbolDimMap& GetSymbolDimMap(); + static void UpdateValueName( const std::string& old_name, const std::string& new_name); @@ -81,6 +89,10 @@ class ConstantValueMap { // from a node. shapeValueMap stores the value of the tensor from a node when // this tensor represents a shape. std::unordered_map shapeValueMap; + // Stores earlier data propagation results so that they are accessible + // during future node-level shape inference. + ShapeDataMap inferredShapeData; + SymbolDimMap symbolDimMap; }; } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index 6c2d8809b0a8f..cf21264e4c22c 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -390,13 +390,15 @@ c10::optional FunctionExtractor::InferScope(Node* n) { } else { scope_list scopes; std::copy_if( - input_scopes.begin(), input_scopes.end(), scopes.begin(), IsValidScope); + input_scopes.begin(), + input_scopes.end(), + std::back_inserter(scopes), + IsValidScope); std::copy_if( output_scopes.begin(), output_scopes.end(), - scopes.begin(), + std::back_inserter(scopes), IsValidScope); - if (scopes.size() > 0) { auto common_ancestor = FindCommonAncestor(scopes); if (common_ancestor.has_value() && diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 2dd4da49ff34d..77b58f037d597 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -167,6 +167,16 @@ Node* createONNXUnsqueeze( return unsqueeze_node; } +Node* createONNXConstant( + Graph* graph, + Node* n_to_insert_before, + at::Tensor value) { + Node* constant_node = graph->create(onnx::Constant, 1); + constant_node->insertBefore(n_to_insert_before); + constant_node->t_(attr::value, value); + return constant_node; +} + bool isValidToTransformToONNXConcatNode(Node* lc_node) { return !lc_node->inputs().empty(); } diff --git a/torch/csrc/jit/passes/onnx/helper.h b/torch/csrc/jit/passes/onnx/helper.h index b17d5d00758b7..77eb98ba8a707 100644 --- a/torch/csrc/jit/passes/onnx/helper.h +++ b/torch/csrc/jit/passes/onnx/helper.h @@ -53,6 +53,10 @@ Node* createONNXUnsqueeze( Value* input, int axis, int opset_version); +Node* createONNXConstant( + Graph* graph, + Node* n_to_insert_before, + at::Tensor value); bool isValidToTransformToONNXConcatNode(Node* lc_node); diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 601ec87c21cdb..3571199442d4b 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -143,6 +143,35 @@ static c10::optional InferExpectedScalarType(const Node* n) { } return c10::nullopt; }; + auto emplace_type_from_scalar = + [&typesFromTensors, &typesFromScalars](at::ScalarType scalar_type) { + // Mimic PyTorch scalar type promotion logic + // from https://github.com/pytorch/pytorch/issues/9515 + // Quoting: + // A Tensor is a considered a "wrapped number" if it is + // auto-wrapped from a C++ or Python number type. Integer types are + // wrapped as 0-dim int64 tensors and floating-point types are + // wrapped as 0-dim double tensors. + auto default_scalar_type = + at::typeMetaToScalarType(at::get_default_dtype()); + switch (scalar_type) { + case at::kDouble: + // floating-point numbers wrapped as double tensors are + // considered to have default type, instead of double. + typesFromScalars.emplace_back(default_scalar_type); + break; + case at::kLong: + case at::kBool: + // bool and integer numbers remain the same type. + typesFromScalars.emplace_back(scalar_type); + break; + default: + // other types are not from wrapped numbers, + // track them as types from tensors. + typesFromTensors.emplace_back(scalar_type); + break; + } + }; std::for_each( n->inputs().begin(), n->inputs().end(), [&](const Value* input) { @@ -162,36 +191,27 @@ static c10::optional InferExpectedScalarType(const Node* n) { auto tensor = input->node()->t(attr::value); auto rank = tensor.dim(); auto scalar_type = tensor.scalar_type(); - // Mimic PyTorch scalar type promotion logic - // from https://github.com/pytorch/pytorch/issues/9515 - // Quoting: - // A Tensor is a considered a "wrapped number" if it is - // auto-wrapped from a C++ or Python number type. Integer types are - // wrapped as 0-dim int64 tensors and floating-point types are - // wrapped as 0-dim double tensors. + if (rank == 0) { - auto default_scalar_type = - at::typeMetaToScalarType(at::get_default_dtype()); - switch (scalar_type) { - case at::kDouble: - // floating-point numbers wrapped as double tensors are - // considered to have default type, instead of double. - typesFromScalars.emplace_back(default_scalar_type); - break; - case at::kLong: - case at::kBool: - // bool and integer numbers remain the same type. - typesFromScalars.emplace_back(scalar_type); - break; - default: - // other types are not from wrapped numbers, - // track them as types from tensors. - typesFromTensors.emplace_back(scalar_type); - break; - } + emplace_type_from_scalar(scalar_type); } else { typesFromTensors.emplace_back(scalar_type); } + } else if (nkind == prim::Param) { + // ONNX doesn't support scalar as graph input. When + // seeing a scalar input, we convert its expected type to tensor. + if (auto scalar_type = get_scalar_type(input)) { + auto tensor_type = input->type()->castRaw(); + // get_scalar_type returns non-null value already guranatees + // that the input has a valid tensor_type. + TORCH_INTERNAL_ASSERT(nullptr != tensor_type); + auto rank = tensor_type->dim(); + if (rank && rank.value() == 0) { + emplace_type_from_scalar(scalar_type.value()); + } else { + typesFromTensors.emplace_back(scalar_type.value()); + } + } } else if (auto scalar_type = get_scalar_type(input)) { typesFromTensors.emplace_back(*scalar_type); } @@ -404,6 +424,7 @@ void ScalarTypeAnalysisForONNX( const std::shared_ptr& graph, bool lowprecision_cast, int opset_version) { + GRAPH_DUMP("Before ScalarTypeAnalysisForONNX: ", graph); ImplicitCastForONNX(graph->block()); if (lowprecision_cast) { LowPrecisionCastForStandardOpsONNX(graph->block(), opset_version); diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index cbdd96a24f796..04427743bb6d5 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -90,9 +90,35 @@ namespace { namespace onnx_torch = ::torch::onnx; namespace onnx = ::ONNX_NAMESPACE; +c10::ShapeSymbol ONNXDimToShapeSymbol( + const onnx::TensorShapeProto_Dimension& dim, + SymbolDimMap& symbol_dim_map) { + if (dim.has_dim_value()) { + return c10::ShapeSymbol::fromStaticSize(dim.dim_value()); + } + c10::optional sym = c10::nullopt; + if (dim.has_dim_param()) { + // If this param is already known, assign the same Symbol. + GRAPH_UPDATE("Got dim_param:", dim.dim_param()); + for (const auto& pair : symbol_dim_map) { + if (pair.second == dim.dim_param()) { + sym = pair.first; + break; + } + } + } + if (!sym) { + sym = c10::ShapeSymbol::newSymbol(); + // If dim.dim_param() is empty, no need to keep track + // because there won't be duplicates. + symbol_dim_map[sym.value()] = dim.dim_param(); + } + return sym.value(); +} + TensorTypePtr TorchTensorTypeFromONNX( const onnx::TypeProto_Tensor& onnx_tensor_type, - SymbolDimMap& symbol_map) { + SymbolDimMap& symbol_dim_map) { c10::optional scalar_type; if (onnx_tensor_type.has_elem_type()) { scalar_type = ONNXTypeToATenType(onnx_tensor_type.elem_type()); @@ -109,35 +135,8 @@ TensorTypePtr TorchTensorTypeFromONNX( const auto& onnx_shape = onnx_tensor_type.shape(); for (const auto i : c10::irange(onnx_shape.dim_size())) { - auto& dim = onnx_shape.dim(i); - if (dim.has_dim_value()) { - sizes.emplace_back(c10::ShapeSymbol::fromStaticSize(dim.dim_value())); - } else { - c10::optional sym = c10::nullopt; - if (dim.has_dim_param()) { - // A specific dim param is produced. - // Search if this is already known, - // and assign the same Symbol. - GRAPH_UPDATE("Got dim_param:", dim.dim_param()); - for (const auto& pair : symbol_map) { - if (pair.second == dim.dim_param()) { - sym = pair.first; - break; - } - } - if (!sym) { - sym = c10::ShapeSymbol::newSymbol(); - symbol_map[sym.value()] = dim.dim_param(); - } - } else { - // A None dim param is produced. - // Assign a new Symbol, no need to keep track - // of it because there won't be duplicates. - sym = c10::ShapeSymbol::newSymbol(); - symbol_map[sym.value()] = ""; - } - sizes.emplace_back(sym.value()); - } + sizes.emplace_back( + ONNXDimToShapeSymbol(onnx_shape.dim(i), symbol_dim_map)); } v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {}); v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes)); @@ -154,14 +153,14 @@ TensorTypePtr TorchTensorTypeFromONNX( ListTypePtr TorchListTypeFromONNX( const onnx::TypeProto_Sequence& onnx_sequence_type, - SymbolDimMap& symbol_map) { + SymbolDimMap& symbol_dim_map) { c10::optional scalar_type; if (onnx_sequence_type.has_elem_type()) { const auto& onnx_seq_elem_type = onnx_sequence_type.elem_type(); if (onnx_seq_elem_type.has_tensor_type()) { const auto& onnx_tensor_type = onnx_seq_elem_type.tensor_type(); const auto v_tensor_type = - TorchTensorTypeFromONNX(onnx_tensor_type, symbol_map); + TorchTensorTypeFromONNX(onnx_tensor_type, symbol_dim_map); auto v_type = ListType::create(v_tensor_type); return v_type; } @@ -172,7 +171,7 @@ ListTypePtr TorchListTypeFromONNX( void UpdateTorchValueByOnnxValueInfo( Value* v, const onnx::ValueInfoProto& p_info, - SymbolDimMap& symbol_map) { + SymbolDimMap& symbol_dim_map) { if (!p_info.has_type()) { return; } @@ -180,13 +179,13 @@ void UpdateTorchValueByOnnxValueInfo( const auto& p_type = p_info.type(); if (p_type.has_tensor_type()) { const auto torch_tensor_type = - TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_map); + TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_dim_map); if (torch_tensor_type) { MergeInferredTypeAndSetMap(v, v->type(), torch_tensor_type); } } else if (p_type.has_sequence_type()) { const auto torch_list_type = - TorchListTypeFromONNX(p_type.sequence_type(), symbol_map); + TorchListTypeFromONNX(p_type.sequence_type(), symbol_dim_map); if (torch_list_type) { MergeInferredTypeAndSetMap(v, v->type(), torch_list_type); } @@ -353,15 +352,16 @@ bool IsGraphValidForInference(std::shared_ptr graph) { void ConvertGraphToONNXProto( std::shared_ptr graph, std::shared_ptr& model_proto, - SymbolDimMap& symbol_map, + SymbolDimMap& symbol_dim_map, int opset_version) { RawDataExportMap export_map; bool val_use_external_data_format; + SymbolDimMap new_symbol_dim_map; NodeNameMap node_names; std::tie( model_proto, export_map, - symbol_map, + new_symbol_dim_map, val_use_external_data_format, node_names) = export_onnx( @@ -377,6 +377,7 @@ void ConvertGraphToONNXProto( true, false, std::string()); + symbol_dim_map.insert(new_symbol_dim_map.begin(), new_symbol_dim_map.end()); for (int i = 0; i < model_proto->graph().output_size(); ++i) { model_proto->mutable_graph()->mutable_output(i)->clear_type(); } @@ -1240,10 +1241,6 @@ void ComputeConstant(Node* n, int opset_version) { 1, c10::ShapeSymbol::fromStaticSize(shape_value_size)); ::c10::SymbolicShape final_shape(final_shape_vector); UpdateShape(n->output(), final_shape); - } else if (ConstantValueMap::HasShape(n->input()->debugName())) { - ConstantValueMap::SetShapeValue( - n->output()->debugName(), - ConstantValueMap::GetShape(n->input()->debugName()).value()); } break; } @@ -1251,28 +1248,6 @@ void ComputeConstant(Node* n, int opset_version) { ProcessReshapeNode(n, opset_version); break; } - case ::c10::onnx::Gather: { - if (ConstantValueMap::HasShapeValue(n->input(0)->debugName()) && - ConstantValueMap::HasValue(n->input(1)->debugName())) { - // Special case for pattern Shape -> Gather, to propagate shape value. - // Gather input 0 is 1d tensor, Gather input 1 is scalar. - // Gather output will be scalar. - auto shape_value = - ConstantValueMap::GetShapeValue(n->input(0)->debugName()).value(); - auto idx_value = - ConstantValueMap::GetValue(n->input(1)->debugName()).value(); - // Consider the case when Gather index is a scalar. - if (idx_value.dim() == 0) { - auto idx_value_0 = idx_value.item(); - if (idx_value_0 >= 0) { - std::vector dims = {shape_value.at(idx_value_0)}; - c10::SymbolicShape symShape(dims); - ConstantValueMap::SetShapeValue(n->output()->debugName(), symShape); - } - } - } - break; - } case ::c10::onnx::Transpose: { if (n->hasAttributeS("perm")) { auto perm_v = n->is(attr::perm); @@ -1545,6 +1520,12 @@ void ProcessConstantValueMap(Node* n, int opset_version) { // For outputs, only update static shapes. For input, we update symbolic // shapes also. ONNX If can have different types on different branches, skip // here. + + // Update the shape reliability for each node before processing + // ConstantValueMap to prevent unreliable nodes from producing static + // shapes + UpdateReliable(n); + auto static_input_shape = AllGraphInputsStatic(n->owningGraph()); for (auto i : c10::irange(n->outputs().size())) { if (TensorTypePtr output_type = n->output(i)->type()->cast()) { @@ -1767,7 +1748,7 @@ void UpdateOutputTypeByONNXProto( Node* n, Node* clone_node, const onnx::ModelProto& model_proto, - SymbolDimMap& symbol_map) { + SymbolDimMap& symbol_dim_map) { const auto& graph_proto = model_proto.graph(); // get data from value_info and updated original graph. @@ -1775,7 +1756,8 @@ void UpdateOutputTypeByONNXProto( [&](const onnx::ValueInfoProto& v_info) { for (size_t i = 0; i < n->outputs().size(); ++i) { if (clone_node->output(i)->debugName() == v_info.name()) { - UpdateTorchValueByOnnxValueInfo(n->output(i), v_info, symbol_map); + UpdateTorchValueByOnnxValueInfo( + n->output(i), v_info, symbol_dim_map); } } }; @@ -1942,6 +1924,12 @@ void ONNXShapeTypeInference( Node* n, const ParamMap& params_dict, int opset_version) { + std::unordered_map torch_to_onnx_input; + std::unordered_map torch_to_onnx_output; + auto& original_shape_data = ConstantValueMap::GetInferredShapeData(); + ShapeDataMap inferred_shape_data; + auto& symbol_dim_map = ConstantValueMap::GetSymbolDimMap(); + SetGraphInputTypeReliable(n->owningGraph()); GRAPH_UPDATE( "Running ONNX shape inference for node: ", n->kind().toDisplayString()); @@ -1957,6 +1945,25 @@ void ONNXShapeTypeInference( n_graph->registerOutput(output); } + // Map original PyTorch graph's i/o name + // to temporal ONNX graph's i/o name for shape inference + for (size_t i = 0; i < clone_node->inputs().size(); ++i) { + torch_to_onnx_input[n->input(i)->debugName()] = + clone_node->input(i)->debugName(); + } + + for (size_t i = 0; i < clone_node->outputs().size(); ++i) { + torch_to_onnx_output[n->output(i)->debugName()] = + clone_node->output(i)->debugName(); + } + // Make inferred_shape_data use name from temporal ONNX graph + // instead of original PyTorch graph + for (const auto& gs_data : original_shape_data) { + const auto onnx_output_name = torch_to_onnx_input.find(gs_data.first); + if (onnx_output_name != torch_to_onnx_input.end()) { + inferred_shape_data[onnx_output_name->second] = gs_data.second; + } + } // Use scalar_type_analysis without low precision cast ScalarTypeAnalysisForONNX(n_graph, false, opset_version); @@ -1969,15 +1976,33 @@ void ONNXShapeTypeInference( // The conversion here is incomplete for these ops. // e.g: ListConstruct, ListUnpack, etc. std::shared_ptr model_proto; - SymbolDimMap symbol_map; - ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version); + ConvertGraphToONNXProto( + n_graph, model_proto, symbol_dim_map, opset_version); GRAPH_DEBUG( "ONNX graph to run shape inference: ", prettyPrint(*model_proto)); // infer shape try { - onnx::shape_inference::InferShapes(*model_proto); - UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map); + // TODO(#79208): Enable more operators to support data propagation + switch (n->kind()) { + case ::c10::onnx::Shape: + case ::c10::onnx::Gather: { + auto* schema_registry = onnx::OpSchemaRegistry::Instance(); + onnx::ShapeInferenceOptions options{ + /*check_type=*/false, + /*error_mode=*/false, + /*enable_data_propagation=*/true}; + onnx::shape_inference::InferShapes( + *model_proto, schema_registry, options, &inferred_shape_data); + break; + } + default: { + onnx::shape_inference::InferShapes(*model_proto); + break; + } + } + UpdateOutputTypeByONNXProto( + n, clone_node, *model_proto, symbol_dim_map); } catch (std::runtime_error& ex) { // TODO: include this as warning once we have a more consolidated // warning system. @@ -2003,6 +2028,28 @@ void ONNXShapeTypeInference( } SpecialPostProcess(n); + // Get data propagation result from ONNX shape inference + for (const auto& output : n->outputs()) { + const auto inferred_shape_pair = + inferred_shape_data.find(torch_to_onnx_output[output->debugName()]); + if (inferred_shape_pair != inferred_shape_data.end()) { + const auto& inferred_shape = inferred_shape_pair->second; + int rank = inferred_shape.dim_size(); + std::vector<::c10::ShapeSymbol> final_shape(rank); + for (int i = 0; i < rank; ++i) { + final_shape[i] = + ONNXDimToShapeSymbol(inferred_shape.dim(i), symbol_dim_map); + } + c10::SymbolicShape shape_value(final_shape); + // Store data propagation result into shapeValueMap + ConstantValueMap::SetShapeValue(output->debugName(), shape_value); + // Use original name in PyTorch graph instead of + // temporary name in intermediate ONNX graph + // Add this back to original_shape_data + original_shape_data[output->debugName()] = inferred_shape; + } + } + if (IsValidONNXNode(n)) { ProcessConstantValueMap(n, opset_version); if (n->kind() != prim::ListConstruct) { diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index 34248b351e86a..afda5b1765377 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -82,5 +82,7 @@ void UpdateReliable( torch::jit::Value* output, const std::pair& input_reliable); +void UpdateReliable(torch::jit::Node* n); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index a4a307b51b879..d04970137a20a 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1391,8 +1391,8 @@ class ShapePropagator : public PropertyPropBase { node, /*num_reduced_dim=*/0, /*upcast_integer=*/false, opt_dtype); }}; - static const auto factory_with_ndim = [](Node* node, - int dim) -> type_vec_t { + static const auto factory_with_ndim = + [](Node* node, int dim, at::ScalarType default_dtype) -> type_vec_t { at::optional maybe_layout_option = node->get(attr::layout); if (!maybe_layout_option) return {}; @@ -1408,7 +1408,7 @@ class ShapePropagator : public PropertyPropBase { if (!maybe_dtype_option) return {}; auto dtype = - (maybe_dtype_option->isNone() ? at::kDouble + (maybe_dtype_option->isNone() ? default_dtype : maybe_dtype_option->toScalarType()); return {TensorType::create( @@ -1491,12 +1491,23 @@ class ShapePropagator : public PropertyPropBase { "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + }, + [](Node* node) -> type_vec_t { + if (auto maybe_size = node->get>(attr::size)) { + return factory_with_ndim( + node, (int)maybe_size->size(), at::kDouble); + } + return {}; + }}; + + static const register_formula_for randint{ + { "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", }, [](Node* node) -> type_vec_t { if (auto maybe_size = node->get>(attr::size)) { - return factory_with_ndim(node, (int)maybe_size->size()); + return factory_with_ndim(node, (int)maybe_size->size(), at::kLong); } return {}; }}; @@ -1980,7 +1991,7 @@ class ShapePropagator : public PropertyPropBase { return true; } else if ( node->matches( - "aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor", + "aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor", /*const_inputs=*/{attr::dim, attr::keepdim})) { auto& tp = tensor_types.at(0); auto sizes = tp->sizes().concrete_sizes().value(); diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 8f81c8189e6cf..fe63de1a04d40 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -94,7 +94,7 @@ bool isSupported(Node* node) { static const OperatorSet supported_reduction_set{ "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor", "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", }; diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp index 00dd3e6ce852c..88baf90054af3 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.cpp +++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp @@ -99,6 +99,24 @@ void insertPrePackedGruOp(std::shared_ptr& graph) { gru_rewriter.runOnGraph(graph); } +void insertPrePackedLstmOp(std::shared_ptr& graph) { + std::string lstm_pattern = R"( + graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): + %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = aten::lstm(%input.1, %hx, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) + return (%y.1, %hn.1, %cn.1) )"; + std::string prepacked_ops_pattern = R"( + graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): + %packed_weights_biases = vulkan_prepack::create_lstm_context( + %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first) + %hx.1 : Tensor, %cx.1 : Tensor = prim::ListUnpack(%hx) + %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = vulkan_prepack::run_lstm_context(%input.1, %hx.1, %cx.1, %packed_weights_biases) + return (%y.1, %hn.1, %cn.1) )"; + + SubgraphRewriter lstm_rewriter; + lstm_rewriter.RegisterRewritePattern(lstm_pattern, prepacked_ops_pattern); + lstm_rewriter.runOnGraph(graph); +} + void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { SubgraphRewriter rewriter; @@ -188,6 +206,7 @@ void vulkanInsertPrePackedOps(std::shared_ptr& graph) { insertPrePackedLinearOp(graph); insertPrePackedConv2dOp(graph); insertPrePackedGruOp(graph); + insertPrePackedLstmOp(graph); } void vulkanInsertPrePackedOps(script::Module& module) { diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index db236a51e2f03..436684e7212cf 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -12,7 +13,7 @@ #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH)) #include #endif -#include +#include #include #include #include @@ -103,8 +104,6 @@ #include #include -#include - #include #include #include @@ -121,31 +120,32 @@ namespace torch { namespace jit { -using ::c10::AliasInfo; -using ::c10::Argument; -using ::c10::FunctionSchema; +using c10::AliasInfo; +using c10::Argument; +using c10::FunctionSchema; +using c10::SchemaArgType; +using c10::SchemaArgument; +using c10::SymIntNode; using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::PyTorchStreamWriter; +using torch::utils::SchemaInfo; -static std::shared_ptr toSymIntNode( - std::shared_ptr a, - py::object b) { - return torch::is_symint_node(b) - ? b.cast>() - : a->wrap(b.cast()); +static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) { + return torch::is_symint_node(b) ? b.cast() + : a->wrap(b.cast()); } -class PythonSymbolicIntNode : public c10::SymbolicIntNode { +class PythonSymIntNodeImpl : public c10::SymIntNodeImpl { public: - PythonSymbolicIntNode(py::object pyobj) : c10::SymbolicIntNode() { + PythonSymIntNodeImpl(py::object pyobj) : c10::SymIntNodeImpl() { pyobj_ = std::make_shared( pyobj.release().ptr(), getPyInterpreter()); }; - virtual std::shared_ptr wrap(int64_t num) override { + virtual SymIntNode wrap(int64_t num) override { py::gil_scoped_acquire acquire; auto r = getPyObj().attr("wrap")(num); - return std::make_shared(r); + return c10::make_intrusive(r); } virtual bool bool_() override { @@ -163,53 +163,53 @@ class PythonSymbolicIntNode : public c10::SymbolicIntNode { return getPyObj().attr("__str__")().cast(); } - virtual std::shared_ptr dispatch_common_( + virtual SymIntNode dispatch_common_( const char* fname, - const std::shared_ptr& other) { - auto pother = std::dynamic_pointer_cast(other); + const SymIntNode& other) { + auto pother = dynamic_cast(other.get()); TORCH_CHECK(pother); py::gil_scoped_acquire acquire; auto r = getPyObj().attr(fname)(pother->getPyObj()); - return std::make_shared(r); + return c10::make_intrusive(r); + } + + virtual SymIntNode add(const SymIntNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + virtual SymIntNode sub(const SymIntNode& other) override { + return dispatch_common_(__FUNCTION__, other); } - virtual std::shared_ptr add( - const std::shared_ptr& other) override { + virtual SymIntNode mul(const SymIntNode& other) override { return dispatch_common_(__FUNCTION__, other); } - virtual std::shared_ptr sub( - const std::shared_ptr& other) override { + virtual SymIntNode div(const SymIntNode& other) override { return dispatch_common_(__FUNCTION__, other); } - virtual std::shared_ptr mul( - const std::shared_ptr& other) override { + virtual SymIntNode mod(const SymIntNode& other) override { return dispatch_common_(__FUNCTION__, other); } - virtual std::shared_ptr div( - const std::shared_ptr& other) override { + virtual SymIntNode eq(const SymIntNode& other) override { return dispatch_common_(__FUNCTION__, other); } - virtual std::shared_ptr mod( - const std::shared_ptr& other) override { + virtual SymIntNode gt(const SymIntNode& other) override { return dispatch_common_(__FUNCTION__, other); } - virtual std::shared_ptr eq( - const std::shared_ptr& other) override { + virtual SymIntNode lt(const SymIntNode& other) override { return dispatch_common_(__FUNCTION__, other); } - virtual std::shared_ptr gt( - const std::shared_ptr& other) override { + virtual SymIntNode le(const SymIntNode& other) override { return dispatch_common_(__FUNCTION__, other); } - virtual std::shared_ptr lt( - const std::shared_ptr& other) override { + virtual SymIntNode ge(const SymIntNode& other) override { return dispatch_common_(__FUNCTION__, other); } @@ -232,6 +232,12 @@ bool loadPythonClasses() { return true; } + +bool isEmptyContainer(const py::handle self) { + bool is_empty_list = + PySequence_Check(self.ptr()) && !PySequence_Size(self.ptr()); + return is_empty_list; +} } // anonymous namespace #if !defined(USE_ROCM) @@ -1172,100 +1178,95 @@ void initJITBindings(PyObject* module) { } }); - py::class_>( - m, "SymbolicIntNode") + py::class_(m, "SymIntNode") .def_static( "new_symint", - [](py::object obj) -> std::shared_ptr { - return std::make_shared(obj); + [](py::object obj) -> c10::SymIntNode { + return c10::make_intrusive(obj); }) .def( "get_pyobj", - [](std::shared_ptr a) -> py::object { - if (auto psn = - std::dynamic_pointer_cast(a)) { + [](c10::SymIntNode a) -> py::object { + if (auto* psn = dynamic_cast(a.get())) { return py::reinterpret_borrow(psn->getPyObj()); } return py::none(); }) .def( "__add__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->add(snb); }) .def( "__radd__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->add(snb); }) .def( "__sub__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->sub(snb); }) .def( "__mul__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->mul(snb); }) .def( "__rmul__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->mul(snb); }) .def( "__div__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->div(snb); }) .def( "__mod__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->mod(snb); }) .def( "__eq__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->eq(snb); }) .def( "__gt__", - [](std::shared_ptr a, py::object b) { + [](c10::SymIntNode a, py::object b) { auto snb = toSymIntNode(a, b); return a->gt(snb); }) .def( "__lt__", - [](std::shared_ptr a, - py::object b) -> std::shared_ptr { + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->lt(snb); }) .def( - "__bool__", - [](std::shared_ptr a) { return a->bool_(); }) + "__le__", + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { + auto snb = toSymIntNode(a, b); + return a->le(snb); + }) .def( - "__int__", - [](std::shared_ptr a) { return a->int_(); }) - .def("__str__", [](std::shared_ptr a) { - return a->str(); - }); + "__ge__", + [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { + auto snb = toSymIntNode(a, b); + return a->ge(snb); + }) + .def("__bool__", [](c10::SymIntNode a) { return a->bool_(); }) + .def("__int__", [](c10::SymIntNode a) { return a->int_(); }) + .def("__str__", [](c10::SymIntNode a) { return a->str(); }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "CompleteArgumentSpec") @@ -1339,8 +1340,13 @@ void initJITBindings(PyObject* module) { .def(py::init()) .def(py::init([](const py::object& buffer) { auto writer_func = [=](const void* data, size_t size) { - auto bytes = py::bytes(reinterpret_cast(data), size); - buffer.attr("write")(std::move(bytes)); + // Writting an empty file is a noop + if (size == 0) { + return size; + } + auto memory_view = py::memoryview::from_memory( + reinterpret_cast(data), size); + buffer.attr("write")(std::move(memory_view)); return size; }; return std::make_unique(std::move(writer_func)); @@ -1548,10 +1554,15 @@ void initJITBindings(PyObject* module) { try { auto symbol = Symbol::fromQualString(op_name); auto operations = getAllOperatorsFor(symbol); + bool allow_numbers_as_tensors = symbol.is_prims() || + (symbol.is_aten() && + torch::should_allow_numbers_as_tensors(symbol.toUnqualString())); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { - auto func = py::cpp_function( - [op, symbol](py::args args, py::kwargs kwargs) { + auto func = + py::cpp_function([op, symbol, allow_numbers_as_tensors]( + py::args args, py::kwargs kwargs) { + ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( {op}, symbol, args, kwargs, true); }); @@ -1587,8 +1598,14 @@ void initJITBindings(PyObject* module) { overload_names.append(py::str(op->schema().overload_name())); } + bool allow_numbers_as_tensors = symbol.is_prims() || + (symbol.is_aten() && + torch::should_allow_numbers_as_tensors(symbol.toUnqualString())); + auto func = py::cpp_function( - [operations, symbol](py::args args, py::kwargs kwargs) { + [operations, symbol, allow_numbers_as_tensors]( + py::args args, py::kwargs kwargs) { + ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( operations, symbol, args, kwargs, false); }, @@ -1622,7 +1639,79 @@ void initJITBindings(PyObject* module) { } return type.value(); }); - + py::enum_(m, "_SchemaArgType") + .value("input", SchemaArgType::input) + .value("output", SchemaArgType::output); + py::class_(m, "_SchemaArgument") + .def(py::init()) + .def_readwrite("type", &SchemaArgument::type) + .def_readwrite("index", &SchemaArgument::index); + py::class_(m, "_SchemaInfo") + .def(py::init()) + .def("is_mutable", [](SchemaInfo& self) { return self.is_mutable(); }) + .def( + "is_mutable", + [](SchemaInfo& self, const SchemaArgument& argument) { + return self.is_mutable(argument); + }) + .def( + "is_mutable", + [](SchemaInfo& self, const std::string& name) { + return self.is_mutable(name); + }) + .def( + "may_alias", + [](SchemaInfo& self, + const SchemaArgument& lhs, + const SchemaArgument& rhs) { return self.may_alias(lhs, rhs); }) + .def( + "may_contain_alias", + [](SchemaInfo& self, + const SchemaArgument& lhs, + const SchemaArgument& rhs) { + return self.may_contain_alias(lhs, rhs); + }) + .def( + "add_argument_value", + [](SchemaInfo& self, + const std::string& name, + const py::object& value) { + if (isEmptyContainer(value)) { + return; + } + // For normalization purposes there is an inconsistency within + // torch.fx that turns all arguments named "self" into "input". Thus + // this check ensures that those arguments are checked correctly. + if (name == "input" && !self.hasInputArgumentNamed("input")) { + self.addArgumentValue("self", toTypeInferredIValue(value)); + } else { + self.addArgumentValue(name, toTypeInferredIValue(value)); + } + }) + .def("add_argument_values", [](SchemaInfo& self, const py::dict& values) { + std::unordered_map value_map; + for (const auto& key_pair : values) { + IValue key = toTypeInferredIValue(key_pair.first); + if (isEmptyContainer(key_pair.second)) { + continue; + } + IValue value = toTypeInferredIValue(key_pair.second); + TORCH_INTERNAL_ASSERT( + key.isString(), + "Add argument value keys types should be strings."); + // For normalization purposes there is an inconsistency within + // torch.fx that + // turns all arguments named "self" into "input". Thus this check + // ensures that those arguments are checked correctly. + if (key.toStringRef() == "input" && + !self.hasInputArgumentNamed("input")) { + self.addArgumentValue("self", value); + } else { + value_map[key.toStringRef()] = value; + } + } + self.addArgumentValues(value_map); + }); py::class_(m, "FunctionSchema") .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) @@ -1789,8 +1878,17 @@ void initJITBindings(PyObject* module) { return nullptr; }), py::call_guard()); - m.def("_is_alias_of", [](const at::Tensor& self, const at::Tensor& other) { - return self.is_alias_of(other); + m.def("_is_alias_of", [](const py::object& self, const py::object& other) { + if (isEmptyContainer(self) || isEmptyContainer(other)) { + return false; + } + return toTypeInferredIValue(self).isAliasOf(toTypeInferredIValue(other)); + }); + m.def("_overlaps", [](const py::object& self, const py::object& other) { + if (isEmptyContainer(self) || isEmptyContainer(other)) { + return true; + } + return toTypeInferredIValue(self).overlaps(toTypeInferredIValue(other)); }); m.def("fork", [](const py::args& args, const py::kwargs& kwargs) { AT_ASSERT(args.size() >= 1); diff --git a/torch/csrc/jit/python/module_python.h b/torch/csrc/jit/python/module_python.h index 544c4d464aca6..35e2fc54b576a 100644 --- a/torch/csrc/jit/python/module_python.h +++ b/torch/csrc/jit/python/module_python.h @@ -2,6 +2,7 @@ #include #include #include +#include namespace py = pybind11; diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 9ba524e7150f9..9dd755acf4e9f 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -8,10 +8,22 @@ #include #include +#include namespace torch { namespace jit { +static thread_local bool allow_numbers_as_tensors = false; + +ToIValueAllowNumbersAsTensors::ToIValueAllowNumbersAsTensors(bool enable) + : old_(allow_numbers_as_tensors) { + allow_numbers_as_tensors = enable; +} + +ToIValueAllowNumbersAsTensors::~ToIValueAllowNumbersAsTensors() { + allow_numbers_as_tensors = old_; +} + // This is a hack to remove instances deleted in C++ from the PyBind cache // C++->Python. We need this because otherwise we may get the old Python object // if C++ creates a new object at the memory location of the deleted object. @@ -38,6 +50,10 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { guardAgainstNamedTensor(var); return var; } else { + if (!allow_numbers_as_tensors) { + throw py::cast_error( + c10::str("Unable to cast ", py::str(obj), " to Tensor")); + } at::Scalar scalar; if (PyBool_Check(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackBool(obj.ptr())); @@ -65,7 +81,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { return static_cast>(c_obj); } case TypeKind::SymIntType: - return py::cast(obj); + return torch::is_symint_node(obj) + ? obj.cast()->toSymInt() + : c10::SymInt{py::cast(obj)}; case TypeKind::IntType: // NB: Typically, these switches are completely dead, because // Argument::type() will always report IntType for these types. @@ -172,6 +190,17 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { } return repeated; } + case TypeKind::SymIntType: { + c10::List symints; + for (auto it = obj.begin(); it != obj.end(); it++) { + auto elm = *it; + auto si = torch::is_symint_node(elm) + ? elm.cast()->toSymInt() + : c10::SymInt{py::cast(elm)}; + symints.push_back(si); + } + return symints; + } case TypeKind::FloatType: if (!N || !py::isinstance(obj)) { return IValue(py::cast>(obj)); @@ -362,9 +391,10 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { throw py::cast_error( c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str())); } + case TypeKind::GeneratorType: + return py::cast(obj); case TypeKind::DynamicType: case TypeKind::FunctionType: - case TypeKind::GeneratorType: case TypeKind::QuantizerType: case TypeKind::VarType: case TypeKind::AnyListType: diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 6dee87e10d005..5386d30bc2bb7 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -66,6 +66,17 @@ TORCH_API IValue toIValue( py::object toPyObject(IValue ivalue); +// Hack to overload the behavior of toIValue to accept Python +// numbers in places where a Tensor is expected +// See also torch::should_allow_numbers_as_tensors +class ToIValueAllowNumbersAsTensors { + bool old_; + + public: + ToIValueAllowNumbersAsTensors(bool enable); + ~ToIValueAllowNumbersAsTensors(); +}; + // Wrap Python function to guard deref // NB: Need VISIBILITY_HIDDEN for silencing compiler error, // 'torch::jit::PythonFunctionGuard' declared with greater visibility than the @@ -720,6 +731,8 @@ inline py::object toPyObject(IValue ivalue) { } } else if (ivalue.isStorage()) { return py::cast(ivalue.toStorage()); + } else if (ivalue.isGenerator()) { + return py::cast(ivalue.toGenerator()); } else if (ivalue.isDouble()) { return py::cast(std::move(ivalue).toDouble()); } else if (ivalue.isComplexDouble()) { @@ -839,7 +852,7 @@ inline py::object toPyObject(IValue ivalue) { #endif } else if (ivalue.isSymInt()) { auto si = ivalue.toSymInt(); - return si.is_symbolic() ? py::cast(si.toSymbolicIntNode()) + return si.is_symbolic() ? py::cast(si.toSymIntNodeImpl()) : py::cast(si.expect_int()); } else { AT_ERROR( @@ -1221,13 +1234,6 @@ inline py::object _get_operation_for_overload_or_packet( } if (overloaded_args.size() > 0 || at::impl::PythonTorchFunctionTLS::get_mode()) { - std::vector overloaded_types; - overloaded_types.reserve(overloaded_args.size()); - for (auto& oarg : overloaded_args) { - overloaded_types.push_back( - py::reinterpret_borrow((PyObject*)Py_TYPE(oarg.ptr()))); - } - py::tuple py_types = py::cast(overloaded_types); py::object ret; std::string ns = symbol.ns().toUnqualString(); std::string method_name = symbol.toUnqualString(); diff --git a/torch/csrc/jit/python/python_dict.cpp b/torch/csrc/jit/python/python_dict.cpp index ab812a18ea732..88b72cd5d0fff 100644 --- a/torch/csrc/jit/python/python_dict.cpp +++ b/torch/csrc/jit/python/python_dict.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/jit/python/python_interpreter.cpp b/torch/csrc/jit/python/python_interpreter.cpp index 29b7929fcd690..d54e3d7a0f628 100644 --- a/torch/csrc/jit/python/python_interpreter.cpp +++ b/torch/csrc/jit/python/python_interpreter.cpp @@ -19,6 +19,7 @@ #include #include #include +#include namespace py = pybind11; diff --git a/torch/csrc/jit/python/python_ivalue.h b/torch/csrc/jit/python/python_ivalue.h index 258afc095f908..f2433c4f12a55 100644 --- a/torch/csrc/jit/python/python_ivalue.h +++ b/torch/csrc/jit/python/python_ivalue.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace py = pybind11; diff --git a/torch/csrc/jit/python/python_list.cpp b/torch/csrc/jit/python/python_list.cpp index 6a30c24c618d9..96243294945d8 100644 --- a/torch/csrc/jit/python/python_list.cpp +++ b/torch/csrc/jit/python/python_list.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace torch { diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 53fd5c5777867..705731778dc35 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -219,7 +220,6 @@ std::shared_ptr PythonModuleValue::attr( return toSugaredValue(member, m, loc, /*is_constant=*/true); } -#if !defined(USE_ROCM) std::shared_ptr CUDAPythonModuleValue::attr( const SourceRange& loc, GraphFunction& m, @@ -259,7 +259,6 @@ std::shared_ptr CUDAPythonModuleValue::attr( // even though it is possible, though rare, for someone to mutate them return toSugaredValue(member, m, loc, /*is_constant=*/true); } -#endif Value* ModuleValue::asValue(const SourceRange& loc, GraphFunction& m) { return self_; @@ -1199,12 +1198,10 @@ std::shared_ptr toSugaredValue( if (auto callee = as_function(obj)) { return std::make_shared(callee->function_); } else if (py::isinstance(obj)) { -#ifndef USE_ROCM std::string obj_name = py::cast(py::getattr(obj, "__name__")); if (obj_name.compare("torch.cuda") == 0) { return std::make_shared(obj); } -#endif return std::make_shared(obj); } else if ( obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() || diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index ad393450bc8ef..b49a0cc948b9b 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -96,7 +96,6 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { // Used for desugaring uses of the torch.cuda module. All the CUDA APIs with // torch.cuda.* are resolved using CUDAPythonModuleValue. -#if !defined(USE_ROCM) struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue { explicit CUDAPythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {} @@ -106,7 +105,6 @@ struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue { GraphFunction& m, const std::string& field) override; }; -#endif // Represents all the parameters of a module as a List[Tensor] struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 4f2d983c025fd..494265e161849 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -27,16 +27,20 @@ namespace tracer { std::vector _pythonCallstack() { pybind11::gil_scoped_acquire gil; PyFrameObject* frame = PyEval_GetFrame(); + Py_INCREF(frame); std::vector entries; while (nullptr != frame) { - size_t line = PyCode_Addr2Line(frame->f_code, frame->f_lasti); - std::string filename = THPUtils_unpackString(frame->f_code->co_filename); - std::string funcname = THPUtils_unpackString(frame->f_code->co_name); + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + size_t line = PyCode_Addr2Line(code.get(), PyFrame_GetLasti(frame)); + std::string filename = THPUtils_unpackString(code->co_filename); + std::string funcname = THPUtils_unpackString(code->co_name); auto source = std::make_shared(funcname, filename, line); entries.emplace_back( StackEntry{funcname, SourceRange(source, 0, funcname.size())}); - frame = frame->f_back; + auto new_frame = PyFrame_GetBack(frame); + Py_DECREF(frame); + frame = new_frame; } return entries; } diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 1b8484ead5df0..5b50b787f7c52 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -487,7 +488,12 @@ static void setInputTensorTypes( auto s_iter = stack.begin(); size_t list_idx = 0; if (!param_count_list.empty()) { - TORCH_INTERNAL_ASSERT(input_values.size() == param_count_list.size()); + TORCH_INTERNAL_ASSERT( + input_values.size() == param_count_list.size(), + " input_values:", + input_values.size(), + " vs param_count_list:", + param_count_list.size()); } for (auto v : input_values) { // Leave packed param types alone. This is needed for downstream passes diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index a8dc3cc7859a7..7e1f6182534cc 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -389,7 +389,7 @@ bool outputRequiresGrad(Value* output) { static ReverseDetails addReverseInline(Gradient& grad_desc) { auto& graph = *grad_desc.f; // note: reverse_node is intentionally not inserted to avoid - // accidentally acting on it (e.g. in elminate dead code), + // accidentally acting on it (e.g. in eliminate dead code), // std::cout << *reverse_node << to view its state. auto reverse_node = graph.create(prim::Reverse, 0); auto reverse_block = reverse_node->addBlock(); diff --git a/torch/csrc/jit/runtime/register_cuda_ops.cpp b/torch/csrc/jit/runtime/register_cuda_ops.cpp index bd74c2d5e646a..a4a4605c7d8af 100644 --- a/torch/csrc/jit/runtime/register_cuda_ops.cpp +++ b/torch/csrc/jit/runtime/register_cuda_ops.cpp @@ -1,6 +1,5 @@ // This file registers special JIT operators used to implement the PyTorch CUDA // API in TorchScript. -#if !defined(USE_ROCM) #include #include #include @@ -167,4 +166,3 @@ RegisterOperators const reg({ } // namespace } // namespace jit } // namespace torch -#endif diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index c7b2527638bd2..e35b79611730f 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -148,14 +148,6 @@ static const std::vector opGenArgs{ push(stack, a.mH()); }, aliasAnalysisFromSchema()), - OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA("prim::layout(Tensor a) -> int"), - [](Stack& stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.layout()); - }, - aliasAnalysisFromSchema()), // only used internally in range() translation OperatorGeneratorArgs( @@ -418,6 +410,17 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA("aten::size(Tensor self) -> int[]"), size, aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::sym_size(Tensor self) -> SymInt[]"), + sym_size, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::stride(Tensor self) -> int[]"), + [](Stack& stack) { + at::Tensor arg = pop(stack).toTensor(); + push(stack, arg.strides()); + }, + aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::EnumName(AnyEnumType enum) -> str"), [](Stack& stack) { @@ -470,6 +473,10 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA("prim::dtype(Tensor a) -> int"), dtype, aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::layout(Tensor a) -> Layout"), + layout, + aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::__not__(bool self) -> bool"), _not, diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp index 47e8b321512d5..50e34d2402df6 100644 --- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp +++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp @@ -2115,11 +2115,15 @@ def transpose(self: List[int], return _1 )=====") -+ std::string(R"=====(def mean_dim(self: List[int], - dims: List[int], ++ std::string(R"=====(def sum_mean_dim(self: List[int], + opt_dims: Optional[List[int]], keep_dim: bool, dt: Any) -> List[int]: out = annotate(List[int], []) + if opt_dims is None: + dims:List[int] = [] + else: + dims = opt_dims for idx in range(torch.len(self)): is_mean_dim = False for _0 in range(torch.len(dims)): @@ -2748,8 +2752,8 @@ const OperatorMap& GetShapeFunctionMappings() { {"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"}, {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"}, {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"}, - {"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"}, - {"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"}, + {"aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_mean_dim"}, + {"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_mean_dim"}, {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"}, {"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"}, {"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"}, diff --git a/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h b/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h index c7e7014a2fb54..6e185cb11fe1f 100644 --- a/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h +++ b/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h @@ -22,7 +22,7 @@ class ProcessedNodeInputs { ProcessedNodeInputs() : ProcessedNodeInputs(0) {} explicit ProcessedNodeInputs(size_t size) { - DCHECK_LT(size, (1 << 16)); + TORCH_DCHECK_LT(size, (1 << 16)); if (size <= kMaxInlineInputs) { repr_.inline_repr_.size = size; } else { @@ -36,7 +36,7 @@ class ProcessedNodeInputs { uint16_t& operator[](uint16_t idx) { if (C10_LIKELY(repr_.is_inline())) { - DCHECK_LT(idx, repr_.inline_repr_.size); + TORCH_DCHECK_LT(idx, repr_.inline_repr_.size); return repr_.inline_repr_.inputs[idx]; } else { return repr_.outline_repr_[idx]; @@ -102,12 +102,12 @@ class ProcessedNodeInputs { } uint16_t operator[](uint16_t idx) const { - DCHECK_LT(idx, size()); + TORCH_DCHECK_LT(idx, size()); return array_[idx + 1]; } uint16_t& operator[](uint16_t idx) { - DCHECK_LT(idx, size()); + TORCH_DCHECK_LT(idx, size()); return array_[idx + 1]; } diff --git a/torch/csrc/jit/runtime/static/README.md b/torch/csrc/jit/runtime/static/README.md index 35d6f02389a05..c97b64dd07710 100644 --- a/torch/csrc/jit/runtime/static/README.md +++ b/torch/csrc/jit/runtime/static/README.md @@ -2,10 +2,19 @@ # Static Runtime -The premise of this approach is that a small subset of neural networks are well represented by a -completely flattened dataflow graph. -TorchScript supports a far more feature programming paradigm, -so many models will not work out of the box. +Static Runtime is an optimized CPU inference runtime for PyTorch models. +It can be used as a drop-in replacement for the TorchScript JIT interpreter +in either C++ or Python. + +Static Runtime is mainly useful if the following conditions are met: +1. The model has very little control flow. +2. PyTorch overhead (tensor creation, etc) accounts for +a non-trivial fraction of the model's runtime. In particular, if +tensor allocation consumes a significant amount of time, Static +Runtime can help. Memory for intermediate tensors is coalesced into +a single slab, so most dynamic allocations are avoided during +inference. +3. Inference performance is extremely important. ## Assumptions @@ -238,3 +247,76 @@ a reference to `ivalue_array[some_set_of_indices[i]]`) Each `ProcessedNode` stores a `ProcessedFunction`, which represents the actual op to execute. `ProcessedFunction`s are initialized upon `StaticModule` construction according to the out variant/native/JIT fallback lookup rules described in "Registering Ops". **Note that all `ProcessedFunction`s are shared amongst all runtime instances**, so all `ProcessedFunction`s must be thread-safe. + +### `ProcessedNodeMetadata` + +`ProcessedNodeMetadata` holds various "extra" fields on behalf of `ProcessedNode`. Typically, this field is unused. But a few ops need extra machinery to work: +* `prim::If` operations have two `BlockRunner`s for the execution of true and false sub-blocks depending upon the condition check. +* `prim::Loop` operations have a `BlockRunner` for the execution of the looping sub-block. +* `prim::fork` operations have `torch::jit::TaskLauncher` (`std::function)>`) responsible for forked graph execution. + +### `Asynchronous Execution` + +`StaticRuntime::runAsync()` API allows execution of asynchronous operations on `TaskLauncher` passed as arguments. +`StaticRuntime::runAsync()` performs inline execution of parent graph on caller thread and asynchronous operations like `prim::fork` are executed +on the launcher passed in. In the case that no launcher is provided, the execution happens on `at::launch` inter-op thread pool. + +### `prim::fork and aten::wait` + +`prim::fork` takes the callable function/method/Module (say `fn`) and arguments to that callable `args` and `kwargs`. Since the execution of forked function `fn` happens asynchronously and fork returns immediately after creating the async task, the `fn` may not have been executed by the time the line of code after the `fork` call is reached. Thus, `aten::wait` is used to wait for the async `fn` task to be completed. `prim::fork` nodes contain the sub-graph for the forked parts of the network. Each parent graph creates a separate instance of `StaticModule` for the +forked sub-graph and `StaticRuntime` instances are created on the fly during runtime as the fork nodes are executed. The forked subgraph execution +happens asynchronously on the launcher provided during `StaticRuntime::runAsync()` or by `at::launch` executor by default. `aten::wait` operator +waits on the future returned by the corresponding `prim::fork` operation + +#### Inter-op parallelism via fork/wait ops + +Sample Model with independent operations can be parallelized by inserting fork/wait nodes in the graph. + +```python +def CNNBlock(x): + out_1 = conv1(x) + out_1 = conv2(out_1) + out_1 = max_pool1(out_1) + + out_2 = conv3(x) + out_2 = max_pool2(out_2) + + out_merged = conv4(out_1 + out_2) + return out_merged +``` +The two branches of (conv,conv,pool) operations can be parallelized by inserting fork nodes such that the execution of both the branches can +happen in parallel: + +```python +def branch1(x): + out = conv1(x) + out = conv2(x) + return max_pool1(out) + +def branch2(x): + out = conv3(x) + return max_pool2(out) + +def CNNBlock(x): + fut_1 = torch.jit.fork(branch1, x) + fut_2 = torch.jit.fork(branch2, x) + + out_merged = conv4(torch.jit.wait(fut_1) + torch.jit.wait(fut_2)) + return out_merged + ``` +**Execution without fork/wait operations:** +``` +: conv1 ─> conv2 ─> max_pool1 ─> conv3 ─> max_pool2 ─> conv4 +``` + +**Execution with fork/wait operations:** +``` + : fork1 ──> fork2 ──────────> wait(fut_1) ─> wait(fut_2) ─> conv4 + | | + | | +: | conv3 ──────────────────> max_pool2 -> fut_2 + | +: conv1 ─> conv2 ─> max_pool1 ──>fut_1 +``` +More examples for fork/wait operations and inter-op parallelism in PyTorch can be found at +[Dynamic Parallelism in TorchScript](https://pytorch.org/tutorials/advanced/torch-script-parallelism.html) diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 9d9ab145b82b0..b8fd7800e5334 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -167,6 +167,10 @@ void OptimizeGraph( graph, fromQualString("fb::sigrid_transforms_torch_bind"), fromQualString("fb::variadic_sigrid_transforms_torch_bind")); + UseVariadicOp( + graph, + fromQualString("torcharrow::inference_wrapper_run_flat"), + fromQualString("torcharrow::variadic_inference_wrapper_run_flat")); // These fused ops only have out variants - we can't do the fusion when // out variants are disabled. FuseSignLog1P(graph); @@ -194,6 +198,7 @@ void OptimizeGraph( graph, /* custom_ops */ {fromQualString("fb::scale_gradient")}); AddIfThenElseOp(graph); UseSplitAndSqueeze(graph); + UseInPlaceGetRealInputsFromOptionalInputsV2(graph); GRAPH_DUMP("Final graph after optimizations: ", graph); } @@ -842,6 +847,7 @@ BlockRunner::BlockRunner( const StaticModule& sm, IValue* values, Block* block, + torch::jit::TaskLauncher* launcher, bool is_root_block) : static_module_(sm), block_info_(static_module_.block_info(block)), @@ -864,19 +870,22 @@ BlockRunner::BlockRunner( for (auto& pnode : nodes_) { auto* node = pnode.node(); + + // attach the async taskLauncher to processedNodes + pnode.set_metadata(launcher); auto blocks = node->blocks(); const auto num_blocks = blocks.size(); if (num_blocks == 0) { continue; } DCHECK(node->kind() == prim::If || node->kind() == prim::Loop); - auto block_runners = std::make_unique>(); - block_runners->reserve(num_blocks); + std::vector block_runners; + block_runners.reserve(num_blocks); for (auto* b : blocks) { - block_runners->emplace_back(sm, values_, b); + block_runners.emplace_back(sm, values_, b, launcher); } - pnode.set_block_runners(std::move(block_runners)); + pnode.set_metadata(std::move(block_runners)); } } @@ -1244,6 +1253,52 @@ c10::IValue BlockRunner::run_impl_record_functions( return run_impl(std::forward(args), kwargs); } +template +c10::intrusive_ptr BlockRunner::run_impl_async( + IValueList&& args, + const KeywordArgs& kwargs) { + // run the graph inline in the caller thread. Async ops will be + // executed on taskLauncher attached to the metadata of ProcessedNodes + c10::IValue output = run_impl(args, kwargs); + + // If the output is of type future, return it + if (output.isFuture()) { + return output.toFuture(); + } + + // wrap the output into future, mark completed and return it + TypePtr return_type; + if (block_info_.num_outputs() > 1) { + return_type = TupleType::create( + fmap(outputs(), [](const IValue* v) { return v->type(); })); + } else { + return_type = outputs().at(0)->type(); + } + c10::intrusive_ptr future = c10::make_intrusive(return_type); + future->markCompleted(output); + return future; +} + +template +c10::intrusive_ptr BlockRunner:: + run_impl_record_functions_async( + IValueList&& args, + const KeywordArgs& kwargs) { + auto step_callbacks = + at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL); + if (C10_UNLIKELY(step_callbacks.has_value())) { + at::RecordFunction guard(std::move(*step_callbacks)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive()); + guard.needsInputs() + ? guard.before( + "forward", c10::ArrayRef(args.data(), args.size())) + : guard.before("forward"); + + return run_impl_async(std::forward(args), kwargs); + } + return run_impl_async(std::forward(args), kwargs); +} + c10::IValue BlockRunner::operator()( const std::vector& args, const KeywordArgs& kwargs) { @@ -1264,6 +1319,26 @@ c10::IValue BlockRunner::operator()( #endif } +c10::intrusive_ptr BlockRunner::runAsync( + const std::vector& args, + const KeywordArgs& kwargs) { +#ifdef PYTORCH_DISABLE_NET_PROFILING + return run_impl_async(args, kwargs); +#else + return run_impl_record_functions_async(args, kwargs); +#endif +} + +c10::intrusive_ptr BlockRunner::runAsync( + std::vector&& args, + const KeywordArgs& kwargs) { +#ifdef PYTORCH_DISABLE_NET_PROFILING + return run_impl_async(std::move(args), kwargs); +#else + return run_impl_record_functions_async(std::move(args), kwargs); +#endif +} + namespace { std::string generate_latency_json(const std::string& label, double millis) { @@ -1720,10 +1795,10 @@ bool BlockRunner::check_for_memory_leak( } } } - - auto* block_runners = pnode.block_runners(); - if (recurse_on_sub_blocks && block_runners) { - for (auto& block_runner : *block_runners) { + auto* metadata = pnode.metadata(); + if (recurse_on_sub_blocks && metadata) { + auto& block_runners = metadata->block_runners(); + for (auto& block_runner : block_runners) { block_runner.check_for_memory_leak( output_returned, recurse_on_sub_blocks); } @@ -1849,7 +1924,7 @@ ProcessedFunction::ProcessedFunction( stack.emplace_back(static_cast(size)); } node_op(stack); - DCHECK_EQ(stack.size(), pnode->num_outputs()); + TORCH_DCHECK_EQ(stack.size(), pnode->num_outputs()); for (const auto i : c10::irange(pnode->num_outputs())) { pnode->Output(i) = std::move(stack[i]); } @@ -2053,9 +2128,14 @@ void ProcessedNode::verify_and_correct_memory_overlap() { StaticRuntime::StaticRuntime(const StaticModule& sm) : values_(sm.value_buffer_size()) { std::copy(sm.constants().begin(), sm.constants().end(), values_.data()); + // default task launcher set to inter-op thread pool + async_task_launcher_ = at::launch; block_ = std::make_unique( - sm, values_.data(), sm.root_block(), /*is_root_block*/ true); - ; + sm, + values_.data(), + sm.root_block(), + &async_task_launcher_, + true /*is_root_block*/); } c10::IValue StaticRuntime::operator()( @@ -2070,6 +2150,22 @@ c10::IValue StaticRuntime::operator()( return (*block_)(std::move(args), kwargs); } +c10::intrusive_ptr StaticRuntime::runAsync( + const std::vector& args, + const KeywordArgs& kwargs, + torch::jit::TaskLauncher taskLauncher) { + async_task_launcher_ = std::move(taskLauncher); + return block_->runAsync(args, kwargs); +} + +c10::intrusive_ptr StaticRuntime::runAsync( + std::vector&& args, + const KeywordArgs& kwargs, + torch::jit::TaskLauncher taskLauncher) { + async_task_launcher_ = std::move(taskLauncher); + return block_->runAsync(std::move(args), kwargs); +} + bool StaticRuntime::check_for_memory_leak(bool output_returned) { return block_->check_for_memory_leak( output_returned, /* recurse_on_sub_blocks */ true); diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index aa979ddecb6b2..67c16ba09b7ca 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -551,6 +551,7 @@ class TORCH_API BlockRunner { const StaticModule& sm, IValue* values, Block* block, + torch::jit::TaskLauncher* launcher, bool is_root_block = false); BlockRunner(BlockRunner&&) noexcept; BlockRunner& operator=(BlockRunner&&) = delete; @@ -566,6 +567,14 @@ class TORCH_API BlockRunner { std::vector&& args, const KeywordArgs& kwargs = KeywordArgs()); + c10::intrusive_ptr runAsync( + const std::vector& args, + const KeywordArgs& kwargs); + + c10::intrusive_ptr runAsync( + std::vector&& args, + const KeywordArgs& kwargs); + void benchmark( const std::vector>& args_list, const std::vector& kwargs_list, @@ -599,7 +608,7 @@ class TORCH_API BlockRunner { // Input is readwrite IValue& Input(uint32_t i) { - DCHECK_LT(i, block_info_.num_inputs()); + TORCH_DCHECK_LT(i, block_info_.num_inputs()); return values_[i + block_info_.block_inputs_idx()]; } @@ -694,6 +703,16 @@ class TORCH_API BlockRunner { IValueList&& args, const KeywordArgs& kwargs); + template + c10::intrusive_ptr run_impl_async( + IValueList&& args, + const KeywordArgs& kwargs); + + template + c10::intrusive_ptr run_impl_record_functions_async( + IValueList&& args, + const KeywordArgs& kwargs); + // helper method for copying input args/kwargs into inputs_ template void set_inputs( @@ -833,6 +852,50 @@ class TORCH_API StaticNodeInfo { uint16_t outputs_offset_; }; +/* + ProcessedNodeMetadata class wraps the possible metadata + for ProcessedNode. Depending upon the nature of op, processedNode + can have one of the below possibilities of metadata: + - prim::If/prim::Loop ops contains block_runners_ as their metadata + - prim::fork op contains TaskLauncher (std::function) responsible for + execution of forked subgraph +*/ +class TORCH_API ProcessedNodeMetadata { + public: + ProcessedNodeMetadata( + std::vector runners, + torch::jit::TaskLauncher* launcher) + : block_runners_(std::move(runners)), launcher_(std::move(launcher)) {} + + ProcessedNodeMetadata() : launcher_(nullptr) {} + + // deleted copy ctor/assigment as standard containers (vector) always + // have copy constructors, but their instantiation is not well-formed + // if the contained type (BlockRunner) is not copyable + ProcessedNodeMetadata(const ProcessedNodeMetadata&) = delete; + ProcessedNodeMetadata& operator=(const ProcessedNodeMetadata&) = delete; + + std::vector& block_runners() { + return block_runners_; + } + + void set_block_runners(std::vector runners) { + block_runners_ = std::move(runners); + } + + void set_launcher(torch::jit::TaskLauncher* launcher) { + launcher_ = launcher; + } + + torch::jit::TaskLauncher* launcher() { + return launcher_; + } + + private: + std::vector block_runners_; + torch::jit::TaskLauncher* launcher_; +}; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_API ProcessedNode { public: @@ -845,9 +908,7 @@ class TORCH_API ProcessedNode { inputs_(other.inputs_), outputs_offset_(other.outputs_offset_), values_(values), - // TODO(T105178680): For this task, we should move - // block runners out of ProcessedNode. - block_runners_(nullptr) {} + metadata_(nullptr) {} // These should be noexcept, but some Android build is failing // saying the noexcept specification doesn't match the calculated @@ -939,13 +1000,25 @@ class TORCH_API ProcessedNode { // used in debug mode bool verify_no_memory_overlap(bool force_check = false) const; - std::vector* block_runners() { - return block_runners_.get(); + // returns pointer to ProcessedNodeMetadata or nullptr if no object is owned + ProcessedNodeMetadata* metadata() { + return metadata_.get(); } - void set_block_runners( - std::unique_ptr> block_runners) { - block_runners_ = std::move(block_runners); + // attach block_runner to metadata of ProcessedNode + void set_metadata(std::vector block_runners) { + if (metadata_ == nullptr) { + metadata_ = std::make_unique(); + } + metadata_->set_block_runners(std::move(block_runners)); + } + + // attach TaskLauncher to metadata of ProcessedNode + void set_metadata(torch::jit::TaskLauncher* launcher) { + if (metadata_ == nullptr) { + metadata_ = std::make_unique(); + } + metadata_->set_launcher(launcher); } private: @@ -959,9 +1032,10 @@ class TORCH_API ProcessedNode { uint16_t outputs_offset_; bool overlap_detected_{false}; IValue* values_ = nullptr; // unowned - // For control flow; processed nodes may have sub-blocks which can - // be executed by op implementations. - std::unique_ptr> block_runners_; + // Metadata for ProcessedNode. + // 1. prim::If/Loop nodes contains sub-blocks as metadata + // 2. prim::fork nodes contains custom executor for async execution + std::unique_ptr metadata_; }; // `StaticRuntime` is the owner of the array of IValues (used for constants, @@ -983,6 +1057,20 @@ class TORCH_API StaticRuntime { std::vector&& args, const KeywordArgs& kwargs = KeywordArgs()); + // runAsync performs inline execution of graph on + // caller thread and async execution on taskLauncher + // If no custom taskLauncher is specified, execution is done + // on inter-op thread pool. + c10::intrusive_ptr runAsync( + const std::vector& args, + const KeywordArgs& kwargs = KeywordArgs(), + torch::jit::TaskLauncher taskLauncher = at::launch); + + c10::intrusive_ptr runAsync( + std::vector&& args, + const KeywordArgs& kwargs = KeywordArgs(), + torch::jit::TaskLauncher taskLauncher = at::launch); + bool check_for_memory_leak(bool output_returned = true); bool checkOutputTensorMemoryLeaks(); @@ -1052,6 +1140,8 @@ class TORCH_API StaticRuntime { }; std::unique_ptr block_; + // for execution of async operations present in graph + torch::jit::TaskLauncher async_task_launcher_; IValueArray values_; }; diff --git a/torch/csrc/jit/runtime/static/init.cpp b/torch/csrc/jit/runtime/static/init.cpp index 778d6e2ea26f4..36b8ca8eaa2e8 100644 --- a/torch/csrc/jit/runtime/static/init.cpp +++ b/torch/csrc/jit/runtime/static/init.cpp @@ -89,6 +89,28 @@ void initStaticModuleBindings(PyObject* module) { kwargs.begin(), kwargs.end()}; return self.runtime().benchmark_individual_ops( {arg_ivalues}, {kwarg_ivalues}, warmup_runs, main_runs); + }) + .def( + "runAsync", + [](StaticModule& self, + const py::tuple& args, + const py::dict& kwargs) { + std::vector arg_ivalues; + for (const auto& elem : args) { + arg_ivalues.push_back( + torch::jit::toIValue(elem, c10::AnyType::get())); + } + std::unordered_map kwarg_ivalues; + for (const auto& kv : kwargs) { + kwarg_ivalues[py::cast(kv.first)] = + torch::jit::toIValue(kv.second, c10::AnyType::get()); + } + // custom executor for async op execution + auto task_launcher = [](const std::function& f) { + at::launch(f); + }; + return toPyObject(self.runtime().runAsync( + arg_ivalues, kwarg_ivalues, task_launcher)); }); m.def( "_jit_to_static_module", diff --git a/torch/csrc/jit/runtime/static/memory_planner.cpp b/torch/csrc/jit/runtime/static/memory_planner.cpp index cf2573a67101f..3b3e69d970220 100644 --- a/torch/csrc/jit/runtime/static/memory_planner.cpp +++ b/torch/csrc/jit/runtime/static/memory_planner.cpp @@ -89,7 +89,7 @@ std::vector assignStorageToManagedTensors( auto assignToAvailableStorageGroup = [&](const Value* value) { DCHECK(!free_storage_groups.empty()); const auto storage_group = free_storage_groups.back(); - DCHECK_LT(storage_group, managed_tensor_groups.size()); + TORCH_DCHECK_LT(storage_group, managed_tensor_groups.size()); storage_group_mapping.emplace(value, storage_group); auto* tensor_ptr = tensor_value_to_tensor.at(value); managed_tensor_groups[storage_group].addTensor(tensor_ptr); @@ -260,7 +260,7 @@ void MemoryPlanner::allocateOutputTensors() { if (tensor_size == 0) { continue; } - DCHECK_LE(offset + tensor_size, output_buffer_bytes_); + TORCH_DCHECK_LE(offset + tensor_size, output_buffer_bytes_); void* src = static_cast(start + offset); // NOTE: Populating `ctx` enables clients to take the ownership of a // tensor managed by Static Runtime. Some clients use "move" semantics to @@ -282,7 +282,7 @@ void MemoryPlanner::allocateOutputTensors() { tensor->storage().set_nbytes(tensor_size); offset += tensor_size; } - DCHECK_EQ(offset, output_buffer_bytes_); + TORCH_DCHECK_EQ(offset, output_buffer_bytes_); } void MemoryPlanner::allocate() { @@ -376,16 +376,16 @@ void StandardMemoryPlanner::allocateManagedTensors() { continue; } at::StorageImpl* storageImpl = &ms.second; - DCHECK_LE(offset + tensor_size, managed_bytes_); + TORCH_DCHECK_LE(offset + tensor_size, managed_bytes_); void* src = static_cast(start + offset); #ifndef NDEBUG - DCHECK_EQ(tensor_size, managed_tensors_[group_idx].maxTensorSize()); + TORCH_DCHECK_EQ(tensor_size, managed_tensors_[group_idx].maxTensorSize()); for (auto* tensor : managed_tensors_[group_idx].group()) { - DCHECK_EQ(storageImpl, tensor->storage().unsafeGetStorageImpl()); + TORCH_DCHECK_EQ(storageImpl, tensor->storage().unsafeGetStorageImpl()); } #endif - DCHECK_NE(managed_tensors_[group_idx].numManagedTensors(), 0); + TORCH_DCHECK_NE(managed_tensors_[group_idx].numManagedTensors(), 0); reused_tensors_ += managed_tensors_[group_idx].numManagedTensors() - 1; storageImpl->set_data_ptr_noswap( at::DataPtr(src, src, nullptr, c10::Device(c10::DeviceType::CPU))); @@ -394,7 +394,7 @@ void StandardMemoryPlanner::allocateManagedTensors() { offset += tensor_size; group_idx++; } - DCHECK_EQ(offset, managed_bytes_); + TORCH_DCHECK_EQ(offset, managed_bytes_); } void StandardMemoryPlanner::deallocateManagedTensors() { @@ -457,7 +457,7 @@ void StandardMemoryPlanner::deallocateManagedTensors() { &managed_tensor_storage_impls_[group_idx].second, tensors.size()))); } - DCHECK_EQ( + TORCH_DCHECK_EQ( tensor->storage().unsafeGetStorageImpl(), &managed_tensor_storage_impls_[group_idx].second); max = std::max(max, current_size); @@ -472,7 +472,8 @@ void StandardMemoryPlanner::deallocateManagedTensors() { managed_bytes_ += max; } - DCHECK_EQ(managed_tensor_storage_impls_.size(), managed_tensors_.size()); + TORCH_DCHECK_EQ( + managed_tensor_storage_impls_.size(), managed_tensors_.size()); VLOG(1) << "managed_bytes: " << managed_bytes_; } diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 5442d481d7818..16e357d8f459d 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -92,7 +92,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( [](Node* n) -> SROperator { auto dict_type = n->output()->type()->expect(); const auto num_inputs = n->inputs().size(); - DCHECK_EQ(num_inputs % 2, 0); + TORCH_DCHECK_EQ(num_inputs % 2, 0); return [dict_type = std::move(dict_type), num_inputs, dict_size = num_inputs / 2](ProcessedNode* p_node) { @@ -699,7 +699,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( // MemoryPlanner::deallocate. MemoryPlanner knows about this // and will safely clean it up by using the corresponding // destroyBorrow method. - DCHECK_NE(&assignFrom, &p_node->Output(0)); + TORCH_DCHECK_NE(&assignFrom, &p_node->Output(0)); // MemoryPlanner should have cleaned this up! DCHECK(p_node->Output(0).isNone()); p_node->Output(0) = @@ -853,7 +853,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::If, prim_If, [](Node* node) -> SROperator { - DCHECK_EQ(node->blocks().size(), 2); + TORCH_DCHECK_EQ(node->blocks().size(), 2); const Block* true_block = node->blocks().at(0); const Block* false_block = node->blocks().at(1); @@ -883,10 +883,11 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( case BlockRunPlan::kRunBothBlocks: return [](ProcessedNode* p_node) { auto condition = p_node->Input(0).toBool(); - auto* block_runners = p_node->block_runners(); - DCHECK(block_runners); - DCHECK_EQ(block_runners->size(), 2); - auto& runner = (*block_runners)[!condition]; + auto* metadata = p_node->metadata(); + DCHECK(metadata); + auto& block_runners = metadata->block_runners(); + TORCH_DCHECK_EQ(block_runners.size(), 2); + auto& runner = block_runners[!condition]; auto output = runner({}); if (!output.isTuple()) { @@ -894,7 +895,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( return; } auto& elems = output.toTupleRef().elements(); - DCHECK_EQ(elems.size(), p_node->num_outputs()); + TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs()); for (const auto i : c10::irange(elems.size())) { p_node->Output(i) = elems[i]; } @@ -902,22 +903,24 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( case BlockRunPlan::kRunOnlyTrueBlock: return [](ProcessedNode* p_node) { auto condition = p_node->Input(0).toBool(); - auto* block_runners = p_node->block_runners(); - DCHECK(block_runners); - DCHECK_EQ(block_runners->size(), 2); + auto* metadata = p_node->metadata(); + DCHECK(metadata); + auto& block_runners = metadata->block_runners(); + TORCH_DCHECK_EQ(block_runners.size(), 2); if (condition) { - auto output = block_runners->front()({}); + auto output = block_runners.front()({}); DCHECK(output.isNone()); } }; case BlockRunPlan::kRunOnlyFalseBlock: return [](ProcessedNode* p_node) { auto condition = p_node->Input(0).toBool(); - auto* block_runners = p_node->block_runners(); - DCHECK(block_runners); - DCHECK_EQ(block_runners->size(), 2); + auto* metadata = p_node->metadata(); + DCHECK(metadata); + auto& block_runners = metadata->block_runners(); + TORCH_DCHECK_EQ(block_runners.size(), 2); if (!condition) { - auto output = block_runners->back()({}); + auto output = block_runners.back()({}); DCHECK(output.isNone()); } }; @@ -931,7 +934,7 @@ namespace { std::vector collectLoopSubBlockInputs(const ProcessedNode& p_node) { const auto num_inputs = p_node.num_inputs(); - DCHECK_GE(num_inputs, 2); + TORCH_DCHECK_GE(num_inputs, 2); // The first two inputs to the loop node are the max trip count // and initial condition. We don't collect them here, since those // are not inputs for the sub-block. @@ -964,16 +967,19 @@ class TORCH_API ForkedSubgraphSRLauncher { ForkedSubgraphSRLauncher( std::shared_ptr smodule, std::vector args, - c10::intrusive_ptr future) + c10::intrusive_ptr future, + TaskLauncher launcher) : smodule_(std::move(smodule)), args_(std::move(args)), - future_(std::move(future)) {} + future_(std::move(future)), + launcher_(std::move(launcher)) {} void operator()() { try { StaticRuntime runtime(*smodule_); - auto output = runtime(args_, {}); - future_->markCompleted(output); + auto future_subgraph = runtime.runAsync(args_, {}, launcher_); + future_subgraph->waitAndThrow(); + future_->markCompleted(future_subgraph->value()); } catch (const std::exception& e) { future_->setErrorIfNeeded( std::make_exception_ptr(c10::ivalue::Future::FutureError(e.what()))); @@ -984,6 +990,7 @@ class TORCH_API ForkedSubgraphSRLauncher { std::shared_ptr smodule_; std::vector args_; c10::intrusive_ptr future_; + torch::jit::TaskLauncher launcher_; }; /* @@ -1037,9 +1044,13 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( createFutureTypeFromGraphOutput(forkedGraph); p_node->Output(0) = future; - TaskLauncher taskLauncher_ = at::launch; - ForkedSubgraphSRLauncher runtime_launcher(smodule, args, future); - taskLauncher_(std::move(runtime_launcher)); + auto* metadata = p_node->metadata(); + DCHECK(metadata); + auto* launcher = metadata->launcher(); + DCHECK(launcher); + ForkedSubgraphSRLauncher runtime_launcher( + smodule, args, future, *launcher); + (*launcher)(std::move(runtime_launcher)); }; }); /* @@ -1067,7 +1078,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( return; } auto& elems = future->value().toTupleRef().elements(); - DCHECK_EQ(elems.size(), p_node->num_outputs()); + TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs()); for (const auto i : c10::irange(elems.size())) { p_node->Output(i) = elems[i]; } @@ -1082,10 +1093,11 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( const auto max_trip_count = p_node->Input(0).toInt(); auto condition = p_node->Input(1).toBool(); - auto* block_runners = p_node->block_runners(); - DCHECK(block_runners); - DCHECK_EQ(block_runners->size(), 1); - auto& runner = (*block_runners)[0]; + auto* metadata = p_node->metadata(); + DCHECK(metadata); + auto& block_runners = metadata->block_runners(); + TORCH_DCHECK_EQ(block_runners.size(), 1); + auto& runner = block_runners[0]; auto args = collectLoopSubBlockInputs(*p_node); int64_t loop_count = 0; @@ -1107,7 +1119,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( } const auto num_outputs = p_node->num_outputs(); - DCHECK_EQ(args.size(), num_outputs + 1); + TORCH_DCHECK_EQ(args.size(), num_outputs + 1); for (const auto i : c10::irange(num_outputs)) { p_node->Output(i) = std::move(args[i + 1]); } @@ -1172,7 +1184,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( const auto num_inputs = pnode->num_inputs(); auto stack = boxInputs(*pnode); format(stack, num_inputs); - DCHECK_EQ(stack.size(), 1); + TORCH_DCHECK_EQ(stack.size(), 1); pnode->Output(0) = std::move(stack[0]); }; }); @@ -1260,7 +1272,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( const auto elem_type = pnode->Input(2).toInt(); std::vector stack{input, dim, elem_type}; toList(stack); - DCHECK_EQ(stack.size(), 1); + TORCH_DCHECK_EQ(stack.size(), 1); pnode->Output(0) = std::move(stack[0]); }; }); diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 1288f6a35b825..93a9043ed276d 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1338,19 +1338,19 @@ ToArgs extract_to_args(ProcessedNode* p_node) { const auto& other = p_node->Input(1).toTensor(); result.dtype = other.scalar_type(); result.layout = other.layout(); - DCHECK_EQ(other.device().type(), c10::DeviceType::CPU); + TORCH_DCHECK_EQ(other.device().type(), c10::DeviceType::CPU); } else { const auto& self = p_node->Input(0).toTensor(); result.dtype = p_node->Input(1).toOptional(); result.layout = self.layout(); // Static runtime only works with CPU tensors; don't need to read this. - DCHECK_EQ(self.device().type(), c10::DeviceType::CPU); + TORCH_DCHECK_EQ(self.device().type(), c10::DeviceType::CPU); result.know_to_will_alias = has_constant_non_tensor_dtype_and_flags && (!result.dtype.has_value() || result.dtype.value() == self.dtype().toScalarType()); } if (has_memory_format) { - DCHECK_EQ(p_node->num_inputs(), 5); + TORCH_DCHECK_EQ(p_node->num_inputs(), 5); result.memory_format = p_node->Input(4).toOptional(); result.know_to_will_alias = result.know_to_will_alias && (result.memory_format.value_or(c10::MemoryFormat::Preserve) == @@ -1691,10 +1691,10 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator { }; } if (n->matches(torch::schema( - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { return [](ProcessedNode* p_node) { const at::Tensor& self = p_node->Input(0).toTensor(); - auto dim = p_node->Input(1).toIntList().vec(); + auto dim = p_node->Input(1).toDimVector(); auto keepdim = p_node->Input(2).toBool(); auto dtype = p_node->Input(3).toOptional(); if (p_node->Output(0).isNone()) { @@ -1712,7 +1712,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::mean, aten_mean, [](Node* n) -> SROperator { if (n->matches(torch::schema( - "aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { + "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); const auto dim = p_node->Input(1).toDimVector(); @@ -1850,7 +1850,8 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { LogAndDumpSchema(n); return nullptr; } - return [](ProcessedNode* p_node) { + + return [te = createDiv()](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); c10::optional rounding_mode = c10::nullopt; if (p_node->num_inputs() > 2) { @@ -1861,12 +1862,37 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { : at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar()); if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::cpu::div(in0_t, in1_t, rounding_mode); - return; + p_node->Output(0) = create_empty_from(in0_t); } auto& out_t = p_node->Output(0).toTensor(); - fastResizeToZero(out_t); - at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode); + + if (in0_t.sizes() == in1_t.sizes() && + in0_t.scalar_type() == in1_t.scalar_type() && + in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() && + in0_t.scalar_type() == at::kFloat) { + int64_t dim = in0_t.numel(); + int i_rounding_mode = 0; + if (rounding_mode && !rounding_mode.value().empty()) { + const char peek_rounding_mode = rounding_mode.value().at(0); + if (peek_rounding_mode == 't') { + // trunc after div + i_rounding_mode = 1; + } else if (peek_rounding_mode == 'f') { + // floor after div + i_rounding_mode = 2; + } + } + at::native::resize_(out_t, in0_t.sizes()); + te->call( + {out_t.data_ptr(), + in0_t.data_ptr(), + in1_t.data_ptr(), + &i_rounding_mode, + &dim}); + } else { + fastResizeToZero(out_t); + at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode); + } }; }); diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 9a4c9e811233f..7234cc3652eea 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -639,7 +639,7 @@ void ReplaceWithMaybeCopy( static const auto select_tensor_symbol = fromQualString("static_runtime::select_tensor"); auto* select_tensor_node = graph->create(select_tensor_symbol, 1); - DCHECK_EQ(new_node->outputs().size(), 2); + TORCH_DCHECK_EQ(new_node->outputs().size(), 2); select_tensor_node->addInput(n->input(0)); for (auto* output : new_node->outputs()) { select_tensor_node->addInput(output); @@ -869,6 +869,12 @@ bool shouldNotFuseListUnpackSpecialCase(const Node* node) { void FuseListUnpack(std::shared_ptr& graph) { const FastMap unfused_to_fused = { + OP_PAIR( + "torcharrow::inference_wrapper_run_flat", + "static_runtime::fused_inference_wrapper_run_flat"), + OP_PAIR( + "torcharrow::variadic_inference_wrapper_run_flat", + "static_runtime::fused_variadic_inference_wrapper_run_flat"), OP_PAIR("fb::equally_split", "static_runtime::fused_equally_split"), OP_PAIR( "fb::sigrid_transforms", "static_runtime::fused_sigrid_transforms"), @@ -1349,6 +1355,34 @@ void EliminateNoOpSlice(std::shared_ptr& graph) { } } +void UseInPlaceGetRealInputsFromOptionalInputsV2( + std::shared_ptr& graph) { +#ifdef FBCODE_CAFFE2 + const std::string original_pattern = R"IR( + graph(%optional_input: (Tensor, Tensor?, Tensor?)?[], %include_last_offsets: bool[]): + %x : (Tensor, Tensor?, Tensor?)[] = remote_collection::get_real_inputs_from_optional_inputs_v2(%optional_input, %include_last_offsets) + return (%x))IR"; + + const std::string new_pattern = R"IR( + graph(%optional_input: (Tensor, Tensor?, Tensor?)?[], %include_last_offsets: bool[]): + %x : (Tensor, Tensor?, Tensor?)[] = static_runtime::get_real_inputs_from_optional_inputs_v2_inplace(%optional_input, %include_last_offsets) + return (%x))IR"; + + auto isSingleUse = [](Value* value) { return value->uses().size() == 1; }; + + auto filter = [&isSingleUse]( + const Match& match, + const std::unordered_map& vmap) { + auto* real_node = match.nodes_map.at(vmap.at("x")->node()); + return isSingleUse(real_node->input(0)); + }; + + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(original_pattern, new_pattern); + fuse.runOnGraph(graph, filter); +#endif +} + void FuseClampNaNToNum(std::shared_ptr& graph) { #ifdef FBCODE_CAFFE2 std::string pattern = R"IR( diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h index b9dfb70c8f8a7..60b87550729ac 100644 --- a/torch/csrc/jit/runtime/static/passes.h +++ b/torch/csrc/jit/runtime/static/passes.h @@ -82,5 +82,8 @@ TORCH_API void RemoveUnnecessaryEmbeddingBagOutputs( TORCH_API void FuseClampNaNToNum(std::shared_ptr& graph); +TORCH_API void UseInPlaceGetRealInputsFromOptionalInputsV2( + std::shared_ptr& graph); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/processed_node_wrapper.h b/torch/csrc/jit/runtime/static/processed_node_wrapper.h index 8372c8d6fd312..e50347821ea9a 100644 --- a/torch/csrc/jit/runtime/static/processed_node_wrapper.h +++ b/torch/csrc/jit/runtime/static/processed_node_wrapper.h @@ -27,7 +27,7 @@ class ProcessedNodeWrapperBase { : container_(container), idx_(start_idx) {} ProcessedNodeWrapperBaseIter& operator++() { - DCHECK_NE(idx_, container_->size()); + TORCH_DCHECK_NE(idx_, container_->size()); ++idx_; return *this; } @@ -51,7 +51,7 @@ class ProcessedNodeWrapperBase { friend bool operator==( ProcessedNodeWrapperBaseIter lhs, ProcessedNodeWrapperBaseIter rhs) { - DCHECK_EQ(lhs.container_, rhs.container_); + TORCH_DCHECK_EQ(lhs.container_, rhs.container_); return lhs.idx_ == rhs.idx_; } diff --git a/torch/csrc/jit/runtime/static/te_wrapper.cpp b/torch/csrc/jit/runtime/static/te_wrapper.cpp index d982e652344fa..35bef91ea70ad 100644 --- a/torch/csrc/jit/runtime/static/te_wrapper.cpp +++ b/torch/csrc/jit/runtime/static/te_wrapper.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -113,6 +114,46 @@ void updateNNCCache(NodeKind kind, std::shared_ptr code) { } // namespace +std::shared_ptr createDiv() { + auto wrap = lookupNNCCache(aten::div); + if (wrap) { + return wrap; + } + wrap = std::make_shared(); + + auto dim = VarHandle("dim", kInt); + auto mode = VarHandle("mode", kInt); + BufHandle A("A", {dim}, kFloat); + BufHandle B("B", {dim}, kFloat); + + using axis = const VarHandle&; + Tensor C = Compute("C", {dim}, [&](axis x) { + auto true_div_result = A.load(x) / B.load(x); + + auto mode_default = IntImm::make(0); + auto mode_trunc = IntImm::make(1); + auto mode_floor = IntImm::make(2); + + // this is a glorified ternary choice operator train + return CompareSelect::make( + mode, + mode_default, + true_div_result, + CompareSelect::make( + mode, + mode_trunc, + trunc(true_div_result), + floor(true_div_result), + kEQ), + kEQ); + }); + + wrap = wrapTECompute(wrap, C, {A, B, mode, dim}); + + updateNNCCache(aten::div, wrap); + return wrap; +} + std::shared_ptr createLogit() { auto wrap = lookupNNCCache(aten::logit); if (wrap) { diff --git a/torch/csrc/jit/runtime/static/te_wrapper.h b/torch/csrc/jit/runtime/static/te_wrapper.h index a9f2a5553dd46..f81d9afb37c54 100644 --- a/torch/csrc/jit/runtime/static/te_wrapper.h +++ b/torch/csrc/jit/runtime/static/te_wrapper.h @@ -33,6 +33,7 @@ class TEWrapper { #endif }; +std::shared_ptr createDiv(); std::shared_ptr createLogit(); std::shared_ptr createRelu(); std::shared_ptr createTanh(); diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 6ed63a71e667d..8cf56d6696569 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -27,10 +27,12 @@ const std::vector functions = { def AD_sum_backward(grad, sizes: List[int], - dims: List[int], + dims: Optional[List[int]], keepdim: bool): if not keepdim and len(sizes) > 0: - if len(dims) == 1: + if dims is None: + return grad.expand(sizes) + elif len(dims) == 1: return grad.unsqueeze(dims[0]).expand(sizes) else: res = AD_unsqueeze_multiple(grad, dims, len(sizes)) @@ -57,7 +59,7 @@ const std::vector functions = { return torch.mean(self, dtype=dtype), backward def mean_1(self, - dim: List[int], + dim: Optional[List[int]], keepdim: bool, *, dtype: Optional[int]): @@ -93,14 +95,20 @@ const std::vector functions = { return grad * (self - self.mean()) * 2.0 / (self.numel() - correction) def AD_safe_size(sizes: List[int], - dims: List[int]): + dims: Optional[List[int]]): if len(sizes) == 0: return 1 size = 1 - for i in range(len(dims)): - d = dims[i] - size *= sizes[d] + + if dims is None: + for s in sizes: + size *= s + + else: + for i in range(len(dims)): + d = dims[i] + size *= sizes[d] return size diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index c59139daf51d7..f4273867fa61f 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -778,7 +778,7 @@ flatbuffers::DetachedBuffer save_mobile_module_to_bytes( jit_constants); } -void save_mobile_module_to_func( +static void save_mobile_module_to_func( const mobile::Module& module, const std::function& writer_func) { auto buffer = save_mobile_module_to_bytes(module); @@ -790,7 +790,11 @@ bool register_flatbuffer_serializer() { return true; } +// iOS builds are often build with -Wglobal-constructor to minimize +// startup time. So let them call register manually if needed. +#if !defined(__APPLE__) const bool kFlatbufferSerializerRegistered = register_flatbuffer_serializer(); +#endif } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp index 9c2b62a7402a9..cd91892eaf24d 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp @@ -74,7 +74,7 @@ flatbuffers::DetachedBuffer save_jit_module_to_bytes( return save_mobile_module_to_bytes(mobilem, extra_files, jitfiles, constants); } -void save_jit_module_to_write_func( +static void save_jit_module_to_write_func( const Module& module, const ExtraFilesMap& extra_files, bool save_mobile_debug_info, @@ -92,7 +92,9 @@ bool register_flatbuffer_all() { return true; } +#if !defined(__APPLE__) const bool kFlatbufferSerializerJitInitialized = register_flatbuffer_all(); +#endif } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index f8876128e3661..484f0161e3a49 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -314,6 +314,13 @@ Module import_ir_module( c10::optional device, ExtraFilesMap& extra_files) { in.seekg(0, in.beg); + // NOTE: Zipformat can be large files. So using stream version directly + // instead of reading the file all at once. + if (getFileFormat(in) != FileFormat::FlatbufferFileFormat) { + auto reader = torch::make_unique(&in); + ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); + return deserializer.deserialize(device, extra_files); + } std::shared_ptr data; size_t size = 0; std::tie(data, size) = get_stream_content(in); @@ -350,6 +357,13 @@ Module import_ir_module( const std::string& filename, c10::optional device, ExtraFilesMap& extra_files) { + // NOTE: Zipformat can be large files. So using stream version directly + // instead of reading the file all at once. + if (getFileFormat(filename) != FileFormat::FlatbufferFileFormat) { + auto reader = torch::make_unique(filename); + ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); + return deserializer.deserialize(device, extra_files); + } std::shared_ptr data; size_t size = 0; std::tie(data, size) = get_file_content(filename.c_str()); @@ -378,10 +392,9 @@ Module import_ir_module( std::shared_ptr rai, c10::optional device, ExtraFilesMap& extra_files) { - std::shared_ptr data; - size_t size = 0; - std::tie(data, size) = get_rai_content(rai.get()); - return _load_jit_module_from_bytes(data, size, cu, device, extra_files); + auto reader = std::make_shared(std::move(rai)); + ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); + return deserializer.deserialize(device, extra_files); } Module load(std::istream& in, c10::optional device) { diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 89436368a5715..a77deb7add639 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -505,11 +505,12 @@ PickleOpCode Unpickler::readInstruction() { tensor = at::empty({0}, options).set_(storage); } - if (device.is_cuda() || device.is_xpu() || device.is_meta()) { + if (device.is_cuda() || device.is_xpu() || device.is_meta() || + device.is_hpu()) { tensor = tensor.to(device, tensor.scalar_type()); } else if (device.type() != DeviceType::CPU) { AT_ERROR( - "supported devices include CPU and CUDA, however got ", + "supported devices include CPU, CUDA and HPU, however got ", DeviceTypeName(device.type(), false)); } stack_.emplace_back(std::move(tensor)); @@ -685,7 +686,12 @@ void Unpickler::readGlobal( class_name, "'"); } else { - AT_ASSERT(type_resolver_); + TORCH_CHECK( + type_resolver_, + "Unpickler found unknown type ", + module_name, + ".", + class_name); at::StrongTypePtr type = type_resolver_(c10::QualifiedName(module_name, class_name)); if (auto enum_type = type.type_->cast()) { diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 7db71ed721d6c..9c84137d330d1 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -8,7 +8,13 @@ #include #include +// llvm::SCEVPredicate has virtual function but non-virtual destructor +// https://github.com/llvm/llvm-project/blob/c1a0a213378a458fbea1a5c77b315c7dce08fd05/llvm/include/llvm/Analysis/ScalarEvolution.h#L198 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" #include +#pragma GCC diagnostic pop + #include #include #include diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp index a2e576d167d32..a7b3a707db277 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -9,7 +9,12 @@ #include #include #include +// llvm::SCEVPredicate has virtual function but non-virtual destructor +// https://github.com/llvm/llvm-project/blob/c1a0a213378a458fbea1a5c77b315c7dce08fd05/llvm/include/llvm/Analysis/ScalarEvolution.h#L198 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" #include +#pragma GCC diagnostic pop #include #include #include diff --git a/torch/csrc/jit/tensorexpr/lowerings.cpp b/torch/csrc/jit/tensorexpr/lowerings.cpp index 6d1681c422755..b4f59c8cfeb67 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.cpp +++ b/torch/csrc/jit/tensorexpr/lowerings.cpp @@ -1767,7 +1767,7 @@ int nnc_lowerings_lazy_registration() { RegisterNNCLoweringsFunction aten_sum( {"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)", - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"}, + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"}, computeSum); RegisterNNCLoweringsFunction aten_softmax( @@ -1803,7 +1803,7 @@ int nnc_lowerings_lazy_registration() { RegisterNNCLoweringsFunction aten_mean( {"aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)", - "aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"}, + "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"}, computeMean); RegisterNNCLoweringsFunction aten_max_reduction( {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"}, diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index 83128739aadc0..d66cd380f9761 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #ifdef USE_CUDA #include #endif diff --git a/torch/csrc/lazy/backend/backend_device.cpp b/torch/csrc/lazy/backend/backend_device.cpp index 13a67bfe481ae..ca19d1c42d7e8 100644 --- a/torch/csrc/lazy/backend/backend_device.cpp +++ b/torch/csrc/lazy/backend/backend_device.cpp @@ -10,9 +10,9 @@ namespace torch { namespace lazy { -// TODO(alanwaketan): Use the backend API to get the default device type. -// In the future, we should also get the default device ordinal. -BackendDevice::BackendDevice() : type_(std::make_shared()) {} +BackendDevice::BackendDevice() + : type_(getBackend()->GetDefaultDeviceType()), + ordinal_(getBackend()->GetDefaultDeviceOrdinal()) {} BackendDevice::BackendDevice( std::shared_ptr&& type, @@ -41,11 +41,11 @@ std::ostream& operator<<(std::ostream& os, const BackendDevice& device) { return os; } -// TODO(whc) refactor this: we need to support non-zero default ordinal for -// torch/XLA. BackendDevice atenDeviceToBackendDevice(const c10::Device& device) { TORCH_CHECK(device.type() == at::kLazy, device); - int64_t ordinal = device.has_index() ? device.index() : 0; + int64_t ordinal = device.has_index() + ? device.index() + : getBackend()->GetDefaultDeviceOrdinal(); return BackendDevice(getBackend()->GetDefaultDeviceType(), ordinal); } diff --git a/torch/csrc/lazy/backend/backend_device.h b/torch/csrc/lazy/backend/backend_device.h index fc911a859fe17..55d7ecdb5d3a5 100644 --- a/torch/csrc/lazy/backend/backend_device.h +++ b/torch/csrc/lazy/backend/backend_device.h @@ -60,7 +60,7 @@ class TORCH_API BackendDevice { // Use shared_ptr instead of unique_ptr so that BackendDevice can be copied. std::shared_ptr type_; - int64_t ordinal_{0}; + int64_t ordinal_; }; TORCH_API std::ostream& operator<<( diff --git a/torch/csrc/lazy/backend/backend_interface.h b/torch/csrc/lazy/backend/backend_interface.h index 23eaad3ffb5a6..cf029be2bf582 100644 --- a/torch/csrc/lazy/backend/backend_interface.h +++ b/torch/csrc/lazy/backend/backend_interface.h @@ -86,7 +86,7 @@ class TORCH_API BackendImplInterface { std::vector instances) const = 0; virtual std::vector ExecuteComputation( - Computation& computation, + torch::lazy::ComputationPtr computation, c10::ArrayRef arguments, const BackendDevice& device) const = 0; @@ -95,11 +95,17 @@ class TORCH_API BackendImplInterface { * */ // Set or get the default device type. - // For backends used with virtual c10:: Devices, this configures what real + // For backends used with virtual c10::Devices, this configures what real // device type the backend should use, and matters if the backend supports // more than one type of real device. virtual std::shared_ptr GetDefaultDeviceType() const = 0; - virtual void SetDefaultDeviceType(std::string) = 0; + virtual void SetDefaultDeviceType(int8_t type) = 0; + + // Set or get the default device ordinal. + // For backends that supports multi-device, this configures what the + // default device the backend should use. + virtual int64_t GetDefaultDeviceOrdinal() const = 0; + virtual void SetDefaultDeviceOrdinal(int64_t) = 0; // Specify which aten device should be used for eager fallback // may change depending on current 'Default' DeviceType @@ -108,6 +114,10 @@ class TORCH_API BackendImplInterface { // Query all available backend devices virtual std::vector GetBackendDevices() const = 0; + virtual std::string CreateMetricReport() const { + return ""; + } + // Map a particular c10:: device to a concrete backend device // Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are // virtual devices, meaning they may map to a gpu, tpu, etc. behind the diff --git a/torch/csrc/lazy/core/dynamic_ir.h b/torch/csrc/lazy/core/dynamic_ir.h index a12bd065730ea..15d5aa4a8662c 100644 --- a/torch/csrc/lazy/core/dynamic_ir.h +++ b/torch/csrc/lazy/core/dynamic_ir.h @@ -49,6 +49,7 @@ class TORCH_API DimensionNode { virtual int64_t getStaticValue() const { TORCH_CHECK(false, "NYI"); }; + virtual ~DimensionNode() = default; }; } // namespace lazy diff --git a/torch/csrc/lazy/core/internal_ops/ltc_ops.h b/torch/csrc/lazy/core/internal_ops/ltc_ops.h index 768faa1d05f4c..3f195d8b445cf 100644 --- a/torch/csrc/lazy/core/internal_ops/ltc_ops.h +++ b/torch/csrc/lazy/core/internal_ops/ltc_ops.h @@ -2,6 +2,8 @@ #include +#include + #include #include @@ -22,13 +24,13 @@ class TORCH_API OpKindWrapper { private: const OpKind& get() const { - std::call_once(once_, [this]() { op_kind_ = OpKind::Get(name_); }); + c10::call_once(once_, [this]() { op_kind_ = OpKind::Get(name_); }); return op_kind_; } const char* name_; mutable OpKind op_kind_; - mutable std::once_flag once_; + mutable c10::once_flag once_; }; const OpKindWrapper ltc_all_to_all("lazy_tensors::all_to_all"); diff --git a/torch/csrc/lazy/core/ir.cpp b/torch/csrc/lazy/core/ir.cpp index 62a0eeadf3b10..e522a23c093a5 100644 --- a/torch/csrc/lazy/core/ir.cpp +++ b/torch/csrc/lazy/core/ir.cpp @@ -13,6 +13,8 @@ C10_DEFINE_bool( namespace torch { namespace lazy { +static const torch::lazy::Output kNullOutput = torch::lazy::Output(); + size_t Output::Hasher::operator()(const Output& output) const { return StdHashCombine( reinterpret_cast(output.node), output.index); @@ -138,10 +140,17 @@ Shape Node::computeShape(const std::function& shape_fn) { const std::vector& Node::operands() const { return operands_as_outputs_; } + const Output& Node::operand(size_t i) const { return operands_as_outputs_.at(i); } +const Output& Node::nullable_operand(size_t i) const { + // We use kNullOutput instead of kNullValue here to avoid implicit casting, + // which would prevent this method from returning a reference. + return i < operands_as_outputs_.size() ? operand(i) : kNullOutput; +} + std::string Node::ToString() const { std::stringstream ss; ss << shapes() << " " << op(); @@ -156,7 +165,7 @@ std::string Node::ToString() const { } void Node::AddOperand(NodePtr node, size_t index) { - CHECK_LT(index, node->num_outputs()); + TORCH_CHECK_LT(index, node->num_outputs()); operands_.push_back(node); operands_as_outputs_.emplace_back(operands_.back().get(), index); } diff --git a/torch/csrc/lazy/core/ir.h b/torch/csrc/lazy/core/ir.h index 2e3b6891fdb83..0f40456e1bf56 100644 --- a/torch/csrc/lazy/core/ir.h +++ b/torch/csrc/lazy/core/ir.h @@ -134,6 +134,9 @@ class TORCH_API Node { virtual const Output& operand(size_t i) const; + // Gets operand at index i if index is valid, or kNullOutput otherwise. + virtual const Output& nullable_operand(size_t i) const; + // Returns the hash of the dag used to look up the compiled graph virtual hash_t hash() const = 0; diff --git a/torch/csrc/lazy/core/ir_builder.h b/torch/csrc/lazy/core/ir_builder.h index 2245b97b98fee..0af188f0131a5 100644 --- a/torch/csrc/lazy/core/ir_builder.h +++ b/torch/csrc/lazy/core/ir_builder.h @@ -130,6 +130,8 @@ struct IrBuilder { virtual NodePtr MakeSizeAdd(const Value& a, const Value& b) const = 0; virtual NodePtr MakeSizeMul(const Value& a, const Value& b) const = 0; virtual NodePtr MakeSizeDiv(const Value& a, const Value& b) const = 0; + + virtual ~IrBuilder() = default; }; static inline NodePtr MakeDeviceData(const std::shared_ptr& data) { diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index c312997c49de5..9f83d7730a9cf 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -962,7 +962,7 @@ std::shared_ptr LazyGraphExecutor:: VLOG(3) << "Executing IR graph hash " << HashToString(hash) << " on device " << async->device << " ..."; auto results = getBackend()->ExecuteComputation( - *async->cached_computation->computation, + async->cached_computation->computation, async->parameters_data, async->device); VLOG(3) << "Executing IR graph hash " << HashToString(hash) diff --git a/torch/csrc/lazy/core/metrics.cpp b/torch/csrc/lazy/core/metrics.cpp index c81fed03d2494..cb8120c1d45c9 100644 --- a/torch/csrc/lazy/core/metrics.cpp +++ b/torch/csrc/lazy/core/metrics.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -346,6 +347,9 @@ std::string CreateMetricReport() { arena->ForEachCounter([&ss](const std::string& name, CounterData* data) { EmitCounterInfo(name, data, &ss); }); + + // Append the backend metrics report + ss << getBackend()->CreateMetricReport(); return ss.str(); } diff --git a/torch/csrc/lazy/core/ops/utils.cpp b/torch/csrc/lazy/core/ops/utils.cpp index aa8cc62b8614d..f94fc0700d605 100644 --- a/torch/csrc/lazy/core/ops/utils.cpp +++ b/torch/csrc/lazy/core/ops/utils.cpp @@ -68,7 +68,7 @@ Shape MakeSelectShape( int64_t GetStride(int64_t start, int64_t end, int64_t stride) { if (stride == 0) { - CHECK_EQ(start, end); + TORCH_CHECK_EQ(start, end); stride = 1; } return stride; diff --git a/torch/csrc/lazy/core/shape.cpp b/torch/csrc/lazy/core/shape.cpp index 2844cc818e235..59f6f12334b84 100644 --- a/torch/csrc/lazy/core/shape.cpp +++ b/torch/csrc/lazy/core/shape.cpp @@ -10,8 +10,13 @@ C10_DEFINE_bool( namespace torch { namespace lazy { -Shape::Shape(at::ScalarType scalar_type, c10::ArrayRef sizes) - : scalar_type_(scalar_type), sizes_(sizes.begin(), sizes.end()) {} +Shape::Shape( + at::ScalarType scalar_type, + c10::ArrayRef sizes, + c10::optional> is_symbolic) + : scalar_type_(scalar_type), + sizes_(sizes.begin(), sizes.end()), + is_symbolic_(std::move(is_symbolic)) {} std::string Shape::to_string() const { return c10::str(toString(scalar_type_), "[", c10::Join(",", sizes_), "]"); diff --git a/torch/csrc/lazy/core/shape.h b/torch/csrc/lazy/core/shape.h index 21377287be92f..1c6b4d5bb3d81 100644 --- a/torch/csrc/lazy/core/shape.h +++ b/torch/csrc/lazy/core/shape.h @@ -16,7 +16,10 @@ class TORCH_API Shape { public: Shape() = default; - Shape(at::ScalarType scalar_type, c10::ArrayRef sizes); + Shape( + at::ScalarType scalar_type, + c10::ArrayRef sizes, + c10::optional> is_symbolic = c10::nullopt); std::string to_string() const; @@ -56,12 +59,12 @@ class TORCH_API Shape { private: c10::ScalarType scalar_type_{c10::ScalarType::Undefined}; + // Sizes are the upper bound sizes for a tensor, used by XLA. + std::vector sizes_; // Stores which dimmensions are symbolic // If nullopt, either it hasn't been initialized or the symbolic // dimmensions are not calculatable c10::optional> is_symbolic_ = c10::nullopt; - // Sizes are the upper bound sizes for a tensor, used by XLA. - std::vector sizes_; }; TORCH_API std::ostream& operator<<(std::ostream& out, const Shape& shape); diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index 0612e2220afc5..64bef53c59ef9 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -50,10 +50,12 @@ #include #include +#include #include #include #include #include +#include #include #include #include @@ -62,6 +64,8 @@ #include #include #include +#include +#include #include #include @@ -418,6 +422,58 @@ std::vector compute_shape_embedding_dense_backward( Shape(grad_output.scalar_type(), {num_weights, grad_output.size(-1)})}; } +std::vector compute_shape_expand( + const at::Tensor& self, + at::IntArrayRef size, + bool implicit) { + TORCH_CHECK_GE(size.size(), self.dim()); + int64_t num_new_dimensions = size.size() - self.dim(); + std::vector padded_self(num_new_dimensions, 0); + padded_self.insert( + padded_self.end(), self.sizes().begin(), self.sizes().end()); + std::vector target_size(size.size()); + for (const auto idx : c10::irange(size.size())) { + target_size[idx] = size[idx] == -1 ? padded_self[idx] : size[idx]; + } + return {Shape(self.scalar_type(), target_size)}; +} + +std::vector compute_shape_expand( + const at::Tensor& self, + c10::SymIntArrayRef size, + bool implicit) { + TORCH_CHECK_GE(size.size(), self.dim()); + std::vector _sizes = ToVector(size); + int64_t num_new_dimensions = _sizes.size() - self.dim(); + std::vector padded_self(num_new_dimensions, 0); + padded_self.insert( + padded_self.end(), self.sizes().begin(), self.sizes().end()); + std::vector target_size(_sizes.size()); + for (const auto idx : c10::irange(_sizes.size())) { + if (_sizes[idx].is_symbolic()) { + c10::SymIntNode symbolicIntNode = _sizes[idx].toSymIntNodeImpl(); + auto* lazySymIntNode = + dynamic_cast(symbolicIntNode.get()); + TORCH_INTERNAL_ASSERT(lazySymIntNode); + auto size_node = lazySymIntNode->node_; + auto static_value = + std::dynamic_pointer_cast(size_node) + ->getStaticValue(); + target_size[idx] = static_value; + } else { + target_size[idx] = _sizes[idx].as_int_unchecked(); + if (_sizes[idx].as_int_unchecked() == -1) { + // -1 can't be specified for non-existing dimensions + TORCH_CHECK(idx >= num_new_dimensions); + target_size[idx] = padded_self[idx]; + } else { + target_size[idx] = _sizes[idx].as_int_unchecked(); + } + } + } + return {Shape(self.scalar_type(), target_size)}; +} + std::vector compute_shape_index_select( const at::Tensor& self, int64_t dim, @@ -442,14 +498,8 @@ std::vector compute_shape_inverse(const at::Tensor& self) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_kl_div_backward( - const at::Tensor& grad_output, - const at::Tensor& self, - const at::Tensor& target, - int64_t reduction, - bool log_target) { - // Based on definition of aten/src/ATen/native/Loss.cpp::kl_div_backward_cpu. - return {Shape(self.scalar_type(), self.sizes().vec())}; +std::vector compute_shape_isnan(const at::Tensor& self) { + return {Shape(c10::ScalarType::Bool, self.sizes().vec())}; } std::vector compute_shape_cat(at::TensorList tensors, int64_t dim) { @@ -470,6 +520,77 @@ std::vector compute_shape_cat(at::TensorList tensors, int64_t dim) { return {Shape(tensors[0].scalar_type(), out_shape)}; } +std::vector compute_shape_native_batch_norm( + const at::Tensor& input, + const c10::optional& weight, + const c10::optional& bias, + const c10::optional& running_mean, + const c10::optional& running_var, + bool training, + double momentum, + double eps) { + std::vector shapes; + shapes.reserve(3); + shapes.emplace_back(input.scalar_type(), input.sizes().vec()); + + // A separate mean and var needs to be kept for each channel. + TORCH_CHECK( + input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); + int64_t num_features = input.size(1); + + if (running_mean.has_value()) { + shapes.emplace_back( + running_mean.value().scalar_type(), running_mean.value().sizes().vec()); + } else { + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + } + + if (running_var.has_value()) { + shapes.emplace_back( + running_var.value().scalar_type(), running_var.value().sizes().vec()); + } else { + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + } + return shapes; +} + +std::vector compute_shape_native_batch_norm_backward( + const at::Tensor& grad_out, + const at::Tensor& input, + const c10::optional& weight, + const c10::optional& running_mean, + const c10::optional& running_var, + const c10::optional& save_mean, + const c10::optional& save_invstd, + bool train, + double eps, + ::std::array output_mask) { + std::vector shapes; + shapes.reserve(3); + shapes.emplace_back(input.scalar_type(), input.sizes().vec()); + + // A separate mean and var needs to be kept for each channel. + TORCH_CHECK( + input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); + int64_t num_features = input.size(1); + + // `weight` and `bias` are vectors of length C (number of channels)` + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + + return shapes; +} + std::vector compute_shape_native_layer_norm( const at::Tensor& input, at::IntArrayRef normalized_shape, @@ -528,6 +649,17 @@ std::vector compute_shape_mean( return {Shape(self.scalar_type(), {})}; } +std::vector compute_shape_new_empty_strided( + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + return {Shape(dtype.has_value() ? *dtype : self.scalar_type(), size.vec())}; +} + std::vector compute_shape_mv( const at::Tensor& self, const at::Tensor& vec) { @@ -550,25 +682,25 @@ std::vector compute_shape_native_dropout_backward( return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())}; } -std::vector compute_shape_random_functional( +std::vector compute_shape_random( const at::Tensor& self, c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_random_functional( +std::vector compute_shape_random( const at::Tensor& self, int64_t to, c10::optional generator) { - return compute_shape_random_functional(self, generator); + return compute_shape_random(self, generator); } -std::vector compute_shape_random_functional( +std::vector compute_shape_random( const at::Tensor& self, int64_t from, c10::optional to, c10::optional generator) { - return compute_shape_random_functional(self, generator); + return compute_shape_random(self, generator); } std::vector compute_shape_relu(const at::Tensor& self) { @@ -596,7 +728,7 @@ std::vector compute_shape_sum( ; } -std::vector compute_shape_zero_functional(const at::Tensor& self) { +std::vector compute_shape_zero(const at::Tensor& self) { return {Shape(self.scalar_type(), self.sizes().vec())}; } @@ -641,19 +773,20 @@ std::vector compute_shape_slogdet(const at::Tensor& self) { } std::vector compute_shape_logical_and( - at::Tensor& self, + const at::Tensor& self, const at::Tensor& other) { TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); return {Shape( c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; } -std::vector compute_shape_logical_not(at::Tensor& self) { +std::vector compute_shape_logical_not( + const at::Tensor& self) { return {Shape(c10::ScalarType::Bool, self.sizes().vec())}; } std::vector compute_shape_logical_or( - at::Tensor& self, + const at::Tensor& self, const at::Tensor& other) { TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); return {Shape( @@ -661,7 +794,7 @@ std::vector compute_shape_logical_or( } std::vector compute_shape_logical_xor( - at::Tensor& self, + const at::Tensor& self, const at::Tensor& other) { TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); return {Shape( @@ -843,6 +976,89 @@ std::vector compute_shape__adaptive_avg_pool2d_backward( return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape__adaptive_avg_pool3d( + const at::Tensor& self, + at::IntArrayRef output_size) { + // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` + // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp` + TORCH_CHECK( + output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3"); + TORCH_CHECK( + (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0), + "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ", + "but received {", + output_size[0], + ", ", + output_size[1], + ", ", + output_size[2], + "}"); + int64_t ndim = self.ndimension(); + for (const auto i : c10::irange(1, ndim)) { + TORCH_CHECK( + self.size(i) > 0, + "adaptive_avg_pool3d(): Expected self to have non-zero size for non-batch dimensions, " + "but Tensor has sizes ", + self.sizes(), + " with dimension ", + i, + " being " + "empty"); + } + TORCH_CHECK( + (ndim == 4 || ndim == 5), + "adaptive_avg_pool3d(): Expected 4D or 5D tensor, but got ", + self.sizes()); + + int64_t channels = self.size(-3); + int64_t output_depth = output_size[0]; + int64_t output_height = output_size[1]; + int64_t output_width = output_size[2]; + + if (ndim == 4) { + return {Shape( + self.scalar_type(), + {channels, output_depth, output_height, output_width})}; + } else { + int64_t nbatch = self.size(0); + return {Shape( + self.scalar_type(), + {nbatch, channels, output_depth, output_height, output_width})}; + } +} + +std::vector compute_shape__adaptive_avg_pool3d_backward( + const at::Tensor& grad_output, + const at::Tensor& self) { + // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` + int64_t ndim = grad_output.ndimension(); + + for (const auto i : c10::irange(1, ndim)) { + TORCH_CHECK( + grad_output.size(i) > 0, + "adaptive_avg_pool3d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, " + "but grad_output has sizes ", + grad_output.sizes(), + " with dimension ", + i, + " being " + "empty"); + } + + TORCH_CHECK( + (ndim == 4 || ndim == 5), + "adaptive_avg_pool3d_backward(): Expected 4D or 5D tensor, but got ", + self.sizes()); + TORCH_CHECK( + self.dtype() == grad_output.dtype(), + "expected dtype ", + self.dtype(), + " for `grad_output` but got dtype ", + grad_output.dtype()); + + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + std::vector compute_shape_glu_backward( const at::Tensor& grad_output, const at::Tensor& self, @@ -878,6 +1094,12 @@ std::vector compute_shape__to_copy( return {Shape(self.scalar_type(), self.sizes().vec())}; } +TORCH_API std::vector compute_shape_clone( + const at::Tensor& self, + c10::optional memory_format) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + std::vector compute_shape_stack(at::TensorList tensors, int64_t dim) { TORCH_CHECK(tensors.size() > 0, "stack expects a non-empty TensorList"); auto wrapped_dim = at::maybe_wrap_dim(dim, tensors[0].ndimension() + 1); @@ -903,7 +1125,7 @@ std::vector compute_shape_stack(at::TensorList tensors, int64_t dim) { std::vector compute_shape_repeat( const at::Tensor& self, at::IntArrayRef repeats) { - CHECK_GE(repeats.size(), self.dim()); + TORCH_CHECK_GE(repeats.size(), self.dim()); int64_t num_new_dimensions = repeats.size() - self.dim(); std::vector padded_size(num_new_dimensions, 1); padded_size.insert( @@ -923,6 +1145,20 @@ std::vector compute_shape_narrow_copy_symint( return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_hardswish(const at::Tensor& self) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_hardswish_backward( + const at::Tensor& grad_output, + const at::Tensor& self) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_selu(const at::Tensor& self) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + // Non-Native Ops std::vector compute_shape_scalar( const at::Scalar& value, @@ -1037,6 +1273,106 @@ std::vector compute_shape_unsqueeze( BuildUnsqueezedDimensions(input_shape.sizes(), dim))}; } +std::vector compute_shape_select_scatter( + const at::Tensor& self, + const at::Tensor& src, + int64_t dim, + int64_t index) { + auto self_meta = at::native::empty_strided_meta( + self.sizes(), + self.strides(), + /*dtype=*/c10::make_optional(self.scalar_type()), + /*layout=*/c10::make_optional(self.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto src_meta = at::native::empty_strided_meta( + src.sizes(), + src.strides(), + /*dtype=*/c10::make_optional(src.scalar_type()), + /*layout=*/c10::make_optional(src.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto out_meta = at::compositeexplicitautograd::select_scatter( + self_meta, src_meta, dim, index); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_diagonal_scatter( + const at::Tensor& self, + const at::Tensor& src, + int64_t offset, + int64_t dim1, + int64_t dim2) { + auto self_meta = at::native::empty_strided_meta( + self.sizes(), + self.strides(), + /*dtype=*/c10::make_optional(self.scalar_type()), + /*layout=*/c10::make_optional(self.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto src_meta = at::native::empty_strided_meta( + src.sizes(), + src.strides(), + /*dtype=*/c10::make_optional(src.scalar_type()), + /*layout=*/c10::make_optional(src.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto out_meta = at::compositeexplicitautograd::diagonal_scatter( + self_meta, src_meta, offset, dim1, dim2); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_slice_scatter( + const at::Tensor& self, + const at::Tensor& src, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { + auto self_meta = at::native::empty_strided_meta( + self.sizes(), + self.strides(), + /*dtype=*/c10::make_optional(self.scalar_type()), + /*layout=*/c10::make_optional(self.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto src_meta = at::native::empty_strided_meta( + src.sizes(), + src.strides(), + /*dtype=*/c10::make_optional(src.scalar_type()), + /*layout=*/c10::make_optional(src.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto out_meta = at::compositeexplicitautograd::slice_scatter( + self_meta, src_meta, dim, start, end, step); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_as_strided_scatter( + const at::Tensor& self, + const at::Tensor& src, + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { + auto self_meta = at::native::empty_strided_meta( + self.sizes(), + self.strides(), + /*dtype=*/c10::make_optional(self.scalar_type()), + /*layout=*/c10::make_optional(self.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto src_meta = at::native::empty_strided_meta( + src.sizes(), + src.strides(), + /*dtype=*/c10::make_optional(src.scalar_type()), + /*layout=*/c10::make_optional(src.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto out_meta = at::compositeexplicitautograd::as_strided_scatter( + self_meta, src_meta, size, stride, storage_offset); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + // Restore unused-parameters warnings #pragma GCC diagnostic pop diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index 6b186f2dff322..77a26ec07fc0b 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -2,11 +2,15 @@ #include #include +#include +#include +#include #include #include #include #include #include +#include #include namespace torch { @@ -16,6 +20,8 @@ namespace lazy { // clang-format off TORCH_API std::vector compute_shape__adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size); TORCH_API std::vector compute_shape__adaptive_avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self); +TORCH_API std::vector compute_shape__adaptive_avg_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API std::vector compute_shape__adaptive_avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self); TORCH_API std::vector compute_shape_abs(const at::Tensor & self); TORCH_API std::vector compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); TORCH_API std::vector compute_shape_bernoulli(const at::Tensor & self, c10::optional generator); @@ -24,11 +30,14 @@ TORCH_API std::vector compute_shape_binary_cross_entropy(con TORCH_API std::vector compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction); TORCH_API std::vector compute_shape_cat(at::TensorList tensors, int64_t dim); TORCH_API std::vector compute_shape_clamp_min(const at::Tensor & self, const at::Scalar & min); +TORCH_API std::vector compute_shape_clone(const at::Tensor & self, c10::optional memory_format); TORCH_API std::vector compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value); TORCH_API std::vector compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups); TORCH_API std::vector compute_shape_convolution_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask); TORCH_API std::vector compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse); TORCH_API std::vector compute_shape_embedding_dense_backward(const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq); +TORCH_API std::vector compute_shape_expand(const at::Tensor & self, at::IntArrayRef size, bool implicit); +TORCH_API std::vector compute_shape_expand(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit); TORCH_API std::vector compute_shape_flip(const at::Tensor & self, at::IntArrayRef dims); TORCH_API std::vector compute_shape_glu_backward(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim); TORCH_API std::vector compute_shape_glu_jvp(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim); @@ -36,30 +45,33 @@ TORCH_API std::vector compute_shape_grid_sampler_2d(const at TORCH_API std::vector compute_shape_grid_sampler_2d_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask); TORCH_API std::vector compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index); TORCH_API std::vector compute_shape_inverse(const at::Tensor & self); -TORCH_API std::vector compute_shape_kl_div_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, bool log_target); +TORCH_API std::vector compute_shape_isnan(const at::Tensor & self); TORCH_API std::vector compute_shape_log_sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer); TORCH_API std::vector compute_shape_log_sigmoid_forward(const at::Tensor & self); TORCH_API std::vector compute_shape_logdet(const at::Tensor & self); -TORCH_API std::vector compute_shape_logical_and(at::Tensor & self, const at::Tensor & other); -TORCH_API std::vector compute_shape_logical_not(at::Tensor & self); -TORCH_API std::vector compute_shape_logical_or(at::Tensor & self, const at::Tensor & other); -TORCH_API std::vector compute_shape_logical_xor(at::Tensor & self, const at::Tensor & other); +TORCH_API std::vector compute_shape_logical_and(const at::Tensor & self, const at::Tensor & other); +TORCH_API std::vector compute_shape_logical_not(const at::Tensor & self); +TORCH_API std::vector compute_shape_logical_or(const at::Tensor & self, const at::Tensor & other); +TORCH_API std::vector compute_shape_logical_xor(const at::Tensor & self, const at::Tensor & other); TORCH_API std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); TORCH_API std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value); TORCH_API std::vector compute_shape_max(const at::Tensor & self); TORCH_API std::vector compute_shape_mean(const at::Tensor & self, c10::optional dtype); TORCH_API std::vector compute_shape_min(const at::Tensor & self); TORCH_API std::vector compute_shape_mv(const at::Tensor & self, const at::Tensor & vec); +TORCH_API std::vector compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps); +TORCH_API std::vector compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_invstd, bool train, double eps, ::std::array output_mask); TORCH_API std::vector compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional train); TORCH_API std::vector compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale); TORCH_API std::vector compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps); TORCH_API std::vector compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask); +TORCH_API std::vector compute_shape_new_empty_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); TORCH_API std::vector compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight); TORCH_API std::vector compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index); TORCH_API std::vector compute_shape_nonzero(const at::Tensor & self); -TORCH_API std::vector compute_shape_random_functional(const at::Tensor & self, c10::optional generator); -TORCH_API std::vector compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional generator); -TORCH_API std::vector compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator); +TORCH_API std::vector compute_shape_random(const at::Tensor & self, c10::optional generator); +TORCH_API std::vector compute_shape_random(const at::Tensor & self, int64_t to, c10::optional generator); +TORCH_API std::vector compute_shape_random(const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator); TORCH_API std::vector compute_shape_relu(const at::Tensor & self); TORCH_API std::vector compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats); TORCH_API std::vector compute_shape_slogdet(const at::Tensor & self); @@ -72,8 +84,11 @@ TORCH_API std::vector compute_shape_std(const at::Tensor & s TORCH_API std::vector compute_shape_sum(const at::Tensor & self, c10::optional dtype); TORCH_API std::vector compute_shape__to_copy(const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format); TORCH_API std::vector compute_shape_trace(const at::Tensor & self); -TORCH_API std::vector compute_shape_zero_functional(const at::Tensor & self); +TORCH_API std::vector compute_shape_zero(const at::Tensor & self); TORCH_API std::vector compute_shape_narrow_copy_symint(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length); +TORCH_API std::vector compute_shape_hardswish(const at::Tensor & self); +TORCH_API std::vector compute_shape_hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self); +TORCH_API std::vector compute_shape_selu(const at::Tensor & self); // Non-Native ops TORCH_API std::vector compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type); @@ -82,6 +97,7 @@ TORCH_API std::vector compute_shape_view(const Output& input0, const std: TORCH_API std::vector compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const c10::optional& stype); // View Ops +// (Now that functionalization pass is used, we should kill these in a later PR) TORCH_API std::vector compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset); TORCH_API std::vector compute_shape_as_strided(const Output& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset); TORCH_API std::vector compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2); @@ -94,6 +110,12 @@ TORCH_API std::vector compute_shape_select_view_update(const Output& targ TORCH_API std::vector compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride); TORCH_API std::vector compute_shape_squeeze(const Output& input, const int& dim); TORCH_API std::vector compute_shape_unsqueeze(const Output& input, const int& dim); + +TORCH_API std::vector compute_shape_select_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index); +TORCH_API std::vector compute_shape_diagonal_scatter(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2); +TORCH_API std::vector compute_shape_slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional start, c10::optional end, int64_t step); +TORCH_API std::vector compute_shape_slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional start, c10::optional end, int64_t step); +TORCH_API std::vector compute_shape_as_strided_scatter(const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset); // clang-format on } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/core/tensor.cpp b/torch/csrc/lazy/core/tensor.cpp index 0fb34f2fa699b..86971dc49bcb3 100644 --- a/torch/csrc/lazy/core/tensor.cpp +++ b/torch/csrc/lazy/core/tensor.cpp @@ -10,6 +10,8 @@ #include #include +#include + namespace torch { namespace lazy { namespace { @@ -482,7 +484,8 @@ torch::lazy::Value GetTensorList(c10::ArrayRef tensors) { } LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor) { - auto* impl = dynamic_cast(tensor.unsafeGetTensorImpl()); + auto* impl = dynamic_cast( + maybe_unwrap_functional(tensor).unsafeGetTensorImpl()); if (impl == nullptr) { // return c10::make_intrusive(); return LazyTensorPtr(); @@ -532,5 +535,27 @@ at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor) { return at::Tensor(c10::make_intrusive(std::move(ltc_tensor))); } +at::Tensor to_lazy_tensor( + const at::Tensor& self, + const c10::TensorOptions& options, + at::Device device, + bool non_blocking, + bool functionalize_output) { + TORCH_INTERNAL_ASSERT(self.device().type() != c10::kLazy); + TORCH_INTERNAL_ASSERT(device.type() == c10::kLazy); + + auto eager_tensor = + self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); + auto lazy_self = torch::lazy::GetOrCreateLtcTensor( + eager_tensor, torch::lazy::atenDeviceToBackendDevice(device)); + auto out = torch::lazy::CreateAtenFromLtcTensor(lazy_self); + if (functionalize_output) { + // See Note [Lazy Tensor Functionalization] + return at::functionalization::impl::to_functional_tensor(out); + } else { + return out; + } +} + } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/core/tensor.h b/torch/csrc/lazy/core/tensor.h index 837e886df34a8..e178ac119c6b0 100644 --- a/torch/csrc/lazy/core/tensor.h +++ b/torch/csrc/lazy/core/tensor.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -11,11 +11,10 @@ namespace torch { namespace lazy { -class TORCH_API SymbolicIntNode : public c10::SymbolicIntNode { +class TORCH_API SymIntNodeImpl : public c10::SymIntNodeImpl { public: - SymbolicIntNode(NodePtr ptr) : node_(std::move(ptr)){}; - std::shared_ptr add( - const std::shared_ptr& other) override { + SymIntNodeImpl(NodePtr ptr) : node_(std::move(ptr)){}; + c10::SymIntNode add(const c10::SymIntNode& other) override { TORCH_CHECK(false, "NYI"); } NodePtr node_; @@ -242,6 +241,46 @@ TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber( TORCH_API at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor); +// Note [Lazy Tensor Functionalization] +// The functionalization pass is implemented by wrapping all TensorImpl +// objects in C++ with an extra FunctionalTensorWrapper object, +// that knows how to perform functionalization +// +// Certain functions in the aten API serve as entry/exit points for +// functionalization, where we need to perform the wrapping/unwrapping: +// - aten::to.device +// - aten::empty + +// Given a non-lazy tensor, this function creates a lazy tensor on the specified +// (lazy) device. The functionalize_output determines whether or not we should +// wrap the output in a "functional wrapper". +// +// How do you know whether to pass true/false for functionalize_output? +// +// Case 1: nonlazy -> lazy +// If you're implementing a function that takes in nonlazy tensors and returns +// lazy tensors, then you should think of that function as an "entrypoint" to +// functionalization, and use functionalize_output=true Examples include: +// - factory functions (the LTC kernel for at::empty) +// - CPU -> Lazy device converions (the LTC kernel for at::to_device) +// +// Case 2: lazy -> lazy +// If you're implementing a function that takes in lazy tensors and returns +// lazy tensors, +// **but** requires creating lazy tensors internally, +// then you can assume that the current function is running inside of some +// outer context where functionalization is already running, that will take +// care of doing the wrapping for you, and use functionalize_output=true +// Examples include: +// - CPU fallback (takes in lazy tensors, converts to cpu, calls kernel, +// converts returns back to lazy tensors). +TORCH_API at::Tensor to_lazy_tensor( + const at::Tensor& self, + const c10::TensorOptions& options, + at::Device device, + bool non_blocking, + bool functionalize_output); + template auto TupleAtenFromLtcTensorsImpl( const std::vector& tensors, diff --git a/torch/csrc/lazy/core/tensor_impl.cpp b/torch/csrc/lazy/core/tensor_impl.cpp index 1434084e502aa..67cd53c442a63 100644 --- a/torch/csrc/lazy/core/tensor_impl.cpp +++ b/torch/csrc/lazy/core/tensor_impl.cpp @@ -95,7 +95,7 @@ LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor) for (auto i : c10::irange(rank)) { auto dim_node = getBackend()->GetIrBuilder()->MakeSizeNode( this->tensor_->GetIrValue(), i); - auto sn = std::make_shared(dim_node); + auto sn = c10::make_intrusive(dim_node); sym_sizes_.push_back(sn->toSymInt()); } } @@ -143,7 +143,13 @@ void LTCTensorImpl::shallow_copy_from( } c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const { - return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size()); + if (FLAGS_ltc_enable_symbolic_shapes) { + return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size()); + } + + // return upper bound + const_cast(this)->setup_size_properties(); + return TensorImpl::sym_sizes_default(); } c10::SymIntArrayRef LTCTensorImpl::sym_sizes() const { diff --git a/torch/csrc/lazy/core/tensor_util.h b/torch/csrc/lazy/core/tensor_util.h index 1d058dd77f585..e4e6a1b7f0c26 100644 --- a/torch/csrc/lazy/core/tensor_util.h +++ b/torch/csrc/lazy/core/tensor_util.h @@ -3,6 +3,8 @@ #include #include +#include + #include #include @@ -62,5 +64,15 @@ at::Scalar MakeIntScalar(T value) { // API returns true. TORCH_API bool IsSpecialScalar(const at::Scalar& value); +// Note: returns a reference instead of a fresh tensor to avoid refcount bumps. +inline const at::Tensor& maybe_unwrap_functional(const at::Tensor& tensor) { + if (at::functionalization::impl::isFunctionalTensor(tensor)) { + return at::functionalization::impl::unsafeGetFunctionalWrapper(tensor) + ->value(); + } else { + return tensor; + } +} + } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index df331c081d669..6ba8e5247d9c8 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -46,8 +47,9 @@ std::string GetTensorsDump( std::vector nodes; std::vector values; for (auto& tensor : tensors) { + auto inner = at::functionalization::impl::from_functional_tensor(tensor); torch::lazy::LazyTensorPtr lazy_tensor = - torch::lazy::TryGetLtcTensor(tensor); + torch::lazy::TryGetLtcTensor(inner); values.push_back(lazy_tensor->GetIrValue()); nodes.push_back(values.back().node.get()); } @@ -127,6 +129,8 @@ void initLazyBindings(PyObject* module) { lazy.def( "_reset_metrics", []() { torch::lazy::MetricsArena::Get()->Reset(); }); lazy.def("_counter_names", []() { return torch::lazy::GetCounterNames(); }); + lazy.def( + "_metrics_report", []() { return torch::lazy::CreateMetricReport(); }); lazy.def("_counter_value", [](const std::string& name) -> py::object { torch::lazy::CounterData* data = torch::lazy::GetCounter(name); return data != nullptr ? py::cast(data->Value()) : py::none(); diff --git a/torch/csrc/lazy/python/init.h b/torch/csrc/lazy/python/init.h index 00df90966e4ef..e9c584ead8ce1 100644 --- a/torch/csrc/lazy/python/init.h +++ b/torch/csrc/lazy/python/init.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include namespace torch { namespace lazy { diff --git a/torch/csrc/lazy/python/python_util.cpp b/torch/csrc/lazy/python/python_util.cpp index e594c752de6f4..00f93e4115ed4 100644 --- a/torch/csrc/lazy/python/python_util.cpp +++ b/torch/csrc/lazy/python/python_util.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include namespace torch { @@ -19,9 +21,10 @@ c10::optional GetPythonFrameTop() { return c10::nullopt; } SourceLocation loc; - loc.line = PyCode_Addr2Line(frame->f_code, frame->f_lasti); - loc.file = THPUtils_unpackString(frame->f_code->co_filename); - loc.function = THPUtils_unpackString(frame->f_code->co_name); + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + loc.line = PyFrame_GetLineNumber(frame); + loc.file = THPUtils_unpackString(code->co_filename); + loc.function = THPUtils_unpackString(code->co_name); return loc; } @@ -30,13 +33,17 @@ std::vector GetPythonFrames() { if (Py_IsInitialized()) { pybind11::gil_scoped_acquire gil; PyFrameObject* frame = PyEval_GetFrame(); + Py_INCREF(frame); while (frame != nullptr) { SourceLocation loc; - loc.line = PyCode_Addr2Line(frame->f_code, frame->f_lasti); - loc.file = THPUtils_unpackString(frame->f_code->co_filename); - loc.function = THPUtils_unpackString(frame->f_code->co_name); + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + loc.line = PyFrame_GetLineNumber(frame); + loc.file = THPUtils_unpackString(code->co_filename); + loc.function = THPUtils_unpackString(code->co_name); frames.push_back(std::move(loc)); - frame = frame->f_back; + auto new_frame = PyFrame_GetBack(frame); + Py_DECREF(frame); + frame = new_frame; } } return frames; diff --git a/torch/csrc/lazy/test_mnist.py b/torch/csrc/lazy/test_mnist.py index 7c4ea654ce72c..16a023df5edde 100644 --- a/torch/csrc/lazy/test_mnist.py +++ b/torch/csrc/lazy/test_mnist.py @@ -68,7 +68,7 @@ def train(log_interval, model, device, train_loader, optimizer, epoch): 'pin_memory': True, 'shuffle': True, 'batch_size': bsz} - train_kwargs.update(cuda_kwargs) + train_kwargs.update(cuda_kwargs) transform = transforms.Compose([ transforms.ToTensor(), diff --git a/torch/csrc/lazy/ts_backend/dynamic_ir.cpp b/torch/csrc/lazy/ts_backend/dynamic_ir.cpp index aaa866af58a02..1180885ab57f7 100644 --- a/torch/csrc/lazy/ts_backend/dynamic_ir.cpp +++ b/torch/csrc/lazy/ts_backend/dynamic_ir.cpp @@ -18,7 +18,7 @@ TSOpVector SizeNode::Lower( arguments.emplace_back(index); torch::lazy::TSOpVector size_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); - CHECK_EQ(size_out.size(), 1); + TORCH_CHECK_EQ(size_out.size(), 1); return size_out; } diff --git a/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp b/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp deleted file mode 100644 index 3b5ca967e4c1d..0000000000000 --- a/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp +++ /dev/null @@ -1,97 +0,0 @@ -#include -#include - -namespace torch { -namespace lazy { - -TSNativeBatchNormBackward::TSNativeBatchNormBackward( - const torch::lazy::Value& grad_out, - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, - const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, - bool training, - double eps, - std::array output_mask) - : torch::lazy::TsNode( - torch::lazy::OpKind(at::aten::native_batch_norm_backward), - {grad_out, - input, - weight, - running_mean, - running_var, - save_mean, - save_invstd}, - {input.shape(), weight.shape(), weight.shape()}, - /*num_outputs=*/3, - torch::lazy::MHash( - training, - eps, - output_mask[0], - output_mask[1], - output_mask[2])), - training_(training), - eps_(eps), - output_mask_(output_mask) {} - -TSNativeBatchNormBackward::TSNativeBatchNormBackward( - const torch::lazy::Value& grad_out, - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, - bool training, - double eps, - std::array output_mask) - : torch::lazy::TsNode( - torch::lazy::OpKind(at::aten::native_batch_norm_backward), - {grad_out, input, weight, save_mean, save_invstd}, - {input.shape(), weight.shape(), weight.shape()}, - /*num_outputs=*/3, - torch::lazy::MHash( - training, - eps, - output_mask[0], - output_mask[1], - output_mask[2])), - training_(training), - eps_(eps), - output_mask_(output_mask) {} - -std::string TSNativeBatchNormBackward::ToString() const { - std::stringstream ss; - ss << torch::lazy::TsNode::ToString() << ", training=" << training_ - << ", eps=" << eps_; - return ss.str(); -} - -TSNativeBatchNormForward::TSNativeBatchNormForward( - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& bias, - const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, - bool training, - double momentum, - double eps) - : torch::lazy::TsNode( - torch::lazy::OpKind(at::aten::native_batch_norm), - {input, weight, bias, running_mean, running_var}, - {input.shape(), running_mean.shape(), running_var.shape()}, - /*num_outputs=*/3, - torch::lazy::MHash(training, momentum, eps)), - training_(training), - momentum_(momentum), - eps_(eps) {} - -std::string TSNativeBatchNormForward::ToString() const { - std::stringstream ss; - ss << torch::lazy::TsNode::ToString() << ", training=" << training_ - << ", momentum=" << momentum_ << ", eps=" << eps_; - return ss.str(); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h b/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h deleted file mode 100644 index 515a75f0b29c2..0000000000000 --- a/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h +++ /dev/null @@ -1,156 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace lazy { - -// Node for the backward batch norm operator. -class TSNativeBatchNormBackward : public torch::lazy::TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::native_batch_norm_backward); - } - - TSNativeBatchNormBackward( - const torch::lazy::Value& grad_out, - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, - const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, - bool training, - double eps, - std::array output_mask); - - TSNativeBatchNormBackward( - const torch::lazy::Value& grad_out, - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, - bool training, - double eps, - std::array output_mask); - - bool CanBeReused( - const torch::lazy::Value& grad_out, - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, - const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, - bool training, - double eps, - std::array output_mask) const { - size_t i = 0; - return ( - operand(i++) == grad_out && operand(i++) == input && - operand(i++) == weight && operand(i++) == running_mean && - operand(i++) == running_var && operand(i++) == save_mean && - operand(i++) == save_invstd && training_ == training && eps_ == eps && - output_mask_ == output_mask); - } - - bool CanBeReused( - const torch::lazy::Value& grad_out, - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, - bool training, - double eps, - std::array output_mask) const { - size_t i = 0; - return ( - operand(i++) == grad_out && operand(i++) == input && - operand(i++) == weight && operand(i++) == save_mean && - operand(i++) == save_invstd && training_ == training && eps_ == eps && - output_mask_ == output_mask); - } - - std::string ToString() const override; - - bool training() const { - return training_; - } - - double eps() const { - return eps_; - } - - const std::array& output_mask() const { - return output_mask_; - } - - TSOpVector Lower( - std::shared_ptr function, - TSLoweringContext* loctx) const override; - - private: - bool training_; - double eps_; - std::array output_mask_; -}; - -class TSNativeBatchNormForward : public torch::lazy::TsNode { - public: - static OpKind ClassOpKind() { - return OpKind(at::aten::native_batch_norm); - } - - TSNativeBatchNormForward( - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& bias, - const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, - bool training, - double momentum, - double eps); - - bool CanBeReused( - const torch::lazy::Value& input, - const torch::lazy::Value& weight, - const torch::lazy::Value& bias, - const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, - bool training, - double momentum, - double eps) const { - size_t i = 0; - return ( - operand(i++) == input && operand(i++) == weight && - operand(i++) == bias && operand(i++) == running_mean && - operand(i++) == running_var && training_ == training && - momentum_ == momentum && eps == eps_); - } - - std::string ToString() const override; - - bool training() const { - return training_; - } - - double momentum() const { - return momentum_; - } - - double eps() const { - return eps_; - } - - TSOpVector Lower( - std::shared_ptr function, - TSLoweringContext* loctx) const override; - - private: - bool training_; - double momentum_; - double eps_; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/random_ops.cpp b/torch/csrc/lazy/ts_backend/ops/random_ops.cpp index cb3708953cd58..7c2e1f4386c9c 100644 --- a/torch/csrc/lazy/ts_backend/ops/random_ops.cpp +++ b/torch/csrc/lazy/ts_backend/ops/random_ops.cpp @@ -38,7 +38,7 @@ torch::lazy::TSOpVector Normal::Lower( arguments.emplace_back("std", std_); torch::lazy::TSOpVector normal__out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); - CHECK_EQ(normal__out.size(), 1); + TORCH_CHECK_EQ(normal__out.size(), 1); return normal__out; } diff --git a/torch/csrc/lazy/ts_backend/ops/to_copy.h b/torch/csrc/lazy/ts_backend/ops/to_copy.h index fdc90f46e609d..4b96b1c389f78 100644 --- a/torch/csrc/lazy/ts_backend/ops/to_copy.h +++ b/torch/csrc/lazy/ts_backend/ops/to_copy.h @@ -110,7 +110,7 @@ class ToCopy : public torch::lazy::TsNode { kwarguments.emplace_back("memory_format", memory_format); torch::lazy::TSOpVector _to_copy_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); - CHECK_EQ(_to_copy_out.size(), 1); + TORCH_CHECK_EQ(_to_copy_out.size(), 1); return _to_copy_out; } diff --git a/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp b/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp index 9d23c48fd94e8..534a9bca130db 100644 --- a/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp +++ b/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp @@ -13,7 +13,6 @@ #include #include #include -#include #include #include #include @@ -40,7 +39,7 @@ torch::lazy::Value MaybeExpand( std::vector GetExpandDimensions( const torch::lazy::Shape& shape, std::vector dimensions) { - CHECK_GE(dimensions.size(), shape.dim()) << shape; + TORCH_CHECK_GE(dimensions.size(), shape.dim()) << shape; int64_t base = dimensions.size() - shape.dim(); for (size_t i = 0; i < shape.dim(); ++i) { if (dimensions[base + i] == -1) { @@ -50,27 +49,6 @@ std::vector GetExpandDimensions( return dimensions; } -// Returns a 1-D shape for batch norm weight or bias based on the input shape. -torch::lazy::Shape BatchNormFeaturesShape( - const torch::lazy::LazyTensorPtr& input) { - CHECK(input); - auto input_shape = input->shape().Get(); - return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]); -} - -// Returns the IR for the given input or the provided default value broadcasted -// to the default shape, if the input is undefined. -torch::lazy::Value GetIrValueOrDefault( - const torch::lazy::LazyTensorPtr& input, - const at::Scalar& default_value, - const torch::lazy::Shape& default_shape, - const torch::lazy::BackendDevice& device) { - return input - ? input->GetIrValue() - : torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( - default_value, default_shape, device); -} - torch::lazy::ViewInfo CreateAsStridedViewInfo( const torch::lazy::Shape& input_shape, std::vector size, @@ -165,105 +143,6 @@ torch::lazy::LazyTensorPtr narrow( return input->CreateViewTensor(std::move(view_info)); } -std::tuple< - torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr> -ts_native_batch_norm( - const torch::lazy::LazyTensorPtr& input, - const torch::lazy::LazyTensorPtr& weight, - const torch::lazy::LazyTensorPtr& bias, - torch::lazy::LazyTensorPtr& running_mean, - torch::lazy::LazyTensorPtr& running_var, - bool training, - double momentum, - double eps) { - torch::lazy::Shape features_shape = BatchNormFeaturesShape(input); - torch::lazy::Value weight_value = - GetIrValueOrDefault(weight, 1, features_shape, input->GetDevice()); - torch::lazy::Value bias_value = - GetIrValueOrDefault(bias, 0, features_shape, input->GetDevice()); - torch::lazy::Value running_mean_value = - GetIrValueOrDefault(running_mean, 0, features_shape, input->GetDevice()); - torch::lazy::Value running_var_value = - GetIrValueOrDefault(running_var, 0, features_shape, input->GetDevice()); - torch::lazy::NodePtr node = ReuseOrMakeNode( - input->GetIrValue(), - weight_value, - bias_value, - running_mean_value, - running_var_value, - training, - momentum, - eps); - torch::lazy::LazyTensorPtr output = torch::lazy::LazyTensor::Create( - torch::lazy::Value(node, 0), input->GetDevice()); - torch::lazy::LazyTensorPtr running_mean_output = - torch::lazy::LazyTensor::Create( - torch::lazy::Value(node, 1), input->GetDevice()); - torch::lazy::LazyTensorPtr running_var_output = - torch::lazy::LazyTensor::Create( - torch::lazy::Value(node, 2), input->GetDevice()); - return std::make_tuple( - std::move(output), - std::move(running_mean_output), - std::move(running_var_output)); -} - -std::tuple< - torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr> -ts_native_batch_norm_backward( - const torch::lazy::LazyTensorPtr& grad_out, - const torch::lazy::LazyTensorPtr& input, - const torch::lazy::LazyTensorPtr& weight, - const torch::lazy::LazyTensorPtr& running_mean, - const torch::lazy::LazyTensorPtr& running_var, - const torch::lazy::LazyTensorPtr& save_mean, - const torch::lazy::LazyTensorPtr& save_invstd, - bool training, - double eps, - c10::ArrayRef output_mask) { - torch::lazy::Shape features_shape = BatchNormFeaturesShape(input); - torch::lazy::Value weight_value = - GetIrValueOrDefault(weight, 1, features_shape, input->GetDevice()); - torch::lazy::NodePtr node; - if (!running_mean && !running_var) { - node = ReuseOrMakeNode( - grad_out->GetIrValue(), - input->GetIrValue(), - weight_value, - save_mean->GetIrValue(), - save_invstd->GetIrValue(), - training, - eps, - std::array{output_mask[0], output_mask[1], output_mask[2]}); - } else { - CHECK(running_mean); - CHECK(running_var); - node = ReuseOrMakeNode( - grad_out->GetIrValue(), - input->GetIrValue(), - weight_value, - running_mean->GetIrValue(), - running_var->GetIrValue(), - save_mean->GetIrValue(), - save_invstd->GetIrValue(), - training, - eps, - std::array{output_mask[0], output_mask[1], output_mask[2]}); - } - torch::lazy::LazyTensorPtr grad_input = torch::lazy::LazyTensor::Create( - torch::lazy::Value(node, 0), input->GetDevice()); - torch::lazy::LazyTensorPtr grad_weight = torch::lazy::LazyTensor::Create( - torch::lazy::Value(node, 1), input->GetDevice()); - torch::lazy::LazyTensorPtr grad_bias = torch::lazy::LazyTensor::Create( - torch::lazy::Value(node, 2), input->GetDevice()); - return std::make_tuple( - std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); -} - torch::lazy::LazyTensorPtr permute( const torch::lazy::LazyTensorPtr& input, c10::ArrayRef dims) { diff --git a/torch/csrc/lazy/ts_backend/tensor_aten_ops.h b/torch/csrc/lazy/ts_backend/tensor_aten_ops.h index d2104d5fcdde0..0d5a49bdfbd67 100644 --- a/torch/csrc/lazy/ts_backend/tensor_aten_ops.h +++ b/torch/csrc/lazy/ts_backend/tensor_aten_ops.h @@ -38,36 +38,6 @@ torch::lazy::LazyTensorPtr narrow( int64_t start, int64_t length); -std::tuple< - torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr> -ts_native_batch_norm( - const torch::lazy::LazyTensorPtr& input, - const torch::lazy::LazyTensorPtr& weight, - const torch::lazy::LazyTensorPtr& bias, - torch::lazy::LazyTensorPtr& running_mean, - torch::lazy::LazyTensorPtr& running_var, - bool training, - double momentum, - double eps); - -std::tuple< - torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr> -ts_native_batch_norm_backward( - const torch::lazy::LazyTensorPtr& grad_out, - const torch::lazy::LazyTensorPtr& input, - const torch::lazy::LazyTensorPtr& weight, - const torch::lazy::LazyTensorPtr& running_mean, - const torch::lazy::LazyTensorPtr& running_var, - const torch::lazy::LazyTensorPtr& save_mean, - const torch::lazy::LazyTensorPtr& save_invstd, - bool training, - double eps, - c10::ArrayRef output_mask); - // Permute the dimensions of this tensor according to the given permutation. torch::lazy::LazyTensorPtr permute( const torch::lazy::LazyTensorPtr& input, diff --git a/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp b/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp index 6d33ebc525032..a390ac76c1260 100644 --- a/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp +++ b/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace at { // This function is defined in the codegenerated RegisterDispatchKey.cpp file. @@ -39,12 +40,12 @@ struct TSBackendDeviceType : public BackendDeviceType { class TSBackendImpl : public torch::lazy::BackendImplInterface { public: - TSBackendImpl() : default_device_type_(at::kCPU) { + TSBackendImpl() { // TODO(whc) unify how all our flags are set and parsed as envs static bool env_use_cuda = std::getenv("LTC_TS_CUDA") != nullptr; auto type = (env_use_cuda || FLAGS_torch_lazy_ts_cuda) ? at::kCUDA : at::kCPU; - default_device_type_ = TSBackendDeviceType(type); + default_device_type_ = std::make_shared(type); } const IrBuilder* GetIrBuilder() const override { @@ -52,6 +53,10 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { return builder; } + std::string CreateMetricReport() const override { + return "TSBackendImpl: N/A"; + } + std::unique_ptr CreateLoweringContext( const std::string& name, torch::lazy::BackendDevice device, @@ -85,9 +90,9 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { const torch::lazy::Shape& shape, const torch::lazy::BackendDevice& device) const override { at::TensorOptions options = tensor.options().device( - default_device_type_.c10Type(), device.ordinal()); - if (tensor.device().type() == default_device_type_.c10Type() && - default_device_type_.c10Type() == at::kCUDA) { + default_device_type_->c10Type(), device.ordinal()); + if (tensor.device().type() == default_device_type_->c10Type() && + default_device_type_->c10Type() == at::kCUDA) { return std::make_shared( tensor.to(options, /*non_blocking=*/true), shape, device); } else if (tensor.device().type() == at::kCPU && tensor.numel() == 1) { @@ -133,26 +138,28 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { std::vector instances) const override; std::vector ExecuteComputation( - torch::lazy::Computation& computation, + torch::lazy::ComputationPtr computation, c10::ArrayRef arguments, const torch::lazy::BackendDevice& device) const override; std::shared_ptr GetDefaultDeviceType() const override { - return std::make_shared(default_device_type_); + return default_device_type_; } at::DeviceType EagerFallbackDeviceType() const override; - void SetDefaultDeviceType(std::string type) override { - default_device_type_ = TSBackendDeviceType(c10::Device(type).type()); - // The first CUDA usage could happen via lazy tensors. Initialize CUDA here - // to account for that, at::scalar_tensor constructor triggers everything we - // need. - static auto init_cuda = default_device_type_.c10Type() == at::kCUDA - ? c10::optional( - at::scalar_tensor(0, at::TensorOptions().device(at::kCUDA))) - : c10::nullopt; + void SetDefaultDeviceType(int8_t type) override { + default_device_type_ = std::make_shared( + static_cast(type)); + } + + int64_t GetDefaultDeviceOrdinal() const { + return default_device_ordinal_; + } + + virtual void SetDefaultDeviceOrdinal(int64_t ordinal) { + default_device_ordinal_ = ordinal; } std::vector GetBackendDevices() const override; @@ -173,7 +180,8 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { void PrepareToExit() const override; private: - TSBackendDeviceType default_device_type_; + std::shared_ptr default_device_type_; + int64_t default_device_ordinal_{0}; }; torch::lazy::BackendDataPtr TSBackendImpl::CreateDataPlaceholder( @@ -195,11 +203,13 @@ std::vector TSBackendImpl::Compile( } std::vector TSBackendImpl::ExecuteComputation( - torch::lazy::Computation& computation, + torch::lazy::ComputationPtr computation, c10::ArrayRef arguments, const torch::lazy::BackendDevice& device) const { - torch::jit::GraphExecutor& graph_executor = - static_cast(computation).graph_executor(); + auto ts_computation = + std::dynamic_pointer_cast(computation); + TORCH_CHECK(ts_computation, "Computation isn't TSComputation"); + torch::jit::GraphExecutor& graph_executor = ts_computation->graph_executor(); std::vector stack; for (const auto& argument : arguments) { const auto ts_data = std::static_pointer_cast(argument); @@ -209,7 +219,8 @@ std::vector TSBackendImpl::ExecuteComputation( // TODO(whc) should this check be made more general? it's written somewhat // oddly CHECK( - (c10::DeviceType)default_device_type_.type != at::kCUDA || + static_cast(default_device_type_->type) != + at::kCUDA || ts_data->data().device().type() == at::kCUDA); stack.emplace_back(ts_data->data()); } diff --git a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp index f3aef2cd1a568..f5352f2d5ba8d 100644 --- a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp +++ b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp @@ -337,7 +337,11 @@ void ts_eager_fallback( } else { dev_str << ""; } - TORCH_WARN( + // We should never hit this for a view op, + // because LazyTensor should provide a lowering for the + // corresponding view_copy operator. The functionalization pass will + // take care of calling the view_copy operator intead of the view. + TORCH_CHECK( false, "The operator ", op.schema().operator_name(), diff --git a/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp b/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp index 1defb5993b37c..ff3d1aa07b78e 100644 --- a/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp +++ b/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp @@ -34,7 +34,7 @@ void TSLoweringContext::Lower(const Node* node) { // codegenned or refactored nodes TSOpVector ops = tsnode->Lower(function_, this); CHECK(!ops.empty()) << "Failed to lower: " << *node; - CHECK_EQ(node->num_outputs(), ops.size()); + TORCH_CHECK_EQ(node->num_outputs(), ops.size()); for (size_t i = 0; i < ops.size(); ++i) { AssignOutputOp(torch::lazy::Output(node, i), ops[i]); } diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index a6796884cc86d..9e837249f52df 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include #include @@ -19,6 +21,8 @@ #include #include +using at::Tensor; + namespace torch { namespace lazy { namespace { @@ -46,48 +50,8 @@ c10::optional GetLtcDevice( } // namespace -at::Tensor LazyNativeFunctions::alias(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("lazy::"); - return self; -} - -at::Tensor LazyNativeFunctions::as_strided( - const at::Tensor& self, - at::IntArrayRef size, - at::IntArrayRef stride, - c10::optional storage_offset) { - TORCH_LAZY_FN_COUNTER("lazy::"); - torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); - auto xsize = torch::lazy::ToI64Vector(size); - auto xstride = torch::lazy::ToI64Vector(stride); - if (!torch::lazy::StrideIsSupported(xstride)) { - return at::native:: - call_fallback_fn<<c_eager_fallback, ATEN_OP(as_strided)>::call( - self, size, stride, storage_offset); - } - return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::as_strided( - self_tensor, std::move(xsize), std::move(xstride), storage_offset)); -} - -const at::Tensor& LazyNativeFunctions::as_strided_( - const at::Tensor& self, - at::IntArrayRef size, - at::IntArrayRef stride, - c10::optional storage_offset) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - auto xsize = torch::lazy::ToI64Vector(size); - auto xstride = torch::lazy::ToI64Vector(stride); - if (!torch::lazy::StrideIsSupported(xstride)) { - return at::native:: - call_fallback_fn<<c_eager_fallback, ATEN_OP(as_strided_)>::call( - self, size, stride, storage_offset); - } - torch::lazy::as_strided_( - self_tensor, std::move(xsize), std::move(xstride), storage_offset); - return self; -} - +// clone is special in LT because we make it a no-op. +// This should be safe to do, because every operator in the LT is functional. at::Tensor LazyNativeFunctions::clone( const at::Tensor& self, c10::optional memory_format) { @@ -211,12 +175,19 @@ at::Tensor LazyNativeFunctions::_to_copy( auto lazy_self = torch::lazy::TryGetLtcTensor(self); if (!lazy_self && device && device->type() == c10::kLazy) { // Case 1: eager->lazy (we create a new lazy tensor) - - auto eager_tensor = - self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); - lazy_self = torch::lazy::GetOrCreateLtcTensor( - eager_tensor, torch::lazy::atenDeviceToBackendDevice(*device)); - return torch::lazy::CreateAtenFromLtcTensor(lazy_self); + // See Note [Lazy Tensor Functionalization] + // Invariant: if the functionalization key is in the exclude set, then we're + // expected to return an ordinary tensor, which will be "lifted" into a + // functional wrapper later. + bool functionalize_output = + !c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::Functionalize); + return torch::lazy::to_lazy_tensor( + self, + options, + *device, + /*non_blocking=*/non_blocking, + /*functionalize_output=*/functionalize_output); } else if (device && device->type() != c10::kLazy) { // Case 2: lazy->eager (forces a graph break since we are materializing a // tensor) @@ -298,22 +269,21 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::diagonal( - const at::Tensor& self, - int64_t offset, - int64_t dim1, - int64_t dim2) { - TORCH_LAZY_FN_COUNTER("lazy::"); - - auto input = GetLtcTensor(self); - auto input_shape = input->shape(); - dim1 = at::maybe_wrap_dim(dim1, self); - dim2 = at::maybe_wrap_dim(dim2, self); - auto diagonal_info = DiagonalInfo{offset, dim1, dim2}; - auto view_info = - ViewInfo(ViewInfo::Type::kDiagonal, input_shape, diagonal_info); - - return CreateAtenFromLtcTensor(input->CreateViewTensor(std::move(view_info))); +at::Tensor LazyNativeFunctions::empty_symint( + c10::SymIntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional memory_format) { + // TODO: support SymIntNodes as well + return empty( + c10::asIntArrayRefSlow(size), + dtype, + layout, + device, + pin_memory, + memory_format); } at::Tensor LazyNativeFunctions::empty( @@ -330,7 +300,18 @@ at::Tensor LazyNativeFunctions::empty( .pinned_memory(pin_memory) .dtype(dtype); auto x_result = at::empty(size, options, memory_format); - return CreateLtcTensor(x_result, GetLtcDevice(device)); + auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device)); + // See Note [Lazy Tensor Functionalization] + if (c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::Functionalize)) { + // Invariant: if the functionalization key is in the exclude set, then we're + // expected to return an ordinary tensor, which will be "lifted" into a + // functional wrapper later. + return tensor; + } else { + auto wrapped = at::functionalization::impl::to_functional_tensor(tensor); + return wrapped; + } } at::Tensor LazyNativeFunctions::empty_strided( @@ -342,16 +323,7 @@ at::Tensor LazyNativeFunctions::empty_strided( c10::optional pin_memory) { TORCH_LAZY_FN_COUNTER("lazy::"); at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt); - return LazyNativeFunctions::as_strided(t, size, stride, /*storage_offset=*/0); -} - -at::Tensor LazyNativeFunctions::expand( - const at::Tensor& self, - at::IntArrayRef size, - bool implicit) { - TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec())); + return t.as_strided(size, stride, /*storage_offset=*/0); } at::Tensor& LazyNativeFunctions::fill_( @@ -374,78 +346,6 @@ at::Tensor LazyNativeFunctions::max_pool3d( self, kernel_size, stride, padding, dilation, ceil_mode); } -std::tuple LazyNativeFunctions:: - native_batch_norm( - const at::Tensor& input, - const c10::optional& weight, - const c10::optional& bias, - const c10::optional& running_mean, - const c10::optional& running_var, - bool training, - double momentum, - double eps) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto input_tensor = torch::lazy::TryGetLtcTensor(input); - const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); - auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device); - auto running_var_tensor = GetOrCreateLtcTensor(running_var, device); - auto outputs = ts_native_batch_norm( - torch::lazy::TryGetLtcTensor(input), - GetOrCreateLtcTensor(weight, device), - GetOrCreateLtcTensor(bias, device), - running_mean_tensor, - running_var_tensor, - training, - momentum, - eps); - return std::make_tuple( - torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)), - torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)), - torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs))); -} - -std::tuple LazyNativeFunctions:: - native_batch_norm_backward( - const at::Tensor& grad_out, - const at::Tensor& input, - const c10::optional& weight, - const c10::optional& running_mean, - const c10::optional& running_var, - const c10::optional& save_mean, - const c10::optional& save_invstd, - bool train, - double eps, - std::array output_mask) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out); - const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice(); - torch::lazy::LazyTensorPtr null_tensor; - bool running_stats = running_mean && running_mean->defined(); - CHECK_EQ(running_var && running_var->defined(), running_stats); - auto gradients = ts_native_batch_norm_backward( - torch::lazy::TryGetLtcTensor(grad_out), - torch::lazy::TryGetLtcTensor(input), - GetOrCreateLtcTensor(weight, device), - running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor, - running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor, - GetOrCreateLtcTensor(save_mean, device), - GetOrCreateLtcTensor(save_invstd, device), - train, - eps, - output_mask); - at::Tensor undefined; - return std::make_tuple( - output_mask[0] - ? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients)) - : undefined, - output_mask[1] - ? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients)) - : undefined, - output_mask[2] - ? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients)) - : undefined); -} - // We need to explicitly override max pooling operators and just call the // fallback for them because we've customized the autograd function for them // (backward needs saved indices from forward). @@ -484,17 +384,6 @@ at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward( indices); } -at::Tensor LazyNativeFunctions::narrow( - const at::Tensor& self, - int64_t dim, - int64_t start, - int64_t length) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::narrow(self_tensor, dim, start, length)); -} - at::Tensor& LazyNativeFunctions::normal_( at::Tensor& self, double mean, @@ -525,124 +414,165 @@ at::Tensor& LazyNativeFunctions::normal_( // std::move(shapes)); lazy_self.SetInPlaceIrValue(node); return self; }; -at::Tensor LazyNativeFunctions::permute( - const at::Tensor& self, - at::IntArrayRef dims) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims))); -} - -at::Tensor LazyNativeFunctions::select( - const at::Tensor& self, - int64_t dim, - int64_t index) { - TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::select(torch::lazy::TryGetLtcTensor(self), dim, index)); -} - -at::Tensor LazyNativeFunctions::slice( +at::Tensor LazyNativeFunctions::_unsafe_view( const at::Tensor& self, - int64_t dim, - c10::optional start, - c10::optional end, - int64_t step) { - int64_t start_val = start.has_value() ? start.value() : 0; - int64_t end_val = end.has_value() ? end.value() : INT64_MAX; - TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::slice( - torch::lazy::TryGetLtcTensor(self), dim, start_val, end_val, step)); -} - -at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) { + at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self))); + return LazyNativeFunctions::view_copy(self, size); } -at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self), dim)); +// This is needed by the torch.tensor constructor. +// LazyTensor always opts into functionalization. +// "lifting" a tensor for functionalization means wrapping it in a +// FunctionalTensorWrapper object. +at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(tensor)); + return at::functionalization::impl::to_functional_tensor(tensor); } - -at::Tensor& LazyNativeFunctions::squeeze_(at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - torch::lazy::squeeze_(self_tensor); - return self; +at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(tensor)); + return at::functionalization::impl::to_functional_tensor(tensor); } -at::Tensor& LazyNativeFunctions::squeeze_(at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - torch::lazy::squeeze_(self_tensor, dim); - return self; +// All of the below ops correspond to CompositeExplicitAutograd kernels from +// core that call into view operators internally. These are all composite ops +// that LTC can technically re-use / get for free, but we need to +// "functionalize" them to remove the view ops before we can use them. +at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) { + return at::functionalization::functionalize_aten_op::call(tensors); } - -at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), 0, 1)); +at::Tensor LazyNativeFunctions::new_empty_strided( + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + return at::functionalization:: + functionalize_aten_op::call( + self, size, stride, dtype, layout, device, pin_memory); } -at::Tensor& LazyNativeFunctions::t_(at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - torch::lazy::transpose_(self_tensor, 0, 1); - return self; +at::Tensor LazyNativeFunctions::narrow_copy( + const at::Tensor& self, + int64_t dim, + int64_t start, + int64_t length) { + return at::functionalization::functionalize_aten_op::call(self, dim, start, length); } - -at::Tensor LazyNativeFunctions::transpose( +at::Tensor LazyNativeFunctions::pixel_shuffle( const at::Tensor& self, - int64_t dim0, - int64_t dim1) { - TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), dim0, dim1)); + int64_t upscale_factor) { + return at::functionalization::functionalize_aten_op::call(self, upscale_factor); } - -at::Tensor& LazyNativeFunctions::transpose_( - at::Tensor& self, - int64_t dim0, - int64_t dim1) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - torch::lazy::transpose_(self_tensor, dim0, dim1); - return self; +at::Tensor LazyNativeFunctions::pixel_unshuffle( + const at::Tensor& self, + int64_t downscale_factor) { + return at::functionalization::functionalize_aten_op::call(self, downscale_factor); } - -at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::unsqueeze(torch::lazy::TryGetLtcTensor(self), dim)); +at::Tensor LazyNativeFunctions::select_backward( + const at::Tensor& grad_output, + at::IntArrayRef input_sizes, + int64_t dim, + int64_t index) { + return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, index); +} +at::Tensor LazyNativeFunctions::_trilinear( + const at::Tensor& i1, + const at::Tensor& i2, + const at::Tensor& i3, + at::IntArrayRef expand1, + at::IntArrayRef expand2, + at::IntArrayRef expand3, + at::IntArrayRef sumdim, + int64_t unroll_dim) { + return at::functionalization::functionalize_aten_op:: + call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); +} +::std::tuple LazyNativeFunctions::linalg_inv_ex( + const at::Tensor& self, + bool check_errors) { + return at::functionalization::functionalize_aten_op::call(self, check_errors); } - -at::Tensor& LazyNativeFunctions::unsqueeze_(at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - torch::lazy::unsqueeze_(self_tensor, dim); - return self; +at::Tensor LazyNativeFunctions::linalg_pinv( + const at::Tensor& self, + const c10::optional& atol, + const c10::optional& rtol, + bool hermitian) { + return at::functionalization::functionalize_aten_op::call(self, atol, rtol, hermitian); } -at::Tensor LazyNativeFunctions::view( +// functionalize_aten_op can't handle out= ops directly. +// Instead, we can call the composite kernel from core, and copy and mutations +// back to the inputs. +at::Tensor& LazyNativeFunctions::logsumexp_out( const at::Tensor& self, - at::IntArrayRef size) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size))); + at::IntArrayRef dim, + bool keepdim, + at::Tensor& out) { + auto self_wrapped = at::functionalization::impl::to_functional_tensor(self); + auto out_wrapped = at::functionalization::impl::to_functional_tensor(out); + // directly call the composite kernel from core. + // Make sure to re-enable functionalization first. + auto curr_tls = c10::impl::tls_local_dispatch_key_set(); + auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet(); + tls_reenable_functionalize.set_included(curr_tls.included_); + tls_reenable_functionalize.set_excluded( + curr_tls.excluded_.remove(c10::DispatchKey::Functionalize)); + c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize); + at::native::logsumexp_out(self_wrapped, dim, keepdim, out_wrapped); + auto out_unwrapped = + at::functionalization::impl::from_functional_tensor(out_wrapped); + // propagate mutations back to the inputs (including resizing) + out.resize_(out_unwrapped.sizes()); + out.copy_(out_unwrapped); + return out; +} + +at::Tensor LazyNativeFunctions::diagonal_backward( + const at::Tensor& grad_output, + at::IntArrayRef input_sizes, + int64_t offset, + int64_t dim1, + int64_t dim2) { + return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, offset, dim1, dim2); } -at::Tensor LazyNativeFunctions::_unsafe_view( - const at::Tensor& self, - at::IntArrayRef size) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - return torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size))); +at::Tensor LazyNativeFunctions::slice_backward( + const at::Tensor& grad_output, + at::IntArrayRef input_sizes, + int64_t dim, + int64_t start, + int64_t end, + int64_t step) { + return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, start, end, step); +} + +// re-use the composite kernel from core, that way we don't need to provide a +// backwards formula for native_group_norm +std::tuple LazyNativeFunctions::native_group_norm( + const at::Tensor& input, + const c10::optional& weight, + const c10::optional& bias, + int64_t N, + int64_t C, + int64_t HxW, + int64_t group, + double eps) { + return at::native::math_group_norm( + input, weight, bias, N, C, HxW, group, eps); } void InitializeAtenBindings() {} diff --git a/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp b/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp index 1fabeb5c03462..2c0598ecfe1b1 100644 --- a/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp +++ b/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include namespace torch { @@ -62,7 +61,7 @@ torch::jit::Value* GenerateClone( std::vector clone_arguments; clone_arguments.emplace_back(val); TSOpVector cloned = LowerBuiltin(at::aten::clone, function, clone_arguments); - CHECK_EQ(cloned.size(), 1); + TORCH_CHECK_EQ(cloned.size(), 1); return cloned.front(); } @@ -90,7 +89,7 @@ torch::jit::Value* GenerateSlice( arguments.emplace_back(end); arguments.emplace_back(step); TSOpVector selected = LowerBuiltin(at::aten::slice, function, arguments); - CHECK_EQ(selected.size(), 1); + TORCH_CHECK_EQ(selected.size(), 1); return selected.front(); } @@ -107,41 +106,6 @@ TSOpVector TsNode::Lower( return LowerBuiltin(this, function, arguments); } -// TS specific ops -TSOpVector TSNativeBatchNormForward::Lower( - std::shared_ptr function, - TSLoweringContext* loctx) const { - std::vector arguments; - for (size_t i = 0; i < 5; ++i) { - arguments.emplace_back(loctx->GetOutputOp(operand(i))); - } - arguments.emplace_back(training_); - arguments.emplace_back(momentum_); - arguments.emplace_back(eps_); - return LowerBuiltin(this, function, arguments); -} - -TSOpVector TSNativeBatchNormBackward::Lower( - std::shared_ptr function, - TSLoweringContext* loctx) const { - std::vector arguments; - for (size_t i = 0; i < 3; ++i) { - arguments.emplace_back(loctx->GetOutputOp(operand(i))); - } - c10::optional null_arg; - if (operands().size() == 5) { - arguments.emplace_back(null_arg); - arguments.emplace_back(null_arg); - } - for (size_t i = 3; i < operands().size(); ++i) { - arguments.emplace_back(loctx->GetOutputOp(operand(i))); - } - arguments.emplace_back(training_); - arguments.emplace_back(eps_); - arguments.emplace_back(output_mask_); - return LowerBuiltin(this, function, arguments); -} - // Non-native ops torch::lazy::TSOpVector Cast::Lower( std::shared_ptr function, @@ -177,7 +141,7 @@ torch::lazy::TSOpVector Expand::Lower( // of rank 0. This leads to false positives when checking for internal // memory overlap, because at::has_internal_overlap returns // MemOverlap::YES when a stride is set to 0. - CHECK_EQ(expand_out.size(), 1); + TORCH_CHECK_EQ(expand_out.size(), 1); return {GenerateClone(expand_out.front(), function)}; } return expand_out; @@ -204,7 +168,7 @@ torch::lazy::TSOpVector AsStrided::Lower( arguments.emplace_back(stride); arguments.emplace_back(storage_offset); TSOpVector as_strided_out = LowerBuiltin(this, function, arguments); - CHECK_EQ(as_strided_out.size(), 1); + TORCH_CHECK_EQ(as_strided_out.size(), 1); return {GenerateClone(as_strided_out.front(), function)}; } @@ -224,7 +188,7 @@ torch::lazy::TSOpVector AsStridedViewUpdate::Lower( dest_arguments.emplace_back(storage_offset); TSOpVector as_strided_out = LowerBuiltin(at::aten::as_strided, function, dest_arguments); - CHECK_EQ(as_strided_out.size(), 1); + TORCH_CHECK_EQ(as_strided_out.size(), 1); torch::jit::Value* as_strided = as_strided_out.front(); GenerateCopy(as_strided, loctx->GetOutputOp(input_op), function); return {destination}; @@ -272,8 +236,8 @@ torch::lazy::TSOpVector Narrow::Lower( const torch::lazy::Output& input = operand(0); torch::jit::Value* base = loctx->GetOutputOp(input); const torch::lazy::Shape& input_shape = input.shape(); - CHECK_EQ(sizes.size(), base_indices.size()); - CHECK_EQ(input_shape.dim(), base_indices.size()); + TORCH_CHECK_EQ(sizes.size(), base_indices.size()); + TORCH_CHECK_EQ(input_shape.dim(), base_indices.size()); for (size_t dim = 0; dim < base_indices.size(); ++dim) { int64_t start = base_indices[dim]; base = GenerateSlice( @@ -294,7 +258,7 @@ torch::lazy::TSOpVector NarrowViewUpdate::Lower( GenerateClone(loctx->GetOutputOp(operand(0)), function); const torch::lazy::Output& source_argument = operand(1); const torch::lazy::Shape& source_shape = source_argument.shape(); - CHECK_EQ(source_shape.dim(), base_indices.size()); + TORCH_CHECK_EQ(source_shape.dim(), base_indices.size()); torch::jit::Value* base = dest; for (size_t dim = 0; dim < base_indices.size(); ++dim) { int64_t start = base_indices[dim]; diff --git a/torch/csrc/lazy/tutorial.md b/torch/csrc/lazy/tutorial.md index 6704a6ccc839c..6d4e75affc38a 100644 --- a/torch/csrc/lazy/tutorial.md +++ b/torch/csrc/lazy/tutorial.md @@ -195,7 +195,7 @@ if __name__ == '__main__': 'pin_memory': True, 'shuffle': True, 'batch_size': bsz} - train_kwargs.update(cuda_kwargs) + train_kwargs.update(cuda_kwargs) transform=transforms.Compose([ transforms.ToTensor(), diff --git a/torch/csrc/multiprocessing/init.cpp b/torch/csrc/multiprocessing/init.cpp index 098cf7f97d536..b503b9f42c54d 100644 --- a/torch/csrc/multiprocessing/init.cpp +++ b/torch/csrc/multiprocessing/init.cpp @@ -1,3 +1,4 @@ +#include #include #include #include diff --git a/torch/csrc/profiler/api.cpp b/torch/csrc/profiler/api.cpp index e89955a039f8f..e1e45bbfd6d49 100644 --- a/torch/csrc/profiler/api.cpp +++ b/torch/csrc/profiler/api.cpp @@ -61,26 +61,29 @@ torch::profiler::impl::ProfilerConfig getProfilerConfig() { return state_ptr->config(); } -CUDAStubs::~CUDAStubs() = default; +ProfilerStubs::~ProfilerStubs() = default; namespace { -struct DefaultCUDAStubs : public CUDAStubs { - void record(int* /*device*/, CUDAEventStub* /*event*/, int64_t* /*cpu_ns*/) - const override { +struct DefaultCUDAStubs : public ProfilerStubs { + void record( + int* /*device*/, + ProfilerEventStub* /*event*/, + int64_t* /*cpu_ns*/) const override { fail(); } - float elapsed(const CUDAEventStub* /*event*/, const CUDAEventStub* /*event2*/) - const override { + float elapsed( + const ProfilerEventStub* /*event*/, + const ProfilerEventStub* /*event2*/) const override { fail(); return 0.f; } - void nvtxMarkA(const char* /*name*/) const override { + void mark(const char* /*name*/) const override { fail(); } - void nvtxRangePushA(const char* /*name*/) const override { + void rangePush(const char* /*name*/) const override { fail(); } - void nvtxRangePop() const override { + void rangePop() const override { fail(); } bool enabled() const override { @@ -100,25 +103,82 @@ struct DefaultCUDAStubs : public CUDAStubs { } }; -const DefaultCUDAStubs default_stubs; -constexpr const DefaultCUDAStubs* default_stubs_addr = &default_stubs; +const DefaultCUDAStubs default_cuda_stubs; +constexpr const DefaultCUDAStubs* default_cuda_stubs_addr = &default_cuda_stubs; // Constant initialization, so it is guaranteed to be initialized before // static initialization calls which may invoke registerCUDAMethods -inline const CUDAStubs*& cuda_stubs() { - static const CUDAStubs* stubs_ = - static_cast(default_stubs_addr); +inline const ProfilerStubs*& cuda_stubs() { + static const ProfilerStubs* stubs_ = + static_cast(default_cuda_stubs_addr); + return stubs_; +} + +struct DefaultITTStubs : public ProfilerStubs { + void record( + int* /*device*/, + ProfilerEventStub* /*event*/, + int64_t* /*cpu_ns*/) const override { + fail(); + } + float elapsed( + const ProfilerEventStub* /*event*/, + const ProfilerEventStub* /*event2*/) const override { + fail(); + return 0.f; + } + void mark(const char* /*name*/) const override { + fail(); + } + void rangePush(const char* /*name*/) const override { + fail(); + } + void rangePop() const override { + fail(); + } + bool enabled() const override { + return false; + } + void onEachDevice(std::function /*op*/) const override { + fail(); + } + void synchronize() const override { + fail(); + } + ~DefaultITTStubs() override = default; + + private: + void fail() const { + AT_ERROR("ITT used in profiler but not enabled."); + } +}; + +const DefaultITTStubs default_itt_stubs; +constexpr const DefaultITTStubs* default_itt_stubs_addr = &default_itt_stubs; +// Constant initialization, so it is guaranteed to be initialized before +// static initialization calls which may invoke registerITTMethods +inline const ProfilerStubs*& itt_stubs() { + static const ProfilerStubs* stubs_ = + static_cast(default_itt_stubs_addr); return stubs_; } } // namespace -const CUDAStubs* cudaStubs() { +const ProfilerStubs* cudaStubs() { return cuda_stubs(); } -void registerCUDAMethods(CUDAStubs* stubs) { +void registerCUDAMethods(ProfilerStubs* stubs) { cuda_stubs() = stubs; } +const ProfilerStubs* ittStubs() { + return itt_stubs(); +} + +void registerITTMethods(ProfilerStubs* stubs) { + itt_stubs() = stubs; +} + } // namespace impl } // namespace profiler } // namespace torch diff --git a/torch/csrc/profiler/api.h b/torch/csrc/profiler/api.h index 53610e30d99ad..b75875d2f4ea8 100644 --- a/torch/csrc/profiler/api.h +++ b/torch/csrc/profiler/api.h @@ -23,13 +23,20 @@ enum class C10_API_ENUM ProfilerState { CPU, // CPU-only profiling CUDA, // CPU + CUDA events NVTX, // only emit NVTX markers + ITT, // only emit ITT markers KINETO, // use libkineto KINETO_GPU_FALLBACK, // use CUDA events when CUPTI is not available KINETO_ONDEMAND, // run the profiler in on-demand mode NUM_PROFILER_STATES, // must be the last one }; -enum class C10_API_ENUM ActiveProfilerType { NONE = 0, LEGACY, KINETO, NVTX }; +enum class C10_API_ENUM ActiveProfilerType { + NONE = 0, + LEGACY, + KINETO, + NVTX, + ITT +}; struct TORCH_API ExperimentalConfig { explicit ExperimentalConfig( @@ -130,28 +137,31 @@ TORCH_API ActiveProfilerType profilerType(); TORCH_API ProfilerConfig getProfilerConfig(); // ---------------------------------------------------------------------------- -// -- CUDA -------------------------------------------------------------------- +// -- Annotation -------------------------------------------------------------- // ---------------------------------------------------------------------------- -using CUDAEventStub = std::shared_ptr; +using ProfilerEventStub = std::shared_ptr; -struct TORCH_API CUDAStubs { - virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) - const = 0; - virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) +struct TORCH_API ProfilerStubs { + virtual void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns) const = 0; - virtual void nvtxMarkA(const char* name) const = 0; - virtual void nvtxRangePushA(const char* name) const = 0; - virtual void nvtxRangePop() const = 0; + virtual float elapsed( + const ProfilerEventStub* event, + const ProfilerEventStub* event2) const = 0; + virtual void mark(const char* name) const = 0; + virtual void rangePush(const char* name) const = 0; + virtual void rangePop() const = 0; virtual bool enabled() const { return false; } virtual void onEachDevice(std::function op) const = 0; virtual void synchronize() const = 0; - virtual ~CUDAStubs(); + virtual ~ProfilerStubs(); }; -TORCH_API void registerCUDAMethods(CUDAStubs* stubs); -TORCH_API const CUDAStubs* cudaStubs(); +TORCH_API void registerCUDAMethods(ProfilerStubs* stubs); +TORCH_API const ProfilerStubs* cudaStubs(); +TORCH_API void registerITTMethods(ProfilerStubs* stubs); +TORCH_API const ProfilerStubs* ittStubs(); } // namespace impl } // namespace profiler diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 67ceb90820129..5f546943c001d 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -1,19 +1,33 @@ #include #include +#include +#include +#include #include +#include #include +#ifdef USE_KINETO +#include +#endif + +#include #include #include +#include #include +#include #include #include +#include namespace torch { namespace profiler { namespace impl { +using trace_ptr_t = + std::unique_ptr; void InputOutputEncoder::push(c10::ArrayRef values) { for (const auto& value : values) { @@ -45,7 +59,8 @@ void InputOutputEncoder::push(const at::Tensor& t) { tensor_metadata_.emplace_back( /*ptr_=*/(void*)t.unsafeGetTensorImpl(), /*dtype_=*/t.scalar_type(), - /*dim_=*/(uint32_t)dim); + /*dim_=*/(uint32_t)dim, + /*layout_=*/t.layout()); for (const auto i : sizes) { tensor_sizes_.emplace_back(i); @@ -72,6 +87,7 @@ auto InputOutputEncoder::getNextShapesAndDtypes() { (void)_; // Suppress unused variable warning out.shapes_.back().push_back(*tensor_size_it++); } + out.tensor_metadata_.emplace_back(md); out.dtypes_.emplace_back(scalarTypeToTypeMeta(md.dtype_).name()); } break; @@ -80,15 +96,18 @@ auto InputOutputEncoder::getNextShapesAndDtypes() { // TODO: Skip TensorLists for now. } out.dtypes_.emplace_back("TensorList"); + out.tensor_metadata_.emplace_back(); break; case Tag::Scalar: out.dtypes_.emplace_back("Scalar"); + out.tensor_metadata_.emplace_back(); break; case Tag::UndefinedTensor: case Tag::Other: out.dtypes_.emplace_back(); + out.tensor_metadata_.emplace_back(); break; case Tag::TERMINATOR: @@ -163,41 +182,7 @@ PythonTracerBase& PythonTracerBase::get() { } } // namespace python_tracer -#define OUT_T(method_name) decltype(std::declval().method_name()) -#define DEFINE_VISITOR( \ - method_name, \ - torch_op_field, \ - backend_field, \ - allocation_field, \ - py_field, \ - py_c_field) \ - OUT_T(method_name) Result::method_name() const { \ - using out_t = OUT_T(method_name); \ - return c10::visit( \ - c10::overloaded( \ - [&](const ExtraFields& e) -> out_t { \ - (void)e; \ - return torch_op_field; \ - }, \ - [&](const ExtraFields& e) -> out_t { \ - (void)e; \ - return backend_field; \ - }, \ - [&](const ExtraFields& e) -> out_t { \ - (void)e; \ - return allocation_field; \ - }, \ - [&](const ExtraFields& e) -> out_t { \ - (void)e; \ - return py_field; \ - }, \ - [&](const ExtraFields& e) -> out_t { \ - (void)e; \ - return py_c_field; \ - }), \ - extra_fields_); \ - } - +namespace { std::string toString(const ExtraFields& e) { if (e.module_.has_value()) { return fmt::format( @@ -210,53 +195,94 @@ std::string toString(const ExtraFields& e) { e.callsite_.funcname_.str()); } -using torch::profiler::impl::kineto::KinetoActivityType; -namespace { -KinetoActivityType scopeToType(at::RecordScope scope) { +auto scopeToType(at::RecordScope scope) { return scope == at::RecordScope::USER_SCOPE - ? KinetoActivityType::USER_ANNOTATION - : KinetoActivityType::CPU_OP; + ? libkineto::ActivityType::USER_ANNOTATION + : libkineto::ActivityType::CPU_OP; +} + +int64_t torchOpEndNS( + const ExtraFields& e, + const bool finished, + const std::weak_ptr& parent) { + if (finished && e.end_time_ns_ == std::numeric_limits::min()) { + auto p = parent.lock(); + if (p) { + return p->endTimeNS(); + } + } + return e.end_time_ns_; +} + +auto kinetoEventCorrelationID( + const ExtraFields& e, + const std::weak_ptr& parent) { + if (e.correlation_id_) { + return e.correlation_id_; + } + auto p = parent.lock(); + return p ? p->correlationID() : 0; } } // namespace -DEFINE_VISITOR( - name, - e.name_, - e.name_, - "[memory]", - toString(e), - e.function_name_.str()); -DEFINE_VISITOR( - kinetoType, - scopeToType(e.scope_), - scopeToType(e.scope_), - KinetoActivityType::CPU_INSTANT_EVENT, - KinetoActivityType::PYTHON_FUNCTION, - KinetoActivityType::PYTHON_FUNCTION); -DEFINE_VISITOR(correlationID, e.correlation_id_, 0, 0, 0, 0); -DEFINE_VISITOR( - endTimeNS, - e.end_time_ns_, - e.end_time_us_ * 1000, - start_time_ns_, - e.end_time_ns_, - e.end_time_ns_); -DEFINE_VISITOR( - endTID, - e.end_tid_, - start_tid_, - start_tid_, - start_tid_, - start_tid_); -DEFINE_VISITOR( - deviceType, - c10::DeviceType::CPU, - c10::DeviceType::CPU, - e.device_type_, - c10::DeviceType::CPU, - c10::DeviceType::CPU); -#undef DEFINE_VISITOR -#undef OUT_T +#define ATTRIBUTE(event_type, expr) \ + [&](const ExtraFields& e) { \ + (void)e; \ + return expr; \ + } + +std::string Result::name() const { + return visit(c10::overloaded( + ATTRIBUTE(Allocation, std::string("[memory]")), + ATTRIBUTE(OutOfMemory, std::string("[OutOfMemory]")), + ATTRIBUTE(PyCall, toString(e)), + ATTRIBUTE(PyCCall, std::string(e.function_name_.str())), + [](const auto& e) -> std::string { return e.name_; })); +} + +libkineto::ActivityType Result::kinetoType() const { + return visit(c10::overloaded( + ATTRIBUTE(TorchOp, scopeToType(e.scope_)), + ATTRIBUTE(Backend, scopeToType(e.scope_)), + ATTRIBUTE(Allocation, libkineto::ActivityType::CPU_INSTANT_EVENT), + ATTRIBUTE(OutOfMemory, libkineto::ActivityType::CPU_INSTANT_EVENT), + ATTRIBUTE(PyCall, libkineto::ActivityType::PYTHON_FUNCTION), + ATTRIBUTE(PyCCall, libkineto::ActivityType::PYTHON_FUNCTION), + ATTRIBUTE(Kineto, e.activity_type_))); +} + +uint64_t Result::correlationID() const { + return visit(c10::overloaded( + ATTRIBUTE(TorchOp, e.correlation_id_), + ATTRIBUTE(Kineto, kinetoEventCorrelationID(e, parent_)), + [&](const auto&) -> uint64_t { return 0; })); +} + +int64_t Result::endTimeNS() const { + return visit(c10::overloaded( + ATTRIBUTE(TorchOp, torchOpEndNS(e, finished_, parent_)), + ATTRIBUTE(Backend, e.end_time_us_ * 1000), + ATTRIBUTE(Allocation, start_time_ns_), + ATTRIBUTE(OutOfMemory, start_time_ns_), + ATTRIBUTE(Kineto, start_time_ns_ + e.duration_us_ * 1000), + [&](const auto& e) -> int64_t { return e.end_time_ns_; })); +} + +uint64_t Result::endTID() const { + return visit(c10::overloaded( + ATTRIBUTE(TorchOp, e.end_tid_), + [&](const auto&) -> uint64_t { return start_tid_; })); +} + +c10::DeviceType Result::deviceType() const { + using torch::autograd::profiler::deviceTypeFromActivity; + return visit(c10::overloaded( + ATTRIBUTE(Allocation, e.device_type_), + ATTRIBUTE(OutOfMemory, e.device_type_), + ATTRIBUTE(Kineto, deviceTypeFromActivity(e.activity_type_)), + [&](const auto&) { return c10::DeviceType::CPU; })); +} +#undef ATTRIBUTE template ThreadLocalSubqueue::EventBlock::EventBlock() { @@ -331,6 +357,7 @@ std::unique_ptr ThreadLocalSubqueue::begin_op( } event->start_time_ = torch::profiler::impl::getApproximateTime(); + event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS(); return out; } @@ -391,6 +418,312 @@ auto steal_or_default(T& it) { } } +void mark_finished(std::shared_ptr& r) { + TORCH_INTERNAL_ASSERT(!r->finished_, r->name()); + r->finished_ = true; + TORCH_INTERNAL_ASSERT(r->endTimeNS() >= r->start_time_ns_, r->name()); +} + +static constexpr const char* indexKey = "Profiler Event Index"; + +void passEventsToKineto( + const std::vector>& results, + uint64_t start_time_us, + uint64_t end_time_us) { + using namespace torch::profiler::impl::kineto; + TraceWrapper cpu_trace(start_time_us, "PyTorch Profiler"); + + // Generate Kineto events for each event recorded by the PyTorch profiler. + for (const auto i : c10::irange(results.size())) { + const auto& e = results[i]; + const auto* activity = cpu_trace.addCPUActivity( + e->name(), + e->kinetoType(), + e->kineto_info_, + e->correlationID(), + e->start_time_ns_ / 1000, + e->endTimeNS() / 1000); + + TORCH_INTERNAL_ASSERT(activity || !kKinetoAvailable); + if (activity) { + addMetadata(activity, indexKey, std::to_string(i)); + } + } + + // Kineto adds the events that it collected. + cpu_trace.transferCpuTrace(end_time_us); +} + +#ifdef USE_KINETO +// There are two mechanisms that we use to connect Profiler and Kineto events. +// The first is the correlation ID. The profiler pushes a unique integer at the +// start of an op and pops it at the end. Kineto then associates the events +// that it collects with that correlation ID and sets the linked activity of +// the events that it collected to point to the profiler op. +// +// However, this is not a sufficient description because it does not retain +// dependency information between kineto ops. Consider a call to `torch.add`. +// Three events will be collected: +// `aten::add` (TorchOp, collected by profiler) +// `cudaLaunchKernel` (CUDA runtime event, collected by Kineto) +// `at::vectorized_...` (GPU kernel, collected by Kineto) +// If we only relied on correlation IDs we would set both Kineto events as +// children of the `at::add`, rather than the correct +// `at::add -> cudaLaunchKernel -> at::vectorized_...` +// +// Kineto surfaces this information through a second concept called a "flow". +// In this example, the `cudaLaunchKernel` event is the start of a flow and the +// GPU kernel has the same flow id but is not a start event. Thus, when merging +// the Kineto events into the call tree we first add all events which are flow +// start nodes. We then merge the rest, trying to pair them with flow starts +// and falling back to correlation ID if necessary. For any nodes without +// linked events the caller is determined using the normal tree construction +// algorithm. +class TransferEvents { + using itrace_t = libkineto::ITraceActivity; + using activity_t = torch::profiler::impl::kineto::activity_t; + + public: + TransferEvents( + std::vector>& results, + trace_ptr_t& trace) + : results_{results} { + auto* trace_activities_ptr = trace->get()->activities(); + TORCH_INTERNAL_ASSERT(trace_activities_ptr != nullptr); + trace_activities_ = *trace_activities_ptr; + reassociate(); + extractEventsFromTrace(); + setParents(); + } + + private: + static long long extractIndex(const std::string& metadata_json) { + static const auto prefix = fmt::format("\"{}\": ", indexKey); + auto pos = metadata_json.find(prefix); + return (pos == std::string::npos) ? unmatchedIndex : [&]() { + auto end = metadata_json.find(",", pos); + end = (end == std::string::npos) ? metadata_json.size() : end; + return std::stoll(metadata_json.substr(pos + prefix.size(), end)); + }(); + } + + std::shared_ptr lookup(const itrace_t* key) { + if (key == nullptr) { + return nullptr; + } + + // First check the map. + auto it = kineto_events_.find(key); + if (it != kineto_events_.end()) { + return it->second; + } + + // Then fallback to the encoded metadata. + const auto index = extractIndex(key ? key->metadataJson() : ""); + if (index != unmatchedIndex) { + auto out = results_.get().at(index); + kineto_events_[key] = out; + return out; + } + + // And finally give up. + return nullptr; + } + + void reassociate() { + // Match profiler events with the corresponding kineto events. Kineto may + // have moved or copied the activities, so we have to recover the + // relationship between `libkineto::ITraceActivity` and `Result`. + for (const auto* activity : trace_activities_) { + TORCH_INTERNAL_ASSERT(activity != nullptr); + auto e = lookup(activity); + if (e != nullptr) { + TORCH_INTERNAL_ASSERT(e->kineto_activity_ == nullptr); + e->kineto_activity_ = static_cast(activity); + } + } + if (results_.get().size() != kineto_events_.size()) { + TORCH_WARN(fmt::format( + "Failed to recover relationship between all profiler and kineto events: " + "{} vs. {} reassociated.", + results_.get().size(), + kineto_events_.size())); + } + } + + std::shared_ptr resultFromActivity(const itrace_t* activity) { + TORCH_INTERNAL_ASSERT(activity != nullptr); + + // Kineto is inconsistent with types, so we have to cast to int32. + torch::profiler::impl::kineto::DeviceAndResource device_and_resource{ + static_cast(activity->deviceId()), + static_cast(activity->resourceId())}; + + auto event = Result::create( + activity->timestamp() * 1000, + noTID, // Placeholder + device_and_resource, + ExtraFields{ + activity->name(), + activity->duration(), + static_cast(activity->correlationId()), + activity->type(), + {/*id=*/static_cast(activity->flowId()), + /*type=*/static_cast(activity->flowType()), + /*start=*/activity->flowStart()}}); + + // NB: It's tempting to set `event->kineto_activity_`; however we can only + // guarantee that the events we passed to Kineto are of type + // `GenericTraceActivity`. Others may derive from ITraceActivity and thus + // are not safe to cast. + return event; + } + + std::shared_ptr toResult(const itrace_t* activity) { + auto e = lookup(activity); + + // Until we are very sure that we can reassociate kineto and profiler + // events we need to be very defensive. + const auto type = activity->type(); + if (e == nullptr && + (type == libkineto::ActivityType::CPU_OP || + type == libkineto::ActivityType::CPU_INSTANT_EVENT || + type == libkineto::ActivityType::USER_ANNOTATION || + type == libkineto::ActivityType::PYTHON_FUNCTION)) { + TORCH_WARN_ONCE( + "Detected an event which was likely passed to kineto by the PyTorch " + "profiler, but is not present in the set of known events: ", + activity->name(), + " This most likely means that Kineto has not " + "maintained address stability for this event. Please report this to " + "the PyTorch team."); + return nullptr; + } + + if (e == nullptr) { + e = resultFromActivity(activity); + results_.get().push_back(e); + kineto_events_[activity] = e; + } + return e; + } + + void extractEventsFromTrace() { + for (const auto* activity : trace_activities_) { + auto e = toResult(activity); + const auto* linked_activity = activity->linkedActivity(); + if (e && linked_activity) { + e->visit(c10::overloaded( + [&](ExtraFields& i) { + i.linked_activity_ = toResult(linked_activity); + }, + [](auto&) { TORCH_INTERNAL_ASSERT(false); })); + } + } + } + + void setKinetoTID( + std::shared_ptr& r, + std::shared_ptr parent) { + r->visit(c10::overloaded( + [&](ExtraFields& i) { + TORCH_INTERNAL_ASSERT(r->start_tid_ == noTID); + r->start_tid_ = parent ? parent->start_tid_ + : at::RecordFunction::currentThreadId(); + }, + [](auto&) {})); + + for (auto& child : r->children_) { + setKinetoTID(child, r); + } + } + + void setParents() { + // First pass: Collect start events and set parent to linked event. + ska::flat_hash_map> flow_map; + for (auto& e : results_.get()) { + TORCH_INTERNAL_ASSERT(e != nullptr); + e->visit(c10::overloaded( + [&](const ExtraFields& i) { + if (i.flow.type == libkineto::kLinkAsyncCpuGpu && i.flow.start) { + auto inserted = flow_map.insert({i.flow.id, e}); +#ifdef USE_ROCM + if (inserted.second) { + TORCH_WARN_ONCE( + "ROCTracer produced duplicate flow start: ", i.flow.id); + } +#else // USE_ROCM + TORCH_INTERNAL_ASSERT(inserted.second); +#endif // USE_ROCM + } + TORCH_INTERNAL_ASSERT(e->parent_.expired()); + e->parent_ = i.linked_activity_; + }, + [](const auto&) {})); + } + + // Second pass + for (auto& e : results_.get()) { + e->visit(c10::overloaded( + [&](const ExtraFields& i) { + // Flow takes priority over linked event. + const auto it = flow_map.find(i.flow.id); + if (it != flow_map.end() && + i.flow.type == libkineto::kLinkAsyncCpuGpu && !i.flow.start) { + e->parent_ = it->second; + } + + // If a parent was set we have to do some bookkeeping. + auto parent = e->parent_.lock(); + if (parent) { + parent->children_.push_back(e); + mark_finished(e); + } + }, + [](const auto&) {})); + } + + // Set TIDs now that we have established lineage. + for (auto& e : results_.get()) { + if (e->parent_.expired()) { + setKinetoTID(e, nullptr); + } + } + } + + static constexpr long long unmatchedIndex = -1; + static constexpr auto noTID = std::numeric_limits::max(); + std::reference_wrapper>> results_; + std::vector trace_activities_; + ska::flat_hash_map> kineto_events_; +}; +#else +class TransferEvents { + public: + template + TransferEvents(Args&&...) {} +}; +#endif + +trace_ptr_t addKinetoEvents( + std::vector>& results, + uint64_t start_time_us, + uint64_t end_time_us, + const ProfilerConfig& config) { + using namespace torch::profiler::impl::kineto; + passEventsToKineto(results, start_time_us, end_time_us); + + // In on demand mode kineto is directly controlled by other machinery. + if (config.state == ProfilerState::KINETO_ONDEMAND) { + return nullptr; + } + + auto trace = std::make_unique(stopTrace()); + TORCH_INTERNAL_ASSERT(trace || !kKinetoAvailable); + TransferEvents transfer{results, trace}; + return trace; +} + struct EvaluateFunctionVisitor { void operator()( ExtraFields& first, @@ -439,17 +772,27 @@ void build_tree(std::vector>& events) { end_events_; auto push_event = [&stacks, &end_events_](std::shared_ptr& event) { + // Kineto builds subtrees using correlation ids and flows, so some Kineto + // events are already marked finished before the main tree building + // algorithm. It's fine to ignore them; the root event of these subtrees + // not a Kineto op and will be handled normally. + if (c10::holds_alternative>( + event->extra_fields_) && + event->finished_) { + return; + } + TORCH_INTERNAL_ASSERT(event->parent_.expired()); - TORCH_INTERNAL_ASSERT(event->children_.empty()); + for (const auto& child : event->children_) { + TORCH_INTERNAL_ASSERT(child->finished_); + } TORCH_INTERNAL_ASSERT(!event->finished_); auto parent_it = stacks.find(event->start_tid_); if (parent_it == stacks.end()) { - auto fwd_tid = c10::visit( - c10::overloaded( - [](const op_fields& i) { return i.forward_tid_; }, - [](const auto&) -> uint64_t { return 0; }), - event->extra_fields_); + auto fwd_tid = event->visit(c10::overloaded( + [](const op_fields& i) { return i.forward_tid_; }, + [](const auto&) -> uint64_t { return 0; })); if (fwd_tid) { parent_it = stacks.find(fwd_tid); } @@ -468,11 +811,11 @@ void build_tree(std::vector>& events) { // encounter such a case we don't push to `end_events_`. stacks[event->start_tid_] = event; } else { - event->finished_ = true; + mark_finished(event); } }; - auto pop_event = [&stacks](const std::shared_ptr& event) { + auto pop_event = [&stacks](std::shared_ptr event) { if (event->finished_) { // This event was marked finished by a previous `pop_event` call. return; @@ -483,12 +826,12 @@ void build_tree(std::vector>& events) { while (frame.get() != event.get()) { TORCH_INTERNAL_ASSERT(frame != nullptr); - frame->finished_ = true; + mark_finished(frame); TORCH_INTERNAL_ASSERT(!frame->parent_.expired()); frame = frame->parent_.lock(); } - event->finished_ = true; + mark_finished(event); stacks.erase(start_tid); auto new_frame = event->parent_.lock(); if (new_frame != nullptr) { @@ -514,8 +857,13 @@ void build_tree(std::vector>& events) { } } // namespace -std::vector> RecordQueue::getRecords( - std::function time_converter) { +std::pair< + std::vector>, + std::unique_ptr> +RecordQueue::getRecords( + std::function time_converter, + uint64_t start_time_us, + uint64_t end_time_us) { auto converter = [&](approx_time_t t) { return t == std::numeric_limits::min() ? std::numeric_limits::min() @@ -557,7 +905,8 @@ std::vector> RecordQueue::getRecords( steal_or_default(jit_stack_it), steal_or_default(jit_module_it), steal_or_default(extra_args_it), - steal_or_default(gpu_fallback_it)))); + steal_or_default(gpu_fallback_it), + i.allow_tf32_cublas_))); } queue.op_events_.clear(); queue.inputs_outputs_.clear(); @@ -575,6 +924,15 @@ std::vector> RecordQueue::getRecords( /*extra_fields_=*/std::move(i))); } queue.allocations_.clear(); + for (auto& i : queue.ooms_) { + auto start_time = converter(i.start_time_); + out.emplace_back(Result::create( + start_time, + /*start_tid_=*/queue.tid(), + /*kineto_info_=*/queue.kineto_info(), + /*extra_fields_=*/std::move(i))); + } + queue.ooms_.clear(); for (auto& i : queue.py_calls_) { python_enters.push_back( @@ -590,8 +948,9 @@ std::vector> RecordQueue::getRecords( tracer.clear(); } + auto trace = addKinetoEvents(out, start_time_us, end_time_us, config_); build_tree(out); - return out; + return {out, std::move(trace)}; } } // namespace impl diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index c070d41c31c1f..5b05f923521f2 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -25,13 +25,17 @@ enum class EventType : uint8_t { TorchOp = 0, Backend, Allocation, + OutOfMemory, PyCall, - PyCCall + PyCCall, + Kineto }; template struct ExtraFields; +struct Result; + struct TorchOpBasicFields { int64_t sequence_number_; uint64_t forward_tid_; @@ -44,9 +48,17 @@ struct TorchOpBasicFields { uint64_t end_tid_{0}; }; +struct TensorMetadata { + void* ptr_; + c10::ScalarType dtype_; + uint32_t dim_; + c10::Layout layout_; +}; + struct Inputs { std::vector> shapes_; std::vector dtypes_; + std::vector> tensor_metadata_; }; using jit_stack_t = std::vector; @@ -54,8 +66,8 @@ using jit_modules_t = std::vector; using extra_args_t = std::unordered_map; struct FallbackPair { - CUDAEventStub cuda_event_start_ = nullptr; - CUDAEventStub cuda_event_end_ = nullptr; + ProfilerEventStub cuda_event_start_ = nullptr; + ProfilerEventStub cuda_event_end_ = nullptr; }; template <> @@ -68,7 +80,8 @@ struct ExtraFields : TorchOpBasicFields { jit_stack_t&& jit_stack, jit_modules_t&& jit_modules, extra_args_t&& extra_args, - FallbackPair&& gpu_fallback) + FallbackPair&& gpu_fallback, + bool allow_tf32_cublas) : TorchOpBasicFields(std::move(f)), correlation_id_{correlation_id}, end_time_ns_{end_time_ns}, @@ -76,7 +89,8 @@ struct ExtraFields : TorchOpBasicFields { jit_stack_{std::move(jit_stack)}, jit_modules_{std::move(jit_modules)}, extra_args_{std::move(extra_args)}, - gpu_fallback_{std::move(gpu_fallback)} {} + gpu_fallback_{std::move(gpu_fallback)}, + allow_tf32_cublas_{allow_tf32_cublas} {} uint64_t correlation_id_; time_t end_time_ns_; Inputs inputs_; @@ -84,6 +98,7 @@ struct ExtraFields : TorchOpBasicFields { jit_modules_t jit_modules_; extra_args_t extra_args_; FallbackPair gpu_fallback_; + bool allow_tf32_cublas_; }; template <> @@ -114,6 +129,21 @@ static_assert( std::is_pod>::value, "Non-POD member of ExtraFields."); +template <> +struct ExtraFields { + torch::profiler::impl::approx_time_t start_time_; + int64_t alloc_size_; + int64_t total_allocated_; + int64_t total_reserved_; + c10::DeviceType device_type_; + c10::DeviceIndex device_index_; +}; + +// For performance. +static_assert( + std::is_pod>::value, + "Non-POD member of ExtraFields."); + struct PyFrameState { int line_no_; at::StringView filename_; @@ -126,7 +156,7 @@ using strong_t = strong:: using PyModuleSelf = strong_t; using PyModuleCls = strong_t; -using PyCFunction = strong_t; +using PyMethod = strong_t; struct NNModuleInfo { PyModuleSelf self_; @@ -181,14 +211,48 @@ struct ExtraFields : public PyExtraFieldsBase { at::StringView function_name_; }; +template <> +struct ExtraFields { + // Mirrors `libkineto::GenericTraceActivity::Flow`. This information is used + // during post processing to properly embed Kineto events into the broader + // profiler tree structure. End users are not generally expected to use these + // fields directly, but they are available for debugging. + struct Flow { + uint32_t id{0}; + uint32_t type{0}; + uint32_t start{0}; + }; + + std::string name_; + int64_t duration_us_; + uint64_t correlation_id_; + libkineto::ActivityType activity_type_; + Flow flow; + std::weak_ptr linked_activity_{}; +}; + struct TORCH_API Result : public std::enable_shared_from_this { template [[nodiscard]] static std::shared_ptr create(Args... args) { return std::shared_ptr(new Result(std::forward(args)...)); } + template + decltype(auto) visit(T&& visitor) { + return c10::visit(std::forward(visitor), extra_fields_); + } + + template + decltype(auto) visit(T&& visitor) const { + return c10::visit(std::forward(visitor), extra_fields_); + } + + EventType tag() const { + return visit([](const auto& i) { return deduceTag(i); }); + } + std::string name() const; - torch::profiler::impl::kineto::KinetoActivityType kinetoType() const; + libkineto::ActivityType kinetoType() const; uint64_t correlationID() const; int64_t endTimeNS() const; uint64_t endTID() const; @@ -201,14 +265,18 @@ struct TORCH_API Result : public std::enable_shared_from_this { ExtraFields, ExtraFields, ExtraFields, + ExtraFields, ExtraFields, - ExtraFields> + ExtraFields, + ExtraFields> extra_fields_; std::weak_ptr parent_; std::vector> children_; bool finished_{false}; + const torch::profiler::impl::kineto::activity_t* kineto_activity_{nullptr}; + private: template Result( @@ -220,6 +288,11 @@ struct TORCH_API Result : public std::enable_shared_from_this { start_tid_{start_tid}, kineto_info_{kineto_info}, extra_fields_{std::move(extra_fields)} {} + + template + static EventType deduceTag(const ExtraFields&) { + return E; + } }; struct KinetoObserverContext : public at::ObserverContext { @@ -229,6 +302,8 @@ struct KinetoObserverContext : public at::ObserverContext { // Set in the exit callback. approx_time_t end_time_{std::numeric_limits::min()}; + + bool allow_tf32_cublas_; }; explicit KinetoObserverContext(Event* event) : event_{event} {} @@ -262,12 +337,6 @@ class InputOutputEncoder final { TERMINATOR }; - struct TensorMetadata { - void* ptr_; - c10::ScalarType dtype_; - uint32_t dim_; - }; - void push(const at::Tensor& t); AppendOnlyList tags_; @@ -336,6 +405,11 @@ class TORCH_API ThreadLocalSubqueue { allocations_.emplace_back(std::forward(args)...); } + template + void emplace_ooms_event(Args&&... args) { + ooms_.emplace_back(std::forward(args)...); + } + template void emplace_py_call(Args&&... args) { py_calls_.emplace_back(std::forward(args)...); @@ -407,6 +481,9 @@ class TORCH_API ThreadLocalSubqueue { // reportMemoryUsage AppendOnlyList, BlockSize> allocations_; + + // reportOOMs + AppendOnlyList, BlockSize> ooms_; }; class TORCH_API RecordQueue { @@ -418,8 +495,13 @@ class TORCH_API RecordQueue { void stop(); // NB: This is a destructive operation. - std::vector> getRecords( - std::function time_converter); + std::pair< + std::vector>, + std::unique_ptr> + getRecords( + std::function time_converter, + uint64_t start_time_us, + uint64_t end_time_us); private: uint32_t id_; diff --git a/torch/csrc/profiler/cuda.cpp b/torch/csrc/profiler/cuda.cpp index 9116516f4ae0b..c64eb5e9caa2a 100644 --- a/torch/csrc/profiler/cuda.cpp +++ b/torch/csrc/profiler/cuda.cpp @@ -33,8 +33,8 @@ static inline void cudaCheck(cudaError_t result, const char* file, int line) { } #define TORCH_CUDA_CHECK(result) cudaCheck(result, __FILE__, __LINE__); -struct CUDAMethods : public CUDAStubs { - void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) +struct CUDAMethods : public ProfilerStubs { + void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns) const override { if (device) { TORCH_CUDA_CHECK(cudaGetDevice(device)); @@ -52,7 +52,7 @@ struct CUDAMethods : public CUDAStubs { TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream)); } - float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) + float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2) const override { TORCH_CUDA_CHECK(cudaEventSynchronize(event->get())); TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get())); @@ -63,17 +63,17 @@ struct CUDAMethods : public CUDAStubs { return ms * 1000.0; } - void nvtxMarkA(const char* name) const override { + void mark(const char* name) const override { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ::nvtxMark(name); } - void nvtxRangePushA(const char* name) const override { + void rangePush(const char* name) const override { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ::nvtxRangePushA(name); } - void nvtxRangePop() const override { + void rangePop() const override { ::nvtxRangePop(); } diff --git a/torch/csrc/profiler/itt.cpp b/torch/csrc/profiler/itt.cpp new file mode 100644 index 0000000000000..3269a5784ae79 --- /dev/null +++ b/torch/csrc/profiler/itt.cpp @@ -0,0 +1,55 @@ +#include +#include +#include + +#include + +namespace torch { +namespace profiler { +namespace impl { +namespace { + +struct ITTMethods : public ProfilerStubs { + void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns) + const override {} + + float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2) + const override { + return 0; + } + + void mark(const char* name) const override { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + torch::profiler::itt_mark(name); + } + + void rangePush(const char* name) const override { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + torch::profiler::itt_range_push(name); + } + + void rangePop() const override { + torch::profiler::itt_range_pop(); + } + + void onEachDevice(std::function op) const override {} + + void synchronize() const override {} + + bool enabled() const override { + return true; + } +}; + +struct RegisterITTMethods { + RegisterITTMethods() { + static ITTMethods methods; + registerITTMethods(&methods); + } +}; +RegisterITTMethods reg; + +} // namespace +} // namespace impl +} // namespace profiler +} // namespace torch diff --git a/torch/csrc/profiler/itt_observer.cpp b/torch/csrc/profiler/itt_observer.cpp new file mode 100644 index 0000000000000..3c044dcf1073c --- /dev/null +++ b/torch/csrc/profiler/itt_observer.cpp @@ -0,0 +1,72 @@ +#include + +#include + +namespace torch { +namespace profiler { +namespace impl { + +struct ITTThreadLocalState : ProfilerThreadLocalStateBase { + explicit ITTThreadLocalState(const ProfilerConfig& config) + : ProfilerThreadLocalStateBase(config) { + // Only `report_input_shapes` makes sense in this context. + TORCH_CHECK(!config.profile_memory); + TORCH_CHECK(!config.with_stack); + TORCH_CHECK(!config.with_flops); + TORCH_CHECK(!config.with_modules); + } + ~ITTThreadLocalState() override = default; + + ActiveProfilerType profilerType() override { + return ActiveProfilerType::ITT; + } + + void reportMemoryUsage(void*, int64_t, int64_t, int64_t, c10::Device) + override {} + + static ITTThreadLocalState* getTLS() { + auto tls = ProfilerThreadLocalStateBase::getTLS(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT); + return static_cast(tls); + } +}; + +template +std::unique_ptr enterITT(const at::RecordFunction& fn) { + if (ITTThreadLocalState::getTLS() != nullptr) { + torch::profiler::impl::ittStubs()->rangePush(fn.name()); + } + return nullptr; +} + +void pushITTCallbacks( + const ProfilerConfig& config, + const std::unordered_set& scopes) { + TORCH_CHECK( + torch::profiler::impl::ittStubs()->enabled(), + "Can't use ITT profiler - PyTorch was compiled without ITT"); + + c10::ThreadLocalDebugInfo::_push( + c10::DebugInfoKind::PROFILER_STATE, + std::make_shared(config)); + + auto state_ptr = ITTThreadLocalState::getTLS(); + TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); + + auto handle = at::addThreadLocalCallback( + at::RecordFunctionCallback( + state_ptr->config().report_input_shapes + ? &enterITT + : &enterITT, + [](const at::RecordFunction&, at::ObserverContext*) { + torch::profiler::impl::ittStubs()->rangePop(); + }) + .needsInputs(config.report_input_shapes) + .scopes(scopes)); + state_ptr->setCallbackHandle(handle); +} + +} // namespace impl +} // namespace profiler +} // namespace torch diff --git a/torch/csrc/profiler/itt_observer.h b/torch/csrc/profiler/itt_observer.h new file mode 100644 index 0000000000000..72d95880d9ea2 --- /dev/null +++ b/torch/csrc/profiler/itt_observer.h @@ -0,0 +1,13 @@ +#include + +namespace torch { +namespace profiler { +namespace impl { + +void pushITTCallbacks( + const ProfilerConfig& config, + const std::unordered_set& scopes); + +} // namespace impl +} // namespace profiler +} // namespace torch diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index d0eeef965ebda..5fd02d2e9374d 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -6,6 +6,8 @@ #include #endif +#include + namespace torch { namespace profiler { namespace impl { @@ -48,6 +50,17 @@ const DeviceAndResource kineto_ids() { #endif // USE_KINETO } +void addMetadata( + const activity_t* activity, + const std::string& key, + const std::string& value) { +#ifdef USE_KINETO + // ActivityTraceInterface returns const pointers, so we have to cast away the + // constness to add metadata. + const_cast(activity)->addMetadata(key, value); +#endif // USE_KINETO +} + TraceWrapper::TraceWrapper(const int64_t start_time, const std::string& name) #ifdef USE_KINETO : cpu_trace_(std::make_unique()) { @@ -60,38 +73,17 @@ TraceWrapper::TraceWrapper(const int64_t start_time, const std::string& name) } #endif // USE_KINETO -#ifdef USE_KINETO -namespace { -libkineto::ActivityType toActivityType(const KinetoActivityType type) { - switch (type) { - case KinetoActivityType::CPU_OP: - return libkineto::ActivityType::CPU_OP; - case KinetoActivityType::CPU_INSTANT_EVENT: - return libkineto::ActivityType::CPU_INSTANT_EVENT; - case KinetoActivityType::PYTHON_FUNCTION: - return libkineto::ActivityType::PYTHON_FUNCTION; - default: - TORCH_INTERNAL_ASSERT( - type == KinetoActivityType::USER_ANNOTATION, - "Invalid KinetoActivityType: ", - (int)type); - return libkineto::ActivityType::USER_ANNOTATION; - } -} -} // namespace -#endif // USE_KINETO +TraceWrapper::~TraceWrapper() = default; -void TraceWrapper::addCPUActivity( +activity_t* TraceWrapper::addCPUActivity( const std::string& name, - const KinetoActivityType kineto_type, + const libkineto::ActivityType type, const DeviceAndResource device_and_resource, const uint64_t correlation_id, const int64_t start_time, - const int64_t end_time, - const annotation_t& annotations) { + const int64_t end_time) { #ifdef USE_KINETO TORCH_CHECK((bool)(*this), "Cannot add event to non-existent trace."); - auto type = toActivityType(kineto_type); cpu_trace_->emplace_activity(cpu_trace_->span, type, name); auto& act = libkineto::CpuTraceBuffer::toRef(cpu_trace_->activities.back()); act.device = device_and_resource.device; @@ -101,9 +93,9 @@ void TraceWrapper::addCPUActivity( if (type != libkineto::ActivityType::CPU_INSTANT_EVENT) { act.endTime = end_time; } - for (const auto& i : annotations) { - act.addMetadata(i.first, i.second); - } + return cpu_trace_->activities.back().get(); +#else + return nullptr; #endif // USE_KINETO } @@ -123,7 +115,7 @@ TraceWrapper::operator bool() const { } ActivityTraceWrapper::ActivityTraceWrapper( - std::unique_ptr trace) + std::unique_ptr&& trace) : trace_(std::move(trace)), saved_{false} {} ActivityTraceWrapper::operator bool() const { @@ -290,7 +282,6 @@ void recordThreadInfo() { namespace autograd { namespace profiler { -#ifdef USE_KINETO c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { // fallthrough switch (activity_type) { @@ -309,13 +300,14 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { case libkineto::ActivityType::PYTHON_FUNCTION: return c10::DeviceType::CPU; default: { - LOG(WARNING) << "Unknown activity type (" << (uint8_t)activity_type - << "), assuming CPU device"; + TORCH_WARN( + "Unknown activity type (", + (uint8_t)activity_type, + "), assuming CPU device"); return c10::DeviceType::CPU; } } } -#endif // USE_KINETO void addMetadataJson(const std::string& key, const std::string& value) { #ifdef USE_KINETO diff --git a/torch/csrc/profiler/kineto_shim.h b/torch/csrc/profiler/kineto_shim.h index 885e2a6755a62..631569b02f6b3 100644 --- a/torch/csrc/profiler/kineto_shim.h +++ b/torch/csrc/profiler/kineto_shim.h @@ -12,13 +12,15 @@ #undef USE_KINETO #endif +#include + #include #include #ifdef USE_KINETO // Forward declarations so we don't have to include `libkineto.h` in a header. namespace libkineto { -enum class ActivityType; +class GenericTraceActivity; struct CpuTraceBuffer; class ActivityTraceInterface; } // namespace libkineto @@ -40,49 +42,44 @@ namespace kineto { // -- Interface (Does not require Kineto) ------------------------------------- // ---------------------------------------------------------------------------- struct DeviceAndResource { -#ifdef USE_KINETO int32_t device; int32_t resource; -#endif // USE_KINETO }; const DeviceAndResource kineto_ids(); #ifdef USE_KINETO using trace_t = libkineto::CpuTraceBuffer; using interface_trace_t = libkineto::ActivityTraceInterface; +using activity_t = libkineto::GenericTraceActivity; #else struct DummyTraceBuffer {}; struct DummyTraceInterface {}; using trace_t = DummyTraceBuffer; using interface_trace_t = DummyTraceBuffer; +struct activity_t; #endif // USE_KINETO -// Subset of `libkineto::ActivityType` for `addCPUActivity`. -enum class KinetoActivityType : uint8_t { - CPU_OP = 0, - CPU_INSTANT_EVENT, - USER_ANNOTATION, - PYTHON_FUNCTION -}; - -using annotation_t = std::vector>; +void addMetadata( + const activity_t* activity, + const std::string& key, + const std::string& value); // Wraps: libkineto::CpuTraceBuffer struct TraceWrapper { TraceWrapper(const int64_t start_time, const std::string& name); TraceWrapper(TraceWrapper&&) = default; TraceWrapper(const TraceWrapper&) = delete; + ~TraceWrapper(); // The caller is expected to hold a mutex when calling `addCPUActivity`. - void addCPUActivity( + activity_t* addCPUActivity( const std::string& name, - const KinetoActivityType kineto_type, + const libkineto::ActivityType type, const DeviceAndResource device_and_resource, const uint64_t correlation_id, const int64_t start_time, - const int64_t end_time, - const annotation_t& annotations); + const int64_t end_time); void transferCpuTrace(int64_t end_time); @@ -98,7 +95,7 @@ struct TraceWrapper { // Wraps libkineto::ActivityTraceInterface struct ActivityTraceWrapper { - explicit ActivityTraceWrapper(std::unique_ptr trace); + explicit ActivityTraceWrapper(std::unique_ptr&& trace); ActivityTraceWrapper() = default; ActivityTraceWrapper(ActivityTraceWrapper&&) = default; ActivityTraceWrapper(const ActivityTraceWrapper&) = delete; @@ -133,9 +130,7 @@ void recordThreadInfo(); namespace autograd { namespace profiler { -#ifdef USE_KINETO c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type); -#endif // USE_KINETO TORCH_API void addMetadataJson( const std::string& key, diff --git a/torch/csrc/profiler/nvtx_observer.cpp b/torch/csrc/profiler/nvtx_observer.cpp index ddc677d9786b5..fa091c4ef8b74 100644 --- a/torch/csrc/profiler/nvtx_observer.cpp +++ b/torch/csrc/profiler/nvtx_observer.cpp @@ -129,7 +129,7 @@ template std::unique_ptr enterNVTX(const at::RecordFunction& fn) { if (NVTXThreadLocalState::getTLS() != nullptr) { auto input_op_ids = getInputTensorOpIds(fn); - torch::profiler::impl::cudaStubs()->nvtxRangePushA( + torch::profiler::impl::cudaStubs()->rangePush( torch::profiler::impl::getNvtxStr( fn.name(), fn.seqNr(), @@ -164,7 +164,7 @@ void pushNVTXCallbacks( ? &enterNVTX : &enterNVTX, [](const at::RecordFunction& fn, at::ObserverContext* ctx) { - torch::profiler::impl::cudaStubs()->nvtxRangePop(); + torch::profiler::impl::cudaStubs()->rangePop(); updateOutputTensorTracker(fn); }) .needsInputs(config.report_input_shapes) diff --git a/torch/csrc/python_headers.h b/torch/csrc/python_headers.h index 1e5b16eebbff1..0130e41ccb46e 100644 --- a/torch/csrc/python_headers.h +++ b/torch/csrc/python_headers.h @@ -1,5 +1,7 @@ #pragma once +// workaround for https://github.com/python/cpython/pull/23326 #include +#include // workaround for Python 2 issue: https://bugs.python.org/issue17120 // NOTE: It looks like this affects Python 3 as well. #pragma push_macro("_XOPEN_SOURCE") @@ -8,11 +10,16 @@ #undef _POSIX_C_SOURCE #include +#include #include #pragma pop_macro("_XOPEN_SOURCE") #pragma pop_macro("_POSIX_C_SOURCE") +#ifdef copysign +#undef copysign +#endif + #if PY_MAJOR_VERSION < 3 #error "Python 2 has reached end-of-life and is no longer supported by PyTorch." #endif diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index 13c3d63675266..23187decb10d7 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index 990e094f82c3f..b2eac4b54fa1b 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -294,3 +295,58 @@ char* tensor_repr(at::Tensor tensor) { } // namespace gdb } // namespace torch + +namespace pybind11 { +namespace detail { + +bool type_caster::load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) { + value = THPVariable_Unpack(obj); + return true; + } + return false; +} + +handle type_caster::cast( + const at::Tensor& src, + return_value_policy /* policy */, + handle /* parent */) { + return handle(THPVariable_Wrap(src)); +} + +bool type_caster::load(handle src, bool) { + PyObject* source = src.ptr(); + auto tuple = PyTuple_Check(source); + if (tuple || PyList_Check(source)) { + // NOLINTNEXTLINE(bugprone-branch-clone) + const auto size = + tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); + v_value.resize(size); + for (const auto idx : c10::irange(size)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx); + if (THPVariable_Check(obj)) { + v_value[idx] = THPVariable_Unpack(obj).item(); + } else if (PyLong_Check(obj)) { + // use THPUtils_unpackLong after it is safe to include + // python_numbers.h + v_value[idx] = THPUtils_unpackLong(obj); + } else { + return false; + } + } + value = v_value; + return true; + } + return false; +} +handle type_caster::cast( + at::IntArrayRef src, + return_value_policy /* policy */, + handle /* parent */) { + return handle(THPUtils_packInt64Array(src.size(), src.data())); +} + +} // namespace detail +} // namespace pybind11 diff --git a/torch/csrc/utils.h b/torch/csrc/utils.h index b45120b18a3d8..fe8c834077580 100644 --- a/torch/csrc/utils.h +++ b/torch/csrc/utils.h @@ -164,22 +164,8 @@ void THPUtils_addPyMethodDefs( int THPUtils_getCallable(PyObject* arg, PyObject** result); -#define THWTensorPtr TH_CONCAT_3(TH, Real, TensorPtr) -#define THPStoragePtr TH_CONCAT_2(THP, StoragePtr) -#define THPTensorPtr TH_CONCAT_3(THP, Real, TensorPtr) -#define THSPTensorPtr TH_CONCAT_3(THSP, Real, TensorPtr) - typedef THPPointer THPGeneratorPtr; - -template -struct THPUtils_typeTraits {}; - -// Disabling clang-format because the order of these includes matters. -// This is mega-sus. -// clang-format off -#include -#include -// clang-format on +typedef class THPPointer THPStoragePtr; std::vector THPUtils_unpackLongs(PyObject* arg); PyObject* THPUtils_dispatchStateless( diff --git a/torch/csrc/utils/cuda_lazy_init.cpp b/torch/csrc/utils/cuda_lazy_init.cpp index 37c682bdbd1f2..0b86cfe4b8004 100644 --- a/torch/csrc/utils/cuda_lazy_init.cpp +++ b/torch/csrc/utils/cuda_lazy_init.cpp @@ -1,14 +1,16 @@ #include -#include -#include - #include +#include #include + namespace torch { namespace utils { +namespace { + +bool is_initialized = false; -static bool run_yet = false; +} void cuda_lazy_init() { pybind11::gil_scoped_acquire g; @@ -16,20 +18,25 @@ void cuda_lazy_init() { // has a buggy implementation that deadlocks if an instance throws an // exception. In any case, call_once isn't necessary, because we // have taken a lock. - if (!run_yet) { - auto module = THPObjectPtr(PyImport_ImportModule("torch.cuda")); - if (!module) - throw python_error(); - auto res = - THPObjectPtr(PyObject_CallMethod(module.get(), "_lazy_init", "")); - if (!res) - throw python_error(); - run_yet = true; + if (is_initialized) { + return; } + + auto module = THPObjectPtr(PyImport_ImportModule("torch.cuda")); + if (!module) { + throw python_error(); + } + + auto res = THPObjectPtr(PyObject_CallMethod(module.get(), "_lazy_init", "")); + if (!res) { + throw python_error(); + } + + is_initialized = true; } -void set_run_yet_variable_to_false() { - run_yet = false; +void set_requires_cuda_init(bool value) { + is_initialized = !value; } } // namespace utils diff --git a/torch/csrc/utils/cuda_lazy_init.h b/torch/csrc/utils/cuda_lazy_init.h index e6efc1f648e52..90a8581e63ab3 100644 --- a/torch/csrc/utils/cuda_lazy_init.h +++ b/torch/csrc/utils/cuda_lazy_init.h @@ -21,7 +21,7 @@ namespace utils { // build, which is not good UX. // void cuda_lazy_init(); -void set_run_yet_variable_to_false(); +void set_requires_cuda_init(bool value); static void maybe_initialize_cuda(const at::TensorOptions& options) { if (options.device().is_cuda()) { diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index ba7d3eb162918..ac29a9157a9c1 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 8200eeb88e954..3cdc33e90681b 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -25,6 +25,7 @@ struct DisableTorchDispatch { c10::impl::ExcludeDispatchKeyGuard guard_; c10::impl::ExcludeDispatchKeyGuard guard_tls_snapshot_; }; + } // namespace torch PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); diff --git a/torch/csrc/utils/init.cpp b/torch/csrc/utils/init.cpp index 584860fbc229b..d1c94b6629cd5 100644 --- a/torch/csrc/utils/init.cpp +++ b/torch/csrc/utils/init.cpp @@ -3,6 +3,7 @@ #include #include +#include namespace torch { namespace throughput_benchmark { diff --git a/torch/csrc/utils/object_ptr.cpp b/torch/csrc/utils/object_ptr.cpp index 7faba811a1fa1..6498e0de0cc12 100644 --- a/torch/csrc/utils/object_ptr.cpp +++ b/torch/csrc/utils/object_ptr.cpp @@ -9,3 +9,19 @@ void THPPointer::free() { } template class THPPointer; + +template <> +void THPPointer::free() { + if (ptr) + Py_DECREF(ptr); +} + +template class THPPointer; + +template <> +void THPPointer::free() { + if (ptr) + Py_DECREF(ptr); +} + +template class THPPointer; diff --git a/torch/csrc/utils/object_ptr.h b/torch/csrc/utils/object_ptr.h index 5b4c1bc404457..359e177f4f58c 100644 --- a/torch/csrc/utils/object_ptr.h +++ b/torch/csrc/utils/object_ptr.h @@ -65,3 +65,5 @@ class THPPointer { * not use THPPointer in this situation. */ using THPObjectPtr = THPPointer; +using THPCodeObjectPtr = THPPointer; +using THPFrameObjectPtr = THPPointer; diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 1de53cf90ecb7..f45caf6fac9c3 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -11,9 +12,6 @@ #include #include #include -#include -#include -#include #include #include @@ -33,26 +31,17 @@ namespace detail { // torch.Tensor <-> at::Tensor conversions (without unwrapping) template <> -struct type_caster { +struct TORCH_PYTHON_API type_caster { public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor")); - bool load(handle src, bool) { - PyObject* obj = src.ptr(); - if (THPVariable_Check(obj)) { - value = THPVariable_Unpack(obj); - return true; - } - return false; - } + bool load(handle src, bool); static handle cast( const at::Tensor& src, return_value_policy /* policy */, - handle /* parent */) { - return handle(THPVariable_Wrap(src)); - } + handle /* parent */); }; // torch._StorageBase <-> at::Storage @@ -103,43 +92,16 @@ struct type_caster { }; template <> -struct type_caster { +struct TORCH_PYTHON_API type_caster { public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) PYBIND11_TYPE_CASTER(at::IntArrayRef, _("at::IntArrayRef")); - bool load(handle src, bool) { - PyObject* source = src.ptr(); - auto tuple = PyTuple_Check(source); - if (tuple || PyList_Check(source)) { - // NOLINTNEXTLINE(bugprone-branch-clone) - const auto size = - tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); - v_value.resize(size); - for (const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(source, idx) - : PyList_GET_ITEM(source, idx); - if (THPVariable_Check(obj)) { - v_value[idx] = THPVariable_Unpack(obj).item(); - } else if (PyLong_Check(obj)) { - // use THPUtils_unpackLong after it is safe to include - // python_numbers.h - v_value[idx] = THPUtils_unpackLong(obj); - } else { - return false; - } - } - value = v_value; - return true; - } - return false; - } + bool load(handle src, bool); static handle cast( at::IntArrayRef src, return_value_policy /* policy */, - handle /* parent */) { - return handle(THPUtils_packInt64Array(src.size(), src.data())); - } + handle /* parent */); private: std::vector v_value; diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index fb77e9d41a55e..9095be4e61d9b 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -82,7 +82,7 @@ static const std::unordered_map> // If you modify this, you will need to adjust the blocklist in // tools/pyi/gen_pyi.py (and add hardcoded signatures for these // functions.) -static bool should_allow_numbers_as_tensors(const std::string& name) { +bool should_allow_numbers_as_tensors(const std::string& name) { static std::unordered_set allowed = { "add", "add_", "add_out", "div", "div_", "div_out", @@ -290,9 +290,12 @@ auto handle_torch_function_no_python_arg_parser( PyObject* mode_obj = nullptr; const bool is_torch_function = torch_function_name == TorchFunctionName::TorchFunction; - const auto& maybe_mode = is_torch_function - ? at::impl::PythonTorchFunctionTLS::get_mode() - : at::impl::TorchDispatchModeTLS::get_state(); + auto get_mode = [&]() { + return is_torch_function ? at::impl::PythonTorchFunctionTLS::get_mode() + : at::impl::TorchDispatchModeTLS::get_state(); + }; + + const auto& maybe_mode = get_mode(); if (maybe_mode) { mode_obj = maybe_mode->ptr(getPyInterpreter()); TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr); @@ -335,6 +338,9 @@ auto handle_torch_function_no_python_arg_parser( // NOLINTNEXTLINE(clang-diagnostic-writable-strings) py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), torch_function_name_str); + if (!torch_function) { + TORCH_INTERNAL_ASSERT(0); + } // See https://github.com/pytorch/pytorch/issues/63767 if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__") @@ -385,8 +391,9 @@ auto handle_torch_function_no_python_arg_parser( // If a user forcibly changes the mode in a non-lexical way // in the inner context, the mode could be invalid here. So just be // a bit safe, it doesn't cost us anything since this is error reporting - const auto& maybe_mode = at::impl::PythonTorchFunctionTLS::get_mode(); - TORCH_INTERNAL_ASSERT(mode_obj == maybe_mode->ptr(getPyInterpreter())); + const auto& maybe_mode = get_mode(); + TORCH_INTERNAL_ASSERT( + maybe_mode && mode_obj == maybe_mode->ptr(getPyInterpreter())); ss << " nor was it found on the currently active mode " << py::repr(mode_obj); } @@ -649,12 +656,29 @@ bool is_float_or_complex_list(PyObject* obj) { static bool is_int_list(PyObject* obj, int broadcast_size) { if (PyTuple_Check(obj) || PyList_Check(obj)) { - if (PySequence_Size(obj) == 0) { + auto len = PySequence_Size(obj); + if (len == 0) { return true; } auto item = py::reinterpret_steal(PySequence_GetItem(obj, 0)); + bool int_first = false; if (THPUtils_checkIndex(item.ptr())) { + // we still have to check that the rest of items are NOT symint nodes + int_first = true; + } + + // Make sure none of the later arguments are SymInt + // NB: do NOT check that the later arguments are ints, as this is + // BC-breaking for FX + for (int i = 1; i < len; i++) { + if (torch::is_symint_node( + py::reinterpret_steal(PySequence_GetItem(obj, i)))) { + return false; + } + } + + if (int_first) { return true; } @@ -1227,10 +1251,14 @@ bool FunctionSignature::parse( // if there is a single positional IntArrayRef argument, i.e. expand(..), // view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as // expand((5,3)) + int int_list_overload = false; if (max_pos_args == 1 && (params[0].type_ == ParameterType::INT_LIST || params[0].type_ == ParameterType::SYM_INT_LIST)) { allow_varargs_intlist = true; + if (params[0].type_ == ParameterType::INT_LIST) { + int_list_overload = true; + } } if (nargs > max_pos_args && !allow_varargs_intlist) { @@ -1287,7 +1315,8 @@ bool FunctionSignature::parse( // should avoid having complex signatures that make use of it... } else if ( allow_varargs_intlist && arg_pos == 0 && !is_kwd && - is_int_or_symint(obj)) { + ((int_list_overload ? is_int_list(args, param.size) + : is_int_or_symint_list(args, param.size)))) { // take all positional arguments as this parameter // e.g. permute(1, 2, 3) -> permute((1, 2, 3)) dst[i++] = args; diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 935dbe25c5905..8523f055a0f0d 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -68,7 +68,7 @@ #include #include -#include +#include #include #include #include @@ -78,6 +78,8 @@ namespace torch { +bool should_allow_numbers_as_tensors(const std::string& name); + enum class ParameterType { TENSOR, SCALAR, @@ -473,7 +475,7 @@ inline std::vector PythonArgs::intlist(int i) { } inline bool is_symint_node(py::handle obj) { - auto static tp_symn = py::type::of(); + auto static tp_symn = py::type::of(); // TODO: switch this to `isinstance` if (obj.get_type().equal(tp_symn)) { TORCH_CHECK( @@ -485,9 +487,11 @@ inline bool is_symint_node(py::handle obj) { inline PyObject* toPyObject(c10::SymInt symint) { if (symint.is_symbolic()) { - return py::cast(symint.toSymbolicIntNode()).release().ptr(); + auto r = py::cast(symint.toSymIntNodeImpl()).release().ptr(); + TORCH_INTERNAL_ASSERT(r); + return r; } else { - return THPUtils_packInt64(symint.data()); + return THPUtils_packInt64(symint.as_int_unchecked()); } } @@ -505,7 +509,7 @@ inline std::vector PythonArgs::symintlist(int i) { } if (size1 > 0 && torch::is_symint_node(py::handle(args[i]))) { - auto si = py::handle(args[i]).cast()->toSymInt(); + auto si = py::handle(args[i]).cast()->toSymInt(); return std::vector(size1, si); } @@ -520,8 +524,7 @@ inline std::vector PythonArgs::symintlist(int i) { tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); try { if (is_symint_node(py::handle(obj))) { - res.push_back( - py::handle(obj).cast()->toSymInt()); + res.push_back(py::handle(obj).cast()->toSymInt()); } else { // Elements of torch.Size are tensors during tracing, and we need to // record extra information before they are turned into an IntArrayRef @@ -671,10 +674,15 @@ inline c10::optional PythonArgs::scalartypeOptional(int i) { return scalartype(i); } +inline at::Layout toLayout(PyObject* obj) { + const auto layout = reinterpret_cast(obj); + return layout->layout; +} + inline at::Layout PythonArgs::layout(int i) { if (!args[i]) return signature.params[i].default_layout; - return reinterpret_cast(args[i])->layout; + return toLayout(args[i]); } inline at::Layout PythonArgs::layoutWithDefault( @@ -848,7 +856,7 @@ inline c10::SymInt PythonArgs::toSymInt(int i) { signature.params[i].name, idx, var, c10::IntType::get()); } if (torch::is_symint_node(py::handle(args[i]))) { - return py::handle(args[i]).cast()->toSymInt(); + return py::handle(args[i]).cast()->toSymInt(); } return c10::SymInt(THPUtils_unpackLong(args[i])); } diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index fb0eec3afb802..6fe4cccfd0119 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -68,3 +68,40 @@ inline int __PySlice_Unpack( #define THPUtils_parseSlice(SLICE, LEN, START, STOP, LENGTH, STEP) \ (PySlice_GetIndicesEx(SLICE, LEN, START, STOP, LENGTH, STEP) == 0) + +// Compat macros macros taken from +// https://docs.python.org/3.11/whatsnew/3.11.html + +#if PY_VERSION_HEX < 0x030900B1 +static inline PyCodeObject* PyFrame_GetCode(PyFrameObject* frame) { + Py_INCREF(frame->f_code); + return frame->f_code; +} + +static inline PyFrameObject* PyFrame_GetBack(PyFrameObject* frame) { + Py_XINCREF(frame->f_back); + return frame->f_back; +} +#endif + +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) +static inline void _Py_SET_TYPE(PyObject* ob, PyTypeObject* type) { + ob->ob_type = type; +} +#define Py_SET_TYPE(ob, type) _Py_SET_TYPE((PyObject*)(ob), type) +#endif + +#if PY_VERSION_HEX < ((3 << 24) | (11 << 16) | (0 << 8) | (0xA << 4) | (4 << 0)) +static inline PyObject* PyFrame_GetLocals(PyFrameObject* frame) { + PyFrame_FastToLocals(frame); + auto res = frame->f_locals; + + // To match PyFrame_GetLocals, return a new reference + Py_INCREF(res); + return res; +} + +static inline int PyFrame_GetLasti(PyFrameObject* frame) { + return frame->f_lasti; +} +#endif diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index c3607f75cbf31..25e93bf02da79 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -12,6 +12,7 @@ #include #include +#include #include diff --git a/torch/csrc/utils/python_dispatch.h b/torch/csrc/utils/python_dispatch.h index e1abbca27758b..f05c36ac268de 100644 --- a/torch/csrc/utils/python_dispatch.h +++ b/torch/csrc/utils/python_dispatch.h @@ -1,4 +1,5 @@ #include +#include namespace torch { namespace impl { diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h index 0117dc21b12df..ff766b51a371b 100644 --- a/torch/csrc/utils/python_scalars.h +++ b/torch/csrc/utils/python_scalars.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -9,22 +10,36 @@ namespace torch { namespace utils { +template +inline T unpackIntegral(PyObject* obj, const char* type) { +#if PY_VERSION_HEX >= 0x030a00f0 + // In Python-3.10 floats can no longer be silently converted to integers + // Keep backward compatible behavior for now + if (PyFloat_Check(obj)) { + return c10::checked_convert(THPUtils_unpackDouble(obj), type); + } + return c10::checked_convert(THPUtils_unpackLong(obj), type); +#else + return static_cast(THPUtils_unpackLong(obj)); +#endif +} + inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) { switch (scalarType) { case at::kByte: - *(uint8_t*)data = (uint8_t)THPUtils_unpackLong(obj); + *(uint8_t*)data = unpackIntegral(obj, "uint8"); break; case at::kChar: - *(int8_t*)data = (int8_t)THPUtils_unpackLong(obj); + *(int8_t*)data = unpackIntegral(obj, "int8"); break; case at::kShort: - *(int16_t*)data = (int16_t)THPUtils_unpackLong(obj); + *(int16_t*)data = unpackIntegral(obj, "int16"); break; case at::kInt: - *(int32_t*)data = (int32_t)THPUtils_unpackLong(obj); + *(int32_t*)data = unpackIntegral(obj, "int32"); break; case at::kLong: - *(int64_t*)data = THPUtils_unpackLong(obj); + *(int64_t*)data = unpackIntegral(obj, "int64"); break; case at::kHalf: *(at::Half*)data = diff --git a/torch/csrc/utils/schema_info.cpp b/torch/csrc/utils/schema_info.cpp new file mode 100644 index 0000000000000..f3b3d1b4017d9 --- /dev/null +++ b/torch/csrc/utils/schema_info.cpp @@ -0,0 +1,432 @@ +#include +#include + +namespace torch { +namespace utils { +void SchemaInfo::addArgumentValue( + const std::string& name, + const at::IValue& value) { + c10::optional index = schema_.argumentIndexWithName(name); + TORCH_INTERNAL_ASSERT( + index != c10::nullopt, "Schema has no argument named ", name); + value_map_[name] = value; + alias_maps_current_ = false; +} + +void SchemaInfo::addArgumentValues( + const std::vector>& value_list) { + TORCH_INTERNAL_ASSERT( + value_list.size() <= schema_.arguments().size(), + "Schema does not have enough arguments for value list"); + + for (size_t i = 0; i < value_list.size(); i++) { + if (value_list[i] != c10::nullopt) { + value_map_[schema_.arguments()[i].name()] = *(value_list[i]); + alias_maps_current_ = false; + } + } +} + +void SchemaInfo::addArgumentValues( + const std::unordered_map& values) { + for (const auto& key_pair : values) { + addArgumentValue(key_pair.first, key_pair.second); + } +} + +bool SchemaInfo::hasInputArgumentNamed(const std::string& name) const { + return std::any_of( + schema_.arguments().begin(), + schema_.arguments().end(), + [&name](const c10::Argument& arg) { return arg.name() == name; }); +} + +bool SchemaInfo::is_mutable() { + for (size_t i = 0; i < schema_.arguments().size(); i++) { + if (is_mutable({c10::SchemaArgType::input, i})) { + return true; + } + } + return false; +} + +bool SchemaInfo::is_mutable(const c10::SchemaArgument& argument) { + TORCH_INTERNAL_ASSERT( + argument.index < schema_.getCorrectList(argument.type).size(), + "Invalid index for schema."); + if (!alias_maps_current_) { + generateAliasMaps(); + } + static const std::vector training_ops = + getTrainingOps(); + const auto& correct_map = (argument.type == c10::SchemaArgType::input) + ? input_alias_map_ + : output_alias_map_; + // Note that the training_op checks depend on index because + // of cases where either running_mean or running_var alias another input + // argument causing its alias status to change. + return std::any_of( + correct_map[argument.index].begin(), + correct_map[argument.index].end(), + [this](size_t aliasing_index) { + const auto is_training_op = std::find_if( + training_ops.begin(), + training_ops.end(), + [this](const auto& training_op) { + return this->schema_ == training_op.first; + }); + + bool special_case = (is_training_op != training_ops.end()) && + is_training_op->second.count( + this->schema_.arguments()[aliasing_index].name()); + if (special_case) { + bool has_training = (hasInputArgumentNamed("training") && + !value_map_.count("training")) || + (value_map_.count("training") && + value_map_.at("training").toBool()); + bool has_train = + (hasInputArgumentNamed("train") && !value_map_.count("train")) || + (value_map_.count("train") && value_map_.at("train").toBool()); + bool has_use_input_stats = + (hasInputArgumentNamed("use_input_stats") && + !value_map_.count("use_input_stats")) || + (value_map_.count("use_input_stats") && + value_map_.at("use_input_stats").toBool()); + return has_training || has_train || has_use_input_stats; + } else { + return this->schema_.is_mutable( + {c10::SchemaArgType::input, aliasing_index}); + } + }); +} + +bool SchemaInfo::is_mutable(c10::string_view name) { + c10::optional index = schema_.argumentIndexWithName(name); + TORCH_INTERNAL_ASSERT( + index != c10::nullopt, "Schema has no argument named ", name); + + return is_mutable({c10::SchemaArgType::input, static_cast(*index)}); +} + +bool SchemaInfo::is_nondeterministic() const { + static const c10::FunctionSchema dropout_schema = torch::jit::parseSchema( + "aten::dropout(Tensor input, float p, bool train) -> Tensor"); + if (dropout_schema == schema_ && value_map_.count("train") && + !value_map_.at("train").toBool()) { + return false; + } + +#if defined C10_MOBILE + static const std::vector nondeterministic_ops = + getNonDeterministicOps(); + return std::any_of( + nondeterministic_ops.begin(), + nondeterministic_ops.end(), + [this](const c10 ::FunctionSchema& nondeterministic_op) { + return nondeterministic_op == this->schema_; + }); +#else + const auto& op = c10::Dispatcher::singleton().findOp( + c10::OperatorName(schema_.name(), schema_.overload_name())); + return op && op->hasTag(at::Tag::nondeterministic_seeded); +#endif +} + +bool SchemaInfo::may_alias( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs) { + bool basic_check = schema_.may_alias(lhs, rhs); + if (basic_check) { + return true; + } + c10::optional lhsAliasTypeSet = + schema_.mapTypeToAliasTypeSet( + schema_.getCorrectList(lhs.type)[lhs.index].type()); + c10::optional rhsAliasTypeSet = + schema_.mapTypeToAliasTypeSet( + schema_.getCorrectList(rhs.type)[rhs.index].type()); + bool types_can_alias = + schema_.canAliasTypeSetsAlias(lhsAliasTypeSet, rhsAliasTypeSet); + if (!types_can_alias) { + return false; + } + + if (!alias_maps_current_) { + generateAliasMaps(); + } + bool wildcard_alias_check = + wildcardSet().count(lhs) && wildcardSet().count(rhs); + if (wildcard_alias_check) { + return true; + } + + if (lhs.type == c10::SchemaArgType::input && + rhs.type == c10::SchemaArgType::input) { + return input_alias_map_[lhs.index].count(rhs.index); + } else if ( + lhs.type == c10::SchemaArgType::output && + rhs.type == c10::SchemaArgType::output) { + for (size_t lhs_alias_input : output_alias_map_[lhs.index]) { + if (output_alias_map_[rhs.index].count(lhs_alias_input)) { + return true; + } + } + return false; + } else if (lhs.type == c10::SchemaArgType::output) { + return output_alias_map_[lhs.index].count(rhs.index); + } else { + return output_alias_map_[rhs.index].count(lhs.index); + } +} + +bool SchemaInfo::may_contain_alias( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs, + bool bidirectional) { + bool basic_check = schema_.may_contain_alias(lhs, rhs) || may_alias(lhs, rhs); + if (basic_check) { + return true; + } + if (!alias_maps_current_) { + generateAliasMaps(); + } + if (bidirectional) { + return mayContainAliasImpl(lhs, rhs) || mayContainAliasImpl(rhs, lhs); + } else { + return mayContainAliasImpl(lhs, rhs); + } +} + +bool SchemaInfo::mayContainAliasImpl( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs) { + c10::optional lhsContainedAliasTypeSet = + schema_.getAliasTypeSetContainedTypes(schema_.mapTypeToAliasTypeSet( + schema_.getCorrectList(lhs.type)[lhs.index].type())); + c10::optional rhsAliasTypeSet = + schema_.mapTypeToAliasTypeSet( + schema_.getCorrectList(rhs.type)[rhs.index].type()); + bool types_can_alias = + schema_.canAliasTypeSetsAlias(lhsContainedAliasTypeSet, rhsAliasTypeSet); + return types_can_alias && containerSet().count(lhs) && + wildcardSet().count(rhs); +} + +void SchemaInfo::ensureConservativity( + const std::unordered_set& duplicates, + const std::vector& arguments_list, + c10::SchemaArgType type) { + for (size_t i = 0; i < arguments_list.size(); i++) { + if (arguments_list[i].alias_info()) { + for (const auto& set : arguments_list[i].alias_info()->afterSets()) { + if (duplicates.count(set)) { + wildcard_set_.insert({type, i}); + } + } + } + } +} + +std::vector SchemaInfo::getNonDeterministicOps() { + // This list of nondeterministic ops is copied from JIT ir.cpp. + static const std::vector nondeterministic_op_strings = { + "aten::dropout(Tensor input, float p, bool train) -> Tensor", + "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)", + "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor", + "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor", + "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor", + "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor", + "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)", + "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor", + "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor", + "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor", + "aten::poisson(Tensor self, Generator? generator) -> Tensor", + "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor", + "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", + "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", + "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; + + std::vector nondeterministic_ops; + nondeterministic_ops.reserve(nondeterministic_op_strings.size()); + for (const std::string& signature : nondeterministic_op_strings) { + nondeterministic_ops.push_back(torch::jit::parseSchema(signature)); + } + + return nondeterministic_ops; +} + +std::vector SchemaInfo::getTrainingOps() { + // This is a list of pairs of ops to sets of strings + // where the a boolean variable (either "training", + // "train" or "use_input_stats") affects the mutability + // of the unorderered set of strings. + static const std::vector>> training_op_pairs = + {{"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", + {"running_mean", "running_var"}}, + {"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor", + {"running_mean", "running_var"}}, + {"aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)", + {"running_mean", "running_var"}}, + {"aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)", + {"running_mean", "running_var"}}, + {"aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)", + {"running_mean", "running_var"}}, + {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", + {"running_mean", "running_var"}}, + {"aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))", + {"running_mean", "running_var"}}, + {"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor", + {"noise"}}, + {"aten::rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", + {"noise"}}, + {"rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)", + {"noise"}}}; + + std::vector training_ops; + training_ops.reserve(training_op_pairs.size()); + for (const auto& signature : training_op_pairs) { + training_ops.emplace_back( + torch::jit::parseSchema(signature.first), signature.second); + } + + return training_ops; +} + +void SchemaInfo::initSchemaInfo() { + if (has_init_) { + return; + } + has_init_ = true; + + std::unordered_set duplicates; + auto init_schema_arguments = [this, &duplicates]( + const std::vector& + arguments_list, + c10::SchemaArgType type) { + std::unordered_set seen; + for (size_t i = 0; i < arguments_list.size(); i++) { + const c10::Argument& argument = arguments_list[i]; + if (argument.alias_info()) { + if (argument.alias_info()->isWildcardAfter()) { + wildcard_set_.insert({type, i}); + } else { + // This check is to ensure that the FunctionSchema will accurately + // be represented when calling may_alias and may_contain_alias + // on schemas with more than one argument within arguments_list that + // shares an alias set. + for (const auto& set : argument.alias_info()->afterSets()) { + if (seen.count(set)) { + TORCH_WARN( + set.toQualString(), + " appears twice in same argument list which will make aliasing checks more conservative."); + duplicates.insert(set); + } else { + seen.insert(set); + } + } + } + } + c10::optional contained_types = + schema_.getAliasTypeSetContainedTypes( + schema_.mapTypeToAliasTypeSet(argument.type())); + if (contained_types && contained_types->size() > 0) { + container_set_.insert({type, i}); + } + } + }; + + init_schema_arguments(schema_.arguments(), c10::SchemaArgType::input); + init_schema_arguments(schema_.returns(), c10::SchemaArgType::output); + ensureConservativity( + duplicates, schema_.arguments(), c10::SchemaArgType::input); + ensureConservativity( + duplicates, schema_.returns(), c10::SchemaArgType::output); +} + +const std::unordered_set& SchemaInfo::wildcardSet() { + initSchemaInfo(); + return wildcard_set_; +} + +const std::unordered_set& SchemaInfo::containerSet() { + initSchemaInfo(); + return container_set_; +} + +void SchemaInfo::generateAliasMaps() { + initSchemaInfo(); + + alias_maps_current_ = true; + input_alias_map_ = std::vector>( + schema_.arguments().size(), std::unordered_set()); + output_alias_map_ = std::vector>( + schema_.returns().size(), std::unordered_set()); + + // Fills input_alias_map_ + for (size_t i = 0; i < schema_.arguments().size(); i++) { + for (size_t j = i; j < schema_.arguments().size(); j++) { + if (i == j) { + input_alias_map_[i].insert(i); + } else if ( + value_map_.count(schema_.arguments()[i].name()) && + value_map_.count(schema_.arguments()[j].name())) { + if (value_map_[schema_.arguments()[i].name()].isAliasOf( + value_map_[schema_.arguments()[j].name()])) { + input_alias_map_[i].insert(j); + input_alias_map_[j].insert(i); + if (wildcard_set_.count({c10::SchemaArgType::input, i})) { + wildcard_set_.insert({c10::SchemaArgType::input, j}); + } else if (wildcard_set_.count({c10::SchemaArgType::input, j})) { + wildcard_set_.insert({c10::SchemaArgType::input, i}); + } + } + } + } + } + + // Fills wildcard_set with container created wildcards. + // For instance, given the schema: + // test(Tensor a, Tensor(*) b, Tensor[] c) -> Tensor + // where value(a) is contained in value(c), then a will be added to the + // wildcard set where it can now alias b. + for (size_t i = 0; i < schema_.arguments().size(); i++) { + for (size_t j = 0; j < schema_.arguments().size(); j++) { + // if they are already aliasing, there is no way one contains the other + if (!input_alias_map_[i].count(j) && + value_map_.count(schema_.arguments()[i].name()) && + value_map_.count(schema_.arguments()[j].name())) { + c10::IValue::HashAliasedIValues subValues; + value_map_[schema_.arguments()[i].name()].getSubValues(subValues); + if (subValues.count(value_map_[schema_.arguments()[j].name()])) { + wildcard_set_.insert({c10::SchemaArgType::input, j}); + } + } + } + } + + // Fills output_alias_map_ + for (size_t i = 0; i < schema_.arguments().size(); i++) { + for (size_t j = 0; j < schema_.returns().size(); j++) { + if (schema_.may_alias( + {c10::SchemaArgType::input, i}, + {c10::SchemaArgType::output, j})) { + if (wildcard_set_.count({c10::SchemaArgType::input, i})) { + wildcard_set_.insert({c10::SchemaArgType::output, j}); + } + output_alias_map_[j].insert( + input_alias_map_[i].begin(), input_alias_map_[i].end()); + } + } + } +} + +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/schema_info.h b/torch/csrc/utils/schema_info.h new file mode 100644 index 0000000000000..a469483512fd8 --- /dev/null +++ b/torch/csrc/utils/schema_info.h @@ -0,0 +1,115 @@ +#pragma once + +#include +#include + +namespace torch { +namespace utils { + +using SchemaSpecialCasePair = + std::pair>; +/** + * class SchemaInfo + * + * FunctionSchema wrapper that publicizes argument value specific operator + * behavior (mutation, aliasing, special cases, etc...) + */ + +struct TORCH_API SchemaInfo { + public: + explicit SchemaInfo(const c10::FunctionSchema& schema) + : schema_(std::move(schema)), + alias_maps_current_(false), + has_init_(false) {} + explicit SchemaInfo(const char* signature) + : schema_(torch::jit::parseSchema(signature)), + alias_maps_current_(false), + has_init_(false) {} + + bool is_mutable(); + + bool is_mutable(const c10::SchemaArgument& argument); + + bool is_mutable(c10::string_view name); + + bool is_nondeterministic() const; + + // Returns whether lhs and rhs may alias directly. + // This does not account for cases where lhs or rhs are a container that + // may contain elements that alias the other argument. + // Besides the checks already included in FunctionSchema::may_alias, this + // method also accounts special aliasing cases causes by aliasing argument + // values supplied from addArgumentValue. + bool may_alias( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs); + + // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a + // container that may contain elements that alias the other argument. Besides + // the checks already included in FunctionSchema::may_contain_alias, this + // method also accounts for special aliasing cases causes by aliasing argument + // values supplied from addArgumentValue. bidirectional = false only returns + // whether lhs may contain an alias of rhs while bidirectional = true returns + // both directions. + bool may_contain_alias( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs, + bool bidirectional = true); + + void addArgumentValue(const std::string& name, const at::IValue& value); + + void addArgumentValues( + const std::vector>& value_list); + + void addArgumentValues( + const std::unordered_map& values); + + bool hasInputArgumentNamed(const std::string& name) const; + + private: + // This function enforces more conservative results when the TORCH_WARN is + // triggered from above due to duplicates in an argument list + void ensureConservativity( + const std::unordered_set& duplicates, + const std::vector& arguments_list, + c10::SchemaArgType type); + + void initSchemaInfo(); + + void generateAliasMaps(); + + bool mayContainAliasImpl( + const c10::SchemaArgument& lhs, + const c10::SchemaArgument& rhs); + + static std::vector getNonDeterministicOps(); + + static std::vector getTrainingOps(); + + const std::unordered_set& wildcardSet(); + + const std::unordered_set& containerSet(); + + // Set of all wildcard arguments + std::unordered_set wildcard_set_; + + // Set of all container arguments + std::unordered_set container_set_; + + // Map of argument IValues + std::unordered_map value_map_; + + // Alias map of inputs with each other + std::vector> input_alias_map_; + + // Alias map of outputs to inputs + std::vector> output_alias_map_; + + const c10::FunctionSchema schema_; + + bool alias_maps_current_; + + bool has_init_; +}; +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h index d6c17a6516f40..cfca55bb86ec7 100644 --- a/torch/csrc/utils/six.h +++ b/torch/csrc/utils/six.h @@ -2,6 +2,7 @@ #include #include +#include #include namespace six { diff --git a/torch/csrc/utils/tensor_list.cpp b/torch/csrc/utils/tensor_list.cpp index 492500992e541..76d587f0166c6 100644 --- a/torch/csrc/utils/tensor_list.cpp +++ b/torch/csrc/utils/tensor_list.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include using namespace at; diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 8376178eab72f..707ebeb19e846 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -336,9 +337,11 @@ Tensor internal_new_from_data( c10::DispatchKey::FuncTorchDynamicLayerFrontMode); c10::impl::ExcludeDispatchKeyGuard functorch_back_guard( c10::DispatchKey::FuncTorchDynamicLayerBackMode); - // We disable DeferredInit handler for similar reasons as functorch. - c10::impl::ExcludeDispatchKeyGuard deferred_init_guard( - c10::DispatchKey::DeferredInit); + // We disable Fake and DeferredInit handlers for similar reasons as + // functorch. + c10::impl::ExcludeDispatchKeyGuard fake_and_deferred_init_guard( + c10::DispatchKeySet{ + c10::DispatchKey::Fake, c10::DispatchKey::DeferredInit}); // Note [Functionalization <> torch.Tensor constructor] // Functionalization "lifts" the newly constructed tensor into a wrapper // using aten::lift(). @@ -359,7 +362,7 @@ Tensor internal_new_from_data( !is_typed_storage || storage_scalar_type == scalar_type, "Expected a Storage of type ", scalar_type, - " or an _UntypedStorage, but got ", + " or an UntypedStorage, but got ", storage_scalar_type); tensor = at::empty( sizes, @@ -412,8 +415,9 @@ Tensor internal_new_from_data( at::tracer::impl::NoTracerDispatchMode tracer_guard; // lift has no autograd implementation, so we need to make sure we don't try // to dispatch to it. + // TODO: arguably it should have an autograd implementation that noops at::AutoDispatchBelowADInplaceOrView guard; - return tensor.lift(); + return at::lift_fresh(tensor); } Tensor new_from_data_copy( @@ -639,7 +643,7 @@ Tensor legacy_tensor_generic_ctor_new( storage_scalar_type == scalar_type, "Expected a Storage of type ", scalar_type, - " or an _UntypedStorage, but got type ", + " or an UntypedStorage, but got type ", storage_scalar_type, " for argument 1 'storage'"); } diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 05db981d5dafd..229352215a283 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -7,7 +7,7 @@ #ifndef USE_NUMPY namespace torch { namespace utils { -PyObject* tensor_to_numpy(const at::Tensor& tensor) { +PyObject* tensor_to_numpy(const at::Tensor&, bool) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } at::Tensor tensor_from_numpy( @@ -29,6 +29,10 @@ bool is_numpy_scalar(PyObject* obj) { at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } + +void warn_numpy_not_writeable() { + throw std::runtime_error("PyTorch was compiled without NumPy support"); +} } // namespace utils } // namespace torch #else @@ -359,7 +363,7 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__")); TORCH_INTERNAL_ASSERT(cuda_dict); - if (!PyDict_Check(cuda_dict)) { + if (!PyDict_Check(cuda_dict.get())) { throw TypeError("`__cuda_array_interface__` must be a dict"); } diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index 53566da839004..25a8ac28a8eca 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -8,6 +8,8 @@ #include #include +#include + #include #include #include @@ -62,8 +64,8 @@ std::string type_to_string(const at::DeprecatedTypeProperties& type) { at::TensorOptions options_from_string(const std::string& str) { static std::string cuda_prefix("torch.cuda."); - static std::once_flag cpu_once; - static std::once_flag cuda_once; + static c10::once_flag cpu_once; + static c10::once_flag cuda_once; static std::unordered_map cpu_map; static std::unordered_map cuda_map; @@ -81,14 +83,14 @@ at::TensorOptions options_from_string(const std::string& str) { if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin()) .first == cuda_prefix.end()) { // torch.cuda. is prefix of str - std::call_once(cuda_once, []() { + c10::call_once(cuda_once, []() { for (auto type : autograd::VariableType::allCUDATypes()) { cuda_map.emplace(type_to_string(*type), type); } }); map = &cuda_map; } else { - std::call_once(cpu_once, []() { + c10::call_once(cpu_once, []() { for (auto type : autograd::VariableType::allCPUTypes()) { cpu_map.emplace(type_to_string(*type), type); } diff --git a/torch/csrc/utils/throughput_benchmark.cpp b/torch/csrc/utils/throughput_benchmark.cpp index 65bc190b8b3c8..dbd89b9f5368e 100644 --- a/torch/csrc/utils/throughput_benchmark.cpp +++ b/torch/csrc/utils/throughput_benchmark.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace torch { namespace throughput_benchmark { diff --git a/torch/csrc/utils/throughput_benchmark.h b/torch/csrc/utils/throughput_benchmark.h index f0582e400847d..56fc0657c1fd9 100644 --- a/torch/csrc/utils/throughput_benchmark.h +++ b/torch/csrc/utils/throughput_benchmark.h @@ -3,6 +3,7 @@ #include #include #include +#include #include diff --git a/torch/csrc/utils/verbose.cpp b/torch/csrc/utils/verbose.cpp new file mode 100644 index 0000000000000..0a532b3d0738e --- /dev/null +++ b/torch/csrc/utils/verbose.cpp @@ -0,0 +1,14 @@ +#include +#include + +namespace torch { + +void initVerboseBindings(PyObject* module) { + auto m = py::handle(module).cast(); + + auto verbose = m.def_submodule("_verbose", "MKL, MKLDNN verbose"); + verbose.def("mkl_set_verbose", torch::verbose::_mkl_set_verbose); + verbose.def("mkldnn_set_verbose", torch::verbose::_mkldnn_set_verbose); +} + +} // namespace torch diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index a78da529dd50d..26e44fa9afe73 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -83,11 +83,15 @@ def is_available() -> bool: return torch._C._cuda_getDeviceCount() > 0 def is_bf16_supported(): - r"""Returns a bool indicating if the current CUDA device supports dtype bfloat16""" + r"""Returns a bool indicating if the current CUDA/ROCm device supports dtype bfloat16""" + # Check for ROCm, if true return true, no ROCM_VERSION check required, + # since it is supported on AMD GPU archs. + if torch.version.hip: + return True + cu_vers = torch.version.cuda if cu_vers is not None: cuda_maj_decide = int(cu_vers.split('.')[0]) >= 11 - else: cuda_maj_decide = False return torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 and cuda_maj_decide diff --git a/torch/cuda/_dynamo_graphs.py b/torch/cuda/_dynamo_graphs.py new file mode 100644 index 0000000000000..1d8ae673edb37 --- /dev/null +++ b/torch/cuda/_dynamo_graphs.py @@ -0,0 +1,160 @@ +import torch +from torch.fx import GraphModule +from torch.nn import Module +from torch.fx.passes.backends.cudagraphs import partition_cudagraphs +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._pytree import tree_map +import torchdynamo # type: ignore[import] +from torchdynamo.optimizations.training import AOTAutogradStrategy # type: ignore[import] + +import operator +from collections import defaultdict +from typing import Set + +# TODO: maybe this should live in torchdynamo instead + +__all__ = ['aot_autograd_cudagraphs'] + +def cloner(t): + if isinstance(t, torch.Tensor): + return t.clone() + else: + return t + + +class CudaGraphModule(Module): + gm: GraphModule + mutated_inputs: Set[int] + + def __init__(self, gm, mutated_inputs): + super().__init__() + self.gm = gm + self.mutated_inputs = mutated_inputs + + warmed_up = False + + # these are all None or all filled + graph = None + static_inputs = None + static_outputs = None + + # NB: we override __call__ as we don't need any nn.Module machinery + # and to reduce overhead + def __call__(self, *args): + # TODO: once we've recorded here, we'd like to replace the __call__ + # implementation with compiled bytecode that copies into static, replays + # the cuda graph, then copies out. First condition is the hotpath, + # needs optimizing + if self.graph is not None: + assert len(args) == len(self.static_inputs) + for dst, src in zip(self.static_inputs, args): + dst.copy_(src) + self.graph.replay() + for i in self.mutated_inputs: + args[i].copy_(self.static_inputs[i]) + return tree_map(cloner, self.static_outputs) + + elif self.warmed_up: + # record + self.static_inputs = [x.clone() for x in args] + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph): + self.static_outputs = self.gm(*self.static_inputs) + # NB: recording doesn't actually run the operations, so + # now we immediately replay the graph to serve up the result + self.graph.replay() + for i in self.mutated_inputs: + args[i].copy_(self.static_inputs[i]) + return tree_map(cloner, self.static_outputs) + + else: + # warmup + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + r = self.gm(*args) + torch.cuda.current_stream().wait_stream(stream) + self.warmed_up = True + return r + + +# Interpreter versions of these passes can be found at +# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23 + + +def find_input_mutations(g): + FK = 'fake_result' + inputs = defaultdict(set) + input_idx = 0 + mutated_inputs = set() + for n in g.nodes: + if n.op == 'placeholder': + inputs[StorageWeakRef(n.meta[FK].storage())].add(input_idx) + input_idx += 1 + elif n.op == 'call_function': + if n.target is operator.getitem: + continue + schema = n.target._schema + for i, arg in enumerate(schema.arguments): + if i < len(n.args): + argument = n.args[i] + else: + if arg.name not in n.kwargs: + continue + argument = n.kwargs[arg.name] + mut_arg = False + if arg.alias_info: + if arg.alias_info.is_write: + mut_arg = True + if mut_arg: + # TODO: not correct for args that contain tensors in a struct + # like list + mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK].storage())] + # TODO: error on unrecognized nodes + return mutated_inputs + + +# Mutates input graph +def apply_cuda_graphs(gm): + for n in gm.graph.nodes: + if n.op == 'call_module': + assert not n.kwargs + submod = gm.get_submodule(n.target) + gm.delete_submodule(n.target) + mutated_inputs = find_input_mutations(submod.graph) + gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs)) + # NB: we didn't actually change the graph, no need for recompile + + +def cudagraphs(model, inputs): + model = partition_cudagraphs(model, inputs) + apply_cuda_graphs(model) + return model + + +def raw_aot_autograd_cudagraphs(model, inputs): + kwargs = { + # these are taken from memory_efficient_fusion() + "fw_compiler": cudagraphs, + "bw_compiler": cudagraphs, + "hasher_type": "StaticShapeHasher", + } + + def _wrapped_bw_compiler(*args, **kwargs): + # stop TorchDynamo from trying to compile our generated backwards pass + return torchdynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator] + + bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] + kwargs["bw_compiler"] = _wrapped_bw_compiler + + from functorch.compile import aot_module_simplified # type: ignore[import] + + return aot_module_simplified(model, **kwargs) + + +class AOTAutogradCudaGraphs(AOTAutogradStrategy): + def candidate(self): + return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs) + + +aot_autograd_cudagraphs = AOTAutogradCudaGraphs.compile_fn diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index fb4f4350be643..b39ae4106fb6d 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -102,8 +102,9 @@ class GradScaler(object): :meth:`update` if inf/NaN gradients occur in an iteration. growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``. - enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` """ def __init__(self, init_scale=2.**16, diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index cf4d1ed7c1e3b..c69c5fdccdd58 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -281,6 +281,9 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3): # the safest approach is to capture all passes in the same order they'll run: # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. + # Clear AMP autocast cache before capturing the graphs + torch.clear_autocast_cache() + # Capture forward graphs per_callable_static_outputs = [] per_callable_output_was_tensor = [] @@ -340,6 +343,9 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3): per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs)) # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. + # Clear AMP autocast cache after both forward and backward graphs are captured + torch.clear_autocast_cache() + def make_graphed_autograd_function(fwd_graph, bwd_graph, module_params, diff --git a/torch/cuda/jiterator.py b/torch/cuda/jiterator.py index 51a171e83a3ae..207da5685acd6 100644 --- a/torch/cuda/jiterator.py +++ b/torch/cuda/jiterator.py @@ -86,7 +86,7 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: Jiterator-generated kernels accepts noncontiguous tensors, and supports boardcasting and type promotion. Args: - code_string (string): CUDA code string to be compiled by jiterator. The entry functor must return by value. + code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value. kwargs (Dict, optional): Keyword arguments for generated function Example:: @@ -140,7 +140,7 @@ def _create_multi_output_jit_fn(code_string: str, num_outputs: int, **kwargs) -> Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs. Args: - code_string (string): CUDA code string to be compiled by jiterator. The entry functor must return value by reference. + code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference. num_outputs(int): number of outputs return by the kernel kwargs (Dict, optional): Keyword arguments for generated function diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index a4b7b1d956cbe..904ad305828b5 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -1,7 +1,7 @@ import collections import contextlib import warnings -from typing import Any, Dict, Union +from typing import Any, Dict, Union, Tuple import torch from . import is_initialized, _get_device_index, _lazy_init @@ -573,7 +573,7 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory") return "\n".join(lines) -def mem_get_info(device: Union[Device, int] = None) -> int: +def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: r"""Returns the global free and total GPU memory occupied for a given device using cudaMemGetInfo. diff --git a/torch/cuda/nvtx.py b/torch/cuda/nvtx.py index c0db21354e25d..7e2e8715a605f 100644 --- a/torch/cuda/nvtx.py +++ b/torch/cuda/nvtx.py @@ -23,7 +23,7 @@ def range_push(msg): depth of the range that is started. Args: - msg (string): ASCII message to associate with range + msg (str): ASCII message to associate with range """ return _nvtx.rangePushA(msg) @@ -48,7 +48,7 @@ def range_start(msg) -> int: Returns: A range handle (uint64_t) that can be passed to range_end(). Args: - msg (string): ASCII message to associate with the range. + msg (str): ASCII message to associate with the range. """ return _nvtx.rangeStartA(msg) @@ -68,7 +68,7 @@ def mark(msg): Describe an instantaneous event that occurred at some point. Args: - msg (string): ASCII message to associate with the event. + msg (str): ASCII message to associate with the event. """ return _nvtx.markA(msg) @@ -81,7 +81,7 @@ def range(msg, *args, **kwargs): they are passed as arguments to msg.format(). Args: - msg (string): message to associate with the range + msg (str): message to associate with the range """ range_push(msg.format(*args, **kwargs)) yield diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 2b56557446a4b..002e2dd48b7bb 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -134,7 +134,7 @@ class ExternalStream(Stream): def __new__(cls, stream_ptr, device=None, **kwargs): with torch.cuda.device(device): - return super(Stream, cls).__new__(cls, stream_ptr=stream_ptr, **kwargs) + return super(ExternalStream, cls).__new__(cls, stream_ptr=stream_ptr, **kwargs) class Event(torch._C._CudaEventBase): diff --git a/torch/distributed/_shard/_utils.py b/torch/distributed/_shard/_utils.py index a81c2398f519d..7e347fefa27c0 100644 --- a/torch/distributed/_shard/_utils.py +++ b/torch/distributed/_shard/_utils.py @@ -1,21 +1,26 @@ import torch from torch.distributed._shard.metadata import ShardMetadata +from typing import Sequence -def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata): +def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], sizes: Sequence[int]) -> torch.Tensor: """ - narrow the tensor according to the metadata + Narrow the tensor according to ``offsets`` and ``sizes``. """ narrowed_tensor = tensor - shard_offsets = metadata.shard_offsets - shard_sizes = metadata.shard_sizes - for idx, (offset, size) in enumerate(zip(shard_offsets, shard_sizes)): + for idx, (offset, size) in enumerate(zip(offsets, sizes)): if size < tensor.size(idx): # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. narrowed_tensor = narrowed_tensor.narrow( idx, - shard_offsets[idx], - shard_sizes[idx] + offset, + size ) return narrowed_tensor + +def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor: + """ + Narrow the tensor according to the metadata + """ + return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes) diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py index 1bc58400b1243..b212d6815afb2 100644 --- a/torch/distributed/_shard/api.py +++ b/torch/distributed/_shard/api.py @@ -119,15 +119,7 @@ def shard_parameter( st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) # Replace param with ShardedTensor. - - # Need to delete the attribute first since param_name might be - # torch.nn.Parameter and can't be replaced with ShardedTensor which is - # not torch.nn.Parameter. - delattr(module, param_name) - - # Now we can set the attribute appropriately. - setattr(module, param_name, st) - + module.register_parameter(param_name, nn.Parameter(st)) def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor: """ diff --git a/torch/distributed/_shard/checkpoint/filesystem.py b/torch/distributed/_shard/checkpoint/filesystem.py index 607cb297ed41e..6498f797efe7d 100644 --- a/torch/distributed/_shard/checkpoint/filesystem.py +++ b/torch/distributed/_shard/checkpoint/filesystem.py @@ -16,6 +16,7 @@ TensorWriteRequest, ) from .storage import StorageReader, StorageWriter +from torch.distributed._shard._utils import narrow_tensor_by_index class FileSystemWriter(StorageWriter): @@ -112,8 +113,7 @@ def read_tensors(self, requests: List[TensorReadRequest]) -> Future[None]: # During load time, we will load the Tensor (with it orignal view) # narrow it along all dimemsions, and copy_ it to the # target tensor, which will be the same size. - for dim, (start, length) in enumerate(zip(req.offsets, req.lengths)): - view_to_copy = torch.narrow(view_to_copy, dim, start, length) + view_to_copy = narrow_tensor_by_index(view_to_copy, req.offsets, req.lengths) assert ( view_to_copy.size() == req.tensor.size() diff --git a/torch/distributed/_shard/checkpoint/metadata.py b/torch/distributed/_shard/checkpoint/metadata.py index 30ed3b2af92f6..40586113b460e 100644 --- a/torch/distributed/_shard/checkpoint/metadata.py +++ b/torch/distributed/_shard/checkpoint/metadata.py @@ -1,6 +1,6 @@ import io -from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union, Optional, Sequence, Any import torch from torch.distributed._shard.sharded_tensor import ( @@ -10,6 +10,7 @@ ) TENSOR_TYPE = Union[torch.Tensor, ShardedTensor] +STATE_DICT_TYPE = Dict[str, Any] @dataclass class ShardStorageMetadata: @@ -77,3 +78,34 @@ class TensorReadRequest: # offset and length w.r.t. to the storage identified by ``storage_key`` offsets: Tuple[int, ...] lengths: Tuple[int, ...] + + +@dataclass(frozen=True) +class MetadataIndex: + """ + This class represents a lookup key for items in a state dict or Metadata. + """ + fqn: str + """Fully Qualified Name of the object""" + + offset: Optional[torch.Size] = None + """If the object is a tensor, offset into the tensor we're looking for""" + + index: Optional[int] = field(hash=False, compare=False, default=None) + """ + Index hint when searching for tensor chunk to speedup lookups (optional) + + A common representation of a sharded tensor is as a list of chunks so to + find the index in such a list you need to linear search it. + + When constructing an instance of MetadataIndex that points to that list, + one can provide the index as a hint and it will be probed first before + the linear search and thus making it significantly faster. + """ + + def __init__(self, fqn: str, offset: Optional[Sequence[int]] = None, index: Optional[int] = None): + # We must use object.__setattr__ due to frozen=True + object.__setattr__(self, "fqn", fqn) + object.__setattr__(self, "index", index) + if offset is not None: + object.__setattr__(self, "offset", torch.Size(offset)) diff --git a/torch/distributed/_shard/checkpoint/resharding.py b/torch/distributed/_shard/checkpoint/resharding.py index 6f6a71b92d10f..15be6dd7cdc55 100644 --- a/torch/distributed/_shard/checkpoint/resharding.py +++ b/torch/distributed/_shard/checkpoint/resharding.py @@ -27,6 +27,12 @@ TensorWriteRequest, ) +def _trim(tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.detach() + if tensor.storage().size() != tensor.numel(): + return tensor.clone() + return tensor + def _create_storage_key( storage_key_to_fqn: Dict[str, str], fqn: str @@ -114,13 +120,6 @@ def _compute_sharded_tensor_md( for shard_md in tensor.metadata().shards_metadata: shard_storage_key = shard_to_storage_key[_get_shard_key(shard_md)] - shard_size = 1 - for d in shard_md.shard_sizes: - shard_size *= d - - # not particularly great - storage_size = shard_size * _get_sharded_tensor_element_size(tensor) - one_smd = ShardStorageMetadata( shard_metadata=shard_md, storage_key=shard_storage_key, @@ -183,7 +182,7 @@ def _prepare_sharded_tensor_write( shard_storage_key = shard_to_storage_key[_get_shard_key(shard.metadata)] wr = TensorWriteRequest( - tensor=tensor, + tensor=_trim(tensor), storage_key=shard_storage_key, ) write_requests.append(wr) @@ -271,7 +270,7 @@ def _prepare_tensor_write( write_reqs = [ TensorWriteRequest( - tensor=tensor.detach(), + tensor=_trim(tensor), storage_key=storage_key, ) ] diff --git a/torch/distributed/_shard/checkpoint/state_dict_loader.py b/torch/distributed/_shard/checkpoint/state_dict_loader.py index 913008f68912c..3c1ccf9e59839 100644 --- a/torch/distributed/_shard/checkpoint/state_dict_loader.py +++ b/torch/distributed/_shard/checkpoint/state_dict_loader.py @@ -1,5 +1,5 @@ import io -from typing import Any, Dict, List, Tuple, Optional, cast +from typing import Any, Dict, List, Tuple, Optional from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharded_tensor.shard import Shard @@ -32,10 +32,9 @@ StorageReader, ) -from .api import CheckpointException +from .utils import _DistWrapper def _create_shard_metadata(size: torch.Size) -> ShardMetadata: - rank = dist.get_rank() if dist.is_initialized() else 0 return ShardMetadata( shard_offsets=[0] * len(size), shard_sizes=list(size), @@ -164,9 +163,9 @@ def load_state_dict( is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()`` """ - is_coordinator = no_dist or dist.get_rank(process_group) == coordinator_rank + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) - try: + def load_model(): metadata = storage_reader.read_metadata() bytes_read_requests, tensor_read_requests = _reshard_and_prepare_read_request( state_dict=state_dict, metadata_from_storage=metadata @@ -185,27 +184,8 @@ def load_state_dict( state_dict[req.fqn] = torch.load(req.bytes) tensor_futures.wait() - result = None - except BaseException as e: - result = e - - global_result: Optional[CheckpointException] = None - if not no_dist: - all_errors = [None] * dist.get_world_size(process_group) - - dist.all_gather_object( - object_list=all_errors, - obj=result, - group=process_group) - - node_failures = cast(Dict[int, BaseException], {i: err for i, err in enumerate(all_errors) if err is not None}) - if len(node_failures) > 0: - global_result = CheckpointException("failed to read checkpoint", node_failures) - elif result is not None: - global_result = CheckpointException("failed to read storage", {coordinator_rank : result}) - - if global_result is not None: - raise global_result + + distW.all_gather("checkpoint read", load_model) def _validate_sharded_tensor( diff --git a/torch/distributed/_shard/checkpoint/state_dict_saver.py b/torch/distributed/_shard/checkpoint/state_dict_saver.py index 27fd0f392e702..d60d143d48dc8 100644 --- a/torch/distributed/_shard/checkpoint/state_dict_saver.py +++ b/torch/distributed/_shard/checkpoint/state_dict_saver.py @@ -25,14 +25,14 @@ StorageWriter, ) -from .api import CheckpointException +from .utils import _DistWrapper + # -------------- private functions -------------- def _prepare( state_dict: Dict[str, Any], write_replicated_data: bool, - process_group: Optional[dist.ProcessGroup] = None, ) -> Tuple[Metadata, List[BytesWriteRequest], List[TensorWriteRequest]]: """ Build the serialization plan for a given state_dict @@ -141,29 +141,18 @@ def save_state_dict( is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()`` """ - is_coordinator = no_dist or dist.get_rank(process_group) == coordinator_rank - - exceptions: List[Optional[BaseException]] = [None] - if is_coordinator: - try: - storage_writer.prepare() - except BaseException as e: - exceptions = [e] - - # Writing can only start once prepare has finished - if not no_dist: - dist.broadcast_object_list(exceptions, group=process_group, src=coordinator_rank) + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) - if exceptions[0] is not None: - raise CheckpointException("failed to prepare storage", {coordinator_rank : exceptions[0]}) + distW.broadcast("prepare", storage_writer.prepare) + metadata = None - rank_write_error: Optional[BaseException] - try: + def write_step(): + nonlocal metadata ( metadata, bytes_write_requests, tensor_write_requests, - ) = _prepare(state_dict, is_coordinator, process_group) + ) = _prepare(state_dict, distW.is_coordinator) combined_writes: List[Union[TensorWriteRequest, BytesWriteRequest]] = [] combined_writes.extend(tensor_write_requests) @@ -173,44 +162,9 @@ def save_state_dict( bytes_futures = storage_writer.write_bytes(bytes_write_requests) tensor_futures = storage_writer.write_tensors(tensor_write_requests) torch.futures.wait_all([bytes_futures, tensor_futures]) - rank_write_error = None - except BaseException as e: - rank_write_error = e - - all_errors: List[Optional[BaseException]] - # collect all write errors - if not no_dist: - all_errors = [None] * dist.get_world_size(process_group) - dist.gather_object( - obj=rank_write_error, - object_gather_list=all_errors if is_coordinator else None, - dst=coordinator_rank - ) - else: - all_errors = [rank_write_error] - - result: List[Optional[CheckpointException]] = [None] - if is_coordinator: - message: Optional[str] = None - # gather produces an array of arrays, flatten it - if any(all_errors): - message = "Failed to write data" - else: - try: - storage_writer.finish(metadata=metadata) - except BaseException as e: - all_errors[coordinator_rank] = e - message = "Failed to finish checkpoint" - - if message is not None: - node_failures = {i: err for i, err in enumerate(all_errors) if err is not None} - result[0] = CheckpointException(message, node_failures) - - if not no_dist: - dist.broadcast_object_list( - result, - group=process_group, - src=coordinator_rank) - - if result[0] is not None: - raise result[0] + + def finish_checkpoint(_): + assert metadata is not None + storage_writer.finish(metadata=metadata) + + distW.all_reduce("checkpoitn write", write_step, finish_checkpoint) diff --git a/torch/distributed/_shard/checkpoint/utils.py b/torch/distributed/_shard/checkpoint/utils.py new file mode 100644 index 0000000000000..86e4543c6fdcc --- /dev/null +++ b/torch/distributed/_shard/checkpoint/utils.py @@ -0,0 +1,267 @@ +from typing import List, Callable, Optional, Union, TypeVar, cast, Any +import torch.distributed as dist +from .api import CheckpointException + +import torch + +from torch.distributed._shard.sharded_tensor import ( + ShardedTensor, +) + +from torch.distributed._shard.sharded_tensor.shard import Shard + +from .metadata import ( + STATE_DICT_TYPE, + MetadataIndex, +) + +T = TypeVar('T') +R = TypeVar('R') + +class _DistWrapper: + """ + This is a wrapper around PG that provides a series of features around object collectives. + + It works without distributed initialized, where most collectives turns into nops. + + All variants that take functions are exception robust, meaning that if one or more + ranks raise errors, all ranks will observe those. + """ + def __init__(self, group: Optional[dist.ProcessGroup], use_dist: bool, coordinator_rank: int): + self.group = group + self.use_dist = use_dist + self.coordinator_rank = coordinator_rank + if self.use_dist: + self.rank = dist.get_rank(group) + self.is_coordinator = self.rank == coordinator_rank + else: + self.rank = 0 + self.is_coordinator = True + + def get_rank(self) -> int: + return self.rank + + def get_world_size(self) -> int: + if self.use_dist: + return dist.get_world_size(self.group) + return 1 + + def broadcast_object(self, object: Optional[T]) -> T: + """ + Same as c10d::broadcast_object_list but works without distributed enabled. + """ + object_list = [object] + if self.use_dist: + dist.broadcast_object_list( + object_list=object_list, + group=self.group, + src=self.coordinator_rank) + return cast(T, object_list[0]) + + def gather_object(self, object: T) -> Optional[List[T]]: + """ + Same as c10d::gather_object but works without distributed enabled. + """ + if self.use_dist: + gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) if self.is_coordinator else None + + dist.gather_object( + obj=object, + object_gather_list=gather_objs if self.is_coordinator else None, + dst=self.coordinator_rank, + group=self.group + ) + result = gather_objs + else: + result = [object] + return result + + def all_gather_object(self, object: T) -> List[T]: + """ + Same as c10d::all_gather_object but works without distributed enabled. + """ + if self.use_dist: + gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) + + dist.all_gather_object( + object_list=gather_objs, + obj=object, + group=self.group + ) + else: + gather_objs = [object] + return gather_objs + + def scatter_object(self, object_list: Optional[List[T]]) -> T: + """ + Same as c10d::scatter_object but works without distributed enabled. + """ + if self.use_dist: + gather_result = cast(List[T], [None]) + dist.scatter_object_list( + scatter_object_output_list=gather_result, + scatter_object_input_list=object_list if self.is_coordinator else None, + src=self.coordinator_rank, + group=self.group + ) + + local_reply = gather_result[0] + else: + assert object_list is not None + local_reply = object_list[0] + return local_reply + + def reduce_scatter( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[List[T]], List[R]] + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Scatter to each rank part of the result. + """ + local_data: Union[BaseException, T] + try: + local_data = map_fun() + except BaseException as e: + local_data = e + + all_data = self.gather_object(local_data) + all_results: Optional[List[Union[R, CheckpointException]]] = None + if self.is_coordinator: + assert all_data is not None + node_failures = {i: err for i, err in enumerate(all_data) if isinstance(err, BaseException)} + + if len(node_failures) == 0: + try: + # N.B. why can't mypy cast List[R] to List[Union[R, CheckpointException]]? + all_results = cast(List[Union[R, CheckpointException]], reduce_fun(cast(List[T], all_data))) + except BaseException as e: + node_failures[self.rank] = e + + if len(node_failures) > 0: + all_results = [CheckpointException(step, node_failures)] * self.get_world_size() + + result = self.scatter_object(all_results) + if isinstance(result, CheckpointException): + raise result + return result + + def all_reduce( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[List[T]], R] + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Broadcast the reduced value to all ranks. + """ + local_data: Union[T, BaseException] + try: + local_data = map_fun() + except BaseException as e: + local_data = e + + all_data = self.gather_object(local_data) + result: Optional[Union[R, CheckpointException]] = None + if self.is_coordinator: + assert all_data is not None + node_failures = {i: err for i, err in enumerate(all_data) if isinstance(err, BaseException)} + if len(node_failures) == 0: + try: + result = reduce_fun(cast(List[T], all_data)) + except BaseException as e: + node_failures[self.rank] = e + + if len(node_failures) > 0: + result = CheckpointException(step, node_failures) + + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(R, final_result) + + def all_gather( + self, + step: str, + map_fun: Callable[[], T], + ) -> List[T]: + """ + Compute a value on each rank, then all_gather them. + + This method operates in the following way: + Run ``map_cp`` on all ranks + all_gather the values to all ranks + """ + result: Union[T, BaseException] + try: + result = map_fun() + except BaseException as e: + result = e + + all_results = self.all_gather_object(result) + + node_failures = {i: err for i, err in enumerate(all_results) if isinstance(err, BaseException)} + if len(node_failures) > 0: + raise CheckpointException(step, node_failures) + return cast(List[T], all_results) + + def broadcast( + self, + step: str, + map_fun: Callable[[], T], + ) -> T: + """ + Compute a value on rank 0 and broadcast it. + + This method operates in the following way: + Run ``map_cp`` on rank 0 + broadcast the value + """ + result: Optional[Union[T, CheckpointException]] = None + if self.is_coordinator: + try: + result = map_fun() + except BaseException as e: + result = CheckpointException(step, {self.rank: e}) + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(T, final_result) + +def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: + if index.offset is None: + raise ValueError(f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided") + + shards = tensor.local_shards() + # index fast path + if index.index is not None: + if len(shards) > index.index and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset: + return shards[index.index] + + for shard in shards: + if torch.Size(shard.metadata.shard_offsets) == index.offset: + return shard + raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") + +def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: + if index.fqn not in state_dict: + raise ValueError(f"Could not find FQN: '{index.fqn}'") + obj = state_dict[index.fqn] + if isinstance(obj, ShardedTensor): + return _find_shard(obj, index).tensor + if index.offset is not None: + raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset") + return obj diff --git a/torch/distributed/_shard/sharded_optim/__init__.py b/torch/distributed/_shard/sharded_optim/__init__.py index e3cc7309bae5a..fbb021f7df442 100644 --- a/torch/distributed/_shard/sharded_optim/__init__.py +++ b/torch/distributed/_shard/sharded_optim/__init__.py @@ -25,7 +25,7 @@ def named_params_with_sharded_tensor( are direct members of this module. Yields: - (string, Union[Tensor, ShardedTensor]): Tuple containing + (str, Union[Tensor, ShardedTensor]): Tuple containing the name and parameter (or ShardedTensor parameter) Example:: diff --git a/torch/distributed/_shard/sharded_optim/api.py b/torch/distributed/_shard/sharded_optim/api.py index 7accc82754a1b..ec4f9e6ae7491 100644 --- a/torch/distributed/_shard/sharded_optim/api.py +++ b/torch/distributed/_shard/sharded_optim/api.py @@ -20,8 +20,7 @@ def __init__( Args: named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict of parameters, where key is the parameter key, value is either - Tensor or ShardedTensor parameter. This usually used in - conjunction with :meth:`named_params_with_sharded_tensor` + Tensor or ShardedTensor parameter. optimizer_class (torch.optim.Optimizer): the Optimizer to use locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc. *optimizer_args: the arguments to initialize the optimizer. @@ -62,7 +61,7 @@ def step(self, closure=None): r"""Performs a single optimization step (parameter update). Args: - closure (callable): A closure that reevaluates the model and + closure (Callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. .. note:: diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index 2d5d1cb637393..827612fee7fd6 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -12,6 +12,7 @@ _CUSTOM_SHARDED_OPS, _SHARDED_OPS, Shard, + ShardedTensorBase, ShardedTensor, ShardedTensorMetadata, TensorProperties, diff --git a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py index 1bbc079f78436..1243a1d2396dd 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -3,6 +3,7 @@ import torch.distributed._shard.sharded_tensor._ops.math_ops import torch.distributed._shard.sharded_tensor._ops.matrix_ops import torch.distributed._shard.sharded_tensor._ops.tensor_ops +import torch.distributed._shard.sharded_tensor._ops.misc_ops from .binary_cmp import equal, allclose from .init import kaiming_uniform_, normal_, uniform_, constant_ diff --git a/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py index fa2d30e7e36bc..fe41cc79a858c 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py @@ -1,12 +1,10 @@ import torch from torch import Tensor -from torch.distributed._shard.sharded_tensor import ( - ShardedTensor, - _sharded_op_impl -) +from torch.distributed._shard.sharded_tensor import ShardedTensor, _sharded_op_impl from torch.distributed._shard.replicated_tensor import ReplicatedTensor from torch.distributed._shard._utils import narrow_tensor + def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None): """ Handles ``__torch_function__`` dispatch for the binary math ops @@ -33,7 +31,8 @@ def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None): res, rhs.sharding_spec(), rhs.size(), # type: ignore[arg-type] - process_group=pg) + process_group=pg, + ) elif isinstance(rhs, ReplicatedTensor): assert isinstance(lhs, ShardedTensor) @@ -49,7 +48,8 @@ def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None): res, lhs.sharding_spec(), lhs.size(), # type: ignore[arg-type] - process_group=pg) + process_group=pg, + ) elif isinstance(lhs, (int, float)): assert isinstance(rhs, ShardedTensor) @@ -58,7 +58,8 @@ def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None): res, rhs.sharding_spec(), rhs.size(), # type: ignore[arg-type] - process_group=pg) + process_group=pg, + ) elif isinstance(rhs, (int, float)): assert isinstance(lhs, ShardedTensor) @@ -67,40 +68,40 @@ def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None): res, lhs.sharding_spec(), lhs.size(), # type: ignore[arg-type] - process_group=pg) + process_group=pg, + ) else: raise RuntimeError( f"torch function '{op.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported yet for ShardedTensor!") + f"kwargs: {kwargs} not supported yet for ShardedTensor!" + ) + def register_math_op(op): @_sharded_op_impl(op) def binary_math_op(types, args=(), kwargs=None, pg=None): return binary_math_op_impl(op, types, args, kwargs, pg) + binary_ops = [ # add torch.add, Tensor.add, - Tensor.add_, Tensor.__add__, Tensor.__radd__, # sub torch.sub, Tensor.sub, - Tensor.sub_, Tensor.__sub__, Tensor.__rsub__, # mul torch.mul, Tensor.mul, - Tensor.mul_, Tensor.__mul__, Tensor.__rmul__, # div torch.div, Tensor.div, - Tensor.div_, Tensor.__div__, Tensor.__rdiv__, ] diff --git a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py new file mode 100644 index 0000000000000..0e0911bb1d18c --- /dev/null +++ b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -0,0 +1,12 @@ +import torch +from torch.distributed._shard.sharded_tensor import ( + _sharded_op_impl, +) + +# This is used by `_apply()` within module.py to set new +# parameters after apply a certain method, we should follow +# the future behavior of overwriting the existing tensor +# instead of doing in-place change using `.data = `. +@_sharded_op_impl(torch._has_compatible_shallow_copy_type) +def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None): + return False diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index 84d893d6519d9..2c9d0df4d84b4 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -10,28 +10,8 @@ ) from torch.distributed._shard.common_op_utils import _register_default_op -@_sharded_op_impl(torch.Tensor.__deepcopy__) -def tensor_deepcopy(types, args=(), kwargs=None, pg=None): - # NOTE: we directly implement deepcopy magic method - # instead of using the default tensor.__deepcopy__ - # and implement clone(). This is because the default - # tensor deepcopy copies every attribute, but the - # process_group in ShardedTensor cannot be deep copied. - self_st = args[0] - # Validate types - if not isinstance(self_st, ShardedTensor): - raise TypeError("input needs to be a ShardedTensor") - - return ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=copy.deepcopy(self_st.local_shards()), - sharded_tensor_metadata=copy.deepcopy(self_st.metadata()), - process_group=self_st._process_group, - init_rrefs=self_st._init_rrefs - ) - # Tensor properties access -_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined] _register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined] _register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined] _register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined] @@ -44,6 +24,27 @@ def tensor_deepcopy(types, args=(), kwargs=None, pg=None): # __reduce_ex__ to dispatch to get_state/set_state _register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl) +# autograd related properties +_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined] +# TODO: set grad with a ShardedTensor that consists of all local grads +_register_default_op(torch.Tensor.grad.__get__, _sharded_op_impl) # type: ignore[union-attr] +_register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined] + +# device property is ambiguous as from a global prospective, +# ShardedTensor.device consists of multiple devices (might even across hosts) +# We choose to return the current device of the local tensor to represent +# the device property on each rank +@_sharded_op_impl(torch.Tensor.device.__get__) +def tensor_device(types, args=(), kwargs=None, pg=None): + self_st = args[0] + # Validate types + if not isinstance(self_st, ShardedTensor): + raise TypeError("input needs to be a ShardedTensor") + + return self_st.local_shards()[0].tensor.device + + def sharded_type_as_check(*args, **kwargs): """ Perform extra checks for the sharded_type_as op such as the input needs to @@ -99,6 +100,7 @@ def sharded_type_as(args, kwargs, pg): customized_func=sharded_type_as, ) + def sharded_deepcopy(args, kwargs, pg): # NOTE: we directly implement deepcopy magic method # instead of using the default tensor.__deepcopy__ @@ -110,11 +112,32 @@ def sharded_deepcopy(args, kwargs, pg): new_metadata = copy.deepcopy(self_st.metadata()) return new_local_shards, new_metadata + _register_sharded_op_on_local_shards( torch.Tensor.__deepcopy__, customized_func=sharded_deepcopy, ) + +@_sharded_op_impl(torch.Tensor.copy_) +def sharded_inplace_copy(types, args, kwargs, pg): + # NOTE: inplace op don't need to rewrap + kwargs = {} if kwargs is None else kwargs + self_st = args[0] + new_st = args[1] + nonblocking = kwargs.get("non_blocking", False) + self_meta = self_st.metadata() + new_meta = new_st.metadata() + if self_meta != new_meta: + raise RuntimeError( + "inplace copy can only happen between two ShardedTensor with same metadata!" + ) + for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): + local_shard.tensor.copy_(new_shard.tensor, nonblocking) + + return self_st + + def sharded_clone(args, kwargs, pg): self_st = args[0] desire_memory_format = kwargs.get("memory_format", None) @@ -130,11 +153,13 @@ def sharded_clone(args, kwargs, pg): new_metadata = copy.deepcopy(self_st.metadata()) return cloned_local_shards, new_metadata + _register_sharded_op_on_local_shards( torch.Tensor.clone, customized_func=sharded_clone, ) + def sharded_detach(args, kwargs, pg): self_st = args[0] detached_local_shards = [ @@ -148,25 +173,33 @@ def sharded_detach(args, kwargs, pg): new_metadata.tensor_properties.requires_grad = False return detached_local_shards, new_metadata + _register_sharded_op_on_local_shards( torch.Tensor.detach, customized_func=sharded_detach, ) + @_sharded_op_impl(torch.Tensor.requires_grad_) def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): self_st = args[0] - requires_grad = args[1] # Validate types if not isinstance(self_st, ShardedTensor): raise TypeError("input needs to be a ShardedTensor") + if kwargs is None: + kwargs = {} + + requires_grad = args[1] if len(args) > 1 else kwargs.get("requires_grad", True) if requires_grad == self_st.requires_grad: return self_st for local_shard in self_st.local_shards(): local_shard.tensor.requires_grad_(requires_grad) + # update the wrapper class property + with torch._C.DisableTorchFunction(): + self_st.requires_grad_(requires_grad) # update the metadata in the meanwhile self_st._metadata.tensor_properties.requires_grad = requires_grad return self_st diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index a83f87c88869e..d30f965a387bf 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -7,7 +7,6 @@ Optional, Sequence, Tuple, - Union, cast, ) import copy @@ -40,7 +39,6 @@ build_metadata_from_local_shards, build_global_metadata ) -from torch.overrides import handle_torch_function from torch.distributed.remote_device import _remote_device from torch.utils._pytree import tree_map @@ -67,9 +65,186 @@ def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]] else: sharded_tensor._register_remote_shards(rrefs, rpc_rank) -class ShardedTensor(object): +class ShardedTensorBase(torch.Tensor): + _sharding_spec: shard_spec.ShardingSpec + _metadata: ShardedTensorMetadata + _local_shards: List[Shard] + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + # Use __new__ to construct a wrapper tensor, for recording tensor + # properties and logging purposes. + torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor") + + # check sharding spec and build sharded tensor metadata + if not isinstance(sharding_spec, shard_spec.ShardingSpec): + raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}") + + sizes = _flatten_tensor_size(size) + dtype = kwargs["dtype"] + layout = kwargs["layout"] + pin_memory = kwargs["pin_memory"] + requires_grad = kwargs["requires_grad"] + + if dtype is None: + dtype = torch.get_default_dtype() + + tensor_properties = TensorProperties( + dtype, layout, requires_grad, pin_memory=pin_memory + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + sizes, tensor_properties=tensor_properties + ) + + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + sizes, + dtype=dtype, + layout=layout, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + # set sharding spec + r._sharding_spec = sharding_spec + # set metadata + r._metadata = sharded_tensor_metadata + # set local shards + r._local_shards = [] + return r + + def metadata(self) -> ShardedTensorMetadata: + """ + Returns a :class:`ShardedTensorMetadata` object corresponding to the + metadata for the entire tensor. + """ + return self._metadata + + def local_shards(self) -> List[Shard]: + """ + Returns a list of :class:`Shard' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + @classmethod + def _init_from_local_shards_and_global_metadata( + cls, + local_shards: List[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + sharding_spec=None, + ) -> "ShardedTensor": + """ + Initialize a ShardedTensorBase with local shards and a global + ShardedTensorMetadata built on each rank. + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + if tensor_properties.layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor_base = ShardedTensor.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): + tensor_property_or_metadata = ( + "tensor property" if is_property else "local ShardMetadata" + ) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property is incompatible with " + f"{tensor_property_or_metadata} on rank {rank}: " + f"{tensor_property_or_metadata} {prop_name}={expected}, " + f"local shard tensor {prop_name}={actual}." + ) + + for shard in local_shards: + shard_meta = shard.metadata + local_shard_tensor = shard.tensor + placement = shard_meta.placement + assert placement is not None, "Must specify placement for `Shard`!" + rank = placement.rank() + local_device = placement.device() + + _raise_if_mismatch( + tensor_properties.layout, + local_shard_tensor.layout, + "layout", + rank, + True, + ) + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) + + _raise_if_mismatch( + shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + rank, + ) + _raise_if_mismatch( + tensor_properties.pin_memory, + local_shard_tensor.is_pinned(), + "pin_memory", + rank, + True, + ) + _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) + _raise_if_mismatch( + tensor_properties.dtype, + local_shard_tensor.dtype, + "dtype", + rank, + True, + ) + _raise_if_mismatch( + tensor_properties.requires_grad, + local_shard_tensor.requires_grad, + "requires_grad", + rank, + True, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor_base._local_shards = local_shards + return sharded_tensor_base + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + raise RuntimeError( + f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} " + "but the there is no custom __torch_dispatch__ implementation for it." + ) + +class ShardedTensor(ShardedTensorBase): """ - ShardedTensor is an abstraction to represent Tensors that are sharded + ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded across multiple devices and multiple processes. ShardedTensor is initialized in an SPMD like fashion where each rank @@ -115,11 +290,9 @@ class ShardedTensor(object): individual GPU, via ``torch.cuda.set_device()`` """ - - def __new__(cls, *args, **kwargs): - # Use __new__ for logging purposes. - torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor") - return super(ShardedTensor, cls).__new__(cls) + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + self = super(ShardedTensor, cls).__new__(cls, sharding_spec, *size, **kwargs) + return self def __init__( self, @@ -137,42 +310,25 @@ def __init__( # _process_group, _local_shards, etc. self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) - tensor_properties = TensorProperties(dtype, layout, requires_grad, memory_format, pin_memory) - - if tensor_properties is None: - raise ValueError('tensor_properties must not be None.') - - if tensor_properties.dtype is None: - tensor_properties.dtype = torch.get_default_dtype() - - if tensor_properties.layout != torch.strided: + if layout != torch.strided: raise ValueError('Only torch.strided layout is currently supported') - if tensor_properties.memory_format != torch.contiguous_format: + if memory_format != torch.contiguous_format: raise ValueError('Only torch.contiguous_format memory_format is currently supported') - dims = _flatten_tensor_size(size) - - if not isinstance(sharding_spec, shard_spec.ShardingSpec): - raise ValueError(f'Expecting ShardingSpec but got: {type(sharding_spec)}') - - self._sharding_spec = sharding_spec - - sharded_tensor_metadata = sharding_spec.build_metadata( - dims, tensor_properties=tensor_properties) + self._metadata.tensor_properties.memory_format = memory_format current_rank = dist.get_rank(self._process_group) - for shard_metadata in sharded_tensor_metadata.shards_metadata: + for shard_metadata in self._metadata.shards_metadata: rank, device = _parse_and_validate_remote_device(self._process_group, shard_metadata.placement) if rank == current_rank: local_tensor = _create_tensor_from_params( shard_metadata.shard_sizes, local_device=device, - tensor_properties=sharded_tensor_metadata.tensor_properties + tensor_properties=self._metadata.tensor_properties ) self._local_shards.append(Shard(local_tensor, shard_metadata)) - self._metadata = sharded_tensor_metadata # do post initialization (i.e. register sharded_tensor_id, initialize_rpc) self._post_init() @@ -187,7 +343,6 @@ def _prepare_init(self, process_group=None, init_rrefs=False): else distributed_c10d._get_default_group() ) - self._local_shards: List[Shard] = [] self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {} def _post_init(self): @@ -209,7 +364,10 @@ def __del__(self): # Clean up the global map. with _sharded_tensor_lock: global _sharded_tensor_current_id, _sharded_tensor_map - if self._sharded_tensor_id in _sharded_tensor_map: + if ( + hasattr(self, "_sharded_tensor_id") + and self._sharded_tensor_id in _sharded_tensor_map + ): _sharded_tensor_map.pop(self._sharded_tensor_id) # type: ignore[call-overload] def _init_rpc(self): @@ -266,7 +424,7 @@ def _get_preferred_device(self) -> torch.device: return torch.device(torch.cuda.current_device()) return torch.device("cpu") - def gather( + def gather( # type: ignore[override] self, dst: int = 0, out: Optional[torch.Tensor] = None, @@ -407,6 +565,148 @@ def cpu( ) return st_cpu + def cuda( + self, + device=None, + non_blocking=False, + memory_format=torch.preserve_format, + process_group=None + ) -> ShardedTensor: + """ + Returns a copy of this object in CUDA memory, if the original ShardedTensor + is on CPU, we will move the local shard to the current GPU device of each + process in a SPMD fashion. + If this ShardedTensor is already on CUDA memory and local shards on each rank are + already on current device, we still returns a new ShardedTensor object with new + metadata, but no underlying data movements are performed. + .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL), + it is the user's responsiblity to explicitly pass in a new process_group that + is compatible with GPU. + """ + if memory_format != torch.preserve_format and \ + memory_format != torch.contiguous_format: + raise RuntimeError("Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!") + + if device is not None: + device = torch.device(device) if isinstance(device, str) else device + assert isinstance(device, torch.device) and device.index == torch.cuda.current_device(), \ + '''Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!''' + + current_device = torch.device(torch.cuda.current_device()) + # returns a copy of ShardedTensor on CUDA current device + list_shards: List[Shard] = [] + # move all local shards to current device, and change metadata + # if local shards already on the current device, there's no + # real data movement, only the metadata are copied. + for shard in self._local_shards: + cuda_tensor = shard.tensor.cuda( + device=current_device, + non_blocking=non_blocking, + memory_format=memory_format + ) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = current_device # type: ignore[union-attr] + + list_shards.append( + Shard(cuda_tensor, metadata) + ) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cuda": # type: ignore[union-attr] + meta.placement._device = current_device # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs + ) + return st_cuda + + def to(self, *args, **kwargs) -> ShardedTensor: + current_device = self._local_shards[0].tensor.device + current_dtype = self.dtype + device_to = current_device + dtype_to = current_dtype + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_to = args[0] + elif isinstance(args[0], torch.device): + device_to = args[0] + elif isinstance(args[0], (str, int)): + device_to = torch.device(args[0]) + elif isinstance(args[0], torch.Tensor): + dtype_to = args[0].dtype + device_to = args[0].device + else: + raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}") + elif len(args) == 2: + device_to, dtype_to = args + else: + dtype_to = kwargs.get("dtype", current_dtype) + device_to = kwargs.get("device", current_device) + + device_to = torch.device(device_to) if isinstance(device_to, (str, int)) else device_to + + if device_to.type == "cuda": + # if device_to set to cuda, set to current device even + # if user specify the device index. + current_idx = torch.cuda.current_device() + if device_to.index != current_idx: + import warnings + warnings.warn("ShardedTensor.to only move tensor to its current device" + "If you want to put to different device, use `reshard` instead.") + device_to = torch.device(current_idx) + + copy_tensor = kwargs.get("copy", False) + non_blocking = kwargs.get("non_blocking", False) + memory_format = kwargs.get("memory_format", torch.preserve_format) + process_group = kwargs.get("process_group", None) + + if not copy_tensor and dtype_to == current_dtype and device_to == current_device: + # already have correct dtype and device, return itself + return self + + # returns a copy of ShardedTensor on CUDA current device + list_shards: List[Shard] = [] + + for shard in self._local_shards: + new_tensor = shard.tensor.to( # type: ignore[call-overload] + device=device_to, + dtype=dtype_to, + non_blocking=non_blocking, + copy=copy_tensor, + memory_format=memory_format + ) + metadata = copy.deepcopy(shard.metadata) + if metadata.placement is not None: + metadata.placement._device = device_to + list_shards.append(Shard(new_tensor, metadata)) + + # update metadata + st_meta = copy.deepcopy(self.metadata()) + st_meta.tensor_properties.dtype = dtype_to + for meta in st_meta.shards_metadata: + meta.placement._device = device_to # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_to = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs + ) + return st_to + + @classmethod def _init_from_local_shards( cls, @@ -446,18 +746,24 @@ def _init_from_local_shards( gathered_metadatas = [local_sharded_tensor_metadata] global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas) + tensor_properties = global_sharded_tensor_metadata.tensor_properties # STEP 3: Validation done, create the actual ShardedTensor and populate fields # prepare initialization - sharded_tensor = cls.__new__(cls) + spec = shard_spec._infer_sharding_spec_from_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + sharded_tensor = cls.__new__(cls, + spec, + global_sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) - # add to metadata and local_shards - sharded_tensor._metadata = global_sharded_tensor_metadata + # attach local_shards to the ShardedTensor created sharded_tensor._local_shards = local_shards - sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata( - global_sharded_tensor_metadata.shards_metadata - ) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() @@ -566,7 +872,7 @@ def _init_from_local_tensor( ) @classmethod - def _init_from_local_shards_and_global_metadata( + def _init_from_local_shards_and_global_metadata( # type: ignore[override] cls, local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, @@ -590,32 +896,12 @@ def _init_from_local_shards_and_global_metadata( current_rank = dist.get_rank(process_group) shards_metadata = sharded_tensor_metadata.shards_metadata - tensor_properties = sharded_tensor_metadata.tensor_properties - - if len(shards_metadata) == 0: - raise ValueError("shards_metadata must not be empty!") - - if tensor_properties.layout != torch.strided: - raise ValueError('Only torch.strided layout is currently supported') - - sharded_tensor = cls.__new__(cls) - sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) - - sharded_tensor._metadata = sharded_tensor_metadata local_shard_metadatas = [] - def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): - tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata" - if expected != actual: - raise ValueError(f"Local shards' tensor {prop_name} property is incompatible with " - f"{tensor_property_or_metadata} on rank {rank}: " - f"{tensor_property_or_metadata} {prop_name}={expected}, " - f"local shard tensor {prop_name}={actual}.") - # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] - rank, local_device = _parse_and_validate_remote_device(sharded_tensor._process_group, shard_metadata.placement) + rank, local_device = _parse_and_validate_remote_device(process_group, shard_metadata.placement) if current_rank == rank: local_shard_metadatas.append(shard_metadata) @@ -626,39 +912,12 @@ def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' f'on rank ({current_rank}) ' ) - - for shard in local_shards: - shard_meta = shard.metadata - local_shard_tensor = shard.tensor - rank, local_device = _parse_and_validate_remote_device(sharded_tensor._process_group, shard_meta.placement) - - # validate if shard_meta in the metadatas collected from sharded_tensor_metadata - assert shard_meta in local_shard_metadatas, \ - "local shard metadata not in sharded_tensor_metadata!" - - _raise_if_mismatch(tensor_properties.layout, local_shard_tensor.layout, "layout", current_rank, True) - if not local_shard_tensor.is_contiguous(): - raise ValueError('Only torch.contiguous_format memory_format is currently supported') - - _raise_if_mismatch(shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) - _raise_if_mismatch(tensor_properties.pin_memory, local_shard_tensor.is_pinned(), "pin_memory", current_rank, True) - _raise_if_mismatch(local_device, local_shard_tensor.device, "device", current_rank) - _raise_if_mismatch(tensor_properties.dtype, local_shard_tensor.dtype, "dtype", current_rank, True) - _raise_if_mismatch( - tensor_properties.requires_grad, local_shard_tensor.requires_grad, "requires_grad", current_rank, True) - - # check if shards_metadata have overlap shards - validate_non_overlapping_shards_metadata(shards_metadata) - - # check if the shards_metadata is compatible with overall size of the sharded tensor. - check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) - - # done validation, add local_shards - sharded_tensor._local_shards = local_shards - if sharding_spec is None: - sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) - else: - sharded_tensor._sharding_spec = sharding_spec + sharded_tensor = super( + ShardedTensor, cls + )._init_from_local_shards_and_global_metadata( + local_shards, sharded_tensor_metadata, sharding_spec=sharding_spec + ) + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() @@ -825,201 +1084,12 @@ def find_sharded_tensor(e): f"torch function '{func.__name__}', with args: {args} and " f"kwargs: {kwargs} not supported for ShardedTensor!") - def metadata(self) -> ShardedTensorMetadata: - """ - Returns a :class:`ShardedTensorMetadata` object corresponding to the - metadata for the entire tensor. - """ - return self._metadata - - def local_shards(self) -> List[Shard]: - """ - Returns a list of :class:`Shard' corresponding to the - local shards for this rank. Returns an empty list if the current rank - does not host any shards for this Tensor. - """ - return self._local_shards - - def size(self, dim: int = None) -> Union[torch.Size, int]: - """ - Returns a :Union:`[torch.Size, int]` which represents the size of the tensor. - The dimension can be specified. - - Args: - dim (int, optional): the dimension over which the size represents. - If specified, it returns the size of the given dimension. - If not, it returns a subclass of tuple. - Default: ``None`` - - Returns: - A :Union:`[torch.Size, int]` represents the size of the tensor. - """ - size = self._metadata.size - if dim is None: - return size - if dim < -len(size) or dim >= len(size): - raise ValueError( - "Argument ``dim`` must be within the range of tensor " - f"dimensions [-{len(size)}, {len(size)})" - ) - return size[dim] - - - def is_pinned(self) -> bool: + def is_pinned(self) -> bool: # type: ignore[override] """ Returns True if the sharded tensor (each local shard) resides in pinned memory. """ return self._metadata.tensor_properties.pin_memory - def is_contiguous(self) -> bool: - """ - Returns True if the sharded tensor (each local shard) is contiguous in memory - in the order specified by memory format. - """ - return self._metadata.tensor_properties.memory_format == torch.contiguous_format - - def dim(self) -> int: - """ - Returns a `int` which represents the dimension of the tensor. - - Returns: - A `int` represents the dimension of the tensor. - """ - return len(self._metadata.size) - - # TODO: This op needs further definition of what exactly its behavior will be. - def contiguous(self) -> ShardedTensor: - """ - Returns a new sharded tensor with the local tensor is made to contiguous. - """ - if self.is_contiguous(): - return self - local_shards = [] - for shard in self.local_shards(): - local_shards.append( - Shard(shard.tensor.contiguous(), shard.metadata) - ) - return ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards, - self._metadata, - process_group=self._process_group, - init_rrefs=self._init_rrefs, - ) - - def masked_fill(self, mask, value) -> ShardedTensor: - """ - Returns a new sharded tensor with each shard has been filled elements - with value where mask is True. The shape of mask must be broadcastable - with the shape of the underlying tensor. - - Args: - mask (BoolTensor): the boolean mask. - value (float): the value to fill in with. - - Returns: - A :class:`ShardedTensor` object whose shards have been applied masked_fill. - """ - return handle_torch_function( - torch.Tensor.masked_fill, (self, mask, value), self, mask, value - ) - - def type_as(self, tensor) -> ShardedTensor: - """ - Returns a new sharded tensor with each shard has been - cast to the type of the given tensor. - - Args: - tensor (Tensor): the tensor which has the desired type. - - Returns: - A :class:`ShardedTensor` object whose shards have been applied type_as. - """ - return handle_torch_function(torch.Tensor.type_as, (self, tensor), self, tensor) - - def view(self, *shape) -> ShardedTensor: - """ - Returns a new sharded tensor with the same data as the - self tensor but of a different shape for its local tensor. - - For now, we only support to pass through the view op to the local - tensor. - - Args: - shape (torch.Size or int...) – the desired size. - - Returns: - A :class:`ShardedTensor` object whose shards have been applied - with view to its local tensor. - """ - return handle_torch_function(torch.Tensor.view, (self, *shape), self, *shape) - - def transpose(self, dim0, dim1) -> ShardedTensor: - """ - Returns a new sharded tensor with the given dimensions transposed. - During the transpose, we keep the original shading dim, e.g., if the - tensor is sharded by dim 0 and if we call transpose(1, 0). The returned - tensor will be sharded by dim 1. - - Args: - dim0 (int): the first dimension to be transposed. - dim1 (int): the second dimension to be transposed. - - Returns: - A :class:`ShardedTensor` object whose dims have been transposed - specified in the input. - """ - return handle_torch_function(torch.Tensor.transpose, (self, dim0, dim1), self, dim0, dim1) - - def bmm(self, st2, *, out=None) -> ShardedTensor: - """ - Performs a batch matrix-matrix product of matrices stored in self and st2. - - Warning: For now we only supports the case when both tensors are sharded - by dim 0 so that no communication is needed. - - Args: - st2 (ShardedTensor) – the second batch of sharded matrices to be multiplied. - - Returns: - A :class:`ShardedTensor` object which is the result of the batch multiplication. - """ - return handle_torch_function(torch.Tensor.bmm, (self, st2, out), self, st2, out=out) - - def chunk(self, chunks, dim=0) -> List[ShardedTensor]: - """ - Attempts to split a tensor into the specified number of chunks. - Each chunk is a view of the input tensor. - - Warnings: Chunk by the sharding dim is not supported. - - Args: - chunks (int) – number of chunks to return - dim (int) – dimension along which to split the tensor - - Returns: - A List of :class:`ShardedTensor` object chunked on dims. - """ - return handle_torch_function(torch.Tensor.chunk, (self, chunks, dim), self, chunks, dim=dim) - - @property - def shape(self): - return self._metadata.size - - @property - def requires_grad(self): - return self._metadata.tensor_properties.requires_grad - - def requires_grad_(self, requires_grad=True): - return handle_torch_function(torch.Tensor.requires_grad_, (self, requires_grad), self, requires_grad) - - @property - def dtype(self): - return self._metadata.tensor_properties.dtype - - @property - def layout(self): - return self._metadata.tensor_properties.layout - def _register_remote_shards(self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int): self._remote_shards[rpc_rank] = remote_shards @@ -1043,45 +1113,6 @@ def __hash__(self): def __repr__(self): return f'ShardedTensor({self._metadata})' - def __add__(self, other): - return handle_torch_function(torch.Tensor.__add__, (self, other), self, other) - - def __radd__(self, other): - return handle_torch_function(torch.Tensor.__radd__, (self, other), self, other) - - def __sub__(self, other): - return handle_torch_function(torch.Tensor.__sub__, (self, other), self, other) - - def __rsub__(self, other): - return handle_torch_function(torch.Tensor.__rsub__, (self, other), self, other) - - def __mul__(self, other): - return handle_torch_function(torch.Tensor.__mul__, (self, other), self, other) - - def __rmul__(self, other): - return handle_torch_function(torch.Tensor.__rmul__, (self, other), self, other) - - def __truediv__(self, other): - return handle_torch_function(torch.Tensor.__div__, (self, other), self, other) - - def __rtruediv__(self, other): - return handle_torch_function(torch.Tensor.__rdiv__, (self, other), self, other) - - def tanh(self): - return handle_torch_function(torch.Tensor.tanh, (self,), self) - - def __getitem__(self, key): - return handle_torch_function(torch.Tensor.__getitem__, (self, key), self, key) - - def __deepcopy__(self, memo): - return handle_torch_function(torch.Tensor.__deepcopy__, (self, memo), self, memo) - - def clone(self, *, memory_format=torch.preserve_format): - return handle_torch_function(torch.Tensor.clone, (self,), self, memory_format=memory_format) - - def detach(self): - return handle_torch_function(torch.Tensor.detach, (self,), self) - @dataclass class ProcessGroupState: """ diff --git a/torch/distributed/_shard/sharded_tensor/metadata.py b/torch/distributed/_shard/sharded_tensor/metadata.py index 812bee02efd74..17b653e1f5c20 100644 --- a/torch/distributed/_shard/sharded_tensor/metadata.py +++ b/torch/distributed/_shard/sharded_tensor/metadata.py @@ -58,6 +58,15 @@ def __setstate__( self.memory_format = memory_format + @staticmethod + def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned() + ) @dataclass class ShardedTensorMetadata(object): """ diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py index 7b1a698c1778c..10433129173d5 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py @@ -1,7 +1,5 @@ # coding=utf-8 -from typing import cast - import torch import torch.distributed as dist from ._common import ( @@ -158,7 +156,7 @@ def _validate_embedding_param(args, kwargs): raise TypeError("input need to be torch.Tensor") if not isinstance(weight, ShardedTensor): raise TypeError("weight needs to be ShardedTensor") - weight_size = cast(torch.Size, weight.size()) + weight_size = weight.size() if len(weight_size) != 2: raise ValueError("Weight needs to have exactly 2 dims") if int(torch.min(input).item()) < 0: diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py index c4579398a4a2a..1703f5d15dc48 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -204,7 +204,7 @@ def _validate_embedding_bag_param(args, kwargs): raise TypeError("weight needs to be ShardedTensor") if len(input.size()) > 2: raise ValueError("Input more than 2 dims not supported") - weight_size = cast(torch.Size, weight.size()) + weight_size = weight.size() if len(weight_size) != 2: raise ValueError("Weight needs to have exactly 2 dims") if int(torch.min(input).item()) < 0: diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py index ac65222e24175..c19eee6ffb809 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import List import torch import torch.distributed as dist @@ -105,14 +105,14 @@ def sharded_linear(types, args, kwargs, pg): world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) - if sharding_dim == 1 and isinstance(input, torch.Tensor): - return _handle_row_wise_sharding_tensor( - input, world_size, weight, rank, local_shard_t, bias, pg - ) - elif sharding_dim == 1 and isinstance(input, ShardedTensor): + if sharding_dim == 1 and isinstance(input, ShardedTensor): return _handle_row_wise_sharding_sharded_tensor( input, world_size, weight, local_shard_t, bias, pg ) + elif sharding_dim == 1 and isinstance(input, torch.Tensor): + return _handle_row_wise_sharding_tensor( + input, world_size, weight, rank, local_shard_t, bias, pg + ) elif sharding_dim == 0: return _handle_col_wise_sharding( input, world_size, weight, rank, local_shard_t, bias, pg @@ -125,7 +125,7 @@ def sharded_linear(types, args, kwargs, pg): def _validate_linear_op_param(args, kwargs): """ - Validate input params of sharded embedding op. + Validate input params of sharded linear op. Args: input: input of the linear layer. @@ -141,13 +141,13 @@ def _validate_linear_op_param(args, kwargs): # Validate types if not isinstance(input, torch.Tensor) and not isinstance(input, ShardedTensor): raise TypeError("input needs to be either torch.Tensor or ShardedTensor") - if not isinstance(bias, torch.Tensor): + if type(bias) != torch.Tensor and type(bias) != torch.nn.Parameter: raise TypeError("bias needs to be torch.Tensor") if not isinstance(weight, ShardedTensor): raise TypeError("weight needs to be ShardedTensor") if len(input.size()) < 1: # type: ignore[arg-type] raise ValueError("Input needs to have at least 1 dim") - weight_size = cast(torch.Size, weight.size()) + weight_size = weight.size() if len(weight_size) != 2: raise ValueError("Weight needs to have exactly 2 dims") if len(bias.size()) != 1: diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index e373490311283..0dad32ec64448 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -5,9 +5,8 @@ from torch.autograd.graph import save_on_cpu from torch.utils.checkpoint import checkpoint from torch.distributed.utils import _replace_by_prefix -from torch.distributed.fsdp.wrap import _recursive_wrap, lambda_auto_wrap_policy import torch.nn as nn -from typing import Dict, Any +from typing import Any, Dict, Iterator, Tuple from functools import partial _CHECKPOINT_PREFIX = "_checkpoint_wrapped_module" @@ -61,6 +60,18 @@ def forward(self, *args, **kwargs): **kwargs, ) + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """ + Overrides :meth:`named_parameters()` to intercept parameter names and + remove all occurrences of _CHECKPOINT_PREFIX. + """ + for param_name, param in super().named_parameters(*args, **kwargs): + yield param_name.replace(f"{_CHECKPOINT_PREFIX}.", ""), param + @staticmethod def _post_state_dict_hook( module: nn.Module, @@ -156,6 +167,9 @@ def apply_activation_checkpointing_wrapper( ``True`` or ``False`` depending on whether input layer should be wrapped. Returns: None (`model` is modified inplace) """ + # TODO: Importing inside function to avoid circular import issue between FSDP and + # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code. + from torch.distributed.fsdp.wrap import _recursive_wrap, lambda_auto_wrap_policy return _recursive_wrap( module=model, auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=check_fn), diff --git a/torch/distributed/algorithms/_comm_hooks/__init__.py b/torch/distributed/algorithms/_comm_hooks/__init__.py new file mode 100644 index 0000000000000..d07adc17247b7 --- /dev/null +++ b/torch/distributed/algorithms/_comm_hooks/__init__.py @@ -0,0 +1,7 @@ + +from . import default_hooks as default + +LOW_PRECISION_HOOKS = [ + default.fp16_compress_hook, + default.bf16_compress_hook, +] diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py new file mode 100644 index 0000000000000..b9ea63392eb3e --- /dev/null +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -0,0 +1,125 @@ +import functools +import torch +import torch.distributed as dist +from torch.distributed import distributed_c10d + + +class DefaultState(object): + r""" + Stores state needed to perform the default ``all_reduce`` algorithm + within a communication hook. + + Args: + process_group (ProcessGroup): The process group to be used for all-reduce. + """ + + __slots__ = [ + "process_group", + "world_size", + "gradient_predivide_factor", + "gradient_postdivide_factor" + ] + + def __init__( + self, + process_group + ): + self.process_group = process_group if process_group is not None else distributed_c10d._get_default_group() + self.world_size = dist.get_world_size(process_group) + self.gradient_predivide_factor = self._get_gradient_predivide_factor( + self.world_size + ) + self.gradient_postdivide_factor = self.world_size / self.gradient_predivide_factor + + # setting two factors `self.gradient_predivide_factor` + # and `self.gradient_postdivide_factor` to avoid underflow and overflow + def _get_gradient_predivide_factor(self, world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) + +class LowPrecisionState(DefaultState): + r""" + Stores state needed to perform gradient communication in a lower precision + within a communication hook. Communication hook will cast gradients back + to the original parameter precision specified by ``parameter_type`` (default: torch.float32). + Builds on top of the :class:`DefaultState`. + + Args: + parameter_type (torch.dtype): The precision of model's parameters. + Required for a hook to cast gradients back to a parameter's precision. + """ + + __slots__ = [ + "parameter_type", + ] + + def __init__( + self, + process_group, + parameter_type=torch.float32, + ): + super().__init__(process_group) + self.parameter_type = parameter_type + + +def _decompress(state: LowPrecisionState, grad: torch.Tensor): + """ + Casts gradients back to full parameter precision so that + further computation happens in full precision + """ + orig_grad_data = grad.data + grad.data = grad.data.to(state.parameter_type) + # Don't let this memory get reused until after the transfer. + orig_grad_data.record_stream(torch.cuda.current_stream()) # type: ignore[arg-type] + +def allreduce_hook(state: DefaultState, grad: torch.Tensor): + r""" + This FSDP communication hook implements ``all_reduce`` algorithm + and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks. + """ + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist.all_reduce(grad, group=state.process_group) + if state.gradient_postdivide_factor > 1: + grad.div_(state.gradient_postdivide_factor) + +def lower_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor): + grad.data = grad.data.to(prec) + allreduce_hook(state, grad) + _decompress(state, grad) + +def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor): + r""" + This FSDP communication hook implements a simple gradient compression + approach that casts ``grad`` to half-precision floating-point format (``torch.float16``). + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.predivide_factor``, and after an allreduce step gradients are averaged by a ``state.postdivide_factor``. + Onse post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (DefaultState): State information, configures pre- and post-division factors + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + """ + fp16_hook = functools.partial(lower_precision_hook, torch.float16) + return fp16_hook(state, grad) + +def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor): + r""" + This FSDP communication hook implements a simple gradient compression + approach that casts ``grad`` to half-precision floating-point format (``torch.float16``). + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.predivide_factor``, and after an allreduce step gradients are averaged by a ``state.postdivide_factor``. + Onse post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (DefaultState): State information, configures pre- and post-division factors + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + """ + bf16_hook = functools.partial(lower_precision_hook, torch.bfloat16) + return bf16_hook(state, grad) diff --git a/torch/distributed/algorithms/_quantization/__init__.py b/torch/distributed/algorithms/_quantization/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/distributed/algorithms/quantization/quantization.py b/torch/distributed/algorithms/_quantization/quantization.py similarity index 97% rename from torch/distributed/algorithms/quantization/quantization.py rename to torch/distributed/algorithms/_quantization/quantization.py index efed9ba90a69e..961fd58e1f042 100644 --- a/torch/distributed/algorithms/quantization/quantization.py +++ b/torch/distributed/algorithms/_quantization/quantization.py @@ -95,11 +95,11 @@ def auto_quantize(func, qtype, quant_loss=None): . all_gather, all_to_all collective ops Note: BFP16 only supports 2D tensors. Args: - func (callable): A function representing collective operations. + func (Callable): A function representing collective operations. qtype (QuantType): Quantization method quant_loss (float, optional): This can be used to improve accuracy in the dequantization. Returns: - (callable): the same collective as func but enables automatic quantization/dequantization. + (Callable): the same collective as func but enables automatic quantization/dequantization. """ @functools.wraps(func) def wrapper(*args, **kwargs): diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index d58c1851968cd..6e2ac880fb401 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -12,6 +12,7 @@ optimizer_overlap_hooks as optimizer_overlap, ) +__all__ = ['DDPCommHookType', 'register_ddp_comm_hook'] def _ddp_comm_hook_wrapper(comm_hook, model, state): model.register_comm_hook(state, comm_hook) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 91651cdef5bf2..53bf9efe1b3c9 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -204,7 +204,7 @@ def hook_with_zero_step( Raises: ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. - RuntimeError: if using any backend other than NCCL since currently + RuntimeError: if using any backend other than NCCL/HCCL since currently Gloo may hang. .. warning:: @@ -226,11 +226,12 @@ def hook_with_zero_step( ddp_ref = weakref.ref(ddp) # NOTE: Gloo may hang with this overlapping approach, so we require - # NCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 - if dist.get_backend(ddp_ref().process_group) != dist.Backend.NCCL: # type: ignore[union-attr] + # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if ((pg != dist.Backend.NCCL) and (pg != 'hccl')): raise RuntimeError( "Overlapping DDP with ZeRO using this approach currently requires " - "NCCL backend to avoid hangs" + "NCCL/HCCL backend to avoid hangs" ) if shard_buckets: @@ -385,11 +386,12 @@ def hook_with_zero_step_interleaved( ddp_ref = weakref.ref(ddp) # NOTE: Gloo may hang with this overlapping approach, so we require - # NCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 - if dist.get_backend(ddp_ref().process_group) != dist.Backend.NCCL: # type: ignore[union-attr] + # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if ((pg != dist.Backend.NCCL) and (pg != 'hccl')): raise RuntimeError( "Overlapping DDP with ZeRO using this approach currently requires " - "NCCL backend to avoid hangs" + "NCCL/HCCL backend to avoid hangs" ) if shard_buckets: diff --git a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py index 72d40272930f1..405612e555ca2 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -30,16 +30,12 @@ def _check_valid_functional_optim(self): ) -# TODO: Add an example to use such a wrapper. def _hook_then_optimizer( hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], optimizer_state: _OptimizerHookState, ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: r""" Runs optimizer in a functional fashion after DDP communication hook. - - .. warning :: - This API is experimental adn subject to change. """ has_set_params = ( hasattr(optimizer_state, 'params_to_optimize') diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index 17fe5cce8c673..60b058eb5a749 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist +__all__ = ['JoinHook', 'Joinable', 'Join'] class JoinHook(): r""" diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py index 19292c1549e25..b903d76abd935 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -5,6 +5,7 @@ import torch.distributed as dist import torch.distributed.algorithms.model_averaging.utils as utils +__all__ = ['ModelAverager', 'PeriodicModelAverager'] class ModelAverager(ABC): r"""Base class for all model averagers. diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 60a084cfe41af..0c74cd5801fb9 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -39,6 +39,7 @@ _MPI_AVAILABLE = True _NCCL_AVAILABLE = True _GLOO_AVAILABLE = True +_UCC_AVAILABLE = True _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -59,6 +60,11 @@ except ImportError: _GLOO_AVAILABLE = False +try: + from torch._C._distributed_c10d import ProcessGroupUCC +except ImportError: + _UCC_AVAILABLE = False + logger = logging.getLogger(__name__) @@ -86,7 +92,7 @@ def supports_complex(reduceOp: ReduceOp) -> bool: class Backend(object): """ - An enum-like class of available backends: GLOO, NCCL, MPI, and other registered + An enum-like class of available backends: GLOO, NCCL, UCC, MPI, and other registered backends. The values of this class are lowercase strings, e.g., ``"gloo"``. They can @@ -105,6 +111,7 @@ class Backend(object): UNDEFINED = "undefined" GLOO = "gloo" NCCL = "nccl" + UCC = "ucc" MPI = "mpi" TCP = "tcp" _plugins: Dict[str, Callable] = {} @@ -122,7 +129,7 @@ def __new__(cls, name: str): ) elif value == Backend.UNDEFINED: raise ValueError("Invalid backend: '{}'".format(name)) - elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.MPI: + elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.UCC and value != Backend.MPI: value = name.lower() return value @@ -145,9 +152,12 @@ def register_backend(cls, name, func): .. note:: This support of 3rd party backend is experimental and subject to change. """ - assert not hasattr(Backend, name.upper()), ( - f"{name.upper()} c10d backend already exist" - ) + # Allow UCC plugin if Pytorch is not built with native support. + # TODO: remove this exception once UCC plugin is fully deprecated. + if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())): + assert not hasattr(Backend, name.upper()), ( + f"{name.upper()} c10d backend already exist" + ) assert name.upper() not in Backend._plugins, ( f"{name.upper()} c10d backend creator function already exist" ) @@ -412,6 +422,13 @@ def is_gloo_available(): return _GLOO_AVAILABLE +def is_ucc_available(): + """ + Checks if the UCC backend is available. + """ + return _UCC_AVAILABLE + + def is_initialized(): """ Checking if the default process group has been initialized @@ -511,12 +528,13 @@ def init_process_group( Args: backend (str or Backend): The backend to use. Depending on build-time configurations, valid values include ``mpi``, ``gloo``, - and ``nccl``. This field should be given as a lowercase string - (e.g., ``"gloo"``), which can also be accessed via + ``nccl``, and ``ucc``. This field should be given as a lowercase + string (e.g., ``"gloo"``), which can also be accessed via :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using multiple processes per machine with ``nccl`` backend, each process must have exclusive access to every GPU it uses, as sharing GPUs - between processes can result in deadlocks. + between processes can result in deadlocks. ``ucc`` backend is + experimental. init_method (str, optional): URL specifying how to initialize the process group. Default is "env://" if no ``init_method`` or ``store`` is specified. @@ -547,6 +565,9 @@ def init_process_group( continue executing user code since failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. Only one of these two environment variables should be set. + For ``ucc``, blocking wait is supported similar to NCCL. However, + async error handling is done differently since with UCC we have + progress thread and not watch-dog thread. group_name (str, optional, deprecated): Group name. pg_options (ProcessGroupOptions, optional): process group options specifying what additional options need to be passed in during @@ -682,7 +703,7 @@ def _new_process_group_helper( is_default_group = len(group_ranks) == 0 backend = Backend(backend) - pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL] + pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL, ProcessGroupUCC] if backend == Backend.MPI: if not is_mpi_available(): raise RuntimeError( @@ -767,6 +788,33 @@ def _new_process_group_helper( ) _pg_map[pg] = (Backend.NCCL, store) _pg_names[pg] = group_name + elif backend == Backend.UCC and is_ucc_available(): + # TODO: once UCC plugin is fully deprecated, remove + # is_ucc_available() from above elif-condition and raise + # RuntimeError if is_ucc_available() returns false. + + pg = ProcessGroupUCC(prefix_store, rank, world_size, timeout=timeout) + # In debug mode and if GLOO is available, wrap in a wrapper PG that + # enables enhanced collective checking for debugability. + if get_debug_level() == DebugLevel.DETAIL: + if not _GLOO_AVAILABLE: + logger.info( + """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but + GLOO is not available. Build with Gloo to + create a wrapper process group in debug mode + to aid collective desynchronization debugging.""" + ) + else: + pg = _create_process_group_wrapper( + wrapped_pg=pg, + store_prefix=group_name, + store=store, + rank=rank, + world_size=world_size, + timeout=timeout, + ) + _pg_map[pg] = (Backend.UCC, store) + _pg_names[pg] = group_name else: assert backend.upper() in Backend._plugins, ( f"unknown c10d backend type {backend.upper()}" @@ -1025,7 +1073,7 @@ class P2POp(object): ``batch_isend_irecv`` for point-to-point communications. Args: - op (callable): A function to send data to or receive data from a peer process. + op (Callable): A function to send data to or receive data from a peer process. The type of ``op`` is either ``torch.distributed.isend`` or ``torch.distributed.irecv``. tensor (Tensor): Tensor to send or receive. @@ -1064,7 +1112,7 @@ def batch_isend_irecv(p2p_op_list): Send or Receive a batch of tensors asynchronously and return a list of requests. Process each of the operations in ``p2p_op_list`` and return the corresponding - requests. NCCL and Gloo backend are currently supported. + requests. NCCL, Gloo, and UCC backend are currently supported. Args: p2p_op_list: A list of point-to-point operations(type of each operator is diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 1068ce08bda63..259632869d43e 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -32,7 +32,7 @@ ) from torch.distributed.elastic.utils.logging import get_logger - +__all__ = ['WorkerSpec', 'Worker', 'WorkerState', 'WorkerGroup', 'RunResult', 'ElasticAgent', 'SimpleElasticAgent'] _TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state" DEFAULT_ROLE = "default" diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 8fa868398d282..ef9248d14afce 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -28,6 +28,7 @@ log = get_logger() +__all__ = ['LocalElasticAgent'] class LocalElasticAgent(SimpleElasticAgent): """ diff --git a/torch/distributed/elastic/events/api.py b/torch/distributed/elastic/events/api.py index 350981eb33937..eea723782f723 100644 --- a/torch/distributed/elastic/events/api.py +++ b/torch/distributed/elastic/events/api.py @@ -11,6 +11,7 @@ from enum import Enum from typing import Dict, Union, Optional +__all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent'] EventMetadataValue = Union[str, int, float, bool, None] diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 35d5c78c6ef58..0a9f297de4fbc 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -13,6 +13,10 @@ from functools import wraps from typing import Dict, Optional +__all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream', + 'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms', + 'MetricData'] + MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) diff --git a/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/torch/distributed/elastic/multiprocessing/errors/error_handler.py index 16eb5ddc45b11..0ba3df411e702 100644 --- a/torch/distributed/elastic/multiprocessing/errors/error_handler.py +++ b/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -12,8 +12,9 @@ import time import traceback import warnings -from typing import Optional +from typing import Any, Dict, Optional +__all__ = ['ErrorHandler'] log = logging.getLogger(__name__) @@ -80,6 +81,28 @@ def record_exception(self, e: BaseException) -> None: with open(file, "w") as fp: json.dump(data, fp) + def override_error_code_in_rootcause_data( + self, + rootcause_error_file: str, + rootcause_error: Dict[str, Any], + error_code: int = 0, + ): + """ + Modify the rootcause_error read from the file, to correctly set the exit code. + """ + if "message" not in rootcause_error: + log.warning( + f"child error file ({rootcause_error_file}) does not have field `message`. \n" + f"cannot override error code: {error_code}" + ) + elif isinstance(rootcause_error["message"], str): + log.warning( + f"child error file ({rootcause_error_file}) has a new message format. \n" + f"skipping error code override" + ) + else: + rootcause_error["message"]["errorCode"] = error_code + def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): """ Dumps parent error file from child process's root cause error and error code. @@ -89,19 +112,7 @@ def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): # Override error code since the child process cannot capture the error code if it # is terminated by singals like SIGSEGV. if error_code: - if "message" not in rootcause_error: - log.warning( - f"child error file ({rootcause_error_file}) does not have field `message`. \n" - f"cannot override error code: {error_code}" - ) - elif isinstance(rootcause_error["message"], str): - log.warning( - f"child error file ({rootcause_error_file}) has a new message format. \n" - f"skipping error code override" - ) - else: - rootcause_error["message"]["errorCode"] = error_code - + self.override_error_code_in_rootcause_data(rootcause_error_file, rootcause_error, error_code) log.debug( f"child error file ({rootcause_error_file}) contents:\n" f"{json.dumps(rootcause_error, indent=2)}" diff --git a/torch/distributed/elastic/multiprocessing/errors/handlers.py b/torch/distributed/elastic/multiprocessing/errors/handlers.py index 09b59a05ddafc..3071aef171178 100644 --- a/torch/distributed/elastic/multiprocessing/errors/handlers.py +++ b/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -10,6 +10,7 @@ from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler +__all__ = ['get_error_handler'] def get_error_handler(): return ErrorHandler() diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index 4f64565ec4dd8..ea5ee2f795e86 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -34,6 +34,8 @@ ) from .utils import _delay, _PeriodicTimer +__all__ = ['RendezvousBackend', 'RendezvousTimeout', 'RendezvousSettings', 'DynamicRendezvousHandler', 'create_handler'] + log = logging.getLogger(__name__) @@ -46,7 +48,6 @@ def get_method_name(depth=2): Token = Any """Represents an opaque fencing token used by the rendezvous backend.""" - class RendezvousBackend(ABC): """Represents a backend that holds the rendezvous state.""" diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index 37b144c5ffa18..51b322ef81904 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -8,6 +8,7 @@ from .api import rendezvous_handler_registry as handler_registry from .dynamic_rendezvous import create_handler +__all__ = ['get_rendezvous_handler'] def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: from . import static_tcp_rendezvous diff --git a/torch/distributed/elastic/rendezvous/utils.py b/torch/distributed/elastic/rendezvous/utils.py index b4687517d0367..14158e8bc708c 100644 --- a/torch/distributed/elastic/rendezvous/utils.py +++ b/torch/distributed/elastic/rendezvous/utils.py @@ -14,6 +14,7 @@ from threading import Event, Thread from typing import Any, Callable, Dict, Optional, Tuple, Union +__all__ = ['parse_rendezvous_endpoint'] def _parse_rendezvous_config(config_str: str) -> Dict[str, str]: """Extracts key-value pairs from a rendezvous configuration string. diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index b031a066c623d..af0a0d955ef47 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -11,6 +11,7 @@ from inspect import getframeinfo, stack from typing import Any, Dict, List, Optional, Set +__all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires'] class TimerRequest: """ diff --git a/torch/distributed/elastic/timer/local_timer.py b/torch/distributed/elastic/timer/local_timer.py index 33777403684eb..56092f8f0df7d 100644 --- a/torch/distributed/elastic/timer/local_timer.py +++ b/torch/distributed/elastic/timer/local_timer.py @@ -13,6 +13,7 @@ from .api import RequestQueue, TimerClient, TimerRequest, TimerServer +__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer'] class LocalTimerClient(TimerClient): """ diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index a19c6cc8c5e74..e2b39c90fe58f 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -1,12 +1,13 @@ -from .flatten_params_wrapper import FlatParameter -from .fully_sharded_data_parallel import FullyShardedDataParallel +from .flat_param import FlatParameter from .fully_sharded_data_parallel import ( - CPUOffload, BackwardPrefetch, - ShardingStrategy, - MixedPrecision, + CPUOffload, FullStateDictConfig, + FullyShardedDataParallel, LocalStateDictConfig, + MixedPrecision, + OptimStateKeyType, + ShardingStrategy, + StateDictType, ) -from .fully_sharded_data_parallel import StateDictType, OptimStateKeyType from .wrap import ParamExecOrderWrapPolicy diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index b24ccdc5d8cc0..907c0f4b06867 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -8,6 +8,7 @@ List, NamedTuple, Optional, + Sequence, Tuple, Union, ) @@ -16,7 +17,10 @@ import torch.distributed as dist # Import the entire FSDP file to avoid circular imports import torch.distributed.fsdp.fully_sharded_data_parallel as FSDP -from torch.distributed.fsdp.flatten_params_wrapper import FlatParameter +from torch.distributed.fsdp.flat_param import ( + FlatParameter, + FlatParamHandle, +) class _ConsolidatedOptimState: @@ -146,7 +150,7 @@ def _communicate_optim_state( # If the parameter is not sharded (e.g. world size of 1), then # neither is the positive-dimension tensor state, so no need to # communicate it -- we take the target rank's value - if not flat_param._is_sharded: + if not flat_param._is_sharded: # type: ignore[attr-defined] tensor_state[state_name] = value.cpu() continue if tensor_buffer is None: @@ -155,10 +159,9 @@ def _communicate_optim_state( buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined] tensor_buffer = value.new_zeros(*buffer_size) dist._all_gather_base(tensor_buffer, value, group=group) + torch.cuda.synchronize() if to_save: - assert hasattr(flat_param, "_orig_size"), \ - "Sharded flattened parameter should have `_orig_size` set" - unpadded_numel = flat_param._orig_size.numel() # type: ignore[attr-defined] + unpadded_numel = flat_param._unsharded_size.numel() # type: ignore[attr-defined] tensor_state[state_name] = tensor_buffer[:unpadded_numel].cpu() # Zero-dimension tensor state and non-tensor state: take this rank's # value directly @@ -192,7 +195,7 @@ def _unflatten_communicated_optim_state( """ unflat_param_state: List[Dict[str, Any]] = [] flat_param_views: Dict[str, Iterator] = {} - num_unflat_params = flat_param._num_unflattened_params + num_unflat_params = flat_param._num_params tensor_state, zero_dim_tensor_state, non_tensor_state = \ state.tensor_state, state.zero_dim_tensor_state, state.non_tensor_state @@ -202,11 +205,11 @@ def _unflatten_communicated_optim_state( for state_name, flat_tensor in tensor_state.items(): views_generated = state_name in flat_param_views if not views_generated: - param_views = flat_param.get_param_views(flat_tensor) - flat_param_views[state_name] = param_views + views = FlatParamHandle._get_unflat_views(flat_param, flat_tensor) + flat_param_views[state_name] = views else: - param_views = flat_param_views[state_name] - unflat_state_param[state_name] = next(param_views) + views = flat_param_views[state_name] + unflat_state_param[state_name] = next(views) # Add zero-dimension tensor state: take the target rank's value for state_name, zero_dim_tensor in zero_dim_tensor_state.items(): unflat_state_param[state_name] = zero_dim_tensor @@ -306,7 +309,7 @@ def _flatten_optim_state( assert num_unflat_params > 0, \ "Expects at least one unflattened parameter corresponding to the " \ "flattened parameter" - unflat_param_shapes = flat_param._param_shapes + unflat_param_shapes = flat_param._shapes num_unflat_param_shapes = len(unflat_param_shapes) assert num_unflat_params == num_unflat_param_shapes, \ f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" @@ -373,7 +376,9 @@ def _flatten_optim_state( if shard_state: # Shard the flattened tensor immediately to minimize max memory # usage - sharded_flat_tensor, _ = fsdp_module._get_shard(flat_tensor) + sharded_flat_tensor, _ = FlatParamHandle._get_shard( + flat_tensor, fsdp_module.rank, fsdp_module.world_size, + ) flat_state[state_name] = sharded_flat_tensor else: flat_state[state_name] = flat_tensor @@ -394,7 +399,7 @@ def _flatten_tensor_optim_state( state_name: str, pos_dim_tensors: List[torch.Tensor], unflat_param_names: List[str], - unflat_param_shapes: List[torch.Size], + unflat_param_shapes: Sequence[torch.Size], flat_param: FlatParameter, ) -> torch.Tensor: """ @@ -463,7 +468,7 @@ def _flatten_tensor_optim_state( in zip(pos_dim_tensors, unflat_param_shapes) ] flat_tensor = torch.cat(tensors) - flat_param_shape = flat_param._orig_size # type: ignore[attr-defined] + flat_param_shape = flat_param._unsharded_size # type: ignore[attr-defined] assert flat_tensor.shape == flat_param_shape, \ f"tensor optim state: {flat_tensor.shape} " \ f"flattened parameter: {flat_param_shape}" @@ -589,11 +594,9 @@ def _process_pos_dim_tensor_state( no_tensor_osd["state"][key][state_name] = value continue if key.is_flat_param: # FSDP parameter - chunk, num_to_pad = FSDP.FullyShardedDataParallel._get_chunk( - value, rank=0, world_size=world_size, - ) - assert len(chunk.shape) == 1, f"Chunk should be 1D but got {chunk.shape}" - info = _PosDimTensorInfo(torch.Size([chunk.shape[0] + num_to_pad]), chunk.dtype) + sharded_size = FlatParamHandle._get_sharded_size(value, rank=0, world_size=world_size) + assert len(sharded_size) == 1, f"{sharded_size}" + info = _PosDimTensorInfo(sharded_size, value.dtype) else: # non-FSDP parameter info = _PosDimTensorInfo(value.shape, value.dtype) no_tensor_osd["state"][key][state_name] = info @@ -710,7 +713,7 @@ def _broadcast_sharded_pos_dim_tensor_state( assert unsharded_tensor is not None, \ "Expects rank 0 to pass in the unsharded tensor" get_shard = functools.partial( - FSDP.FullyShardedDataParallel._get_shard_functional, + FlatParamHandle._get_shard, unsharded_tensor, ) for target_rank in range(1, world_size): diff --git a/torch/distributed/fsdp/_symbolic_trace.py b/torch/distributed/fsdp/_symbolic_trace.py new file mode 100644 index 0000000000000..026595fd7def0 --- /dev/null +++ b/torch/distributed/fsdp/_symbolic_trace.py @@ -0,0 +1,243 @@ +import contextlib +import functools +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +import torch + + +__all__ = ["TracingConfig"] + + +@dataclass +class TracingConfig: + """ + Configurations used in ``ParamExecOrderWrapPolicy`` for symbolic tracing of + a model. + + Args: + tracer (torch.fx.Tracer): An instance of ``torch.fx.Tracer`` that will + be used to perform symbolic tracing. ``tracer`` is default to be + ``torch.fx.Tracer()``, but can also be instance of some child class + of ``torch.fx.Tracer``. For example, one may want to use + ``HFTracer`` for models in Transformers: .. _Transformers: + https://huggingface.co/docs/transformers/index + concrete_args (Optional[Dict[str, Any]]): Concrete arguments that should + not be treated as ``torch.fx.Proxy`` when tracing the forward + function. ``concrete_args`` allows one to partially specialize the + forward function, including removing control flow or data + structures. ``concrete_args`` is also the argument used in + :meth:`~torch.fx.Tracer.trace`. + """ + + tracer: torch.fx.Tracer = torch.fx.Tracer() + concrete_args: Optional[Dict[str, Any]] = None + + +@dataclass +class _ExecutionInfo: + """ + Contains the execution order information in the model forward pass. + + Attributes: + current_module: record the module that is currently being traced. + + module_forward_order: a list of modules, where the ordering is based on + when their forward function is called. ``module_forward_order`` + includes the info of how many times a module is called + used to + check the forward order in different iterations. + + param_exec_order: a list of parameters ordered based on their execution + order. + + module_to_execution_infos: a dict that maps each module to a list of + tuples each containing a module and a list of named parameters. + ``module_execution_info_dict`` is used as the parameter execution + order info. For a given module, each tuple: 1. either contains this + module and part of its ``named_parameters`` that will be executed + together, 2. or contains one of its child modules and all of the + child module's ``named_parameters``. The list of tuples is ordered + based on the parameter execution order. + """ + + current_module: torch.nn.Module + module_forward_order: List[torch.nn.Module] + module_to_execution_infos: Dict[ + torch.nn.Module, + List[Tuple[torch.nn.Module, List[Tuple[str, torch.nn.Parameter]]]], + ] + param_exec_order: List[torch.nn.Parameter] = field(default_factory=list) + + +def _init_execution_info(root_module: torch.nn.Module) -> _ExecutionInfo: + """ + Create an instance of _ExecutionInfo with initialization based on + ``root_module``. + + Args: + root_module (torch.nn.Module): the module to get the execution + information via ``tracer.trace()`` inside ``_patch_tracer``. + """ + return _ExecutionInfo( + current_module=root_module, + module_forward_order=[root_module], + module_to_execution_infos={root_module: []}, + ) + + +def _patched_create_proxy( + create_proxy: Callable, + execution_info: _ExecutionInfo, + prefixed_param_name_to_param: Dict[str, torch.nn.Parameter], + kind: str, + target: torch.fx.node.Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None, +) -> torch.fx.Proxy: + """ + Override of :meth:`~torch.fx.Tracer.create_proxy`. ``Tracer.create_proxy`` + is called in symbolic tracing for each leaf function/method/module. This + override intercepts the recording of each of these operations to update + ``execution_info.module_to_execution_infos``. + + Args: + create_proxy (Callable): + The ``create_proxy`` function to be patched. + execution_info (_ExecutionInfo): + Used to record the execution information. + prefixed_param_name_to_param (Dict[str, torch.nn.Parameter]): + A dict that maps each prefixed parameter name to the parameter. + kind (str): + The type of the target method. One of 'call_function', + 'call_method', 'get_attr', 'call_module', 'placeholder', or + 'output'. The semantics of these opcodes are described in the + ``torch.fx.Graph`` docstring. This is the input to ``create_proxy``. + target (torch.fx.node.Target): + Contains the string name of the method. This is the input to + ``create_proxy``. + args (Tuple[Any, ...]): + Arguments of the method. This is the input to ``create_proxy``. + kwargs (Dict[str, Any]): + Keyword arguments of the method. This is the input to + ``create_proxy``. + name (Optional[str]): + An optional string name for the ``Node`` created in + ``create_proxy``. This is the input to ``create_proxy``. + type_expr (Optional[Any]): + An optional type annotation representing the Python type the output + of a node will have. This is the input to ``create_proxy``. + proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]): + An alternative proxy constructor used in ``create_proxy``. This is + the input to ``create_proxy``. + """ + proxy = create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + + module = execution_info.current_module + if kind in ["call_function", "call_method"]: + if args is not None: + named_params: List[Tuple[str, torch.nn.Parameter]] = [] + for arg in args: + if isinstance(arg, torch.fx.Proxy) and arg.node.target in prefixed_param_name_to_param: + param = prefixed_param_name_to_param[arg.node.target] + named_params.append((arg.node.target, param)) + if param not in set(execution_info.param_exec_order): + execution_info.param_exec_order.append(param) + if named_params: + execution_info.module_to_execution_infos[module].append((module, named_params)) + elif kind == "call_module": + named_params = list(module.named_parameters()) + if named_params: + execution_info.module_to_execution_infos[module].append( + (module, named_params) + ) + for (_, p) in named_params: + if p not in set(execution_info.param_exec_order): + execution_info.param_exec_order.append(p) + return proxy + + +def _patched_call_module( + call_module: Callable, + execution_info: _ExecutionInfo, + module: torch.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + """ + Override of :meth:`~torch.fx.Tracer.call_module`. ``Tracer.call_module`` is + called in symbolic tracing for each non-root module. This override + intercepts the recording of each operation to update + ``execution_info.module_forward_order`` and + ``execution_info.module_to_execution_infos``. + + Args: + call_module (Callable): + The ``call_module`` function to be patched. + execution_info (_ExecutionInfo): + Used to repord the execution information. + module (torch.nn.Module): + The module for which a call is being emitted. + forward (Callable[..., Any]): + The ``forward()`` method of the ``torch.nn.Module`` to be invoked. + args (Tuple[Any, ...]): + ``args`` of the module callsite. + kwargs (Dict[str, Any]): + ``kwargs`` of the module callsite. + """ + execution_info.module_forward_order.append(module) + named_params = list(module.named_parameters()) + if named_params: + execution_info.module_to_execution_infos[execution_info.current_module].append( + (module, list(module.named_parameters())) + ) + # Stores away current_module for restoration later + prev_current_module = execution_info.current_module + execution_info.current_module = module + # Note that if the forward of module is called multiple times, this will record + # the execution info of the last forward pass. + execution_info.module_to_execution_infos[module] = [] + output = call_module(module, forward, args, kwargs) + execution_info.current_module = prev_current_module + return output + + +@contextlib.contextmanager +def _patch_tracer( + tracer: torch.fx.Tracer, + root_module: torch.nn.Module, + execution_info: _ExecutionInfo, +) -> Generator: + """ + Within the context manager, patches the input tracer so that during + ``tracer.trace()``, the forward order of all modules and the parameter + execution information are recorded. The patches of the input tracer will be + removed after the context manager exits. + + Args: + tracer (torch.fx.Tracer): the input ``tracer`` whose member functions + will be patched within the context manager. + root_module (torch.nn.Module): the top-level module to be traced + and should not contain any FSDP modules. + execution_info (_ExecutionInfo): used to record the execution order + information when performing ``tracer.trace()`` within the context + manager. + """ + original_call_module = tracer.call_module + original_create_proxy = tracer.create_proxy + + tracer.call_module = functools.partial( + _patched_call_module, original_call_module, execution_info + ) + prefixed_param_name_to_param = dict(root_module.named_parameters()) + tracer.create_proxy = functools.partial( + _patched_create_proxy, original_create_proxy, execution_info, prefixed_param_name_to_param + ) + try: + yield + finally: + tracer.call_module = original_call_module + tracer.create_proxy = original_create_proxy diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py new file mode 100644 index 0000000000000..8d553391a77e1 --- /dev/null +++ b/torch/distributed/fsdp/flat_param.py @@ -0,0 +1,539 @@ +import contextlib +from itertools import accumulate, chain +from typing import ( + Dict, + Generator, + Iterator, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +__all__ = [ + "FlatParameter", "FlatParamHandle", "FlatParamShardMetadata", + "ParamInfo", "SharedParamInfo", +] + + +class ParamInfo(NamedTuple): + """Information for an original module parameter.""" + param_name: str # unprefixed + module: nn.Module + module_name: str + + +class SharedParamInfo(NamedTuple): + """ + Additional information for a shared parameter. + + For each shared parameter, we designate one module and its parameter + variable to be the primary owner, determined as the first one encountered + in the parameter walk. These are prefixed with "prim". The primary module + and parameter do not have their own :class:`SharedParamInfo` instance. + """ + param_name: str # unprefixed + module: nn.Module + module_name: str + prim_param_name: str # unprefixed + prim_module: nn.Module + prim_module_name: str + + +class FlatParamShardMetadata(NamedTuple): + """ + This holds metadata specific to this rank's shard of the flattened + parameter. + + Attributes: + param_names (Tuple[str, ...]): Prefixed parameter names of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_numels (Tuple[int, ...]): Parameter numels of this rank's shard + of the parameters; see :class:`FlatParameter`. + param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in + units of numels) giving this rank's part of each flattened + original module parameter. + """ + param_names: Tuple[str, ...] + param_shapes: Tuple[torch.Size, ...] + param_numels: Tuple[int, ...] + param_offsets: Tuple[Tuple[int, int], ...] + + +class FlatParameter(nn.Parameter): + """ + This is the flattened parameter used by :class:`FullyShardedDataParallel`. + It is comprised of one or more original parameters, which are flattened + and concatenated to construct the flattened parameter. + + Under the current design, this parameter logically represents both the + unsharded and sharded flattened parameter, and its data changes storages + dynamically. + - In the :class:`FullyShardedDataParallel` constructor, the parameter + is initialized as unsharded and then sharded in-place. + - At runtime, the parameter is lazily (re)-initialized. The sharded + parameter data is saved in ``self._local_shard``, and a new ``Tensor`` + ``self._full_param_padded`` is created, which is the all-gather + destination and owns the unsharded parameter storage thereafter. (See + :meth:`FullyShardedDataParallel._init_param_attributes`.) + - Throughout runtime, the parameter data changes storages as needed, + e.g. to the sharded flattened parameter, reduced-precision sharded + flattened parameter, or the unsharded flattened parameter. + + Attributes: + _is_sharded (bool): Whether the flattened parameter is *ever* sharded + across ranks (not whether it is *currently* sharded). + _unsharded_size (torch.Size): Unsharded flattened parameter's size. + + _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info + entry; see :class:`ParamInfo`. + _numels (Tuple[int, ...]): Each parameter's numel. + _shapes (Tuple[torch.Size, ...]): Each parameter's shape. + _prefixed_param_names (Tuple[str, ...]): Each parameter's name prefixed + with the parent module names starting from the module passed to + construct this flattened parameter via :class:`FlatParamHandle`; + the prefixed names are guaranteed to be unique within the subtree + rooted in that module. + _num_params (int): Number of original parameters flattened into this + flattened parameter; this is the length of ``_param_infos``, + ``_numels``, ``_shapes``, and ``_prefixed_param_names``. + _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter + info entries; see :class:`SharedParamInfo`. + + _shard_param_offsets (List[Tuple[int, int])): [start, end] offsets (in + units of numel) giving this rank's part of each flattened original + module parameter; for any parameter ``p`` that is not sharded + across ranks, this will be [0, ``p.numel()``-1]. + _shard_indices (Tuple[int, int]): [start, end] indices (in units of + parameters) for this rank's shard of the original model parameters, + where the parameters follow the order in which they were originally + flattened; this indexes appropriately into any data structure that + follows the flattening order (e.g. ``_param_infos``, ``_numels``, + etc.). + _shard_numel_padded (int): Numel padded for this rank's sharded + flattened parameter. + + _local_shard (Tensor): Sharded flattened parameter with padding. + _full_param_padded (Tensor): Unsharded flattened parameter with + padding. + _shard_bwd_hook (Tuple[AccumulateGrad, RemovableHandle]): Flattened + parameter's :class:`AccumulateGrad` object and post-backward hook + handle. + _mp_shard (Tensor): Reduced-precision flattened parameter with padding. + _cpu_grad (Tensor): Sharded gradient with padding stored on CPU. + _saved_grad_shard (Tensor): Sharded gradient with padding from previous + iterations for gradient accumulation without :meth:`no_sync`. + """ + + def init_metadata( + self, + param_infos: List[ParamInfo], + numels: List[int], + shapes: List[torch.Size], + prefixed_param_names: List[str], + shared_param_infos: List[SharedParamInfo], + ) -> None: + """ + Initializes attributes holding metadata about the original parameters + comprising the flattened parameter. + + We expose this method separate from the constructor to keep the + constructor only responsible for the flattened parameter's tensor data. + This method should only be called once per model, while the constructor + may be called multiple times, e.g. when reloading from a checkpoint, in + which case only the tensor data needs to be passed to the constructor. + Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the + metadata is correctly assumed to be unchanged. + + Args: + See the Attributes in the class docstring. + """ + assert len(param_infos) == len(numels) + assert len(param_infos) == len(shapes) + assert len(param_infos) == len(prefixed_param_names) + self._num_params = len(param_infos) + self._param_infos = tuple(param_infos) + self._numels = tuple(numels) + self._shapes = tuple(shapes) + self._prefixed_param_names = tuple(prefixed_param_names) + self._shared_param_infos = tuple(shared_param_infos) + self._is_sharded = False + self._unsharded_size = self.size() + + +class FlatParamHandle: + """ + This handle manages a flattened parameter (:class:`FlatParameter`). + + Args: + params (Sequence[nn.Parameter]): The parameters to use for the + flattened parameter. + module (nn.Module): A module that is the root of the subtree containing + all parameters in ``params``; for non-recursive wrapping, this must + be the top-level module, while for recursive wrapping, this may not + necessarily be the top-level module. + """ + def __init__( + self, + params: Sequence[nn.Parameter], + module: nn.Module, + ) -> None: + super().__init__() + self._init_flat_param(module, params) + self._unflatten(as_params=False) + + def _init_flat_param( + self, + module: nn.Module, + params: Sequence[Optional[nn.Parameter]], + ) -> None: + """ + Initializes the flattened parameter ``self.flat_param`` by flattening + the parameters in ``params`` into a single :class:`FlatParameter` and + saves relevant metadata. Shared parameters are only included in the + flattened parameter once. + + This checks that all comprising parameters have the same dtype and + ``requires_grad`` and does not support nested construction of + :class:`FlatParameter` s. + + Args: + See the Args in the class docstring. + """ + params_set = set(params) + params_set.discard(None) + assert len(params_set) > 0, \ + "Cannot initialize a `FlatParameter` from an empty parameter list" + param_infos: List[ParamInfo] = [] + numels: List[int] = [] + shapes: List[torch.Size] = [] + prefixed_param_names: List[str] = [] + shared_param_infos: List[SharedParamInfo] = [] + shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str, str]] = {} + params_to_flatten: List[nn.Parameter] = [] + dtype: Optional[torch.dtype] = None + requires_grad: Optional[bool] = None + for submodule_name, submodule in module.named_modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if param not in params_set: + continue + if param in shared_param_memo: + prim_module, prim_module_name, prim_param_name = shared_param_memo[param] + shared_param_infos.append(SharedParamInfo( + param_name, submodule, submodule_name, prim_param_name, + prim_module, prim_module_name, + )) + else: + if isinstance(param, FlatParameter): + raise ValueError("`FlatParameter` does not support nesting") + if dtype is not None and param.dtype != dtype: + raise ValueError( + "`FlatParameter` requires uniform dtype but got " + f"{dtype} and {param.dtype}" + ) + if requires_grad is not None and param.requires_grad != requires_grad: + raise ValueError("`FlatParameter` requires uniform `requires_grad`") + dtype = param.dtype + requires_grad = param.requires_grad + shared_param_memo[param] = (submodule, submodule_name, param_name) + params_to_flatten.append(param) + param_infos.append(ParamInfo(param_name, submodule, submodule_name)) + numels.append(param.numel()) + shapes.append(param.shape) + prefixed_param_name = submodule_name + "." + param_name \ + if submodule_name else param_name + prefixed_param_names.append(prefixed_param_name) + assert requires_grad is not None + self.flat_param = FlatParamHandle.flatten_params(params_to_flatten, requires_grad) + self.flat_param.init_metadata( + param_infos, numels, shapes, prefixed_param_names, shared_param_infos, + ) + + @staticmethod + def flatten_params( + params: Sequence[torch.Tensor], + requires_grad: bool, + ) -> FlatParameter: + """ + Flattens the parameters in ``params`` into a single + :class:`FlatParameter`. This should be the only way used to construct + :class:`FlatParameter` s. + + We expose this factory method for checkpointing (e.g. sharded state + dict). The flattened parameter's metadata should only be initialized + once (see :meth:`init_metadata`), but its tensor data may be reloaded. + """ + with torch.no_grad(): + flat_params = [ + p.detach().reshape(-1) if isinstance(p, nn.Parameter) + else p.reshape(-1) for p in params + ] + flat_param_data = torch.cat(flat_params, dim=0) + flat_param = FlatParameter(flat_param_data, requires_grad=requires_grad) + return flat_param + + @staticmethod + def _get_unflat_views( + flat_param: FlatParameter, + tensor: Optional[torch.Tensor] = None, + ) -> Iterator[Tensor]: + """ + Returns unflattened ``Tensor`` views into ``tensor`` if it is not + ``None`` or ``flat_param`` otherwise, where the unflattening is based + on ``flat_param`` 's metadata. + + In other words, to get views into the unsharded flattened parameter, + pass ``tensor`` as ``None``, but to get views into tensor optimizer + state, pass ``tensor`` as the optimizer state tensor. + """ + if tensor is None: + tensor = flat_param + assert tensor.numel() == flat_param._unsharded_size.numel(), \ + f"Expects {flat_param._unsharded_size.numel()} numel but got " \ + f"{tensor.numel()} numel" + views = ( + subtensor.view(shape) for (subtensor, shape) in + zip(torch.split(tensor, flat_param._numels, dim=0), flat_param._shapes) # type: ignore[arg-type] + ) + return views + + def _unflatten(self, as_params: bool) -> None: + """ + Unflattens the unsharded flattened parameter by setting the original + module parameter variables to be views into it. + + Args: + as_params (bool): If ``True``, then registers the original + parameters as ``nn.Parameter`` s; if ``False``, then registers + the original parameters only as ``Tensor`` s. ``False`` should + be used during forward/backward computation and when hiding the + original parameters from :meth:`nn.Module.named_parameters`. + """ + views = self._get_unflat_views(self.flat_param) + for view, (param_name, module, _) in zip(views, self.flat_param._param_infos): + if hasattr(module, param_name): + delattr(module, param_name) + if as_params: + module.register_parameter(param_name, nn.Parameter(view)) + else: + setattr(module, param_name, view) + for (param_name, module, _, prim_param_name, prim_module, _) in self.flat_param._shared_param_infos: + if hasattr(module, param_name): + delattr(module, param_name) + assert hasattr(prim_module, prim_param_name) + param: Union[Tensor, nn.Parameter] = getattr(prim_module, prim_param_name) + if as_params: + assert isinstance(param, nn.Parameter) + module.register_parameter(param_name, param) + else: + setattr(module, param_name, param) + + @contextlib.contextmanager + def unflatten_as_params(self) -> Generator: + """ + Assumes the flattened parameter is unsharded. When in the context, + unflattens the original parameters as ``nn.Parameter`` views into the + flattened parameter, and after the context, restores the original + parameters as ``Tensor`` views into the flattened parameter. + """ + self._unflatten(as_params=True) + try: + yield + finally: + self._unflatten(as_params=False) + + def init_shard_metadata( + self, + sharded_flat_param_numel: int, + numel_padded: int, + rank: int, + ) -> None: + """ + Initializes shard-related metadata for this rank's shard of the + flattened parameter: ``_shard_param_offsets``, ``_shard_indices``, and + ``_shard_numel_padded``. + + Args: + sharded_flat_param_numel (int): Numel of each rank's sharded + flattened parameter with padding (i.e. including + ``numel_padded``). + numel_padded (int): Numel padded for this rank's sharded flattened + parameter. + rank (int): Caller's rank. + """ + if numel_padded > sharded_flat_param_numel: + raise ValueError( + f"Sharded flattened parameter with {sharded_flat_param_numel} " + f"numel cannot have {numel_padded} numel padded" + ) + start = sharded_flat_param_numel * rank + end = sharded_flat_param_numel * (rank + 1) - 1 # inclusive + self.flat_param._shard_param_offsets, self.flat_param._shard_indices = ( # type: ignore[attr-defined] + self._get_shard_metadata(start, end) + ) + self.flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] + + def _get_shard_metadata( + self, + start: int, + end: int, + ) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]: + """ + Computes the shard metadata based on ``start`` and ``end``, which give + the closed interval of the unsharded flattened parameter specifying the + shard. + + Args: + start (int): Start index (in units of numel) of this rank's shard + of the flattened parameter. + end (int): End index (in units of numel and inclusive) of this + rank's shard of the flattened parameter. + + Return: + Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]: See + ``_shard_param_offsets`` and ``_shard_indices`` in + :class:`FlatParameter` 's docstring. + """ + flat_param_offsets = self._get_flat_param_offsets() + # Indices of the original parameters in this rank's sharded flattened + # parameter + shard_param_indices_range = [] # elements will be consecutive + # [start, end] offsets giving this rank's part of the flattened + # original module parameter (which will be [0, `p.numel()`-1] for any + # parameter that is not sharded across ranks) + shard_param_offsets = [] + for i, (param_start, param_end) in enumerate(flat_param_offsets): + if start > param_end or end < param_start: + continue + if start <= param_start: + intra_param_start = 0 + else: + intra_param_start = start - param_start + intra_param_end = min(param_end, end) - param_start + shard_param_indices_range.append(i) + shard_param_offsets.append((intra_param_start, intra_param_end)) # both inclusive + if len(shard_param_indices_range) == 0: + shard_param_indices = (0, 0) + assert len(shard_param_offsets) == 0 + else: + shard_param_indices = ( + shard_param_indices_range[0], shard_param_indices_range[-1], + ) + assert len(shard_param_offsets) == \ + shard_param_indices[-1] - shard_param_indices[0] + 1 + return tuple(shard_param_offsets), shard_param_indices + + @staticmethod + def _get_unpadded_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> Tuple[Tensor, int]: + """ + Returns the shard of ``tensor`` without any padding for the given + ``rank`` and ``world_size`` and the numel to pad for that shard. + + If ``tensor`` is already flattened or may be viewed in the flattened + shape (which is true in the expected usage), then this method does not + allocate any new tensor memory. + """ + chunks = torch.flatten(tensor).chunk(world_size) + if len(chunks) < (rank + 1): + # This rank gets an empty chunk fully padded with zeros since there + # are not enough chunks across ranks + chunk = chunks[0].new_empty(0) + else: + chunk = chunks[rank] + numel_to_pad = chunks[0].numel() - chunk.numel() + assert numel_to_pad >= 0, "Chunk's size should be at most the first chunk's size" + return chunk, numel_to_pad + + @staticmethod + def _get_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> Tuple[Tensor, int]: + """ + Returns the shard of ``tensor`` with padding for the given ``rank`` and + ``world_size`` and the numel padded for that shard. + + This method allocates new memory (via :meth:`clone`) since the + unsharded ``tensor`` may be deallocated after this method returns. + """ + chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(tensor, rank, world_size) + shard = chunk.clone() + if numel_to_pad > 0: + shard = F.pad(shard, [0, numel_to_pad]) + return shard, numel_to_pad + + @staticmethod + def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size: + """ + Returns the shape of ``tensor`` after sharding including padding. This + requires ``tensor`` to have 1D shape and ensures that the returned + shape is 1D. + """ + assert len(tensor.shape) == 1, f"{tensor.shape}" + unpadded_sharded_tensor, numel_to_pad = ( + FlatParamHandle._get_unpadded_shard(tensor, rank, world_size) + ) + unpadded_sharded_size = unpadded_sharded_tensor.size() + assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}" + return torch.Size([unpadded_sharded_size[0] + numel_to_pad]) + + def _get_flat_param_offsets(self) -> List[Tuple[int, int]]: + """Returns [start, end] offsets of each original parameter's flattened + data in the unsharded flattened parameter (without padding).""" + cumulative_sum = list(accumulate(self.flat_param._numels)) + starts = [0] + cumulative_sum[:-1] + ends = [end - 1 for end in cumulative_sum] # inclusive + param_offsets = list(zip(starts, ends)) + return param_offsets + + def shard_metadata( + self, + ) -> FlatParamShardMetadata: + """Returns shard-related metadata specific to this rank's shard of the + flattened parameter.""" + assert hasattr(self.flat_param, "_shard_indices") and \ + hasattr(self.flat_param, "_shard_param_offsets"), \ + "Shard metadata has not been initialized" + shard_param_start_index = self.flat_param._shard_indices[0] # type: ignore[attr-defined] + shard_param_end_index = self.flat_param._shard_indices[1] # type: ignore[attr-defined] + sl = slice(shard_param_start_index, shard_param_end_index + 1) \ + if shard_param_start_index <= shard_param_end_index else slice(0, 0) + return FlatParamShardMetadata( + self.flat_param._prefixed_param_names[sl], + self.flat_param._shapes[sl], + self.flat_param._numels[sl], + self.flat_param._shard_param_offsets[:], # type: ignore[attr-defined] + ) + + def _get_modules(self) -> Set[nn.Module]: + """Returns a :class:`set` of the modules whose parameters are included + in this handle's flattened parameter.""" + return set(pi.module for pi in self.flat_param._param_infos).union( + set(spi.module for spi in self.flat_param._shared_param_infos) + ) + + def parameter_module_names(self) -> Iterator[Tuple[str, str]]: + shared_param_infos = [ + ParamInfo(param_name, module, module_name) + for (param_name, module, module_name, _, _, _) + in self.flat_param._shared_param_infos + ] + for param_name, _, module_name in chain( + self.flat_param._param_infos, shared_param_infos + ): + yield (param_name, module_name) diff --git a/torch/distributed/fsdp/flatten_params_wrapper.py b/torch/distributed/fsdp/flatten_params_wrapper.py index 97e086fc8435f..a4eed12d40bf7 100644 --- a/torch/distributed/fsdp/flatten_params_wrapper.py +++ b/torch/distributed/fsdp/flatten_params_wrapper.py @@ -7,31 +7,18 @@ # Licensed under the MIT License. import contextlib -from itertools import accumulate -from typing import ( - Any, - Dict, - Generator, - Iterator, - List, - NamedTuple, - Optional, - Sequence, - Tuple, -) +from typing import Any, Dict, Generator, List -import torch import torch.nn as nn -from torch import Tensor - from torch.distributed.utils import _replace_by_prefix +from .flat_param import FlatParamHandle -ParamOffset = Tuple[int, int] -SharedParamInfo = Tuple[str, str, nn.Module, str, nn.Module, str] FLAT_PARAM = "flat_param" FPW_MODULE = "_fpw_module" +__all__ = ["FlattenParamsWrapper"] + def _post_state_dict_hook( module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any @@ -70,407 +57,100 @@ def _pre_load_state_dict_hook( _replace_by_prefix(state_dict, k, prefix + last_part) -class ParamInfo(NamedTuple): - module_name: str - module: nn.Module - param_name: str - - -class ShardMetadata(NamedTuple): - param_names: List[str] - param_shapes: List[torch.Size] - param_numels: List[int] - param_offsets: List[ParamOffset] - - -class FlatParameter(nn.Parameter): - """ - A parameter that is initialized from a list of parameters. All the - parameters will be flattened and concatened to form the flat parameter. - - Args: - params (Sequence[nn.Parameter]) - The parameters to be flattend and concatened. - requires_grad (bool): - Set to True if gradients need to be computed for this parameter, - False otherwise. - """ - - def __new__( - cls, params: Sequence[nn.Parameter], requires_grad: bool = True - ) -> "FlatParameter": - """Make an object using the parent's __new__ function.""" - - # A empty or non-list input doesn't make sense. - if not isinstance(params, (list, tuple)) or len(params) == 0: - raise ValueError("An non-empty list or tuple argument is needed") - - # Normally, all items are Parameters. But during pickling, we will have a single - # Tensor as the input and later in __init__, the correct _param_numels and _param_shapes - # are set. - if not all(isinstance(p, (nn.Parameter, Tensor)) for p in params): - incorrect_parameters = [ - p for p in params if not isinstance(p, (nn.Parameter, Tensor)) - ] - raise ValueError( - f"List items need to be Parameter types {incorrect_parameters}" - ) - - # Flattening involves (1) making a tensor flat (i.e. single dimensional) and - # (2) making a module hierarchy flat (using a single tensor to replace a tree of - # tensors). Therefore, adding back nesting and hierarchy is counter-productive. - # If nesting is encountered in the future, the reasonable thing to do is likely - # for the top level FlatParameter to absorb the nested one and keep the result flat, - # free from hierarchy. - if any(isinstance(p, FlatParameter) for p in params): - raise ValueError("Nesting FlatParameter is not supported") - - data = torch.cat( - [ - p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1) - for p in params - ], - 0, - ) - - return super(FlatParameter, cls).__new__( - cls, data, requires_grad=requires_grad - ) # type: ignore[call-arg] - - def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True): - self._is_sharded = False - self._param_numels = [p.numel() for p in params] - # The total element numbers. This is equal to the summation of the - # ``numel()`` of all the parameters. - self.full_numel = sum(self._param_numels) - assert self.numel() <= self.full_numel, ( - "Parameter numbers mismatched. " - f"The number of elements in FlatParameter: {self.numel()} vs. " - f"the number of elements in original parameters: {self.full_numel}." - ) - # The shapes of each individual parameter. - self._param_shapes = [p.size() for p in params] - cumulative_sum = list(accumulate(self._param_numels)) - begin = [0] + cumulative_sum[:-1] - end = [e - 1 for e in cumulative_sum] - - self._param_infos: List[ParamInfo] = [] - self._shared_param_infos: List[SharedParamInfo] = [] - - # The element offsets (begin/end pair) in the flat parameter of each - # individual parameter. - self._param_offsets = list(zip(begin, end)) - # The indices (begin/end pair) of the parameters that are included in - # this FlatParameter. The default value is all the parameters because - # no sharding happen yet. - self._param_indice_in_shard = (0, len(self._param_infos) - 1) - # The offsets in each parameter that is included in the FlatParameter. - self._sharded_param_offsets: List[ParamOffset] = [ - (0, numel) for numel in self._param_numels - ] - # The number of padding elements. - self.num_padded = 0 - - def shard_by_offsets(self, start: int, end: int, num_padded: int) -> None: - assert self._is_sharded - if start < 0 or end < 0 or end < start: - raise ValueError( - f"Shard the flatten parameter with an invalid offset pair {(start, end)}." - ) - _shard_size = end - start + 1 - self.num_padded = num_padded - if self.num_padded > _shard_size: - raise ValueError("The number of padding is larger than the shard size.") - self._sharded_param_offsets.clear() - - ranges = [] - for idx, offset in enumerate(self._param_offsets): - if start > offset[1] or end < offset[0]: - continue - if start <= offset[0]: - sharded_param_start = 0 - sharded_param_end = min(offset[1], end) - offset[0] - else: - sharded_param_start = start - offset[0] - sharded_param_end = min(offset[1], end) - offset[0] - ranges.append(idx) - self._sharded_param_offsets.append((sharded_param_start, sharded_param_end)) - if ranges: - self._param_indice_in_shard = (ranges[0], ranges[-1]) - - def _offset_to_slice(self) -> slice: - if self._param_indice_in_shard[0] > self._param_indice_in_shard[1]: - return slice(0, 0) - return slice(self._param_indice_in_shard[0], self._param_indice_in_shard[1] + 1) - - def get_param_views( - self, external_data: Optional[Tensor] = None - ) -> Iterator[Tensor]: - """Return a generator of views that map to the original parameters.""" - # Note, self.data could be sharded, so its numel is <= to the sum. - assert ( - self.data.numel() <= self.full_numel - ), f"Incorrect internal state {self.data.numel()} vs. {self.full_numel}" - data = external_data if external_data is not None else self - if data.numel() != self.full_numel: - raise ValueError( - f"Incorrect numel of supplied data: got {data.numel()} but expected {self.full_numel}" - ) - return ( - t.view(s) - for (t, s) in zip(data.split(self._param_numels), self._param_shapes) - ) - - @property - def _num_unflattened_params(self) -> int: - """Returns the number of unflattened parameters that comprise this - flattened parameter.""" - assert hasattr(self, "_param_infos"), \ - "`_param_infos` has not been set, meaning this `FlatParameter` " \ - "has not been initialized yet" - num_unflat_params = len(self._param_infos) - assert num_unflat_params > 0, "`FlatParameter` corresponding to 0 " \ - "unflattened parameters" - return num_unflat_params - - @property - def param_info(self) -> List[ParamInfo]: - return self._param_infos - - @property - def _param_names(self): - return [".".join([m, n]) if m else n for (m, _, n) in self._param_infos] - - def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]: - """Return tuple of (names, shapes, numels) metadata for this flat parameter.""" - return self._param_names, self._param_shapes, self._param_numels - - def shard_metadata( - self, - ) -> ShardMetadata: - """ - Return tuple of (names, shapes, numels) metadata for the sharded parameter - metadata of this flat parameter. - """ - return ShardMetadata( - self._param_names[self._offset_to_slice()], - self._param_shapes[self._offset_to_slice()], - self._param_numels[self._offset_to_slice()], - self._sharded_param_offsets[:], - ) - - class FlattenParamsWrapper(nn.Module): """ - A wrapper for transparently flattening a Module's parameters. - The original implementation [1] reparameterizes a PyTorch module - that is called ReparamModule. The ReparamModule has only a flattened - parameter representing all parameters of the wrapped module. - Compared to the original implementation [1], this version: - - removes tracing - - supports shared parameters - - is renamed to FlattenParamsWrapper + This is a wrapper for flattening parameters in a ``nn.Module`` 's subtree + into a single flattened parameter and is based on [1]. This is used for + :class:`FullyShardedDataParallel` 's recursive wrapping. [1] https://github.com/SsnL/PyTorch-Reparam-Module + Args: - module (nn.Module): - The module to wrap. - param_list (List[nn.Parameter]): - Only flatten parameters appearing in the given list. - Note, if only a single param is in the list, it still gets - flattened and the original param is removed and replaced - with the flatten one. + module (nn.Module): Module to wrap. + params (List[nn.Parameter]): Parameters in ``module`` 's subtree to + flatten into a single flattened parameter. + + Attributes: + flat_param (Optional[FlatParameter]): The flattened parameter. + ``flat_param`` is ``None`` either when (1) this wrapper manages no + parameters or (2) the wrapped module's parameters are unflattened. + _fpw_module (nn.Module): The wrapped module. + _flat_param_handle (FlatParamHandle): A handle for the flattened + parameter; only present if this wrapper manages parameters. """ - - def __init__(self, module: nn.Module, param_list: List[nn.Parameter]): + def __init__( + self, + module: nn.Module, + params: List[nn.Parameter], + ) -> None: super().__init__() self._fpw_module = module - # People may test whether this module contains parameters by using - # `getattr(module, "flat_param") is None`. This is not always accurate - # as the above condition is also true if this module is unflattened. - # `no_params` explicitly shows this module has no parameters and - # is always correct regardless flattened or unflattened. - self.no_params = True self.flat_param = None - - # Register hook to be called after state_dict() to remove the - # "_fpw_module." prefix and before load_state_dict() to add it back. - # The hooks must be registered even if the target param_list is empty as - # all submodules in FlattenParamsWrapper should be pre/post processed by - # the hooks. + # Register hooks to clean parameter names for state dict (even if this + # wrapper itself manages no parameters since it must clean names from + # submodules) self._register_state_dict_hook(_post_state_dict_hook) self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook) - - if len(param_list) == 0: + if len(params) == 0: return - - # A list of parameters to be flatten - unique_param_list = set(param_list) - self.no_params = False - - # convert from list of Parameters to set of (Module, parameter_name) tuples, which - # will survive in case the Parameter instances are reset. - # it includes (m, n) that points to the same parameter. - self.param_set = set() - for m in self.modules(): - for n, p in m.named_parameters(recurse=False): - if p in unique_param_list: - self.param_set.add((m, n)) - - params, param_infos, shared_param_infos = self._init_flatten_params() - self.flat_param = FlatParameter(params, params[0].requires_grad) - self.flat_param._param_infos = param_infos - self.flat_param._shared_param_infos = shared_param_infos - - # This attribute is used to remember the flat_param inside the unflatten_params() - # context. With this attribute, FSDP can access the flat parameter metadata - # even if flat_param is temporarily deleted. - # ``orig_flat_param` is a list to avoid being tracked by ``state_dict()``. - self.orig_flat_param: List[Optional[FlatParameter]] = [None] - self._flatten_params() - - # Sanity check for the string constants. + self._flat_param_handle = FlatParamHandle(params, module) + # Defining `self.flat_param` registers the `FlatParameter` and makes it + # visible to `named_parameters()` + self.flat_param = self._flat_param_handle.flat_param assert getattr(self, FPW_MODULE) is self._fpw_module assert getattr(self, FLAT_PARAM) is self.flat_param @property - def module(self) -> Any: - """Support _fsdp_wrapped_module.module in case we are immitating DDP, which has .module - property to the underlying module. - """ - return self._fpw_module - - def _init_flatten_params( - self, - ) -> Tuple[List[nn.Parameter], List[ParamInfo], List[SharedParamInfo]]: - """Build metadata for need-to-be-flatten parameters and returns a list - contains the need-to-be-flatten parameters. - This also fills param_infos and shared_param_infos. - """ - param_infos: List[ParamInfo] = [] - shared_param_infos = [] - shared_param_memo: Dict[nn.Parameter, Tuple[str, nn.Module, str]] = {} - params = [] - for module_name, m in self.named_modules(): - for n, p in m.named_parameters(recurse=False): - if p is not None and (m, n) in self.param_set: - if p in shared_param_memo: - mname, shared_m, shared_n = shared_param_memo[p] - shared_param_infos.append( - (module_name, mname, m, n, shared_m, shared_n) - ) - else: - shared_param_memo[p] = (module_name, m, n) - param_infos.append(ParamInfo(module_name, m, n)) - params.append(p) - del shared_param_memo - - assert ( - len(set(p.dtype for p in params)) == 1 - ), "expects all parameters to have same dtype" - assert ( - len(set(p.requires_grad for p in params)) == 1 - ), "expects all parameters to have same requires_grad" - assert len(params) == len(set(params)), "params list should not have dups" + def has_params(self) -> bool: + """Returns whether this wrapper manages any parameters.""" + return hasattr(self, "_flat_param_handle") - return params, param_infos, shared_param_infos - - def _flatten_params(self, external_data: Optional[FlatParameter] = None) -> None: - """Flatten the managed parameters and replaced the original - attributes with views to the flat param. If `external_data` - is passed, it will be used as the flat_param. - """ - # register the flatten one - assert ( - getattr(self, "flat_param", None) is not None or external_data is not None - ), "Can not flatten params when both flat_param and external_data are None." - if external_data is not None: - self.flat_param = external_data - self.register_parameter("flat_param", self.flat_param) - - assert self.flat_param is not None # avoid mypy complain. - # deregister the names as parameters - for _, m, n in self.flat_param._param_infos: - delattr(m, n) - for _, _, m, n, _, _ in self.flat_param._shared_param_infos: - delattr(m, n) - - # register the views as plain attributes - self._unflatten_params_as_views() - - def _unflatten_params_as_views(self) -> None: - """Unlike ``_unflatten_params``, this function unflatten into views and keep - self.flat_param unchanged. - """ - assert ( - self.flat_param is not None - ), "Can not unflatten params as views when flat_param is None." - ps = self._get_param_views() - for (_, m, n), p in zip(self.flat_param._param_infos, ps): - setattr(m, n, p) # This will set as plain attr - - for (_, _, m, n, shared_m, shared_n) in self.flat_param._shared_param_infos: - setattr(m, n, getattr(shared_m, shared_n)) - - def _unflatten_params(self) -> None: - """Undo flattening and create separate parameters from the already flattened - self.flat_param. - """ - assert ( - self.flat_param is not None - ), "Can not unflatten params when flat_param is None." - ps = self._get_param_views() - for (_, m, n), p in zip(self.flat_param._param_infos, ps): - if hasattr(m, n): - delattr(m, n) - m.register_parameter(n, nn.Parameter(p)) - for (_, _, m, n, shared_m, shared_n) in self.flat_param._shared_param_infos: - if hasattr(m, n): - delattr(m, n) - m.register_parameter(n, getattr(shared_m, shared_n)) + @property + def handle(self) -> FlatParamHandle: + assert hasattr(self, "_flat_param_handle"), \ + "Accessing the handle of a `FlattenParamsWrapper` that does not " \ + "manage any parameters" + return self._flat_param_handle - del self.flat_param + @property + def module(self) -> Any: + """Returns the wrapped module (like DDP).""" + return self._fpw_module @contextlib.contextmanager - def unflatten_params(self) -> Generator: + def unflatten_as_params(self) -> Generator: """ - Unflatten params. If the current instance is already unflattened, then - it will remain unflattened after the context manager exits. + Assumes that the flattened parameter is unsharded. When in the context, + unflattens the original parameters as ``nn.Parameter`` views into the + flattened parameter and de-registers the flattened parameter. After the + context, restores the original parameters as ``Tensor`` views into the + flattened parameter and re-registers the flattened parameter. """ if getattr(self, "flat_param", None) is None: yield else: - self.orig_flat_param[0] = self.flat_param - self._unflatten_params() - # Put yield in a try...finally in case the caller catches the exception and handles - # it. In that case, we need to properly handle the undoing of state here. + # De-register the `FlatParameter` from this wrapper to hide it from + # `named_parameters()` (though it still exists in memory) + del self.flat_param try: - yield + with self._flat_param_handle.unflatten_as_params(): + yield finally: - self._flatten_params(self.orig_flat_param[0]) - self.orig_flat_param[0] = None - - def _get_param_views( - self, external_data: Optional[Tensor] = None - ) -> Iterator[Tensor]: - """Return a generator of views that map to the original parameters.""" - assert self.flat_param is not None - return self.flat_param.get_param_views(external_data) + # Re-register the `FlatParameter` + self.flat_param = self._flat_param_handle.flat_param def __getattr__(self, name: str) -> Any: - """Forward missing attributes to wrapped module.""" + """Forward missing attributes of this wrapper to the wrapped module.""" try: - return super().__getattr__(name) # defer to nn.Module's logic + return super().__getattr__(name) # defer to `nn.Module`'s logic except AttributeError: - return getattr(self.module, name) # fallback to wrapped module + return getattr(self.module, name) # fall back to the wrapped module def __getitem__(self, key: int) -> Any: - """Forward indexing calls in case the module is a nn.Sequential.""" + """Forward indexing calls to the wrapped module in case the wrapped + module is an ``nn.Sequential``.""" return self.module.__getitem__(key) - def _unflatten_params_if_needed(self) -> None: - if self.flat_param is not None: - self._unflatten_params_as_views() - def forward(self, *inputs: Any, **kwinputs: Any) -> Any: - self._unflatten_params_if_needed() + if self.flat_param is not None: + self._flat_param_handle._unflatten(as_params=False) return self.module(*inputs, **kwinputs) diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index da4c9f9f3536f..470b189894941 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -28,6 +28,7 @@ import torch import torch.distributed as dist +import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable @@ -37,6 +38,13 @@ ShardedTensor, init_from_local_shards, ) +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.algorithms._comm_hooks import ( + LOW_PRECISION_HOOKS, + default_hooks, +) from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.utils import ( _replace_by_prefix, @@ -63,17 +71,17 @@ _contains_batchnorm, _override_batchnorm_mixed_precision, ) +from .flat_param import FlatParameter, FlatParamHandle from .flatten_params_wrapper import ( FLAT_PARAM, FPW_MODULE, - FlatParameter, FlattenParamsWrapper, ) from .wrap import ( + ParamExecOrderWrapPolicy, _or_policy, _recursive_wrap, _wrap_batchnorm_individually, - ParamExecOrderWrapPolicy ) _TORCHDISTX_AVAIL = True @@ -82,6 +90,16 @@ except ImportError: _TORCHDISTX_AVAIL = False +_TORCH_FX_AVAIL = True +if not hasattr(torch, "fx"): + _TORCH_FX_AVAIL = False +if _TORCH_FX_AVAIL: + from ._symbolic_trace import ( + TracingConfig, + _init_execution_info, + _patch_tracer, + ) + __all__ = [ "FullyShardedDataParallel", "ShardingStrategy", "MixedPrecision", @@ -96,7 +114,6 @@ _PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024) - def _default_meta_device_init_fn(module): """ Default initializer for modules initialized on the meta device. @@ -116,23 +133,31 @@ def _default_meta_device_init_fn(module): class ShardingStrategy(Enum): """ - Specify which sharding strategy will be used for the distributed training. - FULL_SHARD: Shards parameters, gradients and optimizer states. This algorithm - inserts ``all_gather`` before forward and backward computation to gather - parameters, also inserts ``reduce_scatter`` after backward computation for - synchronizing and sharding gradients. Sharded optimizer states are + This specifies the sharding strategy to be used for distributed training by + :class:`FullyShardedDataParallel`. + FULL_SHARD: Parameters, gradients, and optimizer states are sharded. For + the parameters, this algorithm all-gathers before the forward, + reshards after the forward, all-gathers before the backward + computation, and reshards after the backward computation. The + gradients are synchronized and sharded via reduce-scatter after + the backward computation. The sharded optimizer states are updated locally. - SHARD_GRAD_OP: Shard optimizer states and gradients, this algorithm inserts all_gather - before forward computation and keeps the full parameters in - GPU memory until backward computation is done. It inserts reduce_scater - after backward computation for synchronizing and sharding gradients. - Sharded optimizer states are updated locally. - NO_SHARD: This is similar to PyTorch ``DistributedDataParallel`` API. Parameters, gradients - and optimizer states are replicated among ranks, ``all_reduce`` is inserted after - backward computation is done for synchronizing gradients. Full optimizer states - are updated in each rank. - HYBRID_SHARD(future support): apply FULL_SHARD algorithm in the intra node and - apply NO_SHARD algorithm in the inter nodes. + SHARD_GRAD_OP: Gradients and optimizer states are sharded during + computation, and additionally parameters are sharded outside + computation. For the parameters, this algorithm all-gathers + before the forward, does not reshard after the forward, and + only reshards after the backward computation. The gradients + are synchronized and sharded via reduce-scatter after the + backward computation. The sharded optimizer states are + updated locally. Inside ``no_sync()``, the parameters are + not resharded after the backward computation. + NO_SHARD: Parameters, gradients, and optimizer states are not sharded but + instead replicated across ranks, similar to PyTorch's + ``DistributedDataParallel`` API. The gradients are synchronized + via all-reduce after the backward computation. The unsharded + optimizer states are updated locally. + HYBRID_SHARD(future support): Apply ``FULL_SHARD`` intra-node and + ``NO_SHARD`` inter-node. """ FULL_SHARD = auto() @@ -627,7 +652,6 @@ class FullyShardedDataParallel(nn.Module): :class:`FullStateDictConfig` for an example of this. (Default: ``False``) """ - def __init__( self, module: nn.Module, @@ -649,7 +673,7 @@ def __init__( process_group=process_group, sharding_strategy=sharding_strategy, cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy.init_policy, + auto_wrap_policy=auto_wrap_policy, backward_prefetch=backward_prefetch, mixed_precision=mixed_precision, ignored_modules=ignored_modules, @@ -661,6 +685,7 @@ def __init__( torch._C._log_api_usage_once("torch.distributed.fsdp") super().__init__() + self._handles: List[FlatParamHandle] = [] # Validate the ignored modules and derive the ignored parameters/buffers ignored_modules = self._get_ignored_modules(module, ignored_modules) self._ignored_modules = ignored_modules @@ -812,7 +837,6 @@ def _run_param_init_fn(): self.world_size / self.gradient_predivide_factor ) - self.numel_padded_per_param: List[int] = [] self.cpu_offload = cpu_offload or CPUOffload() self.backward_prefetch = backward_prefetch self.forward_prefetch = forward_prefetch @@ -857,20 +881,19 @@ def _run_param_init_fn(): src=0, ) - self._fsdp_wrapped_module: FlattenParamsWrapper = FlattenParamsWrapper( - module, param_list=params - ) + self._fsdp_wrapped_module = FlattenParamsWrapper(module, params) assert getattr(self, FSDP_WRAPPED_MODULE) is self._fsdp_wrapped_module - del module # free original module in case it helps garbage collection - if self._fsdp_wrapped_module.flat_param is not None: - self.params = [self._fsdp_wrapped_module.flat_param] - else: - self.params = [] + self.params = [] + if self._fsdp_wrapped_module.has_params: + self.params.append(self._fsdp_wrapped_module.flat_param) + self._register_param_handle(self._fsdp_wrapped_module.handle) # Shard module parameters in place self._shard_parameters() - # Make sure all parameters are sharded. + # Check that the sharding logic was applied to all parameters by + # checking that the original module parameters have been replaced by + # `Tensor` views and are no longer `nn.Parameter`s for n, p in self.named_parameters(): if p not in ignored_params and not isinstance(p, FlatParameter): raise RuntimeError( @@ -927,23 +950,80 @@ def _run_param_init_fn(): # For validating execution order across ranks self._exec_order_data = _ExecOrderData() + # setting communication hook to a default + self.communication_hook = self._get_default_comm_hook() + self.communication_hook_state = self._get_default_comm_hook_state() + self._hook_registered = False + def _init_param_exec_order_wrap_policy(self, *args, **kwargs) -> None: + auto_wrap_policy = kwargs["auto_wrap_policy"] + module = kwargs["module"] + assert hasattr(auto_wrap_policy, "tracing_config") + if not _TORCH_FX_AVAIL: + assert ( + auto_wrap_policy.tracing_config is None + ), "tracing_config should be None when torch.fx is not enabled" + elif isinstance( + auto_wrap_policy.tracing_config, + TracingConfig + ): + tracer = auto_wrap_policy.tracing_config.tracer + execution_info = _init_execution_info(module) + + for m in module.modules(): + assert not isinstance( + m, FullyShardedDataParallel + ), "The input module of _patch_tracer should not contain FSDP modules" + + with _patch_tracer( + tracer=tracer, + root_module=module, + execution_info=execution_info, + ): + try: + tracer.trace(module, auto_wrap_policy.tracing_config.concrete_args) + except BaseException as e: + raise RuntimeError( + "tracer.trace failed inside _init_param_exec_order_wrap_policy" + f" with the error: {e}." + ) + else: + assert ( + auto_wrap_policy.tracing_config is None + ), "tracing_config should either be an instance of TracingConfig or be None" # The initial FSDP wrapping is done with auto_wrap_policy.init_policy + kwargs["auto_wrap_policy"] = auto_wrap_policy.init_policy self.__init__(*args, **kwargs) self._param_exec_order_policy: bool = True - # self._param_exec_order_prep_stage is set to True in the first forward and backward iteration, - # and set to False otherwise. + # self._param_exec_order_prep_stage is set to True before we get the execution order self._param_exec_order_prep_stage: bool = True # A list that stores the flatten parameters and its name based on the parameter execution order self._fsdp_params_exec_order: List[FlatParameter] = [] + if _TORCH_FX_AVAIL and isinstance( + auto_wrap_policy.tracing_config, + TracingConfig + ): + # Initialize a dict that maps each module to its parent FSDP wrap + module_to_fsdp: Dict[nn.Module, FullyShardedDataParallel] = dict() + for wrap in self.fsdp_modules(self): + module_to_fsdp[wrap.module] = wrap + # Set self._fsdp_params_exec_order based on execution_info.module_forward_order. + # TODO (linjianma): self._fsdp_params_exec_order will be set based on + # the parameter execution order rather than module_forward_order, + # once the non-recursive wrapping policy is fully implemented. + for m in execution_info.module_forward_order: + if m in module_to_fsdp: + for flat_param in module_to_fsdp[m].params: + self._fsdp_params_exec_order.append(flat_param) + self._param_exec_order_prep_stage = False + for m in self.modules(): if m is not self and isinstance(m, FullyShardedDataParallel): # Assignment by reference, so each children FSDP wrap has access to # the _fsdp_params_exec_order of the root module m._fsdp_params_exec_order = self._fsdp_params_exec_order - m._param_exec_order_policy = True - m._param_exec_order_prep_stage = True - + m._param_exec_order_policy = self._param_exec_order_policy + m._param_exec_order_prep_stage = self._param_exec_order_prep_stage def _move_module_if_needed(self, module) -> None: """ @@ -1133,6 +1213,11 @@ def _check_wrapped(cls, begin_module, check_fn, err_fn): if not check_fn(mod): raise ValueError(err_fn(mod)) + def _register_param_handle(self, handle: FlatParamHandle) -> None: + """Registers the parameter handle to this FSDP instance.""" + if handle not in self._handles: + self._handles.append(handle) + @property def module(self) -> nn.Module: """Make model.module accessible, just like DDP. Return the @@ -1266,6 +1351,15 @@ def _mixed_precision_enabled_for_reduce(self) -> bool: and self.mixed_precision.reduce_dtype is not None ) + def _low_precision_hook_enabled(self) -> bool: + """ + Wether a low precision hook is registered or not. + """ + return ( + self.communication_hook is not None + and self.communication_hook in LOW_PRECISION_HOOKS + ) + def _cast_fp_inputs_to_precision( self, dtype: torch.dtype, *args: Any, **kwargs: Any ) -> Tuple[Any, Any]: @@ -1409,7 +1503,8 @@ def _shard_parameters(self) -> None: allocate less memory for optimizer state, avoiding redundancy across data parallel workers. """ - for p in self.params: + for handle in self._handles: + p = handle.flat_param assert not p._is_sharded, "Param should have not been sharded yet." assert ( p.is_floating_point() @@ -1421,10 +1516,8 @@ def _shard_parameters(self) -> None: self.world_size > 1 and self.sharding_strategy != ShardingStrategy.NO_SHARD ) - p._orig_size = p.size() # type: ignore[attr-defined] if not p._is_sharded: # type: ignore[attr-defined] - self.numel_padded_per_param.append(0) continue # Save the original storage and free it later on. @@ -1436,77 +1529,14 @@ def _shard_parameters(self) -> None: orig_storage = p.storage() # Replace p with the relevant shard. - local_shard, num_padded = self._get_shard(p) + local_shard, numel_padded = FlatParamHandle._get_shard(p, self.rank, self.world_size) p.set_(local_shard) # type: ignore[call-overload] - p.shard_by_offsets( - self.rank * local_shard.numel(), - (self.rank + 1) * local_shard.numel() - 1, - num_padded, - ) - self.numel_padded_per_param.append(num_padded) + handle.init_shard_metadata(local_shard.numel(), numel_padded, self.rank) # Free storage that contains the original full data. if orig_storage.size() > 0: orig_storage.resize_(0) # type: ignore[attr-defined] - assert len(self.numel_padded_per_param) == len( - self.params - ), "numel_padded_per_param is not populated correctly." - - @staticmethod - def _get_chunk( - tensor: torch.Tensor, - rank: int, - world_size: int, - ) -> Tuple[torch.Tensor, int]: - """Returns the unpadded chunk as a view and the number of padding - elements of a full tensor for the given rank and world size.""" - # Shard using `torch.chunk()` to match all-gather/reduce-scatter. - chunks = torch.flatten(tensor).chunk(world_size) - if len(chunks) < (rank + 1): - # If there are not enough chunks to shard across ranks, create an - # empty chunk that will just be padded with zeros to be the - # appropriate size. - chunk = chunks[0].new_empty(0) - else: - chunk = chunks[rank] - # Determine number of padding elements. - num_to_pad = chunks[0].numel() - chunk.numel() - assert num_to_pad >= 0, \ - "Chunk's size should at most the first chunk's size" - return chunk, num_to_pad - - @staticmethod - def _get_shard_functional( - tensor: torch.Tensor, - rank: int, - world_size: int, - ) -> Tuple[torch.Tensor, int]: - """Functional version of :meth:`_get_shard`.""" - chunk, num_to_pad = FullyShardedDataParallel._get_chunk( - tensor, rank, world_size, - ) - # We always need to clone here regardless of the padding and even - # though `chunk` is a view of `tensor` because `tensor` may be - # deallocated after this method returns - shard = chunk.clone() - if num_to_pad > 0: - shard = F.pad(shard, [0, num_to_pad]) - return shard, num_to_pad - - def _get_shard( - self, - tensor: torch.Tensor, - rank: Optional[int] = None, - ) -> Tuple[torch.Tensor, int]: - """Returns the local shard and the number of padding elements of a full - tensor for the calling rank if ``rank=None`` or for the rank ``rank`` - if not ``None``.""" - rank = self.rank if rank is None else rank - return FullyShardedDataParallel._get_shard_functional( - tensor, rank, self.world_size, - ) - def __getattr__(self, name: str) -> Any: """Forward missing attributes to wrapped module.""" try: @@ -1521,7 +1551,6 @@ def __getitem__(self, key: int) -> Any: def _reset_lazy_init(self) -> None: """ Reset instance so :func:`_lazy_init` will run on the next forward. - Currently this is only called in __init__ """ self._is_root: Optional[bool] = None self._streams: Dict[str, torch.cuda.Stream] = {} @@ -1539,50 +1568,57 @@ def _reset_lazy_init(self) -> None: self._init_reshard_after_forward() def _lazy_init(self) -> None: - """Initialization steps that should happen lazily, typically right - before the first forward pass. """ - # Initialize param attributes lazily, in case the param's dtype or - # device changes after __init__. - for p in self.params: - self._init_param_attributes(p) - - # Initialize _is_root and setup streams. These steps would ideally - # happen in __init__, but _is_root can only be determined after the - # entire model hierarchy is setup, thus we run it lazily. - if self._is_root is None: - # _is_root means that we are in the outermost module's forward. - self._set_is_root() - self._setup_streams() - - if self._is_root: - # Buffers stay on GPU, and don't get sharded. Since _cast_buffers - # applies recursively, we only call this from the root instance. - self._cast_buffers(recurse=True) - - # Don't free the full params for the outer-most (root) instance, - # In most cases, root instance contains params in the last layers - # or has no params. In these cases, those params will be needed - # immediately after for the backward pass. Note that this only - # applies currently when freeing parameters at end of layer's - # forward pass. - self.reshard_after_forward = False + Performs initialization lazily, typically right before the first + forward pass. The laziness is needed to ensure that the parameter + device/dtype and the FSDP hierarchy have finalized. - # Due to the use of streams, we need to make sure the previous - # ``optim.step()`` is done before we all-gather parameters. - self._wait_for_previous_optim_step() + This method's actual logic only runs on the root FSDP instance, which + performs initialization for all non-root FSDP instances to avoid + partial initialization. + """ + if self._is_root is not None: + return # no-op: already initialized + # The following logic is only run on the root FSDP instance + self._is_root = True + self._assert_state(TrainingState_.IDLE) + self._init_streams() + self._cast_buffers(recurse=True) + for param in self.params: + self._init_param_attributes(param) + # Do not reshard the root's parameters at the end of the forward pass + # with the intention that they are immediately used in the backward + # pass gradient computation (though this may not be true) + self.reshard_after_forward = False + self._exec_order_data.init(self) + # Initialize non-root FSDP instances and share attributes from the root + # to non-root instances (e.g. streams for overlapping) + for fsdp_module in self.fsdp_modules(self): + if fsdp_module is not self: + # Relax the assert for non-root FSDP instances in case the + # nested initialized module is wrapped again in FSDP later (e.g. + # after training to run inference) + assert fsdp_module._is_root is None or not fsdp_module._is_root, ( + "Non-root FSDP instance's `_is_root` should not have been " + "set yet or should have been set to `False`" + ) + fsdp_module._is_root = False + fsdp_module._streams = self._streams + fsdp_module._fsdp_graph_order = self._fsdp_graph_order + fsdp_module._exec_order_data = self._exec_order_data + for param in fsdp_module.params: + fsdp_module._init_param_attributes(param) @torch.no_grad() - def _init_param_attributes(self, p: Parameter) -> None: + def _init_param_attributes(self, p: FlatParameter) -> None: """ - We manage several attributes on each Parameter instance. The first two - are set by :func:`_shard_parameters`: + We manage several attributes on each Parameter instance. The first is + set by :func:`_shard_parameters`: ``_is_sharded``: ``True`` if the Parameter is sharded or ``False`` if the Parameter is intentionally not sharded (in which case we will all-reduce grads for this param). Currently the way `_is_sharded = False` is if world_size = 1 or sharding strategy is NO_SHARD. - ``_orig_size``: the size of the original Parameter (before sharding) A few attributes are set here: ``_local_shard``: a single shard of the parameter. This is needed to recover the shard after rebuilding full parameter in forward @@ -1596,9 +1632,7 @@ def _init_param_attributes(self, p: Parameter) -> None: ``_shard_bwd_hook``: it holds the parameter's AccumulateGrad object and the registered post hook handle. """ - assert hasattr(p, "_is_sharded") and hasattr( - p, "_orig_size" - ), "Parameters should have been sharded during construction." + assert hasattr(p, "_is_sharded"), "Parameters should have been sharded during construction." # If _local_shard has been set in the first lazy init and # current parameter is pointed to _local_shard, no need to # set the _local_shard again. @@ -1674,38 +1708,14 @@ def _init_param_attributes(self, p: Parameter) -> None: ) _free_storage(p._full_param_padded) # type: ignore[attr-defined] - def _set_is_root(self) -> None: - """If ``True``, implies that no other :class:`FullyShardedDataParallel` - instance wraps this one. Called once by :func:`_lazy_init`. - """ - if self._is_root is not None: - return - # No FSDP instance wraps this, else _is_root would be set to False. - self._is_root = True - self._exec_order_data.init(self) - # If final backward callback is never been queued, state should be IDLE. - # If final backward callback is queued, the callback should be finished - # and the state was reset to be IDLE. - # This should be asserted at the beginning of forward pass in the root instance only. - # For children instances, if they are checkpointed, state will not be reset to - # IDLE after each inner forward/backward. - self._assert_state(TrainingState_.IDLE) - for m in self.modules(): - if m is not self and isinstance(m, FullyShardedDataParallel): - # We relax the assert for non-root instance, when the nested initialized module is wrapped - # again in FSDP later, for example after training to run inference. - assert ( - m._is_root is None or not m._is_root - ), "Non-root instance's _is_root flag should have not been set yet" \ - "or has already been set as False." - if m._is_root is None: - m._is_root = False - - def _setup_streams(self) -> None: - """Create streams to overlap data transfer and computation.""" - if len(self._streams) > 0 or not self._is_root: - return + # Track whether the `FlatParameter`'s post-backward hook has been + # called for validation in `_wait_for_post_backward()` + p._post_backward_called = False + def _init_streams(self) -> None: + """Initializes CUDA streams for overlapping data transfer and + computation. This should only be called on the root FSDP instance.""" + assert self._is_root if torch.cuda.is_available(): # Stream for all-gathering parameters. self._streams["all_gather"] = torch.cuda.Stream() @@ -1716,33 +1726,18 @@ def _setup_streams(self) -> None: if self._mixed_precision_enabled_for_params(): self._streams["mixed_precision_params"] = torch.cuda.Stream() - # We share streams with all children instances, which allows them to - # overlap transfers across the forward pass without synchronizing with - # the default stream. - for m in self.modules(): - if m is not self and isinstance(m, FullyShardedDataParallel): - m._streams = self._streams - m._fsdp_graph_order = self._fsdp_graph_order - # Give each non-root FSDP module an alias to the root's - # execution order data structure and the root's ignored - # parameters and all buffer names since only the root's names - # are fully prefixed like the state dict keys - m._exec_order_data = self._exec_order_data - def _wait_for_previous_optim_step(self) -> None: """ - The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root - instance) needs to synchronize with the default stream to ensure the - previous optimizer step is done. + The root :class:`FullyShardedDataParallel` instance needs to + synchronize with the default stream to ensure that the previous + optimizer step is done. """ - if not torch.cuda.is_available(): + if not torch.cuda.is_available() or not self._is_root: return - if self._mixed_precision_enabled_for_params(): self._streams["mixed_precision_params"].wait_stream( torch.cuda.current_stream() ) - self._streams["all_gather"].wait_stream(torch.cuda.current_stream()) def _need_prefetch_full_params(self, state: TrainingState_) -> bool: @@ -1848,6 +1843,23 @@ def state_dict_type( submodule._state_dict_type = prev_state_dict_type submodule._state_dict_config = prev_state_dict_config + @property + def _param_fqns(self) -> Iterator[Tuple[str, str, str]]: + for param_name, module_name in ( + self._fsdp_wrapped_module.handle.parameter_module_names() + ): + module_name = module_name.replace(f"{FPW_MODULE}.", "") + module_name = module_name.replace(f"{FPW_MODULE}", "") + if module_name: + module_name = f"{module_name}." + # Activation checkpoint adds a prefix that has to be + # removed as well. + module_name = module_name.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + def _full_post_state_dict_hook( self, state_dict: Dict[str, Any], @@ -1859,48 +1871,68 @@ def _full_post_state_dict_hook( back to sharded version after _summon_full_params ends, and also remove "_fsdp_wrapped_module" prefix. """ + _replace_by_prefix(state_dict, prefix + f"{FSDP_WRAPPED_MODULE}.", prefix) self._assert_state([TrainingState_.SUMMON_FULL_PARAMS]) - # state_dict is empty for nonzero ranks if `rank0_only` was enabled. - if not state_dict: + # Return early for trivial cases + if not state_dict or not self._fsdp_wrapped_module.has_params: + return state_dict + + # If the `FlatParameter` is registered, then this rank only needed to + # participate in the all-gather but does not actually save the state + # dict (e.g. when `rank0_only=True` and `self.rank != 0`) + if hasattr(self._fsdp_wrapped_module, "flat_param"): return state_dict offload_to_cpu = self._state_dict_config.offload_to_cpu cpu_device = torch.device("cpu") - for key in state_dict: - clean_key = clean_tensor_name(key) + + # Loop only the parameters saved in self._fsdp_wrapped_module to avoid + # processing buffers. + for fqn, param_name, module_name in self._param_fqns: + fqn = f"{prefix}{fqn}" + clean_key = fqn clean_prefix = clean_tensor_name(prefix) # Strip prefix out of key if needed as buffer names and param names # do not have prefix considered as they are not computed in `state_dict` # call. if clean_key.startswith(clean_prefix): clean_key = clean_key[len(clean_prefix):] - # Do not need to clone buffers since they are not sharded - if clean_key in self._buffer_names: - # Offload the buffer to CPU if needed -- we do not do this in - # `_summon_full_params()` since without care, that would free - # the original buffer's GPU memory and require reallocating - # that memory later; this only affects the state dict's buffer - # variable and leaves the original buffer's GPU memory intact - if offload_to_cpu and state_dict[key].device != cpu_device: - state_dict[key] = state_dict[key].to(cpu_device) - continue + # Clone non-ignored parameters before exiting the # `_summon_full_params()` context + assert fqn in state_dict, ( + f"FSDP assumes {fqn} is in the state_dict but the state_dict " + f"only has {state_dict.keys()}. prefix={prefix}, " + f"module_name={module_name} param_name={param_name} rank={self.rank}." + ) if clean_key not in self._ignored_param_names and \ - not getattr(state_dict[key], "_has_been_cloned", False): + not getattr(state_dict[fqn], "_has_been_cloned", False): try: - state_dict[key] = state_dict[key].clone().detach() - state_dict[key]._has_been_cloned = True # type: ignore[attr-defined] + state_dict[fqn] = state_dict[fqn].clone().detach() + state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] except BaseException as e: warnings.warn( - f"Failed to clone() tensor with name {key}. This may mean " + f"Failed to clone() tensor with name {fqn}. This may mean " "that this state_dict entry could point to invalid memory " "regions after returning from state_dict() call if this " "parameter is managed by FSDP. Please check clone " - f"implementation of {key}. Error: {str(e)}" + f"implementation of {fqn}. Error: {str(e)}" ) - _replace_by_prefix(state_dict, prefix + f"{FSDP_WRAPPED_MODULE}.", prefix) + # Offload the buffer to CPU if needed -- we do not do this in + # `_summon_full_params()` since without care, that would free + # the original buffer's GPU memory and require reallocating + # that memory later; this only affects the state dict's buffer + # variable and leaves the original buffer's GPU memory intact + if offload_to_cpu: + for clean_key in self._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_key.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + fqn = f"{prefix}{clean_key}" + if state_dict[fqn].device != cpu_device: + state_dict[fqn] = state_dict[fqn].to(cpu_device) return state_dict def _local_post_state_dict_hook( @@ -1914,7 +1946,7 @@ def _local_post_state_dict_hook( will happen. The underlying storage is the same. """ _replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix) - if self._fsdp_wrapped_module.no_params: + if not self._fsdp_wrapped_module.has_params: return state_dict # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor @@ -1923,10 +1955,10 @@ def _local_post_state_dict_hook( # to get flat_param from the FlattenParamsWrapper to get the metadata. flat_param = getattr(self._fsdp_wrapped_module, FLAT_PARAM, None) # Construct a ShardedTensor from the flat_param. - full_numel = flat_param.full_numel + full_numel = flat_param._unsharded_size.numel() shard_offset = flat_param.numel() * self.rank - valid_data_size = flat_param.numel() - flat_param.num_padded - if valid_data_size > 0 and flat_param.num_padded > 0: + valid_data_size = flat_param.numel() - flat_param._shard_numel_padded + if valid_data_size > 0 and flat_param._shard_numel_padded > 0: flat_param = flat_param.narrow(0, 0, valid_data_size) local_shards = [ Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank) @@ -1948,17 +1980,12 @@ def _sharded_post_state_dict_hook( with a unflattened, sharded parameter (a ShardedTensor). """ _replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix) - if self._fsdp_wrapped_module.no_params: + if not self._fsdp_wrapped_module.has_params: return state_dict - for module_name, _, param_name in self._fsdp_wrapped_module.orig_flat_param[0].param_info: - module_name = module_name.replace(f"{FPW_MODULE}.", "") - module_name = module_name.replace(f"{FPW_MODULE}", "") - if module_name: - module_name = f"{module_name}." - fqn = f"{prefix}{module_name}{param_name}" - + for fqn, _, _ in self._param_fqns: # Create a ShardedTensor for the unflattened, non-sharded parameter. + fqn = f"{prefix}{fqn}" param = state_dict[fqn] local_shard = param.chunk(self.world_size)[self.rank].clone() offsets = [0 for _ in param.size()] @@ -2038,7 +2065,6 @@ def state_dict(self, *args, **kwargs): # is available. if torch.cuda.is_available(): torch.cuda.synchronize() - self._lazy_init() if self._state_dict_type == StateDictType.FULL_STATE_DICT: # Get config args @@ -2173,12 +2199,12 @@ def _local_pre_load_state_dict_hook( # tensor. flat_param = self._fsdp_wrapped_module.flat_param assert flat_param is not None - if flat_param.num_padded not in (0, flat_param.numel()): + if flat_param._shard_numel_padded not in (0, flat_param.numel()): assert load_tensor.numel() < flat_param.numel(), ( f"Local shard size = {flat_param.numel()} and the tensor in " f"the state_dict is {load_tensor.numel()}." ) - load_tensor = F.pad(load_tensor, [0, flat_param.num_padded]) + load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded]) state_dict[fqn] = load_tensor def _sharded_post_load_state_dict_hook(self, *args, **kwargs) -> None: @@ -2194,7 +2220,7 @@ def _sharded_pre_load_state_dict_hook( a new FlatParameter and shards the new FlatParameter to the local chunk. """ _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_WRAPPED_MODULE}.") - if self._fsdp_wrapped_module.no_params: + if not self._fsdp_wrapped_module.has_params: return if not self._fsdp_wrapped_module.flat_param._is_sharded: @@ -2208,7 +2234,7 @@ def _sharded_pre_load_state_dict_hook( # gather all the parameters in this layer. This can be achieved by # concatenated all the local shards and then append the padding. # https://github.com/pytorch/pytorch/issues/77461 - for module_name, _, param_name in self._fsdp_wrapped_module.flat_param._param_infos: + for (param_name, _, module_name) in self._fsdp_wrapped_module.handle.flat_param._param_infos: module_name = module_name.replace(f"{FPW_MODULE}.", "") module_name = module_name.replace(f"{FPW_MODULE}", "") if module_name: @@ -2236,17 +2262,19 @@ def _sharded_pre_load_state_dict_hook( # Create a new flat_param from the loaded, non-sharded tensors. flat_param = self._fsdp_wrapped_module.flat_param - loaded_flat_param = FlatParameter(nonsharded_tensors, requires_grad=False) + loaded_flat_param = FlatParamHandle.flatten_params(nonsharded_tensors, requires_grad=False) # Get the chunk from the loaded flat_param for the local rank. - loaded_flat_param, num_to_pad = self._get_shard(loaded_flat_param) + loaded_flat_param, num_to_pad = FlatParamHandle._get_shard( + loaded_flat_param, self.rank, self.world_size, + ) assert flat_param.numel() == loaded_flat_param.numel(), ( f"The loaded local chunk has different numel({flat_param.numel()}) " f"from the local chunk {flat_param.numel()}." ) - assert flat_param.num_padded == num_to_pad, ( + assert flat_param._shard_numel_padded == num_to_pad, ( f"The loaded local chunk has different padding({num_to_pad}) " - f"from the local chunk {flat_param.num_padded}." + f"from the local chunk {flat_param._shard_numel_padded}." ) state_dict[f"{prefix}_fsdp_wrapped_module.flat_param"] = loaded_flat_param @@ -2348,6 +2376,7 @@ def _load_sharded_state_dict( def forward(self, *args: Any, **kwargs: Any) -> Any: with torch.autograd.profiler.record_function("FullyShardedDataParallel.forward"): self._lazy_init() + self._wait_for_previous_optim_step() # Start of a forward pass. self.training_state = TrainingState_.FORWARD @@ -2540,11 +2569,11 @@ def _free_full_params_and_use_local_shard(params_to_free): # full parameters. with contextlib.ExitStack() as stack: # Invariant: rank == 0 or !rank0_only - stack.enter_context(self._fsdp_wrapped_module.unflatten_params()) + stack.enter_context(self._fsdp_wrapped_module.unflatten_as_params()) try: yield finally: - if offload_to_cpu: + if offload_to_cpu and (not rank0_only or my_rank == 0): for p in self.params: if p._is_sharded: with torch.no_grad(): @@ -2830,6 +2859,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: alignment is created by :func:`_shard_parameters`, which ensures that the local optimizer only sees the relevant parameter shard. """ + p_assert( + hasattr(param, '_post_backward_called'), + "Expected flag _post_backward_called to exist on param." + ) + param._post_backward_called = True with torch.autograd.profiler.record_function("FullyShardedDataParallel._post_backward_hook"): # First hook callback will see PRE state. If we have multiple params, # then subsequent hook callbacks will see POST state. @@ -2849,13 +2883,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: "FSDP only works with gradients that don't require gradients" ) - if self._require_backward_grad_sync or \ - self.sharding_strategy == ShardingStrategy.FULL_SHARD: - # We free full parameters unless we are in `no_sync()` (i.e. when - # `_require_backward_grad_sync=False`) and not using the - # `FULL_SHARD` strategy. If we are not using the `FULL_SHARD` - # strategy (e.g. instead using `SHARD_GRAD_OP`), then we keep the - # full parameters in memory and save network overhead. + if ( + self._require_backward_grad_sync + or self.sharding_strategy == ShardingStrategy.FULL_SHARD + ): self._free_full_params(cast(List[FlatParameter], [param])) if self._mixed_precision_enabled_for_params(): @@ -2889,16 +2920,21 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: with torch.cuda.stream(self._streams["post_backward"]): orig_grad_data = param.grad.data if ( - self._mixed_precision_enabled_for_reduce() + self._mixed_precision_enabled_for_reduce() and not self._low_precision_hook_enabled() ): # Cast gradient to precision in which it should be communicated. + # If a low precision hook is registered and reduce_dtype is specified + # in `MixedPrecision`, communication hook will take care of + # casting to lower precision and back. # TODO: Make this a communication hook when communication hooks # are implemented for FSDP. Note that this is a noop if the # reduce_dtype matches the param dtype. param.grad.data = param.grad.data.to(self.mixed_precision.reduce_dtype) - if self.gradient_predivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. + if self.gradient_predivide_factor > 1 and self.communication_hook is None: + # Average grad by pre-division factor. Together pre- and post-division factors + # lead to an overall averaging by world_size, required for consistency with PyTorch DDP. + # This is a two-step process to avoid potential underflow and overflow. param.grad.div_(self.gradient_predivide_factor) grad = param.grad.data @@ -2925,23 +2961,12 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: output, input_flattened, group=self.process_group ) if self.gradient_postdivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. + # Average grad by pre-division factor. Together pre- and post-division factors + # lead to an overall averaging by world_size, required for consistency with PyTorch DDP. + # This is a two-step process to avoid potential underflow and overflow. output.div_(self.gradient_postdivide_factor) - # Note that we need to cast grads back to the full precision if - # 1) parameters were in reduced precision during fwd, as grads - # would thus be in this reduced precision, or - # 2) parameters did not have precision reduced, but grads - # had reduced precision for communication. - if ( - self._mixed_precision_enabled_for_params() or self._mixed_precision_enabled_for_reduce() - ): - # Cast gradients back to the full parameter precision so that - # optimizer.step() happens in full precision. - orig_param_grad_data = output - output.data = output.data.to(dtype=param.data.dtype) - # Don't let this memory get reused until after the transfer. - orig_param_grad_data.record_stream(torch.cuda.current_stream()) + self._cast_grad_to_param_dtype(output, param) # To support gradient accumulation outside `no_sync()`, we save # the gradient data to `param._saved_grad_shard` before the @@ -2974,24 +2999,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: ), "Currently the way for _is_sharded to be False is \ world_size == 1 or sharding_stratagy is set to be NO_SHARD" if self.sharding_strategy == ShardingStrategy.NO_SHARD: - dist.all_reduce(param.grad, group=self.process_group) - if self.gradient_postdivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. - param.grad.div_(self.gradient_postdivide_factor) - # Note that we need to cast grads back to the full precision if - # 1) parameters were in reduced precision during fwd, as grads - # would thus be in this reduced precision, or - # 2) parameters did not have precision reduced, but grads - # had reduced precision for communication. - if ( - self._mixed_precision_enabled_for_params() or self._mixed_precision_enabled_for_reduce() - ): - # Cast gradients back to the full parameter precision so that - # optimizer.step() happens in full precision. - orig_param_grad_data = param.grad.data - param.grad.data = param.grad.data.to(dtype=param.data.dtype) - # Don't let this memory get reused until after the transfer. - orig_param_grad_data.record_stream(torch.cuda.current_stream()) + # if a communication hook was not registered, + # then a default hook (`all_reduce`) will be used + self.communication_hook(self.communication_hook_state, param.grad) + + self._cast_grad_to_param_dtype(param.grad, param) # Regardless of sharding or not, offload the grad to CPU if we are # offloading params. This is so param and grad reside on same device @@ -3013,6 +3025,35 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py orig_grad_data.record_stream(self._streams["post_backward"]) + def _cast_grad_to_param_dtype( + self, + grad: torch.Tensor, + param: FlatParameter, + ): + """ + Casts gradient ``grad`` back to the full parameter dtype so that the + optimizer step runs with that dtype. This performs an actual cast if + 1. parameters were in reduced precision during the forward since then + gradients would be in that reduced precision, or + 2. parameters were not in reduced precision but gradients were in + reduced precision for communication. + However, if a low precision communication hook is registered, then this + dtype cast happens in the hook instead. + """ + self._assert_state(TrainingState_.BACKWARD_POST) + if ( + not self._low_precision_hook_enabled() + and ( + self._mixed_precision_enabled_for_params() + or self._mixed_precision_enabled_for_reduce() + ) + ): + low_prec_grad_data = grad.data + grad.data = grad.data.to(dtype=param.dtype) + # Do not let the low precision gradient memory get reused until + # the cast to full parameter precision completes + low_prec_grad_data.record_stream(torch.cuda.current_stream()) + def _queue_wait_for_post_backward(self) -> None: """Try to queue a `wait_for_post_backward` callback. Only called on root and only queue one callback at the beginning of @@ -3089,24 +3130,43 @@ def _finalize_params(fsdp_module: FullyShardedDataParallel) -> None: p.grad = p._saved_grad_shard # type: ignore[attr-defined] else: p_assert( - not p._is_sharded, "All sharded parameters should " + not p._is_sharded or not p._post_backward_called, + "All sharded parameters that received gradient should " "use `_saved_grad_shard`" ) if hasattr(p, "_saved_grad_shard"): delattr(p, "_saved_grad_shard") + p_assert( + hasattr(p, '_post_backward_called'), + "Expected flag _post_backward_called to be set on param." + ) + # Reset _post_backward_called in preparation for the next iteration. + p._post_backward_called = False + # Update root and nested FSDP's hooks and flags. for m in self.modules(): # includes self if isinstance(m, FullyShardedDataParallel): - _finalize_params(m) - m._pre_backward_hook_has_run = False if any(p.requires_grad for p in m.parameters()): # Check if the module has params and if any of them has # the `requires_grad` field set. If `requires_grad=False` for # all the params, the post_backward hook will not fire and the # state will remain in `TrainingState_.BACKWARD_PRE`. - if any([p.requires_grad for p in m.params]): - m._assert_state(TrainingState_.BACKWARD_POST) + managed_param_requires_grad = any(p.requires_grad for p in m.params) + if managed_param_requires_grad: + p_assert( + all(hasattr(p, '_post_backward_called') for p in m.params), + "Expected all params to have flag _post_backward_called set!" + ) + post_backward_hook_called = any(p._post_backward_called for p in m.params) + if post_backward_hook_called: + m._assert_state(TrainingState_.BACKWARD_POST) + else: + # post backward hook was not called, meaning param + # did not have a gradient computed. It was either unused + # in forward, or unused in loss computation so it did + # not get gradient + m._assert_state([TrainingState_.BACKWARD_PRE, TrainingState_.IDLE]) else: m._assert_state(TrainingState_.BACKWARD_PRE) else: @@ -3118,6 +3178,9 @@ def _finalize_params(fsdp_module: FullyShardedDataParallel) -> None: # 2. output tensors are `requires_grad==False`. In this case, # pre-backward hook is not registered, so it is in IDLE state. m._assert_state([TrainingState_.BACKWARD_PRE, TrainingState_.IDLE]) + + _finalize_params(m) + m._pre_backward_hook_has_run = False m.training_state = TrainingState_.IDLE if m._is_root: @@ -3163,7 +3226,7 @@ def _update_p_data(self, p, output_tensor: torch.Tensor) -> None: """ p.data = output_tensor # Trim any padding and reshape to match original size. - p.data = p.data[: p._orig_size.numel()].view(p._orig_size) # type: ignore[attr-defined] + p.data = p.data[:p._unsharded_size.numel()].view(p._unsharded_size) # type: ignore[attr-defined] @torch.no_grad() def _rebuild_full_params(self) -> List[Tuple[torch.Tensor, bool]]: @@ -3402,7 +3465,7 @@ def _prep_grads_for_backward(self) -> None: """Make sure p.grad has the correct size/device, otherwise set it to None.""" for p in self.params: if p.grad is not None and ( - p.grad.size() != p._orig_size # type: ignore[attr-defined] + p.grad.size() != p._unsharded_size # type: ignore[attr-defined] or p.grad.device != p.device ): offloaded: bool = p.grad.device != p.device @@ -3560,8 +3623,8 @@ def clip_grad_norm_( .. warning:: This needs to be called on all ranks, since synchronization primitives will be used. """ - # Call `_lazy_init` to ensure the stream synchronization is done appropriately. self._lazy_init() + self._wait_for_previous_optim_step() assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" self._assert_state(TrainingState_.IDLE) @@ -4018,6 +4081,72 @@ def rekey_optim_state_dict( return new_osd return new_osd # should never reach here + def _get_default_comm_hook(self) -> Any: + r""" + Returns a default communication hook based on a sharding strategy. + """ + if self.sharding_strategy != ShardingStrategy.NO_SHARD: + return None + else: + return default_hooks.allreduce_hook + + def _get_default_comm_hook_state(self) -> Any: + r""" + Returns a default communication hook state based on a sharding strategy. + """ + if self.sharding_strategy != ShardingStrategy.NO_SHARD: + return None + else: + return default_hooks.DefaultState(process_group=self.process_group) + + def register_comm_hook(self, state: object, hook: callable): + """ + Registers a communication hook which is an enhancement that provides a + flexible hook to users where they can specify how FSDP aggregates gradients + across multiple workers. + This hook can be used to implement several algorithms like + `GossipGrad `_ and gradient compression + which involve different communication strategies for + parameter syncs while training with :class:`FullyShardedDataParallel`. + + .. warning:: + FSDP only support communication hooks for a ``NO_SHARD`` strategy at this time. + If other strategies are used, an error will be raised. + + .. warning :: + FSDP communication hook should be registered before running an initial forward pass + and only once. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + Examples include error feedback in gradient compression, + peers to communicate with next in `GossipGrad `_, etc. + It is locally stored by each worker + and shared by all the gradient tensors on the worker. + hook (Callable): Callable with the following signature: + ``hook: Callable[torch.Tensor] -> None``: + This function takes in a Python tensor, which represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). + It then performs all necessary processing and returns ``None``. + + """ + if not self.check_is_root(): + raise AssertionError("register_comm_hook can only be called on a root instance.") + if self.sharding_strategy != ShardingStrategy.NO_SHARD: + raise NotImplementedError( + "Communication hooks are currently only available for a NO_SHARD strategy." + ) + else: + # register same hook for root and all submodules + for submodule in self.fsdp_modules(self): + assert not submodule._hook_registered, "communication hook can be only registered once" + submodule._hook_registered = True + assert submodule.communication_hook == self._get_default_comm_hook(),\ + f"communication hook should be default, but it is {submodule.communication_hook.__name__} instead" + submodule.communication_hook_state = state + submodule.communication_hook = hook def _get_default_cuda_device(module: nn.Module) -> torch.device: """Try to infer CUDA device from module parameters.""" @@ -4059,6 +4188,7 @@ def p_assert(cond: Any, s: Any) -> None: to print the error message ``s`` since otherwise, it is swallowed.""" if not cond: print(s) + traceback.print_stack() raise AssertionError def _calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor: @@ -4114,7 +4244,7 @@ def module_fn(module, prefix, param_to_unflat_param_names): if not isinstance(module, FullyShardedDataParallel): for param_name, param in module.named_parameters(recurse=False): module_prefixed_param_names = ( - param._param_names if isinstance(param, FlatParameter) + param._prefixed_param_names if isinstance(param, FlatParameter) else [param_name] ) # prefixed from `module` fully_prefixed_param_names = [ @@ -4184,4 +4314,9 @@ def clean_tensor_name(tensor_name: str) -> str: # call `replace()` twice separately tensor_name = tensor_name.replace(FSDP_WRAPPED_MODULE + ".", "") tensor_name = tensor_name.replace(FPW_MODULE + ".", "") + # TODO: Explicitly replacing checkpoint_wrapper prefix is not ideal, + # as it increases coupling between CheckpointWrapper and FSDP. This is also not + # scalable for additional wrapped modules, we should come up with a general solution + # for this issue. + tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX + ".", "") return tensor_name diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index dfeaf13ea1ee7..27ba44e6c1516 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -71,8 +71,9 @@ class ShardedGradScaler(GradScaler): :meth:`update` if inf/NaN gradients occur in an iteration. growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``. - enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD): process group for sharding """ diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index ab3b9f7ad52de..e0cf445fd9b43 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -296,26 +296,45 @@ class ParamExecOrderWrapPolicy: (also called non-recursive wrapping policy). The policy contains multiple wraps. Each wrap contains original parameters that will be executed together, - and the wrap transfers these parameters into one FlattenParameter. In both forward and the backward passes, + and the wrap transfers these parameters into one ``FlattenParameter``. In both forward and the backward passes, the sharded parameters in each wrap will be gathered just before these parameters are used in the passes. These parameters will then be reshaded once they have been used. - TODO (linjianma): For now, the parameters contained in each wrap of ParamExecOrderWrapPolicy - are the parameters in each wrap of the init_policy (a recursive wrapping policy). + TODO (linjianma): For now, the parameters contained in each wrap of ``ParamExecOrderWrapPolicy`` + are the parameters in each wrap of the ``init_policy`` (a recursive wrapping policy). Later we will wrap parameters based on bucket size. Args: - init_policy (nn.Module): - The initial recursive wrapping policy used to guide the wrapping of this policy. In the first - forward and backward iteration, init_policy is used. Parameter execution order is also recorded - in the first iteration. Starting from second iteration, ParamExecOrderWrapPolicy will be used. - - The default always_wrap_policy might not be the best choice for every model. For example, for - transformer based models, setting transformer_auto_wrap_policy as the init_policy will guarantee + init_policy (Callable): + The initial recursive wrapping policy used to guide the wrapping of + this policy. If tracing_config is none, in the first forward and + backward iteration, ``init_policy`` is used to record parameter + execution order. Otherwise, init_policy is only used in FSDP + constructor for module level wrapping. + + The default ``always_wrap_policy`` might not be the best choice for every model. For example, for + transformer based models, setting ``transformer_auto_wrap_policy`` as the ``init_policy`` will guarantee wrapping each transformer layer into one FSDP unit, and can be easily combined with checkpointing within each transformer layer. + + tracing_config (Optional[TracingConfig]): + The configuration used to perform symbolic tracing at FSDP + constructor to get the module and parameter execution order. The + type of ``tracing_config`` needs to be either ``None`` or + ``TracingConfig``. If set as ``None``, then symbolic tracing is not + enabled, and one forward as well as backward iteration are needed to + get the parameter execution order. + + ..warning :: Note that not all modules can be successfully traced when + ``tracing_config`` is not None and symbolic tracing is enabled. The two + cases below may be unable to trace: 1. when there is a data-dependent + branch, 2. when the forward pass contains operators that don't support + ``torch.fx.Proxy`` as the input type (e.g. ``arange``, ``zeros``, ``ones``, + ``full``, ``full_like``, ``eye``, ``empty``, ``tensor``). For those cases, + users can set ``tracing_config = None`` to disable symbolic tracing. """ init_policy: Callable = always_wrap_policy + tracing_config: Any = None def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 66efbccd9ca78..b09a31250f2a4 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -20,6 +20,7 @@ from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint from torch.distributed.elastic.utils.logging import get_logger +__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent'] logger = get_logger() diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index 6460239883cd5..9f6fd38039675 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -353,7 +353,6 @@ def forward(ctx, output_tensor, input_tensor, group): @staticmethod def backward(ctx, grad_output): if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: - rank = dist.get_rank(group=ctx.group) world_size = dist.get_world_size(group=ctx.group) out_size = list(grad_output.size()) if out_size[0] % world_size != 0: diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index ddba3d5ae671b..5ea02786d00e7 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -4,6 +4,8 @@ from torch import Tensor +__all__ : List[str] = [] + # Define a TorchScript compatible Functional Adadelta Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index be81acf021397..28a97817bba0b 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -4,6 +4,8 @@ from torch import Tensor +__all__ : List[str] = [] + # Define a TorchScript compatible Functional Adagrad Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index d0d2a7df06b46..963b56f6b0726 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -4,6 +4,8 @@ from torch import Tensor +__all__ : List[str] = [] + # Define a TorchScript compatible Functional Adam Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index 33dd1669af184..a664e0df8f69b 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -4,6 +4,8 @@ from torch import Tensor +__all__ : List[str] = [] + # Define a TorchScript compatible Functional Adamax Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 3114d06911384..eeaf5385bd31f 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -4,6 +4,8 @@ from torch import Tensor +__all__ : List[str] = [] + # Define a TorchScript compatible Functional AdamW Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index e628e4855a8a0..c94df3e11ac7b 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -4,6 +4,8 @@ from torch import Tensor +__all__ : List[str] = [] + # Define a TorchScript compatible Functional RMSprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, @@ -25,6 +27,7 @@ def __init__( momentum: float = 0.0, centered: bool = False, foreach: bool = False, + maximize: bool = False, _allow_empty_param_list: bool = False, ): self.defaults = { @@ -36,6 +39,7 @@ def __init__( } self.centered = centered self.foreach = foreach + self.maximize = maximize if len(params) == 0 and not _allow_empty_param_list: raise ValueError("optimizer got an empty parameter list") @@ -102,4 +106,5 @@ def step(self, gradients: List[Optional[Tensor]]): weight_decay=weight_decay, momentum=momentum, centered=self.centered, - foreach=self.foreach) + foreach=self.foreach, + maximize=self.maximize) diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index ed6ebddc3d2b4..9a2f89799d6ce 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -4,6 +4,8 @@ from torch import Tensor +__all__ : List[str] = [] + # Define a TorchScript compatible Functional Rprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index 57cf724ad0795..b4c48159339fd 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -4,6 +4,8 @@ from torch import Tensor +__all__ : List[str] = [] + # Define a TorchScript compatible Functional SGD Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index ee4513c70d28c..0d34930e4b258 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -14,6 +14,8 @@ from collections import defaultdict from threading import Lock +__all__ = ['DistributedOptimizer'] + logger = logging.getLogger(__name__) # XXX: we define a _ScriptModuleOptimizer here to explicitly diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index 534ce9cca18c9..6ba73a7afc8bb 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -20,7 +20,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): >>> import torch.distributed.algorithms.model_averaging.averagers as averagers >>> import torch.nn as nn >>> from torch.distributed.optim import PostLocalSGDOptimizer - >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import post_localSGD_hook + >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) >>> >>> model = nn.parallel.DistributedDataParallel( >>> module, device_ids=[rank], output_device=rank diff --git a/torch/distributed/optim/utils.py b/torch/distributed/optim/utils.py index 0f8753b094079..624d4f688dbc7 100644 --- a/torch/distributed/optim/utils.py +++ b/torch/distributed/optim/utils.py @@ -24,6 +24,22 @@ optim.Adamax: _FunctionalAdamax, } + +def register_functional_optim(key, optim): + """ + Interface to insert a new functional optimizer to functional_optim_map + ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key + need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers) + Example:: + >>> # import the new functional optimizer + >>> from xyz import fn_optimizer + >>> from torch.distributed.optim.utils import register_functional_optim + >>> fn_optim_key = "XYZ_optim" + >>> register_functional_optim(fn_optim_key, fn_optimizer) + """ + if key not in functional_optim_map: + functional_optim_map[key] = optim + def as_functional_optim(optim_cls: Type, *args, **kwargs): try: functional_cls = functional_optim_map[optim_cls] diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index d726e25e9c5bf..4a3b93c1ad27c 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -984,7 +984,7 @@ def _local_step( If the argument itself is ``None``, then all parameters are updated, and the gradients are assumed to be already populated. (default: ``None``) - closure (callable): a closure that re-evaluates the model and + closure (Callable): a closure that re-evaluates the model and returns the loss; optional for most optimizers and should be ``None`` if ``gradients`` is not ``None``; (default: ``None``) Returns: @@ -1043,7 +1043,7 @@ def step( Performs a single optimizer step and syncs parameters across all ranks. Arguments: - closure (callable): a closure that re-evaluates the model and + closure (Callable): a closure that re-evaluates the model and returns the loss; optional for most optimizers. Returns: Optional loss depending on the underlying local optimizer. diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 1300a415cb726..0d2edb86e4d41 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -9,7 +9,7 @@ import os import sys from datetime import timedelta -from typing import Dict +from typing import Dict, Optional import torch._six as six from torch.distributed import FileStore, PrefixStore, Store, TCPStore @@ -50,24 +50,24 @@ def register_rendezvous_handler(scheme, handler): ) _rendezvous_handlers[scheme] = handler + # Query will have format "rank=0&world_size=1" and is # converted into {"rank": 0, "world_size": 1} def _query_to_dict(query: str) -> Dict[str, str]: return dict((pair[0], pair[1]) for pair in (pair.split("=") for pair in filter(None, query.split("&")))) -def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): - if not isinstance(url, six.string_classes): - raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url)) - - if not isinstance(rank, numbers.Integral): - raise RuntimeError("`rank` must be an integer. {}".format(rank)) - if not isinstance(world_size, numbers.Integral): - raise RuntimeError("`world_size` must be an integer. {}".format(world_size)) - - # Append node-specific arguments. +def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): result = urlparse(url) - if rank != -1 or world_size != -1: + if world_size_opt is None: + world_size = -1 + if result.scheme == "env": + rank = int(os.environ.get("RANK", rank)) + # If the world_size env variable is not present then it is a dynamic group + world_size = int(os.environ.get("WORLD_SIZE", world_size)) + else: + world_size = world_size_opt + if rank != -1 or world_size != -1 or world_size_opt is None: query_dict = _query_to_dict(result.query) assert ( "rank" not in query_dict and "world_size" not in query_dict @@ -75,10 +75,9 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): url=url ) if rank != -1: - query_dict["rank"] = rank - if world_size != -1: - query_dict["world_size"] = world_size - + query_dict["rank"] = str(rank) + if world_size != -1 or world_size_opt is None: + query_dict["world_size"] = str(world_size) result = result._replace( query="{}".format( "&".join(["{}={}".format(k, v) for k, v in query_dict.items()]) @@ -90,35 +89,25 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): raise RuntimeError("No rendezvous handler for {}://".format(result.scheme)) return _rendezvous_handlers[result.scheme](url, **kwargs) -def _create_store_from_options(backend_options, rank): - result = urlparse(backend_options.init_method) - # If using env initialization, get rank and world_size from env - world_size = -1 - if result.scheme == "env": - rank = os.environ.get("RANK", rank) - # Here, the world_size has already beeen initialized to -1 in init_rpc - # If the world_size env variable is also not present then it is a dynamic group - world_size = int(os.environ.get("WORLD_SIZE", world_size)) +def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): + if not isinstance(url, six.string_classes): + raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url)) - query_dict = _query_to_dict(result.query) - # if rank is -1 then intentionally exclude rank for the query, error will be thrown later - if rank != -1: - query_dict["rank"] = str(rank) - query_dict["world_size"] = str(world_size) - - result = result._replace( - query="{}".format( - "&".join(["{}={}".format(k, v) for k, v in query_dict.items()]) - ) - ) + if not isinstance(rank, numbers.Integral): + raise RuntimeError("`rank` must be an integer. {}".format(rank)) - url = urlunparse(result) - if result.scheme not in _rendezvous_handlers: - raise RuntimeError("No handler for {}://".format(result.scheme)) - store, _, _ = next(_rendezvous_handlers[result.scheme](url)) + if not isinstance(world_size, numbers.Integral): + raise RuntimeError("`world_size` must be an integer. {}".format(world_size)) + + return _rendezvous_helper(url, rank, world_size, **kwargs) + + +def _create_store_from_options(backend_options, rank): + store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None)) return store + def _rendezvous_error(msg): return ValueError("Error initializing torch.distributed using " + msg) diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 47c983a1907b4..1f5eb76d626ae 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -528,7 +528,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): Args: to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. - func (callable): a callable function, such as Python callables, builtin + func (Callable): a callable function, such as Python callables, builtin operators (e.g. :meth:`~torch.add`) and annotated TorchScript functions. args (tuple): the argument tuple for the ``func`` invocation. @@ -736,7 +736,7 @@ def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): Args: to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. - func (callable): a callable function, such as Python callables, builtin + func (Callable): a callable function, such as Python callables, builtin operators (e.g. :meth:`~torch.add`) and annotated TorchScript functions. args (tuple): the argument tuple for the ``func`` invocation. @@ -810,7 +810,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): Args: to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. - func (callable): a callable function, such as Python callables, builtin + func (Callable): a callable function, such as Python callables, builtin operators (e.g. :meth:`~torch.add`) and annotated TorchScript functions. args (tuple): the argument tuple for the ``func`` invocation. diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index b8c4aec49121a..da441e12a24ee 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -262,6 +262,8 @@ def _get_device_infos(): agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) opts = agent._get_backend_options() device_count = torch.cuda.device_count() + if torch.cuda.is_available() and opts.devices: + torch.cuda.init() return device_count, opts.device_maps, opts.devices def _set_devices_and_reverse_device_map(agent): @@ -312,16 +314,7 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_ ) ) - if torch.cuda.is_available(): - # It's necessary to initialize PyTorch CUDA states here (e.g., - # CUDACachingAllocator). If this is missing, we could hit errors like - # "allocator not initialized", because other processes might send - # CUDA-related RPC request to this process before user code in this - # process initializes its PyTorch CUDA states. - torch.cuda.init() - device_count = torch.cuda.device_count() - else: - device_count = 0 + device_count = torch.cuda.device_count() is_static_group = True if world_size else False # world_size is specified so this is a static group (ranks cannot join and leave) @@ -339,6 +332,14 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_ group, ) + if torch.cuda.is_available() and devices: + # It's necessary to initialize PyTorch CUDA states here (e.g., + # CUDACachingAllocator). If this is missing, we could hit errors like + # "allocator not initialized", because other processes might send + # CUDA-related RPC request to this process before user code in this + # process initializes its PyTorch CUDA states. + torch.cuda.init() + # TODO: add try-except and destroy _agent in all processes if any fails. agent = TensorPipeAgent( store, diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index fba992a947311..b4bcb453dc83b 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -4,12 +4,14 @@ import torch from torch.autograd.profiler_legacy import profile +from typing import List from . import ( _disable_server_process_global_profiler, _enable_server_process_global_profiler, ) +__all__: List[str] = [] class _server_process_global_profile(profile): """ diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index c3c4111bf6042..b6e1afeca19a5 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -7,6 +7,7 @@ from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property from torch.nn.functional import binary_cross_entropy_with_logits +__all__ = ['Bernoulli'] class Bernoulli(ExponentialFamily): r""" diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index b99b3e5b1b471..af2213da450c8 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -6,6 +6,7 @@ from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all +__all__ = ['Beta'] class Beta(ExponentialFamily): r""" diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 92691b88e85bf..e892e2487caa0 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -3,6 +3,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs +__all__ = ['Binomial'] def _clamp_by_zero(x): # works like clamp(x, min=0) but has grad at 0 is 0.5 diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index fa0ba3fae49a8..b18a27fff33f9 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -4,6 +4,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property +__all__ = ['Categorical'] class Categorical(Distribution): r""" diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index f146807afe68c..3db153c6fc131 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -7,6 +7,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all +__all__ = ['Cauchy'] class Cauchy(Distribution): r""" diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py index d635f005e9172..bde0c2ff1f782 100644 --- a/torch/distributions/chi2.py +++ b/torch/distributions/chi2.py @@ -1,6 +1,7 @@ from torch.distributions import constraints from torch.distributions.gamma import Gamma +__all__ = ['Chi2'] class Chi2(Gamma): r""" diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index c03f0ad02d2c6..dc0f07ff7b13e 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -98,7 +98,7 @@ def construct_transform(constraint): constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): A subclass of :class:`~torch.distributions.constraints.Constraint`, or a singleton object of the desired class. - factory (callable): A callable that inputs a constraint object and returns + factory (Callable): A callable that inputs a constraint object and returns a :class:`~torch.distributions.transforms.Transform` object. """ # Support use as decorator. diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index a50b9b539d3c9..a217595725790 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -159,7 +159,7 @@ def support(self): return constraints.interval(self.low, self.high) Args: - fn (callable): The function to be decorated. + fn (Callable): The function to be decorated. is_discrete (bool): Optional value of ``.is_discrete`` in case this can be computed statically. If not provided, access to the ``.is_discrete`` attribute will raise a NotImplementedError. diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index 5d3d488402030..17743592c7e97 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -7,6 +7,7 @@ from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs from torch.nn.functional import binary_cross_entropy_with_logits +__all__ = ['ContinuousBernoulli'] class ContinuousBernoulli(ExponentialFamily): r""" diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index 049bac64ba03d..6d9a1522fb88c 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -4,6 +4,7 @@ from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily +__all__ = ['Dirichlet'] # This helper is exposed for testing. def _Dirichlet_backward(x, concentration, grad_output): @@ -33,7 +34,7 @@ class Dirichlet(ExponentialFamily): Example:: >>> m = Dirichlet(torch.tensor([0.5, 0.5])) - >>> m.sample() # Dirichlet distributed with concentrarion concentration + >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] tensor([ 0.1046, 0.8954]) Args: diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index f24650993521d..66bd158bd87b6 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -4,6 +4,7 @@ from torch.distributions.utils import lazy_property from typing import Dict, Optional, Any +__all__ = ['Distribution'] class Distribution(object): r""" diff --git a/torch/distributions/exp_family.py b/torch/distributions/exp_family.py index 7084714ee3d06..8db7075c7fffc 100644 --- a/torch/distributions/exp_family.py +++ b/torch/distributions/exp_family.py @@ -1,6 +1,7 @@ import torch from torch.distributions.distribution import Distribution +__all__ = ['ExponentialFamily'] class ExponentialFamily(Distribution): r""" diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index fdacbc8d07613..11d176f6fa43e 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -5,6 +5,7 @@ from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all +__all__ = ['Exponential'] class Exponential(ExponentialFamily): r""" diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index 89058225e964d..220ab60899a12 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -6,6 +6,7 @@ from torch.distributions.gamma import Gamma from torch.distributions.utils import broadcast_all +__all__ = ['FisherSnedecor'] class FisherSnedecor(Distribution): r""" diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index e1009a79b1209..907117846ec80 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -5,6 +5,7 @@ from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all +__all__ = ['Gamma'] def _standard_gamma(concentration): return torch._standard_gamma(concentration) diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index af78876aade4e..195add2f13afb 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -6,6 +6,7 @@ from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property from torch.nn.functional import binary_cross_entropy_with_logits +__all__ = ['Geometric'] class Geometric(Distribution): r""" diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index ece771a6ee9fc..51abad6f3fbf2 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -7,6 +7,7 @@ from torch.distributions.transforms import AffineTransform, ExpTransform from torch.distributions.utils import broadcast_all, euler_constant +__all__ = ['Gumbel'] class Gumbel(TransformedDistribution): r""" diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index 6657209279072..eca165a15c37b 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -7,6 +7,7 @@ from torch.distributions.cauchy import Cauchy from torch.distributions.transformed_distribution import TransformedDistribution +__all__ = ['HalfCauchy'] class HalfCauchy(TransformedDistribution): r""" diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index ab27325711ecb..8dc240274f192 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -7,6 +7,7 @@ from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution +__all__ = ['HalfNormal'] class HalfNormal(TransformedDistribution): r""" diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index b089bfe9d8589..1893a045246cc 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -4,6 +4,8 @@ from torch.distributions.utils import _sum_rightmost from typing import Dict +__all__ = ['Independent'] + class Independent(Distribution): r""" Reinterprets some of the batch dims of a distribution as event dims. diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 8a39dbe5f674a..b439f2edd0f24 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -180,10 +180,10 @@ def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor: @register_kl(Bernoulli, Bernoulli) def _kl_bernoulli_bernoulli(p, q): - t1 = p.probs * (p.probs / q.probs).log() + t1 = p.probs * (torch.nn.functional.softplus(-q.logits) - torch.nn.functional.softplus(-p.logits)) t1[q.probs == 0] = inf t1[p.probs == 0] = 0 - t2 = (1 - p.probs) * ((1 - p.probs) / (1 - q.probs)).log() + t2 = (1 - p.probs) * (torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits)) t2[q.probs == 1] = inf t2[p.probs == 1] = 0 return t1 + t2 diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index ed5ff3af40117..16d6886b9f996 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -6,6 +6,7 @@ from torch.distributions.transforms import AffineTransform, PowerTransform from torch.distributions.utils import broadcast_all, euler_constant +__all__ = ['Kumaraswamy'] def _moments(a, b, n): """ diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 9b870e8839b76..21310147d6ae2 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -4,6 +4,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all +__all__ = ['Laplace'] class Laplace(Distribution): r""" diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index 132b873a5bb87..273779b62c59c 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -14,6 +14,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all +__all__ = ['LKJCholesky'] class LKJCholesky(Distribution): r""" diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index 3e18c6ceb1adf..40694e2c77423 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -3,6 +3,7 @@ from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution +__all__ = ['LogNormal'] class LogNormal(TransformedDistribution): r""" diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index c1268767f4d99..9a986ccf1d80f 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -3,6 +3,7 @@ from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import StickBreakingTransform +__all__ = ['LogisticNormal'] class LogisticNormal(TransformedDistribution): r""" diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index d0cff8fb5666e..1deb34a615cea 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -6,6 +6,7 @@ from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv from torch.distributions.utils import _standard_normal, lazy_property +__all__ = ['LowRankMultivariateNormal'] def _batch_capacitance_tril(W, D): r""" diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 2589546e11e02..ee941df5ae29f 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -4,6 +4,7 @@ from torch.distributions import constraints from typing import Dict +__all__ = ['MixtureSameFamily'] class MixtureSameFamily(Distribution): r""" diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 39203b8b34bf8..854dbbd0ee5c7 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -6,6 +6,7 @@ from torch.distributions import constraints from torch.distributions.utils import broadcast_all +__all__ = ['Multinomial'] class Multinomial(Distribution): r""" diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index f6abf0b372a7b..1f2451d4ef4a1 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -5,6 +5,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import _standard_normal, lazy_property +__all__ = ['MultivariateNormal'] def _batch_mv(bmat, bvec): r""" diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index c67baf0d2da8b..20d802654e11e 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -4,6 +4,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs +__all__ = ['NegativeBinomial'] class NegativeBinomial(Distribution): r""" diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 5a2ade0881a0b..9032836ab7ad1 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -7,6 +7,7 @@ from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import _standard_normal, broadcast_all +__all__ = ['Normal'] class Normal(ExponentialFamily): r""" diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index be3553a526fbd..f884338b3d7cb 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -3,6 +3,7 @@ from torch.distributions.categorical import Categorical from torch.distributions.distribution import Distribution +__all__ = ['OneHotCategorical', 'OneHotCategoricalStraightThrough'] class OneHotCategorical(Distribution): r""" diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 15cce64bcb489..c39ca17939eee 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -4,6 +4,7 @@ from torch.distributions.transforms import AffineTransform, ExpTransform from torch.distributions.utils import broadcast_all +__all__ = ['Pareto'] class Pareto(TransformedDistribution): r""" diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index 83e5a1c9041ae..9f9c9f54143fd 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -5,6 +5,7 @@ from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all +__all__ = ['Poisson'] class Poisson(ExponentialFamily): r""" diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index cea8a7c93d3b8..c82ff3d601e1a 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -6,6 +6,7 @@ from torch.distributions.transforms import SigmoidTransform from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs +__all__ = ['LogitRelaxedBernoulli', 'RelaxedBernoulli'] class LogitRelaxedBernoulli(Distribution): r""" diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 867d8b59ad195..beefec3fcffd2 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -6,6 +6,7 @@ from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import ExpTransform +__all__ = ['ExpRelaxedCategorical', 'RelaxedOneHotCategorical'] class ExpRelaxedCategorical(Distribution): r""" diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index f89c886e4de6e..36d227455105c 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -6,6 +6,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import _standard_normal, broadcast_all +__all__ = ['StudentT'] class StudentT(Distribution): r""" diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 53382a3d4eec1..9d7bd6fbd6907 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -6,6 +6,7 @@ from torch.distributions.utils import _sum_rightmost from typing import Dict +__all__ = ['TransformedDistribution'] class TransformedDistribution(Distribution): r""" diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 045441d25f5ac..c73f33023275f 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -96,6 +96,11 @@ def __init__(self, cache_size=0): raise ValueError('cache_size must be 0 or 1') super(Transform, self).__init__() + def __getstate__(self): + state = self.__dict__.copy() + state["_inv"] = None + return state + @property def event_dim(self): if self.domain.event_dim == self.codomain.event_dim: @@ -974,6 +979,7 @@ class CatTransform(Transform): in a way compatible with :func:`torch.cat`. Example:: + x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0) x = torch.cat([x0, x0], dim=0) t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) @@ -1076,6 +1082,7 @@ class StackTransform(Transform): in a way compatible with :func:`torch.stack`. Example:: + x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1) t = StackTransform([ExpTransform(), identity_transform], dim=1) y = t(x) diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index 88f20e4a7131f..c4baa48109f68 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -6,6 +6,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all +__all__ = ['Uniform'] class Uniform(Distribution): r""" diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index ce33796829fd7..5d7c2b868d5ce 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -6,6 +6,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all, lazy_property +__all__ = ['VonMises'] def _eval_poly(y, coef): coef = list(coef) diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index fdfe884a199c2..50f001d5297e4 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -6,6 +6,7 @@ from torch.distributions.utils import broadcast_all from torch.distributions.gumbel import euler_constant +__all__ = ['Weibull'] class Weibull(TransformedDistribution): r""" diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index 30e4284c9f458..058a9e51c5fb7 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -11,6 +11,8 @@ from torch.distributions.multivariate_normal import _precision_to_scale_tril +__all__ = ['Wishart'] + _log_2 = math.log(2) diff --git a/torch/functional.py b/torch/functional.py index 75d8e365140aa..f8798c32040c0 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -252,7 +252,7 @@ def einsum(*args: Any) -> Tensor: may be provided in a sublist to enable broadcasting as described in the Equation section above. Args: - equation (string): The subscripts for the Einstein summation. + equation (str): The subscripts for the Einstein summation. operands (List[Tensor]): The tensors to compute the Einstein summation of. Examples:: @@ -575,7 +575,7 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, center (bool, optional): whether to pad :attr:`input` on both sides so that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. Default: ``True`` - pad_mode (string, optional): controls the padding method used when + pad_mode (str, optional): controls the padding method used when :attr:`center` is ``True``. Default: ``"reflect"`` normalized (bool, optional): controls whether to return the normalized STFT results Default: ``False`` diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index 2c8b77ea0fca7..1795983b3f300 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -4,6 +4,8 @@ import torch +__all__ = ['Future', 'collect_all', 'wait_all'] + T = TypeVar("T") S = TypeVar("S") diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 83372a3e86ce6..5ef5c78278a77 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -3,26 +3,45 @@ import inspect import math import os -from types import CodeType, FunctionType, ModuleType -from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, Type, List, Callable, Union from itertools import chain +from types import CodeType, FunctionType, ModuleType +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + Union, +) + import torch -from torch._C import ScriptObject # type: ignore[attr-defined] import torch.utils._pytree as pytree +from torch._C import ScriptObject # type: ignore[attr-defined] from ._compatibility import compatibility -from .node import Argument, map_aggregate, base_types -from .graph import Graph, _PyTreeInfo, _PyTreeCodeGen +from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph from .graph_module import GraphModule -from .proxy import TracerBase, Proxy, ParameterProxy +from .node import Argument, base_types, map_aggregate +from .proxy import ParameterProxy, Proxy, TracerBase HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS # These need to run in global scope to handle nested calls correctly -_orig_module_call : Callable = torch.nn.Module.__call__ -_orig_module_getattr : Callable = torch.nn.Module.__getattr__ +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ + +_proxyable_classes: Dict[Type, None] = {} + +_is_fx_tracing_flag = False + + +def is_fx_tracing(): + return _is_fx_tracing_flag -_proxyable_classes : Dict[Type, None] = {} @compatibility(is_backward_compatible=True) class ProxyableClassMeta(type): @@ -69,6 +88,7 @@ def forward(self, x : __main___TensorPair, y : torch.Tensor): defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic tracing. """ + def __init__(cls, name, bases, attrs): _proxyable_classes.setdefault(cls) super().__init__(name, bases, attrs) @@ -87,7 +107,7 @@ def check_proxy(a): if len(found_proxies) != 0: tracer = found_proxies[0].tracer - return tracer.create_proxy('call_function', cls, args, kwargs) + return tracer.create_proxy("call_function", cls, args, kwargs) else: cls.__init__(instance, *args, **kwargs) # type: ignore[misc] return instance @@ -96,25 +116,48 @@ def check_proxy(a): def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: co = fn.__code__ co_flags = co.co_flags & ~HAS_VARSTUFF - co_args : tuple + co_args: tuple if hasattr(co, "co_posonlyargcount"): co_args = ( - nargs, 0, - 0, co.co_nlocals, co.co_stacksize, - co_flags, co.co_code, co.co_consts, co.co_names, - co.co_varnames, co.co_filename, co.co_name, - co.co_firstlineno, co.co_lnotab, co.co_freevars, - co.co_cellvars + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, ) else: co_args = ( - nargs, 0, co.co_nlocals, - co.co_stacksize, co_flags, co.co_code, co.co_consts, - co.co_names, co.co_varnames, co.co_filename, - co.co_name, co.co_firstlineno, co.co_lnotab, - co.co_freevars, co.co_cellvars) + nargs, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) new_code = CodeType(*co_args) # type: ignore[arg-type] - return FunctionType(new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__) + return FunctionType( + new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ + ) # we need to insert placeholder nodes for *args and **kwargs # we can't call this function normally, otherwise it would try to unpack them @@ -126,11 +169,14 @@ class PHBase(object): """ Object representing an input placeholder to `concrete_args` """ + def __repr__(self): - return 'PH' + return "PH" + PH = PHBase() + @compatibility(is_backward_compatible=True) class Tracer(TracerBase): # Reference: https://github.com/pytorch/pytorch/issues/54354 @@ -153,9 +199,12 @@ class Tracer(TracerBase): # includes the local filepath to the `math` module, which would jitter # across machines. @compatibility(is_backward_compatible=True) - def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ), - autowrap_functions: Tuple[Callable, ...] = (), - param_shapes_constant: bool = False) -> None: + def __init__( + self, + autowrap_modules: Tuple[ModuleType] = (math,), + autowrap_functions: Tuple[Callable, ...] = (), + param_shapes_constant: bool = False, + ) -> None: # This method's signature is overridden by the first line of this class' # docstring. If this method's signature is modified, the signature that # overrides it also should be modified accordingly. @@ -187,8 +236,10 @@ def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ), # Functions we will eagerly wrap when we see them while tracing # this captures both `math.sqrt()` and `from math import sqrt` automatically self._autowrap_function_ids: Set[int] = { - id(value) for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) - if not name.startswith("_") and callable(value)} + id(value) + for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) + if not name.startswith("_") and callable(value) + } self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) # Python modules to apply autowrap to at the start, in addition to @@ -199,7 +250,7 @@ def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ), self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None @compatibility(is_backward_compatible=True) - def create_arg(self, a: Any) -> 'Argument': + def create_arg(self, a: Any) -> "Argument": """ A method to specify the behavior of tracing when preparing values to be used as arguments to nodes in the ``Graph``. @@ -233,21 +284,21 @@ def create_arg(self, a: Any) -> 'Argument': if isinstance(a, torch.nn.Parameter): for n, p in self.root.named_parameters(): if a is p: - return self.create_node('get_attr', n, (), {}) - raise NameError('parameter is not a member of this module') + return self.create_node("get_attr", n, (), {}) + raise NameError("parameter is not a member of this module") elif isinstance(a, torch.Tensor): for n_, p_ in self.root.named_buffers(): if a is p_: - return self.create_node('get_attr', n_, (), {}) + return self.create_node("get_attr", n_, (), {}) elif isinstance(a, torch.nn.Module): for n_, p_ in self.root.named_modules(): if a is p_: - return self.create_node('get_attr', n_, (), {}) + return self.create_node("get_attr", n_, (), {}) # For NamedTuple instances that appear literally as args, we emit # a node to construct the NamedTuple and use that Node as the argument. - if isinstance(a, tuple) and hasattr(a, '_fields'): + if isinstance(a, tuple) and hasattr(a, "_fields"): args = tuple(self.create_arg(elem) for elem in a) - return self.create_node('call_function', a.__class__, args, {}) + return self.create_node("call_function", a.__class__, args, {}) # Tensors do not have a reliable string repr() from which they can be # constructed (and we probably don't want to rely on that, either), so @@ -257,21 +308,21 @@ def create_arg(self, a: Any) -> 'Argument': # tensor value into a special attribute on the Module s.t. we can # retrieve it with a get_attr. if isinstance(a, (torch.Tensor, ScriptObject)): - qualname : Optional[str] = self.tensor_attrs.get(a) + qualname: Optional[str] = self.tensor_attrs.get(a) # Tensor was not found in the Module hierarchy, stow it away in a # special attribute and set the qualname to refer to that if not qualname: i = 0 while True: - qualname = f'_tensor_constant{i}' + qualname = f"_tensor_constant{i}" if not hasattr(self.root, qualname): break i += 1 self.tensor_attrs[a] = qualname setattr(self.root, qualname, a) - return self.create_node('get_attr', qualname, (), {}) + return self.create_node("get_attr", qualname, (), {}) if type(a) in _proxyable_classes: # This is an instance of a proxyable class for which we did not @@ -280,18 +331,18 @@ def create_arg(self, a: Any) -> 'Argument': # TODO: binary search i = 0 while True: - qualname = f'_{a.__class__.__name__}_constant_{i}' + qualname = f"_{a.__class__.__name__}_constant_{i}" if not hasattr(self.root, qualname): break i += 1 setattr(self.root, qualname, a) - return self.create_node('get_attr', qualname, (), {}) + return self.create_node("get_attr", qualname, (), {}) return super().create_arg(a) @compatibility(is_backward_compatible=True) - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: """ A method to specify whether a given ``nn.Module`` is a "leaf" module. @@ -310,10 +361,12 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo submodule ``bar``, which contains submodule ``baz``, that module will appear with the qualified name ``foo.bar.baz`` here. """ - return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential) + return m.__module__.startswith("torch.nn") and not isinstance( + m, torch.nn.Sequential + ) @compatibility(is_backward_compatible=True) - def path_of_module(self, mod : torch.nn.Module) -> str: + def path_of_module(self, mod: torch.nn.Module) -> str: """ Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if ``root`` has a submodule named ``foo``, which has @@ -328,7 +381,7 @@ def path_of_module(self, mod : torch.nn.Module) -> str: if self.submodule_paths: path = self.submodule_paths.get(mod) if path is None: - raise NameError('module is not installed as a submodule') + raise NameError("module is not installed as a submodule") assert isinstance(path, str) return path # O(N^2) fallback in the case that we didn't store the submodule @@ -337,10 +390,16 @@ def path_of_module(self, mod : torch.nn.Module) -> str: for n, p in self.root.named_modules(): if mod is p: return n - raise NameError('module is not installed as a submodule') + raise NameError("module is not installed as a submodule") @compatibility(is_backward_compatible=True) - def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Any: """ Method that specifies the behavior of this ``Tracer`` when it encounters a call to an ``nn.Module`` instance. @@ -370,7 +429,7 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tu module_qualified_name = self.path_of_module(m) if not self.is_leaf_module(m, module_qualified_name): return forward(*args, **kwargs) - return self.create_proxy('call_module', module_qualified_name, args, kwargs) + return self.create_proxy("call_module", module_qualified_name, args, kwargs) # This method will be refactored @compatibility(is_backward_compatible=False) @@ -389,11 +448,13 @@ def create_args_for_root(self, root_fn, is_module, concrete_args=None): total_args = co.co_argcount + co.co_kwonlyargcount orig_args = list(co.co_varnames) names_iter = iter(co.co_varnames) - args : List[Any] = [] + args: List[Any] = [] skip_arg_idx = 0 if is_module: if total_args == 0: - raise RuntimeError('``self`` argument cannot be part of *args expansion!') + raise RuntimeError( + "``self`` argument cannot be part of *args expansion!" + ) skip_arg_idx = 1 next(names_iter) # skip self args.append(self.root) @@ -401,23 +462,39 @@ def create_args_for_root(self, root_fn, is_module, concrete_args=None): sig = inspect.signature(fn_for_analysis) def proxy_placeholder(name: str): - if concrete_args is not None and name in concrete_args : + if concrete_args is not None and name in concrete_args: cnt = 0 def replace_ph(x): nonlocal cnt cnt += 1 param = sig.parameters[name] - default = () if param.default is inspect.Parameter.empty else (param.default,) - out = self.create_proxy('placeholder', f'{name}_{str(cnt)}', default, {}) + default = ( + () + if param.default is inspect.Parameter.empty + else (param.default,) + ) + out = self.create_proxy( + "placeholder", f"{name}_{str(cnt)}", default, {} + ) if x == PH: return out # Union[int, bool] == bool in Python <= 3.6 - if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor: - torch._assert(out == x, f"{name} has been specialized to have value {x} but got another value") + if ( + type(x) == bool + or type(x) in base_types + and type(x) != torch.Tensor + ): + torch._assert( + out == x, + f"{name} has been specialized to have value {x} but got another value", + ) elif type(x) == type(None): - args = (out, f"{name} has been specialized to have value None but got another value") - self.create_proxy('call_function', _assert_is_none, args, {}) + args = ( + out, + f"{name} has been specialized to have value None but got another value", + ) + self.create_proxy("call_function", _assert_is_none, args, {}) else: torch.warnings.warn( f"Was not able to add assertion to guarantee correct input {name} to " @@ -428,27 +505,34 @@ def replace_ph(x): return x return pytree.tree_map(replace_ph, concrete_args[name]) - if name[0] == '*': + if name[0] == "*": default = () else: param = sig.parameters[name] default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] - return self.create_proxy('placeholder', name, default, {}, - type_expr=fn_for_analysis.__annotations__.get(name, None)) + return self.create_proxy( + "placeholder", + name, + default, + {}, + type_expr=fn_for_analysis.__annotations__.get(name, None), + ) + arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] if isinstance(concrete_args, tuple): if len(arg_names) != len(concrete_args): - raise RuntimeError(f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments") + raise RuntimeError( + f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" + ) concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} args.extend(proxy_placeholder(names) for names in arg_names) - if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: # TODO: type annotations for *args and **kwargs if co.co_flags & inspect.CO_VARARGS: - args.append(proxy_placeholder('*' + next(names_iter))) + args.append(proxy_placeholder("*" + next(names_iter))) if co.co_flags & inspect.CO_VARKEYWORDS: - args.append(proxy_placeholder('**' + next(names_iter))) + args.append(proxy_placeholder("**" + next(names_iter))) root_fn = _patch_function(root_fn, len(args)) flat_args, in_spec = pytree.tree_flatten(tuple(args)) @@ -456,48 +540,69 @@ def replace_ph(x): # In the case that we have pytree-flattened inputs in # `concrete_args`, generate a flattening wrapper around the # original root function and return that. - self.graph._codegen = _PyTreeCodeGen(_PyTreeInfo(orig_args[:total_args], in_spec, None)) + self.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo(orig_args[:total_args], in_spec, None) + ) def flatten_fn(*args): tree_args = pytree.tree_unflatten(list(args), in_spec) tree_out = root_fn(*tree_args) out_args, out_spec = pytree.tree_flatten(tree_out) - assert(isinstance(self.graph._codegen, _PyTreeCodeGen)) - self.graph._codegen.pytree_info = self.graph._codegen.pytree_info._replace(out_spec=out_spec) + assert isinstance(self.graph._codegen, _PyTreeCodeGen) + self.graph._codegen.pytree_info = ( + self.graph._codegen.pytree_info._replace(out_spec=out_spec) + ) return out_args return flatten_fn, flat_args return root_fn, args - def _module_getattr(self, attr, attr_val, parameter_proxy_cache): - def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): + def maybe_get_proxy_for_attr( + attr_val, collection_to_search, parameter_proxy_cache + ): for n, p in collection_to_search: if attr_val is p: if n not in parameter_proxy_cache: kwargs = {} - if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: - kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else - lambda node : ParameterProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + if ( + "proxy_factory_fn" + in inspect.signature(self.create_proxy).parameters + ): + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy( + self, node, n, attr_val + ) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): - maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache) + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) if maybe_buffer_proxy is not None: return maybe_buffer_proxy return attr_val @compatibility(is_backward_compatible=True) - def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + def trace( + self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + ) -> Graph: """ Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` can either be an ``nn.Module`` instance or a Python callable. @@ -521,84 +626,109 @@ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: A ``Graph`` representing the semantics of the passed-in ``root``. """ - if isinstance(root, torch.nn.Module): - self.root = root - - - assert hasattr( - type(root), self.traced_func_name - ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" - - fn = getattr(type(root), self.traced_func_name) - self.submodule_paths = {mod: name for name, mod in root.named_modules()} - else: - self.root = torch.nn.Module() - fn = root - - tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None) - self.graph = Graph(tracer_cls=tracer_cls) - - # When we encounter a Tensor value that's not a parameter, we look if it - # is some other attribute on the model. Construct a dict mapping Tensor - # values to the qualified name here for efficiency. This is used downstream - # in create_arg - self.tensor_attrs : Dict[Union[torch.Tensor, ScriptObject], str] = {} - - def collect_tensor_attrs(m : torch.nn.Module, prefix_atoms : List[str]): - for k, v in m.__dict__.items(): - if isinstance(v, (torch.Tensor, ScriptObject)): - self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) - for k, v in m.named_children(): - collect_tensor_attrs(v, prefix_atoms + [k]) - - collect_tensor_attrs(self.root, []) - - assert isinstance(fn, FunctionType) - - fn_globals = fn.__globals__ # run before it gets patched - fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module), concrete_args) - - parameter_proxy_cache : Dict[str, Proxy] = {} # Reduce number of get_attr calls - - # Method dispatch on parameters is not recorded unless it's directly used. - # Thus, we need to insert a proxy when __getattr__ requests a parameter. - @functools.wraps(_orig_module_getattr) - def module_getattr_wrapper(mod, attr): - attr_val = _orig_module_getattr(mod, attr) - return self._module_getattr(attr, attr_val, parameter_proxy_cache) - - @functools.wraps(_orig_module_call) - def module_call_wrapper(mod, *args, **kwargs): - def forward(*args, **kwargs): - return _orig_module_call(mod, *args, **kwargs) - - _autowrap_check(patcher, getattr(getattr(mod, "forward", mod), "__globals__", {}), - self._autowrap_function_ids) - return self.call_module(mod, forward, args, kwargs) - - with _Patcher() as patcher: - # allow duplicate patches to support the case of nested calls - patcher.patch_method(torch.nn.Module, "__getattr__", module_getattr_wrapper, deduplicate=False) - patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False) - _patch_wrapped_functions(patcher) - _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) - for module in self._autowrap_search: - _autowrap_check(patcher, module.__dict__, self._autowrap_function_ids) - self.create_node('output', 'output', (self.create_arg(fn(*args)),), {}, - type_expr=fn.__annotations__.get('return', None)) - - self.submodule_paths = None - + global _is_fx_tracing_flag + old_is_fx_tracing_flag = _is_fx_tracing_flag + _is_fx_tracing_flag = True + try: + if isinstance(root, torch.nn.Module): + self.root = root + + assert hasattr( + type(root), self.traced_func_name + ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + + fn = getattr(type(root), self.traced_func_name) + self.submodule_paths = {mod: name for name, mod in root.named_modules()} + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None) + self.graph = Graph(tracer_cls=tracer_cls) + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root( + fn, isinstance(root, torch.nn.Module), concrete_args + ) + + parameter_proxy_cache: Dict[ + str, Proxy + ] = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self._module_getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, + getattr(getattr(mod, "forward", mod), "__globals__", {}), + self._autowrap_function_ids, + ) + return self.call_module(mod, forward, args, kwargs) + + with _Patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + "__getattr__", + module_getattr_wrapper, + deduplicate=False, + ) + patcher.patch_method( + torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False + ) + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check( + patcher, module.__dict__, self._autowrap_function_ids + ) + self.create_node( + "output", + "output", + (self.create_arg(fn(*args)),), + {}, + type_expr=fn.__annotations__.get("return", None), + ) + + self.submodule_paths = None + finally: + _is_fx_tracing_flag = old_is_fx_tracing_flag return self.graph # List of pairs of (global dict, function name) functions # to patch for the purposes of the wrap() API. -_wrapped_fns_to_patch : List[Tuple[dict, str]] = [] +_wrapped_fns_to_patch: List[Tuple[dict, str]] = [] # List of methods on classes to wrap (class type, function name) # this currently only works for Tensor.* methods that aren't traced properly -_wrapped_methods_to_patch : List[Tuple[type, str]] = [] +_wrapped_methods_to_patch: List[Tuple[type, str]] = [] if os.environ.get("FX_PATCH_GETITEM") == "1": # This change is needed to trace models like PositionalEmbedding from BERT: @@ -636,8 +766,10 @@ def wrapped(*args, **kwargs): """ proxy = _find_proxy(args, kwargs) if proxy is not None: - return_proxy = proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs) - return_proxy.node.meta['is_wrapped'] = True + return_proxy = proxy.tracer.create_proxy( + "call_function", orig_fn, args, kwargs + ) + return_proxy.node.meta["is_wrapped"] = True return return_proxy return orig_fn(*args, **kwargs) @@ -657,16 +789,16 @@ def wrapped(*args, **kwargs): """ proxy = _find_proxy(args, kwargs) if proxy is not None: - return proxy.tracer.create_proxy('call_method', name, args, kwargs) + return proxy.tracer.create_proxy("call_method", name, args, kwargs) return orig_fn(*args, **kwargs) return wrapped class _PatchedFn(NamedTuple): - frame_dict : Any - fn_name : str - orig_fn : Any + frame_dict: Any + fn_name: str + orig_fn: Any def revert(self): raise NotImplementedError() @@ -690,11 +822,16 @@ def revert(self): class _Patcher(object): def __init__(self): super(_Patcher, self).__init__() - self.patches_made : List[_PatchedFn] = [] - self.visited : Set[int] = set() - - def patch(self, frame_dict : Dict[str, Any], name : str, new_fn : Callable, - deduplicate : bool = True): + self.patches_made: List[_PatchedFn] = [] + self.visited: Set[int] = set() + + def patch( + self, + frame_dict: Dict[str, Any], + name: str, + new_fn: Callable, + deduplicate: bool = True, + ): """ Replace frame_dict[name] with new_fn until we exit the context manager. """ @@ -704,11 +841,14 @@ def patch(self, frame_dict : Dict[str, Any], name : str, new_fn : Callable, elif getattr(frame_dict[name], "__fx_already_patched", False): return # already patched, no need to do it again else: - self.patches_made.append(_PatchedFnSetItem(frame_dict, name, frame_dict[name])) + self.patches_made.append( + _PatchedFnSetItem(frame_dict, name, frame_dict[name]) + ) frame_dict[name] = new_fn - def patch_method(self, cls: type, name : str, new_fn : Callable, - deduplicate : bool = True): + def patch_method( + self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True + ): """ Replace object_or_dict.name with new_fn until we exit the context manager. """ @@ -720,7 +860,7 @@ def patch_method(self, cls: type, name : str, new_fn : Callable, setattr(cls, name, new_fn) def visit_once(self, thing: Any): - """ Return True on the first call to with thing, otherwise false """ + """Return True on the first call to with thing, otherwise false""" idx = id(thing) if idx in self.visited: return False @@ -740,7 +880,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.visited.clear() -def _patch_wrapped_functions(patcher : _Patcher): +def _patch_wrapped_functions(patcher: _Patcher): """ Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap the listed global functions in the `_create_wrapped_func` wrapper. @@ -756,19 +896,25 @@ def _patch_wrapped_functions(patcher : _Patcher): patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) -def _autowrap_check(patcher : _Patcher, frame_dict : Dict[str, Any], function_ids : Set[int]): +def _autowrap_check( + patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] +): """ Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. This method searches a scope for them and patches them if found. """ if patcher.visit_once(frame_dict): for name, value in frame_dict.items(): - if not name.startswith("_") and callable(value) and id(value) in function_ids: + if ( + not name.startswith("_") + and callable(value) + and id(value) in function_ids + ): patcher.patch(frame_dict, name, _create_wrapped_func(value)) @compatibility(is_backward_compatible=True) -def wrap(fn_or_name : Union[str, Callable]): +def wrap(fn_or_name: Union[str, Callable]): """ This function can be called at module-level scope to register fn_or_name as a "leaf function". A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being @@ -802,31 +948,38 @@ def my_custom_function(x, y): graph when it's called """ if not callable(fn_or_name) and not isinstance(fn_or_name, str): - raise RuntimeError('Unsupported type for global function! Must be either a callable or ' - 'string name') + raise RuntimeError( + "Unsupported type for global function! Must be either a callable or " + "string name" + ) - if hasattr(fn_or_name, '__code__'): + if hasattr(fn_or_name, "__code__"): assert not isinstance(fn_or_name, str) # to make mypy happy fn_name = fn_or_name.__code__.co_name else: - assert isinstance(fn_or_name, str), "fn_or_name must be a global function or string name" + assert isinstance( + fn_or_name, str + ), "fn_or_name must be a global function or string name" fn_name = fn_or_name currentframe = inspect.currentframe() assert currentframe is not None f = currentframe.f_back assert f is not None - if f.f_code.co_name != '': - raise NotImplementedError('wrap must be called at the top level of a module') + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search # semantics would be slightly different, but would add support `from x import wrapped_function` _wrapped_fns_to_patch.append((f.f_globals, fn_name)) return fn_or_name + @compatibility(is_backward_compatible=True) -def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, +) -> GraphModule: """ Symbolic tracing API @@ -876,7 +1029,9 @@ def f(x): """ tracer = Tracer() graph = tracer.trace(root, concrete_args) - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) return GraphModule(tracer.root, graph, name) diff --git a/torch/fx/experimental/migrate_gradual_types/__init__.py b/torch/fx/experimental/migrate_gradual_types/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py new file mode 100644 index 0000000000000..b03f8bbe644bd --- /dev/null +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -0,0 +1,558 @@ +# -*- coding: utf-8 -*- +from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ + op_mod, op_gt, op_lt, op_neq, op_eq +from torch.fx.tensor_type import TensorType, Dyn + + +class Constraint: + pass + + +class Conj(Constraint): + def __init__(self, conjuncts): + """ + :param conjuncts: Conjuction of constraints + """ + self.conjucts = conjuncts + + def __eq__(self, other): + if isinstance(other, Conj): + return self.conjucts == other.conjucts and self.conjucts == other.conjucts + else: + return False + + def __repr__(self): + return f'And({self.conjucts})' + + +class Disj(Constraint): + def __init__(self, disjuncts): + """ + :param disjuncts: Disjunction of constraints + """ + self.disjuncts = disjuncts + + def __eq__(self, other): + if isinstance(other, Disj): + return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + else: + return False + + def __repr__(self): + return f'Or({self.disjuncts})' + + +class Prod(Constraint): + def __init__(self, products): + """ + :param products: lists of dimensions to multiply + """ + self.products = products + + def __eq__(self, other): + if isinstance(other, Prod): + return self.products == other.products and self.products == other.products + else: + return False + + def __repr__(self): + return f'Product({self.products})' + + +class T(Constraint): + """ + True + """ + def __init__(self): + pass + + def __eq__(self, other): + return isinstance(other, T) + + def __repr__(self): + return 'True' + +class F(Constraint): + """ + False + """ + def __init__(self): + pass + + def __eq__(self, other): + return isinstance(other, F) + + def __repr__(self): + return 'False' + + +class BinaryConstraint(Constraint): + """ + Represents all binary operations + """ + def __init__(self, lhs, rhs, op): + """ + :param lhs: lhs of the constraint + :param rhs: rhs of the constraint + :param op: string reprsenting the operation + """ + self.lhs = lhs + self.rhs = rhs + self.op = op + + def __eq__(self, other): + if isinstance(other, BinaryConstraint): + return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + else: + return False + + def __repr__(self): + return f'({self.lhs} {self.op} {self.rhs})' + + +class BinConstraintT(BinaryConstraint): + """ + Binary constraints about tensors + """ + def __init__(self, lhs, rhs, op): + assert (isinstance(lhs, TVar) or isinstance(lhs, TensorType) or isinstance(lhs, int) or lhs == Dyn) and \ + (isinstance(rhs, TVar) or isinstance(rhs, TensorType) or isinstance(rhs, int) or rhs == Dyn) + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + +class BinConstraintD(BinaryConstraint): + """ + Binary constraints about dimensions + """ + def __init__(self, lhs, rhs, op): + assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) + assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) + + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + + +class TGreatestUpperBound(Constraint): + """ + Greatest Upper bound for tensors with dynamic type + """ + def __init__(self, res, rhs1, rhs2): + """ + :param res: tensor variable that stores the result of the outout + :param rhs1: tensor or tensor variable + :param rhs2: tensor or tensor variabke + """ + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f'{self.res} = {self.rhs1}⊔*{self.rhs2}' + + def __eq__(self, other): + if isinstance(other, TGreatestUpperBound): + return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + else: + return False + + +class DGreatestUpperBound(Constraint): + """ + Greatest Upper bound for dimensions + """ + def __init__(self, res, rhs1, rhs2): + """ + :param res: Dimension variable to store the result + :param rhs1: dimension variable 1 + :param rhs2: dimension variable 2 + """ + assert is_dim(res) + assert is_dim(rhs1) + assert is_dim(rhs2) + + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f'{self.res} = {self.rhs1}⊔{self.rhs2}' + + def __eq__(self, other): + if isinstance(other, DGreatestUpperBound): + return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + else: + return False + + +class CanReshape(Constraint): + """ + can_reshape constraint + """ + def __init__(self, src, target): + """ + :param src: tensor variable + :param target: tensor + """ + self.src = src + self.target = target + + def __repr__(self): + return f'can-reshape({self.src}, {self.target})' + + def __eq__(self, other): + if isinstance(other, CanReshape): + return self.src == other.src and self.target == other.target + else: + return False + + +class IndexSelect(Constraint): + + def __init__(self, tensor_size, input_var, dim_replace, index, output): + """ + Args: + input_var: input to index_select + tensor_size: tensor size we are considering + dim_replace: the dimension of the output at "index" + index: location of the dimensions to replace in the input + outut: variable to store the result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(dim_replace, DVar) or dim_replace == Dyn + assert isinstance(index, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.dim_replace = dim_replace + self.index = index + self.output = output + + def __repr__(self): + + return f' {self.output} = ' \ + f'IndexSelect({self.input_var}, ' \ + f'tensor_size: {self.tensor_size}, ' \ + f'{self.dim_replace}, ' \ + f'{self.index})' + + def __eq__(self, other): + if isinstance(other, IndexSelect): + return self.tensor_size == other.tensor_size and \ + self.dim_replace == other.dim_replace and \ + self.index == other.index and \ + self.output == other.output and \ + self.input_var == other.input_var + else: + return False + + +class Transpose(Constraint): + + def __init__(self, tensor_size, input_var, index1, index2, output): + """ + Args: + tensor_size: current tensor size + input_var: variable to hold input + index1: dimension 1 + index2: dimension 2 + output: output that stores result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(index1, int) + assert isinstance(index2, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.index1 = index1 + self.index2 = index2 + self.output = output + + def __repr__(self): + + return f' {self.output} = ' \ + f'Transpose({self.input_var}, ' \ + f'tensor_size: {self.tensor_size}, ' \ + f'{self.index1}, ' \ + f'{self.index2})' + + def __eq__(self, other): + if isinstance(other, Transpose): + return self.tensor_size == other.tensor_size and \ + self.index1 == other.index1 and \ + self.index2 == other.index2 and \ + self.output == other.output and \ + self.input_var == other.input_var + else: + return False + + +class GetItem(Constraint): + + def __init__(self, tensor_size, index, res, input_var): + """ + Constraint for getting item given a tensor size + :param tensor_size: actual number + :param index: actual number representing the index + :param res: dimension variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, DVar) + + self.res = res + self.tensor_size = tensor_size + self.index = index + self.input_var = input_var + + def __repr__(self): + return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})' + + def __eq__(self, other): + if isinstance(other, GetItem): + return self.res == other.res and \ + self.tensor_size == other.tensor_size and \ + self.index == other.index and \ + self.input_var == other.input_var + else: + return False + +class GetItemTensor(Constraint): + + def __init__(self, tensor_size, index_tuple, res, input_var): + """ + Constraint for getting item given a tensor size + However, when the argument is a tuple, we will + expect a tensor + :param tensor_size: actual number representing the rank + :param index_tuple: tuple for indexing + :param res: tensor variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, TVar) + + self.res = res + self.tensor_size = tensor_size + self.index_tuple = index_tuple + self.input_var = input_var + + def __repr__(self): + return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})' + + def __eq__(self, other): + if isinstance(other, GetItemTensor): + return self.res == other.res and \ + self.tensor_size == other.tensor_size and \ + self.index_tuple == other.index_tuple and \ + self.input_var == other.input_var + else: + return False + +class CalcConv(Constraint): + + def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars): + """ + :param conv_result: the convolution result + :param input_var: input to convolution + :param c_out: output chanel type + :param kernel: kernel tuple + """ + self.conv_result = conv_result + self.input_var = input_var + self.c_out = c_out + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return f'{self.conv_result} =' \ + f' calc-conv({self.input_var},' \ + f' {self.c_out}, {self.kernel}, ' \ + f'{self.padding}, {self.stride},' \ + f' {self.dilation})' + + def __eq__(self, other): + if isinstance(other, CalcConv): + return self.conv_result == other.conv_result and self.input_var == other.input_var and \ + self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \ + and self.stride == other.stride and self.dilation == other.dilation \ + and self.matching_constraint == other.matching_constraint + else: + return False + + +class CalcMaxPool(Constraint): + + def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars): + """ + :param maxpool_result: the result of maxpool + :param input_var: input to convolution + :param kernel: kernel tuple + """ + self.maxpool_result = maxpool_result + self.input_var = input_var + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return f'{self.maxpool_result} =' \ + f' calc-maxpool({self.input_var},' \ + f' {self.kernel}, ' \ + f'{self.padding}, {self.stride},' \ + f' {self.dilation})' + + def __eq__(self, other): + if isinstance(other, CalcMaxPool): + return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \ + and self.kernel == other.kernel and self.padding == other.padding \ + and self.stride == other.stride and self.dilation == other.dilation \ + and self.matching_constraint == other.matching_constraint + else: + return False + + +class ApplyBroadcasting(Constraint): + def __init__(self, res1, res2, input1, input2): + """ + :param res1: resulting tensor 1 + :param res2: resulting tensor 2 + :param input1: tensor variable 1 + :param input2: tensor variable 2 + """ + self.res1 = res1 + self.res2 = res2 + self.input1 = input1 + self.input2 = input2 + + def __eq__(self, other): + if isinstance(other, ApplyBroadcasting): + return self.res1 == other.res1 \ + and self.res2 == other.res2 \ + and self.input1 == other.input1 \ + and self.input2 == other.input2 + else: + return False + + def __repr__(self): + return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})' + + +class CalcProduct(Constraint): + """ + Given correct dimensions, calculate the product for flatten accounting for Dyn + """ + def __init__(self, start, end, flattened, dims_to_flatten): + """ + :param start: start index + :param end: end index + :param theta: variable to store the product + :param dims_to_flatten: the type which we will flatten + """ + assert isinstance(dims_to_flatten, list) + assert isinstance(flattened, TVar) + assert isinstance(start, int) + assert isinstance(end, int) + + self.start = start + self.end = end + self.dims_to_flatten = dims_to_flatten + self.flattened = flattened + + def __eq__(self, other): + if isinstance(other, CalcProduct): + return self.start == other.start and self.end == other.end and \ + self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened + + else: + return False + + def __repr__(self): + return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})' + + +class TVar: + """ + Tensor variable with no tensor constructor + """ + def __init__(self, tvar): + """ + :param tvar: tensor variable + """ + self.tvar = tvar + + def __repr__(self): + return f'TV({self.tvar})' + + def __eq__(self, other): + if isinstance(other, TVar): + return self.tvar == other.tvar + else: + return False + + +class DVar: + """ + Dimension variable + """ + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f'DV({self.c})' + + def __eq__(self, other): + if isinstance(other, DVar): + return self.c == other.c + else: + return False + + +class BVar: + """ + Boolean variable + """ + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f'BV({self.c})' + + def __eq__(self, other): + if isinstance(other, BVar): + return self.c == other.c + else: + return False + + +def is_algebraic_expression(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] + else: + return isinstance(constraint, Prod) + + +def is_bool_expr(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_gt, op_lt, op_neq, op_eq] + else: + return isinstance(constraint, BVar) or isinstance(constraint, Conj) or isinstance(constraint, Disj) + +def is_dim(d): + return isinstance(d, DVar) or isinstance(d, int) or d == Dyn diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py new file mode 100644 index 0000000000000..e5955fd039b8e --- /dev/null +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -0,0 +1,1129 @@ +import torch +import operator +from typing import Callable, Dict, Iterable + +from torch.fx._symbolic_trace import _assert_is_none +from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \ + Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \ + TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.operation import \ + op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul +from torch.fx.node import Target, Node +from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \ + gen_bvar + +from torch.fx.tensor_type import Dyn, TensorType +from torch.nn.modules.conv import Conv2d +from torch.nn.modules.batchnorm import BatchNorm2d + +_INFERENCE_RULES: Dict[Target, Callable] = {} + +MAX_TENSOR_RANK = 4 + +def register_inference_rule(call_target): + def register(fn): + if call_target in _INFERENCE_RULES: + raise RuntimeError(f'Inference rule already registered for {call_target}!') + _INFERENCE_RULES[call_target] = fn + return fn + return register + + +def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): + d, counter = gen_tensor_dims(n, counter) + c1 = BinConstraintT(input, TensorType(d), op_eq) + start_dim = n if start_dim == -1 else abs(start_dim) + end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 + c2 = CalcProduct(start_dim, end_dim, flattened, d) + nat_constraints = gen_nat_constraints(d) + return Conj([c1, c2, *nat_constraints]), counter + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, symbols, constraints, counter): + """ + If the attribute is "device" then the tensor shape is preserved + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], str) + output, counter = gen_tvar(counter) + symbols[n] = output + + input = symbols[n.args[0]] + attr = n.args[1] + + if attr == 'device': + return [BinConstraintT(input, output, op_eq)], counter + else: + raise NotImplementedError('Not yet implemented') + +@register_inference_rule(torch.bmm) +def bmm_inference_rule(n: Node, symbols, constraints, counter): + """ + Constraints that match the input to a size 3 tensor + and switch the dimensions according to the rules + of batch multiplication + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + bmm_output, counter = gen_tvar(counter) + symbols[n] = bmm_output + + bmm_input1 = symbols[n.args[0]] + bmm_input2 = symbols[n.args[1]] + + dims_input1, counter = gen_tensor_dims(3, counter) + dims_input2, counter = gen_tensor_dims(3, counter) + + inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_output, Dyn, op_eq)]) + + input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)]) + + input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)]) + + consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)] + + batch_size, counter = gen_dvar(counter) + + inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), + *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])]) + + return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter + + +@register_inference_rule("index_select") +def index_select_inference_rule(n: Node, symbols, constraints, counter): + """ + We constrain the second argument to a vector or Dyn. + The output replaces the input with the shape of the vector + at the position given by the index (first argument) + """ + # print(n.args) + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], Node) + + + + index_select, counter = gen_tvar(counter) + symbols[n] = index_select + + dims, counter = gen_tensor_dims(1, counter) + + # equality constraint + is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) + is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) + + c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) + for i in range(MAX_TENSOR_RANK)])]) + c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + for i in range(MAX_TENSOR_RANK)])]) + + return [Disj([c2, c3])], counter + + +@register_inference_rule("expand") +def expand_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the exact constraints as we do for tensor additions but we constraint + the rank of this expression to be equal to len(n.args[1:]) so that only + those cases get considered for the output + """ + assert isinstance(n.args[0], Node) + + # define the output for expand + expand, counter = gen_tvar(counter) + symbols[n] = expand + + # since we do not have two nodes here, we will construct an argument variable + e1 = symbols[n.args[0]] + e2, counter = gen_tvar(counter) + + e2_nat_constraints = [] + for arg in n.args[1:]: + assert isinstance(arg, Node) or isinstance(arg, int) + if isinstance(arg, Node): + assert isinstance(symbols[arg], DVar) + e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) + + e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) + + constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) + + # constraint the output size + dims, counter = gen_tensor_dims(len(n.args[1:]), counter) + nat_constraints = gen_nat_constraints(dims) + c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] + constraints += c + + return constraints, counter + +@register_inference_rule(torch.nn.functional.gelu) +@register_inference_rule(torch.nn.functional.dropout) +@register_inference_rule(torch.nn.functional.softmax) +@register_inference_rule("detach") +@register_inference_rule("to") +@register_inference_rule("int") +@register_inference_rule("long") +@register_inference_rule("contiguous") +def equality_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + assert isinstance(input, TVar) + return [BinConstraintT(input, output, op_eq)], counter + + +@register_inference_rule("transpose") +def transpose_inference_rule(n: Node, symbols, constraints, counter): + """ + Can be considered as a sequence of two index selects, so we generate constraints accordingly + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + assert isinstance(from_arg, TVar) + + # input and output are dyn + is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]) + + # or input is a tensor and we actually do the replacement + c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)]) + + return [Disj([is_dyn, c3])], counter + + +@register_inference_rule("type_as") +def type_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + to_arg = symbols[n.args[1]] + + assert isinstance(from_arg, TVar) + assert isinstance(to_arg, TVar) + + return [BinConstraintT(from_arg, to_arg, op_consistency), + BinConstraintT(output, to_arg, op_eq)], counter + +@register_inference_rule("masked_fill_") +def masked_fill_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to addition. For now we implemenent the constraints when + the argument is a boolean tensor. There is also a case for when + it is a condition. We will leave this out for now. + """ + + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + # We will retrieve the type variables from the symbol table + # and confirm they are tensor variables + + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + if isinstance(e1, TVar) and isinstance(e2, TVar): + masked_fill_tensor, counter = gen_tvar(counter) + symbols[n] = masked_fill_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor) + else: + raise NotImplementedError('Not yet implemented') + + +@register_inference_rule(torch.nn.modules.sparse.Embedding) +def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + The output shape differs from the input shape in the last dimension + """ + assert isinstance(n.args[0], Node) + + embedding_dim = module_instance.embedding_dim # number + + embedding_output, counter = gen_tvar(counter) + symbols[n] = embedding_output + embedding_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) + output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + + for i in range(1, MAX_TENSOR_RANK): + new_dims, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims) + + # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases + c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), + BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + + nat_constraints) + c2.append(c_tensor_i) + + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule("reshape") +@register_inference_rule("view") +def view_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to reshape but with an extra condition on the strides + """ + assert isinstance(n.args[0], Node) + + # generate the new variable + my_view, counter = gen_tvar(counter) + symbols[n] = my_view + + + src_var = symbols[n.args[0]] + t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape + t2_type = [] + num_constraints = [] + + for t in t2: + if t == -1: + var, counter = gen_dvar(counter) + t2_type.append(var) + num_constraints.append(BinConstraintD(var, Dyn, op_neq)) + + else: + num_constraints.append(BinConstraintD(t, Dyn, op_neq)) + t2_type.append(t) + + t2_type = TensorType(t2_type) # type: ignore[assignment] + + c1 = BinConstraintT(my_view, t2_type, op_eq) + c2 = CanReshape(src_var, t2_type) + + # TODO: add the extra check mentioned here: + # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view + + return [c1, c2] + num_constraints, counter # type: ignore[operator] + + +@register_inference_rule("size") +def size_inference_rule(n: Node, symbols, constraints, counter): + """ + The constraint is just lhs = rhs. + Ex: size = input_ids.size() + """ + + + if len(n.args) == 1: + # generate the new variable + size, counter = gen_tvar(counter) + symbols[n] = size + input = symbols[n.args[0]] + c = BinConstraintT(input, size, op_eq) + return [c], counter + + elif len(n.args) == 2: + # TODO: review this rule; should input = dyn; output = dyn be included here? + if isinstance(n.args[1], int): + # generate the new variable + size_index, counter = gen_dvar(counter) + symbols[n] = size_index + input = symbols[n.args[0]] + c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)] + c3 = BinConstraintD(0, size_index, op_leq) + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(size_index, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + else: + raise NotImplementedError + + else: + raise NotImplementedError + + +def range_check(i, n): + """ + Checks if an index i is within range of a size n list + Args: + i: index + n: list size + + Returns: Boolean + """ + if i >= 0: + return T() if i < n else F() + else: + return T() if i >= n else F() + + +@register_inference_rule(torch.cumsum) +def cumsum_inference_rule(n: Node, symbols, constraints, counter): + """ + Input and output shapes should be equal + We should verify that the index is valid + """ + assert isinstance(n.args[0], Node) + arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] + assert isinstance(arg_1, int) + + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims) + + c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), + BinConstraintT(output, TensorType(new_dims), op_eq)] + + [range_check(arg_1, i)] + nat_constraints) + + c2.append(c_tensor_i) + dyn_or_tensor = Disj([c1, Disj(c2)]) + return [dyn_or_tensor], counter + + +@register_inference_rule(_assert_is_none) +def assert_inference_rule(n: Node, symbols, constraints, counter): + assert len(n.users) == 0 + return [], counter + + +@register_inference_rule(operator.getitem) +def getitem_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # dimension output case + if isinstance(n.args[1], int): + # create and store the new dimension variable + get_item_output, counter = gen_dvar(counter) + symbols[n] = get_item_output + + # retreive arg variables + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + + # if the input is dynamic, we accept any index and return + # a dynamic dimension as output + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + # if the input is a tensor, + # generate a getItem constraint which will be expanded based on the + # tensor dimension. + + c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)] + + + # since the output is a dimension, we make sure it's a natural number + # added as a conjunction to the disjuction of c2 + c3 = BinConstraintD(0, get_item_output, op_leq) + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + # tensor output case + elif isinstance(n.args[1], tuple): + # create and store the new tensor variable + get_item_output, counter = gen_tvar(counter) + symbols[n] = get_item_output + + # retreive arg variables + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] + c1 = Conj([input_dyn, output_dyn]) + + + c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)] # type: ignore[misc] + + return [Disj([c1, *c2])], counter + + else: + raise RuntimeError('Method not yet implemented') + + +@register_inference_rule(operator.gt) +def gt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) + assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + gt_tensor, counter = gen_tvar(counter) + symbols[n] = gt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError('Sort Mismatch') + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError('Method not yet implemented') + + else: + raise NotImplementedError('Method not yet implemented') + + + +@register_inference_rule(operator.ne) +def neq_inference_rule(n: Node, symbols, constraints, counter): + """ + Translates to inconsistent in gradual types. + To prove inequality, we should prove that + tensors are either different sizes or + disagree on at least one dimension + + This is a WIP (works when the condition + is false. We are working on making this operation work + when the condition is true as well) + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], tuple) + + # implementing for size 3 and 4 + if len(n.args[1]) == 3: + + assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) + assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) + assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) + + lhs = symbols[n.args[0]] + + b, counter = gen_tensor_dims(4, counter) + input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b[0], op_neq) + neq_2 = BinConstraintD(d2, b[1], op_neq) + neq_3 = BinConstraintD(d3, b[2], op_neq) + + # dimensions inconsistent + dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) + dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) + dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) + + dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) + + # we are covering size 3 and 4 only for now + ne_constraint = Conj([input_is_size3, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + elif len(n.args[1]) == 4: + + assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) + assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) + assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) + assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int) + + lhs = symbols[n.args[0]] + + b1, counter = gen_dvar(counter) + b2, counter = gen_dvar(counter) + b3, counter = gen_dvar(counter) + b4, counter = gen_dvar(counter) + + input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b1, op_neq) + neq_2 = BinConstraintD(d2, b2, op_neq) + neq_3 = BinConstraintD(d3, b3, op_neq) + neq_4 = BinConstraintD(d4, b4, op_neq) + + # dimensions to inconsistent + dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) + dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) + dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) + dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) + + dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) + + ne_constraint = Conj([input_is_size4, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + else: + raise NotImplementedError('Method not yet implemented') + + return [equality_constraint], counter + + +@register_inference_rule(operator.lt) +def lt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) + assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + lt_tensor, counter = gen_tvar(counter) + symbols[n] = lt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError('Sort Mismatch') + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError('Method not yet implemented') + + else: + raise NotImplementedError('Method not yet implemented') + + +@register_inference_rule(torch.full) +def full_inference_rule(n: Node, symbols, constraints, counter): + full, counter = gen_tvar(counter) + symbols[n] = full + res = [] + + assert isinstance(n.args[0], Iterable) + for arg in n.args[0]: + res.append(symbols[arg]) + c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] + return [c], counter + + +# TODO normalize index +@register_inference_rule(torch.arange) +def arange_inference_rule(n: Node, symbols, constraints, counter): + start = 0 + step = 1 + + if len(n.args) == 1: + end = symbols[n.args[0]] + else: + raise NotImplementedError('Not yet implemented') + + # int((end - start) / step) + d1, counter = gen_dvar(counter) + size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) + arange, counter = gen_tvar(counter) + symbols[n] = arange + + # either the a parameter is a number or it is Dyn + c1 = Disj([BinConstraintD(end, Dyn, op_eq), + BinConstraintD(start, Dyn, op_eq), + BinConstraintD(step, Dyn, op_eq)]) + c2 = BinConstraintD(d1, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + c11 = Conj([BinConstraintD(end, Dyn, op_neq), + BinConstraintD(start, Dyn, op_neq), + BinConstraintD(step, Dyn, op_neq)]) + c22 = BinConstraintD(d1, Dyn, op_neq) + both_numbers = Conj([c11, c22, size_constraint]) + + return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter + +def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): + # additional vars that don't correspond to expressions + e11, counter = gen_tvar(counter) + e22, counter = gen_tvar(counter) + + # generate constraints + c1 = TGreatestUpperBound(output_var, e11, e22) + c2 = ApplyBroadcasting(e11, e22, e1, e2) + c3 = BinConstraintT(e11, e22, op_consistency) + return [c1, c2, c3], counter + + +@register_inference_rule(operator.mul) +@register_inference_rule(torch.ne) +@register_inference_rule("ne") +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def broadcasting_inference_rule(n: Node, symbols, constraints, counter): + + op_code = None + if n.target == operator.add or n.target == torch.add: + op_code = op_add + elif n.target == operator.mul: + op_code = op_mul + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) + else: + raise NotImplementedError('Method not yet implemented') + + elif isinstance(n.args[0], Node) and (isinstance(n.args[1], int) or isinstance(n.args[1], float)): + if isinstance(symbols[n.args[0]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + return [BinConstraintT(my_output, e1, op_eq)], counter + elif isinstance(symbols[n.args[0]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + + # we will propagate the runtime value here since this is regular addition + c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), + BinConstraintD(0, my_output, op_leq)]) + return [c], counter + + elif isinstance(n.args[1], Node) and (isinstance(n.args[0], int) or isinstance(n.args[1], float)): + if isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + return [BinConstraintT(my_output, e2, op_eq)], counter + elif isinstance(symbols[n.args[1]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + + # we will propagate the runtime value here since this is regular addition + c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), + BinConstraintD(0, my_output, op_leq)]) + return [c], counter + + else: + raise NotImplementedError('Method not yet implemented') + + else: + # TODO generate add constraints for scalar addition + raise NotImplementedError('Addition not yet implemented') + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + flattened, counter = gen_tvar(counter) + symbols[n] = flattened + + input = symbols[n.args[0]] + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + c1 = BinConstraintT(input, Dyn, op_eq) + c2 = BinConstraintT(flattened, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + const = [] + for i in range(1, MAX_TENSOR_RANK + 1): + c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) + const.append(c) + + return [Disj([both_dyn, *const])], counter + + +@register_inference_rule(torch.nn.LayerNorm) +def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + Input should be consistent with the normalized_shape + """ + assert isinstance(n.args[0], Node) + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims_rhs) + + c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), + BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + + add_layer_norm_constraints(new_dims_rhs, list(module_instance.normalized_shape)) + + nat_constraints) + c2.append(c_tensor_i) + + + return [Disj([c1, Disj(c2)])], counter + + # return [BinConstraintT(input, output, op_eq), + # BinConstraintT(input, normalized_shape, op_consistency)], counter + +@register_inference_rule(torch.nn.Dropout) +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + assert isinstance(input, TVar) + return [BinConstraintT(input, output, op_eq)], counter + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output sizes should be the same except for the last dimension + If the input is Dyn, then so should the output + """ + assert isinstance(n.args[0], Node) + linear_output, counter = gen_tvar(counter) + symbols[n] = linear_output + linear_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(linear_input, Dyn, op_eq) + output_dyn = BinConstraintT(linear_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, module_instance) + + nat_constraints) + c2.append(c_tensor_i) + + + return [Disj([c1, Disj(c2)])], counter + +def add_layer_norm_constraints(input_dim, normalized_dim): + """ + The constraints say that the type has te form: [*, 1024, 1024] + while the normalized_dim have the form [1024, 1024] + Args: + input_dim: Input shape of layer norm + normalized_dim: normalized_dim parameter of the module instance + + """ + + # in this case we return false since there's a pattern mismatch + if len(normalized_dim) > len(input_dim): + return [F()] + + else: + constraints = [] + for i, n in zip(reversed(input_dim), reversed(normalized_dim)): + constraints.append(BinConstraintD(i, n, op_consistency)) + return constraints + + +def add_linear_constraints(dims1, dims2, module_instance): + assert len(dims1) == len(dims2) + constraints = [] + for i in range(len(dims1)): + if i == len(dims1) - 1: + constraints.append(BinConstraintD(dims1[i], module_instance.in_features, op_consistency)) + constraints.append(BinConstraintD(dims2[i], module_instance.out_features, op_eq)) + else: + constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq)) + + return constraints + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + my_reshape, counter = gen_tvar(counter) + symbols[n] = my_reshape + + src_var = symbols[n.args[0]] + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] + c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] + c2 = CanReshape(src_var, t2_type) + + return [c1, c2], counter + + +@register_inference_rule(BatchNorm2d) +def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + batchnorm_output, counter = gen_tvar(counter) + symbols[n] = batchnorm_output + batchnorm_input = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + avg_pool, counter = gen_tvar(counter) + + symbols[n] = avg_pool + input_var = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) + + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + my_conv, counter = gen_tvar(counter) + symbols[n] = my_conv + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + # c2 = DConsistency(module_instance.in_channels, d2) + c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) + + c3 = CalcConv(my_conv, input_var, + module_instance.out_channels, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, [d1, d2, d3, d4]) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, c3, *nat_constraints], counter + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + maxpool, counter = gen_tvar(counter) + symbols[n] = maxpool + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, + module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, *nat_constraints], counter + + +class ConstraintGenerator: + def __init__(self, traced, graph=None): + self.traced = traced # traced or tracer.root + self.traced_params = dict(self.traced.named_parameters()) + self.constraints = [] + self.symbol_dict = {} + self.graph = traced.graph if hasattr(traced, 'graph') else graph + + + def generate_constraints(self, counter=0): + """ + Iterate through every node and generate constraints + Effect: self.constraints will be populated with the final constraints + """ + graph = self.graph + + all_constraints = [] + + # Annotate with Dyn if no type exists + for n in graph.nodes: + if n.type is None: + n.type = Dyn + + for n in graph.nodes: + (constraints, counter) = self.generate_constraints_node(n, counter) + all_constraints += constraints + + return Conj(all_constraints), counter + + def generate_constraints_node(self, n: Node, counter): + """ + Generate constraints the given node: + Currently supported operations: + - Reshape + - Add + - conv2d + """ + + if n.op == 'placeholder': + x, counter = gen_tvar(counter) + self.symbol_dict[n] = x + c1 = BinConstraintT(n.type, x, op_precision) + c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) + return [c1, c2], counter + + elif n.op == 'call_function': + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for target {n.target}!') + + elif n.op == 'call_module': + + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)](n, + module_instance, + self.symbol_dict, + self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + + elif n.op == 'call_method': + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for target {n.target}!') + + elif n.op == 'get_attr': + t = self.traced_params.get(n.target, None) + + if isinstance(t, torch.Tensor): + if len(t.shape) > 0: + res = [] + for t in t.shape: + res.append(t) + attr_type = TensorType(res) + output, counter = gen_tvar(counter) + self.symbol_dict[n] = output + return [BinConstraintT(output, attr_type, op_eq)], counter + else: + # scalar? + return [], counter + else: + return [], counter + + elif n.op == 'output': + return [], counter + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py new file mode 100644 index 0000000000000..120541d27baef --- /dev/null +++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -0,0 +1,1040 @@ +# mypy: ignore-errors +import copy +import itertools +from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK +from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \ + Transpose +from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool +from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape +from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect +from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching +from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq +from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod +from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar +from torch.fx.tensor_type import TensorType, Dyn +from typing import Callable, Dict, List + +_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} + + +def register_transformation_rule(call_target): + def register(fn): + if call_target in _TRANSFORMATION_RULES: + raise RuntimeError(f'Transformation rule already registered for {call_target}!') + _TRANSFORMATION_RULES[call_target] = fn + return fn + return register + + +def valid_index(index, dims): + """ + Given a list of dimensions, checks if an index is valid in the list + """ + try: + dims[index] + return T() + except IndexError: + return F() + + +@register_transformation_rule(Transpose) +def transform_transpose(constraint, counter): + """ + Similar to a sequence of two index-selects + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index1 = valid_index(constraint.index1, dims) + is_valid_index2 = valid_index(constraint.index2, dims) + new_dims = copy.deepcopy(dims) + nat_constraints = gen_nat_constraints(dims) + + if is_valid_index1 == T() and is_valid_index2 == T(): + new_dims[constraint.index1] = dims[constraint.index2] + new_dims[constraint.index2] = dims[constraint.index1] + + transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index1, is_valid_index2, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + return transformed_constraint, counter + + +@register_transformation_rule(IndexSelect) +def transform_index_select(constraint, counter): + """ + The constraints consider the given tensor size, checks if the index is valid + and if so, generates a constraint for replacing the input dimension + with the required dimension + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index = valid_index(constraint.index, dims) + nat_constraints = gen_nat_constraints(dims) + + # if the index is valid then replace the input dimension with the new dimension + # otherwise the dimension will not be replaced and the clause will contain False + if is_valid_index == T(): + new_dims = copy.deepcopy((dims)) + new_dims[constraint.index] = constraint.dim_replace + + transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + + # print(constraints) + return transformed_constraint, counter + + +@register_transformation_rule(GetItem) +def transform_get_item(constraint, counter): + """ + generate an equality of the form: + t = [a1, ..., an] + then generate constraints that check if the given index is valid + given this particular tensor size. + If the index is valid, generate a constraint to get the item + Note that we already handled the Dyn input case in the previous + step. + Args: + constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) + counter: variable tracking + Returns: simplified constraints for GetItem + + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + + is_valid_index = valid_index(constraint.index, dims) + + all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index] + + # if the index is valid, we generate a constraint for getting an item + # otherwise this clause will have been UNSAT due to the wrong index + if is_valid_index == T(): + all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq)) + + return Conj(all_constraints), counter + +def valid_index_tensor(index, dims): + """ + if the slice instances exceed the length of the dimensions + then this is a type error so we return False + """ + slice_count = 0 + for s in index: + if isinstance(s, slice): + slice_count += 1 + if slice_count > len(dims): + return F() + else: + return T() + +@register_transformation_rule(GetItemTensor) +def transform_get_item_tensor(constraint, counter): + """ + When the index is a tuple, then the output will be a tensor + TODO: we have to check if this is the case for all HF models + + The cases we are covrering here are a tuple with one of: + - slice with default argument + - None + + None appends 1 to the input tensor dimensions + so each occurrence of 'None' increases the rank by 1 + + slice with default arguments does not change the rank + """ + assert isinstance(constraint.index_tuple, tuple) + + + # generate a result tensor of the expected size + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + # generate a place-holder list of the right rank + # where "slice" does not contribute to the rank and "None" does + none_c = constraint.index_tuple.count(None) + resulting_tensor_dims = (none_c + len(dims)) * [None] + + dim_index = 0 + for i in range(len(constraint.index_tuple)): + + # append 1 to the right location of the resulting tensor + if constraint.index_tuple[i] is None: + resulting_tensor_dims[i] = 1 + + elif constraint.index_tuple[i] == slice(None, None, None): + pass + + else: + raise NotImplementedError('Method not yet implemented') + + # append the remaining dimensions to the right location + dim_index = 0 + for i in range(len(resulting_tensor_dims)): + if resulting_tensor_dims[i] is None: + resulting_tensor_dims[i] = dims[dim_index] + dim_index += 1 + + # check if the index is valid + is_valid_index = valid_index_tensor(constraint.index_tuple, dims) + + # check if the resulting tensor is within bounds + if len(resulting_tensor_dims) > 4: + return F(), counter + + else: + constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), + *nat_constraints, + is_valid_index] + return Conj(constraints), counter + + +@register_transformation_rule(BinConstraintT) +def generate_binconstraint_t(constraint, counter): + """ + Transform binary constraints for tensors + """ + + # precision constraints + if constraint.op == op_precision: + if constraint.lhs == Dyn: + return T(), counter + elif isinstance(constraint.lhs, TensorType): + is_fully_static = all([d != Dyn for d in constraint.lhs.__args__]) + if is_fully_static: + return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter + else: + new_dims = [] + + for _ in range(len(constraint.lhs.__args__)): + dim, counter = gen_dvar(counter) + new_dims.append(dim) + + new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for + new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ + [BinConstraintD(1, new_dim, op_leq) for + new_dim in new_dims] + return Conj(new_dim_constraints), counter + + # matching + elif constraint.op == op_matching: + assert isinstance(constraint.rhs, TensorType) + d1 = constraint.rhs.__args__[0] + d2 = constraint.rhs.__args__[1] + d3 = constraint.rhs.__args__[2] + d4 = constraint.rhs.__args__[3] + + conj = [BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintD(d1, Dyn, op_eq), + BinConstraintD(d2, Dyn, op_eq), + BinConstraintD(d3, Dyn, op_eq), + BinConstraintD(d4, Dyn, op_eq)] + return Disj([Conj(conj), + BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter + + elif constraint.op == op_consistency: + c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)]) + [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter) + + return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter + + elif constraint.op == op_leq: + assert isinstance(constraint.rhs, int) + disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] + for i in range(1, constraint.rhs + 1): + dims = [] + for j in range(1, i + 1): + dim_var, counter = gen_dvar(counter) + dims.append(dim_var) + disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) + return Disj(disj), counter + else: + return constraint, counter + + +@register_transformation_rule(BinConstraintD) +def generate_binconstraint_d(constraint, counter): + """ + Transform binary constraints for dimensions + """ + if constraint.op == op_precision: + if isinstance(constraint.lhs, int): + return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter + elif constraint.lhs == Dyn: + return T(), counter + + elif constraint.op == op_consistency: + return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq), + BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter + + else: + return constraint, counter + + +@register_transformation_rule(Conj) +def generate_conj(constraint, counter): + """ + Transform conjunctions + """ + new = [] + for c in constraint.conjucts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Conj(new), counter + + +@register_transformation_rule(Disj) +def generate_disj(constraint, counter): + """ + Transform disjunctions + """ + new = [] + for c in constraint.disjuncts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Disj(new), counter + + +@register_transformation_rule(TGreatestUpperBound) +def generate_gub(constraint, counter): + """ + Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound + on dimensions + """ + c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq), + BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)]) + + [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) + + return Disj([c1, c2, c3, c4, c5]), counter + + +@register_transformation_rule(DGreatestUpperBound) +def generate_d_gub(constraint, counter): + """ + Transform greatest upper bound for dimensions into equality constraints + """ + c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)]) + c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) + c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) + return Disj([c1, c2, c3]), counter + + +@register_transformation_rule(CalcConv) +def generate_calc_conv(constraint, counter): + d, counter = gen_tensor_dims(4, counter) + conv_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the convolution result is a tensor of size 4 + c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) + + # the second dimension of the output is equal to the output channels + c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)]) + + # the input corresponds to the output in the first dimension of the convolution + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq)]) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcMaxPool) +def generate_calc_maxpool(constraint, counter): + """ + Transform maxpool constraints + """ + d, counter = gen_tensor_dims(4, counter) + maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the maxpool result is a tensor of size 4 + c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) + + # the input corresponds to the output in the first and second dimension of maxpool + c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq)]) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcProduct) +def generate_calc_product(constraint, counter): + """ + Transform flatten constraints + """ + start = constraint.start + end = constraint.end + dims = constraint.dims_to_flatten + flattened = constraint.flattened + n = len(constraint.dims_to_flatten) + + # this will be evaluated right here + boundary_check = (0 <= start and start < end and end <= n) + + c_boundary = T() if boundary_check else F() + + lhs = dims[0:start] + rhs = dims[end:] + mid = dims[start:end] + + all_possibilities = generate_all_int_dyn_dim_possibilities(mid) + + all_constraints = [] + + for p in all_possibilities: + p = list(p) + # this tells us there is a dynamic variable + contains_dyn = not(all([constraint.op == op_neq for constraint in p])) + if contains_dyn: + mid_var = [Dyn] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p)) + else: + new_var, counter = gen_dvar(counter) + mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)]) + mid_var = [new_var] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p)) + + return Conj([Disj(all_constraints), c_boundary]), counter + + +@register_transformation_rule(CanReshape) +def generate_reshape(constraint, counter): + """ + Transform reshape constraints + """ + d, counter = gen_tensor_dims(4, counter) + + d1 = d[0] + d2 = d[1] + d3 = d[2] + d4 = d[3] + + target = constraint.target.__args__ + + is_fully_static = all([d != Dyn for d in target]) + + # dynamic tensor + c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) + c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) + c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) + c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) + c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) + + d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) + d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) + + d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) + d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) + + d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + nat_d1 = BinConstraintD(0, d1, op_leq) + nat_d2 = BinConstraintD(0, d2, op_leq) + nat_d3 = BinConstraintD(0, d3, op_leq) + nat_d4 = BinConstraintD(0, d4, op_leq) + + if is_fully_static: + # size 1 tensor + c3_tensor1 = Disj([d1_eq_dyn, + (Conj([d1_neq_dyn, + BinConstraintD(d1, Prod(target), op_eq)]))]) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # size 2 tensor + all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) + + # size 3 tensor + all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) + + # size 4 tensor + all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]) + + return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), + nat_d1, nat_d2, nat_d3, nat_d4]), counter + + # then there must be exactly one occurrence of dyn + else: + new_target = [] + + for n in target: + if n != Dyn: + new_target.append(n) + + # tensor 1 + c3_tensor1 = Disj([d1_eq_dyn, + (Conj([d1_neq_dyn, + is_dim_div_by_target(new_target, d1)]))]) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # tensor 2 + c21 = Disj([d1_eq_dyn, d2_eq_dyn]) + c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]) + all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) + + # tensor 3 + c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) + c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))]) + all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) + + # tensor 4 + c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) + c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))]) + all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) + + return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), + nat_d1, nat_d2, nat_d3, nat_d4]), counter + + +@register_transformation_rule(ApplyBroadcasting) +def generate_broadcasting(constraint, counter): + """ + Transform broadcasting constraints + """ + e11, e12 = constraint.res1, constraint.res2 + e1, e2 = constraint.input1, constraint.input2 + + e1_dyn = BinConstraintT(e1, Dyn, op_eq) + e2_dyn = BinConstraintT(e2, Dyn, op_eq) + + # Introduce dimensions + e1_equal_e11 = BinConstraintT(e1, e11, op_eq) + e2_equal_e12 = BinConstraintT(e2, e12, op_eq) + + # dyn possibility + e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) + e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) + + # tensor possibility + # generate dimensions to create tensors of size 1 + final_tensor_1_constraint, _, _, nat_dims_1, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) + + # generate dimensions to create tensors of size 2 + final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ + final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) + + # generate dimensions to create tensors of size 3 + final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ + final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) + + # generate dimensions to create tensors of size 4 + final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ + final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) + + final_result = Disj([ + e1_dyn_constraint, + e2_dyn_constraint, + final_tensor_1_constraint, + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2 + ]) + + return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter + + +def transform_constraint(constraint: Constraint, counter: int): + """ + Transforms a constraint into a simpler constraint. + Ex: precision and consistency are transformed to equality + Args: + constraint: constraint to be transformed + counter: for variable tracking + + Returns: Constraint + + """ + if type(constraint) in _TRANSFORMATION_RULES: + return _TRANSFORMATION_RULES[type(constraint)](constraint, counter) + + else: + return constraint, counter + + + + +def calc_last_two_dims(constraint, d: List[DVar]): + """ + Generates constraints for the last two dimensions of a convolution or a maxpool output + Args: + constraint: CalcConv or CalcMaxPool + d: The list of output dimensions + + Returns: Constraints for calculating the last two dimensions of the output + + """ + + assert isinstance(constraint, CalcConv) or isinstance(constraint, CalcMaxPool) + + b3 = constraint.matching_constraint[2] + b4 = constraint.matching_constraint[3] + + b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) + b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) + + d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]) + d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]) + + # transform parameters into tuples incase they are not already + padding = (constraint.padding, constraint.padding) \ + if isinstance(constraint.padding, int) else constraint.padding + kernel = (constraint.kernel, constraint.kernel) \ + if isinstance(constraint.kernel, int) else constraint.kernel + stride = (constraint.stride, constraint.stride) \ + if isinstance(constraint.stride, int) else constraint.stride + dilation = (constraint.dilation, constraint.dilation) \ + if isinstance(constraint.dilation, int) else constraint.dilation + + f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) + f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) + f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div) + f4 = BinConstraintD(f3, 1, op_add) + + c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) + + f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) + f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) + f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div) + f44 = BinConstraintD(f33, 1, op_add) + + c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) + + return c4, c5 + + +def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): + """ + Generate all possibilities of being equal or not equal to dyn for my_list + Args: + my_list: List of tensor dimensions + + Returns: A list of a list of constraints. Each list of constraints corresponds to + one possibility about the values of the dimension variables + """ + # generate all possibilities of being equal or not equal to dyn for my_list + eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))] + neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))] + d_possibilities = [] + + for i in zip(eq_possibilities, neq_possibilities): + d_possibilities.append(list(i)) + all_possibilities = list(itertools.product(*d_possibilities)) + return all_possibilities + + +def is_target_div_by_dim(target: List[int], dim: List[DVar]): + """ + Generate constraints to check if the target dimensions are divisible by the input dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) + + +def is_dim_div_by_target(target: List[int], dim: List[DVar]): + """ + Generate constraints to check if the input dimensions is divisible by the target dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) + + +def gen_all_reshape_possibilities(list_of_dims, target): + """ + Consider all possibilities what the input dimensions could be (number or dynamic) + Then generate the appropriate constraints using multiplication or mod depending on the possibility + The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn + for the input. Target is fixed because at most one dimension could be dyn. + We have different cases for this. + + Args: + list_of_dims: The input list of dimensions + target: The tensor we want to reshape to + + Returns: A disjuncition of transformed reshape constraints + + """ + all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) + + all_constraints = [] + + for p in all_possibilities: + to_multiply = [] + + p = list(p) + + for constraint in p: + assert isinstance(constraint, BinConstraintD) + if constraint.op == op_neq: + to_multiply.append(constraint.lhs) + + if not to_multiply: + all_constraints.append(Conj(p)) + + elif len(to_multiply) < len(list_of_dims): + all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) + else: + all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims), + Prod(target), op_eq)])) + + return Disj(all_constraints) + + +def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): + """ + Apply broadcasting to the 'index' dimension of tensor_input1. + Args: + tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1 + tensor_input2: represents the second input + res1: broadcasted result 1 + res2: broadcasted result 2 + index: the index to broadcast + padding: If padding was used, then tensor_input1[index] does not exist + + Returns: + + """ + if tensor_input1[index] is None: + assert padding + + + if not padding: + # then the inputs are the same length so they all have dimensions at "index" + return Conj([BinConstraintD(tensor_input1[index], 1, op_eq), + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + + else: + # we don't set the input dimension to 1, since it doesn't exist. + return Conj([BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + + +def apply_padding(e1_var: TVar, + e11: BinConstraintT, + e2: BinConstraintT, + e12: BinConstraintT, + d2: List[DVar], + d11: List[DVar], + d12: List[DVar], + counter: int): + """ + We are considering the possibility where one input has less dimensions than + another input, so we apply padding to the broadcasted results + + Args: + e1_var: Variable representing the first input where padding will be + e11: constraint of the form e11 = Tensortype[d1, ..., dn] + e2: constraint of the form e2 = Tensortype[d1, ..., dn] + e12: constraint of the form e11 = Tensortype[d1, ..., dn] + d2: Tensor variables for the second input + d11: Tensor variables for the broadcasted first input + d12: Tensor variables for the broadcasted second input + counter: variable tracking + + Returns: A new constraint whose goal is to apply padding to the broadcasted result + + """ + + res = [] + + # pad the shorter input with None so we can pass it to the broadcasting helper function + for i in range(1, len(d2)): + + d1, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) + + e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) + + simulate_padding = [None] * (len(d2) - i) + + assert len(simulate_padding + d1) == len(d2) + + broadcast_padding = [] + + # for every padding size, we also consider broadcasting + for j in range((len(d2) - i)): + broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True)) + + # we consider the possibilities for broadcasting for every dimension. Since we already + # padded d1, we do not consider it while broadcasting + all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1, + d2[(len(d2) - i):], + d11[(len(d2) - i):], + d12[(len(d2) - i):]) + # combine all constraints into a conjunction + c = Conj([e1, e11, e2, e12, + *broadcast_padding, + all_broadcasting_possibilities, + *nat_constraints + ]) + res.append(c) + + return Disj(res), counter + + +def no_broadcast_dim_with_index(d1: List[DVar], + d2: List[DVar], + d3: List[DVar], + d4: List[DVar], + i: int): + """ + Args: + d1: inpput 1 + d2: inpput 2 + d3: simulated broadcasting for input 1 + d4: simulated broadcasting for input 2 + i: the rank of the resulting tensor addition + + Returns: Constraints for when no broadcasting occurs + """ + return Conj([ + Disj([ + Conj([BinConstraintD(d1[i], 1, op_eq), + BinConstraintD(d2[i], 1, op_eq)]), + + Conj([BinConstraintD(d1[i], 1, op_neq), + BinConstraintD(d2[i], 1, op_neq)])]), + + BinConstraintD(d1[i], d3[i], op_eq), + BinConstraintD(d2[i], d4[i], op_eq)]) + + + +def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): + """ + Generate lists of DVar to represent tensor dimensions + Args: + num_tensors: the required number of tensors + dim_size: the number of dimensions for each tensor + counter: variable tracking + + Returns: A list of a list of tensor dimensions + + """ + res = [] + + for _ in range(num_tensors): + dims, counter = gen_tensor_dims(dim_size, counter) + res.append(dims) + + return res, counter + + +def create_equality_constraints_for_broadcasting(e1: TVar, + e2: TVar, + e11: TVar, + e12: TVar, + d1: List[DVar], + d2: List[DVar], + d11: List[DVar], + d12: List[DVar]): + """ + Create equality constraints for when no broadcasting occurs + Args: + e1: Input 1 + e2: Input 2 + e11: Broadcasted input 1 + e12: Broadcasted input 2 + d1: Variables that store dimensions for e1 + d2: Variables that store dimensions for e2 + d11: Variables that store dimensions for e11 + d12: Variables that store dimensions for e22 + + Returns: Four equality constraints + + """ + + e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) + e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) + e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) + e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) + return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] + + +def gen_consistency_constraints(constraint: Constraint, counter: int): + """ + Args: + constraint: Consistency constraint on tensors + counter: for variable tracking + + Returns: Equality and consistency constraints on dimensions + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] + + [BinConstraintD(d1, d2, op_consistency) for + d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints) + + all_constraints.append(c_tensor_i) + + return all_constraints, counter + + +def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): + """ + Args: + constraint: Greatest upper bound on tensors + counter: variable tracking + + Returns: A set of equality constraints and DGreatestUpperBound constraints + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + c = [] + dims1, counter = gen_tensor_dims(i, counter) + c1tensor = TensorType(dims1) + + dims2, counter = gen_tensor_dims(i, counter) + c2tensor = TensorType(dims2) + + dims3, counter = gen_tensor_dims(i, counter) + c3tensor = TensorType(dims3) + + c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), + BinConstraintT(constraint.rhs2, c2tensor, op_eq), + BinConstraintT(constraint.res, c3tensor, op_eq)] + \ + gen_nat_constraints(dims1 + dims2 + dims3) + + assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + for i in range(len(c3tensor.__args__)): + c.append(DGreatestUpperBound(c3tensor.__args__[i], + c1tensor.__args__[i], + c2tensor.__args__[i])) + + all_constraints.append(Conj(c)) + return all_constraints, counter + + +def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): + """ + Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. + We look at all combinations for all dimendions in d1 and d2 + Args: + d1: input1 dimensions + d2: input2 dimensions + d11: broadcasted input1 dimensions + d12: broadcasted input2 dimensions + + Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions + + """ + + size = len(d1) + + res2 = [] + + for i in range(size): + t1 = broadcast_dim(d1, d2, d11, d12, i) + t2 = broadcast_dim(d2, d1, d12, d11, i) + t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) + + res2.append(Disj([t1, t2, t3])) + + return Conj(res2) + + +def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int): + """ + Simulates broadcasting on e1 and e2 and returns the results + respectively in e11 and e12. Because of gradual types, + e1 and e2 may not be equal. Similarly, e11 and e12 may not + be equal. e11 and e12 should be guaranteed to be consistent + as they represent the shapes of the tensors to be added after + broadcasting. + Args: + e1: TVar representing the type of input 1 + e2: TVar representing the type of input 2 + e11: TVar representing the representing broadcasted input 1 + e12: TVar representing the representing broadcasted input 2 + i: The rank of the resulting type of addition + counter: for variable tracking + + Returns: Simplified broadcasting constraints + + """ + dims, counter = gen_lists_of_dims(4, i, counter) + [d1, d2, d3, d4] = dims + nat_dims_i = gen_nat_constraints(list(itertools.chain(*dims))) + + initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12, + d1, d2, d3, d4) + + [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints + + # without padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints, + generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)]) + + # with padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_padding_arg1, counter = \ + apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter) + + final_tensor_constraint_padding_arg2, counter = \ + apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter) + + return final_tensor_constraint_no_padding, \ + final_tensor_constraint_padding_arg1, \ + final_tensor_constraint_padding_arg2, nat_dims_i, counter diff --git a/torch/fx/experimental/migrate_gradual_types/operation.py b/torch/fx/experimental/migrate_gradual_types/operation.py new file mode 100644 index 0000000000000..68bba2d59a760 --- /dev/null +++ b/torch/fx/experimental/migrate_gradual_types/operation.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +op_add = '+' +op_sub = '-' +op_mul = '*' +op_div = '/' +op_eq = '=' +op_neq = '!=' +op_imp = '=>' +op_matching = '⊳' +op_consistency = '~' +op_precision = '⊑' +op_leq = '≤' +op_lt = '<' +op_gt = '>' +op_mod = '%' diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py new file mode 100644 index 0000000000000..1a494845caf39 --- /dev/null +++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -0,0 +1,348 @@ +from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr +from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar +from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator +from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint +from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt +from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod +from torch.fx.tensor_type import TensorType, Dyn + +try: + import z3 # type: ignore[import] + from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D + HAS_Z3 = True + + def transform_to_z3(constraint, counter, dimension_dict): + if isinstance(constraint, Conj): + conjuncts = [] + for c in constraint.conjucts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + conjuncts.append(new_c) + return z3.And(conjuncts), counter + + elif isinstance(constraint, Disj): + disjuncts = [] + for c in constraint.disjuncts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + disjuncts.append(new_c) + return z3.Or(disjuncts), counter + + elif isinstance(constraint, T): + return True, counter + + elif isinstance(constraint, F): + return False, counter + + elif isinstance(constraint, BinConstraintT): + if constraint.op == op_eq: + lhs, counter = transform_var(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_var(constraint.rhs, counter, dimension_dict) + return (lhs == rhs), counter + + else: + raise NotImplementedError('Method not yet implemented') + + elif isinstance(constraint, BinConstraintD): + if constraint.op == op_eq: + + if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): + transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict) + transformed_lhs = z3.Bool(constraint.lhs.c) + return transformed_lhs == transformed_rhs, counter + + elif is_dim(constraint.lhs) and is_dim(constraint.rhs): + # with dimension tranformations we consider the encoding + lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + return lhs == rhs, counter + + else: + # then we have an algebraic expression which means that we disregard the + # first element of the encoding + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs == rhs, counter + + # The assumption here is that the LHS and RHS must be dimensions + elif constraint.op == op_neq: + assert is_dim(constraint.lhs) + assert is_dim(constraint.rhs) + lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + if constraint.rhs == Dyn or constraint.lhs == Dyn: + if constraint.rhs == Dyn: + return lhs.arg(0) == 1, counter + elif constraint.lhs == Dyn: + return rhs.arg(0) == 1, counter + + # if one of the instances is a number + elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): + if isinstance(constraint.lhs, int): + return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + + elif isinstance(constraint.rhs, int): + return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + + else: + return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter + + + elif constraint.op == op_leq: + # if the dimensions are not dyn, this will come into effect + # there would have been another constraint specifying if a given dimension + # is dyn or not + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs <= rhs, counter + + elif constraint.op == op_gt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs > rhs, counter + + elif constraint.op == op_lt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs < rhs, counter + + else: + raise NotImplementedError('operation not yet implemented') + + else: + raise NotImplementedError('Operation not yet implemented') + + + def transform_var(tensor, counter, dimension_dict): + """ + Transforms tensor variables to a format understood by z3 + Args: + tensor: Tensor variable or a tensor type potentially with variable dimensions + Returns: Transformed variable to a z3 format + + """ + if isinstance(tensor, TensorType): + res = [] + for t in tensor.__args__: + transformed, counter = transform_dimension(t, counter, dimension_dict) + res.append(transformed) + + assert len(res) <= 4 + if len(tensor.__args__) == 1: + return tensor_type.tensor1(res[0]), counter + elif len(tensor.__args__) == 2: + return tensor_type.tensor2(res[0], res[1]), counter + elif len(tensor.__args__) == 3: + return tensor_type.tensor3(res[0], res[1], res[2]), counter + elif len(tensor.__args__) == 4: + return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter + + elif tensor == Dyn: + return z3_dyn, counter + + elif isinstance(tensor, TVar): + return z3.Const(tensor.tvar, tensor_type), counter + + def transform_dimension(dimension, counter, dimension_dict): + """ + Takes a dimension variable or a number and transforms it to a tuple + according to our scheme + Args: + dimension: The dimension to be transformed + counter: variable tracking + + Returns: tuple and the current counter + + """ + if dimension == Dyn: + counter += 1 + return D(0, z3.Int(counter)), counter + elif isinstance(dimension, int): + return D(1, dimension), counter + elif isinstance(dimension, DVar): + if dimension.c in dimension_dict: + return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter + else: + counter += 1 + dimension_dict[dimension.c] = counter + return D(z3.Int(counter), z3.Int(dimension.c)), counter + + + def transform_algebraic_expression(expr, counter, dimension_dict): + """ + Transforms an algebraic expression to z3 format + Args: + expr: An expression is either a dimension variable or an algebraic-expression + + + Returns: the transformed expression + + """ + assert is_algebraic_expression(expr) or is_dim(expr) + + if is_dim(expr): + transformed, counter = transform_dimension(expr, counter, dimension_dict) + return transformed.arg(1), counter + + elif isinstance(expr, Prod): + + dims = [] + for dim in expr.products: + assert is_dim(dim) + d, counter = transform_dimension(dim, counter, dimension_dict) + dims.append(d.arg(1)) + return z3.Product(dims), counter + + elif is_algebraic_expression(expr): + + lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict) + + if expr.op == op_sub: + c = lhs - rhs + + elif expr.op == op_add: + c = lhs + rhs + + elif expr.op == op_div: + c = lhs / rhs + + elif expr.op == op_mul: + c = lhs * rhs + + elif expr.op == op_mod: + c = lhs % rhs + + else: + raise NotImplementedError('operation not yet implemented') + + return c, counter + + else: + raise RuntimeError + + + def transform_all_constraints(traced, counter=0): + """ + Given a trace, generates constraints and transforms them to z3 format + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(traced) + new_constraints, counter = generator.generate_constraints(counter) + + # print(new_constraints.conjucts[0]) + # print(*new_constraints.conjucts, sep='\n') + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + # print(new_constraints) + # print(new_constraints.conjucts) + # new_constraints.conjucts = new_constraints.conjucts[:-1] + # print(*new_constraints.conjucts, sep='\n') + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + # print(transformed) + return transformed + + def iterate_till_fixed_point(constraints, counter): + """ + Transform constraints till reaching a fixed point + """ + old_c = None + while old_c != constraints: + old_c = constraints + constraints, counter = transform_constraint(constraints, counter) + return constraints, counter + + def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): + """ + Takes a node and a graph and generates two sets of constraints. + One set constraints the node's constraints and another set + constraints the negation of the node's constraints + Args: + tracer_root: the root for getting the module instances + graph: the graph so far in the tracing process + node: node that represents a conditional + counter: variable tracking + + Returns: Two sets of constraints. One with a conjunction with the + the conditional constraint and the other with a conjunction with + its negation. + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(tracer_root, graph) + new_constraints, counter = generator.generate_constraints(counter) + + condition_constraint = new_constraints.conjucts[-1] + + # we know the constraint is a conjunction where the last constraint is about the conditional + # so remove the last constraint + new_constraints.conjucts = new_constraints.conjucts[:-1] + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + + + # since the function returns a list of one element, we get the first element + # we are only interested in the RHS in this case because the LHS just stores + # the result + + # we make sure the constraint is of the form: + # c = b where b is a boolean expression + # and we consider b (constraint.rhs) for transformation + assert isinstance(condition_constraint.lhs, BVar) + assert is_bool_expr(condition_constraint.rhs) + condition_constraint_rhs = condition_constraint.rhs + + # transform the condition constraint + condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter) + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + + transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict) + + negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) + + return z3.And([transformed, transformed_condition_constraint]),\ + z3.And([transformed, negation_transformed_condition_constraint]) + + + def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None): + """ + Given an IR and a node representing a conditional, evaluate the conditional + and its negation + Args: + tracer_root: Tracer root for module instances + node: The node to be evaluated + + Returns: the results of evaluating the condition and the negation with + the rest of the constraints + + """ + + transformed_positive, transformed_negative = \ + transform_all_constraints_trace_time(tracer_root, graph, node, counter) + + s = z3.Solver() + s.add(transformed_positive) + if user_constraints is not None: + s.add(user_constraints) + condition = s.check() + + s = z3.Solver() + s.add(transformed_negative) + if user_constraints is not None: + s.add(user_constraints) + negation = s.check() + return condition, negation + +except ImportError: + HAS_Z3 = False diff --git a/torch/fx/experimental/migrate_gradual_types/util.py b/torch/fx/experimental/migrate_gradual_types/util.py new file mode 100644 index 0000000000000..a43d8f3ebbe06 --- /dev/null +++ b/torch/fx/experimental/migrate_gradual_types/util.py @@ -0,0 +1,52 @@ +from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ + BVar +from torch.fx.experimental.migrate_gradual_types.operation import op_leq + + +def gen_tvar(curr): + """ + Generate a tensor variable + :param curr: The current counter + :return: a tensor variable and the updated counter + """ + curr += 1 + return TVar(curr), curr + + +def gen_dvar(curr): + """ + Generate a dimension variable + :param curr: the current counter + :return: a dimension variable and an updated counter + """ + curr += 1 + return DVar(curr), curr + +def gen_bvar(curr): + """ + Generate a boolean variable + :param curr: the current counter + :return: a boolean variable and an updated counter + """ + curr += 1 + return BVar(curr), curr + +def gen_tensor_dims(n, curr): + """ + Generate a list of tensor dimensions + :param n: the number of dimensions + :param curr: the current counter + :return: a list of dimension variables and an updated counter + """ + dims = [] + for _ in range(n): + dvar, curr = gen_dvar(curr) + dims.append(dvar) + return dims, curr + + +def gen_nat_constraints(list_of_dims): + """ + Generate natural number constraints for dimensions + """ + return [BinConstraintD(0, d, op_leq) for d in list_of_dims] diff --git a/torch/fx/experimental/migrate_gradual_types/z3_types.py b/torch/fx/experimental/migrate_gradual_types/z3_types.py new file mode 100644 index 0000000000000..897a79d569757 --- /dev/null +++ b/torch/fx/experimental/migrate_gradual_types/z3_types.py @@ -0,0 +1,29 @@ +try: + import z3 # type: ignore[import] + HAS_Z3 = True + # dynamic type + dyn = z3.DeclareSort('Dyn') + dyn_type = z3.Const('dyn', dyn) + + # dimension + dim = z3.Datatype('dim') + dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort())) + dim = dim.create() + + # tensors + tensor_type = z3.Datatype('TensorType') + tensor_type.declare('Dyn', ('dyn', dyn)) + tensor_type.declare('tensor1', ('0', dim)) + tensor_type.declare('tensor2', ('0', dim), ('1', dim)) + tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim)) + tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim)) + tensor_type = tensor_type.create() + + # create dimension + D = dim.dim + + z3_dyn = tensor_type.Dyn(dyn_type) + + +except ImportError: + HAS_Z3 = False diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index abbd90cf7e97c..0bd8de42e0e9d 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -3,25 +3,76 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib +import copy import functools from typing import Any, Dict, Optional, Tuple, Callable, Union import torch from torch._C import _disabled_torch_function_impl import torch.utils._pytree as pytree from torch.fx import Tracer, GraphModule +from torch._subclasses.fake_tensor import FakeTensorMode import torch.fx as fx from torch.utils._mode_utils import no_dispatch from torch.fx.passes.shape_prop import _extract_tensor_metadata -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext -from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode +from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode +from torch._subclasses import FakeTensor +from .symbolic_shapes import ShapeEnv, magic_methods, reflectable_magic_methods +import torch.fx.experimental.symbolic_shapes as symbolic_shapes -__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"] +__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict", "DecompositionInterpreter"] aten = torch.ops.aten CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {} +class ProxySymInt(object): + def __init__(self, sym_int, proxy): + assert isinstance(sym_int, torch._C.SymIntNode) or isinstance(sym_int, int) + self.sym_int = sym_int + self.proxy = proxy + + def wrap(self, num): + return ProxySymInt(num, num) + + def __str__(self): + return f"ProxySymInt({self.sym_int})" + + def __int__(self): + # Not sure how to make mypy support this lol + return int(self.sym_int) # type: ignore[arg-type] + + def __bool__(self): + return bool(self.sym_int) + +import operator + +def create_magic_impl(op): + def magic_impl(self, other): + def unwrap_proxy(x): + return x.proxy if isinstance(x, ProxySymInt) else x + out_proxy = op(unwrap_proxy(self), unwrap_proxy(other)) + + def unwrap_proxyint(x): + return x.sym_int if isinstance(x, ProxySymInt) else x + out_sym_int = op(unwrap_proxyint(self), unwrap_proxyint(other)) + return ProxySymInt(out_sym_int, out_proxy) + return magic_impl + +for method in reflectable_magic_methods: + method_name = f'{method}' + + op = getattr(operator, method_name) + setattr(ProxySymInt, f'r{method_name}', create_magic_impl(op)) + +for method in magic_methods: + method_name = f'{method}' + + op = getattr(operator, method_name) + setattr(ProxySymInt, method_name, create_magic_impl(op)) + @contextmanager def decompose(decomposition_table): @@ -39,32 +90,68 @@ def enable_strict(val): global IS_STRICT IS_STRICT = val -def wrap_output(real_out, proxy_out): - def wrap_with_proxy(e, proxy): - if type(e) == torch.Tensor: +def wrap_output(inner_res, proxy_res, *, constant, proxy_mode): + def wrap_with_proxy(e, proxy, constant): + if isinstance(e, torch.Tensor): with no_dispatch(): - return ProxyTensor(e, proxy) + return ProxyTensor(e, proxy, constant=constant, proxy_mode=proxy_mode) else: return e + def get_constant(idx): + if constant is None: + return None + else: + return constant[idx] + # Unfortunately, tree_map cannot directly be used here. As the resulting # object may be a proxy that represents a tuple, we may need to # explicitly unwrap the proxy by simulating the flattening operations. - if isinstance(real_out, tuple): - return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)) - elif isinstance(real_out, list): - return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]) - elif isinstance(real_out, torch.Tensor): - return wrap_with_proxy(real_out, proxy_out) + if isinstance(inner_res, tuple): + return tuple(wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) for idx, e in enumerate(inner_res)) + elif isinstance(inner_res, list): + return list([wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) for idx, e in enumerate(inner_res)]) + elif isinstance(inner_res, torch.Tensor): + return wrap_with_proxy(inner_res, proxy_res, constant) + else: + return inner_res + + +def maybe_disable_fake_tensor_mode(): + # TODO: figure out if this API generally makes sense and bake it into the + # library + mb_fake_mode = torch._C._get_torch_dispatch_mode() + if isinstance(mb_fake_mode, FakeTensorMode): + return enable_torch_dispatch_mode(mb_fake_mode.inner, replace=mb_fake_mode) else: - return real_out + return nullcontext() + + +def unwrap_elem(e): + if isinstance(e, ProxyTensor): + return e.elem + if isinstance(e, torch._C.SymIntNode): + if isinstance(e.get_pyobj(), ProxySymInt): + return e.get_pyobj().sym_int + else: + raise RuntimeError(f"Something has gone wrong, we are trying to put SymInt {e.get_pyobj()} into the graph," + f"even though it's not a ProxySymInt. This is a bug.") + return e + +def proxy_call(proxy_mode, func_overload, args, kwargs=None): + if kwargs is None: + kwargs = {} -def proxy_call(func_overload, args, kwargs=None): func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: + t, = args + assert not kwargs + if t.constant is not None: + with maybe_disable_fake_tensor_mode(): + return t.constant.item() raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") @@ -75,51 +162,150 @@ def unwrap_proxy(e): proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) - proxy_out = func(*proxy_args, **proxy_kwargs) - + proxy_res = func_overload(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": - args[0].proxy = proxy_out - proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) + args[0].proxy = proxy_res + proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) + inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs)) + + # Needed to sync up metadata for in-place operators that modify metadata + # TODO: instead forward the metadata to the inner tensor so updating + # is not necessary + if torch.Tag.inplace_view in func_overload.tags: # type: ignore[attr-defined] + with no_dispatch(): + func_overload(*args, **kwargs) + + # In some circumstances, we will be tracing in a situation where a tensor + # is *statically* known to be a constant (currently, this only happens if + # you run torch.tensor; deterministic factory functions like torch.arange + # don't get this treatment). When the tensor in question is small, it's + # helpful to due constant propagation in case we call item() (in which + # case we can return the constant value that is known, rather than give + # an error.) The logic here tests if constant propagation is possible + # (because all of the inputs are constant). If so, we disable fake tensor + # mode (if it is on) and do true compute on the constant. + # + # It's worth highlighting that we're making a policy decision here. + # There is a potential that the tensor is actually quite large, and we + # don't actually want to run the compute. The tensor being quite large + # is one of the reasons why factory functions don't get this treatment + # (since they can be quite large; if a parameter is initialized to a + # constant value it will be!) Similarly, there is also a potential + # to run an operator that blows up the size of a small tensor; we don't + # protect against this case, but we could force, e.g., only single + # element constant computation by testing the numel of the result before + # propagating const-ness. Similarly, we don't require the constant to + # live on CPU, but we could. + all_constant = True + any_constant = False + + def check_constant(e): + nonlocal all_constant, any_constant + if isinstance(e, ProxyTensor): + if e.constant is None: + all_constant = False + else: + any_constant = True + + pytree.tree_map(check_constant, args) + pytree.tree_map(check_constant, kwargs) - with no_dispatch(): - real_out = func_overload(*args, **kwargs) + def unwrap_constant(e): + if isinstance(e, ProxyTensor): + return e.constant + return e + + constant = None + # NB: do NOT include factories as constants + if all_constant and any_constant: + with maybe_disable_fake_tensor_mode(): + constant = func_overload( + *pytree.tree_map(unwrap_constant, args), + **pytree.tree_map(unwrap_constant, kwargs) + ) + + # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general + # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs)) + return wrap_output(inner_res, proxy_res, constant=constant, proxy_mode=proxy_mode) - return wrap_output(real_out, proxy_out) class ProxyTensor(torch.Tensor): proxy: fx.Proxy + elem: torch.Tensor + has_sym_ints: bool + proxy_mode: "ProxyTorchDispatchMode" + @staticmethod - def __new__(cls, elem, proxy, *, requires_grad=None): - # Hack to deal with super().__new__ not working for sparse tensors - if elem.is_sparse or requires_grad is not None: - if requires_grad is None: - requires_grad = False - r = torch.Tensor._make_subclass(cls, elem, requires_grad) + def __new__(cls, elem, proxy, *, requires_grad=None, constant=None, proxy_mode): + def create_proxy_symint(sym_int, new_proxy): + return torch._C.SymIntNode.new_symint(ProxySymInt(sym_int, new_proxy)) + + has_sym_ints = symbolic_shapes.has_symbolic_sizes_strides(elem) + if has_sym_ints: + new_shape = [] + for idx, s in enumerate(elem.shape): + if isinstance(s, torch._C.SymIntNode): + new_shape.append(create_proxy_symint(s, proxy.size(idx))) + else: + assert isinstance(s, int) + # If it's not an existing SymIntNodeImpl, just pass the proxy as the int + # _make_wrapper_subclass requires all inputs to be SymIntNodeImpls + new_shape.append(create_proxy_symint(s, s)) + # TODO: hack, since we currently don't support symbolic strides + new_strides = symbolic_shapes.create_contiguous(new_shape) else: - r = super().__new__(cls, elem) # type: ignore[call-arg] + new_shape = elem.shape + new_strides = elem.stride() + + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + new_shape, dtype=elem.dtype, layout=elem.layout, device=elem.device, + requires_grad=requires_grad if requires_grad is not None else False, strides=new_strides, + storage_offset=elem.storage_offset() + ) + r.has_sym_ints = has_sym_ints + return r - if elem.is_sparse: + def __init__(self, elem, proxy, *, requires_grad=None, constant=None, proxy_mode): + # TODO: hack since _extract_tensor_metadata currently tries to access stride + if elem.is_sparse or self.has_sym_ints: proxy.node.meta['tensor_meta'] = {} else: - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) - r.proxy = proxy # type: ignore[attr-defined] + proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self) + # This detects situations where you accidentally put a ProxyTensor + # inside a ProxyTensor for the same trace; this is a layering violation + assert not (isinstance(elem, ProxyTensor) and elem.proxy.tracer is proxy.tracer) + self.elem = elem + self.proxy = proxy + self.constant = constant + self.proxy_mode = proxy_mode - return r def __deepcopy__(self, memo): return self.clone() def __repr__(self): with no_dispatch(): - return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type] + return f"ProxyTensor({self.elem}, proxy={self.proxy})" __torch_function__ = _disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): - return proxy_call(func_overload, args, kwargs) + # Get the first proxy mode. If there are different proxy modes with + # different tracers torch.fx.Proxy would raise an error. + proxy_mode = None + for arg in pytree.tree_flatten((args, kwargs))[0]: + if isinstance(arg, ProxyTensor): + if proxy_mode is None: + proxy_mode = arg.proxy_mode + break + assert proxy_mode is not None, "At least one argument must be a ProxyTensor" + + with proxy_mode.restore(): # type: ignore[union-attr] + return func_overload(*args, **kwargs) class PythonKeyTracer(Tracer): @@ -151,25 +337,24 @@ def create_arg(self, a: Any): setattr(self.root, qualname, a) return self.create_node('get_attr', qualname, (), {}) + elif isinstance(a, torch._C.SymIntNode): + py_symint = a.get_pyobj() + assert isinstance(py_symint, ProxySymInt) + return py_symint.proxy.node return super().create_arg(a) def dispatch_trace( root: Union[torch.nn.Module, Callable], + tracer: Tracer, concrete_args: Optional[Tuple[Any, ...]] = None, - trace_factory_functions: bool = False, ) -> GraphModule: - tracer = PythonKeyTracer() - if trace_factory_functions: - with push_torch_dispatch_mode(functools.partial(ProxyTorchDispatchMode, tracer)): - graph = tracer.trace(root, concrete_args) - else: - graph = tracer.trace(root, concrete_args) + graph = tracer.trace(root, concrete_args) name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ return GraphModule(tracer.root, graph, name) -def wrap_key(f, inps): +def wrap_key(f, inps, proxy_mode): flat_inps, _ = pytree.tree_flatten(inps) @functools.wraps(f) @@ -179,9 +364,12 @@ def wrapped(*args): for idx, arg in enumerate(flat_args): if isinstance(flat_inps[idx], torch.Tensor): with no_dispatch(): - flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=( - flat_inps[idx].is_leaf and flat_inps[idx].requires_grad - )) + flat_args[idx] = ProxyTensor( + flat_inps[idx], + arg, + requires_grad=(flat_inps[idx].is_leaf and flat_inps[idx].requires_grad), + proxy_mode=proxy_mode, + ) else: flat_args[idx] = flat_inps[idx] @@ -197,33 +385,230 @@ def wrapped(*args): class ProxyTorchDispatchMode(TorchDispatchMode): - def __init__(self, tracer): + def __init__(self, tracer, trace_factory_functions=True): self.tracer = tracer + self.trace_factory_functions = trace_factory_functions def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None): + if symbolic_shapes.is_symbolic_op(func_overload): + return symbolic_shapes.handle_symbolic_op(func_overload, args, kwargs) + func = func_overload.overloadpacket + # We don't want to convert torch.tensor constants into tracing objects. + if func_overload == aten.lift.default: + return args[0] if any(tuple(isinstance(arg, ProxyTensor) for arg in pytree.tree_flatten(args)[0])): - return proxy_call(func_overload, args, kwargs) - else: - proxy_out = self.tracer.create_proxy('call_function', func, args, kwargs, + return proxy_call(self, func_overload, args, kwargs) + # When we trace through a torch.tensor invocation, you never actually + # see a torch.ops.aten.tensor call. Instead, the way this function is + # implemented internally is that we allocate a plain tensor (this is + # *guaranteed* to be a plain tensor, we disable all modes when doing + # so), and then call at::lift_fresh on it (to give modes a chance to do + # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed + # to be freshly allocated, so we want lift_fresh to be a no-op (directly + # returning the input argument). + # + # Here is the basic problem: when we trace this sequence of executions + # into an FX graph, what happens to this call sequence? Traditionally, + # tensor constants get interned as buffers on the FX GraphModule. But + # this is dangerous. Consider: + # + # x = torch.tensor(1) + # x.add_(2) + # + # Naively, this traces into: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh(t) + # x.add_(2) + # + # If lift_fresh returns t directly, the subsequent add_ call will + # modify the tensor constant. Really, the problem is we've violated + # the invariant the the argument to lift is fresh. So what we should + # preserve the invariant by replacing lift_fresh with lift_fresh_copy: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh_copy(t) + # x.add_(2) + # + # This is what the overload modification does. + elif self.trace_factory_functions: + if func_overload is torch.ops.aten.lift_fresh.default: + func_overload = torch.ops.aten.lift_fresh_copy.default + + proxy_res = self.tracer.create_proxy('call_function', func_overload, args, kwargs, name=self.tracer.graph._target_to_str(func.__name__)) - with no_dispatch(): - real_out = func_overload(*args, **kwargs) - - return wrap_output(real_out, proxy_out) + inner_res = func_overload(*args, **kwargs) + # If this is a lift, the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point + is_lift = func_overload is torch.ops.aten.lift_fresh_copy.default + if is_lift: + with maybe_disable_fake_tensor_mode(): + constant = args[0].clone() + else: + constant = None + return wrap_output(inner_res, proxy_res, constant=constant, proxy_mode=self) + else: + return func_overload(*args, **kwargs) + + +class DecompositionInterpreter(torch.fx.Interpreter): + def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs): + super().__init__(module, **kwargs) + self.new_graph = new_graph + self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph) + self.decomposition_table = decomposition_table + if self.decomposition_table is None: + self.decomposition_table = {} + self.mode = ProxyTorchDispatchMode(self.tracer) + + def placeholder(self, target, args, kwargs): + out = super().placeholder(target, args, kwargs) + # TODO handle case where the first character of target is '*' + return ProxyTensor(out, torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer), proxy_mode=self.mode) + + def get_attr(self, target, args, kwargs): + out = super().get_attr(target, args, kwargs) + return ProxyTensor(out, torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer), proxy_mode=self.mode) + + # call_function, call_method, call_module get traced automatically by the ProxyTensors. + + def output(self, target, args, kwargs): + out = super().output(target, args, kwargs) + + def unwrap(e): + return e.proxy.node if isinstance(e, ProxyTensor) else e + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args, **kwargs): + # Should enter the mode at least once for being able to restore it later + # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 + with self.mode: + pass + with decompose(self.decomposition_table): + return super().run(*args, **kwargs) + +def make_fx(f, decomposition_table=None, trace_factory_functions=True, tracing_mode="real"): + if tracing_mode != "real" and not trace_factory_functions: + raise ValueError("""\ +use_fake and not trace_factory_functions is not currently supported; if +proxy tensor is not executed as a mode, fake tensors must not be executed +as a mode either (otherwise, we will incorrectly intern fake tensors into +the traced graph module.) However, non-mode execution of fake tensors +is not currently supported (although, in principle, it could be; file +a bug if you need this)""") + + assert tracing_mode in ["real", "fake", "symbolic"] -def make_fx(f, decomposition_table=None, trace_factory_functions=True): if decomposition_table is None: decomposition_table = {} @functools.wraps(f) def wrapped(*args): - phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined] - with decompose(decomposition_table): - t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs), - trace_factory_functions=trace_factory_functions) + phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] + fx_tracer = PythonKeyTracer() + fake_tensor_mode: Any = nullcontext() + if tracing_mode == "real": + fake_tensor_mode = nullcontext() + elif tracing_mode == "fake": + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) + elif tracing_mode == "symbolic": + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) + else: + raise AssertionError(f"Unexpected tracing type: {tracing_mode}") + + proxy_mode = ProxyTorchDispatchMode(fx_tracer, trace_factory_functions=trace_factory_functions) + + def wrap_fake_concrete(x): + if isinstance(x, torch.Tensor): + return fake_tensor_mode.from_tensor(x) # type: ignore[attr-defined] + + return x + + shape_env = ShapeEnv() + + # todo: Figure out a more informative name for symints + def wrap_fake_symbolic(x, sym_shape): + if isinstance(x, torch.Tensor): + val = FakeTensor(fake_tensor_mode, torch.empty(sym_shape, device="meta"), x.device) + return val + return x + + wrap_fn_map = { + "real": lambda x: x, + "fake": wrap_fake_concrete, + } + if tracing_mode == "symbolic": + flat_shapes = shape_env.create_shapes_for_args(args) + flat_args, spec = pytree.tree_flatten(args) + args = pytree.tree_unflatten(list(map(lambda a: wrap_fake_symbolic(a[0], a[1]), zip(flat_args, flat_shapes))), spec) + else: + args = pytree.tree_map(wrap_fn_map[tracing_mode], args) + + with decompose(decomposition_table), fake_tensor_mode, proxy_mode: # type: ignore[attr-defined] + t = dispatch_trace(wrap_key(f, args, proxy_mode), tracer=fx_tracer, concrete_args=tuple(phs)) + + # TODO: kind of a bad way to do it, should maybe figure out a better way + t.shape_env = shape_env # type: ignore[assignment] return t return wrapped + + +def get_torch_dispatch_modes(): + modes = [torch._C._get_torch_dispatch_mode()] + if modes[-1] is None: + return list() + while modes[-1].inner is not None: + modes.append(modes[-1].inner) + return modes + + +def get_isolated_graphmodule(func, args, kwargs): + """A helper function used to get the GraphModule for the given func. + + It's expected to be used in the ProxyTensor tracing context. + It detaches the args and kwargs from the current tracer so that the trace of + the current graph module can be created without any side-effects. + """ + # make_fx doesn't support kwargs, so we need to do this flattening + # and then unflatten the args before calling func + all_args, spec = pytree.tree_flatten((args, kwargs)) + + def wrapped(args): + fn_args, fn_kwargs = pytree.tree_unflatten(args, spec) + return func(*fn_args, **fn_kwargs) + + unwrapped_all_args = [unwrap_elem(a) for a in all_args] + + # Current implementation doesn't support the case when ProxyTensor is + # wrapped with another Tensor subclass + # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068 + assert all( + getattr(a, "elem", None) is None + for a in unwrapped_all_args + if isinstance(a, torch.Tensor) + ), "ProxyTensor is wrapped with another Tensor subclass" + + + with contextlib.ExitStack() as stack: + modes = get_torch_dispatch_modes() + # Disable all torch dispatch modes + for mode in modes: + stack.enter_context(enable_torch_dispatch_mode(mode.inner, replace=mode)) + assert torch._C._get_torch_dispatch_mode() is None + + # Enable all torch dispatch modes except ProxyTorchDispatchMode + for mode in reversed([m for m in modes if not isinstance(m, ProxyTorchDispatchMode)]): + mode = copy.copy(mode) + mode.inner = torch._C._get_torch_dispatch_mode() + mode.ancestors = set() if mode.inner is None else mode.inner.ancestors.union({mode.inner}) + stack.enter_context(mode.restore()) + gm = make_fx(wrapped)(unwrapped_all_args) + assert all(m1 == m2 for m1, m2 in zip(modes, get_torch_dispatch_modes())) + assert all(m1.inner == m2.inner for m1, m2 in zip(modes, get_torch_dispatch_modes())) + return gm diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py new file mode 100644 index 0000000000000..098fff6e00ad9 --- /dev/null +++ b/torch/fx/experimental/symbolic_shapes.py @@ -0,0 +1,155 @@ +import torch +import torch.utils._pytree as pytree +from typing import Dict, Any + +try: + import sympy # type: ignore[import] + HAS_SYMPY = True +except ImportError: + HAS_SYMPY = False + +aten = torch.ops.aten + +__all__ = ["has_symbolic_sizes_strides", "create_contiguous", "is_symbolic_op", "handle_symbolic_op", "PySymInt", "ShapeEnv"] + +def has_symbolic_sizes_strides(elem): + return any([isinstance(i, torch._C.SymIntNode) for i in elem.shape]) + +def create_contiguous(shape): + strides = [1] + for dim in reversed(shape[:-1]): + strides.append(dim * strides[-1]) + return list(reversed(strides)) + + +def is_symbolic_op(func): + return func in [aten.sym_size.default, aten.dim.default, aten.is_contiguous.default, aten.stride.default] + + +def handle_symbolic_op(func, args, kwargs): + assert is_symbolic_op(func) + if func == torch.ops.aten.sym_size.default: + return None + if func == torch.ops.aten.dim.default: + return len(args[0].shape) + # TODO: hack, need to make is_contiguous calls symbolic (probably through computing on symbolic strides) + if func == torch.ops.aten.is_contiguous.default: + return True + # TODO: hack, we don't currently support symbolic strides properly + if func == torch.ops.aten.stride.default: + return create_contiguous(args[0].shape) + +# TODO: An incomplete list +# 1. Set variables to be equal when we do equality +# 2. Specialize on 0/1 when we do subtraction +class PySymInt(object): + """ + PySymInt objects are the primary "symbolic shape" objects that flow through + our program. They're what sit under FakeTensor, and contains our primary + implementation of symbolic shapes. + """ + def __init__(self, expr, shape_env): + self.expr = expr + self.shape_env = shape_env + + def wrap(self, num): + return PySymInt(sympy.Integer(num), self.shape_env) + + def __str__(self): + return f"PySymInt({self.expr})" + + # Today we error on calling int on a symbolic shape, as this is a very accessible footgun. + # In the future we'll probably need some explicit way of allowing this + def __int__(self): + raise RuntimeError("Trying to extract a concrete int out of a symbolic int") + + def __bool__(self): + return bool(self.shape_env.evaluate_expr(self.expr)) + +# Methods that have a `__foo__` as well as `__rfoo__` +reflectable_magic_methods = { + 'add': lambda a, b: a + b, + 'sub': lambda a, b: a - b, + 'mul': lambda a, b: a * b, + 'mod': lambda a, b: a % b, +} + +magic_methods = { + **reflectable_magic_methods, + 'eq': lambda a, b: sympy.Eq(a, b), + 'gt': lambda a, b: sympy.Gt(a, b), + 'lt': lambda a, b: sympy.Lt(a, b), + 'le': lambda a, b: sympy.Le(a, b), + 'ge': lambda a, b: sympy.Ge(a, b), +} + +for method, _func in magic_methods.items(): + method_name = f'{method}' + + def _create_magic_impl(func): + def magic_impl(self, other): + if isinstance(other, PySymInt): + other = other.expr + return PySymInt(func(self.expr, other), self.shape_env) + return magic_impl + + # this should be wrapped transparently into torch._C.SymIntNode + setattr(PySymInt, method_name, _create_magic_impl(_func)) + +class ShapeEnv(object): + def __init__(self): + self.guards = [] + self.shape_env = {} + + def create_symint(self, name, val, shape_env=None): + if not HAS_SYMPY: + raise RuntimeError("Need sympy installed to create symbolic shapes") + if shape_env is None: + shape_env = self.shape_env + # Currently we don't put 0/1 specialization in guards but perhaps we should + if val == 0 or val == 1: + return val + sympy_expr = sympy.Symbol(name, positive=True) + py_sym_int = PySymInt(sympy_expr, self) + cpp_sym_int = torch._C.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined] + shape_env[sympy_expr] = val + return cpp_sym_int + + def try_constantify(self, expr): + # Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values) + new_shape_env = {k: sympy.Symbol(f'shape_{idx}', positive=True) + 1 for idx, k in enumerate(self.shape_env.keys())} + new_expr = expr.subs(new_shape_env) + new_expr = new_expr.simplify() + if len(list(new_expr.free_symbols)) == 0: + return new_expr + return None + + def create_shapes_for_args(self, args, shape_env=None): + # Takes pytrees and returns a flat list + arg_cnt = 0 + + def create_shape(x): + nonlocal arg_cnt + if not isinstance(x, torch.Tensor): + return x + + out_shape = [self.create_symint(f"s_{arg_cnt}[{idx}]", sz, shape_env) for idx, sz in enumerate(x.shape)] + arg_cnt += 1 + return out_shape + return list(map(create_shape, pytree.tree_flatten(args)[0])) + + def evaluate_guards_for_args(self, *args): + env: Dict[Any, Any] = {} + _ = self.create_shapes_for_args(args, shape_env=env) + return all(guard.subs(env) == value for guard, value in self.guards) + + + def evaluate_expr(self, expr): + const_expr = self.try_constantify(expr) + if const_expr is not None: + return const_expr + + expr = expr.simplify() + concrete_val = expr.subs(self.shape_env) + self.guards.append((expr, concrete_val)) + return concrete_val diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 27a6082f7cea0..647b3520669c3 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -907,7 +907,7 @@ def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> boo "GraphModule.add_parameter to add the " "necessary Parameter, or " "nn.Module.register_buffer to add the " - "necessary buffer") + "necessary buffer", stacklevel=2) return self.create_node('get_attr', qualified_name, type_expr=type_expr) @compatibility(is_backward_compatible=True) @@ -1289,6 +1289,14 @@ def forward(self, x): def forward(self, x): return x + self.attr_1 + + .. warning:: + + Dead code elimination has some heuristics to avoid removing + side-effectful nodes (see Node.is_impure) but in general coverage + is very bad, so you should assume that this method is not sound + to call unless you know that your FX graph consists entirely + of functional operations. """ # Lint the graph first to make sure its topologically sorted, otherwise # DCE below will not behave as expected. diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 89125185b8507..4f8bcc13d9e38 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -419,8 +419,11 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModu Path(folder).mkdir(exist_ok=True) torch.save(self.state_dict(), folder / 'state_dict.pt') tab = " " * 4 + custom_builtins = '\n'.join([v.import_str for v in _custom_builtins.values()]) model_str = f""" import torch +{custom_builtins} + from torch.nn import * class {module_name}(torch.nn.Module): def __init__(self): @@ -703,6 +706,25 @@ def __deepcopy__(self, memo): def __copy__(self): return GraphModule(self, self.graph) + @compatibility(is_backward_compatible=False) + def nested_str(self) -> str: + """ + Return the Python code generated for current GraphModule and its children GraphModules + """ + module_code = self.code + module_code = module_code.lstrip('\n') + module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule in self.children(): + if isinstance(submodule, GraphModule): + submodule_code_list.append(submodule.__nested_code()) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + return module_code + submodule_code + def __str__(self) -> str: orig_str = super().__str__() return '\n'.join([orig_str, self._code]) diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index cf945aa30eea0..df6404d393f9a 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -7,6 +7,8 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import inspect +__all__ = ['Interpreter', 'Transformer'] + @compatibility(is_backward_compatible=True) class Interpreter: """ diff --git a/torch/fx/node.py b/torch/fx/node.py index 6a9cfd5803e14..66a94154f9466 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -11,6 +11,8 @@ if TYPE_CHECKING: from .graph import Graph +__all__ = ['Node', 'map_arg', 'map_aggregate'] + BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout] base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] diff --git a/torch/fx/passes/README.md b/torch/fx/passes/README.md new file mode 100644 index 0000000000000..a2996848713e5 --- /dev/null +++ b/torch/fx/passes/README.md @@ -0,0 +1,20 @@ +## FX Pass Infrastructure +This folder contains the pass infarstructure and passes for transforming fx.Graph. + + +## Code Structure + +* [infra](infra) - Common infrastructure, such as PassManager, PassBase + * [partitioner.py](infra/partitioner.py) - backend agnostic FX graph partitioner +* [utils](utils) - Utility classes and functions + * [common.py](utils/common.py) - common utility functions + * [fuser_utis.py](utils/fuser_utils.py) - utility functions for fusing list of nodes into a single node +* [dialect](dialect) - dialect specific passes + * [common](dialect/common) - common passes that can be shared by all dialects + * [cse_pass.py](dialect/common/cse_pass.py) - a CSE pass + * [aten](dialect/aten) - aten dialect specific passes + * [prims](dialect/prims) - prim dialect specific passes +* [backends](backends) - Backend specific passes + * [nvfuser](backends/nvfuser) - passes for nvfuser + * [operator_support.py](backends/nvfuser/operator_support.py) - nvFuser supported ops +* [conversion](conversion) - Conversion passes between dialects diff --git a/torch/fx/passes/__init__.py b/torch/fx/passes/__init__.py index c1930795b278c..9577d6c66a9ea 100644 --- a/torch/fx/passes/__init__.py +++ b/torch/fx/passes/__init__.py @@ -3,6 +3,7 @@ from . import net_min_base from . import operator_support from . import param_fetch +from . import reinplace from . import shape_prop from . import split_module from . import split_utils diff --git a/torch/fx/passes/backends/__init__.py b/torch/fx/passes/backends/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/fx/passes/backends/cudagraphs.py b/torch/fx/passes/backends/cudagraphs.py new file mode 100644 index 0000000000000..7aa4aed45ccf4 --- /dev/null +++ b/torch/fx/passes/backends/cudagraphs.py @@ -0,0 +1,53 @@ +import torch +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.utils._pytree import tree_map + +import operator + +class CudaGraphsSupport(OperatorSupport): + # TODO: why is submodules passed here + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op not in CALLABLE_NODE_OPS: + return False + + if node.target in [torch.ops.aten.embedding_dense_backward.default]: + return False + + if node.target in [operator.getitem]: + return True + + found_not_cuda = False + + def find_not_cuda(t): + nonlocal found_not_cuda + if isinstance(t, torch.Tensor) and t.device.type != 'cuda': + found_not_cuda = True + + for n in node.all_input_nodes: + tree_map(find_not_cuda, n.meta['fake_result']) + + tree_map(find_not_cuda, node.meta['fake_result']) + + # NB: factory function is accounted for because the result would be + # cpu or cuda + + return not found_not_cuda + +def partition_cudagraphs(gm, inputs): + """ + Partition an FX graph into sub-GraphModules that can be validly run under + CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations + must involve CUDA tensors only/ + """ + + FakeTensorProp(gm).propagate(*inputs) + supported_ops = CudaGraphsSupport() + # TODO: single node partition may be wrong due to the pessimization + # from copying in and out the data. Check in benchmarks, perhaps + partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True) + partitions = partitioner.propose_partitions() + fused_graph = partitioner.fuse_partitions(partitions) + return fused_graph diff --git a/torch/fx/passes/backends/nvfuser.py b/torch/fx/passes/backends/nvfuser.py new file mode 100644 index 0000000000000..8c8fcbb6e1c36 --- /dev/null +++ b/torch/fx/passes/backends/nvfuser.py @@ -0,0 +1,286 @@ +from typing import Dict + +import torch +from torch.nn import Module +from torch._ops import OpOverload + +from torch.fx import GraphModule +from torch.fx.node import Node, _get_qualified_name +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch._prims.executor import execute +from torch.fx.experimental.proxy_tensor import DecompositionInterpreter +from torch._decomp import decomposition_table + +import typing as t + +import logging + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + +def aten_to_dtype(self, dtype: torch.dtype, **kwargs): + if len(kwargs) > 0 or not dtype: + raise RuntimeError("No support for other to.dtype() formats other than to.dtype(self, dtype)") + return torch._prims.convert_element_type(self, dtype) + +# decomposition_table currently contains both aten2aten and aten2prim decomposition +# this is a hack to seperate them, as we only need aten2prim decomposition for nvfuser-supported aten graph lowering +aten2aten_decomp = {} +aten2prim_decomp = {} + +for op, decomp_fn in decomposition_table.items(): + if "torch._refs" in decomp_fn.__module__: + aten2prim_decomp[op] = decomp_fn + else: + aten2aten_decomp[op] = decomp_fn + +aten2aten_decomp_skips = { + "aten.native_layer_norm_backward.default", + "aten.embedding_dense_backward.default", # This is hurting nvfuser's perf + "aten.addmm.default" +} + +for op, decomp_fn in decomposition_table.items(): + if "torch._refs" in decomp_fn.__module__: + aten2prim_decomp[op] = decomp_fn + else: + if str(op) not in aten2aten_decomp_skips: + aten2aten_decomp[op] = decomp_fn + + +aten2prim_decomp[torch.ops.aten.to.dtype] = aten_to_dtype + + +class NvFuserOperatorSupport(OperatorSupport): + """ + Operator support for nvFuser backend. + + Currently, partitioning is based on FX ATen graph. The fused subgraph will latter be decomposed into prims. + To determine if an ATen ops is supported by nvFuser, we shall check the prim ops used in its ref decomposition. + Only if all the prim ops in the ref has a nvfuser_impl, we say this Aten op is suppported by nvFuser. + + Note: When adding a rule, please add it to the corresponding section and follow the + alphabetical order. + """ + + def __init__(self): + + # TODO: current list copied from torch/csrc/jit/codegen/cuda/parser.cpp is incorrect, + # as that file is solely for TorchScript and doesn't represent the actual status + # whether operation would be runnable by primTorch+nvFuser. + # We will iterate on this list to reflect the the reality. + support_dict = { + # =============================================================== + # call_function aten + # =============================================================== + # Following supported aten ops is copied from torch/csrc/jit/codegen/cuda/parser.cpp + # TODO: might need to update according to supported input types + "torch.ops.aten.add": None, + "torch.ops.aten.sub": None, + # "torch.ops.aten.rsub": None, # rsub decomp is supported at aten2aten level + "torch.ops.aten.div": None, + "torch.ops.aten.atan2": None, + "torch.ops.aten.mul": None, + "torch.ops.aten.max": None, + "torch.ops.aten.min": None, + "torch.ops.aten.pow": None, + "torch.ops.aten.remainder": None, + "torch.ops.aten.fmod": None, + "torch.ops.aten.bitwise_and": None, + "torch.ops.aten.__and__": None, + "torch.ops.aten.bitwise_or": None, + "torch.ops.aten.__or__": None, + "torch.ops.aten.bitwise_xor": None, + "torch.ops.aten.__xor__": None, + "torch.ops.aten.bitwise_left_shift": None, + "torch.ops.aten.__lshift__": None, + "torch.ops.aten.bitwise_right_shift": None, + "torch.ops.aten.__rshift__": None, + "torch.ops.aten.eq": None, + "torch.ops.aten.ne": None, + "torch.ops.aten.ge": None, + "torch.ops.aten.gt": None, + "torch.ops.aten.le": None, + "torch.ops.aten.lt": None, + "torch.ops.aten.abs": None, + "torch.ops.aten.bitwise_not": None, + "torch.ops.aten.ceil": None, + "torch.ops.aten.floor": None, + "torch.ops.aten.frac": None, + "torch.ops.aten.neg": None, + "torch.ops.aten.relu": None, + "torch.ops.aten.round": None, + "torch.ops.aten.silu": None, + "torch.ops.aten.trunc": None, + "torch.ops.aten.log": None, + "torch.ops.aten.log10": None, + "torch.ops.aten.log1p": None, + "torch.ops.aten.log2": None, + "torch.ops.aten.lgamma": None, + "torch.ops.aten.exp": None, + "torch.ops.aten.expm1": None, + "torch.ops.aten.erf": None, + "torch.ops.aten.erfc": None, + "torch.ops.aten.cos": None, + "torch.ops.aten.acos": None, + "torch.ops.aten.cosh": None, + "torch.ops.aten.sin": None, + "torch.ops.aten.asin": None, + "torch.ops.aten.sinh": None, + "torch.ops.aten.tan": None, + "torch.ops.aten.atan": None, + "torch.ops.aten.tanh": None, + "torch.ops.aten.atanh": None, + "torch.ops.aten.sqrt": None, + "torch.ops.aten.rsqrt": None, + "torch.ops.aten.reciprocal": None, + "torch.ops.aten.sigmoid": None, + "torch.ops.aten.isfinite": None, + "torch.ops.aten.isinf": None, + "torch.ops.aten.isnan": None, + "torch.ops.aten.isneginf": None, + "torch.ops.aten.isposinf": None, + "torch.ops.aten.isreal": None, + # "torch.ops.aten.rand_like": None, # causing Node empty_like_default does not support nvfuser + "torch.ops.aten.softplus": None, + "torch.ops.aten.threshold": None, + # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.new_zero op + # "torch.ops.aten.threshold_backward": None, + "torch.ops.aten.clamp": None, + # "torch.ops.aten.clone": None, + # Failing with where(): incompatible function arguments: \ + # [aten->prim decomp, aten2aten is using unsupported aten.div + # "torch.ops.aten.native_layer_norm_backward": None, + "torch.ops.aten.softmax.int": None, + "torch.ops.aten.log_softmax.int": None, + # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.amax + # "torch.ops.aten._softmax": None, + "torch.ops.aten._log_softmax_backward_data": None, + # "torch.ops.aten._softmax_backward_data": None, # Node _softmax_backward_data_default does not support nvfuser + # "torch.ops.aten.var.dim": None, # missing refs + "torch.ops.aten.std.dim": None, + "torch.ops.aten.sum": None, + # "torch.ops.aten.mean.dim": None, # missing refs + "torch.ops.aten._grad_sum_to_size": None, + "torch.ops.aten.sum_to_size": None, + "torch.ops.aten._autocast_to_reduced_precision": None, + "torch.ops.aten._autocast_to_full_precision": None, + # "torch.ops.aten.to.dtype": None, # causing segfault + # "torch.ops.aten.type_as": None, # missing refs + "torch.ops.aten.linear": None, + "torch.ops.aten.gelu": None, + # "torch.ops.aten.gelu_backward": None, # gelu_backward is handled at aten2aten decomp + # "torch.ops.aten.hardtanh": None, # has functional ref, using unsupported aten.clamp + "torch.ops.aten.leaky_relu": None, + "torch.ops.aten.square": None, + # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.conj_physical + "torch.ops.aten.tanh_backward": None, + # "torch.ops.aten.amax": None, # missing prim decomp + # "torch.ops.aten.amin": None, # missing prim decomp + # "torch.ops.aten.reshape": None, + # "torch.ops.aten.view": None, # missing prim decomp + "torch.ops.aten.flatten.using_ints": None, + + # =============================================================== + # call_function builtins and operator + # =============================================================== + "getattr": None, + "_operator.getitem": None, + } + + super().__init__(support_dict) + + def is_node_supported( + self, submodules: t.Mapping[str, Module], node: Node + ) -> bool: + + # nvFuser FX subgraph should be purely functional + if node.op not in CALLABLE_NODE_OPS: + return False + + # ops in supported_dict doesn't have overload name + # use overloadpacket's qualified_name for OpOverload + if isinstance(node.target, OpOverload): + target = _get_qualified_name(node.target.overloadpacket) + if target in self._support_dict: + return True + + return super().is_node_supported(submodules, node) + + +class NvFuserBackend: + def __init__(self): + self.supported_ops = NvFuserOperatorSupport() + + # TODO: this is a naive implementation of cache without proper guard + self.partitioner_cache: Dict[GraphModule, GraphModule] = {} + + # TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs + self.prim_decomp_cache: Dict[GraphModule, GraphModule] = {} + + def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs): + # `graph_module` is an Aten-Fx graph + # "lowering to prims" and "trace execution" are grouped into this function, as they are both input dependent + + if graph_module in self.prim_decomp_cache: + logging.debug("prim_decomp_cache hit!") + prim_module = self.prim_decomp_cache[graph_module] + else: + prim_graph = torch.fx.Graph() + DecompositionInterpreter(graph_module, prim_graph, decomposition_table=aten2prim_decomp).run(*args, **kwargs) + prim_module = torch.fx.GraphModule(graph_module, prim_graph) + self.prim_decomp_cache[graph_module] = prim_module + + logging.debug("Lower to prims graph: ", prim_module.code) + + # invokes trace executor for running the prim graph + return execute(prim_module, *args, executor="nvfuser") + + def compile(self, graph_module: GraphModule) -> GraphModule: + # entry function for nvFuser backend + logging.debug("Compiling graph_module: ", graph_module.code) + + # FX graph based partitioning based on nvfuser supported ops + if graph_module in self.partitioner_cache: + logging.debug("partitioner_cache hit!") + fused_graph_module = self.partitioner_cache[graph_module] + else: + partitioner = CapabilityBasedPartitioner( + graph_module, self.supported_ops, allows_single_node_partition=False) + fused_graph_module = partitioner.partition_and_fuse() + + self.partitioner_cache[graph_module] = fused_graph_module + + # Overriding fused_module's __call__() function with lower_to_prims_and_execute() + for node in fused_graph_module.graph.nodes: + # TODO: use a better way to identify fused submodule + if node.op == "call_module" and "fused_" in node.name: + fused_module = getattr(fused_graph_module, node.name) + fused_module._wrapped_call = self.lower_to_prims_and_execute + + return fused_graph_module + + def __call__(self, graph_module: GraphModule, _) -> GraphModule: + # wrap self.compile as __call__ function to fit the interface for AOTAutograd's fw_compiler + return self.compile(graph_module) diff --git a/torch/fx/passes/dialect/__init__.py b/torch/fx/passes/dialect/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/fx/passes/dialect/common/__init__.py b/torch/fx/passes/dialect/common/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/fx/passes/dialect/common/cse_pass.py b/torch/fx/passes/dialect/common/cse_pass.py new file mode 100644 index 0000000000000..fdfdc791569b5 --- /dev/null +++ b/torch/fx/passes/dialect/common/cse_pass.py @@ -0,0 +1,112 @@ +from typing import Dict, Tuple, Any + +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._pytree import tree_flatten + +from torch.fx import GraphModule, Graph +from torch.fx import Node + +aten = torch.ops.aten + + +# stateful ops are banned from CSE +rand_ops = set([aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]) # noqa: E501 + +inplace_ops = set([aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_]) # noqa: E501 + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def get_CSE_banned_ops(): + return rand_ops.union(inplace_ops) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +class CSEPass(PassBase): + + def __init__(self, banned_ops=None): + """ + This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. + + For functional dialects, user would only need to specify the random ops in ban list. + + Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. + If your dialect contains stateful operators, please customized the banned_ops. + + """ + if banned_ops is None: + banned_ops = set() + self.banned_ops = banned_ops + super().__init__() + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Return a new copy of torch.fx.GraphModule with CSE applied to the input graph + + Example usage: + + from torch.fx.experimental.proxy_tensor import make_fx + def f(a): + b = a * a + c = a * a + return b+c + + p = CSEPass() + traced_graph = make_fx(f)(torch.tensor(1)) + print(traced_graph) + result = p(traced_graph) + print(result.graph_module) + """ + def get_aten_target(node): + if hasattr(node.target, 'overloadpacket'): + return node.target.overloadpacket + return node.target + + modified = False + new_graph = Graph() + env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph + hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph + token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token + for n in graph_module.graph.nodes: + # The placeholder, output, and get_attr nodes are copied to the new grpah without change + # do not CSE away random operations + if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops: + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs memebrs to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, Node) and v in env: + arg_list[i] = env[v] + return tuple(arg_list), spec + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = {"target": n.target, "args": args, "args_spec": args_spec, + "kwargs": kwargs, "kwargs_spec": kwargs_spec} + + # hash substituted args to a number, do not hash specs because specs are not hashable + hash_arg = hash((args, kwargs)) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + modified = True # substition happens and the graph is modified + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + csed_gm = GraphModule(graph_module, new_graph) + return PassResult(csed_gm, modified) diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py new file mode 100644 index 0000000000000..427ff41761f43 --- /dev/null +++ b/torch/fx/passes/fake_tensor_prop.py @@ -0,0 +1,30 @@ +import torch.fx +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch._subclasses.fake_tensor import FakeTensorMode + +__all__ = ['FakeTensorProp'] + +@compatibility(is_backward_compatible=False) +class FakeTensorProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and record a fake tensor representing + the metadata for the node. Unlike ShapeProp, (1) this propagation + is cheap--it does the propagation with meta tensors which do not actually + store data, and (2) the fake tensors have much more fine grained information, + e.g., they have accurate alias information that can be consulted by looking + at the storages. + + Args: + module (GraphModule): The module to be executed + """ + + def run_node(self, n: Node): + result = super().run_node(n) + n.meta['fake_result'] = result + return result + + def propagate(self, *args): + with FakeTensorMode.push() as mode: + fake_args = [mode.from_tensor(a) for a in args] + return super().run(*fake_args) diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index 045f019b587cb..42416066904d8 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -9,6 +9,7 @@ from torch.fx._compatibility import compatibility from itertools import chain +__all__ = ['FxGraphDrawer'] try: import pydot HAS_PYDOT = True @@ -65,12 +66,13 @@ def __init__( graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, skip_node_names_in_args: bool = True, ): self._name = name self._dot_graphs = { name: self._to_dot( - graph_module, name, ignore_getattr, skip_node_names_in_args + graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args ) } @@ -87,6 +89,7 @@ def __init__( leaf_node, f"{name}_{node.target}", ignore_getattr, + ignore_parameters_and_buffers, skip_node_names_in_args, ) @@ -258,10 +261,13 @@ def _to_dot( graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool, + ignore_parameters_and_buffers: bool, skip_node_names_in_args: bool, ) -> pydot.Dot: """ - Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph + Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. + If ignore_parameters_and_buffers is True, the parameters and buffers + created with the module will not be added as nodes and edges. """ dot_graph = pydot.Dot(name, rankdir="TB") @@ -296,7 +302,7 @@ def get_module_params_or_buffers(): if node.op == "call_module": leaf_module = self._get_leaf_node(graph_module, node) - if not isinstance(leaf_module, torch.fx.GraphModule): + if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): get_module_params_or_buffers() for node in graph_module.graph.nodes: diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index b6ad4066f2cfe..f6e53f0e969a1 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -1,20 +1,18 @@ -from typing import Any, Dict, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, NamedTuple, Optional import torch from torch.fx._compatibility import compatibility from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule from torch.fx.node import ( - _get_qualified_name, - Argument, - map_aggregate, map_arg, Node, Target, ) -from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes from torch.fx.passes.shape_prop import ShapeProp +__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta', + 'get_size_of_node'] @compatibility(is_backward_compatible=False) def replace_target_nodes_with( @@ -110,360 +108,3 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: total_size = size_per_elem_bytes * total_num_of_elems output_size = size_per_elem_bytes * output_elem return size_bytes(output_size, total_size) - - -@compatibility(is_backward_compatible=False) -def serialize_shape(shape: torch.Size) -> str: - return str(list(shape)) - - -@compatibility(is_backward_compatible=False) -def serialize_stride(stride: Tuple[int]) -> str: - return str(list(stride)) - - -@compatibility(is_backward_compatible=False) -def serialize_tensor_quantization( - tensor: torch.Tensor, weights: Dict, pcq_prefix: str -) -> Tuple[Dict, Dict]: - """ - Args: - tensor: The tensor from which we try to extract quantization information. - weights: A dict that contains mapping from name to a tensor value. - pcq_prefix: A string that we would use later on as prefix for per channel quantization information. This - usually would be the key that we use to store info of `tensor`. - - Returns: - scheme: Dict that stores the quantization information of `tensor`. - per_channel_dict: Dict that stores the information of per_channel_scales and - per_channel_zero_points of `tensor`. This Will be empty if `tensor` is not - per channel quantized. - - `tensor` is per tensor quantized: - scheme: { - "qscheme": str(tensor.qscheme()), - "q_scale": tensor.q_scale(), - "q_zero_point": tensor.q_zero_point(), - } - - `tensor` is per channel quantized: - scheme: { - "qscheme": str(tensor.qscheme()), - "q_per_channel_scales": {pcq_prefix}_per_channel_scales, - "q_per_channel_zero_points": {pcq_prefix}_per_channel_zero_points, - "q_per_channel_axis": tensor.q_per_channel_axis() - } - per_channel_dict: { - {pcq_prefix}_per_channel_scales: { - "dtype": dtype, - "shape": shape, - "is_quantized": is_quantized, - "stride": stride, - } - {pcq_prefix}_per_channel_zero_points: { - "dtype": dtype, - "shape": shape, - "is_quantized": is_quantized, - "stride": stride, - } - } - weights would be updated with { - {pcq_prefix}_per_channel_scales: tensor.q_per_channel_scales().float() - {pcq_prefix}_per_channel_zero_points: tensor.q_per_channel_zero_points().int() - } - """ - scheme: Dict[str, Any] = {} - per_channel_dict: Dict[str, Dict] = {} - - if not tensor.is_quantized: - return scheme, per_channel_dict - - scheme["qscheme"] = str(tensor.qscheme()) - - # For per tensor scheme, we stores scale and zero_point. - if tensor.qscheme() in {torch.per_tensor_affine, torch.per_tensor_symmetric}: - scheme["q_scale"] = tensor.q_scale() - scheme["q_zero_point"] = tensor.q_zero_point() - - # For per channel scheme, per_channel_scales and per_channel_zero_points are tensors. - # We store their tensor value into `weights` and store the name into `scheme`. - if tensor.qscheme() in { - torch.per_channel_affine, - torch.per_channel_affine_float_qparams, - torch.per_channel_symmetric, - }: - # per_channel_scales is float64. Here we save it as float32. - weights[ - f"{pcq_prefix}_per_channel_scales" - ] = tensor.q_per_channel_scales().float() - scheme["q_per_channel_scales"] = f"{pcq_prefix}_per_channel_scales" - per_channel_dict.update( - serialize_weight( - weights[f"{pcq_prefix}_per_channel_scales"], - weights, - f"{pcq_prefix}_per_channel_scales", - ) - ) - - # per_channel_zero_point is int64. Here we save it as int32. - weights[ - f"{pcq_prefix}_per_channel_zero_points" - ] = tensor.q_per_channel_zero_points().int() - scheme["q_per_channel_zero_points"] = f"{pcq_prefix}_per_channel_zero_points" - per_channel_dict.update( - serialize_weight( - weights[f"{pcq_prefix}_per_channel_zero_points"], - weights, - f"{pcq_prefix}_per_channel_zero_points", - ) - ) - - scheme["q_per_channel_axis"] = tensor.q_per_channel_axis() - return scheme, per_channel_dict - - -@compatibility(is_backward_compatible=False) -def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict: - weight_dict: Dict[str, Dict] = {name: {}} - weight_dict[name]["dtype"] = str(tensor.dtype) - weight_dict[name]["shape"] = serialize_shape(tensor.shape) - weight_dict[name]["requires_grad"] = str(tensor.requires_grad) - weight_dict[name]["is_quantized"] = tensor.is_quantized - weight_dict[name]["stride"] = serialize_stride(tensor.stride()) - - if tensor.is_quantized: - quantization_info, per_channel_dict = serialize_tensor_quantization( - tensor, weights, name - ) - weight_dict[name].update(quantization_info) - weight_dict.update(per_channel_dict) - - return weight_dict - - -@compatibility(is_backward_compatible=False) -def serialize_leaf_module( - node: Node, weights_metadata: Dict, weights: Dict, name_prefix: str -) -> Dict: - parameters: Dict[str, Any] = {} - - for p_name, p_value in node.attrs_for_lowering.items(): # type: ignore[attr-defined] - if isinstance(p_value, torch.Tensor): - weights_metadata.update( - serialize_weight(p_value, weights, f"{name_prefix}.{p_name}") - ) - weights[f"{name_prefix}.{p_name}"] = p_value - else: - parameters[p_name] = str(p_value) - - return parameters - - -def _update_weight_fused_dtypes(weight, name, node): - """ - For quantized embedding tables we need to update the shape/type, so we check if the - users of this get_attr node is a quantized EB and this is the weight for the EB, and - update the dtype accordingly. - """ - if len(node.users) == 0: - return - user = list(node.users)[0] - if user.op != "call_function": - return - user_target = _get_qualified_name(user.target) - if ( - user_target.endswith("acc_ops.embedding_bag_byte_rowwise_offsets") - and node == user.kwargs["weight"] - ): - weight[name]["dtype"] = "acc.uint8fused" - elif ( - user_target.endswith("acc_ops.embedding_bag_4bit_rowwise_offsets") - and node == user.kwargs["weight"] - ): - weight[name]["dtype"] = "acc.uint4fused" - - -@compatibility(is_backward_compatible=False) -def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: - """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. - It also adds all weights the provided weights dictionary by qualified_name. - Dictionary Schema: - MODULE - { - modules: {module_name: MODULE], - nodes: [NODE], - weights {qualified_name: WEIGHT}, - } - NODE - { - shape: [], - stride: [], - dtype: dtype, - is_quantized: bool, - target: target, - op_code: op_code, - name: name, - args: [], - kwargs: {} - } - WEIGHT - { - dtype: dtype, - is_quantized: bool, - shape: [], - QUANTIZATION, - } - QUANTIZATION - { - qscheme: qscheme, - q_scale: float, - q_zero_point: float, - q_per_channel_scales, [], - q_per_channel_zero_points: [], - q_per_channel_axis, int - } - """ - serialized_dict: Dict[str, Any] = {} - serialized_dict["modules"] = {} - serialized_dict["weights"] = {} - serialized_dict["nodes"] = [] - submodules = dict(fx_module.named_modules()) - prefix = f"{name_prefix}." if name_prefix else "" - - def get_node_info(node): - tensor_meta = get_tensor_meta(node) - node_rep = { - "shape": serialize_shape(tensor_meta.shape), - "dtype": str(tensor_meta.dtype), - "requires_grad": str(tensor_meta.requires_grad), - "stride": serialize_stride(tensor_meta.stride), - "is_quantized": tensor_meta.is_quantized, - } - - if tensor_meta.is_quantized: - node_rep["qscheme"] = str(tensor_meta.qparams["qscheme"]) - - if tensor_meta.qparams["qscheme"] in { - torch.per_tensor_affine, - torch.per_tensor_symmetric, - }: - node_rep["q_scale"] = tensor_meta.qparams["scale"] - node_rep["q_zero_point"] = tensor_meta.qparams["zero_point"] - - # Add all extra lowering_info that was provided in node.meta. - lowering_info = node.meta.get("lowering_info") - if lowering_info is not None: - overlapping_keys = node_rep.keys() & lowering_info.keys() - assert ( - len(overlapping_keys) == 0 - ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}" - node_rep.update(lowering_info) - - return node_rep - - # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules - # that cannot currently be symbolically traced into, e.g. batch norm. - lift_lowering_attrs_to_nodes(fx_module) - for node in fx_module.graph.nodes: - node_rep: Dict[str, Any] = {} - # Get shape/type info, currently not needed for call_module node - # whose target is a GraphModule and output node. - if ( - not ( - node.op == "call_module" - and isinstance(submodules[node.target], GraphModule) - ) - and node.op != "output" - ): - node_rep.update(get_node_info(node)) - - # Recurse down into any submodules we are calling. - if node.op == "call_module": - if isinstance(submodules[node.target], GraphModule): - serialized_module = serialize_module( - getattr(fx_module, node.target), weights, node.target - ) - serialized_dict["modules"][node.target] = serialized_module - else: - node_rep["parameters"] = serialize_leaf_module( - node, - serialized_dict["weights"], - weights, - prefix + node.target, - ) - - if node.op == "call_function": - node_rep["target"] = _get_qualified_name(node.target) - else: - node_rep["target"] = str(node.target) - - # Make sure we capture all constants. - if node.op == "get_attr": - # If we are targeting a parent constant we update the target. - if node.target.startswith("parent."): - qualname = node.target[len("parent.") :] - node.name = qualname - node_rep["target"] = qualname - else: - qualname = prefix + node.target - # Find the actual target parameter/buffer from the fx_module. - submod_path, _, target_name = node.target.rpartition(".") - submod: Optional[torch.nn.Module] = ( - fx_module.get_submodule(submod_path) if submod_path else fx_module - ) - assert submod is not None, f"submod {submod_path} not found" - target = getattr(submod, target_name, None) - assert target is not None, f"{target_name} not an attr of {submod_path}" - # Check that the target is a tensor, and that we haven't added it already from a leaf module. - if isinstance(target, torch.Tensor) and qualname not in weights: - weight = serialize_weight(target, weights, qualname) - _update_weight_fused_dtypes(weight, qualname, node) - serialized_dict["weights"].update(weight) - weights[qualname] = target - elif node.op == "placeholder": - ph_type = node.meta.get("ph_type", "") - assert ( - ph_type == "" or ph_type == "input_ph" or ph_type == "output_ph" - ), "When present, placeholder type must be 'input_ph' or 'ouput_ph'" - if ph_type == "input_ph": - node_rep["ph_type"] = "input_ph" - elif ph_type == "output_ph": - node_rep["ph_type"] = "output_ph" - - node_rep["op_code"] = node.op - node_rep["name"] = node.name - - def get_user_info(user_node: Argument) -> Any: - return {"is_node": True, "name": str(user_node)} - - def get_arg_info(arg: Argument) -> Any: - if isinstance(arg, torch.fx.Node): - return {"is_node": True, "name": str(arg)} - elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)): - return str(arg) - else: - return arg - - def get_output_arg_info(arg: Node) -> Dict[str, Any]: - node_rep: Dict[str, Any] = get_arg_info(arg) - node_rep.update(get_node_info(arg)) - return node_rep - - if node.op == "output": - node_rep["args"] = map_arg( - node.args, - get_output_arg_info, - ) - - # If there're multiple outputs then node_rep["args"][0] will be a tuple or - # list. In this case we want to unpack the tuple or list. - if isinstance(node_rep["args"][0], (tuple, list)): - node_rep["args"] = node_rep["args"][0] - else: - node_rep["args"] = map_aggregate(node.args, get_arg_info) - - node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info) - node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info) - serialized_dict["nodes"] += [node_rep] - - return serialized_dict diff --git a/torch/fx/passes/infra/__init__.py b/torch/fx/passes/infra/__init__.py new file mode 100644 index 0000000000000..657b6a93014f4 --- /dev/null +++ b/torch/fx/passes/infra/__init__.py @@ -0,0 +1,2 @@ + +from . import pass_manager diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py new file mode 100644 index 0000000000000..18a665b88ede0 --- /dev/null +++ b/torch/fx/passes/infra/partitioner.py @@ -0,0 +1,227 @@ +from typing import Dict, List, Set, Iterable, Optional + +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions +from torch.fx.passes.tools_common import NodeList + +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node, _get_qualified_name +from torch.fx.passes.operator_support import OperatorSupportBase + +from collections import defaultdict +import logging +import itertools + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + +class Partition: + def __init__(self, id: int = None, nodes: Iterable[Node] = None): + self.id = id + self.nodes: Set[Node] = set(nodes) if nodes is not None else set() + + def __repr__(self) -> str: + return str(self.nodes) + + def add_node(self, node: Node): + self.nodes.add(node) + + def remove_node(self, node: Node): + self.nodes.remove(node) + + def size(self): + return len(self.nodes) + +class CapabilityBasedPartitioner: + + def __init__(self, + graph_module: GraphModule, + operator_support: OperatorSupportBase, + allows_single_node_partition: bool = False + ) -> None: + self.graph_module = graph_module + self.operator_support = operator_support + self.allows_single_node_partition = allows_single_node_partition + + # map of node to it's upstream dependency nodes + # if A is found in dependency_map[B], then B depends on A (or a is an upstream depedency of b) + self.dependency_map = self.__build_dependency_map() + + def __build_dependency_map(self) -> Dict[Node, Set[Node]]: + dependency_map = defaultdict(set) + + # assumptions: nodes in graph are sorted in topological order + for node in self.graph_module.graph.nodes: + for input_node in node.all_input_nodes: + # add input_node and input_node's upstream dependency + dependency_map[node].add(input_node) + dependency_map[node].update(dependency_map[input_node]) + + return dependency_map + + def __node_depends_on(self, a: Node, b: Node) -> int: + # Returns + # 1 if b depends on a (,or equivalently a is an upstream depedency of b) + # -1 if a depends on b (,or equivalently b is an upstream depedency of a) + # 0 if a and b doesn't have dependency between each other + + if a in self.dependency_map[b]: + return 1 + elif b in self.dependency_map[a]: + return -1 + else: + return 0 + + def __partition_depends_on(self, partition_a: Partition, partition_b: Partition) -> int: + # Returns + # 1 if b depends on a (,or equivalently a is an upstream depedency of b) + # -1 if a depends on b (,or equivalently b is an upstream depedency of a) + # 0 if a and b doesn't have dependency between each other + + # TODO: build a cache here to speedup the query + + for node_a in partition_a.nodes: + for node_b in partition_b.nodes: + dependency = self.__node_depends_on(node_a, node_b) + if dependency != 0: + return dependency + return 0 + + def __get_supported_nodes(self) -> NodeList: + logging.debug("Collecting supported nodes...") + supported_nodes = [] + for node in self.graph_module.graph.nodes: + if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node): + supported_nodes.append(node) + return supported_nodes + + def propose_partitions(self) -> List[Partition]: + candidates: NodeList = self.__get_supported_nodes() + + # assumptions: nodes in candidate list is sorted in topological order + assignment: Dict[Node, int] = {} # maping from node to partition_id + partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition + new_partition_id = itertools.count() + + def assign(node: Node, id: Optional[int] = None): + # If id is None, remove the node from original assigment + + # node has been assigned before, clean up and re-assign + if node in assignment: + original_id = assignment[node] + del assignment[node] + partitions_by_id[original_id].remove_node(node) + if partitions_by_id[original_id].size() == 0: + del partitions_by_id[original_id] + + if id is not None: + assignment[node] = id + if id not in partitions_by_id: + partitions_by_id[id] = Partition(id=id, nodes=[node]) + else: + partitions_by_id[id].add_node(node) + + logging.debug("Proposing partitions...") + + # visit candidates in reversed topological order + for node in reversed(candidates): + # use Dict as an ordered set to ensure deterministic partitioning result, don't care value + user_partitions: Dict[Partition, None] = {} + for user_node in node.users: + if user_node in assignment: + id = assignment[user_node] + user_partitions[partitions_by_id[id]] = None + else: + user_partitions[Partition(nodes=[user_node])] = None + + # Filter out all the partitions that has dependency on other users + # TODO: find a better way to do this, rather than pair-wise comparision + user_partitions_list = list(user_partitions.keys()) + for i in range(len(user_partitions_list)): + for j in range(i + 1, len(user_partitions_list)): + pi = user_partitions_list[i] + pj = user_partitions_list[j] + dependency = self.__partition_depends_on(pi, pj) + if dependency == 1 and pj in user_partitions: + del user_partitions[pj] + elif dependency == -1 and pi in user_partitions: + del user_partitions[pi] + + # We use the following rules for partition assignment: + # 1. If none of the candidates has been assigned to a partition, create a new partition + # 2. If there is one partition candidate, assign to the partition + # 3. If there are more than one partition candidates, assign current node to the first partition and + # merge the other partitions with first partition, since user_partitions doesn't have depedency between + # each other. + + assigned_candidate_partition_ids = [partition.id for partition in user_partitions if partition.id is not None] + + if len(assigned_candidate_partition_ids) == 0: + # create a new partition + assign(node, next(new_partition_id)) + elif len(assigned_candidate_partition_ids) == 1: + id = assigned_candidate_partition_ids[0] + assign(node, id) + else: + # users are assigned to more than one partition, since user_partitions doesn't have + # dependency on each other, they can be fused into a single partition + id = assigned_candidate_partition_ids[0] + assign(node, id) + + reassignment: Dict[Node, int] = {} + for other_id in assigned_candidate_partition_ids[1:]: + for other_node in partitions_by_id[other_id].nodes: + reassignment[other_node] = id + for other_node in reassignment: + assign(other_node, id) + + # post processing to re-assign "getitem" nodes into upstream partition + logger.debug("Reassigning getitem nodes to its producer node's partition...") + nodes_reassignment: Dict[Node, int] = {} + for node in self.graph_module.graph.nodes: + is_tuple_output = True + for user in node.users: + if user.op != "call_function" or \ + _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] + is_tuple_output = False + break + + # node has tuple outputs, re-assign all following getitem node into node's partition + if is_tuple_output: + id = assignment.get(node, None) # type: ignore[arg-type] + for user in node.users: + if assignment.get(user, None) != id: # type: ignore[arg-type] + nodes_reassignment[user] = id + for node, id in nodes_reassignment.items(): + assign(node, id) + + # filter out single node partitions + if not self.allows_single_node_partition: + logger.debug("Filtering out single node partitions...") + non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} + partitions_to_remove: List[int] = [] + for id, partition in partitions_by_id.items(): + compute_node_count = 0 + for node in partition.nodes: + if node.op == "call_function" and \ + _get_qualified_name(node.target) not in non_compute_ops: # type: ignore[arg-type] + compute_node_count += 1 + if compute_node_count <= 1: + partitions_to_remove.append(id) + for id in partitions_to_remove: + del partitions_by_id[id] + + logging.debug("Partitions proposed:") + for id, partition in partitions_by_id.items(): + logging.debug(f"partition #{id}", [node.name for node in partition.nodes]) + + return list(partitions_by_id.values()) + + def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: + logging.debug("Fusing partitions...") + # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] + return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions]) + + def partition_and_fuse(self) -> GraphModule: + partitions = self.propose_partitions() + fused_gm = self.fuse_partitions(partitions) + return fused_gm diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py new file mode 100644 index 0000000000000..cb194a56c68e1 --- /dev/null +++ b/torch/fx/passes/infra/pass_base.py @@ -0,0 +1,78 @@ +import abc +from collections import namedtuple +from typing import Optional + +from torch.fx.graph_module import GraphModule +from torch.fx._compatibility import compatibility + + +__all__ = ['PassResult', 'PassBase'] + +@compatibility(is_backward_compatible=False) +class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): + """ + Result of a pass: + graph_module: The modified graph module + modified: A flag for if the pass has modified the graph module + """ + def __new__(cls, graph_module, modified): + return super().__new__(cls, graph_module, modified) + +@compatibility(is_backward_compatible=False) +class PassBase(abc.ABC): + """ + Base interface for implementing passes. + + It is required to implement the `call` function so that we can directly + pass instances of the Pass directly to the PassManager and call them as a + function. + + We can directly pass an instance of a class implementing this interface into + the PassManager's `passes` attribute. + """ + + def __init__(self) -> None: + pass + + def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(graph_module) + res = self.call(graph_module) + self.ensures(graph_module) + return res + + @abc.abstractmethod + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + The pass that is run through the given graph module. To implement a + pass, it is required to implement this function. + + Args: + graph_module: The graph module we will run a pass on + """ + pass + + def requires(self, graph_module: GraphModule) -> None: + """ + This function will be called before the pass is run and will check that + the given graph module contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ + pass + + def ensures(self, graph_module: GraphModule) -> None: + """ + This function will be called after the pass is run and will check that + the given graph module contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ + pass diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py new file mode 100644 index 0000000000000..2295c3001d068 --- /dev/null +++ b/torch/fx/passes/infra/pass_manager.py @@ -0,0 +1,296 @@ +import inspect +from queue import Queue +from functools import wraps +from typing import Callable, Dict, List + +import torch.nn as nn +from torch.fx.graph_module import GraphModule +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult + +__all__ = ['inplace_wrapper', 'pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager'] + +@compatibility(is_backward_compatible=False) +def inplace_wrapper(fn: Callable) -> Callable: + """ + Convenience wrapper for passes which modify an object inplace. This + wrapper makes them return a PassResult containing the modified object and + True for the "modified" flag. + + Args: + fn (Callable[Module, Any]) + + Returns: + wrapped_fn (Callable[Module, PassResult]) + """ + if fn is None: + return None + + @wraps(fn) + def wrapped_fn(gm): + fn(gm) + return PassResult(gm, True) + + return wrapped_fn + +@compatibility(is_backward_compatible=False) +def pass_result_wrapper(fn: Callable) -> Callable: + """ + Wrapper for passes which currently do not return a PassResult. + This wrapper makes them return a PassResult containing the modified object + and True for the "modified" flag. + + Args: + fn (Callable[Module, Any]) + + Returns: + wrapped_fn (Callable[Module, PassResult]) + """ + if fn is None: + return None + + @wraps(fn) + def wrapped_fn(gm): + gm = fn(gm) + return PassResult(gm, True) + + return wrapped_fn + +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: List[Callable] +) -> None: + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + +def _topological_sort_passes( + passes: List[Callable], constraints: List[Callable] +) -> List[Callable]: + """ + Args + passes: Passes that we are ordering + constraints: Constraints applied on these passes + + Returns + A sorted list of callables and a boolean of if a circular dependency + existed + """ + if len(constraints) == 0: + return passes + + # Contruct a graph mapping nodes to a list of their users + graph: Dict[Callable, List[Callable]] = {p : [] for p in passes} + indegree_map: Dict[Callable, int] = {p : 0 for p in passes} + candidates: Queue = Queue() + for a in passes: + for b in passes: + if a == b: + continue + + for constraint in constraints: + if not constraint(a, b): + graph[b].append(a) + indegree_map[a] += 1 + + if indegree_map[a] == 0: + candidates.put(a) + + visited: Dict[Callable, bool] = {p : False for p in passes} + sorted_passes: List[Callable] = [] + + while not candidates.empty(): + p = candidates.get() + sorted_passes.append(p) + visited[p] = True + + for n in graph[p]: + if not visited[n]: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + # Check if there are unvisited nodes (aka cycles in the graph) + cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) + if len(cycle_passes) != 0: + error = f"Circular dependency detected within the following passes: {cycle_passes}" + raise RuntimeError(error) + + return sorted_passes + +@compatibility(is_backward_compatible=False) +def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [pass_b, pass_a] + + constraints = [ + this_before_that_pass_constraint(pass_a, pass_b) + ] + ``` + + Args: + this (Callable): pass which should occur first + that (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + if a == that and b == this: + return False + return True + + return depends_on + + +@compatibility(is_backward_compatible=False) +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): List of passes. A pass is a + callable which modifies an object and returns a PassResult + constraint (Optional[List[Callable]]): List of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + steps (int): Max number of times we run the passes (default = 1). + run_checks_after_each_pass (bool): Whether to run checks and linting + after each pass + suppress_check_failures (bool): Whether to raise errors when running + checks + """ + + passes: List[Callable[[nn.Module], PassResult]] = [] + constraints: List[Callable[[Callable, Callable], bool]] = [] + _validated: bool = False + steps: int = 1 + + def __init__( + self, + passes=None, + constraints=None, + steps=None, + run_checks_after_each_pass: bool = False, + suppress_check_failures: bool = False, + ): + if passes: + self.passes = passes + if constraints: + self.constraints = constraints + if steps: + self.steps = steps + + self.run_checks_after_each_pass = run_checks_after_each_pass + self.suppress_check_failures = suppress_check_failures + + def add_pass(self, _pass: Callable): + """ + Adds a pass into the current list of passes. + """ + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint: Callable): + """ + Adds a constraint into the current list of constraints. + """ + self.constraints.append(constraint) + self._validated = False + + def validate_constraints(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def solve_constraints(self): + """ + Finds a valid traversal order based on the given constraints and orders + the passes based on this order. + + If a circular dependency exists between the constraints and steps = 1, + then we will raise an error because if steps != 1 this means that we + will re-run the passes, allowing for circular dependencies. + """ + self.passes = _topological_sort_passes(self.passes, self.constraints) + self._validated = True + + def add_checks(self, check: Callable) -> None: + """ + Adds a function which takes runs various checks on a given graph module. + This function is run before and after each pass if the + `run_checks_after_each_pass` flag is enabled. + """ + sig = inspect.signature(check) + + if len(list(sig.parameters.values())) != 1: + raise TypeError("PassManager check function should only take in one variable, a module") + + setattr(self, "check", check) # noqa: B010 + + def check(self, module: nn.Module) -> None: + pass + + def __call__(self, module: nn.Module) -> PassResult: + """ + Runs a list of passes in the order based on `self.passes` on the given + graph module. Each time a pass is run, checks and linting will be run on + the graph module if `run_checks_after_each_pass` is set. + + If the module is a graph module, we will run the list of passes until + the graph stops changing, or until `steps` number of times. + """ + # Order the passes based on the constraints + if not self._validated: + self.solve_constraints() + + # Check graph invariants + self.check(module) + + # Run the set of passes `steps` number of times or until the graph stops + # changing + overall_modified = False + for _ in range(self.steps): + modified = False + + # Run the set of passes on the graph module + for fn in self.passes: + res = fn(module) + + module = res.graph_module + modified = modified or res.modified + + if isinstance(module, GraphModule): + module.recompile() + + # Check graph invariants + if self.run_checks_after_each_pass: + self.check(module) + + # If the graph no longer changes, then we can stop running these passes + overall_modified = overall_modified or modified + if not modified: + break + + return PassResult(module, overall_modified) diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 1d9891588b7ce..26df7714e0663 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -19,6 +19,8 @@ ) from dataclasses import dataclass +__all__ = ['FxNetMinimizerBadModuleError', 'FxNetMinimizerRunFuncError', 'FxNetMinimizerResultMismatchError'] + _LOGGER = logging.getLogger(__name__) diff --git a/torch/fx/passes/operator_support.py b/torch/fx/passes/operator_support.py index 5715ff00762af..733338c67a386 100644 --- a/torch/fx/passes/operator_support.py +++ b/torch/fx/passes/operator_support.py @@ -8,6 +8,8 @@ from .tools_common import get_node_target, CALLABLE_NODE_OPS +__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports'] + # fx.Node.target typename, as returned by `get_node_target()` TargetTypeName = str diff --git a/torch/fx/passes/param_fetch.py b/torch/fx/passes/param_fetch.py index 41d7599776d96..5979e29fcc6b2 100644 --- a/torch/fx/passes/param_fetch.py +++ b/torch/fx/passes/param_fetch.py @@ -5,6 +5,7 @@ from torch.fx._compatibility import compatibility +__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] # Matching method matches the attribute name of current version to the attribute name of `target_version` @compatibility(is_backward_compatible=False) diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py new file mode 100644 index 0000000000000..39f0ad1569b92 --- /dev/null +++ b/torch/fx/passes/reinplace.py @@ -0,0 +1,512 @@ +import torch +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor +from torch.utils._pytree import tree_map +from torch.multiprocessing.reductions import StorageWeakRef + +import _operator +from enum import Enum +import itertools +from typing import Set, Dict +from collections import defaultdict + +__all__ = ['reinplace'] + +class _ViewType(Enum): + NonView = 0 + SingleOutputView = 1 + MultiOutputView = 2 + +def _is_view_op(tgt): + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return first_arg.alias_info is not None and not first_arg.alias_info.is_write + +def _get_view_type(tgt) -> _ViewType: + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + if first_arg.alias_info is not None and not first_arg.alias_info.is_write: + # check if op is a multi-output view + if '*' in first_arg.alias_info.after_set: + return _ViewType.MultiOutputView + else: + return _ViewType.SingleOutputView + return _ViewType.NonView + + +# Stores a bunch of metadata related to functionalization each node. +# Relevant metadata: +# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) +# The fake tensor output from running the current node +# n.meta['view_of']: Node +# If the current node n is a view of some base tensor, the 'view_of' field tells us which +# view node was used to generate the current node (a view tensor). +# This information actually makes `fake_result` redundant, but we can use `fake_result` +# to sanity check that our aliasing information is correct. +@compatibility(is_backward_compatible=False) +class _FunctionalizationMetadataProp(torch.fx.Interpreter): + + def run_node(self, node: Node): + self.node_counter += 1 + result = super().run_node(node) + node.meta['fake_result'] = result + node.meta['node_idx'] = self.node_counter + + # (1) Update metadata with the list of nodes that are used by this node + # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. + # We don't want to treat it as "being used as an input". + node_args = node.args + if node.target is torch.ops.aten.copy_.default: + node_args = node_args[1:] + + # (2) Update metadata to track aliasing information about view tensor nodes. + if node.op == 'call_function': + view_type = _get_view_type(node.target) + if view_type == _ViewType.SingleOutputView: + assert isinstance(node.args[0], Node) + node.meta['view_of'] = node.args[0] + elif view_type == _ViewType.MultiOutputView: + self.multi_output_view_nodes[node] = node.args[0] + + # Check if we returned a multi-output view, + # and we're now grabbing the individual views from the output. + # + # For multi-output views, we want to map each output view to the base, + # but this mapping involves two separate nodes in FX IR. + # e.g. "a, b = x_1.split(...)" becomes: + # %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) + # %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) + # And we'd like to set: + # getitem1.meta['view_of'] = x_1 + elif node.target is _operator.getitem: + list_arg = node.args[0] + maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) + if maybe_base_of_view is not None: + # Note: we could also track indexing info here for multi-output views. + # I don't think this metadata is strictly needed for de-functionalization. + assert isinstance(maybe_base_of_view, Node) + node.meta['view_of'] = maybe_base_of_view + + if 'view_of' in node.meta: + # We're linking the current node with its first argument as views. + # Assert here that this is actually the case, and their storages are the same. + assert isinstance(node.meta['fake_result'], FakeTensor) + assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) + view_storage = StorageWeakRef(node.meta['fake_result'].storage()) + base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage()) + assert view_storage == base_storage + return result + + + + def propagate(self, *args): + self.multi_output_view_nodes = {} + self.node_counter = -1 + with FakeTensorMode.push() as mode: + fake_args = [mode.from_tensor(a) for a in args] + return super().run(*fake_args) + +def _schemas_match(functional_schema, inplace_schema): + names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name + arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( + a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) + # for the inplace op, its first argument should be mutable + assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write + # and its remaining arguments shouldn't be. + assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) + return names_match and arg_types_match + +# TODO: this should be beefed up to be able to properly re-inplace with: +# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) +# - out= ops (e.g. angle -> angle.out) +# TODO: we should also figure this info out using torchgen. +def _maybe_get_inplace_op(op): + # __module__ seems broken; it returns torch._ops.aten which doesn't exist + if not isinstance(op, torch._ops.OpOverload): + return None + # Some view ops have inplace variants (as_strided_, etc), + # but we do NOT want the reinplacing pass to directly add these into the program. + # (they'll require extra special handling, aren't aren't really useful for perf anyway) + if _is_view_op(op): + return None + op_namespace = op.__module__.split(".")[-1] + op_base_name = op.overloadpacket.__name__ + maybe_namespace_module = getattr(torch.ops, op_namespace) + maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) + if maybe_inplace_op is None: + return None + + inplace_overloads = [ + getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() + ] + inplace_overloads_with_matching_schemas = [ + f + for f in inplace_overloads + if _schemas_match(op._schema, f._schema) + ] + # This is for sanity: if foo() and foo_() are both operators, + # we expect them to have compatible schemas. + # (This is asserted by codegen for ATen, but might not be true + # for other arbitrary operators). + assert len(inplace_overloads_with_matching_schemas) == 1 + inplace_op = inplace_overloads_with_matching_schemas[0] + return inplace_op + +_VIEW_INVERSE_MAP = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} + +# This function, given a set of set of (aliased) tensor nodes, +# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index +# in the node ordering. +def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): + def _add_if_tensor(x, set_): + if isinstance(x, FakeTensor): + set_.add(StorageWeakRef(x.storage())) + + nodes_used_after = set() + for t in tensor_aliases: + # get all nodes that use the current alias + usage_nodes = t.users + for n in usage_nodes: + # We only care about usages after the current node + if n.meta['node_idx'] <= op_index: + continue + # We also don't care about intermediate view ops. + # They only matter if their output is then used elsewhere + # (either in an out-of-place op, or as an output to the function). + if n in tensor_aliases: + if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: + continue + nodes_used_after.add(n) + return nodes_used_after + +# Given an op that we're trying to re-inplace, "b = foo(a)", +# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" +# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: +# If there are any aliases in the alias_set(a) that satisfy: +# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" +# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata +# as "alias" +def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: + def matching_view_metadata(a, b): + return a.size() == b.size() and \ + a.stride() == b.stride() and \ + a.storage_offset() == b.storage_offset() + + view_inverse_nodes = set() + # Go through them in node order, so we can see chains of view_scatter ops. + for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): + if n.target not in _VIEW_INVERSE_MAP: + continue + base = n.args[0] + mutated_view = n.args[1] + assert isinstance(base, Node) + assert isinstance(base.meta['fake_result'], FakeTensor) + assert isinstance(mutated_view, Node) + assert isinstance(mutated_view.meta['fake_result'], FakeTensor) + # Check that this view_inverse op actually corresponds to taking doing the inverse + # of one of our existing self_alias nodes. + original_view = _VIEW_INVERSE_MAP[n.target] + for self_alias in self_aliases: + # We're looking for some alias of the self arg, "alias", + # that was created from some op `alias = foo(base, args...)` + # such that the current _scatter op "inverts" that foo call. + # We can check that by running the original op again, and checking that the strides match. + if 'view_of' not in self_alias.meta: + continue + self_alias_base = self_alias.meta['view_of'] + try: + # The we're trying to re-use the args from the view_scatter call inside of the corresponding + # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse + # of the current alias we're looking at. + view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) + expected_metadata = self_alias.meta['fake_result'] + # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. + if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ + matching_view_metadata(view_replay_metadata, expected_metadata): + view_inverse_nodes.add(n) + except Exception: + continue + + return view_inverse_nodes + + +@compatibility(is_backward_compatible=True) +def reinplace(gm, *sample_args): + """ + Given an fx.GraphModule, modifies it to perform "reinplacing", + mutating the nodes of the graph. + We look for out-of-place op call sites like `b = a.add(...)`, + and convert them to be inplace (`b = a.add_(...)`), + as long as the input to the current operator ("a") isn't re-used + anywhere later in the graph. + + This pass currently expects to operate on a **functional, ATen** graph. + This can be obtained by running `make_fx(functionalize(f))`. + + Sample inputs are needed to determine aliasing relationships of the inputs. + In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the + inputs to the program. + + Given a node "b = foo(a, ...)", the algorithm for re-inplacing is as follows: + + (1) Check if foo has a mutating variant. If not, move to the next node. + + Note that we ignore view ops (we don't bother to turn `as_strided()` + into `as_strided_()`), as it complicates the algorithm and doesn't + provide meaningful speedups. + + Currently, we also only check for an inplace op, `foo_`. + Later, we should beef this up to check for out= or mutable ops. + + (2) Check if "a" is an alias of any of the program inputs. + + If it is, skip and move to the next node. + Inplace'ing an op that would cause it to mutate a program is not sound, + because that would be a side effect visible to the user. + + NOTE: there's a future optimization that we should make: + if "a" is a (alias of a) program input, but later in the program + there is a node that looks like "a.copy_(...)", + Then re-inplacing is ok to do - we are temporarily re-using a's buffer, + which will later be overwritten by the copy_() call. + + This will be an important optimization to have for programs that mutate + their inputs. It currently isn't implemented though. + + (3) Check that "a" and all of its outstanding aliases are not used anywhere + later in the graph. If this is the case, then it's safe to re-inplace + to "b = foo_(a)". + + There are a few caveats to this, explained in more detail below: + (a) If "a" is used later as an argument to a view op, that is okay. + It's only a problem if "a" (or that view) is later passed + into a normal operator, or if it is returned as the program output. + (b) If "a" is a repeat argument in `foo()`, then don't reinplace. + Most ATen kernels don't make any guarantees that this is sound, + e.g. if you do aten.mul_(a, a). + So we'll just ban re-inplacing in this case. + It's only a problem if "a" (or that view) is later passed + (c) If "a" is used as an input into a view "inverse" / "scatter" + operator, it is potentially fine to re-inplace + (and remove that scatter operator from the graph). + See below for a more detailed example. + + NOTE: there is an optimization in this step that is crucial + to fully recovering performance from functionalization. + + Given this program: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a) + torch.ops.aten.fill_(b, 0) + return d + + Functionalization will emit the following: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a, 0, 1) + b_updated = torch.ops.aten.fill(b, 0) + a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) + return a_updated + + Ordinarily, we would not be able to reinplace the fill, + because "b" aliases with "a" which is used by the diagonal_scatter call. + + "re-inplacing" is on the hook for figuring out that it is ok to + completely, the expensive diagonal_scatter call, if we re-inplace the add(). + + So, for every `alias in alias_set(a)`, instead of checking + that "alias" is not used anywhere later in the graph, + we check that + EITHER: + (a) alias is not used anywhere later in the graph + OR: + (b) alias is used exactly once later on in the graph, + in the following op: + + out = foo_scatter(alias, x, args...) + + where the following must hold: + (i) "foo_scatter" is the "inverse" operator for foo. + This only applies to "foo" ops that are view operators, + which view into a subset of the original tensor's memory. + In practice, there are ~4 operators where this applies: + diagonal -> diagonal_scatter + slice -> slice_scatter + select -> select_scatter + as_strided -> as_strided_scatter + (ii) "args..." are the same between the foo() and foo_scatter() calls. + + (4) Finally, after converting "b = foo(a)" into "foo_(a)", + we need to find all later nodes that use "b" as an argument + and update them to take in "a" instead. + + Note that for the majority of inplace ops, this isn't actually necessary + (because most inplace ops return "self" as their output). + This isn't generally true for all mutable ops though, which is why + we need to actually replace all of the arguments. + + We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], + That maps a given tensor storage to the set of all nodes that take in that storage + as an input. + Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused + together. + + (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" + during step (3) get manually deleted from the graph. + Their outputs are no longer used, so technically standard DCE would be able + to do this, but we can no longer run FX's DCE pass now that we have mutable + ops in the graph. + """ + _FunctionalizationMetadataProp(gm).propagate(*sample_args) + + # Useful debug printing + # def _print(x): + # if isinstance(x, FakeTensor): + # print(f'fake_result: {StorageWeakRef(x.storage()).cdata}') + + # for n in gm.graph.nodes: + # print(n.format_node()) + # if hasattr(n, 'meta'): + # print(f'node_idx: {n.meta["node_idx"]}') + # if 'fake_result' in n.meta: + # tree_map(_print, n.meta['fake_result']) + # if 'view_of' in n.meta: + # print(f'view_of: {str(n.meta["view_of"])}') + # print() + + # We need to know which nodes correspond to inputs (or their aliases) + # so we know not to re-inplace them. + # NOTE: later, we'll need to add an optimization for fully recovering performance + # on programs that mutate inputs. + input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder') + + + # We also need to know for a given node, what are all of its aliasing nodes. + storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) + for n in gm.graph.nodes: + if 'fake_result' in n.meta: + # Tree-mapping because some ops can return lists of tensors. + def _add_to_map(x): + if isinstance(x, FakeTensor): + storage_to_nodes[StorageWeakRef(x.storage())].add(n) + tree_map(_add_to_map, n.meta['fake_result']) + + # inplace-ify functional ops, subject to the constraints written below. + all_later_view_inverse_node_usages = set() + for idx, node in enumerate(gm.graph.nodes): + if node.op == 'call_function': + # Step 1: Check to see if this operator has an inplace variant. + maybe_inplace_op = _maybe_get_inplace_op(node.target) + if maybe_inplace_op is None: + continue + # This is a proxy check for ensuring that the first argument is "tensor-like" + # (This should be the case for all ops with inplace variants in ATen, + # although we technically don't have guarantees for custom ops). + assert len(node.target._schema.arguments) > 0 + assert 'Tensor' in str(node.target._schema.arguments[0].type) + + # Step 2: ensure that the op we're trying to re-inplace isn't a program input. + self_arg = node.args[0] + self_arg_name = self_arg.name + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) + if self_arg_storage in input_storages: + # TODO: later, add the optimization for handling `copy_()` calls in the graph. + continue + if len([x for x in node.args if x is self_arg]) > 1: + # Step (3b) in the original description. + # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, + # so we prevent re-inplacing in this case. + continue + + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) + curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage()) + self_aliases = storage_to_nodes[self_arg_storage] + + # First, we find all later usages of any of the aliases of self_arg. + later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) + # Then, we check if any of those later usages are actually view_scatter ops + # that are safe to fully remove. + later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) + + # Step 3: Check to see if the input to the op is re-used later in the graph. + # If not (same goes for its aliases), then this op is safe to re-in place. + # This is a slightly roundabout way to check that there are no later usages of the current self argument. + # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) + can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 + if not can_reinplace: + continue + # Step 4: replace the current out-of-place op with its inplace variant. + node.target = maybe_inplace_op + # At this point, 'storage_to_nodes' will be stale. + # Now that we're inplacing `b = foo(a)`, we need to effectively + # union together the dict values for b and a's storage. + # Hmm... morally I think we also want to keep the `fake_result` metadata + # up to date here, but I'm not sure how easy it is to do. + # Maybe it's fine to wait until the end of the pass to update it. + storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) + storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) + + # Need to remember the view_scatter view nodes we found so we can remove them alter. + all_later_view_inverse_node_usages.update(later_view_inverse_node_usages) + + # Now that we've replaced b = a.foo() with a.foo_(), + # We need to replace any later usages of "b" with "a" + for old in itertools.chain([node], later_view_inverse_node_usages): + new = old.args[0] + nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] + for node_to_update in nodes_to_update: + new_args = [] + for arg_idx, a in enumerate(node_to_update.args): + if a == old: + new_args.append(new) + else: + new_args.append(a) + new_kwargs = {} + for kwarg_idx, (k, v) in enumerate(node_to_update.kwargs.items()): + if isinstance(v, Node) and v.name == old.name: + new_kwargs[k] = new + else: + new_kwargs[k] = v + node_to_update.args = tuple(new_args) + node_to_update.kwargs = new_kwargs + + old_ref = StorageWeakRef(old.meta['fake_result'].storage()) + node_ref = StorageWeakRef(node_to_update.meta['fake_result'].storage()) + if old_ref == node_ref: + # This will happen if we're updating a view op, e.g. + # e.g. replacing + # x = view(old) + # x = view(new) + # When that happens, we need to make sure to keep our + # storage mapping up to date. + new_ref = StorageWeakRef(new.meta['fake_result'].storage()) + # Technically, "old_ref" and all its aliases will remain + # in our mapping. + # That should be fine though, since we deleted "old" + # from the graph at this point. + storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) + storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) + + # Step 5: delete any _scatter nodes that we de-functionalized + # Need to take care not to delete any of these nodes until after *all* modifications + # to the graph are finished. + for to_delete in all_later_view_inverse_node_usages: + gm.graph.erase_node(to_delete) + + + gm.recompile() + return gm diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index f7feaddd207f5..9c3a036e90bf4 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -6,6 +6,7 @@ from typing import Any, Tuple, NamedTuple, Optional, Dict from torch.fx._compatibility import compatibility +__all__ = ['TensorMetadata', 'ShapeProp'] @compatibility(is_backward_compatible=True) class TensorMetadata(NamedTuple): diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 1bd5918da053b..1795de363ca6e 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -4,6 +4,8 @@ from torch.fx._compatibility import compatibility import inspect +__all__ = ['Partition', 'split_module'] + @compatibility(is_backward_compatible=True) class Partition: def __init__(self, name: str): diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index bf1514085d72a..0236d92d8b418 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -2,11 +2,12 @@ from typing import List, Optional, Dict import torch.fx -import torch.nn as nn from torch.fx.graph import map_arg from .tools_common import NodeList, NodeSet from torch.fx._compatibility import compatibility +from torch.fx.passes.utils import lift_subgraph_as_module, HolderModule +__all__ = ['getattr_recursive', 'setattr_recursive', 'Component', 'split_by_tags'] @compatibility(is_backward_compatible=False) def getattr_recursive(obj, name): @@ -53,19 +54,6 @@ class Component: gm: Optional[torch.fx.GraphModule] = None -@compatibility(is_backward_compatible=False) -class HolderModule(nn.Module): - """ - HolderModule is used to copy all the attributes from original module to submodules - that uses the attributes - """ - - def __init__(self, d): - super().__init__() - for k, v in d.items(): - self.add_module(k, v) - - @compatibility(is_backward_compatible=False) def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule: """ @@ -261,37 +249,7 @@ def remap_func(x): # ((output_0, output_1, ...)). comp.graph.output(outs[0] if len(outs) == 1 else outs) - # Loop through all module calls (call_module) and param fetches (get_attr) - # in this component, creating HolderModules as necessary to match the path. - # e.g. if in the original module there's a get_attr node fetches "conv.weight". - # We create a HolderModule as root -> add a HolderModule named "conv" -> - # make "weight" a attribute of "conv" HolderModule and point to conv.weight in - # the original module. - root = HolderModule({}) - for n in comp.graph.nodes: - if n.op not in ("call_module", "get_attr"): - continue - - target = n.target - assert isinstance(target, str) - target_name_parts = target.split(".") - curr = root - orig_gm = gm - - for name in target_name_parts[:-1]: - if not hasattr(curr, name): - curr.add_module(name, HolderModule({})) - - curr = getattr(curr, name) - orig_gm = getattr(orig_gm, name) - - leaf_node_name = target_name_parts[-1] - leaf_node = getattr(orig_gm, leaf_node_name) - - # Relies on custom __setattr__ magic. - setattr(curr, leaf_node_name, leaf_node) - - comp.gm = torch.fx.GraphModule(root, comp.graph) + comp.gm = lift_subgraph_as_module(gm, comp.graph) # Create a call_module node in main graph. main_node = main_g.call_module( @@ -314,6 +272,6 @@ def remap_func(x): # then we need to make sure get_attr is copied to the new graph. for x in flatten(output_node.args[0]): if x.op == "get_attr": - setattr(main_root, x.name, getattr(gm, x.target)) # type: ignore[arg-type] + setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] return torch.fx.GraphModule(main_root, main_g) diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index b1b8fab7299c3..46e5470162842 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -26,6 +26,7 @@ ) import warnings +__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] _LOGGER = logging.getLogger(__name__) @@ -58,11 +59,19 @@ def __init__(self): "we might not care about non-tensor data flow and we can set this option " "to true to disable the functionality that prevent non-tensor data flow.", ) + parser.add_argument( + "--op_lowering_disallow_list", + default="", + type=str, + help="A comma separated string which represents a disallow_list of " + "operator names." + ) args, unknown = parser.parse_known_args() self.min_acc_module_size: int = args.min_acc_module_size self.skip_fusion: bool = args.skip_fusion self.allow_non_tensor: bool = args.allow_non_tensor + self.op_lowering_disallow_list: List[str] = args.op_lowering_disallow_list.split(",") @compatibility(is_backward_compatible=False) @@ -789,6 +798,10 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph if len(subgraph.nodes) >= self.settings.min_acc_module_size: result.append(subgraph) else: + print( + "Eliminating acc subgraph because it's smaller than the threshold: " + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" + ) if result: result[-1].nodes.extend(subgraph.nodes) else: @@ -824,6 +837,9 @@ def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: def __call__(self) -> torch.fx.GraphModule: subgraphs = self.put_nodes_into_subgraphs() subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) + non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count + print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") self.tag(subgraphs) return self.split() diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index c9e76266fd71c..6300671a7b294 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -6,6 +6,7 @@ from torch.fx.node import _get_qualified_name from torch.fx._compatibility import compatibility +__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] TensorOrTensors = Union[torch.Tensor, Tensors] @@ -22,7 +23,7 @@ def get_acc_ops_name(k): elif k.__module__ and "acc_ops" in k.__module__: return f"acc_ops.{k.__name__}" else: - module = k.__module__ + module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module return f"{module if module else ''}.{k.__name__}" diff --git a/torch/fx/passes/utils/__init__.py b/torch/fx/passes/utils/__init__.py new file mode 100644 index 0000000000000..bc6b4553a356c --- /dev/null +++ b/torch/fx/passes/utils/__init__.py @@ -0,0 +1 @@ +from .common import lift_subgraph_as_module, HolderModule diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py new file mode 100644 index 0000000000000..f14e0a428193f --- /dev/null +++ b/torch/fx/passes/utils/common.py @@ -0,0 +1,67 @@ +from torch.nn import Module + +from torch.fx.graph_module import GraphModule +from torch.fx.graph import Graph +from torch.fx._compatibility import compatibility + + +__all__ = ['HolderModule', 'lift_subgraph_as_module'] + +@compatibility(is_backward_compatible=False) +class HolderModule(Module): + """ + HolderModule is used to copy all the attributes from original module to submodules + that uses the attributes + """ + + def __init__(self, d): + super().__init__() + for k, v in d.items(): + self.add_module(k, v) + + +@compatibility(is_backward_compatible=False) +def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str = 'GraphModule') -> GraphModule: + """ + Create a GraphModule for subgraph, which copies the necessory attributes from the original parent graph_module. + + Args: + gm (GraphModule): parent graph module + + subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph + + class_name (str): name for the submodule + + """ + + # Loop through all module calls (call_module) and param fetches (get_attr) + # in this component, creating HolderModules as necessary to match the path. + # e.g. if in the original module there's a get_attr node fetches "conv.weight". + # We create a HolderModule as root -> add a HolderModule named "conv" -> + # make "weight" a attribute of "conv" HolderModule and point to conv.weight in + # the original module. + submodule = HolderModule({}) + for n in subgraph.nodes: + if n.op not in ("call_module", "get_attr"): + continue + + target = n.target + assert isinstance(target, str) + target_name_parts = target.split(".") + curr = submodule + orig_gm = gm + + for name in target_name_parts[:-1]: + if not hasattr(curr, name): + curr.add_module(name, HolderModule({})) + + curr = getattr(curr, name) + orig_gm = getattr(orig_gm, name) + + leaf_node_name = target_name_parts[-1] + leaf_node = getattr(orig_gm, leaf_node_name) + + # Relies on custom __setattr__ magic. + setattr(curr, leaf_node_name, leaf_node) + + return GraphModule(submodule, subgraph, class_name) diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py new file mode 100644 index 0000000000000..f3d5f02421690 --- /dev/null +++ b/torch/fx/passes/utils/fuser_utils.py @@ -0,0 +1,213 @@ +import copy +from queue import SimpleQueue +from typing import List, Dict, Tuple + +import torch.fx +from torch.fx.graph_module import GraphModule +from torch.fx.graph import Graph +from torch.fx.node import Node +from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph +from torch.fx.passes.utils import lift_subgraph_as_module + +def topo_sort(nodes: NodeList) -> NodeList: + # sort nodes according to the topological order + indegree_map = {node : 0 for node in nodes} + candidates: SimpleQueue = SimpleQueue() + + for node in nodes: + for n in node.all_input_nodes: + if n in indegree_map: + indegree_map[node] += 1 + if indegree_map[node] == 0: + candidates.put(node) + + sorted_nodes: NodeList = list() + while not candidates.empty(): + node = candidates.get() + sorted_nodes.append(node) + + for n in node.users: + if n in indegree_map: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes" + + return sorted_nodes + + +def validate_partition(partition: NodeList) -> bool: + # verify the partition does't form a dependency cycle in the original graph + # returns True for valid partition, False for invalid + + partition_set = set(partition) + + outputs: NodeList = list() + for node in partition_set: + for user_node in node.users: + if user_node not in partition_set: + # external user node, need to expose as an output + outputs.append(user_node) + + # perform DFS on the parition outputs + # if it reaches a node within the partition, then it found a cycle + visited: NodeSet = set() + + def dfs_find_cycle(node): + if node in partition_set: + return True # found cycle, return + + visited.add(node) + for user_node in node.users: + if user_node not in visited: + if dfs_find_cycle(user_node): + return True + return False + + for output_node in outputs: + if dfs_find_cycle(output_node): + return False + + return True + + +def fuse_as_graphmodule(gm: GraphModule, + nodes: NodeList, + module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: + + """ + Fuse nodes in graph_module into a GraphModule. + + Args: + gm (GraphModule): target graph_module + + nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted + + module_name: class name for the fused GraphModule + + Returns: + fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm` + + original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm` + + original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm` + + """ + + # assumption: nodes are already sorted in topo order + + for node in nodes: + assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" + assert not node._erased, f"{node} has been removed from owning graph" + assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}" + + # validates partition doesn't introduce dependency circles in the graph + assert validate_partition(nodes), "Invalid partition, found dependency cycles" + + subgraph = Graph() + + node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph + node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph + + # handles inputs throught graph.node_copy's arg_transform functions + def remap_inputs(x): + if x.op == "get_attr": + # TODO: do we really need copy the get_attr node into the graph? + # do something here + pass + + if x in nodes: + # x is inside subgraph, return the copied node + # the node should have been copied aleady, as we are copying graph in the topological order + return node_map[x] + + if x not in node_to_placeholder: + # x is not in subgraph, create a new placeholder for subgraph + placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelvant for the placeholder node + placeholder_node.meta = copy.copy(x.meta) + node_to_placeholder[x] = placeholder_node + + return node_to_placeholder[x] + + # copy nodes in topological order + for node in nodes: + new_node = subgraph.node_copy(node, remap_inputs) + node_map[node] = new_node + + # handles outputs + output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs + + for node in nodes: + for user_node in node.users: + if user_node not in nodes: + # external user node, need to expose as an output + output_mapping[node] = node_map[node] + + # outs contain nodes in the new subgraph + outs = tuple(output_mapping.values()) + + # Take care of the args of FX output node. If there's a single + # output then the output node args is like (output_single), else + # if there're multiple outputs then the output node args is like + # ((output_0, output_1, ...)). + subgraph.output(outs[0] if len(outs) == 1 else outs) + + # lint to ensure correctness + subgraph.lint() + + fused_gm: GraphModule = lift_subgraph_as_module(gm, subgraph, class_name=module_name) + + # sub_gm's input nodes in the original module + original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) + + # sub_gm's outputs node in the original module + original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys()) + + return fused_gm, original_inputs, original_outputs + + +def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]): + # add sub_gm into gm + submodule_name = sub_gm.__class__.__name__ + gm.add_submodule(submodule_name, sub_gm) + + # Create a call_module node in main graph. + module_node = gm.graph.call_module( + submodule_name, + args=orig_inputs, + kwargs=None) + + if len(orig_outputs) == 1: + # main_remapping[comp.orig_outputs[0]] = module_node + orig_outputs[0].replace_all_uses_with(module_node) + else: + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] + orig_output.replace_all_uses_with(proxy_out) + return gm + +def erase_nodes(gm: GraphModule, nodes: NodeList): + + # erase original nodes in inversed topological order + for node in reversed(nodes): + gm.graph.erase_node(node) + + +def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule: + for partition_id, nodes in enumerate(partitions): + sorted_nodes = topo_sort(nodes) + + submodule_name = "fused_" + str(partition_id) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) + + insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) + + erase_nodes(gm, sorted_nodes) + + # topological sort original gm with newly created sub_gm + legalize_graph(gm) + + return gm diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 83dafb48e4811..a932cc11fde2f 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -10,6 +10,8 @@ from ._compatibility import compatibility from .operator_schemas import check_for_mutable_operation +__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', 'Proxy', 'Attribute', 'ParameterProxy'] + @compatibility(is_backward_compatible=True) class TracerBase: graph: Graph @@ -93,7 +95,7 @@ def _find_user_frame(self): # the user code during tracing. frame = inspect.currentframe() - fx_files = ['torch/fx/proxy.py', 'torch/fx/symbolic_trace.py'] + fx_files = ['torch/fx/proxy.py', 'torch/fx/_symbolic_trace.py'] while frame: frame = frame.f_back if frame and all(not frame.f_code.co_filename.endswith(file) for file in fx_files): diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 3f07439b53afe..c651324bb142b 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -8,6 +8,8 @@ from typing import Callable, Dict, List, NamedTuple, Optional, Set import torch +__all__ = ['Match', 'replace_pattern'] + @compatibility(is_backward_compatible=True) class Match(NamedTuple): # Node from which the match was found diff --git a/torch/fx/tensor_type.py b/torch/fx/tensor_type.py index 0840122a9b168..fd9f408c21c4f 100644 --- a/torch/fx/tensor_type.py +++ b/torch/fx/tensor_type.py @@ -28,7 +28,6 @@ def __eq__(self, other): @staticmethod def __class_getitem__(*args): - assert isinstance(args[0], tuple) return TensorType(args[0]) diff --git a/torch/hub.py b/torch/hub.py index 37cd3d099c7c1..5e1a8baba7173 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -346,7 +346,7 @@ def set_dir(d): Optionally set the Torch Hub directory used to save downloaded models & weights. Args: - d (string): path to a local folder to save downloaded models & weights. + d (str): path to a local folder to save downloaded models & weights. """ global _hub_dir _hub_dir = os.path.expanduser(d) @@ -357,7 +357,7 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None): List all callable entrypoints available in the repo specified by ``github``. Args: - github (string): a string with format "repo_owner/repo_name[:ref]" with an optional + github (str): a string with format "repo_owner/repo_name[:ref]" with an optional ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. Example: 'pytorch/vision:0.10' @@ -367,7 +367,7 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None): specified by the ``github`` argument properly belongs to the repo owner. This will make requests to the GitHub API; you can specify a non-default GitHub token by setting the ``GITHUB_TOKEN`` environment variable. Default is ``False``. - trust_repo (bool, string or None): ``"check"``, ``True``, ``False`` or ``None``. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. This parameter was introduced in v1.12 and helps ensuring that users only run code from repos that they trust. @@ -412,18 +412,18 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No Show the docstring of entrypoint ``model``. Args: - github (string): a string with format with an optional + github (str): a string with format with an optional ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. Example: 'pytorch/vision:0.10' - model (string): a string of entrypoint name defined in repo's ``hubconf.py`` + model (str): a string of entrypoint name defined in repo's ``hubconf.py`` force_reload (bool, optional): whether to discard the existing cache and force a fresh download. Default is ``False``. skip_validation (bool, optional): if ``False``, torchhub will check that the ref specified by the ``github`` argument properly belongs to the repo owner. This will make requests to the GitHub API; you can specify a non-default GitHub token by setting the ``GITHUB_TOKEN`` environment variable. Default is ``False``. - trust_repo (bool, string or None): ``"check"``, ``True``, ``False`` or ``None``. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. This parameter was introduced in v1.12 and helps ensuring that users only run code from repos that they trust. @@ -475,17 +475,17 @@ def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_relo path to a local directory. Args: - repo_or_dir (string): If ``source`` is 'github', + repo_or_dir (str): If ``source`` is 'github', this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified, the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. If ``source`` is 'local' then it should be a path to a local directory. - model (string): the name of a callable (entrypoint) defined in the + model (str): the name of a callable (entrypoint) defined in the repo/dir's ``hubconf.py``. *args (optional): the corresponding args for callable ``model``. - source (string, optional): 'github' or 'local'. Specifies how + source (str, optional): 'github' or 'local'. Specifies how ``repo_or_dir`` is to be interpreted. Default is 'github'. - trust_repo (bool, string or None): ``"check"``, ``True``, ``False`` or ``None``. + trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. This parameter was introduced in v1.12 and helps ensuring that users only run code from repos that they trust. @@ -546,9 +546,9 @@ def _load_local(hubconf_dir, model, *args, **kwargs): Load a model from a local directory with a ``hubconf.py``. Args: - hubconf_dir (string): path to a local directory that contains a + hubconf_dir (str): path to a local directory that contains a ``hubconf.py``. - model (string): name of an entrypoint defined in the directory's + model (str): name of an entrypoint defined in the directory's ``hubconf.py``. *args (optional): the corresponding args for callable ``model``. **kwargs (optional): the corresponding kwargs for callable ``model``. @@ -577,9 +577,9 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): r"""Download object at the given URL to a local path. Args: - url (string): URL of the object to download - dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file`` - hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. + url (str): URL of the object to download + dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` + hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. Default: None progress (bool, optional): whether or not to display a progress bar to stderr Default: True @@ -679,8 +679,8 @@ def load_state_dict_from_url( ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. Args: - url (string): URL of the object to download - model_dir (string, optional): directory in which to save the object + url (str): URL of the object to download + model_dir (str, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) progress (bool, optional): whether or not to display a progress bar to stderr. Default: True @@ -689,7 +689,7 @@ def load_state_dict_from_url( digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False - file_name (string, optional): name for the downloaded file. Filename from ``url`` will be used if not set. + file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. Example: >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 9d0e94542bf4a..50877b1221375 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -154,7 +154,7 @@ def script_if_tracing(fn): # for torch.jit.isinstance def isinstance(obj, target_type): """ - This function provides for conatiner type refinement in TorchScript. It can refine + This function provides for container type refinement in TorchScript. It can refine parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``, ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also refine basic types such as bools and ints that are available in TorchScript. diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index d66204901c470..15fab449eeb5f 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -26,12 +26,10 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: Args: mod (:class:`ScriptModule`): a module to be frozen - preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. - Attributes modified in preserved methods will also be preserved. - + Attributes modified in preserved methods will also be preserved. optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly - preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`. + preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`. Returns: Frozen :class:`ScriptModule`. diff --git a/torch/jit/_fuser.py b/torch/jit/_fuser.py index d345f245c6b94..dd26a4e24d38f 100644 --- a/torch/jit/_fuser.py +++ b/torch/jit/_fuser.py @@ -48,8 +48,13 @@ def fuser(name): torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(True) + elif name == 'none': # Turn Pytorch fuser off + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) else: - raise Exception("unrecognized fuser option") + raise Exception(f"unrecognized fuser option (name: {name})") try: yield finally: diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 8175d14fe5dc9..5dd98c49c6afd 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -5,6 +5,7 @@ import textwrap import functools import warnings +import sys from typing import Dict, List, Set, Type import torch._jit_internal as _jit_internal @@ -25,13 +26,15 @@ "_version", "_parameters", "_buffers", - "_modules", - "_initializing", + "_non_persistent_buffers_set", "_backward_hooks", "_forward_hooks", "_forward_pre_hooks", "_state_dict_hooks", "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_modules", + "_initializing", "dump_patches", ] @@ -134,7 +137,22 @@ def infer_concrete_type_builder(nn_module, share_types=True): if isinstance(nn_module, (torch.nn.ParameterDict)): concrete_type_builder.set_parameter_dict() - class_annotations = getattr(nn_module, '__annotations__', {}) + def get_annotations(obj): + if sys.version_info < (3, 10): + return getattr(obj, '__annotations__', {}) + # In Python-3.10+ it is recommended to use inspect.get_annotations + # See https://docs.python.org/3.10/howto/annotations.html + # But also, in 3.10 annotations from base class are not inherited + # by unannotated derived one, so they must be manually extracted + annotations = inspect.get_annotations(obj) + if len(annotations) > 0: + return annotations + cls = obj if isinstance(obj, type) else type(obj) + if len(cls.__bases__) == 0: + return {} + return inspect.get_annotations(cls.__bases__[0]) + + class_annotations = get_annotations(nn_module) if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)): class_annotations = {} diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 1aad03c82d4ed..acafd9997483f 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1053,7 +1053,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None, and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. Args: - obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, class type, + obj (Callable, class, or nn.Module): The ``nn.Module``, function, class type, dictionary, or list to compile. example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs to annotate the arguments for a function or ``nn.Module``. diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index 89c1c40defc76..1e93f0718635b 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -1003,8 +1003,8 @@ def add_bounded_compute_mapping(operator_schema: str, lower_bound_func: Callable add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view) add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand) add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused) -add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) -add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) +add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) +add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim) add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor) add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 603a4247d368d..62548ba7e2cd6 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -1020,6 +1020,11 @@ def build_ListComp(ctx, stmt): return ListComp(r, elt_expr, target_expr, iter_expr) + @staticmethod + def build_GeneratorExp(ctx, stmt): + # Convert Generator expression to ListComp + return ExprBuilder.build_ListComp(ctx, stmt) + @staticmethod def build_DictComp(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) diff --git a/torch/lib/libshm/CMakeLists.txt b/torch/lib/libshm/CMakeLists.txt index 1022ce84c339f..6ee75c8888847 100644 --- a/torch/lib/libshm/CMakeLists.txt +++ b/torch/lib/libshm/CMakeLists.txt @@ -19,7 +19,7 @@ endif(MSVC) if(CMAKE_VERSION VERSION_LESS "3.1") set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}") else() - set(CMAKE_CXX_STANDARD 14) + set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") endif() add_library(shm SHARED core.cpp) diff --git a/torch/library.h b/torch/library.h index 94fc33a6f5d38..69175d0756622 100644 --- a/torch/library.h +++ b/torch/library.h @@ -202,16 +202,21 @@ class TORCH_API CppFunction final { CppFunction& operator=(CppFunction&&) = default; + /// \private + /// Creates a function from a type-erased boxed kernel. + static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) { + return CppFunction( + c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)), + /* cpp_signature */ c10::nullopt, // not known for boxed functions + /* schema */ nullptr); + } + /// This creates a fallthrough function. Fallthrough functions /// immediately redispatch to the next available dispatch key, /// but are implemented more efficiently than a hand written /// function done in the same way. static CppFunction makeFallthrough() { - // TODO: more user friendly API - return CppFunction( - c10::KernelFunction::makeFallthrough(), - /* cpp_signature */ c10::nullopt, // not known for fallthroughs - /* schema */ nullptr); + return makeFromBoxedKernel(c10::BoxedKernel::makeFallthrough()); } /// \private @@ -219,10 +224,7 @@ class TORCH_API CppFunction final { /// Creates a function that raises an error saying that named tensors /// are not supported when called. static CppFunction makeNamedNotSupported() { - return CppFunction( - c10::KernelFunction::makeNamedNotSupported(), - /* cpp_signature */ c10::nullopt, // not known for fallthroughs - /* schema */ nullptr); + return makeFromBoxedKernel(c10::BoxedKernel::makeNamedNotSupported()); } /// Create a function from a boxed kernel function with signature @@ -231,25 +233,19 @@ class TORCH_API CppFunction final { /// in the native C++ calling convention. Boxed functions are /// typically only used to register backend fallbacks via /// torch::Library::fallback(). - template + template static CppFunction makeFromBoxedFunction() { - // TODO: more user friendly API - return CppFunction( - c10::KernelFunction::makeFromBoxedFunction(), - /* cpp_signature */ c10::nullopt, // not known for boxed functions - /* schema */ nullptr); + return makeFromBoxedKernel( + c10::BoxedKernel::makeFromFunction()); } // Variant that takes in a boxed kernel function with a plumbed // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for // details. - template + template static CppFunction makeFromBoxedFunction() { - // TODO: more user friendly API - return CppFunction( - c10::KernelFunction::makeFromBoxedFunction(), - /* cpp_signature */ c10::nullopt, // not known for boxed functions - /* schema */ nullptr); + return makeFromBoxedKernel( + c10::BoxedKernel::makeFromFunction()); } /// Create a function from a boxed kernel functor which defines @@ -263,10 +259,8 @@ class TORCH_API CppFunction final { template static CppFunction makeFromBoxedFunctor( std::unique_ptr kernelFunctor) { - return CppFunction( - c10::KernelFunction::makeFromBoxedFunctor(std::move(kernelFunctor)), - /* cpp_signature */ c10::nullopt, // not known for boxed functions - /* schema */ nullptr); + return makeFromBoxedKernel( + c10::BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); } /// Create a function from an unboxed kernel function. diff --git a/torch/library.py b/torch/library.py index d9803850be1b5..0589286094e42 100644 --- a/torch/library.py +++ b/torch/library.py @@ -11,6 +11,9 @@ # libraries calling into kernels not intended to be called. _impls: Set[str] = set() +# prim is reserved by TorchScript interpreter +_reserved_namespaces = ['prim'] + class Library: """ A class to create libraries that can be used to register new operators or @@ -28,6 +31,10 @@ class Library: def __init__(self, ns, kind, dispatch_key=""): if kind != "IMPL" and kind != "DEF": raise ValueError("Unsupported kind: ", kind) + + if ns in _reserved_namespaces and kind == "DEF": + raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.") + frame = traceback.extract_stack(limit=3)[0] filename, lineno = frame.filename, frame.lineno self.m = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno) @@ -39,7 +46,41 @@ def __init__(self, ns, kind, dispatch_key=""): def __repr__(self): return "Library(kind={}, ns={}, dispatch_key={})>".format(self.kind, self.ns, self.dispatch_key) + def define(self, schema, alias_analysis=""): + r'''Defines a new operator and its semantics in the ns namespace. + + Args: + schema: function schema to define a new operator. + alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be + inferred from the schema (default behavior) or not ("CONSERVATIVE"). + Returns: + name of the operator as inferred from the schema. + + Example:: + >>> my_lib = Library("foo", "DEF") + >>> my_lib.define("sum(Tensor self) -> Tensor") + ''' + # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid + # AliasAnalysis type in C++ + if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: + raise RuntimeError("Invalid alias_analysis type {}".format(alias_analysis)) + return self.m.define(schema, alias_analysis) + def impl(self, op_name, fn, dispatch_key=''): + r'''Registers the function implementation for an operator defined in the library. + + Args: + op_name: operator name (along with the overload) or OpOverload object. + fn: function that's the operator implementation for the input dispatch key. + dispatch_key: dispatch key that the input function should be registered for. By default, it uses + the dispatch key that the library was created with. + + Example:: + >>> my_lib = Library("aten", "IMPL") + >>> def div_cpu(self, other): + >>> return self * (1 / other) + >>> my_lib.impl("div.Tensor", "CPU") + ''' if not callable(fn): raise TypeError("Input function is required to be a callable but found type {}".format(type(fn))) if dispatch_key == '': @@ -85,24 +126,13 @@ def impl(self, op_name, fn, dispatch_key=''): _impls.add(key) self._op_impls.add(key) - def define(self, schema, alias_analysis=""): - ''' - Takes a schema to define a new operator. - Also, optionally takes `alias_analysis` argument to indicate if the aliasing properties of the arguments - can be inferred from the schema (default behavior) or not ("CONSERVATIVE"). - - Returns the name of the operator as inferred from the schema. - ''' - # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid - # AliasAnalysis type in C++ - if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: - raise RuntimeError("Invalid alias_analysis type {}".format(alias_analysis)) - return self.m.define(schema, alias_analysis) - def __del__(self): - for key in self._op_impls: - _impls.remove(key) - del self.m + # _op_impls might not have been initialized if an error was thrown in __init__ + _op_impls_ = getattr(self, '_op_impls', None) + if _op_impls_: + for key in self._op_impls: + _impls.remove(key) + del self.m # decorator to register python functions for library ops # Note: this decorator API should remain consistent with `Library.impl` API diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index fc83d3fcdf4ba..53917eecd7d6c 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -278,6 +278,43 @@ https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem """) +solve_ex = _add_docstr(_linalg.linalg_solve_ex, r""" +linalg.solve_ex(A, B, *, left=True, check_errors=False, out=None) -> (Tensor, Tensor) + +A version of :func:`~solve` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(result, info)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> Ainv, info = torch.linalg.solve_ex(A) + >>> torch.dist(torch.linalg.inv(A), Ainv) + tensor(0.) + >>> info + tensor(0, dtype=torch.int32) + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""") + inv_ex = _add_docstr(_linalg.linalg_inv_ex, r""" linalg.inv_ex(A, *, check_errors=False, out=None) -> (Tensor, Tensor) @@ -336,16 +373,10 @@ Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. -""" + fr""" -.. note:: This function is computed using :func:`torch.linalg.lu_factor`. - {common_notes["sync_note"]} -""" + r""" - .. seealso:: :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the - absolute value (resp. modulus) of the determinant of real-valued (resp. complex) - square matrices. + absolute value of the determinant of real-valued (resp. complex) square matrices. Args: A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. @@ -372,19 +403,13 @@ For complex :attr:`A`, it returns the angle and the natural logarithm of the modulus of the determinant, that is, a logarithmic polar decomposition of the determinant. +The determinant can be recovered as `sign * exp(logabsdet)`. +When a matrix has a determinant of zero, it returns `(0, -inf)`. + Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. -""" + fr""" -.. note:: This function is computed using :func:`torch.linalg.lu_factor`. - {common_notes["sync_note"]} -""" + r""" - -.. note:: The determinant can be recovered as `sign * exp(logabsdet)`. - -.. note:: When a matrix has a determinant of zero, it returns `(0, -inf)`. - .. seealso:: :func:`torch.linalg.det` computes the determinant of square matrices. @@ -398,10 +423,10 @@ Returns: A named tuple `(sign, logabsdet)`. - `logabsdet` will always be real-valued, even when :attr:`A` is complex. - `sign` will have the same dtype as :attr:`A`. + `logabsdet` will always be real-valued, even when :attr:`A` is complex. + Examples:: >>> A = torch.randn(3, 3) @@ -411,7 +436,7 @@ [-1.6218, -0.9273, -0.0082]]) >>> torch.linalg.det(A) tensor(-0.7576) - >>> torch.linalg.logdet(A) + >>> torch.logdet(A) tensor(nan) >>> torch.linalg.slogdet(A) torch.return_types.linalg_slogdet(sign=tensor(-1.), logabsdet=tensor(-0.2776)) @@ -2776,3 +2801,38 @@ [ 1, 3, 9], [ 1, 5, 25]]) """) + +vecdot = _add_docstr(_linalg.linalg_vecdot, r""" +linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor + +Computes the dot product of two batches of vectors along a dimension. + +In symbols, this function computes + +.. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + +over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex +vectors, and it is the identity for real vectors. + +Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes. +It also supports broadcasting. + +Args: + x (Tensor): first batch of vectors of shape `(*, n)`. + y (Tensor): second batch of vectors of shape `(*, n)`. + +Keyword args: + dim (int): Dimension along which to compute the dot product. Default: `-1`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> v1 = torch.randn(3, 2) + >>> v2 = torch.randn(3, 2) + >>> linalg.vecdot(v1, v2) + tensor([ 0.3223, 0.2815, -0.1944]) + >>> torch.vdot(v1[0], v2[0]) + tensor(0.3223) +""") diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index b978f24e10842..403b28d6a63c6 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -37,6 +37,14 @@ def expired(self): def __del__(self): self._free_weak_ref(self.cdata) + def __hash__(self): + return self.cdata + + def __eq__(self, other): + if id(self) == id(other): + return True + return self.cdata == other.cdata + class SharedCache(dict): """dictionary from multiprocessing handles to StorageWeakRef""" @@ -125,7 +133,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) t = torch._utils._rebuild_tensor( - torch.storage._TypedStorage(wrap_storage=storage._untyped(), dtype=dtype), + torch.storage.TypedStorage(wrap_storage=storage.untyped(), dtype=dtype), tensor_offset, tensor_size, tensor_stride) if tensor_cls == torch.nn.parameter.Parameter: @@ -290,7 +298,7 @@ def storage_from_cache(cls, key): storage_ref = shared_cache.get(key) if storage_ref is None: return None - return torch._UntypedStorage._new_with_weak_ptr(storage_ref.cdata) + return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata) def rebuild_storage_fd(cls, df, size): @@ -307,15 +315,15 @@ def rebuild_storage_fd(cls, df, size): def rebuild_storage_filename(cls, manager, handle, size, dtype=None): - storage: Union[torch._TypedStorage, torch._UntypedStorage] = storage_from_cache(cls, handle) + storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(cls, handle) if storage is not None: return storage._shared_decref() if dtype is None: - storage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, size) + storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) else: byte_size = size * torch._utils._element_size(dtype) - untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) - storage = torch._TypedStorage( + untyped_storage: torch.UntypedStorage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) + storage = torch.TypedStorage( wrap_storage=untyped_storage, dtype=dtype) shared_cache[handle] = StorageWeakRef(storage) @@ -326,16 +334,16 @@ def rebuild_storage_empty(cls): return cls() def rebuild_typed_storage(storage, dtype): - return torch.storage._TypedStorage(wrap_storage=storage, dtype=dtype) + return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype) -# Use for torch.storage._TypedStorage +# Use for torch.storage.TypedStorage def reduce_typed_storage(storage): return (rebuild_typed_storage, (storage._storage, storage.dtype)) def rebuild_typed_storage_child(storage, storage_type): return storage_type(wrap_storage=storage) -# Use for child classes of torch.storage._TypedStorage, like torch.FloatStorage +# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage def reduce_typed_storage_child(storage): return (rebuild_typed_storage_child, (storage._storage, type(storage))) @@ -347,7 +355,7 @@ def reduce_storage(storage): metadata = storage._share_filename_cpu_() cache_key = metadata[1] rebuild = rebuild_storage_filename - if isinstance(storage, torch._TypedStorage): + if isinstance(storage, torch.TypedStorage): metadata += (storage.dtype,) storage._shared_incref() elif storage.size() == 0: @@ -369,12 +377,12 @@ def init_reductions(): ForkingPickler.register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: - if t.__name__ == '_UntypedStorage': + if t.__name__ == 'UntypedStorage': ForkingPickler.register(t, reduce_storage) else: ForkingPickler.register(t, reduce_typed_storage_child) - ForkingPickler.register(torch.storage._TypedStorage, reduce_typed_storage) + ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) for t in torch._tensor_classes: ForkingPickler.register(t, reduce_tensor) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 46df2d835e09e..5b838efc75ea7 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -223,7 +223,7 @@ def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): join (bool): Perform a blocking join on all processes. daemon (bool): The spawned processes' daemon flag. If set to True, daemonic processes will be created. - start_method (string): (deprecated) this method will always use ``spawn`` + start_method (str): (deprecated) this method will always use ``spawn`` as the start method. To use a different start method use ``start_processes()``. diff --git a/torch/nn/__init__.py b/torch/nn/__init__.py index b8cabc2ecda6b..9fca305daa254 100644 --- a/torch/nn/__init__.py +++ b/torch/nn/__init__.py @@ -1,6 +1,10 @@ from .modules import * # noqa: F403 -from .parameter import Parameter, UninitializedParameter, UninitializedBuffer -from .parallel import DataParallel +from .parameter import ( + Parameter as Parameter, + UninitializedParameter as UninitializedParameter, + UninitializedBuffer as UninitializedBuffer, +) +from .parallel import DataParallel as DataParallel from . import init from . import functional from . import utils diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 2c668dda22e5c..ac81cfbce7125 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1651,6 +1651,14 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals :math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a learnable parameter. +.. note:: + `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D, + its size must match the number of input channels, determined by + `input.size(1)` when `input.dim() >= 2`, otherwise 1. + In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded + to the shape of `input` in a way that is not possible using normal + :ref:`broadcasting semantics`. + See :class:`~torch.nn.PReLU` for more details. """) @@ -2122,7 +2130,7 @@ def embedding( is renormalized to have norm :attr:`max_norm`. Note: this will modify :attr:`weight` in-place. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under :class:`torch.nn.Embedding` for more details regarding sparse gradients. @@ -2229,10 +2237,10 @@ def embedding_bag( Note: this will modify :attr:`weight` in-place. norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. Note: this option is not supported when ``mode="max"``. - mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. Default: ``"mean"`` sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under :class:`torch.nn.Embedding` for more details regarding sparse gradients. @@ -2584,7 +2592,7 @@ def ctc_loss( Lengths of the targets blank (int, optional): Blank label. Default :math:`0`. - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the output losses will be divided by the target lengths and then the mean over the batch is taken, ``'sum'``: the output will be @@ -2654,7 +2662,7 @@ def nll_loss( losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -2721,7 +2729,7 @@ def poisson_nll_loss( losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -2771,7 +2779,7 @@ def gaussian_nll_loss( in the input (heteroscedastic), or a single one (homoscedastic). full (bool, optional): include the constant term in the loss calculation. Default: ``False``. eps (float, optional): value added to var, for stability. Default: 1e-6. - reduction (string, optional): specifies the reduction to apply to the output: + reduction (str, optional): specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the output is the average of all batch member losses, ``'sum'``: the output is the sum of all batch member losses. @@ -2863,7 +2871,7 @@ def kl_div( losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. ``'none'``: no reduction will be applied ``'batchmean'``: the sum of the output will be divided by the batchsize @@ -2954,7 +2962,7 @@ def cross_entropy( losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -3039,7 +3047,7 @@ def binary_cross_entropy( losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -3109,7 +3117,7 @@ def binary_cross_entropy_with_logits( losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -3687,7 +3695,7 @@ def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners= size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): output spatial size. scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. - mode (string): algorithm used for upsampling: + mode (str): algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'nearest'`` align_corners (bool, optional): Geometrically, we consider the pixels of the @@ -3772,7 +3780,7 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): output spatial size. scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple, - its length has to match `input.dim()`. + its length has to match the number of spatial dimensions; `input.dim() - 2`. mode (str): algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` @@ -4991,7 +4999,7 @@ def multi_head_attention_forward( - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. """ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) diff --git a/torch/nn/grad.py b/torch/nn/grad.py index 40ffc31f4b93a..0cf5fe23f5b54 100644 --- a/torch/nn/grad.py +++ b/torch/nn/grad.py @@ -2,39 +2,6 @@ import torch from .modules.utils import _single, _pair, _triple -import warnings - - -def _grad_input_padding(grad_output, input_size, stride, padding, kernel_size, dilation=None): - if dilation is None: - # For backward compatibility - warnings.warn("_grad_input_padding 'dilation' argument not provided. Default of 1 is used.") - dilation = [1] * len(stride) - - input_size = list(input_size) - k = grad_output.dim() - 2 - - if len(input_size) == k + 2: - input_size = input_size[-k:] - if len(input_size) != k: - raise ValueError("input_size must have {} elements (got {})" - .format(k + 2, len(input_size))) - - def dim_size(d): - return ((grad_output.size(d + 2) - 1) * stride[d] - 2 * padding[d] + 1 - + dilation[d] * (kernel_size[d] - 1)) - - min_sizes = [dim_size(d) for d in range(k)] - max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)] - for size, min_size, max_size in zip(input_size, min_sizes, max_sizes): - if size < min_size or size > max_size: - raise ValueError( - ("requested an input grad size of {}, but valid sizes range " - "from {} to {} (for a grad_output of {})").format( - input_size, min_sizes, max_sizes, - grad_output.size()[2:])) - - return tuple(input_size[d] - min_sizes[d] for d in range(k)) def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): @@ -62,20 +29,11 @@ def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation= >>> F.grad.conv1d_input(input.shape, weight, grad_output) """ - stride = _single(stride) - padding = _single(padding) - dilation = _single(dilation) - kernel_size = [weight.shape[2]] - - if input_size is None: - raise ValueError("grad.conv1d_input requires specifying an input_size") - - grad_input_padding = _grad_input_padding(grad_output, input_size, stride, - padding, kernel_size, dilation) + input = grad_output.new_empty(1).expand(input_size) - return torch.conv_transpose1d( - grad_output, weight, None, stride, padding, grad_input_padding, groups, - dilation) + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _single(stride), _single(padding), _single(dilation), + False, [0], groups, (True, False, False))[0] def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): @@ -101,29 +59,11 @@ def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation >>> F.grad.conv1d_weight(input, weight.shape, grad_output) """ - stride = _single(stride) - padding = _single(padding) - dilation = _single(dilation) - in_channels = input.shape[1] - out_channels = grad_output.shape[1] - min_batch = input.shape[0] - - grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1) - grad_output = grad_output.contiguous().view( - grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2]) - - input = input.contiguous().view(1, input.shape[0] * input.shape[1], - input.shape[2]) + weight = grad_output.new_empty(1).expand(weight_size) - grad_weight = torch.conv1d(input, grad_output, None, dilation, padding, - stride, in_channels * min_batch) - - grad_weight = grad_weight.contiguous().view( - min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2]) - - return grad_weight.sum(dim=0).view( - in_channels // groups, out_channels, grad_weight.shape[2]).transpose( - 0, 1).narrow(2, 0, weight_size[2]) + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _single(stride), _single(padding), _single(dilation), + False, [0], groups, (False, True, False))[1] def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): @@ -151,20 +91,11 @@ def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation= >>> F.grad.conv2d_input(input.shape, weight, grad_output) """ - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) - kernel_size = (weight.shape[2], weight.shape[3]) - - if input_size is None: - raise ValueError("grad.conv2d_input requires specifying an input_size") + input = grad_output.new_empty(1).expand(input_size) - grad_input_padding = _grad_input_padding(grad_output, input_size, stride, - padding, kernel_size, dilation) - - return torch.conv_transpose2d( - grad_output, weight, None, stride, padding, grad_input_padding, groups, - dilation) + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _pair(stride), _pair(padding), _pair(dilation), + False, [0], groups, (True, False, False))[0] def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): @@ -190,33 +121,11 @@ def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation >>> F.grad.conv2d_weight(input, weight.shape, grad_output) """ - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) - in_channels = input.shape[1] - out_channels = grad_output.shape[1] - min_batch = input.shape[0] - - grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1, - 1) - grad_output = grad_output.contiguous().view( - grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2], - grad_output.shape[3]) - - input = input.contiguous().view(1, input.shape[0] * input.shape[1], - input.shape[2], input.shape[3]) + weight = grad_output.new_empty(1).expand(weight_size) - grad_weight = torch.conv2d(input, grad_output, None, dilation, padding, - stride, in_channels * min_batch) - - grad_weight = grad_weight.contiguous().view( - min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2], - grad_weight.shape[3]) - - return grad_weight.sum(dim=0).view( - in_channels // groups, out_channels, - grad_weight.shape[2], grad_weight.shape[3]).transpose(0, 1).narrow( - 2, 0, weight_size[2]).narrow(3, 0, weight_size[3]) + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _pair(stride), _pair(padding), _pair(dilation), + False, [0], groups, (False, True, False))[1] def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): @@ -244,20 +153,11 @@ def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation= >>> F.grad.conv3d_input(input.shape, weight, grad_output) """ - stride = _triple(stride) - padding = _triple(padding) - dilation = _triple(dilation) - kernel_size = (weight.shape[2], weight.shape[3], weight.shape[4]) - - if input_size is None: - raise ValueError("grad.conv3d_input requires specifying an input_size") + input = grad_output.new_empty(1).expand(input_size) - grad_input_padding = _grad_input_padding(grad_output, input_size, stride, - padding, kernel_size, dilation) - - return torch.conv_transpose3d( - grad_output, weight, None, stride, padding, grad_input_padding, groups, - dilation) + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _triple(stride), _triple(padding), _triple(dilation), + False, [0], groups, (True, False, False))[0] def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): @@ -283,31 +183,8 @@ def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation >>> F.grad.conv3d_weight(input, weight.shape, grad_output) """ - stride = _triple(stride) - padding = _triple(padding) - dilation = _triple(dilation) - in_channels = input.shape[1] - out_channels = grad_output.shape[1] - min_batch = input.shape[0] - - grad_output = grad_output.repeat(1, in_channels // groups, 1, 1, 1) - grad_output = grad_output.contiguous().view( - grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2], - grad_output.shape[3], grad_output.shape[4]) - - input = input.contiguous().view(1, input.shape[0] * input.shape[1], - input.shape[2], input.shape[3], - input.shape[4]) - - grad_weight = torch.conv3d(input, grad_output, None, dilation, padding, - stride, in_channels * min_batch) - - grad_weight = grad_weight.contiguous().view( - min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2], - grad_weight.shape[3], grad_weight.shape[4]) - - return grad_weight.sum(dim=0).view( - in_channels // groups, out_channels, grad_weight.shape[2], - grad_weight.shape[3], grad_weight.shape[4]).transpose(0, 1).narrow( - 2, 0, weight_size[2]).narrow(3, 0, weight_size[3]).narrow( - 4, 0, weight_size[4]) + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _triple(stride), _triple(padding), _triple(dilation), + False, [0], groups, (False, True, False))[1] diff --git a/torch/nn/init.py b/torch/nn/init.py index 09d233e864f3c..e9d2354497e78 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -246,7 +246,7 @@ def dirac_(tensor, groups=1): Args: tensor: a {3, 4, 5}-dimensional `torch.Tensor` - groups (optional): number of groups in the conv layer (default: 1) + groups (int, optional): number of groups in the conv layer (default: 1) Examples: >>> w = torch.empty(3, 16, 5, 5) >>> nn.init.dirac_(w) diff --git a/torch/nn/intrinsic/modules/fused.py b/torch/nn/intrinsic/modules/fused.py index b30b9a7d430c6..261142fa8fc6c 100644 --- a/torch/nn/intrinsic/modules/fused.py +++ b/torch/nn/intrinsic/modules/fused.py @@ -2,6 +2,9 @@ from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.nn.utils.parametrize import type_before_parametrizations +__all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d', + 'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d', + 'LinearBn1d'] # Used for identifying intrinsic modules used in quantization class _FusedModule(torch.nn.Sequential): pass diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py index 7fde27abb9352..dc8ef829c68b7 100644 --- a/torch/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -10,6 +10,8 @@ from torch.nn.parameter import Parameter from typing import TypeVar +__all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d', + 'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats'] _BN_CLASS_MAP = { 1: nn.BatchNorm1d, 2: nn.BatchNorm2d, diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index a0ee1ef875540..980036e5bd183 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -9,6 +9,10 @@ from .module import Module from .. import functional as F +__all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh', + 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU', + 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink', + 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax'] class Threshold(Module): r"""Thresholds each element of the input Tensor. @@ -653,7 +657,7 @@ class GELU(Module): :math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) Args: - approximate (string, optional): the gelu approximation algorithm to use: + approximate (str, optional): the gelu approximation algorithm to use: ``'none'`` | ``'tanh'``. Default: ``'none'`` Shape: @@ -1049,7 +1053,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. .. note:: @@ -1083,10 +1087,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: why_not_fast_path = "_qkv_same_embed_dim was not True" - elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): - why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input" - elif not query.is_nested and key_padding_mask is not None and attn_mask is not None: - why_not_fast_path = "key_padding_mask and attn_mask were both supplied" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = "key_padding_mask is not supported with NestedTensor input" if not why_not_fast_path: tensor_args = ( diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 621bcee8a1151..5f6fb08c82fef 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -11,6 +11,7 @@ from .module import Module from ..functional import log_softmax +__all__ = ['AdaptiveLogSoftmaxWithLoss'] _ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 65271ebd1b816..6d6d1d2ec3d16 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -10,6 +10,8 @@ from .lazy import LazyModuleMixin from .module import Module +__all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d', + 'LazyBatchNorm3d', 'SyncBatchNorm'] class _NormBase(Module): """Common base of _InstanceNorm and _BatchNorm""" diff --git a/torch/nn/modules/channelshuffle.py b/torch/nn/modules/channelshuffle.py index 740ee6022ca2e..efaa76d2c8ab7 100644 --- a/torch/nn/modules/channelshuffle.py +++ b/torch/nn/modules/channelshuffle.py @@ -3,6 +3,7 @@ from torch import Tensor +__all__ = ['ChannelShuffle'] class ChannelShuffle(Module): r"""Divide the channels in a tensor of shape :math:`(*, C , H, W)` diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 9061669ab4ba8..9ede31035728f 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -10,6 +10,8 @@ from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union +__all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict'] + T = TypeVar('T', bound=Module) @@ -115,11 +117,74 @@ def __delitem__(self, idx: Union[slice, int]) -> None: else: key = self._get_item_by_idx(self._modules.keys(), idx) delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) @_copy_to_script_wrapper def __len__(self) -> int: return len(self._modules) + def __add__(self, other) -> 'Sequential': + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError('add operator supports only objects ' + 'of Sequential class, but {} is given.'.format( + str(type(other)))) + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> 'Sequential': + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError('add operator supports only objects ' + 'of Sequential class, but {} is given.'.format( + str(type(other)))) + + def __mul__(self, other: int) -> 'Sequential': + if not isinstance(other, int): + raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") + elif (other <= 0): + raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> 'Sequential': + return self.__mul__(other) + + def __imul__(self, other: int) -> 'Sequential': + if not isinstance(other, int): + raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") + elif (other <= 0): + raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + @_copy_to_script_wrapper def __dir__(self): keys = super(Sequential, self).__dir__() @@ -148,6 +213,26 @@ def append(self, module: Module) -> 'Sequential': self.add_module(str(len(self)), module) return self + def insert(self, index: int, module: Module) -> 'Sequential': + if not isinstance(module, Module): + raise AssertionError( + 'module should be of type: {}'.format(Module)) + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError( + 'Index out of range: {}'.format(index)) + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential) -> 'Sequential': + for layer in sequential: + self.append(layer) + return self + class ModuleList(Module): r"""Holds submodules in a list. @@ -253,6 +338,11 @@ def append(self, module: Module) -> 'ModuleList': self.add_module(str(len(self)), module) return self + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + def extend(self, modules: Iterable[Module]) -> 'ModuleList': r"""Appends modules from a Python iterable to the end of the list. @@ -352,7 +442,7 @@ def pop(self, key: str) -> Module: r"""Remove key from the ModuleDict and return its module. Args: - key (string): key to pop from the ModuleDict + key (str): key to pop from the ModuleDict """ v = self[key] del self[key] @@ -644,7 +734,7 @@ def setdefault(self, key: str, default: Optional[Any] = None) -> Any: `default` defaults to `None`. Args: - key (string): key to set default for + key (str): key to set default for default (Any): the parameter set to the key """ @@ -662,7 +752,7 @@ def pop(self, key: str) -> Any: r"""Remove key from the ParameterDict and return its parameter. Args: - key (string): key to pop from the ParameterDict + key (str): key to pop from the ParameterDict """ v = self[key] del self[key] @@ -684,7 +774,7 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: Otherwise return default if provided, None if not. Args: - key (string): key to get from the ParameterDict + key (str): key to get from the ParameterDict default (Parameter, optional): value to return if key not present """ return self[key] if key in self else default diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 3548ed4d91a0a..d216a8c778499 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -15,6 +15,10 @@ from ..common_types import _size_1_t, _size_2_t, _size_3_t from typing import Optional, List, Tuple, Union +__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', + 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d', + 'LazyConvTranspose3d'] + convolution_notes = \ {"groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. :attr:`in_channels` and :attr:`out_channels` must both be divisible by @@ -228,7 +232,7 @@ class Conv1d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to both sides of the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 @@ -372,7 +376,7 @@ class Conv2d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input @@ -514,7 +518,7 @@ class Conv3d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all six sides of the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` @@ -1203,7 +1207,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 @@ -1272,7 +1276,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 @@ -1341,7 +1345,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 diff --git a/torch/nn/modules/distance.py b/torch/nn/modules/distance.py index 174659d3d30f0..065ad6b243972 100644 --- a/torch/nn/modules/distance.py +++ b/torch/nn/modules/distance.py @@ -3,6 +3,7 @@ from torch import Tensor +__all__ = ['PairwiseDistance', 'CosineSimilarity'] class PairwiseDistance(Module): r""" diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py index 5f25aae7fa519..0b35bd546e230 100644 --- a/torch/nn/modules/dropout.py +++ b/torch/nn/modules/dropout.py @@ -3,6 +3,7 @@ from torch import Tensor +__all__ = ['Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout'] class _DropoutNd(Module): __constants__ = ['p', 'inplace'] diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index 69c788d5862e3..616b6bc690e37 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -4,6 +4,7 @@ from torch import Tensor from torch.types import _size +__all__ = ['Flatten', 'Unflatten'] class Flatten(Module): r""" diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index 5c10bd21df2b0..1819469e0ad37 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -5,6 +5,7 @@ from torch import Tensor from ..common_types import _size_any_t +__all__ = ['Fold', 'Unfold'] class Fold(Module): r"""Combines an array of sliding local blocks into a large containing diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index b12714b3112f8..6d384ebb427ba 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -3,6 +3,8 @@ from .batchnorm import _LazyNormBase, _NormBase from .. import functional as F +__all__ = ['InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LazyInstanceNorm1d', + 'LazyInstanceNorm2d', 'LazyInstanceNorm3d'] class _InstanceNorm(_NormBase): def __init__( diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index dc810af2762e5..7899850ce500e 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -5,6 +5,7 @@ import torch from ..parameter import is_lazy +__all__ = ['LazyModuleMixin'] class _LazyProtocol(Protocol): """This is to avoid errors with mypy checks for diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 2d0bd65624222..b5d07843d7649 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -8,6 +8,11 @@ from torch import Tensor from typing import Callable, Optional +__all__ = ['L1Loss', 'NLLLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'GaussianNLLLoss', 'KLDivLoss', + 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', + 'SmoothL1Loss', 'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'MultiLabelSoftMarginLoss', + 'CosineEmbeddingLoss', 'MarginRankingLoss', 'MultiMarginLoss', 'TripletMarginLoss', + 'TripletMarginWithDistanceLoss', 'CTCLoss'] class _Loss(Module): reduction: str @@ -66,7 +71,7 @@ class L1Loss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -155,7 +160,7 @@ class NLLLoss(_WeightedLoss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``None`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -255,7 +260,7 @@ class PoissonNLLLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -316,7 +321,7 @@ class GaussianNLLLoss(_Loss): calculation. Default: ``False``. eps (float, optional): value used to clamp ``var`` (see note below), for stability. Default: 1e-6. - reduction (string, optional): specifies the reduction to apply to the + reduction (str, optional): specifies the reduction to apply to the output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the output is the average of all batch member losses, ``'sum'``: the output is the sum of all batch member @@ -433,7 +438,7 @@ class KLDivLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is `False`, returns a loss per batch element instead and ignores :attr:`size_average`. Default: `True` - reduction (string, optional): Specifies the reduction to apply to the output. Default: `"mean"` + reduction (str, optional): Specifies the reduction to apply to the output. Default: `"mean"` log_target (bool, optional): Specifies whether `target` is the log space. Default: `False` Shape: @@ -502,7 +507,7 @@ class MSELoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -582,7 +587,7 @@ class BCELoss(_WeightedLoss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -679,7 +684,7 @@ class BCEWithLogitsLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -753,7 +758,7 @@ class HingeEmbeddingLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -808,7 +813,7 @@ class MultiLabelMarginLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -898,7 +903,7 @@ class SmoothL1Loss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -960,7 +965,7 @@ class HuberLoss(_Loss): between the two losses. Args: - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` @@ -1000,7 +1005,7 @@ class SoftMarginLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -1107,7 +1112,7 @@ class probabilities only when a single class label per minibatch item is too res losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -1194,7 +1199,7 @@ class MultiLabelSoftMarginLoss(_WeightedLoss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -1244,7 +1249,7 @@ class CosineEmbeddingLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -1292,7 +1297,7 @@ class MarginRankingLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -1364,7 +1369,7 @@ class MultiMarginLoss(_WeightedLoss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -1444,7 +1449,7 @@ class TripletMarginLoss(_Loss): losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` @@ -1523,7 +1528,7 @@ class TripletMarginWithDistanceLoss(_Loss): loss for input tensors using the :math:`l_p` distance as the distance function. Args: - distance_function (callable, optional): A nonnegative, real-valued function that + distance_function (Callable, optional): A nonnegative, real-valued function that quantifies the closeness of two tensors. If not specified, `nn.PairwiseDistance` will be used. Default: ``None`` margin (float, optional): A nonnegative margin representing the minimum difference @@ -1535,7 +1540,7 @@ class TripletMarginWithDistanceLoss(_Loss): V. Balntas, E. Riba et al. If True, and if the positive example is closer to the negative example than the anchor is, swaps the positive example and the anchor in the loss computation. Default: ``False``. - reduction (string, optional): Specifies the (optional) reduction to apply to the output: + reduction (str, optional): Specifies the (optional) reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` @@ -1612,7 +1617,7 @@ class CTCLoss(_Loss): Args: blank (int, optional): blank label. Default :math:`0`. - reduction (string, optional): Specifies the reduction to apply to the output: + reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: ``'mean'`` diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 9cb2c3c01db4b..d0bfcd9e175a3 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -2,6 +2,7 @@ import itertools import warnings import functools +import weakref import torch from ..parameter import Parameter @@ -11,6 +12,9 @@ from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List from ...utils.hooks import RemovableHandle +__all__ = ['register_module_forward_pre_hook', 'register_module_forward_hook', 'register_module_backward_hook', + 'register_module_full_backward_hook', 'Module'] + _grad_t = Union[Tuple[Tensor, ...], Tensor] # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use # of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be @@ -37,6 +41,41 @@ def _addindent(s_, numSpaces): s = first + '\n' + s return s +class _WrappedHook: + def __init__(self, hook: Callable, module: Optional["Module"] = None): + self.hook: Callable = hook + functools.update_wrapper(self, hook) + + self.with_module: bool = False + + if module is not None: + self.module: weakref.ReferenceType["Module"] = weakref.ref(module) + self.with_module = True + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if self.with_module: + module = self.module() + if module is None: + raise RuntimeError("You are trying to call the hook of a dead Module!") + return self.hook(module, *args, **kwargs) + return self.hook(*args, **kwargs) + + def __getstate__(self) -> Dict: + result = {"hook": self.hook, "with_module": self.with_module} + if self.with_module: + result["module"] = self.module() + + return result + + def __setstate__(self, state: Dict): + self.hook = state["hook"] + self.with_module = state["with_module"] + + if self.with_module: + if state["module"] is None: + raise RuntimeError("You are trying to revive the hook of a dead Module!") + self.module = weakref.ref(state["module"]) + r"""This tracks hooks common to all modules that are executed before/after calling forward and backward. This is global state used for debugging/profiling @@ -249,7 +288,17 @@ def forward(self, x): the change.""" training: bool + _parameters: Dict[str, Optional[Parameter]] + _buffers: Dict[str, Optional[Tensor]] + _non_persistent_buffers_set: Set[str] + _backward_hooks: Dict[int, Callable] _is_full_backward_hook: Optional[bool] + _forward_hooks: Dict[int, Callable] + _forward_pre_hooks: Dict[int, Callable] + _state_dict_hooks: Dict[int, Callable] + _load_state_dict_pre_hooks: Dict[int, Callable] + _load_state_dict_post_hooks: Dict[int, Callable] + _modules: Dict[str, Optional['Module']] def __init__(self) -> None: """ @@ -257,18 +306,24 @@ def __init__(self) -> None: """ torch._C._log_api_usage_once("python.nn_module") - self.training = True - self._parameters: Dict[str, Optional[Parameter]] = OrderedDict() - self._buffers: Dict[str, Optional[Tensor]] = OrderedDict() - self._non_persistent_buffers_set: Set[str] = set() - self._backward_hooks: Dict[int, Callable] = OrderedDict() - self._is_full_backward_hook = None - self._forward_hooks: Dict[int, Callable] = OrderedDict() - self._forward_pre_hooks: Dict[int, Callable] = OrderedDict() - self._state_dict_hooks: Dict[int, Callable] = OrderedDict() - self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict() - self._load_state_dict_post_hooks: Dict[int, Callable] = OrderedDict() - self._modules: Dict[str, Optional['Module']] = OrderedDict() + """ + Calls super().__setattr__('a', a) instead of the typical self.a = a + to avoid Module.__setattr__ overhead. Module's __setattr__ has special + handling for parameters, submodules, and buffers but simply calls into + super().__setattr__ for all other attributes. + """ + super().__setattr__('training', True) + super().__setattr__('_parameters', OrderedDict()) + super().__setattr__('_buffers', OrderedDict()) + super().__setattr__('_non_persistent_buffers_set', set()) + super().__setattr__('_backward_hooks', OrderedDict()) + super().__setattr__('_is_full_backward_hook', None) + super().__setattr__('_forward_hooks', OrderedDict()) + super().__setattr__('_forward_pre_hooks', OrderedDict()) + super().__setattr__('_state_dict_hooks', OrderedDict()) + super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) + super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) + super().__setattr__('_modules', OrderedDict()) forward: Callable[..., Any] = _forward_unimplemented @@ -287,7 +342,7 @@ def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool Buffers can be accessed as attributes using given names. Args: - name (string): name of the buffer. The buffer can be accessed + name (str): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor or None): buffer to be registered. If ``None``, then operations that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, @@ -332,7 +387,7 @@ def register_parameter(self, name: str, param: Optional[Parameter]) -> None: The parameter can be accessed as an attribute using given name. Args: - name (string): name of the parameter. The parameter can be accessed + name (str): name of the parameter. The parameter can be accessed from this module using the given name param (Parameter or None): parameter to be added to the module. If ``None``, then operations that run on parameters, such as :attr:`cuda`, @@ -374,7 +429,7 @@ def add_module(self, name: str, module: Optional['Module']) -> None: The module can be accessed as an attribute using the given name. Args: - name (string): name of the child module. The child module can be + name (str): name of the child module. The child module can be accessed from this module using the given name module (Module): child module to be added to the module. """ @@ -1167,9 +1222,7 @@ def _call_impl(self, *input, **kwargs): grad_fn = var.grad_fn if grad_fn is not None: for hook in non_full_backward_hooks: - wrapper = functools.partial(hook, self) - functools.update_wrapper(wrapper, hook) - grad_fn.register_hook(wrapper) + grad_fn.register_hook(_WrappedHook(hook, self)) self._maybe_warn_non_full_backward_hook(input, result, grad_fn) return result @@ -1253,7 +1306,7 @@ def remove_from(*dicts_or_sets): .format(torch.typename(value), name)) buffers[name] = value else: - object.__setattr__(self, name, value) + super().__setattr__(name, value) def __delattr__(self, name): if name in self._parameters: @@ -1264,7 +1317,7 @@ def __delattr__(self, name): elif name in self._modules: del self._modules[name] else: - object.__delattr__(self, name) + super().__delattr__(name) def _register_state_dict_hook(self, hook): r"""These hooks will be called with arguments: `self`, `state_dict`, @@ -1402,9 +1455,7 @@ def _register_load_state_dict_pre_hook(self, hook, with_module=False): instance to the hook as the first parameter. """ handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) - if with_module: - hook = functools.partial(hook, self) - self._load_state_dict_pre_hooks[handle.id] = hook + self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) return handle def register_load_state_dict_post_hook(self, hook): @@ -1654,7 +1705,7 @@ def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[T are direct members of this module. Yields: - (string, Parameter): Tuple containing the name and parameter + (str, Parameter): Tuple containing the name and parameter Example:: @@ -1702,7 +1753,7 @@ def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tupl are direct members of this module. Yields: - (string, torch.Tensor): Tuple containing the name and buffer + (str, torch.Tensor): Tuple containing the name and buffer Example:: @@ -1731,7 +1782,7 @@ def named_children(self) -> Iterator[Tuple[str, 'Module']]: the name of the module as well as the module itself. Yields: - (string, Module): Tuple containing a name and child module + (str, Module): Tuple containing a name and child module Example:: @@ -1784,7 +1835,7 @@ def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', or not Yields: - (string, Module): Tuple of name and module + (str, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index b9c43c402c5fb..ce2b83253a07c 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -9,6 +9,7 @@ from torch import Tensor, Size from typing import Union, List, Tuple +__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm'] class LocalResponseNorm(Module): r"""Applies local response normalization over an input signal composed diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 79e658373e05a..6511b5aeb18f6 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -9,6 +9,8 @@ # TODO: grad_output size asserts in THNN +__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', + 'ReflectionPad3d', 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad2d'] class _ConstantPadNd(Module): __constants__ = ['padding', 'value'] diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index d17f5616c2e9d..eb5e48dd4b0e4 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -3,6 +3,7 @@ from torch import Tensor +__all__ = ['PixelShuffle', 'PixelUnshuffle'] class PixelShuffle(Module): r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 91c0476f6996a..5da8f0c41fe6a 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -8,6 +8,10 @@ from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t, _ratio_3_t, _ratio_2_t, _size_any_opt_t, _size_2_opt_t, _size_3_opt_t) +__all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', + 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'LPPool1d', + 'LPPool2d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', + 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d'] class _MaxPoolNd(Module): __constants__ = ['kernel_size', 'stride', 'padding', 'dilation', diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 2b7c569cd5c80..7ad9fa58d684f 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -11,15 +11,20 @@ from .. import init from ... import _VF +__all__ = ['RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell'] + _rnn_impls = { 'RNN_TANH': _VF.rnn_tanh, 'RNN_RELU': _VF.rnn_relu, } -def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: return tensor.index_select(dim, permutation) +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") + return _apply_permutation(tensor, permutation, dim) class RNNBase(Module): __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias', @@ -234,7 +239,7 @@ def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optiona def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): if permutation is None: return hx - return apply_permutation(hx, permutation) + return _apply_permutation(hx, permutation) def extra_repr(self) -> str: @@ -702,7 +707,7 @@ def permute_hidden(self, # type: ignore[override] ) -> Tuple[Tensor, Tensor]: if permutation is None: return hx - return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) + return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation) # Same as above, see torch/nn/modules/module.py::_forward_unimplemented @overload # type: ignore[override] diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index 4a3d8922cb0bd..2ab093c37f506 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -8,6 +8,7 @@ from .. import functional as F from .. import init +__all__ = ['Embedding', 'EmbeddingBag'] class Embedding(Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -27,7 +28,7 @@ class Embedding(Module): max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` is renormalized to have norm :attr:`max_norm`. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. @@ -182,14 +183,14 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, Args: embeddings (Tensor): FloatTensor containing weights for the Embedding. First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. - freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated during training, i.e. it remains as a fixed "pad". max_norm (float, optional): See module initialization documentation. norm_type (float, optional): See module initialization documentation. Default ``2``. - scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. sparse (bool, optional): See module initialization documentation. Examples:: @@ -244,10 +245,10 @@ class EmbeddingBag(Module): max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` is renormalized to have norm :attr:`max_norm`. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. Note: this option is not supported when ``mode="max"``. - mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights` into consideration. ``"mean"`` computes the average of the values in the bag, ``"max"`` computes the max value over each bag. @@ -409,12 +410,12 @@ def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, max_norm: Opti Args: embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag. First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'. - freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True`` max_norm (float, optional): See module initialization documentation. Default: ``None`` norm_type (float, optional): See module initialization documentation. Default ``2``. - scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. - mode (string, optional): See module initialization documentation. Default: ``"mean"`` + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. + mode (str, optional): See module initialization documentation. Default: ``"mean"`` sparse (bool, optional): See module initialization documentation. Default: ``False``. include_last_offset (bool, optional): See module initialization documentation. Default: ``False``. padding_idx (int, optional): See module initialization documentation. Default: ``None``. diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 9cf8cc7078d30..5a0d8c7dcc6c2 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -12,6 +12,7 @@ from .linear import Linear from .normalization import LayerNorm +__all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer'] class Transformer(Module): r"""A transformer model. User is able to modify the attributes as needed. The architecture @@ -203,39 +204,68 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma output = src convert_to_nested = False first_layer = self.layers[0] - if isinstance(first_layer, torch.nn.TransformerEncoderLayer): - if (not first_layer.norm_first and not first_layer.training and - first_layer.self_attn.batch_first and - first_layer.self_attn._qkv_same_embed_dim and first_layer.activation_relu_or_gelu and - first_layer.norm1.eps == first_layer.norm2.eps and - src.dim() == 3 and self.enable_nested_tensor) : - if src_key_padding_mask is not None and not output.is_nested and mask is None: - tensor_args = ( - src, - first_layer.self_attn.in_proj_weight, - first_layer.self_attn.in_proj_bias, - first_layer.self_attn.out_proj.weight, - first_layer.self_attn.out_proj.bias, - first_layer.norm1.weight, - first_layer.norm1.bias, - first_layer.norm2.weight, - first_layer.norm2.bias, - first_layer.linear1.weight, - first_layer.linear1.bias, - first_layer.linear2.weight, - first_layer.linear2.bias, - ) - if not torch.overrides.has_torch_function(tensor_args): - if not torch.is_grad_enabled() or all([not x.requires_grad for x in tensor_args]): - if output.is_cuda or 'cpu' in str(output.device): - convert_to_nested = True - output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not()) + src_key_padding_mask_for_layers = src_key_padding_mask + why_not_sparsity_fast_path = '' + str_first_layer = "self.layers[0]" + if not isinstance(first_layer, torch.nn.TransformerEncoderLayer): + why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer" + elif first_layer.norm_first : + why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True" + elif first_layer.training: + why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" + elif not first_layer.self_attn.batch_first: + why_not_sparsity_fast_path = f" {str_first_layer}.self_attn.batch_first was not True" + elif not first_layer.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True" + elif not first_layer.activation_relu_or_gelu: + why_not_sparsity_fast_path = f" {str_first_layer}.activation_relu_or_gelu was not True" + elif not (first_layer.norm1.eps == first_layer.norm2.eps) : + why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps" + elif not src.dim() == 3: + why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + elif not self.enable_nested_tensor: + why_not_sparsity_fast_path = "enable_nested_tensor was not True" + elif src_key_padding_mask is None: + why_not_sparsity_fast_path = "src_key_padding_mask was None" + elif (not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())): + why_not_sparsity_fast_path = "src and src_key_padding_mask was not left aligned" + elif output.is_nested: + why_not_sparsity_fast_path = "NestedTensor input is not supported" + elif mask is not None: + why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied" + + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + first_layer.self_attn.in_proj_weight, + first_layer.self_attn.in_proj_bias, + first_layer.self_attn.out_proj.weight, + first_layer.self_attn.out_proj.bias, + first_layer.norm1.weight, + first_layer.norm1.bias, + first_layer.norm2.weight, + first_layer.norm2.bias, + first_layer.linear1.weight, + first_layer.linear1.bias, + first_layer.linear2.weight, + first_layer.linear2.bias, + ) + + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not (src.is_cuda or 'cpu' in str(src.device)): + why_not_sparsity_fast_path = "src is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): + why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + + if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): + convert_to_nested = True + output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not()) + src_key_padding_mask_for_layers = None for mod in self.layers: - if convert_to_nested: - output = mod(output, src_mask=mask) - else: - output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) + output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers) if convert_to_nested: output = output.to_padded_tensor(0.) @@ -337,7 +367,6 @@ class TransformerEncoderLayer(Module): argument ``requires_grad`` - training is disabled (using ``.eval()``) - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``) - - norm_first is ``False`` (this restriction may be loosened in the future) - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu`` - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed - if src is a `NestedTensor `_, neither ``src_mask`` @@ -388,9 +417,10 @@ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropou self.activation = activation def __setstate__(self, state): - if 'activation' not in state: - state['activation'] = F.relu super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, 'activation'): + self.activation = F.relu + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: @@ -406,14 +436,25 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, """ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf - - if (src.dim() == 3 and not self.norm_first and not self.training and - self.self_attn.batch_first and - self.self_attn._qkv_same_embed_dim and self.activation_relu_or_gelu and - self.norm1.eps == self.norm2.eps and - ((src_mask is None and src_key_padding_mask is None) - if src.is_nested - else (src_mask is None or src_key_padding_mask is None))): + why_not_sparsity_fast_path = '' + if not src.dim() == 3: + why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + elif self.training: + why_not_sparsity_fast_path = "training is enabled" + elif not self.self_attn.batch_first : + why_not_sparsity_fast_path = "self_attn.batch_first was not True" + elif not self.self_attn._qkv_same_embed_dim : + why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" + elif not self.activation_relu_or_gelu: + why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" + elif not (self.norm1.eps == self.norm2.eps): + why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" + elif src_mask is not None: + why_not_sparsity_fast_path = "src_mask is not supported for fastpath" + elif src.is_nested and src_key_padding_mask is not None: + why_not_sparsity_fast_path = "src_key_padding_mask is not supported with NestedTensor input for fastpath" + + if not why_not_sparsity_fast_path: tensor_args = ( src, self.self_attn.in_proj_weight, @@ -429,11 +470,18 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, self.linear2.weight, self.linear2.bias, ) - if (not torch.overrides.has_torch_function(tensor_args) and - # We have to use a list comprehension here because TorchScript - # doesn't support generator expressions. - all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]) and - (not torch.is_grad_enabled() or all([not x.requires_grad for x in tensor_args]))): + + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): + why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + + if not why_not_sparsity_fast_path: return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim, @@ -443,7 +491,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, self.self_attn.out_proj.weight, self.self_attn.out_proj.bias, self.activation_relu_or_gelu == 2, - False, # norm_first, currently not supported + self.norm_first, self.norm1.eps, self.norm1.weight, self.norm1.bias, @@ -453,8 +501,9 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, self.linear1.bias, self.linear2.weight, self.linear2.bias, - src_mask if src_mask is not None else src_key_padding_mask, + src_mask if src_mask is not None else src_key_padding_mask, # TODO: split into two args ) + x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index d2aa2fb9a06ef..6e1702d36b53e 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -5,6 +5,7 @@ from typing import Optional from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t +__all__ = ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d'] class Upsample(Module): r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index dd0c7aae441a9..1b8ca39eab986 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -2,6 +2,7 @@ from itertools import repeat from typing import List, Dict, Any +__all__ = ['consume_prefix_in_state_dict_if_present'] def _ntuple(n, name="parse"): def parse(x): diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index cc4d18afa908d..77111aa67c036 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -13,6 +13,8 @@ _get_devices_properties ) +__all__ = ['DataParallel', 'data_parallel'] + def _check_balance(device_ids): imbalance_warn = """ There is an imbalance between your GPUs. You may want to exclude GPU {} which diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 69bbd8d705cd7..376ba0906cd18 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -39,6 +39,7 @@ from ._replicated_tensor_ddp_utils import _ddp_with_replicated_tensor_enabled from .scatter_gather import gather, is_namedtuple, scatter_kwargs # noqa: F401 +__all__ = ['DistributedDataParallel'] logger = logging.getLogger(__name__) @@ -1323,7 +1324,7 @@ def _register_buffer_comm_hook( _BufferCommHookLocation.POST_FORWARD means that the hook will run _after_ the forward pass. - hook (callable): Callable with the following signature: + hook (Callable): Callable with the following signature: ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``: NOTE: To maximize performance, users can return a @@ -1363,7 +1364,7 @@ def register_comm_hook(self, state: object, hook: callable): It is locally stored by each worker and shared by all the gradient tensors on the worker. - hook (callable): Callable with the following signature: + hook (Callable): Callable with the following signature: ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``: This function is called once the bucket is ready. The diff --git a/torch/nn/parallel/parallel_apply.py b/torch/nn/parallel/parallel_apply.py index 06ab69332e16a..80553fee046ad 100644 --- a/torch/nn/parallel/parallel_apply.py +++ b/torch/nn/parallel/parallel_apply.py @@ -45,16 +45,19 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): else: devices = [None] * len(modules) devices = [_get_device_index(x, True) for x in devices] + streams = [torch.cuda.current_stream(x) for x in devices] lock = threading.Lock() results = {} grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() - def _worker(i, module, input, kwargs, device=None): + def _worker(i, module, input, kwargs, device=None, stream=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() + if stream is None: + stream = torch.cuda.current_stream(device) try: - with torch.cuda.device(device), autocast(enabled=autocast_enabled): + with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): input = (input,) @@ -68,16 +71,16 @@ def _worker(i, module, input, kwargs, device=None): if len(modules) > 1: threads = [threading.Thread(target=_worker, - args=(i, module, input, kwargs, device)) - for i, (module, input, kwargs, device) in - enumerate(zip(modules, inputs, kwargs_tup, devices))] + args=(i, module, input, kwargs, device, stream)) + for i, (module, input, kwargs, device, stream) in + enumerate(zip(modules, inputs, kwargs_tup, devices, streams))] for thread in threads: thread.start() for thread in threads: thread.join() else: - _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0]) outputs = [] for i in range(len(inputs)): diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index e3ed2101e47ae..66e829f498ea5 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,7 +1,15 @@ import torch from ._functions import Scatter, Gather +import warnings + +__all__ = ['scatter', 'scatter_kwargs', 'gather'] def is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + warnings.warn("is_namedtuple is deprecated, please use the python checks instead") + return _is_namedtuple(obj) + +def _is_namedtuple(obj): # Check if type was created from collections.namedtuple or a typing.NamedTuple. return ( isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") @@ -17,7 +25,7 @@ def scatter(inputs, target_gpus, dim=0): def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) - if is_namedtuple(obj): + if _is_namedtuple(obj): return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) @@ -68,7 +76,7 @@ def gather_map(outputs): raise ValueError('All dicts must have the same number of keys') return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) - if is_namedtuple(out): + if _is_namedtuple(out): return type(out)._make(map(gather_map, zip(*outputs))) return type(out)(map(gather_map, zip(*outputs))) diff --git a/torch/nn/qat/modules/embedding_ops.py b/torch/nn/qat/modules/embedding_ops.py index 29425598452b0..da7f33363742d 100644 --- a/torch/nn/qat/modules/embedding_ops.py +++ b/torch/nn/qat/modules/embedding_ops.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F +__all__ = ['Embedding', 'EmbeddingBag'] class Embedding(nn.Embedding): r""" diff --git a/torch/nn/quantizable/modules/activation.py b/torch/nn/quantizable/modules/activation.py index 081ccd2526f24..65ea5c2741359 100644 --- a/torch/nn/quantizable/modules/activation.py +++ b/torch/nn/quantizable/modules/activation.py @@ -1,7 +1,7 @@ import torch +import torch.jit # this is needed to avoid a circular import from torch import nn import torch.nn.functional as nnF -import torch.nn.quantized as nnq from torch import Tensor from typing import Optional, Tuple @@ -73,7 +73,8 @@ def __init__(self, embed_dim: int, num_heads: int, self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment] # Functionals - self.q_scaling_product = nnq.FloatFunctional() + self.q_scaling_product = torch.nn.quantized.FloatFunctional() + # note: importing torch.nn.quantized at top creates a circular import # Quant/Dequant self.quant_attn_output = torch.ao.quantization.QuantStub() @@ -221,30 +222,12 @@ def dequantize(self): @classmethod def from_observed(cls, other): - converted = torch.ao.quantization.convert(other, mapping=None, - inplace=False, - remove_qconfig=True, - convert_custom_config_dict=None) - # Remove the parameters for the bias_k and bias_v to quantize them - # TODO: This is a potential source of accuracy drop. - # quantized cat takes the scale and zp of the first - # element, which might lose the precision in the bias_k - # and the bias_v (which are cat'ed with k/v being first). - if converted.bias_k is not None: - bias_k = converted._parameters.pop('bias_k') - sc, zp = torch._choose_qparams_per_tensor(bias_k, - reduce_range=False) - bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8) - setattr(converted, 'bias_k', bias_k) # noqa: B010 - - if converted.bias_v is not None: - bias_v = converted._parameters.pop('bias_v') - sc, zp = torch._choose_qparams_per_tensor(bias_k, - reduce_range=False) - bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) - setattr(converted, 'bias_v', bias_v) # noqa: B010 - - return converted + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + # See nn.quantized.MultiheadAttention + raise NotImplementedError("It looks like you are trying to prepare an " + "MHA module. Please, see " + "the examples on quantizable MHAs.") def forward(self, query: Tensor, @@ -299,7 +282,7 @@ def forward(self, E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, - S is the source sequence length. If ``average_weights=False``, returns attention weights per + S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(N, num_heads, L, S)`. """ return self._forward_impl(query, key, value, key_padding_mask, diff --git a/torch/nn/quantizable/modules/rnn.py b/torch/nn/quantizable/modules/rnn.py index 0f5be1ed407aa..68ca6a8db1b59 100644 --- a/torch/nn/quantizable/modules/rnn.py +++ b/torch/nn/quantizable/modules/rnn.py @@ -380,5 +380,8 @@ def from_float(cls, other, qconfig=None): @classmethod def from_observed(cls, other): - return torch.ao.quantization.convert(other, inplace=False, - remove_qconfig=True) + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + raise NotImplementedError("It looks like you are trying to convert a " + "non-quantizable LSTM module. Please, see " + "the examples on quantizable LSTMs.") diff --git a/torch/nn/quantized/_reference/modules/rnn.py b/torch/nn/quantized/_reference/modules/rnn.py index bb5ec8bdcc98d..5b84f0e48b9a0 100644 --- a/torch/nn/quantized/_reference/modules/rnn.py +++ b/torch/nn/quantized/_reference/modules/rnn.py @@ -64,16 +64,22 @@ def _init_weight_qparams_dict(self, weight_qparams_dict, device): assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}") if weight_qscheme is not None: - self.register_buffer( - key + "_scale", - torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) - self.register_buffer( - key + "_zero_point", - torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device)) + scale = weight_qparams["scale"] + scale_tensor = scale.clone().detach() \ + if isinstance(scale, torch.Tensor) else \ + torch.tensor(scale, dtype=torch.float, device=device) + self.register_buffer(key + "_scale", scale_tensor) + zp = weight_qparams["zero_point"] + zp_tensor = zp.clone().detach() \ + if isinstance(zp, torch.Tensor) else \ + torch.tensor(zp, dtype=torch.int, device=device) + self.register_buffer(key + "_zero_point", zp_tensor) if weight_qscheme == torch.per_channel_affine: - self.register_buffer( - key + "_axis", - torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device)) + axis = weight_qparams["axis"] + axis_tensor = axis.clone().detach() \ + if isinstance(axis, torch.Tensor) else \ + torch.tensor(axis, dtype=torch.int, device=device) + self.register_buffer(key + "_axis", axis_tensor) else: # added for TorchScriptability, not used self.register_buffer( diff --git a/torch/nn/quantized/_reference/modules/utils.py b/torch/nn/quantized/_reference/modules/utils.py index 58d5cd608ffbd..f9cd0b7dcb21f 100644 --- a/torch/nn/quantized/_reference/modules/utils.py +++ b/torch/nn/quantized/_reference/modules/utils.py @@ -20,16 +20,22 @@ def _init_weight_qparams(self, weight_qparams, device): zero_point_dtype = weight_qparams["zero_point"].dtype if \ isinstance(weight_qparams["zero_point"], torch.Tensor) else \ torch.int - self.register_buffer( - "weight_scale", - torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) - self.register_buffer( - "weight_zero_point", - torch.tensor(weight_qparams["zero_point"], dtype=zero_point_dtype, device=device)) + w_scale = weight_qparams["scale"] + w_scale_tensor = w_scale.clone().detach() \ + if isinstance(w_scale, torch.Tensor) \ + else torch.tensor(w_scale, dtype=torch.float, device=device) + self.register_buffer("weight_scale", w_scale_tensor) + w_zp = weight_qparams["zero_point"] + w_zp_tensor = w_zp.clone().detach() \ + if isinstance(w_zp, torch.Tensor) \ + else torch.tensor(w_zp, dtype=zero_point_dtype, device=device) + self.register_buffer("weight_zero_point", w_zp_tensor) if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]: - self.register_buffer( - "weight_axis", - torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device)) + w_axis = weight_qparams["axis"] + w_axis_tensor = w_axis.clone().detach() \ + if isinstance(w_axis, torch.Tensor) \ + else torch.tensor(w_axis, dtype=torch.int, device=device) + self.register_buffer("weight_axis", w_axis_tensor) else: # added for TorchScriptability, not used self.register_buffer( diff --git a/torch/nn/quantized/dynamic/modules/conv.py b/torch/nn/quantized/dynamic/modules/conv.py index fce4c0ffcfd32..ea31259fcebce 100644 --- a/torch/nn/quantized/dynamic/modules/conv.py +++ b/torch/nn/quantized/dynamic/modules/conv.py @@ -13,6 +13,7 @@ import torch.nn.quantized.modules as nnq import warnings +__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'] class Conv1d(nnq.Conv1d): r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py index 5cba314747235..dff08ee0a24a5 100644 --- a/torch/nn/quantized/dynamic/modules/rnn.py +++ b/torch/nn/quantized/dynamic/modules/rnn.py @@ -8,9 +8,16 @@ from torch.nn.utils.rnn import PackedSequence from torch.nn.quantized.modules.utils import _quantize_weight -def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: +__all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', + 'GRUCell'] + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: return tensor.index_select(dim, permutation) +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") + return _apply_permutation(tensor, permutation, dim) + def pack_weight_bias(qweight, bias, dtype): if dtype == torch.qint8: @@ -207,7 +214,7 @@ def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optiona def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: if permutation is None: return hx - return apply_permutation(hx, permutation) + return _apply_permutation(hx, permutation) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -460,7 +467,7 @@ def permute_hidden( # type: ignore[override] ) -> Tuple[Tensor, Tensor]: if permutation is None: return hx - return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) + return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation) # "type: ignore" is required due to issue #43072 def check_forward_args( # type: ignore[override] @@ -703,7 +710,7 @@ def permute_hidden( ) -> Tensor: if permutation is None: return hx - return apply_permutation(hx, permutation) + return _apply_permutation(hx, permutation) @torch.jit.ignore def forward(self, input, hx=None): diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py index 0f7aba8fd30e3..0be06c73777af 100644 --- a/torch/nn/quantized/functional.py +++ b/torch/nn/quantized/functional.py @@ -545,7 +545,7 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners= size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): output spatial size. scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer. - mode (string): algorithm used for upsampling: + mode (str): algorithm used for upsampling: ``'nearest'`` | ``'bilinear'`` align_corners (bool, optional): Geometrically, we consider the pixels of the input and output as squares rather than points. diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index 8004b52cc6595..2ccfe1de9870b 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -1,7 +1,7 @@ import torch from torch.nn.modules.pooling import MaxPool2d -from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax +from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU from .dropout import Dropout from .batchnorm import BatchNorm2d, BatchNorm3d from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \ @@ -10,6 +10,7 @@ from .conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from .linear import Linear from .embedding_ops import Embedding, EmbeddingBag +from .rnn import LSTM from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional @@ -109,12 +110,15 @@ def from_float(mod): 'LayerNorm', 'LeakyReLU', 'Linear', + 'LSTM', 'MaxPool2d', + 'MultiheadAttention', 'Quantize', 'ReLU6', 'Sigmoid', 'Softmax', 'Dropout', + 'PReLU', # Wrapper modules 'FloatFunctional', 'FXFloatFunctional', diff --git a/torch/nn/quantized/modules/activation.py b/torch/nn/quantized/modules/activation.py index cd581c34a8228..d1ce62b27823f 100644 --- a/torch/nn/quantized/modules/activation.py +++ b/torch/nn/quantized/modules/activation.py @@ -183,3 +183,95 @@ def from_float(mod): @classmethod def from_reference(cls, mod, scale, zero_point): return cls(mod.dim, float(scale), int(zero_point)) + +class MultiheadAttention(torch.nn.quantizable.MultiheadAttention): + _FLOAT_MODULE = torch.nn.quantizable.MultiheadAttention + + def _get_name(self): + return "QuantizedMultiheadAttention" + + @classmethod + def from_float(cls, other): + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError("It looks like you are trying to convert a " + "non-observed MHA module. Please, see " + "the examples on quantizable MHAs.") + + @classmethod + def from_observed(cls, other): + converted = torch.ao.quantization.convert(other, mapping=None, + inplace=False, + remove_qconfig=True, + convert_custom_config_dict=None) + converted.__class__ = cls + # Remove the parameters for the bias_k and bias_v to quantize them + # TODO: This is a potential source of accuracy drop. + # quantized cat takes the scale and zp of the first + # element, which might lose the precision in the bias_k + # and the bias_v (which are cat'ed with k/v being first). + if converted.bias_k is not None: + bias_k = converted._parameters.pop('bias_k') + sc, zp = torch._choose_qparams_per_tensor(bias_k, + reduce_range=False) + bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8) + setattr(converted, 'bias_k', bias_k) # noqa: B010 + + if converted.bias_v is not None: + bias_v = converted._parameters.pop('bias_v') + sc, zp = torch._choose_qparams_per_tensor(bias_k, + reduce_range=False) + bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) + setattr(converted, 'bias_v', bias_v) # noqa: B010 + + return converted + +class PReLU(torch.nn.Module): + r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + num_parameters: number of parameters: 1, or the number of channels at input. Default: 1 + """ + def __init__(self, output_scale: float, output_zero_point: int, + num_parameters: int = 1) -> None: + super().__init__() + self.num_parameters = num_parameters + self.scale = output_scale + self.zero_point = output_zero_point + w = torch.randn(num_parameters, dtype=torch.float) + qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8) + self.set_weight(qw) + + def set_weight(self, w: torch.Tensor) -> None: + self.weight = w + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.prelu(input, self.weight, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedPReLU' + + @classmethod + def from_float(cls, mod): + scale, zero_point = mod.activation_post_process.calculate_qparams() + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8) + qprelu.set_weight(qweight) + return qprelu + + @classmethod + def from_reference(cls, mod, scale, zero_point): + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8) + qprelu.set_weight(qweight) + return qprelu diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 591948b667c2f..7453fe092c389 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -15,6 +15,8 @@ from torch.nn.quantized.modules.utils import _quantize_weight, WeightedQuantizedModule from torch.nn.utils import fuse_conv_bn_weights +__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'] + _SUPPORTED_PADDING = { 'zeros', 'reflect' diff --git a/torch/nn/quantized/modules/dropout.py b/torch/nn/quantized/modules/dropout.py index ae540dad2e95d..2885f5dc68464 100644 --- a/torch/nn/quantized/modules/dropout.py +++ b/torch/nn/quantized/modules/dropout.py @@ -1,6 +1,8 @@ import torch import torch.nn.quantized.functional +__all__ = ['Dropout'] + class Dropout(torch.nn.Dropout): r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`. And this is a placeholder to enable models where fp32 tensors diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index 7af12e9a72e28..6253706249e74 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -5,6 +5,8 @@ from torch.nn.quantized.modules.utils import hide_packed_params_repr from torch.nn.quantized.modules.utils import _quantize_weight +__all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag'] + class EmbeddingPackedParams(torch.nn.Module): _version = 1 diff --git a/torch/nn/quantized/modules/functional_modules.py b/torch/nn/quantized/modules/functional_modules.py index d5090b47349ed..06d4f29d744ae 100644 --- a/torch/nn/quantized/modules/functional_modules.py +++ b/torch/nn/quantized/modules/functional_modules.py @@ -4,6 +4,8 @@ from torch import Tensor from torch._ops import ops +__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional'] + class FloatFunctional(torch.nn.Module): r"""State collector class for float operations. diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index 1463497cafe08..267c1feee195a 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -9,6 +9,8 @@ from torch.nn.utils.parametrize import type_before_parametrizations from typing import Optional +__all__ = ['LinearPackedParams', 'Linear'] + class LinearPackedParams(torch.nn.Module): _version = 3 diff --git a/torch/nn/quantized/modules/normalization.py b/torch/nn/quantized/modules/normalization.py index b695df32b7ca6..a24160dee1bc6 100644 --- a/torch/nn/quantized/modules/normalization.py +++ b/torch/nn/quantized/modules/normalization.py @@ -1,6 +1,8 @@ import torch import torch.nn.quantized.functional +__all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] + class LayerNorm(torch.nn.LayerNorm): r"""This is the quantized version of :class:`~torch.nn.LayerNorm`. diff --git a/torch/nn/quantized/modules/rnn.py b/torch/nn/quantized/modules/rnn.py new file mode 100644 index 0000000000000..7e523ba830d22 --- /dev/null +++ b/torch/nn/quantized/modules/rnn.py @@ -0,0 +1,47 @@ +import torch + +class LSTM(torch.nn.quantizable.LSTM): + r"""A quantized long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples in :class:`~torch.nn.quantizable.LSTM` + + Examples:: + + >>> custom_module_config = { + ... 'float_to_observed_custom_module_class': { + ... nn.LSTM: nn.quantizable.LSTM, + ... }, + ... 'observed_to_quantized_custom_module_class': { + ... nn.quantizable.LSTM: nn.quantized.LSTM, + ... } + ... } + >>> tq.prepare(model, prepare_custom_module_class=custom_module_config) + >>> tq.convert(model, convert_custom_module_class=custom_module_config) + """ + _FLOAT_MODULE = torch.nn.quantizable.LSTM + + def _get_name(self): + return 'QuantizedLSTM' + + @classmethod + def from_float(cls, *args, **kwargs): + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError("It looks like you are trying to convert a " + "non-observed LSTM module. Please, see " + "the examples on quantizable LSTMs.") + + @classmethod + def from_observed(cls, other): + assert type(other) == cls._FLOAT_MODULE + converted = torch.ao.quantization.convert(other, inplace=False, + remove_qconfig=True) + converted.__class__ = cls + return converted diff --git a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py index bfcd72e591da2..df30619160053 100644 --- a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F -from .conv_utils import conv_backward, conv_args_and_kwargs +from .conv_utils import conv_backward, conv_args_and_kwargs, conv_picker from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads from .expanded_weights_utils import forward_helper @@ -17,6 +17,10 @@ def forward(ctx, kwarg_names, conv_fn, *expanded_args_and_kwargs): expanded_args, expanded_kwargs = conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs) output = forward_helper(conv_fn, expanded_args, expanded_kwargs) input, weight = expanded_args + batched_dim_size = conv_picker(conv_fn, 3, 4, 5) + if input.dim() != batched_dim_size: + raise RuntimeError(f"Expanded Weights only support convolution with batched input, got {conv_fn} with an" + f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}") ctx.conv_fn = conv_fn diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index 16b9774325299..f1bb281228e9b 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -86,7 +86,7 @@ def conv_unfold_weight_grad_sample(input, grad_output, weight_shape, kernel_size padding=(0, padding[0]), stride=(1, stride[0])), lambda: F.unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride), - lambda: unfold3d(input, kernel_size, dilation, padding, stride) + lambda: unfold3d(input, kernel_size, padding, stride, dilation) ) input = unfold_func() diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index 7914cf8dc1ee8..1bfb91c3360c9 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -26,13 +26,14 @@ def decorator(autograd_func): # # Needs to be a tensor subclass to allow reparamaterization class ExpandedWeight(torch.Tensor): - def __init__(self, orig_weight, batch_size): + def __init__(self, orig_weight, batch_size, loss_reduction): self.batch_size = batch_size self.orig_weight = orig_weight + self.loss_reduction = loss_reduction handled_functions = HANDLED_FUNCTIONS - def __new__(cls, orig_weight, _): + def __new__(cls, orig_weight, batch_size, loss_reduction): if not isinstance(orig_weight, torch.Tensor): raise RuntimeError(f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}") if not orig_weight.requires_grad: diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index ca0fc7c9e35be..0fb3d4beeafb9 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from .expanded_weights_impl import ExpandedWeight @@ -52,18 +54,34 @@ def _check_and_unexpand_args(func, expanded_args, expanded_kwargs): raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got " f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}") + loss_reduction: Optional[str] = None + for arg in expanded_args + tuple(expanded_kwargs.values()): + if isinstance(arg, ExpandedWeight): + if loss_reduction is None: + loss_reduction = arg.loss_reduction + elif loss_reduction != arg.loss_reduction: + raise RuntimeError("Expected ExpandedWeights to all have the same loss_reduction argument but got one" + f"with {loss_reduction} and one with {arg.loss_reduction}") + unexpanded_args = tuple(arg.orig_weight if isinstance(arg, ExpandedWeight) else arg for arg in expanded_args) unexpanded_kwargs = {name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg for (name, arg) in expanded_kwargs.items()} return unexpanded_args, unexpanded_kwargs +def maybe_scale_by_batch_size(grad_sample, expanded_weight): + if expanded_weight.loss_reduction == "mean": + return grad_sample * expanded_weight.batch_size + else: + return grad_sample + def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) if isinstance(maybe_expanded_weight, ExpandedWeight): + grad_sample_contribution = maybe_scale_by_batch_size(per_sample_grad_fn(unpacked), maybe_expanded_weight) if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None: - unpacked.grad_sample = unpacked.grad_sample + per_sample_grad_fn(unpacked) + unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution else: - unpacked.grad_sample = per_sample_grad_fn(unpacked) + unpacked.grad_sample = grad_sample_contribution def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x): if isinstance(maybe_expanded_weight, ExpandedWeight): diff --git a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py index f050a98836ffe..f3e68b9406602 100644 --- a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -37,8 +37,9 @@ def backward(ctx, grad_output): running_var_ = running_var.repeat(b) if running_var is not None else None input_reshaped = input.contiguous().view(new_shape) grad_output_reshaped = grad_output.contiguous().view(new_shape) - mean = torch.mean(input_reshaped.transpose(0, 1), tuple(range(1, input.dim())), False) - rstd = torch.var(input_reshaped.transpose(0, 1), tuple(range(1, input.dim())), keepdim=False, unbiased=False) + mean = torch.mean(input_reshaped, (0,) + tuple(range(2, input.dim())), False) + var = torch.var(input_reshaped, (0,) + tuple(range(2, input.dim())), keepdim=False, unbiased=False) + rstd = 1 / torch.sqrt(var + eps) # must use native batch norm since it supports all inputs. This may have used cuda or openmi during the forward but # it didn't save the metadata, so we don't know during the backward diff --git a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py index 53cb3fe032ea9..f2ead2d4c08fb 100644 --- a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -19,9 +19,9 @@ def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): output, mean, rstd = forward_helper(torch.native_layer_norm, expanded_args, expanded_kwargs) ctx.args = expanded_args - if input.requires_grad or isinstance(ExpandedWeight, expanded_kwargs['weight']): + if input.requires_grad or isinstance(expanded_kwargs['weight'], ExpandedWeight): ctx.weight = expanded_kwargs['weight'] - if input.requires_grad or isinstance(ExpandedWeight, expanded_kwargs['bias']): + if input.requires_grad or isinstance(expanded_kwargs['bias'], ExpandedWeight): ctx.bias = expanded_kwargs['bias'] ctx.eps = expanded_kwargs['eps'] ctx.mean, ctx.rstd = mean, rstd diff --git a/torch/nn/utils/_per_sample_grad.py b/torch/nn/utils/_per_sample_grad.py index 9d67cc014877f..a0dc7b00db529 100644 --- a/torch/nn/utils/_per_sample_grad.py +++ b/torch/nn/utils/_per_sample_grad.py @@ -1,29 +1,36 @@ +import functools + import torch -from torch.nn.utils._stateless import functional_call +from torch.nn.utils.stateless import functional_call from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight +from torch.utils._pytree import tree_flatten + # dependency on `functional_call` means that this can't be exposed in utils # without creating circular dependency -def call_for_per_sample_grads(module, batch_size, args, kwargs=None): +def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"): r""" - call_for_per_sample_grads(module, batch_size, args, kwargs=None) -> Tensor - Invoked just like a forward pass, ``call_for_per_sample_grads`` will produce the same - forward result. Then, when backward is invoked, the parameters of ``module`` - will have a ``grad_sample`` field populated with the per sample gradients - instead of the regular gradients + call_for_per_sample_grads(module, batch_size=None, loss_reduction="sum") + ``call_for_per_sample_grads`` returns a function that is invoked like the forward + function of ``module`` and will produce the same result. Then, when backward is invoked, + the parameters of ``module`` will have a ``grad_sample`` field populated with the per sample + gradients instead of the regular gradients Args: module: The ``nn.Module`` to get per sample gradients with respect to. All trainable parameters will compute per sample gradients, located in a ``grad_sample`` field when ``backward`` is invoked - batch_size: The batch size of the input. Typically the input's first dimension - args: Tuple of positional args passed to ``module`` to perform the forward pass - kwargs: Dict of named args passed to ``module`` to perform the forward pass. Default: None + batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have + the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually. + Default: None + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If + "mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from + running mean across a batch. Must be "mean" or "sum". Default: "sum" Examples:: >>> model = nn.Linear(4, 3) >>> batched_input = torch.randn(5, 4) # batch size of 5 - >>> res = call_for_per_sample_grads(model, batched_input.shape[0], batched_input).sum() + >>> res = call_for_per_sample_grads(model)(batched_input).sum() >>> res.backward() >>> assert model.weight.shape == (3, 4) >>> assert model.weight.grad_sample.shape == (5, 3, 4) @@ -32,26 +39,65 @@ def call_for_per_sample_grads(module, batch_size, args, kwargs=None): >>> assert model.bias.grad_sample.shape == (5, 3) >>> assert model.bias.grad == None + An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be + if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all + grad_outputs by 1 / batch_size from cross batch interaction. + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean() + >>> res.backward() + Note:: Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom rewrites that wrap an `nn.Linear` module. See Opacus for an example """ - def maybe_build_expanded_weight(og_tensor): + + def maybe_build_expanded_weight(og_tensor, batch_size): if og_tensor.requires_grad: - return ExpandedWeight(og_tensor, batch_size) + return ExpandedWeight(og_tensor, batch_size, loss_reduction) else: return og_tensor + def compute_batch_size(*args, **kwargs): + args_and_kwargs = tree_flatten(args)[0] + tree_flatten(kwargs)[0] + batch_size = None + for arg in args_and_kwargs: + if not isinstance(arg, torch.Tensor): + continue + + arg_batch_size = arg.shape[0] # we assume batch size is the first dim + if batch_size is not None and batch_size != arg_batch_size: + raise RuntimeError("When computing batch size, found at least one input with batch size " + f"{batch_size} and one with batch size {arg_batch_size}. Please specify it " + "explicitly using the batch size kwarg in call_for_per_sample_grads") + batch_size = arg_batch_size + if batch_size is None: + raise RuntimeError("Unable to find a tensor in the passed args and kwargs. They may not be pytree-able " + "and so ExpandedWeights cannot compute the batch size from the inputs. Please specify " + "it explicitly") + return batch_size + + if loss_reduction not in ["sum", "mean"]: + raise RuntimeError(f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}") + if not isinstance(module, torch.nn.Module): raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}") - if not isinstance(batch_size, int): - raise RuntimeError(f"Batch size passed must be an integer, got {type(batch_size).__name__}") - if batch_size < 1: + if not (batch_size is None or isinstance(batch_size, int)): + raise RuntimeError(f"Batch size passed must be None or an integer, got {type(batch_size).__name__}") + if batch_size is not None and batch_size < 1: raise RuntimeError(f"Batch size must be positive, got {batch_size}") for weight in module.parameters(): if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined] raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple " f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or " "post an issue to pytorch/pytorch to prioritize correct behavior") - params = {name: maybe_build_expanded_weight(value) for (name, value) in module.named_parameters()} - return functional_call(module, params, args, kwargs) + + @functools.wraps(module.forward) + def wrapper(*args, **kwargs): + wrapper_batch_size = batch_size + if wrapper_batch_size is None: + wrapper_batch_size = compute_batch_size(*args, **kwargs) + + params = {name: maybe_build_expanded_weight(value, wrapper_batch_size) for (name, value) in module.named_parameters()} + return functional_call(module, params, args, kwargs) + return wrapper diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 1be274281d021..74b0f0833a1e6 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -5,6 +5,7 @@ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] +__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_'] def clip_grad_norm_( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index 5ba4b5c874238..6595b4115f7d2 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -8,6 +8,7 @@ from typing import Optional +__all__ = ['orthogonal', 'spectral_norm'] def _is_orthogonal(Q, eps=None): n, k = Q.size(-2), Q.size(-1) diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index 071787fa1b9cf..f48d18e7a8600 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -4,9 +4,14 @@ from torch import Tensor import collections +import copyreg +from copy import deepcopy from contextlib import contextmanager from typing import Union, Optional, Dict, Tuple, Sequence +__all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations', + 'type_before_parametrizations', 'transfer_parametrizations_and_params'] + _cache_enabled = 0 _cache: Dict[Tuple[int, str], Optional[Tensor]] = {} @@ -254,6 +259,8 @@ def right_inverse(self, value: Tensor) -> None: original_i.set_(tensor) def forward(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError('Parametrization is not working with scripting.') # Unpack the originals for the first parametrization if self.is_tensor: x = self[0](self.original) @@ -280,6 +287,21 @@ def _inject_new_class(module: Module) -> None: """ cls = module.__class__ + def default_deepcopy(self, memo): + # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. + obj = memo.get(id(self), None) + if obj is not None: + return obj + replica = self.__new__(self.__class__) + memo[id(self)] = replica + replica.__dict__ = deepcopy(self.__dict__, memo) + # Also save all slots if they exist. + slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] + for slot in slots_to_save: + if hasattr(self, slot): + setattr(replica, slot, deepcopy(getattr(self, slot), memo)) + return replica + def getstate(self): raise RuntimeError( "Serialization of parametrized modules is only " @@ -288,12 +310,16 @@ def getstate(self): "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" ) + dct = {"__getstate__": getstate} + # We don't allow serialization of parametrized modules but should still allow deepcopying. + # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. + if not hasattr(cls, "__deepcopy__"): + dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] + param_cls = type( f"Parametrized{cls.__name__}", (cls,), - { - "__getstate__": getstate, - }, + dct, ) module.__class__ = param_cls @@ -325,6 +351,8 @@ def get_cached_parametrization(parametrization) -> Tensor: return tensor def get_parametrized(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError('Parametrization is not working with scripting.') parametrization = self.parametrizations[tensor_name] if _cache_enabled: if torch.jit.is_scripting(): @@ -341,6 +369,8 @@ def get_parametrized(self) -> Tensor: return parametrization() def set_original(self, value: Tensor) -> None: + if torch.jit.is_scripting(): + raise RuntimeError('Parametrization is not working with scripting.') self.parametrizations[tensor_name].right_inverse(value) setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index 5acd264a139ee..200269780a3d9 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -9,6 +9,8 @@ from typing import List, Tuple, Union, Iterable +__all__ = ['PackedSequence', 'invert_permutation', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', + 'unpad_sequence', 'pack_sequence', 'unpack_sequence'] PackedSequence_ = namedtuple('PackedSequence_', ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices']) diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index 2c71dd3e7ab43..447e498641f7f 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -6,6 +6,8 @@ from typing import Any, Optional, TypeVar from ..modules import Module +__all__ = ['SpectralNorm', 'SpectralNormLoadStateDictPreHook', 'SpectralNormStateDictHook', + 'spectral_norm', 'remove_spectral_norm'] class SpectralNorm: # Invariant before and after each forward call: diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index e6166b8bcff1d..add34d2d51498 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -64,11 +64,13 @@ def _reparametrize_module( _apply_func_submodules( _create_swap_params(parameters_and_buffers), module, name.split("."), name, (tensor,)) - yield - for name in parameters_and_buffers: - _apply_func_submodules( - _remove_swap, - module, name.split("."), name, ()) + try: + yield + finally: + for name in parameters_and_buffers: + _apply_func_submodules( + _remove_swap, + module, name.split("."), name, ()) def _apply_func_submodules( diff --git a/torch/nn/utils/weight_norm.py b/torch/nn/utils/weight_norm.py index c10a5f917a713..ab206a35be46d 100644 --- a/torch/nn/utils/weight_norm.py +++ b/torch/nn/utils/weight_norm.py @@ -6,6 +6,7 @@ from typing import Any, TypeVar from ..modules import Module +__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm'] class WeightNorm(object): name: str diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index a11436e8b75ac..5d6e7db3356f9 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -10,7 +10,8 @@ TrainingMode, ) -from . import ( +from . import ( # usort:skip. Keep the order instead of sorting lexicographically + _deprecation, errors, symbolic_caffe2, symbolic_helper, @@ -90,11 +91,10 @@ producer_version = _C_onnx.PRODUCER_VERSION +@_deprecation.deprecated( + since="1.12.0", removed_in="TBD", instructions="use `torch.onnx.export` instead" +) def _export(*args, **kwargs): - warnings.warn( - "`torch.onnx._export` is deprecated. Please use `torch.onnx.export` instead.", - DeprecationWarning, - ) return utils._export(*args, **kwargs) diff --git a/torch/onnx/_deprecation.py b/torch/onnx/_deprecation.py new file mode 100644 index 0000000000000..70b6e4d3374c0 --- /dev/null +++ b/torch/onnx/_deprecation.py @@ -0,0 +1,31 @@ +"""Utility for deprecating functions.""" + +import functools +import warnings + + +def deprecated(since: str, removed_in: str, instructions: str): + """Marks functions as deprecated. + + It will result in a warning when the function is called. + + Args: + since: The version when the function was first deprecated. + removed_in: The version when the function will be removed. + instructions: The action users should take. + """ + + def decorator(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + warnings.warn( + f"`{function.__module__}.{function.__name__}` is deprecated in version {since} and will be " + f"removed in version {removed_in}. Please {instructions}.", + category=FutureWarning, + stacklevel=2, + ) + return function(*args, **kwargs) + + return wrapper + + return decorator diff --git a/torch/onnx/onnx_supported_ops.py b/torch/onnx/_onnx_supported_ops.py similarity index 83% rename from torch/onnx/onnx_supported_ops.py rename to torch/onnx/_onnx_supported_ops.py index 7256b034e6c2a..39b51a6f32407 100644 --- a/torch/onnx/onnx_supported_ops.py +++ b/torch/onnx/_onnx_supported_ops.py @@ -63,18 +63,15 @@ def _all_aten_forward_schemas(): def _symbolic_argument_count(func): params = [] - sig = inspect.signature(func) + signature = inspect.signature(func) optional_params = [] - has_var = False - for name, p in sig.parameters.items(): - if p.kind.name == "VAR_POSITIONAL": - has_var = True - elif name == "_outputs" or name == "g": + for name, parameter in signature.parameters.items(): + if name in {"_outputs", "g"}: continue - elif p.default != inspect._empty: # type: ignore[attr-defined] - optional_params.append(p) + if parameter.default is parameter.empty: + optional_params.append(parameter) else: - params.append(str(p)) + params.append(str(parameter)) return params @@ -97,15 +94,13 @@ def onnx_supported_ops(): aten_schemas = _all_aten_forward_schemas() symbolic_schemas = _all_symbolics_schemas() torch_schemas = set(symbolic_schemas.values()) - supported_ops, unsupported_ops = list(), list() - onnx_supported_ops = list() + supported_ops = [] + onnx_supported = [] for schema in aten_schemas: if schema in torch_schemas: opname = schema.name[6:] # without "aten::" prefix opsets = symbolic_schemas[opname].opsets if schema not in supported_ops: supported_ops.append(symbolic_schemas[opname]) - onnx_supported_ops.append((opname, " ".join(str(o) for o in opsets))) - else: - unsupported_ops.append(schema) - return sorted(onnx_supported_ops, key=lambda x: x[0]) + onnx_supported.append((opname, " ".join(str(o) for o in opsets))) + return sorted(onnx_supported, key=lambda x: x[0]) diff --git a/torch/onnx/symbolic_caffe2.py b/torch/onnx/symbolic_caffe2.py index 5f2b17b73735b..c406fac6982b0 100644 --- a/torch/onnx/symbolic_caffe2.py +++ b/torch/onnx/symbolic_caffe2.py @@ -1,9 +1,7 @@ import importlib import inspect -from torch.onnx import symbolic_helper -from torch.onnx import symbolic_opset9 as opset9 -from torch.onnx import symbolic_registry +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9, symbolic_registry def register_quantized_ops(domain: str, version: int): diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 6829c913c3bb6..f89265ef243d3 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -1,9 +1,12 @@ +from __future__ import annotations + +import collections import enum import functools import inspect import sys import warnings -from typing import Optional, Set +from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union import torch import torch._C._onnx as _C_onnx @@ -153,7 +156,7 @@ def _get_const(value, desc, arg_name): return _parse_arg(value, desc) -def _unpack_list(list_value): +def _unpack_list(list_value: _C.Value) -> List[_C.Value]: list_node = list_value.node() assert list_node.kind() == "prim::ListConstruct" return list(list_node.inputs()) @@ -343,8 +346,9 @@ def wrapper(g, *args, **kwargs): def _scalar(x): """Convert a scalar tensor into a Python value.""" - assert x.numel() == 1 - return x.item() + if isinstance(x, torch.Tensor) and x.shape == (): + return x.item() + return None def _if_scalar_type_as(g: _C.Graph, self, tensor): @@ -533,23 +537,35 @@ def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False) return _slice10(g, input, axes, starts, ends, steps, dynamic_slice) +_ScalarAndTensorElementTypeGroup = collections.namedtuple( + "_ScalarAndTensorElementTypeGroup", ("tensor_element_types", "scalar_types") +) +_FPTypeGroup = _ScalarAndTensorElementTypeGroup( + (torch.float16, torch.float32, torch.float64, torch.bfloat16), + ("Float", "Double", "Half", "BFloat16"), +) +_BoolTypeGroup = _ScalarAndTensorElementTypeGroup((torch.bool,), ("Bool",)) + + +def _is_in_type_group(value, type_set): + if not value: + return False + if isinstance(value, torch.Tensor): + return value.dtype in type_set.tensor_element_types + scalar_type = value.type().scalarType() + if scalar_type is None: + warnings.warn( + "Type cannot be inferred, which might cause exported graph to produce incorrect results." + ) + return scalar_type in type_set.scalar_types + + def _is_fp(value): - if value: - if isinstance(value, torch.Tensor): - return value.dtype in ( - torch.float16, - torch.float32, - torch.float64, - torch.bfloat16, - ) - else: - type = value.type().scalarType() - if type is None: - warnings.warn( - "Type cannot be inferred, which might cause exported graph to produce incorrect results." - ) - return type in ("Float", "Double", "Half", "BFloat16") - return False + return _is_in_type_group(value, _FPTypeGroup) + + +def _is_bool(value): + return _is_in_type_group(value, _BoolTypeGroup) def _generate_wrapped_number(g, scalar): @@ -810,7 +826,7 @@ def symbolic_fn(g, input, output_size, *args): if interpolate_mode == "nearest" else "align_corners" if align_corners - else "pytorch_half_pixel" + else "half_pixel" ) if scales is None: @@ -880,7 +896,7 @@ def __interpolate_helper( if mode == "nearest" else "align_corners" if align_corners - else "pytorch_half_pixel" + else "half_pixel" ) if not _is_none(size): @@ -1129,13 +1145,17 @@ def _batchnorm_helper(g, input, weight, bias, running_mean, running_var): return weight, bias, running_mean, running_var -def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name): +def _avgpool_helper( + tuple_fn: Callable[[Any], Sequence[int]], + padding: Union[int, Sequence[int]], + kernel_size, + stride, + divisor_override, + name, +) -> Tuple[int, ...]: if divisor_override and divisor_override.node().kind() != "prim::Constant": - return _unimplemented(name, "divisor_override") - if not stride: - stride = kernel_size - padding = tuple(tuple_fn(padding)) - return padding + _unimplemented(name, "divisor_override") + return tuple(tuple_fn(padding)) def check_training_mode(op_train_mode: int, op_name: str) -> None: @@ -1211,7 +1231,11 @@ def _handle_reduce_dim_none(g, self, op_name): return g.op(op_name, self, keepdims_i=0) -def dequantize_helper(g, qtensor, qdtype=None): +def dequantize_helper( + g, + qtensor: _C.Value, + qdtype: Optional[torch.onnx.TensorProtoDataType] = None, +) -> Tuple[_C.Value, _C.Value, _C.Value, Optional[_C.Value]]: """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`. Args: @@ -1252,7 +1276,13 @@ def dequantize_helper(g, qtensor, qdtype=None): ) -def quantize_helper(g, tensor, scale, zero_point, axis=None): +def quantize_helper( + g, + tensor: _C.Value, + scale: _C.Value, + zero_point: _C.Value, + axis: Optional[_C.Value] = None, +) -> _C.Value: """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`. Args: @@ -1276,11 +1306,13 @@ def quantize_helper(g, tensor, scale, zero_point, axis=None): ) assert scale is not None - if scale.type().scalarType() != "Float": + if scale.type().scalarType() != "Float": # type: ignore[attr-defined] + # TODO(justinchuby): Remove type ignore after #81112 is checked in. scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) assert zero_point is not None - if zero_point.type().scalarType() not in ("Byte", "Char"): + if zero_point.type().scalarType() not in ("Byte", "Char"): # type: ignore[attr-defined] + # TODO(justinchuby): Remove type ignore after #81112 is checked in. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) output = g.op( "QuantizeLinear", diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index ee2358409e103..855e76b30e16d 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -1,15 +1,18 @@ import sys import warnings +from typing import Sequence import torch -from torch import _C import torch._C._onnx as _C_onnx import torch.onnx +from torch import _C # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx import _patch_torch # noqa: F401 -from torch.onnx import symbolic_helper -from torch.onnx import symbolic_opset9 as opset9 +from torch.onnx import ( # noqa: F401 + _patch_torch, + symbolic_helper, + symbolic_opset9 as opset9, +) from torch.onnx._globals import GLOBALS # EDITING THIS FILE? READ THIS FIRST! @@ -20,6 +23,39 @@ # release on 04/24/19 +__all__ = [ + "avg_pool1d", + "avg_pool2d", + "avg_pool3d", + "dequantize", + "div", + "embedding_bag", + "fake_quantize_per_tensor_affine", + "flip", + "fmod", + "isfinite", + "isinf", + "max_pool1d_with_indices", + "max_pool1d", + "max_pool2d_with_indices", + "max_pool2d", + "max_pool3d_with_indices", + "max_pool3d", + "nan_to_num", + "quantize_per_tensor", + "Quantized", + "slice", + "sort", + "topk", + "upsample_bilinear2d", + "upsample_linear1d", + "upsample_nearest1d", + "upsample_nearest2d", + "upsample_nearest3d", + "upsample_trilinear3d", +] + + def div(g, self, other, *args): if len(args) == 0: return opset9.true_divide(g, self, other) @@ -145,12 +181,12 @@ def _avg_pool(name, tuple_fn): @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") def symbolic_fn( g, - input, - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: Sequence[int], + ceil_mode: int, + count_include_pad: int, divisor_override=None, ): if not stride: @@ -624,3 +660,19 @@ def conv2d( ) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + @staticmethod + @symbolic_helper.parse_args("v", "i", "v", "v") + def cat( + g, + q_inputs: _C.Value, + dim: int, + op_scale: _C.Value, + op_zero_point: _C.Value, + ) -> _C.Value: + unpacked_inputs = symbolic_helper._unpack_list(q_inputs) + dequantized = [ + symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs + ] + concatenated = g.op("Concat", *dequantized, axis_i=dim) + return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 71e4940d5b024..942fde5b141e7 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -2,17 +2,92 @@ import sys import warnings +from typing import Tuple, Union import torch -from torch.onnx import symbolic_helper -from torch.onnx import symbolic_opset9 as opset9 -from torch.onnx import symbolic_opset10 as opset10 -from torch.onnx import utils +from torch import _C +from torch.onnx import ( + symbolic_helper, + symbolic_opset10 as opset10, + symbolic_opset9 as opset9, + utils, +) from torch.onnx._globals import GLOBALS # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py +__all__ = [ + "add", + "append", + "arange", + "argsort", + "avg_pool1d", + "avg_pool2d", + "avg_pool3d", + "cat", + "chunk", + "clamp_max", + "clamp_min", + "clamp", + "constant_pad_nd", + "cumsum", + "Delete", + "embedding_bag", + "embedding_renorm", + "flatten", + "gather", + "hardtanh", + "im2col", + "index_fill", + "index", + "index_copy", + "index_put", + "insert", + "linalg_det", + "linalg_vector_norm", + "logdet", + "masked_scatter", + "masked_select", + "mm", + "narrow", + "normal", + "pad", + "pixel_shuffle", + "pop", + "Prim", + "reflection_pad", + "reflection_pad1d", + "reflection_pad2d", + "reflection_pad3d", + "relu6", + "remainder", + "replication_pad", + "replication_pad1d", + "replication_pad2d", + "replication_pad3d", + "round", + "scatter", + "select", + "size", + "sort", + "split_with_sizes", + "split", + "squeeze", + "stack", + "topk", + "unbind", + "unique_dim", + "unsqueeze", + "upsample_bicubic2d", + "upsample_bilinear2d", + "upsample_linear1d", + "upsample_nearest1d", + "upsample_nearest2d", + "upsample_nearest3d", + "upsample_trilinear3d", +] + @symbolic_helper.parse_args("v", "f", "f") def hardtanh(g, self, min_val, max_val): @@ -143,7 +218,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False): if len(indices_list) > 1: for idx_ in range(len(indices_list)): - if indices_list[idx_].type().scalarType() == "Bool": + if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined] + # TODO(justinchuby): Remove type ignore after #81112 is checked in. indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] @@ -198,7 +274,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False): # return (%33) index = indices_list[0] bool_inp = index - if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": + if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined] + # TODO(justinchuby): Remove type ignore after #81112 is checked in. rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: return opset9.masked_fill(g, self, bool_inp, values) @@ -428,12 +505,12 @@ def _avg_pool(name, tuple_fn): @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") def symbolic_fn( g, - input, - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, + input: _C.Value, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Union[int, Tuple[int, ...]], + ceil_mode: int, + count_include_pad: int, divisor_override=None, ): padding = symbolic_helper._avgpool_helper( @@ -487,6 +564,14 @@ def sort(g, self, dim, decending, out=None): return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) +@symbolic_helper.parse_args("v", "i", "i", "none") +def argsort(g, self, dim, decending, out=None): + _, indices = symbolic_helper._sort_helper( + g, self, dim, decending=decending, out=out + ) + return indices + + def round(g, self): return g.op("Round", self) @@ -1062,12 +1147,15 @@ def linalg_vector_norm(g, self, ord, dim, keepdim, dtype): self = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) ) - keepdim = None + keepdim = 0 + cond_op = g.op( "Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))) ) cond_op = g.op( - "Cast", cond_op, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"] + "Cast", + cond_op, + to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) return symbolic_helper._reducesum_helper( g, cond_op, axes_i=dim, keepdims_i=keepdim diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 647cf7634ad2e..ab2197c367cba 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -1,15 +1,35 @@ import sys +from typing import Optional, Tuple import torch -from torch.onnx import symbolic_helper -from torch.onnx import symbolic_opset9 as opset9 -from torch.onnx import utils +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9, utils # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py # This file exports ONNX ops for opset 12 +__all__ = [ + "argmax", + "argmin", + "binary_cross_entropy_with_logits", + "celu", + "cross_entropy_loss", + "dropout", + "einsum", + "einsum_helper", + "ge", + "le", + "native_dropout", + "nll_loss", + "nll_loss2d", + "nll_loss_nd", + "outer", + "pow", + "tensordot", + "unfold", +] + def einsum_helper(g, equation, tensors): if not tensors: @@ -47,16 +67,29 @@ def outer(g, input, other): return einsum_helper(g, "i,j->ij", [input, other]) -@symbolic_helper.parse_args("v", "f", "i") -def dropout(g, input, p, train): +def _dropout_returns_masked_input_and_mask( + g, input: torch._C.Value, p: float, train: bool +) -> Tuple[torch._C.Value, Optional[torch._C.Value]]: symbolic_helper.check_training_mode(train, "dropout") - # if train is False, dropout is no-op + # In eval mode, dropout is non-op. That is, if the node's + # train param is set to False, dropout just returns its inputs. if not train: - return input + return input, None p = g.op("Constant", value_t=torch.tensor(p)) t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) - r, _ = g.op("Dropout", input, p, t, outputs=2) - return r + r, mask = g.op("Dropout", input, p, t, outputs=2) + return r, mask + + +@symbolic_helper.parse_args("v", "f", "i") +def dropout(g, input, p, train): + masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) + return masked + + +@symbolic_helper.parse_args("v", "f", "i") +def native_dropout(g, input, p, train): + return _dropout_returns_masked_input_and_mask(g, input, p, train) def nll_loss(g, self, target, weight, reduction, ignore_index): diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index d63b07263a00c..26f3d5d0d2f7d 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -4,10 +4,12 @@ # This file exports ONNX ops for opset 13 import torch import torch._C._onnx as _C_onnx -from torch.onnx import symbolic_helper -from torch.onnx import symbolic_opset9 as opset9 -from torch.onnx import symbolic_opset11 as opset11 -from torch.onnx import utils +from torch.onnx import ( + symbolic_helper, + symbolic_opset11 as opset11, + symbolic_opset9 as opset9, + utils, +) @symbolic_helper.parse_args("v", "i", "none") diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index c19ba1d31e2bb..fcc625889f856 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -27,8 +27,7 @@ import torch from torch import _C -from torch.onnx import symbolic_helper -from torch.onnx import symbolic_opset9 as opset9 +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 def __is_(g, self, other): diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py index 932290e8d31ef..7e64933147416 100644 --- a/torch/onnx/symbolic_opset7.py +++ b/torch/onnx/symbolic_opset7.py @@ -12,8 +12,7 @@ import warnings -from torch.onnx import symbolic_helper -from torch.onnx import symbolic_opset9 as opset9 +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 block_listed_operators = [ "scan", diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index b47530446d899..604f3de8b8981 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -33,8 +33,7 @@ import warnings import torch -from torch.onnx import symbolic_helper -from torch.onnx import symbolic_opset9 as opset9 +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 block_listed_operators = [ "nonzero", @@ -183,8 +182,12 @@ def matmul(g, self, other): def prelu(g, self, weight): self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) if self_rank is not None and self_rank > 2: weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) if symbolic_helper._try_get_scalar_type(self): old_type, self, weight = _try_cast_integer_to_float(g, self, weight) return _cast_to_type(g, g.op("PRelu", self, weight), old_type) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 20b206b7084a8..6c8d879559db6 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -8,7 +8,7 @@ import math import sys import warnings -from typing import List, Optional +from typing import List, Optional, Sequence, Tuple, Union import torch import torch._C._onnx as _C_onnx @@ -17,8 +17,7 @@ from torch import _C # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx import _patch_torch # noqa: F401 -from torch.onnx import symbolic_helper +from torch.onnx import _patch_torch, symbolic_helper # noqa: F401 from torch.onnx._exporter_states import ( SymbolicContext, # Special case class import for readability ) @@ -56,271 +55,273 @@ # __all__ = [ - "unused", - "reshape", - "reshape_as", + "abs", + "acos", + "adaptive_avg_pool1d", + "adaptive_avg_pool2d", + "adaptive_avg_pool3d", + "adaptive_max_pool1d", + "adaptive_max_pool2d", + "adaptive_max_pool3d", "add", - "sub", - "rsub", - "mul", - "div", "addcmul", - "floor_divide", - "floordiv", - "true_divide", - "reciprocal", - "cat", - "stack", - "mm", - "bmm", - "matmul", "addmm", - "neg", - "sqrt", - "rsqrt", - "tanh", - "sin", - "cos", - "tan", + "alias", + "alpha_dropout_", + "alpha_dropout", + "amax", + "amin", + "aminmax", + "arange", + "argmax", + "argmin", + "as_strided", + "as_tensor", "asin", - "acos", "atan", - "sigmoid", - "sign", - "overload_by_arg_count", - "sum", - "mean", - "prod", - "cumsum", - "t", - "expand", - "expand_as", - "embedding", - "embedding_bag", - "size", - "transpose", - "permute", - "view", - "view_as", - "unsafe_chunk", - "split", - "unsafe_split", - "split_with_sizes", - "unsafe_split_with_sizes", - "unbind", - "select", - "square", - "squeeze", - "prelu", - "silu", - "mish", - "op_with_optional_float_cast", - "relu", - "relu6", - "ceil", - "floor", - "threshold", - "leaky_relu", - "glu", - "softmax", - "softplus", - "get_pool_ceil_padding", - "max_pool1d", - "max_pool2d", - "max_pool3d", - "max_pool1d_with_indices", - "max_pool2d_with_indices", - "max_pool3d_with_indices", "avg_pool1d", "avg_pool2d", "avg_pool3d", - "adaptive_avg_pool1d", - "adaptive_avg_pool2d", - "adaptive_avg_pool3d", - "adaptive_max_pool1d", - "adaptive_max_pool2d", - "adaptive_max_pool3d", - "constant_pad_nd", - "reflection_pad", - "replication_pad", - "reflection_pad1d", - "reflection_pad2d", - "reflection_pad3d", - "replication_pad1d", - "replication_pad2d", - "replication_pad3d", - "pad", - "upsample_nearest1d", - "upsample_nearest2d", - "upsample_nearest3d", - "upsample_linear1d", - "upsample_bilinear2d", - "upsample_trilinear3d", + "baddbmm", + "batch_norm", + "bernoulli", "bitwise_not", - "wrap_logical_op_with_cast_to", - "wrap_logical_op_with_cast_to_and_from", - "wrap_logical_op_with_negation", - "eq", - "ne", - "gt", - "gt_impl", - "lt", - "lt_impl", - "ge", - "le", - "logical_and", - "logical_or", - "logical_xor", - "where", - "log_softmax", - "conv1d", - "conv2d", - "conv3d", + "bmm", + "broadcast_tensors", + "bucketize", + "cat", + "cdist", + "ceil", + "clamp_max", + "clamp_min", + "clamp", + "clone", + "constant_pad_nd", + "contiguous", + "convolution", + "conv_tbc", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", - "batch_norm", - "layer_norm", - "instance_norm", - "unfold", + "conv1d", + "conv2d", + "conv3d", + "cos", + "cosine_similarity", + "cross", + "cumsum", + "detach", + "dim", + "div", + "dot", + "dropout_", + "dropout", "elu", - "selu", - "index_select", - "index_put", - "index_fill", + "embedding_bag", + "embedding", + "empty_like", + "empty", + "eq", + "erf", + "exp", + "expand_as", + "expand", + "eye", + "feature_alpha_dropout_", + "feature_alpha_dropout", + "feature_dropout_", + "feature_dropout", + "fill", + "flatten", + "floor_divide", + "floor", + "floordiv", + "frobenius_norm", + "full_like", + "full", + "gather", + "ge", + "gelu", + "get_pool_ceil_padding", + "glu", + "group_norm", + "gru", + "gt_impl", + "gt", + "hann_window", + "hardshrink", + "hardsigmoid", + "hardswish", + "hardtanh", + "index_add", "index_copy", - "bucketize", - "type_as", - "cosine_similarity", - "pairwise_distance", - "clone", - "abs", + "index_fill", + "index_put", + "index_select", + "index", + "instance_norm", + "is_floating_point", + "isnan", + "item", + "kl_div", + "layer_norm", + "le", + "leaky_relu", + "lerp", + "lift", + "linalg_cross", + "linalg_matrix_norm", + "linalg_norm", + "linalg_vector_norm", + "linear", + "linspace", + "log_sigmoid", + "log_softmax", "log", - "log1p", "log10", - "pow", - "clamp", - "clamp_min", - "clamp_max", + "log1p", + "log2", + "logical_and", + "logical_or", + "logical_xor", + "logsumexp", + "lstm_cell", + "lstm", + "lt_impl", + "lt", + "masked_fill", + "matmul", + "max_pool1d_with_indices", + "max_pool1d", + "max_pool2d_with_indices", + "max_pool2d", + "max_pool3d_with_indices", + "max_pool3d", "max", "maximum", + "mean", + "meshgrid", "min", "minimum", - "amax", - "amin", - "aminmax", - "exp", - "dropout", - "feature_dropout", - "alpha_dropout", - "feature_alpha_dropout", - "dropout_", - "feature_dropout_", - "alpha_dropout_", - "feature_alpha_dropout_", - "norm", - "conv_tbc", - "empty", - "empty_like", + "mish", + "mm", + "movedim", + "mul", + "multinomial", + "mv", + "narrow", + "native_layer_norm", + "ne", + "neg", "new_empty", - "scalar_tensor", - "tensor", - "as_tensor", - "zeros", - "zeros_like", - "new_zeros", - "ones", - "ones_like", - "new_ones", - "full", - "full_like", "new_full", - "eye", - "slice", - "hardtanh", - "hardswish", - "hardsigmoid", - "tanhshrink", - "hardshrink", - "softshrink", - "alias", - "unsqueeze", - "sort", + "new_ones", + "new_zeros", + "nonzero_numpy", + "nonzero", + "norm", "numel", - "topk", - "to", - "repeat", - "repeat_interleave", - "pixel_shuffle", - "pixel_unshuffle", - "lstm", - "lstm_cell", - "gru", - "rnn_tanh", - "rnn_relu", - "detach", - "contiguous", - "randn", + "numpy_T", + "one_hot", + "ones_like", + "ones", + "Onnx", + "op_with_optional_float_cast", + "overload_by_arg_count", + "pad", + "pairwise_distance", + "permute", + "pixel_shuffle", + "pixel_unshuffle", + "pow", + "prelu", + "Prim", + "prod", + "rand_like", "rand", "randn_like", - "rand_like", + "randn", + "reciprocal", + "reflection_pad", + "reflection_pad1d", + "reflection_pad2d", + "reflection_pad3d", + "relu", + "relu6", + "remainder", + "repeat_interleave", + "repeat", + "replication_pad", + "replication_pad1d", + "replication_pad2d", + "replication_pad3d", + "reshape_as", + "reshape", + "rnn_relu", + "rnn_tanh", + "roll", "rrelu", - "bernoulli", - "log_sigmoid", - "erf", - "flatten", - "nonzero", - "nonzero_numpy", - "isnan", - "narrow", - "argmax", - "argmin", - "scatter", + "rsqrt", + "rsub", + "scalar_tensor", "scatter_add", - "log2", - "is_floating_point", - "one_hot", - "gather", - "std", - "var", - "var_mean", + "scatter", + "select", + "selu", + "sigmoid", + "sign", + "silu", + "sin", + "size", + "slice", + "softmax", + "softplus", + "softshrink", + "sort", + "split_with_sizes", + "split", + "sqrt", + "square", + "squeeze", + "stack", "std_mean", - "logsumexp", - "arange", - "linspace", - "lift", - "masked_fill", - "index", - "linalg_norm", - "linalg_vector_norm", - "linalg_matrix_norm", - "linalg_cross", - "frobenius_norm", - "multinomial", - "baddbmm", - "meshgrid", - "remainder", - "gelu", - "group_norm", - "dim", - "item", + "std", + "sub", + "sum", + "t", "take", - "kl_div", - "as_strided", - "linear", - "hann_window", - "mv", - "dot", - "movedim", - "fill", - "index_add", - "roll", - "cross", - "cdist", - "lerp", - "broadcast_tensors", - "Prim", - "Onnx", + "tan", + "tanh", + "tanhshrink", + "tensor", + "threshold", + "to", + "topk", + "transpose", + "true_divide", + "type_as", + "unbind", + "unfold", + "unsafe_chunk", + "unsafe_split_with_sizes", + "unsafe_split", + "unsqueeze", + "unused", + "upsample_bilinear2d", + "upsample_linear1d", + "upsample_nearest1d", + "upsample_nearest2d", + "upsample_nearest3d", + "upsample_trilinear3d", + "var_mean", + "var", + "view_as", + "view", + "where", + "wrap_logical_op_with_cast_to", + "wrap_logical_op_with_negation", + "zeros_like", + "zeros", ] # used to represent "missing" optional inputs @@ -354,17 +355,14 @@ def add(g, self, other, alpha=None): return symbolic_helper._onnx_opset_unsupported_detailed( "Add", 9, 11, "Add between list of tensors not supported" ) - - # default alpha arg is to allow no-alpha add (aten add st overload no alpha) if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: - return symbolic_helper._unimplemented("add", "alpha != 1") + other = g.op("Mul", other, alpha) return g.op("Add", self, other) def sub(g, self, other, alpha=None): - # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha) if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: - return symbolic_helper._unimplemented("sub", "alpha != 1") + other = g.op("Mul", other, alpha) return g.op("Sub", self, other) @@ -373,7 +371,11 @@ def rsub(g, self, other, alpha=None): def mul(g, self, other): - return g.op("Mul", self, other) + if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): + # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. + return g.op("And", self, other) + else: + return g.op("Mul", self, other) def div(g, self, other, *args): @@ -768,6 +770,12 @@ def t(g, self): return g.op("Transpose", self, perm_i=(1, 0)) +def numpy_T(g, input): + ndim = symbolic_helper._get_tensor_rank(input) + perm = list(reversed(range(0, ndim))) + return g.op("Transpose", input, perm_i=perm) + + def expand(g, self, size, implicit): size = symbolic_helper._maybe_get_const(size, "is") if not symbolic_helper._is_value(size): @@ -1063,20 +1071,19 @@ def squeeze(g, self, dim=None): def prelu(g, self, weight): self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + weight_rank = len(weight_sizes) if self_rank is not None: if self_rank > 2: # make weight unidirectional broadcastable weight = symbolic_helper._unsqueeze_helper( g, weight, list(range(1, self_rank - 1)) ) - elif self_rank == 0: - # weight is always rank 1. torch allows scalar self, and ONNX is ambiguous - # about whether this is allowed, but some implementations enforce - # rank(self) >= rank(weight), which makes sense. - self = symbolic_helper._unsqueeze_helper(g, self, [0]) - self_rank = 1 - - weight_rank = symbolic_helper._get_tensor_rank(weight) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + weight_rank = 0 + if self_rank is not None and weight_rank is not None: assert ( self_rank >= weight_rank @@ -1388,15 +1395,16 @@ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): def _avg_pool(name, tuple_fn): + @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") def symbolic_fn( g, - input, - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, + input: _C.Value, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Union[int, Tuple[int, ...]], + ceil_mode: int, + count_include_pad: int, divisor_override=None, ): if not stride: @@ -1404,8 +1412,7 @@ def symbolic_fn( padding = symbolic_helper._avgpool_helper( tuple_fn, padding, kernel_size, stride, divisor_override, name ) - if ceil_mode: - padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + adjusted_padding = padding if count_include_pad: input = g.op( "Pad", @@ -1414,17 +1421,20 @@ def symbolic_fn( mode_s="constant", value_f=0.0, ) - padding = (0,) * len(padding) + adjusted_padding = (0,) * len(padding) if ceil_mode: - padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + adjusted_padding = adjusted_padding + tuple( + a + b for (a, b) in zip(padding_ceil, adjusted_padding) + ) else: - padding = padding * 2 + adjusted_padding = adjusted_padding * 2 output = g.op( "AveragePool", input, kernel_shape_i=tuple_fn(kernel_size), strides_i=tuple_fn(stride), - pads_i=padding, + pads_i=adjusted_padding, ) return output @@ -1686,27 +1696,10 @@ def bitwise_not(g, inp): def wrap_logical_op_with_cast_to(to_type): - def decorator(fn): - def wrap_with_cast(g, input, other): - return g.op( - "Cast", - fn(g, input, other), - to_i=symbolic_helper.cast_pytorch_to_onnx[to_type], - ) - - return wrap_with_cast - - return decorator - - -def wrap_logical_op_with_cast_to_and_from(to_type): def decorator(fn): def wrap_with_cast(g, input, other): to_cast_func = globals()[f"_cast_{to_type}"] - from_cast_func = wrap_logical_op_with_cast_to(input.type().scalarType())(fn) - return from_cast_func( - g, to_cast_func(g, input, False), to_cast_func(g, other, False) - ) + return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) return wrap_with_cast @@ -1816,17 +1809,17 @@ def __xor_(g, input, other): ) -@wrap_logical_op_with_cast_to_and_from("Bool") +@wrap_logical_op_with_cast_to("Bool") def logical_and(g, input, other): return g.op("And", input, other) -@wrap_logical_op_with_cast_to_and_from("Bool") +@wrap_logical_op_with_cast_to("Bool") def logical_or(g, input, other): return g.op("Or", input, other) -@wrap_logical_op_with_cast_to_and_from("Bool") +@wrap_logical_op_with_cast_to("Bool") def logical_xor(g, input, other): return g.op("Xor", input, other) @@ -1926,6 +1919,13 @@ def log_softmax(g, input, dim, dtype=None): return return_op +@symbolic_helper.parse_args("v", "i", "i") +def _log_softmax(g, input, dim, half_to_float): + if half_to_float and input.type().scalarType() == "Half": + input = g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]) + return log_softmax(g, input, dim) + + @symbolic_helper.parse_args( "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" ) @@ -1993,6 +1993,37 @@ def _convolution( return n +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") +def convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + None, + None, + None, + None, + ) + + @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i") def conv1d(g, input, weight, bias, stride, padding, dilation, groups): return _convolution( @@ -2174,8 +2205,16 @@ def batch_norm( return res -@symbolic_helper.parse_args("v", "is", "v", "v", "f", "i") -def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): +def _layer_norm_returns_normalized_input_mean_rstd( + g, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, + return_mean_rstd: bool, +) -> Tuple[_C.Value, Optional[_C.Value], Optional[_C.Value]]: if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "layer_norm", @@ -2186,7 +2225,6 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): eps_f=eps, cudnn_enable_i=cudnn_enable, ) - axes = [-i for i in range(len(normalized_shape), 0, -1)] two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) @@ -2198,14 +2236,33 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) denominator = sqrt(g, add(g, variance, eps_cst)) - layer_norm = g.op("Div", numerator, denominator) + normalized = g.op("Div", numerator, denominator) if not (weight is None or symbolic_helper._is_none(weight)): - layer_norm = mul(g, layer_norm, weight) + normalized = mul(g, normalized, weight) if not (bias is None or symbolic_helper._is_none(bias)): - layer_norm = add(g, layer_norm, bias) + normalized = add(g, normalized, bias) + + if return_mean_rstd: + # rdenominator = 1 / sqrt(variance + eps) + rdenominator = reciprocal(g, denominator) + return normalized, mean, rdenominator + return normalized, None, None + - return layer_norm +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def native_layer_norm(g, input, normalized_shape, weight, bias, eps): + return _layer_norm_returns_normalized_input_mean_rstd( + g, input, normalized_shape, weight, bias, eps, False, True + ) + + +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "i") +def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): + normalized, _, _ = _layer_norm_returns_normalized_input_mean_rstd( + g, input, normalized_shape, weight, bias, eps, cudnn_enable, False + ) + return normalized @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") @@ -2777,7 +2834,8 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if symbolic_helper._is_packed_list(data): if dtype is None: - dtype = symbolic_helper._unpack_list(data)[0].type().scalarType() + dtype = symbolic_helper._unpack_list(data)[0].type().scalarType() # type: ignore[attr-defined] + # TODO(justinchuby): Remove type ignore after #81112 is checked in. dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) @@ -3073,27 +3131,71 @@ def tanhshrink(g, self): @symbolic_helper.parse_args("v", "f") def hardshrink(g, self, lambd): - lambd_op = g.op("Constant", value_t=torch.FloatTensor([lambd])) + dtype = self.type().scalarType() + if dtype is None: + dtype = symbolic_helper.ScalarType.FLOAT + else: + dtype = symbolic_helper.scalar_type_to_onnx.index( + symbolic_helper.cast_pytorch_to_onnx[dtype] + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor( + lambd, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] + ), + ) cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) - return g.op("Where", cond, self, g.op("Constant", value_t=torch.FloatTensor([0]))) + return g.op( + "Where", + cond, + self, + g.op( + "Constant", + value_t=torch.tensor( + 0, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] + ), + ), + ) @symbolic_helper.parse_args("v", "f") def softshrink(g, self, lambd): - lambd_op = g.op("Constant", value_t=torch.FloatTensor([lambd])) + dtype = self.type().scalarType() + if dtype is None: + dtype = symbolic_helper.ScalarType.FLOAT + else: + dtype = symbolic_helper.scalar_type_to_onnx.index( + symbolic_helper.cast_pytorch_to_onnx[dtype] + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor( + lambd, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] + ), + ) gt_cond = gt(g, self, lambd_op) gt_out = g.op( "Where", gt_cond, sub(g, self, lambd_op), - g.op("Constant", value_t=torch.FloatTensor([0])), + g.op( + "Constant", + value_t=torch.tensor( + 0, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] + ), + ), ) lt_cond = lt(g, self, neg(g, lambd_op)) lt_out = g.op( "Where", lt_cond, add(g, self, lambd_op), - g.op("Constant", value_t=torch.FloatTensor([0])), + g.op( + "Constant", + value_t=torch.tensor( + 0, dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype] + ), + ), ) return add(g, gt_out, lt_out) @@ -4079,7 +4181,7 @@ def _any(g, *args): input_sum = symbolic_helper._reducesum_helper( g, input, axes_i=dim, keepdims_i=keepdim ) - return gt(g, input_sum, g.op("Constant", value_t=torch.LongTensor([0]))) + return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) def _all(g, *args): @@ -4545,7 +4647,14 @@ def try_mask_to_index(index): @symbolic_helper.parse_args("v", "v", "is", "i", "v") -def linalg_norm(g, self, ord, dim, keepdim, dtype): +def linalg_norm( + g, + self: torch._C.Value, + ord: torch._C.Value, + dim: List[int], + keepdim: int, + dtype: torch._C.Value, +): # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html ord_value = None if dim is None: @@ -4572,11 +4681,18 @@ def linalg_norm(g, self, ord, dim, keepdim, dtype): @symbolic_helper.parse_args("v", "f", "is", "i", "v") -def linalg_vector_norm(g, self, ord, dim, keepdim, dtype): +def linalg_vector_norm( + g, + self: torch._C.Value, + ord: float, + dim: List[int], + keepdim: int, + dtype: torch._C.Value, +): # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html if dim is None: self = symbolic_helper._reshape_helper(g, self, [-1]) - keepdim = None + keepdim = 0 if ord == math.inf: result = g.op("ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim) @@ -4587,20 +4703,31 @@ def linalg_vector_norm(g, self, ord, dim, keepdim, dtype): "linalg_vector_norm", 9, 11, "ord=0 not supported" ) else: - ord_op = g.op("Constant", value_t=torch.FloatTensor([ord])) + ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) result = symbolic_helper._reducesum_helper( g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim ) result = g.op( "Pow", result, - g.op("Div", g.op("Constant", value_t=torch.FloatTensor([1])), ord_op), + g.op( + "Div", + g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), + ord_op, + ), ) return result @symbolic_helper.parse_args("v", "v", "is", "i", "v") -def linalg_matrix_norm(g, self, ord, dim, keepdim, dtype): +def linalg_matrix_norm( + g, + self: torch._C.Value, + ord: torch._C.Value, + dim: List[int], + keepdim: int, + dtype: torch._C.Value, +): # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html ord_value = symbolic_helper._parse_arg(ord, "s") if ord_value == "fro": diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 9c9b5de26b078..21fcc7ac8b87e 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -19,6 +19,7 @@ from typing import ( Any, Callable, + cast, Collection, Dict, List, @@ -125,30 +126,36 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): @contextlib.contextmanager -def disable_apex_o2_state_dict_hook(model): +def disable_apex_o2_state_dict_hook( + model: Union[torch.nn.Module, torch.jit.ScriptFunction] +): # Apex O2 hook state_dict to return fp16 weights as fp32. # Exporter cannot identify them as same tensors. # Since this hook is only used by optimizer, it is safe to # remove this hook while exporting. if not isinstance(model, torch.jit.ScriptFunction): - tmp_map = {} # type: ignore[var-annotated] + model_hooks = {} # type: ignore[var-annotated] for module in model.modules(): - for k, v in module._state_dict_hooks.items(): - if type(v).__name__ == "O2StateDictHook": - if module not in tmp_map: - tmp_map[module] = {} - tmp_map[module][k] = v - if module in tmp_map: - for k in tmp_map[module].keys(): - module._state_dict_hooks.pop(k) - try: - yield - finally: - if not isinstance(model, torch.jit.ScriptFunction): - # FIXME(justinchuby): tmp_map is possibly unbound - for module, m_map in tmp_map.items(): - for k, v in m_map.items(): - module._state_dict_hooks[k] = v + for key, hook in module._state_dict_hooks.items(): + if type(hook).__name__ == "O2StateDictHook": + if module not in model_hooks: + model_hooks[module] = {} + model_hooks[module][key] = hook + if module in model_hooks: + for key in model_hooks[module]: + module._state_dict_hooks.pop(key) + try: + yield + finally: + # Add the hooks back + for module, m_map in model_hooks.items(): + for key, hook in m_map.items(): + module._state_dict_hooks[key] = hook + else: + try: + yield + finally: + pass @contextlib.contextmanager @@ -873,51 +880,60 @@ def flatten(x): ) -def _create_jit_graph(model, args): - torch_out = None - params: Union[List, Tuple] +def _create_jit_graph( + model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any] +) -> Tuple[ + _C.Graph, + List[_C.IValue], + Optional[Any], + Optional[Union[_C.ScriptModule, _C.ScriptFunction]], +]: if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) _check_flatten_did_not_remove(args, flattened_args) - if isinstance(model, torch.jit.ScriptModule): - try: - graph = model.forward.graph - except AttributeError as e: - raise RuntimeError("'forward' method must be a script method") from e - _C._jit_pass_onnx_function_substitution(graph) - freezed_m = _C._freeze_module(model._c, preserveParameters=True) - module, params = _C._jit_onnx_list_model_parameters(freezed_m) - method_graph = module._get_method("forward").graph - args_params = tuple(args) + tuple(params) - param_count_list = _get_param_count_list(method_graph, args_params) - in_vars, _ = torch.jit._flatten(args_params) - graph = _C._propagate_and_assign_input_shapes( - method_graph, tuple(in_vars), param_count_list, False, False - ) - return graph, params, torch_out, module - elif isinstance(model, torch.jit.ScriptFunction): - params = () + torch_out = None + + if isinstance(model, torch.jit.ScriptModule): + try: + graph = model.forward.graph + except AttributeError as e: + raise RuntimeError("'forward' method must be a script method") from e + _C._jit_pass_onnx_function_substitution(graph) + freezed_module = _C._freeze_module( + cast(_C.ScriptModule, model._c), preserveParameters=True + ) + module, params = _C._jit_onnx_list_model_parameters(freezed_module) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = _C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) + return graph, params, torch_out, module + + # torch.jit.ScriptFunction + params = [] graph = model.graph _C._jit_pass_onnx_function_substitution(graph) param_count_list = _get_param_count_list(graph, args) - # FIXME(justinchuby): flattened_args is possibly unbound graph = _C._propagate_and_assign_input_shapes( graph, flattened_args, param_count_list, False, False ) return graph, params, torch_out, None - else: - graph, torch_out = _trace_and_get_graph_from_model(model, args) - _C._jit_pass_onnx_lint(graph) - state_dict = torch.jit._unique_state_dict(model) - params = list(state_dict.values()) - graph_inputs = list(graph.inputs()) - user_input_num = len(graph_inputs) - len(state_dict) - param_names = list(state_dict.keys()) - for i, inp in enumerate(graph_inputs): - if i >= user_input_num: - inp.setDebugName(param_names[i - user_input_num]) - _C._jit_pass_onnx_function_substitution(graph) - return graph, params, torch_out, None + + graph, torch_out = _trace_and_get_graph_from_model(model, args) + _C._jit_pass_onnx_lint(graph) + state_dict = torch.jit._unique_state_dict(model) + params = list(state_dict.values()) + graph_inputs = list(graph.inputs()) + user_input_num = len(graph_inputs) - len(state_dict) + param_names = list(state_dict.keys()) + for i, inp in enumerate(graph_inputs): + if i >= user_input_num: + inp.setDebugName(param_names[i - user_input_num]) + _C._jit_pass_onnx_function_substitution(graph) + return graph, params, torch_out, None def _get_named_param_dict(graph, params): diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py index 2fa7b3dddd041..198910623204b 100644 --- a/torch/optim/_functional.py +++ b/torch/optim/_functional.py @@ -28,13 +28,15 @@ def sparse_adam(params: List[Tensor], eps: float, beta1: float, beta2: float, - lr: float): + lr: float, + maximize: bool): r"""Functional API that performs Sparse Adam algorithm computation. See :class:`~torch.optim.SparseAdam` for details. """ for i, param in enumerate(params): grad = grads[i] + grad = grad if not maximize else -grad grad = grad.coalesce() # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index eb1d4e3f769a1..14fbfa94ebd70 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -4,6 +4,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['Adadelta', 'adadelta'] class Adadelta(Optimizer): r"""Implements Adadelta algorithm. @@ -78,7 +79,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 59c1db6caf298..1202eeb014d61 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -4,6 +4,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['Adagrad', 'adagrad'] class Adagrad(Optimizer): r"""Implements Adagrad algorithm. @@ -124,7 +125,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None @@ -188,7 +189,7 @@ def adagrad( See :class:`~torch.optim.Adagrad` for details. """ - if not all([isinstance(t, torch.Tensor) for t in state_steps]): + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( "API has changed, `state_steps` argument must contain a list of singleton tensors" ) @@ -303,7 +304,7 @@ def _multi_tensor_adagrad( grads = torch._foreach_neg(grads) if has_sparse_grad is None: - has_sparse_grad = any([grad.is_sparse for grad in grads]) + has_sparse_grad = any(grad.is_sparse for grad in grads) if has_sparse_grad: return _single_tensor_adagrad( diff --git a/torch/optim/adam.py b/torch/optim/adam.py index e75633a786a9f..7176ea544d611 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -4,6 +4,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['Adam', 'adam'] class Adam(Optimizer): r"""Implements Adam algorithm. @@ -54,7 +55,7 @@ class Adam(Optimizer): eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this + amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) foreach (bool, optional): whether foreach implementation of optimizer @@ -107,7 +108,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ self._cuda_graph_capture_health_check() @@ -195,7 +196,7 @@ def adam(params: List[Tensor], See :class:`~torch.optim.Adam` for details. """ - if not all([isinstance(t, torch.Tensor) for t in state_steps]): + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") if foreach is None: @@ -251,8 +252,6 @@ def _single_tensor_adam(params: List[Tensor], if capturable: assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors." - else: - assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." # update step step_t += 1 @@ -260,6 +259,12 @@ def _single_tensor_adam(params: List[Tensor], if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + exp_avg_sq = torch.view_as_real(exp_avg_sq) + param = torch.view_as_real(param) + # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) @@ -330,13 +335,16 @@ def _multi_tensor_adam(params: List[Tensor], if capturable: assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \ "If capturable=True, params and state_steps must be CUDA tensors." - else: - assert all(not step.is_cuda for step in state_steps), \ - "If capturable=False, state_steps should not be CUDA tensors." if maximize: grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment] + # Handle complex parameters + grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] + exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs] + exp_avg_sqs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avg_sqs] + params_ = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] + # update steps torch._foreach_add_(state_steps, 1) @@ -386,7 +394,7 @@ def _multi_tensor_adam(params: List[Tensor], torch._foreach_reciprocal_(eps_over_step_size) denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size) - torch._foreach_addcdiv_(params, exp_avgs, denom) + torch._foreach_addcdiv_(params_, exp_avgs, denom) else: bias_correction1 = [1 - beta1 ** step.item() for step in state_steps] bias_correction2 = [1 - beta2 ** step.item() for step in state_steps] @@ -408,4 +416,4 @@ def _multi_tensor_adam(params: List[Tensor], torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) denom = torch._foreach_add(exp_avg_sq_sqrt, eps) - torch._foreach_addcdiv_(params, exp_avgs, denom, step_size) + torch._foreach_addcdiv_(params_, exp_avgs, denom, step_size) diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 00d58936fbea5..8fe74bf414143 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -4,6 +4,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['Adamax', 'adamax'] class Adamax(Optimizer): r"""Implements Adamax algorithm (a variant of Adam based on infinity norm). @@ -82,7 +83,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None @@ -160,7 +161,7 @@ def adamax(params: List[Tensor], See :class:`~torch.optim.Adamax` for details. """ - if not all([isinstance(t, torch.Tensor) for t in state_steps]): + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") if foreach is None: diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 94ddf65506e70..d678b695820aa 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -4,6 +4,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['AdamW', 'adamw'] class AdamW(Optimizer): r"""Implements AdamW algorithm. @@ -54,7 +55,7 @@ class AdamW(Optimizer): eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay coefficient (default: 1e-2) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this + amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) maximize (bool, optional): maximize the params based on the objective, instead of @@ -108,7 +109,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ self._cuda_graph_capture_health_check() @@ -200,7 +201,7 @@ def adamw(params: List[Tensor], See :class:`~torch.optim.AdamW` for details. """ - if not all([isinstance(t, torch.Tensor) for t in state_steps]): + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") if foreach is None: @@ -255,8 +256,6 @@ def _single_tensor_adamw(params: List[Tensor], if capturable: assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors." - else: - assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." # update step step_t += 1 @@ -334,9 +333,6 @@ def _multi_tensor_adamw(params: List[Tensor], if capturable: assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \ "If capturable=True, params and state_steps must be CUDA tensors." - else: - assert all(not step.is_cuda for step in state_steps), \ - "If capturable=False, state_steps should not be CUDA tensors." if maximize: grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment] diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 687ab508580cb..100b743a9d7ae 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -5,6 +5,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['ASGD', 'asgd'] class ASGD(Optimizer): """Implements Averaged Stochastic Gradient Descent. @@ -22,26 +23,29 @@ class ASGD(Optimizer): weight_decay (float, optional): weight decay (L2 penalty) (default: 0) foreach (bool, optional): whether foreach implementation of optimizer is used (default: None) + maximize (bool, optional): maximize the params based on the objective, instead of + minimizing (default: False) .. _Acceleration of stochastic approximation by averaging: https://dl.acm.org/citation.cfm?id=131098 """ def __init__(self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0, - foreach: Optional[bool] = None): + foreach: Optional[bool] = None, maximize: bool = False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, lambd=lambd, alpha=alpha, t0=t0, - weight_decay=weight_decay, foreach=foreach) + weight_decay=weight_decay, foreach=foreach, maximize=maximize) super(ASGD, self).__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault('foreach', None) + group.setdefault('maximize', False) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) if not step_is_tensor: @@ -61,7 +65,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None @@ -108,7 +112,8 @@ def step(self, closure=None): t0=group['t0'], alpha=group['alpha'], weight_decay=group['weight_decay'], - foreach=group['foreach']) + foreach=group['foreach'], + maximize=group['maximize']) return loss @@ -122,6 +127,7 @@ def asgd(params: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: bool = None, + maximize: bool = False, *, lambd: float, lr: float, @@ -155,7 +161,8 @@ def asgd(params: List[Tensor], lr=lr, t0=t0, alpha=alpha, - weight_decay=weight_decay) + weight_decay=weight_decay, + maximize=maximize) def _single_tensor_asgd(params: List[Tensor], @@ -169,10 +176,12 @@ def _single_tensor_asgd(params: List[Tensor], lr: float, t0: float, alpha: float, - weight_decay: float): + weight_decay: float, + maximize: bool): for i, param in enumerate(params): grad = grads[i] + grad = grad if not maximize else -grad mu = mus[i] ax = axs[i] eta = etas[i] @@ -214,11 +223,15 @@ def _multi_tensor_asgd(params: List[Tensor], lr: float, t0: float, alpha: float, - weight_decay: float): + weight_decay: float, + maximize: bool): if len(params) == 0: return + if maximize: + grads = torch._foreach_neg(grads) + # update step torch._foreach_add_(state_steps, 1) diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index d82c5ed933827..9f9336128699e 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -2,6 +2,7 @@ from functools import reduce from .optimizer import Optimizer +__all__ = ['LBFGS'] def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): # ported from https://github.com/torch/optim/blob/master/polyinterp.lua @@ -284,7 +285,7 @@ def step(self, closure): """Performs a single optimization step. Args: - closure (callable): A closure that reevaluates the model + closure (Callable): A closure that reevaluates the model and returns the loss. """ assert len(self.param_groups) == 1 diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 0cd53f8a708e7..de64c88e0f7c8 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -9,6 +9,9 @@ from .optimizer import Optimizer +__all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR', + 'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau', + 'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR'] EPOCH_DEPRECATION_WARNING = ( "The epoch parameter in `scheduler.step()` was not necessary and is being " diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 3a9d6bcf44d5d..edf4f59a6ae8d 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -4,6 +4,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['NAdam', 'nadam'] class NAdam(Optimizer): r"""Implements NAdam algorithm. @@ -92,7 +93,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None @@ -169,10 +170,10 @@ def nadam(params: List[Tensor], See :class:`~torch.optim.NAdam` for details. """ - if not all([isinstance(t, torch.Tensor) for t in state_steps]): + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") - if not all([isinstance(t, torch.Tensor) for t in mu_products]): + if not all(isinstance(t, torch.Tensor) for t in mu_products): raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors") if foreach is None: diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 2d4ca41d43018..7b32603babac6 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -1,11 +1,11 @@ from collections import defaultdict, abc as container_abcs - import torch from copy import deepcopy from itertools import chain import warnings import functools +__all__ = ['Optimizer'] class _RequiredParameter(object): """Singleton class representing a required parameter for an Optimizer.""" @@ -15,6 +15,18 @@ def __repr__(self): required = _RequiredParameter() +def _use_grad_for_differentiable(func): + def _use_grad(self, *args, **kwargs): + prev_grad = torch.is_grad_enabled() + try: + torch.set_grad_enabled(self.defaults['differentiable']) + ret = func(self, *args, **kwargs) + finally: + torch.set_grad_enabled(prev_grad) + return ret + return _use_grad + + class Optimizer(object): r"""Base class for all optimizers. @@ -58,6 +70,7 @@ def __init__(self, params, defaults): # https://github.com/pytorch/pytorch/issues/72948 self._warned_capturable_if_run_uncaptured = True + def __getstate__(self): return { 'defaults': self.defaults, @@ -68,6 +81,7 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) self._hook_for_profile() # To support multiprocessing pickle/unpickle. + self.defaults.setdefault('differentiable', False) def __repr__(self): format_string = self.__class__.__name__ + ' (' @@ -90,7 +104,11 @@ def _cuda_graph_capture_health_check(self): self.__class__.__name__ + " but this instance was constructed with capturable=False.") - if (not self._warned_capturable_if_run_uncaptured) and self.defaults['capturable'] and (not capturing): + if ( + (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) + and self.defaults["capturable"] + and (not capturing) + ): print("Warning: This instance was constructed with capturable=True, but step() " + "is running without CUDA graph capture. If you never intend to graph-capture this " + "instance, capturable=True can impair performance, and you should set capturable=False.") @@ -254,7 +272,7 @@ def step(self, closure): r"""Performs a single optimization step (parameter update). Args: - closure (callable): A closure that reevaluates the model and + closure (Callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. .. note:: @@ -288,7 +306,7 @@ def add_param_group(self, param_group): if not isinstance(param, torch.Tensor): raise TypeError("optimizer can only optimize Tensors, " "but one of the params is " + torch.typename(param)) - if not param.is_leaf: + if not self.defaults.get('differentiable', None) and not (param.is_leaf or param.retains_grad): raise ValueError("can't optimize a non-leaf Tensor") for name, default in self.defaults.items(): diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 837cf8a57d180..f5bd0e78ae0c4 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -5,6 +5,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['RAdam', 'radam'] class RAdam(Optimizer): r"""Implements RAdam algorithm. @@ -90,7 +91,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None @@ -160,7 +161,7 @@ def radam(params: List[Tensor], See :class:`~torch.optim.RAdam` for details. """ - if not all([isinstance(t, torch.Tensor) for t in state_steps]): + if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") if foreach is None: diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 313c4e9229553..d5aa61e9540c7 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -3,6 +3,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['RMSprop', 'rmsprop'] class RMSprop(Optimizer): r"""Implements RMSprop algorithm. @@ -61,11 +62,13 @@ class RMSprop(Optimizer): weight_decay (float, optional): weight decay (L2 penalty) (default: 0) foreach (bool, optional): whether foreach implementation of optimizer is used (default: None) + maximize (bool, optional): maximize the params based on the objective, instead of + minimizing (default: False) """ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, - centered=False, foreach: Optional[bool] = None): + centered=False, foreach: Optional[bool] = None, maximize: bool = False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -78,7 +81,7 @@ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, moment raise ValueError("Invalid alpha value: {}".format(alpha)) defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, - weight_decay=weight_decay, foreach=foreach) + weight_decay=weight_decay, foreach=foreach, maximize=maximize) super(RMSprop, self).__init__(params, defaults) def __setstate__(self, state): @@ -87,13 +90,14 @@ def __setstate__(self, state): group.setdefault('momentum', 0) group.setdefault('centered', False) group.setdefault('foreach', None) + group.setdefault('maximize', False) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None @@ -149,7 +153,8 @@ def step(self, closure=None): weight_decay=group['weight_decay'], momentum=group['momentum'], centered=group['centered'], - foreach=group['foreach']) + foreach=group['foreach'], + maximize=group["maximize"]) return loss @@ -162,6 +167,7 @@ def rmsprop(params: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: bool = None, + maximize: bool = False, *, lr: float, alpha: float, @@ -195,7 +201,8 @@ def rmsprop(params: List[Tensor], eps=eps, weight_decay=weight_decay, momentum=momentum, - centered=centered) + centered=centered, + maximize=maximize) def _single_tensor_rmsprop(params: List[Tensor], @@ -209,10 +216,12 @@ def _single_tensor_rmsprop(params: List[Tensor], eps: float, weight_decay: float, momentum: float, - centered: bool): + centered: bool, + maximize: bool): for i, param in enumerate(params): grad = grads[i] + grad = grad if not maximize else -grad square_avg = square_avgs[i] if weight_decay != 0: @@ -246,11 +255,15 @@ def _multi_tensor_rmsprop(params: List[Tensor], eps: float, weight_decay: float, momentum: float, - centered: bool): + centered: bool, + maximize: bool): if len(params) == 0: return + if maximize: + grads = torch._foreach_neg(grads) + if weight_decay != 0: torch._foreach_add_(grads, params, alpha=weight_decay) diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index d976647f7ab4a..1b9952b26b7d3 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -3,6 +3,7 @@ from .optimizer import Optimizer from typing import List, Optional +__all__ = ['Rprop', 'rprop'] class Rprop(Optimizer): r"""Implements the resilient backpropagation algorithm. @@ -73,7 +74,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 022d3f1b0bd8f..4a941d4593900 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,8 +1,9 @@ import torch from torch import Tensor -from .optimizer import Optimizer, required +from .optimizer import Optimizer, required, _use_grad_for_differentiable from typing import List, Optional +__all__ = ['SGD', 'sgd'] class SGD(Optimizer): r"""Implements stochastic gradient descent (optionally with momentum). @@ -89,7 +90,8 @@ class SGD(Optimizer): """ def __init__(self, params, lr=required, momentum=0, dampening=0, - weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None): + weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None, + differentiable=False): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -99,7 +101,8 @@ def __init__(self, params, lr=required, momentum=0, dampening=0, defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov, - maximize=maximize, foreach=foreach) + maximize=maximize, foreach=foreach, + differentiable=differentiable) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super(SGD, self).__init__(params, defaults) @@ -111,12 +114,12 @@ def __setstate__(self, state): group.setdefault('maximize', False) group.setdefault('foreach', None) - @torch.no_grad() + @_use_grad_for_differentiable def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None @@ -218,8 +221,8 @@ def _single_tensor_sgd(params: List[Tensor], has_sparse_grad: bool): for i, param in enumerate(params): + d_p = d_p_list[i] if not maximize else -d_p_list[i] - d_p = d_p_list[i] if weight_decay != 0: d_p = d_p.add(param, alpha=weight_decay) @@ -237,8 +240,7 @@ def _single_tensor_sgd(params: List[Tensor], else: d_p = buf - alpha = lr if maximize else -lr - param.add_(d_p, alpha=alpha) + param.add_(d_p, alpha=-lr) def _multi_tensor_sgd(params: List[Tensor], @@ -257,7 +259,10 @@ def _multi_tensor_sgd(params: List[Tensor], return if has_sparse_grad is None: - has_sparse_grad = any([grad.is_sparse for grad in grads]) + has_sparse_grad = any(grad.is_sparse for grad in grads) + + if maximize: + grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment] if weight_decay != 0: grads = torch._foreach_add(grads, params, alpha=weight_decay) @@ -292,10 +297,9 @@ def _multi_tensor_sgd(params: List[Tensor], else: grads = bufs - alpha = lr if maximize else -lr if not has_sparse_grad: - torch._foreach_add_(params, grads, alpha=alpha) + torch._foreach_add_(params, grads, alpha=-lr) else: # foreach APIs dont support sparse for i in range(len(params)): - params[i].add_(grads[i], alpha=alpha) + params[i].add_(grads[i], alpha=-lr) diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index f31dddab65bff..a32aaca92ad47 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -2,6 +2,7 @@ from . import _functional as F from .optimizer import Optimizer +__all__ = ['SparseAdam'] class SparseAdam(Optimizer): r"""Implements lazy version of Adam algorithm suitable for sparse tensors. @@ -17,12 +18,14 @@ class SparseAdam(Optimizer): running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + maximize (bool, optional): maximize the params based on the objective, instead of + minimizing (default: False) .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 """ - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, maximize: bool = False): if not 0.0 < lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 < eps: @@ -47,7 +50,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors" ) - defaults = dict(lr=lr, betas=betas, eps=eps) + defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize) super(SparseAdam, self).__init__(params, defaults) @torch.no_grad() @@ -55,7 +58,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model + closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None @@ -72,6 +75,7 @@ def step(self, closure=None): eps = group['eps'] lr = group['lr'] beta1, beta2 = group['betas'] + maximize = group.get('maximize', False) for p in group['params']: if p.grad is not None: @@ -106,6 +110,7 @@ def step(self, closure=None): beta1=beta1, beta2=beta2, lr=group['lr'], - eps=group['eps']) + eps=group['eps'], + maximize=maximize) return loss diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 83b62911b6aec..d44fc3fffd095 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -7,6 +7,7 @@ from torch.nn import Module from torch.optim.lr_scheduler import _LRScheduler +__all__ = ['AveragedModel', 'update_bn', 'SWALR'] class AveragedModel(Module): r"""Implements averaged model for Stochastic Weight Averaging (SWA). diff --git a/torch/overrides.py b/torch/overrides.py index 2e1d48c55f00c..f4deae2aa257b 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -34,7 +34,7 @@ _has_torch_function, _has_torch_function_unary, _has_torch_function_variadic, _add_docstr, _set_torch_function_mode, _get_torch_function_mode) -from torch.utils._mode_utils import _enable_mode, _push_mode, _ModeInfo, _wrap_init, _restore_mode +from torch.utils._mode_utils import _enable_mode, _ModeInfo, _wrap_init, _restore_mode __all__ = [ "get_ignored_functions", @@ -279,8 +279,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor._is_zerotensor, Tensor._addmm_activation, Tensor._nested_tensor_layer_norm, - Tensor.to_padded_tensor, - Tensor.sym_size + Tensor.to_padded_tensor } @@ -924,6 +923,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.ravel: lambda input: -1, torch.real: lambda input, out=None: -1, torch.vdot: lambda input, other, out=None: -1, + torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1, torch.view_as_real: lambda input: -1, torch.view_as_complex: lambda input: -1, torch.reciprocal: lambda input, out=None: -1, @@ -967,6 +967,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.spmm: lambda input, mat2: -1, torch.softmax: lambda input, dim, dtype=None: -1, torch.linalg.solve: lambda A, B, left=True, out=None: -1, + torch.linalg.solve_ex: lambda A, B, left=True, check_errors=False, out=None: -1, torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1, torch.split: lambda tensor, split_size_or_sections, dim=0: -1, torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1, @@ -990,6 +991,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1, torch.swapaxes: lambda input, dim0, dim1: -1, torch.swapdims: lambda input, axis0, axis1: -1, + torch.special.airy_ai: lambda input: -1, torch.special.bessel_j0: lambda input: -1, torch.special.bessel_j1: lambda input: -1, torch.special.bessel_y0: lambda input: -1, @@ -1033,12 +1035,15 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.special.polygamma: lambda input, n, out=None: -1, torch.special.psi: lambda input: -1, torch.special.round: lambda input: -1, + torch.special.scaled_modified_bessel_k0: lambda input: -1, + torch.special.scaled_modified_bessel_k1: lambda input: -1, torch.special.shifted_chebyshev_polynomial_t: lambda input, n, out=None: -1, torch.special.shifted_chebyshev_polynomial_u: lambda input, n, out=None: -1, torch.special.shifted_chebyshev_polynomial_v: lambda input, n, out=None: -1, torch.special.shifted_chebyshev_polynomial_w: lambda input, n, out=None: -1, torch.special.sinc: lambda input: -1, torch.special.softmax: lambda input, dim, dtype=None: -1, + torch.special.spherical_bessel_j0: lambda input: -1, torch.special.xlog1py: lambda input, other, out=None: -1, torch.special.xlogy: lambda input, other, out=None: -1, torch.special.zeta: lambda self, other, out=None: -1, @@ -1068,6 +1073,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.true_divide: lambda input, other: -1, torch.trunc: lambda input, out=None: -1, torch.unbind: lambda input, dim=0: -1, + torch.unflatten: lambda input, dim, sizes, names: -1, torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1, torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1, torch.unsafe_chunk: lambda input, chunks, dim=0: -1, @@ -1825,7 +1831,7 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta): ``NotImplemented``. Independent subclasses of :class:`TorchFunctionMode` are compositional: - modes can be pushed onto a stack with :func:`push_torch_function_mode`. + modes can be pushed onto a stack using ``with MyMode():``. When you call functions in the PyTorch API inside your ``__torch_function__`` implementation, by default, they will forward on to the next mode on the mode stack. If you want recursively call back into @@ -1854,6 +1860,7 @@ def __enter__(self): else: self.ancestors = self.inner.ancestors.union({self.inner}) _set_torch_function_mode(self) + return self def __exit__(self, exc_type, exc_val, exc_tb): _set_torch_function_mode(self.inner) @@ -1864,8 +1871,9 @@ def restore(self): @classmethod def push(cls, *args, **kwargs): - return push_torch_function_mode(functools.partial(cls, *args, **kwargs)) - + warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`") + instance = cls(*args, **kwargs) + return instance class BaseTorchFunctionMode(TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): @@ -1889,8 +1897,7 @@ def _no_torch_function_mode() -> Iterator[None]: class _TorchFunctionModeInfo(_ModeInfo): def __init__(self): - super().__init__(mode_name="torch_function", mode_class=TorchFunctionMode, - base_mode_class=BaseTorchFunctionMode) + super().__init__(mode_name="torch_function", mode_class=TorchFunctionMode) def get_mode(self): return _get_torch_function_mode() @@ -1905,8 +1912,8 @@ def enable_torch_function_mode(mode, *, replace=None, ignore_preexisting=False) Context manager that sets the current :class:`TorchFunctionMode`; see the class for more information on what modes are. This function is non-compositional; if there is already an existing mode, it will raise an - error; prefer using :func:`push_torch_function_mode` if your - ``__torch_function__`` implementation can defer to an inner mode. + error; prefer using ``with MyMode():`` if your ``__torch_function__`` + implementation can defer to an inner mode. This function is safe to use inside a ``__torch_function__`` mode handler, as the mode is guaranteed to be disabled in this context. You can use @@ -1931,25 +1938,6 @@ class for more information on what modes are. This function is """ return _enable_mode(mode, _TorchFunctionModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting) -@contextlib.contextmanager -def push_torch_function_mode(ctor) -> Iterator[TorchFunctionMode]: - """ - Context manager that pushes a :class:`TorchFunctionMode` onto the current - mode stack; see the class for more information on what modes are. Stacked - modes can delegate to each other by invoking the ``__torch_function__`` - method for the ``inner`` mode. - - Args: - ctor: a function that when invoked as ``ctor(inner=...)`` produces - a :class:`TorchFunctionMode`. If your :class:`TorchFunctionMode` - has no ``__init__`` implementation, you can simply pass the class - itself (e.g., ``push_torch_function_mode(MyMode)``); otherwise, - use ``functools.partial`` to partially apply the constructor with all - non-inner arguments (e.g., - ``push_torch_function_mode(partial(MyMode, arg))``) - """ - return _push_mode(ctor, _TorchFunctionModeInfo()) - class enable_reentrant_dispatch(): def __enter__(self): self._raii_guard = torch._C._RestorePythonTLSSnapshot() @@ -1959,7 +1947,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def get_buffer(tensor_subclass, data, prefix): import ctypes - assert prefix in {"stride", "size"} + assert prefix in {"stride", "size", "sym_size"} buffer_name = f"_{prefix}_buffer" if not hasattr(tensor_subclass, buffer_name): SizeType = ctypes.c_longlong * len(data) diff --git a/torch/package/_directory_reader.py b/torch/package/_directory_reader.py index 14d20181cd3b8..30833493c4fba 100644 --- a/torch/package/_directory_reader.py +++ b/torch/package/_directory_reader.py @@ -35,7 +35,7 @@ def get_record(self, name): def get_storage_from_record(self, name, numel, dtype): filename = f"{self.directory}/{name}" nbytes = torch._utils._element_size(dtype) * numel - storage = cast(Storage, torch._UntypedStorage) + storage = cast(Storage, torch.UntypedStorage) return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes)) def has_record(self, path): diff --git a/torch/package/_stdlib.py b/torch/package/_stdlib.py index d8bd208a1969f..bddde3a60aae0 100644 --- a/torch/package/_stdlib.py +++ b/torch/package/_stdlib.py @@ -25,6 +25,10 @@ def _get_stdlib_modules(): return stdlib3_8 if sys.version_info.minor == 9: return stdlib3_9 + if sys.version_info.minor >= 10: + return sys.stdlib_module_names # type: ignore[attr-defined] + elif sys.version_info.major > 3: + return sys.stdlib_module_names # type: ignore[attr-defined] raise RuntimeError(f"Unsupported Python version: {sys.version_info}") diff --git a/torch/package/analyze/find_first_use_of_broken_modules.py b/torch/package/analyze/find_first_use_of_broken_modules.py index 88553e3238c02..1910afdd98e34 100644 --- a/torch/package/analyze/find_first_use_of_broken_modules.py +++ b/torch/package/analyze/find_first_use_of_broken_modules.py @@ -2,6 +2,8 @@ from ..package_exporter import PackagingError +__all__ = ["find_first_use_of_broken_modules"] + def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]: """ diff --git a/torch/package/analyze/trace_dependencies.py b/torch/package/analyze/trace_dependencies.py index 7ee4e8ca27f15..9f882fb33481e 100644 --- a/torch/package/analyze/trace_dependencies.py +++ b/torch/package/analyze/trace_dependencies.py @@ -1,6 +1,8 @@ import sys from typing import Any, Callable, Iterable, List, Tuple +__all__ = ["trace_dependencies"] + def trace_dependencies( callable: Callable[[Any], Any], inputs: Iterable[Tuple[Any, ...]] diff --git a/torch/package/file_structure_representation.py b/torch/package/file_structure_representation.py index 526b782f2506e..6ea69173ed3f6 100644 --- a/torch/package/file_structure_representation.py +++ b/torch/package/file_structure_representation.py @@ -3,6 +3,8 @@ from .glob_group import GlobGroup, GlobPattern +__all__ = ["Directory"] + class Directory: """A file structure representation. Organized as Directory nodes that have lists of diff --git a/torch/package/importer.py b/torch/package/importer.py index 4893730d41223..20ced7c7a0302 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -10,6 +10,8 @@ from ._mangling import demangle, get_mangle_prefix, is_mangled +__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] + class ObjNotFoundError(Exception): """Raised when an importer cannot find an object by searching for its name.""" diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 25dfe3974a764..81b5e650b518b 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -37,6 +37,13 @@ from .glob_group import GlobGroup, GlobPattern from .importer import Importer, OrderedImporter, sys_importer +__all__ = [ + "PackagingErrorReason", + "EmptyMatchError", + "PackagingError", + "PackageExporter", +] + _gate_torchscript_serialization = True ActionHook = Callable[["PackageExporter", str], None] @@ -876,20 +883,18 @@ def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()): ) def _persistent_id(self, obj): - if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage): - if isinstance(obj, torch.storage._TypedStorage): + if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): + if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case untyped_storage = obj._storage storage_type_str = obj.pickle_storage_type() storage_type = getattr(torch, storage_type_str) - dtype = obj.dtype storage_numel = obj.size() - elif isinstance(obj, torch._UntypedStorage): + elif isinstance(obj, torch.UntypedStorage): untyped_storage = obj storage_type = normalize_storage_type(type(storage)) - dtype = torch.uint8 storage_numel = storage.nbytes() else: raise RuntimeError(f"storage type not recognized: {type(obj)}") diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 2dbd0ca67a11d..6efa943f11e7e 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -1,5 +1,6 @@ import builtins import importlib +import importlib.machinery import inspect import io import linecache @@ -233,9 +234,9 @@ def persistent_load(saved_id): ) storage = loaded_storages[key] # TODO: Once we decide to break serialization FC, we can - # stop wrapping with _TypedStorage - return torch.storage._TypedStorage( - wrap_storage=storage._untyped(), dtype=dtype + # stop wrapping with TypedStorage + return torch.storage.TypedStorage( + wrap_storage=storage.untyped(), dtype=dtype ) elif typename == "reduce_package": # to fix BC breaking change, objects on this load path @@ -422,6 +423,7 @@ def _install_on_parent(self, parent: str, name: str, module: types.ModuleType): def _do_find_and_load(self, name): path = None parent = name.rpartition(".")[0] + module_name_no_parent = name.rpartition(".")[-1] if parent: if parent not in self.modules: self._gcd_import(parent) @@ -429,11 +431,37 @@ def _do_find_and_load(self, name): if name in self.modules: return self.modules[name] parent_module = self.modules[parent] + try: path = parent_module.__path__ # type: ignore[attr-defined] + except AttributeError: - msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent) - raise ModuleNotFoundError(msg, name=name) from None + # when we attempt to import a package only containing pybinded files, + # the parent directory isn't always a package as defined by python, + # so we search if the package is actually there or not before calling the error. + if isinstance( + parent_module.__loader__, + importlib.machinery.ExtensionFileLoader, + ): + if name not in self.extern_modules: + msg = ( + _ERR_MSG + + "; {!r} is a c extension module which was not externed. C extension modules \ + need to be externed by the PackageExporter in order to be used as we do not support interning them.}." + ).format(name, name) + raise ModuleNotFoundError(msg, name=name) from None + if not isinstance( + parent_module.__dict__.get(module_name_no_parent), + types.ModuleType, + ): + msg = ( + _ERR_MSG + + "; {!r} is a c extension package which does not contain {!r}." + ).format(name, parent, name) + raise ModuleNotFoundError(msg, name=name) from None + else: + msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent) + raise ModuleNotFoundError(msg, name=name) from None module = self._load_module(name, parent) diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index baa64588abc77..44635884554f6 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -1,5 +1,5 @@ r''' -PyTorch Profiler is a tool that allows the collecton of the performance metrics during the training and inference. +PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. Profiler's context manager API can be used to better understand what model operators are the most expensive, examine their input shapes and stack traces, study device kernel activity and visualize the execution trace. @@ -16,3 +16,5 @@ __all__ = ['profile', 'schedule', 'supported_activities', 'tensorboard_trace_handler', 'ProfilerAction', 'ProfilerActivity', 'kineto_available', 'DeviceType', 'record_function', 'ExecutionGraphObserver'] + +from . import itt diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py new file mode 100644 index 0000000000000..42d1ebf7c974a --- /dev/null +++ b/torch/profiler/_pattern_matcher.py @@ -0,0 +1,547 @@ +from collections import deque +import os +import re +from typing import Dict, List, Set + +import torch +from torch.profiler import profile +import torch.utils.benchmark as benchmark +from torch.profiler._utils import index_of_first_match +from torch._C._autograd import (_ProfilerEvent, _ExtraFields_TorchOp, + _ExtraFields_PyCCall, _ExtraFields_PyCall, + _EventType) + + +class Pattern: + ''' + Base class for all patterns, subclass this class and implement match() + to define custom patterns. + + In subclass, define description and skip property. + ''' + + def __init__(self, prof: profile, should_benchmark: bool = False): + self.prof = prof + self.should_benchmark = should_benchmark + self.name = "Please specify a name for pattern" + self.description = "Please specify a description for pattern" + assert prof.profiler is not None and prof.profiler.kineto_results is not None + self.event_tree = prof.profiler.kineto_results.experimental_event_tree( + ) + self.tid_root: Dict[int, List[_ProfilerEvent]] = {} + for event in self.event_tree: + self.tid_root.setdefault(event.start_tid, []).append(event) + + @property + def skip(self): + return False + + def report(self, event: _ProfilerEvent): + msg = f"{self.description}\n[Source Code Location] {source_code_location(event)}" + return msg + + def eventTreeTraversal(self): + ''' + Traverse the event tree and yield all events. + Override this method in subclass to customize the traversal. + ''' + yield from eventTreeDFS(self.event_tree) + + def summary(self, events: List[_ProfilerEvent]): + default_summary = f"{self.name}: {len(events)} events matched." + if self.should_benchmark: + summary = self.benchmark_summary(events) + # If benchmark summary is not empty, use it. + return summary if summary else default_summary + return default_summary + + def benchmark_summary(self, events: List[_ProfilerEvent]): + return "" + + def match(self, event: _ProfilerEvent): + ''' + Return True if the event matches the pattern. + This method should be overriden in subclass. + ''' + raise NotImplementedError + + def matched_events(self): + if self.skip: + return [] + matched_events = [] + for event in self.eventTreeTraversal(): + if self.match(event): + matched_events.append(event) + return matched_events + + def root_of(self, event: _ProfilerEvent): + while event.parent: + event = event.parent + return event + + def siblings_of(self, event: _ProfilerEvent): + if event.parent: + children = event.parent.children + else: + children = self.tid_root[event.start_tid] + index = children.index(event) + return children[:index], children[index + 1:] + + def next_of(self, event: _ProfilerEvent): + _, next_events = self.siblings_of(event) + return next_events[0] if next_events else None + + def prev_of(self, event: _ProfilerEvent): + prev_events, _ = self.siblings_of(event) + return prev_events[-1] if prev_events else None + + def go_up_until(self, event: _ProfilerEvent, predicate): + if not event: + return None + while event.parent and not predicate(event): + event = event.parent + return event + + +# Patterns + + +class NamePattern(Pattern): + + def __init__(self, prof: profile, name: str, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.description = f"Matched Name Event: {name}" + self.name = name + + def match(self, event: _ProfilerEvent): + return re.search(self.name, event.name()) is not None + + +class ExtraCUDACopyPattern(Pattern): + ''' + This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU. + example: torch.zeros((100, 100)).to("cuda") + + Pattern: + build-in method |build-in method + ... | aten::to + aten::fill_/aten::zero_ | aten::_to_copy + + Algorithm: + We start at node aten::to, go parent events' previous events, + and check if we have a aten::fill_/aten::zero_ as we keep going down the tree. + We always select the last child in the children list when we go down the tree. + If at any step we failed, it is not a match. + ''' + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Extra CUDA Copy Pattern" + self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initalize it on GPU." + self.init_ops = { + "aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_" + } + + @property + def skip(self): + return not self.prof.with_stack or not self.prof.record_shapes + + def match(self, event): + # TODO: We should also check tensor identities + if event.name() != "aten::to": + return False + # Up one level + event = event.parent + if event is None: + return False + # Check if we have a aten::fill_ in previous leaf + event = self.prev_of(event) + if event is None: + return False + while event.children: + event = event.children[-1] + # aten::zero_ is a special optimzation case where fill_ is not called + if event.name() in self.init_ops: + return True + return event.name() in self.init_ops + # TODO: Check if tensor is reused + + def benchmark(self, events: List[_ProfilerEvent]): + shapes_factor_map = {input_shapes(event)[0]: 0.0 for event in events} + for shape in shapes_factor_map: + to_timer = benchmark.Timer(stmt='torch.ones(shape).to("cuda")', + globals={'shape': shape}) + de_timer = benchmark.Timer(stmt='torch.ones(shape, device="cuda")', + globals={'shape': shape}) + to_time = to_timer.timeit(10).mean + de_time = de_timer.timeit(10).mean + shapes_factor_map[shape] = de_time / to_time + return shapes_factor_map + + def benchmark_summary(self, events: List[_ProfilerEvent]): + shapes_factor_map = self.benchmark(events) + original_time = sum(event.duration_time_ns for event in events) / 1e3 + new_time = sum( + shapes_factor_map[input_shapes(event)[0]] * event.duration_time_ns + for event in events) / 1e3 + return ( + f"{self.name}: {len(events)} events matched. " + f"Total Estimated Speedup: {original_time - new_time}us ({original_time/new_time}X)" + ) + + +class ForLoopIndexingPattern(Pattern): + ''' + This pattern identifies if we use a for loop to index a tensor that + can be vectorized. + example: + tensor = torch.empty((100, 100)) + for i in range(100): + tensor[i] = i + + Pattern: + aten::select | ... | aten::select | ... (Repeat) + + Algorithm: + We start at node aten::select, and we check if we can find this alternating patterns. + We also keep a dictionary to avoid duplicate match in the for loop. + ''' + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "For Loop Indexing Pattern" + self.description = "For loop indexing detected. Vectorization recommended." + self.visited: Set[int] = set() + + def eventTreeTraversal(self): + ''' + We need to use BFS traversal order to avoid duplicate match. + ''' + yield from eventTreeBFS(self.event_tree) + + def match(self, event: _ProfilerEvent): + if event.name() != "aten::select": + return False + if event.id in self.visited: + return False + repeat_count = 1 + _, next = self.siblings_of(event) + if len(next) <= 1: + return False + + # Custom event list matching + def same_ops(list1, list2): + if len(list1) != len(list2): + return False + for op1, op2 in zip(list1, list2): + if op1.name() != op2.name(): + return False + return True + + # Record the ops between two aten::select + next_select_idx = index_of_first_match( + next, lambda e: e.name() == "aten::select") + if next_select_idx is None: + return False + indexing_ops = [event] + next[:next_select_idx] + next = next[len(indexing_ops) - 1:] + for i in range(0, len(next), len(indexing_ops)): + if same_ops(indexing_ops, next[i:i + len(indexing_ops)]): + repeat_count += 1 + self.visited.add(next[i].id) + else: + break + return repeat_count >= 10 + + +class FP32MatMulPattern(Pattern): + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "FP32 MatMul Pattern" + self.description = ( + "You are currently using GPU that supports TF32. " + "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'" + ) + + @property + def skip(self): + # Anything less than sm_80 is not Ampere which doesn't support TF32 + has_tf32 = all( + int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list()) + return has_tf32 is False or super().skip or not self.prof.record_shapes + + def match(self, event: _ProfilerEvent): + # If we saw this pattern once, we don't need to match it again + if event.tag != _EventType.TorchOp: + return False + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + if event.name() == "aten::mm": + if event.extra_fields.allow_tf32_cublas is False: + return True + return False + + def report(self, event: _ProfilerEvent): + return self.description + + def benchmark(self, events: List[_ProfilerEvent]): + shapes_factor_map = {input_shapes(event): 0.0 for event in events} + for shape in shapes_factor_map: + matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32) + matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32) + fp32_timer = benchmark.Timer(stmt='torch.mm(matrixA, matrixB)', + globals={ + "matrixA": matrixA, + "matrixB": matrixB + }) + tf32_timer = benchmark.Timer( + stmt='torch.mm(matrixA, matrixB)', + setup='torch.backends.cuda.matmul.allow_tf32 = True', + globals={ + "matrixA": matrixA, + "matrixB": matrixB + }) + torch.backends.cuda.matmul.allow_tf32 = False + fp32_time = fp32_timer.timeit(10).mean + tf32_time = tf32_timer.timeit(10).mean + shapes_factor_map[shape] = tf32_time / fp32_time + return shapes_factor_map + + def benchmark_summary(self, events: List[_ProfilerEvent]): + shapes_factor_map = self.benchmark(events) + original_time = sum(event.duration_time_ns for event in events) / 1e3 + new_time = sum( + shapes_factor_map[input_shapes(event)] * event.duration_time_ns + for event in events) / 1e3 + return ( + f"{self.name}: {len(events)} events matched. " + f"Total Estimated Speedup: {original_time - new_time}us ({original_time/new_time}X)" + ) + + +class OptimizerSingleTensorPattern(Pattern): + ''' + This pattern identifies if we are using the single-tensor version of an optimizer. + example: + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when + the kernels are relatively small. + + Pattern: + XXXXX: _single_tenser_ + + Algorithm: + String match + ''' + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Optimizer Single Tensor Pattern" + self.optimizers_with_foreach = ["adam", "sgd", "adamw"] + self.description = ( + "Deteced optimizer running with single tensor implementation. " + "Please enable multi tensor implementation by passing 'foreach=True' into optimizer." + ) + + def match(self, event: _ProfilerEvent): + for optimizer in self.optimizers_with_foreach: + if event.name().endswith(f"_single_tensor_{optimizer}"): + return True + return False + + +class SynchronizedDataLoaderPattern(Pattern): + ''' + This pattern identifies if we are using num_workers=0 in DataLoader. + example: + torch.utils.data.DataLoader(dataset, batch_size=batch_size) + Add num_workers=N to the arguments. N depends on system configuration. + + Pattern: + dataloader.py(...): __iter__ + dataloader.py(...): _get_iterator + NOT dataloader.py(...): check_worker_number_rationality + + Algorithm: + If we don't see check_worker_number_rationality call in the dataloader __iter__, + It is not an asynchronous dataloader. + + ''' + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Synchronized DataLoader Pattern" + self.description = ( + "Detected DataLoader running with synchronized implementation. " + "Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader." + ) + + def match(self, event: _ProfilerEvent): + def is_dataloader_function(name: str, function_name: str): + return name.startswith(os.path.join("torch", "utils", "data", "dataloader.py")) and name.endswith(function_name) + if not is_dataloader_function(event.name(), "__iter__"): + return False + if not event.children: + return False + event = event.children[0] + if not is_dataloader_function(event.name(), "_get_iterator"): + return False + if not event.children: + return False + event = event.children[0] + return not is_dataloader_function(event.name(), "check_worker_number_rationality") + # TODO: We should also check if the loader is bottleneck. + + +class GradNotSetToNonePattern(Pattern): + ''' + This pattern identifies if we are not setting grad to None in zero_grad. + example: + optimizer.zero_grad() + By setting set_to_none=True, we can gain speedup + + Pattern: + XXXXX: _zero_grad + NOT aten::zeros + aten::zero_ + + aten::zero_ is called on each parameter in the model. + We also want to make sure it is not called by aten::zeros. + + Algorithm: + String match + ''' + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Gradient Set To Zero Instead of None Pattern" + self.description = ( + "Detected gradient set to zero instead of None. " + "Please add 'set_to_none=True' when calling zero_grad().") + + def match(self, event: _ProfilerEvent): + if not event.name().endswith(": zero_grad"): + return False + if not event.children: + return False + + for sub_event in eventTreeDFS(event.children): + if sub_event.name( + ) == "aten::zero_" and sub_event.parent.name() != "aten::zeros": + return True + # TODO: We should also check if the optimizer's numerical behavior will change. + return False + + +class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): + ''' + This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d. + Bias doesn't do anything when followed by batchnorm. + Pattern: + nn.Module: Conv2d | nn.Module: BatchNorm2d + ... + aten::conv2d AND dtype of third argument is not null + The third argument is the bias + Algorithm: + String match + ''' + + def __init__(self, prof: profile, should_benchmark: bool = False): + super().__init__(prof, should_benchmark) + self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern" + self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d." + + @property + def skip(self): + return self.prof.record_shapes is False or super().skip + + def match(self, event: _ProfilerEvent): + if event.name() != "aten::conv2d": + return False + if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] == "": + return False + # This means bias=True + event = self.go_up_until( + event, lambda e: e.name().startswith("nn.Module: Conv2d")) + if not event: + return False + event = self.next_of(event) + if not event: + return False + return event.name().startswith("nn.Module: BatchNorm2d") + + +def source_code_location(event: _ProfilerEvent): + while event: + if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall: + assert isinstance(event.extra_fields, + _ExtraFields_PyCall) or isinstance( + event.extra_fields, _ExtraFields_PyCCall) + if not event.extra_fields.caller.file_name.startswith("torch" + os.sep): + return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}" + event = event.parent + return "No source code location found" + + +def input_shapes(event: _ProfilerEvent): + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + return tuple([tuple(shape) for shape in event.extra_fields.inputs.shapes]) + + +def input_dtypes(event: _ProfilerEvent): + assert isinstance(event.extra_fields, _ExtraFields_TorchOp) + return tuple(t for t in event.extra_fields.inputs.dtypes) + + +def eventTreeDFS(event_tree: List[_ProfilerEvent]): + ''' + Standard DFS traversal of the event tree. + ''' + stack = deque(event_tree) + while stack: + curr_event = stack.pop() + yield curr_event + for child_event in curr_event.children: + stack.append(child_event) + + +def eventTreeBFS(event_tree: List[_ProfilerEvent]): + ''' + Standard BFS traversal of the event tree. + ''' + stack = deque(event_tree) + while stack: + curr_event = stack.popleft() + yield curr_event + for child_event in curr_event.children: + stack.append(child_event) + + +def report_all_anti_patterns(prof, should_benchmark: bool = False): + anti_patterns = [ + ExtraCUDACopyPattern(prof, should_benchmark), + ForLoopIndexingPattern(prof, should_benchmark), + FP32MatMulPattern(prof, should_benchmark), + OptimizerSingleTensorPattern(prof, should_benchmark), + SynchronizedDataLoaderPattern(prof, should_benchmark), + GradNotSetToNonePattern(prof, should_benchmark), + Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark) + ] + reported = set() + summaries = [] + message_list = [f"{'-'*40}TorchTidy Report{'-'*40}"] + message_list.append("Matched Events:") + for anti_pattern in anti_patterns: + matched_events = anti_pattern.matched_events() + if not matched_events: + continue + summaries.append(anti_pattern.summary(matched_events)) + for event in matched_events: + report_msg = anti_pattern.report(event) + if report_msg not in reported: + message_list.append(report_msg) + reported.add(report_msg) + message_list.append("Summary:") + message_list += summaries + message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}") + print("\n".join(message_list)) diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 83c37c8f0687d..a730fc1768887 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -1,5 +1,6 @@ from collections import deque from dataclasses import dataclass +import re from typing import Dict, List from torch.profiler import DeviceType @@ -12,6 +13,7 @@ class EventMetrics: duration_time_ns: int = 0 self_time_ns: int = 0 idle_time_ns: int = 0 + queue_depth: int = 0 @property def fraction_idle_time(self): @@ -39,20 +41,38 @@ def __eq__(self, other): return self.event.id == other.event.id def __repr__(self): - return f"<{self.event.name()} id={self.event.correlation_id}>" + return f"{self.event.name()}" def intervals_overlap(self, intervals: List[Interval]): overlap_time = 0 intervals = sorted(intervals, key=lambda x: x.start) - for i, interval in enumerate(intervals): - if i + 1 < len(intervals): - assert interval.end <= intervals[ - i + 1].start, "Intervals must be disjoint" - overlap_start = max(self.event.start_time_ns, interval.start) - overlap_end = min(self.event.end_time_ns, interval.end) + + if intervals: + overlap_start = max(self.event.start_time_ns, intervals[0].start) + overlap_end = min(self.event.end_time_ns, intervals[0].end) if overlap_start < overlap_end: overlap_time += overlap_end - overlap_start + + i, j = 0, 1 + while (j < len(intervals)): + prev_interval = intervals[i] + curr_interval = intervals[j] + j += 1 + if prev_interval.end > curr_interval.start: + # Completely subsumed by previous interval + if prev_interval.end > curr_interval.end: + j += 1 + continue + else: + curr_interval.start = prev_interval.end + i = j + + overlap_start = max(self.event.start_time_ns, curr_interval.start) + overlap_end = min(self.event.end_time_ns, curr_interval.end) + if overlap_start < overlap_end: + overlap_time += overlap_end - overlap_start + return overlap_time @@ -62,6 +82,12 @@ def __init__(self, prof: profile): self.profile = prof self.metrics: Dict[EventKey, EventMetrics] = {} self.compute_self_time() + self.event_keys = sorted((e for e in self.metrics.keys()), + key=lambda x: x.event.start_time_ns) + self.events = [e.event for e in self.event_keys] + self.cuda_events: List[_KinetoEvent] = [] + self.queue_depth_list = self.compute_queue_depth() + self.compute_idle_time() def compute_self_time(self): ''' @@ -77,7 +103,6 @@ def compute_self_time(self): for child_event in curr_event.children: self_time -= child_event.duration_time_ns stack.append(child_event) - assert EventKey( curr_event ) not in self.metrics, f"Duplicate id: {curr_event.id}, {curr_event.name()}" @@ -88,9 +113,9 @@ def compute_self_time(self): def compute_queue_depth(self): ''' - Computes event's idle time. Idle time is defined as the time when the CUDA kernel queue depth is 0. - It also return a Time series of the queue depth data. - qd = cuda kernel queue depth + Computes queue_depth at each event. This will calculate the queue depth data for + All the events in the tree. + This will return a list of Interval of queue depth data of cuda launch and kernels. ''' assert (self.profile.kineto_results is not None) cuda_event_list = self.profile.kineto_results.events() @@ -104,10 +129,6 @@ def is_cuda_kernel(e): return e.device_type() == DeviceType.CUDA and "mem" not in e.name( ).lower() - # Record All the idle intervals - idle_interval: List[Interval] = [] - queue_depth_list: List[Interval] = [] - cuda_launch_events = sorted( (e for e in cuda_event_list if is_cuda_launch_kernel(e)), key=lambda x: x.start_us()) @@ -115,6 +136,9 @@ def is_cuda_kernel(e): (e for e in cuda_event_list if is_cuda_kernel(e)), key=lambda x: x.start_us()) + self.cuda_events = sorted(cuda_launch_events + cuda_kernel_events, + key=lambda x: x.start_us()) + kernel_mapping: Dict[_KinetoEvent, int] = {} last_mapped_kernel = 0 for cuda_launch_event in cuda_launch_events: @@ -126,36 +150,162 @@ def is_cuda_kernel(e): kernel_mapping[cuda_launch_event] = index last_mapped_kernel = index if index is not None else last_mapped_kernel - current_kernel_index = -1 - spawned_kernel_index = None - for cuda_launch_event in cuda_launch_events: + current_kernel_index = 0 + spawned_kernel_index = -1 + + all_events = cuda_launch_events + cuda_kernel_events + self.events + + def new_old_event_comparator(event): + if hasattr(event, "start_us"): + return event.start_us() * 1000 + if hasattr(event, "start_time_ns"): + return event.start_time_ns + raise Exception("Unknown Event Type") + + queue_depth_list: List[Interval] = [] + all_events.sort(key=new_old_event_comparator) + for event in all_events: # Find latest cuda kernel event - while (current_kernel_index + 1 < len(cuda_kernel_events) and - cuda_kernel_events[current_kernel_index + 1].start_us() + - cuda_kernel_events[current_kernel_index + 1].duration_us() < - cuda_launch_event.start_us() + - cuda_launch_event.duration_us()): + if hasattr(event, "start_us"): + start_time = event.start_us() * 1000 + end_time = (event.start_us() + event.duration_us()) * 1000 + # Find current spawned cuda kernel event + if event in kernel_mapping and kernel_mapping[ + event] is not None: + spawned_kernel_index = kernel_mapping[event] + elif hasattr(event, "start_time_ns"): + start_time = event.start_time_ns # type: ignore[attr-defined] + end_time = event.end_time_ns # type: ignore[attr-defined] + + while (current_kernel_index < len(cuda_kernel_events) and + (cuda_kernel_events[current_kernel_index].start_us()) * 1000 + <= start_time): current_kernel_index += 1 + current_queue_depth = spawned_kernel_index - current_kernel_index + 1 + current_queue_depth = max(current_queue_depth, 0) - # Find current spawned cuda kernel event - spawned_kernel_index = kernel_mapping[cuda_launch_event] - if spawned_kernel_index is None: - current_queue_depth = 0 - else: - current_queue_depth = spawned_kernel_index - current_kernel_index + if hasattr(event, "start_us"): + queue_depth_list.append( + Interval(start_time, end_time, current_queue_depth)) + elif hasattr(event, "start_time_ns"): + self.metrics[EventKey(event)].queue_depth = current_queue_depth - queue_depth_list.append( - Interval( - cuda_launch_event.start_us(), - cuda_launch_event.start_us() + - cuda_launch_event.duration_us(), current_queue_depth)) + return queue_depth_list + + def compute_idle_time(self): + ''' + Computes idle time of the profile. + ''' + # Based on queue_depth_list, we can calculate idle time for all the events + idle = False + idle_start = 0 + idle_intervals: List[Interval] = [] + if self.queue_depth_list and self.events: + idle_intervals += [ + Interval(self.events[0].start_time_ns, + self.queue_depth_list[0].start), + Interval(self.queue_depth_list[-1].end, + self.events[-1].end_time_ns) + ] + + for data_point in self.queue_depth_list: + if data_point.queue_depth == 0 and not idle: + idle_start = data_point.end + idle = True + if data_point.queue_depth > 0 and idle: + idle_intervals.append(Interval(idle_start, data_point.start)) + idle = False event_list = [e.event for e in self.metrics.keys()] for event in event_list: self.metrics[EventKey(event)].idle_time_ns = EventKey( - event).intervals_overlap(idle_interval) + event).intervals_overlap(idle_intervals) - return queue_depth_list + def rank_events(self, length): + ''' + Filter and Rank the events based on some heuristics: + 1) Events that are in the falling phase of the queue depth. + 2) Events that have a high idle_time, self_time difference. + + Parameters: + length: The number of events to return. + ''' + + # Find the interval when qd is falling to 0 + import torch + queue_depth_list = list(reversed(self.queue_depth_list)) + qd_values = [e.queue_depth for e in queue_depth_list] + + bottom_threashold = 0 + top_threashold = 4 + decrease_interval = [] + i = 0 + while (i < len(qd_values)): + if qd_values[i] > bottom_threashold: + i += 1 + continue + for j in range(i + 1, len(qd_values)): + # Find next zero and if the max value between them exceeds + # the threshold, then we have a falling interval + next_minimum_idx = index_of_first_match( + qd_values, lambda x: x <= bottom_threashold, start=j) + peak_idx = argmax(qd_values, start=j, end=next_minimum_idx) + + # if is a valid peak, we add to list and continue + if peak_idx is not None and qd_values[ + peak_idx] >= top_threashold: + decrease_interval.append( + Interval(queue_depth_list[peak_idx].start, + queue_depth_list[i].start)) + i = next_minimum_idx if next_minimum_idx is not None else i + break + i += 1 + # Filter out events that are not in the decrease interval + event_list = [ + event for event in self.metrics.keys() + if event.intervals_overlap(decrease_interval) + ] + if event_list: + self_time = torch.tensor( + [self.metrics[event].self_time_ns for event in event_list], + dtype=torch.float32) + idle_time = torch.tensor([ + self.metrics[event].fraction_idle_time for event in event_list + ], dtype=torch.float32) + normalized_gain = (idle_time - + torch.mean(idle_time)) / torch.std(idle_time) + normalized_self = (self_time - + torch.mean(self_time)) / torch.std(self_time) + heuristic_score_list = normalized_gain + 0.6 * normalized_self + + # Sort events by heuristic + event_list = [ + event + for _, event in sorted(zip(heuristic_score_list, event_list), + key=lambda x: x[0], + reverse=True) + ] + event_list = event_list[:length] + return event_list + + def get_optimizable_events(self, + length: int = 1, + print_enable: bool = True): + event_list = self.rank_events(length) + if not print_enable: + return event_list + output = "Optimizable events:\n" if event_list else "No events to optimize\n" + + output += "\n".join([ + f"""{'-'*80} +Event: {event} +Source code location: {source_code_location(event.event)} +Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% +{'-'*80}""" for event in event_list + ]) + if print_enable: + print(output) + return event_list def index_of_first_match(seq, predicate, start=0, end=None): @@ -165,3 +315,20 @@ def index_of_first_match(seq, predicate, start=0, end=None): if predicate(seq[i]): return i return None + + +def argmax(seq, key=lambda x: x, start=0, end=None): + seq = seq[start:end] + if len(seq) == 0: + return None + return seq.index(max(seq, key=key)) + start + + +def source_code_location(event): + while (event is not None): + match = re.search(r"\.py\(.*\)", event.name()) + if (match is None): + event = event.parent + continue + return event.name() + return "No source code location found" diff --git a/torch/profiler/itt.py b/torch/profiler/itt.py new file mode 100644 index 0000000000000..022af77c04090 --- /dev/null +++ b/torch/profiler/itt.py @@ -0,0 +1,56 @@ +from contextlib import contextmanager + +try: + from torch._C import _itt +except ImportError: + class _ITTStub(object): + @staticmethod + def _fail(*args, **kwargs): + raise RuntimeError("ITT functions not installed. Are you sure you have a ITT build?") + + rangePush = _fail + rangePop = _fail + mark = _fail + + _itt = _ITTStub() # type: ignore[assignment] + + +__all__ = ['range_push', 'range_pop', 'mark', 'range'] + + +def range_push(msg): + """ + Arguments: + msg (str): ASCII message to associate with range + """ + return _itt.rangePush(msg) + + +def range_pop(): + """ + """ + return _itt.rangePop() + + +def mark(msg): + """ + Describe an instantaneous event that occurred at some point. + Arguments: + msg (str): ASCII message to associate with the event. + """ + return _itt.mark(msg) + + +@contextmanager +def range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an ITT range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + Args: + msg (str): message to associate with the range + """ + range_push(msg.format(*args, **kwargs)) + yield + range_pop() diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 927cd2cdb6cb6..de1eb5b216ec7 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -18,6 +18,8 @@ ) from torch.autograd import ProfilerActivity, kineto_available +__all__ = ['supported_activities', 'ProfilerAction', 'schedule', 'tensorboard_trace_handler', 'profile', + 'ExecutionGraphObserver'] def supported_activities(): """ @@ -56,7 +58,7 @@ class _KinetoProfile(object): used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed. .. note:: - This API is an experimental and subject to change in future. + This API is experimental and subject to change in the future. Enabling shape and stack tracing results in additional overhead. When record_shapes=True is specified, profiler will temporarily hold references to the tensors; @@ -279,9 +281,9 @@ class profile(_KinetoProfile): activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA. - schedule (callable): callable that takes step (int) as a single parameter and returns + schedule (Callable): callable that takes step (int) as a single parameter and returns ``ProfilerAction`` value that specifies the profiler action to perform at each step. - on_trace_ready (callable): callable that is called at each step when ``schedule`` + on_trace_ready (Callable): callable that is called at each step when ``schedule`` returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. record_shapes (bool): save information about operator's input shapes. profile_memory (bool): track tensor memory allocation/deallocation. diff --git a/torch/serialization.py b/torch/serialization.py index b4a3bad4d3f6e..6e9a8a4c84661 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -157,7 +157,7 @@ def _cuda_deserialize(obj, location): device = validate_cuda_device(location) if getattr(obj, "_torch_load_uninitialized", False): with torch.cuda.device(device): - return torch._UntypedStorage(obj.nbytes(), device=torch.device(location)) + return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) else: return obj.cuda(device) @@ -171,7 +171,7 @@ def _mps_deserialize(obj, location): register_package(21, _mps_tag, _mps_deserialize) -def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]): +def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]): for _, tagger, _ in _package_registry: location = tagger(storage) if location: @@ -383,12 +383,13 @@ def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], """ _check_dill_version(pickle_module) - with _open_file_like(f, 'wb') as opened_file: - if _use_new_zipfile_serialization: - with _open_zipfile_writer(opened_file) as opened_zipfile: - _save(obj, opened_zipfile, pickle_module, pickle_protocol) - return - _legacy_save(obj, opened_file, pickle_module, pickle_protocol) + if _use_new_zipfile_serialization: + with _open_zipfile_writer(f) as opened_zipfile: + _save(obj, opened_zipfile, pickle_module, pickle_protocol) + return + else: + with _open_file_like(f, 'wb') as opened_file: + _legacy_save(obj, opened_file, pickle_module, pickle_protocol) def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: @@ -422,10 +423,10 @@ def persistent_id(obj: Any) -> Optional[Tuple]: "for correctness upon loading.") return ('module', obj, source_file, source) - if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj): - storage: torch._UntypedStorage + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + storage: torch.UntypedStorage - if isinstance(obj, torch.storage._TypedStorage): + if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted storage = obj._storage @@ -435,7 +436,7 @@ def persistent_id(obj: Any) -> Optional[Tuple]: dtype = obj.dtype storage_numel = obj.size() - elif isinstance(obj, torch._UntypedStorage): + elif isinstance(obj, torch.UntypedStorage): storage = obj storage_dtype = torch.uint8 storage_type = normalize_storage_type(type(obj)) @@ -475,8 +476,8 @@ def persistent_id(obj: Any) -> Optional[Tuple]: # effectively saving nbytes in this case. We'll be able to load it # and the tensor back up with no problems in _this_ and future # versions of pytorch, but in older versions, here's the problem: - # the storage will be loaded up as a _UntypedStorage, and then the - # FloatTensor will loaded and the _UntypedStorage will be assigned to + # the storage will be loaded up as a UntypedStorage, and then the + # FloatTensor will loaded and the UntypedStorage will be assigned to # it. Since the storage dtype does not match the tensor dtype, this # will cause an error. If we reverse the list, like `[tensor, # storage]`, then we will save the `tensor.storage()` as a faked @@ -484,7 +485,7 @@ def persistent_id(obj: Any) -> Optional[Tuple]: # dtype-specific numel count that old versions expect. `tensor` # will be able to load up properly in old versions, pointing to # a FloatStorage. However, `storage` is still being translated to - # a _UntypedStorage, and it will try to resolve to the same + # a UntypedStorage, and it will try to resolve to the same # FloatStorage that `tensor` contains. This will also cause an # error. It doesn't seem like there's any way around this. # Probably, we just cannot maintain FC for the legacy format if the @@ -551,9 +552,9 @@ def persistent_id(obj): # see # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 - if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj): + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): - if isinstance(obj, torch.storage._TypedStorage): + if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted storage = obj._storage @@ -816,11 +817,11 @@ def persistent_load(saved_id): args = pickle_module.load(f, **pickle_load_args) key, location, storage_type = args dtype = storage_type.dtype - obj = cast(Storage, torch._UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) + obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) obj = restore_location(obj, location) # TODO: Once we decide to break serialization FC, we can - # stop wrapping with _TypedStorage - deserialized_objects[key] = torch.storage._TypedStorage( + # stop wrapping with TypedStorage + deserialized_objects[key] = torch.storage.TypedStorage( wrap_storage=obj, dtype=dtype) @@ -830,8 +831,8 @@ def persistent_load(saved_id): element_size = torch._utils._element_size(root.dtype) offset_bytes = offset * element_size # TODO: Once we decide to break serialization FC, we can - # stop wrapping with _TypedStorage - deserialized_objects[target_cdata] = torch.storage._TypedStorage( + # stop wrapping with TypedStorage + deserialized_objects[target_cdata] = torch.storage.TypedStorage( wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size], dtype=root.dtype) @@ -878,11 +879,11 @@ def persistent_load(saved_id): nbytes = numel * torch._utils._element_size(dtype) if root_key not in deserialized_objects: - obj = cast(Storage, torch._UntypedStorage(nbytes)) + obj = cast(Storage, torch.UntypedStorage(nbytes)) obj._torch_load_uninitialized = True # TODO: Once we decide to break serialization FC, we can - # stop wrapping with _TypedStorage - deserialized_objects[root_key] = torch.storage._TypedStorage( + # stop wrapping with TypedStorage + deserialized_objects[root_key] = torch.storage.TypedStorage( wrap_storage=restore_location(obj, location), dtype=dtype) @@ -893,8 +894,8 @@ def persistent_load(saved_id): view_size_bytes = view_size * torch._utils._element_size(dtype) if view_key not in deserialized_objects: # TODO: Once we decide to break serialization FC, we can - # stop wrapping with _TypedStorage - deserialized_objects[view_key] = torch.storage._TypedStorage( + # stop wrapping with TypedStorage + deserialized_objects[view_key] = torch.storage.TypedStorage( wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes], dtype=dtype) res = deserialized_objects[view_key] @@ -1004,10 +1005,10 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl def load_tensor(dtype, numel, key, location): name = f'data/{key}' - storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped() + storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage().untyped() # TODO: Once we decide to break serialization FC, we can - # stop wrapping with _TypedStorage - loaded_storages[key] = torch.storage._TypedStorage( + # stop wrapping with TypedStorage + loaded_storages[key] = torch.storage.TypedStorage( wrap_storage=restore_location(storage, location), dtype=dtype) @@ -1019,7 +1020,7 @@ def persistent_load(saved_id): assert typename == 'storage', \ f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" storage_type, key, location, numel = data - if storage_type is torch._UntypedStorage: + if storage_type is torch.UntypedStorage: dtype = torch.uint8 else: dtype = storage_type.dtype diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 29cd5510a74ba..64b396e2b452f 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -262,3 +262,97 @@ def sum(input: Tensor, dim: DimOrDims = None, performed. This is useful for preventing data type overflows. Default: None """) + + +spdiags = _add_docstr( + _sparse._spdiags, + r""" +sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor + +Creates a sparse 2D tensor by placing the values from rows of +:attr:`diagonals` along specified diagonals of the output + +The :attr:`offsets` tensor controls which diagonals are set. + +- If :attr:`offsets[i]` = 0, it is the main diagonal +- If :attr:`offsets[i]` < 0, it is below the main diagonal +- If :attr:`offsets[i]` > 0, it is above the main diagonal + +The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`, +and an offset may not be repeated. + +Args: + diagonals (Tensor): Matrix storing diagonals row-wise + offsets (Tensor): The diagonals to be set, stored as a vector + shape (2-tuple of ints): The desired shape of the result +Keyword args: + layout (:class:`torch.layout`, optional): The desired layout of the + returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr`` + are supported. Default: ``torch.sparse_coo`` + +Examples: + +Set the main and first two lower diagonals of a matrix:: + + >>> diags = torch.arange(9).reshape(3, 3) + >>> diags + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3)) + >>> s + tensor(indices=tensor([[0, 1, 2, 1, 2, 2], + [0, 1, 2, 0, 1, 0]]), + values=tensor([0, 1, 2, 3, 4, 6]), + size=(3, 3), nnz=6, layout=torch.sparse_coo) + >>> s.to_dense() + tensor([[0, 0, 0], + [3, 1, 0], + [6, 4, 2]]) + + +Change the output layout:: + + >>> diags = torch.arange(9).reshape(3, 3) + >>> diags + tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8]) + >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr) + >>> s + tensor(crow_indices=tensor([0, 1, 3, 6]), + col_indices=tensor([0, 0, 1, 0, 1, 2]), + values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6, + layout=torch.sparse_csr) + >>> s.to_dense() + tensor([[0, 0, 0], + [3, 1, 0], + [6, 4, 2]]) + +Set partial diagonals of a large output:: + + >>> diags = torch.tensor([[1, 2], [3, 4]]) + >>> offsets = torch.tensor([0, -1]) + >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense() + tensor([[1, 0, 0, 0, 0], + [3, 2, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]) + +.. note:: + + When setting the values along a given diagonal the index into the diagonal + and the index into the row of :attr:`diagonals` is taken as the + column index in the output. This has the effect that when setting a diagonal + with a positive offset `k` the first value along that diagonal will be + the value in position `k` of the row of :attr:`diagonals` + +Specifying a positive offset:: + + >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense() + tensor([[1, 2, 3, 0, 0], + [0, 2, 3, 0, 0], + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]) +""") diff --git a/torch/special/__init__.py b/torch/special/__init__.py index bd71945c95fe1..224e262c1ef60 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -3,6 +3,7 @@ from torch._torch_docs import common_args, multi_dim_common __all__ = [ + 'airy_ai', 'bessel_j0', 'bessel_j1', 'bessel_y0', @@ -50,8 +51,11 @@ 'shifted_chebyshev_polynomial_u', 'shifted_chebyshev_polynomial_v', 'shifted_chebyshev_polynomial_w', + 'scaled_modified_bessel_k0', + 'scaled_modified_bessel_k1', 'sinc', 'softmax', + 'spherical_bessel_j0', 'xlog1py', 'xlogy', 'zeta', @@ -871,6 +875,20 @@ """.format(**common_args)) +airy_ai = _add_docstr(_special.special_airy_ai, + r""" +airy_ai(input, *, out=None) -> Tensor + +Airy function :math:`\text{Ai}\left(\text{input}\right)`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + bessel_j0 = _add_docstr(_special.special_bessel_j0, r""" bessel_j0(input, *, out=None) -> Tensor @@ -1162,6 +1180,34 @@ {out} """.format(**common_args)) +scaled_modified_bessel_k0 = _add_docstr(_special.special_scaled_modified_bessel_k0, + r""" +scaled_modified_bessel_k0(input, *, out=None) -> Tensor + +Scaled modified Bessel function of the second kind of order :math:`0`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + +scaled_modified_bessel_k1 = _add_docstr(_special.special_scaled_modified_bessel_k1, + r""" +scaled_modified_bessel_k1(input, *, out=None) -> Tensor + +Scaled modified Bessel function of the second kind of order :math:`1`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) + shifted_chebyshev_polynomial_t = _add_docstr(_special.special_shifted_chebyshev_polynomial_t, r""" shifted_chebyshev_polynomial_t(input, n, *, out=None) -> Tensor @@ -1221,3 +1267,17 @@ Keyword args: {out} """.format(**common_args)) + +spherical_bessel_j0 = _add_docstr(_special.special_spherical_bessel_j0, + r""" +spherical_bessel_j0(input, *, out=None) -> Tensor + +Spherical Bessel function of the first kind of order :math:`0`. + +""" + r""" +Args: + {input} + +Keyword args: + {out} +""".format(**common_args)) diff --git a/torch/storage.py b/torch/storage.py index a6bef2074c80c..8e35973405b1b 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -13,10 +13,9 @@ except ModuleNotFoundError: np = None # type: ignore[assignment] -T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]') +T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]') class _StorageBase(object): _cdata: Any - is_cuda: bool = False is_sparse: bool = False is_sparse_csr: bool = False device: torch.device @@ -65,6 +64,12 @@ def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704 def _shared_incref(self, *args, **kwargs): ... # noqa: E704 @classmethod def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704 + @property + def is_cuda(self): ... # noqa: E704 + @classmethod + def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704 + @classmethod + def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704 def __str__(self): info_str = ( @@ -112,14 +117,14 @@ def tolist(self): def cpu(self): """Returns a CPU copy of this storage if it's not already on the CPU""" if self.device.type != 'cpu': - return torch._UntypedStorage(self.size()).copy_(self, False) + return torch.UntypedStorage(self.size()).copy_(self, False) else: return self def mps(self): """Returns a CPU copy of this storage if it's not already on the CPU""" if self.device.type != 'mps': - return torch._UntypedStorage(self.size(), device="mps").copy_(self, False) + return torch.UntypedStorage(self.size(), device="mps").copy_(self, False) else: return self @@ -217,16 +222,19 @@ def _new_shared(cls, size, *, device='cpu'): else: return cls._new_using_fd_cpu(size) - def _untyped(self): + def untyped(self): return self -class _UntypedStorage(torch._C.StorageBase, _StorageBase): +class UntypedStorage(torch._C.StorageBase, _StorageBase): def __getitem__(self, *args, **kwargs): if self.device.type == 'meta': raise NotImplementedError("Not available for 'meta' device type") return super().__getitem__(*args, **kwargs) + @property + def is_cuda(self): + return self.device.type == 'cuda' def _load_from_bytes(b): return torch.load(io.BytesIO(b)) @@ -240,9 +248,9 @@ def _load_from_bytes(b): def _dtype_to_storage_type_map(): # NOTE: We should no longer add dtypes to this map. This map # is only used for BC/FC with older PyTorch versions. Going forward, - # new dtypes of _TypedStorage should not translate to a legacy - # Storage class. Instead, new dtypes of _TypedStorage should - # be serialized as an _UntypedStorage paired with a torch.dtype + # new dtypes of TypedStorage should not translate to a legacy + # Storage class. Instead, new dtypes of TypedStorage should + # be serialized as an UntypedStorage paired with a torch.dtype return { torch.double: 'DoubleStorage', torch.float: 'FloatStorage', @@ -289,7 +297,7 @@ def _get_storage_from_sequence(sequence, dtype, device): dtype=dtype, device=device) - return tmp_tensor.storage()._untyped() + return tmp_tensor.storage().untyped() def _isint(x): if HAS_NUMPY: @@ -297,7 +305,7 @@ def _isint(x): else: return isinstance(x, int) -class _TypedStorage: +class TypedStorage: is_sparse = False dtype: torch.dtype @@ -310,7 +318,7 @@ def __new__(cls, *args, wrap_storage=None, dtype=None, device=None): if cls == torch.storage._LegacyStorage: raise RuntimeError("Only child classes of _LegacyStorage can be instantiated") - if cls == _TypedStorage: + if cls == TypedStorage: return super().__new__(cls) else: @@ -320,7 +328,7 @@ def __new__(cls, *args, wrap_storage=None, dtype=None, device=None): ' * no arguments\n' ' * (int size)\n' ' * (Sequence data)\n' - ' * (*, _UntypedStorage wrap_storage)') + ' * (*, UntypedStorage wrap_storage)') if device is not None: raise RuntimeError( @@ -343,10 +351,10 @@ def __new__(cls, *args, wrap_storage=None, dtype=None, device=None): arg_error_msg + f"\nArgument type not recognized: {type(args[0])}") - return _TypedStorage( + return TypedStorage( *args, dtype=cls.dtype, - device='cuda' if eval(cls.__module__) is torch.cuda else 'cpu') + device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu') else: if len(args) != 0: @@ -355,10 +363,10 @@ def __new__(cls, *args, wrap_storage=None, dtype=None, device=None): "\nNo positional arguments should be given when using " "'wrap_storage'") - if not isinstance(wrap_storage, torch._UntypedStorage): + if not isinstance(wrap_storage, torch.UntypedStorage): raise TypeError( arg_error_msg + - f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}") + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}") cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu' @@ -368,19 +376,19 @@ def __new__(cls, *args, wrap_storage=None, dtype=None, device=None): f"\nDevice of 'wrap_storage' must be {cls_device}" f", but got {wrap_storage.device.type}") - return _TypedStorage( + return TypedStorage( *args, wrap_storage=wrap_storage, dtype=cls.dtype) def __init__(self, *args, device=None, dtype=None, wrap_storage=None): arg_error_msg = ( - '_TypedStorage.__init__ received an invalid combination ' + 'TypedStorage.__init__ received an invalid combination ' 'of arguments. Expected one of:\n' ' * (*, torch.device device, torch.dtype dtype)\n' ' * (int size, *, torch.device device, torch.dtype dtype)\n' ' * (Sequence data, *, torch.device device, torch.dtype dtype)\n' - ' * (*, _UntypedStorage wrap_storage, torch.dtype dtype)') + ' * (*, UntypedStorage wrap_storage, torch.dtype dtype)') if wrap_storage is not None: if len(args) != 0: @@ -406,10 +414,10 @@ def __init__(self, *args, device=None, dtype=None, wrap_storage=None): self.dtype = dtype - if not isinstance(wrap_storage, torch._UntypedStorage): + if not isinstance(wrap_storage, torch.UntypedStorage): raise TypeError( arg_error_msg + - f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}") + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}") self._storage = wrap_storage @@ -422,11 +430,11 @@ def __init__(self, *args, device=None, dtype=None, wrap_storage=None): raise RuntimeError("Cannot create CUDA storage with quantized dtype") if len(args) == 0: - self._storage = torch._UntypedStorage(device=device) + self._storage = torch.UntypedStorage(device=device) elif len(args) == 1: if _isint(args[0]): - self._storage = torch._UntypedStorage(int(args[0]) * self.element_size(), device=device) + self._storage = torch.UntypedStorage(int(args[0]) * self.element_size(), device=device) elif isinstance(args[0], collections.abc.Sequence): self._storage = _get_storage_from_sequence(args[0], self.dtype, device) else: @@ -442,16 +450,17 @@ def __init__(self, *args, device=None, dtype=None, wrap_storage=None): @property def is_cuda(self): - return self._storage.device.type == 'cuda' + return self.device.type == 'cuda' - def _untyped(self): + def untyped(self): + """Returns the internal :class:`torch.UntypedStorage`""" return self._storage def _new_wrapped_storage(self, untyped_storage): - assert type(untyped_storage) == torch._UntypedStorage + assert type(untyped_storage) == torch.UntypedStorage - if type(self) == _TypedStorage: - return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype) + if type(self) == TypedStorage: + return TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype) else: return type(self)(wrap_storage=untyped_storage) @@ -497,7 +506,7 @@ def __setitem__(self, idx, value): torch.qint8: torch.int8 } tmp_dtype = interpret_dtypes[self.dtype] - tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(_TypedStorage( + tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(TypedStorage( wrap_storage=self._storage, dtype=tmp_dtype)) else: @@ -509,12 +518,12 @@ def __getitem__(self, idx): if self.device.type == 'meta': raise NotImplementedError("Not available for 'meta' device type") - # NOTE: Before _TypedStorage existed, indexing with a slice used to be + # NOTE: Before TypedStorage existed, indexing with a slice used to be # possible for Storage objects. However, it would return - # a storage view, which would be a hassle to implement in _TypedStorage, + # a storage view, which would be a hassle to implement in TypedStorage, # so it was disabled if isinstance(idx, slice): - raise RuntimeError('slices are only supported in _UntypedStorage.__getitem__') + raise RuntimeError('slices are only supported in UntypedStorage.__getitem__') elif not isinstance(idx, int): raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") @@ -526,7 +535,7 @@ def __getitem__(self, idx): torch.qint32: torch.int32, torch.qint8: torch.int8 } - return _TypedStorage( + return TypedStorage( wrap_storage=self._storage, dtype=interpret_dtypes[self.dtype])[idx] @@ -535,7 +544,7 @@ def __getitem__(self, idx): return tmp_tensor[idx_wrapped].item() def copy_(self, source: T, non_blocking: bool = None): - self._storage.copy_(source._untyped(), non_blocking) + self._storage.copy_(source.untyped(), non_blocking) return self def nbytes(self): @@ -556,7 +565,7 @@ def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]: def cuda(self, device=None, non_blocking=False, **kwargs) -> T: if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create CUDA storage with quantized dtype") - cuda_storage: torch._UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs) + cuda_storage: torch.UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs) return self._new_wrapped_storage(cuda_storage) def element_size(self): @@ -588,7 +597,7 @@ def __deepcopy__(self, memo): return self._new_wrapped_storage(copy.deepcopy(self._storage, memo)) def __sizeof__(self): - return super(_TypedStorage, self).__sizeof__() + self.nbytes() + return super(TypedStorage, self).__sizeof__() + self.nbytes() def clone(self): """Returns a copy of this storage""" @@ -623,8 +632,8 @@ def _new_shared(self, size, *, device=None): if device is None: device = 'cpu' device = torch.device(device) - untyped_storage = torch._UntypedStorage._new_shared(size * self.element_size(), device=device) - return _TypedStorage( + untyped_storage = torch.UntypedStorage._new_shared(size * self.element_size(), device=device) + return TypedStorage( wrap_storage=untyped_storage, dtype=self.dtype) @@ -658,34 +667,34 @@ def resize_(self, size): @classmethod def _free_weak_ref(cls, *args, **kwargs): - return _UntypedStorage._free_weak_ref(*args, **kwargs) + return UntypedStorage._free_weak_ref(*args, **kwargs) def _weak_ref(self, *args, **kwargs): return self._storage._weak_ref(*args, **kwargs) @classmethod def from_buffer(cls, *args, dtype=None, device=None, **kwargs): - if cls == _TypedStorage: + if cls == TypedStorage: dtype = torch.get_default_dtype() if dtype is None else dtype device = torch.device('cpu' if device is None else device) if device.type != 'cpu': - raise RuntimeError(f'_TypedStorage.from_buffer: Not available for device {device.type}') - untyped_storage: torch._UntypedStorage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs) + raise RuntimeError(f'TypedStorage.from_buffer: Not available for device {device.type}') + untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs) else: if dtype is not None or len(args) == 5: raise RuntimeError(( "from_buffer: 'dtype' can only be specified in " - "_UntypedStorage.from_buffer and _TypedStorage.from_buffer")) + "UntypedStorage.from_buffer and TypedStorage.from_buffer")) if device is not None: raise RuntimeError(( "from_buffer: 'device' can only be specified in " - "_UntypedStorage.from_buffer and _TypedStorage.from_buffer")) + "UntypedStorage.from_buffer and TypedStorage.from_buffer")) dtype = cls.dtype - untyped_storage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs) + untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs) - return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype) + return TypedStorage(wrap_storage=untyped_storage, dtype=dtype) def _to(self, dtype): if not isinstance(dtype, torch.dtype): @@ -762,9 +771,9 @@ def from_file(cls, filename, shared, size): shared (bool): whether to share memory size (int): number of elements in the storage """ - if cls == _TypedStorage: + if cls == TypedStorage: raise RuntimeError('from_file can only be called on derived classes') - untyped_storage = eval(cls.__module__)._UntypedStorage.from_file( + untyped_storage: UntypedStorage = UntypedStorage.from_file( filename, shared, size * torch._utils._element_size(cls.dtype)) @@ -773,7 +782,7 @@ def from_file(cls, filename, shared, size): @classmethod def _expired(cls, *args, **kwargs): - return eval(cls.__module__)._UntypedStorage._expired(*args, **kwargs) + return UntypedStorage._expired(*args, **kwargs) def is_pinned(self): return self._storage.is_pinned() @@ -795,7 +804,7 @@ def is_shared(self): @classmethod def _new_shared_cuda(cls, *args, **kwargs): - return torch._UntypedStorage._new_shared_cuda(*args, **kwargs) + return torch.UntypedStorage._new_shared_cuda(*args, **kwargs) def _share_filename_cpu_(self, *args, **kwargs): manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs) @@ -807,7 +816,7 @@ def _shared_decref(self): @classmethod def _release_ipc_counter(cls, *args, device=None, **kwargs): - return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) + return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) def _shared_incref(self, *args, **kwargs): return self._storage._shared_incref(*args, **kwargs) @@ -825,40 +834,40 @@ def _get_legacy_storage_class(self): if self.device.type not in ['cpu', 'cuda']: return None - module = 'torch.' if self.device.type == 'cpu' else 'torch.cuda.' + module = torch if self.device.type == 'cpu' else torch.cuda try: - return eval(module + storage_name) + return getattr(module, storage_name) except AttributeError: return None -_TypedStorage.type.__doc__ = _type.__doc__ -_TypedStorage.cuda.__doc__ = _cuda.__doc__ +TypedStorage.type.__doc__ = _type.__doc__ +TypedStorage.cuda.__doc__ = _cuda.__doc__ class _LegacyStorageMeta(type): dtype: torch.dtype def __instancecheck__(cls, instance): - if type(instance) == _TypedStorage: + if type(instance) == TypedStorage: cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu' return (cls_device == instance.device.type) and (cls.dtype == instance.dtype) return False -class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta): +class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta): @classmethod def _new_shared(cls, size): """Creates a new storage in shared memory with the same data type""" - untyped_storage = torch._UntypedStorage._new_shared(size * cls().element_size()) + untyped_storage = torch.UntypedStorage._new_shared(size * cls().element_size()) return cls(wrap_storage=untyped_storage) @classmethod def _release_ipc_counter(cls, *args, **kwargs): - return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) + return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) @classmethod def _new_shared_filename(cls, manager, obj, size): bytes_size = size * torch._utils._element_size(cls.dtype) - return cls(wrap_storage=torch._UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size)) + return cls(wrap_storage=torch.UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size)) def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): try: diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index be8f969fb2feb..8a8e538e4ea74 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -486,6 +486,25 @@ def setUpClass(cls): # Acquires the current device as the primary (test) device cls.primary_device = 'cuda:{0}'.format(torch.cuda.current_device()) +# See Note [Lazy Tensor tests in device agnostic testing] +lazy_ts_backend_init = False +class LazyTestBase(DeviceTypeTestBase): + device_type = 'lazy' + + def _should_stop_test_suite(self): + return False + + @classmethod + def setUpClass(cls): + import torch._lazy + import torch._lazy.metrics + import torch._lazy.ts_backend + global lazy_ts_backend_init + if not lazy_ts_backend_init: + # Need to connect the TS backend to lazy key before running tests + torch._lazy.ts_backend.init() + lazy_ts_backend_init = True + class MPSTestBase(DeviceTypeTestBase): device_type = 'mps' @@ -570,7 +589,7 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo # The tests in these test cases are derived from the generic tests in # generic_test_class. # See note "Generic Device Type Testing." -def instantiate_device_type_tests(generic_test_class, scope, except_for=None, only_for=None): +def instantiate_device_type_tests(generic_test_class, scope, except_for=None, only_for=None, include_lazy=False): # Removes the generic test class from its enclosing scope so its tests # are not discoverable. del scope[generic_test_class.__name__] @@ -592,6 +611,14 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on # Filter out the device types based on user inputs desired_device_type_test_bases = filter_desired_device_types(device_type_test_bases, except_for, only_for) + if include_lazy: + # Note [Lazy Tensor tests in device agnostic testing] + # Right now, test_view_ops.py runs with LazyTensor. + # We don't want to opt every device-agnostic test into using the lazy device, + # because many of them will fail. + # So instead, the only way to opt a specific device-agnostic test file into + # lazy tensor testing is with include_lazy=True + desired_device_type_test_bases.append(LazyTestBase) def split_if_not_empty(x: str): return x.split(",") if len(x) != 0 else [] @@ -717,6 +744,7 @@ def _parametrize_test(self, test, generic_cls, device_cls): 'context; use it with instantiate_device_type_tests() instead of ' 'instantiate_parametrized_tests()') + op = check_exhausted_iterator = object() for op in self.op_list: # Determine the set of dtypes to use. dtypes: Union[Set[torch.dtype], Set[None]] @@ -794,6 +822,9 @@ def test_wrapper(*args, **kwargs): # Provides an error message for debugging before rethrowing the exception print("Failed to instantiate {0} for op {1}!".format(test_name, op.name)) raise ex + if op is check_exhausted_iterator: + raise ValueError('An empty op_list was passed to @ops. ' + 'Note that this may result from reuse of a generator.') # Decorator that skips a test if the given condition is true. # Notes: diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index dc9ae9f36e1fa..eac2ec44bb3fe 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1,57 +1,135 @@ # Owner(s): ["oncall: distributed"] +import functools +import itertools import sys +from abc import ABC, abstractmethod from contextlib import suppress from copy import deepcopy -from enum import Enum +from enum import Enum, auto from math import inf -from typing import Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from unittest import mock import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel -from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_ +from torch.distributed.fsdp import CPUOffload +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + MixedPrecision, + ShardingStrategy, + TrainingState_, +) +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import ( + always_wrap_policy, + transformer_auto_wrap_policy, + wrap, +) +from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer +from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import ( TEST_SKIPS, MultiProcessTestCase, ) -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torch.distributed.fsdp.wrap import wrap from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms class FSDPInitMode(Enum): - # Move model to CUDA before wrap - CUDA_BEFORE = 1 - # Move model to CUDA after wrap - CUDA_AFTER = 2 - # Don't move model to CUDA at all. - CUDA_NEVER = 3 - -def _get_full_detached_param(fsdp_model: FullyShardedDataParallel): - with FullyShardedDataParallel.summon_full_params(fsdp_model): - params = list(p.clone().detach_() for p in fsdp_model.parameters()) - - return params - -def _validate(model, process_group, assert_fn): - module_states = [param.detach().cpu() for param in model.parameters()] - module_states.extend([buffer.detach().cpu() for buffer in model.buffers()]) + # No FSDP wrapping + NO_FSDP = auto() + # FSDP recursive wrapping + RECURSIVE = auto() + # TODO: FSDP non-recursive wrapping + # NONRECURSIVE = auto() + + +class CUDAInitMode(Enum): + # Move model to CUDA before passing to the FSDP constructor + CUDA_BEFORE = auto() + # Move model to CUDA after passing to the FSDP constructor + CUDA_AFTER = auto() + # Keep on CPU + CUDA_NEVER = auto() + + +class FSDPTestModel(nn.Module, ABC): + """This defines the interface expected from all models used commonly for + FSDP unit tests.""" + @abstractmethod + def get_input(self, device) -> Tuple[torch.Tensor, ...]: + """Returns an input for the model as as tuple.""" + ... + + @abstractmethod + def get_loss(self, input, output) -> torch.Tensor: + """Returns the loss given the input and output.""" + ... + + @abstractmethod + def run_backward(self, loss) -> None: + """Runs the backward pass (e.g. including ``loss.backward()``).""" + ... + + @staticmethod + @abstractmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + *init_args: Any, + cuda_init_mode: CUDAInitMode, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = False, + **init_kwargs: Any, + ) -> nn.Module: + """Initializes an instance of this model.""" + ... + + + +def _assert_module_states( + model: nn.Module, + process_group: dist.ProcessGroup, + assert_fn: Callable, +): + """ + All-gathers module states across ranks and calls ``assert_fn`` on each pair + of corresponding states from rank 0 and a nonzero rank. For example, if + ``assert_fn`` is ``self.assertEqual()``, then this checks that all module + states are equal across ranks. + """ + # Include names for debugging convenience + named_module_states = [ + (param_name, param.detach().cpu()) + for param_name, param in model.named_parameters() + ] + named_module_states += [ + (buffer_name, buffer.detach().cpu()) + for buffer_name, buffer in model.named_buffers() + ] world_size = dist.get_world_size(process_group) olist = [None for _ in range(world_size)] - dist.all_gather_object(olist, module_states, group=process_group) + dist.all_gather_object(olist, named_module_states, group=process_group) rank0_states = olist[0] for state in olist[1:]: - for p1, p2 in zip(rank0_states, state): + for (_, p1), (_, p2) in zip(rank0_states, state): assert_fn(p1, p2) -def _zero_model(fsdp_model: FullyShardedDataParallel): - with FullyShardedDataParallel.summon_full_params(fsdp_model): - for param in fsdp_model.parameters(): +def _zero_model( + model: nn.Module, + zero_buffers: bool = False, +): + """Zeros the parameters and optionally buffers of ``model`` in place.""" + with FSDP.summon_full_params(model): + for param in model.parameters(): with torch.no_grad(): param.zero_() + if zero_buffers: + for buffer in model.buffers(): + with torch.no_grad(): + buffer.zero_() def _get_state_dict(model, cpu_offload=False, half=False): if not cpu_offload: @@ -66,20 +144,26 @@ def subtest_name(test_name_mapping, *args): [test_name_mapping[str(s)] if s is not None else "none" for s in args] ) -# get full params of a model recursively. Note that if CPU offloading, it will -# also automatically move the parameters to GPU, due to _rebuild_full_params -# call. -def get_full_params(model, recurse=True): - with FullyShardedDataParallel.summon_full_params(model, recurse=recurse): +def get_full_params(model: nn.Module, recurse: bool = True): + """ + Returns the full unsharded parameters of ``model``. Any FSDP-managed + parameters offloaded to CPU are moved to GPU in the returned list. + + Args: + recurse (bool): If ``False``, only unshards the parameters immediate to + ``model``; if ``True``, recurses through the module hierarchy + rooted at ``model``. + """ + with FSDP.summon_full_params(model, recurse=recurse): return deepcopy(list(model.parameters())) -def _maybe_cuda(model, move_to_cuda): +def _maybe_cuda(model: nn.Module, move_to_cuda: bool): return model.cuda() if move_to_cuda else model -def _maybe_wrap_fsdp(model, wrap_fsdp, *args, **kwargs): +def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs): return ( model if not wrap_fsdp - else FullyShardedDataParallel(model, *args, **kwargs) + else FSDP(model, *args, **kwargs) ) class DummyProcessGroup: @@ -109,28 +193,31 @@ def __init__(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): super().__init__() # keep everything deterministic for model initialization torch.manual_seed(0) - self.inner: Union[torch.nn.Linear, FullyShardedDataParallel] = \ + self.inner: Union[torch.nn.Linear, FSDP] = \ torch.nn.Linear(2, 2).cuda() if wrap_fsdp: - self.inner = FullyShardedDataParallel(self.inner, cpu_offload=cpu_offload) + self.inner = FSDP(self.inner, cpu_offload=cpu_offload) self.outer = torch.nn.Linear(2, 2).cuda() def forward(self, x): y = self.inner(x) return self.outer(y) -class TransformerWithSharedParams(nn.Module): +class TransformerWithSharedParams(FSDPTestModel): def __init__( - self, group, *args, d_vocab=23, d_model=16, add_bn=True, - fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs + self, + group: dist.ProcessGroup, + cuda_init_mode: CUDAInitMode, + add_bn: bool, + deterministic: bool, ): super().__init__() self.rank = group.rank() self.world_size = group.size() - torch.manual_seed(0) # keep everything deterministic - assert ( - d_vocab >= 12 - ), "dim of vocab should be larger than 12, as we use torch.arange(12) as input" + if deterministic: + torch.manual_seed(0) + d_vocab = 23 + d_model = 16 self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( @@ -147,12 +234,17 @@ def __init__( self.register_buffer( "vocab_bias", self.embed_tokens.weight.new_ones((d_model,)) ) - self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long)) # type: ignore[arg-type] + self.register_buffer( + "long_buffer", + torch.zeros_like(self.vocab_bias, dtype=torch.long), + ) # type: ignore[arg-type] self.bs = 2 self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity() - move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE - self = _maybe_cuda(self, move_to_cuda) + if cuda_init_mode == CUDAInitMode.CUDA_BEFORE: + self = self.cuda() + if deterministic: + self.eval() def get_input(self, device): torch.manual_seed(1 + self.rank) # keep everything deterministic @@ -177,43 +269,95 @@ def get_loss(self, input, output): def run_backward(self, loss): loss.backward() + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + cuda_init_mode: CUDAInitMode, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = False, + add_bn: bool = True, + ) -> Union[nn.Module, FSDP]: + """ + Initializes a :class:`TransformerWithSharedParams` instance. + + Args: + fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap + any modules with FSDP. If ``RECURSIVE``, then wraps with + top-level FSDP. By default, the top-level FSDP uses the + ``transformer_auto_wrap_policy()`` for encoder and decoder + layers, but a different auto wrap policy may be specified via + ``fsdp_kwargs``. + cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. + fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments + forwarded to the FSDP constructor. + deterministic (bool): Whether to make the model deterministic + across constructions. + add_bn (bool): Whether to include batch norm in the model. + """ + if fsdp_kwargs is None: + fsdp_kwargs = {} + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + return TransformerWithSharedParams(group, cuda_init_mode, add_bn, deterministic) + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + # Default to the `transformer_auto_wrap_policy()` + if "auto_wrap_policy" not in fsdp_kwargs: + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + TransformerEncoderLayer, + TransformerDecoderLayer, + }, + ) + else: + auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy") + fsdp_model = FSDP( + TransformerWithSharedParams(group, cuda_init_mode, add_bn, deterministic), + group, + auto_wrap_policy=auto_wrap_policy, + **fsdp_kwargs, + ) + if cuda_init_mode == CUDAInitMode.CUDA_AFTER: + fsdp_model = fsdp_model.cuda() + return fsdp_model + raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") + def get_ignored_modules(self): return [self.transformer] -class NestedWrappedModule(nn.Module): - def __init__(self, group, wrap_fsdp, *args, wrap_everything=False, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs): +class NestedWrappedModule(FSDPTestModel): + def __init__( + self, + group: dist.ProcessGroup, + wrap_fsdp: bool, + cuda_init_mode: CUDAInitMode, + deterministic: bool, + **fsdp_kwargs, + ): super().__init__() self.rank = group.rank() self.world_size = group.size() - move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE + move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE def _maybe_wrap(layer): if wrap_fsdp: - return FullyShardedDataParallel(layer, group, *args, **kwargs) + return FSDP(layer, group, **fsdp_kwargs) return layer - torch.manual_seed(0) # keep everything deterministic - - if wrap_everything: - self.module = nn.Sequential( - _maybe_wrap(_maybe_cuda(nn.Linear(8, 4), move_to_cuda)), - _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), - _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)), - _maybe_wrap(_maybe_cuda(nn.Linear(4, 8), move_to_cuda)), - ) - else: - self.module = nn.Sequential( - _maybe_cuda(nn.Linear(8, 4), move_to_cuda), - _maybe_wrap( - nn.Sequential( - _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), - _maybe_cuda(nn.Linear(16, 16), move_to_cuda), - ), + if deterministic: + torch.manual_seed(0) + self.module = nn.Sequential( + _maybe_cuda(nn.Linear(8, 4), move_to_cuda), + _maybe_wrap( + nn.Sequential( + _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), + _maybe_cuda(nn.Linear(16, 16), move_to_cuda), ), - _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)), - _maybe_cuda(nn.Linear(4, 8), move_to_cuda), - ) + ), + _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)), + _maybe_cuda(nn.Linear(4, 8), move_to_cuda), + ) def get_input(self, device): torch.manual_seed(1 + self.rank) # keep everything deterministic @@ -229,9 +373,94 @@ def get_loss(self, input, output): def run_backward(self, loss): loss.backward() - -class ModuleWithDelay(nn.Module): - def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0): + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + cuda_init_mode: CUDAInitMode, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = False, + ) -> nn.Module: + """ + Initializes a :class:`NestedWrappedModule` instance. + + Args: + fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap + any modules with FSDP. If ``RECURSIVE``, then wraps some nested + modules with FSDP but not the top-level module. The model may + later be wrapped with a top-level FSDP external to this method + if desired. + cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. + fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments + forwarded to the FSDP constructor. + deterministic (bool): Whether to make the model deterministic + across constructions. + """ + if fsdp_kwargs is None: + fsdp_kwargs = {} + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + return NestedWrappedModule( + group, + wrap_fsdp=False, + cuda_init_mode=cuda_init_mode, + deterministic=deterministic, + ) + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + # Does not wrap with top-level FSDP + fsdp_model = NestedWrappedModule( + group, + wrap_fsdp=True, + cuda_init_mode=cuda_init_mode, + deterministic=deterministic, + **fsdp_kwargs, + ) + if cuda_init_mode == CUDAInitMode.CUDA_AFTER: + fsdp_model = fsdp_model.cuda() + return fsdp_model + raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") + + +class AlwaysWrapNestedWrappedModule(NestedWrappedModule): + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + cuda_init_mode: CUDAInitMode, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = False, + ): + """ + Initializes a :class:`NestedWrappedModule` instance, but unlike + :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this + wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap + policy. + """ + super_ = super(AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule) + model = super_.init( + group=group, + fsdp_init_mode=FSDPInitMode.NO_FSDP, + cuda_init_mode=cuda_init_mode, + fsdp_kwargs=fsdp_kwargs, + deterministic=deterministic, + ) + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + return model + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs) + if cuda_init_mode == CUDAInitMode.CUDA_AFTER: + fsdp_model = fsdp_model.cuda() + return fsdp_model + + +class ModuleWithDelay(FSDPTestModel): + """This class wraps a :class:`FSDPTestModel` to optionally add a delay + after computing the loss and/or before the gradient reduction.""" + def __init__( + self, + module: nn.Module, + delay_after_loss_ms: int, + delay_before_reduction_ms: int, + ): super().__init__() self.delay_after_loss_ms = delay_after_loss_ms self.delay_before_reduction_ms = delay_before_reduction_ms @@ -264,32 +493,53 @@ def _delayed_reduce_scatter(*args, **kwargs): ): self.module.run_backward(loss) + @staticmethod + def init( + module_class: Type[FSDPTestModel], + *model_args: Any, + delay_after_loss_ms: int, + delay_before_reduction_ms: int, + **model_kwargs: Any, + ): + """ + Args: + module_class (Type[FSDPTestModel]): Wrapped module class to which + to add delays. + model_args: Positional arguments forwarded to the ``module_class`` + ``init()``. + delay_after_loss_ms (int): Delay after computing the loss/before + the optimizer step (in ms). + delay_before_reduction_ms (int): Delay before reduce-scattering + gradients (in ms). + model_kwargs: Keyword arguments forwarded to the ``module_class`` + ``init()``. + """ + return ModuleWithDelay( + module_class.init(*model_args, **model_kwargs), + delay_after_loss_ms, + delay_before_reduction_ms, + ) class NestedWrappedModuleWithDelay(ModuleWithDelay): - def __init__( - self, - group, - wrap_fsdp, - fsdp_init_mode=FSDPInitMode.CUDA_AFTER, - cpu_offload=None, - backward_prefetch=None, - forward_prefetch=False, - sharding_strategy=None, - mixed_precision=None, - **kwargs + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + cuda_init_mode: CUDAInitMode = CUDAInitMode.CUDA_AFTER, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = False, + delay_after_loss_ms: int = 0, + delay_before_reduction_ms: int = 0, ): - super().__init__( - NestedWrappedModule( - group, - wrap_fsdp, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - ), - **kwargs + return super(NestedWrappedModuleWithDelay, NestedWrappedModuleWithDelay).init( + NestedWrappedModule, + group=group, + fsdp_init_mode=fsdp_init_mode, + cuda_init_mode=cuda_init_mode, + fsdp_kwargs=fsdp_kwargs, + deterministic=deterministic, + delay_after_loss_ms=delay_after_loss_ms, + delay_before_reduction_ms=delay_before_reduction_ms, ) @@ -303,14 +553,28 @@ def forward(self, *args, **kwargs): class MixtureOfExperts(NestedWrappedModule): - def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, **kwargs): - super().__init__(group, wrap_fsdp) + def __init__( + self, + group: dist.ProcessGroup, + wrap_fsdp: bool, + cuda_init_mode: CUDAInitMode, + delay_before_free_ms: int, + deterministic: bool, + **fsdp_kwargs, + ): + super().__init__( + group=group, + wrap_fsdp=wrap_fsdp, + cuda_init_mode=cuda_init_mode, + deterministic=deterministic, + ) self.group = group self.delay_before_free_ms = delay_before_free_ms self.wrap_fsdp = wrap_fsdp - self.move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE - # "expert" params are different on each rank - torch.manual_seed(42 + group.rank()) + self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE + if deterministic: + # Give each rank different expert parameters + torch.manual_seed(42 + self.rank) d_expert = 23 d_shared = 12 d_input = 8 @@ -320,8 +584,9 @@ def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mo for p in expert.parameters(): p.expert = True # type: ignore[attr-defined] - # everything else is shared - torch.manual_seed(0) + if deterministic: + # Keep all other parameters the same across ranks + torch.manual_seed(0) shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda) @@ -330,9 +595,8 @@ def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mo expert_group = torch.distributed.new_group( [group.rank()] ) # world size 1 means no shard - expert = FullyShardedDataParallel(expert, expert_group, **kwargs) # type: ignore[assignment] - - shared = FullyShardedDataParallel(shared, group, **kwargs) # type: ignore[assignment] + expert = FSDP(expert, expert_group, **fsdp_kwargs) # type: ignore[assignment] + shared = FSDP(shared, group, **fsdp_kwargs) # type: ignore[assignment] self.module = nn.Sequential( _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda), @@ -344,7 +608,7 @@ def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mo def forward(self, x): if self.delay_before_free_ms > 0: expert = self.module[2] - if isinstance(expert, FullyShardedDataParallel): + if isinstance(expert, FSDP): orig_free_full_params = self.module[2]._free_full_params def _free_full_params_with_delay(*args): @@ -365,8 +629,7 @@ def _free_full_params_with_delay(*args): def run_backward(self, loss): loss.backward() - - # manually reduce gradients if not wrapped in FullyShardedDataParallel + # Manually reduce gradients if not wrapped in FullyShardedDataParallel if not self.wrap_fsdp: with torch.no_grad(): for p in self.parameters(): @@ -375,6 +638,57 @@ def run_backward(self, loss): p.grad.div_(self.world_size) torch.distributed.all_reduce(p.grad, group=self.group) + @staticmethod + def init( + group: dist.ProcessGroup, + fsdp_init_mode: FSDPInitMode, + cuda_init_mode: CUDAInitMode, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = False, + delay_before_free_ms: int = 0, + ): + """ + Initializes a :class:`MixtureOfExperts` instance. + + Args: + fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap + any modules with FSDP. If ``RECURSIVE``, then wraps some nested + modules with FSDP, including the expert and shared layers, but + not the top-level module. The model may later be wrapped with a + top-level FSDP external to this method if desired. + cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. + fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments + forwarded to the FSDP constructor. + deterministic (bool): Whether to make the model deterministic + across constructions. + delay_before_free_ms (int): Delay before resharding expert + parameters in the forward pass (in ms). + """ + if fsdp_kwargs is None: + fsdp_kwargs = {} + if fsdp_init_mode == FSDPInitMode.NO_FSDP: + return MixtureOfExperts( + group, + wrap_fsdp=False, + cuda_init_mode=cuda_init_mode, + delay_before_free_ms=delay_before_free_ms, + deterministic=deterministic, + ) + elif fsdp_init_mode == FSDPInitMode.RECURSIVE: + # Does not wrap with top-level FSDP + fsdp_model = MixtureOfExperts( + group, + wrap_fsdp=True, + cuda_init_mode=cuda_init_mode, + delay_before_free_ms=delay_before_free_ms, + deterministic=deterministic, + **fsdp_kwargs, + ) + if cuda_init_mode == CUDAInitMode.CUDA_AFTER: + fsdp_model = fsdp_model.cuda() + return fsdp_model + raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") + class FSDPTest(MultiProcessTestCase): def setUp(self): @@ -385,6 +699,10 @@ def setUp(self): def world_size(self): return torch.cuda.device_count() if torch.cuda.is_available() else 4 + @property + def process_group(self): + return dist.distributed_c10d._get_default_group() + @property def init_method(self): return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name) @@ -398,6 +716,39 @@ def _check_backward_prefetch(self, fsdp_model, backward_prefetch): def _check_forward_prefetch(self, fsdp_model, forward_prefetch): self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch) + def run_subtests( + self, + subtest_config: Dict[str, List[Any]], + test_fn: Callable, + *test_args, + **test_kwargs: Any, + ): + """ + Runs a test function given by ``test_fn`` as a subtest according to the + configurations specified by ``subtest_config``. This amortizes the + costly setup overhead (including process spawn and initializing the + process group) over the subtests. + + Args: + subtest_config (Dict[str, List[Any]]): A mapping from subtest + keyword argument name to a list of its possible values. + test_fn (Callable): A callable that runs the actual test. + test_args: Positional arguments to pass to ``test_fn``. + test_kwargs: Keyword arguments to pass to ``test_fn``. + """ + # Convert the config mapping to a list to have a fixed order + subtest_config_items: List[Tuple[str, List[Any]]] = list(subtest_config.items()) + subtest_config_keys: List[str] = [item[0] for item in subtest_config_items] + subtest_config_values: List[List[Any]] = [item[1] for item in subtest_config_items] + for values in itertools.product(*subtest_config_values): + # Map keyword to chosen value + subtest_kwargs = { + kwarg: value for kwarg, value in zip(subtest_config_keys, values) + } + with self.subTest(**subtest_kwargs): + test_fn(*test_args, **test_kwargs, **subtest_kwargs) + dist.barrier() + @classmethod def _run(cls, rank, test_name, file_name, pipe): self = cls(test_name) @@ -440,16 +791,16 @@ def _run(cls, rank, test_name, file_name, pipe): def _train_for_several_steps( self, - model, - num_steps, - autocast, - lr=0.01, - fsdp_cpu_offload=None, - clip_norm=0.3, - norm_type=None, - save_model=False, - mixed_precision=None, - enable_sharded_grad_scaler=False, + model: nn.Module, + num_steps: int, + autocast: bool, + lr: float = 0.01, + fsdp_cpu_offload: Optional[CPUOffload] = None, + norm_type: Optional[Union[float, int]] = None, + save_model: bool = False, + mixed_precision: Optional[MixedPrecision] = None, + enable_sharded_grad_scaler: bool = False, + use_pure_fp16: bool = False, ): cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params @@ -463,14 +814,14 @@ def _train_for_several_steps( with torch.cuda.amp.autocast(enabled=autocast): # Inputs always cuda regardless of cpu offloading, or model.device input = model.module.get_input(torch.device("cuda")) - if mixed_precision and not isinstance(model, FullyShardedDataParallel): + if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)): if isinstance(input, torch.Tensor): input = input.half() else: input = tuple(x.half() for x in input) output = model(*input) # Post-forward, if CPU offloading model param should be on CPU. - if cpu_offload_params and isinstance(model, FullyShardedDataParallel): + if cpu_offload_params and isinstance(model, FSDP): for p in model.parameters(): # Params should always be on CPU, even if # p._is_sharded=False @@ -479,32 +830,35 @@ def _train_for_several_steps( loss = model.module.get_loss(input, output).to(model_device) loss = sharded_grad_scaler.scale(loss) - if not mixed_precision: + if not mixed_precision and not use_pure_fp16: assert ( loss.dtype == torch.float32 ), "loss data type should be float32, as the original \ parameter data type is float32." else: + if use_pure_fp16: + self.assertEqual(loss.dtype, torch.float16) # FSDP loss is fp16, DDP AMP loss is fp32 - if isinstance(model, FullyShardedDataParallel): + elif isinstance(model, FSDP): self.assertEqual(loss.dtype, mixed_precision.param_dtype) else: self.assertEqual(loss.dtype, torch.float32) model.module.run_backward(loss) if norm_type is not None: - if isinstance(model, FullyShardedDataParallel): - model.clip_grad_norm_(clip_norm, norm_type) + max_norm = 0.3 + if isinstance(model, FSDP): + model.clip_grad_norm_(max_norm, norm_type) total_norm_after_clip = _collect_total_grad_norm_fsdp( model, norm_type, self.rank ) else: - torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type) total_norm_after_clip = _collect_total_grad_norm_local( model, norm_type ) - self.assertTrue(total_norm_after_clip <= clip_norm) + self.assertTrue(total_norm_after_clip <= max_norm) # Post-backward, if CPU offloading model params should be on CPU. - if cpu_offload_params and isinstance(model, FullyShardedDataParallel): + if cpu_offload_params and isinstance(model, FSDP): for p in model.parameters(): # Params should always be on CPU, even if # p._is_sharded=False @@ -519,161 +873,157 @@ def _train_for_several_steps( # Zero params, if save/load state_dict did not work properly, this # would break the parity test with DDP. _zero_model(model) - model.load_state_dict(state_dict) - if isinstance(model, FullyShardedDataParallel): + if isinstance(model, FSDP): model._assert_state(TrainingState_.IDLE) return loss.detach() - def _test_identical_outputs( + def _test_fsdp_parity( self, - model_init_fn, - *args, - ref_ddp_fn=None, - num_steps=2, - fsdp_init_mode=FSDPInitMode.CUDA_AFTER, - lr=0.01, - cpu_offload=CPUOffload(), - backward_prefetch=None, - forward_prefetch=False, - sharding_strategy=None, - mixed_precision=None, - save_model=True, - clip_norm=0.3, - norm_type=None, - enable_sharded_grad_scaler=False, - **kwargs + model_class: Type[FSDPTestModel], + fsdp_init_mode: FSDPInitMode, + cuda_init_mode: CUDAInitMode, + ref_init_fn: Optional[Callable] = None, + num_iters: int = 2, + save_model: bool = True, + cpu_offload: CPUOffload = CPUOffload(), + backward_prefetch: Optional[BackwardPrefetch] = None, + forward_prefetch: bool = False, + sharding_strategy: Optional[ShardingStrategy] = None, + mixed_precision: Optional[MixedPrecision] = None, + enable_sharded_grad_scaler: bool = False, + use_pure_fp16: bool = False, + norm_type: Optional[Union[float, int]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + **fsdp_kwargs, ): - group = dist.distributed_c10d._get_default_group() - rank = group.rank() - # Establish reference behavior with PyTorch DDP (+ optionally autocast). - model = model_init_fn(group=group, wrap_fsdp=False).cuda() - if ref_ddp_fn is None: - model = nn.parallel.DistributedDataParallel( - model, device_ids=[rank], output_device=rank - ) + """ + Tests FSDP training against a reference, which defaults to DDP but + may be customized with ``ref_init_fn``. + + Args: + model_class (Type[FSDPTestModel]): A model class that inherits from + ``FSDPTestModel``, which defines the expected interface. + fsdp_init_mode (FSDPInitMode): The mode to initialize the + FSDP-wrapped model. This should not be ``NO_FSDP``. + ref_init_fn (Optional[Callable]): A callable to invoke that wraps a + non-wrapped model to construct the reference model, where this + wrapper should provide data parallel semantics. If ``None``, + then the callable defaults to the DDP constructor. + """ + assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP" + if init_kwargs is None: + init_kwargs = {} + lr = 1e-2 + rank = self.process_group.rank() + # Establish reference behavior with DDP + model = model_class.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + **init_kwargs, + ) + if ref_init_fn is None: + ref_model = DDP(model, device_ids=[rank], output_device=rank) else: - model = ref_ddp_fn(model) - - # DDP training + ref_model = ref_init_fn(model) + if use_pure_fp16: + ref_model = ref_model.half() ref_loss = self._train_for_several_steps( - model, num_steps, autocast=mixed_precision is not None, lr=lr, - fsdp_cpu_offload=cpu_offload, mixed_precision=mixed_precision, + ref_model, + num_iters, + autocast=mixed_precision is not None, + lr=lr, + fsdp_cpu_offload=cpu_offload, + mixed_precision=mixed_precision, + norm_type=norm_type, enable_sharded_grad_scaler=enable_sharded_grad_scaler, + use_pure_fp16=use_pure_fp16, + ) + ddp_params = list(ref_model.parameters()) + # Check against FSDP behavior + fsdp_kwargs.update( + { + "cpu_offload": cpu_offload, + "backward_prefetch": backward_prefetch, + "forward_prefetch": forward_prefetch, + "sharding_strategy": sharding_strategy, + "mixed_precision": mixed_precision, + } ) - ref_full_params = list(model.parameters()) - - # Confirm we get the same behavior using FullyShardedDataParallel. try: - model = model_init_fn( - group=group, - wrap_fsdp=True, - fsdp_init_mode=fsdp_init_mode, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, + fsdp_model = model_class.init( + self.process_group, + fsdp_init_mode, + cuda_init_mode, + fsdp_kwargs, + deterministic=True, + **init_kwargs, ) except Exception as e: - raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}") - - cpu_offload = cpu_offload or CPUOffload() # disabled if not specified. - model = FullyShardedDataParallel( - model, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - forward_prefetch=forward_prefetch, - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - ) - # Call model.cuda() after init FSDP if specified. - if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: - model = model.cuda() - - # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we - # expect FSDP code to raise error that we check below, in the case of - # offload params. - if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: - for p in model.parameters(): - # Should be on CPU regardless of if param is sharded. - self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}") - - only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params - ctx = ( + raise ValueError(f"Initializing {model_class} raised error {str(e)}") + if not isinstance(fsdp_model, FSDP): + # Enforce that we wrap with top-level FSDP since we are comparing + # assuming a data parallel reference and some test models may not + # do so in their `init()` method + fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs) + if use_pure_fp16: + # Change the model parameter dtype after FSDP initialization + fsdp_model = fsdp_model.half() + if cuda_init_mode == CUDAInitMode.CUDA_AFTER: + fsdp_model = fsdp_model.cuda() + offload_params = cpu_offload is not None and cpu_offload.offload_params + # Offloading parameters with `CUDA_AFTER` should raise an error during + # lazy initialization due to the parameter devices not being CPU; + # otherwise, all parameter devices should be CPU + expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER + expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER + if expects_cpu_device: + cpu_device = torch.device("cpu") + for param in fsdp_model.parameters(): + self.assertEqual(param.device, cpu_device) + context = ( self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") - if only_check_err else suppress() + if expects_device_error else suppress() ) - with ctx: - # FSDP training - shard_loss = self._train_for_several_steps( - model, num_steps, autocast=False, lr=lr, - fsdp_cpu_offload=cpu_offload, save_model=save_model, + with context: + fsdp_loss = self._train_for_several_steps( + fsdp_model, + num_iters, + autocast=False, + lr=lr, + fsdp_cpu_offload=cpu_offload, + save_model=save_model, mixed_precision=mixed_precision, + norm_type=norm_type, enable_sharded_grad_scaler=enable_sharded_grad_scaler, + use_pure_fp16=use_pure_fp16, ) - # We only check for errors in the case we have the following setup: - # model = FSDP(model, cpu_offload=True) - # model = model.cuda() - # so skip the rest of this logic. - if only_check_err: + # No need to check for parameter and loss parity if expecting an error + if expects_device_error: return - # If CPU offload, next call will change model params to GPU. Sanity - # check that params are on CPU before. - if cpu_offload.offload_params: - device_set = {p.device for p in model.parameters()} + # Check parameter devices are CPU if offloading to CPU before calling + # `get_full_params()`, which will cast the parameters to FP32 + if offload_params: + for param in fsdp_model.parameters(): + self.assertEqual(param.device, cpu_device) + fsdp_loss = fsdp_loss.cuda() + fsdp_unsharded_params = get_full_params(fsdp_model) + torch.testing.assert_allclose(ref_loss, fsdp_loss) + # Do not check for parameter parity if using mixed precision since (1) + # the DDP parameters are in FP16 (from `half()`) while the FSDP + # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs + # the optimizer in FP16 while FSDP runs it in FP32 + if mixed_precision is not None: self.assertEqual( - {torch.device("cpu")}, - device_set, - f"Got device set {device_set}" - ) - shard_full_params = get_full_params(model) - - if cpu_offload.offload_params: - shard_loss = shard_loss.cuda() - torch.testing.assert_allclose(ref_loss, shard_loss) - # Note that we don't do parameter check when testing mixed precision, - # as FSDP will bring the full param back to fp32 but we did model.half() - # for DDP so they wouldn't be equal. Further, DDP + model.half() would - # run optimizer in reduced precision versus FSDP's full precision. - if not mixed_precision: - self.assertEqual( - ref_full_params, - shard_full_params, + ddp_params, + fsdp_unsharded_params, exact_device=True, - msg="FullyShardedDataParallel didn't match PyTorch DDP", + msg="FSDP did not match DDP", ) - def _get_wrapped_model( - self, group, cuda_first=False, ignore_modules=False, config=None, - **model_kwargs, - ) -> FullyShardedDataParallel: - if config is None: - config = {} - move_to_cuda = not ( - "cpu_offload" in config and config["cpu_offload"].offload_params - ) - transformer = TransformerWithSharedParams(group, **model_kwargs) - if cuda_first and move_to_cuda: - transformer = transformer.cuda() - if ignore_modules: - assert "ignored_modules" not in config, \ - "Do not pass in `ignored_modules` via `config`" - config["ignored_modules"] = transformer.get_ignored_modules() - model = FullyShardedDataParallel(transformer, group, **config) - if not cuda_first and move_to_cuda: - model = model.cuda() - return model - - def _get_nonwrapped_model( - self, group, **model_kwargs, - ) -> torch.nn.Module: - """Returns the non-wrapped model that is wrapped in - :meth:`_get_wrapped_model`. The model used in these two methods should - be kept in sync for tests that use both for parity comparisons.""" - return TransformerWithSharedParams(group, **model_kwargs).cuda() - class SkipModule(nn.Module): def __init__(self): diff --git a/torch/testing/_internal/common_fx2trt.py b/torch/testing/_internal/common_fx2trt.py deleted file mode 100644 index 5e2593c7ab8e2..0000000000000 --- a/torch/testing/_internal/common_fx2trt.py +++ /dev/null @@ -1,273 +0,0 @@ -import unittest -from typing import Callable, List, Tuple - -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -import torch -import torch.fx -from torch_tensorrt.fx import ( - TRTInterpreter, - InputTensorSpec, - TRTModule, -) -from torch_tensorrt.fx.passes.pass_utils import chain_passes -from torch_tensorrt.fx.utils import LowerPrecision -from torch.fx.experimental.normalize import NormalizeArgs -from torch.fx.passes import shape_prop -from torch.testing._internal.common_utils import TestCase -import time - -def fetch_attr(mod, target): - """ - Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. - - Args: - target (str): The fully-qualfiied name of the attribute to fetch - - Return: - Any: The value of the attribute. - """ - target_atoms = target.split(".") - attr_itr = mod - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError( - f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" - ) - attr_itr = getattr(attr_itr, atom) - return attr_itr - - -@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available") -class TRTTestCase(TestCase): - def setUp(self): - super().setUp() - torch.manual_seed(3) - - def run_test(self, mod, inputs, expected_ops, unexpected_ops, interpreter, rtol, atol, precision=LowerPrecision.FP32): - with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - - mod.eval() - if len(expected_ops): - self.assert_has_op(mod, expected_ops) - if unexpected_ops: - self.assert_unexpected_op(mod, unexpected_ops) - start = time.perf_counter() - interpreter_result = interpreter.run(lower_precision=precision) - sec = time.perf_counter() - start - print("Interpreter run time(s):", sec) - trt_mod = TRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, - ) - - ref_outputs = mod(*inputs) - - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - outputs = trt_mod(*cuda_inputs) - end_event.record() - torch.cuda.synchronize() - print("TRT run time(s)=", (start_event.elapsed_time(end_event) * 1.0e-3)) - - if isinstance(outputs, torch.Tensor): - ref_outputs = [ref_outputs] - outputs = [outputs] - for out, ref in zip(outputs, ref_outputs): - if not isinstance(ref, torch.Tensor): - ref = torch.tensor([ref]) - ref = ref.cpu() # to_dtype test has cases with gpu output - torch.testing.assert_allclose(out.cpu(), ref, rtol=rtol, atol=atol) - - def run_test_custom_compare_results( - self, - mod, - inputs, - expected_ops, - interpreter, - comparators: List[Tuple[Callable, List]], - fp16_mode=False, - ): - """ - Runs the test and compares the result using the provided comparators. - The size of comparators must be equal to the number of outputs from 'mod'. - - mod - a model to run. - inputs - a list of the model inputs. - expected ops - a list of ops that should be verified. - interpreter - used for converting the model to TRT. - comparators - a list of (func, args) pairs corresponding to each of - the module outputs. usage: func(x, y, *args) - - """ - with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - - mod.eval() - if len(expected_ops): - self.assert_has_op(mod, expected_ops) - - interpreter_result = interpreter.run(lower_precision=LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32) - trt_mod = TRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, - ) - res_trt = trt_mod(*cuda_inputs).cpu() - res_cpu = mod(*inputs) - assert len(res_trt) == len(res_cpu) - assert len(res_cpu) == len(comparators) - for output_trt, output_cpu, comparator in zip( - res_trt, res_cpu, comparators - ): - comp_func = comparator[0] - args = comparator[1] - self.assertTrue(comp_func(output_trt, output_cpu, *args)) - - def run_test_with_error(self, mod, inputs, interpreter, expect_error): - with self.assertRaises(expect_error): - with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - - mod.eval() - interpreter.run(lower_precision=LowerPrecision.FP32) - - def assert_has_op(self, mod, ops): - ops_in_mod = set() - - for node in mod.graph.nodes: - if node.op == "call_module": - ops_in_mod.add(type(fetch_attr(mod, node.target))) - elif node.op in {"call_function", "call_method"}: - ops_in_mod.add(node.target) - - self.assertTrue( - ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}" - ) - - def assert_unexpected_op(self, mod, ops): - for node in mod.graph.nodes: - if (node.op == "call_module"): - if type(fetch_attr(mod, node.target)) in ops: - return False - elif node.op in {"call_function", "call_method"}: - if node.target in ops: - return False - return True - - -class VanillaTestCase(TRTTestCase): - def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06): - mod = torch.fx.symbolic_trace(mod) - shape_prop.ShapeProp(mod).propagate(*inputs) - mod = NormalizeArgs(mod).transform() - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol) - - def run_test_custom_compare_results( - self, - mod, - inputs, - expected_ops, - interpreter, - comparators: List[Tuple[Callable, List]], - fp16_mode=False, - ): - # interpreter is ignored, we do not need this for Vanilla tests - # Note this is different from internal version, we need to fix the test case - # after we refactor the internal callsites to use this file - mod = torch.fx.symbolic_trace(mod) - shape_prop.ShapeProp(mod).propagate(*inputs) - mod = NormalizeArgs(mod).transform() - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - super().run_test_custom_compare_results( - mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode - ) - - -class AccTestCase(TRTTestCase): - def run_test( - self, - mod, - inputs, - expected_ops, - unexpected_ops=None, - apply_passes=None, - test_explicit_batch_dim=True, - test_implicit_batch_dim=True, - test_explicit_precision=False, - rtol=1e-03, - atol=1e-03, - precision=LowerPrecision.FP32, - ): - mod.eval() - mod = acc_tracer.trace(mod, inputs) - - if apply_passes is not None: - pass_tracer = chain_passes(*apply_passes) - mod = pass_tracer(mod, inputs) - - if test_implicit_batch_dim: - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision) - - if test_explicit_batch_dim: - interp = TRTInterpreter( - mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True - ) - super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision) - - if test_explicit_precision: - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_precision=test_explicit_precision) - super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol) - - interp = TRTInterpreter( - mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, explicit_precision=test_explicit_precision - ) - super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision) - - - def run_test_with_assert_error( - self, - mod, - inputs, - expect_error, - test_explicit_batch_dim=True, - test_implicit_batch_dim=True, - ): - mod.eval() - mod = acc_tracer.trace(mod, inputs) - - if test_implicit_batch_dim: - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - super().run_test_with_error(mod, inputs, interp, expect_error) - - if test_explicit_batch_dim: - interp = TRTInterpreter( - mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True - ) - super().run_test_with_error(mod, inputs, interp, expect_error) - - def run_test_with_dynamic_shape( - self, - mod, - input_specs, - expected_ops, - unexpected_ops=None, - rtol=1e-03, - atol=1e-03, - ): - mod.eval() - inputs = InputTensorSpec.create_inputs_from_specs(input_specs) - mod = acc_tracer.trace(mod, inputs) - interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True) - super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 89cbe327a6290..420cf8738d475 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -34,11 +34,8 @@ _get_torch_cuda_version, _get_magma_version) from torch.testing._internal.common_utils import \ (is_iterable_of_tensors, - random_symmetric_matrix, random_symmetric_psd_matrix, make_fullrank_matrices_with_distinct_singular_values, - random_symmetric_pd_matrix, make_symmetric_matrices, - make_symmetric_pd_matrices, random_square_matrix_of_rank, - TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, IS_X86, TEST_SCIPY, + TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, torch_to_numpy_dtype_dict, TEST_WITH_ASAN, GRADCHECK_NONDET_TOL, slowTest, noncontiguous_like, freeze_rng_state) @@ -47,6 +44,7 @@ import torch._refs as refs # noqa: F401 import torch._refs.nn.functional import torch._refs.special +import torch._refs.linalg import torch._prims as prims # noqa: F401 @@ -70,6 +68,7 @@ L = 20 M = 10 S = 5 +XS = 3 # Unique value to distinguish default from anything else _NOTHING = object() @@ -798,6 +797,8 @@ class OpInfo(object): supports_expanded_weight: bool = False + is_factory_function: bool = False + def __post_init__(self): self._original_opinfo_args = asdict(self).copy() @@ -1248,6 +1249,7 @@ def sample_inputs_sparse_csr_masked_reduction(op_info, device, dtype, requires_g with sparse csr layouts. """ if op_info.supports_sparse_csr: + op_name = op_info.name.replace('_masked.', '') for sample_input in sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs): if not (sample_input.input.ndim == 2 and sample_input.kwargs.get('keepdim')): # - sparse CSR tensors are always 2-D tensors @@ -1260,7 +1262,7 @@ def sample_inputs_sparse_csr_masked_reduction(op_info, device, dtype, requires_g new_sample = SampleInput(sample_input.input.to_sparse_csr(), args=sample_input.args, kwargs=sample_input_kwargs) else: - if op_info.name.lstrip('_masked.') in ['prod']: + if op_name in ['prod', 'amax', 'amin', 'mean']: # reductions with non-zero reduction identity and # unspecified mask is not supported for sparse CSR # tensors, see torch._masked.prod implementation @@ -1352,7 +1354,7 @@ class ReductionOpInfo(OpInfo): the optional keyword parameters of the ReductionOpInfo constructor. If a reduction operator does not yet implement the full required API of - reduction operators, this should be documented by skipping the failing + reduction operators, this should be documented by xfailing the failing tests rather than adding optional parameters to ReductionOpInfo. NOTE @@ -1454,32 +1456,25 @@ def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(make_input((S, S, S)), args=args) -def sample_inputs_linalg_det(op_info, device, dtype, requires_grad, **kwargs): - kw = dict(device=device, dtype=dtype) - inputs = [ - make_tensor((S, S), **kw), - make_tensor((1, 1), **kw), # 1x1 - random_symmetric_matrix(S, **kw), # symmetric - random_symmetric_psd_matrix(S, **kw), # symmetric_psd - random_symmetric_pd_matrix(S, **kw), # symmetric_pd - - random_square_matrix_of_rank(S, S - 2, **kw), # dim2_null - random_square_matrix_of_rank(S, 1, **kw), # rank1 - random_square_matrix_of_rank(S, 2, **kw), # rank2 - - make_fullrank_matrices_with_distinct_singular_values(S, S, **kw), # full rank - make_tensor((3, 3, S, S), **kw), # batched - make_tensor((3, 3, 1, 1), **kw), # batched_1x1 - random_symmetric_matrix(S, 3, **kw), # batched_symmetric - random_symmetric_psd_matrix(S, 3, **kw), # batched_symmetric_psd - random_symmetric_pd_matrix(S, 3, **kw), # batched_symmetric_pd - make_fullrank_matrices_with_distinct_singular_values(S, 3, 3, **kw), # batched fullrank - make_tensor((0, 0), **kw), - make_tensor((0, S, S), **kw), - ] - for t in inputs: - t.requires_grad = requires_grad - return [SampleInput(t) for t in inputs] +def sample_inputs_linalg_det_logdet_slogdet(op_info, device, dtype, requires_grad, **kwargs): + make_fullrank = make_fullrank_matrices_with_distinct_singular_values + make_arg = partial(make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad) + batches = [(), (0, ), (3, )] + ns = [0, 1, 5] + + is_logdet = (op_info.name == "logdet") + + for batch, n, in product(batches, ns): + shape = batch + (n, n) + A = make_arg(*shape) + # Need to make the matrices in A have positive determinant for autograd + # To do so, we multiply A by its determinant to flip the sign of its determinant + if is_logdet and not A.is_complex() and A.numel() > 0: + s = torch.linalg.slogdet(A).sign + A = A * s.unsqueeze(-1).unsqueeze(-1) + A.requires_grad_(requires_grad) + yield SampleInput(A) + def sample_inputs_linalg_det_singular(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype) @@ -1652,16 +1647,19 @@ def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwar return result def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs): + low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((2, 2), (2, 3, 2)) - ords = ('fro', 'nuc', inf, -inf, 1, -1, 2, -2) + if dtype in low_precision_dtypes: + # svdvals not supported for low precision dtypes + ords = ('fro', inf, -inf, 1, -1) + else: + ords = ('fro', 'nuc', inf, -inf, 1, -1, 2, -2) dims = ((-2, -1), (-1, 0)) - inputs: List[SampleInput] = [] for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]): - t = make_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad) - inputs.append(SampleInput(t, args=(ord, dim, keepdim))) - - return inputs + yield SampleInput(make_arg(size), args=(ord, dim, keepdim)) def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad, *, variant=None, **kwargs): if variant is not None and variant not in ('subgradient_at_zero',): @@ -1916,7 +1914,11 @@ def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kw for shape in cases: yield SampleInput(make_arg(shape)) -def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_prelu(op_info, device, dtype, requires_grad, **kwargs): + op_kwargs = op_info.sample_kwargs(device, dtype, None)[0] + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad, + op_kwargs=op_kwargs) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) cases = ( @@ -1935,9 +1937,44 @@ def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **k channel_size = shape[1] yield SampleInput(make_arg(shape), args=(make_arg((channel_size,)),)) weight_tensor = torch.tensor(1., device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg((S, S)), kwargs=dict(weight=weight_tensor,)) yield SampleInput(make_arg((S, S)), kwargs=dict(weight=make_arg((S,)),)) +def reference_inputs_prelu(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_prelu(op, device, dtype, requires_grad, **kwargs) + yield from reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs) + +def sample_kwargs_prelu_scalar_weight(device, dtype, input): + weight = torch.rand(tuple(), device=device, dtype=dtype) + # NumPy does not support bfloat16, so we default to float32 (only for NumPy) in that case + if dtype == torch.bfloat16: + weight_cpu = weight.to(dtype=torch.float32, device="cpu") + else: + weight_cpu = weight.cpu() + np_weight = weight_cpu.numpy() + return ({'weight': weight}, {'weight': np_weight}) + +def error_inputs_prelu(op, device): + # Weight has numel != 1, but self.ndim is zero-dim tensor + inp = make_tensor(tuple(), device=device, dtype=torch.float32) + weight = make_tensor((2,), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="Not allow zero-dim input tensor.") + + # Weight has numel != 1, but numel does not match channel size + inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) + weight = make_tensor((9,), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="Mismatch of parameter numbers and input channel size.") + + # Weight is neither a scalar nor 1-D tensor + inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) + weight = make_tensor((2, 4), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = 2") + + # src and index tensors must have the same # of dimensions def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -2091,6 +2128,7 @@ def generate_elementwise_binary_arbitrarily_strided_tensors(op, *, device, dtype ((9, 5, 2), (0, 1, 7), 3), ) + make_arg = partial( make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, exclude_zero=exclude_zero ) @@ -2381,6 +2419,9 @@ def generate_elementwise_binary_noncontiguous_tensors( # Sample inputs for elementwise binary operators, like add def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): + _M = S if kwargs.get("small_inputs_only", False) else M + _S = XS if kwargs.get("small_inputs_only", False) else S + if hasattr(op, "rhs_make_tensor_kwargs"): exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) @@ -2390,14 +2431,14 @@ def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) shapes = ( ((), ()), - ((S,), ()), - ((S, 1), (S,)), - ((M, S), ()), - ((S, M, S), (M, S)), - ((S, M, S), (S, M, S)), - ((M, 1, S), (M, S)), - ((M, 1, S), (1, M, S)), - ((0, 1, 3), (0, 10, 3)), + ((_S,), ()), + ((_S, 1), (_S,)), + ((_M, _S), ()), + ((_S, _M, _S), (_M, _S)), + ((_S, _M, _S), (_S, _M, _S)), + ((_M, 1, _S), (_M, _S)), + ((_M, 1, _S), (1, _M, _S)), + ((0, 1, XS), (0, _M, XS)), ) sample_kwargs = kwargs.get("sample_kwargs", {}) @@ -2647,6 +2688,8 @@ def sample_inputs_elementwise_unary( if not op_kwargs: op_kwargs = {} + _L = S if kwargs.get("small_inputs_only", False) else L + low, high = op_info.domain low = low if low is None else low + op_info._domain_eps high = high if high is None else high - op_info._domain_eps @@ -2654,7 +2697,7 @@ def sample_inputs_elementwise_unary( # Tensors with dim=2 for sparse compressed testing yield SampleInput( make_tensor( - (L, L), + (_L, _L), device=device, dtype=dtype, low=low, @@ -2665,7 +2708,7 @@ def sample_inputs_elementwise_unary( ) else: # Creates a 1D, empty, and scalar tensor - for shape in ((L,), (1, 0, 3), ()): + for shape in ((_L,), (1, 0, 3), ()): yield SampleInput( make_tensor( shape, @@ -2984,6 +3027,107 @@ def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs): else: yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False}) +def error_inputs_arange(op, device, **kwargs): + yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzer') + yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range') + yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range') + +def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs): + int_samples = ( + # positive direction + (-1, 2, 2), + # negative direction + (2, -3, -1), + # start == end + (1, 1, 1), + (1, 1, -1), + # divides evenly + (0, -8, -4), + (1, 5, 2), + # bool + (False, True, True), + # default step + (0, 1, None), + # default start + (None, 3, None), + ) + + def to_float(start, end, step): + start = start + 0.1 if start is not None else None + end = end + 0.1 + step = float(step) if step is not None else None + return start, end, step + + float_samples = ( + # includes endpoint + (0., -8. - 1e-6, -4.), + (1., 5. + 1e-6, 2.), + (0., -8., -4.), + (1., 5., 2.), + *(to_float(start, end, step) for (start, end, step) in int_samples), + ) + + large_samples = ( + (0, 10000, None), + ) + + samples = int_samples + float_samples + if dtype not in (torch.int8, torch.uint8): + samples += large_samples + + for start, end, step in samples: + if start is None: + assert step is None + yield SampleInput(end, kwargs={"dtype": dtype, "device": device}) + elif step is None: + yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device}) + else: + yield SampleInput(start, args=(end, step), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(2) + yield SampleInput(1, args=(3, 1)) + + +def error_inputs_linspace(op, device, **kwargs): + yield ErrorInput(SampleInput(0, args=(3, -1)), error_type=RuntimeError, error_regex='number of steps must be non-negative') + yield ErrorInput(SampleInput(0, args=(3, 1.)), error_type=TypeError, error_regex='must be int, not float') + + +def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1, 4, 50) + starts = (-2., 0, 4.3, 50) + nsteps = (0, 1, 50) + # Extra case to replicate off-by-one issue on CUDA + cases = list(product(starts, ends, nsteps)) + [(0, 7, 50)] + for start, end, nstep in cases: + if dtype == torch.uint8 and end < 0 or start < 0: + continue + yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(1, args=(3, 1)) + + +def sample_inputs_logpace(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1.2, 2, 4) + starts = (-2., 0, 1, 2, 4.3) + nsteps = (0, 1, 2, 4) + bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) + for start, end, nstep, base in product(starts, ends, nsteps, bases): + if dtype == torch.uint8 and end < 0 or start < 0: + continue + if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): + # https://github.com/pytorch/pytorch/issues/82242 + continue + if base is None: + yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) + else: + yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(1, args=(3, 1, 2.)) + + def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs): yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) @@ -3015,6 +3159,17 @@ def error_inputs_isclose(op, device, **kwargs): error_regex='atol must be greater than or equal to zero') +def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batches = ((), (0,), (1,), (5,)) + ns = (0, 1, 3, 5) + for b, n in product(batches, ns): + shape = b + (n,) + yield SampleInput(make_arg(shape), args=(make_arg(shape),)) + for i in range(len(shape)): + yield SampleInput(make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i)) + + def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) return (SampleInput(make_arg((1, 2))), @@ -4577,11 +4732,11 @@ def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): ((S, S), 1, None, None), ((S, S), 0, (1, S), (2, S)), ((S, S), 0, None, (2, S)), - ((S, S, S), 1, None, None), - ((S, S, S), 2, None, None), - ((S, S, S), 1, (S, 1, S), (S, 1, S)), - ((S, S, S), 2, (S, S, 1), (S, S, 1)), - ((S, S, S), 2, (S, S, S), (S, S, S)),) + ((XS, XS, XS), 1, None, None), + ((XS, XS, XS), 2, None, None), + ((XS, XS, XS), 1, (XS, 1, XS), (XS, 1, XS)), + ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)), + ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),) sample_inputs = [] for size, dim, size_prepend, size_append in test_cases: @@ -4595,8 +4750,8 @@ def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): sample_inputs.append(SampleInput(input_tensor, args=(n, dim, prepend, append,))) # add some samples with n > dim_size - sample_inputs.append(SampleInput(make_arg((S, S, S)), args=(S + 1, 1,))) - sample_inputs.append(SampleInput(make_arg((S, S, S)), args=(S * 3 + 2, 2, make_arg((S, S, S)), make_arg((S, S, S)),))) + sample_inputs.append(SampleInput(make_arg((XS, XS, XS)), args=(S + 1, 1,))) + sample_inputs.append(SampleInput(make_arg((XS, XS, XS)), args=(S * 3 + 2, 2, make_arg((XS, XS, XS)), make_arg((XS, XS, XS)),))) return sample_inputs @@ -5669,8 +5824,10 @@ def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad def _generate_nan_reduction_inputs(device, dtype, requires_grad, **kwargs): yield from _generate_reduction_inputs(device, dtype, requires_grad) - yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad) - yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad) + # NaN only exists for floating point numbers + if dtype.is_complex or dtype.is_floating_point: + yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad) + yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad) def sample_inputs_nan_reduction(supports_multiple_dims): # Generates sample inputs for reduction ops that contain the input tensor @@ -5995,6 +6152,8 @@ def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs): ((), (),), ((), (0,),), ((), (0, True,),), + # Non-fused mode kernel on CUDA + ((3000,), ()), ) inputs = list((SampleInput(make_tensor(input_tensor, dtype=dtype, device=device, low=None, high=None, @@ -6311,57 +6470,25 @@ def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs): yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) -# TODO: reconcile with torch.linalg.det and torch.linalg.slogdet -# Creates matrices with a positive nonzero determinant -def sample_inputs_logdet(op_info, device, dtype, requires_grad, **kwargs): - def make_nonzero_det(A, *, sign=1, min_singular_value=0.1, **kwargs): - u, s, vh = torch.linalg.svd(A, full_matrices=False) - s.clamp_(min=min_singular_value) - A = (u * s.unsqueeze(-2)) @ vh - det = A.det() - if sign is not None: - if A.dim() == 2: - if (det < 0) ^ (sign < 0): - A[0, :].neg_() - else: - cond = ((det < 0) ^ (sign < 0)).nonzero() - if cond.size(0) > 0: - for i in range(cond.size(0)): - A[list(cond[i])][0, :].neg_() - return A - - # cases constructed using make_tensor() - tensor_shapes = ( - (S, S), - (1, 1), - (3, 3, S, S), - (3, 3, 1, 1) - ) - - for shape in tensor_shapes: - t = make_tensor(shape, device=device, dtype=dtype) - d = make_nonzero_det(t).requires_grad_(requires_grad) - yield SampleInput(d) - - # cases constructed using: - # 1) make_symmetric_matrices - # 2) make_symmetric_pd_matrices - # 3) make_fullrank_matrices_with_distinct_singular_values - symmetric_shapes = ( - (S, S), - (3, S, S), - ) +def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs): + # Inherit sample inputs from nn.pad, but transform them to fit + # constant_pad_nd's interface + nn_samples = sample_inputs_nn_pad(op_info, device, dtype, *args, + mode='constant', **kwargs) + # NOTE: primTorch is more strict about the type of the fill value argument + # So we must cast it to the correct dtype + from torch._prims_common import dtype_to_type + scalar_type = dtype_to_type(dtype) - def _helper(constructor, *shape, **kwargs): - t = constructor(*shape, device=device, dtype=dtype) - d = make_nonzero_det(t, **kwargs).requires_grad_(requires_grad) - yield SampleInput(d) + def drop_mode_argument(input, pad, mode=None, value=None): + if value is None: + return SampleInput(input, args=(pad,)) + else: + return SampleInput(input, args=(pad, scalar_type(value))) - for shape in symmetric_shapes: - _helper(make_symmetric_matrices, *shape) - _helper(make_symmetric_pd_matrices, *shape) - _helper(make_fullrank_matrices_with_distinct_singular_values, *shape, min_singular_value=0) + for sample in nn_samples: + yield drop_mode_argument(sample.input, *sample.args, **sample.kwargs) def np_unary_ufunc_integer_promotion_wrapper(fn): @@ -6402,17 +6529,26 @@ def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwarg # cuFFT supports powers of 2 for half and complex half precision # NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args # where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two - if self.name in ['fft.hfft', 'fft.irfft']: + low = None + high = None + if self.name in ['fft.hfft', 'fft.irfft', + '_refs.fft.hfft', '_refs.fft.irfft']: shapes = ((2, 9, 9), (33,)) - elif self.name in ['fft.hfft2', 'fft.irfft2']: + elif self.name in ['fft.hfft2', 'fft.irfft2', + '_refs.fft.hfft2', '_refs.fft.irfft2']: shapes = ((2, 8, 9), (33,)) - elif self.name in ['fft.hfftn', 'fft.irfftn']: + elif self.name in ['fft.hfftn', 'fft.irfftn', + '_refs.fft.hfftn', '_refs.fft.irfftn']: shapes = ((2, 2, 33), (33,)) + # Adjusting the limits because the test would be flaky due to over-saturation of float16 + # See: https://github.com/pytorch/pytorch/pull/81416 + low = -1.0 + high = 1.0 else: shapes = ((2, 8, 16), (32,)) - nd_tensor = partial(make_tensor, shapes[0], device=device, + nd_tensor = partial(make_tensor, shapes[0], device=device, low=low, high=high, dtype=dtype, requires_grad=requires_grad) - oned_tensor = partial(make_tensor, shapes[1], device=device, + oned_tensor = partial(make_tensor, shapes[1], device=device, low=low, high=high, dtype=dtype, requires_grad=requires_grad) if self.ndimensional == SpectralFuncType.ND: @@ -6485,6 +6621,10 @@ def __init__(self, sample_inputs_func=sample_inputs_spectral_ops, decorators=None, **kwargs): + + self._original_spectral_func_args = dict(locals()).copy() + self._original_spectral_func_args.update(kwargs) + decorators = list(decorators) if decorators is not None else [] decorators += [ skipCPUIfNoFFT, @@ -6878,16 +7018,6 @@ def out_fn(output): yield sample -def sample_inputs_linalg_slogdet(op_info, device, dtype, requires_grad=False, **kwargs): - def out_fn(output): - return output[1] - - samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad) - for sample in samples: - sample.output_process_fn_grad = out_fn - yield sample - - def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False, **kwargs): """ This function generates input for torch.linalg.pinv with hermitian=False keyword argument. @@ -7235,7 +7365,7 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs): make_fullrank = make_fullrank_matrices_with_distinct_singular_values make_arg = partial(make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad) - is_linalg_svd = (op_info.name == "linalg.svd") + is_linalg_svd = ("linalg.svd" in op_info.name) batches = [(), (0, ), (3, )] ns = [0, 3, 5] @@ -7318,12 +7448,41 @@ def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, ** for batch, m, n in product(batches, ns, ns): yield SampleInput(make_arg(batch + (m, n))) +def error_inputs_softshrink(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}), + error_regex="lambda must be greater or equal to 0, but found to be -0.5") -def sample_inputs_softshrink_hardshrink_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs): - N = 10 - tensors = [SampleInput(make_tensor((N, N), device=device, dtype=dtype, - requires_grad=requires_grad)) for _ in range(1, N)] - return tensors +def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of lambd beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + for lbda in (0., 0.5): + yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + +def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of lambd beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + # Note that unlike softshrink, lambd is allowed to be negative for hardshrink + for lbda in (-0.5, 0., 0.5): + yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + + +def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of min_val and max_val beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + for max_val, min_val in ((-0.5, 0.5), (0.5, -0.5), (0., 0.)): + yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) def sample_inputs_eig(op_info, device, dtype, requires_grad=False, **kwargs): eigvecs = make_tensor((S, S), device=device, dtype=dtype, @@ -7523,6 +7682,14 @@ def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sa yield SampleInput(a, args=(b, c)) +def _clamp_min_numpy(a, min=None): + return np.maximum(a, min) + + +def _clamp_max_numpy(a, max=None): + return np.minimum(a, max) + + def _clamp_numpy(a, min=None, max=None): if min is None: return np.minimum(a, max) @@ -7531,39 +7698,6 @@ def _clamp_numpy(a, min=None, max=None): return np.minimum(max, np.maximum(a, min)) - -def sample_inputs_clamp_scalar(op_info, device, dtype, requires_grad, **kwargs): - tensors = ( - make_tensor((2, 3, 2), dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad), - make_tensor((2, 0, 3), dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad), - ) - - if dtype is torch.uint8: - min_max_vals = ((2, 5), (3, 7)) - else: - min_max_vals = ((0, 1), (-1, 1)) - - output = [SampleInput( - tensor.clone().requires_grad_(requires_grad), - args=vals) for tensor, vals in product(tensors, min_max_vals)] - output += [ - SampleInput(tensors[0].clone().requires_grad_(requires_grad), - args=(0.5, None)), - SampleInput(tensors[0].clone().requires_grad_(requires_grad), - args=(None, 0.5))] - empty_tensor = make_tensor((), device=device, dtype=dtype, low=None, high=None, requires_grad=requires_grad) - output.append(SampleInput(empty_tensor, args=(0.0, 1.0))) - return output - -def sample_kwargs_clamp_scalar(device, dtype, input): - if dtype is torch.uint8: - min_val, max_val = (random.randint(1, 3), random.randint(4, 8)) - elif dtype.is_floating_point: - min_val, max_val = (random.uniform(-8, 0), random.uniform(1, 8)) # type: ignore[assignment] - else: - min_val, max_val = (random.randint(-8, 0), random.randint(1, 8)) - return {'min': min_val, 'max': max_val}, {'a_min': min_val, 'a_max': max_val} - def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs): sample0 = SampleInput(make_tensor((S, 3), device=device, dtype=dtype, requires_grad=requires_grad), args=(make_tensor((S, 3), device=device, dtype=dtype, requires_grad=requires_grad),)) @@ -7673,10 +7807,10 @@ def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **k make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) # Shapes for 2D Tensors - shapes_2d = ((M, M), (3, 5), (5, 3)) + shapes_2d = ((S, S), (3, 5), (5, 3)) # Shapes for 3D Tensors - shapes_3d = ((M, M, M),) + shapes_3d = ((S, S, S),) kwargs_2d = (dict(), dict(offset=2), dict(offset=2), dict(offset=1)) kwargs_3d = (dict(offset=1, dim1=1, dim2=2), @@ -8519,14 +8653,16 @@ def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(make_arg(shape), args=args) -def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_clone_contiguous(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) yield SampleInput(make_arg((S, M, S))) yield SampleInput(make_arg(())) -def reference_inputs_clone(op, device, dtype, requires_grad, **kwargs): - yield from sample_inputs_clone(op, device, dtype, requires_grad, **kwargs) +def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs): + # NOTE: the default memory format for clone is torch.preserve_format, for contiguous it's torch.contiguous_format + # This exploits that default to test torch.preserve_format for clone, without causing an error when testing contiguous + yield from sample_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs) shapes = ( (3, 5, 6), @@ -8544,6 +8680,11 @@ def reference_inputs_clone(op, device, dtype, requires_grad, **kwargs): yield SampleInput(make_arg(shape, noncontiguous=True)) yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) + yield SampleInput(make_arg(shape), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape, noncontiguous=True), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) + # shape, strides, offset strided_cases = ( ((5, 6, 2), (1, 1, 7), 2), @@ -8556,11 +8697,17 @@ def reference_inputs_clone(op, device, dtype, requires_grad, **kwargs): for shape, strides, offset in strided_cases: yield SampleInput(make_arg(500,).as_strided(shape, strides, offset)) + yield SampleInput(make_arg(500,).as_strided(shape, strides, offset), kwargs={'memory_format': torch.contiguous_format}) + # channels last 2D + yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last}) + a = make_arg((2, 2, 2, 2)).permute(0, 3, 1, 2) + yield SampleInput(a, kwargs={'memory_format': torch.channels_last}) -def sample_inputs_contiguous(op_info, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) - yield SampleInput(make_arg((S, S))) + # channels last 3D + yield SampleInput(make_arg((2, 2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last_3d}) + a = make_arg((2, 2, 2, 2, 2)).permute(0, 4, 1, 2, 3) + yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d}) def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs): @@ -8910,6 +9057,13 @@ def reference_inputs_where(op, device, dtype, requires_grad, **kwargs): yield SampleInput(a, args=(c, b)) + # two python scalars + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((1,)).item() + b = make_arg((1,)).item() + + yield SampleInput(a, args=(c, b)) + # NaN propagation if dtype.is_floating_point or dtype.is_complex: if dtype.is_floating_point: @@ -8925,6 +9079,11 @@ def reference_inputs_where(op, device, dtype, requires_grad, **kwargs): yield SampleInput(a, args=(c, b)) + # Python scalars type promotion + for scalar in (0, 0.0, 2j, False): + yield SampleInput(scalar, args=(c, b)) + yield SampleInput(a, args=(c, scalar)) + def error_inputs_where(op_info, device, **kwargs): shape = (S,) @@ -10489,12 +10648,11 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), skips=( - # Inplace abs doesn't support complex inputs - DecorateInfo(unittest.expectedFailure, 'TestGradients', + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestGradients', 'test_inplace_grad', dtypes=(torch.cdouble,)), - DecorateInfo(unittest.expectedFailure, 'TestGradients', + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestGradients', 'test_inplace_gradgrad', dtypes=(torch.cdouble,)), - DecorateInfo(unittest.expectedFailure, 'TestGradients', + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestGradients', 'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -10508,12 +10666,6 @@ def error_inputs_mean(op_info, device, **kwargs): # https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes', dtypes=[torch.cfloat, torch.cdouble]), - # The complex formula might be wrong - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', - dtypes=complex_types()), - # Forward-over-reverse gradgrad might be wrong for complex (see above): - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=complex_types()), ), supports_fwgrad_bwgrad=True, assert_autodiffed=True, @@ -10528,7 +10680,7 @@ def error_inputs_mean(op_info, device, **kwargs): ref=np.arccos, domain=(-1, 1), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -10561,17 +10713,14 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', dtypes=[torch.cdouble], active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD', - dtypes=[torch.cdouble], active_if=IS_WINDOWS), - )), + dtypes=[torch.cdouble], active_if=IS_WINDOWS),)), # NOTE: the derivative for inplace acosh is not implemented UnaryUfuncInfo('acosh', aliases=('arccosh', ), ref=np.arccosh, domain=(1, None), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - # "rsqrt_cuda" not implemented for 'BFloat16' - backward_dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), decorators=(precisionOverride({torch.bfloat16: 5e-2}),), supports_inplace_autograd=False, supports_forward_ad=True, @@ -10595,17 +10744,10 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), - # Reference: https://github.com/pytorch/pytorch/issues/50692 - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad', - device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_method_grad', - device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', - device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), ), - # acosh is not defined at x < 1 (real) or |z| < 1 (complex) + # acosh is not defined at x < 1 (real) reference_numerics_filter=NumericsFilter( - condition=lambda x: (torch.abs(x) < 1 if x.is_complex() else x < 1), + condition=lambda x: (x < 1 if not x.is_complex() else torch.zeros_like(x, dtype=torch.bool)), safe_val=2)), BinaryUfuncInfo('add', # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate @@ -10643,6 +10785,82 @@ def error_inputs_mean(op_info, device, **kwargs): 'test_reference_numerics_extremal_values', dtypes=(torch.complex64, torch.complex128)), )), + OpInfo('arange', + dtypes=all_types_and(torch.bfloat16, torch.float16), + supports_out=True, + supports_autograd=False, + is_factory_function=True, + error_inputs_func=error_inputs_arange, + sample_inputs_func=sample_inputs_arange, + skips=( + # https://github.com/pytorch/pytorch/issues/81774 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Lazy tensor failures + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + + # Exception raised from analyzeImpl at ../torch/csrc/jit/ir/alias_analysis.cpp:608 + # We don't have an op for aten::arange but it isn't a special case. + # Argument types: bool, bool, bool, int, int, Device, boo + DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'), + DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + + # Captured graph does not contain aten::arange (succeeds on complex!) + # g: graph(): + # %25 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # return (%25) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + BinaryUfuncInfo('clamp_max', + ref=_clamp_max_numpy, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_rhs_python_scalar=False, + supports_fwgrad_bwgrad=True, + rhs_make_tensor_kwargs=dict(exclude_zero=False), + skips=( + # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + # dispatch to lazy test failed + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), + )), + BinaryUfuncInfo('clamp_min', + ref=_clamp_min_numpy, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_rhs_python_scalar=False, + supports_fwgrad_bwgrad=True, + rhs_make_tensor_kwargs=dict(exclude_zero=False), + skips=( + # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + # dispatch to lazy test failed + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), + )), BinaryUfuncInfo('mul', aliases=('multiply',), dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool), @@ -10729,6 +10947,8 @@ def error_inputs_mean(op_info, device, **kwargs): dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if (SM53OrLater and CUDA11OrLater) or TEST_WITH_ROCM else []), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=[ @@ -10752,6 +10972,8 @@ def error_inputs_mean(op_info, device, **kwargs): backward_dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else [], torch.complex64, torch.complex128), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=[ @@ -10881,7 +11103,7 @@ def error_inputs_mean(op_info, device, **kwargs): aliases=('arcsinh', ), ref=np.arcsinh, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), decorators=(precisionOverride({torch.bfloat16: 5e-2}),), supports_inplace_autograd=False, supports_forward_ad=True, @@ -10960,7 +11182,7 @@ def error_inputs_mean(op_info, device, **kwargs): ref=np.arctanh, domain=(-1, 1), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), decorators=(precisionOverride({torch.bfloat16: 1e-2}),), supports_inplace_autograd=False, supports_forward_ad=True, @@ -11148,6 +11370,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('cholesky_inverse', dtypes=floating_and_complex_types(), backward_dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_fwgrad_bwgrad=True, supports_forward_ad=True, check_batched_gradgrad=True, @@ -11176,15 +11400,21 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('clone', ref=np.copy, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), - sample_inputs_func=sample_inputs_clone, - reference_inputs_func=reference_inputs_clone, + sample_inputs_func=sample_inputs_clone_contiguous, + reference_inputs_func=reference_inputs_clone_contiguous, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - supports_out=False), + supports_out=False, + skips=( + # TypeError: _copy_dispatcher() got an unexpected keyword argument 'memory_format' + # (NumPy reference needs to be extended with memory_format) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'), + ),), OpInfo('contiguous', op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), - sample_inputs_func=sample_inputs_contiguous, + sample_inputs_func=sample_inputs_clone_contiguous, + reference_inputs_func=reference_inputs_clone_contiguous, supports_forward_ad=True, supports_fwgrad_bwgrad=True, autodiff_fusible_nodes=['aten::contiguous'], @@ -11219,7 +11449,6 @@ def error_inputs_mean(op_info, device, **kwargs): device_type='mps', dtypes=[torch.float32]), ), decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off]), - # NOTE: clamp has separate opinfos for scalar min/max (unary op) vs. tensors OpInfo('clamp', aliases=('clip',), ref=_clamp_numpy, @@ -11259,6 +11488,7 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_forward_grad=False, supports_out=False), UnaryUfuncInfo('conj_physical', + decomp_aten_name='_conj_physical', ref=np.conj, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), @@ -11327,6 +11557,8 @@ def error_inputs_mean(op_info, device, **kwargs): BinaryUfuncInfo('copysign', dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), promotes_int_to_float=True, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True), OpInfo('corrcoef', @@ -11338,6 +11570,10 @@ def error_inputs_mean(op_info, device, **kwargs): supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, + skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + ), supports_out=False), UnaryUfuncInfo('cos', ref=np.cos, @@ -11408,6 +11644,12 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), # Float did not match double DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'), # Jacobian mismatch @@ -11471,8 +11713,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values', - device_type='cpu'), ), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('cummin', @@ -11482,8 +11722,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values', - device_type='cpu'), ), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), UnaryUfuncInfo('deg2rad', @@ -11506,6 +11744,8 @@ def error_inputs_mean(op_info, device, **kwargs): np.diff(input, n, dim, np._NoValue if prepend is None else prepend, np._NoValue if append is None else append) ), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_diff, @@ -11518,6 +11758,8 @@ def error_inputs_mean(op_info, device, **kwargs): variant_test_name='no_rounding_mode', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, promotes_int_to_float=True, supports_fwgrad_bwgrad=True, @@ -11529,6 +11771,8 @@ def error_inputs_mean(op_info, device, **kwargs): variant_test_name='trunc_rounding', dtypes=all_types_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="trunc")), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, promotes_int_to_float=True, supports_fwgrad_bwgrad=True, @@ -11544,6 +11788,8 @@ def error_inputs_mean(op_info, device, **kwargs): variant_test_name='floor_rounding', dtypes=all_types_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="floor")), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, promotes_int_to_float=True, supports_fwgrad_bwgrad=True, @@ -11622,6 +11868,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('diag_embed', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), supports_out=False, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_diagonal_diag_embed), @@ -11673,6 +11921,8 @@ def error_inputs_mean(op_info, device, **kwargs): ref=np.fmod, dtypes=all_types_and(torch.float16, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, assert_autodiffed=None, @@ -11695,6 +11945,8 @@ def error_inputs_mean(op_info, device, **kwargs): ref=np.remainder, dtypes=all_types_and(torch.float16, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, assert_autodiffed=None, @@ -11736,6 +11988,7 @@ def error_inputs_mean(op_info, device, **kwargs): )), SpectralFuncInfo('fft.fft', aten_name='fft_fft', + decomp_aten_name='_fft_c2c', ref=np.fft.fft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and_complex_and(torch.bool), @@ -11743,6 +11996,8 @@ def error_inputs_mean(op_info, device, **kwargs): # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11751,12 +12006,15 @@ def error_inputs_mean(op_info, device, **kwargs): SpectralFuncInfo('fft.fft2', aten_name='fft_fft2', ref=np.fft.fft2, + decomp_aten_name='_fft_c2c', ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and_complex_and(torch.bool), # rocFFT doesn't support Half/Complex Half Precision FFT # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11766,6 +12024,7 @@ def error_inputs_mean(op_info, device, **kwargs): ), SpectralFuncInfo('fft.fftn', aten_name='fft_fftn', + decomp_aten_name='_fft_c2c', ref=np.fft.fftn, ndimensional=SpectralFuncType.ND, dtypes=all_types_and_complex_and(torch.bool), @@ -11773,6 +12032,8 @@ def error_inputs_mean(op_info, device, **kwargs): # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11782,6 +12043,7 @@ def error_inputs_mean(op_info, device, **kwargs): ), SpectralFuncInfo('fft.hfft', aten_name='fft_hfft', + decomp_aten_name='_fft_c2r', ref=np.fft.hfft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and_complex_and(torch.bool), @@ -11789,6 +12051,8 @@ def error_inputs_mean(op_info, device, **kwargs): # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11796,6 +12060,7 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_gradgrad=False), SpectralFuncInfo('fft.hfft2', aten_name='fft_hfft2', + decomp_aten_name='_fft_c2r', ref=scipy.fft.hfft2 if has_scipy_fft else None, ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and_complex_and(torch.bool), @@ -11803,6 +12068,8 @@ def error_inputs_mean(op_info, device, **kwargs): # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_gradgrad=False, @@ -11815,6 +12082,7 @@ def error_inputs_mean(op_info, device, **kwargs): ), SpectralFuncInfo('fft.hfftn', aten_name='fft_hfftn', + decomp_aten_name='_fft_c2r', ref=scipy.fft.hfftn if has_scipy_fft else None, ndimensional=SpectralFuncType.ND, dtypes=all_types_and_complex_and(torch.bool), @@ -11822,6 +12090,8 @@ def error_inputs_mean(op_info, device, **kwargs): # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and_complex_and( torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_gradgrad=False, @@ -11834,12 +12104,15 @@ def error_inputs_mean(op_info, device, **kwargs): ), SpectralFuncInfo('fft.rfft', aten_name='fft_rfft', + decomp_aten_name='_fft_r2c', ref=np.fft.rfft, ndimensional=SpectralFuncType.OneD, dtypes=all_types_and(torch.bool), # rocFFT doesn't support Half/Complex Half Precision FFT # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, @@ -11848,12 +12121,15 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_gradgrad=False), SpectralFuncInfo('fft.rfft2', aten_name='fft_rfft2', + decomp_aten_name='_fft_r2c', ref=np.fft.rfft2, ndimensional=SpectralFuncType.TwoD, dtypes=all_types_and(torch.bool), # rocFFT doesn't support Half/Complex Half Precision FFT # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, @@ -11863,12 +12139,15 @@ def error_inputs_mean(op_info, device, **kwargs): ],), SpectralFuncInfo('fft.rfftn', aten_name='fft_rfftn', + decomp_aten_name='_fft_r2c', ref=np.fft.rfftn, ndimensional=SpectralFuncType.ND, dtypes=all_types_and(torch.bool), # rocFFT doesn't support Half/Complex Half Precision FFT # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs dtypesIfCUDA=all_types_and(torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half,)), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_grad=False, @@ -11878,8 +12157,11 @@ def error_inputs_mean(op_info, device, **kwargs): ],), SpectralFuncInfo('fft.ifft', aten_name='fft_ifft', + decomp_aten_name='_fft_c2c', ref=np.fft.ifft, ndimensional=SpectralFuncType.OneD, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11891,8 +12173,11 @@ def error_inputs_mean(op_info, device, **kwargs): torch.bool, *() if (TEST_WITH_ROCM or not SM53OrLater) else (torch.half, torch.complex32)),), SpectralFuncInfo('fft.ifft2', aten_name='fft_ifft2', + decomp_aten_name='_fft_c2c', ref=np.fft.ifft2, ndimensional=SpectralFuncType.TwoD, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11909,8 +12194,11 @@ def error_inputs_mean(op_info, device, **kwargs): ), SpectralFuncInfo('fft.ifftn', aten_name='fft_ifftn', + decomp_aten_name='_fft_c2c', ref=np.fft.ifftn, ndimensional=SpectralFuncType.ND, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11927,6 +12215,7 @@ def error_inputs_mean(op_info, device, **kwargs): ), SpectralFuncInfo('fft.ihfft', aten_name='fft_ihfft', + decomp_aten_name='_fft_r2c', ref=np.fft.ihfft, ndimensional=SpectralFuncType.OneD, supports_forward_ad=True, @@ -11942,8 +12231,11 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_grad=False), SpectralFuncInfo('fft.ihfft2', aten_name='fft_ihfft2', + decomp_aten_name='_fft_r2c', ref=scipy.fft.ihfftn if has_scipy_fft else None, ndimensional=SpectralFuncType.TwoD, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11963,8 +12255,11 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warnings'))), SpectralFuncInfo('fft.ihfftn', aten_name='fft_ihfftn', + decomp_aten_name='_fft_r2c', ref=scipy.fft.ihfftn if has_scipy_fft else None, ndimensional=SpectralFuncType.ND, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -11986,8 +12281,11 @@ def error_inputs_mean(op_info, device, **kwargs): ), SpectralFuncInfo('fft.irfft', aten_name='fft_irfft', + decomp_aten_name='_fft_c2r', ref=np.fft.irfft, ndimensional=SpectralFuncType.OneD, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -12000,8 +12298,11 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_gradgrad=False), SpectralFuncInfo('fft.irfft2', aten_name='fft_irfft2', + decomp_aten_name='_fft_c2r', ref=np.fft.irfft2, ndimensional=SpectralFuncType.TwoD, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -12019,8 +12320,11 @@ def error_inputs_mean(op_info, device, **kwargs): ), SpectralFuncInfo('fft.irfftn', aten_name='fft_irfftn', + decomp_aten_name='_fft_c2r', ref=np.fft.irfftn, ndimensional=SpectralFuncType.ND, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # See https://github.com/pytorch/pytorch/pull/78358 @@ -12037,14 +12341,14 @@ def error_inputs_mean(op_info, device, **kwargs): 'TestFFT', 'test_reference_nd')], ), OpInfo('fft.fftshift', - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), sample_inputs_func=sample_inputs_fftshift, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), OpInfo('fft.ifftshift', - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), sample_inputs_func=sample_inputs_fftshift, supports_out=False, supports_forward_ad=True, @@ -12058,6 +12362,8 @@ def error_inputs_mean(op_info, device, **kwargs): ], dtypes=floating_and_complex_types(), sample_inputs_func=sample_inputs_stft, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, @@ -12069,6 +12375,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('istft', dtypes=floating_and_complex_types(), sample_inputs_func=sample_inputs_istft, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, @@ -12084,6 +12392,12 @@ def error_inputs_mean(op_info, device, **kwargs): # gradcheck fails on ROCm (gh-68429) # grad is computed improperly (probably for weights tensor) DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), )), UnaryUfuncInfo('floor', ref=np.floor, @@ -12380,9 +12694,10 @@ def error_inputs_mean(op_info, device, **kwargs): op=torch.linalg.det, aliases=('det',), dtypes=floating_and_complex_types(), - backward_dtypes=floating_and_complex_types(), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, aten_name='linalg_det', - sample_inputs_func=sample_inputs_linalg_det, + sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-3, rtol=1e-3)}))], check_batched_gradgrad=False, @@ -12393,6 +12708,8 @@ def error_inputs_mean(op_info, device, **kwargs): aliases=('det',), dtypes=double_types(), backward_dtypes=double_types(), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, aten_name='linalg_det', sample_inputs_func=sample_inputs_linalg_det_singular, decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, @@ -12400,9 +12717,7 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_gradgrad=False, supports_inplace_autograd=False, skips=( - # These tests started breaking after touching the SVD. - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad', device_type='cpu', - dtypes=(torch.complex128,), active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), "TestGradients", 'test_fn_fwgrad_bwgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'), # dtypes are tested in the suite above, no need to repeat it for singular DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), @@ -12416,10 +12731,6 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_forward_grad=False, sample_inputs_func=sample_inputs_linalg_cholesky, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - skips=( - # Strides are not the same! - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - ), decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],), OpInfo('linalg.cholesky_ex', aten_name='linalg_cholesky_ex', @@ -12430,12 +12741,18 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_forward_grad=False, sample_inputs_func=sample_inputs_linalg_cholesky, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - skips=( - # AssertionError: Scalars are not equal! - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - ), decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], ), + OpInfo('linalg.vecdot', + aten_name='linalg_vecdot', + ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim), + dtypes=floating_and_complex_types_and(torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, + *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), + sample_inputs_func=sample_inputs_linalg_vecdot, + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), OpInfo('linalg.cond', aten_name='linalg_cond', dtypes=floating_and_complex_types(), @@ -12459,8 +12776,6 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( # AssertionError: Scalars are not equal! DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), - # Forward-over-reverse gradgrad might be incorrect - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps', dtypes=[torch.float32]), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', @@ -12505,9 +12820,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_fwgrad_bwgrad=True, decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off], skips=( - # Forward-over-reverse gradgrad might be incorrect - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=complex_types()), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps', dtypes=[torch.float32]), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', @@ -12540,6 +12852,8 @@ def error_inputs_mean(op_info, device, **kwargs): op=torch.linalg.householder_product, aliases=('orgqr', ), dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, # TODO: backward uses in-place operations that vmap doesn't like check_batched_grad=False, check_batched_gradgrad=False, @@ -12587,8 +12901,6 @@ def error_inputs_mean(op_info, device, **kwargs): # we skip gradient checks for this suite as they are tested in # variant_test_name='grad_oriented' DecorateInfo(unittest.skip("Skipped!"), 'TestGradients'), - # At this time ROCm uses magma instead of rocSolver, and the test passes - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', active_if=(not TEST_WITH_ROCM)), # The values for attribute 'shape' do not match DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', @@ -12607,6 +12919,8 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=floating_and_complex_types(), sample_inputs_func=sample_inputs_linalg_lstsq, error_inputs_func=error_inputs_lstsq, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_autograd=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -12615,15 +12929,13 @@ def error_inputs_mean(op_info, device, **kwargs): # tests do not work with passing lambda for op DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), - # At this time ROCm uses magma instead of rocSolver, and the test passes - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', active_if=(not TEST_WITH_ROCM)), - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad', - active_if=(not TEST_WITH_ROCM)), )), OpInfo('linalg.matrix_power', aliases=('matrix_power',), aten_name='linalg_matrix_power', dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_inplace_autograd=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -12663,20 +12975,14 @@ def error_inputs_mean(op_info, device, **kwargs): )), # NB: linalg.norm has two variants so that different skips can be used for different sample inputs OpInfo('linalg.norm', + aten_name='linalg_norm', op=torch.linalg.norm, dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], sample_inputs_func=sample_inputs_linalg_norm, supports_forward_ad=True, - # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: - # Could not allocate memory to change Tensor SizesAndStrides! check_batched_forward_grad=False, - supports_fwgrad_bwgrad=True, - aten_name='linalg_norm', - skips=( - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=[torch.complex128]), - )), + supports_fwgrad_bwgrad=True), OpInfo('linalg.norm', op=torch.linalg.norm, variant_test_name='subgradients_at_zero', @@ -12697,8 +13003,11 @@ def error_inputs_mean(op_info, device, **kwargs): )), OpInfo('linalg.matrix_norm', aten_name='linalg_matrix_norm', - dtypes=floating_and_complex_types(), + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, check_batched_gradgrad=False, + supports_fwgrad_bwgrad=True, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], sample_inputs_func=sample_inputs_linalg_matrix_norm), OpInfo('linalg.qr', @@ -12715,8 +13024,10 @@ def error_inputs_mean(op_info, device, **kwargs): aten_name='linalg_slogdet', op=torch.linalg.slogdet, dtypes=floating_and_complex_types(), - sample_inputs_func=sample_inputs_linalg_slogdet, - decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack]), OpInfo('linalg.vander', aten_name='linalg_vander', ref=np_vander_batched, @@ -12725,6 +13036,10 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False, + skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + ), sample_inputs_func=sample_inputs_linalg_vander), ReductionOpInfo( 'linalg.vector_norm', @@ -12742,32 +13057,113 @@ def error_inputs_mean(op_info, device, **kwargs): generate_args_kwargs=sample_kwargs_vector_norm, aten_name='linalg_vector_norm', skips=( - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=[torch.complex128]), # FIXME: sum reduces all dimensions when dim=[] DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), )), - UnaryUfuncInfo('log', - ref=np.log, - domain=(0, None), - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), - backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf), - assert_autodiffed=True, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - decorators=(precisionOverride({torch.bfloat16: 5e-2}),), - skips=( - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', - device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], - active_if=IS_WINDOWS), - ), - # log(z)->-inf for |z|->0 - reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), - UnaryUfuncInfo('log10', - ref=np.log10, - domain=(0, None), + OpInfo('linspace', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_linspace, + skips=( + # https://github.com/pytorch/pytorch/issues/81774 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + + # cpu implementation is wrong on some integral types + # https://github.com/pytorch/pytorch/issues/81996 + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cpu"), + # cuda implementation is off-by-one on some inputs due to precision issues + # https://github.com/pytorch/pytorch/issues/82230 + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('logspace', + dtypes=all_types_and_complex_and(torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_logpace, + skips=( + # https://github.com/pytorch/pytorch/issues/81774 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + UnaryUfuncInfo('log', + ref=np.log, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + # log(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + UnaryUfuncInfo('log10', + ref=np.log10, + domain=(0, None), decorators=(precisionOverride({torch.bfloat16: 5e-2}),), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), assert_autodiffed=True, @@ -12817,6 +13213,8 @@ def error_inputs_mean(op_info, device, **kwargs): reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), BinaryUfuncInfo('ldexp', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_inplace_autograd=False, @@ -12891,6 +13289,9 @@ def error_inputs_mean(op_info, device, **kwargs): aten_name='linalg_lu_factor', op=torch.linalg.lu_factor, dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_linalg_lu, @@ -12899,6 +13300,8 @@ def error_inputs_mean(op_info, device, **kwargs): aten_name='linalg_lu_factor_ex', op=torch.linalg.lu_factor_ex, dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_linalg_lu, @@ -12907,6 +13310,9 @@ def error_inputs_mean(op_info, device, **kwargs): aten_name='linalg_lu', op=torch.linalg.lu, dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_linalg_lu, @@ -12914,6 +13320,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('lu_unpack', op=torch.lu_unpack, dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=(skipCPUIfNoLapack,), @@ -12921,6 +13329,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('lu', op=torch.lu, dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # https://github.com/pytorch/pytorch/issues/66357 @@ -12963,6 +13373,8 @@ def error_inputs_mean(op_info, device, **kwargs): op=torch.linalg.lu_solve, aten_name='linalg_lu_solve', dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, @@ -12988,8 +13400,6 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_forward_grad=False, supports_out=False, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', - 'test_non_standard_bool_values', device_type='cuda'), )), OpInfo('masked_select', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), @@ -13011,8 +13421,6 @@ def error_inputs_mean(op_info, device, **kwargs): # https://github.com/pytorch/pytorch/issues/66357 check_batched_forward_grad=False, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), # times out DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'), ), @@ -13026,6 +13434,8 @@ def error_inputs_mean(op_info, device, **kwargs): if (SM53OrLater and CUDA11OrLater) or TEST_WITH_ROCM else []), assert_autodiffed=True, assert_jit_shape_analysis=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, @@ -13046,6 +13456,8 @@ def error_inputs_mean(op_info, device, **kwargs): 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'), ], skips=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), # https://github.com/pytorch/pytorch/issues/67470 DecorateInfo(unittest.skip("67470!"), 'TestCommon', 'test_noncontiguous_samples', @@ -13063,8 +13475,6 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values', - device_type='cpu', active_if=not (IS_MACOS and IS_X86)), ), supports_forward_ad=True), OpInfo('max', @@ -13094,57 +13504,20 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), OpInfo('var_mean', dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False), - backward_dtypes=floating_types_and(torch.half, torch.bfloat16), - backward_dtypesIfCUDA=floating_types_and(torch.half), + sample_inputs_func=sample_inputs_std_var, # TODO: some signatures of var_mean do support out supports_out=False, supports_forward_ad=True, - supports_fwgrad_bwgrad=False, # Need: var_mean - skips=( - # var_mean does not support automatic differentiation for outputs with complex dtype - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), - # https://github.com/pytorch/pytorch/issues/67539 - DecorateInfo(unittest.skip("67539"), 'TestCommon', 'test_noncontiguous_samples', - active_if=TEST_WITH_ASAN, device_type='cpu'), - # TODO: FIXME: complex inputs requiring grad error in forward - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), - # TODO: review with var_mean tests in test_autograd.py - DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'), - DecorateInfo(unittest.skip("Fails on ASAN!"), 'TestCompositeCompliance', 'test_backward'), - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'), - # Division by zero, may be related to above? - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad'))), + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True), OpInfo('std_mean', dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False), - backward_dtypes=floating_types_and(torch.half, torch.bfloat16), - backward_dtypesIfCUDA=floating_types_and(torch.half), + sample_inputs_func=sample_inputs_std_var, # TODO: some signatures of std_mean do support out supports_out=False, - supports_forward_ad=True, # Supports only certain variants? - supports_fwgrad_bwgrad=False, # Need: std_mean - skips=( - DecorateInfo(unittest.skip("ASAN: division by zero!"), active_if=TEST_WITH_ASAN), - # std_mean does not support forward when complex inputs require grad - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), - # https://github.com/pytorch/pytorch/issues/67539 - DecorateInfo(unittest.skip("67539"), 'TestCommon', 'test_noncontiguous_samples', - active_if=TEST_WITH_ASAN, device_type='cpu'), - # TODO: fix along with var_mean autograd tests - DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'), - DecorateInfo(unittest.skip("Fails on ASAN!"), 'TestCompositeCompliance', 'test_backward'), - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'), - # Division by zero, may be related to above? - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad'))), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True), OpInfo('meshgrid', variant_test_name='variadic_tensors', ref=np.meshgrid, @@ -13195,8 +13568,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_fwgrad_bwgrad=True, supports_forward_ad=True, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', - 'test_non_standard_bool_values', device_type='cpu'), )), OpInfo('min', variant_test_name='reduction_no_dim', @@ -13213,7 +13584,12 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.skip('Skipped!'), 'TestCompositeCompliance', 'test_backward'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.skip('Skipped!'), 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.skip('Skipped!'), 'TestCompositeCompliance', 'test_operator'), DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'), ), # See https://github.com/pytorch/pytorch/issues/66357 @@ -13226,7 +13602,12 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.skip('Skipped!'), 'TestCompositeCompliance', 'test_backward'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.skip('Skipped!'), 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.skip('Skipped!'), 'TestCompositeCompliance', 'test_operator'), DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'), ), # See https://github.com/pytorch/pytorch/issues/66357 @@ -13482,11 +13863,7 @@ def error_inputs_mean(op_info, device, **kwargs): dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_normalize, supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=[torch.complex128]), - )), + supports_fwgrad_bwgrad=True), OpInfo('aminmax', ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), dtypes=all_types_and(torch.bool), @@ -13498,10 +13875,6 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( # AssertionError: Resizing an out= argument with no elements threw a resize warning! DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), - DecorateInfo(unittest.skip('Fails on clang'), 'TestCommon', - 'test_non_standard_bool_values', device_type='cpu'), - # AssertionError: Shapes torch.Size([]) and torch.Size([1]) are not equal! - DecorateInfo(unittest.expectedFailure, 'TestFakeTensorNonErroring', 'test_fake'), )), OpInfo('as_strided', op=lambda x, size, stride, storage_offset=0: @@ -13519,9 +13892,6 @@ def error_inputs_mean(op_info, device, **kwargs): # AssertionError: False is not true : Scalars failed to compare as equal! DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestCommon', 'test_variant_consistency_eager'), - # RuntimeError: This operator is not Composite Compliant - DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestCompositeCompliance', 'test_backward'), - DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestCompositeCompliance', 'test_forward_ad'), # Not close DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestCommon', 'test_complex_half_reference_testing'), @@ -13567,8 +13937,9 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'), # JIT test also tries to compute double backward, which fails DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # RuntimeError: vector::_M_range_check: __n (which is 2) >= this->size() (which is 2) - DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + # Extremal value issue on aten::native_layer_norm, which returns 'nan' for mean on 'inf' inputs + # possibly because of the welford implementation. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'), )), OpInfo('nn.functional.cosine_similarity', aten_name="cosine_similarity", @@ -13621,6 +13992,8 @@ def error_inputs_mean(op_info, device, **kwargs): # DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), ), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -13721,8 +14094,6 @@ def error_inputs_mean(op_info, device, **kwargs): 'test_variant_consistency_jit', dtypes=(torch.float32,) ), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', "test_fn_gradgrad", dtypes=(torch.float64,)), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', "test_fn_fwgrad_bwgrad", dtypes=(torch.float64,)), ), ), UnaryUfuncInfo( @@ -13783,14 +14154,23 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=sample_inputs_conv_transpose3d, supports_forward_ad=True, supports_fwgrad_bwgrad=True, + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, decorators=[ DecorateInfo( toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), DecorateInfo( - toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), - 'TestCommon', 'test_noncontiguous_samples', device_type='cuda')], + toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }), + 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), }), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }), + 'TestCompositeCompliance', 'test_forward_ad', device_type='cuda', + active_if=TEST_CUDNN)], skips=( # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. @@ -13835,10 +14215,6 @@ def error_inputs_mean(op_info, device, **kwargs): # RuntimeError: UNSUPPORTED DTYPE: complex DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), - # Ref: https://github.com/pytorch/pytorch/issues/78077 - DecorateInfo(unittest.expectedFailure, 'TestExpandedWeightFunctional', - 'test_expanded_weight_per_sample_grad', - dtypes=(torch.float64,)), ), supports_expanded_weight=True, supports_out=False,), @@ -13850,6 +14226,8 @@ def error_inputs_mean(op_info, device, **kwargs): *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), sample_inputs_func=partial(sample_inputs_conv2d), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=( @@ -13873,10 +14251,6 @@ def error_inputs_mean(op_info, device, **kwargs): # RuntimeError: UNSUPPORTED DTYPE: complex DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), - # Ref: https://github.com/pytorch/pytorch/issues/78077 - DecorateInfo(unittest.expectedFailure, 'TestExpandedWeightFunctional', - 'test_expanded_weight_per_sample_grad', - dtypes=(torch.float64,)), ), supports_expanded_weight=True, supports_out=False,), @@ -13943,9 +14317,27 @@ def error_inputs_mean(op_info, device, **kwargs): 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'), ], sample_inputs_func=sample_inputs_local_response_norm,), + OpInfo('constant_pad_nd', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=sample_inputs_constant_pad_nd, + supports_out=False, + skips=( + # bool can't be passed to Scalar arguments in JIT tracer because + # BoolType is not a subtype of ScalarType. + DecorateInfo( + unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.bool,)), + DecorateInfo( + unittest.expectedFailure, 'TestCudaFuserOpInfo', + 'test_nvfuser_correctness', dtypes=(torch.bool,)), + )), OpInfo('nn.functional.pad', variant_test_name='constant', aten_name='constant_pad_nd', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), @@ -14016,6 +14408,8 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half), sample_inputs_func=sample_inputs_nn_unfold, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False, @@ -14212,7 +14606,7 @@ def error_inputs_mean(op_info, device, **kwargs): ref=_NOTHING, supports_out=False, dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_multilabel_soft_margin_loss, supports_forward_ad=True, decorators=( @@ -14314,7 +14708,8 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=sample_inputs_max_pool), OpInfo('nn.functional.max_pool2d', aten_name='max_pool2d', - supports_autograd=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, # Vmap is not happy with non-contiguous (channels_last) inputs check_batched_gradgrad=False, supports_out=False, @@ -14328,7 +14723,8 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=sample_inputs_max_pool), OpInfo('nn.functional.max_pool3d', aten_name='max_pool3d', - supports_autograd=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -14361,6 +14757,8 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad', + device_type='cpu'), )), OpInfo('nn.functional.max_unpool1d', variant_test_name='grad', @@ -14393,11 +14791,13 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_forward_mode_AD'), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), )), OpInfo('nn.functional.max_unpool2d', variant_test_name='grad', aten_name='max_unpool2d', - supports_autograd=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # Vmap is not happy with non-contiguous (channels_last) inputs @@ -14409,7 +14809,8 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=sample_inputs_max_unpool_grad), OpInfo('nn.functional.max_unpool3d', aten_name='max_unpool3d', - supports_autograd=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False, @@ -14427,6 +14828,7 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_forward_mode_AD'), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), )), OpInfo('nn.functional.max_unpool3d', variant_test_name='grad', @@ -14458,6 +14860,8 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_forward_grad=False, supports_expanded_weight=True, decorators=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'), )), @@ -14473,12 +14877,15 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)), ), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False), OpInfo('nn.functional.glu', aten_name='glu', - supports_autograd=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, sample_inputs_func=sample_inputs_glu, dtypes=floating_types_and(torch.bfloat16), dtypesIfROCM=floating_types_and(torch.float16, torch.bfloat16), @@ -14515,7 +14922,7 @@ def error_inputs_mean(op_info, device, **kwargs): 'TestUnaryUfuncs', device_type='cuda', ), ], ), - OpInfo( + UnaryUfuncInfo( 'nn.functional.prelu', aten_backward_name='prelu_backward', ref=lambda x, weight: @@ -14529,7 +14936,11 @@ def error_inputs_mean(op_info, device, **kwargs): assert_autodiffed=False, supports_gradgrad=True, supports_out=False, - sample_inputs_func=sample_inputs_nn_functional_prelu, + # test_reference_numerics only tests the case when the weight tensor is a scalar + sample_kwargs=sample_kwargs_prelu_scalar_weight, + error_inputs_func=error_inputs_prelu, + sample_inputs_func=sample_inputs_prelu, + reference_inputs_func=reference_inputs_prelu, decorators=[ # FIXME: second derivative is implemented but seems to be incorrect # https://github.com/pytorch/pytorch/issues/68760 @@ -14571,19 +14982,18 @@ def error_inputs_mean(op_info, device, **kwargs): aten_backward_name='rrelu_with_noise_backward', op=lambda input, *args, **kwargs: wrapper_set_seed(torch.nn.functional.rrelu, input, *args, **kwargs), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.rrelu, input, *args, inplace=True, **kwargs), ref=_NOTHING, dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), gradcheck_wrapper=wrapper_set_seed, supports_forward_ad=True, - supports_autograd=True, - assert_autodiffed=False, - supports_gradgrad=True, + supports_fwgrad_bwgrad=True, supports_out=False, sample_kwargs=lambda device, dtype, input: - ({'lower': 0., 'upper': 1.}, {'lower': 0., 'upper': 1.}), - inplace_variant=lambda input, *args, **kwargs: - wrapper_set_seed(partial(torch.nn.functional.rrelu, inplace=True), input, *args, **kwargs), + (dict(lower=0., upper=1., training=True), dict(lower=0., upper=1., training=True)), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs=dict(lower=0., upper=1., training=True)), decorators=( DecorateInfo( toleranceOverride({ @@ -14595,11 +15005,17 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( # lambda impl DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), # In-place operations do not play well with forward AD # https://github.com/pytorch/pytorch/issues/77447 DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_forward_mode_AD'), - DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),)), + # The noise vector that's generated in these tests is not the same elementwise + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'),)), UnaryUfuncInfo( 'nn.functional.selu', ref=lambda x, inplace=False: @@ -14725,7 +15141,7 @@ def error_inputs_mean(op_info, device, **kwargs): aten_backward_name='log_sigmoid_backward', ref=reference_logsigmoid, dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_autograd=True, assert_autodiffed=False, supports_forward_ad=True, @@ -14781,8 +15197,6 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=(torch.int, torch.int8)), - DecorateInfo(unittest.expectedFailure, 'TestGradients', - "test_fn_fwgrad_bwgrad", dtypes=(torch.complex128,)), # pytorch computes (0+nanj), numpy computes (-5e-18-1j) for input (-501.-1.0000e+20j) DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', "test_reference_numerics_large", dtypes=(torch.complex64,)),), @@ -14820,18 +15234,22 @@ def error_inputs_mean(op_info, device, **kwargs): active_if=(IS_MACOS or IS_WINDOWS)), ), ), - OpInfo( + UnaryUfuncInfo( 'nn.functional.threshold', - aten_backward_name='threshold_backward', - ref=lambda x, threshold, value: np.where(x > threshold, x, value).astype(x.dtype), + ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype), dtypes=all_types_and(torch.bfloat16), dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), - supports_autograd=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, assert_autodiffed=False, supports_gradgrad=True, supports_out=False, + sample_kwargs=lambda device, dtype, input: ({'threshold': 0.123, + 'value': -9}, + {'threshold': 0.123, + 'value': -9}), + # TODO(whc) should not need sample_inputs_func, but without it + # kwargs aren't being hooked up properly sample_inputs_func=sample_inputs_threshold, ), OpInfo( @@ -15041,46 +15459,43 @@ def error_inputs_mean(op_info, device, **kwargs): # # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), # )), - OpInfo('nn.functional.softshrink', - aten_name="softshrink", - aten_backward_name='softshrink_backward', - dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), - supports_autograd=True, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - assert_autodiffed=False, - sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh, - supports_gradgrad=True, - ), - OpInfo('nn.functional.hardshrink', - aten_name="hardshrink", - aten_backward_name='hardshrink_backward', - dtypes=floating_types_and(torch.bfloat16,), - dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), - supports_autograd=True, - assert_autodiffed=True, - sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh, - supports_gradgrad=True, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - autodiff_nonfusible_nodes=["aten::hardshrink"]), - OpInfo('nn.functional.hardtanh', - aten_name="hardtanh", - aten_backward_name='hardtanh_backward', - dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16), - backward_dtypes=all_types(), - dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.bfloat16), - backward_dtypesIfCUDA=floating_types_and(torch.float16), - supports_autograd=True, - assert_autodiffed=True, - sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh, - supports_gradgrad=True, - supports_out=False, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - autodiff_nonfusible_nodes=["aten::hardtanh"], - ), + UnaryUfuncInfo('nn.functional.softshrink', + ref=_NOTHING, + aten_name="softshrink", + aten_backward_name='softshrink_backward', + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + sample_inputs_func=sample_inputs_softshrink, + error_inputs_func=error_inputs_softshrink), + UnaryUfuncInfo('nn.functional.hardshrink', + ref=_NOTHING, + aten_name="hardshrink", + aten_backward_name='hardshrink_backward', + dtypes=floating_types_and(torch.bfloat16,), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardshrink, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::hardshrink"]), + UnaryUfuncInfo('nn.functional.hardtanh', + ref=_NOTHING, + aten_name="hardtanh", + aten_backward_name='hardtanh_backward', + dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16), + backward_dtypes=all_types(), + dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, + torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardtanh, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::hardtanh"]), OpInfo('nn.functional.gelu', aten_name="gelu", aten_backward_name='gelu_backward', @@ -15099,20 +15514,18 @@ def error_inputs_mean(op_info, device, **kwargs): # AssertionError: Tensor-likes are not close! # May not replicate in CI DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),)), - OpInfo('nn.functional.relu6', - aten_name="relu6", - dtypes=all_types_and(torch.bfloat16), - backward_dtypes=floating_types(), - dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), - backward_dtypesIfCUDA=floating_types_and(torch.float16), - supports_autograd=True, - assert_autodiffed=True, - sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh, - supports_gradgrad=True, - supports_out=False, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - autodiff_nonfusible_nodes=["aten::relu6"]), + UnaryUfuncInfo('nn.functional.relu6', + ref=_NOTHING, + aten_name="relu6", + dtypes=all_types_and(torch.bfloat16), + backward_dtypes=floating_types(), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16), + assert_autodiffed=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::relu6"]), OpInfo('mm', dtypes=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] @@ -15129,7 +15542,6 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( # Resized a non-empty tensor but did not warn about it DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values'), ), sample_inputs_func=sample_inputs_mode,), MvlGammaInfo(variant_test_name='mvlgamma_p_1', @@ -15191,6 +15603,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('dist', op=torch.dist, dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: @@ -15241,6 +15655,8 @@ def error_inputs_mean(op_info, device, **kwargs): # unsupported on CPU. backward_dtypes=floating_and_complex_types_and(torch.bfloat16), backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_inplace_autograd=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -15281,6 +15697,8 @@ def error_inputs_mean(op_info, device, **kwargs): ref=np.float_power, dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), promotes_int_to_float=True, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_one_python_scalar=True, @@ -15347,7 +15765,7 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo( "roll", ref=np.roll, - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), error_inputs_func=error_inputs_roll, supports_out=False, supports_forward_ad=True, @@ -15359,6 +15777,8 @@ def error_inputs_mean(op_info, device, **kwargs): "rot90", dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), error_inputs_func=error_inputs_rot90, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -15537,14 +15957,6 @@ def error_inputs_mean(op_info, device, **kwargs): # Reference: https://github.com/pytorch/pytorch/issues/48486 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.complex64]), - # The complex formula might be wrong - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', - dtypes=complex_types()), - # Passes for float, but for complex - Need: _s_where - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=complex_types()), - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD', - dtypes=complex_types()), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), )), @@ -15594,6 +16006,8 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), promotes_int_to_float=True, lhs_make_tensor_kwargs={'exclude_zero': True}, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, skips=( # https://github.com/pytorch/pytorch/issues/76806 @@ -15652,6 +16066,8 @@ def error_inputs_mean(op_info, device, **kwargs): if (SM53OrLater and CUDA11OrLater) or TEST_WITH_ROCM else []), assert_autodiffed=True, sample_inputs_func=sample_inputs_matmul, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -15682,6 +16098,8 @@ def error_inputs_mean(op_info, device, **kwargs): op=torch.Tensor.__rmod__, dtypes=floating_types_and(torch.bfloat16, torch.half,), dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -15752,6 +16170,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('slice_scatter', dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), sample_inputs_func=sample_inputs_slice_scatter, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False), @@ -16166,6 +16586,24 @@ def error_inputs_mean(op_info, device, **kwargs): op=torch.linalg.solve, dtypes=floating_and_complex_types(), sample_inputs_func=sample_inputs_linalg_solve, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + OpInfo('linalg.solve_ex', + aten_name='linalg_solve_ex', + op=torch.linalg.solve_ex, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_solve, supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], @@ -16219,6 +16657,8 @@ def error_inputs_mean(op_info, device, **kwargs): aten_name='linalg_pinv', op=torch.linalg.pinv, dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, check_batched_grad=False, check_batched_gradgrad=False, supports_forward_ad=True, @@ -16286,6 +16726,10 @@ def error_inputs_mean(op_info, device, **kwargs): skipCUDAIfNoMagma, skipCPUIfNoLapack, ], + skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + ), ), OpInfo('einsum', # we need this lambda because SampleInput expects tensor input as the first argument @@ -16314,6 +16758,8 @@ def error_inputs_mean(op_info, device, **kwargs): op=torch.svd, dtypes=floating_and_complex_types(), sample_inputs_func=sample_inputs_svd, + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, @@ -16322,8 +16768,6 @@ def error_inputs_mean(op_info, device, **kwargs): check_batched_gradgrad=False, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], skips=( - # Fixme, forward over backward gives a numerical error - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps', dtypes=[torch.float32]), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', @@ -16334,7 +16778,10 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('linalg.svd', op=torch.linalg.svd, aten_name='linalg_svd', + decomp_aten_name='_linalg_svd', dtypes=floating_and_complex_types(), + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_fwgrad_bwgrad=True, supports_forward_ad=True, check_batched_forward_grad=False, @@ -16344,8 +16791,6 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=sample_inputs_svd, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], skips=( - # FIXME forward over backward gives a numerical error - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps', dtypes=[torch.float32]), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', @@ -16356,6 +16801,7 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('linalg.svdvals', op=torch.linalg.svdvals, aten_name='linalg_svdvals', + decomp_aten_name='_linalg_svd', dtypes=floating_and_complex_types(), check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, @@ -16370,6 +16816,8 @@ def error_inputs_mean(op_info, device, **kwargs): *args, **kwargs ), dtypes=floating_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, check_batched_grad=False, check_batched_gradgrad=False, @@ -16392,6 +16840,8 @@ def error_inputs_mean(op_info, device, **kwargs): *args, **kwargs ), dtypes=floating_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, check_batched_forward_grad=False, check_batched_grad=False, @@ -16560,12 +17010,6 @@ def error_inputs_mean(op_info, device, **kwargs): # See https://github.com/pytorch/pytorch/pull/78358 check_batched_forward_grad=False, sample_inputs_func=sample_inputs_ravel, - skips=( - # the stride of the tensor was modified directly without going through the PyTorch dispatcher. - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), - # the stride of the tensor was modified directly without going through the PyTorch dispatcher. - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - ) ), OpInfo('reshape', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), @@ -16610,6 +17054,8 @@ def error_inputs_mean(op_info, device, **kwargs): )), OpInfo('atleast_1d', dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -16627,6 +17073,8 @@ def error_inputs_mean(op_info, device, **kwargs): ), OpInfo('atleast_2d', dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -16640,6 +17088,8 @@ def error_inputs_mean(op_info, device, **kwargs): ), OpInfo('atleast_3d', dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -16718,7 +17168,7 @@ def error_inputs_mean(op_info, device, **kwargs): gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('index_select', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), - backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), sample_inputs_func=sample_inputs_index, error_inputs_func=error_inputs_index_select, supports_forward_ad=True, @@ -16726,7 +17176,7 @@ def error_inputs_mean(op_info, device, **kwargs): assert_jit_shape_analysis=True, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('index_add', - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), supports_out=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -16735,11 +17185,17 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=sample_inputs_index, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('index_reduce', - dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypes=all_types_and(torch.float16, torch.bfloat16), supports_out=True, + skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + ), sample_inputs_func=sample_inputs_index_reduce), OpInfo('__getitem__', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -16779,7 +17235,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values'), )), OpInfo('unique', dtypes=all_types_and(torch.bool, torch.bfloat16), @@ -16794,8 +17249,6 @@ def error_inputs_mean(op_info, device, **kwargs): # 76571 DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values', dtypes=(torch.float16, torch.float32, torch.float64)), - DecorateInfo(unittest.skip("memory access error on some platforms"), - 'TestCommon', 'test_non_standard_bool_values'), )), OpInfo('unique_consecutive', dtypes=all_types_and(torch.bool, torch.bfloat16), @@ -16810,10 +17263,6 @@ def error_inputs_mean(op_info, device, **kwargs): # 76571 DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values', dtypes=(torch.float16, torch.float32, torch.float64)), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values'), - DecorateInfo(unittest.skip("memory access error on ROCm"), 'TestCommon', - 'test_non_standard_bool_values', device_type='cuda', - active_if=TEST_WITH_ROCM), )), OpInfo('put', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), @@ -16822,12 +17271,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, check_batched_gradgrad=False, # vmap complains of the sizes - skips=( - # Problem, needs to be fixed - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - ), sample_inputs_func=sample_inputs_put), OpInfo('take', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), @@ -16835,9 +17278,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_take, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), - ), error_inputs_func=error_inputs_take), OpInfo('scatter', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), @@ -17016,6 +17456,8 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing'), # Empty tensor data is garbage so it's hard to make comparisons with it. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty_like is not comparable"), 'TestCompositeCompliance', + 'test_operator'), # Can't find schemas for this operator for some reason DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), )), @@ -17141,6 +17583,8 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), # Empty tensor data is garbage so it's hard to make comparisons with it. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 'TestCompositeCompliance', + 'test_operator'), # Can't find schemas for this operator for some reason DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), @@ -17171,6 +17615,8 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), # Empty tensor data is garbage so it's hard to make comparisons with it. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', + 'test_operator'), # Can't find schemas for this operator for some reason DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), DecorateInfo(unittest.skip("Expected: empty is not comparable"), @@ -17232,8 +17678,6 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), - # NotImplementedError not raised - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'), # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestGradients'),)), OpInfo('normal', @@ -17390,6 +17834,8 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), sample_inputs_func=sample_inputs_cat_concat, reference_inputs_func=reference_inputs_cat, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, assert_autodiffed=True, @@ -17424,6 +17870,8 @@ def error_inputs_mean(op_info, device, **kwargs): op=lambda x, *args: x.unfold(*args), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17444,8 +17892,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_msort, skips=( - DecorateInfo(unittest.skip("segfaults on some systems"), 'TestCommon', - 'test_non_standard_bool_values', device_type='cpu'), )), OpInfo('movedim', aliases=('moveaxis',), @@ -17464,6 +17910,8 @@ def error_inputs_mean(op_info, device, **kwargs): op=lambda x, dims: x.repeat(dims), ref=np.tile, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17503,8 +17951,6 @@ def error_inputs_mean(op_info, device, **kwargs): # JIT has issue when op is passed as lambda # AssertionError: JIT Test does not execute any logic DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), - DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_backward'), DecorateInfo(unittest.skip("No fill_ op"), 'TestCudaFuserOpInfo'), DecorateInfo(unittest.skip("No fill_ op"), 'TestNNCOpInfo'), )), @@ -17551,6 +17997,8 @@ def error_inputs_mean(op_info, device, **kwargs): ShapeFuncInfo('tile', ref=np.tile, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17611,6 +18059,8 @@ def error_inputs_mean(op_info, device, **kwargs): method_variant=None, inplace_variant=torch.Tensor.zero_, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17716,6 +18166,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('mT', op=lambda x: x.mT, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17728,6 +18180,8 @@ def error_inputs_mean(op_info, device, **kwargs): op=lambda x: x.mH, aliases=('adjoint',), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17753,6 +18207,8 @@ def error_inputs_mean(op_info, device, **kwargs): OpInfo('kron', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_inplace_autograd=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -17966,8 +18422,6 @@ def error_inputs_mean(op_info, device, **kwargs): supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( - # fwgrad_bwgrad for abs is wrong - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=complex_types()), # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. DecorateInfo( @@ -18004,10 +18458,12 @@ def error_inputs_mean(op_info, device, **kwargs): reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)), OpInfo( 'logdet', - dtypes=floating_types(), + dtypes=floating_and_complex_types(), supports_out=False, - sample_inputs_func=sample_inputs_logdet, - decorators=(skipCPUIfNoLapack, skipCUDAIfNoMagma)), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), # `log_softmax` supports different dtypes based on whether `dtype` argument, # is passed or not. Hence two OpInfo entries, one with dtype and other without. OpInfo( @@ -18241,6 +18697,8 @@ def error_inputs_mean(op_info, device, **kwargs): "norm", sample_inputs_func=sample_inputs_norm, dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( @@ -18252,8 +18710,6 @@ def error_inputs_mean(op_info, device, **kwargs): "test_out", device_type="meta", ), - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=[torch.complex128]), ), ), OpInfo('norm', @@ -18302,9 +18758,9 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, + # fast gradcheck produces NaNs + gradcheck_fast_mode=False, skips=( - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=[torch.complex128]), # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result # of dtype torch.float32 into an out= with dtype torch.long DecorateInfo( @@ -18352,7 +18808,6 @@ def error_inputs_mean(op_info, device, **kwargs): # the op dispatches to _fused_dropout (with a few more conditions) # hence, different values and this skip here DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'),), - gradcheck_wrapper=wrapper_set_seed, supports_forward_ad=True, supports_fwgrad_bwgrad=True, # https://github.com/pytorch/pytorch/issues/66357 @@ -18371,20 +18826,34 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( # lambda impl DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), - DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: - # vmap: We do not yet support calling random operations inside of vmap. - # Please perform random operations outside of vmap as a workaround - DecorateInfo(unittest.expectedFailure, 'TestGradients', "test_forward_mode_AD"), - DecorateInfo(unittest.expectedFailure, 'TestGradients', "test_inplace_forward_mode_AD"),), - gradcheck_wrapper=wrapper_set_seed, + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),), supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False, + check_batched_forward_grad=False, # As per the docs, valid input dims are (3, 4) sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(3, 4)), inplace_variant=lambda input, *args, **kwargs: wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.dropout3d", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs), + ref=_NOTHING, + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + check_batched_forward_grad=False, + # As per the docs, valid input dims are (4, 5) + sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)), # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases OpInfo( @@ -18404,7 +18873,8 @@ def error_inputs_mean(op_info, device, **kwargs): # Please perform random operations outside of vmap as a workaround DecorateInfo(unittest.expectedFailure, 'TestGradients', "test_forward_mode_AD"), DecorateInfo(unittest.expectedFailure, 'TestGradients', "test_inplace_forward_mode_AD"),), - gradcheck_wrapper=wrapper_set_seed, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False, @@ -18478,7 +18948,6 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), # Not a problem: embedding_bag does weird stuff to its input (it renormalizes) DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cpu'), ), gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, supports_out=False, @@ -18602,6 +19071,9 @@ def error_inputs_mean(op_info, device, **kwargs): ReductionOpInfo( 'amax', nan_policy='propagate', + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), ref=reference_reduction_numpy(np.amax), skips=( @@ -18614,6 +19086,9 @@ def error_inputs_mean(op_info, device, **kwargs): ReductionOpInfo( 'amin', nan_policy='propagate', + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), ref=reference_reduction_numpy(np.amin), skips=( @@ -18693,9 +19168,6 @@ def error_inputs_mean(op_info, device, **kwargs): # FIXME: mean reduces all dimensions when dim=[] DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), - # FIXME: mean does not support passing None to dim - DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), - DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), # FIXME: improve precision DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', dtypes=[torch.float16]), @@ -18708,6 +19180,9 @@ def error_inputs_mean(op_info, device, **kwargs): nan_policy='omit', assert_autodiffed=True, promotes_int_to_float=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, dtypes=floating_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), ref=reference_reduction_numpy(np.nanmean), @@ -18736,6 +19211,7 @@ def error_inputs_mean(op_info, device, **kwargs): supports_fwgrad_bwgrad=True, assert_autodiffed=True, promotes_int_to_float=True, + check_batched_forward_grad=False, dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_std_var, @@ -18766,6 +19242,7 @@ def error_inputs_mean(op_info, device, **kwargs): complex_to_real=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_std_var, @@ -18792,6 +19269,8 @@ def error_inputs_mean(op_info, device, **kwargs): identity=1, nan_policy='propagate', supports_multiple_dims=False, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -18802,6 +19281,10 @@ def error_inputs_mean(op_info, device, **kwargs): sample_inputs_func=sample_inputs_prod, ref=reference_reduction_numpy(np.prod), skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), # FIXME: prod does not support passing keepdim without passing dim DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), # FIXME: prod reduces all dimensions when dim=[] @@ -18833,9 +19316,6 @@ def error_inputs_mean(op_info, device, **kwargs): # FIXME: sum reduces all dimensions when dim=[] DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), - # FIXME: sum does not support passing None to dim - DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), - DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), # FIXME: improve precision DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', dtypes=[torch.float16]), @@ -18849,9 +19329,15 @@ def error_inputs_mean(op_info, device, **kwargs): nan_policy='omit', supports_out=True, promotes_int_to_int64=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), ref=reference_reduction_numpy(np.nansum), skips=( + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), # FIXME: nansum reduces all dimensions when dim=[] DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), @@ -18883,9 +19369,6 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), # RuntimeError: undefined value tensor DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), ), decorators=[ DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-03, rtol=1e-03)}), @@ -18905,6 +19388,8 @@ def error_inputs_mean(op_info, device, **kwargs): method_variant=None, identity=1, nan_policy='propagate', + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -18916,13 +19401,14 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(unittest.skip("Failing on some jobs"), 'TestReductions', 'test_reference_masked', dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) DecorateInfo(unittest.skip("Skipped!"), 'TestMasked', 'test_mask_layout', device_type='cuda', dtypes=(torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, @@ -18943,6 +19429,8 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), method_variant=None, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -18960,10 +19448,15 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), method_variant=None, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), @@ -18978,6 +19471,9 @@ def error_inputs_mean(op_info, device, **kwargs): supports_out=False, dtypes=all_types_and(torch.float16, torch.bfloat16), supports_sparse=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse_csr=True, ref=reference_reduction_numpy(np.amax), skips=( DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), @@ -18987,20 +19483,25 @@ def error_inputs_mean(op_info, device, **kwargs): # RuntimeError: Unknown builtin op: aten::iinfo DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) - DecorateInfo(unittest.skip("Skipped!"), 'TestMasked', 'test_mask_layout', device_type='cuda', + # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) + DecorateInfo(unittest.skip("Skipped!"), 'TestMasked', 'test_mask_layout', dtypes=(torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128)), ), sample_inputs_func=sample_inputs_masked_reduction, sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, + sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, gradcheck_wrapper=gradcheck_wrapper_masked_operation ), ReductionOpInfo( '_masked.amin', nan_policy='propagate', supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, dtypes=all_types_and(torch.float16, torch.bfloat16), supports_sparse=True, + supports_sparse_csr=True, ref=reference_reduction_numpy(np.amin), skips=( DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), @@ -19010,12 +19511,14 @@ def error_inputs_mean(op_info, device, **kwargs): # RuntimeError: Unknown builtin op: aten::iinfo DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) - DecorateInfo(unittest.skip("Skipped!"), 'TestMasked', 'test_mask_layout', device_type='cuda', + # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) + DecorateInfo(unittest.skip("Skipped!"), 'TestMasked', 'test_mask_layout', dtypes=(torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128)), ), sample_inputs_func=sample_inputs_masked_reduction, sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, + sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, gradcheck_wrapper=gradcheck_wrapper_masked_operation ), ReductionOpInfo( @@ -19060,6 +19563,7 @@ def error_inputs_mean(op_info, device, **kwargs): method_variant=None, nan_policy='propagate', supports_out=False, + supports_sparse_csr=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, promotes_int_to_float=True, @@ -19077,15 +19581,17 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), # RuntimeError: undefined value tensor DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), + # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) + DecorateInfo(unittest.skip("Skipped!"), 'TestMasked', 'test_mask_layout', + dtypes=(torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, + torch.int64, torch.complex64, torch.complex128)), ), decorators=[ DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}), 'TestReductions', 'test_reference_masked'), ], sample_inputs_func=sample_inputs_masked_reduction, + sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, gradcheck_wrapper=gradcheck_wrapper_masked_operation ), OpInfo( @@ -19121,9 +19627,6 @@ def error_inputs_mean(op_info, device, **kwargs): # can't take variable number of arguments or use # keyword-only arguments with defaults DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), ), supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -19149,9 +19652,6 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), # RuntimeError: undefined value tensor DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), ), decorators=[ DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02), @@ -19173,6 +19673,8 @@ def error_inputs_mean(op_info, device, **kwargs): ref=reference_reduction_numpy(np.std) if np.lib.NumpyVersion(np.__version__) >= '1.20.2' else None, method_variant=None, nan_policy='propagate', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -19188,9 +19690,6 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), # RuntimeError: undefined value tensor DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness', dtypes=(torch.float16,)), ), @@ -19215,9 +19714,6 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), ), gradcheck_wrapper=gradcheck_wrapper_masked_operation, supports_forward_ad=True, @@ -19231,9 +19727,6 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), ), decorators=[ DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}), @@ -19251,9 +19744,6 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # see https://github.com/pytorch/pytorch/issues/76227 - DecorateInfo(unittest.skip("Fails on UBSAN!"), 'TestCompositeCompliance', 'test_forward_ad', - device_type='cpu'), ), gradcheck_wrapper=gradcheck_wrapper_masked_operation, supports_forward_ad=True, @@ -19266,14 +19756,13 @@ def error_inputs_mean(op_info, device, **kwargs): skips=( DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # Prexisting issue with linalg.vector_norm - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=[torch.complex128]), # RuntimeError: "clamp_min_cpu" not implemented for 'Half' DecorateInfo(unittest.expectedFailure, 'TestMasked', 'test_reference_masked', device_type='cpu', dtypes=[torch.half]), ), gradcheck_wrapper=gradcheck_wrapper_masked_operation, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=False), @@ -19385,12 +19874,18 @@ def error_inputs_mean(op_info, device, **kwargs): ref=_NOTHING, dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_gaussian_nll_loss, skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), # JIT does not support variadic tensors. # RuntimeError: input->type()->kind() == TypeKind::OptionalType # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, @@ -19398,7 +19893,8 @@ def error_inputs_mean(op_info, device, **kwargs): DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), ), decorators=( - DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02), + torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'), ) ), @@ -19461,14 +19957,12 @@ def error_inputs_mean(op_info, device, **kwargs): "test_variant_consistency_jit", dtypes=(torch.float32,), ), - DecorateInfo(unittest.expectedFailure, 'TestCommon', - 'test_non_standard_bool_values', device_type='cpu'), ), ), OpInfo( "repeat_interleave", dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), - backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), sample_inputs_func=sample_inputs_repeat_interleave, supports_out=False, supports_forward_ad=True, @@ -19501,8 +19995,6 @@ def error_inputs_mean(op_info, device, **kwargs): "test_variant_consistency_jit", dtypes=(torch.float32, torch.complex64), ), - DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', - dtypes=[torch.complex128]), ), ), OpInfo( @@ -19541,22 +20033,12 @@ def error_inputs_mean(op_info, device, **kwargs): "nn.functional.kl_div", sample_inputs_func=sample_inputs_kl_div, dtypes=floating_types_and(torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64), - backward_dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64), dtypesIfCUDA=floating_types_and( torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64 ), - backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64), supports_out=False, - check_batched_grad=False, supports_forward_ad=True, - skips=( - # See https://github.com/pytorch/pytorch/issues/65466 - DecorateInfo( - unittest.expectedFailure, - "TestGradients", - "test_fn_gradgrad", - ), - ), + supports_fwgrad_bwgrad=True, ), OpInfo( "diagflat", @@ -19584,8 +20066,12 @@ def error_inputs_mean(op_info, device, **kwargs): # complex not added to dtypes as complex gradients are not properly handled # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_scatter_reduce, + skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + ), ), OpInfo( 'scatter_reduce', @@ -19593,21 +20079,21 @@ def error_inputs_mean(op_info, device, **kwargs): # complex not added to dtypes as complex gradients are not properly handled # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet dtypes=all_types_and(torch.float16, torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_scatter_reduce, ), OpInfo( 'scatter_reduce', variant_test_name='amin', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_scatter_reduce, ), OpInfo( 'scatter_reduce', variant_test_name='amax', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_scatter_reduce, ), OpInfo( @@ -19648,6 +20134,27 @@ def error_inputs_mean(op_info, device, **kwargs): ), ), ), + UnaryUfuncInfo( + 'special.airy_ai', + decorators=( + precisionOverride( + { + torch.float32: 1e-03, + torch.float64: 1e-05, + }, + ), + ), + dtypes=all_types_and(torch.bool), + ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else _NOTHING, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + 'TestUnaryUfuncs', + 'test_reference_numerics_large', + ), + ), + supports_autograd=False, + ), UnaryUfuncInfo( 'special.bessel_j0', decorators=( @@ -19851,6 +20358,34 @@ def error_inputs_mean(op_info, device, **kwargs): ref=scipy.special.k1 if TEST_SCIPY else _NOTHING, supports_autograd=False, ), + UnaryUfuncInfo( + 'special.scaled_modified_bessel_k0', + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.float64: tol(atol=1e-05, rtol=1e-03), + } + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.k0e if TEST_SCIPY else _NOTHING, + supports_autograd=False, + ), + UnaryUfuncInfo( + 'special.scaled_modified_bessel_k1', + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.float64: tol(atol=1e-05, rtol=1e-03), + } + ), + ), + dtypes=all_types_and(torch.bool), + ref=scipy.special.k1e if TEST_SCIPY else _NOTHING, + supports_autograd=False, + ), BinaryUfuncInfo( 'special.shifted_chebyshev_polynomial_t', dtypes=all_types_and(torch.bool), @@ -19899,6 +20434,20 @@ def error_inputs_mean(op_info, device, **kwargs): supports_one_python_scalar=True, supports_autograd=False, ), + UnaryUfuncInfo( + 'special.spherical_bessel_j0', + decorators=( + toleranceOverride( + { + torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.float64: tol(atol=1e-05, rtol=1e-03), + } + ), + ), + dtypes=all_types_and(torch.bool), + ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else _NOTHING, + supports_autograd=False, + ), ] # NOTE [Python References] @@ -19923,12 +20472,12 @@ def error_inputs_mean(op_info, device, **kwargs): # construction arguments. These arguments can be overridden # by adding kwargs to the constructor. -def _find_referenced_opinfo(referenced_name): +def _find_referenced_opinfo(referenced_name, variant_name): ''' Finds the OpInfo with the given name that has no variant name. ''' for opinfo in op_db: - if opinfo.name == referenced_name and opinfo.variant_test_name == '': + if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name: return opinfo def _inherit_constructor_args(name, op, inherited, overrides): @@ -19983,12 +20532,14 @@ def __init__( *, op=None, # the function variant of the operation, populated as torch. if None torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name='', # the variant name for corresponding torch opinfo validate_view_consistency=True, supports_nvfuser=True, **kwargs): # additional kwargs override kwargs inherited from the torch opinfo self.torch_opinfo_name = torch_opinfo_name - self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name) + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name, torch_opinfo_variant_name) self.validate_view_consistency = validate_view_consistency self.supports_nvfuser = supports_nvfuser assert isinstance(self.torch_opinfo, OpInfo) @@ -20007,11 +20558,13 @@ def __init__( *, op=None, # the function variant of the operation, populated as torch. if None torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name='', # the variant name for corresponding torch opinfo supports_nvfuser=True, **kwargs): # additional kwargs override kwargs inherited from the torch opinfo self.torch_opinfo_name = torch_opinfo_name - self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name) + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name, torch_opinfo_variant_name) self.supports_nvfuser = supports_nvfuser assert isinstance(self.torch_opinfo, ReductionOpInfo) @@ -20033,11 +20586,13 @@ def __init__( *, op=None, # the function variant of the operation, populated as torch. if None torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name='', # the variant name for corresponding torch opinfo supports_nvfuser=True, **kwargs): # additional kwargs override kwargs inherited from the torch opinfo self.torch_opinfo_name = torch_opinfo_name - self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name) + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name, torch_opinfo_variant_name) self.supports_nvfuser = supports_nvfuser assert isinstance(self.torch_opinfo, UnaryUfuncInfo) @@ -20056,11 +20611,13 @@ def __init__( *, op=None, # the function variant of the operation, populated as torch. if None torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name='', # the variant name for corresponding torch opinfo supports_nvfuser=True, **kwargs): # additional kwargs override kwargs inherited from the torch opinfo self.torch_opinfo_name = torch_opinfo_name - self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name) + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name, torch_opinfo_variant_name) self.supports_nvfuser = supports_nvfuser assert isinstance(self.torch_opinfo, BinaryUfuncInfo) @@ -20069,6 +20626,30 @@ def __init__( super(ElementwiseBinaryPythonRefInfo, self).__init__(**ukwargs) +class SpectralFuncPythonRefInfo(SpectralFuncInfo): + ''' + An OpInfo for a Python reference of an elementwise unary operation. + ''' + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant='', + supports_nvfuser=True, + **kwargs): # additional kwargs override kwargs inherited from the torch opinfo + + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo = _find_referenced_opinfo(torch_opinfo_name, torch_opinfo_variant) + self.supports_nvfuser = supports_nvfuser + assert isinstance(self.torch_opinfo, SpectralFuncInfo) + + inherited = self.torch_opinfo._original_spectral_func_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + super().__init__(**ukwargs) + # Separate registry for experimental Python Reference OpInfos. python_ref_db = [ @@ -20102,19 +20683,106 @@ def __init__( "_refs.asin", torch_opinfo_name="asin", ), + ElementwiseUnaryPythonRefInfo( + "_refs.asinh", + torch_opinfo_name="asinh", + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.arange", + torch_opinfo_name="arange", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # See https://github.com/pytorch/pytorch/issues/82364 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + + # Prims arange does not follow aten + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', + dtypes=(torch.int64,)), + ), + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.linspace", + torch_opinfo_name="linspace", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # cpu implementation is wrong on some integral types + # https://github.com/pytorch/pytorch/issues/81996 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cpu"), + + # cuda implementation is off-by-one on some inputs due to precision issues + # https://github.com/pytorch/pytorch/issues/82230 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + # returns a view of an intermediate tensor (prims.to_dtype) + validate_view_consistency=False, + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.logspace", + torch_opinfo_name="logspace", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + # returns a view of an intermediate tensor (prims.to_dtype) + validate_view_consistency=False, + supports_nvfuser=False, + ), ElementwiseUnaryPythonRefInfo( "_refs.atan", torch_opinfo_name="atan", ), + ElementwiseUnaryPythonRefInfo( + "_refs.atanh", + torch_opinfo_name="atanh", + ), ElementwiseUnaryPythonRefInfo( "_refs.bitwise_not", torch_opinfo_name="bitwise_not", - supports_nvfuser=False, ), ElementwiseUnaryPythonRefInfo( "_refs.ceil", torch_opinfo_name="ceil", ), + ElementwiseUnaryPythonRefInfo( + "_refs.conj_physical", + torch_opinfo_name="conj_physical", + supports_nvfuser=False, + ), ElementwiseUnaryPythonRefInfo( "_refs.cos", torch_opinfo_name="cos", @@ -20169,6 +20837,11 @@ def __init__( torch_opinfo_name="frac", supports_nvfuser=False, ), + ElementwiseUnaryPythonRefInfo( + "_refs.imag", + torch_opinfo_name="imag", + supports_nvfuser=False, + ), ElementwiseUnaryPythonRefInfo( "_refs.isfinite", torch_opinfo_name="isfinite", @@ -20181,6 +20854,18 @@ def __init__( supports_out=True, supports_nvfuser=False, ), + ElementwiseUnaryPythonRefInfo( + "_refs.isposinf", + torch_opinfo_name="isposinf", + supports_out=True, + supports_nvfuser=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isneginf", + torch_opinfo_name="isneginf", + supports_out=True, + supports_nvfuser=False, + ), ElementwiseUnaryPythonRefInfo( "_refs.isnan", torch_opinfo_name="isnan", @@ -20214,12 +20899,13 @@ def __init__( PythonRefInfo( "_refs.logsumexp", torch_opinfo_name="logsumexp", + # When keepdim=False logsumexp function uses squeeze operation + # that is not yet exposed in nvFuser's Python API. supports_nvfuser=False, ), PythonRefInfo( "_refs.log_softmax", torch_opinfo_name="log_softmax", - supports_nvfuser=False, ), ElementwiseUnaryPythonRefInfo( "_refs.nan_to_num", @@ -20235,6 +20921,11 @@ def __init__( torch_opinfo_name="positive", supports_nvfuser=False, ), + ElementwiseUnaryPythonRefInfo( + "_refs.real", + torch_opinfo_name="real", + supports_nvfuser=False, + ), ElementwiseUnaryPythonRefInfo( "_refs.reciprocal", torch_opinfo_name="reciprocal", @@ -20253,7 +20944,6 @@ def __init__( # Reference: https://github.com/pytorch/pytorch/issues/56012 handles_complex_extremal_values=False, handles_large_floats=False, - supports_nvfuser=False, ), ElementwiseUnaryPythonRefInfo( "_refs.sign", @@ -20276,7 +20966,6 @@ def __init__( PythonRefInfo( "_refs.softmax", torch_opinfo_name="softmax", - supports_nvfuser=False ), ElementwiseUnaryPythonRefInfo( "_refs.sqrt", @@ -20298,6 +20987,10 @@ def __init__( "_refs.tanh", torch_opinfo_name="tanh", ), + ElementwiseUnaryPythonRefInfo( + "_refs.trunc", + torch_opinfo_name="trunc", + ), # # Elementwise Unary Special OpInfos # @@ -20328,6 +21021,11 @@ def __init__( "_refs.nn.functional.celu", torch_opinfo_name="nn.functional.celu", ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.threshold", + torch_opinfo_name="nn.functional.threshold", + supports_nvfuser=False, + ), PythonRefInfo( "_refs.nn.functional.dropout", torch_opinfo_name="nn.functional.dropout", @@ -20361,7 +21059,7 @@ def __init__( "_refs.nn.functional.elu", torch_opinfo_name="nn.functional.elu", ), - PythonRefInfo( + ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.hardtanh", torch_opinfo_name="nn.functional.hardtanh", supports_nvfuser=False, @@ -20373,7 +21071,6 @@ def __init__( PythonRefInfo( "_refs.nn.functional.layer_norm", torch_opinfo_name="nn.functional.layer_norm", - supports_nvfuser=False, skips=( # Reference result was farther (3.5762786809723224e-07) from the precise computation # than the torch result was (2.5068410824946596e-07)! @@ -20385,15 +21082,22 @@ def __init__( "_refs.nn.functional.leaky_relu", torch_opinfo_name="nn.functional.leaky_relu", ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.prelu", + torch_opinfo_name="nn.functional.prelu", + ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.relu", torch_opinfo_name="nn.functional.relu", supports_nvfuser=False, ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.relu6", + torch_opinfo_name="nn.functional.relu6", + ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.mish", torch_opinfo_name="nn.functional.mish", - supports_nvfuser=False, ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.selu", @@ -20403,11 +21107,21 @@ def __init__( "_refs.nn.functional.softplus", torch_opinfo_name="nn.functional.softplus", ), + PythonRefInfo( + "_refs.nn.functional.l1_loss", + torch_opinfo_name="nn.functional.l1_loss", + supports_nvfuser=False, + ), PythonRefInfo( "_refs.nn.functional.margin_ranking_loss", torch_opinfo_name="nn.functional.margin_ranking_loss", supports_nvfuser=False, ), + PythonRefInfo( + "_refs.nn.functional.mse_loss", + torch_opinfo_name="nn.functional.mse_loss", + supports_nvfuser=False, + ), PythonRefInfo( "_refs.nn.functional.hinge_embedding_loss", torch_opinfo_name="nn.functional.hinge_embedding_loss", @@ -20417,12 +21131,12 @@ def __init__( "_refs.nn.functional.tanhshrink", torch_opinfo_name="nn.functional.tanhshrink", ), - PythonRefInfo( + ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.hardshrink", torch_opinfo_name="nn.functional.hardshrink", supports_nvfuser=False, ), - PythonRefInfo( + ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.softshrink", torch_opinfo_name="nn.functional.softshrink", supports_nvfuser=False, @@ -20444,7 +21158,6 @@ def __init__( ElementwiseBinaryPythonRefInfo( "_refs.bitwise_and", torch_opinfo_name="bitwise_and", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.bitwise_left_shift", @@ -20454,17 +21167,69 @@ def __init__( ElementwiseBinaryPythonRefInfo( "_refs.bitwise_or", torch_opinfo_name="bitwise_or", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.bitwise_xor", torch_opinfo_name="bitwise_xor", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.copysign", + torch_opinfo_name="copysign", + supports_nvfuser=False, + skips=( + # RuntimeError: Expected divisor (b) to be on the same device (cuda:0) as dividend (a), but it is found on cpu! + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="no_rounding_mode", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=False, + supports_one_python_scalar=True, + supports_nvfuser=False, + skips=( + # NotImplementedError: argument of type: + DecorateInfo( + unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32, torch.complex64, torch.complex128,) + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda", active_if=not TEST_WITH_ROCM + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda", active_if=not TEST_WITH_ROCM + ), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="trunc_rounding", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=False, + supports_one_python_scalar=True, + supports_nvfuser=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="floor_rounding", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=False, + supports_one_python_scalar=True, supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.eq", torch_opinfo_name="eq", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.float_power", @@ -20475,31 +21240,40 @@ def __init__( DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), ) ), + ElementwiseBinaryPythonRefInfo( + "_refs.floor_divide", + torch_opinfo_name="floor_divide", + rhs_make_tensor_kwargs=dict(exclude_zero=True), + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=False, + supports_one_python_scalar=True, + supports_nvfuser=False, + # bfloat16 floor_divide compared with a float32 reference works inconsistently + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,)), + ), + ), ElementwiseBinaryPythonRefInfo( "_refs.fmax", torch_opinfo_name="fmax", supports_rhs_python_scalar=False, supports_nvfuser=False, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), - ), ), ElementwiseBinaryPythonRefInfo( "_refs.fmin", torch_opinfo_name="fmin", supports_rhs_python_scalar=False, supports_nvfuser=False, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), - ), ), ElementwiseBinaryPythonRefInfo( "_refs.fmod", torch_opinfo_name="fmod", rhs_make_tensor_kwargs={'exclude_zero': True}, - supports_nvfuser=False, + supports_rhs_python_scalar=True, skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', dtypes=(torch.bfloat16,), device_type='cpu'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', @@ -20514,11 +21288,21 @@ def __init__( ElementwiseBinaryPythonRefInfo( "_refs.ge", torch_opinfo_name="ge", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.gt", torch_opinfo_name="gt", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.heaviside", + torch_opinfo_name="heaviside", + supports_rhs_python_scalar=False, + supports_nvfuser=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.hypot", + torch_opinfo_name="hypot", + supports_rhs_python_scalar=False, supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( @@ -20548,12 +21332,10 @@ def __init__( ElementwiseBinaryPythonRefInfo( "_refs.le", torch_opinfo_name="le", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.logical_and", torch_opinfo_name="logical_and", - supports_nvfuser=False, ), ElementwiseUnaryPythonRefInfo( "_refs.logical_not", @@ -20562,17 +21344,14 @@ def __init__( ElementwiseBinaryPythonRefInfo( "_refs.logical_or", torch_opinfo_name="logical_or", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.logical_xor", torch_opinfo_name="logical_xor", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.lt", torch_opinfo_name="lt", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.maximum", @@ -20596,7 +21375,6 @@ def __init__( # https://github.com/pytorch/pytorch/issues/76944 supports_two_python_scalars=False, supports_one_python_scalar=True, - supports_nvfuser=False, skips=( # Reference result was farther (0.0) from the precise computation # than the torch result was (nan)! @@ -20622,7 +21400,6 @@ def __init__( ElementwiseBinaryPythonRefInfo( "_refs.ne", torch_opinfo_name="ne", - supports_nvfuser=False, ), ElementwiseBinaryPythonRefInfo( "_refs.nextafter", @@ -20632,7 +21409,7 @@ def __init__( ElementwiseBinaryPythonRefInfo( "_refs.pow", torch_opinfo_name="pow", - supports_nvfuser=False, + supports_nvfuser=False, # clone default skips=( # Reference result was farther (inf) from the precise # computation than the torch result was (nan)! @@ -20657,7 +21434,6 @@ def __init__( ElementwiseBinaryPythonRefInfo( "_refs.remainder", torch_opinfo_name="remainder", - supports_nvfuser=False, skips=( DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', dtypes=(torch.bfloat16,), device_type='cpu'), @@ -20665,6 +21441,21 @@ def __init__( dtypes=(torch.bfloat16,), device_type='cpu'), ), ), + ElementwiseBinaryPythonRefInfo( + "_refs.rsub", + torch_opinfo_name="rsub", + # https://github.com/pytorch/pytorch/issues/76944 + skips=( + # Reference result was farther (nan) from the precise computation than + # the torch result was (nan)! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.chalf,), device_type='cpu'), + # Reference result was farther (nan) from the precise computation than + # the torch result was (nan)! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.chalf,), device_type='cpu'), + ), + ), ElementwiseBinaryPythonRefInfo( "_refs.sub", torch_opinfo_name="sub", @@ -20688,7 +21479,6 @@ def __init__( # https://github.com/pytorch/pytorch/issues/76944 supports_two_python_scalars=False, supports_one_python_scalar=True, - supports_nvfuser=False, skips=( # Reference result was farther (0.7433461727239705) from the precise # computation than the torch result was (nan)! @@ -20722,6 +21512,24 @@ def __init__( # # Elementwise Ternary Reference OpInfos # + ElementwiseBinaryPythonRefInfo( + "_refs.clamp_min", + torch_opinfo_name="clamp_min", + supports_nvfuser=False, + skips=( + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.clamp_max", + torch_opinfo_name="clamp_max", + supports_nvfuser=False, + skips=( + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), PythonRefInfo( "_refs.clamp", torch_opinfo_name="clamp", @@ -20802,6 +21610,21 @@ def __init__( torch_opinfo_name="column_stack", supports_nvfuser=False, ), + ElementwiseUnaryPythonRefInfo( + "_refs.conj", + torch_opinfo_name="conj", + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.constant_pad_nd", + torch_opinfo_name="constant_pad_nd", + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.contiguous", + torch_opinfo_name="contiguous", + supports_nvfuser=False, + ), PythonRefInfo( "_refs.dsplit", torch_opinfo_name="dsplit", @@ -20816,6 +21639,11 @@ def __init__( DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), ), ), + PythonRefInfo( + "_refs.expand", + torch_opinfo_name="expand", + supports_nvfuser=False, + ), PythonRefInfo( "_refs.flatten", torch_opinfo_name="flatten", @@ -20853,7 +21681,6 @@ def __init__( PythonRefInfo( "_refs.native_layer_norm", torch_opinfo_name="native_layer_norm", - supports_nvfuser=False, ), PythonRefInfo( "_refs.permute", @@ -20950,48 +21777,39 @@ def __init__( ReductionPythonRefInfo( "_refs.all", torch_opinfo_name="all", - # RuntimeError: Tried to reduce a 0-dim tensor - supports_nvfuser=False, ), ReductionPythonRefInfo( "_refs.amax", torch_opinfo_name="amax", - supports_nvfuser=False, ), ReductionPythonRefInfo( "_refs.amin", torch_opinfo_name="amin", - supports_nvfuser=False, ), ReductionPythonRefInfo( "_refs.any", torch_opinfo_name="any", - supports_nvfuser=False ), ReductionPythonRefInfo( "_refs.mean", torch_opinfo_name="mean", supports_out=True, - supports_nvfuser=False, ), ReductionPythonRefInfo( "_refs.std", torch_opinfo_name="std", supports_out=True, - supports_nvfuser=False, ), # std_mean and var_mean are not ReductionInfos PythonRefInfo( "_refs.std_mean", torch_opinfo_name="std_mean", validate_view_consistency=False, - supports_nvfuser=False, ), ReductionPythonRefInfo( "_refs.sum", torch_opinfo_name="sum", supports_out=True, - supports_nvfuser=False, ), ReductionPythonRefInfo( "_refs.prod", @@ -21003,13 +21821,11 @@ def __init__( "_refs.var", torch_opinfo_name="var", supports_out=True, - supports_nvfuser=False, ), PythonRefInfo( "_refs.var_mean", torch_opinfo_name="var_mean", validate_view_consistency=False, - supports_nvfuser=False, ), # # Linear Algebra Operators @@ -21032,6 +21848,57 @@ def __init__( DecorateInfo(unittest.skip("diag is not supported by nvfuser"), 'TestCommon', 'test_python_ref_executor'), ), ), + PythonRefInfo( + "_refs.norm", + torch_opinfo_name="norm", + supports_out=True, + # Uses svdvals which does not support nvfuser + supports_nvfuser=False, + # Uses vector_norm inside and vector_norm is affected by + # https://github.com/pytorch/pytorch/issues/77216 + validate_view_consistency=False, + ), + # + # torch.linalg + # + ReductionPythonRefInfo( + "_refs.linalg.vector_norm", + torch_opinfo_name="linalg.vector_norm", + supports_out=True, + supports_nvfuser=False, # clone_default + ), + PythonRefInfo( + "_refs.linalg.matrix_norm", + torch_opinfo_name="linalg.matrix_norm", + supports_out=True, + # Uses svdvals which does not support nvfuser + supports_nvfuser=False, + # Uses vector_norm inside and vector_norm is affected by + # https://github.com/pytorch/pytorch/issues/77216 + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.linalg.norm", + torch_opinfo_name="linalg.norm", + supports_out=True, + # Uses svdvals which does not support nvfuser + supports_nvfuser=False, + # Uses vector_norm inside and vector_norm is affected by + # https://github.com/pytorch/pytorch/issues/77216 + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.linalg.svd", + torch_opinfo_name="linalg.svd", + supports_out=True, + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.linalg.svdvals", + torch_opinfo_name="linalg.svdvals", + supports_out=True, + supports_nvfuser=False, + ), # # Tensor Creation Reference OpInfos # @@ -21094,6 +21961,51 @@ def __init__( DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), ), ), + PythonRefInfo( + "_refs.new_empty", + torch_opinfo_name="new_empty", + supports_nvfuser=False, + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # Fixme: should not compare results of empty_like + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + ), + ), + PythonRefInfo( + "_refs.new_full", + torch_opinfo_name="new_full", + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.new_ones", + torch_opinfo_name="new_ones", + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.new_zeros", + torch_opinfo_name="new_zeros", + supports_nvfuser=False, + ), # # Conditional Reference OpInfos # @@ -21106,6 +22018,118 @@ def __init__( "_refs.where", torch_opinfo_name="where", op=lambda self, condition, other: refs.where(condition, self, other), + supports_nvfuser=False, + ), + # + # Test-related functions + # + PythonRefInfo( + "_refs.allclose", + torch_opinfo_name="allclose", + supports_nvfuser=False, + ), + # + # FFT OpInfos + # + SpectralFuncPythonRefInfo( + "_refs.fft.fft", + torch_opinfo_name="fft.fft", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ifft", + torch_opinfo_name="fft.ifft", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.rfft", + torch_opinfo_name="fft.rfft", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.irfft", + torch_opinfo_name="fft.irfft", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.hfft", + torch_opinfo_name="fft.hfft", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ihfft", + torch_opinfo_name="fft.ihfft", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.fftn", + torch_opinfo_name="fft.fftn", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ifftn", + torch_opinfo_name="fft.ifftn", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.rfftn", + torch_opinfo_name="fft.rfftn", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.irfftn", + torch_opinfo_name="fft.irfftn", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.hfftn", + torch_opinfo_name="fft.hfftn", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ihfftn", + torch_opinfo_name="fft.ihfftn", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.fft2", + torch_opinfo_name="fft.fft2", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ifft2", + torch_opinfo_name="fft.ifft2", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.rfft2", + torch_opinfo_name="fft.rfft2", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.irfft2", + torch_opinfo_name="fft.irfft2", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.hfft2", + torch_opinfo_name="fft.hfft2", + supports_nvfuser=False, + ), + SpectralFuncPythonRefInfo( + "_refs.fft.ihfft2", + torch_opinfo_name="fft.ihfft2", + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.fft.fftshift", + torch_opinfo_name="fft.fftshift", + supports_nvfuser=False, + ), + PythonRefInfo( + "_refs.fft.ifftshift", + torch_opinfo_name="fft.ifftshift", + supports_nvfuser=False, ), ] diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index c96e94f61ad07..7a0d60b802411 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -10,7 +10,7 @@ from torch.testing._internal.common_cuda import TEST_CUDNN from torch.testing._internal.common_dtype import floating_types from torch.testing._internal.common_device_type import ( - _TestParametrizer, _update_param_kwargs, skipIf, toleranceOverride, tol, + _TestParametrizer, _update_param_kwargs, toleranceOverride, tol, skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta) from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_nn import nllloss_reference, get_reduction @@ -104,25 +104,13 @@ def _parametrize_test(self, test, generic_cls, device_cls): _update_param_kwargs(param_kwargs, 'training', training) try: - active_decorators = [set_single_threaded_if_parallel_tbb] - if module_info.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype): - active_decorators.append(skipIf(True, "Skipped!")) - - if module_info.decorators is not None: - for decorator in module_info.decorators: - # Can't use isinstance as it would cause a circular import - if decorator.__class__.__name__ == 'DecorateInfo': - if decorator.is_active(generic_cls.__name__, test.__name__, - device_cls.device_type, dtype): - active_decorators += decorator.decorators - else: - active_decorators.append(decorator) @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) - for decorator in active_decorators: + for decorator in module_info.get_decorators(generic_cls.__name__, test.__name__, + device_cls.device_type, dtype): test_wrapper = decorator(test_wrapper) yield (test_wrapper, test_name, param_kwargs) @@ -187,16 +175,22 @@ def __init__(self, ): self.module_cls = module_cls self.module_inputs_func = module_inputs_func - self.skips = skips - self.decorators = decorators + self.decorators = (*(decorators if decorators else []), *(skips if skips else [])) self.dtypes = dtypes self.supports_gradgrad = supports_gradgrad self.gradcheck_nondet_tol = gradcheck_nondet_tol self.module_memformat_affects_out = module_memformat_affects_out self.train_and_eval_differ = train_and_eval_differ - def should_skip(self, cls_name, test_name, device_type, dtype): - return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips) + def get_decorators(self, test_class, test_name, device, dtype): + result = [set_single_threaded_if_parallel_tbb] + for decorator in self.decorators: + if isinstance(decorator, DecorateInfo): + if decorator.is_active(test_class, test_name, device, dtype): + result.extend(decorator.decorators) + else: + result.append(decorator) + return result @property def name(self): @@ -706,6 +700,26 @@ def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, r desc='no_batch_dim' )) + def fast_path_reference_fn(module, parameters, *args, **kwargs): + assert not module.training + module = module.train(True) + output = module(*args, **kwargs) + module = module.train(False) + return output + + if not training: + for norm_first in (True, False): + samples.append( + ModuleInput( + constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first), + forward_input=FunctionInput( + make_input((2, 3, 4)), + ), + reference_fn=fast_path_reference_fn, + desc="fast_path_norm_first" if norm_first else "fast_path" + ) + ) + return samples @@ -1074,6 +1088,10 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # Failure on ROCM for float32 issue #70125 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='cuda', dtypes=[torch.float64]), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1088,6 +1106,9 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # Failure on ROCM for float32 issue #70125 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1116,6 +1137,11 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # Failure on ROCM for float32 issue #70125 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'), + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', + dtypes=[torch.float64]), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1130,6 +1156,9 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # Failure on ROCM for float32 issue #70125 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1176,6 +1205,10 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # Lazy modules don't currently play well with ModuleInfo tests on the meta device. # See https://github.com/pytorch/pytorch/issues/70505 for more info. DecorateInfo(skipMeta), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='cuda', dtypes=[torch.float64]), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1193,6 +1226,9 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # See https://github.com/pytorch/pytorch/issues/70505 for more info. DecorateInfo(skipMeta), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1227,6 +1263,11 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # See https://github.com/pytorch/pytorch/issues/70505 for more info. DecorateInfo(skipMeta), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'), + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', + dtypes=[torch.float64]), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1244,6 +1285,9 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train # See https://github.com/pytorch/pytorch/issues/70505 for more info. DecorateInfo(skipMeta), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index a1aeb5414be80..d68ef4387b095 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -4,7 +4,7 @@ import unittest from copy import deepcopy -from functools import reduce +from functools import reduce, partial, wraps from itertools import product from operator import mul from math import pi @@ -499,22 +499,22 @@ def bce_with_logistic_no_reduce_scalar_test(): def kldivloss_with_target_no_reduce_test(): - i = torch.rand(10, 10).log() + t = torch.rand(10, 10) return dict( fullname='KLDivLoss_with_target_no_reduce', constructor=wrap_functional( - lambda t: F.kl_div(i.type_as(t), t, reduction='none')), - cpp_function_call='F::kl_div(i.to(t.options()), t, F::KLDivFuncOptions().reduction(torch::kNone))', - input_fn=lambda: torch.rand(10, 10), - cpp_var_map={'i': i, 't': '_get_input()'}, - reference_fn=lambda t, *_: - loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduction='none'), + lambda i: F.kl_div(i, t.type_as(i), reduction='none')), + cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.rand(10, 10).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False) def kldivloss_no_reduce_test(): - t = torch.randn(10, 10) + t = torch.rand(10, 10) return dict( fullname='KLDivLoss_no_reduce', constructor=wrap_functional( @@ -530,7 +530,7 @@ def kldivloss_no_reduce_test(): def kldivloss_no_reduce_scalar_test(): - t = torch.randn(()) + t = torch.rand(()) return dict( fullname='KLDivLoss_no_reduce_scalar', constructor=wrap_functional( @@ -545,22 +545,22 @@ def kldivloss_no_reduce_scalar_test(): def kldivloss_with_log_target_no_reduce_test(): - i = torch.rand(10, 10).log() + t = torch.rand(10, 10).log() return dict( fullname='KLDivLoss_with_log_target_no_reduce', constructor=wrap_functional( - lambda t: F.kl_div(i.type_as(t), t, reduction='none', log_target=True)), - cpp_function_call='F::kl_div(i.to(t.options()), t, F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', - input_fn=lambda: torch.rand(10, 10), - cpp_var_map={'i': i, 't': '_get_input()'}, - reference_fn=lambda t, *_: - loss_reference_fns['KLDivLoss_log_target'](i.type_as(t), t, reduction='none'), + lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), + cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', + input_fn=lambda: torch.rand(10, 10).log(), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False) def kldivloss_no_reduce_log_target_test(): - t = torch.randn(10, 10) + t = torch.rand(10, 10).log() return dict( fullname='KLDivLoss_no_reduce_log_target', constructor=wrap_functional( @@ -576,7 +576,7 @@ def kldivloss_no_reduce_log_target_test(): def kldivloss_no_reduce_scalar_log_target_test(): - t = torch.randn(()) + t = torch.rand(()).log() return dict( fullname='KLDivLoss_no_reduce_scalar_log_target', constructor=wrap_functional( @@ -4340,9 +4340,7 @@ def unsqueeze_inp(inp): def kldivloss_reference(input, target, reduction='mean'): - safe_target = target * (target > 0).type_as(target) - safe_target_log = (safe_target + (target <= 0).type_as(target)).log() - result = safe_target * (safe_target_log - input) + result = target * (target.log() - input) if reduction == 'mean': return result.mean() elif reduction == 'sum': @@ -4854,10 +4852,12 @@ def padding3d_circular(input, pad): ), dict( module_name='KLDivLoss', + constructor=wraps(nn.KLDivLoss)(partial(nn.KLDivLoss, log_target=True)), + cpp_constructor_args='torch::nn::KLDivLossOptions().log_target(true)', input_fn=lambda: torch.rand(10, 10).log(), - target_fn=lambda: torch.rand(10, 10), + target_fn=lambda: torch.rand(10, 10).log(), reference_fn=lambda i, t, m: - kldivloss_log_target_reference(i, t.log(), get_reduction(m)), + kldivloss_log_target_reference(i, t, get_reduction(m)), check_sum_reduction=True, desc='log_target', ), @@ -5451,10 +5451,12 @@ def padding3d_circular(input, pad): ), dict( module_name='KLDivLoss', + constructor=wraps(nn.KLDivLoss)(partial(nn.KLDivLoss, log_target=True)), + cpp_constructor_args='torch::nn::KLDivLossOptions().log_target(true)', input_fn=lambda: torch.rand(()).log(), - target_fn=lambda: torch.rand(()), + target_fn=lambda: torch.rand(()).log(), reference_fn=lambda i, t, m: - kldivloss_log_target_reference(i, t.log(), get_reduction(m)), + kldivloss_log_target_reference(i, t, get_reduction(m)), check_sum_reduction=True, desc='scalar_log_target', ), @@ -5651,7 +5653,7 @@ def flatten(xs): # Check that regression criterion work with no batch dimensions regression_criterion_no_batch = [ - 'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'KLDivLoss', 'HuberLoss', 'SmoothL1Loss' + 'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss' ] reductions = ['none', 'mean', 'sum'] for name, reduction in product(regression_criterion_no_batch, reductions): @@ -5666,6 +5668,18 @@ def flatten(xs): criterion_tests.append(regression_test_info) +for reduction in reductions: + regression_test_info = dict( + fullname=f"KLDivLoss_no_batch_dim_{reduction}", + constructor=lambda: nn.KLDivLoss(reduction=reduction), + input_fn=lambda: torch.rand((3,)).log(), + target_fn=lambda: torch.rand((3,)), + reference_fn=single_batch_reference_criterion_fn, + test_cpp_api_parity=False, + ) + criterion_tests.append(regression_test_info) + + # Check that classification criterion work with no batch dimensions # List of tuples of (name, input_fn, target_fn) classification_criterion_no_batch = [ diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 4835ef5431288..6c677d9c689ab 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -34,10 +34,11 @@ try: # graph mode quantization based on fx - from torch.quantization.quantize_fx import ( + from torch.ao.quantization.quantize_fx import ( prepare_fx, prepare_qat_fx, convert_fx, + convert_to_reference_fx, ) from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph from torch.fx.graph import Node @@ -883,7 +884,7 @@ def checkGraphModeFxOp( prepared_copy = copy.deepcopy(prepared) qgraph = convert_fx(copy.deepcopy(prepared)) - qgraph_reference = convert_fx(copy.deepcopy(prepared), is_reference=True) + qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared)) result = qgraph(*inputs) result_reference = qgraph_reference(*inputs) qgraph_copy = copy.deepcopy(qgraph) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index eb8bf6042ab38..b13d6678c5188 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -75,6 +75,8 @@ unregister_custom_op_symbolic) torch.backends.disable_global_flags() +PYTEST_FILES = ["test_ops", "test_ops_gradients", "test_ops_jit"] + FILE_SCHEMA = "file://" if sys.platform == 'win32': FILE_SCHEMA = "file:///" @@ -91,11 +93,9 @@ DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json' SLOW_TESTS_FILE = '.pytorch-slow-tests.json' -slow_tests_dict: Optional[Dict[str, Any]] = None -disabled_tests_dict: Optional[Dict[str, Any]] = None - NATIVE_DEVICES = ('cpu', 'cuda', 'meta') + class _TestParametrizer(object): """ Decorator class for parametrizing a test function, yielding a set of new tests spawned @@ -294,7 +294,7 @@ def test_baz(self, x, y): arg_str (str): String of arg names separate by commas (e.g. "x,y"). arg_values (iterable): Iterable of arg values (e.g. range(10)) or tuples of arg values (e.g. [(1, 2), (3, 4)]). - name_fn (callable): Optional function that takes in parameters and returns subtest name. + name_fn (Callable): Optional function that takes in parameters and returns subtest name. """ def __init__(self, arg_str, arg_values, name_fn=None): self.arg_names: List[str] = [s.strip() for s in arg_str.split(',')] @@ -335,6 +335,7 @@ def _parametrize_test(self, test, generic_cls, device_cls): # Each "values" item is expected to be either: # * A tuple of values with one for each arg. For a single arg, a single item is expected. # * A subtest instance with arg_values matching the previous. + values = check_exhausted_iterator = object() for values in self.arg_values: maybe_name = None if isinstance(values, subtest): @@ -370,6 +371,10 @@ def test_wrapper(*args, **kwargs): yield (gen_test, test_name, param_kwargs) + if values is check_exhausted_iterator: + raise ValueError('An empty arg_values was passed to @parametrize. ' + 'Note that this may result from reuse of a generator.') + class ProfilingMode(Enum): LEGACY = 1 @@ -507,6 +512,8 @@ def run_unittest_help(argv): # CI Prefix path used only on CI environment CI_TEST_PREFIX = str(Path(os.getcwd())) +CI_PT_ROOT = str(Path(os.getcwd()).parent) +CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) def wait_for_process(p): try: @@ -571,6 +578,15 @@ def sanitize_test_filename(filename): strip_py = re.sub(r'.py$', '', filename) return re.sub('/', r'.', strip_py) +# hack until https://github.com/pytorch/pytorch/issues/82109 is resolved +def sanitize_if_functorch_test_filename(filename): + # absolute filenames must be converted to relative paths, otherwise, + # we cannot prepend test-reports/ to it + # (e.g. test-reports\\C:\\... on windows is nonsense) + if filename.startswith(CI_FUNCTORCH_ROOT): + filename = filename[len(CI_PT_ROOT) + 1:] + return filename + def lint_test_case_extension(suite): succeed = True for test_case_or_suite in suite: @@ -589,20 +605,33 @@ def lint_test_case_extension(suite): succeed = False return succeed +def sanitize_pytest_xml(xml_file: str): + # pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml + # consider somehow modifying the XML logger in conftest to do this instead + import xml.etree.ElementTree as ET + tree = ET.parse(xml_file) + for testcase in tree.iter('testcase'): + full_classname = testcase.attrib['classname'] + regex_result = re.search(r"^test\.(.*)\.([^\.]*)$", full_classname) + classname = regex_result.group(2) + file = regex_result.group(1).replace('.', "/") + testcase.set('classname', classname) + testcase.set('file', f"{file}.py") + tree.write(xml_file) + def run_tests(argv=UNITTEST_ARGS): # import test files. if IMPORT_SLOW_TESTS: if os.path.exists(IMPORT_SLOW_TESTS): - global slow_tests_dict with open(IMPORT_SLOW_TESTS, 'r') as fp: - slow_tests_dict = json.load(fp) + # use env vars so pytest-xdist subprocesses can still access them + os.environ['SLOW_TESTS_DICT'] = fp.read() else: print(f'[WARNING] slow test file provided but not found: {IMPORT_SLOW_TESTS}') if IMPORT_DISABLED_TESTS: if os.path.exists(IMPORT_DISABLED_TESTS): - global disabled_tests_dict with open(IMPORT_DISABLED_TESTS, 'r') as fp: - disabled_tests_dict = json.load(fp) + os.environ['DISABLED_TESTS_DICT'] = fp.read() else: print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}') # Determine the test launch mechanism @@ -678,17 +707,45 @@ def addSkip(self, test, reason): # it stands for `verbose_str` captured in the closure c.cell_contents = f"skip: {reason}" - test_filename = sanitize_test_filename(inspect.getfile(sys._getframe(1))) + test_filename = inspect.getfile(sys._getframe(1)) + test_filename = sanitize_if_functorch_test_filename(test_filename) + test_filename = sanitize_test_filename(test_filename) test_report_path = TEST_SAVE_XML + LOG_SUFFIX test_report_path = os.path.join(test_report_path, test_filename) - os.makedirs(test_report_path, exist_ok=True) - verbose = '--verbose' in argv or '-v' in argv - if verbose: - print('Test results will be stored in {}'.format(test_report_path)) - unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner( - output=test_report_path, - verbosity=2 if verbose else 1, - resultclass=XMLTestResultVerbose)) + build_environment = os.environ.get("BUILD_ENVIRONMENT", "") + if test_filename in PYTEST_FILES and not IS_SANDCASTLE and not ( + "cuda" in build_environment and "linux" in build_environment + ): + # exclude linux cuda tests because we run into memory issues when running in parallel + import pytest + os.environ["NO_COLOR"] = "1" + os.environ["USING_PYTEST"] = "1" + pytest_report_path = test_report_path.replace('python-unittest', 'python-pytest') + os.makedirs(pytest_report_path, exist_ok=True) + # part of our xml parsing looks for grandparent folder names + pytest_report_path = os.path.join(pytest_report_path, f"{test_filename}.xml") + print(f'Test results will be stored in {pytest_report_path}') + # mac slower on 4 proc than 3 + num_procs = 3 if "macos" in build_environment else 4 + # f = failed + # E = error + # X = unexpected success + exit_code = pytest.main(args=[inspect.getfile(sys._getframe(1)), f'-n={num_procs}', '-vv', '-x', + '--reruns=2', '-rfEX', f'--junit-xml-reruns={pytest_report_path}']) + del os.environ["USING_PYTEST"] + sanitize_pytest_xml(f'{pytest_report_path}') + # exitcode of 5 means no tests were found, which happens since some test configs don't + # run tests from certain files + exit(0 if exit_code == 5 else exit_code) + else: + os.makedirs(test_report_path, exist_ok=True) + verbose = '--verbose' in argv or '-v' in argv + if verbose: + print(f'Test results will be stored in {test_report_path}') + unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner( + output=test_report_path, + verbosity=2 if verbose else 1, + resultclass=XMLTestResultVerbose)) elif REPEAT_COUNT > 1: for _ in range(REPEAT_COUNT): if not unittest.main(exit=False, argv=argv).result.wasSuccessful(): @@ -823,6 +880,30 @@ def __torch_function__(self, func, types, args=(), kwargs=None): r = func(*args, **kwargs) return r +# Run PyTorch tests with TorchDynamo +TEST_WITH_TORCHDYNAMO = os.getenv('PYTORCH_TEST_WITH_DYNAMO') == '1' +if TEST_WITH_TORCHDYNAMO: + import torchdynamo + # torchdynamo.config.trace = True + # torchdynamo.config.debug = True + torchdynamo.config.print_internal_exceptions = False + # TODO - Collect errors with fake tensors + torchdynamo.config.fake_tensor_propagation = False + # Do not spend time on helper functions that are called with different inputs + torchdynamo.config.cache_size_limit = 8 + + +def skipIfTorchDynamo(msg="test doesn't currently work with torchdynamo"): + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_WITH_TORCHDYNAMO: + raise unittest.SkipTest(msg) + else: + fn(*args, **kwargs) + return wrapper + return decorator + # Determine whether to enable cuda memory leak check. # CUDA mem leak check is expensive and thus we don't want to execute it on every # test case / configuration. @@ -1188,20 +1269,35 @@ def set_rng_seed(seed): np.random.seed(seed) +@contextmanager +def disable_functorch(): + guard = torch._C._DisableFuncTorch() # type: ignore[attr-defined] + try: + yield + finally: + del guard + + @contextlib.contextmanager def freeze_rng_state(): # no_dispatch needed for test_composite_compliance # Some OpInfos use freeze_rng_state for rng determinism, but # test_composite_compliance overrides dispatch for all torch functions # which we need to disable to get and set rng state - with no_dispatch(): + with no_dispatch(), disable_functorch(): rng_state = torch.get_rng_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() try: yield finally: - with no_dispatch(): + # Modes are not happy with torch.cuda.set_rng_state + # because it clones the state (which could produce a Tensor Subclass) + # and then grabs the new tensor's data pointer in generator.set_state. + # + # In the long run torch.cuda.set_rng_state should probably be + # an operator. + with no_dispatch(), disable_functorch(): if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) torch.set_rng_state(rng_state) @@ -1464,13 +1560,16 @@ def remove_device_and_dtype_suffixes(test_name: str) -> str: def check_if_enable(test: unittest.TestCase): test_suite = str(test.__class__).split('\'')[1] + if "USING_PYTEST" in os.environ: + test_suite = f"__main__.{test_suite.split('.')[1]}" raw_test_name = f'{test._testMethodName} ({test_suite})' - if slow_tests_dict is not None and raw_test_name in slow_tests_dict: + if raw_test_name in json.loads(os.environ.get("SLOW_TESTS_DICT", "{}")): getattr(test, test._testMethodName).__dict__['slow_test'] = True if not TEST_WITH_SLOW: raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test") sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName) - if not IS_SANDCASTLE and disabled_tests_dict is not None: + if not IS_SANDCASTLE and "DISABLED_TESTS_DICT" in os.environ: + disabled_tests_dict = json.loads(os.environ["DISABLED_TESTS_DICT"]) for disabled_test, (issue_url, platforms) in disabled_tests_dict.items(): disable_test_parts = disabled_test.split() if len(disable_test_parts) > 1: @@ -1826,7 +1925,16 @@ def _run_with_retry(self, result=None, num_runs_left=0, report_only=True, num_re failures_before = 0 if result is None else len(result.failures) # num tests marked as failed before starting errors_before = 0 if result is None else len(result.errors) # num tests marked as errored before starting - super().run(result=result) + + if TEST_WITH_TORCHDYNAMO: + with torchdynamo.optimize("eager"): + super().run(result=result) + + # TODO - Reset for each test slows down testing significantly. + # torchdynamo.reset() + else: + super().run(result=result) + # Early terminate test if necessary. if self._should_stop_test_suite(): if result.wasSuccessful(): @@ -1881,7 +1989,7 @@ def _run_with_retry(self, result=None, num_runs_left=0, report_only=True, num_re def run(self, result=None): with contextlib.ExitStack() as stack: if TEST_WITH_CROSSREF: - stack.enter_context(torch.overrides.push_torch_function_mode(CrossRefMode)) + stack.enter_context(CrossRefMode()) num_runs = MAX_NUM_RETRIES + 1 if RETRY_TEST_CASES else 1 self._run_with_retry( result=result, @@ -2196,6 +2304,17 @@ def assertEqualIgnoreType(self, *args, **kwargs) -> None: # and deserves detailed investigation return self.assertEqual(*args, exact_dtype=False, **kwargs) + def assertEqualBroadcasting(self, x, y, *args, **kwargs) -> None: + r"""Tests if tensor x equals to y, if y to be broadcast to x.shape. + """ + if not isinstance(y, Iterable): + # int, float, etc. or different shape tensors + y = torch.ones_like(x) * y + if not isinstance(y, torch.Tensor): + # iterable, but not a tensor + y = torch.ones_like(x) * torch.tensor(y) + return self.assertEqual(x, y, *args, **kwargs) + def assertEqual( self, x, @@ -2250,7 +2369,7 @@ def to_list(input): ), sequence_types=( Sequence, - torch.storage._TypedStorage, + torch.storage.TypedStorage, Sequential, ModuleList, ParameterList, @@ -3015,6 +3134,14 @@ def __exit__(self, *args): # For more information see https://github.com/pytorch/pytorch/issues/56202 GRADCHECK_NONDET_TOL = 1e-12 +def is_slow_gradcheck_env() -> bool: + return os.environ.get('PYTORCH_TEST_WITH_SLOW_GRADCHECK', "0") == "1" + +skipIfSlowGradcheckEnv = unittest.skipIf( + is_slow_gradcheck_env(), + "Tests that don't use gradcheck don't need to run on slow_gradcheck CI" +) + def gradcheck(fn, inputs, **kwargs): # Wrapper around gradcheck that enables certain keys by default. # Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and @@ -3027,7 +3154,7 @@ def gradcheck(fn, inputs, **kwargs): "fast_mode": True, } - if os.environ.get('PYTORCH_TEST_WITH_SLOW_GRADCHECK', "0") == "1": + if is_slow_gradcheck_env(): default_values["fast_mode"] = False for key, value in default_values.items(): @@ -3047,7 +3174,7 @@ def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs): "fast_mode": True, } - if os.environ.get('PYTORCH_TEST_WITH_SLOW_GRADCHECK', "0") == "1": + if is_slow_gradcheck_env(): default_values["fast_mode"] = False for key, value in default_values.items(): diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index 7eeefa333efaa..166c37f082cdc 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -8,6 +8,7 @@ from torch.utils._python_dispatch import enable_torch_dispatch_mode import torch.autograd.forward_ad as fwAD from torch.overrides import enable_reentrant_dispatch +from typing import Callable import re @@ -142,13 +143,19 @@ def __new__(cls, elem, *args, **kwargs): device=elem.device, requires_grad=elem.requires_grad, strides=elem.stride(), storage_offset=elem.storage_offset()) - # CompositeCompliantTensor steals the "requires_grad"-ness. if elem.requires_grad: - # Why clone? Because sometimes OpInfo shares inputs between tests... - r.elem = elem.detach().clone() + # CompositeCompliantTensor steals the "requires_grad"-ness. + # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests... + tmp = torch.empty_strided(elem.shape, elem.stride(), dtype=elem.dtype, + device=elem.device, layout=elem.layout, + requires_grad=False) + tmp.copy_(elem.detach()) + r.elem = tmp else: r.elem = elem + assert r.stride() == r.elem.stride() + # Propagate conjugate bits to the wrapper tensor # Ref: https://github.com/albanD/subclass_zoo/issues/24 # Ref: https://github.com/albanD/subclass_zoo/issues/21 @@ -167,6 +174,12 @@ def unwrap(e): def wrap(e): return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e + if func == torch.ops.aten._local_scalar_dense.default: + raise RuntimeError( + ".item() is not allowed to be called inside of composite " + "functions in the PyTorch library because not all backends " + "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.") + if func.overloadpacket.__name__ in ('set_', 'resize_'): raise RuntimeError( f"{func.__name__} is not allowed to be called inside of " @@ -322,14 +335,15 @@ def raise_composite_compliance_error(err, additional_info=''): # If some composite operation does any non-compliant behavior, # CompositeCompliantTensor will raise an error. -def check_all_permutations(op, args, kwargs): +def check_all_permutations(op, args, kwargs, assert_equal_fn): CCT = generate_cct() + expected = op(*args, **kwargs) for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice try: - op(*new_args, **new_kwargs) - # NOTE: [What errors are Composite Compiance trying to catch?] + actual = op(*new_args, **new_kwargs) + # NOTE: [What errors are Composite Compliance trying to catch?] # # There's two things we want to catch: # - errors that would raise within the torch_dispatch impl @@ -349,6 +363,11 @@ def check_all_permutations(op, args, kwargs): f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" ) + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tree_map(unwrap, actual), expected) + # Checks via the usage of torch dispatch mode certain anti-patterns that # are not composite compliant. # @@ -361,21 +380,28 @@ def check_all_permutations(op, args, kwargs): # CompositeCompliantTensor wrappers. If an operator that is # Composite does any non-compliant behavior, # CompositeCompliantTensor will raise an error. -def check_with_mode(op, args, kwargs): +def check_with_mode(op, args, kwargs, assert_equal_fn): CCT = generate_cct() def wrap(e): return CCT(e) if isinstance(e, torch.Tensor) else e + expected = op(*args, **kwargs) + args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) try: with enable_torch_dispatch_mode(CCT): - op(*args, **kwargs) - # see NOTE: [What errors are Composite Compiance trying to catch?] + actual = op(*args, **kwargs) + # see NOTE: [What errors are Composite Compliance trying to catch?] except RuntimeError as err: raise_composite_compliance_error(err) + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tree_map(unwrap, actual), expected) + def gather_leaf_tensors(args, kwargs): leaf_tensors = [] args, args_spec = tree_flatten(args) @@ -392,19 +418,50 @@ def gather_leaf_tensors(args, kwargs): # Checks if the backward formula is composite compliant by testing # all possible permutations of {inputs, grad_outputs} being # CompositeCompliantTensor or regular Tensors. -def check_backward_formula(op, args, kwargs, output_process_fn_grad=None): - assert op.supports_autograd +# +# NB: it is important that op is accepted as a Callable and not an OpInfo, +# this means we can apply check_backward_formula to things that aren't OpInfos +# while debugging. +def check_backward_formula(op: Callable, args, kwargs, + output_process_fn_grad=None, + gradcheck_wrapper=None, assert_equal_fn=None): CCT = generate_cct() + + def compute_expected_grads(args, kwargs): + if gradcheck_wrapper is None: + results = op(*args, **kwargs) + else: + results = gradcheck_wrapper(op, *args, **kwargs) + + if output_process_fn_grad is not None: + results = output_process_fn_grad(results) + + flat_results, _ = tree_flatten(results) + flat_diff_results = [r for r in flat_results if r.requires_grad] + assert len(flat_diff_results) > 0 + + grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) + for r in flat_diff_results] + leaf_tensors = gather_leaf_tensors(args, kwargs) + assert len(leaf_tensors) > 0 + return torch.autograd.grad(flat_diff_results, leaf_tensors, + grads, allow_unused=True, retain_graph=True) + + expected = compute_expected_grads(args, kwargs) + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice leaf_tensors = gather_leaf_tensors(new_args, new_kwargs) assert len(leaf_tensors) > 0 try: - results = op.gradcheck_wrapper(op.get_op(), *new_args, **new_kwargs) + if gradcheck_wrapper is None: + results = op(*new_args, **new_kwargs) + else: + results = gradcheck_wrapper(op, *new_args, **new_kwargs) if output_process_fn_grad is not None: results = output_process_fn_grad(results) - # see NOTE: [What errors are Composite Compiance trying to catch?] + # see NOTE: [What errors are Composite Compliance trying to catch?] except RuntimeError as err: raise_composite_compliance_error( err, @@ -412,11 +469,6 @@ def check_backward_formula(op, args, kwargs, output_process_fn_grad=None): f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" ) - # Hack: tree_flatten doesn't handle torch.return_types yet, - # so we're gonna convert them to tuple. - # TODO: https://github.com/pytorch/pytorch/issues/74624 - if isinstance(results, tuple): - results = tuple(results) flat_results, _ = tree_flatten(results) flat_diff_results = [r for r in flat_results if r.requires_grad] assert len(flat_diff_results) > 0 @@ -426,9 +478,9 @@ def check_backward_formula(op, args, kwargs, output_process_fn_grad=None): for r in flat_diff_results] for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT): try: - torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads, - allow_unused=True, retain_graph=True) - # see NOTE: [What errors are Composite Compiance trying to catch?] + actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads, + allow_unused=True, retain_graph=True) + # see NOTE: [What errors are Composite Compliance trying to catch?] except RuntimeError as err: raise_composite_compliance_error( err, @@ -437,55 +489,79 @@ def check_backward_formula(op, args, kwargs, output_process_fn_grad=None): f"- wrapped_grads: {which_grad_is_batched}\n" ) + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True) + # Checks if the forward AD formula is composite compliant by testing # all possible permutations of {primals, tangents} being # CompositeCompliantTensor or regular Tensors. -def check_forward_ad_formula(op, args, kwargs): - assert op.supports_forward_ad - +# +# NB: it is important that op is accepted as a Callable and not an OpInfo, +# this means we can apply check_forward_ad_formula to things that aren't OpInfos +# while debugging. +def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None): CCT = generate_cct(enable_recursive_torch_dispatch=True, autograd_view_consistency=False) - # Permutations of arg and kwargs in CCT. - for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): - new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice - def maybe_tangent(t): - assert type(t) is not CCT - # Generate `tangent` tensor - # if given object is a Tensor and requires grad is set. - if isinstance(t, torch.Tensor) and t.requires_grad: - return torch.randn_like(t) - elif is_tensorlist(t): - return list(torch.randn_like(e) if e.requires_grad else None for e in t) - return None - - tangent_args = tuple(maybe_tangent(arg) for arg in args) - flat_kwargs, spec = tree_flatten(kwargs) - flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs) - tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec) - - # Permutations tangent arg and tangent kwargs in CCT. - for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT): - new_tang_args, new_tang_kwargs, \ - which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice - - with fwAD.dual_level(): - def maybe_make_dual(dual): - # Returns dual tensor if primal is a tensor/tensor subclass - # with requires_grad set. - primal, tangent = dual - if isinstance(primal, torch.Tensor) and primal.requires_grad: - return fwAD.make_dual(primal, tangent) - elif is_tensorlist(primal): - return tuple(fwAD.make_dual(pri, tang) if tang is not None else pri - for pri, tang in zip(primal, tangent)) - return primal + def maybe_tangent(t): + assert type(t) is not CCT + # Generate `tangent` tensor + # if given object is a Tensor and requires grad is set. + if isinstance(t, torch.Tensor) and t.requires_grad: + return torch.randn_like(t) + elif is_tensorlist(t): + return list(torch.randn_like(e) if e.requires_grad else None for e in t) + return None + + tangent_args = tuple(maybe_tangent(arg) for arg in args) + flat_kwargs, spec = tree_flatten(kwargs) + flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs) + tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec) + + with fwAD.dual_level(): + def maybe_make_dual(dual): + # Returns dual tensor if primal is a tensor/tensor subclass + # with requires_grad set. + primal, tangent = dual + if isinstance(primal, torch.Tensor) and primal.requires_grad: + return fwAD.make_dual(primal, tangent) + elif is_tensorlist(primal): + return tuple(fwAD.make_dual(pri, tang) if tang is not None else pri + for pri, tang in zip(primal, tangent)) + return primal + + def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs): + op_args = tuple(map(maybe_make_dual, zip(args, tangent_args))) + op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()} + + if gradcheck_wrapper is None: + return op(*op_args, **op_kwargs) + return gradcheck_wrapper(op, *op_args, **op_kwargs) + + expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs) + expected = tree_map(fwAD.unpack_dual, expected) + expected_primals = tree_map(lambda x: x.primal, expected) + expected_tangents = tree_map(lambda x: x.tangent, expected) + + # Permutations of arg and kwargs in CCT. + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + + # Permutations tangent arg and tangent kwargs in CCT. + for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT): + new_tang_args, new_tang_kwargs, \ + which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args))) op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()} try: - op.gradcheck_wrapper(op.get_op(), *op_args, **op_kwargs) - # see NOTE: [What errors are Composite Compiance trying to catch?] + if gradcheck_wrapper is None: + actual = op(*op_args, **op_kwargs) + else: + actual = gradcheck_wrapper(op, *op_args, **op_kwargs) + # see NOTE: [What errors are Composite Compliance trying to catch?] except RuntimeError as err: raise_composite_compliance_error( err, @@ -494,3 +570,12 @@ def maybe_make_dual(dual): f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n" f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n" ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + actual = tree_map(fwAD.unpack_dual, actual) + actual_primals = tree_map(lambda x: unwrap(x.primal), actual) + actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual) + assert_equal_fn(actual_primals, expected_primals, equal_nan=True) + assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True) diff --git a/torch/testing/_internal/distributed/distributed_utils.py b/torch/testing/_internal/distributed/distributed_utils.py new file mode 100644 index 0000000000000..8473077c3c7f4 --- /dev/null +++ b/torch/testing/_internal/distributed/distributed_utils.py @@ -0,0 +1,64 @@ +from contextlib import contextmanager +from datetime import timedelta +from functools import ( + partial, + wraps, +) + +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d + +class MockProcessGroup(dist.ProcessGroup): + + def __init__(self, rank, world): + super(MockProcessGroup, self).__init__(rank, world) + + def getBackendName(self): + return "mock_process_group" + +def create_mock_pg(prefix_store, rank, world_size, timeout): + return MockProcessGroup(rank, world_size) + +dist.Backend.register_backend('mock_process_group', create_mock_pg) + +def mock_init_dist(rank, world_size): + # !!! WARNING !!! + # Kids don't try this at home, this is a cute pile of hacks that + # depends on a small mountain of c10d internals + assert not dist.is_initialized() + store = dist.HashStore() + # Trick _store_based_barrier into believing everyone else already checked-in + # Zero is the group index + store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1) + dist.init_process_group( + backend="mock_process_group", + rank=rank, + world_size=world_size, + store=store, + group_name="fake", + timeout=timedelta(seconds=1)) + +@contextmanager +def with_dist(rank=0, world_size=2): + """ + Context manager that initializer c10d with a fake process group. + """ + mock_init_dist(rank=rank, world_size=world_size) + try: + yield + finally: + dist.destroy_process_group() + +def with_fake_comms(func=None, rank=0, world_size=2): + """ + Function wrapper that inits a fake process group designed for testing. + Right now only querying for world size is available + """ + if func is None: + return partial(with_fake_comms, rank=rank, world_size=world_size) + + @wraps(func) + def wrapper(self, *args, **kwargs): + with with_dist(rank, world_size): + func(self, *args, **kwargs) + return wrapper diff --git a/torch/testing/_internal/logging_tensor.py b/torch/testing/_internal/logging_tensor.py index 22b33eace3ce1..649c411fc9443 100644 --- a/torch/testing/_internal/logging_tensor.py +++ b/torch/testing/_internal/logging_tensor.py @@ -4,7 +4,7 @@ import logging import contextlib import itertools -from torch.utils._python_dispatch import TorchDispatchMode, push_torch_dispatch_mode +from torch.utils._python_dispatch import TorchDispatchMode # How the chain of calls works for LoggingTensor: @@ -124,5 +124,5 @@ def capture_logs(is_mode=False) -> Iterator[List[str]]: @contextlib.contextmanager def capture_logs_with_logging_tensor_mode(): - with push_torch_dispatch_mode(LoggingTensorMode), capture_logs(True) as logs: + with LoggingTensorMode(), capture_logs(True) as logs: yield logs diff --git a/torch/testing/_internal/schema_check_mode.py b/torch/testing/_internal/schema_check_mode.py index fe124379c5e67..3aa7c400c8bd5 100644 --- a/torch/testing/_internal/schema_check_mode.py +++ b/torch/testing/_internal/schema_check_mode.py @@ -3,6 +3,18 @@ from torch.fx.operator_schemas import normalize_function from torch.testing._internal.jit_utils import clone_inputs from torch.utils._python_dispatch import TorchDispatchMode +from itertools import combinations +from collections import namedtuple +from copy import deepcopy + +# Named Tuples used within SchemaCheckMode +Mutation = namedtuple('Mutation', ['op_name', 'arg_name']) +Aliasing = namedtuple('Aliasing', ['op_name', 'arg_name', 'output_number']) + +# Simplified naming for C++ classes +SchemaArgument = torch._C._SchemaArgument +SchemaArgType = torch._C._SchemaArgType +SchemaInfo = torch._C._SchemaInfo # This TorchDispatchMode Subclass is used to verify op schemas # This TorchDispatchMode Scubclass currently: @@ -12,29 +24,42 @@ class SchemaCheckMode(TorchDispatchMode): def __init__(self): + # Information recorded for testing purposes. For example: + # - incorrect schemas + # - overly conservative schemas self.ops = [] + self.mutated = [] + self.aliasing = [] def reset_cache(self): self.ops.clear() + self.mutated.clear() + self.aliasing.clear() def display_ops(self): print(*self.ops, sep=",") def __torch_dispatch__(self, func, types, args=(), kwargs=None): - def has_mutated(before, after): - return not torch.equal(before, after) if isinstance(before, torch.Tensor) and isinstance(after, torch.Tensor) else False + def has_mutated(before, after, md): + if type(before) == torch.Tensor and type(after) == torch.Tensor: + return not ( + torch.equal(before, after) and + md[0] == after.stride() and + md[1] == after.storage()._cdata + ) + return False def has_aliased(lhs, rhs): - return torch._C._is_alias_of(lhs, rhs) if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor) else False + try: + return torch._C._overlaps(lhs, rhs) + except Exception as exception: + if str(exception).startswith("Cannot inspect value of type "): + return False + else: + raise exception - def is_mutable(alias_info): - return alias_info is not None and alias_info.is_write - - def is_aliasing(lhs_alias_info, rhs_alias_info): - if lhs_alias_info is None or rhs_alias_info is None: - return False - else: - return bool(len(lhs_alias_info.before_set & rhs_alias_info.before_set)) + def standardize_name(name): + return name if name != "self" else "input" def unwrap(e): if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: @@ -42,36 +67,70 @@ def unwrap(e): return e.elem except AttributeError as t: return e - else: - return e + return e + + def parse_metadata(e): + if isinstance(e, torch.Tensor): + if not type(e) == torch.Tensor: + try: + current = e.elem + return (deepcopy(current.stride()), current.storage()._cdata) + except AttributeError as t: + return None + else: + return (deepcopy(e.stride()), e.storage()._cdata) + return None self.ops.append(func._schema.name) - arguments = normalize_function( + + # Clone and process arguments and outputs + pre_arguments = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True ).kwargs - cloned_arguments = dict(zip(arguments.keys(), clone_inputs(arguments.values()))) + c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values()))) + cloned_arguments = {name : tree_map(unwrap, c_p_args.get(name)) for name in c_p_args} + cloned_metadata = {name : tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0]) for name in pre_arguments} + out = func(*args, **kwargs) + arguments = {name : tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments} + tuple_out = out if isinstance(out, tuple) else (out, ) + tuple_out = tree_map(unwrap, tuple_out) - for arg in func._schema.arguments: - name = arg.name if arg.name != "self" else "input" + schema_info = SchemaInfo(func._schema) + schema_info.add_argument_values(pre_arguments) + + # Process arguments with outputs + for i in range(len(func._schema.arguments)): + arg = func._schema.arguments[i] + name = standardize_name(arg.name) if arguments.get(name) is not None: - before = tree_flatten(cloned_arguments.get(name))[0] - after = tree_flatten(arguments.get(name))[0] - u_values = tree_map(unwrap, after) - u_out = tree_map(unwrap, out) - if (any([has_mutated(i, j) for i, j in zip(before, after)]) and not is_mutable(arg.alias_info)): - raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated") - for v in u_values: - if not isinstance(u_out, tuple): - if has_aliased(v, u_out) and not is_aliasing(arg.alias_info, func._schema.returns[0].alias_info): + before = cloned_arguments.get(name) + md = cloned_metadata.get(name) + after = arguments.get(name) + for j in range(len(tuple_out)): + if has_aliased(tuple_out[j], after): + if not schema_info.may_contain_alias( + SchemaArgument(SchemaArgType.output, j), + SchemaArgument(SchemaArgType.input, i)): raise RuntimeError(f'Argument {name} is not defined to alias output but was aliasing') + else: + self.aliasing.append(Aliasing(func._schema.name, name, f"output_{j}")) + if any(has_mutated(a, b, c) for a, b, c in zip(tree_flatten(before)[0], tree_flatten(after)[0], md)): + if not schema_info.is_mutable(SchemaArgument(SchemaArgType.input, i)): + raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated") else: - for j in range(len(u_out)): - if has_aliased(v, u_out[j]) and not is_aliasing(arg.alias_info, func._schema.returns[j].alias_info): - raise RuntimeError(f'Argument {name} is not defined to alias output but was aliasing') + self.mutated.append(Mutation(func._schema.name, name)) + + # Aliasing between outputs + for i, j in combinations(range(len(func._schema.returns)), 2): + if has_aliased(tuple_out[i], tuple_out[j]): + if not schema_info.may_contain_alias( + SchemaArgument(SchemaArgType.output, i), + SchemaArgument(SchemaArgType.output, j)): + raise RuntimeError(f'Outputs {i} and {j} alias unexpectedly') return out diff --git a/torch/torch_version.py b/torch/torch_version.py index 1e2b1348a312c..745595f1df15b 100644 --- a/torch/torch_version.py +++ b/torch/torch_version.py @@ -1,6 +1,8 @@ from typing import Any, Iterable from .version import __version__ as internal_version +__all__ = ['TorchVersion', 'Version', 'InvalidVersion'] + class _LazyImport: """Wraps around classes lazy imported from packaging.version Output of the function v in following snippets are identical: diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py index ded57fcc5ba4d..9b2526bc08386 100644 --- a/torch/utils/_mode_utils.py +++ b/torch/utils/_mode_utils.py @@ -1,24 +1,23 @@ import functools import torch -from typing import Iterator +from typing import Iterator, TypeVar from dataclasses import dataclass from contextlib import contextmanager +T = TypeVar('T') + # This file has all the logic to dedupe logic between torch dispatch and # torch function modes # # Specifically, it has the helper functions for enable_ and push_X_mode and the # ModeInfo class, which is extended by each where they are different -# used by both TorchFunctionMode and TorchDispatchMode, this will wrap the init -# function to require an "inner" kwarg def _wrap_init(f): - undef = object() - @functools.wraps(f) - def wrapped(self, *args, inner=undef, **kwargs): - if inner is not undef: - self.inner = inner + def wrapped(self, *args, **kwargs): + if 'inner' in kwargs: + self.inner = kwargs['inner'] + del kwargs['inner'] return f(self, *args, **kwargs) return wrapped @@ -31,7 +30,6 @@ def wrapped(self, *args, inner=undef, **kwargs): class _ModeInfo: mode_name: str mode_class: type # the class related to the mode that's allowed to be passed in - base_mode_class: type # the base class of mode_class that dispatches to the original function def mode_class_name(self): return self.mode_class.__name__ @@ -51,7 +49,7 @@ def set_mode(self, mode): # shared version of enable_torch_function/enable_torch_dispatch_mode in order to deduplicate the code. # The differences between the modes are captured by `mode_info` and then queried when they're # needed during the function's invocation -def _enable_mode(mode, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[None]: +def _enable_mode(mode: T, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[T]: if not ( mode is None or isinstance(mode, mode_info.mode_class) or @@ -61,7 +59,7 @@ def _enable_mode(mode, mode_info: _ModeInfo, *, replace=None, ignore_preexisting f'or None as an argument got {type(mode)} instead') old = mode_info.get_mode() if old is mode: - yield + yield mode # type: ignore[misc] return if old is not None and not ignore_preexisting and old is not replace: if isinstance(mode, mode_info.mode_class): @@ -79,14 +77,15 @@ def _enable_mode(mode, mode_info: _ModeInfo, *, replace=None, ignore_preexisting ) # NB: we don't require TorchFunctionMode/PythonMode since this is intended to also # let you directly pass a Tensor subclass type to "mode-ify" it. - required_fn = "__" + mode_info.mode_name + "__" - if not hasattr(mode, required_fn): - raise ValueError( - f'The argument passed to enable_{mode_info.mode_name}_mode must implement {required_fn}' - ) + if mode is not None: + required_fn = "__" + mode_info.mode_name + "__" + if not hasattr(mode, required_fn): + raise ValueError( + f'The argument passed to enable_{mode_info.mode_name}_mode must implement {required_fn}' + ) mode_info.set_mode(mode) try: - yield + yield mode # type: ignore[misc] finally: mode_info.set_mode(old) @@ -104,36 +103,6 @@ def _restore_mode(mode, mode_info: _ModeInfo): mode_info.set_mode(old) -# shared version of push_torch_function/push_torch_dispatch_mode in order to deduplicate the code. -# The differences between the modes are captured by `mode_info` and then queried when they're -# needed during the function's invocation -def _push_mode(ctor, mode_info: _ModeInfo) -> Iterator[object]: - # Helper function for pushing a mode onto the stack - if isinstance(ctor, mode_info.mode_class): - raise ValueError( - f'Expected a {mode_info.mode_class_name()} constructor function, but got an ' - f'instance of {mode_info.mode_class_name()} {ctor}. Consider using ' - f'enable_{mode_info.mode_name}_mode instead.' - ) - old = mode_info.get_mode() - if old is None: - inner = mode_info.base_mode_class(inner=None) - else: - inner = old - - mode = ctor(inner=inner) - if not isinstance(mode, mode_info.mode_class): - raise ValueError( - f'The callable passed to push_{mode_info.mode_name}_mode' - f'must return a {mode_info.mode_class_name()}' - ) - mode_info.set_mode(mode) - try: - yield mode - finally: - mode_info.set_mode(old) - - # To help with non-lexical scoping, it will error if all the modes are from different scopes or haven't been used def find_outermost_mode(modes): outermost = None diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 0bf07954e6423..0e8c57d926263 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -2,7 +2,8 @@ from typing import Iterator, Set import functools -from torch.utils._mode_utils import _enable_mode, _push_mode, _ModeInfo, _wrap_init, _restore_mode +import warnings +from torch.utils._mode_utils import _enable_mode, _ModeInfo, _wrap_init, _restore_mode from torch._C import _get_torch_dispatch_mode, _set_torch_dispatch_mode from dataclasses import dataclass @@ -10,8 +11,7 @@ @dataclass class TorchDispatchModeInfo(_ModeInfo): def __init__(self): - super().__init__(mode_name="torch_dispatch", mode_class=TorchDispatchMode, - base_mode_class=BaseTorchDispatchMode) + super().__init__(mode_name="torch_dispatch", mode_class=TorchDispatchMode) def get_mode(self): return _get_torch_dispatch_mode() @@ -132,13 +132,13 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta): ``NotImplemented``. Independent subclasses of :class:`TorchDispatchMode` are compositional: - modes can be pushed onto a stack with :func:`push_torch_dispatch_mode`. + modes can be pushed onto a stack using ``with MyMode():``. When you call functions in the PyTorch API inside your ``__torch_dispatch__`` implementation, by default, they will forward on to the next mode on the mode stack. If you want recursively call back into your current ``__torch_dispatch__`` implementation, either explicitly invoke ``self.__torch_dispatch__(...)``, or use the context manager - ``__torch_dispatch__(self, replace=self.inner)`` to make PyTorch + ``__torch_dispatch__(self)`` to make PyTorch API self-referential (beware of infinite loops, in this case!) """ # Force metaclass to generate constructor at the base of the hierarchy @@ -159,6 +159,7 @@ def __enter__(self): else: self.ancestors = self.inner.ancestors.union({self.inner}) _set_torch_dispatch_mode(self) + return self def __exit__(self, exc_type, exc_val, exc_tb): _set_torch_dispatch_mode(self.inner) @@ -169,7 +170,9 @@ def restore(self): @classmethod def push(cls, *args, **kwargs): - return push_torch_dispatch_mode(functools.partial(cls, *args, **kwargs)) + warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`") + instance = cls(*args, **kwargs) + return instance class BaseTorchDispatchMode(TorchDispatchMode): @@ -177,7 +180,3 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) - -@contextlib.contextmanager -def push_torch_dispatch_mode(ctor) -> Iterator[object]: - return _push_mode(ctor, mode_info=TorchDispatchModeInfo()) diff --git a/torch/utils/benchmark/utils/common.py b/torch/utils/benchmark/utils/common.py index f5eaa109d1efa..a8bbef3bfbeb4 100644 --- a/torch/utils/benchmark/utils/common.py +++ b/torch/utils/benchmark/utils/common.py @@ -14,7 +14,7 @@ import torch -__all__ = ["TaskSpec", "Measurement", "_make_temp_dir"] +__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"] _MAX_SIGNIFICANT_FIGURES = 4 diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index 7c6b568c994c4..d3713fd708cf2 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -7,7 +7,7 @@ from torch.utils.benchmark.utils import common from torch import tensor as _tensor -__all__ = ["Compare"] +__all__ = ["Colorize", "Compare"] BEST = "\033[92m" GOOD = "\033[34m" diff --git a/torch/utils/benchmark/utils/timeit_template.cpp b/torch/utils/benchmark/utils/timeit_template.cpp index d739b70f70abe..30b6f79c0b5ae 100644 --- a/torch/utils/benchmark/utils/timeit_template.cpp +++ b/torch/utils/benchmark/utils/timeit_template.cpp @@ -10,6 +10,7 @@ sections with user provided statements. #include #include +#include #include #include diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index e030896388c09..dc774b1fc0d50 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -2,6 +2,11 @@ import warnings from typing import Any, Dict, Iterable, List, Optional, Tuple +__all__ = [ + "checkpoint", "checkpoint_sequential", "CheckpointFunction", + "check_backward_validity", "detach_variable", "get_device_states", + "set_device_states", +] def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: if isinstance(inputs, tuple): @@ -51,6 +56,16 @@ def set_device_states(devices, states) -> None: with torch.cuda.device(device): torch.cuda.set_rng_state(state) +def _get_autocast_kwargs(): + gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + + cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(), + "dtype": torch.get_autocast_cpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + + return gpu_autocast_kwargs, cpu_autocast_kwargs class CheckpointFunction(torch.autograd.Function): @@ -60,12 +75,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. - ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled()} - ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(), - "dtype": torch.get_autocast_cpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled()} + ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. @@ -212,15 +222,18 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` - preserve_rng_state(bool, optional, default=True): Omit stashing and restoring + preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. - use_reentrant(bool, optional, default=True): Use checkpointing + Default: ``True`` + use_reentrant(bool, optional): Use checkpointing implementation that requires re-entrant autograd. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an implementation that does not require re-entrant autograd. This allows ``checkpoint`` to support additional functionality, such as - working as expected with ``torch.autograd.grad``. Note that future + working as expected with ``torch.autograd.grad`` and support for + keyword arguments input into the checkpointed function. Note that future versions of PyTorch will default to ``use_reentrant=False``. + Default: ``True`` args: tuple containing inputs to the :attr:`function` Returns: @@ -228,7 +241,7 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) - if kwargs: + if kwargs and use_reentrant: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) if use_reentrant: @@ -237,7 +250,8 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): return _checkpoint_without_reentrant( function, preserve, - *args + *args, + **kwargs, ) @@ -272,8 +286,9 @@ def checkpoint_sequential(functions, segments, input, **kwargs): functions (comprising the model) to run sequentially. segments: Number of chunks to create in the model input: A Tensor that is input to :attr:`functions` - preserve_rng_state(bool, optional, default=True): Omit stashing and restoring + preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. + Default: ``True`` Returns: Output of running :attr:`functions` sequentially on :attr:`*inputs` @@ -306,7 +321,7 @@ def forward(input): preserve_rng_state=preserve) return run_function(end + 1, len(functions) - 1, functions)(input) -def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args): +def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs): """Checkpointining without re-entrant autograd Args: function: describes what to run in the forward pass of the model or @@ -314,11 +329,14 @@ def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args): passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` - preserve_rng_state(bool, optional, default=True): Omit stashing and restoring + preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. + Default: ``True`` *args: Arguments to pass in to the given ``function``. + **kwargs: Keyword arguments to pass into the given ``function``. """ - had_autocast_in_fwd = torch.is_autocast_enabled() + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs() if preserve_rng_state: fwd_cpu_state = torch.get_rng_state() @@ -366,9 +384,12 @@ def inner_unpack(packed): torch.set_rng_state(fwd_cpu_state) if had_cuda_in_fwd: set_device_states(fwd_gpu_devices, fwd_gpu_states) - with torch.enable_grad(), torch.cuda.amp.autocast(had_autocast_in_fwd): - with torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): - _unused = function(*args) + + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**gpu_autocast_kwargs), \ + torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args, **kwargs) if x not in storage: raise RuntimeError( @@ -380,7 +401,7 @@ def inner_unpack(packed): return storage.pop(x) with torch.autograd.graph.saved_tensors_hooks(pack, unpack): - output = function(*args) + output = function(*args, **kwargs) if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd: # Cuda was not initialized before running the forward, so we didn't # stash the CUDA state. diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index edcdcd0fce047..3adfb1a4439ff 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -104,7 +104,8 @@ def _find_cuda_home() -> Optional[str]: if not os.path.exists(cuda_home): cuda_home = None if cuda_home and not torch.cuda.is_available(): - print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'") + print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'", + file=sys.stderr) return cuda_home def _find_rocm_home() -> Optional[str]: @@ -127,7 +128,8 @@ def _find_rocm_home() -> Optional[str]: if not os.path.exists(rocm_home): rocm_home = None if rocm_home and torch.version.hip is None: - print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'") + print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'", + file=sys.stderr) return rocm_home @@ -355,6 +357,56 @@ def get_compiler_abi_compatibility_and_version(compiler) -> Tuple[bool, TorchVer return (False, TorchVersion('.'.join(version))) +def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> None: + if not CUDA_HOME: + raise RuntimeError(CUDA_NOT_FOUND_MESSAGE) + + nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc') + cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode(*SUBPROCESS_DECODE_ARGS) + cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str) + if cuda_version is None: + return + + cuda_str_version = cuda_version.group(1) + cuda_ver = packaging.version.parse(cuda_str_version) + torch_cuda_version = packaging.version.parse(torch.version.cuda) + if cuda_ver != torch_cuda_version: + # major/minor attributes are only available in setuptools>=49.6.0 + if getattr(cuda_ver, "major", float("nan")) != getattr(torch_cuda_version, "major", float("nan")): + raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda)) + warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda)) + + if not (sys.platform.startswith('linux') and + os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') not in ['ON', '1', 'YES', 'TRUE', 'Y'] and + _is_binary_build()): + return + + cuda_compiler_bounds = CUDA_CLANG_VERSIONS if compiler_name.startswith('clang') else CUDA_GCC_VERSIONS + + if cuda_str_version not in cuda_compiler_bounds: + warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}') + else: + min_compiler_version, max_compiler_version = cuda_compiler_bounds[cuda_str_version] + min_compiler_version_str = '.'.join(map(str, min_compiler_version)) + max_compiler_version_str = '.'.join(map(str, max_compiler_version)) + + version_bound_str = f'>={min_compiler_version_str}' + version_bound_str = f'{version_bound_str}, <={max_compiler_version_str}' + + if compiler_version < TorchVersion(min_compiler_version_str): + raise RuntimeError( + f'The current installed version of {compiler_name} ({compiler_version}) is less ' + f'than the minimum required version by CUDA {cuda_str_version} ({min_compiler_version_str}). ' + f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' + ) + if compiler_version > TorchVersion(max_compiler_version_str): + raise RuntimeError( + f'The current installed version of {compiler_name} ({compiler_version}) is greater ' + f'than the maximum required version by CUDA {cuda_str_version} ({max_compiler_version_str}). ' + f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' + ) + + # See below for why we inherit BuildExtension from object. # https://stackoverflow.com/questions/1713038/super-fails-with-error-typeerror-argument-1-must-be-type-not-classobj-when @@ -431,7 +483,7 @@ def build_extensions(self) -> None: extension = next(extension_iter, None) if cuda_ext and not IS_HIP_EXTENSION: - self._check_cuda_version(compiler_name, compiler_version) + _check_cuda_version(compiler_name, compiler_version) for extension in self.extensions: # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when @@ -810,50 +862,6 @@ def _check_abi(self) -> Tuple[str, TorchVersion]: raise UserWarning(msg) return compiler, version - def _check_cuda_version(self, compiler_name: str, compiler_version: TorchVersion): - if CUDA_HOME: - nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc') - cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode(*SUBPROCESS_DECODE_ARGS) - cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str) - if cuda_version is not None: - cuda_str_version = cuda_version.group(1) - cuda_ver = packaging.version.parse(cuda_str_version) - torch_cuda_version = packaging.version.parse(torch.version.cuda) - if cuda_ver != torch_cuda_version: - # major/minor attributes are only available in setuptools>=49.6.0 - if getattr(cuda_ver, "major", float("nan")) != getattr(torch_cuda_version, "major", float("nan")): - raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda)) - warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda)) - if (sys.platform.startswith('linux') and - os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') not in ['ON', '1', 'YES', 'TRUE', 'Y'] and - _is_binary_build()): - cuda_compiler_bounds = CUDA_CLANG_VERSIONS if compiler_name.startswith('clang') else CUDA_GCC_VERSIONS - - if cuda_str_version not in cuda_compiler_bounds: - warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}') - else: - min_compiler_version, max_compiler_version = cuda_compiler_bounds[cuda_str_version] - min_compiler_version_str = '.'.join(map(str, min_compiler_version)) - max_compiler_version_str = '.'.join(map(str, max_compiler_version)) - - version_bound_str = f'>={min_compiler_version_str}' - version_bound_str = f'{version_bound_str}, <={max_compiler_version_str}' - - if compiler_version < TorchVersion(min_compiler_version_str): - raise RuntimeError( - f'The current installed version of {compiler_name} ({compiler_version}) is less ' - f'than the minimum required version by CUDA {cuda_str_version} ({min_compiler_version_str}). ' - f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' - ) - elif compiler_version > TorchVersion(max_compiler_version_str): - raise RuntimeError( - f'The current installed version of {compiler_name} ({compiler_version}) is greater ' - f'than the maximum required version by CUDA {cuda_str_version} ({max_compiler_version_str}). ' - f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' - ) - else: - raise RuntimeError(CUDA_NOT_FOUND_MESSAGE) - def _add_compile_flag(self, extension, flag): extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args) if isinstance(extension.extra_compile_args, dict): @@ -1449,7 +1457,8 @@ def _jit_compile(name, if version > 0: if version != old_version and verbose: print(f'The input conditions for extension module {name} have changed. ' + - f'Bumping to version {version} and re-building as {name}_v{version}...') + f'Bumping to version {version} and re-building as {name}_v{version}...', + file=sys.stderr) name = f'{name}_v{version}' if version != old_version: @@ -1494,10 +1503,11 @@ def _jit_compile(name, baton.wait() elif verbose: print('No modifications detected for re-loaded extension ' - f'module {name}, skipping build step...') + f'module {name}, skipping build step...', + file=sys.stderr) if verbose: - print(f'Loading extension module {name}...') + print(f'Loading extension module {name}...', file=sys.stderr) if is_standalone: return _get_exec_path(name, build_directory) @@ -1526,7 +1536,7 @@ def _write_ninja_file_and_compile_objects( with_cuda = any(map(_is_cuda_file, sources)) build_file_path = os.path.join(build_directory, 'build.ninja') if verbose: - print(f'Emitting ninja build file {build_file_path}...') + print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr) _write_ninja_file( path=build_file_path, cflags=cflags, @@ -1540,7 +1550,7 @@ def _write_ninja_file_and_compile_objects( library_target=None, with_cuda=with_cuda) if verbose: - print('Compiling objects...') + print('Compiling objects...', file=sys.stderr) _run_ninja_build( build_directory, verbose, @@ -1575,7 +1585,7 @@ def _write_ninja_file_and_build_library( is_standalone) build_file_path = os.path.join(build_directory, 'build.ninja') if verbose: - print(f'Emitting ninja build file {build_file_path}...') + print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr) # NOTE: Emitting a new ninja build file does not cause re-compilation if # the sources did not change, so it's ok to re-emit (and it's fast). _write_ninja_file_to_build_library( @@ -1590,7 +1600,7 @@ def _write_ninja_file_and_build_library( is_standalone=is_standalone) if verbose: - print(f'Building extension module {name}...') + print(f'Building extension module {name}...', file=sys.stderr) _run_ninja_build( build_directory, verbose, @@ -1669,7 +1679,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): if with_cuda: if verbose: - print('Detected CUDA files, patching ldflags') + print('Detected CUDA files, patching ldflags', file=sys.stderr) if IS_WINDOWS: extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib/x64")}') extra_ldflags.append('cudart.lib') @@ -1782,7 +1792,11 @@ def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: # Allow env var to override, just like during initial cmake build. _archs = os.environ.get('PYTORCH_ROCM_ARCH', None) if not _archs: - archs = torch.cuda.get_arch_list() + archFlags = torch._C._cuda_getArchFlags() + if archFlags: + archs = archFlags.split() + else: + archs = [] else: archs = _archs.replace(' ', ';').split(';') flags = ['--amdgpu-target=%s' % arch for arch in archs] @@ -1802,12 +1816,12 @@ def _get_build_directory(name: str, verbose: bool) -> str: root_extensions_directory, build_folder) if verbose: - print(f'Using {root_extensions_directory} as PyTorch extensions root...') + print(f'Using {root_extensions_directory} as PyTorch extensions root...', file=sys.stderr) build_directory = os.path.join(root_extensions_directory, name) if not os.path.exists(build_directory): if verbose: - print(f'Creating extension directory {build_directory}...') + print(f'Creating extension directory {build_directory}...', file=sys.stderr) # This is like mkdir -p, i.e. will also create parent directories. os.makedirs(build_directory, exist_ok=True) @@ -1818,11 +1832,13 @@ def _get_num_workers(verbose: bool) -> Optional[int]: max_jobs = os.environ.get('MAX_JOBS') if max_jobs is not None and max_jobs.isdigit(): if verbose: - print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...') + print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...', + file=sys.stderr) return int(max_jobs) if verbose: print('Allowing ninja to set a default number of workers... ' - '(overridable by setting the environment variable MAX_JOBS=N)') + '(overridable by setting the environment variable MAX_JOBS=N)', + file=sys.stderr) return None diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py index 96e4629cba00b..1085e96a122a3 100644 --- a/torch/utils/data/_utils/__init__.py +++ b/torch/utils/data/_utils/__init__.py @@ -34,6 +34,20 @@ https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 """ +DATAPIPE_SHARED_SEED = "_dl_shared_seed" +r"""The key to share the same seed for shuffle DataPipe across distributed processes""" + +DATAPIPE_SHARED_SEED_COUNTER = "_dl_shared_seed_recv_cnt" +r"""The key to count the number of distributed processes that have received the shared seed""" + +DATAPIPE_SHARED_SEED_DEFAULT_TIMEOUT = 30 * 60 +r"""Timeout (in seconds) sending the shared seed from Rank 0 and sending + the signal of the shared seed received from other Ranks. + It uses the same default timeout for the distributed process group""" + +DATAPIPE_SHARED_SEED_CHECK_INTERVAL = 0.01 +r"""Interval to check if each rank has received the shared seed""" + try: import numpy diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 7acfa5d8700af..9c67d50789f1b 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,12 +5,16 @@ in `./_utils/worker.py`. """ +import functools +import itertools +import logging import os +import queue import threading -import itertools +import time import warnings -import queue -import functools + +from datetime import timedelta from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional, Union import multiprocessing as python_multiprocessing @@ -63,6 +67,8 @@ get_worker_info = _utils.worker.get_worker_info +logger = logging.getLogger(__name__) + class _DatasetKind(object): Map = 0 @@ -140,7 +146,7 @@ class DataLoader(Generic[T_co]): num_workers (int, optional): how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``0``) - collate_fn (callable, optional): merges a list of samples to form a + collate_fn (Callable, optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. pin_memory (bool, optional): If ``True``, the data loader will copy Tensors @@ -153,7 +159,7 @@ class DataLoader(Generic[T_co]): will be smaller. (default: ``False``) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``0``) - worker_init_fn (callable, optional): If not ``None``, this will be called on each + worker_init_fn (Callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``) generator (torch.Generator, optional): If not ``None``, this RNG will be used @@ -567,21 +573,44 @@ def _get_shared_seed(self): ws = dist.get_world_size() store = dist.distributed_c10d._get_default_store() if rank == 0: - store.set("_dl_shared_seed", str(_shared_seed)) + _shared_seed_str = str(_shared_seed) + store.set(_utils.DATAPIPE_SHARED_SEED, _shared_seed_str) + logger.info(f"Shared seed ({_shared_seed_str}) sent to store on rank 0") + # Use 'add' instead of 'get' since for some store implementations 'add' + # doesn't work well with 'get'. + _shared_seed_recv_cnt = store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, 1) + start = time.time() + while _shared_seed_recv_cnt < ws: + time.sleep(_utils.DATAPIPE_SHARED_SEED_CHECK_INTERVAL) + _shared_seed_recv_cnt = store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, 0) + if timedelta(seconds=(time.time() - start)) > \ + timedelta(seconds=_utils.DATAPIPE_SHARED_SEED_DEFAULT_TIMEOUT): + raise RuntimeError("Timed out receiving the signal from the distribtued store on " + "Rank 0 that all other Ranks have received the shared seed. " + f"(world_size={ws}, received={_shared_seed_recv_cnt}, " + f"timeout={_utils.DATAPIPE_SHARED_SEED_DEFAULT_TIMEOUT})") # Reset after all distributed processes have received the shared seed - store.add("_dl_shared_seed_recv_cnt", 1) - _shared_seed_recv_cnt = 1 - while _shared_seed_recv_cnt != ws: - _shared_seed_recv_cnt = int(store.get("_dl_shared_seed_recv_cnt")) - store.set("_dl_shared_seed", "") - store.add("_dl_shared_seed_recv_cnt", -ws) - assert int(store.get("_dl_shared_seed_recv_cnt")) == 0 + store.set(_utils.DATAPIPE_SHARED_SEED, "") + _shared_seed_recv_cnt = store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, -ws) + assert _shared_seed_recv_cnt == 0 else: _shared_seed_str = "" - store.wait(["_dl_shared_seed"], _utils.MP_STATUS_CHECK_INTERVAL) + start = time.time() while len(_shared_seed_str) == 0: - _shared_seed_str = store.get("_dl_shared_seed") - store.add("_dl_shared_seed_recv_cnt", 1) + time.sleep(_utils.DATAPIPE_SHARED_SEED_CHECK_INTERVAL) + _shared_seed_str = store.get(_utils.DATAPIPE_SHARED_SEED) + if timedelta(seconds=(time.time() - start)) > \ + timedelta(seconds=_utils.DATAPIPE_SHARED_SEED_DEFAULT_TIMEOUT): + raise RuntimeError("Timed out receiving the shared seed from the distribtued store " + f"on Rank {rank}. (world_size={ws}, " + f"timeout={_utils.DATAPIPE_SHARED_SEED_DEFAULT_TIMEOUT})") + logger.info(f"Shared seed ({_shared_seed_str}) received from store on rank {rank}") + _shared_seed_recv_cnt = store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, 1) + # Exit only when all ranks received seed, otherwise we are at risk that current rank + # will reach same section of the code again while rank zero still in the previous iteration + while _shared_seed_recv_cnt > 0: + time.sleep(_utils.DATAPIPE_SHARED_SEED_CHECK_INTERVAL) + _shared_seed_recv_cnt = store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, 0) _shared_seed = int(_shared_seed_str) return _shared_seed else: diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py index 8963c60746059..8a8d536b79857 100644 --- a/torch/utils/data/dataloader_experimental.py +++ b/torch/utils/data/dataloader_experimental.py @@ -84,7 +84,7 @@ def __new__(cls, raise Exception( 'sampler is not yet supported by DataPipes') datapipe = dataset - datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) + datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) # type: ignore[assignment] if batch_outside_worker and pin_memory: raise Exception( 'pin_memory is not yet compatible with batch_outside_worker') diff --git a/torch/utils/data/datapipes/_hook_iterator.py b/torch/utils/data/datapipes/_hook_iterator.py index 0d4213d051c13..df27f85c03738 100644 --- a/torch/utils/data/datapipes/_hook_iterator.py +++ b/torch/utils/data/datapipes/_hook_iterator.py @@ -1,8 +1,22 @@ import inspect import functools +from enum import Enum + import torch.autograd +class _SnapshotState(Enum): + r""" + These are the snapshotting-related states that IterDataPipes can be in. + `NotStarted` - allows you to restore a snapshot and create an iterator without reset + `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe + `Iterating` - can restore, will reset if you create a new iterator + """ + NotStarted = 0 + Restored = 1 + Iterating = 2 + + def _simplify_obj_name(obj) -> str: """ Simplify the display strings of objects for the purpose of rendering within DataPipe error messages. @@ -94,26 +108,40 @@ def profiler_record_fn_context(): return torch.autograd.profiler.record_function(profile_name) class IteratorDecorator: - """Wrap the iterator and modifying its `__next__` method""" - def __init__(self, iterator, source_dp, iterator_id): + r""" + Wrap the iterator and modifying its `__next__` method. This decorator is applied to + DataPipes of which `__iter__` method is NOT a generator function. Those `__iter__` + method commonly returns `self` but not necessarily. + """ + def __init__(self, iterator, source_dp, iterator_id, has_next_method): self.iterator = iterator self.source_dp = source_dp self.iterator_id = iterator_id self._profiler_enabled = torch.autograd._profiler_enabled() + # Check if `__iter__` returns `self` and `DataPipe` has `__next__` + self.self_and_has_next_method = self.iterator is self.source_dp and has_next_method def __iter__(self): return self + def _get_next(self): + r""" + Return next with logic related to iterator validity, profiler, and incrementation of samples yielded. + """ + _check_iterator_valid(self.source_dp, self.iterator_id) + result = next(self.iterator) + if not self.self_and_has_next_method: + self.source_dp._number_of_samples_yielded += 1 + return result + def __next__(self): # TODO: Add try-except to in-place reduce traceback from the Exception # See: https://github.com/pytorch/data/issues/284 if self._profiler_enabled: with profiler_record_fn_context(): - _check_iterator_valid(self.source_dp, self.iterator_id) - return next(self.iterator) + return self._get_next() else: # Decided against using `contextlib.nullcontext` for performance reasons - _check_iterator_valid(self.source_dp, self.iterator_id) - return next(self.iterator) + return self._get_next() def __getattr__(self, name): return getattr(self.iterator, name) @@ -126,6 +154,15 @@ def __getattr__(self, name): def wrap_generator(*args, **kwargs): gen = func(*args, **kwargs) datapipe = args[0] + if datapipe._fast_forward_iterator: + it = datapipe._fast_forward_iterator + datapipe._fast_forward_iterator = None + datapipe._snapshot_state = _SnapshotState.Iterating + while True: + try: + yield next(it) + except StopIteration: + return iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator _profiler_enabled = torch.autograd._profiler_enabled() try: @@ -136,6 +173,7 @@ def wrap_generator(*args, **kwargs): response = gen.send(None) while True: + datapipe._number_of_samples_yielded += 1 request = yield response # Pass through here every time `__next__` is called if _profiler_enabled: @@ -146,7 +184,7 @@ def wrap_generator(*args, **kwargs): _check_iterator_valid(datapipe, iterator_id) response = gen.send(request) except StopIteration as e: - return e.value + return except Exception as e: # TODO: Simplify the traceback message to skip over `response = gen.send(None)` # Part of https://github.com/pytorch/data/issues/284 @@ -172,21 +210,31 @@ def wrap_generator(*args, **kwargs): def wrap_next(*args, **kwargs): if torch.autograd._profiler_enabled(): with profiler_record_fn_context(): - return next_func(*args, **kwargs) + result = next_func(*args, **kwargs) else: - return next_func(*args, **kwargs) + result = next_func(*args, **kwargs) + datapipe = args[0] + datapipe._number_of_samples_yielded += 1 + return result namespace['__next__'] = wrap_next - # Note that if the `__next__` and `__iter__` do something completely unrelated? It may cause issue but - # the user will be violating the iterator protocol + # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but + # the user will be violating the iterator protocol. Potential issue: + # 1. Valid iterator ID may not update or checked properly + # 2. The number of samples yielded will be miscounted # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators @functools.wraps(func) def wrap_iter(*args, **kwargs): iter_ret = func(*args, **kwargs) datapipe = args[0] + datapipe._snapshot_state = _SnapshotState.Iterating + if datapipe._fast_forward_iterator: + iter_ret = datapipe._fast_forward_iterator + datapipe._fast_forward_iterator = None + return iter_ret iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator - return IteratorDecorator(iter_ret, datapipe, iterator_id) + return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace) namespace['__iter__'] = wrap_iter diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index 8cd4300435600..1fd316cbec5ff 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -5,7 +5,8 @@ import functools import numbers import sys -from torch.utils.data.datapipes._hook_iterator import hook_iterator + +from torch.utils.data.datapipes._hook_iterator import hook_iterator, _SnapshotState from typing import (Any, Dict, Iterator, Generic, List, Set, Tuple, TypeVar, Union, get_type_hints) from typing import _eval_type, _tp_cache, _type_check, _type_repr # type: ignore[attr-defined] @@ -350,32 +351,20 @@ def __new__(cls, name, bases, namespace, **kwargs): @functools.wraps(reset_func) def conditional_reset(*args, **kwargs): r""" - Only execute DataPipe's `reset()` method if `_restored` is False. This allows recently + Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating`. This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call. """ datapipe = args[0] - if datapipe._restored is True: - datapipe._restored = False - else: + if datapipe._snapshot_state == _SnapshotState.Iterating: + # Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have + # already begun iterating. + datapipe._number_of_samples_yielded = 0 + datapipe._fast_forward_iterator = None reset_func(*args, **kwargs) + datapipe._snapshot_state = _SnapshotState.Iterating namespace['reset'] = conditional_reset - if '__setstate__' in namespace: - setstate_func = namespace['__setstate__'] - - @functools.wraps(setstate_func) - def wrap_setstate(*args, **kwargs): - r""" - Set `_restored` to True during `__setstate__`, such that the next `reset()` call during - iterator creation will not actually reset the state of the DataPipe. - """ - datapipe = args[0] - datapipe._restored = True - return setstate_func(*args, **kwargs) - - namespace['__setstate__'] = wrap_setstate - if '__iter__' in namespace: hook_iterator(namespace, 'enumerate(DataPipe)#{}'.format(name)) return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] diff --git a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py index 8bdd91ae9d518..540adc3777ebf 100644 --- a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py +++ b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py @@ -1,6 +1,7 @@ _pandas = None _WITH_PANDAS = None + def _try_import_pandas() -> bool: try: import pandas # type: ignore[import] @@ -18,6 +19,7 @@ def _with_pandas() -> bool: _WITH_PANDAS = _try_import_pandas() return _WITH_PANDAS + class PandasWrapper: @classmethod def create_dataframe(cls, data, columns): @@ -41,7 +43,7 @@ def is_column(cls, data): def iterate(cls, data): if not _with_pandas(): raise Exception("DataFrames prototype requires pandas to function") - for d in data: + for d in data.itertuples(index=False): yield d @classmethod @@ -54,7 +56,7 @@ def concat(cls, buffer): def get_item(cls, data, idx): if not _with_pandas(): raise Exception("DataFrames prototype requires pandas to function") - return data[idx : idx + 1] + return data[idx: idx + 1] @classmethod def get_len(cls, df): @@ -62,10 +64,17 @@ def get_len(cls, df): raise Exception("DataFrames prototype requires pandas to function") return len(df.index) + @classmethod + def get_columns(cls, df): + if not _with_pandas(): + raise Exception("DataFrames prototype requires pandas to function") + return list(df.columns.values.tolist()) + # When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class) default_wrapper = PandasWrapper + def get_df_wrapper(): return default_wrapper @@ -85,6 +94,11 @@ def is_dataframe(data): return wrapper.is_dataframe(data) +def get_columns(data): + wrapper = get_df_wrapper() + return wrapper.get_columns(data) + + def is_column(data): wrapper = get_df_wrapper() return wrapper.is_column(data) diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index 5d9028dc4217c..fcbf15328e43c 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -9,14 +9,17 @@ __all__ = [ "Capture", + "CaptureA", "CaptureAdd", "CaptureCall", + "CaptureControl", "CaptureDataFrame", "CaptureDataFrameWithDataPipeOps", "CaptureF", "CaptureGetAttr", "CaptureGetItem", "CaptureInitial", + "CaptureLikeMock", "CaptureMul", "CaptureSetItem", "CaptureSub", @@ -24,10 +27,19 @@ "CaptureVariableAssign", "DataFrameTracer", "DataFrameTracedOps", + "disable_capture", "get_val", ] +def disable_capture(): + CaptureControl.disabled = True + + +class CaptureControl(): + disabled = False + + class DataFrameTracedOps(DFIterDataPipe): def __init__(self, source_datapipe, output_var): self.source_datapipe = source_datapipe @@ -42,13 +54,14 @@ def __iter__(self): DATAPIPES_OPS = ['_dataframes_as_tuples', 'groupby', '_dataframes_filter', 'map', 'to_datapipe', 'shuffle', 'concat', 'batch', '_dataframes_per_row', '_dataframes_concat', '_dataframes_shuffle'] +UNIMPLEMENTED_ATTR = ['__deepcopy__', '__setstate__', 'is_shardable', 'apply_sharding'] + class Capture(object): # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures - ctx: Dict[str, List[Any]] - def __init__(self): - self.ctx = {'operations': [], 'variables': []} + def __init__(self, schema_df=None): + self.ctx = {'operations': [], 'variables': [], 'schema_df': schema_df} def __str__(self): return self._ops_str() @@ -61,10 +74,27 @@ def _ops_str(self): res += str(op) return res + def __getstate__(self): + # TODO(VitalyFedyunin): Currently can't pickle (why?) + self.ctx['schema_df'] = None + for var in self.ctx['variables']: + var.calculated_value = None + state = {} + for item in self.__dict__: + state[item] = getattr(self, item) + return state + + def __setstate__(self, state): + for k, v in state.items(): + setattr(self, k, v) + def __getattr__(self, attrname): - if attrname == 'kwarg': + if attrname == 'kwarg' or attrname == 'kwargs': raise Exception('no kwargs!') - return CaptureGetAttr(self, attrname, ctx=self.ctx) + if attrname in ['__deepcopy__']: + raise AttributeError() + result = CaptureGetAttr(self, attrname, ctx=self.ctx) + return result def __getitem__(self, key): return CaptureGetItem(self, key, ctx=self.ctx) @@ -94,38 +124,127 @@ def __mul__(self, add_val): self.ctx['operations'].append(t) return var + def _is_context_empty(self): + return len(self.ctx['operations']) == 0 and len(self.ctx['variables']) == 0 + + def apply_ops_2(self, dataframe): + # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) + self.ctx['variables'][0].calculated_value = dataframe + for op in self.ctx['operations']: + op.execute() + + @property + def columns(self): + self.apply_ops_2(self.ctx['schema_df']) + value = self.execute() + return value.columns + + # TODO(VitalyFedyunin): Add tests + # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture + + def __call__(self, *args, **kwargs): + # TODO: Check if args or kwargs have more than one different context + if self._is_context_empty(): + # TODO: Allow CaptureA to take context from mock + for arg in args: + if isinstance(arg, Capture) and not arg._is_context_empty(): + self.ctx = arg.ctx + break + if self._is_context_empty(): + for k, v in kwargs.items(): + if isinstance(k, Capture) and not k._is_context_empty(): + self.ctx = k.ctx + break + if isinstance(v, Capture) and not v._is_context_empty(): + self.ctx = v.ctx + break + + res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs) + var = CaptureVariable(None, ctx=self.ctx) + t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res) + self.ctx['operations'].append(t) + return var + class CaptureF(Capture): def __init__(self, ctx=None, **kwargs): if ctx is None: self.ctx = {'operations': [], 'variables': []} - self.ctx = ctx + else: + self.ctx = ctx self.kwargs = kwargs -class CaptureCall(CaptureF): +class CaptureA(CaptureF): + def __str__(self): + return '{name}'.format(name=self.kwargs['name']) + + def execute(self): + value = self.kwargs['real_attribute'] + return value + + +class CaptureLikeMock(): + def __init__(self, name): + import unittest.mock as mock + # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead. + get_target, attribute = mock._get_target(name) # type: ignore[attr-defined] + self.get_target = get_target + self.attribute = attribute + self.name = name + + def __enter__(self): + self.save = getattr(self.get_target(), self.attribute) + capt = CaptureA(name=self.name, real_attribute=self.save) + setattr(self.get_target(), self.attribute, capt) + + def __exit__(self, *exc_info): + setattr(self.get_target(), self.attribute, self.save) + + +class CaptureCall(Capture): + + def __init__(self, callable, ctx=None, **kwargs): + if ctx is None: + self.ctx = {'operations': [], 'variables': []} + else: + self.ctx = ctx + self.kwargs = kwargs + self.callable = callable + def __str__(self): - return "{variable}({args},{kwargs})".format(**self.kwargs) + return "{callable}({args},{kwargs})".format(callable=self.callable, **self.kwargs) def execute(self): - return (get_val(self.kwargs['variable']))(*self.kwargs['args'], **self.kwargs['kwargs']) + + # TODO: VitalyFedyunin execute kwargs and maybe nestted structures + executed_args = [] + for arg in self.kwargs['args']: + if isinstance(arg, Capture): + executed_args.append(arg.execute()) + else: + executed_args.append(arg) + left = get_val(self.callable) + return left(*executed_args, **self.kwargs['kwargs']) class CaptureVariableAssign(CaptureF): def __str__(self): - return "{variable} = {value}".format(**self.kwargs) + variable = self.kwargs['variable'] + value = self.kwargs['value'] + return "{variable} = {value}".format(variable=variable, value=value) def execute(self): self.kwargs['variable'].calculated_value = self.kwargs['value'].execute() class CaptureVariable(Capture): - value = None - name = None - calculated_value = None + # TODO(VitalyFedyunin): This should be atomic and thread safe names_idx = 0 def __init__(self, value, ctx): + if CaptureControl.disabled: + raise Exception('Attempting to create capture variable with capture off') self.ctx = ctx self.value = value self.name = 'var_%s' % CaptureVariable.names_idx @@ -147,9 +266,6 @@ def apply_ops(self, dataframe): class CaptureGetItem(Capture): - left: Capture - key: Any - def __init__(self, left, key, ctx): self.ctx = ctx self.left = left @@ -159,14 +275,11 @@ def __str__(self): return "%s[%s]" % (self.left, get_val(self.key)) def execute(self): - return (self.left.execute())[self.key] + left = self.left.execute() + return left[self.key] class CaptureSetItem(Capture): - left: Capture - key: Any - value: Capture - def __init__(self, left, key, value, ctx): self.ctx = ctx self.left = left @@ -177,14 +290,12 @@ def __str__(self): return "%s[%s] = %s" % (self.left, get_val(self.key), self.value) def execute(self): - (self.left.execute())[ - self.key] = self.value.execute() + left = self.left.execute() + value = self.value.execute() + left[self.key] = value class CaptureAdd(Capture): - left = None - right = None - def __init__(self, left, right, ctx): self.ctx = ctx self.left = left @@ -198,9 +309,6 @@ def execute(self): class CaptureMul(Capture): - left = None - right = None - def __init__(self, left, right, ctx): self.ctx = ctx self.left = left @@ -214,9 +322,6 @@ def execute(self): class CaptureSub(Capture): - left = None - right = None - def __init__(self, left, right, ctx): self.ctx = ctx self.left = left @@ -230,9 +335,6 @@ def execute(self): class CaptureGetAttr(Capture): - source = None - name: str - def __init__(self, src, name, ctx): self.ctx = ctx self.src = src @@ -256,9 +358,8 @@ def get_val(capture): class CaptureInitial(CaptureVariable): - - def __init__(self): - new_ctx: Dict[str, List[Any]] = {'operations': [], 'variables': []} + def __init__(self, schema_df=None): + new_ctx: Dict[str, List[Any]] = {'operations': [], 'variables': [], 'schema_df': schema_df} super().__init__(None, new_ctx) self.name = 'input_%s' % self.name @@ -302,7 +403,12 @@ def shuffle(self, *args, **kwargs): def filter(self, *args, **kwargs): return self._dataframes_filter(*args, **kwargs) + def collate(self, *args, **kwargs): + raise Exception("Can't collate unbatched DataFrames stream") + def __getattr__(self, attrname): # ? + if attrname in UNIMPLEMENTED_ATTR: + raise AttributeError('Attemping to get ', attrname) if attrname in DATAPIPES_OPS: return (self.as_datapipe()).__getattr__(attrname) return super().__getattr__(attrname) @@ -312,6 +418,16 @@ def __getattr__(self, attrname): # ? class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): source_datapipe = None - def __init__(self, source_datapipe): - super().__init__() + # TODO(VitalyFedyunin): Must implement all special functions of datapipes + + def set_shuffle_settings(self, *args, **kwargs): + pass + + def is_shardable(self): + return False + + def __init__(self, source_datapipe, schema_df=None): self.source_datapipe = source_datapipe + if schema_df is None: + schema_df = next(iter(self.source_datapipe)) + super().__init__(schema_df=schema_df) diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py index d8f54cf1997e7..8b52f593a8fa2 100644 --- a/torch/utils/data/datapipes/dataframe/datapipes.py +++ b/torch/utils/data/datapipes/dataframe/datapipes.py @@ -22,7 +22,8 @@ def __init__(self, source_datapipe): def __iter__(self): for df in self.source_datapipe: - for record in df.to_records(index=False): + # for record in df.to_records(index=False): + for record in df_wrapper.iterate(df): yield record @@ -33,7 +34,8 @@ def __init__(self, source_datapipe): def __iter__(self): for df in self.source_datapipe: - for i in range(len(df.index)): + # TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup + for i in range(len(df)): yield df[i:i + 1] diff --git a/torch/utils/data/datapipes/dataframe/structures.py b/torch/utils/data/datapipes/dataframe/structures.py index 003e7625604bc..f290b351ca664 100644 --- a/torch/utils/data/datapipes/dataframe/structures.py +++ b/torch/utils/data/datapipes/dataframe/structures.py @@ -1,4 +1,5 @@ from torch.utils.data.datapipes.datapipe import DataChunk +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper __all__ = ["DataChunkDF", ] @@ -11,11 +12,11 @@ class DataChunkDF(DataChunk): def __iter__(self): for df in self.items: - for record in df.to_records(index=False): + for record in df_wrapper.iterate(df): yield record def __len__(self): total_len = 0 for df in self.items: - total_len += len(df) + total_len += df_wrapper.get_len(df) return total_len diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index ec3a4d8a51b02..d6cff084e826e 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -3,6 +3,7 @@ from typing import Dict, Callable, Optional, TypeVar, Generic, Iterator from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta +from torch.utils.data.datapipes._hook_iterator import _SnapshotState from torch.utils.data.datapipes.utils.common import ( _deprecation_warning, _iter_deprecated_functional_names, @@ -110,7 +111,9 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta): str_hook: Optional[Callable] = None repr_hook: Optional[Callable] = None _valid_iterator_id: Optional[int] = None - _restored: bool = False + _number_of_samples_yielded: int = 0 + _snapshot_state: _SnapshotState = _SnapshotState.NotStarted + _fast_forward_iterator: Optional[Iterator] = None def __getattr__(self, attribute_name): if attribute_name in IterDataPipe.functions: @@ -185,7 +188,7 @@ def __str__(self): # Instead of showing , return the class name return str(self.__class__.__qualname__) - def reset(self): + def reset(self) -> None: r""" Reset the `IterDataPipe` to the initial state. By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities, they may want to override this method with implementations that diff --git a/torch/utils/data/datapipes/datapipe.pyi.in b/torch/utils/data/datapipes/datapipe.pyi.in index b776caf3bf315..46fb27a703370 100644 --- a/torch/utils/data/datapipes/datapipe.pyi.in +++ b/torch/utils/data/datapipes/datapipe.pyi.in @@ -4,8 +4,9 @@ # classes/objects here, even though we are not injecting extra code into them at the moment. from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta -from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar, Union +from torch.utils.data import Dataset, IterableDataset, default_collate T_co = TypeVar('T_co', covariant=True) T = TypeVar('T') @@ -38,6 +39,9 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta): getstate_hook: Optional[Callable] = ... str_hook: Optional[Callable] = ... repr_hook: Optional[Callable] = ... + _number_of_samples_yielded: int = ... + _snapshot_state: _SnapshotState = _SnapshotState.Iterating + _fast_forward_iterator: Optional[Iterator] = ... def __getattr__(self, attribute_name: Any): ... @classmethod def register_function(cls, function_name: Any, function: Any) -> None: ... diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 29c4bb694beea..682315c388b36 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -1,9 +1,13 @@ -from typing import Callable, Iterator, Sized, TypeVar +import functools +from collections import namedtuple + +from typing import Callable, Iterator, Sized, TypeVar, Optional, Union, Any, Dict, List from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data._utils.collate import default_collate +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper from torch.utils.data.datapipes.datapipe import IterDataPipe -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn __all__ = [ "CollatorIterDataPipe", @@ -64,7 +68,7 @@ def __init__( super().__init__() self.datapipe = datapipe - _check_lambda_fn(fn) + _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment] self.input_col = input_col @@ -123,6 +127,44 @@ def __len__(self) -> int: ) +def _collate_helper(conversion, item): + # TODO(VitalyFedyunin): Verify that item is any sort of batch + if len(item.items) > 1: + # TODO(VitalyFedyunin): Compact all batch dataframes into one + raise Exception("Only supports one DataFrame per batch") + df = item[0] + columns_name = df_wrapper.get_columns(df) + tuple_names: List = [] + tuple_values: List = [] + + for name in conversion.keys(): + if name not in columns_name: + raise Exception("Conversion keys missmatch") + + for name in columns_name: + if name in conversion: + if not callable(conversion[name]): + raise Exception('Collate (DF)DataPipe requires callable as dict values') + collation_fn = conversion[name] + else: + # TODO(VitalyFedyunin): Add default collation into df_wrapper + try: + import torcharrow.pytorch as tap # type: ignore[import] + collation_fn = tap.rec.Default() + except Exception: + raise Exception("unable to import default collation function from the TorchArrrow") + + tuple_names.append(str(name)) + value = collation_fn(df[name]) + tuple_values.append(value) + + # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here + # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty + tpl_cls = namedtuple("CollateResult", tuple_names) # type: ignore[misc] + tuple = tpl_cls(*tuple_values) + return tuple + + @functional_datapipe("collate") class CollatorIterDataPipe(MapperIterDataPipe): r""" @@ -166,6 +208,22 @@ class CollatorIterDataPipe(MapperIterDataPipe): def __init__( self, datapipe: IterDataPipe, - collate_fn: Callable = default_collate, + conversion: Optional[ + Union[ + Callable[..., Any], + Dict[Union[str, Any], Union[Callable, Any]], + ] + ] = default_collate, + collate_fn: Optional[Callable] = None, ) -> None: - super().__init__(datapipe, fn=collate_fn) + # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]` + # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]` + if collate_fn is not None: + super().__init__(datapipe, fn=collate_fn) + else: + if callable(conversion): + super().__init__(datapipe, fn=conversion) + else: + # TODO(VitalyFedyunin): Validate passed dictionary + collate_fn = functools.partial(_collate_helper, conversion) + super().__init__(datapipe, fn=collate_fn) diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index c63e2999fe9c3..6b4a1323d8f16 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -153,6 +153,8 @@ def __getstate__(self): self.buffer_size, self._enabled, self._seed, + self._valid_iterator_id, + self._number_of_samples_yielded, self._rng.getstate(), ) return state @@ -163,6 +165,8 @@ def __setstate__(self, state): self.buffer_size, self._enabled, self._seed, + self._valid_iterator_id, + self._number_of_samples_yielded, rng_state, ) = state self._rng = random.Random() diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index cbc6b35be3435..191d5d02f7854 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -4,8 +4,9 @@ from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVar, Deque from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes._hook_iterator import _SnapshotState from torch.utils.data.datapipes.datapipe import IterDataPipe -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import StreamWrapper, _check_unpickable_fn __all__ = [ "ConcaterIterDataPipe", @@ -112,9 +113,9 @@ def __init__(self, datapipe: IterDataPipe, num_instances: int, buffer_size: int UserWarning ) self.child_pointers: List[int] = [0] * num_instances # Indicate the indices of the next element to get - self.slowest_ptr = 0 - self.leading_ptr = 0 - self.end_ptr: Optional[int] = None + self.slowest_ptr = 0 # The index to read by the slowest child + self.leading_ptr = 0 # The index to read by the fastest child + self.end_ptr: Optional[int] = None # The index to stop child def __len__(self): return len(self.main_datapipe) @@ -122,34 +123,39 @@ def __len__(self): def get_next_element_by_instance(self, instance_id: int): if self._datapipe_iterator is None: self._datapipe_iterator = iter(self.main_datapipe) - while self.end_ptr is None or self.child_pointers[instance_id] < self.end_ptr: - if not self.buffer or self.child_pointers[instance_id] > self.leading_ptr: + self._snapshot_state = _SnapshotState.Iterating + while self.end_ptr is None or self.child_pointers[instance_id] + 1 < self.end_ptr: + self.child_pointers[instance_id] += 1 + # Use buffer + if self.buffer and self.child_pointers[instance_id] <= self.leading_ptr: + idx = self.child_pointers[instance_id] - self.slowest_ptr - 1 + return_val = self.buffer[idx] + else: # Retreive one element from main datapipe self.leading_ptr = self.child_pointers[instance_id] - if self.buffer_size >= 0 and self.leading_ptr - self.slowest_ptr + 1 > self.buffer_size: - raise BufferError("ForkerIterDataPipe buffer overflow," + - f"buffer size {self.buffer_size} is insufficient.") try: - self.buffer.append(next(self._datapipe_iterator)) - self.child_pointers[instance_id] += 1 - yield self.buffer[-1] + return_val = next(self._datapipe_iterator) + self.buffer.append(return_val) except StopIteration: self.end_ptr = self.leading_ptr - else: # Child pointer is slower than or equal to the leading_ptr - buffer_index = self.child_pointers[instance_id] - self.slowest_ptr - return_val = self.buffer[buffer_index] - self.child_pointers[instance_id] += 1 - if self.child_pointers[instance_id] - 1 == self.slowest_ptr: - new_min = min(self.child_pointers) # Can optimize by avoiding the call to min() - if self.slowest_ptr < new_min: - self.slowest_ptr = new_min - self.buffer.popleft() - yield return_val - if self.end_ptr and self.child_pointers[instance_id] == self.end_ptr and\ - all(p == self.end_ptr for p in self.child_pointers): + continue + if self.child_pointers[instance_id] == self.slowest_ptr + 1: + new_min = min(self.child_pointers) # Can optimize by avoiding the call to min() + if self.slowest_ptr < new_min: + self.slowest_ptr = new_min + self.buffer.popleft() + if self.buffer_size >= 0 and self.leading_ptr > self.buffer_size + self.slowest_ptr: + raise BufferError("ForkerIterDataPipe buffer overflow," + + f"buffer size {self.buffer_size} is insufficient.") + yield return_val + + if all(p + 1 == self.end_ptr for p in self.child_pointers): self._datapipe_iterator = None def is_every_instance_exhausted(self) -> bool: - return all(self.end_ptr == ptr for ptr in self.child_pointers) + # Due to the implementation of `get_next_element_by_instance`, `self.end_ptr` will end up + # equaling to `len(main_datapipe) + 1`, hence the check for `self.end_ptr - 1 == ptr` below. + return self.end_ptr is not None and\ + all(self.end_ptr == ptr or self.end_ptr - 1 == ptr for ptr in self.child_pointers) def reset(self) -> None: self._datapipe_iterator = iter(self.main_datapipe) @@ -167,6 +173,8 @@ def __getstate__(self): self.main_datapipe, self.num_instances, self.buffer_size, + self._valid_iterator_id, + self._number_of_samples_yielded, ) return state @@ -175,6 +183,8 @@ def __setstate__(self, state): self.main_datapipe, self.num_instances, self.buffer_size, + self._valid_iterator_id, + self._number_of_samples_yielded, ) = state self._datapipe_iterator = None self.buffer = deque() @@ -300,7 +310,7 @@ def __new__(cls, datapipe: IterDataPipe, num_instances: int, if num_instances < 1: raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") - _check_lambda_fn(classifier_fn) + _check_unpickable_fn(classifier_fn) # When num_instances == 1, demux can be replaced by filter, # but keep it as Demultiplexer for the sake of consistency @@ -345,6 +355,7 @@ def _find_next(self, instance_id: int) -> T_co: value = next(self._datapipe_iterator) classification = self.classifier_fn(value) if classification is None and self.drop_none: + StreamWrapper.close_streams(value) continue if classification is None or classification >= self.num_instances or classification < 0: raise ValueError(f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " + @@ -360,6 +371,7 @@ def _find_next(self, instance_id: int) -> T_co: def get_next_element_by_instance(self, instance_id: int): if self._datapipe_iterator is None and not self.main_datapipe_exhausted: self._datapipe_iterator = iter(self.main_datapipe) + self._snapshot_state = _SnapshotState.Iterating # This is necessary for the DataPipe to reset properly. stop = False while not stop: if self.child_buffers[instance_id]: @@ -392,6 +404,8 @@ def __getstate__(self): self.buffer_size, self.classifier_fn, self.drop_none, + self._valid_iterator_id, + self._number_of_samples_yielded, ) return state @@ -402,6 +416,8 @@ def __setstate__(self, state): self.buffer_size, self.classifier_fn, self.drop_none, + self._valid_iterator_id, + self._number_of_samples_yielded, ) = state self._datapipe_iterator = None self.current_buffer_usage = 0 @@ -469,6 +485,8 @@ def __getstate__(self): state = ( self.datapipes, self.length, + self._valid_iterator_id, + self._number_of_samples_yielded, ) return state @@ -476,6 +494,8 @@ def __setstate__(self, state): ( self.datapipes, self.length, + self._valid_iterator_id, + self._number_of_samples_yielded, ) = state self.buffer = [] @@ -510,8 +530,21 @@ def __init__(self, *datapipes: IterDataPipe): self.length = None def __iter__(self) -> Iterator[Tuple[T_co]]: - for data in zip(*self.datapipes): - yield data + iterators = [iter(datapipe) for datapipe in self.datapipes] + try: + for data in zip(*iterators): + yield data + finally: + unused = [] + for iterator in iterators: + try: + unused += list(iterator) + except RuntimeError: # Some iterators may have been invalidated by single iterator constraints + pass + + # TODO(VitalyFedyunin): This should be Exception or warning when torchdata.debug is enabled + for item in unused: + StreamWrapper.close_streams(item) def __len__(self) -> int: if self.length is not None: diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 45eebac3a0508..4314cb9a14970 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -20,9 +20,8 @@ class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]): Args: datapipe: Iterable datapipe that provides pathnames mode: An optional string that specifies the mode in which - the file is opened by ``open()``. It defaults to ``b`` which - means open for reading in binary mode. Another option is - to use ``t`` for text mode + the file is opened by ``open()``. It defaults to ``r``, other options are + ``b`` for reading in binary mode and ``t`` for text mode. encoding: An optional string that specifies the encoding of the underlying file. It defaults to ``None`` to match the default encoding of ``open``. length: Nominal length of the datapipe diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 37bfb18f6c84a..0c9f05fecf3d0 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -2,7 +2,7 @@ from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe, DataChunk -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar __all__ = [ @@ -215,7 +215,7 @@ def __init__(self, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False): - _check_lambda_fn(group_key_fn) + _check_unpickable_fn(group_key_fn) self.datapipe = datapipe self.group_key_fn = group_key_fn @@ -290,6 +290,8 @@ def __getstate__(self): self.guaranteed_group_size, self.drop_remaining, self.wrapper_class, + self._valid_iterator_id, + self._number_of_samples_yielded, ) return state @@ -302,6 +304,8 @@ def __setstate__(self, state): self.guaranteed_group_size, self.drop_remaining, self.wrapper_class, + self._valid_iterator_id, + self._number_of_samples_yielded, ) = state self.curr_buffer_size = 0 self.buffer_elements = defaultdict(list) diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py index 22d03e89c243c..5e4ec78c7091e 100644 --- a/torch/utils/data/datapipes/iter/selecting.py +++ b/torch/utils/data/datapipes/iter/selecting.py @@ -3,7 +3,12 @@ from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper -from torch.utils.data.datapipes.utils.common import _check_lambda_fn, _deprecation_warning +from torch.utils.data.datapipes.utils.common import ( + _check_unpickable_fn, + _deprecation_warning, + StreamWrapper, +) + __all__ = ["FilterIterDataPipe", ] @@ -48,7 +53,7 @@ def __init__( super().__init__() self.datapipe = datapipe - _check_lambda_fn(filter_fn) + _check_unpickable_fn(filter_fn) self.filter_fn = filter_fn # type: ignore[assignment] if drop_empty_batches is None: @@ -78,6 +83,8 @@ def __iter__(self) -> Iterator[T_co]: filtered = self._returnIfTrue(data) if self._isNonEmpty(filtered): yield filtered + else: + StreamWrapper.close_streams(data) def _returnIfTrue(self, data): condition = self._apply_filter_fn(data) diff --git a/torch/utils/data/datapipes/iter/streamreader.py b/torch/utils/data/datapipes/iter/streamreader.py index 974c371089560..1ba17b6041892 100644 --- a/torch/utils/data/datapipes/iter/streamreader.py +++ b/torch/utils/data/datapipes/iter/streamreader.py @@ -32,5 +32,6 @@ def __iter__(self): while True: d = stream.read(self.chunk) if not d: + stream.close() break yield (furl, d) diff --git a/torch/utils/data/datapipes/map/callable.py b/torch/utils/data/datapipes/map/callable.py index 0056c290996b7..7744c30de45d1 100644 --- a/torch/utils/data/datapipes/map/callable.py +++ b/torch/utils/data/datapipes/map/callable.py @@ -1,4 +1,4 @@ -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from typing import Callable, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import MapDataPipe @@ -48,7 +48,7 @@ def __init__( ) -> None: super().__init__() self.datapipe = datapipe - _check_lambda_fn(fn) + _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment] def __len__(self) -> int: diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index b126a585853b9..42227bfaf5922 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -1,9 +1,13 @@ -import os import fnmatch +import inspect +import os import warnings from io import IOBase -from typing import Dict, Iterable, List, Tuple, Union, Optional + +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + from torch.utils.data._utils.serialization import DILL_AVAILABLE @@ -16,13 +20,48 @@ ] -def _check_lambda_fn(fn): - # Partial object has no attribute '__name__', but can be pickled +def _is_local_fn(fn): + # Functions or Methods + if hasattr(fn, "__code__"): + return fn.__code__.co_flags & inspect.CO_NESTED + # Callable Objects + else: + if hasattr(fn, "__qualname__"): + return "" in fn.__qualname__ + fn_type = type(fn) + if hasattr(fn_type, "__qualname__"): + return "" in fn_type.__qualname__ + return False + + +def _check_unpickable_fn(fn: Callable): + """ + Checks function is pickable or not. If it is a lambda or local function, a UserWarning + will be raised. If it's not a callable function, a TypeError will be raised. + """ + if not callable(fn): + raise TypeError(f"A callable function is expected, but {type(fn)} is provided.") + + # Extract function from partial object + # Nested partial function is automatically expanded as a single partial object + if isinstance(fn, partial): + fn = fn.func + + # Local function + if _is_local_fn(fn) and not DILL_AVAILABLE: + warnings.warn( + "Local function is not supported by pickle, please use " + "regular python function or functools.partial instead." + ) + return + + # Lambda function if hasattr(fn, "__name__") and fn.__name__ == "" and not DILL_AVAILABLE: warnings.warn( - "Lambda function is not supported for pickle, please use " + "Lambda function is not supported by pickle, please use " "regular python function or functools.partial instead." ) + return def match_masks(name : str, masks : Union[str, List[str]]) -> bool: @@ -174,28 +213,81 @@ def _deprecation_warning( class StreamWrapper: - ''' + """ StreamWrapper is introduced to wrap file handler generated by DataPipe operation like `FileOpener`. StreamWrapper would guarantee the wrapped file handler is closed when it's out of scope. - ''' - def __init__(self, file_obj): + """ + session_streams: Dict[Any, int] = {} + debug_unclosed_streams: bool = False + + def __init__(self, file_obj, parent_stream=None, name=None): self.file_obj = file_obj + self.child_counter = 0 + self.parent_stream = parent_stream + self.close_on_last_child = False + self.name = name + self.closed = False + if parent_stream is not None: + if not isinstance(parent_stream, StreamWrapper): + raise RuntimeError('Parent stream should be StreamWrapper, {} was given'.format(type(parent_stream))) + parent_stream.child_counter += 1 + self.parent_stream = parent_stream + if StreamWrapper.debug_unclosed_streams: + StreamWrapper.session_streams[self] = 1 + + @classmethod + def close_streams(cls, v, depth=0): + """ + Traverse structure and attempts to close all found StreamWrappers on best effort basis. + """ + if depth > 10: + return + if isinstance(v, StreamWrapper): + v.close() + else: + # Traverse only simple structures + if isinstance(v, dict): + for kk, vv in v.items(): + cls.close_streams(vv, depth=depth + 1) + elif isinstance(v, list) or isinstance(v, tuple): + for vv in v: + cls.close_streams(vv, depth=depth + 1) def __getattr__(self, name): file_obj = self.__dict__['file_obj'] return getattr(file_obj, name) + def close(self, *args, **kwargs): + if StreamWrapper.debug_unclosed_streams: + del StreamWrapper.session_streams[self] + if hasattr(self, "parent_stream") and self.parent_stream is not None: + self.parent_stream.child_counter -= 1 + if not self.parent_stream.child_counter and self.parent_stream.close_on_last_child: + self.parent_stream.close() + try: + self.file_obj.close(*args, **kwargs) + except AttributeError: + pass + self.closed = True + + def autoclose(self): + """ + Close steam if there is no children, or make it to be automatically closed as soon as + all child streams are closed. + """ + if self.child_counter == 0: + self.close() + self.close_on_last_child = True + def __dir__(self): attrs = list(self.__dict__.keys()) + list(StreamWrapper.__dict__.keys()) attrs += dir(self.file_obj) return list(set(list(attrs))) def __del__(self): - try: - self.file_obj.close() - except AttributeError: - pass + if not self.closed: + self.close() def __iter__(self): for line in self.file_obj: @@ -205,7 +297,10 @@ def __next__(self): return next(self.file_obj) def __repr__(self): - return f"StreamWrapper<{self.file_obj!r}>" + if self.name is None: + return f"StreamWrapper<{self.file_obj!r}>" + else: + return f"StreamWrapper<{self.name},{self.file_obj!r}>" def __getstate__(self): return self.file_obj diff --git a/torch/utils/data/datapipes/utils/snapshot.py b/torch/utils/data/datapipes/utils/snapshot.py new file mode 100644 index 0000000000000..95af98e6b9203 --- /dev/null +++ b/torch/utils/data/datapipes/utils/snapshot.py @@ -0,0 +1,58 @@ +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.graph_settings import apply_shuffle_seed + + +# TODO: Caveats +# 1. Caller (either the ReadingService or DataLoader) must pass in the initial RNG +# 2. `in_batch_shuffle` and `bucketbatch` are not compatible with this because they currently +# lack the option to `set_seed`. +def _simple_graph_snapshot_restoration(datapipe: IterDataPipe, n_iterations: int, rng=None) -> None: + r""" + This function will restore a snapshot by fast-forwarding the given DataPipe by ``n_iterations``, + and in the process, fast-forward its parent DataPipes as well at the cost of re-doing every computation. + For instance, applying this function to the final DataPipe of a graph will restore the snapshot + (via fast-forward) every DataPipe within the graph. + + After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input + to this function to forward the DataPipe. + + A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration + attempts. + + Note: + This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding + methods (custom ones if necessary) are recommended. + + Args: + datapipe: IterDataPipe to be fast-forwarded + n_iterations: number of iterations to fast-forward + rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator + should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``. + """ + if datapipe._snapshot_state == _SnapshotState.Restored: + raise RuntimeError( + "Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph " + "if your graph has not been restored.") + + # For this snapshot restoration function, we want the DataPipe to be at its initial state prior to + # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`, + # the first reset will not actually reset. + datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`. + apply_shuffle_seed(datapipe, rng) + + remainder = n_iterations + it = iter(datapipe) # This always reset the DataPipe if it hasn't already. + while remainder > 0: + try: + next(it) + remainder -= 1 + except StopIteration: + raise RuntimeError(f"Fast-forward {datapipe} by {n_iterations} iterations " + "exceeds the number of samples available.") + datapipe._fast_forward_iterator = it + # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere. + + # This will prevent the DataPipe from resetting in the `iter()` call + # If another DataPipe is consuming it, it won't have to start over again + datapipe._snapshot_state = _SnapshotState.Restored diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py index 34be21edf1bd6..df2b81648cefa 100644 --- a/torch/utils/data/graph.py +++ b/torch/utils/data/graph.py @@ -4,11 +4,13 @@ from torch.utils.data import IterDataPipe, MapDataPipe from torch.utils.data._utils.serialization import DILL_AVAILABLE -from typing import Any, Dict, Set, Tuple, Type, Union +from typing import Dict, List, Set, Tuple, Type, Union -__all__ = ["traverse", ] +__all__ = ["traverse"] DataPipe = Union[IterDataPipe, MapDataPipe] +DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]] # type: ignore[misc] + reduce_ex_hook = None @@ -17,7 +19,7 @@ def _stub_unpickler(): # TODO(VitalyFedyunin): Make sure it works without dill module installed -def _list_connected_datapipes(scan_obj, only_datapipe, cache): +def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Set[int]) -> List[DataPipe]: f = io.BytesIO() p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is if DILL_AVAILABLE: @@ -39,10 +41,11 @@ def getstate_hook(obj): return state def reduce_hook(obj): - if obj == scan_obj or obj in cache: + if obj == scan_obj or id(obj) in cache: raise NotImplementedError else: captured_connections.append(obj) + cache.add(id(obj)) return _stub_unpickler, () datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment] @@ -70,21 +73,24 @@ def reduce_hook(obj): return captured_connections -def traverse(datapipe, only_datapipe=False): - cache: Set[DataPipe] = set() +def traverse(datapipe: DataPipe, only_datapipe: bool = False) -> DataPipeGraph: + cache: Set[int] = set() return _traverse_helper(datapipe, only_datapipe, cache) # Add cache here to prevent infinite recursion on DataPipe -def _traverse_helper(datapipe, only_datapipe, cache): +def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph: if not isinstance(datapipe, (IterDataPipe, MapDataPipe)): raise RuntimeError("Expected `IterDataPipe` or `MapDataPipe`, but {} is found".format(type(datapipe))) - cache.add(datapipe) - items = _list_connected_datapipes(datapipe, only_datapipe, cache) - d: Dict[DataPipe, Any] = {datapipe: {}} + dp_id = id(datapipe) + if dp_id in cache: + return {} + cache.add(dp_id) + items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy()) + d: DataPipeGraph = {dp_id: (datapipe, {})} for item in items: # Using cache.copy() here is to prevent recursion on a single path rather than global graph # Single DataPipe can present multiple times in different paths in graph - d[datapipe].update(_traverse_helper(item, only_datapipe, cache.copy())) + d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy())) return d diff --git a/torch/utils/data/graph_settings.py b/torch/utils/data/graph_settings.py index 02fa32d2f0e92..43678a06625dc 100644 --- a/torch/utils/data/graph_settings.py +++ b/torch/utils/data/graph_settings.py @@ -1,7 +1,12 @@ -import torch.utils.data.graph -from torch.utils.data.datapipes.iter import Shuffler import warnings +from typing import Any, List, Optional, Set + +import torch +import torch.utils.data.datapipes as dp + +from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse + __all__ = [ "apply_sharding", "apply_shuffle_seed", @@ -10,18 +15,22 @@ ] -def get_all_graph_pipes(graph): - results = set() - for datapipe, sub_graph in graph.items(): - results.add(datapipe) - sub_items = get_all_graph_pipes(sub_graph) - for item in sub_items: - results.add(item) +def get_all_graph_pipes(graph: DataPipeGraph) -> List[DataPipe]: + return _get_all_graph_pipes_helper(graph, set()) + +def _get_all_graph_pipes_helper(graph: DataPipeGraph, id_cache: Set[int]) -> List[DataPipe]: + results: List[DataPipe] = [] + for dp_id, (datapipe, sub_graph) in graph.items(): + if dp_id in id_cache: + continue + id_cache.add(dp_id) + results.append(datapipe) + results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache)) return results -def apply_sharding(datapipe, num_of_instances, instance_id): - graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True) +def apply_sharding(datapipe: DataPipe, num_of_instances: int, instance_id: int) -> DataPipe: + graph = traverse(datapipe, only_datapipe=True) all_pipes = get_all_graph_pipes(graph) already_applied_to = None for pipe in all_pipes: @@ -33,22 +42,23 @@ def apply_sharding(datapipe, num_of_instances, instance_id): 'Already applied to', already_applied_to, 'while trying to apply to', pipe) pipe.apply_sharding(num_of_instances, instance_id) already_applied_to = pipe + return datapipe -def apply_shuffle_settings(datapipe, shuffle): +def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool]) -> DataPipe: if shuffle is None: return datapipe - graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True) + graph = traverse(datapipe, only_datapipe=True) all_pipes = get_all_graph_pipes(graph) - shufflers = {pipe for pipe in all_pipes if isinstance(pipe, Shuffler)} + shufflers = [pipe for pipe in all_pipes if isinstance(pipe, (dp.iter.Shuffler, dp.map.Shuffler))] if not shufflers and shuffle: warnings.warn( "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. " "Be aware that the default buffer size might not be sufficient for your task." ) datapipe = datapipe.shuffle() - shufflers = {datapipe} + shufflers = [datapipe, ] # type: ignore[list-item] for shuffler in shufflers: shuffler.set_shuffle(shuffle) @@ -56,10 +66,10 @@ def apply_shuffle_settings(datapipe, shuffle): return datapipe -def apply_shuffle_seed(datapipe, rng): - graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True) +def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe: + graph = traverse(datapipe, only_datapipe=True) all_pipes = get_all_graph_pipes(graph) - shufflers = {pipe for pipe in all_pipes if isinstance(pipe, Shuffler)} + shufflers = {pipe for pipe in all_pipes if isinstance(pipe, (dp.iter.Shuffler, dp.map.Shuffler))} for shuffler in shufflers: shuffle_seed = int(torch.empty((), dtype=torch.int64).random_(generator=rng).item()) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index bd3c0a739f15b..b23ed4b688e0d 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -1,4 +1,7 @@ import collections +import os +import re +import subprocess from .constants import (API_BLAS, API_C10, API_CAFFE2, API_DRIVER, API_FFT, API_PYTORCH, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, @@ -24,6 +27,37 @@ supported in ROCm/HIP yet. """ +# We need to know the ROCm version so we can conditionalize some of the mappings later. +# As of ROCm 5.0, the version is found in rocm_version.h header file under /opt/rocm/include. +rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm" +try: + rocm_path = subprocess.check_output(["hipconfig", "--rocmpath"]).decode("utf-8") +except subprocess.CalledProcessError: + print(f"Warning: hipconfig --rocmpath failed, assuming {rocm_path}") +except FileNotFoundError: + # Do not print warning. This is okay. This file can also be imported for non-ROCm builds. + pass + +rocm_version = (0, 0, 0) +rocm_version_h = f"{rocm_path}/include/rocm_version.h" +# The file could be missing due to 1) ROCm version < 5.0, or 2) no ROCm install. +if os.path.isfile(rocm_version_h): + RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)") + RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)") + RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)") + major, minor, patch = 0, 0, 0 + for line in open(rocm_version_h, "r"): + match = RE_MAJOR.search(line) + if match: + major = int(match.group(1)) + match = RE_MINOR.search(line) + if match: + minor = int(match.group(1)) + match = RE_PATCH.search(line) + if match: + patch = int(match.group(1)) + rocm_version = (major, minor, patch) + # List of math functions that should be replaced inside device code only. MATH_TRANSPILATIONS = collections.OrderedDict( [ @@ -563,8 +597,8 @@ ("hip/hip_texture_types.h", CONV_INCLUDE, API_RUNTIME), ), ("vector_types.h", ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME)), - ("cublas.h", ("rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), - ("cublas_v2.h", ("rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), + ("cublas.h", ("rocblas.h" if rocm_version < (5, 2, 0) else "rocblas/rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), + ("cublas_v2.h", ("rocblas.h" if rocm_version < (5, 2, 0) else "rocblas/rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), ("curand.h", ("hiprand/hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)), ("curand_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("curand_discrete.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), @@ -8188,6 +8222,7 @@ C10_MAPPINGS = collections.OrderedDict( [ ("cuda::compat::", ("hip::compat::", API_C10)), + ("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)), ("c10/cuda/CUDAException.h", ("c10/hip/HIPException.h", API_C10)), ("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)), ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 54f1ece427194..f7d5e714c5c37 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -46,6 +46,12 @@ to their actual types.""" PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"} +__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter', + 'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group', + 'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared', + 'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_caffe2_gpu_file', + 'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header', + 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify'] class InputError(Exception): # Exception raised for errors in the input. diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index 60d711e9bba17..8a230800099bc 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -84,6 +84,8 @@ DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024 +__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton', + 'burn_in_info', 'get_info_and_burn_skeleton'] def get_storage_info(storage): assert isinstance(storage, torch.utils.show_pickle.FakeObject) diff --git a/torch/utils/tensorboard/_caffe2_graph.py b/torch/utils/tensorboard/_caffe2_graph.py index 155b0ad846b01..1d371dbc21ca9 100644 --- a/torch/utils/tensorboard/_caffe2_graph.py +++ b/torch/utils/tensorboard/_caffe2_graph.py @@ -21,12 +21,12 @@ def _make_unique_name(seen: Set[str], name: str, min_version: int = 0): Args: seen (set): Set of names that have already been used (with respect to some context). - name (string): The name to make unique + name (str): The name to make unique min_version (number): Starting index. Is incremented continually until it can make the resulting name unique relative to 'seen'. Returns: - x (string): A version of name that is not in seen. + x (str): A version of name that is not in seen. """ assert name is not None i = min_version diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index 99d051b72bda9..00aca545a3337 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -1,4 +1,5 @@ from collections import OrderedDict +import contextlib from typing import Dict, Any from tensorboard.compat.proto.config_pb2 import RunMetadata @@ -112,7 +113,7 @@ def __init__(self, node_cpp): # Replace single quote which causes strange behavior in TensorBoard # TODO: See if we can remove this in the future self.attributes = str( - {k: node_cpp[k] for k in node_cpp.attributeNames()} + {k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()} ).replace("'", " ") self.kind = node_cpp.kind() @@ -331,9 +332,7 @@ def graph(model, args, verbose=False, use_strict_trace=True): `torch.jit.trace`. Pass False when you want the tracer to record your mutable container types (list, dict) """ - with torch.onnx.select_model_mode_for_export( - model, torch.onnx.TrainingMode.EVAL - ): # TODO: move outside of torch.onnx? + with _set_model_to_eval(model): try: trace = torch.jit.trace(model, args, strict=use_strict_trace) graph = trace.graph @@ -362,3 +361,27 @@ def graph(model, args, verbose=False, use_strict_trace=True): return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats # The producer version has been reverse engineered from standard # TensorBoard logged data. + + +@contextlib.contextmanager +def _set_model_to_eval(model): + """A context manager to temporarily set the training mode of ``model`` to eval.""" + if not isinstance(model, torch.jit.ScriptFunction): + originally_training = model.training + model.train(False) + try: + yield + finally: + model.train(originally_training) + else: + # Do nothing for ScriptFunction + try: + yield + finally: + pass + + +def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type.""" + sel = node.kindOf(key) + return getattr(node, sel)(key) diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index afc6b8e9b8fbf..4bb8083076bfc 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -20,6 +20,9 @@ from ._convert_np import make_np from ._utils import _prepare_video, convert_to_HWC +__all__ = ['hparams', 'scalar', 'histogram_raw', 'histogram', 'make_histogram', 'image', 'image_boxes', 'draw_boxes', + 'make_image', 'video', 'make_video', 'audio', 'custom_scalars', 'text', 'pr_curve_raw', 'pr_curve', 'compute_curve', + 'mesh'] logger = logging.getLogger(__name__) diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index a4b23d0935873..70b654384ff08 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -39,6 +39,7 @@ hparams, ) +__all__ = ['FileWriter', 'SummaryWriter'] class FileWriter(object): """Writes protocol buffers to event files to be consumed by TensorBoard. @@ -187,12 +188,12 @@ def __init__( to the event file. Args: - log_dir (string): Save directory location. Default is + log_dir (str): Save directory location. Default is runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run. Use hierarchical folder structure to compare between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc. for each new experiment to compare across them. - comment (string): Comment log_dir suffix appended to the default + comment (str): Comment log_dir suffix appended to the default ``log_dir``. If ``log_dir`` is assigned, this argument has no effect. purge_step (int): When logging crashes at step :math:`T+X` and restarts at step :math:`T`, @@ -204,7 +205,7 @@ def __init__( Default is ten items. flush_secs (int): How often, in seconds, to flush the pending events and summaries to disk. Default is every two minutes. - filename_suffix (string): Suffix added to all event filenames in + filename_suffix (str): Suffix added to all event filenames in the log_dir directory. More details on filename construction in tensorboard.summary.writer.event_file_writer.EventFileWriter. @@ -356,7 +357,7 @@ def add_scalar( """Add scalar data to summary. Args: - tag (string): Data identifier + tag (str): Data identifier scalar_value (float or string/blobname): Value to save global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) @@ -393,7 +394,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None """Adds many scalar data to summary. Args: - main_tag (string): The parent name for the tags + main_tag (str): The parent name for the tags tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) @@ -449,10 +450,10 @@ def add_histogram( """Add histogram to summary. Args: - tag (string): Data identifier - values (torch.Tensor, numpy.array, or string/blobname): Values to build histogram + tag (str): Data identifier + values (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram global_step (int): Global step value to record - bins (string): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find + bins (str): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html walltime (float): Optional override default walltime (time.time()) seconds after epoch of event @@ -500,15 +501,15 @@ def add_histogram_raw( """Adds histogram with raw data. Args: - tag (string): Data identifier + tag (str): Data identifier min (float or int): Min value max (float or int): Max value num (int): Number of values sum (float or int): Sum of all values sum_squares (float or int): Sum of squares for all values - bucket_limits (torch.Tensor, numpy.array): Upper value per bucket. + bucket_limits (torch.Tensor, numpy.ndarray): Upper value per bucket. The number of elements of it should be the same as `bucket_counts`. - bucket_counts (torch.Tensor, numpy.array): Number of values per bucket + bucket_counts (torch.Tensor, numpy.ndarray): Number of values per bucket global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) seconds after epoch of event @@ -567,12 +568,12 @@ def add_image( Note that this requires the ``pillow`` package. Args: - tag (string): Data identifier - img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) seconds after epoch of event - dataformats (string): Image data format specification of the form + dataformats (str): Image data format specification of the form CHW, HWC, HW, WH, etc. Shape: img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to @@ -622,12 +623,12 @@ def add_images( Note that this requires the ``pillow`` package. Args: - tag (string): Data identifier - img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) seconds after epoch of event - dataformats (string): Image data format specification of the form + dataformats (str): Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc. Shape: img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be @@ -676,22 +677,22 @@ def add_image_with_boxes( """Add image and draw bounding boxes on the image. Args: - tag (string): Data identifier - img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data - box_tensor (torch.Tensor, numpy.array, or string/blobname): Box data (for detected objects) + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data + box_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Box data (for detected objects) box should be represented as [x1, y1, x2, y2]. global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) seconds after epoch of event rescale (float): Optional scale override - dataformats (string): Image data format specification of the form + dataformats (str): Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc. labels (list of string): The label to be shown for each bounding box. Shape: img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformats`` argument. e.g. CHW or HWC - box_tensor: (torch.Tensor, numpy.array, or string/blobname): NX4, where N is the number of + box_tensor: (torch.Tensor, numpy.ndarray, or string/blobname): NX4, where N is the number of boxes and each 4 elements in a row represents (xmin, ymin, xmax, ymax). """ torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes") @@ -727,7 +728,7 @@ def add_figure(self, tag, figure, global_step=None, close=True, walltime=None): Note that this requires the ``matplotlib`` package. Args: - tag (string): Data identifier + tag (str): Data identifier figure (matplotlib.pyplot.figure) or list of figures: Figure or a list of figures global_step (int): Global step value to record close (bool): Flag to automatically close the figure @@ -758,7 +759,7 @@ def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): Note that this requires the ``moviepy`` package. Args: - tag (string): Data identifier + tag (str): Data identifier vid_tensor (torch.Tensor): Video data global_step (int): Global step value to record fps (float or int): Frames per second @@ -778,7 +779,7 @@ def add_audio( """Add audio data to summary. Args: - tag (string): Data identifier + tag (str): Data identifier snd_tensor (torch.Tensor): Sound data global_step (int): Global step value to record sample_rate (int): sample rate in Hz @@ -800,8 +801,8 @@ def add_text(self, tag, text_string, global_step=None, walltime=None): """Add text data to summary. Args: - tag (string): Data identifier - text_string (string): String to save + tag (str): Data identifier + text_string (str): String to save global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) seconds after epoch of event @@ -881,11 +882,11 @@ def add_embedding( """Add embedding projector data to summary. Args: - mat (torch.Tensor or numpy.array): A matrix which each row is the feature vector of the data point + mat (torch.Tensor or numpy.ndarray): A matrix which each row is the feature vector of the data point metadata (list): A list of labels, each element will be convert to string label_img (torch.Tensor): Images correspond to each data point global_step (int): Global step value to record - tag (string): Name for the embedding + tag (str): Name for the embedding Shape: mat: :math:`(N, D)`, where N is number of data and D is feature dimension @@ -985,10 +986,10 @@ def add_pr_curve( will let you choose the threshold interactively. Args: - tag (string): Data identifier - labels (torch.Tensor, numpy.array, or string/blobname): + tag (str): Data identifier + labels (torch.Tensor, numpy.ndarray, or string/blobname): Ground truth data. Binary label for each element. - predictions (torch.Tensor, numpy.array, or string/blobname): + predictions (torch.Tensor, numpy.ndarray, or string/blobname): The probability that an element be classified as true. Value should be in [0, 1] global_step (int): Global step value to record @@ -1032,13 +1033,13 @@ def add_pr_curve_raw( """Adds precision recall curve with raw data. Args: - tag (string): Data identifier - true_positive_counts (torch.Tensor, numpy.array, or string/blobname): true positive counts - false_positive_counts (torch.Tensor, numpy.array, or string/blobname): false positive counts - true_negative_counts (torch.Tensor, numpy.array, or string/blobname): true negative counts - false_negative_counts (torch.Tensor, numpy.array, or string/blobname): false negative counts - precision (torch.Tensor, numpy.array, or string/blobname): precision - recall (torch.Tensor, numpy.array, or string/blobname): recall + tag (str): Data identifier + true_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): true positive counts + false_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): false positive counts + true_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): true negative counts + false_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): false negative counts + precision (torch.Tensor, numpy.ndarray, or string/blobname): precision + recall (torch.Tensor, numpy.ndarray, or string/blobname): recall global_step (int): Global step value to record num_thresholds (int): Number of thresholds used to draw the curve. walltime (float): Optional override default walltime (time.time()) @@ -1140,7 +1141,7 @@ def add_mesh( advanced usage. Args: - tag (string): Data identifier + tag (str): Data identifier vertices (torch.Tensor): List of the 3D coordinates of vertices. colors (torch.Tensor): Colors for each vertex faces (torch.Tensor): Indices of vertices within each triangle. (Optional) diff --git a/torch/utils/throughput_benchmark.py b/torch/utils/throughput_benchmark.py index bed054407f2c5..4008d5a23c048 100644 --- a/torch/utils/throughput_benchmark.py +++ b/torch/utils/throughput_benchmark.py @@ -136,7 +136,7 @@ def benchmark( iterations might be slightly larger. Which is reported as stats.num_iters where stats is the result of this function - profiler_output_path (string): Location to save Autograd Profiler trace. + profiler_output_path (str): Location to save Autograd Profiler trace. If not empty, Autograd Profiler will be enabled for the main benchmark execution (but not the warmup phase). The full trace will be saved into the file path provided by this argument diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 11dd831bacd33..417bba637ce0a 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -1,15 +1,10 @@ -from dataclasses import dataclass import re -from typing import Optional, Sequence, Set, List, Tuple, Match +from dataclasses import dataclass +from typing import List, Match, Optional, Sequence, Set, Tuple from torchgen.api import cpp from torchgen.api.types import Binding, NamedCType -from torchgen.model import ( - NativeFunction, - Type, - SchemaKind, - NativeFunctionsViewGroup, -) +from torchgen.model import NativeFunction, NativeFunctionsViewGroup, SchemaKind, Type from torchgen.utils import IDENT_REGEX # Represents a saved attribute involved in backward calculation. diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 44890c8827782..408c123a0952f 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -1,3 +1,35 @@ +from typing import List, Optional, Sequence, Set, Union + +from torchgen import local +from torchgen.api.types import ( + ArgName, + ArrayCType, + ArrayRefCType, + BaseCType, + BaseTypeToCppMapping, + Binding, + boolT, + ConstRefCType, + CType, + dimnameListT, + intArrayRefT, + ListCType, + longT, + MutRefCType, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + scalarT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorListT, + tensorOptionsT, + tensorT, + TupleCType, + VectorCType, + voidT, +) from torchgen.model import ( Argument, Arguments, @@ -12,37 +44,7 @@ TensorOptionsArguments, Type, ) -from torchgen.api.types import ( - ArgName, - BaseCType, - Binding, - ConstRefCType, - NamedCType, - CType, - MutRefCType, - ArrayCType, - ListCType, - VectorCType, - ArrayRefCType, - OptionalCType, - TupleCType, - SpecialArgName, - boolT, - scalarT, - tensorListT, - dimnameListT, - tensorT, - voidT, - longT, - BaseTypeToCppMapping, - intArrayRefT, - optionalIntArrayRefT, - tensorOptionsT, - symIntArrayRefT, -) -from torchgen import local from torchgen.utils import assert_never -from typing import Optional, Sequence, Union, List, Set # This file describes the translation of JIT schema to the public C++ # API, which is what people use when they call functions like at::add. @@ -64,8 +66,6 @@ def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str: name = str(func.name.name) - if func.is_functional_fn(): - name += "_functional" if func.is_symint_fn(): name += "_symint" if func.is_out_fn(): @@ -155,12 +155,15 @@ def argumenttype_type( return NamedCType(binds, VectorCType(BaseCType(longT))) else: return NamedCType(binds, BaseCType(intArrayRefT)) + if str(t.elem) == "SymInt": + if remove_non_owning_ref_types: + return NamedCType(binds, VectorCType(BaseCType(SymIntT))) + else: + return NamedCType(binds, BaseCType(symIntArrayRefT)) elif str(t.elem) == "Tensor": return NamedCType(binds, BaseCType(tensorListT)) elif str(t.elem) == "Scalar": return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) - elif str(t.elem) == "SymInt": - return NamedCType(binds, BaseCType(symIntArrayRefT)) elif str(t.elem) == "Dimname": return NamedCType(binds, BaseCType(dimnameListT)) elif str(t.elem) == "Tensor?": diff --git a/torchgen/api/dispatcher.py b/torchgen/api/dispatcher.py index ad1f17f719403..008e8c5664a47 100644 --- a/torchgen/api/dispatcher.py +++ b/torchgen/api/dispatcher.py @@ -1,3 +1,9 @@ +import itertools +from typing import List, Sequence, Union + +from torchgen.api import cpp + +from torchgen.api.types import ArgName, Binding, CType, NamedCType from torchgen.model import ( Argument, FunctionSchema, @@ -6,13 +12,7 @@ TensorOptionsArguments, Type, ) - -from torchgen.api.types import ArgName, Binding, NamedCType, CType -from torchgen.api import cpp -from torchgen.utils import concatMap, assert_never - -import itertools -from typing import Sequence, List, Union +from torchgen.utils import assert_never, concatMap # This file describes the translation of JIT schema to the dispatcher # API, the *unboxed* calling convention by which invocations through diff --git a/torchgen/api/functionalization.py b/torchgen/api/functionalization.py index 22ce2c3c4d00a..c071fd10087bf 100644 --- a/torchgen/api/functionalization.py +++ b/torchgen/api/functionalization.py @@ -1,22 +1,23 @@ -from torchgen.model import ( - FunctionSchema, - BaseTy, - BaseType, - NativeFunctionsViewGroup, - Argument, -) +from typing import List, Optional + +from torchgen.api import dispatcher from torchgen.api.types import ( + BaseCType, Binding, - NamedCType, + boolT, ConstRefCType, - BaseCType, CType, - tensorT, longT, - boolT, + NamedCType, + tensorT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + NativeFunctionsViewGroup, ) -from torchgen.api import dispatcher -from typing import List, Optional # This file describes the translation of JIT schema to API's used diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index d424ae02ecb4e..6bce9db92bdb2 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -1,35 +1,36 @@ -from typing import Any, Dict, List, Union, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple, Union -from torchgen.model import ( - Type, - BaseTy, - BaseType, - OptionalType, - ListType, - OperatorName, - FunctionSchema, - Return, - TensorOptionsArguments, - Argument, -) from torchgen.api.types import ( - CType, BaseCppType, BaseCType, - OptionalCType, - NamedCType, - deviceT, - layoutT, - VectorCType, boolT, - longT, + CType, + deviceT, doubleT, + layoutT, ListCType, - stringT, + longT, + memoryFormatT, + NamedCType, + OptionalCType, scalarT, scalarTypeT, - memoryFormatT, + stringT, SymIntT, + VectorCType, +) + +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + OperatorName, + OptionalType, + Return, + TensorOptionsArguments, + Type, ) diff --git a/torchgen/api/native.py b/torchgen/api/native.py index 47610022e55a5..16814e34867c0 100644 --- a/torchgen/api/native.py +++ b/torchgen/api/native.py @@ -1,35 +1,35 @@ -from torchgen.model import ( - Argument, - FunctionSchema, - Return, - SelfArgument, - TensorOptionsArguments, - Type, -) +from typing import List, Optional, Sequence, Union + +from torchgen import local +from torchgen.api import cpp from torchgen.api.types import ( ArgName, BaseCType, Binding, + boolT, ConstRefCType, - NamedCType, CType, - MutRefCType, + deviceT, + layoutT, ListCType, + MutRefCType, + NamedCType, OptionalCType, - tensorT, scalarT, - layoutT, - deviceT, - boolT, scalarTypeT, + tensorT, +) +from torchgen.model import ( + Argument, + FunctionSchema, + Return, + SelfArgument, + TensorOptionsArguments, + Type, ) -from torchgen.api import cpp -from torchgen import local from torchgen.utils import assert_never -from typing import Union, Sequence, List, Optional - # This file describes the translation of JIT schema to the native functions API. # This looks a lot like the C++ API (which makes historical sense, because the # idea was you wrote native functions to implement functions in the C++ API), diff --git a/torchgen/api/python.py b/torchgen/api/python.py index 64ce1a9700f7d..fb9e7352fb5dd 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -1,13 +1,15 @@ from dataclasses import dataclass -from typing import Optional, Union, Sequence, Set, List, Dict, Tuple +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union -from torchgen.api.types import Binding, CppSignature, CppSignatureGroup from torchgen.api import cpp + +from torchgen.api.types import Binding, CppSignature, CppSignatureGroup from torchgen.gen import pythonify_default from torchgen.model import ( Argument, BaseTy, BaseType, + FunctionSchema, ListType, NativeFunction, OptionalType, @@ -446,12 +448,8 @@ def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[st # dedicated data model to store these extra properties. @dataclass(frozen=True) class PythonSignatureDeprecated(PythonSignature): - # We need keep the order of arguments in deprecated signature. - # Particularly, method signature might have 'self' not at the beginning, e.g.: - # addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) - # When generating lambda function signature we need follow the exact order (even for method=True): - # [](Scalar beta, const Tensor & self, const Tensor & mat1, const Tensor & mat2) -> Tensor - deprecated_args_names: Tuple[str, ...] + # Schema for the deprecated function + deprecated_schema: FunctionSchema # The deprecated signature might miss some arguments that the corresponding # C++ signature expects. We need store the constant default values to pass in. @@ -520,6 +518,35 @@ class PythonSignatureGroup: # The out variant (e.g. conv2d_out) outplace: Optional[NativeFunction] + @classmethod + def from_pairs( + cls, + functional: PythonSignatureNativeFunctionPair, + out: Optional[PythonSignatureNativeFunctionPair], + ) -> "PythonSignatureGroup": + if out is None: + return PythonSignatureGroup( + signature=functional.signature, + base=functional.function, + outplace=None, + ) + + # prefer the signature with optional out=... arguments because it's the + # superset that can be used to parse input for both base and outplace. + signature_kwargs = out.signature.__dict__.copy() + + # Out overloads in C++ don't have TensorOptions arguments, + # so take these from the functional variant + signature_kwargs[ + "tensor_options_args" + ] = functional.signature.tensor_options_args + + return PythonSignatureGroup( + signature=type(out.signature)(**signature_kwargs), + base=functional.function, + outplace=out.function, + ) + # C++ function dispatch is wrapped in a lambda function. The lambda function # has almost the same signature as the C++ function, only with some small @@ -694,22 +721,34 @@ def argument(a: Argument) -> PythonArgument: # Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen def signature( f: NativeFunction, *, method: bool = False, pyi: bool = False +) -> PythonSignature: + return signature_from_schema( + f.func, category_override=f.category_override, method=method, pyi=pyi + ) + + +def signature_from_schema( + func: FunctionSchema, + *, + category_override: Optional[str], + method: bool = False, + pyi: bool = False, ) -> PythonSignature: args: List[Argument] = [] - args.extend(f.func.arguments.pre_self_positional) + args.extend(func.arguments.pre_self_positional) # Skip SelfArgument if this is method. - if not method and f.func.arguments.self_arg is not None: - args.append(f.func.arguments.self_arg.argument) - args.extend(f.func.arguments.post_self_positional) - args.extend(f.func.arguments.pre_tensor_options_kwarg_only) + if not method and func.arguments.self_arg is not None: + args.append(func.arguments.self_arg.argument) + args.extend(func.arguments.post_self_positional) + args.extend(func.arguments.pre_tensor_options_kwarg_only) # Skip TensorOptionsArguments. Python side TensorOptions # arguments are created based on different rules - see below. - args.extend(f.func.arguments.post_tensor_options_kwarg_only) - args.extend(f.func.arguments.out) + args.extend(func.arguments.post_tensor_options_kwarg_only) + args.extend(func.arguments.out) - input_arg_set = set(a.name for a in f.func.arguments.flat_positional) - kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only) - out_arg_set = set(a.name for a in f.func.arguments.out) + input_arg_set = set(a.name for a in func.arguments.flat_positional) + kwarg_only_set = set(a.name for a in func.arguments.flat_kwarg_only) + out_arg_set = set(a.name for a in func.arguments.out) input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) input_kwargs = tuple( @@ -726,43 +765,61 @@ def signature( # source of drift between eager and JIT. Pull this logic out to a shared place. has_tensor_input_arg = any( - a.type.is_tensor_like() for a in f.func.arguments.flat_non_out + a.type.is_tensor_like() for a in func.arguments.flat_non_out ) - if any(a.name == "requires_grad" for a in f.func.schema_order_arguments()): + if any(a.name == "requires_grad" for a in func.schema_order_arguments()): raise ValueError( "argument named requires_grad is reserved, should not explicitly add it in the schema" ) # [old codegen] this probably won't work if one of the returns is not a tensor, # but it will produce a compile-time error that is obvious. - has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) + has_tensor_return = any(r.type.is_tensor_like() for r in func.returns) - name: str = cpp.name(f.func) - is_factory_function = f.category_override == "factory" or ( + name: str = cpp.name(func) + is_factory_function = category_override == "factory" or ( has_tensor_return and not has_tensor_input_arg ) is_like_or_new_function = ( - f.category_override in ("new", "like") + category_override in ("new", "like") or name.startswith("new_") or name.endswith("_like") ) tensor_options_args: List[PythonArgument] = [] if is_factory_function or is_like_or_new_function: + + def topt_default_init(name: str) -> Optional[str]: + topt_args = func.arguments.tensor_options + if topt_args is None: + return None + a = getattr(topt_args, name) + if a.default is None or a.default == "None": + return None + return cpp.default_expr(a.default, a.type) + tensor_options_args.append( PythonArgument( name="dtype", type=BaseType(BaseTy.ScalarType), - default="None" if pyi else _dtype_default_type_hack(name), - default_init="self.scalar_type()" if is_like_or_new_function else None, + default="None", + default_init=( + "self.scalar_type()" + if is_like_or_new_function + else topt_default_init("dtype") + ), ) ) tensor_options_args.append( PythonArgument( name="layout", type=OptionalType(BaseType(BaseTy.Layout)), - default="strided" if pyi else "torch.strided", - default_init="self.layout()" if is_like_or_new_function else None, + default="None", + default_init=( + "self.layout()" + if is_like_or_new_function + else topt_default_init("layout") + ), ) ) tensor_options_args.append( @@ -770,7 +827,11 @@ def signature( name="device", type=BaseType(BaseTy.Device), default="None", - default_init="self.device()" if is_like_or_new_function else None, + default_init=( + "self.device()" + if is_like_or_new_function + else topt_default_init("device") + ), ) ) tensor_options_args.append( @@ -790,10 +851,10 @@ def signature( ) ) - returns = PythonReturns(returns=f.func.returns) + returns = PythonReturns(returns=func.returns) return PythonSignature( - name=str(f.func.name.name), + name=str(func.name.name), input_args=input_args, input_kwargs=input_kwargs, output_args=PythonOutArgument.from_outputs(outputs), @@ -803,16 +864,6 @@ def signature( ) -# TODO blowtorch -# note: removing this will be BC-breaking. A quick test shows that -# randperm will otherwise default its dtype to torch.float64 -def _dtype_default_type_hack(name: str) -> str: - if name.startswith("randperm") or name == "tril_indices" or name == "triu_indices": - return "torch.int64" - else: - return "None" - - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # Python Interface @@ -993,20 +1044,19 @@ def returns_str_pyi(signature: PythonSignature) -> str: def dispatch_lambda_args( ps: PythonSignature, f: NativeFunction ) -> Tuple[DispatchLambdaArgument, ...]: - # Start with cpp arguments - dispatch lambda signature always include 'self' - cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() - - # Special reorder logic for deprecated python signature if isinstance(ps, PythonSignatureDeprecated): - m: Dict[str, Binding] = dict((a.name, a) for a in cpp_args) - # reorder according to the deprecated signature - # ignore 'out' argument when binding to non-output function. - ordered_args = filter( - lambda n: n != "out" or f.func.is_out_fn(), ps.deprecated_args_names - ) - cpp_args = list(map(lambda n: m[n], ordered_args)) + schema = ps.deprecated_schema + else: + schema = f.func - out_args: Set[str] = set(a.name for a in f.func.arguments.out) + # Start with cpp arguments - dispatch lambda signature always include 'self' + cpp_args = cpp.arguments( + arguments=schema.arguments, + faithful=False, + method=False, + cpp_no_default_args=f.cpp_no_default_args, + ) + out_args: Set[str] = set(a.name for a in schema.arguments.out) # Convert from cpp argument to lambda argument def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index 2a0ecd9182928..4787adccae6b3 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -1,3 +1,25 @@ +from typing import List, Union + +from torchgen.api import cpp + +from torchgen.api.types import ( + ArgName, + ArrayRefCType, + BaseCType, + Binding, + ConstRefCType, + dimnameListT, + intArrayRefT, + iOptTensorListRefT, + iTensorListRefT, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, + optionalTensorRefT, + scalarT, + tensorT, +) from torchgen.model import ( Argument, BaseTy, @@ -9,31 +31,8 @@ TensorOptionsArguments, Type, ) - -from torchgen.api.types import ( - ArgName, - BaseCType, - Binding, - ArrayRefCType, - ConstRefCType, - OptionalCType, - NamedCType, - tensorT, - scalarT, - intArrayRefT, - dimnameListT, - optionalTensorRefT, - optionalScalarRefT, - optionalIntArrayRefT, - iTensorListRefT, - iOptTensorListRefT, -) - -from torchgen.api import cpp from torchgen.utils import assert_never -from typing import Union, List - # This file describes the translation of JIT schema to the structured functions API. # This is similar to native API, but a number of historical problems with native # API have been fixed. diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index 372350cea58a6..bee33b473dc94 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -1,33 +1,36 @@ -from typing import Dict, Sequence, List, NoReturn, Union +from typing import Dict, List, NoReturn, Sequence, Union + from torchgen.api.types import ( - ListCType, - tensorListT, BaseCType, Binding, + boolT, ConstRefCType, + deviceT, Expr, + intArrayRefT, + iOptTensorListRefT, + iTensorListRefT, + layoutT, + ListCType, + longT, + memoryFormatT, MutRefCType, - OptionalCType, NamedCType, - SpecialArgName, - tensorT, - memoryFormatT, - tensorOptionsT, - scalarTypeT, - boolT, - deviceT, - layoutT, + opmath_t, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, optionalTensorRefT, - iTensorListRefT, - iOptTensorListRefT, + scalar_t, scalarT, - optionalScalarRefT, + scalarTypeT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorListT, + tensorOptionsT, + tensorT, VectorCType, - longT, - intArrayRefT, - scalar_t, - opmath_t, - optionalIntArrayRefT, ) # This file implements a small program synthesis engine that implements @@ -63,6 +66,7 @@ out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) longVec_ctype = VectorCType(BaseCType(longT)) +longSymVec_ctype = VectorCType(BaseCType(SymIntT)) optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) @@ -324,7 +328,19 @@ def direct_solve(goal: NamedCType) -> str: # We can always do translations from value types to reference types, like vector -> IntArrayRef elif goal.type == BaseCType(intArrayRefT): - return direct_solve(NamedCType(goal.name, longVec_ctype)) + try: + return direct_solve(NamedCType(goal.name, longVec_ctype)) + except UnsatError: + # We can also go SymIntArrayRef -> IntArrayRef + symIntArrayRef_type = direct_solve( + NamedCType(goal.name, BaseCType(symIntArrayRefT)) + ) + return f"c10::asIntArrayRefSlow({symIntArrayRef_type})" + elif goal.type == BaseCType(symIntArrayRefT): + return direct_solve(NamedCType(goal.name, longSymVec_ctype)) + elif goal.type == BaseCType(longT): + symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) + return f"{symInt_type}.expectInt()" elif goal.type == BaseCType(optionalIntArrayRefT): return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) elif goal.type == BaseCType(optionalScalarRefT): @@ -345,6 +361,10 @@ def direct_solve(goal: NamedCType) -> str: intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) argname = direct_solve(intArrayRef_ctype) return f"{argname}.vec()" + if goal.type == VectorCType(BaseCType(SymIntT)): + symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) + argname = direct_solve(symIntArrayRef_ctype) + return f"{argname}.vec()" elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): optionalIntArrayRef_ctype = NamedCType( goal.name, BaseCType(optionalIntArrayRefT) diff --git a/torchgen/api/types.py b/torchgen/api/types.py index 9717133a4cdbb..6bee40a421d40 100644 --- a/torchgen/api/types.py +++ b/torchgen/api/types.py @@ -1,18 +1,19 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Sequence, Set, TypeVar, Union + from torchgen.model import ( Argument, + BackendIndex, + BaseTy, FunctionSchema, NativeFunction, - BackendIndex, NativeFunctionsGroup, NativeFunctionsViewGroup, + ScalarType, SelfArgument, TensorOptionsArguments, - BaseTy, - ScalarType, ) -from dataclasses import dataclass -from typing import Optional, Union, Sequence, TypeVar, List, Set, Dict -from enum import Enum _T = TypeVar("_T") @@ -752,8 +753,8 @@ def kernel_signature( from torchgen.api import ( cpp, dispatcher, - native, - translate, functionalization, + native, structured, + translate, ) diff --git a/torchgen/api/ufunc.py b/torchgen/api/ufunc.py index 5836e276240ee..34384ce340d53 100644 --- a/torchgen/api/ufunc.py +++ b/torchgen/api/ufunc.py @@ -1,29 +1,28 @@ -from torchgen.model import ( - Argument, - BaseTy, - BaseType, - FunctionSchema, - NativeFunctionsGroup, - Type, - DispatchKey, -) +from dataclasses import dataclass +from typing import List, Optional import torchgen.api.types as api_types + +from torchgen.api import cpp, structured from torchgen.api.types import ( ArgName, + BaseCppType, BaseCType, Binding, ConstRefCType, + CType, NamedCType, scalarT, - CType, - BaseCppType, ) - -from torchgen.api import cpp, structured - -from dataclasses import dataclass -from typing import List, Optional +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, + FunctionSchema, + NativeFunctionsGroup, + Type, +) def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: diff --git a/torchgen/api/unboxing.py b/torchgen/api/unboxing.py index 06595353de291..b5afdc099fa9d 100644 --- a/torchgen/api/unboxing.py +++ b/torchgen/api/unboxing.py @@ -1,15 +1,15 @@ from typing import List, Tuple from torchgen.api import cpp -from torchgen.api.types import Binding, CType, CppSignatureGroup +from torchgen.api.types import Binding, CppSignatureGroup, CType from torchgen.model import ( Argument, - NativeFunction, - Type, + BaseTy, BaseType, - OptionalType, ListType, - BaseTy, + NativeFunction, + OptionalType, + Type, ) # This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the diff --git a/torchgen/code_template.py b/torchgen/code_template.py index e8241c65586ff..9f877771afe9b 100644 --- a/torchgen/code_template.py +++ b/torchgen/code_template.py @@ -1,5 +1,5 @@ import re -from typing import Match, Optional, Sequence, Mapping +from typing import Mapping, Match, Optional, Sequence # match $identifier or ${identifier} and replace with value in env # If this identifier is at the beginning of whitespace on a line diff --git a/torchgen/context.py b/torchgen/context.py index f65e3daaa8d9b..bbb8ea4d5c4c0 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -1,16 +1,17 @@ -from torchgen.utils import S, T, context +import contextlib + +import functools +from typing import Callable, Dict, Iterator, Optional, TypeVar, Union + +import torchgen.local as local from torchgen.model import ( + BackendIndex, + DispatchKey, NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, - BackendIndex, - DispatchKey, ) -import torchgen.local as local - -import functools -from typing import TypeVar, Union, Iterator, Callable, Dict, Optional -import contextlib +from torchgen.utils import context, S, T # Helper functions for defining generators on things in the model diff --git a/torchgen/dest/__init__.py b/torchgen/dest/__init__.py index 498c437a88a34..0c684fc1915cb 100644 --- a/torchgen/dest/__init__.py +++ b/torchgen/dest/__init__.py @@ -1,19 +1,19 @@ -from .lazy_ir import GenLazyIR as GenLazyIR -from .lazy_ir import GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition -from .lazy_ir import GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition from .lazy_ir import ( generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes, -) -from .register_dispatch_key import ( - RegisterDispatchKey as RegisterDispatchKey, - gen_registration_helpers as gen_registration_helpers, - gen_registration_headers as gen_registration_headers, + GenLazyIR as GenLazyIR, + GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition, + GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition, ) from .native_functions import ( compute_native_function_declaration as compute_native_function_declaration, ) +from .register_dispatch_key import ( + gen_registration_headers as gen_registration_headers, + gen_registration_helpers as gen_registration_helpers, + RegisterDispatchKey as RegisterDispatchKey, +) from .ufunc import ( - compute_ufunc_cuda as compute_ufunc_cuda, compute_ufunc_cpu as compute_ufunc_cpu, compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel, + compute_ufunc_cuda as compute_ufunc_cuda, ) diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py index 55fdeb171ed51..b2dc965d8f962 100644 --- a/torchgen/dest/lazy_ir.py +++ b/torchgen/dest/lazy_ir.py @@ -1,30 +1,36 @@ +import itertools from abc import ABC from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union -from torchgen.context import method_with_native_function -from torchgen.model import ( - BackendIndex, - NativeFunction, - NativeFunctionsGroup, - FunctionSchema, -) -from torchgen.api.types import ( - BaseCType, - OptionalCType, - VectorCType, - kernel_signature, - deviceT, -) +from typing import Any, Dict, List, Optional, Tuple, Union + import torchgen.api.dispatcher as dispatcher from torchgen.api.lazy import ( - LazyIrProperties, - LazyIrSchema, - LazyArgument, getValueT, isValueType, + LazyArgument, + LazyIrProperties, + LazyIrSchema, tensorListValueT, ) +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + deviceT, + DispatcherSignature, + kernel_signature, + OptionalCType, + VectorCType, +) +from torchgen.context import method_with_native_function from torchgen.dest.lazy_ts_lowering import ts_lowering_body +from torchgen.model import ( + Argument, + BackendIndex, + FunctionSchema, + NativeFunction, + NativeFunctionsGroup, +) def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: @@ -42,10 +48,7 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: return f"lazy_{arg.name}_tensorlist" elif arg.is_symint_or_list: cpp_type = arg.lazy_type.cpp_type() - return ( - f"{cpp_type}(std::dynamic_pointer_cast" - f"({arg.name}.toSymbolicIntNode())->node_, 0)" - ) + return f"{cpp_type}(dynamic_cast({arg.name}.toSymIntNodeImpl().get())->node_, 0)" return f"lazy_{arg.name}->GetIrValue()" elif isinstance(arg.lazy_type, OptionalCType): if arg.is_wrapped_scalar: @@ -121,6 +124,25 @@ def aten_symbol(schema: LazyIrSchema) -> str: return schema.aten_name +# converts all tensor-like arguments to meta tensors. Returns: +# (1) a string containing all of the logic that does the conversions. +# (2) a context, to be used by translate(), with all of the relevant bindings. +def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: + context: List[Binding] = [] + unwrapped_tensor_args: List[str] = [] + for arg in sig.arguments(): + if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): + unwrapped_name = f"{arg.name}_meta" + unwrapped_tensor_args.append( + f"auto {unwrapped_name} = to_meta({arg.name});" + ) + context.append(arg.with_name(unwrapped_name)) + else: + context.append(arg) + unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) + return unwrap_tensor_args_str, context + + @dataclass(frozen=True) class GenLazyIR(ABC): backend_index: BackendIndex @@ -206,7 +228,13 @@ def gen(self, schema: LazyIrSchema) -> List[str]: node_ctor_args = ", ".join(ctor_args) scalar_initializers = ",\n ".join( - f"{a.name}({a.name})" for a in scalar_args + [ + # This code is just special casing the mapping from string_view -> strings + f"{a.name}({a.name}.has_value() ? c10::make_optional(std::string(*{a.name})) : c10::nullopt)" + if a.lazy_type.cpp_type() == "c10::optional" + else f"{a.name}({a.name})" + for a in scalar_args + ] ) if len(scalar_initializers): scalar_initializers = f",\n {scalar_initializers}" @@ -214,6 +242,8 @@ def gen(self, schema: LazyIrSchema) -> List[str]: [ f"std::string {a.name};" if a.lazy_type.cpp_type() == "c10::string_view" + else f"c10::optional {a.name};" + if a.lazy_type.cpp_type() == "c10::optional" else f"{a.lazy_type.cpp_type()} {a.name};" for a in scalar_args ] @@ -314,19 +344,20 @@ def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> s elif not schema.properties.CanBeReused: return "" value_comparison = [] - for arg in schema.positional_values: + for arg in itertools.chain(schema.positional_values, schema.keyword_values): if isinstance(arg.lazy_type, OptionalCType): value_comparison.append( - f"operand(i++) == {arg.name}.value_or(kNullValue)" + f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)" ) else: value_comparison.append(f"operand(i++) == {arg.name}") - for arg in schema.positional_scalars: - value_comparison.append(f"this->{arg.name} == {arg.name}") - for arg in schema.keyword_values: - value_comparison.append(f"operand(i++) == {arg.name}") - for arg in schema.keyword_scalars: - value_comparison.append(f"this->{arg.name} == {arg.name}") + for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars): + if isinstance(arg.lazy_type, OptionalCType): + value_comparison.append( + f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))" + ) + else: + value_comparison.append(f"this->{arg.name} == {arg.name}") value_comparison_str = " &&\n ".join(value_comparison) return f"""{signature} {{ @@ -428,9 +459,20 @@ def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str: all_args = schema.filtered_args() returns_length = len(schema.returns) # call the meta kernel if it exists, to compute output shape/dtype for our IR - if func.structured or func.structured_delegate is not None: - meta_out = """std::vector shapes{ - torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" + # Note [Generated LTC Shape Functions] + # LTC uses meta tensors from core to do shape inference when possible, and otherwise + # we generate a shape function declaration that needs to be manually implemented. + # How do we detect which ops are eligible to use meta tensors? + # In general we should be able to use meta tensors not just on structured operators, + # but also on composite operators that are implemented in terms of structured kernels. + # We don't currently have a way of knowing at codegen time which ops are implemented that way. + # This is the case for all view and view_copy operators however, so we're going to + # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them). + is_view_copy_op = "view_copy" in func.tags + is_structured = func.structured or func.structured_delegate is not None + if is_structured or is_view_copy_op: + meta_out = """ +std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" if returns_length > 1: def this_shape(i: int) -> str: @@ -439,8 +481,28 @@ def this_shape(i: int) -> str: shapes_str = ",".join([this_shape(i) for i in range(returns_length)]) meta_out = "std::vector shapes{" + shapes_str + "};" - shape_str = f"""auto out_meta = at::meta::{schema.aten_name}({', '.join(str(a.name) for a in all_args)}); - {meta_out}""" + # Convert tensor args to the meta device and call it. + # (We can't pass in the input tensors directly, because they are "functional wrappers". + # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.) + # Even at::meta:: functions might redispatch, e.g. if they call into view ops. + dispatcher_sig = DispatcherSignature.from_schema(func.func) + meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) + meta_call_args = [ + e.expr + for e in translate( + meta_call_ctx, dispatcher_sig.arguments(), method=False + ) + ] + if is_view_copy_op: + # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel + assert func.has_composite_explicit_autograd_non_functional_kernel + dispatch_ns = "compositeexplicitautogradnonfunctional" + else: + dispatch_ns = "meta" + shape_str = f"""\ + {meta_conversion_str} + auto out_meta = at::{dispatch_ns}::{schema.aten_name}({', '.join(meta_call_args)}); + {meta_out}""" else: shape_sig = ComputeShapeSignature(metadata.kernel, func) shape_str = f""" @@ -571,13 +633,14 @@ def __call__(self, f: NativeFunction) -> List[str]: metadata = self.backend_index.get_kernel(f) assert metadata is not None - # Only generate shape/dtype fn for non-structured kernels, - # since we just use the meta function for structured kernels - if not f.structured and f.structured_delegate is None: + # See Note [Generated LTC Shape Functions] + is_view_copy_op = "view_copy" in f.tags + is_structured = f.structured or f.structured_delegate is not None + if is_structured or is_view_copy_op: + return [] + else: shape_sig = ComputeShapeSignature(metadata.kernel, f) return ["\n".join([f"{shape_sig.shape_decl};"])] - else: - return [] def generate_non_native_lazy_ir_nodes( diff --git a/torchgen/dest/lazy_ts_lowering.py b/torchgen/dest/lazy_ts_lowering.py index b84c625836bf4..bb1d69ee393a2 100644 --- a/torchgen/dest/lazy_ts_lowering.py +++ b/torchgen/dest/lazy_ts_lowering.py @@ -42,7 +42,7 @@ def ts_lowering_body(schema: LazyIrSchema) -> str: {emplace_arguments_str} {emplace_kwarguments} torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); - CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); + TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); return {schema.aten_name}_out; """ diff --git a/torchgen/dest/native_functions.py b/torchgen/dest/native_functions.py index 67db9795f11ed..57a9217550d9c 100644 --- a/torchgen/dest/native_functions.py +++ b/torchgen/dest/native_functions.py @@ -1,11 +1,12 @@ -from typing import List, Union, Optional +from typing import List, Optional, Union -from torchgen.context import with_native_function_and_index -from torchgen.utils import mapMaybe -from torchgen.model import NativeFunction, NativeFunctionsGroup, BackendIndex -from torchgen.api.types import kernel_signature import torchgen.api.meta as meta import torchgen.api.structured as structured +from torchgen.api.types import kernel_signature + +from torchgen.context import with_native_function_and_index +from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup +from torchgen.utils import mapMaybe @with_native_function_and_index diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 5a814ec10ba0f..f7a3ef7bb6448 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -1,42 +1,44 @@ -from typing import List, Optional, Tuple, Union import itertools -from typing_extensions import Literal -from dataclasses import dataclass import textwrap +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union -from torchgen.context import method_with_native_function, native_function_manager -from torchgen.utils import Target, mapMaybe, assert_never -from torchgen.model import ( - DispatchKey, - NativeFunction, - NativeFunctionsGroup, - SchemaKind, - TensorOptionsArguments, - DeviceCheckType, - Argument, - is_cuda_dispatch_key, - BackendIndex, - gets_generated_out_inplace_wrapper, -) +from typing_extensions import Literal + +import torchgen.api.cpp as cpp +import torchgen.api.meta as meta +import torchgen.api.structured as structured +from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, ConstRefCType, CppSignature, CppSignatureGroup, + DispatcherSignature, Expr, - MutRefCType, kernel_signature, + MutRefCType, + NamedCType, NativeSignature, tensorT, - NamedCType, - DispatcherSignature, ) -import torchgen.api.meta as meta -import torchgen.api.cpp as cpp -import torchgen.api.structured as structured -from torchgen.api.translate import translate + +from torchgen.context import method_with_native_function, native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + DeviceCheckType, + DispatchKey, + gets_generated_out_inplace_wrapper, + is_cuda_dispatch_key, + NativeFunction, + NativeFunctionsGroup, + SchemaKind, + TensorOptionsArguments, +) from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import assert_never, mapMaybe, Target def gen_registration_headers( diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index 4b81c4218f2ef..da42149c596b6 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -1,30 +1,31 @@ from dataclasses import dataclass -from typing import Union, Optional, List, Tuple, Dict, Sequence -from torchgen.api.translate import translate -from torchgen.model import ( - NativeFunctionsGroup, - ScalarType, - UfuncKey, - DispatchKey, - BaseType, - BaseTy, - Argument, -) +from typing import Dict, List, Optional, Sequence, Tuple, Union + import torchgen.api.ufunc as ufunc -from torchgen.api.ufunc import UfunctorBindings +from torchgen.api.translate import translate from torchgen.api.types import ( - StructuredImplSignature, - scalar_t, - opmath_t, + BaseCType, Binding, CType, - BaseCType, Expr, NamedCType, - ScalarTypeToCppMapping, + opmath_t, + scalar_t, + StructuredImplSignature, VectorizedCType, ) +from torchgen.api.ufunc import UfunctorBindings from torchgen.context import with_native_function +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, + NativeFunctionsGroup, + ScalarType, + UfuncKey, +) +from torchgen.utils import OrderedSet # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # @@ -169,7 +170,7 @@ def compute_ufunc_cuda_functors( # functors per dtype, which is awful, so we're not going to do it unless # someone really forces us to) ufunc_name = None - supported_dtypes = set() + supported_dtypes: OrderedSet[ScalarType] = OrderedSet() for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: if lk not in loops: continue @@ -287,12 +288,12 @@ def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: # Next, build the conditionals sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) dtype_cases = [] - for dtype, inner_ufunctor_sigs in ufunctor_sigs.items(): + for dtype, inner_ufunc_sigs in ufunctor_sigs.items(): dtype_cases.append( f""" -AT_PRIVATE_CASE_TYPE("{sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]}, +AT_DISPATCH_CASE(at::ScalarType::{dtype}, [&]() {{ - {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunctor_sigs, sig.arguments())} + {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())} }} ) """ @@ -309,13 +310,9 @@ def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: {stub_sig.dispatch_decl()}; {stub_sig.kernel_defn()} {{ - at::ScalarType st = iter.common_dtype(); - RECORD_KERNEL_FUNCTION_DTYPE("{sig.name}", st); - switch (st) {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}", {dtype_cases_str} - default: - TORCH_CHECK(false, "{sig.name}", " not implemented for '", toString(st), "'"); - }} + ); }} REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); @@ -522,7 +519,7 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: for dtype, inner_ufunc_sigs in ufunc_sigs.items(): dtype_cases.append( f""" -AT_PRIVATE_CASE_TYPE("{stub_sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]}, +AT_DISPATCH_CASE(at::ScalarType::{dtype}, [&]() {{ {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} }} @@ -535,13 +532,9 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: namespace {{ {stub_sig.kernel_defn()} {{ - at::ScalarType st = iter.common_dtype(); - RECORD_KERNEL_FUNCTION_DTYPE("{stub_sig.name}", st); - switch (st) {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}", {dtype_cases_str} - default: - TORCH_CHECK(false, "{stub_sig.name}", " not implemented for '", toString(st), "'"); - }} + ); }} }} // anonymous namespace diff --git a/torchgen/gen.py b/torchgen/gen.py index 87e47bdd7ebd9..377f13a445479 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -1,82 +1,88 @@ -import os -from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar -from typing_extensions import Literal -import yaml -from collections import OrderedDict, defaultdict, namedtuple import argparse -import pathlib +import functools import json +import os +import pathlib +from collections import defaultdict, namedtuple, OrderedDict from dataclasses import dataclass -import functools +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union + +import yaml +from typing_extensions import Literal + +import torchgen.api.dispatcher as dispatcher +import torchgen.api.meta as meta +import torchgen.api.native as native +import torchgen.api.structured as structured +import torchgen.dest as dest +from torchgen.api import cpp +from torchgen.api.translate import translate +from torchgen.api.types import ( + Binding, + CppSignatureGroup, + DispatcherSignature, + NamedCType, + NativeSignature, + SpecialArgName, +) +from torchgen.context import ( + method_with_native_function, + native_function_manager, + with_native_function, + with_native_function_and_indices, +) +from torchgen.gen_functionalization_type import ( + gen_composite_view_copy_kernel, + gen_functionalization_definition, + gen_functionalization_registration, + gen_functionalization_view_inverse_declaration, + gen_symint_view_copy_kernel, +) +from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing from torchgen.model import ( - STRUCTURED_DISPATCH_KEYS, Argument, + BackendIndex, + BackendMetadata, + BaseOperatorName, + DEFAULT_KERNEL_NAMESPACE, DispatchKey, FunctionSchema, + is_cuda_dispatch_key, + is_generic_dispatch_key, + is_ufunc_dispatch_key, Location, NativeFunction, NativeFunctionsGroup, + NativeFunctionsViewGroup, OperatorName, - BackendIndex, - BackendMetadata, OptionalType, SchemaKind, SelfArgument, + STRUCTURED_DISPATCH_KEYS, TensorOptionsArguments, Type, Variant, - is_cuda_dispatch_key, - is_generic_dispatch_key, - is_ufunc_dispatch_key, - NativeFunctionsViewGroup, ViewSchemaKind, - BaseOperatorName, ) from torchgen.native_function_generation import ( - pre_group_native_functions, add_generated_native_functions, gen_composite_functional_kernel, gen_composite_out_kernel, + pre_group_native_functions, ) -from torchgen.api.types import ( - Binding, - CppSignatureGroup, - DispatcherSignature, - NamedCType, - NativeSignature, - SpecialArgName, -) -from torchgen.api import cpp -import torchgen.api.dispatcher as dispatcher -import torchgen.api.native as native -import torchgen.api.meta as meta -import torchgen.api.structured as structured -from torchgen.api.translate import translate from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import ( - Target, + assert_never, concatMap, context, + FileManager, + make_file_manager, mapMaybe, + NamespaceHelper, + Target, YamlDumper, YamlLoader, - FileManager, - assert_never, - make_file_manager, -) -from torchgen.context import ( - method_with_native_function, - native_function_manager, - with_native_function_and_indices, - with_native_function, -) -import torchgen.dest as dest -from torchgen.gen_functionalization_type import ( - gen_functionalization_definition, - gen_functionalization_registration, - gen_functionalization_view_inverse_declaration, - gen_composite_view_copy_kernel, ) T = TypeVar("T") @@ -100,7 +106,7 @@ # - 'api' has conversions for how to translate JIT schema into # the various C++ APIs that the codegen interacts with. There # are in fact THREE different C++ APIs: the public C++ API, -# the dispatcher API, and the legacy disaptcher API. See each +# the dispatcher API, and the legacy dispatcher API. See each # of these respective files for more information # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # @@ -110,37 +116,6 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -class NamespaceHelper: - """A helper for constructing the namespace open and close strings for a nested set of namespaces. - - e.g. for namespace_str torch::lazy, - - prologue: - namespace torch { - namespace lazy { - - epilogue: - } // namespace lazy - } // namespace torch - """ - - def __init__(self, namespace_str: str): - # cpp_namespace can be a colon joined string such as torch::lazy - cpp_namespaces = namespace_str.split("::") - self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces]) - self.epilogue_ = "\n".join( - [f"}} // namespace {n}" for n in reversed(cpp_namespaces)] - ) - - @property - def prologue(self) -> str: - return self.prologue_ - - @property - def epilogue(self) -> str: - return self.epilogue_ - - # A custom loader for YAML to let us also keep track of line numbers # of each entry in the YAML file class LineLoader(YamlLoader): @@ -407,20 +382,26 @@ def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding def generate_static_dispatch_backend_call( f: NativeFunction, backend_index: BackendIndex, - ns: str = "at", ) -> str: name = DispatcherSignature.from_schema(f.func).name() exprs = translate_args_dispatcher_to_cpp(f) + backend_metadata = backend_index.get_kernel(f) + kernel_ns = ( + backend_metadata.cpp_namespace + if backend_metadata and backend_metadata.cpp_namespace + else DEFAULT_KERNEL_NAMESPACE + ) + ns = kernel_ns.replace("::native", "") return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});" def generate_static_dispatch_fallback_call( f: NativeFunction, backend_indices: List[BackendIndex], - ns: str = "at", ) -> str: name = DispatcherSignature.from_schema(f.func).name() exprs = translate_args_dispatcher_to_cpp(f) + ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "") if f.has_composite_explicit_autograd_kernel: return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" elif f.has_composite_explicit_autograd_non_functional_kernel: @@ -435,7 +416,6 @@ def generate_static_dispatch_fallback_call( def static_dispatch( f: NativeFunction, backend_indices: List[BackendIndex], - namespace: str = "at", ) -> str: if len(backend_indices) == 0 or f.manual_kernel_registration: return "" @@ -450,9 +430,9 @@ def static_dispatch( ) ] if len(keys) == 1: - return generate_static_dispatch_backend_call(f, keys[0], namespace) + return generate_static_dispatch_backend_call(f, keys[0]) elif len(keys) == 0: - return generate_static_dispatch_fallback_call(f, backend_indices, namespace) + return generate_static_dispatch_fallback_call(f, backend_indices) sig = DispatcherSignature.from_schema(f.func) native_tensor_args = [ @@ -480,10 +460,10 @@ def static_dispatch( for index in keys: dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""") dispatch_code.append( - f"""\t{generate_static_dispatch_backend_call(f, index, namespace)};""" + f"""\t{generate_static_dispatch_backend_call(f, index)};""" ) - fallback = generate_static_dispatch_fallback_call(f, backend_indices, namespace) + fallback = generate_static_dispatch_fallback_call(f, backend_indices) connector = "\n\t\t" return f""" @@ -639,7 +619,7 @@ def generate_defn(faithful: bool) -> str: return f""" // aten::{f.func} -TORCH_API inline {sig.decl()} {{ +inline {sig.decl()} {{ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); }} """ @@ -730,7 +710,7 @@ def generate_defn(faithful: bool) -> str: return f""" // aten::{f.func} -TORCH_API inline {sig.decl(is_redispatching_fn=True)} {{ +inline {sig.decl(is_redispatching_fn=True)} {{ return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str}); }} """ @@ -924,12 +904,14 @@ def __call__(self, f: NativeFunction) -> Optional[str]: # The first case could probably be improved though- it calls computeDispatchKeySet(), # which looks at TLS dispatch keys- there should not be any by the time we reach backend select. if native_tensor_args: + assert f.func.arguments.has_tensor_arg() tensor_args = ", ".join(a.name for a in native_tensor_args) compute_dk = f"""\ DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args}); DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect); DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);""" else: + assert not f.func.arguments.has_tensor_arg() compute_dk = ( f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});" ) @@ -1373,6 +1355,89 @@ def flatten_pre_group( ) +# Return native function declarations grouped by their namespaces. +def get_native_function_declarations( + *, + grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], + backend_indices: Dict[DispatchKey, BackendIndex], +) -> List[str]: + declarations: List[str] = [] + ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) + newline = "\n" + for f in grouped_native_functions: + native_function_namespaces = set() + dispatch_keys = set() + for dispatch_key, backend_idx in backend_indices.items(): + backend_metadata = backend_idx.get_kernel(f) + if backend_metadata: + namespace = backend_metadata.cpp_namespace + dispatch_keys.add(dispatch_key) + native_function_namespaces.add(namespace) + else: + namespace = DEFAULT_KERNEL_NAMESPACE + assert ( + len(native_function_namespaces) <= 1 + ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}" + ns_grouped_kernels[namespace].extend( + dest.compute_native_function_declaration(f, backend_idx) + ) + + for namespace, kernels in ns_grouped_kernels.items(): + ns_helper = NamespaceHelper( + namespace_str=namespace, + entity_name="", + max_level=3, + ) + # Convert to a set first to remove duplicate kernel names. Backends are + # allowed to repeat kernel names; only generate the declaration once! + ordered_kernels = list(OrderedDict.fromkeys(kernels)) + declarations.extend( + f""" +{ns_helper.prologue} +{newline.join(ordered_kernels)} +{ns_helper.epilogue} + """.split( + newline + ) + ) + return declarations + + +# Return native function schema registration code for aten and other namespaces. +def get_native_function_schema_registrations( + *, + native_functions: Sequence[NativeFunction], + schema_selector: SelectiveBuilder, +) -> Tuple[List[str], str]: + ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list) + for native_function in native_functions: + ns_native_functions[native_function.namespace].append(native_function) + schema_registrations = "" + aten_schema_registrations = [] + custom_namespace = None + for namespace, funcs in ns_native_functions.items(): + + schema_registrations_body = list( + mapMaybe(RegisterSchema(schema_selector), funcs) + ) + # NB: we have to separate aten namespace registration from other namespaces, + # because in the template we hardcoded an operator for ATen already. + if namespace == "aten": + aten_schema_registrations = schema_registrations_body + else: + assert custom_namespace is None or namespace == custom_namespace, ( + "Only one custom namespace (other than 'aten') is currently supported, " + f" but getting {namespace} and {custom_namespace}" + ) + custom_namespace = namespace + tab = "\t" + schema_registrations += f""" +TORCH_LIBRARY({custom_namespace}, m) {{ + {tab.join(schema_registrations_body)} +}};""" + return (aten_schema_registrations, schema_registrations) + + def gen_aggregated_headers( *, native_functions: Sequence[NativeFunction], @@ -1449,27 +1514,15 @@ def gen_aggregated_headers( ), }, ) + declarations = get_native_function_declarations( + grouped_native_functions=grouped_native_functions, + backend_indices=backend_indices, + ) cpu_fm.write( "NativeFunctions.h", lambda: { "NativeFunctions_includes": ["#include "], - "NativeFunctions_declarations": list( - concatMap( - # Convert to a set first to remove duplicate kernel names. - # Backends are allowed to repeat kernel names; only generate the declaration once! - lambda f: list( - OrderedDict.fromkeys( - concatMap( - lambda backend_idx: dest.compute_native_function_declaration( - f, backend_idx - ), - backend_indices.values(), - ) - ) - ), - grouped_native_functions, - ) - ), + "NativeFunctions_declarations": declarations, }, ) @@ -1597,7 +1650,9 @@ def gen_per_operator_headers( ), }, ) - + declarations = get_native_function_declarations( + grouped_native_functions=grouped_functions, backend_indices=backend_indices + ) ops_fm.write_with_template( f"{name}_native.h", "NativeFunction.h", @@ -1605,23 +1660,7 @@ def gen_per_operator_headers( "extra_includes": ( f"#include " if is_structured else [] ), - "native_function_declarations": list( - concatMap( - # Convert to a set first to remove duplicate kernel names. - # Backends are allowed to repeat kernel names; only generate the declaration once! - lambda f: list( - OrderedDict.fromkeys( - concatMap( - lambda backend_idx: dest.compute_native_function_declaration( - f, backend_idx - ), - backend_indices.values(), - ) - ) - ), - grouped_functions, - ) - ), + "native_function_declarations": declarations, }, ) @@ -1806,6 +1845,10 @@ def gen_headers( }, ) + cpu_fm.write( + "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions) + ) + def gen_aten_interned_strings() -> Dict[str, str]: attrs = set() # All function argument names names = set() # All ATen function names @@ -1848,7 +1891,7 @@ def gen_aten_interned_strings() -> Dict[str, str]: core_fm.write("aten_interned_strings.h", gen_aten_interned_strings) def gen_tags_enum() -> Dict[str, str]: - return {"enum_of_valid_tags": (",\n".join([f"{tag}" for tag in valid_tags]))} + return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))} core_fm.write("enum_tag.h", gen_tags_enum) @@ -2097,32 +2140,12 @@ def gen_backend_select() -> Dict[str, List[str]]: if force_schema_registration: schema_selector = SelectiveBuilder.get_nop_selector() - ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list) - for native_function in native_functions: - ns_native_functions[native_function.namespace].append(native_function) - schema_registrations = "" - aten_schema_registrations = [] - custom_namespace = None - for namespace, funcs in ns_native_functions.items(): - - schema_registrations_body = list( - mapMaybe(RegisterSchema(schema_selector), funcs) - ) - # NB: we have to separate aten namespace registration from other namespaces, - # because in the template we hardcoded an operator for ATen already. - if namespace == "aten": - aten_schema_registrations = schema_registrations_body - else: - assert custom_namespace is None or namespace == custom_namespace, ( - "Only one custom namespace (other than 'aten') is currently supported, " - f" but getting {namespace} and {custom_namespace}" - ) - custom_namespace = namespace - tab = "\t" - schema_registrations += f""" -TORCH_LIBRARY({custom_namespace}, m) {{ - {tab.join(schema_registrations_body)} -}};""" + ( + aten_schema_registrations, + schema_registrations, + ) = get_native_function_schema_registrations( + native_functions=native_functions, schema_selector=schema_selector + ) cpu_fm.write( "RegisterSchema.cpp", lambda: { @@ -2281,6 +2304,29 @@ def gen_op_headers( ) }, ) + view_copy_with_symint_pairs: List[Tuple[NativeFunction, NativeFunction]] = [] + for g1 in view_groups: + for g2 in view_groups: + if g1.view_copy is None or g2.view_copy is None: + continue + # TODO: make this more first class in the data model + g1_base_name = str(g1.view_copy.func.name.name) + g2_base_name = str(g2.view_copy.func.name.name) + + same_base_op = ( + g1_base_name == g2_base_name + and g1.view_copy.func.arguments.symints_to_ints() + == g2.view_copy.func.arguments.symints_to_ints() + ) + op1_not_symint = "SymInt" not in str(g1.view_copy.func.name.overload_name) + op2_symint = "SymInt" in str(g2.view_copy.func.name.overload_name) + if same_base_op and op1_not_symint and op2_symint: + view_copy_with_symint_pairs.append( + ( + g1.view_copy, + g2.view_copy, + ) + ) # Note [view_copy NativeFunctions] # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd @@ -2321,6 +2367,12 @@ def gen_op_headers( "CompositeViewCopyKernel_Definitions": list( mapMaybe(gen_composite_view_copy_kernel, view_groups) ), + "SymIntViewCopyKernel_Definitions": list( + mapMaybe( + lambda pair: gen_symint_view_copy_kernel(pair[0], pair[1]), + view_copy_with_symint_pairs, + ) + ), "GeneratedCompositeFunctional_Definitions": list( mapMaybe( gen_composite_functional_kernel, @@ -2480,8 +2532,6 @@ def main() -> None: if isinstance(g, NativeFunctionsViewGroup) ] - template_dir = os.path.join(options.source_path, "templates") - # NB: It is mandatory to NOT use os.path.join here, as the install directory # will eventually be ingested by cmake, which does not respect Windows style # path slashes. If you switch this to use os.path.join, you'll get an error @@ -2503,18 +2553,6 @@ def main() -> None: cuda_fm = make_file_manager(options=options) ops_fm = make_file_manager(options=options, install_dir=ops_install_dir) - extra_cuda_headers = """\ -#include -#include -#include -#include """ - if options.rocm: - extra_cuda_headers = """\ -#include -#include -#include -#include """ - # Only a limited set of dispatch keys get CPUFunctions.h headers generated # for them; this is the set functions_keys = { @@ -2525,6 +2563,9 @@ def main() -> None: DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.Meta, } + if options.mps: + functions_keys.add(DispatchKey.MPS) + if options.backend_whitelist: dispatch_keys = [ k diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index c5b7af9e069bc..ae346e6d3aca6 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -1,15 +1,18 @@ -import pathlib import argparse import os -import yaml +import pathlib import re -from collections import namedtuple, Counter, defaultdict -from typing import List, Dict, Union, Sequence, Optional -from torchgen.gen import ( - get_grouped_native_functions, - parse_native_yaml, - NamespaceHelper, -) +from collections import Counter, defaultdict, namedtuple +from typing import Dict, List, Optional, Sequence, Union + +import yaml + +import torchgen.api.dispatcher as dispatcher +import torchgen.dest as dest +from torchgen.api.types import DispatcherSignature +from torchgen.code_template import CodeTemplate +from torchgen.context import native_function_manager +from torchgen.gen import get_grouped_native_functions, parse_native_yaml from torchgen.model import ( BackendIndex, BackendMetadata, @@ -19,12 +22,14 @@ OperatorName, ) from torchgen.selective_build.selector import SelectiveBuilder -from torchgen.utils import Target, concatMap, context, YamlLoader, FileManager -from torchgen.context import native_function_manager -from torchgen.code_template import CodeTemplate -import torchgen.dest as dest -import torchgen.api.dispatcher as dispatcher -from torchgen.api.types import DispatcherSignature +from torchgen.utils import ( + concatMap, + context, + FileManager, + NamespaceHelper, + Target, + YamlLoader, +) # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key. @@ -62,6 +67,7 @@ def parse_backend_yaml( "autograd", "full_codegen", "non_native", + "ir_gen", ] backend = yaml_values.pop("backend", None) @@ -102,6 +108,9 @@ def parse_backend_yaml( # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py non_native = yaml_values.pop("non_native", {}) + # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py + _ = yaml_values.pop("ir_gen", {}) + assert ( len(yaml_values.keys()) == 0 ), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \ @@ -265,7 +274,10 @@ def error_on_missing_kernels( native_f ) - kernel_defn_regex = rf"{class_name}::([\w\d]*)\([^\)]*\)\s*{{" + # This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented. + # It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel + # here, then we get a nicer error message. If we miss it, you get a linker error. + kernel_defn_regex = rf"{class_name}::\s*([\w\d]*)\(" actual_backend_kernel_name_counts = Counter( re.findall(kernel_defn_regex, backend_defns) ) diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 71c8c252f5fb6..54730be4b2215 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,47 +1,47 @@ +from typing import Callable, List, Optional, Tuple, Union + from torchgen.api import cpp, dispatcher +from torchgen.api.translate import translate from torchgen.api.types import ( - DispatcherSignature, + BaseCType, Binding, + CType, + DispatcherSignature, FunctionalizationLambda, - ViewInverseSignature, NativeSignature, - CType, - BaseCType, - VectorCType, tensorListT, tensorT, + VectorCType, + ViewInverseSignature, ) -from torchgen.api.translate import translate from torchgen.context import ( + native_function_manager, with_native_function, with_native_function_and, - native_function_manager, ) from torchgen.model import ( Argument, - Return, - NativeFunction, - NativeFunctionsGroup, BackendIndex, + BaseTy, + BaseType, FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + Return, SchemaKind, SelfArgument, TensorOptionsArguments, - BaseType, - BaseTy, - NativeFunctionsViewGroup, - ListType, ) from torchgen.native_function_generation import ( - OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, - MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY, + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, + OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, ) from torchgen.selective_build.selector import SelectiveBuilder -from typing import List, Optional, Union, Tuple, Callable - # Note: [Mutable Ops Not Using Functionalization] # Ops in this list currently do not work with functionalization and should be fixed. @@ -75,8 +75,31 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str] if g.view_copy is None: return None + + # For view_copy.SymInt overloads, + # See gen_symint_view_copy_kernel. + if g.view_copy.func.name.overload_name == "SymInt": + return None + + # We can make view_copy work in more cases by using reshape() + # when a normal view call would ordinarily fail. + # This also makes LTC more efficient, because they don't need to include + # clone() calls in their graph (which is normally needed by reshape). + if str(g.view_copy.func.name) == "view_copy": + return """\ +at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) { + DimVector shape = infer_size_dv(size, self.numel()); + if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) { + return self.reshape(size); + } else { + auto output = at::_ops::view::call(self, size); + return output.clone(); + } +} +""" # view_copy is a native signature, since we're generating an at::native:: kernel view_copy_sig = NativeSignature(g.view_copy.func) + # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) @@ -113,6 +136,34 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str] """ +# For symint view copy kernels, we want to generate them to call into +# their concrete view_copy counterparts. +@with_native_function_and +def gen_symint_view_copy_kernel( + view_copy: NativeFunction, view_copy_symint: NativeFunction +) -> str: + # view_copy.symint is a native signature, since we're generating an at::native:: kernel + view_copy_symint_sig = NativeSignature(view_copy_symint.func) + + # view_copy is a dispatcher signature, since we're calling into the at::_ops API + view_copy_sig = DispatcherSignature(view_copy.func) + + exprs = ", ".join( + [ + e.expr + for e in translate( + view_copy_symint_sig.arguments(), view_copy_sig.arguments() + ) + ] + ) + + return f""" +{view_copy_symint_sig.defn()} {{ + return at::_ops::{view_copy.func.name.unambiguous_name()}::call({exprs}); +}} +""" + + def return_str(rets: Tuple[Return, ...], names: List[str]) -> str: assert len(rets) == len(names) if len(rets) == 0: @@ -313,7 +364,8 @@ def emit_view_functionalization_body( ); {return_type} reference_tensor_output; {{ - at::AutoDispatchSkipFunctionalize guard; + at::AutoDispatchSkipFunctionalize func_guard; + c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); {meta_conversion_str} reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); }} @@ -341,12 +393,16 @@ def emit_view_functionalization_body( return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); }} auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); - {return_type} tmp_output; {return_type} reference_tensor_output; {{ - at::AutoDispatchSkipFunctionalize guard; + at::AutoDispatchSkipFunctionalize func_guard; + c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); {meta_conversion_str} reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); + }} + {return_type} tmp_output; + {{ + at::AutoDispatchSkipFunctionalize guard; if (reapply_views) {{ tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); }} else {{ @@ -553,8 +609,9 @@ def emit_inplace_functionalization_body( // Before converting the mutable op to its functional variant, run meta tensors through the original op. // This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants. // (We can only do this for inplace ops today though, because they technicaly all support meta tensors). + at::AutoDispatchSkipFunctionalize func_guard; + c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); {meta_conversion_str} - at::AutoDispatchSkipFunctionalize guard; at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)}); }} {unwrap_tensor_args_str} @@ -632,6 +689,9 @@ def emit_registration_helper(f: NativeFunction) -> str: if isinstance(g, NativeFunctionsViewGroup): # functionalization needs to register kernels for view + view_inplace ops + # See Note [Functionalization <> torch.Tensor constructor] + if str(g.view.func.name) == "lift_fresh": + return [] view_str = [emit_registration_helper(g.view)] if g.view_inplace is not None: assert g.view_inplace.is_view_op diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index b6a777ed1407e..a5a1fd9b2535a 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -1,45 +1,39 @@ -import pathlib import argparse import os +import pathlib import re -import yaml -from collections import namedtuple, Counter +from collections import Counter, namedtuple from typing import ( Any, - List, - Dict, - Tuple, - Union, - Sequence, - Optional, Callable, + Dict, Iterable, Iterator, + List, + Optional, + Sequence, + Tuple, Type, -) -from torchgen.api.types import BaseCppType -from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR -from torchgen.gen import ( - get_grouped_native_functions, - parse_native_yaml, - NamespaceHelper, + Union, ) +import yaml + +import torchgen.dest as dest + from torchgen.api.lazy import setValueT +from torchgen.api.types import BaseCppType +from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR +from torchgen.gen import get_grouped_native_functions, parse_native_yaml -from torchgen.model import ( - NativeFunction, - NativeFunctionsGroup, - OperatorName, -) +from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName from torchgen.selective_build.selector import SelectiveBuilder -from torchgen.utils import concatMap, YamlLoader, FileManager -import torchgen.dest as dest +from torchgen.utils import concatMap, FileManager, NamespaceHelper, YamlLoader from .gen_backend_stubs import ( - parse_backend_yaml, error_on_missing_kernels, - gen_dispatchkey_nativefunc_headers, gen_dispatcher_registrations, + gen_dispatchkey_nativefunc_headers, + parse_backend_yaml, ) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # @@ -111,7 +105,7 @@ def parse_native_functions_keys( backend_yaml_path: str, grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], -) -> Tuple[List[OperatorName], List[Any]]: +) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]: native_functions_map: Dict[OperatorName, NativeFunction] = { f.func.name: f @@ -127,9 +121,13 @@ def parse_native_functions_keys( full_codegen = yaml_values.pop("full_codegen", []) non_native = yaml_values.pop("non_native", []) + ir_gen = yaml_values.pop("ir_gen", []) assert isinstance(full_codegen, list) assert isinstance(non_native, list) - return [OperatorName.parse(name) for name in full_codegen], non_native + assert isinstance(ir_gen, list) + full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen] + ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen] + return full_codegen_opnames, non_native, ir_gen_opnames def validate_shape_inference_header( @@ -162,6 +160,39 @@ def validate_shape_inference_header( ) +# Some helper functions for the codegen. +def get_ltc_helper_fns() -> str: + return """\ +at::Tensor to_meta(const at::Tensor& tensor) { + // undefined tensors can't be converted to the meta device, since they don't have sizes/strides + if (!tensor.defined()) return tensor; + auto out = at::native::empty_strided_meta(tensor.sizes(), tensor.strides(), \ +/*dtype=*/c10::make_optional(tensor.scalar_type()), /*layout=*/c10::make_optional(tensor.layout()), \ +/*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); + // needs to handle wrapped numbers, so dtype promotion works properly. + if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + out.unsafeGetTensorImpl()->set_wrapped_number(true); + } + return out; +} +c10::optional to_meta(const c10::optional& tensor) { + if (tensor.has_value()) { + return to_meta(*tensor); + } + return c10::nullopt; +} + +std::vector to_meta(const at::TensorList& t_list) { + std::vector outs; + outs.reserve(t_list.size()); + for (const auto& i : c10::irange(t_list.size())) { + outs.push_back(to_meta(t_list[i])); + } + return outs; +} +""" + + class default_args: node_base: str = "Node" node_base_hdr: Optional[str] = None @@ -320,6 +351,7 @@ def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str: grouped_native_functions = sorted( grouped_native_functions, key=sort_native_function ) + parsed_backend_yaml = parse_backend_yaml( source_yaml, grouped_native_functions, backend_indices ) @@ -327,13 +359,19 @@ def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str: autograd_key = parsed_backend_yaml.autograd_key cpp_namespace = parsed_backend_yaml.cpp_namespace backend_indices = parsed_backend_yaml.backend_indices - full_codegen, non_native = parse_native_functions_keys( + # the following 3 keys are all processed differently + # for full_codegen, we generate IR, kernels, etc + # for ir_gen, we generate only IR + # non_native is used to register kernels not declared in + # native_functions.yaml + full_codegen, non_native, ir_gen = parse_native_functions_keys( source_yaml, grouped_native_functions ) def concat_map_codegen( func: Callable[[NativeFunction], Sequence[str]], xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]], + ops_list: List[OperatorName] = full_codegen, ) -> Iterator[str]: """ We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we @@ -343,7 +381,7 @@ def concat_map_codegen( for x in xs: fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x] for f in fs: - if f.func.name in full_codegen: + if f.func.name in ops_list: for r in func(f): yield r @@ -436,6 +474,9 @@ def concat_map_codegen( tensor_class_hdr, shape_inference_hdr, "ATen/Functions.h", + "ATen/native/TensorConversions.h", + "ATen/NativeFunctions.h", + "ATen/CompositeExplicitAutogradNonFunctionalFunctions.h", "ATen/MetaFunctions.h", "ATen/Operators.h", "ATen/native/CPUFallback.h", @@ -452,6 +493,7 @@ def concat_map_codegen( else [] ) ], + "helper_fns": get_ltc_helper_fns(), "native_functions_include": "", "namespace_prologue": ns_helper.prologue, "namespace_epilogue": ns_helper.epilogue, @@ -504,7 +546,9 @@ def concat_map_codegen( if node_base_hdr is not None else [], "ir_declarations": list( - concat_map_codegen(lazy_ir_obj, grouped_native_functions) + concat_map_codegen( + lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen + ) ), "namespace_prologue": ns_helper.prologue, "namespace_epilogue": ns_helper.epilogue, diff --git a/torchgen/gen_vmap_plumbing.py b/torchgen/gen_vmap_plumbing.py new file mode 100644 index 0000000000000..ac1413a1845b9 --- /dev/null +++ b/torchgen/gen_vmap_plumbing.py @@ -0,0 +1,266 @@ +import textwrap +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple + +from torchgen.api.translate import translate +from torchgen.api.types import DispatcherSignature +from torchgen.context import method_with_native_function +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + OptionalType, + Return, + SchemaKind, + Type, +) +from torchgen.utils import mapMaybe + + +def is_tensor(typ: Type) -> bool: + return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor + + +def is_optional_tensor(typ: Type) -> bool: + return isinstance(typ, OptionalType) and is_tensor(typ.elem) + + +def is_tensor_list(typ: Type) -> bool: + return isinstance(typ, ListType) and is_tensor(typ.elem) + + +def unwrap_tensor(name: str, cur_level_var: str) -> List[str]: + result = f"""\ + Tensor {name}_value; + optional {name}_bdim; + std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}, {cur_level_var});""" + return textwrap.dedent(result).split("\n") + + +def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]: + result = f"""\ + optional {name}_value; + optional {name}_bdim; + if ({name}) {{ + std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var}); + }}""" + return textwrap.dedent(result).split("\n") + + +def gen_unwraps( + flat_arguments: Sequence[Argument], cur_level_var: str +) -> Tuple[str, List[str]]: + arg_names = [a.name for a in flat_arguments] + arg_types = [a.type for a in flat_arguments] + + tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)] + optional_tensors = [ + name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ) + ] + + unwraps = [] + for tensor in tensors: + unwraps += unwrap_tensor(tensor, cur_level_var) + + for opt_tensor in optional_tensors: + unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var) + unwrap_code = "\n".join(unwraps) + + unwrapped_arg_list = [] + for arg in arg_names: + if arg in tensors or arg in optional_tensors: + unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"] + else: + unwrapped_arg_list.append(arg) + return unwrap_code, unwrapped_arg_list + + +def gen_case_where_all_bdims_are_none( + schema: FunctionSchema, cur_level_var: str +) -> str: + conditions = [] + flat_args = schema.arguments.flat_all + for arg in flat_args: + if not arg.type.is_tensor_like(): + continue + conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})") + + sig = DispatcherSignature.from_schema(schema) + translated_args = ", ".join( + e.expr for e in translate(sig.arguments(), sig.arguments()) + ) + return f"""\ +if ({' && '.join(conditions)}) {{ + return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args}); +}}""" + + +def gen_returns( + returns: Tuple[Return, ...], cur_level_var: str, results_var: str +) -> str: + idx = 0 + wrapped_returns = [] + for ret in returns: + if is_tensor(ret.type): + wrapped_returns.append( + f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})" + ) + idx += 2 + elif is_tensor_list(ret.type): + wrapped_returns.append( + f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})" + ) + idx += 2 + else: + wrapped_returns.append(f"std::get<{idx}>({results_var})") + idx += 1 + if len(wrapped_returns) == 1: + result = f"return {wrapped_returns[0]};" + else: + result = f'return std::make_tuple({", ".join(wrapped_returns)});' + return result + + +def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool: + return any(a.type.is_tensor_like() for a in schema.arguments.flat_all) + + +def is_mutated_arg(argument: Argument) -> bool: + return argument.annotation is not None and argument.annotation.is_write + + +def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]: + # Assumptions: + # - only one argument is being modified in-place + # - the argument that is being modified in-place is the first argument + # - all returns are either Tensor, tuple of Tensor, or TensorList + schema = native_function.func + sig = DispatcherSignature.from_schema(schema) + returns = schema.returns + + # Check assumptions. If these are invalid we return None + # and punt the work to handle them to the future. + assert schema.kind() == SchemaKind.inplace + if not is_mutated_arg(schema.arguments.flat_all[0]): + return None + if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1: + return None + + # Only support cases where all returns are Tensors or vector + if len(returns) == 0: + return None + if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns): + return None + if not accepts_at_least_one_tensor_input(schema): + return None + + cur_level_var = "cur_level" + + unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) + + return f"""\ +template +{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t {cur_level_var} = maybe_layer->layerId(); +{textwrap.indent(bdims_all_none_case, " ")} +{textwrap.indent(unwraps, " ")} + batch_rule({', '.join(unwrapped_arg_list)}); + return {schema.arguments.flat_all[0].name}; +}}""" + + +def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str: + schema = native_function.func + sig = DispatcherSignature.from_schema(schema) + cur_level_var = "cur_level" + + unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) + + return f"""\ +template +{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t {cur_level_var} = maybe_layer->layerId(); +{textwrap.indent(bdims_all_none_case, " ")} +{textwrap.indent(unwraps, " ")} + batch_rule({', '.join(unwrapped_arg_list)}); +}}""" + + +def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]: + schema = native_function.func + sig = DispatcherSignature.from_schema(schema) + returns = schema.returns + + # Only support cases where all returns are Tensors or vector + if not accepts_at_least_one_tensor_input(schema): + return None + if len(returns) == 0: + return gen_vmap_plumbing_no_returns(native_function) + if not all(ret.type.is_tensor_like() for ret in returns): + return None + # in-place views need special handling + if "inplace_view" in native_function.tags: + return None + + if schema.kind() == SchemaKind.inplace: + return gen_vmap_inplace_plumbing(native_function) + + # Don't support these (mutable, out, scratch) + if schema.kind() != SchemaKind.functional: + return None + + results_var = "results" + cur_level_var = "cur_level" + + unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) + + wrapped_returns = gen_returns(returns, cur_level_var, results_var) + return f"""\ +template +{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t {cur_level_var} = maybe_layer->layerId(); +{textwrap.indent(bdims_all_none_case, " ")} +{textwrap.indent(unwraps, " ")} + auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)}); + {wrapped_returns} +}}""" + + +@dataclass(frozen=True) +class ComputeBatchRulePlumbing: + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + opname = str(f.func.name) + result = gen_vmap_plumbing(f) + return result + + +def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str: + body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions))) + return f""" +#pragma once +#include +#include +#include + +namespace at {{ namespace functorch {{ + +{body} + +}}}} // namespace at::functorch +""" diff --git a/torchgen/local.py b/torchgen/local.py index dd570dd8d7ee3..65efce2c3b11b 100644 --- a/torchgen/local.py +++ b/torchgen/local.py @@ -1,6 +1,6 @@ import threading from contextlib import contextmanager -from typing import Optional, Iterator +from typing import Iterator, Optional # Simple dynamic scoping implementation. The name "parametrize" comes # from Racket. diff --git a/torchgen/model.py b/torchgen/model.py index 784924ea8652a..c5ed79453c5ba 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -6,7 +6,7 @@ from enum import auto, Enum from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union -from torchgen.utils import assert_never +from torchgen.utils import assert_never, NamespaceHelper, OrderedSet # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # @@ -49,21 +49,23 @@ def __str__(self) -> str: DEFAULT_KERNEL_NAMESPACE = "at::native" # NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h +BACKEND_COMPONENTS = "CPU CUDA HIP XLA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split() +FUNCTIONALITY_KEYS = ["", "Quantized", "Sparse", "NestedTensor", "Autograd"] + +# This doesn't have to be in sync with the header, it only needs to contain +# entries that we actually use in the codegen class DispatchKey(Enum): Undefined = 0 CatchAll = Undefined - Dense = auto() FPGA = auto() ORT = auto() - MPS = auto() Vulkan = auto() Metal = auto() MKLDNN = auto() OpenGL = auto() OpenCL = auto() IDEEP = auto() - Quantized = auto() CustomRNGKeyId = auto() MkldnnCPU = auto() Sparse = auto() @@ -71,7 +73,6 @@ class DispatchKey(Enum): SparseCsrCUDA = auto() ZeroTensor = auto() - Meta = auto() BackendSelect = auto() Named = auto() AutogradOther = auto() @@ -83,54 +84,84 @@ class DispatchKey(Enum): VmapMode = auto() TESTING_ONLY_GenericWrapper = auto() TESTING_ONLY_GenericMode = auto() - EndOfFunctionalityKeys = TESTING_ONLY_GenericMode + Autograd = auto() + CompositeImplicitAutograd = auto() + CompositeExplicitAutograd = auto() + CompositeExplicitAutogradNonFunctional = auto() + + # BEGIN autogenerated CPU = auto() CUDA = auto() HIP = auto() XLA = auto() - Lazy = auto() + MPS = auto() IPU = auto() XPU = auto() - NestedTensor = auto() + HPU = auto() + VE = auto() + Lazy = auto() + Meta = auto() PrivateUse1 = auto() PrivateUse2 = auto() PrivateUse3 = auto() - QuantizedCPU = auto() QuantizedCUDA = auto() + QuantizedHIP = auto() + QuantizedXLA = auto() + QuantizedMPS = auto() + QuantizedIPU = auto() QuantizedXPU = auto() - + QuantizedHPU = auto() + QuantizedVE = auto() + QuantizedLazy = auto() + QuantizedMeta = auto() + QuantizedPrivateUse1 = auto() + QuantizedPrivateUse2 = auto() + QuantizedPrivateUse3 = auto() SparseCPU = auto() SparseCUDA = auto() SparseHIP = auto() + SparseXLA = auto() + SparseMPS = auto() + SparseIPU = auto() SparseXPU = auto() - + SparseHPU = auto() + SparseVE = auto() + SparseLazy = auto() + SparseMeta = auto() + SparsePrivateUse1 = auto() + SparsePrivateUse2 = auto() + SparsePrivateUse3 = auto() NestedTensorCPU = auto() NestedTensorCUDA = auto() - + NestedTensorHIP = auto() + NestedTensorXLA = auto() + NestedTensorMPS = auto() + NestedTensorIPU = auto() + NestedTensorXPU = auto() + NestedTensorHPU = auto() + NestedTensorVE = auto() + NestedTensorLazy = auto() + NestedTensorMeta = auto() + NestedTensorPrivateUse1 = auto() + NestedTensorPrivateUse2 = auto() + NestedTensorPrivateUse3 = auto() AutogradCPU = auto() AutogradCUDA = auto() + AutogradHIP = auto() AutogradXLA = auto() - AutogradLazy = auto() - AutogradIPU = auto() AutogradMPS = auto() + AutogradIPU = auto() AutogradXPU = auto() + AutogradHPU = auto() + AutogradVE = auto() + AutogradLazy = auto() + AutogradMeta = auto() AutogradPrivateUse1 = auto() AutogradPrivateUse2 = auto() AutogradPrivateUse3 = auto() - - Autograd = auto() - CompositeImplicitAutograd = auto() - CompositeExplicitAutograd = auto() - CompositeExplicitAutogradNonFunctional = auto() - EndOfAliasKeys = CompositeExplicitAutogradNonFunctional - - CPUTensorId = CPU - CUDATensorId = CUDA - PrivateUse1_PreAutograd = AutogradPrivateUse1 - PrivateUse2_PreAutograd = AutogradPrivateUse2 - PrivateUse3_PreAutograd = AutogradPrivateUse3 + # END autogenerated def __str__(self) -> str: return self.name @@ -146,6 +177,24 @@ def parse(value: str) -> "DispatchKey": raise AssertionError(f"unknown dispatch key {value}") +def codegen_per_backend_entries() -> str: + r = [] + for fk in FUNCTIONALITY_KEYS: + for bc in BACKEND_COMPONENTS: + r.append(f" {fk}{bc} = auto()") + return "\n".join(r) + + +for fk in FUNCTIONALITY_KEYS: + for bc in BACKEND_COMPONENTS: + if not hasattr(DispatchKey, fk + bc): + r = codegen_per_backend_entries() + print(r) + raise RuntimeError( + f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}" + ) + + STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU} UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU} @@ -169,6 +218,9 @@ def parse(value: str) -> "DispatchKey": # Meta is a magic key: it is automatically generated for structured # kernels DispatchKey.Meta, + DispatchKey.SparseMeta, + DispatchKey.QuantizedMeta, + DispatchKey.NestedTensorMeta, DispatchKey.ZeroTensor, ] @@ -191,7 +243,6 @@ def is_cuda_dispatch_key(dk: DispatchKey) -> bool: DispatchKey.SparseCsrCUDA, DispatchKey.NestedTensorCUDA, DispatchKey.AutogradCUDA, - DispatchKey.CUDATensorId, } @@ -239,8 +290,8 @@ def parse(value: str) -> "ScalarType": return mb_r @staticmethod - def parse_set(values: str) -> Set["ScalarType"]: - dtypes: Set[ScalarType] = set() + def parse_set(values: str) -> OrderedSet["ScalarType"]: + dtypes: OrderedSet[ScalarType] = OrderedSet() for value in values.split(", "): if value in DTYPE_CLASSES: dtypes.update(DTYPE_CLASSES[value]) @@ -249,18 +300,22 @@ def parse_set(values: str) -> Set["ScalarType"]: return dtypes -DTYPE_CLASSES: Dict[str, Set[ScalarType]] = {} +DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {} # NB: Integral doesn't include boolean -DTYPE_CLASSES["Integral"] = { - ScalarType.Byte, - ScalarType.Char, - ScalarType.Int, - ScalarType.Long, - ScalarType.Short, -} +DTYPE_CLASSES["Integral"] = OrderedSet( + [ + ScalarType.Byte, + ScalarType.Char, + ScalarType.Int, + ScalarType.Long, + ScalarType.Short, + ] +) # NB: Floating doesn't include low precision types -DTYPE_CLASSES["Floating"] = {ScalarType.Float, ScalarType.Double} -DTYPE_CLASSES["Complex"] = {ScalarType.ComplexFloat, ScalarType.ComplexDouble} +DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double]) +DTYPE_CLASSES["Complex"] = OrderedSet( + [ScalarType.ComplexFloat, ScalarType.ComplexDouble] +) DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"] DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"] DTYPE_CLASSES["FloatingAndComplex"] = ( @@ -454,12 +509,11 @@ def from_yaml( funcs = e.pop("func") assert isinstance(funcs, str), f"not a str: {funcs}" # only support one level of namespace. E.g., aten::add - namespaced_funcs = funcs.split("::", 1) - if len(namespaced_funcs) == 1: - namespace = "aten" - else: - namespace = namespaced_funcs[0] - func = FunctionSchema.parse(namespaced_funcs[-1]) + namespace_helper = NamespaceHelper.from_namespaced_entity( + namespaced_entity=funcs, max_level=1 + ) + namespace = namespace_helper.get_cpp_namespace(default="aten") + func = FunctionSchema.parse(namespace_helper.entity_name) cpp_no_default_args_list = e.pop("cpp_no_default_args", []) assert isinstance(cpp_no_default_args_list, list) @@ -579,19 +633,20 @@ def from_yaml( f"Dispatch key {dispatch_key} of kernel {v} " "is not a supported dispatch key." ) - # We only allow one level of namespace for kernels and operator. + # We only allow at most 2 levels of namespace for kernels. # We will append "native" to a custom kernel namespace. - tokens = v.split("::", 1) + namespace_helper = NamespaceHelper.from_namespaced_entity( + v, max_level=2 + ) + kernel_namespace = namespace_helper.get_cpp_namespace(default="at") # Why is 'structured' included? External backends (e.g. # XLA) opt into which ops are structured independently # of which in-tree ops are structured dispatch[dispatch_key] = BackendMetadata( - kernel=tokens[-1], + kernel=namespace_helper.entity_name, structured=structured and is_structured_dispatch_key(dispatch_key), - cpp_namespace=(tokens[0] + "::native") - if len(tokens) > 1 - else DEFAULT_KERNEL_NAMESPACE, + cpp_namespace=(kernel_namespace + "::native"), ) if ( dispatch_key is DispatchKey.CompositeImplicitAutograd @@ -613,6 +668,21 @@ def from_yaml( "name, then delete the dispatch table" ) elif not structured and structured_delegate is None: + name = str(func.name.name) + assert not ( + name.startswith("new_") + or name.endswith("_like") + # TODO: maybe it's better to test the return + or ( + func.arguments.tensor_options + and not func.arguments.has_tensor_arg() + ) + ), ( + f"expected {name} to have a CompositeExplicitAutograd " + "dispatch entry, but there was no dispatch table. Factory functions " + "should not have implicit dispatch as they should not be decomposed " + "for __torch_dispatch__" + ) dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata( cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE ) @@ -661,6 +731,11 @@ def from_yaml( # Program the BackendIndex for the implicit dispatch entry from ufunc if ufunc_inner_loop: assert structured, "ufunc must be structured" + + # Delay import ufunc here to avoid circular import issue + # See: https://github.com/pytorch/pytorch/issues/81294 + import torchgen.api.ufunc as ufunc + for dispatch_key in UFUNC_DISPATCH_KEYS: assert ( dispatch_key not in dispatch @@ -735,6 +810,9 @@ def from_yaml( backend_metadata, ) + def symints_to_ints(self) -> "NativeFunction": + return dataclasses.replace(self, func=self.func.symints_to_ints()) + def validate_unstructured(self) -> None: # TODO: probably better to accumulate these errors and report them all # at once @@ -872,6 +950,8 @@ def __post_init__(self) -> None: if self.mutable is not None: assert self.mutable.func.kind() == SchemaKind.mutable assert self.mutable.namespace == self.functional.namespace + # See Note [Overload Ambiguity With Functional Variants] + assert self.functional.func.name.name.functional_overload if self.structured: # For now, structured composite kernels are not supported (need some @@ -901,7 +981,7 @@ def __post_init__(self) -> None: raise RuntimeError( f"The codegen expects to be able to generate '{generated_fns_str}'." f" To do so, it expects a line: 'autogen: {generated_fns_str}'." - f" Instead, it found 'autogen: {generated_fns_str}'" + f" Instead, it found 'autogen: {expected_generated_fns_str}'" ) def signature(self) -> "FunctionSchema": @@ -969,7 +1049,7 @@ class BackendMetadata: @dataclass(frozen=True) class UfuncInnerLoop: name: str - supported_dtypes: Set[ScalarType] + supported_dtypes: OrderedSet[ScalarType] # key is stored here because it affects the semantics of name, # so its helpful to have them together for further processing ufunc_key: UfuncKey @@ -979,7 +1059,7 @@ def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop": name, supported_dtypes_str = value.split(" ", 1) assert supported_dtypes_str[0] == "(" assert supported_dtypes_str[-1] == ")" - supported_dtypes = set() + supported_dtypes: OrderedSet[ScalarType] = OrderedSet() for k in supported_dtypes_str[1:-1].split(", "): supported_dtypes |= ScalarType.parse_set(k) return UfuncInnerLoop( @@ -1039,7 +1119,7 @@ def get_kernel( elif isinstance(g, NativeFunctionsGroup): f = self.primary(g) else: - assert_never(f) + assert_never(g) if f.func.name not in self.index: return None return self.index[f.func.name] @@ -1126,6 +1206,9 @@ def schema_order_arguments(self) -> Iterator["Argument"]: decl_re = re.compile(r"(?P[^\(]+)\((?P.*)\) -> (?P.*)") + def symints_to_ints(self) -> "FunctionSchema": + return dataclasses.replace(self, arguments=self.arguments.symints_to_ints()) + @staticmethod def parse(func: str) -> "FunctionSchema": # We should probably get a proper parser here @@ -1565,6 +1648,9 @@ def is_nullable(self) -> bool: def is_list_like(self) -> Optional["ListType"]: raise NotImplementedError + def symint_to_int(self) -> "Type": + raise NotImplementedError + # Base types are simple, atomic types with no further structure BaseTy = Enum( @@ -1605,6 +1691,11 @@ def is_tensor_like(self) -> bool: def is_nullable(self) -> bool: return False + def symint_to_int(self) -> "BaseType": + if self.name == BaseTy.SymInt: + return BaseType(BaseTy.int) + return self + def is_list_like(self) -> Optional["ListType"]: return None @@ -1623,6 +1714,9 @@ def is_tensor_like(self) -> bool: def is_nullable(self) -> bool: return True + def symint_to_int(self) -> "Type": + return dataclasses.replace(self, elem=self.elem.symint_to_int()) + def is_list_like(self) -> Optional["ListType"]: return self.elem.is_list_like() @@ -1649,6 +1743,9 @@ def is_tensor_like(self) -> bool: def is_nullable(self) -> bool: return self.elem.is_nullable() + def symint_to_int(self) -> "ListType": + return ListType(self.elem.symint_to_int(), self.size) + def is_list_like(self) -> Optional["ListType"]: return self @@ -1722,6 +1819,9 @@ def parse(arg: str) -> "Argument": def is_write(self) -> bool: return self.annotation is not None and self.annotation.is_write + def symint_to_int(self) -> "Argument": + return dataclasses.replace(self, type=self.type.symint_to_int()) + def __str__(self) -> str: type = f"{self.type}" if self.annotation: @@ -1911,6 +2011,40 @@ def mutable_arg_names(self) -> List[str]: if a.annotation is not None and a.annotation.is_write ] + def symints_to_ints(self) -> "Arguments": + arguments = self + + if arguments.self_arg: + arguments = dataclasses.replace( + arguments, + pre_self_positional=[ + x.symint_to_int() for x in arguments.pre_self_positional + ], + ) + + if self.tensor_options: + arguments = dataclasses.replace( + arguments, + post_tensor_options_kwarg_only=[ + x.symint_to_int() for x in arguments.post_tensor_options_kwarg_only + ], + ) + + arguments = dataclasses.replace( + arguments, + post_self_positional=[ + x.symint_to_int() for x in arguments.post_self_positional + ], + pre_tensor_options_kwarg_only=[ + x.symint_to_int() for x in arguments.pre_tensor_options_kwarg_only + ], + ) + + return arguments + + def has_tensor_arg(self) -> bool: + return any(a.type.is_tensor_like() for a in self.flat_non_out) + def signature(self, *, strip_default: bool = False) -> "Arguments": # dataclasses.replace could be used here, but it is less # type safe so for now I've opted to type everything out @@ -2135,6 +2269,26 @@ class BaseOperatorName: base: str inplace: bool dunder_method: bool + # Note [Overload Ambiguity With Functional Variants] + # A handful of operators have both a "mutable" and a "functional" variant. + # (native_batch_norm is a good example, although this isn't the case today). + # For those operators, the mutable and functional variant take in the same set of + # arguments, but have different alias annotations. + # this makes it ambiguous when you try to resolve an OverloadPacket into an overload, + # given a set of input arguments. + # + # So instead of making the "functional" variant in this case a real overload, e.g: + # native_batch_norm (mutable variant) + # native_batch_norm.functional (functional variant) + # we make it a new base operator, + # native_batch_norm_functional (functional variant) + # + # In an ideal world, we would probably invert this so the operators were: + # native_batch_norm.mutable (mutable variant) + # native_batch_norm (functional variant) + # + # Doing that is BC-breaking though, so we're stuck with the above modeling. + functional_overload: bool = False @staticmethod def parse(op: str) -> "BaseOperatorName": @@ -2165,7 +2319,24 @@ def parse(op: str) -> "BaseOperatorName": base = base[:-1] else: inplace = False - r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method) + + # See Note [Overload Ambiguity With Functional Variants] + functional_suffix = "_functional" + if base.endswith(functional_suffix): + functional_overload = True + base = base[: -len(functional_suffix)] + # This seems complicated and unnecessary, so banning dunder methods + # for now on ops that have a functional + mutable variant (like native_batch_norm). + assert not dunder_method and not inplace + else: + functional_overload = False + + r = BaseOperatorName( + base=base, + inplace=inplace, + dunder_method=dunder_method, + functional_overload=functional_overload, + ) assert str(r) == op, f"{str(r)} != {op}" return r @@ -2174,7 +2345,13 @@ def __str__(self) -> str: i = "i" if self.inplace else "" return f"__{i}{self.base}__" else: - i = "_" if self.inplace else "" + i = ( + "_" + if self.inplace + else "_functional" + if self.functional_overload + else "" + ) return f"{self.base}{i}" @@ -2425,6 +2602,3 @@ def to_list(self) -> List[str]: replace_list.append(f"{kernel_param} -> {replacements}") return replace_list - - -import torchgen.api.ufunc as ufunc diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index fe7f7182ec05b..df34b8a7c1f5f 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -11,6 +11,7 @@ Argument, BackendIndex, BackendMetadata, + BaseOperatorName, BaseTy, BaseType, DEFAULT_KERNEL_NAMESPACE, @@ -193,14 +194,20 @@ def generate_function( # The new "functional" NativeFunction has: # - any mutable arguments have been converted into (immutable) returns. # (if a mutable argument was not also a return, it gets converted to one) - # - a "functional" overload name. + # - "_functional" appended to the base name, ONLY IF this op has a mutable variant. + # See Note [Overload Ambiguity With Functional Variants] # The default grouping logic in signature() actually already does this, # so we can piggy-back off it (but we still want return names) func = f.func.signature(keep_return_names=True).with_name( - f.func.name.remove_inplace().with_overload( - "functional" - if not f.func.name.overload_name - else f"{f.func.name.overload_name}_functional" + OperatorName( + name=BaseOperatorName( + base=f.func.name.name.base, + inplace=False, + dunder_method=f.func.name.name.dunder_method, + # See Note [Overload Ambiguity With Functional Variants] + functional_overload=f.func.kind() == SchemaKind.mutable, + ), + overload_name=f.func.name.overload_name, ) ) elif k == SchemaKind.out: diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index 54c5b3a5628af..5006f4f6d89a0 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -5,8 +5,9 @@ from typing import Any, Dict, List import torch -from torchgen.code_template import CodeTemplate from torch.jit.generate_bytecode import generate_upgraders_bytecode + +from torchgen.code_template import CodeTemplate from torchgen.operator_versions.gen_mobile_upgraders_constant import ( MOBILE_UPGRADERS_HEADER_DESCRIPTION, ) diff --git a/torchgen/selective_build/operator.py b/torchgen/selective_build/operator.py index ca80f5ad7f2a8..76f8b963b990e 100644 --- a/torchgen/selective_build/operator.py +++ b/torchgen/selective_build/operator.py @@ -1,5 +1,5 @@ -from typing import Dict, Optional, Tuple from dataclasses import dataclass +from typing import Dict, Optional, Tuple # This class holds information about a single operator used to determine # the outcome of a selective/custom PyTorch build that doesn't include diff --git a/torchgen/selective_build/selector.py b/torchgen/selective_build/selector.py index e65ecf5eaf452..dd94dd17dd0ed 100644 --- a/torchgen/selective_build/selector.py +++ b/torchgen/selective_build/selector.py @@ -1,13 +1,13 @@ -from typing import Dict, Set, Optional, Tuple, List -import yaml - from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple + +import yaml from torchgen.model import NativeFunction from torchgen.selective_build.operator import ( - SelectiveBuildOperator, merge_debug_info, merge_operator_dicts, + SelectiveBuildOperator, strip_operator_overload_name, ) diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py index 9d1f7a75f9a55..6013c4de5350f 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 import os -from pathlib import Path from itertools import chain +from pathlib import Path from torch.jit._shape_functions import ( - shape_compute_graph_mapping, bounded_compute_graph_mapping, + shape_compute_graph_mapping, ) SHAPE_HEADER = r""" diff --git a/torchgen/static_runtime/config.py b/torchgen/static_runtime/config.py index 90ae5ba36d250..0c4fa07880737 100644 --- a/torchgen/static_runtime/config.py +++ b/torchgen/static_runtime/config.py @@ -1,7 +1,7 @@ -from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup - from typing import Dict, Union +from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup + def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str: if isinstance(g, NativeFunctionsGroup): @@ -55,7 +55,7 @@ def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) ) -def is_hand_written(g: NativeFunctionsGroup) -> bool: +def is_hand_written(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: name_base = func_name_base_str(g) return name_base in is_hand_written_ops_ @@ -102,9 +102,9 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N return if op_name == "take_along_dim": if index == 0: - arg_map["indices"] = "at::argsort(self0, 1)" + arg_map["indices"] = "at::argsort(self0, 1, true)" else: - arg_map["indices"] = "at::argsort(self1, 1)" + arg_map["indices"] = "at::argsort(self1, 1, true)" return if op_name == "masked_select": if index == 0: diff --git a/torchgen/static_runtime/gen_static_runtime_ops.py b/torchgen/static_runtime/gen_static_runtime_ops.py index 8608aa82401e3..6e234d33a7212 100644 --- a/torchgen/static_runtime/gen_static_runtime_ops.py +++ b/torchgen/static_runtime/gen_static_runtime_ops.py @@ -1,24 +1,28 @@ -from torchgen import gen -from torchgen.context import native_function_manager -from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup -from torchgen.static_runtime import generator - import argparse import itertools import os -from typing import Sequence, Union -from libfb.py.log import set_simple_logging +from typing import Sequence, TypeVar, Union + +from libfb.py.log import set_simple_logging # type: ignore[import] + +from torchgen import gen +from torchgen.context import native_function_manager +from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup +from torchgen.static_runtime import config, generator # Given a list of `grouped_native_functions` sorted by their op names, return a list of # lists each of which groups ops that share the base name. For example, `mean` and # `mean.dim` are grouped together by this function. +NativeGroupT = TypeVar( + "NativeGroupT", + bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup], +) + def group_functions_by_op_name( - grouped_native_functions: Sequence[ - Union[NativeFunctionsGroup, NativeFunctionsViewGroup] - ] -) -> Sequence[Sequence[Union[NativeFunctionsGroup, NativeFunctionsViewGroup]]]: + grouped_native_functions: Sequence[NativeGroupT], +) -> Sequence[Sequence[NativeGroupT]]: if not grouped_native_functions: return [] groups = [] @@ -33,9 +37,7 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo for k, group in ( itertools.groupby( eligible_ops, - key=lambda g: g.functional.func.name.name.base - if isinstance(g, NativeFunctionsGroup) - else g.view.root_name, + key=lambda g: config.func_name_base_str(g), ) ) ] diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index 24593726056ce..22bf259f640bb 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -1,26 +1,27 @@ +import json +import logging + +import math +from typing import Dict, List, Optional, Sequence, Tuple, Union + import torchgen.api.cpp as cpp from torchgen.context import native_function_manager from torchgen.model import ( Argument, BackendIndex, BaseTy, + BaseType, FunctionSchema, + NativeFunctionsGroup, + NativeFunctionsViewGroup, OptionalType, SelfArgument, - BaseType, - NativeFunctionsGroup, TensorOptionsArguments, Type, - NativeFunctionsViewGroup, ) from torchgen.static_runtime import config -import math -import logging -import json -from typing import List, Optional, Sequence, Tuple, Union - -logger: logger = logging.getLogger() +logger: logging.Logger = logging.getLogger() def has_alias( @@ -228,7 +229,7 @@ def test_tensor_dim(op_name: str) -> int: test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' -test_tensor_shape_json = json.loads(test_tensor_shapes_string) +test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string) def test_tensor_shape(op_name: str) -> str: @@ -399,7 +400,6 @@ def generate_out_variant_call( if not g.structured: assert len(schema.arguments.out) == 1 arg_names.append(schema.arguments.out[0].name) - cpp_func_name = cpp.name(schema) cpp_arg_names = ",".join(arg_names) namespace_name = "cpu" if g.structured else "native" return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})" @@ -490,7 +490,6 @@ def out_variant_op_generator( ) -> str: functional = g.functional schema = str(functional.func) - op_name = op_name_from_group(g) populated_argument = generate_arg_extraction(g.functional.func) functional_variant_call = generate_non_out_variant_call(g, backend_index) assert len(g.out.func.arguments.out) == 1 @@ -515,7 +514,6 @@ def view_op_generator( self, g: NativeFunctionsViewGroup, backend_index: BackendIndex ) -> str: schema = str(g.view.func) - op_name = config.func_name_base_str(g) populated_argument = generate_arg_extraction(g.view.func) functional_variant_call = generate_call_to_view_ops(g, backend_index) generated = f""" diff --git a/torchgen/utils.py b/torchgen/utils.py index 1067d5ace28e5..c168f186f83c3 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -3,29 +3,29 @@ import hashlib import os import re -import textwrap import sys +import textwrap from argparse import Namespace -from dataclasses import ( - fields, - is_dataclass, -) +from dataclasses import fields, is_dataclass +from enum import Enum from typing import ( - Tuple, - List, + Any, + Callable, + Dict, + Generic, Iterable, Iterator, - Callable, + List, + NoReturn, + Optional, Sequence, + Set, + Tuple, TypeVar, - Optional, - Dict, - Any, Union, - Set, - NoReturn, ) -from enum import Enum + +from typing_extensions import Literal from torchgen.code_template import CodeTemplate @@ -396,3 +396,113 @@ def _format( indent_str = " " * indent body = f", {delimiter}{curr_indent_str}".join(fields_str) return f"{start}{indent_str}{body}{end}" + + +class NamespaceHelper: + """A helper for constructing the namespace open and close strings for a nested set of namespaces. + + e.g. for namespace_str torch::lazy, + + prologue: + namespace torch { + namespace lazy { + + epilogue: + } // namespace lazy + } // namespace torch + """ + + def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2): + # cpp_namespace can be a colon joined string such as torch::lazy + cpp_namespaces = namespace_str.split("::") + assert ( + len(cpp_namespaces) <= max_level + ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}." + self.cpp_namespace_ = namespace_str + self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces]) + self.epilogue_ = "\n".join( + [f"}} // namespace {n}" for n in reversed(cpp_namespaces)] + ) + self.namespaces_ = cpp_namespaces + self.entity_name_ = entity_name + + @staticmethod + def from_namespaced_entity( + namespaced_entity: str, max_level: int = 2 + ) -> "NamespaceHelper": + """ + Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" + """ + names = namespaced_entity.split("::") + entity_name = names[-1] + namespace_str = "::".join(names[:-1]) + return NamespaceHelper( + namespace_str=namespace_str, entity_name=entity_name, max_level=max_level + ) + + @property + def prologue(self) -> str: + return self.prologue_ + + @property + def epilogue(self) -> str: + return self.epilogue_ + + @property + def entity_name(self) -> str: + return self.entity_name_ + + # Only allow certain level of namespaces + def get_cpp_namespace(self, default: str = "") -> str: + """ + Return the namespace string from joining all the namespaces by "::" (hence no leading "::"). + Return default if namespace string is empty. + """ + return self.cpp_namespace_ if self.cpp_namespace_ else default + + +class OrderedSet(Generic[T]): + storage: Dict[T, Literal[None]] + + def __init__(self, iterable: Optional[Iterable[T]] = None): + if iterable is None: + self.storage = {} + else: + self.storage = {k: None for k in iterable} + + def __contains__(self, item: T) -> bool: + return item in self.storage + + def __iter__(self) -> Iterator[T]: + return iter(self.storage.keys()) + + def update(self, items: "OrderedSet[T]") -> None: + self.storage.update(items.storage) + + def add(self, item: T) -> None: + self.storage[item] = None + + def copy(self) -> "OrderedSet[T]": + ret: OrderedSet[T] = OrderedSet() + ret.storage = self.storage.copy() + return ret + + @staticmethod + def union(*args: "OrderedSet[T]") -> "OrderedSet[T]": + ret = args[0].copy() + for s in args[1:]: + ret.update(s) + return ret + + def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]": + return OrderedSet.union(self, other) + + def __ior__(self, other: "OrderedSet[T]") -> "OrderedSet[T]": + self.update(other) + return self + + def __eq__(self, other: object) -> bool: + if isinstance(other, OrderedSet): + return self.storage == other.storage + else: + return set(self.storage.keys()) == other
+
+
+ + + + + +
+
+